diff options
Diffstat (limited to 'lib/pelt.py')
-rw-r--r-- | lib/pelt.py | 24 |
1 files changed, 20 insertions, 4 deletions
diff --git a/lib/pelt.py b/lib/pelt.py index 58bc9c1..bd79c9b 100644 --- a/lib/pelt.py +++ b/lib/pelt.py @@ -34,6 +34,7 @@ def PELT_get_raw_states(num_measurement, algo, penalty): class PELT: def __init__(self, **kwargs): + self.algo = "pelt" self.model = "l1" self.jump = 1 self.min_dist = 10 @@ -45,6 +46,7 @@ class PELT: self.__dict__.update(kwargs) if os.getenv("DFATOOL_PELT_MODEL"): + # https://centre-borelli.github.io/ruptures-docs/user-guide/costs/costl1/ self.model = os.getenv("DFATOOL_PELT_MODEL") if os.getenv("DFATOOL_PELT_JUMP"): self.jump = int(os.getenv("DFATOOL_PELT_JUMP")) @@ -73,7 +75,7 @@ class PELT: normed_signal[i] = normed_signal[i] * scaler return normed_signal - def get_penalty_and_changepoints(self, signal, penalty=None): + def get_penalty_and_changepoints(self, signal, penalty=None, num_changepoints=None): # imported here as ruptures is only used for changepoint detection. # This way, dfatool can be used without having ruptures installed as # long as --pelt isn't active. @@ -84,9 +86,17 @@ class PELT: else: self.jump = 1 - algo = ruptures.Pelt( - model=self.model, jump=self.jump, min_size=self.min_dist - ).fit(self.norm_signal(signal)) + if self.algo == "dynp": + # https://centre-borelli.github.io/ruptures-docs/user-guide/detection/dynp/ + algo = ruptures.Dynp( + model=self.model, jump=self.jump, min_size=self.min_dist + ) + else: + # https://centre-borelli.github.io/ruptures-docs/user-guide/detection/pelt/ + algo = ruptures.Pelt( + model=self.model, jump=self.jump, min_size=self.min_dist + ) + algo = algo.fit(self.norm_signal(signal)) if penalty is not None: changepoints = algo.predict(pen=penalty) @@ -94,6 +104,12 @@ class PELT: changepoints.pop() return penalty, changepoints + if self.algo == "dynp" and num_changepoints is not None: + changepoints = algo.predict(pen=penalty) + if len(changepoints) and changepoints[-1] == len(signal): + changepoints.pop() + return None, changepoints + queue = list() for i in range(0, 100): queue.append((algo, i)) |