singlebehaviorlab 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. sam2/__init__.py +11 -0
  2. sam2/automatic_mask_generator.py +454 -0
  3. sam2/benchmark.py +92 -0
  4. sam2/build_sam.py +174 -0
  5. sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
  6. sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
  7. sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
  8. sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
  9. sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +116 -0
  10. sam2/configs/sam2.1/sam2.1_hiera_l.yaml +120 -0
  11. sam2/configs/sam2.1/sam2.1_hiera_s.yaml +119 -0
  12. sam2/configs/sam2.1/sam2.1_hiera_t.yaml +121 -0
  13. sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +339 -0
  14. sam2/modeling/__init__.py +5 -0
  15. sam2/modeling/backbones/__init__.py +5 -0
  16. sam2/modeling/backbones/hieradet.py +317 -0
  17. sam2/modeling/backbones/image_encoder.py +134 -0
  18. sam2/modeling/backbones/utils.py +93 -0
  19. sam2/modeling/memory_attention.py +169 -0
  20. sam2/modeling/memory_encoder.py +181 -0
  21. sam2/modeling/position_encoding.py +239 -0
  22. sam2/modeling/sam/__init__.py +5 -0
  23. sam2/modeling/sam/mask_decoder.py +295 -0
  24. sam2/modeling/sam/prompt_encoder.py +202 -0
  25. sam2/modeling/sam/transformer.py +311 -0
  26. sam2/modeling/sam2_base.py +913 -0
  27. sam2/modeling/sam2_utils.py +323 -0
  28. sam2/sam2_hiera_b+.yaml +113 -0
  29. sam2/sam2_hiera_l.yaml +117 -0
  30. sam2/sam2_hiera_s.yaml +116 -0
  31. sam2/sam2_hiera_t.yaml +118 -0
  32. sam2/sam2_image_predictor.py +466 -0
  33. sam2/sam2_video_predictor.py +1388 -0
  34. sam2/sam2_video_predictor_legacy.py +1172 -0
  35. sam2/utils/__init__.py +5 -0
  36. sam2/utils/amg.py +348 -0
  37. sam2/utils/misc.py +349 -0
  38. sam2/utils/transforms.py +118 -0
  39. singlebehaviorlab/__init__.py +4 -0
  40. singlebehaviorlab/__main__.py +130 -0
  41. singlebehaviorlab/_paths.py +100 -0
  42. singlebehaviorlab/backend/__init__.py +2 -0
  43. singlebehaviorlab/backend/augmentations.py +320 -0
  44. singlebehaviorlab/backend/data_store.py +420 -0
  45. singlebehaviorlab/backend/model.py +1290 -0
  46. singlebehaviorlab/backend/train.py +4667 -0
  47. singlebehaviorlab/backend/uncertainty.py +578 -0
  48. singlebehaviorlab/backend/video_processor.py +688 -0
  49. singlebehaviorlab/backend/video_utils.py +139 -0
  50. singlebehaviorlab/data/config/config.yaml +85 -0
  51. singlebehaviorlab/data/training_profiles.json +334 -0
  52. singlebehaviorlab/gui/__init__.py +4 -0
  53. singlebehaviorlab/gui/analysis_widget.py +2291 -0
  54. singlebehaviorlab/gui/attention_export.py +311 -0
  55. singlebehaviorlab/gui/clip_extraction_widget.py +481 -0
  56. singlebehaviorlab/gui/clustering_widget.py +3187 -0
  57. singlebehaviorlab/gui/inference_popups.py +1138 -0
  58. singlebehaviorlab/gui/inference_widget.py +4550 -0
  59. singlebehaviorlab/gui/inference_worker.py +651 -0
  60. singlebehaviorlab/gui/labeling_widget.py +2324 -0
  61. singlebehaviorlab/gui/main_window.py +754 -0
  62. singlebehaviorlab/gui/metadata_management_widget.py +1119 -0
  63. singlebehaviorlab/gui/motion_tracking.py +764 -0
  64. singlebehaviorlab/gui/overlay_export.py +1234 -0
  65. singlebehaviorlab/gui/plot_integration.py +729 -0
  66. singlebehaviorlab/gui/qt_helpers.py +29 -0
  67. singlebehaviorlab/gui/registration_widget.py +1485 -0
  68. singlebehaviorlab/gui/review_widget.py +1330 -0
  69. singlebehaviorlab/gui/segmentation_tracking_widget.py +2752 -0
  70. singlebehaviorlab/gui/tab_tutorial_dialog.py +312 -0
  71. singlebehaviorlab/gui/timeline_themes.py +131 -0
  72. singlebehaviorlab/gui/training_profiles.py +418 -0
  73. singlebehaviorlab/gui/training_widget.py +3719 -0
  74. singlebehaviorlab/gui/video_utils.py +233 -0
  75. singlebehaviorlab/licenses/SAM2-LICENSE +201 -0
  76. singlebehaviorlab/licenses/VideoPrism-LICENSE +202 -0
  77. singlebehaviorlab-2.0.0.dist-info/METADATA +447 -0
  78. singlebehaviorlab-2.0.0.dist-info/RECORD +88 -0
  79. singlebehaviorlab-2.0.0.dist-info/WHEEL +5 -0
  80. singlebehaviorlab-2.0.0.dist-info/entry_points.txt +2 -0
  81. singlebehaviorlab-2.0.0.dist-info/licenses/LICENSE +21 -0
  82. singlebehaviorlab-2.0.0.dist-info/top_level.txt +3 -0
  83. videoprism/__init__.py +0 -0
  84. videoprism/encoders.py +910 -0
  85. videoprism/layers.py +1136 -0
  86. videoprism/models.py +407 -0
  87. videoprism/tokenizers.py +167 -0
  88. videoprism/utils.py +168 -0
videoprism/models.py ADDED
@@ -0,0 +1,407 @@
1
+ # Copyright 2026 VideoPrism Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Provides builders and loaders of VideoPrism checkpoints.
16
+
17
+ The v1 base model takes videos with shape (16, 288, 288) as inputs and outputs
18
+ embeddings with shape (batch_size, 4096, 768) which could be reshaped into
19
+ (batch_size, 16, 16, 16, 768) for spatiotemporal representations. The input
20
+ videos should be normalized in [0.0, 1.0].
21
+
22
+ Example usage:
23
+ ```
24
+ from videoprism import models as vp
25
+
26
+ model_name = 'videoprism_public_v1_base'
27
+ flax_model = vp.get_model(model_name)
28
+ loaded_state = vp.load_pretrained_weights(model_name)
29
+
30
+ @jax.jit
31
+ def forward_fn(inputs):
32
+ return flax_model.apply(loaded_state, inputs, train=False)
33
+
34
+ model_inputs = ...
35
+ outputs = forward_fn(model_inputs)
36
+ ```
37
+ """
38
+
39
+ from collections.abc import Callable, Mapping, Sequence
40
+ import functools
41
+
42
+ from flax import linen as nn
43
+ import jax
44
+ import jax.numpy as jnp
45
+ import huggingface_hub
46
+ import numpy as np
47
+ from videoprism import encoders
48
+ from videoprism import tokenizers
49
+ from videoprism import utils
50
+
51
+ K400_NUM_CLASSES: int = 400
52
+ SSV2_NUM_CLASSES: int = 174
53
+
54
+ TEXT_MAX_LEN: int = 64
55
+ TEXT_TOKENIZERS = {
56
+ 'c4_en': {
57
+ 'model_path': 'gs://t5-data/vocabs/cc_en.32000/sentencepiece.model',
58
+ 'vocab_size': 32_000,
59
+ },
60
+ }
61
+
62
+ CHECKPOINTS = {
63
+ # Hugging Face checkpoints (repository, filename).
64
+ 'videoprism_public_v1_base': (
65
+ 'google/videoprism-base-f16r288',
66
+ 'flax_base_f16r288_repeated.npz',
67
+ ),
68
+ 'videoprism_public_v1_large': (
69
+ 'google/videoprism-large-f8r288',
70
+ 'flax_large_f8r288_repeated.npz',
71
+ ),
72
+ 'videoprism_lvt_public_v1_base': (
73
+ 'google/videoprism-lvt-base-f16r288',
74
+ 'flax_lvt_base_f16r288_repeated.npz',
75
+ ),
76
+ 'videoprism_lvt_public_v1_large': (
77
+ 'google/videoprism-lvt-large-f8r288',
78
+ 'flax_lvt_large_f8r288_repeated.npz',
79
+ ),
80
+ }
81
+
82
+ CONFIGS = {
83
+ 'videoprism_v1_base': dict(
84
+ patch_size=18,
85
+ pos_emb_shape=(16, 16, 16),
86
+ model_dim=768,
87
+ num_spatial_layers=12,
88
+ num_temporal_layers=4,
89
+ num_heads=12,
90
+ mlp_dim=3072,
91
+ atten_logit_cap=50.0,
92
+ scan=True,
93
+ ),
94
+ 'videoprism_v1_large': dict(
95
+ patch_size=18,
96
+ pos_emb_shape=(8, 16, 16),
97
+ model_dim=1024,
98
+ num_spatial_layers=24,
99
+ num_temporal_layers=4,
100
+ num_heads=16,
101
+ mlp_dim=4096,
102
+ atten_logit_cap=50.0,
103
+ scan=True,
104
+ ),
105
+ 'videoprism_v1_giant': dict(
106
+ patch_size=18,
107
+ pos_emb_shape=(8, 16, 16),
108
+ model_dim=1408,
109
+ num_spatial_layers=40,
110
+ num_temporal_layers=4,
111
+ num_heads=16,
112
+ mlp_dim=6144,
113
+ atten_logit_cap=50.0,
114
+ scan=True,
115
+ ),
116
+ 'videoprism_lvt_v1_base': dict(
117
+ patch_size=18,
118
+ pos_emb_shape=(16, 16, 16),
119
+ num_spatial_layers=12,
120
+ num_temporal_layers=4,
121
+ mlp_dim=3072,
122
+ num_auxiliary_layers=2,
123
+ enable_causal_atten=True,
124
+ num_unimodal_layers=12,
125
+ norm_policy='pre',
126
+ model_dim=768,
127
+ num_heads=12,
128
+ atten_logit_cap=50.0,
129
+ scan=True,
130
+ ),
131
+ 'videoprism_lvt_v1_large': dict(
132
+ patch_size=18,
133
+ pos_emb_shape=(8, 16, 16),
134
+ num_spatial_layers=24,
135
+ num_temporal_layers=4,
136
+ mlp_dim=4096,
137
+ num_auxiliary_layers=2,
138
+ enable_causal_atten=True,
139
+ num_unimodal_layers=12,
140
+ norm_policy='pre',
141
+ model_dim=1024,
142
+ num_heads=16,
143
+ atten_logit_cap=50.0,
144
+ scan=True,
145
+ ),
146
+ 'videoprism_lvt_v1_giant': dict(
147
+ patch_size=18,
148
+ pos_emb_shape=(8, 16, 16),
149
+ num_spatial_layers=40,
150
+ num_temporal_layers=4,
151
+ mlp_dim=6144,
152
+ num_auxiliary_layers=2,
153
+ enable_causal_atten=True,
154
+ num_unimodal_layers=16,
155
+ norm_policy='primer_hybrid',
156
+ model_dim=1408,
157
+ num_heads=16,
158
+ atten_logit_cap=50.0,
159
+ scan=True,
160
+ ),
161
+ }
162
+
163
+
164
+ def videoprism_v1_base():
165
+ """Builds VideoPrism v1 base model."""
166
+ return encoders.FactorizedEncoder(**CONFIGS['videoprism_v1_base'])
167
+
168
+
169
+ def videoprism_v1_large():
170
+ """Builds VideoPrism v1 large model."""
171
+ return encoders.FactorizedEncoder(**CONFIGS['videoprism_v1_large'])
172
+
173
+
174
+ def videoprism_v1_giant():
175
+ """Builds VideoPrism v1 giant model."""
176
+ return encoders.FactorizedEncoder(**CONFIGS['videoprism_v1_giant'])
177
+
178
+
179
+ def videoprism_lvt_v1_base(text_tokenizer: str = 'c4_en'):
180
+ """Builds VideoPrism LvT v1 base model."""
181
+ config = CONFIGS['videoprism_lvt_v1_base']
182
+ config['vocabulary_size'] = TEXT_TOKENIZERS[text_tokenizer]['vocab_size']
183
+ return encoders.FactorizedVideoCLIP(**config)
184
+
185
+
186
+ def videoprism_lvt_v1_large(text_tokenizer: str = 'c4_en'):
187
+ """Builds VideoPrism LvT v1 large model."""
188
+ config = CONFIGS['videoprism_lvt_v1_large']
189
+ config['vocabulary_size'] = TEXT_TOKENIZERS[text_tokenizer]['vocab_size']
190
+ return encoders.FactorizedVideoCLIP(**config)
191
+
192
+
193
+ def videoprism_lvt_v1_giant(text_tokenizer: str = 'c4_en'):
194
+ """Builds VideoPrism LvT v1 giant model."""
195
+ config = CONFIGS['videoprism_lvt_v1_giant']
196
+ config['vocabulary_size'] = TEXT_TOKENIZERS[text_tokenizer]['vocab_size']
197
+ return encoders.FactorizedVideoCLIP(**config)
198
+
199
+
200
+ def videoprism_vc_v1_base(num_classes: int):
201
+ """Builds VideoPrism Classification v1 base model."""
202
+ encoder_params = CONFIGS['videoprism_v1_base']
203
+ return encoders.FactorizedVideoClassifier(
204
+ encoder_params=encoder_params, num_classes=num_classes
205
+ )
206
+
207
+
208
+ def videoprism_vc_v1_large(num_classes: int):
209
+ """Builds VideoPrism Classification v1 large model."""
210
+ encoder_params = CONFIGS['videoprism_v1_large']
211
+ return encoders.FactorizedVideoClassifier(
212
+ encoder_params=encoder_params, num_classes=num_classes
213
+ )
214
+
215
+
216
+ def videoprism_vc_v1_giant(num_classes: int):
217
+ """Builds VideoPrism Classification v1 giant model."""
218
+ encoder_params = CONFIGS['videoprism_v1_giant']
219
+ return encoders.FactorizedVideoClassifier(
220
+ encoder_params=encoder_params, num_classes=num_classes
221
+ )
222
+
223
+
224
+ MODELS = {
225
+ 'videoprism_public_v1_base': videoprism_v1_base,
226
+ 'videoprism_public_v1_large': videoprism_v1_large,
227
+ 'videoprism_lvt_public_v1_base': functools.partial(
228
+ videoprism_lvt_v1_base, text_tokenizer='c4_en'
229
+ ),
230
+ 'videoprism_lvt_public_v1_large': functools.partial(
231
+ videoprism_lvt_v1_large, text_tokenizer='c4_en'
232
+ ),
233
+ }
234
+
235
+
236
+ def _get_model_name_by_hf_model_id(model_id: str) -> str | None:
237
+ """Returns model name for the given Hugging Face model ID.
238
+
239
+ Hugging Face model ID is typically the name of the repository, e.g.,
240
+ `google/videoprism-base-f16r288`.
241
+
242
+ Args:
243
+ model_id: A string for the Hugging Face model ID.
244
+
245
+ Returns:
246
+ The model name for the given Hugging Face model ID or None if not found.
247
+ """
248
+ for model_name, value in CHECKPOINTS.items():
249
+ if isinstance(value, tuple) and value[0] == model_id:
250
+ return model_name
251
+
252
+ return None
253
+
254
+
255
+ def has_model(
256
+ model_name: str,
257
+ models: Mapping[str, Callable[[], nn.Module]] | None = None,
258
+ ) -> bool:
259
+ """Returns whether the model is available."""
260
+ models = models or MODELS
261
+ if model_name.startswith('google/'):
262
+ # Handle Hugging Face model ID.
263
+ model_name = _get_model_name_by_hf_model_id(model_name)
264
+
265
+ return model_name is not None and model_name in models
266
+
267
+
268
+ def get_model(
269
+ model_name: str | None,
270
+ model_fn: Callable[[], nn.Module] | None = None,
271
+ models: Mapping[str, Callable[[], nn.Module]] | None = None,
272
+ fprop_dtype: jax.typing.DTypeLike | None = None,
273
+ ):
274
+ """Returns VideoPrism model with the given name.
275
+
276
+ Args:
277
+ model_name: A string for the model name or Hugging Face model ID.
278
+ model_fn: Optional function that returns the model.
279
+ models: Mapping from model name to model creation function. Used with
280
+ `model_name`. If None, use the default `MODELS`.
281
+
282
+ Returns:
283
+ A Flax VideoPrism model.
284
+ """
285
+
286
+ if model_fn is None:
287
+ assert model_name is not None
288
+ models = models or MODELS
289
+ if model_name.startswith('google/'):
290
+ # Handle Hugging Face model ID.
291
+ model_name = _get_model_name_by_hf_model_id(model_name)
292
+ if model_name is None:
293
+ raise ValueError(f'Failed to find model name with `{model_name}`.')
294
+
295
+ if model_name not in models:
296
+ raise ValueError(f'Model `{model_name}` not found.')
297
+
298
+ model_fn = models[model_name]
299
+
300
+ model = model_fn()
301
+ if fprop_dtype is not None:
302
+ model.fprop_dtype = fprop_dtype
303
+ return model
304
+
305
+
306
+ def load_pretrained_weights(
307
+ model_name: str | None,
308
+ checkpoint_path: str | None = None,
309
+ checkpoints: Mapping[str, str | tuple[str, str]] | None = None,
310
+ ):
311
+ """Loads pretrained model weights.
312
+
313
+ Args:
314
+ model_name: A string for the model name or Hugging Face model ID.
315
+ checkpoint_path: Optional path of the model checkpoint.
316
+ checkpoints: Mapping from model name to checkpoint path. Used with
317
+ `model_name`. If None, use the default `CHECKPOINTS`.
318
+
319
+ Returns:
320
+ Restored Flax model weights.
321
+ """
322
+ checkpoints = checkpoints or CHECKPOINTS
323
+
324
+ if checkpoint_path is None:
325
+ assert model_name is not None
326
+ if model_name.startswith('google/'):
327
+ # Handle Hugging Face model ID.
328
+ model_name = _get_model_name_by_hf_model_id(model_name)
329
+
330
+ repo_id, filename = checkpoints[model_name]
331
+ checkpoint_path = huggingface_hub.hf_hub_download(
332
+ repo_id=repo_id, filename=filename
333
+ )
334
+
335
+ variables = utils.load_checkpoint(checkpoint_path)
336
+ return jax.tree_util.tree_map(jnp.asarray, variables)
337
+
338
+
339
+ def load_text_tokenizer(name: str) -> tokenizers.Tokenizer:
340
+ """Loads a text tokenizer by name.
341
+
342
+ Args:
343
+ name: A string for the text tokenizer model name.
344
+
345
+ Returns:
346
+ A text tokenizer.
347
+ """
348
+ if name not in TEXT_TOKENIZERS:
349
+ raise ValueError(f'Text tokenizer `{name}` not found.')
350
+
351
+ model_path = TEXT_TOKENIZERS[name]['model_path']
352
+ return tokenizers.SentencePieceTokenizer(model_path)
353
+
354
+
355
+ def tokenize_texts(
356
+ tokenizer: tokenizers.Tokenizer,
357
+ inputs: Sequence[str],
358
+ max_length: int = TEXT_MAX_LEN,
359
+ add_bos: bool | None = None,
360
+ canonicalize: bool = True,
361
+ ) -> tuple[np.ndarray, np.ndarray]:
362
+ """Tokenizes a batch of texts.
363
+
364
+ Args:
365
+ tokenizer: The tokenizer to use.
366
+ inputs: The list of texts to tokenize.
367
+ max_length: The maximum length of the tokenized texts.
368
+ add_bos: Whether to add a beginning-of-sentence token. If None, the
369
+ beginning-of-sentence token will be added if the tokenizer's bos_token is
370
+ a non-negative integer.
371
+ canonicalize: Whether to canonicalize the texts before tokenization.
372
+
373
+ Returns:
374
+ A tuple of two numpy arrays containing the padded token ids and the
375
+ corresponding paddings, where 1 denotes padding token.
376
+ """
377
+
378
+ if canonicalize:
379
+ inputs = [utils.canonicalize_text(text) for text in inputs]
380
+
381
+ batch_ids, batch_paddings = [], []
382
+ if add_bos is None:
383
+ add_bos = tokenizer.bos_token >= 0
384
+
385
+ for ids in tokenizer.to_int(inputs, bos=add_bos, eos=False):
386
+ ids_seq_len = len(ids)
387
+ if ids_seq_len > max_length:
388
+ ids = ids[:max_length]
389
+
390
+ ids = np.asarray(ids, dtype=np.int32)
391
+ paddings = np.zeros_like(ids, dtype=np.float32)
392
+
393
+ if ids_seq_len < max_length:
394
+ ids = np.pad(
395
+ ids, (0, max_length - ids_seq_len), 'constant', constant_values=0
396
+ )
397
+ paddings = np.pad(
398
+ paddings,
399
+ (0, max_length - ids_seq_len),
400
+ 'constant',
401
+ constant_values=1.0,
402
+ )
403
+
404
+ batch_ids.append(ids)
405
+ batch_paddings.append(paddings)
406
+
407
+ return np.asarray(batch_ids), np.asarray(batch_paddings)
@@ -0,0 +1,167 @@
1
+ # Copyright 2026 VideoPrism Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Tokenizers for text encoders."""
16
+
17
+ from collections.abc import Sequence
18
+ from typing import Protocol
19
+
20
+ import tensorflow as tf
21
+ from tensorflow.io import gfile
22
+
23
+ import sentencepiece
24
+
25
+ SentencePieceProcessor = sentencepiece.SentencePieceProcessor
26
+
27
+
28
+ class Tokenizer(Protocol):
29
+ """Tokenizer interface."""
30
+
31
+ def to_int(
32
+ self, text: str | Sequence[str], *, bos: bool = False, eos: bool = False
33
+ ) -> list[int] | list[list[int]]:
34
+ """Tokenizes `text` into a list of integer tokens.
35
+
36
+ Args:
37
+ text: can be a single string, or a list of strings.
38
+ bos: Whether a beginning-of-sentence token should be prepended.
39
+ eos: Whether an end-of-sentence token should be appended.
40
+
41
+ Returns:
42
+ A list or list-of-list of tokens.
43
+ """
44
+
45
+ def to_int_tf_op(
46
+ self, text: str | Sequence[str], *, bos: bool = False, eos: bool = False
47
+ ) -> tf.Tensor | tf.RaggedTensor:
48
+ """Same as `to_int()`, but as TF ops to be used in data pipelines.
49
+
50
+ Args:
51
+ text: can be a single string, or a list of strings.
52
+ bos: Whether a beginning-of-sentence token should be prepended.
53
+ eos: Whether an end-of-sentence token should be appended.
54
+
55
+ Returns:
56
+ A tf.Tensor of tokens.
57
+ """
58
+
59
+ @property
60
+ def pad_token(self) -> int:
61
+ """Token id of padding token."""
62
+
63
+ @property
64
+ def eos_token(self) -> int:
65
+ """Token id of end-of-sentence token."""
66
+
67
+ @property
68
+ def bos_token(self) -> int:
69
+ """Token id of beginning-of-sentence token."""
70
+
71
+ @property
72
+ def vocab_size(self) -> int:
73
+ """Returns the size of the vocabulary."""
74
+
75
+
76
+ class SentencePieceTokenizer(Tokenizer):
77
+ """Wraps a SentencePiece model for tokenization."""
78
+
79
+ def __init__(self, model_path):
80
+ """Initializes the tokenizer.
81
+
82
+ Args:
83
+ model_path: A path to load the SentencePiece model.
84
+ """
85
+ with gfile.GFile(model_path, "rb") as f:
86
+ model_bytes = f.read()
87
+
88
+ self._model = SentencePieceProcessor()
89
+ self._model.LoadFromSerializedProto(model_bytes)
90
+
91
+ def to_int(
92
+ self, text: str | Sequence[str], *, bos: bool = False, eos: bool = False
93
+ ) -> list[int] | list[list[int]]:
94
+ """Tokenizes `text` into a list of integer tokens.
95
+
96
+ Args:
97
+ text: can be a single string, or a list of strings.
98
+ bos: Whether a beginning-of-sentence token should be prepended.
99
+ eos: Whether an end-of-sentence token should be appended.
100
+
101
+ Returns:
102
+ A list or list-of-list of tokens.
103
+ """
104
+
105
+ def _single(s: str) -> list[int]:
106
+ return (
107
+ ([self.bos_token] if bos else [])
108
+ + self._model.EncodeAsIds(s)
109
+ + ([self.eos_token] if eos else [])
110
+ )
111
+
112
+ if isinstance(text, str):
113
+ return _single(text)
114
+ return list([_single(s) for s in text])
115
+
116
+ def to_int_tf_op(
117
+ self, text: str | Sequence[str], *, bos: bool = False, eos: bool = False
118
+ ) -> tf.Tensor | tf.RaggedTensor:
119
+ """Same as `to_int()`, but as TF ops to be used in data pipelines.
120
+
121
+ Args:
122
+ text: can be a single string, or a list of strings.
123
+ bos: Whether a beginning-of-sentence token should be prepended.
124
+ eos: Whether an end-of-sentence token should be appended.
125
+
126
+ Returns:
127
+ A tf.Tensor or tf.RaggedTensor of tokens.
128
+ """
129
+ text = tf.convert_to_tensor(text)
130
+ if text.ndim == 0:
131
+
132
+ def fn(txt):
133
+ """Tokenizes a single string."""
134
+ s = txt.numpy().decode()
135
+ return tf.constant(self.to_int(s, bos=bos, eos=eos), tf.int32)
136
+
137
+ return tf.py_function(fn, [text], tf.int32)
138
+ else:
139
+
140
+ def fn(txt):
141
+ """Tokenizes a list of strings."""
142
+ strings = [s.decode() for s in txt.numpy().tolist()]
143
+ toks = self.to_int(strings, bos=bos, eos=eos)
144
+ return tf.ragged.constant(toks)
145
+
146
+ out_type = tf.RaggedTensorSpec([text.shape[0], None], tf.int32)
147
+ return tf.py_function(fn, [text], Tout=out_type)
148
+
149
+ @property
150
+ def pad_token(self) -> int:
151
+ """Token id of padding token."""
152
+ return self._model.pad_id()
153
+
154
+ @property
155
+ def eos_token(self) -> int:
156
+ """Token id of end-of-sentence token."""
157
+ return self._model.eos_id()
158
+
159
+ @property
160
+ def bos_token(self) -> int:
161
+ """Token id of beginning-of-sentence token."""
162
+ return self._model.bos_id()
163
+
164
+ @property
165
+ def vocab_size(self) -> int:
166
+ """Returns the size of the vocabulary."""
167
+ return self._model.GetPieceSize()