flaxdiff 0.1.37.7__py3-none-any.whl → 0.1.38__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -14,15 +14,15 @@ class StableDiffusionVAE(AutoEncoder):
14
14
  def __init__(self, modelname = "CompVis/stable-diffusion-v1-4", revision="bf16", dtype=jnp.bfloat16):
15
15
 
16
16
  from diffusers.models.vae_flax import FlaxEncoder, FlaxDecoder
17
- from diffusers import FlaxStableDiffusionPipeline
17
+ from diffusers import FlaxStableDiffusionPipeline, FlaxAutoencoderKL
18
18
 
19
- pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
19
+ vae, params = FlaxAutoencoderKL.from_pretrained(
20
20
  modelname,
21
- revision=revision,
21
+ # revision=revision,
22
22
  dtype=dtype,
23
23
  )
24
24
 
25
- vae = pipeline.vae
25
+ # vae = pipeline.vae
26
26
 
27
27
  enc = FlaxEncoder(
28
28
  in_channels=vae.config.in_channels,
@@ -0,0 +1,21 @@
1
+ from flax import linen as nn
2
+ import jax
3
+ import jax.numpy as jnp
4
+
5
+ class BCHWModelWrapper(nn.Module):
6
+ model: nn.Module
7
+
8
+ @nn.compact
9
+ def __call__(self, x, temb, textcontext):
10
+ # Reshape the input to BCHW format from BHWC
11
+ x = jnp.transpose(x, (0, 3, 1, 2))
12
+ # Pass the input through the UNet model
13
+ out = self.model(
14
+ sample=x,
15
+ timesteps=temb,
16
+ encoder_hidden_states=textcontext,
17
+ )
18
+ # Reshape the output back to BHWC format
19
+ out = jnp.transpose(out.sample, (0, 2, 3, 1))
20
+ return out
21
+
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: flaxdiff
3
- Version: 0.1.37.7
3
+ Version: 0.1.38
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author-email: Ashish Kumar Singh <ashishkmr472@gmail.com>
6
6
  License-Expression: MIT
@@ -14,11 +14,12 @@ flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,2
14
14
  flaxdiff/models/attention.py,sha256=JvrP7-09MV6IfRLRBhqjPmNUU-lkEMk9TOnJSBKcar8,13289
15
15
  flaxdiff/models/common.py,sha256=hWsSs2BP2J-JN1s4qLRr-h-KYkcVyl2hOp1Wsm_L-h8,10994
16
16
  flaxdiff/models/favor_fastattn.py,sha256=79Ew1nqarsNLPzZaBSd1ILORzJr74CupYeqGiCQK5E4,27689
17
+ flaxdiff/models/general.py,sha256=7xMME6KVKQY8sScyHYH4f-Kek4j1pRfplKShFXwVZd4,587
17
18
  flaxdiff/models/simple_unet.py,sha256=L5m2j5580QP7pJ5VIme7U5xYA22PZiGP7qdvcKUnB38,11463
18
19
  flaxdiff/models/simple_vit.py,sha256=UCDDr0XVnpf6tbJWKFtEt3_nAqMqOoakXf5amyVWZNo,7929
19
20
  flaxdiff/models/autoencoder/__init__.py,sha256=qY-7MldZpsfkF-_T2LqlRK7VHbqfmosz0NmvzDlBkOk,78
20
21
  flaxdiff/models/autoencoder/autoencoder.py,sha256=27_hYl0yXAdH9Mx4Xu9J79mSNo-FEKr9SxhVaS3ffn4,591
21
- flaxdiff/models/autoencoder/diffusers.py,sha256=JHeFLCxiHhu-QHwhKiCuKsQJn4AZumquiuxgZkiYGQ0,3643
22
+ flaxdiff/models/autoencoder/diffusers.py,sha256=DVWT4LRMvEtN36Yt0FTD0KzG8Isq_BvHkNpgDy6Gs40,3651
22
23
  flaxdiff/models/autoencoder/simple_autoenc.py,sha256=UXHPgDmwGTnv3Uts6Zj3p9R9nJXnEiEXbllgarwDfXM,805
23
24
  flaxdiff/predictors/__init__.py,sha256=A6lMnRSqDLwBdwB0a1x4xPpcXcTYCISWFMkldbiK6Vs,4554
24
25
  flaxdiff/samplers/__init__.py,sha256=EY9v1pgwEoR64Kiz9K8fAR-4_ir9c03mYeY3hrpUNhE,308
@@ -43,7 +44,7 @@ flaxdiff/trainer/autoencoder_trainer.py,sha256=hxihkRL9WCIQVGOP-pc1jjjIUaRXDLcNo
43
44
  flaxdiff/trainer/diffusion_trainer.py,sha256=kEulMnk6ZkKhQRSVr3UtDdCmXR4cWphJ3XNuk7VIAUY,14189
44
45
  flaxdiff/trainer/simple_trainer.py,sha256=LScHQZCy5ksSC7n0GC0tjOXK-zptxpMJsC6Udf-nz18,22178
45
46
  flaxdiff/trainer/video_diffusion_trainer.py,sha256=gMkKpnKNTo8QhTx5ptEEkc7W5-7rzXIr9queU53hXyQ,2197
46
- flaxdiff-0.1.37.7.dist-info/METADATA,sha256=Z4cI3PW0VHzx-zCtBj36U5BjoVhFh1KxbQ5wfOoZPAo,23985
47
- flaxdiff-0.1.37.7.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
48
- flaxdiff-0.1.37.7.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
49
- flaxdiff-0.1.37.7.dist-info/RECORD,,
47
+ flaxdiff-0.1.38.dist-info/METADATA,sha256=UdC9L-EG8blWVHclW6ZyhCTQQ8fNda_8JcQKHiyoHN8,23983
48
+ flaxdiff-0.1.38.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
49
+ flaxdiff-0.1.38.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
50
+ flaxdiff-0.1.38.dist-info/RECORD,,