summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lib/pelt.py40
1 files changed, 38 insertions, 2 deletions
diff --git a/lib/pelt.py b/lib/pelt.py
index 05d93f8..29faf5c 100644
--- a/lib/pelt.py
+++ b/lib/pelt.py
@@ -81,8 +81,40 @@ class PELT:
if os.getenv("DFATOOL_PELT_MIN_DIST"):
self.min_dist = int(os.getenv("DFATOOL_PELT_MIN_DIST"))
- # signals: a set of uW measurements belonging to a single parameter configuration (i.e., a single by_param entry)
- def needs_refinement(self, signals):
+ def needs_refinement_pelt(self, signals):
+ import ruptures
+
+ count = 0
+ for signal in signals:
+ if len(signal) < 100:
+ continue
+
+ algo = ruptures.Pelt(
+ model=self.model, jump=len(signal) // 100, min_size=self.min_dist
+ )
+ algo = algo.fit(self.norm_signal(signal))
+
+ # Empirically, most sub-state detectino results use a penalty
+ # in the range 30 to 60. If there's no changepoints with a
+ # penalty of 20, there's also no changepoins with any penalty
+ # > 20, so we can safely skip changepoint detection altogether.
+ changepoints = algo.predict(pen=20)
+
+ if not changepoints:
+ continue
+
+ if len(changepoints) and changepoints[-1] == len(signal):
+ changepoints.pop()
+ if len(changepoints) and changepoints[0] == 0:
+ changepoints.pop(0)
+
+ if changepoints:
+ count += 1
+
+ refinement_ratio = count / len(signals)
+ return refinement_ratio > 0.3
+
+ def needs_refinement_percentile(self, signals):
count = 0
for signal in signals:
if len(signal) < 30:
@@ -97,6 +129,10 @@ class PELT:
refinement_ratio = count / len(signals)
return refinement_ratio > 0.3
+ # signals: a set of uW measurements belonging to a single parameter configuration (i.e., a single by_param entry)
+ def needs_refinement(self, signals):
+ return self.needs_refinement_pelt(signals)
+
def norm_signal(self, signal, scaler=25):
max_val = max(np.abs(signal))
normed_signal = np.zeros(shape=len(signal))