flaxdiff 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
flaxdiff/__init__.py ADDED
File without changes
@@ -0,0 +1 @@
1
+ from .simple_unet import *
@@ -0,0 +1,489 @@
1
+ """
2
+ Some Code ported from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_flax.py
3
+ """
4
+
5
+ import jax
6
+ import jax.numpy as jnp
7
+ from flax import linen as nn
8
+ from typing import Dict, Callable, Sequence, Any, Union
9
+ import einops
10
+ import functools
11
+ import math
12
+ from .common import kernel_init
13
+
14
+ def _query_chunk_attention(query, key, value, precision, key_chunk_size: int = 4096):
15
+ """Multi-head dot product attention with a limited number of queries."""
16
+ num_kv, num_heads, k_features = key.shape[-3:]
17
+ v_features = value.shape[-1]
18
+ key_chunk_size = min(key_chunk_size, num_kv)
19
+ query = query / jnp.sqrt(k_features)
20
+
21
+ @functools.partial(jax.checkpoint, prevent_cse=False)
22
+ def summarize_chunk(query, key, value):
23
+ attn_weights = jnp.einsum("...qhd,...khd->...qhk", query, key, precision=precision)
24
+
25
+ max_score = jnp.max(attn_weights, axis=-1, keepdims=True)
26
+ max_score = jax.lax.stop_gradient(max_score)
27
+ exp_weights = jnp.exp(attn_weights - max_score)
28
+
29
+ exp_values = jnp.einsum("...vhf,...qhv->...qhf", value, exp_weights, precision=precision)
30
+ max_score = jnp.einsum("...qhk->...qh", max_score)
31
+
32
+ return (exp_values, exp_weights.sum(axis=-1), max_score)
33
+
34
+ def chunk_scanner(chunk_idx):
35
+ # julienne key array
36
+ key_chunk = jax.lax.dynamic_slice(
37
+ operand=key,
38
+ start_indices=[0] * (key.ndim - 3) + [chunk_idx, 0, 0], # [...,k,h,d]
39
+ slice_sizes=list(key.shape[:-3]) + [key_chunk_size, num_heads, k_features], # [...,k,h,d]
40
+ )
41
+
42
+ # julienne value array
43
+ value_chunk = jax.lax.dynamic_slice(
44
+ operand=value,
45
+ start_indices=[0] * (value.ndim - 3) + [chunk_idx, 0, 0], # [...,v,h,d]
46
+ slice_sizes=list(value.shape[:-3]) + [key_chunk_size, num_heads, v_features], # [...,v,h,d]
47
+ )
48
+
49
+ return summarize_chunk(query, key_chunk, value_chunk)
50
+
51
+ chunk_values, chunk_weights, chunk_max = jax.lax.map(f=chunk_scanner, xs=jnp.arange(0, num_kv, key_chunk_size))
52
+
53
+ global_max = jnp.max(chunk_max, axis=0, keepdims=True)
54
+ max_diffs = jnp.exp(chunk_max - global_max)
55
+
56
+ chunk_values *= jnp.expand_dims(max_diffs, axis=-1)
57
+ chunk_weights *= max_diffs
58
+
59
+ all_values = chunk_values.sum(axis=0)
60
+ all_weights = jnp.expand_dims(chunk_weights, -1).sum(axis=0)
61
+
62
+ return all_values / all_weights
63
+
64
+
65
+ def jax_memory_efficient_attention(
66
+ query, key, value, precision=jax.lax.Precision.HIGHEST, query_chunk_size: int = 1024, key_chunk_size: int = 4096
67
+ ):
68
+ r"""
69
+ Flax Memory-efficient multi-head dot product attention. https://arxiv.org/abs/2112.05682v2
70
+ https://github.com/AminRezaei0x443/memory-efficient-attention
71
+
72
+ Args:
73
+ query (`jnp.ndarray`): (batch..., query_length, head, query_key_depth_per_head)
74
+ key (`jnp.ndarray`): (batch..., key_value_length, head, query_key_depth_per_head)
75
+ value (`jnp.ndarray`): (batch..., key_value_length, head, value_depth_per_head)
76
+ precision (`jax.lax.Precision`, *optional*, defaults to `jax.lax.Precision.HIGHEST`):
77
+ numerical precision for computation
78
+ query_chunk_size (`int`, *optional*, defaults to 1024):
79
+ chunk size to divide query array value must divide query_length equally without remainder
80
+ key_chunk_size (`int`, *optional*, defaults to 4096):
81
+ chunk size to divide key and value array value must divide key_value_length equally without remainder
82
+
83
+ Returns:
84
+ (`jnp.ndarray`) with shape of (batch..., query_length, head, value_depth_per_head)
85
+ """
86
+ num_q, num_heads, q_features = query.shape[-3:]
87
+
88
+ def chunk_scanner(chunk_idx, _):
89
+ # julienne query array
90
+ query_chunk = jax.lax.dynamic_slice(
91
+ operand=query,
92
+ start_indices=([0] * (query.ndim - 3)) + [chunk_idx, 0, 0], # [...,q,h,d]
93
+ slice_sizes=list(query.shape[:-3]) + [min(query_chunk_size, num_q), num_heads, q_features], # [...,q,h,d]
94
+ )
95
+
96
+ return (
97
+ chunk_idx + query_chunk_size, # unused ignore it
98
+ _query_chunk_attention(
99
+ query=query_chunk, key=key, value=value, precision=precision, key_chunk_size=key_chunk_size
100
+ ),
101
+ )
102
+
103
+ _, res = jax.lax.scan(
104
+ f=chunk_scanner,
105
+ init=0,
106
+ xs=None,
107
+ length=math.ceil(num_q / query_chunk_size), # start counter # stop counter
108
+ )
109
+
110
+ return jnp.concatenate(res, axis=-3) # fuse the chunked result back
111
+
112
+
113
+ class EfficientAttention(nn.Module):
114
+ """
115
+ Based on the pallas attention implementation.
116
+ """
117
+ query_dim: int
118
+ heads: int = 4
119
+ dim_head: int = 64
120
+ dtype: Any = jnp.float32
121
+ precision: Any = jax.lax.Precision.HIGHEST
122
+ use_bias: bool = True
123
+ kernel_init: Callable = lambda : kernel_init(1.0)
124
+
125
+ def setup(self):
126
+ inner_dim = self.dim_head * self.heads
127
+ # Weights were exported with old names {to_q, to_k, to_v, to_out}
128
+ self.query = nn.DenseGeneral(inner_dim, use_bias=False, precision=self.precision,
129
+ kernel_init=self.kernel_init(), dtype=self.dtype, name="to_q")
130
+ self.key = nn.DenseGeneral(inner_dim, use_bias=False, precision=self.precision,
131
+ kernel_init=self.kernel_init(), dtype=self.dtype, name="to_k")
132
+ self.value = nn.DenseGeneral(inner_dim, use_bias=False, precision=self.precision,
133
+ kernel_init=self.kernel_init(), dtype=self.dtype, name="to_v")
134
+ self.proj_attn = nn.DenseGeneral(self.query_dim, use_bias=False, precision=self.precision,
135
+ kernel_init=self.kernel_init(), dtype=self.dtype, name="to_out_0")
136
+ # self.attnfn = make_fast_generalized_attention(qkv_dim=inner_dim, lax_scan_unroll=16)
137
+
138
+ @nn.compact
139
+ def __call__(self, x:jax.Array, context=None):
140
+ # x has shape [B, H * W, C]
141
+ context = x if context is None else context
142
+ query = self.query(x)
143
+ key = self.key(context)
144
+ value = self.value(context)
145
+
146
+ # print(query.shape, key.shape, value.shape)
147
+
148
+ # hidden_states = jax.experimental.pallas.ops.tpu.flash_attention.mha_reference(
149
+ # query, key, value, None
150
+ # )
151
+
152
+ hidden_states = nn.dot_product_attention(
153
+ query, key, value, dtype=self.dtype, broadcast_dropout=False, dropout_rng=None, precision=self.precision
154
+ )
155
+ # hidden_states = self.attnfn(
156
+ # query, key, value, None
157
+ # )
158
+
159
+ proj = self.proj_attn(hidden_states)
160
+ return proj
161
+
162
+
163
+ class NormalAttention(nn.Module):
164
+ """
165
+ Simple implementation of the normal attention.
166
+ """
167
+ query_dim: int
168
+ heads: int = 4
169
+ dim_head: int = 64
170
+ dtype: Any = jnp.float32
171
+ precision: Any = jax.lax.Precision.HIGHEST
172
+ use_bias: bool = True
173
+ kernel_init: Callable = lambda : kernel_init(1.0)
174
+
175
+ def setup(self):
176
+ inner_dim = self.dim_head * self.heads
177
+ dense = functools.partial(
178
+ nn.DenseGeneral,
179
+ features=[self.heads, self.dim_head],
180
+ axis=-1,
181
+ precision=self.precision,
182
+ use_bias=self.use_bias,
183
+ kernel_init=self.kernel_init(),
184
+ dtype=self.dtype
185
+ )
186
+ self.query = dense(name="to_q")
187
+ self.key = dense(name="to_k")
188
+ self.value = dense(name="to_v")
189
+
190
+ self.proj_attn = nn.DenseGeneral(
191
+ self.query_dim,
192
+ axis=(-2, -1),
193
+ precision=self.precision,
194
+ use_bias=self.use_bias,
195
+ dtype=self.dtype,
196
+ name="to_out_0",
197
+ kernel_init=self.kernel_init()
198
+ # kernel_init=jax.nn.initializers.xavier_uniform()
199
+ )
200
+
201
+ @nn.compact
202
+ def __call__(self, x, context=None):
203
+ # x has shape [B, H, W, C]
204
+ context = x if context is None else context
205
+ query = self.query(x)
206
+ key = self.key(context)
207
+ value = self.value(context)
208
+
209
+ hidden_states = nn.dot_product_attention(
210
+ query, key, value, dtype=self.dtype, broadcast_dropout=False, dropout_rng=None, precision=self.precision
211
+ )
212
+ proj = self.proj_attn(hidden_states)
213
+ return proj
214
+
215
+ class AttentionBlock(nn.Module):
216
+ # Has self and cross attention
217
+ query_dim: int
218
+ heads: int = 4
219
+ dim_head: int = 64
220
+ dtype: Any = jnp.float32
221
+ precision: Any = jax.lax.Precision.HIGHEST
222
+ use_bias: bool = True
223
+ kernel_init: Callable = lambda : kernel_init(1.0)
224
+ use_flash_attention:bool = False
225
+ use_cross_only:bool = False
226
+
227
+ def setup(self):
228
+ if self.use_flash_attention:
229
+ attenBlock = EfficientAttention
230
+ else:
231
+ attenBlock = NormalAttention
232
+
233
+ self.attention1 = attenBlock(
234
+ query_dim=self.query_dim,
235
+ heads=self.heads,
236
+ dim_head=self.dim_head,
237
+ name=f'Attention1',
238
+ precision=self.precision,
239
+ use_bias=self.use_bias,
240
+ dtype=self.dtype,
241
+ kernel_init=self.kernel_init
242
+ )
243
+ self.attention2 = attenBlock(
244
+ query_dim=self.query_dim,
245
+ heads=self.heads,
246
+ dim_head=self.dim_head,
247
+ name=f'Attention2',
248
+ precision=self.precision,
249
+ use_bias=self.use_bias,
250
+ dtype=self.dtype,
251
+ kernel_init=self.kernel_init
252
+ )
253
+
254
+ self.ff = nn.DenseGeneral(
255
+ features=self.query_dim,
256
+ use_bias=self.use_bias,
257
+ precision=self.precision,
258
+ dtype=self.dtype,
259
+ kernel_init=self.kernel_init(),
260
+ name="ff"
261
+ )
262
+ self.norm1 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
263
+ self.norm2 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
264
+ self.norm3 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
265
+ self.norm4 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
266
+
267
+ @nn.compact
268
+ def __call__(self, hidden_states, context=None):
269
+ # self attention
270
+ residual = hidden_states
271
+ hidden_states = self.norm1(hidden_states)
272
+ if self.use_cross_only:
273
+ hidden_states = self.attention1(hidden_states, context)
274
+ else:
275
+ hidden_states = self.attention1(hidden_states)
276
+ hidden_states = hidden_states + residual
277
+
278
+ # cross attention
279
+ residual = hidden_states
280
+ hidden_states = self.norm2(hidden_states)
281
+ hidden_states = self.attention2(hidden_states, context)
282
+ hidden_states = hidden_states + residual
283
+
284
+ # feed forward
285
+ residual = hidden_states
286
+ hidden_states = self.norm3(hidden_states)
287
+ hidden_states = nn.gelu(hidden_states)
288
+ hidden_states = self.ff(hidden_states)
289
+ hidden_states = hidden_states + residual
290
+
291
+ return hidden_states
292
+
293
+ class TransformerBlock(nn.Module):
294
+ heads: int = 4
295
+ dim_head: int = 32
296
+ use_linear_attention: bool = True
297
+ dtype: Any = jnp.float32
298
+ precision: Any = jax.lax.Precision.HIGH
299
+ use_projection: bool = False
300
+ use_flash_attention:bool = True
301
+ use_self_and_cross:bool = False
302
+
303
+ @nn.compact
304
+ def __call__(self, x, context=None):
305
+ inner_dim = self.heads * self.dim_head
306
+ B, H, W, C = x.shape
307
+ normed_x = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)(x)
308
+ if self.use_projection == True:
309
+ if self.use_linear_attention:
310
+ projected_x = nn.Dense(features=inner_dim,
311
+ use_bias=False, precision=self.precision,
312
+ kernel_init=kernel_init(1.0),
313
+ dtype=self.dtype, name=f'project_in')(normed_x)
314
+ else:
315
+ projected_x = nn.Conv(
316
+ features=inner_dim, kernel_size=(1, 1),
317
+ kernel_init=kernel_init(1.0),
318
+ strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
319
+ precision=self.precision, name=f'project_in_conv',
320
+ )(normed_x)
321
+ else:
322
+ projected_x = normed_x
323
+ inner_dim = C
324
+
325
+ context = projected_x if context is None else context
326
+
327
+ if self.use_self_and_cross:
328
+ projected_x = AttentionBlock(
329
+ query_dim=inner_dim,
330
+ heads=self.heads,
331
+ dim_head=self.dim_head,
332
+ name=f'Attention',
333
+ precision=self.precision,
334
+ use_bias=False,
335
+ dtype=self.dtype,
336
+ use_flash_attention=self.use_flash_attention,
337
+ use_cross_only=False
338
+ )(projected_x, context)
339
+ elif self.use_flash_attention == True:
340
+ projected_x = EfficientAttention(
341
+ query_dim=inner_dim,
342
+ heads=self.heads,
343
+ dim_head=self.dim_head,
344
+ name=f'Attention',
345
+ precision=self.precision,
346
+ use_bias=False,
347
+ dtype=self.dtype,
348
+ )(projected_x, context)
349
+ else:
350
+ projected_x = NormalAttention(
351
+ query_dim=inner_dim,
352
+ heads=self.heads,
353
+ dim_head=self.dim_head,
354
+ name=f'Attention',
355
+ precision=self.precision,
356
+ use_bias=False,
357
+ )(projected_x, context)
358
+
359
+
360
+ if self.use_projection == True:
361
+ if self.use_linear_attention:
362
+ projected_x = nn.Dense(features=C, precision=self.precision,
363
+ dtype=self.dtype, use_bias=False,
364
+ kernel_init=kernel_init(1.0),
365
+ name=f'project_out')(projected_x)
366
+ else:
367
+ projected_x = nn.Conv(
368
+ features=C, kernel_size=(1, 1),
369
+ kernel_init=kernel_init(1.0),
370
+ strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
371
+ precision=self.precision, name=f'project_out_conv',
372
+ )(projected_x)
373
+
374
+ out = x + projected_x
375
+ return out
376
+
377
+ class FlaxGEGLU(nn.Module):
378
+ r"""
379
+ Flax implementation of a Linear layer followed by the variant of the gated linear unit activation function from
380
+ https://arxiv.org/abs/2002.05202.
381
+
382
+ Parameters:
383
+ dim (:obj:`int`):
384
+ Input hidden states dimension
385
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
386
+ Dropout rate
387
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
388
+ Parameters `dtype`
389
+ """
390
+
391
+ dim: int
392
+ dropout: float = 0.0
393
+ dtype: jnp.dtype = jnp.float32
394
+
395
+ def setup(self):
396
+ inner_dim = self.dim * 4
397
+ self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype, precision=jax.lax.Precision.DEFAULT)
398
+
399
+ def __call__(self, hidden_states):
400
+ hidden_states = self.proj(hidden_states)
401
+ hidden_linear, hidden_gelu = jnp.split(hidden_states, 2, axis=3)
402
+ return hidden_linear * nn.gelu(hidden_gelu)
403
+
404
+ class FlaxFeedForward(nn.Module):
405
+ r"""
406
+ Flax module that encapsulates two Linear layers separated by a non-linearity. It is the counterpart of PyTorch's
407
+ [`FeedForward`] class, with the following simplifications:
408
+ - The activation function is currently hardcoded to a gated linear unit from:
409
+ https://arxiv.org/abs/2002.05202
410
+ - `dim_out` is equal to `dim`.
411
+ - The number of hidden dimensions is hardcoded to `dim * 4` in [`FlaxGELU`].
412
+
413
+ Parameters:
414
+ dim (:obj:`int`):
415
+ Inner hidden states dimension
416
+ dropout (:obj:`float`, *optional*, defaults to 0.0):
417
+ Dropout rate
418
+ dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
419
+ Parameters `dtype`
420
+ """
421
+
422
+ dim: int
423
+ dropout: float = 0.0
424
+ dtype: jnp.dtype = jnp.float32
425
+
426
+ def setup(self):
427
+ # The second linear layer needs to be called
428
+ # net_2 for now to match the index of the Sequential layer
429
+ self.net_0 = FlaxGEGLU(self.dim, self.dtype)
430
+ self.net_2 = nn.Dense(self.dim, dtype=self.dtype, precision=jax.lax.Precision.DEFAULT)
431
+
432
+ def __call__(self, hidden_states):
433
+ hidden_states = self.net_0(hidden_states)
434
+ hidden_states = self.net_2(hidden_states)
435
+ return hidden_states
436
+
437
+ class BasicTransformerBlock(nn.Module):
438
+ query_dim: int
439
+ heads: int
440
+ dim_head: int
441
+ dropout: float = 0.0
442
+ only_cross_attention: bool = False
443
+ dtype: jnp.dtype = jnp.float32
444
+ use_memory_efficient_attention: bool = False
445
+ split_head_dim: bool = False
446
+ precision: Any = jax.lax.Precision.DEFAULT
447
+
448
+ def setup(self):
449
+ # self attention (or cross_attention if only_cross_attention is True)
450
+ self.attn1 = NormalAttention(
451
+ query_dim=self.query_dim,
452
+ heads=self.heads,
453
+ dim_head=self.dim_head,
454
+ dtype=self.dtype,
455
+ precision=self.precision,
456
+ )
457
+ # cross attention
458
+ self.attn2 = NormalAttention(
459
+ query_dim=self.query_dim,
460
+ heads=self.heads,
461
+ dim_head=self.dim_head,
462
+ dtype=self.dtype,
463
+ precision=self.precision,
464
+ )
465
+ self.ff = FlaxFeedForward(dim=self.query_dim, dropout=self.dropout, dtype=self.dtype)
466
+ self.norm1 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
467
+ self.norm2 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
468
+ self.norm3 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
469
+
470
+ def __call__(self, hidden_states, context, deterministic=True):
471
+ # self attention
472
+ residual = hidden_states
473
+ if self.only_cross_attention:
474
+ hidden_states = self.attn1(self.norm1(hidden_states), context)
475
+ else:
476
+ hidden_states = self.attn1(self.norm1(hidden_states))
477
+ hidden_states = hidden_states + residual
478
+
479
+ # cross attention
480
+ residual = hidden_states
481
+ hidden_states = self.attn2(self.norm2(hidden_states), context)
482
+ hidden_states = hidden_states + residual
483
+
484
+ # feed forward
485
+ residual = hidden_states
486
+ hidden_states = self.ff(self.norm3(hidden_states))
487
+ hidden_states = hidden_states + residual
488
+
489
+ return hidden_states
@@ -0,0 +1,7 @@
1
+ import jax.numpy as jnp
2
+ from flax import linen as nn
3
+
4
+ # Kernel initializer to use
5
+ def kernel_init(scale):
6
+ scale = max(scale, 1e-10)
7
+ return nn.initializers.variance_scaling(scale=scale, mode="fan_in", distribution="truncated_normal")