summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
authorDaniel Friesel <daniel.friesel@uos.de>2021-02-24 12:36:36 +0100
committerDaniel Friesel <daniel.friesel@uos.de>2021-02-24 12:36:36 +0100
commite09189307bb6bf33000839582adc3bdbbc0a8eeb (patch)
treec415c4445d792f660a3ea564cd00b734781be83e /lib
parent62cd6c1adfa6158d039032cc0f6b3823dce74d39 (diff)
pelt: allow contraction (averaging) of data as well as stretching
Diffstat (limited to 'lib')
-rw-r--r--lib/pelt.py25
1 files changed, 23 insertions, 2 deletions
diff --git a/lib/pelt.py b/lib/pelt.py
index 10bb135..1ebc37f 100644
--- a/lib/pelt.py
+++ b/lib/pelt.py
@@ -170,7 +170,7 @@ class PELT:
# long as --pelt isn't active.
import ruptures
- if self.stretch != 1:
+ if self.stretch > 1:
traces = list(
map(
lambda trace: np.interp(
@@ -183,6 +183,20 @@ class PELT:
traces,
)
)
+ elif self.stretch < -1:
+ ds_factor = -self.stretch
+ new_traces = list()
+ for trace in traces:
+ if trace.shape[0] % ds_factor:
+ trace = np.array(
+ list(trace)
+ + [
+ trace[-1]
+ for i in range(ds_factor - (trace.shape[0] % ds_factor))
+ ]
+ )
+ new_traces.append(trace.reshape(-1, ds_factor).mean(axis=1))
+ traces = new_traces
algos = list()
queue = list()
@@ -224,12 +238,19 @@ class PELT:
changepoints.pop()
if len(changepoints) and changepoints[0] == 0:
changepoints.pop(0)
- if self.stretch != 1:
+ if self.stretch > 1:
changepoints = list(
np.array(
np.around(np.array(changepoints) / self.stretch), dtype=np.int
)
)
+ elif self.stretch < -1:
+ ds_factor = -self.stretch
+ changepoints = list(
+ np.array(
+ np.around(np.array(changepoints) * ds_factor), dtype=np.int
+ )
+ )
changepoints_by_penalty_by_trace[i][range_penalty] = changepoints
for i in range(len(traces)):