diff options
-rw-r--r-- | lib/pelt.py | 70 |
1 files changed, 65 insertions, 5 deletions
diff --git a/lib/pelt.py b/lib/pelt.py index 9a11807..7632f0d 100644 --- a/lib/pelt.py +++ b/lib/pelt.py @@ -1,9 +1,12 @@ #!/usr/bin/env python3 +import hashlib +import json import logging import numpy as np import os from multiprocessing import Pool +from .utils import NpEncoder logger = logging.getLogger(__name__) @@ -41,7 +44,7 @@ class PELT: self.num_samples = None self.name_filter = None self.refinement_threshold = 200e-6 # 200 µW - self.range_min = 0 + self.range_min = 0 # TODO 1 .. 89 self.range_max = 88 self.stretch = 1 self.with_multiprocessing = True @@ -50,6 +53,12 @@ class PELT: self.jump = int(self.jump) self.min_dist = int(self.min_dist) self.stretch = int(self.stretch) + self.cache_dir = "cache" + + try: + os.mkdir(self.cache_dir) + except FileExistsError: + pass if os.getenv("DFATOOL_PELT_MODEL"): # https://centre-borelli.github.io/ruptures-docs/user-guide/costs/costl1/ @@ -83,7 +92,57 @@ class PELT: normed_signal[i] = normed_signal[i] * scaler return normed_signal + def cache_key(self, signal, penalty, num_changepoints): + config = [ + signal, + penalty, + num_changepoints, + self.algo, + self.model, + self.jump, + self.min_dist, + self.range_min, + self.range_max, + self.stretch, + ] + cache_key = hashlib.sha256( + json.dumps(config, cls=NpEncoder).encode() + ).hexdigest() + return cache_key + + def save_cache(self, signal, penalty, num_changepoints, data): + if self.cache_dir is None: + return + cache_key = self.cache_key(signal, penalty, num_changepoints) + with open(f"{self.cache_dir}/pelt-{cache_key}.json", "w") as f: + json.dump(data, f) + + def load_cache(self, signal, penalty, num_changepoints): + cache_key = self.cache_key(signal, penalty, num_changepoints) + try: + with open(f"{self.cache_dir}/pelt-{cache_key}.json", "r") as f: + return json.load(f) + except FileNotFoundError: + return None + def get_penalty_and_changepoints(self, signal, penalty=None, num_changepoints=None): + data = self.load_cache(signal, penalty, num_changepoints) + if data: + if type(data[1]) is dict: + str_keys = list(data[1].keys()) + for k in str_keys: + data[1][int(k)] = data[1].pop(k) + return data + + data = self.calculate_penalty_and_changepoints( + signal, penalty, num_changepoints + ) + self.save_cache(signal, penalty, num_changepoints, data) + return data + + def calculate_penalty_and_changepoints( + self, signal, penalty=None, num_changepoints=None + ): # imported here as ruptures is only used for changepoint detection. # This way, dfatool can be used without having ruptures installed as # long as --pelt isn't active. @@ -156,6 +215,7 @@ class PELT: np.array(np.around(np.array(res[1]) / self.stretch), dtype=np.int), ) changepoints_by_penalty[res[0]] = res[1] + changepoint_counts = list() for i in range(self.range_min, self.range_max): changepoint_counts.append(len(changepoints_by_penalty[i])) @@ -165,8 +225,8 @@ class PELT: longest_start = -1 longest_end = -1 prev_val = -1 - for i, num_changepoints in enumerate(changepoint_counts): - if num_changepoints != prev_val: + for i, changepoint_count in enumerate(changepoint_counts): + if changepoint_count != prev_val: end_index = i - 1 if end_index - start_index > longest_end - longest_start: longest_start = start_index @@ -178,9 +238,9 @@ class PELT: longest_start = start_index longest_end = end_index start_index = i - prev_val = num_changepoints + prev_val = changepoint_count middle_of_plateau = longest_start + (longest_start - longest_start) // 2 - changepoints = np.array(changepoints_by_penalty[middle_of_plateau]) + return middle_of_plateau, changepoints_by_penalty def get_changepoints(self, signal, **kwargs): |