diff options
author | Daniel Friesel <daniel.friesel@uos.de> | 2020-12-14 12:53:01 +0100 |
---|---|---|
committer | Daniel Friesel <daniel.friesel@uos.de> | 2020-12-14 12:53:01 +0100 |
commit | 7d319d3c87de653fb9c4d788e17f8c134820171f (patch) | |
tree | ecdf23b2d188b3cd58487f1ddbe7230abdbbac27 /lib/pelt.py | |
parent | 2a4c3f84cc5b19a96f79be6643322d51d581ee14 (diff) |
energytrace: add pelt-based drift compensation experiment.
Enable with DFATOOL_COMPENSATE_DRIFT=1
so far it's pretty unreliable.
Diffstat (limited to 'lib/pelt.py')
-rw-r--r-- | lib/pelt.py | 17 |
1 files changed, 12 insertions, 5 deletions
diff --git a/lib/pelt.py b/lib/pelt.py index 38fb158..613ae80 100644 --- a/lib/pelt.py +++ b/lib/pelt.py @@ -63,7 +63,7 @@ class PELT: normed_signal[i] = normed_signal[i] * scaler return normed_signal - def get_penalty_and_changepoints(self, signal): + def get_penalty_and_changepoints(self, signal, penalty=None): # imported here as ruptures is only used for changepoint detection. # This way, dfatool can be used without having ruptures installed as # long as --pelt isn't active. @@ -77,6 +77,13 @@ class PELT: algo = ruptures.Pelt( model=self.model, jump=self.jump, min_size=self.min_dist ).fit(self.norm_signal(signal)) + + if penalty is not None: + changepoints = algo.predict(pen=penalty) + if len(changepoints) and changepoints[-1] == len(signal): + changepoints.pop() + return penalty, changepoints + queue = list() for i in range(0, 100): queue.append((algo, i)) @@ -117,12 +124,12 @@ class PELT: changepoints = np.array(changepoints_by_penalty[middle_of_plateau]) return middle_of_plateau, changepoints - def get_changepoints(self, signal): - _, changepoints = self.get_penalty_and_changepoints(signal) + def get_changepoints(self, signal, **kwargs): + _, changepoints = self.get_penalty_and_changepoints(signal, **kwargs) return changepoints - def get_penalty(self, signal): - penalty, _ = self.get_penalty_and_changepoints(signal) + def get_penalty(self, signal, **kwargs): + penalty, _ = self.get_penalty_and_changepoints(signal, **kwargs) return penalty def calc_raw_states(self, timestamps, signals, penalty, opt_model=None): |