flaxdiff 0.1.32__py3-none-any.whl → 0.1.34__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/data/online_loader.py +118 -53
- {flaxdiff-0.1.32.dist-info → flaxdiff-0.1.34.dist-info}/METADATA +1 -1
- {flaxdiff-0.1.32.dist-info → flaxdiff-0.1.34.dist-info}/RECORD +5 -5
- {flaxdiff-0.1.32.dist-info → flaxdiff-0.1.34.dist-info}/WHEEL +0 -0
- {flaxdiff-0.1.32.dist-info → flaxdiff-0.1.34.dist-info}/top_level.txt +0 -0
flaxdiff/data/online_loader.py
CHANGED
@@ -21,11 +21,12 @@ import urllib
|
|
21
21
|
|
22
22
|
import PIL.Image
|
23
23
|
import cv2
|
24
|
-
import traceback
|
24
|
+
import traceback
|
25
25
|
|
26
26
|
USER_AGENT = get_datasets_user_agent()
|
27
27
|
|
28
28
|
data_queue = Queue(16*2000)
|
29
|
+
error_queue = Queue(16*2000)
|
29
30
|
|
30
31
|
|
31
32
|
def fetch_single_image(image_url, timeout=None, retries=0):
|
@@ -45,7 +46,7 @@ def fetch_single_image(image_url, timeout=None, retries=0):
|
|
45
46
|
|
46
47
|
|
47
48
|
def default_image_processor(
|
48
|
-
image, image_shape,
|
49
|
+
image, image_shape,
|
49
50
|
min_image_shape=(128, 128),
|
50
51
|
upscale_interpolation=cv2.INTER_CUBIC,
|
51
52
|
downscale_interpolation=cv2.INTER_AREA,
|
@@ -77,8 +78,15 @@ def default_image_processor(
|
|
77
78
|
return image, original_height, original_width
|
78
79
|
|
79
80
|
|
81
|
+
def default_feature_extractor(sample):
|
82
|
+
return {
|
83
|
+
"url": sample["url"],
|
84
|
+
"caption": sample["caption"],
|
85
|
+
}
|
86
|
+
|
87
|
+
|
80
88
|
def map_sample(
|
81
|
-
|
89
|
+
sample,
|
82
90
|
image_shape=(256, 256),
|
83
91
|
min_image_shape=(128, 128),
|
84
92
|
timeout=15,
|
@@ -86,8 +94,11 @@ def map_sample(
|
|
86
94
|
upscale_interpolation=cv2.INTER_CUBIC,
|
87
95
|
downscale_interpolation=cv2.INTER_AREA,
|
88
96
|
image_processor=default_image_processor,
|
97
|
+
feature_extractor=default_feature_extractor,
|
89
98
|
):
|
90
99
|
try:
|
100
|
+
features = feature_extractor(sample)
|
101
|
+
url, caption = features["url"], features["caption"]
|
91
102
|
# Assuming fetch_single_image is defined elsewhere
|
92
103
|
image = fetch_single_image(url, timeout=timeout, retries=retries)
|
93
104
|
if image is None:
|
@@ -96,11 +107,12 @@ def map_sample(
|
|
96
107
|
image, original_height, original_width = image_processor(
|
97
108
|
image, image_shape, min_image_shape=min_image_shape,
|
98
109
|
upscale_interpolation=upscale_interpolation,
|
99
|
-
downscale_interpolation=downscale_interpolation,
|
100
|
-
|
110
|
+
downscale_interpolation=downscale_interpolation,
|
111
|
+
)
|
112
|
+
|
101
113
|
if image is None:
|
102
114
|
return
|
103
|
-
|
115
|
+
|
104
116
|
data_queue.put({
|
105
117
|
"url": url,
|
106
118
|
"caption": caption,
|
@@ -110,22 +122,17 @@ def map_sample(
|
|
110
122
|
})
|
111
123
|
except Exception as e:
|
112
124
|
print(f"Error maping sample {url}", e)
|
113
|
-
traceback.print_exc()
|
125
|
+
traceback.print_exc()
|
114
126
|
# error_queue.put_nowait({
|
115
127
|
# "url": url,
|
116
128
|
# "caption": caption,
|
117
129
|
# "error": str(e)
|
118
130
|
# })
|
119
131
|
pass
|
120
|
-
|
121
|
-
def default_feature_extractor(sample):
|
122
|
-
return {
|
123
|
-
"url": sample["url"],
|
124
|
-
"caption": sample["caption"],
|
125
|
-
}
|
132
|
+
|
126
133
|
|
127
134
|
def map_batch(
|
128
|
-
batch, num_threads=256, image_shape=(256, 256),
|
135
|
+
batch, num_threads=256, image_shape=(256, 256),
|
129
136
|
min_image_shape=(128, 128),
|
130
137
|
timeout=15, retries=3, image_processor=default_image_processor,
|
131
138
|
upscale_interpolation=cv2.INTER_CUBIC,
|
@@ -133,40 +140,76 @@ def map_batch(
|
|
133
140
|
feature_extractor=default_feature_extractor,
|
134
141
|
):
|
135
142
|
try:
|
136
|
-
map_sample_fn = partial(
|
137
|
-
|
138
|
-
|
139
|
-
|
143
|
+
map_sample_fn = partial(
|
144
|
+
map_sample, image_shape=image_shape, min_image_shape=min_image_shape,
|
145
|
+
timeout=timeout, retries=retries, image_processor=image_processor,
|
146
|
+
upscale_interpolation=upscale_interpolation,
|
147
|
+
downscale_interpolation=downscale_interpolation,
|
148
|
+
feature_extractor=feature_extractor
|
149
|
+
)
|
140
150
|
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
141
|
-
|
142
|
-
url, caption = features["url"], features["caption"]
|
143
|
-
executor.map(map_sample_fn, url, caption)
|
151
|
+
executor.map(map_sample_fn, batch)
|
144
152
|
except Exception as e:
|
145
153
|
print(f"Error maping batch", e)
|
146
|
-
traceback.print_exc()
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
154
|
+
traceback.print_exc()
|
155
|
+
error_queue.put_nowait({
|
156
|
+
"batch": batch,
|
157
|
+
"error": str(e)
|
158
|
+
})
|
151
159
|
pass
|
152
160
|
|
153
161
|
|
162
|
+
def map_batch_repeat_forever(
|
163
|
+
batch, num_threads=256, image_shape=(256, 256),
|
164
|
+
min_image_shape=(128, 128),
|
165
|
+
timeout=15, retries=3, image_processor=default_image_processor,
|
166
|
+
upscale_interpolation=cv2.INTER_CUBIC,
|
167
|
+
downscale_interpolation=cv2.INTER_AREA,
|
168
|
+
feature_extractor=default_feature_extractor,
|
169
|
+
):
|
170
|
+
while True: # Repeat forever
|
171
|
+
try:
|
172
|
+
map_sample_fn = partial(
|
173
|
+
map_sample, image_shape=image_shape, min_image_shape=min_image_shape,
|
174
|
+
timeout=timeout, retries=retries, image_processor=image_processor,
|
175
|
+
upscale_interpolation=upscale_interpolation,
|
176
|
+
downscale_interpolation=downscale_interpolation,
|
177
|
+
feature_extractor=feature_extractor
|
178
|
+
)
|
179
|
+
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
180
|
+
executor.map(map_sample_fn, batch)
|
181
|
+
# Shuffle the batch
|
182
|
+
batch = batch.shuffle(seed=np.random.randint(0, 1000000))
|
183
|
+
except Exception as e:
|
184
|
+
print(f"Error maping batch", e)
|
185
|
+
traceback.print_exc()
|
186
|
+
error_queue.put_nowait({
|
187
|
+
"batch": batch,
|
188
|
+
"error": str(e)
|
189
|
+
})
|
190
|
+
pass
|
191
|
+
|
192
|
+
|
154
193
|
def parallel_image_loader(
|
155
|
-
dataset: Dataset, num_workers: int = 8, image_shape=(256, 256),
|
194
|
+
dataset: Dataset, num_workers: int = 8, image_shape=(256, 256),
|
156
195
|
min_image_shape=(128, 128),
|
157
196
|
num_threads=256, timeout=15, retries=3, image_processor=default_image_processor,
|
158
197
|
upscale_interpolation=cv2.INTER_CUBIC,
|
159
198
|
downscale_interpolation=cv2.INTER_AREA,
|
160
199
|
feature_extractor=default_feature_extractor,
|
200
|
+
map_batch_fn=map_batch,
|
201
|
+
|
161
202
|
):
|
162
|
-
map_batch_fn = partial(
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
203
|
+
map_batch_fn = partial(
|
204
|
+
map_batch_fn, num_threads=num_threads, image_shape=image_shape,
|
205
|
+
min_image_shape=min_image_shape,
|
206
|
+
timeout=timeout, retries=retries, image_processor=image_processor,
|
207
|
+
upscale_interpolation=upscale_interpolation,
|
208
|
+
downscale_interpolation=downscale_interpolation,
|
209
|
+
feature_extractor=feature_extractor
|
210
|
+
)
|
168
211
|
shard_len = len(dataset) // num_workers
|
169
|
-
print(f"Local Shard lengths: {shard_len}")
|
212
|
+
print(f"Local Shard lengths: {shard_len}, workers: {num_workers}")
|
170
213
|
with multiprocessing.Pool(num_workers) as pool:
|
171
214
|
iteration = 0
|
172
215
|
while True:
|
@@ -178,6 +221,7 @@ def parallel_image_loader(
|
|
178
221
|
iteration += 1
|
179
222
|
print(f"Shuffling dataset with seed {iteration}")
|
180
223
|
dataset = dataset.shuffle(seed=iteration)
|
224
|
+
print(f"Dataset shuffled")
|
181
225
|
# Clear the error queue
|
182
226
|
# while not error_queue.empty():
|
183
227
|
# error_queue.get_nowait()
|
@@ -185,27 +229,44 @@ def parallel_image_loader(
|
|
185
229
|
|
186
230
|
class ImageBatchIterator:
|
187
231
|
def __init__(
|
188
|
-
self, dataset: Dataset, batch_size: int = 64, image_shape=(256, 256),
|
232
|
+
self, dataset: Dataset, batch_size: int = 64, image_shape=(256, 256),
|
189
233
|
min_image_shape=(128, 128),
|
190
|
-
num_workers: int = 8, num_threads=256, timeout=15, retries=3,
|
234
|
+
num_workers: int = 8, num_threads=256, timeout=15, retries=3,
|
191
235
|
image_processor=default_image_processor,
|
192
236
|
upscale_interpolation=cv2.INTER_CUBIC,
|
193
237
|
downscale_interpolation=cv2.INTER_AREA,
|
194
238
|
feature_extractor=default_feature_extractor,
|
239
|
+
map_batch_fn=map_batch,
|
195
240
|
):
|
196
241
|
self.dataset = dataset
|
197
242
|
self.num_workers = num_workers
|
198
243
|
self.batch_size = batch_size
|
199
|
-
loader = partial(
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
244
|
+
loader = partial(
|
245
|
+
parallel_image_loader,
|
246
|
+
num_threads=num_threads,
|
247
|
+
image_shape=image_shape,
|
248
|
+
min_image_shape=min_image_shape,
|
249
|
+
num_workers=num_workers,
|
250
|
+
timeout=timeout, retries=retries,
|
251
|
+
image_processor=image_processor,
|
252
|
+
upscale_interpolation=upscale_interpolation,
|
253
|
+
downscale_interpolation=downscale_interpolation,
|
254
|
+
feature_extractor=feature_extractor,
|
255
|
+
map_batch_fn=map_batch_fn,
|
256
|
+
)
|
207
257
|
self.thread = threading.Thread(target=loader, args=(dataset,))
|
208
258
|
self.thread.start()
|
259
|
+
self.error_queue = queue.Queue()
|
260
|
+
|
261
|
+
def error_fetcher():
|
262
|
+
while True:
|
263
|
+
error = error_queue.get()
|
264
|
+
self.error_queue.put(error)
|
265
|
+
self.error_thread = threading.Thread(target=error_fetcher)
|
266
|
+
self.error_thread.start()
|
267
|
+
|
268
|
+
def get_error(self):
|
269
|
+
yield self.error_queue.get()
|
209
270
|
|
210
271
|
def __iter__(self):
|
211
272
|
return self
|
@@ -269,6 +330,7 @@ class OnlineStreamingDataLoader():
|
|
269
330
|
upscale_interpolation=cv2.INTER_CUBIC,
|
270
331
|
downscale_interpolation=cv2.INTER_AREA,
|
271
332
|
feature_extractor=default_feature_extractor,
|
333
|
+
map_batch_fn=map_batch,
|
272
334
|
):
|
273
335
|
if isinstance(dataset, str):
|
274
336
|
dataset_path = dataset
|
@@ -289,13 +351,16 @@ class OnlineStreamingDataLoader():
|
|
289
351
|
self.dataset = dataset.shard(
|
290
352
|
num_shards=global_process_count, index=global_process_index)
|
291
353
|
print(f"Dataset length: {len(dataset)}")
|
292
|
-
self.iterator = ImageBatchIterator(
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
354
|
+
self.iterator = ImageBatchIterator(
|
355
|
+
self.dataset, image_shape=image_shape,
|
356
|
+
min_image_shape=min_image_shape,
|
357
|
+
num_workers=num_workers, batch_size=batch_size, num_threads=num_threads,
|
358
|
+
timeout=timeout, retries=retries, image_processor=image_processor,
|
359
|
+
upscale_interpolation=upscale_interpolation,
|
360
|
+
downscale_interpolation=downscale_interpolation,
|
361
|
+
feature_extractor=feature_extractor,
|
362
|
+
map_batch_fn=map_batch_fn,
|
363
|
+
)
|
299
364
|
self.batch_size = batch_size
|
300
365
|
|
301
366
|
# Launch a thread to load batches in the background
|
@@ -306,7 +371,7 @@ class OnlineStreamingDataLoader():
|
|
306
371
|
try:
|
307
372
|
self.batch_queue.put(collate_fn(batch))
|
308
373
|
except Exception as e:
|
309
|
-
print("Error
|
374
|
+
print("Error collating batch", e)
|
310
375
|
|
311
376
|
self.loader_thread = threading.Thread(target=batch_loader)
|
312
377
|
self.loader_thread.start()
|
@@ -319,4 +384,4 @@ class OnlineStreamingDataLoader():
|
|
319
384
|
# return self.collate_fn(next(self.iterator))
|
320
385
|
|
321
386
|
def __len__(self):
|
322
|
-
return len(self.dataset)
|
387
|
+
return len(self.dataset)
|
@@ -1,7 +1,7 @@
|
|
1
1
|
flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
flaxdiff/utils.py,sha256=B0GcHlzlVYDNEIdh2v5qmP4u0neIT-FqexNohuyuCvg,2452
|
3
3
|
flaxdiff/data/__init__.py,sha256=PM3PkHihyohT5SHVYKc8vQ4IeVfGPpCktkSVwvqMjQ4,52
|
4
|
-
flaxdiff/data/online_loader.py,sha256=
|
4
|
+
flaxdiff/data/online_loader.py,sha256=VpwTivqTYXM4Eu3jLWoDkYRcl7KsgJ5A6t17zDSDsDc,13226
|
5
5
|
flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,26
|
6
6
|
flaxdiff/models/attention.py,sha256=ZbDGIb5Q6FRqJ6qRY660cqw4WvF9IwCnhEuYdTpLPdM,13023
|
7
7
|
flaxdiff/models/common.py,sha256=hWsSs2BP2J-JN1s4qLRr-h-KYkcVyl2hOp1Wsm_L-h8,10994
|
@@ -34,7 +34,7 @@ flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk,
|
|
34
34
|
flaxdiff/trainer/autoencoder_trainer.py,sha256=al7AsZ7yeDMEiDD-gbcXf0ADq_xfk1VMxvg24GfA-XQ,7008
|
35
35
|
flaxdiff/trainer/diffusion_trainer.py,sha256=wKkg63DWZjx2MoM3VQNCDIr40rWN8fUGxH9jWWxfZao,9373
|
36
36
|
flaxdiff/trainer/simple_trainer.py,sha256=Z77zRS5viJpd2Mpl6sonJk5WcnEWi2Cd4gl4u5tIX2M,18206
|
37
|
-
flaxdiff-0.1.
|
38
|
-
flaxdiff-0.1.
|
39
|
-
flaxdiff-0.1.
|
40
|
-
flaxdiff-0.1.
|
37
|
+
flaxdiff-0.1.34.dist-info/METADATA,sha256=5x3QvLbZ3AgBjghAxnWQXQD5-cJ9jZ0ifPhGxdTeuEc,22083
|
38
|
+
flaxdiff-0.1.34.dist-info/WHEEL,sha256=cVxcB9AmuTcXqmwrtPhNK88dr7IR_b6qagTj0UvIEbY,91
|
39
|
+
flaxdiff-0.1.34.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
|
40
|
+
flaxdiff-0.1.34.dist-info/RECORD,,
|
File without changes
|
File without changes
|