flaxdiff 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -291,7 +291,7 @@ def get_dataset_grain(
291
291
 
292
292
  local_batch_size = batch_size // jax.process_count()
293
293
 
294
- sampler = pygrain.IndexSampler(
294
+ train_sampler = pygrain.IndexSampler(
295
295
  num_records=len(data_source) if count is None else count,
296
296
  shuffle=True,
297
297
  seed=seed,
@@ -299,6 +299,14 @@ def get_dataset_grain(
299
299
  shard_options=pygrain.ShardByJaxProcess(),
300
300
  )
301
301
 
302
+ # val_sampler = pygrain.IndexSampler(
303
+ # num_records=len(data_source) if count is None else count,
304
+ # shuffle=False,
305
+ # seed=seed,
306
+ # num_epochs=num_epochs,
307
+ # shard_options=pygrain.ShardByJaxProcess(),
308
+ # )
309
+
302
310
  def get_trainset():
303
311
  transformations = [
304
312
  augmenter(),
@@ -307,7 +315,7 @@ def get_dataset_grain(
307
315
 
308
316
  loader = pygrain.DataLoader(
309
317
  data_source=data_source,
310
- sampler=sampler,
318
+ sampler=train_sampler,
311
319
  operations=transformations,
312
320
  worker_count=worker_count,
313
321
  read_options=pygrain.ReadOptions(
@@ -316,10 +324,31 @@ def get_dataset_grain(
316
324
  worker_buffer_size=worker_buffer_size,
317
325
  )
318
326
  return loader
327
+
328
+ # def get_valset():
329
+ # transformations = [
330
+ # augmenter(),
331
+ # pygrain.Batch(local_batch_size, drop_remainder=True),
332
+ # ]
333
+
334
+ # loader = pygrain.DataLoader(
335
+ # data_source=data_source,
336
+ # sampler=val_sampler,
337
+ # operations=transformations,
338
+ # worker_count=worker_count,
339
+ # read_options=pygrain.ReadOptions(
340
+ # read_thread_count, read_buffer_size
341
+ # ),
342
+ # worker_buffer_size=worker_buffer_size,
343
+ # )
344
+ # return loader
345
+ get_valset = get_trainset # For now, use the same function for validation
319
346
 
320
347
  return {
321
348
  "train": get_trainset,
322
349
  "train_len": len(data_source),
350
+ "val": get_valset,
351
+ "val_len": len(data_source),
323
352
  "local_batch_size": local_batch_size,
324
353
  "global_batch_size": batch_size,
325
354
  }
@@ -21,7 +21,7 @@ datasetMap = {
21
21
  "augmenter": gcs_augmenters,
22
22
  },
23
23
  "laiona_coco": {
24
- "source": data_source_gcs('arrayrecord2/laion-aesthetics-12m+mscoco-2017'),
24
+ "source": data_source_gcs('datasets/laion12m+mscoco'),
25
25
  "augmenter": gcs_augmenters,
26
26
  },
27
27
  "aesthetic_coyo": {
@@ -167,6 +167,16 @@ class ImageTFDSAugmenter(DataAugmenter):
167
167
 
168
168
  return TFDSTransform
169
169
 
170
+ """
171
+ Batch structure:
172
+ {
173
+ "image": image_batch,
174
+ "text": {
175
+ "input_ids": input_ids_batch,
176
+ "attention_mask": attention_mask_batch,
177
+ }
178
+
179
+ """
170
180
 
171
181
  # ----------------------------------------------------------------------------------
172
182
  # GCS Image Source
@@ -248,6 +258,13 @@ class ImageGCSAugmenter(DataAugmenter):
248
258
  A callable that returns a pygrain.MapTransform.
249
259
  """
250
260
  labelizer = self.labelizer
261
+ if method is None:
262
+ if image_scale > 256:
263
+ method = cv2.INTER_CUBIC
264
+ else:
265
+ method = cv2.INTER_AREA
266
+
267
+ print(f"Using method: {method}")
251
268
 
252
269
  class GCSTransform(pygrain.MapTransform):
253
270
  def __init__(self, *args, **kwargs):
@@ -18,13 +18,13 @@ from ..samplers.ddim import DDIMSampler
18
18
  from flaxdiff.utils import RandomMarkovState, serialize_model, get_latest_checkpoint
19
19
  from flaxdiff.inputs import ConditioningEncoder, ConditionalInputConfig, DiffusionInputConfig
20
20
 
21
- from .simple_trainer import SimpleTrainer, SimpleTrainState, Metrics
21
+ from .simple_trainer import SimpleTrainer, SimpleTrainState, Metrics, convert_to_global_tree
22
22
 
23
23
  from flaxdiff.models.autoencoder.autoencoder import AutoEncoder
24
24
  from flax.training import dynamic_scale as dynamic_scale_lib
25
25
 
26
26
  # Reuse the TrainState from the DiffusionTrainer
27
- from flaxdiff.trainer.diffusion_trainer import TrainState, DiffusionTrainer
27
+ from .diffusion_trainer import TrainState, DiffusionTrainer
28
28
  import shutil
29
29
 
30
30
  def generate_modelname(
@@ -103,6 +103,15 @@ def generate_modelname(
103
103
  # model_name = f"{model_name}-{config_hash}"
104
104
  return model_name
105
105
 
106
+ @dataclass
107
+ class EvaluationMetric:
108
+ """
109
+ Evaluation metrics for the diffusion model.
110
+ The function is given generated samples batch [B, H, W, C] and the original batch.
111
+ """
112
+ function: Callable
113
+ name: str
114
+
106
115
  class GeneralDiffusionTrainer(DiffusionTrainer):
107
116
  """
108
117
  General trainer for diffusion models supporting both images and videos.
@@ -126,6 +135,7 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
126
135
  native_resolution: int = None,
127
136
  frames_per_sample: int = None,
128
137
  wandb_config: Dict[str, Any] = None,
138
+ eval_metrics: List[EvaluationMetric] = None,
129
139
  **kwargs
130
140
  ):
131
141
  """
@@ -150,6 +160,7 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
150
160
  autoencoder=autoencoder,
151
161
  )
152
162
  self.input_config = input_config
163
+ self.eval_metrics = eval_metrics
153
164
 
154
165
  if wandb_config is not None:
155
166
  # If input_config is not in wandb_config, add it
@@ -363,7 +374,6 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
363
374
  def generate_samples(
364
375
  val_state: TrainState,
365
376
  batch,
366
- sampler: DiffusionSampler,
367
377
  diffusion_steps: int,
368
378
  ):
369
379
  # Process all conditional inputs
@@ -385,7 +395,7 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
385
395
  model_conditioning_inputs=tuple(model_conditioning_inputs),
386
396
  )
387
397
 
388
- return sampler, generate_samples
398
+ return generate_samples
389
399
 
390
400
  def _get_image_size(self):
391
401
  """Helper to determine image size from available information."""
@@ -415,32 +425,73 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
415
425
  """
416
426
  Run validation and log samples for both image and video diffusion.
417
427
  """
418
- sampler, generate_samples = val_step_fn
419
- val_ds = iter(val_ds()) if val_ds else None
428
+ global_device_count = jax.device_count()
429
+ local_device_count = jax.local_device_count()
430
+ process_index = jax.process_index()
431
+ generate_samples = val_step_fn
420
432
 
433
+ val_ds = iter(val_ds()) if val_ds else None
434
+ # Evaluation step
421
435
  try:
422
- # Generate samples
423
- samples = generate_samples(
424
- val_state,
425
- next(val_ds),
426
- sampler,
427
- diffusion_steps,
428
- )
429
-
430
- # Log samples to wandb
431
- if getattr(self, 'wandb', None) is not None and self.wandb:
432
- import numpy as np
436
+ metrics = {metric.name: [] for metric in self.eval_metrics} if self.eval_metrics else {}
437
+ for i in range(val_steps_per_epoch):
438
+ if val_ds is None:
439
+ batch = None
440
+ else:
441
+ batch = next(val_ds)
442
+ if self.distributed_training and global_device_count > 1:
443
+ batch = convert_to_global_tree(self.mesh, batch)
444
+ # Generate samples
445
+ samples = generate_samples(
446
+ val_state,
447
+ batch,
448
+ diffusion_steps,
449
+ )
433
450
 
434
- # Process samples differently based on dimensionality
435
- if len(samples.shape) == 5: # [B,T,H,W,C] - Video data
436
- self._log_video_samples(samples, current_step)
437
- else: # [B,H,W,C] - Image data
438
- self._log_image_samples(samples, current_step)
451
+ if self.eval_metrics is not None:
452
+ for metric in self.eval_metrics:
453
+ try:
454
+ # Evaluate metrics
455
+ metric_val = metric.function(samples, batch)
456
+ metrics[metric.name].append(metric_val)
457
+ except Exception as e:
458
+ print("Error in evaluation metrics:", e)
459
+ import traceback
460
+ traceback.print_exc()
461
+ pass
439
462
 
463
+ if i == 0:
464
+ print(f"Evaluation started for process index {process_index}")
465
+ # Log samples to wandb
466
+ if getattr(self, 'wandb', None) is not None and self.wandb:
467
+ import numpy as np
468
+
469
+ # Process samples differently based on dimensionality
470
+ if len(samples.shape) == 5: # [B,T,H,W,C] - Video data
471
+ self._log_video_samples(samples, current_step)
472
+ else: # [B,H,W,C] - Image data
473
+ self._log_image_samples(samples, current_step)
474
+
475
+ if getattr(self, 'wandb', None) is not None and self.wandb:
476
+ # metrics is a dict of metrics
477
+ if metrics and type(metrics) == dict:
478
+ # Flatten the metrics
479
+ metrics = {k: np.mean(v) for k, v in metrics.items()}
480
+ # Log the metrics
481
+ for key, value in metrics.items():
482
+ if isinstance(value, jnp.ndarray):
483
+ value = np.array(value)
484
+ self.wandb.log({
485
+ f"val/{key}": value,
486
+ }, step=current_step)
487
+
488
+ except StopIteration:
489
+ print(f"Validation dataset exhausted for process index {process_index}")
440
490
  except Exception as e:
441
- print("Error in validation loop:", e)
491
+ print(f"Error during validation for process index {process_index}: {e}")
442
492
  import traceback
443
493
  traceback.print_exc()
494
+
444
495
 
445
496
  def _log_video_samples(self, samples, current_step):
446
497
  """Helper to log video samples to wandb."""
@@ -484,7 +535,7 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
484
535
  def push_to_registry(
485
536
  self,
486
537
  registry_name: str = 'wandb-registry-model',
487
- aliases: List[str] = ['latest'],
538
+ aliases: List[str] = [],
488
539
  ):
489
540
  """
490
541
  Push the model to wandb registry.
@@ -504,7 +555,7 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
504
555
  artifact_or_path=latest_checkpoint_path,
505
556
  name=modelname,
506
557
  type="model",
507
- aliases=aliases,
558
+ aliases=['latest'] + aliases,
508
559
  )
509
560
 
510
561
  target_path = f"{registry_name}/{modelname}"
@@ -512,6 +563,7 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
512
563
  self.wandb.link_artifact(
513
564
  artifact=logged_artifact,
514
565
  target_path=target_path,
566
+ aliases=aliases,
515
567
  )
516
568
  print(f"Model pushed to registry at {target_path}")
517
569
  return logged_artifact
@@ -582,7 +634,7 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
582
634
  is_good, is_best = self.__compare_run_against_best__(top_k=5, metric="train/best_loss")
583
635
  if is_good:
584
636
  # Push to registry with appropriate aliases
585
- aliases = ["latest"]
637
+ aliases = []
586
638
  if is_best:
587
639
  aliases.append("best")
588
640
  self.push_to_registry(aliases=aliases)
@@ -411,7 +411,9 @@ class SimpleTrainer:
411
411
  train_ds,
412
412
  train_steps_per_epoch,
413
413
  current_step,
414
- rng_state
414
+ rng_state,
415
+ save_every:int=None,
416
+ val_every=None,
415
417
  ):
416
418
  global_device_count = jax.device_count()
417
419
  process_index = jax.process_index()
@@ -491,8 +493,8 @@ class SimpleTrainer:
491
493
  "train/loss": loss,
492
494
  }, step=current_step)
493
495
  # Save the model every few steps
494
- if i % 10000 == 0 and i > 0:
495
- print(f"Saving model after 10000 step {current_step}")
496
+ if save_every and i % save_every == 0 and i > 0:
497
+ print(f"Saving model after {save_every} step {current_step}")
496
498
  print(f"Devices: {len(jax.devices())}") # To sync the devices
497
499
  self.save(current_epoch, current_step, train_state, rng_state)
498
500
  print(f"Saving done by process index {process_index}")
@@ -518,7 +520,7 @@ class SimpleTrainer:
518
520
  self.validation_loop(
519
521
  train_state,
520
522
  val_step,
521
- data.get('test', data.get('val', None)),
523
+ data.get('val', data.get('test', None)),
522
524
  val_steps_per_epoch,
523
525
  self.latest_step,
524
526
  )
@@ -569,7 +571,7 @@ class SimpleTrainer:
569
571
  self.validation_loop(
570
572
  train_state,
571
573
  val_step,
572
- data.get('test', None),
574
+ data.get('val', data.get('test', None)),
573
575
  val_steps_per_epoch,
574
576
  current_step,
575
577
  )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: flaxdiff
3
- Version: 0.2.1
3
+ Version: 0.2.3
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author-email: Ashish Kumar Singh <ashishkmr472@gmail.com>
6
6
  License-Expression: MIT
@@ -22,6 +22,8 @@ Requires-Dist: python-dotenv
22
22
 
23
23
  # ![](images/logo.jpeg "FlaxDiff")
24
24
 
25
+ **This project is being used for the UMD Course project MSML 605: MLOps**
26
+
25
27
  **This project is partially supported by [Google TPU Research Cloud](https://sites.research.google/trc/about/). I would like to thank the Google Cloud TPU team for providing me with the resources to train the bigger text-conditional models in multi-host distributed settings.**
26
28
 
27
29
  ## A Versatile and simple Diffusion Library
@@ -2,14 +2,14 @@ flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  flaxdiff/utils.py,sha256=DmlWUY1FGz4ESxIHaPQJf92CHjsdMjyDd651wFUtyNg,8838
3
3
  flaxdiff/data/__init__.py,sha256=8W5y7NyAOWtpLi8WRawk4VYeE3DMDnM3B_jKPD8BoFQ,143
4
4
  flaxdiff/data/benchmark_decord.py,sha256=x56Db1VPmziv_9KJvWdfS0O7cffsYkF5tt5WvldOKc0,13720
5
- flaxdiff/data/dataloaders.py,sha256=V4goNCK0JD_TthggXAEgJJD4LxJi1pUDew1x_fMCuO4,22576
6
- flaxdiff/data/dataset_map.py,sha256=NrLG1XtIxy8GcCsZ-e6eascjgsP0Xq5lVA1z3HIIYyI,5093
5
+ flaxdiff/data/dataloaders.py,sha256=TgbR5CMxE86L0-1qy5ohZT8zhOPjk3oncd5WPBv08sQ,23557
6
+ flaxdiff/data/dataset_map.py,sha256=_6SYnmrYO-URDd8vPAmALTV6r0eMGWWmwUtsdjKGXnA,5072
7
7
  flaxdiff/data/online_loader.py,sha256=t1jEhdB6gWTlwx68ehj1ol_PrImbwXYiRlrJPCmNgCM,35701
8
8
  flaxdiff/data/sources/audio_utils.py,sha256=X27gG1yQt_abVOYgMtruYmZD7-8_uQCRhhTSpn4clkI,4514
9
9
  flaxdiff/data/sources/av_example.py,sha256=RIcbVKqckFqbfnV65NQotzIBxjdDuM67kD1nY8fqw5Q,3826
10
10
  flaxdiff/data/sources/av_utils.py,sha256=LCr9MJNurOaoxY-sjzkLqJS_MlX0x3gRSlKAVIglAU0,24045
11
11
  flaxdiff/data/sources/base.py,sha256=uhF0odJSYRy0SLw1xnI9Q_q_xiVht2DmEYcX1j9AWT4,4246
12
- flaxdiff/data/sources/images.py,sha256=WpH4ywZhNol26peX3m6m5NrmDJ1K2s6fRcYHvOFlOk8,11102
12
+ flaxdiff/data/sources/images.py,sha256=P7Rea7Zu0h9l7Zoc33zEHKdLI1ST6JEqgl1-bRwORM4,11460
13
13
  flaxdiff/data/sources/utils.py,sha256=kFzM4_kPoThbAu54ulABmEDAR33tR50NgzXIpC0Dzjk,7316
14
14
  flaxdiff/data/sources/videos.py,sha256=CVpOH6A4P2D8iv3gZIhd2GB5ATUD8Vsm_wVYbbugWD4,9359
15
15
  flaxdiff/data/sources/voxceleb2.py,sha256=BoKfat_hsw6ObDyyaiQmPbBzuFiqgCGlgAZmf-t5Iz8,18621
@@ -56,9 +56,9 @@ flaxdiff/schedulers/sqrt.py,sha256=mCd_szmOqF6vqQKiAiEOqV_3eBIPGYrW3VxK0o4rBuo,4
56
56
  flaxdiff/trainer/__init__.py,sha256=xSoierfi26gxfgxlNnwvyyPmuPAJ--5i3mEHxt3S-AE,215
57
57
  flaxdiff/trainer/autoencoder_trainer.py,sha256=2FP2P-k9c0n_k3eT0trkq73dQrHRdBj9ObK1idcyhSw,6996
58
58
  flaxdiff/trainer/diffusion_trainer.py,sha256=reQEVWKTqKAeyCMQ-curPOfSRmBKxKooK8EVtUuorcM,14599
59
- flaxdiff/trainer/general_diffusion_trainer.py,sha256=lEeDCHJyon2GdColI_EGOHNgg1jqqRWT3PlhlIE-NOg,24776
60
- flaxdiff/trainer/simple_trainer.py,sha256=CF2mMcc6AtBgcR1XiqKevRL0paGS0S9ZJofCns32nRM,24214
61
- flaxdiff-0.2.1.dist-info/METADATA,sha256=bFpBzrWOiBN1S5UPevSnN3vbPRYydLV4l_cToAUIOlI,23982
62
- flaxdiff-0.2.1.dist-info/WHEEL,sha256=pxyMxgL8-pra_rKaQ4drOZAegBVuX-G_4nRHjjgWbmo,91
63
- flaxdiff-0.2.1.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
64
- flaxdiff-0.2.1.dist-info/RECORD,,
59
+ flaxdiff/trainer/general_diffusion_trainer.py,sha256=1rLU7iooXIlSDIGFZ7bHgpMWmkqMbUzM9fHBu1L0t-U,27252
60
+ flaxdiff/trainer/simple_trainer.py,sha256=raLS1shwpjJBT_bYXLAB2E4kA9MbwasDTzDTUqfCCUc,24312
61
+ flaxdiff-0.2.3.dist-info/METADATA,sha256=eoCSaBNoDpk90qWz5_NGVkzvuf3Oqt6rSj_ZVTfYn7s,24057
62
+ flaxdiff-0.2.3.dist-info/WHEEL,sha256=wXxTzcEDnjrTwFYjLPcsW_7_XihufBwmpiBeiXNBGEA,91
63
+ flaxdiff-0.2.3.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
64
+ flaxdiff-0.2.3.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (79.0.0)
2
+ Generator: setuptools (80.1.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5