flaxdiff 0.1.24__tar.gz → 0.1.26__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. {flaxdiff-0.1.24 → flaxdiff-0.1.26}/PKG-INFO +1 -1
  2. {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff/models/common.py +18 -18
  3. {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff/models/simple_vit.py +18 -8
  4. {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff.egg-info/PKG-INFO +1 -1
  5. {flaxdiff-0.1.24 → flaxdiff-0.1.26}/setup.py +1 -1
  6. {flaxdiff-0.1.24 → flaxdiff-0.1.26}/README.md +0 -0
  7. {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff/__init__.py +0 -0
  8. {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff/data/__init__.py +0 -0
  9. {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff/data/online_loader.py +0 -0
  10. {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff/models/__init__.py +0 -0
  11. {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff/models/attention.py +0 -0
  12. {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff/models/autoencoder/__init__.py +0 -0
  13. {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff/models/autoencoder/autoencoder.py +0 -0
  14. {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff/models/autoencoder/diffusers.py +0 -0
  15. {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff/models/autoencoder/simple_autoenc.py +0 -0
  16. {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff/models/favor_fastattn.py +0 -0
  17. {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff/models/simple_unet.py +0 -0
  18. {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff/predictors/__init__.py +0 -0
  19. {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff/samplers/__init__.py +0 -0
  20. {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff/samplers/common.py +0 -0
  21. {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff/samplers/ddim.py +0 -0
  22. {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff/samplers/ddpm.py +0 -0
  23. {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff/samplers/euler.py +0 -0
  24. {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff/samplers/heun_sampler.py +0 -0
  25. {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff/samplers/multistep_dpm.py +0 -0
  26. {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff/samplers/rk4_sampler.py +0 -0
  27. {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff/schedulers/__init__.py +0 -0
  28. {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff/schedulers/common.py +0 -0
  29. {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff/schedulers/continuous.py +0 -0
  30. {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff/schedulers/cosine.py +0 -0
  31. {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff/schedulers/discrete.py +0 -0
  32. {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff/schedulers/exp.py +0 -0
  33. {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff/schedulers/karras.py +0 -0
  34. {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff/schedulers/linear.py +0 -0
  35. {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff/schedulers/sqrt.py +0 -0
  36. {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff/trainer/__init__.py +0 -0
  37. {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff/trainer/autoencoder_trainer.py +0 -0
  38. {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff/trainer/diffusion_trainer.py +0 -0
  39. {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff/trainer/simple_trainer.py +0 -0
  40. {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff/utils.py +0 -0
  41. {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff.egg-info/SOURCES.txt +0 -0
  42. {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff.egg-info/dependency_links.txt +0 -0
  43. {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff.egg-info/requires.txt +0 -0
  44. {flaxdiff-0.1.24 → flaxdiff-0.1.26}/flaxdiff.egg-info/top_level.txt +0 -0
  45. {flaxdiff-0.1.24 → flaxdiff-0.1.26}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flaxdiff
3
- Version: 0.1.24
3
+ Version: 0.1.26
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author: Ashish Kumar Singh
6
6
  Author-email: ashishkmr472@gmail.com
@@ -108,13 +108,13 @@ class FourierEmbedding(nn.Module):
108
108
  class TimeProjection(nn.Module):
109
109
  features:int
110
110
  activation:Callable=jax.nn.gelu
111
- kernel_init:Callable=partial(kernel_init, 1.0)
111
+ kernel_init:Callable=kernel_init(1.0)
112
112
 
113
113
  @nn.compact
114
114
  def __call__(self, x):
115
- x = nn.DenseGeneral(self.features, kernel_init=self.kernel_init())(x)
115
+ x = nn.DenseGeneral(self.features, kernel_init=self.kernel_init)(x)
116
116
  x = self.activation(x)
117
- x = nn.DenseGeneral(self.features, kernel_init=self.kernel_init())(x)
117
+ x = nn.DenseGeneral(self.features, kernel_init=self.kernel_init)(x)
118
118
  x = self.activation(x)
119
119
  return x
120
120
 
@@ -123,7 +123,7 @@ class SeparableConv(nn.Module):
123
123
  kernel_size:tuple=(3, 3)
124
124
  strides:tuple=(1, 1)
125
125
  use_bias:bool=False
126
- kernel_init:Callable=partial(kernel_init, 1.0)
126
+ kernel_init:Callable=kernel_init(1.0)
127
127
  padding:str="SAME"
128
128
  dtype: Optional[Dtype] = None
129
129
  precision: PrecisionLike = None
@@ -133,7 +133,7 @@ class SeparableConv(nn.Module):
133
133
  in_features = x.shape[-1]
134
134
  depthwise = nn.Conv(
135
135
  features=in_features, kernel_size=self.kernel_size,
136
- strides=self.strides, kernel_init=self.kernel_init(),
136
+ strides=self.strides, kernel_init=self.kernel_init,
137
137
  feature_group_count=in_features, use_bias=self.use_bias,
138
138
  padding=self.padding,
139
139
  dtype=self.dtype,
@@ -141,7 +141,7 @@ class SeparableConv(nn.Module):
141
141
  )(x)
142
142
  pointwise = nn.Conv(
143
143
  features=self.features, kernel_size=(1, 1),
144
- strides=(1, 1), kernel_init=self.kernel_init(),
144
+ strides=(1, 1), kernel_init=self.kernel_init,
145
145
  use_bias=self.use_bias,
146
146
  dtype=self.dtype,
147
147
  precision=self.precision
@@ -153,7 +153,7 @@ class ConvLayer(nn.Module):
153
153
  features:int
154
154
  kernel_size:tuple=(3, 3)
155
155
  strides:tuple=(1, 1)
156
- kernel_init:Callable=partial(kernel_init, 1.0)
156
+ kernel_init:Callable=kernel_init(1.0)
157
157
  dtype: Optional[Dtype] = None
158
158
  precision: PrecisionLike = None
159
159
 
@@ -164,7 +164,7 @@ class ConvLayer(nn.Module):
164
164
  features=self.features,
165
165
  kernel_size=self.kernel_size,
166
166
  strides=self.strides,
167
- kernel_init=self.kernel_init(),
167
+ kernel_init=self.kernel_init,
168
168
  dtype=self.dtype,
169
169
  precision=self.precision
170
170
  )
@@ -183,7 +183,7 @@ class ConvLayer(nn.Module):
183
183
  features=self.features,
184
184
  kernel_size=self.kernel_size,
185
185
  strides=self.strides,
186
- kernel_init=self.kernel_init(),
186
+ kernel_init=self.kernel_init,
187
187
  dtype=self.dtype,
188
188
  precision=self.precision
189
189
  )
@@ -192,7 +192,7 @@ class ConvLayer(nn.Module):
192
192
  features=self.features,
193
193
  kernel_size=self.kernel_size,
194
194
  strides=self.strides,
195
- kernel_init=self.kernel_init(),
195
+ kernel_init=self.kernel_init,
196
196
  dtype=self.dtype,
197
197
  precision=self.precision
198
198
  )
@@ -206,7 +206,7 @@ class Upsample(nn.Module):
206
206
  activation:Callable=jax.nn.swish
207
207
  dtype: Optional[Dtype] = None
208
208
  precision: PrecisionLike = None
209
- kernel_init:Callable=partial(kernel_init, 1.0)
209
+ kernel_init:Callable=kernel_init(1.0)
210
210
 
211
211
  @nn.compact
212
212
  def __call__(self, x, residual=None):
@@ -221,7 +221,7 @@ class Upsample(nn.Module):
221
221
  strides=(1, 1),
222
222
  dtype=self.dtype,
223
223
  precision=self.precision,
224
- kernel_init=self.kernel_init()
224
+ kernel_init=self.kernel_init
225
225
  )(out)
226
226
  if residual is not None:
227
227
  out = jnp.concatenate([out, residual], axis=-1)
@@ -233,7 +233,7 @@ class Downsample(nn.Module):
233
233
  activation:Callable=jax.nn.swish
234
234
  dtype: Optional[Dtype] = None
235
235
  precision: PrecisionLike = None
236
- kernel_init:Callable=partial(kernel_init, 1.0)
236
+ kernel_init:Callable=kernel_init(1.0)
237
237
 
238
238
  @nn.compact
239
239
  def __call__(self, x, residual=None):
@@ -244,7 +244,7 @@ class Downsample(nn.Module):
244
244
  strides=(2, 2),
245
245
  dtype=self.dtype,
246
246
  precision=self.precision,
247
- kernel_init=self.kernel_init()
247
+ kernel_init=self.kernel_init
248
248
  )(x)
249
249
  if residual is not None:
250
250
  if residual.shape[1] > out.shape[1]:
@@ -269,7 +269,7 @@ class ResidualBlock(nn.Module):
269
269
  direction:str=None
270
270
  res:int=2
271
271
  norm_groups:int=8
272
- kernel_init:Callable=partial(kernel_init, 1.0)
272
+ kernel_init:Callable=kernel_init(1.0)
273
273
  dtype: Optional[Dtype] = None
274
274
  precision: PrecisionLike = None
275
275
  named_norms:bool=False
@@ -296,7 +296,7 @@ class ResidualBlock(nn.Module):
296
296
  features=self.features,
297
297
  kernel_size=self.kernel_size,
298
298
  strides=self.strides,
299
- kernel_init=self.kernel_init(),
299
+ kernel_init=self.kernel_init,
300
300
  name="conv1",
301
301
  dtype=self.dtype,
302
302
  precision=self.precision
@@ -321,7 +321,7 @@ class ResidualBlock(nn.Module):
321
321
  features=self.features,
322
322
  kernel_size=self.kernel_size,
323
323
  strides=self.strides,
324
- kernel_init=self.kernel_init(),
324
+ kernel_init=self.kernel_init,
325
325
  name="conv2",
326
326
  dtype=self.dtype,
327
327
  precision=self.precision
@@ -333,7 +333,7 @@ class ResidualBlock(nn.Module):
333
333
  features=self.features,
334
334
  kernel_size=(1, 1),
335
335
  strides=1,
336
- kernel_init=self.kernel_init(),
336
+ kernel_init=self.kernel_init,
337
337
  name="residual_conv",
338
338
  dtype=self.dtype,
339
339
  precision=self.precision
@@ -23,6 +23,7 @@ class PatchEmbedding(nn.Module):
23
23
  embedding_dim: int
24
24
  dtype: Any = jnp.float32
25
25
  precision: Any = jax.lax.Precision.HIGH
26
+ kernel_init: Callable = partial(kernel_init, 1.0)
26
27
 
27
28
  @nn.compact
28
29
  def __call__(self, x):
@@ -33,6 +34,7 @@ class PatchEmbedding(nn.Module):
33
34
  kernel_size=(self.patch_size, self.patch_size),
34
35
  strides=(self.patch_size, self.patch_size),
35
36
  dtype=self.dtype,
37
+ kernel_init=self.kernel_init(),
36
38
  precision=self.precision)(x)
37
39
  x = jnp.reshape(x, (batch, -1, self.embedding_dim))
38
40
  return x
@@ -96,7 +98,7 @@ class UViT(nn.Module):
96
98
  # print(f'Shape of x after time embedding: {x.shape}')
97
99
 
98
100
  # Add positional encoding
99
- x = PositionalEncoding(max_len=x.shape[1], embedding_dim=self.emb_features)(x)
101
+ x = PositionalEncoding(max_len=x.shape[1], embedding_dim=self.emb_features, kernel_init=self.kernel_init)(x)
100
102
 
101
103
  # print(f'Shape of x after positional encoding: {x.shape}')
102
104
 
@@ -113,20 +115,20 @@ class UViT(nn.Module):
113
115
  # Middle block
114
116
  x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads,
115
117
  dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
116
- use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.fforce_fp32_for_softmax,
118
+ use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.force_fp32_for_softmax,
117
119
  only_pure_attention=False,
118
120
  kernel_init=self.kernel_init())(x)
119
121
 
120
122
  # # Out blocks
121
123
  for i in range(self.num_layers // 2):
122
- skip = jnp.concatenate([x, skips.pop()], axis=-1)
123
- skip = nn.DenseGeneral(features=self.emb_features, kernel_init=self.kernel_init(),
124
- dtype=self.dtype, precision=self.precision)(skip)
124
+ x = jnp.concatenate([x, skips.pop()], axis=-1)
125
+ x = nn.DenseGeneral(features=self.emb_features, kernel_init=self.kernel_init(),
126
+ dtype=self.dtype, precision=self.precision)(x)
125
127
  x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads,
126
128
  dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
127
- use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.fforce_fp32_for_softmax,
129
+ use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.force_fp32_for_softmax,
128
130
  only_pure_attention=False,
129
- kernel_init=self.kernel_init())(skip)
131
+ kernel_init=self.kernel_init())(x)
130
132
 
131
133
  # print(f'Shape of x after transformer blocks: {x.shape}')
132
134
  x = self.norm()(x)
@@ -139,6 +141,14 @@ class UViT(nn.Module):
139
141
  x = x[:, 1 + num_text_tokens:, :]
140
142
  x = unpatchify(x, channels=self.output_channels)
141
143
  # print(f'Shape of x after final dense layer: {x.shape}')
142
- x = nn.Dense(features=self.output_channels, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init())(x)
144
+ x = nn.Conv(
145
+ features=self.output_channels,
146
+ kernel_size=(3, 3),
147
+ strides=(1, 1),
148
+ padding='SAME',
149
+ dtype=self.dtype,
150
+ precision=self.precision,
151
+ kernel_init=kernel_init(0.0),
152
+ )(x)
143
153
 
144
154
  return x
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flaxdiff
3
- Version: 0.1.24
3
+ Version: 0.1.26
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author: Ashish Kumar Singh
6
6
  Author-email: ashishkmr472@gmail.com
@@ -11,7 +11,7 @@ required_packages=[
11
11
  setup(
12
12
  name='flaxdiff',
13
13
  packages=find_packages(),
14
- version='0.1.24',
14
+ version='0.1.26',
15
15
  description='A versatile and easy to understand Diffusion library',
16
16
  long_description=open('README.md').read(),
17
17
  long_description_content_type='text/markdown',
File without changes
File without changes
File without changes