w2t-bkin 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- w2t_bkin/__init__.py +85 -0
- w2t_bkin/behavior/__init__.py +115 -0
- w2t_bkin/behavior/core.py +1027 -0
- w2t_bkin/bpod/__init__.py +38 -0
- w2t_bkin/bpod/core.py +519 -0
- w2t_bkin/config.py +625 -0
- w2t_bkin/dlc/__init__.py +59 -0
- w2t_bkin/dlc/core.py +448 -0
- w2t_bkin/dlc/models.py +124 -0
- w2t_bkin/exceptions.py +426 -0
- w2t_bkin/facemap/__init__.py +42 -0
- w2t_bkin/facemap/core.py +397 -0
- w2t_bkin/facemap/models.py +134 -0
- w2t_bkin/pipeline.py +665 -0
- w2t_bkin/pose/__init__.py +48 -0
- w2t_bkin/pose/core.py +227 -0
- w2t_bkin/pose/io.py +363 -0
- w2t_bkin/pose/skeleton.py +165 -0
- w2t_bkin/pose/ttl_mock.py +477 -0
- w2t_bkin/session.py +423 -0
- w2t_bkin/sync/__init__.py +72 -0
- w2t_bkin/sync/core.py +678 -0
- w2t_bkin/sync/stats.py +176 -0
- w2t_bkin/sync/timebase.py +311 -0
- w2t_bkin/sync/ttl.py +254 -0
- w2t_bkin/transcode/__init__.py +38 -0
- w2t_bkin/transcode/core.py +303 -0
- w2t_bkin/transcode/models.py +96 -0
- w2t_bkin/ttl/__init__.py +64 -0
- w2t_bkin/ttl/core.py +518 -0
- w2t_bkin/ttl/models.py +19 -0
- w2t_bkin/utils.py +1093 -0
- w2t_bkin-0.0.6.dist-info/METADATA +145 -0
- w2t_bkin-0.0.6.dist-info/RECORD +36 -0
- w2t_bkin-0.0.6.dist-info/WHEEL +4 -0
- w2t_bkin-0.0.6.dist-info/licenses/LICENSE +201 -0
w2t_bkin/ttl/__init__.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
"""TTL hardware signals: loading, processing, and NWB integration.
|
|
2
|
+
|
|
3
|
+
This module handles TTL (Transistor-Transistor Logic) pulse timestamps from
|
|
4
|
+
hardware synchronization signals, providing functions to load from files and
|
|
5
|
+
convert to standardized NWB EventsTable format using ndx-events.
|
|
6
|
+
|
|
7
|
+
Public API
|
|
8
|
+
----------
|
|
9
|
+
from w2t_bkin.ttl import (
|
|
10
|
+
# Loading functions (migrated from sync.ttl)
|
|
11
|
+
load_ttl_file,
|
|
12
|
+
get_ttl_pulses,
|
|
13
|
+
|
|
14
|
+
# NWB integration (ndx-events)
|
|
15
|
+
extract_ttl_table,
|
|
16
|
+
add_ttl_table_to_nwb,
|
|
17
|
+
|
|
18
|
+
# ndx-events types
|
|
19
|
+
EventsTable,
|
|
20
|
+
|
|
21
|
+
# Exceptions
|
|
22
|
+
TTLError,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
Usage Example
|
|
26
|
+
-------------
|
|
27
|
+
```python
|
|
28
|
+
from w2t_bkin.ttl import get_ttl_pulses, extract_ttl_table
|
|
29
|
+
|
|
30
|
+
# Load TTL pulses from files
|
|
31
|
+
ttl_patterns = {"ttl_camera": "TTLs/cam*.txt", "ttl_cue": "TTLs/cue*.txt"}
|
|
32
|
+
ttl_pulses = get_ttl_pulses(session_dir, ttl_patterns)
|
|
33
|
+
|
|
34
|
+
# Extract TTL descriptions from config
|
|
35
|
+
ttl_descriptions = {ttl.id: ttl.description for ttl in session.TTLs}
|
|
36
|
+
|
|
37
|
+
# Create EventsTable
|
|
38
|
+
ttl_table = extract_ttl_table(ttl_pulses, descriptions=ttl_descriptions)
|
|
39
|
+
|
|
40
|
+
# Add to NWBFile
|
|
41
|
+
nwbfile.add_acquisition(ttl_table)
|
|
42
|
+
```
|
|
43
|
+
|
|
44
|
+
Requirements
|
|
45
|
+
------------
|
|
46
|
+
- FR-17: Hardware sync signal recording
|
|
47
|
+
- ndx-events~=0.4.0 (for EventsTable)
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
from .core import TTLError, add_ttl_table_to_nwb, extract_ttl_table, get_ttl_pulses, load_ttl_file
|
|
51
|
+
from .models import EventsTable
|
|
52
|
+
|
|
53
|
+
__all__ = [
|
|
54
|
+
# Loading functions
|
|
55
|
+
"load_ttl_file",
|
|
56
|
+
"get_ttl_pulses",
|
|
57
|
+
# NWB integration
|
|
58
|
+
"extract_ttl_table",
|
|
59
|
+
"add_ttl_table_to_nwb",
|
|
60
|
+
# ndx-events types
|
|
61
|
+
"EventsTable",
|
|
62
|
+
# Exceptions
|
|
63
|
+
"TTLError",
|
|
64
|
+
]
|
w2t_bkin/ttl/core.py
ADDED
|
@@ -0,0 +1,518 @@
|
|
|
1
|
+
"""Core functions for TTL pulse loading and NWB EventsTable conversion.
|
|
2
|
+
|
|
3
|
+
Provides TTL timestamp loading from text files and conversion to structured
|
|
4
|
+
NWB-compatible event tables using the ndx-events extension. Optimized for
|
|
5
|
+
large datasets (camera frames with 10k+ timestamps).
|
|
6
|
+
|
|
7
|
+
Functions
|
|
8
|
+
---------
|
|
9
|
+
- load_ttl_file: Load timestamps from a single TTL file
|
|
10
|
+
- get_ttl_pulses: Load TTL pulses from multiple files using glob patterns
|
|
11
|
+
- extract_ttl_table: Convert TTL pulses to ndx-events EventsTable
|
|
12
|
+
- add_ttl_table_to_nwb: Helper to add TTL EventsTable to NWBFile
|
|
13
|
+
|
|
14
|
+
Performance
|
|
15
|
+
-----------
|
|
16
|
+
Uses numpy vectorized operations for efficient handling of large TTL datasets.
|
|
17
|
+
Tested with 10k+ events in <60s.
|
|
18
|
+
|
|
19
|
+
Example
|
|
20
|
+
-------
|
|
21
|
+
>>> from pathlib import Path
|
|
22
|
+
>>> from w2t_bkin.ttl import get_ttl_pulses, extract_ttl_table
|
|
23
|
+
>>>
|
|
24
|
+
>>> # Load TTL pulses
|
|
25
|
+
>>> ttl_patterns = {"ttl_camera": "TTLs/cam*.txt"}
|
|
26
|
+
>>> ttl_pulses = get_ttl_pulses(Path("data/session"), ttl_patterns)
|
|
27
|
+
>>>
|
|
28
|
+
>>> # Create EventsTable
|
|
29
|
+
>>> ttl_table = extract_ttl_table(
|
|
30
|
+
... ttl_pulses,
|
|
31
|
+
... descriptions={"ttl_camera": "Camera frame sync (30 Hz)"}
|
|
32
|
+
... )
|
|
33
|
+
>>>
|
|
34
|
+
>>> # Add to NWBFile
|
|
35
|
+
>>> nwbfile.add_acquisition(ttl_table)
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
import glob
|
|
39
|
+
import logging
|
|
40
|
+
from pathlib import Path
|
|
41
|
+
from typing import Dict, List, Optional, Protocol, Tuple
|
|
42
|
+
|
|
43
|
+
from ndx_events import EventsTable
|
|
44
|
+
import numpy as np
|
|
45
|
+
import pandas as pd
|
|
46
|
+
from pynwb import NWBFile
|
|
47
|
+
|
|
48
|
+
logger = logging.getLogger(__name__)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class TTLError(Exception):
|
|
52
|
+
"""Exception raised for TTL processing errors."""
|
|
53
|
+
|
|
54
|
+
pass
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
# =============================================================================
|
|
58
|
+
# TTL File Loading (migrated from sync.ttl)
|
|
59
|
+
# =============================================================================
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def load_ttl_file(path: Path) -> List[float]:
|
|
63
|
+
"""Load TTL timestamps from a single file.
|
|
64
|
+
|
|
65
|
+
Expects one timestamp per line in seconds (floating-point format).
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
path: Path to TTL file
|
|
69
|
+
|
|
70
|
+
Returns:
|
|
71
|
+
List of timestamps in seconds
|
|
72
|
+
|
|
73
|
+
Raises:
|
|
74
|
+
TTLError: File not found or read error
|
|
75
|
+
|
|
76
|
+
Example:
|
|
77
|
+
>>> from pathlib import Path
|
|
78
|
+
>>> timestamps = load_ttl_file(Path("TTLs/cam0.txt"))
|
|
79
|
+
>>> print(f"Loaded {len(timestamps)} TTL pulses")
|
|
80
|
+
"""
|
|
81
|
+
if not path.exists():
|
|
82
|
+
raise TTLError(f"TTL file not found: {path}")
|
|
83
|
+
|
|
84
|
+
timestamps = []
|
|
85
|
+
|
|
86
|
+
try:
|
|
87
|
+
with open(path, "r") as f:
|
|
88
|
+
for line_num, line in enumerate(f, start=1):
|
|
89
|
+
line = line.strip()
|
|
90
|
+
if not line:
|
|
91
|
+
continue
|
|
92
|
+
|
|
93
|
+
try:
|
|
94
|
+
timestamps.append(float(line))
|
|
95
|
+
except ValueError:
|
|
96
|
+
logger.warning(f"Skipping invalid TTL timestamp in {path.name} " f"line {line_num}: {line}")
|
|
97
|
+
except Exception as e:
|
|
98
|
+
raise TTLError(f"Failed to read TTL file {path}: {e}")
|
|
99
|
+
|
|
100
|
+
return timestamps
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def get_ttl_pulses(session_dir: Path, ttl_patterns: Dict[str, str]) -> Dict[str, List[float]]:
|
|
104
|
+
"""Load TTL pulses from multiple files using glob patterns.
|
|
105
|
+
|
|
106
|
+
Discovers and loads TTL files matching glob patterns, merging timestamps
|
|
107
|
+
from multiple files per channel and sorting chronologically.
|
|
108
|
+
|
|
109
|
+
Args:
|
|
110
|
+
session_dir: Base directory for resolving patterns
|
|
111
|
+
ttl_patterns: Dict mapping TTL ID to glob pattern
|
|
112
|
+
(e.g., {"ttl_camera": "TTLs/cam*.txt"})
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
Dict mapping TTL ID to sorted timestamp list
|
|
116
|
+
|
|
117
|
+
Raises:
|
|
118
|
+
TTLError: File read failed
|
|
119
|
+
|
|
120
|
+
Example:
|
|
121
|
+
>>> from pathlib import Path
|
|
122
|
+
>>> ttl_patterns = {
|
|
123
|
+
... "ttl_camera": "TTLs/*cam*.txt",
|
|
124
|
+
... "ttl_cue": "TTLs/*cue*.txt"
|
|
125
|
+
... }
|
|
126
|
+
>>> ttl_pulses = get_ttl_pulses(Path("data/Session-000001"), ttl_patterns)
|
|
127
|
+
>>> print(f"Camera: {len(ttl_pulses['ttl_camera'])} pulses")
|
|
128
|
+
"""
|
|
129
|
+
session_dir = Path(session_dir)
|
|
130
|
+
ttl_pulses = {}
|
|
131
|
+
|
|
132
|
+
for ttl_id, pattern_str in ttl_patterns.items():
|
|
133
|
+
# Resolve glob pattern relative to session directory
|
|
134
|
+
pattern = str(session_dir / pattern_str)
|
|
135
|
+
ttl_files = sorted(glob.glob(pattern))
|
|
136
|
+
|
|
137
|
+
if not ttl_files:
|
|
138
|
+
logger.warning(f"No TTL files found for '{ttl_id}' with pattern: {pattern}")
|
|
139
|
+
ttl_pulses[ttl_id] = []
|
|
140
|
+
continue
|
|
141
|
+
|
|
142
|
+
# Load and merge timestamps from all matching files
|
|
143
|
+
timestamps = []
|
|
144
|
+
for ttl_file in ttl_files:
|
|
145
|
+
path = Path(ttl_file)
|
|
146
|
+
file_timestamps = load_ttl_file(path)
|
|
147
|
+
timestamps.extend(file_timestamps)
|
|
148
|
+
|
|
149
|
+
# Sort chronologically and store
|
|
150
|
+
ttl_pulses[ttl_id] = sorted(timestamps)
|
|
151
|
+
logger.debug(f"Loaded {len(timestamps)} TTL pulses for '{ttl_id}' " f"from {len(ttl_files)} file(s)")
|
|
152
|
+
|
|
153
|
+
return ttl_pulses
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
# =============================================================================
|
|
157
|
+
# NWB EventsTable Conversion (ndx-events integration)
|
|
158
|
+
# =============================================================================
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def extract_ttl_table(
|
|
162
|
+
ttl_pulses: Dict[str, List[float]],
|
|
163
|
+
name: str = "TTLEvents",
|
|
164
|
+
descriptions: Optional[Dict[str, str]] = None,
|
|
165
|
+
sources: Optional[Dict[str, str]] = None,
|
|
166
|
+
) -> EventsTable:
|
|
167
|
+
"""Extract EventsTable from TTL pulse timestamps.
|
|
168
|
+
|
|
169
|
+
Converts a dictionary of TTL pulse timestamps into an ndx-events EventsTable
|
|
170
|
+
with one row per pulse. Includes channel ID, description, and source metadata
|
|
171
|
+
via custom columns. Optimized for large datasets using numpy vectorization.
|
|
172
|
+
|
|
173
|
+
Performance: Handles 10k+ events efficiently (O(n log n) for sorting).
|
|
174
|
+
|
|
175
|
+
Args:
|
|
176
|
+
ttl_pulses: Dict mapping TTL ID to list of timestamps (seconds)
|
|
177
|
+
name: Name for the EventsTable container (default: "TTLEvents")
|
|
178
|
+
descriptions: Optional dict mapping TTL ID to description string
|
|
179
|
+
(typically from session.toml [[TTLs]].description)
|
|
180
|
+
sources: Optional dict mapping TTL ID to source device/system
|
|
181
|
+
|
|
182
|
+
Returns:
|
|
183
|
+
EventsTable with all TTL pulses as events, sorted by timestamp
|
|
184
|
+
|
|
185
|
+
Raises:
|
|
186
|
+
TTLError: If ttl_pulses is empty or all channels are empty
|
|
187
|
+
|
|
188
|
+
Example:
|
|
189
|
+
>>> ttl_pulses = {
|
|
190
|
+
... "ttl_camera": [0.0, 0.033, 0.066], # Camera frames
|
|
191
|
+
... "ttl_cue": [1.0, 3.0, 5.0] # Behavioral cues
|
|
192
|
+
... }
|
|
193
|
+
>>> ttl_table = extract_ttl_table(
|
|
194
|
+
... ttl_pulses,
|
|
195
|
+
... descriptions={"ttl_camera": "Camera sync", "ttl_cue": "Cue trigger"},
|
|
196
|
+
... sources={"ttl_camera": "FLIR Blackfly", "ttl_cue": "Bpod"}
|
|
197
|
+
... )
|
|
198
|
+
>>> len(ttl_table.timestamp) # Total pulses across all channels
|
|
199
|
+
6
|
|
200
|
+
"""
|
|
201
|
+
if not ttl_pulses:
|
|
202
|
+
raise TTLError("ttl_pulses dictionary is empty")
|
|
203
|
+
|
|
204
|
+
descriptions = descriptions or {}
|
|
205
|
+
sources = sources or {}
|
|
206
|
+
|
|
207
|
+
# Pre-compute total size for efficient array allocation
|
|
208
|
+
total_events = sum(len(timestamps) for timestamps in ttl_pulses.values())
|
|
209
|
+
if total_events == 0:
|
|
210
|
+
raise TTLError("No valid TTL pulses found in any channel")
|
|
211
|
+
|
|
212
|
+
# Pre-allocate arrays for performance (avoids list appends)
|
|
213
|
+
all_timestamps = np.empty(total_events, dtype=np.float64)
|
|
214
|
+
all_channels = np.empty(total_events, dtype=object)
|
|
215
|
+
all_descriptions = np.empty(total_events, dtype=object)
|
|
216
|
+
all_sources = np.empty(total_events, dtype=object)
|
|
217
|
+
|
|
218
|
+
# Fill arrays efficiently
|
|
219
|
+
offset = 0
|
|
220
|
+
for ttl_id in sorted(ttl_pulses.keys()): # Deterministic order
|
|
221
|
+
timestamps = ttl_pulses[ttl_id]
|
|
222
|
+
if not timestamps:
|
|
223
|
+
logger.warning(f"TTL channel '{ttl_id}' has no pulses, skipping")
|
|
224
|
+
continue
|
|
225
|
+
|
|
226
|
+
n = len(timestamps)
|
|
227
|
+
all_timestamps[offset : offset + n] = timestamps
|
|
228
|
+
all_channels[offset : offset + n] = ttl_id
|
|
229
|
+
all_descriptions[offset : offset + n] = descriptions.get(ttl_id, f"TTL pulses from {ttl_id}")
|
|
230
|
+
all_sources[offset : offset + n] = sources.get(ttl_id, "unknown")
|
|
231
|
+
offset += n
|
|
232
|
+
|
|
233
|
+
# Trim arrays if some channels were empty
|
|
234
|
+
if offset < total_events:
|
|
235
|
+
all_timestamps = all_timestamps[:offset]
|
|
236
|
+
all_channels = all_channels[:offset]
|
|
237
|
+
all_descriptions = all_descriptions[:offset]
|
|
238
|
+
all_sources = all_sources[:offset]
|
|
239
|
+
|
|
240
|
+
# Sort by timestamp (O(n log n), efficient for large datasets)
|
|
241
|
+
sort_indices = np.argsort(all_timestamps)
|
|
242
|
+
sorted_timestamps = all_timestamps[sort_indices]
|
|
243
|
+
sorted_channels = all_channels[sort_indices]
|
|
244
|
+
sorted_descriptions = all_descriptions[sort_indices]
|
|
245
|
+
sorted_sources = all_sources[sort_indices]
|
|
246
|
+
|
|
247
|
+
# Create DataFrame for bulk insertion (much faster than add_row loop)
|
|
248
|
+
df = pd.DataFrame(
|
|
249
|
+
{
|
|
250
|
+
"timestamp": sorted_timestamps,
|
|
251
|
+
"channel": sorted_channels,
|
|
252
|
+
"ttl_description": sorted_descriptions,
|
|
253
|
+
"source": sorted_sources,
|
|
254
|
+
}
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
# Define column descriptions for EventsTable
|
|
258
|
+
columns = [
|
|
259
|
+
{"name": "channel", "description": "TTL channel identifier"},
|
|
260
|
+
{"name": "ttl_description", "description": "Description of the TTL channel"},
|
|
261
|
+
{"name": "source", "description": "Source device or system generating the TTL signal"},
|
|
262
|
+
]
|
|
263
|
+
|
|
264
|
+
# Create EventsTable from DataFrame (bulk insertion - much faster than add_row)
|
|
265
|
+
ttl_table = EventsTable.from_dataframe(
|
|
266
|
+
df=df,
|
|
267
|
+
name=name,
|
|
268
|
+
table_description=f"Hardware TTL pulse events from {len(ttl_pulses)} channels, {offset} total pulses",
|
|
269
|
+
columns=columns,
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
logger.info(f"Created EventsTable '{name}' with {offset} events from {len(ttl_pulses)} TTL channels")
|
|
273
|
+
|
|
274
|
+
return ttl_table
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
def add_ttl_table_to_nwb(
|
|
278
|
+
nwbfile: NWBFile,
|
|
279
|
+
ttl_pulses: Dict[str, List[float]],
|
|
280
|
+
descriptions: Optional[Dict[str, str]] = None,
|
|
281
|
+
sources: Optional[Dict[str, str]] = None,
|
|
282
|
+
container_name: str = "TTLEvents",
|
|
283
|
+
) -> NWBFile:
|
|
284
|
+
"""Add TTL events to NWBFile as EventsTable.
|
|
285
|
+
|
|
286
|
+
Convenience function that creates an EventsTable and adds it to the NWBFile
|
|
287
|
+
acquisition section.
|
|
288
|
+
|
|
289
|
+
Args:
|
|
290
|
+
nwbfile: NWBFile to add TTL table to
|
|
291
|
+
ttl_pulses: Dict mapping TTL ID to timestamps
|
|
292
|
+
descriptions: Optional channel descriptions (from session.toml)
|
|
293
|
+
sources: Optional source device/system names
|
|
294
|
+
container_name: Name for the TTL table container (default: "TTLEvents")
|
|
295
|
+
|
|
296
|
+
Returns:
|
|
297
|
+
Modified NWBFile with TTL table added to acquisition
|
|
298
|
+
|
|
299
|
+
Example:
|
|
300
|
+
>>> from pynwb import NWBFile
|
|
301
|
+
>>> from w2t_bkin.ttl import get_ttl_pulses, add_ttl_table_to_nwb
|
|
302
|
+
>>>
|
|
303
|
+
>>> nwbfile = NWBFile(...)
|
|
304
|
+
>>> ttl_pulses = get_ttl_pulses(session_dir, ttl_patterns)
|
|
305
|
+
>>> nwbfile = add_ttl_table_to_nwb(
|
|
306
|
+
... nwbfile,
|
|
307
|
+
... ttl_pulses,
|
|
308
|
+
... descriptions={"ttl_camera": "Camera sync"},
|
|
309
|
+
... sources={"ttl_camera": "FLIR Blackfly"}
|
|
310
|
+
... )
|
|
311
|
+
"""
|
|
312
|
+
ttl_table = extract_ttl_table(
|
|
313
|
+
ttl_pulses,
|
|
314
|
+
name=container_name,
|
|
315
|
+
descriptions=descriptions,
|
|
316
|
+
sources=sources,
|
|
317
|
+
)
|
|
318
|
+
|
|
319
|
+
nwbfile.add_acquisition(ttl_table)
|
|
320
|
+
|
|
321
|
+
logger.info(f"Added EventsTable '{container_name}' to NWBFile acquisition")
|
|
322
|
+
|
|
323
|
+
return nwbfile
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
# =============================================================================
|
|
327
|
+
# Bpod Alignment (migrated from sync.behavior)
|
|
328
|
+
# =============================================================================
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
def get_sync_time_from_bpod_trial(trial_data: Dict, sync_signal: str) -> Optional[float]:
|
|
332
|
+
"""Extract sync signal start time from Bpod trial.
|
|
333
|
+
|
|
334
|
+
Args:
|
|
335
|
+
trial_data: Trial data with States structure
|
|
336
|
+
sync_signal: State name (e.g. "W2L_Audio")
|
|
337
|
+
|
|
338
|
+
Returns:
|
|
339
|
+
Start time relative to trial start, or None if not found
|
|
340
|
+
"""
|
|
341
|
+
from ..utils import convert_matlab_struct, is_nan_or_none
|
|
342
|
+
|
|
343
|
+
# Convert MATLAB struct to dict if needed
|
|
344
|
+
trial_data = convert_matlab_struct(trial_data)
|
|
345
|
+
|
|
346
|
+
states = trial_data.get("States", {})
|
|
347
|
+
if not states:
|
|
348
|
+
return None
|
|
349
|
+
|
|
350
|
+
# Convert states to dict if it's a MATLAB struct
|
|
351
|
+
states = convert_matlab_struct(states)
|
|
352
|
+
|
|
353
|
+
sync_times = states.get(sync_signal)
|
|
354
|
+
if sync_times is None:
|
|
355
|
+
return None
|
|
356
|
+
|
|
357
|
+
if not isinstance(sync_times, (list, tuple, np.ndarray)) or len(sync_times) < 2:
|
|
358
|
+
return None
|
|
359
|
+
|
|
360
|
+
start_time = sync_times[0]
|
|
361
|
+
if is_nan_or_none(start_time):
|
|
362
|
+
return None
|
|
363
|
+
|
|
364
|
+
return float(start_time)
|
|
365
|
+
|
|
366
|
+
|
|
367
|
+
class BpodTrialTypeProtocol(Protocol):
|
|
368
|
+
"""Protocol for Bpod trial type configuration access.
|
|
369
|
+
|
|
370
|
+
Defines minimal interface needed by sync.ttl module without
|
|
371
|
+
importing from domain.session.BpodTrialType.
|
|
372
|
+
|
|
373
|
+
Attributes:
|
|
374
|
+
trial_type: Trial type identifier
|
|
375
|
+
sync_signal: Bpod state/event name for alignment
|
|
376
|
+
sync_ttl: TTL channel ID for sync pulses
|
|
377
|
+
description: Human-readable description
|
|
378
|
+
"""
|
|
379
|
+
|
|
380
|
+
trial_type: int
|
|
381
|
+
sync_signal: str
|
|
382
|
+
sync_ttl: str
|
|
383
|
+
description: str
|
|
384
|
+
|
|
385
|
+
|
|
386
|
+
def align_bpod_trials_to_ttl(
|
|
387
|
+
trial_type_configs: List[BpodTrialTypeProtocol],
|
|
388
|
+
bpod_data: Dict,
|
|
389
|
+
ttl_pulses: Dict[str, List[float]],
|
|
390
|
+
) -> Tuple[Dict[int, float], List[str]]:
|
|
391
|
+
"""Align Bpod trials to absolute time using TTL sync signals (low-level, Session-free).
|
|
392
|
+
|
|
393
|
+
Converts Bpod relative timestamps to absolute time by matching per-trial
|
|
394
|
+
sync signals to corresponding TTL pulses. Returns per-trial offsets that
|
|
395
|
+
can be used with events.extract_trials() and events.extract_behavioral_events()
|
|
396
|
+
to convert relative timestamps to absolute timestamps.
|
|
397
|
+
|
|
398
|
+
Args:
|
|
399
|
+
trial_type_configs: List of trial type sync configurations
|
|
400
|
+
(from session.bpod.trial_types)
|
|
401
|
+
bpod_data: Parsed Bpod data (SessionData structure from events.parse_bpod)
|
|
402
|
+
ttl_pulses: Dict mapping TTL channel ID to sorted list of absolute timestamps
|
|
403
|
+
(typically from w2t_bkin.ttl.get_ttl_pulses)
|
|
404
|
+
|
|
405
|
+
Returns:
|
|
406
|
+
Tuple of:
|
|
407
|
+
- trial_offsets: Dict mapping trial_number → absolute time offset
|
|
408
|
+
- warnings: List of warning messages for trials that couldn't be aligned
|
|
409
|
+
|
|
410
|
+
Raises:
|
|
411
|
+
SyncError: If trial_type config missing or data structure invalid
|
|
412
|
+
"""
|
|
413
|
+
from ..utils import convert_matlab_struct, to_scalar
|
|
414
|
+
|
|
415
|
+
# Validate Bpod structure
|
|
416
|
+
if "SessionData" not in bpod_data:
|
|
417
|
+
raise SyncError("Invalid Bpod structure: missing SessionData")
|
|
418
|
+
|
|
419
|
+
session_data = convert_matlab_struct(bpod_data["SessionData"])
|
|
420
|
+
n_trials = int(session_data["nTrials"])
|
|
421
|
+
|
|
422
|
+
if n_trials == 0:
|
|
423
|
+
logger.info("No trials to align")
|
|
424
|
+
return {}, []
|
|
425
|
+
|
|
426
|
+
# Build trial_type → sync config mapping
|
|
427
|
+
trial_type_map = {}
|
|
428
|
+
for tt_config in trial_type_configs:
|
|
429
|
+
trial_type_map[tt_config.trial_type] = {
|
|
430
|
+
"sync_signal": tt_config.sync_signal,
|
|
431
|
+
"sync_ttl": tt_config.sync_ttl,
|
|
432
|
+
}
|
|
433
|
+
|
|
434
|
+
if not trial_type_map:
|
|
435
|
+
raise SyncError("No trial_type sync configuration provided in trial_type_configs")
|
|
436
|
+
|
|
437
|
+
# Prepare TTL pulse pointers (track consumption per channel)
|
|
438
|
+
ttl_pointers = {ttl_id: 0 for ttl_id in ttl_pulses.keys()}
|
|
439
|
+
|
|
440
|
+
# Extract raw events
|
|
441
|
+
raw_events = convert_matlab_struct(session_data["RawEvents"])
|
|
442
|
+
trial_data_list = raw_events["Trial"]
|
|
443
|
+
|
|
444
|
+
# Extract TrialTypes if available
|
|
445
|
+
trial_types_array = session_data.get("TrialTypes")
|
|
446
|
+
if trial_types_array is None:
|
|
447
|
+
# Default to trial_type 1 for all trials if not specified
|
|
448
|
+
trial_types_array = [1] * n_trials
|
|
449
|
+
logger.warning("TrialTypes not found in Bpod data, defaulting all trials to type 1")
|
|
450
|
+
|
|
451
|
+
trial_offsets = {}
|
|
452
|
+
warnings_list = []
|
|
453
|
+
|
|
454
|
+
for i in range(n_trials):
|
|
455
|
+
trial_num = i + 1
|
|
456
|
+
trial_data = convert_matlab_struct(trial_data_list[i])
|
|
457
|
+
|
|
458
|
+
# Get trial type (handle numpy arrays)
|
|
459
|
+
trial_type = int(to_scalar(trial_types_array, i))
|
|
460
|
+
|
|
461
|
+
if trial_type not in trial_type_map:
|
|
462
|
+
warnings_list.append(f"Trial {trial_num}: trial_type {trial_type} not in session config, skipping")
|
|
463
|
+
logger.warning(warnings_list[-1])
|
|
464
|
+
continue
|
|
465
|
+
|
|
466
|
+
sync_config = trial_type_map[trial_type]
|
|
467
|
+
sync_signal = sync_config["sync_signal"]
|
|
468
|
+
sync_ttl_id = sync_config["sync_ttl"]
|
|
469
|
+
|
|
470
|
+
# Extract sync time from trial (relative to trial start)
|
|
471
|
+
sync_time_rel = get_sync_time_from_bpod_trial(trial_data, sync_signal)
|
|
472
|
+
if sync_time_rel is None:
|
|
473
|
+
warnings_list.append(f"Trial {trial_num}: sync_signal '{sync_signal}' not found or not visited, skipping")
|
|
474
|
+
logger.warning(warnings_list[-1])
|
|
475
|
+
continue
|
|
476
|
+
|
|
477
|
+
# Get next TTL pulse
|
|
478
|
+
if sync_ttl_id not in ttl_pulses:
|
|
479
|
+
warnings_list.append(f"Trial {trial_num}: TTL channel '{sync_ttl_id}' not found in ttl_pulses, skipping")
|
|
480
|
+
logger.error(warnings_list[-1])
|
|
481
|
+
continue
|
|
482
|
+
|
|
483
|
+
ttl_channel = ttl_pulses[sync_ttl_id]
|
|
484
|
+
ttl_ptr = ttl_pointers[sync_ttl_id]
|
|
485
|
+
|
|
486
|
+
if ttl_ptr >= len(ttl_channel):
|
|
487
|
+
warnings_list.append(f"Trial {trial_num}: No more TTL pulses available for '{sync_ttl_id}', skipping")
|
|
488
|
+
logger.warning(warnings_list[-1])
|
|
489
|
+
continue
|
|
490
|
+
|
|
491
|
+
ttl_pulse_time = ttl_channel[ttl_ptr]
|
|
492
|
+
ttl_pointers[sync_ttl_id] += 1
|
|
493
|
+
|
|
494
|
+
# Get trial start timestamp from Bpod (may be non-zero after merge)
|
|
495
|
+
trial_start_timestamp = float(to_scalar(session_data["TrialStartTimestamp"], i))
|
|
496
|
+
|
|
497
|
+
# Compute offset: absolute_time = offset + TrialStartTimestamp
|
|
498
|
+
# The sync signal occurs at: trial_start_timestamp + sync_time_rel (in Bpod timeline)
|
|
499
|
+
# And should align to: ttl_pulse_time (in absolute timeline)
|
|
500
|
+
# Therefore: offset + (trial_start_timestamp + sync_time_rel) = ttl_pulse_time
|
|
501
|
+
offset_abs = ttl_pulse_time - (trial_start_timestamp + sync_time_rel)
|
|
502
|
+
trial_offsets[trial_num] = offset_abs
|
|
503
|
+
|
|
504
|
+
logger.debug(
|
|
505
|
+
f"Trial {trial_num}: type={trial_type}, sync_signal={sync_signal}, "
|
|
506
|
+
f"trial_start={trial_start_timestamp:.4f}s, sync_rel={sync_time_rel:.4f}s, "
|
|
507
|
+
f"ttl_abs={ttl_pulse_time:.4f}s, offset={offset_abs:.4f}s"
|
|
508
|
+
) # fmt: skip
|
|
509
|
+
|
|
510
|
+
# Warn about unused TTL pulses
|
|
511
|
+
for ttl_id, ptr in ttl_pointers.items():
|
|
512
|
+
unused = len(ttl_pulses[ttl_id]) - ptr
|
|
513
|
+
if unused > 0:
|
|
514
|
+
warnings_list.append(f"TTL channel '{ttl_id}' has {unused} unused pulses")
|
|
515
|
+
logger.warning(warnings_list[-1])
|
|
516
|
+
|
|
517
|
+
logger.info(f"Computed offsets for {len(trial_offsets)} out of {n_trials} trials using TTL sync")
|
|
518
|
+
return trial_offsets, warnings_list
|
w2t_bkin/ttl/models.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
"""Re-export ndx-events types for TTL signal integration.
|
|
2
|
+
|
|
3
|
+
This module provides convenient access to ndx-events types used for
|
|
4
|
+
hardware synchronization signal recording.
|
|
5
|
+
|
|
6
|
+
Types
|
|
7
|
+
-----
|
|
8
|
+
- EventsTable: Table for timestamped events with metadata
|
|
9
|
+
|
|
10
|
+
Requirements
|
|
11
|
+
------------
|
|
12
|
+
- ndx-events~=0.4.0 (already in dependencies)
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from ndx_events import EventsTable
|
|
16
|
+
|
|
17
|
+
__all__ = [
|
|
18
|
+
"EventsTable",
|
|
19
|
+
]
|