flaxdiff 0.1.35.2__py3-none-any.whl → 0.1.35.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -26,7 +26,6 @@ import traceback
26
26
  USER_AGENT = get_datasets_user_agent()
27
27
 
28
28
  data_queue = Queue(16*2000)
29
- error_queue = Queue(16*2000)
30
29
 
31
30
 
32
31
  def fetch_single_image(image_url, timeout=None, retries=0):
@@ -78,12 +77,6 @@ def default_image_processor(
78
77
  return image, original_height, original_width
79
78
 
80
79
 
81
- def default_feature_extractor(sample):
82
- return {
83
- "url": sample["url"],
84
- "caption": sample["caption"],
85
- }
86
-
87
80
  def map_sample(
88
81
  url,
89
82
  caption,
@@ -127,6 +120,14 @@ def map_sample(
127
120
  # })
128
121
  pass
129
122
 
123
+
124
+ def default_feature_extractor(sample):
125
+ return {
126
+ "url": sample["url"],
127
+ "caption": sample["caption"],
128
+ }
129
+
130
+
130
131
  def map_batch(
131
132
  batch, num_threads=256, image_shape=(256, 256),
132
133
  min_image_shape=(128, 128),
@@ -140,14 +141,12 @@ def map_batch(
140
141
  map_sample, image_shape=image_shape, min_image_shape=min_image_shape,
141
142
  timeout=timeout, retries=retries, image_processor=image_processor,
142
143
  upscale_interpolation=upscale_interpolation,
143
- downscale_interpolation=downscale_interpolation,
144
- feature_extractor=feature_extractor
144
+ downscale_interpolation=downscale_interpolation
145
145
  )
146
- features = feature_extractor(batch)
147
- url, caption = features["url"], features["caption"]
148
146
  with ThreadPoolExecutor(max_workers=num_threads) as executor:
147
+ features = feature_extractor(batch)
148
+ url, caption = features["url"], features["caption"]
149
149
  executor.map(map_sample_fn, url, caption)
150
- return None
151
150
  except Exception as e:
152
151
  print(f"Error maping batch", e)
153
152
  traceback.print_exc()
@@ -155,40 +154,8 @@ def map_batch(
155
154
  # "batch": batch,
156
155
  # "error": str(e)
157
156
  # })
158
- return e
159
-
160
-
161
- # def map_batch_repeat_forever(
162
- # batch, num_threads=256, image_shape=(256, 256),
163
- # min_image_shape=(128, 128),
164
- # timeout=15, retries=3, image_processor=default_image_processor,
165
- # upscale_interpolation=cv2.INTER_CUBIC,
166
- # downscale_interpolation=cv2.INTER_AREA,
167
- # feature_extractor=default_feature_extractor,
168
- # ):
169
- # while True: # Repeat forever
170
- # try:
171
- # map_sample_fn = partial(
172
- # map_sample, image_shape=image_shape, min_image_shape=min_image_shape,
173
- # timeout=timeout, retries=retries, image_processor=image_processor,
174
- # upscale_interpolation=upscale_interpolation,
175
- # downscale_interpolation=downscale_interpolation,
176
- # feature_extractor=feature_extractor
177
- # )
178
- # features = feature_extractor(batch)
179
- # url, caption = features["url"], features["caption"]
180
- # with ThreadPoolExecutor(max_workers=num_threads) as executor:
181
- # executor.map(map_sample_fn, url, caption)
182
- # # Shuffle the batch
183
- # batch = batch.shuffle(seed=np.random.randint(0, 1000000))
184
- # except Exception as e:
185
- # print(f"Error maping batch", e)
186
- # traceback.print_exc()
187
- # # error_queue.put_nowait({
188
- # # "batch": batch,
189
- # # "error": str(e)
190
- # # })
191
- # pass
157
+ pass
158
+
192
159
 
193
160
  def parallel_image_loader(
194
161
  dataset: Dataset, num_workers: int = 8, image_shape=(256, 256),
@@ -197,11 +164,9 @@ def parallel_image_loader(
197
164
  upscale_interpolation=cv2.INTER_CUBIC,
198
165
  downscale_interpolation=cv2.INTER_AREA,
199
166
  feature_extractor=default_feature_extractor,
200
- map_batch_fn=map_batch,
201
-
202
167
  ):
203
168
  map_batch_fn = partial(
204
- map_batch_fn, num_threads=num_threads, image_shape=image_shape,
169
+ map_batch, num_threads=num_threads, image_shape=image_shape,
205
170
  min_image_shape=min_image_shape,
206
171
  timeout=timeout, retries=retries, image_processor=image_processor,
207
172
  upscale_interpolation=upscale_interpolation,
@@ -209,23 +174,18 @@ def parallel_image_loader(
209
174
  feature_extractor=feature_extractor
210
175
  )
211
176
  shard_len = len(dataset) // num_workers
212
- print(f"Local Shard lengths: {shard_len}, workers: {num_workers}")
177
+ print(f"Local Shard lengths: {shard_len}")
213
178
  with multiprocessing.Pool(num_workers) as pool:
214
179
  iteration = 0
215
180
  while True:
216
181
  # Repeat forever
217
182
  shards = [dataset[i*shard_len:(i+1)*shard_len]
218
183
  for i in range(num_workers)]
219
- # shards = [dataset.shard(num_shards=num_workers, index=i) for i in range(num_workers)]
220
184
  print(f"mapping {len(shards)} shards")
221
- errors = pool.map(map_batch_fn, shards)
222
- for error in errors:
223
- if error is not None:
224
- print(f"Error in mapping batch", error)
185
+ pool.map(map_batch_fn, shards)
225
186
  iteration += 1
226
187
  print(f"Shuffling dataset with seed {iteration}")
227
188
  dataset = dataset.shuffle(seed=iteration)
228
- print(f"Dataset shuffled")
229
189
  # Clear the error queue
230
190
  # while not error_queue.empty():
231
191
  # error_queue.get_nowait()
@@ -240,7 +200,6 @@ class ImageBatchIterator:
240
200
  upscale_interpolation=cv2.INTER_CUBIC,
241
201
  downscale_interpolation=cv2.INTER_AREA,
242
202
  feature_extractor=default_feature_extractor,
243
- map_batch_fn=map_batch,
244
203
  ):
245
204
  self.dataset = dataset
246
205
  self.num_workers = num_workers
@@ -255,8 +214,7 @@ class ImageBatchIterator:
255
214
  image_processor=image_processor,
256
215
  upscale_interpolation=upscale_interpolation,
257
216
  downscale_interpolation=downscale_interpolation,
258
- feature_extractor=feature_extractor,
259
- map_batch_fn=map_batch_fn,
217
+ feature_extractor=feature_extractor
260
218
  )
261
219
  self.thread = threading.Thread(target=loader, args=(dataset,))
262
220
  self.thread.start()
@@ -323,7 +281,6 @@ class OnlineStreamingDataLoader():
323
281
  upscale_interpolation=cv2.INTER_CUBIC,
324
282
  downscale_interpolation=cv2.INTER_AREA,
325
283
  feature_extractor=default_feature_extractor,
326
- map_batch_fn=map_batch,
327
284
  ):
328
285
  if isinstance(dataset, str):
329
286
  dataset_path = dataset
@@ -351,8 +308,7 @@ class OnlineStreamingDataLoader():
351
308
  timeout=timeout, retries=retries, image_processor=image_processor,
352
309
  upscale_interpolation=upscale_interpolation,
353
310
  downscale_interpolation=downscale_interpolation,
354
- feature_extractor=feature_extractor,
355
- map_batch_fn=map_batch_fn,
311
+ feature_extractor=feature_extractor
356
312
  )
357
313
  self.batch_size = batch_size
358
314
 
@@ -364,7 +320,7 @@ class OnlineStreamingDataLoader():
364
320
  try:
365
321
  self.batch_queue.put(collate_fn(batch))
366
322
  except Exception as e:
367
- print("Error collating batch", e)
323
+ print("Error processing batch", e)
368
324
 
369
325
  self.loader_thread = threading.Thread(target=batch_loader)
370
326
  self.loader_thread.start()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flaxdiff
3
- Version: 0.1.35.2
3
+ Version: 0.1.35.3
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author: Ashish Kumar Singh
6
6
  Author-email: ashishkmr472@gmail.com
@@ -1,7 +1,7 @@
1
1
  flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  flaxdiff/utils.py,sha256=B0GcHlzlVYDNEIdh2v5qmP4u0neIT-FqexNohuyuCvg,2452
3
3
  flaxdiff/data/__init__.py,sha256=PM3PkHihyohT5SHVYKc8vQ4IeVfGPpCktkSVwvqMjQ4,52
4
- flaxdiff/data/online_loader.py,sha256=r7bIA1TvYcLZ-CyAmNkUJSB2ein0nJc_Mx2-j2GQ_IE,13306
4
+ flaxdiff/data/online_loader.py,sha256=fUM91etaEZmxP0ZxzE1TfxOyHzk1Yq45tYT_P3F-HT0,11311
5
5
  flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,26
6
6
  flaxdiff/models/attention.py,sha256=ZbDGIb5Q6FRqJ6qRY660cqw4WvF9IwCnhEuYdTpLPdM,13023
7
7
  flaxdiff/models/common.py,sha256=hWsSs2BP2J-JN1s4qLRr-h-KYkcVyl2hOp1Wsm_L-h8,10994
@@ -34,7 +34,7 @@ flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk,
34
34
  flaxdiff/trainer/autoencoder_trainer.py,sha256=al7AsZ7yeDMEiDD-gbcXf0ADq_xfk1VMxvg24GfA-XQ,7008
35
35
  flaxdiff/trainer/diffusion_trainer.py,sha256=wKkg63DWZjx2MoM3VQNCDIr40rWN8fUGxH9jWWxfZao,9373
36
36
  flaxdiff/trainer/simple_trainer.py,sha256=Z77zRS5viJpd2Mpl6sonJk5WcnEWi2Cd4gl4u5tIX2M,18206
37
- flaxdiff-0.1.35.2.dist-info/METADATA,sha256=B2UGjl6c0U5qj20BAi4dRo-7Y59fhG2CMj4XRu8CgAw,22085
38
- flaxdiff-0.1.35.2.dist-info/WHEEL,sha256=cVxcB9AmuTcXqmwrtPhNK88dr7IR_b6qagTj0UvIEbY,91
39
- flaxdiff-0.1.35.2.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
40
- flaxdiff-0.1.35.2.dist-info/RECORD,,
37
+ flaxdiff-0.1.35.3.dist-info/METADATA,sha256=g845MSjktfjXWKWtae4_ELwFvtsND8ysj4_yt572Rl4,22085
38
+ flaxdiff-0.1.35.3.dist-info/WHEEL,sha256=cVxcB9AmuTcXqmwrtPhNK88dr7IR_b6qagTj0UvIEbY,91
39
+ flaxdiff-0.1.35.3.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
40
+ flaxdiff-0.1.35.3.dist-info/RECORD,,