w2t-bkin 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- w2t_bkin/__init__.py +85 -0
- w2t_bkin/behavior/__init__.py +115 -0
- w2t_bkin/behavior/core.py +1027 -0
- w2t_bkin/bpod/__init__.py +38 -0
- w2t_bkin/bpod/core.py +519 -0
- w2t_bkin/config.py +625 -0
- w2t_bkin/dlc/__init__.py +59 -0
- w2t_bkin/dlc/core.py +448 -0
- w2t_bkin/dlc/models.py +124 -0
- w2t_bkin/exceptions.py +426 -0
- w2t_bkin/facemap/__init__.py +42 -0
- w2t_bkin/facemap/core.py +397 -0
- w2t_bkin/facemap/models.py +134 -0
- w2t_bkin/pipeline.py +665 -0
- w2t_bkin/pose/__init__.py +48 -0
- w2t_bkin/pose/core.py +227 -0
- w2t_bkin/pose/io.py +363 -0
- w2t_bkin/pose/skeleton.py +165 -0
- w2t_bkin/pose/ttl_mock.py +477 -0
- w2t_bkin/session.py +423 -0
- w2t_bkin/sync/__init__.py +72 -0
- w2t_bkin/sync/core.py +678 -0
- w2t_bkin/sync/stats.py +176 -0
- w2t_bkin/sync/timebase.py +311 -0
- w2t_bkin/sync/ttl.py +254 -0
- w2t_bkin/transcode/__init__.py +38 -0
- w2t_bkin/transcode/core.py +303 -0
- w2t_bkin/transcode/models.py +96 -0
- w2t_bkin/ttl/__init__.py +64 -0
- w2t_bkin/ttl/core.py +518 -0
- w2t_bkin/ttl/models.py +19 -0
- w2t_bkin/utils.py +1093 -0
- w2t_bkin-0.0.6.dist-info/METADATA +145 -0
- w2t_bkin-0.0.6.dist-info/RECORD +36 -0
- w2t_bkin-0.0.6.dist-info/WHEEL +4 -0
- w2t_bkin-0.0.6.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
"""Behavioral data parsing from Bpod .mat files.
|
|
2
|
+
|
|
3
|
+
Provides low-level Bpod file operations:
|
|
4
|
+
- Parsing and merging Bpod .mat files
|
|
5
|
+
- Validating Bpod data structure
|
|
6
|
+
- File indexing and manipulation
|
|
7
|
+
|
|
8
|
+
For behavioral data extraction, use the behavior module with ndx-structured-behavior.
|
|
9
|
+
|
|
10
|
+
Example:
|
|
11
|
+
>>> from pathlib import Path
|
|
12
|
+
>>> from w2t_bkin.bpod import parse_bpod
|
|
13
|
+
>>> from w2t_bkin.behavior import extract_trials_table
|
|
14
|
+
>>> bpod_data = parse_bpod(Path("data"), "Bpod/*.mat", "name_asc")
|
|
15
|
+
>>> trials = extract_trials_table(bpod_data)
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
# Exceptions
|
|
19
|
+
from ..exceptions import BpodParseError, BpodValidationError, EventsError
|
|
20
|
+
|
|
21
|
+
# Bpod file operations
|
|
22
|
+
from .core import index_bpod_data, merge_bpod_sessions, parse_bpod, parse_bpod_from_files, parse_bpod_mat, split_bpod_data, validate_bpod_structure, write_bpod_mat
|
|
23
|
+
|
|
24
|
+
__all__ = [
|
|
25
|
+
# Exceptions
|
|
26
|
+
"EventsError",
|
|
27
|
+
"BpodParseError",
|
|
28
|
+
"BpodValidationError",
|
|
29
|
+
# Bpod file operations
|
|
30
|
+
"parse_bpod",
|
|
31
|
+
"parse_bpod_mat",
|
|
32
|
+
"merge_bpod_sessions",
|
|
33
|
+
"parse_bpod_from_files",
|
|
34
|
+
"validate_bpod_structure",
|
|
35
|
+
"index_bpod_data",
|
|
36
|
+
"split_bpod_data",
|
|
37
|
+
"write_bpod_mat",
|
|
38
|
+
]
|
w2t_bkin/bpod/core.py
ADDED
|
@@ -0,0 +1,519 @@
|
|
|
1
|
+
"""Low-level Bpod .mat file I/O operations.
|
|
2
|
+
|
|
3
|
+
Provides functions to parse, merge, validate, index, and write Bpod data files.
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import copy
|
|
7
|
+
import logging
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Any, Dict, Iterable, List, Sequence
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
|
|
13
|
+
try:
|
|
14
|
+
from scipy.io import loadmat, savemat
|
|
15
|
+
except ImportError:
|
|
16
|
+
loadmat = None
|
|
17
|
+
savemat = None
|
|
18
|
+
|
|
19
|
+
from ..exceptions import BpodParseError, BpodValidationError
|
|
20
|
+
from ..utils import convert_matlab_struct, discover_files, sanitize_string, sort_files, validate_against_whitelist, validate_file_exists, validate_file_size
|
|
21
|
+
|
|
22
|
+
logger = logging.getLogger(__name__)
|
|
23
|
+
|
|
24
|
+
# Constants
|
|
25
|
+
MAX_BPOD_FILE_SIZE_MB = 100
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def validate_bpod_path(path: Path) -> None:
|
|
29
|
+
"""Validate Bpod file path and size.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
path: Path to .mat file
|
|
33
|
+
|
|
34
|
+
Raises:
|
|
35
|
+
BpodValidationError: Invalid path or file too large
|
|
36
|
+
"""
|
|
37
|
+
# Validate file exists
|
|
38
|
+
validate_file_exists(path, BpodValidationError, "Bpod file not found")
|
|
39
|
+
|
|
40
|
+
# Check file extension
|
|
41
|
+
if path.suffix.lower() not in [".mat"]:
|
|
42
|
+
raise BpodValidationError(f"Invalid file extension: {path.suffix}", file_path=str(path))
|
|
43
|
+
|
|
44
|
+
# Check file size (prevent memory exhaustion)
|
|
45
|
+
try:
|
|
46
|
+
file_size_mb = validate_file_size(path, max_size_mb=MAX_BPOD_FILE_SIZE_MB)
|
|
47
|
+
logger.debug(f"Validated Bpod file: {path.name} ({file_size_mb:.2f}MB)")
|
|
48
|
+
except ValueError as e:
|
|
49
|
+
# Re-raise as BpodValidationError for consistent error handling
|
|
50
|
+
raise BpodValidationError(str(e), file_path=str(path))
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def parse_bpod(session_dir: Path, pattern: str, order: str, continuous_time: bool = True) -> Dict[str, Any]:
|
|
54
|
+
"""Parse Bpod files matching a glob pattern.
|
|
55
|
+
|
|
56
|
+
Discovers files using glob pattern, sorts them, then parses and merges.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
session_dir: Base directory for resolving glob pattern
|
|
60
|
+
pattern: Glob pattern for Bpod files (e.g. "Bpod/*.mat")
|
|
61
|
+
order: Sort order (e.g. "name_asc", "modified_desc")
|
|
62
|
+
continuous_time: Offset timestamps for continuous timeline
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
Merged Bpod data dictionary
|
|
66
|
+
|
|
67
|
+
Raises:
|
|
68
|
+
BpodValidationError: No files found
|
|
69
|
+
BpodParseError: Parse/merge failed
|
|
70
|
+
|
|
71
|
+
Example:
|
|
72
|
+
>>> from pathlib import Path
|
|
73
|
+
>>> bpod_data = parse_bpod(Path("data"), "Bpod/*.mat", "name_asc")
|
|
74
|
+
"""
|
|
75
|
+
file_paths = discover_bpod_files_from_pattern(session_dir=session_dir, pattern=pattern, order=order)
|
|
76
|
+
return parse_bpod_from_files(file_paths=file_paths, continuous_time=continuous_time)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def discover_bpod_files_from_pattern(session_dir: Path, pattern: str, order: str) -> List[Path]:
|
|
80
|
+
"""Discover and sort Bpod .mat files using a glob pattern.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
session_dir: Base directory for glob pattern
|
|
84
|
+
pattern: Glob pattern (e.g. "Bpod/*.mat")
|
|
85
|
+
order: Sort order (e.g. "name_asc")
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
Sorted list of file paths
|
|
89
|
+
|
|
90
|
+
Raises:
|
|
91
|
+
BpodValidationError: No files found
|
|
92
|
+
"""
|
|
93
|
+
file_paths = discover_files(session_dir, pattern, sort=False)
|
|
94
|
+
|
|
95
|
+
if not file_paths:
|
|
96
|
+
raise BpodValidationError(f"No Bpod files found matching pattern: {pattern}")
|
|
97
|
+
|
|
98
|
+
file_paths = sort_files(file_paths, order)
|
|
99
|
+
|
|
100
|
+
logger.info("Discovered %d Bpod files with order '%s'", len(file_paths), order)
|
|
101
|
+
return file_paths
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def parse_bpod_from_files(file_paths: Sequence[Path], continuous_time: bool = True) -> Dict[str, Any]:
|
|
105
|
+
"""Parse and merge Bpod files from explicit paths.
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
file_paths: Ordered paths to .mat files
|
|
109
|
+
continuous_time: Offset timestamps for continuous timeline
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
Merged Bpod data dictionary
|
|
113
|
+
|
|
114
|
+
Raises:
|
|
115
|
+
BpodParseError: Parse/merge failed
|
|
116
|
+
"""
|
|
117
|
+
return merge_bpod_sessions(list(file_paths), continuous_time=continuous_time)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def parse_bpod_mat(path: Path) -> Dict[str, Any]:
|
|
121
|
+
"""Parse a single Bpod .mat file.
|
|
122
|
+
|
|
123
|
+
Args:
|
|
124
|
+
path: Path to .mat file
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
Bpod data dictionary
|
|
128
|
+
|
|
129
|
+
Raises:
|
|
130
|
+
BpodValidationError: File validation failed
|
|
131
|
+
BpodParseError: Parse failed
|
|
132
|
+
|
|
133
|
+
Example:
|
|
134
|
+
>>> from pathlib import Path
|
|
135
|
+
>>> bpod_data = parse_bpod_mat(Path("data/session.mat"))
|
|
136
|
+
"""
|
|
137
|
+
# Validate path and file size
|
|
138
|
+
validate_bpod_path(path)
|
|
139
|
+
|
|
140
|
+
if loadmat is None:
|
|
141
|
+
raise BpodParseError("scipy is required for .mat file parsing. Install with: pip install scipy")
|
|
142
|
+
|
|
143
|
+
try:
|
|
144
|
+
data = loadmat(str(path), squeeze_me=True, struct_as_record=False)
|
|
145
|
+
logger.info(f"Successfully parsed Bpod file: {path.name}")
|
|
146
|
+
return data
|
|
147
|
+
except Exception as e:
|
|
148
|
+
# Avoid leaking full path in error message
|
|
149
|
+
raise BpodParseError(f"Failed to parse Bpod file: {type(e).__name__}")
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def validate_bpod_structure(data: Dict[str, Any]) -> bool:
|
|
153
|
+
"""Validate Bpod data has required fields.
|
|
154
|
+
|
|
155
|
+
Args:
|
|
156
|
+
data: Bpod data dictionary
|
|
157
|
+
|
|
158
|
+
Returns:
|
|
159
|
+
True if valid
|
|
160
|
+
"""
|
|
161
|
+
if "SessionData" not in data:
|
|
162
|
+
logger.warning("Missing 'SessionData' in Bpod file")
|
|
163
|
+
return False
|
|
164
|
+
|
|
165
|
+
session_data = convert_matlab_struct(data["SessionData"])
|
|
166
|
+
|
|
167
|
+
# Check for required fields
|
|
168
|
+
required_fields = ["nTrials", "TrialStartTimestamp", "TrialEndTimestamp"]
|
|
169
|
+
for field in required_fields:
|
|
170
|
+
if field not in session_data:
|
|
171
|
+
logger.warning(f"Missing required field '{field}' in SessionData")
|
|
172
|
+
return False
|
|
173
|
+
|
|
174
|
+
# Check for RawEvents structure
|
|
175
|
+
if "RawEvents" not in session_data:
|
|
176
|
+
logger.warning("Missing 'RawEvents' in SessionData")
|
|
177
|
+
return False
|
|
178
|
+
|
|
179
|
+
raw_events = convert_matlab_struct(session_data["RawEvents"])
|
|
180
|
+
|
|
181
|
+
if "Trial" not in raw_events:
|
|
182
|
+
logger.warning("Missing 'Trial' in RawEvents")
|
|
183
|
+
return False
|
|
184
|
+
|
|
185
|
+
logger.debug("Bpod structure validation passed")
|
|
186
|
+
return True
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def merge_bpod_sessions(file_paths: List[Path], continuous_time: bool = True) -> Dict[str, Any]:
|
|
190
|
+
"""Merge multiple Bpod .mat files into one.
|
|
191
|
+
|
|
192
|
+
Combines trials from files in order. With continuous_time=True, offsets
|
|
193
|
+
timestamps so each file continues from the previous file's end time.
|
|
194
|
+
|
|
195
|
+
Args:
|
|
196
|
+
file_paths: Ordered list of .mat file paths
|
|
197
|
+
continuous_time: Offset timestamps for continuous timeline
|
|
198
|
+
|
|
199
|
+
Returns:
|
|
200
|
+
Merged Bpod data dictionary
|
|
201
|
+
|
|
202
|
+
Raises:
|
|
203
|
+
BpodParseError: Parse/merge failed
|
|
204
|
+
"""
|
|
205
|
+
if not file_paths:
|
|
206
|
+
raise BpodParseError("No Bpod files to merge")
|
|
207
|
+
|
|
208
|
+
if len(file_paths) == 1:
|
|
209
|
+
# Single file - just parse and return
|
|
210
|
+
return parse_bpod_mat(file_paths[0])
|
|
211
|
+
|
|
212
|
+
# Parse all files
|
|
213
|
+
parsed_files = []
|
|
214
|
+
for path in file_paths:
|
|
215
|
+
try:
|
|
216
|
+
data = parse_bpod_mat(path)
|
|
217
|
+
parsed_files.append((path, data))
|
|
218
|
+
except Exception as e:
|
|
219
|
+
logger.error(f"Failed to parse {path.name}: {e}")
|
|
220
|
+
raise
|
|
221
|
+
|
|
222
|
+
# Start with first file as base
|
|
223
|
+
_, merged_data = parsed_files[0]
|
|
224
|
+
merged_session = convert_matlab_struct(merged_data["SessionData"])
|
|
225
|
+
|
|
226
|
+
# Extract base data
|
|
227
|
+
all_trials = []
|
|
228
|
+
all_start_times = []
|
|
229
|
+
all_end_times = []
|
|
230
|
+
all_trial_settings = []
|
|
231
|
+
all_trial_types = []
|
|
232
|
+
|
|
233
|
+
# Add first file's data
|
|
234
|
+
first_raw_events = convert_matlab_struct(merged_session["RawEvents"])
|
|
235
|
+
# Ensure RawEvents is a dict in merged_session
|
|
236
|
+
merged_session["RawEvents"] = first_raw_events
|
|
237
|
+
|
|
238
|
+
# Convert Trial to list if it's a mat_struct or numpy array
|
|
239
|
+
trials = first_raw_events["Trial"]
|
|
240
|
+
if hasattr(trials, "__dict__"):
|
|
241
|
+
# mat_struct object - could be a single trial or not iterable
|
|
242
|
+
# Try to iterate, if not possible, wrap in list
|
|
243
|
+
try:
|
|
244
|
+
trials = [convert_matlab_struct(trial) for trial in trials]
|
|
245
|
+
except TypeError:
|
|
246
|
+
# Single mat_struct object - wrap in list
|
|
247
|
+
trials = [convert_matlab_struct(trials)]
|
|
248
|
+
elif isinstance(trials, np.ndarray):
|
|
249
|
+
# numpy array - convert to list
|
|
250
|
+
trials = trials.tolist()
|
|
251
|
+
elif not isinstance(trials, list):
|
|
252
|
+
# Other types - wrap in list
|
|
253
|
+
trials = list(trials) if hasattr(trials, "__iter__") else [trials]
|
|
254
|
+
|
|
255
|
+
all_trials.extend(trials)
|
|
256
|
+
|
|
257
|
+
# Convert timestamps to lists if they're numpy arrays
|
|
258
|
+
start_times = merged_session["TrialStartTimestamp"]
|
|
259
|
+
end_times = merged_session["TrialEndTimestamp"]
|
|
260
|
+
if isinstance(start_times, np.ndarray):
|
|
261
|
+
start_times = start_times.tolist()
|
|
262
|
+
if isinstance(end_times, np.ndarray):
|
|
263
|
+
end_times = end_times.tolist()
|
|
264
|
+
|
|
265
|
+
all_start_times.extend(start_times if isinstance(start_times, list) else [start_times])
|
|
266
|
+
all_end_times.extend(end_times if isinstance(end_times, list) else [end_times])
|
|
267
|
+
|
|
268
|
+
# Convert settings and types to lists if they're numpy arrays
|
|
269
|
+
trial_settings = merged_session.get("TrialSettings", [])
|
|
270
|
+
trial_types = merged_session.get("TrialTypes", [])
|
|
271
|
+
if isinstance(trial_settings, np.ndarray):
|
|
272
|
+
trial_settings = trial_settings.tolist()
|
|
273
|
+
if isinstance(trial_types, np.ndarray):
|
|
274
|
+
trial_types = trial_types.tolist()
|
|
275
|
+
|
|
276
|
+
all_trial_settings.extend(trial_settings if isinstance(trial_settings, list) else [trial_settings])
|
|
277
|
+
all_trial_types.extend(trial_types if isinstance(trial_types, list) else [trial_types])
|
|
278
|
+
|
|
279
|
+
# Merge subsequent files
|
|
280
|
+
for path, data in parsed_files[1:]:
|
|
281
|
+
session_data = convert_matlab_struct(data["SessionData"])
|
|
282
|
+
raw_events = convert_matlab_struct(session_data["RawEvents"])
|
|
283
|
+
|
|
284
|
+
# Get trial offset (time of last trial end) - only if continuous_time is True
|
|
285
|
+
time_offset = all_end_times[-1] if all_end_times and continuous_time else 0.0
|
|
286
|
+
|
|
287
|
+
# Convert Trial to list if it's a mat_struct or numpy array
|
|
288
|
+
trials = raw_events["Trial"]
|
|
289
|
+
if hasattr(trials, "__dict__"):
|
|
290
|
+
# mat_struct object - could be a single trial or not iterable
|
|
291
|
+
# Try to iterate, if not possible, wrap in list
|
|
292
|
+
try:
|
|
293
|
+
trials = [convert_matlab_struct(trial) for trial in trials]
|
|
294
|
+
except TypeError:
|
|
295
|
+
# Single mat_struct object - wrap in list
|
|
296
|
+
trials = [convert_matlab_struct(trials)]
|
|
297
|
+
elif isinstance(trials, np.ndarray):
|
|
298
|
+
# numpy array - convert to list
|
|
299
|
+
trials = trials.tolist()
|
|
300
|
+
elif not isinstance(trials, list):
|
|
301
|
+
# Other types - wrap in list
|
|
302
|
+
trials = list(trials) if hasattr(trials, "__iter__") else [trials]
|
|
303
|
+
|
|
304
|
+
# Append trials
|
|
305
|
+
all_trials.extend(trials)
|
|
306
|
+
|
|
307
|
+
# Offset timestamps
|
|
308
|
+
start_times = session_data["TrialStartTimestamp"]
|
|
309
|
+
end_times = session_data["TrialEndTimestamp"]
|
|
310
|
+
|
|
311
|
+
# Convert numpy arrays to lists
|
|
312
|
+
if isinstance(start_times, np.ndarray):
|
|
313
|
+
start_times = start_times.tolist()
|
|
314
|
+
if isinstance(end_times, np.ndarray):
|
|
315
|
+
end_times = end_times.tolist()
|
|
316
|
+
|
|
317
|
+
if isinstance(start_times, (list, tuple)):
|
|
318
|
+
all_start_times.extend([t + time_offset for t in start_times])
|
|
319
|
+
all_end_times.extend([t + time_offset for t in end_times])
|
|
320
|
+
else:
|
|
321
|
+
all_start_times.append(start_times + time_offset)
|
|
322
|
+
all_end_times.append(end_times + time_offset)
|
|
323
|
+
|
|
324
|
+
# Append settings and types
|
|
325
|
+
trial_settings = session_data.get("TrialSettings", [])
|
|
326
|
+
trial_types = session_data.get("TrialTypes", [])
|
|
327
|
+
|
|
328
|
+
# Convert numpy arrays to lists
|
|
329
|
+
if isinstance(trial_settings, np.ndarray):
|
|
330
|
+
trial_settings = trial_settings.tolist()
|
|
331
|
+
if isinstance(trial_types, np.ndarray):
|
|
332
|
+
trial_types = trial_types.tolist()
|
|
333
|
+
|
|
334
|
+
all_trial_settings.extend(trial_settings if isinstance(trial_settings, list) else [trial_settings])
|
|
335
|
+
all_trial_types.extend(trial_types if isinstance(trial_types, list) else [trial_types])
|
|
336
|
+
|
|
337
|
+
logger.debug(f"Merged {path.name}: added {session_data['nTrials']} trials")
|
|
338
|
+
|
|
339
|
+
# Update merged data
|
|
340
|
+
merged_session["nTrials"] = len(all_trials)
|
|
341
|
+
merged_session["TrialStartTimestamp"] = all_start_times
|
|
342
|
+
merged_session["TrialEndTimestamp"] = all_end_times
|
|
343
|
+
merged_session["RawEvents"]["Trial"] = all_trials
|
|
344
|
+
merged_session["TrialSettings"] = all_trial_settings
|
|
345
|
+
merged_session["TrialTypes"] = all_trial_types
|
|
346
|
+
|
|
347
|
+
merged_data["SessionData"] = merged_session
|
|
348
|
+
|
|
349
|
+
logger.info(f"Merged {len(file_paths)} Bpod files into {len(all_trials)} total trials")
|
|
350
|
+
return merged_data
|
|
351
|
+
|
|
352
|
+
|
|
353
|
+
def index_bpod_data(bpod_data: Dict[str, Any], trial_indices: List[int]) -> Dict[str, Any]:
|
|
354
|
+
"""Filter Bpod data to keep only specified trials.
|
|
355
|
+
|
|
356
|
+
Args:
|
|
357
|
+
bpod_data: Bpod data dictionary
|
|
358
|
+
trial_indices: 0-based indices of trials to keep
|
|
359
|
+
|
|
360
|
+
Returns:
|
|
361
|
+
New Bpod data with filtered trials
|
|
362
|
+
|
|
363
|
+
Raises:
|
|
364
|
+
BpodParseError: Invalid structure
|
|
365
|
+
IndexError: Indices out of bounds
|
|
366
|
+
|
|
367
|
+
Example:
|
|
368
|
+
>>> bpod_data = parse_bpod_mat(Path("data/session.mat"))
|
|
369
|
+
>>> filtered = index_bpod_data(bpod_data, [0, 1, 2]) # First 3 trials
|
|
370
|
+
"""
|
|
371
|
+
# Validate structure
|
|
372
|
+
if not validate_bpod_structure(bpod_data):
|
|
373
|
+
raise BpodParseError("Invalid Bpod structure")
|
|
374
|
+
|
|
375
|
+
# Deep copy to avoid modifying original
|
|
376
|
+
filtered_data = copy.deepcopy(bpod_data)
|
|
377
|
+
|
|
378
|
+
# Convert MATLAB struct to dict if needed
|
|
379
|
+
session_data = convert_matlab_struct(filtered_data["SessionData"])
|
|
380
|
+
filtered_data["SessionData"] = session_data
|
|
381
|
+
|
|
382
|
+
n_trials = int(session_data["nTrials"])
|
|
383
|
+
|
|
384
|
+
# Validate indices
|
|
385
|
+
if not trial_indices:
|
|
386
|
+
raise ValueError("trial_indices cannot be empty")
|
|
387
|
+
|
|
388
|
+
for idx in trial_indices:
|
|
389
|
+
if idx < 0 or idx >= n_trials:
|
|
390
|
+
raise IndexError(f"Trial index {idx} out of bounds (0-{n_trials-1})")
|
|
391
|
+
|
|
392
|
+
# Filter trial-related arrays
|
|
393
|
+
start_timestamps = session_data["TrialStartTimestamp"]
|
|
394
|
+
end_timestamps = session_data["TrialEndTimestamp"]
|
|
395
|
+
|
|
396
|
+
# Convert RawEvents to dict if needed
|
|
397
|
+
raw_events = convert_matlab_struct(session_data["RawEvents"])
|
|
398
|
+
session_data["RawEvents"] = raw_events
|
|
399
|
+
|
|
400
|
+
# Handle both numpy arrays and lists
|
|
401
|
+
def _index_array(arr: Any, indices: List[int]) -> Any:
|
|
402
|
+
"""Helper to index arrays or lists."""
|
|
403
|
+
if isinstance(arr, np.ndarray):
|
|
404
|
+
return arr[indices]
|
|
405
|
+
elif isinstance(arr, (list, tuple)):
|
|
406
|
+
return [arr[i] for i in indices]
|
|
407
|
+
else:
|
|
408
|
+
# Scalar - shouldn't happen for these fields
|
|
409
|
+
return arr
|
|
410
|
+
|
|
411
|
+
# Filter timestamps
|
|
412
|
+
session_data["TrialStartTimestamp"] = _index_array(start_timestamps, trial_indices)
|
|
413
|
+
session_data["TrialEndTimestamp"] = _index_array(end_timestamps, trial_indices)
|
|
414
|
+
|
|
415
|
+
# Filter RawEvents.Trial (now always a dict)
|
|
416
|
+
trial_list = raw_events["Trial"]
|
|
417
|
+
filtered_trials = _index_array(trial_list, trial_indices)
|
|
418
|
+
raw_events["Trial"] = filtered_trials
|
|
419
|
+
|
|
420
|
+
# Filter optional fields if present
|
|
421
|
+
if "TrialSettings" in session_data:
|
|
422
|
+
trial_settings = session_data["TrialSettings"]
|
|
423
|
+
session_data["TrialSettings"] = _index_array(trial_settings, trial_indices)
|
|
424
|
+
|
|
425
|
+
if "TrialTypes" in session_data:
|
|
426
|
+
trial_types = session_data["TrialTypes"]
|
|
427
|
+
session_data["TrialTypes"] = _index_array(trial_types, trial_indices)
|
|
428
|
+
|
|
429
|
+
# Update nTrials count
|
|
430
|
+
session_data["nTrials"] = len(trial_indices)
|
|
431
|
+
|
|
432
|
+
logger.info(f"Indexed Bpod data: kept {len(trial_indices)} trials out of {n_trials}")
|
|
433
|
+
return filtered_data
|
|
434
|
+
|
|
435
|
+
|
|
436
|
+
def split_bpod_data(bpod_data: Dict[str, Any], splits: Sequence[Sequence[int]]) -> List[Dict[str, Any]]:
|
|
437
|
+
"""Split Bpod data into multiple chunks by trial indices.
|
|
438
|
+
|
|
439
|
+
Each output chunk is a valid Bpod data dictionary that can be written
|
|
440
|
+
with write_bpod_mat and later re-merged with merge_bpod_sessions.
|
|
441
|
+
|
|
442
|
+
Args:
|
|
443
|
+
bpod_data: Bpod data dictionary
|
|
444
|
+
splits: Sequences of 0-based trial indices for each chunk
|
|
445
|
+
|
|
446
|
+
Returns:
|
|
447
|
+
List of Bpod data dictionaries
|
|
448
|
+
|
|
449
|
+
Raises:
|
|
450
|
+
BpodParseError: Invalid structure
|
|
451
|
+
IndexError: Indices out of bounds
|
|
452
|
+
ValueError: Empty split
|
|
453
|
+
|
|
454
|
+
Example:
|
|
455
|
+
>>> bpod_data = parse_bpod_mat(Path("data/session.mat"))
|
|
456
|
+
>>> chunks = split_bpod_data(bpod_data, [[0, 1], [2, 3], [4, 5]])
|
|
457
|
+
"""
|
|
458
|
+
|
|
459
|
+
# Validate structure first
|
|
460
|
+
if not validate_bpod_structure(bpod_data):
|
|
461
|
+
raise BpodParseError("Invalid Bpod structure")
|
|
462
|
+
|
|
463
|
+
# Convert and inspect the source to validate indices against nTrials
|
|
464
|
+
session_data = convert_matlab_struct(bpod_data["SessionData"])
|
|
465
|
+
n_trials = int(session_data["nTrials"])
|
|
466
|
+
|
|
467
|
+
# Helper for a single split; reuses index_bpod_data to ensure deep copy
|
|
468
|
+
# and consistent filtering of all fields.
|
|
469
|
+
def _make_chunk(indices: Sequence[int]) -> Dict[str, Any]:
|
|
470
|
+
if not indices:
|
|
471
|
+
raise ValueError("split indices cannot be empty")
|
|
472
|
+
|
|
473
|
+
for idx in indices:
|
|
474
|
+
if idx < 0 or idx >= n_trials:
|
|
475
|
+
raise IndexError(f"Trial index {idx} out of bounds (0-{n_trials-1})")
|
|
476
|
+
|
|
477
|
+
# Delegate the heavy lifting to index_bpod_data, which:
|
|
478
|
+
# - deep copies the original structure
|
|
479
|
+
# - converts MATLAB structs to dicts
|
|
480
|
+
# - consistently filters timestamps, RawEvents.Trial, TrialSettings,
|
|
481
|
+
# TrialTypes, and updates nTrials.
|
|
482
|
+
return index_bpod_data(bpod_data, list(indices))
|
|
483
|
+
|
|
484
|
+
return [_make_chunk(indices) for indices in splits]
|
|
485
|
+
|
|
486
|
+
|
|
487
|
+
def write_bpod_mat(bpod_data: Dict[str, Any], output_path: Path) -> None:
|
|
488
|
+
"""Write Bpod data to a .mat file.
|
|
489
|
+
|
|
490
|
+
Args:
|
|
491
|
+
bpod_data: Bpod data dictionary
|
|
492
|
+
output_path: Output .mat file path
|
|
493
|
+
|
|
494
|
+
Raises:
|
|
495
|
+
BpodParseError: Write failed or scipy not available
|
|
496
|
+
BpodValidationError: Invalid structure
|
|
497
|
+
|
|
498
|
+
Example:
|
|
499
|
+
>>> bpod_data = parse_bpod_mat(Path("data/session.mat"))
|
|
500
|
+
>>> filtered = index_bpod_data(bpod_data, [0, 1, 2])
|
|
501
|
+
>>> write_bpod_mat(filtered, Path("data/filtered.mat"))
|
|
502
|
+
"""
|
|
503
|
+
# Validate structure before writing
|
|
504
|
+
if not validate_bpod_structure(bpod_data):
|
|
505
|
+
raise BpodValidationError("Invalid Bpod structure - cannot write to file")
|
|
506
|
+
|
|
507
|
+
if savemat is None:
|
|
508
|
+
raise BpodParseError("scipy is required for .mat file writing. Install with: pip install scipy")
|
|
509
|
+
|
|
510
|
+
try:
|
|
511
|
+
# Ensure parent directory exists
|
|
512
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
513
|
+
|
|
514
|
+
# Write to .mat file (MATLAB v5 format for compatibility)
|
|
515
|
+
savemat(str(output_path), bpod_data, format="5", oned_as="column")
|
|
516
|
+
|
|
517
|
+
logger.info(f"Successfully wrote Bpod data to: {output_path.name}")
|
|
518
|
+
except Exception as e:
|
|
519
|
+
raise BpodParseError(f"Failed to write Bpod file: {type(e).__name__}: {e}")
|