flaxdiff 0.1.35.2__py3-none-any.whl → 0.1.35.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/data/online_loader.py +19 -63
- {flaxdiff-0.1.35.2.dist-info → flaxdiff-0.1.35.3.dist-info}/METADATA +1 -1
- {flaxdiff-0.1.35.2.dist-info → flaxdiff-0.1.35.3.dist-info}/RECORD +5 -5
- {flaxdiff-0.1.35.2.dist-info → flaxdiff-0.1.35.3.dist-info}/WHEEL +0 -0
- {flaxdiff-0.1.35.2.dist-info → flaxdiff-0.1.35.3.dist-info}/top_level.txt +0 -0
flaxdiff/data/online_loader.py
CHANGED
@@ -26,7 +26,6 @@ import traceback
|
|
26
26
|
USER_AGENT = get_datasets_user_agent()
|
27
27
|
|
28
28
|
data_queue = Queue(16*2000)
|
29
|
-
error_queue = Queue(16*2000)
|
30
29
|
|
31
30
|
|
32
31
|
def fetch_single_image(image_url, timeout=None, retries=0):
|
@@ -78,12 +77,6 @@ def default_image_processor(
|
|
78
77
|
return image, original_height, original_width
|
79
78
|
|
80
79
|
|
81
|
-
def default_feature_extractor(sample):
|
82
|
-
return {
|
83
|
-
"url": sample["url"],
|
84
|
-
"caption": sample["caption"],
|
85
|
-
}
|
86
|
-
|
87
80
|
def map_sample(
|
88
81
|
url,
|
89
82
|
caption,
|
@@ -127,6 +120,14 @@ def map_sample(
|
|
127
120
|
# })
|
128
121
|
pass
|
129
122
|
|
123
|
+
|
124
|
+
def default_feature_extractor(sample):
|
125
|
+
return {
|
126
|
+
"url": sample["url"],
|
127
|
+
"caption": sample["caption"],
|
128
|
+
}
|
129
|
+
|
130
|
+
|
130
131
|
def map_batch(
|
131
132
|
batch, num_threads=256, image_shape=(256, 256),
|
132
133
|
min_image_shape=(128, 128),
|
@@ -140,14 +141,12 @@ def map_batch(
|
|
140
141
|
map_sample, image_shape=image_shape, min_image_shape=min_image_shape,
|
141
142
|
timeout=timeout, retries=retries, image_processor=image_processor,
|
142
143
|
upscale_interpolation=upscale_interpolation,
|
143
|
-
downscale_interpolation=downscale_interpolation
|
144
|
-
feature_extractor=feature_extractor
|
144
|
+
downscale_interpolation=downscale_interpolation
|
145
145
|
)
|
146
|
-
features = feature_extractor(batch)
|
147
|
-
url, caption = features["url"], features["caption"]
|
148
146
|
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
147
|
+
features = feature_extractor(batch)
|
148
|
+
url, caption = features["url"], features["caption"]
|
149
149
|
executor.map(map_sample_fn, url, caption)
|
150
|
-
return None
|
151
150
|
except Exception as e:
|
152
151
|
print(f"Error maping batch", e)
|
153
152
|
traceback.print_exc()
|
@@ -155,40 +154,8 @@ def map_batch(
|
|
155
154
|
# "batch": batch,
|
156
155
|
# "error": str(e)
|
157
156
|
# })
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
# def map_batch_repeat_forever(
|
162
|
-
# batch, num_threads=256, image_shape=(256, 256),
|
163
|
-
# min_image_shape=(128, 128),
|
164
|
-
# timeout=15, retries=3, image_processor=default_image_processor,
|
165
|
-
# upscale_interpolation=cv2.INTER_CUBIC,
|
166
|
-
# downscale_interpolation=cv2.INTER_AREA,
|
167
|
-
# feature_extractor=default_feature_extractor,
|
168
|
-
# ):
|
169
|
-
# while True: # Repeat forever
|
170
|
-
# try:
|
171
|
-
# map_sample_fn = partial(
|
172
|
-
# map_sample, image_shape=image_shape, min_image_shape=min_image_shape,
|
173
|
-
# timeout=timeout, retries=retries, image_processor=image_processor,
|
174
|
-
# upscale_interpolation=upscale_interpolation,
|
175
|
-
# downscale_interpolation=downscale_interpolation,
|
176
|
-
# feature_extractor=feature_extractor
|
177
|
-
# )
|
178
|
-
# features = feature_extractor(batch)
|
179
|
-
# url, caption = features["url"], features["caption"]
|
180
|
-
# with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
181
|
-
# executor.map(map_sample_fn, url, caption)
|
182
|
-
# # Shuffle the batch
|
183
|
-
# batch = batch.shuffle(seed=np.random.randint(0, 1000000))
|
184
|
-
# except Exception as e:
|
185
|
-
# print(f"Error maping batch", e)
|
186
|
-
# traceback.print_exc()
|
187
|
-
# # error_queue.put_nowait({
|
188
|
-
# # "batch": batch,
|
189
|
-
# # "error": str(e)
|
190
|
-
# # })
|
191
|
-
# pass
|
157
|
+
pass
|
158
|
+
|
192
159
|
|
193
160
|
def parallel_image_loader(
|
194
161
|
dataset: Dataset, num_workers: int = 8, image_shape=(256, 256),
|
@@ -197,11 +164,9 @@ def parallel_image_loader(
|
|
197
164
|
upscale_interpolation=cv2.INTER_CUBIC,
|
198
165
|
downscale_interpolation=cv2.INTER_AREA,
|
199
166
|
feature_extractor=default_feature_extractor,
|
200
|
-
map_batch_fn=map_batch,
|
201
|
-
|
202
167
|
):
|
203
168
|
map_batch_fn = partial(
|
204
|
-
|
169
|
+
map_batch, num_threads=num_threads, image_shape=image_shape,
|
205
170
|
min_image_shape=min_image_shape,
|
206
171
|
timeout=timeout, retries=retries, image_processor=image_processor,
|
207
172
|
upscale_interpolation=upscale_interpolation,
|
@@ -209,23 +174,18 @@ def parallel_image_loader(
|
|
209
174
|
feature_extractor=feature_extractor
|
210
175
|
)
|
211
176
|
shard_len = len(dataset) // num_workers
|
212
|
-
print(f"Local Shard lengths: {shard_len}
|
177
|
+
print(f"Local Shard lengths: {shard_len}")
|
213
178
|
with multiprocessing.Pool(num_workers) as pool:
|
214
179
|
iteration = 0
|
215
180
|
while True:
|
216
181
|
# Repeat forever
|
217
182
|
shards = [dataset[i*shard_len:(i+1)*shard_len]
|
218
183
|
for i in range(num_workers)]
|
219
|
-
# shards = [dataset.shard(num_shards=num_workers, index=i) for i in range(num_workers)]
|
220
184
|
print(f"mapping {len(shards)} shards")
|
221
|
-
|
222
|
-
for error in errors:
|
223
|
-
if error is not None:
|
224
|
-
print(f"Error in mapping batch", error)
|
185
|
+
pool.map(map_batch_fn, shards)
|
225
186
|
iteration += 1
|
226
187
|
print(f"Shuffling dataset with seed {iteration}")
|
227
188
|
dataset = dataset.shuffle(seed=iteration)
|
228
|
-
print(f"Dataset shuffled")
|
229
189
|
# Clear the error queue
|
230
190
|
# while not error_queue.empty():
|
231
191
|
# error_queue.get_nowait()
|
@@ -240,7 +200,6 @@ class ImageBatchIterator:
|
|
240
200
|
upscale_interpolation=cv2.INTER_CUBIC,
|
241
201
|
downscale_interpolation=cv2.INTER_AREA,
|
242
202
|
feature_extractor=default_feature_extractor,
|
243
|
-
map_batch_fn=map_batch,
|
244
203
|
):
|
245
204
|
self.dataset = dataset
|
246
205
|
self.num_workers = num_workers
|
@@ -255,8 +214,7 @@ class ImageBatchIterator:
|
|
255
214
|
image_processor=image_processor,
|
256
215
|
upscale_interpolation=upscale_interpolation,
|
257
216
|
downscale_interpolation=downscale_interpolation,
|
258
|
-
feature_extractor=feature_extractor
|
259
|
-
map_batch_fn=map_batch_fn,
|
217
|
+
feature_extractor=feature_extractor
|
260
218
|
)
|
261
219
|
self.thread = threading.Thread(target=loader, args=(dataset,))
|
262
220
|
self.thread.start()
|
@@ -323,7 +281,6 @@ class OnlineStreamingDataLoader():
|
|
323
281
|
upscale_interpolation=cv2.INTER_CUBIC,
|
324
282
|
downscale_interpolation=cv2.INTER_AREA,
|
325
283
|
feature_extractor=default_feature_extractor,
|
326
|
-
map_batch_fn=map_batch,
|
327
284
|
):
|
328
285
|
if isinstance(dataset, str):
|
329
286
|
dataset_path = dataset
|
@@ -351,8 +308,7 @@ class OnlineStreamingDataLoader():
|
|
351
308
|
timeout=timeout, retries=retries, image_processor=image_processor,
|
352
309
|
upscale_interpolation=upscale_interpolation,
|
353
310
|
downscale_interpolation=downscale_interpolation,
|
354
|
-
feature_extractor=feature_extractor
|
355
|
-
map_batch_fn=map_batch_fn,
|
311
|
+
feature_extractor=feature_extractor
|
356
312
|
)
|
357
313
|
self.batch_size = batch_size
|
358
314
|
|
@@ -364,7 +320,7 @@ class OnlineStreamingDataLoader():
|
|
364
320
|
try:
|
365
321
|
self.batch_queue.put(collate_fn(batch))
|
366
322
|
except Exception as e:
|
367
|
-
print("Error
|
323
|
+
print("Error processing batch", e)
|
368
324
|
|
369
325
|
self.loader_thread = threading.Thread(target=batch_loader)
|
370
326
|
self.loader_thread.start()
|
@@ -1,7 +1,7 @@
|
|
1
1
|
flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
flaxdiff/utils.py,sha256=B0GcHlzlVYDNEIdh2v5qmP4u0neIT-FqexNohuyuCvg,2452
|
3
3
|
flaxdiff/data/__init__.py,sha256=PM3PkHihyohT5SHVYKc8vQ4IeVfGPpCktkSVwvqMjQ4,52
|
4
|
-
flaxdiff/data/online_loader.py,sha256=
|
4
|
+
flaxdiff/data/online_loader.py,sha256=fUM91etaEZmxP0ZxzE1TfxOyHzk1Yq45tYT_P3F-HT0,11311
|
5
5
|
flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,26
|
6
6
|
flaxdiff/models/attention.py,sha256=ZbDGIb5Q6FRqJ6qRY660cqw4WvF9IwCnhEuYdTpLPdM,13023
|
7
7
|
flaxdiff/models/common.py,sha256=hWsSs2BP2J-JN1s4qLRr-h-KYkcVyl2hOp1Wsm_L-h8,10994
|
@@ -34,7 +34,7 @@ flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk,
|
|
34
34
|
flaxdiff/trainer/autoencoder_trainer.py,sha256=al7AsZ7yeDMEiDD-gbcXf0ADq_xfk1VMxvg24GfA-XQ,7008
|
35
35
|
flaxdiff/trainer/diffusion_trainer.py,sha256=wKkg63DWZjx2MoM3VQNCDIr40rWN8fUGxH9jWWxfZao,9373
|
36
36
|
flaxdiff/trainer/simple_trainer.py,sha256=Z77zRS5viJpd2Mpl6sonJk5WcnEWi2Cd4gl4u5tIX2M,18206
|
37
|
-
flaxdiff-0.1.35.
|
38
|
-
flaxdiff-0.1.35.
|
39
|
-
flaxdiff-0.1.35.
|
40
|
-
flaxdiff-0.1.35.
|
37
|
+
flaxdiff-0.1.35.3.dist-info/METADATA,sha256=g845MSjktfjXWKWtae4_ELwFvtsND8ysj4_yt572Rl4,22085
|
38
|
+
flaxdiff-0.1.35.3.dist-info/WHEEL,sha256=cVxcB9AmuTcXqmwrtPhNK88dr7IR_b6qagTj0UvIEbY,91
|
39
|
+
flaxdiff-0.1.35.3.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
|
40
|
+
flaxdiff-0.1.35.3.dist-info/RECORD,,
|
File without changes
|
File without changes
|