w2t-bkin 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
w2t_bkin/sync/core.py ADDED
@@ -0,0 +1,678 @@
1
+ """Core synchronization logic and algorithms.
2
+
3
+ This module consolidates the core synchronization functionality, including:
4
+ 1. Low-level alignment algorithms (nearest neighbor, linear interpolation)
5
+ 2. Jitter computation and budget enforcement
6
+ 3. High-level stream synchronization utilities
7
+ 4. Protocol definitions for configuration
8
+
9
+ Key Concepts:
10
+ -------------
11
+ - **Sample Times**: The original timestamps recorded by your data stream's internal
12
+ clock (e.g., video camera's timer, pose estimation software's clock). These may
13
+ have clock drift or offset relative to other devices.
14
+
15
+ - **Reference Times**: The ground-truth timestamps from hardware synchronization
16
+ signals (e.g., TTL pulses from a master clock). These define the canonical
17
+ timebase that all data streams should align to.
18
+
19
+ - **Synchronized Times**: Sample times that have been **mapped** to the reference
20
+ timebase. These are NOT the raw reference times—they are your original sample
21
+ times adjusted to account for clock drift and offset between your device and
22
+ the reference clock.
23
+
24
+ - **Alignment**: The process of finding the mapping between sample_times and
25
+ reference_times, then computing synchronized times that place your samples on
26
+ the reference timebase.
27
+
28
+ Clock Drift Example:
29
+ --------------------
30
+ Device Clock (sample_times): [0.0, 1.0, 2.0, 3.0, 4.0] # 1 Hz
31
+ Reference Clock: [0.0, 1.01, 2.02, 3.03, 4.04] # 1% drift
32
+
33
+ Synchronized times: [0.0, 1.01, 2.02, 3.03, 4.04]
34
+ # ^ Sample times mapped to reference clock to correct for drift
35
+
36
+ Typical Workflow:
37
+ -----------------
38
+ 1. Acquire TTL pulses from hardware (e.g., camera trigger signals)
39
+ 2. Use TTL timestamps as reference_times
40
+ 3. Align each data stream (video, pose, facemap) to reference_times
41
+ 4. Use synchronized times (NOT raw sample times) when building NWB objects
42
+ """
43
+
44
+ import logging
45
+ from typing import Any, Dict, List, Literal, Protocol, Tuple
46
+ import warnings
47
+
48
+ import numpy as np
49
+ from scipy import stats
50
+
51
+ from ..exceptions import JitterExceedsBudgetError, SyncError
52
+
53
+ logger = logging.getLogger(__name__)
54
+
55
+ __all__ = [
56
+ "TimebaseConfigProtocol",
57
+ "map_nearest",
58
+ "map_linear",
59
+ "compute_jitter_stats",
60
+ "enforce_jitter_budget",
61
+ "align_samples",
62
+ "sync_stream_to_timebase",
63
+ "align_pose_frames_to_reference",
64
+ "fit_robust_linear_model",
65
+ ]
66
+
67
+
68
+ # =============================================================================
69
+ # Protocols
70
+ # =============================================================================
71
+
72
+
73
+ class TimebaseConfigProtocol(Protocol):
74
+ """Protocol for timebase configuration access.
75
+
76
+ Defines minimal interface needed by sync modules without importing
77
+ from domain.config.TimebaseConfig.
78
+
79
+ Attributes:
80
+ mapping: Alignment strategy ("nearest" or "linear")
81
+ jitter_budget_s: Maximum acceptable jitter in seconds
82
+ """
83
+
84
+ mapping: Literal["nearest", "linear"]
85
+ jitter_budget_s: float
86
+
87
+
88
+ # =============================================================================
89
+ # Mapping Strategies (Primitives)
90
+ # =============================================================================
91
+
92
+
93
+ def map_nearest(sample_times: List[float], reference_times: List[float]) -> List[int]:
94
+ """Map samples to nearest reference timestamps.
95
+
96
+ Args:
97
+ sample_times: Times to align
98
+ reference_times: Reference timebase (sorted)
99
+
100
+ Returns:
101
+ List of indices into reference_times
102
+
103
+ Raises:
104
+ SyncError: Empty or non-monotonic reference
105
+
106
+ Example:
107
+ >>> indices = map_nearest([0.3, 1.7], [0.0, 1.0, 2.0])
108
+ """
109
+ if not reference_times:
110
+ raise SyncError("Cannot map to empty reference timebase")
111
+
112
+ # Check monotonicity
113
+ if reference_times != sorted(reference_times):
114
+ raise SyncError("Reference timestamps must be monotonic")
115
+
116
+ if not sample_times:
117
+ return []
118
+
119
+ # Check for large gaps and warn
120
+ ref_array = np.array(reference_times)
121
+ indices = []
122
+
123
+ for sample_time in sample_times:
124
+ # Find nearest index
125
+ idx = np.argmin(np.abs(ref_array - sample_time))
126
+ indices.append(int(idx))
127
+
128
+ # Check for large gaps
129
+ gap = abs(ref_array[idx] - sample_time)
130
+ if gap > 1.0: # > 1 second gap
131
+ warnings.warn(f"Sample time {sample_time} has large gap ({gap:.3f}s) from nearest reference", UserWarning)
132
+
133
+ return indices
134
+
135
+
136
+ def map_linear(sample_times: List[float], reference_times: List[float]) -> Tuple[List[Tuple[int, int]], List[Tuple[float, float]]]:
137
+ """Map samples using linear interpolation.
138
+
139
+ Args:
140
+ sample_times: Times to align
141
+ reference_times: Reference timebase (sorted)
142
+
143
+ Returns:
144
+ (indices, weights) where indices are (idx0, idx1) pairs and
145
+ weights are (w0, w1) for interpolation
146
+
147
+ Raises:
148
+ SyncError: Empty or non-monotonic reference
149
+
150
+ Example:
151
+ >>> indices, weights = map_linear([0.5], [0.0, 1.0])
152
+ """
153
+ if not reference_times:
154
+ raise SyncError("Cannot map to empty reference timebase")
155
+
156
+ if reference_times != sorted(reference_times):
157
+ raise SyncError("Reference timestamps must be monotonic")
158
+
159
+ if not sample_times:
160
+ return [], []
161
+
162
+ ref_array = np.array(reference_times)
163
+ indices = []
164
+ weights = []
165
+
166
+ for sample_time in sample_times:
167
+ # Find bracketing indices
168
+ idx_after = np.searchsorted(ref_array, sample_time)
169
+
170
+ if idx_after == 0:
171
+ # Before first reference point - clamp to first
172
+ indices.append((0, 0))
173
+ weights.append((1.0, 0.0))
174
+ elif idx_after >= len(ref_array):
175
+ # After last reference point - clamp to last
176
+ idx = len(ref_array) - 1
177
+ indices.append((idx, idx))
178
+ weights.append((1.0, 0.0))
179
+ else:
180
+ # Interpolate between idx_after-1 and idx_after
181
+ idx0 = idx_after - 1
182
+ idx1 = idx_after
183
+
184
+ t0 = ref_array[idx0]
185
+ t1 = ref_array[idx1]
186
+
187
+ # Linear interpolation weight
188
+ if t1 - t0 > 0:
189
+ w1 = (sample_time - t0) / (t1 - t0)
190
+ w0 = 1.0 - w1
191
+ else:
192
+ # Zero interval - equal weights
193
+ w0, w1 = 0.5, 0.5
194
+
195
+ indices.append((idx0, idx1))
196
+ weights.append((w0, w1))
197
+
198
+ return indices, weights
199
+
200
+
201
+ # =============================================================================
202
+ # Jitter Computation
203
+ # =============================================================================
204
+
205
+
206
+ def compute_jitter_stats(sample_times: List[float], reference_times: List[float], indices: List[int]) -> Dict[str, float]:
207
+ """Compute jitter statistics.
208
+
209
+ Args:
210
+ sample_times: Original sample times
211
+ reference_times: Reference timebase
212
+ indices: Mapping indices
213
+
214
+ Returns:
215
+ Dict with max_jitter_s and p95_jitter_s
216
+
217
+ Example:
218
+ >>> stats = compute_jitter_stats(samples, reference, indices)
219
+ """
220
+ if not sample_times or not indices:
221
+ return {"max_jitter_s": 0.0, "p95_jitter_s": 0.0}
222
+
223
+ ref_array = np.array(reference_times)
224
+ sample_array = np.array(sample_times)
225
+
226
+ # Compute jitter for each sample
227
+ jitters = []
228
+ for i, idx in enumerate(indices):
229
+ jitter = abs(sample_array[i] - ref_array[idx])
230
+ jitters.append(jitter)
231
+
232
+ jitter_array = np.array(jitters)
233
+
234
+ return {"max_jitter_s": float(np.max(jitter_array)), "p95_jitter_s": float(np.percentile(jitter_array, 95))}
235
+
236
+
237
+ # =============================================================================
238
+ # Jitter Budget Enforcement
239
+ # =============================================================================
240
+
241
+
242
+ def enforce_jitter_budget(max_jitter: float, p95_jitter: float, budget: float) -> None:
243
+ """Enforce jitter budget before NWB assembly.
244
+
245
+ Validates that observed jitter is within acceptable limits. This is
246
+ typically called before writing final NWB files to ensure data quality.
247
+
248
+ Args:
249
+ max_jitter: Maximum jitter observed (seconds)
250
+ p95_jitter: 95th percentile jitter (seconds)
251
+ budget: Configured jitter budget threshold (seconds)
252
+
253
+ Raises:
254
+ JitterExceedsBudgetError: If max or p95 jitter exceeds budget
255
+
256
+ Example:
257
+ >>> enforce_jitter_budget(
258
+ ... max_jitter=0.005,
259
+ ... p95_jitter=0.003,
260
+ ... budget=0.010
261
+ ... ) # Passes
262
+ """
263
+ if max_jitter > budget:
264
+ raise JitterExceedsBudgetError(f"Max jitter {max_jitter:.6f}s exceeds budget {budget:.6f}s")
265
+
266
+ if p95_jitter > budget:
267
+ raise JitterExceedsBudgetError(f"P95 jitter {p95_jitter:.6f}s exceeds budget {budget:.6f}s")
268
+
269
+
270
+ # =============================================================================
271
+ # High-Level Alignment (Primitives)
272
+ # =============================================================================
273
+
274
+
275
+ def align_samples(
276
+ sample_times: List[float],
277
+ reference_times: List[float],
278
+ config: TimebaseConfigProtocol,
279
+ enforce_budget: bool = False,
280
+ ) -> Dict[str, Any]:
281
+ """Align samples to reference timebase using configured strategy.
282
+
283
+ Orchestrates mapping, jitter computation, and budget enforcement.
284
+
285
+ Args:
286
+ sample_times: Times to align
287
+ reference_times: Reference timebase
288
+ config: Timebase configuration
289
+ enforce_budget: Enforce jitter budget
290
+
291
+ Returns:
292
+ Dict with indices, jitter_stats, and mapping
293
+
294
+ Raises:
295
+ JitterExceedsBudgetError: Jitter exceeds budget
296
+ SyncError: Alignment failed
297
+ """
298
+ if config.mapping == "nearest":
299
+ indices = map_nearest(sample_times, reference_times)
300
+ jitter_stats = compute_jitter_stats(sample_times, reference_times, indices)
301
+ result = {"indices": indices, "jitter_stats": jitter_stats, "mapping": "nearest"}
302
+
303
+ elif config.mapping == "linear":
304
+ indices, weights = map_linear(sample_times, reference_times)
305
+ # Jitter stats for linear interpolation are complex, using nearest for budget check
306
+ # This is a simplification - ideally we'd compute residual from interpolation
307
+ nearest_indices = map_nearest(sample_times, reference_times)
308
+ jitter_stats = compute_jitter_stats(sample_times, reference_times, nearest_indices)
309
+ result = {"indices": indices, "weights": weights, "jitter_stats": jitter_stats, "mapping": "linear"}
310
+
311
+ else:
312
+ raise SyncError(f"Unknown mapping strategy: {config.mapping}")
313
+
314
+ if enforce_budget:
315
+ enforce_jitter_budget(
316
+ max_jitter=jitter_stats["max_jitter_s"],
317
+ p95_jitter=jitter_stats["p95_jitter_s"],
318
+ budget=config.jitter_budget_s,
319
+ )
320
+
321
+ return result
322
+
323
+
324
+ # =============================================================================
325
+ # Stream Synchronization
326
+ # =============================================================================
327
+
328
+
329
+ def sync_stream_to_timebase(
330
+ sample_times: List[float],
331
+ reference_times: List[float],
332
+ config: TimebaseConfigProtocol,
333
+ enforce_budget: bool = False,
334
+ ) -> Dict[str, Any]:
335
+ """Align data stream timestamps to a reference timebase, correcting for clock drift.
336
+
337
+ This function performs temporal alignment between two clocks:
338
+ 1. Your device's clock (sample_times) - may have drift/offset
339
+ 2. A reference clock (reference_times) - the ground truth timebase
340
+
341
+ It returns timestamps on the reference timebase that correspond to your samples,
342
+ effectively "translating" from your device's clock to the reference clock.
343
+
344
+ **Important: The returned "aligned_times" are NOT simply reference_times!**
345
+ They are your sample_times mapped/interpolated onto the reference timebase to
346
+ correct for clock drift and offset.
347
+
348
+ **What this function does:**
349
+ - Finds correspondence between sample_times and reference_times
350
+ - Computes timestamps on the reference clock for each of your samples
351
+ - Accounts for clock drift, offset, and timing jitter
352
+ - Returns quality metrics to validate synchronization accuracy
353
+
354
+ **When to use:**
355
+ - Synchronizing video frames to camera TTL pulses (video clock → TTL clock)
356
+ - Aligning pose estimation to video timestamps (pose clock → video clock)
357
+ - Synchronizing facemap outputs to behavioral recordings
358
+ - Any case where two clocks need temporal alignment
359
+
360
+ Algorithm:
361
+ ----------
362
+ For "nearest" mapping:
363
+ 1. For each sample_time, find nearest reference_time
364
+ 2. Return that reference_time as the aligned timestamp
365
+ 3. Effectively snaps each sample to closest reference point
366
+
367
+ For "linear" mapping:
368
+ 1. For each sample_time, find bracketing reference_times
369
+ 2. Interpolate between them based on sample_time position
370
+ 3. Returns interpolated timestamps (smoother alignment)
371
+
372
+ Args:
373
+ sample_times: Original timestamps from your data stream's internal clock.
374
+ Examples:
375
+ - Video frame timestamps from camera's timer
376
+ - Pose estimation frame times from DLC/SLEAP processing clock
377
+ - Facemap processing timestamps
378
+ These may have clock drift relative to reference_times.
379
+
380
+ reference_times: Ground-truth timestamps from a master clock.
381
+ Typically from hardware TTL pulses (e.g., camera trigger signals).
382
+ These define the canonical timebase all data should align to.
383
+
384
+ config: Configuration object specifying:
385
+ - mapping: "nearest" (snap to closest) or "linear" (interpolate)
386
+ - jitter_budget: Maximum acceptable alignment error (seconds)
387
+ - Other alignment parameters
388
+
389
+ enforce_budget: If True, raise error when jitter exceeds configured budget.
390
+ Use this when synchronization quality is critical.
391
+
392
+ Returns:
393
+ Dictionary containing:
394
+ - indices: Reference indices used for each sample (int for nearest,
395
+ tuple for linear interpolation)
396
+ - aligned_times: Timestamps on reference timebase corresponding to your
397
+ samples. Use these instead of raw sample_times in NWB.
398
+ - jitter_stats: Quality metrics (mean, std, max jitter in seconds)
399
+ - mapping: Strategy used ("nearest" or "linear")
400
+
401
+ Raises:
402
+ JitterExceedsBudgetError: If enforce_budget=True and alignment quality poor
403
+ SyncError: If alignment fails due to incompatible data
404
+
405
+ Example - Video to Hardware TTL:
406
+ >>> from w2t_bkin import ttl, sync
407
+ >>>
408
+ >>> # Step 1: Get hardware clock (ground truth)
409
+ >>> ttl_pulses = ttl.get_ttl_pulses(rawdata_dir, {"ttl_camera": "TTLs/*.xa_7_0*.txt"})
410
+ >>> reference_times = ttl_pulses["ttl_camera"] # [0.0, 0.0334, 0.0667, ...]
411
+ >>>
412
+ >>> # Step 2: Get video's internal clock (may have drift)
413
+ >>> video_metadata = load_video_metadata("video.mp4")
414
+ >>> sample_times = video_metadata["frame_timestamps"] # [0.0, 0.033, 0.066, ...]
415
+ >>>
416
+ >>> # Step 3: Align video clock → hardware clock
417
+ >>> config = sync.TimebaseConfig(mapping="nearest", jitter_budget=0.001)
418
+ >>> result = sync.sync_stream_to_timebase(
419
+ ... sample_times=sample_times, # Video's clock (may drift)
420
+ ... reference_times=reference_times, # Hardware clock (ground truth)
421
+ ... config=config,
422
+ ... enforce_budget=True
423
+ ... )
424
+ >>>
425
+ >>> # Step 4: Use synchronized times (on reference clock, NOT raw sample times)
426
+ >>> aligned_timestamps = result["aligned_times"]
427
+ >>> # aligned_timestamps are now on TTL clock: [0.0, 0.0334, 0.0667, ...]
428
+ >>> # They correct for any drift between video and hardware clocks
429
+ >>>
430
+ >>> print(f"Mean jitter: {result['jitter_stats']['mean']*1000:.2f} ms")
431
+ >>> # Jitter measures alignment quality (difference between clocks)
432
+
433
+ Example - Understanding Clock Drift:
434
+ >>> # Your camera reports these timestamps (internal clock):
435
+ >>> sample_times = [0.0, 1.0, 2.0, 3.0] # Appears to be exactly 1 Hz
436
+ >>>
437
+ >>> # But hardware TTL shows actual times (ground truth):
438
+ >>> reference_times = [0.0, 1.01, 2.02, 3.03] # Camera is 1% slow!
439
+ >>>
440
+ >>> result = sync_stream_to_timebase(sample_times, reference_times, config)
441
+ >>> result["aligned_times"] # [0.0, 1.01, 2.02, 3.03]
442
+ >>> # ^ These are your frames placed on the TRUE timeline
443
+ >>> # NOT just reference_times copied - they're YOUR samples mapped correctly
444
+ """
445
+ # Perform alignment using generic strategy
446
+ result = align_samples(sample_times, reference_times, config, enforce_budget)
447
+
448
+ indices = result["indices"]
449
+
450
+ # Extract aligned timestamps from reference
451
+ # IMPORTANT: These are NOT just copying reference_times!
452
+ # They are reference_times at the indices/interpolation points
453
+ # that correspond to each sample_time
454
+ if config.mapping == "nearest":
455
+ # Snap each sample to nearest reference point
456
+ aligned_times = [reference_times[idx] for idx in indices]
457
+ elif config.mapping == "linear":
458
+ # Interpolate between reference points
459
+ aligned_times = []
460
+ weights = result.get("weights", [])
461
+ for (idx0, idx1), (w0, w1) in zip(indices, weights):
462
+ # Weighted average: places sample between two reference points
463
+ t_aligned = w0 * reference_times[idx0] + w1 * reference_times[idx1]
464
+ aligned_times.append(t_aligned)
465
+ else:
466
+ # Fallback: use nearest (should be caught by align_samples)
467
+ aligned_times = [reference_times[indices[0]] for _ in sample_times]
468
+
469
+ return {
470
+ "indices": indices,
471
+ "aligned_times": aligned_times,
472
+ "jitter_stats": result["jitter_stats"],
473
+ "mapping": result["mapping"],
474
+ }
475
+
476
+
477
+ def align_pose_frames_to_reference(
478
+ pose_data: List[Dict],
479
+ reference_times: List[float],
480
+ mapping: str = "nearest",
481
+ ) -> Dict[int, float]:
482
+ """Map pose frame indices to reference timebase timestamps.
483
+
484
+ This function is a specialized alignment utility for pose estimation data.
485
+ Unlike sync_stream_to_timebase, it works with **frame indices** rather than
486
+ timestamps, because DLC/SLEAP output typically only includes frame numbers.
487
+
488
+ **Use Case:**
489
+ After running DeepLabCut or SLEAP, you have pose data with frame indices
490
+ (0, 1, 2, ...) but no absolute timestamps. This function uses synchronized
491
+ video frame timestamps (from hardware TTL alignment) to assign timestamps
492
+ to each pose frame.
493
+
494
+ **Key Difference from sync_stream_to_timebase:**
495
+ - sync_stream_to_timebase: Aligns one set of timestamps to another
496
+ - align_pose_frames_to_reference: Maps frame INDEX → timestamp lookup
497
+
498
+ The result is the same: timestamps on the reference timebase for your data.
499
+
500
+ **Workflow:**
501
+ 1. Sync video frames to hardware TTL (using sync_stream_to_timebase)
502
+ - Input: video sample_times, TTL reference_times
503
+ - Output: video_aligned_times (video frames on TTL clock)
504
+ 2. Use video_aligned_times as reference_times here
505
+ 3. Map pose frame_index → corresponding video_aligned_time
506
+ 4. Result: pose timestamps on the same TTL clock as video
507
+
508
+ Algorithm:
509
+ ----------
510
+ For each pose frame:
511
+ 1. Extract frame_index (which video frame this pose belongs to)
512
+ 2. Lookup reference_times[frame_index] (direct or interpolated)
513
+ 3. Return as pose timestamp
514
+
515
+ Args:
516
+ pose_data: Harmonized pose data from DLC/SLEAP. List of dicts with:
517
+ - frame_index: Frame number in video (0-based integer)
518
+ - keypoints: Pose keypoint data (can be empty if tracking failed)
519
+
520
+ reference_times: Timestamps for video frames, already synchronized to
521
+ hardware TTL clock. Obtained from:
522
+ 1. Hardware TTL pulses for camera triggers
523
+ 2. Video sample_times aligned to TTL (via sync_stream_to_timebase)
524
+
525
+ Index i in reference_times = timestamp for video frame i.
526
+
527
+ mapping: Lookup strategy:
528
+ - "nearest": Direct lookup: frame_index → reference_times[frame_index]
529
+ - "linear": Extrapolate if frame_index exceeds reference_times length
530
+
531
+ Returns:
532
+ Dictionary mapping frame_index → absolute_timestamp (seconds).
533
+ These timestamps are on the same reference clock as reference_times.
534
+ Example: {0: 10.5, 1: 10.533, 2: 10.566, ...}
535
+
536
+ Raises:
537
+ SyncError: If mapping strategy invalid or data malformed
538
+
539
+ Example - Complete Synchronization Chain:
540
+ >>> from w2t_bkin import ttl, sync, pose
541
+ >>>
542
+ >>> # Step 1: Get hardware clock (TTL pulses)
543
+ >>> ttl_pulses = ttl.get_ttl_pulses(rawdata_dir, {"ttl_camera": "TTLs/*.xa_7_0*.txt"})
544
+ >>> ttl_times = ttl_pulses["ttl_camera"] # Ground truth: [0.0, 0.0334, 0.0667, ...]
545
+ >>>
546
+ >>> # Step 2: Align video to hardware clock
547
+ >>> video_metadata = load_video_metadata("video.mp4")
548
+ >>> video_sample_times = video_metadata["frame_timestamps"] # Video's clock
549
+ >>> video_result = sync.sync_stream_to_timebase(
550
+ ... sample_times=video_sample_times, # Video clock
551
+ ... reference_times=ttl_times, # Hardware clock
552
+ ... config=config
553
+ ... )
554
+ >>> video_aligned_times = video_result["aligned_times"]
555
+ >>> # video_aligned_times: video frames on TTL clock [0.0, 0.0334, 0.0667, ...]
556
+ >>>
557
+ >>> # Step 3: Load pose data (only has frame indices, no timestamps)
558
+ >>> pose_data, metadata = pose.import_dlc_pose("pose.h5")
559
+ >>> # pose_data: [{'frame_index': 0, 'keypoints': ...}, {'frame_index': 1, ...}, ...]
560
+ >>>
561
+ >>> # Step 4: Map pose frame indices → video timestamps (already on TTL clock)
562
+ >>> frame_timestamps = align_pose_frames_to_reference(
563
+ ... pose_data=pose_data,
564
+ ... reference_times=video_aligned_times, # Video frames on TTL clock
565
+ ... mapping="nearest"
566
+ ... )
567
+ >>> # frame_timestamps: {0: 0.0, 1: 0.0334, 2: 0.0667, ...}
568
+ >>> # These are pose frames on TTL clock (via video)
569
+ >>>
570
+ >>> # Step 5: Add timestamps to pose data
571
+ >>> for frame in pose_data:
572
+ ... frame['timestamp'] = frame_timestamps[frame['frame_index']]
573
+ >>>
574
+ >>> # Result: Pose, video, and TTL are all on the same timebase!
575
+ """
576
+ if not pose_data:
577
+ return {}
578
+
579
+ if not reference_times:
580
+ raise SyncError("Reference timebase is empty")
581
+
582
+ if mapping not in ["nearest", "linear"]:
583
+ raise SyncError(f"Unknown mapping strategy: {mapping}")
584
+
585
+ frame_timestamps = {}
586
+
587
+ for frame_data in pose_data:
588
+ frame_idx = frame_data["frame_index"]
589
+
590
+ # Map frame index to reference timestamp
591
+ if mapping == "nearest":
592
+ if frame_idx < len(reference_times):
593
+ # Direct lookup: frame N → reference_times[N]
594
+ timestamp = reference_times[frame_idx]
595
+ else:
596
+ # Out of bounds - use last timestamp
597
+ logger.warning(f"Frame {frame_idx} out of bounds, using last timestamp")
598
+ timestamp = reference_times[-1]
599
+
600
+ elif mapping == "linear":
601
+ if frame_idx < len(reference_times):
602
+ timestamp = reference_times[frame_idx]
603
+ else:
604
+ # Linear extrapolation beyond last frame
605
+ if len(reference_times) >= 2:
606
+ dt = reference_times[-1] - reference_times[-2]
607
+ timestamp = reference_times[-1] + dt * (frame_idx - len(reference_times) + 1)
608
+ else:
609
+ timestamp = reference_times[-1]
610
+
611
+ frame_timestamps[frame_idx] = timestamp
612
+
613
+ logger.debug(f"Aligned {len(frame_timestamps)} pose frames to reference timebase")
614
+ return frame_timestamps
615
+
616
+
617
+ # =============================================================================
618
+ # Robust Synchronization Recovery
619
+ # =============================================================================
620
+
621
+
622
+ def fit_robust_linear_model(
623
+ source_times: np.ndarray,
624
+ target_times: np.ndarray,
625
+ outlier_threshold_s: float = 0.1,
626
+ min_valid_points: int = 2,
627
+ ) -> Tuple[float, float, np.ndarray]:
628
+ """Fit a robust linear model to align two timebases, handling missing data/outliers.
629
+
630
+ Recovers the linear relationship: target = slope * source + intercept
631
+ even when the correspondence is noisy or has missing points (e.g., dropped TTL pulses).
632
+
633
+ Algorithm:
634
+ 1. Perform initial nearest-neighbor mapping
635
+ 2. Compute residuals (target - source)
636
+ 3. Identify outliers based on median residual and threshold
637
+ 4. Fit linear regression to valid pairs only
638
+
639
+ Args:
640
+ source_times: Timestamps from source clock (e.g., Bpod trial starts)
641
+ target_times: Timestamps from target clock (e.g., recorded TTL pulses)
642
+ outlier_threshold_s: Maximum residual deviation to consider valid (seconds)
643
+ min_valid_points: Minimum number of valid points required for fit
644
+
645
+ Returns:
646
+ Tuple containing:
647
+ - slope: Clock drift factor (approx 1.0)
648
+ - intercept: Clock offset (seconds)
649
+ - valid_mask: Boolean mask of source_times that were successfully matched
650
+
651
+ Raises:
652
+ SyncError: If too few valid points found to fit model
653
+ """
654
+ # 1. Naive Nearest Neighbor Mapping
655
+ indices = map_nearest(source_times.tolist(), target_times.tolist())
656
+ matched_target = target_times[indices]
657
+
658
+ # 2. Calculate Residuals
659
+ # diff = Target - Source
660
+ # For correct matches: diff ≈ Intercept (constant offset)
661
+ # For incorrect matches: diff ≈ Intercept ± Interval (outlier)
662
+ diffs = matched_target - source_times
663
+ median_diff = np.median(diffs)
664
+
665
+ # 3. Filter Outliers
666
+ valid_mask = np.abs(diffs - median_diff) < outlier_threshold_s
667
+ n_valid = np.sum(valid_mask)
668
+
669
+ if n_valid < min_valid_points:
670
+ raise SyncError(f"Too few valid points ({n_valid}) to fit robust model (min={min_valid_points})")
671
+
672
+ # 4. Robust Linear Regression on Valid Pairs
673
+ valid_source = source_times[valid_mask]
674
+ valid_target = matched_target[valid_mask]
675
+
676
+ res = stats.linregress(valid_source, valid_target)
677
+
678
+ return res.slope, res.intercept, valid_mask