singlebehaviorlab 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sam2/__init__.py +11 -0
- sam2/automatic_mask_generator.py +454 -0
- sam2/benchmark.py +92 -0
- sam2/build_sam.py +174 -0
- sam2/configs/sam2/sam2_hiera_b+.yaml +113 -0
- sam2/configs/sam2/sam2_hiera_l.yaml +117 -0
- sam2/configs/sam2/sam2_hiera_s.yaml +116 -0
- sam2/configs/sam2/sam2_hiera_t.yaml +118 -0
- sam2/configs/sam2.1/sam2.1_hiera_b+.yaml +116 -0
- sam2/configs/sam2.1/sam2.1_hiera_l.yaml +120 -0
- sam2/configs/sam2.1/sam2.1_hiera_s.yaml +119 -0
- sam2/configs/sam2.1/sam2.1_hiera_t.yaml +121 -0
- sam2/configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml +339 -0
- sam2/modeling/__init__.py +5 -0
- sam2/modeling/backbones/__init__.py +5 -0
- sam2/modeling/backbones/hieradet.py +317 -0
- sam2/modeling/backbones/image_encoder.py +134 -0
- sam2/modeling/backbones/utils.py +93 -0
- sam2/modeling/memory_attention.py +169 -0
- sam2/modeling/memory_encoder.py +181 -0
- sam2/modeling/position_encoding.py +239 -0
- sam2/modeling/sam/__init__.py +5 -0
- sam2/modeling/sam/mask_decoder.py +295 -0
- sam2/modeling/sam/prompt_encoder.py +202 -0
- sam2/modeling/sam/transformer.py +311 -0
- sam2/modeling/sam2_base.py +913 -0
- sam2/modeling/sam2_utils.py +323 -0
- sam2/sam2_hiera_b+.yaml +113 -0
- sam2/sam2_hiera_l.yaml +117 -0
- sam2/sam2_hiera_s.yaml +116 -0
- sam2/sam2_hiera_t.yaml +118 -0
- sam2/sam2_image_predictor.py +466 -0
- sam2/sam2_video_predictor.py +1388 -0
- sam2/sam2_video_predictor_legacy.py +1172 -0
- sam2/utils/__init__.py +5 -0
- sam2/utils/amg.py +348 -0
- sam2/utils/misc.py +349 -0
- sam2/utils/transforms.py +118 -0
- singlebehaviorlab/__init__.py +4 -0
- singlebehaviorlab/__main__.py +130 -0
- singlebehaviorlab/_paths.py +100 -0
- singlebehaviorlab/backend/__init__.py +2 -0
- singlebehaviorlab/backend/augmentations.py +320 -0
- singlebehaviorlab/backend/data_store.py +420 -0
- singlebehaviorlab/backend/model.py +1290 -0
- singlebehaviorlab/backend/train.py +4667 -0
- singlebehaviorlab/backend/uncertainty.py +578 -0
- singlebehaviorlab/backend/video_processor.py +688 -0
- singlebehaviorlab/backend/video_utils.py +139 -0
- singlebehaviorlab/data/config/config.yaml +85 -0
- singlebehaviorlab/data/training_profiles.json +334 -0
- singlebehaviorlab/gui/__init__.py +4 -0
- singlebehaviorlab/gui/analysis_widget.py +2291 -0
- singlebehaviorlab/gui/attention_export.py +311 -0
- singlebehaviorlab/gui/clip_extraction_widget.py +481 -0
- singlebehaviorlab/gui/clustering_widget.py +3187 -0
- singlebehaviorlab/gui/inference_popups.py +1138 -0
- singlebehaviorlab/gui/inference_widget.py +4550 -0
- singlebehaviorlab/gui/inference_worker.py +651 -0
- singlebehaviorlab/gui/labeling_widget.py +2324 -0
- singlebehaviorlab/gui/main_window.py +754 -0
- singlebehaviorlab/gui/metadata_management_widget.py +1119 -0
- singlebehaviorlab/gui/motion_tracking.py +764 -0
- singlebehaviorlab/gui/overlay_export.py +1234 -0
- singlebehaviorlab/gui/plot_integration.py +729 -0
- singlebehaviorlab/gui/qt_helpers.py +29 -0
- singlebehaviorlab/gui/registration_widget.py +1485 -0
- singlebehaviorlab/gui/review_widget.py +1330 -0
- singlebehaviorlab/gui/segmentation_tracking_widget.py +2752 -0
- singlebehaviorlab/gui/tab_tutorial_dialog.py +312 -0
- singlebehaviorlab/gui/timeline_themes.py +131 -0
- singlebehaviorlab/gui/training_profiles.py +418 -0
- singlebehaviorlab/gui/training_widget.py +3719 -0
- singlebehaviorlab/gui/video_utils.py +233 -0
- singlebehaviorlab/licenses/SAM2-LICENSE +201 -0
- singlebehaviorlab/licenses/VideoPrism-LICENSE +202 -0
- singlebehaviorlab-2.0.0.dist-info/METADATA +447 -0
- singlebehaviorlab-2.0.0.dist-info/RECORD +88 -0
- singlebehaviorlab-2.0.0.dist-info/WHEEL +5 -0
- singlebehaviorlab-2.0.0.dist-info/entry_points.txt +2 -0
- singlebehaviorlab-2.0.0.dist-info/licenses/LICENSE +21 -0
- singlebehaviorlab-2.0.0.dist-info/top_level.txt +3 -0
- videoprism/__init__.py +0 -0
- videoprism/encoders.py +910 -0
- videoprism/layers.py +1136 -0
- videoprism/models.py +407 -0
- videoprism/tokenizers.py +167 -0
- videoprism/utils.py +168 -0
videoprism/layers.py
ADDED
|
@@ -0,0 +1,1136 @@
|
|
|
1
|
+
# Copyright 2026 VideoPrism Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""VideoPrism Flax layers."""
|
|
16
|
+
|
|
17
|
+
from collections.abc import Callable
|
|
18
|
+
import functools
|
|
19
|
+
import string
|
|
20
|
+
from typing import Any
|
|
21
|
+
from flax import linen as nn
|
|
22
|
+
import jax
|
|
23
|
+
from jax import numpy as jnp
|
|
24
|
+
import numpy as np
|
|
25
|
+
|
|
26
|
+
Array = jax.Array
|
|
27
|
+
ActivationFunc = Callable[[Array], Array]
|
|
28
|
+
Initializer = nn.initializers.Initializer
|
|
29
|
+
|
|
30
|
+
default_kernel_init = nn.initializers.lecun_normal()
|
|
31
|
+
gelu = functools.partial(jax.nn.gelu, approximate=False)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def identity(x: Array) -> Array:
|
|
35
|
+
"""Identity activation."""
|
|
36
|
+
return x
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _get_large_negative_number(dtype: jax.typing.DTypeLike) -> Array:
|
|
40
|
+
"""Returns a large-magnitude negative value for the given dtype."""
|
|
41
|
+
# -0.7 is a float64 in JAX. Explicit cast output to target dtype.
|
|
42
|
+
if jnp.issubdtype(dtype, jnp.inexact):
|
|
43
|
+
dtype_max = jnp.finfo(dtype).max
|
|
44
|
+
elif jnp.issubdtype(dtype, jnp.integer):
|
|
45
|
+
dtype_max = jnp.iinfo(dtype).max
|
|
46
|
+
else:
|
|
47
|
+
raise ValueError('Unsupported dtype for inputs.')
|
|
48
|
+
return jnp.asarray(-0.7 * dtype_max, dtype=dtype)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def _apply_mask_to_logits(logits: Array, mask: Array) -> Array:
|
|
52
|
+
"""Applies a floating-point mask to a set of logits.
|
|
53
|
+
|
|
54
|
+
The mask is represented as a float32 tensor where 0 represents true and values
|
|
55
|
+
below a large negative number (here set to
|
|
56
|
+
_get_large_negative_number(jnp.float32) / 2) represent false. Applying the
|
|
57
|
+
mask leaves the logits alone in the true case and replaces them by
|
|
58
|
+
_get_large_negative_number(jnp.float32) in the false case. Previously, this
|
|
59
|
+
was done by adding the logits to the mask; however, this leads to a bad fusion
|
|
60
|
+
decision in the compiler that saves the float32 values in memory rather than
|
|
61
|
+
just the predicate. This implementation avoids that problem.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
logits: A jax.Array of logit values.
|
|
65
|
+
mask: A jax.Array (float32) of mask values with the encoding described in
|
|
66
|
+
the function documentation.
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
Masked logits.
|
|
70
|
+
"""
|
|
71
|
+
min_value = _get_large_negative_number(logits.dtype)
|
|
72
|
+
return jnp.where((mask >= min_value * 0.5), logits, min_value)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def _convert_paddings_to_mask(
|
|
76
|
+
paddings: Array, dtype: jax.typing.DTypeLike = jnp.float32
|
|
77
|
+
) -> Array:
|
|
78
|
+
"""Converts binary paddings to a logit mask ready to add to attention matrix.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
paddings: A binary jax.Array of shape [B, T], with 1 denoting padding token.
|
|
82
|
+
dtype: Data type of the input.
|
|
83
|
+
|
|
84
|
+
Returns:
|
|
85
|
+
A jax.Array of shape [B, 1, 1, T] ready to be added to attention logits.
|
|
86
|
+
"""
|
|
87
|
+
attention_mask = paddings[:, jnp.newaxis, jnp.newaxis, :]
|
|
88
|
+
attention_mask *= _get_large_negative_number(dtype)
|
|
89
|
+
return attention_mask
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def _causal_mask(input_t: Array) -> Array:
|
|
93
|
+
"""Computes and returns causal mask.
|
|
94
|
+
|
|
95
|
+
Args:
|
|
96
|
+
input_t: A jax.Array of shape [B, T, D].
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
An attention_mask jax.Array of shape [1, 1, T, T]. Attention mask has
|
|
100
|
+
already been converted large negative values.
|
|
101
|
+
"""
|
|
102
|
+
assert jnp.issubdtype(input_t.dtype, jnp.floating), input_t.dtype
|
|
103
|
+
large_negative_number = _get_large_negative_number(input_t.dtype)
|
|
104
|
+
t = input_t.shape[-2]
|
|
105
|
+
col_idx = jnp.tile(jnp.arange(t)[jnp.newaxis, :], [t, 1])
|
|
106
|
+
row_idx = jnp.tile(jnp.arange(t)[:, jnp.newaxis], [1, t])
|
|
107
|
+
mask = (row_idx < col_idx).astype(input_t.dtype) * large_negative_number
|
|
108
|
+
return mask[jnp.newaxis, jnp.newaxis, :, :]
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def _merge_masks(a: Array, b: Array) -> Array:
|
|
112
|
+
"""Merges two masks.
|
|
113
|
+
|
|
114
|
+
This function merges two masks with the same shape, where the smaller value
|
|
115
|
+
will be chosen at the same position. Log-scale mask is expected but 0/1 mask
|
|
116
|
+
is also fine.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
a: A jax.Array of shape [1|B, 1, 1|T, S].
|
|
120
|
+
b: A jax.Array of shape [1|B, 1, 1|T, S].
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
A jax.Array of shape [1|B, 1, 1|T, S].
|
|
124
|
+
"""
|
|
125
|
+
|
|
126
|
+
def expand_t(key_mask):
|
|
127
|
+
"""Expands the 1D mask to the 2D mask.
|
|
128
|
+
|
|
129
|
+
Given [[1, 1, 0, 0]], this function returns the following mask,
|
|
130
|
+
1 1 0 0
|
|
131
|
+
1 1 0 0
|
|
132
|
+
0 0 0 0
|
|
133
|
+
0 0 0 0
|
|
134
|
+
|
|
135
|
+
Args:
|
|
136
|
+
key_mask: A jax.Array of the input 1D mask.
|
|
137
|
+
|
|
138
|
+
Returns:
|
|
139
|
+
A jax.Array of the expanded 2D mask.
|
|
140
|
+
"""
|
|
141
|
+
query_mask = jnp.transpose(key_mask, [0, 1, 3, 2])
|
|
142
|
+
return jnp.minimum(query_mask, key_mask)
|
|
143
|
+
|
|
144
|
+
if a.shape[-2] != b.shape[-2]:
|
|
145
|
+
if a.shape[-2] == 1:
|
|
146
|
+
a = expand_t(a)
|
|
147
|
+
else:
|
|
148
|
+
assert b.shape[-2] == 1
|
|
149
|
+
b = expand_t(b)
|
|
150
|
+
|
|
151
|
+
assert a.shape[-3:] == b.shape[-3:], f'a.shape={a.shape}, b.shape={b.shape}.'
|
|
152
|
+
return jnp.minimum(a, b)
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def compute_attention_masks_for_fprop(
|
|
156
|
+
inputs: Array,
|
|
157
|
+
paddings: Array,
|
|
158
|
+
causal_attention: bool = False,
|
|
159
|
+
) -> Array:
|
|
160
|
+
"""Computes attention mask from inputs and paddings for fprop.
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
inputs: Input sequence jax.Array of shape [B, T, H].
|
|
164
|
+
paddings: Input paddings jax.Array of shape [B, T].
|
|
165
|
+
causal_attention: Boolean to apply causal masking.
|
|
166
|
+
|
|
167
|
+
Returns:
|
|
168
|
+
attention_mask: Attention mask jax.Array ready to be added to logits for
|
|
169
|
+
self-attention of shape [1|B, 1, 1|T, T].
|
|
170
|
+
"""
|
|
171
|
+
# Get paddings mask to [B, 1, 1, T].
|
|
172
|
+
attention_mask = _convert_paddings_to_mask(paddings, inputs.dtype)
|
|
173
|
+
|
|
174
|
+
# Causal mask of shape [1, 1, T, T].
|
|
175
|
+
if causal_attention:
|
|
176
|
+
causal_mask = _causal_mask(inputs)
|
|
177
|
+
attention_mask = _merge_masks(attention_mask, causal_mask)
|
|
178
|
+
|
|
179
|
+
return attention_mask
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
class Module(nn.Module):
|
|
183
|
+
"""Base class for layers with dtype configured.
|
|
184
|
+
|
|
185
|
+
Attributes:
|
|
186
|
+
dtype: Default dtype for all variables.
|
|
187
|
+
fprop_dtype: Activations dtype to use.
|
|
188
|
+
"""
|
|
189
|
+
|
|
190
|
+
dtype: jnp.dtype = jnp.float32
|
|
191
|
+
fprop_dtype: jnp.dtype = jnp.float32
|
|
192
|
+
|
|
193
|
+
@nn.nowrap
|
|
194
|
+
def _cast_to_fprop_dtype(self, value: Any) -> Any:
|
|
195
|
+
"""Casts values to the desired dtype."""
|
|
196
|
+
|
|
197
|
+
def _cast(x):
|
|
198
|
+
if x is None:
|
|
199
|
+
return None
|
|
200
|
+
if self.fprop_dtype != x.dtype:
|
|
201
|
+
if jnp.issubdtype(x.dtype, jnp.floating):
|
|
202
|
+
return x.astype(self.fprop_dtype)
|
|
203
|
+
return x
|
|
204
|
+
|
|
205
|
+
return jax.tree_util.tree_map(_cast, value)
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
class LayerNorm(Module):
|
|
209
|
+
"""Layer normalization.
|
|
210
|
+
|
|
211
|
+
Attributes:
|
|
212
|
+
direct_scale: Whether to apply scale directly without a +1.0. Var is
|
|
213
|
+
initialized to 1.0 instead when True.
|
|
214
|
+
epsilon: Tiny value to guard rsqrt.
|
|
215
|
+
use_scale: Whether to use a learned scaling.
|
|
216
|
+
use_bias: Whether to use bias.
|
|
217
|
+
reductions_in_fp32: Whether to compute mean and variance in fp32.
|
|
218
|
+
Recommended for stable training on GPUs.
|
|
219
|
+
"""
|
|
220
|
+
|
|
221
|
+
direct_scale: bool = False
|
|
222
|
+
epsilon: float = 1e-6
|
|
223
|
+
use_scale: bool = True
|
|
224
|
+
use_bias: bool = True
|
|
225
|
+
reductions_in_fp32: bool = False
|
|
226
|
+
|
|
227
|
+
@nn.compact
|
|
228
|
+
def __call__(self, inputs: Array) -> Array:
|
|
229
|
+
"""Applies layer norm to inputs.
|
|
230
|
+
|
|
231
|
+
Args:
|
|
232
|
+
inputs: A jax.Array for the inputs of shape [..., dim].
|
|
233
|
+
|
|
234
|
+
Returns:
|
|
235
|
+
A jax.Aray for the normalized inputs of the same shape.
|
|
236
|
+
"""
|
|
237
|
+
inputs_dtype = inputs.dtype
|
|
238
|
+
if self.reductions_in_fp32:
|
|
239
|
+
inputs = inputs.astype(jnp.float32)
|
|
240
|
+
mean = jnp.mean(inputs, axis=[-1], keepdims=True)
|
|
241
|
+
var = jnp.mean(jnp.square(inputs - mean), axis=[-1], keepdims=True)
|
|
242
|
+
normed_inputs = (inputs - mean) * jax.lax.rsqrt(var + self.epsilon)
|
|
243
|
+
if self.reductions_in_fp32:
|
|
244
|
+
normed_inputs = normed_inputs.astype(inputs_dtype)
|
|
245
|
+
|
|
246
|
+
input_dim = inputs.shape[-1]
|
|
247
|
+
if self.use_scale:
|
|
248
|
+
init_value = 1.0 if self.direct_scale else 0.0
|
|
249
|
+
scale = self._cast_to_fprop_dtype(
|
|
250
|
+
self.param(
|
|
251
|
+
'scale',
|
|
252
|
+
nn.initializers.constant(init_value),
|
|
253
|
+
[input_dim],
|
|
254
|
+
self.dtype,
|
|
255
|
+
)
|
|
256
|
+
)
|
|
257
|
+
if not self.direct_scale:
|
|
258
|
+
scale += 1.0
|
|
259
|
+
normed_inputs *= scale
|
|
260
|
+
if self.use_bias:
|
|
261
|
+
bias = self._cast_to_fprop_dtype(
|
|
262
|
+
self.param(
|
|
263
|
+
'bias',
|
|
264
|
+
nn.initializers.zeros_init(),
|
|
265
|
+
[input_dim],
|
|
266
|
+
self.dtype,
|
|
267
|
+
)
|
|
268
|
+
)
|
|
269
|
+
normed_inputs += bias
|
|
270
|
+
return normed_inputs
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
class FeedForward(Module):
|
|
274
|
+
"""Feedforward layer with activation.
|
|
275
|
+
|
|
276
|
+
Attributes:
|
|
277
|
+
output_dim: Depth of the output.
|
|
278
|
+
has_bias: Adds bias weights or not.
|
|
279
|
+
activation_fn: Activation function to use.
|
|
280
|
+
weight_init: Initializer function for the weight matrix.
|
|
281
|
+
bias_init: Initializer function for the bias.
|
|
282
|
+
"""
|
|
283
|
+
|
|
284
|
+
output_dim: int = 0
|
|
285
|
+
has_bias: bool = True
|
|
286
|
+
activation_fn: ActivationFunc = nn.relu
|
|
287
|
+
weight_init: Initializer = default_kernel_init
|
|
288
|
+
bias_init: Initializer = nn.initializers.zeros_init()
|
|
289
|
+
|
|
290
|
+
@nn.compact
|
|
291
|
+
def __call__(self, inputs: Array) -> Array:
|
|
292
|
+
|
|
293
|
+
def _promote_dtype(x, kernel, bias, dtype):
|
|
294
|
+
"""Promotes the dtype of the arrays to the desired dtype."""
|
|
295
|
+
del dtype
|
|
296
|
+
# To be compatible with other layers, we do not promote the inputs as they
|
|
297
|
+
# are expected to be in the `fprop_dtype`.
|
|
298
|
+
return (
|
|
299
|
+
x,
|
|
300
|
+
self._cast_to_fprop_dtype(kernel),
|
|
301
|
+
self._cast_to_fprop_dtype(bias),
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
projected_inputs = nn.Dense(
|
|
305
|
+
self.output_dim,
|
|
306
|
+
use_bias=self.has_bias,
|
|
307
|
+
kernel_init=self.weight_init,
|
|
308
|
+
bias_init=self.bias_init,
|
|
309
|
+
name='linear',
|
|
310
|
+
param_dtype=self.dtype,
|
|
311
|
+
promote_dtype=_promote_dtype,
|
|
312
|
+
)(inputs)
|
|
313
|
+
return self.activation_fn(projected_inputs)
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
class TransformerFeedForward(Module):
|
|
317
|
+
"""Transformer feedforward layer with residual connection and dropout.
|
|
318
|
+
|
|
319
|
+
Attributes:
|
|
320
|
+
output_dim: Depth of the output. The value of input_dim will be used when
|
|
321
|
+
output_dim is 0. Must be equal to input_dim if add_skip_connection=True.
|
|
322
|
+
hidden_dim: Hidden dimension of FFN.
|
|
323
|
+
has_bias: Adds bias weights to Feedforward or not.
|
|
324
|
+
activation_fn: Activation function to use.
|
|
325
|
+
residual_dropout_prob: Residual dropout.
|
|
326
|
+
relu_dropout_prob: FFN dropout.
|
|
327
|
+
add_skip_connection: Whether to add residual connection.
|
|
328
|
+
residual_weight: Weight of the residual connection. Output = fn(x) *
|
|
329
|
+
residual_weight + x.
|
|
330
|
+
norm_policy: Policy for applying normalization wrt. transformations. Options
|
|
331
|
+
are: (1) "pre", applied before transformation. (2) "primer_hybrid",
|
|
332
|
+
applied before and after transformation. (3) "post", applied after
|
|
333
|
+
transformation, (4) "post_skip", applied after the skip connection.
|
|
334
|
+
"""
|
|
335
|
+
|
|
336
|
+
output_dim: int = 0
|
|
337
|
+
hidden_dim: int = 0
|
|
338
|
+
has_bias: bool = True
|
|
339
|
+
activation_fn: ActivationFunc = nn.relu
|
|
340
|
+
residual_dropout_prob: float = 0.0
|
|
341
|
+
relu_dropout_prob: float = 0.0
|
|
342
|
+
add_skip_connection: bool = True
|
|
343
|
+
residual_weight: float = 1.0
|
|
344
|
+
norm_policy: str = 'pre'
|
|
345
|
+
|
|
346
|
+
@nn.nowrap
|
|
347
|
+
def _make_ln(self, name: str) -> LayerNorm:
|
|
348
|
+
"""Makes a LayerNorm module."""
|
|
349
|
+
return LayerNorm(
|
|
350
|
+
name=name,
|
|
351
|
+
use_bias=self.has_bias,
|
|
352
|
+
dtype=self.dtype,
|
|
353
|
+
fprop_dtype=self.fprop_dtype,
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
@nn.nowrap
|
|
357
|
+
def _make_ffn(
|
|
358
|
+
self, output_dim: int, name: str, skip_activation: bool = False
|
|
359
|
+
) -> FeedForward:
|
|
360
|
+
"""Makes a FeedForward module."""
|
|
361
|
+
return FeedForward(
|
|
362
|
+
name=name,
|
|
363
|
+
output_dim=output_dim,
|
|
364
|
+
has_bias=self.has_bias,
|
|
365
|
+
activation_fn=identity if skip_activation else self.activation_fn,
|
|
366
|
+
dtype=self.dtype,
|
|
367
|
+
fprop_dtype=self.fprop_dtype,
|
|
368
|
+
)
|
|
369
|
+
|
|
370
|
+
@nn.compact
|
|
371
|
+
def __call__(
|
|
372
|
+
self, inputs: Array, paddings: Array | None, train: bool
|
|
373
|
+
) -> Array:
|
|
374
|
+
residual = inputs
|
|
375
|
+
output_dim = self.output_dim
|
|
376
|
+
if output_dim == 0:
|
|
377
|
+
output_dim = inputs.shape[-1]
|
|
378
|
+
if self.add_skip_connection and output_dim != inputs.shape[-1]:
|
|
379
|
+
raise ValueError(
|
|
380
|
+
'Skip connections are only supported when input_dim == output_dim '
|
|
381
|
+
f'but got {self.input_dim} != {output_dim}'
|
|
382
|
+
)
|
|
383
|
+
|
|
384
|
+
# Expand paddings to last dim if not None to have shape [batch, seq_len, 1].
|
|
385
|
+
if paddings is not None:
|
|
386
|
+
paddings = jnp.expand_dims(paddings, axis=-1)
|
|
387
|
+
|
|
388
|
+
if self.norm_policy == 'primer_hybrid':
|
|
389
|
+
inputs = self._make_ln(name='pre_layer_norm')(inputs)
|
|
390
|
+
elif self.norm_policy == 'pre':
|
|
391
|
+
inputs = self._make_ln(name='layer_norm')(inputs)
|
|
392
|
+
|
|
393
|
+
# Apply first FFN layer.
|
|
394
|
+
activations = self._make_ffn(self.hidden_dim, name='ffn_layer1')(inputs)
|
|
395
|
+
|
|
396
|
+
# Apply paddings if not None.
|
|
397
|
+
if paddings is not None:
|
|
398
|
+
activations *= 1.0 - paddings
|
|
399
|
+
|
|
400
|
+
# Apply RELU dropout.
|
|
401
|
+
activations = nn.Dropout(self.relu_dropout_prob, name='relu_dropout')(
|
|
402
|
+
activations, deterministic=not train
|
|
403
|
+
)
|
|
404
|
+
# Apply second FFN layer.
|
|
405
|
+
outputs = self._make_ffn(
|
|
406
|
+
output_dim, name='ffn_layer2', skip_activation=True
|
|
407
|
+
)(activations)
|
|
408
|
+
|
|
409
|
+
# Apply paddings if not None.
|
|
410
|
+
if paddings is not None:
|
|
411
|
+
outputs *= 1.0 - paddings
|
|
412
|
+
|
|
413
|
+
# Apply Primer normalization before dropout.
|
|
414
|
+
if self.norm_policy == 'primer_hybrid':
|
|
415
|
+
outputs = self._make_ln(name='post_layer_norm')(outputs)
|
|
416
|
+
elif self.norm_policy == 'post':
|
|
417
|
+
outputs = self._make_ln(name='layer_norm')(outputs)
|
|
418
|
+
|
|
419
|
+
# Apply residual dropout.
|
|
420
|
+
outputs = nn.Dropout(self.residual_dropout_prob, name='residual_dropout')(
|
|
421
|
+
outputs, deterministic=not train
|
|
422
|
+
)
|
|
423
|
+
# Apply skip connection.
|
|
424
|
+
if self.add_skip_connection:
|
|
425
|
+
outputs = residual + outputs * self.residual_weight
|
|
426
|
+
|
|
427
|
+
if self.norm_policy == 'post_skip':
|
|
428
|
+
outputs = self._make_ln(name='layer_norm')(outputs)
|
|
429
|
+
|
|
430
|
+
return outputs
|
|
431
|
+
|
|
432
|
+
|
|
433
|
+
class AttentionProjection(Module):
|
|
434
|
+
"""Layer that computes multi heads projection.
|
|
435
|
+
|
|
436
|
+
This layer is expected to be used within DotProductAttention below.
|
|
437
|
+
|
|
438
|
+
Attributes:
|
|
439
|
+
output_dim: Input dimension.
|
|
440
|
+
num_heads: Number of attention heads.
|
|
441
|
+
dim_per_head: Size of each head.
|
|
442
|
+
is_output_projection: Whether it is out projection or not. If False, we use
|
|
443
|
+
"...D,DNH->...NH" for query,key,value projection. Otherwise we use
|
|
444
|
+
"...NH,DNH->...D" for output projection.
|
|
445
|
+
use_bias: Whether to add bias in projection or not.
|
|
446
|
+
"""
|
|
447
|
+
|
|
448
|
+
output_dim: int = 0
|
|
449
|
+
num_heads: int = 0
|
|
450
|
+
dim_per_head: int = 0
|
|
451
|
+
is_output_projection: bool = False
|
|
452
|
+
use_bias: bool = True
|
|
453
|
+
|
|
454
|
+
@nn.compact
|
|
455
|
+
def __call__(self, inputs: Array) -> Array:
|
|
456
|
+
"""Computes the multi headed projection for inputs.
|
|
457
|
+
|
|
458
|
+
Args:
|
|
459
|
+
inputs: A jax.Array with shape [..., num_heads, dim_per_head] if
|
|
460
|
+
is_output_projection is True or [..., input_dim] otherwise.
|
|
461
|
+
|
|
462
|
+
Returns:
|
|
463
|
+
The projected jax.Array with shape [..., input_dim] if
|
|
464
|
+
is_output_projection is True or [..., num_heads, dim_per_head]
|
|
465
|
+
otherwise.
|
|
466
|
+
"""
|
|
467
|
+
# Sort the available symbols to avoid nondeterminism.
|
|
468
|
+
eqn_sym = ''.join(sorted(set(string.ascii_uppercase) - set('DHN')))
|
|
469
|
+
output_dim = (
|
|
470
|
+
self.output_dim if self.is_output_projection else inputs.shape[-1]
|
|
471
|
+
)
|
|
472
|
+
rank = len(inputs.shape)
|
|
473
|
+
|
|
474
|
+
hd_shape = [self.num_heads, self.dim_per_head]
|
|
475
|
+
pc_shape = [output_dim] + hd_shape
|
|
476
|
+
w = self._cast_to_fprop_dtype(
|
|
477
|
+
self.param('w', default_kernel_init, pc_shape, self.dtype)
|
|
478
|
+
)
|
|
479
|
+
|
|
480
|
+
if self.is_output_projection:
|
|
481
|
+
assert inputs.shape[-2:] == (self.num_heads, self.dim_per_head)
|
|
482
|
+
batch_eqn = eqn_sym[: (rank - 2)]
|
|
483
|
+
eqn = f'{batch_eqn}NH,DNH->{batch_eqn}D'
|
|
484
|
+
else:
|
|
485
|
+
batch_eqn = eqn_sym[: (rank - 1)] if rank else '...'
|
|
486
|
+
eqn = f'{batch_eqn}D,DNH->{batch_eqn}NH'
|
|
487
|
+
|
|
488
|
+
ret = jnp.einsum(eqn, inputs, w)
|
|
489
|
+
if self.use_bias:
|
|
490
|
+
b = self._cast_to_fprop_dtype(
|
|
491
|
+
self.param(
|
|
492
|
+
'b',
|
|
493
|
+
nn.initializers.zeros_init(),
|
|
494
|
+
[output_dim] if self.is_output_projection else hd_shape,
|
|
495
|
+
self.dtype,
|
|
496
|
+
)
|
|
497
|
+
)
|
|
498
|
+
ret += b
|
|
499
|
+
return ret
|
|
500
|
+
|
|
501
|
+
|
|
502
|
+
class PerDimScale(Module):
|
|
503
|
+
"""A layer to scale individual dimensions of the input."""
|
|
504
|
+
|
|
505
|
+
@nn.compact
|
|
506
|
+
def __call__(self, inputs: Array) -> Array:
|
|
507
|
+
"""Returns per_dim_scale * inputs / jnp.sqrt(dim)).
|
|
508
|
+
|
|
509
|
+
Args:
|
|
510
|
+
inputs: A jax.Array with shape [..., dim].
|
|
511
|
+
|
|
512
|
+
Returns:
|
|
513
|
+
outputs: A jax.Array with shape [..., dim].
|
|
514
|
+
"""
|
|
515
|
+
dim = inputs.shape[-1]
|
|
516
|
+
per_dim_scale = self._cast_to_fprop_dtype(
|
|
517
|
+
self.param(
|
|
518
|
+
'per_dim_scale', nn.initializers.zeros_init(), [dim], self.dtype
|
|
519
|
+
)
|
|
520
|
+
)
|
|
521
|
+
|
|
522
|
+
# 1.0/jax.nn.softplus(0.0) = 1.442695041. Hard code this number so that we
|
|
523
|
+
# can avoid unnecessary XLA op fusion mess on TPU.
|
|
524
|
+
r_softplus_0 = 1.442695041
|
|
525
|
+
scale = jnp.array(r_softplus_0 / np.sqrt(dim), dtype=self.fprop_dtype)
|
|
526
|
+
scale *= jax.nn.softplus(per_dim_scale)
|
|
527
|
+
return inputs * scale
|
|
528
|
+
|
|
529
|
+
|
|
530
|
+
class DotProductAttention(Module):
|
|
531
|
+
"""Dot-product attention with multiple attention heads.
|
|
532
|
+
|
|
533
|
+
Attributes:
|
|
534
|
+
hidden_dim: Number of hidden nodes.
|
|
535
|
+
num_heads: Number of attention heads.
|
|
536
|
+
dim_per_head: Dimension of each attention head. If None then dim_per_head ==
|
|
537
|
+
hidden_dim // num_heads.
|
|
538
|
+
atten_dropout_prob: Probability at which we apply dropout to the attention
|
|
539
|
+
weights.
|
|
540
|
+
use_bias: Whether to use bias for projection layers.
|
|
541
|
+
internal_enable_query_scale: Internal. Enable scaling of query vector.
|
|
542
|
+
internal_enable_per_dim_scale: Internal. Setting to False disables rescaling
|
|
543
|
+
of attention logits with 1/sqrt(dim) factor. Some Transformer variants
|
|
544
|
+
(GShard, T5) use internal_enable_per_dim_scale=False and adjust
|
|
545
|
+
initialization of the linear transformations(einsums), in conjunction with
|
|
546
|
+
Adafactor optimizer.
|
|
547
|
+
scale_query_by_dim_per_head: whether to scale the query by dim_per_head,
|
|
548
|
+
instead of default hidden_dim // num_heads (only activated when
|
|
549
|
+
internal_enable_per_dim_scale = False).
|
|
550
|
+
scale_logits_by_head_dims: Enables a 1/sqrt(head dim) scaling to the logits.
|
|
551
|
+
This occurs prior to logit cap, if any.
|
|
552
|
+
atten_logit_cap: Cap the absolute values of logits by tanh. Enabled when a
|
|
553
|
+
positive value is specified. May not be supported by a subclass.
|
|
554
|
+
use_qk_norm: If QK norm is used.
|
|
555
|
+
"""
|
|
556
|
+
|
|
557
|
+
hidden_dim: int = 0
|
|
558
|
+
num_heads: int = 1
|
|
559
|
+
dim_per_head: int | None = None
|
|
560
|
+
atten_dropout_prob: float = 0.0
|
|
561
|
+
use_bias: bool = True
|
|
562
|
+
internal_enable_query_scale: bool = True
|
|
563
|
+
internal_enable_per_dim_scale: bool = True
|
|
564
|
+
scale_query_by_dim_per_head: bool = False
|
|
565
|
+
scale_logits_by_head_dims: bool = False
|
|
566
|
+
atten_logit_cap: float = 0.0
|
|
567
|
+
use_qk_norm: bool = False
|
|
568
|
+
|
|
569
|
+
def _scale_query(self, query: Array) -> Array:
|
|
570
|
+
"""Scales the query vector if enabled."""
|
|
571
|
+
if not self.internal_enable_query_scale:
|
|
572
|
+
return query
|
|
573
|
+
if self.internal_enable_per_dim_scale:
|
|
574
|
+
query = PerDimScale(
|
|
575
|
+
name='per_dim_scale', dtype=self.dtype, fprop_dtype=self.fprop_dtype
|
|
576
|
+
)(query)
|
|
577
|
+
else:
|
|
578
|
+
if self.scale_query_by_dim_per_head and self.dim_per_head is not None:
|
|
579
|
+
dim_per_head = self.dim_per_head
|
|
580
|
+
else:
|
|
581
|
+
dim_per_head = self.hidden_dim // self.num_heads
|
|
582
|
+
|
|
583
|
+
query *= dim_per_head**-0.5
|
|
584
|
+
return query
|
|
585
|
+
|
|
586
|
+
def _cap_logits(self, logits: Array) -> Array:
|
|
587
|
+
"""Caps the logits by p.atten_logit_cap with tanh, if enabled."""
|
|
588
|
+
if not self.atten_logit_cap or self.atten_logit_cap <= 0.0:
|
|
589
|
+
return logits
|
|
590
|
+
cap = jnp.array(self.atten_logit_cap, dtype=self.fprop_dtype)
|
|
591
|
+
# Note that since this caps the negative side as well, caller must defer the
|
|
592
|
+
# pad-with-very-negative-logits logic to after this function returns.
|
|
593
|
+
logits = cap * jnp.tanh(logits / cap)
|
|
594
|
+
return logits
|
|
595
|
+
|
|
596
|
+
def _atten_logits(self, query: Array, key: Array) -> Array:
|
|
597
|
+
"""Computes logits from query and key."""
|
|
598
|
+
logits = jnp.einsum('BTNH,BSNH->BNTS', query, key)
|
|
599
|
+
return logits
|
|
600
|
+
|
|
601
|
+
def _dot_atten(
|
|
602
|
+
self,
|
|
603
|
+
query: Array,
|
|
604
|
+
key: Array,
|
|
605
|
+
value: Array,
|
|
606
|
+
atten_mask: Array,
|
|
607
|
+
train: bool,
|
|
608
|
+
) -> tuple[Array, Array]:
|
|
609
|
+
"""Main attention function.
|
|
610
|
+
|
|
611
|
+
Args:
|
|
612
|
+
query: A jax.Array of shape [B, T, N, H].
|
|
613
|
+
key: A jax.Array of shape [B, S, N, H].
|
|
614
|
+
value: A jax.Array of shape [B, S, N, H].
|
|
615
|
+
atten_mask: A jax.Array of shape [1|B, 1, 1|T, S] which is a mask that is
|
|
616
|
+
applied to prevent attention between unwanted pairs. This has already
|
|
617
|
+
been converted into large negative logits. Note that the first and third
|
|
618
|
+
dimension allow size 1 if the mask is shared by every item in the batch
|
|
619
|
+
or every token in the target sequence.
|
|
620
|
+
train: Whether the model is in the train mode.
|
|
621
|
+
|
|
622
|
+
Returns:
|
|
623
|
+
encoded: A jax.Array of shape [B, T, N, H].
|
|
624
|
+
atten_probs: A jax.Array of shape [B, N, T, S].
|
|
625
|
+
"""
|
|
626
|
+
assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.'
|
|
627
|
+
assert (
|
|
628
|
+
query.shape[:-3] == key.shape[:-3] == value.shape[:-3]
|
|
629
|
+
), 'q, k, v batch dims must match.'
|
|
630
|
+
assert (
|
|
631
|
+
query.shape[-2] == key.shape[-2] == value.shape[-2]
|
|
632
|
+
), 'q, k, v num_heads must match.'
|
|
633
|
+
assert key.shape[-3] == value.shape[-3], 'k, v lengths must match.'
|
|
634
|
+
# If only padding bias is supplied, then atten_mask can be [B, 1, 1, S]
|
|
635
|
+
# since each target token is prohibited from attending to the same set of
|
|
636
|
+
# source tokens. In this case tiling is inefficient and unnecessary.
|
|
637
|
+
# If there is no padding mask, and only causal mask then the shape can be
|
|
638
|
+
# [1, 1, T, S].
|
|
639
|
+
assert atten_mask.ndim == 4 and atten_mask.shape[-1] == key.shape[-3]
|
|
640
|
+
assert atten_mask.shape[-2] in [query.shape[-3], 1]
|
|
641
|
+
assert atten_mask.shape[0] in [key.shape[0], 1]
|
|
642
|
+
|
|
643
|
+
query = self._scale_query(query)
|
|
644
|
+
logits = self._atten_logits(query, key)
|
|
645
|
+
|
|
646
|
+
if self.scale_logits_by_head_dims:
|
|
647
|
+
logits = jnp.multiply(logits, 1.0 / np.sqrt(key.shape[-1]))
|
|
648
|
+
|
|
649
|
+
logits = self._cap_logits(logits)
|
|
650
|
+
# Attention softmax is always carried out in fp32.
|
|
651
|
+
logits = logits.astype(jnp.float32)
|
|
652
|
+
# Apply attention masking.
|
|
653
|
+
padded_logits = _apply_mask_to_logits(logits, atten_mask)
|
|
654
|
+
probs = jax.nn.softmax(padded_logits, axis=-1).astype(self.fprop_dtype)
|
|
655
|
+
# Apply attention dropout.
|
|
656
|
+
probs = nn.Dropout(self.atten_dropout_prob, name='atten_dropout')(
|
|
657
|
+
probs, deterministic=not train
|
|
658
|
+
)
|
|
659
|
+
# Compute the attention context.
|
|
660
|
+
encoded = jnp.einsum('BNTS,BSNH->BTNH', probs, value)
|
|
661
|
+
return encoded, probs
|
|
662
|
+
|
|
663
|
+
@nn.nowrap
|
|
664
|
+
def _project_input(self, name: str, dim_per_head: int) -> AttentionProjection:
|
|
665
|
+
"""Builds an AttentionProjection module."""
|
|
666
|
+
return AttentionProjection(
|
|
667
|
+
name=name,
|
|
668
|
+
num_heads=self.num_heads,
|
|
669
|
+
dim_per_head=dim_per_head,
|
|
670
|
+
use_bias=self.use_bias,
|
|
671
|
+
dtype=self.dtype,
|
|
672
|
+
fprop_dtype=self.fprop_dtype,
|
|
673
|
+
)
|
|
674
|
+
|
|
675
|
+
@nn.nowrap
|
|
676
|
+
def _make_ln(self, name: str) -> LayerNorm:
|
|
677
|
+
"""Makes a LayerNorm module."""
|
|
678
|
+
return LayerNorm(
|
|
679
|
+
name=name,
|
|
680
|
+
use_bias=self.use_bias,
|
|
681
|
+
dtype=self.dtype,
|
|
682
|
+
fprop_dtype=self.fprop_dtype,
|
|
683
|
+
)
|
|
684
|
+
|
|
685
|
+
@nn.compact
|
|
686
|
+
def __call__(
|
|
687
|
+
self,
|
|
688
|
+
query_vec: Array,
|
|
689
|
+
key_vec: Array,
|
|
690
|
+
value_vec: Array,
|
|
691
|
+
atten_mask: Array,
|
|
692
|
+
train: bool,
|
|
693
|
+
) -> tuple[Array, Array]:
|
|
694
|
+
"""Computes the value vector given the current query output.
|
|
695
|
+
|
|
696
|
+
Args:
|
|
697
|
+
query_vec: jax.Array of shape [B, T, D].
|
|
698
|
+
key_vec: jax.Array of shape [B, S, D].
|
|
699
|
+
value_vec: jax.Array of shape [B, S, D].
|
|
700
|
+
atten_mask: jax.Array of shape [1|B, 1, 1|T, S] which is a mask that is
|
|
701
|
+
applied to prevent attention between unwanted pairs. This has already
|
|
702
|
+
been converted into large negative logits. Note that the first and third
|
|
703
|
+
dimension allow size 1 if the mask is shared by every item in the batch
|
|
704
|
+
or every token in the target sequence.
|
|
705
|
+
train: If the model is in the train mode.
|
|
706
|
+
|
|
707
|
+
Returns:
|
|
708
|
+
encoded: jax.Array of shape [B, T, D].
|
|
709
|
+
atten_probs: jax.Array of shape [B, N, T, S].
|
|
710
|
+
"""
|
|
711
|
+
dim_per_head = self.dim_per_head
|
|
712
|
+
if dim_per_head is None:
|
|
713
|
+
dim_per_head = self.hidden_dim // self.num_heads
|
|
714
|
+
assert (
|
|
715
|
+
dim_per_head * self.num_heads == self.hidden_dim
|
|
716
|
+
), f'{dim_per_head} * {self.num_heads} != {self.hidden_dim}'
|
|
717
|
+
|
|
718
|
+
# Project inputs to key, value and query, respectively has shape
|
|
719
|
+
# [B, S, N, H], [B, S, N, H], and [B, T, N, H].
|
|
720
|
+
query_proj = self._project_input('query', dim_per_head)(query_vec)
|
|
721
|
+
key_proj = self._project_input('key', dim_per_head)(key_vec)
|
|
722
|
+
value_proj = self._project_input('value', dim_per_head)(value_vec)
|
|
723
|
+
|
|
724
|
+
if self.use_qk_norm:
|
|
725
|
+
query_proj = self._make_ln(name='layer_norm_q')(query_proj)
|
|
726
|
+
key_proj = self._make_ln(name='layer_norm_k')(key_proj)
|
|
727
|
+
|
|
728
|
+
encoded, atten_probs = self._dot_atten(
|
|
729
|
+
query_proj, key_proj, value_proj, atten_mask, train=train
|
|
730
|
+
)
|
|
731
|
+
|
|
732
|
+
# Post projection. Setting is_output_projection=True to set the projection
|
|
733
|
+
# direction from hidden dim to input dim. Output projection follows
|
|
734
|
+
# query_input_dim.
|
|
735
|
+
query_input_dim = query_vec.shape[-1]
|
|
736
|
+
encoded = AttentionProjection(
|
|
737
|
+
name='post',
|
|
738
|
+
output_dim=query_input_dim,
|
|
739
|
+
num_heads=self.num_heads,
|
|
740
|
+
dim_per_head=dim_per_head,
|
|
741
|
+
is_output_projection=True,
|
|
742
|
+
use_bias=self.use_bias,
|
|
743
|
+
dtype=self.dtype,
|
|
744
|
+
fprop_dtype=self.fprop_dtype,
|
|
745
|
+
)(encoded)
|
|
746
|
+
return encoded, atten_probs
|
|
747
|
+
|
|
748
|
+
|
|
749
|
+
class Transformer(Module):
|
|
750
|
+
"""Transformer layer with multi-headed attention.
|
|
751
|
+
|
|
752
|
+
Attributes:
|
|
753
|
+
hidden_dim: Hidden dimension of FFN layer.
|
|
754
|
+
num_heads: Number of heads in self-attention.
|
|
755
|
+
dim_per_head: Dimension of each attention head. If None then dim_per_head ==
|
|
756
|
+
hidden_dim // num_heads.
|
|
757
|
+
atten_dropout_prob: Probability at which we apply dropout to the attention
|
|
758
|
+
weights.
|
|
759
|
+
residual_dropout_prob: Probability at which we apply dropout to the residual
|
|
760
|
+
layers, such that, residual(x, y) = (x + dropout(y)).
|
|
761
|
+
relu_dropout_prob: Probability at which we apply dropout to the FFN layers.
|
|
762
|
+
norm_policy: Policy for applying normalization wrt. transformations. Options
|
|
763
|
+
are: (1) "pre", applied before transformation. (2) "primer_hybrid",
|
|
764
|
+
applied before and after transformation. (3) "post", applied after
|
|
765
|
+
transformation. (4) "post_skip", applied after the skip connection.
|
|
766
|
+
use_bias: Whether to use bias.
|
|
767
|
+
activation_fn: Activation function to use.
|
|
768
|
+
internal_enable_per_dim_scale: Internal. Setting to False disables rescaling
|
|
769
|
+
of attention logits with 1/sqrt(dim) factor.
|
|
770
|
+
atten_logit_cap: Cap the absolute values of logits by tanh. Enabled when a
|
|
771
|
+
positive value is specified. May not be supported by a subclass.
|
|
772
|
+
"""
|
|
773
|
+
|
|
774
|
+
hidden_dim: int = 0
|
|
775
|
+
num_heads: int = 0
|
|
776
|
+
dim_per_head: int | None = None
|
|
777
|
+
atten_dropout_prob: float = 0.0
|
|
778
|
+
residual_dropout_prob: float = 0.0
|
|
779
|
+
relu_dropout_prob: float = 0.0
|
|
780
|
+
norm_policy: str = 'pre'
|
|
781
|
+
use_bias: bool = True
|
|
782
|
+
activation_fn: ActivationFunc = nn.relu
|
|
783
|
+
internal_enable_per_dim_scale: bool = True
|
|
784
|
+
atten_logit_cap: float = 0.0
|
|
785
|
+
|
|
786
|
+
@nn.nowrap
|
|
787
|
+
def _make_ln(self, name: str) -> LayerNorm:
|
|
788
|
+
"""Makes a LayerNorm module."""
|
|
789
|
+
return LayerNorm(
|
|
790
|
+
name=name,
|
|
791
|
+
use_bias=self.use_bias,
|
|
792
|
+
dtype=self.dtype,
|
|
793
|
+
fprop_dtype=self.fprop_dtype,
|
|
794
|
+
)
|
|
795
|
+
|
|
796
|
+
@nn.compact
|
|
797
|
+
def __call__(
|
|
798
|
+
self,
|
|
799
|
+
inputs: Array,
|
|
800
|
+
paddings: Array,
|
|
801
|
+
atten_mask: Array,
|
|
802
|
+
train: bool,
|
|
803
|
+
) -> Array:
|
|
804
|
+
"""Transformer decoder layer.
|
|
805
|
+
|
|
806
|
+
Args:
|
|
807
|
+
inputs: Input sequence jax.Array of shape [B, T, H].
|
|
808
|
+
paddings: Input paddings jax.Array of shape [B, T] (only used in FFN).
|
|
809
|
+
atten_mask: Self attention mask ready to add to the logits. It can be of
|
|
810
|
+
shape [1|B, 1, 1|T, T] which is broadcast compatible with the
|
|
811
|
+
self-attention matrix of shape [B, N, T, T]. This is assumed to have
|
|
812
|
+
combined paddings, causal masking as well as segment maskings.
|
|
813
|
+
train: Whether the model is in the train mode.
|
|
814
|
+
|
|
815
|
+
Returns:
|
|
816
|
+
The fflayer output with shape [B, T, D].
|
|
817
|
+
"""
|
|
818
|
+
|
|
819
|
+
if self.norm_policy == 'primer_hybrid':
|
|
820
|
+
inputs_normalized = self._make_ln(name='pre_layer_norm')(inputs)
|
|
821
|
+
elif self.norm_policy == 'pre':
|
|
822
|
+
inputs_normalized = self._make_ln(name='layer_norm')(inputs)
|
|
823
|
+
else:
|
|
824
|
+
inputs_normalized = inputs
|
|
825
|
+
|
|
826
|
+
# Compute self-attention, key/value vectors are the input itself.
|
|
827
|
+
atten_outputs, _ = DotProductAttention(
|
|
828
|
+
name='self_attention',
|
|
829
|
+
hidden_dim=inputs_normalized.shape[-1],
|
|
830
|
+
num_heads=self.num_heads,
|
|
831
|
+
dim_per_head=self.dim_per_head,
|
|
832
|
+
atten_dropout_prob=self.atten_dropout_prob,
|
|
833
|
+
use_bias=self.use_bias,
|
|
834
|
+
internal_enable_per_dim_scale=self.internal_enable_per_dim_scale,
|
|
835
|
+
atten_logit_cap=self.atten_logit_cap,
|
|
836
|
+
dtype=self.dtype,
|
|
837
|
+
fprop_dtype=self.fprop_dtype,
|
|
838
|
+
)(
|
|
839
|
+
inputs_normalized,
|
|
840
|
+
inputs_normalized,
|
|
841
|
+
inputs_normalized,
|
|
842
|
+
atten_mask=atten_mask,
|
|
843
|
+
train=train,
|
|
844
|
+
)
|
|
845
|
+
|
|
846
|
+
if self.norm_policy == 'primer_hybrid':
|
|
847
|
+
atten_outputs = self._make_ln(name='post_layer_norm')(atten_outputs)
|
|
848
|
+
elif self.norm_policy == 'post':
|
|
849
|
+
atten_outputs = self._make_ln(name='layer_norm')(atten_outputs)
|
|
850
|
+
|
|
851
|
+
# Residual dropout and connection.
|
|
852
|
+
atten_outputs = nn.Dropout(
|
|
853
|
+
self.residual_dropout_prob, name='residual_dropout'
|
|
854
|
+
)(atten_outputs, deterministic=not train)
|
|
855
|
+
atten_outputs += inputs
|
|
856
|
+
|
|
857
|
+
if self.norm_policy == 'post_skip':
|
|
858
|
+
atten_outputs = self._make_ln(name='layer_norm')(atten_outputs)
|
|
859
|
+
|
|
860
|
+
# Apply FFN layer.
|
|
861
|
+
outputs = TransformerFeedForward(
|
|
862
|
+
name='ff_layer',
|
|
863
|
+
hidden_dim=self.hidden_dim,
|
|
864
|
+
has_bias=self.use_bias,
|
|
865
|
+
activation_fn=self.activation_fn,
|
|
866
|
+
residual_dropout_prob=self.residual_dropout_prob,
|
|
867
|
+
relu_dropout_prob=self.relu_dropout_prob,
|
|
868
|
+
norm_policy=self.norm_policy,
|
|
869
|
+
dtype=self.dtype,
|
|
870
|
+
fprop_dtype=self.fprop_dtype,
|
|
871
|
+
)(atten_outputs, paddings=paddings, train=train)
|
|
872
|
+
return outputs
|
|
873
|
+
|
|
874
|
+
|
|
875
|
+
class Repeat(nn.Module):
|
|
876
|
+
"""A generic repeat layer with `nn.remat` and`nn.scan`.
|
|
877
|
+
|
|
878
|
+
Attributes:
|
|
879
|
+
block_fn: The block function to repeat.
|
|
880
|
+
times: The number of times to repeat block.
|
|
881
|
+
checkpoint_policy: Checkpoint policy for `nn.remat`.
|
|
882
|
+
"""
|
|
883
|
+
|
|
884
|
+
block_fn: Callable[..., Any]
|
|
885
|
+
times: int = 0
|
|
886
|
+
checkpoint_policy: str = 'nothing_saveable'
|
|
887
|
+
|
|
888
|
+
def __call__(
|
|
889
|
+
self,
|
|
890
|
+
inputs: Array,
|
|
891
|
+
*args: Any,
|
|
892
|
+
**kwargs: Any,
|
|
893
|
+
) -> Any:
|
|
894
|
+
"""Forwards inputs through the block layer stack.
|
|
895
|
+
|
|
896
|
+
Block outputs are expected to be of the same structure as inputs.
|
|
897
|
+
|
|
898
|
+
Args:
|
|
899
|
+
inputs: A NestedMap of inputs that goes through the block layer stack.
|
|
900
|
+
*args: Positional args to be passed to the forward method.
|
|
901
|
+
**kwargs: Keyward args to be passed to the forward method.
|
|
902
|
+
|
|
903
|
+
Returns:
|
|
904
|
+
Output from the last layer.
|
|
905
|
+
"""
|
|
906
|
+
return self.call_with_custom_method(
|
|
907
|
+
inputs,
|
|
908
|
+
*args,
|
|
909
|
+
main_fn=self.block_fn,
|
|
910
|
+
**kwargs,
|
|
911
|
+
)
|
|
912
|
+
|
|
913
|
+
def call_with_custom_method(
|
|
914
|
+
self,
|
|
915
|
+
inputs: Array,
|
|
916
|
+
*args: Any,
|
|
917
|
+
main_fn: Callable[..., Any],
|
|
918
|
+
**kwargs: Any,
|
|
919
|
+
) -> Any:
|
|
920
|
+
"""Similar to __call__, but allows a custom way to create a layer method."""
|
|
921
|
+
|
|
922
|
+
def body_fn(fn, layer_inputs):
|
|
923
|
+
return fn(layer_inputs, *args, **kwargs), None
|
|
924
|
+
|
|
925
|
+
rematted_body_fn = nn.remat(
|
|
926
|
+
body_fn,
|
|
927
|
+
prevent_cse=False,
|
|
928
|
+
policy=getattr(jax.checkpoint_policies, self.checkpoint_policy, None),
|
|
929
|
+
)
|
|
930
|
+
scan_fn = nn.scan(
|
|
931
|
+
rematted_body_fn,
|
|
932
|
+
variable_axes={'params': 0},
|
|
933
|
+
split_rngs={'params': True, 'dropout': True},
|
|
934
|
+
length=self.times,
|
|
935
|
+
)
|
|
936
|
+
outputs, _ = scan_fn(main_fn, inputs)
|
|
937
|
+
return outputs
|
|
938
|
+
|
|
939
|
+
|
|
940
|
+
class StackedTransformer(Module):
|
|
941
|
+
"""A stack of Transformer layers.
|
|
942
|
+
|
|
943
|
+
Attributes:
|
|
944
|
+
num_layers: Number of layers in this stack.
|
|
945
|
+
hidden_dim: The hidden layer dimension of FFN in Transformer layers.
|
|
946
|
+
num_heads: Number of attention heads.
|
|
947
|
+
dim_per_head: Dimension of each attention head. If None then dim_per_head ==
|
|
948
|
+
model_dims // num_heads.
|
|
949
|
+
dropout_prob: Apply dropout at this prob at various places.
|
|
950
|
+
atten_dropout_prob: Probability at which we apply dropout to the attention
|
|
951
|
+
weights.
|
|
952
|
+
residual_dropout_prob: Probability at which we apply dropout to the residual
|
|
953
|
+
layers, such that, residual(x, y) = (x + dropout(y)).
|
|
954
|
+
relu_dropout_prob: Probability at which we apply dropout to the FFN layers.
|
|
955
|
+
input_dropout_prob: Dropout probability applied to the input before any
|
|
956
|
+
processing happens.
|
|
957
|
+
norm_policy: Policy for applying normalization wrt. transformations. Options
|
|
958
|
+
are: (1) "pre", applied before transformation. (2) "primer_hybrid",
|
|
959
|
+
applied before and after transformation. (3) "post", applied after
|
|
960
|
+
transformation. (4) "post_skip", applied after the skip connection.
|
|
961
|
+
use_bias: Whether to use bias.
|
|
962
|
+
activation_fn: Activation function to use.
|
|
963
|
+
internal_enable_per_dim_scale: Internal. Setting to False disables rescaling
|
|
964
|
+
of attention logits with 1/sqrt(dim) factor.
|
|
965
|
+
atten_logit_cap: Cap the absolute values of logits by tanh. Enabled when a
|
|
966
|
+
positive value is specified. May not be supported by a subclass.
|
|
967
|
+
enable_causal_atten: Whether to enable causal attention.
|
|
968
|
+
scan: Whether to use `nn.remat` and`nn.scan`.
|
|
969
|
+
"""
|
|
970
|
+
|
|
971
|
+
num_layers: int = 0
|
|
972
|
+
hidden_dim: int = 0
|
|
973
|
+
num_heads: int = 0
|
|
974
|
+
dim_per_head: int | None = None
|
|
975
|
+
dropout_prob: float = 0.0
|
|
976
|
+
atten_dropout_prob: float | None = None
|
|
977
|
+
residual_dropout_prob: float | None = None
|
|
978
|
+
relu_dropout_prob: float | None = None
|
|
979
|
+
input_dropout_prob: float = 0.0
|
|
980
|
+
norm_policy: str = 'pre'
|
|
981
|
+
use_bias: bool = True
|
|
982
|
+
activation_fn: ActivationFunc = nn.relu
|
|
983
|
+
internal_enable_per_dim_scale: bool = True
|
|
984
|
+
atten_logit_cap: float = 0.0
|
|
985
|
+
enable_causal_atten: bool = False
|
|
986
|
+
scan: bool = False
|
|
987
|
+
|
|
988
|
+
@nn.compact
|
|
989
|
+
def __call__(
|
|
990
|
+
self,
|
|
991
|
+
inputs: Array,
|
|
992
|
+
paddings: Array,
|
|
993
|
+
train: bool,
|
|
994
|
+
) -> Array:
|
|
995
|
+
"""Stacked Transformer layer.
|
|
996
|
+
|
|
997
|
+
Args:
|
|
998
|
+
inputs: Input sequence of shape [B, T, H].
|
|
999
|
+
paddings: Input paddings of shape [B, T].
|
|
1000
|
+
train: If the model is in the train mode.
|
|
1001
|
+
|
|
1002
|
+
Returns:
|
|
1003
|
+
Output vector with shape [B, T, D].
|
|
1004
|
+
"""
|
|
1005
|
+
|
|
1006
|
+
atten_mask = compute_attention_masks_for_fprop(
|
|
1007
|
+
inputs, paddings, causal_attention=self.enable_causal_atten
|
|
1008
|
+
)
|
|
1009
|
+
|
|
1010
|
+
outputs = inputs
|
|
1011
|
+
if self.input_dropout_prob > 0.0:
|
|
1012
|
+
outputs = nn.Dropout(self.input_dropout_prob, name='input_dropout')(
|
|
1013
|
+
outputs, deterministic=not train
|
|
1014
|
+
)
|
|
1015
|
+
|
|
1016
|
+
transformer_kwargs = dict(
|
|
1017
|
+
num_heads=self.num_heads,
|
|
1018
|
+
dim_per_head=self.dim_per_head,
|
|
1019
|
+
hidden_dim=self.hidden_dim,
|
|
1020
|
+
atten_dropout_prob=self.atten_dropout_prob or self.dropout_prob,
|
|
1021
|
+
residual_dropout_prob=self.residual_dropout_prob or self.dropout_prob,
|
|
1022
|
+
relu_dropout_prob=self.relu_dropout_prob or self.dropout_prob,
|
|
1023
|
+
norm_policy=self.norm_policy,
|
|
1024
|
+
use_bias=self.use_bias,
|
|
1025
|
+
activation_fn=self.activation_fn,
|
|
1026
|
+
internal_enable_per_dim_scale=self.internal_enable_per_dim_scale,
|
|
1027
|
+
atten_logit_cap=self.atten_logit_cap,
|
|
1028
|
+
dtype=self.dtype,
|
|
1029
|
+
fprop_dtype=self.fprop_dtype,
|
|
1030
|
+
)
|
|
1031
|
+
if self.scan:
|
|
1032
|
+
block_fn = Transformer(name='x_layers', **transformer_kwargs)
|
|
1033
|
+
outputs = Repeat(block_fn=block_fn, times=self.num_layers)(
|
|
1034
|
+
outputs, paddings, atten_mask, train
|
|
1035
|
+
)
|
|
1036
|
+
else:
|
|
1037
|
+
for i in range(self.num_layers):
|
|
1038
|
+
outputs = Transformer(name=f'x_layers_{i}', **transformer_kwargs)(
|
|
1039
|
+
outputs, paddings, atten_mask, train
|
|
1040
|
+
)
|
|
1041
|
+
return outputs
|
|
1042
|
+
|
|
1043
|
+
|
|
1044
|
+
class AttenTokenPoolingLayer(Module):
|
|
1045
|
+
"""Attentional token pooling layer.
|
|
1046
|
+
|
|
1047
|
+
Attributes:
|
|
1048
|
+
query_dim: The query dimension of attention. If None then query_dim ==
|
|
1049
|
+
input_dim.
|
|
1050
|
+
hidden_dim: The hidden layer dimension of FFN in Transformer layers.
|
|
1051
|
+
num_heads: Number of attention heads.
|
|
1052
|
+
num_queries: Number of attention queries.
|
|
1053
|
+
add_layer_norm: Whether to apply layer norm to the pooled tokens.
|
|
1054
|
+
dropout_prob: The probability of dropout on the pooled tokens.
|
|
1055
|
+
use_qk_norm: If QK norm is used.
|
|
1056
|
+
use_bias: Whether to use bias.
|
|
1057
|
+
internal_enable_per_dim_scale: Internal. Setting to False disables rescaling
|
|
1058
|
+
of attention logits with 1/sqrt(dim) factor.
|
|
1059
|
+
"""
|
|
1060
|
+
|
|
1061
|
+
query_dim: int | None = None
|
|
1062
|
+
hidden_dim: int = 0
|
|
1063
|
+
num_heads: int = 1
|
|
1064
|
+
num_queries: int = 1
|
|
1065
|
+
add_layer_norm: bool = True
|
|
1066
|
+
dropout_prob: float = 0.0
|
|
1067
|
+
use_qk_norm: bool = False
|
|
1068
|
+
use_bias: bool = True
|
|
1069
|
+
internal_enable_per_dim_scale: bool = True
|
|
1070
|
+
|
|
1071
|
+
@nn.compact
|
|
1072
|
+
def __call__(
|
|
1073
|
+
self,
|
|
1074
|
+
tokens: Array,
|
|
1075
|
+
paddings: Array | None,
|
|
1076
|
+
train: bool,
|
|
1077
|
+
) -> Array:
|
|
1078
|
+
"""Computes the pooled tokens for inputs.
|
|
1079
|
+
|
|
1080
|
+
Args:
|
|
1081
|
+
tokens: Input tokens of shape [B, T, H].
|
|
1082
|
+
paddings: Input paddings of shape [B, T].
|
|
1083
|
+
train: If the model is in the train mode.
|
|
1084
|
+
|
|
1085
|
+
Returns:
|
|
1086
|
+
Output vector with shape [B, N, D].
|
|
1087
|
+
"""
|
|
1088
|
+
input_dim = tokens.shape[-1]
|
|
1089
|
+
query_dim = self.query_dim or input_dim
|
|
1090
|
+
hidden_dim = self.hidden_dim if self.hidden_dim > 0 else 4 * input_dim
|
|
1091
|
+
batch_size, seq_length = tokens.shape[0], tokens.shape[-2]
|
|
1092
|
+
|
|
1093
|
+
query = self._cast_to_fprop_dtype(
|
|
1094
|
+
self.param(
|
|
1095
|
+
'pooling_attention_query',
|
|
1096
|
+
default_kernel_init,
|
|
1097
|
+
[self.num_queries, query_dim],
|
|
1098
|
+
self.dtype,
|
|
1099
|
+
)
|
|
1100
|
+
)
|
|
1101
|
+
query = jnp.tile(query[jnp.newaxis, :, :], [batch_size, 1, 1])
|
|
1102
|
+
|
|
1103
|
+
if paddings is None:
|
|
1104
|
+
paddings = jnp.zeros([batch_size, seq_length], dtype=tokens.dtype)
|
|
1105
|
+
|
|
1106
|
+
atten_mask = _convert_paddings_to_mask(paddings, dtype=paddings.dtype)
|
|
1107
|
+
outputs, _ = DotProductAttention(
|
|
1108
|
+
name='pooling_attention',
|
|
1109
|
+
hidden_dim=hidden_dim,
|
|
1110
|
+
num_heads=self.num_heads,
|
|
1111
|
+
use_bias=self.use_bias,
|
|
1112
|
+
internal_enable_per_dim_scale=self.internal_enable_per_dim_scale,
|
|
1113
|
+
use_qk_norm=self.use_qk_norm,
|
|
1114
|
+
dtype=self.dtype,
|
|
1115
|
+
fprop_dtype=self.fprop_dtype,
|
|
1116
|
+
)(
|
|
1117
|
+
query,
|
|
1118
|
+
tokens,
|
|
1119
|
+
tokens,
|
|
1120
|
+
atten_mask=atten_mask,
|
|
1121
|
+
train=train,
|
|
1122
|
+
)
|
|
1123
|
+
|
|
1124
|
+
if self.add_layer_norm:
|
|
1125
|
+
outputs = LayerNorm(
|
|
1126
|
+
name='pooling_attention_layer_norm',
|
|
1127
|
+
dtype=self.dtype,
|
|
1128
|
+
fprop_dtype=self.fprop_dtype,
|
|
1129
|
+
)(outputs)
|
|
1130
|
+
|
|
1131
|
+
if self.dropout_prob > 0.0:
|
|
1132
|
+
outputs = nn.Dropout(self.dropout_prob, name='attention_dropout')(
|
|
1133
|
+
outputs, deterministic=not train
|
|
1134
|
+
)
|
|
1135
|
+
|
|
1136
|
+
return outputs
|