From 6bd652f1c2097ae228d20b14869e06cca439c07a Mon Sep 17 00:00:00 2001 From: Daniel Friesel Date: Wed, 14 Apr 2021 14:23:11 +0200 Subject: pelt: use single PELT run to determine whether refinement is needed --- lib/pelt.py | 40 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 38 insertions(+), 2 deletions(-) (limited to 'lib/pelt.py') diff --git a/lib/pelt.py b/lib/pelt.py index 05d93f8..29faf5c 100644 --- a/lib/pelt.py +++ b/lib/pelt.py @@ -81,8 +81,40 @@ class PELT: if os.getenv("DFATOOL_PELT_MIN_DIST"): self.min_dist = int(os.getenv("DFATOOL_PELT_MIN_DIST")) - # signals: a set of uW measurements belonging to a single parameter configuration (i.e., a single by_param entry) - def needs_refinement(self, signals): + def needs_refinement_pelt(self, signals): + import ruptures + + count = 0 + for signal in signals: + if len(signal) < 100: + continue + + algo = ruptures.Pelt( + model=self.model, jump=len(signal) // 100, min_size=self.min_dist + ) + algo = algo.fit(self.norm_signal(signal)) + + # Empirically, most sub-state detectino results use a penalty + # in the range 30 to 60. If there's no changepoints with a + # penalty of 20, there's also no changepoins with any penalty + # > 20, so we can safely skip changepoint detection altogether. + changepoints = algo.predict(pen=20) + + if not changepoints: + continue + + if len(changepoints) and changepoints[-1] == len(signal): + changepoints.pop() + if len(changepoints) and changepoints[0] == 0: + changepoints.pop(0) + + if changepoints: + count += 1 + + refinement_ratio = count / len(signals) + return refinement_ratio > 0.3 + + def needs_refinement_percentile(self, signals): count = 0 for signal in signals: if len(signal) < 30: @@ -97,6 +129,10 @@ class PELT: refinement_ratio = count / len(signals) return refinement_ratio > 0.3 + # signals: a set of uW measurements belonging to a single parameter configuration (i.e., a single by_param entry) + def needs_refinement(self, signals): + return self.needs_refinement_pelt(signals) + def norm_signal(self, signal, scaler=25): max_val = max(np.abs(signal)) normed_signal = np.zeros(shape=len(signal)) -- cgit v1.2.3