flaxdiff 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,723 @@
1
+ # coding=utf-8
2
+ # Copyright 2024 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Core Fast Attention Module for Flax.
17
+
18
+ Implementation of the approximate fast softmax and generalized
19
+ attention mechanism leveraging structured random feature maps [RFM] techniques
20
+ and low rank decomposition of the attention matrix.
21
+ """
22
+ # pylint: disable=invalid-name, missing-function-docstring, line-too-long
23
+
24
+ import abc
25
+ from collections.abc import Iterable # pylint: disable=g-importing-member
26
+ import functools
27
+ from absl import logging
28
+ import gin
29
+ import jax
30
+ from jax import lax
31
+ from jax import random
32
+ import jax.numpy as jnp
33
+
34
+ import numpy as onp
35
+
36
+ # Nonlinear mappings encoding different attention kernels.
37
+ gin.external_configurable(jnp.cos, 'jcos')
38
+ gin.external_configurable(jnp.sin, 'jsin')
39
+ gin.external_configurable(jnp.tanh, 'jtanh')
40
+ gin.external_configurable(jax.nn.sigmoid, 'jsigmoid')
41
+ gin.external_configurable(
42
+ lambda x: jax.nn.gelu(x, approximate=False), 'jgelu'
43
+ ) # Needs to be exact, although might be slower. See https://github.com/google/jax/issues/4428.
44
+ gin.external_configurable(lambda x: x * x * (x > 0.0), 'jrequ')
45
+ gin.external_configurable(jnp.exp, 'jexp')
46
+ gin.external_configurable(lambda x: x, 'jidentity')
47
+ gin.external_configurable(
48
+ lambda x: (jnp.exp(x)) * (x <= 0.0) + (x + 1.0) * (x > 0.0), 'jshiftedelu'
49
+ ) # Nonlinearity used in "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention" (https://arxiv.org/abs/2006.16236).
50
+
51
+
52
+ def nonnegative_softmax_kernel_feature_creator(data,
53
+ projection_matrix,
54
+ attention_dims_t,
55
+ batch_dims_t,
56
+ precision,
57
+ is_query,
58
+ normalize_data=True,
59
+ eps=0.0001):
60
+ """Constructs nonnegative kernel features for fast softmax attention.
61
+
62
+
63
+ Args:
64
+ data: input for which features are computes
65
+ projection_matrix: random matrix used to compute features
66
+ attention_dims_t: tuple of attention dimensions
67
+ batch_dims_t: tuple of batch dimensions
68
+ precision: precision parameter
69
+ is_query: predicate indicating whether input data corresponds to queries or
70
+ keys
71
+ normalize_data: predicate indicating whether data should be normalized,
72
+ eps: numerical stabilizer.
73
+
74
+ Returns:
75
+ Random features for fast softmax attention.
76
+ """
77
+
78
+ if normalize_data:
79
+ # We have e^{qk^T/sqrt{d}} = e^{q_norm k_norm^T}, where
80
+ # w_norm = w * data_normalizer for w in {q,k}.
81
+ data_normalizer = 1.0 / (jnp.sqrt(jnp.sqrt(data.shape[-1])))
82
+ else:
83
+ data_normalizer = 1.0
84
+ ratio = 1.0 / jnp.sqrt(projection_matrix.shape[0])
85
+ data_mod_shape = data.shape[0:len(batch_dims_t)] + projection_matrix.shape
86
+ data_thick_random_matrix = jnp.zeros(data_mod_shape) + projection_matrix
87
+
88
+ data_dash = lax.dot_general(
89
+ data_normalizer * data,
90
+ data_thick_random_matrix,
91
+ (((data.ndim - 1,), (data_thick_random_matrix.ndim - 1,)),
92
+ (batch_dims_t, batch_dims_t)),
93
+ precision=precision)
94
+
95
+ diag_data = jnp.square(data)
96
+ diag_data = jnp.sum(diag_data, axis=data.ndim - 1)
97
+ diag_data = (diag_data / 2.0) * data_normalizer * data_normalizer
98
+ diag_data = jnp.expand_dims(diag_data, axis=data.ndim - 1)
99
+
100
+ last_dims_t = (len(data_dash.shape) - 1,)
101
+ if is_query:
102
+ data_dash = ratio * (
103
+ jnp.exp(data_dash - diag_data -
104
+ jnp.max(data_dash, axis=last_dims_t, keepdims=True)) + eps)
105
+ else:
106
+ data_dash = ratio * (
107
+ jnp.exp(data_dash - diag_data - jnp.max(
108
+ data_dash, axis=last_dims_t + attention_dims_t, keepdims=True)) +
109
+ eps)
110
+
111
+ return data_dash
112
+
113
+
114
+ def sincos_softmax_kernel_feature_creator(data,
115
+ projection_matrix,
116
+ attention_dims_t,
117
+ batch_dims_t,
118
+ precision,
119
+ normalize_data=True):
120
+ """Constructs kernel sin-cos features for fast softmax attention.
121
+
122
+
123
+ Args:
124
+ data: input for which features are computes
125
+ projection_matrix: random matrix used to compute features
126
+ attention_dims_t: tuple of attention dimensions
127
+ batch_dims_t: tuple of batch dimensions
128
+ precision: precision parameter
129
+ normalize_data: predicate indicating whether data should be normalized.
130
+
131
+ Returns:
132
+ Random features for fast softmax attention.
133
+ """
134
+ if normalize_data:
135
+ # We have: exp(qk^T/sqrt{d}) = exp(|q|^2/2sqrt{d}) * exp(|k|^2/2sqrt{d}) *
136
+ # exp(-(|q*c-k*c|^2)/2), where c = 1.0 / sqrt{sqrt{d}}.
137
+ data_normalizer = 1.0 / (jnp.sqrt(jnp.sqrt(data.shape[-1])))
138
+ else:
139
+ data_normalizer = 1.0
140
+ ratio = 1.0 / jnp.sqrt(projection_matrix.shape[0])
141
+ data_mod_shape = data.shape[0:len(batch_dims_t)] + projection_matrix.shape
142
+ data_thick_random_matrix = jnp.zeros(data_mod_shape) + projection_matrix
143
+
144
+ data_dash = lax.dot_general(
145
+ data_normalizer * data,
146
+ data_thick_random_matrix,
147
+ (((data.ndim - 1,), (data_thick_random_matrix.ndim - 1,)),
148
+ (batch_dims_t, batch_dims_t)),
149
+ precision=precision)
150
+ data_dash_cos = ratio * jnp.cos(data_dash)
151
+ data_dash_sin = ratio * jnp.sin(data_dash)
152
+ data_dash = jnp.concatenate((data_dash_cos, data_dash_sin), axis=-1)
153
+
154
+ # Constructing D_data and data^{'}
155
+ diag_data = jnp.square(data)
156
+ diag_data = jnp.sum(diag_data, axis=data.ndim - 1)
157
+ diag_data = (diag_data / 2.0) * data_normalizer * data_normalizer
158
+ diag_data = jnp.expand_dims(diag_data, axis=data.ndim - 1)
159
+ # Additional renormalization for numerical stability
160
+ data_renormalizer = jnp.max(diag_data, attention_dims_t, keepdims=True)
161
+ diag_data -= data_renormalizer
162
+ diag_data = jnp.exp(diag_data)
163
+ data_prime = data_dash * diag_data
164
+ return data_prime
165
+
166
+
167
+ def generalized_kernel_feature_creator(data, projection_matrix, batch_dims_t,
168
+ precision, kernel_fn, kernel_epsilon,
169
+ normalize_data):
170
+ """Constructs kernel features for fast generalized attention.
171
+
172
+
173
+ Args:
174
+ data: input for which features are computes
175
+ projection_matrix: matrix used to compute features
176
+ batch_dims_t: tuple of batch dimensions
177
+ precision: precision parameter
178
+ kernel_fn: kernel function used
179
+ kernel_epsilon: additive positive term added to every feature for numerical
180
+ stability
181
+ normalize_data: predicate indicating whether data should be normalized.
182
+
183
+ Returns:
184
+ Random features for fast generalized attention.
185
+ """
186
+ if normalize_data:
187
+ data_normalizer = 1.0 / (jnp.sqrt(jnp.sqrt(data.shape[-1])))
188
+ else:
189
+ data_normalizer = 1.0
190
+ if projection_matrix is None:
191
+ return kernel_fn(data_normalizer * data) + kernel_epsilon
192
+ else:
193
+ data_mod_shape = data.shape[0:len(batch_dims_t)] + projection_matrix.shape
194
+ data_thick_random_matrix = jnp.zeros(data_mod_shape) + projection_matrix
195
+ data_dash = lax.dot_general(
196
+ data_normalizer * data,
197
+ data_thick_random_matrix,
198
+ (((data.ndim - 1,), (data_thick_random_matrix.ndim - 1,)),
199
+ (batch_dims_t, batch_dims_t)),
200
+ precision=precision)
201
+ data_prime = kernel_fn(data_dash) + kernel_epsilon
202
+ return data_prime
203
+
204
+
205
+ @gin.configurable
206
+ def make_fast_softmax_attention(qkv_dim,
207
+ renormalize_attention=True,
208
+ numerical_stabilizer=0.000001,
209
+ nb_features=256,
210
+ ortho_features=True,
211
+ ortho_scaling=0.0,
212
+ redraw_features=True,
213
+ unidirectional=False,
214
+ nonnegative_features=True,
215
+ lax_scan_unroll=1):
216
+ """Construct a fast softmax attention method."""
217
+ logging.info(
218
+ 'Fast softmax attention: %s features and orthogonal=%s, renormalize=%s',
219
+ nb_features, ortho_features, renormalize_attention)
220
+ if ortho_features:
221
+ matrix_creator = functools.partial(
222
+ GaussianOrthogonalRandomMatrix,
223
+ nb_features,
224
+ qkv_dim,
225
+ scaling=ortho_scaling)
226
+ else:
227
+ matrix_creator = functools.partial(GaussianUnstructuredRandomMatrix,
228
+ nb_features, qkv_dim)
229
+ if nonnegative_features:
230
+
231
+ def kernel_feature_creator(data,
232
+ projection_matrix,
233
+ attention_dims_t,
234
+ batch_dims_t,
235
+ precision,
236
+ is_query,
237
+ normalize_data=True):
238
+ return nonnegative_softmax_kernel_feature_creator(
239
+ data, projection_matrix, attention_dims_t, batch_dims_t, precision,
240
+ is_query, normalize_data, numerical_stabilizer)
241
+ else:
242
+
243
+ def kernel_feature_creator(data,
244
+ projection_matrix,
245
+ attention_dims_t,
246
+ batch_dims_t,
247
+ precision,
248
+ is_query,
249
+ normalize_data=True):
250
+ del is_query
251
+ return sincos_softmax_kernel_feature_creator(data, projection_matrix,
252
+ attention_dims_t,
253
+ batch_dims_t, precision,
254
+ normalize_data)
255
+
256
+ attention_fn = FastAttentionviaLowRankDecomposition(
257
+ matrix_creator,
258
+ kernel_feature_creator,
259
+ renormalize_attention=renormalize_attention,
260
+ numerical_stabilizer=numerical_stabilizer,
261
+ redraw_features=redraw_features,
262
+ unidirectional=unidirectional,
263
+ lax_scan_unroll=lax_scan_unroll).dot_product_attention
264
+ return attention_fn
265
+
266
+
267
+ @gin.configurable
268
+ def make_fast_generalized_attention(qkv_dim,
269
+ renormalize_attention=True,
270
+ numerical_stabilizer=0.0,
271
+ nb_features=256,
272
+ features_type='deterministic',
273
+ kernel_fn=jax.nn.relu,
274
+ kernel_epsilon=0.001,
275
+ redraw_features=False,
276
+ unidirectional=False,
277
+ lax_scan_unroll=1):
278
+ """Construct a fast generalized attention menthod."""
279
+ logging.info('Fast generalized attention.: %s features and renormalize=%s',
280
+ nb_features, renormalize_attention)
281
+ if features_type == 'ortho':
282
+ matrix_creator = functools.partial(
283
+ GaussianOrthogonalRandomMatrix, nb_features, qkv_dim, scaling=False)
284
+ elif features_type == 'iid':
285
+ matrix_creator = functools.partial(GaussianUnstructuredRandomMatrix,
286
+ nb_features, qkv_dim)
287
+ elif features_type == 'deterministic':
288
+ matrix_creator = None
289
+ else:
290
+ raise ValueError('Unknown feature value type')
291
+
292
+ def kernel_feature_creator(data,
293
+ projection_matrix,
294
+ attention_dims_t,
295
+ batch_dims_t,
296
+ precision,
297
+ is_query,
298
+ normalize_data=False):
299
+ del attention_dims_t
300
+ del is_query
301
+ return generalized_kernel_feature_creator(data, projection_matrix,
302
+ batch_dims_t, precision,
303
+ kernel_fn, kernel_epsilon,
304
+ normalize_data)
305
+
306
+ attention_fn = FastAttentionviaLowRankDecomposition(
307
+ matrix_creator,
308
+ kernel_feature_creator,
309
+ renormalize_attention=renormalize_attention,
310
+ numerical_stabilizer=numerical_stabilizer,
311
+ redraw_features=redraw_features,
312
+ unidirectional=unidirectional,
313
+ lax_scan_unroll=lax_scan_unroll).dot_product_attention
314
+ return attention_fn
315
+
316
+
317
+ class RandomMatrix(object):
318
+ r"""Abstract class providing a method for constructing 2D random arrays.
319
+
320
+ Class is responsible for constructing 2D random arrays.
321
+ """
322
+
323
+ __metaclass__ = abc.ABCMeta
324
+
325
+ @abc.abstractmethod
326
+ def get_2d_array(self):
327
+ raise NotImplementedError('Abstract method')
328
+
329
+
330
+ class GaussianUnstructuredRandomMatrix(RandomMatrix):
331
+
332
+ def __init__(self, nb_rows, nb_columns, key):
333
+ self.nb_rows = nb_rows
334
+ self.nb_columns = nb_columns
335
+ self.key = key
336
+
337
+ def get_2d_array(self):
338
+ return random.normal(self.key, (self.nb_rows, self.nb_columns))
339
+
340
+
341
+ class GaussianOrthogonalRandomMatrix(RandomMatrix):
342
+ r"""Class providing a method to create Gaussian orthogonal matrix.
343
+
344
+ Class is responsible for constructing 2D Gaussian orthogonal arrays.
345
+ """
346
+
347
+ def __init__(self, nb_rows, nb_columns, key, scaling=0):
348
+ self.nb_rows = nb_rows
349
+ self.nb_columns = nb_columns
350
+ self.key = key
351
+ self.scaling = scaling
352
+
353
+ def get_2d_array(self):
354
+ nb_full_blocks = int(self.nb_rows / self.nb_columns)
355
+ block_list = []
356
+ rng = self.key
357
+ for _ in range(nb_full_blocks):
358
+ rng, rng_input = jax.random.split(rng)
359
+ unstructured_block = random.normal(rng_input,
360
+ (self.nb_columns, self.nb_columns))
361
+ q, _ = jnp.linalg.qr(unstructured_block)
362
+ q = jnp.transpose(q)
363
+ block_list.append(q)
364
+ remaining_rows = self.nb_rows - nb_full_blocks * self.nb_columns
365
+ if remaining_rows > 0:
366
+ rng, rng_input = jax.random.split(rng)
367
+ unstructured_block = random.normal(rng_input,
368
+ (self.nb_columns, self.nb_columns))
369
+ q, _ = jnp.linalg.qr(unstructured_block)
370
+ q = jnp.transpose(q)
371
+ block_list.append(q[0:remaining_rows])
372
+ final_matrix = jnp.vstack(block_list)
373
+
374
+ if self.scaling == 0:
375
+ multiplier = jnp.linalg.norm(
376
+ random.normal(self.key, (self.nb_rows, self.nb_columns)), axis=1)
377
+ elif self.scaling == 1:
378
+ multiplier = jnp.sqrt(float(self.nb_columns)) * jnp.ones((self.nb_rows))
379
+ else:
380
+ raise ValueError('Scaling must be one of {0, 1}. Was %s' % self._scaling)
381
+
382
+ return jnp.matmul(jnp.diag(multiplier), final_matrix)
383
+
384
+
385
+ class FastAttention(object):
386
+ r"""Abstract class providing a method for fast attention.
387
+
388
+ Class is responsible for providing a method <dot_product_attention> for fast
389
+ approximate attention.
390
+ """
391
+
392
+ __metaclass__ = abc.ABCMeta
393
+
394
+ @abc.abstractmethod
395
+ def dot_product_attention(self,
396
+ query,
397
+ key,
398
+ value,
399
+ dtype=jnp.float32,
400
+ bias=None,
401
+ mask=None,
402
+ axis=None,
403
+ broadcast_dropout=True,
404
+ dropout_rng=None,
405
+ dropout_rate=0.,
406
+ deterministic=False,
407
+ precision=None):
408
+ """Computes dot-product attention given query, key, and value.
409
+
410
+ This is the core function for applying fast approximate dot-product
411
+ attention. It calculates the attention weights given query and key and
412
+ combines the values using the attention weights. This function supports
413
+ multi-dimensional inputs.
414
+
415
+
416
+ Args:
417
+ query: queries for calculating attention with shape of [batch_size, dim1,
418
+ dim2, ..., dimN, num_heads, mem_channels].
419
+ key: keys for calculating attention with shape of [batch_size, dim1, dim2,
420
+ ..., dimN, num_heads, mem_channels].
421
+ value: values to be used in attention with shape of [batch_size, dim1,
422
+ dim2,..., dimN, num_heads, value_channels].
423
+ dtype: the dtype of the computation (default: float32)
424
+ bias: bias for the attention weights. This can be used for incorporating
425
+ autoregressive mask, padding mask, proximity bias.
426
+ mask: mask for the attention weights. This can be used for incorporating
427
+ autoregressive masks.
428
+ axis: axises over which the attention is applied.
429
+ broadcast_dropout: bool: use a broadcasted dropout along batch dims.
430
+ dropout_rng: JAX PRNGKey: to be used for dropout.
431
+ dropout_rate: dropout rate.
432
+ deterministic: bool, deterministic or not (to apply dropout).
433
+ precision: numerical precision of the computation see `jax.lax.Precision`
434
+ for details.
435
+
436
+ Returns:
437
+ Output of shape [bs, dim1, dim2, ..., dimN,, num_heads, value_channels].
438
+ """
439
+ raise NotImplementedError('Abstract method')
440
+
441
+
442
+ def _numerator(z_slice_shape, precision, unroll=1):
443
+
444
+ def fwd(qs, ks, vs):
445
+
446
+ def body(p, qkv):
447
+ (q, k, v) = qkv
448
+ p += jnp.einsum('...m,...d->...md', k, v, precision=precision)
449
+ X_slice = jnp.einsum('...m,...md->...d', q, p, precision=precision)
450
+ return p, X_slice
451
+
452
+ init_value = jnp.zeros(z_slice_shape)
453
+ p, W = lax.scan(body, init_value, (qs, ks, vs), unroll=unroll)
454
+ return W, (p, qs, ks, vs)
455
+
456
+ def bwd(pqkv, W_ct):
457
+
458
+ def body(carry, qkv_xct):
459
+ p, p_ct = carry
460
+ q, k, v, x_ct = qkv_xct
461
+ q_ct = jnp.einsum('...d,...md->...m', x_ct, p, precision=precision)
462
+ p_ct += jnp.einsum('...d,...m->...md', x_ct, q, precision=precision)
463
+ k_ct = jnp.einsum('...md,...d->...m', p_ct, v, precision=precision)
464
+ v_ct = jnp.einsum('...md,...m->...d', p_ct, k, precision=precision)
465
+ p -= jnp.einsum('...m,...d->...md', k, v, precision=precision)
466
+ return (p, p_ct), (q_ct, k_ct, v_ct)
467
+
468
+ p, qs, ks, vs = pqkv
469
+ _, (qs_ct, ks_ct, vs_ct) = lax.scan(
470
+ body, (p, jnp.zeros_like(p)), (qs, ks, vs, W_ct),
471
+ reverse=True,
472
+ unroll=unroll)
473
+ return qs_ct, ks_ct, vs_ct
474
+
475
+ @jax.custom_vjp
476
+ def _numerator_impl(qs, ks, vs):
477
+ W, _ = fwd(qs, ks, vs)
478
+ return W
479
+
480
+ _numerator_impl.defvjp(fwd, bwd)
481
+
482
+ return _numerator_impl
483
+
484
+
485
+ def _denominator(t_slice_shape, precision, unroll=1):
486
+
487
+ def fwd(qs, ks):
488
+
489
+ def body(p, qk):
490
+ q, k = qk
491
+ p += k
492
+ x = jnp.einsum('...m,...m->...', q, p, precision=precision)
493
+ return p, x
494
+
495
+ p = jnp.zeros(t_slice_shape)
496
+ p, R = lax.scan(body, p, (qs, ks), unroll=unroll)
497
+ return R, (qs, ks, p)
498
+
499
+ def bwd(qkp, R_ct):
500
+
501
+ def body(carry, qkx):
502
+ p, p_ct = carry
503
+ q, k, x_ct = qkx
504
+ q_ct = jnp.einsum('...,...m->...m', x_ct, p, precision=precision)
505
+ p_ct += jnp.einsum('...,...m->...m', x_ct, q, precision=precision)
506
+ k_ct = p_ct
507
+ p -= k
508
+ return (p, p_ct), (q_ct, k_ct)
509
+
510
+ qs, ks, p = qkp
511
+ _, (qs_ct, ks_ct) = lax.scan(
512
+ body, (p, jnp.zeros_like(p)), (qs, ks, R_ct),
513
+ reverse=True,
514
+ unroll=unroll)
515
+ return (qs_ct, ks_ct)
516
+
517
+ @jax.custom_vjp
518
+ def _denominator_impl(qs, ks):
519
+ R, _ = fwd(qs, ks)
520
+ return R
521
+
522
+ _denominator_impl.defvjp(fwd, bwd)
523
+
524
+ return _denominator_impl
525
+
526
+
527
+ class FastAttentionviaLowRankDecomposition(FastAttention):
528
+ r"""Class providing a method for fast attention via low rank decomposition.
529
+
530
+ Class is responsible for providing a method <dot_product_attention> for fast
531
+ dot-product attention with the use of low rank decomposition (e.g. with
532
+ random feature maps).
533
+ """
534
+
535
+ def __init__(self,
536
+ matrix_creator,
537
+ kernel_feature_creator,
538
+ renormalize_attention,
539
+ numerical_stabilizer,
540
+ redraw_features,
541
+ unidirectional,
542
+ lax_scan_unroll=1): # For optimal GPU performance, set to 16.
543
+ rng = random.PRNGKey(0)
544
+ self.matrix_creator = matrix_creator
545
+ self.projection_matrix = self.draw_weights(rng)
546
+ self.kernel_feature_creator = kernel_feature_creator
547
+ self.renormalize_attention = renormalize_attention
548
+ self.numerical_stabilizer = numerical_stabilizer
549
+ self.redraw_features = redraw_features
550
+ self.unidirectional = unidirectional
551
+ self.lax_scan_unroll = lax_scan_unroll
552
+
553
+ def draw_weights(self, key):
554
+ if self.matrix_creator is None:
555
+ return None
556
+ matrixrng, _ = random.split(key)
557
+ projection_matrix = self.matrix_creator(key=matrixrng).get_2d_array()
558
+ return projection_matrix
559
+
560
+ def dot_product_attention(self,
561
+ query,
562
+ key,
563
+ value,
564
+ dtype=jnp.float32,
565
+ bias=None,
566
+ mask=None,
567
+ axis=None,
568
+ broadcast_dropout=True,
569
+ dropout_rng=None,
570
+ dropout_rate=0.,
571
+ deterministic=False,
572
+ precision=None):
573
+
574
+ assert key.shape[:-1] == value.shape[:-1]
575
+ assert (query.shape[0:1] == key.shape[0:1] and
576
+ query.shape[-1] == key.shape[-1])
577
+ if axis is None:
578
+ axis = tuple(range(1, key.ndim - 2))
579
+ if not isinstance(axis, Iterable):
580
+ axis = (axis,)
581
+ assert key.ndim == query.ndim
582
+ assert key.ndim == value.ndim
583
+ for ax in axis:
584
+ if not (query.ndim >= 3 and 1 <= ax < query.ndim - 2):
585
+ raise ValueError('Attention axis must be between the batch '
586
+ 'axis and the last-two axes.')
587
+ n = key.ndim
588
+
589
+ # Constructing projection tensor.
590
+ if self.redraw_features:
591
+ # TODO(kchoro): Get rid of the constant below.
592
+ query_seed = lax.convert_element_type(
593
+ jnp.ceil(jnp.sum(query) * 10000000.0), jnp.int32)
594
+ rng = random.PRNGKey(query_seed)
595
+ self.projection_matrix = self.draw_weights(rng)
596
+
597
+ # batch_dims is <bs, <non-attention dims>, num_heads>
598
+ batch_dims = tuple(onp.delete(range(n), axis + (n - 1,)))
599
+ # q & k -> (bs, <non-attention dims>, num_heads, <attention dims>, channels)
600
+ qk_perm = batch_dims + axis + (n - 1,)
601
+ k_extra_perm = axis + batch_dims + (n - 1,)
602
+ key_extra = key.transpose(k_extra_perm)
603
+ key = key.transpose(qk_perm)
604
+ query = query.transpose(qk_perm)
605
+ # v -> (bs, <non-attention dims>, num_heads, <attention dims>, channels)
606
+ v_perm = batch_dims + axis + (n - 1,)
607
+ value = value.transpose(v_perm)
608
+ batch_dims_t = tuple(range(len(batch_dims)))
609
+ attention_dims_t = tuple(
610
+ range(len(batch_dims),
611
+ len(batch_dims) + len(axis)))
612
+
613
+ # Constructing tensors Q^{'} and K^{'}.
614
+ query_prime = self.kernel_feature_creator(query, self.projection_matrix,
615
+ attention_dims_t, batch_dims_t,
616
+ precision, True)
617
+ key_prime = self.kernel_feature_creator(key, self.projection_matrix,
618
+ attention_dims_t, batch_dims_t,
619
+ precision, False)
620
+
621
+ if self.unidirectional:
622
+ index = attention_dims_t[0]
623
+ z_slice_shape = key_prime.shape[0:len(batch_dims_t)] + (
624
+ key_prime.shape[-1],) + (value.shape[-1],)
625
+
626
+ numerator_fn = _numerator(z_slice_shape, precision, self.lax_scan_unroll)
627
+ W = numerator_fn(
628
+ jnp.moveaxis(query_prime, index, 0),
629
+ jnp.moveaxis(key_prime, index, 0), jnp.moveaxis(value, index, 0))
630
+
631
+ # Constructing W = (Q^{'}(K^{'})^{T})_{masked}V
632
+ W = jnp.moveaxis(W, 0, index)
633
+
634
+ if not self.renormalize_attention:
635
+ # Unidirectional, not-normalized attention.
636
+ perm_inv = _invert_perm(qk_perm)
637
+ result = W.transpose(perm_inv)
638
+ return result
639
+ else:
640
+ # Unidirectional, normalized attention.
641
+ thick_all_ones = jnp.zeros(key.shape[0:-1]) + jnp.ones(
642
+ key_extra.shape[0:len(axis)])
643
+
644
+ index = attention_dims_t[0]
645
+ t_slice_shape = key_prime.shape[0:len(batch_dims_t)] + (
646
+ key_prime.shape[-1],)
647
+ denominator_fn = _denominator(t_slice_shape, precision,
648
+ self.lax_scan_unroll)
649
+ R = denominator_fn(
650
+ jnp.moveaxis(query_prime, index, 0),
651
+ jnp.moveaxis(key_prime, index, 0))
652
+
653
+ R = jnp.moveaxis(R, 0, index)
654
+ else:
655
+ contract_query = tuple(
656
+ range(len(batch_dims) + len(axis),
657
+ len(batch_dims) + len(axis) + 1))
658
+ contract_z = tuple(range(len(batch_dims), len(batch_dims) + 1))
659
+ # Constructing Z = (K^{'})^{T}V
660
+ # Z (bs, <non-attention dims>, num_heads, channels_m, channels_v)
661
+ Z = lax.dot_general(
662
+ key_prime,
663
+ value,
664
+ ((attention_dims_t, attention_dims_t), (batch_dims_t, batch_dims_t)),
665
+ precision=precision)
666
+ # Constructing W = Q^{'}Z = Q^{'}(K^{'})^{T}V
667
+ # q (bs, <non-attention dims>, num_heads, <attention dims>, channels_m)
668
+ # Z (bs, <non-attention dims>, num_heads, channels_m, channels_v)
669
+ # W (bs, <non-attention dims>, num_heads, <attention dims>, channels_v)
670
+ W = lax.dot_general(
671
+ query_prime,
672
+ Z, ((contract_query, contract_z), (batch_dims_t, batch_dims_t)),
673
+ precision=precision)
674
+ if not self.renormalize_attention:
675
+ # Bidirectional, not-normalized attention.
676
+ perm_inv = _invert_perm(qk_perm)
677
+ result = W.transpose(perm_inv)
678
+ return result
679
+ else:
680
+ # Bidirectional, normalized attention.
681
+ thick_all_ones = jnp.zeros(key.shape[0:-1]) + jnp.ones(
682
+ key_extra.shape[0:len(axis)])
683
+ contract_key = tuple(
684
+ range(len(batch_dims),
685
+ len(batch_dims) + len(axis)))
686
+ contract_thick_all_ones = tuple(
687
+ range(thick_all_ones.ndim - len(axis), thick_all_ones.ndim))
688
+ # Construct T = (K^{'})^{T} 1_L
689
+ # k (bs, <non-attention dims>, num_heads, <attention dims>, channels)
690
+ T = lax.dot_general(
691
+ key_prime,
692
+ thick_all_ones, ((contract_key, contract_thick_all_ones),
693
+ (batch_dims_t, batch_dims_t)),
694
+ precision=precision)
695
+
696
+ # Construct partition function: R = Q^{'} T = Q^{'}(K^{'})^{T} 1_L
697
+ # q_p (bs, <non-attention dims>, num_heads, <attention dims>, channs_m)
698
+ # T (bs, <non-attention dims>, num_heads, channels_m)
699
+ R = lax.dot_general(
700
+ query_prime,
701
+ T, (((query_prime.ndim - 1,), (T.ndim - 1,)),
702
+ (batch_dims_t, range(0,
703
+ len(T.shape) - 1))),
704
+ precision=precision)
705
+
706
+ R = R + 2 * self.numerical_stabilizer * (
707
+ jnp.abs(R) <= self.numerical_stabilizer)
708
+ R = jnp.reciprocal(R)
709
+ R = jnp.expand_dims(R, len(R.shape))
710
+ # W (bs, <non-attention dims>, num_heads, <attention dims>, channels_v)
711
+ # R (bs, <non-attention dims>, num_heads, <attention dims>, extra_channel)
712
+ result = W * R
713
+ # back to (bs, dim1, dim2, ..., dimN, num_heads, channels)
714
+ perm_inv = _invert_perm(qk_perm)
715
+ result = result.transpose(perm_inv)
716
+ return result
717
+
718
+
719
+ def _invert_perm(perm):
720
+ perm_inv = [0] * len(perm)
721
+ for i, j in enumerate(perm):
722
+ perm_inv[j] = i
723
+ return tuple(perm_inv)