w2t-bkin 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- w2t_bkin/__init__.py +85 -0
- w2t_bkin/behavior/__init__.py +115 -0
- w2t_bkin/behavior/core.py +1027 -0
- w2t_bkin/bpod/__init__.py +38 -0
- w2t_bkin/bpod/core.py +519 -0
- w2t_bkin/config.py +625 -0
- w2t_bkin/dlc/__init__.py +59 -0
- w2t_bkin/dlc/core.py +448 -0
- w2t_bkin/dlc/models.py +124 -0
- w2t_bkin/exceptions.py +426 -0
- w2t_bkin/facemap/__init__.py +42 -0
- w2t_bkin/facemap/core.py +397 -0
- w2t_bkin/facemap/models.py +134 -0
- w2t_bkin/pipeline.py +665 -0
- w2t_bkin/pose/__init__.py +48 -0
- w2t_bkin/pose/core.py +227 -0
- w2t_bkin/pose/io.py +363 -0
- w2t_bkin/pose/skeleton.py +165 -0
- w2t_bkin/pose/ttl_mock.py +477 -0
- w2t_bkin/session.py +423 -0
- w2t_bkin/sync/__init__.py +72 -0
- w2t_bkin/sync/core.py +678 -0
- w2t_bkin/sync/stats.py +176 -0
- w2t_bkin/sync/timebase.py +311 -0
- w2t_bkin/sync/ttl.py +254 -0
- w2t_bkin/transcode/__init__.py +38 -0
- w2t_bkin/transcode/core.py +303 -0
- w2t_bkin/transcode/models.py +96 -0
- w2t_bkin/ttl/__init__.py +64 -0
- w2t_bkin/ttl/core.py +518 -0
- w2t_bkin/ttl/models.py +19 -0
- w2t_bkin/utils.py +1093 -0
- w2t_bkin-0.0.6.dist-info/METADATA +145 -0
- w2t_bkin-0.0.6.dist-info/RECORD +36 -0
- w2t_bkin-0.0.6.dist-info/WHEEL +4 -0
- w2t_bkin-0.0.6.dist-info/licenses/LICENSE +201 -0
w2t_bkin/utils.py
ADDED
|
@@ -0,0 +1,1093 @@
|
|
|
1
|
+
"""Utility functions for W2T-BKIN pipeline (Phase 0 - Foundation).
|
|
2
|
+
|
|
3
|
+
This module provides core utilities used throughout the pipeline:
|
|
4
|
+
- Deterministic SHA256 hashing for files and data structures
|
|
5
|
+
- Path sanitization to prevent directory traversal attacks
|
|
6
|
+
- File discovery and sorting with glob patterns
|
|
7
|
+
- Path and file validation with customizable error handling
|
|
8
|
+
- String sanitization for safe identifiers
|
|
9
|
+
- File size validation
|
|
10
|
+
- Directory creation with write permission checking
|
|
11
|
+
- File checksum computation
|
|
12
|
+
- TOML file reading
|
|
13
|
+
- JSON I/O with consistent formatting
|
|
14
|
+
- Video analysis using FFmpeg/FFprobe
|
|
15
|
+
- Logger configuration
|
|
16
|
+
|
|
17
|
+
The utilities ensure reproducible outputs (NFR-1), secure file handling (NFR-2),
|
|
18
|
+
and efficient video metadata extraction (FR-2).
|
|
19
|
+
|
|
20
|
+
Key Functions:
|
|
21
|
+
--------------
|
|
22
|
+
Core Hashing:
|
|
23
|
+
- compute_hash: Deterministic hashing with key canonicalization for dicts
|
|
24
|
+
- compute_file_checksum: Compute SHA256/SHA1/MD5 checksum of files
|
|
25
|
+
|
|
26
|
+
File Discovery & Sorting:
|
|
27
|
+
- discover_files: Find files matching glob patterns, return absolute paths
|
|
28
|
+
- sort_files: Sort files by name or modification time
|
|
29
|
+
|
|
30
|
+
Path & File Validation:
|
|
31
|
+
- sanitize_path: Security validation for file paths (directory traversal prevention)
|
|
32
|
+
- validate_file_exists: Check file exists and is a file
|
|
33
|
+
- validate_dir_exists: Check directory exists and is a directory
|
|
34
|
+
- validate_file_size: Check file size within limits
|
|
35
|
+
|
|
36
|
+
String & Directory Operations:
|
|
37
|
+
- sanitize_string: Remove control characters, limit length
|
|
38
|
+
- is_nan_or_none: Check if value is None or NaN
|
|
39
|
+
- convert_matlab_struct: Convert MATLAB struct objects to dictionaries
|
|
40
|
+
- validate_against_whitelist: Validate value against allowed set
|
|
41
|
+
- ensure_directory: Create directory with optional write permission check
|
|
42
|
+
|
|
43
|
+
File I/O:
|
|
44
|
+
- read_toml: Load TOML files
|
|
45
|
+
- read_json: Load JSON files
|
|
46
|
+
- write_json: Save JSON with Path object support
|
|
47
|
+
|
|
48
|
+
Video Analysis:
|
|
49
|
+
- run_ffprobe: Count frames using ffprobe
|
|
50
|
+
|
|
51
|
+
Logging:
|
|
52
|
+
- configure_logger: Set up structured or standard logging
|
|
53
|
+
|
|
54
|
+
Requirements:
|
|
55
|
+
-------------
|
|
56
|
+
- NFR-1: Reproducible outputs (deterministic hashing)
|
|
57
|
+
- NFR-2: Security (path sanitization, validation)
|
|
58
|
+
- NFR-3: Performance (efficient I/O)
|
|
59
|
+
- FR-2: Video frame counting
|
|
60
|
+
|
|
61
|
+
Acceptance Criteria:
|
|
62
|
+
-------------------
|
|
63
|
+
- A18: Deterministic hashing produces identical results for identical inputs
|
|
64
|
+
|
|
65
|
+
Example:
|
|
66
|
+
--------
|
|
67
|
+
>>> from w2t_bkin.utils import compute_hash, sanitize_path, discover_files
|
|
68
|
+
>>>
|
|
69
|
+
>>> # Compute deterministic hash
|
|
70
|
+
>>> data = {"session": "Session-001", "timestamp": "2025-11-12"}
|
|
71
|
+
>>> hash_value = compute_hash(data)
|
|
72
|
+
>>> print(hash_value) # Consistent across runs
|
|
73
|
+
>>>
|
|
74
|
+
>>> # Discover files with glob
|
|
75
|
+
>>> video_files = discover_files(Path("data/raw/session"), "*.avi")
|
|
76
|
+
>>>
|
|
77
|
+
>>> # Sanitize file paths
|
|
78
|
+
>>> safe_path = sanitize_path("data/raw/session.toml")
|
|
79
|
+
>>> # Raises ValueError for dangerous paths like "../../../etc/passwd"
|
|
80
|
+
>>>
|
|
81
|
+
>>> # Validate files exist
|
|
82
|
+
>>> from w2t_bkin.utils import validate_file_exists
|
|
83
|
+
>>> validate_file_exists(video_path, IngestError, "Video file required")
|
|
84
|
+
"""
|
|
85
|
+
|
|
86
|
+
from datetime import datetime
|
|
87
|
+
import glob
|
|
88
|
+
import hashlib
|
|
89
|
+
import json
|
|
90
|
+
import logging
|
|
91
|
+
import math
|
|
92
|
+
from pathlib import Path
|
|
93
|
+
import subprocess
|
|
94
|
+
import sys
|
|
95
|
+
from typing import Any, Dict, FrozenSet, List, Literal, Optional, Set, Type, Union
|
|
96
|
+
|
|
97
|
+
# Module logger
|
|
98
|
+
logger = logging.getLogger(__name__)
|
|
99
|
+
|
|
100
|
+
# Import version info
|
|
101
|
+
try:
|
|
102
|
+
from importlib.metadata import version
|
|
103
|
+
except ImportError:
|
|
104
|
+
# Python < 3.8
|
|
105
|
+
from importlib_metadata import version
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def parse_datetime(dt_str: str) -> datetime:
|
|
109
|
+
"""Parse ISO 8601 datetime string.
|
|
110
|
+
|
|
111
|
+
Supports formats: YYYY-MM-DDTHH:MM:SS and YYYY-MM-DD HH:MM:SS
|
|
112
|
+
|
|
113
|
+
Parameters
|
|
114
|
+
----------
|
|
115
|
+
dt_str : str
|
|
116
|
+
ISO 8601 datetime string
|
|
117
|
+
|
|
118
|
+
Returns
|
|
119
|
+
-------
|
|
120
|
+
datetime
|
|
121
|
+
Parsed datetime object
|
|
122
|
+
"""
|
|
123
|
+
# Try with 'T' separator first
|
|
124
|
+
try:
|
|
125
|
+
return datetime.fromisoformat(dt_str)
|
|
126
|
+
except ValueError:
|
|
127
|
+
pass
|
|
128
|
+
|
|
129
|
+
# Try with space separator
|
|
130
|
+
try:
|
|
131
|
+
return datetime.strptime(dt_str, "%Y-%m-%d %H:%M:%S")
|
|
132
|
+
except ValueError:
|
|
133
|
+
raise ValueError(f"Invalid datetime format: {dt_str}. " "Expected ISO 8601: YYYY-MM-DDTHH:MM:SS or YYYY-MM-DD HH:MM:SS")
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def get_source_script() -> Optional[str]:
|
|
137
|
+
"""Get the source script that is running (the __main__ script).
|
|
138
|
+
|
|
139
|
+
Returns the absolute path of the script file used to create the NWB file.
|
|
140
|
+
This captures the actual entry point script (e.g., pipeline.py, analysis.py).
|
|
141
|
+
|
|
142
|
+
Returns
|
|
143
|
+
-------
|
|
144
|
+
Optional[str]
|
|
145
|
+
Absolute path to the main script file, or None if not available
|
|
146
|
+
|
|
147
|
+
Example
|
|
148
|
+
-------
|
|
149
|
+
>>> # When running: python pipeline.py
|
|
150
|
+
>>> get_source_script()
|
|
151
|
+
'/home/user/project/pipeline.py'
|
|
152
|
+
|
|
153
|
+
>>> # When running interactively or from module
|
|
154
|
+
>>> get_source_script()
|
|
155
|
+
None
|
|
156
|
+
"""
|
|
157
|
+
try:
|
|
158
|
+
# sys.argv[0] contains the script that was invoked
|
|
159
|
+
if sys.argv and sys.argv[0]:
|
|
160
|
+
script_path = Path(sys.argv[0]).resolve()
|
|
161
|
+
# Only return if it's an actual file (not '<stdin>' or similar)
|
|
162
|
+
if script_path.exists() and script_path.is_file():
|
|
163
|
+
return str(script_path)
|
|
164
|
+
except (IndexError, OSError):
|
|
165
|
+
pass
|
|
166
|
+
|
|
167
|
+
return None
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def get_source_script_file_name() -> Optional[str]:
|
|
171
|
+
"""Get the name of the source script file (without path).
|
|
172
|
+
|
|
173
|
+
Returns just the filename of the script used to create the NWB file.
|
|
174
|
+
|
|
175
|
+
Returns
|
|
176
|
+
-------
|
|
177
|
+
Optional[str]
|
|
178
|
+
Name of the main script file, or None if not available
|
|
179
|
+
|
|
180
|
+
Example
|
|
181
|
+
-------
|
|
182
|
+
>>> # When running: python /home/user/project/pipeline.py
|
|
183
|
+
>>> get_source_script_file_name()
|
|
184
|
+
'pipeline.py'
|
|
185
|
+
"""
|
|
186
|
+
script_path = get_source_script()
|
|
187
|
+
if script_path:
|
|
188
|
+
return Path(script_path).name
|
|
189
|
+
return None
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def get_software_packages() -> List[str]:
|
|
193
|
+
"""Get list of software package names and versions used.
|
|
194
|
+
|
|
195
|
+
Returns a list of package names with versions in the format:
|
|
196
|
+
"package_name==version"
|
|
197
|
+
|
|
198
|
+
This captures key dependencies for reproducibility and provenance tracking.
|
|
199
|
+
|
|
200
|
+
Returns
|
|
201
|
+
-------
|
|
202
|
+
List[str]
|
|
203
|
+
List of package names with versions (e.g., ["pynwb==3.1.0", "w2t_bkin==0.0.3"])
|
|
204
|
+
|
|
205
|
+
Example
|
|
206
|
+
-------
|
|
207
|
+
>>> packages = get_software_packages()
|
|
208
|
+
>>> print(packages)
|
|
209
|
+
['w2t_bkin==0.0.3', 'pynwb==3.1.0', 'hdmf==4.1.0', ...]
|
|
210
|
+
"""
|
|
211
|
+
packages = []
|
|
212
|
+
|
|
213
|
+
# Core packages to track
|
|
214
|
+
package_names = [
|
|
215
|
+
"w2t_bkin", # This package
|
|
216
|
+
"pynwb", # NWB file creation
|
|
217
|
+
"hdmf", # Data format
|
|
218
|
+
"deeplabcut", # Pose estimation
|
|
219
|
+
"facemap", # Facial metrics
|
|
220
|
+
"scipy", # Scientific computing
|
|
221
|
+
"numpy", # Array operations
|
|
222
|
+
"pandas", # Data frames
|
|
223
|
+
"torch", # Deep learning (if used)
|
|
224
|
+
]
|
|
225
|
+
|
|
226
|
+
for package_name in package_names:
|
|
227
|
+
try:
|
|
228
|
+
pkg_version = version(package_name)
|
|
229
|
+
packages.append(f"{package_name}=={pkg_version}")
|
|
230
|
+
except Exception:
|
|
231
|
+
# Package not installed or version not available
|
|
232
|
+
continue
|
|
233
|
+
|
|
234
|
+
return packages
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
def compute_hash(data: Union[str, Dict[str, Any]]) -> str:
|
|
238
|
+
"""Compute deterministic SHA256 hash of input data.
|
|
239
|
+
|
|
240
|
+
For dictionaries, canonicalizes by sorting keys before hashing.
|
|
241
|
+
|
|
242
|
+
Args:
|
|
243
|
+
data: String or dictionary to hash
|
|
244
|
+
|
|
245
|
+
Returns:
|
|
246
|
+
SHA256 hex digest (64 characters)
|
|
247
|
+
"""
|
|
248
|
+
if isinstance(data, dict):
|
|
249
|
+
# Canonicalize: sort keys and convert to compact JSON
|
|
250
|
+
canonical = json.dumps(data, sort_keys=True, separators=(",", ":"))
|
|
251
|
+
data_bytes = canonical.encode("utf-8")
|
|
252
|
+
else:
|
|
253
|
+
data_bytes = data.encode("utf-8")
|
|
254
|
+
|
|
255
|
+
return hashlib.sha256(data_bytes).hexdigest()
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
def sanitize_path(path: Union[str, Path], base: Optional[Path] = None) -> Path:
|
|
259
|
+
"""Sanitize path to prevent directory traversal attacks.
|
|
260
|
+
|
|
261
|
+
Args:
|
|
262
|
+
path: Path to sanitize
|
|
263
|
+
base: Optional base directory to restrict path to
|
|
264
|
+
|
|
265
|
+
Returns:
|
|
266
|
+
Sanitized Path object
|
|
267
|
+
|
|
268
|
+
Raises:
|
|
269
|
+
ValueError: If path attempts directory traversal
|
|
270
|
+
"""
|
|
271
|
+
path_obj = Path(path)
|
|
272
|
+
|
|
273
|
+
# Check for directory traversal patterns
|
|
274
|
+
if ".." in path_obj.parts:
|
|
275
|
+
raise ValueError(f"Directory traversal not allowed: {path}")
|
|
276
|
+
|
|
277
|
+
# If base provided, ensure resolved path is within base
|
|
278
|
+
if base is not None:
|
|
279
|
+
base = Path(base).resolve()
|
|
280
|
+
resolved = (base / path_obj).resolve()
|
|
281
|
+
if not str(resolved).startswith(str(base)):
|
|
282
|
+
raise ValueError(f"Path {path} outside allowed base {base}")
|
|
283
|
+
return resolved
|
|
284
|
+
|
|
285
|
+
return path_obj
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
def discover_files(base_dir: Path, pattern: str, sort: bool = True) -> List[Path]:
|
|
289
|
+
"""Discover files matching glob pattern and return absolute paths.
|
|
290
|
+
|
|
291
|
+
Args:
|
|
292
|
+
base_dir: Base directory to resolve pattern from
|
|
293
|
+
pattern: Glob pattern (relative to base_dir)
|
|
294
|
+
sort: If True, sort files by name (default: True)
|
|
295
|
+
|
|
296
|
+
Returns:
|
|
297
|
+
List of absolute Path objects
|
|
298
|
+
|
|
299
|
+
Example:
|
|
300
|
+
>>> files = discover_files(Path("data/raw"), "*.avi")
|
|
301
|
+
>>> files = discover_files(session_dir, "Bpod/*.mat", sort=True)
|
|
302
|
+
"""
|
|
303
|
+
full_pattern = str(base_dir / pattern)
|
|
304
|
+
file_paths = [Path(p).resolve() for p in glob.glob(full_pattern)]
|
|
305
|
+
|
|
306
|
+
if sort:
|
|
307
|
+
file_paths.sort(key=lambda p: p.name)
|
|
308
|
+
|
|
309
|
+
return file_paths
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
def sort_files(files: List[Path], strategy: Literal["name_asc", "name_desc", "time_asc", "time_desc"]) -> List[Path]:
|
|
313
|
+
"""Sort file list by specified strategy.
|
|
314
|
+
|
|
315
|
+
Args:
|
|
316
|
+
files: List of file paths to sort
|
|
317
|
+
strategy: Sorting strategy:
|
|
318
|
+
- "name_asc": Sort by filename ascending
|
|
319
|
+
- "name_desc": Sort by filename descending
|
|
320
|
+
- "time_asc": Sort by modification time ascending (oldest first)
|
|
321
|
+
- "time_desc": Sort by modification time descending (newest first)
|
|
322
|
+
|
|
323
|
+
Returns:
|
|
324
|
+
Sorted list of Path objects (new list, does not modify input)
|
|
325
|
+
|
|
326
|
+
Example:
|
|
327
|
+
>>> files = sort_files(discovered_files, "time_desc")
|
|
328
|
+
"""
|
|
329
|
+
sorted_files = files.copy()
|
|
330
|
+
|
|
331
|
+
if strategy == "name_asc":
|
|
332
|
+
sorted_files.sort(key=lambda p: p.name)
|
|
333
|
+
elif strategy == "name_desc":
|
|
334
|
+
sorted_files.sort(key=lambda p: p.name, reverse=True)
|
|
335
|
+
elif strategy == "time_asc":
|
|
336
|
+
sorted_files.sort(key=lambda p: p.stat().st_mtime)
|
|
337
|
+
elif strategy == "time_desc":
|
|
338
|
+
sorted_files.sort(key=lambda p: p.stat().st_mtime, reverse=True)
|
|
339
|
+
else:
|
|
340
|
+
raise ValueError(f"Invalid sort strategy: {strategy}")
|
|
341
|
+
|
|
342
|
+
return sorted_files
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
def validate_file_exists(path: Path, error_class: Type[Exception] = FileNotFoundError, message: Optional[str] = None) -> None:
|
|
346
|
+
"""Validate file exists and is a file, not a directory.
|
|
347
|
+
|
|
348
|
+
Args:
|
|
349
|
+
path: Path to validate
|
|
350
|
+
error_class: Exception class to raise on validation failure
|
|
351
|
+
message: Optional custom error message
|
|
352
|
+
|
|
353
|
+
Raises:
|
|
354
|
+
error_class: If file doesn't exist or is not a file
|
|
355
|
+
|
|
356
|
+
Example:
|
|
357
|
+
>>> validate_file_exists(video_path, IngestError, "Video file required")
|
|
358
|
+
"""
|
|
359
|
+
if not path.exists():
|
|
360
|
+
msg = message or f"File not found: {path}"
|
|
361
|
+
raise error_class(msg)
|
|
362
|
+
|
|
363
|
+
if not path.is_file():
|
|
364
|
+
msg = message or f"Path is not a file: {path}"
|
|
365
|
+
raise error_class(msg)
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
def validate_dir_exists(path: Path, error_class: Type[Exception] = FileNotFoundError, message: Optional[str] = None) -> None:
|
|
369
|
+
"""Validate directory exists and is a directory, not a file.
|
|
370
|
+
|
|
371
|
+
Args:
|
|
372
|
+
path: Path to validate
|
|
373
|
+
error_class: Exception class to raise on validation failure
|
|
374
|
+
message: Optional custom error message
|
|
375
|
+
|
|
376
|
+
Raises:
|
|
377
|
+
error_class: If directory doesn't exist or is not a directory
|
|
378
|
+
|
|
379
|
+
Example:
|
|
380
|
+
>>> validate_dir_exists(output_dir, NWBError, "Output directory required")
|
|
381
|
+
"""
|
|
382
|
+
if not path.exists():
|
|
383
|
+
msg = message or f"Directory not found: {path}"
|
|
384
|
+
raise error_class(msg)
|
|
385
|
+
|
|
386
|
+
if not path.is_dir():
|
|
387
|
+
msg = message or f"Path is not a directory: {path}"
|
|
388
|
+
raise error_class(msg)
|
|
389
|
+
|
|
390
|
+
|
|
391
|
+
def validate_file_size(path: Path, max_size_mb: float) -> float:
|
|
392
|
+
"""Validate file size within limits, return size in MB.
|
|
393
|
+
|
|
394
|
+
Args:
|
|
395
|
+
path: Path to file
|
|
396
|
+
max_size_mb: Maximum allowed size in megabytes
|
|
397
|
+
|
|
398
|
+
Returns:
|
|
399
|
+
File size in MB
|
|
400
|
+
|
|
401
|
+
Raises:
|
|
402
|
+
ValueError: If file exceeds size limit
|
|
403
|
+
|
|
404
|
+
Example:
|
|
405
|
+
>>> size_mb = validate_file_size(bpod_path, max_size_mb=100)
|
|
406
|
+
"""
|
|
407
|
+
file_size_mb = path.stat().st_size / (1024 * 1024)
|
|
408
|
+
|
|
409
|
+
if file_size_mb > max_size_mb:
|
|
410
|
+
raise ValueError(f"File too large: {file_size_mb:.1f}MB exceeds {max_size_mb}MB limit")
|
|
411
|
+
|
|
412
|
+
return file_size_mb
|
|
413
|
+
|
|
414
|
+
|
|
415
|
+
def sanitize_string(
|
|
416
|
+
text: str, max_length: int = 100, allowed_pattern: Literal["alphanumeric", "alphanumeric_-", "alphanumeric_-_", "printable"] = "alphanumeric_-_", default: str = "unknown"
|
|
417
|
+
) -> str:
|
|
418
|
+
"""Sanitize string by removing control characters and limiting length.
|
|
419
|
+
|
|
420
|
+
Args:
|
|
421
|
+
text: String to sanitize
|
|
422
|
+
max_length: Maximum length of output string
|
|
423
|
+
allowed_pattern: Character allowance pattern:
|
|
424
|
+
- "alphanumeric": Only letters and numbers
|
|
425
|
+
- "alphanumeric_-": Letters, numbers, hyphens
|
|
426
|
+
- "alphanumeric_-_": Letters, numbers, hyphens, underscores
|
|
427
|
+
- "printable": All printable characters
|
|
428
|
+
default: Default value if sanitized string is empty
|
|
429
|
+
|
|
430
|
+
Returns:
|
|
431
|
+
Sanitized string
|
|
432
|
+
|
|
433
|
+
Example:
|
|
434
|
+
>>> safe_id = sanitize_string("Session-001", allowed_pattern="alphanumeric_-")
|
|
435
|
+
>>> safe_event = sanitize_string(raw_event_name, max_length=50)
|
|
436
|
+
"""
|
|
437
|
+
if not isinstance(text, str):
|
|
438
|
+
return default
|
|
439
|
+
|
|
440
|
+
# Remove control characters based on pattern
|
|
441
|
+
if allowed_pattern == "alphanumeric":
|
|
442
|
+
sanitized = "".join(c for c in text if c.isalnum())
|
|
443
|
+
elif allowed_pattern == "alphanumeric_-":
|
|
444
|
+
sanitized = "".join(c for c in text if c.isalnum() or c == "-")
|
|
445
|
+
elif allowed_pattern == "alphanumeric_-_":
|
|
446
|
+
sanitized = "".join(c for c in text if c.isalnum() or c in "-_")
|
|
447
|
+
elif allowed_pattern == "printable":
|
|
448
|
+
sanitized = "".join(c for c in text if c.isprintable())
|
|
449
|
+
else:
|
|
450
|
+
raise ValueError(f"Invalid allowed_pattern: {allowed_pattern}")
|
|
451
|
+
|
|
452
|
+
# Limit length
|
|
453
|
+
sanitized = sanitized[:max_length]
|
|
454
|
+
|
|
455
|
+
# Return default if empty
|
|
456
|
+
if not sanitized:
|
|
457
|
+
return default
|
|
458
|
+
|
|
459
|
+
return sanitized
|
|
460
|
+
|
|
461
|
+
|
|
462
|
+
def is_nan_or_none(value: Any) -> bool:
|
|
463
|
+
"""Check if value is None or NaN (for float values).
|
|
464
|
+
|
|
465
|
+
Args:
|
|
466
|
+
value: Value to check
|
|
467
|
+
|
|
468
|
+
Returns:
|
|
469
|
+
True if value is None or NaN, False otherwise
|
|
470
|
+
|
|
471
|
+
Example:
|
|
472
|
+
>>> is_nan_or_none(None) # True
|
|
473
|
+
>>> is_nan_or_none(float('nan')) # True
|
|
474
|
+
>>> is_nan_or_none(0.0) # False
|
|
475
|
+
>>> is_nan_or_none([1.0, 2.0]) # False
|
|
476
|
+
"""
|
|
477
|
+
if value is None:
|
|
478
|
+
return True
|
|
479
|
+
if isinstance(value, float) and math.isnan(value):
|
|
480
|
+
return True
|
|
481
|
+
return False
|
|
482
|
+
|
|
483
|
+
|
|
484
|
+
def convert_matlab_struct(obj: Any) -> Dict[str, Any]:
|
|
485
|
+
"""Convert MATLAB struct object to dictionary.
|
|
486
|
+
|
|
487
|
+
Handles scipy.io mat_struct objects by extracting non-private attributes.
|
|
488
|
+
If already a dict, returns as-is. For other types, returns empty dict.
|
|
489
|
+
|
|
490
|
+
Args:
|
|
491
|
+
obj: MATLAB struct object, dictionary, or other type
|
|
492
|
+
|
|
493
|
+
Returns:
|
|
494
|
+
Dictionary representation
|
|
495
|
+
|
|
496
|
+
Example:
|
|
497
|
+
>>> # With scipy mat_struct
|
|
498
|
+
>>> from scipy.io import loadmat
|
|
499
|
+
>>> data = loadmat("file.mat")
|
|
500
|
+
>>> session_data = convert_matlab_struct(data["SessionData"])
|
|
501
|
+
>>>
|
|
502
|
+
>>> # With plain dict
|
|
503
|
+
>>> convert_matlab_struct({"key": "value"}) # Returns as-is
|
|
504
|
+
"""
|
|
505
|
+
if hasattr(obj, "__dict__"):
|
|
506
|
+
# scipy mat_struct or similar object with __dict__
|
|
507
|
+
return {k: v for k, v in obj.__dict__.items() if not k.startswith("_")}
|
|
508
|
+
elif isinstance(obj, dict):
|
|
509
|
+
# Already a dictionary
|
|
510
|
+
return obj
|
|
511
|
+
else:
|
|
512
|
+
# Unsupported type - return empty dict
|
|
513
|
+
return {}
|
|
514
|
+
|
|
515
|
+
|
|
516
|
+
def validate_against_whitelist(value: str, whitelist: Union[Set[str], FrozenSet[str]], default: str, warn: bool = True) -> str:
|
|
517
|
+
"""Validate string value against whitelist, return default if invalid.
|
|
518
|
+
|
|
519
|
+
Args:
|
|
520
|
+
value: Value to validate
|
|
521
|
+
whitelist: Set or frozenset of allowed values
|
|
522
|
+
default: Default value to return if validation fails
|
|
523
|
+
warn: If True, log warning when value not in whitelist
|
|
524
|
+
|
|
525
|
+
Returns:
|
|
526
|
+
Value if in whitelist, otherwise default
|
|
527
|
+
|
|
528
|
+
Example:
|
|
529
|
+
>>> outcomes = frozenset(["hit", "miss", "correct"])
|
|
530
|
+
>>> validate_against_whitelist("hit", outcomes, "unknown") # "hit"
|
|
531
|
+
>>> validate_against_whitelist("invalid", outcomes, "unknown") # "unknown"
|
|
532
|
+
"""
|
|
533
|
+
if value in whitelist:
|
|
534
|
+
return value
|
|
535
|
+
|
|
536
|
+
if warn:
|
|
537
|
+
logger = logging.getLogger(__name__)
|
|
538
|
+
logger.warning(f"Invalid value '{value}', defaulting to '{default}'")
|
|
539
|
+
|
|
540
|
+
return default
|
|
541
|
+
|
|
542
|
+
|
|
543
|
+
def ensure_directory(path: Path, check_writable: bool = False) -> Path:
|
|
544
|
+
"""Ensure directory exists, optionally check write permissions.
|
|
545
|
+
|
|
546
|
+
Args:
|
|
547
|
+
path: Directory path to ensure
|
|
548
|
+
check_writable: If True, verify directory is writable
|
|
549
|
+
|
|
550
|
+
Returns:
|
|
551
|
+
The path (for chaining)
|
|
552
|
+
|
|
553
|
+
Raises:
|
|
554
|
+
OSError: If directory cannot be created
|
|
555
|
+
PermissionError: If check_writable=True and directory is not writable
|
|
556
|
+
|
|
557
|
+
Example:
|
|
558
|
+
>>> output_dir = ensure_directory(Path("data/processed"), check_writable=True)
|
|
559
|
+
"""
|
|
560
|
+
if not path.exists():
|
|
561
|
+
path.mkdir(parents=True, exist_ok=True)
|
|
562
|
+
|
|
563
|
+
if not path.is_dir():
|
|
564
|
+
raise OSError(f"Path exists but is not a directory: {path}")
|
|
565
|
+
|
|
566
|
+
if check_writable:
|
|
567
|
+
# Try to write test file to check permissions
|
|
568
|
+
test_file = path / ".test_write"
|
|
569
|
+
try:
|
|
570
|
+
test_file.touch()
|
|
571
|
+
test_file.unlink()
|
|
572
|
+
except Exception as e:
|
|
573
|
+
raise PermissionError(f"Directory is not writable: {path}. Error: {e}")
|
|
574
|
+
|
|
575
|
+
return path
|
|
576
|
+
|
|
577
|
+
|
|
578
|
+
def compute_file_checksum(file_path: Path, algorithm: str = "sha256", chunk_size: int = 8192) -> str:
|
|
579
|
+
"""Compute checksum of file using specified algorithm.
|
|
580
|
+
|
|
581
|
+
Args:
|
|
582
|
+
file_path: Path to file
|
|
583
|
+
algorithm: Hash algorithm (sha256, sha1, md5)
|
|
584
|
+
chunk_size: Read chunk size in bytes
|
|
585
|
+
|
|
586
|
+
Returns:
|
|
587
|
+
Hex digest of file checksum
|
|
588
|
+
|
|
589
|
+
Raises:
|
|
590
|
+
FileNotFoundError: If file doesn't exist
|
|
591
|
+
ValueError: If algorithm is unsupported
|
|
592
|
+
|
|
593
|
+
Example:
|
|
594
|
+
>>> checksum = compute_file_checksum(video_path)
|
|
595
|
+
>>> checksum = compute_file_checksum(video_path, algorithm="sha1")
|
|
596
|
+
"""
|
|
597
|
+
if not file_path.exists():
|
|
598
|
+
raise FileNotFoundError(f"File not found: {file_path}")
|
|
599
|
+
|
|
600
|
+
# Create hash object
|
|
601
|
+
if algorithm == "sha256":
|
|
602
|
+
hasher = hashlib.sha256()
|
|
603
|
+
elif algorithm == "sha1":
|
|
604
|
+
hasher = hashlib.sha1()
|
|
605
|
+
elif algorithm == "md5":
|
|
606
|
+
hasher = hashlib.md5()
|
|
607
|
+
else:
|
|
608
|
+
raise ValueError(f"Unsupported hash algorithm: {algorithm}")
|
|
609
|
+
|
|
610
|
+
# Read file in chunks and update hash
|
|
611
|
+
with open(file_path, "rb") as f:
|
|
612
|
+
while chunk := f.read(chunk_size):
|
|
613
|
+
hasher.update(chunk)
|
|
614
|
+
|
|
615
|
+
return hasher.hexdigest()
|
|
616
|
+
|
|
617
|
+
|
|
618
|
+
def read_toml(path: Union[str, Path]) -> Dict[str, Any]:
|
|
619
|
+
"""Read TOML file into dictionary.
|
|
620
|
+
|
|
621
|
+
Args:
|
|
622
|
+
path: Path to TOML file (str or Path)
|
|
623
|
+
|
|
624
|
+
Returns:
|
|
625
|
+
Dictionary with parsed TOML data
|
|
626
|
+
|
|
627
|
+
Raises:
|
|
628
|
+
FileNotFoundError: If file doesn't exist
|
|
629
|
+
|
|
630
|
+
Example:
|
|
631
|
+
>>> data = read_toml("config.toml")
|
|
632
|
+
>>> data = read_toml(Path("session.toml"))
|
|
633
|
+
"""
|
|
634
|
+
path = Path(path) if isinstance(path, str) else path
|
|
635
|
+
|
|
636
|
+
if not path.exists():
|
|
637
|
+
raise FileNotFoundError(f"TOML file not found: {path}")
|
|
638
|
+
|
|
639
|
+
try:
|
|
640
|
+
import tomllib
|
|
641
|
+
except ImportError:
|
|
642
|
+
import tomli as tomllib
|
|
643
|
+
|
|
644
|
+
with open(path, "rb") as f:
|
|
645
|
+
return tomllib.load(f)
|
|
646
|
+
|
|
647
|
+
|
|
648
|
+
def write_json(data: Dict[str, Any], path: Union[str, Path], indent: int = 2) -> None:
|
|
649
|
+
"""Write data to JSON file with custom encoder for Path objects.
|
|
650
|
+
|
|
651
|
+
Args:
|
|
652
|
+
data: Dictionary to write
|
|
653
|
+
path: Output file path
|
|
654
|
+
indent: JSON indentation (default: 2 spaces)
|
|
655
|
+
"""
|
|
656
|
+
|
|
657
|
+
class PathEncoder(json.JSONEncoder):
|
|
658
|
+
"""Custom JSON encoder that handles Path objects."""
|
|
659
|
+
|
|
660
|
+
def default(self, obj):
|
|
661
|
+
if isinstance(obj, Path):
|
|
662
|
+
return str(obj)
|
|
663
|
+
return super().default(obj)
|
|
664
|
+
|
|
665
|
+
path_obj = Path(path)
|
|
666
|
+
path_obj.parent.mkdir(parents=True, exist_ok=True)
|
|
667
|
+
|
|
668
|
+
with open(path_obj, "w", encoding="utf-8") as f:
|
|
669
|
+
json.dump(data, f, indent=indent, cls=PathEncoder)
|
|
670
|
+
|
|
671
|
+
|
|
672
|
+
def read_json(path: Union[str, Path]) -> Dict[str, Any]:
|
|
673
|
+
"""Read JSON file into dictionary.
|
|
674
|
+
|
|
675
|
+
Args:
|
|
676
|
+
path: Input file path
|
|
677
|
+
|
|
678
|
+
Returns:
|
|
679
|
+
Dictionary with parsed JSON data
|
|
680
|
+
"""
|
|
681
|
+
with open(path, "r", encoding="utf-8") as f:
|
|
682
|
+
return json.load(f)
|
|
683
|
+
|
|
684
|
+
|
|
685
|
+
def configure_logger(name: str, level: str = "INFO", structured: bool = False) -> logging.Logger:
|
|
686
|
+
"""Configure logger with specified settings.
|
|
687
|
+
|
|
688
|
+
Args:
|
|
689
|
+
name: Logger name
|
|
690
|
+
level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
|
|
691
|
+
structured: If True, use structured (JSON) logging
|
|
692
|
+
|
|
693
|
+
Returns:
|
|
694
|
+
Configured logger instance
|
|
695
|
+
"""
|
|
696
|
+
logger = logging.getLogger(name)
|
|
697
|
+
logger.setLevel(getattr(logging, level.upper()))
|
|
698
|
+
|
|
699
|
+
# Remove existing handlers
|
|
700
|
+
logger.handlers.clear()
|
|
701
|
+
|
|
702
|
+
handler = logging.StreamHandler()
|
|
703
|
+
|
|
704
|
+
if structured:
|
|
705
|
+
# JSON structured logging
|
|
706
|
+
formatter = logging.Formatter('{"timestamp":"%(asctime)s","level":"%(levelname)s","name":"%(name)s","message":"%(message)s"}')
|
|
707
|
+
else:
|
|
708
|
+
# Standard logging
|
|
709
|
+
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
|
710
|
+
|
|
711
|
+
handler.setFormatter(formatter)
|
|
712
|
+
logger.addHandler(handler)
|
|
713
|
+
|
|
714
|
+
return logger
|
|
715
|
+
|
|
716
|
+
|
|
717
|
+
class VideoAnalysisError(Exception):
|
|
718
|
+
"""Error during video analysis operations."""
|
|
719
|
+
|
|
720
|
+
pass
|
|
721
|
+
|
|
722
|
+
|
|
723
|
+
def run_ffprobe(video_path: Path, timeout: int = 30) -> int:
|
|
724
|
+
"""Count frames in a video file using ffprobe.
|
|
725
|
+
|
|
726
|
+
Uses ffprobe to accurately count video frames by reading the stream metadata.
|
|
727
|
+
This is more reliable than using OpenCV for corrupted or unusual video formats.
|
|
728
|
+
|
|
729
|
+
Args:
|
|
730
|
+
video_path: Path to video file
|
|
731
|
+
timeout: Maximum time in seconds to wait for ffprobe (default: 30)
|
|
732
|
+
|
|
733
|
+
Returns:
|
|
734
|
+
Number of frames in video
|
|
735
|
+
|
|
736
|
+
Raises:
|
|
737
|
+
VideoAnalysisError: If video file is invalid or ffprobe fails
|
|
738
|
+
FileNotFoundError: If video file does not exist
|
|
739
|
+
ValueError: If video_path is not a valid path
|
|
740
|
+
|
|
741
|
+
Security:
|
|
742
|
+
- Input path validation to prevent command injection
|
|
743
|
+
- Subprocess timeout to prevent hanging
|
|
744
|
+
- stderr capture for diagnostic information
|
|
745
|
+
"""
|
|
746
|
+
# Input validation
|
|
747
|
+
if not isinstance(video_path, Path):
|
|
748
|
+
video_path = Path(video_path)
|
|
749
|
+
|
|
750
|
+
if not video_path.exists():
|
|
751
|
+
raise FileNotFoundError(f"Video file not found: {video_path}")
|
|
752
|
+
|
|
753
|
+
if not video_path.is_file():
|
|
754
|
+
raise ValueError(f"Path is not a file: {video_path}")
|
|
755
|
+
|
|
756
|
+
# Sanitize path - resolve to absolute path to prevent injection
|
|
757
|
+
video_path = video_path.resolve()
|
|
758
|
+
|
|
759
|
+
# ffprobe command to count frames accurately
|
|
760
|
+
# -v error: only show errors
|
|
761
|
+
# -select_streams v:0: select first video stream
|
|
762
|
+
# -count_frames: actually count frames (slower but accurate)
|
|
763
|
+
# -show_entries stream=nb_read_frames: output only frame count
|
|
764
|
+
# -of csv=p=0: output as CSV without header
|
|
765
|
+
command = [
|
|
766
|
+
"ffprobe",
|
|
767
|
+
"-v",
|
|
768
|
+
"error",
|
|
769
|
+
"-select_streams",
|
|
770
|
+
"v:0",
|
|
771
|
+
"-count_frames",
|
|
772
|
+
"-show_entries",
|
|
773
|
+
"stream=nb_read_frames",
|
|
774
|
+
"-of",
|
|
775
|
+
"csv=p=0",
|
|
776
|
+
str(video_path),
|
|
777
|
+
]
|
|
778
|
+
|
|
779
|
+
try:
|
|
780
|
+
# Run ffprobe with timeout and capture output
|
|
781
|
+
result = subprocess.run(
|
|
782
|
+
command,
|
|
783
|
+
capture_output=True,
|
|
784
|
+
text=True,
|
|
785
|
+
timeout=timeout,
|
|
786
|
+
check=True,
|
|
787
|
+
)
|
|
788
|
+
|
|
789
|
+
# Parse output - should be a single integer
|
|
790
|
+
output = result.stdout.strip()
|
|
791
|
+
|
|
792
|
+
if not output:
|
|
793
|
+
raise VideoAnalysisError(f"ffprobe returned empty output for: {video_path}")
|
|
794
|
+
|
|
795
|
+
try:
|
|
796
|
+
frame_count = int(output)
|
|
797
|
+
except ValueError:
|
|
798
|
+
raise VideoAnalysisError(f"ffprobe returned non-integer output: {output}")
|
|
799
|
+
|
|
800
|
+
if frame_count < 0:
|
|
801
|
+
raise VideoAnalysisError(f"ffprobe returned negative frame count: {frame_count}")
|
|
802
|
+
|
|
803
|
+
return frame_count
|
|
804
|
+
|
|
805
|
+
except subprocess.TimeoutExpired:
|
|
806
|
+
raise VideoAnalysisError(f"ffprobe timed out after {timeout}s for: {video_path}")
|
|
807
|
+
|
|
808
|
+
except subprocess.CalledProcessError as e:
|
|
809
|
+
# ffprobe failed - provide diagnostic information
|
|
810
|
+
stderr_msg = e.stderr.strip() if e.stderr else "No error message"
|
|
811
|
+
raise VideoAnalysisError(f"ffprobe failed for {video_path}: {stderr_msg}")
|
|
812
|
+
|
|
813
|
+
except Exception as e:
|
|
814
|
+
# Unexpected error
|
|
815
|
+
raise VideoAnalysisError(f"Unexpected error running ffprobe: {e}")
|
|
816
|
+
|
|
817
|
+
|
|
818
|
+
# =============================================================================
|
|
819
|
+
# Events Helper Functions (Numpy Array Handling)
|
|
820
|
+
# =============================================================================
|
|
821
|
+
|
|
822
|
+
|
|
823
|
+
def to_scalar(value: Union[Any, "np.ndarray"], index: int) -> Any:
|
|
824
|
+
"""Extract scalar from array or list.
|
|
825
|
+
|
|
826
|
+
Args:
|
|
827
|
+
value: Array, list, tuple, or scalar
|
|
828
|
+
index: Index to extract
|
|
829
|
+
|
|
830
|
+
Returns:
|
|
831
|
+
Scalar value at index
|
|
832
|
+
|
|
833
|
+
Raises:
|
|
834
|
+
IndexError: Index out of bounds
|
|
835
|
+
"""
|
|
836
|
+
import numpy as np
|
|
837
|
+
|
|
838
|
+
if isinstance(value, np.ndarray):
|
|
839
|
+
# Handle numpy arrays (including 0-d arrays)
|
|
840
|
+
if value.ndim == 0:
|
|
841
|
+
return value.item()
|
|
842
|
+
return value[index].item() if hasattr(value[index], "item") else value[index]
|
|
843
|
+
elif isinstance(value, (list, tuple)):
|
|
844
|
+
return value[index]
|
|
845
|
+
else:
|
|
846
|
+
# Assume it's already a scalar
|
|
847
|
+
return value
|
|
848
|
+
|
|
849
|
+
|
|
850
|
+
def to_list(value: Union[Any, "np.ndarray"]) -> List[Any]:
|
|
851
|
+
"""Convert array or scalar to Python list.
|
|
852
|
+
|
|
853
|
+
Args:
|
|
854
|
+
value: Array, list, tuple, or scalar
|
|
855
|
+
|
|
856
|
+
Returns:
|
|
857
|
+
Python list
|
|
858
|
+
"""
|
|
859
|
+
import numpy as np
|
|
860
|
+
|
|
861
|
+
if isinstance(value, np.ndarray):
|
|
862
|
+
return value.tolist()
|
|
863
|
+
elif isinstance(value, (list, tuple)):
|
|
864
|
+
return list(value)
|
|
865
|
+
else:
|
|
866
|
+
# Scalar value
|
|
867
|
+
return [value]
|
|
868
|
+
|
|
869
|
+
|
|
870
|
+
if __name__ == "__main__":
|
|
871
|
+
"""Usage examples for utils module."""
|
|
872
|
+
import tempfile
|
|
873
|
+
|
|
874
|
+
print("=" * 70)
|
|
875
|
+
print("W2T-BKIN Utils Module - Usage Examples")
|
|
876
|
+
print("=" * 70)
|
|
877
|
+
print()
|
|
878
|
+
|
|
879
|
+
# Example 1: Compute hash
|
|
880
|
+
print("Example 1: Compute Hash")
|
|
881
|
+
print("-" * 50)
|
|
882
|
+
test_data = {"session_id": "Session-000001", "timestamp": "2025-11-12"}
|
|
883
|
+
hash_result = compute_hash(test_data)
|
|
884
|
+
print(f"Data: {test_data}")
|
|
885
|
+
print(f"Hash: {hash_result}")
|
|
886
|
+
print()
|
|
887
|
+
|
|
888
|
+
# Example 2: Sanitize path
|
|
889
|
+
print("Example 2: Sanitize Path")
|
|
890
|
+
print("-" * 50)
|
|
891
|
+
safe_path = sanitize_path("data/raw/Session-000001")
|
|
892
|
+
print(f"Input: data/raw/Session-000001")
|
|
893
|
+
print(f"Sanitized: {safe_path}")
|
|
894
|
+
|
|
895
|
+
try:
|
|
896
|
+
dangerous = sanitize_path("../../etc/passwd")
|
|
897
|
+
print(f"Dangerous path: {dangerous}")
|
|
898
|
+
except ValueError as e:
|
|
899
|
+
print(f"Blocked directory traversal: {e}")
|
|
900
|
+
print()
|
|
901
|
+
|
|
902
|
+
# Example 3: JSON I/O
|
|
903
|
+
print("Example 3: JSON I/O")
|
|
904
|
+
print("-" * 50)
|
|
905
|
+
with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f:
|
|
906
|
+
temp_path = Path(f.name)
|
|
907
|
+
|
|
908
|
+
test_obj = {"key": "value", "number": 42}
|
|
909
|
+
write_json(test_obj, temp_path)
|
|
910
|
+
print(f"Wrote to: {temp_path.name}")
|
|
911
|
+
|
|
912
|
+
loaded = read_json(temp_path)
|
|
913
|
+
print(f"Read back: {loaded}")
|
|
914
|
+
temp_path.unlink()
|
|
915
|
+
print()
|
|
916
|
+
|
|
917
|
+
print("=" * 70)
|
|
918
|
+
print("Examples completed. See module docstring for more details.")
|
|
919
|
+
print("=" * 70)
|
|
920
|
+
|
|
921
|
+
|
|
922
|
+
# =============================================================================
|
|
923
|
+
# Video and TTL Counting Utilities
|
|
924
|
+
# =============================================================================
|
|
925
|
+
|
|
926
|
+
|
|
927
|
+
def count_video_frames(video_path: Path) -> int:
|
|
928
|
+
"""Count frames in a video file using ffprobe or synthetic stub.
|
|
929
|
+
|
|
930
|
+
This function counts video frames using ffprobe. It includes special
|
|
931
|
+
handling for synthetic stub videos (used in testing) and provides
|
|
932
|
+
robust error handling for unreadable or missing files.
|
|
933
|
+
|
|
934
|
+
Args:
|
|
935
|
+
video_path: Path to video file
|
|
936
|
+
|
|
937
|
+
Returns:
|
|
938
|
+
Number of frames in video (0 if file not found or empty)
|
|
939
|
+
|
|
940
|
+
Raises:
|
|
941
|
+
RuntimeError: If video file cannot be analyzed
|
|
942
|
+
|
|
943
|
+
Example:
|
|
944
|
+
>>> from pathlib import Path
|
|
945
|
+
>>> frame_count = count_video_frames(Path("video.avi"))
|
|
946
|
+
>>> print(f"Frames: {frame_count}")
|
|
947
|
+
"""
|
|
948
|
+
# Validate input
|
|
949
|
+
if not video_path.exists():
|
|
950
|
+
logger.warning(f"Video file not found: {video_path}")
|
|
951
|
+
return 0
|
|
952
|
+
|
|
953
|
+
# Handle empty files
|
|
954
|
+
if video_path.stat().st_size == 0:
|
|
955
|
+
logger.warning(f"Video file is empty: {video_path}")
|
|
956
|
+
return 0
|
|
957
|
+
|
|
958
|
+
# Check if this is a synthetic stub video
|
|
959
|
+
try:
|
|
960
|
+
# Try importing synthetic module (only available if in test/synthetic context)
|
|
961
|
+
from synthetic.video_synth import count_stub_frames, is_synthetic_stub
|
|
962
|
+
|
|
963
|
+
if is_synthetic_stub(video_path):
|
|
964
|
+
frame_count = count_stub_frames(video_path)
|
|
965
|
+
logger.debug(f"Counted {frame_count} frames in synthetic stub {video_path.name}")
|
|
966
|
+
return frame_count
|
|
967
|
+
except ImportError:
|
|
968
|
+
# Synthetic module not available - continue with normal ffprobe
|
|
969
|
+
pass
|
|
970
|
+
|
|
971
|
+
# Use ffprobe to count frames
|
|
972
|
+
try:
|
|
973
|
+
frame_count = run_ffprobe(video_path)
|
|
974
|
+
logger.debug(f"Counted {frame_count} frames in {video_path.name}")
|
|
975
|
+
return frame_count
|
|
976
|
+
except Exception as e:
|
|
977
|
+
# Log error and raise - frame counting failure is critical
|
|
978
|
+
logger.error(f"Failed to count frames in {video_path}: {e}")
|
|
979
|
+
raise RuntimeError(f"Could not count frames in video {video_path}: {e}")
|
|
980
|
+
|
|
981
|
+
|
|
982
|
+
def normalize_keypoints_to_dict(keypoints) -> Dict[str, Dict]:
|
|
983
|
+
"""Convert keypoints to standard dict format (name -> keypoint data).
|
|
984
|
+
|
|
985
|
+
Handles multiple input formats:
|
|
986
|
+
- KeypointsDict (custom dict that iterates over values)
|
|
987
|
+
- Regular dict with keypoint name as key
|
|
988
|
+
- List of keypoint dicts
|
|
989
|
+
|
|
990
|
+
Args:
|
|
991
|
+
keypoints: Keypoints in various formats
|
|
992
|
+
|
|
993
|
+
Returns:
|
|
994
|
+
Dictionary mapping keypoint name to keypoint data dict
|
|
995
|
+
|
|
996
|
+
Example:
|
|
997
|
+
>>> kp_list = [{"name": "nose", "x": 10, "y": 20}]
|
|
998
|
+
>>> result = normalize_keypoints_to_dict(kp_list)
|
|
999
|
+
>>> result["nose"]
|
|
1000
|
+
{"name": "nose", "x": 10, "y": 20}
|
|
1001
|
+
"""
|
|
1002
|
+
if isinstance(keypoints, dict):
|
|
1003
|
+
# Check if it's a KeypointsDict or dict-like that iterates values
|
|
1004
|
+
if hasattr(keypoints, "__iter__") and keypoints:
|
|
1005
|
+
first_val = next(iter(keypoints.values()))
|
|
1006
|
+
if isinstance(first_val, dict) and "name" in first_val:
|
|
1007
|
+
# Already in correct format (name -> dict)
|
|
1008
|
+
return keypoints
|
|
1009
|
+
# Standard dict format
|
|
1010
|
+
return keypoints
|
|
1011
|
+
elif isinstance(keypoints, list):
|
|
1012
|
+
# Convert list to dict
|
|
1013
|
+
return {kp["name"]: kp for kp in keypoints}
|
|
1014
|
+
else:
|
|
1015
|
+
# Fallback for unknown types
|
|
1016
|
+
return {}
|
|
1017
|
+
|
|
1018
|
+
|
|
1019
|
+
def log_missing_keypoints(
|
|
1020
|
+
frame_index: int,
|
|
1021
|
+
expected_names: Set[str],
|
|
1022
|
+
actual_names: Set[str],
|
|
1023
|
+
logger_instance: logging.Logger,
|
|
1024
|
+
) -> None:
|
|
1025
|
+
"""Log warning for missing keypoints in a frame.
|
|
1026
|
+
|
|
1027
|
+
Args:
|
|
1028
|
+
frame_index: Frame number for logging
|
|
1029
|
+
expected_names: Set of expected keypoint names
|
|
1030
|
+
actual_names: Set of actual keypoint names found
|
|
1031
|
+
logger_instance: Logger to use for warning
|
|
1032
|
+
"""
|
|
1033
|
+
missing = expected_names - actual_names
|
|
1034
|
+
if missing:
|
|
1035
|
+
logger_instance.warning(f"Frame {frame_index}: Missing keypoints {missing}")
|
|
1036
|
+
|
|
1037
|
+
|
|
1038
|
+
def derive_bodyparts_from_data(data: List[Dict]) -> List[str]:
|
|
1039
|
+
"""Extract canonical bodypart names from harmonized pose data.
|
|
1040
|
+
|
|
1041
|
+
Args:
|
|
1042
|
+
data: List of pose frame dictionaries with keypoints
|
|
1043
|
+
|
|
1044
|
+
Returns:
|
|
1045
|
+
Sorted list of bodypart names
|
|
1046
|
+
|
|
1047
|
+
Raises:
|
|
1048
|
+
ValueError: If data is empty or has no keypoints
|
|
1049
|
+
|
|
1050
|
+
Example:
|
|
1051
|
+
>>> frames = [{"keypoints": {"nose": {...}, "ear_left": {...}}}]
|
|
1052
|
+
>>> derive_bodyparts_from_data(frames)
|
|
1053
|
+
["ear_left", "nose"]
|
|
1054
|
+
"""
|
|
1055
|
+
if not data:
|
|
1056
|
+
raise ValueError("Cannot derive bodyparts from empty data")
|
|
1057
|
+
|
|
1058
|
+
first_frame_keypoints = normalize_keypoints_to_dict(data[0].get("keypoints", {}))
|
|
1059
|
+
|
|
1060
|
+
if not first_frame_keypoints:
|
|
1061
|
+
raise ValueError("First frame has no keypoints")
|
|
1062
|
+
|
|
1063
|
+
# Sort for consistency across runs
|
|
1064
|
+
return sorted(first_frame_keypoints.keys())
|
|
1065
|
+
|
|
1066
|
+
|
|
1067
|
+
def count_ttl_pulses(ttl_path: Path) -> int:
|
|
1068
|
+
"""Count TTL pulses from log file.
|
|
1069
|
+
|
|
1070
|
+
Counts non-empty lines in a TTL log file. Each line represents one
|
|
1071
|
+
TTL pulse event.
|
|
1072
|
+
|
|
1073
|
+
Args:
|
|
1074
|
+
ttl_path: Path to TTL log file
|
|
1075
|
+
|
|
1076
|
+
Returns:
|
|
1077
|
+
Number of pulses in file (0 if file not found or unreadable)
|
|
1078
|
+
|
|
1079
|
+
Example:
|
|
1080
|
+
>>> from pathlib import Path
|
|
1081
|
+
>>> pulse_count = count_ttl_pulses(Path("camera_ttl.log"))
|
|
1082
|
+
>>> print(f"Pulses: {pulse_count}")
|
|
1083
|
+
"""
|
|
1084
|
+
if not ttl_path.exists():
|
|
1085
|
+
return 0
|
|
1086
|
+
|
|
1087
|
+
# Count lines in TTL file (each line = one pulse)
|
|
1088
|
+
try:
|
|
1089
|
+
with open(ttl_path, "r") as f:
|
|
1090
|
+
lines = f.readlines()
|
|
1091
|
+
return len([line for line in lines if line.strip()])
|
|
1092
|
+
except Exception:
|
|
1093
|
+
return 0
|