flaxdiff 0.1.37.6__tar.gz → 0.1.38__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {flaxdiff-0.1.37.6 → flaxdiff-0.1.38}/PKG-INFO +1 -1
- {flaxdiff-0.1.37.6 → flaxdiff-0.1.38}/flaxdiff/models/autoencoder/diffusers.py +4 -4
- flaxdiff-0.1.38/flaxdiff/models/general.py +21 -0
- {flaxdiff-0.1.37.6 → flaxdiff-0.1.38}/flaxdiff/samplers/common.py +1 -1
- {flaxdiff-0.1.37.6 → flaxdiff-0.1.38}/flaxdiff.egg-info/PKG-INFO +1 -1
- {flaxdiff-0.1.37.6 → flaxdiff-0.1.38}/flaxdiff.egg-info/SOURCES.txt +1 -0
- {flaxdiff-0.1.37.6 → flaxdiff-0.1.38}/pyproject.toml +1 -1
- {flaxdiff-0.1.37.6 → flaxdiff-0.1.38}/README.md +0 -0
- {flaxdiff-0.1.37.6 → flaxdiff-0.1.38}/flaxdiff/__init__.py +0 -0
- {flaxdiff-0.1.37.6 → flaxdiff-0.1.38}/flaxdiff/data/__init__.py +0 -0
- {flaxdiff-0.1.37.6 → flaxdiff-0.1.38}/flaxdiff/data/dataset_map.py +0 -0
- {flaxdiff-0.1.37.6 → flaxdiff-0.1.38}/flaxdiff/data/datasets.py +0 -0
- {flaxdiff-0.1.37.6 → flaxdiff-0.1.38}/flaxdiff/data/online_loader.py +0 -0
- {flaxdiff-0.1.37.6 → flaxdiff-0.1.38}/flaxdiff/data/sources/gcs.py +0 -0
- {flaxdiff-0.1.37.6 → flaxdiff-0.1.38}/flaxdiff/data/sources/tfds.py +0 -0
- {flaxdiff-0.1.37.6 → flaxdiff-0.1.38}/flaxdiff/metrics/inception.py +0 -0
- {flaxdiff-0.1.37.6 → flaxdiff-0.1.38}/flaxdiff/metrics/psnr.py +0 -0
- {flaxdiff-0.1.37.6 → flaxdiff-0.1.38}/flaxdiff/metrics/ssim.py +0 -0
- {flaxdiff-0.1.37.6 → flaxdiff-0.1.38}/flaxdiff/metrics/utils.py +0 -0
- {flaxdiff-0.1.37.6 → flaxdiff-0.1.38}/flaxdiff/models/__init__.py +0 -0
- {flaxdiff-0.1.37.6 → flaxdiff-0.1.38}/flaxdiff/models/attention.py +0 -0
- {flaxdiff-0.1.37.6 → flaxdiff-0.1.38}/flaxdiff/models/autoencoder/__init__.py +0 -0
- {flaxdiff-0.1.37.6 → flaxdiff-0.1.38}/flaxdiff/models/autoencoder/autoencoder.py +0 -0
- {flaxdiff-0.1.37.6 → flaxdiff-0.1.38}/flaxdiff/models/autoencoder/simple_autoenc.py +0 -0
- {flaxdiff-0.1.37.6 → flaxdiff-0.1.38}/flaxdiff/models/common.py +0 -0
- {flaxdiff-0.1.37.6 → flaxdiff-0.1.38}/flaxdiff/models/favor_fastattn.py +0 -0
- {flaxdiff-0.1.37.6 → flaxdiff-0.1.38}/flaxdiff/models/simple_unet.py +0 -0
- {flaxdiff-0.1.37.6 → flaxdiff-0.1.38}/flaxdiff/models/simple_vit.py +0 -0
- {flaxdiff-0.1.37.6 → flaxdiff-0.1.38}/flaxdiff/predictors/__init__.py +0 -0
- {flaxdiff-0.1.37.6 → flaxdiff-0.1.38}/flaxdiff/samplers/__init__.py +0 -0
- {flaxdiff-0.1.37.6 → flaxdiff-0.1.38}/flaxdiff/samplers/ddim.py +0 -0
- {flaxdiff-0.1.37.6 → flaxdiff-0.1.38}/flaxdiff/samplers/ddpm.py +0 -0
- {flaxdiff-0.1.37.6 → flaxdiff-0.1.38}/flaxdiff/samplers/euler.py +0 -0
- {flaxdiff-0.1.37.6 → flaxdiff-0.1.38}/flaxdiff/samplers/heun_sampler.py +0 -0
- {flaxdiff-0.1.37.6 → flaxdiff-0.1.38}/flaxdiff/samplers/multistep_dpm.py +0 -0
- {flaxdiff-0.1.37.6 → flaxdiff-0.1.38}/flaxdiff/samplers/rk4_sampler.py +0 -0
- {flaxdiff-0.1.37.6 → flaxdiff-0.1.38}/flaxdiff/schedulers/__init__.py +0 -0
- {flaxdiff-0.1.37.6 → flaxdiff-0.1.38}/flaxdiff/schedulers/common.py +0 -0
- {flaxdiff-0.1.37.6 → flaxdiff-0.1.38}/flaxdiff/schedulers/continuous.py +0 -0
- {flaxdiff-0.1.37.6 → flaxdiff-0.1.38}/flaxdiff/schedulers/cosine.py +0 -0
- {flaxdiff-0.1.37.6 → flaxdiff-0.1.38}/flaxdiff/schedulers/discrete.py +0 -0
- {flaxdiff-0.1.37.6 → flaxdiff-0.1.38}/flaxdiff/schedulers/exp.py +0 -0
- {flaxdiff-0.1.37.6 → flaxdiff-0.1.38}/flaxdiff/schedulers/karras.py +0 -0
- {flaxdiff-0.1.37.6 → flaxdiff-0.1.38}/flaxdiff/schedulers/linear.py +0 -0
- {flaxdiff-0.1.37.6 → flaxdiff-0.1.38}/flaxdiff/schedulers/sqrt.py +0 -0
- {flaxdiff-0.1.37.6 → flaxdiff-0.1.38}/flaxdiff/trainer/__init__.py +0 -0
- {flaxdiff-0.1.37.6 → flaxdiff-0.1.38}/flaxdiff/trainer/autoencoder_trainer.py +0 -0
- {flaxdiff-0.1.37.6 → flaxdiff-0.1.38}/flaxdiff/trainer/diffusion_trainer.py +0 -0
- {flaxdiff-0.1.37.6 → flaxdiff-0.1.38}/flaxdiff/trainer/simple_trainer.py +0 -0
- {flaxdiff-0.1.37.6 → flaxdiff-0.1.38}/flaxdiff/trainer/video_diffusion_trainer.py +0 -0
- {flaxdiff-0.1.37.6 → flaxdiff-0.1.38}/flaxdiff/utils.py +0 -0
- {flaxdiff-0.1.37.6 → flaxdiff-0.1.38}/flaxdiff.egg-info/dependency_links.txt +0 -0
- {flaxdiff-0.1.37.6 → flaxdiff-0.1.38}/flaxdiff.egg-info/requires.txt +0 -0
- {flaxdiff-0.1.37.6 → flaxdiff-0.1.38}/flaxdiff.egg-info/top_level.txt +0 -0
- {flaxdiff-0.1.37.6 → flaxdiff-0.1.38}/setup.cfg +0 -0
@@ -14,15 +14,15 @@ class StableDiffusionVAE(AutoEncoder):
|
|
14
14
|
def __init__(self, modelname = "CompVis/stable-diffusion-v1-4", revision="bf16", dtype=jnp.bfloat16):
|
15
15
|
|
16
16
|
from diffusers.models.vae_flax import FlaxEncoder, FlaxDecoder
|
17
|
-
from diffusers import FlaxStableDiffusionPipeline
|
17
|
+
from diffusers import FlaxStableDiffusionPipeline, FlaxAutoencoderKL
|
18
18
|
|
19
|
-
|
19
|
+
vae, params = FlaxAutoencoderKL.from_pretrained(
|
20
20
|
modelname,
|
21
|
-
revision=revision,
|
21
|
+
# revision=revision,
|
22
22
|
dtype=dtype,
|
23
23
|
)
|
24
24
|
|
25
|
-
vae = pipeline.vae
|
25
|
+
# vae = pipeline.vae
|
26
26
|
|
27
27
|
enc = FlaxEncoder(
|
28
28
|
in_channels=vae.config.in_channels,
|
@@ -0,0 +1,21 @@
|
|
1
|
+
from flax import linen as nn
|
2
|
+
import jax
|
3
|
+
import jax.numpy as jnp
|
4
|
+
|
5
|
+
class BCHWModelWrapper(nn.Module):
|
6
|
+
model: nn.Module
|
7
|
+
|
8
|
+
@nn.compact
|
9
|
+
def __call__(self, x, temb, textcontext):
|
10
|
+
# Reshape the input to BCHW format from BHWC
|
11
|
+
x = jnp.transpose(x, (0, 3, 1, 2))
|
12
|
+
# Pass the input through the UNet model
|
13
|
+
out = self.model(
|
14
|
+
sample=x,
|
15
|
+
timesteps=temb,
|
16
|
+
encoder_hidden_states=textcontext,
|
17
|
+
)
|
18
|
+
# Reshape the output back to BHWC format
|
19
|
+
out = jnp.transpose(out.sample, (0, 2, 3, 1))
|
20
|
+
return out
|
21
|
+
|
@@ -21,6 +21,7 @@ flaxdiff/models/__init__.py
|
|
21
21
|
flaxdiff/models/attention.py
|
22
22
|
flaxdiff/models/common.py
|
23
23
|
flaxdiff/models/favor_fastattn.py
|
24
|
+
flaxdiff/models/general.py
|
24
25
|
flaxdiff/models/simple_unet.py
|
25
26
|
flaxdiff/models/simple_vit.py
|
26
27
|
flaxdiff/models/autoencoder/__init__.py
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|