flaxdiff 0.1.35.5__py3-none-any.whl → 0.1.35.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -11,6 +11,7 @@ import einops
11
11
  import functools
12
12
  import math
13
13
  from .common import kernel_init
14
+ import jax.experimental.pallas.ops.tpu.flash_attention
14
15
 
15
16
  class EfficientAttention(nn.Module):
16
17
  """
@@ -50,7 +50,7 @@ class Unet(nn.Module):
50
50
  features=self.feature_depths[0],
51
51
  kernel_size=(3, 3),
52
52
  strides=(1, 1),
53
- kernel_init=self.kernel_init(1.0),
53
+ kernel_init=self.kernel_init(scale=1.0),
54
54
  dtype=self.dtype,
55
55
  precision=self.precision
56
56
  )(x)
@@ -65,7 +65,7 @@ class Unet(nn.Module):
65
65
  down_conv_type,
66
66
  name=f"down_{i}_residual_{j}",
67
67
  features=dim_in,
68
- kernel_init=self.kernel_init(1.0),
68
+ kernel_init=self.kernel_init(scale=1.0),
69
69
  kernel_size=(3, 3),
70
70
  strides=(1, 1),
71
71
  activation=self.activation,
@@ -85,7 +85,7 @@ class Unet(nn.Module):
85
85
  force_fp32_for_softmax=attention_config.get("force_fp32_for_softmax", False),
86
86
  norm_inputs=attention_config.get("norm_inputs", True),
87
87
  explicitly_add_residual=attention_config.get("explicitly_add_residual", True),
88
- kernel_init=self.kernel_init(1.0),
88
+ kernel_init=self.kernel_init(scale=1.0),
89
89
  name=f"down_{i}_attention_{j}")(x, textcontext)
90
90
  # print("down residual for feature level", i, "is of shape", x.shape, "features", dim_in)
91
91
  downs.append(x)
@@ -108,7 +108,7 @@ class Unet(nn.Module):
108
108
  middle_conv_type,
109
109
  name=f"middle_res1_{j}",
110
110
  features=middle_dim_out,
111
- kernel_init=self.kernel_init(1.0),
111
+ kernel_init=self.kernel_init(scale=1.0),
112
112
  kernel_size=(3, 3),
113
113
  strides=(1, 1),
114
114
  activation=self.activation,
@@ -129,13 +129,13 @@ class Unet(nn.Module):
129
129
  force_fp32_for_softmax=middle_attention.get("force_fp32_for_softmax", False),
130
130
  norm_inputs=middle_attention.get("norm_inputs", True),
131
131
  explicitly_add_residual=middle_attention.get("explicitly_add_residual", True),
132
- kernel_init=self.kernel_init(1.0),
132
+ kernel_init=self.kernel_init(scale=1.0),
133
133
  name=f"middle_attention_{j}")(x, textcontext)
134
134
  x = ResidualBlock(
135
135
  middle_conv_type,
136
136
  name=f"middle_res2_{j}",
137
137
  features=middle_dim_out,
138
- kernel_init=self.kernel_init(1.0),
138
+ kernel_init=self.kernel_init(scale=1.0),
139
139
  kernel_size=(3, 3),
140
140
  strides=(1, 1),
141
141
  activation=self.activation,
@@ -157,7 +157,7 @@ class Unet(nn.Module):
157
157
  up_conv_type,# if j == 0 else "separable",
158
158
  name=f"up_{i}_residual_{j}",
159
159
  features=dim_out,
160
- kernel_init=self.kernel_init(1.0),
160
+ kernel_init=self.kernel_init(scale=1.0),
161
161
  kernel_size=kernel_size,
162
162
  strides=(1, 1),
163
163
  activation=self.activation,
@@ -177,7 +177,7 @@ class Unet(nn.Module):
177
177
  force_fp32_for_softmax=middle_attention.get("force_fp32_for_softmax", False),
178
178
  norm_inputs=attention_config.get("norm_inputs", True),
179
179
  explicitly_add_residual=attention_config.get("explicitly_add_residual", True),
180
- kernel_init=self.kernel_init(1.0),
180
+ kernel_init=self.kernel_init(scale=1.0),
181
181
  name=f"up_{i}_attention_{j}")(x, textcontext)
182
182
  # print("Upscaling ", i, x.shape)
183
183
  if i != len(feature_depths) - 1:
@@ -196,7 +196,7 @@ class Unet(nn.Module):
196
196
  features=self.feature_depths[0],
197
197
  kernel_size=(3, 3),
198
198
  strides=(1, 1),
199
- kernel_init=self.kernel_init(1.0),
199
+ kernel_init=self.kernel_init(scale=1.0),
200
200
  dtype=self.dtype,
201
201
  precision=self.precision
202
202
  )(x)
@@ -207,7 +207,7 @@ class Unet(nn.Module):
207
207
  conv_type,
208
208
  name="final_residual",
209
209
  features=self.feature_depths[0],
210
- kernel_init=self.kernel_init(1.0),
210
+ kernel_init=self.kernel_init(scale=1.0),
211
211
  kernel_size=(3,3),
212
212
  strides=(1, 1),
213
213
  activation=self.activation,
@@ -226,7 +226,7 @@ class Unet(nn.Module):
226
226
  kernel_size=(3, 3),
227
227
  strides=(1, 1),
228
228
  # activation=jax.nn.mish
229
- kernel_init=self.kernel_init(0.0),
229
+ kernel_init=self.kernel_init(scale=0.0),
230
230
  dtype=self.dtype,
231
231
  precision=self.precision
232
232
  )(x)
@@ -70,7 +70,7 @@ class UViT(nn.Module):
70
70
  kernel_init: Callable = partial(kernel_init, scale=1.0)
71
71
  add_residualblock_output: bool = False
72
72
  norm_inputs: bool = False
73
- explicitly_add_residual: bool = False
73
+ explicitly_add_residual: bool = True
74
74
 
75
75
  def setup(self):
76
76
  if self.norm_groups > 0:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flaxdiff
3
- Version: 0.1.35.5
3
+ Version: 0.1.35.6
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author: Ashish Kumar Singh
6
6
  Author-email: ashishkmr472@gmail.com
@@ -3,11 +3,11 @@ flaxdiff/utils.py,sha256=B0GcHlzlVYDNEIdh2v5qmP4u0neIT-FqexNohuyuCvg,2452
3
3
  flaxdiff/data/__init__.py,sha256=PM3PkHihyohT5SHVYKc8vQ4IeVfGPpCktkSVwvqMjQ4,52
4
4
  flaxdiff/data/online_loader.py,sha256=DoHrMZCi5gMd9tmkCpZIUU9lGxvfYtuaz58943_lCRc,11315
5
5
  flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,26
6
- flaxdiff/models/attention.py,sha256=Oz0-F0jllo3cIqQQJwYbYMrkhVzZWf5bq1UC6_RU1r8,13234
6
+ flaxdiff/models/attention.py,sha256=JvrP7-09MV6IfRLRBhqjPmNUU-lkEMk9TOnJSBKcar8,13289
7
7
  flaxdiff/models/common.py,sha256=hWsSs2BP2J-JN1s4qLRr-h-KYkcVyl2hOp1Wsm_L-h8,10994
8
8
  flaxdiff/models/favor_fastattn.py,sha256=79Ew1nqarsNLPzZaBSd1ILORzJr74CupYeqGiCQK5E4,27689
9
- flaxdiff/models/simple_unet.py,sha256=EmTmf3wLIVULGpL1cSmAPhvg7RW4U8Ff3E8Qrtt8RLY,11397
10
- flaxdiff/models/simple_vit.py,sha256=YQiK5AHMjNxH5DAF6Oy4kjRPqddrvUP4i-CuavASPPI,7930
9
+ flaxdiff/models/simple_unet.py,sha256=L5m2j5580QP7pJ5VIme7U5xYA22PZiGP7qdvcKUnB38,11463
10
+ flaxdiff/models/simple_vit.py,sha256=UCDDr0XVnpf6tbJWKFtEt3_nAqMqOoakXf5amyVWZNo,7929
11
11
  flaxdiff/models/autoencoder/__init__.py,sha256=qY-7MldZpsfkF-_T2LqlRK7VHbqfmosz0NmvzDlBkOk,78
12
12
  flaxdiff/models/autoencoder/autoencoder.py,sha256=27_hYl0yXAdH9Mx4Xu9J79mSNo-FEKr9SxhVaS3ffn4,591
13
13
  flaxdiff/models/autoencoder/diffusers.py,sha256=JHeFLCxiHhu-QHwhKiCuKsQJn4AZumquiuxgZkiYGQ0,3643
@@ -34,7 +34,7 @@ flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk,
34
34
  flaxdiff/trainer/autoencoder_trainer.py,sha256=al7AsZ7yeDMEiDD-gbcXf0ADq_xfk1VMxvg24GfA-XQ,7008
35
35
  flaxdiff/trainer/diffusion_trainer.py,sha256=wKkg63DWZjx2MoM3VQNCDIr40rWN8fUGxH9jWWxfZao,9373
36
36
  flaxdiff/trainer/simple_trainer.py,sha256=cawm6fZNQoLLATMneAU2gQ9j7kefqHnBPHuaIj3i_a4,18237
37
- flaxdiff-0.1.35.5.dist-info/METADATA,sha256=J8ikdMf9LIzdAuk6E3E77sI8NPegdgaKf5mrDHhpdXc,22085
38
- flaxdiff-0.1.35.5.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
39
- flaxdiff-0.1.35.5.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
40
- flaxdiff-0.1.35.5.dist-info/RECORD,,
37
+ flaxdiff-0.1.35.6.dist-info/METADATA,sha256=NVCk5V7Zc3iq-nrWTivzO17dQa1fIjYgjJb800ZrZhQ,22085
38
+ flaxdiff-0.1.35.6.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
39
+ flaxdiff-0.1.35.6.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
40
+ flaxdiff-0.1.35.6.dist-info/RECORD,,