diff options
Diffstat (limited to 'lib')
-rw-r--r-- | lib/model.py | 7 | ||||
-rw-r--r-- | lib/pelt.py | 134 |
2 files changed, 139 insertions, 2 deletions
diff --git a/lib/model.py b/lib/model.py index d7a9bc9..57507e7 100644 --- a/lib/model.py +++ b/lib/model.py @@ -944,7 +944,12 @@ class PTAModel: penalty = self.pelt.get_penalty_value( self.by_param[k]["power_traces"] ) - print(f" penalty: {penalty}") + print( + f" we found {penalty[1]} changepoints with penalty {penalty[0]}" + ) + self.pelt.calc_raw_states( + self.by_param[k]["power_traces"], penalty[0] + ) return None, None diff --git a/lib/pelt.py b/lib/pelt.py index ddc2324..7f2e922 100644 --- a/lib/pelt.py +++ b/lib/pelt.py @@ -28,6 +28,97 @@ def norm_signal(signal, scaler=25): return normed_signal +# Scheint Einfluss auf die gefundene Anzahl CHangepoints zu haben. Hrmpf. +# Kann aber auch shitty downsampling sein. Diese Funktion ist _sehr_ quick&dirty. +def coarse_signal(signal, divisor=10): + ret = list() + for i in range((len(signal) // divisor)): + ret.append(np.mean(signal[i * divisor : (i + 1) * divisor])) + return np.array(ret) + + +# returns the changepoints found on signal with penalty penalty. +# model, jump and min_dist are directly passed to PELT +def calc_pelt(signal, penalty, model="l1", jump=5, min_dist=2, plotting=False): + # default params in Function + if model is None: + model = "l1" + if jump is None: + jump = 5 + if min_dist is None: + min_dist = 2 + if plotting is None: + plotting = False + # change point detection. best fit seemingly with l1. rbf prods. RuntimeErr for pen > 30 + # https://ctruong.perso.math.cnrs.fr/ruptures-docs/build/html/costs/index.html + # model = "l1" #"l1" # "l2", "rbf" + algo = ruptures.Pelt(model=model, jump=jump, min_size=min_dist).fit(signal) + + if penalty is not None: + bkps = algo.predict(pen=penalty) + if plotting: + fig, ax = ruptures.display(signal, bkps) + plt.show() + return bkps + + print_error("No Penalty specified.") + sys.exit(-1) + + +# calculates the raw_states for measurement measurement. num_measurement is used to identify the +# return value +# penalty, model and jump are directly passed to pelt +def calc_raw_states_func(num_measurement, measurement, penalty, model, jump): + # extract signal + signal = np.array(measurement) + # norm signal to remove dependency on absolute values + normed_signal = norm_signal(signal) + # calculate the breakpoints + bkpts = calc_pelt(normed_signal, penalty, model=model, jump=jump) + calced_states = list() + start_time = 0 + end_time = 0 + # calc metrics for all states + for bkpt in bkpts: + # start_time of state is end_time of previous one + # (Transitions are instantaneous) + start_time = end_time + end_time = bkpt + power_vals = signal[start_time:end_time] + mean_power = np.mean(power_vals) + std_dev = np.std(power_vals) + calced_state = (start_time, end_time, mean_power, std_dev) + calced_states.append(calced_state) + num = 0 + new_avg_std = 0 + # calc avg std for all states from this measurement + for s in calced_states: + # print_info("State " + str(num) + " starts at t=" + str(s[0]) + # + " and ends at t=" + str(s[1]) + # + " while using " + str(s[2]) + # + "uW with sigma=" + str(s[3])) + num = num + 1 + new_avg_std = new_avg_std + s[3] + # check case if no state has been found to avoid crashing + if len(calced_states) != 0: + new_avg_std = new_avg_std / len(calced_states) + else: + new_avg_std = 0 + change_avg_std = None # measurement["uW_std"] - new_avg_std + # print_info("The average standard deviation for the newly found states is " + # + str(new_avg_std)) + # print_info("That is a reduction of " + str(change_avg_std)) + return num_measurement, calced_states, new_avg_std, change_avg_std + + +# parallelize calc over all measurements +def calc_raw_states(arg_list): + with Pool() as pool: + # collect results from pool + result = pool.starmap(calc_raw_states_func, arg_list) + return result + + class PELT: def __init__(self, **kwargs): # Defaults von Janis @@ -54,7 +145,10 @@ class PELT: self, signals, model="l1", min_dist=2, range_min=0, range_max=100, S=1.0 ): # Janis macht hier noch kein norm_signal. Mit sieht es aber genau so brauchbar aus. - signal = norm_signal(signals[0]) + # TODO vor der Penaltybestimmung die Auflösung der Daten auf z.B. 1 kHz + # verringern. Dann geht's deutlich schneller und superkurze + # Substates interessieren uns ohnehin weniger + signal = coarse_signal(norm_signal(signals[0])) algo = ruptures.Pelt(model=model, jump=self.jump, min_size=min_dist).fit(signal) queue = list() for i in range(range_min, range_max + 1): @@ -102,6 +196,44 @@ class PELT: knee = (knee[0] * 1, knee[1]) return knee + def calc_raw_states(self, signals, penalty, opt_model=None): + raw_states_calc_args = list() + for num_measurement, measurement in enumerate(signals): + raw_states_calc_args.append( + (num_measurement, measurement, penalty, opt_model, self.jump) + ) + + raw_states_list = [None] * len(signals) + raw_states_res = calc_raw_states(raw_states_calc_args) + + # extracting result and putting it in correct order -> index of raw_states_list + # entry still corresponds with index of measurement in measurements_by_states + # -> If measurements are discarded the used ones are easily recognized + for ret_val in raw_states_res: + num_measurement = ret_val[0] + raw_states = ret_val[1] + avg_std = ret_val[2] + change_avg_std = ret_val[3] + # FIXME: Wieso gibt mir meine IDE hier eine Warning aus? Der Index müsste doch + # int sein oder nicht? Es scheint auch vernünftig zu klappen... + raw_states_list[num_measurement] = raw_states + # print( + # "The average standard deviation for the newly found states in " + # + "measurement No. " + # + str(num_measurement) + # + " is " + # + str(avg_std) + # ) + # print("That is a reduction of " + str(change_avg_std)) + for i, raw_state in enumerate(raw_states): + print( + f"Measurement #{num_measurement} sub-state #{i}: {raw_state[0]} -> {raw_state[1]}, mean {raw_state[2]}" + ) + # l_signal = measurements_by_config['offline'][num_measurement]['uW'] + # l_bkpts = [s[1] for s in raw_states] + # fig, ax = rpt.display(np.array(l_signal), l_bkpts) + # plt.show() + """ # calculates and returns the necessary penalty for signal. Parallel execution with num_processes many processes # jump, min_dist are passed directly to PELT. S is directly passed to kneedle. |