flaxdiff 0.1.18__py3-none-any.whl → 0.1.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/data/online_loader.py +66 -32
- {flaxdiff-0.1.18.dist-info → flaxdiff-0.1.20.dist-info}/METADATA +1 -1
- {flaxdiff-0.1.18.dist-info → flaxdiff-0.1.20.dist-info}/RECORD +5 -5
- {flaxdiff-0.1.18.dist-info → flaxdiff-0.1.20.dist-info}/WHEEL +0 -0
- {flaxdiff-0.1.18.dist-info → flaxdiff-0.1.20.dist-info}/top_level.txt +0 -0
flaxdiff/data/online_loader.py
CHANGED
@@ -25,7 +25,6 @@ import cv2
|
|
25
25
|
USER_AGENT = get_datasets_user_agent()
|
26
26
|
|
27
27
|
data_queue = Queue(16*2000)
|
28
|
-
error_queue = Queue()
|
29
28
|
|
30
29
|
|
31
30
|
def fetch_single_image(image_url, timeout=None, retries=0):
|
@@ -44,7 +43,7 @@ def fetch_single_image(image_url, timeout=None, retries=0):
|
|
44
43
|
return image
|
45
44
|
|
46
45
|
|
47
|
-
def default_image_processor(image, image_shape, interpolation=cv2.
|
46
|
+
def default_image_processor(image, image_shape, interpolation=cv2.INTER_CUBIC):
|
48
47
|
image = A.longest_max_size(image, max(
|
49
48
|
image_shape), interpolation=interpolation)
|
50
49
|
image = A.pad(
|
@@ -60,9 +59,10 @@ def default_image_processor(image, image_shape, interpolation=cv2.INTER_LANCZOS4
|
|
60
59
|
def map_sample(
|
61
60
|
url, caption,
|
62
61
|
image_shape=(256, 256),
|
62
|
+
min_image_shape=(128, 128),
|
63
63
|
timeout=15,
|
64
64
|
retries=3,
|
65
|
-
upscale_interpolation=cv2.
|
65
|
+
upscale_interpolation=cv2.INTER_CUBIC,
|
66
66
|
downscale_interpolation=cv2.INTER_AREA,
|
67
67
|
image_processor=default_image_processor,
|
68
68
|
):
|
@@ -75,10 +75,10 @@ def map_sample(
|
|
75
75
|
image = np.array(image)
|
76
76
|
original_height, original_width = image.shape[:2]
|
77
77
|
# check if the image is too small
|
78
|
-
if min(original_height, original_width) < min(
|
78
|
+
if min(original_height, original_width) < min(min_image_shape):
|
79
79
|
return
|
80
80
|
# check if wrong aspect ratio
|
81
|
-
if max(original_height, original_width) / min(original_height, original_width) > 2:
|
81
|
+
if max(original_height, original_width) / min(original_height, original_width) > 2.4:
|
82
82
|
return
|
83
83
|
# check if the variance is too low
|
84
84
|
if np.std(image) < 1e-4:
|
@@ -98,30 +98,48 @@ def map_sample(
|
|
98
98
|
"original_width": original_width,
|
99
99
|
})
|
100
100
|
except Exception as e:
|
101
|
-
error_queue.put_nowait({
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
})
|
106
|
-
|
107
|
-
|
108
|
-
|
101
|
+
# error_queue.put_nowait({
|
102
|
+
# "url": url,
|
103
|
+
# "caption": caption,
|
104
|
+
# "error": str(e)
|
105
|
+
# })
|
106
|
+
pass
|
107
|
+
|
108
|
+
|
109
|
+
def map_batch(
|
110
|
+
batch, num_threads=256, image_shape=(256, 256),
|
111
|
+
min_image_shape=(128, 128),
|
112
|
+
timeout=15, retries=3, image_processor=default_image_processor,
|
113
|
+
upscale_interpolation=cv2.INTER_CUBIC,
|
114
|
+
downscale_interpolation=cv2.INTER_AREA,
|
115
|
+
):
|
109
116
|
try:
|
110
|
-
map_sample_fn = partial(map_sample, image_shape=image_shape,
|
111
|
-
timeout=timeout, retries=retries, image_processor=image_processor
|
117
|
+
map_sample_fn = partial(map_sample, image_shape=image_shape, min_image_shape=min_image_shape,
|
118
|
+
timeout=timeout, retries=retries, image_processor=image_processor,
|
119
|
+
upscale_interpolation=upscale_interpolation,
|
120
|
+
downscale_interpolation=downscale_interpolation)
|
112
121
|
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
113
122
|
executor.map(map_sample_fn, batch["url"], batch['caption'])
|
114
123
|
except Exception as e:
|
115
|
-
error_queue.
|
116
|
-
|
117
|
-
|
118
|
-
})
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
124
|
+
# error_queue.put_nowait({
|
125
|
+
# "batch": batch,
|
126
|
+
# "error": str(e)
|
127
|
+
# })
|
128
|
+
pass
|
129
|
+
|
130
|
+
|
131
|
+
def parallel_image_loader(
|
132
|
+
dataset: Dataset, num_workers: int = 8, image_shape=(256, 256),
|
133
|
+
min_image_shape=(128, 128),
|
134
|
+
num_threads=256, timeout=15, retries=3, image_processor=default_image_processor,
|
135
|
+
upscale_interpolation=cv2.INTER_CUBIC,
|
136
|
+
downscale_interpolation=cv2.INTER_AREA,
|
137
|
+
):
|
138
|
+
map_batch_fn = partial(map_batch, num_threads=num_threads, image_shape=image_shape,
|
139
|
+
min_image_shape=min_image_shape,
|
140
|
+
timeout=timeout, retries=retries, image_processor=image_processor,
|
141
|
+
upscale_interpolation=upscale_interpolation,
|
142
|
+
downscale_interpolation=downscale_interpolation)
|
125
143
|
shard_len = len(dataset) // num_workers
|
126
144
|
print(f"Local Shard lengths: {shard_len}")
|
127
145
|
with multiprocessing.Pool(num_workers) as pool:
|
@@ -136,19 +154,29 @@ def parallel_image_loader(dataset: Dataset, num_workers: int = 8, image_shape=(2
|
|
136
154
|
print(f"Shuffling dataset with seed {iteration}")
|
137
155
|
dataset = dataset.shuffle(seed=iteration)
|
138
156
|
# Clear the error queue
|
139
|
-
while not error_queue.empty():
|
140
|
-
|
157
|
+
# while not error_queue.empty():
|
158
|
+
# error_queue.get_nowait()
|
141
159
|
|
142
160
|
|
143
161
|
class ImageBatchIterator:
|
144
|
-
def __init__(
|
145
|
-
|
162
|
+
def __init__(
|
163
|
+
self, dataset: Dataset, batch_size: int = 64, image_shape=(256, 256),
|
164
|
+
min_image_shape=(128, 128),
|
165
|
+
num_workers: int = 8, num_threads=256, timeout=15, retries=3,
|
166
|
+
image_processor=default_image_processor,
|
167
|
+
upscale_interpolation=cv2.INTER_CUBIC,
|
168
|
+
downscale_interpolation=cv2.INTER_AREA,
|
169
|
+
):
|
146
170
|
self.dataset = dataset
|
147
171
|
self.num_workers = num_workers
|
148
172
|
self.batch_size = batch_size
|
149
173
|
loader = partial(parallel_image_loader, num_threads=num_threads,
|
150
|
-
image_shape=image_shape,
|
151
|
-
|
174
|
+
image_shape=image_shape,
|
175
|
+
min_image_shape=min_image_shape,
|
176
|
+
num_workers=num_workers,
|
177
|
+
timeout=timeout, retries=retries, image_processor=image_processor,
|
178
|
+
upscale_interpolation=upscale_interpolation,
|
179
|
+
downscale_interpolation=downscale_interpolation)
|
152
180
|
self.thread = threading.Thread(target=loader, args=(dataset,))
|
153
181
|
self.thread.start()
|
154
182
|
|
@@ -195,6 +223,7 @@ class OnlineStreamingDataLoader():
|
|
195
223
|
dataset,
|
196
224
|
batch_size=64,
|
197
225
|
image_shape=(256, 256),
|
226
|
+
min_image_shape=(128, 128),
|
198
227
|
num_workers=16,
|
199
228
|
num_threads=512,
|
200
229
|
default_split="all",
|
@@ -210,6 +239,8 @@ class OnlineStreamingDataLoader():
|
|
210
239
|
timeout=15,
|
211
240
|
retries=3,
|
212
241
|
image_processor=default_image_processor,
|
242
|
+
upscale_interpolation=cv2.INTER_CUBIC,
|
243
|
+
downscale_interpolation=cv2.INTER_AREA,
|
213
244
|
):
|
214
245
|
if isinstance(dataset, str):
|
215
246
|
dataset_path = dataset
|
@@ -231,8 +262,11 @@ class OnlineStreamingDataLoader():
|
|
231
262
|
num_shards=global_process_count, index=global_process_index)
|
232
263
|
print(f"Dataset length: {len(dataset)}")
|
233
264
|
self.iterator = ImageBatchIterator(self.dataset, image_shape=image_shape,
|
265
|
+
min_image_shape=min_image_shape,
|
234
266
|
num_workers=num_workers, batch_size=batch_size, num_threads=num_threads,
|
235
|
-
|
267
|
+
timeout=timeout, retries=retries, image_processor=image_processor,
|
268
|
+
upscale_interpolation=upscale_interpolation,
|
269
|
+
downscale_interpolation=downscale_interpolation)
|
236
270
|
self.batch_size = batch_size
|
237
271
|
|
238
272
|
# Launch a thread to load batches in the background
|
@@ -1,7 +1,7 @@
|
|
1
1
|
flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
flaxdiff/utils.py,sha256=B0GcHlzlVYDNEIdh2v5qmP4u0neIT-FqexNohuyuCvg,2452
|
3
3
|
flaxdiff/data/__init__.py,sha256=PM3PkHihyohT5SHVYKc8vQ4IeVfGPpCktkSVwvqMjQ4,52
|
4
|
-
flaxdiff/data/online_loader.py,sha256=
|
4
|
+
flaxdiff/data/online_loader.py,sha256=XVT_kT7v9CQVaQgunTL48KxgPgwQ-bhIi8RN-Q1qbYc,10451
|
5
5
|
flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,26
|
6
6
|
flaxdiff/models/attention.py,sha256=ZbDGIb5Q6FRqJ6qRY660cqw4WvF9IwCnhEuYdTpLPdM,13023
|
7
7
|
flaxdiff/models/common.py,sha256=fd-Fl0VCNEBjijHNwGBqYL5VvXe9u0347h25czNTmRw,10780
|
@@ -34,7 +34,7 @@ flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk,
|
|
34
34
|
flaxdiff/trainer/autoencoder_trainer.py,sha256=al7AsZ7yeDMEiDD-gbcXf0ADq_xfk1VMxvg24GfA-XQ,7008
|
35
35
|
flaxdiff/trainer/diffusion_trainer.py,sha256=wKkg63DWZjx2MoM3VQNCDIr40rWN8fUGxH9jWWxfZao,9373
|
36
36
|
flaxdiff/trainer/simple_trainer.py,sha256=Z77zRS5viJpd2Mpl6sonJk5WcnEWi2Cd4gl4u5tIX2M,18206
|
37
|
-
flaxdiff-0.1.
|
38
|
-
flaxdiff-0.1.
|
39
|
-
flaxdiff-0.1.
|
40
|
-
flaxdiff-0.1.
|
37
|
+
flaxdiff-0.1.20.dist-info/METADATA,sha256=ls0rUYnHBWdChfQ7meO2nlHSqGVEPn2JzZTOTagt2H8,22083
|
38
|
+
flaxdiff-0.1.20.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
|
39
|
+
flaxdiff-0.1.20.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
|
40
|
+
flaxdiff-0.1.20.dist-info/RECORD,,
|
File without changes
|
File without changes
|