flaxdiff 0.1.38__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/data/__init__.py +5 -1
- flaxdiff/data/benchmark_decord.py +443 -0
- flaxdiff/data/dataloaders.py +608 -0
- flaxdiff/data/dataset_map.py +61 -6
- flaxdiff/data/online_loader.py +779 -150
- flaxdiff/data/sources/audio_utils.py +142 -0
- flaxdiff/data/sources/av_example.py +125 -0
- flaxdiff/data/sources/av_utils.py +590 -0
- flaxdiff/data/sources/base.py +129 -0
- flaxdiff/data/sources/images.py +309 -0
- flaxdiff/data/sources/utils.py +158 -0
- flaxdiff/data/sources/videos.py +250 -0
- flaxdiff/data/sources/voxceleb2.py +412 -0
- flaxdiff/inference/__init__.py +0 -0
- flaxdiff/inference/pipeline.py +260 -0
- flaxdiff/inference/utils.py +320 -0
- flaxdiff/inputs/__init__.py +173 -0
- flaxdiff/inputs/encoders.py +98 -0
- flaxdiff/models/__init__.py +2 -1
- flaxdiff/models/attention.py +22 -16
- flaxdiff/models/autoencoder/autoencoder.py +141 -9
- flaxdiff/models/autoencoder/diffusers.py +88 -25
- flaxdiff/models/autoencoder/simple_autoenc.py +40 -8
- flaxdiff/models/common.py +8 -18
- flaxdiff/models/simple_unet.py +6 -17
- flaxdiff/models/simple_vit.py +9 -13
- flaxdiff/models/unet_3d.py +446 -0
- flaxdiff/models/unet_3d_blocks.py +505 -0
- flaxdiff/samplers/common.py +358 -96
- flaxdiff/samplers/ddim.py +44 -5
- flaxdiff/schedulers/karras.py +20 -12
- flaxdiff/trainer/__init__.py +2 -1
- flaxdiff/trainer/autoencoder_trainer.py +1 -2
- flaxdiff/trainer/diffusion_trainer.py +35 -29
- flaxdiff/trainer/general_diffusion_trainer.py +583 -0
- flaxdiff/trainer/simple_trainer.py +51 -16
- flaxdiff/utils.py +128 -57
- {flaxdiff-0.1.38.dist-info → flaxdiff-0.2.0.dist-info}/METADATA +1 -1
- flaxdiff-0.2.0.dist-info/RECORD +64 -0
- {flaxdiff-0.1.38.dist-info → flaxdiff-0.2.0.dist-info}/WHEEL +1 -1
- flaxdiff/data/datasets.py +0 -169
- flaxdiff/data/sources/gcs.py +0 -81
- flaxdiff/data/sources/tfds.py +0 -79
- flaxdiff/trainer/video_diffusion_trainer.py +0 -62
- flaxdiff-0.1.38.dist-info/RECORD +0 -50
- {flaxdiff-0.1.38.dist-info → flaxdiff-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,583 @@
|
|
1
|
+
import json
|
2
|
+
import flax
|
3
|
+
from flax import linen as nn
|
4
|
+
import jax
|
5
|
+
from typing import Callable, List, Dict, Tuple, Union, Any, Sequence, Type, Optional
|
6
|
+
from dataclasses import field, dataclass
|
7
|
+
import jax.numpy as jnp
|
8
|
+
import optax
|
9
|
+
import functools
|
10
|
+
from jax.sharding import Mesh, PartitionSpec as P
|
11
|
+
from jax.experimental.shard_map import shard_map
|
12
|
+
|
13
|
+
from ..schedulers import NoiseScheduler, get_coeff_shapes_tuple
|
14
|
+
from ..predictors import DiffusionPredictionTransform, EpsilonPredictionTransform
|
15
|
+
from ..samplers.common import DiffusionSampler
|
16
|
+
from ..samplers.ddim import DDIMSampler
|
17
|
+
|
18
|
+
from flaxdiff.utils import RandomMarkovState, serialize_model, get_latest_checkpoint
|
19
|
+
from flaxdiff.inputs import ConditioningEncoder, ConditionalInputConfig, DiffusionInputConfig
|
20
|
+
|
21
|
+
from .simple_trainer import SimpleTrainer, SimpleTrainState, Metrics
|
22
|
+
|
23
|
+
from flaxdiff.models.autoencoder.autoencoder import AutoEncoder
|
24
|
+
from flax.training import dynamic_scale as dynamic_scale_lib
|
25
|
+
|
26
|
+
# Reuse the TrainState from the DiffusionTrainer
|
27
|
+
from flaxdiff.trainer.diffusion_trainer import TrainState, DiffusionTrainer
|
28
|
+
import shutil
|
29
|
+
|
30
|
+
def generate_modelname(
|
31
|
+
dataset_name: str,
|
32
|
+
noise_schedule_name: str,
|
33
|
+
architecture_name: str,
|
34
|
+
model: nn.Module,
|
35
|
+
input_config: DiffusionInputConfig,
|
36
|
+
autoencoder: AutoEncoder = None,
|
37
|
+
frames_per_sample: int = None,
|
38
|
+
) -> str:
|
39
|
+
"""
|
40
|
+
Generate a model name based on the configuration.
|
41
|
+
|
42
|
+
Args:
|
43
|
+
config: Configuration dictionary.
|
44
|
+
|
45
|
+
Returns:
|
46
|
+
A string representing the model name.
|
47
|
+
"""
|
48
|
+
import hashlib
|
49
|
+
import json
|
50
|
+
|
51
|
+
# Extract key components for the name
|
52
|
+
|
53
|
+
model_name = f"diffusion-{dataset_name}-res{input_config.sample_data_shape[-2]}"
|
54
|
+
|
55
|
+
# model_name = f"diffusion-{dataset_name}-res{input_config.sample_data_shape[-2]}-{noise_schedule_name}-{architecture_name}"
|
56
|
+
|
57
|
+
# if autoencoder is not None:
|
58
|
+
# model_name += f"-vae"
|
59
|
+
|
60
|
+
# if frames_per_sample is not None:
|
61
|
+
# model_name += f"-frames_{frames_per_sample}"
|
62
|
+
|
63
|
+
# model_name += f"-{'.'.join([cond.encoder.key for cond in input_config.conditions])}"
|
64
|
+
|
65
|
+
# # Create a sorted representation of model config for consistent hashing
|
66
|
+
# def sort_dict_recursively(d):
|
67
|
+
# if isinstance(d, dict):
|
68
|
+
# return {k: sort_dict_recursively(d[k]) for k in sorted(d.keys())}
|
69
|
+
# elif isinstance(d, list):
|
70
|
+
# return [sort_dict_recursively(v) for v in d]
|
71
|
+
# else:
|
72
|
+
# return d
|
73
|
+
|
74
|
+
# # Extract model config and sort it
|
75
|
+
# model_config = serialize_model(model)
|
76
|
+
# sorted_model_config = sort_dict_recursively(model_config)
|
77
|
+
|
78
|
+
# # Convert to JSON string with sorted keys for consistent hash
|
79
|
+
# try:
|
80
|
+
# config_json = json.dumps(sorted_model_config)
|
81
|
+
# except TypeError:
|
82
|
+
# # Handle non-serializable objects
|
83
|
+
# def make_serializable(obj):
|
84
|
+
# if isinstance(obj, dict):
|
85
|
+
# return {k: make_serializable(v) for k, v in obj.items()}
|
86
|
+
# elif isinstance(obj, list):
|
87
|
+
# return [make_serializable(v) for v in obj]
|
88
|
+
# else:
|
89
|
+
# try:
|
90
|
+
# # Test if object is JSON serializable
|
91
|
+
# json.dumps(obj)
|
92
|
+
# return obj
|
93
|
+
# except TypeError:
|
94
|
+
# return str(obj)
|
95
|
+
|
96
|
+
# serializable_config = make_serializable(sorted_model_config)
|
97
|
+
# config_json = json.dumps(serializable_config)
|
98
|
+
|
99
|
+
# # Generate a hash of the configuration
|
100
|
+
# config_hash = hashlib.md5(config_json.encode('utf-8')).hexdigest()[:8]
|
101
|
+
|
102
|
+
# # Construct the model name
|
103
|
+
# model_name = f"{model_name}-{config_hash}"
|
104
|
+
return model_name
|
105
|
+
|
106
|
+
class GeneralDiffusionTrainer(DiffusionTrainer):
|
107
|
+
"""
|
108
|
+
General trainer for diffusion models supporting both images and videos.
|
109
|
+
|
110
|
+
Extends DiffusionTrainer to support:
|
111
|
+
1. Both image data (4D tensors: B,H,W,C) and video data (5D tensors: B,T,H,W,C)
|
112
|
+
2. Multiple conditioning inputs
|
113
|
+
3. Various model architectures
|
114
|
+
"""
|
115
|
+
|
116
|
+
def __init__(self,
|
117
|
+
model: nn.Module,
|
118
|
+
optimizer: optax.GradientTransformation,
|
119
|
+
noise_schedule: NoiseScheduler,
|
120
|
+
input_config: DiffusionInputConfig,
|
121
|
+
rngs: jax.random.PRNGKey,
|
122
|
+
unconditional_prob: float = 0.12,
|
123
|
+
name: str = "GeneralDiffusion",
|
124
|
+
model_output_transform: DiffusionPredictionTransform = EpsilonPredictionTransform(),
|
125
|
+
autoencoder: AutoEncoder = None,
|
126
|
+
native_resolution: int = None,
|
127
|
+
frames_per_sample: int = None,
|
128
|
+
wandb_config: Dict[str, Any] = None,
|
129
|
+
**kwargs
|
130
|
+
):
|
131
|
+
"""
|
132
|
+
Initialize the general diffusion trainer.
|
133
|
+
|
134
|
+
Args:
|
135
|
+
model: Neural network model
|
136
|
+
optimizer: Optimization algorithm
|
137
|
+
noise_schedule: Noise scheduler for diffusion process
|
138
|
+
input_config: Configuration for input data, including keys, shapes and conditioning inputs
|
139
|
+
rngs: Random number generator keys
|
140
|
+
unconditional_prob: Probability of training with unconditional samples
|
141
|
+
name: Name of this trainer
|
142
|
+
model_output_transform: Transform for model predictions
|
143
|
+
autoencoder: Optional autoencoder for latent diffusion
|
144
|
+
native_resolution: Native resolution of the data
|
145
|
+
frames_per_sample: Number of frames per video sample (for video only)
|
146
|
+
**kwargs: Additional arguments for parent class
|
147
|
+
"""
|
148
|
+
# Initialize with parent DiffusionTrainer but without encoder parameter
|
149
|
+
input_shapes = input_config.get_input_shapes(
|
150
|
+
autoencoder=autoencoder,
|
151
|
+
)
|
152
|
+
self.input_config = input_config
|
153
|
+
|
154
|
+
if wandb_config is not None:
|
155
|
+
# If input_config is not in wandb_config, add it
|
156
|
+
if 'input_config' not in wandb_config['config']:
|
157
|
+
wandb_config['config']['input_config'] = input_config.serialize()
|
158
|
+
# If model is not in wandb_config, add it
|
159
|
+
if 'model' not in wandb_config['config']:
|
160
|
+
wandb_config['config']['model'] = serialize_model(model)
|
161
|
+
if 'autoencoder' not in wandb_config['config'] and autoencoder is not None:
|
162
|
+
wandb_config['config']['autoencoder'] = autoencoder.name
|
163
|
+
wandb_config['config']['autoencoder_opts'] = json.dumps(autoencoder.serialize())
|
164
|
+
|
165
|
+
# Generate a model name based on the configuration
|
166
|
+
modelname = generate_modelname(
|
167
|
+
dataset_name=wandb_config['config']['arguments']['dataset'],
|
168
|
+
noise_schedule_name=wandb_config['config']['arguments']['noise_schedule'],
|
169
|
+
architecture_name=wandb_config['config']['arguments']['architecture'],
|
170
|
+
model=model,
|
171
|
+
input_config=input_config,
|
172
|
+
autoencoder=autoencoder,
|
173
|
+
frames_per_sample=frames_per_sample,
|
174
|
+
)
|
175
|
+
print("Model name:", modelname)
|
176
|
+
self.modelname = modelname
|
177
|
+
wandb_config['config']['modelname'] = modelname
|
178
|
+
|
179
|
+
super().__init__(
|
180
|
+
model=model,
|
181
|
+
input_shapes=input_shapes,
|
182
|
+
optimizer=optimizer,
|
183
|
+
noise_schedule=noise_schedule,
|
184
|
+
unconditional_prob=unconditional_prob,
|
185
|
+
autoencoder=autoencoder,
|
186
|
+
model_output_transform=model_output_transform,
|
187
|
+
rngs=rngs,
|
188
|
+
name=name,
|
189
|
+
native_resolution=native_resolution,
|
190
|
+
encoder=None, # Don't use the default encoder from the parent class
|
191
|
+
wandb_config=wandb_config,
|
192
|
+
**kwargs
|
193
|
+
)
|
194
|
+
|
195
|
+
# Store video-specific parameters
|
196
|
+
self.frames_per_sample = frames_per_sample
|
197
|
+
|
198
|
+
# List of conditional inputs
|
199
|
+
self.conditional_inputs = input_config.conditions
|
200
|
+
# Determine if we're working with video or images
|
201
|
+
self.is_video = self._is_video_data()
|
202
|
+
|
203
|
+
def _is_video_data(self):
|
204
|
+
sample_data_shape = self.input_config.sample_data_shape
|
205
|
+
return len(sample_data_shape) == 5
|
206
|
+
|
207
|
+
def _define_train_step(self, batch_size):
|
208
|
+
"""
|
209
|
+
Define the training step function for both image and video diffusion.
|
210
|
+
Optimized for efficient sharding and JIT compilation.
|
211
|
+
"""
|
212
|
+
# Access class variables once for JIT optimization
|
213
|
+
noise_schedule = self.noise_schedule
|
214
|
+
model = self.model
|
215
|
+
model_output_transform = self.model_output_transform
|
216
|
+
loss_fn = self.loss_fn
|
217
|
+
distributed_training = self.distributed_training
|
218
|
+
autoencoder = self.autoencoder
|
219
|
+
unconditional_prob = self.unconditional_prob
|
220
|
+
|
221
|
+
input_config = self.input_config
|
222
|
+
sample_data_key = input_config.sample_data_key
|
223
|
+
|
224
|
+
# JIT-optimized function for processing conditional inputs
|
225
|
+
# @functools.partial(jax.jit, static_argnums=(2,))
|
226
|
+
def process_conditioning(batch, uncond_mask):
|
227
|
+
return input_config.process_conditioning(
|
228
|
+
batch,
|
229
|
+
uncond_mask=uncond_mask,
|
230
|
+
)
|
231
|
+
|
232
|
+
# Main training step function - optimized for JIT compilation and sharding
|
233
|
+
def train_step(train_state: TrainState, rng_state: RandomMarkovState, batch, local_device_index):
|
234
|
+
"""Training step optimized for distributed execution."""
|
235
|
+
# Random key handling
|
236
|
+
rng_state, key_fold = rng_state.get_random_key()
|
237
|
+
folded_key = jax.random.fold_in(key_fold, local_device_index.reshape())
|
238
|
+
local_rng_state = RandomMarkovState(folded_key)
|
239
|
+
|
240
|
+
# Extract and normalize data (works for both images and videos)
|
241
|
+
data = batch[sample_data_key]
|
242
|
+
local_batch_size = data.shape[0]
|
243
|
+
data = (jnp.asarray(data, dtype=jnp.float32) - 127.5) / 127.5
|
244
|
+
|
245
|
+
# Autoencoder step (handles both image and video data)
|
246
|
+
if autoencoder is not None:
|
247
|
+
local_rng_state, enc_key = local_rng_state.get_random_key()
|
248
|
+
data = autoencoder.encode(data, enc_key)
|
249
|
+
|
250
|
+
# Determine number of unconditional samples per mini batch randomly
|
251
|
+
local_rng_state, uncond_key = local_rng_state.get_random_key()
|
252
|
+
# Determine unconditional samples
|
253
|
+
uncond_mask = jax.random.bernoulli(
|
254
|
+
uncond_key,
|
255
|
+
shape=(local_batch_size,),
|
256
|
+
p=unconditional_prob
|
257
|
+
)
|
258
|
+
|
259
|
+
# Process conditioning
|
260
|
+
all_conditional_inputs = process_conditioning(batch, uncond_mask)
|
261
|
+
|
262
|
+
# Generate diffusion timesteps
|
263
|
+
noise_level, local_rng_state = noise_schedule.generate_timesteps(local_batch_size, local_rng_state)
|
264
|
+
|
265
|
+
# Generate noise
|
266
|
+
local_rng_state, noise_key = local_rng_state.get_random_key()
|
267
|
+
noise = jax.random.normal(noise_key, shape=data.shape, dtype=jnp.float32)
|
268
|
+
|
269
|
+
# Forward diffusion process
|
270
|
+
rates = noise_schedule.get_rates(noise_level, get_coeff_shapes_tuple(data))
|
271
|
+
noisy_data, c_in, expected_output = model_output_transform.forward_diffusion(data, noise, rates)
|
272
|
+
|
273
|
+
# Loss function
|
274
|
+
def model_loss(params):
|
275
|
+
# Apply model
|
276
|
+
inputs = noise_schedule.transform_inputs(noisy_data * c_in, noise_level)
|
277
|
+
preds = model.apply(params, *inputs, *all_conditional_inputs)
|
278
|
+
|
279
|
+
# Transform predictions and calculate loss
|
280
|
+
preds = model_output_transform.pred_transform(noisy_data, preds, rates)
|
281
|
+
sample_losses = loss_fn(preds, expected_output)
|
282
|
+
|
283
|
+
# Apply loss weighting
|
284
|
+
weights = noise_schedule.get_weights(noise_level, get_coeff_shapes_tuple(sample_losses))
|
285
|
+
weighted_loss = sample_losses * weights
|
286
|
+
|
287
|
+
return jnp.mean(weighted_loss)
|
288
|
+
|
289
|
+
# Compute gradients and apply updates
|
290
|
+
if train_state.dynamic_scale is not None:
|
291
|
+
# Mixed precision training with dynamic scale
|
292
|
+
grad_fn = train_state.dynamic_scale.value_and_grad(model_loss, axis_name="data")
|
293
|
+
dynamic_scale, is_finite, loss, grads = grad_fn(train_state.params)
|
294
|
+
|
295
|
+
train_state = train_state.replace(dynamic_scale=dynamic_scale)
|
296
|
+
new_state = train_state.apply_gradients(grads=grads)
|
297
|
+
|
298
|
+
# Handle NaN/Inf gradients
|
299
|
+
select_fn = functools.partial(jnp.where, is_finite)
|
300
|
+
new_state = new_state.replace(
|
301
|
+
opt_state=jax.tree_map(select_fn, new_state.opt_state, train_state.opt_state),
|
302
|
+
params=jax.tree_map(select_fn, new_state.params, train_state.params)
|
303
|
+
)
|
304
|
+
else:
|
305
|
+
# Standard gradient computation
|
306
|
+
grad_fn = jax.value_and_grad(model_loss)
|
307
|
+
loss, grads = grad_fn(train_state.params)
|
308
|
+
|
309
|
+
if distributed_training:
|
310
|
+
grads = jax.lax.pmean(grads, axis_name="data")
|
311
|
+
|
312
|
+
new_state = train_state.apply_gradients(grads=grads)
|
313
|
+
|
314
|
+
# Apply EMA update
|
315
|
+
new_state = new_state.apply_ema(self.ema_decay)
|
316
|
+
|
317
|
+
# Average loss across devices if distributed
|
318
|
+
if distributed_training:
|
319
|
+
loss = jax.lax.pmean(loss, axis_name="data")
|
320
|
+
|
321
|
+
return new_state, loss, rng_state
|
322
|
+
|
323
|
+
# Apply sharding for distributed training
|
324
|
+
if distributed_training:
|
325
|
+
train_step = shard_map(
|
326
|
+
train_step,
|
327
|
+
mesh=self.mesh,
|
328
|
+
in_specs=(P(), P(), P('data'), P('data')),
|
329
|
+
out_specs=(P(), P(), P()),
|
330
|
+
)
|
331
|
+
|
332
|
+
# Apply JIT compilation
|
333
|
+
train_step = jax.jit(train_step, donate_argnums=(2))
|
334
|
+
return train_step
|
335
|
+
|
336
|
+
def _define_validation_step(self, sampler_class: Type[DiffusionSampler]=DDIMSampler, sampling_noise_schedule: NoiseScheduler=None):
|
337
|
+
"""
|
338
|
+
Define the validation step for both image and video diffusion models.
|
339
|
+
"""
|
340
|
+
# Setup for validation
|
341
|
+
model = self.model
|
342
|
+
autoencoder = self.autoencoder
|
343
|
+
input_config = self.input_config
|
344
|
+
conditional_inputs = self.conditional_inputs
|
345
|
+
is_video = self.is_video
|
346
|
+
|
347
|
+
# Get necessary parameters
|
348
|
+
image_size = self._get_image_size()
|
349
|
+
|
350
|
+
# Get sequence length only for video data
|
351
|
+
sequence_length = self._get_sequence_length() if is_video else None
|
352
|
+
|
353
|
+
# Initialize the sampler
|
354
|
+
sampler = sampler_class(
|
355
|
+
model=model,
|
356
|
+
noise_schedule=self.noise_schedule if sampling_noise_schedule is None else sampling_noise_schedule,
|
357
|
+
model_output_transform=self.model_output_transform,
|
358
|
+
input_config=input_config,
|
359
|
+
autoencoder=autoencoder,
|
360
|
+
guidance_scale=3.0,
|
361
|
+
)
|
362
|
+
|
363
|
+
def generate_samples(
|
364
|
+
val_state: TrainState,
|
365
|
+
batch,
|
366
|
+
sampler: DiffusionSampler,
|
367
|
+
diffusion_steps: int,
|
368
|
+
):
|
369
|
+
# Process all conditional inputs
|
370
|
+
model_conditioning_inputs = [cond_input(batch) for cond_input in conditional_inputs]
|
371
|
+
|
372
|
+
# Determine batch size
|
373
|
+
batch_size = len(model_conditioning_inputs[0]) if model_conditioning_inputs else 4
|
374
|
+
|
375
|
+
# Generate samples - works for both images and videos
|
376
|
+
return sampler.generate_samples(
|
377
|
+
params=val_state.ema_params,
|
378
|
+
resolution=image_size,
|
379
|
+
num_samples=batch_size,
|
380
|
+
sequence_length=sequence_length, # Will be None for images
|
381
|
+
diffusion_steps=diffusion_steps,
|
382
|
+
start_step=1000,
|
383
|
+
end_step=0,
|
384
|
+
priors=None,
|
385
|
+
model_conditioning_inputs=tuple(model_conditioning_inputs),
|
386
|
+
)
|
387
|
+
|
388
|
+
return sampler, generate_samples
|
389
|
+
|
390
|
+
def _get_image_size(self):
|
391
|
+
"""Helper to determine image size from available information."""
|
392
|
+
if self.native_resolution is not None:
|
393
|
+
return self.native_resolution
|
394
|
+
|
395
|
+
sample_data_shape = self.input_config.sample_data_shape
|
396
|
+
return sample_data_shape[-2] # Assuming [..., H, W, C] format
|
397
|
+
|
398
|
+
def _get_sequence_length(self):
|
399
|
+
"""Helper to determine sequence length for video generation."""
|
400
|
+
if not self.is_video:
|
401
|
+
return None
|
402
|
+
|
403
|
+
sample_data_shape = self.input_config.sample_data_shape
|
404
|
+
return sample_data_shape[1] # Assuming [B,T,H,W,C] format
|
405
|
+
|
406
|
+
def validation_loop(
|
407
|
+
self,
|
408
|
+
val_state: SimpleTrainState,
|
409
|
+
val_step_fn: Callable,
|
410
|
+
val_ds,
|
411
|
+
val_steps_per_epoch,
|
412
|
+
current_step,
|
413
|
+
diffusion_steps=200,
|
414
|
+
):
|
415
|
+
"""
|
416
|
+
Run validation and log samples for both image and video diffusion.
|
417
|
+
"""
|
418
|
+
sampler, generate_samples = val_step_fn
|
419
|
+
val_ds = iter(val_ds()) if val_ds else None
|
420
|
+
|
421
|
+
try:
|
422
|
+
# Generate samples
|
423
|
+
samples = generate_samples(
|
424
|
+
val_state,
|
425
|
+
next(val_ds),
|
426
|
+
sampler,
|
427
|
+
diffusion_steps,
|
428
|
+
)
|
429
|
+
|
430
|
+
# Log samples to wandb
|
431
|
+
if getattr(self, 'wandb', None) is not None and self.wandb:
|
432
|
+
import numpy as np
|
433
|
+
|
434
|
+
# Process samples differently based on dimensionality
|
435
|
+
if len(samples.shape) == 5: # [B,T,H,W,C] - Video data
|
436
|
+
self._log_video_samples(samples, current_step)
|
437
|
+
else: # [B,H,W,C] - Image data
|
438
|
+
self._log_image_samples(samples, current_step)
|
439
|
+
|
440
|
+
except Exception as e:
|
441
|
+
print("Error in validation loop:", e)
|
442
|
+
import traceback
|
443
|
+
traceback.print_exc()
|
444
|
+
|
445
|
+
def _log_video_samples(self, samples, current_step):
|
446
|
+
"""Helper to log video samples to wandb."""
|
447
|
+
import numpy as np
|
448
|
+
from wandb import Video as wandbVideo
|
449
|
+
|
450
|
+
for i in range(samples.shape[0]):
|
451
|
+
# Convert to numpy, denormalize and clip
|
452
|
+
sample = np.array(samples[i])
|
453
|
+
sample = (sample + 1) * 127.5
|
454
|
+
sample = np.clip(sample, 0, 255).astype(np.uint8)
|
455
|
+
|
456
|
+
# Log as video
|
457
|
+
self.wandb.log({
|
458
|
+
f"video_sample_{i}": wandbVideo(
|
459
|
+
sample,
|
460
|
+
fps=10,
|
461
|
+
caption=f"Video Sample {i} at step {current_step}"
|
462
|
+
)
|
463
|
+
}, step=current_step)
|
464
|
+
|
465
|
+
def _log_image_samples(self, samples, current_step):
|
466
|
+
"""Helper to log image samples to wandb."""
|
467
|
+
import numpy as np
|
468
|
+
from wandb import Image as wandbImage
|
469
|
+
|
470
|
+
for i in range(samples.shape[0]):
|
471
|
+
# Convert to numpy, denormalize and clip
|
472
|
+
sample = np.array(samples[i])
|
473
|
+
sample = (sample + 1) * 127.5
|
474
|
+
sample = np.clip(sample, 0, 255).astype(np.uint8)
|
475
|
+
|
476
|
+
# Log as image
|
477
|
+
self.wandb.log({
|
478
|
+
f"sample_{i}": wandbImage(
|
479
|
+
sample,
|
480
|
+
caption=f"Sample {i} at step {current_step}"
|
481
|
+
)
|
482
|
+
}, step=current_step)
|
483
|
+
|
484
|
+
def push_to_registry(
|
485
|
+
self,
|
486
|
+
registry_name: str = 'wandb-registry-model',
|
487
|
+
):
|
488
|
+
"""
|
489
|
+
Push the model to wandb registry.
|
490
|
+
Args:
|
491
|
+
registry_name: Name of the model registry.
|
492
|
+
"""
|
493
|
+
if self.wandb is None:
|
494
|
+
raise ValueError("Wandb is not initialized. Cannot push to registry.")
|
495
|
+
|
496
|
+
modelname = self.modelname
|
497
|
+
if hasattr(self, "wandb_sweep"):
|
498
|
+
modelname = f"{modelname}-sweep-{self.wandb_sweep.id}"
|
499
|
+
|
500
|
+
latest_checkpoint_path = get_latest_checkpoint(self.checkpoint_path())
|
501
|
+
logged_artifact = self.wandb.log_artifact(
|
502
|
+
artifact_or_path=latest_checkpoint_path,
|
503
|
+
name=modelname,
|
504
|
+
type="model",
|
505
|
+
)
|
506
|
+
|
507
|
+
target_path = f"{registry_name}/{modelname}"
|
508
|
+
|
509
|
+
self.wandb.link_artifact(
|
510
|
+
artifact=logged_artifact,
|
511
|
+
target_path=target_path,
|
512
|
+
)
|
513
|
+
print(f"Model pushed to registry at {target_path}")
|
514
|
+
return logged_artifact
|
515
|
+
|
516
|
+
def __get_best_sweep_runs__(
|
517
|
+
self,
|
518
|
+
metric: str = "train/best_loss",
|
519
|
+
top_k: int = 5,
|
520
|
+
):
|
521
|
+
"""
|
522
|
+
Get the best runs from a wandb sweep.
|
523
|
+
Args:
|
524
|
+
metric: Metric to sort by.
|
525
|
+
top_k: Number of top runs to return.
|
526
|
+
"""
|
527
|
+
if self.wandb is None:
|
528
|
+
raise ValueError("Wandb is not initialized. Cannot get best runs.")
|
529
|
+
|
530
|
+
if not hasattr(self, "wandb_sweep"):
|
531
|
+
raise ValueError("Wandb sweep is not initialized. Cannot get best runs.")
|
532
|
+
|
533
|
+
# Get the sweep runs
|
534
|
+
runs = sorted(self.wandb_sweep.runs, key=lambda x: x.summary.get(metric, float('inf')))
|
535
|
+
best_runs = runs[:top_k]
|
536
|
+
lower_bound = best_runs[-1].summary.get(metric, float('inf'))
|
537
|
+
upper_bound = best_runs[0].summary.get(metric, float('inf'))
|
538
|
+
print(f"Best runs from sweep {self.wandb_sweep.id}:")
|
539
|
+
for run in best_runs:
|
540
|
+
print(f"\t\tRun ID: {run.id}, Metric: {run.summary.get(metric, float('inf'))}")
|
541
|
+
return best_runs, (min(lower_bound, upper_bound), max(lower_bound, upper_bound))
|
542
|
+
|
543
|
+
def __compare_run_against_best__(self, top_k=2, metric="train/best_loss"):
|
544
|
+
# Get best runs
|
545
|
+
best_runs, bounds = self.__get_best_sweep_runs__(metric=metric, top_k=top_k)
|
546
|
+
|
547
|
+
# Determine if lower or higher values are better (for loss, lower is better)
|
548
|
+
is_lower_better = "loss" in metric.lower()
|
549
|
+
|
550
|
+
# Check if current run is one of the best
|
551
|
+
current_run_metric = self.wandb.summary.get(metric, float('inf') if is_lower_better else float('-inf'))
|
552
|
+
|
553
|
+
# Direct check if current run is in best runs
|
554
|
+
for run in best_runs:
|
555
|
+
if run.id == self.wandb.id:
|
556
|
+
print(f"Current run {self.wandb.id} is one of the best runs.")
|
557
|
+
return True
|
558
|
+
|
559
|
+
# Backup check based on metric value
|
560
|
+
if (is_lower_better and current_run_metric < bounds[1]) or (not is_lower_better and current_run_metric > bounds[0]):
|
561
|
+
print(f"Current run {self.wandb.id} meets performance criteria.")
|
562
|
+
return True
|
563
|
+
|
564
|
+
return False
|
565
|
+
|
566
|
+
def save(self, epoch=0, step=0, state=None, rngstate=None):
|
567
|
+
super().save(epoch=epoch, step=step, state=state, rngstate=rngstate)
|
568
|
+
|
569
|
+
if self.wandb is not None and hasattr(self, "wandb_sweep"):
|
570
|
+
checkpoint = get_latest_checkpoint(self.checkpoint_path())
|
571
|
+
try:
|
572
|
+
if self.__compare_run_against_best__(top_k=5, metric="train/best_loss"):
|
573
|
+
self.push_to_registry()
|
574
|
+
print("Model pushed to registry successfully")
|
575
|
+
else:
|
576
|
+
print("Current run is not one of the best runs. Not saving model.")
|
577
|
+
|
578
|
+
# Only delete after successful registry push
|
579
|
+
shutil.rmtree(checkpoint, ignore_errors=True)
|
580
|
+
print(f"Checkpoint deleted at {checkpoint}")
|
581
|
+
except Exception as e:
|
582
|
+
print(f"Error during registry operations: {e}")
|
583
|
+
print(f"Checkpoint preserved at {checkpoint}")
|