diff options
author | Daniel Friesel <daniel.friesel@uos.de> | 2021-02-23 10:32:00 +0100 |
---|---|---|
committer | Daniel Friesel <daniel.friesel@uos.de> | 2021-02-23 10:32:00 +0100 |
commit | 5cba61eb83ffb4954c924a789305bc0114b6c7ca (patch) | |
tree | 61ac32c7f8b52ccbd8d3ee1c2888183225f0889e /lib/pelt.py | |
parent | e537dafe711dfbf1cc643442a55668bd285a3c6e (diff) |
fix drift compensation and reduce pelt + caching overhead
Diffstat (limited to 'lib/pelt.py')
-rw-r--r-- | lib/pelt.py | 48 |
1 files changed, 29 insertions, 19 deletions
diff --git a/lib/pelt.py b/lib/pelt.py index 3336e4a..10bb135 100644 --- a/lib/pelt.py +++ b/lib/pelt.py @@ -47,12 +47,12 @@ class PELT: self.range_max = 89 self.stretch = 1 self.with_multiprocessing = True + self.cache_dir = "cache" self.__dict__.update(kwargs) self.jump = int(self.jump) self.min_dist = int(self.min_dist) self.stretch = int(self.stretch) - self.cache_dir = "cache" if os.getenv("DFATOOL_PELT_MODEL"): # https://centre-borelli.github.io/ruptures-docs/user-guide/costs/costl1/ @@ -120,7 +120,7 @@ class PELT: pass with open(f"{self.cache_dir}/{cache_key[:2]}/pelt-{cache_key}.json", "w") as f: - json.dump(data, f) + json.dump(data, f, cls=NpEncoder) def load_cache(self, traces, penalty, num_changepoints): cache_key = self.cache_key(traces, penalty, num_changepoints) @@ -131,8 +131,17 @@ class PELT: return json.load(f) except FileNotFoundError: return None + except json.decoder.JSONDecodeError: + logger.warning( + f"Ignoring invalid cache entry {self.cache_dir}/{cache_key[:2]}/pelt-{cache_key}.json" + ) + return None def get_penalty_and_changepoints(self, traces, penalty=None, num_changepoints=None): + list_of_lists = type(traces[0]) is list or type(traces[0]) is np.ndarray + if not list_of_lists: + traces = [traces] + data = self.load_cache(traces, penalty, num_changepoints) if data: for res in data: @@ -140,13 +149,18 @@ class PELT: str_keys = list(res[1].keys()) for k in str_keys: res[1][int(k)] = res[1].pop(k) - return data + if list_of_lists: + return data + return data[0] data = self.calculate_penalty_and_changepoints( traces, penalty, num_changepoints ) self.save_cache(traces, penalty, num_changepoints, data) - return data + + if list_of_lists: + return data + return data[0] def calculate_penalty_and_changepoints( self, traces, penalty=None, num_changepoints=None @@ -156,11 +170,6 @@ class PELT: # long as --pelt isn't active. import ruptures - list_of_lists = type(traces[0]) is list or type(traces[0]) is np.ndarray - - if not list_of_lists: - traces = [traces] - if self.stretch != 1: traces = list( map( @@ -216,17 +225,19 @@ class PELT: if len(changepoints) and changepoints[0] == 0: changepoints.pop(0) if self.stretch != 1: - changepoints = np.array( - np.around(np.array(changepoints) / self.stretch), dtype=np.int + changepoints = list( + np.array( + np.around(np.array(changepoints) / self.stretch), dtype=np.int + ) ) changepoints_by_penalty_by_trace[i][range_penalty] = changepoints for i in range(len(traces)): changepoints_by_penalty = changepoints_by_penalty_by_trace[i] if penalty is not None: - results.append((penalty, changepoints_by_penalty[penalty])) + results.append((penalty, changepoints_by_penalty)) elif self.algo == "dynp" and num_changepoints is not None: - results.append((None, changepoints_by_penalty[None])) + results.append((None, {0: changepoints_by_penalty[None]})) else: results.append( ( @@ -235,10 +246,7 @@ class PELT: ) ) - if list_of_lists: - return results - else: - return results[0] + return results def find_penalty(self, changepoints_by_penalty): changepoint_counts = list() @@ -271,12 +279,14 @@ class PELT: def get_changepoints(self, traces, **kwargs): results = self.get_penalty_and_changepoints(traces, **kwargs) if type(results) is list: - return list(map(lambda res: res[1][res[0]])) + return list(map(lambda res: res[1][res[0]], results)) return results[1][results[0]] def get_penalty(self, traces, **kwargs): results = self.get_penalty_and_changepoints(traces, **kwargs) - return list(map(lambda res: res[0])) + if type(results) is list: + return list(map(lambda res: res[0])) + return res[0] def calc_raw_states( self, |