flaxdiff 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/__init__.py +0 -0
- flaxdiff/models/__init__.py +1 -0
- flaxdiff/models/attention.py +489 -0
- flaxdiff/models/common.py +7 -0
- flaxdiff/models/favor_fastattn.py +723 -0
- flaxdiff/models/simple_unet.py +519 -0
- flaxdiff/predictors/__init__.py +96 -0
- flaxdiff/samplers/__init__.py +7 -0
- flaxdiff/samplers/common.py +113 -0
- flaxdiff/samplers/ddim.py +10 -0
- flaxdiff/samplers/ddpm.py +43 -0
- flaxdiff/samplers/euler.py +59 -0
- flaxdiff/samplers/heun_sampler.py +28 -0
- flaxdiff/samplers/multistep_dpm.py +60 -0
- flaxdiff/samplers/rk4_sampler.py +34 -0
- flaxdiff/schedulers/__init__.py +6 -0
- flaxdiff/schedulers/common.py +98 -0
- flaxdiff/schedulers/continuous.py +12 -0
- flaxdiff/schedulers/cosine.py +40 -0
- flaxdiff/schedulers/discrete.py +74 -0
- flaxdiff/schedulers/exp.py +13 -0
- flaxdiff/schedulers/karras.py +69 -0
- flaxdiff/schedulers/linear.py +14 -0
- flaxdiff/schedulers/sqrt.py +10 -0
- flaxdiff/trainer/__init__.py +216 -0
- flaxdiff/utils.py +89 -0
- flaxdiff-0.1.1.dist-info/METADATA +326 -0
- flaxdiff-0.1.1.dist-info/RECORD +30 -0
- flaxdiff-0.1.1.dist-info/WHEEL +5 -0
- flaxdiff-0.1.1.dist-info/top_level.txt +1 -0
flaxdiff/__init__.py
ADDED
File without changes
|
@@ -0,0 +1 @@
|
|
1
|
+
from .simple_unet import *
|
@@ -0,0 +1,489 @@
|
|
1
|
+
"""
|
2
|
+
Some Code ported from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_flax.py
|
3
|
+
"""
|
4
|
+
|
5
|
+
import jax
|
6
|
+
import jax.numpy as jnp
|
7
|
+
from flax import linen as nn
|
8
|
+
from typing import Dict, Callable, Sequence, Any, Union
|
9
|
+
import einops
|
10
|
+
import functools
|
11
|
+
import math
|
12
|
+
from .common import kernel_init
|
13
|
+
|
14
|
+
def _query_chunk_attention(query, key, value, precision, key_chunk_size: int = 4096):
|
15
|
+
"""Multi-head dot product attention with a limited number of queries."""
|
16
|
+
num_kv, num_heads, k_features = key.shape[-3:]
|
17
|
+
v_features = value.shape[-1]
|
18
|
+
key_chunk_size = min(key_chunk_size, num_kv)
|
19
|
+
query = query / jnp.sqrt(k_features)
|
20
|
+
|
21
|
+
@functools.partial(jax.checkpoint, prevent_cse=False)
|
22
|
+
def summarize_chunk(query, key, value):
|
23
|
+
attn_weights = jnp.einsum("...qhd,...khd->...qhk", query, key, precision=precision)
|
24
|
+
|
25
|
+
max_score = jnp.max(attn_weights, axis=-1, keepdims=True)
|
26
|
+
max_score = jax.lax.stop_gradient(max_score)
|
27
|
+
exp_weights = jnp.exp(attn_weights - max_score)
|
28
|
+
|
29
|
+
exp_values = jnp.einsum("...vhf,...qhv->...qhf", value, exp_weights, precision=precision)
|
30
|
+
max_score = jnp.einsum("...qhk->...qh", max_score)
|
31
|
+
|
32
|
+
return (exp_values, exp_weights.sum(axis=-1), max_score)
|
33
|
+
|
34
|
+
def chunk_scanner(chunk_idx):
|
35
|
+
# julienne key array
|
36
|
+
key_chunk = jax.lax.dynamic_slice(
|
37
|
+
operand=key,
|
38
|
+
start_indices=[0] * (key.ndim - 3) + [chunk_idx, 0, 0], # [...,k,h,d]
|
39
|
+
slice_sizes=list(key.shape[:-3]) + [key_chunk_size, num_heads, k_features], # [...,k,h,d]
|
40
|
+
)
|
41
|
+
|
42
|
+
# julienne value array
|
43
|
+
value_chunk = jax.lax.dynamic_slice(
|
44
|
+
operand=value,
|
45
|
+
start_indices=[0] * (value.ndim - 3) + [chunk_idx, 0, 0], # [...,v,h,d]
|
46
|
+
slice_sizes=list(value.shape[:-3]) + [key_chunk_size, num_heads, v_features], # [...,v,h,d]
|
47
|
+
)
|
48
|
+
|
49
|
+
return summarize_chunk(query, key_chunk, value_chunk)
|
50
|
+
|
51
|
+
chunk_values, chunk_weights, chunk_max = jax.lax.map(f=chunk_scanner, xs=jnp.arange(0, num_kv, key_chunk_size))
|
52
|
+
|
53
|
+
global_max = jnp.max(chunk_max, axis=0, keepdims=True)
|
54
|
+
max_diffs = jnp.exp(chunk_max - global_max)
|
55
|
+
|
56
|
+
chunk_values *= jnp.expand_dims(max_diffs, axis=-1)
|
57
|
+
chunk_weights *= max_diffs
|
58
|
+
|
59
|
+
all_values = chunk_values.sum(axis=0)
|
60
|
+
all_weights = jnp.expand_dims(chunk_weights, -1).sum(axis=0)
|
61
|
+
|
62
|
+
return all_values / all_weights
|
63
|
+
|
64
|
+
|
65
|
+
def jax_memory_efficient_attention(
|
66
|
+
query, key, value, precision=jax.lax.Precision.HIGHEST, query_chunk_size: int = 1024, key_chunk_size: int = 4096
|
67
|
+
):
|
68
|
+
r"""
|
69
|
+
Flax Memory-efficient multi-head dot product attention. https://arxiv.org/abs/2112.05682v2
|
70
|
+
https://github.com/AminRezaei0x443/memory-efficient-attention
|
71
|
+
|
72
|
+
Args:
|
73
|
+
query (`jnp.ndarray`): (batch..., query_length, head, query_key_depth_per_head)
|
74
|
+
key (`jnp.ndarray`): (batch..., key_value_length, head, query_key_depth_per_head)
|
75
|
+
value (`jnp.ndarray`): (batch..., key_value_length, head, value_depth_per_head)
|
76
|
+
precision (`jax.lax.Precision`, *optional*, defaults to `jax.lax.Precision.HIGHEST`):
|
77
|
+
numerical precision for computation
|
78
|
+
query_chunk_size (`int`, *optional*, defaults to 1024):
|
79
|
+
chunk size to divide query array value must divide query_length equally without remainder
|
80
|
+
key_chunk_size (`int`, *optional*, defaults to 4096):
|
81
|
+
chunk size to divide key and value array value must divide key_value_length equally without remainder
|
82
|
+
|
83
|
+
Returns:
|
84
|
+
(`jnp.ndarray`) with shape of (batch..., query_length, head, value_depth_per_head)
|
85
|
+
"""
|
86
|
+
num_q, num_heads, q_features = query.shape[-3:]
|
87
|
+
|
88
|
+
def chunk_scanner(chunk_idx, _):
|
89
|
+
# julienne query array
|
90
|
+
query_chunk = jax.lax.dynamic_slice(
|
91
|
+
operand=query,
|
92
|
+
start_indices=([0] * (query.ndim - 3)) + [chunk_idx, 0, 0], # [...,q,h,d]
|
93
|
+
slice_sizes=list(query.shape[:-3]) + [min(query_chunk_size, num_q), num_heads, q_features], # [...,q,h,d]
|
94
|
+
)
|
95
|
+
|
96
|
+
return (
|
97
|
+
chunk_idx + query_chunk_size, # unused ignore it
|
98
|
+
_query_chunk_attention(
|
99
|
+
query=query_chunk, key=key, value=value, precision=precision, key_chunk_size=key_chunk_size
|
100
|
+
),
|
101
|
+
)
|
102
|
+
|
103
|
+
_, res = jax.lax.scan(
|
104
|
+
f=chunk_scanner,
|
105
|
+
init=0,
|
106
|
+
xs=None,
|
107
|
+
length=math.ceil(num_q / query_chunk_size), # start counter # stop counter
|
108
|
+
)
|
109
|
+
|
110
|
+
return jnp.concatenate(res, axis=-3) # fuse the chunked result back
|
111
|
+
|
112
|
+
|
113
|
+
class EfficientAttention(nn.Module):
|
114
|
+
"""
|
115
|
+
Based on the pallas attention implementation.
|
116
|
+
"""
|
117
|
+
query_dim: int
|
118
|
+
heads: int = 4
|
119
|
+
dim_head: int = 64
|
120
|
+
dtype: Any = jnp.float32
|
121
|
+
precision: Any = jax.lax.Precision.HIGHEST
|
122
|
+
use_bias: bool = True
|
123
|
+
kernel_init: Callable = lambda : kernel_init(1.0)
|
124
|
+
|
125
|
+
def setup(self):
|
126
|
+
inner_dim = self.dim_head * self.heads
|
127
|
+
# Weights were exported with old names {to_q, to_k, to_v, to_out}
|
128
|
+
self.query = nn.DenseGeneral(inner_dim, use_bias=False, precision=self.precision,
|
129
|
+
kernel_init=self.kernel_init(), dtype=self.dtype, name="to_q")
|
130
|
+
self.key = nn.DenseGeneral(inner_dim, use_bias=False, precision=self.precision,
|
131
|
+
kernel_init=self.kernel_init(), dtype=self.dtype, name="to_k")
|
132
|
+
self.value = nn.DenseGeneral(inner_dim, use_bias=False, precision=self.precision,
|
133
|
+
kernel_init=self.kernel_init(), dtype=self.dtype, name="to_v")
|
134
|
+
self.proj_attn = nn.DenseGeneral(self.query_dim, use_bias=False, precision=self.precision,
|
135
|
+
kernel_init=self.kernel_init(), dtype=self.dtype, name="to_out_0")
|
136
|
+
# self.attnfn = make_fast_generalized_attention(qkv_dim=inner_dim, lax_scan_unroll=16)
|
137
|
+
|
138
|
+
@nn.compact
|
139
|
+
def __call__(self, x:jax.Array, context=None):
|
140
|
+
# x has shape [B, H * W, C]
|
141
|
+
context = x if context is None else context
|
142
|
+
query = self.query(x)
|
143
|
+
key = self.key(context)
|
144
|
+
value = self.value(context)
|
145
|
+
|
146
|
+
# print(query.shape, key.shape, value.shape)
|
147
|
+
|
148
|
+
# hidden_states = jax.experimental.pallas.ops.tpu.flash_attention.mha_reference(
|
149
|
+
# query, key, value, None
|
150
|
+
# )
|
151
|
+
|
152
|
+
hidden_states = nn.dot_product_attention(
|
153
|
+
query, key, value, dtype=self.dtype, broadcast_dropout=False, dropout_rng=None, precision=self.precision
|
154
|
+
)
|
155
|
+
# hidden_states = self.attnfn(
|
156
|
+
# query, key, value, None
|
157
|
+
# )
|
158
|
+
|
159
|
+
proj = self.proj_attn(hidden_states)
|
160
|
+
return proj
|
161
|
+
|
162
|
+
|
163
|
+
class NormalAttention(nn.Module):
|
164
|
+
"""
|
165
|
+
Simple implementation of the normal attention.
|
166
|
+
"""
|
167
|
+
query_dim: int
|
168
|
+
heads: int = 4
|
169
|
+
dim_head: int = 64
|
170
|
+
dtype: Any = jnp.float32
|
171
|
+
precision: Any = jax.lax.Precision.HIGHEST
|
172
|
+
use_bias: bool = True
|
173
|
+
kernel_init: Callable = lambda : kernel_init(1.0)
|
174
|
+
|
175
|
+
def setup(self):
|
176
|
+
inner_dim = self.dim_head * self.heads
|
177
|
+
dense = functools.partial(
|
178
|
+
nn.DenseGeneral,
|
179
|
+
features=[self.heads, self.dim_head],
|
180
|
+
axis=-1,
|
181
|
+
precision=self.precision,
|
182
|
+
use_bias=self.use_bias,
|
183
|
+
kernel_init=self.kernel_init(),
|
184
|
+
dtype=self.dtype
|
185
|
+
)
|
186
|
+
self.query = dense(name="to_q")
|
187
|
+
self.key = dense(name="to_k")
|
188
|
+
self.value = dense(name="to_v")
|
189
|
+
|
190
|
+
self.proj_attn = nn.DenseGeneral(
|
191
|
+
self.query_dim,
|
192
|
+
axis=(-2, -1),
|
193
|
+
precision=self.precision,
|
194
|
+
use_bias=self.use_bias,
|
195
|
+
dtype=self.dtype,
|
196
|
+
name="to_out_0",
|
197
|
+
kernel_init=self.kernel_init()
|
198
|
+
# kernel_init=jax.nn.initializers.xavier_uniform()
|
199
|
+
)
|
200
|
+
|
201
|
+
@nn.compact
|
202
|
+
def __call__(self, x, context=None):
|
203
|
+
# x has shape [B, H, W, C]
|
204
|
+
context = x if context is None else context
|
205
|
+
query = self.query(x)
|
206
|
+
key = self.key(context)
|
207
|
+
value = self.value(context)
|
208
|
+
|
209
|
+
hidden_states = nn.dot_product_attention(
|
210
|
+
query, key, value, dtype=self.dtype, broadcast_dropout=False, dropout_rng=None, precision=self.precision
|
211
|
+
)
|
212
|
+
proj = self.proj_attn(hidden_states)
|
213
|
+
return proj
|
214
|
+
|
215
|
+
class AttentionBlock(nn.Module):
|
216
|
+
# Has self and cross attention
|
217
|
+
query_dim: int
|
218
|
+
heads: int = 4
|
219
|
+
dim_head: int = 64
|
220
|
+
dtype: Any = jnp.float32
|
221
|
+
precision: Any = jax.lax.Precision.HIGHEST
|
222
|
+
use_bias: bool = True
|
223
|
+
kernel_init: Callable = lambda : kernel_init(1.0)
|
224
|
+
use_flash_attention:bool = False
|
225
|
+
use_cross_only:bool = False
|
226
|
+
|
227
|
+
def setup(self):
|
228
|
+
if self.use_flash_attention:
|
229
|
+
attenBlock = EfficientAttention
|
230
|
+
else:
|
231
|
+
attenBlock = NormalAttention
|
232
|
+
|
233
|
+
self.attention1 = attenBlock(
|
234
|
+
query_dim=self.query_dim,
|
235
|
+
heads=self.heads,
|
236
|
+
dim_head=self.dim_head,
|
237
|
+
name=f'Attention1',
|
238
|
+
precision=self.precision,
|
239
|
+
use_bias=self.use_bias,
|
240
|
+
dtype=self.dtype,
|
241
|
+
kernel_init=self.kernel_init
|
242
|
+
)
|
243
|
+
self.attention2 = attenBlock(
|
244
|
+
query_dim=self.query_dim,
|
245
|
+
heads=self.heads,
|
246
|
+
dim_head=self.dim_head,
|
247
|
+
name=f'Attention2',
|
248
|
+
precision=self.precision,
|
249
|
+
use_bias=self.use_bias,
|
250
|
+
dtype=self.dtype,
|
251
|
+
kernel_init=self.kernel_init
|
252
|
+
)
|
253
|
+
|
254
|
+
self.ff = nn.DenseGeneral(
|
255
|
+
features=self.query_dim,
|
256
|
+
use_bias=self.use_bias,
|
257
|
+
precision=self.precision,
|
258
|
+
dtype=self.dtype,
|
259
|
+
kernel_init=self.kernel_init(),
|
260
|
+
name="ff"
|
261
|
+
)
|
262
|
+
self.norm1 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
|
263
|
+
self.norm2 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
|
264
|
+
self.norm3 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
|
265
|
+
self.norm4 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
|
266
|
+
|
267
|
+
@nn.compact
|
268
|
+
def __call__(self, hidden_states, context=None):
|
269
|
+
# self attention
|
270
|
+
residual = hidden_states
|
271
|
+
hidden_states = self.norm1(hidden_states)
|
272
|
+
if self.use_cross_only:
|
273
|
+
hidden_states = self.attention1(hidden_states, context)
|
274
|
+
else:
|
275
|
+
hidden_states = self.attention1(hidden_states)
|
276
|
+
hidden_states = hidden_states + residual
|
277
|
+
|
278
|
+
# cross attention
|
279
|
+
residual = hidden_states
|
280
|
+
hidden_states = self.norm2(hidden_states)
|
281
|
+
hidden_states = self.attention2(hidden_states, context)
|
282
|
+
hidden_states = hidden_states + residual
|
283
|
+
|
284
|
+
# feed forward
|
285
|
+
residual = hidden_states
|
286
|
+
hidden_states = self.norm3(hidden_states)
|
287
|
+
hidden_states = nn.gelu(hidden_states)
|
288
|
+
hidden_states = self.ff(hidden_states)
|
289
|
+
hidden_states = hidden_states + residual
|
290
|
+
|
291
|
+
return hidden_states
|
292
|
+
|
293
|
+
class TransformerBlock(nn.Module):
|
294
|
+
heads: int = 4
|
295
|
+
dim_head: int = 32
|
296
|
+
use_linear_attention: bool = True
|
297
|
+
dtype: Any = jnp.float32
|
298
|
+
precision: Any = jax.lax.Precision.HIGH
|
299
|
+
use_projection: bool = False
|
300
|
+
use_flash_attention:bool = True
|
301
|
+
use_self_and_cross:bool = False
|
302
|
+
|
303
|
+
@nn.compact
|
304
|
+
def __call__(self, x, context=None):
|
305
|
+
inner_dim = self.heads * self.dim_head
|
306
|
+
B, H, W, C = x.shape
|
307
|
+
normed_x = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)(x)
|
308
|
+
if self.use_projection == True:
|
309
|
+
if self.use_linear_attention:
|
310
|
+
projected_x = nn.Dense(features=inner_dim,
|
311
|
+
use_bias=False, precision=self.precision,
|
312
|
+
kernel_init=kernel_init(1.0),
|
313
|
+
dtype=self.dtype, name=f'project_in')(normed_x)
|
314
|
+
else:
|
315
|
+
projected_x = nn.Conv(
|
316
|
+
features=inner_dim, kernel_size=(1, 1),
|
317
|
+
kernel_init=kernel_init(1.0),
|
318
|
+
strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
|
319
|
+
precision=self.precision, name=f'project_in_conv',
|
320
|
+
)(normed_x)
|
321
|
+
else:
|
322
|
+
projected_x = normed_x
|
323
|
+
inner_dim = C
|
324
|
+
|
325
|
+
context = projected_x if context is None else context
|
326
|
+
|
327
|
+
if self.use_self_and_cross:
|
328
|
+
projected_x = AttentionBlock(
|
329
|
+
query_dim=inner_dim,
|
330
|
+
heads=self.heads,
|
331
|
+
dim_head=self.dim_head,
|
332
|
+
name=f'Attention',
|
333
|
+
precision=self.precision,
|
334
|
+
use_bias=False,
|
335
|
+
dtype=self.dtype,
|
336
|
+
use_flash_attention=self.use_flash_attention,
|
337
|
+
use_cross_only=False
|
338
|
+
)(projected_x, context)
|
339
|
+
elif self.use_flash_attention == True:
|
340
|
+
projected_x = EfficientAttention(
|
341
|
+
query_dim=inner_dim,
|
342
|
+
heads=self.heads,
|
343
|
+
dim_head=self.dim_head,
|
344
|
+
name=f'Attention',
|
345
|
+
precision=self.precision,
|
346
|
+
use_bias=False,
|
347
|
+
dtype=self.dtype,
|
348
|
+
)(projected_x, context)
|
349
|
+
else:
|
350
|
+
projected_x = NormalAttention(
|
351
|
+
query_dim=inner_dim,
|
352
|
+
heads=self.heads,
|
353
|
+
dim_head=self.dim_head,
|
354
|
+
name=f'Attention',
|
355
|
+
precision=self.precision,
|
356
|
+
use_bias=False,
|
357
|
+
)(projected_x, context)
|
358
|
+
|
359
|
+
|
360
|
+
if self.use_projection == True:
|
361
|
+
if self.use_linear_attention:
|
362
|
+
projected_x = nn.Dense(features=C, precision=self.precision,
|
363
|
+
dtype=self.dtype, use_bias=False,
|
364
|
+
kernel_init=kernel_init(1.0),
|
365
|
+
name=f'project_out')(projected_x)
|
366
|
+
else:
|
367
|
+
projected_x = nn.Conv(
|
368
|
+
features=C, kernel_size=(1, 1),
|
369
|
+
kernel_init=kernel_init(1.0),
|
370
|
+
strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
|
371
|
+
precision=self.precision, name=f'project_out_conv',
|
372
|
+
)(projected_x)
|
373
|
+
|
374
|
+
out = x + projected_x
|
375
|
+
return out
|
376
|
+
|
377
|
+
class FlaxGEGLU(nn.Module):
|
378
|
+
r"""
|
379
|
+
Flax implementation of a Linear layer followed by the variant of the gated linear unit activation function from
|
380
|
+
https://arxiv.org/abs/2002.05202.
|
381
|
+
|
382
|
+
Parameters:
|
383
|
+
dim (:obj:`int`):
|
384
|
+
Input hidden states dimension
|
385
|
+
dropout (:obj:`float`, *optional*, defaults to 0.0):
|
386
|
+
Dropout rate
|
387
|
+
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
388
|
+
Parameters `dtype`
|
389
|
+
"""
|
390
|
+
|
391
|
+
dim: int
|
392
|
+
dropout: float = 0.0
|
393
|
+
dtype: jnp.dtype = jnp.float32
|
394
|
+
|
395
|
+
def setup(self):
|
396
|
+
inner_dim = self.dim * 4
|
397
|
+
self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype, precision=jax.lax.Precision.DEFAULT)
|
398
|
+
|
399
|
+
def __call__(self, hidden_states):
|
400
|
+
hidden_states = self.proj(hidden_states)
|
401
|
+
hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=3)
|
402
|
+
return hidden_linear * nn.gelu(hidden_gelu)
|
403
|
+
|
404
|
+
class FlaxFeedForward(nn.Module):
|
405
|
+
r"""
|
406
|
+
Flax module that encapsulates two Linear layers separated by a non-linearity. It is the counterpart of PyTorch's
|
407
|
+
[`FeedForward`] class, with the following simplifications:
|
408
|
+
- The activation function is currently hardcoded to a gated linear unit from:
|
409
|
+
https://arxiv.org/abs/2002.05202
|
410
|
+
- `dim_out` is equal to `dim`.
|
411
|
+
- The number of hidden dimensions is hardcoded to `dim * 4` in [`FlaxGELU`].
|
412
|
+
|
413
|
+
Parameters:
|
414
|
+
dim (:obj:`int`):
|
415
|
+
Inner hidden states dimension
|
416
|
+
dropout (:obj:`float`, *optional*, defaults to 0.0):
|
417
|
+
Dropout rate
|
418
|
+
dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
|
419
|
+
Parameters `dtype`
|
420
|
+
"""
|
421
|
+
|
422
|
+
dim: int
|
423
|
+
dropout: float = 0.0
|
424
|
+
dtype: jnp.dtype = jnp.float32
|
425
|
+
|
426
|
+
def setup(self):
|
427
|
+
# The second linear layer needs to be called
|
428
|
+
# net_2 for now to match the index of the Sequential layer
|
429
|
+
self.net_0 = FlaxGEGLU(self.dim, self.dtype)
|
430
|
+
self.net_2 = nn.Dense(self.dim, dtype=self.dtype, precision=jax.lax.Precision.DEFAULT)
|
431
|
+
|
432
|
+
def __call__(self, hidden_states):
|
433
|
+
hidden_states = self.net_0(hidden_states)
|
434
|
+
hidden_states = self.net_2(hidden_states)
|
435
|
+
return hidden_states
|
436
|
+
|
437
|
+
class BasicTransformerBlock(nn.Module):
|
438
|
+
query_dim: int
|
439
|
+
heads: int
|
440
|
+
dim_head: int
|
441
|
+
dropout: float = 0.0
|
442
|
+
only_cross_attention: bool = False
|
443
|
+
dtype: jnp.dtype = jnp.float32
|
444
|
+
use_memory_efficient_attention: bool = False
|
445
|
+
split_head_dim: bool = False
|
446
|
+
precision: Any = jax.lax.Precision.DEFAULT
|
447
|
+
|
448
|
+
def setup(self):
|
449
|
+
# self attention (or cross_attention if only_cross_attention is True)
|
450
|
+
self.attn1 = NormalAttention(
|
451
|
+
query_dim=self.query_dim,
|
452
|
+
heads=self.heads,
|
453
|
+
dim_head=self.dim_head,
|
454
|
+
dtype=self.dtype,
|
455
|
+
precision=self.precision,
|
456
|
+
)
|
457
|
+
# cross attention
|
458
|
+
self.attn2 = NormalAttention(
|
459
|
+
query_dim=self.query_dim,
|
460
|
+
heads=self.heads,
|
461
|
+
dim_head=self.dim_head,
|
462
|
+
dtype=self.dtype,
|
463
|
+
precision=self.precision,
|
464
|
+
)
|
465
|
+
self.ff = FlaxFeedForward(dim=self.query_dim, dropout=self.dropout, dtype=self.dtype)
|
466
|
+
self.norm1 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
|
467
|
+
self.norm2 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
|
468
|
+
self.norm3 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
|
469
|
+
|
470
|
+
def __call__(self, hidden_states, context, deterministic=True):
|
471
|
+
# self attention
|
472
|
+
residual = hidden_states
|
473
|
+
if self.only_cross_attention:
|
474
|
+
hidden_states = self.attn1(self.norm1(hidden_states), context)
|
475
|
+
else:
|
476
|
+
hidden_states = self.attn1(self.norm1(hidden_states))
|
477
|
+
hidden_states = hidden_states + residual
|
478
|
+
|
479
|
+
# cross attention
|
480
|
+
residual = hidden_states
|
481
|
+
hidden_states = self.attn2(self.norm2(hidden_states), context)
|
482
|
+
hidden_states = hidden_states + residual
|
483
|
+
|
484
|
+
# feed forward
|
485
|
+
residual = hidden_states
|
486
|
+
hidden_states = self.ff(self.norm3(hidden_states))
|
487
|
+
hidden_states = hidden_states + residual
|
488
|
+
|
489
|
+
return hidden_states
|