w2t-bkin 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
w2t_bkin/utils.py ADDED
@@ -0,0 +1,1093 @@
1
+ """Utility functions for W2T-BKIN pipeline (Phase 0 - Foundation).
2
+
3
+ This module provides core utilities used throughout the pipeline:
4
+ - Deterministic SHA256 hashing for files and data structures
5
+ - Path sanitization to prevent directory traversal attacks
6
+ - File discovery and sorting with glob patterns
7
+ - Path and file validation with customizable error handling
8
+ - String sanitization for safe identifiers
9
+ - File size validation
10
+ - Directory creation with write permission checking
11
+ - File checksum computation
12
+ - TOML file reading
13
+ - JSON I/O with consistent formatting
14
+ - Video analysis using FFmpeg/FFprobe
15
+ - Logger configuration
16
+
17
+ The utilities ensure reproducible outputs (NFR-1), secure file handling (NFR-2),
18
+ and efficient video metadata extraction (FR-2).
19
+
20
+ Key Functions:
21
+ --------------
22
+ Core Hashing:
23
+ - compute_hash: Deterministic hashing with key canonicalization for dicts
24
+ - compute_file_checksum: Compute SHA256/SHA1/MD5 checksum of files
25
+
26
+ File Discovery & Sorting:
27
+ - discover_files: Find files matching glob patterns, return absolute paths
28
+ - sort_files: Sort files by name or modification time
29
+
30
+ Path & File Validation:
31
+ - sanitize_path: Security validation for file paths (directory traversal prevention)
32
+ - validate_file_exists: Check file exists and is a file
33
+ - validate_dir_exists: Check directory exists and is a directory
34
+ - validate_file_size: Check file size within limits
35
+
36
+ String & Directory Operations:
37
+ - sanitize_string: Remove control characters, limit length
38
+ - is_nan_or_none: Check if value is None or NaN
39
+ - convert_matlab_struct: Convert MATLAB struct objects to dictionaries
40
+ - validate_against_whitelist: Validate value against allowed set
41
+ - ensure_directory: Create directory with optional write permission check
42
+
43
+ File I/O:
44
+ - read_toml: Load TOML files
45
+ - read_json: Load JSON files
46
+ - write_json: Save JSON with Path object support
47
+
48
+ Video Analysis:
49
+ - run_ffprobe: Count frames using ffprobe
50
+
51
+ Logging:
52
+ - configure_logger: Set up structured or standard logging
53
+
54
+ Requirements:
55
+ -------------
56
+ - NFR-1: Reproducible outputs (deterministic hashing)
57
+ - NFR-2: Security (path sanitization, validation)
58
+ - NFR-3: Performance (efficient I/O)
59
+ - FR-2: Video frame counting
60
+
61
+ Acceptance Criteria:
62
+ -------------------
63
+ - A18: Deterministic hashing produces identical results for identical inputs
64
+
65
+ Example:
66
+ --------
67
+ >>> from w2t_bkin.utils import compute_hash, sanitize_path, discover_files
68
+ >>>
69
+ >>> # Compute deterministic hash
70
+ >>> data = {"session": "Session-001", "timestamp": "2025-11-12"}
71
+ >>> hash_value = compute_hash(data)
72
+ >>> print(hash_value) # Consistent across runs
73
+ >>>
74
+ >>> # Discover files with glob
75
+ >>> video_files = discover_files(Path("data/raw/session"), "*.avi")
76
+ >>>
77
+ >>> # Sanitize file paths
78
+ >>> safe_path = sanitize_path("data/raw/session.toml")
79
+ >>> # Raises ValueError for dangerous paths like "../../../etc/passwd"
80
+ >>>
81
+ >>> # Validate files exist
82
+ >>> from w2t_bkin.utils import validate_file_exists
83
+ >>> validate_file_exists(video_path, IngestError, "Video file required")
84
+ """
85
+
86
+ from datetime import datetime
87
+ import glob
88
+ import hashlib
89
+ import json
90
+ import logging
91
+ import math
92
+ from pathlib import Path
93
+ import subprocess
94
+ import sys
95
+ from typing import Any, Dict, FrozenSet, List, Literal, Optional, Set, Type, Union
96
+
97
+ # Module logger
98
+ logger = logging.getLogger(__name__)
99
+
100
+ # Import version info
101
+ try:
102
+ from importlib.metadata import version
103
+ except ImportError:
104
+ # Python < 3.8
105
+ from importlib_metadata import version
106
+
107
+
108
+ def parse_datetime(dt_str: str) -> datetime:
109
+ """Parse ISO 8601 datetime string.
110
+
111
+ Supports formats: YYYY-MM-DDTHH:MM:SS and YYYY-MM-DD HH:MM:SS
112
+
113
+ Parameters
114
+ ----------
115
+ dt_str : str
116
+ ISO 8601 datetime string
117
+
118
+ Returns
119
+ -------
120
+ datetime
121
+ Parsed datetime object
122
+ """
123
+ # Try with 'T' separator first
124
+ try:
125
+ return datetime.fromisoformat(dt_str)
126
+ except ValueError:
127
+ pass
128
+
129
+ # Try with space separator
130
+ try:
131
+ return datetime.strptime(dt_str, "%Y-%m-%d %H:%M:%S")
132
+ except ValueError:
133
+ raise ValueError(f"Invalid datetime format: {dt_str}. " "Expected ISO 8601: YYYY-MM-DDTHH:MM:SS or YYYY-MM-DD HH:MM:SS")
134
+
135
+
136
+ def get_source_script() -> Optional[str]:
137
+ """Get the source script that is running (the __main__ script).
138
+
139
+ Returns the absolute path of the script file used to create the NWB file.
140
+ This captures the actual entry point script (e.g., pipeline.py, analysis.py).
141
+
142
+ Returns
143
+ -------
144
+ Optional[str]
145
+ Absolute path to the main script file, or None if not available
146
+
147
+ Example
148
+ -------
149
+ >>> # When running: python pipeline.py
150
+ >>> get_source_script()
151
+ '/home/user/project/pipeline.py'
152
+
153
+ >>> # When running interactively or from module
154
+ >>> get_source_script()
155
+ None
156
+ """
157
+ try:
158
+ # sys.argv[0] contains the script that was invoked
159
+ if sys.argv and sys.argv[0]:
160
+ script_path = Path(sys.argv[0]).resolve()
161
+ # Only return if it's an actual file (not '<stdin>' or similar)
162
+ if script_path.exists() and script_path.is_file():
163
+ return str(script_path)
164
+ except (IndexError, OSError):
165
+ pass
166
+
167
+ return None
168
+
169
+
170
+ def get_source_script_file_name() -> Optional[str]:
171
+ """Get the name of the source script file (without path).
172
+
173
+ Returns just the filename of the script used to create the NWB file.
174
+
175
+ Returns
176
+ -------
177
+ Optional[str]
178
+ Name of the main script file, or None if not available
179
+
180
+ Example
181
+ -------
182
+ >>> # When running: python /home/user/project/pipeline.py
183
+ >>> get_source_script_file_name()
184
+ 'pipeline.py'
185
+ """
186
+ script_path = get_source_script()
187
+ if script_path:
188
+ return Path(script_path).name
189
+ return None
190
+
191
+
192
+ def get_software_packages() -> List[str]:
193
+ """Get list of software package names and versions used.
194
+
195
+ Returns a list of package names with versions in the format:
196
+ "package_name==version"
197
+
198
+ This captures key dependencies for reproducibility and provenance tracking.
199
+
200
+ Returns
201
+ -------
202
+ List[str]
203
+ List of package names with versions (e.g., ["pynwb==3.1.0", "w2t_bkin==0.0.3"])
204
+
205
+ Example
206
+ -------
207
+ >>> packages = get_software_packages()
208
+ >>> print(packages)
209
+ ['w2t_bkin==0.0.3', 'pynwb==3.1.0', 'hdmf==4.1.0', ...]
210
+ """
211
+ packages = []
212
+
213
+ # Core packages to track
214
+ package_names = [
215
+ "w2t_bkin", # This package
216
+ "pynwb", # NWB file creation
217
+ "hdmf", # Data format
218
+ "deeplabcut", # Pose estimation
219
+ "facemap", # Facial metrics
220
+ "scipy", # Scientific computing
221
+ "numpy", # Array operations
222
+ "pandas", # Data frames
223
+ "torch", # Deep learning (if used)
224
+ ]
225
+
226
+ for package_name in package_names:
227
+ try:
228
+ pkg_version = version(package_name)
229
+ packages.append(f"{package_name}=={pkg_version}")
230
+ except Exception:
231
+ # Package not installed or version not available
232
+ continue
233
+
234
+ return packages
235
+
236
+
237
+ def compute_hash(data: Union[str, Dict[str, Any]]) -> str:
238
+ """Compute deterministic SHA256 hash of input data.
239
+
240
+ For dictionaries, canonicalizes by sorting keys before hashing.
241
+
242
+ Args:
243
+ data: String or dictionary to hash
244
+
245
+ Returns:
246
+ SHA256 hex digest (64 characters)
247
+ """
248
+ if isinstance(data, dict):
249
+ # Canonicalize: sort keys and convert to compact JSON
250
+ canonical = json.dumps(data, sort_keys=True, separators=(",", ":"))
251
+ data_bytes = canonical.encode("utf-8")
252
+ else:
253
+ data_bytes = data.encode("utf-8")
254
+
255
+ return hashlib.sha256(data_bytes).hexdigest()
256
+
257
+
258
+ def sanitize_path(path: Union[str, Path], base: Optional[Path] = None) -> Path:
259
+ """Sanitize path to prevent directory traversal attacks.
260
+
261
+ Args:
262
+ path: Path to sanitize
263
+ base: Optional base directory to restrict path to
264
+
265
+ Returns:
266
+ Sanitized Path object
267
+
268
+ Raises:
269
+ ValueError: If path attempts directory traversal
270
+ """
271
+ path_obj = Path(path)
272
+
273
+ # Check for directory traversal patterns
274
+ if ".." in path_obj.parts:
275
+ raise ValueError(f"Directory traversal not allowed: {path}")
276
+
277
+ # If base provided, ensure resolved path is within base
278
+ if base is not None:
279
+ base = Path(base).resolve()
280
+ resolved = (base / path_obj).resolve()
281
+ if not str(resolved).startswith(str(base)):
282
+ raise ValueError(f"Path {path} outside allowed base {base}")
283
+ return resolved
284
+
285
+ return path_obj
286
+
287
+
288
+ def discover_files(base_dir: Path, pattern: str, sort: bool = True) -> List[Path]:
289
+ """Discover files matching glob pattern and return absolute paths.
290
+
291
+ Args:
292
+ base_dir: Base directory to resolve pattern from
293
+ pattern: Glob pattern (relative to base_dir)
294
+ sort: If True, sort files by name (default: True)
295
+
296
+ Returns:
297
+ List of absolute Path objects
298
+
299
+ Example:
300
+ >>> files = discover_files(Path("data/raw"), "*.avi")
301
+ >>> files = discover_files(session_dir, "Bpod/*.mat", sort=True)
302
+ """
303
+ full_pattern = str(base_dir / pattern)
304
+ file_paths = [Path(p).resolve() for p in glob.glob(full_pattern)]
305
+
306
+ if sort:
307
+ file_paths.sort(key=lambda p: p.name)
308
+
309
+ return file_paths
310
+
311
+
312
+ def sort_files(files: List[Path], strategy: Literal["name_asc", "name_desc", "time_asc", "time_desc"]) -> List[Path]:
313
+ """Sort file list by specified strategy.
314
+
315
+ Args:
316
+ files: List of file paths to sort
317
+ strategy: Sorting strategy:
318
+ - "name_asc": Sort by filename ascending
319
+ - "name_desc": Sort by filename descending
320
+ - "time_asc": Sort by modification time ascending (oldest first)
321
+ - "time_desc": Sort by modification time descending (newest first)
322
+
323
+ Returns:
324
+ Sorted list of Path objects (new list, does not modify input)
325
+
326
+ Example:
327
+ >>> files = sort_files(discovered_files, "time_desc")
328
+ """
329
+ sorted_files = files.copy()
330
+
331
+ if strategy == "name_asc":
332
+ sorted_files.sort(key=lambda p: p.name)
333
+ elif strategy == "name_desc":
334
+ sorted_files.sort(key=lambda p: p.name, reverse=True)
335
+ elif strategy == "time_asc":
336
+ sorted_files.sort(key=lambda p: p.stat().st_mtime)
337
+ elif strategy == "time_desc":
338
+ sorted_files.sort(key=lambda p: p.stat().st_mtime, reverse=True)
339
+ else:
340
+ raise ValueError(f"Invalid sort strategy: {strategy}")
341
+
342
+ return sorted_files
343
+
344
+
345
+ def validate_file_exists(path: Path, error_class: Type[Exception] = FileNotFoundError, message: Optional[str] = None) -> None:
346
+ """Validate file exists and is a file, not a directory.
347
+
348
+ Args:
349
+ path: Path to validate
350
+ error_class: Exception class to raise on validation failure
351
+ message: Optional custom error message
352
+
353
+ Raises:
354
+ error_class: If file doesn't exist or is not a file
355
+
356
+ Example:
357
+ >>> validate_file_exists(video_path, IngestError, "Video file required")
358
+ """
359
+ if not path.exists():
360
+ msg = message or f"File not found: {path}"
361
+ raise error_class(msg)
362
+
363
+ if not path.is_file():
364
+ msg = message or f"Path is not a file: {path}"
365
+ raise error_class(msg)
366
+
367
+
368
+ def validate_dir_exists(path: Path, error_class: Type[Exception] = FileNotFoundError, message: Optional[str] = None) -> None:
369
+ """Validate directory exists and is a directory, not a file.
370
+
371
+ Args:
372
+ path: Path to validate
373
+ error_class: Exception class to raise on validation failure
374
+ message: Optional custom error message
375
+
376
+ Raises:
377
+ error_class: If directory doesn't exist or is not a directory
378
+
379
+ Example:
380
+ >>> validate_dir_exists(output_dir, NWBError, "Output directory required")
381
+ """
382
+ if not path.exists():
383
+ msg = message or f"Directory not found: {path}"
384
+ raise error_class(msg)
385
+
386
+ if not path.is_dir():
387
+ msg = message or f"Path is not a directory: {path}"
388
+ raise error_class(msg)
389
+
390
+
391
+ def validate_file_size(path: Path, max_size_mb: float) -> float:
392
+ """Validate file size within limits, return size in MB.
393
+
394
+ Args:
395
+ path: Path to file
396
+ max_size_mb: Maximum allowed size in megabytes
397
+
398
+ Returns:
399
+ File size in MB
400
+
401
+ Raises:
402
+ ValueError: If file exceeds size limit
403
+
404
+ Example:
405
+ >>> size_mb = validate_file_size(bpod_path, max_size_mb=100)
406
+ """
407
+ file_size_mb = path.stat().st_size / (1024 * 1024)
408
+
409
+ if file_size_mb > max_size_mb:
410
+ raise ValueError(f"File too large: {file_size_mb:.1f}MB exceeds {max_size_mb}MB limit")
411
+
412
+ return file_size_mb
413
+
414
+
415
+ def sanitize_string(
416
+ text: str, max_length: int = 100, allowed_pattern: Literal["alphanumeric", "alphanumeric_-", "alphanumeric_-_", "printable"] = "alphanumeric_-_", default: str = "unknown"
417
+ ) -> str:
418
+ """Sanitize string by removing control characters and limiting length.
419
+
420
+ Args:
421
+ text: String to sanitize
422
+ max_length: Maximum length of output string
423
+ allowed_pattern: Character allowance pattern:
424
+ - "alphanumeric": Only letters and numbers
425
+ - "alphanumeric_-": Letters, numbers, hyphens
426
+ - "alphanumeric_-_": Letters, numbers, hyphens, underscores
427
+ - "printable": All printable characters
428
+ default: Default value if sanitized string is empty
429
+
430
+ Returns:
431
+ Sanitized string
432
+
433
+ Example:
434
+ >>> safe_id = sanitize_string("Session-001", allowed_pattern="alphanumeric_-")
435
+ >>> safe_event = sanitize_string(raw_event_name, max_length=50)
436
+ """
437
+ if not isinstance(text, str):
438
+ return default
439
+
440
+ # Remove control characters based on pattern
441
+ if allowed_pattern == "alphanumeric":
442
+ sanitized = "".join(c for c in text if c.isalnum())
443
+ elif allowed_pattern == "alphanumeric_-":
444
+ sanitized = "".join(c for c in text if c.isalnum() or c == "-")
445
+ elif allowed_pattern == "alphanumeric_-_":
446
+ sanitized = "".join(c for c in text if c.isalnum() or c in "-_")
447
+ elif allowed_pattern == "printable":
448
+ sanitized = "".join(c for c in text if c.isprintable())
449
+ else:
450
+ raise ValueError(f"Invalid allowed_pattern: {allowed_pattern}")
451
+
452
+ # Limit length
453
+ sanitized = sanitized[:max_length]
454
+
455
+ # Return default if empty
456
+ if not sanitized:
457
+ return default
458
+
459
+ return sanitized
460
+
461
+
462
+ def is_nan_or_none(value: Any) -> bool:
463
+ """Check if value is None or NaN (for float values).
464
+
465
+ Args:
466
+ value: Value to check
467
+
468
+ Returns:
469
+ True if value is None or NaN, False otherwise
470
+
471
+ Example:
472
+ >>> is_nan_or_none(None) # True
473
+ >>> is_nan_or_none(float('nan')) # True
474
+ >>> is_nan_or_none(0.0) # False
475
+ >>> is_nan_or_none([1.0, 2.0]) # False
476
+ """
477
+ if value is None:
478
+ return True
479
+ if isinstance(value, float) and math.isnan(value):
480
+ return True
481
+ return False
482
+
483
+
484
+ def convert_matlab_struct(obj: Any) -> Dict[str, Any]:
485
+ """Convert MATLAB struct object to dictionary.
486
+
487
+ Handles scipy.io mat_struct objects by extracting non-private attributes.
488
+ If already a dict, returns as-is. For other types, returns empty dict.
489
+
490
+ Args:
491
+ obj: MATLAB struct object, dictionary, or other type
492
+
493
+ Returns:
494
+ Dictionary representation
495
+
496
+ Example:
497
+ >>> # With scipy mat_struct
498
+ >>> from scipy.io import loadmat
499
+ >>> data = loadmat("file.mat")
500
+ >>> session_data = convert_matlab_struct(data["SessionData"])
501
+ >>>
502
+ >>> # With plain dict
503
+ >>> convert_matlab_struct({"key": "value"}) # Returns as-is
504
+ """
505
+ if hasattr(obj, "__dict__"):
506
+ # scipy mat_struct or similar object with __dict__
507
+ return {k: v for k, v in obj.__dict__.items() if not k.startswith("_")}
508
+ elif isinstance(obj, dict):
509
+ # Already a dictionary
510
+ return obj
511
+ else:
512
+ # Unsupported type - return empty dict
513
+ return {}
514
+
515
+
516
+ def validate_against_whitelist(value: str, whitelist: Union[Set[str], FrozenSet[str]], default: str, warn: bool = True) -> str:
517
+ """Validate string value against whitelist, return default if invalid.
518
+
519
+ Args:
520
+ value: Value to validate
521
+ whitelist: Set or frozenset of allowed values
522
+ default: Default value to return if validation fails
523
+ warn: If True, log warning when value not in whitelist
524
+
525
+ Returns:
526
+ Value if in whitelist, otherwise default
527
+
528
+ Example:
529
+ >>> outcomes = frozenset(["hit", "miss", "correct"])
530
+ >>> validate_against_whitelist("hit", outcomes, "unknown") # "hit"
531
+ >>> validate_against_whitelist("invalid", outcomes, "unknown") # "unknown"
532
+ """
533
+ if value in whitelist:
534
+ return value
535
+
536
+ if warn:
537
+ logger = logging.getLogger(__name__)
538
+ logger.warning(f"Invalid value '{value}', defaulting to '{default}'")
539
+
540
+ return default
541
+
542
+
543
+ def ensure_directory(path: Path, check_writable: bool = False) -> Path:
544
+ """Ensure directory exists, optionally check write permissions.
545
+
546
+ Args:
547
+ path: Directory path to ensure
548
+ check_writable: If True, verify directory is writable
549
+
550
+ Returns:
551
+ The path (for chaining)
552
+
553
+ Raises:
554
+ OSError: If directory cannot be created
555
+ PermissionError: If check_writable=True and directory is not writable
556
+
557
+ Example:
558
+ >>> output_dir = ensure_directory(Path("data/processed"), check_writable=True)
559
+ """
560
+ if not path.exists():
561
+ path.mkdir(parents=True, exist_ok=True)
562
+
563
+ if not path.is_dir():
564
+ raise OSError(f"Path exists but is not a directory: {path}")
565
+
566
+ if check_writable:
567
+ # Try to write test file to check permissions
568
+ test_file = path / ".test_write"
569
+ try:
570
+ test_file.touch()
571
+ test_file.unlink()
572
+ except Exception as e:
573
+ raise PermissionError(f"Directory is not writable: {path}. Error: {e}")
574
+
575
+ return path
576
+
577
+
578
+ def compute_file_checksum(file_path: Path, algorithm: str = "sha256", chunk_size: int = 8192) -> str:
579
+ """Compute checksum of file using specified algorithm.
580
+
581
+ Args:
582
+ file_path: Path to file
583
+ algorithm: Hash algorithm (sha256, sha1, md5)
584
+ chunk_size: Read chunk size in bytes
585
+
586
+ Returns:
587
+ Hex digest of file checksum
588
+
589
+ Raises:
590
+ FileNotFoundError: If file doesn't exist
591
+ ValueError: If algorithm is unsupported
592
+
593
+ Example:
594
+ >>> checksum = compute_file_checksum(video_path)
595
+ >>> checksum = compute_file_checksum(video_path, algorithm="sha1")
596
+ """
597
+ if not file_path.exists():
598
+ raise FileNotFoundError(f"File not found: {file_path}")
599
+
600
+ # Create hash object
601
+ if algorithm == "sha256":
602
+ hasher = hashlib.sha256()
603
+ elif algorithm == "sha1":
604
+ hasher = hashlib.sha1()
605
+ elif algorithm == "md5":
606
+ hasher = hashlib.md5()
607
+ else:
608
+ raise ValueError(f"Unsupported hash algorithm: {algorithm}")
609
+
610
+ # Read file in chunks and update hash
611
+ with open(file_path, "rb") as f:
612
+ while chunk := f.read(chunk_size):
613
+ hasher.update(chunk)
614
+
615
+ return hasher.hexdigest()
616
+
617
+
618
+ def read_toml(path: Union[str, Path]) -> Dict[str, Any]:
619
+ """Read TOML file into dictionary.
620
+
621
+ Args:
622
+ path: Path to TOML file (str or Path)
623
+
624
+ Returns:
625
+ Dictionary with parsed TOML data
626
+
627
+ Raises:
628
+ FileNotFoundError: If file doesn't exist
629
+
630
+ Example:
631
+ >>> data = read_toml("config.toml")
632
+ >>> data = read_toml(Path("session.toml"))
633
+ """
634
+ path = Path(path) if isinstance(path, str) else path
635
+
636
+ if not path.exists():
637
+ raise FileNotFoundError(f"TOML file not found: {path}")
638
+
639
+ try:
640
+ import tomllib
641
+ except ImportError:
642
+ import tomli as tomllib
643
+
644
+ with open(path, "rb") as f:
645
+ return tomllib.load(f)
646
+
647
+
648
+ def write_json(data: Dict[str, Any], path: Union[str, Path], indent: int = 2) -> None:
649
+ """Write data to JSON file with custom encoder for Path objects.
650
+
651
+ Args:
652
+ data: Dictionary to write
653
+ path: Output file path
654
+ indent: JSON indentation (default: 2 spaces)
655
+ """
656
+
657
+ class PathEncoder(json.JSONEncoder):
658
+ """Custom JSON encoder that handles Path objects."""
659
+
660
+ def default(self, obj):
661
+ if isinstance(obj, Path):
662
+ return str(obj)
663
+ return super().default(obj)
664
+
665
+ path_obj = Path(path)
666
+ path_obj.parent.mkdir(parents=True, exist_ok=True)
667
+
668
+ with open(path_obj, "w", encoding="utf-8") as f:
669
+ json.dump(data, f, indent=indent, cls=PathEncoder)
670
+
671
+
672
+ def read_json(path: Union[str, Path]) -> Dict[str, Any]:
673
+ """Read JSON file into dictionary.
674
+
675
+ Args:
676
+ path: Input file path
677
+
678
+ Returns:
679
+ Dictionary with parsed JSON data
680
+ """
681
+ with open(path, "r", encoding="utf-8") as f:
682
+ return json.load(f)
683
+
684
+
685
+ def configure_logger(name: str, level: str = "INFO", structured: bool = False) -> logging.Logger:
686
+ """Configure logger with specified settings.
687
+
688
+ Args:
689
+ name: Logger name
690
+ level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
691
+ structured: If True, use structured (JSON) logging
692
+
693
+ Returns:
694
+ Configured logger instance
695
+ """
696
+ logger = logging.getLogger(name)
697
+ logger.setLevel(getattr(logging, level.upper()))
698
+
699
+ # Remove existing handlers
700
+ logger.handlers.clear()
701
+
702
+ handler = logging.StreamHandler()
703
+
704
+ if structured:
705
+ # JSON structured logging
706
+ formatter = logging.Formatter('{"timestamp":"%(asctime)s","level":"%(levelname)s","name":"%(name)s","message":"%(message)s"}')
707
+ else:
708
+ # Standard logging
709
+ formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
710
+
711
+ handler.setFormatter(formatter)
712
+ logger.addHandler(handler)
713
+
714
+ return logger
715
+
716
+
717
+ class VideoAnalysisError(Exception):
718
+ """Error during video analysis operations."""
719
+
720
+ pass
721
+
722
+
723
+ def run_ffprobe(video_path: Path, timeout: int = 30) -> int:
724
+ """Count frames in a video file using ffprobe.
725
+
726
+ Uses ffprobe to accurately count video frames by reading the stream metadata.
727
+ This is more reliable than using OpenCV for corrupted or unusual video formats.
728
+
729
+ Args:
730
+ video_path: Path to video file
731
+ timeout: Maximum time in seconds to wait for ffprobe (default: 30)
732
+
733
+ Returns:
734
+ Number of frames in video
735
+
736
+ Raises:
737
+ VideoAnalysisError: If video file is invalid or ffprobe fails
738
+ FileNotFoundError: If video file does not exist
739
+ ValueError: If video_path is not a valid path
740
+
741
+ Security:
742
+ - Input path validation to prevent command injection
743
+ - Subprocess timeout to prevent hanging
744
+ - stderr capture for diagnostic information
745
+ """
746
+ # Input validation
747
+ if not isinstance(video_path, Path):
748
+ video_path = Path(video_path)
749
+
750
+ if not video_path.exists():
751
+ raise FileNotFoundError(f"Video file not found: {video_path}")
752
+
753
+ if not video_path.is_file():
754
+ raise ValueError(f"Path is not a file: {video_path}")
755
+
756
+ # Sanitize path - resolve to absolute path to prevent injection
757
+ video_path = video_path.resolve()
758
+
759
+ # ffprobe command to count frames accurately
760
+ # -v error: only show errors
761
+ # -select_streams v:0: select first video stream
762
+ # -count_frames: actually count frames (slower but accurate)
763
+ # -show_entries stream=nb_read_frames: output only frame count
764
+ # -of csv=p=0: output as CSV without header
765
+ command = [
766
+ "ffprobe",
767
+ "-v",
768
+ "error",
769
+ "-select_streams",
770
+ "v:0",
771
+ "-count_frames",
772
+ "-show_entries",
773
+ "stream=nb_read_frames",
774
+ "-of",
775
+ "csv=p=0",
776
+ str(video_path),
777
+ ]
778
+
779
+ try:
780
+ # Run ffprobe with timeout and capture output
781
+ result = subprocess.run(
782
+ command,
783
+ capture_output=True,
784
+ text=True,
785
+ timeout=timeout,
786
+ check=True,
787
+ )
788
+
789
+ # Parse output - should be a single integer
790
+ output = result.stdout.strip()
791
+
792
+ if not output:
793
+ raise VideoAnalysisError(f"ffprobe returned empty output for: {video_path}")
794
+
795
+ try:
796
+ frame_count = int(output)
797
+ except ValueError:
798
+ raise VideoAnalysisError(f"ffprobe returned non-integer output: {output}")
799
+
800
+ if frame_count < 0:
801
+ raise VideoAnalysisError(f"ffprobe returned negative frame count: {frame_count}")
802
+
803
+ return frame_count
804
+
805
+ except subprocess.TimeoutExpired:
806
+ raise VideoAnalysisError(f"ffprobe timed out after {timeout}s for: {video_path}")
807
+
808
+ except subprocess.CalledProcessError as e:
809
+ # ffprobe failed - provide diagnostic information
810
+ stderr_msg = e.stderr.strip() if e.stderr else "No error message"
811
+ raise VideoAnalysisError(f"ffprobe failed for {video_path}: {stderr_msg}")
812
+
813
+ except Exception as e:
814
+ # Unexpected error
815
+ raise VideoAnalysisError(f"Unexpected error running ffprobe: {e}")
816
+
817
+
818
+ # =============================================================================
819
+ # Events Helper Functions (Numpy Array Handling)
820
+ # =============================================================================
821
+
822
+
823
+ def to_scalar(value: Union[Any, "np.ndarray"], index: int) -> Any:
824
+ """Extract scalar from array or list.
825
+
826
+ Args:
827
+ value: Array, list, tuple, or scalar
828
+ index: Index to extract
829
+
830
+ Returns:
831
+ Scalar value at index
832
+
833
+ Raises:
834
+ IndexError: Index out of bounds
835
+ """
836
+ import numpy as np
837
+
838
+ if isinstance(value, np.ndarray):
839
+ # Handle numpy arrays (including 0-d arrays)
840
+ if value.ndim == 0:
841
+ return value.item()
842
+ return value[index].item() if hasattr(value[index], "item") else value[index]
843
+ elif isinstance(value, (list, tuple)):
844
+ return value[index]
845
+ else:
846
+ # Assume it's already a scalar
847
+ return value
848
+
849
+
850
+ def to_list(value: Union[Any, "np.ndarray"]) -> List[Any]:
851
+ """Convert array or scalar to Python list.
852
+
853
+ Args:
854
+ value: Array, list, tuple, or scalar
855
+
856
+ Returns:
857
+ Python list
858
+ """
859
+ import numpy as np
860
+
861
+ if isinstance(value, np.ndarray):
862
+ return value.tolist()
863
+ elif isinstance(value, (list, tuple)):
864
+ return list(value)
865
+ else:
866
+ # Scalar value
867
+ return [value]
868
+
869
+
870
+ if __name__ == "__main__":
871
+ """Usage examples for utils module."""
872
+ import tempfile
873
+
874
+ print("=" * 70)
875
+ print("W2T-BKIN Utils Module - Usage Examples")
876
+ print("=" * 70)
877
+ print()
878
+
879
+ # Example 1: Compute hash
880
+ print("Example 1: Compute Hash")
881
+ print("-" * 50)
882
+ test_data = {"session_id": "Session-000001", "timestamp": "2025-11-12"}
883
+ hash_result = compute_hash(test_data)
884
+ print(f"Data: {test_data}")
885
+ print(f"Hash: {hash_result}")
886
+ print()
887
+
888
+ # Example 2: Sanitize path
889
+ print("Example 2: Sanitize Path")
890
+ print("-" * 50)
891
+ safe_path = sanitize_path("data/raw/Session-000001")
892
+ print(f"Input: data/raw/Session-000001")
893
+ print(f"Sanitized: {safe_path}")
894
+
895
+ try:
896
+ dangerous = sanitize_path("../../etc/passwd")
897
+ print(f"Dangerous path: {dangerous}")
898
+ except ValueError as e:
899
+ print(f"Blocked directory traversal: {e}")
900
+ print()
901
+
902
+ # Example 3: JSON I/O
903
+ print("Example 3: JSON I/O")
904
+ print("-" * 50)
905
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
906
+ temp_path = Path(f.name)
907
+
908
+ test_obj = {"key": "value", "number": 42}
909
+ write_json(test_obj, temp_path)
910
+ print(f"Wrote to: {temp_path.name}")
911
+
912
+ loaded = read_json(temp_path)
913
+ print(f"Read back: {loaded}")
914
+ temp_path.unlink()
915
+ print()
916
+
917
+ print("=" * 70)
918
+ print("Examples completed. See module docstring for more details.")
919
+ print("=" * 70)
920
+
921
+
922
+ # =============================================================================
923
+ # Video and TTL Counting Utilities
924
+ # =============================================================================
925
+
926
+
927
+ def count_video_frames(video_path: Path) -> int:
928
+ """Count frames in a video file using ffprobe or synthetic stub.
929
+
930
+ This function counts video frames using ffprobe. It includes special
931
+ handling for synthetic stub videos (used in testing) and provides
932
+ robust error handling for unreadable or missing files.
933
+
934
+ Args:
935
+ video_path: Path to video file
936
+
937
+ Returns:
938
+ Number of frames in video (0 if file not found or empty)
939
+
940
+ Raises:
941
+ RuntimeError: If video file cannot be analyzed
942
+
943
+ Example:
944
+ >>> from pathlib import Path
945
+ >>> frame_count = count_video_frames(Path("video.avi"))
946
+ >>> print(f"Frames: {frame_count}")
947
+ """
948
+ # Validate input
949
+ if not video_path.exists():
950
+ logger.warning(f"Video file not found: {video_path}")
951
+ return 0
952
+
953
+ # Handle empty files
954
+ if video_path.stat().st_size == 0:
955
+ logger.warning(f"Video file is empty: {video_path}")
956
+ return 0
957
+
958
+ # Check if this is a synthetic stub video
959
+ try:
960
+ # Try importing synthetic module (only available if in test/synthetic context)
961
+ from synthetic.video_synth import count_stub_frames, is_synthetic_stub
962
+
963
+ if is_synthetic_stub(video_path):
964
+ frame_count = count_stub_frames(video_path)
965
+ logger.debug(f"Counted {frame_count} frames in synthetic stub {video_path.name}")
966
+ return frame_count
967
+ except ImportError:
968
+ # Synthetic module not available - continue with normal ffprobe
969
+ pass
970
+
971
+ # Use ffprobe to count frames
972
+ try:
973
+ frame_count = run_ffprobe(video_path)
974
+ logger.debug(f"Counted {frame_count} frames in {video_path.name}")
975
+ return frame_count
976
+ except Exception as e:
977
+ # Log error and raise - frame counting failure is critical
978
+ logger.error(f"Failed to count frames in {video_path}: {e}")
979
+ raise RuntimeError(f"Could not count frames in video {video_path}: {e}")
980
+
981
+
982
+ def normalize_keypoints_to_dict(keypoints) -> Dict[str, Dict]:
983
+ """Convert keypoints to standard dict format (name -> keypoint data).
984
+
985
+ Handles multiple input formats:
986
+ - KeypointsDict (custom dict that iterates over values)
987
+ - Regular dict with keypoint name as key
988
+ - List of keypoint dicts
989
+
990
+ Args:
991
+ keypoints: Keypoints in various formats
992
+
993
+ Returns:
994
+ Dictionary mapping keypoint name to keypoint data dict
995
+
996
+ Example:
997
+ >>> kp_list = [{"name": "nose", "x": 10, "y": 20}]
998
+ >>> result = normalize_keypoints_to_dict(kp_list)
999
+ >>> result["nose"]
1000
+ {"name": "nose", "x": 10, "y": 20}
1001
+ """
1002
+ if isinstance(keypoints, dict):
1003
+ # Check if it's a KeypointsDict or dict-like that iterates values
1004
+ if hasattr(keypoints, "__iter__") and keypoints:
1005
+ first_val = next(iter(keypoints.values()))
1006
+ if isinstance(first_val, dict) and "name" in first_val:
1007
+ # Already in correct format (name -> dict)
1008
+ return keypoints
1009
+ # Standard dict format
1010
+ return keypoints
1011
+ elif isinstance(keypoints, list):
1012
+ # Convert list to dict
1013
+ return {kp["name"]: kp for kp in keypoints}
1014
+ else:
1015
+ # Fallback for unknown types
1016
+ return {}
1017
+
1018
+
1019
+ def log_missing_keypoints(
1020
+ frame_index: int,
1021
+ expected_names: Set[str],
1022
+ actual_names: Set[str],
1023
+ logger_instance: logging.Logger,
1024
+ ) -> None:
1025
+ """Log warning for missing keypoints in a frame.
1026
+
1027
+ Args:
1028
+ frame_index: Frame number for logging
1029
+ expected_names: Set of expected keypoint names
1030
+ actual_names: Set of actual keypoint names found
1031
+ logger_instance: Logger to use for warning
1032
+ """
1033
+ missing = expected_names - actual_names
1034
+ if missing:
1035
+ logger_instance.warning(f"Frame {frame_index}: Missing keypoints {missing}")
1036
+
1037
+
1038
+ def derive_bodyparts_from_data(data: List[Dict]) -> List[str]:
1039
+ """Extract canonical bodypart names from harmonized pose data.
1040
+
1041
+ Args:
1042
+ data: List of pose frame dictionaries with keypoints
1043
+
1044
+ Returns:
1045
+ Sorted list of bodypart names
1046
+
1047
+ Raises:
1048
+ ValueError: If data is empty or has no keypoints
1049
+
1050
+ Example:
1051
+ >>> frames = [{"keypoints": {"nose": {...}, "ear_left": {...}}}]
1052
+ >>> derive_bodyparts_from_data(frames)
1053
+ ["ear_left", "nose"]
1054
+ """
1055
+ if not data:
1056
+ raise ValueError("Cannot derive bodyparts from empty data")
1057
+
1058
+ first_frame_keypoints = normalize_keypoints_to_dict(data[0].get("keypoints", {}))
1059
+
1060
+ if not first_frame_keypoints:
1061
+ raise ValueError("First frame has no keypoints")
1062
+
1063
+ # Sort for consistency across runs
1064
+ return sorted(first_frame_keypoints.keys())
1065
+
1066
+
1067
+ def count_ttl_pulses(ttl_path: Path) -> int:
1068
+ """Count TTL pulses from log file.
1069
+
1070
+ Counts non-empty lines in a TTL log file. Each line represents one
1071
+ TTL pulse event.
1072
+
1073
+ Args:
1074
+ ttl_path: Path to TTL log file
1075
+
1076
+ Returns:
1077
+ Number of pulses in file (0 if file not found or unreadable)
1078
+
1079
+ Example:
1080
+ >>> from pathlib import Path
1081
+ >>> pulse_count = count_ttl_pulses(Path("camera_ttl.log"))
1082
+ >>> print(f"Pulses: {pulse_count}")
1083
+ """
1084
+ if not ttl_path.exists():
1085
+ return 0
1086
+
1087
+ # Count lines in TTL file (each line = one pulse)
1088
+ try:
1089
+ with open(ttl_path, "r") as f:
1090
+ lines = f.readlines()
1091
+ return len([line for line in lines if line.strip()])
1092
+ except Exception:
1093
+ return 0