flaxdiff 0.1.30__py3-none-any.whl → 0.1.32__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -117,7 +117,12 @@ def map_sample(
117
117
  # "error": str(e)
118
118
  # })
119
119
  pass
120
-
120
+
121
+ def default_feature_extractor(sample):
122
+ return {
123
+ "url": sample["url"],
124
+ "caption": sample["caption"],
125
+ }
121
126
 
122
127
  def map_batch(
123
128
  batch, num_threads=256, image_shape=(256, 256),
@@ -125,6 +130,7 @@ def map_batch(
125
130
  timeout=15, retries=3, image_processor=default_image_processor,
126
131
  upscale_interpolation=cv2.INTER_CUBIC,
127
132
  downscale_interpolation=cv2.INTER_AREA,
133
+ feature_extractor=default_feature_extractor,
128
134
  ):
129
135
  try:
130
136
  map_sample_fn = partial(map_sample, image_shape=image_shape, min_image_shape=min_image_shape,
@@ -132,7 +138,9 @@ def map_batch(
132
138
  upscale_interpolation=upscale_interpolation,
133
139
  downscale_interpolation=downscale_interpolation)
134
140
  with ThreadPoolExecutor(max_workers=num_threads) as executor:
135
- executor.map(map_sample_fn, batch["url"], batch['caption'])
141
+ features = feature_extractor(batch)
142
+ url, caption = features["url"], features["caption"]
143
+ executor.map(map_sample_fn, url, caption)
136
144
  except Exception as e:
137
145
  print(f"Error maping batch", e)
138
146
  traceback.print_exc()
@@ -149,12 +157,14 @@ def parallel_image_loader(
149
157
  num_threads=256, timeout=15, retries=3, image_processor=default_image_processor,
150
158
  upscale_interpolation=cv2.INTER_CUBIC,
151
159
  downscale_interpolation=cv2.INTER_AREA,
160
+ feature_extractor=default_feature_extractor,
152
161
  ):
153
162
  map_batch_fn = partial(map_batch, num_threads=num_threads, image_shape=image_shape,
154
163
  min_image_shape=min_image_shape,
155
164
  timeout=timeout, retries=retries, image_processor=image_processor,
156
165
  upscale_interpolation=upscale_interpolation,
157
- downscale_interpolation=downscale_interpolation)
166
+ downscale_interpolation=downscale_interpolation,
167
+ feature_extractor=feature_extractor)
158
168
  shard_len = len(dataset) // num_workers
159
169
  print(f"Local Shard lengths: {shard_len}")
160
170
  with multiprocessing.Pool(num_workers) as pool:
@@ -181,6 +191,7 @@ class ImageBatchIterator:
181
191
  image_processor=default_image_processor,
182
192
  upscale_interpolation=cv2.INTER_CUBIC,
183
193
  downscale_interpolation=cv2.INTER_AREA,
194
+ feature_extractor=default_feature_extractor,
184
195
  ):
185
196
  self.dataset = dataset
186
197
  self.num_workers = num_workers
@@ -191,7 +202,8 @@ class ImageBatchIterator:
191
202
  num_workers=num_workers,
192
203
  timeout=timeout, retries=retries, image_processor=image_processor,
193
204
  upscale_interpolation=upscale_interpolation,
194
- downscale_interpolation=downscale_interpolation)
205
+ downscale_interpolation=downscale_interpolation,
206
+ feature_extractor=feature_extractor)
195
207
  self.thread = threading.Thread(target=loader, args=(dataset,))
196
208
  self.thread.start()
197
209
 
@@ -256,6 +268,7 @@ class OnlineStreamingDataLoader():
256
268
  image_processor=default_image_processor,
257
269
  upscale_interpolation=cv2.INTER_CUBIC,
258
270
  downscale_interpolation=cv2.INTER_AREA,
271
+ feature_extractor=default_feature_extractor,
259
272
  ):
260
273
  if isinstance(dataset, str):
261
274
  dataset_path = dataset
@@ -281,7 +294,8 @@ class OnlineStreamingDataLoader():
281
294
  num_workers=num_workers, batch_size=batch_size, num_threads=num_threads,
282
295
  timeout=timeout, retries=retries, image_processor=image_processor,
283
296
  upscale_interpolation=upscale_interpolation,
284
- downscale_interpolation=downscale_interpolation)
297
+ downscale_interpolation=downscale_interpolation,
298
+ feature_extractor=feature_extractor)
285
299
  self.batch_size = batch_size
286
300
 
287
301
  # Launch a thread to load batches in the background
@@ -23,7 +23,7 @@ class PatchEmbedding(nn.Module):
23
23
  embedding_dim: int
24
24
  dtype: Any = jnp.float32
25
25
  precision: Any = jax.lax.Precision.HIGH
26
- kernel_init: Callable = kernel_init(1.0)
26
+ kernel_init: Callable = partial(kernel_init, 1.0)
27
27
 
28
28
  @nn.compact
29
29
  def __call__(self, x):
@@ -34,7 +34,7 @@ class PatchEmbedding(nn.Module):
34
34
  kernel_size=(self.patch_size, self.patch_size),
35
35
  strides=(self.patch_size, self.patch_size),
36
36
  dtype=self.dtype,
37
- kernel_init=self.kernel_init,
37
+ kernel_init=self.kernel_init(),
38
38
  precision=self.precision)(x)
39
39
  x = jnp.reshape(x, (batch, -1, self.embedding_dim))
40
40
  return x
@@ -67,7 +67,7 @@ class UViT(nn.Module):
67
67
  norm_groups:int=8
68
68
  dtype: Optional[Dtype] = None
69
69
  precision: PrecisionLike = None
70
- kernel_init: Callable = partial(kernel_init)
70
+ kernel_init: Callable = partial(kernel_init, scale=1.0)
71
71
  add_residualblock_output: bool = False
72
72
 
73
73
  def setup(self):
@@ -86,10 +86,10 @@ class UViT(nn.Module):
86
86
 
87
87
  # Patch embedding
88
88
  x = PatchEmbedding(patch_size=self.patch_size, embedding_dim=self.emb_features,
89
- dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init(1.0))(x)
89
+ dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init)(x)
90
90
  num_patches = x.shape[1]
91
91
 
92
- context_emb = nn.DenseGeneral(features=self.emb_features, kernel_init=self.kernel_init(1.0),
92
+ context_emb = nn.DenseGeneral(features=self.emb_features, kernel_init=self.kernel_init(),
93
93
  dtype=self.dtype, precision=self.precision)(textcontext)
94
94
  num_text_tokens = textcontext.shape[1]
95
95
 
@@ -112,7 +112,7 @@ class UViT(nn.Module):
112
112
  dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
113
113
  use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.force_fp32_for_softmax,
114
114
  only_pure_attention=False,
115
- kernel_init=self.kernel_init(1.0))(x)
115
+ kernel_init=self.kernel_init())(x)
116
116
  skips.append(x)
117
117
 
118
118
  # Middle block
@@ -120,24 +120,24 @@ class UViT(nn.Module):
120
120
  dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
121
121
  use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.force_fp32_for_softmax,
122
122
  only_pure_attention=False,
123
- kernel_init=self.kernel_init(1.0))(x)
123
+ kernel_init=self.kernel_init())(x)
124
124
 
125
125
  # # Out blocks
126
126
  for i in range(self.num_layers // 2):
127
127
  x = jnp.concatenate([x, skips.pop()], axis=-1)
128
- x = nn.DenseGeneral(features=self.emb_features, kernel_init=self.kernel_init(1.0),
128
+ x = nn.DenseGeneral(features=self.emb_features, kernel_init=self.kernel_init(),
129
129
  dtype=self.dtype, precision=self.precision)(x)
130
130
  x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads,
131
131
  dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
132
132
  use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.force_fp32_for_softmax,
133
133
  only_pure_attention=False,
134
- kernel_init=self.kernel_init(1.0))(x)
134
+ kernel_init=self.kernel_init())(x)
135
135
 
136
136
  # print(f'Shape of x after transformer blocks: {x.shape}')
137
137
  x = self.norm()(x)
138
138
 
139
139
  patch_dim = self.patch_size ** 2 * self.output_channels
140
- x = nn.Dense(features=patch_dim, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init(1.0))(x)
140
+ x = nn.Dense(features=patch_dim, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init())(x)
141
141
  x = x[:, 1 + num_text_tokens:, :]
142
142
  x = unpatchify(x, channels=self.output_channels)
143
143
 
@@ -151,7 +151,7 @@ class UViT(nn.Module):
151
151
  kernel_size=(3, 3),
152
152
  strides=(1, 1),
153
153
  # activation=jax.nn.mish
154
- kernel_init=self.kernel_init(0.0),
154
+ kernel_init=self.kernel_init(scale=0.0),
155
155
  dtype=self.dtype,
156
156
  precision=self.precision
157
157
  )(x)
@@ -165,7 +165,7 @@ class UViT(nn.Module):
165
165
  kernel_size=(3, 3),
166
166
  strides=(1, 1),
167
167
  # activation=jax.nn.mish
168
- kernel_init=self.kernel_init(0.0),
168
+ kernel_init=self.kernel_init(scale=0.0),
169
169
  dtype=self.dtype,
170
170
  precision=self.precision
171
171
  )(x)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flaxdiff
3
- Version: 0.1.30
3
+ Version: 0.1.32
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author: Ashish Kumar Singh
6
6
  Author-email: ashishkmr472@gmail.com
@@ -1,13 +1,13 @@
1
1
  flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  flaxdiff/utils.py,sha256=B0GcHlzlVYDNEIdh2v5qmP4u0neIT-FqexNohuyuCvg,2452
3
3
  flaxdiff/data/__init__.py,sha256=PM3PkHihyohT5SHVYKc8vQ4IeVfGPpCktkSVwvqMjQ4,52
4
- flaxdiff/data/online_loader.py,sha256=LIK_O1C3yDPvvAEOWvsJrVeBopVqjg2IOMTbiSIvH6M,11025
4
+ flaxdiff/data/online_loader.py,sha256=Yi1thc2izJoXeB9UPTAkfWhWp8m4UYWi6cDwW6I6zUc,11661
5
5
  flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,26
6
6
  flaxdiff/models/attention.py,sha256=ZbDGIb5Q6FRqJ6qRY660cqw4WvF9IwCnhEuYdTpLPdM,13023
7
7
  flaxdiff/models/common.py,sha256=hWsSs2BP2J-JN1s4qLRr-h-KYkcVyl2hOp1Wsm_L-h8,10994
8
8
  flaxdiff/models/favor_fastattn.py,sha256=79Ew1nqarsNLPzZaBSd1ILORzJr74CupYeqGiCQK5E4,27689
9
9
  flaxdiff/models/simple_unet.py,sha256=h1o9mQlLJy7Ec8Pz_O5miRbAyUaM5UNhSs-oXzpQvZo,10763
10
- flaxdiff/models/simple_vit.py,sha256=Nnrlo5T9IUu3lu6y-SIWIgfISc07uOztBB4kyfBrQVY,7443
10
+ flaxdiff/models/simple_vit.py,sha256=NHt6v-teGjiI65fk1l1WN3WqfeqTE7xY9VYqBiYUDgI,7454
11
11
  flaxdiff/models/autoencoder/__init__.py,sha256=qY-7MldZpsfkF-_T2LqlRK7VHbqfmosz0NmvzDlBkOk,78
12
12
  flaxdiff/models/autoencoder/autoencoder.py,sha256=27_hYl0yXAdH9Mx4Xu9J79mSNo-FEKr9SxhVaS3ffn4,591
13
13
  flaxdiff/models/autoencoder/diffusers.py,sha256=JHeFLCxiHhu-QHwhKiCuKsQJn4AZumquiuxgZkiYGQ0,3643
@@ -34,7 +34,7 @@ flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk,
34
34
  flaxdiff/trainer/autoencoder_trainer.py,sha256=al7AsZ7yeDMEiDD-gbcXf0ADq_xfk1VMxvg24GfA-XQ,7008
35
35
  flaxdiff/trainer/diffusion_trainer.py,sha256=wKkg63DWZjx2MoM3VQNCDIr40rWN8fUGxH9jWWxfZao,9373
36
36
  flaxdiff/trainer/simple_trainer.py,sha256=Z77zRS5viJpd2Mpl6sonJk5WcnEWi2Cd4gl4u5tIX2M,18206
37
- flaxdiff-0.1.30.dist-info/METADATA,sha256=lzEiqudjsqRLsDrI1icVnN3NM8hHrAqWloafwhxbhBE,22083
38
- flaxdiff-0.1.30.dist-info/WHEEL,sha256=cVxcB9AmuTcXqmwrtPhNK88dr7IR_b6qagTj0UvIEbY,91
39
- flaxdiff-0.1.30.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
40
- flaxdiff-0.1.30.dist-info/RECORD,,
37
+ flaxdiff-0.1.32.dist-info/METADATA,sha256=yqODOUEAlcE_60tOOfSP7XoH0VTBm3dY9J1ali36VFY,22083
38
+ flaxdiff-0.1.32.dist-info/WHEEL,sha256=cVxcB9AmuTcXqmwrtPhNK88dr7IR_b6qagTj0UvIEbY,91
39
+ flaxdiff-0.1.32.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
40
+ flaxdiff-0.1.32.dist-info/RECORD,,