summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--lib/lennart/DataProcessor.py81
1 files changed, 76 insertions, 5 deletions
diff --git a/lib/lennart/DataProcessor.py b/lib/lennart/DataProcessor.py
index 2554dba..23e6766 100644
--- a/lib/lennart/DataProcessor.py
+++ b/lib/lennart/DataProcessor.py
@@ -195,7 +195,6 @@ class DataProcessor:
compensated_timestamps.append(expected_start_ts + drift)
compensated_timestamps.append(expected_end_ts + drift)
- print(drift)
return compensated_timestamps
@@ -265,12 +264,21 @@ class DataProcessor:
edge_dsts = list()
csr_weights = list()
node_drifts = list()
- for candidates in transition_start_candidate_weights:
+
+ nodes_by_transition_index = dict()
+ transition_by_node = dict()
+
+ for transition_index, candidates in enumerate(
+ transition_start_candidate_weights
+ ):
new_nodes = list()
new_drifts = list()
i_offset = prev_nodes[-1] + 1
+ nodes_by_transition_index[transition_index] = list()
for new_node_i, (_, new_drift, _) in enumerate(candidates):
new_node = new_node_i + i_offset
+ nodes_by_transition_index[transition_index].append(new_node)
+ transition_by_node[new_node] = transition_index
new_nodes.append(new_node)
new_drifts.append(new_drift)
node_drifts.append(new_drift)
@@ -295,6 +303,40 @@ class DataProcessor:
edge_dsts.append(new_node)
csr_weights.append(np.abs(prev_drift))
+ # Add "skip" edges spanning from transition i to transition i+2
+ # and from transition i to transition i+3.
+ # These avoid synchronization errors caused by transitions wich are
+ # not found by changepiont detection, as long as they are sufficiently rare.
+ for transition_index, candidates in enumerate(
+ transition_start_candidate_weights
+ ):
+ if transition_index < 2:
+ continue
+ for from_i, (_, from_drift, _) in enumerate(
+ transition_start_candidate_weights[transition_index - 2]
+ ):
+ for to_i, (_, to_drift, _) in enumerate(candidates):
+ # Penalize shortcut by the duration of one sample
+ # (~270 us)
+ edge_srcs.append(
+ nodes_by_transition_index[transition_index - 2][from_i]
+ )
+ edge_dsts.append(nodes_by_transition_index[transition_index][to_i])
+ csr_weights.append(np.abs(from_drift - to_drift) + 270e-6)
+ if transition_index < 3:
+ continue
+ for from_i, (_, from_drift, _) in enumerate(
+ transition_start_candidate_weights[transition_index - 3]
+ ):
+ for to_i, (_, to_drift, _) in enumerate(candidates):
+ # Penalize shortcut by the duration of one sample
+ # (~270 us)
+ edge_srcs.append(
+ nodes_by_transition_index[transition_index - 3][from_i]
+ )
+ edge_dsts.append(nodes_by_transition_index[transition_index][to_i])
+ csr_weights.append(np.abs(from_drift - to_drift) + 2 * 270e-6)
+
sm = scipy.sparse.csr_matrix(
(csr_weights, (edge_srcs, edge_dsts)), shape=(new_node + 1, new_node + 1)
)
@@ -308,15 +350,44 @@ class DataProcessor:
nodes.append(pred)
pred = predecessors[pred]
+ nodes = list(reversed(nodes))
+
# first and graph nodes are not included in "nodes" as they represent
# the start/stop sync pulse (and not a transition with sync candidates)
- for i, node in enumerate(reversed(nodes)):
+ prev_transition = -1
+ for i, node in enumerate(nodes):
+ transition = transition_by_node[node]
drift = node_drifts[node]
- expected_start_ts = sync_timestamps[i * 2] + drift
- expected_end_ts = sync_timestamps[i * 2 + 1] + drift
+
+ if transition - prev_transition >= 2:
+ # previous transition was skipped due to lack of detected changepoints
+ prev_drift = node_drifts[nodes[i - 1]]
+ mean_drift = np.mean([prev_drift, drift])
+ expected_start_ts = (
+ sync_timestamps[(prev_transition + 1) * 2] + mean_drift
+ )
+ expected_end_ts = (
+ sync_timestamps[(prev_transition + 1) * 2 + 1] + mean_drift
+ )
+ compensated_timestamps.append(expected_start_ts)
+ compensated_timestamps.append(expected_end_ts)
+ if transition - prev_transition >= 3:
+ # previous transition was skipped due to lack of detected changepoints
+ expected_start_ts = (
+ sync_timestamps[(prev_transition + 2) * 2] + mean_drift
+ )
+ expected_end_ts = (
+ sync_timestamps[(prev_transition + 2) * 2 + 1] + mean_drift
+ )
+ compensated_timestamps.append(expected_start_ts)
+ compensated_timestamps.append(expected_end_ts)
+
+ expected_start_ts = sync_timestamps[transition * 2] + drift
+ expected_end_ts = sync_timestamps[transition * 2 + 1] + drift
compensated_timestamps.append(expected_start_ts)
compensated_timestamps.append(expected_end_ts)
+ prev_transition = transition
if os.getenv("DFATOOL_EXPORT_DRIFT_COMPENSATION"):
import json