w2t-bkin 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
w2t_bkin/config.py ADDED
@@ -0,0 +1,625 @@
1
+ """Configuration management for W2T-BKIN pipeline.
2
+
3
+ This module provides Pydantic models for validating configuration files (config.toml)
4
+ and functions for loading, validating, and hashing configurations.
5
+
6
+ The configuration system enforces strict schema validation to catch errors early,
7
+ supports deterministic hashing for reproducibility, and provides clear error messages.
8
+
9
+ Typical usage example:
10
+ >>> from w2t_bkin.config import load_config
11
+ >>>
12
+ >>> config = load_config("config.toml")
13
+ >>> print(config.project.name)
14
+ >>> print(config.timebase.source)
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ from pathlib import Path
20
+ from typing import Any, Dict, List, Literal, Optional, Union
21
+
22
+ try:
23
+ import tomllib
24
+ except ImportError:
25
+ import tomli as tomllib # Python < 3.11 fallback
26
+
27
+ from pydantic import BaseModel, Field, ValidationError, field_validator
28
+
29
+ from .utils import compute_hash, read_toml
30
+
31
+ # =============================================================================
32
+ # Constants
33
+ # =============================================================================
34
+
35
+ VALID_TIMEBASE_SOURCES = frozenset({"nominal_rate", "ttl", "neuropixels"})
36
+ VALID_TIMEBASE_MAPPINGS = frozenset({"nearest", "linear"})
37
+ VALID_LOGGING_LEVELS = frozenset({"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"})
38
+
39
+
40
+ # =============================================================================
41
+ # Configuration Models - Core
42
+ # =============================================================================
43
+
44
+
45
+ class ProjectConfig(BaseModel, extra="forbid"):
46
+ """Project identification.
47
+
48
+ Attributes:
49
+ name: Project name identifier.
50
+ """
51
+
52
+ name: str = Field(..., description="Project name")
53
+
54
+
55
+ class PathsConfig(BaseModel, extra="forbid"):
56
+ """File system paths configuration.
57
+
58
+ Attributes:
59
+ raw_root: Path to raw data directory.
60
+ intermediate_root: Path for intermediate processing outputs.
61
+ output_root: Path for final outputs.
62
+ metadata_file: Filename for session metadata (default: session.toml).
63
+ models_root: Directory containing pose estimation models (default: models).
64
+ """
65
+
66
+ raw_root: Path = Field(..., description="Raw data root directory")
67
+ intermediate_root: Path = Field(..., description="Intermediate processing outputs")
68
+ output_root: Path = Field(..., description="Output data root directory")
69
+ metadata_file: Path = Field(default="session.toml", description="Session metadata filename")
70
+ models_root: Path = Field(default="models", description="Pose estimation models directory")
71
+
72
+
73
+ class TimebaseConfig(BaseModel, extra="forbid"):
74
+ """Reference timebase for aligning derived data.
75
+
76
+ Defines the reference clock for synchronizing pose and behavior data.
77
+ ImageSeries remain rate-based; this timebase applies to derived modalities.
78
+
79
+ Attributes:
80
+ source: Timebase source (nominal_rate, ttl, or neuropixels).
81
+ mapping: Strategy for mapping timestamps (nearest or linear).
82
+ jitter_budget_s: Maximum allowed temporal jitter in seconds.
83
+ offset_s: Global time offset before mapping (default: 0.0).
84
+ ttl_id: TTL channel ID (required when source='ttl').
85
+ neuropixels_stream: Neuropixels stream name (required when source='neuropixels').
86
+ """
87
+
88
+ source: Literal["nominal_rate", "ttl", "neuropixels"] = Field(..., description="Timebase source")
89
+ mapping: Literal["nearest", "linear"] = Field(..., description="Mapping strategy")
90
+ jitter_budget_s: float = Field(..., ge=0.0, description="Max allowed jitter in seconds")
91
+ offset_s: float = Field(default=0.0, description="Global offset before mapping")
92
+ ttl_id: Optional[str] = Field(None, description="TTL ID (required when source='ttl')")
93
+ neuropixels_stream: Optional[str] = Field(None, description="Neuropixels stream (required when source='neuropixels')")
94
+
95
+
96
+ class AcquisitionConfig(BaseModel, extra="forbid"):
97
+ """Data acquisition policies.
98
+
99
+ Attributes:
100
+ concat_strategy: Video concatenation method (ffconcat or streamlist).
101
+ """
102
+
103
+ concat_strategy: Literal["ffconcat", "streamlist"] = Field(default="ffconcat", description="Video concatenation strategy")
104
+
105
+
106
+ class VerificationConfig(BaseModel, extra="forbid"):
107
+ """Hardware synchronization verification.
108
+
109
+ Attributes:
110
+ mismatch_tolerance_frames: Max allowed frame/TTL count mismatch before abort.
111
+ warn_on_mismatch: If True, warn instead of abort when within tolerance.
112
+ """
113
+
114
+ mismatch_tolerance_frames: int = Field(default=0, ge=0, description="Abort if frame_count - ttl_pulse_count > tolerance")
115
+ warn_on_mismatch: bool = Field(default=False, description="Warn instead of abort if within tolerance")
116
+
117
+
118
+ # =============================================================================
119
+ # Configuration Models - Bpod
120
+ # =============================================================================
121
+
122
+
123
+ class BpodSyncTrialType(BaseModel, extra="forbid"):
124
+ """Bpod trial type synchronization mapping.
125
+
126
+ Maps a Bpod trial type to its synchronization signal and TTL channel,
127
+ enabling conversion from Bpod relative timestamps to absolute time.
128
+
129
+ Attributes:
130
+ trial_type: Trial type identifier matching Bpod classification.
131
+ sync_signal: Bpod state/event name for alignment (e.g., 'W2T_Audio').
132
+ sync_ttl: TTL channel whose pulses correspond to sync_signal.
133
+ """
134
+
135
+ trial_type: int = Field(..., ge=0, description="Trial type identifier")
136
+ sync_signal: str = Field(..., description="Bpod state/event for alignment")
137
+ sync_ttl: str = Field(..., description="TTL channel for sync pulses")
138
+
139
+
140
+ class BpodSyncConfig(BaseModel, extra="forbid"):
141
+ """Bpod-to-TTL synchronization configuration.
142
+
143
+ Attributes:
144
+ trial_types: List of trial type sync configurations.
145
+ """
146
+
147
+ trial_types: List[BpodSyncTrialType] = Field(default_factory=list, description="Trial type sync configs")
148
+
149
+
150
+ class BpodConfig(BaseModel, extra="forbid"):
151
+ """Bpod behavioral control system configuration.
152
+
153
+ Attributes:
154
+ parse: Whether to parse Bpod .mat files.
155
+ sync: Trial synchronization configuration.
156
+ """
157
+
158
+ parse: bool = Field(default=True, description="Parse Bpod .mat files if present")
159
+ sync: BpodSyncConfig = Field(default_factory=BpodSyncConfig, description="Trial sync configuration")
160
+
161
+
162
+ # =============================================================================
163
+ # Configuration Models - Video
164
+ # =============================================================================
165
+
166
+
167
+ class TranscodeConfig(BaseModel, extra="forbid"):
168
+ """Video transcoding settings.
169
+
170
+ Attributes:
171
+ enabled: Enable video transcoding.
172
+ codec: FFmpeg codec (e.g., 'h264', 'libx264').
173
+ crf: Constant rate factor quality (0-51, lower is better).
174
+ preset: FFmpeg encoding preset (e.g., 'fast', 'medium').
175
+ keyint: GOP (group of pictures) length.
176
+ """
177
+
178
+ enabled: bool = Field(default=True, description="Enable transcoding")
179
+ codec: str = Field(default="h264", description="FFmpeg codec name")
180
+ crf: int = Field(default=20, ge=0, le=51, description="Quality factor (0-51)")
181
+ preset: str = Field(default="fast", description="FFmpeg preset")
182
+ keyint: int = Field(default=15, ge=1, description="GOP length")
183
+
184
+
185
+ class VideoConfig(BaseModel, extra="forbid"):
186
+ """Video processing configuration.
187
+
188
+ Attributes:
189
+ transcode: Transcoding settings.
190
+ """
191
+
192
+ transcode: TranscodeConfig = Field(default_factory=TranscodeConfig, description="Transcoding config")
193
+
194
+
195
+ # =============================================================================
196
+ # Configuration Models - Output & Logging
197
+ # =============================================================================
198
+
199
+
200
+ class NWBConfig(BaseModel, extra="forbid"):
201
+ """NWB (Neurodata Without Borders) export settings.
202
+
203
+ Attributes:
204
+ link_external_video: Use external links for videos instead of embedding.
205
+ lab: Laboratory name.
206
+ institution: Institution name.
207
+ file_name_template: Template for NWB filename.
208
+ session_description_template: Template for session description.
209
+ """
210
+
211
+ link_external_video: bool = Field(default=True, description="Link videos externally")
212
+ lab: str = Field(default="Lab Name", description="Lab name")
213
+ institution: str = Field(default="Institution Name", description="Institution name")
214
+ file_name_template: str = Field(default="{session.id}.nwb", description="NWB filename template")
215
+ session_description_template: str = Field(default="Session {session.id} on {session.date}", description="Session description template")
216
+
217
+
218
+ class QCConfig(BaseModel, extra="forbid"):
219
+ """Quality control report configuration.
220
+
221
+ Attributes:
222
+ generate_report: Enable QC report generation.
223
+ out_template: Output path template for reports.
224
+ include_verification: Include frame/TTL verification in reports.
225
+ """
226
+
227
+ generate_report: bool = Field(default=True, description="Generate QC report")
228
+ out_template: str = Field(default="qc/{session.id}", description="Output path template")
229
+ include_verification: bool = Field(default=True, description="Include verification in report")
230
+
231
+
232
+ class LoggingConfig(BaseModel, extra="forbid"):
233
+ """Logging configuration.
234
+
235
+ Attributes:
236
+ level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL).
237
+ structured: Use structured (JSON) logging format.
238
+ """
239
+
240
+ level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = Field(default="INFO", description="Logging level")
241
+ structured: bool = Field(default=False, description="Use structured logging")
242
+
243
+
244
+ # =============================================================================
245
+ # Configuration Models - Inference
246
+ # =============================================================================
247
+
248
+
249
+ class DLCConfig(BaseModel, extra="forbid"):
250
+ """DeepLabCut pose estimation configuration.
251
+
252
+ Attributes:
253
+ run_inference: Enable DLC inference.
254
+ model: Path to DLC model file.
255
+ gputouse: GPU device index (-1 for CPU, None for auto-select).
256
+ """
257
+
258
+ run_inference: bool = Field(default=False, description="Run DLC inference")
259
+ model: str = Field(default="model.pb", description="DLC model path")
260
+ gputouse: Optional[int] = Field(None, description="GPU index (-1=CPU, None=auto)")
261
+
262
+
263
+ class SLEAPConfig(BaseModel, extra="forbid"):
264
+ """SLEAP pose estimation configuration.
265
+
266
+ Attributes:
267
+ run_inference: Enable SLEAP inference.
268
+ model: Path to SLEAP model file.
269
+ """
270
+
271
+ run_inference: bool = Field(default=False, description="Run SLEAP inference")
272
+ model: str = Field(default="sleap.h5", description="SLEAP model path")
273
+
274
+
275
+ class LabelsConfig(BaseModel, extra="forbid"):
276
+ """Pose labeling configuration.
277
+
278
+ Attributes:
279
+ dlc: DeepLabCut configuration.
280
+ sleap: SLEAP configuration.
281
+ """
282
+
283
+ dlc: DLCConfig = Field(default_factory=DLCConfig, description="DLC config")
284
+ sleap: SLEAPConfig = Field(default_factory=SLEAPConfig, description="SLEAP config")
285
+
286
+
287
+ class FacemapConfig(BaseModel, extra="forbid"):
288
+ """Facemap facial motion tracking configuration.
289
+
290
+ Attributes:
291
+ run_inference: Enable Facemap inference.
292
+ ROIs: Regions of interest to process.
293
+ """
294
+
295
+ run_inference: bool = Field(default=False, description="Run Facemap inference")
296
+ ROIs: List[str] = Field(default_factory=lambda: ["face", "left_eye", "right_eye"], description="ROIs to process")
297
+
298
+
299
+ # =============================================================================
300
+ # Main Configuration Model
301
+ # =============================================================================
302
+
303
+
304
+ class Config(BaseModel, extra="forbid"):
305
+ """Main pipeline configuration.
306
+
307
+ Root configuration model loaded from config.toml. Uses strict validation
308
+ with extra="forbid" to prevent typos and configuration errors.
309
+
310
+ Attributes:
311
+ project: Project identification.
312
+ paths: File system paths.
313
+ timebase: Reference timebase for synchronization.
314
+ acquisition: Data acquisition policies.
315
+ verification: Hardware sync verification.
316
+ bpod: Bpod behavioral control settings.
317
+ video: Video processing configuration.
318
+ nwb: NWB export settings.
319
+ qc: Quality control configuration.
320
+ logging: Logging configuration.
321
+ labels: Pose labeling configuration.
322
+ facemap: Facemap tracking configuration.
323
+ """
324
+
325
+ project: ProjectConfig
326
+ paths: PathsConfig
327
+ # timebase: TimebaseConfig
328
+ # acquisition: AcquisitionConfig = Field(default_factory=AcquisitionConfig)
329
+ # verification: VerificationConfig = Field(default_factory=VerificationConfig)
330
+ bpod: BpodConfig = Field(default_factory=BpodConfig)
331
+ # video: VideoConfig = Field(default_factory=VideoConfig)
332
+ # nwb: NWBConfig = Field(default_factory=NWBConfig)
333
+ # qc: QCConfig = Field(default_factory=QCConfig)
334
+ logging: LoggingConfig = Field(default_factory=LoggingConfig)
335
+ # labels: LabelsConfig = Field(default_factory=LabelsConfig)
336
+ # facemap: FacemapConfig = Field(default_factory=FacemapConfig)
337
+
338
+ # @field_validator("timebase")
339
+ # @classmethod
340
+ # def validate_timebase_conditionals(cls, v: TimebaseConfig) -> TimebaseConfig:
341
+ # """Validate conditional timebase requirements.
342
+
343
+ # Args:
344
+ # v: TimebaseConfig instance to validate.
345
+
346
+ # Returns:
347
+ # Validated TimebaseConfig.
348
+
349
+ # Raises:
350
+ # ValueError: If conditional requirements are not met.
351
+ # """
352
+ # if v.source == "ttl" and v.ttl_id is None:
353
+ # raise ValueError("timebase.ttl_id is required when source='ttl'")
354
+ # if v.source == "neuropixels" and v.neuropixels_stream is None:
355
+ # raise ValueError("timebase.neuropixels_stream is required when source='neuropixels'")
356
+ # return v
357
+
358
+
359
+ # =============================================================================
360
+ # Public API Functions
361
+ # =============================================================================
362
+
363
+
364
+ def load_config(path: Union[str, Path]) -> Config:
365
+ """Load and validate configuration from TOML file.
366
+
367
+ Performs comprehensive validation including:
368
+ - Schema validation with extra="forbid" to prevent typos
369
+ - Enum validation for source, mapping, and level fields
370
+ - Numeric constraints (e.g., jitter_budget_s >= 0)
371
+ - Conditional requirements (e.g., ttl_id when source='ttl')
372
+
373
+ Args:
374
+ path: Path to config.toml file.
375
+
376
+ Returns:
377
+ Validated Config instance.
378
+
379
+ Raises:
380
+ FileNotFoundError: If config file doesn't exist.
381
+ ValidationError: If config violates Pydantic schema.
382
+ ValueError: If enum or conditional validation fails.
383
+
384
+ Example:
385
+ >>> config = load_config("config.toml")
386
+ >>> print(config.project.name)
387
+ >>> print(config.timebase.source)
388
+ """
389
+ data = read_toml(path)
390
+
391
+ # Pre-validate enums for clearer error messages
392
+ _validate_config_enums(data)
393
+
394
+ # Pre-validate conditional requirements
395
+ _validate_config_conditionals(data)
396
+
397
+ return Config(**data)
398
+
399
+
400
+ def compute_config_hash(config: Config) -> str:
401
+ """Compute deterministic SHA256 hash of configuration.
402
+
403
+ Converts config to canonical dict representation and computes hash.
404
+ Useful for tracking configuration changes and ensuring reproducibility.
405
+
406
+ Args:
407
+ config: Config instance to hash.
408
+
409
+ Returns:
410
+ SHA256 hex digest (64 characters).
411
+
412
+ Example:
413
+ >>> config = load_config("config.toml")
414
+ >>> hash_value = compute_config_hash(config)
415
+ >>> print(f"Config hash: {hash_value[:16]}...")
416
+ """
417
+ config_dict = config.model_dump()
418
+ return compute_hash(config_dict)
419
+
420
+
421
+ # =============================================================================
422
+ # Private Validation Helpers
423
+ # =============================================================================
424
+
425
+
426
+ def _validate_config_enums(data: Dict[str, Any]) -> None:
427
+ """Validate enum constraints before Pydantic validation.
428
+
429
+ Pre-validates enum fields to provide clearer error messages than
430
+ Pydantic's default validation.
431
+
432
+ Args:
433
+ data: Raw configuration dict from TOML.
434
+
435
+ Raises:
436
+ ValueError: If any enum value is invalid.
437
+ """
438
+ timebase = data.get("timebase", {})
439
+
440
+ # Validate timebase.source
441
+ source = timebase.get("source")
442
+ if source and source not in VALID_TIMEBASE_SOURCES:
443
+ raise ValueError(f"Invalid timebase.source: '{source}'. " f"Must be one of {sorted(VALID_TIMEBASE_SOURCES)}")
444
+
445
+ # Validate timebase.mapping
446
+ mapping = timebase.get("mapping")
447
+ if mapping and mapping not in VALID_TIMEBASE_MAPPINGS:
448
+ raise ValueError(f"Invalid timebase.mapping: '{mapping}'. " f"Must be one of {sorted(VALID_TIMEBASE_MAPPINGS)}")
449
+
450
+ # Validate jitter_budget_s >= 0
451
+ jitter_budget = timebase.get("jitter_budget_s")
452
+ if jitter_budget is not None and jitter_budget < 0:
453
+ raise ValueError(f"Invalid timebase.jitter_budget_s: {jitter_budget}. " f"Must be >= 0")
454
+
455
+ # Validate logging.level
456
+ logging_config = data.get("logging", {})
457
+ level = logging_config.get("level")
458
+ if level and level not in VALID_LOGGING_LEVELS:
459
+ raise ValueError(f"Invalid logging.level: '{level}'. " f"Must be one of {sorted(VALID_LOGGING_LEVELS)}")
460
+
461
+
462
+ def _validate_config_conditionals(data: Dict[str, Any]) -> None:
463
+ """Validate conditional requirements before Pydantic validation.
464
+
465
+ Checks that required fields are present based on other field values.
466
+
467
+ Args:
468
+ data: Raw configuration dict from TOML.
469
+
470
+ Raises:
471
+ ValueError: If conditional requirements are not met.
472
+ """
473
+ timebase = data.get("timebase", {})
474
+ source = timebase.get("source")
475
+
476
+ if source == "ttl" and not timebase.get("ttl_id"):
477
+ raise ValueError("timebase.ttl_id is required when timebase.source='ttl'")
478
+
479
+ if source == "neuropixels" and not timebase.get("neuropixels_stream"):
480
+ raise ValueError("timebase.neuropixels_stream is required when " "timebase.source='neuropixels'")
481
+
482
+
483
+ # =============================================================================
484
+ # Session Loading Functions (backward compatibility)
485
+ # =============================================================================
486
+
487
+
488
+ def load_session(path: Union[str, Path]) -> Dict[str, Any]:
489
+ """Load session metadata from TOML file.
490
+
491
+ Args:
492
+ path: Path to session.toml or metadata.toml file.
493
+
494
+ Returns:
495
+ Parsed session metadata dictionary.
496
+
497
+ Raises:
498
+ FileNotFoundError: If file doesn't exist.
499
+
500
+ Example:
501
+ >>> session = load_session("data/raw/Session-000001/session.toml")
502
+ >>> print(session["identifier"])
503
+ """
504
+ session_path = Path(path)
505
+
506
+ if not session_path.exists():
507
+ raise FileNotFoundError(f"Session file not found: {session_path}")
508
+
509
+ return read_toml(session_path)
510
+
511
+
512
+ def compute_session_hash(session: Dict[str, Any]) -> str:
513
+ """Compute deterministic SHA256 hash of session metadata.
514
+
515
+ Args:
516
+ session: Session metadata dictionary.
517
+
518
+ Returns:
519
+ SHA256 hex digest (64 characters).
520
+
521
+ Example:
522
+ >>> session = load_session("session.toml")
523
+ >>> hash_value = compute_session_hash(session)
524
+ >>> print(f"Session hash: {hash_value[:16]}...")
525
+ """
526
+ return compute_hash(session)
527
+
528
+
529
+ # =============================================================================
530
+ # CLI/Testing Entry Point
531
+ # =============================================================================
532
+
533
+ if __name__ == "__main__":
534
+ """Demonstrate configuration loading and validation."""
535
+
536
+ print("=" * 70)
537
+ print("Configuration Loading Examples")
538
+ print("=" * 70)
539
+ print()
540
+
541
+ # Example 1: Load valid configuration
542
+ print("Example 1: Load and validate config.toml")
543
+ print("-" * 70)
544
+
545
+ try:
546
+ config_path = Path("tests/fixtures/configs/valid_config.toml")
547
+ config = load_config(config_path)
548
+
549
+ print(f"✓ Loaded: {config_path}")
550
+ print(f" Project: {config.project.name}")
551
+ print(f" Timebase: {config.timebase.source} ({config.timebase.mapping})")
552
+ print(f" Jitter budget: {config.timebase.jitter_budget_s}s")
553
+ print(f" Logging: {config.logging.level}")
554
+
555
+ config_hash = compute_config_hash(config)
556
+ print(f" Hash: {config_hash[:16]}...")
557
+
558
+ except FileNotFoundError as e:
559
+ print(f"✗ File not found: {e}")
560
+ print(" Hint: Run from project root")
561
+ except ValidationError as e:
562
+ print(f"✗ Validation failed:")
563
+ for error in e.errors():
564
+ print(f" - {error['loc']}: {error['msg']}")
565
+ except ValueError as e:
566
+ print(f"✗ Configuration error: {e}")
567
+
568
+ print()
569
+
570
+ # Example 2: Demonstrate validation errors
571
+ print("Example 2: Validation error handling")
572
+ print("-" * 70)
573
+
574
+ # Invalid enum
575
+ print("\n2a. Invalid timebase.source:")
576
+ try:
577
+ test_data = {
578
+ "project": {"name": "test"},
579
+ "paths": {
580
+ "raw_root": "data/raw",
581
+ "intermediate_root": "data/interim",
582
+ "output_root": "data/processed",
583
+ },
584
+ "timebase": {
585
+ "source": "invalid",
586
+ "mapping": "nearest",
587
+ "jitter_budget_s": 0.01,
588
+ },
589
+ }
590
+ _validate_config_enums(test_data)
591
+ except ValueError as e:
592
+ print(f" ✓ Caught: {e}")
593
+
594
+ # Missing conditional field
595
+ print("\n2b. Missing conditional field (ttl_id):")
596
+ try:
597
+ test_data = {
598
+ "timebase": {
599
+ "source": "ttl",
600
+ "mapping": "nearest",
601
+ "jitter_budget_s": 0.01,
602
+ }
603
+ }
604
+ _validate_config_conditionals(test_data)
605
+ except ValueError as e:
606
+ print(f" ✓ Caught: {e}")
607
+
608
+ # Invalid numeric constraint
609
+ print("\n2c. Invalid numeric constraint:")
610
+ try:
611
+ test_data = {
612
+ "timebase": {
613
+ "source": "nominal_rate",
614
+ "mapping": "nearest",
615
+ "jitter_budget_s": -0.01,
616
+ }
617
+ }
618
+ _validate_config_enums(test_data)
619
+ except ValueError as e:
620
+ print(f" ✓ Caught: {e}")
621
+
622
+ print()
623
+ print("=" * 70)
624
+ print("See module docstring for more information")
625
+ print("=" * 70)
@@ -0,0 +1,59 @@
1
+ """DLC (DeepLabCut) inference module.
2
+
3
+ This module provides low-level primitives for running DeepLabCut model inference
4
+ on video files. It follows the 3-tier architecture:
5
+
6
+ - **Low-level**: Functions accept primitives only (Path, int, bool, List)
7
+ - **No Config/Session**: Never imports config, Session, or Manifest
8
+ - **Module-local models**: Owns DLCInferenceOptions, DLCInferenceResult, DLCModelInfo
9
+
10
+ **Key Features**:
11
+ - Batch processing: Single DLC call for multiple videos (optimal GPU utilization)
12
+ - GPU auto-detection: Automatic GPU selection with manual override support
13
+ - Partial failure handling: Gracefully handle individual video failures in batch
14
+ - Idempotency: Content-addressed outputs, skip inference if unchanged
15
+
16
+ **Architecture**:
17
+ - ``dlc/core.py``: Low-level inference functions
18
+ - ``dlc/models.py``: Module-local data models
19
+ - ``dlc/__init__.py``: Public API surface
20
+
21
+ Requirements:
22
+ - FR-5: Optional pose estimation
23
+ - NFR-1: Determinism (idempotent outputs)
24
+ - NFR-2: Performance (batch processing)
25
+
26
+ Example:
27
+ >>> from w2t_bkin.dlc import run_dlc_inference_batch, DLCInferenceOptions
28
+ >>> from pathlib import Path
29
+ >>>
30
+ >>> videos = [Path("cam0.mp4"), Path("cam1.mp4")]
31
+ >>> model_config = Path("models/dlc_model/config.yaml")
32
+ >>> output_dir = Path("output/dlc")
33
+ >>>
34
+ >>> options = DLCInferenceOptions(gputouse=0, save_as_csv=False)
35
+ >>> results = run_dlc_inference_batch(videos, model_config, output_dir, options)
36
+ >>>
37
+ >>> for result in results:
38
+ ... if result.success:
39
+ ... print(f"Success: {result.h5_output_path}")
40
+ ... else:
41
+ ... print(f"Failed: {result.error_message}")
42
+ """
43
+
44
+ from w2t_bkin.dlc.core import DLCInferenceError, auto_detect_gpu, predict_output_paths, run_dlc_inference_batch, validate_dlc_model
45
+ from w2t_bkin.dlc.models import DLCInferenceOptions, DLCInferenceResult, DLCModelInfo
46
+
47
+ __all__ = [
48
+ # Exception
49
+ "DLCInferenceError",
50
+ # Models
51
+ "DLCInferenceOptions",
52
+ "DLCInferenceResult",
53
+ "DLCModelInfo",
54
+ # Functions
55
+ "run_dlc_inference_batch",
56
+ "validate_dlc_model",
57
+ "predict_output_paths",
58
+ "auto_detect_gpu",
59
+ ]