nmn 0.1.2__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,764 @@
1
+ from __future__ import annotations
2
+
3
+ import functools
4
+ from typing import Any, Callable, Optional
5
+ import typing as tp
6
+
7
+ import jax
8
+ import jax.numpy as jnp
9
+ from jax import lax, random
10
+
11
+ from flax import nnx
12
+ from flax.nnx import rnglib
13
+ from flax.nnx.module import Module, first_from
14
+ from flax.nnx.nn import initializers
15
+ from flax.nnx.nn.dtypes import promote_dtype
16
+ from flax.nnx.nn.linear import (
17
+ LinearGeneral,
18
+ default_kernel_init,
19
+ )
20
+ from flax.nnx.nn.normalization import LayerNorm
21
+ from flax.typing import (
22
+ Dtype,
23
+ Shape,
24
+ Initializer,
25
+ PrecisionLike,
26
+ DotGeneralT,
27
+ )
28
+
29
+
30
+
31
+ def yat_attention_weights(
32
+ query: Array,
33
+ key: Array,
34
+ bias: Optional[Array] = None,
35
+ mask: Optional[Array] = None,
36
+ broadcast_dropout: bool = True,
37
+ dropout_rng: Optional[Array] = None,
38
+ dropout_rate: float = 0.0,
39
+ deterministic: bool = False,
40
+ dtype: Optional[Dtype] = None,
41
+ precision: PrecisionLike = None,
42
+ module: Optional[Module] = None,
43
+ epsilon: float = 1e-5,
44
+ ):
45
+ """Computes attention weights using YatNMN distance-based calculation."""
46
+ query, key = promote_dtype((query, key), dtype=dtype)
47
+ dtype = query.dtype
48
+
49
+ assert query.ndim == key.ndim, 'q, k must have same rank.'
50
+ assert query.shape[:-3] == key.shape[:-3], 'q, k batch dims must match.'
51
+ assert query.shape[-2] == key.shape[-2], 'q, k num_heads must match.'
52
+ assert query.shape[-1] == key.shape[-1], 'q, k depths must match.'
53
+
54
+ # YatNMN-style attention calculation using the cleaner approach
55
+ # query shape: [..., q_length, num_heads, head_dim]
56
+ # key shape: [..., kv_length, num_heads, head_dim]
57
+
58
+ # Calculate dot product attention scores
59
+ attn = jnp.einsum('...qhd,...khd->...hqk', query, key, precision=precision)
60
+ squared_dot_product = jnp.square(attn)
61
+
62
+ # Calculate norms
63
+ q_norm = jnp.sum(jnp.square(query), axis=-1, keepdims=True) # [..., q_length, num_heads, 1]
64
+ k_norm = jnp.sum(jnp.square(key), axis=-1, keepdims=True) # [..., kv_length, num_heads, 1]
65
+ qk_norm_sum = q_norm + k_norm # Broadcasting: [..., q_length, num_heads, 1] + [..., kv_length, num_heads, 1]
66
+
67
+ # Transpose to match attention dimensions [..., num_heads, q_length, kv_length]
68
+ # The transpose converts [..., q_length, num_heads, kv_length] -> [..., num_heads, q_length, kv_length]
69
+ batch_dims = len(qk_norm_sum.shape) - 3
70
+ transpose_axes = tuple(range(batch_dims)) + (batch_dims + 1, batch_dims, batch_dims + 2)
71
+ qk_norm_sum_transposed = qk_norm_sum.transpose(transpose_axes)
72
+
73
+ # Calculate squared distances: ||q||² + ||k||² - 2*(q·k)²
74
+ squared_dist = qk_norm_sum_transposed - 2.0 * squared_dot_product
75
+
76
+ # YatNMN attention scores: (q·k)² / (squared_distance + ε)
77
+ attn_weights = squared_dot_product / (squared_dist + epsilon)
78
+
79
+ # apply attention bias: masking, dropout, proximity bias, etc.
80
+ if bias is not None:
81
+ attn_weights = attn_weights + bias
82
+ # apply attention mask
83
+ if mask is not None:
84
+ big_neg = jnp.finfo(dtype).min
85
+ attn_weights = jnp.where(mask, attn_weights, big_neg)
86
+
87
+ # normalize the attention weights
88
+ attn_weights = jax.nn.softmax(attn_weights).astype(dtype)
89
+
90
+ if module:
91
+ module.sow(nnx.Intermediate, 'attention_weights', attn_weights)
92
+
93
+ # apply attention dropout
94
+ if not deterministic and dropout_rate > 0.0:
95
+ keep_prob = 1.0 - dropout_rate
96
+ if broadcast_dropout:
97
+ # dropout is broadcast across the batch + head dimensions
98
+ dropout_shape = tuple([1] * (key.ndim - 2)) + attn_weights.shape[-2:]
99
+ keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape)
100
+ else:
101
+ keep = random.bernoulli(dropout_rng, keep_prob, attn_weights.shape)
102
+ multiplier = keep.astype(dtype) / jnp.asarray(keep_prob, dtype=dtype)
103
+ attn_weights = attn_weights * multiplier
104
+
105
+ return attn_weights
106
+
107
+
108
+ def yat_attention(
109
+ query: Array,
110
+ key: Array,
111
+ value: Array,
112
+ bias: Optional[Array] = None,
113
+ mask: Optional[Array] = None,
114
+ broadcast_dropout: bool = True,
115
+ dropout_rng: Optional[Array] = None,
116
+ dropout_rate: float = 0.0,
117
+ deterministic: bool = False,
118
+ dtype: Optional[Dtype] = None,
119
+ precision: PrecisionLike = None,
120
+ module: Optional[Module] = None,
121
+ epsilon: float = 1e-5,
122
+ ):
123
+ """Computes attention using YatNMN distance-based calculation."""
124
+ query, key, value = promote_dtype((query, key, value), dtype=dtype)
125
+ dtype = query.dtype
126
+ assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.'
127
+ assert (
128
+ query.shape[:-3] == key.shape[:-3] == value.shape[:-3]
129
+ ), 'q, k, v batch dims must match.'
130
+ assert (
131
+ query.shape[-2] == key.shape[-2] == value.shape[-2]
132
+ ), 'q, k, v num_heads must match.'
133
+ assert key.shape[-3] == value.shape[-3], 'k, v lengths must match.'
134
+
135
+ # compute attention weights using YatNMN
136
+ attn_weights = yat_attention_weights(
137
+ query,
138
+ key,
139
+ bias,
140
+ mask,
141
+ broadcast_dropout,
142
+ dropout_rng,
143
+ dropout_rate,
144
+ deterministic,
145
+ dtype,
146
+ precision,
147
+ module,
148
+ epsilon,
149
+ )
150
+
151
+ # return weighted sum over values for each query position
152
+ return jnp.einsum(
153
+ '...hqk,...khd->...qhd', attn_weights, value, precision=precision
154
+ )
155
+
156
+ Array = jax.Array
157
+
158
+ # Add YatNMN class implementation
159
+ default_bias_init = initializers.zeros_init()
160
+ default_alpha_init = initializers.ones_init()
161
+
162
+ class YatNMN(Module):
163
+ """A linear transformation with custom distance-based computation."""
164
+
165
+ def __init__(
166
+ self,
167
+ in_features: int,
168
+ out_features: int,
169
+ *,
170
+ use_bias: bool = True,
171
+ use_alpha: bool = True,
172
+ dtype: Optional[Dtype] = None,
173
+ param_dtype: Dtype = jnp.float32,
174
+ precision: PrecisionLike = None,
175
+ kernel_init: Initializer = default_kernel_init,
176
+ bias_init: Initializer = default_bias_init,
177
+ alpha_init: Initializer = default_alpha_init,
178
+ dot_general: DotGeneralT = lax.dot_general,
179
+ rngs: rnglib.Rngs,
180
+ epsilon: float = 1e-5,
181
+ ):
182
+
183
+ kernel_key = rngs.params()
184
+ self.kernel = nnx.Param(
185
+ kernel_init(kernel_key, (in_features, out_features), param_dtype)
186
+ )
187
+ self.bias: nnx.Param[jax.Array] | None
188
+ if use_bias:
189
+ bias_key = rngs.params()
190
+ self.bias = nnx.Param(bias_init(bias_key, (out_features,), param_dtype))
191
+ else:
192
+ self.bias = None
193
+
194
+ self.alpha: nnx.Param[jax.Array] | None
195
+ if use_alpha:
196
+ alpha_key = rngs.params()
197
+ self.alpha = nnx.Param(alpha_init(alpha_key, (1,), param_dtype))
198
+ else:
199
+ self.alpha = None
200
+
201
+ self.in_features = in_features
202
+ self.out_features = out_features
203
+ self.use_bias = use_bias
204
+ self.use_alpha = use_alpha
205
+ self.dtype = dtype
206
+ self.param_dtype = param_dtype
207
+ self.precision = precision
208
+ self.kernel_init = kernel_init
209
+ self.bias_init = bias_init
210
+ self.dot_general = dot_general
211
+ self.epsilon = epsilon
212
+
213
+ def __call__(self, inputs: Array) -> Array:
214
+ """Applies YatNMN transformation to inputs."""
215
+ kernel = self.kernel.value
216
+ bias = self.bias.value if self.bias is not None else None
217
+ alpha = self.alpha.value if self.alpha is not None else None
218
+
219
+ y = self.dot_general(
220
+ inputs,
221
+ kernel,
222
+ (((inputs.ndim - 1,), (0,)), ((), ())),
223
+ precision=self.precision,
224
+ )
225
+
226
+ inputs_squared_sum = jnp.sum(inputs**2, axis=-1, keepdims=True)
227
+ kernel_squared_sum = jnp.sum(kernel**2, axis=0, keepdims=True)
228
+ distances = inputs_squared_sum + kernel_squared_sum - 2 * y
229
+
230
+ # Element-wise operation
231
+ y = y ** 2 / (distances + self.epsilon)
232
+
233
+ if bias is not None:
234
+ y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,))
235
+
236
+ if alpha is not None:
237
+ scale = (jnp.sqrt(self.out_features) / jnp.log(1 + self.out_features)) ** alpha
238
+ y = y * scale
239
+
240
+ return y
241
+
242
+
243
+ def dot_product_attention_weights(
244
+ query: Array,
245
+ key: Array,
246
+ bias: Optional[Array] = None,
247
+ mask: Optional[Array] = None,
248
+ broadcast_dropout: bool = True,
249
+ dropout_rng: Optional[Array] = None,
250
+ dropout_rate: float = 0.0,
251
+ deterministic: bool = False,
252
+ dtype: Optional[Dtype] = None,
253
+ precision: PrecisionLike = None,
254
+ module: Optional[Module] = None,
255
+ ):
256
+ """Computes dot-product attention weights given query and key.
257
+
258
+ Used by :func:`dot_product_attention`, which is what you'll most likely use.
259
+ But if you want access to the attention weights for introspection, then
260
+ you can directly call this function and call einsum yourself.
261
+
262
+ Args:
263
+ query: queries for calculating attention with shape of `[batch..., q_length,
264
+ num_heads, qk_depth_per_head]`.
265
+ key: keys for calculating attention with shape of `[batch..., kv_length,
266
+ num_heads, qk_depth_per_head]`.
267
+ bias: bias for the attention weights. This should be broadcastable to the
268
+ shape `[batch..., num_heads, q_length, kv_length]`. This can be used for
269
+ incorporating causal masks, padding masks, proximity bias, etc.
270
+ mask: mask for the attention weights. This should be broadcastable to the
271
+ shape `[batch..., num_heads, q_length, kv_length]`. This can be used for
272
+ incorporating causal masks. Attention weights are masked out if their
273
+ corresponding mask value is `False`.
274
+ broadcast_dropout: bool: use a broadcasted dropout along batch dims.
275
+ dropout_rng: JAX PRNGKey: to be used for dropout
276
+ dropout_rate: dropout rate
277
+ deterministic: bool, deterministic or not (to apply dropout)
278
+ dtype: the dtype of the computation (default: infer from inputs and params)
279
+ precision: numerical precision of the computation see `jax.lax.Precision`
280
+ for details.
281
+ module: the Module that will sow the attention weights into the
282
+ ``nnx.Intermediate`` collection. If ``module`` is None, the attention
283
+ weights will not be sowed.
284
+
285
+ Returns:
286
+ Output of shape `[batch..., num_heads, q_length, kv_length]`.
287
+ """
288
+ query, key = promote_dtype((query, key), dtype=dtype) # type: ignore[bad-unpacking]
289
+ dtype = query.dtype
290
+
291
+ assert query.ndim == key.ndim, 'q, k must have same rank.'
292
+ assert query.shape[:-3] == key.shape[:-3], 'q, k batch dims must match.'
293
+ assert query.shape[-2] == key.shape[-2], 'q, k num_heads must match.'
294
+ assert query.shape[-1] == key.shape[-1], 'q, k depths must match.'
295
+
296
+ # calculate attention matrix
297
+ depth = query.shape[-1]
298
+ query = query / jnp.sqrt(depth).astype(dtype)
299
+ # attn weight shape is (batch..., num_heads, q_length, kv_length)
300
+ attn_weights = jnp.einsum(
301
+ '...qhd,...khd->...hqk', query, key, precision=precision
302
+ )
303
+
304
+ # apply attention bias: masking, dropout, proximity bias, etc.
305
+ if bias is not None:
306
+ attn_weights = attn_weights + bias
307
+ # apply attention mask
308
+ if mask is not None:
309
+ big_neg = jnp.finfo(dtype).min
310
+ attn_weights = jnp.where(mask, attn_weights, big_neg)
311
+
312
+ # normalize the attention weights
313
+ attn_weights = jax.nn.softmax(attn_weights).astype(dtype)
314
+
315
+ if module:
316
+ module.sow(nnx.Intermediate, 'attention_weights', attn_weights)
317
+
318
+ # apply attention dropout
319
+ if not deterministic and dropout_rate > 0.0:
320
+ keep_prob = 1.0 - dropout_rate
321
+ if broadcast_dropout:
322
+ # dropout is broadcast across the batch + head dimensions
323
+ dropout_shape = tuple([1] * (key.ndim - 2)) + attn_weights.shape[-2:]
324
+ keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape) # type: ignore
325
+ else:
326
+ keep = random.bernoulli(dropout_rng, keep_prob, attn_weights.shape) # type: ignore
327
+ multiplier = keep.astype(dtype) / jnp.asarray(keep_prob, dtype=dtype)
328
+ attn_weights = attn_weights * multiplier
329
+
330
+ return attn_weights
331
+
332
+
333
+ def dot_product_attention(
334
+ query: Array,
335
+ key: Array,
336
+ value: Array,
337
+ bias: Optional[Array] = None,
338
+ mask: Optional[Array] = None,
339
+ broadcast_dropout: bool = True,
340
+ dropout_rng: Optional[Array] = None,
341
+ dropout_rate: float = 0.0,
342
+ deterministic: bool = False,
343
+ dtype: Optional[Dtype] = None,
344
+ precision: PrecisionLike = None,
345
+ module: Optional[Module] = None,
346
+ ):
347
+ """Computes dot-product attention given query, key, and value.
348
+
349
+ This is the core function for applying attention based on
350
+ https://arxiv.org/abs/1706.03762. It calculates the attention weights given
351
+ query and key and combines the values using the attention weights.
352
+
353
+ .. note::
354
+ ``query``, ``key``, ``value`` needn't have any batch dimensions.
355
+
356
+ Args:
357
+ query: queries for calculating attention with shape of ``[batch..., q_length,
358
+ num_heads, qk_depth_per_head]``.
359
+ key: keys for calculating attention with shape of ``[batch..., kv_length,
360
+ num_heads, qk_depth_per_head]``.
361
+ value: values to be used in attention with shape of ``[batch..., kv_length,
362
+ num_heads, v_depth_per_head]``.
363
+ bias: bias for the attention weights. This should be broadcastable to the
364
+ shape `[batch..., num_heads, q_length, kv_length]`. This can be used for
365
+ incorporating causal masks, padding masks, proximity bias, etc.
366
+ mask: mask for the attention weights. This should be broadcastable to the
367
+ shape `[batch..., num_heads, q_length, kv_length]`. This can be used for
368
+ incorporating causal masks. Attention weights are masked out if their
369
+ corresponding mask value is `False`.
370
+ broadcast_dropout: bool: use a broadcasted dropout along batch dims.
371
+ dropout_rng: JAX PRNGKey: to be used for dropout
372
+ dropout_rate: dropout rate
373
+ deterministic: bool, deterministic or not (to apply dropout)
374
+ dtype: the dtype of the computation (default: infer from inputs)
375
+ precision: numerical precision of the computation see `jax.lax.Precision`
376
+ for details.
377
+ module: the Module that will sow the attention weights into the
378
+ ``nnx.Intermediate`` collection. If ``module`` is None, the attention
379
+ weights will not be sowed.
380
+
381
+ Returns:
382
+ Output of shape `[batch..., q_length, num_heads, v_depth_per_head]`.
383
+ """
384
+ query, key, value = promote_dtype((query, key, value), dtype=dtype) # type: ignore[bad-unpacking]
385
+ dtype = query.dtype
386
+ assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.'
387
+ assert (
388
+ query.shape[:-3] == key.shape[:-3] == value.shape[:-3]
389
+ ), 'q, k, v batch dims must match.'
390
+ assert (
391
+ query.shape[-2] == key.shape[-2] == value.shape[-2]
392
+ ), 'q, k, v num_heads must match.'
393
+ assert key.shape[-3] == value.shape[-3], 'k, v lengths must match.'
394
+
395
+ # compute attention weights
396
+ attn_weights = dot_product_attention_weights(
397
+ query,
398
+ key,
399
+ bias,
400
+ mask,
401
+ broadcast_dropout,
402
+ dropout_rng,
403
+ dropout_rate,
404
+ deterministic,
405
+ dtype,
406
+ precision,
407
+ module,
408
+ )
409
+
410
+ # return weighted sum over values for each query position
411
+ return jnp.einsum(
412
+ '...hqk,...khd->...qhd', attn_weights, value, precision=precision
413
+ )
414
+
415
+
416
+ class MultiHeadAttention(Module):
417
+ def __init__(
418
+ self,
419
+ num_heads: int,
420
+ in_features: int,
421
+ qkv_features: int | None = None,
422
+ out_features: int | None = None,
423
+ *,
424
+ dtype: Dtype | None = None,
425
+ param_dtype: Dtype = jnp.float32,
426
+ broadcast_dropout: bool = True,
427
+ dropout_rate: float = 0.0,
428
+ deterministic: bool | None = None,
429
+ precision: PrecisionLike = None,
430
+ kernel_init: Initializer = default_kernel_init,
431
+ out_kernel_init: Initializer | None = None,
432
+ bias_init: Initializer = initializers.zeros_init(),
433
+ out_bias_init: Initializer | None = None,
434
+ use_bias: bool = True,
435
+ attention_fn: Callable[..., Array] = yat_attention,
436
+ decode: bool | None = None,
437
+ normalize_qk: bool = False,
438
+ # Deprecated, will be removed.
439
+ qkv_dot_general: DotGeneralT | None = None,
440
+ out_dot_general: DotGeneralT | None = None,
441
+ qkv_dot_general_cls: Any = None,
442
+ out_dot_general_cls: Any = None,
443
+ rngs: rnglib.Rngs,
444
+ epsilon: float = 1e-5,
445
+ ):
446
+ self.num_heads = num_heads
447
+ self.in_features = in_features
448
+ self.qkv_features = (
449
+ qkv_features if qkv_features is not None else in_features
450
+ )
451
+ self.out_features = (
452
+ out_features if out_features is not None else in_features
453
+ )
454
+ self.dtype = dtype
455
+ self.param_dtype = param_dtype
456
+ self.broadcast_dropout = broadcast_dropout
457
+ self.dropout_rate = dropout_rate
458
+ self.deterministic = deterministic
459
+ self.precision = precision
460
+ self.kernel_init = kernel_init
461
+ self.out_kernel_init = out_kernel_init
462
+ self.bias_init = bias_init
463
+ self.out_bias_init = out_bias_init
464
+ self.use_bias = use_bias
465
+ self.attention_fn = attention_fn
466
+ self.decode = decode
467
+ self.normalize_qk = normalize_qk
468
+ self.qkv_dot_general = qkv_dot_general
469
+ self.out_dot_general = out_dot_general
470
+ self.qkv_dot_general_cls = qkv_dot_general_cls
471
+ self.out_dot_general_cls = out_dot_general_cls
472
+ self.epsilon = epsilon
473
+
474
+ if self.qkv_features % self.num_heads != 0:
475
+ raise ValueError(
476
+ f'Memory dimension ({self.qkv_features}) must be divisible by '
477
+ f"'num_heads' heads ({self.num_heads})."
478
+ )
479
+
480
+ self.head_dim = self.qkv_features // self.num_heads
481
+
482
+ # Replace LinearGeneral with YatNMN for query, key, value projections
483
+ yat_linear = functools.partial(
484
+ YatNMN,
485
+ in_features=self.in_features,
486
+ out_features=self.qkv_features, # Output total features, will reshape later
487
+ dtype=self.dtype,
488
+ param_dtype=self.param_dtype,
489
+ kernel_init=self.kernel_init,
490
+ bias_init=self.bias_init,
491
+ use_bias=self.use_bias,
492
+ precision=self.precision,
493
+ epsilon=self.epsilon,
494
+ )
495
+
496
+ # project inputs_q to multi-headed q/k/v
497
+ # dimensions will be reshaped to [batch..., length, n_heads, n_features_per_head]
498
+ self.query = yat_linear(rngs=rngs)
499
+ self.key = yat_linear(rngs=rngs)
500
+ self.value = yat_linear(rngs=rngs)
501
+
502
+ self.query_ln: LayerNorm | None
503
+ self.key_ln: LayerNorm | None
504
+ if self.normalize_qk:
505
+ # Normalizing query and key projections stabilizes training with higher
506
+ # LR. See ViT-22B paper http://arxiv.org/abs/2302.05442 for analysis.
507
+ self.query_ln = LayerNorm(
508
+ self.head_dim,
509
+ use_bias=False,
510
+ dtype=self.dtype,
511
+ param_dtype=self.param_dtype,
512
+ rngs=rngs,
513
+ )
514
+ self.key_ln = LayerNorm(
515
+ self.head_dim,
516
+ use_bias=False,
517
+ dtype=self.dtype,
518
+ param_dtype=self.param_dtype,
519
+ rngs=rngs,
520
+ )
521
+ else:
522
+ self.query_ln = None
523
+ self.key_ln = None
524
+
525
+ # Remove the output layer - no more self.out
526
+ self.rngs = rngs if dropout_rate > 0.0 else None
527
+
528
+ self.cached_key: nnx.Cache[Array] | None = None
529
+ self.cached_value: nnx.Cache[Array] | None = None
530
+ self.cache_index: nnx.Cache[Array] | None = None
531
+
532
+ def __call__(
533
+ self,
534
+ inputs_q: Array,
535
+ inputs_k: Array | None = None,
536
+ inputs_v: Array | None = None,
537
+ *,
538
+ mask: Array | None = None,
539
+ deterministic: bool | None = None,
540
+ rngs: rnglib.Rngs | None = None,
541
+ sow_weights: bool = False,
542
+ decode: bool | None = None,
543
+ ):
544
+ """Applies multi-head dot product attention on the input data.
545
+
546
+ Projects the inputs into multi-headed query, key, and value vectors,
547
+ applies dot-product attention and project the results to an output vector.
548
+
549
+ If both inputs_k and inputs_v are None, they will both copy the value of
550
+ inputs_q (self attention).
551
+ If only inputs_v is None, it will copy the value of inputs_k.
552
+
553
+ Args:
554
+ inputs_q: input queries of shape `[batch_sizes..., length, features]`.
555
+ inputs_k: key of shape `[batch_sizes..., length, features]`. If None,
556
+ inputs_k will copy the value of inputs_q.
557
+ inputs_v: values of shape `[batch_sizes..., length, features]`. If None,
558
+ inputs_v will copy the value of inputs_k.
559
+ mask: attention mask of shape `[batch_sizes..., num_heads, query_length,
560
+ key/value_length]`. Attention weights are masked out if their
561
+ corresponding mask value is `False`.
562
+ deterministic: if false, the attention weight is masked randomly using
563
+ dropout, whereas if true, the attention weights are deterministic.
564
+ rngs: container for random number generators to generate the dropout
565
+ mask when `deterministic` is False. The `rngs` container should have a
566
+ `dropout` key.
567
+ sow_weights: if ``True``, the attention weights are sowed into the
568
+ 'intermediates' collection.
569
+
570
+ Returns:
571
+ output of shape `[batch_sizes..., length, features]`.
572
+ """
573
+ if rngs is None:
574
+ rngs = self.rngs
575
+
576
+ if inputs_k is None:
577
+ if inputs_v is not None:
578
+ raise ValueError(
579
+ '`inputs_k` cannot be None if `inputs_v` is not None. '
580
+ 'To have both `inputs_k` and `inputs_v` be the same value, pass in the '
581
+ 'value to `inputs_k` and leave `inputs_v` as None.'
582
+ )
583
+ inputs_k = inputs_q
584
+ if inputs_v is None:
585
+ inputs_v = inputs_k
586
+
587
+ if inputs_q.shape[-1] != self.in_features:
588
+ raise ValueError(
589
+ f'Incompatible input dimension, got {inputs_q.shape[-1]} '
590
+ f'but module expects {self.in_features}.'
591
+ )
592
+
593
+ # Apply YatNMN transformations and reshape to multi-head format
594
+ query = squash(self.query(inputs_q))
595
+ key = squash(self.key(inputs_k))
596
+ value = squash(self.value(inputs_v))
597
+
598
+ # Reshape from [batch..., length, qkv_features] to [batch..., length, num_heads, head_dim]
599
+ query = query.reshape(query.shape[:-1] + (self.num_heads, self.head_dim))
600
+ key = key.reshape(key.shape[:-1] + (self.num_heads, self.head_dim))
601
+ value = value.reshape(value.shape[:-1] + (self.num_heads, self.head_dim))
602
+
603
+ if self.normalize_qk:
604
+ assert self.query_ln is not None and self.key_ln is not None
605
+ # Normalizing query and key projections stabilizes training with higher
606
+ # LR. See ViT-22B paper http://arxiv.org/abs/2302.05442 for analysis.
607
+ query = self.query_ln(query)
608
+ key = self.key_ln(key)
609
+
610
+ # During fast autoregressive decoding, we feed one position at a time,
611
+ # and cache the keys and values step by step.
612
+ decode = first_from(
613
+ decode,
614
+ self.decode,
615
+ error_msg="""No `decode` argument was provided to MultiHeadAttention
616
+ as either a __call__ argument, class attribute, or nnx.flag.""",
617
+ )
618
+
619
+ if decode:
620
+ if (
621
+ self.cached_key is None
622
+ or self.cached_value is None
623
+ or self.cache_index is None
624
+ ):
625
+ raise ValueError(
626
+ 'Autoregressive cache not initialized, call ``init_cache`` first.'
627
+ )
628
+ (
629
+ *batch_dims,
630
+ max_length,
631
+ num_heads,
632
+ depth_per_head,
633
+ ) = self.cached_key.value.shape
634
+ # shape check of cached keys against query input
635
+ expected_shape = tuple(batch_dims) + (1, num_heads, depth_per_head)
636
+ if expected_shape != query.shape:
637
+ raise ValueError(
638
+ 'Autoregressive cache shape error, '
639
+ 'expected query shape %s instead got %s.'
640
+ # % (expected_shape, query.shape)
641
+ )
642
+ # update key, value caches with our new 1d spatial slices
643
+ cur_index = self.cache_index.value
644
+ zero = jnp.array(0, dtype=lax.dtype(cur_index.dtype))
645
+ indices = (zero,) * len(batch_dims) + (cur_index, zero, zero)
646
+ key = lax.dynamic_update_slice(self.cached_key.value, key, indices)
647
+ value = lax.dynamic_update_slice(self.cached_value.value, value, indices)
648
+ self.cached_key.value = key
649
+ self.cached_value.value = value
650
+ self.cache_index.value += 1
651
+ # causal mask for cached decoder self-attention:
652
+ # our single query position should only attend to those key
653
+ # positions that have already been generated and cached,
654
+ # not the remaining zero elements.
655
+ mask = combine_masks(
656
+ mask,
657
+ jnp.broadcast_to(
658
+ jnp.arange(max_length) <= cur_index,
659
+ tuple(batch_dims) + (1, 1, max_length),
660
+ ),
661
+ )
662
+
663
+ if (
664
+ self.dropout_rate > 0.0
665
+ ): # Require `deterministic` only if using dropout.
666
+ deterministic = first_from(
667
+ deterministic,
668
+ self.deterministic,
669
+ error_msg="""No `deterministic` argument was provided to MultiHeadAttention
670
+ as either a __call__ argument, class attribute, or nnx.flag.""",
671
+ )
672
+ if not deterministic:
673
+ if rngs is None:
674
+ raise ValueError(
675
+ "'rngs' must be provided if 'dropout_rng' is not given."
676
+ )
677
+ dropout_rng = rngs.dropout()
678
+ else:
679
+ dropout_rng = None
680
+ else:
681
+ deterministic = True
682
+ dropout_rng = None
683
+
684
+ # apply attention with epsilon parameter for YatNMN
685
+ x = self.attention_fn(
686
+ query,
687
+ key,
688
+ value,
689
+ mask=mask,
690
+ dropout_rng=dropout_rng,
691
+ dropout_rate=self.dropout_rate,
692
+ broadcast_dropout=self.broadcast_dropout,
693
+ deterministic=deterministic,
694
+ dtype=self.dtype,
695
+ precision=self.precision,
696
+ module=self if sow_weights else None,
697
+ epsilon=self.epsilon, # Pass epsilon to yat_attention
698
+ )
699
+ # Reshape attention output back to original embedding dimension
700
+ # from [batch..., length, num_heads, head_dim] to [batch..., length, qkv_features]
701
+ x = x.reshape(x.shape[:-2] + (self.qkv_features,))
702
+ return x
703
+
704
+ def init_cache(self, input_shape: Shape, dtype: Dtype = jnp.float32):
705
+
706
+ cache_shape = (*input_shape[:-1], self.num_heads, self.head_dim)
707
+ self.cached_key = nnx.Cache(jnp.zeros(cache_shape, dtype))
708
+ self.cached_value = nnx.Cache(jnp.zeros(cache_shape, dtype))
709
+ self.cache_index = nnx.Cache(jnp.array(0, dtype=jnp.int32))
710
+
711
+
712
+ # mask-making utility functions
713
+
714
+
715
+ def make_attention_mask(
716
+ query_input: Array,
717
+ key_input: Array,
718
+ pairwise_fn: Callable[..., Any] = jnp.multiply,
719
+ extra_batch_dims: int = 0,
720
+ dtype: Dtype = jnp.float32,
721
+ ):
722
+
723
+ mask = pairwise_fn(
724
+ jnp.expand_dims(query_input, axis=-1), jnp.expand_dims(key_input, axis=-2)
725
+ )
726
+ mask = jnp.expand_dims(mask, axis=-3)
727
+ mask = jnp.expand_dims(mask, axis=tuple(range(extra_batch_dims)))
728
+ return mask.astype(dtype)
729
+
730
+
731
+ def make_causal_mask(
732
+ x: Array, extra_batch_dims: int = 0, dtype: Dtype = jnp.float32
733
+ ) -> Array:
734
+
735
+ idxs = jnp.broadcast_to(jnp.arange(x.shape[-1], dtype=jnp.int32), x.shape)
736
+ return make_attention_mask(
737
+ idxs,
738
+ idxs,
739
+ jnp.greater_equal,
740
+ extra_batch_dims=extra_batch_dims,
741
+ dtype=dtype,
742
+ )
743
+
744
+
745
+ def combine_masks(
746
+ *masks: Optional[Array], dtype: Dtype = jnp.float32
747
+ ) -> Array | None:
748
+
749
+ masks_list = [m for m in masks if m is not None]
750
+ if not masks_list:
751
+ return None
752
+ assert all(
753
+ map(lambda x: x.ndim == masks_list[0].ndim, masks_list)
754
+ ), f'masks must have same rank: {tuple(map(lambda x: x.ndim, masks_list))}'
755
+ mask, *other_masks = masks_list
756
+ for other_mask in other_masks:
757
+ mask = jnp.logical_and(mask, other_mask)
758
+ return mask.astype(dtype)
759
+
760
+
761
+
762
+ # Define a triangular mask for causal attention with `jax.numpy.tril` and `jax.numpy.ones`.
763
+ def causal_attention_mask(seq_len):
764
+ return jnp.tril(jnp.ones((seq_len, seq_len)))