w2t-bkin 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- w2t_bkin/__init__.py +85 -0
- w2t_bkin/behavior/__init__.py +115 -0
- w2t_bkin/behavior/core.py +1027 -0
- w2t_bkin/bpod/__init__.py +38 -0
- w2t_bkin/bpod/core.py +519 -0
- w2t_bkin/config.py +625 -0
- w2t_bkin/dlc/__init__.py +59 -0
- w2t_bkin/dlc/core.py +448 -0
- w2t_bkin/dlc/models.py +124 -0
- w2t_bkin/exceptions.py +426 -0
- w2t_bkin/facemap/__init__.py +42 -0
- w2t_bkin/facemap/core.py +397 -0
- w2t_bkin/facemap/models.py +134 -0
- w2t_bkin/pipeline.py +665 -0
- w2t_bkin/pose/__init__.py +48 -0
- w2t_bkin/pose/core.py +227 -0
- w2t_bkin/pose/io.py +363 -0
- w2t_bkin/pose/skeleton.py +165 -0
- w2t_bkin/pose/ttl_mock.py +477 -0
- w2t_bkin/session.py +423 -0
- w2t_bkin/sync/__init__.py +72 -0
- w2t_bkin/sync/core.py +678 -0
- w2t_bkin/sync/stats.py +176 -0
- w2t_bkin/sync/timebase.py +311 -0
- w2t_bkin/sync/ttl.py +254 -0
- w2t_bkin/transcode/__init__.py +38 -0
- w2t_bkin/transcode/core.py +303 -0
- w2t_bkin/transcode/models.py +96 -0
- w2t_bkin/ttl/__init__.py +64 -0
- w2t_bkin/ttl/core.py +518 -0
- w2t_bkin/ttl/models.py +19 -0
- w2t_bkin/utils.py +1093 -0
- w2t_bkin-0.0.6.dist-info/METADATA +145 -0
- w2t_bkin-0.0.6.dist-info/RECORD +36 -0
- w2t_bkin-0.0.6.dist-info/WHEEL +4 -0
- w2t_bkin-0.0.6.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
"""Skeleton management utilities for pose estimation.
|
|
2
|
+
|
|
3
|
+
Provides functions to create, validate, and manage Skeleton objects
|
|
4
|
+
for use with ndx-pose PoseEstimation containers.
|
|
5
|
+
|
|
6
|
+
Key Functions:
|
|
7
|
+
--------------
|
|
8
|
+
- validate_skeleton_edges: Validate edge indices against node list
|
|
9
|
+
- create_skeleton: Factory for individual Skeleton objects
|
|
10
|
+
- create_skeletons_container: Factory for Skeletons container (NWB LabMetaData)
|
|
11
|
+
|
|
12
|
+
OOP Design:
|
|
13
|
+
-----------
|
|
14
|
+
This module follows the Factory pattern, providing clean constructors
|
|
15
|
+
for ndx-pose objects with proper validation and defaults.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
from typing import List, Optional
|
|
19
|
+
|
|
20
|
+
from ndx_pose import Skeleton, Skeletons
|
|
21
|
+
import numpy as np
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def validate_skeleton_edges(nodes: List[str], edges: List[List[int]]) -> None:
|
|
25
|
+
"""Validate that skeleton edges reference valid node indices.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
nodes: List of node names
|
|
29
|
+
edges: List of [node_idx1, node_idx2] pairs
|
|
30
|
+
|
|
31
|
+
Raises:
|
|
32
|
+
ValueError: If any edge references an invalid node index
|
|
33
|
+
TypeError: If edges are not properly formatted
|
|
34
|
+
|
|
35
|
+
Example:
|
|
36
|
+
>>> nodes = ["nose", "ear_left", "ear_right"]
|
|
37
|
+
>>> edges = [[0, 1], [0, 2]] # Valid
|
|
38
|
+
>>> validate_skeleton_edges(nodes, edges) # No error
|
|
39
|
+
>>> edges = [[0, 5]] # Invalid index
|
|
40
|
+
>>> validate_skeleton_edges(nodes, edges) # Raises ValueError
|
|
41
|
+
"""
|
|
42
|
+
if not isinstance(nodes, list):
|
|
43
|
+
raise TypeError(f"nodes must be a list, got {type(nodes)}")
|
|
44
|
+
|
|
45
|
+
if not isinstance(edges, list):
|
|
46
|
+
raise TypeError(f"edges must be a list, got {type(edges)}")
|
|
47
|
+
|
|
48
|
+
n_nodes = len(nodes)
|
|
49
|
+
|
|
50
|
+
for i, edge in enumerate(edges):
|
|
51
|
+
if not isinstance(edge, (list, tuple)) or len(edge) != 2:
|
|
52
|
+
raise ValueError(f"Edge {i}: must be a list/tuple of 2 integers, got {edge}")
|
|
53
|
+
|
|
54
|
+
src, dst = edge
|
|
55
|
+
|
|
56
|
+
if not isinstance(src, int) or not isinstance(dst, int):
|
|
57
|
+
raise TypeError(f"Edge {i}: indices must be integers, got ({type(src)}, {type(dst)})")
|
|
58
|
+
|
|
59
|
+
if src < 0 or src >= n_nodes:
|
|
60
|
+
raise ValueError(f"Edge {i}: source index {src} out of range [0, {n_nodes})")
|
|
61
|
+
|
|
62
|
+
if dst < 0 or dst >= n_nodes:
|
|
63
|
+
raise ValueError(f"Edge {i}: destination index {dst} out of range [0, {n_nodes})")
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def create_skeleton(
|
|
67
|
+
name: str,
|
|
68
|
+
nodes: List[str],
|
|
69
|
+
edges: Optional[List[List[int]]] = None,
|
|
70
|
+
validate: bool = True,
|
|
71
|
+
) -> Skeleton:
|
|
72
|
+
"""Create a Skeleton object with optional validation.
|
|
73
|
+
|
|
74
|
+
This is the recommended way to create Skeleton objects for use with
|
|
75
|
+
PoseEstimation. The skeleton should be added to a Skeletons container
|
|
76
|
+
in the NWBFile and then linked to PoseEstimation objects.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
name: Skeleton identifier (e.g., "mouse_skeleton", "subject")
|
|
80
|
+
nodes: List of bodypart/node names in order
|
|
81
|
+
edges: List of [node_idx1, node_idx2] pairs defining connectivity.
|
|
82
|
+
If None or empty, creates skeleton with no edges.
|
|
83
|
+
validate: If True, validate that edges reference valid node indices
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
Skeleton object ready to add to Skeletons container
|
|
87
|
+
|
|
88
|
+
Raises:
|
|
89
|
+
ValueError: If validation fails
|
|
90
|
+
|
|
91
|
+
Example:
|
|
92
|
+
>>> skeleton = create_skeleton(
|
|
93
|
+
... name="mouse_skeleton",
|
|
94
|
+
... nodes=["nose", "ear_left", "ear_right"],
|
|
95
|
+
... edges=[[0, 1], [0, 2]],
|
|
96
|
+
... validate=True
|
|
97
|
+
... )
|
|
98
|
+
>>> # Add to NWB:
|
|
99
|
+
>>> # skeletons = Skeletons(skeletons=[skeleton])
|
|
100
|
+
>>> # nwbfile.add_lab_meta_data(skeletons)
|
|
101
|
+
"""
|
|
102
|
+
if not nodes:
|
|
103
|
+
raise ValueError("Skeleton must have at least one node")
|
|
104
|
+
|
|
105
|
+
if edges is None:
|
|
106
|
+
edges = []
|
|
107
|
+
|
|
108
|
+
if validate and edges:
|
|
109
|
+
validate_skeleton_edges(nodes, edges)
|
|
110
|
+
|
|
111
|
+
# Convert edges to numpy array with proper shape
|
|
112
|
+
if edges:
|
|
113
|
+
edges_array = np.array(edges, dtype="uint8")
|
|
114
|
+
else:
|
|
115
|
+
# Empty array with correct shape (0, 2)
|
|
116
|
+
edges_array = np.array([], dtype="uint8").reshape(0, 2)
|
|
117
|
+
|
|
118
|
+
return Skeleton(name=name, nodes=nodes, edges=edges_array)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def create_skeletons_container(
|
|
122
|
+
name: str,
|
|
123
|
+
skeletons: List[Skeleton],
|
|
124
|
+
) -> Skeletons:
|
|
125
|
+
"""Create a Skeletons container (NWB LabMetaData) for one or more skeletons.
|
|
126
|
+
|
|
127
|
+
The Skeletons container is added to the NWBFile as LabMetaData, then
|
|
128
|
+
individual PoseEstimation objects link to specific skeletons within it.
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
name: Container identifier (required)
|
|
132
|
+
skeletons: List of Skeleton objects (must be non-empty)
|
|
133
|
+
|
|
134
|
+
Returns:
|
|
135
|
+
Skeletons container ready to add to NWBFile
|
|
136
|
+
|
|
137
|
+
Raises:
|
|
138
|
+
ValueError: If skeletons list is empty
|
|
139
|
+
TypeError: If skeletons is not a list or contains non-Skeleton objects
|
|
140
|
+
|
|
141
|
+
Example:
|
|
142
|
+
>>> # Single skeleton
|
|
143
|
+
>>> skeleton = create_skeleton(name="mouse", nodes=["nose", "ear_left"])
|
|
144
|
+
>>> container = create_skeletons_container(name="skeletons", skeletons=[skeleton])
|
|
145
|
+
>>> nwbfile.add_lab_meta_data(container)
|
|
146
|
+
>>>
|
|
147
|
+
>>> # Multiple skeletons
|
|
148
|
+
>>> mouse_skel = create_skeleton(name="mouse", nodes=["nose", "ear_left"])
|
|
149
|
+
>>> rat_skel = create_skeleton(name="rat", nodes=["snout", "ear"])
|
|
150
|
+
>>> container = create_skeletons_container(name="skeletons", skeletons=[mouse_skel, rat_skel])
|
|
151
|
+
>>> nwbfile.add_lab_meta_data(container)
|
|
152
|
+
"""
|
|
153
|
+
if not isinstance(skeletons, list):
|
|
154
|
+
raise TypeError(f"skeletons must be a list, got {type(skeletons)}")
|
|
155
|
+
|
|
156
|
+
if not skeletons:
|
|
157
|
+
raise ValueError("skeletons list cannot be empty")
|
|
158
|
+
|
|
159
|
+
if not all(isinstance(s, Skeleton) for s in skeletons):
|
|
160
|
+
raise TypeError("All items in skeletons list must be Skeleton objects")
|
|
161
|
+
|
|
162
|
+
return Skeletons(skeletons=skeletons)
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
__all__ = ["validate_skeleton_edges", "create_skeleton", "create_skeletons_container"]
|
|
@@ -0,0 +1,477 @@
|
|
|
1
|
+
"""Mock TTL signal generation from DeepLabCut pose data.
|
|
2
|
+
|
|
3
|
+
Generates synthetic TTL pulse sequences from pose estimation data by detecting
|
|
4
|
+
events in tracked body part movements. Designed for testing, validation, and
|
|
5
|
+
synthetic pipeline scenarios where TTL signals need to be derived from behavioral
|
|
6
|
+
tracking data.
|
|
7
|
+
|
|
8
|
+
Features:
|
|
9
|
+
---------
|
|
10
|
+
- **Likelihood-based detection**: Threshold-based signal generation from keypoint confidence
|
|
11
|
+
- **Duration filtering**: Minimum duration requirements for valid signal phases
|
|
12
|
+
- **Frame-to-time conversion**: Automatic FPS-based timestamp generation
|
|
13
|
+
- **Flexible triggering**: Support for ON/OFF transitions, state changes, and custom predicates
|
|
14
|
+
- **TTL file format**: Compatible with w2t_bkin.sync.ttl loader (one timestamp per line)
|
|
15
|
+
|
|
16
|
+
Use Cases:
|
|
17
|
+
----------
|
|
18
|
+
1. Generate trial light signals from tracked LED positions
|
|
19
|
+
2. Create behavioral event markers from pose kinematics
|
|
20
|
+
3. Produce synthetic TTL data for end-to-end pipeline testing
|
|
21
|
+
4. Validate synchronization logic with known ground truth
|
|
22
|
+
|
|
23
|
+
TTL File Format:
|
|
24
|
+
----------------
|
|
25
|
+
One floating-point timestamp per line (seconds), sorted ascending:
|
|
26
|
+
|
|
27
|
+
0.0000
|
|
28
|
+
2.0000
|
|
29
|
+
4.0000
|
|
30
|
+
...
|
|
31
|
+
|
|
32
|
+
Example:
|
|
33
|
+
--------
|
|
34
|
+
>>> from pathlib import Path
|
|
35
|
+
>>> from w2t_bkin.pose.ttl_mock import (
|
|
36
|
+
... TTLMockOptions,
|
|
37
|
+
... generate_ttl_from_dlc_likelihood,
|
|
38
|
+
... write_ttl_timestamps
|
|
39
|
+
... )
|
|
40
|
+
>>>
|
|
41
|
+
>>> # Generate TTL pulses from trial_light likelihood
|
|
42
|
+
>>> h5_path = Path("pose_output.h5")
|
|
43
|
+
>>> options = TTLMockOptions(
|
|
44
|
+
... bodypart="trial_light",
|
|
45
|
+
... likelihood_threshold=0.99,
|
|
46
|
+
... min_duration_frames=301,
|
|
47
|
+
... fps=150.0
|
|
48
|
+
... )
|
|
49
|
+
>>> timestamps = generate_ttl_from_dlc_likelihood(h5_path, options)
|
|
50
|
+
>>> print(f"Generated {len(timestamps)} TTL pulses")
|
|
51
|
+
>>>
|
|
52
|
+
>>> # Write to TTL file
|
|
53
|
+
>>> output_path = Path("TTLs/trial_light.txt")
|
|
54
|
+
>>> write_ttl_timestamps(timestamps, output_path)
|
|
55
|
+
|
|
56
|
+
Integration with Pipeline:
|
|
57
|
+
--------------------------
|
|
58
|
+
>>> from w2t_bkin.ttl import get_ttl_pulses
|
|
59
|
+
>>> from w2t_bkin.config import load_session
|
|
60
|
+
>>>
|
|
61
|
+
>>> # Generate mock TTL from pose data
|
|
62
|
+
>>> generate_and_write_ttl_from_pose(
|
|
63
|
+
... h5_path=session_dir / "pose.h5",
|
|
64
|
+
... output_path=session_dir / "TTLs/ttl_sync.txt",
|
|
65
|
+
... options=TTLMockOptions(bodypart="trial_light", ...)
|
|
66
|
+
... )
|
|
67
|
+
>>>
|
|
68
|
+
>>> # Load in pipeline
|
|
69
|
+
>>> session = load_session(session_dir / "session.toml")
|
|
70
|
+
>>> pulses = get_ttl_pulses(session)
|
|
71
|
+
>>> print(len(pulses['ttl_sync'])) # Matches generated count
|
|
72
|
+
|
|
73
|
+
Requirements:
|
|
74
|
+
-------------
|
|
75
|
+
- pandas (for HDF5 reading)
|
|
76
|
+
- numpy (for numerical operations)
|
|
77
|
+
- Python 3.10+
|
|
78
|
+
|
|
79
|
+
See Also:
|
|
80
|
+
---------
|
|
81
|
+
- synthetic.ttl_synth: Pure synthetic TTL generation with deterministic RNG
|
|
82
|
+
- w2t_bkin.sync.ttl: TTL pulse loading and validation
|
|
83
|
+
- w2t_bkin.pose.core: DLC pose data import
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
from __future__ import annotations
|
|
87
|
+
|
|
88
|
+
import logging
|
|
89
|
+
from pathlib import Path
|
|
90
|
+
from typing import Callable, List, Optional, Tuple
|
|
91
|
+
|
|
92
|
+
import numpy as np
|
|
93
|
+
import pandas as pd
|
|
94
|
+
from pydantic import BaseModel, Field, field_validator
|
|
95
|
+
|
|
96
|
+
from w2t_bkin.pose.core import PoseError
|
|
97
|
+
|
|
98
|
+
logger = logging.getLogger(__name__)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class TTLMockOptions(BaseModel):
|
|
102
|
+
"""Configuration for generating mock TTL signals from pose data.
|
|
103
|
+
|
|
104
|
+
Attributes:
|
|
105
|
+
bodypart: Name of the body part to track (must match DLC keypoint name)
|
|
106
|
+
likelihood_threshold: Minimum confidence score for signal ON state (0-1)
|
|
107
|
+
min_duration_frames: Minimum number of consecutive frames for valid signal
|
|
108
|
+
fps: Camera frame rate for converting frame indices to timestamps
|
|
109
|
+
transition_type: Type of signal transition to detect
|
|
110
|
+
- 'rising': Detect OFF→ON transitions (signal start)
|
|
111
|
+
- 'falling': Detect ON→OFF transitions (signal end)
|
|
112
|
+
- 'both': Detect both transitions
|
|
113
|
+
start_time_offset_s: Time offset to add to all timestamps (seconds)
|
|
114
|
+
filter_consecutive: If True, only keep first pulse in consecutive groups
|
|
115
|
+
"""
|
|
116
|
+
|
|
117
|
+
bodypart: str = Field(..., description="DLC body part name to track")
|
|
118
|
+
likelihood_threshold: float = Field(default=0.99, ge=0.0, le=1.0, description="Minimum confidence threshold")
|
|
119
|
+
min_duration_frames: int = Field(default=1, ge=1, description="Minimum frames for valid signal phase")
|
|
120
|
+
fps: float = Field(default=30.0, gt=0.0, description="Camera frame rate (Hz)")
|
|
121
|
+
transition_type: str = Field(default="rising", pattern="^(rising|falling|both)$", description="Transition detection mode")
|
|
122
|
+
start_time_offset_s: float = Field(default=0.0, description="Time offset for all timestamps (s)")
|
|
123
|
+
filter_consecutive: bool = Field(default=False, description="Keep only first pulse in consecutive groups")
|
|
124
|
+
|
|
125
|
+
@field_validator("bodypart")
|
|
126
|
+
@classmethod
|
|
127
|
+
def bodypart_not_empty(cls, v: str) -> str:
|
|
128
|
+
"""Validate bodypart is not empty."""
|
|
129
|
+
if not v or not v.strip():
|
|
130
|
+
raise ValueError("bodypart must be a non-empty string")
|
|
131
|
+
return v.strip()
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def load_dlc_likelihood_series(h5_path: Path, bodypart: str, scorer: Optional[str] = None) -> pd.Series:
|
|
135
|
+
"""Load likelihood time series for a specific body part from DLC H5 file.
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
h5_path: Path to DeepLabCut H5 output file
|
|
139
|
+
bodypart: Name of body part to extract
|
|
140
|
+
scorer: Optional scorer name (auto-detected if None)
|
|
141
|
+
|
|
142
|
+
Returns:
|
|
143
|
+
Pandas Series with frame indices and likelihood values
|
|
144
|
+
|
|
145
|
+
Raises:
|
|
146
|
+
PoseError: If file not found, format invalid, or bodypart missing
|
|
147
|
+
"""
|
|
148
|
+
if not h5_path.exists():
|
|
149
|
+
raise PoseError(f"DLC H5 file not found: {h5_path}")
|
|
150
|
+
|
|
151
|
+
try:
|
|
152
|
+
df = pd.read_hdf(h5_path, "df_with_missing")
|
|
153
|
+
except (KeyError, OSError, ValueError) as e:
|
|
154
|
+
raise PoseError(f"Failed to read DLC H5 file {h5_path}: {e}") from e
|
|
155
|
+
|
|
156
|
+
# Validate MultiIndex structure
|
|
157
|
+
if not isinstance(df.columns, pd.MultiIndex) or df.columns.nlevels != 3:
|
|
158
|
+
raise PoseError(f"Invalid DLC format: expected 3-level MultiIndex, got {type(df.columns)}")
|
|
159
|
+
|
|
160
|
+
# Auto-detect scorer if not provided
|
|
161
|
+
if scorer is None:
|
|
162
|
+
scorer = df.columns.get_level_values(0)[0]
|
|
163
|
+
logger.debug(f"Auto-detected scorer: {scorer}")
|
|
164
|
+
|
|
165
|
+
# Check if bodypart exists
|
|
166
|
+
bodyparts = df.columns.get_level_values(1).unique()
|
|
167
|
+
if bodypart not in bodyparts:
|
|
168
|
+
raise PoseError(f"Body part '{bodypart}' not found. Available: {list(bodyparts)}")
|
|
169
|
+
|
|
170
|
+
# Extract likelihood column
|
|
171
|
+
try:
|
|
172
|
+
likelihood = df[(scorer, bodypart, "likelihood")]
|
|
173
|
+
except KeyError as e:
|
|
174
|
+
raise PoseError(f"Failed to extract likelihood for '{bodypart}': {e}") from e
|
|
175
|
+
|
|
176
|
+
return likelihood
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def detect_signal_transitions(
|
|
180
|
+
signal: pd.Series,
|
|
181
|
+
transition_type: str = "rising",
|
|
182
|
+
) -> Tuple[List[int], List[int]]:
|
|
183
|
+
"""Detect rising and/or falling edge transitions in a boolean signal.
|
|
184
|
+
|
|
185
|
+
Args:
|
|
186
|
+
signal: Boolean series indicating signal state (True=ON, False=OFF)
|
|
187
|
+
transition_type: Type of transitions to detect ('rising', 'falling', 'both')
|
|
188
|
+
|
|
189
|
+
Returns:
|
|
190
|
+
Tuple of (onsets, offsets) frame index lists
|
|
191
|
+
- onsets: Frame indices where signal transitions OFF→ON
|
|
192
|
+
- offsets: Frame indices where signal transitions ON→OFF
|
|
193
|
+
|
|
194
|
+
Example:
|
|
195
|
+
>>> signal = pd.Series([False, False, True, True, False, True])
|
|
196
|
+
>>> onsets, offsets = detect_signal_transitions(signal, 'rising')
|
|
197
|
+
>>> print(onsets) # [2, 5]
|
|
198
|
+
>>> print(offsets) # [4]
|
|
199
|
+
"""
|
|
200
|
+
# Compute transitions using shift
|
|
201
|
+
prev_signal = signal.shift(1, fill_value=False)
|
|
202
|
+
|
|
203
|
+
onsets = []
|
|
204
|
+
offsets = []
|
|
205
|
+
|
|
206
|
+
if transition_type in ("rising", "both"):
|
|
207
|
+
# Rising edge: previous=False AND current=True
|
|
208
|
+
rising_mask = (~prev_signal) & signal
|
|
209
|
+
onsets = signal.index[rising_mask].tolist()
|
|
210
|
+
|
|
211
|
+
if transition_type in ("falling", "both"):
|
|
212
|
+
# Falling edge: previous=True AND current=False
|
|
213
|
+
falling_mask = prev_signal & (~signal)
|
|
214
|
+
offsets = signal.index[falling_mask].tolist()
|
|
215
|
+
|
|
216
|
+
return onsets, offsets
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def filter_by_duration(
|
|
220
|
+
onsets: List[int],
|
|
221
|
+
offsets: List[int],
|
|
222
|
+
min_duration_frames: int,
|
|
223
|
+
) -> Tuple[List[int], List[int]]:
|
|
224
|
+
"""Filter signal phases by minimum duration requirement.
|
|
225
|
+
|
|
226
|
+
Args:
|
|
227
|
+
onsets: Frame indices of signal ON transitions
|
|
228
|
+
offsets: Frame indices of signal OFF transitions
|
|
229
|
+
min_duration_frames: Minimum number of frames for valid phase
|
|
230
|
+
|
|
231
|
+
Returns:
|
|
232
|
+
Tuple of (filtered_onsets, filtered_offsets) with only valid phases
|
|
233
|
+
|
|
234
|
+
Example:
|
|
235
|
+
>>> onsets = [10, 50, 100]
|
|
236
|
+
>>> offsets = [15, 55, 400] # Durations: 5, 5, 300
|
|
237
|
+
>>> filtered = filter_by_duration(onsets, offsets, min_duration_frames=10)
|
|
238
|
+
>>> print(filtered[0]) # [100]
|
|
239
|
+
>>> print(filtered[1]) # [400]
|
|
240
|
+
"""
|
|
241
|
+
if not onsets or not offsets:
|
|
242
|
+
return [], []
|
|
243
|
+
|
|
244
|
+
# Ensure we have matching pairs
|
|
245
|
+
min_len = min(len(onsets), len(offsets))
|
|
246
|
+
onsets = onsets[:min_len]
|
|
247
|
+
offsets = offsets[:min_len]
|
|
248
|
+
|
|
249
|
+
# Calculate durations
|
|
250
|
+
durations = [off - on for on, off in zip(onsets, offsets)]
|
|
251
|
+
|
|
252
|
+
# Filter by minimum duration
|
|
253
|
+
valid_indices = [i for i, dur in enumerate(durations) if dur >= min_duration_frames]
|
|
254
|
+
filtered_onsets = [onsets[i] for i in valid_indices]
|
|
255
|
+
filtered_offsets = [offsets[i] for i in valid_indices]
|
|
256
|
+
|
|
257
|
+
return filtered_onsets, filtered_offsets
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
def frames_to_timestamps(frame_indices: List[int], fps: float, offset_s: float = 0.0) -> List[float]:
|
|
261
|
+
"""Convert frame indices to timestamps in seconds.
|
|
262
|
+
|
|
263
|
+
Args:
|
|
264
|
+
frame_indices: List of frame indices (0-based)
|
|
265
|
+
fps: Frame rate in frames per second
|
|
266
|
+
offset_s: Time offset to add to all timestamps
|
|
267
|
+
|
|
268
|
+
Returns:
|
|
269
|
+
List of timestamps in seconds
|
|
270
|
+
|
|
271
|
+
Example:
|
|
272
|
+
>>> frames = [0, 150, 300]
|
|
273
|
+
>>> timestamps = frames_to_timestamps(frames, fps=150.0)
|
|
274
|
+
>>> print(timestamps) # [0.0, 1.0, 2.0]
|
|
275
|
+
"""
|
|
276
|
+
if not frame_indices:
|
|
277
|
+
return []
|
|
278
|
+
|
|
279
|
+
timestamps = [frame / fps + offset_s for frame in frame_indices]
|
|
280
|
+
return timestamps
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
def generate_ttl_from_dlc_likelihood(
|
|
284
|
+
h5_path: Path,
|
|
285
|
+
options: TTLMockOptions,
|
|
286
|
+
scorer: Optional[str] = None,
|
|
287
|
+
) -> List[float]:
|
|
288
|
+
"""Generate mock TTL timestamps from DLC likelihood data.
|
|
289
|
+
|
|
290
|
+
Main entry point for likelihood-based TTL generation. Loads pose data,
|
|
291
|
+
applies threshold and duration filters, and converts to timestamps.
|
|
292
|
+
|
|
293
|
+
Args:
|
|
294
|
+
h5_path: Path to DeepLabCut H5 output file
|
|
295
|
+
options: Configuration for TTL generation
|
|
296
|
+
scorer: Optional DLC scorer name (auto-detected if None)
|
|
297
|
+
|
|
298
|
+
Returns:
|
|
299
|
+
List of TTL pulse timestamps in seconds (sorted)
|
|
300
|
+
|
|
301
|
+
Raises:
|
|
302
|
+
PoseError: If file not found, format invalid, or bodypart missing
|
|
303
|
+
|
|
304
|
+
Example:
|
|
305
|
+
>>> options = TTLMockOptions(
|
|
306
|
+
... bodypart="trial_light",
|
|
307
|
+
... likelihood_threshold=0.99,
|
|
308
|
+
... min_duration_frames=301,
|
|
309
|
+
... fps=150.0,
|
|
310
|
+
... transition_type="rising"
|
|
311
|
+
... )
|
|
312
|
+
>>> timestamps = generate_ttl_from_dlc_likelihood(Path("pose.h5"), options)
|
|
313
|
+
"""
|
|
314
|
+
logger.info(f"Generating TTL from DLC pose data: {h5_path}")
|
|
315
|
+
logger.debug(f"Options: bodypart={options.bodypart}, threshold={options.likelihood_threshold}, " f"min_duration={options.min_duration_frames}, fps={options.fps}")
|
|
316
|
+
|
|
317
|
+
# Load likelihood series
|
|
318
|
+
likelihood = load_dlc_likelihood_series(h5_path, options.bodypart, scorer)
|
|
319
|
+
logger.debug(f"Loaded {len(likelihood)} frames, mean likelihood: {likelihood.mean():.3f}")
|
|
320
|
+
|
|
321
|
+
# Create boolean signal from threshold
|
|
322
|
+
signal = likelihood >= options.likelihood_threshold
|
|
323
|
+
high_conf_count = signal.sum()
|
|
324
|
+
logger.debug(f"Frames above threshold: {high_conf_count} ({100*high_conf_count/len(signal):.1f}%)")
|
|
325
|
+
|
|
326
|
+
# Detect transitions
|
|
327
|
+
# For duration filtering, we always need both onsets and offsets
|
|
328
|
+
if options.min_duration_frames > 1:
|
|
329
|
+
onsets, offsets = detect_signal_transitions(signal, "both")
|
|
330
|
+
logger.debug(f"Detected {len(onsets)} onsets, {len(offsets)} offsets (for duration filtering)")
|
|
331
|
+
onsets, offsets = filter_by_duration(onsets, offsets, options.min_duration_frames)
|
|
332
|
+
logger.debug(f"After duration filter: {len(onsets)} valid phases")
|
|
333
|
+
else:
|
|
334
|
+
onsets, offsets = detect_signal_transitions(signal, options.transition_type)
|
|
335
|
+
logger.debug(f"Detected {len(onsets)} onsets, {len(offsets)} offsets")
|
|
336
|
+
|
|
337
|
+
# Select timestamps based on transition type
|
|
338
|
+
if options.transition_type == "rising":
|
|
339
|
+
frame_indices = onsets
|
|
340
|
+
elif options.transition_type == "falling":
|
|
341
|
+
frame_indices = offsets
|
|
342
|
+
else: # both
|
|
343
|
+
frame_indices = sorted(onsets + offsets)
|
|
344
|
+
|
|
345
|
+
# Convert to timestamps
|
|
346
|
+
timestamps = frames_to_timestamps(frame_indices, options.fps, options.start_time_offset_s)
|
|
347
|
+
|
|
348
|
+
logger.info(f"Generated {len(timestamps)} TTL pulses from {options.bodypart}")
|
|
349
|
+
if timestamps:
|
|
350
|
+
logger.debug(f"Time range: {timestamps[0]:.3f}s - {timestamps[-1]:.3f}s")
|
|
351
|
+
|
|
352
|
+
return timestamps
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
def generate_ttl_from_custom_predicate(
|
|
356
|
+
h5_path: Path,
|
|
357
|
+
predicate: Callable[[pd.DataFrame], pd.Series],
|
|
358
|
+
options: TTLMockOptions,
|
|
359
|
+
) -> List[float]:
|
|
360
|
+
"""Generate TTL timestamps using a custom predicate function.
|
|
361
|
+
|
|
362
|
+
Advanced API for complex signal generation logic. The predicate receives
|
|
363
|
+
the full DLC DataFrame and returns a boolean Series indicating signal state.
|
|
364
|
+
|
|
365
|
+
Args:
|
|
366
|
+
h5_path: Path to DeepLabCut H5 output file
|
|
367
|
+
predicate: Function that takes DataFrame and returns boolean Series
|
|
368
|
+
options: Configuration (fps, transition_type, etc.)
|
|
369
|
+
|
|
370
|
+
Returns:
|
|
371
|
+
List of TTL pulse timestamps in seconds
|
|
372
|
+
|
|
373
|
+
Example:
|
|
374
|
+
>>> def detect_movement(df):
|
|
375
|
+
... # Detect when nose moves > 10 pixels between frames
|
|
376
|
+
... scorer = df.columns.get_level_values(0)[0]
|
|
377
|
+
... x = df[(scorer, 'nose', 'x')]
|
|
378
|
+
... y = df[(scorer, 'nose', 'y')]
|
|
379
|
+
... dx = x.diff().abs()
|
|
380
|
+
... dy = y.diff().abs()
|
|
381
|
+
... return (dx + dy) > 10
|
|
382
|
+
>>>
|
|
383
|
+
>>> options = TTLMockOptions(bodypart="nose", fps=150.0)
|
|
384
|
+
>>> timestamps = generate_ttl_from_custom_predicate(
|
|
385
|
+
... Path("pose.h5"), detect_movement, options
|
|
386
|
+
... )
|
|
387
|
+
"""
|
|
388
|
+
if not h5_path.exists():
|
|
389
|
+
raise PoseError(f"DLC H5 file not found: {h5_path}")
|
|
390
|
+
|
|
391
|
+
try:
|
|
392
|
+
df = pd.read_hdf(h5_path, "df_with_missing")
|
|
393
|
+
except (KeyError, OSError, ValueError) as e:
|
|
394
|
+
raise PoseError(f"Failed to read DLC H5 file {h5_path}: {e}") from e
|
|
395
|
+
|
|
396
|
+
# Apply custom predicate
|
|
397
|
+
signal = predicate(df)
|
|
398
|
+
|
|
399
|
+
if not isinstance(signal, pd.Series):
|
|
400
|
+
raise PoseError(f"Predicate must return pd.Series, got {type(signal)}")
|
|
401
|
+
|
|
402
|
+
# Detect transitions
|
|
403
|
+
onsets, offsets = detect_signal_transitions(signal, options.transition_type)
|
|
404
|
+
|
|
405
|
+
# Filter by minimum duration if needed
|
|
406
|
+
if options.min_duration_frames > 1 and options.transition_type in ("rising", "both"):
|
|
407
|
+
onsets, offsets = filter_by_duration(onsets, offsets, options.min_duration_frames)
|
|
408
|
+
|
|
409
|
+
# Select timestamps based on transition type
|
|
410
|
+
if options.transition_type == "rising":
|
|
411
|
+
frame_indices = onsets
|
|
412
|
+
elif options.transition_type == "falling":
|
|
413
|
+
frame_indices = offsets
|
|
414
|
+
else: # both
|
|
415
|
+
frame_indices = sorted(onsets + offsets)
|
|
416
|
+
|
|
417
|
+
# Convert to timestamps
|
|
418
|
+
timestamps = frames_to_timestamps(frame_indices, options.fps, options.start_time_offset_s)
|
|
419
|
+
|
|
420
|
+
return timestamps
|
|
421
|
+
|
|
422
|
+
|
|
423
|
+
def write_ttl_timestamps(timestamps: List[float], output_path: Path) -> None:
|
|
424
|
+
"""Write TTL timestamps to file in w2t_bkin format.
|
|
425
|
+
|
|
426
|
+
Writes one timestamp per line, sorted, with high precision. Creates parent
|
|
427
|
+
directories if needed.
|
|
428
|
+
|
|
429
|
+
Args:
|
|
430
|
+
timestamps: List of timestamps in seconds
|
|
431
|
+
output_path: Path to output file
|
|
432
|
+
|
|
433
|
+
Example:
|
|
434
|
+
>>> timestamps = [0.0, 1.5, 3.0]
|
|
435
|
+
>>> write_ttl_timestamps(timestamps, Path("TTLs/ttl_sync.txt"))
|
|
436
|
+
"""
|
|
437
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
438
|
+
|
|
439
|
+
# Sort timestamps
|
|
440
|
+
sorted_timestamps = sorted(timestamps)
|
|
441
|
+
|
|
442
|
+
# Write with high precision
|
|
443
|
+
with open(output_path, "w") as f:
|
|
444
|
+
for ts in sorted_timestamps:
|
|
445
|
+
f.write(f"{ts:.6f}\n")
|
|
446
|
+
|
|
447
|
+
logger.info(f"Wrote {len(timestamps)} TTL timestamps to {output_path}")
|
|
448
|
+
|
|
449
|
+
|
|
450
|
+
def generate_and_write_ttl_from_pose(
|
|
451
|
+
h5_path: Path,
|
|
452
|
+
output_path: Path,
|
|
453
|
+
options: TTLMockOptions,
|
|
454
|
+
scorer: Optional[str] = None,
|
|
455
|
+
) -> int:
|
|
456
|
+
"""Convenience function to generate and write TTL file in one call.
|
|
457
|
+
|
|
458
|
+
Args:
|
|
459
|
+
h5_path: Path to DeepLabCut H5 output file
|
|
460
|
+
output_path: Path to output TTL file
|
|
461
|
+
options: Configuration for TTL generation
|
|
462
|
+
scorer: Optional DLC scorer name
|
|
463
|
+
|
|
464
|
+
Returns:
|
|
465
|
+
Number of TTL pulses generated
|
|
466
|
+
|
|
467
|
+
Example:
|
|
468
|
+
>>> count = generate_and_write_ttl_from_pose(
|
|
469
|
+
... h5_path=Path("pose.h5"),
|
|
470
|
+
... output_path=Path("TTLs/trial_light.txt"),
|
|
471
|
+
... options=TTLMockOptions(bodypart="trial_light", fps=150.0)
|
|
472
|
+
... )
|
|
473
|
+
>>> print(f"Generated {count} pulses")
|
|
474
|
+
"""
|
|
475
|
+
timestamps = generate_ttl_from_dlc_likelihood(h5_path, options, scorer)
|
|
476
|
+
write_ttl_timestamps(timestamps, output_path)
|
|
477
|
+
return len(timestamps)
|