flaxdiff 0.2.2__py3-none-any.whl → 0.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -258,7 +258,7 @@ def get_dataset_grain(
258
258
  image_scale=256,
259
259
  count=None,
260
260
  num_epochs=None,
261
- method=jax.image.ResizeMethod.LANCZOS3,
261
+ method=None, #jax.image.ResizeMethod.LANCZOS3,
262
262
  worker_count=32,
263
263
  read_thread_count=64,
264
264
  read_buffer_size=50,
@@ -291,7 +291,7 @@ def get_dataset_grain(
291
291
 
292
292
  local_batch_size = batch_size // jax.process_count()
293
293
 
294
- sampler = pygrain.IndexSampler(
294
+ train_sampler = pygrain.IndexSampler(
295
295
  num_records=len(data_source) if count is None else count,
296
296
  shuffle=True,
297
297
  seed=seed,
@@ -299,6 +299,14 @@ def get_dataset_grain(
299
299
  shard_options=pygrain.ShardByJaxProcess(),
300
300
  )
301
301
 
302
+ # val_sampler = pygrain.IndexSampler(
303
+ # num_records=len(data_source) if count is None else count,
304
+ # shuffle=False,
305
+ # seed=seed,
306
+ # num_epochs=num_epochs,
307
+ # shard_options=pygrain.ShardByJaxProcess(),
308
+ # )
309
+
302
310
  def get_trainset():
303
311
  transformations = [
304
312
  augmenter(),
@@ -307,7 +315,7 @@ def get_dataset_grain(
307
315
 
308
316
  loader = pygrain.DataLoader(
309
317
  data_source=data_source,
310
- sampler=sampler,
318
+ sampler=train_sampler,
311
319
  operations=transformations,
312
320
  worker_count=worker_count,
313
321
  read_options=pygrain.ReadOptions(
@@ -316,10 +324,31 @@ def get_dataset_grain(
316
324
  worker_buffer_size=worker_buffer_size,
317
325
  )
318
326
  return loader
327
+
328
+ # def get_valset():
329
+ # transformations = [
330
+ # augmenter(),
331
+ # pygrain.Batch(local_batch_size, drop_remainder=True),
332
+ # ]
333
+
334
+ # loader = pygrain.DataLoader(
335
+ # data_source=data_source,
336
+ # sampler=val_sampler,
337
+ # operations=transformations,
338
+ # worker_count=worker_count,
339
+ # read_options=pygrain.ReadOptions(
340
+ # read_thread_count, read_buffer_size
341
+ # ),
342
+ # worker_buffer_size=worker_buffer_size,
343
+ # )
344
+ # return loader
345
+ get_valset = get_trainset # For now, use the same function for validation
319
346
 
320
347
  return {
321
348
  "train": get_trainset,
322
349
  "train_len": len(data_source),
350
+ "val": get_valset,
351
+ "val_len": len(data_source),
323
352
  "local_batch_size": local_batch_size,
324
353
  "global_batch_size": batch_size,
325
354
  }
@@ -21,7 +21,7 @@ datasetMap = {
21
21
  "augmenter": gcs_augmenters,
22
22
  },
23
23
  "laiona_coco": {
24
- "source": data_source_gcs('arrayrecord2/laion-aesthetics-12m+mscoco-2017'),
24
+ "source": data_source_gcs('datasets/laion12m+mscoco'),
25
25
  "augmenter": gcs_augmenters,
26
26
  },
27
27
  "aesthetic_coyo": {
@@ -167,6 +167,16 @@ class ImageTFDSAugmenter(DataAugmenter):
167
167
 
168
168
  return TFDSTransform
169
169
 
170
+ """
171
+ Batch structure:
172
+ {
173
+ "image": image_batch,
174
+ "text": {
175
+ "input_ids": input_ids_batch,
176
+ "attention_mask": attention_mask_batch,
177
+ }
178
+
179
+ """
170
180
 
171
181
  # ----------------------------------------------------------------------------------
172
182
  # GCS Image Source
@@ -248,6 +258,13 @@ class ImageGCSAugmenter(DataAugmenter):
248
258
  A callable that returns a pygrain.MapTransform.
249
259
  """
250
260
  labelizer = self.labelizer
261
+ if method is None:
262
+ if image_scale > 256:
263
+ method = cv2.INTER_CUBIC
264
+ else:
265
+ method = cv2.INTER_AREA
266
+
267
+ print(f"Using method: {method}")
251
268
 
252
269
  class GCSTransform(pygrain.MapTransform):
253
270
  def __init__(self, *args, **kwargs):
File without changes
@@ -0,0 +1,11 @@
1
+ from typing import Callable
2
+ from dataclasses import dataclass
3
+
4
+ @dataclass
5
+ class EvaluationMetric:
6
+ """
7
+ Evaluation metrics for the diffusion model.
8
+ The function is given generated samples batch [B, H, W, C] and the original batch.
9
+ """
10
+ function: Callable
11
+ name: str
@@ -0,0 +1,59 @@
1
+ from .common import EvaluationMetric
2
+ import jax
3
+ import jax.numpy as jnp
4
+
5
+ def get_clip_metric(
6
+ modelname: str = "openai/clip-vit-large-patch14",
7
+ ):
8
+ from transformers import AutoProcessor, FlaxCLIPModel
9
+ model = FlaxCLIPModel.from_pretrained(modelname, dtype=jnp.float16)
10
+ processor = AutoProcessor.from_pretrained(modelname, use_fast=True, dtype=jnp.float16)
11
+
12
+ @jax.jit
13
+ def calc(pixel_values, input_ids, attention_mask):
14
+ # Get the logits
15
+ generated_out = model(
16
+ pixel_values=pixel_values,
17
+ input_ids=input_ids,
18
+ attention_mask=attention_mask,
19
+ )
20
+
21
+ gen_img_emb = generated_out.image_embeds
22
+ txt_emb = generated_out.text_embeds
23
+
24
+ # 1. Normalize embeddings (essential for cosine similarity/distance)
25
+ gen_img_emb = gen_img_emb / (jnp.linalg.norm(gen_img_emb, axis=-1, keepdims=True) + 1e-6)
26
+ txt_emb = txt_emb / (jnp.linalg.norm(txt_emb, axis=-1, keepdims=True) + 1e-6)
27
+
28
+ # 2. Calculate cosine similarity
29
+ # Using einsum for batch dot product: batch (b), embedding_dim (d) -> bd,bd->b
30
+ # Calculate cosine similarity
31
+ similarity = jnp.einsum('bd,bd->b', gen_img_emb, txt_emb)
32
+
33
+ scaled_distance = (1.0 - similarity)
34
+ # 4. Average over the batch
35
+ mean_scaled_distance = jnp.mean(scaled_distance)
36
+
37
+ return mean_scaled_distance
38
+
39
+ def clip_metric(
40
+ generated: jnp.ndarray,
41
+ batch
42
+ ):
43
+ original_conditions = batch['text']
44
+
45
+ # Convert samples from [-1, 1] to [0, 255] and uint8
46
+ generated = (((generated + 1.0) / 2.0) * 255).astype(jnp.uint8)
47
+
48
+ generated_inputs = processor(images=generated, return_tensors="jax", padding=True,)
49
+
50
+ pixel_values = generated_inputs['pixel_values']
51
+ input_ids = original_conditions['input_ids']
52
+ attention_mask = original_conditions['attention_mask']
53
+
54
+ return calc(pixel_values, input_ids, attention_mask)
55
+
56
+ return EvaluationMetric(
57
+ function=clip_metric,
58
+ name='clip_similarity'
59
+ )
@@ -18,15 +18,17 @@ from ..samplers.ddim import DDIMSampler
18
18
  from flaxdiff.utils import RandomMarkovState, serialize_model, get_latest_checkpoint
19
19
  from flaxdiff.inputs import ConditioningEncoder, ConditionalInputConfig, DiffusionInputConfig
20
20
 
21
- from .simple_trainer import SimpleTrainer, SimpleTrainState, Metrics
21
+ from .simple_trainer import SimpleTrainer, SimpleTrainState, Metrics, convert_to_global_tree
22
22
 
23
23
  from flaxdiff.models.autoencoder.autoencoder import AutoEncoder
24
24
  from flax.training import dynamic_scale as dynamic_scale_lib
25
25
 
26
26
  # Reuse the TrainState from the DiffusionTrainer
27
- from flaxdiff.trainer.diffusion_trainer import TrainState, DiffusionTrainer
27
+ from .diffusion_trainer import TrainState, DiffusionTrainer
28
28
  import shutil
29
29
 
30
+ from flaxdiff.metrics.common import EvaluationMetric
31
+
30
32
  def generate_modelname(
31
33
  dataset_name: str,
32
34
  noise_schedule_name: str,
@@ -126,6 +128,7 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
126
128
  native_resolution: int = None,
127
129
  frames_per_sample: int = None,
128
130
  wandb_config: Dict[str, Any] = None,
131
+ eval_metrics: List[EvaluationMetric] = None,
129
132
  **kwargs
130
133
  ):
131
134
  """
@@ -150,6 +153,7 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
150
153
  autoencoder=autoencoder,
151
154
  )
152
155
  self.input_config = input_config
156
+ self.eval_metrics = eval_metrics
153
157
 
154
158
  if wandb_config is not None:
155
159
  # If input_config is not in wandb_config, add it
@@ -363,7 +367,6 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
363
367
  def generate_samples(
364
368
  val_state: TrainState,
365
369
  batch,
366
- sampler: DiffusionSampler,
367
370
  diffusion_steps: int,
368
371
  ):
369
372
  # Process all conditional inputs
@@ -385,7 +388,7 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
385
388
  model_conditioning_inputs=tuple(model_conditioning_inputs),
386
389
  )
387
390
 
388
- return sampler, generate_samples
391
+ return generate_samples
389
392
 
390
393
  def _get_image_size(self):
391
394
  """Helper to determine image size from available information."""
@@ -415,32 +418,73 @@ class GeneralDiffusionTrainer(DiffusionTrainer):
415
418
  """
416
419
  Run validation and log samples for both image and video diffusion.
417
420
  """
418
- sampler, generate_samples = val_step_fn
419
- val_ds = iter(val_ds()) if val_ds else None
421
+ global_device_count = jax.device_count()
422
+ local_device_count = jax.local_device_count()
423
+ process_index = jax.process_index()
424
+ generate_samples = val_step_fn
420
425
 
426
+ val_ds = iter(val_ds()) if val_ds else None
427
+ # Evaluation step
421
428
  try:
422
- # Generate samples
423
- samples = generate_samples(
424
- val_state,
425
- next(val_ds),
426
- sampler,
427
- diffusion_steps,
428
- )
429
-
430
- # Log samples to wandb
431
- if getattr(self, 'wandb', None) is not None and self.wandb:
432
- import numpy as np
429
+ metrics = {metric.name: [] for metric in self.eval_metrics} if self.eval_metrics else {}
430
+ for i in range(val_steps_per_epoch):
431
+ if val_ds is None:
432
+ batch = None
433
+ else:
434
+ batch = next(val_ds)
435
+ if self.distributed_training and global_device_count > 1:
436
+ batch = convert_to_global_tree(self.mesh, batch)
437
+ # Generate samples
438
+ samples = generate_samples(
439
+ val_state,
440
+ batch,
441
+ diffusion_steps,
442
+ )
433
443
 
434
- # Process samples differently based on dimensionality
435
- if len(samples.shape) == 5: # [B,T,H,W,C] - Video data
436
- self._log_video_samples(samples, current_step)
437
- else: # [B,H,W,C] - Image data
438
- self._log_image_samples(samples, current_step)
444
+ if self.eval_metrics is not None:
445
+ for metric in self.eval_metrics:
446
+ try:
447
+ # Evaluate metrics
448
+ metric_val = metric.function(samples, batch)
449
+ metrics[metric.name].append(metric_val)
450
+ except Exception as e:
451
+ print("Error in evaluation metrics:", e)
452
+ import traceback
453
+ traceback.print_exc()
454
+ pass
455
+
456
+ if i == 0:
457
+ print(f"Evaluation started for process index {process_index}")
458
+ # Log samples to wandb
459
+ if getattr(self, 'wandb', None) is not None and self.wandb:
460
+ import numpy as np
461
+
462
+ # Process samples differently based on dimensionality
463
+ if len(samples.shape) == 5: # [B,T,H,W,C] - Video data
464
+ self._log_video_samples(samples, current_step)
465
+ else: # [B,H,W,C] - Image data
466
+ self._log_image_samples(samples, current_step)
439
467
 
468
+ if getattr(self, 'wandb', None) is not None and self.wandb:
469
+ # metrics is a dict of metrics
470
+ if metrics and type(metrics) == dict:
471
+ # Flatten the metrics
472
+ metrics = {k: np.mean(v) for k, v in metrics.items()}
473
+ # Log the metrics
474
+ for key, value in metrics.items():
475
+ if isinstance(value, jnp.ndarray):
476
+ value = np.array(value)
477
+ self.wandb.log({
478
+ f"val/{key}": value,
479
+ }, step=current_step)
480
+
481
+ except StopIteration:
482
+ print(f"Validation dataset exhausted for process index {process_index}")
440
483
  except Exception as e:
441
- print("Error in validation loop:", e)
484
+ print(f"Error during validation for process index {process_index}: {e}")
442
485
  import traceback
443
486
  traceback.print_exc()
487
+
444
488
 
445
489
  def _log_video_samples(self, samples, current_step):
446
490
  """Helper to log video samples to wandb."""
@@ -411,7 +411,9 @@ class SimpleTrainer:
411
411
  train_ds,
412
412
  train_steps_per_epoch,
413
413
  current_step,
414
- rng_state
414
+ rng_state,
415
+ save_every:int=None,
416
+ val_every=None,
415
417
  ):
416
418
  global_device_count = jax.device_count()
417
419
  process_index = jax.process_index()
@@ -491,8 +493,8 @@ class SimpleTrainer:
491
493
  "train/loss": loss,
492
494
  }, step=current_step)
493
495
  # Save the model every few steps
494
- if i % 10000 == 0 and i > 0:
495
- print(f"Saving model after 10000 step {current_step}")
496
+ if save_every and i % save_every == 0 and i > 0:
497
+ print(f"Saving model after {save_every} step {current_step}")
496
498
  print(f"Devices: {len(jax.devices())}") # To sync the devices
497
499
  self.save(current_epoch, current_step, train_state, rng_state)
498
500
  print(f"Saving done by process index {process_index}")
@@ -518,7 +520,7 @@ class SimpleTrainer:
518
520
  self.validation_loop(
519
521
  train_state,
520
522
  val_step,
521
- data.get('test', data.get('val', None)),
523
+ data.get('val', data.get('test', None)),
522
524
  val_steps_per_epoch,
523
525
  self.latest_step,
524
526
  )
@@ -569,7 +571,7 @@ class SimpleTrainer:
569
571
  self.validation_loop(
570
572
  train_state,
571
573
  val_step,
572
- data.get('test', None),
574
+ data.get('val', data.get('test', None)),
573
575
  val_steps_per_epoch,
574
576
  current_step,
575
577
  )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: flaxdiff
3
- Version: 0.2.2
3
+ Version: 0.2.4
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author-email: Ashish Kumar Singh <ashishkmr472@gmail.com>
6
6
  License-Expression: MIT
@@ -22,6 +22,8 @@ Requires-Dist: python-dotenv
22
22
 
23
23
  # ![](images/logo.jpeg "FlaxDiff")
24
24
 
25
+ **This project is being used for the UMD Course project MSML 605: MLOps**
26
+
25
27
  **This project is partially supported by [Google TPU Research Cloud](https://sites.research.google/trc/about/). I would like to thank the Google Cloud TPU team for providing me with the resources to train the bigger text-conditional models in multi-host distributed settings.**
26
28
 
27
29
  ## A Versatile and simple Diffusion Library
@@ -2,14 +2,14 @@ flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  flaxdiff/utils.py,sha256=DmlWUY1FGz4ESxIHaPQJf92CHjsdMjyDd651wFUtyNg,8838
3
3
  flaxdiff/data/__init__.py,sha256=8W5y7NyAOWtpLi8WRawk4VYeE3DMDnM3B_jKPD8BoFQ,143
4
4
  flaxdiff/data/benchmark_decord.py,sha256=x56Db1VPmziv_9KJvWdfS0O7cffsYkF5tt5WvldOKc0,13720
5
- flaxdiff/data/dataloaders.py,sha256=V4goNCK0JD_TthggXAEgJJD4LxJi1pUDew1x_fMCuO4,22576
6
- flaxdiff/data/dataset_map.py,sha256=NrLG1XtIxy8GcCsZ-e6eascjgsP0Xq5lVA1z3HIIYyI,5093
5
+ flaxdiff/data/dataloaders.py,sha256=LV8ugqoB86yihfYeOJZHHdRZJNmZ63A2NQkdILMR9QA,23564
6
+ flaxdiff/data/dataset_map.py,sha256=_6SYnmrYO-URDd8vPAmALTV6r0eMGWWmwUtsdjKGXnA,5072
7
7
  flaxdiff/data/online_loader.py,sha256=t1jEhdB6gWTlwx68ehj1ol_PrImbwXYiRlrJPCmNgCM,35701
8
8
  flaxdiff/data/sources/audio_utils.py,sha256=X27gG1yQt_abVOYgMtruYmZD7-8_uQCRhhTSpn4clkI,4514
9
9
  flaxdiff/data/sources/av_example.py,sha256=RIcbVKqckFqbfnV65NQotzIBxjdDuM67kD1nY8fqw5Q,3826
10
10
  flaxdiff/data/sources/av_utils.py,sha256=LCr9MJNurOaoxY-sjzkLqJS_MlX0x3gRSlKAVIglAU0,24045
11
11
  flaxdiff/data/sources/base.py,sha256=uhF0odJSYRy0SLw1xnI9Q_q_xiVht2DmEYcX1j9AWT4,4246
12
- flaxdiff/data/sources/images.py,sha256=WpH4ywZhNol26peX3m6m5NrmDJ1K2s6fRcYHvOFlOk8,11102
12
+ flaxdiff/data/sources/images.py,sha256=P7Rea7Zu0h9l7Zoc33zEHKdLI1ST6JEqgl1-bRwORM4,11460
13
13
  flaxdiff/data/sources/utils.py,sha256=kFzM4_kPoThbAu54ulABmEDAR33tR50NgzXIpC0Dzjk,7316
14
14
  flaxdiff/data/sources/videos.py,sha256=CVpOH6A4P2D8iv3gZIhd2GB5ATUD8Vsm_wVYbbugWD4,9359
15
15
  flaxdiff/data/sources/voxceleb2.py,sha256=BoKfat_hsw6ObDyyaiQmPbBzuFiqgCGlgAZmf-t5Iz8,18621
@@ -18,6 +18,9 @@ flaxdiff/inference/pipeline.py,sha256=oMBRjvTtlC3Yzl1FqiBHcI4V34HXGAecCg8UvQbKoO
18
18
  flaxdiff/inference/utils.py,sha256=SRNYo-YtHzEPRpNv0fD8ZrUvnRIK941Rh4tjlsOGRgM,12278
19
19
  flaxdiff/inputs/__init__.py,sha256=ybPjQsFAf5sqRVZG1sRiOl99EnwpI-NQ8HE3y7UbXmU,7197
20
20
  flaxdiff/inputs/encoders.py,sha256=pjfbx4Rk7bLoE80MOfThZDm6YtsDncRekmn0Bmg_CwI,2963
21
+ flaxdiff/metrics/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
22
+ flaxdiff/metrics/common.py,sha256=E0MkL43dicImzNNa-RyQ3sVcrUbpeLlooIQsKTIStvo,285
23
+ flaxdiff/metrics/images.py,sha256=sIuF_Sa2VmPOKrFFoUpzhqOqNa9P7NF0njbrYi93AvE,2128
21
24
  flaxdiff/metrics/inception.py,sha256=a5kjMCPMT9gB88c_HCKiek-2vsAyoE35K7nDt4h4pVI,31843
22
25
  flaxdiff/metrics/psnr.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
23
26
  flaxdiff/metrics/ssim.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -56,9 +59,9 @@ flaxdiff/schedulers/sqrt.py,sha256=mCd_szmOqF6vqQKiAiEOqV_3eBIPGYrW3VxK0o4rBuo,4
56
59
  flaxdiff/trainer/__init__.py,sha256=xSoierfi26gxfgxlNnwvyyPmuPAJ--5i3mEHxt3S-AE,215
57
60
  flaxdiff/trainer/autoencoder_trainer.py,sha256=2FP2P-k9c0n_k3eT0trkq73dQrHRdBj9ObK1idcyhSw,6996
58
61
  flaxdiff/trainer/diffusion_trainer.py,sha256=reQEVWKTqKAeyCMQ-curPOfSRmBKxKooK8EVtUuorcM,14599
59
- flaxdiff/trainer/general_diffusion_trainer.py,sha256=7VAeT3TzCDUyns8wdZbIwXJqDKx_FYSzq8toOkaeQMI,24802
60
- flaxdiff/trainer/simple_trainer.py,sha256=CF2mMcc6AtBgcR1XiqKevRL0paGS0S9ZJofCns32nRM,24214
61
- flaxdiff-0.2.2.dist-info/METADATA,sha256=pzYYdy1zK7lbaqSRdpopZHHYx7q3BP0DL11hGTOO7h4,23982
62
- flaxdiff-0.2.2.dist-info/WHEEL,sha256=pxyMxgL8-pra_rKaQ4drOZAegBVuX-G_4nRHjjgWbmo,91
63
- flaxdiff-0.2.2.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
64
- flaxdiff-0.2.2.dist-info/RECORD,,
62
+ flaxdiff/trainer/general_diffusion_trainer.py,sha256=9c3Ys5sN4_eTehusLjS6IKW5XPOkxoguik-6G0cyQc4,27082
63
+ flaxdiff/trainer/simple_trainer.py,sha256=raLS1shwpjJBT_bYXLAB2E4kA9MbwasDTzDTUqfCCUc,24312
64
+ flaxdiff-0.2.4.dist-info/METADATA,sha256=mqm2um1TtgjQzrGlvl7x_CCb_09hK376_dsRECe6qLQ,24057
65
+ flaxdiff-0.2.4.dist-info/WHEEL,sha256=wXxTzcEDnjrTwFYjLPcsW_7_XihufBwmpiBeiXNBGEA,91
66
+ flaxdiff-0.2.4.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
67
+ flaxdiff-0.2.4.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (79.0.0)
2
+ Generator: setuptools (80.1.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5