flaxdiff 0.2.8__py3-none-any.whl → 0.2.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/data/dataloaders.py +11 -19
- flaxdiff/data/dataset_map.py +2 -1
- flaxdiff/data/sources/images.py +29 -14
- flaxdiff/inference/utils.py +7 -1
- flaxdiff/models/simple_dit.py +1 -202
- flaxdiff/models/simple_mmdit.py +1 -132
- flaxdiff/models/simple_vit.py +217 -118
- flaxdiff/models/vit_common.py +262 -0
- flaxdiff/trainer/general_diffusion_trainer.py +22 -11
- flaxdiff/trainer/simple_trainer.py +14 -12
- {flaxdiff-0.2.8.dist-info → flaxdiff-0.2.10.dist-info}/METADATA +1 -1
- {flaxdiff-0.2.8.dist-info → flaxdiff-0.2.10.dist-info}/RECORD +14 -13
- {flaxdiff-0.2.8.dist-info → flaxdiff-0.2.10.dist-info}/WHEEL +0 -0
- {flaxdiff-0.2.8.dist-info → flaxdiff-0.2.10.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,262 @@
|
|
1
|
+
import jax
|
2
|
+
import jax.numpy as jnp
|
3
|
+
from flax import linen as nn
|
4
|
+
from typing import Any, Optional
|
5
|
+
import einops
|
6
|
+
from flax.typing import Dtype, PrecisionLike
|
7
|
+
|
8
|
+
from .attention import NormalAttention
|
9
|
+
|
10
|
+
def unpatchify(x, channels=3):
|
11
|
+
patch_size = int((x.shape[2] // channels) ** 0.5)
|
12
|
+
h = w = int(x.shape[1] ** .5)
|
13
|
+
assert h * w == x.shape[1] and patch_size ** 2 * \
|
14
|
+
channels == x.shape[2], f"Invalid shape: {x.shape}, should be {h*w}, {patch_size**2*channels}"
|
15
|
+
x = einops.rearrange(
|
16
|
+
x, 'B (h w) (p1 p2 C) -> B (h p1) (w p2) C', h=h, p1=patch_size, p2=patch_size)
|
17
|
+
return x
|
18
|
+
|
19
|
+
|
20
|
+
class PatchEmbedding(nn.Module):
|
21
|
+
patch_size: int
|
22
|
+
embedding_dim: int
|
23
|
+
dtype: Any = jnp.float32
|
24
|
+
precision: Any = jax.lax.Precision.HIGH
|
25
|
+
|
26
|
+
@nn.compact
|
27
|
+
def __call__(self, x):
|
28
|
+
batch, height, width, channels = x.shape
|
29
|
+
assert height % self.patch_size == 0 and width % self.patch_size == 0, "Image dimensions must be divisible by patch size"
|
30
|
+
|
31
|
+
x = nn.Conv(features=self.embedding_dim,
|
32
|
+
kernel_size=(self.patch_size, self.patch_size),
|
33
|
+
strides=(self.patch_size, self.patch_size),
|
34
|
+
dtype=self.dtype,
|
35
|
+
precision=self.precision)(x)
|
36
|
+
x = jnp.reshape(x, (batch, -1, self.embedding_dim))
|
37
|
+
return x
|
38
|
+
|
39
|
+
|
40
|
+
class PositionalEncoding(nn.Module):
|
41
|
+
max_len: int
|
42
|
+
embedding_dim: int
|
43
|
+
|
44
|
+
@nn.compact
|
45
|
+
def __call__(self, x):
|
46
|
+
pe = self.param('pos_encoding',
|
47
|
+
jax.nn.initializers.zeros,
|
48
|
+
(1, self.max_len, self.embedding_dim))
|
49
|
+
return x + pe[:, :x.shape[1], :]
|
50
|
+
|
51
|
+
|
52
|
+
# --- Rotary Positional Embedding (RoPE) ---
|
53
|
+
# Adapted from https://github.com/google-deepmind/ring_attention/blob/main/ring_attention/layers/rotary.py
|
54
|
+
|
55
|
+
|
56
|
+
def _rotate_half(x: jax.Array) -> jax.Array:
|
57
|
+
"""Rotates half the hidden dims of the input."""
|
58
|
+
x1 = x[..., : x.shape[-1] // 2]
|
59
|
+
x2 = x[..., x.shape[-1] // 2:]
|
60
|
+
return jnp.concatenate((-x2, x1), axis=-1)
|
61
|
+
|
62
|
+
def apply_rotary_embedding(
|
63
|
+
x: jax.Array, freqs_cos: jax.Array, freqs_sin: jax.Array
|
64
|
+
) -> jax.Array:
|
65
|
+
"""Applies rotary embedding to the input tensor using rotate_half method."""
|
66
|
+
# x shape: [..., Sequence, Dimension] e.g. [B, H, S, D] or [B, S, D]
|
67
|
+
# freqs_cos/sin shape: [Sequence, Dimension / 2]
|
68
|
+
|
69
|
+
# Expand dims for broadcasting: [1, 1, S, D/2] or [1, S, D/2]
|
70
|
+
if x.ndim == 4: # [B, H, S, D]
|
71
|
+
cos_freqs = jnp.expand_dims(freqs_cos, axis=(0, 1))
|
72
|
+
sin_freqs = jnp.expand_dims(freqs_sin, axis=(0, 1))
|
73
|
+
elif x.ndim == 3: # [B, S, D]
|
74
|
+
cos_freqs = jnp.expand_dims(freqs_cos, axis=0)
|
75
|
+
sin_freqs = jnp.expand_dims(freqs_sin, axis=0)
|
76
|
+
|
77
|
+
# Duplicate cos and sin for the full dimension D
|
78
|
+
# Shape becomes [..., S, D]
|
79
|
+
cos_freqs = jnp.concatenate([cos_freqs, cos_freqs], axis=-1)
|
80
|
+
sin_freqs = jnp.concatenate([sin_freqs, sin_freqs], axis=-1)
|
81
|
+
|
82
|
+
# Apply rotation: x * cos + rotate_half(x) * sin
|
83
|
+
x_rotated = x * cos_freqs + _rotate_half(x) * sin_freqs
|
84
|
+
return x_rotated.astype(x.dtype)
|
85
|
+
|
86
|
+
class RotaryEmbedding(nn.Module):
|
87
|
+
dim: int
|
88
|
+
max_seq_len: int = 4096 # Increased default based on SimpleDiT
|
89
|
+
base: int = 10000
|
90
|
+
dtype: Dtype = jnp.float32
|
91
|
+
|
92
|
+
def setup(self):
|
93
|
+
inv_freq = 1.0 / (
|
94
|
+
self.base ** (jnp.arange(0, self.dim, 2,
|
95
|
+
dtype=jnp.float32) / self.dim)
|
96
|
+
)
|
97
|
+
t = jnp.arange(self.max_seq_len, dtype=jnp.float32)
|
98
|
+
freqs = jnp.outer(t, inv_freq)
|
99
|
+
self.freqs_cos = jnp.cos(freqs)
|
100
|
+
self.freqs_sin = jnp.sin(freqs)
|
101
|
+
|
102
|
+
def __call__(self, seq_len: int):
|
103
|
+
if seq_len > self.max_seq_len:
|
104
|
+
# Dynamically extend frequencies if needed (more robust)
|
105
|
+
t = jnp.arange(seq_len, dtype=jnp.float32)
|
106
|
+
inv_freq = 1.0 / (
|
107
|
+
self.base ** (jnp.arange(0, self.dim, 2,
|
108
|
+
dtype=jnp.float32) / self.dim)
|
109
|
+
)
|
110
|
+
freqs = jnp.outer(t, inv_freq)
|
111
|
+
freqs_cos = jnp.cos(freqs)
|
112
|
+
freqs_sin = jnp.sin(freqs)
|
113
|
+
# Consider caching extended freqs if this happens often
|
114
|
+
return freqs_cos, freqs_sin
|
115
|
+
# Or raise error like before:
|
116
|
+
# raise ValueError(f"Sequence length {seq_len} exceeds max_seq_len {self.max_seq_len}")
|
117
|
+
return self.freqs_cos[:seq_len, :], self.freqs_sin[:seq_len, :]
|
118
|
+
|
119
|
+
|
120
|
+
# --- Attention with RoPE ---
|
121
|
+
|
122
|
+
|
123
|
+
class RoPEAttention(NormalAttention):
|
124
|
+
rope_emb: RotaryEmbedding = None
|
125
|
+
|
126
|
+
@nn.compact
|
127
|
+
def __call__(self, x, context=None, freqs_cis=None):
|
128
|
+
orig_x_shape = x.shape
|
129
|
+
is_4d = len(orig_x_shape) == 4
|
130
|
+
if is_4d:
|
131
|
+
B, H, W, C = x.shape
|
132
|
+
seq_len = H * W
|
133
|
+
x = x.reshape((B, seq_len, C))
|
134
|
+
else:
|
135
|
+
B, seq_len, C = x.shape
|
136
|
+
|
137
|
+
context = x if context is None else context
|
138
|
+
if len(context.shape) == 4:
|
139
|
+
_B, _H, _W, _C = context.shape
|
140
|
+
context_seq_len = _H * _W
|
141
|
+
context = context.reshape((B, context_seq_len, _C))
|
142
|
+
# else: # context is already [B, S_ctx, C]
|
143
|
+
|
144
|
+
query = self.query(x) # [B, S, H, D]
|
145
|
+
key = self.key(context) # [B, S_ctx, H, D]
|
146
|
+
value = self.value(context) # [B, S_ctx, H, D]
|
147
|
+
|
148
|
+
if freqs_cis is None and self.rope_emb is not None:
|
149
|
+
seq_len_q = query.shape[1] # Use query's sequence length
|
150
|
+
freqs_cos, freqs_sin = self.rope_emb(seq_len_q)
|
151
|
+
elif freqs_cis is not None:
|
152
|
+
freqs_cos, freqs_sin = freqs_cis
|
153
|
+
else:
|
154
|
+
# Should not happen if rope_emb is provided or freqs_cis are passed
|
155
|
+
raise ValueError("RoPE frequencies not provided.")
|
156
|
+
|
157
|
+
# Apply RoPE to query and key
|
158
|
+
# Permute to [B, H, S, D] for RoPE application
|
159
|
+
query = einops.rearrange(query, 'b s h d -> b h s d')
|
160
|
+
key = einops.rearrange(key, 'b s h d -> b h s d')
|
161
|
+
|
162
|
+
# Apply RoPE only up to the context sequence length for keys if different
|
163
|
+
# Assuming self-attention or context has same seq len for simplicity here
|
164
|
+
query = apply_rotary_embedding(query, freqs_cos, freqs_sin)
|
165
|
+
key = apply_rotary_embedding(
|
166
|
+
key, freqs_cos, freqs_sin) # Apply same freqs to key
|
167
|
+
|
168
|
+
# Permute back to [B, S, H, D] for dot_product_attention
|
169
|
+
query = einops.rearrange(query, 'b h s d -> b s h d')
|
170
|
+
key = einops.rearrange(key, 'b h s d -> b s h d')
|
171
|
+
|
172
|
+
hidden_states = nn.dot_product_attention(
|
173
|
+
query, key, value, dtype=self.dtype, broadcast_dropout=False,
|
174
|
+
dropout_rng=None, precision=self.precision, force_fp32_for_softmax=self.force_fp32_for_softmax,
|
175
|
+
deterministic=True
|
176
|
+
)
|
177
|
+
|
178
|
+
proj = self.proj_attn(hidden_states)
|
179
|
+
|
180
|
+
if is_4d:
|
181
|
+
proj = proj.reshape(orig_x_shape)
|
182
|
+
|
183
|
+
return proj
|
184
|
+
|
185
|
+
|
186
|
+
# --- adaLN-Zero ---
|
187
|
+
|
188
|
+
|
189
|
+
class AdaLNZero(nn.Module):
|
190
|
+
features: int
|
191
|
+
dtype: Optional[Dtype] = None
|
192
|
+
precision: PrecisionLike = None
|
193
|
+
norm_epsilon: float = 1e-5 # Standard LayerNorm epsilon
|
194
|
+
|
195
|
+
@nn.compact
|
196
|
+
def __call__(self, x, conditioning):
|
197
|
+
# Project conditioning signal to get scale and shift parameters
|
198
|
+
# Conditioning shape: [B, D_cond] -> [B, 1, ..., 1, 6 * features] for broadcasting
|
199
|
+
# Or [B, 1, 6*features] if x is [B, S, F]
|
200
|
+
|
201
|
+
# Ensure conditioning has seq dim if x does
|
202
|
+
# x=[B,S,F], cond=[B,D_cond]
|
203
|
+
if x.ndim == 3 and conditioning.ndim == 2:
|
204
|
+
conditioning = jnp.expand_dims(
|
205
|
+
conditioning, axis=1) # cond=[B,1,D_cond]
|
206
|
+
|
207
|
+
# Project conditioning to get 6 params per feature (scale_mlp, shift_mlp, gate_mlp, scale_attn, shift_attn, gate_attn)
|
208
|
+
# Using nn.DenseGeneral for flexibility if needed, but nn.Dense is fine if cond is [B, D_cond] or [B, 1, D_cond]
|
209
|
+
ada_params = nn.Dense(
|
210
|
+
features=6 * self.features,
|
211
|
+
dtype=self.dtype,
|
212
|
+
precision=self.precision,
|
213
|
+
# Initialize projection to zero (Zero init)
|
214
|
+
kernel_init=nn.initializers.zeros,
|
215
|
+
name="ada_proj"
|
216
|
+
)(conditioning)
|
217
|
+
|
218
|
+
# Split into scale, shift, gate for MLP and Attention
|
219
|
+
scale_mlp, shift_mlp, gate_mlp, scale_attn, shift_attn, gate_attn = jnp.split(
|
220
|
+
ada_params, 6, axis=-1)
|
221
|
+
|
222
|
+
scale_mlp = jnp.clip(scale_mlp, -10.0, 10.0)
|
223
|
+
shift_mlp = jnp.clip(shift_mlp, -10.0, 10.0)
|
224
|
+
# Apply Layer Normalization
|
225
|
+
norm = nn.LayerNorm(epsilon=self.norm_epsilon,
|
226
|
+
use_scale=False, use_bias=False, dtype=self.dtype)
|
227
|
+
# norm = nn.RMSNorm(epsilon=self.norm_epsilon, dtype=self.dtype) # Alternative: RMSNorm
|
228
|
+
|
229
|
+
norm_x = norm(x)
|
230
|
+
|
231
|
+
# Modulate for Attention path
|
232
|
+
x_attn = norm_x * (1 + scale_attn) + shift_attn
|
233
|
+
|
234
|
+
# Modulate for MLP path
|
235
|
+
x_mlp = norm_x * (1 + scale_mlp) + shift_mlp
|
236
|
+
|
237
|
+
# Return modulated outputs and gates
|
238
|
+
return x_attn, gate_attn, x_mlp, gate_mlp
|
239
|
+
|
240
|
+
class AdaLNParams(nn.Module): # Renamed for clarity
|
241
|
+
features: int
|
242
|
+
dtype: Optional[Dtype] = None
|
243
|
+
precision: PrecisionLike = None
|
244
|
+
|
245
|
+
@nn.compact
|
246
|
+
def __call__(self, conditioning):
|
247
|
+
# Ensure conditioning is broadcastable if needed (e.g., [B, 1, D_cond])
|
248
|
+
if conditioning.ndim == 2:
|
249
|
+
conditioning = jnp.expand_dims(conditioning, axis=1)
|
250
|
+
|
251
|
+
# Project conditioning to get 6 params per feature
|
252
|
+
ada_params = nn.Dense(
|
253
|
+
features=6 * self.features,
|
254
|
+
dtype=self.dtype,
|
255
|
+
precision=self.precision,
|
256
|
+
kernel_init=nn.initializers.zeros,
|
257
|
+
name="ada_proj"
|
258
|
+
)(conditioning)
|
259
|
+
# Return all params (or split if preferred, but maybe return tuple/dict)
|
260
|
+
# Shape: [B, 1, 6*F]
|
261
|
+
return ada_params # Or split and return tuple: jnp.split(ada_params, 6, axis=-1)
|
262
|
+
|
@@ -427,7 +427,8 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
|
|
427
427
|
process_index = jax.process_index()
|
428
428
|
generate_samples = val_step_fn
|
429
429
|
|
430
|
-
val_ds = iter(val_ds) if val_ds else None
|
430
|
+
val_ds = iter(val_ds()) if val_ds else None
|
431
|
+
print(f"Validation loop started for process index {process_index} with {global_device_count} devices.")
|
431
432
|
# Evaluation step
|
432
433
|
try:
|
433
434
|
metrics = {metric.name: [] for metric in self.eval_metrics} if self.eval_metrics else {}
|
@@ -474,10 +475,11 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
|
|
474
475
|
metrics = {k: np.mean(v) for k, v in metrics.items()}
|
475
476
|
# Update the best validation metrics
|
476
477
|
for key, value in metrics.items():
|
477
|
-
|
478
|
-
|
478
|
+
final_key = f"val/{key}"
|
479
|
+
if final_key not in self.best_val_metrics:
|
480
|
+
self.best_val_metrics[final_key] = value
|
479
481
|
else:
|
480
|
-
self.best_val_metrics[
|
482
|
+
self.best_val_metrics[final_key] = min(self.best_val_metrics[final_key], value)
|
481
483
|
# Log the best validation metrics
|
482
484
|
if getattr(self, 'wandb', None) is not None and self.wandb:
|
483
485
|
# Log the metrics
|
@@ -487,7 +489,14 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
|
|
487
489
|
self.wandb.log({
|
488
490
|
f"val/{key}": value,
|
489
491
|
}, step=current_step)
|
490
|
-
|
492
|
+
# Log the best validation metrics
|
493
|
+
for key, value in self.best_val_metrics.items():
|
494
|
+
if isinstance(value, jnp.ndarray):
|
495
|
+
value = np.array(value)
|
496
|
+
self.wandb.log({
|
497
|
+
f"best_{key}": value,
|
498
|
+
}, step=current_step)
|
499
|
+
print(f"Validation metrics for process index {process_index}: {metrics}")
|
491
500
|
|
492
501
|
# Close validation dataset iterator
|
493
502
|
del val_ds
|
@@ -621,10 +630,10 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
|
|
621
630
|
if not runs:
|
622
631
|
raise ValueError("No runs found in wandb.")
|
623
632
|
print(f"Getting best runs from wandb {self.wandb.id}...")
|
624
|
-
runs = sorted(runs, key=lambda x: x.summary.get(metric, float('inf')))
|
633
|
+
runs = sorted(runs, key=lambda x: x.summary.get(f"best_{metric}", float('inf')))
|
625
634
|
best_runs = runs[:top_k]
|
626
|
-
lower_bound = best_runs[-1].summary.get(metric, float('inf'))
|
627
|
-
upper_bound = best_runs[0].summary.get(metric, float('inf'))
|
635
|
+
lower_bound = best_runs[-1].summary.get(f"best_{metric}", float('inf'))
|
636
|
+
upper_bound = best_runs[0].summary.get(f"best_{metric}", float('inf'))
|
628
637
|
print(f"Best runs from wandb {self.wandb.id}:")
|
629
638
|
for run in best_runs:
|
630
639
|
print(f"\t\tRun ID: {run.id}, Metric: {run.summary.get(metric, float('inf'))}")
|
@@ -648,19 +657,21 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
|
|
648
657
|
best_runs, bounds = self.__get_best_general_runs__(metric=metric, top_k=top_k)
|
649
658
|
|
650
659
|
# Determine if lower or higher values are better (for loss, lower is better)
|
651
|
-
is_lower_better =
|
660
|
+
is_lower_better = True
|
652
661
|
|
653
662
|
# Check if current run is one of the best
|
654
663
|
if metric == "train/best_loss":
|
655
664
|
current_run_metric = self.best_loss
|
656
665
|
elif metric in self.best_val_metrics:
|
666
|
+
print(f"Fetching best validation metric {metric} from local")
|
657
667
|
current_run_metric = self.best_val_metrics[metric]
|
658
668
|
else:
|
659
669
|
current_run_metric = self.wandb.summary.get(metric, float('inf') if is_lower_better else float('-inf'))
|
660
|
-
|
670
|
+
|
671
|
+
print(f"Current run {self.wandb.id} metric: {current_run_metric}, Best bounds: {bounds}")
|
661
672
|
# Check based on bounds
|
662
673
|
if (is_lower_better and current_run_metric < bounds[1]) or (not is_lower_better and current_run_metric > bounds[0]):
|
663
|
-
print(f"Current run {self.wandb.id} meets performance criteria.")
|
674
|
+
print(f"Current run {self.wandb.id} meets performance criteria. Current metric: {current_run_metric}, Best bounds: {bounds}")
|
664
675
|
is_best = (is_lower_better and current_run_metric < bounds[0]) or (not is_lower_better and current_run_metric > bounds[1])
|
665
676
|
return True, is_best
|
666
677
|
|
@@ -600,7 +600,7 @@ class SimpleTrainer:
|
|
600
600
|
|
601
601
|
def fit(self, data, train_steps_per_epoch, epochs, train_step_args={}, val_steps_per_epoch=5, validation_step_args={}):
|
602
602
|
train_ds = iter(data['train']())
|
603
|
-
val_ds = data.get('val', data.get('test', None))
|
603
|
+
val_ds = data.get('val', data.get('test', None))
|
604
604
|
train_step = self._define_train_step(**train_step_args)
|
605
605
|
val_step = self._define_validation_step(**validation_step_args)
|
606
606
|
train_state = self.state
|
@@ -642,6 +642,19 @@ class SimpleTrainer:
|
|
642
642
|
self.rngstate = rng_state
|
643
643
|
total_time = end_time - start_time
|
644
644
|
avg_time_per_step = total_time / train_steps_per_epoch
|
645
|
+
|
646
|
+
if val_steps_per_epoch > 0:
|
647
|
+
print(f"Validation started for process index {process_index}")
|
648
|
+
# Validation step
|
649
|
+
self.validation_loop(
|
650
|
+
train_state,
|
651
|
+
val_step,
|
652
|
+
val_ds,
|
653
|
+
val_steps_per_epoch,
|
654
|
+
current_step,
|
655
|
+
)
|
656
|
+
print(colored(f"Validation done on process index {process_index}", PROCESS_COLOR_MAP[process_index]))
|
657
|
+
|
645
658
|
avg_loss = epoch_loss / train_steps_per_epoch
|
646
659
|
if avg_loss < self.best_loss:
|
647
660
|
self.best_loss = avg_loss
|
@@ -659,17 +672,6 @@ class SimpleTrainer:
|
|
659
672
|
}, step=current_step)
|
660
673
|
print(colored(f"\n\tEpoch {current_epoch} completed. Avg Loss: {avg_loss}, Time: {total_time:.2f}s, Best Loss: {self.best_loss}", 'green'))
|
661
674
|
|
662
|
-
if val_steps_per_epoch > 0:
|
663
|
-
print(f"Validation started for process index {process_index}")
|
664
|
-
# Validation step
|
665
|
-
self.validation_loop(
|
666
|
-
train_state,
|
667
|
-
val_step,
|
668
|
-
val_ds,
|
669
|
-
val_steps_per_epoch,
|
670
|
-
current_step,
|
671
|
-
)
|
672
|
-
print(colored(f"Validation done on process index {process_index}", PROCESS_COLOR_MAP[process_index]))
|
673
675
|
|
674
676
|
self.save(epochs)#
|
675
677
|
return self.state
|
@@ -2,20 +2,20 @@ flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
2
|
flaxdiff/utils.py,sha256=DmlWUY1FGz4ESxIHaPQJf92CHjsdMjyDd651wFUtyNg,8838
|
3
3
|
flaxdiff/data/__init__.py,sha256=8W5y7NyAOWtpLi8WRawk4VYeE3DMDnM3B_jKPD8BoFQ,143
|
4
4
|
flaxdiff/data/benchmark_decord.py,sha256=x56Db1VPmziv_9KJvWdfS0O7cffsYkF5tt5WvldOKc0,13720
|
5
|
-
flaxdiff/data/dataloaders.py,sha256=
|
6
|
-
flaxdiff/data/dataset_map.py,sha256=
|
5
|
+
flaxdiff/data/dataloaders.py,sha256=k_3YGJhiY2Wt_-7qK0Yjl4pmF2QJjX_-BlSFuXbH5-M,23628
|
6
|
+
flaxdiff/data/dataset_map.py,sha256=p30U23RkfgMbR8kfPBDIjrjfzDBszWQ9Q1ff2BvDYZk,5116
|
7
7
|
flaxdiff/data/online_loader.py,sha256=t1jEhdB6gWTlwx68ehj1ol_PrImbwXYiRlrJPCmNgCM,35701
|
8
8
|
flaxdiff/data/sources/audio_utils.py,sha256=X27gG1yQt_abVOYgMtruYmZD7-8_uQCRhhTSpn4clkI,4514
|
9
9
|
flaxdiff/data/sources/av_example.py,sha256=RIcbVKqckFqbfnV65NQotzIBxjdDuM67kD1nY8fqw5Q,3826
|
10
10
|
flaxdiff/data/sources/av_utils.py,sha256=LCr9MJNurOaoxY-sjzkLqJS_MlX0x3gRSlKAVIglAU0,24045
|
11
11
|
flaxdiff/data/sources/base.py,sha256=4Rm9pCtXxzoB8FO0lkDHsrX3ULoU_PNNcid978e6ir0,4610
|
12
|
-
flaxdiff/data/sources/images.py,sha256=
|
12
|
+
flaxdiff/data/sources/images.py,sha256=ZHBmZ2fnPN75Hc2kiog-Wcs_NZJZOiqw4WcSH5WZJHA,16572
|
13
13
|
flaxdiff/data/sources/utils.py,sha256=kFzM4_kPoThbAu54ulABmEDAR33tR50NgzXIpC0Dzjk,7316
|
14
14
|
flaxdiff/data/sources/videos.py,sha256=NkxwEruNpAwDCM53q4WurQ802gSjQMOqjNLxYOqjoNE,9545
|
15
15
|
flaxdiff/data/sources/voxceleb2.py,sha256=BoKfat_hsw6ObDyyaiQmPbBzuFiqgCGlgAZmf-t5Iz8,18621
|
16
16
|
flaxdiff/inference/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
17
17
|
flaxdiff/inference/pipeline.py,sha256=8S30FAlXEjvrDd87H-qdD6biySQZ3cJUflU8gdmPxig,9223
|
18
|
-
flaxdiff/inference/utils.py,sha256=
|
18
|
+
flaxdiff/inference/utils.py,sha256=Dh0KawgvQrZxyqN_9wbsb7gUyvPRendwb-YtAU6zIBE,12606
|
19
19
|
flaxdiff/inputs/__init__.py,sha256=ybPjQsFAf5sqRVZG1sRiOl99EnwpI-NQ8HE3y7UbXmU,7197
|
20
20
|
flaxdiff/inputs/encoders.py,sha256=pjfbx4Rk7bLoE80MOfThZDm6YtsDncRekmn0Bmg_CwI,2963
|
21
21
|
flaxdiff/metrics/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -31,12 +31,13 @@ flaxdiff/models/common.py,sha256=QpciwuJldvLUwyAyWBQqiPPGVI-c9qLR7h7C1YoRX7w,105
|
|
31
31
|
flaxdiff/models/favor_fastattn.py,sha256=79Ew1nqarsNLPzZaBSd1ILORzJr74CupYeqGiCQK5E4,27689
|
32
32
|
flaxdiff/models/general.py,sha256=7xMME6KVKQY8sScyHYH4f-Kek4j1pRfplKShFXwVZd4,587
|
33
33
|
flaxdiff/models/hilbert.py,sha256=AjlAv49dL6UAYWslMJfCMLiFqY4kTgpiUWr2nc1mk34,24823
|
34
|
-
flaxdiff/models/simple_dit.py,sha256=
|
35
|
-
flaxdiff/models/simple_mmdit.py,sha256=
|
34
|
+
flaxdiff/models/simple_dit.py,sha256=l238MYHRTArv_pS57aY24C2PTfxeL8EmzJ24iQqdoWI,11702
|
35
|
+
flaxdiff/models/simple_mmdit.py,sha256=ARk0juopn2k7giln5BAUrnYD1pTFwgTJoSzrhozQ0A8,31356
|
36
36
|
flaxdiff/models/simple_unet.py,sha256=pjeixszG_6gEY5PNFbQ7KbOyg4z5bfn4RUbINCJexOM,10758
|
37
|
-
flaxdiff/models/simple_vit.py,sha256=
|
37
|
+
flaxdiff/models/simple_vit.py,sha256=J9s3hBF87_iVrJDBe2cs9a56N7ect6pux_f_ge07XXc,17357
|
38
38
|
flaxdiff/models/unet_3d.py,sha256=LF0PMxBKGU-_lAMtO_Coxy1yRE02yKKdgb7i6YZxI_4,20163
|
39
39
|
flaxdiff/models/unet_3d_blocks.py,sha256=lRYDc9X1VEu54Kg7xEEphXYiQ09tabPXKi-hEcKFYug,19687
|
40
|
+
flaxdiff/models/vit_common.py,sha256=1OGu4ezY3uzKinTnw3p8YkQAslHDqEbN78JheXnTleI,9831
|
40
41
|
flaxdiff/models/autoencoder/__init__.py,sha256=qY-7MldZpsfkF-_T2LqlRK7VHbqfmosz0NmvzDlBkOk,78
|
41
42
|
flaxdiff/models/autoencoder/autoencoder.py,sha256=8XWdsWvsPsyWGtzpCT8w0KXi_ZLGpRuQpn4oXo1gHKw,6039
|
42
43
|
flaxdiff/models/autoencoder/diffusers.py,sha256=tPz77YuctrT--jF2AOL8G6vr0NiIr3RXANNrZCxe0bg,5921
|
@@ -62,9 +63,9 @@ flaxdiff/schedulers/sqrt.py,sha256=mCd_szmOqF6vqQKiAiEOqV_3eBIPGYrW3VxK0o4rBuo,4
|
|
62
63
|
flaxdiff/trainer/__init__.py,sha256=xSoierfi26gxfgxlNnwvyyPmuPAJ--5i3mEHxt3S-AE,215
|
63
64
|
flaxdiff/trainer/autoencoder_trainer.py,sha256=2FP2P-k9c0n_k3eT0trkq73dQrHRdBj9ObK1idcyhSw,6996
|
64
65
|
flaxdiff/trainer/diffusion_trainer.py,sha256=reQEVWKTqKAeyCMQ-curPOfSRmBKxKooK8EVtUuorcM,14599
|
65
|
-
flaxdiff/trainer/general_diffusion_trainer.py,sha256=
|
66
|
-
flaxdiff/trainer/simple_trainer.py,sha256=
|
67
|
-
flaxdiff-0.2.
|
68
|
-
flaxdiff-0.2.
|
69
|
-
flaxdiff-0.2.
|
70
|
-
flaxdiff-0.2.
|
66
|
+
flaxdiff/trainer/general_diffusion_trainer.py,sha256=gMo0OOz8EFKGfiqZnDwhVSxtk_IUMGUvyt5TTr_Hk8g,30168
|
67
|
+
flaxdiff/trainer/simple_trainer.py,sha256=nXYy9tadteG8N0RovpevPPEs6oeFvbr2gVq7Zot9l78,28754
|
68
|
+
flaxdiff-0.2.10.dist-info/METADATA,sha256=xsqksvLSps2a9nNdvZkguWvsC07kX8A3Z26DPTq-tGI,24058
|
69
|
+
flaxdiff-0.2.10.dist-info/WHEEL,sha256=0CuiUZ_p9E4cD6NyLD6UG80LBXYyiSYZOKDm5lp32xk,91
|
70
|
+
flaxdiff-0.2.10.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
|
71
|
+
flaxdiff-0.2.10.dist-info/RECORD,,
|
File without changes
|
File without changes
|