jabs-core 0.1.0a1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jabs_core-0.1.0a1/PKG-INFO +43 -0
- jabs_core-0.1.0a1/README.md +28 -0
- jabs_core-0.1.0a1/pyproject.toml +44 -0
- jabs_core-0.1.0a1/src/jabs/core/__init__.py +1 -0
- jabs_core-0.1.0a1/src/jabs/core/abstract/__init__.py +7 -0
- jabs_core-0.1.0a1/src/jabs/core/abstract/pose_est.py +421 -0
- jabs_core-0.1.0a1/src/jabs/core/constants.py +15 -0
- jabs_core-0.1.0a1/src/jabs/core/enums/__init__.py +12 -0
- jabs_core-0.1.0a1/src/jabs/core/enums/classifier_types.py +9 -0
- jabs_core-0.1.0a1/src/jabs/core/enums/cv_grouping.py +15 -0
- jabs_core-0.1.0a1/src/jabs/core/enums/units.py +8 -0
- jabs_core-0.1.0a1/src/jabs/core/exceptions.py +28 -0
- jabs_core-0.1.0a1/src/jabs/core/utils/__init__.py +12 -0
- jabs_core-0.1.0a1/src/jabs/core/utils/pose_util.py +36 -0
- jabs_core-0.1.0a1/src/jabs/core/utils/process_pool_manager.py +223 -0
- jabs_core-0.1.0a1/src/jabs/core/utils/sampleposeintervals.py +269 -0
- jabs_core-0.1.0a1/src/jabs/core/utils/update_checker.py +54 -0
- jabs_core-0.1.0a1/src/jabs/core/utils/utilities.py +64 -0
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
Metadata-Version: 2.3
|
|
2
|
+
Name: jabs-core
|
|
3
|
+
Version: 0.1.0a1
|
|
4
|
+
Summary: Add your description here
|
|
5
|
+
Requires-Dist: packaging>=24.0
|
|
6
|
+
Requires-Dist: toml>=0.10.2,<0.11.0
|
|
7
|
+
Requires-Dist: h5py>=3.10.0,<4.0.0
|
|
8
|
+
Requires-Dist: shapely>=2.0.1,<3.0.0
|
|
9
|
+
Requires-Dist: numpy>=2.0.0,<3.0.0
|
|
10
|
+
Requires-Dist: opencv-python-headless>=4.8.1.78,<5.0.0
|
|
11
|
+
Requires-Python: >=3.10, <3.15
|
|
12
|
+
Project-URL: Repository, https://github.com/KumarLabJax/JABS-behavior-classifier
|
|
13
|
+
Project-URL: Issues, https://github.com/KumarLabJax/JABS-behavior-classifier/issues
|
|
14
|
+
Description-Content-Type: text/markdown
|
|
15
|
+
|
|
16
|
+
# JABS Core (`jabs-core`)
|
|
17
|
+
|
|
18
|
+
The infrastructure and shared utility layer for the JABS.
|
|
19
|
+
|
|
20
|
+
## Overview
|
|
21
|
+
|
|
22
|
+
`jabs-core` provides low-level, domain-agnostic utilities used across all JABS packages.
|
|
23
|
+
It is designed to be lightweight and free of heavy scientific dependencies (like
|
|
24
|
+
`scikit-learn` or `pandas`), making it safe to import at any level of the hierarchy.
|
|
25
|
+
|
|
26
|
+
## Responsibilities
|
|
27
|
+
|
|
28
|
+
- **Shared Constants**: Global constants used for file compression and configuration.
|
|
29
|
+
- **Exceptions**: Centralized exception hierarchy (`JabsError`, `PoseHashException`,
|
|
30
|
+
etc.).
|
|
31
|
+
- **Infrastructure**: Base classes for registries and plugin discovery systems.
|
|
32
|
+
- **Abstract Bases**: High-level interface definitions (e.g., the `PoseEstimation`
|
|
33
|
+
abstract base).
|
|
34
|
+
- **Utility Functions**: Generic helpers for file hashing, logging configuration, and
|
|
35
|
+
basic string/path manipulation.
|
|
36
|
+
|
|
37
|
+
## Package Structure
|
|
38
|
+
|
|
39
|
+
- `jabs.core.constants`: Global constants.
|
|
40
|
+
- `jabs.core.exceptions`: Shared exception classes.
|
|
41
|
+
- `jabs.core.abstract`: Abstract base classes for the system.
|
|
42
|
+
- `jabs.core.utils`: Generic utility functions.
|
|
43
|
+
- `jabs.core.enums`: Shared enumerations (e.g., `ClassifierType`).
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
# JABS Core (`jabs-core`)
|
|
2
|
+
|
|
3
|
+
The infrastructure and shared utility layer for the JABS.
|
|
4
|
+
|
|
5
|
+
## Overview
|
|
6
|
+
|
|
7
|
+
`jabs-core` provides low-level, domain-agnostic utilities used across all JABS packages.
|
|
8
|
+
It is designed to be lightweight and free of heavy scientific dependencies (like
|
|
9
|
+
`scikit-learn` or `pandas`), making it safe to import at any level of the hierarchy.
|
|
10
|
+
|
|
11
|
+
## Responsibilities
|
|
12
|
+
|
|
13
|
+
- **Shared Constants**: Global constants used for file compression and configuration.
|
|
14
|
+
- **Exceptions**: Centralized exception hierarchy (`JabsError`, `PoseHashException`,
|
|
15
|
+
etc.).
|
|
16
|
+
- **Infrastructure**: Base classes for registries and plugin discovery systems.
|
|
17
|
+
- **Abstract Bases**: High-level interface definitions (e.g., the `PoseEstimation`
|
|
18
|
+
abstract base).
|
|
19
|
+
- **Utility Functions**: Generic helpers for file hashing, logging configuration, and
|
|
20
|
+
basic string/path manipulation.
|
|
21
|
+
|
|
22
|
+
## Package Structure
|
|
23
|
+
|
|
24
|
+
- `jabs.core.constants`: Global constants.
|
|
25
|
+
- `jabs.core.exceptions`: Shared exception classes.
|
|
26
|
+
- `jabs.core.abstract`: Abstract base classes for the system.
|
|
27
|
+
- `jabs.core.utils`: Generic utility functions.
|
|
28
|
+
- `jabs.core.enums`: Shared enumerations (e.g., `ClassifierType`).
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "jabs-core"
|
|
3
|
+
version = "0.1.0a1"
|
|
4
|
+
description = "Add your description here"
|
|
5
|
+
readme = "README.md"
|
|
6
|
+
requires-python = ">=3.10,<3.15"
|
|
7
|
+
dependencies = [
|
|
8
|
+
"packaging>=24.0",
|
|
9
|
+
"toml>=0.10.2,<0.11.0",
|
|
10
|
+
"h5py>=3.10.0,<4.0.0",
|
|
11
|
+
"shapely>=2.0.1,<3.0.0",
|
|
12
|
+
"numpy>=2.0.0,<3.0.0",
|
|
13
|
+
"opencv-python-headless>=4.8.1.78,<5.0.0",
|
|
14
|
+
]
|
|
15
|
+
|
|
16
|
+
[dependency-groups]
|
|
17
|
+
dev = [
|
|
18
|
+
{include-group = "lint"},
|
|
19
|
+
{include-group = "test"},
|
|
20
|
+
{include-group = "docs"},
|
|
21
|
+
"pre-commit>=4.2.0,<5.0.0",
|
|
22
|
+
"matplotlib>=3.9.3,<4.0.0",
|
|
23
|
+
]
|
|
24
|
+
test = [
|
|
25
|
+
"pytest>=8.3.4,<9.0.0",
|
|
26
|
+
"pytest-cov>=7.0.0",
|
|
27
|
+
]
|
|
28
|
+
lint = [
|
|
29
|
+
"ruff>=0.11.5,<0.12.0",
|
|
30
|
+
]
|
|
31
|
+
docs = [
|
|
32
|
+
"mkdocs>=1.6.1",
|
|
33
|
+
]
|
|
34
|
+
|
|
35
|
+
[project.urls]
|
|
36
|
+
Repository = "https://github.com/KumarLabJax/JABS-behavior-classifier"
|
|
37
|
+
Issues = "https://github.com/KumarLabJax/JABS-behavior-classifier/issues"
|
|
38
|
+
|
|
39
|
+
[build-system]
|
|
40
|
+
requires = ["uv_build>=0.9.26,<0.10.0"]
|
|
41
|
+
build-backend = "uv_build"
|
|
42
|
+
|
|
43
|
+
[tool.uv.build-backend]
|
|
44
|
+
module-name = "jabs.core"
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""The root of the jabs.core package."""
|
|
@@ -0,0 +1,421 @@
|
|
|
1
|
+
import enum
|
|
2
|
+
import logging
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
import h5py
|
|
7
|
+
import joblib
|
|
8
|
+
import numpy as np
|
|
9
|
+
from shapely.geometry import MultiPoint
|
|
10
|
+
|
|
11
|
+
from jabs.core.utils import hash_file
|
|
12
|
+
|
|
13
|
+
MINIMUM_CONFIDENCE = 0.3
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class PoseEstimation(ABC):
|
|
17
|
+
"""Abstract base class for pose estimation data handlers.
|
|
18
|
+
|
|
19
|
+
Provides a common interface for loading, accessing, and processing pose data
|
|
20
|
+
from HDF5 files. Defines methods for retrieving keypoints, confidence masks, identity
|
|
21
|
+
presence, and static objects, as well as utilities for geometric computations such as
|
|
22
|
+
convex hulls and bearing angles. All pose estimation versioned classes should inherit
|
|
23
|
+
from this base class.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
file_path (Path): Path to the pose HDF5 file.
|
|
27
|
+
cache_dir (Path | None): Optional cache directory for intermediate data.
|
|
28
|
+
fps (int): Frames per second for the video.
|
|
29
|
+
|
|
30
|
+
Abstract Methods:
|
|
31
|
+
get_points(frame_index, identity, scale): Get points and mask for an identity in a frame.
|
|
32
|
+
get_identity_poses(identity, scale): Get all points and masks for an identity.
|
|
33
|
+
get_identity_point_mask(identity): Get the point mask array for a given identity.
|
|
34
|
+
identity_mask(identity): Get the identity mask for a given identity.
|
|
35
|
+
identity_to_track: Get the identity-to-track mapping for this file.
|
|
36
|
+
format_major_version: Returns the major version of the pose file format.
|
|
37
|
+
|
|
38
|
+
Methods:
|
|
39
|
+
get_identity_convex_hulls(identity): Get convex hulls for an identity across frames.
|
|
40
|
+
compute_bearing(points): Compute the bearing angle for a single frame.
|
|
41
|
+
compute_all_bearings(identity): Compute bearing angles for all frames of an identity.
|
|
42
|
+
get_pose_file_attributes(path): Static method to get HDF5 file attributes.
|
|
43
|
+
|
|
44
|
+
Properties:
|
|
45
|
+
num_frames (int): Number of frames.
|
|
46
|
+
identities (list): List of identities.
|
|
47
|
+
num_identities (int): Number of identities.
|
|
48
|
+
cm_per_pixel (float | None): Centimeters per pixel.
|
|
49
|
+
fps (int): Frames per second.
|
|
50
|
+
pose_file (Path): Path to the pose file.
|
|
51
|
+
hash (str): Hash of the pose file.
|
|
52
|
+
static_objects (dict): Static objects in the pose file.
|
|
53
|
+
num_lixit_keypoints (int): Number of lixit keypoints (default 0).
|
|
54
|
+
external_identities (list[int] | None): Mapping to external identities.
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
class KeypointIndex(enum.IntEnum):
|
|
58
|
+
"""enum defining the 12 keypoint indexes"""
|
|
59
|
+
|
|
60
|
+
NOSE = 0
|
|
61
|
+
LEFT_EAR = 1
|
|
62
|
+
RIGHT_EAR = 2
|
|
63
|
+
BASE_NECK = 3
|
|
64
|
+
LEFT_FRONT_PAW = 4
|
|
65
|
+
RIGHT_FRONT_PAW = 5
|
|
66
|
+
CENTER_SPINE = 6
|
|
67
|
+
LEFT_REAR_PAW = 7
|
|
68
|
+
RIGHT_REAR_PAW = 8
|
|
69
|
+
BASE_TAIL = 9
|
|
70
|
+
MID_TAIL = 10
|
|
71
|
+
TIP_TAIL = 11
|
|
72
|
+
|
|
73
|
+
# Connected segments to use when full 12 keypoints are available.
|
|
74
|
+
FULL_CONNECTED_SEGMENTS = (
|
|
75
|
+
(
|
|
76
|
+
KeypointIndex.LEFT_FRONT_PAW,
|
|
77
|
+
KeypointIndex.CENTER_SPINE,
|
|
78
|
+
KeypointIndex.RIGHT_FRONT_PAW,
|
|
79
|
+
),
|
|
80
|
+
(
|
|
81
|
+
KeypointIndex.LEFT_REAR_PAW,
|
|
82
|
+
KeypointIndex.BASE_TAIL,
|
|
83
|
+
KeypointIndex.RIGHT_REAR_PAW,
|
|
84
|
+
),
|
|
85
|
+
(
|
|
86
|
+
KeypointIndex.NOSE,
|
|
87
|
+
KeypointIndex.BASE_NECK,
|
|
88
|
+
KeypointIndex.CENTER_SPINE,
|
|
89
|
+
KeypointIndex.BASE_TAIL,
|
|
90
|
+
KeypointIndex.MID_TAIL,
|
|
91
|
+
KeypointIndex.TIP_TAIL,
|
|
92
|
+
),
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
# Pose based on the Envision Hydra model will have fewer keypoints,
|
|
96
|
+
# so we adjust the connected segments accordingly.
|
|
97
|
+
NVSN_CONNECTED_SEGMENTS = (
|
|
98
|
+
(
|
|
99
|
+
KeypointIndex.LEFT_EAR,
|
|
100
|
+
KeypointIndex.NOSE,
|
|
101
|
+
KeypointIndex.RIGHT_EAR,
|
|
102
|
+
),
|
|
103
|
+
(
|
|
104
|
+
KeypointIndex.NOSE,
|
|
105
|
+
KeypointIndex.BASE_TAIL,
|
|
106
|
+
KeypointIndex.TIP_TAIL,
|
|
107
|
+
),
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
_CACHE_FILE_VERSION = 1
|
|
111
|
+
|
|
112
|
+
def __init__(self, file_path: Path, cache_dir: Path | None = None, fps: int = 30):
|
|
113
|
+
"""initialize new object from h5 file
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
file_path: path to pose_est_v2.h5 file
|
|
117
|
+
cache_dir: optional cache directory, used to cache convex
|
|
118
|
+
hulls
|
|
119
|
+
fps: frames per second, used for scaling time series
|
|
120
|
+
features
|
|
121
|
+
for faster loading
|
|
122
|
+
from "per frame" to "per second"
|
|
123
|
+
"""
|
|
124
|
+
super().__init__()
|
|
125
|
+
self._num_frames = 0
|
|
126
|
+
self._identities = []
|
|
127
|
+
self._external_identities: list[str] | None = None
|
|
128
|
+
self._convex_hull_cache = {}
|
|
129
|
+
self._path = file_path
|
|
130
|
+
self._cache_dir = cache_dir
|
|
131
|
+
self._cm_per_pixel = None
|
|
132
|
+
self._hash = hash_file(file_path)
|
|
133
|
+
self._fps = fps
|
|
134
|
+
|
|
135
|
+
self._static_objects = {}
|
|
136
|
+
|
|
137
|
+
# check cache version, if it doesn't match, clear the cache file for this pose file
|
|
138
|
+
if self._cache_dir is not None and not self.check_cache_version():
|
|
139
|
+
cache_file = self._cache_file_path()
|
|
140
|
+
if cache_file and cache_file.exists():
|
|
141
|
+
try:
|
|
142
|
+
cache_file.unlink()
|
|
143
|
+
except Exception:
|
|
144
|
+
logging.warning("Unable to delete old cache file %s", cache_file)
|
|
145
|
+
pass
|
|
146
|
+
|
|
147
|
+
@property
|
|
148
|
+
def num_frames(self) -> int:
|
|
149
|
+
"""return the number of frames in the pose_est file"""
|
|
150
|
+
return self._num_frames
|
|
151
|
+
|
|
152
|
+
@property
|
|
153
|
+
def identities(self):
|
|
154
|
+
"""return list of integer identities generated from file"""
|
|
155
|
+
return self._identities
|
|
156
|
+
|
|
157
|
+
@property
|
|
158
|
+
def num_identities(self) -> int:
|
|
159
|
+
"""get the number of identities in the pose file"""
|
|
160
|
+
return len(self._identities)
|
|
161
|
+
|
|
162
|
+
@property
|
|
163
|
+
def cm_per_pixel(self):
|
|
164
|
+
"""get centimeters per pixel for video/pose"""
|
|
165
|
+
return self._cm_per_pixel
|
|
166
|
+
|
|
167
|
+
@property
|
|
168
|
+
def fps(self):
|
|
169
|
+
"""get frames per second"""
|
|
170
|
+
return self._fps
|
|
171
|
+
|
|
172
|
+
@property
|
|
173
|
+
def pose_file(self):
|
|
174
|
+
"""get the path to the pose file"""
|
|
175
|
+
return self._path
|
|
176
|
+
|
|
177
|
+
@property
|
|
178
|
+
def hash(self):
|
|
179
|
+
"""get the hash of the pose file"""
|
|
180
|
+
return self._hash
|
|
181
|
+
|
|
182
|
+
@abstractmethod
|
|
183
|
+
def get_points(self, frame_index: int, identity: int, scale: float | None = None):
|
|
184
|
+
"""return points and point masks for an individual frame
|
|
185
|
+
|
|
186
|
+
Args:
|
|
187
|
+
frame_index: frame index of points and masks to be returned
|
|
188
|
+
identity: identity to return points for
|
|
189
|
+
scale: optional scale factor, set to cm_per_pixel to convert
|
|
190
|
+
poses from pixel coordinates to cm coordinates
|
|
191
|
+
|
|
192
|
+
Returns:
|
|
193
|
+
numpy array of points (12,2), numpy array of point masks (12,)
|
|
194
|
+
"""
|
|
195
|
+
pass
|
|
196
|
+
|
|
197
|
+
@abstractmethod
|
|
198
|
+
def get_identity_poses(self, identity: int, scale: float | None = None):
|
|
199
|
+
"""return all points and point masks
|
|
200
|
+
|
|
201
|
+
Args:
|
|
202
|
+
identity: identity to return points for
|
|
203
|
+
scale: optional scale factor, set to cm_per_pixel to convert
|
|
204
|
+
poses from pixel coordinates to cm coordinates
|
|
205
|
+
|
|
206
|
+
Returns:
|
|
207
|
+
numpy array of points (#frames, 12, 2), numpy array of point masks (#frames, 12)
|
|
208
|
+
"""
|
|
209
|
+
pass
|
|
210
|
+
|
|
211
|
+
@abstractmethod
|
|
212
|
+
def get_identity_point_mask(self, identity):
|
|
213
|
+
"""get the point mask array for a given identity
|
|
214
|
+
|
|
215
|
+
Args:
|
|
216
|
+
identity: identity to return point mask for
|
|
217
|
+
|
|
218
|
+
Returns:
|
|
219
|
+
array of point masks (#frames, 12)
|
|
220
|
+
"""
|
|
221
|
+
pass
|
|
222
|
+
|
|
223
|
+
@abstractmethod
|
|
224
|
+
def get_reduced_point_mask(self):
|
|
225
|
+
"""Returns a boolean array of length 12 indicating which keypoints are valid.
|
|
226
|
+
|
|
227
|
+
Determines which keypoints are valid for any identity across all frames.
|
|
228
|
+
|
|
229
|
+
Returns:
|
|
230
|
+
numpy array of shape (12,) with boolean values indicating validity
|
|
231
|
+
of each keypoint.
|
|
232
|
+
"""
|
|
233
|
+
pass
|
|
234
|
+
|
|
235
|
+
def get_connected_segments(self):
|
|
236
|
+
"""Get the segments to use for rendering connections between the keypoints
|
|
237
|
+
|
|
238
|
+
Returns:
|
|
239
|
+
list of tuples, where each tuple contains the indexes of the keypoints
|
|
240
|
+
that form a connected segment
|
|
241
|
+
"""
|
|
242
|
+
return PoseEstimation.FULL_CONNECTED_SEGMENTS
|
|
243
|
+
|
|
244
|
+
@abstractmethod
|
|
245
|
+
def identity_mask(self, identity):
|
|
246
|
+
"""get the identity mask (indicates if specified identity is present in each frame)
|
|
247
|
+
|
|
248
|
+
Args:
|
|
249
|
+
identity: identity to get masks for
|
|
250
|
+
|
|
251
|
+
Returns:
|
|
252
|
+
numpy array of size (#frames,)
|
|
253
|
+
"""
|
|
254
|
+
pass
|
|
255
|
+
|
|
256
|
+
@property
|
|
257
|
+
@abstractmethod
|
|
258
|
+
def identity_to_track(self):
|
|
259
|
+
"""get the identity to track mapping for this file"""
|
|
260
|
+
pass
|
|
261
|
+
|
|
262
|
+
@property
|
|
263
|
+
@abstractmethod
|
|
264
|
+
def format_major_version(self):
|
|
265
|
+
"""an integer giving the major version of the format"""
|
|
266
|
+
pass
|
|
267
|
+
|
|
268
|
+
@property
|
|
269
|
+
def static_objects(self):
|
|
270
|
+
"""get static objects from the pose file"""
|
|
271
|
+
return self._static_objects
|
|
272
|
+
|
|
273
|
+
def get_identity_convex_hulls(self, identity):
|
|
274
|
+
"""get a list of length #frames containing convex hulls for the given identity.
|
|
275
|
+
|
|
276
|
+
The convex hulls are calculated using all valid points except for the
|
|
277
|
+
middle of tail and tip of tail points.
|
|
278
|
+
|
|
279
|
+
Args:
|
|
280
|
+
identity: identity to return points for
|
|
281
|
+
|
|
282
|
+
Returns:
|
|
283
|
+
the convex hulls in pixel units (array elements will be None
|
|
284
|
+
if there is no valid convex hull for that frame)
|
|
285
|
+
"""
|
|
286
|
+
if identity in self._convex_hull_cache:
|
|
287
|
+
return self._convex_hull_cache[identity]
|
|
288
|
+
else:
|
|
289
|
+
convex_hulls = None
|
|
290
|
+
path = None
|
|
291
|
+
if self._cache_dir is not None:
|
|
292
|
+
path = (
|
|
293
|
+
self._cache_dir
|
|
294
|
+
/ "convex_hulls"
|
|
295
|
+
/ self._path.with_suffix("").name
|
|
296
|
+
/ f"convex_hulls_{identity}.pickle"
|
|
297
|
+
)
|
|
298
|
+
path.parents[0].mkdir(mode=0o775, parents=True, exist_ok=True)
|
|
299
|
+
|
|
300
|
+
try:
|
|
301
|
+
with path.open("rb") as f:
|
|
302
|
+
convex_hulls = joblib.load(f)
|
|
303
|
+
except Exception:
|
|
304
|
+
# we weren't able to read in the cached convex hulls,
|
|
305
|
+
# just ignore the exception and we'll generate them
|
|
306
|
+
pass
|
|
307
|
+
|
|
308
|
+
if convex_hulls is None:
|
|
309
|
+
points, point_masks = self.get_identity_poses(identity)
|
|
310
|
+
# Omit tail from convex hull
|
|
311
|
+
body_points = points[:, :-2, :]
|
|
312
|
+
body_point_masks = point_masks[:, :-2]
|
|
313
|
+
convex_hulls = []
|
|
314
|
+
|
|
315
|
+
for frame_index in range(self.num_frames):
|
|
316
|
+
if sum(body_point_masks[frame_index, :]) >= 3:
|
|
317
|
+
filtered_points = body_points[
|
|
318
|
+
frame_index, body_point_masks[frame_index, :] == 1, :
|
|
319
|
+
]
|
|
320
|
+
convex_hulls.append(MultiPoint(filtered_points).convex_hull)
|
|
321
|
+
else:
|
|
322
|
+
convex_hulls.append(None)
|
|
323
|
+
|
|
324
|
+
if path:
|
|
325
|
+
with path.open("wb") as f:
|
|
326
|
+
joblib.dump(convex_hulls, f)
|
|
327
|
+
|
|
328
|
+
self._convex_hull_cache[identity] = convex_hulls
|
|
329
|
+
return convex_hulls
|
|
330
|
+
|
|
331
|
+
def compute_bearing(self, points: np.ndarray, use_nose: bool = False):
|
|
332
|
+
"""compute the bearing of the animal using base tail and base neck keypoints
|
|
333
|
+
|
|
334
|
+
Args:
|
|
335
|
+
points (np.ndarray): the points for a single frame (12,2) array
|
|
336
|
+
use_nose (bool): use nose keypoint instead of base neck, used when
|
|
337
|
+
we have a reduced keypoint pose that lacks base neck
|
|
338
|
+
"""
|
|
339
|
+
# fall back to use nose instead of base neck if base neck is absent from this pose file
|
|
340
|
+
# (for example, 5 keypoint pose instead of 12)
|
|
341
|
+
if use_nose:
|
|
342
|
+
p1_xy = points[self.KeypointIndex.NOSE.value].astype(np.float32)
|
|
343
|
+
else:
|
|
344
|
+
p1_xy = points[self.KeypointIndex.BASE_NECK.value].astype(np.float32)
|
|
345
|
+
p2_xy = points[self.KeypointIndex.BASE_TAIL.value].astype(np.float32)
|
|
346
|
+
offset_xy = p1_xy - p2_xy
|
|
347
|
+
|
|
348
|
+
angle_rad = np.arctan2(offset_xy[1], offset_xy[0])
|
|
349
|
+
|
|
350
|
+
return np.degrees(angle_rad)
|
|
351
|
+
|
|
352
|
+
def compute_all_bearings(self, identity):
|
|
353
|
+
"""compute the bearing for each frame for a given identity"""
|
|
354
|
+
use_nose = not self.get_reduced_point_mask()[self.KeypointIndex.BASE_NECK.value]
|
|
355
|
+
if use_nose:
|
|
356
|
+
logging.warning("Falling back to using nose keypoint for bearing computation")
|
|
357
|
+
|
|
358
|
+
bearings = np.full(self.num_frames, np.nan, dtype=np.float32)
|
|
359
|
+
for i in range(self.num_frames):
|
|
360
|
+
points, mask = self.get_points(i, identity)
|
|
361
|
+
if points is not None:
|
|
362
|
+
bearings[i] = self.compute_bearing(points, use_nose)
|
|
363
|
+
return bearings
|
|
364
|
+
|
|
365
|
+
@staticmethod
|
|
366
|
+
def get_pose_file_attributes(path: Path) -> dict:
|
|
367
|
+
"""get the attributes from the pose file's hdf5 file"""
|
|
368
|
+
with h5py.File(path, "r") as pose_h5:
|
|
369
|
+
attrs = dict(pose_h5.attrs)
|
|
370
|
+
attrs["poseest"] = dict(pose_h5["poseest"].attrs)
|
|
371
|
+
return attrs
|
|
372
|
+
|
|
373
|
+
@property
|
|
374
|
+
def num_lixit_keypoints(self) -> int:
|
|
375
|
+
"""get the number of lixit keypoints
|
|
376
|
+
|
|
377
|
+
always 0 for pose file versions <5
|
|
378
|
+
"""
|
|
379
|
+
return 0
|
|
380
|
+
|
|
381
|
+
@property
|
|
382
|
+
def external_identities(self) -> list[str] | None:
|
|
383
|
+
"""get the jabs identity to external identity mapping"""
|
|
384
|
+
return self._external_identities
|
|
385
|
+
|
|
386
|
+
def identity_index_to_display(self, identity_index: int) -> str:
|
|
387
|
+
"""Convert an identity index to a display string.
|
|
388
|
+
|
|
389
|
+
Args:
|
|
390
|
+
identity_index (int): The identity index to convert.
|
|
391
|
+
|
|
392
|
+
Returns:
|
|
393
|
+
str: The display string for the identity.
|
|
394
|
+
"""
|
|
395
|
+
if self.external_identities and 0 <= identity_index < len(self.external_identities):
|
|
396
|
+
return self.external_identities[identity_index]
|
|
397
|
+
return str(identity_index)
|
|
398
|
+
|
|
399
|
+
def check_cache_version(self) -> bool:
|
|
400
|
+
"""Check if the cache version matches the expected version.
|
|
401
|
+
|
|
402
|
+
Returns:
|
|
403
|
+
bool: True if the cache version matches, False otherwise.
|
|
404
|
+
"""
|
|
405
|
+
try:
|
|
406
|
+
with h5py.File(self._cache_file_path(), "r") as cache_h5:
|
|
407
|
+
cache_version = cache_h5.attrs.get("cache_file_version", None)
|
|
408
|
+
return cache_version == self._CACHE_FILE_VERSION
|
|
409
|
+
except Exception:
|
|
410
|
+
return False
|
|
411
|
+
|
|
412
|
+
def _cache_file_path(self) -> Path | None:
|
|
413
|
+
"""Get the path to the cache file for this pose file.
|
|
414
|
+
|
|
415
|
+
Returns:
|
|
416
|
+
Path | None: The path to the cache file, or None if no cache directory is set.
|
|
417
|
+
"""
|
|
418
|
+
if self._cache_dir is None:
|
|
419
|
+
return None
|
|
420
|
+
filename = self._path.name.replace(".h5", "_cache.h5")
|
|
421
|
+
return self._cache_dir / filename
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
ORG_NAME = "JAX"
|
|
2
|
+
APP_NAME = "JABS"
|
|
3
|
+
APP_NAME_LONG = f"{ORG_NAME} Animal Behavior System"
|
|
4
|
+
|
|
5
|
+
# a hard coded random seed used for the final training
|
|
6
|
+
# This is not used during cross-validation, but to ensure that final classifier is reproducible
|
|
7
|
+
# we use this fixed seed when training the final model after cross validation.
|
|
8
|
+
FINAL_TRAIN_SEED = 0xAB3BDB
|
|
9
|
+
|
|
10
|
+
# some defaults for compressing hdf5 output
|
|
11
|
+
COMPRESSION = "gzip"
|
|
12
|
+
COMPRESSION_OPTS_DEFAULT = 6
|
|
13
|
+
|
|
14
|
+
# settings keys for project settings stored in the project.json file
|
|
15
|
+
CV_GROUPING_KEY = "cv_grouping"
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
"""Module for defining enums used in JABS"""
|
|
2
|
+
|
|
3
|
+
from .classifier_types import ClassifierType
|
|
4
|
+
from .cv_grouping import DEFAULT_CV_GROUPING_STRATEGY, CrossValidationGroupingStrategy
|
|
5
|
+
from .units import ProjectDistanceUnit
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"DEFAULT_CV_GROUPING_STRATEGY",
|
|
9
|
+
"ClassifierType",
|
|
10
|
+
"CrossValidationGroupingStrategy",
|
|
11
|
+
"ProjectDistanceUnit",
|
|
12
|
+
]
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from enum import Enum
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class CrossValidationGroupingStrategy(str, Enum):
|
|
5
|
+
"""Cross-validation grouping type for the project.
|
|
6
|
+
|
|
7
|
+
Inheriting from str allows for easy serialization to/from JSON (the enum will
|
|
8
|
+
automatically be serialized using the enum value).
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
INDIVIDUAL = "Individual Animal"
|
|
12
|
+
VIDEO = "Video"
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
DEFAULT_CV_GROUPING_STRATEGY = CrossValidationGroupingStrategy.INDIVIDUAL
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
class PoseHashException(Exception):
|
|
2
|
+
"""Exception raised when the hash of a pose file does not match the expected value."""
|
|
3
|
+
|
|
4
|
+
pass
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class PoseIdEmbeddingException(Exception):
|
|
8
|
+
"""Exception raised for invalid instance_embed_id values in pose file."""
|
|
9
|
+
|
|
10
|
+
pass
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class MissingBehaviorError(Exception):
|
|
14
|
+
"""Exception raised when a behavior is not found in the prediction file."""
|
|
15
|
+
|
|
16
|
+
pass
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class FeatureVersionException(Exception):
|
|
20
|
+
"""exception raised when the version of the features in the h5 file is not compatible with the current version of JABS"""
|
|
21
|
+
|
|
22
|
+
pass
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class DistanceScaleException(Exception):
|
|
26
|
+
"""exception raised when the distance scale factor in the h5 file don't match what the classifier expects"""
|
|
27
|
+
|
|
28
|
+
pass
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
"""JABS utilities"""
|
|
2
|
+
|
|
3
|
+
from .update_checker import check_for_update, is_pypi_install
|
|
4
|
+
from .utilities import get_bool_env_var, hash_file, hide_stderr
|
|
5
|
+
|
|
6
|
+
__all__ = [
|
|
7
|
+
"check_for_update",
|
|
8
|
+
"get_bool_env_var",
|
|
9
|
+
"hash_file",
|
|
10
|
+
"hide_stderr",
|
|
11
|
+
"is_pypi_install",
|
|
12
|
+
]
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
from collections.abc import Generator, Iterable
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
from jabs.core.abstract import PoseEstimation
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def gen_line_fragments(
|
|
9
|
+
connected_segments: Iterable[Iterable[PoseEstimation.KeypointIndex]],
|
|
10
|
+
exclude_points: np.ndarray,
|
|
11
|
+
) -> Generator[list[int], None, None]:
|
|
12
|
+
"""generate line fragments from the connected segments.
|
|
13
|
+
|
|
14
|
+
This will break up segments if a point within the segment is excluded,
|
|
15
|
+
or will remove the segment completely if it does not have at least two points
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
connected_segments: Iterable of Iterables of KeypointIndex, where each inner
|
|
19
|
+
Iterable represents a segment of connected keypoints
|
|
20
|
+
exclude_points: numpy array of points to exclude when generating segments
|
|
21
|
+
|
|
22
|
+
Yields:
|
|
23
|
+
yields lists of Keypoint indexes that make up the segments to draw
|
|
24
|
+
"""
|
|
25
|
+
curr_fragment = []
|
|
26
|
+
for curr_pt_indexes in connected_segments:
|
|
27
|
+
for curr_pt_index in curr_pt_indexes:
|
|
28
|
+
if curr_pt_index.value in exclude_points:
|
|
29
|
+
if len(curr_fragment) >= 2:
|
|
30
|
+
yield curr_fragment
|
|
31
|
+
curr_fragment = []
|
|
32
|
+
else:
|
|
33
|
+
curr_fragment.append(curr_pt_index.value)
|
|
34
|
+
if len(curr_fragment) >= 2:
|
|
35
|
+
yield curr_fragment
|
|
36
|
+
curr_fragment = []
|
|
@@ -0,0 +1,223 @@
|
|
|
1
|
+
import contextlib
|
|
2
|
+
import logging
|
|
3
|
+
import os
|
|
4
|
+
import threading
|
|
5
|
+
import time
|
|
6
|
+
from collections.abc import Callable, Iterable
|
|
7
|
+
from concurrent.futures import Future, ProcessPoolExecutor
|
|
8
|
+
from multiprocessing import shared_memory
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
MAX_POOL_WORKERS = 6
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _noop() -> None:
|
|
17
|
+
"""No-op function for warming up worker processes.
|
|
18
|
+
|
|
19
|
+
Must be at module level to be pickleable by ProcessPoolExecutor.
|
|
20
|
+
"""
|
|
21
|
+
return None
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class ProcessPoolManager:
|
|
25
|
+
"""
|
|
26
|
+
Manage a shared ProcessPoolExecutor with warm-up and safe shutdown.
|
|
27
|
+
|
|
28
|
+
Attributes:
|
|
29
|
+
_max_workers (int | None): Maximum number of worker processes. Passed to
|
|
30
|
+
ProcessPoolExecutor when created.
|
|
31
|
+
_initializer (Callable[..., object] | None): Optional function executed in
|
|
32
|
+
each worker process when it starts.
|
|
33
|
+
_initargs (tuple[object, ...]): Arguments passed to the initializer.
|
|
34
|
+
_name (str): Logical name for debugging/logging.
|
|
35
|
+
_executor (ProcessPoolExecutor | None): The lazily-created underlying
|
|
36
|
+
process pool. None until first use.
|
|
37
|
+
_lock (threading.RLock): Protects access to `_executor` and `_is_shutdown`.
|
|
38
|
+
_is_shutdown (bool): Whether shutdown() has been called. Prevents reuse once
|
|
39
|
+
the pool has been shut down.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
max_workers (int | None): Maximum number of worker processes. Defaults to
|
|
43
|
+
os.cpu_count() if None.
|
|
44
|
+
initializer (Callable | None): Optional function run in each worker process
|
|
45
|
+
when it starts.
|
|
46
|
+
initargs (tuple): Arguments passed to the initializer.
|
|
47
|
+
name (str): Optional name used only for debugging/logging.
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
def __init__(
|
|
51
|
+
self,
|
|
52
|
+
max_workers: int | None = None,
|
|
53
|
+
*,
|
|
54
|
+
initializer: Callable[..., object] | None = None,
|
|
55
|
+
initargs: tuple[object, ...] = (),
|
|
56
|
+
name: str = "ProcessPoolManager",
|
|
57
|
+
) -> None:
|
|
58
|
+
logger.debug(f"PPM __init__ name={name} id={id(self)}")
|
|
59
|
+
requested_workers = max_workers or (os.cpu_count() or 1)
|
|
60
|
+
self._max_workers: int = max(1, min(requested_workers, MAX_POOL_WORKERS))
|
|
61
|
+
self._initializer = initializer
|
|
62
|
+
self._initargs = initargs
|
|
63
|
+
self._name = name
|
|
64
|
+
|
|
65
|
+
self._executor: ProcessPoolExecutor | None = None
|
|
66
|
+
self._lock = threading.RLock() # protects _executor and _is_shutdown
|
|
67
|
+
self._is_shutdown = False
|
|
68
|
+
|
|
69
|
+
self._cancel_shm: shared_memory.SharedMemory | None = None
|
|
70
|
+
|
|
71
|
+
@property
|
|
72
|
+
def max_workers(self) -> int:
|
|
73
|
+
"""Maximum number of worker processes in the pool."""
|
|
74
|
+
return self._max_workers
|
|
75
|
+
|
|
76
|
+
@property
|
|
77
|
+
def name(self) -> str:
|
|
78
|
+
"""Logical name of the ProcessPoolManager for debugging/logging."""
|
|
79
|
+
return self._name
|
|
80
|
+
|
|
81
|
+
def _ensure_cancel_shm(self) -> shared_memory.SharedMemory:
|
|
82
|
+
"""Create the shared-memory cancel flag on first use, if not shut down."""
|
|
83
|
+
with self._lock:
|
|
84
|
+
if self._is_shutdown:
|
|
85
|
+
raise RuntimeError(f"{self._name} has been shut down")
|
|
86
|
+
|
|
87
|
+
if self._cancel_shm is None:
|
|
88
|
+
shm = shared_memory.SharedMemory(create=True, size=1)
|
|
89
|
+
# 0 = not cancelled, 1 = cancelled
|
|
90
|
+
shm.buf[0] = 0
|
|
91
|
+
self._cancel_shm = shm
|
|
92
|
+
|
|
93
|
+
return self._cancel_shm
|
|
94
|
+
|
|
95
|
+
@property
|
|
96
|
+
def cancel_flag_name(self) -> str | None:
|
|
97
|
+
"""Name of the shared-memory cancel flag, or None if shut down.
|
|
98
|
+
|
|
99
|
+
Callers can pass this name to worker functions so they can open the
|
|
100
|
+
shared memory and cooperatively check for cancellation.
|
|
101
|
+
"""
|
|
102
|
+
with self._lock:
|
|
103
|
+
if self._is_shutdown:
|
|
104
|
+
return None
|
|
105
|
+
|
|
106
|
+
shm = self._ensure_cancel_shm()
|
|
107
|
+
return shm.name
|
|
108
|
+
|
|
109
|
+
def set_cancelled(self) -> None:
|
|
110
|
+
"""Set the cancel flag to 1, signalling cooperative cancellation."""
|
|
111
|
+
with self._lock:
|
|
112
|
+
if self._is_shutdown:
|
|
113
|
+
return
|
|
114
|
+
|
|
115
|
+
shm = self._cancel_shm or self._ensure_cancel_shm()
|
|
116
|
+
shm.buf[0] = 1
|
|
117
|
+
|
|
118
|
+
def clear_cancelled(self) -> None:
|
|
119
|
+
"""Reset the cancel flag back to 0."""
|
|
120
|
+
with self._lock:
|
|
121
|
+
if self._cancel_shm is not None:
|
|
122
|
+
self._cancel_shm.buf[0] = 0
|
|
123
|
+
|
|
124
|
+
def _ensure_executor(self) -> ProcessPoolExecutor:
|
|
125
|
+
"""Create the executor on first use, if not shut down."""
|
|
126
|
+
with self._lock:
|
|
127
|
+
if self._is_shutdown:
|
|
128
|
+
raise RuntimeError(f"{self._name} has been shut down")
|
|
129
|
+
|
|
130
|
+
if self._executor is None:
|
|
131
|
+
# noinspection PyTypeChecker
|
|
132
|
+
self._executor = ProcessPoolExecutor(
|
|
133
|
+
max_workers=self._max_workers,
|
|
134
|
+
initializer=self._initializer,
|
|
135
|
+
initargs=self._initargs,
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
return self._executor
|
|
139
|
+
|
|
140
|
+
def submit(self, fn: Callable[..., Any], *args: Any, **kwargs: Any) -> Future:
|
|
141
|
+
"""Submit a task to the process pool."""
|
|
142
|
+
executor = self._ensure_executor()
|
|
143
|
+
return executor.submit(fn, *args, **kwargs)
|
|
144
|
+
|
|
145
|
+
def map(
|
|
146
|
+
self,
|
|
147
|
+
fn: Callable[[Any], Any],
|
|
148
|
+
iterable: Iterable[Any],
|
|
149
|
+
chunksize: int = 1,
|
|
150
|
+
) -> Iterable[Any]:
|
|
151
|
+
"""Map over an iterable using the process pool."""
|
|
152
|
+
executor = self._ensure_executor()
|
|
153
|
+
return executor.map(fn, iterable, chunksize=chunksize)
|
|
154
|
+
|
|
155
|
+
def warm_up(self, wait: bool = True) -> None:
|
|
156
|
+
"""Eagerly start worker processes and optionally run trivial tasks.
|
|
157
|
+
|
|
158
|
+
This is useful if you want the cost of spawning processes and running
|
|
159
|
+
initializers to happen at a controlled time (e.g., on app startup)
|
|
160
|
+
instead of on the first real submit().
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
wait (bool): If True, submit and wait for trivial tasks to complete
|
|
164
|
+
in each worker process. This ensures that all workers are fully
|
|
165
|
+
initialized and ready to accept real tasks. If False, only starts
|
|
166
|
+
the processes without waiting for task completion.
|
|
167
|
+
"""
|
|
168
|
+
start_time = time.time()
|
|
169
|
+
logger.debug(f"PPM warm_up name={self._name} id={id(self)}")
|
|
170
|
+
executor = self._ensure_executor()
|
|
171
|
+
self._ensure_cancel_shm()
|
|
172
|
+
|
|
173
|
+
if not wait:
|
|
174
|
+
return
|
|
175
|
+
|
|
176
|
+
futures = [executor.submit(_noop) for _ in range(self._max_workers)]
|
|
177
|
+
for f in futures:
|
|
178
|
+
with contextlib.suppress(Exception):
|
|
179
|
+
f.result()
|
|
180
|
+
|
|
181
|
+
elapsed = time.time() - start_time
|
|
182
|
+
logger.debug(
|
|
183
|
+
f"PPM warm_up name={self._name} id={id(self)} COMPLETED in {elapsed:.2f} seconds"
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
def shutdown(self, *, wait: bool = True, cancel_futures: bool = False) -> None:
|
|
187
|
+
"""Explicitly shut down the process pool.
|
|
188
|
+
|
|
189
|
+
After shutdown, the manager cannot be reused.
|
|
190
|
+
"""
|
|
191
|
+
with self._lock:
|
|
192
|
+
self._is_shutdown = True
|
|
193
|
+
executor = self._executor
|
|
194
|
+
if executor is not None:
|
|
195
|
+
with contextlib.suppress(Exception):
|
|
196
|
+
executor.shutdown(wait=wait, cancel_futures=cancel_futures)
|
|
197
|
+
self._executor = None
|
|
198
|
+
if self._cancel_shm is not None:
|
|
199
|
+
with contextlib.suppress(Exception):
|
|
200
|
+
self._cancel_shm.close()
|
|
201
|
+
self._cancel_shm.unlink()
|
|
202
|
+
self._cancel_shm = None
|
|
203
|
+
|
|
204
|
+
def __enter__(self) -> "ProcessPoolManager":
|
|
205
|
+
"""Enter context manager, returning self.
|
|
206
|
+
|
|
207
|
+
Allows using the manager in a 'with' statement for automatic cleanup.
|
|
208
|
+
"""
|
|
209
|
+
self._ensure_executor()
|
|
210
|
+
return self
|
|
211
|
+
|
|
212
|
+
def __exit__(self, exc_type, exc, tb) -> None:
|
|
213
|
+
"""Exit context manager, shutting down the process pool."""
|
|
214
|
+
self.shutdown(wait=True, cancel_futures=False)
|
|
215
|
+
|
|
216
|
+
def __del__(self) -> None:
|
|
217
|
+
"""Best-effort cleanup if user code forgets to call shutdown().
|
|
218
|
+
|
|
219
|
+
Note: __del__ is not guaranteed to run at interpreter shutdown, so
|
|
220
|
+
you should still call shutdown() or use the manager as a context manager.
|
|
221
|
+
"""
|
|
222
|
+
with contextlib.suppress(Exception):
|
|
223
|
+
self.shutdown(wait=False, cancel_futures=True)
|
|
@@ -0,0 +1,269 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import os
|
|
3
|
+
import random
|
|
4
|
+
|
|
5
|
+
import cv2
|
|
6
|
+
import h5py
|
|
7
|
+
|
|
8
|
+
# Command line example of using this script:
|
|
9
|
+
#
|
|
10
|
+
# share_root='/run/user/1000/gvfs/smb-share:server=bht2stor.jax.org,share=vkumar'
|
|
11
|
+
# python src/utils/sampleposeintervals.py \
|
|
12
|
+
# --batch-file UCSD_Rotta_TS_v2.txt \
|
|
13
|
+
# --root-dir "${share_root}" \
|
|
14
|
+
# --out-dir UCSD_Rotta_TS_v2-intervals \
|
|
15
|
+
# --out-frame-count 9000 \
|
|
16
|
+
# --start-frame 54000 \
|
|
17
|
+
# --pose-version 3
|
|
18
|
+
#
|
|
19
|
+
# share_root='/media/sheppk/TOSHIBA EXT/rotta-data/UCSD_Rotta_TS_v2-vidcache'
|
|
20
|
+
# python src/utils/sampleposeintervals.py \
|
|
21
|
+
# --batch-file "${share_root}/batch.txt" \
|
|
22
|
+
# --root-dir "${share_root}" \
|
|
23
|
+
# --out-dir UCSD_Rotta_TS_v2-intervals-2021-05-25 \
|
|
24
|
+
# --out-frame-count 9000 \
|
|
25
|
+
# --start-frame 27000 \
|
|
26
|
+
# --pose-version 3
|
|
27
|
+
|
|
28
|
+
# python src/utils/sampleposeintervals.py \
|
|
29
|
+
# --batch-file ~/projects/social-interaction/data/bxd-batch-early-morning-2021-06-09.txt \
|
|
30
|
+
# --root-dir '/run/user/1000/gvfs/smb-share:server=bht2stor.jax.org,share=vkumar' \
|
|
31
|
+
# --out-dir bxd-batch-early-morning-2021-06-09 \
|
|
32
|
+
# --out-frame-count 9000 \
|
|
33
|
+
# --start-frame 54000 \
|
|
34
|
+
# --pose-version 3
|
|
35
|
+
|
|
36
|
+
# python src/utils/sampleposeintervals.py \
|
|
37
|
+
# --batch-file temp/B6J-and-BTBR-3M-strangers-4-day-rand-2021-05-24.txt \
|
|
38
|
+
# --root-dir '/media/sheppk/TOSHIBA EXT/rotta-data/B6J-and-BTBR-3M-strangers-4-day-rand-2021-05-24' \
|
|
39
|
+
# --out-dir B6J-and-BTBR-3M-strangers-4-day-rand-samples-2021-05-24 \
|
|
40
|
+
# --out-frame-count 3600 \
|
|
41
|
+
# --start-frame 6000 \
|
|
42
|
+
# --pose-version 3
|
|
43
|
+
|
|
44
|
+
# python src/utils/sampleposeintervals.py \
|
|
45
|
+
# --batch-file temp/B6J-and-BTBR-3M-strangers-4-day-rand-2021-05-24.txt \
|
|
46
|
+
# --root-dir '/media/sheppk/TOSHIBA EXT/rotta-data/B6J_and_BTBR_3M_stranger_4day_2021-07-20' \
|
|
47
|
+
# --out-dir temp/B6J-and-BTBR-3M-strangers-4-day-rand-samples-2021-08-05 \
|
|
48
|
+
# --out-frame-count 3600 \
|
|
49
|
+
# --start-frame 6000 \
|
|
50
|
+
# --only-pose \
|
|
51
|
+
# --pose-version 4
|
|
52
|
+
|
|
53
|
+
# rclone copy --transfers 4 --progress \
|
|
54
|
+
# --include-from /home/sheppk/projects/behavior-classifier/temp/BTBR_3M_stranger_4day-subset-avi.txt \
|
|
55
|
+
# "labdropbox:/KumarLab's shared workspace/VideoData/MDS_Tests/BTBR_3M_stranger_4day" \
|
|
56
|
+
# /media/sheppk/TOSHIBA\ EXT/BTBR_3M_stranger_4day-2021-08-24
|
|
57
|
+
# rclone copy --transfers 4 --progress \
|
|
58
|
+
# --include-from /home/sheppk/projects/behavior-classifier/temp/BTBR_3M_stranger_4day-subset-pose.txt \
|
|
59
|
+
# /home/sheppk/sshfs/winterproj/bgeuther/IdentityInfer/Data/BTBR_3M_stranger_4day \
|
|
60
|
+
# /media/sheppk/TOSHIBA\ EXT/BTBR_3M_stranger_4day-2021-08-24
|
|
61
|
+
# python src/utils/sampleposeintervals.py \
|
|
62
|
+
# --batch-file /media/sheppk/TOSHIBA\ EXT/BTBR_3M_stranger_4day-2021-08-24/batch.txt \
|
|
63
|
+
# --root-dir /media/sheppk/TOSHIBA\ EXT/BTBR_3M_stranger_4day-2021-08-24 \
|
|
64
|
+
# --out-dir /media/sheppk/TOSHIBA\ EXT/BTBR_3M_stranger_4day-2021-08-24-samples \
|
|
65
|
+
# --out-frame-count 3600 \
|
|
66
|
+
# --start-frame 6000 \
|
|
67
|
+
# --pose-version 4
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def main():
|
|
71
|
+
"""sample pose intervals"""
|
|
72
|
+
parser = argparse.ArgumentParser()
|
|
73
|
+
|
|
74
|
+
parser.add_argument(
|
|
75
|
+
"--batch-file",
|
|
76
|
+
help="path to the file that is a new-line separated list of all videos to process",
|
|
77
|
+
required=True,
|
|
78
|
+
)
|
|
79
|
+
parser.add_argument(
|
|
80
|
+
"--root-dir",
|
|
81
|
+
help="the root directory. All paths given in the batch files are relative to this root",
|
|
82
|
+
required=True,
|
|
83
|
+
)
|
|
84
|
+
parser.add_argument(
|
|
85
|
+
"--out-dir",
|
|
86
|
+
help="output directory. The videos and pose files for sampled intervals are saved to this dir",
|
|
87
|
+
required=True,
|
|
88
|
+
)
|
|
89
|
+
parser.add_argument(
|
|
90
|
+
"--out-frame-count",
|
|
91
|
+
help="this defines how many frames to save. Assuming 30fps a value of 1800 corresponds to one minute",
|
|
92
|
+
required=True,
|
|
93
|
+
type=int,
|
|
94
|
+
)
|
|
95
|
+
parser.add_argument(
|
|
96
|
+
"--start-frame",
|
|
97
|
+
help="this argument specifies which frame we start at. If this option is not specified we randomly select"
|
|
98
|
+
" a start frame from the video.",
|
|
99
|
+
required=False,
|
|
100
|
+
type=int,
|
|
101
|
+
)
|
|
102
|
+
parser.add_argument(
|
|
103
|
+
"--pose-version",
|
|
104
|
+
help="give the integer version number that should be used for pose",
|
|
105
|
+
default=2,
|
|
106
|
+
type=int,
|
|
107
|
+
choices=(2, 3, 4, 5),
|
|
108
|
+
)
|
|
109
|
+
parser.add_argument(
|
|
110
|
+
"--only-pose",
|
|
111
|
+
help="if specified this option will sample pose data and exclude video from output",
|
|
112
|
+
action="store_true",
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
args = parser.parse_args()
|
|
116
|
+
|
|
117
|
+
if args.pose_version == 2:
|
|
118
|
+
pose_suffix = "_pose_est_v2.h5"
|
|
119
|
+
elif args.pose_version == 3:
|
|
120
|
+
pose_suffix = "_pose_est_v3.h5"
|
|
121
|
+
elif args.pose_version == 4:
|
|
122
|
+
pose_suffix = "_pose_est_v4.h5"
|
|
123
|
+
elif args.pose_version == 5:
|
|
124
|
+
pose_suffix = "_pose_est_v5.h5"
|
|
125
|
+
else:
|
|
126
|
+
raise NotImplementedError("pose version not implemented: " + str(args.pose_version))
|
|
127
|
+
|
|
128
|
+
os.makedirs(args.out_dir, exist_ok=True)
|
|
129
|
+
|
|
130
|
+
with open(args.batch_file) as batch_file:
|
|
131
|
+
for line in batch_file:
|
|
132
|
+
vid_filename = line.strip()
|
|
133
|
+
if vid_filename:
|
|
134
|
+
print("Processing:", vid_filename)
|
|
135
|
+
vid_path = os.path.join(args.root_dir, vid_filename)
|
|
136
|
+
vid_path_root, _ = os.path.splitext(vid_path)
|
|
137
|
+
pose_in_path = vid_path_root + pose_suffix
|
|
138
|
+
|
|
139
|
+
if not args.only_pose and not os.path.isfile(vid_path):
|
|
140
|
+
print("WARNING: missing video path:", vid_path)
|
|
141
|
+
continue
|
|
142
|
+
|
|
143
|
+
if not os.path.isfile(pose_in_path):
|
|
144
|
+
print("WARNING: missing pose path:", pose_in_path)
|
|
145
|
+
continue
|
|
146
|
+
|
|
147
|
+
with h5py.File(pose_in_path, "r") as pose_in:
|
|
148
|
+
frame_count = pose_in["poseest"]["confidence"].shape[0]
|
|
149
|
+
|
|
150
|
+
last_candidate_frame = frame_count - args.out_frame_count
|
|
151
|
+
if last_candidate_frame <= 0:
|
|
152
|
+
print(
|
|
153
|
+
f"WARNING: {vid_filename} skipped because it only contains {frame_count} frames"
|
|
154
|
+
)
|
|
155
|
+
continue
|
|
156
|
+
|
|
157
|
+
if args.start_frame is None:
|
|
158
|
+
out_start_frame_index = random.randrange(last_candidate_frame)
|
|
159
|
+
else:
|
|
160
|
+
out_start_frame_index = args.start_frame - 1
|
|
161
|
+
|
|
162
|
+
vid_out_filename = vid_filename.replace("/", "+").replace("\\", "+")
|
|
163
|
+
vid_out_path = os.path.join(args.out_dir, vid_out_filename)
|
|
164
|
+
vid_out_path_root, _ = os.path.splitext(vid_out_path)
|
|
165
|
+
vid_out_path = (
|
|
166
|
+
vid_out_path_root + "_" + str(out_start_frame_index + 1) + ".avi"
|
|
167
|
+
)
|
|
168
|
+
pose_out_path = (
|
|
169
|
+
vid_out_path_root + "_" + str(out_start_frame_index + 1) + pose_suffix
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
with h5py.File(pose_out_path, "w") as pose_out:
|
|
173
|
+
# pose v2 stuff
|
|
174
|
+
start = out_start_frame_index
|
|
175
|
+
stop = start + args.out_frame_count
|
|
176
|
+
pose_out["poseest/points"] = pose_in["poseest/points"][start:stop, ...]
|
|
177
|
+
pose_out["poseest/confidence"] = pose_in["poseest/confidence"][
|
|
178
|
+
start:stop, ...
|
|
179
|
+
]
|
|
180
|
+
|
|
181
|
+
# pose v3 stuff
|
|
182
|
+
if "instance_count" in pose_in["poseest"]:
|
|
183
|
+
pose_out["poseest/instance_count"] = pose_in["poseest/instance_count"][
|
|
184
|
+
start:stop, ...
|
|
185
|
+
]
|
|
186
|
+
if "instance_embedding" in pose_in["poseest"]:
|
|
187
|
+
pose_out["poseest/instance_embedding"] = pose_in[
|
|
188
|
+
"poseest/instance_embedding"
|
|
189
|
+
][start:stop, ...]
|
|
190
|
+
if "instance_track_id" in pose_in["poseest"]:
|
|
191
|
+
pose_out["poseest/instance_track_id"] = pose_in[
|
|
192
|
+
"poseest/instance_track_id"
|
|
193
|
+
][start:stop, ...]
|
|
194
|
+
|
|
195
|
+
# pose v4 stuff
|
|
196
|
+
if "id_mask" in pose_in["poseest"]:
|
|
197
|
+
pose_out["poseest/id_mask"] = pose_in["poseest/id_mask"][
|
|
198
|
+
start:stop, ...
|
|
199
|
+
]
|
|
200
|
+
if "identity_embeds" in pose_in["poseest"]:
|
|
201
|
+
pose_out["poseest/identity_embeds"] = pose_in[
|
|
202
|
+
"poseest/identity_embeds"
|
|
203
|
+
][start:stop, ...]
|
|
204
|
+
if "instance_embed_id" in pose_in["poseest"]:
|
|
205
|
+
pose_out["poseest/instance_embed_id"] = pose_in[
|
|
206
|
+
"poseest/instance_embed_id"
|
|
207
|
+
][start:stop, ...]
|
|
208
|
+
if "instance_id_center" in pose_in["poseest"]:
|
|
209
|
+
pose_out["poseest/instance_id_center"] = pose_in[
|
|
210
|
+
"poseest/instance_id_center"
|
|
211
|
+
][:]
|
|
212
|
+
|
|
213
|
+
# v5 specific stuff
|
|
214
|
+
if "static_objects" in pose_in:
|
|
215
|
+
static_group = pose_out.create_group("static_objects")
|
|
216
|
+
for dataset in pose_in["static_objects"]:
|
|
217
|
+
static_group.create_dataset(
|
|
218
|
+
dataset, data=pose_in["static_objects"][dataset]
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
# copy attributes
|
|
222
|
+
for attr in pose_in["poseest"].attrs:
|
|
223
|
+
pose_out["poseest"].attrs[attr] = pose_in["poseest"].attrs[attr]
|
|
224
|
+
|
|
225
|
+
cap = None
|
|
226
|
+
writer = None
|
|
227
|
+
|
|
228
|
+
if not args.only_pose:
|
|
229
|
+
try:
|
|
230
|
+
cap = cv2.VideoCapture(vid_path)
|
|
231
|
+
if not cap.isOpened():
|
|
232
|
+
print(f"WARNING: failed to open {vid_filename}")
|
|
233
|
+
continue
|
|
234
|
+
|
|
235
|
+
cap.set(cv2.CAP_PROP_POS_FRAMES, out_start_frame_index)
|
|
236
|
+
if not cap.isOpened():
|
|
237
|
+
print(f"WARNING: failed to seek to start frame {vid_filename}")
|
|
238
|
+
continue
|
|
239
|
+
|
|
240
|
+
writer = cv2.VideoWriter(
|
|
241
|
+
vid_out_path,
|
|
242
|
+
cv2.VideoWriter_fourcc(*"MJPG"),
|
|
243
|
+
30,
|
|
244
|
+
(
|
|
245
|
+
int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
|
|
246
|
+
int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)),
|
|
247
|
+
),
|
|
248
|
+
)
|
|
249
|
+
for _ in range(args.out_frame_count):
|
|
250
|
+
if not cap.isOpened():
|
|
251
|
+
print(f"WARNING: {vid_filename} ended prematurely")
|
|
252
|
+
break
|
|
253
|
+
|
|
254
|
+
ret, frame = cap.read()
|
|
255
|
+
if ret:
|
|
256
|
+
writer.write(frame)
|
|
257
|
+
else:
|
|
258
|
+
print(f"WARNING: {vid_filename} ended prematurely")
|
|
259
|
+
break
|
|
260
|
+
|
|
261
|
+
finally:
|
|
262
|
+
if cap is not None:
|
|
263
|
+
cap.release()
|
|
264
|
+
if writer is not None:
|
|
265
|
+
writer.release()
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
if __name__ == "__main__":
|
|
269
|
+
main()
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
"""Utilities for checking PyPI for JABS updates."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import logging
|
|
5
|
+
import urllib.request
|
|
6
|
+
from importlib import metadata
|
|
7
|
+
|
|
8
|
+
from packaging.version import parse as parse_version
|
|
9
|
+
|
|
10
|
+
# TODO: Consider moving this to jabs.core
|
|
11
|
+
from jabs.version import version_str
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def check_for_update() -> tuple[bool, str | None, str]:
|
|
17
|
+
"""Check PyPI for newer version of jabs-behavior-classifier.
|
|
18
|
+
|
|
19
|
+
Returns:
|
|
20
|
+
tuple: (has_update: bool, latest_version: str | None, current_version: str)
|
|
21
|
+
- has_update: True if a newer version is available
|
|
22
|
+
- latest_version: Latest version string from PyPI, or None if check failed
|
|
23
|
+
- current_version: Current installed version string
|
|
24
|
+
"""
|
|
25
|
+
try:
|
|
26
|
+
current_version = version_str()
|
|
27
|
+
|
|
28
|
+
with urllib.request.urlopen(
|
|
29
|
+
"https://pypi.org/pypi/jabs-behavior-classifier/json", timeout=5
|
|
30
|
+
) as response:
|
|
31
|
+
data = json.loads(response.read())
|
|
32
|
+
latest_version = data["info"]["version"]
|
|
33
|
+
|
|
34
|
+
has_update = parse_version(latest_version) > parse_version(current_version)
|
|
35
|
+
return has_update, latest_version, current_version
|
|
36
|
+
except Exception as e:
|
|
37
|
+
logger.warning(f"Failed to check for updates: {e}")
|
|
38
|
+
return False, None, version_str()
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def is_pypi_install() -> bool:
|
|
42
|
+
"""Check if jabs-behavior-classifier was installed from PyPI.
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
bool: True if installed via pip from PyPI, False otherwise
|
|
46
|
+
"""
|
|
47
|
+
try:
|
|
48
|
+
dist = metadata.distribution("jabs-behavior-classifier")
|
|
49
|
+
# Check if installer was pip
|
|
50
|
+
installer = dist.read_text("INSTALLER")
|
|
51
|
+
return installer is not None and installer.strip() in ("pip", "uv")
|
|
52
|
+
except Exception as e:
|
|
53
|
+
logger.debug(f"Could not determine installation method: {e}")
|
|
54
|
+
return False
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
import hashlib
|
|
2
|
+
import os
|
|
3
|
+
import sys
|
|
4
|
+
from collections.abc import Generator
|
|
5
|
+
from contextlib import contextmanager
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@contextmanager
|
|
11
|
+
def hide_stderr() -> Generator[int, Any, None]:
|
|
12
|
+
"""Context manager to temporarily suppress output to standard error (stderr).
|
|
13
|
+
|
|
14
|
+
Redirects all output sent to stderr to os.devnull while the context is active,
|
|
15
|
+
restoring stderr to its original state upon exit.
|
|
16
|
+
|
|
17
|
+
Yields:
|
|
18
|
+
int: The file descriptor for stderr.
|
|
19
|
+
"""
|
|
20
|
+
fd = sys.stderr.fileno()
|
|
21
|
+
|
|
22
|
+
# copy fd before it is overwritten
|
|
23
|
+
with os.fdopen(os.dup(fd), "wb") as copied:
|
|
24
|
+
sys.stderr.flush()
|
|
25
|
+
|
|
26
|
+
# open destination
|
|
27
|
+
with open(os.devnull, "wb") as fout:
|
|
28
|
+
os.dup2(fout.fileno(), fd)
|
|
29
|
+
try:
|
|
30
|
+
yield fd
|
|
31
|
+
finally:
|
|
32
|
+
# restore stderr to its previous value
|
|
33
|
+
sys.stderr.flush()
|
|
34
|
+
os.dup2(copied.fileno(), fd)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def hash_file(file: Path):
|
|
38
|
+
"""return hash"""
|
|
39
|
+
chunk_size = 8192
|
|
40
|
+
with file.open("rb") as f:
|
|
41
|
+
h = hashlib.blake2b(digest_size=20)
|
|
42
|
+
c = f.read(chunk_size)
|
|
43
|
+
while c:
|
|
44
|
+
h.update(c)
|
|
45
|
+
c = f.read(chunk_size)
|
|
46
|
+
return h.hexdigest()
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def get_bool_env_var(var_name, default_value=False) -> bool:
|
|
50
|
+
"""Gets a boolean value from an environment variable.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
var_name: The name of the environment variable.
|
|
54
|
+
default_value: The default value to return if the variable is
|
|
55
|
+
not set or invalid.
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
A boolean value.
|
|
59
|
+
"""
|
|
60
|
+
value = os.getenv(var_name)
|
|
61
|
+
if value is None:
|
|
62
|
+
return default_value
|
|
63
|
+
|
|
64
|
+
return value.lower() in ("true", "1", "yes", "on", "y", "t")
|