singlebehaviorlab 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. sam2/__init__.py +11 -0
  2. sam2/automatic_mask_generator.py +454 -0
  3. sam2/benchmark.py +92 -0
  4. sam2/build_sam.py +174 -0
  5. sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
  6. sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
  7. sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
  8. sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
  9. sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +116 -0
  10. sam2/configs/sam2.1/sam2.1_hiera_l.yaml +120 -0
  11. sam2/configs/sam2.1/sam2.1_hiera_s.yaml +119 -0
  12. sam2/configs/sam2.1/sam2.1_hiera_t.yaml +121 -0
  13. sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +339 -0
  14. sam2/modeling/__init__.py +5 -0
  15. sam2/modeling/backbones/__init__.py +5 -0
  16. sam2/modeling/backbones/hieradet.py +317 -0
  17. sam2/modeling/backbones/image_encoder.py +134 -0
  18. sam2/modeling/backbones/utils.py +93 -0
  19. sam2/modeling/memory_attention.py +169 -0
  20. sam2/modeling/memory_encoder.py +181 -0
  21. sam2/modeling/position_encoding.py +239 -0
  22. sam2/modeling/sam/__init__.py +5 -0
  23. sam2/modeling/sam/mask_decoder.py +295 -0
  24. sam2/modeling/sam/prompt_encoder.py +202 -0
  25. sam2/modeling/sam/transformer.py +311 -0
  26. sam2/modeling/sam2_base.py +913 -0
  27. sam2/modeling/sam2_utils.py +323 -0
  28. sam2/sam2_hiera_b+.yaml +113 -0
  29. sam2/sam2_hiera_l.yaml +117 -0
  30. sam2/sam2_hiera_s.yaml +116 -0
  31. sam2/sam2_hiera_t.yaml +118 -0
  32. sam2/sam2_image_predictor.py +466 -0
  33. sam2/sam2_video_predictor.py +1388 -0
  34. sam2/sam2_video_predictor_legacy.py +1172 -0
  35. sam2/utils/__init__.py +5 -0
  36. sam2/utils/amg.py +348 -0
  37. sam2/utils/misc.py +349 -0
  38. sam2/utils/transforms.py +118 -0
  39. singlebehaviorlab/__init__.py +4 -0
  40. singlebehaviorlab/__main__.py +130 -0
  41. singlebehaviorlab/_paths.py +100 -0
  42. singlebehaviorlab/backend/__init__.py +2 -0
  43. singlebehaviorlab/backend/augmentations.py +320 -0
  44. singlebehaviorlab/backend/data_store.py +420 -0
  45. singlebehaviorlab/backend/model.py +1290 -0
  46. singlebehaviorlab/backend/train.py +4667 -0
  47. singlebehaviorlab/backend/uncertainty.py +578 -0
  48. singlebehaviorlab/backend/video_processor.py +688 -0
  49. singlebehaviorlab/backend/video_utils.py +139 -0
  50. singlebehaviorlab/data/config/config.yaml +85 -0
  51. singlebehaviorlab/data/training_profiles.json +334 -0
  52. singlebehaviorlab/gui/__init__.py +4 -0
  53. singlebehaviorlab/gui/analysis_widget.py +2291 -0
  54. singlebehaviorlab/gui/attention_export.py +311 -0
  55. singlebehaviorlab/gui/clip_extraction_widget.py +481 -0
  56. singlebehaviorlab/gui/clustering_widget.py +3187 -0
  57. singlebehaviorlab/gui/inference_popups.py +1138 -0
  58. singlebehaviorlab/gui/inference_widget.py +4550 -0
  59. singlebehaviorlab/gui/inference_worker.py +651 -0
  60. singlebehaviorlab/gui/labeling_widget.py +2324 -0
  61. singlebehaviorlab/gui/main_window.py +754 -0
  62. singlebehaviorlab/gui/metadata_management_widget.py +1119 -0
  63. singlebehaviorlab/gui/motion_tracking.py +764 -0
  64. singlebehaviorlab/gui/overlay_export.py +1234 -0
  65. singlebehaviorlab/gui/plot_integration.py +729 -0
  66. singlebehaviorlab/gui/qt_helpers.py +29 -0
  67. singlebehaviorlab/gui/registration_widget.py +1485 -0
  68. singlebehaviorlab/gui/review_widget.py +1330 -0
  69. singlebehaviorlab/gui/segmentation_tracking_widget.py +2752 -0
  70. singlebehaviorlab/gui/tab_tutorial_dialog.py +312 -0
  71. singlebehaviorlab/gui/timeline_themes.py +131 -0
  72. singlebehaviorlab/gui/training_profiles.py +418 -0
  73. singlebehaviorlab/gui/training_widget.py +3719 -0
  74. singlebehaviorlab/gui/video_utils.py +233 -0
  75. singlebehaviorlab/licenses/SAM2-LICENSE +201 -0
  76. singlebehaviorlab/licenses/VideoPrism-LICENSE +202 -0
  77. singlebehaviorlab-2.0.0.dist-info/METADATA +447 -0
  78. singlebehaviorlab-2.0.0.dist-info/RECORD +88 -0
  79. singlebehaviorlab-2.0.0.dist-info/WHEEL +5 -0
  80. singlebehaviorlab-2.0.0.dist-info/entry_points.txt +2 -0
  81. singlebehaviorlab-2.0.0.dist-info/licenses/LICENSE +21 -0
  82. singlebehaviorlab-2.0.0.dist-info/top_level.txt +3 -0
  83. videoprism/__init__.py +0 -0
  84. videoprism/encoders.py +910 -0
  85. videoprism/layers.py +1136 -0
  86. videoprism/models.py +407 -0
  87. videoprism/tokenizers.py +167 -0
  88. videoprism/utils.py +168 -0
videoprism/encoders.py ADDED
@@ -0,0 +1,910 @@
1
+ # Copyright 2026 VideoPrism Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Modules for video encoders."""
16
+
17
+ from collections.abc import Collection, Sequence
18
+ import dataclasses
19
+ import math
20
+ from typing import Any
21
+
22
+ import einops
23
+ import einshape
24
+ from flax import linen as nn
25
+ import jax
26
+ from jax import numpy as jnp
27
+ import numpy as np
28
+ from videoprism import layers
29
+
30
+ Array = jax.Array
31
+ Variables = nn.module.VariableDict
32
+
33
+ default_kernel_init = layers.default_kernel_init
34
+
35
+
36
+ def _contains(collection: Collection[str] | bool, key: str) -> bool:
37
+ """Checks if a collection contains a key.
38
+
39
+ Args:
40
+ collection: A collection of strings or a boolean value.
41
+ key: A string key to check.
42
+
43
+ Returns:
44
+ True if the collection contains the key, or if the collection is a True
45
+ boolean. False otherwise.
46
+ """
47
+ return collection if isinstance(collection, bool) else key in collection
48
+
49
+
50
+ def _l2_normalize(
51
+ x: Array, axis: int | Sequence[int] = -1, epsilon: float = 1e-12
52
+ ) -> Array:
53
+ """L2-normalizes a jax.Array along certain dimension.
54
+
55
+ Args:
56
+ x: An input jax.Array.
57
+ axis: An integer or a sequence of integers for the axis to normalize.
58
+ epsilon: A small constant for numerical stability.
59
+
60
+ Returns:
61
+ Normalized jax.Array.
62
+ """
63
+ x_dtype = x.dtype
64
+ # Always convert embed to float32 for all precisions.
65
+ x = x.astype(jnp.float32)
66
+ norm = jnp.sqrt(jnp.sum(x * x, axis=axis, keepdims=True) + epsilon)
67
+ return (x / norm).astype(x_dtype)
68
+
69
+
70
+ def _image_to_patch(inputs: Array, patch_size: int) -> Array:
71
+ """Converts an image to patches.
72
+
73
+ Args:
74
+ inputs: A jax.Array of shape [B, H, W, C] ,
75
+ patch_size: An integer for dimension of a square patch.
76
+
77
+ Returns:
78
+ batched_patches: [B, (H * W / P^2), P^2 * C].
79
+ """
80
+ if len(inputs.shape) < 4:
81
+ raise ValueError(
82
+ f'Image should be formatted as 4D [B, H, W, C], Shape: {inputs.shape}'
83
+ )
84
+ height, width, channels = inputs.shape[-3:]
85
+
86
+ if height % patch_size != 0 or width % patch_size != 0:
87
+ raise ValueError(
88
+ f'Image height ({height}) and width ({width}) should be multiples '
89
+ f'of patch_size ({patch_size}).'
90
+ )
91
+
92
+ row_blocks = height // patch_size
93
+ column_blocks = width // patch_size
94
+
95
+ patches = einops.rearrange(
96
+ inputs,
97
+ '... (m p)(n q) c->...(m n)(p q c)',
98
+ m=row_blocks,
99
+ n=column_blocks,
100
+ p=patch_size,
101
+ q=patch_size,
102
+ c=channels,
103
+ )
104
+ return patches
105
+
106
+
107
+ def _interpolate_emb_1d(emb: Array, target_emb_length: int) -> Array:
108
+ """Interpolates a 1D positional embedding to a new shape.
109
+
110
+ Args:
111
+ emb: jax.Array, (1, N, D), flattened 1D positional embedding.
112
+ target_emb_length: length of the target embedding.
113
+
114
+ Returns:
115
+ Flattened, interpolated embedding of shape (1, target_emb_length, D)
116
+ """
117
+
118
+ if len(emb.shape) > 3 or emb.shape[0] != 1:
119
+ raise ValueError('The shape of the embedding should be (1, N, D)')
120
+
121
+ emb_dim = emb.shape[-1]
122
+ emb = jnp.squeeze(emb, axis=0)
123
+
124
+ target_emb = jax.image.resize(
125
+ emb, (target_emb_length, emb_dim), method='bilinear'
126
+ )
127
+ target_emb = jnp.reshape(target_emb, (1, target_emb_length, emb_dim))
128
+ return target_emb
129
+
130
+
131
+ def _interpolate_emb_2d(
132
+ emb: Array,
133
+ source_emb_shape: tuple[int, int],
134
+ target_emb_shape: tuple[int, int],
135
+ ) -> Array:
136
+ """Interpolates a 2D positional embedding to a new shape.
137
+
138
+ Args:
139
+ emb: A jax.Array of shape (1, H1xW1, D) for flattened 2D positional
140
+ embedding.
141
+ source_emb_shape: Tuple, (H1, W1), height and width of the source embedding.
142
+ target_emb_shape: Tuple, (H2, W2), height and width of the target embedding.
143
+
144
+ Returns:
145
+ Flattened, interpolated embedding of shape (1, H2xW2, D)
146
+ """
147
+
148
+ if len(emb.shape) > 3 or emb.shape[0] != 1:
149
+ raise ValueError('The shape of the embedding should be (1, H * W, D)')
150
+
151
+ if emb.shape[-2] != source_emb_shape[0] * source_emb_shape[1]:
152
+ raise ValueError('The shape of the embedding does NOT match input specs.')
153
+
154
+ emb_dim = emb.shape[-1]
155
+ emb = jnp.reshape(emb, (source_emb_shape[0], source_emb_shape[1], emb_dim))
156
+
157
+ target_emb = jax.image.resize(
158
+ emb,
159
+ (target_emb_shape[0], target_emb_shape[1], emb_dim),
160
+ method='bilinear',
161
+ )
162
+ target_emb = jnp.reshape(
163
+ target_emb, (1, target_emb_shape[0] * target_emb_shape[1], emb_dim)
164
+ )
165
+ return target_emb
166
+
167
+
168
+ class Embedding(layers.Module):
169
+ """A simple embedding layer that performs embedding lookups from ids.
170
+
171
+ Attributes:
172
+ num_classes: Number of tokens in the vocabulary.
173
+ input_dim: Depth of the embedding output. This is called `input_dim` as
174
+ opposed to the more appropriate `embedding_dim` to be compatible with
175
+ other embedding layers defined in this file.
176
+ lookup_style: Style of lookup, one of index or matmul.
177
+ scale_sqrt_depth: If set to True, activations are scaled with
178
+ sqrt(embedding_dim) in embeding lookup.
179
+ set_nan_for_oob_id: If set to True, embeddings corresponding to
180
+ out-of-boundaries ids will be set to NaN.
181
+ """
182
+
183
+ num_classes: int = 0
184
+ input_dim: int = 0
185
+ lookup_style: str = 'index'
186
+ scale_sqrt_depth: bool = False
187
+ set_nan_for_oob_id: bool = False
188
+
189
+ @nn.compact
190
+ def __call__(self, ids: Array) -> Array:
191
+ """Generates a jax.Array of embedding lookup result.
192
+
193
+ Args:
194
+ ids: Indexes of shape [...] for embedding lookup.
195
+
196
+ Returns:
197
+ A jax.Array of shape [..., input_dim].
198
+ """
199
+ emb_var = self._cast_to_fprop_dtype(
200
+ self.param(
201
+ 'emb_var',
202
+ nn.initializers.normal(stddev=1.0 / math.sqrt(self.input_dim)),
203
+ [self.num_classes, self.input_dim],
204
+ self.dtype,
205
+ )
206
+ )
207
+ if self.lookup_style == 'index':
208
+ embs = jnp.asarray(emb_var)[(ids,)]
209
+ elif self.lookup_style == 'matmul':
210
+ one_hot_ids = jax.nn.one_hot(
211
+ ids, self.num_classes, dtype=self.fprop_dtype
212
+ )
213
+ embs = jnp.einsum('...y,yz->...z', one_hot_ids, emb_var)
214
+ else:
215
+ raise ValueError(f'Unknown lookup style: `{self.lookup_style}`.')
216
+
217
+ # Map out-of-boundary ids to NaN.
218
+ if self.set_nan_for_oob_id:
219
+ embs = jnp.where(ids[..., jnp.newaxis] < self.num_classes, embs, jnp.nan)
220
+
221
+ if self.scale_sqrt_depth:
222
+ embs *= self.input_dim**0.5
223
+
224
+ return embs
225
+
226
+
227
+ class PositionalEmbedding(layers.Module):
228
+ """Generates position embedding for a given 1-d sequence.
229
+
230
+ Attributes:
231
+ embedding_dim: Dimension of the embedding to be generated.
232
+ min_timescale: Start of the geometric index.
233
+ max_timescale: End of the geometric index.
234
+ """
235
+
236
+ embedding_dim: int = 0
237
+ min_timescale: int = 1
238
+ max_timescale: int = 10_000
239
+
240
+ def __call__(self, seq_length: int) -> Array:
241
+ """Generates a jax.Array of embedding lookup result.
242
+
243
+ Args:
244
+ seq_length: Sequence length of the embeddings to be generated.
245
+
246
+ Returns:
247
+ A jax.Array of shape [1, seq_length, embedding_dim].
248
+ """
249
+ position = jnp.arange(seq_length, dtype=jnp.float32)[jnp.newaxis, :]
250
+ num_timescales = self.embedding_dim // 2
251
+ log_timescale_increment = math.log(
252
+ float(self.max_timescale) / float(self.min_timescale)
253
+ ) / jnp.maximum(jnp.asarray(num_timescales, dtype=jnp.float32) - 1, 1)
254
+ inv_timescales = self.min_timescale * jnp.exp(
255
+ jnp.arange(num_timescales, dtype=jnp.float32) * -log_timescale_increment
256
+ )
257
+ scaled_time = (
258
+ position[:, :, jnp.newaxis]
259
+ * inv_timescales[jnp.newaxis, jnp.newaxis, :]
260
+ )
261
+ embs = jnp.concatenate(
262
+ [jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=-1
263
+ ).astype(self.fprop_dtype)
264
+ # Force usage of `np` to compute static values at trace time.
265
+ embs = jnp.pad(embs, [[0, 0], [0, 0], [0, np.mod(self.embedding_dim, 2)]])
266
+ return embs
267
+
268
+
269
+ class TrainablePositionalEmbedding(layers.Module):
270
+ """Generates trainable position embedding for a given 1-d sequence.
271
+
272
+ Attributes:
273
+ embedding_dim: Dimension of the embedding to be generated.
274
+ max_seq_length: Max sequence length.
275
+ lookup_style: Style of lookup, one of index or matmul.
276
+ """
277
+
278
+ embedding_dim: int = 0
279
+ max_seq_length: int = 10_240
280
+ lookup_style: str = 'matmul'
281
+
282
+ @nn.compact
283
+ def __call__(self, seq_length: int) -> Array:
284
+ """Generates a jax.Array of embedding lookup result.
285
+
286
+ Args:
287
+ seq_length: Sequence length of the embeddings to be generated.
288
+
289
+ Returns:
290
+ A jax.Array of shape [1, seq_length, embedding_dim].
291
+ """
292
+ position = jnp.arange(seq_length, dtype=jnp.int32)[jnp.newaxis, :]
293
+ pos_emb_var = self._cast_to_fprop_dtype(
294
+ self.param(
295
+ 'emb_var',
296
+ default_kernel_init,
297
+ [self.max_seq_length, self.embedding_dim],
298
+ self.dtype,
299
+ )
300
+ )
301
+ pos_emb_var = jax.lax.slice_in_dim(pos_emb_var, 0, seq_length, axis=0)
302
+ if self.lookup_style == 'matmul':
303
+ one_hot_ids = jax.nn.one_hot(position, seq_length, dtype=self.fprop_dtype)
304
+ embs = jnp.einsum('...y,yz->...z', one_hot_ids, pos_emb_var)
305
+ else:
306
+ raise ValueError(f'Unknown lookup style: `{self.lookup_style}`.')
307
+ return embs
308
+
309
+
310
+ class VisionTransformer(layers.Module):
311
+ """Vision transformer model.
312
+
313
+ This class follows a minimalistic design pattern. Users need to configure the
314
+ templates for the submodules themselves; this increases the generalizability
315
+ of this class.
316
+
317
+ Attributes:
318
+ num_tfm_layers: Number of layers in this model.
319
+ mlp_dim: The hidden layer dimension of FFN in Transformer layers.
320
+ num_heads: Number of attention heads.
321
+ xformer_has_bias: Whether to use bias.
322
+ xformer_dropout_prob: Apply dropout at this prob at various places.
323
+ xformer_atten_dropout_prob: Probability at which we apply dropout to the
324
+ attention weights.
325
+ xformer_residual_dropout_prob: Probability at which we apply dropout to the
326
+ residual layers, such that, residual(x, y) = (x + dropout(y)).
327
+ xformer_relu_dropout_prob: Probability at which we apply dropout to the FFN
328
+ layers.
329
+ atten_logit_cap: Cap the absolute values of logits by tanh. Enabled when a
330
+ positive value is specified. May not be supported by a subclass.
331
+ norm_policy: Policy for applying normalization wrt. transformations. Options
332
+ are: (1) "pre", applied before transformation. (2) "primer_hybrid",
333
+ applied before and after transformation. (3) "post", applied after
334
+ transformation. (4) "post_skip", applied after the skip connection.
335
+ scan: Whether to use `nn.remat` and`nn.scan`.
336
+ """
337
+
338
+ num_tfm_layers: int = 12
339
+ mlp_dim: int = 3072
340
+ num_heads: int = 12
341
+ xformer_has_bias: bool = True
342
+ xformer_dropout_prob: float = 0.0
343
+ xformer_atten_dropout_prob: float | None = None
344
+ xformer_residual_dropout_prob: float | None = None
345
+ xformer_relu_dropout_prob: float | None = None
346
+ atten_logit_cap: float = 0.0
347
+ norm_policy: str = 'pre'
348
+ scan: bool = False
349
+
350
+ @nn.compact
351
+ def __call__(
352
+ self, inputs: Array, paddings: Array | None = None, train: bool = False
353
+ ) -> Array:
354
+ """Applies the ViT model to the inputs.
355
+
356
+ Args:
357
+ inputs: Input tensor of shape [B, N, D], which are sequences of embeddings
358
+ or patches.
359
+ paddings: Optional [B, N] padding field of inputs when inputs are with [B,
360
+ N, D].
361
+ train: If the model is in the train mode.
362
+
363
+ Returns:
364
+ Output tensor of shape [B, N, D].
365
+ """
366
+ features = inputs
367
+ if paddings is None:
368
+ paddings = jnp.zeros(features.shape[:-1], dtype=features.dtype)
369
+ features = layers.StackedTransformer(
370
+ name='transformers_stack',
371
+ num_layers=self.num_tfm_layers,
372
+ hidden_dim=self.mlp_dim,
373
+ num_heads=self.num_heads,
374
+ dropout_prob=self.xformer_dropout_prob,
375
+ atten_dropout_prob=self.xformer_atten_dropout_prob,
376
+ residual_dropout_prob=self.xformer_residual_dropout_prob,
377
+ relu_dropout_prob=self.xformer_relu_dropout_prob,
378
+ use_bias=self.xformer_has_bias,
379
+ atten_logit_cap=self.atten_logit_cap,
380
+ norm_policy=self.norm_policy,
381
+ internal_enable_per_dim_scale=False,
382
+ activation_fn=layers.gelu,
383
+ enable_causal_atten=False,
384
+ scan=self.scan,
385
+ dtype=self.dtype,
386
+ fprop_dtype=self.fprop_dtype,
387
+ )(features, paddings, train=train)
388
+ return features
389
+
390
+
391
+ class FactorizedEncoder(layers.Module):
392
+ """Factorized encoder from the paper `ViViT: A Video Vision Transformer`.
393
+
394
+ This is an implementation of model-2 in the paper. It applies ViT model for
395
+ video data based on the factorized space-time encoder.
396
+
397
+ Reference: https://arxiv.org/abs/2103.15691
398
+ """
399
+
400
+ patch_size: int = 18
401
+ pos_emb_shape: tuple[int, int, int] = (16, 16, 16)
402
+ model_dim: int = 768
403
+ num_spatial_layers: int = 12
404
+ num_temporal_layers: int = 4
405
+ num_heads: int = 12
406
+ mlp_dim: int = 3072
407
+ atten_logit_cap: float = 0.0
408
+ norm_policy: str = 'pre'
409
+ scan: bool = False
410
+
411
+ def __call__(
412
+ self,
413
+ inputs: Array,
414
+ train: bool = False,
415
+ return_intermediate: bool | Collection[str] = False,
416
+ frame_paddings: Array | None = None,
417
+ ) -> tuple[Array, dict[str, Array]]:
418
+ """Computes predictions for batched inputs.
419
+
420
+ Args:
421
+ inputs: Input image tensor of shape [B, T, H, W, 3] (H == W).
422
+ train: If the model is in the train mode.
423
+ return_intermediate: A boolean for whether all intermediate features are
424
+ returned, or a container of intermediate feature names to return.
425
+ frame_paddings: Optional binary tensor of shape [B, T] indicating padding.
426
+ 1 denotes padding frame.
427
+
428
+ Returns:
429
+ embeddings: Output tensor for video embeddings of shape [B, T * N, D].
430
+ outputs: A dictionary of additional outputs, including `spatial_features`
431
+ (shape = [B, T * N, D]). Empty if `return_intermediate` is False or does
432
+ not contain 'spatial_features'.
433
+ """
434
+ b, t, h, w, c = inputs.shape
435
+ assert h == w
436
+ reshaped_inputs = inputs.reshape(b * t, h, w, c) # (B * T, H, W, C).
437
+
438
+ # Tokenization.
439
+ patches = _image_to_patch(reshaped_inputs, self.patch_size)
440
+ patches_paddings = None
441
+ if frame_paddings is not None:
442
+ assert frame_paddings.shape == (b, t)
443
+ reshaped_frame_paddings = frame_paddings.reshape(b * t) # (B * T,).
444
+ num_patches = patches.shape[1]
445
+ patches_paddings = jnp.repeat(
446
+ reshaped_frame_paddings[:, jnp.newaxis], num_patches, axis=-1
447
+ ) # (B * T, num_patches).
448
+
449
+ embeddings, outputs = self.encode_with_patches(
450
+ patches=patches,
451
+ image_shape=(t, h, w),
452
+ train=train,
453
+ return_intermediate=return_intermediate,
454
+ patches_paddings=patches_paddings,
455
+ )
456
+ return embeddings, outputs
457
+
458
+ @nn.compact
459
+ def encode_with_patches(
460
+ self,
461
+ patches: Array,
462
+ image_shape: tuple[int, int, int],
463
+ train: bool = False,
464
+ return_intermediate: bool | Collection[str] = False,
465
+ patches_paddings: Array | None = None,
466
+ ) -> tuple[Array, dict[str, Array]]:
467
+ """Computes predictions for patches.
468
+
469
+ Args:
470
+ patches: Input patches tensor of shape [B * T, (H * W / P^2), P^2 * C].
471
+ image_shape: Original image shape (T, H, W).
472
+ train: If the model is in the train mode.
473
+ return_intermediate: A boolean for whether all intermediate features are
474
+ returned, or a collection of intermediate feature names to return.
475
+ patches_paddings: Optional binary tensor of shape [B * T, (H * W / P^2)]
476
+ indicating padding. 1 denotes padded patch.
477
+
478
+ Returns:
479
+ embeddings: Output tensor for video embedding sequence of shape [B, T * N,
480
+ D].
481
+ outputs: A dictionary of additional outputs, including `spatial_features`
482
+ of shape [B, T * N, D]. Empty if `return_intermediate` is False or does
483
+ not contain 'spatial_features'.
484
+ """
485
+ t, h, w = image_shape
486
+ b = patches.shape[0] // t
487
+
488
+ patches = layers.FeedForward( # (B * T, N, D).
489
+ name='patch_projection',
490
+ output_dim=self.model_dim,
491
+ activation_fn=layers.identity,
492
+ dtype=self.dtype,
493
+ fprop_dtype=self.fprop_dtype,
494
+ )(patches)
495
+
496
+ # Add spatial positional encoding.
497
+ spatial_pos_emb_shape = self.pos_emb_shape[-2:]
498
+ spatial_seq_length = np.prod(spatial_pos_emb_shape)
499
+ spatial_pos_emb = TrainablePositionalEmbedding(
500
+ name='spatial_pos_emb',
501
+ embedding_dim=self.model_dim,
502
+ max_seq_length=spatial_seq_length,
503
+ dtype=self.dtype,
504
+ fprop_dtype=self.fprop_dtype,
505
+ )(seq_length=spatial_seq_length)
506
+ num_row_patches = h // self.patch_size
507
+ num_col_patches = w // self.patch_size
508
+ if spatial_pos_emb_shape != (num_row_patches, num_col_patches):
509
+ spatial_pos_emb = _interpolate_emb_2d(
510
+ spatial_pos_emb,
511
+ spatial_pos_emb_shape,
512
+ (num_row_patches, num_col_patches),
513
+ )
514
+ patches += spatial_pos_emb # (B * T, N, D).
515
+
516
+ # Get features from the spatial encoder.
517
+ features = VisionTransformer( # (B * T, N, D).
518
+ name='spatial_encoder',
519
+ num_tfm_layers=self.num_spatial_layers,
520
+ mlp_dim=self.mlp_dim,
521
+ num_heads=self.num_heads,
522
+ atten_logit_cap=self.atten_logit_cap,
523
+ norm_policy=self.norm_policy,
524
+ scan=self.scan,
525
+ dtype=self.dtype,
526
+ fprop_dtype=self.fprop_dtype,
527
+ )(patches, train=train, paddings=patches_paddings)
528
+ features = layers.LayerNorm(
529
+ name='spatial_ln', dtype=self.dtype, fprop_dtype=self.fprop_dtype
530
+ )(features)
531
+ spatial_features = features
532
+
533
+ # Instead of mean pooling, we keep the spatial tokens.
534
+ # Shape = (B * N, T, D).
535
+ features = einshape.jax_einshape('(bt)nd->(bn)td', features, t=t)
536
+ temporal_paddings = None
537
+ if patches_paddings is not None:
538
+ temporal_paddings = einshape.jax_einshape(
539
+ '(bt)n->(bn)t', patches_paddings, t=t
540
+ ) # (B * N, T).
541
+
542
+ # Add temporal positional encoding.
543
+ temporal_seq_length = self.pos_emb_shape[0]
544
+ temporal_pos_emb = TrainablePositionalEmbedding(
545
+ name='temporal_pos_emb',
546
+ embedding_dim=self.model_dim,
547
+ max_seq_length=temporal_seq_length,
548
+ dtype=self.dtype,
549
+ fprop_dtype=self.fprop_dtype,
550
+ )(seq_length=temporal_seq_length)
551
+ if temporal_seq_length != t:
552
+ temporal_pos_emb = _interpolate_emb_1d(temporal_pos_emb, t)
553
+ features += temporal_pos_emb
554
+
555
+ # Get features from the temporal encoder.
556
+ features = VisionTransformer(
557
+ name='temporal_encoder',
558
+ num_tfm_layers=self.num_temporal_layers,
559
+ mlp_dim=self.mlp_dim,
560
+ num_heads=self.num_heads,
561
+ atten_logit_cap=self.atten_logit_cap,
562
+ norm_policy=self.norm_policy,
563
+ scan=self.scan,
564
+ dtype=self.dtype,
565
+ fprop_dtype=self.fprop_dtype,
566
+ )(features, train=train, paddings=temporal_paddings)
567
+ features = layers.LayerNorm(
568
+ name='temporal_ln', dtype=self.dtype, fprop_dtype=self.fprop_dtype
569
+ )(features)
570
+ features = einshape.jax_einshape( # (B, T * N, D).
571
+ '(bn)td->b(tn)d', features, b=b
572
+ )
573
+
574
+ embeddings, outputs = features, {}
575
+ if _contains(return_intermediate, 'spatial_features'):
576
+ outputs['spatial_features'] = einshape.jax_einshape(
577
+ '(bt)nd->b(tn)d', spatial_features, t=t
578
+ )
579
+
580
+ return embeddings, outputs
581
+
582
+
583
+ class FactorizedVideoClassifier(layers.Module):
584
+ """Video classifier with `FactorizedEncoder` backbone.
585
+
586
+ Attributes:
587
+ encoder_params: A dictionary of parameters for `FactorizedEncoder`.
588
+ num_classes: Number of output classes.
589
+ """
590
+
591
+ encoder_params: dict[str, Any] = dataclasses.field(default_factory=dict)
592
+ num_classes: int = 0
593
+
594
+ @nn.compact
595
+ def __call__(
596
+ self,
597
+ inputs: Array,
598
+ train: bool = False,
599
+ return_intermediate: bool | Collection[str] = False,
600
+ frame_paddings: Array | None = None,
601
+ ):
602
+ """Applies video classifier to inputs.
603
+
604
+ Args:
605
+ inputs: Input tensor of shape [B, T, H, W, 3].
606
+ train: Whether the model is in the training mode.
607
+ return_intermediate: A boolean for whether all intermediate features are
608
+ returned, or a collection of intermediate feature names to return.
609
+ frame_paddings: Optional binary tensor of shape [B, T] indicating padding.
610
+ 1 denotes padding frame.
611
+
612
+ Returns:
613
+ logits: Output tensor of shape [B, num_classes].
614
+ outputs: A dictionary of additional outputs, including `spatial_features`
615
+ of shape [B, T * N, D], `spatiotemporal_features` of shape [B, T * N,
616
+ D], and `global_embeddings` of shape [B, D]. Empty if
617
+ `return_intermediate` is False.
618
+ """
619
+ features, outputs = FactorizedEncoder(
620
+ name='encoder',
621
+ dtype=self.dtype,
622
+ fprop_dtype=self.fprop_dtype,
623
+ **self.encoder_params,
624
+ )(
625
+ inputs,
626
+ train=train,
627
+ return_intermediate=return_intermediate,
628
+ frame_paddings=frame_paddings,
629
+ )
630
+ if _contains(return_intermediate, 'spatiotemporal_features'):
631
+ outputs['spatiotemporal_features'] = features
632
+
633
+ embeddings = layers.AttenTokenPoolingLayer(
634
+ name='atten_pooler',
635
+ num_heads=self.encoder_params['num_heads'],
636
+ hidden_dim=self.encoder_params['model_dim'],
637
+ num_queries=1,
638
+ dtype=self.dtype,
639
+ fprop_dtype=self.fprop_dtype,
640
+ )(features, paddings=None, train=train)
641
+ embeddings = jnp.squeeze(embeddings, axis=-2)
642
+
643
+ if _contains(return_intermediate, 'global_embeddings'):
644
+ outputs['global_embeddings'] = embeddings
645
+
646
+ logits = layers.FeedForward(
647
+ name='projection',
648
+ output_dim=self.num_classes,
649
+ activation_fn=layers.identity,
650
+ dtype=self.dtype,
651
+ fprop_dtype=self.fprop_dtype,
652
+ )(embeddings)
653
+ return logits, outputs
654
+
655
+
656
+ class TextEncoder(layers.Module):
657
+ """CoCa-style text encoder.
658
+
659
+ Reference: https://arxiv.org/abs/2205.01917
660
+
661
+ Attributes:
662
+ vocabulary_size: Vocabulary size of the text tokens.
663
+ num_class_tokens: Number of class tokens.
664
+ enable_causal_atten: Whether to enable causal attention.
665
+ model_dim: The model dimension.
666
+ num_tfm_layers: Number of layers in this model.
667
+ mlp_dim: The hidden layer dimension of FFN in Transformer layers.
668
+ num_heads: Number of attention heads.
669
+ enable_per_dim_scale: Whether to ensable rescaling of attention logits with
670
+ 1/sqrt(dim) factor.
671
+ atten_logit_cap: Cap the absolute values of logits by tanh. Enabled when a
672
+ positive value is specified. May not be supported by a subclass.
673
+ norm_policy: Policy for applying normalization wrt. transformations. Options
674
+ are: (1) "pre", applied before transformation. (2) "primer_hybrid",
675
+ applied before and after transformation. (3) "post", applied after
676
+ transformation. (4) "post_skip", applied after the skip connection.
677
+ scan: Whether to use `nn.remat` and`nn.scan`.
678
+ """
679
+
680
+ vocabulary_size: int = 128
681
+ num_class_tokens: int = 0
682
+ enable_causal_atten: bool = True
683
+ model_dim: int = 768
684
+ num_layers: int = 12
685
+ mlp_dim: int = 3072
686
+ num_heads: int = 12
687
+ atten_logit_cap: float = 0.0
688
+ norm_policy: str = 'pre'
689
+ enable_per_dim_scale: bool = False
690
+ scan: bool = False
691
+
692
+ @nn.compact
693
+ def __call__(
694
+ self, inputs: Array, paddings: Array, train: bool = False
695
+ ) -> Array:
696
+ """Applies the text encoder to the inputs.
697
+
698
+ Args:
699
+ inputs: Input tensor of shape [B, N] including sequences of token ids.
700
+ paddings: Optional [B, N] padding field of inputs.
701
+ train: If the model is in the train mode.
702
+
703
+ Returns:
704
+ Output tensor of shape [B, N, D].
705
+ """
706
+ batch_size, seq_length = inputs.shape
707
+
708
+ pos_emb = PositionalEmbedding(
709
+ name='pos_emb',
710
+ embedding_dim=self.model_dim,
711
+ dtype=self.dtype,
712
+ fprop_dtype=self.fprop_dtype,
713
+ )(seq_length=seq_length)
714
+ input_emb = Embedding(
715
+ name='token_emb',
716
+ num_classes=self.vocabulary_size,
717
+ input_dim=self.model_dim,
718
+ scale_sqrt_depth=True,
719
+ dtype=self.dtype,
720
+ fprop_dtype=self.fprop_dtype,
721
+ )(inputs)
722
+ features = input_emb + pos_emb
723
+
724
+ if self.num_class_tokens > 0:
725
+ cls_emb = self._cast_to_fprop_dtype(
726
+ self.param(
727
+ 'cls_emb',
728
+ nn.initializers.normal(stddev=1.0 / math.sqrt(self.model_dim)),
729
+ [1, self.num_class_tokens, self.model_dim],
730
+ self.dtype,
731
+ )
732
+ )
733
+ cls_emb = jnp.tile(cls_emb, [batch_size, 1, 1])
734
+ cls_emb *= self.model_dim**0.5
735
+ features = jnp.concatenate([features, cls_emb], axis=-2)
736
+
737
+ cls_paddings = jnp.zeros(
738
+ [batch_size, self.num_class_tokens], dtype=paddings.dtype
739
+ )
740
+ paddings = jnp.concatenate([paddings, cls_paddings], axis=-1)
741
+
742
+ features = layers.StackedTransformer(
743
+ name='unimodal_transformer',
744
+ num_layers=self.num_layers,
745
+ hidden_dim=self.mlp_dim,
746
+ num_heads=self.num_heads,
747
+ atten_logit_cap=self.atten_logit_cap,
748
+ norm_policy=self.norm_policy,
749
+ internal_enable_per_dim_scale=self.enable_per_dim_scale,
750
+ activation_fn=jax.nn.relu,
751
+ enable_causal_atten=self.enable_causal_atten,
752
+ scan=self.scan,
753
+ dtype=self.dtype,
754
+ fprop_dtype=self.fprop_dtype,
755
+ )(features, paddings, train=train)
756
+ features = layers.LayerNorm(
757
+ name='unimodal_ln', dtype=self.dtype, fprop_dtype=self.fprop_dtype
758
+ )(features)
759
+ return features
760
+
761
+
762
+ class FactorizedVideoCLIP(layers.Module):
763
+ """Video CLIP model with a factorized vision encoder."""
764
+
765
+ # Vision parameters.
766
+ patch_size: int = 18
767
+ pos_emb_shape: tuple[int, int, int] = (16, 16, 16)
768
+ num_spatial_layers: int = 12
769
+ num_temporal_layers: int = 4
770
+ mlp_dim: int = 3072
771
+ num_auxiliary_layers: int = 0
772
+ # Text parameters.
773
+ vocabulary_size: int = 128
774
+ enable_causal_atten: bool = True
775
+ num_unimodal_layers: int = 12
776
+ norm_policy: str = 'pre'
777
+ # Shared parameters.
778
+ model_dim: int = 768
779
+ num_heads: int = 12
780
+ atten_logit_cap: float = 0.0
781
+ scan: bool = False
782
+
783
+ @nn.compact
784
+ def __call__(
785
+ self,
786
+ inputs: Array | None = None,
787
+ text_token_ids: Array | None = None,
788
+ text_paddings: Array | None = None,
789
+ train: bool = False,
790
+ normalize: bool = True,
791
+ return_intermediate: bool | Collection[str] = False,
792
+ frame_paddings: Array | None = None,
793
+ ) -> tuple[Array | None, Array | None, dict[str, Array]]:
794
+ """Computes predictions for `input_batch`.
795
+
796
+ Args:
797
+ inputs: Input frame image tensor of shape [B, T, H, W, 3] (H == W).
798
+ text_token_ids: Input text token id tensor of shape [B, L].
799
+ text_paddings: Input text paddings of shape [B, L]. Required if
800
+ `text_token_ids` is not None.
801
+ train: If the model is in the train mode.
802
+ normalize: Whether to normalize the output embeddings.
803
+ return_intermediate: A boolean for whether all intermediate features are
804
+ returned, or a collection of intermediate feature names to return.
805
+ frame_paddings: Optional binary tensor of shape [B, T] indicating padding.
806
+ 1 denotes padding frame.
807
+
808
+ Returns:
809
+ video_embeddings: Output contrastive video embeddings of shape [B, D].
810
+ None if `inputs` is None.
811
+ text_embeddings: Output contrastive text embeddings of shape [B, D]. None
812
+ if `text_token_ids` is None.
813
+ outputs: A dictionary of additional outputs, including `spatial_features`
814
+ of shape [B, T * N, D], `spatiotemporal_features` of shape [B, T * N,
815
+ D], and `frame_embeddings` of shape [B, T, D]. Empty if
816
+ `return_intermediate` is False or does not contain `spatial_features`.
817
+ """
818
+ video_embeddings, text_embeddings, outputs = None, None, {}
819
+
820
+ if inputs is not None:
821
+ num_frames = inputs.shape[-4]
822
+ vision_features, vision_outputs = FactorizedEncoder(
823
+ name='vision_encoder',
824
+ patch_size=self.patch_size,
825
+ pos_emb_shape=self.pos_emb_shape,
826
+ model_dim=self.model_dim,
827
+ num_spatial_layers=self.num_spatial_layers,
828
+ num_temporal_layers=self.num_temporal_layers,
829
+ num_heads=self.num_heads,
830
+ mlp_dim=self.mlp_dim,
831
+ atten_logit_cap=self.atten_logit_cap,
832
+ norm_policy='pre',
833
+ scan=self.scan,
834
+ dtype=self.dtype,
835
+ fprop_dtype=self.fprop_dtype,
836
+ )(
837
+ inputs,
838
+ train=train,
839
+ return_intermediate=return_intermediate,
840
+ frame_paddings=frame_paddings,
841
+ )
842
+ outputs.update(vision_outputs)
843
+ if _contains(return_intermediate, 'spatiotemporal_features'):
844
+ outputs['spatiotemporal_features'] = vision_features
845
+
846
+ if self.num_auxiliary_layers > 0:
847
+ vision_features = VisionTransformer(
848
+ name='auxiliary_encoder',
849
+ num_tfm_layers=self.num_auxiliary_layers,
850
+ mlp_dim=self.mlp_dim,
851
+ num_heads=self.num_heads,
852
+ atten_logit_cap=self.atten_logit_cap,
853
+ norm_policy='pre',
854
+ scan=self.scan,
855
+ dtype=self.dtype,
856
+ fprop_dtype=self.fprop_dtype,
857
+ )(vision_features, train=train)
858
+
859
+ pooling_layer = layers.AttenTokenPoolingLayer(
860
+ name='contrastive_vision_pooler',
861
+ hidden_dim=self.model_dim * 4,
862
+ num_heads=self.num_heads,
863
+ num_queries=1,
864
+ dtype=self.dtype,
865
+ fprop_dtype=self.fprop_dtype,
866
+ )
867
+ video_embeddings = pooling_layer(vision_features, None, train=train)
868
+
869
+ # Squeeze the query dimension in the pooler output.
870
+ video_embeddings = jnp.squeeze(video_embeddings, axis=-2)
871
+ if normalize:
872
+ video_embeddings = _l2_normalize(video_embeddings, axis=-1)
873
+
874
+ if _contains(return_intermediate, 'frame_embeddings'):
875
+ frame_features = einshape.jax_einshape(
876
+ 'b(tn)d->(bt)nd', vision_features, t=num_frames
877
+ )
878
+ frame_embeddings = pooling_layer(frame_features, None, train=train)
879
+ frame_embeddings = jnp.squeeze(frame_embeddings, axis=-2)
880
+ frame_embeddings = einshape.jax_einshape(
881
+ '(bt)d->btd', frame_embeddings, t=num_frames
882
+ )
883
+ if normalize:
884
+ frame_embeddings = _l2_normalize(frame_embeddings, axis=-1)
885
+ outputs['frame_embeddings'] = frame_embeddings
886
+
887
+ if text_token_ids is not None:
888
+ assert text_paddings is not None, 'Text paddings are required.'
889
+ text_features = TextEncoder(
890
+ name='text_encoder',
891
+ vocabulary_size=self.vocabulary_size,
892
+ num_class_tokens=1,
893
+ enable_causal_atten=self.enable_causal_atten,
894
+ model_dim=self.model_dim,
895
+ num_layers=self.num_unimodal_layers,
896
+ num_heads=self.num_heads,
897
+ mlp_dim=self.model_dim * 4,
898
+ atten_logit_cap=self.atten_logit_cap,
899
+ norm_policy=self.norm_policy,
900
+ scan=self.scan,
901
+ dtype=self.dtype,
902
+ fprop_dtype=self.fprop_dtype,
903
+ )(text_token_ids, text_paddings, train=train)
904
+
905
+ # Take the last token (i.e., class token) as the text embedding.
906
+ text_embeddings = text_features[:, -1]
907
+ if normalize:
908
+ text_embeddings = _l2_normalize(text_embeddings, axis=-1)
909
+
910
+ return video_embeddings, text_embeddings, outputs