flaxdiff 0.2.3__py3-none-any.whl → 0.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -258,7 +258,7 @@ def get_dataset_grain(
258
258
  image_scale=256,
259
259
  count=None,
260
260
  num_epochs=None,
261
- method=jax.image.ResizeMethod.LANCZOS3,
261
+ method=None, #jax.image.ResizeMethod.LANCZOS3,
262
262
  worker_count=32,
263
263
  read_thread_count=64,
264
264
  read_buffer_size=50,
File without changes
@@ -0,0 +1,11 @@
1
+ from typing import Callable
2
+ from dataclasses import dataclass
3
+
4
+ @dataclass
5
+ class EvaluationMetric:
6
+ """
7
+ Evaluation metrics for the diffusion model.
8
+ The function is given generated samples batch [B, H, W, C] and the original batch.
9
+ """
10
+ function: Callable
11
+ name: str
@@ -0,0 +1,59 @@
1
+ from .common import EvaluationMetric
2
+ import jax
3
+ import jax.numpy as jnp
4
+
5
+ def get_clip_metric(
6
+ modelname: str = "openai/clip-vit-large-patch14",
7
+ ):
8
+ from transformers import AutoProcessor, FlaxCLIPModel
9
+ model = FlaxCLIPModel.from_pretrained(modelname, dtype=jnp.float16)
10
+ processor = AutoProcessor.from_pretrained(modelname, use_fast=True, dtype=jnp.float16)
11
+
12
+ @jax.jit
13
+ def calc(pixel_values, input_ids, attention_mask):
14
+ # Get the logits
15
+ generated_out = model(
16
+ pixel_values=pixel_values,
17
+ input_ids=input_ids,
18
+ attention_mask=attention_mask,
19
+ )
20
+
21
+ gen_img_emb = generated_out.image_embeds
22
+ txt_emb = generated_out.text_embeds
23
+
24
+ # 1. Normalize embeddings (essential for cosine similarity/distance)
25
+ gen_img_emb = gen_img_emb / (jnp.linalg.norm(gen_img_emb, axis=-1, keepdims=True) + 1e-6)
26
+ txt_emb = txt_emb / (jnp.linalg.norm(txt_emb, axis=-1, keepdims=True) + 1e-6)
27
+
28
+ # 2. Calculate cosine similarity
29
+ # Using einsum for batch dot product: batch (b), embedding_dim (d) -> bd,bd->b
30
+ # Calculate cosine similarity
31
+ similarity = jnp.einsum('bd,bd->b', gen_img_emb, txt_emb)
32
+
33
+ scaled_distance = (1.0 - similarity)
34
+ # 4. Average over the batch
35
+ mean_scaled_distance = jnp.mean(scaled_distance)
36
+
37
+ return mean_scaled_distance
38
+
39
+ def clip_metric(
40
+ generated: jnp.ndarray,
41
+ batch
42
+ ):
43
+ original_conditions = batch['text']
44
+
45
+ # Convert samples from [-1, 1] to [0, 255] and uint8
46
+ generated = (((generated + 1.0) / 2.0) * 255).astype(jnp.uint8)
47
+
48
+ generated_inputs = processor(images=generated, return_tensors="jax", padding=True,)
49
+
50
+ pixel_values = generated_inputs['pixel_values']
51
+ input_ids = original_conditions['input_ids']
52
+ attention_mask = original_conditions['attention_mask']
53
+
54
+ return calc(pixel_values, input_ids, attention_mask)
55
+
56
+ return EvaluationMetric(
57
+ function=clip_metric,
58
+ name='clip_similarity'
59
+ )
@@ -27,6 +27,8 @@ from flax.training import dynamic_scale as dynamic_scale_lib
27
27
  from .diffusion_trainer import TrainState, DiffusionTrainer
28
28
  import shutil
29
29
 
30
+ from flaxdiff.metrics.common import EvaluationMetric
31
+
30
32
  def generate_modelname(
31
33
  dataset_name: str,
32
34
  noise_schedule_name: str,
@@ -103,15 +105,6 @@ def generate_modelname(
103
105
  # model_name = f"{model_name}-{config_hash}"
104
106
  return model_name
105
107
 
106
- @dataclass
107
- class EvaluationMetric:
108
- """
109
- Evaluation metrics for the diffusion model.
110
- The function is given generated samples batch [B, H, W, C] and the original batch.
111
- """
112
- function: Callable
113
- name: str
114
-
115
108
  class GeneralDiffusionTrainer(DiffusionTrainer):
116
109
  """
117
110
  General trainer for diffusion models supporting both images and videos.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: flaxdiff
3
- Version: 0.2.3
3
+ Version: 0.2.4
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author-email: Ashish Kumar Singh <ashishkmr472@gmail.com>
6
6
  License-Expression: MIT
@@ -2,7 +2,7 @@ flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  flaxdiff/utils.py,sha256=DmlWUY1FGz4ESxIHaPQJf92CHjsdMjyDd651wFUtyNg,8838
3
3
  flaxdiff/data/__init__.py,sha256=8W5y7NyAOWtpLi8WRawk4VYeE3DMDnM3B_jKPD8BoFQ,143
4
4
  flaxdiff/data/benchmark_decord.py,sha256=x56Db1VPmziv_9KJvWdfS0O7cffsYkF5tt5WvldOKc0,13720
5
- flaxdiff/data/dataloaders.py,sha256=TgbR5CMxE86L0-1qy5ohZT8zhOPjk3oncd5WPBv08sQ,23557
5
+ flaxdiff/data/dataloaders.py,sha256=LV8ugqoB86yihfYeOJZHHdRZJNmZ63A2NQkdILMR9QA,23564
6
6
  flaxdiff/data/dataset_map.py,sha256=_6SYnmrYO-URDd8vPAmALTV6r0eMGWWmwUtsdjKGXnA,5072
7
7
  flaxdiff/data/online_loader.py,sha256=t1jEhdB6gWTlwx68ehj1ol_PrImbwXYiRlrJPCmNgCM,35701
8
8
  flaxdiff/data/sources/audio_utils.py,sha256=X27gG1yQt_abVOYgMtruYmZD7-8_uQCRhhTSpn4clkI,4514
@@ -18,6 +18,9 @@ flaxdiff/inference/pipeline.py,sha256=oMBRjvTtlC3Yzl1FqiBHcI4V34HXGAecCg8UvQbKoO
18
18
  flaxdiff/inference/utils.py,sha256=SRNYo-YtHzEPRpNv0fD8ZrUvnRIK941Rh4tjlsOGRgM,12278
19
19
  flaxdiff/inputs/__init__.py,sha256=ybPjQsFAf5sqRVZG1sRiOl99EnwpI-NQ8HE3y7UbXmU,7197
20
20
  flaxdiff/inputs/encoders.py,sha256=pjfbx4Rk7bLoE80MOfThZDm6YtsDncRekmn0Bmg_CwI,2963
21
+ flaxdiff/metrics/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
22
+ flaxdiff/metrics/common.py,sha256=E0MkL43dicImzNNa-RyQ3sVcrUbpeLlooIQsKTIStvo,285
23
+ flaxdiff/metrics/images.py,sha256=sIuF_Sa2VmPOKrFFoUpzhqOqNa9P7NF0njbrYi93AvE,2128
21
24
  flaxdiff/metrics/inception.py,sha256=a5kjMCPMT9gB88c_HCKiek-2vsAyoE35K7nDt4h4pVI,31843
22
25
  flaxdiff/metrics/psnr.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
23
26
  flaxdiff/metrics/ssim.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -56,9 +59,9 @@ flaxdiff/schedulers/sqrt.py,sha256=mCd_szmOqF6vqQKiAiEOqV_3eBIPGYrW3VxK0o4rBuo,4
56
59
  flaxdiff/trainer/__init__.py,sha256=xSoierfi26gxfgxlNnwvyyPmuPAJ--5i3mEHxt3S-AE,215
57
60
  flaxdiff/trainer/autoencoder_trainer.py,sha256=2FP2P-k9c0n_k3eT0trkq73dQrHRdBj9ObK1idcyhSw,6996
58
61
  flaxdiff/trainer/diffusion_trainer.py,sha256=reQEVWKTqKAeyCMQ-curPOfSRmBKxKooK8EVtUuorcM,14599
59
- flaxdiff/trainer/general_diffusion_trainer.py,sha256=1rLU7iooXIlSDIGFZ7bHgpMWmkqMbUzM9fHBu1L0t-U,27252
62
+ flaxdiff/trainer/general_diffusion_trainer.py,sha256=9c3Ys5sN4_eTehusLjS6IKW5XPOkxoguik-6G0cyQc4,27082
60
63
  flaxdiff/trainer/simple_trainer.py,sha256=raLS1shwpjJBT_bYXLAB2E4kA9MbwasDTzDTUqfCCUc,24312
61
- flaxdiff-0.2.3.dist-info/METADATA,sha256=eoCSaBNoDpk90qWz5_NGVkzvuf3Oqt6rSj_ZVTfYn7s,24057
62
- flaxdiff-0.2.3.dist-info/WHEEL,sha256=wXxTzcEDnjrTwFYjLPcsW_7_XihufBwmpiBeiXNBGEA,91
63
- flaxdiff-0.2.3.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
64
- flaxdiff-0.2.3.dist-info/RECORD,,
64
+ flaxdiff-0.2.4.dist-info/METADATA,sha256=mqm2um1TtgjQzrGlvl7x_CCb_09hK376_dsRECe6qLQ,24057
65
+ flaxdiff-0.2.4.dist-info/WHEEL,sha256=wXxTzcEDnjrTwFYjLPcsW_7_XihufBwmpiBeiXNBGEA,91
66
+ flaxdiff-0.2.4.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
67
+ flaxdiff-0.2.4.dist-info/RECORD,,