flaxdiff 0.1.15__py3-none-any.whl → 0.1.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/data/online_loader.py +92 -46
- {flaxdiff-0.1.15.dist-info → flaxdiff-0.1.17.dist-info}/METADATA +1 -1
- {flaxdiff-0.1.15.dist-info → flaxdiff-0.1.17.dist-info}/RECORD +5 -5
- {flaxdiff-0.1.15.dist-info → flaxdiff-0.1.17.dist-info}/WHEEL +0 -0
- {flaxdiff-0.1.15.dist-info → flaxdiff-0.1.17.dist-info}/top_level.txt +0 -0
flaxdiff/data/online_loader.py
CHANGED
@@ -13,7 +13,7 @@ from typing import Any, Dict, List, Tuple
|
|
13
13
|
import numpy as np
|
14
14
|
from functools import partial
|
15
15
|
|
16
|
-
from datasets import load_dataset, concatenate_datasets, Dataset
|
16
|
+
from datasets import load_dataset, concatenate_datasets, Dataset, load_from_disk
|
17
17
|
from datasets.utils.file_utils import get_datasets_user_agent
|
18
18
|
from concurrent.futures import ThreadPoolExecutor
|
19
19
|
import io
|
@@ -43,17 +43,35 @@ def fetch_single_image(image_url, timeout=None, retries=0):
|
|
43
43
|
image = None
|
44
44
|
return image
|
45
45
|
|
46
|
+
|
47
|
+
def default_image_processor(image, image_shape, interpolation=cv2.INTER_LANCZOS4):
|
48
|
+
image = A.longest_max_size(image, max(
|
49
|
+
image_shape), interpolation=interpolation)
|
50
|
+
image = A.pad(
|
51
|
+
image,
|
52
|
+
min_height=image_shape[0],
|
53
|
+
min_width=image_shape[1],
|
54
|
+
border_mode=cv2.BORDER_CONSTANT,
|
55
|
+
value=[255, 255, 255],
|
56
|
+
)
|
57
|
+
return image
|
58
|
+
|
59
|
+
|
46
60
|
def map_sample(
|
47
|
-
url, caption,
|
61
|
+
url, caption,
|
48
62
|
image_shape=(256, 256),
|
63
|
+
timeout=15,
|
64
|
+
retries=3,
|
49
65
|
upscale_interpolation=cv2.INTER_LANCZOS4,
|
50
66
|
downscale_interpolation=cv2.INTER_AREA,
|
67
|
+
image_processor=default_image_processor,
|
51
68
|
):
|
52
69
|
try:
|
53
|
-
|
70
|
+
# Assuming fetch_single_image is defined elsewhere
|
71
|
+
image = fetch_single_image(url, timeout=timeout, retries=retries)
|
54
72
|
if image is None:
|
55
73
|
return
|
56
|
-
|
74
|
+
|
57
75
|
image = np.array(image)
|
58
76
|
original_height, original_width = image.shape[:2]
|
59
77
|
# check if the image is too small
|
@@ -68,14 +86,10 @@ def map_sample(
|
|
68
86
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
69
87
|
downscale = max(original_width, original_height) > max(image_shape)
|
70
88
|
interpolation = downscale_interpolation if downscale else upscale_interpolation
|
71
|
-
|
72
|
-
image =
|
73
|
-
image,
|
74
|
-
|
75
|
-
min_width=image_shape[1],
|
76
|
-
border_mode=cv2.BORDER_CONSTANT,
|
77
|
-
value=[255, 255, 255],
|
78
|
-
)
|
89
|
+
|
90
|
+
image = image_processor(
|
91
|
+
image, image_shape, interpolation=interpolation)
|
92
|
+
|
79
93
|
data_queue.put({
|
80
94
|
"url": url,
|
81
95
|
"caption": caption,
|
@@ -89,49 +103,69 @@ def map_sample(
|
|
89
103
|
"caption": caption,
|
90
104
|
"error": str(e)
|
91
105
|
})
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
106
|
+
|
107
|
+
|
108
|
+
def map_batch(batch, num_threads=256, image_shape=(256, 256), timeout=15, retries=3, image_processor=default_image_processor):
|
109
|
+
try:
|
110
|
+
map_sample_fn = partial(map_sample, image_shape=image_shape,
|
111
|
+
timeout=timeout, retries=retries, image_processor=image_processor)
|
112
|
+
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
113
|
+
executor.map(map_sample_fn, batch["url"], batch['caption'])
|
114
|
+
except Exception as e:
|
115
|
+
error_queue.put({
|
116
|
+
"batch": batch,
|
117
|
+
"error": str(e)
|
118
|
+
})
|
119
|
+
|
120
|
+
|
121
|
+
def parallel_image_loader(dataset: Dataset, num_workers: int = 8, image_shape=(256, 256),
|
122
|
+
num_threads=256, timeout=15, retries=3, image_processor=default_image_processor):
|
123
|
+
map_batch_fn = partial(map_batch, num_threads=num_threads, image_shape=image_shape,
|
124
|
+
timeout=timeout, retries=retries, image_processor=image_processor)
|
99
125
|
shard_len = len(dataset) // num_workers
|
100
126
|
print(f"Local Shard lengths: {shard_len}")
|
101
127
|
with multiprocessing.Pool(num_workers) as pool:
|
102
128
|
iteration = 0
|
103
129
|
while True:
|
104
130
|
# Repeat forever
|
105
|
-
|
106
|
-
|
131
|
+
shards = [dataset[i*shard_len:(i+1)*shard_len]
|
132
|
+
for i in range(num_workers)]
|
133
|
+
print(f"mapping {len(shards)} shards")
|
107
134
|
pool.map(map_batch_fn, shards)
|
108
135
|
iteration += 1
|
109
|
-
|
136
|
+
print(f"Shuffling dataset with seed {iteration}")
|
137
|
+
dataset = dataset.shuffle(seed=iteration)
|
138
|
+
|
139
|
+
|
110
140
|
class ImageBatchIterator:
|
111
|
-
def __init__(self, dataset: Dataset, batch_size: int = 64, image_shape=(256, 256),
|
141
|
+
def __init__(self, dataset: Dataset, batch_size: int = 64, image_shape=(256, 256),
|
142
|
+
num_workers: int = 8, num_threads=256, timeout=15, retries=3, image_processor=default_image_processor):
|
112
143
|
self.dataset = dataset
|
113
144
|
self.num_workers = num_workers
|
114
145
|
self.batch_size = batch_size
|
115
|
-
loader = partial(parallel_image_loader, num_threads=num_threads,
|
146
|
+
loader = partial(parallel_image_loader, num_threads=num_threads,
|
147
|
+
image_shape=image_shape, num_workers=num_workers,
|
148
|
+
timeout=timeout, retries=retries, image_processor=image_processor)
|
116
149
|
self.thread = threading.Thread(target=loader, args=(dataset,))
|
117
150
|
self.thread.start()
|
118
|
-
|
151
|
+
|
119
152
|
def __iter__(self):
|
120
153
|
return self
|
121
|
-
|
154
|
+
|
122
155
|
def __next__(self):
|
123
156
|
def fetcher(_):
|
124
157
|
return data_queue.get()
|
125
158
|
with ThreadPoolExecutor(max_workers=self.batch_size) as executor:
|
126
159
|
batch = list(executor.map(fetcher, range(self.batch_size)))
|
127
160
|
return batch
|
128
|
-
|
161
|
+
|
129
162
|
def __del__(self):
|
130
163
|
self.thread.join()
|
131
|
-
|
164
|
+
|
132
165
|
def __len__(self):
|
133
166
|
return len(self.dataset) // self.batch_size
|
134
167
|
|
168
|
+
|
135
169
|
def default_collate(batch):
|
136
170
|
urls = [sample["url"] for sample in batch]
|
137
171
|
captions = [sample["caption"] for sample in batch]
|
@@ -141,7 +175,8 @@ def default_collate(batch):
|
|
141
175
|
"caption": captions,
|
142
176
|
"image": images,
|
143
177
|
}
|
144
|
-
|
178
|
+
|
179
|
+
|
145
180
|
def dataMapper(map: Dict[str, Any]):
|
146
181
|
def _map(sample) -> Dict[str, Any]:
|
147
182
|
return {
|
@@ -150,16 +185,17 @@ def dataMapper(map: Dict[str, Any]):
|
|
150
185
|
}
|
151
186
|
return _map
|
152
187
|
|
188
|
+
|
153
189
|
class OnlineStreamingDataLoader():
|
154
190
|
def __init__(
|
155
|
-
self,
|
156
|
-
dataset,
|
157
|
-
batch_size=64,
|
191
|
+
self,
|
192
|
+
dataset,
|
193
|
+
batch_size=64,
|
158
194
|
image_shape=(256, 256),
|
159
|
-
num_workers=16,
|
195
|
+
num_workers=16,
|
160
196
|
num_threads=512,
|
161
197
|
default_split="all",
|
162
|
-
pre_map_maker=dataMapper,
|
198
|
+
pre_map_maker=dataMapper,
|
163
199
|
pre_map_def={
|
164
200
|
"url": "URL",
|
165
201
|
"caption": "TEXT",
|
@@ -168,41 +204,51 @@ class OnlineStreamingDataLoader():
|
|
168
204
|
global_process_index=0,
|
169
205
|
prefetch=1000,
|
170
206
|
collate_fn=default_collate,
|
207
|
+
timeout=15,
|
208
|
+
retries=3,
|
209
|
+
image_processor=default_image_processor,
|
171
210
|
):
|
172
211
|
if isinstance(dataset, str):
|
173
212
|
dataset_path = dataset
|
174
213
|
print("Loading dataset from path")
|
175
|
-
|
214
|
+
if "gs://" in dataset:
|
215
|
+
dataset = load_from_disk(dataset_path)
|
216
|
+
else:
|
217
|
+
dataset = load_dataset(dataset_path, split=default_split)
|
176
218
|
elif isinstance(dataset, list):
|
177
219
|
if isinstance(dataset[0], str):
|
178
220
|
print("Loading multiple datasets from paths")
|
179
|
-
dataset = [
|
221
|
+
dataset = [load_from_disk(dataset_path) if "gs://" in dataset_path else load_dataset(
|
222
|
+
dataset_path, split=default_split) for dataset_path in dataset]
|
180
223
|
print("Concatenating multiple datasets")
|
181
224
|
dataset = concatenate_datasets(dataset)
|
182
|
-
|
183
|
-
|
225
|
+
dataset = dataset.shuffle(seed=0)
|
226
|
+
# dataset = dataset.map(pre_map_maker(pre_map_def), batched=True, batch_size=10000000)
|
227
|
+
self.dataset = dataset.shard(
|
228
|
+
num_shards=global_process_count, index=global_process_index)
|
184
229
|
print(f"Dataset length: {len(dataset)}")
|
185
|
-
self.iterator = ImageBatchIterator(self.dataset, image_shape=image_shape,
|
230
|
+
self.iterator = ImageBatchIterator(self.dataset, image_shape=image_shape,
|
231
|
+
num_workers=num_workers, batch_size=batch_size, num_threads=num_threads,
|
232
|
+
timeout=timeout, retries=retries, image_processor=image_processor)
|
186
233
|
self.collate_fn = collate_fn
|
187
234
|
self.batch_size = batch_size
|
188
|
-
|
235
|
+
|
189
236
|
# Launch a thread to load batches in the background
|
190
237
|
self.batch_queue = queue.Queue(prefetch)
|
191
|
-
|
238
|
+
|
192
239
|
def batch_loader():
|
193
240
|
for batch in self.iterator:
|
194
241
|
self.batch_queue.put(batch)
|
195
|
-
|
242
|
+
|
196
243
|
self.loader_thread = threading.Thread(target=batch_loader)
|
197
244
|
self.loader_thread.start()
|
198
|
-
|
245
|
+
|
199
246
|
def __iter__(self):
|
200
247
|
return self
|
201
|
-
|
248
|
+
|
202
249
|
def __next__(self):
|
203
250
|
return self.collate_fn(self.batch_queue.get())
|
204
251
|
# return self.collate_fn(next(self.iterator))
|
205
|
-
|
252
|
+
|
206
253
|
def __len__(self):
|
207
254
|
return len(self.dataset)
|
208
|
-
|
@@ -1,7 +1,7 @@
|
|
1
1
|
flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
flaxdiff/utils.py,sha256=B0GcHlzlVYDNEIdh2v5qmP4u0neIT-FqexNohuyuCvg,2452
|
3
3
|
flaxdiff/data/__init__.py,sha256=PM3PkHihyohT5SHVYKc8vQ4IeVfGPpCktkSVwvqMjQ4,52
|
4
|
-
flaxdiff/data/online_loader.py,sha256=
|
4
|
+
flaxdiff/data/online_loader.py,sha256=BM4Le-4BUo8MJpRzGIA2nMHKm4-WynQ2BOdiQz0JCDs,8791
|
5
5
|
flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,26
|
6
6
|
flaxdiff/models/attention.py,sha256=ZbDGIb5Q6FRqJ6qRY660cqw4WvF9IwCnhEuYdTpLPdM,13023
|
7
7
|
flaxdiff/models/common.py,sha256=fd-Fl0VCNEBjijHNwGBqYL5VvXe9u0347h25czNTmRw,10780
|
@@ -34,7 +34,7 @@ flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk,
|
|
34
34
|
flaxdiff/trainer/autoencoder_trainer.py,sha256=al7AsZ7yeDMEiDD-gbcXf0ADq_xfk1VMxvg24GfA-XQ,7008
|
35
35
|
flaxdiff/trainer/diffusion_trainer.py,sha256=wKkg63DWZjx2MoM3VQNCDIr40rWN8fUGxH9jWWxfZao,9373
|
36
36
|
flaxdiff/trainer/simple_trainer.py,sha256=Z77zRS5viJpd2Mpl6sonJk5WcnEWi2Cd4gl4u5tIX2M,18206
|
37
|
-
flaxdiff-0.1.
|
38
|
-
flaxdiff-0.1.
|
39
|
-
flaxdiff-0.1.
|
40
|
-
flaxdiff-0.1.
|
37
|
+
flaxdiff-0.1.17.dist-info/METADATA,sha256=2Nr_T2yg3XHFt2jBuUXo8FxLYM8si-DBLdW_PBKxzc4,22083
|
38
|
+
flaxdiff-0.1.17.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
|
39
|
+
flaxdiff-0.1.17.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
|
40
|
+
flaxdiff-0.1.17.dist-info/RECORD,,
|
File without changes
|
File without changes
|