flaxdiff 0.1.35__py3-none-any.whl → 0.1.35.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -84,9 +84,9 @@ def default_feature_extractor(sample):
84
84
  "caption": sample["caption"],
85
85
  }
86
86
 
87
-
88
87
  def map_sample(
89
- sample,
88
+ url,
89
+ caption,
90
90
  image_shape=(256, 256),
91
91
  min_image_shape=(128, 128),
92
92
  timeout=15,
@@ -94,11 +94,8 @@ def map_sample(
94
94
  upscale_interpolation=cv2.INTER_CUBIC,
95
95
  downscale_interpolation=cv2.INTER_AREA,
96
96
  image_processor=default_image_processor,
97
- feature_extractor=default_feature_extractor,
98
97
  ):
99
98
  try:
100
- features = feature_extractor(sample)
101
- url, caption = features["url"], features["caption"]
102
99
  # Assuming fetch_single_image is defined elsewhere
103
100
  image = fetch_single_image(url, timeout=timeout, retries=retries)
104
101
  if image is None:
@@ -130,7 +127,6 @@ def map_sample(
130
127
  # })
131
128
  pass
132
129
 
133
-
134
130
  def map_batch(
135
131
  batch, num_threads=256, image_shape=(256, 256),
136
132
  min_image_shape=(128, 128),
@@ -147,48 +143,52 @@ def map_batch(
147
143
  downscale_interpolation=downscale_interpolation,
148
144
  feature_extractor=feature_extractor
149
145
  )
146
+ features = feature_extractor(batch)
147
+ url, caption = features["url"], features["caption"]
150
148
  with ThreadPoolExecutor(max_workers=num_threads) as executor:
151
- executor.map(map_sample_fn, batch)
149
+ executor.map(map_sample_fn, url, caption)
150
+ return None
152
151
  except Exception as e:
153
152
  print(f"Error maping batch", e)
154
153
  traceback.print_exc()
155
- error_queue.put_nowait({
156
- "batch": batch,
157
- "error": str(e)
158
- })
159
- pass
160
-
161
-
162
- def map_batch_repeat_forever(
163
- batch, num_threads=256, image_shape=(256, 256),
164
- min_image_shape=(128, 128),
165
- timeout=15, retries=3, image_processor=default_image_processor,
166
- upscale_interpolation=cv2.INTER_CUBIC,
167
- downscale_interpolation=cv2.INTER_AREA,
168
- feature_extractor=default_feature_extractor,
169
- ):
170
- while True: # Repeat forever
171
- try:
172
- map_sample_fn = partial(
173
- map_sample, image_shape=image_shape, min_image_shape=min_image_shape,
174
- timeout=timeout, retries=retries, image_processor=image_processor,
175
- upscale_interpolation=upscale_interpolation,
176
- downscale_interpolation=downscale_interpolation,
177
- feature_extractor=feature_extractor
178
- )
179
- with ThreadPoolExecutor(max_workers=num_threads) as executor:
180
- executor.map(map_sample_fn, batch)
181
- # Shuffle the batch
182
- batch = batch.shuffle(seed=np.random.randint(0, 1000000))
183
- except Exception as e:
184
- print(f"Error maping batch", e)
185
- traceback.print_exc()
186
- error_queue.put_nowait({
187
- "batch": batch,
188
- "error": str(e)
189
- })
190
- pass
191
-
154
+ # error_queue.put_nowait({
155
+ # "batch": batch,
156
+ # "error": str(e)
157
+ # })
158
+ return e
159
+
160
+
161
+ # def map_batch_repeat_forever(
162
+ # batch, num_threads=256, image_shape=(256, 256),
163
+ # min_image_shape=(128, 128),
164
+ # timeout=15, retries=3, image_processor=default_image_processor,
165
+ # upscale_interpolation=cv2.INTER_CUBIC,
166
+ # downscale_interpolation=cv2.INTER_AREA,
167
+ # feature_extractor=default_feature_extractor,
168
+ # ):
169
+ # while True: # Repeat forever
170
+ # try:
171
+ # map_sample_fn = partial(
172
+ # map_sample, image_shape=image_shape, min_image_shape=min_image_shape,
173
+ # timeout=timeout, retries=retries, image_processor=image_processor,
174
+ # upscale_interpolation=upscale_interpolation,
175
+ # downscale_interpolation=downscale_interpolation,
176
+ # feature_extractor=feature_extractor
177
+ # )
178
+ # features = feature_extractor(batch)
179
+ # url, caption = features["url"], features["caption"]
180
+ # with ThreadPoolExecutor(max_workers=num_threads) as executor:
181
+ # executor.map(map_sample_fn, url, caption)
182
+ # # Shuffle the batch
183
+ # batch = batch.shuffle(seed=np.random.randint(0, 1000000))
184
+ # except Exception as e:
185
+ # print(f"Error maping batch", e)
186
+ # traceback.print_exc()
187
+ # # error_queue.put_nowait({
188
+ # # "batch": batch,
189
+ # # "error": str(e)
190
+ # # })
191
+ # pass
192
192
 
193
193
  def parallel_image_loader(
194
194
  dataset: Dataset, num_workers: int = 8, image_shape=(256, 256),
@@ -214,11 +214,14 @@ def parallel_image_loader(
214
214
  iteration = 0
215
215
  while True:
216
216
  # Repeat forever
217
- # shards = [dataset[i*shard_len:(i+1)*shard_len]
218
- # for i in range(num_workers)]
219
- shards = [dataset.shard(num_shards=num_workers, index=i) for i in range(num_workers)]
217
+ shards = [dataset[i*shard_len:(i+1)*shard_len]
218
+ for i in range(num_workers)]
219
+ # shards = [dataset.shard(num_shards=num_workers, index=i) for i in range(num_workers)]
220
220
  print(f"mapping {len(shards)} shards")
221
- pool.map(map_batch_fn, shards)
221
+ errors = pool.map(map_batch_fn, shards)
222
+ for error in errors:
223
+ if error is not None:
224
+ print(f"Error in mapping batch", error)
222
225
  iteration += 1
223
226
  print(f"Shuffling dataset with seed {iteration}")
224
227
  dataset = dataset.shuffle(seed=iteration)
@@ -257,17 +260,6 @@ class ImageBatchIterator:
257
260
  )
258
261
  self.thread = threading.Thread(target=loader, args=(dataset,))
259
262
  self.thread.start()
260
- self.error_queue = queue.Queue()
261
-
262
- def error_fetcher():
263
- while True:
264
- error = error_queue.get()
265
- self.error_queue.put(error)
266
- self.error_thread = threading.Thread(target=error_fetcher)
267
- self.error_thread.start()
268
-
269
- def get_error(self):
270
- yield self.error_queue.get()
271
263
 
272
264
  def __iter__(self):
273
265
  return self
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flaxdiff
3
- Version: 0.1.35
3
+ Version: 0.1.35.2
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author: Ashish Kumar Singh
6
6
  Author-email: ashishkmr472@gmail.com
@@ -1,7 +1,7 @@
1
1
  flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  flaxdiff/utils.py,sha256=B0GcHlzlVYDNEIdh2v5qmP4u0neIT-FqexNohuyuCvg,2452
3
3
  flaxdiff/data/__init__.py,sha256=PM3PkHihyohT5SHVYKc8vQ4IeVfGPpCktkSVwvqMjQ4,52
4
- flaxdiff/data/online_loader.py,sha256=SPcKexYYBWCVOZqmH7v8Zfhtlf9NtnTsjdZKbPcoALY,13328
4
+ flaxdiff/data/online_loader.py,sha256=r7bIA1TvYcLZ-CyAmNkUJSB2ein0nJc_Mx2-j2GQ_IE,13306
5
5
  flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,26
6
6
  flaxdiff/models/attention.py,sha256=ZbDGIb5Q6FRqJ6qRY660cqw4WvF9IwCnhEuYdTpLPdM,13023
7
7
  flaxdiff/models/common.py,sha256=hWsSs2BP2J-JN1s4qLRr-h-KYkcVyl2hOp1Wsm_L-h8,10994
@@ -34,7 +34,7 @@ flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk,
34
34
  flaxdiff/trainer/autoencoder_trainer.py,sha256=al7AsZ7yeDMEiDD-gbcXf0ADq_xfk1VMxvg24GfA-XQ,7008
35
35
  flaxdiff/trainer/diffusion_trainer.py,sha256=wKkg63DWZjx2MoM3VQNCDIr40rWN8fUGxH9jWWxfZao,9373
36
36
  flaxdiff/trainer/simple_trainer.py,sha256=Z77zRS5viJpd2Mpl6sonJk5WcnEWi2Cd4gl4u5tIX2M,18206
37
- flaxdiff-0.1.35.dist-info/METADATA,sha256=rEM3EQWZtPTTtEyVAnUQa88P1Dh2daPq8zr4kDC3Al0,22083
38
- flaxdiff-0.1.35.dist-info/WHEEL,sha256=cVxcB9AmuTcXqmwrtPhNK88dr7IR_b6qagTj0UvIEbY,91
39
- flaxdiff-0.1.35.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
40
- flaxdiff-0.1.35.dist-info/RECORD,,
37
+ flaxdiff-0.1.35.2.dist-info/METADATA,sha256=B2UGjl6c0U5qj20BAi4dRo-7Y59fhG2CMj4XRu8CgAw,22085
38
+ flaxdiff-0.1.35.2.dist-info/WHEEL,sha256=cVxcB9AmuTcXqmwrtPhNK88dr7IR_b6qagTj0UvIEbY,91
39
+ flaxdiff-0.1.35.2.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
40
+ flaxdiff-0.1.35.2.dist-info/RECORD,,