flaxdiff 0.1.35.6__py3-none-any.whl → 0.1.36.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/utils.py +105 -2
- {flaxdiff-0.1.35.6.dist-info → flaxdiff-0.1.36.1.dist-info}/METADATA +16 -7
- flaxdiff-0.1.36.1.dist-info/RECORD +6 -0
- {flaxdiff-0.1.35.6.dist-info → flaxdiff-0.1.36.1.dist-info}/WHEEL +1 -1
- flaxdiff/data/__init__.py +0 -1
- flaxdiff/data/online_loader.py +0 -336
- flaxdiff/models/__init__.py +0 -1
- flaxdiff/models/attention.py +0 -368
- flaxdiff/models/autoencoder/__init__.py +0 -2
- flaxdiff/models/autoencoder/autoencoder.py +0 -19
- flaxdiff/models/autoencoder/diffusers.py +0 -91
- flaxdiff/models/autoencoder/simple_autoenc.py +0 -26
- flaxdiff/models/common.py +0 -346
- flaxdiff/models/favor_fastattn.py +0 -723
- flaxdiff/models/simple_unet.py +0 -233
- flaxdiff/models/simple_vit.py +0 -180
- flaxdiff/predictors/__init__.py +0 -96
- flaxdiff/samplers/__init__.py +0 -7
- flaxdiff/samplers/common.py +0 -113
- flaxdiff/samplers/ddim.py +0 -10
- flaxdiff/samplers/ddpm.py +0 -43
- flaxdiff/samplers/euler.py +0 -59
- flaxdiff/samplers/heun_sampler.py +0 -28
- flaxdiff/samplers/multistep_dpm.py +0 -60
- flaxdiff/samplers/rk4_sampler.py +0 -34
- flaxdiff/schedulers/__init__.py +0 -6
- flaxdiff/schedulers/common.py +0 -98
- flaxdiff/schedulers/continuous.py +0 -12
- flaxdiff/schedulers/cosine.py +0 -40
- flaxdiff/schedulers/discrete.py +0 -74
- flaxdiff/schedulers/exp.py +0 -13
- flaxdiff/schedulers/karras.py +0 -69
- flaxdiff/schedulers/linear.py +0 -14
- flaxdiff/schedulers/sqrt.py +0 -10
- flaxdiff/trainer/__init__.py +0 -2
- flaxdiff/trainer/autoencoder_trainer.py +0 -182
- flaxdiff/trainer/diffusion_trainer.py +0 -234
- flaxdiff/trainer/simple_trainer.py +0 -442
- flaxdiff-0.1.35.6.dist-info/RECORD +0 -40
- {flaxdiff-0.1.35.6.dist-info → flaxdiff-0.1.36.1.dist-info}/top_level.txt +0 -0
@@ -1,442 +0,0 @@
|
|
1
|
-
import orbax.checkpoint
|
2
|
-
import tqdm
|
3
|
-
from flax import linen as nn
|
4
|
-
import jax
|
5
|
-
from typing import Callable
|
6
|
-
from dataclasses import field
|
7
|
-
import jax.numpy as jnp
|
8
|
-
import numpy as np
|
9
|
-
from functools import partial
|
10
|
-
from clu import metrics
|
11
|
-
from flax.training import train_state # Useful dataclass to keep train state
|
12
|
-
import optax
|
13
|
-
from flax import struct # Flax dataclasses
|
14
|
-
import flax
|
15
|
-
import time
|
16
|
-
import os
|
17
|
-
import orbax
|
18
|
-
from flax.training import orbax_utils
|
19
|
-
from jax.sharding import Mesh, PartitionSpec as P
|
20
|
-
from jax.experimental import mesh_utils
|
21
|
-
from jax.experimental.shard_map import shard_map
|
22
|
-
from orbax.checkpoint.utils import fully_replicated_host_local_array_to_global_array
|
23
|
-
from termcolor import colored
|
24
|
-
from typing import Dict, Callable, Sequence, Any, Union, Tuple
|
25
|
-
from flax.training.dynamic_scale import DynamicScale
|
26
|
-
from flaxdiff.utils import RandomMarkovState
|
27
|
-
|
28
|
-
PROCESS_COLOR_MAP = {
|
29
|
-
0: "green",
|
30
|
-
1: "yellow",
|
31
|
-
2: "magenta",
|
32
|
-
3: "cyan",
|
33
|
-
4: "white",
|
34
|
-
5: "light_blue",
|
35
|
-
6: "light_red",
|
36
|
-
7: "light_cyan"
|
37
|
-
}
|
38
|
-
|
39
|
-
def _build_global_shape_and_sharding(
|
40
|
-
local_shape: tuple[int, ...], global_mesh: Mesh
|
41
|
-
) -> tuple[tuple[int, ...], jax.sharding.NamedSharding]:
|
42
|
-
sharding = jax.sharding.NamedSharding(global_mesh, P(global_mesh.axis_names))
|
43
|
-
global_shape = (jax.process_count() * local_shape[0],) + local_shape[1:]
|
44
|
-
return global_shape, sharding
|
45
|
-
|
46
|
-
|
47
|
-
def form_global_array(path, array: np.ndarray, global_mesh: Mesh) -> jax.Array:
|
48
|
-
"""Put local sharded array into local devices"""
|
49
|
-
global_shape, sharding = _build_global_shape_and_sharding(np.shape(array), global_mesh)
|
50
|
-
try:
|
51
|
-
local_device_arrays = np.split(array, len(global_mesh.local_devices), axis=0)
|
52
|
-
except ValueError as array_split_error:
|
53
|
-
raise ValueError(
|
54
|
-
f"Unable to put to devices shape {array.shape} with "
|
55
|
-
f"local device count {len(global_mesh.local_devices)} "
|
56
|
-
) from array_split_error
|
57
|
-
local_device_buffers = jax.device_put(local_device_arrays, global_mesh.local_devices)
|
58
|
-
return jax.make_array_from_single_device_arrays(global_shape, sharding, local_device_buffers)
|
59
|
-
|
60
|
-
def convert_to_global_tree(global_mesh, pytree):
|
61
|
-
return jax.tree_util.tree_map_with_path(partial(form_global_array, global_mesh=global_mesh), pytree)
|
62
|
-
|
63
|
-
@struct.dataclass
|
64
|
-
class Metrics(metrics.Collection):
|
65
|
-
accuracy: metrics.Accuracy
|
66
|
-
loss: metrics.Average.from_output('loss')
|
67
|
-
|
68
|
-
# Define the TrainState
|
69
|
-
class SimpleTrainState(train_state.TrainState):
|
70
|
-
metrics: Metrics
|
71
|
-
dynamic_scale: DynamicScale
|
72
|
-
|
73
|
-
class SimpleTrainer:
|
74
|
-
state: SimpleTrainState
|
75
|
-
best_state: SimpleTrainState
|
76
|
-
best_loss: float
|
77
|
-
model: nn.Module
|
78
|
-
ema_decay: float = 0.999
|
79
|
-
|
80
|
-
def __init__(self,
|
81
|
-
model: nn.Module,
|
82
|
-
input_shapes: Dict[str, Tuple[int]],
|
83
|
-
optimizer: optax.GradientTransformation,
|
84
|
-
rngs: jax.random.PRNGKey,
|
85
|
-
train_state: SimpleTrainState = None,
|
86
|
-
name: str = "Simple",
|
87
|
-
load_from_checkpoint: str = None,
|
88
|
-
checkpoint_suffix: str = "",
|
89
|
-
loss_fn=optax.l2_loss,
|
90
|
-
param_transforms: Callable = None,
|
91
|
-
wandb_config: Dict[str, Any] = None,
|
92
|
-
distributed_training: bool = None,
|
93
|
-
checkpoint_base_path: str = "./checkpoints",
|
94
|
-
checkpoint_step: int = None,
|
95
|
-
use_dynamic_scale: bool = False,
|
96
|
-
):
|
97
|
-
if distributed_training is None or distributed_training is True:
|
98
|
-
# Auto-detect if we are running on multiple devices
|
99
|
-
distributed_training = jax.device_count() > 1
|
100
|
-
self.mesh = jax.sharding.Mesh(jax.devices(), 'data')
|
101
|
-
else:
|
102
|
-
self.mesh = None
|
103
|
-
|
104
|
-
self.distributed_training = distributed_training
|
105
|
-
self.model = model
|
106
|
-
self.name = name
|
107
|
-
self.loss_fn = loss_fn
|
108
|
-
self.input_shapes = input_shapes
|
109
|
-
self.checkpoint_base_path = checkpoint_base_path
|
110
|
-
|
111
|
-
|
112
|
-
if wandb_config is not None and jax.process_index() == 0:
|
113
|
-
run = wandb.init(**wandb_config)
|
114
|
-
self.wandb = run
|
115
|
-
|
116
|
-
# define our custom x axis metric
|
117
|
-
self.wandb.define_metric("train/step")
|
118
|
-
self.wandb.define_metric("train/epoch")
|
119
|
-
|
120
|
-
self.wandb.define_metric("train/loss", step_metric="train/step")
|
121
|
-
|
122
|
-
self.wandb.define_metric("train/epoch_time", step_metric="train/epoch")
|
123
|
-
self.wandb.define_metric("train/avg_time_per_step", step_metric="train/epoch")
|
124
|
-
self.wandb.define_metric("train/avg_loss", step_metric="train/epoch")
|
125
|
-
self.wandb.define_metric("train/best_loss", step_metric="train/epoch")
|
126
|
-
|
127
|
-
# checkpointer = orbax.checkpoint.PyTreeCheckpointer()
|
128
|
-
async_checkpointer = orbax.checkpoint.AsyncCheckpointer(orbax.checkpoint.PyTreeCheckpointHandler(), timeout_secs=60)
|
129
|
-
|
130
|
-
options = orbax.checkpoint.CheckpointManagerOptions(
|
131
|
-
max_to_keep=4, create=True)
|
132
|
-
self.checkpointer = orbax.checkpoint.CheckpointManager(
|
133
|
-
self.checkpoint_path() + checkpoint_suffix, async_checkpointer, options)
|
134
|
-
|
135
|
-
if load_from_checkpoint is not None:
|
136
|
-
latest_epoch, latest_step, old_state, old_best_state, rngstate = self.load(load_from_checkpoint, checkpoint_step)
|
137
|
-
else:
|
138
|
-
latest_epoch, latest_step, old_state, old_best_state, rngstate = 0, 0, None, None, None
|
139
|
-
|
140
|
-
self.latest_step = latest_step
|
141
|
-
|
142
|
-
if rngstate:
|
143
|
-
self.rngstate = RandomMarkovState(**rngstate)
|
144
|
-
else:
|
145
|
-
self.rngstate = RandomMarkovState(rngs)
|
146
|
-
|
147
|
-
self.rngstate, subkey = self.rngstate.get_random_key()
|
148
|
-
|
149
|
-
if train_state == None:
|
150
|
-
state, best_state = self.generate_states(
|
151
|
-
optimizer, subkey, old_state, old_best_state, model, param_transforms, use_dynamic_scale
|
152
|
-
)
|
153
|
-
self.init_state(state, best_state)
|
154
|
-
else:
|
155
|
-
self.state = train_state
|
156
|
-
self.best_state = train_state
|
157
|
-
self.best_loss = 1e9
|
158
|
-
|
159
|
-
def get_input_ones(self):
|
160
|
-
return {k: jnp.ones((1, *v)) for k, v in self.input_shapes.items()}
|
161
|
-
|
162
|
-
def generate_states(
|
163
|
-
self,
|
164
|
-
optimizer: optax.GradientTransformation,
|
165
|
-
rngs: jax.random.PRNGKey,
|
166
|
-
existing_state: dict = None,
|
167
|
-
existing_best_state: dict = None,
|
168
|
-
model: nn.Module = None,
|
169
|
-
param_transforms: Callable = None,
|
170
|
-
use_dynamic_scale: bool = False
|
171
|
-
) -> Tuple[SimpleTrainState, SimpleTrainState]:
|
172
|
-
print("Generating states for SimpleTrainer")
|
173
|
-
rngs, subkey = jax.random.split(rngs)
|
174
|
-
|
175
|
-
if existing_state == None:
|
176
|
-
input_vars = self.get_input_ones()
|
177
|
-
params = model.init(subkey, **input_vars)
|
178
|
-
else:
|
179
|
-
params = existing_state['params']
|
180
|
-
|
181
|
-
if param_transforms is not None:
|
182
|
-
params = param_transforms(params)
|
183
|
-
|
184
|
-
state = SimpleTrainState.create(
|
185
|
-
apply_fn=model.apply,
|
186
|
-
params=params,
|
187
|
-
tx=optimizer,
|
188
|
-
metrics=Metrics.empty(),
|
189
|
-
dynamic_scale = DynamicScale() if use_dynamic_scale else None
|
190
|
-
)
|
191
|
-
if existing_best_state is not None:
|
192
|
-
best_state = state.replace(
|
193
|
-
params=existing_best_state['params'])
|
194
|
-
else:
|
195
|
-
best_state = state
|
196
|
-
|
197
|
-
return state, best_state
|
198
|
-
|
199
|
-
def init_state(
|
200
|
-
self,
|
201
|
-
state: SimpleTrainState,
|
202
|
-
best_state: SimpleTrainState,
|
203
|
-
):
|
204
|
-
self.best_loss = 1e9
|
205
|
-
|
206
|
-
self.state = state
|
207
|
-
self.best_state = best_state
|
208
|
-
|
209
|
-
def get_state(self):
|
210
|
-
# return fully_replicated_host_local_array_to_global_array()
|
211
|
-
return jax.tree_util.tree_map(lambda x : np.array(x), self.state)
|
212
|
-
|
213
|
-
def get_best_state(self):
|
214
|
-
# return convert_to_global_tree(self.mesh, flax.jax_utils.replicate(self.best_state, jax.local_devices()))
|
215
|
-
return jax.tree_util.tree_map(lambda x : np.array(x), self.best_state)
|
216
|
-
|
217
|
-
def get_rngstate(self):
|
218
|
-
# return convert_to_global_tree(self.mesh, flax.jax_utils.replicate(self.rngstate, jax.local_devices()))
|
219
|
-
return jax.tree_util.tree_map(lambda x : np.array(x), self.rngstate)
|
220
|
-
|
221
|
-
def checkpoint_path(self):
|
222
|
-
path = os.path.join(self.checkpoint_base_path, self.name.replace(' ', '_').lower())
|
223
|
-
if not os.path.exists(path):
|
224
|
-
os.makedirs(path)
|
225
|
-
return path
|
226
|
-
|
227
|
-
def tensorboard_path(self):
|
228
|
-
experiment_name = self.name
|
229
|
-
path = os.path.join(os.path.abspath('./tensorboard'), experiment_name)
|
230
|
-
if not os.path.exists(path):
|
231
|
-
os.makedirs(path)
|
232
|
-
return path
|
233
|
-
|
234
|
-
def load(self, checkpoint_path=None, checkpoint_step=None):
|
235
|
-
if checkpoint_path is None:
|
236
|
-
checkpointer = self.checkpointer
|
237
|
-
else:
|
238
|
-
checkpointer = orbax.checkpoint.PyTreeCheckpointer()
|
239
|
-
options = orbax.checkpoint.CheckpointManagerOptions(
|
240
|
-
max_to_keep=4, create=False)
|
241
|
-
checkpointer = orbax.checkpoint.CheckpointManager(
|
242
|
-
checkpoint_path, checkpointer, options)
|
243
|
-
|
244
|
-
if checkpoint_step is None:
|
245
|
-
step = checkpointer.latest_step()
|
246
|
-
else:
|
247
|
-
step = checkpoint_step
|
248
|
-
|
249
|
-
print("Loading model from checkpoint at step ", step)
|
250
|
-
ckpt = checkpointer.restore(step)
|
251
|
-
state = ckpt['state']
|
252
|
-
best_state = ckpt['best_state']
|
253
|
-
rngstate = ckpt['rngs']
|
254
|
-
# Convert the state to a TrainState
|
255
|
-
self.best_loss = ckpt['best_loss']
|
256
|
-
current_epoch = ckpt.get('epoch', step) # Must be a checkpoint from an older version which used epochs instead of steps
|
257
|
-
print(
|
258
|
-
f"Loaded model from checkpoint at epoch {current_epoch} step {step}", ckpt['best_loss'])
|
259
|
-
return current_epoch, step, state, best_state, rngstate
|
260
|
-
|
261
|
-
def save(self, epoch=0, step=0):
|
262
|
-
print(f"Saving model at epoch {epoch} step {step}")
|
263
|
-
ckpt = {
|
264
|
-
# 'model': self.model,
|
265
|
-
'rngs': self.get_rngstate(),
|
266
|
-
'state': self.get_state(),
|
267
|
-
'best_state': self.get_best_state(),
|
268
|
-
'best_loss': np.array(self.best_loss),
|
269
|
-
'epoch': epoch,
|
270
|
-
}
|
271
|
-
try:
|
272
|
-
save_args = orbax_utils.save_args_from_target(ckpt)
|
273
|
-
self.checkpointer.save(step, ckpt, save_kwargs={
|
274
|
-
'save_args': save_args}, force=True)
|
275
|
-
self.checkpointer.wait_until_finished()
|
276
|
-
pass
|
277
|
-
except Exception as e:
|
278
|
-
print("Error saving checkpoint", e)
|
279
|
-
|
280
|
-
def _define_train_step(self, **kwargs):
|
281
|
-
model = self.model
|
282
|
-
loss_fn = self.loss_fn
|
283
|
-
distributed_training = self.distributed_training
|
284
|
-
|
285
|
-
def train_step(train_state: SimpleTrainState, rng_state: RandomMarkovState, batch, local_device_indexes):
|
286
|
-
"""Train for a single step."""
|
287
|
-
images = batch['image']
|
288
|
-
labels = batch['label']
|
289
|
-
|
290
|
-
def model_loss(params):
|
291
|
-
preds = model.apply(params, images)
|
292
|
-
expected_output = labels
|
293
|
-
nloss = loss_fn(preds, expected_output)
|
294
|
-
loss = jnp.mean(nloss)
|
295
|
-
return loss
|
296
|
-
loss, grads = jax.value_and_grad(model_loss)(train_state.params)
|
297
|
-
if distributed_training:
|
298
|
-
grads = jax.lax.pmean(grads, "data")
|
299
|
-
train_state = train_state.apply_gradients(grads=grads)
|
300
|
-
return train_state, loss, rng_state
|
301
|
-
|
302
|
-
if distributed_training:
|
303
|
-
train_step = shard_map(train_step, mesh=self.mesh, in_specs=(P(), P(), P('data'), P('data')), out_specs=(P(), P('data'), P()))
|
304
|
-
train_step = jax.pmap(train_step)
|
305
|
-
return train_step
|
306
|
-
|
307
|
-
def _define_compute_metrics(self):
|
308
|
-
model = self.model
|
309
|
-
loss_fn = self.loss_fn
|
310
|
-
|
311
|
-
@jax.jit
|
312
|
-
def compute_metrics(state: SimpleTrainState, batch):
|
313
|
-
preds = model.apply(state.params, batch['image'])
|
314
|
-
expected_output = batch['label']
|
315
|
-
loss = jnp.mean(loss_fn(preds, expected_output))
|
316
|
-
metric_updates = state.metrics.single_from_model_output(
|
317
|
-
loss=loss, logits=preds, labels=expected_output)
|
318
|
-
metrics = state.metrics.merge(metric_updates)
|
319
|
-
state = state.replace(metrics=metrics)
|
320
|
-
return state
|
321
|
-
return compute_metrics
|
322
|
-
|
323
|
-
def summary(self):
|
324
|
-
input_vars = self.get_input_ones()
|
325
|
-
print(self.model.tabulate(jax.random.key(0), **input_vars,
|
326
|
-
console_kwargs={"width": 200, "force_jupyter": True, }))
|
327
|
-
|
328
|
-
def config(self):
|
329
|
-
return {
|
330
|
-
"model": self.model,
|
331
|
-
"state": self.state,
|
332
|
-
"name": self.name,
|
333
|
-
"input_shapes": self.input_shapes
|
334
|
-
}
|
335
|
-
|
336
|
-
def init_tensorboard(self, batch_size, steps_per_epoch, epochs):
|
337
|
-
from flax.metrics import tensorboard
|
338
|
-
summary_writer = tensorboard.SummaryWriter(self.tensorboard_path())
|
339
|
-
summary_writer.hparams({
|
340
|
-
**self.config(),
|
341
|
-
"steps_per_epoch": steps_per_epoch,
|
342
|
-
"epochs": epochs,
|
343
|
-
"batch_size": batch_size
|
344
|
-
})
|
345
|
-
return summary_writer
|
346
|
-
|
347
|
-
def fit(self, data, steps_per_epoch, epochs, train_step_args={}):
|
348
|
-
train_ds = iter(data['train']())
|
349
|
-
if 'test' in data:
|
350
|
-
test_ds = data['test']
|
351
|
-
else:
|
352
|
-
test_ds = None
|
353
|
-
train_step = self._define_train_step(**train_step_args)
|
354
|
-
compute_metrics = self._define_compute_metrics()
|
355
|
-
train_state = self.state
|
356
|
-
rng_state = self.rngstate
|
357
|
-
global_device_count = jax.device_count()
|
358
|
-
local_device_count = jax.local_device_count()
|
359
|
-
process_index = jax.process_index()
|
360
|
-
if self.distributed_training:
|
361
|
-
global_device_indexes = jnp.arange(global_device_count)
|
362
|
-
else:
|
363
|
-
global_device_indexes = 0
|
364
|
-
|
365
|
-
def train_loop(current_step, pbar: tqdm.tqdm, train_state, rng_state):
|
366
|
-
epoch_loss = 0
|
367
|
-
current_epoch = current_step // steps_per_epoch
|
368
|
-
last_save_time = time.time()
|
369
|
-
for i in range(steps_per_epoch):
|
370
|
-
batch = next(train_ds)
|
371
|
-
if self.distributed_training and global_device_count > 1:
|
372
|
-
# Convert the local device batches to a unified global jax.Array
|
373
|
-
batch = convert_to_global_tree(self.mesh, batch)
|
374
|
-
train_state, loss, rng_state = train_step(train_state, rng_state, batch, global_device_indexes)
|
375
|
-
|
376
|
-
if self.distributed_training:
|
377
|
-
loss = jax.experimental.multihost_utils.process_allgather(loss)
|
378
|
-
loss = jnp.mean(loss) # Just to make sure its a scaler value
|
379
|
-
|
380
|
-
if loss <= 1e-6:
|
381
|
-
# If the loss is too low, we can assume the model has diverged
|
382
|
-
print(colored(f"Loss too low at step {current_step} => {loss}", 'red'))
|
383
|
-
# Exit the training loop
|
384
|
-
exit(1)
|
385
|
-
|
386
|
-
epoch_loss += loss
|
387
|
-
current_step += 1
|
388
|
-
if i % 100 == 0:
|
389
|
-
if pbar is not None:
|
390
|
-
pbar.set_postfix(loss=f'{loss:.4f}')
|
391
|
-
pbar.update(100)
|
392
|
-
if self.wandb is not None:
|
393
|
-
self.wandb.log({
|
394
|
-
"train/step" : current_step,
|
395
|
-
"train/loss": loss,
|
396
|
-
}, step=current_step)
|
397
|
-
# Save the model every 40 minutes
|
398
|
-
if time.time() - last_save_time > 40 * 60:
|
399
|
-
print(f"Saving model after 40 minutes at step {current_step}")
|
400
|
-
self.save(current_epoch, current_step)
|
401
|
-
last_save_time = time.time()
|
402
|
-
print(colored(f"Epoch done on index {process_index} => {current_epoch} Loss: {epoch_loss/steps_per_epoch}", 'green'))
|
403
|
-
return epoch_loss, current_step, train_state, rng_state
|
404
|
-
|
405
|
-
while self.latest_step < epochs * steps_per_epoch:
|
406
|
-
current_epoch = self.latest_step // steps_per_epoch
|
407
|
-
print(f"\nEpoch {current_epoch}/{epochs}")
|
408
|
-
start_time = time.time()
|
409
|
-
epoch_loss = 0
|
410
|
-
|
411
|
-
if process_index == 0:
|
412
|
-
with tqdm.tqdm(total=steps_per_epoch, desc=f'\t\tEpoch {current_epoch}', ncols=100, unit='step') as pbar:
|
413
|
-
epoch_loss, current_step, train_state, rng_state = train_loop(self.latest_step, pbar, train_state, rng_state)
|
414
|
-
else:
|
415
|
-
epoch_loss, current_step, train_state, rng_state = train_loop(self.latest_step, None, train_state, rng_state)
|
416
|
-
print(colored(f"Epoch done on process index {process_index}", PROCESS_COLOR_MAP[process_index]))
|
417
|
-
|
418
|
-
self.latest_step = current_step
|
419
|
-
end_time = time.time()
|
420
|
-
self.state = train_state
|
421
|
-
self.rngstate = rng_state
|
422
|
-
total_time = end_time - start_time
|
423
|
-
avg_time_per_step = total_time / steps_per_epoch
|
424
|
-
avg_loss = epoch_loss / steps_per_epoch
|
425
|
-
if avg_loss < self.best_loss:
|
426
|
-
self.best_loss = avg_loss
|
427
|
-
self.best_state = train_state
|
428
|
-
self.save(current_epoch, current_step)
|
429
|
-
|
430
|
-
if process_index == 0:
|
431
|
-
if self.wandb is not None:
|
432
|
-
self.wandb.log({
|
433
|
-
"train/epoch_time": total_time,
|
434
|
-
"train/avg_time_per_step": avg_time_per_step,
|
435
|
-
"train/avg_loss": avg_loss,
|
436
|
-
"train/best_loss": self.best_loss,
|
437
|
-
"train/epoch": current_epoch,
|
438
|
-
}, step=current_step)
|
439
|
-
print(colored(f"\n\tEpoch {current_epoch} completed. Avg Loss: {avg_loss}, Time: {total_time:.2f}s, Best Loss: {self.best_loss}", 'green'))
|
440
|
-
print("Training done")
|
441
|
-
self.save(epochs)
|
442
|
-
return self.state
|
@@ -1,40 +0,0 @@
|
|
1
|
-
flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
|
-
flaxdiff/utils.py,sha256=B0GcHlzlVYDNEIdh2v5qmP4u0neIT-FqexNohuyuCvg,2452
|
3
|
-
flaxdiff/data/__init__.py,sha256=PM3PkHihyohT5SHVYKc8vQ4IeVfGPpCktkSVwvqMjQ4,52
|
4
|
-
flaxdiff/data/online_loader.py,sha256=DoHrMZCi5gMd9tmkCpZIUU9lGxvfYtuaz58943_lCRc,11315
|
5
|
-
flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,26
|
6
|
-
flaxdiff/models/attention.py,sha256=JvrP7-09MV6IfRLRBhqjPmNUU-lkEMk9TOnJSBKcar8,13289
|
7
|
-
flaxdiff/models/common.py,sha256=hWsSs2BP2J-JN1s4qLRr-h-KYkcVyl2hOp1Wsm_L-h8,10994
|
8
|
-
flaxdiff/models/favor_fastattn.py,sha256=79Ew1nqarsNLPzZaBSd1ILORzJr74CupYeqGiCQK5E4,27689
|
9
|
-
flaxdiff/models/simple_unet.py,sha256=L5m2j5580QP7pJ5VIme7U5xYA22PZiGP7qdvcKUnB38,11463
|
10
|
-
flaxdiff/models/simple_vit.py,sha256=UCDDr0XVnpf6tbJWKFtEt3_nAqMqOoakXf5amyVWZNo,7929
|
11
|
-
flaxdiff/models/autoencoder/__init__.py,sha256=qY-7MldZpsfkF-_T2LqlRK7VHbqfmosz0NmvzDlBkOk,78
|
12
|
-
flaxdiff/models/autoencoder/autoencoder.py,sha256=27_hYl0yXAdH9Mx4Xu9J79mSNo-FEKr9SxhVaS3ffn4,591
|
13
|
-
flaxdiff/models/autoencoder/diffusers.py,sha256=JHeFLCxiHhu-QHwhKiCuKsQJn4AZumquiuxgZkiYGQ0,3643
|
14
|
-
flaxdiff/models/autoencoder/simple_autoenc.py,sha256=UXHPgDmwGTnv3Uts6Zj3p9R9nJXnEiEXbllgarwDfXM,805
|
15
|
-
flaxdiff/predictors/__init__.py,sha256=SKkYYRF9Wfgk2zhtZw4vCXOdOeRlrm2Mk6cvuaEvAzc,4403
|
16
|
-
flaxdiff/samplers/__init__.py,sha256=_S-9TwDeshrI0VmapV-J2hqjTByOa0-oOeUs_IdovjU,285
|
17
|
-
flaxdiff/samplers/common.py,sha256=_an5h5Niz9Joz_-ppridLrGHpu8X0VVvhNGknPu6AUY,5272
|
18
|
-
flaxdiff/samplers/ddim.py,sha256=XHMBX06S5hMTnKMaGh6fmq189pQcaGkA6fnX6YPbHP0,511
|
19
|
-
flaxdiff/samplers/ddpm.py,sha256=d_58hfVShJHsRPQf5h1-4YKrD43-HGjWz9Vd8hltZBg,2627
|
20
|
-
flaxdiff/samplers/euler.py,sha256=Epf7LBKUky7B8b-1ZyIlLWdRMgjmP08BQraGSKmr_3I,2726
|
21
|
-
flaxdiff/samplers/heun_sampler.py,sha256=hhWnSM26OfOIFAcsuWYa1z-2QPjASuoYTop2byLWqzE,1388
|
22
|
-
flaxdiff/samplers/multistep_dpm.py,sha256=ocmEq2sCvsULy6oTFaD5BhTU4c8VHsge4bdg6tfxW80,2724
|
23
|
-
flaxdiff/samplers/rk4_sampler.py,sha256=BF-dMV1KauO-SYShqrCfm3U3V-1n4clqQXBeoG8RWQo,1728
|
24
|
-
flaxdiff/schedulers/__init__.py,sha256=3id390WEfdf-MN-oLSPAhlRFIXrFWr6ioAHPAwURJyE,375
|
25
|
-
flaxdiff/schedulers/common.py,sha256=b-W4iI-aqScpVE8VZbBpiYvAVI6rqDkUP-C_hEVBwCI,4151
|
26
|
-
flaxdiff/schedulers/continuous.py,sha256=5c_niOA20fxJ5oJDi09FfayIRogBGwtfG0XThW2IUZk,334
|
27
|
-
flaxdiff/schedulers/cosine.py,sha256=9ban0dFHLMm35wQvaBT4nCQwPGmzNsXwQ1xI0oppmJI,2005
|
28
|
-
flaxdiff/schedulers/discrete.py,sha256=O54wH2HVu3olJA71NxgAXFW9cr6B6Gl-DR_uZeytpds,3319
|
29
|
-
flaxdiff/schedulers/exp.py,sha256=cPTnUJpYdzJRRZqMLYQz0rRUCpEmaP2tXhRumLx94jA,605
|
30
|
-
flaxdiff/schedulers/karras.py,sha256=4GN120kGwdxxU-h2mVdhBVy9IORkUMm_vvz3XjthBcI,3355
|
31
|
-
flaxdiff/schedulers/linear.py,sha256=6003F5ISq1Wc0h6UAzY95MJgsDIKGMhBzbiVALpea0k,581
|
32
|
-
flaxdiff/schedulers/sqrt.py,sha256=1F84ZgQPuoNMhe6yxGTR2G0h7dPOZtm4UDQOakbSsEU,445
|
33
|
-
flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk,128
|
34
|
-
flaxdiff/trainer/autoencoder_trainer.py,sha256=al7AsZ7yeDMEiDD-gbcXf0ADq_xfk1VMxvg24GfA-XQ,7008
|
35
|
-
flaxdiff/trainer/diffusion_trainer.py,sha256=wKkg63DWZjx2MoM3VQNCDIr40rWN8fUGxH9jWWxfZao,9373
|
36
|
-
flaxdiff/trainer/simple_trainer.py,sha256=cawm6fZNQoLLATMneAU2gQ9j7kefqHnBPHuaIj3i_a4,18237
|
37
|
-
flaxdiff-0.1.35.6.dist-info/METADATA,sha256=NVCk5V7Zc3iq-nrWTivzO17dQa1fIjYgjJb800ZrZhQ,22085
|
38
|
-
flaxdiff-0.1.35.6.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
|
39
|
-
flaxdiff-0.1.35.6.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
|
40
|
-
flaxdiff-0.1.35.6.dist-info/RECORD,,
|
File without changes
|