flaxdiff 0.1.19__py3-none-any.whl → 0.1.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/data/online_loader.py +32 -20
- {flaxdiff-0.1.19.dist-info → flaxdiff-0.1.21.dist-info}/METADATA +1 -1
- {flaxdiff-0.1.19.dist-info → flaxdiff-0.1.21.dist-info}/RECORD +5 -5
- {flaxdiff-0.1.19.dist-info → flaxdiff-0.1.21.dist-info}/WHEEL +0 -0
- {flaxdiff-0.1.19.dist-info → flaxdiff-0.1.21.dist-info}/top_level.txt +0 -0
flaxdiff/data/online_loader.py
CHANGED
@@ -25,7 +25,6 @@ import cv2
|
|
25
25
|
USER_AGENT = get_datasets_user_agent()
|
26
26
|
|
27
27
|
data_queue = Queue(16*2000)
|
28
|
-
error_queue = Queue()
|
29
28
|
|
30
29
|
|
31
30
|
def fetch_single_image(image_url, timeout=None, retries=0):
|
@@ -60,6 +59,7 @@ def default_image_processor(image, image_shape, interpolation=cv2.INTER_CUBIC):
|
|
60
59
|
def map_sample(
|
61
60
|
url, caption,
|
62
61
|
image_shape=(256, 256),
|
62
|
+
min_image_shape=(128, 128),
|
63
63
|
timeout=15,
|
64
64
|
retries=3,
|
65
65
|
upscale_interpolation=cv2.INTER_CUBIC,
|
@@ -75,21 +75,21 @@ def map_sample(
|
|
75
75
|
image = np.array(image)
|
76
76
|
original_height, original_width = image.shape[:2]
|
77
77
|
# check if the image is too small
|
78
|
-
if min(original_height, original_width) < min(
|
78
|
+
if min(original_height, original_width) < min(min_image_shape):
|
79
79
|
return
|
80
80
|
# check if wrong aspect ratio
|
81
|
-
if max(original_height, original_width) / min(original_height, original_width) > 2:
|
81
|
+
if max(original_height, original_width) / min(original_height, original_width) > 2.4:
|
82
82
|
return
|
83
83
|
# check if the variance is too low
|
84
84
|
if np.std(image) < 1e-4:
|
85
85
|
return
|
86
|
-
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
86
|
+
# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
87
87
|
downscale = max(original_width, original_height) > max(image_shape)
|
88
88
|
interpolation = downscale_interpolation if downscale else upscale_interpolation
|
89
89
|
|
90
90
|
image = image_processor(
|
91
91
|
image, image_shape, interpolation=interpolation)
|
92
|
-
|
92
|
+
|
93
93
|
data_queue.put({
|
94
94
|
"url": url,
|
95
95
|
"caption": caption,
|
@@ -98,40 +98,47 @@ def map_sample(
|
|
98
98
|
"original_width": original_width,
|
99
99
|
})
|
100
100
|
except Exception as e:
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
101
|
+
print(f"Error processing {url}", e)
|
102
|
+
# error_queue.put_nowait({
|
103
|
+
# "url": url,
|
104
|
+
# "caption": caption,
|
105
|
+
# "error": str(e)
|
106
|
+
# })
|
107
|
+
pass
|
106
108
|
|
107
109
|
|
108
110
|
def map_batch(
|
109
111
|
batch, num_threads=256, image_shape=(256, 256),
|
112
|
+
min_image_shape=(128, 128),
|
110
113
|
timeout=15, retries=3, image_processor=default_image_processor,
|
111
114
|
upscale_interpolation=cv2.INTER_CUBIC,
|
112
115
|
downscale_interpolation=cv2.INTER_AREA,
|
113
116
|
):
|
114
117
|
try:
|
115
|
-
map_sample_fn = partial(map_sample, image_shape=image_shape,
|
118
|
+
map_sample_fn = partial(map_sample, image_shape=image_shape, min_image_shape=min_image_shape,
|
116
119
|
timeout=timeout, retries=retries, image_processor=image_processor,
|
117
120
|
upscale_interpolation=upscale_interpolation,
|
118
121
|
downscale_interpolation=downscale_interpolation)
|
119
122
|
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
120
123
|
executor.map(map_sample_fn, batch["url"], batch['caption'])
|
121
124
|
except Exception as e:
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
125
|
+
print(f"Error processing batch", e)
|
126
|
+
# error_queue.put_nowait({
|
127
|
+
# "batch": batch,
|
128
|
+
# "error": str(e)
|
129
|
+
# })
|
130
|
+
pass
|
126
131
|
|
127
132
|
|
128
133
|
def parallel_image_loader(
|
129
134
|
dataset: Dataset, num_workers: int = 8, image_shape=(256, 256),
|
135
|
+
min_image_shape=(128, 128),
|
130
136
|
num_threads=256, timeout=15, retries=3, image_processor=default_image_processor,
|
131
137
|
upscale_interpolation=cv2.INTER_CUBIC,
|
132
138
|
downscale_interpolation=cv2.INTER_AREA,
|
133
139
|
):
|
134
|
-
map_batch_fn = partial(map_batch, num_threads=num_threads, image_shape=image_shape,
|
140
|
+
map_batch_fn = partial(map_batch, num_threads=num_threads, image_shape=image_shape,
|
141
|
+
min_image_shape=min_image_shape,
|
135
142
|
timeout=timeout, retries=retries, image_processor=image_processor,
|
136
143
|
upscale_interpolation=upscale_interpolation,
|
137
144
|
downscale_interpolation=downscale_interpolation)
|
@@ -149,13 +156,14 @@ def parallel_image_loader(
|
|
149
156
|
print(f"Shuffling dataset with seed {iteration}")
|
150
157
|
dataset = dataset.shuffle(seed=iteration)
|
151
158
|
# Clear the error queue
|
152
|
-
while not error_queue.empty():
|
153
|
-
|
159
|
+
# while not error_queue.empty():
|
160
|
+
# error_queue.get_nowait()
|
154
161
|
|
155
162
|
|
156
163
|
class ImageBatchIterator:
|
157
164
|
def __init__(
|
158
165
|
self, dataset: Dataset, batch_size: int = 64, image_shape=(256, 256),
|
166
|
+
min_image_shape=(128, 128),
|
159
167
|
num_workers: int = 8, num_threads=256, timeout=15, retries=3,
|
160
168
|
image_processor=default_image_processor,
|
161
169
|
upscale_interpolation=cv2.INTER_CUBIC,
|
@@ -165,7 +173,9 @@ class ImageBatchIterator:
|
|
165
173
|
self.num_workers = num_workers
|
166
174
|
self.batch_size = batch_size
|
167
175
|
loader = partial(parallel_image_loader, num_threads=num_threads,
|
168
|
-
image_shape=image_shape,
|
176
|
+
image_shape=image_shape,
|
177
|
+
min_image_shape=min_image_shape,
|
178
|
+
num_workers=num_workers,
|
169
179
|
timeout=timeout, retries=retries, image_processor=image_processor,
|
170
180
|
upscale_interpolation=upscale_interpolation,
|
171
181
|
downscale_interpolation=downscale_interpolation)
|
@@ -215,6 +225,7 @@ class OnlineStreamingDataLoader():
|
|
215
225
|
dataset,
|
216
226
|
batch_size=64,
|
217
227
|
image_shape=(256, 256),
|
228
|
+
min_image_shape=(128, 128),
|
218
229
|
num_workers=16,
|
219
230
|
num_threads=512,
|
220
231
|
default_split="all",
|
@@ -253,8 +264,9 @@ class OnlineStreamingDataLoader():
|
|
253
264
|
num_shards=global_process_count, index=global_process_index)
|
254
265
|
print(f"Dataset length: {len(dataset)}")
|
255
266
|
self.iterator = ImageBatchIterator(self.dataset, image_shape=image_shape,
|
267
|
+
min_image_shape=min_image_shape,
|
256
268
|
num_workers=num_workers, batch_size=batch_size, num_threads=num_threads,
|
257
|
-
|
269
|
+
timeout=timeout, retries=retries, image_processor=image_processor,
|
258
270
|
upscale_interpolation=upscale_interpolation,
|
259
271
|
downscale_interpolation=downscale_interpolation)
|
260
272
|
self.batch_size = batch_size
|
@@ -1,7 +1,7 @@
|
|
1
1
|
flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
flaxdiff/utils.py,sha256=B0GcHlzlVYDNEIdh2v5qmP4u0neIT-FqexNohuyuCvg,2452
|
3
3
|
flaxdiff/data/__init__.py,sha256=PM3PkHihyohT5SHVYKc8vQ4IeVfGPpCktkSVwvqMjQ4,52
|
4
|
-
flaxdiff/data/online_loader.py,sha256=
|
4
|
+
flaxdiff/data/online_loader.py,sha256=w6gi1tAzWr4gtPt7onpStzOxp7Kdo_2q8Ro4Yi7OT4w,10549
|
5
5
|
flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,26
|
6
6
|
flaxdiff/models/attention.py,sha256=ZbDGIb5Q6FRqJ6qRY660cqw4WvF9IwCnhEuYdTpLPdM,13023
|
7
7
|
flaxdiff/models/common.py,sha256=fd-Fl0VCNEBjijHNwGBqYL5VvXe9u0347h25czNTmRw,10780
|
@@ -34,7 +34,7 @@ flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk,
|
|
34
34
|
flaxdiff/trainer/autoencoder_trainer.py,sha256=al7AsZ7yeDMEiDD-gbcXf0ADq_xfk1VMxvg24GfA-XQ,7008
|
35
35
|
flaxdiff/trainer/diffusion_trainer.py,sha256=wKkg63DWZjx2MoM3VQNCDIr40rWN8fUGxH9jWWxfZao,9373
|
36
36
|
flaxdiff/trainer/simple_trainer.py,sha256=Z77zRS5viJpd2Mpl6sonJk5WcnEWi2Cd4gl4u5tIX2M,18206
|
37
|
-
flaxdiff-0.1.
|
38
|
-
flaxdiff-0.1.
|
39
|
-
flaxdiff-0.1.
|
40
|
-
flaxdiff-0.1.
|
37
|
+
flaxdiff-0.1.21.dist-info/METADATA,sha256=k1s_EIWBL0y4oCxXxr3QIi7LQbR47_jyDFfjVbURSMY,22083
|
38
|
+
flaxdiff-0.1.21.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
|
39
|
+
flaxdiff-0.1.21.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
|
40
|
+
flaxdiff-0.1.21.dist-info/RECORD,,
|
File without changes
|
File without changes
|