flaxdiff 0.2.3__tar.gz → 0.2.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. {flaxdiff-0.2.3 → flaxdiff-0.2.5}/PKG-INFO +1 -1
  2. {flaxdiff-0.2.3 → flaxdiff-0.2.5}/flaxdiff/data/dataloaders.py +1 -1
  3. {flaxdiff-0.2.3 → flaxdiff-0.2.5}/flaxdiff/inference/pipeline.py +4 -2
  4. flaxdiff-0.2.5/flaxdiff/metrics/common.py +11 -0
  5. flaxdiff-0.2.5/flaxdiff/metrics/images.py +59 -0
  6. flaxdiff-0.2.5/flaxdiff/metrics/ssim.py +0 -0
  7. {flaxdiff-0.2.3 → flaxdiff-0.2.5}/flaxdiff/trainer/general_diffusion_trainer.py +2 -9
  8. {flaxdiff-0.2.3 → flaxdiff-0.2.5}/flaxdiff.egg-info/PKG-INFO +1 -1
  9. {flaxdiff-0.2.3 → flaxdiff-0.2.5}/flaxdiff.egg-info/SOURCES.txt +3 -0
  10. {flaxdiff-0.2.3 → flaxdiff-0.2.5}/pyproject.toml +1 -1
  11. {flaxdiff-0.2.3 → flaxdiff-0.2.5}/README.md +0 -0
  12. {flaxdiff-0.2.3 → flaxdiff-0.2.5}/flaxdiff/__init__.py +0 -0
  13. {flaxdiff-0.2.3 → flaxdiff-0.2.5}/flaxdiff/data/__init__.py +0 -0
  14. {flaxdiff-0.2.3 → flaxdiff-0.2.5}/flaxdiff/data/benchmark_decord.py +0 -0
  15. {flaxdiff-0.2.3 → flaxdiff-0.2.5}/flaxdiff/data/dataset_map.py +0 -0
  16. {flaxdiff-0.2.3 → flaxdiff-0.2.5}/flaxdiff/data/online_loader.py +0 -0
  17. {flaxdiff-0.2.3 → flaxdiff-0.2.5}/flaxdiff/data/sources/audio_utils.py +0 -0
  18. {flaxdiff-0.2.3 → flaxdiff-0.2.5}/flaxdiff/data/sources/av_example.py +0 -0
  19. {flaxdiff-0.2.3 → flaxdiff-0.2.5}/flaxdiff/data/sources/av_utils.py +0 -0
  20. {flaxdiff-0.2.3 → flaxdiff-0.2.5}/flaxdiff/data/sources/base.py +0 -0
  21. {flaxdiff-0.2.3 → flaxdiff-0.2.5}/flaxdiff/data/sources/images.py +0 -0
  22. {flaxdiff-0.2.3 → flaxdiff-0.2.5}/flaxdiff/data/sources/utils.py +0 -0
  23. {flaxdiff-0.2.3 → flaxdiff-0.2.5}/flaxdiff/data/sources/videos.py +0 -0
  24. {flaxdiff-0.2.3 → flaxdiff-0.2.5}/flaxdiff/data/sources/voxceleb2.py +0 -0
  25. {flaxdiff-0.2.3 → flaxdiff-0.2.5}/flaxdiff/inference/__init__.py +0 -0
  26. {flaxdiff-0.2.3 → flaxdiff-0.2.5}/flaxdiff/inference/utils.py +0 -0
  27. {flaxdiff-0.2.3 → flaxdiff-0.2.5}/flaxdiff/inputs/__init__.py +0 -0
  28. {flaxdiff-0.2.3 → flaxdiff-0.2.5}/flaxdiff/inputs/encoders.py +0 -0
  29. /flaxdiff-0.2.3/flaxdiff/metrics/psnr.py → /flaxdiff-0.2.5/flaxdiff/metrics/__init__.py +0 -0
  30. {flaxdiff-0.2.3 → flaxdiff-0.2.5}/flaxdiff/metrics/inception.py +0 -0
  31. /flaxdiff-0.2.3/flaxdiff/metrics/ssim.py → /flaxdiff-0.2.5/flaxdiff/metrics/psnr.py +0 -0
  32. {flaxdiff-0.2.3 → flaxdiff-0.2.5}/flaxdiff/metrics/utils.py +0 -0
  33. {flaxdiff-0.2.3 → flaxdiff-0.2.5}/flaxdiff/models/__init__.py +0 -0
  34. {flaxdiff-0.2.3 → flaxdiff-0.2.5}/flaxdiff/models/attention.py +0 -0
  35. {flaxdiff-0.2.3 → flaxdiff-0.2.5}/flaxdiff/models/autoencoder/__init__.py +0 -0
  36. {flaxdiff-0.2.3 → flaxdiff-0.2.5}/flaxdiff/models/autoencoder/autoencoder.py +0 -0
  37. {flaxdiff-0.2.3 → flaxdiff-0.2.5}/flaxdiff/models/autoencoder/diffusers.py +0 -0
  38. {flaxdiff-0.2.3 → flaxdiff-0.2.5}/flaxdiff/models/autoencoder/simple_autoenc.py +0 -0
  39. {flaxdiff-0.2.3 → flaxdiff-0.2.5}/flaxdiff/models/common.py +0 -0
  40. {flaxdiff-0.2.3 → flaxdiff-0.2.5}/flaxdiff/models/favor_fastattn.py +0 -0
  41. {flaxdiff-0.2.3 → flaxdiff-0.2.5}/flaxdiff/models/general.py +0 -0
  42. {flaxdiff-0.2.3 → flaxdiff-0.2.5}/flaxdiff/models/simple_unet.py +0 -0
  43. {flaxdiff-0.2.3 → flaxdiff-0.2.5}/flaxdiff/models/simple_vit.py +0 -0
  44. {flaxdiff-0.2.3 → flaxdiff-0.2.5}/flaxdiff/models/unet_3d.py +0 -0
  45. {flaxdiff-0.2.3 → flaxdiff-0.2.5}/flaxdiff/models/unet_3d_blocks.py +0 -0
  46. {flaxdiff-0.2.3 → flaxdiff-0.2.5}/flaxdiff/predictors/__init__.py +0 -0
  47. {flaxdiff-0.2.3 → flaxdiff-0.2.5}/flaxdiff/samplers/__init__.py +0 -0
  48. {flaxdiff-0.2.3 → flaxdiff-0.2.5}/flaxdiff/samplers/common.py +0 -0
  49. {flaxdiff-0.2.3 → flaxdiff-0.2.5}/flaxdiff/samplers/ddim.py +0 -0
  50. {flaxdiff-0.2.3 → flaxdiff-0.2.5}/flaxdiff/samplers/ddpm.py +0 -0
  51. {flaxdiff-0.2.3 → flaxdiff-0.2.5}/flaxdiff/samplers/euler.py +0 -0
  52. {flaxdiff-0.2.3 → flaxdiff-0.2.5}/flaxdiff/samplers/heun_sampler.py +0 -0
  53. {flaxdiff-0.2.3 → flaxdiff-0.2.5}/flaxdiff/samplers/multistep_dpm.py +0 -0
  54. {flaxdiff-0.2.3 → flaxdiff-0.2.5}/flaxdiff/samplers/rk4_sampler.py +0 -0
  55. {flaxdiff-0.2.3 → flaxdiff-0.2.5}/flaxdiff/schedulers/__init__.py +0 -0
  56. {flaxdiff-0.2.3 → flaxdiff-0.2.5}/flaxdiff/schedulers/common.py +0 -0
  57. {flaxdiff-0.2.3 → flaxdiff-0.2.5}/flaxdiff/schedulers/continuous.py +0 -0
  58. {flaxdiff-0.2.3 → flaxdiff-0.2.5}/flaxdiff/schedulers/cosine.py +0 -0
  59. {flaxdiff-0.2.3 → flaxdiff-0.2.5}/flaxdiff/schedulers/discrete.py +0 -0
  60. {flaxdiff-0.2.3 → flaxdiff-0.2.5}/flaxdiff/schedulers/exp.py +0 -0
  61. {flaxdiff-0.2.3 → flaxdiff-0.2.5}/flaxdiff/schedulers/karras.py +0 -0
  62. {flaxdiff-0.2.3 → flaxdiff-0.2.5}/flaxdiff/schedulers/linear.py +0 -0
  63. {flaxdiff-0.2.3 → flaxdiff-0.2.5}/flaxdiff/schedulers/sqrt.py +0 -0
  64. {flaxdiff-0.2.3 → flaxdiff-0.2.5}/flaxdiff/trainer/__init__.py +0 -0
  65. {flaxdiff-0.2.3 → flaxdiff-0.2.5}/flaxdiff/trainer/autoencoder_trainer.py +0 -0
  66. {flaxdiff-0.2.3 → flaxdiff-0.2.5}/flaxdiff/trainer/diffusion_trainer.py +0 -0
  67. {flaxdiff-0.2.3 → flaxdiff-0.2.5}/flaxdiff/trainer/simple_trainer.py +0 -0
  68. {flaxdiff-0.2.3 → flaxdiff-0.2.5}/flaxdiff/utils.py +0 -0
  69. {flaxdiff-0.2.3 → flaxdiff-0.2.5}/flaxdiff.egg-info/dependency_links.txt +0 -0
  70. {flaxdiff-0.2.3 → flaxdiff-0.2.5}/flaxdiff.egg-info/requires.txt +0 -0
  71. {flaxdiff-0.2.3 → flaxdiff-0.2.5}/flaxdiff.egg-info/top_level.txt +0 -0
  72. {flaxdiff-0.2.3 → flaxdiff-0.2.5}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: flaxdiff
3
- Version: 0.2.3
3
+ Version: 0.2.5
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author-email: Ashish Kumar Singh <ashishkmr472@gmail.com>
6
6
  License-Expression: MIT
@@ -258,7 +258,7 @@ def get_dataset_grain(
258
258
  image_scale=256,
259
259
  count=None,
260
260
  num_epochs=None,
261
- method=jax.image.ResizeMethod.LANCZOS3,
261
+ method=None, #jax.image.ResizeMethod.LANCZOS3,
262
262
  worker_count=32,
263
263
  read_thread_count=64,
264
264
  read_buffer_size=50,
@@ -208,7 +208,8 @@ class DiffusionInferencePipeline(InferencePipeline):
208
208
  self,
209
209
  num_samples: int,
210
210
  resolution: int,
211
- conditioning_data: Optional[List[Union[Tuple, Dict]]] = None, # one list per modality or list of tuples
211
+ conditioning_data: List[Union[Tuple, Dict]] = None,
212
+ conditioning_data_tokens: Tuple = None,
212
213
  sequence_length: Optional[int] = None,
213
214
  diffusion_steps: int = 50,
214
215
  guidance_scale: float = 1.0,
@@ -256,5 +257,6 @@ class DiffusionInferencePipeline(InferencePipeline):
256
257
  steps_override=steps_override,
257
258
  priors=priors,
258
259
  rngstate=rngstate,
259
- conditioning=conditioning_data
260
+ conditioning=conditioning_data,
261
+ model_conditioning_inputs=conditioning_data_tokens,
260
262
  )
@@ -0,0 +1,11 @@
1
+ from typing import Callable
2
+ from dataclasses import dataclass
3
+
4
+ @dataclass
5
+ class EvaluationMetric:
6
+ """
7
+ Evaluation metrics for the diffusion model.
8
+ The function is given generated samples batch [B, H, W, C] and the original batch.
9
+ """
10
+ function: Callable
11
+ name: str
@@ -0,0 +1,59 @@
1
+ from .common import EvaluationMetric
2
+ import jax
3
+ import jax.numpy as jnp
4
+
5
+ def get_clip_metric(
6
+ modelname: str = "openai/clip-vit-large-patch14",
7
+ ):
8
+ from transformers import AutoProcessor, FlaxCLIPModel
9
+ model = FlaxCLIPModel.from_pretrained(modelname, dtype=jnp.float16)
10
+ processor = AutoProcessor.from_pretrained(modelname, use_fast=False, dtype=jnp.float16)
11
+
12
+ @jax.jit
13
+ def calc(pixel_values, input_ids, attention_mask):
14
+ # Get the logits
15
+ generated_out = model(
16
+ pixel_values=pixel_values,
17
+ input_ids=input_ids,
18
+ attention_mask=attention_mask,
19
+ )
20
+
21
+ gen_img_emb = generated_out.image_embeds
22
+ txt_emb = generated_out.text_embeds
23
+
24
+ # 1. Normalize embeddings (essential for cosine similarity/distance)
25
+ gen_img_emb = gen_img_emb / (jnp.linalg.norm(gen_img_emb, axis=-1, keepdims=True) + 1e-6)
26
+ txt_emb = txt_emb / (jnp.linalg.norm(txt_emb, axis=-1, keepdims=True) + 1e-6)
27
+
28
+ # 2. Calculate cosine similarity
29
+ # Using einsum for batch dot product: batch (b), embedding_dim (d) -> bd,bd->b
30
+ # Calculate cosine similarity
31
+ similarity = jnp.einsum('bd,bd->b', gen_img_emb, txt_emb)
32
+
33
+ scaled_distance = (1.0 - similarity)
34
+ # 4. Average over the batch
35
+ mean_scaled_distance = jnp.mean(scaled_distance)
36
+
37
+ return mean_scaled_distance
38
+
39
+ def clip_metric(
40
+ generated: jnp.ndarray,
41
+ batch
42
+ ):
43
+ original_conditions = batch['text']
44
+
45
+ # Convert samples from [-1, 1] to [0, 255] and uint8
46
+ generated = (((generated + 1.0) / 2.0) * 255).astype(jnp.uint8)
47
+
48
+ generated_inputs = processor(images=generated, return_tensors="jax", padding=True,)
49
+
50
+ pixel_values = generated_inputs['pixel_values']
51
+ input_ids = original_conditions['input_ids']
52
+ attention_mask = original_conditions['attention_mask']
53
+
54
+ return calc(pixel_values, input_ids, attention_mask)
55
+
56
+ return EvaluationMetric(
57
+ function=clip_metric,
58
+ name='clip_similarity'
59
+ )
File without changes
@@ -27,6 +27,8 @@ from flax.training import dynamic_scale as dynamic_scale_lib
27
27
  from .diffusion_trainer import TrainState, DiffusionTrainer
28
28
  import shutil
29
29
 
30
+ from flaxdiff.metrics.common import EvaluationMetric
31
+
30
32
  def generate_modelname(
31
33
  dataset_name: str,
32
34
  noise_schedule_name: str,
@@ -103,15 +105,6 @@ def generate_modelname(
103
105
  # model_name = f"{model_name}-{config_hash}"
104
106
  return model_name
105
107
 
106
- @dataclass
107
- class EvaluationMetric:
108
- """
109
- Evaluation metrics for the diffusion model.
110
- The function is given generated samples batch [B, H, W, C] and the original batch.
111
- """
112
- function: Callable
113
- name: str
114
-
115
108
  class GeneralDiffusionTrainer(DiffusionTrainer):
116
109
  """
117
110
  General trainer for diffusion models supporting both images and videos.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: flaxdiff
3
- Version: 0.2.3
3
+ Version: 0.2.5
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author-email: Ashish Kumar Singh <ashishkmr472@gmail.com>
6
6
  License-Expression: MIT
@@ -25,6 +25,9 @@ flaxdiff/inference/pipeline.py
25
25
  flaxdiff/inference/utils.py
26
26
  flaxdiff/inputs/__init__.py
27
27
  flaxdiff/inputs/encoders.py
28
+ flaxdiff/metrics/__init__.py
29
+ flaxdiff/metrics/common.py
30
+ flaxdiff/metrics/images.py
28
31
  flaxdiff/metrics/inception.py
29
32
  flaxdiff/metrics/psnr.py
30
33
  flaxdiff/metrics/ssim.py
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "flaxdiff"
7
- version = "0.2.3"
7
+ version = "0.2.5"
8
8
  description = "A versatile and easy to understand Diffusion library"
9
9
  readme = "README.md"
10
10
  authors = [
File without changes
File without changes
File without changes
File without changes