summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorDaniel Friesel <daniel.friesel@uos.de>2021-01-15 14:01:30 +0100
committerDaniel Friesel <daniel.friesel@uos.de>2021-01-15 14:01:30 +0100
commit4896ec66ca4eccbe029a286617d7e74d9a732ab0 (patch)
tree5c0e407d7a475bce5ece61cc015f42b357bc6d4c /lib
parent34cd6f989de6a7f1eec27d8390b179c0f943e2a4 (diff)
pelt: add stretch parameter
Diffstat (limited to 'lib')
-rw-r--r--lib/pelt.py20
1 files changed, 20 insertions, 0 deletions
diff --git a/lib/pelt.py b/lib/pelt.py
index 95e828a..500ae13 100644
--- a/lib/pelt.py
+++ b/lib/pelt.py
@@ -42,6 +42,7 @@ class PELT:
self.refinement_threshold = 200e-6 # 200 µW
self.range_min = 0
self.range_max = 100
+ self.stretch = 1
self.with_multiprocessing = True
self.__dict__.update(kwargs)
@@ -81,6 +82,13 @@ class PELT:
# long as --pelt isn't active.
import ruptures
+ if self.stretch != 1:
+ signal = np.interp(
+ np.linspace(0, len(signal) - 1, (len(signal) - 1) * self.stretch + 1),
+ np.arange(len(signal)),
+ signal,
+ )
+
if self.num_samples is not None and len(signal) > self.num_samples:
self.jump = len(signal) // int(self.num_samples)
else:
@@ -104,6 +112,10 @@ class PELT:
changepoints.pop()
if len(changepoints) and changepoints[0] == 0:
changepoints.pop(0)
+ if self.stretch != 1:
+ changepoints = np.array(
+ np.around(np.array(changepoints) / self.stretch), dtype=np.int
+ )
return penalty, changepoints
if self.algo == "dynp" and num_changepoints is not None:
@@ -112,6 +124,10 @@ class PELT:
changepoints.pop()
if len(changepoints) and changepoints[0] == 0:
changepoints.pop(0)
+ if self.stretch != 1:
+ changepoints = np.array(
+ np.around(np.array(changepoints) / self.stretch), dtype=np.int
+ )
return None, changepoints
queue = list()
@@ -126,6 +142,10 @@ class PELT:
for res in changepoints:
if len(res[1]) > 0 and res[1][-1] == len(signal):
res[1].pop()
+ if self.stretch != 1:
+ res[1] = np.array(
+ np.around(np.array(res[1]) / self.stretch), dtype=np.int
+ )
changepoints_by_penalty[res[0]] = res[1]
changepoint_counts = list()
for i in range(0, 100):