singlebehaviorlab 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sam2/__init__.py +11 -0
- sam2/automatic_mask_generator.py +454 -0
- sam2/benchmark.py +92 -0
- sam2/build_sam.py +174 -0
- sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
- sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
- sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
- sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
- sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +116 -0
- sam2/configs/sam2.1/sam2.1_hiera_l.yaml +120 -0
- sam2/configs/sam2.1/sam2.1_hiera_s.yaml +119 -0
- sam2/configs/sam2.1/sam2.1_hiera_t.yaml +121 -0
- sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +339 -0
- sam2/modeling/__init__.py +5 -0
- sam2/modeling/backbones/__init__.py +5 -0
- sam2/modeling/backbones/hieradet.py +317 -0
- sam2/modeling/backbones/image_encoder.py +134 -0
- sam2/modeling/backbones/utils.py +93 -0
- sam2/modeling/memory_attention.py +169 -0
- sam2/modeling/memory_encoder.py +181 -0
- sam2/modeling/position_encoding.py +239 -0
- sam2/modeling/sam/__init__.py +5 -0
- sam2/modeling/sam/mask_decoder.py +295 -0
- sam2/modeling/sam/prompt_encoder.py +202 -0
- sam2/modeling/sam/transformer.py +311 -0
- sam2/modeling/sam2_base.py +913 -0
- sam2/modeling/sam2_utils.py +323 -0
- sam2/sam2_hiera_b+.yaml +113 -0
- sam2/sam2_hiera_l.yaml +117 -0
- sam2/sam2_hiera_s.yaml +116 -0
- sam2/sam2_hiera_t.yaml +118 -0
- sam2/sam2_image_predictor.py +466 -0
- sam2/sam2_video_predictor.py +1388 -0
- sam2/sam2_video_predictor_legacy.py +1172 -0
- sam2/utils/__init__.py +5 -0
- sam2/utils/amg.py +348 -0
- sam2/utils/misc.py +349 -0
- sam2/utils/transforms.py +118 -0
- singlebehaviorlab/__init__.py +4 -0
- singlebehaviorlab/__main__.py +130 -0
- singlebehaviorlab/_paths.py +100 -0
- singlebehaviorlab/backend/__init__.py +2 -0
- singlebehaviorlab/backend/augmentations.py +320 -0
- singlebehaviorlab/backend/data_store.py +420 -0
- singlebehaviorlab/backend/model.py +1290 -0
- singlebehaviorlab/backend/train.py +4667 -0
- singlebehaviorlab/backend/uncertainty.py +578 -0
- singlebehaviorlab/backend/video_processor.py +688 -0
- singlebehaviorlab/backend/video_utils.py +139 -0
- singlebehaviorlab/data/config/config.yaml +85 -0
- singlebehaviorlab/data/training_profiles.json +334 -0
- singlebehaviorlab/gui/__init__.py +4 -0
- singlebehaviorlab/gui/analysis_widget.py +2291 -0
- singlebehaviorlab/gui/attention_export.py +311 -0
- singlebehaviorlab/gui/clip_extraction_widget.py +481 -0
- singlebehaviorlab/gui/clustering_widget.py +3187 -0
- singlebehaviorlab/gui/inference_popups.py +1138 -0
- singlebehaviorlab/gui/inference_widget.py +4550 -0
- singlebehaviorlab/gui/inference_worker.py +651 -0
- singlebehaviorlab/gui/labeling_widget.py +2324 -0
- singlebehaviorlab/gui/main_window.py +754 -0
- singlebehaviorlab/gui/metadata_management_widget.py +1119 -0
- singlebehaviorlab/gui/motion_tracking.py +764 -0
- singlebehaviorlab/gui/overlay_export.py +1234 -0
- singlebehaviorlab/gui/plot_integration.py +729 -0
- singlebehaviorlab/gui/qt_helpers.py +29 -0
- singlebehaviorlab/gui/registration_widget.py +1485 -0
- singlebehaviorlab/gui/review_widget.py +1330 -0
- singlebehaviorlab/gui/segmentation_tracking_widget.py +2752 -0
- singlebehaviorlab/gui/tab_tutorial_dialog.py +312 -0
- singlebehaviorlab/gui/timeline_themes.py +131 -0
- singlebehaviorlab/gui/training_profiles.py +418 -0
- singlebehaviorlab/gui/training_widget.py +3719 -0
- singlebehaviorlab/gui/video_utils.py +233 -0
- singlebehaviorlab/licenses/SAM2-LICENSE +201 -0
- singlebehaviorlab/licenses/VideoPrism-LICENSE +202 -0
- singlebehaviorlab-2.0.0.dist-info/METADATA +447 -0
- singlebehaviorlab-2.0.0.dist-info/RECORD +88 -0
- singlebehaviorlab-2.0.0.dist-info/WHEEL +5 -0
- singlebehaviorlab-2.0.0.dist-info/entry_points.txt +2 -0
- singlebehaviorlab-2.0.0.dist-info/licenses/LICENSE +21 -0
- singlebehaviorlab-2.0.0.dist-info/top_level.txt +3 -0
- videoprism/__init__.py +0 -0
- videoprism/encoders.py +910 -0
- videoprism/layers.py +1136 -0
- videoprism/models.py +407 -0
- videoprism/tokenizers.py +167 -0
- videoprism/utils.py +168 -0
videoprism/utils.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
1
|
+
# Copyright 2026 VideoPrism Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Utility functions for checkpointing and other purposes."""
|
|
16
|
+
|
|
17
|
+
import collections
|
|
18
|
+
from collections.abc import Mapping, Sequence
|
|
19
|
+
import io
|
|
20
|
+
import os
|
|
21
|
+
import string
|
|
22
|
+
|
|
23
|
+
import jax
|
|
24
|
+
import numpy as np
|
|
25
|
+
from tensorflow.io import gfile
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def traverse_with_names(tree, with_inner_nodes=False):
|
|
29
|
+
"""Traverses nested dicts and emits (leaf_name, leaf_val).
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
tree: JAX Pytree object.
|
|
33
|
+
with_inner_nodes: Whether to traverse the non-leaf nodes.
|
|
34
|
+
|
|
35
|
+
Yields:
|
|
36
|
+
A pair of (leaf_name, leaf_val).
|
|
37
|
+
"""
|
|
38
|
+
# Don't output the non-leaf nodes. If the optimizer doesn't have a state
|
|
39
|
+
# the tree leaves can be Nones which was interpreted as a leaf by this
|
|
40
|
+
# function but not by the other functions (like jax.tree.map).
|
|
41
|
+
if tree is None:
|
|
42
|
+
return
|
|
43
|
+
elif isinstance(tree, Mapping):
|
|
44
|
+
keys = sorted(tree.keys())
|
|
45
|
+
for key in keys:
|
|
46
|
+
for path, v in traverse_with_names(tree[key], with_inner_nodes):
|
|
47
|
+
yield (key + "/" + path).rstrip("/"), v
|
|
48
|
+
if with_inner_nodes:
|
|
49
|
+
yield "", tree
|
|
50
|
+
elif isinstance(tree, Sequence):
|
|
51
|
+
for idx in range(len(tree)):
|
|
52
|
+
for path, v in traverse_with_names(tree[idx], with_inner_nodes):
|
|
53
|
+
yield (str(idx) + "/" + path).rstrip("/"), v
|
|
54
|
+
if with_inner_nodes:
|
|
55
|
+
yield "", tree
|
|
56
|
+
else:
|
|
57
|
+
yield "", tree
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def tree_flatten_with_names(tree):
|
|
61
|
+
"""Populates tree_flatten with leaf names.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
tree: JAX Pytree object.
|
|
65
|
+
|
|
66
|
+
Returns:
|
|
67
|
+
A list of values with names: [(name, value), ...]
|
|
68
|
+
"""
|
|
69
|
+
vals, tree_def = jax.tree.flatten(tree)
|
|
70
|
+
|
|
71
|
+
tokens = range(len(vals))
|
|
72
|
+
token_tree = tree_def.unflatten(tokens)
|
|
73
|
+
val_names, perm = zip(*traverse_with_names(token_tree))
|
|
74
|
+
inv_perm = np.argsort(perm)
|
|
75
|
+
|
|
76
|
+
# Custom traverasal should visit the same number of leaves.
|
|
77
|
+
assert len(val_names) == len(vals)
|
|
78
|
+
|
|
79
|
+
return [(val_names[i], v) for i, v in zip(inv_perm, vals)]
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def recover_tree(keys, values):
|
|
83
|
+
"""Recovers a tree as a nested dict from flat names and values.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
keys: A list of keys, where '/' is used as separator between nodes.
|
|
87
|
+
values: A list of leaf values.
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
A nested tree-like dict.
|
|
91
|
+
"""
|
|
92
|
+
tree = {}
|
|
93
|
+
sub_trees = collections.defaultdict(list)
|
|
94
|
+
for k, v in zip(keys, values):
|
|
95
|
+
if "/" not in k:
|
|
96
|
+
tree[k] = v
|
|
97
|
+
else:
|
|
98
|
+
k_left, k_right = k.split("/", 1)
|
|
99
|
+
sub_trees[k_left].append((k_right, v))
|
|
100
|
+
for k, kv_pairs in sub_trees.items():
|
|
101
|
+
k_subtree, v_subtree = zip(*kv_pairs)
|
|
102
|
+
tree[k] = recover_tree(k_subtree, v_subtree)
|
|
103
|
+
return tree
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def npload(fname):
|
|
107
|
+
"""Loads `fname` and returns an np.ndarray or dict thereof."""
|
|
108
|
+
# Load the data; use local paths directly if possible:
|
|
109
|
+
if os.path.exists(fname):
|
|
110
|
+
loaded = np.load(fname, allow_pickle=False)
|
|
111
|
+
else:
|
|
112
|
+
# For other (remote) paths go via gfile+BytesIO as np.load requires seeks.
|
|
113
|
+
with gfile.GFile(fname, "rb") as f:
|
|
114
|
+
data = f.read()
|
|
115
|
+
loaded = np.load(io.BytesIO(data), allow_pickle=False)
|
|
116
|
+
|
|
117
|
+
# Support loading both single-array files (np.save) and zips (np.savez).
|
|
118
|
+
if isinstance(loaded, np.ndarray):
|
|
119
|
+
return loaded
|
|
120
|
+
else:
|
|
121
|
+
return dict(loaded)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def load_checkpoint(npz):
|
|
125
|
+
"""Loads a jax Pytree from a npz file.
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
npz: Either path to the checkpoint file (.npz), or a dict-like.
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
A Pytree that is the checkpoint.
|
|
132
|
+
"""
|
|
133
|
+
if isinstance(npz, str): # If not already loaded, then load.
|
|
134
|
+
npz = npload(npz)
|
|
135
|
+
keys, values = zip(*list(npz.items()))
|
|
136
|
+
return recover_tree(keys, values)
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def canonicalize_text(text: str) -> str:
|
|
140
|
+
"""Canonicalizes text.
|
|
141
|
+
|
|
142
|
+
Canonicalization includes:
|
|
143
|
+
- Replace all punctuation with a whitespace.
|
|
144
|
+
- Use all lower case.
|
|
145
|
+
- Leave only one whitespace between words.
|
|
146
|
+
- End with a period.
|
|
147
|
+
|
|
148
|
+
Examples:
|
|
149
|
+
"Hello, World!" -> "hello world."
|
|
150
|
+
"Hello,World.." -> "hello world."
|
|
151
|
+
" Hello WORLD" -> "hello world."
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
text: A string for the input text.
|
|
155
|
+
|
|
156
|
+
Returns:
|
|
157
|
+
A string for the canonicalized text.
|
|
158
|
+
"""
|
|
159
|
+
# Replace all punctuation with a whitespace.
|
|
160
|
+
p = string.punctuation
|
|
161
|
+
text = text.translate(str.maketrans(p, " " * len(p)))
|
|
162
|
+
# Use all lower case.
|
|
163
|
+
text = text.lower()
|
|
164
|
+
# Leave only one whitespace between words.
|
|
165
|
+
text = " ".join(text.split())
|
|
166
|
+
# End with a period.
|
|
167
|
+
text = text + "."
|
|
168
|
+
return text
|