diff options
-rw-r--r-- | lib/pelt.py | 30 |
1 files changed, 19 insertions, 11 deletions
diff --git a/lib/pelt.py b/lib/pelt.py index 71f6370..aaad699 100644 --- a/lib/pelt.py +++ b/lib/pelt.py @@ -44,8 +44,8 @@ class PELT: self.num_samples = None self.name_filter = None self.refinement_threshold = 200e-6 # 200 µW - self.range_min = 0 # TODO 1 .. 89 - self.range_max = 88 + self.range_min = 1 + self.range_max = 89 self.stretch = 1 self.with_multiprocessing = True self.__dict__.update(kwargs) @@ -55,11 +55,6 @@ class PELT: self.stretch = int(self.stretch) self.cache_dir = "cache" - try: - os.mkdir(self.cache_dir) - except FileExistsError: - pass - if os.getenv("DFATOOL_PELT_MODEL"): # https://centre-borelli.github.io/ruptures-docs/user-guide/costs/costl1/ self.model = os.getenv("DFATOOL_PELT_MODEL") @@ -114,13 +109,26 @@ class PELT: if self.cache_dir is None: return cache_key = self.cache_key(signal, penalty, num_changepoints) - with open(f"{self.cache_dir}/pelt-{cache_key}.json", "w") as f: + + try: + os.mkdir(self.cache_dir) + except FileExistsError: + pass + + try: + os.mkdir(f"{self.cache_dir}/{cache_key[:2]}") + except FileExistsError: + pass + + with open(f"{self.cache_dir}/{cache_key[:2]}/pelt-{cache_key}.json", "w") as f: json.dump(data, f) def load_cache(self, signal, penalty, num_changepoints): cache_key = self.cache_key(signal, penalty, num_changepoints) try: - with open(f"{self.cache_dir}/pelt-{cache_key}.json", "r") as f: + with open( + f"{self.cache_dir}/{cache_key[:2]}/pelt-{cache_key}.json", "r" + ) as f: return json.load(f) except FileNotFoundError: return None @@ -239,9 +247,9 @@ class PELT: longest_end = end_index start_index = i prev_val = changepoint_count - middle_of_plateau = longest_start + (longest_start - longest_start) // 2 + middle_of_plateau = longest_start + (longest_end - longest_start) // 2 - return middle_of_plateau, changepoints_by_penalty + return self.range_min + middle_of_plateau, changepoints_by_penalty def get_changepoints(self, signal, **kwargs): penalty, changepoints_by_penalty = self.get_penalty_and_changepoints( |