summaryrefslogtreecommitdiff
path: root/lib/pelt.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/pelt.py')
-rw-r--r--lib/pelt.py24
1 files changed, 20 insertions, 4 deletions
diff --git a/lib/pelt.py b/lib/pelt.py
index 58bc9c1..bd79c9b 100644
--- a/lib/pelt.py
+++ b/lib/pelt.py
@@ -34,6 +34,7 @@ def PELT_get_raw_states(num_measurement, algo, penalty):
class PELT:
def __init__(self, **kwargs):
+ self.algo = "pelt"
self.model = "l1"
self.jump = 1
self.min_dist = 10
@@ -45,6 +46,7 @@ class PELT:
self.__dict__.update(kwargs)
if os.getenv("DFATOOL_PELT_MODEL"):
+ # https://centre-borelli.github.io/ruptures-docs/user-guide/costs/costl1/
self.model = os.getenv("DFATOOL_PELT_MODEL")
if os.getenv("DFATOOL_PELT_JUMP"):
self.jump = int(os.getenv("DFATOOL_PELT_JUMP"))
@@ -73,7 +75,7 @@ class PELT:
normed_signal[i] = normed_signal[i] * scaler
return normed_signal
- def get_penalty_and_changepoints(self, signal, penalty=None):
+ def get_penalty_and_changepoints(self, signal, penalty=None, num_changepoints=None):
# imported here as ruptures is only used for changepoint detection.
# This way, dfatool can be used without having ruptures installed as
# long as --pelt isn't active.
@@ -84,9 +86,17 @@ class PELT:
else:
self.jump = 1
- algo = ruptures.Pelt(
- model=self.model, jump=self.jump, min_size=self.min_dist
- ).fit(self.norm_signal(signal))
+ if self.algo == "dynp":
+ # https://centre-borelli.github.io/ruptures-docs/user-guide/detection/dynp/
+ algo = ruptures.Dynp(
+ model=self.model, jump=self.jump, min_size=self.min_dist
+ )
+ else:
+ # https://centre-borelli.github.io/ruptures-docs/user-guide/detection/pelt/
+ algo = ruptures.Pelt(
+ model=self.model, jump=self.jump, min_size=self.min_dist
+ )
+ algo = algo.fit(self.norm_signal(signal))
if penalty is not None:
changepoints = algo.predict(pen=penalty)
@@ -94,6 +104,12 @@ class PELT:
changepoints.pop()
return penalty, changepoints
+ if self.algo == "dynp" and num_changepoints is not None:
+ changepoints = algo.predict(pen=penalty)
+ if len(changepoints) and changepoints[-1] == len(signal):
+ changepoints.pop()
+ return None, changepoints
+
queue = list()
for i in range(0, 100):
queue.append((algo, i))