flaxdiff 0.2.3__py3-none-any.whl → 0.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -258,7 +258,7 @@ def get_dataset_grain(
258
258
  image_scale=256,
259
259
  count=None,
260
260
  num_epochs=None,
261
- method=jax.image.ResizeMethod.LANCZOS3,
261
+ method=None, #jax.image.ResizeMethod.LANCZOS3,
262
262
  worker_count=32,
263
263
  read_thread_count=64,
264
264
  read_buffer_size=50,
@@ -208,7 +208,8 @@ class DiffusionInferencePipeline(InferencePipeline):
208
208
  self,
209
209
  num_samples: int,
210
210
  resolution: int,
211
- conditioning_data: Optional[List[Union[Tuple, Dict]]] = None, # one list per modality or list of tuples
211
+ conditioning_data: List[Union[Tuple, Dict]] = None,
212
+ conditioning_data_tokens: Tuple = None,
212
213
  sequence_length: Optional[int] = None,
213
214
  diffusion_steps: int = 50,
214
215
  guidance_scale: float = 1.0,
@@ -256,5 +257,6 @@ class DiffusionInferencePipeline(InferencePipeline):
256
257
  steps_override=steps_override,
257
258
  priors=priors,
258
259
  rngstate=rngstate,
259
- conditioning=conditioning_data
260
+ conditioning=conditioning_data,
261
+ model_conditioning_inputs=conditioning_data_tokens,
260
262
  )
File without changes
@@ -0,0 +1,11 @@
1
+ from typing import Callable
2
+ from dataclasses import dataclass
3
+
4
+ @dataclass
5
+ class EvaluationMetric:
6
+ """
7
+ Evaluation metrics for the diffusion model.
8
+ The function is given generated samples batch [B, H, W, C] and the original batch.
9
+ """
10
+ function: Callable
11
+ name: str
@@ -0,0 +1,59 @@
1
+ from .common import EvaluationMetric
2
+ import jax
3
+ import jax.numpy as jnp
4
+
5
+ def get_clip_metric(
6
+ modelname: str = "openai/clip-vit-large-patch14",
7
+ ):
8
+ from transformers import AutoProcessor, FlaxCLIPModel
9
+ model = FlaxCLIPModel.from_pretrained(modelname, dtype=jnp.float16)
10
+ processor = AutoProcessor.from_pretrained(modelname, use_fast=False, dtype=jnp.float16)
11
+
12
+ @jax.jit
13
+ def calc(pixel_values, input_ids, attention_mask):
14
+ # Get the logits
15
+ generated_out = model(
16
+ pixel_values=pixel_values,
17
+ input_ids=input_ids,
18
+ attention_mask=attention_mask,
19
+ )
20
+
21
+ gen_img_emb = generated_out.image_embeds
22
+ txt_emb = generated_out.text_embeds
23
+
24
+ # 1. Normalize embeddings (essential for cosine similarity/distance)
25
+ gen_img_emb = gen_img_emb / (jnp.linalg.norm(gen_img_emb, axis=-1, keepdims=True) + 1e-6)
26
+ txt_emb = txt_emb / (jnp.linalg.norm(txt_emb, axis=-1, keepdims=True) + 1e-6)
27
+
28
+ # 2. Calculate cosine similarity
29
+ # Using einsum for batch dot product: batch (b), embedding_dim (d) -> bd,bd->b
30
+ # Calculate cosine similarity
31
+ similarity = jnp.einsum('bd,bd->b', gen_img_emb, txt_emb)
32
+
33
+ scaled_distance = (1.0 - similarity)
34
+ # 4. Average over the batch
35
+ mean_scaled_distance = jnp.mean(scaled_distance)
36
+
37
+ return mean_scaled_distance
38
+
39
+ def clip_metric(
40
+ generated: jnp.ndarray,
41
+ batch
42
+ ):
43
+ original_conditions = batch['text']
44
+
45
+ # Convert samples from [-1, 1] to [0, 255] and uint8
46
+ generated = (((generated + 1.0) / 2.0) * 255).astype(jnp.uint8)
47
+
48
+ generated_inputs = processor(images=generated, return_tensors="jax", padding=True,)
49
+
50
+ pixel_values = generated_inputs['pixel_values']
51
+ input_ids = original_conditions['input_ids']
52
+ attention_mask = original_conditions['attention_mask']
53
+
54
+ return calc(pixel_values, input_ids, attention_mask)
55
+
56
+ return EvaluationMetric(
57
+ function=clip_metric,
58
+ name='clip_similarity'
59
+ )
@@ -27,6 +27,8 @@ from flax.training import dynamic_scale as dynamic_scale_lib
27
27
  from .diffusion_trainer import TrainState, DiffusionTrainer
28
28
  import shutil
29
29
 
30
+ from flaxdiff.metrics.common import EvaluationMetric
31
+
30
32
  def generate_modelname(
31
33
  dataset_name: str,
32
34
  noise_schedule_name: str,
@@ -103,15 +105,6 @@ def generate_modelname(
103
105
  # model_name = f"{model_name}-{config_hash}"
104
106
  return model_name
105
107
 
106
- @dataclass
107
- class EvaluationMetric:
108
- """
109
- Evaluation metrics for the diffusion model.
110
- The function is given generated samples batch [B, H, W, C] and the original batch.
111
- """
112
- function: Callable
113
- name: str
114
-
115
108
  class GeneralDiffusionTrainer(DiffusionTrainer):
116
109
  """
117
110
  General trainer for diffusion models supporting both images and videos.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: flaxdiff
3
- Version: 0.2.3
3
+ Version: 0.2.5
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author-email: Ashish Kumar Singh <ashishkmr472@gmail.com>
6
6
  License-Expression: MIT
@@ -2,7 +2,7 @@ flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  flaxdiff/utils.py,sha256=DmlWUY1FGz4ESxIHaPQJf92CHjsdMjyDd651wFUtyNg,8838
3
3
  flaxdiff/data/__init__.py,sha256=8W5y7NyAOWtpLi8WRawk4VYeE3DMDnM3B_jKPD8BoFQ,143
4
4
  flaxdiff/data/benchmark_decord.py,sha256=x56Db1VPmziv_9KJvWdfS0O7cffsYkF5tt5WvldOKc0,13720
5
- flaxdiff/data/dataloaders.py,sha256=TgbR5CMxE86L0-1qy5ohZT8zhOPjk3oncd5WPBv08sQ,23557
5
+ flaxdiff/data/dataloaders.py,sha256=LV8ugqoB86yihfYeOJZHHdRZJNmZ63A2NQkdILMR9QA,23564
6
6
  flaxdiff/data/dataset_map.py,sha256=_6SYnmrYO-URDd8vPAmALTV6r0eMGWWmwUtsdjKGXnA,5072
7
7
  flaxdiff/data/online_loader.py,sha256=t1jEhdB6gWTlwx68ehj1ol_PrImbwXYiRlrJPCmNgCM,35701
8
8
  flaxdiff/data/sources/audio_utils.py,sha256=X27gG1yQt_abVOYgMtruYmZD7-8_uQCRhhTSpn4clkI,4514
@@ -14,10 +14,13 @@ flaxdiff/data/sources/utils.py,sha256=kFzM4_kPoThbAu54ulABmEDAR33tR50NgzXIpC0Dzj
14
14
  flaxdiff/data/sources/videos.py,sha256=CVpOH6A4P2D8iv3gZIhd2GB5ATUD8Vsm_wVYbbugWD4,9359
15
15
  flaxdiff/data/sources/voxceleb2.py,sha256=BoKfat_hsw6ObDyyaiQmPbBzuFiqgCGlgAZmf-t5Iz8,18621
16
16
  flaxdiff/inference/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
17
- flaxdiff/inference/pipeline.py,sha256=oMBRjvTtlC3Yzl1FqiBHcI4V34HXGAecCg8UvQbKoOc,8849
17
+ flaxdiff/inference/pipeline.py,sha256=pVMAiK8-nm-UWJRkd2aJqY3GFBGW9h63VtM68RBvfrM,8909
18
18
  flaxdiff/inference/utils.py,sha256=SRNYo-YtHzEPRpNv0fD8ZrUvnRIK941Rh4tjlsOGRgM,12278
19
19
  flaxdiff/inputs/__init__.py,sha256=ybPjQsFAf5sqRVZG1sRiOl99EnwpI-NQ8HE3y7UbXmU,7197
20
20
  flaxdiff/inputs/encoders.py,sha256=pjfbx4Rk7bLoE80MOfThZDm6YtsDncRekmn0Bmg_CwI,2963
21
+ flaxdiff/metrics/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
22
+ flaxdiff/metrics/common.py,sha256=E0MkL43dicImzNNa-RyQ3sVcrUbpeLlooIQsKTIStvo,285
23
+ flaxdiff/metrics/images.py,sha256=B6yXp-u9TMQY-EvM2hybRDxDbdwPwhzeGP_lBaX0mnc,2129
21
24
  flaxdiff/metrics/inception.py,sha256=a5kjMCPMT9gB88c_HCKiek-2vsAyoE35K7nDt4h4pVI,31843
22
25
  flaxdiff/metrics/psnr.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
23
26
  flaxdiff/metrics/ssim.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -56,9 +59,9 @@ flaxdiff/schedulers/sqrt.py,sha256=mCd_szmOqF6vqQKiAiEOqV_3eBIPGYrW3VxK0o4rBuo,4
56
59
  flaxdiff/trainer/__init__.py,sha256=xSoierfi26gxfgxlNnwvyyPmuPAJ--5i3mEHxt3S-AE,215
57
60
  flaxdiff/trainer/autoencoder_trainer.py,sha256=2FP2P-k9c0n_k3eT0trkq73dQrHRdBj9ObK1idcyhSw,6996
58
61
  flaxdiff/trainer/diffusion_trainer.py,sha256=reQEVWKTqKAeyCMQ-curPOfSRmBKxKooK8EVtUuorcM,14599
59
- flaxdiff/trainer/general_diffusion_trainer.py,sha256=1rLU7iooXIlSDIGFZ7bHgpMWmkqMbUzM9fHBu1L0t-U,27252
62
+ flaxdiff/trainer/general_diffusion_trainer.py,sha256=9c3Ys5sN4_eTehusLjS6IKW5XPOkxoguik-6G0cyQc4,27082
60
63
  flaxdiff/trainer/simple_trainer.py,sha256=raLS1shwpjJBT_bYXLAB2E4kA9MbwasDTzDTUqfCCUc,24312
61
- flaxdiff-0.2.3.dist-info/METADATA,sha256=eoCSaBNoDpk90qWz5_NGVkzvuf3Oqt6rSj_ZVTfYn7s,24057
62
- flaxdiff-0.2.3.dist-info/WHEEL,sha256=wXxTzcEDnjrTwFYjLPcsW_7_XihufBwmpiBeiXNBGEA,91
63
- flaxdiff-0.2.3.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
64
- flaxdiff-0.2.3.dist-info/RECORD,,
64
+ flaxdiff-0.2.5.dist-info/METADATA,sha256=hIsqI7gFYMgPEN9cuJLjafhC0la5c3Y6ZShbyMdYl5A,24057
65
+ flaxdiff-0.2.5.dist-info/WHEEL,sha256=wXxTzcEDnjrTwFYjLPcsW_7_XihufBwmpiBeiXNBGEA,91
66
+ flaxdiff-0.2.5.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
67
+ flaxdiff-0.2.5.dist-info/RECORD,,