flaxdiff 0.1.23__py3-none-any.whl → 0.1.25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
flaxdiff/models/common.py CHANGED
@@ -108,13 +108,13 @@ class FourierEmbedding(nn.Module):
108
108
  class TimeProjection(nn.Module):
109
109
  features:int
110
110
  activation:Callable=jax.nn.gelu
111
- kernel_init:Callable=partial(kernel_init, 1.0)
111
+ kernel_init:Callable=kernel_init(1.0)
112
112
 
113
113
  @nn.compact
114
114
  def __call__(self, x):
115
- x = nn.DenseGeneral(self.features, kernel_init=self.kernel_init())(x)
115
+ x = nn.DenseGeneral(self.features, kernel_init=self.kernel_init)(x)
116
116
  x = self.activation(x)
117
- x = nn.DenseGeneral(self.features, kernel_init=self.kernel_init())(x)
117
+ x = nn.DenseGeneral(self.features, kernel_init=self.kernel_init)(x)
118
118
  x = self.activation(x)
119
119
  return x
120
120
 
@@ -123,7 +123,7 @@ class SeparableConv(nn.Module):
123
123
  kernel_size:tuple=(3, 3)
124
124
  strides:tuple=(1, 1)
125
125
  use_bias:bool=False
126
- kernel_init:Callable=partial(kernel_init, 1.0)
126
+ kernel_init:Callable=kernel_init(1.0)
127
127
  padding:str="SAME"
128
128
  dtype: Optional[Dtype] = None
129
129
  precision: PrecisionLike = None
@@ -133,7 +133,7 @@ class SeparableConv(nn.Module):
133
133
  in_features = x.shape[-1]
134
134
  depthwise = nn.Conv(
135
135
  features=in_features, kernel_size=self.kernel_size,
136
- strides=self.strides, kernel_init=self.kernel_init(),
136
+ strides=self.strides, kernel_init=self.kernel_init,
137
137
  feature_group_count=in_features, use_bias=self.use_bias,
138
138
  padding=self.padding,
139
139
  dtype=self.dtype,
@@ -141,7 +141,7 @@ class SeparableConv(nn.Module):
141
141
  )(x)
142
142
  pointwise = nn.Conv(
143
143
  features=self.features, kernel_size=(1, 1),
144
- strides=(1, 1), kernel_init=self.kernel_init(),
144
+ strides=(1, 1), kernel_init=self.kernel_init,
145
145
  use_bias=self.use_bias,
146
146
  dtype=self.dtype,
147
147
  precision=self.precision
@@ -153,7 +153,7 @@ class ConvLayer(nn.Module):
153
153
  features:int
154
154
  kernel_size:tuple=(3, 3)
155
155
  strides:tuple=(1, 1)
156
- kernel_init:Callable=partial(kernel_init, 1.0)
156
+ kernel_init:Callable=kernel_init(1.0)
157
157
  dtype: Optional[Dtype] = None
158
158
  precision: PrecisionLike = None
159
159
 
@@ -164,7 +164,7 @@ class ConvLayer(nn.Module):
164
164
  features=self.features,
165
165
  kernel_size=self.kernel_size,
166
166
  strides=self.strides,
167
- kernel_init=self.kernel_init(),
167
+ kernel_init=self.kernel_init,
168
168
  dtype=self.dtype,
169
169
  precision=self.precision
170
170
  )
@@ -183,7 +183,7 @@ class ConvLayer(nn.Module):
183
183
  features=self.features,
184
184
  kernel_size=self.kernel_size,
185
185
  strides=self.strides,
186
- kernel_init=self.kernel_init(),
186
+ kernel_init=self.kernel_init,
187
187
  dtype=self.dtype,
188
188
  precision=self.precision
189
189
  )
@@ -192,7 +192,7 @@ class ConvLayer(nn.Module):
192
192
  features=self.features,
193
193
  kernel_size=self.kernel_size,
194
194
  strides=self.strides,
195
- kernel_init=self.kernel_init(),
195
+ kernel_init=self.kernel_init,
196
196
  dtype=self.dtype,
197
197
  precision=self.precision
198
198
  )
@@ -206,7 +206,7 @@ class Upsample(nn.Module):
206
206
  activation:Callable=jax.nn.swish
207
207
  dtype: Optional[Dtype] = None
208
208
  precision: PrecisionLike = None
209
- kernel_init:Callable=partial(kernel_init, 1.0)
209
+ kernel_init:Callable=kernel_init(1.0)
210
210
 
211
211
  @nn.compact
212
212
  def __call__(self, x, residual=None):
@@ -221,7 +221,7 @@ class Upsample(nn.Module):
221
221
  strides=(1, 1),
222
222
  dtype=self.dtype,
223
223
  precision=self.precision,
224
- kernel_init=self.kernel_init()
224
+ kernel_init=self.kernel_init
225
225
  )(out)
226
226
  if residual is not None:
227
227
  out = jnp.concatenate([out, residual], axis=-1)
@@ -233,7 +233,7 @@ class Downsample(nn.Module):
233
233
  activation:Callable=jax.nn.swish
234
234
  dtype: Optional[Dtype] = None
235
235
  precision: PrecisionLike = None
236
- kernel_init:Callable=partial(kernel_init, 1.0)
236
+ kernel_init:Callable=kernel_init(1.0)
237
237
 
238
238
  @nn.compact
239
239
  def __call__(self, x, residual=None):
@@ -244,7 +244,7 @@ class Downsample(nn.Module):
244
244
  strides=(2, 2),
245
245
  dtype=self.dtype,
246
246
  precision=self.precision,
247
- kernel_init=self.kernel_init()
247
+ kernel_init=self.kernel_init
248
248
  )(x)
249
249
  if residual is not None:
250
250
  if residual.shape[1] > out.shape[1]:
@@ -269,7 +269,7 @@ class ResidualBlock(nn.Module):
269
269
  direction:str=None
270
270
  res:int=2
271
271
  norm_groups:int=8
272
- kernel_init:Callable=partial(kernel_init, 1.0)
272
+ kernel_init:Callable=kernel_init(1.0)
273
273
  dtype: Optional[Dtype] = None
274
274
  precision: PrecisionLike = None
275
275
  named_norms:bool=False
@@ -296,7 +296,7 @@ class ResidualBlock(nn.Module):
296
296
  features=self.features,
297
297
  kernel_size=self.kernel_size,
298
298
  strides=self.strides,
299
- kernel_init=self.kernel_init(),
299
+ kernel_init=self.kernel_init,
300
300
  name="conv1",
301
301
  dtype=self.dtype,
302
302
  precision=self.precision
@@ -321,7 +321,7 @@ class ResidualBlock(nn.Module):
321
321
  features=self.features,
322
322
  kernel_size=self.kernel_size,
323
323
  strides=self.strides,
324
- kernel_init=self.kernel_init(),
324
+ kernel_init=self.kernel_init,
325
325
  name="conv2",
326
326
  dtype=self.dtype,
327
327
  precision=self.precision
@@ -333,7 +333,7 @@ class ResidualBlock(nn.Module):
333
333
  features=self.features,
334
334
  kernel_size=(1, 1),
335
335
  strides=1,
336
- kernel_init=self.kernel_init(),
336
+ kernel_init=self.kernel_init,
337
337
  name="residual_conv",
338
338
  dtype=self.dtype,
339
339
  precision=self.precision
@@ -58,6 +58,9 @@ class UViT(nn.Module):
58
58
  dtype: Any = jnp.float32
59
59
  precision: Any = jax.lax.Precision.HIGH
60
60
  use_projection: bool = False
61
+ use_flash_attention: bool = False
62
+ use_self_and_cross: bool = False
63
+ force_fp32_for_softmax: bool = True
61
64
  activation:Callable = jax.nn.swish
62
65
  norm_groups:int=8
63
66
  dtype: Optional[Dtype] = None
@@ -102,7 +105,7 @@ class UViT(nn.Module):
102
105
  for i in range(self.num_layers // 2):
103
106
  x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads,
104
107
  dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
105
- use_flash_attention=False, use_self_and_cross=False, force_fp32_for_softmax=True,
108
+ use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.force_fp32_for_softmax,
106
109
  only_pure_attention=False,
107
110
  kernel_init=self.kernel_init())(x)
108
111
  skips.append(x)
@@ -110,7 +113,7 @@ class UViT(nn.Module):
110
113
  # Middle block
111
114
  x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads,
112
115
  dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
113
- use_flash_attention=False, use_self_and_cross=True, force_fp32_for_softmax=True,
116
+ use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.force_fp32_for_softmax,
114
117
  only_pure_attention=False,
115
118
  kernel_init=self.kernel_init())(x)
116
119
 
@@ -121,7 +124,7 @@ class UViT(nn.Module):
121
124
  dtype=self.dtype, precision=self.precision)(skip)
122
125
  x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads,
123
126
  dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
124
- use_flash_attention=False, use_self_and_cross=False, force_fp32_for_softmax=True,
127
+ use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.force_fp32_for_softmax,
125
128
  only_pure_attention=False,
126
129
  kernel_init=self.kernel_init())(skip)
127
130
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flaxdiff
3
- Version: 0.1.23
3
+ Version: 0.1.25
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author: Ashish Kumar Singh
6
6
  Author-email: ashishkmr472@gmail.com
@@ -4,10 +4,10 @@ flaxdiff/data/__init__.py,sha256=PM3PkHihyohT5SHVYKc8vQ4IeVfGPpCktkSVwvqMjQ4,52
4
4
  flaxdiff/data/online_loader.py,sha256=LIK_O1C3yDPvvAEOWvsJrVeBopVqjg2IOMTbiSIvH6M,11025
5
5
  flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,26
6
6
  flaxdiff/models/attention.py,sha256=ZbDGIb5Q6FRqJ6qRY660cqw4WvF9IwCnhEuYdTpLPdM,13023
7
- flaxdiff/models/common.py,sha256=lzfOHB-7Bjx83ZZzywXy2mjhwP2UOMKR11vdVSKvsCo,11068
7
+ flaxdiff/models/common.py,sha256=fw_gP7PZayO6RVe6xSf-7FtVq-S0pp5U6NgHg4PlKO8,10990
8
8
  flaxdiff/models/favor_fastattn.py,sha256=79Ew1nqarsNLPzZaBSd1ILORzJr74CupYeqGiCQK5E4,27689
9
9
  flaxdiff/models/simple_unet.py,sha256=h1o9mQlLJy7Ec8Pz_O5miRbAyUaM5UNhSs-oXzpQvZo,10763
10
- flaxdiff/models/simple_vit.py,sha256=tXO0WIozq2C4j2GqphSM9mMAYEwj9fKr7rfm4G6vf4A,6403
10
+ flaxdiff/models/simple_vit.py,sha256=g94RchoccNOELCMqAp9hkt290I3_Jg-GWU6Q3RLtQZs,6699
11
11
  flaxdiff/models/autoencoder/__init__.py,sha256=qY-7MldZpsfkF-_T2LqlRK7VHbqfmosz0NmvzDlBkOk,78
12
12
  flaxdiff/models/autoencoder/autoencoder.py,sha256=27_hYl0yXAdH9Mx4Xu9J79mSNo-FEKr9SxhVaS3ffn4,591
13
13
  flaxdiff/models/autoencoder/diffusers.py,sha256=JHeFLCxiHhu-QHwhKiCuKsQJn4AZumquiuxgZkiYGQ0,3643
@@ -34,7 +34,7 @@ flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk,
34
34
  flaxdiff/trainer/autoencoder_trainer.py,sha256=al7AsZ7yeDMEiDD-gbcXf0ADq_xfk1VMxvg24GfA-XQ,7008
35
35
  flaxdiff/trainer/diffusion_trainer.py,sha256=wKkg63DWZjx2MoM3VQNCDIr40rWN8fUGxH9jWWxfZao,9373
36
36
  flaxdiff/trainer/simple_trainer.py,sha256=Z77zRS5viJpd2Mpl6sonJk5WcnEWi2Cd4gl4u5tIX2M,18206
37
- flaxdiff-0.1.23.dist-info/METADATA,sha256=iTgk4DY-kuALF-Y1U-a433c7pYRhSy-bXPQrtXh6d54,22083
38
- flaxdiff-0.1.23.dist-info/WHEEL,sha256=HiCZjzuy6Dw0hdX5R3LCFPDmFS4BWl8H-8W39XfmgX4,91
39
- flaxdiff-0.1.23.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
40
- flaxdiff-0.1.23.dist-info/RECORD,,
37
+ flaxdiff-0.1.25.dist-info/METADATA,sha256=DaJHzXya9jzJiiiBF4mzwb0FXx_M0DssZbMQuc-RVsI,22083
38
+ flaxdiff-0.1.25.dist-info/WHEEL,sha256=cVxcB9AmuTcXqmwrtPhNK88dr7IR_b6qagTj0UvIEbY,91
39
+ flaxdiff-0.1.25.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
40
+ flaxdiff-0.1.25.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (72.2.0)
2
+ Generator: setuptools (74.1.2)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5