singlebehaviorlab 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sam2/__init__.py +11 -0
- sam2/automatic_mask_generator.py +454 -0
- sam2/benchmark.py +92 -0
- sam2/build_sam.py +174 -0
- sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
- sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
- sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
- sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
- sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +116 -0
- sam2/configs/sam2.1/sam2.1_hiera_l.yaml +120 -0
- sam2/configs/sam2.1/sam2.1_hiera_s.yaml +119 -0
- sam2/configs/sam2.1/sam2.1_hiera_t.yaml +121 -0
- sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +339 -0
- sam2/modeling/__init__.py +5 -0
- sam2/modeling/backbones/__init__.py +5 -0
- sam2/modeling/backbones/hieradet.py +317 -0
- sam2/modeling/backbones/image_encoder.py +134 -0
- sam2/modeling/backbones/utils.py +93 -0
- sam2/modeling/memory_attention.py +169 -0
- sam2/modeling/memory_encoder.py +181 -0
- sam2/modeling/position_encoding.py +239 -0
- sam2/modeling/sam/__init__.py +5 -0
- sam2/modeling/sam/mask_decoder.py +295 -0
- sam2/modeling/sam/prompt_encoder.py +202 -0
- sam2/modeling/sam/transformer.py +311 -0
- sam2/modeling/sam2_base.py +913 -0
- sam2/modeling/sam2_utils.py +323 -0
- sam2/sam2_hiera_b+.yaml +113 -0
- sam2/sam2_hiera_l.yaml +117 -0
- sam2/sam2_hiera_s.yaml +116 -0
- sam2/sam2_hiera_t.yaml +118 -0
- sam2/sam2_image_predictor.py +466 -0
- sam2/sam2_video_predictor.py +1388 -0
- sam2/sam2_video_predictor_legacy.py +1172 -0
- sam2/utils/__init__.py +5 -0
- sam2/utils/amg.py +348 -0
- sam2/utils/misc.py +349 -0
- sam2/utils/transforms.py +118 -0
- singlebehaviorlab/__init__.py +4 -0
- singlebehaviorlab/__main__.py +130 -0
- singlebehaviorlab/_paths.py +100 -0
- singlebehaviorlab/backend/__init__.py +2 -0
- singlebehaviorlab/backend/augmentations.py +320 -0
- singlebehaviorlab/backend/data_store.py +420 -0
- singlebehaviorlab/backend/model.py +1290 -0
- singlebehaviorlab/backend/train.py +4667 -0
- singlebehaviorlab/backend/uncertainty.py +578 -0
- singlebehaviorlab/backend/video_processor.py +688 -0
- singlebehaviorlab/backend/video_utils.py +139 -0
- singlebehaviorlab/data/config/config.yaml +85 -0
- singlebehaviorlab/data/training_profiles.json +334 -0
- singlebehaviorlab/gui/__init__.py +4 -0
- singlebehaviorlab/gui/analysis_widget.py +2291 -0
- singlebehaviorlab/gui/attention_export.py +311 -0
- singlebehaviorlab/gui/clip_extraction_widget.py +481 -0
- singlebehaviorlab/gui/clustering_widget.py +3187 -0
- singlebehaviorlab/gui/inference_popups.py +1138 -0
- singlebehaviorlab/gui/inference_widget.py +4550 -0
- singlebehaviorlab/gui/inference_worker.py +651 -0
- singlebehaviorlab/gui/labeling_widget.py +2324 -0
- singlebehaviorlab/gui/main_window.py +754 -0
- singlebehaviorlab/gui/metadata_management_widget.py +1119 -0
- singlebehaviorlab/gui/motion_tracking.py +764 -0
- singlebehaviorlab/gui/overlay_export.py +1234 -0
- singlebehaviorlab/gui/plot_integration.py +729 -0
- singlebehaviorlab/gui/qt_helpers.py +29 -0
- singlebehaviorlab/gui/registration_widget.py +1485 -0
- singlebehaviorlab/gui/review_widget.py +1330 -0
- singlebehaviorlab/gui/segmentation_tracking_widget.py +2752 -0
- singlebehaviorlab/gui/tab_tutorial_dialog.py +312 -0
- singlebehaviorlab/gui/timeline_themes.py +131 -0
- singlebehaviorlab/gui/training_profiles.py +418 -0
- singlebehaviorlab/gui/training_widget.py +3719 -0
- singlebehaviorlab/gui/video_utils.py +233 -0
- singlebehaviorlab/licenses/SAM2-LICENSE +201 -0
- singlebehaviorlab/licenses/VideoPrism-LICENSE +202 -0
- singlebehaviorlab-2.0.0.dist-info/METADATA +447 -0
- singlebehaviorlab-2.0.0.dist-info/RECORD +88 -0
- singlebehaviorlab-2.0.0.dist-info/WHEEL +5 -0
- singlebehaviorlab-2.0.0.dist-info/entry_points.txt +2 -0
- singlebehaviorlab-2.0.0.dist-info/licenses/LICENSE +21 -0
- singlebehaviorlab-2.0.0.dist-info/top_level.txt +3 -0
- videoprism/__init__.py +0 -0
- videoprism/encoders.py +910 -0
- videoprism/layers.py +1136 -0
- videoprism/models.py +407 -0
- videoprism/tokenizers.py +167 -0
- videoprism/utils.py +168 -0
videoprism/models.py
ADDED
|
@@ -0,0 +1,407 @@
|
|
|
1
|
+
# Copyright 2026 VideoPrism Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Provides builders and loaders of VideoPrism checkpoints.
|
|
16
|
+
|
|
17
|
+
The v1 base model takes videos with shape (16, 288, 288) as inputs and outputs
|
|
18
|
+
embeddings with shape (batch_size, 4096, 768) which could be reshaped into
|
|
19
|
+
(batch_size, 16, 16, 16, 768) for spatiotemporal representations. The input
|
|
20
|
+
videos should be normalized in [0.0, 1.0].
|
|
21
|
+
|
|
22
|
+
Example usage:
|
|
23
|
+
```
|
|
24
|
+
from videoprism import models as vp
|
|
25
|
+
|
|
26
|
+
model_name = 'videoprism_public_v1_base'
|
|
27
|
+
flax_model = vp.get_model(model_name)
|
|
28
|
+
loaded_state = vp.load_pretrained_weights(model_name)
|
|
29
|
+
|
|
30
|
+
@jax.jit
|
|
31
|
+
def forward_fn(inputs):
|
|
32
|
+
return flax_model.apply(loaded_state, inputs, train=False)
|
|
33
|
+
|
|
34
|
+
model_inputs = ...
|
|
35
|
+
outputs = forward_fn(model_inputs)
|
|
36
|
+
```
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
from collections.abc import Callable, Mapping, Sequence
|
|
40
|
+
import functools
|
|
41
|
+
|
|
42
|
+
from flax import linen as nn
|
|
43
|
+
import jax
|
|
44
|
+
import jax.numpy as jnp
|
|
45
|
+
import huggingface_hub
|
|
46
|
+
import numpy as np
|
|
47
|
+
from videoprism import encoders
|
|
48
|
+
from videoprism import tokenizers
|
|
49
|
+
from videoprism import utils
|
|
50
|
+
|
|
51
|
+
K400_NUM_CLASSES: int = 400
|
|
52
|
+
SSV2_NUM_CLASSES: int = 174
|
|
53
|
+
|
|
54
|
+
TEXT_MAX_LEN: int = 64
|
|
55
|
+
TEXT_TOKENIZERS = {
|
|
56
|
+
'c4_en': {
|
|
57
|
+
'model_path': 'gs://t5-data/vocabs/cc_en.32000/sentencepiece.model',
|
|
58
|
+
'vocab_size': 32_000,
|
|
59
|
+
},
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
CHECKPOINTS = {
|
|
63
|
+
# Hugging Face checkpoints (repository, filename).
|
|
64
|
+
'videoprism_public_v1_base': (
|
|
65
|
+
'google/videoprism-base-f16r288',
|
|
66
|
+
'flax_base_f16r288_repeated.npz',
|
|
67
|
+
),
|
|
68
|
+
'videoprism_public_v1_large': (
|
|
69
|
+
'google/videoprism-large-f8r288',
|
|
70
|
+
'flax_large_f8r288_repeated.npz',
|
|
71
|
+
),
|
|
72
|
+
'videoprism_lvt_public_v1_base': (
|
|
73
|
+
'google/videoprism-lvt-base-f16r288',
|
|
74
|
+
'flax_lvt_base_f16r288_repeated.npz',
|
|
75
|
+
),
|
|
76
|
+
'videoprism_lvt_public_v1_large': (
|
|
77
|
+
'google/videoprism-lvt-large-f8r288',
|
|
78
|
+
'flax_lvt_large_f8r288_repeated.npz',
|
|
79
|
+
),
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
CONFIGS = {
|
|
83
|
+
'videoprism_v1_base': dict(
|
|
84
|
+
patch_size=18,
|
|
85
|
+
pos_emb_shape=(16, 16, 16),
|
|
86
|
+
model_dim=768,
|
|
87
|
+
num_spatial_layers=12,
|
|
88
|
+
num_temporal_layers=4,
|
|
89
|
+
num_heads=12,
|
|
90
|
+
mlp_dim=3072,
|
|
91
|
+
atten_logit_cap=50.0,
|
|
92
|
+
scan=True,
|
|
93
|
+
),
|
|
94
|
+
'videoprism_v1_large': dict(
|
|
95
|
+
patch_size=18,
|
|
96
|
+
pos_emb_shape=(8, 16, 16),
|
|
97
|
+
model_dim=1024,
|
|
98
|
+
num_spatial_layers=24,
|
|
99
|
+
num_temporal_layers=4,
|
|
100
|
+
num_heads=16,
|
|
101
|
+
mlp_dim=4096,
|
|
102
|
+
atten_logit_cap=50.0,
|
|
103
|
+
scan=True,
|
|
104
|
+
),
|
|
105
|
+
'videoprism_v1_giant': dict(
|
|
106
|
+
patch_size=18,
|
|
107
|
+
pos_emb_shape=(8, 16, 16),
|
|
108
|
+
model_dim=1408,
|
|
109
|
+
num_spatial_layers=40,
|
|
110
|
+
num_temporal_layers=4,
|
|
111
|
+
num_heads=16,
|
|
112
|
+
mlp_dim=6144,
|
|
113
|
+
atten_logit_cap=50.0,
|
|
114
|
+
scan=True,
|
|
115
|
+
),
|
|
116
|
+
'videoprism_lvt_v1_base': dict(
|
|
117
|
+
patch_size=18,
|
|
118
|
+
pos_emb_shape=(16, 16, 16),
|
|
119
|
+
num_spatial_layers=12,
|
|
120
|
+
num_temporal_layers=4,
|
|
121
|
+
mlp_dim=3072,
|
|
122
|
+
num_auxiliary_layers=2,
|
|
123
|
+
enable_causal_atten=True,
|
|
124
|
+
num_unimodal_layers=12,
|
|
125
|
+
norm_policy='pre',
|
|
126
|
+
model_dim=768,
|
|
127
|
+
num_heads=12,
|
|
128
|
+
atten_logit_cap=50.0,
|
|
129
|
+
scan=True,
|
|
130
|
+
),
|
|
131
|
+
'videoprism_lvt_v1_large': dict(
|
|
132
|
+
patch_size=18,
|
|
133
|
+
pos_emb_shape=(8, 16, 16),
|
|
134
|
+
num_spatial_layers=24,
|
|
135
|
+
num_temporal_layers=4,
|
|
136
|
+
mlp_dim=4096,
|
|
137
|
+
num_auxiliary_layers=2,
|
|
138
|
+
enable_causal_atten=True,
|
|
139
|
+
num_unimodal_layers=12,
|
|
140
|
+
norm_policy='pre',
|
|
141
|
+
model_dim=1024,
|
|
142
|
+
num_heads=16,
|
|
143
|
+
atten_logit_cap=50.0,
|
|
144
|
+
scan=True,
|
|
145
|
+
),
|
|
146
|
+
'videoprism_lvt_v1_giant': dict(
|
|
147
|
+
patch_size=18,
|
|
148
|
+
pos_emb_shape=(8, 16, 16),
|
|
149
|
+
num_spatial_layers=40,
|
|
150
|
+
num_temporal_layers=4,
|
|
151
|
+
mlp_dim=6144,
|
|
152
|
+
num_auxiliary_layers=2,
|
|
153
|
+
enable_causal_atten=True,
|
|
154
|
+
num_unimodal_layers=16,
|
|
155
|
+
norm_policy='primer_hybrid',
|
|
156
|
+
model_dim=1408,
|
|
157
|
+
num_heads=16,
|
|
158
|
+
atten_logit_cap=50.0,
|
|
159
|
+
scan=True,
|
|
160
|
+
),
|
|
161
|
+
}
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def videoprism_v1_base():
|
|
165
|
+
"""Builds VideoPrism v1 base model."""
|
|
166
|
+
return encoders.FactorizedEncoder(**CONFIGS['videoprism_v1_base'])
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def videoprism_v1_large():
|
|
170
|
+
"""Builds VideoPrism v1 large model."""
|
|
171
|
+
return encoders.FactorizedEncoder(**CONFIGS['videoprism_v1_large'])
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def videoprism_v1_giant():
|
|
175
|
+
"""Builds VideoPrism v1 giant model."""
|
|
176
|
+
return encoders.FactorizedEncoder(**CONFIGS['videoprism_v1_giant'])
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def videoprism_lvt_v1_base(text_tokenizer: str = 'c4_en'):
|
|
180
|
+
"""Builds VideoPrism LvT v1 base model."""
|
|
181
|
+
config = CONFIGS['videoprism_lvt_v1_base']
|
|
182
|
+
config['vocabulary_size'] = TEXT_TOKENIZERS[text_tokenizer]['vocab_size']
|
|
183
|
+
return encoders.FactorizedVideoCLIP(**config)
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def videoprism_lvt_v1_large(text_tokenizer: str = 'c4_en'):
|
|
187
|
+
"""Builds VideoPrism LvT v1 large model."""
|
|
188
|
+
config = CONFIGS['videoprism_lvt_v1_large']
|
|
189
|
+
config['vocabulary_size'] = TEXT_TOKENIZERS[text_tokenizer]['vocab_size']
|
|
190
|
+
return encoders.FactorizedVideoCLIP(**config)
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def videoprism_lvt_v1_giant(text_tokenizer: str = 'c4_en'):
|
|
194
|
+
"""Builds VideoPrism LvT v1 giant model."""
|
|
195
|
+
config = CONFIGS['videoprism_lvt_v1_giant']
|
|
196
|
+
config['vocabulary_size'] = TEXT_TOKENIZERS[text_tokenizer]['vocab_size']
|
|
197
|
+
return encoders.FactorizedVideoCLIP(**config)
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def videoprism_vc_v1_base(num_classes: int):
|
|
201
|
+
"""Builds VideoPrism Classification v1 base model."""
|
|
202
|
+
encoder_params = CONFIGS['videoprism_v1_base']
|
|
203
|
+
return encoders.FactorizedVideoClassifier(
|
|
204
|
+
encoder_params=encoder_params, num_classes=num_classes
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def videoprism_vc_v1_large(num_classes: int):
|
|
209
|
+
"""Builds VideoPrism Classification v1 large model."""
|
|
210
|
+
encoder_params = CONFIGS['videoprism_v1_large']
|
|
211
|
+
return encoders.FactorizedVideoClassifier(
|
|
212
|
+
encoder_params=encoder_params, num_classes=num_classes
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
def videoprism_vc_v1_giant(num_classes: int):
|
|
217
|
+
"""Builds VideoPrism Classification v1 giant model."""
|
|
218
|
+
encoder_params = CONFIGS['videoprism_v1_giant']
|
|
219
|
+
return encoders.FactorizedVideoClassifier(
|
|
220
|
+
encoder_params=encoder_params, num_classes=num_classes
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
MODELS = {
|
|
225
|
+
'videoprism_public_v1_base': videoprism_v1_base,
|
|
226
|
+
'videoprism_public_v1_large': videoprism_v1_large,
|
|
227
|
+
'videoprism_lvt_public_v1_base': functools.partial(
|
|
228
|
+
videoprism_lvt_v1_base, text_tokenizer='c4_en'
|
|
229
|
+
),
|
|
230
|
+
'videoprism_lvt_public_v1_large': functools.partial(
|
|
231
|
+
videoprism_lvt_v1_large, text_tokenizer='c4_en'
|
|
232
|
+
),
|
|
233
|
+
}
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
def _get_model_name_by_hf_model_id(model_id: str) -> str | None:
|
|
237
|
+
"""Returns model name for the given Hugging Face model ID.
|
|
238
|
+
|
|
239
|
+
Hugging Face model ID is typically the name of the repository, e.g.,
|
|
240
|
+
`google/videoprism-base-f16r288`.
|
|
241
|
+
|
|
242
|
+
Args:
|
|
243
|
+
model_id: A string for the Hugging Face model ID.
|
|
244
|
+
|
|
245
|
+
Returns:
|
|
246
|
+
The model name for the given Hugging Face model ID or None if not found.
|
|
247
|
+
"""
|
|
248
|
+
for model_name, value in CHECKPOINTS.items():
|
|
249
|
+
if isinstance(value, tuple) and value[0] == model_id:
|
|
250
|
+
return model_name
|
|
251
|
+
|
|
252
|
+
return None
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
def has_model(
|
|
256
|
+
model_name: str,
|
|
257
|
+
models: Mapping[str, Callable[[], nn.Module]] | None = None,
|
|
258
|
+
) -> bool:
|
|
259
|
+
"""Returns whether the model is available."""
|
|
260
|
+
models = models or MODELS
|
|
261
|
+
if model_name.startswith('google/'):
|
|
262
|
+
# Handle Hugging Face model ID.
|
|
263
|
+
model_name = _get_model_name_by_hf_model_id(model_name)
|
|
264
|
+
|
|
265
|
+
return model_name is not None and model_name in models
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
def get_model(
|
|
269
|
+
model_name: str | None,
|
|
270
|
+
model_fn: Callable[[], nn.Module] | None = None,
|
|
271
|
+
models: Mapping[str, Callable[[], nn.Module]] | None = None,
|
|
272
|
+
fprop_dtype: jax.typing.DTypeLike | None = None,
|
|
273
|
+
):
|
|
274
|
+
"""Returns VideoPrism model with the given name.
|
|
275
|
+
|
|
276
|
+
Args:
|
|
277
|
+
model_name: A string for the model name or Hugging Face model ID.
|
|
278
|
+
model_fn: Optional function that returns the model.
|
|
279
|
+
models: Mapping from model name to model creation function. Used with
|
|
280
|
+
`model_name`. If None, use the default `MODELS`.
|
|
281
|
+
|
|
282
|
+
Returns:
|
|
283
|
+
A Flax VideoPrism model.
|
|
284
|
+
"""
|
|
285
|
+
|
|
286
|
+
if model_fn is None:
|
|
287
|
+
assert model_name is not None
|
|
288
|
+
models = models or MODELS
|
|
289
|
+
if model_name.startswith('google/'):
|
|
290
|
+
# Handle Hugging Face model ID.
|
|
291
|
+
model_name = _get_model_name_by_hf_model_id(model_name)
|
|
292
|
+
if model_name is None:
|
|
293
|
+
raise ValueError(f'Failed to find model name with `{model_name}`.')
|
|
294
|
+
|
|
295
|
+
if model_name not in models:
|
|
296
|
+
raise ValueError(f'Model `{model_name}` not found.')
|
|
297
|
+
|
|
298
|
+
model_fn = models[model_name]
|
|
299
|
+
|
|
300
|
+
model = model_fn()
|
|
301
|
+
if fprop_dtype is not None:
|
|
302
|
+
model.fprop_dtype = fprop_dtype
|
|
303
|
+
return model
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
def load_pretrained_weights(
|
|
307
|
+
model_name: str | None,
|
|
308
|
+
checkpoint_path: str | None = None,
|
|
309
|
+
checkpoints: Mapping[str, str | tuple[str, str]] | None = None,
|
|
310
|
+
):
|
|
311
|
+
"""Loads pretrained model weights.
|
|
312
|
+
|
|
313
|
+
Args:
|
|
314
|
+
model_name: A string for the model name or Hugging Face model ID.
|
|
315
|
+
checkpoint_path: Optional path of the model checkpoint.
|
|
316
|
+
checkpoints: Mapping from model name to checkpoint path. Used with
|
|
317
|
+
`model_name`. If None, use the default `CHECKPOINTS`.
|
|
318
|
+
|
|
319
|
+
Returns:
|
|
320
|
+
Restored Flax model weights.
|
|
321
|
+
"""
|
|
322
|
+
checkpoints = checkpoints or CHECKPOINTS
|
|
323
|
+
|
|
324
|
+
if checkpoint_path is None:
|
|
325
|
+
assert model_name is not None
|
|
326
|
+
if model_name.startswith('google/'):
|
|
327
|
+
# Handle Hugging Face model ID.
|
|
328
|
+
model_name = _get_model_name_by_hf_model_id(model_name)
|
|
329
|
+
|
|
330
|
+
repo_id, filename = checkpoints[model_name]
|
|
331
|
+
checkpoint_path = huggingface_hub.hf_hub_download(
|
|
332
|
+
repo_id=repo_id, filename=filename
|
|
333
|
+
)
|
|
334
|
+
|
|
335
|
+
variables = utils.load_checkpoint(checkpoint_path)
|
|
336
|
+
return jax.tree_util.tree_map(jnp.asarray, variables)
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
def load_text_tokenizer(name: str) -> tokenizers.Tokenizer:
|
|
340
|
+
"""Loads a text tokenizer by name.
|
|
341
|
+
|
|
342
|
+
Args:
|
|
343
|
+
name: A string for the text tokenizer model name.
|
|
344
|
+
|
|
345
|
+
Returns:
|
|
346
|
+
A text tokenizer.
|
|
347
|
+
"""
|
|
348
|
+
if name not in TEXT_TOKENIZERS:
|
|
349
|
+
raise ValueError(f'Text tokenizer `{name}` not found.')
|
|
350
|
+
|
|
351
|
+
model_path = TEXT_TOKENIZERS[name]['model_path']
|
|
352
|
+
return tokenizers.SentencePieceTokenizer(model_path)
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
def tokenize_texts(
|
|
356
|
+
tokenizer: tokenizers.Tokenizer,
|
|
357
|
+
inputs: Sequence[str],
|
|
358
|
+
max_length: int = TEXT_MAX_LEN,
|
|
359
|
+
add_bos: bool | None = None,
|
|
360
|
+
canonicalize: bool = True,
|
|
361
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
362
|
+
"""Tokenizes a batch of texts.
|
|
363
|
+
|
|
364
|
+
Args:
|
|
365
|
+
tokenizer: The tokenizer to use.
|
|
366
|
+
inputs: The list of texts to tokenize.
|
|
367
|
+
max_length: The maximum length of the tokenized texts.
|
|
368
|
+
add_bos: Whether to add a beginning-of-sentence token. If None, the
|
|
369
|
+
beginning-of-sentence token will be added if the tokenizer's bos_token is
|
|
370
|
+
a non-negative integer.
|
|
371
|
+
canonicalize: Whether to canonicalize the texts before tokenization.
|
|
372
|
+
|
|
373
|
+
Returns:
|
|
374
|
+
A tuple of two numpy arrays containing the padded token ids and the
|
|
375
|
+
corresponding paddings, where 1 denotes padding token.
|
|
376
|
+
"""
|
|
377
|
+
|
|
378
|
+
if canonicalize:
|
|
379
|
+
inputs = [utils.canonicalize_text(text) for text in inputs]
|
|
380
|
+
|
|
381
|
+
batch_ids, batch_paddings = [], []
|
|
382
|
+
if add_bos is None:
|
|
383
|
+
add_bos = tokenizer.bos_token >= 0
|
|
384
|
+
|
|
385
|
+
for ids in tokenizer.to_int(inputs, bos=add_bos, eos=False):
|
|
386
|
+
ids_seq_len = len(ids)
|
|
387
|
+
if ids_seq_len > max_length:
|
|
388
|
+
ids = ids[:max_length]
|
|
389
|
+
|
|
390
|
+
ids = np.asarray(ids, dtype=np.int32)
|
|
391
|
+
paddings = np.zeros_like(ids, dtype=np.float32)
|
|
392
|
+
|
|
393
|
+
if ids_seq_len < max_length:
|
|
394
|
+
ids = np.pad(
|
|
395
|
+
ids, (0, max_length - ids_seq_len), 'constant', constant_values=0
|
|
396
|
+
)
|
|
397
|
+
paddings = np.pad(
|
|
398
|
+
paddings,
|
|
399
|
+
(0, max_length - ids_seq_len),
|
|
400
|
+
'constant',
|
|
401
|
+
constant_values=1.0,
|
|
402
|
+
)
|
|
403
|
+
|
|
404
|
+
batch_ids.append(ids)
|
|
405
|
+
batch_paddings.append(paddings)
|
|
406
|
+
|
|
407
|
+
return np.asarray(batch_ids), np.asarray(batch_paddings)
|
videoprism/tokenizers.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
1
|
+
# Copyright 2026 VideoPrism Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Tokenizers for text encoders."""
|
|
16
|
+
|
|
17
|
+
from collections.abc import Sequence
|
|
18
|
+
from typing import Protocol
|
|
19
|
+
|
|
20
|
+
import tensorflow as tf
|
|
21
|
+
from tensorflow.io import gfile
|
|
22
|
+
|
|
23
|
+
import sentencepiece
|
|
24
|
+
|
|
25
|
+
SentencePieceProcessor = sentencepiece.SentencePieceProcessor
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class Tokenizer(Protocol):
|
|
29
|
+
"""Tokenizer interface."""
|
|
30
|
+
|
|
31
|
+
def to_int(
|
|
32
|
+
self, text: str | Sequence[str], *, bos: bool = False, eos: bool = False
|
|
33
|
+
) -> list[int] | list[list[int]]:
|
|
34
|
+
"""Tokenizes `text` into a list of integer tokens.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
text: can be a single string, or a list of strings.
|
|
38
|
+
bos: Whether a beginning-of-sentence token should be prepended.
|
|
39
|
+
eos: Whether an end-of-sentence token should be appended.
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
A list or list-of-list of tokens.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
def to_int_tf_op(
|
|
46
|
+
self, text: str | Sequence[str], *, bos: bool = False, eos: bool = False
|
|
47
|
+
) -> tf.Tensor | tf.RaggedTensor:
|
|
48
|
+
"""Same as `to_int()`, but as TF ops to be used in data pipelines.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
text: can be a single string, or a list of strings.
|
|
52
|
+
bos: Whether a beginning-of-sentence token should be prepended.
|
|
53
|
+
eos: Whether an end-of-sentence token should be appended.
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
A tf.Tensor of tokens.
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
@property
|
|
60
|
+
def pad_token(self) -> int:
|
|
61
|
+
"""Token id of padding token."""
|
|
62
|
+
|
|
63
|
+
@property
|
|
64
|
+
def eos_token(self) -> int:
|
|
65
|
+
"""Token id of end-of-sentence token."""
|
|
66
|
+
|
|
67
|
+
@property
|
|
68
|
+
def bos_token(self) -> int:
|
|
69
|
+
"""Token id of beginning-of-sentence token."""
|
|
70
|
+
|
|
71
|
+
@property
|
|
72
|
+
def vocab_size(self) -> int:
|
|
73
|
+
"""Returns the size of the vocabulary."""
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class SentencePieceTokenizer(Tokenizer):
|
|
77
|
+
"""Wraps a SentencePiece model for tokenization."""
|
|
78
|
+
|
|
79
|
+
def __init__(self, model_path):
|
|
80
|
+
"""Initializes the tokenizer.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
model_path: A path to load the SentencePiece model.
|
|
84
|
+
"""
|
|
85
|
+
with gfile.GFile(model_path, "rb") as f:
|
|
86
|
+
model_bytes = f.read()
|
|
87
|
+
|
|
88
|
+
self._model = SentencePieceProcessor()
|
|
89
|
+
self._model.LoadFromSerializedProto(model_bytes)
|
|
90
|
+
|
|
91
|
+
def to_int(
|
|
92
|
+
self, text: str | Sequence[str], *, bos: bool = False, eos: bool = False
|
|
93
|
+
) -> list[int] | list[list[int]]:
|
|
94
|
+
"""Tokenizes `text` into a list of integer tokens.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
text: can be a single string, or a list of strings.
|
|
98
|
+
bos: Whether a beginning-of-sentence token should be prepended.
|
|
99
|
+
eos: Whether an end-of-sentence token should be appended.
|
|
100
|
+
|
|
101
|
+
Returns:
|
|
102
|
+
A list or list-of-list of tokens.
|
|
103
|
+
"""
|
|
104
|
+
|
|
105
|
+
def _single(s: str) -> list[int]:
|
|
106
|
+
return (
|
|
107
|
+
([self.bos_token] if bos else [])
|
|
108
|
+
+ self._model.EncodeAsIds(s)
|
|
109
|
+
+ ([self.eos_token] if eos else [])
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
if isinstance(text, str):
|
|
113
|
+
return _single(text)
|
|
114
|
+
return list([_single(s) for s in text])
|
|
115
|
+
|
|
116
|
+
def to_int_tf_op(
|
|
117
|
+
self, text: str | Sequence[str], *, bos: bool = False, eos: bool = False
|
|
118
|
+
) -> tf.Tensor | tf.RaggedTensor:
|
|
119
|
+
"""Same as `to_int()`, but as TF ops to be used in data pipelines.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
text: can be a single string, or a list of strings.
|
|
123
|
+
bos: Whether a beginning-of-sentence token should be prepended.
|
|
124
|
+
eos: Whether an end-of-sentence token should be appended.
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
A tf.Tensor or tf.RaggedTensor of tokens.
|
|
128
|
+
"""
|
|
129
|
+
text = tf.convert_to_tensor(text)
|
|
130
|
+
if text.ndim == 0:
|
|
131
|
+
|
|
132
|
+
def fn(txt):
|
|
133
|
+
"""Tokenizes a single string."""
|
|
134
|
+
s = txt.numpy().decode()
|
|
135
|
+
return tf.constant(self.to_int(s, bos=bos, eos=eos), tf.int32)
|
|
136
|
+
|
|
137
|
+
return tf.py_function(fn, [text], tf.int32)
|
|
138
|
+
else:
|
|
139
|
+
|
|
140
|
+
def fn(txt):
|
|
141
|
+
"""Tokenizes a list of strings."""
|
|
142
|
+
strings = [s.decode() for s in txt.numpy().tolist()]
|
|
143
|
+
toks = self.to_int(strings, bos=bos, eos=eos)
|
|
144
|
+
return tf.ragged.constant(toks)
|
|
145
|
+
|
|
146
|
+
out_type = tf.RaggedTensorSpec([text.shape[0], None], tf.int32)
|
|
147
|
+
return tf.py_function(fn, [text], Tout=out_type)
|
|
148
|
+
|
|
149
|
+
@property
|
|
150
|
+
def pad_token(self) -> int:
|
|
151
|
+
"""Token id of padding token."""
|
|
152
|
+
return self._model.pad_id()
|
|
153
|
+
|
|
154
|
+
@property
|
|
155
|
+
def eos_token(self) -> int:
|
|
156
|
+
"""Token id of end-of-sentence token."""
|
|
157
|
+
return self._model.eos_id()
|
|
158
|
+
|
|
159
|
+
@property
|
|
160
|
+
def bos_token(self) -> int:
|
|
161
|
+
"""Token id of beginning-of-sentence token."""
|
|
162
|
+
return self._model.bos_id()
|
|
163
|
+
|
|
164
|
+
@property
|
|
165
|
+
def vocab_size(self) -> int:
|
|
166
|
+
"""Returns the size of the vocabulary."""
|
|
167
|
+
return self._model.GetPieceSize()
|