flaxdiff 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/__init__.py +0 -0
- flaxdiff/models/__init__.py +1 -0
- flaxdiff/models/attention.py +489 -0
- flaxdiff/models/common.py +7 -0
- flaxdiff/models/favor_fastattn.py +723 -0
- flaxdiff/models/simple_unet.py +519 -0
- flaxdiff/predictors/__init__.py +96 -0
- flaxdiff/samplers/__init__.py +7 -0
- flaxdiff/samplers/common.py +113 -0
- flaxdiff/samplers/ddim.py +10 -0
- flaxdiff/samplers/ddpm.py +43 -0
- flaxdiff/samplers/euler.py +59 -0
- flaxdiff/samplers/heun_sampler.py +28 -0
- flaxdiff/samplers/multistep_dpm.py +60 -0
- flaxdiff/samplers/rk4_sampler.py +34 -0
- flaxdiff/schedulers/__init__.py +6 -0
- flaxdiff/schedulers/common.py +98 -0
- flaxdiff/schedulers/continuous.py +12 -0
- flaxdiff/schedulers/cosine.py +40 -0
- flaxdiff/schedulers/discrete.py +74 -0
- flaxdiff/schedulers/exp.py +13 -0
- flaxdiff/schedulers/karras.py +69 -0
- flaxdiff/schedulers/linear.py +14 -0
- flaxdiff/schedulers/sqrt.py +10 -0
- flaxdiff/trainer/__init__.py +216 -0
- flaxdiff/utils.py +89 -0
- flaxdiff-0.1.1.dist-info/METADATA +326 -0
- flaxdiff-0.1.1.dist-info/RECORD +30 -0
- flaxdiff-0.1.1.dist-info/WHEEL +5 -0
- flaxdiff-0.1.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,723 @@
|
|
1
|
+
# coding=utf-8
|
2
|
+
# Copyright 2024 The Google Research Authors.
|
3
|
+
#
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
#
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
#
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
"""Core Fast Attention Module for Flax.
|
17
|
+
|
18
|
+
Implementation of the approximate fast softmax and generalized
|
19
|
+
attention mechanism leveraging structured random feature maps [RFM] techniques
|
20
|
+
and low rank decomposition of the attention matrix.
|
21
|
+
"""
|
22
|
+
# pylint: disable=invalid-name, missing-function-docstring, line-too-long
|
23
|
+
|
24
|
+
import abc
|
25
|
+
from collections.abc import Iterable # pylint: disable=g-importing-member
|
26
|
+
import functools
|
27
|
+
from absl import logging
|
28
|
+
import gin
|
29
|
+
import jax
|
30
|
+
from jax import lax
|
31
|
+
from jax import random
|
32
|
+
import jax.numpy as jnp
|
33
|
+
|
34
|
+
import numpy as onp
|
35
|
+
|
36
|
+
# Nonlinear mappings encoding different attention kernels.
|
37
|
+
gin.external_configurable(jnp.cos, 'jcos')
|
38
|
+
gin.external_configurable(jnp.sin, 'jsin')
|
39
|
+
gin.external_configurable(jnp.tanh, 'jtanh')
|
40
|
+
gin.external_configurable(jax.nn.sigmoid, 'jsigmoid')
|
41
|
+
gin.external_configurable(
|
42
|
+
lambda x: jax.nn.gelu(x, approximate=False), 'jgelu'
|
43
|
+
) # Needs to be exact, although might be slower. See https://github.com/google/jax/issues/4428.
|
44
|
+
gin.external_configurable(lambda x: x * x * (x > 0.0), 'jrequ')
|
45
|
+
gin.external_configurable(jnp.exp, 'jexp')
|
46
|
+
gin.external_configurable(lambda x: x, 'jidentity')
|
47
|
+
gin.external_configurable(
|
48
|
+
lambda x: (jnp.exp(x)) * (x <= 0.0) + (x + 1.0) * (x > 0.0), 'jshiftedelu'
|
49
|
+
) # Nonlinearity used in "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention" (https://arxiv.org/abs/2006.16236).
|
50
|
+
|
51
|
+
|
52
|
+
def nonnegative_softmax_kernel_feature_creator(data,
|
53
|
+
projection_matrix,
|
54
|
+
attention_dims_t,
|
55
|
+
batch_dims_t,
|
56
|
+
precision,
|
57
|
+
is_query,
|
58
|
+
normalize_data=True,
|
59
|
+
eps=0.0001):
|
60
|
+
"""Constructs nonnegative kernel features for fast softmax attention.
|
61
|
+
|
62
|
+
|
63
|
+
Args:
|
64
|
+
data: input for which features are computes
|
65
|
+
projection_matrix: random matrix used to compute features
|
66
|
+
attention_dims_t: tuple of attention dimensions
|
67
|
+
batch_dims_t: tuple of batch dimensions
|
68
|
+
precision: precision parameter
|
69
|
+
is_query: predicate indicating whether input data corresponds to queries or
|
70
|
+
keys
|
71
|
+
normalize_data: predicate indicating whether data should be normalized,
|
72
|
+
eps: numerical stabilizer.
|
73
|
+
|
74
|
+
Returns:
|
75
|
+
Random features for fast softmax attention.
|
76
|
+
"""
|
77
|
+
|
78
|
+
if normalize_data:
|
79
|
+
# We have e^{qk^T/sqrt{d}} = e^{q_norm k_norm^T}, where
|
80
|
+
# w_norm = w * data_normalizer for w in {q,k}.
|
81
|
+
data_normalizer = 1.0 / (jnp.sqrt(jnp.sqrt(data.shape[-1])))
|
82
|
+
else:
|
83
|
+
data_normalizer = 1.0
|
84
|
+
ratio = 1.0 / jnp.sqrt(projection_matrix.shape[0])
|
85
|
+
data_mod_shape = data.shape[0:len(batch_dims_t)] + projection_matrix.shape
|
86
|
+
data_thick_random_matrix = jnp.zeros(data_mod_shape) + projection_matrix
|
87
|
+
|
88
|
+
data_dash = lax.dot_general(
|
89
|
+
data_normalizer * data,
|
90
|
+
data_thick_random_matrix,
|
91
|
+
(((data.ndim - 1,), (data_thick_random_matrix.ndim - 1,)),
|
92
|
+
(batch_dims_t, batch_dims_t)),
|
93
|
+
precision=precision)
|
94
|
+
|
95
|
+
diag_data = jnp.square(data)
|
96
|
+
diag_data = jnp.sum(diag_data, axis=data.ndim - 1)
|
97
|
+
diag_data = (diag_data / 2.0) * data_normalizer * data_normalizer
|
98
|
+
diag_data = jnp.expand_dims(diag_data, axis=data.ndim - 1)
|
99
|
+
|
100
|
+
last_dims_t = (len(data_dash.shape) - 1,)
|
101
|
+
if is_query:
|
102
|
+
data_dash = ratio * (
|
103
|
+
jnp.exp(data_dash - diag_data -
|
104
|
+
jnp.max(data_dash, axis=last_dims_t, keepdims=True)) + eps)
|
105
|
+
else:
|
106
|
+
data_dash = ratio * (
|
107
|
+
jnp.exp(data_dash - diag_data - jnp.max(
|
108
|
+
data_dash, axis=last_dims_t + attention_dims_t, keepdims=True)) +
|
109
|
+
eps)
|
110
|
+
|
111
|
+
return data_dash
|
112
|
+
|
113
|
+
|
114
|
+
def sincos_softmax_kernel_feature_creator(data,
|
115
|
+
projection_matrix,
|
116
|
+
attention_dims_t,
|
117
|
+
batch_dims_t,
|
118
|
+
precision,
|
119
|
+
normalize_data=True):
|
120
|
+
"""Constructs kernel sin-cos features for fast softmax attention.
|
121
|
+
|
122
|
+
|
123
|
+
Args:
|
124
|
+
data: input for which features are computes
|
125
|
+
projection_matrix: random matrix used to compute features
|
126
|
+
attention_dims_t: tuple of attention dimensions
|
127
|
+
batch_dims_t: tuple of batch dimensions
|
128
|
+
precision: precision parameter
|
129
|
+
normalize_data: predicate indicating whether data should be normalized.
|
130
|
+
|
131
|
+
Returns:
|
132
|
+
Random features for fast softmax attention.
|
133
|
+
"""
|
134
|
+
if normalize_data:
|
135
|
+
# We have: exp(qk^T/sqrt{d}) = exp(|q|^2/2sqrt{d}) * exp(|k|^2/2sqrt{d}) *
|
136
|
+
# exp(-(|q*c-k*c|^2)/2), where c = 1.0 / sqrt{sqrt{d}}.
|
137
|
+
data_normalizer = 1.0 / (jnp.sqrt(jnp.sqrt(data.shape[-1])))
|
138
|
+
else:
|
139
|
+
data_normalizer = 1.0
|
140
|
+
ratio = 1.0 / jnp.sqrt(projection_matrix.shape[0])
|
141
|
+
data_mod_shape = data.shape[0:len(batch_dims_t)] + projection_matrix.shape
|
142
|
+
data_thick_random_matrix = jnp.zeros(data_mod_shape) + projection_matrix
|
143
|
+
|
144
|
+
data_dash = lax.dot_general(
|
145
|
+
data_normalizer * data,
|
146
|
+
data_thick_random_matrix,
|
147
|
+
(((data.ndim - 1,), (data_thick_random_matrix.ndim - 1,)),
|
148
|
+
(batch_dims_t, batch_dims_t)),
|
149
|
+
precision=precision)
|
150
|
+
data_dash_cos = ratio * jnp.cos(data_dash)
|
151
|
+
data_dash_sin = ratio * jnp.sin(data_dash)
|
152
|
+
data_dash = jnp.concatenate((data_dash_cos, data_dash_sin), axis=-1)
|
153
|
+
|
154
|
+
# Constructing D_data and data^{'}
|
155
|
+
diag_data = jnp.square(data)
|
156
|
+
diag_data = jnp.sum(diag_data, axis=data.ndim - 1)
|
157
|
+
diag_data = (diag_data / 2.0) * data_normalizer * data_normalizer
|
158
|
+
diag_data = jnp.expand_dims(diag_data, axis=data.ndim - 1)
|
159
|
+
# Additional renormalization for numerical stability
|
160
|
+
data_renormalizer = jnp.max(diag_data, attention_dims_t, keepdims=True)
|
161
|
+
diag_data -= data_renormalizer
|
162
|
+
diag_data = jnp.exp(diag_data)
|
163
|
+
data_prime = data_dash * diag_data
|
164
|
+
return data_prime
|
165
|
+
|
166
|
+
|
167
|
+
def generalized_kernel_feature_creator(data, projection_matrix, batch_dims_t,
|
168
|
+
precision, kernel_fn, kernel_epsilon,
|
169
|
+
normalize_data):
|
170
|
+
"""Constructs kernel features for fast generalized attention.
|
171
|
+
|
172
|
+
|
173
|
+
Args:
|
174
|
+
data: input for which features are computes
|
175
|
+
projection_matrix: matrix used to compute features
|
176
|
+
batch_dims_t: tuple of batch dimensions
|
177
|
+
precision: precision parameter
|
178
|
+
kernel_fn: kernel function used
|
179
|
+
kernel_epsilon: additive positive term added to every feature for numerical
|
180
|
+
stability
|
181
|
+
normalize_data: predicate indicating whether data should be normalized.
|
182
|
+
|
183
|
+
Returns:
|
184
|
+
Random features for fast generalized attention.
|
185
|
+
"""
|
186
|
+
if normalize_data:
|
187
|
+
data_normalizer = 1.0 / (jnp.sqrt(jnp.sqrt(data.shape[-1])))
|
188
|
+
else:
|
189
|
+
data_normalizer = 1.0
|
190
|
+
if projection_matrix is None:
|
191
|
+
return kernel_fn(data_normalizer * data) + kernel_epsilon
|
192
|
+
else:
|
193
|
+
data_mod_shape = data.shape[0:len(batch_dims_t)] + projection_matrix.shape
|
194
|
+
data_thick_random_matrix = jnp.zeros(data_mod_shape) + projection_matrix
|
195
|
+
data_dash = lax.dot_general(
|
196
|
+
data_normalizer * data,
|
197
|
+
data_thick_random_matrix,
|
198
|
+
(((data.ndim - 1,), (data_thick_random_matrix.ndim - 1,)),
|
199
|
+
(batch_dims_t, batch_dims_t)),
|
200
|
+
precision=precision)
|
201
|
+
data_prime = kernel_fn(data_dash) + kernel_epsilon
|
202
|
+
return data_prime
|
203
|
+
|
204
|
+
|
205
|
+
@gin.configurable
|
206
|
+
def make_fast_softmax_attention(qkv_dim,
|
207
|
+
renormalize_attention=True,
|
208
|
+
numerical_stabilizer=0.000001,
|
209
|
+
nb_features=256,
|
210
|
+
ortho_features=True,
|
211
|
+
ortho_scaling=0.0,
|
212
|
+
redraw_features=True,
|
213
|
+
unidirectional=False,
|
214
|
+
nonnegative_features=True,
|
215
|
+
lax_scan_unroll=1):
|
216
|
+
"""Construct a fast softmax attention method."""
|
217
|
+
logging.info(
|
218
|
+
'Fast softmax attention: %s features and orthogonal=%s, renormalize=%s',
|
219
|
+
nb_features, ortho_features, renormalize_attention)
|
220
|
+
if ortho_features:
|
221
|
+
matrix_creator = functools.partial(
|
222
|
+
GaussianOrthogonalRandomMatrix,
|
223
|
+
nb_features,
|
224
|
+
qkv_dim,
|
225
|
+
scaling=ortho_scaling)
|
226
|
+
else:
|
227
|
+
matrix_creator = functools.partial(GaussianUnstructuredRandomMatrix,
|
228
|
+
nb_features, qkv_dim)
|
229
|
+
if nonnegative_features:
|
230
|
+
|
231
|
+
def kernel_feature_creator(data,
|
232
|
+
projection_matrix,
|
233
|
+
attention_dims_t,
|
234
|
+
batch_dims_t,
|
235
|
+
precision,
|
236
|
+
is_query,
|
237
|
+
normalize_data=True):
|
238
|
+
return nonnegative_softmax_kernel_feature_creator(
|
239
|
+
data, projection_matrix, attention_dims_t, batch_dims_t, precision,
|
240
|
+
is_query, normalize_data, numerical_stabilizer)
|
241
|
+
else:
|
242
|
+
|
243
|
+
def kernel_feature_creator(data,
|
244
|
+
projection_matrix,
|
245
|
+
attention_dims_t,
|
246
|
+
batch_dims_t,
|
247
|
+
precision,
|
248
|
+
is_query,
|
249
|
+
normalize_data=True):
|
250
|
+
del is_query
|
251
|
+
return sincos_softmax_kernel_feature_creator(data, projection_matrix,
|
252
|
+
attention_dims_t,
|
253
|
+
batch_dims_t, precision,
|
254
|
+
normalize_data)
|
255
|
+
|
256
|
+
attention_fn = FastAttentionviaLowRankDecomposition(
|
257
|
+
matrix_creator,
|
258
|
+
kernel_feature_creator,
|
259
|
+
renormalize_attention=renormalize_attention,
|
260
|
+
numerical_stabilizer=numerical_stabilizer,
|
261
|
+
redraw_features=redraw_features,
|
262
|
+
unidirectional=unidirectional,
|
263
|
+
lax_scan_unroll=lax_scan_unroll).dot_product_attention
|
264
|
+
return attention_fn
|
265
|
+
|
266
|
+
|
267
|
+
@gin.configurable
|
268
|
+
def make_fast_generalized_attention(qkv_dim,
|
269
|
+
renormalize_attention=True,
|
270
|
+
numerical_stabilizer=0.0,
|
271
|
+
nb_features=256,
|
272
|
+
features_type='deterministic',
|
273
|
+
kernel_fn=jax.nn.relu,
|
274
|
+
kernel_epsilon=0.001,
|
275
|
+
redraw_features=False,
|
276
|
+
unidirectional=False,
|
277
|
+
lax_scan_unroll=1):
|
278
|
+
"""Construct a fast generalized attention menthod."""
|
279
|
+
logging.info('Fast generalized attention.: %s features and renormalize=%s',
|
280
|
+
nb_features, renormalize_attention)
|
281
|
+
if features_type == 'ortho':
|
282
|
+
matrix_creator = functools.partial(
|
283
|
+
GaussianOrthogonalRandomMatrix, nb_features, qkv_dim, scaling=False)
|
284
|
+
elif features_type == 'iid':
|
285
|
+
matrix_creator = functools.partial(GaussianUnstructuredRandomMatrix,
|
286
|
+
nb_features, qkv_dim)
|
287
|
+
elif features_type == 'deterministic':
|
288
|
+
matrix_creator = None
|
289
|
+
else:
|
290
|
+
raise ValueError('Unknown feature value type')
|
291
|
+
|
292
|
+
def kernel_feature_creator(data,
|
293
|
+
projection_matrix,
|
294
|
+
attention_dims_t,
|
295
|
+
batch_dims_t,
|
296
|
+
precision,
|
297
|
+
is_query,
|
298
|
+
normalize_data=False):
|
299
|
+
del attention_dims_t
|
300
|
+
del is_query
|
301
|
+
return generalized_kernel_feature_creator(data, projection_matrix,
|
302
|
+
batch_dims_t, precision,
|
303
|
+
kernel_fn, kernel_epsilon,
|
304
|
+
normalize_data)
|
305
|
+
|
306
|
+
attention_fn = FastAttentionviaLowRankDecomposition(
|
307
|
+
matrix_creator,
|
308
|
+
kernel_feature_creator,
|
309
|
+
renormalize_attention=renormalize_attention,
|
310
|
+
numerical_stabilizer=numerical_stabilizer,
|
311
|
+
redraw_features=redraw_features,
|
312
|
+
unidirectional=unidirectional,
|
313
|
+
lax_scan_unroll=lax_scan_unroll).dot_product_attention
|
314
|
+
return attention_fn
|
315
|
+
|
316
|
+
|
317
|
+
class RandomMatrix(object):
|
318
|
+
r"""Abstract class providing a method for constructing 2D random arrays.
|
319
|
+
|
320
|
+
Class is responsible for constructing 2D random arrays.
|
321
|
+
"""
|
322
|
+
|
323
|
+
__metaclass__ = abc.ABCMeta
|
324
|
+
|
325
|
+
@abc.abstractmethod
|
326
|
+
def get_2d_array(self):
|
327
|
+
raise NotImplementedError('Abstract method')
|
328
|
+
|
329
|
+
|
330
|
+
class GaussianUnstructuredRandomMatrix(RandomMatrix):
|
331
|
+
|
332
|
+
def __init__(self, nb_rows, nb_columns, key):
|
333
|
+
self.nb_rows = nb_rows
|
334
|
+
self.nb_columns = nb_columns
|
335
|
+
self.key = key
|
336
|
+
|
337
|
+
def get_2d_array(self):
|
338
|
+
return random.normal(self.key, (self.nb_rows, self.nb_columns))
|
339
|
+
|
340
|
+
|
341
|
+
class GaussianOrthogonalRandomMatrix(RandomMatrix):
|
342
|
+
r"""Class providing a method to create Gaussian orthogonal matrix.
|
343
|
+
|
344
|
+
Class is responsible for constructing 2D Gaussian orthogonal arrays.
|
345
|
+
"""
|
346
|
+
|
347
|
+
def __init__(self, nb_rows, nb_columns, key, scaling=0):
|
348
|
+
self.nb_rows = nb_rows
|
349
|
+
self.nb_columns = nb_columns
|
350
|
+
self.key = key
|
351
|
+
self.scaling = scaling
|
352
|
+
|
353
|
+
def get_2d_array(self):
|
354
|
+
nb_full_blocks = int(self.nb_rows / self.nb_columns)
|
355
|
+
block_list = []
|
356
|
+
rng = self.key
|
357
|
+
for _ in range(nb_full_blocks):
|
358
|
+
rng, rng_input = jax.random.split(rng)
|
359
|
+
unstructured_block = random.normal(rng_input,
|
360
|
+
(self.nb_columns, self.nb_columns))
|
361
|
+
q, _ = jnp.linalg.qr(unstructured_block)
|
362
|
+
q = jnp.transpose(q)
|
363
|
+
block_list.append(q)
|
364
|
+
remaining_rows = self.nb_rows - nb_full_blocks * self.nb_columns
|
365
|
+
if remaining_rows > 0:
|
366
|
+
rng, rng_input = jax.random.split(rng)
|
367
|
+
unstructured_block = random.normal(rng_input,
|
368
|
+
(self.nb_columns, self.nb_columns))
|
369
|
+
q, _ = jnp.linalg.qr(unstructured_block)
|
370
|
+
q = jnp.transpose(q)
|
371
|
+
block_list.append(q[0:remaining_rows])
|
372
|
+
final_matrix = jnp.vstack(block_list)
|
373
|
+
|
374
|
+
if self.scaling == 0:
|
375
|
+
multiplier = jnp.linalg.norm(
|
376
|
+
random.normal(self.key, (self.nb_rows, self.nb_columns)), axis=1)
|
377
|
+
elif self.scaling == 1:
|
378
|
+
multiplier = jnp.sqrt(float(self.nb_columns)) * jnp.ones((self.nb_rows))
|
379
|
+
else:
|
380
|
+
raise ValueError('Scaling must be one of {0, 1}. Was %s' % self._scaling)
|
381
|
+
|
382
|
+
return jnp.matmul(jnp.diag(multiplier), final_matrix)
|
383
|
+
|
384
|
+
|
385
|
+
class FastAttention(object):
|
386
|
+
r"""Abstract class providing a method for fast attention.
|
387
|
+
|
388
|
+
Class is responsible for providing a method <dot_product_attention> for fast
|
389
|
+
approximate attention.
|
390
|
+
"""
|
391
|
+
|
392
|
+
__metaclass__ = abc.ABCMeta
|
393
|
+
|
394
|
+
@abc.abstractmethod
|
395
|
+
def dot_product_attention(self,
|
396
|
+
query,
|
397
|
+
key,
|
398
|
+
value,
|
399
|
+
dtype=jnp.float32,
|
400
|
+
bias=None,
|
401
|
+
mask=None,
|
402
|
+
axis=None,
|
403
|
+
broadcast_dropout=True,
|
404
|
+
dropout_rng=None,
|
405
|
+
dropout_rate=0.,
|
406
|
+
deterministic=False,
|
407
|
+
precision=None):
|
408
|
+
"""Computes dot-product attention given query, key, and value.
|
409
|
+
|
410
|
+
This is the core function for applying fast approximate dot-product
|
411
|
+
attention. It calculates the attention weights given query and key and
|
412
|
+
combines the values using the attention weights. This function supports
|
413
|
+
multi-dimensional inputs.
|
414
|
+
|
415
|
+
|
416
|
+
Args:
|
417
|
+
query: queries for calculating attention with shape of [batch_size, dim1,
|
418
|
+
dim2, ..., dimN, num_heads, mem_channels].
|
419
|
+
key: keys for calculating attention with shape of [batch_size, dim1, dim2,
|
420
|
+
..., dimN, num_heads, mem_channels].
|
421
|
+
value: values to be used in attention with shape of [batch_size, dim1,
|
422
|
+
dim2,..., dimN, num_heads, value_channels].
|
423
|
+
dtype: the dtype of the computation (default: float32)
|
424
|
+
bias: bias for the attention weights. This can be used for incorporating
|
425
|
+
autoregressive mask, padding mask, proximity bias.
|
426
|
+
mask: mask for the attention weights. This can be used for incorporating
|
427
|
+
autoregressive masks.
|
428
|
+
axis: axises over which the attention is applied.
|
429
|
+
broadcast_dropout: bool: use a broadcasted dropout along batch dims.
|
430
|
+
dropout_rng: JAX PRNGKey: to be used for dropout.
|
431
|
+
dropout_rate: dropout rate.
|
432
|
+
deterministic: bool, deterministic or not (to apply dropout).
|
433
|
+
precision: numerical precision of the computation see `jax.lax.Precision`
|
434
|
+
for details.
|
435
|
+
|
436
|
+
Returns:
|
437
|
+
Output of shape [bs, dim1, dim2, ..., dimN,, num_heads, value_channels].
|
438
|
+
"""
|
439
|
+
raise NotImplementedError('Abstract method')
|
440
|
+
|
441
|
+
|
442
|
+
def _numerator(z_slice_shape, precision, unroll=1):
|
443
|
+
|
444
|
+
def fwd(qs, ks, vs):
|
445
|
+
|
446
|
+
def body(p, qkv):
|
447
|
+
(q, k, v) = qkv
|
448
|
+
p += jnp.einsum('...m,...d->...md', k, v, precision=precision)
|
449
|
+
X_slice = jnp.einsum('...m,...md->...d', q, p, precision=precision)
|
450
|
+
return p, X_slice
|
451
|
+
|
452
|
+
init_value = jnp.zeros(z_slice_shape)
|
453
|
+
p, W = lax.scan(body, init_value, (qs, ks, vs), unroll=unroll)
|
454
|
+
return W, (p, qs, ks, vs)
|
455
|
+
|
456
|
+
def bwd(pqkv, W_ct):
|
457
|
+
|
458
|
+
def body(carry, qkv_xct):
|
459
|
+
p, p_ct = carry
|
460
|
+
q, k, v, x_ct = qkv_xct
|
461
|
+
q_ct = jnp.einsum('...d,...md->...m', x_ct, p, precision=precision)
|
462
|
+
p_ct += jnp.einsum('...d,...m->...md', x_ct, q, precision=precision)
|
463
|
+
k_ct = jnp.einsum('...md,...d->...m', p_ct, v, precision=precision)
|
464
|
+
v_ct = jnp.einsum('...md,...m->...d', p_ct, k, precision=precision)
|
465
|
+
p -= jnp.einsum('...m,...d->...md', k, v, precision=precision)
|
466
|
+
return (p, p_ct), (q_ct, k_ct, v_ct)
|
467
|
+
|
468
|
+
p, qs, ks, vs = pqkv
|
469
|
+
_, (qs_ct, ks_ct, vs_ct) = lax.scan(
|
470
|
+
body, (p, jnp.zeros_like(p)), (qs, ks, vs, W_ct),
|
471
|
+
reverse=True,
|
472
|
+
unroll=unroll)
|
473
|
+
return qs_ct, ks_ct, vs_ct
|
474
|
+
|
475
|
+
@jax.custom_vjp
|
476
|
+
def _numerator_impl(qs, ks, vs):
|
477
|
+
W, _ = fwd(qs, ks, vs)
|
478
|
+
return W
|
479
|
+
|
480
|
+
_numerator_impl.defvjp(fwd, bwd)
|
481
|
+
|
482
|
+
return _numerator_impl
|
483
|
+
|
484
|
+
|
485
|
+
def _denominator(t_slice_shape, precision, unroll=1):
|
486
|
+
|
487
|
+
def fwd(qs, ks):
|
488
|
+
|
489
|
+
def body(p, qk):
|
490
|
+
q, k = qk
|
491
|
+
p += k
|
492
|
+
x = jnp.einsum('...m,...m->...', q, p, precision=precision)
|
493
|
+
return p, x
|
494
|
+
|
495
|
+
p = jnp.zeros(t_slice_shape)
|
496
|
+
p, R = lax.scan(body, p, (qs, ks), unroll=unroll)
|
497
|
+
return R, (qs, ks, p)
|
498
|
+
|
499
|
+
def bwd(qkp, R_ct):
|
500
|
+
|
501
|
+
def body(carry, qkx):
|
502
|
+
p, p_ct = carry
|
503
|
+
q, k, x_ct = qkx
|
504
|
+
q_ct = jnp.einsum('...,...m->...m', x_ct, p, precision=precision)
|
505
|
+
p_ct += jnp.einsum('...,...m->...m', x_ct, q, precision=precision)
|
506
|
+
k_ct = p_ct
|
507
|
+
p -= k
|
508
|
+
return (p, p_ct), (q_ct, k_ct)
|
509
|
+
|
510
|
+
qs, ks, p = qkp
|
511
|
+
_, (qs_ct, ks_ct) = lax.scan(
|
512
|
+
body, (p, jnp.zeros_like(p)), (qs, ks, R_ct),
|
513
|
+
reverse=True,
|
514
|
+
unroll=unroll)
|
515
|
+
return (qs_ct, ks_ct)
|
516
|
+
|
517
|
+
@jax.custom_vjp
|
518
|
+
def _denominator_impl(qs, ks):
|
519
|
+
R, _ = fwd(qs, ks)
|
520
|
+
return R
|
521
|
+
|
522
|
+
_denominator_impl.defvjp(fwd, bwd)
|
523
|
+
|
524
|
+
return _denominator_impl
|
525
|
+
|
526
|
+
|
527
|
+
class FastAttentionviaLowRankDecomposition(FastAttention):
|
528
|
+
r"""Class providing a method for fast attention via low rank decomposition.
|
529
|
+
|
530
|
+
Class is responsible for providing a method <dot_product_attention> for fast
|
531
|
+
dot-product attention with the use of low rank decomposition (e.g. with
|
532
|
+
random feature maps).
|
533
|
+
"""
|
534
|
+
|
535
|
+
def __init__(self,
|
536
|
+
matrix_creator,
|
537
|
+
kernel_feature_creator,
|
538
|
+
renormalize_attention,
|
539
|
+
numerical_stabilizer,
|
540
|
+
redraw_features,
|
541
|
+
unidirectional,
|
542
|
+
lax_scan_unroll=1): # For optimal GPU performance, set to 16.
|
543
|
+
rng = random.PRNGKey(0)
|
544
|
+
self.matrix_creator = matrix_creator
|
545
|
+
self.projection_matrix = self.draw_weights(rng)
|
546
|
+
self.kernel_feature_creator = kernel_feature_creator
|
547
|
+
self.renormalize_attention = renormalize_attention
|
548
|
+
self.numerical_stabilizer = numerical_stabilizer
|
549
|
+
self.redraw_features = redraw_features
|
550
|
+
self.unidirectional = unidirectional
|
551
|
+
self.lax_scan_unroll = lax_scan_unroll
|
552
|
+
|
553
|
+
def draw_weights(self, key):
|
554
|
+
if self.matrix_creator is None:
|
555
|
+
return None
|
556
|
+
matrixrng, _ = random.split(key)
|
557
|
+
projection_matrix = self.matrix_creator(key=matrixrng).get_2d_array()
|
558
|
+
return projection_matrix
|
559
|
+
|
560
|
+
def dot_product_attention(self,
|
561
|
+
query,
|
562
|
+
key,
|
563
|
+
value,
|
564
|
+
dtype=jnp.float32,
|
565
|
+
bias=None,
|
566
|
+
mask=None,
|
567
|
+
axis=None,
|
568
|
+
broadcast_dropout=True,
|
569
|
+
dropout_rng=None,
|
570
|
+
dropout_rate=0.,
|
571
|
+
deterministic=False,
|
572
|
+
precision=None):
|
573
|
+
|
574
|
+
assert key.shape[:-1] == value.shape[:-1]
|
575
|
+
assert (query.shape[0:1] == key.shape[0:1] and
|
576
|
+
query.shape[-1] == key.shape[-1])
|
577
|
+
if axis is None:
|
578
|
+
axis = tuple(range(1, key.ndim - 2))
|
579
|
+
if not isinstance(axis, Iterable):
|
580
|
+
axis = (axis,)
|
581
|
+
assert key.ndim == query.ndim
|
582
|
+
assert key.ndim == value.ndim
|
583
|
+
for ax in axis:
|
584
|
+
if not (query.ndim >= 3 and 1 <= ax < query.ndim - 2):
|
585
|
+
raise ValueError('Attention axis must be between the batch '
|
586
|
+
'axis and the last-two axes.')
|
587
|
+
n = key.ndim
|
588
|
+
|
589
|
+
# Constructing projection tensor.
|
590
|
+
if self.redraw_features:
|
591
|
+
# TODO(kchoro): Get rid of the constant below.
|
592
|
+
query_seed = lax.convert_element_type(
|
593
|
+
jnp.ceil(jnp.sum(query) * 10000000.0), jnp.int32)
|
594
|
+
rng = random.PRNGKey(query_seed)
|
595
|
+
self.projection_matrix = self.draw_weights(rng)
|
596
|
+
|
597
|
+
# batch_dims is <bs, <non-attention dims>, num_heads>
|
598
|
+
batch_dims = tuple(onp.delete(range(n), axis + (n - 1,)))
|
599
|
+
# q & k -> (bs, <non-attention dims>, num_heads, <attention dims>, channels)
|
600
|
+
qk_perm = batch_dims + axis + (n - 1,)
|
601
|
+
k_extra_perm = axis + batch_dims + (n - 1,)
|
602
|
+
key_extra = key.transpose(k_extra_perm)
|
603
|
+
key = key.transpose(qk_perm)
|
604
|
+
query = query.transpose(qk_perm)
|
605
|
+
# v -> (bs, <non-attention dims>, num_heads, <attention dims>, channels)
|
606
|
+
v_perm = batch_dims + axis + (n - 1,)
|
607
|
+
value = value.transpose(v_perm)
|
608
|
+
batch_dims_t = tuple(range(len(batch_dims)))
|
609
|
+
attention_dims_t = tuple(
|
610
|
+
range(len(batch_dims),
|
611
|
+
len(batch_dims) + len(axis)))
|
612
|
+
|
613
|
+
# Constructing tensors Q^{'} and K^{'}.
|
614
|
+
query_prime = self.kernel_feature_creator(query, self.projection_matrix,
|
615
|
+
attention_dims_t, batch_dims_t,
|
616
|
+
precision, True)
|
617
|
+
key_prime = self.kernel_feature_creator(key, self.projection_matrix,
|
618
|
+
attention_dims_t, batch_dims_t,
|
619
|
+
precision, False)
|
620
|
+
|
621
|
+
if self.unidirectional:
|
622
|
+
index = attention_dims_t[0]
|
623
|
+
z_slice_shape = key_prime.shape[0:len(batch_dims_t)] + (
|
624
|
+
key_prime.shape[-1],) + (value.shape[-1],)
|
625
|
+
|
626
|
+
numerator_fn = _numerator(z_slice_shape, precision, self.lax_scan_unroll)
|
627
|
+
W = numerator_fn(
|
628
|
+
jnp.moveaxis(query_prime, index, 0),
|
629
|
+
jnp.moveaxis(key_prime, index, 0), jnp.moveaxis(value, index, 0))
|
630
|
+
|
631
|
+
# Constructing W = (Q^{'}(K^{'})^{T})_{masked}V
|
632
|
+
W = jnp.moveaxis(W, 0, index)
|
633
|
+
|
634
|
+
if not self.renormalize_attention:
|
635
|
+
# Unidirectional, not-normalized attention.
|
636
|
+
perm_inv = _invert_perm(qk_perm)
|
637
|
+
result = W.transpose(perm_inv)
|
638
|
+
return result
|
639
|
+
else:
|
640
|
+
# Unidirectional, normalized attention.
|
641
|
+
thick_all_ones = jnp.zeros(key.shape[0:-1]) + jnp.ones(
|
642
|
+
key_extra.shape[0:len(axis)])
|
643
|
+
|
644
|
+
index = attention_dims_t[0]
|
645
|
+
t_slice_shape = key_prime.shape[0:len(batch_dims_t)] + (
|
646
|
+
key_prime.shape[-1],)
|
647
|
+
denominator_fn = _denominator(t_slice_shape, precision,
|
648
|
+
self.lax_scan_unroll)
|
649
|
+
R = denominator_fn(
|
650
|
+
jnp.moveaxis(query_prime, index, 0),
|
651
|
+
jnp.moveaxis(key_prime, index, 0))
|
652
|
+
|
653
|
+
R = jnp.moveaxis(R, 0, index)
|
654
|
+
else:
|
655
|
+
contract_query = tuple(
|
656
|
+
range(len(batch_dims) + len(axis),
|
657
|
+
len(batch_dims) + len(axis) + 1))
|
658
|
+
contract_z = tuple(range(len(batch_dims), len(batch_dims) + 1))
|
659
|
+
# Constructing Z = (K^{'})^{T}V
|
660
|
+
# Z (bs, <non-attention dims>, num_heads, channels_m, channels_v)
|
661
|
+
Z = lax.dot_general(
|
662
|
+
key_prime,
|
663
|
+
value,
|
664
|
+
((attention_dims_t, attention_dims_t), (batch_dims_t, batch_dims_t)),
|
665
|
+
precision=precision)
|
666
|
+
# Constructing W = Q^{'}Z = Q^{'}(K^{'})^{T}V
|
667
|
+
# q (bs, <non-attention dims>, num_heads, <attention dims>, channels_m)
|
668
|
+
# Z (bs, <non-attention dims>, num_heads, channels_m, channels_v)
|
669
|
+
# W (bs, <non-attention dims>, num_heads, <attention dims>, channels_v)
|
670
|
+
W = lax.dot_general(
|
671
|
+
query_prime,
|
672
|
+
Z, ((contract_query, contract_z), (batch_dims_t, batch_dims_t)),
|
673
|
+
precision=precision)
|
674
|
+
if not self.renormalize_attention:
|
675
|
+
# Bidirectional, not-normalized attention.
|
676
|
+
perm_inv = _invert_perm(qk_perm)
|
677
|
+
result = W.transpose(perm_inv)
|
678
|
+
return result
|
679
|
+
else:
|
680
|
+
# Bidirectional, normalized attention.
|
681
|
+
thick_all_ones = jnp.zeros(key.shape[0:-1]) + jnp.ones(
|
682
|
+
key_extra.shape[0:len(axis)])
|
683
|
+
contract_key = tuple(
|
684
|
+
range(len(batch_dims),
|
685
|
+
len(batch_dims) + len(axis)))
|
686
|
+
contract_thick_all_ones = tuple(
|
687
|
+
range(thick_all_ones.ndim - len(axis), thick_all_ones.ndim))
|
688
|
+
# Construct T = (K^{'})^{T} 1_L
|
689
|
+
# k (bs, <non-attention dims>, num_heads, <attention dims>, channels)
|
690
|
+
T = lax.dot_general(
|
691
|
+
key_prime,
|
692
|
+
thick_all_ones, ((contract_key, contract_thick_all_ones),
|
693
|
+
(batch_dims_t, batch_dims_t)),
|
694
|
+
precision=precision)
|
695
|
+
|
696
|
+
# Construct partition function: R = Q^{'} T = Q^{'}(K^{'})^{T} 1_L
|
697
|
+
# q_p (bs, <non-attention dims>, num_heads, <attention dims>, channs_m)
|
698
|
+
# T (bs, <non-attention dims>, num_heads, channels_m)
|
699
|
+
R = lax.dot_general(
|
700
|
+
query_prime,
|
701
|
+
T, (((query_prime.ndim - 1,), (T.ndim - 1,)),
|
702
|
+
(batch_dims_t, range(0,
|
703
|
+
len(T.shape) - 1))),
|
704
|
+
precision=precision)
|
705
|
+
|
706
|
+
R = R + 2 * self.numerical_stabilizer * (
|
707
|
+
jnp.abs(R) <= self.numerical_stabilizer)
|
708
|
+
R = jnp.reciprocal(R)
|
709
|
+
R = jnp.expand_dims(R, len(R.shape))
|
710
|
+
# W (bs, <non-attention dims>, num_heads, <attention dims>, channels_v)
|
711
|
+
# R (bs, <non-attention dims>, num_heads, <attention dims>, extra_channel)
|
712
|
+
result = W * R
|
713
|
+
# back to (bs, dim1, dim2, ..., dimN, num_heads, channels)
|
714
|
+
perm_inv = _invert_perm(qk_perm)
|
715
|
+
result = result.transpose(perm_inv)
|
716
|
+
return result
|
717
|
+
|
718
|
+
|
719
|
+
def _invert_perm(perm):
|
720
|
+
perm_inv = [0] * len(perm)
|
721
|
+
for i, j in enumerate(perm):
|
722
|
+
perm_inv[j] = i
|
723
|
+
return tuple(perm_inv)
|