flaxdiff 0.1.15__py3-none-any.whl → 0.1.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -27,7 +27,6 @@ USER_AGENT = get_datasets_user_agent()
27
27
  data_queue = Queue(16*2000)
28
28
  error_queue = Queue(16*2000)
29
29
 
30
-
31
30
  def fetch_single_image(image_url, timeout=None, retries=0):
32
31
  for _ in range(retries + 1):
33
32
  try:
@@ -46,11 +45,13 @@ def fetch_single_image(image_url, timeout=None, retries=0):
46
45
  def map_sample(
47
46
  url, caption,
48
47
  image_shape=(256, 256),
48
+ timeout=15,
49
+ retries=3,
49
50
  upscale_interpolation=cv2.INTER_LANCZOS4,
50
51
  downscale_interpolation=cv2.INTER_AREA,
51
52
  ):
52
53
  try:
53
- image = fetch_single_image(url, timeout=15, retries=3) # Assuming fetch_single_image is defined elsewhere
54
+ image = fetch_single_image(url, timeout=timeout, retries=retries) # Assuming fetch_single_image is defined elsewhere
54
55
  if image is None:
55
56
  return
56
57
 
@@ -84,15 +85,24 @@ def map_sample(
84
85
  "original_width": original_width,
85
86
  })
86
87
  except Exception as e:
88
+ print(f"Error in map_sample: {str(e)}")
87
89
  error_queue.put({
88
90
  "url": url,
89
91
  "caption": caption,
90
92
  "error": str(e)
91
93
  })
92
-
93
- def map_batch(batch, num_threads=256, image_shape=(256, 256), timeout=None, retries=0):
94
- with ThreadPoolExecutor(max_workers=num_threads) as executor:
95
- executor.map(map_sample, batch["url"], batch['caption'], image_shape=image_shape, timeout=timeout, retries=retries)
94
+
95
+ def map_batch(batch, num_threads=256, image_shape=(256, 256), timeout=15, retries=3):
96
+ try:
97
+ map_sample_fn = partial(map_sample, image_shape=image_shape, timeout=timeout, retries=retries)
98
+ with ThreadPoolExecutor(max_workers=num_threads) as executor:
99
+ executor.map(map_sample_fn, batch["url"], batch['caption'])
100
+ except Exception as e:
101
+ print(f"Error in map_batch: {str(e)}")
102
+ error_queue.put({
103
+ "batch": batch,
104
+ "error": str(e)
105
+ })
96
106
 
97
107
  def parallel_image_loader(dataset: Dataset, num_workers: int = 8, image_shape=(256, 256), num_threads=256):
98
108
  map_batch_fn = partial(map_batch, num_threads=num_threads, image_shape=image_shape)
@@ -102,8 +112,10 @@ def parallel_image_loader(dataset: Dataset, num_workers: int = 8, image_shape=(2
102
112
  iteration = 0
103
113
  while True:
104
114
  # Repeat forever
105
- dataset = dataset.shuffle(seed=iteration)
115
+ print(f"Shuffling dataset with seed {iteration}")
116
+ # dataset = dataset.shuffle(seed=iteration)
106
117
  shards = [dataset[i*shard_len:(i+1)*shard_len] for i in range(num_workers)]
118
+ print(f"mapping {len(shards)} shards")
107
119
  pool.map(map_batch_fn, shards)
108
120
  iteration += 1
109
121
 
@@ -205,4 +217,3 @@ class OnlineStreamingDataLoader():
205
217
 
206
218
  def __len__(self):
207
219
  return len(self.dataset)
208
-
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flaxdiff
3
- Version: 0.1.15
3
+ Version: 0.1.16
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author: Ashish Kumar Singh
6
6
  Author-email: ashishkmr472@gmail.com
@@ -1,7 +1,7 @@
1
1
  flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  flaxdiff/utils.py,sha256=B0GcHlzlVYDNEIdh2v5qmP4u0neIT-FqexNohuyuCvg,2452
3
3
  flaxdiff/data/__init__.py,sha256=PM3PkHihyohT5SHVYKc8vQ4IeVfGPpCktkSVwvqMjQ4,52
4
- flaxdiff/data/online_loader.py,sha256=Xd4xAerzOU65tvb0cuM8S3Hnm8xv1dJ62xbxAxZkeBw,7307
4
+ flaxdiff/data/online_loader.py,sha256=nrtZU4srZHsg3iN0sG91y_6nY7QtYXRPLk5rGn_BTIU,7728
5
5
  flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,26
6
6
  flaxdiff/models/attention.py,sha256=ZbDGIb5Q6FRqJ6qRY660cqw4WvF9IwCnhEuYdTpLPdM,13023
7
7
  flaxdiff/models/common.py,sha256=fd-Fl0VCNEBjijHNwGBqYL5VvXe9u0347h25czNTmRw,10780
@@ -34,7 +34,7 @@ flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk,
34
34
  flaxdiff/trainer/autoencoder_trainer.py,sha256=al7AsZ7yeDMEiDD-gbcXf0ADq_xfk1VMxvg24GfA-XQ,7008
35
35
  flaxdiff/trainer/diffusion_trainer.py,sha256=wKkg63DWZjx2MoM3VQNCDIr40rWN8fUGxH9jWWxfZao,9373
36
36
  flaxdiff/trainer/simple_trainer.py,sha256=Z77zRS5viJpd2Mpl6sonJk5WcnEWi2Cd4gl4u5tIX2M,18206
37
- flaxdiff-0.1.15.dist-info/METADATA,sha256=45ifUa-j2rPqP3s734HRy2rIM1StRvSZ4lIJI0R3sTw,22083
38
- flaxdiff-0.1.15.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
39
- flaxdiff-0.1.15.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
40
- flaxdiff-0.1.15.dist-info/RECORD,,
37
+ flaxdiff-0.1.16.dist-info/METADATA,sha256=BM2RLOiCDqRSWO_owxvWmL_PS3aFtNokPbf-qAuyK4o,22083
38
+ flaxdiff-0.1.16.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
39
+ flaxdiff-0.1.16.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
40
+ flaxdiff-0.1.16.dist-info/RECORD,,