flaxdiff 0.2.7__py3-none-any.whl → 0.2.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/data/dataloaders.py +23 -19
- flaxdiff/data/dataset_map.py +2 -1
- flaxdiff/data/sources/base.py +12 -0
- flaxdiff/data/sources/images.py +75 -3
- flaxdiff/data/sources/videos.py +5 -0
- flaxdiff/inference/utils.py +7 -1
- flaxdiff/models/common.py +1 -70
- flaxdiff/models/hilbert.py +617 -0
- flaxdiff/models/simple_dit.py +275 -0
- flaxdiff/models/simple_mmdit.py +730 -0
- flaxdiff/models/simple_vit.py +405 -145
- flaxdiff/models/vit_common.py +262 -0
- flaxdiff/trainer/general_diffusion_trainer.py +30 -10
- flaxdiff/trainer/simple_trainer.py +113 -19
- {flaxdiff-0.2.7.dist-info → flaxdiff-0.2.9.dist-info}/METADATA +1 -1
- {flaxdiff-0.2.7.dist-info → flaxdiff-0.2.9.dist-info}/RECORD +18 -15
- {flaxdiff-0.2.7.dist-info → flaxdiff-0.2.9.dist-info}/WHEEL +1 -1
- flaxdiff/models/better_uvit.py +0 -380
- {flaxdiff-0.2.7.dist-info → flaxdiff-0.2.9.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,262 @@
|
|
1
|
+
import jax
|
2
|
+
import jax.numpy as jnp
|
3
|
+
from flax import linen as nn
|
4
|
+
from typing import Any, Optional
|
5
|
+
import einops
|
6
|
+
from flax.typing import Dtype, PrecisionLike
|
7
|
+
|
8
|
+
from .attention import NormalAttention
|
9
|
+
|
10
|
+
def unpatchify(x, channels=3):
|
11
|
+
patch_size = int((x.shape[2] // channels) ** 0.5)
|
12
|
+
h = w = int(x.shape[1] ** .5)
|
13
|
+
assert h * w == x.shape[1] and patch_size ** 2 * \
|
14
|
+
channels == x.shape[2], f"Invalid shape: {x.shape}, should be {h*w}, {patch_size**2*channels}"
|
15
|
+
x = einops.rearrange(
|
16
|
+
x, 'B (h w) (p1 p2 C) -> B (h p1) (w p2) C', h=h, p1=patch_size, p2=patch_size)
|
17
|
+
return x
|
18
|
+
|
19
|
+
|
20
|
+
class PatchEmbedding(nn.Module):
|
21
|
+
patch_size: int
|
22
|
+
embedding_dim: int
|
23
|
+
dtype: Any = jnp.float32
|
24
|
+
precision: Any = jax.lax.Precision.HIGH
|
25
|
+
|
26
|
+
@nn.compact
|
27
|
+
def __call__(self, x):
|
28
|
+
batch, height, width, channels = x.shape
|
29
|
+
assert height % self.patch_size == 0 and width % self.patch_size == 0, "Image dimensions must be divisible by patch size"
|
30
|
+
|
31
|
+
x = nn.Conv(features=self.embedding_dim,
|
32
|
+
kernel_size=(self.patch_size, self.patch_size),
|
33
|
+
strides=(self.patch_size, self.patch_size),
|
34
|
+
dtype=self.dtype,
|
35
|
+
precision=self.precision)(x)
|
36
|
+
x = jnp.reshape(x, (batch, -1, self.embedding_dim))
|
37
|
+
return x
|
38
|
+
|
39
|
+
|
40
|
+
class PositionalEncoding(nn.Module):
|
41
|
+
max_len: int
|
42
|
+
embedding_dim: int
|
43
|
+
|
44
|
+
@nn.compact
|
45
|
+
def __call__(self, x):
|
46
|
+
pe = self.param('pos_encoding',
|
47
|
+
jax.nn.initializers.zeros,
|
48
|
+
(1, self.max_len, self.embedding_dim))
|
49
|
+
return x + pe[:, :x.shape[1], :]
|
50
|
+
|
51
|
+
|
52
|
+
# --- Rotary Positional Embedding (RoPE) ---
|
53
|
+
# Adapted from https://github.com/google-deepmind/ring_attention/blob/main/ring_attention/layers/rotary.py
|
54
|
+
|
55
|
+
|
56
|
+
def _rotate_half(x: jax.Array) -> jax.Array:
|
57
|
+
"""Rotates half the hidden dims of the input."""
|
58
|
+
x1 = x[..., : x.shape[-1] // 2]
|
59
|
+
x2 = x[..., x.shape[-1] // 2:]
|
60
|
+
return jnp.concatenate((-x2, x1), axis=-1)
|
61
|
+
|
62
|
+
def apply_rotary_embedding(
|
63
|
+
x: jax.Array, freqs_cos: jax.Array, freqs_sin: jax.Array
|
64
|
+
) -> jax.Array:
|
65
|
+
"""Applies rotary embedding to the input tensor using rotate_half method."""
|
66
|
+
# x shape: [..., Sequence, Dimension] e.g. [B, H, S, D] or [B, S, D]
|
67
|
+
# freqs_cos/sin shape: [Sequence, Dimension / 2]
|
68
|
+
|
69
|
+
# Expand dims for broadcasting: [1, 1, S, D/2] or [1, S, D/2]
|
70
|
+
if x.ndim == 4: # [B, H, S, D]
|
71
|
+
cos_freqs = jnp.expand_dims(freqs_cos, axis=(0, 1))
|
72
|
+
sin_freqs = jnp.expand_dims(freqs_sin, axis=(0, 1))
|
73
|
+
elif x.ndim == 3: # [B, S, D]
|
74
|
+
cos_freqs = jnp.expand_dims(freqs_cos, axis=0)
|
75
|
+
sin_freqs = jnp.expand_dims(freqs_sin, axis=0)
|
76
|
+
|
77
|
+
# Duplicate cos and sin for the full dimension D
|
78
|
+
# Shape becomes [..., S, D]
|
79
|
+
cos_freqs = jnp.concatenate([cos_freqs, cos_freqs], axis=-1)
|
80
|
+
sin_freqs = jnp.concatenate([sin_freqs, sin_freqs], axis=-1)
|
81
|
+
|
82
|
+
# Apply rotation: x * cos + rotate_half(x) * sin
|
83
|
+
x_rotated = x * cos_freqs + _rotate_half(x) * sin_freqs
|
84
|
+
return x_rotated.astype(x.dtype)
|
85
|
+
|
86
|
+
class RotaryEmbedding(nn.Module):
|
87
|
+
dim: int
|
88
|
+
max_seq_len: int = 4096 # Increased default based on SimpleDiT
|
89
|
+
base: int = 10000
|
90
|
+
dtype: Dtype = jnp.float32
|
91
|
+
|
92
|
+
def setup(self):
|
93
|
+
inv_freq = 1.0 / (
|
94
|
+
self.base ** (jnp.arange(0, self.dim, 2,
|
95
|
+
dtype=jnp.float32) / self.dim)
|
96
|
+
)
|
97
|
+
t = jnp.arange(self.max_seq_len, dtype=jnp.float32)
|
98
|
+
freqs = jnp.outer(t, inv_freq)
|
99
|
+
self.freqs_cos = jnp.cos(freqs)
|
100
|
+
self.freqs_sin = jnp.sin(freqs)
|
101
|
+
|
102
|
+
def __call__(self, seq_len: int):
|
103
|
+
if seq_len > self.max_seq_len:
|
104
|
+
# Dynamically extend frequencies if needed (more robust)
|
105
|
+
t = jnp.arange(seq_len, dtype=jnp.float32)
|
106
|
+
inv_freq = 1.0 / (
|
107
|
+
self.base ** (jnp.arange(0, self.dim, 2,
|
108
|
+
dtype=jnp.float32) / self.dim)
|
109
|
+
)
|
110
|
+
freqs = jnp.outer(t, inv_freq)
|
111
|
+
freqs_cos = jnp.cos(freqs)
|
112
|
+
freqs_sin = jnp.sin(freqs)
|
113
|
+
# Consider caching extended freqs if this happens often
|
114
|
+
return freqs_cos, freqs_sin
|
115
|
+
# Or raise error like before:
|
116
|
+
# raise ValueError(f"Sequence length {seq_len} exceeds max_seq_len {self.max_seq_len}")
|
117
|
+
return self.freqs_cos[:seq_len, :], self.freqs_sin[:seq_len, :]
|
118
|
+
|
119
|
+
|
120
|
+
# --- Attention with RoPE ---
|
121
|
+
|
122
|
+
|
123
|
+
class RoPEAttention(NormalAttention):
|
124
|
+
rope_emb: RotaryEmbedding = None
|
125
|
+
|
126
|
+
@nn.compact
|
127
|
+
def __call__(self, x, context=None, freqs_cis=None):
|
128
|
+
orig_x_shape = x.shape
|
129
|
+
is_4d = len(orig_x_shape) == 4
|
130
|
+
if is_4d:
|
131
|
+
B, H, W, C = x.shape
|
132
|
+
seq_len = H * W
|
133
|
+
x = x.reshape((B, seq_len, C))
|
134
|
+
else:
|
135
|
+
B, seq_len, C = x.shape
|
136
|
+
|
137
|
+
context = x if context is None else context
|
138
|
+
if len(context.shape) == 4:
|
139
|
+
_B, _H, _W, _C = context.shape
|
140
|
+
context_seq_len = _H * _W
|
141
|
+
context = context.reshape((B, context_seq_len, _C))
|
142
|
+
# else: # context is already [B, S_ctx, C]
|
143
|
+
|
144
|
+
query = self.query(x) # [B, S, H, D]
|
145
|
+
key = self.key(context) # [B, S_ctx, H, D]
|
146
|
+
value = self.value(context) # [B, S_ctx, H, D]
|
147
|
+
|
148
|
+
if freqs_cis is None and self.rope_emb is not None:
|
149
|
+
seq_len_q = query.shape[1] # Use query's sequence length
|
150
|
+
freqs_cos, freqs_sin = self.rope_emb(seq_len_q)
|
151
|
+
elif freqs_cis is not None:
|
152
|
+
freqs_cos, freqs_sin = freqs_cis
|
153
|
+
else:
|
154
|
+
# Should not happen if rope_emb is provided or freqs_cis are passed
|
155
|
+
raise ValueError("RoPE frequencies not provided.")
|
156
|
+
|
157
|
+
# Apply RoPE to query and key
|
158
|
+
# Permute to [B, H, S, D] for RoPE application
|
159
|
+
query = einops.rearrange(query, 'b s h d -> b h s d')
|
160
|
+
key = einops.rearrange(key, 'b s h d -> b h s d')
|
161
|
+
|
162
|
+
# Apply RoPE only up to the context sequence length for keys if different
|
163
|
+
# Assuming self-attention or context has same seq len for simplicity here
|
164
|
+
query = apply_rotary_embedding(query, freqs_cos, freqs_sin)
|
165
|
+
key = apply_rotary_embedding(
|
166
|
+
key, freqs_cos, freqs_sin) # Apply same freqs to key
|
167
|
+
|
168
|
+
# Permute back to [B, S, H, D] for dot_product_attention
|
169
|
+
query = einops.rearrange(query, 'b h s d -> b s h d')
|
170
|
+
key = einops.rearrange(key, 'b h s d -> b s h d')
|
171
|
+
|
172
|
+
hidden_states = nn.dot_product_attention(
|
173
|
+
query, key, value, dtype=self.dtype, broadcast_dropout=False,
|
174
|
+
dropout_rng=None, precision=self.precision, force_fp32_for_softmax=self.force_fp32_for_softmax,
|
175
|
+
deterministic=True
|
176
|
+
)
|
177
|
+
|
178
|
+
proj = self.proj_attn(hidden_states)
|
179
|
+
|
180
|
+
if is_4d:
|
181
|
+
proj = proj.reshape(orig_x_shape)
|
182
|
+
|
183
|
+
return proj
|
184
|
+
|
185
|
+
|
186
|
+
# --- adaLN-Zero ---
|
187
|
+
|
188
|
+
|
189
|
+
class AdaLNZero(nn.Module):
|
190
|
+
features: int
|
191
|
+
dtype: Optional[Dtype] = None
|
192
|
+
precision: PrecisionLike = None
|
193
|
+
norm_epsilon: float = 1e-5 # Standard LayerNorm epsilon
|
194
|
+
|
195
|
+
@nn.compact
|
196
|
+
def __call__(self, x, conditioning):
|
197
|
+
# Project conditioning signal to get scale and shift parameters
|
198
|
+
# Conditioning shape: [B, D_cond] -> [B, 1, ..., 1, 6 * features] for broadcasting
|
199
|
+
# Or [B, 1, 6*features] if x is [B, S, F]
|
200
|
+
|
201
|
+
# Ensure conditioning has seq dim if x does
|
202
|
+
# x=[B,S,F], cond=[B,D_cond]
|
203
|
+
if x.ndim == 3 and conditioning.ndim == 2:
|
204
|
+
conditioning = jnp.expand_dims(
|
205
|
+
conditioning, axis=1) # cond=[B,1,D_cond]
|
206
|
+
|
207
|
+
# Project conditioning to get 6 params per feature (scale_mlp, shift_mlp, gate_mlp, scale_attn, shift_attn, gate_attn)
|
208
|
+
# Using nn.DenseGeneral for flexibility if needed, but nn.Dense is fine if cond is [B, D_cond] or [B, 1, D_cond]
|
209
|
+
ada_params = nn.Dense(
|
210
|
+
features=6 * self.features,
|
211
|
+
dtype=self.dtype,
|
212
|
+
precision=self.precision,
|
213
|
+
# Initialize projection to zero (Zero init)
|
214
|
+
kernel_init=nn.initializers.zeros,
|
215
|
+
name="ada_proj"
|
216
|
+
)(conditioning)
|
217
|
+
|
218
|
+
# Split into scale, shift, gate for MLP and Attention
|
219
|
+
scale_mlp, shift_mlp, gate_mlp, scale_attn, shift_attn, gate_attn = jnp.split(
|
220
|
+
ada_params, 6, axis=-1)
|
221
|
+
|
222
|
+
scale_mlp = jnp.clip(scale_mlp, -10.0, 10.0)
|
223
|
+
shift_mlp = jnp.clip(shift_mlp, -10.0, 10.0)
|
224
|
+
# Apply Layer Normalization
|
225
|
+
norm = nn.LayerNorm(epsilon=self.norm_epsilon,
|
226
|
+
use_scale=False, use_bias=False, dtype=self.dtype)
|
227
|
+
# norm = nn.RMSNorm(epsilon=self.norm_epsilon, dtype=self.dtype) # Alternative: RMSNorm
|
228
|
+
|
229
|
+
norm_x = norm(x)
|
230
|
+
|
231
|
+
# Modulate for Attention path
|
232
|
+
x_attn = norm_x * (1 + scale_attn) + shift_attn
|
233
|
+
|
234
|
+
# Modulate for MLP path
|
235
|
+
x_mlp = norm_x * (1 + scale_mlp) + shift_mlp
|
236
|
+
|
237
|
+
# Return modulated outputs and gates
|
238
|
+
return x_attn, gate_attn, x_mlp, gate_mlp
|
239
|
+
|
240
|
+
class AdaLNParams(nn.Module): # Renamed for clarity
|
241
|
+
features: int
|
242
|
+
dtype: Optional[Dtype] = None
|
243
|
+
precision: PrecisionLike = None
|
244
|
+
|
245
|
+
@nn.compact
|
246
|
+
def __call__(self, conditioning):
|
247
|
+
# Ensure conditioning is broadcastable if needed (e.g., [B, 1, D_cond])
|
248
|
+
if conditioning.ndim == 2:
|
249
|
+
conditioning = jnp.expand_dims(conditioning, axis=1)
|
250
|
+
|
251
|
+
# Project conditioning to get 6 params per feature
|
252
|
+
ada_params = nn.Dense(
|
253
|
+
features=6 * self.features,
|
254
|
+
dtype=self.dtype,
|
255
|
+
precision=self.precision,
|
256
|
+
kernel_init=nn.initializers.zeros,
|
257
|
+
name="ada_proj"
|
258
|
+
)(conditioning)
|
259
|
+
# Return all params (or split if preferred, but maybe return tuple/dict)
|
260
|
+
# Shape: [B, 1, 6*F]
|
261
|
+
return ada_params # Or split and return tuple: jnp.split(ada_params, 6, axis=-1)
|
262
|
+
|
@@ -129,6 +129,7 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
|
|
129
129
|
frames_per_sample: int = None,
|
130
130
|
wandb_config: Dict[str, Any] = None,
|
131
131
|
eval_metrics: List[EvaluationMetric] = None,
|
132
|
+
best_tracker_metric: str = "train/best_loss",
|
132
133
|
**kwargs
|
133
134
|
):
|
134
135
|
"""
|
@@ -196,6 +197,8 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
|
|
196
197
|
**kwargs
|
197
198
|
)
|
198
199
|
|
200
|
+
self.best_tracker_metric = best_tracker_metric
|
201
|
+
|
199
202
|
# Store video-specific parameters
|
200
203
|
self.frames_per_sample = frames_per_sample
|
201
204
|
|
@@ -203,6 +206,7 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
|
|
203
206
|
self.conditional_inputs = input_config.conditions
|
204
207
|
# Determine if we're working with video or images
|
205
208
|
self.is_video = self._is_video_data()
|
209
|
+
self.best_val_metrics = {}
|
206
210
|
|
207
211
|
def _is_video_data(self):
|
208
212
|
sample_data_shape = self.input_config.sample_data_shape
|
@@ -423,7 +427,8 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
|
|
423
427
|
process_index = jax.process_index()
|
424
428
|
generate_samples = val_step_fn
|
425
429
|
|
426
|
-
val_ds = iter(val_ds
|
430
|
+
val_ds = iter(val_ds) if val_ds else None
|
431
|
+
print(f"Validation loop started for process index {process_index} with {global_device_count} devices.")
|
427
432
|
# Evaluation step
|
428
433
|
try:
|
429
434
|
metrics = {metric.name: [] for metric in self.eval_metrics} if self.eval_metrics else {}
|
@@ -465,11 +470,17 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
|
|
465
470
|
else: # [B,H,W,C] - Image data
|
466
471
|
self._log_image_samples(samples, current_step)
|
467
472
|
|
468
|
-
|
469
|
-
|
470
|
-
|
471
|
-
|
472
|
-
|
473
|
+
# Flatten the metrics
|
474
|
+
if metrics:
|
475
|
+
metrics = {k: np.mean(v) for k, v in metrics.items()}
|
476
|
+
# Update the best validation metrics
|
477
|
+
for key, value in metrics.items():
|
478
|
+
if key not in self.best_val_metrics:
|
479
|
+
self.best_val_metrics[key] = value
|
480
|
+
else:
|
481
|
+
self.best_val_metrics[key] = min(self.best_val_metrics[key], value)
|
482
|
+
# Log the best validation metrics
|
483
|
+
if getattr(self, 'wandb', None) is not None and self.wandb:
|
473
484
|
# Log the metrics
|
474
485
|
for key, value in metrics.items():
|
475
486
|
if isinstance(value, jnp.ndarray):
|
@@ -477,7 +488,10 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
|
|
477
488
|
self.wandb.log({
|
478
489
|
f"val/{key}": value,
|
479
490
|
}, step=current_step)
|
480
|
-
|
491
|
+
print(f"Validation metrics for process index {process_index}: {metrics}")
|
492
|
+
|
493
|
+
# Close validation dataset iterator
|
494
|
+
del val_ds
|
481
495
|
except StopIteration:
|
482
496
|
print(f"Validation dataset exhausted for process index {process_index}")
|
483
497
|
except Exception as e:
|
@@ -602,9 +616,13 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
|
|
602
616
|
"""
|
603
617
|
if self.wandb is None:
|
604
618
|
raise ValueError("Wandb is not initialized. Cannot get best runs.")
|
605
|
-
|
619
|
+
import wandb
|
606
620
|
# Get the sweep runs
|
607
|
-
runs =
|
621
|
+
runs = [i for i in wandb.Api().runs(path=f"{self.wandb.entity}/{self.wandb.project}", filters={"config.dataset.name": self.wandb.config['dataset']['name']})]
|
622
|
+
if not runs:
|
623
|
+
raise ValueError("No runs found in wandb.")
|
624
|
+
print(f"Getting best runs from wandb {self.wandb.id}...")
|
625
|
+
runs = sorted(runs, key=lambda x: x.summary.get(metric, float('inf')))
|
608
626
|
best_runs = runs[:top_k]
|
609
627
|
lower_bound = best_runs[-1].summary.get(metric, float('inf'))
|
610
628
|
upper_bound = best_runs[0].summary.get(metric, float('inf'))
|
@@ -636,6 +654,8 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
|
|
636
654
|
# Check if current run is one of the best
|
637
655
|
if metric == "train/best_loss":
|
638
656
|
current_run_metric = self.best_loss
|
657
|
+
elif metric in self.best_val_metrics:
|
658
|
+
current_run_metric = self.best_val_metrics[metric]
|
639
659
|
else:
|
640
660
|
current_run_metric = self.wandb.summary.get(metric, float('inf') if is_lower_better else float('-inf'))
|
641
661
|
|
@@ -653,7 +673,7 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
|
|
653
673
|
if self.wandb is not None:
|
654
674
|
checkpoint = get_latest_checkpoint(self.checkpoint_path())
|
655
675
|
try:
|
656
|
-
is_good, is_best = self.__compare_run_against_best__(top_k=5, metric=
|
676
|
+
is_good, is_best = self.__compare_run_against_best__(top_k=5, metric=self.best_tracker_metric, from_sweeps=hasattr(self, "wandb_sweep"))
|
657
677
|
if is_good:
|
658
678
|
# Push to registry with appropriate aliases
|
659
679
|
aliases = []
|
@@ -26,6 +26,7 @@ from flax.training.dynamic_scale import DynamicScale
|
|
26
26
|
from flaxdiff.utils import RandomMarkovState
|
27
27
|
from flax.training import dynamic_scale as dynamic_scale_lib
|
28
28
|
from dataclasses import dataclass
|
29
|
+
import shutil
|
29
30
|
import gc
|
30
31
|
|
31
32
|
PROCESS_COLOR_MAP = {
|
@@ -73,6 +74,76 @@ class SimpleTrainState(train_state.TrainState):
|
|
73
74
|
metrics: Metrics
|
74
75
|
dynamic_scale: dynamic_scale_lib.DynamicScale
|
75
76
|
|
77
|
+
def move_contents_to_subdir(target_dir, new_subdir_name):
|
78
|
+
# --- 1. Validate Target Directory ---
|
79
|
+
if not os.path.isdir(target_dir):
|
80
|
+
print(f"Error: Target directory '{target_dir}' not found or is not a directory.")
|
81
|
+
return
|
82
|
+
# --- 2. Define Paths ---
|
83
|
+
# Construct the full path for the new subdirectory
|
84
|
+
new_subdir_path = os.path.join(target_dir, new_subdir_name)
|
85
|
+
# --- 3. Create New Subdirectory ---
|
86
|
+
try:
|
87
|
+
# Create the subdirectory.
|
88
|
+
# exist_ok=True prevents an error if the directory already exists.
|
89
|
+
os.makedirs(new_subdir_path, exist_ok=True)
|
90
|
+
print(f"Subdirectory '{new_subdir_path}' created or already exists.")
|
91
|
+
except OSError as e:
|
92
|
+
print(f"Error creating subdirectory '{new_subdir_path}': {e}")
|
93
|
+
return # Stop execution if subdirectory creation fails
|
94
|
+
# --- 4. List Contents of Target Directory ---
|
95
|
+
try:
|
96
|
+
items_to_move = os.listdir(target_dir)
|
97
|
+
except OSError as e:
|
98
|
+
print(f"Error listing contents of '{target_dir}': {e}")
|
99
|
+
return # Stop if we can't list directory contents
|
100
|
+
# --- 5. Move Items ---
|
101
|
+
print(f"Moving items from '{target_dir}' to '{new_subdir_path}'...")
|
102
|
+
moved_count = 0
|
103
|
+
error_count = 0
|
104
|
+
for item_name in items_to_move:
|
105
|
+
# Construct the full path of the item in the target directory
|
106
|
+
source_path = os.path.join(target_dir, item_name)
|
107
|
+
# IMPORTANT: Skip the newly created subdirectory itself!
|
108
|
+
if source_path == new_subdir_path:
|
109
|
+
continue
|
110
|
+
# Construct the destination path inside the new subdirectory
|
111
|
+
destination_path = os.path.join(new_subdir_path, item_name)
|
112
|
+
# Move the item
|
113
|
+
try:
|
114
|
+
shutil.move(source_path, destination_path)
|
115
|
+
# print(f" Moved: '{item_name}'") # Uncomment for verbose output
|
116
|
+
moved_count += 1
|
117
|
+
except Exception as e:
|
118
|
+
print(f" Error moving '{item_name}': {e}")
|
119
|
+
error_count += 1
|
120
|
+
print(f"\nOperation complete.")
|
121
|
+
print(f" Successfully moved: {moved_count} item(s).")
|
122
|
+
if error_count > 0:
|
123
|
+
print(f" Errors encountered: {error_count} item(s).")
|
124
|
+
|
125
|
+
def load_from_checkpoint(
|
126
|
+
checkpoint_dir: str,
|
127
|
+
):
|
128
|
+
try:
|
129
|
+
checkpointer = orbax.checkpoint.PyTreeCheckpointer()
|
130
|
+
options = orbax.checkpoint.CheckpointManagerOptions(create=False)
|
131
|
+
# Convert checkpoint_dir to absolute path
|
132
|
+
checkpoint_dir = os.path.abspath(checkpoint_dir)
|
133
|
+
manager = orbax.checkpoint.CheckpointManager(checkpoint_dir, checkpointer, options)
|
134
|
+
ckpt = manager.restore(checkpoint_dir)
|
135
|
+
# Extract as above
|
136
|
+
state, best_state = None, None
|
137
|
+
if 'state' in ckpt:
|
138
|
+
state = ckpt['state']
|
139
|
+
if 'best_state' in ckpt:
|
140
|
+
best_state = ckpt['best_state']
|
141
|
+
print(f"Loaded checkpoint from local dir {checkpoint_dir}")
|
142
|
+
return state, best_state
|
143
|
+
except Exception as e:
|
144
|
+
print(f"Warning: Failed to load checkpoint from local dir: {e}")
|
145
|
+
return None, None
|
146
|
+
|
76
147
|
@dataclass
|
77
148
|
class SimpleTrainer:
|
78
149
|
state: SimpleTrainState
|
@@ -97,6 +168,7 @@ class SimpleTrainer:
|
|
97
168
|
checkpoint_step: int = None,
|
98
169
|
use_dynamic_scale: bool = False,
|
99
170
|
max_checkpoints_to_keep: int = 2,
|
171
|
+
train_start_step_override: int = None,
|
100
172
|
):
|
101
173
|
if distributed_training is None or distributed_training is True:
|
102
174
|
# Auto-detect if we are running on multiple devices
|
@@ -112,11 +184,32 @@ class SimpleTrainer:
|
|
112
184
|
self.input_shapes = input_shapes
|
113
185
|
self.checkpoint_base_path = checkpoint_base_path
|
114
186
|
|
187
|
+
load_directly_from_dir = False
|
188
|
+
|
115
189
|
if wandb_config is not None and jax.process_index() == 0:
|
116
190
|
import wandb
|
117
191
|
run = wandb.init(resume='allow', **wandb_config)
|
118
192
|
self.wandb = run
|
119
193
|
|
194
|
+
if 'id' in wandb_config:
|
195
|
+
# If resuming from a previous run, and train_start_step_override is not set,
|
196
|
+
# set the start step to the last step of the previous run
|
197
|
+
if train_start_step_override is None:
|
198
|
+
train_start_step_override = run.summary['train/step'] + 1
|
199
|
+
print(f"Resuming from previous run {wandb_config['id']} with start step {train_start_step_override}")
|
200
|
+
|
201
|
+
# If load_from_checkpoint is not set, and an artifact is found, load the artifact
|
202
|
+
if load_from_checkpoint is None:
|
203
|
+
api_run = wandb.Api().run(f"{wandb_config['entity']}/{wandb_config['project']}/{wandb_config['id']}")
|
204
|
+
model_artifacts = [i for i in api_run.logged_artifacts() if i.type == 'model']
|
205
|
+
if model_artifacts:
|
206
|
+
artifact = model_artifacts[0]
|
207
|
+
artifact_dir = artifact.download()
|
208
|
+
print(f"Loading model from artifact {artifact.name} at {artifact_dir}")
|
209
|
+
# Move the artifact's contents
|
210
|
+
load_from_checkpoint = artifact_dir
|
211
|
+
load_directly_from_dir = True
|
212
|
+
|
120
213
|
# define our custom x axis metric
|
121
214
|
self.wandb.define_metric("train/step")
|
122
215
|
self.wandb.define_metric("train/epoch")
|
@@ -142,12 +235,16 @@ class SimpleTrainer:
|
|
142
235
|
self.checkpoint_path(), async_checkpointer, options)
|
143
236
|
|
144
237
|
if load_from_checkpoint is not None:
|
145
|
-
|
238
|
+
latest_step, old_state, old_best_state, rngstate = self.load(load_from_checkpoint, checkpoint_step, load_directly_from_dir)
|
146
239
|
else:
|
147
|
-
|
240
|
+
latest_step, old_state, old_best_state, rngstate = 0, None, None, None
|
148
241
|
|
149
242
|
self.latest_step = latest_step
|
150
243
|
|
244
|
+
if train_start_step_override is not None:
|
245
|
+
self.latest_step = train_start_step_override
|
246
|
+
print(f"Overriding start step to {self.latest_step}")
|
247
|
+
|
151
248
|
if rngstate:
|
152
249
|
self.rngstate = RandomMarkovState(**rngstate)
|
153
250
|
else:
|
@@ -239,15 +336,12 @@ class SimpleTrainer:
|
|
239
336
|
os.makedirs(path)
|
240
337
|
return path
|
241
338
|
|
242
|
-
def load(self, checkpoint_path=None,
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
max_to_keep=4, create=False)
|
249
|
-
checkpointer = orbax.checkpoint.CheckpointManager(
|
250
|
-
checkpoint_path, checkpointer, options)
|
339
|
+
def load(self, checkpoint_path, checkpoint_step=None, load_directly_from_dir=False):
|
340
|
+
checkpointer = orbax.checkpoint.PyTreeCheckpointer()
|
341
|
+
options = orbax.checkpoint.CheckpointManagerOptions(
|
342
|
+
max_to_keep=4, create=False)
|
343
|
+
checkpointer = orbax.checkpoint.CheckpointManager(
|
344
|
+
checkpoint_path, checkpointer, options)
|
251
345
|
|
252
346
|
if checkpoint_step is None:
|
253
347
|
step = checkpointer.latest_step()
|
@@ -259,7 +353,8 @@ class SimpleTrainer:
|
|
259
353
|
checkpoint_path if checkpoint_path else self.checkpoint_path(),
|
260
354
|
f"{step}")
|
261
355
|
self.loaded_checkpoint_path = loaded_checkpoint_path
|
262
|
-
ckpt = checkpointer.restore(step)
|
356
|
+
ckpt = checkpointer.restore(step) if not load_directly_from_dir else checkpointer.restore(checkpoint_path)
|
357
|
+
|
263
358
|
state = ckpt['state']
|
264
359
|
best_state = ckpt['best_state']
|
265
360
|
rngstate = ckpt['rngs']
|
@@ -268,10 +363,8 @@ class SimpleTrainer:
|
|
268
363
|
if self.best_loss == 0:
|
269
364
|
# It cant be zero as that must have been some problem
|
270
365
|
self.best_loss = 1e9
|
271
|
-
|
272
|
-
|
273
|
-
f"Loaded model from checkpoint at epoch {current_epoch} step {step}", ckpt['best_loss'])
|
274
|
-
return current_epoch, step, state, best_state, rngstate
|
366
|
+
print(f"Loaded model from checkpoint at step {step}", ckpt['best_loss'])
|
367
|
+
return step, state, best_state, rngstate
|
275
368
|
|
276
369
|
def save(self, epoch=0, step=0, state=None, rngstate=None):
|
277
370
|
print(f"Saving model at epoch {epoch} step {step}")
|
@@ -507,6 +600,7 @@ class SimpleTrainer:
|
|
507
600
|
|
508
601
|
def fit(self, data, train_steps_per_epoch, epochs, train_step_args={}, val_steps_per_epoch=5, validation_step_args={}):
|
509
602
|
train_ds = iter(data['train']())
|
603
|
+
val_ds = data.get('val', data.get('test', None))()
|
510
604
|
train_step = self._define_train_step(**train_step_args)
|
511
605
|
val_step = self._define_validation_step(**validation_step_args)
|
512
606
|
train_state = self.state
|
@@ -520,7 +614,7 @@ class SimpleTrainer:
|
|
520
614
|
self.validation_loop(
|
521
615
|
train_state,
|
522
616
|
val_step,
|
523
|
-
|
617
|
+
val_ds,
|
524
618
|
val_steps_per_epoch,
|
525
619
|
self.latest_step,
|
526
620
|
)
|
@@ -571,11 +665,11 @@ class SimpleTrainer:
|
|
571
665
|
self.validation_loop(
|
572
666
|
train_state,
|
573
667
|
val_step,
|
574
|
-
|
668
|
+
val_ds,
|
575
669
|
val_steps_per_epoch,
|
576
670
|
current_step,
|
577
671
|
)
|
578
672
|
print(colored(f"Validation done on process index {process_index}", PROCESS_COLOR_MAP[process_index]))
|
579
673
|
|
580
|
-
self.save(epochs)
|
674
|
+
self.save(epochs)#
|
581
675
|
return self.state
|
@@ -2,20 +2,20 @@ flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
2
|
flaxdiff/utils.py,sha256=DmlWUY1FGz4ESxIHaPQJf92CHjsdMjyDd651wFUtyNg,8838
|
3
3
|
flaxdiff/data/__init__.py,sha256=8W5y7NyAOWtpLi8WRawk4VYeE3DMDnM3B_jKPD8BoFQ,143
|
4
4
|
flaxdiff/data/benchmark_decord.py,sha256=x56Db1VPmziv_9KJvWdfS0O7cffsYkF5tt5WvldOKc0,13720
|
5
|
-
flaxdiff/data/dataloaders.py,sha256=
|
6
|
-
flaxdiff/data/dataset_map.py,sha256=
|
5
|
+
flaxdiff/data/dataloaders.py,sha256=k_3YGJhiY2Wt_-7qK0Yjl4pmF2QJjX_-BlSFuXbH5-M,23628
|
6
|
+
flaxdiff/data/dataset_map.py,sha256=p30U23RkfgMbR8kfPBDIjrjfzDBszWQ9Q1ff2BvDYZk,5116
|
7
7
|
flaxdiff/data/online_loader.py,sha256=t1jEhdB6gWTlwx68ehj1ol_PrImbwXYiRlrJPCmNgCM,35701
|
8
8
|
flaxdiff/data/sources/audio_utils.py,sha256=X27gG1yQt_abVOYgMtruYmZD7-8_uQCRhhTSpn4clkI,4514
|
9
9
|
flaxdiff/data/sources/av_example.py,sha256=RIcbVKqckFqbfnV65NQotzIBxjdDuM67kD1nY8fqw5Q,3826
|
10
10
|
flaxdiff/data/sources/av_utils.py,sha256=LCr9MJNurOaoxY-sjzkLqJS_MlX0x3gRSlKAVIglAU0,24045
|
11
|
-
flaxdiff/data/sources/base.py,sha256=
|
12
|
-
flaxdiff/data/sources/images.py,sha256=
|
11
|
+
flaxdiff/data/sources/base.py,sha256=4Rm9pCtXxzoB8FO0lkDHsrX3ULoU_PNNcid978e6ir0,4610
|
12
|
+
flaxdiff/data/sources/images.py,sha256=ZHBmZ2fnPN75Hc2kiog-Wcs_NZJZOiqw4WcSH5WZJHA,16572
|
13
13
|
flaxdiff/data/sources/utils.py,sha256=kFzM4_kPoThbAu54ulABmEDAR33tR50NgzXIpC0Dzjk,7316
|
14
|
-
flaxdiff/data/sources/videos.py,sha256=
|
14
|
+
flaxdiff/data/sources/videos.py,sha256=NkxwEruNpAwDCM53q4WurQ802gSjQMOqjNLxYOqjoNE,9545
|
15
15
|
flaxdiff/data/sources/voxceleb2.py,sha256=BoKfat_hsw6ObDyyaiQmPbBzuFiqgCGlgAZmf-t5Iz8,18621
|
16
16
|
flaxdiff/inference/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
17
17
|
flaxdiff/inference/pipeline.py,sha256=8S30FAlXEjvrDd87H-qdD6biySQZ3cJUflU8gdmPxig,9223
|
18
|
-
flaxdiff/inference/utils.py,sha256=
|
18
|
+
flaxdiff/inference/utils.py,sha256=JEBZYSgj-0DLJTV-TNmIAllAqqVJMn0KfryHwFO-MFs,12606
|
19
19
|
flaxdiff/inputs/__init__.py,sha256=ybPjQsFAf5sqRVZG1sRiOl99EnwpI-NQ8HE3y7UbXmU,7197
|
20
20
|
flaxdiff/inputs/encoders.py,sha256=pjfbx4Rk7bLoE80MOfThZDm6YtsDncRekmn0Bmg_CwI,2963
|
21
21
|
flaxdiff/metrics/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -27,14 +27,17 @@ flaxdiff/metrics/ssim.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
27
27
|
flaxdiff/metrics/utils.py,sha256=YuuOfqvqgIjsceupwNeJ59vQ2TnGeNMIyKdkIqOmoNg,1702
|
28
28
|
flaxdiff/models/__init__.py,sha256=amtDF07DfiAdnZsvWX4eaW79nwNEU1s8Zb4PB3ewtg4,118
|
29
29
|
flaxdiff/models/attention.py,sha256=YkED3_MRTjI9aTMTTQdsuReHhG8MK0Z4OVuU2j8ZAHs,13524
|
30
|
-
flaxdiff/models/
|
31
|
-
flaxdiff/models/common.py,sha256=0j9AAjGPgkBLHo2DlYj0R6OsUNw2QaoDjaXSKq2mqkA,12647
|
30
|
+
flaxdiff/models/common.py,sha256=QpciwuJldvLUwyAyWBQqiPPGVI-c9qLR7h7C1YoRX7w,10510
|
32
31
|
flaxdiff/models/favor_fastattn.py,sha256=79Ew1nqarsNLPzZaBSd1ILORzJr74CupYeqGiCQK5E4,27689
|
33
32
|
flaxdiff/models/general.py,sha256=7xMME6KVKQY8sScyHYH4f-Kek4j1pRfplKShFXwVZd4,587
|
33
|
+
flaxdiff/models/hilbert.py,sha256=AjlAv49dL6UAYWslMJfCMLiFqY4kTgpiUWr2nc1mk34,24823
|
34
|
+
flaxdiff/models/simple_dit.py,sha256=l238MYHRTArv_pS57aY24C2PTfxeL8EmzJ24iQqdoWI,11702
|
35
|
+
flaxdiff/models/simple_mmdit.py,sha256=ARk0juopn2k7giln5BAUrnYD1pTFwgTJoSzrhozQ0A8,31356
|
34
36
|
flaxdiff/models/simple_unet.py,sha256=pjeixszG_6gEY5PNFbQ7KbOyg4z5bfn4RUbINCJexOM,10758
|
35
|
-
flaxdiff/models/simple_vit.py,sha256=
|
37
|
+
flaxdiff/models/simple_vit.py,sha256=J9s3hBF87_iVrJDBe2cs9a56N7ect6pux_f_ge07XXc,17357
|
36
38
|
flaxdiff/models/unet_3d.py,sha256=LF0PMxBKGU-_lAMtO_Coxy1yRE02yKKdgb7i6YZxI_4,20163
|
37
39
|
flaxdiff/models/unet_3d_blocks.py,sha256=lRYDc9X1VEu54Kg7xEEphXYiQ09tabPXKi-hEcKFYug,19687
|
40
|
+
flaxdiff/models/vit_common.py,sha256=1OGu4ezY3uzKinTnw3p8YkQAslHDqEbN78JheXnTleI,9831
|
38
41
|
flaxdiff/models/autoencoder/__init__.py,sha256=qY-7MldZpsfkF-_T2LqlRK7VHbqfmosz0NmvzDlBkOk,78
|
39
42
|
flaxdiff/models/autoencoder/autoencoder.py,sha256=8XWdsWvsPsyWGtzpCT8w0KXi_ZLGpRuQpn4oXo1gHKw,6039
|
40
43
|
flaxdiff/models/autoencoder/diffusers.py,sha256=tPz77YuctrT--jF2AOL8G6vr0NiIr3RXANNrZCxe0bg,5921
|
@@ -60,9 +63,9 @@ flaxdiff/schedulers/sqrt.py,sha256=mCd_szmOqF6vqQKiAiEOqV_3eBIPGYrW3VxK0o4rBuo,4
|
|
60
63
|
flaxdiff/trainer/__init__.py,sha256=xSoierfi26gxfgxlNnwvyyPmuPAJ--5i3mEHxt3S-AE,215
|
61
64
|
flaxdiff/trainer/autoencoder_trainer.py,sha256=2FP2P-k9c0n_k3eT0trkq73dQrHRdBj9ObK1idcyhSw,6996
|
62
65
|
flaxdiff/trainer/diffusion_trainer.py,sha256=reQEVWKTqKAeyCMQ-curPOfSRmBKxKooK8EVtUuorcM,14599
|
63
|
-
flaxdiff/trainer/general_diffusion_trainer.py,sha256=
|
64
|
-
flaxdiff/trainer/simple_trainer.py,sha256=
|
65
|
-
flaxdiff-0.2.
|
66
|
-
flaxdiff-0.2.
|
67
|
-
flaxdiff-0.2.
|
68
|
-
flaxdiff-0.2.
|
66
|
+
flaxdiff/trainer/general_diffusion_trainer.py,sha256=OtE2spZIBFPpY6q-ijYol5Y-CaP2UHJYIDX3PFBiPtg,29492
|
67
|
+
flaxdiff/trainer/simple_trainer.py,sha256=Hdltuo3lgF61N04Lxc7L3z6NLveW4_h1ff7_5mu3Wbg,28730
|
68
|
+
flaxdiff-0.2.9.dist-info/METADATA,sha256=a8btxHRkAZVieuZfTyXgPkJbEG9fZRknEhq2Ti3_7m4,24057
|
69
|
+
flaxdiff-0.2.9.dist-info/WHEEL,sha256=0CuiUZ_p9E4cD6NyLD6UG80LBXYyiSYZOKDm5lp32xk,91
|
70
|
+
flaxdiff-0.2.9.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
|
71
|
+
flaxdiff-0.2.9.dist-info/RECORD,,
|