flaxdiff 0.2.3__tar.gz → 0.2.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {flaxdiff-0.2.3 → flaxdiff-0.2.4}/PKG-INFO +1 -1
- {flaxdiff-0.2.3 → flaxdiff-0.2.4}/flaxdiff/data/dataloaders.py +1 -1
- flaxdiff-0.2.4/flaxdiff/metrics/common.py +11 -0
- flaxdiff-0.2.4/flaxdiff/metrics/images.py +59 -0
- flaxdiff-0.2.4/flaxdiff/metrics/ssim.py +0 -0
- {flaxdiff-0.2.3 → flaxdiff-0.2.4}/flaxdiff/trainer/general_diffusion_trainer.py +2 -9
- {flaxdiff-0.2.3 → flaxdiff-0.2.4}/flaxdiff.egg-info/PKG-INFO +1 -1
- {flaxdiff-0.2.3 → flaxdiff-0.2.4}/flaxdiff.egg-info/SOURCES.txt +3 -0
- {flaxdiff-0.2.3 → flaxdiff-0.2.4}/pyproject.toml +1 -1
- {flaxdiff-0.2.3 → flaxdiff-0.2.4}/README.md +0 -0
- {flaxdiff-0.2.3 → flaxdiff-0.2.4}/flaxdiff/__init__.py +0 -0
- {flaxdiff-0.2.3 → flaxdiff-0.2.4}/flaxdiff/data/__init__.py +0 -0
- {flaxdiff-0.2.3 → flaxdiff-0.2.4}/flaxdiff/data/benchmark_decord.py +0 -0
- {flaxdiff-0.2.3 → flaxdiff-0.2.4}/flaxdiff/data/dataset_map.py +0 -0
- {flaxdiff-0.2.3 → flaxdiff-0.2.4}/flaxdiff/data/online_loader.py +0 -0
- {flaxdiff-0.2.3 → flaxdiff-0.2.4}/flaxdiff/data/sources/audio_utils.py +0 -0
- {flaxdiff-0.2.3 → flaxdiff-0.2.4}/flaxdiff/data/sources/av_example.py +0 -0
- {flaxdiff-0.2.3 → flaxdiff-0.2.4}/flaxdiff/data/sources/av_utils.py +0 -0
- {flaxdiff-0.2.3 → flaxdiff-0.2.4}/flaxdiff/data/sources/base.py +0 -0
- {flaxdiff-0.2.3 → flaxdiff-0.2.4}/flaxdiff/data/sources/images.py +0 -0
- {flaxdiff-0.2.3 → flaxdiff-0.2.4}/flaxdiff/data/sources/utils.py +0 -0
- {flaxdiff-0.2.3 → flaxdiff-0.2.4}/flaxdiff/data/sources/videos.py +0 -0
- {flaxdiff-0.2.3 → flaxdiff-0.2.4}/flaxdiff/data/sources/voxceleb2.py +0 -0
- {flaxdiff-0.2.3 → flaxdiff-0.2.4}/flaxdiff/inference/__init__.py +0 -0
- {flaxdiff-0.2.3 → flaxdiff-0.2.4}/flaxdiff/inference/pipeline.py +0 -0
- {flaxdiff-0.2.3 → flaxdiff-0.2.4}/flaxdiff/inference/utils.py +0 -0
- {flaxdiff-0.2.3 → flaxdiff-0.2.4}/flaxdiff/inputs/__init__.py +0 -0
- {flaxdiff-0.2.3 → flaxdiff-0.2.4}/flaxdiff/inputs/encoders.py +0 -0
- /flaxdiff-0.2.3/flaxdiff/metrics/psnr.py → /flaxdiff-0.2.4/flaxdiff/metrics/__init__.py +0 -0
- {flaxdiff-0.2.3 → flaxdiff-0.2.4}/flaxdiff/metrics/inception.py +0 -0
- /flaxdiff-0.2.3/flaxdiff/metrics/ssim.py → /flaxdiff-0.2.4/flaxdiff/metrics/psnr.py +0 -0
- {flaxdiff-0.2.3 → flaxdiff-0.2.4}/flaxdiff/metrics/utils.py +0 -0
- {flaxdiff-0.2.3 → flaxdiff-0.2.4}/flaxdiff/models/__init__.py +0 -0
- {flaxdiff-0.2.3 → flaxdiff-0.2.4}/flaxdiff/models/attention.py +0 -0
- {flaxdiff-0.2.3 → flaxdiff-0.2.4}/flaxdiff/models/autoencoder/__init__.py +0 -0
- {flaxdiff-0.2.3 → flaxdiff-0.2.4}/flaxdiff/models/autoencoder/autoencoder.py +0 -0
- {flaxdiff-0.2.3 → flaxdiff-0.2.4}/flaxdiff/models/autoencoder/diffusers.py +0 -0
- {flaxdiff-0.2.3 → flaxdiff-0.2.4}/flaxdiff/models/autoencoder/simple_autoenc.py +0 -0
- {flaxdiff-0.2.3 → flaxdiff-0.2.4}/flaxdiff/models/common.py +0 -0
- {flaxdiff-0.2.3 → flaxdiff-0.2.4}/flaxdiff/models/favor_fastattn.py +0 -0
- {flaxdiff-0.2.3 → flaxdiff-0.2.4}/flaxdiff/models/general.py +0 -0
- {flaxdiff-0.2.3 → flaxdiff-0.2.4}/flaxdiff/models/simple_unet.py +0 -0
- {flaxdiff-0.2.3 → flaxdiff-0.2.4}/flaxdiff/models/simple_vit.py +0 -0
- {flaxdiff-0.2.3 → flaxdiff-0.2.4}/flaxdiff/models/unet_3d.py +0 -0
- {flaxdiff-0.2.3 → flaxdiff-0.2.4}/flaxdiff/models/unet_3d_blocks.py +0 -0
- {flaxdiff-0.2.3 → flaxdiff-0.2.4}/flaxdiff/predictors/__init__.py +0 -0
- {flaxdiff-0.2.3 → flaxdiff-0.2.4}/flaxdiff/samplers/__init__.py +0 -0
- {flaxdiff-0.2.3 → flaxdiff-0.2.4}/flaxdiff/samplers/common.py +0 -0
- {flaxdiff-0.2.3 → flaxdiff-0.2.4}/flaxdiff/samplers/ddim.py +0 -0
- {flaxdiff-0.2.3 → flaxdiff-0.2.4}/flaxdiff/samplers/ddpm.py +0 -0
- {flaxdiff-0.2.3 → flaxdiff-0.2.4}/flaxdiff/samplers/euler.py +0 -0
- {flaxdiff-0.2.3 → flaxdiff-0.2.4}/flaxdiff/samplers/heun_sampler.py +0 -0
- {flaxdiff-0.2.3 → flaxdiff-0.2.4}/flaxdiff/samplers/multistep_dpm.py +0 -0
- {flaxdiff-0.2.3 → flaxdiff-0.2.4}/flaxdiff/samplers/rk4_sampler.py +0 -0
- {flaxdiff-0.2.3 → flaxdiff-0.2.4}/flaxdiff/schedulers/__init__.py +0 -0
- {flaxdiff-0.2.3 → flaxdiff-0.2.4}/flaxdiff/schedulers/common.py +0 -0
- {flaxdiff-0.2.3 → flaxdiff-0.2.4}/flaxdiff/schedulers/continuous.py +0 -0
- {flaxdiff-0.2.3 → flaxdiff-0.2.4}/flaxdiff/schedulers/cosine.py +0 -0
- {flaxdiff-0.2.3 → flaxdiff-0.2.4}/flaxdiff/schedulers/discrete.py +0 -0
- {flaxdiff-0.2.3 → flaxdiff-0.2.4}/flaxdiff/schedulers/exp.py +0 -0
- {flaxdiff-0.2.3 → flaxdiff-0.2.4}/flaxdiff/schedulers/karras.py +0 -0
- {flaxdiff-0.2.3 → flaxdiff-0.2.4}/flaxdiff/schedulers/linear.py +0 -0
- {flaxdiff-0.2.3 → flaxdiff-0.2.4}/flaxdiff/schedulers/sqrt.py +0 -0
- {flaxdiff-0.2.3 → flaxdiff-0.2.4}/flaxdiff/trainer/__init__.py +0 -0
- {flaxdiff-0.2.3 → flaxdiff-0.2.4}/flaxdiff/trainer/autoencoder_trainer.py +0 -0
- {flaxdiff-0.2.3 → flaxdiff-0.2.4}/flaxdiff/trainer/diffusion_trainer.py +0 -0
- {flaxdiff-0.2.3 → flaxdiff-0.2.4}/flaxdiff/trainer/simple_trainer.py +0 -0
- {flaxdiff-0.2.3 → flaxdiff-0.2.4}/flaxdiff/utils.py +0 -0
- {flaxdiff-0.2.3 → flaxdiff-0.2.4}/flaxdiff.egg-info/dependency_links.txt +0 -0
- {flaxdiff-0.2.3 → flaxdiff-0.2.4}/flaxdiff.egg-info/requires.txt +0 -0
- {flaxdiff-0.2.3 → flaxdiff-0.2.4}/flaxdiff.egg-info/top_level.txt +0 -0
- {flaxdiff-0.2.3 → flaxdiff-0.2.4}/setup.cfg +0 -0
@@ -0,0 +1,11 @@
|
|
1
|
+
from typing import Callable
|
2
|
+
from dataclasses import dataclass
|
3
|
+
|
4
|
+
@dataclass
|
5
|
+
class EvaluationMetric:
|
6
|
+
"""
|
7
|
+
Evaluation metrics for the diffusion model.
|
8
|
+
The function is given generated samples batch [B, H, W, C] and the original batch.
|
9
|
+
"""
|
10
|
+
function: Callable
|
11
|
+
name: str
|
@@ -0,0 +1,59 @@
|
|
1
|
+
from .common import EvaluationMetric
|
2
|
+
import jax
|
3
|
+
import jax.numpy as jnp
|
4
|
+
|
5
|
+
def get_clip_metric(
|
6
|
+
modelname: str = "openai/clip-vit-large-patch14",
|
7
|
+
):
|
8
|
+
from transformers import AutoProcessor, FlaxCLIPModel
|
9
|
+
model = FlaxCLIPModel.from_pretrained(modelname, dtype=jnp.float16)
|
10
|
+
processor = AutoProcessor.from_pretrained(modelname, use_fast=True, dtype=jnp.float16)
|
11
|
+
|
12
|
+
@jax.jit
|
13
|
+
def calc(pixel_values, input_ids, attention_mask):
|
14
|
+
# Get the logits
|
15
|
+
generated_out = model(
|
16
|
+
pixel_values=pixel_values,
|
17
|
+
input_ids=input_ids,
|
18
|
+
attention_mask=attention_mask,
|
19
|
+
)
|
20
|
+
|
21
|
+
gen_img_emb = generated_out.image_embeds
|
22
|
+
txt_emb = generated_out.text_embeds
|
23
|
+
|
24
|
+
# 1. Normalize embeddings (essential for cosine similarity/distance)
|
25
|
+
gen_img_emb = gen_img_emb / (jnp.linalg.norm(gen_img_emb, axis=-1, keepdims=True) + 1e-6)
|
26
|
+
txt_emb = txt_emb / (jnp.linalg.norm(txt_emb, axis=-1, keepdims=True) + 1e-6)
|
27
|
+
|
28
|
+
# 2. Calculate cosine similarity
|
29
|
+
# Using einsum for batch dot product: batch (b), embedding_dim (d) -> bd,bd->b
|
30
|
+
# Calculate cosine similarity
|
31
|
+
similarity = jnp.einsum('bd,bd->b', gen_img_emb, txt_emb)
|
32
|
+
|
33
|
+
scaled_distance = (1.0 - similarity)
|
34
|
+
# 4. Average over the batch
|
35
|
+
mean_scaled_distance = jnp.mean(scaled_distance)
|
36
|
+
|
37
|
+
return mean_scaled_distance
|
38
|
+
|
39
|
+
def clip_metric(
|
40
|
+
generated: jnp.ndarray,
|
41
|
+
batch
|
42
|
+
):
|
43
|
+
original_conditions = batch['text']
|
44
|
+
|
45
|
+
# Convert samples from [-1, 1] to [0, 255] and uint8
|
46
|
+
generated = (((generated + 1.0) / 2.0) * 255).astype(jnp.uint8)
|
47
|
+
|
48
|
+
generated_inputs = processor(images=generated, return_tensors="jax", padding=True,)
|
49
|
+
|
50
|
+
pixel_values = generated_inputs['pixel_values']
|
51
|
+
input_ids = original_conditions['input_ids']
|
52
|
+
attention_mask = original_conditions['attention_mask']
|
53
|
+
|
54
|
+
return calc(pixel_values, input_ids, attention_mask)
|
55
|
+
|
56
|
+
return EvaluationMetric(
|
57
|
+
function=clip_metric,
|
58
|
+
name='clip_similarity'
|
59
|
+
)
|
File without changes
|
@@ -27,6 +27,8 @@ from flax.training import dynamic_scale as dynamic_scale_lib
|
|
27
27
|
from .diffusion_trainer import TrainState, DiffusionTrainer
|
28
28
|
import shutil
|
29
29
|
|
30
|
+
from flaxdiff.metrics.common import EvaluationMetric
|
31
|
+
|
30
32
|
def generate_modelname(
|
31
33
|
dataset_name: str,
|
32
34
|
noise_schedule_name: str,
|
@@ -103,15 +105,6 @@ def generate_modelname(
|
|
103
105
|
# model_name = f"{model_name}-{config_hash}"
|
104
106
|
return model_name
|
105
107
|
|
106
|
-
@dataclass
|
107
|
-
class EvaluationMetric:
|
108
|
-
"""
|
109
|
-
Evaluation metrics for the diffusion model.
|
110
|
-
The function is given generated samples batch [B, H, W, C] and the original batch.
|
111
|
-
"""
|
112
|
-
function: Callable
|
113
|
-
name: str
|
114
|
-
|
115
108
|
class GeneralDiffusionTrainer(DiffusionTrainer):
|
116
109
|
"""
|
117
110
|
General trainer for diffusion models supporting both images and videos.
|
@@ -25,6 +25,9 @@ flaxdiff/inference/pipeline.py
|
|
25
25
|
flaxdiff/inference/utils.py
|
26
26
|
flaxdiff/inputs/__init__.py
|
27
27
|
flaxdiff/inputs/encoders.py
|
28
|
+
flaxdiff/metrics/__init__.py
|
29
|
+
flaxdiff/metrics/common.py
|
30
|
+
flaxdiff/metrics/images.py
|
28
31
|
flaxdiff/metrics/inception.py
|
29
32
|
flaxdiff/metrics/psnr.py
|
30
33
|
flaxdiff/metrics/ssim.py
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|