w2t-bkin 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- w2t_bkin/__init__.py +85 -0
- w2t_bkin/behavior/__init__.py +115 -0
- w2t_bkin/behavior/core.py +1027 -0
- w2t_bkin/bpod/__init__.py +38 -0
- w2t_bkin/bpod/core.py +519 -0
- w2t_bkin/config.py +625 -0
- w2t_bkin/dlc/__init__.py +59 -0
- w2t_bkin/dlc/core.py +448 -0
- w2t_bkin/dlc/models.py +124 -0
- w2t_bkin/exceptions.py +426 -0
- w2t_bkin/facemap/__init__.py +42 -0
- w2t_bkin/facemap/core.py +397 -0
- w2t_bkin/facemap/models.py +134 -0
- w2t_bkin/pipeline.py +665 -0
- w2t_bkin/pose/__init__.py +48 -0
- w2t_bkin/pose/core.py +227 -0
- w2t_bkin/pose/io.py +363 -0
- w2t_bkin/pose/skeleton.py +165 -0
- w2t_bkin/pose/ttl_mock.py +477 -0
- w2t_bkin/session.py +423 -0
- w2t_bkin/sync/__init__.py +72 -0
- w2t_bkin/sync/core.py +678 -0
- w2t_bkin/sync/stats.py +176 -0
- w2t_bkin/sync/timebase.py +311 -0
- w2t_bkin/sync/ttl.py +254 -0
- w2t_bkin/transcode/__init__.py +38 -0
- w2t_bkin/transcode/core.py +303 -0
- w2t_bkin/transcode/models.py +96 -0
- w2t_bkin/ttl/__init__.py +64 -0
- w2t_bkin/ttl/core.py +518 -0
- w2t_bkin/ttl/models.py +19 -0
- w2t_bkin/utils.py +1093 -0
- w2t_bkin-0.0.6.dist-info/METADATA +145 -0
- w2t_bkin-0.0.6.dist-info/RECORD +36 -0
- w2t_bkin-0.0.6.dist-info/WHEEL +4 -0
- w2t_bkin-0.0.6.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,1027 @@
|
|
|
1
|
+
"""Core transformation functions: Bpod data → ndx-structured-behavior.
|
|
2
|
+
|
|
3
|
+
This module implements the transformation layer that converts parsed Bpod .mat
|
|
4
|
+
files into ndx-structured-behavior NWB classes. All functions produce NWB-native
|
|
5
|
+
objects directly, following the NWB-first architecture from Phase 1.
|
|
6
|
+
|
|
7
|
+
Architecture:
|
|
8
|
+
- Low-level: Parse Bpod .mat files (events.bpod module)
|
|
9
|
+
- Mid-level: Transform to ndx-structured-behavior (this module)
|
|
10
|
+
- High-level: Integrate with pipeline and NWB assembly (pipeline.py, nwb.py)
|
|
11
|
+
|
|
12
|
+
Data Flow:
|
|
13
|
+
Bpod .mat → parse_bpod() → extract_*_types() → extract_*() → build_trials_table() → TaskRecording → NWBFile
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
import logging
|
|
17
|
+
from typing import Any, Dict, List, Optional, Set, Tuple
|
|
18
|
+
|
|
19
|
+
from ndx_structured_behavior import ActionsTable, ActionTypesTable, EventsTable, EventTypesTable, StatesTable, StateTypesTable, Task, TaskArgumentsTable, TaskRecording, TrialsTable
|
|
20
|
+
import numpy as np
|
|
21
|
+
|
|
22
|
+
from ..exceptions import BpodParseError
|
|
23
|
+
from ..utils import convert_matlab_struct, is_nan_or_none, to_scalar
|
|
24
|
+
|
|
25
|
+
logger = logging.getLogger(__name__)
|
|
26
|
+
|
|
27
|
+
# Mapping of Bpod state names to action names
|
|
28
|
+
# States that represent actions (rewards, stimuli, etc.)
|
|
29
|
+
ACTION_STATES = {
|
|
30
|
+
"LeftReward": "left_valve_open",
|
|
31
|
+
"RightReward": "right_valve_open",
|
|
32
|
+
"W2T_Audio": "audio_stimulus",
|
|
33
|
+
"A2L_Audio": "audio_stimulus",
|
|
34
|
+
"Airpuff": "airpuff_stimulus",
|
|
35
|
+
"Microstim": "microstimulation",
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
# =============================================================================
|
|
40
|
+
# Type Tables (Metadata)
|
|
41
|
+
# =============================================================================
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def extract_state_types(bpod_data: Dict[str, Any]) -> StateTypesTable:
|
|
45
|
+
"""Extract unique state types from Bpod data.
|
|
46
|
+
|
|
47
|
+
Discovers all state names present in RawEvents.Trial[].States and
|
|
48
|
+
creates a StateTypesTable for ndx-structured-behavior.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
bpod_data: Parsed Bpod data dictionary from parse_bpod()
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
StateTypesTable with all unique state names
|
|
55
|
+
|
|
56
|
+
Raises:
|
|
57
|
+
BpodParseError: Invalid Bpod structure
|
|
58
|
+
|
|
59
|
+
Example:
|
|
60
|
+
>>> bpod_data = parse_bpod(Path("data"), "Bpod/*.mat", "name_asc")
|
|
61
|
+
>>> state_types = extract_state_types(bpod_data)
|
|
62
|
+
>>> print(state_types["state_name"].data)
|
|
63
|
+
['ITI', 'Response_window', 'HIT', 'Miss', ...]
|
|
64
|
+
"""
|
|
65
|
+
session_data = convert_matlab_struct(bpod_data.get("SessionData", {}))
|
|
66
|
+
|
|
67
|
+
if "RawEvents" not in session_data:
|
|
68
|
+
raise BpodParseError("Missing RawEvents in Bpod data")
|
|
69
|
+
|
|
70
|
+
raw_events = convert_matlab_struct(session_data["RawEvents"])
|
|
71
|
+
trial_data_list = raw_events.get("Trial", [])
|
|
72
|
+
|
|
73
|
+
# Discover unique state names across all trials
|
|
74
|
+
state_names: Set[str] = set()
|
|
75
|
+
|
|
76
|
+
for trial_data in trial_data_list:
|
|
77
|
+
# Handle both dict and MATLAB struct
|
|
78
|
+
if hasattr(trial_data, "States"):
|
|
79
|
+
states = trial_data.States
|
|
80
|
+
elif isinstance(trial_data, dict):
|
|
81
|
+
states = trial_data.get("States", {})
|
|
82
|
+
else:
|
|
83
|
+
continue
|
|
84
|
+
|
|
85
|
+
states = convert_matlab_struct(states)
|
|
86
|
+
state_names.update(states.keys())
|
|
87
|
+
|
|
88
|
+
# Create StateTypesTable
|
|
89
|
+
state_types = StateTypesTable(description="State types from Bpod protocol")
|
|
90
|
+
|
|
91
|
+
# Add states in sorted order for consistency
|
|
92
|
+
for state_name in sorted(state_names):
|
|
93
|
+
state_types.add_row(state_name=state_name)
|
|
94
|
+
|
|
95
|
+
logger.info(f"Extracted {len(state_names)} unique state types")
|
|
96
|
+
return state_types
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def extract_event_types(bpod_data: Dict[str, Any]) -> EventTypesTable:
|
|
100
|
+
"""Extract unique event types from Bpod data.
|
|
101
|
+
|
|
102
|
+
Discovers all event names present in RawEvents.Trial[].Events and
|
|
103
|
+
creates an EventTypesTable for ndx-structured-behavior.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
bpod_data: Parsed Bpod data dictionary from parse_bpod()
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
EventTypesTable with all unique event names
|
|
110
|
+
|
|
111
|
+
Raises:
|
|
112
|
+
BpodParseError: Invalid Bpod structure
|
|
113
|
+
|
|
114
|
+
Example:
|
|
115
|
+
>>> bpod_data = parse_bpod(Path("data"), "Bpod/*.mat", "name_asc")
|
|
116
|
+
>>> event_types = extract_event_types(bpod_data)
|
|
117
|
+
>>> print(event_types["event_name"].data)
|
|
118
|
+
['Port1In', 'Port1Out', 'BNC1High', 'Flex1Trig1', ...]
|
|
119
|
+
"""
|
|
120
|
+
session_data = convert_matlab_struct(bpod_data.get("SessionData", {}))
|
|
121
|
+
|
|
122
|
+
if "RawEvents" not in session_data:
|
|
123
|
+
raise BpodParseError("Missing RawEvents in Bpod data")
|
|
124
|
+
|
|
125
|
+
raw_events = convert_matlab_struct(session_data["RawEvents"])
|
|
126
|
+
trial_data_list = raw_events.get("Trial", [])
|
|
127
|
+
|
|
128
|
+
# Discover unique event names across all trials
|
|
129
|
+
event_names: Set[str] = set()
|
|
130
|
+
|
|
131
|
+
for trial_data in trial_data_list:
|
|
132
|
+
# Handle both dict and MATLAB struct
|
|
133
|
+
if hasattr(trial_data, "Events"):
|
|
134
|
+
events = trial_data.Events
|
|
135
|
+
elif isinstance(trial_data, dict):
|
|
136
|
+
events = trial_data.get("Events", {})
|
|
137
|
+
else:
|
|
138
|
+
continue
|
|
139
|
+
|
|
140
|
+
events = convert_matlab_struct(events)
|
|
141
|
+
event_names.update(events.keys())
|
|
142
|
+
|
|
143
|
+
# Create EventTypesTable
|
|
144
|
+
event_types = EventTypesTable(description="Event types from Bpod hardware")
|
|
145
|
+
|
|
146
|
+
# Add events in sorted order for consistency
|
|
147
|
+
for event_name in sorted(event_names):
|
|
148
|
+
event_types.add_row(event_name=event_name)
|
|
149
|
+
|
|
150
|
+
logger.info(f"Extracted {len(event_names)} unique event types")
|
|
151
|
+
return event_types
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def extract_action_types(bpod_data: Dict[str, Any]) -> ActionTypesTable:
|
|
155
|
+
"""Extract action types from Bpod state names.
|
|
156
|
+
|
|
157
|
+
Identifies states that represent actions (rewards, stimuli) using
|
|
158
|
+
the ACTION_STATES mapping and creates an ActionTypesTable.
|
|
159
|
+
|
|
160
|
+
Args:
|
|
161
|
+
bpod_data: Parsed Bpod data dictionary from parse_bpod()
|
|
162
|
+
|
|
163
|
+
Returns:
|
|
164
|
+
ActionTypesTable with action names
|
|
165
|
+
|
|
166
|
+
Example:
|
|
167
|
+
>>> bpod_data = parse_bpod(Path("data"), "Bpod/*.mat", "name_asc")
|
|
168
|
+
>>> action_types = extract_action_types(bpod_data)
|
|
169
|
+
>>> print(action_types["action_name"].data)
|
|
170
|
+
['left_valve_open', 'right_valve_open', 'audio_stimulus', ...]
|
|
171
|
+
"""
|
|
172
|
+
session_data = convert_matlab_struct(bpod_data.get("SessionData", {}))
|
|
173
|
+
|
|
174
|
+
if "RawEvents" not in session_data:
|
|
175
|
+
raise BpodParseError("Missing RawEvents in Bpod data")
|
|
176
|
+
|
|
177
|
+
raw_events = convert_matlab_struct(session_data["RawEvents"])
|
|
178
|
+
trial_data_list = raw_events.get("Trial", [])
|
|
179
|
+
|
|
180
|
+
# Discover action states present in data
|
|
181
|
+
observed_actions: Set[str] = set()
|
|
182
|
+
|
|
183
|
+
for trial_data in trial_data_list:
|
|
184
|
+
# Handle both dict and MATLAB struct
|
|
185
|
+
if hasattr(trial_data, "States"):
|
|
186
|
+
states = trial_data.States
|
|
187
|
+
elif isinstance(trial_data, dict):
|
|
188
|
+
states = trial_data.get("States", {})
|
|
189
|
+
else:
|
|
190
|
+
continue
|
|
191
|
+
|
|
192
|
+
states = convert_matlab_struct(states)
|
|
193
|
+
|
|
194
|
+
# Check which action states are present
|
|
195
|
+
for state_name in states.keys():
|
|
196
|
+
if state_name in ACTION_STATES:
|
|
197
|
+
observed_actions.add(ACTION_STATES[state_name])
|
|
198
|
+
|
|
199
|
+
# Create ActionTypesTable
|
|
200
|
+
action_types = ActionTypesTable(description="Action types from Bpod protocol")
|
|
201
|
+
|
|
202
|
+
# Add actions in sorted order for consistency
|
|
203
|
+
for action_name in sorted(observed_actions):
|
|
204
|
+
action_types.add_row(action_name=action_name)
|
|
205
|
+
|
|
206
|
+
logger.info(f"Extracted {len(observed_actions)} unique action types")
|
|
207
|
+
return action_types
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
# =============================================================================
|
|
211
|
+
# Data Tables (Temporal Events)
|
|
212
|
+
# =============================================================================
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
def extract_states(
|
|
216
|
+
bpod_data: Dict[str, Any],
|
|
217
|
+
state_types: StateTypesTable,
|
|
218
|
+
trial_offsets: Optional[Dict[int, float]] = None,
|
|
219
|
+
) -> StatesTable:
|
|
220
|
+
"""Extract state sequences from Bpod data.
|
|
221
|
+
|
|
222
|
+
Converts RawEvents.Trial[].States to ndx-structured-behavior StatesTable
|
|
223
|
+
with start_time/stop_time for each state occurrence.
|
|
224
|
+
|
|
225
|
+
Args:
|
|
226
|
+
bpod_data: Parsed Bpod data dictionary
|
|
227
|
+
state_types: StateTypesTable with state name → index mapping
|
|
228
|
+
trial_offsets: Optional dict mapping trial_number → absolute time offset
|
|
229
|
+
|
|
230
|
+
Returns:
|
|
231
|
+
Tuple of (StatesTable with state occurrences, Dict mapping trial_number → list of state row indices)
|
|
232
|
+
|
|
233
|
+
Example:
|
|
234
|
+
>>> states, state_indices = extract_states(bpod_data, state_types, trial_offsets)
|
|
235
|
+
>>> print(f"{len(states)} state occurrences")
|
|
236
|
+
>>> print(f"Trial 1 has {len(state_indices[1])} states")
|
|
237
|
+
"""
|
|
238
|
+
session_data = convert_matlab_struct(bpod_data.get("SessionData", {}))
|
|
239
|
+
raw_events = convert_matlab_struct(session_data["RawEvents"])
|
|
240
|
+
trial_data_list = raw_events.get("Trial", [])
|
|
241
|
+
start_timestamps = session_data["TrialStartTimestamp"]
|
|
242
|
+
|
|
243
|
+
# Build state name → index mapping
|
|
244
|
+
state_name_to_idx = {name: idx for idx, name in enumerate(state_types["state_name"].data)}
|
|
245
|
+
|
|
246
|
+
# Create StatesTable
|
|
247
|
+
states = StatesTable(description="State sequences from Bpod trials", state_types_table=state_types)
|
|
248
|
+
|
|
249
|
+
n_states = 0
|
|
250
|
+
# Track which states belong to which trial
|
|
251
|
+
trial_state_indices: Dict[int, List[int]] = {}
|
|
252
|
+
|
|
253
|
+
for trial_idx, trial_data in enumerate(trial_data_list):
|
|
254
|
+
trial_num = trial_idx + 1
|
|
255
|
+
trial_start_ts = float(to_scalar(start_timestamps, trial_idx))
|
|
256
|
+
|
|
257
|
+
# Initialize list for this trial's state indices
|
|
258
|
+
trial_state_indices[trial_num] = []
|
|
259
|
+
|
|
260
|
+
# Get time offset for absolute time conversion
|
|
261
|
+
offset = trial_offsets.get(trial_num) if trial_offsets else 0.0
|
|
262
|
+
|
|
263
|
+
# Extract states
|
|
264
|
+
if hasattr(trial_data, "States"):
|
|
265
|
+
trial_states = trial_data.States
|
|
266
|
+
elif isinstance(trial_data, dict):
|
|
267
|
+
trial_states = trial_data.get("States", {})
|
|
268
|
+
else:
|
|
269
|
+
continue
|
|
270
|
+
|
|
271
|
+
trial_states = convert_matlab_struct(trial_states)
|
|
272
|
+
|
|
273
|
+
# Add each state occurrence
|
|
274
|
+
for state_name, state_times in trial_states.items():
|
|
275
|
+
if state_name not in state_name_to_idx:
|
|
276
|
+
logger.warning(f"Unknown state '{state_name}' not in StateTypesTable")
|
|
277
|
+
continue
|
|
278
|
+
|
|
279
|
+
# Check if state was visited (non-NaN start time)
|
|
280
|
+
if isinstance(state_times, np.ndarray) and state_times.size >= 2:
|
|
281
|
+
start_rel = float(state_times.flat[0])
|
|
282
|
+
stop_rel = float(state_times.flat[1])
|
|
283
|
+
elif isinstance(state_times, (list, tuple)) and len(state_times) >= 2:
|
|
284
|
+
start_rel = float(state_times[0])
|
|
285
|
+
stop_rel = float(state_times[1])
|
|
286
|
+
else:
|
|
287
|
+
continue
|
|
288
|
+
|
|
289
|
+
# Skip NaN states (not visited)
|
|
290
|
+
if is_nan_or_none(start_rel) or is_nan_or_none(stop_rel):
|
|
291
|
+
continue
|
|
292
|
+
|
|
293
|
+
# Convert to absolute time
|
|
294
|
+
start_abs = offset + trial_start_ts + start_rel
|
|
295
|
+
stop_abs = offset + trial_start_ts + stop_rel
|
|
296
|
+
|
|
297
|
+
# Add to StatesTable
|
|
298
|
+
state_type_idx = state_name_to_idx[state_name]
|
|
299
|
+
states.add_state(
|
|
300
|
+
state_type=state_type_idx,
|
|
301
|
+
start_time=start_abs,
|
|
302
|
+
stop_time=stop_abs,
|
|
303
|
+
)
|
|
304
|
+
# Track this state index for the trial
|
|
305
|
+
trial_state_indices[trial_num].append(n_states)
|
|
306
|
+
n_states += 1
|
|
307
|
+
|
|
308
|
+
logger.info(f"Extracted {n_states} state occurrences from {len(trial_data_list)} trials")
|
|
309
|
+
return states, trial_state_indices
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
def extract_events(
|
|
313
|
+
bpod_data: Dict[str, Any],
|
|
314
|
+
event_types: EventTypesTable,
|
|
315
|
+
trial_offsets: Optional[Dict[int, float]] = None,
|
|
316
|
+
) -> Tuple[EventsTable, Dict[int, List[int]]]:
|
|
317
|
+
"""Extract hardware events from Bpod data.
|
|
318
|
+
|
|
319
|
+
Converts RawEvents.Trial[].Events to ndx-structured-behavior EventsTable
|
|
320
|
+
with timestamps for each event occurrence.
|
|
321
|
+
|
|
322
|
+
Args:
|
|
323
|
+
bpod_data: Parsed Bpod data dictionary
|
|
324
|
+
event_types: EventTypesTable with event name → index mapping
|
|
325
|
+
trial_offsets: Optional dict mapping trial_number → absolute time offset
|
|
326
|
+
|
|
327
|
+
Returns:
|
|
328
|
+
Tuple of (EventsTable with event occurrences, Dict mapping trial_number → list of event row indices)
|
|
329
|
+
|
|
330
|
+
Example:
|
|
331
|
+
>>> events, event_indices = extract_events(bpod_data, event_types, trial_offsets)
|
|
332
|
+
>>> print(f"{len(events)} event occurrences")
|
|
333
|
+
>>> print(f"Trial 1 has {len(event_indices[1])} events")
|
|
334
|
+
"""
|
|
335
|
+
session_data = convert_matlab_struct(bpod_data.get("SessionData", {}))
|
|
336
|
+
raw_events = convert_matlab_struct(session_data["RawEvents"])
|
|
337
|
+
trial_data_list = raw_events.get("Trial", [])
|
|
338
|
+
start_timestamps = session_data["TrialStartTimestamp"]
|
|
339
|
+
|
|
340
|
+
# Build event name → index mapping
|
|
341
|
+
event_name_to_idx = {name: idx for idx, name in enumerate(event_types["event_name"].data)}
|
|
342
|
+
|
|
343
|
+
# Create EventsTable
|
|
344
|
+
events = EventsTable(description="Hardware events from Bpod", event_types_table=event_types)
|
|
345
|
+
|
|
346
|
+
n_events = 0
|
|
347
|
+
# Track which events belong to which trial
|
|
348
|
+
trial_event_indices: Dict[int, List[int]] = {}
|
|
349
|
+
|
|
350
|
+
for trial_idx, trial_data in enumerate(trial_data_list):
|
|
351
|
+
trial_num = trial_idx + 1
|
|
352
|
+
trial_start_ts = float(to_scalar(start_timestamps, trial_idx))
|
|
353
|
+
|
|
354
|
+
# Initialize list for this trial's event indices
|
|
355
|
+
trial_event_indices[trial_num] = []
|
|
356
|
+
|
|
357
|
+
# Get time offset for absolute time conversion
|
|
358
|
+
offset = trial_offsets.get(trial_num) if trial_offsets else 0.0
|
|
359
|
+
|
|
360
|
+
# Extract events
|
|
361
|
+
if hasattr(trial_data, "Events"):
|
|
362
|
+
trial_events = trial_data.Events
|
|
363
|
+
elif isinstance(trial_data, dict):
|
|
364
|
+
trial_events = trial_data.get("Events", {})
|
|
365
|
+
else:
|
|
366
|
+
continue
|
|
367
|
+
|
|
368
|
+
trial_events = convert_matlab_struct(trial_events)
|
|
369
|
+
|
|
370
|
+
# Add each event occurrence
|
|
371
|
+
for event_name, timestamps in trial_events.items():
|
|
372
|
+
if event_name not in event_name_to_idx:
|
|
373
|
+
logger.warning(f"Unknown event '{event_name}' not in EventTypesTable")
|
|
374
|
+
continue
|
|
375
|
+
|
|
376
|
+
# Convert to list if numpy array or scalar
|
|
377
|
+
if isinstance(timestamps, np.ndarray):
|
|
378
|
+
timestamps = timestamps.flatten().tolist()
|
|
379
|
+
elif not isinstance(timestamps, (list, tuple)):
|
|
380
|
+
timestamps = [timestamps]
|
|
381
|
+
|
|
382
|
+
event_type_idx = event_name_to_idx[event_name]
|
|
383
|
+
|
|
384
|
+
# Add each timestamp
|
|
385
|
+
for timestamp_rel in timestamps:
|
|
386
|
+
if is_nan_or_none(timestamp_rel):
|
|
387
|
+
continue
|
|
388
|
+
|
|
389
|
+
timestamp_rel = float(timestamp_rel)
|
|
390
|
+
timestamp_abs = offset + trial_start_ts + timestamp_rel
|
|
391
|
+
|
|
392
|
+
# Add to EventsTable
|
|
393
|
+
events.add_event(
|
|
394
|
+
event_type=event_type_idx,
|
|
395
|
+
timestamp=timestamp_abs,
|
|
396
|
+
value=event_name, # Store original event name
|
|
397
|
+
)
|
|
398
|
+
# Track this event index for the trial
|
|
399
|
+
trial_event_indices[trial_num].append(n_events)
|
|
400
|
+
n_events += 1
|
|
401
|
+
|
|
402
|
+
logger.info(f"Extracted {n_events} event occurrences from {len(trial_data_list)} trials")
|
|
403
|
+
return events, trial_event_indices
|
|
404
|
+
|
|
405
|
+
|
|
406
|
+
def extract_actions(
|
|
407
|
+
bpod_data: Dict[str, Any],
|
|
408
|
+
action_types: ActionTypesTable,
|
|
409
|
+
trial_offsets: Optional[Dict[int, float]] = None,
|
|
410
|
+
) -> Tuple[ActionsTable, Dict[int, List[int]]]:
|
|
411
|
+
"""Extract actions from Bpod state transitions.
|
|
412
|
+
|
|
413
|
+
Identifies action states (rewards, stimuli) and converts to
|
|
414
|
+
ndx-structured-behavior ActionsTable with timestamps and durations.
|
|
415
|
+
|
|
416
|
+
Args:
|
|
417
|
+
bpod_data: Parsed Bpod data dictionary
|
|
418
|
+
action_types: ActionTypesTable with action name → index mapping
|
|
419
|
+
trial_offsets: Optional dict mapping trial_number → absolute time offset
|
|
420
|
+
|
|
421
|
+
Returns:
|
|
422
|
+
Tuple of (ActionsTable with action occurrences, Dict mapping trial_number → list of action row indices)
|
|
423
|
+
|
|
424
|
+
Example:
|
|
425
|
+
>>> actions, action_indices = extract_actions(bpod_data, action_types, trial_offsets)
|
|
426
|
+
>>> print(f"{len(actions)} action occurrences")
|
|
427
|
+
>>> print(f"Trial 1 has {len(action_indices[1])} actions")
|
|
428
|
+
"""
|
|
429
|
+
session_data = convert_matlab_struct(bpod_data.get("SessionData", {}))
|
|
430
|
+
raw_events = convert_matlab_struct(session_data["RawEvents"])
|
|
431
|
+
trial_data_list = raw_events.get("Trial", [])
|
|
432
|
+
start_timestamps = session_data["TrialStartTimestamp"]
|
|
433
|
+
|
|
434
|
+
# Build action name → index mapping
|
|
435
|
+
action_name_to_idx = {name: idx for idx, name in enumerate(action_types["action_name"].data)}
|
|
436
|
+
|
|
437
|
+
# Reverse mapping: state_name → action_name
|
|
438
|
+
state_to_action = {state: action for state, action in ACTION_STATES.items() if action in action_name_to_idx}
|
|
439
|
+
|
|
440
|
+
# Create ActionsTable
|
|
441
|
+
actions = ActionsTable(description="Actions from Bpod protocol", action_types_table=action_types)
|
|
442
|
+
|
|
443
|
+
n_actions = 0
|
|
444
|
+
# Track which actions belong to which trial
|
|
445
|
+
trial_action_indices: Dict[int, List[int]] = {}
|
|
446
|
+
|
|
447
|
+
for trial_idx, trial_data in enumerate(trial_data_list):
|
|
448
|
+
trial_num = trial_idx + 1
|
|
449
|
+
trial_start_ts = float(to_scalar(start_timestamps, trial_idx))
|
|
450
|
+
|
|
451
|
+
# Initialize list for this trial's action indices
|
|
452
|
+
trial_action_indices[trial_num] = []
|
|
453
|
+
|
|
454
|
+
# Get time offset for absolute time conversion
|
|
455
|
+
offset = trial_offsets.get(trial_num) if trial_offsets else 0.0
|
|
456
|
+
|
|
457
|
+
# Extract states
|
|
458
|
+
if hasattr(trial_data, "States"):
|
|
459
|
+
trial_states = trial_data.States
|
|
460
|
+
elif isinstance(trial_data, dict):
|
|
461
|
+
trial_states = trial_data.get("States", {})
|
|
462
|
+
else:
|
|
463
|
+
continue
|
|
464
|
+
|
|
465
|
+
trial_states = convert_matlab_struct(trial_states)
|
|
466
|
+
|
|
467
|
+
# Check action states
|
|
468
|
+
for state_name, state_times in trial_states.items():
|
|
469
|
+
if state_name not in state_to_action:
|
|
470
|
+
continue
|
|
471
|
+
|
|
472
|
+
action_name = state_to_action[state_name]
|
|
473
|
+
action_type_idx = action_name_to_idx[action_name]
|
|
474
|
+
|
|
475
|
+
# Check if state was visited
|
|
476
|
+
if isinstance(state_times, np.ndarray) and state_times.size >= 2:
|
|
477
|
+
start_rel = float(state_times.flat[0])
|
|
478
|
+
stop_rel = float(state_times.flat[1])
|
|
479
|
+
elif isinstance(state_times, (list, tuple)) and len(state_times) >= 2:
|
|
480
|
+
start_rel = float(state_times[0])
|
|
481
|
+
stop_rel = float(state_times[1])
|
|
482
|
+
else:
|
|
483
|
+
continue
|
|
484
|
+
|
|
485
|
+
# Skip NaN states (not visited)
|
|
486
|
+
if is_nan_or_none(start_rel) or is_nan_or_none(stop_rel):
|
|
487
|
+
continue
|
|
488
|
+
|
|
489
|
+
# Convert to absolute time
|
|
490
|
+
timestamp_abs = offset + trial_start_ts + start_rel
|
|
491
|
+
duration = stop_rel - start_rel
|
|
492
|
+
|
|
493
|
+
# Add to ActionsTable
|
|
494
|
+
actions.add_action(
|
|
495
|
+
action_type=action_type_idx,
|
|
496
|
+
timestamp=timestamp_abs,
|
|
497
|
+
duration=duration,
|
|
498
|
+
value=state_name, # Original state name for traceability
|
|
499
|
+
)
|
|
500
|
+
# Track this action index for the trial
|
|
501
|
+
trial_action_indices[trial_num].append(n_actions)
|
|
502
|
+
n_actions += 1
|
|
503
|
+
|
|
504
|
+
logger.info(f"Extracted {n_actions} action occurrences from {len(trial_data_list)} trials")
|
|
505
|
+
return actions, trial_action_indices
|
|
506
|
+
|
|
507
|
+
|
|
508
|
+
# =============================================================================
|
|
509
|
+
# Trials and Recording
|
|
510
|
+
# =============================================================================
|
|
511
|
+
|
|
512
|
+
|
|
513
|
+
def build_trials_table(
|
|
514
|
+
bpod_data: Dict[str, Any],
|
|
515
|
+
recording: TaskRecording,
|
|
516
|
+
state_indices: Dict[int, List[int]],
|
|
517
|
+
event_indices: Dict[int, List[int]],
|
|
518
|
+
action_indices: Dict[int, List[int]],
|
|
519
|
+
trial_offsets: Optional[Dict[int, float]] = None,
|
|
520
|
+
) -> TrialsTable:
|
|
521
|
+
"""Build TrialsTable with references to TaskRecording tables.
|
|
522
|
+
|
|
523
|
+
Creates ndx-structured-behavior TrialsTable with start/stop times for
|
|
524
|
+
each trial and index ranges referencing the states/events/actions tables
|
|
525
|
+
from the TaskRecording.
|
|
526
|
+
|
|
527
|
+
This simplified API ensures that the TrialsTable references the exact same
|
|
528
|
+
table instances as the TaskRecording, preventing instance mismatch errors.
|
|
529
|
+
|
|
530
|
+
Args:
|
|
531
|
+
bpod_data: Parsed Bpod data dictionary
|
|
532
|
+
recording: TaskRecording containing states/events/actions tables
|
|
533
|
+
state_indices: Dict mapping trial_number → list of state row indices
|
|
534
|
+
event_indices: Dict mapping trial_number → list of event row indices
|
|
535
|
+
action_indices: Dict mapping trial_number → list of action row indices
|
|
536
|
+
trial_offsets: Optional dict mapping trial_number → absolute time offset
|
|
537
|
+
|
|
538
|
+
Returns:
|
|
539
|
+
TrialsTable with trial structure
|
|
540
|
+
|
|
541
|
+
Example:
|
|
542
|
+
>>> # Build TaskRecording first
|
|
543
|
+
>>> recording = build_task_recording(states, events, actions)
|
|
544
|
+
>>> # Build TrialsTable using the same instances
|
|
545
|
+
>>> trials = build_trials_table(bpod_data, recording,
|
|
546
|
+
... state_indices, event_indices, action_indices,
|
|
547
|
+
... trial_offsets)
|
|
548
|
+
>>> print(f"{len(trials)} trials")
|
|
549
|
+
"""
|
|
550
|
+
# Extract tables from TaskRecording to ensure instance consistency
|
|
551
|
+
states = recording.states
|
|
552
|
+
events = recording.events
|
|
553
|
+
actions = recording.actions
|
|
554
|
+
|
|
555
|
+
session_data = convert_matlab_struct(bpod_data.get("SessionData", {}))
|
|
556
|
+
n_trials = int(session_data["nTrials"])
|
|
557
|
+
start_timestamps = session_data["TrialStartTimestamp"]
|
|
558
|
+
end_timestamps = session_data["TrialEndTimestamp"]
|
|
559
|
+
|
|
560
|
+
# Create TrialsTable
|
|
561
|
+
trials = TrialsTable(
|
|
562
|
+
description="Trials from Bpod session",
|
|
563
|
+
states_table=states,
|
|
564
|
+
events_table=events,
|
|
565
|
+
actions_table=actions,
|
|
566
|
+
)
|
|
567
|
+
|
|
568
|
+
# Build trials with references to states/events/actions
|
|
569
|
+
for trial_idx in range(n_trials):
|
|
570
|
+
trial_num = trial_idx + 1
|
|
571
|
+
trial_start_rel = float(to_scalar(start_timestamps, trial_idx))
|
|
572
|
+
trial_stop_rel = float(to_scalar(end_timestamps, trial_idx))
|
|
573
|
+
|
|
574
|
+
# Get time offset
|
|
575
|
+
offset = trial_offsets.get(trial_num) if trial_offsets else 0.0
|
|
576
|
+
|
|
577
|
+
# Convert to absolute time
|
|
578
|
+
start_time = offset + trial_start_rel
|
|
579
|
+
stop_time = offset + trial_stop_rel
|
|
580
|
+
|
|
581
|
+
# Get indices for this trial (use empty lists if trial not found)
|
|
582
|
+
trial_states = state_indices.get(trial_num, [])
|
|
583
|
+
trial_events = event_indices.get(trial_num, [])
|
|
584
|
+
trial_actions = action_indices.get(trial_num, [])
|
|
585
|
+
|
|
586
|
+
trials.add_trial(
|
|
587
|
+
start_time=start_time,
|
|
588
|
+
stop_time=stop_time,
|
|
589
|
+
states=trial_states,
|
|
590
|
+
events=trial_events,
|
|
591
|
+
actions=trial_actions,
|
|
592
|
+
)
|
|
593
|
+
|
|
594
|
+
logger.info(f"Built TrialsTable with {n_trials} trials")
|
|
595
|
+
return trials
|
|
596
|
+
|
|
597
|
+
|
|
598
|
+
def extract_trials_table(
|
|
599
|
+
bpod_data: Dict[str, Any],
|
|
600
|
+
recording: TaskRecording,
|
|
601
|
+
trial_offsets: Optional[Dict[int, float]] = None,
|
|
602
|
+
) -> TrialsTable:
|
|
603
|
+
"""Extract complete TrialsTable from Bpod data using TaskRecording.
|
|
604
|
+
|
|
605
|
+
High-level function that builds a TrialsTable using the data tables from
|
|
606
|
+
an existing TaskRecording. This ensures instance consistency between the
|
|
607
|
+
TaskRecording and TrialsTable.
|
|
608
|
+
|
|
609
|
+
This is the recommended approach for creating TrialsTable:
|
|
610
|
+
1. Build TaskRecording with extract_task_recording() or build_task_recording()
|
|
611
|
+
2. Pass TaskRecording to this function to build TrialsTable
|
|
612
|
+
|
|
613
|
+
Args:
|
|
614
|
+
bpod_data: Parsed Bpod data dictionary
|
|
615
|
+
recording: TaskRecording with states/events/actions tables
|
|
616
|
+
trial_offsets: Optional dict mapping trial_number → absolute time offset
|
|
617
|
+
|
|
618
|
+
Returns:
|
|
619
|
+
TrialsTable with complete trial structure
|
|
620
|
+
|
|
621
|
+
Example:
|
|
622
|
+
>>> from w2t_bkin.bpod.code import parse_bpod
|
|
623
|
+
>>> from w2t_bkin.behavior import extract_task_recording, extract_trials_table
|
|
624
|
+
>>>
|
|
625
|
+
>>> bpod_data = parse_bpod(Path("data"), "Bpod/*.mat", "name_asc")
|
|
626
|
+
>>> recording = extract_task_recording(bpod_data, trial_offsets)
|
|
627
|
+
>>> trials = extract_trials_table(bpod_data, recording, trial_offsets)
|
|
628
|
+
>>> print(f"{len(trials)} trials extracted")
|
|
629
|
+
|
|
630
|
+
Note:
|
|
631
|
+
The recording parameter ensures that TrialsTable references the exact
|
|
632
|
+
same table instances as the TaskRecording, preventing NWB serialization
|
|
633
|
+
errors due to instance mismatches.
|
|
634
|
+
"""
|
|
635
|
+
# Extract type tables from recording to build indices
|
|
636
|
+
states = recording.states
|
|
637
|
+
events = recording.events
|
|
638
|
+
actions = recording.actions
|
|
639
|
+
|
|
640
|
+
state_types = states.state_type.table
|
|
641
|
+
event_types = events.event_type.table
|
|
642
|
+
action_types = actions.action_type.table
|
|
643
|
+
|
|
644
|
+
# Re-extract indices (they're not stored in TaskRecording)
|
|
645
|
+
# This is necessary to map trial_number → row indices
|
|
646
|
+
_, state_indices = extract_states(bpod_data, state_types, trial_offsets)
|
|
647
|
+
_, event_indices = extract_events(bpod_data, event_types, trial_offsets)
|
|
648
|
+
_, action_indices = extract_actions(bpod_data, action_types, trial_offsets)
|
|
649
|
+
|
|
650
|
+
# Build TrialsTable using TaskRecording
|
|
651
|
+
trials = build_trials_table(
|
|
652
|
+
bpod_data=bpod_data,
|
|
653
|
+
recording=recording,
|
|
654
|
+
state_indices=state_indices,
|
|
655
|
+
event_indices=event_indices,
|
|
656
|
+
action_indices=action_indices,
|
|
657
|
+
trial_offsets=trial_offsets,
|
|
658
|
+
)
|
|
659
|
+
|
|
660
|
+
return trials
|
|
661
|
+
|
|
662
|
+
|
|
663
|
+
def build_task_recording(
|
|
664
|
+
states: StatesTable,
|
|
665
|
+
events: EventsTable,
|
|
666
|
+
actions: ActionsTable,
|
|
667
|
+
) -> TaskRecording:
|
|
668
|
+
"""Build TaskRecording container for states/events/actions.
|
|
669
|
+
|
|
670
|
+
Creates ndx-structured-behavior TaskRecording object that packages
|
|
671
|
+
the three data tables for NWB file integration.
|
|
672
|
+
|
|
673
|
+
Args:
|
|
674
|
+
states: StatesTable with state occurrences
|
|
675
|
+
events: EventsTable with event occurrences
|
|
676
|
+
actions: ActionsTable with action occurrences
|
|
677
|
+
|
|
678
|
+
Returns:
|
|
679
|
+
TaskRecording container
|
|
680
|
+
|
|
681
|
+
Example:
|
|
682
|
+
>>> task_recording = build_task_recording(states, events, actions)
|
|
683
|
+
>>> nwbfile.add_acquisition(task_recording)
|
|
684
|
+
"""
|
|
685
|
+
task_recording = TaskRecording(
|
|
686
|
+
states=states,
|
|
687
|
+
events=events,
|
|
688
|
+
actions=actions,
|
|
689
|
+
)
|
|
690
|
+
|
|
691
|
+
logger.info("Built TaskRecording container")
|
|
692
|
+
return task_recording
|
|
693
|
+
|
|
694
|
+
|
|
695
|
+
def extract_task_recording(
|
|
696
|
+
bpod_data: Dict[str, Any],
|
|
697
|
+
trial_offsets: Optional[Dict[int, float]] = None,
|
|
698
|
+
) -> TaskRecording:
|
|
699
|
+
"""Extract complete TaskRecording from Bpod data (convenience function).
|
|
700
|
+
|
|
701
|
+
High-level function that performs all extraction steps:
|
|
702
|
+
1. Extract type tables (states, events, actions)
|
|
703
|
+
2. Extract data tables (states, events, actions) with row indices
|
|
704
|
+
3. Build TaskRecording container
|
|
705
|
+
|
|
706
|
+
This is a convenience wrapper for simpler API usage when you need the
|
|
707
|
+
complete TaskRecording for NWB acquisition.
|
|
708
|
+
|
|
709
|
+
Args:
|
|
710
|
+
bpod_data: Parsed Bpod data dictionary
|
|
711
|
+
trial_offsets: Optional dict mapping trial_number → absolute time offset
|
|
712
|
+
|
|
713
|
+
Returns:
|
|
714
|
+
TaskRecording with complete state/event/action tables
|
|
715
|
+
|
|
716
|
+
Example:
|
|
717
|
+
>>> from w2t_bkin.bpod.code import parse_bpod
|
|
718
|
+
>>> from w2t_bkin.behavior import extract_task_recording
|
|
719
|
+
>>>
|
|
720
|
+
>>> bpod_data = parse_bpod(Path("data"), "Bpod/*.mat", "name_asc")
|
|
721
|
+
>>> task_recording = extract_task_recording(bpod_data)
|
|
722
|
+
>>> nwbfile.add_acquisition(task_recording)
|
|
723
|
+
|
|
724
|
+
Note:
|
|
725
|
+
If you need access to the intermediate type tables or data tables,
|
|
726
|
+
use the individual extract_* functions instead.
|
|
727
|
+
"""
|
|
728
|
+
# Step 1: Extract type tables
|
|
729
|
+
state_types = extract_state_types(bpod_data)
|
|
730
|
+
event_types = extract_event_types(bpod_data)
|
|
731
|
+
action_types = extract_action_types(bpod_data)
|
|
732
|
+
|
|
733
|
+
# Step 2: Extract data tables with indices
|
|
734
|
+
states, _ = extract_states(bpod_data, state_types, trial_offsets)
|
|
735
|
+
events, _ = extract_events(bpod_data, event_types, trial_offsets)
|
|
736
|
+
actions, _ = extract_actions(bpod_data, action_types, trial_offsets)
|
|
737
|
+
|
|
738
|
+
# Step 3: Build TaskRecording
|
|
739
|
+
task_recording = build_task_recording(states, events, actions)
|
|
740
|
+
|
|
741
|
+
return task_recording
|
|
742
|
+
|
|
743
|
+
|
|
744
|
+
# =============================================================================
|
|
745
|
+
# Task Metadata (Top-level Container)
|
|
746
|
+
# =============================================================================
|
|
747
|
+
|
|
748
|
+
|
|
749
|
+
def _flatten_dict(d: Dict[str, Any], parent_key: str = "", sep: str = ".") -> List[tuple]:
|
|
750
|
+
"""Recursively flatten nested dictionary into list of (key, value) tuples.
|
|
751
|
+
|
|
752
|
+
Args:
|
|
753
|
+
d: Dictionary to flatten
|
|
754
|
+
parent_key: Parent key prefix for nested keys
|
|
755
|
+
sep: Separator between parent and child keys
|
|
756
|
+
|
|
757
|
+
Returns:
|
|
758
|
+
List of (flattened_key, value) tuples
|
|
759
|
+
|
|
760
|
+
Example:
|
|
761
|
+
>>> _flatten_dict({'a': 1, 'b': {'c': 2, 'd': 3}})
|
|
762
|
+
[('a', 1), ('b.c', 2), ('b.d', 3)]
|
|
763
|
+
"""
|
|
764
|
+
items = []
|
|
765
|
+
for k, v in d.items():
|
|
766
|
+
new_key = f"{parent_key}{sep}{k}" if parent_key else k
|
|
767
|
+
if isinstance(v, dict):
|
|
768
|
+
items.extend(_flatten_dict(v, new_key, sep=sep))
|
|
769
|
+
else:
|
|
770
|
+
items.append((new_key, v))
|
|
771
|
+
return items
|
|
772
|
+
|
|
773
|
+
|
|
774
|
+
def extract_task_arguments(bpod_data: Dict[str, Any]) -> Optional[TaskArgumentsTable]:
|
|
775
|
+
"""Extract task arguments/parameters from Bpod data.
|
|
776
|
+
|
|
777
|
+
Attempts to extract task configuration parameters from:
|
|
778
|
+
1. SessionData.Settings (protocol parameters) - most common
|
|
779
|
+
2. SessionData.TrialSettings (per-trial parameters) - if uniform across trials
|
|
780
|
+
3. Top-level SessionData fields (metadata)
|
|
781
|
+
|
|
782
|
+
Args:
|
|
783
|
+
bpod_data: Parsed Bpod data dictionary from parse_bpod()
|
|
784
|
+
|
|
785
|
+
Returns:
|
|
786
|
+
TaskArgumentsTable if arguments found, None otherwise
|
|
787
|
+
|
|
788
|
+
Example:
|
|
789
|
+
>>> bpod_data = parse_bpod(Path("data"), "Bpod/*.mat", "name_asc")
|
|
790
|
+
>>> task_args = extract_task_arguments(bpod_data)
|
|
791
|
+
>>> if task_args:
|
|
792
|
+
... print(f"{len(task_args)} parameters")
|
|
793
|
+
"""
|
|
794
|
+
session_data = convert_matlab_struct(bpod_data.get("SessionData", {}))
|
|
795
|
+
|
|
796
|
+
# Try Settings first (most common location)
|
|
797
|
+
params = {}
|
|
798
|
+
if "Settings" in session_data:
|
|
799
|
+
settings = convert_matlab_struct(session_data["Settings"])
|
|
800
|
+
if isinstance(settings, dict) and len(settings) > 0:
|
|
801
|
+
params.update(dict(_flatten_dict(settings)))
|
|
802
|
+
logger.info(f"Found {len(params)} parameters in Settings")
|
|
803
|
+
|
|
804
|
+
# Try TrialSettings (check if uniform across trials)
|
|
805
|
+
if "TrialSettings" in session_data and len(params) == 0:
|
|
806
|
+
trial_settings = session_data["TrialSettings"]
|
|
807
|
+
if hasattr(trial_settings, "__len__") and len(trial_settings) > 0:
|
|
808
|
+
first_trial = convert_matlab_struct(trial_settings[0])
|
|
809
|
+
if isinstance(first_trial, dict):
|
|
810
|
+
# Check if all trials have same settings
|
|
811
|
+
uniform = True
|
|
812
|
+
for trial in trial_settings[1:]:
|
|
813
|
+
trial_dict = convert_matlab_struct(trial)
|
|
814
|
+
if trial_dict != first_trial:
|
|
815
|
+
uniform = False
|
|
816
|
+
break
|
|
817
|
+
|
|
818
|
+
if uniform:
|
|
819
|
+
params.update(dict(_flatten_dict(first_trial)))
|
|
820
|
+
logger.info(f"Found {len(params)} uniform parameters in TrialSettings")
|
|
821
|
+
else:
|
|
822
|
+
logger.debug("TrialSettings vary across trials, not extracting as task arguments")
|
|
823
|
+
|
|
824
|
+
# Add useful metadata fields
|
|
825
|
+
metadata_fields = ["nTrials", "TrialTypes"]
|
|
826
|
+
for field in metadata_fields:
|
|
827
|
+
if field in session_data and field not in params:
|
|
828
|
+
value = session_data[field]
|
|
829
|
+
# Convert arrays to scalar if single value
|
|
830
|
+
if hasattr(value, "__len__") and not isinstance(value, str):
|
|
831
|
+
if len(set(value)) == 1: # All same value
|
|
832
|
+
value = value[0]
|
|
833
|
+
else:
|
|
834
|
+
continue # Skip non-uniform arrays
|
|
835
|
+
params[field] = value
|
|
836
|
+
|
|
837
|
+
if len(params) == 0:
|
|
838
|
+
logger.info("No task arguments found in Bpod data")
|
|
839
|
+
return None
|
|
840
|
+
|
|
841
|
+
# Create TaskArgumentsTable
|
|
842
|
+
task_args = TaskArgumentsTable(description="Task parameters from Bpod")
|
|
843
|
+
|
|
844
|
+
# Add each parameter as a row
|
|
845
|
+
for arg_name, arg_value in sorted(params.items()):
|
|
846
|
+
# Convert value to string for storage
|
|
847
|
+
if isinstance(arg_value, (np.ndarray, list)):
|
|
848
|
+
value_str = str(list(arg_value))
|
|
849
|
+
value_type = "array"
|
|
850
|
+
elif isinstance(arg_value, (int, np.integer)):
|
|
851
|
+
value_str = str(arg_value)
|
|
852
|
+
value_type = "integer"
|
|
853
|
+
elif isinstance(arg_value, (float, np.floating)):
|
|
854
|
+
value_str = str(arg_value)
|
|
855
|
+
value_type = "float"
|
|
856
|
+
elif isinstance(arg_value, bool):
|
|
857
|
+
value_str = str(arg_value)
|
|
858
|
+
value_type = "boolean"
|
|
859
|
+
else:
|
|
860
|
+
value_str = str(arg_value)
|
|
861
|
+
value_type = "string"
|
|
862
|
+
|
|
863
|
+
task_args.add_row(
|
|
864
|
+
argument_name=arg_name,
|
|
865
|
+
argument_description=f"Parameter from Bpod data",
|
|
866
|
+
expression=value_str,
|
|
867
|
+
expression_type=value_type,
|
|
868
|
+
output_type=value_type,
|
|
869
|
+
)
|
|
870
|
+
|
|
871
|
+
logger.info(f"Extracted {len(task_args)} task arguments")
|
|
872
|
+
return task_args
|
|
873
|
+
|
|
874
|
+
|
|
875
|
+
def build_task(
|
|
876
|
+
state_types: StateTypesTable,
|
|
877
|
+
event_types: EventTypesTable,
|
|
878
|
+
action_types: ActionTypesTable,
|
|
879
|
+
task_arguments: Optional[TaskArgumentsTable] = None,
|
|
880
|
+
) -> Task:
|
|
881
|
+
"""Build Task container with type tables and metadata.
|
|
882
|
+
|
|
883
|
+
Assembles the top-level Task container that holds all behavioral
|
|
884
|
+
type tables (states, events, actions) and optional task metadata
|
|
885
|
+
(arguments). This Task object is added to /general/task in the NWBFile.
|
|
886
|
+
|
|
887
|
+
Args:
|
|
888
|
+
state_types: StateTypesTable with state definitions
|
|
889
|
+
event_types: EventTypesTable with event definitions
|
|
890
|
+
action_types: ActionTypesTable with action definitions
|
|
891
|
+
task_arguments: Optional task parameters/arguments
|
|
892
|
+
|
|
893
|
+
Returns:
|
|
894
|
+
Task container for /general/task in NWBFile
|
|
895
|
+
|
|
896
|
+
Example:
|
|
897
|
+
>>> task_args = extract_task_arguments(bpod_data)
|
|
898
|
+
>>> task = build_task(state_types, event_types, action_types,
|
|
899
|
+
... task_arguments=task_args)
|
|
900
|
+
>>> nwbfile.add_lab_meta_data(task)
|
|
901
|
+
"""
|
|
902
|
+
# Create Task container with required type tables
|
|
903
|
+
task = Task(
|
|
904
|
+
event_types=event_types,
|
|
905
|
+
state_types=state_types,
|
|
906
|
+
action_types=action_types,
|
|
907
|
+
)
|
|
908
|
+
|
|
909
|
+
# Add optional task arguments
|
|
910
|
+
if task_arguments is not None:
|
|
911
|
+
task.task_arguments = task_arguments
|
|
912
|
+
logger.info(f"Built Task with {len(task_arguments)} arguments")
|
|
913
|
+
else:
|
|
914
|
+
logger.info("Built Task without arguments")
|
|
915
|
+
|
|
916
|
+
return task
|
|
917
|
+
|
|
918
|
+
|
|
919
|
+
def extract_task(bpod_data: Dict[str, Any]) -> Task:
|
|
920
|
+
"""Extract complete Task container from Bpod data (convenience function).
|
|
921
|
+
|
|
922
|
+
High-level function that performs all extraction steps:
|
|
923
|
+
1. Extract type tables (states, events, actions)
|
|
924
|
+
2. Extract task arguments (optional)
|
|
925
|
+
3. Build Task container
|
|
926
|
+
|
|
927
|
+
This is a convenience wrapper for simpler API usage when you need the
|
|
928
|
+
complete Task container for /general/task in NWBFile.
|
|
929
|
+
|
|
930
|
+
Args:
|
|
931
|
+
bpod_data: Parsed Bpod data dictionary
|
|
932
|
+
|
|
933
|
+
Returns:
|
|
934
|
+
Task container with type tables and optional arguments
|
|
935
|
+
|
|
936
|
+
Example:
|
|
937
|
+
>>> from w2t_bkin.bpod.code import parse_bpod
|
|
938
|
+
>>> from w2t_bkin.behavior import extract_task
|
|
939
|
+
>>>
|
|
940
|
+
>>> bpod_data = parse_bpod(Path("data"), "Bpod/*.mat", "name_asc")
|
|
941
|
+
>>> task = extract_task(bpod_data)
|
|
942
|
+
>>> nwbfile.add_lab_meta_data(task)
|
|
943
|
+
|
|
944
|
+
Note:
|
|
945
|
+
If you need access to the intermediate type tables or task arguments,
|
|
946
|
+
use the individual extract_* functions instead.
|
|
947
|
+
"""
|
|
948
|
+
# Step 1: Extract type tables
|
|
949
|
+
state_types = extract_state_types(bpod_data)
|
|
950
|
+
event_types = extract_event_types(bpod_data)
|
|
951
|
+
action_types = extract_action_types(bpod_data)
|
|
952
|
+
|
|
953
|
+
# Step 2: Extract task arguments (optional)
|
|
954
|
+
task_arguments = extract_task_arguments(bpod_data)
|
|
955
|
+
|
|
956
|
+
# Step 3: Build Task
|
|
957
|
+
task = build_task(state_types, event_types, action_types, task_arguments)
|
|
958
|
+
|
|
959
|
+
return task
|
|
960
|
+
|
|
961
|
+
|
|
962
|
+
def extract_behavioral_data(
|
|
963
|
+
bpod_data: Dict[str, Any],
|
|
964
|
+
trial_offsets: Optional[Dict[int, float]] = None,
|
|
965
|
+
) -> Tuple[Task, TaskRecording, TrialsTable]:
|
|
966
|
+
"""Extract all behavioral data structures in one call (highest-level convenience).
|
|
967
|
+
|
|
968
|
+
This is the simplest API for extracting complete behavioral data from Bpod.
|
|
969
|
+
It extracts Task, TaskRecording, and TrialsTable with guaranteed instance
|
|
970
|
+
consistency between all components.
|
|
971
|
+
|
|
972
|
+
Recommended for most use cases where you need all three components.
|
|
973
|
+
|
|
974
|
+
Args:
|
|
975
|
+
bpod_data: Parsed Bpod data dictionary
|
|
976
|
+
trial_offsets: Optional dict mapping trial_number → absolute time offset
|
|
977
|
+
|
|
978
|
+
Returns:
|
|
979
|
+
Tuple of (Task, TaskRecording, TrialsTable)
|
|
980
|
+
|
|
981
|
+
Example:
|
|
982
|
+
>>> from w2t_bkin.bpod.code import parse_bpod
|
|
983
|
+
>>> from w2t_bkin.behavior import extract_behavioral_data
|
|
984
|
+
>>>
|
|
985
|
+
>>> bpod_data = parse_bpod(Path("data"), "Bpod/*.mat", "name_asc")
|
|
986
|
+
>>> task, recording, trials = extract_behavioral_data(bpod_data, trial_offsets)
|
|
987
|
+
>>>
|
|
988
|
+
>>> # Add to NWB file
|
|
989
|
+
>>> nwbfile.add_lab_meta_data(task)
|
|
990
|
+
>>> nwbfile.add_acquisition(recording)
|
|
991
|
+
>>> nwbfile.trials = trials
|
|
992
|
+
|
|
993
|
+
Note:
|
|
994
|
+
This function ensures that:
|
|
995
|
+
- Task contains the type tables
|
|
996
|
+
- TaskRecording references those type tables
|
|
997
|
+
- TrialsTable references the data tables from TaskRecording
|
|
998
|
+
"""
|
|
999
|
+
# Step 1: Extract type tables once (shared between Task and TaskRecording)
|
|
1000
|
+
state_types = extract_state_types(bpod_data)
|
|
1001
|
+
event_types = extract_event_types(bpod_data)
|
|
1002
|
+
action_types = extract_action_types(bpod_data)
|
|
1003
|
+
|
|
1004
|
+
# Step 2: Build Task with type tables
|
|
1005
|
+
task_arguments = extract_task_arguments(bpod_data)
|
|
1006
|
+
task = build_task(state_types, event_types, action_types, task_arguments)
|
|
1007
|
+
|
|
1008
|
+
# Step 3: Extract data tables using the same type tables
|
|
1009
|
+
states, state_indices = extract_states(bpod_data, state_types, trial_offsets)
|
|
1010
|
+
events, event_indices = extract_events(bpod_data, event_types, trial_offsets)
|
|
1011
|
+
actions, action_indices = extract_actions(bpod_data, action_types, trial_offsets)
|
|
1012
|
+
|
|
1013
|
+
# Step 4: Build TaskRecording with data tables
|
|
1014
|
+
recording = build_task_recording(states, events, actions)
|
|
1015
|
+
|
|
1016
|
+
# Step 5: Build TrialsTable using TaskRecording (ensures instance consistency)
|
|
1017
|
+
trials = build_trials_table(
|
|
1018
|
+
bpod_data=bpod_data,
|
|
1019
|
+
recording=recording,
|
|
1020
|
+
state_indices=state_indices,
|
|
1021
|
+
event_indices=event_indices,
|
|
1022
|
+
action_indices=action_indices,
|
|
1023
|
+
trial_offsets=trial_offsets,
|
|
1024
|
+
)
|
|
1025
|
+
|
|
1026
|
+
logger.info("Extracted complete behavioral data: Task, TaskRecording, and TrialsTable")
|
|
1027
|
+
return task, recording, trials
|