diff options
Diffstat (limited to 'bin/Proof_Of_Concept_PELT.py')
-rw-r--r-- | bin/Proof_Of_Concept_PELT.py | 111 |
1 files changed, 97 insertions, 14 deletions
diff --git a/bin/Proof_Of_Concept_PELT.py b/bin/Proof_Of_Concept_PELT.py index d4878c1..80f7c04 100644 --- a/bin/Proof_Of_Concept_PELT.py +++ b/bin/Proof_Of_Concept_PELT.py @@ -10,6 +10,11 @@ import getopt import re from dfatool.dfatool import RawData +from sklearn.cluster import AgglomerativeClustering +from scipy.cluster.hierarchy import dendrogram, linkage + +# py bin\Proof_Of_Concept_PELT.py --filename="..\data\TX.json" --jump=1 --pen_override=10 --refinement_thresh=100 + def plot_data_from_json(filename, trace_num, x_axis, y_axis): with open(filename, 'r') as f: @@ -60,7 +65,7 @@ def find_knee_point(data_x, data_y, S=1.0, curve='convex', direction='decreasing def calc_pelt(signal, model='l1', jump=5, min_dist=2, range_min=0, range_max=50, num_processes=8, refresh_delay=1, - refresh_thresh=5, S=1.0, pen_override=None, plotting=False): + refresh_thresh=5, S=1.0, pen_override=None, pen_modifier=None, plotting=False): # default params in Function if model is None: model = 'l1' @@ -82,6 +87,8 @@ def calc_pelt(signal, model='l1', jump=5, min_dist=2, range_min=0, range_max=50, S = 1.0 if plotting is None: plotting = False + if pen_modifier is None: + pen_modifier = 1 # change point detection. best fit seemingly with l1. rbf prods. RuntimeErr for pen > 30 # https://ctruong.perso.math.cnrs.fr/ruptures-docs/build/html/costs/index.html # model = "l1" #"l1" # "l2", "rbf" @@ -98,7 +105,7 @@ def calc_pelt(signal, model='l1', jump=5, min_dist=2, range_min=0, range_max=50, for i in range(range_min, range_max + 1): args.append((algo, i, q)) - print('[INFO]starting kneepoint calculation.') + print_info('starting kneepoint calculation.') # init Pool with num_proesses with Pool(num_processes) as p: # collect results from pool @@ -114,7 +121,7 @@ def calc_pelt(signal, model='l1', jump=5, min_dist=2, range_min=0, range_max=50, last_percentage = percentage percentage = round(size / (range_max - range_min) * 100, 2) if percentage >= last_percentage + 2 or i >= refresh_thresh: - print('[INFO]Current progress: ' + str(percentage) + '%') + print_info('Current progress: ' + str(percentage) + '%') i = 0 else: i += 1 @@ -133,6 +140,8 @@ def calc_pelt(signal, model='l1', jump=5, min_dist=2, range_min=0, range_max=50, # plt.vlines(knee[0], 0, max(fitted_bkps_val), linestyles='dashed') # print("knee: " + str(knee[0])) # plt.show() + # modify knee according to options. Defaults to 1 * knee + knee = (knee[0] * pen_modifier, knee[1]) else: # use forced pen value for plotting if specified. Else use only pen in range if pen_override is not None: @@ -237,15 +246,21 @@ def needs_refinement(signal, thresh): def print_info(str): - print("[INFO]" + str) + str_lst = str.split(sep='\n') + for str in str_lst: + print("[INFO]" + str) def print_warning(str): - print("[WARNING]" + str) + str_lst = str.split(sep='\n') + for str in str_lst: + print("[WARNING]" + str) def print_error(str): - print("ERROR" + str, file=sys.stderr) + str_lst = str.split(sep='\n') + for str in str_lst: + print("[ERROR]" + str, file=sys.stderr) if __name__ == '__main__': @@ -265,6 +280,7 @@ if __name__ == '__main__': "refresh_thresh= " "S= " "pen_override= " + "pen_modifier= " "plotting= " "refinement_thresh= " ) @@ -280,6 +296,7 @@ if __name__ == '__main__': opt_refresh_thresh = None opt_S = None opt_pen_override = None + opt_pen_modifier = None opt_plotting = False opt_refinement_thresh = None try: @@ -353,6 +370,12 @@ if __name__ == '__main__': except ValueError as verr: print(verr, file=sys.stderr) sys.exit(2) + if 'pen_modifier' in opt: + try: + opt_pen_modifier = float(opt['pen_modifier']) + except ValueError as verr: + print(verr, file=sys.stderr) + sys.exit(2) if 'refinement_thresh' in opt: try: opt_refinement_thresh = int(opt['refinement_thresh']) @@ -390,7 +413,7 @@ if __name__ == '__main__': print_info("No refinement necessary for state '" + measurements_by_state['name'] + "'") else: # calc and save all bkpts for the given state and param config - state_list = list() + raw_states_list = list() for measurement in measurements_by_state['offline']: signal = np.array(measurement['uW']) normed_signal = np.zeros(shape=len(signal)) @@ -398,7 +421,7 @@ if __name__ == '__main__': normed_signal[i] = signal[i] / 1000 bkpts = calc_pelt(normed_signal, model=opt_model, range_min=opt_range_min, range_max=opt_range_max, num_processes=opt_num_processes, jump=opt_jump, S=opt_S, - pen_override=opt_pen_override) + pen_override=opt_pen_override, pen_modifier=opt_pen_modifier) calced_states = list() start_time = 0 end_time = 0 @@ -420,12 +443,12 @@ if __name__ == '__main__': new_avg_std = new_avg_std + s[3] new_avg_std = new_avg_std / len(calced_states) change_avg_std = measurement['uW_std'] - new_avg_std - print_info("The average standard deviation for the newly found states is " + str(new_avg_std) - + ".\n[INFO]That is a reduction of " + str(change_avg_std)) - state_list.append(calced_states) - num_states_array = np.zeros(shape=len(measurements_by_state['offline'])) + print_info("The average standard deviation for the newly found states is " + str(new_avg_std)) + print_info("That is a reduction of " + str(change_avg_std)) + raw_states_list.append(calced_states) + num_states_array = [int()] * len(raw_states_list) i = 0 - for x in state_list: + for x in raw_states_list: num_states_array[i] = len(x) i = i + 1 avg_num_states = np.mean(num_states_array) @@ -435,10 +458,70 @@ if __name__ == '__main__': # TODO: MAGIC NUMBER if num_states_dev > 1: print_warning("The number of states varies strongly across measurements. Consider choosing a " - "larger value for S.") + "larger value for S or using the pen_modifier option.") time.sleep(5) # TODO: Wie bekomme ich da jetzt raus, was die Wahrheit ist? # Einfach Durchschnitt nehmen? + # Preliminary decision: Further on only use the traces, which have the most frequent state count + counts = np.bincount(num_states_array) + num_raw_states = np.argmax(counts) + print_info("Choose " + str(num_raw_states) + " as number of raw_states.") + i = 0 + cluster_labels_list = [] + num_cluster_list = [] + for raw_states in raw_states_list: + # iterate through raw states from measurements + if len(raw_states) == num_raw_states: + # build array with power values to cluster these + value_to_cluster = np.zeros((num_raw_states, 2)) + j = 0 + for s in raw_states: + value_to_cluster[j][0] = s[2] + value_to_cluster[j][1] = 0 + j = j + 1 + # linked = linkage(value_to_cluster, 'single') + # + # labelList = range(1, 11) + # + # plt.figure(figsize=(10, 7)) + # dendrogram(linked, + # orientation='top', + # distance_sort='descending', + # show_leaf_counts=True) + # plt.show() + # TODO: Automatic detection of number of clusters. Aktuell noch MAGIC NUMBER + # cluster = AgglomerativeClustering(n_clusters=None, compute_full_tree=True, affinity='euclidean', + # linkage='ward', distance_threshold=opt_refinement_thresh) + cluster = AgglomerativeClustering(n_clusters=5, affinity='euclidean', linkage='ward') + cluster.fit_predict(value_to_cluster) + print_info("Cluster labels:\n" + str(cluster.labels_)) + # plt.scatter(value_to_cluster[:, 0], value_to_cluster[:, 1], c=cluster.labels_, cmap='rainbow') + # plt.show() + # TODO: Problem: Der Algorithmus nummeriert die Zustände nicht immer gleich... also bspw.: + # mal ist das tatsächliche Transmit mit 1 belabelt und mal mit 3 + cluster_labels_list.append(cluster.labels_) + num_cluster_list.append(cluster.n_clusters_) + i = i + 1 + if i != len(raw_states_list): + print_info("Used " + str(i) + "/" + str(len(raw_states_list)) + + " Measurements for state clustering. " + "Others did not recognize number of states correctly.") + num_states = np.argmax(np.bincount(num_cluster_list)) + resulting_sequence = [None] * num_raw_states + i = 0 + for x in resulting_sequence: + j = 0 + test_list = [] + for arr in cluster_labels_list: + if num_cluster_list[j] != num_states: + j = j + 1 + else: + test_list.append(arr[i]) + j = j + 1 + resulting_sequence[i] = np.argmax(np.bincount(test_list)) + i = i + 1 + print(resulting_sequence) + # TODO: TESTING PURPOSES exit() |