jabs-core 0.1.0a1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,43 @@
1
+ Metadata-Version: 2.3
2
+ Name: jabs-core
3
+ Version: 0.1.0a1
4
+ Summary: Add your description here
5
+ Requires-Dist: packaging>=24.0
6
+ Requires-Dist: toml>=0.10.2,<0.11.0
7
+ Requires-Dist: h5py>=3.10.0,<4.0.0
8
+ Requires-Dist: shapely>=2.0.1,<3.0.0
9
+ Requires-Dist: numpy>=2.0.0,<3.0.0
10
+ Requires-Dist: opencv-python-headless>=4.8.1.78,<5.0.0
11
+ Requires-Python: >=3.10, <3.15
12
+ Project-URL: Repository, https://github.com/KumarLabJax/JABS-behavior-classifier
13
+ Project-URL: Issues, https://github.com/KumarLabJax/JABS-behavior-classifier/issues
14
+ Description-Content-Type: text/markdown
15
+
16
+ # JABS Core (`jabs-core`)
17
+
18
+ The infrastructure and shared utility layer for the JABS.
19
+
20
+ ## Overview
21
+
22
+ `jabs-core` provides low-level, domain-agnostic utilities used across all JABS packages.
23
+ It is designed to be lightweight and free of heavy scientific dependencies (like
24
+ `scikit-learn` or `pandas`), making it safe to import at any level of the hierarchy.
25
+
26
+ ## Responsibilities
27
+
28
+ - **Shared Constants**: Global constants used for file compression and configuration.
29
+ - **Exceptions**: Centralized exception hierarchy (`JabsError`, `PoseHashException`,
30
+ etc.).
31
+ - **Infrastructure**: Base classes for registries and plugin discovery systems.
32
+ - **Abstract Bases**: High-level interface definitions (e.g., the `PoseEstimation`
33
+ abstract base).
34
+ - **Utility Functions**: Generic helpers for file hashing, logging configuration, and
35
+ basic string/path manipulation.
36
+
37
+ ## Package Structure
38
+
39
+ - `jabs.core.constants`: Global constants.
40
+ - `jabs.core.exceptions`: Shared exception classes.
41
+ - `jabs.core.abstract`: Abstract base classes for the system.
42
+ - `jabs.core.utils`: Generic utility functions.
43
+ - `jabs.core.enums`: Shared enumerations (e.g., `ClassifierType`).
@@ -0,0 +1,28 @@
1
+ # JABS Core (`jabs-core`)
2
+
3
+ The infrastructure and shared utility layer for the JABS.
4
+
5
+ ## Overview
6
+
7
+ `jabs-core` provides low-level, domain-agnostic utilities used across all JABS packages.
8
+ It is designed to be lightweight and free of heavy scientific dependencies (like
9
+ `scikit-learn` or `pandas`), making it safe to import at any level of the hierarchy.
10
+
11
+ ## Responsibilities
12
+
13
+ - **Shared Constants**: Global constants used for file compression and configuration.
14
+ - **Exceptions**: Centralized exception hierarchy (`JabsError`, `PoseHashException`,
15
+ etc.).
16
+ - **Infrastructure**: Base classes for registries and plugin discovery systems.
17
+ - **Abstract Bases**: High-level interface definitions (e.g., the `PoseEstimation`
18
+ abstract base).
19
+ - **Utility Functions**: Generic helpers for file hashing, logging configuration, and
20
+ basic string/path manipulation.
21
+
22
+ ## Package Structure
23
+
24
+ - `jabs.core.constants`: Global constants.
25
+ - `jabs.core.exceptions`: Shared exception classes.
26
+ - `jabs.core.abstract`: Abstract base classes for the system.
27
+ - `jabs.core.utils`: Generic utility functions.
28
+ - `jabs.core.enums`: Shared enumerations (e.g., `ClassifierType`).
@@ -0,0 +1,44 @@
1
+ [project]
2
+ name = "jabs-core"
3
+ version = "0.1.0a1"
4
+ description = "Add your description here"
5
+ readme = "README.md"
6
+ requires-python = ">=3.10,<3.15"
7
+ dependencies = [
8
+ "packaging>=24.0",
9
+ "toml>=0.10.2,<0.11.0",
10
+ "h5py>=3.10.0,<4.0.0",
11
+ "shapely>=2.0.1,<3.0.0",
12
+ "numpy>=2.0.0,<3.0.0",
13
+ "opencv-python-headless>=4.8.1.78,<5.0.0",
14
+ ]
15
+
16
+ [dependency-groups]
17
+ dev = [
18
+ {include-group = "lint"},
19
+ {include-group = "test"},
20
+ {include-group = "docs"},
21
+ "pre-commit>=4.2.0,<5.0.0",
22
+ "matplotlib>=3.9.3,<4.0.0",
23
+ ]
24
+ test = [
25
+ "pytest>=8.3.4,<9.0.0",
26
+ "pytest-cov>=7.0.0",
27
+ ]
28
+ lint = [
29
+ "ruff>=0.11.5,<0.12.0",
30
+ ]
31
+ docs = [
32
+ "mkdocs>=1.6.1",
33
+ ]
34
+
35
+ [project.urls]
36
+ Repository = "https://github.com/KumarLabJax/JABS-behavior-classifier"
37
+ Issues = "https://github.com/KumarLabJax/JABS-behavior-classifier/issues"
38
+
39
+ [build-system]
40
+ requires = ["uv_build>=0.9.26,<0.10.0"]
41
+ build-backend = "uv_build"
42
+
43
+ [tool.uv.build-backend]
44
+ module-name = "jabs.core"
@@ -0,0 +1 @@
1
+ """The root of the jabs.core package."""
@@ -0,0 +1,7 @@
1
+ """JABS Abstract Base Classes"""
2
+
3
+ from .pose_est import PoseEstimation
4
+
5
+ __all__ = [
6
+ "PoseEstimation",
7
+ ]
@@ -0,0 +1,421 @@
1
+ import enum
2
+ import logging
3
+ from abc import ABC, abstractmethod
4
+ from pathlib import Path
5
+
6
+ import h5py
7
+ import joblib
8
+ import numpy as np
9
+ from shapely.geometry import MultiPoint
10
+
11
+ from jabs.core.utils import hash_file
12
+
13
+ MINIMUM_CONFIDENCE = 0.3
14
+
15
+
16
+ class PoseEstimation(ABC):
17
+ """Abstract base class for pose estimation data handlers.
18
+
19
+ Provides a common interface for loading, accessing, and processing pose data
20
+ from HDF5 files. Defines methods for retrieving keypoints, confidence masks, identity
21
+ presence, and static objects, as well as utilities for geometric computations such as
22
+ convex hulls and bearing angles. All pose estimation versioned classes should inherit
23
+ from this base class.
24
+
25
+ Args:
26
+ file_path (Path): Path to the pose HDF5 file.
27
+ cache_dir (Path | None): Optional cache directory for intermediate data.
28
+ fps (int): Frames per second for the video.
29
+
30
+ Abstract Methods:
31
+ get_points(frame_index, identity, scale): Get points and mask for an identity in a frame.
32
+ get_identity_poses(identity, scale): Get all points and masks for an identity.
33
+ get_identity_point_mask(identity): Get the point mask array for a given identity.
34
+ identity_mask(identity): Get the identity mask for a given identity.
35
+ identity_to_track: Get the identity-to-track mapping for this file.
36
+ format_major_version: Returns the major version of the pose file format.
37
+
38
+ Methods:
39
+ get_identity_convex_hulls(identity): Get convex hulls for an identity across frames.
40
+ compute_bearing(points): Compute the bearing angle for a single frame.
41
+ compute_all_bearings(identity): Compute bearing angles for all frames of an identity.
42
+ get_pose_file_attributes(path): Static method to get HDF5 file attributes.
43
+
44
+ Properties:
45
+ num_frames (int): Number of frames.
46
+ identities (list): List of identities.
47
+ num_identities (int): Number of identities.
48
+ cm_per_pixel (float | None): Centimeters per pixel.
49
+ fps (int): Frames per second.
50
+ pose_file (Path): Path to the pose file.
51
+ hash (str): Hash of the pose file.
52
+ static_objects (dict): Static objects in the pose file.
53
+ num_lixit_keypoints (int): Number of lixit keypoints (default 0).
54
+ external_identities (list[int] | None): Mapping to external identities.
55
+ """
56
+
57
+ class KeypointIndex(enum.IntEnum):
58
+ """enum defining the 12 keypoint indexes"""
59
+
60
+ NOSE = 0
61
+ LEFT_EAR = 1
62
+ RIGHT_EAR = 2
63
+ BASE_NECK = 3
64
+ LEFT_FRONT_PAW = 4
65
+ RIGHT_FRONT_PAW = 5
66
+ CENTER_SPINE = 6
67
+ LEFT_REAR_PAW = 7
68
+ RIGHT_REAR_PAW = 8
69
+ BASE_TAIL = 9
70
+ MID_TAIL = 10
71
+ TIP_TAIL = 11
72
+
73
+ # Connected segments to use when full 12 keypoints are available.
74
+ FULL_CONNECTED_SEGMENTS = (
75
+ (
76
+ KeypointIndex.LEFT_FRONT_PAW,
77
+ KeypointIndex.CENTER_SPINE,
78
+ KeypointIndex.RIGHT_FRONT_PAW,
79
+ ),
80
+ (
81
+ KeypointIndex.LEFT_REAR_PAW,
82
+ KeypointIndex.BASE_TAIL,
83
+ KeypointIndex.RIGHT_REAR_PAW,
84
+ ),
85
+ (
86
+ KeypointIndex.NOSE,
87
+ KeypointIndex.BASE_NECK,
88
+ KeypointIndex.CENTER_SPINE,
89
+ KeypointIndex.BASE_TAIL,
90
+ KeypointIndex.MID_TAIL,
91
+ KeypointIndex.TIP_TAIL,
92
+ ),
93
+ )
94
+
95
+ # Pose based on the Envision Hydra model will have fewer keypoints,
96
+ # so we adjust the connected segments accordingly.
97
+ NVSN_CONNECTED_SEGMENTS = (
98
+ (
99
+ KeypointIndex.LEFT_EAR,
100
+ KeypointIndex.NOSE,
101
+ KeypointIndex.RIGHT_EAR,
102
+ ),
103
+ (
104
+ KeypointIndex.NOSE,
105
+ KeypointIndex.BASE_TAIL,
106
+ KeypointIndex.TIP_TAIL,
107
+ ),
108
+ )
109
+
110
+ _CACHE_FILE_VERSION = 1
111
+
112
+ def __init__(self, file_path: Path, cache_dir: Path | None = None, fps: int = 30):
113
+ """initialize new object from h5 file
114
+
115
+ Args:
116
+ file_path: path to pose_est_v2.h5 file
117
+ cache_dir: optional cache directory, used to cache convex
118
+ hulls
119
+ fps: frames per second, used for scaling time series
120
+ features
121
+ for faster loading
122
+ from "per frame" to "per second"
123
+ """
124
+ super().__init__()
125
+ self._num_frames = 0
126
+ self._identities = []
127
+ self._external_identities: list[str] | None = None
128
+ self._convex_hull_cache = {}
129
+ self._path = file_path
130
+ self._cache_dir = cache_dir
131
+ self._cm_per_pixel = None
132
+ self._hash = hash_file(file_path)
133
+ self._fps = fps
134
+
135
+ self._static_objects = {}
136
+
137
+ # check cache version, if it doesn't match, clear the cache file for this pose file
138
+ if self._cache_dir is not None and not self.check_cache_version():
139
+ cache_file = self._cache_file_path()
140
+ if cache_file and cache_file.exists():
141
+ try:
142
+ cache_file.unlink()
143
+ except Exception:
144
+ logging.warning("Unable to delete old cache file %s", cache_file)
145
+ pass
146
+
147
+ @property
148
+ def num_frames(self) -> int:
149
+ """return the number of frames in the pose_est file"""
150
+ return self._num_frames
151
+
152
+ @property
153
+ def identities(self):
154
+ """return list of integer identities generated from file"""
155
+ return self._identities
156
+
157
+ @property
158
+ def num_identities(self) -> int:
159
+ """get the number of identities in the pose file"""
160
+ return len(self._identities)
161
+
162
+ @property
163
+ def cm_per_pixel(self):
164
+ """get centimeters per pixel for video/pose"""
165
+ return self._cm_per_pixel
166
+
167
+ @property
168
+ def fps(self):
169
+ """get frames per second"""
170
+ return self._fps
171
+
172
+ @property
173
+ def pose_file(self):
174
+ """get the path to the pose file"""
175
+ return self._path
176
+
177
+ @property
178
+ def hash(self):
179
+ """get the hash of the pose file"""
180
+ return self._hash
181
+
182
+ @abstractmethod
183
+ def get_points(self, frame_index: int, identity: int, scale: float | None = None):
184
+ """return points and point masks for an individual frame
185
+
186
+ Args:
187
+ frame_index: frame index of points and masks to be returned
188
+ identity: identity to return points for
189
+ scale: optional scale factor, set to cm_per_pixel to convert
190
+ poses from pixel coordinates to cm coordinates
191
+
192
+ Returns:
193
+ numpy array of points (12,2), numpy array of point masks (12,)
194
+ """
195
+ pass
196
+
197
+ @abstractmethod
198
+ def get_identity_poses(self, identity: int, scale: float | None = None):
199
+ """return all points and point masks
200
+
201
+ Args:
202
+ identity: identity to return points for
203
+ scale: optional scale factor, set to cm_per_pixel to convert
204
+ poses from pixel coordinates to cm coordinates
205
+
206
+ Returns:
207
+ numpy array of points (#frames, 12, 2), numpy array of point masks (#frames, 12)
208
+ """
209
+ pass
210
+
211
+ @abstractmethod
212
+ def get_identity_point_mask(self, identity):
213
+ """get the point mask array for a given identity
214
+
215
+ Args:
216
+ identity: identity to return point mask for
217
+
218
+ Returns:
219
+ array of point masks (#frames, 12)
220
+ """
221
+ pass
222
+
223
+ @abstractmethod
224
+ def get_reduced_point_mask(self):
225
+ """Returns a boolean array of length 12 indicating which keypoints are valid.
226
+
227
+ Determines which keypoints are valid for any identity across all frames.
228
+
229
+ Returns:
230
+ numpy array of shape (12,) with boolean values indicating validity
231
+ of each keypoint.
232
+ """
233
+ pass
234
+
235
+ def get_connected_segments(self):
236
+ """Get the segments to use for rendering connections between the keypoints
237
+
238
+ Returns:
239
+ list of tuples, where each tuple contains the indexes of the keypoints
240
+ that form a connected segment
241
+ """
242
+ return PoseEstimation.FULL_CONNECTED_SEGMENTS
243
+
244
+ @abstractmethod
245
+ def identity_mask(self, identity):
246
+ """get the identity mask (indicates if specified identity is present in each frame)
247
+
248
+ Args:
249
+ identity: identity to get masks for
250
+
251
+ Returns:
252
+ numpy array of size (#frames,)
253
+ """
254
+ pass
255
+
256
+ @property
257
+ @abstractmethod
258
+ def identity_to_track(self):
259
+ """get the identity to track mapping for this file"""
260
+ pass
261
+
262
+ @property
263
+ @abstractmethod
264
+ def format_major_version(self):
265
+ """an integer giving the major version of the format"""
266
+ pass
267
+
268
+ @property
269
+ def static_objects(self):
270
+ """get static objects from the pose file"""
271
+ return self._static_objects
272
+
273
+ def get_identity_convex_hulls(self, identity):
274
+ """get a list of length #frames containing convex hulls for the given identity.
275
+
276
+ The convex hulls are calculated using all valid points except for the
277
+ middle of tail and tip of tail points.
278
+
279
+ Args:
280
+ identity: identity to return points for
281
+
282
+ Returns:
283
+ the convex hulls in pixel units (array elements will be None
284
+ if there is no valid convex hull for that frame)
285
+ """
286
+ if identity in self._convex_hull_cache:
287
+ return self._convex_hull_cache[identity]
288
+ else:
289
+ convex_hulls = None
290
+ path = None
291
+ if self._cache_dir is not None:
292
+ path = (
293
+ self._cache_dir
294
+ / "convex_hulls"
295
+ / self._path.with_suffix("").name
296
+ / f"convex_hulls_{identity}.pickle"
297
+ )
298
+ path.parents[0].mkdir(mode=0o775, parents=True, exist_ok=True)
299
+
300
+ try:
301
+ with path.open("rb") as f:
302
+ convex_hulls = joblib.load(f)
303
+ except Exception:
304
+ # we weren't able to read in the cached convex hulls,
305
+ # just ignore the exception and we'll generate them
306
+ pass
307
+
308
+ if convex_hulls is None:
309
+ points, point_masks = self.get_identity_poses(identity)
310
+ # Omit tail from convex hull
311
+ body_points = points[:, :-2, :]
312
+ body_point_masks = point_masks[:, :-2]
313
+ convex_hulls = []
314
+
315
+ for frame_index in range(self.num_frames):
316
+ if sum(body_point_masks[frame_index, :]) >= 3:
317
+ filtered_points = body_points[
318
+ frame_index, body_point_masks[frame_index, :] == 1, :
319
+ ]
320
+ convex_hulls.append(MultiPoint(filtered_points).convex_hull)
321
+ else:
322
+ convex_hulls.append(None)
323
+
324
+ if path:
325
+ with path.open("wb") as f:
326
+ joblib.dump(convex_hulls, f)
327
+
328
+ self._convex_hull_cache[identity] = convex_hulls
329
+ return convex_hulls
330
+
331
+ def compute_bearing(self, points: np.ndarray, use_nose: bool = False):
332
+ """compute the bearing of the animal using base tail and base neck keypoints
333
+
334
+ Args:
335
+ points (np.ndarray): the points for a single frame (12,2) array
336
+ use_nose (bool): use nose keypoint instead of base neck, used when
337
+ we have a reduced keypoint pose that lacks base neck
338
+ """
339
+ # fall back to use nose instead of base neck if base neck is absent from this pose file
340
+ # (for example, 5 keypoint pose instead of 12)
341
+ if use_nose:
342
+ p1_xy = points[self.KeypointIndex.NOSE.value].astype(np.float32)
343
+ else:
344
+ p1_xy = points[self.KeypointIndex.BASE_NECK.value].astype(np.float32)
345
+ p2_xy = points[self.KeypointIndex.BASE_TAIL.value].astype(np.float32)
346
+ offset_xy = p1_xy - p2_xy
347
+
348
+ angle_rad = np.arctan2(offset_xy[1], offset_xy[0])
349
+
350
+ return np.degrees(angle_rad)
351
+
352
+ def compute_all_bearings(self, identity):
353
+ """compute the bearing for each frame for a given identity"""
354
+ use_nose = not self.get_reduced_point_mask()[self.KeypointIndex.BASE_NECK.value]
355
+ if use_nose:
356
+ logging.warning("Falling back to using nose keypoint for bearing computation")
357
+
358
+ bearings = np.full(self.num_frames, np.nan, dtype=np.float32)
359
+ for i in range(self.num_frames):
360
+ points, mask = self.get_points(i, identity)
361
+ if points is not None:
362
+ bearings[i] = self.compute_bearing(points, use_nose)
363
+ return bearings
364
+
365
+ @staticmethod
366
+ def get_pose_file_attributes(path: Path) -> dict:
367
+ """get the attributes from the pose file's hdf5 file"""
368
+ with h5py.File(path, "r") as pose_h5:
369
+ attrs = dict(pose_h5.attrs)
370
+ attrs["poseest"] = dict(pose_h5["poseest"].attrs)
371
+ return attrs
372
+
373
+ @property
374
+ def num_lixit_keypoints(self) -> int:
375
+ """get the number of lixit keypoints
376
+
377
+ always 0 for pose file versions <5
378
+ """
379
+ return 0
380
+
381
+ @property
382
+ def external_identities(self) -> list[str] | None:
383
+ """get the jabs identity to external identity mapping"""
384
+ return self._external_identities
385
+
386
+ def identity_index_to_display(self, identity_index: int) -> str:
387
+ """Convert an identity index to a display string.
388
+
389
+ Args:
390
+ identity_index (int): The identity index to convert.
391
+
392
+ Returns:
393
+ str: The display string for the identity.
394
+ """
395
+ if self.external_identities and 0 <= identity_index < len(self.external_identities):
396
+ return self.external_identities[identity_index]
397
+ return str(identity_index)
398
+
399
+ def check_cache_version(self) -> bool:
400
+ """Check if the cache version matches the expected version.
401
+
402
+ Returns:
403
+ bool: True if the cache version matches, False otherwise.
404
+ """
405
+ try:
406
+ with h5py.File(self._cache_file_path(), "r") as cache_h5:
407
+ cache_version = cache_h5.attrs.get("cache_file_version", None)
408
+ return cache_version == self._CACHE_FILE_VERSION
409
+ except Exception:
410
+ return False
411
+
412
+ def _cache_file_path(self) -> Path | None:
413
+ """Get the path to the cache file for this pose file.
414
+
415
+ Returns:
416
+ Path | None: The path to the cache file, or None if no cache directory is set.
417
+ """
418
+ if self._cache_dir is None:
419
+ return None
420
+ filename = self._path.name.replace(".h5", "_cache.h5")
421
+ return self._cache_dir / filename
@@ -0,0 +1,15 @@
1
+ ORG_NAME = "JAX"
2
+ APP_NAME = "JABS"
3
+ APP_NAME_LONG = f"{ORG_NAME} Animal Behavior System"
4
+
5
+ # a hard coded random seed used for the final training
6
+ # This is not used during cross-validation, but to ensure that final classifier is reproducible
7
+ # we use this fixed seed when training the final model after cross validation.
8
+ FINAL_TRAIN_SEED = 0xAB3BDB
9
+
10
+ # some defaults for compressing hdf5 output
11
+ COMPRESSION = "gzip"
12
+ COMPRESSION_OPTS_DEFAULT = 6
13
+
14
+ # settings keys for project settings stored in the project.json file
15
+ CV_GROUPING_KEY = "cv_grouping"
@@ -0,0 +1,12 @@
1
+ """Module for defining enums used in JABS"""
2
+
3
+ from .classifier_types import ClassifierType
4
+ from .cv_grouping import DEFAULT_CV_GROUPING_STRATEGY, CrossValidationGroupingStrategy
5
+ from .units import ProjectDistanceUnit
6
+
7
+ __all__ = [
8
+ "DEFAULT_CV_GROUPING_STRATEGY",
9
+ "ClassifierType",
10
+ "CrossValidationGroupingStrategy",
11
+ "ProjectDistanceUnit",
12
+ ]
@@ -0,0 +1,9 @@
1
+ from enum import Enum
2
+
3
+
4
+ class ClassifierType(str, Enum):
5
+ """Classifier type for the project."""
6
+
7
+ RANDOM_FOREST = "Random Forest"
8
+ CATBOOST = "CatBoost"
9
+ XGBOOST = "XGBoost"
@@ -0,0 +1,15 @@
1
+ from enum import Enum
2
+
3
+
4
+ class CrossValidationGroupingStrategy(str, Enum):
5
+ """Cross-validation grouping type for the project.
6
+
7
+ Inheriting from str allows for easy serialization to/from JSON (the enum will
8
+ automatically be serialized using the enum value).
9
+ """
10
+
11
+ INDIVIDUAL = "Individual Animal"
12
+ VIDEO = "Video"
13
+
14
+
15
+ DEFAULT_CV_GROUPING_STRATEGY = CrossValidationGroupingStrategy.INDIVIDUAL
@@ -0,0 +1,8 @@
1
+ import enum
2
+
3
+
4
+ class ProjectDistanceUnit(enum.IntEnum):
5
+ """Distance unit for the project."""
6
+
7
+ PIXEL = 0
8
+ CM = 1
@@ -0,0 +1,28 @@
1
+ class PoseHashException(Exception):
2
+ """Exception raised when the hash of a pose file does not match the expected value."""
3
+
4
+ pass
5
+
6
+
7
+ class PoseIdEmbeddingException(Exception):
8
+ """Exception raised for invalid instance_embed_id values in pose file."""
9
+
10
+ pass
11
+
12
+
13
+ class MissingBehaviorError(Exception):
14
+ """Exception raised when a behavior is not found in the prediction file."""
15
+
16
+ pass
17
+
18
+
19
+ class FeatureVersionException(Exception):
20
+ """exception raised when the version of the features in the h5 file is not compatible with the current version of JABS"""
21
+
22
+ pass
23
+
24
+
25
+ class DistanceScaleException(Exception):
26
+ """exception raised when the distance scale factor in the h5 file don't match what the classifier expects"""
27
+
28
+ pass
@@ -0,0 +1,12 @@
1
+ """JABS utilities"""
2
+
3
+ from .update_checker import check_for_update, is_pypi_install
4
+ from .utilities import get_bool_env_var, hash_file, hide_stderr
5
+
6
+ __all__ = [
7
+ "check_for_update",
8
+ "get_bool_env_var",
9
+ "hash_file",
10
+ "hide_stderr",
11
+ "is_pypi_install",
12
+ ]
@@ -0,0 +1,36 @@
1
+ from collections.abc import Generator, Iterable
2
+
3
+ import numpy as np
4
+
5
+ from jabs.core.abstract import PoseEstimation
6
+
7
+
8
+ def gen_line_fragments(
9
+ connected_segments: Iterable[Iterable[PoseEstimation.KeypointIndex]],
10
+ exclude_points: np.ndarray,
11
+ ) -> Generator[list[int], None, None]:
12
+ """generate line fragments from the connected segments.
13
+
14
+ This will break up segments if a point within the segment is excluded,
15
+ or will remove the segment completely if it does not have at least two points
16
+
17
+ Args:
18
+ connected_segments: Iterable of Iterables of KeypointIndex, where each inner
19
+ Iterable represents a segment of connected keypoints
20
+ exclude_points: numpy array of points to exclude when generating segments
21
+
22
+ Yields:
23
+ yields lists of Keypoint indexes that make up the segments to draw
24
+ """
25
+ curr_fragment = []
26
+ for curr_pt_indexes in connected_segments:
27
+ for curr_pt_index in curr_pt_indexes:
28
+ if curr_pt_index.value in exclude_points:
29
+ if len(curr_fragment) >= 2:
30
+ yield curr_fragment
31
+ curr_fragment = []
32
+ else:
33
+ curr_fragment.append(curr_pt_index.value)
34
+ if len(curr_fragment) >= 2:
35
+ yield curr_fragment
36
+ curr_fragment = []
@@ -0,0 +1,223 @@
1
+ import contextlib
2
+ import logging
3
+ import os
4
+ import threading
5
+ import time
6
+ from collections.abc import Callable, Iterable
7
+ from concurrent.futures import Future, ProcessPoolExecutor
8
+ from multiprocessing import shared_memory
9
+ from typing import Any
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+ MAX_POOL_WORKERS = 6
14
+
15
+
16
+ def _noop() -> None:
17
+ """No-op function for warming up worker processes.
18
+
19
+ Must be at module level to be pickleable by ProcessPoolExecutor.
20
+ """
21
+ return None
22
+
23
+
24
+ class ProcessPoolManager:
25
+ """
26
+ Manage a shared ProcessPoolExecutor with warm-up and safe shutdown.
27
+
28
+ Attributes:
29
+ _max_workers (int | None): Maximum number of worker processes. Passed to
30
+ ProcessPoolExecutor when created.
31
+ _initializer (Callable[..., object] | None): Optional function executed in
32
+ each worker process when it starts.
33
+ _initargs (tuple[object, ...]): Arguments passed to the initializer.
34
+ _name (str): Logical name for debugging/logging.
35
+ _executor (ProcessPoolExecutor | None): The lazily-created underlying
36
+ process pool. None until first use.
37
+ _lock (threading.RLock): Protects access to `_executor` and `_is_shutdown`.
38
+ _is_shutdown (bool): Whether shutdown() has been called. Prevents reuse once
39
+ the pool has been shut down.
40
+
41
+ Args:
42
+ max_workers (int | None): Maximum number of worker processes. Defaults to
43
+ os.cpu_count() if None.
44
+ initializer (Callable | None): Optional function run in each worker process
45
+ when it starts.
46
+ initargs (tuple): Arguments passed to the initializer.
47
+ name (str): Optional name used only for debugging/logging.
48
+ """
49
+
50
+ def __init__(
51
+ self,
52
+ max_workers: int | None = None,
53
+ *,
54
+ initializer: Callable[..., object] | None = None,
55
+ initargs: tuple[object, ...] = (),
56
+ name: str = "ProcessPoolManager",
57
+ ) -> None:
58
+ logger.debug(f"PPM __init__ name={name} id={id(self)}")
59
+ requested_workers = max_workers or (os.cpu_count() or 1)
60
+ self._max_workers: int = max(1, min(requested_workers, MAX_POOL_WORKERS))
61
+ self._initializer = initializer
62
+ self._initargs = initargs
63
+ self._name = name
64
+
65
+ self._executor: ProcessPoolExecutor | None = None
66
+ self._lock = threading.RLock() # protects _executor and _is_shutdown
67
+ self._is_shutdown = False
68
+
69
+ self._cancel_shm: shared_memory.SharedMemory | None = None
70
+
71
+ @property
72
+ def max_workers(self) -> int:
73
+ """Maximum number of worker processes in the pool."""
74
+ return self._max_workers
75
+
76
+ @property
77
+ def name(self) -> str:
78
+ """Logical name of the ProcessPoolManager for debugging/logging."""
79
+ return self._name
80
+
81
+ def _ensure_cancel_shm(self) -> shared_memory.SharedMemory:
82
+ """Create the shared-memory cancel flag on first use, if not shut down."""
83
+ with self._lock:
84
+ if self._is_shutdown:
85
+ raise RuntimeError(f"{self._name} has been shut down")
86
+
87
+ if self._cancel_shm is None:
88
+ shm = shared_memory.SharedMemory(create=True, size=1)
89
+ # 0 = not cancelled, 1 = cancelled
90
+ shm.buf[0] = 0
91
+ self._cancel_shm = shm
92
+
93
+ return self._cancel_shm
94
+
95
+ @property
96
+ def cancel_flag_name(self) -> str | None:
97
+ """Name of the shared-memory cancel flag, or None if shut down.
98
+
99
+ Callers can pass this name to worker functions so they can open the
100
+ shared memory and cooperatively check for cancellation.
101
+ """
102
+ with self._lock:
103
+ if self._is_shutdown:
104
+ return None
105
+
106
+ shm = self._ensure_cancel_shm()
107
+ return shm.name
108
+
109
+ def set_cancelled(self) -> None:
110
+ """Set the cancel flag to 1, signalling cooperative cancellation."""
111
+ with self._lock:
112
+ if self._is_shutdown:
113
+ return
114
+
115
+ shm = self._cancel_shm or self._ensure_cancel_shm()
116
+ shm.buf[0] = 1
117
+
118
+ def clear_cancelled(self) -> None:
119
+ """Reset the cancel flag back to 0."""
120
+ with self._lock:
121
+ if self._cancel_shm is not None:
122
+ self._cancel_shm.buf[0] = 0
123
+
124
+ def _ensure_executor(self) -> ProcessPoolExecutor:
125
+ """Create the executor on first use, if not shut down."""
126
+ with self._lock:
127
+ if self._is_shutdown:
128
+ raise RuntimeError(f"{self._name} has been shut down")
129
+
130
+ if self._executor is None:
131
+ # noinspection PyTypeChecker
132
+ self._executor = ProcessPoolExecutor(
133
+ max_workers=self._max_workers,
134
+ initializer=self._initializer,
135
+ initargs=self._initargs,
136
+ )
137
+
138
+ return self._executor
139
+
140
+ def submit(self, fn: Callable[..., Any], *args: Any, **kwargs: Any) -> Future:
141
+ """Submit a task to the process pool."""
142
+ executor = self._ensure_executor()
143
+ return executor.submit(fn, *args, **kwargs)
144
+
145
+ def map(
146
+ self,
147
+ fn: Callable[[Any], Any],
148
+ iterable: Iterable[Any],
149
+ chunksize: int = 1,
150
+ ) -> Iterable[Any]:
151
+ """Map over an iterable using the process pool."""
152
+ executor = self._ensure_executor()
153
+ return executor.map(fn, iterable, chunksize=chunksize)
154
+
155
+ def warm_up(self, wait: bool = True) -> None:
156
+ """Eagerly start worker processes and optionally run trivial tasks.
157
+
158
+ This is useful if you want the cost of spawning processes and running
159
+ initializers to happen at a controlled time (e.g., on app startup)
160
+ instead of on the first real submit().
161
+
162
+ Args:
163
+ wait (bool): If True, submit and wait for trivial tasks to complete
164
+ in each worker process. This ensures that all workers are fully
165
+ initialized and ready to accept real tasks. If False, only starts
166
+ the processes without waiting for task completion.
167
+ """
168
+ start_time = time.time()
169
+ logger.debug(f"PPM warm_up name={self._name} id={id(self)}")
170
+ executor = self._ensure_executor()
171
+ self._ensure_cancel_shm()
172
+
173
+ if not wait:
174
+ return
175
+
176
+ futures = [executor.submit(_noop) for _ in range(self._max_workers)]
177
+ for f in futures:
178
+ with contextlib.suppress(Exception):
179
+ f.result()
180
+
181
+ elapsed = time.time() - start_time
182
+ logger.debug(
183
+ f"PPM warm_up name={self._name} id={id(self)} COMPLETED in {elapsed:.2f} seconds"
184
+ )
185
+
186
+ def shutdown(self, *, wait: bool = True, cancel_futures: bool = False) -> None:
187
+ """Explicitly shut down the process pool.
188
+
189
+ After shutdown, the manager cannot be reused.
190
+ """
191
+ with self._lock:
192
+ self._is_shutdown = True
193
+ executor = self._executor
194
+ if executor is not None:
195
+ with contextlib.suppress(Exception):
196
+ executor.shutdown(wait=wait, cancel_futures=cancel_futures)
197
+ self._executor = None
198
+ if self._cancel_shm is not None:
199
+ with contextlib.suppress(Exception):
200
+ self._cancel_shm.close()
201
+ self._cancel_shm.unlink()
202
+ self._cancel_shm = None
203
+
204
+ def __enter__(self) -> "ProcessPoolManager":
205
+ """Enter context manager, returning self.
206
+
207
+ Allows using the manager in a 'with' statement for automatic cleanup.
208
+ """
209
+ self._ensure_executor()
210
+ return self
211
+
212
+ def __exit__(self, exc_type, exc, tb) -> None:
213
+ """Exit context manager, shutting down the process pool."""
214
+ self.shutdown(wait=True, cancel_futures=False)
215
+
216
+ def __del__(self) -> None:
217
+ """Best-effort cleanup if user code forgets to call shutdown().
218
+
219
+ Note: __del__ is not guaranteed to run at interpreter shutdown, so
220
+ you should still call shutdown() or use the manager as a context manager.
221
+ """
222
+ with contextlib.suppress(Exception):
223
+ self.shutdown(wait=False, cancel_futures=True)
@@ -0,0 +1,269 @@
1
+ import argparse
2
+ import os
3
+ import random
4
+
5
+ import cv2
6
+ import h5py
7
+
8
+ # Command line example of using this script:
9
+ #
10
+ # share_root='/run/user/1000/gvfs/smb-share:server=bht2stor.jax.org,share=vkumar'
11
+ # python src/utils/sampleposeintervals.py \
12
+ # --batch-file UCSD_Rotta_TS_v2.txt \
13
+ # --root-dir "${share_root}" \
14
+ # --out-dir UCSD_Rotta_TS_v2-intervals \
15
+ # --out-frame-count 9000 \
16
+ # --start-frame 54000 \
17
+ # --pose-version 3
18
+ #
19
+ # share_root='/media/sheppk/TOSHIBA EXT/rotta-data/UCSD_Rotta_TS_v2-vidcache'
20
+ # python src/utils/sampleposeintervals.py \
21
+ # --batch-file "${share_root}/batch.txt" \
22
+ # --root-dir "${share_root}" \
23
+ # --out-dir UCSD_Rotta_TS_v2-intervals-2021-05-25 \
24
+ # --out-frame-count 9000 \
25
+ # --start-frame 27000 \
26
+ # --pose-version 3
27
+
28
+ # python src/utils/sampleposeintervals.py \
29
+ # --batch-file ~/projects/social-interaction/data/bxd-batch-early-morning-2021-06-09.txt \
30
+ # --root-dir '/run/user/1000/gvfs/smb-share:server=bht2stor.jax.org,share=vkumar' \
31
+ # --out-dir bxd-batch-early-morning-2021-06-09 \
32
+ # --out-frame-count 9000 \
33
+ # --start-frame 54000 \
34
+ # --pose-version 3
35
+
36
+ # python src/utils/sampleposeintervals.py \
37
+ # --batch-file temp/B6J-and-BTBR-3M-strangers-4-day-rand-2021-05-24.txt \
38
+ # --root-dir '/media/sheppk/TOSHIBA EXT/rotta-data/B6J-and-BTBR-3M-strangers-4-day-rand-2021-05-24' \
39
+ # --out-dir B6J-and-BTBR-3M-strangers-4-day-rand-samples-2021-05-24 \
40
+ # --out-frame-count 3600 \
41
+ # --start-frame 6000 \
42
+ # --pose-version 3
43
+
44
+ # python src/utils/sampleposeintervals.py \
45
+ # --batch-file temp/B6J-and-BTBR-3M-strangers-4-day-rand-2021-05-24.txt \
46
+ # --root-dir '/media/sheppk/TOSHIBA EXT/rotta-data/B6J_and_BTBR_3M_stranger_4day_2021-07-20' \
47
+ # --out-dir temp/B6J-and-BTBR-3M-strangers-4-day-rand-samples-2021-08-05 \
48
+ # --out-frame-count 3600 \
49
+ # --start-frame 6000 \
50
+ # --only-pose \
51
+ # --pose-version 4
52
+
53
+ # rclone copy --transfers 4 --progress \
54
+ # --include-from /home/sheppk/projects/behavior-classifier/temp/BTBR_3M_stranger_4day-subset-avi.txt \
55
+ # "labdropbox:/KumarLab's shared workspace/VideoData/MDS_Tests/BTBR_3M_stranger_4day" \
56
+ # /media/sheppk/TOSHIBA\ EXT/BTBR_3M_stranger_4day-2021-08-24
57
+ # rclone copy --transfers 4 --progress \
58
+ # --include-from /home/sheppk/projects/behavior-classifier/temp/BTBR_3M_stranger_4day-subset-pose.txt \
59
+ # /home/sheppk/sshfs/winterproj/bgeuther/IdentityInfer/Data/BTBR_3M_stranger_4day \
60
+ # /media/sheppk/TOSHIBA\ EXT/BTBR_3M_stranger_4day-2021-08-24
61
+ # python src/utils/sampleposeintervals.py \
62
+ # --batch-file /media/sheppk/TOSHIBA\ EXT/BTBR_3M_stranger_4day-2021-08-24/batch.txt \
63
+ # --root-dir /media/sheppk/TOSHIBA\ EXT/BTBR_3M_stranger_4day-2021-08-24 \
64
+ # --out-dir /media/sheppk/TOSHIBA\ EXT/BTBR_3M_stranger_4day-2021-08-24-samples \
65
+ # --out-frame-count 3600 \
66
+ # --start-frame 6000 \
67
+ # --pose-version 4
68
+
69
+
70
+ def main():
71
+ """sample pose intervals"""
72
+ parser = argparse.ArgumentParser()
73
+
74
+ parser.add_argument(
75
+ "--batch-file",
76
+ help="path to the file that is a new-line separated list of all videos to process",
77
+ required=True,
78
+ )
79
+ parser.add_argument(
80
+ "--root-dir",
81
+ help="the root directory. All paths given in the batch files are relative to this root",
82
+ required=True,
83
+ )
84
+ parser.add_argument(
85
+ "--out-dir",
86
+ help="output directory. The videos and pose files for sampled intervals are saved to this dir",
87
+ required=True,
88
+ )
89
+ parser.add_argument(
90
+ "--out-frame-count",
91
+ help="this defines how many frames to save. Assuming 30fps a value of 1800 corresponds to one minute",
92
+ required=True,
93
+ type=int,
94
+ )
95
+ parser.add_argument(
96
+ "--start-frame",
97
+ help="this argument specifies which frame we start at. If this option is not specified we randomly select"
98
+ " a start frame from the video.",
99
+ required=False,
100
+ type=int,
101
+ )
102
+ parser.add_argument(
103
+ "--pose-version",
104
+ help="give the integer version number that should be used for pose",
105
+ default=2,
106
+ type=int,
107
+ choices=(2, 3, 4, 5),
108
+ )
109
+ parser.add_argument(
110
+ "--only-pose",
111
+ help="if specified this option will sample pose data and exclude video from output",
112
+ action="store_true",
113
+ )
114
+
115
+ args = parser.parse_args()
116
+
117
+ if args.pose_version == 2:
118
+ pose_suffix = "_pose_est_v2.h5"
119
+ elif args.pose_version == 3:
120
+ pose_suffix = "_pose_est_v3.h5"
121
+ elif args.pose_version == 4:
122
+ pose_suffix = "_pose_est_v4.h5"
123
+ elif args.pose_version == 5:
124
+ pose_suffix = "_pose_est_v5.h5"
125
+ else:
126
+ raise NotImplementedError("pose version not implemented: " + str(args.pose_version))
127
+
128
+ os.makedirs(args.out_dir, exist_ok=True)
129
+
130
+ with open(args.batch_file) as batch_file:
131
+ for line in batch_file:
132
+ vid_filename = line.strip()
133
+ if vid_filename:
134
+ print("Processing:", vid_filename)
135
+ vid_path = os.path.join(args.root_dir, vid_filename)
136
+ vid_path_root, _ = os.path.splitext(vid_path)
137
+ pose_in_path = vid_path_root + pose_suffix
138
+
139
+ if not args.only_pose and not os.path.isfile(vid_path):
140
+ print("WARNING: missing video path:", vid_path)
141
+ continue
142
+
143
+ if not os.path.isfile(pose_in_path):
144
+ print("WARNING: missing pose path:", pose_in_path)
145
+ continue
146
+
147
+ with h5py.File(pose_in_path, "r") as pose_in:
148
+ frame_count = pose_in["poseest"]["confidence"].shape[0]
149
+
150
+ last_candidate_frame = frame_count - args.out_frame_count
151
+ if last_candidate_frame <= 0:
152
+ print(
153
+ f"WARNING: {vid_filename} skipped because it only contains {frame_count} frames"
154
+ )
155
+ continue
156
+
157
+ if args.start_frame is None:
158
+ out_start_frame_index = random.randrange(last_candidate_frame)
159
+ else:
160
+ out_start_frame_index = args.start_frame - 1
161
+
162
+ vid_out_filename = vid_filename.replace("/", "+").replace("\\", "+")
163
+ vid_out_path = os.path.join(args.out_dir, vid_out_filename)
164
+ vid_out_path_root, _ = os.path.splitext(vid_out_path)
165
+ vid_out_path = (
166
+ vid_out_path_root + "_" + str(out_start_frame_index + 1) + ".avi"
167
+ )
168
+ pose_out_path = (
169
+ vid_out_path_root + "_" + str(out_start_frame_index + 1) + pose_suffix
170
+ )
171
+
172
+ with h5py.File(pose_out_path, "w") as pose_out:
173
+ # pose v2 stuff
174
+ start = out_start_frame_index
175
+ stop = start + args.out_frame_count
176
+ pose_out["poseest/points"] = pose_in["poseest/points"][start:stop, ...]
177
+ pose_out["poseest/confidence"] = pose_in["poseest/confidence"][
178
+ start:stop, ...
179
+ ]
180
+
181
+ # pose v3 stuff
182
+ if "instance_count" in pose_in["poseest"]:
183
+ pose_out["poseest/instance_count"] = pose_in["poseest/instance_count"][
184
+ start:stop, ...
185
+ ]
186
+ if "instance_embedding" in pose_in["poseest"]:
187
+ pose_out["poseest/instance_embedding"] = pose_in[
188
+ "poseest/instance_embedding"
189
+ ][start:stop, ...]
190
+ if "instance_track_id" in pose_in["poseest"]:
191
+ pose_out["poseest/instance_track_id"] = pose_in[
192
+ "poseest/instance_track_id"
193
+ ][start:stop, ...]
194
+
195
+ # pose v4 stuff
196
+ if "id_mask" in pose_in["poseest"]:
197
+ pose_out["poseest/id_mask"] = pose_in["poseest/id_mask"][
198
+ start:stop, ...
199
+ ]
200
+ if "identity_embeds" in pose_in["poseest"]:
201
+ pose_out["poseest/identity_embeds"] = pose_in[
202
+ "poseest/identity_embeds"
203
+ ][start:stop, ...]
204
+ if "instance_embed_id" in pose_in["poseest"]:
205
+ pose_out["poseest/instance_embed_id"] = pose_in[
206
+ "poseest/instance_embed_id"
207
+ ][start:stop, ...]
208
+ if "instance_id_center" in pose_in["poseest"]:
209
+ pose_out["poseest/instance_id_center"] = pose_in[
210
+ "poseest/instance_id_center"
211
+ ][:]
212
+
213
+ # v5 specific stuff
214
+ if "static_objects" in pose_in:
215
+ static_group = pose_out.create_group("static_objects")
216
+ for dataset in pose_in["static_objects"]:
217
+ static_group.create_dataset(
218
+ dataset, data=pose_in["static_objects"][dataset]
219
+ )
220
+
221
+ # copy attributes
222
+ for attr in pose_in["poseest"].attrs:
223
+ pose_out["poseest"].attrs[attr] = pose_in["poseest"].attrs[attr]
224
+
225
+ cap = None
226
+ writer = None
227
+
228
+ if not args.only_pose:
229
+ try:
230
+ cap = cv2.VideoCapture(vid_path)
231
+ if not cap.isOpened():
232
+ print(f"WARNING: failed to open {vid_filename}")
233
+ continue
234
+
235
+ cap.set(cv2.CAP_PROP_POS_FRAMES, out_start_frame_index)
236
+ if not cap.isOpened():
237
+ print(f"WARNING: failed to seek to start frame {vid_filename}")
238
+ continue
239
+
240
+ writer = cv2.VideoWriter(
241
+ vid_out_path,
242
+ cv2.VideoWriter_fourcc(*"MJPG"),
243
+ 30,
244
+ (
245
+ int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
246
+ int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)),
247
+ ),
248
+ )
249
+ for _ in range(args.out_frame_count):
250
+ if not cap.isOpened():
251
+ print(f"WARNING: {vid_filename} ended prematurely")
252
+ break
253
+
254
+ ret, frame = cap.read()
255
+ if ret:
256
+ writer.write(frame)
257
+ else:
258
+ print(f"WARNING: {vid_filename} ended prematurely")
259
+ break
260
+
261
+ finally:
262
+ if cap is not None:
263
+ cap.release()
264
+ if writer is not None:
265
+ writer.release()
266
+
267
+
268
+ if __name__ == "__main__":
269
+ main()
@@ -0,0 +1,54 @@
1
+ """Utilities for checking PyPI for JABS updates."""
2
+
3
+ import json
4
+ import logging
5
+ import urllib.request
6
+ from importlib import metadata
7
+
8
+ from packaging.version import parse as parse_version
9
+
10
+ # TODO: Consider moving this to jabs.core
11
+ from jabs.version import version_str
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ def check_for_update() -> tuple[bool, str | None, str]:
17
+ """Check PyPI for newer version of jabs-behavior-classifier.
18
+
19
+ Returns:
20
+ tuple: (has_update: bool, latest_version: str | None, current_version: str)
21
+ - has_update: True if a newer version is available
22
+ - latest_version: Latest version string from PyPI, or None if check failed
23
+ - current_version: Current installed version string
24
+ """
25
+ try:
26
+ current_version = version_str()
27
+
28
+ with urllib.request.urlopen(
29
+ "https://pypi.org/pypi/jabs-behavior-classifier/json", timeout=5
30
+ ) as response:
31
+ data = json.loads(response.read())
32
+ latest_version = data["info"]["version"]
33
+
34
+ has_update = parse_version(latest_version) > parse_version(current_version)
35
+ return has_update, latest_version, current_version
36
+ except Exception as e:
37
+ logger.warning(f"Failed to check for updates: {e}")
38
+ return False, None, version_str()
39
+
40
+
41
+ def is_pypi_install() -> bool:
42
+ """Check if jabs-behavior-classifier was installed from PyPI.
43
+
44
+ Returns:
45
+ bool: True if installed via pip from PyPI, False otherwise
46
+ """
47
+ try:
48
+ dist = metadata.distribution("jabs-behavior-classifier")
49
+ # Check if installer was pip
50
+ installer = dist.read_text("INSTALLER")
51
+ return installer is not None and installer.strip() in ("pip", "uv")
52
+ except Exception as e:
53
+ logger.debug(f"Could not determine installation method: {e}")
54
+ return False
@@ -0,0 +1,64 @@
1
+ import hashlib
2
+ import os
3
+ import sys
4
+ from collections.abc import Generator
5
+ from contextlib import contextmanager
6
+ from pathlib import Path
7
+ from typing import Any
8
+
9
+
10
+ @contextmanager
11
+ def hide_stderr() -> Generator[int, Any, None]:
12
+ """Context manager to temporarily suppress output to standard error (stderr).
13
+
14
+ Redirects all output sent to stderr to os.devnull while the context is active,
15
+ restoring stderr to its original state upon exit.
16
+
17
+ Yields:
18
+ int: The file descriptor for stderr.
19
+ """
20
+ fd = sys.stderr.fileno()
21
+
22
+ # copy fd before it is overwritten
23
+ with os.fdopen(os.dup(fd), "wb") as copied:
24
+ sys.stderr.flush()
25
+
26
+ # open destination
27
+ with open(os.devnull, "wb") as fout:
28
+ os.dup2(fout.fileno(), fd)
29
+ try:
30
+ yield fd
31
+ finally:
32
+ # restore stderr to its previous value
33
+ sys.stderr.flush()
34
+ os.dup2(copied.fileno(), fd)
35
+
36
+
37
+ def hash_file(file: Path):
38
+ """return hash"""
39
+ chunk_size = 8192
40
+ with file.open("rb") as f:
41
+ h = hashlib.blake2b(digest_size=20)
42
+ c = f.read(chunk_size)
43
+ while c:
44
+ h.update(c)
45
+ c = f.read(chunk_size)
46
+ return h.hexdigest()
47
+
48
+
49
+ def get_bool_env_var(var_name, default_value=False) -> bool:
50
+ """Gets a boolean value from an environment variable.
51
+
52
+ Args:
53
+ var_name: The name of the environment variable.
54
+ default_value: The default value to return if the variable is
55
+ not set or invalid.
56
+
57
+ Returns:
58
+ A boolean value.
59
+ """
60
+ value = os.getenv(var_name)
61
+ if value is None:
62
+ return default_value
63
+
64
+ return value.lower() in ("true", "1", "yes", "on", "y", "t")