w2t-bkin 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,165 @@
1
+ """Skeleton management utilities for pose estimation.
2
+
3
+ Provides functions to create, validate, and manage Skeleton objects
4
+ for use with ndx-pose PoseEstimation containers.
5
+
6
+ Key Functions:
7
+ --------------
8
+ - validate_skeleton_edges: Validate edge indices against node list
9
+ - create_skeleton: Factory for individual Skeleton objects
10
+ - create_skeletons_container: Factory for Skeletons container (NWB LabMetaData)
11
+
12
+ OOP Design:
13
+ -----------
14
+ This module follows the Factory pattern, providing clean constructors
15
+ for ndx-pose objects with proper validation and defaults.
16
+ """
17
+
18
+ from typing import List, Optional
19
+
20
+ from ndx_pose import Skeleton, Skeletons
21
+ import numpy as np
22
+
23
+
24
+ def validate_skeleton_edges(nodes: List[str], edges: List[List[int]]) -> None:
25
+ """Validate that skeleton edges reference valid node indices.
26
+
27
+ Args:
28
+ nodes: List of node names
29
+ edges: List of [node_idx1, node_idx2] pairs
30
+
31
+ Raises:
32
+ ValueError: If any edge references an invalid node index
33
+ TypeError: If edges are not properly formatted
34
+
35
+ Example:
36
+ >>> nodes = ["nose", "ear_left", "ear_right"]
37
+ >>> edges = [[0, 1], [0, 2]] # Valid
38
+ >>> validate_skeleton_edges(nodes, edges) # No error
39
+ >>> edges = [[0, 5]] # Invalid index
40
+ >>> validate_skeleton_edges(nodes, edges) # Raises ValueError
41
+ """
42
+ if not isinstance(nodes, list):
43
+ raise TypeError(f"nodes must be a list, got {type(nodes)}")
44
+
45
+ if not isinstance(edges, list):
46
+ raise TypeError(f"edges must be a list, got {type(edges)}")
47
+
48
+ n_nodes = len(nodes)
49
+
50
+ for i, edge in enumerate(edges):
51
+ if not isinstance(edge, (list, tuple)) or len(edge) != 2:
52
+ raise ValueError(f"Edge {i}: must be a list/tuple of 2 integers, got {edge}")
53
+
54
+ src, dst = edge
55
+
56
+ if not isinstance(src, int) or not isinstance(dst, int):
57
+ raise TypeError(f"Edge {i}: indices must be integers, got ({type(src)}, {type(dst)})")
58
+
59
+ if src < 0 or src >= n_nodes:
60
+ raise ValueError(f"Edge {i}: source index {src} out of range [0, {n_nodes})")
61
+
62
+ if dst < 0 or dst >= n_nodes:
63
+ raise ValueError(f"Edge {i}: destination index {dst} out of range [0, {n_nodes})")
64
+
65
+
66
+ def create_skeleton(
67
+ name: str,
68
+ nodes: List[str],
69
+ edges: Optional[List[List[int]]] = None,
70
+ validate: bool = True,
71
+ ) -> Skeleton:
72
+ """Create a Skeleton object with optional validation.
73
+
74
+ This is the recommended way to create Skeleton objects for use with
75
+ PoseEstimation. The skeleton should be added to a Skeletons container
76
+ in the NWBFile and then linked to PoseEstimation objects.
77
+
78
+ Args:
79
+ name: Skeleton identifier (e.g., "mouse_skeleton", "subject")
80
+ nodes: List of bodypart/node names in order
81
+ edges: List of [node_idx1, node_idx2] pairs defining connectivity.
82
+ If None or empty, creates skeleton with no edges.
83
+ validate: If True, validate that edges reference valid node indices
84
+
85
+ Returns:
86
+ Skeleton object ready to add to Skeletons container
87
+
88
+ Raises:
89
+ ValueError: If validation fails
90
+
91
+ Example:
92
+ >>> skeleton = create_skeleton(
93
+ ... name="mouse_skeleton",
94
+ ... nodes=["nose", "ear_left", "ear_right"],
95
+ ... edges=[[0, 1], [0, 2]],
96
+ ... validate=True
97
+ ... )
98
+ >>> # Add to NWB:
99
+ >>> # skeletons = Skeletons(skeletons=[skeleton])
100
+ >>> # nwbfile.add_lab_meta_data(skeletons)
101
+ """
102
+ if not nodes:
103
+ raise ValueError("Skeleton must have at least one node")
104
+
105
+ if edges is None:
106
+ edges = []
107
+
108
+ if validate and edges:
109
+ validate_skeleton_edges(nodes, edges)
110
+
111
+ # Convert edges to numpy array with proper shape
112
+ if edges:
113
+ edges_array = np.array(edges, dtype="uint8")
114
+ else:
115
+ # Empty array with correct shape (0, 2)
116
+ edges_array = np.array([], dtype="uint8").reshape(0, 2)
117
+
118
+ return Skeleton(name=name, nodes=nodes, edges=edges_array)
119
+
120
+
121
+ def create_skeletons_container(
122
+ name: str,
123
+ skeletons: List[Skeleton],
124
+ ) -> Skeletons:
125
+ """Create a Skeletons container (NWB LabMetaData) for one or more skeletons.
126
+
127
+ The Skeletons container is added to the NWBFile as LabMetaData, then
128
+ individual PoseEstimation objects link to specific skeletons within it.
129
+
130
+ Args:
131
+ name: Container identifier (required)
132
+ skeletons: List of Skeleton objects (must be non-empty)
133
+
134
+ Returns:
135
+ Skeletons container ready to add to NWBFile
136
+
137
+ Raises:
138
+ ValueError: If skeletons list is empty
139
+ TypeError: If skeletons is not a list or contains non-Skeleton objects
140
+
141
+ Example:
142
+ >>> # Single skeleton
143
+ >>> skeleton = create_skeleton(name="mouse", nodes=["nose", "ear_left"])
144
+ >>> container = create_skeletons_container(name="skeletons", skeletons=[skeleton])
145
+ >>> nwbfile.add_lab_meta_data(container)
146
+ >>>
147
+ >>> # Multiple skeletons
148
+ >>> mouse_skel = create_skeleton(name="mouse", nodes=["nose", "ear_left"])
149
+ >>> rat_skel = create_skeleton(name="rat", nodes=["snout", "ear"])
150
+ >>> container = create_skeletons_container(name="skeletons", skeletons=[mouse_skel, rat_skel])
151
+ >>> nwbfile.add_lab_meta_data(container)
152
+ """
153
+ if not isinstance(skeletons, list):
154
+ raise TypeError(f"skeletons must be a list, got {type(skeletons)}")
155
+
156
+ if not skeletons:
157
+ raise ValueError("skeletons list cannot be empty")
158
+
159
+ if not all(isinstance(s, Skeleton) for s in skeletons):
160
+ raise TypeError("All items in skeletons list must be Skeleton objects")
161
+
162
+ return Skeletons(skeletons=skeletons)
163
+
164
+
165
+ __all__ = ["validate_skeleton_edges", "create_skeleton", "create_skeletons_container"]
@@ -0,0 +1,477 @@
1
+ """Mock TTL signal generation from DeepLabCut pose data.
2
+
3
+ Generates synthetic TTL pulse sequences from pose estimation data by detecting
4
+ events in tracked body part movements. Designed for testing, validation, and
5
+ synthetic pipeline scenarios where TTL signals need to be derived from behavioral
6
+ tracking data.
7
+
8
+ Features:
9
+ ---------
10
+ - **Likelihood-based detection**: Threshold-based signal generation from keypoint confidence
11
+ - **Duration filtering**: Minimum duration requirements for valid signal phases
12
+ - **Frame-to-time conversion**: Automatic FPS-based timestamp generation
13
+ - **Flexible triggering**: Support for ON/OFF transitions, state changes, and custom predicates
14
+ - **TTL file format**: Compatible with w2t_bkin.sync.ttl loader (one timestamp per line)
15
+
16
+ Use Cases:
17
+ ----------
18
+ 1. Generate trial light signals from tracked LED positions
19
+ 2. Create behavioral event markers from pose kinematics
20
+ 3. Produce synthetic TTL data for end-to-end pipeline testing
21
+ 4. Validate synchronization logic with known ground truth
22
+
23
+ TTL File Format:
24
+ ----------------
25
+ One floating-point timestamp per line (seconds), sorted ascending:
26
+
27
+ 0.0000
28
+ 2.0000
29
+ 4.0000
30
+ ...
31
+
32
+ Example:
33
+ --------
34
+ >>> from pathlib import Path
35
+ >>> from w2t_bkin.pose.ttl_mock import (
36
+ ... TTLMockOptions,
37
+ ... generate_ttl_from_dlc_likelihood,
38
+ ... write_ttl_timestamps
39
+ ... )
40
+ >>>
41
+ >>> # Generate TTL pulses from trial_light likelihood
42
+ >>> h5_path = Path("pose_output.h5")
43
+ >>> options = TTLMockOptions(
44
+ ... bodypart="trial_light",
45
+ ... likelihood_threshold=0.99,
46
+ ... min_duration_frames=301,
47
+ ... fps=150.0
48
+ ... )
49
+ >>> timestamps = generate_ttl_from_dlc_likelihood(h5_path, options)
50
+ >>> print(f"Generated {len(timestamps)} TTL pulses")
51
+ >>>
52
+ >>> # Write to TTL file
53
+ >>> output_path = Path("TTLs/trial_light.txt")
54
+ >>> write_ttl_timestamps(timestamps, output_path)
55
+
56
+ Integration with Pipeline:
57
+ --------------------------
58
+ >>> from w2t_bkin.ttl import get_ttl_pulses
59
+ >>> from w2t_bkin.config import load_session
60
+ >>>
61
+ >>> # Generate mock TTL from pose data
62
+ >>> generate_and_write_ttl_from_pose(
63
+ ... h5_path=session_dir / "pose.h5",
64
+ ... output_path=session_dir / "TTLs/ttl_sync.txt",
65
+ ... options=TTLMockOptions(bodypart="trial_light", ...)
66
+ ... )
67
+ >>>
68
+ >>> # Load in pipeline
69
+ >>> session = load_session(session_dir / "session.toml")
70
+ >>> pulses = get_ttl_pulses(session)
71
+ >>> print(len(pulses['ttl_sync'])) # Matches generated count
72
+
73
+ Requirements:
74
+ -------------
75
+ - pandas (for HDF5 reading)
76
+ - numpy (for numerical operations)
77
+ - Python 3.10+
78
+
79
+ See Also:
80
+ ---------
81
+ - synthetic.ttl_synth: Pure synthetic TTL generation with deterministic RNG
82
+ - w2t_bkin.sync.ttl: TTL pulse loading and validation
83
+ - w2t_bkin.pose.core: DLC pose data import
84
+ """
85
+
86
+ from __future__ import annotations
87
+
88
+ import logging
89
+ from pathlib import Path
90
+ from typing import Callable, List, Optional, Tuple
91
+
92
+ import numpy as np
93
+ import pandas as pd
94
+ from pydantic import BaseModel, Field, field_validator
95
+
96
+ from w2t_bkin.pose.core import PoseError
97
+
98
+ logger = logging.getLogger(__name__)
99
+
100
+
101
+ class TTLMockOptions(BaseModel):
102
+ """Configuration for generating mock TTL signals from pose data.
103
+
104
+ Attributes:
105
+ bodypart: Name of the body part to track (must match DLC keypoint name)
106
+ likelihood_threshold: Minimum confidence score for signal ON state (0-1)
107
+ min_duration_frames: Minimum number of consecutive frames for valid signal
108
+ fps: Camera frame rate for converting frame indices to timestamps
109
+ transition_type: Type of signal transition to detect
110
+ - 'rising': Detect OFF→ON transitions (signal start)
111
+ - 'falling': Detect ON→OFF transitions (signal end)
112
+ - 'both': Detect both transitions
113
+ start_time_offset_s: Time offset to add to all timestamps (seconds)
114
+ filter_consecutive: If True, only keep first pulse in consecutive groups
115
+ """
116
+
117
+ bodypart: str = Field(..., description="DLC body part name to track")
118
+ likelihood_threshold: float = Field(default=0.99, ge=0.0, le=1.0, description="Minimum confidence threshold")
119
+ min_duration_frames: int = Field(default=1, ge=1, description="Minimum frames for valid signal phase")
120
+ fps: float = Field(default=30.0, gt=0.0, description="Camera frame rate (Hz)")
121
+ transition_type: str = Field(default="rising", pattern="^(rising|falling|both)$", description="Transition detection mode")
122
+ start_time_offset_s: float = Field(default=0.0, description="Time offset for all timestamps (s)")
123
+ filter_consecutive: bool = Field(default=False, description="Keep only first pulse in consecutive groups")
124
+
125
+ @field_validator("bodypart")
126
+ @classmethod
127
+ def bodypart_not_empty(cls, v: str) -> str:
128
+ """Validate bodypart is not empty."""
129
+ if not v or not v.strip():
130
+ raise ValueError("bodypart must be a non-empty string")
131
+ return v.strip()
132
+
133
+
134
+ def load_dlc_likelihood_series(h5_path: Path, bodypart: str, scorer: Optional[str] = None) -> pd.Series:
135
+ """Load likelihood time series for a specific body part from DLC H5 file.
136
+
137
+ Args:
138
+ h5_path: Path to DeepLabCut H5 output file
139
+ bodypart: Name of body part to extract
140
+ scorer: Optional scorer name (auto-detected if None)
141
+
142
+ Returns:
143
+ Pandas Series with frame indices and likelihood values
144
+
145
+ Raises:
146
+ PoseError: If file not found, format invalid, or bodypart missing
147
+ """
148
+ if not h5_path.exists():
149
+ raise PoseError(f"DLC H5 file not found: {h5_path}")
150
+
151
+ try:
152
+ df = pd.read_hdf(h5_path, "df_with_missing")
153
+ except (KeyError, OSError, ValueError) as e:
154
+ raise PoseError(f"Failed to read DLC H5 file {h5_path}: {e}") from e
155
+
156
+ # Validate MultiIndex structure
157
+ if not isinstance(df.columns, pd.MultiIndex) or df.columns.nlevels != 3:
158
+ raise PoseError(f"Invalid DLC format: expected 3-level MultiIndex, got {type(df.columns)}")
159
+
160
+ # Auto-detect scorer if not provided
161
+ if scorer is None:
162
+ scorer = df.columns.get_level_values(0)[0]
163
+ logger.debug(f"Auto-detected scorer: {scorer}")
164
+
165
+ # Check if bodypart exists
166
+ bodyparts = df.columns.get_level_values(1).unique()
167
+ if bodypart not in bodyparts:
168
+ raise PoseError(f"Body part '{bodypart}' not found. Available: {list(bodyparts)}")
169
+
170
+ # Extract likelihood column
171
+ try:
172
+ likelihood = df[(scorer, bodypart, "likelihood")]
173
+ except KeyError as e:
174
+ raise PoseError(f"Failed to extract likelihood for '{bodypart}': {e}") from e
175
+
176
+ return likelihood
177
+
178
+
179
+ def detect_signal_transitions(
180
+ signal: pd.Series,
181
+ transition_type: str = "rising",
182
+ ) -> Tuple[List[int], List[int]]:
183
+ """Detect rising and/or falling edge transitions in a boolean signal.
184
+
185
+ Args:
186
+ signal: Boolean series indicating signal state (True=ON, False=OFF)
187
+ transition_type: Type of transitions to detect ('rising', 'falling', 'both')
188
+
189
+ Returns:
190
+ Tuple of (onsets, offsets) frame index lists
191
+ - onsets: Frame indices where signal transitions OFF→ON
192
+ - offsets: Frame indices where signal transitions ON→OFF
193
+
194
+ Example:
195
+ >>> signal = pd.Series([False, False, True, True, False, True])
196
+ >>> onsets, offsets = detect_signal_transitions(signal, 'rising')
197
+ >>> print(onsets) # [2, 5]
198
+ >>> print(offsets) # [4]
199
+ """
200
+ # Compute transitions using shift
201
+ prev_signal = signal.shift(1, fill_value=False)
202
+
203
+ onsets = []
204
+ offsets = []
205
+
206
+ if transition_type in ("rising", "both"):
207
+ # Rising edge: previous=False AND current=True
208
+ rising_mask = (~prev_signal) & signal
209
+ onsets = signal.index[rising_mask].tolist()
210
+
211
+ if transition_type in ("falling", "both"):
212
+ # Falling edge: previous=True AND current=False
213
+ falling_mask = prev_signal & (~signal)
214
+ offsets = signal.index[falling_mask].tolist()
215
+
216
+ return onsets, offsets
217
+
218
+
219
+ def filter_by_duration(
220
+ onsets: List[int],
221
+ offsets: List[int],
222
+ min_duration_frames: int,
223
+ ) -> Tuple[List[int], List[int]]:
224
+ """Filter signal phases by minimum duration requirement.
225
+
226
+ Args:
227
+ onsets: Frame indices of signal ON transitions
228
+ offsets: Frame indices of signal OFF transitions
229
+ min_duration_frames: Minimum number of frames for valid phase
230
+
231
+ Returns:
232
+ Tuple of (filtered_onsets, filtered_offsets) with only valid phases
233
+
234
+ Example:
235
+ >>> onsets = [10, 50, 100]
236
+ >>> offsets = [15, 55, 400] # Durations: 5, 5, 300
237
+ >>> filtered = filter_by_duration(onsets, offsets, min_duration_frames=10)
238
+ >>> print(filtered[0]) # [100]
239
+ >>> print(filtered[1]) # [400]
240
+ """
241
+ if not onsets or not offsets:
242
+ return [], []
243
+
244
+ # Ensure we have matching pairs
245
+ min_len = min(len(onsets), len(offsets))
246
+ onsets = onsets[:min_len]
247
+ offsets = offsets[:min_len]
248
+
249
+ # Calculate durations
250
+ durations = [off - on for on, off in zip(onsets, offsets)]
251
+
252
+ # Filter by minimum duration
253
+ valid_indices = [i for i, dur in enumerate(durations) if dur >= min_duration_frames]
254
+ filtered_onsets = [onsets[i] for i in valid_indices]
255
+ filtered_offsets = [offsets[i] for i in valid_indices]
256
+
257
+ return filtered_onsets, filtered_offsets
258
+
259
+
260
+ def frames_to_timestamps(frame_indices: List[int], fps: float, offset_s: float = 0.0) -> List[float]:
261
+ """Convert frame indices to timestamps in seconds.
262
+
263
+ Args:
264
+ frame_indices: List of frame indices (0-based)
265
+ fps: Frame rate in frames per second
266
+ offset_s: Time offset to add to all timestamps
267
+
268
+ Returns:
269
+ List of timestamps in seconds
270
+
271
+ Example:
272
+ >>> frames = [0, 150, 300]
273
+ >>> timestamps = frames_to_timestamps(frames, fps=150.0)
274
+ >>> print(timestamps) # [0.0, 1.0, 2.0]
275
+ """
276
+ if not frame_indices:
277
+ return []
278
+
279
+ timestamps = [frame / fps + offset_s for frame in frame_indices]
280
+ return timestamps
281
+
282
+
283
+ def generate_ttl_from_dlc_likelihood(
284
+ h5_path: Path,
285
+ options: TTLMockOptions,
286
+ scorer: Optional[str] = None,
287
+ ) -> List[float]:
288
+ """Generate mock TTL timestamps from DLC likelihood data.
289
+
290
+ Main entry point for likelihood-based TTL generation. Loads pose data,
291
+ applies threshold and duration filters, and converts to timestamps.
292
+
293
+ Args:
294
+ h5_path: Path to DeepLabCut H5 output file
295
+ options: Configuration for TTL generation
296
+ scorer: Optional DLC scorer name (auto-detected if None)
297
+
298
+ Returns:
299
+ List of TTL pulse timestamps in seconds (sorted)
300
+
301
+ Raises:
302
+ PoseError: If file not found, format invalid, or bodypart missing
303
+
304
+ Example:
305
+ >>> options = TTLMockOptions(
306
+ ... bodypart="trial_light",
307
+ ... likelihood_threshold=0.99,
308
+ ... min_duration_frames=301,
309
+ ... fps=150.0,
310
+ ... transition_type="rising"
311
+ ... )
312
+ >>> timestamps = generate_ttl_from_dlc_likelihood(Path("pose.h5"), options)
313
+ """
314
+ logger.info(f"Generating TTL from DLC pose data: {h5_path}")
315
+ logger.debug(f"Options: bodypart={options.bodypart}, threshold={options.likelihood_threshold}, " f"min_duration={options.min_duration_frames}, fps={options.fps}")
316
+
317
+ # Load likelihood series
318
+ likelihood = load_dlc_likelihood_series(h5_path, options.bodypart, scorer)
319
+ logger.debug(f"Loaded {len(likelihood)} frames, mean likelihood: {likelihood.mean():.3f}")
320
+
321
+ # Create boolean signal from threshold
322
+ signal = likelihood >= options.likelihood_threshold
323
+ high_conf_count = signal.sum()
324
+ logger.debug(f"Frames above threshold: {high_conf_count} ({100*high_conf_count/len(signal):.1f}%)")
325
+
326
+ # Detect transitions
327
+ # For duration filtering, we always need both onsets and offsets
328
+ if options.min_duration_frames > 1:
329
+ onsets, offsets = detect_signal_transitions(signal, "both")
330
+ logger.debug(f"Detected {len(onsets)} onsets, {len(offsets)} offsets (for duration filtering)")
331
+ onsets, offsets = filter_by_duration(onsets, offsets, options.min_duration_frames)
332
+ logger.debug(f"After duration filter: {len(onsets)} valid phases")
333
+ else:
334
+ onsets, offsets = detect_signal_transitions(signal, options.transition_type)
335
+ logger.debug(f"Detected {len(onsets)} onsets, {len(offsets)} offsets")
336
+
337
+ # Select timestamps based on transition type
338
+ if options.transition_type == "rising":
339
+ frame_indices = onsets
340
+ elif options.transition_type == "falling":
341
+ frame_indices = offsets
342
+ else: # both
343
+ frame_indices = sorted(onsets + offsets)
344
+
345
+ # Convert to timestamps
346
+ timestamps = frames_to_timestamps(frame_indices, options.fps, options.start_time_offset_s)
347
+
348
+ logger.info(f"Generated {len(timestamps)} TTL pulses from {options.bodypart}")
349
+ if timestamps:
350
+ logger.debug(f"Time range: {timestamps[0]:.3f}s - {timestamps[-1]:.3f}s")
351
+
352
+ return timestamps
353
+
354
+
355
+ def generate_ttl_from_custom_predicate(
356
+ h5_path: Path,
357
+ predicate: Callable[[pd.DataFrame], pd.Series],
358
+ options: TTLMockOptions,
359
+ ) -> List[float]:
360
+ """Generate TTL timestamps using a custom predicate function.
361
+
362
+ Advanced API for complex signal generation logic. The predicate receives
363
+ the full DLC DataFrame and returns a boolean Series indicating signal state.
364
+
365
+ Args:
366
+ h5_path: Path to DeepLabCut H5 output file
367
+ predicate: Function that takes DataFrame and returns boolean Series
368
+ options: Configuration (fps, transition_type, etc.)
369
+
370
+ Returns:
371
+ List of TTL pulse timestamps in seconds
372
+
373
+ Example:
374
+ >>> def detect_movement(df):
375
+ ... # Detect when nose moves > 10 pixels between frames
376
+ ... scorer = df.columns.get_level_values(0)[0]
377
+ ... x = df[(scorer, 'nose', 'x')]
378
+ ... y = df[(scorer, 'nose', 'y')]
379
+ ... dx = x.diff().abs()
380
+ ... dy = y.diff().abs()
381
+ ... return (dx + dy) > 10
382
+ >>>
383
+ >>> options = TTLMockOptions(bodypart="nose", fps=150.0)
384
+ >>> timestamps = generate_ttl_from_custom_predicate(
385
+ ... Path("pose.h5"), detect_movement, options
386
+ ... )
387
+ """
388
+ if not h5_path.exists():
389
+ raise PoseError(f"DLC H5 file not found: {h5_path}")
390
+
391
+ try:
392
+ df = pd.read_hdf(h5_path, "df_with_missing")
393
+ except (KeyError, OSError, ValueError) as e:
394
+ raise PoseError(f"Failed to read DLC H5 file {h5_path}: {e}") from e
395
+
396
+ # Apply custom predicate
397
+ signal = predicate(df)
398
+
399
+ if not isinstance(signal, pd.Series):
400
+ raise PoseError(f"Predicate must return pd.Series, got {type(signal)}")
401
+
402
+ # Detect transitions
403
+ onsets, offsets = detect_signal_transitions(signal, options.transition_type)
404
+
405
+ # Filter by minimum duration if needed
406
+ if options.min_duration_frames > 1 and options.transition_type in ("rising", "both"):
407
+ onsets, offsets = filter_by_duration(onsets, offsets, options.min_duration_frames)
408
+
409
+ # Select timestamps based on transition type
410
+ if options.transition_type == "rising":
411
+ frame_indices = onsets
412
+ elif options.transition_type == "falling":
413
+ frame_indices = offsets
414
+ else: # both
415
+ frame_indices = sorted(onsets + offsets)
416
+
417
+ # Convert to timestamps
418
+ timestamps = frames_to_timestamps(frame_indices, options.fps, options.start_time_offset_s)
419
+
420
+ return timestamps
421
+
422
+
423
+ def write_ttl_timestamps(timestamps: List[float], output_path: Path) -> None:
424
+ """Write TTL timestamps to file in w2t_bkin format.
425
+
426
+ Writes one timestamp per line, sorted, with high precision. Creates parent
427
+ directories if needed.
428
+
429
+ Args:
430
+ timestamps: List of timestamps in seconds
431
+ output_path: Path to output file
432
+
433
+ Example:
434
+ >>> timestamps = [0.0, 1.5, 3.0]
435
+ >>> write_ttl_timestamps(timestamps, Path("TTLs/ttl_sync.txt"))
436
+ """
437
+ output_path.parent.mkdir(parents=True, exist_ok=True)
438
+
439
+ # Sort timestamps
440
+ sorted_timestamps = sorted(timestamps)
441
+
442
+ # Write with high precision
443
+ with open(output_path, "w") as f:
444
+ for ts in sorted_timestamps:
445
+ f.write(f"{ts:.6f}\n")
446
+
447
+ logger.info(f"Wrote {len(timestamps)} TTL timestamps to {output_path}")
448
+
449
+
450
+ def generate_and_write_ttl_from_pose(
451
+ h5_path: Path,
452
+ output_path: Path,
453
+ options: TTLMockOptions,
454
+ scorer: Optional[str] = None,
455
+ ) -> int:
456
+ """Convenience function to generate and write TTL file in one call.
457
+
458
+ Args:
459
+ h5_path: Path to DeepLabCut H5 output file
460
+ output_path: Path to output TTL file
461
+ options: Configuration for TTL generation
462
+ scorer: Optional DLC scorer name
463
+
464
+ Returns:
465
+ Number of TTL pulses generated
466
+
467
+ Example:
468
+ >>> count = generate_and_write_ttl_from_pose(
469
+ ... h5_path=Path("pose.h5"),
470
+ ... output_path=Path("TTLs/trial_light.txt"),
471
+ ... options=TTLMockOptions(bodypart="trial_light", fps=150.0)
472
+ ... )
473
+ >>> print(f"Generated {count} pulses")
474
+ """
475
+ timestamps = generate_ttl_from_dlc_likelihood(h5_path, options, scorer)
476
+ write_ttl_timestamps(timestamps, output_path)
477
+ return len(timestamps)