w2t-bkin 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,38 @@
1
+ """Behavioral data parsing from Bpod .mat files.
2
+
3
+ Provides low-level Bpod file operations:
4
+ - Parsing and merging Bpod .mat files
5
+ - Validating Bpod data structure
6
+ - File indexing and manipulation
7
+
8
+ For behavioral data extraction, use the behavior module with ndx-structured-behavior.
9
+
10
+ Example:
11
+ >>> from pathlib import Path
12
+ >>> from w2t_bkin.bpod import parse_bpod
13
+ >>> from w2t_bkin.behavior import extract_trials_table
14
+ >>> bpod_data = parse_bpod(Path("data"), "Bpod/*.mat", "name_asc")
15
+ >>> trials = extract_trials_table(bpod_data)
16
+ """
17
+
18
+ # Exceptions
19
+ from ..exceptions import BpodParseError, BpodValidationError, EventsError
20
+
21
+ # Bpod file operations
22
+ from .core import index_bpod_data, merge_bpod_sessions, parse_bpod, parse_bpod_from_files, parse_bpod_mat, split_bpod_data, validate_bpod_structure, write_bpod_mat
23
+
24
+ __all__ = [
25
+ # Exceptions
26
+ "EventsError",
27
+ "BpodParseError",
28
+ "BpodValidationError",
29
+ # Bpod file operations
30
+ "parse_bpod",
31
+ "parse_bpod_mat",
32
+ "merge_bpod_sessions",
33
+ "parse_bpod_from_files",
34
+ "validate_bpod_structure",
35
+ "index_bpod_data",
36
+ "split_bpod_data",
37
+ "write_bpod_mat",
38
+ ]
w2t_bkin/bpod/core.py ADDED
@@ -0,0 +1,519 @@
1
+ """Low-level Bpod .mat file I/O operations.
2
+
3
+ Provides functions to parse, merge, validate, index, and write Bpod data files.
4
+ """
5
+
6
+ import copy
7
+ import logging
8
+ from pathlib import Path
9
+ from typing import Any, Dict, Iterable, List, Sequence
10
+
11
+ import numpy as np
12
+
13
+ try:
14
+ from scipy.io import loadmat, savemat
15
+ except ImportError:
16
+ loadmat = None
17
+ savemat = None
18
+
19
+ from ..exceptions import BpodParseError, BpodValidationError
20
+ from ..utils import convert_matlab_struct, discover_files, sanitize_string, sort_files, validate_against_whitelist, validate_file_exists, validate_file_size
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+ # Constants
25
+ MAX_BPOD_FILE_SIZE_MB = 100
26
+
27
+
28
+ def validate_bpod_path(path: Path) -> None:
29
+ """Validate Bpod file path and size.
30
+
31
+ Args:
32
+ path: Path to .mat file
33
+
34
+ Raises:
35
+ BpodValidationError: Invalid path or file too large
36
+ """
37
+ # Validate file exists
38
+ validate_file_exists(path, BpodValidationError, "Bpod file not found")
39
+
40
+ # Check file extension
41
+ if path.suffix.lower() not in [".mat"]:
42
+ raise BpodValidationError(f"Invalid file extension: {path.suffix}", file_path=str(path))
43
+
44
+ # Check file size (prevent memory exhaustion)
45
+ try:
46
+ file_size_mb = validate_file_size(path, max_size_mb=MAX_BPOD_FILE_SIZE_MB)
47
+ logger.debug(f"Validated Bpod file: {path.name} ({file_size_mb:.2f}MB)")
48
+ except ValueError as e:
49
+ # Re-raise as BpodValidationError for consistent error handling
50
+ raise BpodValidationError(str(e), file_path=str(path))
51
+
52
+
53
+ def parse_bpod(session_dir: Path, pattern: str, order: str, continuous_time: bool = True) -> Dict[str, Any]:
54
+ """Parse Bpod files matching a glob pattern.
55
+
56
+ Discovers files using glob pattern, sorts them, then parses and merges.
57
+
58
+ Args:
59
+ session_dir: Base directory for resolving glob pattern
60
+ pattern: Glob pattern for Bpod files (e.g. "Bpod/*.mat")
61
+ order: Sort order (e.g. "name_asc", "modified_desc")
62
+ continuous_time: Offset timestamps for continuous timeline
63
+
64
+ Returns:
65
+ Merged Bpod data dictionary
66
+
67
+ Raises:
68
+ BpodValidationError: No files found
69
+ BpodParseError: Parse/merge failed
70
+
71
+ Example:
72
+ >>> from pathlib import Path
73
+ >>> bpod_data = parse_bpod(Path("data"), "Bpod/*.mat", "name_asc")
74
+ """
75
+ file_paths = discover_bpod_files_from_pattern(session_dir=session_dir, pattern=pattern, order=order)
76
+ return parse_bpod_from_files(file_paths=file_paths, continuous_time=continuous_time)
77
+
78
+
79
+ def discover_bpod_files_from_pattern(session_dir: Path, pattern: str, order: str) -> List[Path]:
80
+ """Discover and sort Bpod .mat files using a glob pattern.
81
+
82
+ Args:
83
+ session_dir: Base directory for glob pattern
84
+ pattern: Glob pattern (e.g. "Bpod/*.mat")
85
+ order: Sort order (e.g. "name_asc")
86
+
87
+ Returns:
88
+ Sorted list of file paths
89
+
90
+ Raises:
91
+ BpodValidationError: No files found
92
+ """
93
+ file_paths = discover_files(session_dir, pattern, sort=False)
94
+
95
+ if not file_paths:
96
+ raise BpodValidationError(f"No Bpod files found matching pattern: {pattern}")
97
+
98
+ file_paths = sort_files(file_paths, order)
99
+
100
+ logger.info("Discovered %d Bpod files with order '%s'", len(file_paths), order)
101
+ return file_paths
102
+
103
+
104
+ def parse_bpod_from_files(file_paths: Sequence[Path], continuous_time: bool = True) -> Dict[str, Any]:
105
+ """Parse and merge Bpod files from explicit paths.
106
+
107
+ Args:
108
+ file_paths: Ordered paths to .mat files
109
+ continuous_time: Offset timestamps for continuous timeline
110
+
111
+ Returns:
112
+ Merged Bpod data dictionary
113
+
114
+ Raises:
115
+ BpodParseError: Parse/merge failed
116
+ """
117
+ return merge_bpod_sessions(list(file_paths), continuous_time=continuous_time)
118
+
119
+
120
+ def parse_bpod_mat(path: Path) -> Dict[str, Any]:
121
+ """Parse a single Bpod .mat file.
122
+
123
+ Args:
124
+ path: Path to .mat file
125
+
126
+ Returns:
127
+ Bpod data dictionary
128
+
129
+ Raises:
130
+ BpodValidationError: File validation failed
131
+ BpodParseError: Parse failed
132
+
133
+ Example:
134
+ >>> from pathlib import Path
135
+ >>> bpod_data = parse_bpod_mat(Path("data/session.mat"))
136
+ """
137
+ # Validate path and file size
138
+ validate_bpod_path(path)
139
+
140
+ if loadmat is None:
141
+ raise BpodParseError("scipy is required for .mat file parsing. Install with: pip install scipy")
142
+
143
+ try:
144
+ data = loadmat(str(path), squeeze_me=True, struct_as_record=False)
145
+ logger.info(f"Successfully parsed Bpod file: {path.name}")
146
+ return data
147
+ except Exception as e:
148
+ # Avoid leaking full path in error message
149
+ raise BpodParseError(f"Failed to parse Bpod file: {type(e).__name__}")
150
+
151
+
152
+ def validate_bpod_structure(data: Dict[str, Any]) -> bool:
153
+ """Validate Bpod data has required fields.
154
+
155
+ Args:
156
+ data: Bpod data dictionary
157
+
158
+ Returns:
159
+ True if valid
160
+ """
161
+ if "SessionData" not in data:
162
+ logger.warning("Missing 'SessionData' in Bpod file")
163
+ return False
164
+
165
+ session_data = convert_matlab_struct(data["SessionData"])
166
+
167
+ # Check for required fields
168
+ required_fields = ["nTrials", "TrialStartTimestamp", "TrialEndTimestamp"]
169
+ for field in required_fields:
170
+ if field not in session_data:
171
+ logger.warning(f"Missing required field '{field}' in SessionData")
172
+ return False
173
+
174
+ # Check for RawEvents structure
175
+ if "RawEvents" not in session_data:
176
+ logger.warning("Missing 'RawEvents' in SessionData")
177
+ return False
178
+
179
+ raw_events = convert_matlab_struct(session_data["RawEvents"])
180
+
181
+ if "Trial" not in raw_events:
182
+ logger.warning("Missing 'Trial' in RawEvents")
183
+ return False
184
+
185
+ logger.debug("Bpod structure validation passed")
186
+ return True
187
+
188
+
189
+ def merge_bpod_sessions(file_paths: List[Path], continuous_time: bool = True) -> Dict[str, Any]:
190
+ """Merge multiple Bpod .mat files into one.
191
+
192
+ Combines trials from files in order. With continuous_time=True, offsets
193
+ timestamps so each file continues from the previous file's end time.
194
+
195
+ Args:
196
+ file_paths: Ordered list of .mat file paths
197
+ continuous_time: Offset timestamps for continuous timeline
198
+
199
+ Returns:
200
+ Merged Bpod data dictionary
201
+
202
+ Raises:
203
+ BpodParseError: Parse/merge failed
204
+ """
205
+ if not file_paths:
206
+ raise BpodParseError("No Bpod files to merge")
207
+
208
+ if len(file_paths) == 1:
209
+ # Single file - just parse and return
210
+ return parse_bpod_mat(file_paths[0])
211
+
212
+ # Parse all files
213
+ parsed_files = []
214
+ for path in file_paths:
215
+ try:
216
+ data = parse_bpod_mat(path)
217
+ parsed_files.append((path, data))
218
+ except Exception as e:
219
+ logger.error(f"Failed to parse {path.name}: {e}")
220
+ raise
221
+
222
+ # Start with first file as base
223
+ _, merged_data = parsed_files[0]
224
+ merged_session = convert_matlab_struct(merged_data["SessionData"])
225
+
226
+ # Extract base data
227
+ all_trials = []
228
+ all_start_times = []
229
+ all_end_times = []
230
+ all_trial_settings = []
231
+ all_trial_types = []
232
+
233
+ # Add first file's data
234
+ first_raw_events = convert_matlab_struct(merged_session["RawEvents"])
235
+ # Ensure RawEvents is a dict in merged_session
236
+ merged_session["RawEvents"] = first_raw_events
237
+
238
+ # Convert Trial to list if it's a mat_struct or numpy array
239
+ trials = first_raw_events["Trial"]
240
+ if hasattr(trials, "__dict__"):
241
+ # mat_struct object - could be a single trial or not iterable
242
+ # Try to iterate, if not possible, wrap in list
243
+ try:
244
+ trials = [convert_matlab_struct(trial) for trial in trials]
245
+ except TypeError:
246
+ # Single mat_struct object - wrap in list
247
+ trials = [convert_matlab_struct(trials)]
248
+ elif isinstance(trials, np.ndarray):
249
+ # numpy array - convert to list
250
+ trials = trials.tolist()
251
+ elif not isinstance(trials, list):
252
+ # Other types - wrap in list
253
+ trials = list(trials) if hasattr(trials, "__iter__") else [trials]
254
+
255
+ all_trials.extend(trials)
256
+
257
+ # Convert timestamps to lists if they're numpy arrays
258
+ start_times = merged_session["TrialStartTimestamp"]
259
+ end_times = merged_session["TrialEndTimestamp"]
260
+ if isinstance(start_times, np.ndarray):
261
+ start_times = start_times.tolist()
262
+ if isinstance(end_times, np.ndarray):
263
+ end_times = end_times.tolist()
264
+
265
+ all_start_times.extend(start_times if isinstance(start_times, list) else [start_times])
266
+ all_end_times.extend(end_times if isinstance(end_times, list) else [end_times])
267
+
268
+ # Convert settings and types to lists if they're numpy arrays
269
+ trial_settings = merged_session.get("TrialSettings", [])
270
+ trial_types = merged_session.get("TrialTypes", [])
271
+ if isinstance(trial_settings, np.ndarray):
272
+ trial_settings = trial_settings.tolist()
273
+ if isinstance(trial_types, np.ndarray):
274
+ trial_types = trial_types.tolist()
275
+
276
+ all_trial_settings.extend(trial_settings if isinstance(trial_settings, list) else [trial_settings])
277
+ all_trial_types.extend(trial_types if isinstance(trial_types, list) else [trial_types])
278
+
279
+ # Merge subsequent files
280
+ for path, data in parsed_files[1:]:
281
+ session_data = convert_matlab_struct(data["SessionData"])
282
+ raw_events = convert_matlab_struct(session_data["RawEvents"])
283
+
284
+ # Get trial offset (time of last trial end) - only if continuous_time is True
285
+ time_offset = all_end_times[-1] if all_end_times and continuous_time else 0.0
286
+
287
+ # Convert Trial to list if it's a mat_struct or numpy array
288
+ trials = raw_events["Trial"]
289
+ if hasattr(trials, "__dict__"):
290
+ # mat_struct object - could be a single trial or not iterable
291
+ # Try to iterate, if not possible, wrap in list
292
+ try:
293
+ trials = [convert_matlab_struct(trial) for trial in trials]
294
+ except TypeError:
295
+ # Single mat_struct object - wrap in list
296
+ trials = [convert_matlab_struct(trials)]
297
+ elif isinstance(trials, np.ndarray):
298
+ # numpy array - convert to list
299
+ trials = trials.tolist()
300
+ elif not isinstance(trials, list):
301
+ # Other types - wrap in list
302
+ trials = list(trials) if hasattr(trials, "__iter__") else [trials]
303
+
304
+ # Append trials
305
+ all_trials.extend(trials)
306
+
307
+ # Offset timestamps
308
+ start_times = session_data["TrialStartTimestamp"]
309
+ end_times = session_data["TrialEndTimestamp"]
310
+
311
+ # Convert numpy arrays to lists
312
+ if isinstance(start_times, np.ndarray):
313
+ start_times = start_times.tolist()
314
+ if isinstance(end_times, np.ndarray):
315
+ end_times = end_times.tolist()
316
+
317
+ if isinstance(start_times, (list, tuple)):
318
+ all_start_times.extend([t + time_offset for t in start_times])
319
+ all_end_times.extend([t + time_offset for t in end_times])
320
+ else:
321
+ all_start_times.append(start_times + time_offset)
322
+ all_end_times.append(end_times + time_offset)
323
+
324
+ # Append settings and types
325
+ trial_settings = session_data.get("TrialSettings", [])
326
+ trial_types = session_data.get("TrialTypes", [])
327
+
328
+ # Convert numpy arrays to lists
329
+ if isinstance(trial_settings, np.ndarray):
330
+ trial_settings = trial_settings.tolist()
331
+ if isinstance(trial_types, np.ndarray):
332
+ trial_types = trial_types.tolist()
333
+
334
+ all_trial_settings.extend(trial_settings if isinstance(trial_settings, list) else [trial_settings])
335
+ all_trial_types.extend(trial_types if isinstance(trial_types, list) else [trial_types])
336
+
337
+ logger.debug(f"Merged {path.name}: added {session_data['nTrials']} trials")
338
+
339
+ # Update merged data
340
+ merged_session["nTrials"] = len(all_trials)
341
+ merged_session["TrialStartTimestamp"] = all_start_times
342
+ merged_session["TrialEndTimestamp"] = all_end_times
343
+ merged_session["RawEvents"]["Trial"] = all_trials
344
+ merged_session["TrialSettings"] = all_trial_settings
345
+ merged_session["TrialTypes"] = all_trial_types
346
+
347
+ merged_data["SessionData"] = merged_session
348
+
349
+ logger.info(f"Merged {len(file_paths)} Bpod files into {len(all_trials)} total trials")
350
+ return merged_data
351
+
352
+
353
+ def index_bpod_data(bpod_data: Dict[str, Any], trial_indices: List[int]) -> Dict[str, Any]:
354
+ """Filter Bpod data to keep only specified trials.
355
+
356
+ Args:
357
+ bpod_data: Bpod data dictionary
358
+ trial_indices: 0-based indices of trials to keep
359
+
360
+ Returns:
361
+ New Bpod data with filtered trials
362
+
363
+ Raises:
364
+ BpodParseError: Invalid structure
365
+ IndexError: Indices out of bounds
366
+
367
+ Example:
368
+ >>> bpod_data = parse_bpod_mat(Path("data/session.mat"))
369
+ >>> filtered = index_bpod_data(bpod_data, [0, 1, 2]) # First 3 trials
370
+ """
371
+ # Validate structure
372
+ if not validate_bpod_structure(bpod_data):
373
+ raise BpodParseError("Invalid Bpod structure")
374
+
375
+ # Deep copy to avoid modifying original
376
+ filtered_data = copy.deepcopy(bpod_data)
377
+
378
+ # Convert MATLAB struct to dict if needed
379
+ session_data = convert_matlab_struct(filtered_data["SessionData"])
380
+ filtered_data["SessionData"] = session_data
381
+
382
+ n_trials = int(session_data["nTrials"])
383
+
384
+ # Validate indices
385
+ if not trial_indices:
386
+ raise ValueError("trial_indices cannot be empty")
387
+
388
+ for idx in trial_indices:
389
+ if idx < 0 or idx >= n_trials:
390
+ raise IndexError(f"Trial index {idx} out of bounds (0-{n_trials-1})")
391
+
392
+ # Filter trial-related arrays
393
+ start_timestamps = session_data["TrialStartTimestamp"]
394
+ end_timestamps = session_data["TrialEndTimestamp"]
395
+
396
+ # Convert RawEvents to dict if needed
397
+ raw_events = convert_matlab_struct(session_data["RawEvents"])
398
+ session_data["RawEvents"] = raw_events
399
+
400
+ # Handle both numpy arrays and lists
401
+ def _index_array(arr: Any, indices: List[int]) -> Any:
402
+ """Helper to index arrays or lists."""
403
+ if isinstance(arr, np.ndarray):
404
+ return arr[indices]
405
+ elif isinstance(arr, (list, tuple)):
406
+ return [arr[i] for i in indices]
407
+ else:
408
+ # Scalar - shouldn't happen for these fields
409
+ return arr
410
+
411
+ # Filter timestamps
412
+ session_data["TrialStartTimestamp"] = _index_array(start_timestamps, trial_indices)
413
+ session_data["TrialEndTimestamp"] = _index_array(end_timestamps, trial_indices)
414
+
415
+ # Filter RawEvents.Trial (now always a dict)
416
+ trial_list = raw_events["Trial"]
417
+ filtered_trials = _index_array(trial_list, trial_indices)
418
+ raw_events["Trial"] = filtered_trials
419
+
420
+ # Filter optional fields if present
421
+ if "TrialSettings" in session_data:
422
+ trial_settings = session_data["TrialSettings"]
423
+ session_data["TrialSettings"] = _index_array(trial_settings, trial_indices)
424
+
425
+ if "TrialTypes" in session_data:
426
+ trial_types = session_data["TrialTypes"]
427
+ session_data["TrialTypes"] = _index_array(trial_types, trial_indices)
428
+
429
+ # Update nTrials count
430
+ session_data["nTrials"] = len(trial_indices)
431
+
432
+ logger.info(f"Indexed Bpod data: kept {len(trial_indices)} trials out of {n_trials}")
433
+ return filtered_data
434
+
435
+
436
+ def split_bpod_data(bpod_data: Dict[str, Any], splits: Sequence[Sequence[int]]) -> List[Dict[str, Any]]:
437
+ """Split Bpod data into multiple chunks by trial indices.
438
+
439
+ Each output chunk is a valid Bpod data dictionary that can be written
440
+ with write_bpod_mat and later re-merged with merge_bpod_sessions.
441
+
442
+ Args:
443
+ bpod_data: Bpod data dictionary
444
+ splits: Sequences of 0-based trial indices for each chunk
445
+
446
+ Returns:
447
+ List of Bpod data dictionaries
448
+
449
+ Raises:
450
+ BpodParseError: Invalid structure
451
+ IndexError: Indices out of bounds
452
+ ValueError: Empty split
453
+
454
+ Example:
455
+ >>> bpod_data = parse_bpod_mat(Path("data/session.mat"))
456
+ >>> chunks = split_bpod_data(bpod_data, [[0, 1], [2, 3], [4, 5]])
457
+ """
458
+
459
+ # Validate structure first
460
+ if not validate_bpod_structure(bpod_data):
461
+ raise BpodParseError("Invalid Bpod structure")
462
+
463
+ # Convert and inspect the source to validate indices against nTrials
464
+ session_data = convert_matlab_struct(bpod_data["SessionData"])
465
+ n_trials = int(session_data["nTrials"])
466
+
467
+ # Helper for a single split; reuses index_bpod_data to ensure deep copy
468
+ # and consistent filtering of all fields.
469
+ def _make_chunk(indices: Sequence[int]) -> Dict[str, Any]:
470
+ if not indices:
471
+ raise ValueError("split indices cannot be empty")
472
+
473
+ for idx in indices:
474
+ if idx < 0 or idx >= n_trials:
475
+ raise IndexError(f"Trial index {idx} out of bounds (0-{n_trials-1})")
476
+
477
+ # Delegate the heavy lifting to index_bpod_data, which:
478
+ # - deep copies the original structure
479
+ # - converts MATLAB structs to dicts
480
+ # - consistently filters timestamps, RawEvents.Trial, TrialSettings,
481
+ # TrialTypes, and updates nTrials.
482
+ return index_bpod_data(bpod_data, list(indices))
483
+
484
+ return [_make_chunk(indices) for indices in splits]
485
+
486
+
487
+ def write_bpod_mat(bpod_data: Dict[str, Any], output_path: Path) -> None:
488
+ """Write Bpod data to a .mat file.
489
+
490
+ Args:
491
+ bpod_data: Bpod data dictionary
492
+ output_path: Output .mat file path
493
+
494
+ Raises:
495
+ BpodParseError: Write failed or scipy not available
496
+ BpodValidationError: Invalid structure
497
+
498
+ Example:
499
+ >>> bpod_data = parse_bpod_mat(Path("data/session.mat"))
500
+ >>> filtered = index_bpod_data(bpod_data, [0, 1, 2])
501
+ >>> write_bpod_mat(filtered, Path("data/filtered.mat"))
502
+ """
503
+ # Validate structure before writing
504
+ if not validate_bpod_structure(bpod_data):
505
+ raise BpodValidationError("Invalid Bpod structure - cannot write to file")
506
+
507
+ if savemat is None:
508
+ raise BpodParseError("scipy is required for .mat file writing. Install with: pip install scipy")
509
+
510
+ try:
511
+ # Ensure parent directory exists
512
+ output_path.parent.mkdir(parents=True, exist_ok=True)
513
+
514
+ # Write to .mat file (MATLAB v5 format for compatibility)
515
+ savemat(str(output_path), bpod_data, format="5", oned_as="column")
516
+
517
+ logger.info(f"Successfully wrote Bpod data to: {output_path.name}")
518
+ except Exception as e:
519
+ raise BpodParseError(f"Failed to write Bpod file: {type(e).__name__}: {e}")