flaxdiff 0.1.28__py3-none-any.whl → 0.1.30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
flaxdiff/models/common.py CHANGED
@@ -8,7 +8,7 @@ import einops
8
8
  from functools import partial
9
9
 
10
10
  # Kernel initializer to use
11
- def kernel_init(scale, dtype=jnp.float32):
11
+ def kernel_init(scale=1.0, dtype=jnp.float32):
12
12
  scale = max(scale, 1e-10)
13
13
  return nn.initializers.variance_scaling(scale=scale, mode="fan_avg", distribution="truncated_normal", dtype=dtype)
14
14
 
@@ -23,7 +23,7 @@ class PatchEmbedding(nn.Module):
23
23
  embedding_dim: int
24
24
  dtype: Any = jnp.float32
25
25
  precision: Any = jax.lax.Precision.HIGH
26
- kernel_init: Callable = partial(kernel_init, 1.0)
26
+ kernel_init: Callable = kernel_init(1.0)
27
27
 
28
28
  @nn.compact
29
29
  def __call__(self, x):
@@ -34,7 +34,7 @@ class PatchEmbedding(nn.Module):
34
34
  kernel_size=(self.patch_size, self.patch_size),
35
35
  strides=(self.patch_size, self.patch_size),
36
36
  dtype=self.dtype,
37
- kernel_init=self.kernel_init(),
37
+ kernel_init=self.kernel_init,
38
38
  precision=self.precision)(x)
39
39
  x = jnp.reshape(x, (batch, -1, self.embedding_dim))
40
40
  return x
@@ -67,7 +67,7 @@ class UViT(nn.Module):
67
67
  norm_groups:int=8
68
68
  dtype: Optional[Dtype] = None
69
69
  precision: PrecisionLike = None
70
- kernel_init: Callable = partial(kernel_init, 1.0)
70
+ kernel_init: Callable = partial(kernel_init)
71
71
  add_residualblock_output: bool = False
72
72
 
73
73
  def setup(self):
@@ -86,10 +86,10 @@ class UViT(nn.Module):
86
86
 
87
87
  # Patch embedding
88
88
  x = PatchEmbedding(patch_size=self.patch_size, embedding_dim=self.emb_features,
89
- dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init)(x)
89
+ dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init(1.0))(x)
90
90
  num_patches = x.shape[1]
91
91
 
92
- context_emb = nn.DenseGeneral(features=self.emb_features, kernel_init=self.kernel_init(),
92
+ context_emb = nn.DenseGeneral(features=self.emb_features, kernel_init=self.kernel_init(1.0),
93
93
  dtype=self.dtype, precision=self.precision)(textcontext)
94
94
  num_text_tokens = textcontext.shape[1]
95
95
 
@@ -112,7 +112,7 @@ class UViT(nn.Module):
112
112
  dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
113
113
  use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.force_fp32_for_softmax,
114
114
  only_pure_attention=False,
115
- kernel_init=self.kernel_init())(x)
115
+ kernel_init=self.kernel_init(1.0))(x)
116
116
  skips.append(x)
117
117
 
118
118
  # Middle block
@@ -120,48 +120,41 @@ class UViT(nn.Module):
120
120
  dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
121
121
  use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.force_fp32_for_softmax,
122
122
  only_pure_attention=False,
123
- kernel_init=self.kernel_init())(x)
123
+ kernel_init=self.kernel_init(1.0))(x)
124
124
 
125
125
  # # Out blocks
126
126
  for i in range(self.num_layers // 2):
127
127
  x = jnp.concatenate([x, skips.pop()], axis=-1)
128
- x = nn.DenseGeneral(features=self.emb_features, kernel_init=self.kernel_init(),
128
+ x = nn.DenseGeneral(features=self.emb_features, kernel_init=self.kernel_init(1.0),
129
129
  dtype=self.dtype, precision=self.precision)(x)
130
130
  x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads,
131
131
  dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
132
132
  use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.force_fp32_for_softmax,
133
133
  only_pure_attention=False,
134
- kernel_init=self.kernel_init())(x)
134
+ kernel_init=self.kernel_init(1.0))(x)
135
135
 
136
136
  # print(f'Shape of x after transformer blocks: {x.shape}')
137
137
  x = self.norm()(x)
138
138
 
139
- # print(f'Shape of x after norm: {x.shape}')
140
-
141
139
  patch_dim = self.patch_size ** 2 * self.output_channels
142
- x = nn.Dense(features=patch_dim, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init())(x)
143
- # print(f'Shape of x after patch dense layer: {x.shape}, patch_dim: {patch_dim}')
140
+ x = nn.Dense(features=patch_dim, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init(1.0))(x)
144
141
  x = x[:, 1 + num_text_tokens:, :]
145
142
  x = unpatchify(x, channels=self.output_channels)
146
- # print(f'Shape of x after final dense layer: {x.shape}')
147
143
 
148
144
  if self.add_residualblock_output:
149
145
  # Concatenate the original image
150
146
  x = jnp.concatenate([original_img, x], axis=-1)
151
147
 
152
- x = ResidualBlock(
148
+ x = ConvLayer(
153
149
  "conv",
154
- name="final_residual",
155
150
  features=64,
156
- kernel_init=self.kernel_init(1.0),
157
- kernel_size=(3,3),
151
+ kernel_size=(3, 3),
158
152
  strides=(1, 1),
159
- activation=self.activation,
160
- norm_groups=self.norm_groups,
153
+ # activation=jax.nn.mish
154
+ kernel_init=self.kernel_init(0.0),
161
155
  dtype=self.dtype,
162
- precision=self.precision,
163
- named_norms=False
164
- )(x, temb)
156
+ precision=self.precision
157
+ )(x)
165
158
 
166
159
  x = self.norm()(x)
167
160
  x = self.activation(x)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flaxdiff
3
- Version: 0.1.28
3
+ Version: 0.1.30
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author: Ashish Kumar Singh
6
6
  Author-email: ashishkmr472@gmail.com
@@ -4,10 +4,10 @@ flaxdiff/data/__init__.py,sha256=PM3PkHihyohT5SHVYKc8vQ4IeVfGPpCktkSVwvqMjQ4,52
4
4
  flaxdiff/data/online_loader.py,sha256=LIK_O1C3yDPvvAEOWvsJrVeBopVqjg2IOMTbiSIvH6M,11025
5
5
  flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,26
6
6
  flaxdiff/models/attention.py,sha256=ZbDGIb5Q6FRqJ6qRY660cqw4WvF9IwCnhEuYdTpLPdM,13023
7
- flaxdiff/models/common.py,sha256=fw_gP7PZayO6RVe6xSf-7FtVq-S0pp5U6NgHg4PlKO8,10990
7
+ flaxdiff/models/common.py,sha256=hWsSs2BP2J-JN1s4qLRr-h-KYkcVyl2hOp1Wsm_L-h8,10994
8
8
  flaxdiff/models/favor_fastattn.py,sha256=79Ew1nqarsNLPzZaBSd1ILORzJr74CupYeqGiCQK5E4,27689
9
9
  flaxdiff/models/simple_unet.py,sha256=h1o9mQlLJy7Ec8Pz_O5miRbAyUaM5UNhSs-oXzpQvZo,10763
10
- flaxdiff/models/simple_vit.py,sha256=W2LxTKWA0wJHbPLf4hd2eUO4-ZV5u0Y-M168QulGwTg,7786
10
+ flaxdiff/models/simple_vit.py,sha256=Nnrlo5T9IUu3lu6y-SIWIgfISc07uOztBB4kyfBrQVY,7443
11
11
  flaxdiff/models/autoencoder/__init__.py,sha256=qY-7MldZpsfkF-_T2LqlRK7VHbqfmosz0NmvzDlBkOk,78
12
12
  flaxdiff/models/autoencoder/autoencoder.py,sha256=27_hYl0yXAdH9Mx4Xu9J79mSNo-FEKr9SxhVaS3ffn4,591
13
13
  flaxdiff/models/autoencoder/diffusers.py,sha256=JHeFLCxiHhu-QHwhKiCuKsQJn4AZumquiuxgZkiYGQ0,3643
@@ -34,7 +34,7 @@ flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk,
34
34
  flaxdiff/trainer/autoencoder_trainer.py,sha256=al7AsZ7yeDMEiDD-gbcXf0ADq_xfk1VMxvg24GfA-XQ,7008
35
35
  flaxdiff/trainer/diffusion_trainer.py,sha256=wKkg63DWZjx2MoM3VQNCDIr40rWN8fUGxH9jWWxfZao,9373
36
36
  flaxdiff/trainer/simple_trainer.py,sha256=Z77zRS5viJpd2Mpl6sonJk5WcnEWi2Cd4gl4u5tIX2M,18206
37
- flaxdiff-0.1.28.dist-info/METADATA,sha256=AeUPnS3eT-lJSMSM9p1J_HBkJz9f5QFey5r0wdZddH8,22083
38
- flaxdiff-0.1.28.dist-info/WHEEL,sha256=cVxcB9AmuTcXqmwrtPhNK88dr7IR_b6qagTj0UvIEbY,91
39
- flaxdiff-0.1.28.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
40
- flaxdiff-0.1.28.dist-info/RECORD,,
37
+ flaxdiff-0.1.30.dist-info/METADATA,sha256=lzEiqudjsqRLsDrI1icVnN3NM8hHrAqWloafwhxbhBE,22083
38
+ flaxdiff-0.1.30.dist-info/WHEEL,sha256=cVxcB9AmuTcXqmwrtPhNK88dr7IR_b6qagTj0UvIEbY,91
39
+ flaxdiff-0.1.30.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
40
+ flaxdiff-0.1.30.dist-info/RECORD,,