summaryrefslogtreecommitdiff
path: root/bin
diff options
context:
space:
mode:
authorjfalkenhagen <jfalkenhagen@uos.de>2020-07-10 16:19:08 +0200
committerjfalkenhagen <jfalkenhagen@uos.de>2020-07-10 16:19:08 +0200
commit98a7873ec1ce265e6d229af4fa8416b3a9ef018a (patch)
treef28987e8f7ba6d7ab65e82cc758495441fed5f75 /bin
parent71b981c13c007d33f4042823703f98e41ff56770 (diff)
bin/Proof_Of_Concept_PELT.py: Calculation of raw_states is now parallelized.
Diffstat (limited to 'bin')
-rw-r--r--bin/Proof_Of_Concept_PELT.py121
1 files changed, 79 insertions, 42 deletions
diff --git a/bin/Proof_Of_Concept_PELT.py b/bin/Proof_Of_Concept_PELT.py
index 92d09fa..bcbd53e 100644
--- a/bin/Proof_Of_Concept_PELT.py
+++ b/bin/Proof_Of_Concept_PELT.py
@@ -6,7 +6,6 @@ import re
from multiprocessing import Pool, Manager
from kneed import KneeLocator
from sklearn.cluster import AgglomerativeClustering
-from scipy.signal import find_peaks
import matplotlib.pyplot as plt
import ruptures as rpt
import numpy as np
@@ -287,6 +286,50 @@ def calculate_penalty_value(signal, model="l1", jump=5, min_dist=2, range_min=0,
# return False
+# raw_states_calc_args.append((num_measurement, measurement, penalty, opt_model
+# , opt_jump))
+def calc_raw_states_func(num_trace, measurement, penalty, model, jump):
+ signal = np.array(measurement['uW'])
+ normed_signal = norm_signal(signal)
+ bkpts = calc_pelt(normed_signal, penalty, model=model, jump=jump)
+ calced_states = list()
+ start_time = 0
+ end_time = 0
+ for bkpt in bkpts:
+ # start_time of state is end_time of previous one
+ # (Transitions are instantaneous)
+ start_time = end_time
+ end_time = bkpt
+ power_vals = signal[start_time: end_time]
+ mean_power = np.mean(power_vals)
+ std_dev = np.std(power_vals)
+ calced_state = (start_time, end_time, mean_power, std_dev)
+ calced_states.append(calced_state)
+ num = 0
+ new_avg_std = 0
+ for s in calced_states:
+ # print_info("State " + str(num) + " starts at t=" + str(s[0])
+ # + " and ends at t=" + str(s[1])
+ # + " while using " + str(s[2])
+ # + "uW with sigma=" + str(s[3]))
+ num = num + 1
+ new_avg_std = new_avg_std + s[3]
+ new_avg_std = new_avg_std / len(calced_states)
+ change_avg_std = measurement['uW_std'] - new_avg_std
+ # print_info("The average standard deviation for the newly found states is "
+ # + str(new_avg_std))
+ # print_info("That is a reduction of " + str(change_avg_std))
+ return num_trace, calced_states, new_avg_std, change_avg_std
+
+
+def calc_raw_states(arg_list, num_processes=8):
+ m = Manager()
+ with Pool(processes=num_processes) as p:
+ # collect results from pool
+ result = p.starmap(calc_raw_states_func, arg_list)
+ return result
+
+
# Very short benchmark yielded approx. 3 times the speed of solution not using sort
# TODO: Decide whether median is really the better baseline than mean
def needs_refinement(signal, thresh):
@@ -477,10 +520,9 @@ if __name__ == '__main__':
signal = measurement['uW']
# mean = measurement['uW_mean']
# TODO: Decide if median is really the better baseline than mean
- if needs_refinement(signal, opt_refinement_thresh):
+ if needs_refinement(signal, opt_refinement_thresh) and not refine:
print_info("Refinement is necessary!")
refine = True
- break
if not refine:
print_info("No refinement necessary for state '" + measurements_by_state['name']
+ "' with params: " + str(measurements_by_state['parameter']))
@@ -499,45 +541,34 @@ if __name__ == '__main__':
penalty = penalty[0]
else:
penalty = opt_pen_override
- # calc and save all bkpts for the given state and param config
- raw_states_list = list()
- for measurement in measurements_by_state['offline']:
- signal = np.array(measurement['uW'])
- normed_signal = norm_signal(signal)
- bkpts = calc_pelt(normed_signal, penalty, model=opt_model, jump=opt_jump)
- calced_states = list()
- start_time = 0
- end_time = 0
- for bkpt in bkpts:
- # start_time of state is end_time of previous one
- # (Transitions are instantaneous)
- start_time = end_time
- end_time = bkpt
- power_vals = signal[start_time: end_time]
- mean_power = np.mean(power_vals)
- std_dev = np.std(power_vals)
- calced_state = (start_time, end_time, mean_power, std_dev)
- calced_states.append(calced_state)
- num = 0
- new_avg_std = 0
- for s in calced_states:
- print_info("State " + str(num) + " starts at t=" + str(s[0])
- + " and ends at t=" + str(s[1])
- + " while using " + str(s[2])
- + "uW with sigma=" + str(s[3]))
- num = num + 1
- new_avg_std = new_avg_std + s[3]
- new_avg_std = new_avg_std / len(calced_states)
- change_avg_std = measurement['uW_std'] - new_avg_std
- print_info("The average standard deviation for the newly found states is "
- + str(new_avg_std))
+ # build arguments for parallel excecution
+ print_info("Starting raw_states calculation.")
+ raw_states_calc_args = []
+ for num_measurement, measurement in enumerate(measurements_by_state['offline']):
+ raw_states_calc_args.append((num_measurement, measurement, penalty,
+ opt_model, opt_jump))
+
+ raw_states_list = [None] * len(measurements_by_state['offline'])
+ raw_states_res = calc_raw_states(raw_states_calc_args, opt_num_processes)
+ # extracting result and putting it in correct order -> index of raw_states_list
+ # entry still corresponds with index of measurement in measurements_by_states
+ # -> If measurements are discarded the correct ones are easily recognized
+ for ret_val in raw_states_res:
+ num_trace = ret_val[0]
+ raw_states = ret_val[1]
+ avg_std = ret_val[2]
+ change_avg_std = ret_val[3]
+ # TODO: Wieso gibt mir meine IDE hier eine Warning aus? Der Index müsste doch
+ # int sein oder nicht? Es scheint auch vernünftig zu klappen...
+ raw_states_list[num_trace] = raw_states
+ print_info("The average standard deviation for the newly found states in "
+ + "measurement No. " + str(num_trace) + " is " + str(avg_std))
print_info("That is a reduction of " + str(change_avg_std))
- raw_states_list.append(calced_states)
+ print_info("Finished raw_states calculation.")
num_states_array = [int()] * len(raw_states_list)
i = 0
- for x in raw_states_list:
+ for i, x in enumerate(raw_states_list):
num_states_array[i] = len(x)
- i = i + 1
avg_num_states = np.mean(num_states_array)
num_states_dev = np.std(num_states_array)
print_info("On average " + str(avg_num_states)
@@ -558,7 +589,7 @@ if __name__ == '__main__':
i = 0
cluster_labels_list = []
num_cluster_list = []
- for raw_states in raw_states_list:
+ for num_trace, raw_states in enumerate(raw_states_list):
# iterate through raw states from measurements
if len(raw_states) == num_raw_states:
# build array with power values to cluster these
@@ -580,12 +611,14 @@ if __name__ == '__main__':
# plt.show()
# TODO: Automatic detection of number of clusters. Aktuell noch MAGIC NUMBER
# im distance_threshold
- cluster = AgglomerativeClustering(n_clusters=None, compute_full_tree=True, affinity='euclidean',
- linkage='ward', distance_threshold=opt_refinement_thresh*100)
+ cluster = AgglomerativeClustering(n_clusters=None, compute_full_tree=True,
+ affinity='euclidean',
+ linkage='ward',
+ distance_threshold=opt_refinement_thresh * 100)
# cluster = AgglomerativeClustering(n_clusters=5, affinity='euclidean',
# linkage='ward')
cluster.fit_predict(value_to_cluster)
- print_info("Cluster labels:\n" + str(cluster.labels_))
+ # print_info("Cluster labels:\n" + str(cluster.labels_))
# plt.scatter(value_to_cluster[:, 0], value_to_cluster[:, 1], c=cluster.labels_, cmap='rainbow')
# plt.show()
# TODO: Problem: Der Algorithmus nummeriert die Zustände nicht immer gleich... also bspw.:
@@ -593,6 +626,9 @@ if __name__ == '__main__':
cluster_labels_list.append(cluster.labels_)
num_cluster_list.append(cluster.n_clusters_)
i = i + 1
+ else:
+ print_info("Discarding measurement No. " + str(num_trace) + " because it "
+ + "did not recognize the number of raw_states correctly.")
if i != len(raw_states_list):
if i / len(raw_states_list) <= 0.5:
print_warning("Only used " + str(i) + "/" + str(len(raw_states_list))
@@ -603,6 +639,7 @@ if __name__ == '__main__':
print_info("Used " + str(i) + "/" + str(len(raw_states_list))
+ " Measurements for refinement. "
"Others did not recognize number of states correctly.")
+ # TODO: DEBUG Kram
sys.exit()
else:
print_info("Used all available measurements.")