flaxdiff 0.1.36.1__py3-none-any.whl → 0.1.36.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data/__init__.py +1 -0
- data/dataset_map.py +71 -0
- data/datasets.py +169 -0
- data/online_loader.py +363 -0
- data/sources/gcs.py +81 -0
- data/sources/tfds.py +67 -0
- {flaxdiff-0.1.36.1.dist-info → flaxdiff-0.1.36.2.dist-info}/METADATA +1 -1
- flaxdiff-0.1.36.2.dist-info/RECORD +47 -0
- flaxdiff-0.1.36.2.dist-info/top_level.txt +9 -0
- metrics/inception.py +658 -0
- metrics/utils.py +49 -0
- models/__init__.py +1 -0
- models/attention.py +368 -0
- models/autoencoder/__init__.py +2 -0
- models/autoencoder/autoencoder.py +19 -0
- models/autoencoder/diffusers.py +91 -0
- models/autoencoder/simple_autoenc.py +26 -0
- models/common.py +346 -0
- models/favor_fastattn.py +723 -0
- models/simple_unet.py +233 -0
- models/simple_vit.py +180 -0
- predictors/__init__.py +96 -0
- samplers/__init__.py +7 -0
- samplers/common.py +165 -0
- samplers/ddim.py +10 -0
- samplers/ddpm.py +37 -0
- samplers/euler.py +56 -0
- samplers/heun_sampler.py +27 -0
- samplers/multistep_dpm.py +59 -0
- samplers/rk4_sampler.py +34 -0
- schedulers/__init__.py +6 -0
- schedulers/common.py +98 -0
- schedulers/continuous.py +12 -0
- schedulers/cosine.py +40 -0
- schedulers/discrete.py +74 -0
- schedulers/exp.py +13 -0
- schedulers/karras.py +69 -0
- schedulers/linear.py +14 -0
- schedulers/sqrt.py +10 -0
- trainer/__init__.py +2 -0
- trainer/autoencoder_trainer.py +182 -0
- trainer/diffusion_trainer.py +326 -0
- trainer/simple_trainer.py +540 -0
- trainer/video_diffusion_trainer.py +62 -0
- flaxdiff-0.1.36.1.dist-info/RECORD +0 -6
- flaxdiff-0.1.36.1.dist-info/top_level.txt +0 -1
- /flaxdiff/__init__.py → /__init__.py +0 -0
- {flaxdiff-0.1.36.1.dist-info → flaxdiff-0.1.36.2.dist-info}/WHEEL +0 -0
- /flaxdiff/utils.py → /utils.py +0 -0
@@ -0,0 +1,540 @@
|
|
1
|
+
import orbax.checkpoint
|
2
|
+
import tqdm
|
3
|
+
from flax import linen as nn
|
4
|
+
import jax
|
5
|
+
from typing import Callable
|
6
|
+
from dataclasses import field
|
7
|
+
import jax.numpy as jnp
|
8
|
+
import numpy as np
|
9
|
+
from functools import partial
|
10
|
+
from clu import metrics
|
11
|
+
from flax.training import train_state # Useful dataclass to keep train state
|
12
|
+
import optax
|
13
|
+
from flax import struct # Flax dataclasses
|
14
|
+
import flax
|
15
|
+
import time
|
16
|
+
import os
|
17
|
+
import orbax
|
18
|
+
from flax.training import orbax_utils
|
19
|
+
from jax.sharding import Mesh, PartitionSpec as P
|
20
|
+
from jax.experimental import mesh_utils
|
21
|
+
from jax.experimental.shard_map import shard_map
|
22
|
+
from orbax.checkpoint.utils import fully_replicated_host_local_array_to_global_array
|
23
|
+
from termcolor import colored
|
24
|
+
from typing import Dict, Callable, Sequence, Any, Union, Tuple
|
25
|
+
from flax.training.dynamic_scale import DynamicScale
|
26
|
+
from flaxdiff.utils import RandomMarkovState
|
27
|
+
from flax.training import dynamic_scale as dynamic_scale_lib
|
28
|
+
|
29
|
+
PROCESS_COLOR_MAP = {
|
30
|
+
0: "green",
|
31
|
+
1: "yellow",
|
32
|
+
2: "magenta",
|
33
|
+
3: "cyan",
|
34
|
+
4: "white",
|
35
|
+
5: "light_blue",
|
36
|
+
6: "light_red",
|
37
|
+
7: "light_cyan"
|
38
|
+
}
|
39
|
+
|
40
|
+
def _build_global_shape_and_sharding(
|
41
|
+
local_shape: tuple[int, ...], global_mesh: Mesh
|
42
|
+
) -> tuple[tuple[int, ...], jax.sharding.NamedSharding]:
|
43
|
+
sharding = jax.sharding.NamedSharding(global_mesh, P(global_mesh.axis_names))
|
44
|
+
global_shape = (jax.process_count() * local_shape[0],) + local_shape[1:]
|
45
|
+
return global_shape, sharding
|
46
|
+
|
47
|
+
|
48
|
+
def form_global_array(path, array: np.ndarray, global_mesh: Mesh) -> jax.Array:
|
49
|
+
"""Put local sharded array into local devices"""
|
50
|
+
global_shape, sharding = _build_global_shape_and_sharding(np.shape(array), global_mesh)
|
51
|
+
try:
|
52
|
+
local_device_arrays = np.split(array, len(global_mesh.local_devices), axis=0)
|
53
|
+
except ValueError as array_split_error:
|
54
|
+
raise ValueError(
|
55
|
+
f"Unable to put to devices shape {array.shape} with "
|
56
|
+
f"local device count {len(global_mesh.local_devices)} "
|
57
|
+
) from array_split_error
|
58
|
+
local_device_buffers = jax.device_put(local_device_arrays, global_mesh.local_devices)
|
59
|
+
return jax.make_array_from_single_device_arrays(global_shape, sharding, local_device_buffers)
|
60
|
+
|
61
|
+
def convert_to_global_tree(global_mesh, pytree):
|
62
|
+
return jax.tree_util.tree_map_with_path(partial(form_global_array, global_mesh=global_mesh), pytree)
|
63
|
+
|
64
|
+
@struct.dataclass
|
65
|
+
class Metrics(metrics.Collection):
|
66
|
+
accuracy: metrics.Accuracy
|
67
|
+
loss: metrics.Average#.from_output('loss')
|
68
|
+
|
69
|
+
# Define the TrainState
|
70
|
+
class SimpleTrainState(train_state.TrainState):
|
71
|
+
metrics: Metrics
|
72
|
+
dynamic_scale: dynamic_scale_lib.DynamicScale
|
73
|
+
|
74
|
+
class SimpleTrainer:
|
75
|
+
state: SimpleTrainState
|
76
|
+
best_state: SimpleTrainState
|
77
|
+
best_loss: float
|
78
|
+
model: nn.Module
|
79
|
+
ema_decay: float = 0.999
|
80
|
+
|
81
|
+
def __init__(self,
|
82
|
+
model: nn.Module,
|
83
|
+
input_shapes: Dict[str, Tuple[int]],
|
84
|
+
optimizer: optax.GradientTransformation,
|
85
|
+
rngs: jax.random.PRNGKey,
|
86
|
+
train_state: SimpleTrainState = None,
|
87
|
+
name: str = "Simple",
|
88
|
+
load_from_checkpoint: str = None,
|
89
|
+
checkpoint_suffix: str = "",
|
90
|
+
loss_fn=optax.l2_loss,
|
91
|
+
param_transforms: Callable = None,
|
92
|
+
wandb_config: Dict[str, Any] = None,
|
93
|
+
distributed_training: bool = None,
|
94
|
+
checkpoint_base_path: str = "./checkpoints",
|
95
|
+
checkpoint_step: int = None,
|
96
|
+
use_dynamic_scale: bool = False,
|
97
|
+
):
|
98
|
+
if distributed_training is None or distributed_training is True:
|
99
|
+
# Auto-detect if we are running on multiple devices
|
100
|
+
distributed_training = jax.device_count() > 1
|
101
|
+
self.mesh = jax.sharding.Mesh(jax.devices(), 'data')
|
102
|
+
else:
|
103
|
+
self.mesh = None
|
104
|
+
|
105
|
+
self.distributed_training = distributed_training
|
106
|
+
self.model = model
|
107
|
+
self.name = name
|
108
|
+
self.loss_fn = loss_fn
|
109
|
+
self.input_shapes = input_shapes
|
110
|
+
self.checkpoint_base_path = checkpoint_base_path
|
111
|
+
|
112
|
+
|
113
|
+
if wandb_config is not None and jax.process_index() == 0:
|
114
|
+
import wandb
|
115
|
+
run = wandb.init(**wandb_config)
|
116
|
+
self.wandb = run
|
117
|
+
|
118
|
+
# define our custom x axis metric
|
119
|
+
self.wandb.define_metric("train/step")
|
120
|
+
self.wandb.define_metric("train/epoch")
|
121
|
+
|
122
|
+
self.wandb.define_metric("train/loss", step_metric="train/step")
|
123
|
+
|
124
|
+
self.wandb.define_metric("train/epoch_time", step_metric="train/epoch")
|
125
|
+
self.wandb.define_metric("train/avg_time_per_step", step_metric="train/epoch")
|
126
|
+
self.wandb.define_metric("train/avg_loss", step_metric="train/epoch")
|
127
|
+
self.wandb.define_metric("train/best_loss", step_metric="train/epoch")
|
128
|
+
|
129
|
+
# checkpointer = orbax.checkpoint.PyTreeCheckpointer()
|
130
|
+
async_checkpointer = orbax.checkpoint.AsyncCheckpointer(orbax.checkpoint.PyTreeCheckpointHandler(), timeout_secs=60)
|
131
|
+
|
132
|
+
options = orbax.checkpoint.CheckpointManagerOptions(
|
133
|
+
max_to_keep=4, create=True)
|
134
|
+
self.checkpointer = orbax.checkpoint.CheckpointManager(
|
135
|
+
self.checkpoint_path() + checkpoint_suffix, async_checkpointer, options)
|
136
|
+
|
137
|
+
if load_from_checkpoint is not None:
|
138
|
+
latest_epoch, latest_step, old_state, old_best_state, rngstate = self.load(load_from_checkpoint, checkpoint_step)
|
139
|
+
else:
|
140
|
+
latest_epoch, latest_step, old_state, old_best_state, rngstate = 0, 0, None, None, None
|
141
|
+
|
142
|
+
self.latest_step = latest_step
|
143
|
+
|
144
|
+
if rngstate:
|
145
|
+
self.rngstate = RandomMarkovState(**rngstate)
|
146
|
+
else:
|
147
|
+
self.rngstate = RandomMarkovState(rngs)
|
148
|
+
|
149
|
+
self.rngstate, subkey = self.rngstate.get_random_key()
|
150
|
+
|
151
|
+
if train_state == None:
|
152
|
+
state, best_state = self.generate_states(
|
153
|
+
optimizer, subkey, old_state, old_best_state, model, param_transforms, use_dynamic_scale
|
154
|
+
)
|
155
|
+
self.init_state(state, best_state)
|
156
|
+
else:
|
157
|
+
self.state = train_state
|
158
|
+
self.best_state = train_state
|
159
|
+
self.best_loss = 1e9
|
160
|
+
|
161
|
+
def get_input_ones(self):
|
162
|
+
return {k: jnp.ones((1, *v)) for k, v in self.input_shapes.items()}
|
163
|
+
|
164
|
+
def generate_states(
|
165
|
+
self,
|
166
|
+
optimizer: optax.GradientTransformation,
|
167
|
+
rngs: jax.random.PRNGKey,
|
168
|
+
existing_state: dict = None,
|
169
|
+
existing_best_state: dict = None,
|
170
|
+
model: nn.Module = None,
|
171
|
+
param_transforms: Callable = None,
|
172
|
+
use_dynamic_scale: bool = False
|
173
|
+
) -> Tuple[SimpleTrainState, SimpleTrainState]:
|
174
|
+
print("Generating states for SimpleTrainer")
|
175
|
+
rngs, subkey = jax.random.split(rngs)
|
176
|
+
|
177
|
+
if existing_state == None:
|
178
|
+
input_vars = self.get_input_ones()
|
179
|
+
params = model.init(subkey, **input_vars)
|
180
|
+
else:
|
181
|
+
params = existing_state['params']
|
182
|
+
|
183
|
+
state = SimpleTrainState.create(
|
184
|
+
apply_fn=model.apply,
|
185
|
+
params=params,
|
186
|
+
tx=optimizer,
|
187
|
+
metrics=Metrics.empty(),
|
188
|
+
dynamic_scale = dynamic_scale_lib.DynamicScale() if use_dynamic_scale else None
|
189
|
+
)
|
190
|
+
if existing_best_state is not None:
|
191
|
+
best_state = state.replace(
|
192
|
+
params=existing_best_state['params'])
|
193
|
+
else:
|
194
|
+
best_state = state
|
195
|
+
|
196
|
+
return state, best_state
|
197
|
+
|
198
|
+
def init_state(
|
199
|
+
self,
|
200
|
+
state: SimpleTrainState,
|
201
|
+
best_state: SimpleTrainState,
|
202
|
+
):
|
203
|
+
self.best_loss = 1e9
|
204
|
+
|
205
|
+
self.state = state
|
206
|
+
self.best_state = best_state
|
207
|
+
|
208
|
+
def get_state(self):
|
209
|
+
return self.get_np_tree(self.state)
|
210
|
+
|
211
|
+
def get_best_state(self):
|
212
|
+
return self.get_np_tree(self.best_state)
|
213
|
+
|
214
|
+
def get_rngstate(self):
|
215
|
+
return self.get_np_tree(self.rngstate)
|
216
|
+
|
217
|
+
def get_np_tree(self, pytree):
|
218
|
+
return jax.tree_util.tree_map(lambda x : np.array(x), pytree)
|
219
|
+
|
220
|
+
def checkpoint_path(self):
|
221
|
+
path = os.path.join(self.checkpoint_base_path, self.name.replace(' ', '_').lower())
|
222
|
+
# Convert the path to an absolute path
|
223
|
+
path = os.path.abspath(path)
|
224
|
+
if not os.path.exists(path):
|
225
|
+
os.makedirs(path)
|
226
|
+
return path
|
227
|
+
|
228
|
+
def tensorboard_path(self):
|
229
|
+
experiment_name = self.name
|
230
|
+
path = os.path.join(os.path.abspath('./tensorboard'), experiment_name)
|
231
|
+
if not os.path.exists(path):
|
232
|
+
os.makedirs(path)
|
233
|
+
return path
|
234
|
+
|
235
|
+
def load(self, checkpoint_path=None, checkpoint_step=None):
|
236
|
+
if checkpoint_path is None:
|
237
|
+
checkpointer = self.checkpointer
|
238
|
+
else:
|
239
|
+
checkpointer = orbax.checkpoint.PyTreeCheckpointer()
|
240
|
+
options = orbax.checkpoint.CheckpointManagerOptions(
|
241
|
+
max_to_keep=4, create=False)
|
242
|
+
checkpointer = orbax.checkpoint.CheckpointManager(
|
243
|
+
checkpoint_path, checkpointer, options)
|
244
|
+
|
245
|
+
if checkpoint_step is None:
|
246
|
+
step = checkpointer.latest_step()
|
247
|
+
else:
|
248
|
+
step = checkpoint_step
|
249
|
+
|
250
|
+
print("Loading model from checkpoint at step ", step)
|
251
|
+
ckpt = checkpointer.restore(step)
|
252
|
+
state = ckpt['state']
|
253
|
+
best_state = ckpt['best_state']
|
254
|
+
rngstate = ckpt['rngs']
|
255
|
+
# Convert the state to a TrainState
|
256
|
+
self.best_loss = ckpt['best_loss']
|
257
|
+
if self.best_loss == 0:
|
258
|
+
# It cant be zero as that must have been some problem
|
259
|
+
self.best_loss = 1e9
|
260
|
+
current_epoch = ckpt.get('epoch', step) # Must be a checkpoint from an older version which used epochs instead of steps
|
261
|
+
print(
|
262
|
+
f"Loaded model from checkpoint at epoch {current_epoch} step {step}", ckpt['best_loss'])
|
263
|
+
return current_epoch, step, state, best_state, rngstate
|
264
|
+
|
265
|
+
def save(self, epoch=0, step=0, state=None, rngstate=None):
|
266
|
+
print(f"Saving model at epoch {epoch} step {step}")
|
267
|
+
try:
|
268
|
+
ckpt = {
|
269
|
+
# 'model': self.model,
|
270
|
+
'rngs': self.get_rngstate() if rngstate is None else self.get_np_tree(rngstate),
|
271
|
+
'state': self.get_state() if state is None else self.get_np_tree(state),
|
272
|
+
'best_state': self.get_best_state(),
|
273
|
+
'best_loss': np.array(self.best_loss),
|
274
|
+
'epoch': epoch,
|
275
|
+
}
|
276
|
+
try:
|
277
|
+
save_args = orbax_utils.save_args_from_target(ckpt)
|
278
|
+
self.checkpointer.save(step, ckpt, save_kwargs={
|
279
|
+
'save_args': save_args}, force=True)
|
280
|
+
self.checkpointer.wait_until_finished()
|
281
|
+
pass
|
282
|
+
except Exception as e:
|
283
|
+
print("Error saving checkpoint", e)
|
284
|
+
except Exception as e:
|
285
|
+
print("Error saving checkpoint outer", e)
|
286
|
+
|
287
|
+
def _define_train_step(self, **kwargs):
|
288
|
+
model = self.model
|
289
|
+
loss_fn = self.loss_fn
|
290
|
+
distributed_training = self.distributed_training
|
291
|
+
|
292
|
+
def train_step(train_state: SimpleTrainState, rng_state: RandomMarkovState, batch, local_device_indexes):
|
293
|
+
"""Train for a single step."""
|
294
|
+
images = batch['image']
|
295
|
+
labels = batch['label']
|
296
|
+
|
297
|
+
def model_loss(params):
|
298
|
+
preds = model.apply(params, images)
|
299
|
+
expected_output = labels
|
300
|
+
nloss = loss_fn(preds, expected_output)
|
301
|
+
loss = jnp.mean(nloss)
|
302
|
+
return loss
|
303
|
+
loss, grads = jax.value_and_grad(model_loss)(train_state.params)
|
304
|
+
if distributed_training:
|
305
|
+
grads = jax.lax.pmean(grads, "data")
|
306
|
+
train_state = train_state.apply_gradients(grads=grads)
|
307
|
+
return train_state, loss, rng_state
|
308
|
+
|
309
|
+
if distributed_training:
|
310
|
+
train_step = shard_map(train_step, mesh=self.mesh, in_specs=(P(), P(), P('data'), P('data')), out_specs=(P(), P('data'), P()))
|
311
|
+
train_step = jax.pmap(train_step)
|
312
|
+
return train_step
|
313
|
+
|
314
|
+
def _define_vaidation_step(self):
|
315
|
+
model = self.model
|
316
|
+
loss_fn = self.loss_fn
|
317
|
+
distributed_training = self.distributed_training
|
318
|
+
|
319
|
+
def validation_step(state: SimpleTrainState, batch):
|
320
|
+
preds = model.apply(state.params, batch['image'])
|
321
|
+
expected_output = batch['label']
|
322
|
+
loss = jnp.mean(loss_fn(preds, expected_output))
|
323
|
+
if distributed_training:
|
324
|
+
loss = jax.lax.pmean(loss, "data")
|
325
|
+
metric_updates = state.metrics.single_from_model_output(
|
326
|
+
loss=loss, logits=preds, labels=expected_output)
|
327
|
+
metrics = state.metrics.merge(metric_updates)
|
328
|
+
state = state.replace(metrics=metrics)
|
329
|
+
return state
|
330
|
+
if distributed_training:
|
331
|
+
validation_step = shard_map(validation_step, mesh=self.mesh, in_specs=(P(), P('data')), out_specs=(P()))
|
332
|
+
validation_step = jax.pmap(validation_step)
|
333
|
+
return validation_step
|
334
|
+
|
335
|
+
def summary(self):
|
336
|
+
input_vars = self.get_input_ones()
|
337
|
+
print(self.model.tabulate(jax.random.key(0), **input_vars,
|
338
|
+
console_kwargs={"width": 200, "force_jupyter": True, }))
|
339
|
+
|
340
|
+
def config(self):
|
341
|
+
return {
|
342
|
+
"model": self.model,
|
343
|
+
"state": self.state,
|
344
|
+
"name": self.name,
|
345
|
+
"input_shapes": self.input_shapes
|
346
|
+
}
|
347
|
+
|
348
|
+
def init_tensorboard(self, batch_size, steps_per_epoch, epochs):
|
349
|
+
from flax.metrics import tensorboard
|
350
|
+
summary_writer = tensorboard.SummaryWriter(self.tensorboard_path())
|
351
|
+
summary_writer.hparams({
|
352
|
+
**self.config(),
|
353
|
+
"steps_per_epoch": steps_per_epoch,
|
354
|
+
"epochs": epochs,
|
355
|
+
"batch_size": batch_size
|
356
|
+
})
|
357
|
+
return summary_writer
|
358
|
+
|
359
|
+
def validation_loop(
|
360
|
+
self,
|
361
|
+
val_state: SimpleTrainState,
|
362
|
+
val_step_fn: Callable,
|
363
|
+
val_ds,
|
364
|
+
val_steps_per_epoch,
|
365
|
+
current_step,
|
366
|
+
):
|
367
|
+
global_device_count = jax.device_count()
|
368
|
+
local_device_count = jax.local_device_count()
|
369
|
+
process_index = jax.process_index()
|
370
|
+
|
371
|
+
val_ds = iter(val_ds()) if val_ds else None
|
372
|
+
# Evaluation step
|
373
|
+
try:
|
374
|
+
for i in range(val_steps_per_epoch):
|
375
|
+
if val_ds is None:
|
376
|
+
batch = None
|
377
|
+
else:
|
378
|
+
batch = next(val_ds)
|
379
|
+
if self.distributed_training and global_device_count > 1:
|
380
|
+
batch = convert_to_global_tree(self.mesh, batch)
|
381
|
+
if i == 0:
|
382
|
+
print(f"Evaluation started for process index {process_index}")
|
383
|
+
metrics = val_step_fn(val_state, batch)
|
384
|
+
if self.wandb is not None:
|
385
|
+
# metrics is a dict of metrics
|
386
|
+
if metrics and type(metrics) == dict:
|
387
|
+
for key, value in metrics.items():
|
388
|
+
if isinstance(value, jnp.ndarray):
|
389
|
+
value = np.array(value)
|
390
|
+
self.wandb.log({
|
391
|
+
f"val/{key}": value,
|
392
|
+
}, step=current_step)
|
393
|
+
except Exception as e:
|
394
|
+
print("Error logging images to wandb", e)
|
395
|
+
|
396
|
+
def train_loop(
|
397
|
+
self,
|
398
|
+
train_state: SimpleTrainState,
|
399
|
+
train_step_fn: Callable,
|
400
|
+
train_ds,
|
401
|
+
train_steps_per_epoch,
|
402
|
+
current_step,
|
403
|
+
rng_state
|
404
|
+
):
|
405
|
+
global_device_count = jax.device_count()
|
406
|
+
local_device_count = jax.local_device_count()
|
407
|
+
process_index = jax.process_index()
|
408
|
+
if self.distributed_training:
|
409
|
+
global_device_indexes = jnp.arange(global_device_count)
|
410
|
+
else:
|
411
|
+
global_device_indexes = 0
|
412
|
+
|
413
|
+
epoch_loss = 0
|
414
|
+
current_epoch = current_step // train_steps_per_epoch
|
415
|
+
last_save_time = time.time()
|
416
|
+
|
417
|
+
if process_index == 0:
|
418
|
+
pbar = tqdm.tqdm(total=train_steps_per_epoch, desc=f'\t\tEpoch {current_epoch}', ncols=100, unit='step')
|
419
|
+
|
420
|
+
for i in range(train_steps_per_epoch):
|
421
|
+
batch = next(train_ds)
|
422
|
+
if i == 0:
|
423
|
+
print(f"First batch loaded at step {current_step}")
|
424
|
+
|
425
|
+
if self.distributed_training and global_device_count > 1:
|
426
|
+
# # Convert the local device batches to a unified global jax.Array
|
427
|
+
batch = convert_to_global_tree(self.mesh, batch)
|
428
|
+
train_state, loss, rng_state = train_step_fn(train_state, rng_state, batch, global_device_indexes)
|
429
|
+
|
430
|
+
if i == 0:
|
431
|
+
print(f"Training started for process index {process_index} at step {current_step}")
|
432
|
+
|
433
|
+
if self.distributed_training:
|
434
|
+
# loss = jax.experimental.multihost_utils.process_allgather(loss)
|
435
|
+
loss = jnp.mean(loss) # Just to make sure its a scaler value
|
436
|
+
|
437
|
+
if loss <= 1e-6:
|
438
|
+
# If the loss is too low, we can assume the model has diverged
|
439
|
+
print(colored(f"Loss too low at step {current_step} => {loss}", 'red'))
|
440
|
+
# Reset the model to the old state
|
441
|
+
exit(1)
|
442
|
+
|
443
|
+
epoch_loss += loss
|
444
|
+
current_step += 1
|
445
|
+
if i % 100 == 0:
|
446
|
+
if pbar is not None:
|
447
|
+
pbar.set_postfix(loss=f'{loss:.4f}')
|
448
|
+
pbar.update(100)
|
449
|
+
if self.wandb is not None:
|
450
|
+
self.wandb.log({
|
451
|
+
"train/step" : current_step,
|
452
|
+
"train/loss": loss,
|
453
|
+
}, step=current_step)
|
454
|
+
# Save the model every few steps
|
455
|
+
if i % 10000 == 0 and i > 0:
|
456
|
+
print(f"Saving model after 10000 step {current_step}")
|
457
|
+
print(f"Devices: {len(jax.devices())}") # To sync the devices
|
458
|
+
self.save(current_epoch, current_step, train_state, rng_state)
|
459
|
+
print(f"Saving done by process index {process_index}")
|
460
|
+
last_save_time = time.time()
|
461
|
+
print(colored(f"Epoch done on index {process_index} => {current_epoch} Loss: {epoch_loss/train_steps_per_epoch}", 'green'))
|
462
|
+
if pbar is not None:
|
463
|
+
pbar.close()
|
464
|
+
return epoch_loss, current_step, train_state, rng_state
|
465
|
+
|
466
|
+
|
467
|
+
def fit(self, data, train_steps_per_epoch, epochs, train_step_args={}, val_steps_per_epoch=5, validation_step_args={}):
|
468
|
+
train_ds = iter(data['train']())
|
469
|
+
train_step = self._define_train_step(**train_step_args)
|
470
|
+
val_step = self._define_vaidation_step(**validation_step_args)
|
471
|
+
train_state = self.state
|
472
|
+
rng_state = self.rngstate
|
473
|
+
process_index = jax.process_index()
|
474
|
+
|
475
|
+
if val_steps_per_epoch > 0:
|
476
|
+
# We should first run a validation step to make sure the model is working
|
477
|
+
print(f"Validation run for sanity check for process index {process_index}")
|
478
|
+
# Validation step
|
479
|
+
self.validation_loop(
|
480
|
+
train_state,
|
481
|
+
val_step,
|
482
|
+
data.get('test', data.get('val', None)),
|
483
|
+
val_steps_per_epoch,
|
484
|
+
self.latest_step,
|
485
|
+
)
|
486
|
+
print(colored(f"Sanity Validation done on process index {process_index}", PROCESS_COLOR_MAP[process_index]))
|
487
|
+
|
488
|
+
while self.latest_step < epochs * train_steps_per_epoch:
|
489
|
+
current_epoch = self.latest_step // train_steps_per_epoch
|
490
|
+
print(f"\nEpoch {current_epoch}/{epochs}")
|
491
|
+
start_time = time.time()
|
492
|
+
epoch_loss = 0
|
493
|
+
|
494
|
+
epoch_loss, current_step, train_state, rng_state = self.train_loop(
|
495
|
+
train_state,
|
496
|
+
train_step,
|
497
|
+
train_ds,
|
498
|
+
train_steps_per_epoch,
|
499
|
+
self.latest_step,
|
500
|
+
rng_state,
|
501
|
+
)
|
502
|
+
print(colored(f"Epoch done on process index {process_index}", PROCESS_COLOR_MAP[process_index]))
|
503
|
+
|
504
|
+
self.latest_step = current_step
|
505
|
+
end_time = time.time()
|
506
|
+
self.state = train_state
|
507
|
+
self.rngstate = rng_state
|
508
|
+
total_time = end_time - start_time
|
509
|
+
avg_time_per_step = total_time / train_steps_per_epoch
|
510
|
+
avg_loss = epoch_loss / train_steps_per_epoch
|
511
|
+
if avg_loss < self.best_loss:
|
512
|
+
self.best_loss = avg_loss
|
513
|
+
self.best_state = train_state
|
514
|
+
self.save(current_epoch, current_step)
|
515
|
+
|
516
|
+
if process_index == 0:
|
517
|
+
if self.wandb is not None:
|
518
|
+
self.wandb.log({
|
519
|
+
"train/epoch_time": total_time,
|
520
|
+
"train/avg_time_per_step": avg_time_per_step,
|
521
|
+
"train/avg_loss": avg_loss,
|
522
|
+
"train/best_loss": self.best_loss,
|
523
|
+
"train/epoch": current_epoch,
|
524
|
+
}, step=current_step)
|
525
|
+
print(colored(f"\n\tEpoch {current_epoch} completed. Avg Loss: {avg_loss}, Time: {total_time:.2f}s, Best Loss: {self.best_loss}", 'green'))
|
526
|
+
|
527
|
+
if val_steps_per_epoch > 0:
|
528
|
+
print(f"Validation started for process index {process_index}")
|
529
|
+
# Validation step
|
530
|
+
self.validation_loop(
|
531
|
+
train_state,
|
532
|
+
val_step,
|
533
|
+
data.get('test', None),
|
534
|
+
val_steps_per_epoch,
|
535
|
+
current_step,
|
536
|
+
)
|
537
|
+
print(colored(f"Validation done on process index {process_index}", PROCESS_COLOR_MAP[process_index]))
|
538
|
+
|
539
|
+
self.save(epochs)
|
540
|
+
return self.state
|
@@ -0,0 +1,62 @@
|
|
1
|
+
import flax
|
2
|
+
from flax import linen as nn
|
3
|
+
import jax
|
4
|
+
from typing import Callable
|
5
|
+
from dataclasses import field
|
6
|
+
import jax.numpy as jnp
|
7
|
+
import optax
|
8
|
+
import functools
|
9
|
+
from jax.sharding import Mesh, PartitionSpec as P
|
10
|
+
from jax.experimental.shard_map import shard_map
|
11
|
+
from typing import Dict, Callable, Sequence, Any, Union, Tuple
|
12
|
+
|
13
|
+
from ..schedulers import NoiseScheduler
|
14
|
+
from ..predictors import DiffusionPredictionTransform, EpsilonPredictionTransform
|
15
|
+
|
16
|
+
from flaxdiff.utils import RandomMarkovState
|
17
|
+
|
18
|
+
from .simple_trainer import SimpleTrainer, SimpleTrainState, Metrics
|
19
|
+
|
20
|
+
from flaxdiff.models.autoencoder.autoencoder import AutoEncoder
|
21
|
+
from flax.training import dynamic_scale as dynamic_scale_lib
|
22
|
+
|
23
|
+
class TrainState(SimpleTrainState):
|
24
|
+
rngs: jax.random.PRNGKey
|
25
|
+
ema_params: dict
|
26
|
+
|
27
|
+
def apply_ema(self, decay: float = 0.999):
|
28
|
+
new_ema_params = jax.tree_util.tree_map(
|
29
|
+
lambda ema, param: decay * ema + (1 - decay) * param,
|
30
|
+
self.ema_params,
|
31
|
+
self.params,
|
32
|
+
)
|
33
|
+
return self.replace(ema_params=new_ema_params)
|
34
|
+
|
35
|
+
from flaxdiff.models.autoencoder.autoencoder import AutoEncoder
|
36
|
+
from flaxdiff.trainer.diffusion_trainer import DiffusionTrainer
|
37
|
+
|
38
|
+
class SimpleVideoDiffusionTrainer(DiffusionTrainer):
|
39
|
+
def __init__(self,
|
40
|
+
model: nn.Module,
|
41
|
+
input_shapes: Dict[str, Tuple[int]],
|
42
|
+
optimizer: optax.GradientTransformation,
|
43
|
+
noise_schedule: NoiseScheduler,
|
44
|
+
rngs: jax.random.PRNGKey,
|
45
|
+
unconditional_prob: float = 0.12,
|
46
|
+
name: str = "SimpleVideoDiffusion",
|
47
|
+
model_output_transform: DiffusionPredictionTransform = EpsilonPredictionTransform(),
|
48
|
+
autoencoder: AutoEncoder = None,
|
49
|
+
**kwargs
|
50
|
+
):
|
51
|
+
super().__init__(
|
52
|
+
model=model,
|
53
|
+
input_shapes=input_shapes,
|
54
|
+
optimizer=optimizer,
|
55
|
+
noise_schedule=noise_schedule,
|
56
|
+
unconditional_prob=unconditional_prob,
|
57
|
+
autoencoder=autoencoder,
|
58
|
+
model_output_transform=model_output_transform,
|
59
|
+
rngs=rngs,
|
60
|
+
name=name,
|
61
|
+
**kwargs
|
62
|
+
)
|
@@ -1,6 +0,0 @@
|
|
1
|
-
flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
|
-
flaxdiff/utils.py,sha256=b_hFXsam2NICQYCFk0EOcqtBjM-RUqnN0NKTn0lQ070,6532
|
3
|
-
flaxdiff-0.1.36.1.dist-info/METADATA,sha256=Fl9tlGh_BgRnT-f8k4cEYnFj7G03VecUNOX_1zbJrmE,22310
|
4
|
-
flaxdiff-0.1.36.1.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
5
|
-
flaxdiff-0.1.36.1.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
|
6
|
-
flaxdiff-0.1.36.1.dist-info/RECORD,,
|
@@ -1 +0,0 @@
|
|
1
|
-
flaxdiff
|
File without changes
|
File without changes
|
/flaxdiff/utils.py → /utils.py
RENAMED
File without changes
|