flaxdiff 0.2.7__py3-none-any.whl → 0.2.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -26,6 +26,7 @@ from flax.training.dynamic_scale import DynamicScale
26
26
  from flaxdiff.utils import RandomMarkovState
27
27
  from flax.training import dynamic_scale as dynamic_scale_lib
28
28
  from dataclasses import dataclass
29
+ import shutil
29
30
  import gc
30
31
 
31
32
  PROCESS_COLOR_MAP = {
@@ -73,6 +74,76 @@ class SimpleTrainState(train_state.TrainState):
73
74
  metrics: Metrics
74
75
  dynamic_scale: dynamic_scale_lib.DynamicScale
75
76
 
77
+ def move_contents_to_subdir(target_dir, new_subdir_name):
78
+ # --- 1. Validate Target Directory ---
79
+ if not os.path.isdir(target_dir):
80
+ print(f"Error: Target directory '{target_dir}' not found or is not a directory.")
81
+ return
82
+ # --- 2. Define Paths ---
83
+ # Construct the full path for the new subdirectory
84
+ new_subdir_path = os.path.join(target_dir, new_subdir_name)
85
+ # --- 3. Create New Subdirectory ---
86
+ try:
87
+ # Create the subdirectory.
88
+ # exist_ok=True prevents an error if the directory already exists.
89
+ os.makedirs(new_subdir_path, exist_ok=True)
90
+ print(f"Subdirectory '{new_subdir_path}' created or already exists.")
91
+ except OSError as e:
92
+ print(f"Error creating subdirectory '{new_subdir_path}': {e}")
93
+ return # Stop execution if subdirectory creation fails
94
+ # --- 4. List Contents of Target Directory ---
95
+ try:
96
+ items_to_move = os.listdir(target_dir)
97
+ except OSError as e:
98
+ print(f"Error listing contents of '{target_dir}': {e}")
99
+ return # Stop if we can't list directory contents
100
+ # --- 5. Move Items ---
101
+ print(f"Moving items from '{target_dir}' to '{new_subdir_path}'...")
102
+ moved_count = 0
103
+ error_count = 0
104
+ for item_name in items_to_move:
105
+ # Construct the full path of the item in the target directory
106
+ source_path = os.path.join(target_dir, item_name)
107
+ # IMPORTANT: Skip the newly created subdirectory itself!
108
+ if source_path == new_subdir_path:
109
+ continue
110
+ # Construct the destination path inside the new subdirectory
111
+ destination_path = os.path.join(new_subdir_path, item_name)
112
+ # Move the item
113
+ try:
114
+ shutil.move(source_path, destination_path)
115
+ # print(f" Moved: '{item_name}'") # Uncomment for verbose output
116
+ moved_count += 1
117
+ except Exception as e:
118
+ print(f" Error moving '{item_name}': {e}")
119
+ error_count += 1
120
+ print(f"\nOperation complete.")
121
+ print(f" Successfully moved: {moved_count} item(s).")
122
+ if error_count > 0:
123
+ print(f" Errors encountered: {error_count} item(s).")
124
+
125
+ def load_from_checkpoint(
126
+ checkpoint_dir: str,
127
+ ):
128
+ try:
129
+ checkpointer = orbax.checkpoint.PyTreeCheckpointer()
130
+ options = orbax.checkpoint.CheckpointManagerOptions(create=False)
131
+ # Convert checkpoint_dir to absolute path
132
+ checkpoint_dir = os.path.abspath(checkpoint_dir)
133
+ manager = orbax.checkpoint.CheckpointManager(checkpoint_dir, checkpointer, options)
134
+ ckpt = manager.restore(checkpoint_dir)
135
+ # Extract as above
136
+ state, best_state = None, None
137
+ if 'state' in ckpt:
138
+ state = ckpt['state']
139
+ if 'best_state' in ckpt:
140
+ best_state = ckpt['best_state']
141
+ print(f"Loaded checkpoint from local dir {checkpoint_dir}")
142
+ return state, best_state
143
+ except Exception as e:
144
+ print(f"Warning: Failed to load checkpoint from local dir: {e}")
145
+ return None, None
146
+
76
147
  @dataclass
77
148
  class SimpleTrainer:
78
149
  state: SimpleTrainState
@@ -97,6 +168,7 @@ class SimpleTrainer:
97
168
  checkpoint_step: int = None,
98
169
  use_dynamic_scale: bool = False,
99
170
  max_checkpoints_to_keep: int = 2,
171
+ train_start_step_override: int = None,
100
172
  ):
101
173
  if distributed_training is None or distributed_training is True:
102
174
  # Auto-detect if we are running on multiple devices
@@ -112,11 +184,32 @@ class SimpleTrainer:
112
184
  self.input_shapes = input_shapes
113
185
  self.checkpoint_base_path = checkpoint_base_path
114
186
 
187
+ load_directly_from_dir = False
188
+
115
189
  if wandb_config is not None and jax.process_index() == 0:
116
190
  import wandb
117
191
  run = wandb.init(resume='allow', **wandb_config)
118
192
  self.wandb = run
119
193
 
194
+ if 'id' in wandb_config:
195
+ # If resuming from a previous run, and train_start_step_override is not set,
196
+ # set the start step to the last step of the previous run
197
+ if train_start_step_override is None:
198
+ train_start_step_override = run.summary['train/step'] + 1
199
+ print(f"Resuming from previous run {wandb_config['id']} with start step {train_start_step_override}")
200
+
201
+ # If load_from_checkpoint is not set, and an artifact is found, load the artifact
202
+ if load_from_checkpoint is None:
203
+ api_run = wandb.Api().run(f"{wandb_config['entity']}/{wandb_config['project']}/{wandb_config['id']}")
204
+ model_artifacts = [i for i in api_run.logged_artifacts() if i.type == 'model']
205
+ if model_artifacts:
206
+ artifact = model_artifacts[0]
207
+ artifact_dir = artifact.download()
208
+ print(f"Loading model from artifact {artifact.name} at {artifact_dir}")
209
+ # Move the artifact's contents
210
+ load_from_checkpoint = artifact_dir
211
+ load_directly_from_dir = True
212
+
120
213
  # define our custom x axis metric
121
214
  self.wandb.define_metric("train/step")
122
215
  self.wandb.define_metric("train/epoch")
@@ -142,12 +235,16 @@ class SimpleTrainer:
142
235
  self.checkpoint_path(), async_checkpointer, options)
143
236
 
144
237
  if load_from_checkpoint is not None:
145
- latest_epoch, latest_step, old_state, old_best_state, rngstate = self.load(load_from_checkpoint, checkpoint_step)
238
+ latest_step, old_state, old_best_state, rngstate = self.load(load_from_checkpoint, checkpoint_step, load_directly_from_dir)
146
239
  else:
147
- latest_epoch, latest_step, old_state, old_best_state, rngstate = 0, 0, None, None, None
240
+ latest_step, old_state, old_best_state, rngstate = 0, None, None, None
148
241
 
149
242
  self.latest_step = latest_step
150
243
 
244
+ if train_start_step_override is not None:
245
+ self.latest_step = train_start_step_override
246
+ print(f"Overriding start step to {self.latest_step}")
247
+
151
248
  if rngstate:
152
249
  self.rngstate = RandomMarkovState(**rngstate)
153
250
  else:
@@ -239,15 +336,12 @@ class SimpleTrainer:
239
336
  os.makedirs(path)
240
337
  return path
241
338
 
242
- def load(self, checkpoint_path=None, checkpoint_step=None):
243
- if checkpoint_path is None:
244
- checkpointer = self.checkpointer
245
- else:
246
- checkpointer = orbax.checkpoint.PyTreeCheckpointer()
247
- options = orbax.checkpoint.CheckpointManagerOptions(
248
- max_to_keep=4, create=False)
249
- checkpointer = orbax.checkpoint.CheckpointManager(
250
- checkpoint_path, checkpointer, options)
339
+ def load(self, checkpoint_path, checkpoint_step=None, load_directly_from_dir=False):
340
+ checkpointer = orbax.checkpoint.PyTreeCheckpointer()
341
+ options = orbax.checkpoint.CheckpointManagerOptions(
342
+ max_to_keep=4, create=False)
343
+ checkpointer = orbax.checkpoint.CheckpointManager(
344
+ checkpoint_path, checkpointer, options)
251
345
 
252
346
  if checkpoint_step is None:
253
347
  step = checkpointer.latest_step()
@@ -259,7 +353,8 @@ class SimpleTrainer:
259
353
  checkpoint_path if checkpoint_path else self.checkpoint_path(),
260
354
  f"{step}")
261
355
  self.loaded_checkpoint_path = loaded_checkpoint_path
262
- ckpt = checkpointer.restore(step)
356
+ ckpt = checkpointer.restore(step) if not load_directly_from_dir else checkpointer.restore(checkpoint_path)
357
+
263
358
  state = ckpt['state']
264
359
  best_state = ckpt['best_state']
265
360
  rngstate = ckpt['rngs']
@@ -268,10 +363,8 @@ class SimpleTrainer:
268
363
  if self.best_loss == 0:
269
364
  # It cant be zero as that must have been some problem
270
365
  self.best_loss = 1e9
271
- current_epoch = ckpt.get('epoch', step) # Must be a checkpoint from an older version which used epochs instead of steps
272
- print(
273
- f"Loaded model from checkpoint at epoch {current_epoch} step {step}", ckpt['best_loss'])
274
- return current_epoch, step, state, best_state, rngstate
366
+ print(f"Loaded model from checkpoint at step {step}", ckpt['best_loss'])
367
+ return step, state, best_state, rngstate
275
368
 
276
369
  def save(self, epoch=0, step=0, state=None, rngstate=None):
277
370
  print(f"Saving model at epoch {epoch} step {step}")
@@ -507,6 +600,7 @@ class SimpleTrainer:
507
600
 
508
601
  def fit(self, data, train_steps_per_epoch, epochs, train_step_args={}, val_steps_per_epoch=5, validation_step_args={}):
509
602
  train_ds = iter(data['train']())
603
+ val_ds = data.get('val', data.get('test', None))()
510
604
  train_step = self._define_train_step(**train_step_args)
511
605
  val_step = self._define_validation_step(**validation_step_args)
512
606
  train_state = self.state
@@ -520,7 +614,7 @@ class SimpleTrainer:
520
614
  self.validation_loop(
521
615
  train_state,
522
616
  val_step,
523
- data.get('val', data.get('test', None)),
617
+ val_ds,
524
618
  val_steps_per_epoch,
525
619
  self.latest_step,
526
620
  )
@@ -571,11 +665,11 @@ class SimpleTrainer:
571
665
  self.validation_loop(
572
666
  train_state,
573
667
  val_step,
574
- data.get('val', data.get('test', None)),
668
+ val_ds,
575
669
  val_steps_per_epoch,
576
670
  current_step,
577
671
  )
578
672
  print(colored(f"Validation done on process index {process_index}", PROCESS_COLOR_MAP[process_index]))
579
673
 
580
- self.save(epochs)
674
+ self.save(epochs)#
581
675
  return self.state
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: flaxdiff
3
- Version: 0.2.7
3
+ Version: 0.2.8
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author-email: Ashish Kumar Singh <ashishkmr472@gmail.com>
6
6
  License-Expression: MIT
@@ -2,16 +2,16 @@ flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  flaxdiff/utils.py,sha256=DmlWUY1FGz4ESxIHaPQJf92CHjsdMjyDd651wFUtyNg,8838
3
3
  flaxdiff/data/__init__.py,sha256=8W5y7NyAOWtpLi8WRawk4VYeE3DMDnM3B_jKPD8BoFQ,143
4
4
  flaxdiff/data/benchmark_decord.py,sha256=x56Db1VPmziv_9KJvWdfS0O7cffsYkF5tt5WvldOKc0,13720
5
- flaxdiff/data/dataloaders.py,sha256=LV8ugqoB86yihfYeOJZHHdRZJNmZ63A2NQkdILMR9QA,23564
6
- flaxdiff/data/dataset_map.py,sha256=_6SYnmrYO-URDd8vPAmALTV6r0eMGWWmwUtsdjKGXnA,5072
5
+ flaxdiff/data/dataloaders.py,sha256=HQR0rsLNYXRPBmdOBKFCc3UfWsmSbSO_-dOQHCbu_VA,23966
6
+ flaxdiff/data/dataset_map.py,sha256=Dz_suGz23Cy7RfWt0FDRX7Q3NTB5SAw2UNHO_-p0qiM,5098
7
7
  flaxdiff/data/online_loader.py,sha256=t1jEhdB6gWTlwx68ehj1ol_PrImbwXYiRlrJPCmNgCM,35701
8
8
  flaxdiff/data/sources/audio_utils.py,sha256=X27gG1yQt_abVOYgMtruYmZD7-8_uQCRhhTSpn4clkI,4514
9
9
  flaxdiff/data/sources/av_example.py,sha256=RIcbVKqckFqbfnV65NQotzIBxjdDuM67kD1nY8fqw5Q,3826
10
10
  flaxdiff/data/sources/av_utils.py,sha256=LCr9MJNurOaoxY-sjzkLqJS_MlX0x3gRSlKAVIglAU0,24045
11
- flaxdiff/data/sources/base.py,sha256=uhF0odJSYRy0SLw1xnI9Q_q_xiVht2DmEYcX1j9AWT4,4246
12
- flaxdiff/data/sources/images.py,sha256=RFLtKW1xzw6ZPVXtCMmnTg1MPb8dc7rP77rZWbK7qpo,11796
11
+ flaxdiff/data/sources/base.py,sha256=4Rm9pCtXxzoB8FO0lkDHsrX3ULoU_PNNcid978e6ir0,4610
12
+ flaxdiff/data/sources/images.py,sha256=71TzTVbPzV-Md3-1Lk4eWfb11w6aaO01OClwK_SiCSM,14708
13
13
  flaxdiff/data/sources/utils.py,sha256=kFzM4_kPoThbAu54ulABmEDAR33tR50NgzXIpC0Dzjk,7316
14
- flaxdiff/data/sources/videos.py,sha256=CVpOH6A4P2D8iv3gZIhd2GB5ATUD8Vsm_wVYbbugWD4,9359
14
+ flaxdiff/data/sources/videos.py,sha256=NkxwEruNpAwDCM53q4WurQ802gSjQMOqjNLxYOqjoNE,9545
15
15
  flaxdiff/data/sources/voxceleb2.py,sha256=BoKfat_hsw6ObDyyaiQmPbBzuFiqgCGlgAZmf-t5Iz8,18621
16
16
  flaxdiff/inference/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
17
17
  flaxdiff/inference/pipeline.py,sha256=8S30FAlXEjvrDd87H-qdD6biySQZ3cJUflU8gdmPxig,9223
@@ -27,12 +27,14 @@ flaxdiff/metrics/ssim.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
27
27
  flaxdiff/metrics/utils.py,sha256=YuuOfqvqgIjsceupwNeJ59vQ2TnGeNMIyKdkIqOmoNg,1702
28
28
  flaxdiff/models/__init__.py,sha256=amtDF07DfiAdnZsvWX4eaW79nwNEU1s8Zb4PB3ewtg4,118
29
29
  flaxdiff/models/attention.py,sha256=YkED3_MRTjI9aTMTTQdsuReHhG8MK0Z4OVuU2j8ZAHs,13524
30
- flaxdiff/models/better_uvit.py,sha256=wPxvYBjuWQH6-OqW79VedzN6_WRY1f2mysPxaciWLww,15598
31
- flaxdiff/models/common.py,sha256=0j9AAjGPgkBLHo2DlYj0R6OsUNw2QaoDjaXSKq2mqkA,12647
30
+ flaxdiff/models/common.py,sha256=QpciwuJldvLUwyAyWBQqiPPGVI-c9qLR7h7C1YoRX7w,10510
32
31
  flaxdiff/models/favor_fastattn.py,sha256=79Ew1nqarsNLPzZaBSd1ILORzJr74CupYeqGiCQK5E4,27689
33
32
  flaxdiff/models/general.py,sha256=7xMME6KVKQY8sScyHYH4f-Kek4j1pRfplKShFXwVZd4,587
33
+ flaxdiff/models/hilbert.py,sha256=AjlAv49dL6UAYWslMJfCMLiFqY4kTgpiUWr2nc1mk34,24823
34
+ flaxdiff/models/simple_dit.py,sha256=Hc2jLOZCYSDm6x88m3bGYu-OKge1TukiQPSdlaO68rE,19667
35
+ flaxdiff/models/simple_mmdit.py,sha256=RmOq6LbfDBUUEib6MSAURujxn9iHgdh77a6ntNsWI2w,36210
34
36
  flaxdiff/models/simple_unet.py,sha256=pjeixszG_6gEY5PNFbQ7KbOyg4z5bfn4RUbINCJexOM,10758
35
- flaxdiff/models/simple_vit.py,sha256=no0o3os8dEKGU5I4PMBJlXq6XKjhUex8S8uZ9BDPZS4,7971
37
+ flaxdiff/models/simple_vit.py,sha256=QEHPyaQIYhqSYrD6eb65X70jQL-y09nRT8Yc4b5Jq6Q,15181
36
38
  flaxdiff/models/unet_3d.py,sha256=LF0PMxBKGU-_lAMtO_Coxy1yRE02yKKdgb7i6YZxI_4,20163
37
39
  flaxdiff/models/unet_3d_blocks.py,sha256=lRYDc9X1VEu54Kg7xEEphXYiQ09tabPXKi-hEcKFYug,19687
38
40
  flaxdiff/models/autoencoder/__init__.py,sha256=qY-7MldZpsfkF-_T2LqlRK7VHbqfmosz0NmvzDlBkOk,78
@@ -60,9 +62,9 @@ flaxdiff/schedulers/sqrt.py,sha256=mCd_szmOqF6vqQKiAiEOqV_3eBIPGYrW3VxK0o4rBuo,4
60
62
  flaxdiff/trainer/__init__.py,sha256=xSoierfi26gxfgxlNnwvyyPmuPAJ--5i3mEHxt3S-AE,215
61
63
  flaxdiff/trainer/autoencoder_trainer.py,sha256=2FP2P-k9c0n_k3eT0trkq73dQrHRdBj9ObK1idcyhSw,6996
62
64
  flaxdiff/trainer/diffusion_trainer.py,sha256=reQEVWKTqKAeyCMQ-curPOfSRmBKxKooK8EVtUuorcM,14599
63
- flaxdiff/trainer/general_diffusion_trainer.py,sha256=BeDpJzgR8bUClJI4epQXlAul27MwiSfRW0lIBZSiPWk,28342
64
- flaxdiff/trainer/simple_trainer.py,sha256=raLS1shwpjJBT_bYXLAB2E4kA9MbwasDTzDTUqfCCUc,24312
65
- flaxdiff-0.2.7.dist-info/METADATA,sha256=nwglJYeF2lH_MNq5PeFLR8TSPU-I9tzJUcBbTaLYxRM,24057
66
- flaxdiff-0.2.7.dist-info/WHEEL,sha256=wXxTzcEDnjrTwFYjLPcsW_7_XihufBwmpiBeiXNBGEA,91
67
- flaxdiff-0.2.7.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
68
- flaxdiff-0.2.7.dist-info/RECORD,,
65
+ flaxdiff/trainer/general_diffusion_trainer.py,sha256=FUvc--3ibRAjrYiKbA-FyLqKhusakxeNOa6UJZaK4SU,29307
66
+ flaxdiff/trainer/simple_trainer.py,sha256=Hdltuo3lgF61N04Lxc7L3z6NLveW4_h1ff7_5mu3Wbg,28730
67
+ flaxdiff-0.2.8.dist-info/METADATA,sha256=y2jLjsEkR-GKvLWuGzlyBrk1SNM6tCPT0Oc7vRZC7_I,24057
68
+ flaxdiff-0.2.8.dist-info/WHEEL,sha256=0CuiUZ_p9E4cD6NyLD6UG80LBXYyiSYZOKDm5lp32xk,91
69
+ flaxdiff-0.2.8.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
70
+ flaxdiff-0.2.8.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.1.0)
2
+ Generator: setuptools (80.3.1)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,380 +0,0 @@
1
- # flaxdiff/models/better_uvit.py
2
- import jax
3
- import jax.numpy as jnp
4
- from flax import linen as nn
5
- from typing import Callable, Any, Optional, Tuple, Sequence, Union
6
- import einops
7
- from functools import partial
8
-
9
- # Re-use existing components if they are suitable
10
- from .common import kernel_init, FourierEmbedding, TimeProjection, hilbert_indices, inverse_permutation
11
- from .attention import NormalAttention # Using NormalAttention for RoPE integration
12
- from flax.typing import Dtype, PrecisionLike
13
-
14
- # --- Rotary Positional Embedding (RoPE) ---
15
- # Adapted from https://github.com/google-deepmind/ring_attention/blob/main/ring_attention/layers/rotary.py
16
-
17
- def _rotate_half(x: jax.Array) -> jax.Array:
18
- """Rotates half the hidden dims of the input."""
19
- x1 = x[..., : x.shape[-1] // 2]
20
- x2 = x[..., x.shape[-1] // 2 :]
21
- return jnp.concatenate((-x2, x1), axis=-1)
22
-
23
- def apply_rotary_embedding(
24
- x: jax.Array, freqs_cis: jax.Array
25
- ) -> jax.Array:
26
- """Applies rotary embedding to the input tensor using rotate_half method."""
27
- # x shape: [..., Sequence, Dimension] e.g. [B, H, S, D] or [B, S, D]
28
- # freqs_cis shape: complex [Sequence, Dimension / 2]
29
-
30
- # Extract cos and sin from the complex freqs_cis
31
- cos_freqs = jnp.real(freqs_cis) # Shape [S, D/2]
32
- sin_freqs = jnp.imag(freqs_cis) # Shape [S, D/2]
33
-
34
- # Expand dims for broadcasting: [1, 1, S, D/2] or [1, S, D/2]
35
- if x.ndim == 4: # [B, H, S, D]
36
- cos_freqs = jnp.expand_dims(cos_freqs, axis=(0, 1))
37
- sin_freqs = jnp.expand_dims(sin_freqs, axis=(0, 1))
38
- elif x.ndim == 3: # [B, S, D]
39
- cos_freqs = jnp.expand_dims(cos_freqs, axis=0)
40
- sin_freqs = jnp.expand_dims(sin_freqs, axis=0)
41
-
42
- # Duplicate cos and sin for the full dimension D
43
- # Shape becomes [..., S, D]
44
- cos_freqs = jnp.concatenate([cos_freqs, cos_freqs], axis=-1)
45
- sin_freqs = jnp.concatenate([sin_freqs, sin_freqs], axis=-1)
46
-
47
- # Apply rotation: x * cos + rotate_half(x) * sin
48
- x_rotated = x * cos_freqs + _rotate_half(x) * sin_freqs
49
- return x_rotated.astype(x.dtype)
50
-
51
-
52
- class RotaryEmbedding(nn.Module):
53
- dim: int # Dimension of the head
54
- max_seq_len: int = 2048
55
- base: int = 10000
56
- dtype: Dtype = jnp.float32
57
-
58
- def setup(self):
59
- inv_freq = 1.0 / (
60
- self.base ** (jnp.arange(0, self.dim, 2, dtype=jnp.float32) / self.dim)
61
- )
62
- t = jnp.arange(self.max_seq_len, dtype=jnp.float32)
63
- freqs = jnp.outer(t, inv_freq) # Shape: [max_seq_len, dim / 2]
64
-
65
- # Precompute the complex form: cos(theta) + i * sin(theta)
66
- self.freqs_cis_complex = jnp.cos(freqs) + 1j * jnp.sin(freqs)
67
- # Shape: [max_seq_len, dim / 2]
68
-
69
- def __call__(self, seq_len: int):
70
- if seq_len > self.max_seq_len:
71
- raise ValueError(f"Sequence length {seq_len} exceeds max_seq_len {self.max_seq_len}")
72
- # Return complex shape [seq_len, dim / 2]
73
- return self.freqs_cis_complex[:seq_len, :]
74
-
75
- # --- Attention with RoPE ---
76
-
77
- class RoPEAttention(NormalAttention):
78
- rope_emb: RotaryEmbedding
79
-
80
- @nn.compact
81
- def __call__(self, x, context=None, freqs_cis=None):
82
- # x has shape [B, H, W, C] or [B, S, C]
83
- orig_x_shape = x.shape
84
- is_4d = len(orig_x_shape) == 4
85
- if is_4d:
86
- B, H, W, C = x.shape
87
- seq_len = H * W
88
- x = x.reshape((B, seq_len, C))
89
- else:
90
- B, seq_len, C = x.shape
91
-
92
- context = x if context is None else context
93
- if len(context.shape) == 4:
94
- _B, _H, _W, _C = context.shape
95
- context_seq_len = _H * _W
96
- context = context.reshape((B, context_seq_len, _C))
97
- else:
98
- _B, context_seq_len, _C = context.shape
99
-
100
- query = self.query(x) # [B, S, H, D]
101
- key = self.key(context) # [B, S_ctx, H, D]
102
- value = self.value(context) # [B, S_ctx, H, D]
103
-
104
- # Apply RoPE to query and key
105
- if freqs_cis is not None:
106
- # Permute to [B, H, S, D] for RoPE application if needed by apply_rotary_embedding
107
- query = einops.rearrange(query, 'b s h d -> b h s d')
108
- key = einops.rearrange(key, 'b s h d -> b h s d')
109
-
110
- query = apply_rotary_embedding(query, freqs_cis)
111
- key = apply_rotary_embedding(key, freqs_cis) # Apply to key as well
112
-
113
- # Permute back to [B, S, H, D] for dot_product_attention
114
- query = einops.rearrange(query, 'b h s d -> b s h d')
115
- key = einops.rearrange(key, 'b h s d -> b s h d')
116
-
117
- hidden_states = nn.dot_product_attention(
118
- query, key, value, dtype=self.dtype, broadcast_dropout=False,
119
- dropout_rng=None, precision=self.precision, force_fp32_for_softmax=self.force_fp32_for_softmax,
120
- deterministic=True
121
- ) # Output shape [B, S, H, D]
122
-
123
- proj = self.proj_attn(hidden_states) # Output shape [B, S, C]
124
-
125
- if is_4d:
126
- proj = proj.reshape(orig_x_shape) # Reshape back if input was 4D
127
-
128
- return proj
129
-
130
- # --- adaLN-Zero ---
131
-
132
- class AdaLNZero(nn.Module):
133
- features: int
134
- dtype: Optional[Dtype] = None
135
- precision: PrecisionLike = None
136
- norm_epsilon: float = 1e-5 # Standard LayerNorm epsilon
137
-
138
- @nn.compact
139
- def __call__(self, x, conditioning):
140
- # Project conditioning signal to get scale and shift parameters
141
- # Conditioning shape: [B, D_cond] -> [B, 1, ..., 1, 6 * features] for broadcasting
142
- # Or [B, 1, 6*features] if x is [B, S, F]
143
-
144
- # Ensure conditioning has seq dim if x does
145
- if x.ndim == 3 and conditioning.ndim == 2: # x=[B,S,F], cond=[B,D_cond]
146
- conditioning = jnp.expand_dims(conditioning, axis=1) # cond=[B,1,D_cond]
147
-
148
- # Project conditioning to get 6 params per feature (scale_mlp, shift_mlp, gate_mlp, scale_attn, shift_attn, gate_attn)
149
- # Using nn.DenseGeneral for flexibility if needed, but nn.Dense is fine if cond is [B, D_cond] or [B, 1, D_cond]
150
- ada_params = nn.Dense(
151
- features=6 * self.features,
152
- dtype=self.dtype,
153
- precision=self.precision,
154
- kernel_init=nn.initializers.zeros, # Initialize projection to zero (Zero init)
155
- name="ada_proj"
156
- )(conditioning)
157
-
158
- # Split into scale, shift, gate for MLP and Attention
159
- scale_mlp, shift_mlp, gate_mlp, scale_attn, shift_attn, gate_attn = jnp.split(ada_params, 6, axis=-1)
160
-
161
- # Apply Layer Normalization
162
- norm = nn.LayerNorm(epsilon=self.norm_epsilon, use_scale=False, use_bias=False, dtype=self.dtype)
163
- # norm = nn.RMSNorm(epsilon=self.norm_epsilon, dtype=self.dtype) # Alternative: RMSNorm
164
-
165
- norm_x = norm(x)
166
-
167
- # Modulate for Attention path
168
- x_attn = norm_x * (1 + scale_attn) + shift_attn
169
-
170
- # Modulate for MLP path
171
- x_mlp = norm_x * (1 + scale_mlp) + shift_mlp
172
-
173
- # Return modulated outputs and gates
174
- return x_attn, gate_attn, x_mlp, gate_mlp
175
-
176
-
177
- # --- DiT Block ---
178
-
179
- class DiTBlock(nn.Module):
180
- features: int
181
- num_heads: int
182
- mlp_ratio: int = 4
183
- dropout_rate: float = 0.0 # Typically dropout is not used in diffusion models
184
- dtype: Optional[Dtype] = None
185
- precision: PrecisionLike = None
186
- use_flash_attention: bool = False # Keep option, but RoPEAttention uses NormalAttention base
187
- force_fp32_for_softmax: bool = True
188
- norm_epsilon: float = 1e-5
189
- rope_emb: RotaryEmbedding # Pass RoPE module
190
-
191
- def setup(self):
192
- hidden_features = int(self.features * self.mlp_ratio)
193
- self.ada_ln_zero = AdaLNZero(self.features, dtype=self.dtype, precision=self.precision, norm_epsilon=self.norm_epsilon)
194
-
195
- # Use RoPEAttention
196
- self.attention = RoPEAttention(
197
- query_dim=self.features,
198
- heads=self.num_heads,
199
- dim_head=self.features // self.num_heads,
200
- dtype=self.dtype,
201
- precision=self.precision,
202
- use_bias=True, # Bias is common in DiT attention proj
203
- force_fp32_for_softmax=self.force_fp32_for_softmax,
204
- rope_emb=self.rope_emb # Pass RoPE module instance
205
- )
206
-
207
- # Standard MLP block
208
- self.mlp = nn.Sequential([
209
- nn.Dense(features=hidden_features, dtype=self.dtype, precision=self.precision),
210
- nn.gelu,
211
- nn.Dense(features=self.features, dtype=self.dtype, precision=self.precision)
212
- ])
213
-
214
- @nn.compact
215
- def __call__(self, x, conditioning, freqs_cis):
216
- # x shape: [B, S, F]
217
- # conditioning shape: [B, D_cond]
218
-
219
- residual = x
220
-
221
- # Apply adaLN-Zero to get modulated inputs and gates
222
- x_attn, gate_attn, x_mlp, gate_mlp = self.ada_ln_zero(x, conditioning)
223
-
224
- # Attention block
225
- attn_output = self.attention(x_attn, context=None, freqs_cis=freqs_cis) # Self-attention only
226
- x = residual + gate_attn * attn_output
227
-
228
- # MLP block
229
- mlp_output = self.mlp(x_mlp)
230
- x = x + gate_mlp * mlp_output
231
-
232
- return x
233
-
234
- # --- Patch Embedding (reuse or define if needed) ---
235
- # Assuming PatchEmbedding exists in simple_vit.py and is suitable
236
- from .simple_vit import PatchEmbedding, unpatchify
237
-
238
- # --- Better UViT (DiT Style) ---
239
-
240
- class BetterUViT(nn.Module):
241
- output_channels: int = 3
242
- patch_size: int = 16
243
- emb_features: int = 768
244
- num_layers: int = 12
245
- num_heads: int = 12
246
- mlp_ratio: int = 4
247
- dropout_rate: float = 0.0 # Typically 0 for diffusion
248
- dtype: Optional[Dtype] = None
249
- precision: PrecisionLike = None
250
- use_flash_attention: bool = False # Passed down, but RoPEAttention uses NormalAttention
251
- force_fp32_for_softmax: bool = True
252
- norm_epsilon: float = 1e-5
253
- learn_sigma: bool = False # Option to predict sigma like in DiT paper
254
- use_hilbert: bool = False # Toggle Hilbert patch reorder
255
-
256
- def setup(self):
257
- self.patch_embed = PatchEmbedding(
258
- patch_size=self.patch_size,
259
- embedding_dim=self.emb_features,
260
- dtype=self.dtype,
261
- precision=self.precision
262
- )
263
-
264
- # Time embedding projection
265
- self.time_embed = nn.Sequential([
266
- FourierEmbedding(features=self.emb_features),
267
- TimeProjection(features=self.emb_features * self.mlp_ratio), # Project to MLP dim
268
- nn.Dense(features=self.emb_features, dtype=self.dtype, precision=self.precision) # Final projection
269
- ])
270
-
271
- # Text context projection (if used)
272
- # Assuming textcontext is already projected to some dimension, project it to match emb_features
273
- # This might need adjustment based on how text context is provided
274
- self.text_proj = nn.Dense(features=self.emb_features, dtype=self.dtype, precision=self.precision, name="text_context_proj")
275
-
276
- # Rotary Positional Embedding
277
- # Max length needs to be estimated or set large enough.
278
- # For images, seq len = (H/P) * (W/P). Example: 256/16 * 256/16 = 16*16 = 256
279
- # Add 1 if a class token is used, or more for text tokens if concatenated.
280
- # Let's assume max seq len accommodates patches + time + text tokens if needed, or just patches.
281
- # If only patches use RoPE, max_len = max_image_tokens
282
- # If time/text are concatenated *before* blocks, max_len needs to include them.
283
- # DiT typically applies PE only to patch tokens. Let's follow that.
284
- # max_len should be max number of patches.
285
- # Example: max image size 512x512, patch 16 -> (512/16)^2 = 32^2 = 1024 patches
286
- self.rope = RotaryEmbedding(dim=self.emb_features // self.num_heads, max_seq_len=4096, dtype=self.dtype) # Dim per head
287
-
288
- # Transformer Blocks
289
- self.blocks = [
290
- DiTBlock(
291
- features=self.emb_features,
292
- num_heads=self.num_heads,
293
- mlp_ratio=self.mlp_ratio,
294
- dropout_rate=self.dropout_rate,
295
- dtype=self.dtype,
296
- precision=self.precision,
297
- use_flash_attention=self.use_flash_attention,
298
- force_fp32_for_softmax=self.force_fp32_for_softmax,
299
- norm_epsilon=self.norm_epsilon,
300
- rope_emb=self.rope, # Pass RoPE instance
301
- name=f"dit_block_{i}"
302
- ) for i in range(self.num_layers)
303
- ]
304
-
305
- # Final Layer (Normalization + Linear Projection)
306
- self.final_norm = nn.LayerNorm(epsilon=self.norm_epsilon, dtype=self.dtype, name="final_norm")
307
- # self.final_norm = nn.RMSNorm(epsilon=self.norm_epsilon, dtype=self.dtype, name="final_norm")
308
-
309
- # Predict patch pixels + potentially sigma
310
- output_dim = self.patch_size * self.patch_size * self.output_channels
311
- if self.learn_sigma:
312
- output_dim *= 2 # Predict both mean and variance (or log_variance)
313
-
314
- self.final_proj = nn.Dense(
315
- features=output_dim,
316
- dtype=self.dtype,
317
- precision=self.precision,
318
- kernel_init=nn.initializers.zeros, # Initialize final layer to zero
319
- name="final_proj"
320
- )
321
-
322
- @nn.compact
323
- def __call__(self, x, temb, textcontext=None):
324
- B, H, W, C = x.shape
325
- assert H % self.patch_size == 0 and W % self.patch_size == 0, "Image dimensions must be divisible by patch size"
326
-
327
- # 1. Patch Embedding
328
- patches = self.patch_embed(x) # Shape: [B, num_patches, emb_features]
329
- num_patches = patches.shape[1]
330
-
331
- # Optional Hilbert reorder
332
- if self.use_hilbert:
333
- idx = hilbert_indices(H // self.patch_size, W // self.patch_size)
334
- inv_idx = inverse_permutation(idx)
335
- patches = patches[:, idx, :]
336
-
337
- # replace x with patches
338
- x_seq = patches
339
-
340
- # 2. Prepare Conditioning Signal (Time + Text Context)
341
- t_emb = self.time_embed(temb) # Shape: [B, emb_features]
342
-
343
- cond_emb = t_emb
344
- if textcontext is not None:
345
- text_emb = self.text_proj(textcontext) # Shape: [B, num_text_tokens, emb_features]
346
- # Pool or select text embedding (e.g., mean pool or use CLS token)
347
- # Assuming mean pooling for simplicity
348
- text_emb_pooled = jnp.mean(text_emb, axis=1) # Shape: [B, emb_features]
349
- cond_emb = cond_emb + text_emb_pooled # Combine time and text embeddings
350
-
351
- # 3. Apply RoPE
352
- # Get RoPE frequencies for the sequence length (number of patches)
353
- freqs_cis = self.rope(seq_len=num_patches) # Shape [num_patches, D_head/2]
354
-
355
- # 4. Apply Transformer Blocks with adaLN-Zero conditioning
356
- for block in self.blocks:
357
- x_seq = block(x_seq, conditioning=cond_emb, freqs_cis=freqs_cis)
358
-
359
- # 5. Final Layer
360
- x_out = self.final_norm(x_seq)
361
- x_out = self.final_proj(x_out) # Shape: [B, num_patches, patch_pixels (*2 if learn_sigma)]
362
-
363
- # Optional Hilbert inverse reorder
364
- if self.use_hilbert:
365
- x_out = x_out[:, inv_idx, :]
366
-
367
- # 6. Unpatchify
368
- if self.learn_sigma:
369
- # Split into mean and variance predictions
370
- x_mean, x_logvar = jnp.split(x_out, 2, axis=-1)
371
- x = unpatchify(x_mean, channels=self.output_channels)
372
- # Return both mean and logvar if needed by the loss function
373
- # For now, just returning the mean prediction like standard diffusion models
374
- # logvar = unpatchify(x_logvar, channels=self.output_channels)
375
- # return x, logvar
376
- return x
377
- else:
378
- x = unpatchify(x_out, channels=self.output_channels) # Shape: [B, H, W, C]
379
- return x
380
-