summaryrefslogtreecommitdiff
path: root/lib/drift.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/drift.py')
-rw-r--r--lib/drift.py339
1 files changed, 339 insertions, 0 deletions
diff --git a/lib/drift.py b/lib/drift.py
new file mode 100644
index 0000000..cb769f4
--- /dev/null
+++ b/lib/drift.py
@@ -0,0 +1,339 @@
+#!/usr/bin/env python3
+
+import numpy as np
+import os
+import scipy
+from bisect import bisect_left, bisect_right
+
+
+def compensate(data, timestamps, event_timestamps, offline_index=None):
+ """Use ruptures (e.g. Pelt, Dynp) to determine transition timestamps."""
+ from dfatool.pelt import PELT
+
+ # "rbf" und "l2" scheinen ähnlich gut zu funktionieren, l2 ist schneller. l1 ist wohl noch besser.
+ # PELT does not find changepoints for transitions which span just four or five data points (i.e., transitions shorter than ~2ms).
+ # Workaround: Double the data rate passed to PELT by interpolation ("stretch=2")
+ pelt = PELT(with_multiprocessing=False, stretch=2, min_dist=1, cache_dir=None)
+ expected_transition_start_timestamps = event_timestamps[::2]
+ transition_start_candidate_weights = list()
+ drift = 0
+
+ # TODO auch Kandidatenbestimmung per Ableitung probieren
+ # (-> Umgebungsvariable zur Auswahl)
+
+ pelt_traces = list()
+ range_timestamps = list()
+ candidate_weights = list()
+
+ for i, expected_start_ts in enumerate(expected_transition_start_timestamps):
+ expected_end_ts = event_timestamps[2 * i + 1]
+ # assumption: maximum deviation between expected and actual timestamps is 5ms.
+ # We use ±10ms to have some contetx for PELT
+ et_timestamps_start = bisect_left(timestamps, expected_start_ts - 10e-3)
+ et_timestamps_end = bisect_right(timestamps, expected_end_ts + 10e-3)
+ range_timestamps.append(timestamps[et_timestamps_start : et_timestamps_end + 1])
+ pelt_traces.append(data[et_timestamps_start : et_timestamps_end + 1])
+
+ # TODO for greedy mode, perform changepoint detection between greedy steps
+ # (-> the expected changepoint area is well-known, Dynp with 1/2 changepoints
+ # should work much better than "somewhere in these 20ms there should be a transition")
+
+ if os.getenv("DFATOOL_DRIFT_COMPENSATION_PENALTY"):
+ penalties = (int(os.getenv("DFATOOL_DRIFT_COMPENSATION_PENALTY")),)
+ else:
+ penalties = (1, 2, 5, 10, 15, 20)
+ for penalty in penalties:
+ changepoints_by_transition = pelt.get_changepoints(pelt_traces, penalty=penalty)
+ for i in range(len(expected_transition_start_timestamps)):
+ candidate_weights.append(dict())
+ for changepoint in changepoints_by_transition[i]:
+ if changepoint in candidate_weights[i]:
+ candidate_weights[i][changepoint] += 1
+ else:
+ candidate_weights[i][changepoint] = 1
+
+ for i, expected_start_ts in enumerate(expected_transition_start_timestamps):
+
+ # TODO ist expected_start_ts wirklich eine gute Referenz? Wenn vor einer Transition ein UART-Dump
+ # liegt, dürfte expected_end_ts besser sein, dann muss allerdings bei der compensation wieder auf
+ # start_ts zurückgerechnet werden.
+ transition_start_candidate_weights.append(
+ list(
+ map(
+ lambda k: (
+ range_timestamps[i][k] - expected_start_ts,
+ range_timestamps[i][k] - expected_end_ts,
+ candidate_weights[i][k],
+ ),
+ sorted(candidate_weights[i].keys()),
+ )
+ )
+ )
+
+ if os.getenv("DFATOOL_COMPENSATE_DRIFT_GREEDY"):
+ return compensate_drift_greedy(
+ event_timestamps, transition_start_candidate_weights
+ )
+
+ return compensate_drift_graph(
+ event_timestamps,
+ transition_start_candidate_weights,
+ offline_index=offline_index,
+ )
+
+
+def compensate_drift_graph(
+ event_timestamps, transition_start_candidate_weights, offline_index=None
+):
+ # Algorithm: Obtain the shortest path in a layered graph made up from
+ # transition candidates. Each node represents a transition candidate timestamp, and each layer represents a transition.
+ # Each node in layer i contains a directed edge to each node in layer i+1.
+ # The edge weight is the drift delta between the two nodes. So, if,
+ # node X (transition i, candidate a) has a drift of 5, and node Y
+ # (transition i+1, candidate b) has a drift of -2, the weight is 7.
+ # The first and last layer of the graph consists of a single node
+ # with a drift of 0, representing the start / end synchronization pulse, respectively.
+
+ prev_nodes = [0]
+ prev_drifts = [0]
+ node_drifts = [0]
+ edge_srcs = list()
+ edge_dsts = list()
+ csr_weights = list()
+
+ # (transition index) -> [candidate 0/start node, candidate 0/end node, candidate 1/start node, ...]
+ nodes_by_transition_index = dict()
+
+ # (node number) -> (transition index, candidate index, is_end)
+ # (-> transition_start_candidate_weights[transition index][candidate index][is_end])
+ transition_by_node = dict()
+
+ compensated_timestamps = list()
+
+ # default: up to two nodes may be skipped
+ max_skip_count = 2
+
+ if os.getenv("DFATOOL_DC_MAX_SKIP"):
+ max_skip_count = int(os.getenv("DFATOOL_DC_MAX_SKIP"))
+
+ for transition_index, candidates in enumerate(transition_start_candidate_weights):
+ new_nodes = list()
+ new_drifts = list()
+ i_offset = prev_nodes[-1] + 1
+ nodes_by_transition_index[transition_index] = list()
+ for new_node_i, (new_drift_start, new_drift_end, _) in enumerate(candidates):
+ for is_end, new_drift in enumerate((new_drift_start, new_drift_end)):
+ new_node = i_offset + new_node_i * 2 + is_end
+ nodes_by_transition_index[transition_index].append(new_node)
+ transition_by_node[new_node] = (transition_index, new_node_i, is_end)
+ new_nodes.append(new_node)
+ new_drifts.append(new_drift)
+ node_drifts.append(new_drift)
+ for prev_node_i, prev_node in enumerate(prev_nodes):
+ prev_drift = prev_drifts[prev_node_i]
+
+ edge_srcs.append(prev_node)
+ edge_dsts.append(new_node)
+
+ delta_drift = np.abs(prev_drift - new_drift)
+ # TODO evaluate "delta_drift ** 2" or similar nonlinear
+ # weights -> further penalize large drift deltas
+ csr_weights.append(delta_drift)
+
+ # a transition's candidate list may be empty
+ if len(new_nodes):
+ prev_nodes = new_nodes
+ prev_drifts = new_drifts
+
+ # add an end node for shortest path search
+ # (end node == final sync, so drift == 0)
+ new_node = prev_nodes[-1] + 1
+ for prev_node_i, prev_node in enumerate(prev_nodes):
+ prev_drift = prev_drifts[prev_node_i]
+ edge_srcs.append(prev_node)
+ edge_dsts.append(new_node)
+ csr_weights.append(np.abs(prev_drift))
+
+ # Add "skip" edges spanning from transition i to transition i+n (n > 1).
+ # These avoid synchronization errors caused by transitions wich are
+ # not found by changepiont detection, as long as they are sufficiently rare.
+ for transition_index, candidates in enumerate(transition_start_candidate_weights):
+ for skip_count in range(2, max_skip_count + 2):
+ if transition_index < skip_count:
+ continue
+ for from_node in nodes_by_transition_index[transition_index - skip_count]:
+ for to_node in nodes_by_transition_index[transition_index]:
+
+ (from_trans_i, from_candidate_i, from_is_end) = transition_by_node[
+ from_node
+ ]
+ to_trans_i, to_candidate_i, to_is_end = transition_by_node[to_node]
+
+ assert transition_index - skip_count == from_trans_i
+ assert transition_index == to_trans_i
+
+ from_drift = transition_start_candidate_weights[from_trans_i][
+ from_candidate_i
+ ][from_is_end]
+ to_drift = transition_start_candidate_weights[to_trans_i][
+ to_candidate_i
+ ][to_is_end]
+
+ edge_srcs.append(from_node)
+ edge_dsts.append(to_node)
+ csr_weights.append(
+ np.abs(from_drift - to_drift) + (skip_count - 1) * 270e-6
+ )
+
+ sm = scipy.sparse.csr_matrix(
+ (csr_weights, (edge_srcs, edge_dsts)), shape=(new_node + 1, new_node + 1)
+ )
+ dm, predecessors = scipy.sparse.csgraph.shortest_path(
+ sm, return_predecessors=True, indices=0
+ )
+
+ nodes = list()
+ pred = predecessors[-1]
+ while pred > 0:
+ nodes.append(pred)
+ pred = predecessors[pred]
+
+ nodes = list(reversed(nodes))
+
+ # first and last node are not included in "nodes" as they represent
+ # the start/stop sync pulse (and not a transition with sync candidates)
+
+ prev_transition = -1
+ for i, node in enumerate(nodes):
+ transition, _, _ = transition_by_node[node]
+ drift = node_drifts[node]
+
+ while transition - prev_transition > 1:
+ prev_drift = node_drifts[nodes[i - 1]]
+ prev_transition += 1
+ expected_start_ts = event_timestamps[prev_transition * 2] + prev_drift
+ expected_end_ts = event_timestamps[prev_transition * 2 + 1] + prev_drift
+ compensated_timestamps.append(expected_start_ts)
+ compensated_timestamps.append(expected_end_ts)
+
+ expected_start_ts = event_timestamps[transition * 2] + drift
+ expected_end_ts = event_timestamps[transition * 2 + 1] + drift
+ compensated_timestamps.append(expected_start_ts)
+ compensated_timestamps.append(expected_end_ts)
+ prev_transition = transition
+
+ # handle skips over the last few transitions, if any
+ transition = len(transition_start_candidate_weights) - 1
+ while transition - prev_transition > 0:
+ prev_drift = node_drifts[nodes[-1]]
+ prev_transition += 1
+ expected_start_ts = event_timestamps[prev_transition * 2] + prev_drift
+ expected_end_ts = event_timestamps[prev_transition * 2 + 1] + prev_drift
+ compensated_timestamps.append(expected_start_ts)
+ compensated_timestamps.append(expected_end_ts)
+
+ if os.getenv("DFATOOL_EXPORT_DRIFT_COMPENSATION"):
+ import json
+ from dfatool.utils import NpEncoder
+
+ expected_transition_start_timestamps = event_timestamps[::2]
+ filename = os.getenv("DFATOOL_EXPORT_DRIFT_COMPENSATION")
+ filename = f"{filename}.{offline_index}"
+
+ with open(filename, "w") as f:
+ json.dump(
+ [
+ expected_transition_start_timestamps,
+ transition_start_candidate_weights,
+ ],
+ f,
+ cls=NpEncoder,
+ )
+
+ return compensated_timestamps
+
+
+def compensate_drift_greedy(event_timestamps, transition_start_candidate_weights):
+ drift = 0
+ expected_transition_start_timestamps = event_timestamps[::2]
+ compensated_timestamps = list()
+
+ for i, expected_start_ts in enumerate(expected_transition_start_timestamps):
+ candidates = sorted(
+ map(
+ lambda x: x[0] + expected_start_ts,
+ transition_start_candidate_weights[i],
+ )
+ )
+ expected_start_ts += drift
+ expected_end_ts = event_timestamps[2 * i + 1] + drift
+
+ # choose the next candidates around the expected sync point.
+ start_right_sync = bisect_left(candidates, expected_start_ts)
+ start_left_sync = start_right_sync - 1
+
+ end_right_sync = bisect_left(candidates, expected_end_ts)
+ end_left_sync = end_right_sync - 1
+
+ if start_right_sync >= 0:
+ start_left_diff = expected_start_ts - candidates[start_left_sync]
+ else:
+ start_left_diff = np.inf
+
+ if start_right_sync < len(candidates):
+ start_right_diff = candidates[start_right_sync] - expected_start_ts
+ else:
+ start_right_diff = np.inf
+
+ if end_left_sync >= 0:
+ end_left_diff = expected_end_ts - candidates[end_left_sync]
+ else:
+ end_left_diff = np.inf
+
+ if end_right_sync < len(candidates):
+ end_right_diff = candidates[end_right_sync] - expected_end_ts
+ else:
+ end_right_diff = np.inf
+
+ drift_candidates = (
+ start_left_diff,
+ start_right_diff,
+ end_left_diff,
+ end_right_diff,
+ )
+ min_drift_i = np.argmin(drift_candidates)
+ min_drift = min(drift_candidates)
+
+ if min_drift < 5e-4:
+ if min_drift_i % 2 == 0:
+ # left
+ compensated_timestamps.append(expected_start_ts - min_drift)
+ compensated_timestamps.append(expected_end_ts - min_drift)
+ drift -= min_drift
+ else:
+ # right
+ compensated_timestamps.append(expected_start_ts + min_drift)
+ compensated_timestamps.append(expected_end_ts + min_drift)
+ drift += min_drift
+
+ else:
+ compensated_timestamps.append(expected_start_ts)
+ compensated_timestamps.append(expected_end_ts)
+
+ if os.getenv("DFATOOL_EXPORT_DRIFT_COMPENSATION"):
+ import json
+ from dfatool.utils import NpEncoder
+
+ expected_transition_start_timestamps = event_timestamps[::2]
+
+ with open(os.getenv("DFATOOL_EXPORT_DRIFT_COMPENSATION"), "w") as f:
+ json.dump(
+ [
+ expected_transition_start_timestamps,
+ transition_start_candidate_weights,
+ ],
+ f,
+ cls=NpEncoder,
+ )
+
+ return compensated_timestamps