diff options
author | jfalkenhagen <jfalkenhagen@uos.de> | 2020-07-10 16:19:08 +0200 |
---|---|---|
committer | jfalkenhagen <jfalkenhagen@uos.de> | 2020-07-10 16:19:08 +0200 |
commit | 98a7873ec1ce265e6d229af4fa8416b3a9ef018a (patch) | |
tree | f28987e8f7ba6d7ab65e82cc758495441fed5f75 /bin | |
parent | 71b981c13c007d33f4042823703f98e41ff56770 (diff) |
bin/Proof_Of_Concept_PELT.py: Calculation of raw_states is now parallelized.
Diffstat (limited to 'bin')
-rw-r--r-- | bin/Proof_Of_Concept_PELT.py | 121 |
1 files changed, 79 insertions, 42 deletions
diff --git a/bin/Proof_Of_Concept_PELT.py b/bin/Proof_Of_Concept_PELT.py index 92d09fa..bcbd53e 100644 --- a/bin/Proof_Of_Concept_PELT.py +++ b/bin/Proof_Of_Concept_PELT.py @@ -6,7 +6,6 @@ import re from multiprocessing import Pool, Manager from kneed import KneeLocator from sklearn.cluster import AgglomerativeClustering -from scipy.signal import find_peaks import matplotlib.pyplot as plt import ruptures as rpt import numpy as np @@ -287,6 +286,50 @@ def calculate_penalty_value(signal, model="l1", jump=5, min_dist=2, range_min=0, # return False +# raw_states_calc_args.append((num_measurement, measurement, penalty, opt_model +# , opt_jump)) +def calc_raw_states_func(num_trace, measurement, penalty, model, jump): + signal = np.array(measurement['uW']) + normed_signal = norm_signal(signal) + bkpts = calc_pelt(normed_signal, penalty, model=model, jump=jump) + calced_states = list() + start_time = 0 + end_time = 0 + for bkpt in bkpts: + # start_time of state is end_time of previous one + # (Transitions are instantaneous) + start_time = end_time + end_time = bkpt + power_vals = signal[start_time: end_time] + mean_power = np.mean(power_vals) + std_dev = np.std(power_vals) + calced_state = (start_time, end_time, mean_power, std_dev) + calced_states.append(calced_state) + num = 0 + new_avg_std = 0 + for s in calced_states: + # print_info("State " + str(num) + " starts at t=" + str(s[0]) + # + " and ends at t=" + str(s[1]) + # + " while using " + str(s[2]) + # + "uW with sigma=" + str(s[3])) + num = num + 1 + new_avg_std = new_avg_std + s[3] + new_avg_std = new_avg_std / len(calced_states) + change_avg_std = measurement['uW_std'] - new_avg_std + # print_info("The average standard deviation for the newly found states is " + # + str(new_avg_std)) + # print_info("That is a reduction of " + str(change_avg_std)) + return num_trace, calced_states, new_avg_std, change_avg_std + + +def calc_raw_states(arg_list, num_processes=8): + m = Manager() + with Pool(processes=num_processes) as p: + # collect results from pool + result = p.starmap(calc_raw_states_func, arg_list) + return result + + # Very short benchmark yielded approx. 3 times the speed of solution not using sort # TODO: Decide whether median is really the better baseline than mean def needs_refinement(signal, thresh): @@ -477,10 +520,9 @@ if __name__ == '__main__': signal = measurement['uW'] # mean = measurement['uW_mean'] # TODO: Decide if median is really the better baseline than mean - if needs_refinement(signal, opt_refinement_thresh): + if needs_refinement(signal, opt_refinement_thresh) and not refine: print_info("Refinement is necessary!") refine = True - break if not refine: print_info("No refinement necessary for state '" + measurements_by_state['name'] + "' with params: " + str(measurements_by_state['parameter'])) @@ -499,45 +541,34 @@ if __name__ == '__main__': penalty = penalty[0] else: penalty = opt_pen_override - # calc and save all bkpts for the given state and param config - raw_states_list = list() - for measurement in measurements_by_state['offline']: - signal = np.array(measurement['uW']) - normed_signal = norm_signal(signal) - bkpts = calc_pelt(normed_signal, penalty, model=opt_model, jump=opt_jump) - calced_states = list() - start_time = 0 - end_time = 0 - for bkpt in bkpts: - # start_time of state is end_time of previous one - # (Transitions are instantaneous) - start_time = end_time - end_time = bkpt - power_vals = signal[start_time: end_time] - mean_power = np.mean(power_vals) - std_dev = np.std(power_vals) - calced_state = (start_time, end_time, mean_power, std_dev) - calced_states.append(calced_state) - num = 0 - new_avg_std = 0 - for s in calced_states: - print_info("State " + str(num) + " starts at t=" + str(s[0]) - + " and ends at t=" + str(s[1]) - + " while using " + str(s[2]) - + "uW with sigma=" + str(s[3])) - num = num + 1 - new_avg_std = new_avg_std + s[3] - new_avg_std = new_avg_std / len(calced_states) - change_avg_std = measurement['uW_std'] - new_avg_std - print_info("The average standard deviation for the newly found states is " - + str(new_avg_std)) + # build arguments for parallel excecution + print_info("Starting raw_states calculation.") + raw_states_calc_args = [] + for num_measurement, measurement in enumerate(measurements_by_state['offline']): + raw_states_calc_args.append((num_measurement, measurement, penalty, + opt_model, opt_jump)) + + raw_states_list = [None] * len(measurements_by_state['offline']) + raw_states_res = calc_raw_states(raw_states_calc_args, opt_num_processes) + # extracting result and putting it in correct order -> index of raw_states_list + # entry still corresponds with index of measurement in measurements_by_states + # -> If measurements are discarded the correct ones are easily recognized + for ret_val in raw_states_res: + num_trace = ret_val[0] + raw_states = ret_val[1] + avg_std = ret_val[2] + change_avg_std = ret_val[3] + # TODO: Wieso gibt mir meine IDE hier eine Warning aus? Der Index müsste doch + # int sein oder nicht? Es scheint auch vernünftig zu klappen... + raw_states_list[num_trace] = raw_states + print_info("The average standard deviation for the newly found states in " + + "measurement No. " + str(num_trace) + " is " + str(avg_std)) print_info("That is a reduction of " + str(change_avg_std)) - raw_states_list.append(calced_states) + print_info("Finished raw_states calculation.") num_states_array = [int()] * len(raw_states_list) i = 0 - for x in raw_states_list: + for i, x in enumerate(raw_states_list): num_states_array[i] = len(x) - i = i + 1 avg_num_states = np.mean(num_states_array) num_states_dev = np.std(num_states_array) print_info("On average " + str(avg_num_states) @@ -558,7 +589,7 @@ if __name__ == '__main__': i = 0 cluster_labels_list = [] num_cluster_list = [] - for raw_states in raw_states_list: + for num_trace, raw_states in enumerate(raw_states_list): # iterate through raw states from measurements if len(raw_states) == num_raw_states: # build array with power values to cluster these @@ -580,12 +611,14 @@ if __name__ == '__main__': # plt.show() # TODO: Automatic detection of number of clusters. Aktuell noch MAGIC NUMBER # im distance_threshold - cluster = AgglomerativeClustering(n_clusters=None, compute_full_tree=True, affinity='euclidean', - linkage='ward', distance_threshold=opt_refinement_thresh*100) + cluster = AgglomerativeClustering(n_clusters=None, compute_full_tree=True, + affinity='euclidean', + linkage='ward', + distance_threshold=opt_refinement_thresh * 100) # cluster = AgglomerativeClustering(n_clusters=5, affinity='euclidean', # linkage='ward') cluster.fit_predict(value_to_cluster) - print_info("Cluster labels:\n" + str(cluster.labels_)) + # print_info("Cluster labels:\n" + str(cluster.labels_)) # plt.scatter(value_to_cluster[:, 0], value_to_cluster[:, 1], c=cluster.labels_, cmap='rainbow') # plt.show() # TODO: Problem: Der Algorithmus nummeriert die Zustände nicht immer gleich... also bspw.: @@ -593,6 +626,9 @@ if __name__ == '__main__': cluster_labels_list.append(cluster.labels_) num_cluster_list.append(cluster.n_clusters_) i = i + 1 + else: + print_info("Discarding measurement No. " + str(num_trace) + " because it " + + "did not recognize the number of raw_states correctly.") if i != len(raw_states_list): if i / len(raw_states_list) <= 0.5: print_warning("Only used " + str(i) + "/" + str(len(raw_states_list)) @@ -603,6 +639,7 @@ if __name__ == '__main__': print_info("Used " + str(i) + "/" + str(len(raw_states_list)) + " Measurements for refinement. " "Others did not recognize number of states correctly.") + # TODO: DEBUG Kram sys.exit() else: print_info("Used all available measurements.") |