diff options
author | Daniel Friesel <daniel.friesel@uos.de> | 2021-02-23 10:32:00 +0100 |
---|---|---|
committer | Daniel Friesel <daniel.friesel@uos.de> | 2021-02-23 10:32:00 +0100 |
commit | 5cba61eb83ffb4954c924a789305bc0114b6c7ca (patch) | |
tree | 61ac32c7f8b52ccbd8d3ee1c2888183225f0889e | |
parent | e537dafe711dfbf1cc643442a55668bd285a3c6e (diff) |
fix drift compensation and reduce pelt + caching overhead
-rw-r--r-- | lib/lennart/DataProcessor.py | 56 | ||||
-rw-r--r-- | lib/pelt.py | 48 |
2 files changed, 60 insertions, 44 deletions
diff --git a/lib/lennart/DataProcessor.py b/lib/lennart/DataProcessor.py index 0fda100..6df813e 100644 --- a/lib/lennart/DataProcessor.py +++ b/lib/lennart/DataProcessor.py @@ -216,9 +216,7 @@ class DataProcessor: """Use ruptures (e.g. Pelt, Dynp) to determine transition timestamps.""" from dfatool.pelt import PELT - # TODO die Anzahl Changepoints ist a priori bekannt, es könnte mit ruptures.Dynp statt ruptures.Pelt besser funktionieren. - # Vielleicht sollte man auch "rbf" statt "l1" nutzen. - # "rbf" und "l2" scheinen ähnlich gut zu funktionieren, l2 ist schneller. + # "rbf" und "l2" scheinen ähnlich gut zu funktionieren, l2 ist schneller. l1 ist wohl noch besser. # PELT does not find changepoints for transitions which span just four or five data points (i.e., transitions shorter than ~2ms). # Workaround: Double the data rate passed to PELT by interpolation ("stretch=2") pelt = PELT(with_multiprocessing=False, stretch=2, min_dist=1) @@ -229,6 +227,10 @@ class DataProcessor: # TODO auch Kandidatenbestimmung per Ableitung probieren # (-> Umgebungsvariable zur Auswahl) + pelt_traces = list() + timestamps = list() + candidate_weights = list() + for i, expected_start_ts in enumerate(expected_transition_start_timestamps): expected_end_ts = sync_timestamps[2 * i + 1] # assumption: maximum deviation between expected and actual timestamps is 5ms. @@ -239,30 +241,34 @@ class DataProcessor: et_timestamps_end = bisect_right( self.et_timestamps, expected_end_ts + 10e-3 ) - timestamps = self.et_timestamps[et_timestamps_start : et_timestamps_end + 1] - energy_data = self.et_power_values[ - et_timestamps_start : et_timestamps_end + 1 - ] - candidate_weight = dict() + timestamps.append( + self.et_timestamps[et_timestamps_start : et_timestamps_end + 1] + ) + pelt_traces.append( + self.et_power_values[et_timestamps_start : et_timestamps_end + 1] + ) # TODO for greedy mode, perform changepoint detection between greedy steps # (-> the expected changepoint area is well-known, Dynp with 1/2 changepoints # should work much better than "somewhere in these 20ms there should be a transition") - if 0: - penalties = (None,) - elif os.getenv("DFATOOL_DRIFT_COMPENSATION_PENALTY"): - penalties = (int(os.getenv("DFATOOL_DRIFT_COMPENSATION_PENALTY")),) - else: - penalties = (1, 2, 5, 10, 15, 20) - for penalty in penalties: - for changepoint in pelt.get_changepoints( - energy_data, penalty=penalty, num_changepoints=1 - ): - if changepoint in candidate_weight: - candidate_weight[changepoint] += 1 + if os.getenv("DFATOOL_DRIFT_COMPENSATION_PENALTY"): + penalties = (int(os.getenv("DFATOOL_DRIFT_COMPENSATION_PENALTY")),) + else: + penalties = (1, 2, 5, 10, 15, 20) + for penalty in penalties: + changepoints_by_transition = pelt.get_changepoints( + pelt_traces, penalty=penalty + ) + for i in range(len(expected_transition_start_timestamps)): + candidate_weights.append(dict()) + for changepoint in changepoints_by_transition[i]: + if changepoint in candidate_weights[i]: + candidate_weights[i][changepoint] += 1 else: - candidate_weight[changepoint] = 1 + candidate_weights[i][changepoint] = 1 + + for i, expected_start_ts in enumerate(expected_transition_start_timestamps): # TODO ist expected_start_ts wirklich eine gute Referenz? Wenn vor einer Transition ein UART-Dump # liegt, dürfte expected_end_ts besser sein, dann muss allerdings bei der compensation wieder auf @@ -271,11 +277,11 @@ class DataProcessor: list( map( lambda k: ( - timestamps[k] - expected_start_ts, - timestamps[k] - expected_end_ts, - candidate_weight[k], + timestamps[i][k] - expected_start_ts, + timestamps[i][k] - expected_end_ts, + candidate_weights[i][k], ), - sorted(candidate_weight.keys()), + sorted(candidate_weights[i].keys()), ) ) ) diff --git a/lib/pelt.py b/lib/pelt.py index 3336e4a..10bb135 100644 --- a/lib/pelt.py +++ b/lib/pelt.py @@ -47,12 +47,12 @@ class PELT: self.range_max = 89 self.stretch = 1 self.with_multiprocessing = True + self.cache_dir = "cache" self.__dict__.update(kwargs) self.jump = int(self.jump) self.min_dist = int(self.min_dist) self.stretch = int(self.stretch) - self.cache_dir = "cache" if os.getenv("DFATOOL_PELT_MODEL"): # https://centre-borelli.github.io/ruptures-docs/user-guide/costs/costl1/ @@ -120,7 +120,7 @@ class PELT: pass with open(f"{self.cache_dir}/{cache_key[:2]}/pelt-{cache_key}.json", "w") as f: - json.dump(data, f) + json.dump(data, f, cls=NpEncoder) def load_cache(self, traces, penalty, num_changepoints): cache_key = self.cache_key(traces, penalty, num_changepoints) @@ -131,8 +131,17 @@ class PELT: return json.load(f) except FileNotFoundError: return None + except json.decoder.JSONDecodeError: + logger.warning( + f"Ignoring invalid cache entry {self.cache_dir}/{cache_key[:2]}/pelt-{cache_key}.json" + ) + return None def get_penalty_and_changepoints(self, traces, penalty=None, num_changepoints=None): + list_of_lists = type(traces[0]) is list or type(traces[0]) is np.ndarray + if not list_of_lists: + traces = [traces] + data = self.load_cache(traces, penalty, num_changepoints) if data: for res in data: @@ -140,13 +149,18 @@ class PELT: str_keys = list(res[1].keys()) for k in str_keys: res[1][int(k)] = res[1].pop(k) - return data + if list_of_lists: + return data + return data[0] data = self.calculate_penalty_and_changepoints( traces, penalty, num_changepoints ) self.save_cache(traces, penalty, num_changepoints, data) - return data + + if list_of_lists: + return data + return data[0] def calculate_penalty_and_changepoints( self, traces, penalty=None, num_changepoints=None @@ -156,11 +170,6 @@ class PELT: # long as --pelt isn't active. import ruptures - list_of_lists = type(traces[0]) is list or type(traces[0]) is np.ndarray - - if not list_of_lists: - traces = [traces] - if self.stretch != 1: traces = list( map( @@ -216,17 +225,19 @@ class PELT: if len(changepoints) and changepoints[0] == 0: changepoints.pop(0) if self.stretch != 1: - changepoints = np.array( - np.around(np.array(changepoints) / self.stretch), dtype=np.int + changepoints = list( + np.array( + np.around(np.array(changepoints) / self.stretch), dtype=np.int + ) ) changepoints_by_penalty_by_trace[i][range_penalty] = changepoints for i in range(len(traces)): changepoints_by_penalty = changepoints_by_penalty_by_trace[i] if penalty is not None: - results.append((penalty, changepoints_by_penalty[penalty])) + results.append((penalty, changepoints_by_penalty)) elif self.algo == "dynp" and num_changepoints is not None: - results.append((None, changepoints_by_penalty[None])) + results.append((None, {0: changepoints_by_penalty[None]})) else: results.append( ( @@ -235,10 +246,7 @@ class PELT: ) ) - if list_of_lists: - return results - else: - return results[0] + return results def find_penalty(self, changepoints_by_penalty): changepoint_counts = list() @@ -271,12 +279,14 @@ class PELT: def get_changepoints(self, traces, **kwargs): results = self.get_penalty_and_changepoints(traces, **kwargs) if type(results) is list: - return list(map(lambda res: res[1][res[0]])) + return list(map(lambda res: res[1][res[0]], results)) return results[1][results[0]] def get_penalty(self, traces, **kwargs): results = self.get_penalty_and_changepoints(traces, **kwargs) - return list(map(lambda res: res[0])) + if type(results) is list: + return list(map(lambda res: res[0])) + return res[0] def calc_raw_states( self, |