diff options
author | Daniel Friesel <daniel.friesel@uos.de> | 2021-01-15 14:01:30 +0100 |
---|---|---|
committer | Daniel Friesel <daniel.friesel@uos.de> | 2021-01-15 14:01:30 +0100 |
commit | 4896ec66ca4eccbe029a286617d7e74d9a732ab0 (patch) | |
tree | 5c0e407d7a475bce5ece61cc015f42b357bc6d4c /lib/pelt.py | |
parent | 34cd6f989de6a7f1eec27d8390b179c0f943e2a4 (diff) |
pelt: add stretch parameter
Diffstat (limited to 'lib/pelt.py')
-rw-r--r-- | lib/pelt.py | 20 |
1 files changed, 20 insertions, 0 deletions
diff --git a/lib/pelt.py b/lib/pelt.py index 95e828a..500ae13 100644 --- a/lib/pelt.py +++ b/lib/pelt.py @@ -42,6 +42,7 @@ class PELT: self.refinement_threshold = 200e-6 # 200 µW self.range_min = 0 self.range_max = 100 + self.stretch = 1 self.with_multiprocessing = True self.__dict__.update(kwargs) @@ -81,6 +82,13 @@ class PELT: # long as --pelt isn't active. import ruptures + if self.stretch != 1: + signal = np.interp( + np.linspace(0, len(signal) - 1, (len(signal) - 1) * self.stretch + 1), + np.arange(len(signal)), + signal, + ) + if self.num_samples is not None and len(signal) > self.num_samples: self.jump = len(signal) // int(self.num_samples) else: @@ -104,6 +112,10 @@ class PELT: changepoints.pop() if len(changepoints) and changepoints[0] == 0: changepoints.pop(0) + if self.stretch != 1: + changepoints = np.array( + np.around(np.array(changepoints) / self.stretch), dtype=np.int + ) return penalty, changepoints if self.algo == "dynp" and num_changepoints is not None: @@ -112,6 +124,10 @@ class PELT: changepoints.pop() if len(changepoints) and changepoints[0] == 0: changepoints.pop(0) + if self.stretch != 1: + changepoints = np.array( + np.around(np.array(changepoints) / self.stretch), dtype=np.int + ) return None, changepoints queue = list() @@ -126,6 +142,10 @@ class PELT: for res in changepoints: if len(res[1]) > 0 and res[1][-1] == len(signal): res[1].pop() + if self.stretch != 1: + res[1] = np.array( + np.around(np.array(res[1]) / self.stretch), dtype=np.int + ) changepoints_by_penalty[res[0]] = res[1] changepoint_counts = list() for i in range(0, 100): |