w2t-bkin 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1027 @@
1
+ """Core transformation functions: Bpod data → ndx-structured-behavior.
2
+
3
+ This module implements the transformation layer that converts parsed Bpod .mat
4
+ files into ndx-structured-behavior NWB classes. All functions produce NWB-native
5
+ objects directly, following the NWB-first architecture from Phase 1.
6
+
7
+ Architecture:
8
+ - Low-level: Parse Bpod .mat files (events.bpod module)
9
+ - Mid-level: Transform to ndx-structured-behavior (this module)
10
+ - High-level: Integrate with pipeline and NWB assembly (pipeline.py, nwb.py)
11
+
12
+ Data Flow:
13
+ Bpod .mat → parse_bpod() → extract_*_types() → extract_*() → build_trials_table() → TaskRecording → NWBFile
14
+ """
15
+
16
+ import logging
17
+ from typing import Any, Dict, List, Optional, Set, Tuple
18
+
19
+ from ndx_structured_behavior import ActionsTable, ActionTypesTable, EventsTable, EventTypesTable, StatesTable, StateTypesTable, Task, TaskArgumentsTable, TaskRecording, TrialsTable
20
+ import numpy as np
21
+
22
+ from ..exceptions import BpodParseError
23
+ from ..utils import convert_matlab_struct, is_nan_or_none, to_scalar
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+ # Mapping of Bpod state names to action names
28
+ # States that represent actions (rewards, stimuli, etc.)
29
+ ACTION_STATES = {
30
+ "LeftReward": "left_valve_open",
31
+ "RightReward": "right_valve_open",
32
+ "W2T_Audio": "audio_stimulus",
33
+ "A2L_Audio": "audio_stimulus",
34
+ "Airpuff": "airpuff_stimulus",
35
+ "Microstim": "microstimulation",
36
+ }
37
+
38
+
39
+ # =============================================================================
40
+ # Type Tables (Metadata)
41
+ # =============================================================================
42
+
43
+
44
+ def extract_state_types(bpod_data: Dict[str, Any]) -> StateTypesTable:
45
+ """Extract unique state types from Bpod data.
46
+
47
+ Discovers all state names present in RawEvents.Trial[].States and
48
+ creates a StateTypesTable for ndx-structured-behavior.
49
+
50
+ Args:
51
+ bpod_data: Parsed Bpod data dictionary from parse_bpod()
52
+
53
+ Returns:
54
+ StateTypesTable with all unique state names
55
+
56
+ Raises:
57
+ BpodParseError: Invalid Bpod structure
58
+
59
+ Example:
60
+ >>> bpod_data = parse_bpod(Path("data"), "Bpod/*.mat", "name_asc")
61
+ >>> state_types = extract_state_types(bpod_data)
62
+ >>> print(state_types["state_name"].data)
63
+ ['ITI', 'Response_window', 'HIT', 'Miss', ...]
64
+ """
65
+ session_data = convert_matlab_struct(bpod_data.get("SessionData", {}))
66
+
67
+ if "RawEvents" not in session_data:
68
+ raise BpodParseError("Missing RawEvents in Bpod data")
69
+
70
+ raw_events = convert_matlab_struct(session_data["RawEvents"])
71
+ trial_data_list = raw_events.get("Trial", [])
72
+
73
+ # Discover unique state names across all trials
74
+ state_names: Set[str] = set()
75
+
76
+ for trial_data in trial_data_list:
77
+ # Handle both dict and MATLAB struct
78
+ if hasattr(trial_data, "States"):
79
+ states = trial_data.States
80
+ elif isinstance(trial_data, dict):
81
+ states = trial_data.get("States", {})
82
+ else:
83
+ continue
84
+
85
+ states = convert_matlab_struct(states)
86
+ state_names.update(states.keys())
87
+
88
+ # Create StateTypesTable
89
+ state_types = StateTypesTable(description="State types from Bpod protocol")
90
+
91
+ # Add states in sorted order for consistency
92
+ for state_name in sorted(state_names):
93
+ state_types.add_row(state_name=state_name)
94
+
95
+ logger.info(f"Extracted {len(state_names)} unique state types")
96
+ return state_types
97
+
98
+
99
+ def extract_event_types(bpod_data: Dict[str, Any]) -> EventTypesTable:
100
+ """Extract unique event types from Bpod data.
101
+
102
+ Discovers all event names present in RawEvents.Trial[].Events and
103
+ creates an EventTypesTable for ndx-structured-behavior.
104
+
105
+ Args:
106
+ bpod_data: Parsed Bpod data dictionary from parse_bpod()
107
+
108
+ Returns:
109
+ EventTypesTable with all unique event names
110
+
111
+ Raises:
112
+ BpodParseError: Invalid Bpod structure
113
+
114
+ Example:
115
+ >>> bpod_data = parse_bpod(Path("data"), "Bpod/*.mat", "name_asc")
116
+ >>> event_types = extract_event_types(bpod_data)
117
+ >>> print(event_types["event_name"].data)
118
+ ['Port1In', 'Port1Out', 'BNC1High', 'Flex1Trig1', ...]
119
+ """
120
+ session_data = convert_matlab_struct(bpod_data.get("SessionData", {}))
121
+
122
+ if "RawEvents" not in session_data:
123
+ raise BpodParseError("Missing RawEvents in Bpod data")
124
+
125
+ raw_events = convert_matlab_struct(session_data["RawEvents"])
126
+ trial_data_list = raw_events.get("Trial", [])
127
+
128
+ # Discover unique event names across all trials
129
+ event_names: Set[str] = set()
130
+
131
+ for trial_data in trial_data_list:
132
+ # Handle both dict and MATLAB struct
133
+ if hasattr(trial_data, "Events"):
134
+ events = trial_data.Events
135
+ elif isinstance(trial_data, dict):
136
+ events = trial_data.get("Events", {})
137
+ else:
138
+ continue
139
+
140
+ events = convert_matlab_struct(events)
141
+ event_names.update(events.keys())
142
+
143
+ # Create EventTypesTable
144
+ event_types = EventTypesTable(description="Event types from Bpod hardware")
145
+
146
+ # Add events in sorted order for consistency
147
+ for event_name in sorted(event_names):
148
+ event_types.add_row(event_name=event_name)
149
+
150
+ logger.info(f"Extracted {len(event_names)} unique event types")
151
+ return event_types
152
+
153
+
154
+ def extract_action_types(bpod_data: Dict[str, Any]) -> ActionTypesTable:
155
+ """Extract action types from Bpod state names.
156
+
157
+ Identifies states that represent actions (rewards, stimuli) using
158
+ the ACTION_STATES mapping and creates an ActionTypesTable.
159
+
160
+ Args:
161
+ bpod_data: Parsed Bpod data dictionary from parse_bpod()
162
+
163
+ Returns:
164
+ ActionTypesTable with action names
165
+
166
+ Example:
167
+ >>> bpod_data = parse_bpod(Path("data"), "Bpod/*.mat", "name_asc")
168
+ >>> action_types = extract_action_types(bpod_data)
169
+ >>> print(action_types["action_name"].data)
170
+ ['left_valve_open', 'right_valve_open', 'audio_stimulus', ...]
171
+ """
172
+ session_data = convert_matlab_struct(bpod_data.get("SessionData", {}))
173
+
174
+ if "RawEvents" not in session_data:
175
+ raise BpodParseError("Missing RawEvents in Bpod data")
176
+
177
+ raw_events = convert_matlab_struct(session_data["RawEvents"])
178
+ trial_data_list = raw_events.get("Trial", [])
179
+
180
+ # Discover action states present in data
181
+ observed_actions: Set[str] = set()
182
+
183
+ for trial_data in trial_data_list:
184
+ # Handle both dict and MATLAB struct
185
+ if hasattr(trial_data, "States"):
186
+ states = trial_data.States
187
+ elif isinstance(trial_data, dict):
188
+ states = trial_data.get("States", {})
189
+ else:
190
+ continue
191
+
192
+ states = convert_matlab_struct(states)
193
+
194
+ # Check which action states are present
195
+ for state_name in states.keys():
196
+ if state_name in ACTION_STATES:
197
+ observed_actions.add(ACTION_STATES[state_name])
198
+
199
+ # Create ActionTypesTable
200
+ action_types = ActionTypesTable(description="Action types from Bpod protocol")
201
+
202
+ # Add actions in sorted order for consistency
203
+ for action_name in sorted(observed_actions):
204
+ action_types.add_row(action_name=action_name)
205
+
206
+ logger.info(f"Extracted {len(observed_actions)} unique action types")
207
+ return action_types
208
+
209
+
210
+ # =============================================================================
211
+ # Data Tables (Temporal Events)
212
+ # =============================================================================
213
+
214
+
215
+ def extract_states(
216
+ bpod_data: Dict[str, Any],
217
+ state_types: StateTypesTable,
218
+ trial_offsets: Optional[Dict[int, float]] = None,
219
+ ) -> StatesTable:
220
+ """Extract state sequences from Bpod data.
221
+
222
+ Converts RawEvents.Trial[].States to ndx-structured-behavior StatesTable
223
+ with start_time/stop_time for each state occurrence.
224
+
225
+ Args:
226
+ bpod_data: Parsed Bpod data dictionary
227
+ state_types: StateTypesTable with state name → index mapping
228
+ trial_offsets: Optional dict mapping trial_number → absolute time offset
229
+
230
+ Returns:
231
+ Tuple of (StatesTable with state occurrences, Dict mapping trial_number → list of state row indices)
232
+
233
+ Example:
234
+ >>> states, state_indices = extract_states(bpod_data, state_types, trial_offsets)
235
+ >>> print(f"{len(states)} state occurrences")
236
+ >>> print(f"Trial 1 has {len(state_indices[1])} states")
237
+ """
238
+ session_data = convert_matlab_struct(bpod_data.get("SessionData", {}))
239
+ raw_events = convert_matlab_struct(session_data["RawEvents"])
240
+ trial_data_list = raw_events.get("Trial", [])
241
+ start_timestamps = session_data["TrialStartTimestamp"]
242
+
243
+ # Build state name → index mapping
244
+ state_name_to_idx = {name: idx for idx, name in enumerate(state_types["state_name"].data)}
245
+
246
+ # Create StatesTable
247
+ states = StatesTable(description="State sequences from Bpod trials", state_types_table=state_types)
248
+
249
+ n_states = 0
250
+ # Track which states belong to which trial
251
+ trial_state_indices: Dict[int, List[int]] = {}
252
+
253
+ for trial_idx, trial_data in enumerate(trial_data_list):
254
+ trial_num = trial_idx + 1
255
+ trial_start_ts = float(to_scalar(start_timestamps, trial_idx))
256
+
257
+ # Initialize list for this trial's state indices
258
+ trial_state_indices[trial_num] = []
259
+
260
+ # Get time offset for absolute time conversion
261
+ offset = trial_offsets.get(trial_num) if trial_offsets else 0.0
262
+
263
+ # Extract states
264
+ if hasattr(trial_data, "States"):
265
+ trial_states = trial_data.States
266
+ elif isinstance(trial_data, dict):
267
+ trial_states = trial_data.get("States", {})
268
+ else:
269
+ continue
270
+
271
+ trial_states = convert_matlab_struct(trial_states)
272
+
273
+ # Add each state occurrence
274
+ for state_name, state_times in trial_states.items():
275
+ if state_name not in state_name_to_idx:
276
+ logger.warning(f"Unknown state '{state_name}' not in StateTypesTable")
277
+ continue
278
+
279
+ # Check if state was visited (non-NaN start time)
280
+ if isinstance(state_times, np.ndarray) and state_times.size >= 2:
281
+ start_rel = float(state_times.flat[0])
282
+ stop_rel = float(state_times.flat[1])
283
+ elif isinstance(state_times, (list, tuple)) and len(state_times) >= 2:
284
+ start_rel = float(state_times[0])
285
+ stop_rel = float(state_times[1])
286
+ else:
287
+ continue
288
+
289
+ # Skip NaN states (not visited)
290
+ if is_nan_or_none(start_rel) or is_nan_or_none(stop_rel):
291
+ continue
292
+
293
+ # Convert to absolute time
294
+ start_abs = offset + trial_start_ts + start_rel
295
+ stop_abs = offset + trial_start_ts + stop_rel
296
+
297
+ # Add to StatesTable
298
+ state_type_idx = state_name_to_idx[state_name]
299
+ states.add_state(
300
+ state_type=state_type_idx,
301
+ start_time=start_abs,
302
+ stop_time=stop_abs,
303
+ )
304
+ # Track this state index for the trial
305
+ trial_state_indices[trial_num].append(n_states)
306
+ n_states += 1
307
+
308
+ logger.info(f"Extracted {n_states} state occurrences from {len(trial_data_list)} trials")
309
+ return states, trial_state_indices
310
+
311
+
312
+ def extract_events(
313
+ bpod_data: Dict[str, Any],
314
+ event_types: EventTypesTable,
315
+ trial_offsets: Optional[Dict[int, float]] = None,
316
+ ) -> Tuple[EventsTable, Dict[int, List[int]]]:
317
+ """Extract hardware events from Bpod data.
318
+
319
+ Converts RawEvents.Trial[].Events to ndx-structured-behavior EventsTable
320
+ with timestamps for each event occurrence.
321
+
322
+ Args:
323
+ bpod_data: Parsed Bpod data dictionary
324
+ event_types: EventTypesTable with event name → index mapping
325
+ trial_offsets: Optional dict mapping trial_number → absolute time offset
326
+
327
+ Returns:
328
+ Tuple of (EventsTable with event occurrences, Dict mapping trial_number → list of event row indices)
329
+
330
+ Example:
331
+ >>> events, event_indices = extract_events(bpod_data, event_types, trial_offsets)
332
+ >>> print(f"{len(events)} event occurrences")
333
+ >>> print(f"Trial 1 has {len(event_indices[1])} events")
334
+ """
335
+ session_data = convert_matlab_struct(bpod_data.get("SessionData", {}))
336
+ raw_events = convert_matlab_struct(session_data["RawEvents"])
337
+ trial_data_list = raw_events.get("Trial", [])
338
+ start_timestamps = session_data["TrialStartTimestamp"]
339
+
340
+ # Build event name → index mapping
341
+ event_name_to_idx = {name: idx for idx, name in enumerate(event_types["event_name"].data)}
342
+
343
+ # Create EventsTable
344
+ events = EventsTable(description="Hardware events from Bpod", event_types_table=event_types)
345
+
346
+ n_events = 0
347
+ # Track which events belong to which trial
348
+ trial_event_indices: Dict[int, List[int]] = {}
349
+
350
+ for trial_idx, trial_data in enumerate(trial_data_list):
351
+ trial_num = trial_idx + 1
352
+ trial_start_ts = float(to_scalar(start_timestamps, trial_idx))
353
+
354
+ # Initialize list for this trial's event indices
355
+ trial_event_indices[trial_num] = []
356
+
357
+ # Get time offset for absolute time conversion
358
+ offset = trial_offsets.get(trial_num) if trial_offsets else 0.0
359
+
360
+ # Extract events
361
+ if hasattr(trial_data, "Events"):
362
+ trial_events = trial_data.Events
363
+ elif isinstance(trial_data, dict):
364
+ trial_events = trial_data.get("Events", {})
365
+ else:
366
+ continue
367
+
368
+ trial_events = convert_matlab_struct(trial_events)
369
+
370
+ # Add each event occurrence
371
+ for event_name, timestamps in trial_events.items():
372
+ if event_name not in event_name_to_idx:
373
+ logger.warning(f"Unknown event '{event_name}' not in EventTypesTable")
374
+ continue
375
+
376
+ # Convert to list if numpy array or scalar
377
+ if isinstance(timestamps, np.ndarray):
378
+ timestamps = timestamps.flatten().tolist()
379
+ elif not isinstance(timestamps, (list, tuple)):
380
+ timestamps = [timestamps]
381
+
382
+ event_type_idx = event_name_to_idx[event_name]
383
+
384
+ # Add each timestamp
385
+ for timestamp_rel in timestamps:
386
+ if is_nan_or_none(timestamp_rel):
387
+ continue
388
+
389
+ timestamp_rel = float(timestamp_rel)
390
+ timestamp_abs = offset + trial_start_ts + timestamp_rel
391
+
392
+ # Add to EventsTable
393
+ events.add_event(
394
+ event_type=event_type_idx,
395
+ timestamp=timestamp_abs,
396
+ value=event_name, # Store original event name
397
+ )
398
+ # Track this event index for the trial
399
+ trial_event_indices[trial_num].append(n_events)
400
+ n_events += 1
401
+
402
+ logger.info(f"Extracted {n_events} event occurrences from {len(trial_data_list)} trials")
403
+ return events, trial_event_indices
404
+
405
+
406
+ def extract_actions(
407
+ bpod_data: Dict[str, Any],
408
+ action_types: ActionTypesTable,
409
+ trial_offsets: Optional[Dict[int, float]] = None,
410
+ ) -> Tuple[ActionsTable, Dict[int, List[int]]]:
411
+ """Extract actions from Bpod state transitions.
412
+
413
+ Identifies action states (rewards, stimuli) and converts to
414
+ ndx-structured-behavior ActionsTable with timestamps and durations.
415
+
416
+ Args:
417
+ bpod_data: Parsed Bpod data dictionary
418
+ action_types: ActionTypesTable with action name → index mapping
419
+ trial_offsets: Optional dict mapping trial_number → absolute time offset
420
+
421
+ Returns:
422
+ Tuple of (ActionsTable with action occurrences, Dict mapping trial_number → list of action row indices)
423
+
424
+ Example:
425
+ >>> actions, action_indices = extract_actions(bpod_data, action_types, trial_offsets)
426
+ >>> print(f"{len(actions)} action occurrences")
427
+ >>> print(f"Trial 1 has {len(action_indices[1])} actions")
428
+ """
429
+ session_data = convert_matlab_struct(bpod_data.get("SessionData", {}))
430
+ raw_events = convert_matlab_struct(session_data["RawEvents"])
431
+ trial_data_list = raw_events.get("Trial", [])
432
+ start_timestamps = session_data["TrialStartTimestamp"]
433
+
434
+ # Build action name → index mapping
435
+ action_name_to_idx = {name: idx for idx, name in enumerate(action_types["action_name"].data)}
436
+
437
+ # Reverse mapping: state_name → action_name
438
+ state_to_action = {state: action for state, action in ACTION_STATES.items() if action in action_name_to_idx}
439
+
440
+ # Create ActionsTable
441
+ actions = ActionsTable(description="Actions from Bpod protocol", action_types_table=action_types)
442
+
443
+ n_actions = 0
444
+ # Track which actions belong to which trial
445
+ trial_action_indices: Dict[int, List[int]] = {}
446
+
447
+ for trial_idx, trial_data in enumerate(trial_data_list):
448
+ trial_num = trial_idx + 1
449
+ trial_start_ts = float(to_scalar(start_timestamps, trial_idx))
450
+
451
+ # Initialize list for this trial's action indices
452
+ trial_action_indices[trial_num] = []
453
+
454
+ # Get time offset for absolute time conversion
455
+ offset = trial_offsets.get(trial_num) if trial_offsets else 0.0
456
+
457
+ # Extract states
458
+ if hasattr(trial_data, "States"):
459
+ trial_states = trial_data.States
460
+ elif isinstance(trial_data, dict):
461
+ trial_states = trial_data.get("States", {})
462
+ else:
463
+ continue
464
+
465
+ trial_states = convert_matlab_struct(trial_states)
466
+
467
+ # Check action states
468
+ for state_name, state_times in trial_states.items():
469
+ if state_name not in state_to_action:
470
+ continue
471
+
472
+ action_name = state_to_action[state_name]
473
+ action_type_idx = action_name_to_idx[action_name]
474
+
475
+ # Check if state was visited
476
+ if isinstance(state_times, np.ndarray) and state_times.size >= 2:
477
+ start_rel = float(state_times.flat[0])
478
+ stop_rel = float(state_times.flat[1])
479
+ elif isinstance(state_times, (list, tuple)) and len(state_times) >= 2:
480
+ start_rel = float(state_times[0])
481
+ stop_rel = float(state_times[1])
482
+ else:
483
+ continue
484
+
485
+ # Skip NaN states (not visited)
486
+ if is_nan_or_none(start_rel) or is_nan_or_none(stop_rel):
487
+ continue
488
+
489
+ # Convert to absolute time
490
+ timestamp_abs = offset + trial_start_ts + start_rel
491
+ duration = stop_rel - start_rel
492
+
493
+ # Add to ActionsTable
494
+ actions.add_action(
495
+ action_type=action_type_idx,
496
+ timestamp=timestamp_abs,
497
+ duration=duration,
498
+ value=state_name, # Original state name for traceability
499
+ )
500
+ # Track this action index for the trial
501
+ trial_action_indices[trial_num].append(n_actions)
502
+ n_actions += 1
503
+
504
+ logger.info(f"Extracted {n_actions} action occurrences from {len(trial_data_list)} trials")
505
+ return actions, trial_action_indices
506
+
507
+
508
+ # =============================================================================
509
+ # Trials and Recording
510
+ # =============================================================================
511
+
512
+
513
+ def build_trials_table(
514
+ bpod_data: Dict[str, Any],
515
+ recording: TaskRecording,
516
+ state_indices: Dict[int, List[int]],
517
+ event_indices: Dict[int, List[int]],
518
+ action_indices: Dict[int, List[int]],
519
+ trial_offsets: Optional[Dict[int, float]] = None,
520
+ ) -> TrialsTable:
521
+ """Build TrialsTable with references to TaskRecording tables.
522
+
523
+ Creates ndx-structured-behavior TrialsTable with start/stop times for
524
+ each trial and index ranges referencing the states/events/actions tables
525
+ from the TaskRecording.
526
+
527
+ This simplified API ensures that the TrialsTable references the exact same
528
+ table instances as the TaskRecording, preventing instance mismatch errors.
529
+
530
+ Args:
531
+ bpod_data: Parsed Bpod data dictionary
532
+ recording: TaskRecording containing states/events/actions tables
533
+ state_indices: Dict mapping trial_number → list of state row indices
534
+ event_indices: Dict mapping trial_number → list of event row indices
535
+ action_indices: Dict mapping trial_number → list of action row indices
536
+ trial_offsets: Optional dict mapping trial_number → absolute time offset
537
+
538
+ Returns:
539
+ TrialsTable with trial structure
540
+
541
+ Example:
542
+ >>> # Build TaskRecording first
543
+ >>> recording = build_task_recording(states, events, actions)
544
+ >>> # Build TrialsTable using the same instances
545
+ >>> trials = build_trials_table(bpod_data, recording,
546
+ ... state_indices, event_indices, action_indices,
547
+ ... trial_offsets)
548
+ >>> print(f"{len(trials)} trials")
549
+ """
550
+ # Extract tables from TaskRecording to ensure instance consistency
551
+ states = recording.states
552
+ events = recording.events
553
+ actions = recording.actions
554
+
555
+ session_data = convert_matlab_struct(bpod_data.get("SessionData", {}))
556
+ n_trials = int(session_data["nTrials"])
557
+ start_timestamps = session_data["TrialStartTimestamp"]
558
+ end_timestamps = session_data["TrialEndTimestamp"]
559
+
560
+ # Create TrialsTable
561
+ trials = TrialsTable(
562
+ description="Trials from Bpod session",
563
+ states_table=states,
564
+ events_table=events,
565
+ actions_table=actions,
566
+ )
567
+
568
+ # Build trials with references to states/events/actions
569
+ for trial_idx in range(n_trials):
570
+ trial_num = trial_idx + 1
571
+ trial_start_rel = float(to_scalar(start_timestamps, trial_idx))
572
+ trial_stop_rel = float(to_scalar(end_timestamps, trial_idx))
573
+
574
+ # Get time offset
575
+ offset = trial_offsets.get(trial_num) if trial_offsets else 0.0
576
+
577
+ # Convert to absolute time
578
+ start_time = offset + trial_start_rel
579
+ stop_time = offset + trial_stop_rel
580
+
581
+ # Get indices for this trial (use empty lists if trial not found)
582
+ trial_states = state_indices.get(trial_num, [])
583
+ trial_events = event_indices.get(trial_num, [])
584
+ trial_actions = action_indices.get(trial_num, [])
585
+
586
+ trials.add_trial(
587
+ start_time=start_time,
588
+ stop_time=stop_time,
589
+ states=trial_states,
590
+ events=trial_events,
591
+ actions=trial_actions,
592
+ )
593
+
594
+ logger.info(f"Built TrialsTable with {n_trials} trials")
595
+ return trials
596
+
597
+
598
+ def extract_trials_table(
599
+ bpod_data: Dict[str, Any],
600
+ recording: TaskRecording,
601
+ trial_offsets: Optional[Dict[int, float]] = None,
602
+ ) -> TrialsTable:
603
+ """Extract complete TrialsTable from Bpod data using TaskRecording.
604
+
605
+ High-level function that builds a TrialsTable using the data tables from
606
+ an existing TaskRecording. This ensures instance consistency between the
607
+ TaskRecording and TrialsTable.
608
+
609
+ This is the recommended approach for creating TrialsTable:
610
+ 1. Build TaskRecording with extract_task_recording() or build_task_recording()
611
+ 2. Pass TaskRecording to this function to build TrialsTable
612
+
613
+ Args:
614
+ bpod_data: Parsed Bpod data dictionary
615
+ recording: TaskRecording with states/events/actions tables
616
+ trial_offsets: Optional dict mapping trial_number → absolute time offset
617
+
618
+ Returns:
619
+ TrialsTable with complete trial structure
620
+
621
+ Example:
622
+ >>> from w2t_bkin.bpod.code import parse_bpod
623
+ >>> from w2t_bkin.behavior import extract_task_recording, extract_trials_table
624
+ >>>
625
+ >>> bpod_data = parse_bpod(Path("data"), "Bpod/*.mat", "name_asc")
626
+ >>> recording = extract_task_recording(bpod_data, trial_offsets)
627
+ >>> trials = extract_trials_table(bpod_data, recording, trial_offsets)
628
+ >>> print(f"{len(trials)} trials extracted")
629
+
630
+ Note:
631
+ The recording parameter ensures that TrialsTable references the exact
632
+ same table instances as the TaskRecording, preventing NWB serialization
633
+ errors due to instance mismatches.
634
+ """
635
+ # Extract type tables from recording to build indices
636
+ states = recording.states
637
+ events = recording.events
638
+ actions = recording.actions
639
+
640
+ state_types = states.state_type.table
641
+ event_types = events.event_type.table
642
+ action_types = actions.action_type.table
643
+
644
+ # Re-extract indices (they're not stored in TaskRecording)
645
+ # This is necessary to map trial_number → row indices
646
+ _, state_indices = extract_states(bpod_data, state_types, trial_offsets)
647
+ _, event_indices = extract_events(bpod_data, event_types, trial_offsets)
648
+ _, action_indices = extract_actions(bpod_data, action_types, trial_offsets)
649
+
650
+ # Build TrialsTable using TaskRecording
651
+ trials = build_trials_table(
652
+ bpod_data=bpod_data,
653
+ recording=recording,
654
+ state_indices=state_indices,
655
+ event_indices=event_indices,
656
+ action_indices=action_indices,
657
+ trial_offsets=trial_offsets,
658
+ )
659
+
660
+ return trials
661
+
662
+
663
+ def build_task_recording(
664
+ states: StatesTable,
665
+ events: EventsTable,
666
+ actions: ActionsTable,
667
+ ) -> TaskRecording:
668
+ """Build TaskRecording container for states/events/actions.
669
+
670
+ Creates ndx-structured-behavior TaskRecording object that packages
671
+ the three data tables for NWB file integration.
672
+
673
+ Args:
674
+ states: StatesTable with state occurrences
675
+ events: EventsTable with event occurrences
676
+ actions: ActionsTable with action occurrences
677
+
678
+ Returns:
679
+ TaskRecording container
680
+
681
+ Example:
682
+ >>> task_recording = build_task_recording(states, events, actions)
683
+ >>> nwbfile.add_acquisition(task_recording)
684
+ """
685
+ task_recording = TaskRecording(
686
+ states=states,
687
+ events=events,
688
+ actions=actions,
689
+ )
690
+
691
+ logger.info("Built TaskRecording container")
692
+ return task_recording
693
+
694
+
695
+ def extract_task_recording(
696
+ bpod_data: Dict[str, Any],
697
+ trial_offsets: Optional[Dict[int, float]] = None,
698
+ ) -> TaskRecording:
699
+ """Extract complete TaskRecording from Bpod data (convenience function).
700
+
701
+ High-level function that performs all extraction steps:
702
+ 1. Extract type tables (states, events, actions)
703
+ 2. Extract data tables (states, events, actions) with row indices
704
+ 3. Build TaskRecording container
705
+
706
+ This is a convenience wrapper for simpler API usage when you need the
707
+ complete TaskRecording for NWB acquisition.
708
+
709
+ Args:
710
+ bpod_data: Parsed Bpod data dictionary
711
+ trial_offsets: Optional dict mapping trial_number → absolute time offset
712
+
713
+ Returns:
714
+ TaskRecording with complete state/event/action tables
715
+
716
+ Example:
717
+ >>> from w2t_bkin.bpod.code import parse_bpod
718
+ >>> from w2t_bkin.behavior import extract_task_recording
719
+ >>>
720
+ >>> bpod_data = parse_bpod(Path("data"), "Bpod/*.mat", "name_asc")
721
+ >>> task_recording = extract_task_recording(bpod_data)
722
+ >>> nwbfile.add_acquisition(task_recording)
723
+
724
+ Note:
725
+ If you need access to the intermediate type tables or data tables,
726
+ use the individual extract_* functions instead.
727
+ """
728
+ # Step 1: Extract type tables
729
+ state_types = extract_state_types(bpod_data)
730
+ event_types = extract_event_types(bpod_data)
731
+ action_types = extract_action_types(bpod_data)
732
+
733
+ # Step 2: Extract data tables with indices
734
+ states, _ = extract_states(bpod_data, state_types, trial_offsets)
735
+ events, _ = extract_events(bpod_data, event_types, trial_offsets)
736
+ actions, _ = extract_actions(bpod_data, action_types, trial_offsets)
737
+
738
+ # Step 3: Build TaskRecording
739
+ task_recording = build_task_recording(states, events, actions)
740
+
741
+ return task_recording
742
+
743
+
744
+ # =============================================================================
745
+ # Task Metadata (Top-level Container)
746
+ # =============================================================================
747
+
748
+
749
+ def _flatten_dict(d: Dict[str, Any], parent_key: str = "", sep: str = ".") -> List[tuple]:
750
+ """Recursively flatten nested dictionary into list of (key, value) tuples.
751
+
752
+ Args:
753
+ d: Dictionary to flatten
754
+ parent_key: Parent key prefix for nested keys
755
+ sep: Separator between parent and child keys
756
+
757
+ Returns:
758
+ List of (flattened_key, value) tuples
759
+
760
+ Example:
761
+ >>> _flatten_dict({'a': 1, 'b': {'c': 2, 'd': 3}})
762
+ [('a', 1), ('b.c', 2), ('b.d', 3)]
763
+ """
764
+ items = []
765
+ for k, v in d.items():
766
+ new_key = f"{parent_key}{sep}{k}" if parent_key else k
767
+ if isinstance(v, dict):
768
+ items.extend(_flatten_dict(v, new_key, sep=sep))
769
+ else:
770
+ items.append((new_key, v))
771
+ return items
772
+
773
+
774
+ def extract_task_arguments(bpod_data: Dict[str, Any]) -> Optional[TaskArgumentsTable]:
775
+ """Extract task arguments/parameters from Bpod data.
776
+
777
+ Attempts to extract task configuration parameters from:
778
+ 1. SessionData.Settings (protocol parameters) - most common
779
+ 2. SessionData.TrialSettings (per-trial parameters) - if uniform across trials
780
+ 3. Top-level SessionData fields (metadata)
781
+
782
+ Args:
783
+ bpod_data: Parsed Bpod data dictionary from parse_bpod()
784
+
785
+ Returns:
786
+ TaskArgumentsTable if arguments found, None otherwise
787
+
788
+ Example:
789
+ >>> bpod_data = parse_bpod(Path("data"), "Bpod/*.mat", "name_asc")
790
+ >>> task_args = extract_task_arguments(bpod_data)
791
+ >>> if task_args:
792
+ ... print(f"{len(task_args)} parameters")
793
+ """
794
+ session_data = convert_matlab_struct(bpod_data.get("SessionData", {}))
795
+
796
+ # Try Settings first (most common location)
797
+ params = {}
798
+ if "Settings" in session_data:
799
+ settings = convert_matlab_struct(session_data["Settings"])
800
+ if isinstance(settings, dict) and len(settings) > 0:
801
+ params.update(dict(_flatten_dict(settings)))
802
+ logger.info(f"Found {len(params)} parameters in Settings")
803
+
804
+ # Try TrialSettings (check if uniform across trials)
805
+ if "TrialSettings" in session_data and len(params) == 0:
806
+ trial_settings = session_data["TrialSettings"]
807
+ if hasattr(trial_settings, "__len__") and len(trial_settings) > 0:
808
+ first_trial = convert_matlab_struct(trial_settings[0])
809
+ if isinstance(first_trial, dict):
810
+ # Check if all trials have same settings
811
+ uniform = True
812
+ for trial in trial_settings[1:]:
813
+ trial_dict = convert_matlab_struct(trial)
814
+ if trial_dict != first_trial:
815
+ uniform = False
816
+ break
817
+
818
+ if uniform:
819
+ params.update(dict(_flatten_dict(first_trial)))
820
+ logger.info(f"Found {len(params)} uniform parameters in TrialSettings")
821
+ else:
822
+ logger.debug("TrialSettings vary across trials, not extracting as task arguments")
823
+
824
+ # Add useful metadata fields
825
+ metadata_fields = ["nTrials", "TrialTypes"]
826
+ for field in metadata_fields:
827
+ if field in session_data and field not in params:
828
+ value = session_data[field]
829
+ # Convert arrays to scalar if single value
830
+ if hasattr(value, "__len__") and not isinstance(value, str):
831
+ if len(set(value)) == 1: # All same value
832
+ value = value[0]
833
+ else:
834
+ continue # Skip non-uniform arrays
835
+ params[field] = value
836
+
837
+ if len(params) == 0:
838
+ logger.info("No task arguments found in Bpod data")
839
+ return None
840
+
841
+ # Create TaskArgumentsTable
842
+ task_args = TaskArgumentsTable(description="Task parameters from Bpod")
843
+
844
+ # Add each parameter as a row
845
+ for arg_name, arg_value in sorted(params.items()):
846
+ # Convert value to string for storage
847
+ if isinstance(arg_value, (np.ndarray, list)):
848
+ value_str = str(list(arg_value))
849
+ value_type = "array"
850
+ elif isinstance(arg_value, (int, np.integer)):
851
+ value_str = str(arg_value)
852
+ value_type = "integer"
853
+ elif isinstance(arg_value, (float, np.floating)):
854
+ value_str = str(arg_value)
855
+ value_type = "float"
856
+ elif isinstance(arg_value, bool):
857
+ value_str = str(arg_value)
858
+ value_type = "boolean"
859
+ else:
860
+ value_str = str(arg_value)
861
+ value_type = "string"
862
+
863
+ task_args.add_row(
864
+ argument_name=arg_name,
865
+ argument_description=f"Parameter from Bpod data",
866
+ expression=value_str,
867
+ expression_type=value_type,
868
+ output_type=value_type,
869
+ )
870
+
871
+ logger.info(f"Extracted {len(task_args)} task arguments")
872
+ return task_args
873
+
874
+
875
+ def build_task(
876
+ state_types: StateTypesTable,
877
+ event_types: EventTypesTable,
878
+ action_types: ActionTypesTable,
879
+ task_arguments: Optional[TaskArgumentsTable] = None,
880
+ ) -> Task:
881
+ """Build Task container with type tables and metadata.
882
+
883
+ Assembles the top-level Task container that holds all behavioral
884
+ type tables (states, events, actions) and optional task metadata
885
+ (arguments). This Task object is added to /general/task in the NWBFile.
886
+
887
+ Args:
888
+ state_types: StateTypesTable with state definitions
889
+ event_types: EventTypesTable with event definitions
890
+ action_types: ActionTypesTable with action definitions
891
+ task_arguments: Optional task parameters/arguments
892
+
893
+ Returns:
894
+ Task container for /general/task in NWBFile
895
+
896
+ Example:
897
+ >>> task_args = extract_task_arguments(bpod_data)
898
+ >>> task = build_task(state_types, event_types, action_types,
899
+ ... task_arguments=task_args)
900
+ >>> nwbfile.add_lab_meta_data(task)
901
+ """
902
+ # Create Task container with required type tables
903
+ task = Task(
904
+ event_types=event_types,
905
+ state_types=state_types,
906
+ action_types=action_types,
907
+ )
908
+
909
+ # Add optional task arguments
910
+ if task_arguments is not None:
911
+ task.task_arguments = task_arguments
912
+ logger.info(f"Built Task with {len(task_arguments)} arguments")
913
+ else:
914
+ logger.info("Built Task without arguments")
915
+
916
+ return task
917
+
918
+
919
+ def extract_task(bpod_data: Dict[str, Any]) -> Task:
920
+ """Extract complete Task container from Bpod data (convenience function).
921
+
922
+ High-level function that performs all extraction steps:
923
+ 1. Extract type tables (states, events, actions)
924
+ 2. Extract task arguments (optional)
925
+ 3. Build Task container
926
+
927
+ This is a convenience wrapper for simpler API usage when you need the
928
+ complete Task container for /general/task in NWBFile.
929
+
930
+ Args:
931
+ bpod_data: Parsed Bpod data dictionary
932
+
933
+ Returns:
934
+ Task container with type tables and optional arguments
935
+
936
+ Example:
937
+ >>> from w2t_bkin.bpod.code import parse_bpod
938
+ >>> from w2t_bkin.behavior import extract_task
939
+ >>>
940
+ >>> bpod_data = parse_bpod(Path("data"), "Bpod/*.mat", "name_asc")
941
+ >>> task = extract_task(bpod_data)
942
+ >>> nwbfile.add_lab_meta_data(task)
943
+
944
+ Note:
945
+ If you need access to the intermediate type tables or task arguments,
946
+ use the individual extract_* functions instead.
947
+ """
948
+ # Step 1: Extract type tables
949
+ state_types = extract_state_types(bpod_data)
950
+ event_types = extract_event_types(bpod_data)
951
+ action_types = extract_action_types(bpod_data)
952
+
953
+ # Step 2: Extract task arguments (optional)
954
+ task_arguments = extract_task_arguments(bpod_data)
955
+
956
+ # Step 3: Build Task
957
+ task = build_task(state_types, event_types, action_types, task_arguments)
958
+
959
+ return task
960
+
961
+
962
+ def extract_behavioral_data(
963
+ bpod_data: Dict[str, Any],
964
+ trial_offsets: Optional[Dict[int, float]] = None,
965
+ ) -> Tuple[Task, TaskRecording, TrialsTable]:
966
+ """Extract all behavioral data structures in one call (highest-level convenience).
967
+
968
+ This is the simplest API for extracting complete behavioral data from Bpod.
969
+ It extracts Task, TaskRecording, and TrialsTable with guaranteed instance
970
+ consistency between all components.
971
+
972
+ Recommended for most use cases where you need all three components.
973
+
974
+ Args:
975
+ bpod_data: Parsed Bpod data dictionary
976
+ trial_offsets: Optional dict mapping trial_number → absolute time offset
977
+
978
+ Returns:
979
+ Tuple of (Task, TaskRecording, TrialsTable)
980
+
981
+ Example:
982
+ >>> from w2t_bkin.bpod.code import parse_bpod
983
+ >>> from w2t_bkin.behavior import extract_behavioral_data
984
+ >>>
985
+ >>> bpod_data = parse_bpod(Path("data"), "Bpod/*.mat", "name_asc")
986
+ >>> task, recording, trials = extract_behavioral_data(bpod_data, trial_offsets)
987
+ >>>
988
+ >>> # Add to NWB file
989
+ >>> nwbfile.add_lab_meta_data(task)
990
+ >>> nwbfile.add_acquisition(recording)
991
+ >>> nwbfile.trials = trials
992
+
993
+ Note:
994
+ This function ensures that:
995
+ - Task contains the type tables
996
+ - TaskRecording references those type tables
997
+ - TrialsTable references the data tables from TaskRecording
998
+ """
999
+ # Step 1: Extract type tables once (shared between Task and TaskRecording)
1000
+ state_types = extract_state_types(bpod_data)
1001
+ event_types = extract_event_types(bpod_data)
1002
+ action_types = extract_action_types(bpod_data)
1003
+
1004
+ # Step 2: Build Task with type tables
1005
+ task_arguments = extract_task_arguments(bpod_data)
1006
+ task = build_task(state_types, event_types, action_types, task_arguments)
1007
+
1008
+ # Step 3: Extract data tables using the same type tables
1009
+ states, state_indices = extract_states(bpod_data, state_types, trial_offsets)
1010
+ events, event_indices = extract_events(bpod_data, event_types, trial_offsets)
1011
+ actions, action_indices = extract_actions(bpod_data, action_types, trial_offsets)
1012
+
1013
+ # Step 4: Build TaskRecording with data tables
1014
+ recording = build_task_recording(states, events, actions)
1015
+
1016
+ # Step 5: Build TrialsTable using TaskRecording (ensures instance consistency)
1017
+ trials = build_trials_table(
1018
+ bpod_data=bpod_data,
1019
+ recording=recording,
1020
+ state_indices=state_indices,
1021
+ event_indices=event_indices,
1022
+ action_indices=action_indices,
1023
+ trial_offsets=trial_offsets,
1024
+ )
1025
+
1026
+ logger.info("Extracted complete behavioral data: Task, TaskRecording, and TrialsTable")
1027
+ return task, recording, trials