summaryrefslogtreecommitdiff
path: root/lib/model.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/model.py')
-rw-r--r--lib/model.py31
1 files changed, 22 insertions, 9 deletions
diff --git a/lib/model.py b/lib/model.py
index 9070ece..f0b0624 100644
--- a/lib/model.py
+++ b/lib/model.py
@@ -926,13 +926,22 @@ class PTAModel:
def pelt_refine(self, by_param_key):
logger.debug(f"PELT: {by_param_key} needs refinement")
- # Assumption: All power traces for this parameter setting
- # are similar, so determining the penalty for the first one
- # is sufficient.
- penalty, changepoints = self.pelt.get_penalty_and_changepoints(
- self.by_param[by_param_key]["power_traces"][0]
- )
- if len(changepoints) == 0:
+
+ penalty_by_trace = list()
+ changepoints_by_penalty_by_trace = list()
+ num_changepoints_by_trace = list()
+ changepoints_by_trace = list()
+
+ for power_values in self.by_param[by_param_key]["power_traces"]:
+ penalty, changepoints_by_penalty = self.pelt.get_penalty_and_changepoints(
+ power_values
+ )
+ penalty_by_trace.append(penalty)
+ changepoints_by_penalty_by_trace.append(changepoints_by_penalty)
+ num_changepoints_by_trace.append(len(changepoints_by_penalty[penalty]))
+ changepoints_by_trace.append(changepoints_by_penalty[penalty])
+
+ if np.median(num_changepoints_by_trace) < 1:
logger.debug(f" we found no changepoints with penalty {penalty}")
substate_counts = [1 for i in self.by_param[by_param_key]["param"]]
substate_data = {
@@ -941,13 +950,17 @@ class PTAModel:
"power_std": self.by_param[by_param_key]["power_std"],
}
return (substate_counts, substate_data)
+
+ num_changepoints = np.argmax(np.bincount(num_changepoints_by_trace))
+
logger.debug(
- f" we found {len(changepoints)} changepoints with penalty {penalty}"
+ f" we found {num_changepoints} changepoints from {num_changepoints_by_trace} with penalties {penalty_by_trace}"
)
return self.pelt.calc_raw_states(
self.by_param[by_param_key]["timestamps"],
self.by_param[by_param_key]["power_traces"],
- penalty,
+ changepoints_by_trace,
+ num_changepoints,
)
def find_substates(self):