summaryrefslogtreecommitdiff
path: root/lib/pelt.py
diff options
context:
space:
mode:
authorDaniel Friesel <daniel.friesel@uos.de>2021-02-23 10:32:00 +0100
committerDaniel Friesel <daniel.friesel@uos.de>2021-02-23 10:32:00 +0100
commit5cba61eb83ffb4954c924a789305bc0114b6c7ca (patch)
tree61ac32c7f8b52ccbd8d3ee1c2888183225f0889e /lib/pelt.py
parente537dafe711dfbf1cc643442a55668bd285a3c6e (diff)
fix drift compensation and reduce pelt + caching overhead
Diffstat (limited to 'lib/pelt.py')
-rw-r--r--lib/pelt.py48
1 files changed, 29 insertions, 19 deletions
diff --git a/lib/pelt.py b/lib/pelt.py
index 3336e4a..10bb135 100644
--- a/lib/pelt.py
+++ b/lib/pelt.py
@@ -47,12 +47,12 @@ class PELT:
self.range_max = 89
self.stretch = 1
self.with_multiprocessing = True
+ self.cache_dir = "cache"
self.__dict__.update(kwargs)
self.jump = int(self.jump)
self.min_dist = int(self.min_dist)
self.stretch = int(self.stretch)
- self.cache_dir = "cache"
if os.getenv("DFATOOL_PELT_MODEL"):
# https://centre-borelli.github.io/ruptures-docs/user-guide/costs/costl1/
@@ -120,7 +120,7 @@ class PELT:
pass
with open(f"{self.cache_dir}/{cache_key[:2]}/pelt-{cache_key}.json", "w") as f:
- json.dump(data, f)
+ json.dump(data, f, cls=NpEncoder)
def load_cache(self, traces, penalty, num_changepoints):
cache_key = self.cache_key(traces, penalty, num_changepoints)
@@ -131,8 +131,17 @@ class PELT:
return json.load(f)
except FileNotFoundError:
return None
+ except json.decoder.JSONDecodeError:
+ logger.warning(
+ f"Ignoring invalid cache entry {self.cache_dir}/{cache_key[:2]}/pelt-{cache_key}.json"
+ )
+ return None
def get_penalty_and_changepoints(self, traces, penalty=None, num_changepoints=None):
+ list_of_lists = type(traces[0]) is list or type(traces[0]) is np.ndarray
+ if not list_of_lists:
+ traces = [traces]
+
data = self.load_cache(traces, penalty, num_changepoints)
if data:
for res in data:
@@ -140,13 +149,18 @@ class PELT:
str_keys = list(res[1].keys())
for k in str_keys:
res[1][int(k)] = res[1].pop(k)
- return data
+ if list_of_lists:
+ return data
+ return data[0]
data = self.calculate_penalty_and_changepoints(
traces, penalty, num_changepoints
)
self.save_cache(traces, penalty, num_changepoints, data)
- return data
+
+ if list_of_lists:
+ return data
+ return data[0]
def calculate_penalty_and_changepoints(
self, traces, penalty=None, num_changepoints=None
@@ -156,11 +170,6 @@ class PELT:
# long as --pelt isn't active.
import ruptures
- list_of_lists = type(traces[0]) is list or type(traces[0]) is np.ndarray
-
- if not list_of_lists:
- traces = [traces]
-
if self.stretch != 1:
traces = list(
map(
@@ -216,17 +225,19 @@ class PELT:
if len(changepoints) and changepoints[0] == 0:
changepoints.pop(0)
if self.stretch != 1:
- changepoints = np.array(
- np.around(np.array(changepoints) / self.stretch), dtype=np.int
+ changepoints = list(
+ np.array(
+ np.around(np.array(changepoints) / self.stretch), dtype=np.int
+ )
)
changepoints_by_penalty_by_trace[i][range_penalty] = changepoints
for i in range(len(traces)):
changepoints_by_penalty = changepoints_by_penalty_by_trace[i]
if penalty is not None:
- results.append((penalty, changepoints_by_penalty[penalty]))
+ results.append((penalty, changepoints_by_penalty))
elif self.algo == "dynp" and num_changepoints is not None:
- results.append((None, changepoints_by_penalty[None]))
+ results.append((None, {0: changepoints_by_penalty[None]}))
else:
results.append(
(
@@ -235,10 +246,7 @@ class PELT:
)
)
- if list_of_lists:
- return results
- else:
- return results[0]
+ return results
def find_penalty(self, changepoints_by_penalty):
changepoint_counts = list()
@@ -271,12 +279,14 @@ class PELT:
def get_changepoints(self, traces, **kwargs):
results = self.get_penalty_and_changepoints(traces, **kwargs)
if type(results) is list:
- return list(map(lambda res: res[1][res[0]]))
+ return list(map(lambda res: res[1][res[0]], results))
return results[1][results[0]]
def get_penalty(self, traces, **kwargs):
results = self.get_penalty_and_changepoints(traces, **kwargs)
- return list(map(lambda res: res[0]))
+ if type(results) is list:
+ return list(map(lambda res: res[0]))
+ return res[0]
def calc_raw_states(
self,