singlebehaviorlab 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sam2/__init__.py +11 -0
- sam2/automatic_mask_generator.py +454 -0
- sam2/benchmark.py +92 -0
- sam2/build_sam.py +174 -0
- sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
- sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
- sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
- sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
- sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +116 -0
- sam2/configs/sam2.1/sam2.1_hiera_l.yaml +120 -0
- sam2/configs/sam2.1/sam2.1_hiera_s.yaml +119 -0
- sam2/configs/sam2.1/sam2.1_hiera_t.yaml +121 -0
- sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +339 -0
- sam2/modeling/__init__.py +5 -0
- sam2/modeling/backbones/__init__.py +5 -0
- sam2/modeling/backbones/hieradet.py +317 -0
- sam2/modeling/backbones/image_encoder.py +134 -0
- sam2/modeling/backbones/utils.py +93 -0
- sam2/modeling/memory_attention.py +169 -0
- sam2/modeling/memory_encoder.py +181 -0
- sam2/modeling/position_encoding.py +239 -0
- sam2/modeling/sam/__init__.py +5 -0
- sam2/modeling/sam/mask_decoder.py +295 -0
- sam2/modeling/sam/prompt_encoder.py +202 -0
- sam2/modeling/sam/transformer.py +311 -0
- sam2/modeling/sam2_base.py +913 -0
- sam2/modeling/sam2_utils.py +323 -0
- sam2/sam2_hiera_b+.yaml +113 -0
- sam2/sam2_hiera_l.yaml +117 -0
- sam2/sam2_hiera_s.yaml +116 -0
- sam2/sam2_hiera_t.yaml +118 -0
- sam2/sam2_image_predictor.py +466 -0
- sam2/sam2_video_predictor.py +1388 -0
- sam2/sam2_video_predictor_legacy.py +1172 -0
- sam2/utils/__init__.py +5 -0
- sam2/utils/amg.py +348 -0
- sam2/utils/misc.py +349 -0
- sam2/utils/transforms.py +118 -0
- singlebehaviorlab/__init__.py +4 -0
- singlebehaviorlab/__main__.py +130 -0
- singlebehaviorlab/_paths.py +100 -0
- singlebehaviorlab/backend/__init__.py +2 -0
- singlebehaviorlab/backend/augmentations.py +320 -0
- singlebehaviorlab/backend/data_store.py +420 -0
- singlebehaviorlab/backend/model.py +1290 -0
- singlebehaviorlab/backend/train.py +4667 -0
- singlebehaviorlab/backend/uncertainty.py +578 -0
- singlebehaviorlab/backend/video_processor.py +688 -0
- singlebehaviorlab/backend/video_utils.py +139 -0
- singlebehaviorlab/data/config/config.yaml +85 -0
- singlebehaviorlab/data/training_profiles.json +334 -0
- singlebehaviorlab/gui/__init__.py +4 -0
- singlebehaviorlab/gui/analysis_widget.py +2291 -0
- singlebehaviorlab/gui/attention_export.py +311 -0
- singlebehaviorlab/gui/clip_extraction_widget.py +481 -0
- singlebehaviorlab/gui/clustering_widget.py +3187 -0
- singlebehaviorlab/gui/inference_popups.py +1138 -0
- singlebehaviorlab/gui/inference_widget.py +4550 -0
- singlebehaviorlab/gui/inference_worker.py +651 -0
- singlebehaviorlab/gui/labeling_widget.py +2324 -0
- singlebehaviorlab/gui/main_window.py +754 -0
- singlebehaviorlab/gui/metadata_management_widget.py +1119 -0
- singlebehaviorlab/gui/motion_tracking.py +764 -0
- singlebehaviorlab/gui/overlay_export.py +1234 -0
- singlebehaviorlab/gui/plot_integration.py +729 -0
- singlebehaviorlab/gui/qt_helpers.py +29 -0
- singlebehaviorlab/gui/registration_widget.py +1485 -0
- singlebehaviorlab/gui/review_widget.py +1330 -0
- singlebehaviorlab/gui/segmentation_tracking_widget.py +2752 -0
- singlebehaviorlab/gui/tab_tutorial_dialog.py +312 -0
- singlebehaviorlab/gui/timeline_themes.py +131 -0
- singlebehaviorlab/gui/training_profiles.py +418 -0
- singlebehaviorlab/gui/training_widget.py +3719 -0
- singlebehaviorlab/gui/video_utils.py +233 -0
- singlebehaviorlab/licenses/SAM2-LICENSE +201 -0
- singlebehaviorlab/licenses/VideoPrism-LICENSE +202 -0
- singlebehaviorlab-2.0.0.dist-info/METADATA +447 -0
- singlebehaviorlab-2.0.0.dist-info/RECORD +88 -0
- singlebehaviorlab-2.0.0.dist-info/WHEEL +5 -0
- singlebehaviorlab-2.0.0.dist-info/entry_points.txt +2 -0
- singlebehaviorlab-2.0.0.dist-info/licenses/LICENSE +21 -0
- singlebehaviorlab-2.0.0.dist-info/top_level.txt +3 -0
- videoprism/__init__.py +0 -0
- videoprism/encoders.py +910 -0
- videoprism/layers.py +1136 -0
- videoprism/models.py +407 -0
- videoprism/tokenizers.py +167 -0
- videoprism/utils.py +168 -0
videoprism/encoders.py
ADDED
|
@@ -0,0 +1,910 @@
|
|
|
1
|
+
# Copyright 2026 VideoPrism Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""Modules for video encoders."""
|
|
16
|
+
|
|
17
|
+
from collections.abc import Collection, Sequence
|
|
18
|
+
import dataclasses
|
|
19
|
+
import math
|
|
20
|
+
from typing import Any
|
|
21
|
+
|
|
22
|
+
import einops
|
|
23
|
+
import einshape
|
|
24
|
+
from flax import linen as nn
|
|
25
|
+
import jax
|
|
26
|
+
from jax import numpy as jnp
|
|
27
|
+
import numpy as np
|
|
28
|
+
from videoprism import layers
|
|
29
|
+
|
|
30
|
+
Array = jax.Array
|
|
31
|
+
Variables = nn.module.VariableDict
|
|
32
|
+
|
|
33
|
+
default_kernel_init = layers.default_kernel_init
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _contains(collection: Collection[str] | bool, key: str) -> bool:
|
|
37
|
+
"""Checks if a collection contains a key.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
collection: A collection of strings or a boolean value.
|
|
41
|
+
key: A string key to check.
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
True if the collection contains the key, or if the collection is a True
|
|
45
|
+
boolean. False otherwise.
|
|
46
|
+
"""
|
|
47
|
+
return collection if isinstance(collection, bool) else key in collection
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _l2_normalize(
|
|
51
|
+
x: Array, axis: int | Sequence[int] = -1, epsilon: float = 1e-12
|
|
52
|
+
) -> Array:
|
|
53
|
+
"""L2-normalizes a jax.Array along certain dimension.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
x: An input jax.Array.
|
|
57
|
+
axis: An integer or a sequence of integers for the axis to normalize.
|
|
58
|
+
epsilon: A small constant for numerical stability.
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
Normalized jax.Array.
|
|
62
|
+
"""
|
|
63
|
+
x_dtype = x.dtype
|
|
64
|
+
# Always convert embed to float32 for all precisions.
|
|
65
|
+
x = x.astype(jnp.float32)
|
|
66
|
+
norm = jnp.sqrt(jnp.sum(x * x, axis=axis, keepdims=True) + epsilon)
|
|
67
|
+
return (x / norm).astype(x_dtype)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def _image_to_patch(inputs: Array, patch_size: int) -> Array:
|
|
71
|
+
"""Converts an image to patches.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
inputs: A jax.Array of shape [B, H, W, C] ,
|
|
75
|
+
patch_size: An integer for dimension of a square patch.
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
batched_patches: [B, (H * W / P^2), P^2 * C].
|
|
79
|
+
"""
|
|
80
|
+
if len(inputs.shape) < 4:
|
|
81
|
+
raise ValueError(
|
|
82
|
+
f'Image should be formatted as 4D [B, H, W, C], Shape: {inputs.shape}'
|
|
83
|
+
)
|
|
84
|
+
height, width, channels = inputs.shape[-3:]
|
|
85
|
+
|
|
86
|
+
if height % patch_size != 0 or width % patch_size != 0:
|
|
87
|
+
raise ValueError(
|
|
88
|
+
f'Image height ({height}) and width ({width}) should be multiples '
|
|
89
|
+
f'of patch_size ({patch_size}).'
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
row_blocks = height // patch_size
|
|
93
|
+
column_blocks = width // patch_size
|
|
94
|
+
|
|
95
|
+
patches = einops.rearrange(
|
|
96
|
+
inputs,
|
|
97
|
+
'... (m p)(n q) c->...(m n)(p q c)',
|
|
98
|
+
m=row_blocks,
|
|
99
|
+
n=column_blocks,
|
|
100
|
+
p=patch_size,
|
|
101
|
+
q=patch_size,
|
|
102
|
+
c=channels,
|
|
103
|
+
)
|
|
104
|
+
return patches
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def _interpolate_emb_1d(emb: Array, target_emb_length: int) -> Array:
|
|
108
|
+
"""Interpolates a 1D positional embedding to a new shape.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
emb: jax.Array, (1, N, D), flattened 1D positional embedding.
|
|
112
|
+
target_emb_length: length of the target embedding.
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
Flattened, interpolated embedding of shape (1, target_emb_length, D)
|
|
116
|
+
"""
|
|
117
|
+
|
|
118
|
+
if len(emb.shape) > 3 or emb.shape[0] != 1:
|
|
119
|
+
raise ValueError('The shape of the embedding should be (1, N, D)')
|
|
120
|
+
|
|
121
|
+
emb_dim = emb.shape[-1]
|
|
122
|
+
emb = jnp.squeeze(emb, axis=0)
|
|
123
|
+
|
|
124
|
+
target_emb = jax.image.resize(
|
|
125
|
+
emb, (target_emb_length, emb_dim), method='bilinear'
|
|
126
|
+
)
|
|
127
|
+
target_emb = jnp.reshape(target_emb, (1, target_emb_length, emb_dim))
|
|
128
|
+
return target_emb
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def _interpolate_emb_2d(
|
|
132
|
+
emb: Array,
|
|
133
|
+
source_emb_shape: tuple[int, int],
|
|
134
|
+
target_emb_shape: tuple[int, int],
|
|
135
|
+
) -> Array:
|
|
136
|
+
"""Interpolates a 2D positional embedding to a new shape.
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
emb: A jax.Array of shape (1, H1xW1, D) for flattened 2D positional
|
|
140
|
+
embedding.
|
|
141
|
+
source_emb_shape: Tuple, (H1, W1), height and width of the source embedding.
|
|
142
|
+
target_emb_shape: Tuple, (H2, W2), height and width of the target embedding.
|
|
143
|
+
|
|
144
|
+
Returns:
|
|
145
|
+
Flattened, interpolated embedding of shape (1, H2xW2, D)
|
|
146
|
+
"""
|
|
147
|
+
|
|
148
|
+
if len(emb.shape) > 3 or emb.shape[0] != 1:
|
|
149
|
+
raise ValueError('The shape of the embedding should be (1, H * W, D)')
|
|
150
|
+
|
|
151
|
+
if emb.shape[-2] != source_emb_shape[0] * source_emb_shape[1]:
|
|
152
|
+
raise ValueError('The shape of the embedding does NOT match input specs.')
|
|
153
|
+
|
|
154
|
+
emb_dim = emb.shape[-1]
|
|
155
|
+
emb = jnp.reshape(emb, (source_emb_shape[0], source_emb_shape[1], emb_dim))
|
|
156
|
+
|
|
157
|
+
target_emb = jax.image.resize(
|
|
158
|
+
emb,
|
|
159
|
+
(target_emb_shape[0], target_emb_shape[1], emb_dim),
|
|
160
|
+
method='bilinear',
|
|
161
|
+
)
|
|
162
|
+
target_emb = jnp.reshape(
|
|
163
|
+
target_emb, (1, target_emb_shape[0] * target_emb_shape[1], emb_dim)
|
|
164
|
+
)
|
|
165
|
+
return target_emb
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
class Embedding(layers.Module):
|
|
169
|
+
"""A simple embedding layer that performs embedding lookups from ids.
|
|
170
|
+
|
|
171
|
+
Attributes:
|
|
172
|
+
num_classes: Number of tokens in the vocabulary.
|
|
173
|
+
input_dim: Depth of the embedding output. This is called `input_dim` as
|
|
174
|
+
opposed to the more appropriate `embedding_dim` to be compatible with
|
|
175
|
+
other embedding layers defined in this file.
|
|
176
|
+
lookup_style: Style of lookup, one of index or matmul.
|
|
177
|
+
scale_sqrt_depth: If set to True, activations are scaled with
|
|
178
|
+
sqrt(embedding_dim) in embeding lookup.
|
|
179
|
+
set_nan_for_oob_id: If set to True, embeddings corresponding to
|
|
180
|
+
out-of-boundaries ids will be set to NaN.
|
|
181
|
+
"""
|
|
182
|
+
|
|
183
|
+
num_classes: int = 0
|
|
184
|
+
input_dim: int = 0
|
|
185
|
+
lookup_style: str = 'index'
|
|
186
|
+
scale_sqrt_depth: bool = False
|
|
187
|
+
set_nan_for_oob_id: bool = False
|
|
188
|
+
|
|
189
|
+
@nn.compact
|
|
190
|
+
def __call__(self, ids: Array) -> Array:
|
|
191
|
+
"""Generates a jax.Array of embedding lookup result.
|
|
192
|
+
|
|
193
|
+
Args:
|
|
194
|
+
ids: Indexes of shape [...] for embedding lookup.
|
|
195
|
+
|
|
196
|
+
Returns:
|
|
197
|
+
A jax.Array of shape [..., input_dim].
|
|
198
|
+
"""
|
|
199
|
+
emb_var = self._cast_to_fprop_dtype(
|
|
200
|
+
self.param(
|
|
201
|
+
'emb_var',
|
|
202
|
+
nn.initializers.normal(stddev=1.0 / math.sqrt(self.input_dim)),
|
|
203
|
+
[self.num_classes, self.input_dim],
|
|
204
|
+
self.dtype,
|
|
205
|
+
)
|
|
206
|
+
)
|
|
207
|
+
if self.lookup_style == 'index':
|
|
208
|
+
embs = jnp.asarray(emb_var)[(ids,)]
|
|
209
|
+
elif self.lookup_style == 'matmul':
|
|
210
|
+
one_hot_ids = jax.nn.one_hot(
|
|
211
|
+
ids, self.num_classes, dtype=self.fprop_dtype
|
|
212
|
+
)
|
|
213
|
+
embs = jnp.einsum('...y,yz->...z', one_hot_ids, emb_var)
|
|
214
|
+
else:
|
|
215
|
+
raise ValueError(f'Unknown lookup style: `{self.lookup_style}`.')
|
|
216
|
+
|
|
217
|
+
# Map out-of-boundary ids to NaN.
|
|
218
|
+
if self.set_nan_for_oob_id:
|
|
219
|
+
embs = jnp.where(ids[..., jnp.newaxis] < self.num_classes, embs, jnp.nan)
|
|
220
|
+
|
|
221
|
+
if self.scale_sqrt_depth:
|
|
222
|
+
embs *= self.input_dim**0.5
|
|
223
|
+
|
|
224
|
+
return embs
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
class PositionalEmbedding(layers.Module):
|
|
228
|
+
"""Generates position embedding for a given 1-d sequence.
|
|
229
|
+
|
|
230
|
+
Attributes:
|
|
231
|
+
embedding_dim: Dimension of the embedding to be generated.
|
|
232
|
+
min_timescale: Start of the geometric index.
|
|
233
|
+
max_timescale: End of the geometric index.
|
|
234
|
+
"""
|
|
235
|
+
|
|
236
|
+
embedding_dim: int = 0
|
|
237
|
+
min_timescale: int = 1
|
|
238
|
+
max_timescale: int = 10_000
|
|
239
|
+
|
|
240
|
+
def __call__(self, seq_length: int) -> Array:
|
|
241
|
+
"""Generates a jax.Array of embedding lookup result.
|
|
242
|
+
|
|
243
|
+
Args:
|
|
244
|
+
seq_length: Sequence length of the embeddings to be generated.
|
|
245
|
+
|
|
246
|
+
Returns:
|
|
247
|
+
A jax.Array of shape [1, seq_length, embedding_dim].
|
|
248
|
+
"""
|
|
249
|
+
position = jnp.arange(seq_length, dtype=jnp.float32)[jnp.newaxis, :]
|
|
250
|
+
num_timescales = self.embedding_dim // 2
|
|
251
|
+
log_timescale_increment = math.log(
|
|
252
|
+
float(self.max_timescale) / float(self.min_timescale)
|
|
253
|
+
) / jnp.maximum(jnp.asarray(num_timescales, dtype=jnp.float32) - 1, 1)
|
|
254
|
+
inv_timescales = self.min_timescale * jnp.exp(
|
|
255
|
+
jnp.arange(num_timescales, dtype=jnp.float32) * -log_timescale_increment
|
|
256
|
+
)
|
|
257
|
+
scaled_time = (
|
|
258
|
+
position[:, :, jnp.newaxis]
|
|
259
|
+
* inv_timescales[jnp.newaxis, jnp.newaxis, :]
|
|
260
|
+
)
|
|
261
|
+
embs = jnp.concatenate(
|
|
262
|
+
[jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=-1
|
|
263
|
+
).astype(self.fprop_dtype)
|
|
264
|
+
# Force usage of `np` to compute static values at trace time.
|
|
265
|
+
embs = jnp.pad(embs, [[0, 0], [0, 0], [0, np.mod(self.embedding_dim, 2)]])
|
|
266
|
+
return embs
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
class TrainablePositionalEmbedding(layers.Module):
|
|
270
|
+
"""Generates trainable position embedding for a given 1-d sequence.
|
|
271
|
+
|
|
272
|
+
Attributes:
|
|
273
|
+
embedding_dim: Dimension of the embedding to be generated.
|
|
274
|
+
max_seq_length: Max sequence length.
|
|
275
|
+
lookup_style: Style of lookup, one of index or matmul.
|
|
276
|
+
"""
|
|
277
|
+
|
|
278
|
+
embedding_dim: int = 0
|
|
279
|
+
max_seq_length: int = 10_240
|
|
280
|
+
lookup_style: str = 'matmul'
|
|
281
|
+
|
|
282
|
+
@nn.compact
|
|
283
|
+
def __call__(self, seq_length: int) -> Array:
|
|
284
|
+
"""Generates a jax.Array of embedding lookup result.
|
|
285
|
+
|
|
286
|
+
Args:
|
|
287
|
+
seq_length: Sequence length of the embeddings to be generated.
|
|
288
|
+
|
|
289
|
+
Returns:
|
|
290
|
+
A jax.Array of shape [1, seq_length, embedding_dim].
|
|
291
|
+
"""
|
|
292
|
+
position = jnp.arange(seq_length, dtype=jnp.int32)[jnp.newaxis, :]
|
|
293
|
+
pos_emb_var = self._cast_to_fprop_dtype(
|
|
294
|
+
self.param(
|
|
295
|
+
'emb_var',
|
|
296
|
+
default_kernel_init,
|
|
297
|
+
[self.max_seq_length, self.embedding_dim],
|
|
298
|
+
self.dtype,
|
|
299
|
+
)
|
|
300
|
+
)
|
|
301
|
+
pos_emb_var = jax.lax.slice_in_dim(pos_emb_var, 0, seq_length, axis=0)
|
|
302
|
+
if self.lookup_style == 'matmul':
|
|
303
|
+
one_hot_ids = jax.nn.one_hot(position, seq_length, dtype=self.fprop_dtype)
|
|
304
|
+
embs = jnp.einsum('...y,yz->...z', one_hot_ids, pos_emb_var)
|
|
305
|
+
else:
|
|
306
|
+
raise ValueError(f'Unknown lookup style: `{self.lookup_style}`.')
|
|
307
|
+
return embs
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
class VisionTransformer(layers.Module):
|
|
311
|
+
"""Vision transformer model.
|
|
312
|
+
|
|
313
|
+
This class follows a minimalistic design pattern. Users need to configure the
|
|
314
|
+
templates for the submodules themselves; this increases the generalizability
|
|
315
|
+
of this class.
|
|
316
|
+
|
|
317
|
+
Attributes:
|
|
318
|
+
num_tfm_layers: Number of layers in this model.
|
|
319
|
+
mlp_dim: The hidden layer dimension of FFN in Transformer layers.
|
|
320
|
+
num_heads: Number of attention heads.
|
|
321
|
+
xformer_has_bias: Whether to use bias.
|
|
322
|
+
xformer_dropout_prob: Apply dropout at this prob at various places.
|
|
323
|
+
xformer_atten_dropout_prob: Probability at which we apply dropout to the
|
|
324
|
+
attention weights.
|
|
325
|
+
xformer_residual_dropout_prob: Probability at which we apply dropout to the
|
|
326
|
+
residual layers, such that, residual(x, y) = (x + dropout(y)).
|
|
327
|
+
xformer_relu_dropout_prob: Probability at which we apply dropout to the FFN
|
|
328
|
+
layers.
|
|
329
|
+
atten_logit_cap: Cap the absolute values of logits by tanh. Enabled when a
|
|
330
|
+
positive value is specified. May not be supported by a subclass.
|
|
331
|
+
norm_policy: Policy for applying normalization wrt. transformations. Options
|
|
332
|
+
are: (1) "pre", applied before transformation. (2) "primer_hybrid",
|
|
333
|
+
applied before and after transformation. (3) "post", applied after
|
|
334
|
+
transformation. (4) "post_skip", applied after the skip connection.
|
|
335
|
+
scan: Whether to use `nn.remat` and`nn.scan`.
|
|
336
|
+
"""
|
|
337
|
+
|
|
338
|
+
num_tfm_layers: int = 12
|
|
339
|
+
mlp_dim: int = 3072
|
|
340
|
+
num_heads: int = 12
|
|
341
|
+
xformer_has_bias: bool = True
|
|
342
|
+
xformer_dropout_prob: float = 0.0
|
|
343
|
+
xformer_atten_dropout_prob: float | None = None
|
|
344
|
+
xformer_residual_dropout_prob: float | None = None
|
|
345
|
+
xformer_relu_dropout_prob: float | None = None
|
|
346
|
+
atten_logit_cap: float = 0.0
|
|
347
|
+
norm_policy: str = 'pre'
|
|
348
|
+
scan: bool = False
|
|
349
|
+
|
|
350
|
+
@nn.compact
|
|
351
|
+
def __call__(
|
|
352
|
+
self, inputs: Array, paddings: Array | None = None, train: bool = False
|
|
353
|
+
) -> Array:
|
|
354
|
+
"""Applies the ViT model to the inputs.
|
|
355
|
+
|
|
356
|
+
Args:
|
|
357
|
+
inputs: Input tensor of shape [B, N, D], which are sequences of embeddings
|
|
358
|
+
or patches.
|
|
359
|
+
paddings: Optional [B, N] padding field of inputs when inputs are with [B,
|
|
360
|
+
N, D].
|
|
361
|
+
train: If the model is in the train mode.
|
|
362
|
+
|
|
363
|
+
Returns:
|
|
364
|
+
Output tensor of shape [B, N, D].
|
|
365
|
+
"""
|
|
366
|
+
features = inputs
|
|
367
|
+
if paddings is None:
|
|
368
|
+
paddings = jnp.zeros(features.shape[:-1], dtype=features.dtype)
|
|
369
|
+
features = layers.StackedTransformer(
|
|
370
|
+
name='transformers_stack',
|
|
371
|
+
num_layers=self.num_tfm_layers,
|
|
372
|
+
hidden_dim=self.mlp_dim,
|
|
373
|
+
num_heads=self.num_heads,
|
|
374
|
+
dropout_prob=self.xformer_dropout_prob,
|
|
375
|
+
atten_dropout_prob=self.xformer_atten_dropout_prob,
|
|
376
|
+
residual_dropout_prob=self.xformer_residual_dropout_prob,
|
|
377
|
+
relu_dropout_prob=self.xformer_relu_dropout_prob,
|
|
378
|
+
use_bias=self.xformer_has_bias,
|
|
379
|
+
atten_logit_cap=self.atten_logit_cap,
|
|
380
|
+
norm_policy=self.norm_policy,
|
|
381
|
+
internal_enable_per_dim_scale=False,
|
|
382
|
+
activation_fn=layers.gelu,
|
|
383
|
+
enable_causal_atten=False,
|
|
384
|
+
scan=self.scan,
|
|
385
|
+
dtype=self.dtype,
|
|
386
|
+
fprop_dtype=self.fprop_dtype,
|
|
387
|
+
)(features, paddings, train=train)
|
|
388
|
+
return features
|
|
389
|
+
|
|
390
|
+
|
|
391
|
+
class FactorizedEncoder(layers.Module):
|
|
392
|
+
"""Factorized encoder from the paper `ViViT: A Video Vision Transformer`.
|
|
393
|
+
|
|
394
|
+
This is an implementation of model-2 in the paper. It applies ViT model for
|
|
395
|
+
video data based on the factorized space-time encoder.
|
|
396
|
+
|
|
397
|
+
Reference: https://arxiv.org/abs/2103.15691
|
|
398
|
+
"""
|
|
399
|
+
|
|
400
|
+
patch_size: int = 18
|
|
401
|
+
pos_emb_shape: tuple[int, int, int] = (16, 16, 16)
|
|
402
|
+
model_dim: int = 768
|
|
403
|
+
num_spatial_layers: int = 12
|
|
404
|
+
num_temporal_layers: int = 4
|
|
405
|
+
num_heads: int = 12
|
|
406
|
+
mlp_dim: int = 3072
|
|
407
|
+
atten_logit_cap: float = 0.0
|
|
408
|
+
norm_policy: str = 'pre'
|
|
409
|
+
scan: bool = False
|
|
410
|
+
|
|
411
|
+
def __call__(
|
|
412
|
+
self,
|
|
413
|
+
inputs: Array,
|
|
414
|
+
train: bool = False,
|
|
415
|
+
return_intermediate: bool | Collection[str] = False,
|
|
416
|
+
frame_paddings: Array | None = None,
|
|
417
|
+
) -> tuple[Array, dict[str, Array]]:
|
|
418
|
+
"""Computes predictions for batched inputs.
|
|
419
|
+
|
|
420
|
+
Args:
|
|
421
|
+
inputs: Input image tensor of shape [B, T, H, W, 3] (H == W).
|
|
422
|
+
train: If the model is in the train mode.
|
|
423
|
+
return_intermediate: A boolean for whether all intermediate features are
|
|
424
|
+
returned, or a container of intermediate feature names to return.
|
|
425
|
+
frame_paddings: Optional binary tensor of shape [B, T] indicating padding.
|
|
426
|
+
1 denotes padding frame.
|
|
427
|
+
|
|
428
|
+
Returns:
|
|
429
|
+
embeddings: Output tensor for video embeddings of shape [B, T * N, D].
|
|
430
|
+
outputs: A dictionary of additional outputs, including `spatial_features`
|
|
431
|
+
(shape = [B, T * N, D]). Empty if `return_intermediate` is False or does
|
|
432
|
+
not contain 'spatial_features'.
|
|
433
|
+
"""
|
|
434
|
+
b, t, h, w, c = inputs.shape
|
|
435
|
+
assert h == w
|
|
436
|
+
reshaped_inputs = inputs.reshape(b * t, h, w, c) # (B * T, H, W, C).
|
|
437
|
+
|
|
438
|
+
# Tokenization.
|
|
439
|
+
patches = _image_to_patch(reshaped_inputs, self.patch_size)
|
|
440
|
+
patches_paddings = None
|
|
441
|
+
if frame_paddings is not None:
|
|
442
|
+
assert frame_paddings.shape == (b, t)
|
|
443
|
+
reshaped_frame_paddings = frame_paddings.reshape(b * t) # (B * T,).
|
|
444
|
+
num_patches = patches.shape[1]
|
|
445
|
+
patches_paddings = jnp.repeat(
|
|
446
|
+
reshaped_frame_paddings[:, jnp.newaxis], num_patches, axis=-1
|
|
447
|
+
) # (B * T, num_patches).
|
|
448
|
+
|
|
449
|
+
embeddings, outputs = self.encode_with_patches(
|
|
450
|
+
patches=patches,
|
|
451
|
+
image_shape=(t, h, w),
|
|
452
|
+
train=train,
|
|
453
|
+
return_intermediate=return_intermediate,
|
|
454
|
+
patches_paddings=patches_paddings,
|
|
455
|
+
)
|
|
456
|
+
return embeddings, outputs
|
|
457
|
+
|
|
458
|
+
@nn.compact
|
|
459
|
+
def encode_with_patches(
|
|
460
|
+
self,
|
|
461
|
+
patches: Array,
|
|
462
|
+
image_shape: tuple[int, int, int],
|
|
463
|
+
train: bool = False,
|
|
464
|
+
return_intermediate: bool | Collection[str] = False,
|
|
465
|
+
patches_paddings: Array | None = None,
|
|
466
|
+
) -> tuple[Array, dict[str, Array]]:
|
|
467
|
+
"""Computes predictions for patches.
|
|
468
|
+
|
|
469
|
+
Args:
|
|
470
|
+
patches: Input patches tensor of shape [B * T, (H * W / P^2), P^2 * C].
|
|
471
|
+
image_shape: Original image shape (T, H, W).
|
|
472
|
+
train: If the model is in the train mode.
|
|
473
|
+
return_intermediate: A boolean for whether all intermediate features are
|
|
474
|
+
returned, or a collection of intermediate feature names to return.
|
|
475
|
+
patches_paddings: Optional binary tensor of shape [B * T, (H * W / P^2)]
|
|
476
|
+
indicating padding. 1 denotes padded patch.
|
|
477
|
+
|
|
478
|
+
Returns:
|
|
479
|
+
embeddings: Output tensor for video embedding sequence of shape [B, T * N,
|
|
480
|
+
D].
|
|
481
|
+
outputs: A dictionary of additional outputs, including `spatial_features`
|
|
482
|
+
of shape [B, T * N, D]. Empty if `return_intermediate` is False or does
|
|
483
|
+
not contain 'spatial_features'.
|
|
484
|
+
"""
|
|
485
|
+
t, h, w = image_shape
|
|
486
|
+
b = patches.shape[0] // t
|
|
487
|
+
|
|
488
|
+
patches = layers.FeedForward( # (B * T, N, D).
|
|
489
|
+
name='patch_projection',
|
|
490
|
+
output_dim=self.model_dim,
|
|
491
|
+
activation_fn=layers.identity,
|
|
492
|
+
dtype=self.dtype,
|
|
493
|
+
fprop_dtype=self.fprop_dtype,
|
|
494
|
+
)(patches)
|
|
495
|
+
|
|
496
|
+
# Add spatial positional encoding.
|
|
497
|
+
spatial_pos_emb_shape = self.pos_emb_shape[-2:]
|
|
498
|
+
spatial_seq_length = np.prod(spatial_pos_emb_shape)
|
|
499
|
+
spatial_pos_emb = TrainablePositionalEmbedding(
|
|
500
|
+
name='spatial_pos_emb',
|
|
501
|
+
embedding_dim=self.model_dim,
|
|
502
|
+
max_seq_length=spatial_seq_length,
|
|
503
|
+
dtype=self.dtype,
|
|
504
|
+
fprop_dtype=self.fprop_dtype,
|
|
505
|
+
)(seq_length=spatial_seq_length)
|
|
506
|
+
num_row_patches = h // self.patch_size
|
|
507
|
+
num_col_patches = w // self.patch_size
|
|
508
|
+
if spatial_pos_emb_shape != (num_row_patches, num_col_patches):
|
|
509
|
+
spatial_pos_emb = _interpolate_emb_2d(
|
|
510
|
+
spatial_pos_emb,
|
|
511
|
+
spatial_pos_emb_shape,
|
|
512
|
+
(num_row_patches, num_col_patches),
|
|
513
|
+
)
|
|
514
|
+
patches += spatial_pos_emb # (B * T, N, D).
|
|
515
|
+
|
|
516
|
+
# Get features from the spatial encoder.
|
|
517
|
+
features = VisionTransformer( # (B * T, N, D).
|
|
518
|
+
name='spatial_encoder',
|
|
519
|
+
num_tfm_layers=self.num_spatial_layers,
|
|
520
|
+
mlp_dim=self.mlp_dim,
|
|
521
|
+
num_heads=self.num_heads,
|
|
522
|
+
atten_logit_cap=self.atten_logit_cap,
|
|
523
|
+
norm_policy=self.norm_policy,
|
|
524
|
+
scan=self.scan,
|
|
525
|
+
dtype=self.dtype,
|
|
526
|
+
fprop_dtype=self.fprop_dtype,
|
|
527
|
+
)(patches, train=train, paddings=patches_paddings)
|
|
528
|
+
features = layers.LayerNorm(
|
|
529
|
+
name='spatial_ln', dtype=self.dtype, fprop_dtype=self.fprop_dtype
|
|
530
|
+
)(features)
|
|
531
|
+
spatial_features = features
|
|
532
|
+
|
|
533
|
+
# Instead of mean pooling, we keep the spatial tokens.
|
|
534
|
+
# Shape = (B * N, T, D).
|
|
535
|
+
features = einshape.jax_einshape('(bt)nd->(bn)td', features, t=t)
|
|
536
|
+
temporal_paddings = None
|
|
537
|
+
if patches_paddings is not None:
|
|
538
|
+
temporal_paddings = einshape.jax_einshape(
|
|
539
|
+
'(bt)n->(bn)t', patches_paddings, t=t
|
|
540
|
+
) # (B * N, T).
|
|
541
|
+
|
|
542
|
+
# Add temporal positional encoding.
|
|
543
|
+
temporal_seq_length = self.pos_emb_shape[0]
|
|
544
|
+
temporal_pos_emb = TrainablePositionalEmbedding(
|
|
545
|
+
name='temporal_pos_emb',
|
|
546
|
+
embedding_dim=self.model_dim,
|
|
547
|
+
max_seq_length=temporal_seq_length,
|
|
548
|
+
dtype=self.dtype,
|
|
549
|
+
fprop_dtype=self.fprop_dtype,
|
|
550
|
+
)(seq_length=temporal_seq_length)
|
|
551
|
+
if temporal_seq_length != t:
|
|
552
|
+
temporal_pos_emb = _interpolate_emb_1d(temporal_pos_emb, t)
|
|
553
|
+
features += temporal_pos_emb
|
|
554
|
+
|
|
555
|
+
# Get features from the temporal encoder.
|
|
556
|
+
features = VisionTransformer(
|
|
557
|
+
name='temporal_encoder',
|
|
558
|
+
num_tfm_layers=self.num_temporal_layers,
|
|
559
|
+
mlp_dim=self.mlp_dim,
|
|
560
|
+
num_heads=self.num_heads,
|
|
561
|
+
atten_logit_cap=self.atten_logit_cap,
|
|
562
|
+
norm_policy=self.norm_policy,
|
|
563
|
+
scan=self.scan,
|
|
564
|
+
dtype=self.dtype,
|
|
565
|
+
fprop_dtype=self.fprop_dtype,
|
|
566
|
+
)(features, train=train, paddings=temporal_paddings)
|
|
567
|
+
features = layers.LayerNorm(
|
|
568
|
+
name='temporal_ln', dtype=self.dtype, fprop_dtype=self.fprop_dtype
|
|
569
|
+
)(features)
|
|
570
|
+
features = einshape.jax_einshape( # (B, T * N, D).
|
|
571
|
+
'(bn)td->b(tn)d', features, b=b
|
|
572
|
+
)
|
|
573
|
+
|
|
574
|
+
embeddings, outputs = features, {}
|
|
575
|
+
if _contains(return_intermediate, 'spatial_features'):
|
|
576
|
+
outputs['spatial_features'] = einshape.jax_einshape(
|
|
577
|
+
'(bt)nd->b(tn)d', spatial_features, t=t
|
|
578
|
+
)
|
|
579
|
+
|
|
580
|
+
return embeddings, outputs
|
|
581
|
+
|
|
582
|
+
|
|
583
|
+
class FactorizedVideoClassifier(layers.Module):
|
|
584
|
+
"""Video classifier with `FactorizedEncoder` backbone.
|
|
585
|
+
|
|
586
|
+
Attributes:
|
|
587
|
+
encoder_params: A dictionary of parameters for `FactorizedEncoder`.
|
|
588
|
+
num_classes: Number of output classes.
|
|
589
|
+
"""
|
|
590
|
+
|
|
591
|
+
encoder_params: dict[str, Any] = dataclasses.field(default_factory=dict)
|
|
592
|
+
num_classes: int = 0
|
|
593
|
+
|
|
594
|
+
@nn.compact
|
|
595
|
+
def __call__(
|
|
596
|
+
self,
|
|
597
|
+
inputs: Array,
|
|
598
|
+
train: bool = False,
|
|
599
|
+
return_intermediate: bool | Collection[str] = False,
|
|
600
|
+
frame_paddings: Array | None = None,
|
|
601
|
+
):
|
|
602
|
+
"""Applies video classifier to inputs.
|
|
603
|
+
|
|
604
|
+
Args:
|
|
605
|
+
inputs: Input tensor of shape [B, T, H, W, 3].
|
|
606
|
+
train: Whether the model is in the training mode.
|
|
607
|
+
return_intermediate: A boolean for whether all intermediate features are
|
|
608
|
+
returned, or a collection of intermediate feature names to return.
|
|
609
|
+
frame_paddings: Optional binary tensor of shape [B, T] indicating padding.
|
|
610
|
+
1 denotes padding frame.
|
|
611
|
+
|
|
612
|
+
Returns:
|
|
613
|
+
logits: Output tensor of shape [B, num_classes].
|
|
614
|
+
outputs: A dictionary of additional outputs, including `spatial_features`
|
|
615
|
+
of shape [B, T * N, D], `spatiotemporal_features` of shape [B, T * N,
|
|
616
|
+
D], and `global_embeddings` of shape [B, D]. Empty if
|
|
617
|
+
`return_intermediate` is False.
|
|
618
|
+
"""
|
|
619
|
+
features, outputs = FactorizedEncoder(
|
|
620
|
+
name='encoder',
|
|
621
|
+
dtype=self.dtype,
|
|
622
|
+
fprop_dtype=self.fprop_dtype,
|
|
623
|
+
**self.encoder_params,
|
|
624
|
+
)(
|
|
625
|
+
inputs,
|
|
626
|
+
train=train,
|
|
627
|
+
return_intermediate=return_intermediate,
|
|
628
|
+
frame_paddings=frame_paddings,
|
|
629
|
+
)
|
|
630
|
+
if _contains(return_intermediate, 'spatiotemporal_features'):
|
|
631
|
+
outputs['spatiotemporal_features'] = features
|
|
632
|
+
|
|
633
|
+
embeddings = layers.AttenTokenPoolingLayer(
|
|
634
|
+
name='atten_pooler',
|
|
635
|
+
num_heads=self.encoder_params['num_heads'],
|
|
636
|
+
hidden_dim=self.encoder_params['model_dim'],
|
|
637
|
+
num_queries=1,
|
|
638
|
+
dtype=self.dtype,
|
|
639
|
+
fprop_dtype=self.fprop_dtype,
|
|
640
|
+
)(features, paddings=None, train=train)
|
|
641
|
+
embeddings = jnp.squeeze(embeddings, axis=-2)
|
|
642
|
+
|
|
643
|
+
if _contains(return_intermediate, 'global_embeddings'):
|
|
644
|
+
outputs['global_embeddings'] = embeddings
|
|
645
|
+
|
|
646
|
+
logits = layers.FeedForward(
|
|
647
|
+
name='projection',
|
|
648
|
+
output_dim=self.num_classes,
|
|
649
|
+
activation_fn=layers.identity,
|
|
650
|
+
dtype=self.dtype,
|
|
651
|
+
fprop_dtype=self.fprop_dtype,
|
|
652
|
+
)(embeddings)
|
|
653
|
+
return logits, outputs
|
|
654
|
+
|
|
655
|
+
|
|
656
|
+
class TextEncoder(layers.Module):
|
|
657
|
+
"""CoCa-style text encoder.
|
|
658
|
+
|
|
659
|
+
Reference: https://arxiv.org/abs/2205.01917
|
|
660
|
+
|
|
661
|
+
Attributes:
|
|
662
|
+
vocabulary_size: Vocabulary size of the text tokens.
|
|
663
|
+
num_class_tokens: Number of class tokens.
|
|
664
|
+
enable_causal_atten: Whether to enable causal attention.
|
|
665
|
+
model_dim: The model dimension.
|
|
666
|
+
num_tfm_layers: Number of layers in this model.
|
|
667
|
+
mlp_dim: The hidden layer dimension of FFN in Transformer layers.
|
|
668
|
+
num_heads: Number of attention heads.
|
|
669
|
+
enable_per_dim_scale: Whether to ensable rescaling of attention logits with
|
|
670
|
+
1/sqrt(dim) factor.
|
|
671
|
+
atten_logit_cap: Cap the absolute values of logits by tanh. Enabled when a
|
|
672
|
+
positive value is specified. May not be supported by a subclass.
|
|
673
|
+
norm_policy: Policy for applying normalization wrt. transformations. Options
|
|
674
|
+
are: (1) "pre", applied before transformation. (2) "primer_hybrid",
|
|
675
|
+
applied before and after transformation. (3) "post", applied after
|
|
676
|
+
transformation. (4) "post_skip", applied after the skip connection.
|
|
677
|
+
scan: Whether to use `nn.remat` and`nn.scan`.
|
|
678
|
+
"""
|
|
679
|
+
|
|
680
|
+
vocabulary_size: int = 128
|
|
681
|
+
num_class_tokens: int = 0
|
|
682
|
+
enable_causal_atten: bool = True
|
|
683
|
+
model_dim: int = 768
|
|
684
|
+
num_layers: int = 12
|
|
685
|
+
mlp_dim: int = 3072
|
|
686
|
+
num_heads: int = 12
|
|
687
|
+
atten_logit_cap: float = 0.0
|
|
688
|
+
norm_policy: str = 'pre'
|
|
689
|
+
enable_per_dim_scale: bool = False
|
|
690
|
+
scan: bool = False
|
|
691
|
+
|
|
692
|
+
@nn.compact
|
|
693
|
+
def __call__(
|
|
694
|
+
self, inputs: Array, paddings: Array, train: bool = False
|
|
695
|
+
) -> Array:
|
|
696
|
+
"""Applies the text encoder to the inputs.
|
|
697
|
+
|
|
698
|
+
Args:
|
|
699
|
+
inputs: Input tensor of shape [B, N] including sequences of token ids.
|
|
700
|
+
paddings: Optional [B, N] padding field of inputs.
|
|
701
|
+
train: If the model is in the train mode.
|
|
702
|
+
|
|
703
|
+
Returns:
|
|
704
|
+
Output tensor of shape [B, N, D].
|
|
705
|
+
"""
|
|
706
|
+
batch_size, seq_length = inputs.shape
|
|
707
|
+
|
|
708
|
+
pos_emb = PositionalEmbedding(
|
|
709
|
+
name='pos_emb',
|
|
710
|
+
embedding_dim=self.model_dim,
|
|
711
|
+
dtype=self.dtype,
|
|
712
|
+
fprop_dtype=self.fprop_dtype,
|
|
713
|
+
)(seq_length=seq_length)
|
|
714
|
+
input_emb = Embedding(
|
|
715
|
+
name='token_emb',
|
|
716
|
+
num_classes=self.vocabulary_size,
|
|
717
|
+
input_dim=self.model_dim,
|
|
718
|
+
scale_sqrt_depth=True,
|
|
719
|
+
dtype=self.dtype,
|
|
720
|
+
fprop_dtype=self.fprop_dtype,
|
|
721
|
+
)(inputs)
|
|
722
|
+
features = input_emb + pos_emb
|
|
723
|
+
|
|
724
|
+
if self.num_class_tokens > 0:
|
|
725
|
+
cls_emb = self._cast_to_fprop_dtype(
|
|
726
|
+
self.param(
|
|
727
|
+
'cls_emb',
|
|
728
|
+
nn.initializers.normal(stddev=1.0 / math.sqrt(self.model_dim)),
|
|
729
|
+
[1, self.num_class_tokens, self.model_dim],
|
|
730
|
+
self.dtype,
|
|
731
|
+
)
|
|
732
|
+
)
|
|
733
|
+
cls_emb = jnp.tile(cls_emb, [batch_size, 1, 1])
|
|
734
|
+
cls_emb *= self.model_dim**0.5
|
|
735
|
+
features = jnp.concatenate([features, cls_emb], axis=-2)
|
|
736
|
+
|
|
737
|
+
cls_paddings = jnp.zeros(
|
|
738
|
+
[batch_size, self.num_class_tokens], dtype=paddings.dtype
|
|
739
|
+
)
|
|
740
|
+
paddings = jnp.concatenate([paddings, cls_paddings], axis=-1)
|
|
741
|
+
|
|
742
|
+
features = layers.StackedTransformer(
|
|
743
|
+
name='unimodal_transformer',
|
|
744
|
+
num_layers=self.num_layers,
|
|
745
|
+
hidden_dim=self.mlp_dim,
|
|
746
|
+
num_heads=self.num_heads,
|
|
747
|
+
atten_logit_cap=self.atten_logit_cap,
|
|
748
|
+
norm_policy=self.norm_policy,
|
|
749
|
+
internal_enable_per_dim_scale=self.enable_per_dim_scale,
|
|
750
|
+
activation_fn=jax.nn.relu,
|
|
751
|
+
enable_causal_atten=self.enable_causal_atten,
|
|
752
|
+
scan=self.scan,
|
|
753
|
+
dtype=self.dtype,
|
|
754
|
+
fprop_dtype=self.fprop_dtype,
|
|
755
|
+
)(features, paddings, train=train)
|
|
756
|
+
features = layers.LayerNorm(
|
|
757
|
+
name='unimodal_ln', dtype=self.dtype, fprop_dtype=self.fprop_dtype
|
|
758
|
+
)(features)
|
|
759
|
+
return features
|
|
760
|
+
|
|
761
|
+
|
|
762
|
+
class FactorizedVideoCLIP(layers.Module):
|
|
763
|
+
"""Video CLIP model with a factorized vision encoder."""
|
|
764
|
+
|
|
765
|
+
# Vision parameters.
|
|
766
|
+
patch_size: int = 18
|
|
767
|
+
pos_emb_shape: tuple[int, int, int] = (16, 16, 16)
|
|
768
|
+
num_spatial_layers: int = 12
|
|
769
|
+
num_temporal_layers: int = 4
|
|
770
|
+
mlp_dim: int = 3072
|
|
771
|
+
num_auxiliary_layers: int = 0
|
|
772
|
+
# Text parameters.
|
|
773
|
+
vocabulary_size: int = 128
|
|
774
|
+
enable_causal_atten: bool = True
|
|
775
|
+
num_unimodal_layers: int = 12
|
|
776
|
+
norm_policy: str = 'pre'
|
|
777
|
+
# Shared parameters.
|
|
778
|
+
model_dim: int = 768
|
|
779
|
+
num_heads: int = 12
|
|
780
|
+
atten_logit_cap: float = 0.0
|
|
781
|
+
scan: bool = False
|
|
782
|
+
|
|
783
|
+
@nn.compact
|
|
784
|
+
def __call__(
|
|
785
|
+
self,
|
|
786
|
+
inputs: Array | None = None,
|
|
787
|
+
text_token_ids: Array | None = None,
|
|
788
|
+
text_paddings: Array | None = None,
|
|
789
|
+
train: bool = False,
|
|
790
|
+
normalize: bool = True,
|
|
791
|
+
return_intermediate: bool | Collection[str] = False,
|
|
792
|
+
frame_paddings: Array | None = None,
|
|
793
|
+
) -> tuple[Array | None, Array | None, dict[str, Array]]:
|
|
794
|
+
"""Computes predictions for `input_batch`.
|
|
795
|
+
|
|
796
|
+
Args:
|
|
797
|
+
inputs: Input frame image tensor of shape [B, T, H, W, 3] (H == W).
|
|
798
|
+
text_token_ids: Input text token id tensor of shape [B, L].
|
|
799
|
+
text_paddings: Input text paddings of shape [B, L]. Required if
|
|
800
|
+
`text_token_ids` is not None.
|
|
801
|
+
train: If the model is in the train mode.
|
|
802
|
+
normalize: Whether to normalize the output embeddings.
|
|
803
|
+
return_intermediate: A boolean for whether all intermediate features are
|
|
804
|
+
returned, or a collection of intermediate feature names to return.
|
|
805
|
+
frame_paddings: Optional binary tensor of shape [B, T] indicating padding.
|
|
806
|
+
1 denotes padding frame.
|
|
807
|
+
|
|
808
|
+
Returns:
|
|
809
|
+
video_embeddings: Output contrastive video embeddings of shape [B, D].
|
|
810
|
+
None if `inputs` is None.
|
|
811
|
+
text_embeddings: Output contrastive text embeddings of shape [B, D]. None
|
|
812
|
+
if `text_token_ids` is None.
|
|
813
|
+
outputs: A dictionary of additional outputs, including `spatial_features`
|
|
814
|
+
of shape [B, T * N, D], `spatiotemporal_features` of shape [B, T * N,
|
|
815
|
+
D], and `frame_embeddings` of shape [B, T, D]. Empty if
|
|
816
|
+
`return_intermediate` is False or does not contain `spatial_features`.
|
|
817
|
+
"""
|
|
818
|
+
video_embeddings, text_embeddings, outputs = None, None, {}
|
|
819
|
+
|
|
820
|
+
if inputs is not None:
|
|
821
|
+
num_frames = inputs.shape[-4]
|
|
822
|
+
vision_features, vision_outputs = FactorizedEncoder(
|
|
823
|
+
name='vision_encoder',
|
|
824
|
+
patch_size=self.patch_size,
|
|
825
|
+
pos_emb_shape=self.pos_emb_shape,
|
|
826
|
+
model_dim=self.model_dim,
|
|
827
|
+
num_spatial_layers=self.num_spatial_layers,
|
|
828
|
+
num_temporal_layers=self.num_temporal_layers,
|
|
829
|
+
num_heads=self.num_heads,
|
|
830
|
+
mlp_dim=self.mlp_dim,
|
|
831
|
+
atten_logit_cap=self.atten_logit_cap,
|
|
832
|
+
norm_policy='pre',
|
|
833
|
+
scan=self.scan,
|
|
834
|
+
dtype=self.dtype,
|
|
835
|
+
fprop_dtype=self.fprop_dtype,
|
|
836
|
+
)(
|
|
837
|
+
inputs,
|
|
838
|
+
train=train,
|
|
839
|
+
return_intermediate=return_intermediate,
|
|
840
|
+
frame_paddings=frame_paddings,
|
|
841
|
+
)
|
|
842
|
+
outputs.update(vision_outputs)
|
|
843
|
+
if _contains(return_intermediate, 'spatiotemporal_features'):
|
|
844
|
+
outputs['spatiotemporal_features'] = vision_features
|
|
845
|
+
|
|
846
|
+
if self.num_auxiliary_layers > 0:
|
|
847
|
+
vision_features = VisionTransformer(
|
|
848
|
+
name='auxiliary_encoder',
|
|
849
|
+
num_tfm_layers=self.num_auxiliary_layers,
|
|
850
|
+
mlp_dim=self.mlp_dim,
|
|
851
|
+
num_heads=self.num_heads,
|
|
852
|
+
atten_logit_cap=self.atten_logit_cap,
|
|
853
|
+
norm_policy='pre',
|
|
854
|
+
scan=self.scan,
|
|
855
|
+
dtype=self.dtype,
|
|
856
|
+
fprop_dtype=self.fprop_dtype,
|
|
857
|
+
)(vision_features, train=train)
|
|
858
|
+
|
|
859
|
+
pooling_layer = layers.AttenTokenPoolingLayer(
|
|
860
|
+
name='contrastive_vision_pooler',
|
|
861
|
+
hidden_dim=self.model_dim * 4,
|
|
862
|
+
num_heads=self.num_heads,
|
|
863
|
+
num_queries=1,
|
|
864
|
+
dtype=self.dtype,
|
|
865
|
+
fprop_dtype=self.fprop_dtype,
|
|
866
|
+
)
|
|
867
|
+
video_embeddings = pooling_layer(vision_features, None, train=train)
|
|
868
|
+
|
|
869
|
+
# Squeeze the query dimension in the pooler output.
|
|
870
|
+
video_embeddings = jnp.squeeze(video_embeddings, axis=-2)
|
|
871
|
+
if normalize:
|
|
872
|
+
video_embeddings = _l2_normalize(video_embeddings, axis=-1)
|
|
873
|
+
|
|
874
|
+
if _contains(return_intermediate, 'frame_embeddings'):
|
|
875
|
+
frame_features = einshape.jax_einshape(
|
|
876
|
+
'b(tn)d->(bt)nd', vision_features, t=num_frames
|
|
877
|
+
)
|
|
878
|
+
frame_embeddings = pooling_layer(frame_features, None, train=train)
|
|
879
|
+
frame_embeddings = jnp.squeeze(frame_embeddings, axis=-2)
|
|
880
|
+
frame_embeddings = einshape.jax_einshape(
|
|
881
|
+
'(bt)d->btd', frame_embeddings, t=num_frames
|
|
882
|
+
)
|
|
883
|
+
if normalize:
|
|
884
|
+
frame_embeddings = _l2_normalize(frame_embeddings, axis=-1)
|
|
885
|
+
outputs['frame_embeddings'] = frame_embeddings
|
|
886
|
+
|
|
887
|
+
if text_token_ids is not None:
|
|
888
|
+
assert text_paddings is not None, 'Text paddings are required.'
|
|
889
|
+
text_features = TextEncoder(
|
|
890
|
+
name='text_encoder',
|
|
891
|
+
vocabulary_size=self.vocabulary_size,
|
|
892
|
+
num_class_tokens=1,
|
|
893
|
+
enable_causal_atten=self.enable_causal_atten,
|
|
894
|
+
model_dim=self.model_dim,
|
|
895
|
+
num_layers=self.num_unimodal_layers,
|
|
896
|
+
num_heads=self.num_heads,
|
|
897
|
+
mlp_dim=self.model_dim * 4,
|
|
898
|
+
atten_logit_cap=self.atten_logit_cap,
|
|
899
|
+
norm_policy=self.norm_policy,
|
|
900
|
+
scan=self.scan,
|
|
901
|
+
dtype=self.dtype,
|
|
902
|
+
fprop_dtype=self.fprop_dtype,
|
|
903
|
+
)(text_token_ids, text_paddings, train=train)
|
|
904
|
+
|
|
905
|
+
# Take the last token (i.e., class token) as the text embedding.
|
|
906
|
+
text_embeddings = text_features[:, -1]
|
|
907
|
+
if normalize:
|
|
908
|
+
text_embeddings = _l2_normalize(text_embeddings, axis=-1)
|
|
909
|
+
|
|
910
|
+
return video_embeddings, text_embeddings, outputs
|