flaxdiff 0.1.17__py3-none-any.whl → 0.1.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/data/online_loader.py +46 -17
- {flaxdiff-0.1.17.dist-info → flaxdiff-0.1.19.dist-info}/METADATA +1 -1
- {flaxdiff-0.1.17.dist-info → flaxdiff-0.1.19.dist-info}/RECORD +5 -5
- {flaxdiff-0.1.17.dist-info → flaxdiff-0.1.19.dist-info}/WHEEL +0 -0
- {flaxdiff-0.1.17.dist-info → flaxdiff-0.1.19.dist-info}/top_level.txt +0 -0
flaxdiff/data/online_loader.py
CHANGED
@@ -25,7 +25,7 @@ import cv2
|
|
25
25
|
USER_AGENT = get_datasets_user_agent()
|
26
26
|
|
27
27
|
data_queue = Queue(16*2000)
|
28
|
-
error_queue = Queue(
|
28
|
+
error_queue = Queue()
|
29
29
|
|
30
30
|
|
31
31
|
def fetch_single_image(image_url, timeout=None, retries=0):
|
@@ -44,7 +44,7 @@ def fetch_single_image(image_url, timeout=None, retries=0):
|
|
44
44
|
return image
|
45
45
|
|
46
46
|
|
47
|
-
def default_image_processor(image, image_shape, interpolation=cv2.
|
47
|
+
def default_image_processor(image, image_shape, interpolation=cv2.INTER_CUBIC):
|
48
48
|
image = A.longest_max_size(image, max(
|
49
49
|
image_shape), interpolation=interpolation)
|
50
50
|
image = A.pad(
|
@@ -62,7 +62,7 @@ def map_sample(
|
|
62
62
|
image_shape=(256, 256),
|
63
63
|
timeout=15,
|
64
64
|
retries=3,
|
65
|
-
upscale_interpolation=cv2.
|
65
|
+
upscale_interpolation=cv2.INTER_CUBIC,
|
66
66
|
downscale_interpolation=cv2.INTER_AREA,
|
67
67
|
image_processor=default_image_processor,
|
68
68
|
):
|
@@ -98,17 +98,24 @@ def map_sample(
|
|
98
98
|
"original_width": original_width,
|
99
99
|
})
|
100
100
|
except Exception as e:
|
101
|
-
error_queue.
|
101
|
+
error_queue.put_nowait({
|
102
102
|
"url": url,
|
103
103
|
"caption": caption,
|
104
104
|
"error": str(e)
|
105
105
|
})
|
106
106
|
|
107
107
|
|
108
|
-
def map_batch(
|
108
|
+
def map_batch(
|
109
|
+
batch, num_threads=256, image_shape=(256, 256),
|
110
|
+
timeout=15, retries=3, image_processor=default_image_processor,
|
111
|
+
upscale_interpolation=cv2.INTER_CUBIC,
|
112
|
+
downscale_interpolation=cv2.INTER_AREA,
|
113
|
+
):
|
109
114
|
try:
|
110
115
|
map_sample_fn = partial(map_sample, image_shape=image_shape,
|
111
|
-
timeout=timeout, retries=retries, image_processor=image_processor
|
116
|
+
timeout=timeout, retries=retries, image_processor=image_processor,
|
117
|
+
upscale_interpolation=upscale_interpolation,
|
118
|
+
downscale_interpolation=downscale_interpolation)
|
112
119
|
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
113
120
|
executor.map(map_sample_fn, batch["url"], batch['caption'])
|
114
121
|
except Exception as e:
|
@@ -118,10 +125,16 @@ def map_batch(batch, num_threads=256, image_shape=(256, 256), timeout=15, retrie
|
|
118
125
|
})
|
119
126
|
|
120
127
|
|
121
|
-
def parallel_image_loader(
|
122
|
-
|
128
|
+
def parallel_image_loader(
|
129
|
+
dataset: Dataset, num_workers: int = 8, image_shape=(256, 256),
|
130
|
+
num_threads=256, timeout=15, retries=3, image_processor=default_image_processor,
|
131
|
+
upscale_interpolation=cv2.INTER_CUBIC,
|
132
|
+
downscale_interpolation=cv2.INTER_AREA,
|
133
|
+
):
|
123
134
|
map_batch_fn = partial(map_batch, num_threads=num_threads, image_shape=image_shape,
|
124
|
-
timeout=timeout, retries=retries, image_processor=image_processor
|
135
|
+
timeout=timeout, retries=retries, image_processor=image_processor,
|
136
|
+
upscale_interpolation=upscale_interpolation,
|
137
|
+
downscale_interpolation=downscale_interpolation)
|
125
138
|
shard_len = len(dataset) // num_workers
|
126
139
|
print(f"Local Shard lengths: {shard_len}")
|
127
140
|
with multiprocessing.Pool(num_workers) as pool:
|
@@ -135,17 +148,27 @@ def parallel_image_loader(dataset: Dataset, num_workers: int = 8, image_shape=(2
|
|
135
148
|
iteration += 1
|
136
149
|
print(f"Shuffling dataset with seed {iteration}")
|
137
150
|
dataset = dataset.shuffle(seed=iteration)
|
151
|
+
# Clear the error queue
|
152
|
+
while not error_queue.empty():
|
153
|
+
error_queue.get_nowait()
|
138
154
|
|
139
155
|
|
140
156
|
class ImageBatchIterator:
|
141
|
-
def __init__(
|
142
|
-
|
157
|
+
def __init__(
|
158
|
+
self, dataset: Dataset, batch_size: int = 64, image_shape=(256, 256),
|
159
|
+
num_workers: int = 8, num_threads=256, timeout=15, retries=3,
|
160
|
+
image_processor=default_image_processor,
|
161
|
+
upscale_interpolation=cv2.INTER_CUBIC,
|
162
|
+
downscale_interpolation=cv2.INTER_AREA,
|
163
|
+
):
|
143
164
|
self.dataset = dataset
|
144
165
|
self.num_workers = num_workers
|
145
166
|
self.batch_size = batch_size
|
146
167
|
loader = partial(parallel_image_loader, num_threads=num_threads,
|
147
168
|
image_shape=image_shape, num_workers=num_workers,
|
148
|
-
timeout=timeout, retries=retries, image_processor=image_processor
|
169
|
+
timeout=timeout, retries=retries, image_processor=image_processor,
|
170
|
+
upscale_interpolation=upscale_interpolation,
|
171
|
+
downscale_interpolation=downscale_interpolation)
|
149
172
|
self.thread = threading.Thread(target=loader, args=(dataset,))
|
150
173
|
self.thread.start()
|
151
174
|
|
@@ -207,6 +230,8 @@ class OnlineStreamingDataLoader():
|
|
207
230
|
timeout=15,
|
208
231
|
retries=3,
|
209
232
|
image_processor=default_image_processor,
|
233
|
+
upscale_interpolation=cv2.INTER_CUBIC,
|
234
|
+
downscale_interpolation=cv2.INTER_AREA,
|
210
235
|
):
|
211
236
|
if isinstance(dataset, str):
|
212
237
|
dataset_path = dataset
|
@@ -229,8 +254,9 @@ class OnlineStreamingDataLoader():
|
|
229
254
|
print(f"Dataset length: {len(dataset)}")
|
230
255
|
self.iterator = ImageBatchIterator(self.dataset, image_shape=image_shape,
|
231
256
|
num_workers=num_workers, batch_size=batch_size, num_threads=num_threads,
|
232
|
-
timeout=timeout, retries=retries, image_processor=image_processor
|
233
|
-
|
257
|
+
timeout=timeout, retries=retries, image_processor=image_processor,
|
258
|
+
upscale_interpolation=upscale_interpolation,
|
259
|
+
downscale_interpolation=downscale_interpolation)
|
234
260
|
self.batch_size = batch_size
|
235
261
|
|
236
262
|
# Launch a thread to load batches in the background
|
@@ -238,7 +264,10 @@ class OnlineStreamingDataLoader():
|
|
238
264
|
|
239
265
|
def batch_loader():
|
240
266
|
for batch in self.iterator:
|
241
|
-
|
267
|
+
try:
|
268
|
+
self.batch_queue.put(collate_fn(batch))
|
269
|
+
except Exception as e:
|
270
|
+
print("Error processing batch", e)
|
242
271
|
|
243
272
|
self.loader_thread = threading.Thread(target=batch_loader)
|
244
273
|
self.loader_thread.start()
|
@@ -247,8 +276,8 @@ class OnlineStreamingDataLoader():
|
|
247
276
|
return self
|
248
277
|
|
249
278
|
def __next__(self):
|
250
|
-
return self.
|
279
|
+
return self.batch_queue.get()
|
251
280
|
# return self.collate_fn(next(self.iterator))
|
252
281
|
|
253
282
|
def __len__(self):
|
254
|
-
return len(self.dataset)
|
283
|
+
return len(self.dataset)
|
@@ -1,7 +1,7 @@
|
|
1
1
|
flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
flaxdiff/utils.py,sha256=B0GcHlzlVYDNEIdh2v5qmP4u0neIT-FqexNohuyuCvg,2452
|
3
3
|
flaxdiff/data/__init__.py,sha256=PM3PkHihyohT5SHVYKc8vQ4IeVfGPpCktkSVwvqMjQ4,52
|
4
|
-
flaxdiff/data/online_loader.py,sha256=
|
4
|
+
flaxdiff/data/online_loader.py,sha256=WK4apO8Bx-RTU_z5imB53Lzq12vqGnXA9DhLq8nb0us,9991
|
5
5
|
flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,26
|
6
6
|
flaxdiff/models/attention.py,sha256=ZbDGIb5Q6FRqJ6qRY660cqw4WvF9IwCnhEuYdTpLPdM,13023
|
7
7
|
flaxdiff/models/common.py,sha256=fd-Fl0VCNEBjijHNwGBqYL5VvXe9u0347h25czNTmRw,10780
|
@@ -34,7 +34,7 @@ flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk,
|
|
34
34
|
flaxdiff/trainer/autoencoder_trainer.py,sha256=al7AsZ7yeDMEiDD-gbcXf0ADq_xfk1VMxvg24GfA-XQ,7008
|
35
35
|
flaxdiff/trainer/diffusion_trainer.py,sha256=wKkg63DWZjx2MoM3VQNCDIr40rWN8fUGxH9jWWxfZao,9373
|
36
36
|
flaxdiff/trainer/simple_trainer.py,sha256=Z77zRS5viJpd2Mpl6sonJk5WcnEWi2Cd4gl4u5tIX2M,18206
|
37
|
-
flaxdiff-0.1.
|
38
|
-
flaxdiff-0.1.
|
39
|
-
flaxdiff-0.1.
|
40
|
-
flaxdiff-0.1.
|
37
|
+
flaxdiff-0.1.19.dist-info/METADATA,sha256=NH-f1SK5obamoVRk8ZPQxvtQcz_R3mui3ToZe0Qx8Vg,22083
|
38
|
+
flaxdiff-0.1.19.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
|
39
|
+
flaxdiff-0.1.19.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
|
40
|
+
flaxdiff-0.1.19.dist-info/RECORD,,
|
File without changes
|
File without changes
|