flaxdiff 0.2.4__tar.gz → 0.2.6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {flaxdiff-0.2.4 → flaxdiff-0.2.6}/PKG-INFO +1 -1
- {flaxdiff-0.2.4 → flaxdiff-0.2.6}/flaxdiff/data/sources/images.py +7 -0
- {flaxdiff-0.2.4 → flaxdiff-0.2.6}/flaxdiff/inference/pipeline.py +11 -4
- {flaxdiff-0.2.4 → flaxdiff-0.2.6}/flaxdiff/inference/utils.py +3 -2
- {flaxdiff-0.2.4 → flaxdiff-0.2.6}/flaxdiff/metrics/images.py +1 -1
- {flaxdiff-0.2.4 → flaxdiff-0.2.6}/flaxdiff/models/attention.py +7 -4
- flaxdiff-0.2.6/flaxdiff/models/better_uvit.py +380 -0
- {flaxdiff-0.2.4 → flaxdiff-0.2.6}/flaxdiff/models/common.py +75 -4
- {flaxdiff-0.2.4 → flaxdiff-0.2.6}/flaxdiff/models/simple_vit.py +26 -16
- {flaxdiff-0.2.4 → flaxdiff-0.2.6}/flaxdiff/trainer/general_diffusion_trainer.py +34 -5
- {flaxdiff-0.2.4 → flaxdiff-0.2.6}/flaxdiff.egg-info/PKG-INFO +1 -1
- {flaxdiff-0.2.4 → flaxdiff-0.2.6}/flaxdiff.egg-info/SOURCES.txt +1 -0
- {flaxdiff-0.2.4 → flaxdiff-0.2.6}/pyproject.toml +1 -1
- {flaxdiff-0.2.4 → flaxdiff-0.2.6}/README.md +0 -0
- {flaxdiff-0.2.4 → flaxdiff-0.2.6}/flaxdiff/__init__.py +0 -0
- {flaxdiff-0.2.4 → flaxdiff-0.2.6}/flaxdiff/data/__init__.py +0 -0
- {flaxdiff-0.2.4 → flaxdiff-0.2.6}/flaxdiff/data/benchmark_decord.py +0 -0
- {flaxdiff-0.2.4 → flaxdiff-0.2.6}/flaxdiff/data/dataloaders.py +0 -0
- {flaxdiff-0.2.4 → flaxdiff-0.2.6}/flaxdiff/data/dataset_map.py +0 -0
- {flaxdiff-0.2.4 → flaxdiff-0.2.6}/flaxdiff/data/online_loader.py +0 -0
- {flaxdiff-0.2.4 → flaxdiff-0.2.6}/flaxdiff/data/sources/audio_utils.py +0 -0
- {flaxdiff-0.2.4 → flaxdiff-0.2.6}/flaxdiff/data/sources/av_example.py +0 -0
- {flaxdiff-0.2.4 → flaxdiff-0.2.6}/flaxdiff/data/sources/av_utils.py +0 -0
- {flaxdiff-0.2.4 → flaxdiff-0.2.6}/flaxdiff/data/sources/base.py +0 -0
- {flaxdiff-0.2.4 → flaxdiff-0.2.6}/flaxdiff/data/sources/utils.py +0 -0
- {flaxdiff-0.2.4 → flaxdiff-0.2.6}/flaxdiff/data/sources/videos.py +0 -0
- {flaxdiff-0.2.4 → flaxdiff-0.2.6}/flaxdiff/data/sources/voxceleb2.py +0 -0
- {flaxdiff-0.2.4 → flaxdiff-0.2.6}/flaxdiff/inference/__init__.py +0 -0
- {flaxdiff-0.2.4 → flaxdiff-0.2.6}/flaxdiff/inputs/__init__.py +0 -0
- {flaxdiff-0.2.4 → flaxdiff-0.2.6}/flaxdiff/inputs/encoders.py +0 -0
- {flaxdiff-0.2.4 → flaxdiff-0.2.6}/flaxdiff/metrics/__init__.py +0 -0
- {flaxdiff-0.2.4 → flaxdiff-0.2.6}/flaxdiff/metrics/common.py +0 -0
- {flaxdiff-0.2.4 → flaxdiff-0.2.6}/flaxdiff/metrics/inception.py +0 -0
- {flaxdiff-0.2.4 → flaxdiff-0.2.6}/flaxdiff/metrics/psnr.py +0 -0
- {flaxdiff-0.2.4 → flaxdiff-0.2.6}/flaxdiff/metrics/ssim.py +0 -0
- {flaxdiff-0.2.4 → flaxdiff-0.2.6}/flaxdiff/metrics/utils.py +0 -0
- {flaxdiff-0.2.4 → flaxdiff-0.2.6}/flaxdiff/models/__init__.py +0 -0
- {flaxdiff-0.2.4 → flaxdiff-0.2.6}/flaxdiff/models/autoencoder/__init__.py +0 -0
- {flaxdiff-0.2.4 → flaxdiff-0.2.6}/flaxdiff/models/autoencoder/autoencoder.py +0 -0
- {flaxdiff-0.2.4 → flaxdiff-0.2.6}/flaxdiff/models/autoencoder/diffusers.py +0 -0
- {flaxdiff-0.2.4 → flaxdiff-0.2.6}/flaxdiff/models/autoencoder/simple_autoenc.py +0 -0
- {flaxdiff-0.2.4 → flaxdiff-0.2.6}/flaxdiff/models/favor_fastattn.py +0 -0
- {flaxdiff-0.2.4 → flaxdiff-0.2.6}/flaxdiff/models/general.py +0 -0
- {flaxdiff-0.2.4 → flaxdiff-0.2.6}/flaxdiff/models/simple_unet.py +0 -0
- {flaxdiff-0.2.4 → flaxdiff-0.2.6}/flaxdiff/models/unet_3d.py +0 -0
- {flaxdiff-0.2.4 → flaxdiff-0.2.6}/flaxdiff/models/unet_3d_blocks.py +0 -0
- {flaxdiff-0.2.4 → flaxdiff-0.2.6}/flaxdiff/predictors/__init__.py +0 -0
- {flaxdiff-0.2.4 → flaxdiff-0.2.6}/flaxdiff/samplers/__init__.py +0 -0
- {flaxdiff-0.2.4 → flaxdiff-0.2.6}/flaxdiff/samplers/common.py +0 -0
- {flaxdiff-0.2.4 → flaxdiff-0.2.6}/flaxdiff/samplers/ddim.py +0 -0
- {flaxdiff-0.2.4 → flaxdiff-0.2.6}/flaxdiff/samplers/ddpm.py +0 -0
- {flaxdiff-0.2.4 → flaxdiff-0.2.6}/flaxdiff/samplers/euler.py +0 -0
- {flaxdiff-0.2.4 → flaxdiff-0.2.6}/flaxdiff/samplers/heun_sampler.py +0 -0
- {flaxdiff-0.2.4 → flaxdiff-0.2.6}/flaxdiff/samplers/multistep_dpm.py +0 -0
- {flaxdiff-0.2.4 → flaxdiff-0.2.6}/flaxdiff/samplers/rk4_sampler.py +0 -0
- {flaxdiff-0.2.4 → flaxdiff-0.2.6}/flaxdiff/schedulers/__init__.py +0 -0
- {flaxdiff-0.2.4 → flaxdiff-0.2.6}/flaxdiff/schedulers/common.py +0 -0
- {flaxdiff-0.2.4 → flaxdiff-0.2.6}/flaxdiff/schedulers/continuous.py +0 -0
- {flaxdiff-0.2.4 → flaxdiff-0.2.6}/flaxdiff/schedulers/cosine.py +0 -0
- {flaxdiff-0.2.4 → flaxdiff-0.2.6}/flaxdiff/schedulers/discrete.py +0 -0
- {flaxdiff-0.2.4 → flaxdiff-0.2.6}/flaxdiff/schedulers/exp.py +0 -0
- {flaxdiff-0.2.4 → flaxdiff-0.2.6}/flaxdiff/schedulers/karras.py +0 -0
- {flaxdiff-0.2.4 → flaxdiff-0.2.6}/flaxdiff/schedulers/linear.py +0 -0
- {flaxdiff-0.2.4 → flaxdiff-0.2.6}/flaxdiff/schedulers/sqrt.py +0 -0
- {flaxdiff-0.2.4 → flaxdiff-0.2.6}/flaxdiff/trainer/__init__.py +0 -0
- {flaxdiff-0.2.4 → flaxdiff-0.2.6}/flaxdiff/trainer/autoencoder_trainer.py +0 -0
- {flaxdiff-0.2.4 → flaxdiff-0.2.6}/flaxdiff/trainer/diffusion_trainer.py +0 -0
- {flaxdiff-0.2.4 → flaxdiff-0.2.6}/flaxdiff/trainer/simple_trainer.py +0 -0
- {flaxdiff-0.2.4 → flaxdiff-0.2.6}/flaxdiff/utils.py +0 -0
- {flaxdiff-0.2.4 → flaxdiff-0.2.6}/flaxdiff.egg-info/dependency_links.txt +0 -0
- {flaxdiff-0.2.4 → flaxdiff-0.2.6}/flaxdiff.egg-info/requires.txt +0 -0
- {flaxdiff-0.2.4 → flaxdiff-0.2.6}/flaxdiff.egg-info/top_level.txt +0 -0
- {flaxdiff-0.2.4 → flaxdiff-0.2.6}/setup.cfg +0 -0
@@ -266,6 +266,12 @@ class ImageGCSAugmenter(DataAugmenter):
|
|
266
266
|
|
267
267
|
print(f"Using method: {method}")
|
268
268
|
|
269
|
+
from torchvision.transforms import v2
|
270
|
+
augments = v2.Compose([
|
271
|
+
v2.RandomHorizontalFlip(p=0.5),
|
272
|
+
v2.ColorJitter(brightness=0.2, contrast=0.05, saturation=0.2)
|
273
|
+
])
|
274
|
+
|
269
275
|
class GCSTransform(pygrain.MapTransform):
|
270
276
|
def __init__(self, *args, **kwargs):
|
271
277
|
super().__init__(*args, **kwargs)
|
@@ -277,6 +283,7 @@ class ImageGCSAugmenter(DataAugmenter):
|
|
277
283
|
image = np.asarray(bytearray(element['jpg']), dtype="uint8")
|
278
284
|
image = cv2.imdecode(image, cv2.IMREAD_UNCHANGED)
|
279
285
|
image = self.image_augmenter(image)
|
286
|
+
image = augments(image)
|
280
287
|
caption = labelizer(element).decode('utf-8')
|
281
288
|
results = self.auto_tokenize(caption)
|
282
289
|
return {
|
@@ -53,6 +53,7 @@ class DiffusionInferencePipeline(InferencePipeline):
|
|
53
53
|
input_config: DiffusionInputConfig = None
|
54
54
|
samplers: Dict[Type[DiffusionSampler], Dict[float, DiffusionSampler]] = field(default_factory=dict)
|
55
55
|
config: Dict[str, Any] = field(default_factory=dict)
|
56
|
+
wandb_run = None
|
56
57
|
|
57
58
|
@classmethod
|
58
59
|
def from_wandb_run(
|
@@ -75,7 +76,7 @@ class DiffusionInferencePipeline(InferencePipeline):
|
|
75
76
|
Returns:
|
76
77
|
DiffusionInferencePipeline instance
|
77
78
|
"""
|
78
|
-
states, config = load_from_wandb_run(
|
79
|
+
states, config, run = load_from_wandb_run(
|
79
80
|
wandb_run,
|
80
81
|
project=project,
|
81
82
|
entity=entity,
|
@@ -93,6 +94,7 @@ class DiffusionInferencePipeline(InferencePipeline):
|
|
93
94
|
state=state,
|
94
95
|
best_state=best_state,
|
95
96
|
rngstate=RandomMarkovState(jax.random.PRNGKey(42)),
|
97
|
+
run=run,
|
96
98
|
)
|
97
99
|
return pipeline
|
98
100
|
|
@@ -117,7 +119,7 @@ class DiffusionInferencePipeline(InferencePipeline):
|
|
117
119
|
Returns:
|
118
120
|
DiffusionInferencePipeline instance
|
119
121
|
"""
|
120
|
-
states, config = load_from_wandb_registry(
|
122
|
+
states, config, run = load_from_wandb_registry(
|
121
123
|
modelname=modelname,
|
122
124
|
project=project,
|
123
125
|
entity=entity,
|
@@ -137,6 +139,7 @@ class DiffusionInferencePipeline(InferencePipeline):
|
|
137
139
|
state=state,
|
138
140
|
best_state=best_state,
|
139
141
|
rngstate=RandomMarkovState(jax.random.PRNGKey(42)),
|
142
|
+
run=run,
|
140
143
|
)
|
141
144
|
return pipeline
|
142
145
|
|
@@ -147,6 +150,7 @@ class DiffusionInferencePipeline(InferencePipeline):
|
|
147
150
|
state: Dict[str, Any],
|
148
151
|
best_state: Optional[Dict[str, Any]] = None,
|
149
152
|
rngstate: Optional[RandomMarkovState] = None,
|
153
|
+
run=None,
|
150
154
|
):
|
151
155
|
if rngstate is None:
|
152
156
|
rngstate = RandomMarkovState(jax.random.PRNGKey(42))
|
@@ -161,6 +165,7 @@ class DiffusionInferencePipeline(InferencePipeline):
|
|
161
165
|
autoencoder=config['autoencoder'],
|
162
166
|
input_config=config['input_config'],
|
163
167
|
config=config,
|
168
|
+
wandb_run=run,
|
164
169
|
)
|
165
170
|
|
166
171
|
def get_sampler(
|
@@ -208,7 +213,8 @@ class DiffusionInferencePipeline(InferencePipeline):
|
|
208
213
|
self,
|
209
214
|
num_samples: int,
|
210
215
|
resolution: int,
|
211
|
-
conditioning_data:
|
216
|
+
conditioning_data: List[Union[Tuple, Dict]] = None,
|
217
|
+
conditioning_data_tokens: Tuple = None,
|
212
218
|
sequence_length: Optional[int] = None,
|
213
219
|
diffusion_steps: int = 50,
|
214
220
|
guidance_scale: float = 1.0,
|
@@ -256,5 +262,6 @@ class DiffusionInferencePipeline(InferencePipeline):
|
|
256
262
|
steps_override=steps_override,
|
257
263
|
priors=priors,
|
258
264
|
rngstate=rngstate,
|
259
|
-
conditioning=conditioning_data
|
265
|
+
conditioning=conditioning_data,
|
266
|
+
model_conditioning_inputs=conditioning_data_tokens,
|
260
267
|
)
|
@@ -292,7 +292,7 @@ def load_from_wandb_run(
|
|
292
292
|
config = run.config
|
293
293
|
except Exception as e:
|
294
294
|
print(f"Warning: Failed to load model from wandb: {e}")
|
295
|
-
return states, config
|
295
|
+
return states, config, run
|
296
296
|
|
297
297
|
def load_from_wandb_registry(
|
298
298
|
modelname: str,
|
@@ -307,6 +307,7 @@ def load_from_wandb_registry(
|
|
307
307
|
# Get the model version from wandb
|
308
308
|
states = None
|
309
309
|
config = None
|
310
|
+
run = None
|
310
311
|
try:
|
311
312
|
artifact = wandb.Api().artifact(f"{registry}/{modelname}:{version}")
|
312
313
|
ckpt_dir = artifact.download()
|
@@ -317,4 +318,4 @@ def load_from_wandb_registry(
|
|
317
318
|
config = run.config
|
318
319
|
except Exception as e:
|
319
320
|
print(f"Warning: Failed to load model from wandb: {e}")
|
320
|
-
return states, config
|
321
|
+
return states, config, run
|
@@ -7,7 +7,7 @@ def get_clip_metric(
|
|
7
7
|
):
|
8
8
|
from transformers import AutoProcessor, FlaxCLIPModel
|
9
9
|
model = FlaxCLIPModel.from_pretrained(modelname, dtype=jnp.float16)
|
10
|
-
processor = AutoProcessor.from_pretrained(modelname, use_fast=
|
10
|
+
processor = AutoProcessor.from_pretrained(modelname, use_fast=False, dtype=jnp.float16)
|
11
11
|
|
12
12
|
@jax.jit
|
13
13
|
def calc(pixel_values, input_ids, attention_mask):
|
@@ -247,6 +247,7 @@ class BasicTransformerBlock(nn.Module):
|
|
247
247
|
use_cross_only:bool = False
|
248
248
|
only_pure_attention:bool = False
|
249
249
|
force_fp32_for_softmax: bool = True
|
250
|
+
norm_epsilon: float = 1e-4
|
250
251
|
|
251
252
|
def setup(self):
|
252
253
|
if self.use_flash_attention:
|
@@ -278,9 +279,9 @@ class BasicTransformerBlock(nn.Module):
|
|
278
279
|
)
|
279
280
|
|
280
281
|
self.ff = FlaxFeedForward(dim=self.query_dim)
|
281
|
-
self.norm1 = nn.RMSNorm(epsilon=
|
282
|
-
self.norm2 = nn.RMSNorm(epsilon=
|
283
|
-
self.norm3 = nn.RMSNorm(epsilon=
|
282
|
+
self.norm1 = nn.RMSNorm(epsilon=self.norm_epsilon, dtype=self.dtype)
|
283
|
+
self.norm2 = nn.RMSNorm(epsilon=self.norm_epsilon, dtype=self.dtype)
|
284
|
+
self.norm3 = nn.RMSNorm(epsilon=self.norm_epsilon, dtype=self.dtype)
|
284
285
|
|
285
286
|
@nn.compact
|
286
287
|
def __call__(self, hidden_states, context=None):
|
@@ -312,13 +313,14 @@ class TransformerBlock(nn.Module):
|
|
312
313
|
# kernel_init: Callable = kernel_init(1.0)
|
313
314
|
norm_inputs: bool = True
|
314
315
|
explicitly_add_residual: bool = True
|
316
|
+
norm_epsilon: float = 1e-4
|
315
317
|
|
316
318
|
@nn.compact
|
317
319
|
def __call__(self, x, context=None):
|
318
320
|
inner_dim = self.heads * self.dim_head
|
319
321
|
C = x.shape[-1]
|
320
322
|
if self.norm_inputs:
|
321
|
-
x = nn.RMSNorm(epsilon=
|
323
|
+
x = nn.RMSNorm(epsilon=self.norm_epsilon, dtype=self.dtype)(x)
|
322
324
|
if self.use_projection == True:
|
323
325
|
if self.use_linear_attention:
|
324
326
|
projected_x = nn.Dense(features=inner_dim,
|
@@ -350,6 +352,7 @@ class TransformerBlock(nn.Module):
|
|
350
352
|
use_cross_only=(not self.use_self_and_cross),
|
351
353
|
only_pure_attention=self.only_pure_attention,
|
352
354
|
force_fp32_for_softmax=self.force_fp32_for_softmax,
|
355
|
+
norm_epsilon=self.norm_epsilon
|
353
356
|
# kernel_init=self.kernel_init
|
354
357
|
)(projected_x, context)
|
355
358
|
|
@@ -0,0 +1,380 @@
|
|
1
|
+
# flaxdiff/models/better_uvit.py
|
2
|
+
import jax
|
3
|
+
import jax.numpy as jnp
|
4
|
+
from flax import linen as nn
|
5
|
+
from typing import Callable, Any, Optional, Tuple, Sequence, Union
|
6
|
+
import einops
|
7
|
+
from functools import partial
|
8
|
+
|
9
|
+
# Re-use existing components if they are suitable
|
10
|
+
from .common import kernel_init, FourierEmbedding, TimeProjection, hilbert_indices, inverse_permutation
|
11
|
+
from .attention import NormalAttention # Using NormalAttention for RoPE integration
|
12
|
+
from flax.typing import Dtype, PrecisionLike
|
13
|
+
|
14
|
+
# --- Rotary Positional Embedding (RoPE) ---
|
15
|
+
# Adapted from https://github.com/google-deepmind/ring_attention/blob/main/ring_attention/layers/rotary.py
|
16
|
+
|
17
|
+
def _rotate_half(x: jax.Array) -> jax.Array:
|
18
|
+
"""Rotates half the hidden dims of the input."""
|
19
|
+
x1 = x[..., : x.shape[-1] // 2]
|
20
|
+
x2 = x[..., x.shape[-1] // 2 :]
|
21
|
+
return jnp.concatenate((-x2, x1), axis=-1)
|
22
|
+
|
23
|
+
def apply_rotary_embedding(
|
24
|
+
x: jax.Array, freqs_cis: jax.Array
|
25
|
+
) -> jax.Array:
|
26
|
+
"""Applies rotary embedding to the input tensor using rotate_half method."""
|
27
|
+
# x shape: [..., Sequence, Dimension] e.g. [B, H, S, D] or [B, S, D]
|
28
|
+
# freqs_cis shape: complex [Sequence, Dimension / 2]
|
29
|
+
|
30
|
+
# Extract cos and sin from the complex freqs_cis
|
31
|
+
cos_freqs = jnp.real(freqs_cis) # Shape [S, D/2]
|
32
|
+
sin_freqs = jnp.imag(freqs_cis) # Shape [S, D/2]
|
33
|
+
|
34
|
+
# Expand dims for broadcasting: [1, 1, S, D/2] or [1, S, D/2]
|
35
|
+
if x.ndim == 4: # [B, H, S, D]
|
36
|
+
cos_freqs = jnp.expand_dims(cos_freqs, axis=(0, 1))
|
37
|
+
sin_freqs = jnp.expand_dims(sin_freqs, axis=(0, 1))
|
38
|
+
elif x.ndim == 3: # [B, S, D]
|
39
|
+
cos_freqs = jnp.expand_dims(cos_freqs, axis=0)
|
40
|
+
sin_freqs = jnp.expand_dims(sin_freqs, axis=0)
|
41
|
+
|
42
|
+
# Duplicate cos and sin for the full dimension D
|
43
|
+
# Shape becomes [..., S, D]
|
44
|
+
cos_freqs = jnp.concatenate([cos_freqs, cos_freqs], axis=-1)
|
45
|
+
sin_freqs = jnp.concatenate([sin_freqs, sin_freqs], axis=-1)
|
46
|
+
|
47
|
+
# Apply rotation: x * cos + rotate_half(x) * sin
|
48
|
+
x_rotated = x * cos_freqs + _rotate_half(x) * sin_freqs
|
49
|
+
return x_rotated.astype(x.dtype)
|
50
|
+
|
51
|
+
|
52
|
+
class RotaryEmbedding(nn.Module):
|
53
|
+
dim: int # Dimension of the head
|
54
|
+
max_seq_len: int = 2048
|
55
|
+
base: int = 10000
|
56
|
+
dtype: Dtype = jnp.float32
|
57
|
+
|
58
|
+
def setup(self):
|
59
|
+
inv_freq = 1.0 / (
|
60
|
+
self.base ** (jnp.arange(0, self.dim, 2, dtype=jnp.float32) / self.dim)
|
61
|
+
)
|
62
|
+
t = jnp.arange(self.max_seq_len, dtype=jnp.float32)
|
63
|
+
freqs = jnp.outer(t, inv_freq) # Shape: [max_seq_len, dim / 2]
|
64
|
+
|
65
|
+
# Precompute the complex form: cos(theta) + i * sin(theta)
|
66
|
+
self.freqs_cis_complex = jnp.cos(freqs) + 1j * jnp.sin(freqs)
|
67
|
+
# Shape: [max_seq_len, dim / 2]
|
68
|
+
|
69
|
+
def __call__(self, seq_len: int):
|
70
|
+
if seq_len > self.max_seq_len:
|
71
|
+
raise ValueError(f"Sequence length {seq_len} exceeds max_seq_len {self.max_seq_len}")
|
72
|
+
# Return complex shape [seq_len, dim / 2]
|
73
|
+
return self.freqs_cis_complex[:seq_len, :]
|
74
|
+
|
75
|
+
# --- Attention with RoPE ---
|
76
|
+
|
77
|
+
class RoPEAttention(NormalAttention):
|
78
|
+
rope_emb: RotaryEmbedding
|
79
|
+
|
80
|
+
@nn.compact
|
81
|
+
def __call__(self, x, context=None, freqs_cis=None):
|
82
|
+
# x has shape [B, H, W, C] or [B, S, C]
|
83
|
+
orig_x_shape = x.shape
|
84
|
+
is_4d = len(orig_x_shape) == 4
|
85
|
+
if is_4d:
|
86
|
+
B, H, W, C = x.shape
|
87
|
+
seq_len = H * W
|
88
|
+
x = x.reshape((B, seq_len, C))
|
89
|
+
else:
|
90
|
+
B, seq_len, C = x.shape
|
91
|
+
|
92
|
+
context = x if context is None else context
|
93
|
+
if len(context.shape) == 4:
|
94
|
+
_B, _H, _W, _C = context.shape
|
95
|
+
context_seq_len = _H * _W
|
96
|
+
context = context.reshape((B, context_seq_len, _C))
|
97
|
+
else:
|
98
|
+
_B, context_seq_len, _C = context.shape
|
99
|
+
|
100
|
+
query = self.query(x) # [B, S, H, D]
|
101
|
+
key = self.key(context) # [B, S_ctx, H, D]
|
102
|
+
value = self.value(context) # [B, S_ctx, H, D]
|
103
|
+
|
104
|
+
# Apply RoPE to query and key
|
105
|
+
if freqs_cis is not None:
|
106
|
+
# Permute to [B, H, S, D] for RoPE application if needed by apply_rotary_embedding
|
107
|
+
query = einops.rearrange(query, 'b s h d -> b h s d')
|
108
|
+
key = einops.rearrange(key, 'b s h d -> b h s d')
|
109
|
+
|
110
|
+
query = apply_rotary_embedding(query, freqs_cis)
|
111
|
+
key = apply_rotary_embedding(key, freqs_cis) # Apply to key as well
|
112
|
+
|
113
|
+
# Permute back to [B, S, H, D] for dot_product_attention
|
114
|
+
query = einops.rearrange(query, 'b h s d -> b s h d')
|
115
|
+
key = einops.rearrange(key, 'b h s d -> b s h d')
|
116
|
+
|
117
|
+
hidden_states = nn.dot_product_attention(
|
118
|
+
query, key, value, dtype=self.dtype, broadcast_dropout=False,
|
119
|
+
dropout_rng=None, precision=self.precision, force_fp32_for_softmax=self.force_fp32_for_softmax,
|
120
|
+
deterministic=True
|
121
|
+
) # Output shape [B, S, H, D]
|
122
|
+
|
123
|
+
proj = self.proj_attn(hidden_states) # Output shape [B, S, C]
|
124
|
+
|
125
|
+
if is_4d:
|
126
|
+
proj = proj.reshape(orig_x_shape) # Reshape back if input was 4D
|
127
|
+
|
128
|
+
return proj
|
129
|
+
|
130
|
+
# --- adaLN-Zero ---
|
131
|
+
|
132
|
+
class AdaLNZero(nn.Module):
|
133
|
+
features: int
|
134
|
+
dtype: Optional[Dtype] = None
|
135
|
+
precision: PrecisionLike = None
|
136
|
+
norm_epsilon: float = 1e-5 # Standard LayerNorm epsilon
|
137
|
+
|
138
|
+
@nn.compact
|
139
|
+
def __call__(self, x, conditioning):
|
140
|
+
# Project conditioning signal to get scale and shift parameters
|
141
|
+
# Conditioning shape: [B, D_cond] -> [B, 1, ..., 1, 6 * features] for broadcasting
|
142
|
+
# Or [B, 1, 6*features] if x is [B, S, F]
|
143
|
+
|
144
|
+
# Ensure conditioning has seq dim if x does
|
145
|
+
if x.ndim == 3 and conditioning.ndim == 2: # x=[B,S,F], cond=[B,D_cond]
|
146
|
+
conditioning = jnp.expand_dims(conditioning, axis=1) # cond=[B,1,D_cond]
|
147
|
+
|
148
|
+
# Project conditioning to get 6 params per feature (scale_mlp, shift_mlp, gate_mlp, scale_attn, shift_attn, gate_attn)
|
149
|
+
# Using nn.DenseGeneral for flexibility if needed, but nn.Dense is fine if cond is [B, D_cond] or [B, 1, D_cond]
|
150
|
+
ada_params = nn.Dense(
|
151
|
+
features=6 * self.features,
|
152
|
+
dtype=self.dtype,
|
153
|
+
precision=self.precision,
|
154
|
+
kernel_init=nn.initializers.zeros, # Initialize projection to zero (Zero init)
|
155
|
+
name="ada_proj"
|
156
|
+
)(conditioning)
|
157
|
+
|
158
|
+
# Split into scale, shift, gate for MLP and Attention
|
159
|
+
scale_mlp, shift_mlp, gate_mlp, scale_attn, shift_attn, gate_attn = jnp.split(ada_params, 6, axis=-1)
|
160
|
+
|
161
|
+
# Apply Layer Normalization
|
162
|
+
norm = nn.LayerNorm(epsilon=self.norm_epsilon, use_scale=False, use_bias=False, dtype=self.dtype)
|
163
|
+
# norm = nn.RMSNorm(epsilon=self.norm_epsilon, dtype=self.dtype) # Alternative: RMSNorm
|
164
|
+
|
165
|
+
norm_x = norm(x)
|
166
|
+
|
167
|
+
# Modulate for Attention path
|
168
|
+
x_attn = norm_x * (1 + scale_attn) + shift_attn
|
169
|
+
|
170
|
+
# Modulate for MLP path
|
171
|
+
x_mlp = norm_x * (1 + scale_mlp) + shift_mlp
|
172
|
+
|
173
|
+
# Return modulated outputs and gates
|
174
|
+
return x_attn, gate_attn, x_mlp, gate_mlp
|
175
|
+
|
176
|
+
|
177
|
+
# --- DiT Block ---
|
178
|
+
|
179
|
+
class DiTBlock(nn.Module):
|
180
|
+
features: int
|
181
|
+
num_heads: int
|
182
|
+
mlp_ratio: int = 4
|
183
|
+
dropout_rate: float = 0.0 # Typically dropout is not used in diffusion models
|
184
|
+
dtype: Optional[Dtype] = None
|
185
|
+
precision: PrecisionLike = None
|
186
|
+
use_flash_attention: bool = False # Keep option, but RoPEAttention uses NormalAttention base
|
187
|
+
force_fp32_for_softmax: bool = True
|
188
|
+
norm_epsilon: float = 1e-5
|
189
|
+
rope_emb: RotaryEmbedding # Pass RoPE module
|
190
|
+
|
191
|
+
def setup(self):
|
192
|
+
hidden_features = int(self.features * self.mlp_ratio)
|
193
|
+
self.ada_ln_zero = AdaLNZero(self.features, dtype=self.dtype, precision=self.precision, norm_epsilon=self.norm_epsilon)
|
194
|
+
|
195
|
+
# Use RoPEAttention
|
196
|
+
self.attention = RoPEAttention(
|
197
|
+
query_dim=self.features,
|
198
|
+
heads=self.num_heads,
|
199
|
+
dim_head=self.features // self.num_heads,
|
200
|
+
dtype=self.dtype,
|
201
|
+
precision=self.precision,
|
202
|
+
use_bias=True, # Bias is common in DiT attention proj
|
203
|
+
force_fp32_for_softmax=self.force_fp32_for_softmax,
|
204
|
+
rope_emb=self.rope_emb # Pass RoPE module instance
|
205
|
+
)
|
206
|
+
|
207
|
+
# Standard MLP block
|
208
|
+
self.mlp = nn.Sequential([
|
209
|
+
nn.Dense(features=hidden_features, dtype=self.dtype, precision=self.precision),
|
210
|
+
nn.gelu,
|
211
|
+
nn.Dense(features=self.features, dtype=self.dtype, precision=self.precision)
|
212
|
+
])
|
213
|
+
|
214
|
+
@nn.compact
|
215
|
+
def __call__(self, x, conditioning, freqs_cis):
|
216
|
+
# x shape: [B, S, F]
|
217
|
+
# conditioning shape: [B, D_cond]
|
218
|
+
|
219
|
+
residual = x
|
220
|
+
|
221
|
+
# Apply adaLN-Zero to get modulated inputs and gates
|
222
|
+
x_attn, gate_attn, x_mlp, gate_mlp = self.ada_ln_zero(x, conditioning)
|
223
|
+
|
224
|
+
# Attention block
|
225
|
+
attn_output = self.attention(x_attn, context=None, freqs_cis=freqs_cis) # Self-attention only
|
226
|
+
x = residual + gate_attn * attn_output
|
227
|
+
|
228
|
+
# MLP block
|
229
|
+
mlp_output = self.mlp(x_mlp)
|
230
|
+
x = x + gate_mlp * mlp_output
|
231
|
+
|
232
|
+
return x
|
233
|
+
|
234
|
+
# --- Patch Embedding (reuse or define if needed) ---
|
235
|
+
# Assuming PatchEmbedding exists in simple_vit.py and is suitable
|
236
|
+
from .simple_vit import PatchEmbedding, unpatchify
|
237
|
+
|
238
|
+
# --- Better UViT (DiT Style) ---
|
239
|
+
|
240
|
+
class BetterUViT(nn.Module):
|
241
|
+
output_channels: int = 3
|
242
|
+
patch_size: int = 16
|
243
|
+
emb_features: int = 768
|
244
|
+
num_layers: int = 12
|
245
|
+
num_heads: int = 12
|
246
|
+
mlp_ratio: int = 4
|
247
|
+
dropout_rate: float = 0.0 # Typically 0 for diffusion
|
248
|
+
dtype: Optional[Dtype] = None
|
249
|
+
precision: PrecisionLike = None
|
250
|
+
use_flash_attention: bool = False # Passed down, but RoPEAttention uses NormalAttention
|
251
|
+
force_fp32_for_softmax: bool = True
|
252
|
+
norm_epsilon: float = 1e-5
|
253
|
+
learn_sigma: bool = False # Option to predict sigma like in DiT paper
|
254
|
+
use_hilbert: bool = False # Toggle Hilbert patch reorder
|
255
|
+
|
256
|
+
def setup(self):
|
257
|
+
self.patch_embed = PatchEmbedding(
|
258
|
+
patch_size=self.patch_size,
|
259
|
+
embedding_dim=self.emb_features,
|
260
|
+
dtype=self.dtype,
|
261
|
+
precision=self.precision
|
262
|
+
)
|
263
|
+
|
264
|
+
# Time embedding projection
|
265
|
+
self.time_embed = nn.Sequential([
|
266
|
+
FourierEmbedding(features=self.emb_features),
|
267
|
+
TimeProjection(features=self.emb_features * self.mlp_ratio), # Project to MLP dim
|
268
|
+
nn.Dense(features=self.emb_features, dtype=self.dtype, precision=self.precision) # Final projection
|
269
|
+
])
|
270
|
+
|
271
|
+
# Text context projection (if used)
|
272
|
+
# Assuming textcontext is already projected to some dimension, project it to match emb_features
|
273
|
+
# This might need adjustment based on how text context is provided
|
274
|
+
self.text_proj = nn.Dense(features=self.emb_features, dtype=self.dtype, precision=self.precision, name="text_context_proj")
|
275
|
+
|
276
|
+
# Rotary Positional Embedding
|
277
|
+
# Max length needs to be estimated or set large enough.
|
278
|
+
# For images, seq len = (H/P) * (W/P). Example: 256/16 * 256/16 = 16*16 = 256
|
279
|
+
# Add 1 if a class token is used, or more for text tokens if concatenated.
|
280
|
+
# Let's assume max seq len accommodates patches + time + text tokens if needed, or just patches.
|
281
|
+
# If only patches use RoPE, max_len = max_image_tokens
|
282
|
+
# If time/text are concatenated *before* blocks, max_len needs to include them.
|
283
|
+
# DiT typically applies PE only to patch tokens. Let's follow that.
|
284
|
+
# max_len should be max number of patches.
|
285
|
+
# Example: max image size 512x512, patch 16 -> (512/16)^2 = 32^2 = 1024 patches
|
286
|
+
self.rope = RotaryEmbedding(dim=self.emb_features // self.num_heads, max_seq_len=4096, dtype=self.dtype) # Dim per head
|
287
|
+
|
288
|
+
# Transformer Blocks
|
289
|
+
self.blocks = [
|
290
|
+
DiTBlock(
|
291
|
+
features=self.emb_features,
|
292
|
+
num_heads=self.num_heads,
|
293
|
+
mlp_ratio=self.mlp_ratio,
|
294
|
+
dropout_rate=self.dropout_rate,
|
295
|
+
dtype=self.dtype,
|
296
|
+
precision=self.precision,
|
297
|
+
use_flash_attention=self.use_flash_attention,
|
298
|
+
force_fp32_for_softmax=self.force_fp32_for_softmax,
|
299
|
+
norm_epsilon=self.norm_epsilon,
|
300
|
+
rope_emb=self.rope, # Pass RoPE instance
|
301
|
+
name=f"dit_block_{i}"
|
302
|
+
) for i in range(self.num_layers)
|
303
|
+
]
|
304
|
+
|
305
|
+
# Final Layer (Normalization + Linear Projection)
|
306
|
+
self.final_norm = nn.LayerNorm(epsilon=self.norm_epsilon, dtype=self.dtype, name="final_norm")
|
307
|
+
# self.final_norm = nn.RMSNorm(epsilon=self.norm_epsilon, dtype=self.dtype, name="final_norm")
|
308
|
+
|
309
|
+
# Predict patch pixels + potentially sigma
|
310
|
+
output_dim = self.patch_size * self.patch_size * self.output_channels
|
311
|
+
if self.learn_sigma:
|
312
|
+
output_dim *= 2 # Predict both mean and variance (or log_variance)
|
313
|
+
|
314
|
+
self.final_proj = nn.Dense(
|
315
|
+
features=output_dim,
|
316
|
+
dtype=self.dtype,
|
317
|
+
precision=self.precision,
|
318
|
+
kernel_init=nn.initializers.zeros, # Initialize final layer to zero
|
319
|
+
name="final_proj"
|
320
|
+
)
|
321
|
+
|
322
|
+
@nn.compact
|
323
|
+
def __call__(self, x, temb, textcontext=None):
|
324
|
+
B, H, W, C = x.shape
|
325
|
+
assert H % self.patch_size == 0 and W % self.patch_size == 0, "Image dimensions must be divisible by patch size"
|
326
|
+
|
327
|
+
# 1. Patch Embedding
|
328
|
+
patches = self.patch_embed(x) # Shape: [B, num_patches, emb_features]
|
329
|
+
num_patches = patches.shape[1]
|
330
|
+
|
331
|
+
# Optional Hilbert reorder
|
332
|
+
if self.use_hilbert:
|
333
|
+
idx = hilbert_indices(H // self.patch_size, W // self.patch_size)
|
334
|
+
inv_idx = inverse_permutation(idx)
|
335
|
+
patches = patches[:, idx, :]
|
336
|
+
|
337
|
+
# replace x with patches
|
338
|
+
x_seq = patches
|
339
|
+
|
340
|
+
# 2. Prepare Conditioning Signal (Time + Text Context)
|
341
|
+
t_emb = self.time_embed(temb) # Shape: [B, emb_features]
|
342
|
+
|
343
|
+
cond_emb = t_emb
|
344
|
+
if textcontext is not None:
|
345
|
+
text_emb = self.text_proj(textcontext) # Shape: [B, num_text_tokens, emb_features]
|
346
|
+
# Pool or select text embedding (e.g., mean pool or use CLS token)
|
347
|
+
# Assuming mean pooling for simplicity
|
348
|
+
text_emb_pooled = jnp.mean(text_emb, axis=1) # Shape: [B, emb_features]
|
349
|
+
cond_emb = cond_emb + text_emb_pooled # Combine time and text embeddings
|
350
|
+
|
351
|
+
# 3. Apply RoPE
|
352
|
+
# Get RoPE frequencies for the sequence length (number of patches)
|
353
|
+
freqs_cis = self.rope(seq_len=num_patches) # Shape [num_patches, D_head/2]
|
354
|
+
|
355
|
+
# 4. Apply Transformer Blocks with adaLN-Zero conditioning
|
356
|
+
for block in self.blocks:
|
357
|
+
x_seq = block(x_seq, conditioning=cond_emb, freqs_cis=freqs_cis)
|
358
|
+
|
359
|
+
# 5. Final Layer
|
360
|
+
x_out = self.final_norm(x_seq)
|
361
|
+
x_out = self.final_proj(x_out) # Shape: [B, num_patches, patch_pixels (*2 if learn_sigma)]
|
362
|
+
|
363
|
+
# Optional Hilbert inverse reorder
|
364
|
+
if self.use_hilbert:
|
365
|
+
x_out = x_out[:, inv_idx, :]
|
366
|
+
|
367
|
+
# 6. Unpatchify
|
368
|
+
if self.learn_sigma:
|
369
|
+
# Split into mean and variance predictions
|
370
|
+
x_mean, x_logvar = jnp.split(x_out, 2, axis=-1)
|
371
|
+
x = unpatchify(x_mean, channels=self.output_channels)
|
372
|
+
# Return both mean and logvar if needed by the loss function
|
373
|
+
# For now, just returning the mean prediction like standard diffusion models
|
374
|
+
# logvar = unpatchify(x_logvar, channels=self.output_channels)
|
375
|
+
# return x, logvar
|
376
|
+
return x
|
377
|
+
else:
|
378
|
+
x = unpatchify(x_out, channels=self.output_channels) # Shape: [B, H, W, C]
|
379
|
+
return x
|
380
|
+
|
@@ -6,6 +6,8 @@ from flax.typing import Dtype, PrecisionLike
|
|
6
6
|
from typing import Dict, Callable, Sequence, Any, Union
|
7
7
|
import einops
|
8
8
|
from functools import partial
|
9
|
+
import math
|
10
|
+
from einops import rearrange
|
9
11
|
|
10
12
|
# Kernel initializer to use
|
11
13
|
def kernel_init(scale=1.0, dtype=jnp.float32):
|
@@ -247,7 +249,7 @@ class Downsample(nn.Module):
|
|
247
249
|
return out
|
248
250
|
|
249
251
|
|
250
|
-
def l2norm(t, axis=1, eps=1e-
|
252
|
+
def l2norm(t, axis=1, eps=1e-6): # Increased epsilon from 1e-12
|
251
253
|
denom = jnp.clip(jnp.linalg.norm(t, ord=2, axis=axis, keepdims=True), eps)
|
252
254
|
out = t/denom
|
253
255
|
return (out)
|
@@ -266,14 +268,15 @@ class ResidualBlock(nn.Module):
|
|
266
268
|
dtype: Optional[Dtype] = None
|
267
269
|
precision: PrecisionLike = None
|
268
270
|
named_norms:bool=False
|
271
|
+
norm_epsilon: float = 1e-4 # Added epsilon parameter, increased default
|
269
272
|
|
270
273
|
def setup(self):
|
271
274
|
if self.norm_groups > 0:
|
272
|
-
norm = partial(nn.GroupNorm, self.norm_groups)
|
275
|
+
norm = partial(nn.GroupNorm, self.norm_groups, epsilon=self.norm_epsilon)
|
273
276
|
self.norm1 = norm(name="GroupNorm_0") if self.named_norms else norm()
|
274
277
|
self.norm2 = norm(name="GroupNorm_1") if self.named_norms else norm()
|
275
278
|
else:
|
276
|
-
norm = partial(nn.RMSNorm,
|
279
|
+
norm = partial(nn.RMSNorm, epsilon=self.norm_epsilon)
|
277
280
|
self.norm1 = norm()
|
278
281
|
self.norm2 = norm()
|
279
282
|
|
@@ -333,4 +336,72 @@ class ResidualBlock(nn.Module):
|
|
333
336
|
out = jnp.concatenate([out, extra_features], axis=-1) if extra_features is not None else out
|
334
337
|
|
335
338
|
return out
|
336
|
-
|
339
|
+
|
340
|
+
# Convert Hilbert index d to 2D coordinates (x, y) for an n x n grid
|
341
|
+
def _d2xy(n, d):
|
342
|
+
x = 0
|
343
|
+
y = 0
|
344
|
+
t = d
|
345
|
+
s = 1
|
346
|
+
while s < n:
|
347
|
+
rx = (t // 2) & 1
|
348
|
+
ry = (t ^ rx) & 1
|
349
|
+
if ry == 0:
|
350
|
+
if rx == 1:
|
351
|
+
x = n - 1 - x
|
352
|
+
y = n - 1 - y
|
353
|
+
x, y = y, x
|
354
|
+
x += s * rx
|
355
|
+
y += s * ry
|
356
|
+
t //= 4
|
357
|
+
s *= 2
|
358
|
+
return x, y
|
359
|
+
|
360
|
+
# Hilbert index mapping for a rectangular grid of patches H_P x W_P
|
361
|
+
|
362
|
+
def hilbert_indices(H_P, W_P):
|
363
|
+
size = max(H_P, W_P)
|
364
|
+
order = math.ceil(math.log2(size))
|
365
|
+
n = 1 << order
|
366
|
+
coords = []
|
367
|
+
for d in range(n * n):
|
368
|
+
x, y = _d2xy(n, d)
|
369
|
+
# x is column index, y is row index
|
370
|
+
if x < W_P and y < H_P:
|
371
|
+
coords.append((y, x)) # (row, col)
|
372
|
+
if len(coords) == H_P * W_P:
|
373
|
+
break
|
374
|
+
# Convert (row, col) to linear indices row-major
|
375
|
+
indices = [r * W_P + c for r, c in coords]
|
376
|
+
return jnp.array(indices, dtype=jnp.int32)
|
377
|
+
|
378
|
+
# Inverse permutation: given idx where idx[i] = new position of element i, return inv such that inv[idx[i]] = i
|
379
|
+
|
380
|
+
def inverse_permutation(idx):
|
381
|
+
inv = jnp.zeros_like(idx)
|
382
|
+
inv = inv.at[idx].set(jnp.arange(idx.shape[0], dtype=idx.dtype))
|
383
|
+
return inv
|
384
|
+
|
385
|
+
# Patchify using Hilbert ordering: extract patches and reorder sequence
|
386
|
+
|
387
|
+
def hilbert_patchify(x, patch_size):
|
388
|
+
B, H, W, C = x.shape
|
389
|
+
H_P = H // patch_size
|
390
|
+
W_P = W // patch_size
|
391
|
+
# Extract patches in row-major
|
392
|
+
patches = rearrange(x, 'b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size)
|
393
|
+
idx = hilbert_indices(H_P, W_P)
|
394
|
+
return patches[:, idx, :]
|
395
|
+
|
396
|
+
# Unpatchify from Hilbert ordering: reorder sequence back and reconstruct image
|
397
|
+
|
398
|
+
def hilbert_unpatchify(patches, patch_size, H, W, C):
|
399
|
+
B, N, D = patches.shape
|
400
|
+
H_P = H // patch_size
|
401
|
+
W_P = W // patch_size
|
402
|
+
inv = inverse_permutation(hilbert_indices(H_P, W_P))
|
403
|
+
# Reorder back to row-major
|
404
|
+
linear = patches[:, inv, :]
|
405
|
+
# Reconstruct image
|
406
|
+
x = rearrange(linear, 'b (h w) (p1 p2 c) -> b (h p1) (w p2) c', h=H_P, w=W_P, p1=patch_size, p2=patch_size, c=C)
|
407
|
+
return x
|
@@ -10,6 +10,7 @@ from flaxdiff.models.simple_unet import FourierEmbedding, TimeProjection, ConvLa
|
|
10
10
|
import einops
|
11
11
|
from flax.typing import Dtype, PrecisionLike
|
12
12
|
from functools import partial
|
13
|
+
from .common import hilbert_indices, inverse_permutation
|
13
14
|
|
14
15
|
def unpatchify(x, channels=3):
|
15
16
|
patch_size = int((x.shape[2] // channels) ** 0.5)
|
@@ -55,8 +56,6 @@ class UViT(nn.Module):
|
|
55
56
|
num_layers: int = 12
|
56
57
|
num_heads: int = 12
|
57
58
|
dropout_rate: float = 0.1
|
58
|
-
dtype: Any = jnp.float32
|
59
|
-
precision: Any = jax.lax.Precision.HIGH
|
60
59
|
use_projection: bool = False
|
61
60
|
use_flash_attention: bool = False
|
62
61
|
use_self_and_cross: bool = False
|
@@ -65,16 +64,17 @@ class UViT(nn.Module):
|
|
65
64
|
norm_groups:int=8
|
66
65
|
dtype: Optional[Dtype] = None
|
67
66
|
precision: PrecisionLike = None
|
68
|
-
# kernel_init: Callable = partial(kernel_init, scale=1.0)
|
69
67
|
add_residualblock_output: bool = False
|
70
68
|
norm_inputs: bool = False
|
71
69
|
explicitly_add_residual: bool = True
|
70
|
+
norm_epsilon: float = 1e-4 # Added epsilon parameter, increased default
|
71
|
+
use_hilbert: bool = False # Toggle Hilbert patch reorder
|
72
72
|
|
73
73
|
def setup(self):
|
74
74
|
if self.norm_groups > 0:
|
75
|
-
self.norm = partial(nn.GroupNorm, self.norm_groups)
|
75
|
+
self.norm = partial(nn.GroupNorm, self.norm_groups, epsilon=self.norm_epsilon)
|
76
76
|
else:
|
77
|
-
self.norm = partial(nn.RMSNorm,
|
77
|
+
self.norm = partial(nn.RMSNorm, epsilon=self.norm_epsilon)
|
78
78
|
|
79
79
|
@nn.compact
|
80
80
|
def __call__(self, x, temb, textcontext=None):
|
@@ -83,28 +83,32 @@ class UViT(nn.Module):
|
|
83
83
|
temb = TimeProjection(features=self.emb_features)(temb)
|
84
84
|
|
85
85
|
original_img = x
|
86
|
+
B, H, W, C = original_img.shape
|
87
|
+
H_P = H // self.patch_size
|
88
|
+
W_P = W // self.patch_size
|
86
89
|
|
87
90
|
# Patch embedding
|
88
91
|
x = PatchEmbedding(patch_size=self.patch_size, embedding_dim=self.emb_features,
|
89
92
|
dtype=self.dtype, precision=self.precision)(x)
|
90
93
|
num_patches = x.shape[1]
|
91
|
-
|
94
|
+
|
95
|
+
# Optional Hilbert reorder
|
96
|
+
if self.use_hilbert:
|
97
|
+
idx = hilbert_indices(H_P, W_P)
|
98
|
+
inv_idx = inverse_permutation(idx)
|
99
|
+
x = x[:, idx, :]
|
100
|
+
|
92
101
|
context_emb = nn.DenseGeneral(features=self.emb_features,
|
93
102
|
dtype=self.dtype, precision=self.precision)(textcontext)
|
94
103
|
num_text_tokens = textcontext.shape[1]
|
95
104
|
|
96
|
-
# print(f'Shape of x after patch embedding: {x.shape}, numPatches: {num_patches}, temb: {temb.shape}, context_emb: {context_emb.shape}')
|
97
|
-
|
98
105
|
# Add time embedding
|
99
106
|
temb = jnp.expand_dims(temb, axis=1)
|
100
107
|
x = jnp.concatenate([x, temb, context_emb], axis=1)
|
101
|
-
|
102
|
-
|
108
|
+
|
103
109
|
# Add positional encoding
|
104
110
|
x = PositionalEncoding(max_len=x.shape[1], embedding_dim=self.emb_features)(x)
|
105
111
|
|
106
|
-
# print(f'Shape of x after positional encoding: {x.shape}')
|
107
|
-
|
108
112
|
skips = []
|
109
113
|
# In blocks
|
110
114
|
for i in range(self.num_layers // 2):
|
@@ -114,6 +118,7 @@ class UViT(nn.Module):
|
|
114
118
|
only_pure_attention=False,
|
115
119
|
norm_inputs=self.norm_inputs,
|
116
120
|
explicitly_add_residual=self.explicitly_add_residual,
|
121
|
+
norm_epsilon=self.norm_epsilon, # Pass epsilon
|
117
122
|
)(x)
|
118
123
|
skips.append(x)
|
119
124
|
|
@@ -124,9 +129,10 @@ class UViT(nn.Module):
|
|
124
129
|
only_pure_attention=False,
|
125
130
|
norm_inputs=self.norm_inputs,
|
126
131
|
explicitly_add_residual=self.explicitly_add_residual,
|
132
|
+
norm_epsilon=self.norm_epsilon, # Pass epsilon
|
127
133
|
)(x)
|
128
134
|
|
129
|
-
#
|
135
|
+
# Out blocks
|
130
136
|
for i in range(self.num_layers // 2):
|
131
137
|
x = jnp.concatenate([x, skips.pop()], axis=-1)
|
132
138
|
x = nn.DenseGeneral(features=self.emb_features,
|
@@ -137,14 +143,18 @@ class UViT(nn.Module):
|
|
137
143
|
only_pure_attention=False,
|
138
144
|
norm_inputs=self.norm_inputs,
|
139
145
|
explicitly_add_residual=self.explicitly_add_residual,
|
146
|
+
norm_epsilon=self.norm_epsilon, # Pass epsilon
|
140
147
|
)(x)
|
141
148
|
|
142
|
-
|
143
|
-
x = self.norm()(x)
|
149
|
+
x = self.norm()(x) # Uses norm_epsilon defined in setup
|
144
150
|
|
145
151
|
patch_dim = self.patch_size ** 2 * self.output_channels
|
146
152
|
x = nn.Dense(features=patch_dim, dtype=self.dtype, precision=self.precision)(x)
|
147
|
-
|
153
|
+
# If Hilbert, restore original patch order
|
154
|
+
if self.use_hilbert:
|
155
|
+
x = x[:, inv_idx, :]
|
156
|
+
# Extract only the image patch tokens (first num_patches tokens)
|
157
|
+
x = x[:, :num_patches, :]
|
148
158
|
x = unpatchify(x, channels=self.output_channels)
|
149
159
|
|
150
160
|
if self.add_residualblock_output:
|
@@ -578,6 +578,7 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
|
|
578
578
|
if not hasattr(self, "wandb_sweep"):
|
579
579
|
raise ValueError("Wandb sweep is not initialized. Cannot get best runs.")
|
580
580
|
|
581
|
+
print(f"Getting best runs from sweep {self.wandb_sweep.id}...")
|
581
582
|
# Get the sweep runs
|
582
583
|
runs = sorted(self.wandb_sweep.runs, key=lambda x: x.summary.get(metric, float('inf')))
|
583
584
|
best_runs = runs[:top_k]
|
@@ -588,18 +589,46 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
|
|
588
589
|
print(f"\t\tRun ID: {run.id}, Metric: {run.summary.get(metric, float('inf'))}")
|
589
590
|
return best_runs, (min(lower_bound, upper_bound), max(lower_bound, upper_bound))
|
590
591
|
|
591
|
-
def
|
592
|
+
def __get_best_general_runs__(
|
593
|
+
self,
|
594
|
+
metric: str = "train/best_loss",
|
595
|
+
top_k: int = 5,
|
596
|
+
):
|
592
597
|
"""
|
593
|
-
|
598
|
+
Get the best runs from wandb.
|
599
|
+
Args:
|
600
|
+
metric: Metric to sort by.
|
601
|
+
top_k: Number of top runs to return.
|
602
|
+
"""
|
603
|
+
if self.wandb is None:
|
604
|
+
raise ValueError("Wandb is not initialized. Cannot get best runs.")
|
605
|
+
|
606
|
+
# Get the sweep runs
|
607
|
+
runs = sorted(self.wandb.runs, key=lambda x: x.summary.get(metric, float('inf')))
|
608
|
+
best_runs = runs[:top_k]
|
609
|
+
lower_bound = best_runs[-1].summary.get(metric, float('inf'))
|
610
|
+
upper_bound = best_runs[0].summary.get(metric, float('inf'))
|
611
|
+
print(f"Best runs from wandb {self.wandb.id}:")
|
612
|
+
for run in best_runs:
|
613
|
+
print(f"\t\tRun ID: {run.id}, Metric: {run.summary.get(metric, float('inf'))}")
|
614
|
+
return best_runs, (min(lower_bound, upper_bound), max(lower_bound, upper_bound))
|
615
|
+
|
616
|
+
def __compare_run_against_best__(self, top_k=2, metric="train/best_loss", from_sweeps=False):
|
617
|
+
"""
|
618
|
+
Compare the current run against the best runs from wandb.
|
594
619
|
Args:
|
595
620
|
top_k: Number of top runs to consider.
|
596
621
|
metric: Metric to compare against.
|
622
|
+
from_sweeps: Whether to consider runs from sweeps.
|
597
623
|
Returns:
|
598
624
|
is_good: Whether the current run is among the best.
|
599
625
|
is_best: Whether the current run is the best.
|
600
626
|
"""
|
601
627
|
# Get best runs
|
602
|
-
|
628
|
+
if from_sweeps:
|
629
|
+
best_runs, bounds = self.__get_best_sweep_runs__(metric=metric, top_k=top_k)
|
630
|
+
else:
|
631
|
+
best_runs, bounds = self.__get_best_general_runs__(metric=metric, top_k=top_k)
|
603
632
|
|
604
633
|
# Determine if lower or higher values are better (for loss, lower is better)
|
605
634
|
is_lower_better = "loss" in metric.lower()
|
@@ -621,10 +650,10 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
|
|
621
650
|
def save(self, epoch=0, step=0, state=None, rngstate=None):
|
622
651
|
super().save(epoch=epoch, step=step, state=state, rngstate=rngstate)
|
623
652
|
|
624
|
-
if self.wandb is not None
|
653
|
+
if self.wandb is not None:
|
625
654
|
checkpoint = get_latest_checkpoint(self.checkpoint_path())
|
626
655
|
try:
|
627
|
-
is_good, is_best = self.__compare_run_against_best__(top_k=5, metric="train/best_loss")
|
656
|
+
is_good, is_best = self.__compare_run_against_best__(top_k=5, metric="train/best_loss", from_sweeps=hasattr(self, "wandb_sweep"))
|
628
657
|
if is_good:
|
629
658
|
# Push to registry with appropriate aliases
|
630
659
|
aliases = []
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|