flaxdiff 0.1.18__py3-none-any.whl → 0.1.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/data/online_loader.py +35 -11
- {flaxdiff-0.1.18.dist-info → flaxdiff-0.1.19.dist-info}/METADATA +1 -1
- {flaxdiff-0.1.18.dist-info → flaxdiff-0.1.19.dist-info}/RECORD +5 -5
- {flaxdiff-0.1.18.dist-info → flaxdiff-0.1.19.dist-info}/WHEEL +0 -0
- {flaxdiff-0.1.18.dist-info → flaxdiff-0.1.19.dist-info}/top_level.txt +0 -0
flaxdiff/data/online_loader.py
CHANGED
@@ -44,7 +44,7 @@ def fetch_single_image(image_url, timeout=None, retries=0):
|
|
44
44
|
return image
|
45
45
|
|
46
46
|
|
47
|
-
def default_image_processor(image, image_shape, interpolation=cv2.
|
47
|
+
def default_image_processor(image, image_shape, interpolation=cv2.INTER_CUBIC):
|
48
48
|
image = A.longest_max_size(image, max(
|
49
49
|
image_shape), interpolation=interpolation)
|
50
50
|
image = A.pad(
|
@@ -62,7 +62,7 @@ def map_sample(
|
|
62
62
|
image_shape=(256, 256),
|
63
63
|
timeout=15,
|
64
64
|
retries=3,
|
65
|
-
upscale_interpolation=cv2.
|
65
|
+
upscale_interpolation=cv2.INTER_CUBIC,
|
66
66
|
downscale_interpolation=cv2.INTER_AREA,
|
67
67
|
image_processor=default_image_processor,
|
68
68
|
):
|
@@ -105,10 +105,17 @@ def map_sample(
|
|
105
105
|
})
|
106
106
|
|
107
107
|
|
108
|
-
def map_batch(
|
108
|
+
def map_batch(
|
109
|
+
batch, num_threads=256, image_shape=(256, 256),
|
110
|
+
timeout=15, retries=3, image_processor=default_image_processor,
|
111
|
+
upscale_interpolation=cv2.INTER_CUBIC,
|
112
|
+
downscale_interpolation=cv2.INTER_AREA,
|
113
|
+
):
|
109
114
|
try:
|
110
115
|
map_sample_fn = partial(map_sample, image_shape=image_shape,
|
111
|
-
timeout=timeout, retries=retries, image_processor=image_processor
|
116
|
+
timeout=timeout, retries=retries, image_processor=image_processor,
|
117
|
+
upscale_interpolation=upscale_interpolation,
|
118
|
+
downscale_interpolation=downscale_interpolation)
|
112
119
|
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
113
120
|
executor.map(map_sample_fn, batch["url"], batch['caption'])
|
114
121
|
except Exception as e:
|
@@ -118,10 +125,16 @@ def map_batch(batch, num_threads=256, image_shape=(256, 256), timeout=15, retrie
|
|
118
125
|
})
|
119
126
|
|
120
127
|
|
121
|
-
def parallel_image_loader(
|
122
|
-
|
128
|
+
def parallel_image_loader(
|
129
|
+
dataset: Dataset, num_workers: int = 8, image_shape=(256, 256),
|
130
|
+
num_threads=256, timeout=15, retries=3, image_processor=default_image_processor,
|
131
|
+
upscale_interpolation=cv2.INTER_CUBIC,
|
132
|
+
downscale_interpolation=cv2.INTER_AREA,
|
133
|
+
):
|
123
134
|
map_batch_fn = partial(map_batch, num_threads=num_threads, image_shape=image_shape,
|
124
|
-
timeout=timeout, retries=retries, image_processor=image_processor
|
135
|
+
timeout=timeout, retries=retries, image_processor=image_processor,
|
136
|
+
upscale_interpolation=upscale_interpolation,
|
137
|
+
downscale_interpolation=downscale_interpolation)
|
125
138
|
shard_len = len(dataset) // num_workers
|
126
139
|
print(f"Local Shard lengths: {shard_len}")
|
127
140
|
with multiprocessing.Pool(num_workers) as pool:
|
@@ -141,14 +154,21 @@ def parallel_image_loader(dataset: Dataset, num_workers: int = 8, image_shape=(2
|
|
141
154
|
|
142
155
|
|
143
156
|
class ImageBatchIterator:
|
144
|
-
def __init__(
|
145
|
-
|
157
|
+
def __init__(
|
158
|
+
self, dataset: Dataset, batch_size: int = 64, image_shape=(256, 256),
|
159
|
+
num_workers: int = 8, num_threads=256, timeout=15, retries=3,
|
160
|
+
image_processor=default_image_processor,
|
161
|
+
upscale_interpolation=cv2.INTER_CUBIC,
|
162
|
+
downscale_interpolation=cv2.INTER_AREA,
|
163
|
+
):
|
146
164
|
self.dataset = dataset
|
147
165
|
self.num_workers = num_workers
|
148
166
|
self.batch_size = batch_size
|
149
167
|
loader = partial(parallel_image_loader, num_threads=num_threads,
|
150
168
|
image_shape=image_shape, num_workers=num_workers,
|
151
|
-
timeout=timeout, retries=retries, image_processor=image_processor
|
169
|
+
timeout=timeout, retries=retries, image_processor=image_processor,
|
170
|
+
upscale_interpolation=upscale_interpolation,
|
171
|
+
downscale_interpolation=downscale_interpolation)
|
152
172
|
self.thread = threading.Thread(target=loader, args=(dataset,))
|
153
173
|
self.thread.start()
|
154
174
|
|
@@ -210,6 +230,8 @@ class OnlineStreamingDataLoader():
|
|
210
230
|
timeout=15,
|
211
231
|
retries=3,
|
212
232
|
image_processor=default_image_processor,
|
233
|
+
upscale_interpolation=cv2.INTER_CUBIC,
|
234
|
+
downscale_interpolation=cv2.INTER_AREA,
|
213
235
|
):
|
214
236
|
if isinstance(dataset, str):
|
215
237
|
dataset_path = dataset
|
@@ -232,7 +254,9 @@ class OnlineStreamingDataLoader():
|
|
232
254
|
print(f"Dataset length: {len(dataset)}")
|
233
255
|
self.iterator = ImageBatchIterator(self.dataset, image_shape=image_shape,
|
234
256
|
num_workers=num_workers, batch_size=batch_size, num_threads=num_threads,
|
235
|
-
timeout=timeout, retries=retries, image_processor=image_processor
|
257
|
+
timeout=timeout, retries=retries, image_processor=image_processor,
|
258
|
+
upscale_interpolation=upscale_interpolation,
|
259
|
+
downscale_interpolation=downscale_interpolation)
|
236
260
|
self.batch_size = batch_size
|
237
261
|
|
238
262
|
# Launch a thread to load batches in the background
|
@@ -1,7 +1,7 @@
|
|
1
1
|
flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
flaxdiff/utils.py,sha256=B0GcHlzlVYDNEIdh2v5qmP4u0neIT-FqexNohuyuCvg,2452
|
3
3
|
flaxdiff/data/__init__.py,sha256=PM3PkHihyohT5SHVYKc8vQ4IeVfGPpCktkSVwvqMjQ4,52
|
4
|
-
flaxdiff/data/online_loader.py,sha256=
|
4
|
+
flaxdiff/data/online_loader.py,sha256=WK4apO8Bx-RTU_z5imB53Lzq12vqGnXA9DhLq8nb0us,9991
|
5
5
|
flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,26
|
6
6
|
flaxdiff/models/attention.py,sha256=ZbDGIb5Q6FRqJ6qRY660cqw4WvF9IwCnhEuYdTpLPdM,13023
|
7
7
|
flaxdiff/models/common.py,sha256=fd-Fl0VCNEBjijHNwGBqYL5VvXe9u0347h25czNTmRw,10780
|
@@ -34,7 +34,7 @@ flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk,
|
|
34
34
|
flaxdiff/trainer/autoencoder_trainer.py,sha256=al7AsZ7yeDMEiDD-gbcXf0ADq_xfk1VMxvg24GfA-XQ,7008
|
35
35
|
flaxdiff/trainer/diffusion_trainer.py,sha256=wKkg63DWZjx2MoM3VQNCDIr40rWN8fUGxH9jWWxfZao,9373
|
36
36
|
flaxdiff/trainer/simple_trainer.py,sha256=Z77zRS5viJpd2Mpl6sonJk5WcnEWi2Cd4gl4u5tIX2M,18206
|
37
|
-
flaxdiff-0.1.
|
38
|
-
flaxdiff-0.1.
|
39
|
-
flaxdiff-0.1.
|
40
|
-
flaxdiff-0.1.
|
37
|
+
flaxdiff-0.1.19.dist-info/METADATA,sha256=NH-f1SK5obamoVRk8ZPQxvtQcz_R3mui3ToZe0Qx8Vg,22083
|
38
|
+
flaxdiff-0.1.19.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
|
39
|
+
flaxdiff-0.1.19.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
|
40
|
+
flaxdiff-0.1.19.dist-info/RECORD,,
|
File without changes
|
File without changes
|