flaxdiff 0.1.35.1__py3-none-any.whl → 0.1.35.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/data/online_loader.py +43 -51
- {flaxdiff-0.1.35.1.dist-info → flaxdiff-0.1.35.2.dist-info}/METADATA +1 -1
- {flaxdiff-0.1.35.1.dist-info → flaxdiff-0.1.35.2.dist-info}/RECORD +5 -5
- {flaxdiff-0.1.35.1.dist-info → flaxdiff-0.1.35.2.dist-info}/WHEEL +0 -0
- {flaxdiff-0.1.35.1.dist-info → flaxdiff-0.1.35.2.dist-info}/top_level.txt +0 -0
flaxdiff/data/online_loader.py
CHANGED
@@ -84,7 +84,6 @@ def default_feature_extractor(sample):
|
|
84
84
|
"caption": sample["caption"],
|
85
85
|
}
|
86
86
|
|
87
|
-
|
88
87
|
def map_sample(
|
89
88
|
url,
|
90
89
|
caption,
|
@@ -128,7 +127,6 @@ def map_sample(
|
|
128
127
|
# })
|
129
128
|
pass
|
130
129
|
|
131
|
-
|
132
130
|
def map_batch(
|
133
131
|
batch, num_threads=256, image_shape=(256, 256),
|
134
132
|
min_image_shape=(128, 128),
|
@@ -149,46 +147,48 @@ def map_batch(
|
|
149
147
|
url, caption = features["url"], features["caption"]
|
150
148
|
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
151
149
|
executor.map(map_sample_fn, url, caption)
|
150
|
+
return None
|
152
151
|
except Exception as e:
|
153
152
|
print(f"Error maping batch", e)
|
154
153
|
traceback.print_exc()
|
155
|
-
error_queue.put_nowait({
|
156
|
-
|
157
|
-
|
158
|
-
})
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
def map_batch_repeat_forever(
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
):
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
154
|
+
# error_queue.put_nowait({
|
155
|
+
# "batch": batch,
|
156
|
+
# "error": str(e)
|
157
|
+
# })
|
158
|
+
return e
|
159
|
+
|
160
|
+
|
161
|
+
# def map_batch_repeat_forever(
|
162
|
+
# batch, num_threads=256, image_shape=(256, 256),
|
163
|
+
# min_image_shape=(128, 128),
|
164
|
+
# timeout=15, retries=3, image_processor=default_image_processor,
|
165
|
+
# upscale_interpolation=cv2.INTER_CUBIC,
|
166
|
+
# downscale_interpolation=cv2.INTER_AREA,
|
167
|
+
# feature_extractor=default_feature_extractor,
|
168
|
+
# ):
|
169
|
+
# while True: # Repeat forever
|
170
|
+
# try:
|
171
|
+
# map_sample_fn = partial(
|
172
|
+
# map_sample, image_shape=image_shape, min_image_shape=min_image_shape,
|
173
|
+
# timeout=timeout, retries=retries, image_processor=image_processor,
|
174
|
+
# upscale_interpolation=upscale_interpolation,
|
175
|
+
# downscale_interpolation=downscale_interpolation,
|
176
|
+
# feature_extractor=feature_extractor
|
177
|
+
# )
|
178
|
+
# features = feature_extractor(batch)
|
179
|
+
# url, caption = features["url"], features["caption"]
|
180
|
+
# with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
181
|
+
# executor.map(map_sample_fn, url, caption)
|
182
|
+
# # Shuffle the batch
|
183
|
+
# batch = batch.shuffle(seed=np.random.randint(0, 1000000))
|
184
|
+
# except Exception as e:
|
185
|
+
# print(f"Error maping batch", e)
|
186
|
+
# traceback.print_exc()
|
187
|
+
# # error_queue.put_nowait({
|
188
|
+
# # "batch": batch,
|
189
|
+
# # "error": str(e)
|
190
|
+
# # })
|
191
|
+
# pass
|
192
192
|
|
193
193
|
def parallel_image_loader(
|
194
194
|
dataset: Dataset, num_workers: int = 8, image_shape=(256, 256),
|
@@ -218,7 +218,10 @@ def parallel_image_loader(
|
|
218
218
|
for i in range(num_workers)]
|
219
219
|
# shards = [dataset.shard(num_shards=num_workers, index=i) for i in range(num_workers)]
|
220
220
|
print(f"mapping {len(shards)} shards")
|
221
|
-
pool.map(map_batch_fn, shards)
|
221
|
+
errors = pool.map(map_batch_fn, shards)
|
222
|
+
for error in errors:
|
223
|
+
if error is not None:
|
224
|
+
print(f"Error in mapping batch", error)
|
222
225
|
iteration += 1
|
223
226
|
print(f"Shuffling dataset with seed {iteration}")
|
224
227
|
dataset = dataset.shuffle(seed=iteration)
|
@@ -257,17 +260,6 @@ class ImageBatchIterator:
|
|
257
260
|
)
|
258
261
|
self.thread = threading.Thread(target=loader, args=(dataset,))
|
259
262
|
self.thread.start()
|
260
|
-
self.error_queue = queue.Queue()
|
261
|
-
|
262
|
-
def error_fetcher():
|
263
|
-
while True:
|
264
|
-
error = error_queue.get()
|
265
|
-
self.error_queue.put(error)
|
266
|
-
self.error_thread = threading.Thread(target=error_fetcher)
|
267
|
-
self.error_thread.start()
|
268
|
-
|
269
|
-
def get_error(self):
|
270
|
-
yield self.error_queue.get()
|
271
263
|
|
272
264
|
def __iter__(self):
|
273
265
|
return self
|
@@ -1,7 +1,7 @@
|
|
1
1
|
flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
flaxdiff/utils.py,sha256=B0GcHlzlVYDNEIdh2v5qmP4u0neIT-FqexNohuyuCvg,2452
|
3
3
|
flaxdiff/data/__init__.py,sha256=PM3PkHihyohT5SHVYKc8vQ4IeVfGPpCktkSVwvqMjQ4,52
|
4
|
-
flaxdiff/data/online_loader.py,sha256=
|
4
|
+
flaxdiff/data/online_loader.py,sha256=r7bIA1TvYcLZ-CyAmNkUJSB2ein0nJc_Mx2-j2GQ_IE,13306
|
5
5
|
flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,26
|
6
6
|
flaxdiff/models/attention.py,sha256=ZbDGIb5Q6FRqJ6qRY660cqw4WvF9IwCnhEuYdTpLPdM,13023
|
7
7
|
flaxdiff/models/common.py,sha256=hWsSs2BP2J-JN1s4qLRr-h-KYkcVyl2hOp1Wsm_L-h8,10994
|
@@ -34,7 +34,7 @@ flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk,
|
|
34
34
|
flaxdiff/trainer/autoencoder_trainer.py,sha256=al7AsZ7yeDMEiDD-gbcXf0ADq_xfk1VMxvg24GfA-XQ,7008
|
35
35
|
flaxdiff/trainer/diffusion_trainer.py,sha256=wKkg63DWZjx2MoM3VQNCDIr40rWN8fUGxH9jWWxfZao,9373
|
36
36
|
flaxdiff/trainer/simple_trainer.py,sha256=Z77zRS5viJpd2Mpl6sonJk5WcnEWi2Cd4gl4u5tIX2M,18206
|
37
|
-
flaxdiff-0.1.35.
|
38
|
-
flaxdiff-0.1.35.
|
39
|
-
flaxdiff-0.1.35.
|
40
|
-
flaxdiff-0.1.35.
|
37
|
+
flaxdiff-0.1.35.2.dist-info/METADATA,sha256=B2UGjl6c0U5qj20BAi4dRo-7Y59fhG2CMj4XRu8CgAw,22085
|
38
|
+
flaxdiff-0.1.35.2.dist-info/WHEEL,sha256=cVxcB9AmuTcXqmwrtPhNK88dr7IR_b6qagTj0UvIEbY,91
|
39
|
+
flaxdiff-0.1.35.2.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
|
40
|
+
flaxdiff-0.1.35.2.dist-info/RECORD,,
|
File without changes
|
File without changes
|