flaxdiff 0.2.5__tar.gz → 0.2.6.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. {flaxdiff-0.2.5 → flaxdiff-0.2.6.1}/PKG-INFO +1 -1
  2. {flaxdiff-0.2.5 → flaxdiff-0.2.6.1}/flaxdiff/data/sources/images.py +7 -0
  3. {flaxdiff-0.2.5 → flaxdiff-0.2.6.1}/flaxdiff/inference/pipeline.py +8 -3
  4. {flaxdiff-0.2.5 → flaxdiff-0.2.6.1}/flaxdiff/inference/utils.py +3 -2
  5. {flaxdiff-0.2.5 → flaxdiff-0.2.6.1}/flaxdiff/models/attention.py +7 -4
  6. flaxdiff-0.2.6.1/flaxdiff/models/better_uvit.py +380 -0
  7. {flaxdiff-0.2.5 → flaxdiff-0.2.6.1}/flaxdiff/models/common.py +75 -4
  8. {flaxdiff-0.2.5 → flaxdiff-0.2.6.1}/flaxdiff/models/simple_vit.py +26 -16
  9. {flaxdiff-0.2.5 → flaxdiff-0.2.6.1}/flaxdiff/trainer/general_diffusion_trainer.py +34 -5
  10. {flaxdiff-0.2.5 → flaxdiff-0.2.6.1}/flaxdiff.egg-info/PKG-INFO +1 -1
  11. {flaxdiff-0.2.5 → flaxdiff-0.2.6.1}/flaxdiff.egg-info/SOURCES.txt +1 -0
  12. {flaxdiff-0.2.5 → flaxdiff-0.2.6.1}/pyproject.toml +1 -1
  13. {flaxdiff-0.2.5 → flaxdiff-0.2.6.1}/README.md +0 -0
  14. {flaxdiff-0.2.5 → flaxdiff-0.2.6.1}/flaxdiff/__init__.py +0 -0
  15. {flaxdiff-0.2.5 → flaxdiff-0.2.6.1}/flaxdiff/data/__init__.py +0 -0
  16. {flaxdiff-0.2.5 → flaxdiff-0.2.6.1}/flaxdiff/data/benchmark_decord.py +0 -0
  17. {flaxdiff-0.2.5 → flaxdiff-0.2.6.1}/flaxdiff/data/dataloaders.py +0 -0
  18. {flaxdiff-0.2.5 → flaxdiff-0.2.6.1}/flaxdiff/data/dataset_map.py +0 -0
  19. {flaxdiff-0.2.5 → flaxdiff-0.2.6.1}/flaxdiff/data/online_loader.py +0 -0
  20. {flaxdiff-0.2.5 → flaxdiff-0.2.6.1}/flaxdiff/data/sources/audio_utils.py +0 -0
  21. {flaxdiff-0.2.5 → flaxdiff-0.2.6.1}/flaxdiff/data/sources/av_example.py +0 -0
  22. {flaxdiff-0.2.5 → flaxdiff-0.2.6.1}/flaxdiff/data/sources/av_utils.py +0 -0
  23. {flaxdiff-0.2.5 → flaxdiff-0.2.6.1}/flaxdiff/data/sources/base.py +0 -0
  24. {flaxdiff-0.2.5 → flaxdiff-0.2.6.1}/flaxdiff/data/sources/utils.py +0 -0
  25. {flaxdiff-0.2.5 → flaxdiff-0.2.6.1}/flaxdiff/data/sources/videos.py +0 -0
  26. {flaxdiff-0.2.5 → flaxdiff-0.2.6.1}/flaxdiff/data/sources/voxceleb2.py +0 -0
  27. {flaxdiff-0.2.5 → flaxdiff-0.2.6.1}/flaxdiff/inference/__init__.py +0 -0
  28. {flaxdiff-0.2.5 → flaxdiff-0.2.6.1}/flaxdiff/inputs/__init__.py +0 -0
  29. {flaxdiff-0.2.5 → flaxdiff-0.2.6.1}/flaxdiff/inputs/encoders.py +0 -0
  30. {flaxdiff-0.2.5 → flaxdiff-0.2.6.1}/flaxdiff/metrics/__init__.py +0 -0
  31. {flaxdiff-0.2.5 → flaxdiff-0.2.6.1}/flaxdiff/metrics/common.py +0 -0
  32. {flaxdiff-0.2.5 → flaxdiff-0.2.6.1}/flaxdiff/metrics/images.py +0 -0
  33. {flaxdiff-0.2.5 → flaxdiff-0.2.6.1}/flaxdiff/metrics/inception.py +0 -0
  34. {flaxdiff-0.2.5 → flaxdiff-0.2.6.1}/flaxdiff/metrics/psnr.py +0 -0
  35. {flaxdiff-0.2.5 → flaxdiff-0.2.6.1}/flaxdiff/metrics/ssim.py +0 -0
  36. {flaxdiff-0.2.5 → flaxdiff-0.2.6.1}/flaxdiff/metrics/utils.py +0 -0
  37. {flaxdiff-0.2.5 → flaxdiff-0.2.6.1}/flaxdiff/models/__init__.py +0 -0
  38. {flaxdiff-0.2.5 → flaxdiff-0.2.6.1}/flaxdiff/models/autoencoder/__init__.py +0 -0
  39. {flaxdiff-0.2.5 → flaxdiff-0.2.6.1}/flaxdiff/models/autoencoder/autoencoder.py +0 -0
  40. {flaxdiff-0.2.5 → flaxdiff-0.2.6.1}/flaxdiff/models/autoencoder/diffusers.py +0 -0
  41. {flaxdiff-0.2.5 → flaxdiff-0.2.6.1}/flaxdiff/models/autoencoder/simple_autoenc.py +0 -0
  42. {flaxdiff-0.2.5 → flaxdiff-0.2.6.1}/flaxdiff/models/favor_fastattn.py +0 -0
  43. {flaxdiff-0.2.5 → flaxdiff-0.2.6.1}/flaxdiff/models/general.py +0 -0
  44. {flaxdiff-0.2.5 → flaxdiff-0.2.6.1}/flaxdiff/models/simple_unet.py +0 -0
  45. {flaxdiff-0.2.5 → flaxdiff-0.2.6.1}/flaxdiff/models/unet_3d.py +0 -0
  46. {flaxdiff-0.2.5 → flaxdiff-0.2.6.1}/flaxdiff/models/unet_3d_blocks.py +0 -0
  47. {flaxdiff-0.2.5 → flaxdiff-0.2.6.1}/flaxdiff/predictors/__init__.py +0 -0
  48. {flaxdiff-0.2.5 → flaxdiff-0.2.6.1}/flaxdiff/samplers/__init__.py +0 -0
  49. {flaxdiff-0.2.5 → flaxdiff-0.2.6.1}/flaxdiff/samplers/common.py +0 -0
  50. {flaxdiff-0.2.5 → flaxdiff-0.2.6.1}/flaxdiff/samplers/ddim.py +0 -0
  51. {flaxdiff-0.2.5 → flaxdiff-0.2.6.1}/flaxdiff/samplers/ddpm.py +0 -0
  52. {flaxdiff-0.2.5 → flaxdiff-0.2.6.1}/flaxdiff/samplers/euler.py +0 -0
  53. {flaxdiff-0.2.5 → flaxdiff-0.2.6.1}/flaxdiff/samplers/heun_sampler.py +0 -0
  54. {flaxdiff-0.2.5 → flaxdiff-0.2.6.1}/flaxdiff/samplers/multistep_dpm.py +0 -0
  55. {flaxdiff-0.2.5 → flaxdiff-0.2.6.1}/flaxdiff/samplers/rk4_sampler.py +0 -0
  56. {flaxdiff-0.2.5 → flaxdiff-0.2.6.1}/flaxdiff/schedulers/__init__.py +0 -0
  57. {flaxdiff-0.2.5 → flaxdiff-0.2.6.1}/flaxdiff/schedulers/common.py +0 -0
  58. {flaxdiff-0.2.5 → flaxdiff-0.2.6.1}/flaxdiff/schedulers/continuous.py +0 -0
  59. {flaxdiff-0.2.5 → flaxdiff-0.2.6.1}/flaxdiff/schedulers/cosine.py +0 -0
  60. {flaxdiff-0.2.5 → flaxdiff-0.2.6.1}/flaxdiff/schedulers/discrete.py +0 -0
  61. {flaxdiff-0.2.5 → flaxdiff-0.2.6.1}/flaxdiff/schedulers/exp.py +0 -0
  62. {flaxdiff-0.2.5 → flaxdiff-0.2.6.1}/flaxdiff/schedulers/karras.py +0 -0
  63. {flaxdiff-0.2.5 → flaxdiff-0.2.6.1}/flaxdiff/schedulers/linear.py +0 -0
  64. {flaxdiff-0.2.5 → flaxdiff-0.2.6.1}/flaxdiff/schedulers/sqrt.py +0 -0
  65. {flaxdiff-0.2.5 → flaxdiff-0.2.6.1}/flaxdiff/trainer/__init__.py +0 -0
  66. {flaxdiff-0.2.5 → flaxdiff-0.2.6.1}/flaxdiff/trainer/autoencoder_trainer.py +0 -0
  67. {flaxdiff-0.2.5 → flaxdiff-0.2.6.1}/flaxdiff/trainer/diffusion_trainer.py +0 -0
  68. {flaxdiff-0.2.5 → flaxdiff-0.2.6.1}/flaxdiff/trainer/simple_trainer.py +0 -0
  69. {flaxdiff-0.2.5 → flaxdiff-0.2.6.1}/flaxdiff/utils.py +0 -0
  70. {flaxdiff-0.2.5 → flaxdiff-0.2.6.1}/flaxdiff.egg-info/dependency_links.txt +0 -0
  71. {flaxdiff-0.2.5 → flaxdiff-0.2.6.1}/flaxdiff.egg-info/requires.txt +0 -0
  72. {flaxdiff-0.2.5 → flaxdiff-0.2.6.1}/flaxdiff.egg-info/top_level.txt +0 -0
  73. {flaxdiff-0.2.5 → flaxdiff-0.2.6.1}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: flaxdiff
3
- Version: 0.2.5
3
+ Version: 0.2.6.1
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author-email: Ashish Kumar Singh <ashishkmr472@gmail.com>
6
6
  License-Expression: MIT
@@ -266,6 +266,12 @@ class ImageGCSAugmenter(DataAugmenter):
266
266
 
267
267
  print(f"Using method: {method}")
268
268
 
269
+ from torchvision.transforms import v2
270
+ augments = v2.Compose([
271
+ v2.RandomHorizontalFlip(p=0.5),
272
+ v2.ColorJitter(brightness=0.2, contrast=0.05, saturation=0.2)
273
+ ])
274
+
269
275
  class GCSTransform(pygrain.MapTransform):
270
276
  def __init__(self, *args, **kwargs):
271
277
  super().__init__(*args, **kwargs)
@@ -277,6 +283,7 @@ class ImageGCSAugmenter(DataAugmenter):
277
283
  image = np.asarray(bytearray(element['jpg']), dtype="uint8")
278
284
  image = cv2.imdecode(image, cv2.IMREAD_UNCHANGED)
279
285
  image = self.image_augmenter(image)
286
+ image = augments(image)
280
287
  caption = labelizer(element).decode('utf-8')
281
288
  results = self.auto_tokenize(caption)
282
289
  return {
@@ -20,7 +20,7 @@ from flaxdiff.models.autoencoder import AutoEncoder
20
20
  from flaxdiff.inputs import DiffusionInputConfig
21
21
  from flaxdiff.utils import defaultTextEncodeModel, RandomMarkovState
22
22
  from flaxdiff.samplers.euler import EulerAncestralSampler
23
- from .utils import parse_config, load_from_wandb_run, load_from_wandb_registry
23
+ from flaxdiff.inference.utils import parse_config, load_from_wandb_run, load_from_wandb_registry
24
24
 
25
25
  @dataclass
26
26
  class InferencePipeline:
@@ -51,6 +51,7 @@ class DiffusionInferencePipeline(InferencePipeline):
51
51
  model_output_transform: DiffusionPredictionTransform = None
52
52
  autoencoder: AutoEncoder = None
53
53
  input_config: DiffusionInputConfig = None
54
+ wandb_run = None
54
55
  samplers: Dict[Type[DiffusionSampler], Dict[float, DiffusionSampler]] = field(default_factory=dict)
55
56
  config: Dict[str, Any] = field(default_factory=dict)
56
57
 
@@ -75,7 +76,7 @@ class DiffusionInferencePipeline(InferencePipeline):
75
76
  Returns:
76
77
  DiffusionInferencePipeline instance
77
78
  """
78
- states, config = load_from_wandb_run(
79
+ states, config, run = load_from_wandb_run(
79
80
  wandb_run,
80
81
  project=project,
81
82
  entity=entity,
@@ -93,6 +94,7 @@ class DiffusionInferencePipeline(InferencePipeline):
93
94
  state=state,
94
95
  best_state=best_state,
95
96
  rngstate=RandomMarkovState(jax.random.PRNGKey(42)),
97
+ run=run,
96
98
  )
97
99
  return pipeline
98
100
 
@@ -117,7 +119,7 @@ class DiffusionInferencePipeline(InferencePipeline):
117
119
  Returns:
118
120
  DiffusionInferencePipeline instance
119
121
  """
120
- states, config = load_from_wandb_registry(
122
+ states, config, run = load_from_wandb_registry(
121
123
  modelname=modelname,
122
124
  project=project,
123
125
  entity=entity,
@@ -137,6 +139,7 @@ class DiffusionInferencePipeline(InferencePipeline):
137
139
  state=state,
138
140
  best_state=best_state,
139
141
  rngstate=RandomMarkovState(jax.random.PRNGKey(42)),
142
+ run=run,
140
143
  )
141
144
  return pipeline
142
145
 
@@ -147,6 +150,7 @@ class DiffusionInferencePipeline(InferencePipeline):
147
150
  state: Dict[str, Any],
148
151
  best_state: Optional[Dict[str, Any]] = None,
149
152
  rngstate: Optional[RandomMarkovState] = None,
153
+ run=None,
150
154
  ):
151
155
  if rngstate is None:
152
156
  rngstate = RandomMarkovState(jax.random.PRNGKey(42))
@@ -161,6 +165,7 @@ class DiffusionInferencePipeline(InferencePipeline):
161
165
  autoencoder=config['autoencoder'],
162
166
  input_config=config['input_config'],
163
167
  config=config,
168
+ wandb_run=run,
164
169
  )
165
170
 
166
171
  def get_sampler(
@@ -292,7 +292,7 @@ def load_from_wandb_run(
292
292
  config = run.config
293
293
  except Exception as e:
294
294
  print(f"Warning: Failed to load model from wandb: {e}")
295
- return states, config
295
+ return states, config, run
296
296
 
297
297
  def load_from_wandb_registry(
298
298
  modelname: str,
@@ -307,6 +307,7 @@ def load_from_wandb_registry(
307
307
  # Get the model version from wandb
308
308
  states = None
309
309
  config = None
310
+ run = None
310
311
  try:
311
312
  artifact = wandb.Api().artifact(f"{registry}/{modelname}:{version}")
312
313
  ckpt_dir = artifact.download()
@@ -317,4 +318,4 @@ def load_from_wandb_registry(
317
318
  config = run.config
318
319
  except Exception as e:
319
320
  print(f"Warning: Failed to load model from wandb: {e}")
320
- return states, config
321
+ return states, config, run
@@ -247,6 +247,7 @@ class BasicTransformerBlock(nn.Module):
247
247
  use_cross_only:bool = False
248
248
  only_pure_attention:bool = False
249
249
  force_fp32_for_softmax: bool = True
250
+ norm_epsilon: float = 1e-4
250
251
 
251
252
  def setup(self):
252
253
  if self.use_flash_attention:
@@ -278,9 +279,9 @@ class BasicTransformerBlock(nn.Module):
278
279
  )
279
280
 
280
281
  self.ff = FlaxFeedForward(dim=self.query_dim)
281
- self.norm1 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
282
- self.norm2 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
283
- self.norm3 = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)
282
+ self.norm1 = nn.RMSNorm(epsilon=self.norm_epsilon, dtype=self.dtype)
283
+ self.norm2 = nn.RMSNorm(epsilon=self.norm_epsilon, dtype=self.dtype)
284
+ self.norm3 = nn.RMSNorm(epsilon=self.norm_epsilon, dtype=self.dtype)
284
285
 
285
286
  @nn.compact
286
287
  def __call__(self, hidden_states, context=None):
@@ -312,13 +313,14 @@ class TransformerBlock(nn.Module):
312
313
  # kernel_init: Callable = kernel_init(1.0)
313
314
  norm_inputs: bool = True
314
315
  explicitly_add_residual: bool = True
316
+ norm_epsilon: float = 1e-4
315
317
 
316
318
  @nn.compact
317
319
  def __call__(self, x, context=None):
318
320
  inner_dim = self.heads * self.dim_head
319
321
  C = x.shape[-1]
320
322
  if self.norm_inputs:
321
- x = nn.RMSNorm(epsilon=1e-5, dtype=self.dtype)(x)
323
+ x = nn.RMSNorm(epsilon=self.norm_epsilon, dtype=self.dtype)(x)
322
324
  if self.use_projection == True:
323
325
  if self.use_linear_attention:
324
326
  projected_x = nn.Dense(features=inner_dim,
@@ -350,6 +352,7 @@ class TransformerBlock(nn.Module):
350
352
  use_cross_only=(not self.use_self_and_cross),
351
353
  only_pure_attention=self.only_pure_attention,
352
354
  force_fp32_for_softmax=self.force_fp32_for_softmax,
355
+ norm_epsilon=self.norm_epsilon
353
356
  # kernel_init=self.kernel_init
354
357
  )(projected_x, context)
355
358
 
@@ -0,0 +1,380 @@
1
+ # flaxdiff/models/better_uvit.py
2
+ import jax
3
+ import jax.numpy as jnp
4
+ from flax import linen as nn
5
+ from typing import Callable, Any, Optional, Tuple, Sequence, Union
6
+ import einops
7
+ from functools import partial
8
+
9
+ # Re-use existing components if they are suitable
10
+ from .common import kernel_init, FourierEmbedding, TimeProjection, hilbert_indices, inverse_permutation
11
+ from .attention import NormalAttention # Using NormalAttention for RoPE integration
12
+ from flax.typing import Dtype, PrecisionLike
13
+
14
+ # --- Rotary Positional Embedding (RoPE) ---
15
+ # Adapted from https://github.com/google-deepmind/ring_attention/blob/main/ring_attention/layers/rotary.py
16
+
17
+ def _rotate_half(x: jax.Array) -> jax.Array:
18
+ """Rotates half the hidden dims of the input."""
19
+ x1 = x[..., : x.shape[-1] // 2]
20
+ x2 = x[..., x.shape[-1] // 2 :]
21
+ return jnp.concatenate((-x2, x1), axis=-1)
22
+
23
+ def apply_rotary_embedding(
24
+ x: jax.Array, freqs_cis: jax.Array
25
+ ) -> jax.Array:
26
+ """Applies rotary embedding to the input tensor using rotate_half method."""
27
+ # x shape: [..., Sequence, Dimension] e.g. [B, H, S, D] or [B, S, D]
28
+ # freqs_cis shape: complex [Sequence, Dimension / 2]
29
+
30
+ # Extract cos and sin from the complex freqs_cis
31
+ cos_freqs = jnp.real(freqs_cis) # Shape [S, D/2]
32
+ sin_freqs = jnp.imag(freqs_cis) # Shape [S, D/2]
33
+
34
+ # Expand dims for broadcasting: [1, 1, S, D/2] or [1, S, D/2]
35
+ if x.ndim == 4: # [B, H, S, D]
36
+ cos_freqs = jnp.expand_dims(cos_freqs, axis=(0, 1))
37
+ sin_freqs = jnp.expand_dims(sin_freqs, axis=(0, 1))
38
+ elif x.ndim == 3: # [B, S, D]
39
+ cos_freqs = jnp.expand_dims(cos_freqs, axis=0)
40
+ sin_freqs = jnp.expand_dims(sin_freqs, axis=0)
41
+
42
+ # Duplicate cos and sin for the full dimension D
43
+ # Shape becomes [..., S, D]
44
+ cos_freqs = jnp.concatenate([cos_freqs, cos_freqs], axis=-1)
45
+ sin_freqs = jnp.concatenate([sin_freqs, sin_freqs], axis=-1)
46
+
47
+ # Apply rotation: x * cos + rotate_half(x) * sin
48
+ x_rotated = x * cos_freqs + _rotate_half(x) * sin_freqs
49
+ return x_rotated.astype(x.dtype)
50
+
51
+
52
+ class RotaryEmbedding(nn.Module):
53
+ dim: int # Dimension of the head
54
+ max_seq_len: int = 2048
55
+ base: int = 10000
56
+ dtype: Dtype = jnp.float32
57
+
58
+ def setup(self):
59
+ inv_freq = 1.0 / (
60
+ self.base ** (jnp.arange(0, self.dim, 2, dtype=jnp.float32) / self.dim)
61
+ )
62
+ t = jnp.arange(self.max_seq_len, dtype=jnp.float32)
63
+ freqs = jnp.outer(t, inv_freq) # Shape: [max_seq_len, dim / 2]
64
+
65
+ # Precompute the complex form: cos(theta) + i * sin(theta)
66
+ self.freqs_cis_complex = jnp.cos(freqs) + 1j * jnp.sin(freqs)
67
+ # Shape: [max_seq_len, dim / 2]
68
+
69
+ def __call__(self, seq_len: int):
70
+ if seq_len > self.max_seq_len:
71
+ raise ValueError(f"Sequence length {seq_len} exceeds max_seq_len {self.max_seq_len}")
72
+ # Return complex shape [seq_len, dim / 2]
73
+ return self.freqs_cis_complex[:seq_len, :]
74
+
75
+ # --- Attention with RoPE ---
76
+
77
+ class RoPEAttention(NormalAttention):
78
+ rope_emb: RotaryEmbedding
79
+
80
+ @nn.compact
81
+ def __call__(self, x, context=None, freqs_cis=None):
82
+ # x has shape [B, H, W, C] or [B, S, C]
83
+ orig_x_shape = x.shape
84
+ is_4d = len(orig_x_shape) == 4
85
+ if is_4d:
86
+ B, H, W, C = x.shape
87
+ seq_len = H * W
88
+ x = x.reshape((B, seq_len, C))
89
+ else:
90
+ B, seq_len, C = x.shape
91
+
92
+ context = x if context is None else context
93
+ if len(context.shape) == 4:
94
+ _B, _H, _W, _C = context.shape
95
+ context_seq_len = _H * _W
96
+ context = context.reshape((B, context_seq_len, _C))
97
+ else:
98
+ _B, context_seq_len, _C = context.shape
99
+
100
+ query = self.query(x) # [B, S, H, D]
101
+ key = self.key(context) # [B, S_ctx, H, D]
102
+ value = self.value(context) # [B, S_ctx, H, D]
103
+
104
+ # Apply RoPE to query and key
105
+ if freqs_cis is not None:
106
+ # Permute to [B, H, S, D] for RoPE application if needed by apply_rotary_embedding
107
+ query = einops.rearrange(query, 'b s h d -> b h s d')
108
+ key = einops.rearrange(key, 'b s h d -> b h s d')
109
+
110
+ query = apply_rotary_embedding(query, freqs_cis)
111
+ key = apply_rotary_embedding(key, freqs_cis) # Apply to key as well
112
+
113
+ # Permute back to [B, S, H, D] for dot_product_attention
114
+ query = einops.rearrange(query, 'b h s d -> b s h d')
115
+ key = einops.rearrange(key, 'b h s d -> b s h d')
116
+
117
+ hidden_states = nn.dot_product_attention(
118
+ query, key, value, dtype=self.dtype, broadcast_dropout=False,
119
+ dropout_rng=None, precision=self.precision, force_fp32_for_softmax=self.force_fp32_for_softmax,
120
+ deterministic=True
121
+ ) # Output shape [B, S, H, D]
122
+
123
+ proj = self.proj_attn(hidden_states) # Output shape [B, S, C]
124
+
125
+ if is_4d:
126
+ proj = proj.reshape(orig_x_shape) # Reshape back if input was 4D
127
+
128
+ return proj
129
+
130
+ # --- adaLN-Zero ---
131
+
132
+ class AdaLNZero(nn.Module):
133
+ features: int
134
+ dtype: Optional[Dtype] = None
135
+ precision: PrecisionLike = None
136
+ norm_epsilon: float = 1e-5 # Standard LayerNorm epsilon
137
+
138
+ @nn.compact
139
+ def __call__(self, x, conditioning):
140
+ # Project conditioning signal to get scale and shift parameters
141
+ # Conditioning shape: [B, D_cond] -> [B, 1, ..., 1, 6 * features] for broadcasting
142
+ # Or [B, 1, 6*features] if x is [B, S, F]
143
+
144
+ # Ensure conditioning has seq dim if x does
145
+ if x.ndim == 3 and conditioning.ndim == 2: # x=[B,S,F], cond=[B,D_cond]
146
+ conditioning = jnp.expand_dims(conditioning, axis=1) # cond=[B,1,D_cond]
147
+
148
+ # Project conditioning to get 6 params per feature (scale_mlp, shift_mlp, gate_mlp, scale_attn, shift_attn, gate_attn)
149
+ # Using nn.DenseGeneral for flexibility if needed, but nn.Dense is fine if cond is [B, D_cond] or [B, 1, D_cond]
150
+ ada_params = nn.Dense(
151
+ features=6 * self.features,
152
+ dtype=self.dtype,
153
+ precision=self.precision,
154
+ kernel_init=nn.initializers.zeros, # Initialize projection to zero (Zero init)
155
+ name="ada_proj"
156
+ )(conditioning)
157
+
158
+ # Split into scale, shift, gate for MLP and Attention
159
+ scale_mlp, shift_mlp, gate_mlp, scale_attn, shift_attn, gate_attn = jnp.split(ada_params, 6, axis=-1)
160
+
161
+ # Apply Layer Normalization
162
+ norm = nn.LayerNorm(epsilon=self.norm_epsilon, use_scale=False, use_bias=False, dtype=self.dtype)
163
+ # norm = nn.RMSNorm(epsilon=self.norm_epsilon, dtype=self.dtype) # Alternative: RMSNorm
164
+
165
+ norm_x = norm(x)
166
+
167
+ # Modulate for Attention path
168
+ x_attn = norm_x * (1 + scale_attn) + shift_attn
169
+
170
+ # Modulate for MLP path
171
+ x_mlp = norm_x * (1 + scale_mlp) + shift_mlp
172
+
173
+ # Return modulated outputs and gates
174
+ return x_attn, gate_attn, x_mlp, gate_mlp
175
+
176
+
177
+ # --- DiT Block ---
178
+
179
+ class DiTBlock(nn.Module):
180
+ features: int
181
+ num_heads: int
182
+ mlp_ratio: int = 4
183
+ dropout_rate: float = 0.0 # Typically dropout is not used in diffusion models
184
+ dtype: Optional[Dtype] = None
185
+ precision: PrecisionLike = None
186
+ use_flash_attention: bool = False # Keep option, but RoPEAttention uses NormalAttention base
187
+ force_fp32_for_softmax: bool = True
188
+ norm_epsilon: float = 1e-5
189
+ rope_emb: RotaryEmbedding # Pass RoPE module
190
+
191
+ def setup(self):
192
+ hidden_features = int(self.features * self.mlp_ratio)
193
+ self.ada_ln_zero = AdaLNZero(self.features, dtype=self.dtype, precision=self.precision, norm_epsilon=self.norm_epsilon)
194
+
195
+ # Use RoPEAttention
196
+ self.attention = RoPEAttention(
197
+ query_dim=self.features,
198
+ heads=self.num_heads,
199
+ dim_head=self.features // self.num_heads,
200
+ dtype=self.dtype,
201
+ precision=self.precision,
202
+ use_bias=True, # Bias is common in DiT attention proj
203
+ force_fp32_for_softmax=self.force_fp32_for_softmax,
204
+ rope_emb=self.rope_emb # Pass RoPE module instance
205
+ )
206
+
207
+ # Standard MLP block
208
+ self.mlp = nn.Sequential([
209
+ nn.Dense(features=hidden_features, dtype=self.dtype, precision=self.precision),
210
+ nn.gelu,
211
+ nn.Dense(features=self.features, dtype=self.dtype, precision=self.precision)
212
+ ])
213
+
214
+ @nn.compact
215
+ def __call__(self, x, conditioning, freqs_cis):
216
+ # x shape: [B, S, F]
217
+ # conditioning shape: [B, D_cond]
218
+
219
+ residual = x
220
+
221
+ # Apply adaLN-Zero to get modulated inputs and gates
222
+ x_attn, gate_attn, x_mlp, gate_mlp = self.ada_ln_zero(x, conditioning)
223
+
224
+ # Attention block
225
+ attn_output = self.attention(x_attn, context=None, freqs_cis=freqs_cis) # Self-attention only
226
+ x = residual + gate_attn * attn_output
227
+
228
+ # MLP block
229
+ mlp_output = self.mlp(x_mlp)
230
+ x = x + gate_mlp * mlp_output
231
+
232
+ return x
233
+
234
+ # --- Patch Embedding (reuse or define if needed) ---
235
+ # Assuming PatchEmbedding exists in simple_vit.py and is suitable
236
+ from .simple_vit import PatchEmbedding, unpatchify
237
+
238
+ # --- Better UViT (DiT Style) ---
239
+
240
+ class BetterUViT(nn.Module):
241
+ output_channels: int = 3
242
+ patch_size: int = 16
243
+ emb_features: int = 768
244
+ num_layers: int = 12
245
+ num_heads: int = 12
246
+ mlp_ratio: int = 4
247
+ dropout_rate: float = 0.0 # Typically 0 for diffusion
248
+ dtype: Optional[Dtype] = None
249
+ precision: PrecisionLike = None
250
+ use_flash_attention: bool = False # Passed down, but RoPEAttention uses NormalAttention
251
+ force_fp32_for_softmax: bool = True
252
+ norm_epsilon: float = 1e-5
253
+ learn_sigma: bool = False # Option to predict sigma like in DiT paper
254
+ use_hilbert: bool = False # Toggle Hilbert patch reorder
255
+
256
+ def setup(self):
257
+ self.patch_embed = PatchEmbedding(
258
+ patch_size=self.patch_size,
259
+ embedding_dim=self.emb_features,
260
+ dtype=self.dtype,
261
+ precision=self.precision
262
+ )
263
+
264
+ # Time embedding projection
265
+ self.time_embed = nn.Sequential([
266
+ FourierEmbedding(features=self.emb_features),
267
+ TimeProjection(features=self.emb_features * self.mlp_ratio), # Project to MLP dim
268
+ nn.Dense(features=self.emb_features, dtype=self.dtype, precision=self.precision) # Final projection
269
+ ])
270
+
271
+ # Text context projection (if used)
272
+ # Assuming textcontext is already projected to some dimension, project it to match emb_features
273
+ # This might need adjustment based on how text context is provided
274
+ self.text_proj = nn.Dense(features=self.emb_features, dtype=self.dtype, precision=self.precision, name="text_context_proj")
275
+
276
+ # Rotary Positional Embedding
277
+ # Max length needs to be estimated or set large enough.
278
+ # For images, seq len = (H/P) * (W/P). Example: 256/16 * 256/16 = 16*16 = 256
279
+ # Add 1 if a class token is used, or more for text tokens if concatenated.
280
+ # Let's assume max seq len accommodates patches + time + text tokens if needed, or just patches.
281
+ # If only patches use RoPE, max_len = max_image_tokens
282
+ # If time/text are concatenated *before* blocks, max_len needs to include them.
283
+ # DiT typically applies PE only to patch tokens. Let's follow that.
284
+ # max_len should be max number of patches.
285
+ # Example: max image size 512x512, patch 16 -> (512/16)^2 = 32^2 = 1024 patches
286
+ self.rope = RotaryEmbedding(dim=self.emb_features // self.num_heads, max_seq_len=4096, dtype=self.dtype) # Dim per head
287
+
288
+ # Transformer Blocks
289
+ self.blocks = [
290
+ DiTBlock(
291
+ features=self.emb_features,
292
+ num_heads=self.num_heads,
293
+ mlp_ratio=self.mlp_ratio,
294
+ dropout_rate=self.dropout_rate,
295
+ dtype=self.dtype,
296
+ precision=self.precision,
297
+ use_flash_attention=self.use_flash_attention,
298
+ force_fp32_for_softmax=self.force_fp32_for_softmax,
299
+ norm_epsilon=self.norm_epsilon,
300
+ rope_emb=self.rope, # Pass RoPE instance
301
+ name=f"dit_block_{i}"
302
+ ) for i in range(self.num_layers)
303
+ ]
304
+
305
+ # Final Layer (Normalization + Linear Projection)
306
+ self.final_norm = nn.LayerNorm(epsilon=self.norm_epsilon, dtype=self.dtype, name="final_norm")
307
+ # self.final_norm = nn.RMSNorm(epsilon=self.norm_epsilon, dtype=self.dtype, name="final_norm")
308
+
309
+ # Predict patch pixels + potentially sigma
310
+ output_dim = self.patch_size * self.patch_size * self.output_channels
311
+ if self.learn_sigma:
312
+ output_dim *= 2 # Predict both mean and variance (or log_variance)
313
+
314
+ self.final_proj = nn.Dense(
315
+ features=output_dim,
316
+ dtype=self.dtype,
317
+ precision=self.precision,
318
+ kernel_init=nn.initializers.zeros, # Initialize final layer to zero
319
+ name="final_proj"
320
+ )
321
+
322
+ @nn.compact
323
+ def __call__(self, x, temb, textcontext=None):
324
+ B, H, W, C = x.shape
325
+ assert H % self.patch_size == 0 and W % self.patch_size == 0, "Image dimensions must be divisible by patch size"
326
+
327
+ # 1. Patch Embedding
328
+ patches = self.patch_embed(x) # Shape: [B, num_patches, emb_features]
329
+ num_patches = patches.shape[1]
330
+
331
+ # Optional Hilbert reorder
332
+ if self.use_hilbert:
333
+ idx = hilbert_indices(H // self.patch_size, W // self.patch_size)
334
+ inv_idx = inverse_permutation(idx)
335
+ patches = patches[:, idx, :]
336
+
337
+ # replace x with patches
338
+ x_seq = patches
339
+
340
+ # 2. Prepare Conditioning Signal (Time + Text Context)
341
+ t_emb = self.time_embed(temb) # Shape: [B, emb_features]
342
+
343
+ cond_emb = t_emb
344
+ if textcontext is not None:
345
+ text_emb = self.text_proj(textcontext) # Shape: [B, num_text_tokens, emb_features]
346
+ # Pool or select text embedding (e.g., mean pool or use CLS token)
347
+ # Assuming mean pooling for simplicity
348
+ text_emb_pooled = jnp.mean(text_emb, axis=1) # Shape: [B, emb_features]
349
+ cond_emb = cond_emb + text_emb_pooled # Combine time and text embeddings
350
+
351
+ # 3. Apply RoPE
352
+ # Get RoPE frequencies for the sequence length (number of patches)
353
+ freqs_cis = self.rope(seq_len=num_patches) # Shape [num_patches, D_head/2]
354
+
355
+ # 4. Apply Transformer Blocks with adaLN-Zero conditioning
356
+ for block in self.blocks:
357
+ x_seq = block(x_seq, conditioning=cond_emb, freqs_cis=freqs_cis)
358
+
359
+ # 5. Final Layer
360
+ x_out = self.final_norm(x_seq)
361
+ x_out = self.final_proj(x_out) # Shape: [B, num_patches, patch_pixels (*2 if learn_sigma)]
362
+
363
+ # Optional Hilbert inverse reorder
364
+ if self.use_hilbert:
365
+ x_out = x_out[:, inv_idx, :]
366
+
367
+ # 6. Unpatchify
368
+ if self.learn_sigma:
369
+ # Split into mean and variance predictions
370
+ x_mean, x_logvar = jnp.split(x_out, 2, axis=-1)
371
+ x = unpatchify(x_mean, channels=self.output_channels)
372
+ # Return both mean and logvar if needed by the loss function
373
+ # For now, just returning the mean prediction like standard diffusion models
374
+ # logvar = unpatchify(x_logvar, channels=self.output_channels)
375
+ # return x, logvar
376
+ return x
377
+ else:
378
+ x = unpatchify(x_out, channels=self.output_channels) # Shape: [B, H, W, C]
379
+ return x
380
+
@@ -6,6 +6,8 @@ from flax.typing import Dtype, PrecisionLike
6
6
  from typing import Dict, Callable, Sequence, Any, Union
7
7
  import einops
8
8
  from functools import partial
9
+ import math
10
+ from einops import rearrange
9
11
 
10
12
  # Kernel initializer to use
11
13
  def kernel_init(scale=1.0, dtype=jnp.float32):
@@ -247,7 +249,7 @@ class Downsample(nn.Module):
247
249
  return out
248
250
 
249
251
 
250
- def l2norm(t, axis=1, eps=1e-12):
252
+ def l2norm(t, axis=1, eps=1e-6): # Increased epsilon from 1e-12
251
253
  denom = jnp.clip(jnp.linalg.norm(t, ord=2, axis=axis, keepdims=True), eps)
252
254
  out = t/denom
253
255
  return (out)
@@ -266,14 +268,15 @@ class ResidualBlock(nn.Module):
266
268
  dtype: Optional[Dtype] = None
267
269
  precision: PrecisionLike = None
268
270
  named_norms:bool=False
271
+ norm_epsilon: float = 1e-4 # Added epsilon parameter, increased default
269
272
 
270
273
  def setup(self):
271
274
  if self.norm_groups > 0:
272
- norm = partial(nn.GroupNorm, self.norm_groups)
275
+ norm = partial(nn.GroupNorm, self.norm_groups, epsilon=self.norm_epsilon)
273
276
  self.norm1 = norm(name="GroupNorm_0") if self.named_norms else norm()
274
277
  self.norm2 = norm(name="GroupNorm_1") if self.named_norms else norm()
275
278
  else:
276
- norm = partial(nn.RMSNorm, 1e-5)
279
+ norm = partial(nn.RMSNorm, epsilon=self.norm_epsilon)
277
280
  self.norm1 = norm()
278
281
  self.norm2 = norm()
279
282
 
@@ -333,4 +336,72 @@ class ResidualBlock(nn.Module):
333
336
  out = jnp.concatenate([out, extra_features], axis=-1) if extra_features is not None else out
334
337
 
335
338
  return out
336
-
339
+
340
+ # Convert Hilbert index d to 2D coordinates (x, y) for an n x n grid
341
+ def _d2xy(n, d):
342
+ x = 0
343
+ y = 0
344
+ t = d
345
+ s = 1
346
+ while s < n:
347
+ rx = (t // 2) & 1
348
+ ry = (t ^ rx) & 1
349
+ if ry == 0:
350
+ if rx == 1:
351
+ x = n - 1 - x
352
+ y = n - 1 - y
353
+ x, y = y, x
354
+ x += s * rx
355
+ y += s * ry
356
+ t //= 4
357
+ s *= 2
358
+ return x, y
359
+
360
+ # Hilbert index mapping for a rectangular grid of patches H_P x W_P
361
+
362
+ def hilbert_indices(H_P, W_P):
363
+ size = max(H_P, W_P)
364
+ order = math.ceil(math.log2(size))
365
+ n = 1 << order
366
+ coords = []
367
+ for d in range(n * n):
368
+ x, y = _d2xy(n, d)
369
+ # x is column index, y is row index
370
+ if x < W_P and y < H_P:
371
+ coords.append((y, x)) # (row, col)
372
+ if len(coords) == H_P * W_P:
373
+ break
374
+ # Convert (row, col) to linear indices row-major
375
+ indices = [r * W_P + c for r, c in coords]
376
+ return jnp.array(indices, dtype=jnp.int32)
377
+
378
+ # Inverse permutation: given idx where idx[i] = new position of element i, return inv such that inv[idx[i]] = i
379
+
380
+ def inverse_permutation(idx):
381
+ inv = jnp.zeros_like(idx)
382
+ inv = inv.at[idx].set(jnp.arange(idx.shape[0], dtype=idx.dtype))
383
+ return inv
384
+
385
+ # Patchify using Hilbert ordering: extract patches and reorder sequence
386
+
387
+ def hilbert_patchify(x, patch_size):
388
+ B, H, W, C = x.shape
389
+ H_P = H // patch_size
390
+ W_P = W // patch_size
391
+ # Extract patches in row-major
392
+ patches = rearrange(x, 'b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size)
393
+ idx = hilbert_indices(H_P, W_P)
394
+ return patches[:, idx, :]
395
+
396
+ # Unpatchify from Hilbert ordering: reorder sequence back and reconstruct image
397
+
398
+ def hilbert_unpatchify(patches, patch_size, H, W, C):
399
+ B, N, D = patches.shape
400
+ H_P = H // patch_size
401
+ W_P = W // patch_size
402
+ inv = inverse_permutation(hilbert_indices(H_P, W_P))
403
+ # Reorder back to row-major
404
+ linear = patches[:, inv, :]
405
+ # Reconstruct image
406
+ x = rearrange(linear, 'b (h w) (p1 p2 c) -> b (h p1) (w p2) c', h=H_P, w=W_P, p1=patch_size, p2=patch_size, c=C)
407
+ return x
@@ -10,6 +10,7 @@ from flaxdiff.models.simple_unet import FourierEmbedding, TimeProjection, ConvLa
10
10
  import einops
11
11
  from flax.typing import Dtype, PrecisionLike
12
12
  from functools import partial
13
+ from .common import hilbert_indices, inverse_permutation
13
14
 
14
15
  def unpatchify(x, channels=3):
15
16
  patch_size = int((x.shape[2] // channels) ** 0.5)
@@ -55,8 +56,6 @@ class UViT(nn.Module):
55
56
  num_layers: int = 12
56
57
  num_heads: int = 12
57
58
  dropout_rate: float = 0.1
58
- dtype: Any = jnp.float32
59
- precision: Any = jax.lax.Precision.HIGH
60
59
  use_projection: bool = False
61
60
  use_flash_attention: bool = False
62
61
  use_self_and_cross: bool = False
@@ -65,16 +64,17 @@ class UViT(nn.Module):
65
64
  norm_groups:int=8
66
65
  dtype: Optional[Dtype] = None
67
66
  precision: PrecisionLike = None
68
- # kernel_init: Callable = partial(kernel_init, scale=1.0)
69
67
  add_residualblock_output: bool = False
70
68
  norm_inputs: bool = False
71
69
  explicitly_add_residual: bool = True
70
+ norm_epsilon: float = 1e-4 # Added epsilon parameter, increased default
71
+ use_hilbert: bool = False # Toggle Hilbert patch reorder
72
72
 
73
73
  def setup(self):
74
74
  if self.norm_groups > 0:
75
- self.norm = partial(nn.GroupNorm, self.norm_groups)
75
+ self.norm = partial(nn.GroupNorm, self.norm_groups, epsilon=self.norm_epsilon)
76
76
  else:
77
- self.norm = partial(nn.RMSNorm, 1e-5)
77
+ self.norm = partial(nn.RMSNorm, epsilon=self.norm_epsilon)
78
78
 
79
79
  @nn.compact
80
80
  def __call__(self, x, temb, textcontext=None):
@@ -83,28 +83,32 @@ class UViT(nn.Module):
83
83
  temb = TimeProjection(features=self.emb_features)(temb)
84
84
 
85
85
  original_img = x
86
+ B, H, W, C = original_img.shape
87
+ H_P = H // self.patch_size
88
+ W_P = W // self.patch_size
86
89
 
87
90
  # Patch embedding
88
91
  x = PatchEmbedding(patch_size=self.patch_size, embedding_dim=self.emb_features,
89
92
  dtype=self.dtype, precision=self.precision)(x)
90
93
  num_patches = x.shape[1]
91
-
94
+
95
+ # Optional Hilbert reorder
96
+ if self.use_hilbert:
97
+ idx = hilbert_indices(H_P, W_P)
98
+ inv_idx = inverse_permutation(idx)
99
+ x = x[:, idx, :]
100
+
92
101
  context_emb = nn.DenseGeneral(features=self.emb_features,
93
102
  dtype=self.dtype, precision=self.precision)(textcontext)
94
103
  num_text_tokens = textcontext.shape[1]
95
104
 
96
- # print(f'Shape of x after patch embedding: {x.shape}, numPatches: {num_patches}, temb: {temb.shape}, context_emb: {context_emb.shape}')
97
-
98
105
  # Add time embedding
99
106
  temb = jnp.expand_dims(temb, axis=1)
100
107
  x = jnp.concatenate([x, temb, context_emb], axis=1)
101
- # print(f'Shape of x after time embedding: {x.shape}')
102
-
108
+
103
109
  # Add positional encoding
104
110
  x = PositionalEncoding(max_len=x.shape[1], embedding_dim=self.emb_features)(x)
105
111
 
106
- # print(f'Shape of x after positional encoding: {x.shape}')
107
-
108
112
  skips = []
109
113
  # In blocks
110
114
  for i in range(self.num_layers // 2):
@@ -114,6 +118,7 @@ class UViT(nn.Module):
114
118
  only_pure_attention=False,
115
119
  norm_inputs=self.norm_inputs,
116
120
  explicitly_add_residual=self.explicitly_add_residual,
121
+ norm_epsilon=self.norm_epsilon, # Pass epsilon
117
122
  )(x)
118
123
  skips.append(x)
119
124
 
@@ -124,9 +129,10 @@ class UViT(nn.Module):
124
129
  only_pure_attention=False,
125
130
  norm_inputs=self.norm_inputs,
126
131
  explicitly_add_residual=self.explicitly_add_residual,
132
+ norm_epsilon=self.norm_epsilon, # Pass epsilon
127
133
  )(x)
128
134
 
129
- # # Out blocks
135
+ # Out blocks
130
136
  for i in range(self.num_layers // 2):
131
137
  x = jnp.concatenate([x, skips.pop()], axis=-1)
132
138
  x = nn.DenseGeneral(features=self.emb_features,
@@ -137,14 +143,18 @@ class UViT(nn.Module):
137
143
  only_pure_attention=False,
138
144
  norm_inputs=self.norm_inputs,
139
145
  explicitly_add_residual=self.explicitly_add_residual,
146
+ norm_epsilon=self.norm_epsilon, # Pass epsilon
140
147
  )(x)
141
148
 
142
- # print(f'Shape of x after transformer blocks: {x.shape}')
143
- x = self.norm()(x)
149
+ x = self.norm()(x) # Uses norm_epsilon defined in setup
144
150
 
145
151
  patch_dim = self.patch_size ** 2 * self.output_channels
146
152
  x = nn.Dense(features=patch_dim, dtype=self.dtype, precision=self.precision)(x)
147
- x = x[:, 1 + num_text_tokens:, :]
153
+ # If Hilbert, restore original patch order
154
+ if self.use_hilbert:
155
+ x = x[:, inv_idx, :]
156
+ # Extract only the image patch tokens (first num_patches tokens)
157
+ x = x[:, :num_patches, :]
148
158
  x = unpatchify(x, channels=self.output_channels)
149
159
 
150
160
  if self.add_residualblock_output:
@@ -578,6 +578,7 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
578
578
  if not hasattr(self, "wandb_sweep"):
579
579
  raise ValueError("Wandb sweep is not initialized. Cannot get best runs.")
580
580
 
581
+ print(f"Getting best runs from sweep {self.wandb_sweep.id}...")
581
582
  # Get the sweep runs
582
583
  runs = sorted(self.wandb_sweep.runs, key=lambda x: x.summary.get(metric, float('inf')))
583
584
  best_runs = runs[:top_k]
@@ -588,18 +589,46 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
588
589
  print(f"\t\tRun ID: {run.id}, Metric: {run.summary.get(metric, float('inf'))}")
589
590
  return best_runs, (min(lower_bound, upper_bound), max(lower_bound, upper_bound))
590
591
 
591
- def __compare_run_against_best__(self, top_k=2, metric="train/best_loss"):
592
+ def __get_best_general_runs__(
593
+ self,
594
+ metric: str = "train/best_loss",
595
+ top_k: int = 5,
596
+ ):
592
597
  """
593
- Compare the current run against the best runs from the sweep.
598
+ Get the best runs from wandb.
599
+ Args:
600
+ metric: Metric to sort by.
601
+ top_k: Number of top runs to return.
602
+ """
603
+ if self.wandb is None:
604
+ raise ValueError("Wandb is not initialized. Cannot get best runs.")
605
+
606
+ # Get the sweep runs
607
+ runs = sorted(self.wandb.runs, key=lambda x: x.summary.get(metric, float('inf')))
608
+ best_runs = runs[:top_k]
609
+ lower_bound = best_runs[-1].summary.get(metric, float('inf'))
610
+ upper_bound = best_runs[0].summary.get(metric, float('inf'))
611
+ print(f"Best runs from wandb {self.wandb.id}:")
612
+ for run in best_runs:
613
+ print(f"\t\tRun ID: {run.id}, Metric: {run.summary.get(metric, float('inf'))}")
614
+ return best_runs, (min(lower_bound, upper_bound), max(lower_bound, upper_bound))
615
+
616
+ def __compare_run_against_best__(self, top_k=2, metric="train/best_loss", from_sweeps=False):
617
+ """
618
+ Compare the current run against the best runs from wandb.
594
619
  Args:
595
620
  top_k: Number of top runs to consider.
596
621
  metric: Metric to compare against.
622
+ from_sweeps: Whether to consider runs from sweeps.
597
623
  Returns:
598
624
  is_good: Whether the current run is among the best.
599
625
  is_best: Whether the current run is the best.
600
626
  """
601
627
  # Get best runs
602
- best_runs, bounds = self.__get_best_sweep_runs__(metric=metric, top_k=top_k)
628
+ if from_sweeps:
629
+ best_runs, bounds = self.__get_best_sweep_runs__(metric=metric, top_k=top_k)
630
+ else:
631
+ best_runs, bounds = self.__get_best_general_runs__(metric=metric, top_k=top_k)
603
632
 
604
633
  # Determine if lower or higher values are better (for loss, lower is better)
605
634
  is_lower_better = "loss" in metric.lower()
@@ -621,10 +650,10 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
621
650
  def save(self, epoch=0, step=0, state=None, rngstate=None):
622
651
  super().save(epoch=epoch, step=step, state=state, rngstate=rngstate)
623
652
 
624
- if self.wandb is not None and hasattr(self, "wandb_sweep"):
653
+ if self.wandb is not None:
625
654
  checkpoint = get_latest_checkpoint(self.checkpoint_path())
626
655
  try:
627
- is_good, is_best = self.__compare_run_against_best__(top_k=5, metric="train/best_loss")
656
+ is_good, is_best = self.__compare_run_against_best__(top_k=5, metric="train/best_loss", from_sweeps=hasattr(self, "wandb_sweep"))
628
657
  if is_good:
629
658
  # Push to registry with appropriate aliases
630
659
  aliases = []
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: flaxdiff
3
- Version: 0.2.5
3
+ Version: 0.2.6.1
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author-email: Ashish Kumar Singh <ashishkmr472@gmail.com>
6
6
  License-Expression: MIT
@@ -34,6 +34,7 @@ flaxdiff/metrics/ssim.py
34
34
  flaxdiff/metrics/utils.py
35
35
  flaxdiff/models/__init__.py
36
36
  flaxdiff/models/attention.py
37
+ flaxdiff/models/better_uvit.py
37
38
  flaxdiff/models/common.py
38
39
  flaxdiff/models/favor_fastattn.py
39
40
  flaxdiff/models/general.py
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "flaxdiff"
7
- version = "0.2.5"
7
+ version = "0.2.6.1"
8
8
  description = "A versatile and easy to understand Diffusion library"
9
9
  readme = "README.md"
10
10
  authors = [
File without changes
File without changes
File without changes