diff options
Diffstat (limited to 'lib')
-rw-r--r-- | lib/lennart/DataProcessor.py | 81 |
1 files changed, 76 insertions, 5 deletions
diff --git a/lib/lennart/DataProcessor.py b/lib/lennart/DataProcessor.py index 2554dba..23e6766 100644 --- a/lib/lennart/DataProcessor.py +++ b/lib/lennart/DataProcessor.py @@ -195,7 +195,6 @@ class DataProcessor: compensated_timestamps.append(expected_start_ts + drift) compensated_timestamps.append(expected_end_ts + drift) - print(drift) return compensated_timestamps @@ -265,12 +264,21 @@ class DataProcessor: edge_dsts = list() csr_weights = list() node_drifts = list() - for candidates in transition_start_candidate_weights: + + nodes_by_transition_index = dict() + transition_by_node = dict() + + for transition_index, candidates in enumerate( + transition_start_candidate_weights + ): new_nodes = list() new_drifts = list() i_offset = prev_nodes[-1] + 1 + nodes_by_transition_index[transition_index] = list() for new_node_i, (_, new_drift, _) in enumerate(candidates): new_node = new_node_i + i_offset + nodes_by_transition_index[transition_index].append(new_node) + transition_by_node[new_node] = transition_index new_nodes.append(new_node) new_drifts.append(new_drift) node_drifts.append(new_drift) @@ -295,6 +303,40 @@ class DataProcessor: edge_dsts.append(new_node) csr_weights.append(np.abs(prev_drift)) + # Add "skip" edges spanning from transition i to transition i+2 + # and from transition i to transition i+3. + # These avoid synchronization errors caused by transitions wich are + # not found by changepiont detection, as long as they are sufficiently rare. + for transition_index, candidates in enumerate( + transition_start_candidate_weights + ): + if transition_index < 2: + continue + for from_i, (_, from_drift, _) in enumerate( + transition_start_candidate_weights[transition_index - 2] + ): + for to_i, (_, to_drift, _) in enumerate(candidates): + # Penalize shortcut by the duration of one sample + # (~270 us) + edge_srcs.append( + nodes_by_transition_index[transition_index - 2][from_i] + ) + edge_dsts.append(nodes_by_transition_index[transition_index][to_i]) + csr_weights.append(np.abs(from_drift - to_drift) + 270e-6) + if transition_index < 3: + continue + for from_i, (_, from_drift, _) in enumerate( + transition_start_candidate_weights[transition_index - 3] + ): + for to_i, (_, to_drift, _) in enumerate(candidates): + # Penalize shortcut by the duration of one sample + # (~270 us) + edge_srcs.append( + nodes_by_transition_index[transition_index - 3][from_i] + ) + edge_dsts.append(nodes_by_transition_index[transition_index][to_i]) + csr_weights.append(np.abs(from_drift - to_drift) + 2 * 270e-6) + sm = scipy.sparse.csr_matrix( (csr_weights, (edge_srcs, edge_dsts)), shape=(new_node + 1, new_node + 1) ) @@ -308,15 +350,44 @@ class DataProcessor: nodes.append(pred) pred = predecessors[pred] + nodes = list(reversed(nodes)) + # first and graph nodes are not included in "nodes" as they represent # the start/stop sync pulse (and not a transition with sync candidates) - for i, node in enumerate(reversed(nodes)): + prev_transition = -1 + for i, node in enumerate(nodes): + transition = transition_by_node[node] drift = node_drifts[node] - expected_start_ts = sync_timestamps[i * 2] + drift - expected_end_ts = sync_timestamps[i * 2 + 1] + drift + + if transition - prev_transition >= 2: + # previous transition was skipped due to lack of detected changepoints + prev_drift = node_drifts[nodes[i - 1]] + mean_drift = np.mean([prev_drift, drift]) + expected_start_ts = ( + sync_timestamps[(prev_transition + 1) * 2] + mean_drift + ) + expected_end_ts = ( + sync_timestamps[(prev_transition + 1) * 2 + 1] + mean_drift + ) + compensated_timestamps.append(expected_start_ts) + compensated_timestamps.append(expected_end_ts) + if transition - prev_transition >= 3: + # previous transition was skipped due to lack of detected changepoints + expected_start_ts = ( + sync_timestamps[(prev_transition + 2) * 2] + mean_drift + ) + expected_end_ts = ( + sync_timestamps[(prev_transition + 2) * 2 + 1] + mean_drift + ) + compensated_timestamps.append(expected_start_ts) + compensated_timestamps.append(expected_end_ts) + + expected_start_ts = sync_timestamps[transition * 2] + drift + expected_end_ts = sync_timestamps[transition * 2 + 1] + drift compensated_timestamps.append(expected_start_ts) compensated_timestamps.append(expected_end_ts) + prev_transition = transition if os.getenv("DFATOOL_EXPORT_DRIFT_COMPENSATION"): import json |