flaxdiff 0.1.25__py3-none-any.whl → 0.1.27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -23,6 +23,7 @@ class PatchEmbedding(nn.Module):
23
23
  embedding_dim: int
24
24
  dtype: Any = jnp.float32
25
25
  precision: Any = jax.lax.Precision.HIGH
26
+ kernel_init: Callable = partial(kernel_init, 1.0)
26
27
 
27
28
  @nn.compact
28
29
  def __call__(self, x):
@@ -33,6 +34,7 @@ class PatchEmbedding(nn.Module):
33
34
  kernel_size=(self.patch_size, self.patch_size),
34
35
  strides=(self.patch_size, self.patch_size),
35
36
  dtype=self.dtype,
37
+ kernel_init=self.kernel_init(),
36
38
  precision=self.precision)(x)
37
39
  x = jnp.reshape(x, (batch, -1, self.embedding_dim))
38
40
  return x
@@ -81,7 +83,7 @@ class UViT(nn.Module):
81
83
 
82
84
  # Patch embedding
83
85
  x = PatchEmbedding(patch_size=self.patch_size, embedding_dim=self.emb_features,
84
- dtype=self.dtype, precision=self.precision)(x)
86
+ dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init)(x)
85
87
  num_patches = x.shape[1]
86
88
 
87
89
  context_emb = nn.DenseGeneral(features=self.emb_features, kernel_init=self.kernel_init(),
@@ -119,14 +121,14 @@ class UViT(nn.Module):
119
121
 
120
122
  # # Out blocks
121
123
  for i in range(self.num_layers // 2):
122
- skip = jnp.concatenate([x, skips.pop()], axis=-1)
123
- skip = nn.DenseGeneral(features=self.emb_features, kernel_init=self.kernel_init(),
124
- dtype=self.dtype, precision=self.precision)(skip)
124
+ x = jnp.concatenate([x, skips.pop()], axis=-1)
125
+ x = nn.DenseGeneral(features=self.emb_features, kernel_init=self.kernel_init(),
126
+ dtype=self.dtype, precision=self.precision)(x)
125
127
  x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads,
126
128
  dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
127
129
  use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.force_fp32_for_softmax,
128
130
  only_pure_attention=False,
129
- kernel_init=self.kernel_init())(skip)
131
+ kernel_init=self.kernel_init())(x)
130
132
 
131
133
  # print(f'Shape of x after transformer blocks: {x.shape}')
132
134
  x = self.norm()(x)
@@ -139,6 +141,14 @@ class UViT(nn.Module):
139
141
  x = x[:, 1 + num_text_tokens:, :]
140
142
  x = unpatchify(x, channels=self.output_channels)
141
143
  # print(f'Shape of x after final dense layer: {x.shape}')
142
- x = nn.Dense(features=self.output_channels, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init())(x)
144
+ x = nn.Conv(
145
+ features=self.output_channels,
146
+ kernel_size=(3, 3),
147
+ strides=(1, 1),
148
+ padding='SAME',
149
+ dtype=self.dtype,
150
+ precision=self.precision,
151
+ kernel_init=kernel_init(0.0),
152
+ )(x)
143
153
 
144
154
  return x
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flaxdiff
3
- Version: 0.1.25
3
+ Version: 0.1.27
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author: Ashish Kumar Singh
6
6
  Author-email: ashishkmr472@gmail.com
@@ -7,7 +7,7 @@ flaxdiff/models/attention.py,sha256=ZbDGIb5Q6FRqJ6qRY660cqw4WvF9IwCnhEuYdTpLPdM,
7
7
  flaxdiff/models/common.py,sha256=fw_gP7PZayO6RVe6xSf-7FtVq-S0pp5U6NgHg4PlKO8,10990
8
8
  flaxdiff/models/favor_fastattn.py,sha256=79Ew1nqarsNLPzZaBSd1ILORzJr74CupYeqGiCQK5E4,27689
9
9
  flaxdiff/models/simple_unet.py,sha256=h1o9mQlLJy7Ec8Pz_O5miRbAyUaM5UNhSs-oXzpQvZo,10763
10
- flaxdiff/models/simple_vit.py,sha256=g94RchoccNOELCMqAp9hkt290I3_Jg-GWU6Q3RLtQZs,6699
10
+ flaxdiff/models/simple_vit.py,sha256=-xGeiRztVssisf0CRd9CvlBQNrIUXaRSQbNckYvkuac,6972
11
11
  flaxdiff/models/autoencoder/__init__.py,sha256=qY-7MldZpsfkF-_T2LqlRK7VHbqfmosz0NmvzDlBkOk,78
12
12
  flaxdiff/models/autoencoder/autoencoder.py,sha256=27_hYl0yXAdH9Mx4Xu9J79mSNo-FEKr9SxhVaS3ffn4,591
13
13
  flaxdiff/models/autoencoder/diffusers.py,sha256=JHeFLCxiHhu-QHwhKiCuKsQJn4AZumquiuxgZkiYGQ0,3643
@@ -34,7 +34,7 @@ flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk,
34
34
  flaxdiff/trainer/autoencoder_trainer.py,sha256=al7AsZ7yeDMEiDD-gbcXf0ADq_xfk1VMxvg24GfA-XQ,7008
35
35
  flaxdiff/trainer/diffusion_trainer.py,sha256=wKkg63DWZjx2MoM3VQNCDIr40rWN8fUGxH9jWWxfZao,9373
36
36
  flaxdiff/trainer/simple_trainer.py,sha256=Z77zRS5viJpd2Mpl6sonJk5WcnEWi2Cd4gl4u5tIX2M,18206
37
- flaxdiff-0.1.25.dist-info/METADATA,sha256=DaJHzXya9jzJiiiBF4mzwb0FXx_M0DssZbMQuc-RVsI,22083
38
- flaxdiff-0.1.25.dist-info/WHEEL,sha256=cVxcB9AmuTcXqmwrtPhNK88dr7IR_b6qagTj0UvIEbY,91
39
- flaxdiff-0.1.25.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
40
- flaxdiff-0.1.25.dist-info/RECORD,,
37
+ flaxdiff-0.1.27.dist-info/METADATA,sha256=-344uFzDA8b17cd1LV5RpDDx4bGZ6i8kdNFJ439FD9g,22083
38
+ flaxdiff-0.1.27.dist-info/WHEEL,sha256=cVxcB9AmuTcXqmwrtPhNK88dr7IR_b6qagTj0UvIEbY,91
39
+ flaxdiff-0.1.27.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
40
+ flaxdiff-0.1.27.dist-info/RECORD,,