nmn 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- nmn/nnx/examples/language/mingpt.py +1650 -0
- nmn/nnx/examples/vision/cnn_cifar.py +1769 -0
- nmn/nnx/nmn.py +1 -1
- nmn/nnx/yatattention.py +764 -0
- nmn/nnx/yatconv.py +22 -2
- nmn/torch/nmn.py +2 -1
- {nmn-0.1.2.dist-info → nmn-0.1.4.dist-info}/METADATA +2 -2
- nmn-0.1.4.dist-info/RECORD +14 -0
- nmn-0.1.2.dist-info/RECORD +0 -11
- {nmn-0.1.2.dist-info → nmn-0.1.4.dist-info}/WHEEL +0 -0
- {nmn-0.1.2.dist-info → nmn-0.1.4.dist-info}/licenses/LICENSE +0 -0
nmn/nnx/yatattention.py
ADDED
@@ -0,0 +1,764 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import functools
|
4
|
+
from typing import Any, Callable, Optional
|
5
|
+
import typing as tp
|
6
|
+
|
7
|
+
import jax
|
8
|
+
import jax.numpy as jnp
|
9
|
+
from jax import lax, random
|
10
|
+
|
11
|
+
from flax import nnx
|
12
|
+
from flax.nnx import rnglib
|
13
|
+
from flax.nnx.module import Module, first_from
|
14
|
+
from flax.nnx.nn import initializers
|
15
|
+
from flax.nnx.nn.dtypes import promote_dtype
|
16
|
+
from flax.nnx.nn.linear import (
|
17
|
+
LinearGeneral,
|
18
|
+
default_kernel_init,
|
19
|
+
)
|
20
|
+
from flax.nnx.nn.normalization import LayerNorm
|
21
|
+
from flax.typing import (
|
22
|
+
Dtype,
|
23
|
+
Shape,
|
24
|
+
Initializer,
|
25
|
+
PrecisionLike,
|
26
|
+
DotGeneralT,
|
27
|
+
)
|
28
|
+
|
29
|
+
|
30
|
+
|
31
|
+
def yat_attention_weights(
|
32
|
+
query: Array,
|
33
|
+
key: Array,
|
34
|
+
bias: Optional[Array] = None,
|
35
|
+
mask: Optional[Array] = None,
|
36
|
+
broadcast_dropout: bool = True,
|
37
|
+
dropout_rng: Optional[Array] = None,
|
38
|
+
dropout_rate: float = 0.0,
|
39
|
+
deterministic: bool = False,
|
40
|
+
dtype: Optional[Dtype] = None,
|
41
|
+
precision: PrecisionLike = None,
|
42
|
+
module: Optional[Module] = None,
|
43
|
+
epsilon: float = 1e-5,
|
44
|
+
):
|
45
|
+
"""Computes attention weights using YatNMN distance-based calculation."""
|
46
|
+
query, key = promote_dtype((query, key), dtype=dtype)
|
47
|
+
dtype = query.dtype
|
48
|
+
|
49
|
+
assert query.ndim == key.ndim, 'q, k must have same rank.'
|
50
|
+
assert query.shape[:-3] == key.shape[:-3], 'q, k batch dims must match.'
|
51
|
+
assert query.shape[-2] == key.shape[-2], 'q, k num_heads must match.'
|
52
|
+
assert query.shape[-1] == key.shape[-1], 'q, k depths must match.'
|
53
|
+
|
54
|
+
# YatNMN-style attention calculation using the cleaner approach
|
55
|
+
# query shape: [..., q_length, num_heads, head_dim]
|
56
|
+
# key shape: [..., kv_length, num_heads, head_dim]
|
57
|
+
|
58
|
+
# Calculate dot product attention scores
|
59
|
+
attn = jnp.einsum('...qhd,...khd->...hqk', query, key, precision=precision)
|
60
|
+
squared_dot_product = jnp.square(attn)
|
61
|
+
|
62
|
+
# Calculate norms
|
63
|
+
q_norm = jnp.sum(jnp.square(query), axis=-1, keepdims=True) # [..., q_length, num_heads, 1]
|
64
|
+
k_norm = jnp.sum(jnp.square(key), axis=-1, keepdims=True) # [..., kv_length, num_heads, 1]
|
65
|
+
qk_norm_sum = q_norm + k_norm # Broadcasting: [..., q_length, num_heads, 1] + [..., kv_length, num_heads, 1]
|
66
|
+
|
67
|
+
# Transpose to match attention dimensions [..., num_heads, q_length, kv_length]
|
68
|
+
# The transpose converts [..., q_length, num_heads, kv_length] -> [..., num_heads, q_length, kv_length]
|
69
|
+
batch_dims = len(qk_norm_sum.shape) - 3
|
70
|
+
transpose_axes = tuple(range(batch_dims)) + (batch_dims + 1, batch_dims, batch_dims + 2)
|
71
|
+
qk_norm_sum_transposed = qk_norm_sum.transpose(transpose_axes)
|
72
|
+
|
73
|
+
# Calculate squared distances: ||q||² + ||k||² - 2*(q·k)²
|
74
|
+
squared_dist = qk_norm_sum_transposed - 2.0 * squared_dot_product
|
75
|
+
|
76
|
+
# YatNMN attention scores: (q·k)² / (squared_distance + ε)
|
77
|
+
attn_weights = squared_dot_product / (squared_dist + epsilon)
|
78
|
+
|
79
|
+
# apply attention bias: masking, dropout, proximity bias, etc.
|
80
|
+
if bias is not None:
|
81
|
+
attn_weights = attn_weights + bias
|
82
|
+
# apply attention mask
|
83
|
+
if mask is not None:
|
84
|
+
big_neg = jnp.finfo(dtype).min
|
85
|
+
attn_weights = jnp.where(mask, attn_weights, big_neg)
|
86
|
+
|
87
|
+
# normalize the attention weights
|
88
|
+
attn_weights = jax.nn.softmax(attn_weights).astype(dtype)
|
89
|
+
|
90
|
+
if module:
|
91
|
+
module.sow(nnx.Intermediate, 'attention_weights', attn_weights)
|
92
|
+
|
93
|
+
# apply attention dropout
|
94
|
+
if not deterministic and dropout_rate > 0.0:
|
95
|
+
keep_prob = 1.0 - dropout_rate
|
96
|
+
if broadcast_dropout:
|
97
|
+
# dropout is broadcast across the batch + head dimensions
|
98
|
+
dropout_shape = tuple([1] * (key.ndim - 2)) + attn_weights.shape[-2:]
|
99
|
+
keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape)
|
100
|
+
else:
|
101
|
+
keep = random.bernoulli(dropout_rng, keep_prob, attn_weights.shape)
|
102
|
+
multiplier = keep.astype(dtype) / jnp.asarray(keep_prob, dtype=dtype)
|
103
|
+
attn_weights = attn_weights * multiplier
|
104
|
+
|
105
|
+
return attn_weights
|
106
|
+
|
107
|
+
|
108
|
+
def yat_attention(
|
109
|
+
query: Array,
|
110
|
+
key: Array,
|
111
|
+
value: Array,
|
112
|
+
bias: Optional[Array] = None,
|
113
|
+
mask: Optional[Array] = None,
|
114
|
+
broadcast_dropout: bool = True,
|
115
|
+
dropout_rng: Optional[Array] = None,
|
116
|
+
dropout_rate: float = 0.0,
|
117
|
+
deterministic: bool = False,
|
118
|
+
dtype: Optional[Dtype] = None,
|
119
|
+
precision: PrecisionLike = None,
|
120
|
+
module: Optional[Module] = None,
|
121
|
+
epsilon: float = 1e-5,
|
122
|
+
):
|
123
|
+
"""Computes attention using YatNMN distance-based calculation."""
|
124
|
+
query, key, value = promote_dtype((query, key, value), dtype=dtype)
|
125
|
+
dtype = query.dtype
|
126
|
+
assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.'
|
127
|
+
assert (
|
128
|
+
query.shape[:-3] == key.shape[:-3] == value.shape[:-3]
|
129
|
+
), 'q, k, v batch dims must match.'
|
130
|
+
assert (
|
131
|
+
query.shape[-2] == key.shape[-2] == value.shape[-2]
|
132
|
+
), 'q, k, v num_heads must match.'
|
133
|
+
assert key.shape[-3] == value.shape[-3], 'k, v lengths must match.'
|
134
|
+
|
135
|
+
# compute attention weights using YatNMN
|
136
|
+
attn_weights = yat_attention_weights(
|
137
|
+
query,
|
138
|
+
key,
|
139
|
+
bias,
|
140
|
+
mask,
|
141
|
+
broadcast_dropout,
|
142
|
+
dropout_rng,
|
143
|
+
dropout_rate,
|
144
|
+
deterministic,
|
145
|
+
dtype,
|
146
|
+
precision,
|
147
|
+
module,
|
148
|
+
epsilon,
|
149
|
+
)
|
150
|
+
|
151
|
+
# return weighted sum over values for each query position
|
152
|
+
return jnp.einsum(
|
153
|
+
'...hqk,...khd->...qhd', attn_weights, value, precision=precision
|
154
|
+
)
|
155
|
+
|
156
|
+
Array = jax.Array
|
157
|
+
|
158
|
+
# Add YatNMN class implementation
|
159
|
+
default_bias_init = initializers.zeros_init()
|
160
|
+
default_alpha_init = initializers.ones_init()
|
161
|
+
|
162
|
+
class YatNMN(Module):
|
163
|
+
"""A linear transformation with custom distance-based computation."""
|
164
|
+
|
165
|
+
def __init__(
|
166
|
+
self,
|
167
|
+
in_features: int,
|
168
|
+
out_features: int,
|
169
|
+
*,
|
170
|
+
use_bias: bool = True,
|
171
|
+
use_alpha: bool = True,
|
172
|
+
dtype: Optional[Dtype] = None,
|
173
|
+
param_dtype: Dtype = jnp.float32,
|
174
|
+
precision: PrecisionLike = None,
|
175
|
+
kernel_init: Initializer = default_kernel_init,
|
176
|
+
bias_init: Initializer = default_bias_init,
|
177
|
+
alpha_init: Initializer = default_alpha_init,
|
178
|
+
dot_general: DotGeneralT = lax.dot_general,
|
179
|
+
rngs: rnglib.Rngs,
|
180
|
+
epsilon: float = 1e-5,
|
181
|
+
):
|
182
|
+
|
183
|
+
kernel_key = rngs.params()
|
184
|
+
self.kernel = nnx.Param(
|
185
|
+
kernel_init(kernel_key, (in_features, out_features), param_dtype)
|
186
|
+
)
|
187
|
+
self.bias: nnx.Param[jax.Array] | None
|
188
|
+
if use_bias:
|
189
|
+
bias_key = rngs.params()
|
190
|
+
self.bias = nnx.Param(bias_init(bias_key, (out_features,), param_dtype))
|
191
|
+
else:
|
192
|
+
self.bias = None
|
193
|
+
|
194
|
+
self.alpha: nnx.Param[jax.Array] | None
|
195
|
+
if use_alpha:
|
196
|
+
alpha_key = rngs.params()
|
197
|
+
self.alpha = nnx.Param(alpha_init(alpha_key, (1,), param_dtype))
|
198
|
+
else:
|
199
|
+
self.alpha = None
|
200
|
+
|
201
|
+
self.in_features = in_features
|
202
|
+
self.out_features = out_features
|
203
|
+
self.use_bias = use_bias
|
204
|
+
self.use_alpha = use_alpha
|
205
|
+
self.dtype = dtype
|
206
|
+
self.param_dtype = param_dtype
|
207
|
+
self.precision = precision
|
208
|
+
self.kernel_init = kernel_init
|
209
|
+
self.bias_init = bias_init
|
210
|
+
self.dot_general = dot_general
|
211
|
+
self.epsilon = epsilon
|
212
|
+
|
213
|
+
def __call__(self, inputs: Array) -> Array:
|
214
|
+
"""Applies YatNMN transformation to inputs."""
|
215
|
+
kernel = self.kernel.value
|
216
|
+
bias = self.bias.value if self.bias is not None else None
|
217
|
+
alpha = self.alpha.value if self.alpha is not None else None
|
218
|
+
|
219
|
+
y = self.dot_general(
|
220
|
+
inputs,
|
221
|
+
kernel,
|
222
|
+
(((inputs.ndim - 1,), (0,)), ((), ())),
|
223
|
+
precision=self.precision,
|
224
|
+
)
|
225
|
+
|
226
|
+
inputs_squared_sum = jnp.sum(inputs**2, axis=-1, keepdims=True)
|
227
|
+
kernel_squared_sum = jnp.sum(kernel**2, axis=0, keepdims=True)
|
228
|
+
distances = inputs_squared_sum + kernel_squared_sum - 2 * y
|
229
|
+
|
230
|
+
# Element-wise operation
|
231
|
+
y = y ** 2 / (distances + self.epsilon)
|
232
|
+
|
233
|
+
if bias is not None:
|
234
|
+
y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,))
|
235
|
+
|
236
|
+
if alpha is not None:
|
237
|
+
scale = (jnp.sqrt(self.out_features) / jnp.log(1 + self.out_features)) ** alpha
|
238
|
+
y = y * scale
|
239
|
+
|
240
|
+
return y
|
241
|
+
|
242
|
+
|
243
|
+
def dot_product_attention_weights(
|
244
|
+
query: Array,
|
245
|
+
key: Array,
|
246
|
+
bias: Optional[Array] = None,
|
247
|
+
mask: Optional[Array] = None,
|
248
|
+
broadcast_dropout: bool = True,
|
249
|
+
dropout_rng: Optional[Array] = None,
|
250
|
+
dropout_rate: float = 0.0,
|
251
|
+
deterministic: bool = False,
|
252
|
+
dtype: Optional[Dtype] = None,
|
253
|
+
precision: PrecisionLike = None,
|
254
|
+
module: Optional[Module] = None,
|
255
|
+
):
|
256
|
+
"""Computes dot-product attention weights given query and key.
|
257
|
+
|
258
|
+
Used by :func:`dot_product_attention`, which is what you'll most likely use.
|
259
|
+
But if you want access to the attention weights for introspection, then
|
260
|
+
you can directly call this function and call einsum yourself.
|
261
|
+
|
262
|
+
Args:
|
263
|
+
query: queries for calculating attention with shape of `[batch..., q_length,
|
264
|
+
num_heads, qk_depth_per_head]`.
|
265
|
+
key: keys for calculating attention with shape of `[batch..., kv_length,
|
266
|
+
num_heads, qk_depth_per_head]`.
|
267
|
+
bias: bias for the attention weights. This should be broadcastable to the
|
268
|
+
shape `[batch..., num_heads, q_length, kv_length]`. This can be used for
|
269
|
+
incorporating causal masks, padding masks, proximity bias, etc.
|
270
|
+
mask: mask for the attention weights. This should be broadcastable to the
|
271
|
+
shape `[batch..., num_heads, q_length, kv_length]`. This can be used for
|
272
|
+
incorporating causal masks. Attention weights are masked out if their
|
273
|
+
corresponding mask value is `False`.
|
274
|
+
broadcast_dropout: bool: use a broadcasted dropout along batch dims.
|
275
|
+
dropout_rng: JAX PRNGKey: to be used for dropout
|
276
|
+
dropout_rate: dropout rate
|
277
|
+
deterministic: bool, deterministic or not (to apply dropout)
|
278
|
+
dtype: the dtype of the computation (default: infer from inputs and params)
|
279
|
+
precision: numerical precision of the computation see `jax.lax.Precision`
|
280
|
+
for details.
|
281
|
+
module: the Module that will sow the attention weights into the
|
282
|
+
``nnx.Intermediate`` collection. If ``module`` is None, the attention
|
283
|
+
weights will not be sowed.
|
284
|
+
|
285
|
+
Returns:
|
286
|
+
Output of shape `[batch..., num_heads, q_length, kv_length]`.
|
287
|
+
"""
|
288
|
+
query, key = promote_dtype((query, key), dtype=dtype) # type: ignore[bad-unpacking]
|
289
|
+
dtype = query.dtype
|
290
|
+
|
291
|
+
assert query.ndim == key.ndim, 'q, k must have same rank.'
|
292
|
+
assert query.shape[:-3] == key.shape[:-3], 'q, k batch dims must match.'
|
293
|
+
assert query.shape[-2] == key.shape[-2], 'q, k num_heads must match.'
|
294
|
+
assert query.shape[-1] == key.shape[-1], 'q, k depths must match.'
|
295
|
+
|
296
|
+
# calculate attention matrix
|
297
|
+
depth = query.shape[-1]
|
298
|
+
query = query / jnp.sqrt(depth).astype(dtype)
|
299
|
+
# attn weight shape is (batch..., num_heads, q_length, kv_length)
|
300
|
+
attn_weights = jnp.einsum(
|
301
|
+
'...qhd,...khd->...hqk', query, key, precision=precision
|
302
|
+
)
|
303
|
+
|
304
|
+
# apply attention bias: masking, dropout, proximity bias, etc.
|
305
|
+
if bias is not None:
|
306
|
+
attn_weights = attn_weights + bias
|
307
|
+
# apply attention mask
|
308
|
+
if mask is not None:
|
309
|
+
big_neg = jnp.finfo(dtype).min
|
310
|
+
attn_weights = jnp.where(mask, attn_weights, big_neg)
|
311
|
+
|
312
|
+
# normalize the attention weights
|
313
|
+
attn_weights = jax.nn.softmax(attn_weights).astype(dtype)
|
314
|
+
|
315
|
+
if module:
|
316
|
+
module.sow(nnx.Intermediate, 'attention_weights', attn_weights)
|
317
|
+
|
318
|
+
# apply attention dropout
|
319
|
+
if not deterministic and dropout_rate > 0.0:
|
320
|
+
keep_prob = 1.0 - dropout_rate
|
321
|
+
if broadcast_dropout:
|
322
|
+
# dropout is broadcast across the batch + head dimensions
|
323
|
+
dropout_shape = tuple([1] * (key.ndim - 2)) + attn_weights.shape[-2:]
|
324
|
+
keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape) # type: ignore
|
325
|
+
else:
|
326
|
+
keep = random.bernoulli(dropout_rng, keep_prob, attn_weights.shape) # type: ignore
|
327
|
+
multiplier = keep.astype(dtype) / jnp.asarray(keep_prob, dtype=dtype)
|
328
|
+
attn_weights = attn_weights * multiplier
|
329
|
+
|
330
|
+
return attn_weights
|
331
|
+
|
332
|
+
|
333
|
+
def dot_product_attention(
|
334
|
+
query: Array,
|
335
|
+
key: Array,
|
336
|
+
value: Array,
|
337
|
+
bias: Optional[Array] = None,
|
338
|
+
mask: Optional[Array] = None,
|
339
|
+
broadcast_dropout: bool = True,
|
340
|
+
dropout_rng: Optional[Array] = None,
|
341
|
+
dropout_rate: float = 0.0,
|
342
|
+
deterministic: bool = False,
|
343
|
+
dtype: Optional[Dtype] = None,
|
344
|
+
precision: PrecisionLike = None,
|
345
|
+
module: Optional[Module] = None,
|
346
|
+
):
|
347
|
+
"""Computes dot-product attention given query, key, and value.
|
348
|
+
|
349
|
+
This is the core function for applying attention based on
|
350
|
+
https://arxiv.org/abs/1706.03762. It calculates the attention weights given
|
351
|
+
query and key and combines the values using the attention weights.
|
352
|
+
|
353
|
+
.. note::
|
354
|
+
``query``, ``key``, ``value`` needn't have any batch dimensions.
|
355
|
+
|
356
|
+
Args:
|
357
|
+
query: queries for calculating attention with shape of ``[batch..., q_length,
|
358
|
+
num_heads, qk_depth_per_head]``.
|
359
|
+
key: keys for calculating attention with shape of ``[batch..., kv_length,
|
360
|
+
num_heads, qk_depth_per_head]``.
|
361
|
+
value: values to be used in attention with shape of ``[batch..., kv_length,
|
362
|
+
num_heads, v_depth_per_head]``.
|
363
|
+
bias: bias for the attention weights. This should be broadcastable to the
|
364
|
+
shape `[batch..., num_heads, q_length, kv_length]`. This can be used for
|
365
|
+
incorporating causal masks, padding masks, proximity bias, etc.
|
366
|
+
mask: mask for the attention weights. This should be broadcastable to the
|
367
|
+
shape `[batch..., num_heads, q_length, kv_length]`. This can be used for
|
368
|
+
incorporating causal masks. Attention weights are masked out if their
|
369
|
+
corresponding mask value is `False`.
|
370
|
+
broadcast_dropout: bool: use a broadcasted dropout along batch dims.
|
371
|
+
dropout_rng: JAX PRNGKey: to be used for dropout
|
372
|
+
dropout_rate: dropout rate
|
373
|
+
deterministic: bool, deterministic or not (to apply dropout)
|
374
|
+
dtype: the dtype of the computation (default: infer from inputs)
|
375
|
+
precision: numerical precision of the computation see `jax.lax.Precision`
|
376
|
+
for details.
|
377
|
+
module: the Module that will sow the attention weights into the
|
378
|
+
``nnx.Intermediate`` collection. If ``module`` is None, the attention
|
379
|
+
weights will not be sowed.
|
380
|
+
|
381
|
+
Returns:
|
382
|
+
Output of shape `[batch..., q_length, num_heads, v_depth_per_head]`.
|
383
|
+
"""
|
384
|
+
query, key, value = promote_dtype((query, key, value), dtype=dtype) # type: ignore[bad-unpacking]
|
385
|
+
dtype = query.dtype
|
386
|
+
assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.'
|
387
|
+
assert (
|
388
|
+
query.shape[:-3] == key.shape[:-3] == value.shape[:-3]
|
389
|
+
), 'q, k, v batch dims must match.'
|
390
|
+
assert (
|
391
|
+
query.shape[-2] == key.shape[-2] == value.shape[-2]
|
392
|
+
), 'q, k, v num_heads must match.'
|
393
|
+
assert key.shape[-3] == value.shape[-3], 'k, v lengths must match.'
|
394
|
+
|
395
|
+
# compute attention weights
|
396
|
+
attn_weights = dot_product_attention_weights(
|
397
|
+
query,
|
398
|
+
key,
|
399
|
+
bias,
|
400
|
+
mask,
|
401
|
+
broadcast_dropout,
|
402
|
+
dropout_rng,
|
403
|
+
dropout_rate,
|
404
|
+
deterministic,
|
405
|
+
dtype,
|
406
|
+
precision,
|
407
|
+
module,
|
408
|
+
)
|
409
|
+
|
410
|
+
# return weighted sum over values for each query position
|
411
|
+
return jnp.einsum(
|
412
|
+
'...hqk,...khd->...qhd', attn_weights, value, precision=precision
|
413
|
+
)
|
414
|
+
|
415
|
+
|
416
|
+
class MultiHeadAttention(Module):
|
417
|
+
def __init__(
|
418
|
+
self,
|
419
|
+
num_heads: int,
|
420
|
+
in_features: int,
|
421
|
+
qkv_features: int | None = None,
|
422
|
+
out_features: int | None = None,
|
423
|
+
*,
|
424
|
+
dtype: Dtype | None = None,
|
425
|
+
param_dtype: Dtype = jnp.float32,
|
426
|
+
broadcast_dropout: bool = True,
|
427
|
+
dropout_rate: float = 0.0,
|
428
|
+
deterministic: bool | None = None,
|
429
|
+
precision: PrecisionLike = None,
|
430
|
+
kernel_init: Initializer = default_kernel_init,
|
431
|
+
out_kernel_init: Initializer | None = None,
|
432
|
+
bias_init: Initializer = initializers.zeros_init(),
|
433
|
+
out_bias_init: Initializer | None = None,
|
434
|
+
use_bias: bool = True,
|
435
|
+
attention_fn: Callable[..., Array] = yat_attention,
|
436
|
+
decode: bool | None = None,
|
437
|
+
normalize_qk: bool = False,
|
438
|
+
# Deprecated, will be removed.
|
439
|
+
qkv_dot_general: DotGeneralT | None = None,
|
440
|
+
out_dot_general: DotGeneralT | None = None,
|
441
|
+
qkv_dot_general_cls: Any = None,
|
442
|
+
out_dot_general_cls: Any = None,
|
443
|
+
rngs: rnglib.Rngs,
|
444
|
+
epsilon: float = 1e-5,
|
445
|
+
):
|
446
|
+
self.num_heads = num_heads
|
447
|
+
self.in_features = in_features
|
448
|
+
self.qkv_features = (
|
449
|
+
qkv_features if qkv_features is not None else in_features
|
450
|
+
)
|
451
|
+
self.out_features = (
|
452
|
+
out_features if out_features is not None else in_features
|
453
|
+
)
|
454
|
+
self.dtype = dtype
|
455
|
+
self.param_dtype = param_dtype
|
456
|
+
self.broadcast_dropout = broadcast_dropout
|
457
|
+
self.dropout_rate = dropout_rate
|
458
|
+
self.deterministic = deterministic
|
459
|
+
self.precision = precision
|
460
|
+
self.kernel_init = kernel_init
|
461
|
+
self.out_kernel_init = out_kernel_init
|
462
|
+
self.bias_init = bias_init
|
463
|
+
self.out_bias_init = out_bias_init
|
464
|
+
self.use_bias = use_bias
|
465
|
+
self.attention_fn = attention_fn
|
466
|
+
self.decode = decode
|
467
|
+
self.normalize_qk = normalize_qk
|
468
|
+
self.qkv_dot_general = qkv_dot_general
|
469
|
+
self.out_dot_general = out_dot_general
|
470
|
+
self.qkv_dot_general_cls = qkv_dot_general_cls
|
471
|
+
self.out_dot_general_cls = out_dot_general_cls
|
472
|
+
self.epsilon = epsilon
|
473
|
+
|
474
|
+
if self.qkv_features % self.num_heads != 0:
|
475
|
+
raise ValueError(
|
476
|
+
f'Memory dimension ({self.qkv_features}) must be divisible by '
|
477
|
+
f"'num_heads' heads ({self.num_heads})."
|
478
|
+
)
|
479
|
+
|
480
|
+
self.head_dim = self.qkv_features // self.num_heads
|
481
|
+
|
482
|
+
# Replace LinearGeneral with YatNMN for query, key, value projections
|
483
|
+
yat_linear = functools.partial(
|
484
|
+
YatNMN,
|
485
|
+
in_features=self.in_features,
|
486
|
+
out_features=self.qkv_features, # Output total features, will reshape later
|
487
|
+
dtype=self.dtype,
|
488
|
+
param_dtype=self.param_dtype,
|
489
|
+
kernel_init=self.kernel_init,
|
490
|
+
bias_init=self.bias_init,
|
491
|
+
use_bias=self.use_bias,
|
492
|
+
precision=self.precision,
|
493
|
+
epsilon=self.epsilon,
|
494
|
+
)
|
495
|
+
|
496
|
+
# project inputs_q to multi-headed q/k/v
|
497
|
+
# dimensions will be reshaped to [batch..., length, n_heads, n_features_per_head]
|
498
|
+
self.query = yat_linear(rngs=rngs)
|
499
|
+
self.key = yat_linear(rngs=rngs)
|
500
|
+
self.value = yat_linear(rngs=rngs)
|
501
|
+
|
502
|
+
self.query_ln: LayerNorm | None
|
503
|
+
self.key_ln: LayerNorm | None
|
504
|
+
if self.normalize_qk:
|
505
|
+
# Normalizing query and key projections stabilizes training with higher
|
506
|
+
# LR. See ViT-22B paper http://arxiv.org/abs/2302.05442 for analysis.
|
507
|
+
self.query_ln = LayerNorm(
|
508
|
+
self.head_dim,
|
509
|
+
use_bias=False,
|
510
|
+
dtype=self.dtype,
|
511
|
+
param_dtype=self.param_dtype,
|
512
|
+
rngs=rngs,
|
513
|
+
)
|
514
|
+
self.key_ln = LayerNorm(
|
515
|
+
self.head_dim,
|
516
|
+
use_bias=False,
|
517
|
+
dtype=self.dtype,
|
518
|
+
param_dtype=self.param_dtype,
|
519
|
+
rngs=rngs,
|
520
|
+
)
|
521
|
+
else:
|
522
|
+
self.query_ln = None
|
523
|
+
self.key_ln = None
|
524
|
+
|
525
|
+
# Remove the output layer - no more self.out
|
526
|
+
self.rngs = rngs if dropout_rate > 0.0 else None
|
527
|
+
|
528
|
+
self.cached_key: nnx.Cache[Array] | None = None
|
529
|
+
self.cached_value: nnx.Cache[Array] | None = None
|
530
|
+
self.cache_index: nnx.Cache[Array] | None = None
|
531
|
+
|
532
|
+
def __call__(
|
533
|
+
self,
|
534
|
+
inputs_q: Array,
|
535
|
+
inputs_k: Array | None = None,
|
536
|
+
inputs_v: Array | None = None,
|
537
|
+
*,
|
538
|
+
mask: Array | None = None,
|
539
|
+
deterministic: bool | None = None,
|
540
|
+
rngs: rnglib.Rngs | None = None,
|
541
|
+
sow_weights: bool = False,
|
542
|
+
decode: bool | None = None,
|
543
|
+
):
|
544
|
+
"""Applies multi-head dot product attention on the input data.
|
545
|
+
|
546
|
+
Projects the inputs into multi-headed query, key, and value vectors,
|
547
|
+
applies dot-product attention and project the results to an output vector.
|
548
|
+
|
549
|
+
If both inputs_k and inputs_v are None, they will both copy the value of
|
550
|
+
inputs_q (self attention).
|
551
|
+
If only inputs_v is None, it will copy the value of inputs_k.
|
552
|
+
|
553
|
+
Args:
|
554
|
+
inputs_q: input queries of shape `[batch_sizes..., length, features]`.
|
555
|
+
inputs_k: key of shape `[batch_sizes..., length, features]`. If None,
|
556
|
+
inputs_k will copy the value of inputs_q.
|
557
|
+
inputs_v: values of shape `[batch_sizes..., length, features]`. If None,
|
558
|
+
inputs_v will copy the value of inputs_k.
|
559
|
+
mask: attention mask of shape `[batch_sizes..., num_heads, query_length,
|
560
|
+
key/value_length]`. Attention weights are masked out if their
|
561
|
+
corresponding mask value is `False`.
|
562
|
+
deterministic: if false, the attention weight is masked randomly using
|
563
|
+
dropout, whereas if true, the attention weights are deterministic.
|
564
|
+
rngs: container for random number generators to generate the dropout
|
565
|
+
mask when `deterministic` is False. The `rngs` container should have a
|
566
|
+
`dropout` key.
|
567
|
+
sow_weights: if ``True``, the attention weights are sowed into the
|
568
|
+
'intermediates' collection.
|
569
|
+
|
570
|
+
Returns:
|
571
|
+
output of shape `[batch_sizes..., length, features]`.
|
572
|
+
"""
|
573
|
+
if rngs is None:
|
574
|
+
rngs = self.rngs
|
575
|
+
|
576
|
+
if inputs_k is None:
|
577
|
+
if inputs_v is not None:
|
578
|
+
raise ValueError(
|
579
|
+
'`inputs_k` cannot be None if `inputs_v` is not None. '
|
580
|
+
'To have both `inputs_k` and `inputs_v` be the same value, pass in the '
|
581
|
+
'value to `inputs_k` and leave `inputs_v` as None.'
|
582
|
+
)
|
583
|
+
inputs_k = inputs_q
|
584
|
+
if inputs_v is None:
|
585
|
+
inputs_v = inputs_k
|
586
|
+
|
587
|
+
if inputs_q.shape[-1] != self.in_features:
|
588
|
+
raise ValueError(
|
589
|
+
f'Incompatible input dimension, got {inputs_q.shape[-1]} '
|
590
|
+
f'but module expects {self.in_features}.'
|
591
|
+
)
|
592
|
+
|
593
|
+
# Apply YatNMN transformations and reshape to multi-head format
|
594
|
+
query = squash(self.query(inputs_q))
|
595
|
+
key = squash(self.key(inputs_k))
|
596
|
+
value = squash(self.value(inputs_v))
|
597
|
+
|
598
|
+
# Reshape from [batch..., length, qkv_features] to [batch..., length, num_heads, head_dim]
|
599
|
+
query = query.reshape(query.shape[:-1] + (self.num_heads, self.head_dim))
|
600
|
+
key = key.reshape(key.shape[:-1] + (self.num_heads, self.head_dim))
|
601
|
+
value = value.reshape(value.shape[:-1] + (self.num_heads, self.head_dim))
|
602
|
+
|
603
|
+
if self.normalize_qk:
|
604
|
+
assert self.query_ln is not None and self.key_ln is not None
|
605
|
+
# Normalizing query and key projections stabilizes training with higher
|
606
|
+
# LR. See ViT-22B paper http://arxiv.org/abs/2302.05442 for analysis.
|
607
|
+
query = self.query_ln(query)
|
608
|
+
key = self.key_ln(key)
|
609
|
+
|
610
|
+
# During fast autoregressive decoding, we feed one position at a time,
|
611
|
+
# and cache the keys and values step by step.
|
612
|
+
decode = first_from(
|
613
|
+
decode,
|
614
|
+
self.decode,
|
615
|
+
error_msg="""No `decode` argument was provided to MultiHeadAttention
|
616
|
+
as either a __call__ argument, class attribute, or nnx.flag.""",
|
617
|
+
)
|
618
|
+
|
619
|
+
if decode:
|
620
|
+
if (
|
621
|
+
self.cached_key is None
|
622
|
+
or self.cached_value is None
|
623
|
+
or self.cache_index is None
|
624
|
+
):
|
625
|
+
raise ValueError(
|
626
|
+
'Autoregressive cache not initialized, call ``init_cache`` first.'
|
627
|
+
)
|
628
|
+
(
|
629
|
+
*batch_dims,
|
630
|
+
max_length,
|
631
|
+
num_heads,
|
632
|
+
depth_per_head,
|
633
|
+
) = self.cached_key.value.shape
|
634
|
+
# shape check of cached keys against query input
|
635
|
+
expected_shape = tuple(batch_dims) + (1, num_heads, depth_per_head)
|
636
|
+
if expected_shape != query.shape:
|
637
|
+
raise ValueError(
|
638
|
+
'Autoregressive cache shape error, '
|
639
|
+
'expected query shape %s instead got %s.'
|
640
|
+
# % (expected_shape, query.shape)
|
641
|
+
)
|
642
|
+
# update key, value caches with our new 1d spatial slices
|
643
|
+
cur_index = self.cache_index.value
|
644
|
+
zero = jnp.array(0, dtype=lax.dtype(cur_index.dtype))
|
645
|
+
indices = (zero,) * len(batch_dims) + (cur_index, zero, zero)
|
646
|
+
key = lax.dynamic_update_slice(self.cached_key.value, key, indices)
|
647
|
+
value = lax.dynamic_update_slice(self.cached_value.value, value, indices)
|
648
|
+
self.cached_key.value = key
|
649
|
+
self.cached_value.value = value
|
650
|
+
self.cache_index.value += 1
|
651
|
+
# causal mask for cached decoder self-attention:
|
652
|
+
# our single query position should only attend to those key
|
653
|
+
# positions that have already been generated and cached,
|
654
|
+
# not the remaining zero elements.
|
655
|
+
mask = combine_masks(
|
656
|
+
mask,
|
657
|
+
jnp.broadcast_to(
|
658
|
+
jnp.arange(max_length) <= cur_index,
|
659
|
+
tuple(batch_dims) + (1, 1, max_length),
|
660
|
+
),
|
661
|
+
)
|
662
|
+
|
663
|
+
if (
|
664
|
+
self.dropout_rate > 0.0
|
665
|
+
): # Require `deterministic` only if using dropout.
|
666
|
+
deterministic = first_from(
|
667
|
+
deterministic,
|
668
|
+
self.deterministic,
|
669
|
+
error_msg="""No `deterministic` argument was provided to MultiHeadAttention
|
670
|
+
as either a __call__ argument, class attribute, or nnx.flag.""",
|
671
|
+
)
|
672
|
+
if not deterministic:
|
673
|
+
if rngs is None:
|
674
|
+
raise ValueError(
|
675
|
+
"'rngs' must be provided if 'dropout_rng' is not given."
|
676
|
+
)
|
677
|
+
dropout_rng = rngs.dropout()
|
678
|
+
else:
|
679
|
+
dropout_rng = None
|
680
|
+
else:
|
681
|
+
deterministic = True
|
682
|
+
dropout_rng = None
|
683
|
+
|
684
|
+
# apply attention with epsilon parameter for YatNMN
|
685
|
+
x = self.attention_fn(
|
686
|
+
query,
|
687
|
+
key,
|
688
|
+
value,
|
689
|
+
mask=mask,
|
690
|
+
dropout_rng=dropout_rng,
|
691
|
+
dropout_rate=self.dropout_rate,
|
692
|
+
broadcast_dropout=self.broadcast_dropout,
|
693
|
+
deterministic=deterministic,
|
694
|
+
dtype=self.dtype,
|
695
|
+
precision=self.precision,
|
696
|
+
module=self if sow_weights else None,
|
697
|
+
epsilon=self.epsilon, # Pass epsilon to yat_attention
|
698
|
+
)
|
699
|
+
# Reshape attention output back to original embedding dimension
|
700
|
+
# from [batch..., length, num_heads, head_dim] to [batch..., length, qkv_features]
|
701
|
+
x = x.reshape(x.shape[:-2] + (self.qkv_features,))
|
702
|
+
return x
|
703
|
+
|
704
|
+
def init_cache(self, input_shape: Shape, dtype: Dtype = jnp.float32):
|
705
|
+
|
706
|
+
cache_shape = (*input_shape[:-1], self.num_heads, self.head_dim)
|
707
|
+
self.cached_key = nnx.Cache(jnp.zeros(cache_shape, dtype))
|
708
|
+
self.cached_value = nnx.Cache(jnp.zeros(cache_shape, dtype))
|
709
|
+
self.cache_index = nnx.Cache(jnp.array(0, dtype=jnp.int32))
|
710
|
+
|
711
|
+
|
712
|
+
# mask-making utility functions
|
713
|
+
|
714
|
+
|
715
|
+
def make_attention_mask(
|
716
|
+
query_input: Array,
|
717
|
+
key_input: Array,
|
718
|
+
pairwise_fn: Callable[..., Any] = jnp.multiply,
|
719
|
+
extra_batch_dims: int = 0,
|
720
|
+
dtype: Dtype = jnp.float32,
|
721
|
+
):
|
722
|
+
|
723
|
+
mask = pairwise_fn(
|
724
|
+
jnp.expand_dims(query_input, axis=-1), jnp.expand_dims(key_input, axis=-2)
|
725
|
+
)
|
726
|
+
mask = jnp.expand_dims(mask, axis=-3)
|
727
|
+
mask = jnp.expand_dims(mask, axis=tuple(range(extra_batch_dims)))
|
728
|
+
return mask.astype(dtype)
|
729
|
+
|
730
|
+
|
731
|
+
def make_causal_mask(
|
732
|
+
x: Array, extra_batch_dims: int = 0, dtype: Dtype = jnp.float32
|
733
|
+
) -> Array:
|
734
|
+
|
735
|
+
idxs = jnp.broadcast_to(jnp.arange(x.shape[-1], dtype=jnp.int32), x.shape)
|
736
|
+
return make_attention_mask(
|
737
|
+
idxs,
|
738
|
+
idxs,
|
739
|
+
jnp.greater_equal,
|
740
|
+
extra_batch_dims=extra_batch_dims,
|
741
|
+
dtype=dtype,
|
742
|
+
)
|
743
|
+
|
744
|
+
|
745
|
+
def combine_masks(
|
746
|
+
*masks: Optional[Array], dtype: Dtype = jnp.float32
|
747
|
+
) -> Array | None:
|
748
|
+
|
749
|
+
masks_list = [m for m in masks if m is not None]
|
750
|
+
if not masks_list:
|
751
|
+
return None
|
752
|
+
assert all(
|
753
|
+
map(lambda x: x.ndim == masks_list[0].ndim, masks_list)
|
754
|
+
), f'masks must have same rank: {tuple(map(lambda x: x.ndim, masks_list))}'
|
755
|
+
mask, *other_masks = masks_list
|
756
|
+
for other_mask in other_masks:
|
757
|
+
mask = jnp.logical_and(mask, other_mask)
|
758
|
+
return mask.astype(dtype)
|
759
|
+
|
760
|
+
|
761
|
+
|
762
|
+
# Define a triangular mask for causal attention with `jax.numpy.tril` and `jax.numpy.ones`.
|
763
|
+
def causal_attention_mask(seq_len):
|
764
|
+
return jnp.tril(jnp.ones((seq_len, seq_len)))
|