summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lib/pelt.py70
1 files changed, 65 insertions, 5 deletions
diff --git a/lib/pelt.py b/lib/pelt.py
index 9a11807..7632f0d 100644
--- a/lib/pelt.py
+++ b/lib/pelt.py
@@ -1,9 +1,12 @@
#!/usr/bin/env python3
+import hashlib
+import json
import logging
import numpy as np
import os
from multiprocessing import Pool
+from .utils import NpEncoder
logger = logging.getLogger(__name__)
@@ -41,7 +44,7 @@ class PELT:
self.num_samples = None
self.name_filter = None
self.refinement_threshold = 200e-6 # 200 µW
- self.range_min = 0
+ self.range_min = 0 # TODO 1 .. 89
self.range_max = 88
self.stretch = 1
self.with_multiprocessing = True
@@ -50,6 +53,12 @@ class PELT:
self.jump = int(self.jump)
self.min_dist = int(self.min_dist)
self.stretch = int(self.stretch)
+ self.cache_dir = "cache"
+
+ try:
+ os.mkdir(self.cache_dir)
+ except FileExistsError:
+ pass
if os.getenv("DFATOOL_PELT_MODEL"):
# https://centre-borelli.github.io/ruptures-docs/user-guide/costs/costl1/
@@ -83,7 +92,57 @@ class PELT:
normed_signal[i] = normed_signal[i] * scaler
return normed_signal
+ def cache_key(self, signal, penalty, num_changepoints):
+ config = [
+ signal,
+ penalty,
+ num_changepoints,
+ self.algo,
+ self.model,
+ self.jump,
+ self.min_dist,
+ self.range_min,
+ self.range_max,
+ self.stretch,
+ ]
+ cache_key = hashlib.sha256(
+ json.dumps(config, cls=NpEncoder).encode()
+ ).hexdigest()
+ return cache_key
+
+ def save_cache(self, signal, penalty, num_changepoints, data):
+ if self.cache_dir is None:
+ return
+ cache_key = self.cache_key(signal, penalty, num_changepoints)
+ with open(f"{self.cache_dir}/pelt-{cache_key}.json", "w") as f:
+ json.dump(data, f)
+
+ def load_cache(self, signal, penalty, num_changepoints):
+ cache_key = self.cache_key(signal, penalty, num_changepoints)
+ try:
+ with open(f"{self.cache_dir}/pelt-{cache_key}.json", "r") as f:
+ return json.load(f)
+ except FileNotFoundError:
+ return None
+
def get_penalty_and_changepoints(self, signal, penalty=None, num_changepoints=None):
+ data = self.load_cache(signal, penalty, num_changepoints)
+ if data:
+ if type(data[1]) is dict:
+ str_keys = list(data[1].keys())
+ for k in str_keys:
+ data[1][int(k)] = data[1].pop(k)
+ return data
+
+ data = self.calculate_penalty_and_changepoints(
+ signal, penalty, num_changepoints
+ )
+ self.save_cache(signal, penalty, num_changepoints, data)
+ return data
+
+ def calculate_penalty_and_changepoints(
+ self, signal, penalty=None, num_changepoints=None
+ ):
# imported here as ruptures is only used for changepoint detection.
# This way, dfatool can be used without having ruptures installed as
# long as --pelt isn't active.
@@ -156,6 +215,7 @@ class PELT:
np.array(np.around(np.array(res[1]) / self.stretch), dtype=np.int),
)
changepoints_by_penalty[res[0]] = res[1]
+
changepoint_counts = list()
for i in range(self.range_min, self.range_max):
changepoint_counts.append(len(changepoints_by_penalty[i]))
@@ -165,8 +225,8 @@ class PELT:
longest_start = -1
longest_end = -1
prev_val = -1
- for i, num_changepoints in enumerate(changepoint_counts):
- if num_changepoints != prev_val:
+ for i, changepoint_count in enumerate(changepoint_counts):
+ if changepoint_count != prev_val:
end_index = i - 1
if end_index - start_index > longest_end - longest_start:
longest_start = start_index
@@ -178,9 +238,9 @@ class PELT:
longest_start = start_index
longest_end = end_index
start_index = i
- prev_val = num_changepoints
+ prev_val = changepoint_count
middle_of_plateau = longest_start + (longest_start - longest_start) // 2
- changepoints = np.array(changepoints_by_penalty[middle_of_plateau])
+
return middle_of_plateau, changepoints_by_penalty
def get_changepoints(self, signal, **kwargs):