summaryrefslogtreecommitdiff
path: root/lib/pelt.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/pelt.py')
-rw-r--r--lib/pelt.py30
1 files changed, 19 insertions, 11 deletions
diff --git a/lib/pelt.py b/lib/pelt.py
index 71f6370..aaad699 100644
--- a/lib/pelt.py
+++ b/lib/pelt.py
@@ -44,8 +44,8 @@ class PELT:
self.num_samples = None
self.name_filter = None
self.refinement_threshold = 200e-6 # 200 µW
- self.range_min = 0 # TODO 1 .. 89
- self.range_max = 88
+ self.range_min = 1
+ self.range_max = 89
self.stretch = 1
self.with_multiprocessing = True
self.__dict__.update(kwargs)
@@ -55,11 +55,6 @@ class PELT:
self.stretch = int(self.stretch)
self.cache_dir = "cache"
- try:
- os.mkdir(self.cache_dir)
- except FileExistsError:
- pass
-
if os.getenv("DFATOOL_PELT_MODEL"):
# https://centre-borelli.github.io/ruptures-docs/user-guide/costs/costl1/
self.model = os.getenv("DFATOOL_PELT_MODEL")
@@ -114,13 +109,26 @@ class PELT:
if self.cache_dir is None:
return
cache_key = self.cache_key(signal, penalty, num_changepoints)
- with open(f"{self.cache_dir}/pelt-{cache_key}.json", "w") as f:
+
+ try:
+ os.mkdir(self.cache_dir)
+ except FileExistsError:
+ pass
+
+ try:
+ os.mkdir(f"{self.cache_dir}/{cache_key[:2]}")
+ except FileExistsError:
+ pass
+
+ with open(f"{self.cache_dir}/{cache_key[:2]}/pelt-{cache_key}.json", "w") as f:
json.dump(data, f)
def load_cache(self, signal, penalty, num_changepoints):
cache_key = self.cache_key(signal, penalty, num_changepoints)
try:
- with open(f"{self.cache_dir}/pelt-{cache_key}.json", "r") as f:
+ with open(
+ f"{self.cache_dir}/{cache_key[:2]}/pelt-{cache_key}.json", "r"
+ ) as f:
return json.load(f)
except FileNotFoundError:
return None
@@ -239,9 +247,9 @@ class PELT:
longest_end = end_index
start_index = i
prev_val = changepoint_count
- middle_of_plateau = longest_start + (longest_start - longest_start) // 2
+ middle_of_plateau = longest_start + (longest_end - longest_start) // 2
- return middle_of_plateau, changepoints_by_penalty
+ return self.range_min + middle_of_plateau, changepoints_by_penalty
def get_changepoints(self, signal, **kwargs):
penalty, changepoints_by_penalty = self.get_penalty_and_changepoints(