flaxdiff 0.2.8__py3-none-any.whl → 0.2.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,262 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ from flax import linen as nn
4
+ from typing import Any, Optional
5
+ import einops
6
+ from flax.typing import Dtype, PrecisionLike
7
+
8
+ from .attention import NormalAttention
9
+
10
+ def unpatchify(x, channels=3):
11
+ patch_size = int((x.shape[2] // channels) ** 0.5)
12
+ h = w = int(x.shape[1] ** .5)
13
+ assert h * w == x.shape[1] and patch_size ** 2 * \
14
+ channels == x.shape[2], f"Invalid shape: {x.shape}, should be {h*w}, {patch_size**2*channels}"
15
+ x = einops.rearrange(
16
+ x, 'B (h w) (p1 p2 C) -> B (h p1) (w p2) C', h=h, p1=patch_size, p2=patch_size)
17
+ return x
18
+
19
+
20
+ class PatchEmbedding(nn.Module):
21
+ patch_size: int
22
+ embedding_dim: int
23
+ dtype: Any = jnp.float32
24
+ precision: Any = jax.lax.Precision.HIGH
25
+
26
+ @nn.compact
27
+ def __call__(self, x):
28
+ batch, height, width, channels = x.shape
29
+ assert height % self.patch_size == 0 and width % self.patch_size == 0, "Image dimensions must be divisible by patch size"
30
+
31
+ x = nn.Conv(features=self.embedding_dim,
32
+ kernel_size=(self.patch_size, self.patch_size),
33
+ strides=(self.patch_size, self.patch_size),
34
+ dtype=self.dtype,
35
+ precision=self.precision)(x)
36
+ x = jnp.reshape(x, (batch, -1, self.embedding_dim))
37
+ return x
38
+
39
+
40
+ class PositionalEncoding(nn.Module):
41
+ max_len: int
42
+ embedding_dim: int
43
+
44
+ @nn.compact
45
+ def __call__(self, x):
46
+ pe = self.param('pos_encoding',
47
+ jax.nn.initializers.zeros,
48
+ (1, self.max_len, self.embedding_dim))
49
+ return x + pe[:, :x.shape[1], :]
50
+
51
+
52
+ # --- Rotary Positional Embedding (RoPE) ---
53
+ # Adapted from https://github.com/google-deepmind/ring_attention/blob/main/ring_attention/layers/rotary.py
54
+
55
+
56
+ def _rotate_half(x: jax.Array) -> jax.Array:
57
+ """Rotates half the hidden dims of the input."""
58
+ x1 = x[..., : x.shape[-1] // 2]
59
+ x2 = x[..., x.shape[-1] // 2:]
60
+ return jnp.concatenate((-x2, x1), axis=-1)
61
+
62
+ def apply_rotary_embedding(
63
+ x: jax.Array, freqs_cos: jax.Array, freqs_sin: jax.Array
64
+ ) -> jax.Array:
65
+ """Applies rotary embedding to the input tensor using rotate_half method."""
66
+ # x shape: [..., Sequence, Dimension] e.g. [B, H, S, D] or [B, S, D]
67
+ # freqs_cos/sin shape: [Sequence, Dimension / 2]
68
+
69
+ # Expand dims for broadcasting: [1, 1, S, D/2] or [1, S, D/2]
70
+ if x.ndim == 4: # [B, H, S, D]
71
+ cos_freqs = jnp.expand_dims(freqs_cos, axis=(0, 1))
72
+ sin_freqs = jnp.expand_dims(freqs_sin, axis=(0, 1))
73
+ elif x.ndim == 3: # [B, S, D]
74
+ cos_freqs = jnp.expand_dims(freqs_cos, axis=0)
75
+ sin_freqs = jnp.expand_dims(freqs_sin, axis=0)
76
+
77
+ # Duplicate cos and sin for the full dimension D
78
+ # Shape becomes [..., S, D]
79
+ cos_freqs = jnp.concatenate([cos_freqs, cos_freqs], axis=-1)
80
+ sin_freqs = jnp.concatenate([sin_freqs, sin_freqs], axis=-1)
81
+
82
+ # Apply rotation: x * cos + rotate_half(x) * sin
83
+ x_rotated = x * cos_freqs + _rotate_half(x) * sin_freqs
84
+ return x_rotated.astype(x.dtype)
85
+
86
+ class RotaryEmbedding(nn.Module):
87
+ dim: int
88
+ max_seq_len: int = 4096 # Increased default based on SimpleDiT
89
+ base: int = 10000
90
+ dtype: Dtype = jnp.float32
91
+
92
+ def setup(self):
93
+ inv_freq = 1.0 / (
94
+ self.base ** (jnp.arange(0, self.dim, 2,
95
+ dtype=jnp.float32) / self.dim)
96
+ )
97
+ t = jnp.arange(self.max_seq_len, dtype=jnp.float32)
98
+ freqs = jnp.outer(t, inv_freq)
99
+ self.freqs_cos = jnp.cos(freqs)
100
+ self.freqs_sin = jnp.sin(freqs)
101
+
102
+ def __call__(self, seq_len: int):
103
+ if seq_len > self.max_seq_len:
104
+ # Dynamically extend frequencies if needed (more robust)
105
+ t = jnp.arange(seq_len, dtype=jnp.float32)
106
+ inv_freq = 1.0 / (
107
+ self.base ** (jnp.arange(0, self.dim, 2,
108
+ dtype=jnp.float32) / self.dim)
109
+ )
110
+ freqs = jnp.outer(t, inv_freq)
111
+ freqs_cos = jnp.cos(freqs)
112
+ freqs_sin = jnp.sin(freqs)
113
+ # Consider caching extended freqs if this happens often
114
+ return freqs_cos, freqs_sin
115
+ # Or raise error like before:
116
+ # raise ValueError(f"Sequence length {seq_len} exceeds max_seq_len {self.max_seq_len}")
117
+ return self.freqs_cos[:seq_len, :], self.freqs_sin[:seq_len, :]
118
+
119
+
120
+ # --- Attention with RoPE ---
121
+
122
+
123
+ class RoPEAttention(NormalAttention):
124
+ rope_emb: RotaryEmbedding = None
125
+
126
+ @nn.compact
127
+ def __call__(self, x, context=None, freqs_cis=None):
128
+ orig_x_shape = x.shape
129
+ is_4d = len(orig_x_shape) == 4
130
+ if is_4d:
131
+ B, H, W, C = x.shape
132
+ seq_len = H * W
133
+ x = x.reshape((B, seq_len, C))
134
+ else:
135
+ B, seq_len, C = x.shape
136
+
137
+ context = x if context is None else context
138
+ if len(context.shape) == 4:
139
+ _B, _H, _W, _C = context.shape
140
+ context_seq_len = _H * _W
141
+ context = context.reshape((B, context_seq_len, _C))
142
+ # else: # context is already [B, S_ctx, C]
143
+
144
+ query = self.query(x) # [B, S, H, D]
145
+ key = self.key(context) # [B, S_ctx, H, D]
146
+ value = self.value(context) # [B, S_ctx, H, D]
147
+
148
+ if freqs_cis is None and self.rope_emb is not None:
149
+ seq_len_q = query.shape[1] # Use query's sequence length
150
+ freqs_cos, freqs_sin = self.rope_emb(seq_len_q)
151
+ elif freqs_cis is not None:
152
+ freqs_cos, freqs_sin = freqs_cis
153
+ else:
154
+ # Should not happen if rope_emb is provided or freqs_cis are passed
155
+ raise ValueError("RoPE frequencies not provided.")
156
+
157
+ # Apply RoPE to query and key
158
+ # Permute to [B, H, S, D] for RoPE application
159
+ query = einops.rearrange(query, 'b s h d -> b h s d')
160
+ key = einops.rearrange(key, 'b s h d -> b h s d')
161
+
162
+ # Apply RoPE only up to the context sequence length for keys if different
163
+ # Assuming self-attention or context has same seq len for simplicity here
164
+ query = apply_rotary_embedding(query, freqs_cos, freqs_sin)
165
+ key = apply_rotary_embedding(
166
+ key, freqs_cos, freqs_sin) # Apply same freqs to key
167
+
168
+ # Permute back to [B, S, H, D] for dot_product_attention
169
+ query = einops.rearrange(query, 'b h s d -> b s h d')
170
+ key = einops.rearrange(key, 'b h s d -> b s h d')
171
+
172
+ hidden_states = nn.dot_product_attention(
173
+ query, key, value, dtype=self.dtype, broadcast_dropout=False,
174
+ dropout_rng=None, precision=self.precision, force_fp32_for_softmax=self.force_fp32_for_softmax,
175
+ deterministic=True
176
+ )
177
+
178
+ proj = self.proj_attn(hidden_states)
179
+
180
+ if is_4d:
181
+ proj = proj.reshape(orig_x_shape)
182
+
183
+ return proj
184
+
185
+
186
+ # --- adaLN-Zero ---
187
+
188
+
189
+ class AdaLNZero(nn.Module):
190
+ features: int
191
+ dtype: Optional[Dtype] = None
192
+ precision: PrecisionLike = None
193
+ norm_epsilon: float = 1e-5 # Standard LayerNorm epsilon
194
+
195
+ @nn.compact
196
+ def __call__(self, x, conditioning):
197
+ # Project conditioning signal to get scale and shift parameters
198
+ # Conditioning shape: [B, D_cond] -> [B, 1, ..., 1, 6 * features] for broadcasting
199
+ # Or [B, 1, 6*features] if x is [B, S, F]
200
+
201
+ # Ensure conditioning has seq dim if x does
202
+ # x=[B,S,F], cond=[B,D_cond]
203
+ if x.ndim == 3 and conditioning.ndim == 2:
204
+ conditioning = jnp.expand_dims(
205
+ conditioning, axis=1) # cond=[B,1,D_cond]
206
+
207
+ # Project conditioning to get 6 params per feature (scale_mlp, shift_mlp, gate_mlp, scale_attn, shift_attn, gate_attn)
208
+ # Using nn.DenseGeneral for flexibility if needed, but nn.Dense is fine if cond is [B, D_cond] or [B, 1, D_cond]
209
+ ada_params = nn.Dense(
210
+ features=6 * self.features,
211
+ dtype=self.dtype,
212
+ precision=self.precision,
213
+ # Initialize projection to zero (Zero init)
214
+ kernel_init=nn.initializers.zeros,
215
+ name="ada_proj"
216
+ )(conditioning)
217
+
218
+ # Split into scale, shift, gate for MLP and Attention
219
+ scale_mlp, shift_mlp, gate_mlp, scale_attn, shift_attn, gate_attn = jnp.split(
220
+ ada_params, 6, axis=-1)
221
+
222
+ scale_mlp = jnp.clip(scale_mlp, -10.0, 10.0)
223
+ shift_mlp = jnp.clip(shift_mlp, -10.0, 10.0)
224
+ # Apply Layer Normalization
225
+ norm = nn.LayerNorm(epsilon=self.norm_epsilon,
226
+ use_scale=False, use_bias=False, dtype=self.dtype)
227
+ # norm = nn.RMSNorm(epsilon=self.norm_epsilon, dtype=self.dtype) # Alternative: RMSNorm
228
+
229
+ norm_x = norm(x)
230
+
231
+ # Modulate for Attention path
232
+ x_attn = norm_x * (1 + scale_attn) + shift_attn
233
+
234
+ # Modulate for MLP path
235
+ x_mlp = norm_x * (1 + scale_mlp) + shift_mlp
236
+
237
+ # Return modulated outputs and gates
238
+ return x_attn, gate_attn, x_mlp, gate_mlp
239
+
240
+ class AdaLNParams(nn.Module): # Renamed for clarity
241
+ features: int
242
+ dtype: Optional[Dtype] = None
243
+ precision: PrecisionLike = None
244
+
245
+ @nn.compact
246
+ def __call__(self, conditioning):
247
+ # Ensure conditioning is broadcastable if needed (e.g., [B, 1, D_cond])
248
+ if conditioning.ndim == 2:
249
+ conditioning = jnp.expand_dims(conditioning, axis=1)
250
+
251
+ # Project conditioning to get 6 params per feature
252
+ ada_params = nn.Dense(
253
+ features=6 * self.features,
254
+ dtype=self.dtype,
255
+ precision=self.precision,
256
+ kernel_init=nn.initializers.zeros,
257
+ name="ada_proj"
258
+ )(conditioning)
259
+ # Return all params (or split if preferred, but maybe return tuple/dict)
260
+ # Shape: [B, 1, 6*F]
261
+ return ada_params # Or split and return tuple: jnp.split(ada_params, 6, axis=-1)
262
+
@@ -427,7 +427,8 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
427
427
  process_index = jax.process_index()
428
428
  generate_samples = val_step_fn
429
429
 
430
- val_ds = iter(val_ds) if val_ds else None
430
+ val_ds = iter(val_ds()) if val_ds else None
431
+ print(f"Validation loop started for process index {process_index} with {global_device_count} devices.")
431
432
  # Evaluation step
432
433
  try:
433
434
  metrics = {metric.name: [] for metric in self.eval_metrics} if self.eval_metrics else {}
@@ -474,10 +475,11 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
474
475
  metrics = {k: np.mean(v) for k, v in metrics.items()}
475
476
  # Update the best validation metrics
476
477
  for key, value in metrics.items():
477
- if key not in self.best_val_metrics:
478
- self.best_val_metrics[key] = value
478
+ final_key = f"val/{key}"
479
+ if final_key not in self.best_val_metrics:
480
+ self.best_val_metrics[final_key] = value
479
481
  else:
480
- self.best_val_metrics[key] = min(self.best_val_metrics[key], value)
482
+ self.best_val_metrics[final_key] = min(self.best_val_metrics[final_key], value)
481
483
  # Log the best validation metrics
482
484
  if getattr(self, 'wandb', None) is not None and self.wandb:
483
485
  # Log the metrics
@@ -487,7 +489,14 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
487
489
  self.wandb.log({
488
490
  f"val/{key}": value,
489
491
  }, step=current_step)
490
-
492
+ # Log the best validation metrics
493
+ for key, value in self.best_val_metrics.items():
494
+ if isinstance(value, jnp.ndarray):
495
+ value = np.array(value)
496
+ self.wandb.log({
497
+ f"best_{key}": value,
498
+ }, step=current_step)
499
+ print(f"Validation metrics for process index {process_index}: {metrics}")
491
500
 
492
501
  # Close validation dataset iterator
493
502
  del val_ds
@@ -621,10 +630,10 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
621
630
  if not runs:
622
631
  raise ValueError("No runs found in wandb.")
623
632
  print(f"Getting best runs from wandb {self.wandb.id}...")
624
- runs = sorted(runs, key=lambda x: x.summary.get(metric, float('inf')))
633
+ runs = sorted(runs, key=lambda x: x.summary.get(f"best_{metric}", float('inf')))
625
634
  best_runs = runs[:top_k]
626
- lower_bound = best_runs[-1].summary.get(metric, float('inf'))
627
- upper_bound = best_runs[0].summary.get(metric, float('inf'))
635
+ lower_bound = best_runs[-1].summary.get(f"best_{metric}", float('inf'))
636
+ upper_bound = best_runs[0].summary.get(f"best_{metric}", float('inf'))
628
637
  print(f"Best runs from wandb {self.wandb.id}:")
629
638
  for run in best_runs:
630
639
  print(f"\t\tRun ID: {run.id}, Metric: {run.summary.get(metric, float('inf'))}")
@@ -648,19 +657,21 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
648
657
  best_runs, bounds = self.__get_best_general_runs__(metric=metric, top_k=top_k)
649
658
 
650
659
  # Determine if lower or higher values are better (for loss, lower is better)
651
- is_lower_better = "loss" in metric.lower()
660
+ is_lower_better = True
652
661
 
653
662
  # Check if current run is one of the best
654
663
  if metric == "train/best_loss":
655
664
  current_run_metric = self.best_loss
656
665
  elif metric in self.best_val_metrics:
666
+ print(f"Fetching best validation metric {metric} from local")
657
667
  current_run_metric = self.best_val_metrics[metric]
658
668
  else:
659
669
  current_run_metric = self.wandb.summary.get(metric, float('inf') if is_lower_better else float('-inf'))
660
-
670
+
671
+ print(f"Current run {self.wandb.id} metric: {current_run_metric}, Best bounds: {bounds}")
661
672
  # Check based on bounds
662
673
  if (is_lower_better and current_run_metric < bounds[1]) or (not is_lower_better and current_run_metric > bounds[0]):
663
- print(f"Current run {self.wandb.id} meets performance criteria.")
674
+ print(f"Current run {self.wandb.id} meets performance criteria. Current metric: {current_run_metric}, Best bounds: {bounds}")
664
675
  is_best = (is_lower_better and current_run_metric < bounds[0]) or (not is_lower_better and current_run_metric > bounds[1])
665
676
  return True, is_best
666
677
 
@@ -600,7 +600,7 @@ class SimpleTrainer:
600
600
 
601
601
  def fit(self, data, train_steps_per_epoch, epochs, train_step_args={}, val_steps_per_epoch=5, validation_step_args={}):
602
602
  train_ds = iter(data['train']())
603
- val_ds = data.get('val', data.get('test', None))()
603
+ val_ds = data.get('val', data.get('test', None))
604
604
  train_step = self._define_train_step(**train_step_args)
605
605
  val_step = self._define_validation_step(**validation_step_args)
606
606
  train_state = self.state
@@ -642,6 +642,19 @@ class SimpleTrainer:
642
642
  self.rngstate = rng_state
643
643
  total_time = end_time - start_time
644
644
  avg_time_per_step = total_time / train_steps_per_epoch
645
+
646
+ if val_steps_per_epoch > 0:
647
+ print(f"Validation started for process index {process_index}")
648
+ # Validation step
649
+ self.validation_loop(
650
+ train_state,
651
+ val_step,
652
+ val_ds,
653
+ val_steps_per_epoch,
654
+ current_step,
655
+ )
656
+ print(colored(f"Validation done on process index {process_index}", PROCESS_COLOR_MAP[process_index]))
657
+
645
658
  avg_loss = epoch_loss / train_steps_per_epoch
646
659
  if avg_loss < self.best_loss:
647
660
  self.best_loss = avg_loss
@@ -659,17 +672,6 @@ class SimpleTrainer:
659
672
  }, step=current_step)
660
673
  print(colored(f"\n\tEpoch {current_epoch} completed. Avg Loss: {avg_loss}, Time: {total_time:.2f}s, Best Loss: {self.best_loss}", 'green'))
661
674
 
662
- if val_steps_per_epoch > 0:
663
- print(f"Validation started for process index {process_index}")
664
- # Validation step
665
- self.validation_loop(
666
- train_state,
667
- val_step,
668
- val_ds,
669
- val_steps_per_epoch,
670
- current_step,
671
- )
672
- print(colored(f"Validation done on process index {process_index}", PROCESS_COLOR_MAP[process_index]))
673
675
 
674
676
  self.save(epochs)#
675
677
  return self.state
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: flaxdiff
3
- Version: 0.2.8
3
+ Version: 0.2.10
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author-email: Ashish Kumar Singh <ashishkmr472@gmail.com>
6
6
  License-Expression: MIT
@@ -2,20 +2,20 @@ flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  flaxdiff/utils.py,sha256=DmlWUY1FGz4ESxIHaPQJf92CHjsdMjyDd651wFUtyNg,8838
3
3
  flaxdiff/data/__init__.py,sha256=8W5y7NyAOWtpLi8WRawk4VYeE3DMDnM3B_jKPD8BoFQ,143
4
4
  flaxdiff/data/benchmark_decord.py,sha256=x56Db1VPmziv_9KJvWdfS0O7cffsYkF5tt5WvldOKc0,13720
5
- flaxdiff/data/dataloaders.py,sha256=HQR0rsLNYXRPBmdOBKFCc3UfWsmSbSO_-dOQHCbu_VA,23966
6
- flaxdiff/data/dataset_map.py,sha256=Dz_suGz23Cy7RfWt0FDRX7Q3NTB5SAw2UNHO_-p0qiM,5098
5
+ flaxdiff/data/dataloaders.py,sha256=k_3YGJhiY2Wt_-7qK0Yjl4pmF2QJjX_-BlSFuXbH5-M,23628
6
+ flaxdiff/data/dataset_map.py,sha256=p30U23RkfgMbR8kfPBDIjrjfzDBszWQ9Q1ff2BvDYZk,5116
7
7
  flaxdiff/data/online_loader.py,sha256=t1jEhdB6gWTlwx68ehj1ol_PrImbwXYiRlrJPCmNgCM,35701
8
8
  flaxdiff/data/sources/audio_utils.py,sha256=X27gG1yQt_abVOYgMtruYmZD7-8_uQCRhhTSpn4clkI,4514
9
9
  flaxdiff/data/sources/av_example.py,sha256=RIcbVKqckFqbfnV65NQotzIBxjdDuM67kD1nY8fqw5Q,3826
10
10
  flaxdiff/data/sources/av_utils.py,sha256=LCr9MJNurOaoxY-sjzkLqJS_MlX0x3gRSlKAVIglAU0,24045
11
11
  flaxdiff/data/sources/base.py,sha256=4Rm9pCtXxzoB8FO0lkDHsrX3ULoU_PNNcid978e6ir0,4610
12
- flaxdiff/data/sources/images.py,sha256=71TzTVbPzV-Md3-1Lk4eWfb11w6aaO01OClwK_SiCSM,14708
12
+ flaxdiff/data/sources/images.py,sha256=ZHBmZ2fnPN75Hc2kiog-Wcs_NZJZOiqw4WcSH5WZJHA,16572
13
13
  flaxdiff/data/sources/utils.py,sha256=kFzM4_kPoThbAu54ulABmEDAR33tR50NgzXIpC0Dzjk,7316
14
14
  flaxdiff/data/sources/videos.py,sha256=NkxwEruNpAwDCM53q4WurQ802gSjQMOqjNLxYOqjoNE,9545
15
15
  flaxdiff/data/sources/voxceleb2.py,sha256=BoKfat_hsw6ObDyyaiQmPbBzuFiqgCGlgAZmf-t5Iz8,18621
16
16
  flaxdiff/inference/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
17
17
  flaxdiff/inference/pipeline.py,sha256=8S30FAlXEjvrDd87H-qdD6biySQZ3cJUflU8gdmPxig,9223
18
- flaxdiff/inference/utils.py,sha256=MVnWl0LnC-1ILk0SsLd1YFu6igaQFR7mGhzo0jE797E,12323
18
+ flaxdiff/inference/utils.py,sha256=Dh0KawgvQrZxyqN_9wbsb7gUyvPRendwb-YtAU6zIBE,12606
19
19
  flaxdiff/inputs/__init__.py,sha256=ybPjQsFAf5sqRVZG1sRiOl99EnwpI-NQ8HE3y7UbXmU,7197
20
20
  flaxdiff/inputs/encoders.py,sha256=pjfbx4Rk7bLoE80MOfThZDm6YtsDncRekmn0Bmg_CwI,2963
21
21
  flaxdiff/metrics/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -31,12 +31,13 @@ flaxdiff/models/common.py,sha256=QpciwuJldvLUwyAyWBQqiPPGVI-c9qLR7h7C1YoRX7w,105
31
31
  flaxdiff/models/favor_fastattn.py,sha256=79Ew1nqarsNLPzZaBSd1ILORzJr74CupYeqGiCQK5E4,27689
32
32
  flaxdiff/models/general.py,sha256=7xMME6KVKQY8sScyHYH4f-Kek4j1pRfplKShFXwVZd4,587
33
33
  flaxdiff/models/hilbert.py,sha256=AjlAv49dL6UAYWslMJfCMLiFqY4kTgpiUWr2nc1mk34,24823
34
- flaxdiff/models/simple_dit.py,sha256=Hc2jLOZCYSDm6x88m3bGYu-OKge1TukiQPSdlaO68rE,19667
35
- flaxdiff/models/simple_mmdit.py,sha256=RmOq6LbfDBUUEib6MSAURujxn9iHgdh77a6ntNsWI2w,36210
34
+ flaxdiff/models/simple_dit.py,sha256=l238MYHRTArv_pS57aY24C2PTfxeL8EmzJ24iQqdoWI,11702
35
+ flaxdiff/models/simple_mmdit.py,sha256=ARk0juopn2k7giln5BAUrnYD1pTFwgTJoSzrhozQ0A8,31356
36
36
  flaxdiff/models/simple_unet.py,sha256=pjeixszG_6gEY5PNFbQ7KbOyg4z5bfn4RUbINCJexOM,10758
37
- flaxdiff/models/simple_vit.py,sha256=QEHPyaQIYhqSYrD6eb65X70jQL-y09nRT8Yc4b5Jq6Q,15181
37
+ flaxdiff/models/simple_vit.py,sha256=J9s3hBF87_iVrJDBe2cs9a56N7ect6pux_f_ge07XXc,17357
38
38
  flaxdiff/models/unet_3d.py,sha256=LF0PMxBKGU-_lAMtO_Coxy1yRE02yKKdgb7i6YZxI_4,20163
39
39
  flaxdiff/models/unet_3d_blocks.py,sha256=lRYDc9X1VEu54Kg7xEEphXYiQ09tabPXKi-hEcKFYug,19687
40
+ flaxdiff/models/vit_common.py,sha256=1OGu4ezY3uzKinTnw3p8YkQAslHDqEbN78JheXnTleI,9831
40
41
  flaxdiff/models/autoencoder/__init__.py,sha256=qY-7MldZpsfkF-_T2LqlRK7VHbqfmosz0NmvzDlBkOk,78
41
42
  flaxdiff/models/autoencoder/autoencoder.py,sha256=8XWdsWvsPsyWGtzpCT8w0KXi_ZLGpRuQpn4oXo1gHKw,6039
42
43
  flaxdiff/models/autoencoder/diffusers.py,sha256=tPz77YuctrT--jF2AOL8G6vr0NiIr3RXANNrZCxe0bg,5921
@@ -62,9 +63,9 @@ flaxdiff/schedulers/sqrt.py,sha256=mCd_szmOqF6vqQKiAiEOqV_3eBIPGYrW3VxK0o4rBuo,4
62
63
  flaxdiff/trainer/__init__.py,sha256=xSoierfi26gxfgxlNnwvyyPmuPAJ--5i3mEHxt3S-AE,215
63
64
  flaxdiff/trainer/autoencoder_trainer.py,sha256=2FP2P-k9c0n_k3eT0trkq73dQrHRdBj9ObK1idcyhSw,6996
64
65
  flaxdiff/trainer/diffusion_trainer.py,sha256=reQEVWKTqKAeyCMQ-curPOfSRmBKxKooK8EVtUuorcM,14599
65
- flaxdiff/trainer/general_diffusion_trainer.py,sha256=FUvc--3ibRAjrYiKbA-FyLqKhusakxeNOa6UJZaK4SU,29307
66
- flaxdiff/trainer/simple_trainer.py,sha256=Hdltuo3lgF61N04Lxc7L3z6NLveW4_h1ff7_5mu3Wbg,28730
67
- flaxdiff-0.2.8.dist-info/METADATA,sha256=y2jLjsEkR-GKvLWuGzlyBrk1SNM6tCPT0Oc7vRZC7_I,24057
68
- flaxdiff-0.2.8.dist-info/WHEEL,sha256=0CuiUZ_p9E4cD6NyLD6UG80LBXYyiSYZOKDm5lp32xk,91
69
- flaxdiff-0.2.8.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
70
- flaxdiff-0.2.8.dist-info/RECORD,,
66
+ flaxdiff/trainer/general_diffusion_trainer.py,sha256=gMo0OOz8EFKGfiqZnDwhVSxtk_IUMGUvyt5TTr_Hk8g,30168
67
+ flaxdiff/trainer/simple_trainer.py,sha256=nXYy9tadteG8N0RovpevPPEs6oeFvbr2gVq7Zot9l78,28754
68
+ flaxdiff-0.2.10.dist-info/METADATA,sha256=xsqksvLSps2a9nNdvZkguWvsC07kX8A3Z26DPTq-tGI,24058
69
+ flaxdiff-0.2.10.dist-info/WHEEL,sha256=0CuiUZ_p9E4cD6NyLD6UG80LBXYyiSYZOKDm5lp32xk,91
70
+ flaxdiff-0.2.10.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
71
+ flaxdiff-0.2.10.dist-info/RECORD,,