w2t-bkin 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- w2t_bkin/__init__.py +85 -0
- w2t_bkin/behavior/__init__.py +115 -0
- w2t_bkin/behavior/core.py +1027 -0
- w2t_bkin/bpod/__init__.py +38 -0
- w2t_bkin/bpod/core.py +519 -0
- w2t_bkin/config.py +625 -0
- w2t_bkin/dlc/__init__.py +59 -0
- w2t_bkin/dlc/core.py +448 -0
- w2t_bkin/dlc/models.py +124 -0
- w2t_bkin/exceptions.py +426 -0
- w2t_bkin/facemap/__init__.py +42 -0
- w2t_bkin/facemap/core.py +397 -0
- w2t_bkin/facemap/models.py +134 -0
- w2t_bkin/pipeline.py +665 -0
- w2t_bkin/pose/__init__.py +48 -0
- w2t_bkin/pose/core.py +227 -0
- w2t_bkin/pose/io.py +363 -0
- w2t_bkin/pose/skeleton.py +165 -0
- w2t_bkin/pose/ttl_mock.py +477 -0
- w2t_bkin/session.py +423 -0
- w2t_bkin/sync/__init__.py +72 -0
- w2t_bkin/sync/core.py +678 -0
- w2t_bkin/sync/stats.py +176 -0
- w2t_bkin/sync/timebase.py +311 -0
- w2t_bkin/sync/ttl.py +254 -0
- w2t_bkin/transcode/__init__.py +38 -0
- w2t_bkin/transcode/core.py +303 -0
- w2t_bkin/transcode/models.py +96 -0
- w2t_bkin/ttl/__init__.py +64 -0
- w2t_bkin/ttl/core.py +518 -0
- w2t_bkin/ttl/models.py +19 -0
- w2t_bkin/utils.py +1093 -0
- w2t_bkin-0.0.6.dist-info/METADATA +145 -0
- w2t_bkin-0.0.6.dist-info/RECORD +36 -0
- w2t_bkin-0.0.6.dist-info/WHEEL +4 -0
- w2t_bkin-0.0.6.dist-info/licenses/LICENSE +201 -0
w2t_bkin/config.py
ADDED
|
@@ -0,0 +1,625 @@
|
|
|
1
|
+
"""Configuration management for W2T-BKIN pipeline.
|
|
2
|
+
|
|
3
|
+
This module provides Pydantic models for validating configuration files (config.toml)
|
|
4
|
+
and functions for loading, validating, and hashing configurations.
|
|
5
|
+
|
|
6
|
+
The configuration system enforces strict schema validation to catch errors early,
|
|
7
|
+
supports deterministic hashing for reproducibility, and provides clear error messages.
|
|
8
|
+
|
|
9
|
+
Typical usage example:
|
|
10
|
+
>>> from w2t_bkin.config import load_config
|
|
11
|
+
>>>
|
|
12
|
+
>>> config = load_config("config.toml")
|
|
13
|
+
>>> print(config.project.name)
|
|
14
|
+
>>> print(config.timebase.source)
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
from pathlib import Path
|
|
20
|
+
from typing import Any, Dict, List, Literal, Optional, Union
|
|
21
|
+
|
|
22
|
+
try:
|
|
23
|
+
import tomllib
|
|
24
|
+
except ImportError:
|
|
25
|
+
import tomli as tomllib # Python < 3.11 fallback
|
|
26
|
+
|
|
27
|
+
from pydantic import BaseModel, Field, ValidationError, field_validator
|
|
28
|
+
|
|
29
|
+
from .utils import compute_hash, read_toml
|
|
30
|
+
|
|
31
|
+
# =============================================================================
|
|
32
|
+
# Constants
|
|
33
|
+
# =============================================================================
|
|
34
|
+
|
|
35
|
+
VALID_TIMEBASE_SOURCES = frozenset({"nominal_rate", "ttl", "neuropixels"})
|
|
36
|
+
VALID_TIMEBASE_MAPPINGS = frozenset({"nearest", "linear"})
|
|
37
|
+
VALID_LOGGING_LEVELS = frozenset({"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"})
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
# =============================================================================
|
|
41
|
+
# Configuration Models - Core
|
|
42
|
+
# =============================================================================
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class ProjectConfig(BaseModel, extra="forbid"):
|
|
46
|
+
"""Project identification.
|
|
47
|
+
|
|
48
|
+
Attributes:
|
|
49
|
+
name: Project name identifier.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
name: str = Field(..., description="Project name")
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class PathsConfig(BaseModel, extra="forbid"):
|
|
56
|
+
"""File system paths configuration.
|
|
57
|
+
|
|
58
|
+
Attributes:
|
|
59
|
+
raw_root: Path to raw data directory.
|
|
60
|
+
intermediate_root: Path for intermediate processing outputs.
|
|
61
|
+
output_root: Path for final outputs.
|
|
62
|
+
metadata_file: Filename for session metadata (default: session.toml).
|
|
63
|
+
models_root: Directory containing pose estimation models (default: models).
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
raw_root: Path = Field(..., description="Raw data root directory")
|
|
67
|
+
intermediate_root: Path = Field(..., description="Intermediate processing outputs")
|
|
68
|
+
output_root: Path = Field(..., description="Output data root directory")
|
|
69
|
+
metadata_file: Path = Field(default="session.toml", description="Session metadata filename")
|
|
70
|
+
models_root: Path = Field(default="models", description="Pose estimation models directory")
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class TimebaseConfig(BaseModel, extra="forbid"):
|
|
74
|
+
"""Reference timebase for aligning derived data.
|
|
75
|
+
|
|
76
|
+
Defines the reference clock for synchronizing pose and behavior data.
|
|
77
|
+
ImageSeries remain rate-based; this timebase applies to derived modalities.
|
|
78
|
+
|
|
79
|
+
Attributes:
|
|
80
|
+
source: Timebase source (nominal_rate, ttl, or neuropixels).
|
|
81
|
+
mapping: Strategy for mapping timestamps (nearest or linear).
|
|
82
|
+
jitter_budget_s: Maximum allowed temporal jitter in seconds.
|
|
83
|
+
offset_s: Global time offset before mapping (default: 0.0).
|
|
84
|
+
ttl_id: TTL channel ID (required when source='ttl').
|
|
85
|
+
neuropixels_stream: Neuropixels stream name (required when source='neuropixels').
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
source: Literal["nominal_rate", "ttl", "neuropixels"] = Field(..., description="Timebase source")
|
|
89
|
+
mapping: Literal["nearest", "linear"] = Field(..., description="Mapping strategy")
|
|
90
|
+
jitter_budget_s: float = Field(..., ge=0.0, description="Max allowed jitter in seconds")
|
|
91
|
+
offset_s: float = Field(default=0.0, description="Global offset before mapping")
|
|
92
|
+
ttl_id: Optional[str] = Field(None, description="TTL ID (required when source='ttl')")
|
|
93
|
+
neuropixels_stream: Optional[str] = Field(None, description="Neuropixels stream (required when source='neuropixels')")
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class AcquisitionConfig(BaseModel, extra="forbid"):
|
|
97
|
+
"""Data acquisition policies.
|
|
98
|
+
|
|
99
|
+
Attributes:
|
|
100
|
+
concat_strategy: Video concatenation method (ffconcat or streamlist).
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
concat_strategy: Literal["ffconcat", "streamlist"] = Field(default="ffconcat", description="Video concatenation strategy")
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class VerificationConfig(BaseModel, extra="forbid"):
|
|
107
|
+
"""Hardware synchronization verification.
|
|
108
|
+
|
|
109
|
+
Attributes:
|
|
110
|
+
mismatch_tolerance_frames: Max allowed frame/TTL count mismatch before abort.
|
|
111
|
+
warn_on_mismatch: If True, warn instead of abort when within tolerance.
|
|
112
|
+
"""
|
|
113
|
+
|
|
114
|
+
mismatch_tolerance_frames: int = Field(default=0, ge=0, description="Abort if frame_count - ttl_pulse_count > tolerance")
|
|
115
|
+
warn_on_mismatch: bool = Field(default=False, description="Warn instead of abort if within tolerance")
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
# =============================================================================
|
|
119
|
+
# Configuration Models - Bpod
|
|
120
|
+
# =============================================================================
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
class BpodSyncTrialType(BaseModel, extra="forbid"):
|
|
124
|
+
"""Bpod trial type synchronization mapping.
|
|
125
|
+
|
|
126
|
+
Maps a Bpod trial type to its synchronization signal and TTL channel,
|
|
127
|
+
enabling conversion from Bpod relative timestamps to absolute time.
|
|
128
|
+
|
|
129
|
+
Attributes:
|
|
130
|
+
trial_type: Trial type identifier matching Bpod classification.
|
|
131
|
+
sync_signal: Bpod state/event name for alignment (e.g., 'W2T_Audio').
|
|
132
|
+
sync_ttl: TTL channel whose pulses correspond to sync_signal.
|
|
133
|
+
"""
|
|
134
|
+
|
|
135
|
+
trial_type: int = Field(..., ge=0, description="Trial type identifier")
|
|
136
|
+
sync_signal: str = Field(..., description="Bpod state/event for alignment")
|
|
137
|
+
sync_ttl: str = Field(..., description="TTL channel for sync pulses")
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
class BpodSyncConfig(BaseModel, extra="forbid"):
|
|
141
|
+
"""Bpod-to-TTL synchronization configuration.
|
|
142
|
+
|
|
143
|
+
Attributes:
|
|
144
|
+
trial_types: List of trial type sync configurations.
|
|
145
|
+
"""
|
|
146
|
+
|
|
147
|
+
trial_types: List[BpodSyncTrialType] = Field(default_factory=list, description="Trial type sync configs")
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
class BpodConfig(BaseModel, extra="forbid"):
|
|
151
|
+
"""Bpod behavioral control system configuration.
|
|
152
|
+
|
|
153
|
+
Attributes:
|
|
154
|
+
parse: Whether to parse Bpod .mat files.
|
|
155
|
+
sync: Trial synchronization configuration.
|
|
156
|
+
"""
|
|
157
|
+
|
|
158
|
+
parse: bool = Field(default=True, description="Parse Bpod .mat files if present")
|
|
159
|
+
sync: BpodSyncConfig = Field(default_factory=BpodSyncConfig, description="Trial sync configuration")
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
# =============================================================================
|
|
163
|
+
# Configuration Models - Video
|
|
164
|
+
# =============================================================================
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
class TranscodeConfig(BaseModel, extra="forbid"):
|
|
168
|
+
"""Video transcoding settings.
|
|
169
|
+
|
|
170
|
+
Attributes:
|
|
171
|
+
enabled: Enable video transcoding.
|
|
172
|
+
codec: FFmpeg codec (e.g., 'h264', 'libx264').
|
|
173
|
+
crf: Constant rate factor quality (0-51, lower is better).
|
|
174
|
+
preset: FFmpeg encoding preset (e.g., 'fast', 'medium').
|
|
175
|
+
keyint: GOP (group of pictures) length.
|
|
176
|
+
"""
|
|
177
|
+
|
|
178
|
+
enabled: bool = Field(default=True, description="Enable transcoding")
|
|
179
|
+
codec: str = Field(default="h264", description="FFmpeg codec name")
|
|
180
|
+
crf: int = Field(default=20, ge=0, le=51, description="Quality factor (0-51)")
|
|
181
|
+
preset: str = Field(default="fast", description="FFmpeg preset")
|
|
182
|
+
keyint: int = Field(default=15, ge=1, description="GOP length")
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
class VideoConfig(BaseModel, extra="forbid"):
|
|
186
|
+
"""Video processing configuration.
|
|
187
|
+
|
|
188
|
+
Attributes:
|
|
189
|
+
transcode: Transcoding settings.
|
|
190
|
+
"""
|
|
191
|
+
|
|
192
|
+
transcode: TranscodeConfig = Field(default_factory=TranscodeConfig, description="Transcoding config")
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
# =============================================================================
|
|
196
|
+
# Configuration Models - Output & Logging
|
|
197
|
+
# =============================================================================
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
class NWBConfig(BaseModel, extra="forbid"):
|
|
201
|
+
"""NWB (Neurodata Without Borders) export settings.
|
|
202
|
+
|
|
203
|
+
Attributes:
|
|
204
|
+
link_external_video: Use external links for videos instead of embedding.
|
|
205
|
+
lab: Laboratory name.
|
|
206
|
+
institution: Institution name.
|
|
207
|
+
file_name_template: Template for NWB filename.
|
|
208
|
+
session_description_template: Template for session description.
|
|
209
|
+
"""
|
|
210
|
+
|
|
211
|
+
link_external_video: bool = Field(default=True, description="Link videos externally")
|
|
212
|
+
lab: str = Field(default="Lab Name", description="Lab name")
|
|
213
|
+
institution: str = Field(default="Institution Name", description="Institution name")
|
|
214
|
+
file_name_template: str = Field(default="{session.id}.nwb", description="NWB filename template")
|
|
215
|
+
session_description_template: str = Field(default="Session {session.id} on {session.date}", description="Session description template")
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
class QCConfig(BaseModel, extra="forbid"):
|
|
219
|
+
"""Quality control report configuration.
|
|
220
|
+
|
|
221
|
+
Attributes:
|
|
222
|
+
generate_report: Enable QC report generation.
|
|
223
|
+
out_template: Output path template for reports.
|
|
224
|
+
include_verification: Include frame/TTL verification in reports.
|
|
225
|
+
"""
|
|
226
|
+
|
|
227
|
+
generate_report: bool = Field(default=True, description="Generate QC report")
|
|
228
|
+
out_template: str = Field(default="qc/{session.id}", description="Output path template")
|
|
229
|
+
include_verification: bool = Field(default=True, description="Include verification in report")
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
class LoggingConfig(BaseModel, extra="forbid"):
|
|
233
|
+
"""Logging configuration.
|
|
234
|
+
|
|
235
|
+
Attributes:
|
|
236
|
+
level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL).
|
|
237
|
+
structured: Use structured (JSON) logging format.
|
|
238
|
+
"""
|
|
239
|
+
|
|
240
|
+
level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = Field(default="INFO", description="Logging level")
|
|
241
|
+
structured: bool = Field(default=False, description="Use structured logging")
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
# =============================================================================
|
|
245
|
+
# Configuration Models - Inference
|
|
246
|
+
# =============================================================================
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
class DLCConfig(BaseModel, extra="forbid"):
|
|
250
|
+
"""DeepLabCut pose estimation configuration.
|
|
251
|
+
|
|
252
|
+
Attributes:
|
|
253
|
+
run_inference: Enable DLC inference.
|
|
254
|
+
model: Path to DLC model file.
|
|
255
|
+
gputouse: GPU device index (-1 for CPU, None for auto-select).
|
|
256
|
+
"""
|
|
257
|
+
|
|
258
|
+
run_inference: bool = Field(default=False, description="Run DLC inference")
|
|
259
|
+
model: str = Field(default="model.pb", description="DLC model path")
|
|
260
|
+
gputouse: Optional[int] = Field(None, description="GPU index (-1=CPU, None=auto)")
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
class SLEAPConfig(BaseModel, extra="forbid"):
|
|
264
|
+
"""SLEAP pose estimation configuration.
|
|
265
|
+
|
|
266
|
+
Attributes:
|
|
267
|
+
run_inference: Enable SLEAP inference.
|
|
268
|
+
model: Path to SLEAP model file.
|
|
269
|
+
"""
|
|
270
|
+
|
|
271
|
+
run_inference: bool = Field(default=False, description="Run SLEAP inference")
|
|
272
|
+
model: str = Field(default="sleap.h5", description="SLEAP model path")
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
class LabelsConfig(BaseModel, extra="forbid"):
|
|
276
|
+
"""Pose labeling configuration.
|
|
277
|
+
|
|
278
|
+
Attributes:
|
|
279
|
+
dlc: DeepLabCut configuration.
|
|
280
|
+
sleap: SLEAP configuration.
|
|
281
|
+
"""
|
|
282
|
+
|
|
283
|
+
dlc: DLCConfig = Field(default_factory=DLCConfig, description="DLC config")
|
|
284
|
+
sleap: SLEAPConfig = Field(default_factory=SLEAPConfig, description="SLEAP config")
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
class FacemapConfig(BaseModel, extra="forbid"):
|
|
288
|
+
"""Facemap facial motion tracking configuration.
|
|
289
|
+
|
|
290
|
+
Attributes:
|
|
291
|
+
run_inference: Enable Facemap inference.
|
|
292
|
+
ROIs: Regions of interest to process.
|
|
293
|
+
"""
|
|
294
|
+
|
|
295
|
+
run_inference: bool = Field(default=False, description="Run Facemap inference")
|
|
296
|
+
ROIs: List[str] = Field(default_factory=lambda: ["face", "left_eye", "right_eye"], description="ROIs to process")
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
# =============================================================================
|
|
300
|
+
# Main Configuration Model
|
|
301
|
+
# =============================================================================
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
class Config(BaseModel, extra="forbid"):
|
|
305
|
+
"""Main pipeline configuration.
|
|
306
|
+
|
|
307
|
+
Root configuration model loaded from config.toml. Uses strict validation
|
|
308
|
+
with extra="forbid" to prevent typos and configuration errors.
|
|
309
|
+
|
|
310
|
+
Attributes:
|
|
311
|
+
project: Project identification.
|
|
312
|
+
paths: File system paths.
|
|
313
|
+
timebase: Reference timebase for synchronization.
|
|
314
|
+
acquisition: Data acquisition policies.
|
|
315
|
+
verification: Hardware sync verification.
|
|
316
|
+
bpod: Bpod behavioral control settings.
|
|
317
|
+
video: Video processing configuration.
|
|
318
|
+
nwb: NWB export settings.
|
|
319
|
+
qc: Quality control configuration.
|
|
320
|
+
logging: Logging configuration.
|
|
321
|
+
labels: Pose labeling configuration.
|
|
322
|
+
facemap: Facemap tracking configuration.
|
|
323
|
+
"""
|
|
324
|
+
|
|
325
|
+
project: ProjectConfig
|
|
326
|
+
paths: PathsConfig
|
|
327
|
+
# timebase: TimebaseConfig
|
|
328
|
+
# acquisition: AcquisitionConfig = Field(default_factory=AcquisitionConfig)
|
|
329
|
+
# verification: VerificationConfig = Field(default_factory=VerificationConfig)
|
|
330
|
+
bpod: BpodConfig = Field(default_factory=BpodConfig)
|
|
331
|
+
# video: VideoConfig = Field(default_factory=VideoConfig)
|
|
332
|
+
# nwb: NWBConfig = Field(default_factory=NWBConfig)
|
|
333
|
+
# qc: QCConfig = Field(default_factory=QCConfig)
|
|
334
|
+
logging: LoggingConfig = Field(default_factory=LoggingConfig)
|
|
335
|
+
# labels: LabelsConfig = Field(default_factory=LabelsConfig)
|
|
336
|
+
# facemap: FacemapConfig = Field(default_factory=FacemapConfig)
|
|
337
|
+
|
|
338
|
+
# @field_validator("timebase")
|
|
339
|
+
# @classmethod
|
|
340
|
+
# def validate_timebase_conditionals(cls, v: TimebaseConfig) -> TimebaseConfig:
|
|
341
|
+
# """Validate conditional timebase requirements.
|
|
342
|
+
|
|
343
|
+
# Args:
|
|
344
|
+
# v: TimebaseConfig instance to validate.
|
|
345
|
+
|
|
346
|
+
# Returns:
|
|
347
|
+
# Validated TimebaseConfig.
|
|
348
|
+
|
|
349
|
+
# Raises:
|
|
350
|
+
# ValueError: If conditional requirements are not met.
|
|
351
|
+
# """
|
|
352
|
+
# if v.source == "ttl" and v.ttl_id is None:
|
|
353
|
+
# raise ValueError("timebase.ttl_id is required when source='ttl'")
|
|
354
|
+
# if v.source == "neuropixels" and v.neuropixels_stream is None:
|
|
355
|
+
# raise ValueError("timebase.neuropixels_stream is required when source='neuropixels'")
|
|
356
|
+
# return v
|
|
357
|
+
|
|
358
|
+
|
|
359
|
+
# =============================================================================
|
|
360
|
+
# Public API Functions
|
|
361
|
+
# =============================================================================
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
def load_config(path: Union[str, Path]) -> Config:
|
|
365
|
+
"""Load and validate configuration from TOML file.
|
|
366
|
+
|
|
367
|
+
Performs comprehensive validation including:
|
|
368
|
+
- Schema validation with extra="forbid" to prevent typos
|
|
369
|
+
- Enum validation for source, mapping, and level fields
|
|
370
|
+
- Numeric constraints (e.g., jitter_budget_s >= 0)
|
|
371
|
+
- Conditional requirements (e.g., ttl_id when source='ttl')
|
|
372
|
+
|
|
373
|
+
Args:
|
|
374
|
+
path: Path to config.toml file.
|
|
375
|
+
|
|
376
|
+
Returns:
|
|
377
|
+
Validated Config instance.
|
|
378
|
+
|
|
379
|
+
Raises:
|
|
380
|
+
FileNotFoundError: If config file doesn't exist.
|
|
381
|
+
ValidationError: If config violates Pydantic schema.
|
|
382
|
+
ValueError: If enum or conditional validation fails.
|
|
383
|
+
|
|
384
|
+
Example:
|
|
385
|
+
>>> config = load_config("config.toml")
|
|
386
|
+
>>> print(config.project.name)
|
|
387
|
+
>>> print(config.timebase.source)
|
|
388
|
+
"""
|
|
389
|
+
data = read_toml(path)
|
|
390
|
+
|
|
391
|
+
# Pre-validate enums for clearer error messages
|
|
392
|
+
_validate_config_enums(data)
|
|
393
|
+
|
|
394
|
+
# Pre-validate conditional requirements
|
|
395
|
+
_validate_config_conditionals(data)
|
|
396
|
+
|
|
397
|
+
return Config(**data)
|
|
398
|
+
|
|
399
|
+
|
|
400
|
+
def compute_config_hash(config: Config) -> str:
|
|
401
|
+
"""Compute deterministic SHA256 hash of configuration.
|
|
402
|
+
|
|
403
|
+
Converts config to canonical dict representation and computes hash.
|
|
404
|
+
Useful for tracking configuration changes and ensuring reproducibility.
|
|
405
|
+
|
|
406
|
+
Args:
|
|
407
|
+
config: Config instance to hash.
|
|
408
|
+
|
|
409
|
+
Returns:
|
|
410
|
+
SHA256 hex digest (64 characters).
|
|
411
|
+
|
|
412
|
+
Example:
|
|
413
|
+
>>> config = load_config("config.toml")
|
|
414
|
+
>>> hash_value = compute_config_hash(config)
|
|
415
|
+
>>> print(f"Config hash: {hash_value[:16]}...")
|
|
416
|
+
"""
|
|
417
|
+
config_dict = config.model_dump()
|
|
418
|
+
return compute_hash(config_dict)
|
|
419
|
+
|
|
420
|
+
|
|
421
|
+
# =============================================================================
|
|
422
|
+
# Private Validation Helpers
|
|
423
|
+
# =============================================================================
|
|
424
|
+
|
|
425
|
+
|
|
426
|
+
def _validate_config_enums(data: Dict[str, Any]) -> None:
|
|
427
|
+
"""Validate enum constraints before Pydantic validation.
|
|
428
|
+
|
|
429
|
+
Pre-validates enum fields to provide clearer error messages than
|
|
430
|
+
Pydantic's default validation.
|
|
431
|
+
|
|
432
|
+
Args:
|
|
433
|
+
data: Raw configuration dict from TOML.
|
|
434
|
+
|
|
435
|
+
Raises:
|
|
436
|
+
ValueError: If any enum value is invalid.
|
|
437
|
+
"""
|
|
438
|
+
timebase = data.get("timebase", {})
|
|
439
|
+
|
|
440
|
+
# Validate timebase.source
|
|
441
|
+
source = timebase.get("source")
|
|
442
|
+
if source and source not in VALID_TIMEBASE_SOURCES:
|
|
443
|
+
raise ValueError(f"Invalid timebase.source: '{source}'. " f"Must be one of {sorted(VALID_TIMEBASE_SOURCES)}")
|
|
444
|
+
|
|
445
|
+
# Validate timebase.mapping
|
|
446
|
+
mapping = timebase.get("mapping")
|
|
447
|
+
if mapping and mapping not in VALID_TIMEBASE_MAPPINGS:
|
|
448
|
+
raise ValueError(f"Invalid timebase.mapping: '{mapping}'. " f"Must be one of {sorted(VALID_TIMEBASE_MAPPINGS)}")
|
|
449
|
+
|
|
450
|
+
# Validate jitter_budget_s >= 0
|
|
451
|
+
jitter_budget = timebase.get("jitter_budget_s")
|
|
452
|
+
if jitter_budget is not None and jitter_budget < 0:
|
|
453
|
+
raise ValueError(f"Invalid timebase.jitter_budget_s: {jitter_budget}. " f"Must be >= 0")
|
|
454
|
+
|
|
455
|
+
# Validate logging.level
|
|
456
|
+
logging_config = data.get("logging", {})
|
|
457
|
+
level = logging_config.get("level")
|
|
458
|
+
if level and level not in VALID_LOGGING_LEVELS:
|
|
459
|
+
raise ValueError(f"Invalid logging.level: '{level}'. " f"Must be one of {sorted(VALID_LOGGING_LEVELS)}")
|
|
460
|
+
|
|
461
|
+
|
|
462
|
+
def _validate_config_conditionals(data: Dict[str, Any]) -> None:
|
|
463
|
+
"""Validate conditional requirements before Pydantic validation.
|
|
464
|
+
|
|
465
|
+
Checks that required fields are present based on other field values.
|
|
466
|
+
|
|
467
|
+
Args:
|
|
468
|
+
data: Raw configuration dict from TOML.
|
|
469
|
+
|
|
470
|
+
Raises:
|
|
471
|
+
ValueError: If conditional requirements are not met.
|
|
472
|
+
"""
|
|
473
|
+
timebase = data.get("timebase", {})
|
|
474
|
+
source = timebase.get("source")
|
|
475
|
+
|
|
476
|
+
if source == "ttl" and not timebase.get("ttl_id"):
|
|
477
|
+
raise ValueError("timebase.ttl_id is required when timebase.source='ttl'")
|
|
478
|
+
|
|
479
|
+
if source == "neuropixels" and not timebase.get("neuropixels_stream"):
|
|
480
|
+
raise ValueError("timebase.neuropixels_stream is required when " "timebase.source='neuropixels'")
|
|
481
|
+
|
|
482
|
+
|
|
483
|
+
# =============================================================================
|
|
484
|
+
# Session Loading Functions (backward compatibility)
|
|
485
|
+
# =============================================================================
|
|
486
|
+
|
|
487
|
+
|
|
488
|
+
def load_session(path: Union[str, Path]) -> Dict[str, Any]:
|
|
489
|
+
"""Load session metadata from TOML file.
|
|
490
|
+
|
|
491
|
+
Args:
|
|
492
|
+
path: Path to session.toml or metadata.toml file.
|
|
493
|
+
|
|
494
|
+
Returns:
|
|
495
|
+
Parsed session metadata dictionary.
|
|
496
|
+
|
|
497
|
+
Raises:
|
|
498
|
+
FileNotFoundError: If file doesn't exist.
|
|
499
|
+
|
|
500
|
+
Example:
|
|
501
|
+
>>> session = load_session("data/raw/Session-000001/session.toml")
|
|
502
|
+
>>> print(session["identifier"])
|
|
503
|
+
"""
|
|
504
|
+
session_path = Path(path)
|
|
505
|
+
|
|
506
|
+
if not session_path.exists():
|
|
507
|
+
raise FileNotFoundError(f"Session file not found: {session_path}")
|
|
508
|
+
|
|
509
|
+
return read_toml(session_path)
|
|
510
|
+
|
|
511
|
+
|
|
512
|
+
def compute_session_hash(session: Dict[str, Any]) -> str:
|
|
513
|
+
"""Compute deterministic SHA256 hash of session metadata.
|
|
514
|
+
|
|
515
|
+
Args:
|
|
516
|
+
session: Session metadata dictionary.
|
|
517
|
+
|
|
518
|
+
Returns:
|
|
519
|
+
SHA256 hex digest (64 characters).
|
|
520
|
+
|
|
521
|
+
Example:
|
|
522
|
+
>>> session = load_session("session.toml")
|
|
523
|
+
>>> hash_value = compute_session_hash(session)
|
|
524
|
+
>>> print(f"Session hash: {hash_value[:16]}...")
|
|
525
|
+
"""
|
|
526
|
+
return compute_hash(session)
|
|
527
|
+
|
|
528
|
+
|
|
529
|
+
# =============================================================================
|
|
530
|
+
# CLI/Testing Entry Point
|
|
531
|
+
# =============================================================================
|
|
532
|
+
|
|
533
|
+
if __name__ == "__main__":
|
|
534
|
+
"""Demonstrate configuration loading and validation."""
|
|
535
|
+
|
|
536
|
+
print("=" * 70)
|
|
537
|
+
print("Configuration Loading Examples")
|
|
538
|
+
print("=" * 70)
|
|
539
|
+
print()
|
|
540
|
+
|
|
541
|
+
# Example 1: Load valid configuration
|
|
542
|
+
print("Example 1: Load and validate config.toml")
|
|
543
|
+
print("-" * 70)
|
|
544
|
+
|
|
545
|
+
try:
|
|
546
|
+
config_path = Path("tests/fixtures/configs/valid_config.toml")
|
|
547
|
+
config = load_config(config_path)
|
|
548
|
+
|
|
549
|
+
print(f"✓ Loaded: {config_path}")
|
|
550
|
+
print(f" Project: {config.project.name}")
|
|
551
|
+
print(f" Timebase: {config.timebase.source} ({config.timebase.mapping})")
|
|
552
|
+
print(f" Jitter budget: {config.timebase.jitter_budget_s}s")
|
|
553
|
+
print(f" Logging: {config.logging.level}")
|
|
554
|
+
|
|
555
|
+
config_hash = compute_config_hash(config)
|
|
556
|
+
print(f" Hash: {config_hash[:16]}...")
|
|
557
|
+
|
|
558
|
+
except FileNotFoundError as e:
|
|
559
|
+
print(f"✗ File not found: {e}")
|
|
560
|
+
print(" Hint: Run from project root")
|
|
561
|
+
except ValidationError as e:
|
|
562
|
+
print(f"✗ Validation failed:")
|
|
563
|
+
for error in e.errors():
|
|
564
|
+
print(f" - {error['loc']}: {error['msg']}")
|
|
565
|
+
except ValueError as e:
|
|
566
|
+
print(f"✗ Configuration error: {e}")
|
|
567
|
+
|
|
568
|
+
print()
|
|
569
|
+
|
|
570
|
+
# Example 2: Demonstrate validation errors
|
|
571
|
+
print("Example 2: Validation error handling")
|
|
572
|
+
print("-" * 70)
|
|
573
|
+
|
|
574
|
+
# Invalid enum
|
|
575
|
+
print("\n2a. Invalid timebase.source:")
|
|
576
|
+
try:
|
|
577
|
+
test_data = {
|
|
578
|
+
"project": {"name": "test"},
|
|
579
|
+
"paths": {
|
|
580
|
+
"raw_root": "data/raw",
|
|
581
|
+
"intermediate_root": "data/interim",
|
|
582
|
+
"output_root": "data/processed",
|
|
583
|
+
},
|
|
584
|
+
"timebase": {
|
|
585
|
+
"source": "invalid",
|
|
586
|
+
"mapping": "nearest",
|
|
587
|
+
"jitter_budget_s": 0.01,
|
|
588
|
+
},
|
|
589
|
+
}
|
|
590
|
+
_validate_config_enums(test_data)
|
|
591
|
+
except ValueError as e:
|
|
592
|
+
print(f" ✓ Caught: {e}")
|
|
593
|
+
|
|
594
|
+
# Missing conditional field
|
|
595
|
+
print("\n2b. Missing conditional field (ttl_id):")
|
|
596
|
+
try:
|
|
597
|
+
test_data = {
|
|
598
|
+
"timebase": {
|
|
599
|
+
"source": "ttl",
|
|
600
|
+
"mapping": "nearest",
|
|
601
|
+
"jitter_budget_s": 0.01,
|
|
602
|
+
}
|
|
603
|
+
}
|
|
604
|
+
_validate_config_conditionals(test_data)
|
|
605
|
+
except ValueError as e:
|
|
606
|
+
print(f" ✓ Caught: {e}")
|
|
607
|
+
|
|
608
|
+
# Invalid numeric constraint
|
|
609
|
+
print("\n2c. Invalid numeric constraint:")
|
|
610
|
+
try:
|
|
611
|
+
test_data = {
|
|
612
|
+
"timebase": {
|
|
613
|
+
"source": "nominal_rate",
|
|
614
|
+
"mapping": "nearest",
|
|
615
|
+
"jitter_budget_s": -0.01,
|
|
616
|
+
}
|
|
617
|
+
}
|
|
618
|
+
_validate_config_enums(test_data)
|
|
619
|
+
except ValueError as e:
|
|
620
|
+
print(f" ✓ Caught: {e}")
|
|
621
|
+
|
|
622
|
+
print()
|
|
623
|
+
print("=" * 70)
|
|
624
|
+
print("See module docstring for more information")
|
|
625
|
+
print("=" * 70)
|
w2t_bkin/dlc/__init__.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
"""DLC (DeepLabCut) inference module.
|
|
2
|
+
|
|
3
|
+
This module provides low-level primitives for running DeepLabCut model inference
|
|
4
|
+
on video files. It follows the 3-tier architecture:
|
|
5
|
+
|
|
6
|
+
- **Low-level**: Functions accept primitives only (Path, int, bool, List)
|
|
7
|
+
- **No Config/Session**: Never imports config, Session, or Manifest
|
|
8
|
+
- **Module-local models**: Owns DLCInferenceOptions, DLCInferenceResult, DLCModelInfo
|
|
9
|
+
|
|
10
|
+
**Key Features**:
|
|
11
|
+
- Batch processing: Single DLC call for multiple videos (optimal GPU utilization)
|
|
12
|
+
- GPU auto-detection: Automatic GPU selection with manual override support
|
|
13
|
+
- Partial failure handling: Gracefully handle individual video failures in batch
|
|
14
|
+
- Idempotency: Content-addressed outputs, skip inference if unchanged
|
|
15
|
+
|
|
16
|
+
**Architecture**:
|
|
17
|
+
- ``dlc/core.py``: Low-level inference functions
|
|
18
|
+
- ``dlc/models.py``: Module-local data models
|
|
19
|
+
- ``dlc/__init__.py``: Public API surface
|
|
20
|
+
|
|
21
|
+
Requirements:
|
|
22
|
+
- FR-5: Optional pose estimation
|
|
23
|
+
- NFR-1: Determinism (idempotent outputs)
|
|
24
|
+
- NFR-2: Performance (batch processing)
|
|
25
|
+
|
|
26
|
+
Example:
|
|
27
|
+
>>> from w2t_bkin.dlc import run_dlc_inference_batch, DLCInferenceOptions
|
|
28
|
+
>>> from pathlib import Path
|
|
29
|
+
>>>
|
|
30
|
+
>>> videos = [Path("cam0.mp4"), Path("cam1.mp4")]
|
|
31
|
+
>>> model_config = Path("models/dlc_model/config.yaml")
|
|
32
|
+
>>> output_dir = Path("output/dlc")
|
|
33
|
+
>>>
|
|
34
|
+
>>> options = DLCInferenceOptions(gputouse=0, save_as_csv=False)
|
|
35
|
+
>>> results = run_dlc_inference_batch(videos, model_config, output_dir, options)
|
|
36
|
+
>>>
|
|
37
|
+
>>> for result in results:
|
|
38
|
+
... if result.success:
|
|
39
|
+
... print(f"Success: {result.h5_output_path}")
|
|
40
|
+
... else:
|
|
41
|
+
... print(f"Failed: {result.error_message}")
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
from w2t_bkin.dlc.core import DLCInferenceError, auto_detect_gpu, predict_output_paths, run_dlc_inference_batch, validate_dlc_model
|
|
45
|
+
from w2t_bkin.dlc.models import DLCInferenceOptions, DLCInferenceResult, DLCModelInfo
|
|
46
|
+
|
|
47
|
+
__all__ = [
|
|
48
|
+
# Exception
|
|
49
|
+
"DLCInferenceError",
|
|
50
|
+
# Models
|
|
51
|
+
"DLCInferenceOptions",
|
|
52
|
+
"DLCInferenceResult",
|
|
53
|
+
"DLCModelInfo",
|
|
54
|
+
# Functions
|
|
55
|
+
"run_dlc_inference_batch",
|
|
56
|
+
"validate_dlc_model",
|
|
57
|
+
"predict_output_paths",
|
|
58
|
+
"auto_detect_gpu",
|
|
59
|
+
]
|