diff options
author | Daniel Friesel <daniel.friesel@uos.de> | 2021-02-22 11:19:10 +0100 |
---|---|---|
committer | Daniel Friesel <daniel.friesel@uos.de> | 2021-02-22 11:19:10 +0100 |
commit | 0fdf04c5cfb473cb9bcc0d5bf4eacb0d3c6f51e7 (patch) | |
tree | f5fcd8d0ced9505972ffe702165cc66774e257d8 /lib/pelt.py | |
parent | b3ce8283535499ba420046ca90402c7bb1b22e61 (diff) |
PELT: Increase parallelism
Diffstat (limited to 'lib/pelt.py')
-rw-r--r-- | lib/pelt.py | 171 |
1 files changed, 94 insertions, 77 deletions
diff --git a/lib/pelt.py b/lib/pelt.py index 58415e1..3336e4a 100644 --- a/lib/pelt.py +++ b/lib/pelt.py @@ -11,8 +11,8 @@ from .utils import NpEncoder logger = logging.getLogger(__name__) -def PELT_get_changepoints(algo, penalty): - res = (penalty, algo.predict(pen=penalty)) +def PELT_get_changepoints(index, penalty, algo): + res = (index, penalty, algo.predict(pen=penalty)) return res @@ -41,7 +41,6 @@ class PELT: self.model = "l1" self.jump = 1 self.min_dist = 10 - self.num_samples = None self.name_filter = None self.refinement_threshold = 200e-6 # 200 µW self.range_min = 1 @@ -87,9 +86,9 @@ class PELT: normed_signal[i] = normed_signal[i] * scaler return normed_signal - def cache_key(self, signal, penalty, num_changepoints): + def cache_key(self, traces, penalty, num_changepoints): config = [ - signal, + traces, penalty, num_changepoints, self.algo, @@ -105,10 +104,10 @@ class PELT: ).hexdigest() return cache_key - def save_cache(self, signal, penalty, num_changepoints, data): + def save_cache(self, traces, penalty, num_changepoints, data): if self.cache_dir is None: return - cache_key = self.cache_key(signal, penalty, num_changepoints) + cache_key = self.cache_key(traces, penalty, num_changepoints) try: os.mkdir(self.cache_dir) @@ -123,8 +122,8 @@ class PELT: with open(f"{self.cache_dir}/{cache_key[:2]}/pelt-{cache_key}.json", "w") as f: json.dump(data, f) - def load_cache(self, signal, penalty, num_changepoints): - cache_key = self.cache_key(signal, penalty, num_changepoints) + def load_cache(self, traces, penalty, num_changepoints): + cache_key = self.cache_key(traces, penalty, num_changepoints) try: with open( f"{self.cache_dir}/{cache_key[:2]}/pelt-{cache_key}.json", "r" @@ -133,57 +132,86 @@ class PELT: except FileNotFoundError: return None - def get_penalty_and_changepoints(self, signal, penalty=None, num_changepoints=None): - data = self.load_cache(signal, penalty, num_changepoints) + def get_penalty_and_changepoints(self, traces, penalty=None, num_changepoints=None): + data = self.load_cache(traces, penalty, num_changepoints) if data: - if type(data[1]) is dict: - str_keys = list(data[1].keys()) - for k in str_keys: - data[1][int(k)] = data[1].pop(k) + for res in data: + if type(res[1]) is dict: + str_keys = list(res[1].keys()) + for k in str_keys: + res[1][int(k)] = res[1].pop(k) return data data = self.calculate_penalty_and_changepoints( - signal, penalty, num_changepoints + traces, penalty, num_changepoints ) - self.save_cache(signal, penalty, num_changepoints, data) + self.save_cache(traces, penalty, num_changepoints, data) return data def calculate_penalty_and_changepoints( - self, signal, penalty=None, num_changepoints=None + self, traces, penalty=None, num_changepoints=None ): # imported here as ruptures is only used for changepoint detection. # This way, dfatool can be used without having ruptures installed as # long as --pelt isn't active. import ruptures + list_of_lists = type(traces[0]) is list or type(traces[0]) is np.ndarray + + if not list_of_lists: + traces = [traces] + if self.stretch != 1: - signal = np.interp( - np.linspace(0, len(signal) - 1, (len(signal) - 1) * self.stretch + 1), - np.arange(len(signal)), - signal, + traces = list( + map( + lambda trace: np.interp( + np.linspace( + 0, len(trace) - 1, (len(trace) - 1) * self.stretch + 1 + ), + np.arange(len(trace)), + trace, + ), + traces, + ) ) - if self.num_samples is not None: - if len(signal) > self.num_samples: - self.jump = len(signal) // int(self.num_samples) + algos = list() + queue = list() + changepoints_by_penalty_by_trace = list() + results = list() + + for i in range(len(traces)): + if self.algo == "dynp": + # https://centre-borelli.github.io/ruptures-docs/user-guide/detection/dynp/ + algo = ruptures.Dynp( + model=self.model, jump=self.jump, min_size=self.min_dist + ) + else: + # https://centre-borelli.github.io/ruptures-docs/user-guide/detection/pelt/ + algo = ruptures.Pelt( + model=self.model, jump=self.jump, min_size=self.min_dist + ) + algo = algo.fit(self.norm_signal(traces[i])) + algos.append(algo) + + for i in range(len(traces)): + changepoints_by_penalty_by_trace.append(dict()) + if penalty is not None: + queue.append((i, penalty, algos[i])) + elif self.algo == "dynp" and num_changepoints is not None: + queue.append((i, None, algos[i])) else: - self.jump = 1 + for range_penalty in range(self.range_min, self.range_max): + queue.append((i, range_penalty, algos[i])) - if self.algo == "dynp": - # https://centre-borelli.github.io/ruptures-docs/user-guide/detection/dynp/ - algo = ruptures.Dynp( - model=self.model, jump=self.jump, min_size=self.min_dist - ) + if self.with_multiprocessing: + with Pool() as pool: + changepoints_by_trace = pool.starmap(PELT_get_changepoints, queue) else: - # https://centre-borelli.github.io/ruptures-docs/user-guide/detection/pelt/ - algo = ruptures.Pelt( - model=self.model, jump=self.jump, min_size=self.min_dist - ) - algo = algo.fit(self.norm_signal(signal)) + changepoints_by_trace = map(lambda x: PELT_get_changepoints(*x), queue) - if penalty is not None: - changepoints = algo.predict(pen=penalty) - if len(changepoints) and changepoints[-1] == len(signal): + for i, range_penalty, changepoints in changepoints_by_trace: + if len(changepoints) and changepoints[-1] == len(traces[i]): changepoints.pop() if len(changepoints) and changepoints[0] == 0: changepoints.pop(0) @@ -191,39 +219,28 @@ class PELT: changepoints = np.array( np.around(np.array(changepoints) / self.stretch), dtype=np.int ) - return penalty, changepoints - - if self.algo == "dynp" and num_changepoints is not None: - changepoints = algo.predict(n_bkps=num_changepoints) - if len(changepoints) and changepoints[-1] == len(signal): - changepoints.pop() - if len(changepoints) and changepoints[0] == 0: - changepoints.pop(0) - if self.stretch != 1: - changepoints = np.array( - np.around(np.array(changepoints) / self.stretch), dtype=np.int + changepoints_by_penalty_by_trace[i][range_penalty] = changepoints + + for i in range(len(traces)): + changepoints_by_penalty = changepoints_by_penalty_by_trace[i] + if penalty is not None: + results.append((penalty, changepoints_by_penalty[penalty])) + elif self.algo == "dynp" and num_changepoints is not None: + results.append((None, changepoints_by_penalty[None])) + else: + results.append( + ( + self.find_penalty(changepoints_by_penalty), + changepoints_by_penalty, + ) ) - return None, changepoints - queue = list() - for i in range(self.range_min, self.range_max): - queue.append((algo, i)) - if self.with_multiprocessing: - with Pool() as pool: - changepoints = pool.starmap(PELT_get_changepoints, queue) + if list_of_lists: + return results else: - changepoints = map(lambda x: PELT_get_changepoints(*x), queue) - changepoints_by_penalty = dict() - for res in changepoints: - if len(res[1]) > 0 and res[1][-1] == len(signal): - res[1].pop() - if self.stretch != 1: - res = ( - res[0], - np.array(np.around(np.array(res[1]) / self.stretch), dtype=np.int), - ) - changepoints_by_penalty[res[0]] = res[1] + return results[0] + def find_penalty(self, changepoints_by_penalty): changepoint_counts = list() for i in range(self.range_min, self.range_max): changepoint_counts.append(len(changepoints_by_penalty[i])) @@ -249,17 +266,17 @@ class PELT: prev_val = changepoint_count middle_of_plateau = longest_start + (longest_end - longest_start) // 2 - return self.range_min + middle_of_plateau, changepoints_by_penalty + return self.range_min + middle_of_plateau - def get_changepoints(self, signal, **kwargs): - penalty, changepoints_by_penalty = self.get_penalty_and_changepoints( - signal, **kwargs - ) - return changepoints_by_penalty[penalty] + def get_changepoints(self, traces, **kwargs): + results = self.get_penalty_and_changepoints(traces, **kwargs) + if type(results) is list: + return list(map(lambda res: res[1][res[0]])) + return results[1][results[0]] - def get_penalty(self, signal, **kwargs): - penalty, _ = self.get_penalty_and_changepoints(signal, **kwargs) - return penalty + def get_penalty(self, traces, **kwargs): + results = self.get_penalty_and_changepoints(traces, **kwargs) + return list(map(lambda res: res[0])) def calc_raw_states( self, |