flaxdiff 0.1.37.6__py3-none-any.whl → 0.1.38__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -14,15 +14,15 @@ class StableDiffusionVAE(AutoEncoder):
14
14
  def __init__(self, modelname = "CompVis/stable-diffusion-v1-4", revision="bf16", dtype=jnp.bfloat16):
15
15
 
16
16
  from diffusers.models.vae_flax import FlaxEncoder, FlaxDecoder
17
- from diffusers import FlaxStableDiffusionPipeline
17
+ from diffusers import FlaxStableDiffusionPipeline, FlaxAutoencoderKL
18
18
 
19
- pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
19
+ vae, params = FlaxAutoencoderKL.from_pretrained(
20
20
  modelname,
21
- revision=revision,
21
+ # revision=revision,
22
22
  dtype=dtype,
23
23
  )
24
24
 
25
- vae = pipeline.vae
25
+ # vae = pipeline.vae
26
26
 
27
27
  enc = FlaxEncoder(
28
28
  in_channels=vae.config.in_channels,
@@ -0,0 +1,21 @@
1
+ from flax import linen as nn
2
+ import jax
3
+ import jax.numpy as jnp
4
+
5
+ class BCHWModelWrapper(nn.Module):
6
+ model: nn.Module
7
+
8
+ @nn.compact
9
+ def __call__(self, x, temb, textcontext):
10
+ # Reshape the input to BCHW format from BHWC
11
+ x = jnp.transpose(x, (0, 3, 1, 2))
12
+ # Pass the input through the UNet model
13
+ out = self.model(
14
+ sample=x,
15
+ timesteps=temb,
16
+ encoder_hidden_states=textcontext,
17
+ )
18
+ # Reshape the output back to BHWC format
19
+ out = jnp.transpose(out.sample, (0, 2, 3, 1))
20
+ return out
21
+
@@ -133,7 +133,7 @@ class DiffusionSampler():
133
133
 
134
134
  params = params if params is not None else self.params
135
135
 
136
- @jax.jit
136
+ # @jax.jit
137
137
  def sample_model_fn(x_t, t, *additional_inputs):
138
138
  return self.sample_model(params, x_t, t, *additional_inputs)
139
139
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: flaxdiff
3
- Version: 0.1.37.6
3
+ Version: 0.1.38
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author-email: Ashish Kumar Singh <ashishkmr472@gmail.com>
6
6
  License-Expression: MIT
@@ -14,15 +14,16 @@ flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,2
14
14
  flaxdiff/models/attention.py,sha256=JvrP7-09MV6IfRLRBhqjPmNUU-lkEMk9TOnJSBKcar8,13289
15
15
  flaxdiff/models/common.py,sha256=hWsSs2BP2J-JN1s4qLRr-h-KYkcVyl2hOp1Wsm_L-h8,10994
16
16
  flaxdiff/models/favor_fastattn.py,sha256=79Ew1nqarsNLPzZaBSd1ILORzJr74CupYeqGiCQK5E4,27689
17
+ flaxdiff/models/general.py,sha256=7xMME6KVKQY8sScyHYH4f-Kek4j1pRfplKShFXwVZd4,587
17
18
  flaxdiff/models/simple_unet.py,sha256=L5m2j5580QP7pJ5VIme7U5xYA22PZiGP7qdvcKUnB38,11463
18
19
  flaxdiff/models/simple_vit.py,sha256=UCDDr0XVnpf6tbJWKFtEt3_nAqMqOoakXf5amyVWZNo,7929
19
20
  flaxdiff/models/autoencoder/__init__.py,sha256=qY-7MldZpsfkF-_T2LqlRK7VHbqfmosz0NmvzDlBkOk,78
20
21
  flaxdiff/models/autoencoder/autoencoder.py,sha256=27_hYl0yXAdH9Mx4Xu9J79mSNo-FEKr9SxhVaS3ffn4,591
21
- flaxdiff/models/autoencoder/diffusers.py,sha256=JHeFLCxiHhu-QHwhKiCuKsQJn4AZumquiuxgZkiYGQ0,3643
22
+ flaxdiff/models/autoencoder/diffusers.py,sha256=DVWT4LRMvEtN36Yt0FTD0KzG8Isq_BvHkNpgDy6Gs40,3651
22
23
  flaxdiff/models/autoencoder/simple_autoenc.py,sha256=UXHPgDmwGTnv3Uts6Zj3p9R9nJXnEiEXbllgarwDfXM,805
23
24
  flaxdiff/predictors/__init__.py,sha256=A6lMnRSqDLwBdwB0a1x4xPpcXcTYCISWFMkldbiK6Vs,4554
24
25
  flaxdiff/samplers/__init__.py,sha256=EY9v1pgwEoR64Kiz9K8fAR-4_ir9c03mYeY3hrpUNhE,308
25
- flaxdiff/samplers/common.py,sha256=wn8tryC3B0KE0V98zMiH_X2x-Tc1NbM5iV27hn5p8Aw,8846
26
+ flaxdiff/samplers/common.py,sha256=wkzalSYrnsq6oUsevEeRCVfzqwk8qfwvggAlgNTqK-o,8848
26
27
  flaxdiff/samplers/ddim.py,sha256=hTjDm0SmIj-Tkc80QRATMcN_sKVhHbqZQboRQCAn4mY,569
27
28
  flaxdiff/samplers/ddpm.py,sha256=JgkNSo7fp7Jm-8rCy4eu5m4YIzXTWzxv-iHf3EQ0z5w,2243
28
29
  flaxdiff/samplers/euler.py,sha256=QSkttB4DYnepDGwhWq3EGXYjMAqj4qLOdh7u98HttzY,2791
@@ -43,7 +44,7 @@ flaxdiff/trainer/autoencoder_trainer.py,sha256=hxihkRL9WCIQVGOP-pc1jjjIUaRXDLcNo
43
44
  flaxdiff/trainer/diffusion_trainer.py,sha256=kEulMnk6ZkKhQRSVr3UtDdCmXR4cWphJ3XNuk7VIAUY,14189
44
45
  flaxdiff/trainer/simple_trainer.py,sha256=LScHQZCy5ksSC7n0GC0tjOXK-zptxpMJsC6Udf-nz18,22178
45
46
  flaxdiff/trainer/video_diffusion_trainer.py,sha256=gMkKpnKNTo8QhTx5ptEEkc7W5-7rzXIr9queU53hXyQ,2197
46
- flaxdiff-0.1.37.6.dist-info/METADATA,sha256=SujaCKk29ECrfSEIdchYvAl-nf0L270t2of7oeX5kgk,23985
47
- flaxdiff-0.1.37.6.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
48
- flaxdiff-0.1.37.6.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
49
- flaxdiff-0.1.37.6.dist-info/RECORD,,
47
+ flaxdiff-0.1.38.dist-info/METADATA,sha256=UdC9L-EG8blWVHclW6ZyhCTQQ8fNda_8JcQKHiyoHN8,23983
48
+ flaxdiff-0.1.38.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
49
+ flaxdiff-0.1.38.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
50
+ flaxdiff-0.1.38.dist-info/RECORD,,