flaxdiff 0.2.2__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/data/dataloaders.py +31 -2
- flaxdiff/data/dataset_map.py +1 -1
- flaxdiff/data/sources/images.py +17 -0
- flaxdiff/trainer/general_diffusion_trainer.py +74 -23
- flaxdiff/trainer/simple_trainer.py +7 -5
- {flaxdiff-0.2.2.dist-info → flaxdiff-0.2.3.dist-info}/METADATA +3 -1
- {flaxdiff-0.2.2.dist-info → flaxdiff-0.2.3.dist-info}/RECORD +9 -9
- {flaxdiff-0.2.2.dist-info → flaxdiff-0.2.3.dist-info}/WHEEL +1 -1
- {flaxdiff-0.2.2.dist-info → flaxdiff-0.2.3.dist-info}/top_level.txt +0 -0
flaxdiff/data/dataloaders.py
CHANGED
@@ -291,7 +291,7 @@ def get_dataset_grain(
|
|
291
291
|
|
292
292
|
local_batch_size = batch_size // jax.process_count()
|
293
293
|
|
294
|
-
|
294
|
+
train_sampler = pygrain.IndexSampler(
|
295
295
|
num_records=len(data_source) if count is None else count,
|
296
296
|
shuffle=True,
|
297
297
|
seed=seed,
|
@@ -299,6 +299,14 @@ def get_dataset_grain(
|
|
299
299
|
shard_options=pygrain.ShardByJaxProcess(),
|
300
300
|
)
|
301
301
|
|
302
|
+
# val_sampler = pygrain.IndexSampler(
|
303
|
+
# num_records=len(data_source) if count is None else count,
|
304
|
+
# shuffle=False,
|
305
|
+
# seed=seed,
|
306
|
+
# num_epochs=num_epochs,
|
307
|
+
# shard_options=pygrain.ShardByJaxProcess(),
|
308
|
+
# )
|
309
|
+
|
302
310
|
def get_trainset():
|
303
311
|
transformations = [
|
304
312
|
augmenter(),
|
@@ -307,7 +315,7 @@ def get_dataset_grain(
|
|
307
315
|
|
308
316
|
loader = pygrain.DataLoader(
|
309
317
|
data_source=data_source,
|
310
|
-
sampler=
|
318
|
+
sampler=train_sampler,
|
311
319
|
operations=transformations,
|
312
320
|
worker_count=worker_count,
|
313
321
|
read_options=pygrain.ReadOptions(
|
@@ -316,10 +324,31 @@ def get_dataset_grain(
|
|
316
324
|
worker_buffer_size=worker_buffer_size,
|
317
325
|
)
|
318
326
|
return loader
|
327
|
+
|
328
|
+
# def get_valset():
|
329
|
+
# transformations = [
|
330
|
+
# augmenter(),
|
331
|
+
# pygrain.Batch(local_batch_size, drop_remainder=True),
|
332
|
+
# ]
|
333
|
+
|
334
|
+
# loader = pygrain.DataLoader(
|
335
|
+
# data_source=data_source,
|
336
|
+
# sampler=val_sampler,
|
337
|
+
# operations=transformations,
|
338
|
+
# worker_count=worker_count,
|
339
|
+
# read_options=pygrain.ReadOptions(
|
340
|
+
# read_thread_count, read_buffer_size
|
341
|
+
# ),
|
342
|
+
# worker_buffer_size=worker_buffer_size,
|
343
|
+
# )
|
344
|
+
# return loader
|
345
|
+
get_valset = get_trainset # For now, use the same function for validation
|
319
346
|
|
320
347
|
return {
|
321
348
|
"train": get_trainset,
|
322
349
|
"train_len": len(data_source),
|
350
|
+
"val": get_valset,
|
351
|
+
"val_len": len(data_source),
|
323
352
|
"local_batch_size": local_batch_size,
|
324
353
|
"global_batch_size": batch_size,
|
325
354
|
}
|
flaxdiff/data/dataset_map.py
CHANGED
@@ -21,7 +21,7 @@ datasetMap = {
|
|
21
21
|
"augmenter": gcs_augmenters,
|
22
22
|
},
|
23
23
|
"laiona_coco": {
|
24
|
-
"source": data_source_gcs('
|
24
|
+
"source": data_source_gcs('datasets/laion12m+mscoco'),
|
25
25
|
"augmenter": gcs_augmenters,
|
26
26
|
},
|
27
27
|
"aesthetic_coyo": {
|
flaxdiff/data/sources/images.py
CHANGED
@@ -167,6 +167,16 @@ class ImageTFDSAugmenter(DataAugmenter):
|
|
167
167
|
|
168
168
|
return TFDSTransform
|
169
169
|
|
170
|
+
"""
|
171
|
+
Batch structure:
|
172
|
+
{
|
173
|
+
"image": image_batch,
|
174
|
+
"text": {
|
175
|
+
"input_ids": input_ids_batch,
|
176
|
+
"attention_mask": attention_mask_batch,
|
177
|
+
}
|
178
|
+
|
179
|
+
"""
|
170
180
|
|
171
181
|
# ----------------------------------------------------------------------------------
|
172
182
|
# GCS Image Source
|
@@ -248,6 +258,13 @@ class ImageGCSAugmenter(DataAugmenter):
|
|
248
258
|
A callable that returns a pygrain.MapTransform.
|
249
259
|
"""
|
250
260
|
labelizer = self.labelizer
|
261
|
+
if method is None:
|
262
|
+
if image_scale > 256:
|
263
|
+
method = cv2.INTER_CUBIC
|
264
|
+
else:
|
265
|
+
method = cv2.INTER_AREA
|
266
|
+
|
267
|
+
print(f"Using method: {method}")
|
251
268
|
|
252
269
|
class GCSTransform(pygrain.MapTransform):
|
253
270
|
def __init__(self, *args, **kwargs):
|
@@ -18,13 +18,13 @@ from ..samplers.ddim import DDIMSampler
|
|
18
18
|
from flaxdiff.utils import RandomMarkovState, serialize_model, get_latest_checkpoint
|
19
19
|
from flaxdiff.inputs import ConditioningEncoder, ConditionalInputConfig, DiffusionInputConfig
|
20
20
|
|
21
|
-
from .simple_trainer import SimpleTrainer, SimpleTrainState, Metrics
|
21
|
+
from .simple_trainer import SimpleTrainer, SimpleTrainState, Metrics, convert_to_global_tree
|
22
22
|
|
23
23
|
from flaxdiff.models.autoencoder.autoencoder import AutoEncoder
|
24
24
|
from flax.training import dynamic_scale as dynamic_scale_lib
|
25
25
|
|
26
26
|
# Reuse the TrainState from the DiffusionTrainer
|
27
|
-
from
|
27
|
+
from .diffusion_trainer import TrainState, DiffusionTrainer
|
28
28
|
import shutil
|
29
29
|
|
30
30
|
def generate_modelname(
|
@@ -103,6 +103,15 @@ def generate_modelname(
|
|
103
103
|
# model_name = f"{model_name}-{config_hash}"
|
104
104
|
return model_name
|
105
105
|
|
106
|
+
@dataclass
|
107
|
+
class EvaluationMetric:
|
108
|
+
"""
|
109
|
+
Evaluation metrics for the diffusion model.
|
110
|
+
The function is given generated samples batch [B, H, W, C] and the original batch.
|
111
|
+
"""
|
112
|
+
function: Callable
|
113
|
+
name: str
|
114
|
+
|
106
115
|
class GeneralDiffusionTrainer(DiffusionTrainer):
|
107
116
|
"""
|
108
117
|
General trainer for diffusion models supporting both images and videos.
|
@@ -126,6 +135,7 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
|
|
126
135
|
native_resolution: int = None,
|
127
136
|
frames_per_sample: int = None,
|
128
137
|
wandb_config: Dict[str, Any] = None,
|
138
|
+
eval_metrics: List[EvaluationMetric] = None,
|
129
139
|
**kwargs
|
130
140
|
):
|
131
141
|
"""
|
@@ -150,6 +160,7 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
|
|
150
160
|
autoencoder=autoencoder,
|
151
161
|
)
|
152
162
|
self.input_config = input_config
|
163
|
+
self.eval_metrics = eval_metrics
|
153
164
|
|
154
165
|
if wandb_config is not None:
|
155
166
|
# If input_config is not in wandb_config, add it
|
@@ -363,7 +374,6 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
|
|
363
374
|
def generate_samples(
|
364
375
|
val_state: TrainState,
|
365
376
|
batch,
|
366
|
-
sampler: DiffusionSampler,
|
367
377
|
diffusion_steps: int,
|
368
378
|
):
|
369
379
|
# Process all conditional inputs
|
@@ -385,7 +395,7 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
|
|
385
395
|
model_conditioning_inputs=tuple(model_conditioning_inputs),
|
386
396
|
)
|
387
397
|
|
388
|
-
return
|
398
|
+
return generate_samples
|
389
399
|
|
390
400
|
def _get_image_size(self):
|
391
401
|
"""Helper to determine image size from available information."""
|
@@ -415,32 +425,73 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
|
|
415
425
|
"""
|
416
426
|
Run validation and log samples for both image and video diffusion.
|
417
427
|
"""
|
418
|
-
|
419
|
-
|
428
|
+
global_device_count = jax.device_count()
|
429
|
+
local_device_count = jax.local_device_count()
|
430
|
+
process_index = jax.process_index()
|
431
|
+
generate_samples = val_step_fn
|
420
432
|
|
433
|
+
val_ds = iter(val_ds()) if val_ds else None
|
434
|
+
# Evaluation step
|
421
435
|
try:
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
|
431
|
-
|
432
|
-
|
436
|
+
metrics = {metric.name: [] for metric in self.eval_metrics} if self.eval_metrics else {}
|
437
|
+
for i in range(val_steps_per_epoch):
|
438
|
+
if val_ds is None:
|
439
|
+
batch = None
|
440
|
+
else:
|
441
|
+
batch = next(val_ds)
|
442
|
+
if self.distributed_training and global_device_count > 1:
|
443
|
+
batch = convert_to_global_tree(self.mesh, batch)
|
444
|
+
# Generate samples
|
445
|
+
samples = generate_samples(
|
446
|
+
val_state,
|
447
|
+
batch,
|
448
|
+
diffusion_steps,
|
449
|
+
)
|
433
450
|
|
434
|
-
|
435
|
-
|
436
|
-
|
437
|
-
|
438
|
-
|
451
|
+
if self.eval_metrics is not None:
|
452
|
+
for metric in self.eval_metrics:
|
453
|
+
try:
|
454
|
+
# Evaluate metrics
|
455
|
+
metric_val = metric.function(samples, batch)
|
456
|
+
metrics[metric.name].append(metric_val)
|
457
|
+
except Exception as e:
|
458
|
+
print("Error in evaluation metrics:", e)
|
459
|
+
import traceback
|
460
|
+
traceback.print_exc()
|
461
|
+
pass
|
439
462
|
|
463
|
+
if i == 0:
|
464
|
+
print(f"Evaluation started for process index {process_index}")
|
465
|
+
# Log samples to wandb
|
466
|
+
if getattr(self, 'wandb', None) is not None and self.wandb:
|
467
|
+
import numpy as np
|
468
|
+
|
469
|
+
# Process samples differently based on dimensionality
|
470
|
+
if len(samples.shape) == 5: # [B,T,H,W,C] - Video data
|
471
|
+
self._log_video_samples(samples, current_step)
|
472
|
+
else: # [B,H,W,C] - Image data
|
473
|
+
self._log_image_samples(samples, current_step)
|
474
|
+
|
475
|
+
if getattr(self, 'wandb', None) is not None and self.wandb:
|
476
|
+
# metrics is a dict of metrics
|
477
|
+
if metrics and type(metrics) == dict:
|
478
|
+
# Flatten the metrics
|
479
|
+
metrics = {k: np.mean(v) for k, v in metrics.items()}
|
480
|
+
# Log the metrics
|
481
|
+
for key, value in metrics.items():
|
482
|
+
if isinstance(value, jnp.ndarray):
|
483
|
+
value = np.array(value)
|
484
|
+
self.wandb.log({
|
485
|
+
f"val/{key}": value,
|
486
|
+
}, step=current_step)
|
487
|
+
|
488
|
+
except StopIteration:
|
489
|
+
print(f"Validation dataset exhausted for process index {process_index}")
|
440
490
|
except Exception as e:
|
441
|
-
print("Error
|
491
|
+
print(f"Error during validation for process index {process_index}: {e}")
|
442
492
|
import traceback
|
443
493
|
traceback.print_exc()
|
494
|
+
|
444
495
|
|
445
496
|
def _log_video_samples(self, samples, current_step):
|
446
497
|
"""Helper to log video samples to wandb."""
|
@@ -411,7 +411,9 @@ class SimpleTrainer:
|
|
411
411
|
train_ds,
|
412
412
|
train_steps_per_epoch,
|
413
413
|
current_step,
|
414
|
-
rng_state
|
414
|
+
rng_state,
|
415
|
+
save_every:int=None,
|
416
|
+
val_every=None,
|
415
417
|
):
|
416
418
|
global_device_count = jax.device_count()
|
417
419
|
process_index = jax.process_index()
|
@@ -491,8 +493,8 @@ class SimpleTrainer:
|
|
491
493
|
"train/loss": loss,
|
492
494
|
}, step=current_step)
|
493
495
|
# Save the model every few steps
|
494
|
-
if i %
|
495
|
-
print(f"Saving model after
|
496
|
+
if save_every and i % save_every == 0 and i > 0:
|
497
|
+
print(f"Saving model after {save_every} step {current_step}")
|
496
498
|
print(f"Devices: {len(jax.devices())}") # To sync the devices
|
497
499
|
self.save(current_epoch, current_step, train_state, rng_state)
|
498
500
|
print(f"Saving done by process index {process_index}")
|
@@ -518,7 +520,7 @@ class SimpleTrainer:
|
|
518
520
|
self.validation_loop(
|
519
521
|
train_state,
|
520
522
|
val_step,
|
521
|
-
data.get('
|
523
|
+
data.get('val', data.get('test', None)),
|
522
524
|
val_steps_per_epoch,
|
523
525
|
self.latest_step,
|
524
526
|
)
|
@@ -569,7 +571,7 @@ class SimpleTrainer:
|
|
569
571
|
self.validation_loop(
|
570
572
|
train_state,
|
571
573
|
val_step,
|
572
|
-
data.get('test', None),
|
574
|
+
data.get('val', data.get('test', None)),
|
573
575
|
val_steps_per_epoch,
|
574
576
|
current_step,
|
575
577
|
)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: flaxdiff
|
3
|
-
Version: 0.2.
|
3
|
+
Version: 0.2.3
|
4
4
|
Summary: A versatile and easy to understand Diffusion library
|
5
5
|
Author-email: Ashish Kumar Singh <ashishkmr472@gmail.com>
|
6
6
|
License-Expression: MIT
|
@@ -22,6 +22,8 @@ Requires-Dist: python-dotenv
|
|
22
22
|
|
23
23
|
# 
|
24
24
|
|
25
|
+
**This project is being used for the UMD Course project MSML 605: MLOps**
|
26
|
+
|
25
27
|
**This project is partially supported by [Google TPU Research Cloud](https://sites.research.google/trc/about/). I would like to thank the Google Cloud TPU team for providing me with the resources to train the bigger text-conditional models in multi-host distributed settings.**
|
26
28
|
|
27
29
|
## A Versatile and simple Diffusion Library
|
@@ -2,14 +2,14 @@ flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
2
|
flaxdiff/utils.py,sha256=DmlWUY1FGz4ESxIHaPQJf92CHjsdMjyDd651wFUtyNg,8838
|
3
3
|
flaxdiff/data/__init__.py,sha256=8W5y7NyAOWtpLi8WRawk4VYeE3DMDnM3B_jKPD8BoFQ,143
|
4
4
|
flaxdiff/data/benchmark_decord.py,sha256=x56Db1VPmziv_9KJvWdfS0O7cffsYkF5tt5WvldOKc0,13720
|
5
|
-
flaxdiff/data/dataloaders.py,sha256=
|
6
|
-
flaxdiff/data/dataset_map.py,sha256=
|
5
|
+
flaxdiff/data/dataloaders.py,sha256=TgbR5CMxE86L0-1qy5ohZT8zhOPjk3oncd5WPBv08sQ,23557
|
6
|
+
flaxdiff/data/dataset_map.py,sha256=_6SYnmrYO-URDd8vPAmALTV6r0eMGWWmwUtsdjKGXnA,5072
|
7
7
|
flaxdiff/data/online_loader.py,sha256=t1jEhdB6gWTlwx68ehj1ol_PrImbwXYiRlrJPCmNgCM,35701
|
8
8
|
flaxdiff/data/sources/audio_utils.py,sha256=X27gG1yQt_abVOYgMtruYmZD7-8_uQCRhhTSpn4clkI,4514
|
9
9
|
flaxdiff/data/sources/av_example.py,sha256=RIcbVKqckFqbfnV65NQotzIBxjdDuM67kD1nY8fqw5Q,3826
|
10
10
|
flaxdiff/data/sources/av_utils.py,sha256=LCr9MJNurOaoxY-sjzkLqJS_MlX0x3gRSlKAVIglAU0,24045
|
11
11
|
flaxdiff/data/sources/base.py,sha256=uhF0odJSYRy0SLw1xnI9Q_q_xiVht2DmEYcX1j9AWT4,4246
|
12
|
-
flaxdiff/data/sources/images.py,sha256=
|
12
|
+
flaxdiff/data/sources/images.py,sha256=P7Rea7Zu0h9l7Zoc33zEHKdLI1ST6JEqgl1-bRwORM4,11460
|
13
13
|
flaxdiff/data/sources/utils.py,sha256=kFzM4_kPoThbAu54ulABmEDAR33tR50NgzXIpC0Dzjk,7316
|
14
14
|
flaxdiff/data/sources/videos.py,sha256=CVpOH6A4P2D8iv3gZIhd2GB5ATUD8Vsm_wVYbbugWD4,9359
|
15
15
|
flaxdiff/data/sources/voxceleb2.py,sha256=BoKfat_hsw6ObDyyaiQmPbBzuFiqgCGlgAZmf-t5Iz8,18621
|
@@ -56,9 +56,9 @@ flaxdiff/schedulers/sqrt.py,sha256=mCd_szmOqF6vqQKiAiEOqV_3eBIPGYrW3VxK0o4rBuo,4
|
|
56
56
|
flaxdiff/trainer/__init__.py,sha256=xSoierfi26gxfgxlNnwvyyPmuPAJ--5i3mEHxt3S-AE,215
|
57
57
|
flaxdiff/trainer/autoencoder_trainer.py,sha256=2FP2P-k9c0n_k3eT0trkq73dQrHRdBj9ObK1idcyhSw,6996
|
58
58
|
flaxdiff/trainer/diffusion_trainer.py,sha256=reQEVWKTqKAeyCMQ-curPOfSRmBKxKooK8EVtUuorcM,14599
|
59
|
-
flaxdiff/trainer/general_diffusion_trainer.py,sha256=
|
60
|
-
flaxdiff/trainer/simple_trainer.py,sha256=
|
61
|
-
flaxdiff-0.2.
|
62
|
-
flaxdiff-0.2.
|
63
|
-
flaxdiff-0.2.
|
64
|
-
flaxdiff-0.2.
|
59
|
+
flaxdiff/trainer/general_diffusion_trainer.py,sha256=1rLU7iooXIlSDIGFZ7bHgpMWmkqMbUzM9fHBu1L0t-U,27252
|
60
|
+
flaxdiff/trainer/simple_trainer.py,sha256=raLS1shwpjJBT_bYXLAB2E4kA9MbwasDTzDTUqfCCUc,24312
|
61
|
+
flaxdiff-0.2.3.dist-info/METADATA,sha256=eoCSaBNoDpk90qWz5_NGVkzvuf3Oqt6rSj_ZVTfYn7s,24057
|
62
|
+
flaxdiff-0.2.3.dist-info/WHEEL,sha256=wXxTzcEDnjrTwFYjLPcsW_7_XihufBwmpiBeiXNBGEA,91
|
63
|
+
flaxdiff-0.2.3.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
|
64
|
+
flaxdiff-0.2.3.dist-info/RECORD,,
|
File without changes
|