flaxdiff 0.2.8__py3-none-any.whl → 0.2.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,262 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ from flax import linen as nn
4
+ from typing import Any, Optional
5
+ import einops
6
+ from flax.typing import Dtype, PrecisionLike
7
+
8
+ from .attention import NormalAttention
9
+
10
+ def unpatchify(x, channels=3):
11
+ patch_size = int((x.shape[2] // channels) ** 0.5)
12
+ h = w = int(x.shape[1] ** .5)
13
+ assert h * w == x.shape[1] and patch_size ** 2 * \
14
+ channels == x.shape[2], f"Invalid shape: {x.shape}, should be {h*w}, {patch_size**2*channels}"
15
+ x = einops.rearrange(
16
+ x, 'B (h w) (p1 p2 C) -> B (h p1) (w p2) C', h=h, p1=patch_size, p2=patch_size)
17
+ return x
18
+
19
+
20
+ class PatchEmbedding(nn.Module):
21
+ patch_size: int
22
+ embedding_dim: int
23
+ dtype: Any = jnp.float32
24
+ precision: Any = jax.lax.Precision.HIGH
25
+
26
+ @nn.compact
27
+ def __call__(self, x):
28
+ batch, height, width, channels = x.shape
29
+ assert height % self.patch_size == 0 and width % self.patch_size == 0, "Image dimensions must be divisible by patch size"
30
+
31
+ x = nn.Conv(features=self.embedding_dim,
32
+ kernel_size=(self.patch_size, self.patch_size),
33
+ strides=(self.patch_size, self.patch_size),
34
+ dtype=self.dtype,
35
+ precision=self.precision)(x)
36
+ x = jnp.reshape(x, (batch, -1, self.embedding_dim))
37
+ return x
38
+
39
+
40
+ class PositionalEncoding(nn.Module):
41
+ max_len: int
42
+ embedding_dim: int
43
+
44
+ @nn.compact
45
+ def __call__(self, x):
46
+ pe = self.param('pos_encoding',
47
+ jax.nn.initializers.zeros,
48
+ (1, self.max_len, self.embedding_dim))
49
+ return x + pe[:, :x.shape[1], :]
50
+
51
+
52
+ # --- Rotary Positional Embedding (RoPE) ---
53
+ # Adapted from https://github.com/google-deepmind/ring_attention/blob/main/ring_attention/layers/rotary.py
54
+
55
+
56
+ def _rotate_half(x: jax.Array) -> jax.Array:
57
+ """Rotates half the hidden dims of the input."""
58
+ x1 = x[..., : x.shape[-1] // 2]
59
+ x2 = x[..., x.shape[-1] // 2:]
60
+ return jnp.concatenate((-x2, x1), axis=-1)
61
+
62
+ def apply_rotary_embedding(
63
+ x: jax.Array, freqs_cos: jax.Array, freqs_sin: jax.Array
64
+ ) -> jax.Array:
65
+ """Applies rotary embedding to the input tensor using rotate_half method."""
66
+ # x shape: [..., Sequence, Dimension] e.g. [B, H, S, D] or [B, S, D]
67
+ # freqs_cos/sin shape: [Sequence, Dimension / 2]
68
+
69
+ # Expand dims for broadcasting: [1, 1, S, D/2] or [1, S, D/2]
70
+ if x.ndim == 4: # [B, H, S, D]
71
+ cos_freqs = jnp.expand_dims(freqs_cos, axis=(0, 1))
72
+ sin_freqs = jnp.expand_dims(freqs_sin, axis=(0, 1))
73
+ elif x.ndim == 3: # [B, S, D]
74
+ cos_freqs = jnp.expand_dims(freqs_cos, axis=0)
75
+ sin_freqs = jnp.expand_dims(freqs_sin, axis=0)
76
+
77
+ # Duplicate cos and sin for the full dimension D
78
+ # Shape becomes [..., S, D]
79
+ cos_freqs = jnp.concatenate([cos_freqs, cos_freqs], axis=-1)
80
+ sin_freqs = jnp.concatenate([sin_freqs, sin_freqs], axis=-1)
81
+
82
+ # Apply rotation: x * cos + rotate_half(x) * sin
83
+ x_rotated = x * cos_freqs + _rotate_half(x) * sin_freqs
84
+ return x_rotated.astype(x.dtype)
85
+
86
+ class RotaryEmbedding(nn.Module):
87
+ dim: int
88
+ max_seq_len: int = 4096 # Increased default based on SimpleDiT
89
+ base: int = 10000
90
+ dtype: Dtype = jnp.float32
91
+
92
+ def setup(self):
93
+ inv_freq = 1.0 / (
94
+ self.base ** (jnp.arange(0, self.dim, 2,
95
+ dtype=jnp.float32) / self.dim)
96
+ )
97
+ t = jnp.arange(self.max_seq_len, dtype=jnp.float32)
98
+ freqs = jnp.outer(t, inv_freq)
99
+ self.freqs_cos = jnp.cos(freqs)
100
+ self.freqs_sin = jnp.sin(freqs)
101
+
102
+ def __call__(self, seq_len: int):
103
+ if seq_len > self.max_seq_len:
104
+ # Dynamically extend frequencies if needed (more robust)
105
+ t = jnp.arange(seq_len, dtype=jnp.float32)
106
+ inv_freq = 1.0 / (
107
+ self.base ** (jnp.arange(0, self.dim, 2,
108
+ dtype=jnp.float32) / self.dim)
109
+ )
110
+ freqs = jnp.outer(t, inv_freq)
111
+ freqs_cos = jnp.cos(freqs)
112
+ freqs_sin = jnp.sin(freqs)
113
+ # Consider caching extended freqs if this happens often
114
+ return freqs_cos, freqs_sin
115
+ # Or raise error like before:
116
+ # raise ValueError(f"Sequence length {seq_len} exceeds max_seq_len {self.max_seq_len}")
117
+ return self.freqs_cos[:seq_len, :], self.freqs_sin[:seq_len, :]
118
+
119
+
120
+ # --- Attention with RoPE ---
121
+
122
+
123
+ class RoPEAttention(NormalAttention):
124
+ rope_emb: RotaryEmbedding = None
125
+
126
+ @nn.compact
127
+ def __call__(self, x, context=None, freqs_cis=None):
128
+ orig_x_shape = x.shape
129
+ is_4d = len(orig_x_shape) == 4
130
+ if is_4d:
131
+ B, H, W, C = x.shape
132
+ seq_len = H * W
133
+ x = x.reshape((B, seq_len, C))
134
+ else:
135
+ B, seq_len, C = x.shape
136
+
137
+ context = x if context is None else context
138
+ if len(context.shape) == 4:
139
+ _B, _H, _W, _C = context.shape
140
+ context_seq_len = _H * _W
141
+ context = context.reshape((B, context_seq_len, _C))
142
+ # else: # context is already [B, S_ctx, C]
143
+
144
+ query = self.query(x) # [B, S, H, D]
145
+ key = self.key(context) # [B, S_ctx, H, D]
146
+ value = self.value(context) # [B, S_ctx, H, D]
147
+
148
+ if freqs_cis is None and self.rope_emb is not None:
149
+ seq_len_q = query.shape[1] # Use query's sequence length
150
+ freqs_cos, freqs_sin = self.rope_emb(seq_len_q)
151
+ elif freqs_cis is not None:
152
+ freqs_cos, freqs_sin = freqs_cis
153
+ else:
154
+ # Should not happen if rope_emb is provided or freqs_cis are passed
155
+ raise ValueError("RoPE frequencies not provided.")
156
+
157
+ # Apply RoPE to query and key
158
+ # Permute to [B, H, S, D] for RoPE application
159
+ query = einops.rearrange(query, 'b s h d -> b h s d')
160
+ key = einops.rearrange(key, 'b s h d -> b h s d')
161
+
162
+ # Apply RoPE only up to the context sequence length for keys if different
163
+ # Assuming self-attention or context has same seq len for simplicity here
164
+ query = apply_rotary_embedding(query, freqs_cos, freqs_sin)
165
+ key = apply_rotary_embedding(
166
+ key, freqs_cos, freqs_sin) # Apply same freqs to key
167
+
168
+ # Permute back to [B, S, H, D] for dot_product_attention
169
+ query = einops.rearrange(query, 'b h s d -> b s h d')
170
+ key = einops.rearrange(key, 'b h s d -> b s h d')
171
+
172
+ hidden_states = nn.dot_product_attention(
173
+ query, key, value, dtype=self.dtype, broadcast_dropout=False,
174
+ dropout_rng=None, precision=self.precision, force_fp32_for_softmax=self.force_fp32_for_softmax,
175
+ deterministic=True
176
+ )
177
+
178
+ proj = self.proj_attn(hidden_states)
179
+
180
+ if is_4d:
181
+ proj = proj.reshape(orig_x_shape)
182
+
183
+ return proj
184
+
185
+
186
+ # --- adaLN-Zero ---
187
+
188
+
189
+ class AdaLNZero(nn.Module):
190
+ features: int
191
+ dtype: Optional[Dtype] = None
192
+ precision: PrecisionLike = None
193
+ norm_epsilon: float = 1e-5 # Standard LayerNorm epsilon
194
+
195
+ @nn.compact
196
+ def __call__(self, x, conditioning):
197
+ # Project conditioning signal to get scale and shift parameters
198
+ # Conditioning shape: [B, D_cond] -> [B, 1, ..., 1, 6 * features] for broadcasting
199
+ # Or [B, 1, 6*features] if x is [B, S, F]
200
+
201
+ # Ensure conditioning has seq dim if x does
202
+ # x=[B,S,F], cond=[B,D_cond]
203
+ if x.ndim == 3 and conditioning.ndim == 2:
204
+ conditioning = jnp.expand_dims(
205
+ conditioning, axis=1) # cond=[B,1,D_cond]
206
+
207
+ # Project conditioning to get 6 params per feature (scale_mlp, shift_mlp, gate_mlp, scale_attn, shift_attn, gate_attn)
208
+ # Using nn.DenseGeneral for flexibility if needed, but nn.Dense is fine if cond is [B, D_cond] or [B, 1, D_cond]
209
+ ada_params = nn.Dense(
210
+ features=6 * self.features,
211
+ dtype=self.dtype,
212
+ precision=self.precision,
213
+ # Initialize projection to zero (Zero init)
214
+ kernel_init=nn.initializers.zeros,
215
+ name="ada_proj"
216
+ )(conditioning)
217
+
218
+ # Split into scale, shift, gate for MLP and Attention
219
+ scale_mlp, shift_mlp, gate_mlp, scale_attn, shift_attn, gate_attn = jnp.split(
220
+ ada_params, 6, axis=-1)
221
+
222
+ scale_mlp = jnp.clip(scale_mlp, -10.0, 10.0)
223
+ shift_mlp = jnp.clip(shift_mlp, -10.0, 10.0)
224
+ # Apply Layer Normalization
225
+ norm = nn.LayerNorm(epsilon=self.norm_epsilon,
226
+ use_scale=False, use_bias=False, dtype=self.dtype)
227
+ # norm = nn.RMSNorm(epsilon=self.norm_epsilon, dtype=self.dtype) # Alternative: RMSNorm
228
+
229
+ norm_x = norm(x)
230
+
231
+ # Modulate for Attention path
232
+ x_attn = norm_x * (1 + scale_attn) + shift_attn
233
+
234
+ # Modulate for MLP path
235
+ x_mlp = norm_x * (1 + scale_mlp) + shift_mlp
236
+
237
+ # Return modulated outputs and gates
238
+ return x_attn, gate_attn, x_mlp, gate_mlp
239
+
240
+ class AdaLNParams(nn.Module): # Renamed for clarity
241
+ features: int
242
+ dtype: Optional[Dtype] = None
243
+ precision: PrecisionLike = None
244
+
245
+ @nn.compact
246
+ def __call__(self, conditioning):
247
+ # Ensure conditioning is broadcastable if needed (e.g., [B, 1, D_cond])
248
+ if conditioning.ndim == 2:
249
+ conditioning = jnp.expand_dims(conditioning, axis=1)
250
+
251
+ # Project conditioning to get 6 params per feature
252
+ ada_params = nn.Dense(
253
+ features=6 * self.features,
254
+ dtype=self.dtype,
255
+ precision=self.precision,
256
+ kernel_init=nn.initializers.zeros,
257
+ name="ada_proj"
258
+ )(conditioning)
259
+ # Return all params (or split if preferred, but maybe return tuple/dict)
260
+ # Shape: [B, 1, 6*F]
261
+ return ada_params # Or split and return tuple: jnp.split(ada_params, 6, axis=-1)
262
+
@@ -428,6 +428,7 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
428
428
  generate_samples = val_step_fn
429
429
 
430
430
  val_ds = iter(val_ds) if val_ds else None
431
+ print(f"Validation loop started for process index {process_index} with {global_device_count} devices.")
431
432
  # Evaluation step
432
433
  try:
433
434
  metrics = {metric.name: [] for metric in self.eval_metrics} if self.eval_metrics else {}
@@ -487,7 +488,7 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
487
488
  self.wandb.log({
488
489
  f"val/{key}": value,
489
490
  }, step=current_step)
490
-
491
+ print(f"Validation metrics for process index {process_index}: {metrics}")
491
492
 
492
493
  # Close validation dataset iterator
493
494
  del val_ds
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: flaxdiff
3
- Version: 0.2.8
3
+ Version: 0.2.9
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author-email: Ashish Kumar Singh <ashishkmr472@gmail.com>
6
6
  License-Expression: MIT
@@ -2,20 +2,20 @@ flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  flaxdiff/utils.py,sha256=DmlWUY1FGz4ESxIHaPQJf92CHjsdMjyDd651wFUtyNg,8838
3
3
  flaxdiff/data/__init__.py,sha256=8W5y7NyAOWtpLi8WRawk4VYeE3DMDnM3B_jKPD8BoFQ,143
4
4
  flaxdiff/data/benchmark_decord.py,sha256=x56Db1VPmziv_9KJvWdfS0O7cffsYkF5tt5WvldOKc0,13720
5
- flaxdiff/data/dataloaders.py,sha256=HQR0rsLNYXRPBmdOBKFCc3UfWsmSbSO_-dOQHCbu_VA,23966
6
- flaxdiff/data/dataset_map.py,sha256=Dz_suGz23Cy7RfWt0FDRX7Q3NTB5SAw2UNHO_-p0qiM,5098
5
+ flaxdiff/data/dataloaders.py,sha256=k_3YGJhiY2Wt_-7qK0Yjl4pmF2QJjX_-BlSFuXbH5-M,23628
6
+ flaxdiff/data/dataset_map.py,sha256=p30U23RkfgMbR8kfPBDIjrjfzDBszWQ9Q1ff2BvDYZk,5116
7
7
  flaxdiff/data/online_loader.py,sha256=t1jEhdB6gWTlwx68ehj1ol_PrImbwXYiRlrJPCmNgCM,35701
8
8
  flaxdiff/data/sources/audio_utils.py,sha256=X27gG1yQt_abVOYgMtruYmZD7-8_uQCRhhTSpn4clkI,4514
9
9
  flaxdiff/data/sources/av_example.py,sha256=RIcbVKqckFqbfnV65NQotzIBxjdDuM67kD1nY8fqw5Q,3826
10
10
  flaxdiff/data/sources/av_utils.py,sha256=LCr9MJNurOaoxY-sjzkLqJS_MlX0x3gRSlKAVIglAU0,24045
11
11
  flaxdiff/data/sources/base.py,sha256=4Rm9pCtXxzoB8FO0lkDHsrX3ULoU_PNNcid978e6ir0,4610
12
- flaxdiff/data/sources/images.py,sha256=71TzTVbPzV-Md3-1Lk4eWfb11w6aaO01OClwK_SiCSM,14708
12
+ flaxdiff/data/sources/images.py,sha256=ZHBmZ2fnPN75Hc2kiog-Wcs_NZJZOiqw4WcSH5WZJHA,16572
13
13
  flaxdiff/data/sources/utils.py,sha256=kFzM4_kPoThbAu54ulABmEDAR33tR50NgzXIpC0Dzjk,7316
14
14
  flaxdiff/data/sources/videos.py,sha256=NkxwEruNpAwDCM53q4WurQ802gSjQMOqjNLxYOqjoNE,9545
15
15
  flaxdiff/data/sources/voxceleb2.py,sha256=BoKfat_hsw6ObDyyaiQmPbBzuFiqgCGlgAZmf-t5Iz8,18621
16
16
  flaxdiff/inference/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
17
17
  flaxdiff/inference/pipeline.py,sha256=8S30FAlXEjvrDd87H-qdD6biySQZ3cJUflU8gdmPxig,9223
18
- flaxdiff/inference/utils.py,sha256=MVnWl0LnC-1ILk0SsLd1YFu6igaQFR7mGhzo0jE797E,12323
18
+ flaxdiff/inference/utils.py,sha256=JEBZYSgj-0DLJTV-TNmIAllAqqVJMn0KfryHwFO-MFs,12606
19
19
  flaxdiff/inputs/__init__.py,sha256=ybPjQsFAf5sqRVZG1sRiOl99EnwpI-NQ8HE3y7UbXmU,7197
20
20
  flaxdiff/inputs/encoders.py,sha256=pjfbx4Rk7bLoE80MOfThZDm6YtsDncRekmn0Bmg_CwI,2963
21
21
  flaxdiff/metrics/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -31,12 +31,13 @@ flaxdiff/models/common.py,sha256=QpciwuJldvLUwyAyWBQqiPPGVI-c9qLR7h7C1YoRX7w,105
31
31
  flaxdiff/models/favor_fastattn.py,sha256=79Ew1nqarsNLPzZaBSd1ILORzJr74CupYeqGiCQK5E4,27689
32
32
  flaxdiff/models/general.py,sha256=7xMME6KVKQY8sScyHYH4f-Kek4j1pRfplKShFXwVZd4,587
33
33
  flaxdiff/models/hilbert.py,sha256=AjlAv49dL6UAYWslMJfCMLiFqY4kTgpiUWr2nc1mk34,24823
34
- flaxdiff/models/simple_dit.py,sha256=Hc2jLOZCYSDm6x88m3bGYu-OKge1TukiQPSdlaO68rE,19667
35
- flaxdiff/models/simple_mmdit.py,sha256=RmOq6LbfDBUUEib6MSAURujxn9iHgdh77a6ntNsWI2w,36210
34
+ flaxdiff/models/simple_dit.py,sha256=l238MYHRTArv_pS57aY24C2PTfxeL8EmzJ24iQqdoWI,11702
35
+ flaxdiff/models/simple_mmdit.py,sha256=ARk0juopn2k7giln5BAUrnYD1pTFwgTJoSzrhozQ0A8,31356
36
36
  flaxdiff/models/simple_unet.py,sha256=pjeixszG_6gEY5PNFbQ7KbOyg4z5bfn4RUbINCJexOM,10758
37
- flaxdiff/models/simple_vit.py,sha256=QEHPyaQIYhqSYrD6eb65X70jQL-y09nRT8Yc4b5Jq6Q,15181
37
+ flaxdiff/models/simple_vit.py,sha256=J9s3hBF87_iVrJDBe2cs9a56N7ect6pux_f_ge07XXc,17357
38
38
  flaxdiff/models/unet_3d.py,sha256=LF0PMxBKGU-_lAMtO_Coxy1yRE02yKKdgb7i6YZxI_4,20163
39
39
  flaxdiff/models/unet_3d_blocks.py,sha256=lRYDc9X1VEu54Kg7xEEphXYiQ09tabPXKi-hEcKFYug,19687
40
+ flaxdiff/models/vit_common.py,sha256=1OGu4ezY3uzKinTnw3p8YkQAslHDqEbN78JheXnTleI,9831
40
41
  flaxdiff/models/autoencoder/__init__.py,sha256=qY-7MldZpsfkF-_T2LqlRK7VHbqfmosz0NmvzDlBkOk,78
41
42
  flaxdiff/models/autoencoder/autoencoder.py,sha256=8XWdsWvsPsyWGtzpCT8w0KXi_ZLGpRuQpn4oXo1gHKw,6039
42
43
  flaxdiff/models/autoencoder/diffusers.py,sha256=tPz77YuctrT--jF2AOL8G6vr0NiIr3RXANNrZCxe0bg,5921
@@ -62,9 +63,9 @@ flaxdiff/schedulers/sqrt.py,sha256=mCd_szmOqF6vqQKiAiEOqV_3eBIPGYrW3VxK0o4rBuo,4
62
63
  flaxdiff/trainer/__init__.py,sha256=xSoierfi26gxfgxlNnwvyyPmuPAJ--5i3mEHxt3S-AE,215
63
64
  flaxdiff/trainer/autoencoder_trainer.py,sha256=2FP2P-k9c0n_k3eT0trkq73dQrHRdBj9ObK1idcyhSw,6996
64
65
  flaxdiff/trainer/diffusion_trainer.py,sha256=reQEVWKTqKAeyCMQ-curPOfSRmBKxKooK8EVtUuorcM,14599
65
- flaxdiff/trainer/general_diffusion_trainer.py,sha256=FUvc--3ibRAjrYiKbA-FyLqKhusakxeNOa6UJZaK4SU,29307
66
+ flaxdiff/trainer/general_diffusion_trainer.py,sha256=OtE2spZIBFPpY6q-ijYol5Y-CaP2UHJYIDX3PFBiPtg,29492
66
67
  flaxdiff/trainer/simple_trainer.py,sha256=Hdltuo3lgF61N04Lxc7L3z6NLveW4_h1ff7_5mu3Wbg,28730
67
- flaxdiff-0.2.8.dist-info/METADATA,sha256=y2jLjsEkR-GKvLWuGzlyBrk1SNM6tCPT0Oc7vRZC7_I,24057
68
- flaxdiff-0.2.8.dist-info/WHEEL,sha256=0CuiUZ_p9E4cD6NyLD6UG80LBXYyiSYZOKDm5lp32xk,91
69
- flaxdiff-0.2.8.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
70
- flaxdiff-0.2.8.dist-info/RECORD,,
68
+ flaxdiff-0.2.9.dist-info/METADATA,sha256=a8btxHRkAZVieuZfTyXgPkJbEG9fZRknEhq2Ti3_7m4,24057
69
+ flaxdiff-0.2.9.dist-info/WHEEL,sha256=0CuiUZ_p9E4cD6NyLD6UG80LBXYyiSYZOKDm5lp32xk,91
70
+ flaxdiff-0.2.9.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
71
+ flaxdiff-0.2.9.dist-info/RECORD,,