singlebehaviorlab 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. sam2/__init__.py +11 -0
  2. sam2/automatic_mask_generator.py +454 -0
  3. sam2/benchmark.py +92 -0
  4. sam2/build_sam.py +174 -0
  5. sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
  6. sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
  7. sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
  8. sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
  9. sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +116 -0
  10. sam2/configs/sam2.1/sam2.1_hiera_l.yaml +120 -0
  11. sam2/configs/sam2.1/sam2.1_hiera_s.yaml +119 -0
  12. sam2/configs/sam2.1/sam2.1_hiera_t.yaml +121 -0
  13. sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +339 -0
  14. sam2/modeling/__init__.py +5 -0
  15. sam2/modeling/backbones/__init__.py +5 -0
  16. sam2/modeling/backbones/hieradet.py +317 -0
  17. sam2/modeling/backbones/image_encoder.py +134 -0
  18. sam2/modeling/backbones/utils.py +93 -0
  19. sam2/modeling/memory_attention.py +169 -0
  20. sam2/modeling/memory_encoder.py +181 -0
  21. sam2/modeling/position_encoding.py +239 -0
  22. sam2/modeling/sam/__init__.py +5 -0
  23. sam2/modeling/sam/mask_decoder.py +295 -0
  24. sam2/modeling/sam/prompt_encoder.py +202 -0
  25. sam2/modeling/sam/transformer.py +311 -0
  26. sam2/modeling/sam2_base.py +913 -0
  27. sam2/modeling/sam2_utils.py +323 -0
  28. sam2/sam2_hiera_b+.yaml +113 -0
  29. sam2/sam2_hiera_l.yaml +117 -0
  30. sam2/sam2_hiera_s.yaml +116 -0
  31. sam2/sam2_hiera_t.yaml +118 -0
  32. sam2/sam2_image_predictor.py +466 -0
  33. sam2/sam2_video_predictor.py +1388 -0
  34. sam2/sam2_video_predictor_legacy.py +1172 -0
  35. sam2/utils/__init__.py +5 -0
  36. sam2/utils/amg.py +348 -0
  37. sam2/utils/misc.py +349 -0
  38. sam2/utils/transforms.py +118 -0
  39. singlebehaviorlab/__init__.py +4 -0
  40. singlebehaviorlab/__main__.py +130 -0
  41. singlebehaviorlab/_paths.py +100 -0
  42. singlebehaviorlab/backend/__init__.py +2 -0
  43. singlebehaviorlab/backend/augmentations.py +320 -0
  44. singlebehaviorlab/backend/data_store.py +420 -0
  45. singlebehaviorlab/backend/model.py +1290 -0
  46. singlebehaviorlab/backend/train.py +4667 -0
  47. singlebehaviorlab/backend/uncertainty.py +578 -0
  48. singlebehaviorlab/backend/video_processor.py +688 -0
  49. singlebehaviorlab/backend/video_utils.py +139 -0
  50. singlebehaviorlab/data/config/config.yaml +85 -0
  51. singlebehaviorlab/data/training_profiles.json +334 -0
  52. singlebehaviorlab/gui/__init__.py +4 -0
  53. singlebehaviorlab/gui/analysis_widget.py +2291 -0
  54. singlebehaviorlab/gui/attention_export.py +311 -0
  55. singlebehaviorlab/gui/clip_extraction_widget.py +481 -0
  56. singlebehaviorlab/gui/clustering_widget.py +3187 -0
  57. singlebehaviorlab/gui/inference_popups.py +1138 -0
  58. singlebehaviorlab/gui/inference_widget.py +4550 -0
  59. singlebehaviorlab/gui/inference_worker.py +651 -0
  60. singlebehaviorlab/gui/labeling_widget.py +2324 -0
  61. singlebehaviorlab/gui/main_window.py +754 -0
  62. singlebehaviorlab/gui/metadata_management_widget.py +1119 -0
  63. singlebehaviorlab/gui/motion_tracking.py +764 -0
  64. singlebehaviorlab/gui/overlay_export.py +1234 -0
  65. singlebehaviorlab/gui/plot_integration.py +729 -0
  66. singlebehaviorlab/gui/qt_helpers.py +29 -0
  67. singlebehaviorlab/gui/registration_widget.py +1485 -0
  68. singlebehaviorlab/gui/review_widget.py +1330 -0
  69. singlebehaviorlab/gui/segmentation_tracking_widget.py +2752 -0
  70. singlebehaviorlab/gui/tab_tutorial_dialog.py +312 -0
  71. singlebehaviorlab/gui/timeline_themes.py +131 -0
  72. singlebehaviorlab/gui/training_profiles.py +418 -0
  73. singlebehaviorlab/gui/training_widget.py +3719 -0
  74. singlebehaviorlab/gui/video_utils.py +233 -0
  75. singlebehaviorlab/licenses/SAM2-LICENSE +201 -0
  76. singlebehaviorlab/licenses/VideoPrism-LICENSE +202 -0
  77. singlebehaviorlab-2.0.0.dist-info/METADATA +447 -0
  78. singlebehaviorlab-2.0.0.dist-info/RECORD +88 -0
  79. singlebehaviorlab-2.0.0.dist-info/WHEEL +5 -0
  80. singlebehaviorlab-2.0.0.dist-info/entry_points.txt +2 -0
  81. singlebehaviorlab-2.0.0.dist-info/licenses/LICENSE +21 -0
  82. singlebehaviorlab-2.0.0.dist-info/top_level.txt +3 -0
  83. videoprism/__init__.py +0 -0
  84. videoprism/encoders.py +910 -0
  85. videoprism/layers.py +1136 -0
  86. videoprism/models.py +407 -0
  87. videoprism/tokenizers.py +167 -0
  88. videoprism/utils.py +168 -0
videoprism/layers.py ADDED
@@ -0,0 +1,1136 @@
1
+ # Copyright 2026 VideoPrism Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """VideoPrism Flax layers."""
16
+
17
+ from collections.abc import Callable
18
+ import functools
19
+ import string
20
+ from typing import Any
21
+ from flax import linen as nn
22
+ import jax
23
+ from jax import numpy as jnp
24
+ import numpy as np
25
+
26
+ Array = jax.Array
27
+ ActivationFunc = Callable[[Array], Array]
28
+ Initializer = nn.initializers.Initializer
29
+
30
+ default_kernel_init = nn.initializers.lecun_normal()
31
+ gelu = functools.partial(jax.nn.gelu, approximate=False)
32
+
33
+
34
+ def identity(x: Array) -> Array:
35
+ """Identity activation."""
36
+ return x
37
+
38
+
39
+ def _get_large_negative_number(dtype: jax.typing.DTypeLike) -> Array:
40
+ """Returns a large-magnitude negative value for the given dtype."""
41
+ # -0.7 is a float64 in JAX. Explicit cast output to target dtype.
42
+ if jnp.issubdtype(dtype, jnp.inexact):
43
+ dtype_max = jnp.finfo(dtype).max
44
+ elif jnp.issubdtype(dtype, jnp.integer):
45
+ dtype_max = jnp.iinfo(dtype).max
46
+ else:
47
+ raise ValueError('Unsupported dtype for inputs.')
48
+ return jnp.asarray(-0.7 * dtype_max, dtype=dtype)
49
+
50
+
51
+ def _apply_mask_to_logits(logits: Array, mask: Array) -> Array:
52
+ """Applies a floating-point mask to a set of logits.
53
+
54
+ The mask is represented as a float32 tensor where 0 represents true and values
55
+ below a large negative number (here set to
56
+ _get_large_negative_number(jnp.float32) / 2) represent false. Applying the
57
+ mask leaves the logits alone in the true case and replaces them by
58
+ _get_large_negative_number(jnp.float32) in the false case. Previously, this
59
+ was done by adding the logits to the mask; however, this leads to a bad fusion
60
+ decision in the compiler that saves the float32 values in memory rather than
61
+ just the predicate. This implementation avoids that problem.
62
+
63
+ Args:
64
+ logits: A jax.Array of logit values.
65
+ mask: A jax.Array (float32) of mask values with the encoding described in
66
+ the function documentation.
67
+
68
+ Returns:
69
+ Masked logits.
70
+ """
71
+ min_value = _get_large_negative_number(logits.dtype)
72
+ return jnp.where((mask >= min_value * 0.5), logits, min_value)
73
+
74
+
75
+ def _convert_paddings_to_mask(
76
+ paddings: Array, dtype: jax.typing.DTypeLike = jnp.float32
77
+ ) -> Array:
78
+ """Converts binary paddings to a logit mask ready to add to attention matrix.
79
+
80
+ Args:
81
+ paddings: A binary jax.Array of shape [B, T], with 1 denoting padding token.
82
+ dtype: Data type of the input.
83
+
84
+ Returns:
85
+ A jax.Array of shape [B, 1, 1, T] ready to be added to attention logits.
86
+ """
87
+ attention_mask = paddings[:, jnp.newaxis, jnp.newaxis, :]
88
+ attention_mask *= _get_large_negative_number(dtype)
89
+ return attention_mask
90
+
91
+
92
+ def _causal_mask(input_t: Array) -> Array:
93
+ """Computes and returns causal mask.
94
+
95
+ Args:
96
+ input_t: A jax.Array of shape [B, T, D].
97
+
98
+ Returns:
99
+ An attention_mask jax.Array of shape [1, 1, T, T]. Attention mask has
100
+ already been converted large negative values.
101
+ """
102
+ assert jnp.issubdtype(input_t.dtype, jnp.floating), input_t.dtype
103
+ large_negative_number = _get_large_negative_number(input_t.dtype)
104
+ t = input_t.shape[-2]
105
+ col_idx = jnp.tile(jnp.arange(t)[jnp.newaxis, :], [t, 1])
106
+ row_idx = jnp.tile(jnp.arange(t)[:, jnp.newaxis], [1, t])
107
+ mask = (row_idx < col_idx).astype(input_t.dtype) * large_negative_number
108
+ return mask[jnp.newaxis, jnp.newaxis, :, :]
109
+
110
+
111
+ def _merge_masks(a: Array, b: Array) -> Array:
112
+ """Merges two masks.
113
+
114
+ This function merges two masks with the same shape, where the smaller value
115
+ will be chosen at the same position. Log-scale mask is expected but 0/1 mask
116
+ is also fine.
117
+
118
+ Args:
119
+ a: A jax.Array of shape [1|B, 1, 1|T, S].
120
+ b: A jax.Array of shape [1|B, 1, 1|T, S].
121
+
122
+ Returns:
123
+ A jax.Array of shape [1|B, 1, 1|T, S].
124
+ """
125
+
126
+ def expand_t(key_mask):
127
+ """Expands the 1D mask to the 2D mask.
128
+
129
+ Given [[1, 1, 0, 0]], this function returns the following mask,
130
+ 1 1 0 0
131
+ 1 1 0 0
132
+ 0 0 0 0
133
+ 0 0 0 0
134
+
135
+ Args:
136
+ key_mask: A jax.Array of the input 1D mask.
137
+
138
+ Returns:
139
+ A jax.Array of the expanded 2D mask.
140
+ """
141
+ query_mask = jnp.transpose(key_mask, [0, 1, 3, 2])
142
+ return jnp.minimum(query_mask, key_mask)
143
+
144
+ if a.shape[-2] != b.shape[-2]:
145
+ if a.shape[-2] == 1:
146
+ a = expand_t(a)
147
+ else:
148
+ assert b.shape[-2] == 1
149
+ b = expand_t(b)
150
+
151
+ assert a.shape[-3:] == b.shape[-3:], f'a.shape={a.shape}, b.shape={b.shape}.'
152
+ return jnp.minimum(a, b)
153
+
154
+
155
+ def compute_attention_masks_for_fprop(
156
+ inputs: Array,
157
+ paddings: Array,
158
+ causal_attention: bool = False,
159
+ ) -> Array:
160
+ """Computes attention mask from inputs and paddings for fprop.
161
+
162
+ Args:
163
+ inputs: Input sequence jax.Array of shape [B, T, H].
164
+ paddings: Input paddings jax.Array of shape [B, T].
165
+ causal_attention: Boolean to apply causal masking.
166
+
167
+ Returns:
168
+ attention_mask: Attention mask jax.Array ready to be added to logits for
169
+ self-attention of shape [1|B, 1, 1|T, T].
170
+ """
171
+ # Get paddings mask to [B, 1, 1, T].
172
+ attention_mask = _convert_paddings_to_mask(paddings, inputs.dtype)
173
+
174
+ # Causal mask of shape [1, 1, T, T].
175
+ if causal_attention:
176
+ causal_mask = _causal_mask(inputs)
177
+ attention_mask = _merge_masks(attention_mask, causal_mask)
178
+
179
+ return attention_mask
180
+
181
+
182
+ class Module(nn.Module):
183
+ """Base class for layers with dtype configured.
184
+
185
+ Attributes:
186
+ dtype: Default dtype for all variables.
187
+ fprop_dtype: Activations dtype to use.
188
+ """
189
+
190
+ dtype: jnp.dtype = jnp.float32
191
+ fprop_dtype: jnp.dtype = jnp.float32
192
+
193
+ @nn.nowrap
194
+ def _cast_to_fprop_dtype(self, value: Any) -> Any:
195
+ """Casts values to the desired dtype."""
196
+
197
+ def _cast(x):
198
+ if x is None:
199
+ return None
200
+ if self.fprop_dtype != x.dtype:
201
+ if jnp.issubdtype(x.dtype, jnp.floating):
202
+ return x.astype(self.fprop_dtype)
203
+ return x
204
+
205
+ return jax.tree_util.tree_map(_cast, value)
206
+
207
+
208
+ class LayerNorm(Module):
209
+ """Layer normalization.
210
+
211
+ Attributes:
212
+ direct_scale: Whether to apply scale directly without a +1.0. Var is
213
+ initialized to 1.0 instead when True.
214
+ epsilon: Tiny value to guard rsqrt.
215
+ use_scale: Whether to use a learned scaling.
216
+ use_bias: Whether to use bias.
217
+ reductions_in_fp32: Whether to compute mean and variance in fp32.
218
+ Recommended for stable training on GPUs.
219
+ """
220
+
221
+ direct_scale: bool = False
222
+ epsilon: float = 1e-6
223
+ use_scale: bool = True
224
+ use_bias: bool = True
225
+ reductions_in_fp32: bool = False
226
+
227
+ @nn.compact
228
+ def __call__(self, inputs: Array) -> Array:
229
+ """Applies layer norm to inputs.
230
+
231
+ Args:
232
+ inputs: A jax.Array for the inputs of shape [..., dim].
233
+
234
+ Returns:
235
+ A jax.Aray for the normalized inputs of the same shape.
236
+ """
237
+ inputs_dtype = inputs.dtype
238
+ if self.reductions_in_fp32:
239
+ inputs = inputs.astype(jnp.float32)
240
+ mean = jnp.mean(inputs, axis=[-1], keepdims=True)
241
+ var = jnp.mean(jnp.square(inputs - mean), axis=[-1], keepdims=True)
242
+ normed_inputs = (inputs - mean) * jax.lax.rsqrt(var + self.epsilon)
243
+ if self.reductions_in_fp32:
244
+ normed_inputs = normed_inputs.astype(inputs_dtype)
245
+
246
+ input_dim = inputs.shape[-1]
247
+ if self.use_scale:
248
+ init_value = 1.0 if self.direct_scale else 0.0
249
+ scale = self._cast_to_fprop_dtype(
250
+ self.param(
251
+ 'scale',
252
+ nn.initializers.constant(init_value),
253
+ [input_dim],
254
+ self.dtype,
255
+ )
256
+ )
257
+ if not self.direct_scale:
258
+ scale += 1.0
259
+ normed_inputs *= scale
260
+ if self.use_bias:
261
+ bias = self._cast_to_fprop_dtype(
262
+ self.param(
263
+ 'bias',
264
+ nn.initializers.zeros_init(),
265
+ [input_dim],
266
+ self.dtype,
267
+ )
268
+ )
269
+ normed_inputs += bias
270
+ return normed_inputs
271
+
272
+
273
+ class FeedForward(Module):
274
+ """Feedforward layer with activation.
275
+
276
+ Attributes:
277
+ output_dim: Depth of the output.
278
+ has_bias: Adds bias weights or not.
279
+ activation_fn: Activation function to use.
280
+ weight_init: Initializer function for the weight matrix.
281
+ bias_init: Initializer function for the bias.
282
+ """
283
+
284
+ output_dim: int = 0
285
+ has_bias: bool = True
286
+ activation_fn: ActivationFunc = nn.relu
287
+ weight_init: Initializer = default_kernel_init
288
+ bias_init: Initializer = nn.initializers.zeros_init()
289
+
290
+ @nn.compact
291
+ def __call__(self, inputs: Array) -> Array:
292
+
293
+ def _promote_dtype(x, kernel, bias, dtype):
294
+ """Promotes the dtype of the arrays to the desired dtype."""
295
+ del dtype
296
+ # To be compatible with other layers, we do not promote the inputs as they
297
+ # are expected to be in the `fprop_dtype`.
298
+ return (
299
+ x,
300
+ self._cast_to_fprop_dtype(kernel),
301
+ self._cast_to_fprop_dtype(bias),
302
+ )
303
+
304
+ projected_inputs = nn.Dense(
305
+ self.output_dim,
306
+ use_bias=self.has_bias,
307
+ kernel_init=self.weight_init,
308
+ bias_init=self.bias_init,
309
+ name='linear',
310
+ param_dtype=self.dtype,
311
+ promote_dtype=_promote_dtype,
312
+ )(inputs)
313
+ return self.activation_fn(projected_inputs)
314
+
315
+
316
+ class TransformerFeedForward(Module):
317
+ """Transformer feedforward layer with residual connection and dropout.
318
+
319
+ Attributes:
320
+ output_dim: Depth of the output. The value of input_dim will be used when
321
+ output_dim is 0. Must be equal to input_dim if add_skip_connection=True.
322
+ hidden_dim: Hidden dimension of FFN.
323
+ has_bias: Adds bias weights to Feedforward or not.
324
+ activation_fn: Activation function to use.
325
+ residual_dropout_prob: Residual dropout.
326
+ relu_dropout_prob: FFN dropout.
327
+ add_skip_connection: Whether to add residual connection.
328
+ residual_weight: Weight of the residual connection. Output = fn(x) *
329
+ residual_weight + x.
330
+ norm_policy: Policy for applying normalization wrt. transformations. Options
331
+ are: (1) "pre", applied before transformation. (2) "primer_hybrid",
332
+ applied before and after transformation. (3) "post", applied after
333
+ transformation, (4) "post_skip", applied after the skip connection.
334
+ """
335
+
336
+ output_dim: int = 0
337
+ hidden_dim: int = 0
338
+ has_bias: bool = True
339
+ activation_fn: ActivationFunc = nn.relu
340
+ residual_dropout_prob: float = 0.0
341
+ relu_dropout_prob: float = 0.0
342
+ add_skip_connection: bool = True
343
+ residual_weight: float = 1.0
344
+ norm_policy: str = 'pre'
345
+
346
+ @nn.nowrap
347
+ def _make_ln(self, name: str) -> LayerNorm:
348
+ """Makes a LayerNorm module."""
349
+ return LayerNorm(
350
+ name=name,
351
+ use_bias=self.has_bias,
352
+ dtype=self.dtype,
353
+ fprop_dtype=self.fprop_dtype,
354
+ )
355
+
356
+ @nn.nowrap
357
+ def _make_ffn(
358
+ self, output_dim: int, name: str, skip_activation: bool = False
359
+ ) -> FeedForward:
360
+ """Makes a FeedForward module."""
361
+ return FeedForward(
362
+ name=name,
363
+ output_dim=output_dim,
364
+ has_bias=self.has_bias,
365
+ activation_fn=identity if skip_activation else self.activation_fn,
366
+ dtype=self.dtype,
367
+ fprop_dtype=self.fprop_dtype,
368
+ )
369
+
370
+ @nn.compact
371
+ def __call__(
372
+ self, inputs: Array, paddings: Array | None, train: bool
373
+ ) -> Array:
374
+ residual = inputs
375
+ output_dim = self.output_dim
376
+ if output_dim == 0:
377
+ output_dim = inputs.shape[-1]
378
+ if self.add_skip_connection and output_dim != inputs.shape[-1]:
379
+ raise ValueError(
380
+ 'Skip connections are only supported when input_dim == output_dim '
381
+ f'but got {self.input_dim} != {output_dim}'
382
+ )
383
+
384
+ # Expand paddings to last dim if not None to have shape [batch, seq_len, 1].
385
+ if paddings is not None:
386
+ paddings = jnp.expand_dims(paddings, axis=-1)
387
+
388
+ if self.norm_policy == 'primer_hybrid':
389
+ inputs = self._make_ln(name='pre_layer_norm')(inputs)
390
+ elif self.norm_policy == 'pre':
391
+ inputs = self._make_ln(name='layer_norm')(inputs)
392
+
393
+ # Apply first FFN layer.
394
+ activations = self._make_ffn(self.hidden_dim, name='ffn_layer1')(inputs)
395
+
396
+ # Apply paddings if not None.
397
+ if paddings is not None:
398
+ activations *= 1.0 - paddings
399
+
400
+ # Apply RELU dropout.
401
+ activations = nn.Dropout(self.relu_dropout_prob, name='relu_dropout')(
402
+ activations, deterministic=not train
403
+ )
404
+ # Apply second FFN layer.
405
+ outputs = self._make_ffn(
406
+ output_dim, name='ffn_layer2', skip_activation=True
407
+ )(activations)
408
+
409
+ # Apply paddings if not None.
410
+ if paddings is not None:
411
+ outputs *= 1.0 - paddings
412
+
413
+ # Apply Primer normalization before dropout.
414
+ if self.norm_policy == 'primer_hybrid':
415
+ outputs = self._make_ln(name='post_layer_norm')(outputs)
416
+ elif self.norm_policy == 'post':
417
+ outputs = self._make_ln(name='layer_norm')(outputs)
418
+
419
+ # Apply residual dropout.
420
+ outputs = nn.Dropout(self.residual_dropout_prob, name='residual_dropout')(
421
+ outputs, deterministic=not train
422
+ )
423
+ # Apply skip connection.
424
+ if self.add_skip_connection:
425
+ outputs = residual + outputs * self.residual_weight
426
+
427
+ if self.norm_policy == 'post_skip':
428
+ outputs = self._make_ln(name='layer_norm')(outputs)
429
+
430
+ return outputs
431
+
432
+
433
+ class AttentionProjection(Module):
434
+ """Layer that computes multi heads projection.
435
+
436
+ This layer is expected to be used within DotProductAttention below.
437
+
438
+ Attributes:
439
+ output_dim: Input dimension.
440
+ num_heads: Number of attention heads.
441
+ dim_per_head: Size of each head.
442
+ is_output_projection: Whether it is out projection or not. If False, we use
443
+ "...D,DNH->...NH" for query,key,value projection. Otherwise we use
444
+ "...NH,DNH->...D" for output projection.
445
+ use_bias: Whether to add bias in projection or not.
446
+ """
447
+
448
+ output_dim: int = 0
449
+ num_heads: int = 0
450
+ dim_per_head: int = 0
451
+ is_output_projection: bool = False
452
+ use_bias: bool = True
453
+
454
+ @nn.compact
455
+ def __call__(self, inputs: Array) -> Array:
456
+ """Computes the multi headed projection for inputs.
457
+
458
+ Args:
459
+ inputs: A jax.Array with shape [..., num_heads, dim_per_head] if
460
+ is_output_projection is True or [..., input_dim] otherwise.
461
+
462
+ Returns:
463
+ The projected jax.Array with shape [..., input_dim] if
464
+ is_output_projection is True or [..., num_heads, dim_per_head]
465
+ otherwise.
466
+ """
467
+ # Sort the available symbols to avoid nondeterminism.
468
+ eqn_sym = ''.join(sorted(set(string.ascii_uppercase) - set('DHN')))
469
+ output_dim = (
470
+ self.output_dim if self.is_output_projection else inputs.shape[-1]
471
+ )
472
+ rank = len(inputs.shape)
473
+
474
+ hd_shape = [self.num_heads, self.dim_per_head]
475
+ pc_shape = [output_dim] + hd_shape
476
+ w = self._cast_to_fprop_dtype(
477
+ self.param('w', default_kernel_init, pc_shape, self.dtype)
478
+ )
479
+
480
+ if self.is_output_projection:
481
+ assert inputs.shape[-2:] == (self.num_heads, self.dim_per_head)
482
+ batch_eqn = eqn_sym[: (rank - 2)]
483
+ eqn = f'{batch_eqn}NH,DNH->{batch_eqn}D'
484
+ else:
485
+ batch_eqn = eqn_sym[: (rank - 1)] if rank else '...'
486
+ eqn = f'{batch_eqn}D,DNH->{batch_eqn}NH'
487
+
488
+ ret = jnp.einsum(eqn, inputs, w)
489
+ if self.use_bias:
490
+ b = self._cast_to_fprop_dtype(
491
+ self.param(
492
+ 'b',
493
+ nn.initializers.zeros_init(),
494
+ [output_dim] if self.is_output_projection else hd_shape,
495
+ self.dtype,
496
+ )
497
+ )
498
+ ret += b
499
+ return ret
500
+
501
+
502
+ class PerDimScale(Module):
503
+ """A layer to scale individual dimensions of the input."""
504
+
505
+ @nn.compact
506
+ def __call__(self, inputs: Array) -> Array:
507
+ """Returns per_dim_scale * inputs / jnp.sqrt(dim)).
508
+
509
+ Args:
510
+ inputs: A jax.Array with shape [..., dim].
511
+
512
+ Returns:
513
+ outputs: A jax.Array with shape [..., dim].
514
+ """
515
+ dim = inputs.shape[-1]
516
+ per_dim_scale = self._cast_to_fprop_dtype(
517
+ self.param(
518
+ 'per_dim_scale', nn.initializers.zeros_init(), [dim], self.dtype
519
+ )
520
+ )
521
+
522
+ # 1.0/jax.nn.softplus(0.0) = 1.442695041. Hard code this number so that we
523
+ # can avoid unnecessary XLA op fusion mess on TPU.
524
+ r_softplus_0 = 1.442695041
525
+ scale = jnp.array(r_softplus_0 / np.sqrt(dim), dtype=self.fprop_dtype)
526
+ scale *= jax.nn.softplus(per_dim_scale)
527
+ return inputs * scale
528
+
529
+
530
+ class DotProductAttention(Module):
531
+ """Dot-product attention with multiple attention heads.
532
+
533
+ Attributes:
534
+ hidden_dim: Number of hidden nodes.
535
+ num_heads: Number of attention heads.
536
+ dim_per_head: Dimension of each attention head. If None then dim_per_head ==
537
+ hidden_dim // num_heads.
538
+ atten_dropout_prob: Probability at which we apply dropout to the attention
539
+ weights.
540
+ use_bias: Whether to use bias for projection layers.
541
+ internal_enable_query_scale: Internal. Enable scaling of query vector.
542
+ internal_enable_per_dim_scale: Internal. Setting to False disables rescaling
543
+ of attention logits with 1/sqrt(dim) factor. Some Transformer variants
544
+ (GShard, T5) use internal_enable_per_dim_scale=False and adjust
545
+ initialization of the linear transformations(einsums), in conjunction with
546
+ Adafactor optimizer.
547
+ scale_query_by_dim_per_head: whether to scale the query by dim_per_head,
548
+ instead of default hidden_dim // num_heads (only activated when
549
+ internal_enable_per_dim_scale = False).
550
+ scale_logits_by_head_dims: Enables a 1/sqrt(head dim) scaling to the logits.
551
+ This occurs prior to logit cap, if any.
552
+ atten_logit_cap: Cap the absolute values of logits by tanh. Enabled when a
553
+ positive value is specified. May not be supported by a subclass.
554
+ use_qk_norm: If QK norm is used.
555
+ """
556
+
557
+ hidden_dim: int = 0
558
+ num_heads: int = 1
559
+ dim_per_head: int | None = None
560
+ atten_dropout_prob: float = 0.0
561
+ use_bias: bool = True
562
+ internal_enable_query_scale: bool = True
563
+ internal_enable_per_dim_scale: bool = True
564
+ scale_query_by_dim_per_head: bool = False
565
+ scale_logits_by_head_dims: bool = False
566
+ atten_logit_cap: float = 0.0
567
+ use_qk_norm: bool = False
568
+
569
+ def _scale_query(self, query: Array) -> Array:
570
+ """Scales the query vector if enabled."""
571
+ if not self.internal_enable_query_scale:
572
+ return query
573
+ if self.internal_enable_per_dim_scale:
574
+ query = PerDimScale(
575
+ name='per_dim_scale', dtype=self.dtype, fprop_dtype=self.fprop_dtype
576
+ )(query)
577
+ else:
578
+ if self.scale_query_by_dim_per_head and self.dim_per_head is not None:
579
+ dim_per_head = self.dim_per_head
580
+ else:
581
+ dim_per_head = self.hidden_dim // self.num_heads
582
+
583
+ query *= dim_per_head**-0.5
584
+ return query
585
+
586
+ def _cap_logits(self, logits: Array) -> Array:
587
+ """Caps the logits by p.atten_logit_cap with tanh, if enabled."""
588
+ if not self.atten_logit_cap or self.atten_logit_cap <= 0.0:
589
+ return logits
590
+ cap = jnp.array(self.atten_logit_cap, dtype=self.fprop_dtype)
591
+ # Note that since this caps the negative side as well, caller must defer the
592
+ # pad-with-very-negative-logits logic to after this function returns.
593
+ logits = cap * jnp.tanh(logits / cap)
594
+ return logits
595
+
596
+ def _atten_logits(self, query: Array, key: Array) -> Array:
597
+ """Computes logits from query and key."""
598
+ logits = jnp.einsum('BTNH,BSNH->BNTS', query, key)
599
+ return logits
600
+
601
+ def _dot_atten(
602
+ self,
603
+ query: Array,
604
+ key: Array,
605
+ value: Array,
606
+ atten_mask: Array,
607
+ train: bool,
608
+ ) -> tuple[Array, Array]:
609
+ """Main attention function.
610
+
611
+ Args:
612
+ query: A jax.Array of shape [B, T, N, H].
613
+ key: A jax.Array of shape [B, S, N, H].
614
+ value: A jax.Array of shape [B, S, N, H].
615
+ atten_mask: A jax.Array of shape [1|B, 1, 1|T, S] which is a mask that is
616
+ applied to prevent attention between unwanted pairs. This has already
617
+ been converted into large negative logits. Note that the first and third
618
+ dimension allow size 1 if the mask is shared by every item in the batch
619
+ or every token in the target sequence.
620
+ train: Whether the model is in the train mode.
621
+
622
+ Returns:
623
+ encoded: A jax.Array of shape [B, T, N, H].
624
+ atten_probs: A jax.Array of shape [B, N, T, S].
625
+ """
626
+ assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.'
627
+ assert (
628
+ query.shape[:-3] == key.shape[:-3] == value.shape[:-3]
629
+ ), 'q, k, v batch dims must match.'
630
+ assert (
631
+ query.shape[-2] == key.shape[-2] == value.shape[-2]
632
+ ), 'q, k, v num_heads must match.'
633
+ assert key.shape[-3] == value.shape[-3], 'k, v lengths must match.'
634
+ # If only padding bias is supplied, then atten_mask can be [B, 1, 1, S]
635
+ # since each target token is prohibited from attending to the same set of
636
+ # source tokens. In this case tiling is inefficient and unnecessary.
637
+ # If there is no padding mask, and only causal mask then the shape can be
638
+ # [1, 1, T, S].
639
+ assert atten_mask.ndim == 4 and atten_mask.shape[-1] == key.shape[-3]
640
+ assert atten_mask.shape[-2] in [query.shape[-3], 1]
641
+ assert atten_mask.shape[0] in [key.shape[0], 1]
642
+
643
+ query = self._scale_query(query)
644
+ logits = self._atten_logits(query, key)
645
+
646
+ if self.scale_logits_by_head_dims:
647
+ logits = jnp.multiply(logits, 1.0 / np.sqrt(key.shape[-1]))
648
+
649
+ logits = self._cap_logits(logits)
650
+ # Attention softmax is always carried out in fp32.
651
+ logits = logits.astype(jnp.float32)
652
+ # Apply attention masking.
653
+ padded_logits = _apply_mask_to_logits(logits, atten_mask)
654
+ probs = jax.nn.softmax(padded_logits, axis=-1).astype(self.fprop_dtype)
655
+ # Apply attention dropout.
656
+ probs = nn.Dropout(self.atten_dropout_prob, name='atten_dropout')(
657
+ probs, deterministic=not train
658
+ )
659
+ # Compute the attention context.
660
+ encoded = jnp.einsum('BNTS,BSNH->BTNH', probs, value)
661
+ return encoded, probs
662
+
663
+ @nn.nowrap
664
+ def _project_input(self, name: str, dim_per_head: int) -> AttentionProjection:
665
+ """Builds an AttentionProjection module."""
666
+ return AttentionProjection(
667
+ name=name,
668
+ num_heads=self.num_heads,
669
+ dim_per_head=dim_per_head,
670
+ use_bias=self.use_bias,
671
+ dtype=self.dtype,
672
+ fprop_dtype=self.fprop_dtype,
673
+ )
674
+
675
+ @nn.nowrap
676
+ def _make_ln(self, name: str) -> LayerNorm:
677
+ """Makes a LayerNorm module."""
678
+ return LayerNorm(
679
+ name=name,
680
+ use_bias=self.use_bias,
681
+ dtype=self.dtype,
682
+ fprop_dtype=self.fprop_dtype,
683
+ )
684
+
685
+ @nn.compact
686
+ def __call__(
687
+ self,
688
+ query_vec: Array,
689
+ key_vec: Array,
690
+ value_vec: Array,
691
+ atten_mask: Array,
692
+ train: bool,
693
+ ) -> tuple[Array, Array]:
694
+ """Computes the value vector given the current query output.
695
+
696
+ Args:
697
+ query_vec: jax.Array of shape [B, T, D].
698
+ key_vec: jax.Array of shape [B, S, D].
699
+ value_vec: jax.Array of shape [B, S, D].
700
+ atten_mask: jax.Array of shape [1|B, 1, 1|T, S] which is a mask that is
701
+ applied to prevent attention between unwanted pairs. This has already
702
+ been converted into large negative logits. Note that the first and third
703
+ dimension allow size 1 if the mask is shared by every item in the batch
704
+ or every token in the target sequence.
705
+ train: If the model is in the train mode.
706
+
707
+ Returns:
708
+ encoded: jax.Array of shape [B, T, D].
709
+ atten_probs: jax.Array of shape [B, N, T, S].
710
+ """
711
+ dim_per_head = self.dim_per_head
712
+ if dim_per_head is None:
713
+ dim_per_head = self.hidden_dim // self.num_heads
714
+ assert (
715
+ dim_per_head * self.num_heads == self.hidden_dim
716
+ ), f'{dim_per_head} * {self.num_heads} != {self.hidden_dim}'
717
+
718
+ # Project inputs to key, value and query, respectively has shape
719
+ # [B, S, N, H], [B, S, N, H], and [B, T, N, H].
720
+ query_proj = self._project_input('query', dim_per_head)(query_vec)
721
+ key_proj = self._project_input('key', dim_per_head)(key_vec)
722
+ value_proj = self._project_input('value', dim_per_head)(value_vec)
723
+
724
+ if self.use_qk_norm:
725
+ query_proj = self._make_ln(name='layer_norm_q')(query_proj)
726
+ key_proj = self._make_ln(name='layer_norm_k')(key_proj)
727
+
728
+ encoded, atten_probs = self._dot_atten(
729
+ query_proj, key_proj, value_proj, atten_mask, train=train
730
+ )
731
+
732
+ # Post projection. Setting is_output_projection=True to set the projection
733
+ # direction from hidden dim to input dim. Output projection follows
734
+ # query_input_dim.
735
+ query_input_dim = query_vec.shape[-1]
736
+ encoded = AttentionProjection(
737
+ name='post',
738
+ output_dim=query_input_dim,
739
+ num_heads=self.num_heads,
740
+ dim_per_head=dim_per_head,
741
+ is_output_projection=True,
742
+ use_bias=self.use_bias,
743
+ dtype=self.dtype,
744
+ fprop_dtype=self.fprop_dtype,
745
+ )(encoded)
746
+ return encoded, atten_probs
747
+
748
+
749
+ class Transformer(Module):
750
+ """Transformer layer with multi-headed attention.
751
+
752
+ Attributes:
753
+ hidden_dim: Hidden dimension of FFN layer.
754
+ num_heads: Number of heads in self-attention.
755
+ dim_per_head: Dimension of each attention head. If None then dim_per_head ==
756
+ hidden_dim // num_heads.
757
+ atten_dropout_prob: Probability at which we apply dropout to the attention
758
+ weights.
759
+ residual_dropout_prob: Probability at which we apply dropout to the residual
760
+ layers, such that, residual(x, y) = (x + dropout(y)).
761
+ relu_dropout_prob: Probability at which we apply dropout to the FFN layers.
762
+ norm_policy: Policy for applying normalization wrt. transformations. Options
763
+ are: (1) "pre", applied before transformation. (2) "primer_hybrid",
764
+ applied before and after transformation. (3) "post", applied after
765
+ transformation. (4) "post_skip", applied after the skip connection.
766
+ use_bias: Whether to use bias.
767
+ activation_fn: Activation function to use.
768
+ internal_enable_per_dim_scale: Internal. Setting to False disables rescaling
769
+ of attention logits with 1/sqrt(dim) factor.
770
+ atten_logit_cap: Cap the absolute values of logits by tanh. Enabled when a
771
+ positive value is specified. May not be supported by a subclass.
772
+ """
773
+
774
+ hidden_dim: int = 0
775
+ num_heads: int = 0
776
+ dim_per_head: int | None = None
777
+ atten_dropout_prob: float = 0.0
778
+ residual_dropout_prob: float = 0.0
779
+ relu_dropout_prob: float = 0.0
780
+ norm_policy: str = 'pre'
781
+ use_bias: bool = True
782
+ activation_fn: ActivationFunc = nn.relu
783
+ internal_enable_per_dim_scale: bool = True
784
+ atten_logit_cap: float = 0.0
785
+
786
+ @nn.nowrap
787
+ def _make_ln(self, name: str) -> LayerNorm:
788
+ """Makes a LayerNorm module."""
789
+ return LayerNorm(
790
+ name=name,
791
+ use_bias=self.use_bias,
792
+ dtype=self.dtype,
793
+ fprop_dtype=self.fprop_dtype,
794
+ )
795
+
796
+ @nn.compact
797
+ def __call__(
798
+ self,
799
+ inputs: Array,
800
+ paddings: Array,
801
+ atten_mask: Array,
802
+ train: bool,
803
+ ) -> Array:
804
+ """Transformer decoder layer.
805
+
806
+ Args:
807
+ inputs: Input sequence jax.Array of shape [B, T, H].
808
+ paddings: Input paddings jax.Array of shape [B, T] (only used in FFN).
809
+ atten_mask: Self attention mask ready to add to the logits. It can be of
810
+ shape [1|B, 1, 1|T, T] which is broadcast compatible with the
811
+ self-attention matrix of shape [B, N, T, T]. This is assumed to have
812
+ combined paddings, causal masking as well as segment maskings.
813
+ train: Whether the model is in the train mode.
814
+
815
+ Returns:
816
+ The fflayer output with shape [B, T, D].
817
+ """
818
+
819
+ if self.norm_policy == 'primer_hybrid':
820
+ inputs_normalized = self._make_ln(name='pre_layer_norm')(inputs)
821
+ elif self.norm_policy == 'pre':
822
+ inputs_normalized = self._make_ln(name='layer_norm')(inputs)
823
+ else:
824
+ inputs_normalized = inputs
825
+
826
+ # Compute self-attention, key/value vectors are the input itself.
827
+ atten_outputs, _ = DotProductAttention(
828
+ name='self_attention',
829
+ hidden_dim=inputs_normalized.shape[-1],
830
+ num_heads=self.num_heads,
831
+ dim_per_head=self.dim_per_head,
832
+ atten_dropout_prob=self.atten_dropout_prob,
833
+ use_bias=self.use_bias,
834
+ internal_enable_per_dim_scale=self.internal_enable_per_dim_scale,
835
+ atten_logit_cap=self.atten_logit_cap,
836
+ dtype=self.dtype,
837
+ fprop_dtype=self.fprop_dtype,
838
+ )(
839
+ inputs_normalized,
840
+ inputs_normalized,
841
+ inputs_normalized,
842
+ atten_mask=atten_mask,
843
+ train=train,
844
+ )
845
+
846
+ if self.norm_policy == 'primer_hybrid':
847
+ atten_outputs = self._make_ln(name='post_layer_norm')(atten_outputs)
848
+ elif self.norm_policy == 'post':
849
+ atten_outputs = self._make_ln(name='layer_norm')(atten_outputs)
850
+
851
+ # Residual dropout and connection.
852
+ atten_outputs = nn.Dropout(
853
+ self.residual_dropout_prob, name='residual_dropout'
854
+ )(atten_outputs, deterministic=not train)
855
+ atten_outputs += inputs
856
+
857
+ if self.norm_policy == 'post_skip':
858
+ atten_outputs = self._make_ln(name='layer_norm')(atten_outputs)
859
+
860
+ # Apply FFN layer.
861
+ outputs = TransformerFeedForward(
862
+ name='ff_layer',
863
+ hidden_dim=self.hidden_dim,
864
+ has_bias=self.use_bias,
865
+ activation_fn=self.activation_fn,
866
+ residual_dropout_prob=self.residual_dropout_prob,
867
+ relu_dropout_prob=self.relu_dropout_prob,
868
+ norm_policy=self.norm_policy,
869
+ dtype=self.dtype,
870
+ fprop_dtype=self.fprop_dtype,
871
+ )(atten_outputs, paddings=paddings, train=train)
872
+ return outputs
873
+
874
+
875
+ class Repeat(nn.Module):
876
+ """A generic repeat layer with `nn.remat` and`nn.scan`.
877
+
878
+ Attributes:
879
+ block_fn: The block function to repeat.
880
+ times: The number of times to repeat block.
881
+ checkpoint_policy: Checkpoint policy for `nn.remat`.
882
+ """
883
+
884
+ block_fn: Callable[..., Any]
885
+ times: int = 0
886
+ checkpoint_policy: str = 'nothing_saveable'
887
+
888
+ def __call__(
889
+ self,
890
+ inputs: Array,
891
+ *args: Any,
892
+ **kwargs: Any,
893
+ ) -> Any:
894
+ """Forwards inputs through the block layer stack.
895
+
896
+ Block outputs are expected to be of the same structure as inputs.
897
+
898
+ Args:
899
+ inputs: A NestedMap of inputs that goes through the block layer stack.
900
+ *args: Positional args to be passed to the forward method.
901
+ **kwargs: Keyward args to be passed to the forward method.
902
+
903
+ Returns:
904
+ Output from the last layer.
905
+ """
906
+ return self.call_with_custom_method(
907
+ inputs,
908
+ *args,
909
+ main_fn=self.block_fn,
910
+ **kwargs,
911
+ )
912
+
913
+ def call_with_custom_method(
914
+ self,
915
+ inputs: Array,
916
+ *args: Any,
917
+ main_fn: Callable[..., Any],
918
+ **kwargs: Any,
919
+ ) -> Any:
920
+ """Similar to __call__, but allows a custom way to create a layer method."""
921
+
922
+ def body_fn(fn, layer_inputs):
923
+ return fn(layer_inputs, *args, **kwargs), None
924
+
925
+ rematted_body_fn = nn.remat(
926
+ body_fn,
927
+ prevent_cse=False,
928
+ policy=getattr(jax.checkpoint_policies, self.checkpoint_policy, None),
929
+ )
930
+ scan_fn = nn.scan(
931
+ rematted_body_fn,
932
+ variable_axes={'params': 0},
933
+ split_rngs={'params': True, 'dropout': True},
934
+ length=self.times,
935
+ )
936
+ outputs, _ = scan_fn(main_fn, inputs)
937
+ return outputs
938
+
939
+
940
+ class StackedTransformer(Module):
941
+ """A stack of Transformer layers.
942
+
943
+ Attributes:
944
+ num_layers: Number of layers in this stack.
945
+ hidden_dim: The hidden layer dimension of FFN in Transformer layers.
946
+ num_heads: Number of attention heads.
947
+ dim_per_head: Dimension of each attention head. If None then dim_per_head ==
948
+ model_dims // num_heads.
949
+ dropout_prob: Apply dropout at this prob at various places.
950
+ atten_dropout_prob: Probability at which we apply dropout to the attention
951
+ weights.
952
+ residual_dropout_prob: Probability at which we apply dropout to the residual
953
+ layers, such that, residual(x, y) = (x + dropout(y)).
954
+ relu_dropout_prob: Probability at which we apply dropout to the FFN layers.
955
+ input_dropout_prob: Dropout probability applied to the input before any
956
+ processing happens.
957
+ norm_policy: Policy for applying normalization wrt. transformations. Options
958
+ are: (1) "pre", applied before transformation. (2) "primer_hybrid",
959
+ applied before and after transformation. (3) "post", applied after
960
+ transformation. (4) "post_skip", applied after the skip connection.
961
+ use_bias: Whether to use bias.
962
+ activation_fn: Activation function to use.
963
+ internal_enable_per_dim_scale: Internal. Setting to False disables rescaling
964
+ of attention logits with 1/sqrt(dim) factor.
965
+ atten_logit_cap: Cap the absolute values of logits by tanh. Enabled when a
966
+ positive value is specified. May not be supported by a subclass.
967
+ enable_causal_atten: Whether to enable causal attention.
968
+ scan: Whether to use `nn.remat` and`nn.scan`.
969
+ """
970
+
971
+ num_layers: int = 0
972
+ hidden_dim: int = 0
973
+ num_heads: int = 0
974
+ dim_per_head: int | None = None
975
+ dropout_prob: float = 0.0
976
+ atten_dropout_prob: float | None = None
977
+ residual_dropout_prob: float | None = None
978
+ relu_dropout_prob: float | None = None
979
+ input_dropout_prob: float = 0.0
980
+ norm_policy: str = 'pre'
981
+ use_bias: bool = True
982
+ activation_fn: ActivationFunc = nn.relu
983
+ internal_enable_per_dim_scale: bool = True
984
+ atten_logit_cap: float = 0.0
985
+ enable_causal_atten: bool = False
986
+ scan: bool = False
987
+
988
+ @nn.compact
989
+ def __call__(
990
+ self,
991
+ inputs: Array,
992
+ paddings: Array,
993
+ train: bool,
994
+ ) -> Array:
995
+ """Stacked Transformer layer.
996
+
997
+ Args:
998
+ inputs: Input sequence of shape [B, T, H].
999
+ paddings: Input paddings of shape [B, T].
1000
+ train: If the model is in the train mode.
1001
+
1002
+ Returns:
1003
+ Output vector with shape [B, T, D].
1004
+ """
1005
+
1006
+ atten_mask = compute_attention_masks_for_fprop(
1007
+ inputs, paddings, causal_attention=self.enable_causal_atten
1008
+ )
1009
+
1010
+ outputs = inputs
1011
+ if self.input_dropout_prob > 0.0:
1012
+ outputs = nn.Dropout(self.input_dropout_prob, name='input_dropout')(
1013
+ outputs, deterministic=not train
1014
+ )
1015
+
1016
+ transformer_kwargs = dict(
1017
+ num_heads=self.num_heads,
1018
+ dim_per_head=self.dim_per_head,
1019
+ hidden_dim=self.hidden_dim,
1020
+ atten_dropout_prob=self.atten_dropout_prob or self.dropout_prob,
1021
+ residual_dropout_prob=self.residual_dropout_prob or self.dropout_prob,
1022
+ relu_dropout_prob=self.relu_dropout_prob or self.dropout_prob,
1023
+ norm_policy=self.norm_policy,
1024
+ use_bias=self.use_bias,
1025
+ activation_fn=self.activation_fn,
1026
+ internal_enable_per_dim_scale=self.internal_enable_per_dim_scale,
1027
+ atten_logit_cap=self.atten_logit_cap,
1028
+ dtype=self.dtype,
1029
+ fprop_dtype=self.fprop_dtype,
1030
+ )
1031
+ if self.scan:
1032
+ block_fn = Transformer(name='x_layers', **transformer_kwargs)
1033
+ outputs = Repeat(block_fn=block_fn, times=self.num_layers)(
1034
+ outputs, paddings, atten_mask, train
1035
+ )
1036
+ else:
1037
+ for i in range(self.num_layers):
1038
+ outputs = Transformer(name=f'x_layers_{i}', **transformer_kwargs)(
1039
+ outputs, paddings, atten_mask, train
1040
+ )
1041
+ return outputs
1042
+
1043
+
1044
+ class AttenTokenPoolingLayer(Module):
1045
+ """Attentional token pooling layer.
1046
+
1047
+ Attributes:
1048
+ query_dim: The query dimension of attention. If None then query_dim ==
1049
+ input_dim.
1050
+ hidden_dim: The hidden layer dimension of FFN in Transformer layers.
1051
+ num_heads: Number of attention heads.
1052
+ num_queries: Number of attention queries.
1053
+ add_layer_norm: Whether to apply layer norm to the pooled tokens.
1054
+ dropout_prob: The probability of dropout on the pooled tokens.
1055
+ use_qk_norm: If QK norm is used.
1056
+ use_bias: Whether to use bias.
1057
+ internal_enable_per_dim_scale: Internal. Setting to False disables rescaling
1058
+ of attention logits with 1/sqrt(dim) factor.
1059
+ """
1060
+
1061
+ query_dim: int | None = None
1062
+ hidden_dim: int = 0
1063
+ num_heads: int = 1
1064
+ num_queries: int = 1
1065
+ add_layer_norm: bool = True
1066
+ dropout_prob: float = 0.0
1067
+ use_qk_norm: bool = False
1068
+ use_bias: bool = True
1069
+ internal_enable_per_dim_scale: bool = True
1070
+
1071
+ @nn.compact
1072
+ def __call__(
1073
+ self,
1074
+ tokens: Array,
1075
+ paddings: Array | None,
1076
+ train: bool,
1077
+ ) -> Array:
1078
+ """Computes the pooled tokens for inputs.
1079
+
1080
+ Args:
1081
+ tokens: Input tokens of shape [B, T, H].
1082
+ paddings: Input paddings of shape [B, T].
1083
+ train: If the model is in the train mode.
1084
+
1085
+ Returns:
1086
+ Output vector with shape [B, N, D].
1087
+ """
1088
+ input_dim = tokens.shape[-1]
1089
+ query_dim = self.query_dim or input_dim
1090
+ hidden_dim = self.hidden_dim if self.hidden_dim > 0 else 4 * input_dim
1091
+ batch_size, seq_length = tokens.shape[0], tokens.shape[-2]
1092
+
1093
+ query = self._cast_to_fprop_dtype(
1094
+ self.param(
1095
+ 'pooling_attention_query',
1096
+ default_kernel_init,
1097
+ [self.num_queries, query_dim],
1098
+ self.dtype,
1099
+ )
1100
+ )
1101
+ query = jnp.tile(query[jnp.newaxis, :, :], [batch_size, 1, 1])
1102
+
1103
+ if paddings is None:
1104
+ paddings = jnp.zeros([batch_size, seq_length], dtype=tokens.dtype)
1105
+
1106
+ atten_mask = _convert_paddings_to_mask(paddings, dtype=paddings.dtype)
1107
+ outputs, _ = DotProductAttention(
1108
+ name='pooling_attention',
1109
+ hidden_dim=hidden_dim,
1110
+ num_heads=self.num_heads,
1111
+ use_bias=self.use_bias,
1112
+ internal_enable_per_dim_scale=self.internal_enable_per_dim_scale,
1113
+ use_qk_norm=self.use_qk_norm,
1114
+ dtype=self.dtype,
1115
+ fprop_dtype=self.fprop_dtype,
1116
+ )(
1117
+ query,
1118
+ tokens,
1119
+ tokens,
1120
+ atten_mask=atten_mask,
1121
+ train=train,
1122
+ )
1123
+
1124
+ if self.add_layer_norm:
1125
+ outputs = LayerNorm(
1126
+ name='pooling_attention_layer_norm',
1127
+ dtype=self.dtype,
1128
+ fprop_dtype=self.fprop_dtype,
1129
+ )(outputs)
1130
+
1131
+ if self.dropout_prob > 0.0:
1132
+ outputs = nn.Dropout(self.dropout_prob, name='attention_dropout')(
1133
+ outputs, deterministic=not train
1134
+ )
1135
+
1136
+ return outputs