summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorDaniel Friesel <daniel.friesel@uos.de>2020-10-09 16:22:19 +0200
committerDaniel Friesel <daniel.friesel@uos.de>2020-10-09 16:22:19 +0200
commit8e0d3674c9e8fcbc90eafaa390154550bbc234a6 (patch)
tree070ce2d02677fa1edf1fa38c5c828e25b907a65d /lib
parent36a4a6f2c8d7f73279a43e0926eeb09da07bff63 (diff)
add sub-state generation
Diffstat (limited to 'lib')
-rw-r--r--lib/model.py7
-rw-r--r--lib/pelt.py134
2 files changed, 139 insertions, 2 deletions
diff --git a/lib/model.py b/lib/model.py
index d7a9bc9..57507e7 100644
--- a/lib/model.py
+++ b/lib/model.py
@@ -944,7 +944,12 @@ class PTAModel:
penalty = self.pelt.get_penalty_value(
self.by_param[k]["power_traces"]
)
- print(f" penalty: {penalty}")
+ print(
+ f" we found {penalty[1]} changepoints with penalty {penalty[0]}"
+ )
+ self.pelt.calc_raw_states(
+ self.by_param[k]["power_traces"], penalty[0]
+ )
return None, None
diff --git a/lib/pelt.py b/lib/pelt.py
index ddc2324..7f2e922 100644
--- a/lib/pelt.py
+++ b/lib/pelt.py
@@ -28,6 +28,97 @@ def norm_signal(signal, scaler=25):
return normed_signal
+# Scheint Einfluss auf die gefundene Anzahl CHangepoints zu haben. Hrmpf.
+# Kann aber auch shitty downsampling sein. Diese Funktion ist _sehr_ quick&dirty.
+def coarse_signal(signal, divisor=10):
+ ret = list()
+ for i in range((len(signal) // divisor)):
+ ret.append(np.mean(signal[i * divisor : (i + 1) * divisor]))
+ return np.array(ret)
+
+
+# returns the changepoints found on signal with penalty penalty.
+# model, jump and min_dist are directly passed to PELT
+def calc_pelt(signal, penalty, model="l1", jump=5, min_dist=2, plotting=False):
+ # default params in Function
+ if model is None:
+ model = "l1"
+ if jump is None:
+ jump = 5
+ if min_dist is None:
+ min_dist = 2
+ if plotting is None:
+ plotting = False
+ # change point detection. best fit seemingly with l1. rbf prods. RuntimeErr for pen > 30
+ # https://ctruong.perso.math.cnrs.fr/ruptures-docs/build/html/costs/index.html
+ # model = "l1" #"l1" # "l2", "rbf"
+ algo = ruptures.Pelt(model=model, jump=jump, min_size=min_dist).fit(signal)
+
+ if penalty is not None:
+ bkps = algo.predict(pen=penalty)
+ if plotting:
+ fig, ax = ruptures.display(signal, bkps)
+ plt.show()
+ return bkps
+
+ print_error("No Penalty specified.")
+ sys.exit(-1)
+
+
+# calculates the raw_states for measurement measurement. num_measurement is used to identify the
+# return value
+# penalty, model and jump are directly passed to pelt
+def calc_raw_states_func(num_measurement, measurement, penalty, model, jump):
+ # extract signal
+ signal = np.array(measurement)
+ # norm signal to remove dependency on absolute values
+ normed_signal = norm_signal(signal)
+ # calculate the breakpoints
+ bkpts = calc_pelt(normed_signal, penalty, model=model, jump=jump)
+ calced_states = list()
+ start_time = 0
+ end_time = 0
+ # calc metrics for all states
+ for bkpt in bkpts:
+ # start_time of state is end_time of previous one
+ # (Transitions are instantaneous)
+ start_time = end_time
+ end_time = bkpt
+ power_vals = signal[start_time:end_time]
+ mean_power = np.mean(power_vals)
+ std_dev = np.std(power_vals)
+ calced_state = (start_time, end_time, mean_power, std_dev)
+ calced_states.append(calced_state)
+ num = 0
+ new_avg_std = 0
+ # calc avg std for all states from this measurement
+ for s in calced_states:
+ # print_info("State " + str(num) + " starts at t=" + str(s[0])
+ # + " and ends at t=" + str(s[1])
+ # + " while using " + str(s[2])
+ # + "uW with sigma=" + str(s[3]))
+ num = num + 1
+ new_avg_std = new_avg_std + s[3]
+ # check case if no state has been found to avoid crashing
+ if len(calced_states) != 0:
+ new_avg_std = new_avg_std / len(calced_states)
+ else:
+ new_avg_std = 0
+ change_avg_std = None # measurement["uW_std"] - new_avg_std
+ # print_info("The average standard deviation for the newly found states is "
+ # + str(new_avg_std))
+ # print_info("That is a reduction of " + str(change_avg_std))
+ return num_measurement, calced_states, new_avg_std, change_avg_std
+
+
+# parallelize calc over all measurements
+def calc_raw_states(arg_list):
+ with Pool() as pool:
+ # collect results from pool
+ result = pool.starmap(calc_raw_states_func, arg_list)
+ return result
+
+
class PELT:
def __init__(self, **kwargs):
# Defaults von Janis
@@ -54,7 +145,10 @@ class PELT:
self, signals, model="l1", min_dist=2, range_min=0, range_max=100, S=1.0
):
# Janis macht hier noch kein norm_signal. Mit sieht es aber genau so brauchbar aus.
- signal = norm_signal(signals[0])
+ # TODO vor der Penaltybestimmung die Auflösung der Daten auf z.B. 1 kHz
+ # verringern. Dann geht's deutlich schneller und superkurze
+ # Substates interessieren uns ohnehin weniger
+ signal = coarse_signal(norm_signal(signals[0]))
algo = ruptures.Pelt(model=model, jump=self.jump, min_size=min_dist).fit(signal)
queue = list()
for i in range(range_min, range_max + 1):
@@ -102,6 +196,44 @@ class PELT:
knee = (knee[0] * 1, knee[1])
return knee
+ def calc_raw_states(self, signals, penalty, opt_model=None):
+ raw_states_calc_args = list()
+ for num_measurement, measurement in enumerate(signals):
+ raw_states_calc_args.append(
+ (num_measurement, measurement, penalty, opt_model, self.jump)
+ )
+
+ raw_states_list = [None] * len(signals)
+ raw_states_res = calc_raw_states(raw_states_calc_args)
+
+ # extracting result and putting it in correct order -> index of raw_states_list
+ # entry still corresponds with index of measurement in measurements_by_states
+ # -> If measurements are discarded the used ones are easily recognized
+ for ret_val in raw_states_res:
+ num_measurement = ret_val[0]
+ raw_states = ret_val[1]
+ avg_std = ret_val[2]
+ change_avg_std = ret_val[3]
+ # FIXME: Wieso gibt mir meine IDE hier eine Warning aus? Der Index müsste doch
+ # int sein oder nicht? Es scheint auch vernünftig zu klappen...
+ raw_states_list[num_measurement] = raw_states
+ # print(
+ # "The average standard deviation for the newly found states in "
+ # + "measurement No. "
+ # + str(num_measurement)
+ # + " is "
+ # + str(avg_std)
+ # )
+ # print("That is a reduction of " + str(change_avg_std))
+ for i, raw_state in enumerate(raw_states):
+ print(
+ f"Measurement #{num_measurement} sub-state #{i}: {raw_state[0]} -> {raw_state[1]}, mean {raw_state[2]}"
+ )
+ # l_signal = measurements_by_config['offline'][num_measurement]['uW']
+ # l_bkpts = [s[1] for s in raw_states]
+ # fig, ax = rpt.display(np.array(l_signal), l_bkpts)
+ # plt.show()
+
"""
# calculates and returns the necessary penalty for signal. Parallel execution with num_processes many processes
# jump, min_dist are passed directly to PELT. S is directly passed to kneedle.