flaxdiff 0.1.35.5__py3-none-any.whl → 0.1.36__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/data/dataset_map.py +71 -0
- flaxdiff/data/datasets.py +169 -0
- flaxdiff/data/online_loader.py +69 -42
- flaxdiff/models/attention.py +1 -0
- flaxdiff/models/simple_unet.py +11 -11
- flaxdiff/models/simple_vit.py +1 -1
- flaxdiff/samplers/common.py +72 -20
- flaxdiff/samplers/ddim.py +5 -5
- flaxdiff/samplers/ddpm.py +5 -11
- flaxdiff/samplers/euler.py +7 -10
- flaxdiff/samplers/heun_sampler.py +3 -4
- flaxdiff/samplers/multistep_dpm.py +2 -3
- flaxdiff/samplers/rk4_sampler.py +9 -9
- flaxdiff/trainer/autoencoder_trainer.py +1 -1
- flaxdiff/trainer/diffusion_trainer.py +124 -32
- flaxdiff/trainer/simple_trainer.py +187 -91
- flaxdiff/trainer/video_diffusion_trainer.py +62 -0
- flaxdiff/utils.py +105 -2
- {flaxdiff-0.1.35.5.dist-info → flaxdiff-0.1.36.dist-info}/METADATA +11 -5
- flaxdiff-0.1.36.dist-info/RECORD +43 -0
- {flaxdiff-0.1.35.5.dist-info → flaxdiff-0.1.36.dist-info}/WHEEL +1 -1
- flaxdiff-0.1.35.5.dist-info/RECORD +0 -40
- {flaxdiff-0.1.35.5.dist-info → flaxdiff-0.1.36.dist-info}/top_level.txt +0 -0
@@ -24,6 +24,7 @@ from termcolor import colored
|
|
24
24
|
from typing import Dict, Callable, Sequence, Any, Union, Tuple
|
25
25
|
from flax.training.dynamic_scale import DynamicScale
|
26
26
|
from flaxdiff.utils import RandomMarkovState
|
27
|
+
from flax.training import dynamic_scale as dynamic_scale_lib
|
27
28
|
|
28
29
|
PROCESS_COLOR_MAP = {
|
29
30
|
0: "green",
|
@@ -63,12 +64,12 @@ def convert_to_global_tree(global_mesh, pytree):
|
|
63
64
|
@struct.dataclass
|
64
65
|
class Metrics(metrics.Collection):
|
65
66
|
accuracy: metrics.Accuracy
|
66
|
-
loss: metrics.Average
|
67
|
+
loss: metrics.Average#.from_output('loss')
|
67
68
|
|
68
69
|
# Define the TrainState
|
69
70
|
class SimpleTrainState(train_state.TrainState):
|
70
71
|
metrics: Metrics
|
71
|
-
dynamic_scale: DynamicScale
|
72
|
+
dynamic_scale: dynamic_scale_lib.DynamicScale
|
72
73
|
|
73
74
|
class SimpleTrainer:
|
74
75
|
state: SimpleTrainState
|
@@ -110,6 +111,7 @@ class SimpleTrainer:
|
|
110
111
|
|
111
112
|
|
112
113
|
if wandb_config is not None and jax.process_index() == 0:
|
114
|
+
import wandb
|
113
115
|
run = wandb.init(**wandb_config)
|
114
116
|
self.wandb = run
|
115
117
|
|
@@ -177,16 +179,13 @@ class SimpleTrainer:
|
|
177
179
|
params = model.init(subkey, **input_vars)
|
178
180
|
else:
|
179
181
|
params = existing_state['params']
|
180
|
-
|
181
|
-
if param_transforms is not None:
|
182
|
-
params = param_transforms(params)
|
183
182
|
|
184
183
|
state = SimpleTrainState.create(
|
185
184
|
apply_fn=model.apply,
|
186
185
|
params=params,
|
187
186
|
tx=optimizer,
|
188
187
|
metrics=Metrics.empty(),
|
189
|
-
dynamic_scale = DynamicScale() if use_dynamic_scale else None
|
188
|
+
dynamic_scale = dynamic_scale_lib.DynamicScale() if use_dynamic_scale else None
|
190
189
|
)
|
191
190
|
if existing_best_state is not None:
|
192
191
|
best_state = state.replace(
|
@@ -207,16 +206,16 @@ class SimpleTrainer:
|
|
207
206
|
self.best_state = best_state
|
208
207
|
|
209
208
|
def get_state(self):
|
210
|
-
|
211
|
-
return jax.tree_util.tree_map(lambda x : np.array(x), self.state)
|
209
|
+
return self.get_np_tree(self.state)
|
212
210
|
|
213
211
|
def get_best_state(self):
|
214
|
-
|
215
|
-
return jax.tree_util.tree_map(lambda x : np.array(x), self.best_state)
|
212
|
+
return self.get_np_tree(self.best_state)
|
216
213
|
|
217
214
|
def get_rngstate(self):
|
218
|
-
|
219
|
-
|
215
|
+
return self.get_np_tree(self.rngstate)
|
216
|
+
|
217
|
+
def get_np_tree(self, pytree):
|
218
|
+
return jax.tree_util.tree_map(lambda x : np.array(x), pytree)
|
220
219
|
|
221
220
|
def checkpoint_path(self):
|
222
221
|
path = os.path.join(self.checkpoint_base_path, self.name.replace(' ', '_').lower())
|
@@ -253,29 +252,35 @@ class SimpleTrainer:
|
|
253
252
|
rngstate = ckpt['rngs']
|
254
253
|
# Convert the state to a TrainState
|
255
254
|
self.best_loss = ckpt['best_loss']
|
255
|
+
if self.best_loss == 0:
|
256
|
+
# It cant be zero as that must have been some problem
|
257
|
+
self.best_loss = 1e9
|
256
258
|
current_epoch = ckpt.get('epoch', step) # Must be a checkpoint from an older version which used epochs instead of steps
|
257
259
|
print(
|
258
260
|
f"Loaded model from checkpoint at epoch {current_epoch} step {step}", ckpt['best_loss'])
|
259
261
|
return current_epoch, step, state, best_state, rngstate
|
260
262
|
|
261
|
-
def save(self, epoch=0, step=0):
|
263
|
+
def save(self, epoch=0, step=0, state=None, rngstate=None):
|
262
264
|
print(f"Saving model at epoch {epoch} step {step}")
|
263
|
-
ckpt = {
|
264
|
-
# 'model': self.model,
|
265
|
-
'rngs': self.get_rngstate(),
|
266
|
-
'state': self.get_state(),
|
267
|
-
'best_state': self.get_best_state(),
|
268
|
-
'best_loss': np.array(self.best_loss),
|
269
|
-
'epoch': epoch,
|
270
|
-
}
|
271
265
|
try:
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
266
|
+
ckpt = {
|
267
|
+
# 'model': self.model,
|
268
|
+
'rngs': self.get_rngstate() if rngstate is None else self.get_np_tree(rngstate),
|
269
|
+
'state': self.get_state() if state is None else self.get_np_tree(state),
|
270
|
+
'best_state': self.get_best_state(),
|
271
|
+
'best_loss': np.array(self.best_loss),
|
272
|
+
'epoch': epoch,
|
273
|
+
}
|
274
|
+
try:
|
275
|
+
save_args = orbax_utils.save_args_from_target(ckpt)
|
276
|
+
self.checkpointer.save(step, ckpt, save_kwargs={
|
277
|
+
'save_args': save_args}, force=True)
|
278
|
+
self.checkpointer.wait_until_finished()
|
279
|
+
pass
|
280
|
+
except Exception as e:
|
281
|
+
print("Error saving checkpoint", e)
|
277
282
|
except Exception as e:
|
278
|
-
print("Error saving checkpoint", e)
|
283
|
+
print("Error saving checkpoint outer", e)
|
279
284
|
|
280
285
|
def _define_train_step(self, **kwargs):
|
281
286
|
model = self.model
|
@@ -304,21 +309,26 @@ class SimpleTrainer:
|
|
304
309
|
train_step = jax.pmap(train_step)
|
305
310
|
return train_step
|
306
311
|
|
307
|
-
def
|
312
|
+
def _define_vaidation_step(self):
|
308
313
|
model = self.model
|
309
314
|
loss_fn = self.loss_fn
|
315
|
+
distributed_training = self.distributed_training
|
310
316
|
|
311
|
-
|
312
|
-
def compute_metrics(state: SimpleTrainState, batch):
|
317
|
+
def validation_step(state: SimpleTrainState, batch):
|
313
318
|
preds = model.apply(state.params, batch['image'])
|
314
319
|
expected_output = batch['label']
|
315
320
|
loss = jnp.mean(loss_fn(preds, expected_output))
|
321
|
+
if distributed_training:
|
322
|
+
loss = jax.lax.pmean(loss, "data")
|
316
323
|
metric_updates = state.metrics.single_from_model_output(
|
317
324
|
loss=loss, logits=preds, labels=expected_output)
|
318
325
|
metrics = state.metrics.merge(metric_updates)
|
319
326
|
state = state.replace(metrics=metrics)
|
320
327
|
return state
|
321
|
-
|
328
|
+
if distributed_training:
|
329
|
+
validation_step = shard_map(validation_step, mesh=self.mesh, in_specs=(P(), P('data')), out_specs=(P()))
|
330
|
+
validation_step = jax.pmap(validation_step)
|
331
|
+
return validation_step
|
322
332
|
|
323
333
|
def summary(self):
|
324
334
|
input_vars = self.get_input_ones()
|
@@ -343,17 +353,53 @@ class SimpleTrainer:
|
|
343
353
|
"batch_size": batch_size
|
344
354
|
})
|
345
355
|
return summary_writer
|
356
|
+
|
357
|
+
def validation_loop(
|
358
|
+
self,
|
359
|
+
val_state: SimpleTrainState,
|
360
|
+
val_step_fn: Callable,
|
361
|
+
val_ds,
|
362
|
+
val_steps_per_epoch,
|
363
|
+
current_step,
|
364
|
+
):
|
365
|
+
global_device_count = jax.device_count()
|
366
|
+
local_device_count = jax.local_device_count()
|
367
|
+
process_index = jax.process_index()
|
368
|
+
|
369
|
+
val_ds = iter(val_ds()) if val_ds else None
|
370
|
+
# Evaluation step
|
371
|
+
try:
|
372
|
+
for i in range(val_steps_per_epoch):
|
373
|
+
if val_ds is None:
|
374
|
+
batch = None
|
375
|
+
else:
|
376
|
+
batch = next(val_ds)
|
377
|
+
if self.distributed_training and global_device_count > 1:
|
378
|
+
batch = convert_to_global_tree(self.mesh, batch)
|
379
|
+
if i == 0:
|
380
|
+
print(f"Evaluation started for process index {process_index}")
|
381
|
+
metrics = val_step_fn(val_state, batch)
|
382
|
+
if self.wandb is not None:
|
383
|
+
# metrics is a dict of metrics
|
384
|
+
if metrics and type(metrics) == dict:
|
385
|
+
for key, value in metrics.items():
|
386
|
+
if isinstance(value, jnp.ndarray):
|
387
|
+
value = np.array(value)
|
388
|
+
self.wandb.log({
|
389
|
+
f"val/{key}": value,
|
390
|
+
}, step=current_step)
|
391
|
+
except Exception as e:
|
392
|
+
print("Error logging images to wandb", e)
|
346
393
|
|
347
|
-
def
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
rng_state = self.rngstate
|
394
|
+
def train_loop(
|
395
|
+
self,
|
396
|
+
train_state: SimpleTrainState,
|
397
|
+
train_step_fn: Callable,
|
398
|
+
train_ds,
|
399
|
+
train_steps_per_epoch,
|
400
|
+
current_step,
|
401
|
+
rng_state
|
402
|
+
):
|
357
403
|
global_device_count = jax.device_count()
|
358
404
|
local_device_count = jax.local_device_count()
|
359
405
|
process_index = jax.process_index()
|
@@ -361,67 +407,105 @@ class SimpleTrainer:
|
|
361
407
|
global_device_indexes = jnp.arange(global_device_count)
|
362
408
|
else:
|
363
409
|
global_device_indexes = 0
|
410
|
+
|
411
|
+
epoch_loss = 0
|
412
|
+
current_epoch = current_step // train_steps_per_epoch
|
413
|
+
last_save_time = time.time()
|
414
|
+
|
415
|
+
if process_index == 0:
|
416
|
+
pbar = tqdm.tqdm(total=train_steps_per_epoch, desc=f'\t\tEpoch {current_epoch}', ncols=100, unit='step')
|
417
|
+
|
418
|
+
for i in range(train_steps_per_epoch):
|
419
|
+
batch = next(train_ds)
|
420
|
+
if i == 0:
|
421
|
+
print(f"First batch loaded at step {current_step}")
|
422
|
+
|
423
|
+
if self.distributed_training and global_device_count > 1:
|
424
|
+
# # Convert the local device batches to a unified global jax.Array
|
425
|
+
batch = convert_to_global_tree(self.mesh, batch)
|
426
|
+
train_state, loss, rng_state = train_step_fn(train_state, rng_state, batch, global_device_indexes)
|
364
427
|
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
if self.distributed_training and global_device_count > 1:
|
372
|
-
# Convert the local device batches to a unified global jax.Array
|
373
|
-
batch = convert_to_global_tree(self.mesh, batch)
|
374
|
-
train_state, loss, rng_state = train_step(train_state, rng_state, batch, global_device_indexes)
|
375
|
-
|
376
|
-
if self.distributed_training:
|
377
|
-
loss = jax.experimental.multihost_utils.process_allgather(loss)
|
378
|
-
loss = jnp.mean(loss) # Just to make sure its a scaler value
|
428
|
+
if i == 0:
|
429
|
+
print(f"Training started for process index {process_index} at step {current_step}")
|
430
|
+
|
431
|
+
if self.distributed_training:
|
432
|
+
# loss = jax.experimental.multihost_utils.process_allgather(loss)
|
433
|
+
loss = jnp.mean(loss) # Just to make sure its a scaler value
|
379
434
|
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
435
|
+
if loss <= 1e-6:
|
436
|
+
# If the loss is too low, we can assume the model has diverged
|
437
|
+
print(colored(f"Loss too low at step {current_step} => {loss}", 'red'))
|
438
|
+
# Reset the model to the old state
|
439
|
+
exit(1)
|
385
440
|
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
|
441
|
+
epoch_loss += loss
|
442
|
+
current_step += 1
|
443
|
+
if i % 100 == 0:
|
444
|
+
if pbar is not None:
|
445
|
+
pbar.set_postfix(loss=f'{loss:.4f}')
|
446
|
+
pbar.update(100)
|
447
|
+
if self.wandb is not None:
|
448
|
+
self.wandb.log({
|
449
|
+
"train/step" : current_step,
|
450
|
+
"train/loss": loss,
|
451
|
+
}, step=current_step)
|
452
|
+
# Save the model every few steps
|
453
|
+
if i % 10000 == 0 and i > 0:
|
454
|
+
print(f"Saving model after 10000 step {current_step}")
|
455
|
+
print(f"Devices: {len(jax.devices())}") # To sync the devices
|
456
|
+
self.save(current_epoch, current_step, train_state, rng_state)
|
457
|
+
print(f"Saving done by process index {process_index}")
|
458
|
+
last_save_time = time.time()
|
459
|
+
print(colored(f"Epoch done on index {process_index} => {current_epoch} Loss: {epoch_loss/train_steps_per_epoch}", 'green'))
|
460
|
+
if pbar is not None:
|
461
|
+
pbar.close()
|
462
|
+
return epoch_loss, current_step, train_state, rng_state
|
463
|
+
|
464
|
+
|
465
|
+
def fit(self, data, train_steps_per_epoch, epochs, train_step_args={}, val_steps_per_epoch=5, validation_step_args={}):
|
466
|
+
train_ds = iter(data['train']())
|
467
|
+
train_step = self._define_train_step(**train_step_args)
|
468
|
+
val_step = self._define_vaidation_step(**validation_step_args)
|
469
|
+
train_state = self.state
|
470
|
+
rng_state = self.rngstate
|
471
|
+
process_index = jax.process_index()
|
472
|
+
|
473
|
+
if val_steps_per_epoch > 0:
|
474
|
+
# We should first run a validation step to make sure the model is working
|
475
|
+
print(f"Validation run for sanity check for process index {process_index}")
|
476
|
+
# Validation step
|
477
|
+
self.validation_loop(
|
478
|
+
train_state,
|
479
|
+
val_step,
|
480
|
+
data.get('test', data.get('val', None)),
|
481
|
+
val_steps_per_epoch,
|
482
|
+
self.latest_step,
|
483
|
+
)
|
484
|
+
print(colored(f"Sanity Validation done on process index {process_index}", PROCESS_COLOR_MAP[process_index]))
|
485
|
+
|
486
|
+
while self.latest_step < epochs * train_steps_per_epoch:
|
487
|
+
current_epoch = self.latest_step // train_steps_per_epoch
|
407
488
|
print(f"\nEpoch {current_epoch}/{epochs}")
|
408
489
|
start_time = time.time()
|
409
490
|
epoch_loss = 0
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
491
|
+
|
492
|
+
epoch_loss, current_step, train_state, rng_state = self.train_loop(
|
493
|
+
train_state,
|
494
|
+
train_step,
|
495
|
+
train_ds,
|
496
|
+
train_steps_per_epoch,
|
497
|
+
self.latest_step,
|
498
|
+
rng_state,
|
499
|
+
)
|
500
|
+
print(colored(f"Epoch done on process index {process_index}", PROCESS_COLOR_MAP[process_index]))
|
417
501
|
|
418
502
|
self.latest_step = current_step
|
419
503
|
end_time = time.time()
|
420
504
|
self.state = train_state
|
421
505
|
self.rngstate = rng_state
|
422
506
|
total_time = end_time - start_time
|
423
|
-
avg_time_per_step = total_time /
|
424
|
-
avg_loss = epoch_loss /
|
507
|
+
avg_time_per_step = total_time / train_steps_per_epoch
|
508
|
+
avg_loss = epoch_loss / train_steps_per_epoch
|
425
509
|
if avg_loss < self.best_loss:
|
426
510
|
self.best_loss = avg_loss
|
427
511
|
self.best_state = train_state
|
@@ -437,6 +521,18 @@ class SimpleTrainer:
|
|
437
521
|
"train/epoch": current_epoch,
|
438
522
|
}, step=current_step)
|
439
523
|
print(colored(f"\n\tEpoch {current_epoch} completed. Avg Loss: {avg_loss}, Time: {total_time:.2f}s, Best Loss: {self.best_loss}", 'green'))
|
440
|
-
|
524
|
+
|
525
|
+
if val_steps_per_epoch > 0:
|
526
|
+
print(f"Validation started for process index {process_index}")
|
527
|
+
# Validation step
|
528
|
+
self.validation_loop(
|
529
|
+
train_state,
|
530
|
+
val_step,
|
531
|
+
data.get('test', None),
|
532
|
+
val_steps_per_epoch,
|
533
|
+
current_step,
|
534
|
+
)
|
535
|
+
print(colored(f"Validation done on process index {process_index}", PROCESS_COLOR_MAP[process_index]))
|
536
|
+
|
441
537
|
self.save(epochs)
|
442
538
|
return self.state
|
@@ -0,0 +1,62 @@
|
|
1
|
+
import flax
|
2
|
+
from flax import linen as nn
|
3
|
+
import jax
|
4
|
+
from typing import Callable
|
5
|
+
from dataclasses import field
|
6
|
+
import jax.numpy as jnp
|
7
|
+
import optax
|
8
|
+
import functools
|
9
|
+
from jax.sharding import Mesh, PartitionSpec as P
|
10
|
+
from jax.experimental.shard_map import shard_map
|
11
|
+
from typing import Dict, Callable, Sequence, Any, Union, Tuple
|
12
|
+
|
13
|
+
from ..schedulers import NoiseScheduler
|
14
|
+
from ..predictors import DiffusionPredictionTransform, EpsilonPredictionTransform
|
15
|
+
|
16
|
+
from flaxdiff.utils import RandomMarkovState
|
17
|
+
|
18
|
+
from .simple_trainer import SimpleTrainer, SimpleTrainState, Metrics
|
19
|
+
|
20
|
+
from flaxdiff.models.autoencoder.autoencoder import AutoEncoder
|
21
|
+
from flax.training import dynamic_scale as dynamic_scale_lib
|
22
|
+
|
23
|
+
class TrainState(SimpleTrainState):
|
24
|
+
rngs: jax.random.PRNGKey
|
25
|
+
ema_params: dict
|
26
|
+
|
27
|
+
def apply_ema(self, decay: float = 0.999):
|
28
|
+
new_ema_params = jax.tree_util.tree_map(
|
29
|
+
lambda ema, param: decay * ema + (1 - decay) * param,
|
30
|
+
self.ema_params,
|
31
|
+
self.params,
|
32
|
+
)
|
33
|
+
return self.replace(ema_params=new_ema_params)
|
34
|
+
|
35
|
+
from flaxdiff.models.autoencoder.autoencoder import AutoEncoder
|
36
|
+
from flaxdiff.trainer.diffusion_trainer import DiffusionTrainer
|
37
|
+
|
38
|
+
class SimpleVideoDiffusionTrainer(DiffusionTrainer):
|
39
|
+
def __init__(self,
|
40
|
+
model: nn.Module,
|
41
|
+
input_shapes: Dict[str, Tuple[int]],
|
42
|
+
optimizer: optax.GradientTransformation,
|
43
|
+
noise_schedule: NoiseScheduler,
|
44
|
+
rngs: jax.random.PRNGKey,
|
45
|
+
unconditional_prob: float = 0.12,
|
46
|
+
name: str = "SimpleVideoDiffusion",
|
47
|
+
model_output_transform: DiffusionPredictionTransform = EpsilonPredictionTransform(),
|
48
|
+
autoencoder: AutoEncoder = None,
|
49
|
+
**kwargs
|
50
|
+
):
|
51
|
+
super().__init__(
|
52
|
+
model=model,
|
53
|
+
input_shapes=input_shapes,
|
54
|
+
optimizer=optimizer,
|
55
|
+
noise_schedule=noise_schedule,
|
56
|
+
unconditional_prob=unconditional_prob,
|
57
|
+
autoencoder=autoencoder,
|
58
|
+
model_output_transform=model_output_transform,
|
59
|
+
rngs=rngs,
|
60
|
+
name=name,
|
61
|
+
**kwargs
|
62
|
+
)
|
flaxdiff/utils.py
CHANGED
@@ -2,7 +2,12 @@ import jax
|
|
2
2
|
import jax.numpy as jnp
|
3
3
|
import flax.struct as struct
|
4
4
|
import flax.linen as nn
|
5
|
-
from typing import Any
|
5
|
+
from typing import Any, Callable
|
6
|
+
from dataclasses import dataclass
|
7
|
+
from functools import partial
|
8
|
+
import numpy as np
|
9
|
+
from jax.sharding import Mesh, PartitionSpec as P
|
10
|
+
from abc import ABC, abstractmethod
|
6
11
|
|
7
12
|
class MarkovState(struct.PyTreeNode):
|
8
13
|
pass
|
@@ -17,6 +22,30 @@ class RandomMarkovState(MarkovState):
|
|
17
22
|
def clip_images(images, clip_min=-1, clip_max=1):
|
18
23
|
return jnp.clip(images, clip_min, clip_max)
|
19
24
|
|
25
|
+
def _build_global_shape_and_sharding(
|
26
|
+
local_shape: tuple[int, ...], global_mesh: Mesh
|
27
|
+
) -> tuple[tuple[int, ...], jax.sharding.NamedSharding]:
|
28
|
+
sharding = jax.sharding.NamedSharding(global_mesh, P(global_mesh.axis_names))
|
29
|
+
global_shape = (jax.process_count() * local_shape[0],) + local_shape[1:]
|
30
|
+
return global_shape, sharding
|
31
|
+
|
32
|
+
|
33
|
+
def form_global_array(path, array: np.ndarray, global_mesh: Mesh) -> jax.Array:
|
34
|
+
"""Put local sharded array into local devices"""
|
35
|
+
global_shape, sharding = _build_global_shape_and_sharding(np.shape(array), global_mesh)
|
36
|
+
try:
|
37
|
+
local_device_arrays = np.split(array, len(global_mesh.local_devices), axis=0)
|
38
|
+
except ValueError as array_split_error:
|
39
|
+
raise ValueError(
|
40
|
+
f"Unable to put to devices shape {array.shape} with "
|
41
|
+
f"local device count {len(global_mesh.local_devices)} "
|
42
|
+
) from array_split_error
|
43
|
+
local_device_buffers = jax.device_put(local_device_arrays, global_mesh.local_devices)
|
44
|
+
return jax.make_array_from_single_device_arrays(global_shape, sharding, local_device_buffers)
|
45
|
+
|
46
|
+
def convert_to_global_tree(global_mesh, pytree):
|
47
|
+
return jax.tree_util.tree_map_with_path(partial(form_global_array, global_mesh=global_mesh), pytree)
|
48
|
+
|
20
49
|
class RMSNorm(nn.Module):
|
21
50
|
"""
|
22
51
|
From "Root Mean Square Layer Normalization" by https://arxiv.org/abs/1910.07467
|
@@ -86,4 +115,78 @@ class RMSNorm(nn.Module):
|
|
86
115
|
).reshape(feature_shape)
|
87
116
|
mul *= scale
|
88
117
|
y = mul * x
|
89
|
-
return jnp.asarray(y, dtype)
|
118
|
+
return jnp.asarray(y, dtype)
|
119
|
+
|
120
|
+
@dataclass
|
121
|
+
class ConditioningEncoder(ABC):
|
122
|
+
model: nn.Module
|
123
|
+
tokenizer: Callable
|
124
|
+
|
125
|
+
def __call__(self, data):
|
126
|
+
tokens = self.tokenize(data)
|
127
|
+
outputs = self.encode_from_tokens(tokens)
|
128
|
+
return outputs
|
129
|
+
|
130
|
+
def encode_from_tokens(self, tokens):
|
131
|
+
outputs = self.model(input_ids=tokens['input_ids'],
|
132
|
+
attention_mask=tokens['attention_mask'])
|
133
|
+
last_hidden_state = outputs.last_hidden_state
|
134
|
+
return last_hidden_state
|
135
|
+
|
136
|
+
def tokenize(self, data):
|
137
|
+
tokens = self.tokenizer(data, padding="max_length",
|
138
|
+
max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="np")
|
139
|
+
return tokens
|
140
|
+
|
141
|
+
@dataclass
|
142
|
+
class TextEncoder(ConditioningEncoder):
|
143
|
+
# def __call__(self, data):
|
144
|
+
# tokens = self.tokenize(data)
|
145
|
+
# outputs = self.encode_from_tokens(tokens)
|
146
|
+
# return outputs
|
147
|
+
|
148
|
+
# def encode_from_tokens(self, tokens):
|
149
|
+
# outputs = self.model(input_ids=tokens['input_ids'],
|
150
|
+
# attention_mask=tokens['attention_mask'])
|
151
|
+
# last_hidden_state = outputs.last_hidden_state
|
152
|
+
# # pooler_output = outputs.pooler_output # pooled (EOS token) states
|
153
|
+
# # embed_pooled = pooler_output # .astype(jnp.float16)
|
154
|
+
# embed_labels_full = last_hidden_state # .astype(jnp.float16)
|
155
|
+
|
156
|
+
# return embed_labels_full
|
157
|
+
pass
|
158
|
+
|
159
|
+
class AutoTextTokenizer:
|
160
|
+
def __init__(self, tensor_type="pt", modelname="openai/clip-vit-large-patch14"):
|
161
|
+
from transformers import AutoTokenizer
|
162
|
+
self.tokenizer = AutoTokenizer.from_pretrained(modelname)
|
163
|
+
self.tensor_type = tensor_type
|
164
|
+
|
165
|
+
def __call__(self, inputs):
|
166
|
+
# print(caption)
|
167
|
+
tokens = self.tokenizer(inputs, padding="max_length", max_length=self.tokenizer.model_max_length,
|
168
|
+
truncation=True, return_tensors=self.tensor_type)
|
169
|
+
# print(tokens.keys())
|
170
|
+
return {
|
171
|
+
"input_ids": tokens["input_ids"],
|
172
|
+
"attention_mask": tokens["attention_mask"],
|
173
|
+
"caption": inputs,
|
174
|
+
}
|
175
|
+
|
176
|
+
def __repr__(self):
|
177
|
+
return self.__class__.__name__ + '()'
|
178
|
+
|
179
|
+
def defaultTextEncodeModel(backend="jax"):
|
180
|
+
from transformers import (
|
181
|
+
CLIPTextModel,
|
182
|
+
FlaxCLIPTextModel,
|
183
|
+
AutoTokenizer,
|
184
|
+
)
|
185
|
+
modelname = "openai/clip-vit-large-patch14"
|
186
|
+
if backend == "jax":
|
187
|
+
model = FlaxCLIPTextModel.from_pretrained(
|
188
|
+
modelname, dtype=jnp.bfloat16)
|
189
|
+
else:
|
190
|
+
model = CLIPTextModel.from_pretrained(modelname)
|
191
|
+
tokenizer = AutoTokenizer.from_pretrained(modelname, dtype=jnp.float16)
|
192
|
+
return TextEncoder(model, tokenizer)
|
@@ -1,15 +1,21 @@
|
|
1
|
-
Metadata-Version: 2.
|
1
|
+
Metadata-Version: 2.4
|
2
2
|
Name: flaxdiff
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.36
|
4
4
|
Summary: A versatile and easy to understand Diffusion library
|
5
5
|
Author: Ashish Kumar Singh
|
6
6
|
Author-email: ashishkmr472@gmail.com
|
7
7
|
Description-Content-Type: text/markdown
|
8
|
-
Requires-Dist: flax
|
9
|
-
Requires-Dist: optax
|
10
|
-
Requires-Dist: jax
|
8
|
+
Requires-Dist: flax>=0.8.4
|
9
|
+
Requires-Dist: optax>=0.2.2
|
10
|
+
Requires-Dist: jax>=0.4.28
|
11
11
|
Requires-Dist: orbax
|
12
12
|
Requires-Dist: clu
|
13
|
+
Dynamic: author
|
14
|
+
Dynamic: author-email
|
15
|
+
Dynamic: description
|
16
|
+
Dynamic: description-content-type
|
17
|
+
Dynamic: requires-dist
|
18
|
+
Dynamic: summary
|
13
19
|
|
14
20
|
# 
|
15
21
|
|
@@ -0,0 +1,43 @@
|
|
1
|
+
flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
|
+
flaxdiff/utils.py,sha256=b_hFXsam2NICQYCFk0EOcqtBjM-RUqnN0NKTn0lQ070,6532
|
3
|
+
flaxdiff/data/__init__.py,sha256=PM3PkHihyohT5SHVYKc8vQ4IeVfGPpCktkSVwvqMjQ4,52
|
4
|
+
flaxdiff/data/dataset_map.py,sha256=hcHaoR2IbNQmfyPUhYd6_8xinurxxCqawQijAsDI0Ek,3093
|
5
|
+
flaxdiff/data/datasets.py,sha256=YUMoSvF2yAyikRvRofZVlHwfEOU3zXSSG4KkLnVfpoA,5626
|
6
|
+
flaxdiff/data/online_loader.py,sha256=1Fi_QRixxRzbt602nORINcDeHEccvCrBpagrz4PURYg,12499
|
7
|
+
flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,26
|
8
|
+
flaxdiff/models/attention.py,sha256=JvrP7-09MV6IfRLRBhqjPmNUU-lkEMk9TOnJSBKcar8,13289
|
9
|
+
flaxdiff/models/common.py,sha256=hWsSs2BP2J-JN1s4qLRr-h-KYkcVyl2hOp1Wsm_L-h8,10994
|
10
|
+
flaxdiff/models/favor_fastattn.py,sha256=79Ew1nqarsNLPzZaBSd1ILORzJr74CupYeqGiCQK5E4,27689
|
11
|
+
flaxdiff/models/simple_unet.py,sha256=L5m2j5580QP7pJ5VIme7U5xYA22PZiGP7qdvcKUnB38,11463
|
12
|
+
flaxdiff/models/simple_vit.py,sha256=UCDDr0XVnpf6tbJWKFtEt3_nAqMqOoakXf5amyVWZNo,7929
|
13
|
+
flaxdiff/models/autoencoder/__init__.py,sha256=qY-7MldZpsfkF-_T2LqlRK7VHbqfmosz0NmvzDlBkOk,78
|
14
|
+
flaxdiff/models/autoencoder/autoencoder.py,sha256=27_hYl0yXAdH9Mx4Xu9J79mSNo-FEKr9SxhVaS3ffn4,591
|
15
|
+
flaxdiff/models/autoencoder/diffusers.py,sha256=JHeFLCxiHhu-QHwhKiCuKsQJn4AZumquiuxgZkiYGQ0,3643
|
16
|
+
flaxdiff/models/autoencoder/simple_autoenc.py,sha256=UXHPgDmwGTnv3Uts6Zj3p9R9nJXnEiEXbllgarwDfXM,805
|
17
|
+
flaxdiff/predictors/__init__.py,sha256=SKkYYRF9Wfgk2zhtZw4vCXOdOeRlrm2Mk6cvuaEvAzc,4403
|
18
|
+
flaxdiff/samplers/__init__.py,sha256=_S-9TwDeshrI0VmapV-J2hqjTByOa0-oOeUs_IdovjU,285
|
19
|
+
flaxdiff/samplers/common.py,sha256=ZA08VyovxegpRx4wOQq9LSwZi0gSCz2lrbS5oVYOEYg,8488
|
20
|
+
flaxdiff/samplers/ddim.py,sha256=pB8Kod8ZLJ3GXev4uM3cOj1Uy6ibR0jsaZa-VE0fyJM,552
|
21
|
+
flaxdiff/samplers/ddpm.py,sha256=u1OchQu0XPhc_6w9JXoaFp2wo4y-zXyQNtGAIJwxNLg,2209
|
22
|
+
flaxdiff/samplers/euler.py,sha256=Htb-IJeu7jSgY6mvgYr9yl9pUnos49vijlVk5IQsRps,2740
|
23
|
+
flaxdiff/samplers/heun_sampler.py,sha256=UyI-hSlyWvt-7VEUJj27zjgyzKkGVl8fDUHV-YpSOCc,1421
|
24
|
+
flaxdiff/samplers/multistep_dpm.py,sha256=3Wu3MrMLYaBb1ObraTbWrJmtEtU0adl1dDbz5fPJ4Gs,2735
|
25
|
+
flaxdiff/samplers/rk4_sampler.py,sha256=1j1pES_Q2QiaURvEWeedbbT1LHmkc3jsu0GgH83qBL0,1926
|
26
|
+
flaxdiff/schedulers/__init__.py,sha256=3id390WEfdf-MN-oLSPAhlRFIXrFWr6ioAHPAwURJyE,375
|
27
|
+
flaxdiff/schedulers/common.py,sha256=b-W4iI-aqScpVE8VZbBpiYvAVI6rqDkUP-C_hEVBwCI,4151
|
28
|
+
flaxdiff/schedulers/continuous.py,sha256=5c_niOA20fxJ5oJDi09FfayIRogBGwtfG0XThW2IUZk,334
|
29
|
+
flaxdiff/schedulers/cosine.py,sha256=9ban0dFHLMm35wQvaBT4nCQwPGmzNsXwQ1xI0oppmJI,2005
|
30
|
+
flaxdiff/schedulers/discrete.py,sha256=O54wH2HVu3olJA71NxgAXFW9cr6B6Gl-DR_uZeytpds,3319
|
31
|
+
flaxdiff/schedulers/exp.py,sha256=cPTnUJpYdzJRRZqMLYQz0rRUCpEmaP2tXhRumLx94jA,605
|
32
|
+
flaxdiff/schedulers/karras.py,sha256=4GN120kGwdxxU-h2mVdhBVy9IORkUMm_vvz3XjthBcI,3355
|
33
|
+
flaxdiff/schedulers/linear.py,sha256=6003F5ISq1Wc0h6UAzY95MJgsDIKGMhBzbiVALpea0k,581
|
34
|
+
flaxdiff/schedulers/sqrt.py,sha256=1F84ZgQPuoNMhe6yxGTR2G0h7dPOZtm4UDQOakbSsEU,445
|
35
|
+
flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk,128
|
36
|
+
flaxdiff/trainer/autoencoder_trainer.py,sha256=hxihkRL9WCIQVGOP-pc1jjjIUaRXDLcNo3_erTKsuWM,7049
|
37
|
+
flaxdiff/trainer/diffusion_trainer.py,sha256=ajOWBgFFwXP_VQScUjcuPoaB4Gk02aF0Ls5LNlA8wqA,12691
|
38
|
+
flaxdiff/trainer/simple_trainer.py,sha256=jCD9-qCwX0SC0rN3GrXUBfRrndWNqUI0HmbOAbmYBMM,21906
|
39
|
+
flaxdiff/trainer/video_diffusion_trainer.py,sha256=gMkKpnKNTo8QhTx5ptEEkc7W5-7rzXIr9queU53hXyQ,2197
|
40
|
+
flaxdiff-0.1.36.dist-info/METADATA,sha256=7fO1e_icIEK6dmSopv538Hm2fQnhnkOAE2Ab9inpcNE,22213
|
41
|
+
flaxdiff-0.1.36.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
42
|
+
flaxdiff-0.1.36.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
|
43
|
+
flaxdiff-0.1.36.dist-info/RECORD,,
|