summaryrefslogtreecommitdiff
path: root/lib/pelt.py
diff options
context:
space:
mode:
authorDaniel Friesel <daniel.friesel@uos.de>2020-12-14 12:53:01 +0100
committerDaniel Friesel <daniel.friesel@uos.de>2020-12-14 12:53:01 +0100
commit7d319d3c87de653fb9c4d788e17f8c134820171f (patch)
treeecdf23b2d188b3cd58487f1ddbe7230abdbbac27 /lib/pelt.py
parent2a4c3f84cc5b19a96f79be6643322d51d581ee14 (diff)
energytrace: add pelt-based drift compensation experiment.
Enable with DFATOOL_COMPENSATE_DRIFT=1 so far it's pretty unreliable.
Diffstat (limited to 'lib/pelt.py')
-rw-r--r--lib/pelt.py17
1 files changed, 12 insertions, 5 deletions
diff --git a/lib/pelt.py b/lib/pelt.py
index 38fb158..613ae80 100644
--- a/lib/pelt.py
+++ b/lib/pelt.py
@@ -63,7 +63,7 @@ class PELT:
normed_signal[i] = normed_signal[i] * scaler
return normed_signal
- def get_penalty_and_changepoints(self, signal):
+ def get_penalty_and_changepoints(self, signal, penalty=None):
# imported here as ruptures is only used for changepoint detection.
# This way, dfatool can be used without having ruptures installed as
# long as --pelt isn't active.
@@ -77,6 +77,13 @@ class PELT:
algo = ruptures.Pelt(
model=self.model, jump=self.jump, min_size=self.min_dist
).fit(self.norm_signal(signal))
+
+ if penalty is not None:
+ changepoints = algo.predict(pen=penalty)
+ if len(changepoints) and changepoints[-1] == len(signal):
+ changepoints.pop()
+ return penalty, changepoints
+
queue = list()
for i in range(0, 100):
queue.append((algo, i))
@@ -117,12 +124,12 @@ class PELT:
changepoints = np.array(changepoints_by_penalty[middle_of_plateau])
return middle_of_plateau, changepoints
- def get_changepoints(self, signal):
- _, changepoints = self.get_penalty_and_changepoints(signal)
+ def get_changepoints(self, signal, **kwargs):
+ _, changepoints = self.get_penalty_and_changepoints(signal, **kwargs)
return changepoints
- def get_penalty(self, signal):
- penalty, _ = self.get_penalty_and_changepoints(signal)
+ def get_penalty(self, signal, **kwargs):
+ penalty, _ = self.get_penalty_and_changepoints(signal, **kwargs)
return penalty
def calc_raw_states(self, timestamps, signals, penalty, opt_model=None):