flaxdiff 0.1.37.7__tar.gz → 0.1.38__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. {flaxdiff-0.1.37.7 → flaxdiff-0.1.38}/PKG-INFO +1 -1
  2. {flaxdiff-0.1.37.7 → flaxdiff-0.1.38}/flaxdiff/models/autoencoder/diffusers.py +4 -4
  3. flaxdiff-0.1.38/flaxdiff/models/general.py +21 -0
  4. {flaxdiff-0.1.37.7 → flaxdiff-0.1.38}/flaxdiff.egg-info/PKG-INFO +1 -1
  5. {flaxdiff-0.1.37.7 → flaxdiff-0.1.38}/flaxdiff.egg-info/SOURCES.txt +1 -0
  6. {flaxdiff-0.1.37.7 → flaxdiff-0.1.38}/pyproject.toml +1 -1
  7. {flaxdiff-0.1.37.7 → flaxdiff-0.1.38}/README.md +0 -0
  8. {flaxdiff-0.1.37.7 → flaxdiff-0.1.38}/flaxdiff/__init__.py +0 -0
  9. {flaxdiff-0.1.37.7 → flaxdiff-0.1.38}/flaxdiff/data/__init__.py +0 -0
  10. {flaxdiff-0.1.37.7 → flaxdiff-0.1.38}/flaxdiff/data/dataset_map.py +0 -0
  11. {flaxdiff-0.1.37.7 → flaxdiff-0.1.38}/flaxdiff/data/datasets.py +0 -0
  12. {flaxdiff-0.1.37.7 → flaxdiff-0.1.38}/flaxdiff/data/online_loader.py +0 -0
  13. {flaxdiff-0.1.37.7 → flaxdiff-0.1.38}/flaxdiff/data/sources/gcs.py +0 -0
  14. {flaxdiff-0.1.37.7 → flaxdiff-0.1.38}/flaxdiff/data/sources/tfds.py +0 -0
  15. {flaxdiff-0.1.37.7 → flaxdiff-0.1.38}/flaxdiff/metrics/inception.py +0 -0
  16. {flaxdiff-0.1.37.7 → flaxdiff-0.1.38}/flaxdiff/metrics/psnr.py +0 -0
  17. {flaxdiff-0.1.37.7 → flaxdiff-0.1.38}/flaxdiff/metrics/ssim.py +0 -0
  18. {flaxdiff-0.1.37.7 → flaxdiff-0.1.38}/flaxdiff/metrics/utils.py +0 -0
  19. {flaxdiff-0.1.37.7 → flaxdiff-0.1.38}/flaxdiff/models/__init__.py +0 -0
  20. {flaxdiff-0.1.37.7 → flaxdiff-0.1.38}/flaxdiff/models/attention.py +0 -0
  21. {flaxdiff-0.1.37.7 → flaxdiff-0.1.38}/flaxdiff/models/autoencoder/__init__.py +0 -0
  22. {flaxdiff-0.1.37.7 → flaxdiff-0.1.38}/flaxdiff/models/autoencoder/autoencoder.py +0 -0
  23. {flaxdiff-0.1.37.7 → flaxdiff-0.1.38}/flaxdiff/models/autoencoder/simple_autoenc.py +0 -0
  24. {flaxdiff-0.1.37.7 → flaxdiff-0.1.38}/flaxdiff/models/common.py +0 -0
  25. {flaxdiff-0.1.37.7 → flaxdiff-0.1.38}/flaxdiff/models/favor_fastattn.py +0 -0
  26. {flaxdiff-0.1.37.7 → flaxdiff-0.1.38}/flaxdiff/models/simple_unet.py +0 -0
  27. {flaxdiff-0.1.37.7 → flaxdiff-0.1.38}/flaxdiff/models/simple_vit.py +0 -0
  28. {flaxdiff-0.1.37.7 → flaxdiff-0.1.38}/flaxdiff/predictors/__init__.py +0 -0
  29. {flaxdiff-0.1.37.7 → flaxdiff-0.1.38}/flaxdiff/samplers/__init__.py +0 -0
  30. {flaxdiff-0.1.37.7 → flaxdiff-0.1.38}/flaxdiff/samplers/common.py +0 -0
  31. {flaxdiff-0.1.37.7 → flaxdiff-0.1.38}/flaxdiff/samplers/ddim.py +0 -0
  32. {flaxdiff-0.1.37.7 → flaxdiff-0.1.38}/flaxdiff/samplers/ddpm.py +0 -0
  33. {flaxdiff-0.1.37.7 → flaxdiff-0.1.38}/flaxdiff/samplers/euler.py +0 -0
  34. {flaxdiff-0.1.37.7 → flaxdiff-0.1.38}/flaxdiff/samplers/heun_sampler.py +0 -0
  35. {flaxdiff-0.1.37.7 → flaxdiff-0.1.38}/flaxdiff/samplers/multistep_dpm.py +0 -0
  36. {flaxdiff-0.1.37.7 → flaxdiff-0.1.38}/flaxdiff/samplers/rk4_sampler.py +0 -0
  37. {flaxdiff-0.1.37.7 → flaxdiff-0.1.38}/flaxdiff/schedulers/__init__.py +0 -0
  38. {flaxdiff-0.1.37.7 → flaxdiff-0.1.38}/flaxdiff/schedulers/common.py +0 -0
  39. {flaxdiff-0.1.37.7 → flaxdiff-0.1.38}/flaxdiff/schedulers/continuous.py +0 -0
  40. {flaxdiff-0.1.37.7 → flaxdiff-0.1.38}/flaxdiff/schedulers/cosine.py +0 -0
  41. {flaxdiff-0.1.37.7 → flaxdiff-0.1.38}/flaxdiff/schedulers/discrete.py +0 -0
  42. {flaxdiff-0.1.37.7 → flaxdiff-0.1.38}/flaxdiff/schedulers/exp.py +0 -0
  43. {flaxdiff-0.1.37.7 → flaxdiff-0.1.38}/flaxdiff/schedulers/karras.py +0 -0
  44. {flaxdiff-0.1.37.7 → flaxdiff-0.1.38}/flaxdiff/schedulers/linear.py +0 -0
  45. {flaxdiff-0.1.37.7 → flaxdiff-0.1.38}/flaxdiff/schedulers/sqrt.py +0 -0
  46. {flaxdiff-0.1.37.7 → flaxdiff-0.1.38}/flaxdiff/trainer/__init__.py +0 -0
  47. {flaxdiff-0.1.37.7 → flaxdiff-0.1.38}/flaxdiff/trainer/autoencoder_trainer.py +0 -0
  48. {flaxdiff-0.1.37.7 → flaxdiff-0.1.38}/flaxdiff/trainer/diffusion_trainer.py +0 -0
  49. {flaxdiff-0.1.37.7 → flaxdiff-0.1.38}/flaxdiff/trainer/simple_trainer.py +0 -0
  50. {flaxdiff-0.1.37.7 → flaxdiff-0.1.38}/flaxdiff/trainer/video_diffusion_trainer.py +0 -0
  51. {flaxdiff-0.1.37.7 → flaxdiff-0.1.38}/flaxdiff/utils.py +0 -0
  52. {flaxdiff-0.1.37.7 → flaxdiff-0.1.38}/flaxdiff.egg-info/dependency_links.txt +0 -0
  53. {flaxdiff-0.1.37.7 → flaxdiff-0.1.38}/flaxdiff.egg-info/requires.txt +0 -0
  54. {flaxdiff-0.1.37.7 → flaxdiff-0.1.38}/flaxdiff.egg-info/top_level.txt +0 -0
  55. {flaxdiff-0.1.37.7 → flaxdiff-0.1.38}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: flaxdiff
3
- Version: 0.1.37.7
3
+ Version: 0.1.38
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author-email: Ashish Kumar Singh <ashishkmr472@gmail.com>
6
6
  License-Expression: MIT
@@ -14,15 +14,15 @@ class StableDiffusionVAE(AutoEncoder):
14
14
  def __init__(self, modelname = "CompVis/stable-diffusion-v1-4", revision="bf16", dtype=jnp.bfloat16):
15
15
 
16
16
  from diffusers.models.vae_flax import FlaxEncoder, FlaxDecoder
17
- from diffusers import FlaxStableDiffusionPipeline
17
+ from diffusers import FlaxStableDiffusionPipeline, FlaxAutoencoderKL
18
18
 
19
- pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
19
+ vae, params = FlaxAutoencoderKL.from_pretrained(
20
20
  modelname,
21
- revision=revision,
21
+ # revision=revision,
22
22
  dtype=dtype,
23
23
  )
24
24
 
25
- vae = pipeline.vae
25
+ # vae = pipeline.vae
26
26
 
27
27
  enc = FlaxEncoder(
28
28
  in_channels=vae.config.in_channels,
@@ -0,0 +1,21 @@
1
+ from flax import linen as nn
2
+ import jax
3
+ import jax.numpy as jnp
4
+
5
+ class BCHWModelWrapper(nn.Module):
6
+ model: nn.Module
7
+
8
+ @nn.compact
9
+ def __call__(self, x, temb, textcontext):
10
+ # Reshape the input to BCHW format from BHWC
11
+ x = jnp.transpose(x, (0, 3, 1, 2))
12
+ # Pass the input through the UNet model
13
+ out = self.model(
14
+ sample=x,
15
+ timesteps=temb,
16
+ encoder_hidden_states=textcontext,
17
+ )
18
+ # Reshape the output back to BHWC format
19
+ out = jnp.transpose(out.sample, (0, 2, 3, 1))
20
+ return out
21
+
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: flaxdiff
3
- Version: 0.1.37.7
3
+ Version: 0.1.38
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author-email: Ashish Kumar Singh <ashishkmr472@gmail.com>
6
6
  License-Expression: MIT
@@ -21,6 +21,7 @@ flaxdiff/models/__init__.py
21
21
  flaxdiff/models/attention.py
22
22
  flaxdiff/models/common.py
23
23
  flaxdiff/models/favor_fastattn.py
24
+ flaxdiff/models/general.py
24
25
  flaxdiff/models/simple_unet.py
25
26
  flaxdiff/models/simple_vit.py
26
27
  flaxdiff/models/autoencoder/__init__.py
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "flaxdiff"
7
- version = "0.1.37.7"
7
+ version = "0.1.38"
8
8
  description = "A versatile and easy to understand Diffusion library"
9
9
  readme = "README.md"
10
10
  authors = [
File without changes
File without changes
File without changes