flaxdiff 0.1.35.1__py3-none-any.whl → 0.1.35.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/data/online_loader.py +19 -71
- {flaxdiff-0.1.35.1.dist-info → flaxdiff-0.1.35.3.dist-info}/METADATA +1 -1
- {flaxdiff-0.1.35.1.dist-info → flaxdiff-0.1.35.3.dist-info}/RECORD +5 -5
- {flaxdiff-0.1.35.1.dist-info → flaxdiff-0.1.35.3.dist-info}/WHEEL +0 -0
- {flaxdiff-0.1.35.1.dist-info → flaxdiff-0.1.35.3.dist-info}/top_level.txt +0 -0
flaxdiff/data/online_loader.py
CHANGED
@@ -26,7 +26,6 @@ import traceback
|
|
26
26
|
USER_AGENT = get_datasets_user_agent()
|
27
27
|
|
28
28
|
data_queue = Queue(16*2000)
|
29
|
-
error_queue = Queue(16*2000)
|
30
29
|
|
31
30
|
|
32
31
|
def fetch_single_image(image_url, timeout=None, retries=0):
|
@@ -78,13 +77,6 @@ def default_image_processor(
|
|
78
77
|
return image, original_height, original_width
|
79
78
|
|
80
79
|
|
81
|
-
def default_feature_extractor(sample):
|
82
|
-
return {
|
83
|
-
"url": sample["url"],
|
84
|
-
"caption": sample["caption"],
|
85
|
-
}
|
86
|
-
|
87
|
-
|
88
80
|
def map_sample(
|
89
81
|
url,
|
90
82
|
caption,
|
@@ -129,6 +121,13 @@ def map_sample(
|
|
129
121
|
pass
|
130
122
|
|
131
123
|
|
124
|
+
def default_feature_extractor(sample):
|
125
|
+
return {
|
126
|
+
"url": sample["url"],
|
127
|
+
"caption": sample["caption"],
|
128
|
+
}
|
129
|
+
|
130
|
+
|
132
131
|
def map_batch(
|
133
132
|
batch, num_threads=256, image_shape=(256, 256),
|
134
133
|
min_image_shape=(128, 128),
|
@@ -142,54 +141,22 @@ def map_batch(
|
|
142
141
|
map_sample, image_shape=image_shape, min_image_shape=min_image_shape,
|
143
142
|
timeout=timeout, retries=retries, image_processor=image_processor,
|
144
143
|
upscale_interpolation=upscale_interpolation,
|
145
|
-
downscale_interpolation=downscale_interpolation
|
146
|
-
feature_extractor=feature_extractor
|
144
|
+
downscale_interpolation=downscale_interpolation
|
147
145
|
)
|
148
|
-
features = feature_extractor(batch)
|
149
|
-
url, caption = features["url"], features["caption"]
|
150
146
|
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
147
|
+
features = feature_extractor(batch)
|
148
|
+
url, caption = features["url"], features["caption"]
|
151
149
|
executor.map(map_sample_fn, url, caption)
|
152
150
|
except Exception as e:
|
153
151
|
print(f"Error maping batch", e)
|
154
152
|
traceback.print_exc()
|
155
|
-
error_queue.put_nowait({
|
156
|
-
|
157
|
-
|
158
|
-
})
|
153
|
+
# error_queue.put_nowait({
|
154
|
+
# "batch": batch,
|
155
|
+
# "error": str(e)
|
156
|
+
# })
|
159
157
|
pass
|
160
158
|
|
161
159
|
|
162
|
-
def map_batch_repeat_forever(
|
163
|
-
batch, num_threads=256, image_shape=(256, 256),
|
164
|
-
min_image_shape=(128, 128),
|
165
|
-
timeout=15, retries=3, image_processor=default_image_processor,
|
166
|
-
upscale_interpolation=cv2.INTER_CUBIC,
|
167
|
-
downscale_interpolation=cv2.INTER_AREA,
|
168
|
-
feature_extractor=default_feature_extractor,
|
169
|
-
):
|
170
|
-
while True: # Repeat forever
|
171
|
-
try:
|
172
|
-
map_sample_fn = partial(
|
173
|
-
map_sample, image_shape=image_shape, min_image_shape=min_image_shape,
|
174
|
-
timeout=timeout, retries=retries, image_processor=image_processor,
|
175
|
-
upscale_interpolation=upscale_interpolation,
|
176
|
-
downscale_interpolation=downscale_interpolation,
|
177
|
-
feature_extractor=feature_extractor
|
178
|
-
)
|
179
|
-
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
180
|
-
executor.map(map_sample_fn, batch)
|
181
|
-
# Shuffle the batch
|
182
|
-
batch = batch.shuffle(seed=np.random.randint(0, 1000000))
|
183
|
-
except Exception as e:
|
184
|
-
print(f"Error maping batch", e)
|
185
|
-
traceback.print_exc()
|
186
|
-
error_queue.put_nowait({
|
187
|
-
"batch": batch,
|
188
|
-
"error": str(e)
|
189
|
-
})
|
190
|
-
pass
|
191
|
-
|
192
|
-
|
193
160
|
def parallel_image_loader(
|
194
161
|
dataset: Dataset, num_workers: int = 8, image_shape=(256, 256),
|
195
162
|
min_image_shape=(128, 128),
|
@@ -197,11 +164,9 @@ def parallel_image_loader(
|
|
197
164
|
upscale_interpolation=cv2.INTER_CUBIC,
|
198
165
|
downscale_interpolation=cv2.INTER_AREA,
|
199
166
|
feature_extractor=default_feature_extractor,
|
200
|
-
map_batch_fn=map_batch,
|
201
|
-
|
202
167
|
):
|
203
168
|
map_batch_fn = partial(
|
204
|
-
|
169
|
+
map_batch, num_threads=num_threads, image_shape=image_shape,
|
205
170
|
min_image_shape=min_image_shape,
|
206
171
|
timeout=timeout, retries=retries, image_processor=image_processor,
|
207
172
|
upscale_interpolation=upscale_interpolation,
|
@@ -209,20 +174,18 @@ def parallel_image_loader(
|
|
209
174
|
feature_extractor=feature_extractor
|
210
175
|
)
|
211
176
|
shard_len = len(dataset) // num_workers
|
212
|
-
print(f"Local Shard lengths: {shard_len}
|
177
|
+
print(f"Local Shard lengths: {shard_len}")
|
213
178
|
with multiprocessing.Pool(num_workers) as pool:
|
214
179
|
iteration = 0
|
215
180
|
while True:
|
216
181
|
# Repeat forever
|
217
182
|
shards = [dataset[i*shard_len:(i+1)*shard_len]
|
218
183
|
for i in range(num_workers)]
|
219
|
-
# shards = [dataset.shard(num_shards=num_workers, index=i) for i in range(num_workers)]
|
220
184
|
print(f"mapping {len(shards)} shards")
|
221
185
|
pool.map(map_batch_fn, shards)
|
222
186
|
iteration += 1
|
223
187
|
print(f"Shuffling dataset with seed {iteration}")
|
224
188
|
dataset = dataset.shuffle(seed=iteration)
|
225
|
-
print(f"Dataset shuffled")
|
226
189
|
# Clear the error queue
|
227
190
|
# while not error_queue.empty():
|
228
191
|
# error_queue.get_nowait()
|
@@ -237,7 +200,6 @@ class ImageBatchIterator:
|
|
237
200
|
upscale_interpolation=cv2.INTER_CUBIC,
|
238
201
|
downscale_interpolation=cv2.INTER_AREA,
|
239
202
|
feature_extractor=default_feature_extractor,
|
240
|
-
map_batch_fn=map_batch,
|
241
203
|
):
|
242
204
|
self.dataset = dataset
|
243
205
|
self.num_workers = num_workers
|
@@ -252,22 +214,10 @@ class ImageBatchIterator:
|
|
252
214
|
image_processor=image_processor,
|
253
215
|
upscale_interpolation=upscale_interpolation,
|
254
216
|
downscale_interpolation=downscale_interpolation,
|
255
|
-
feature_extractor=feature_extractor
|
256
|
-
map_batch_fn=map_batch_fn,
|
217
|
+
feature_extractor=feature_extractor
|
257
218
|
)
|
258
219
|
self.thread = threading.Thread(target=loader, args=(dataset,))
|
259
220
|
self.thread.start()
|
260
|
-
self.error_queue = queue.Queue()
|
261
|
-
|
262
|
-
def error_fetcher():
|
263
|
-
while True:
|
264
|
-
error = error_queue.get()
|
265
|
-
self.error_queue.put(error)
|
266
|
-
self.error_thread = threading.Thread(target=error_fetcher)
|
267
|
-
self.error_thread.start()
|
268
|
-
|
269
|
-
def get_error(self):
|
270
|
-
yield self.error_queue.get()
|
271
221
|
|
272
222
|
def __iter__(self):
|
273
223
|
return self
|
@@ -331,7 +281,6 @@ class OnlineStreamingDataLoader():
|
|
331
281
|
upscale_interpolation=cv2.INTER_CUBIC,
|
332
282
|
downscale_interpolation=cv2.INTER_AREA,
|
333
283
|
feature_extractor=default_feature_extractor,
|
334
|
-
map_batch_fn=map_batch,
|
335
284
|
):
|
336
285
|
if isinstance(dataset, str):
|
337
286
|
dataset_path = dataset
|
@@ -359,8 +308,7 @@ class OnlineStreamingDataLoader():
|
|
359
308
|
timeout=timeout, retries=retries, image_processor=image_processor,
|
360
309
|
upscale_interpolation=upscale_interpolation,
|
361
310
|
downscale_interpolation=downscale_interpolation,
|
362
|
-
feature_extractor=feature_extractor
|
363
|
-
map_batch_fn=map_batch_fn,
|
311
|
+
feature_extractor=feature_extractor
|
364
312
|
)
|
365
313
|
self.batch_size = batch_size
|
366
314
|
|
@@ -372,7 +320,7 @@ class OnlineStreamingDataLoader():
|
|
372
320
|
try:
|
373
321
|
self.batch_queue.put(collate_fn(batch))
|
374
322
|
except Exception as e:
|
375
|
-
print("Error
|
323
|
+
print("Error processing batch", e)
|
376
324
|
|
377
325
|
self.loader_thread = threading.Thread(target=batch_loader)
|
378
326
|
self.loader_thread.start()
|
@@ -1,7 +1,7 @@
|
|
1
1
|
flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
flaxdiff/utils.py,sha256=B0GcHlzlVYDNEIdh2v5qmP4u0neIT-FqexNohuyuCvg,2452
|
3
3
|
flaxdiff/data/__init__.py,sha256=PM3PkHihyohT5SHVYKc8vQ4IeVfGPpCktkSVwvqMjQ4,52
|
4
|
-
flaxdiff/data/online_loader.py,sha256=
|
4
|
+
flaxdiff/data/online_loader.py,sha256=fUM91etaEZmxP0ZxzE1TfxOyHzk1Yq45tYT_P3F-HT0,11311
|
5
5
|
flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,26
|
6
6
|
flaxdiff/models/attention.py,sha256=ZbDGIb5Q6FRqJ6qRY660cqw4WvF9IwCnhEuYdTpLPdM,13023
|
7
7
|
flaxdiff/models/common.py,sha256=hWsSs2BP2J-JN1s4qLRr-h-KYkcVyl2hOp1Wsm_L-h8,10994
|
@@ -34,7 +34,7 @@ flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk,
|
|
34
34
|
flaxdiff/trainer/autoencoder_trainer.py,sha256=al7AsZ7yeDMEiDD-gbcXf0ADq_xfk1VMxvg24GfA-XQ,7008
|
35
35
|
flaxdiff/trainer/diffusion_trainer.py,sha256=wKkg63DWZjx2MoM3VQNCDIr40rWN8fUGxH9jWWxfZao,9373
|
36
36
|
flaxdiff/trainer/simple_trainer.py,sha256=Z77zRS5viJpd2Mpl6sonJk5WcnEWi2Cd4gl4u5tIX2M,18206
|
37
|
-
flaxdiff-0.1.35.
|
38
|
-
flaxdiff-0.1.35.
|
39
|
-
flaxdiff-0.1.35.
|
40
|
-
flaxdiff-0.1.35.
|
37
|
+
flaxdiff-0.1.35.3.dist-info/METADATA,sha256=g845MSjktfjXWKWtae4_ELwFvtsND8ysj4_yt572Rl4,22085
|
38
|
+
flaxdiff-0.1.35.3.dist-info/WHEEL,sha256=cVxcB9AmuTcXqmwrtPhNK88dr7IR_b6qagTj0UvIEbY,91
|
39
|
+
flaxdiff-0.1.35.3.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
|
40
|
+
flaxdiff-0.1.35.3.dist-info/RECORD,,
|
File without changes
|
File without changes
|