flaxdiff 0.2.8__py3-none-any.whl → 0.2.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -292,14 +292,12 @@ def get_dataset_grain(
292
292
  Dictionary with train dataset function and metadata.
293
293
  """
294
294
  dataset = datasetMap[data_name]
295
- train_source = dataset["source"](dataset_source, split="train")
296
- # val_source = dataset["source"](dataset_source, split="val")
295
+ data_source = dataset["source"](dataset_source)
297
296
  augmenter = dataset["augmenter"](image_scale, method)
298
-
299
297
  local_batch_size = batch_size // jax.process_count()
300
298
 
301
299
  train_sampler = pygrain.IndexSampler(
302
- num_records=len(train_source) if count is None else count,
300
+ num_records=len(data_source) if count is None else count,
303
301
  shuffle=True,
304
302
  seed=seed,
305
303
  num_epochs=num_epochs,
@@ -307,7 +305,7 @@ def get_dataset_grain(
307
305
  )
308
306
 
309
307
  # val_sampler = pygrain.IndexSampler(
310
- # num_records=len(val_source) if count is None else count,
308
+ # num_records=len(data_source) if count is None else count,
311
309
  # shuffle=False,
312
310
  # seed=seed,
313
311
  # num_epochs=num_epochs,
@@ -318,16 +316,10 @@ def get_dataset_grain(
318
316
  transformations = [
319
317
  augmenter(),
320
318
  ]
321
-
322
- # if filters:
323
- # print("Adding filters to transformations")
324
- # transformations.append(filters())
325
-
326
- # transformations.append(CaptionDeletionTransform())
327
319
  transformations.append(pygrain.Batch(local_batch_size, drop_remainder=True))
328
320
 
329
321
  loader = pygrain.DataLoader(
330
- data_source=train_source,
322
+ data_source=data_source,
331
323
  sampler=train_sampler,
332
324
  operations=transformations,
333
325
  worker_count=worker_count,
@@ -341,26 +333,26 @@ def get_dataset_grain(
341
333
  def get_valset():
342
334
  transformations = [
343
335
  augmenter(),
344
- pygrain.Batch(local_batch_size, drop_remainder=True),
336
+ pygrain.Batch(32, drop_remainder=True),
345
337
  ]
346
338
 
347
339
  loader = pygrain.DataLoader(
348
- data_source=train_source,
340
+ data_source=data_source,
349
341
  sampler=train_sampler,
350
342
  operations=transformations,
351
- worker_count=2,
343
+ worker_count=8,
352
344
  read_options=pygrain.ReadOptions(
353
- read_thread_count, read_buffer_size
345
+ 32, 128
354
346
  ),
355
- worker_buffer_size=2,
347
+ worker_buffer_size=32,
356
348
  )
357
349
  return loader
358
350
 
359
351
  return {
360
352
  "train": get_trainset,
361
- "train_len": len(train_source),
353
+ "train_len": len(data_source),
362
354
  "val": get_valset,
363
- "val_len": len(train_source),
355
+ "val_len": len(data_source),
364
356
  "local_batch_size": local_batch_size,
365
357
  "global_batch_size": batch_size,
366
358
  }
@@ -21,8 +21,9 @@ datasetMap = {
21
21
  "augmenter": gcs_augmenters,
22
22
  },
23
23
  "laiona_coco": {
24
- "source": data_source_gcs('datasets/laion12m+mscoco_filtered-new'),
24
+ "source": data_source_gcs('datasets/laion12m+mscoco'),
25
25
  "augmenter": gcs_augmenters,
26
+ "filter": gcs_filters,
26
27
  },
27
28
  "aesthetic_coyo": {
28
29
  "source": data_source_gcs('arrayrecords/aestheticCoyo_0.25clip_6aesthetic'),
@@ -11,7 +11,7 @@ import struct as st
11
11
  from functools import partial
12
12
  import numpy as np
13
13
  from .base import DataSource, DataAugmenter
14
-
14
+ import traceback
15
15
 
16
16
  # ----------------------------------------------------------------------------------
17
17
  # Utility functions
@@ -79,10 +79,28 @@ def labelizer_oxford_flowers102(path):
79
79
  # TFDS Image Source
80
80
  # ----------------------------------------------------------------------------------
81
81
 
82
+ def get_oxford_valset(text_encoder):
83
+ # Construct a validation set by the prompts for consistency
84
+ val_prompts = ['water tulip', ' a water lily', ' a water lily', ' a photo of a rose', ' a photo of a rose', ' a water lily', ' a water lily', ' a photo of a marigold', ' a photo of a marigold', ' a photo of a marigold', ' a water lily', ' a photo of a sunflower', ' a photo of a lotus', ' columbine', ' columbine', ' an orchid', ' an orchid', ' an orchid', ' a water lily', ' a water lily', ' a water lily', ' columbine', ' columbine', ' a photo of a sunflower', ' a photo of a sunflower', ' a photo of a sunflower', ' a photo of a lotus', ' a photo of a lotus', ' a photo of a marigold', ' a photo of a marigold', ' a photo of a rose', ' a photo of a rose', ' a photo of a rose', ' orange dahlia', ' orange dahlia', ' a lenten rose', ' a lenten rose', ' a water lily', ' a water lily', ' a water lily', ' a water lily', ' an orchid', ' an orchid', ' an orchid', ' hard-leaved pocket orchid', ' bird of paradise', ' bird of paradise', ' a photo of a lovely rose', ' a photo of a lovely rose', ' a photo of a globe-flower', ' a photo of a globe-flower', ' a photo of a lovely rose', ' a photo of a lovely rose', ' a photo of a ruby-lipped cattleya', ' a photo of a ruby-lipped cattleya', ' a photo of a lovely rose', ' a water lily', ' a osteospermum', ' a osteospermum', ' a water lily', ' a water lily', ' a water lily', ' a red rose', ' a red rose']
85
+ val_prompts *= 100
86
+
87
+ def get_val_dataset(batch_size=128):
88
+ for i in range(0, len(val_prompts), batch_size):
89
+ try:
90
+ prompts = val_prompts[i:i + batch_size]
91
+ tokens = text_encoder.tokenize(prompts)
92
+ yield {"text": tokens}
93
+ except Exception as e:
94
+ print(f"Error in get_val_dataset: {e}")
95
+ traceback.print_exc()
96
+ continue
97
+
98
+ return get_val_dataset, len(val_prompts)
99
+
82
100
  class ImageTFDSSource(DataSource):
83
101
  """Data source for TensorFlow Datasets (TFDS) image datasets."""
84
102
 
85
- def __init__(self, name: str, use_tf: bool = True):
103
+ def __init__(self, name: str, use_tf: bool = True, split: str = "all"):
86
104
  """Initialize a TFDS image data source.
87
105
 
88
106
  Args:
@@ -92,8 +110,9 @@ class ImageTFDSSource(DataSource):
92
110
  """
93
111
  self.name = name
94
112
  self.use_tf = use_tf
113
+ self.split = split
95
114
 
96
- def get_source(self, path_override: str, split: str = "all") -> Any:
115
+ def get_source(self, path_override: str) -> Any:
97
116
  """Get the TFDS data source.
98
117
 
99
118
  Args:
@@ -104,9 +123,9 @@ class ImageTFDSSource(DataSource):
104
123
  """
105
124
  import tensorflow_datasets as tfds
106
125
  if self.use_tf:
107
- return tfds.load(self.name, split=split, shuffle_files=True)
126
+ return tfds.load(self.name, split=self.split, shuffle_files=True)
108
127
  else:
109
- return tfds.data_source(self.name, split=split, try_gcs=False)
128
+ return tfds.data_source(self.name, split=self.split, try_gcs=False)
110
129
 
111
130
 
112
131
  class ImageTFDSAugmenter(DataAugmenter):
@@ -198,7 +217,7 @@ class ImageGCSSource(DataSource):
198
217
  """
199
218
  self.source = source
200
219
 
201
- def get_source(self, path_override: str = "/home/mrwhite0racle/gcs_mount", split: str = "train") -> Any:
220
+ def get_source(self, path_override: str = "/home/mrwhite0racle/gcs_mount") -> Any:
202
221
  """Get the GCS data source.
203
222
 
204
223
  Args:
@@ -210,8 +229,6 @@ class ImageGCSSource(DataSource):
210
229
  records_path = os.path.join(path_override, self.source)
211
230
  records = [os.path.join(records_path, i) for i in os.listdir(
212
231
  records_path) if 'array_record' in i]
213
- if split == "val":
214
- records = records[:1]
215
232
  return pygrain.ArrayRecordDataSource(records)
216
233
 
217
234
 
@@ -226,7 +243,7 @@ class CombinedImageGCSSource(DataSource):
226
243
  """
227
244
  self.sources = sources
228
245
 
229
- def get_source(self, path_override: str = "/home/mrwhite0racle/gcs_mount", split: str = "train") -> Any:
246
+ def get_source(self, path_override: str = "/home/mrwhite0racle/gcs_mount") -> Any:
230
247
  """Get the combined GCS data source.
231
248
 
232
249
  Args:
@@ -240,8 +257,6 @@ class CombinedImageGCSSource(DataSource):
240
257
  for records_path in records_paths:
241
258
  records += [os.path.join(records_path, i) for i in os.listdir(
242
259
  records_path) if 'array_record' in i]
243
- if split == "val":
244
- records = records[:1]
245
260
  return pygrain.ArrayRecordDataSource(records)
246
261
 
247
262
  class ImageGCSAugmenter(DataAugmenter):
@@ -357,9 +372,9 @@ class ImageGCSAugmenter(DataAugmenter):
357
372
 
358
373
  # These functions maintain backward compatibility with existing code
359
374
 
360
- def data_source_tfds(name, use_tf=True):
375
+ def data_source_tfds(name, use_tf=True, split="all"):
361
376
  """Legacy function for TFDS data sources."""
362
- source = ImageTFDSSource(name=name, use_tf=use_tf)
377
+ source = ImageTFDSSource(name=name, use_tf=use_tf, split=split)
363
378
  return source.get_source
364
379
 
365
380
 
@@ -389,4 +404,4 @@ def gcs_augmenters(image_scale, method):
389
404
  def gcs_filters(image_scale):
390
405
  """Legacy function for GCS Filters."""
391
406
  augmenter = ImageGCSAugmenter()
392
- return augmenter.create_filter(image_scale=image_scale)
407
+ return augmenter.create_filter(image_scale=image_scale)
@@ -25,6 +25,9 @@ from flaxdiff.models.autoencoder.diffusers import StableDiffusionVAE
25
25
  from flaxdiff.inputs import DiffusionInputConfig, ConditionalInputConfig
26
26
  from flaxdiff.utils import defaultTextEncodeModel
27
27
 
28
+ from flaxdiff.models.simple_vit import UViT, SimpleUDiT
29
+ from flaxdiff.models.simple_dit import SimpleDiT
30
+ from flaxdiff.models.simple_mmdit import SimpleMMDiT, HierarchicalMMDiT
28
31
  from orbax.checkpoint import CheckpointManager, CheckpointManagerOptions, PyTreeCheckpointer
29
32
  import os
30
33
 
@@ -116,7 +119,10 @@ def parse_config(config, overrides=None):
116
119
  MODEL_CLASSES = {
117
120
  'unet': Unet,
118
121
  'uvit': UViT,
119
- 'diffusers_unet_simple': FlaxUNet2DConditionModel
122
+ 'diffusers_unet_simple': FlaxUNet2DConditionModel,
123
+ 'simple_dit': SimpleDiT,
124
+ 'simple_uvit': SimpleUDiT,
125
+ 'simple_mmdit': SimpleMMDiT,
120
126
  }
121
127
 
122
128
  # Map all the leaves of the model config, converting strings to appropriate types
@@ -1,4 +1,3 @@
1
- from .simple_vit import PatchEmbedding, unpatchify
2
1
  import jax
3
2
  import jax.numpy as jnp
4
3
  from flax import linen as nn
@@ -7,6 +6,7 @@ import einops
7
6
  from functools import partial
8
7
 
9
8
  # Re-use existing components if they are suitable
9
+ from .vit_common import PatchEmbedding, unpatchify, RotaryEmbedding, RoPEAttention, AdaLNParams
10
10
  from .common import kernel_init, FourierEmbedding, TimeProjection
11
11
  # Using NormalAttention for RoPE integration
12
12
  from .attention import NormalAttention
@@ -15,207 +15,6 @@ from flax.typing import Dtype, PrecisionLike
15
15
  # Use our improved Hilbert implementation
16
16
  from .hilbert import hilbert_indices, inverse_permutation, hilbert_patchify, hilbert_unpatchify
17
17
 
18
- # --- Rotary Positional Embedding (RoPE) ---
19
- # Adapted from https://github.com/google-deepmind/ring_attention/blob/main/ring_attention/layers/rotary.py
20
-
21
-
22
- def _rotate_half(x: jax.Array) -> jax.Array:
23
- """Rotates half the hidden dims of the input."""
24
- x1 = x[..., : x.shape[-1] // 2]
25
- x2 = x[..., x.shape[-1] // 2:]
26
- return jnp.concatenate((-x2, x1), axis=-1)
27
-
28
- def apply_rotary_embedding(
29
- x: jax.Array, freqs_cos: jax.Array, freqs_sin: jax.Array
30
- ) -> jax.Array:
31
- """Applies rotary embedding to the input tensor using rotate_half method."""
32
- # x shape: [..., Sequence, Dimension] e.g. [B, H, S, D] or [B, S, D]
33
- # freqs_cos/sin shape: [Sequence, Dimension / 2]
34
-
35
- # Expand dims for broadcasting: [1, 1, S, D/2] or [1, S, D/2]
36
- if x.ndim == 4: # [B, H, S, D]
37
- cos_freqs = jnp.expand_dims(freqs_cos, axis=(0, 1))
38
- sin_freqs = jnp.expand_dims(freqs_sin, axis=(0, 1))
39
- elif x.ndim == 3: # [B, S, D]
40
- cos_freqs = jnp.expand_dims(freqs_cos, axis=0)
41
- sin_freqs = jnp.expand_dims(freqs_sin, axis=0)
42
-
43
- # Duplicate cos and sin for the full dimension D
44
- # Shape becomes [..., S, D]
45
- cos_freqs = jnp.concatenate([cos_freqs, cos_freqs], axis=-1)
46
- sin_freqs = jnp.concatenate([sin_freqs, sin_freqs], axis=-1)
47
-
48
- # Apply rotation: x * cos + rotate_half(x) * sin
49
- x_rotated = x * cos_freqs + _rotate_half(x) * sin_freqs
50
- return x_rotated.astype(x.dtype)
51
-
52
- class RotaryEmbedding(nn.Module):
53
- dim: int # Dimension of the head
54
- max_seq_len: int = 2048
55
- base: int = 10000
56
- dtype: Dtype = jnp.float32
57
-
58
- def setup(self):
59
- inv_freq = 1.0 / (
60
- self.base ** (jnp.arange(0, self.dim, 2,
61
- dtype=jnp.float32) / self.dim)
62
- )
63
- t = jnp.arange(self.max_seq_len, dtype=jnp.float32)
64
- freqs = jnp.outer(t, inv_freq) # Shape: [max_seq_len, dim / 2]
65
-
66
- # Store cosine and sine separately instead of as complex numbers
67
- self.freqs_cos = jnp.cos(freqs) # Shape: [max_seq_len, dim / 2]
68
- self.freqs_sin = jnp.sin(freqs) # Shape: [max_seq_len, dim / 2]
69
-
70
- def __call__(self, seq_len: int):
71
- if seq_len > self.max_seq_len:
72
- raise ValueError(
73
- f"Sequence length {seq_len} exceeds max_seq_len {self.max_seq_len}")
74
- # Return separate cos and sin components
75
- return self.freqs_cos[:seq_len, :], self.freqs_sin[:seq_len, :]
76
- # --- Attention with RoPE ---
77
-
78
-
79
- class RoPEAttention(NormalAttention):
80
- rope_emb: RotaryEmbedding = None # Instance of RotaryEmbedding
81
-
82
- @nn.compact
83
- def __call__(self, x, context=None, freqs_cis=None):
84
- # x has shape [B, H, W, C] or [B, S, C]
85
- orig_x_shape = x.shape
86
- is_4d = len(orig_x_shape) == 4
87
- if is_4d:
88
- B, H, W, C = x.shape
89
- seq_len = H * W
90
- x = x.reshape((B, seq_len, C))
91
- else:
92
- B, seq_len, C = x.shape
93
-
94
- context = x if context is None else context
95
- if len(context.shape) == 4:
96
- _B, _H, _W, _C = context.shape
97
- context_seq_len = _H * _W
98
- context = context.reshape((B, context_seq_len, _C))
99
- # else: context is already [B, S_ctx, C]
100
-
101
- query = self.query(x) # [B, S, H, D]
102
- key = self.key(context) # [B, S_ctx, H, D]
103
- value = self.value(context) # [B, S_ctx, H, D]
104
-
105
- # Apply RoPE to query and key
106
- if freqs_cis is None:
107
- # Generate frequencies using the rope_emb instance
108
- seq_len_q = query.shape[1] # Use query's sequence length
109
- freqs_cos, freqs_sin = self.rope_emb(seq_len_q)
110
- else:
111
- # If freqs_cis is passed in as a tuple
112
- freqs_cos, freqs_sin = freqs_cis
113
-
114
- # Apply RoPE to query and key
115
- # Permute to [B, H, S, D] for RoPE application
116
- query = einops.rearrange(query, 'b s h d -> b h s d')
117
- key = einops.rearrange(key, 'b s h d -> b h s d')
118
-
119
- # Apply RoPE only up to the context sequence length for keys if different
120
- # Assuming self-attention or context has same seq len for simplicity here
121
- query = apply_rotary_embedding(query, freqs_cos, freqs_sin)
122
- key = apply_rotary_embedding(key, freqs_cos, freqs_sin) # Apply same freqs to key
123
-
124
- # Permute back to [B, S, H, D] for dot_product_attention
125
- query = einops.rearrange(query, 'b h s d -> b s h d')
126
- key = einops.rearrange(key, 'b h s d -> b s h d')
127
-
128
- hidden_states = nn.dot_product_attention(
129
- query, key, value, dtype=self.dtype, broadcast_dropout=False,
130
- dropout_rng=None, precision=self.precision, force_fp32_for_softmax=self.force_fp32_for_softmax,
131
- deterministic=True
132
- ) # Output shape [B, S, H, D]
133
-
134
- # Use the proj_attn from NormalAttention which expects [B, S, H, D]
135
- proj = self.proj_attn(hidden_states) # Output shape [B, S, C]
136
-
137
- if is_4d:
138
- proj = proj.reshape(orig_x_shape) # Reshape back if input was 4D
139
-
140
- return proj
141
-
142
- # --- adaLN-Zero ---
143
-
144
-
145
- class AdaLNZero(nn.Module):
146
- features: int
147
- dtype: Optional[Dtype] = None
148
- precision: PrecisionLike = None
149
- norm_epsilon: float = 1e-5 # Standard LayerNorm epsilon
150
-
151
- @nn.compact
152
- def __call__(self, x, conditioning):
153
- # Project conditioning signal to get scale and shift parameters
154
- # Conditioning shape: [B, D_cond] -> [B, 1, ..., 1, 6 * features] for broadcasting
155
- # Or [B, 1, 6*features] if x is [B, S, F]
156
-
157
- # Ensure conditioning has seq dim if x does
158
- # x=[B,S,F], cond=[B,D_cond]
159
- if x.ndim == 3 and conditioning.ndim == 2:
160
- conditioning = jnp.expand_dims(
161
- conditioning, axis=1) # cond=[B,1,D_cond]
162
-
163
- # Project conditioning to get 6 params per feature (scale_mlp, shift_mlp, gate_mlp, scale_attn, shift_attn, gate_attn)
164
- # Using nn.DenseGeneral for flexibility if needed, but nn.Dense is fine if cond is [B, D_cond] or [B, 1, D_cond]
165
- ada_params = nn.Dense(
166
- features=6 * self.features,
167
- dtype=self.dtype,
168
- precision=self.precision,
169
- # Initialize projection to zero (Zero init)
170
- kernel_init=nn.initializers.zeros,
171
- name="ada_proj"
172
- )(conditioning)
173
-
174
- # Split into scale, shift, gate for MLP and Attention
175
- scale_mlp, shift_mlp, gate_mlp, scale_attn, shift_attn, gate_attn = jnp.split(
176
- ada_params, 6, axis=-1)
177
-
178
- scale_mlp = jnp.clip(scale_mlp, -10.0, 10.0)
179
- shift_mlp = jnp.clip(shift_mlp, -10.0, 10.0)
180
- # Apply Layer Normalization
181
- norm = nn.LayerNorm(epsilon=self.norm_epsilon,
182
- use_scale=False, use_bias=False, dtype=self.dtype)
183
- # norm = nn.RMSNorm(epsilon=self.norm_epsilon, dtype=self.dtype) # Alternative: RMSNorm
184
-
185
- norm_x = norm(x)
186
-
187
- # Modulate for Attention path
188
- x_attn = norm_x * (1 + scale_attn) + shift_attn
189
-
190
- # Modulate for MLP path
191
- x_mlp = norm_x * (1 + scale_mlp) + shift_mlp
192
-
193
- # Return modulated outputs and gates
194
- return x_attn, gate_attn, x_mlp, gate_mlp
195
-
196
- class AdaLNParams(nn.Module): # Renamed for clarity
197
- features: int
198
- dtype: Optional[Dtype] = None
199
- precision: PrecisionLike = None
200
-
201
- @nn.compact
202
- def __call__(self, conditioning):
203
- # Ensure conditioning is broadcastable if needed (e.g., [B, 1, D_cond])
204
- if conditioning.ndim == 2:
205
- conditioning = jnp.expand_dims(conditioning, axis=1)
206
-
207
- # Project conditioning to get 6 params per feature
208
- ada_params = nn.Dense(
209
- features=6 * self.features,
210
- dtype=self.dtype,
211
- precision=self.precision,
212
- kernel_init=nn.initializers.zeros,
213
- name="ada_proj"
214
- )(conditioning)
215
- # Return all params (or split if preferred, but maybe return tuple/dict)
216
- # Shape: [B, 1, 6*F]
217
- return ada_params # Or split and return tuple: jnp.split(ada_params, 6, axis=-1)
218
-
219
18
  # --- DiT Block ---
220
19
  class DiTBlock(nn.Module):
221
20
  features: int
@@ -7,143 +7,12 @@ from functools import partial
7
7
  from flax.typing import Dtype, PrecisionLike
8
8
 
9
9
  # Imports from local modules
10
- from .simple_vit import PatchEmbedding, unpatchify
10
+ from .vit_common import PatchEmbedding, unpatchify, RotaryEmbedding, RoPEAttention
11
11
  from .common import kernel_init, FourierEmbedding, TimeProjection
12
12
  from .attention import NormalAttention # Base for RoPEAttention
13
13
  # Replace common.hilbert_indices with improved implementation from hilbert.py
14
14
  from .hilbert import hilbert_indices, inverse_permutation, hilbert_patchify, hilbert_unpatchify
15
15
 
16
- # --- Rotary Positional Embedding (RoPE) ---
17
- # Re-used from simple_dit.py
18
-
19
-
20
- def _rotate_half(x: jax.Array) -> jax.Array:
21
- """Rotates half the hidden dims of the input."""
22
- x1 = x[..., : x.shape[-1] // 2]
23
- x2 = x[..., x.shape[-1] // 2:]
24
- return jnp.concatenate((-x2, x1), axis=-1)
25
-
26
-
27
- def apply_rotary_embedding(
28
- x: jax.Array, freqs_cos: jax.Array, freqs_sin: jax.Array
29
- ) -> jax.Array:
30
- """Applies rotary embedding to the input tensor using rotate_half method."""
31
- if x.ndim == 4: # [B, H, S, D]
32
- cos_freqs = jnp.expand_dims(freqs_cos, axis=(0, 1))
33
- sin_freqs = jnp.expand_dims(freqs_sin, axis=(0, 1))
34
- elif x.ndim == 3: # [B, S, D]
35
- cos_freqs = jnp.expand_dims(freqs_cos, axis=0)
36
- sin_freqs = jnp.expand_dims(freqs_sin, axis=0)
37
- else:
38
- raise ValueError(f"Unsupported input dimension: {x.ndim}")
39
-
40
- cos_freqs = jnp.concatenate([cos_freqs, cos_freqs], axis=-1)
41
- sin_freqs = jnp.concatenate([sin_freqs, sin_freqs], axis=-1)
42
-
43
- x_rotated = x * cos_freqs + _rotate_half(x) * sin_freqs
44
- return x_rotated.astype(x.dtype)
45
-
46
-
47
- class RotaryEmbedding(nn.Module):
48
- dim: int
49
- max_seq_len: int = 4096 # Increased default based on SimpleDiT
50
- base: int = 10000
51
- dtype: Dtype = jnp.float32
52
-
53
- def setup(self):
54
- inv_freq = 1.0 / (
55
- self.base ** (jnp.arange(0, self.dim, 2,
56
- dtype=jnp.float32) / self.dim)
57
- )
58
- t = jnp.arange(self.max_seq_len, dtype=jnp.float32)
59
- freqs = jnp.outer(t, inv_freq)
60
- self.freqs_cos = jnp.cos(freqs)
61
- self.freqs_sin = jnp.sin(freqs)
62
-
63
- def __call__(self, seq_len: int):
64
- if seq_len > self.max_seq_len:
65
- # Dynamically extend frequencies if needed (more robust)
66
- t = jnp.arange(seq_len, dtype=jnp.float32)
67
- inv_freq = 1.0 / (
68
- self.base ** (jnp.arange(0, self.dim, 2,
69
- dtype=jnp.float32) / self.dim)
70
- )
71
- freqs = jnp.outer(t, inv_freq)
72
- freqs_cos = jnp.cos(freqs)
73
- freqs_sin = jnp.sin(freqs)
74
- # Consider caching extended freqs if this happens often
75
- return freqs_cos, freqs_sin
76
- # Or raise error like before:
77
- # raise ValueError(f"Sequence length {seq_len} exceeds max_seq_len {self.max_seq_len}")
78
- return self.freqs_cos[:seq_len, :], self.freqs_sin[:seq_len, :]
79
-
80
- # --- Attention with RoPE ---
81
- # Re-used from simple_dit.py
82
-
83
-
84
- class RoPEAttention(NormalAttention):
85
- rope_emb: RotaryEmbedding = None
86
-
87
- @nn.compact
88
- def __call__(self, x, context=None, freqs_cis=None):
89
- orig_x_shape = x.shape
90
- is_4d = len(orig_x_shape) == 4
91
- if is_4d:
92
- B, H, W, C = x.shape
93
- seq_len = H * W
94
- x = x.reshape((B, seq_len, C))
95
- else:
96
- B, seq_len, C = x.shape
97
-
98
- context = x if context is None else context
99
- if len(context.shape) == 4:
100
- _B, _H, _W, _C = context.shape
101
- context_seq_len = _H * _W
102
- context = context.reshape((B, context_seq_len, _C))
103
- # else: # context is already [B, S_ctx, C]
104
-
105
- query = self.query(x) # [B, S, H, D]
106
- key = self.key(context) # [B, S_ctx, H, D]
107
- value = self.value(context) # [B, S_ctx, H, D]
108
-
109
- if freqs_cis is None and self.rope_emb is not None:
110
- seq_len_q = query.shape[1] # Use query's sequence length
111
- freqs_cos, freqs_sin = self.rope_emb(seq_len_q)
112
- elif freqs_cis is not None:
113
- freqs_cos, freqs_sin = freqs_cis
114
- else:
115
- # Should not happen if rope_emb is provided or freqs_cis are passed
116
- raise ValueError("RoPE frequencies not provided.")
117
-
118
- # Apply RoPE to query and key
119
- # Permute to [B, H, S, D] for RoPE application
120
- query = einops.rearrange(query, 'b s h d -> b h s d')
121
- key = einops.rearrange(key, 'b s h d -> b h s d')
122
-
123
- # Apply RoPE only up to the context sequence length for keys if different
124
- # Assuming self-attention or context has same seq len for simplicity here
125
- query = apply_rotary_embedding(query, freqs_cos, freqs_sin)
126
- key = apply_rotary_embedding(
127
- key, freqs_cos, freqs_sin) # Apply same freqs to key
128
-
129
- # Permute back to [B, S, H, D] for dot_product_attention
130
- query = einops.rearrange(query, 'b h s d -> b s h d')
131
- key = einops.rearrange(key, 'b h s d -> b s h d')
132
-
133
- hidden_states = nn.dot_product_attention(
134
- query, key, value, dtype=self.dtype, broadcast_dropout=False,
135
- dropout_rng=None, precision=self.precision, force_fp32_for_softmax=self.force_fp32_for_softmax,
136
- deterministic=True
137
- )
138
-
139
- proj = self.proj_attn(hidden_states)
140
-
141
- if is_4d:
142
- proj = proj.reshape(orig_x_shape)
143
-
144
- return proj
145
-
146
-
147
16
  # --- MM-DiT AdaLN-Zero ---
148
17
  class MMAdaLNZero(nn.Module):
149
18
  """