flaxdiff 0.1.16__py3-none-any.whl → 0.1.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/data/online_loader.py +92 -52
- {flaxdiff-0.1.16.dist-info → flaxdiff-0.1.18.dist-info}/METADATA +1 -1
- {flaxdiff-0.1.16.dist-info → flaxdiff-0.1.18.dist-info}/RECORD +5 -5
- {flaxdiff-0.1.16.dist-info → flaxdiff-0.1.18.dist-info}/WHEEL +0 -0
- {flaxdiff-0.1.16.dist-info → flaxdiff-0.1.18.dist-info}/top_level.txt +0 -0
flaxdiff/data/online_loader.py
CHANGED
@@ -13,7 +13,7 @@ from typing import Any, Dict, List, Tuple
|
|
13
13
|
import numpy as np
|
14
14
|
from functools import partial
|
15
15
|
|
16
|
-
from datasets import load_dataset, concatenate_datasets, Dataset
|
16
|
+
from datasets import load_dataset, concatenate_datasets, Dataset, load_from_disk
|
17
17
|
from datasets.utils.file_utils import get_datasets_user_agent
|
18
18
|
from concurrent.futures import ThreadPoolExecutor
|
19
19
|
import io
|
@@ -25,7 +25,8 @@ import cv2
|
|
25
25
|
USER_AGENT = get_datasets_user_agent()
|
26
26
|
|
27
27
|
data_queue = Queue(16*2000)
|
28
|
-
error_queue = Queue(
|
28
|
+
error_queue = Queue()
|
29
|
+
|
29
30
|
|
30
31
|
def fetch_single_image(image_url, timeout=None, retries=0):
|
31
32
|
for _ in range(retries + 1):
|
@@ -42,19 +43,35 @@ def fetch_single_image(image_url, timeout=None, retries=0):
|
|
42
43
|
image = None
|
43
44
|
return image
|
44
45
|
|
46
|
+
|
47
|
+
def default_image_processor(image, image_shape, interpolation=cv2.INTER_LANCZOS4):
|
48
|
+
image = A.longest_max_size(image, max(
|
49
|
+
image_shape), interpolation=interpolation)
|
50
|
+
image = A.pad(
|
51
|
+
image,
|
52
|
+
min_height=image_shape[0],
|
53
|
+
min_width=image_shape[1],
|
54
|
+
border_mode=cv2.BORDER_CONSTANT,
|
55
|
+
value=[255, 255, 255],
|
56
|
+
)
|
57
|
+
return image
|
58
|
+
|
59
|
+
|
45
60
|
def map_sample(
|
46
|
-
url, caption,
|
61
|
+
url, caption,
|
47
62
|
image_shape=(256, 256),
|
48
63
|
timeout=15,
|
49
64
|
retries=3,
|
50
65
|
upscale_interpolation=cv2.INTER_LANCZOS4,
|
51
66
|
downscale_interpolation=cv2.INTER_AREA,
|
67
|
+
image_processor=default_image_processor,
|
52
68
|
):
|
53
69
|
try:
|
54
|
-
|
70
|
+
# Assuming fetch_single_image is defined elsewhere
|
71
|
+
image = fetch_single_image(url, timeout=timeout, retries=retries)
|
55
72
|
if image is None:
|
56
73
|
return
|
57
|
-
|
74
|
+
|
58
75
|
image = np.array(image)
|
59
76
|
original_height, original_width = image.shape[:2]
|
60
77
|
# check if the image is too small
|
@@ -69,14 +86,10 @@ def map_sample(
|
|
69
86
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
70
87
|
downscale = max(original_width, original_height) > max(image_shape)
|
71
88
|
interpolation = downscale_interpolation if downscale else upscale_interpolation
|
72
|
-
|
73
|
-
image =
|
74
|
-
image,
|
75
|
-
|
76
|
-
min_width=image_shape[1],
|
77
|
-
border_mode=cv2.BORDER_CONSTANT,
|
78
|
-
value=[255, 255, 255],
|
79
|
-
)
|
89
|
+
|
90
|
+
image = image_processor(
|
91
|
+
image, image_shape, interpolation=interpolation)
|
92
|
+
|
80
93
|
data_queue.put({
|
81
94
|
"url": url,
|
82
95
|
"caption": caption,
|
@@ -85,65 +98,77 @@ def map_sample(
|
|
85
98
|
"original_width": original_width,
|
86
99
|
})
|
87
100
|
except Exception as e:
|
88
|
-
|
89
|
-
error_queue.put({
|
101
|
+
error_queue.put_nowait({
|
90
102
|
"url": url,
|
91
103
|
"caption": caption,
|
92
104
|
"error": str(e)
|
93
105
|
})
|
94
106
|
|
95
|
-
|
107
|
+
|
108
|
+
def map_batch(batch, num_threads=256, image_shape=(256, 256), timeout=15, retries=3, image_processor=default_image_processor):
|
96
109
|
try:
|
97
|
-
map_sample_fn = partial(map_sample, image_shape=image_shape,
|
110
|
+
map_sample_fn = partial(map_sample, image_shape=image_shape,
|
111
|
+
timeout=timeout, retries=retries, image_processor=image_processor)
|
98
112
|
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
99
113
|
executor.map(map_sample_fn, batch["url"], batch['caption'])
|
100
114
|
except Exception as e:
|
101
|
-
print(f"Error in map_batch: {str(e)}")
|
102
115
|
error_queue.put({
|
103
116
|
"batch": batch,
|
104
117
|
"error": str(e)
|
105
118
|
})
|
106
|
-
|
107
|
-
|
108
|
-
|
119
|
+
|
120
|
+
|
121
|
+
def parallel_image_loader(dataset: Dataset, num_workers: int = 8, image_shape=(256, 256),
|
122
|
+
num_threads=256, timeout=15, retries=3, image_processor=default_image_processor):
|
123
|
+
map_batch_fn = partial(map_batch, num_threads=num_threads, image_shape=image_shape,
|
124
|
+
timeout=timeout, retries=retries, image_processor=image_processor)
|
109
125
|
shard_len = len(dataset) // num_workers
|
110
126
|
print(f"Local Shard lengths: {shard_len}")
|
111
127
|
with multiprocessing.Pool(num_workers) as pool:
|
112
128
|
iteration = 0
|
113
129
|
while True:
|
114
130
|
# Repeat forever
|
115
|
-
|
116
|
-
|
117
|
-
shards = [dataset[i*shard_len:(i+1)*shard_len] for i in range(num_workers)]
|
131
|
+
shards = [dataset[i*shard_len:(i+1)*shard_len]
|
132
|
+
for i in range(num_workers)]
|
118
133
|
print(f"mapping {len(shards)} shards")
|
119
134
|
pool.map(map_batch_fn, shards)
|
120
135
|
iteration += 1
|
121
|
-
|
136
|
+
print(f"Shuffling dataset with seed {iteration}")
|
137
|
+
dataset = dataset.shuffle(seed=iteration)
|
138
|
+
# Clear the error queue
|
139
|
+
while not error_queue.empty():
|
140
|
+
error_queue.get_nowait()
|
141
|
+
|
142
|
+
|
122
143
|
class ImageBatchIterator:
|
123
|
-
def __init__(self, dataset: Dataset, batch_size: int = 64, image_shape=(256, 256),
|
144
|
+
def __init__(self, dataset: Dataset, batch_size: int = 64, image_shape=(256, 256),
|
145
|
+
num_workers: int = 8, num_threads=256, timeout=15, retries=3, image_processor=default_image_processor):
|
124
146
|
self.dataset = dataset
|
125
147
|
self.num_workers = num_workers
|
126
148
|
self.batch_size = batch_size
|
127
|
-
loader = partial(parallel_image_loader, num_threads=num_threads,
|
149
|
+
loader = partial(parallel_image_loader, num_threads=num_threads,
|
150
|
+
image_shape=image_shape, num_workers=num_workers,
|
151
|
+
timeout=timeout, retries=retries, image_processor=image_processor)
|
128
152
|
self.thread = threading.Thread(target=loader, args=(dataset,))
|
129
153
|
self.thread.start()
|
130
|
-
|
154
|
+
|
131
155
|
def __iter__(self):
|
132
156
|
return self
|
133
|
-
|
157
|
+
|
134
158
|
def __next__(self):
|
135
159
|
def fetcher(_):
|
136
160
|
return data_queue.get()
|
137
161
|
with ThreadPoolExecutor(max_workers=self.batch_size) as executor:
|
138
162
|
batch = list(executor.map(fetcher, range(self.batch_size)))
|
139
163
|
return batch
|
140
|
-
|
164
|
+
|
141
165
|
def __del__(self):
|
142
166
|
self.thread.join()
|
143
|
-
|
167
|
+
|
144
168
|
def __len__(self):
|
145
169
|
return len(self.dataset) // self.batch_size
|
146
170
|
|
171
|
+
|
147
172
|
def default_collate(batch):
|
148
173
|
urls = [sample["url"] for sample in batch]
|
149
174
|
captions = [sample["caption"] for sample in batch]
|
@@ -153,7 +178,8 @@ def default_collate(batch):
|
|
153
178
|
"caption": captions,
|
154
179
|
"image": images,
|
155
180
|
}
|
156
|
-
|
181
|
+
|
182
|
+
|
157
183
|
def dataMapper(map: Dict[str, Any]):
|
158
184
|
def _map(sample) -> Dict[str, Any]:
|
159
185
|
return {
|
@@ -162,16 +188,17 @@ def dataMapper(map: Dict[str, Any]):
|
|
162
188
|
}
|
163
189
|
return _map
|
164
190
|
|
191
|
+
|
165
192
|
class OnlineStreamingDataLoader():
|
166
193
|
def __init__(
|
167
|
-
self,
|
168
|
-
dataset,
|
169
|
-
batch_size=64,
|
194
|
+
self,
|
195
|
+
dataset,
|
196
|
+
batch_size=64,
|
170
197
|
image_shape=(256, 256),
|
171
|
-
num_workers=16,
|
198
|
+
num_workers=16,
|
172
199
|
num_threads=512,
|
173
200
|
default_split="all",
|
174
|
-
pre_map_maker=dataMapper,
|
201
|
+
pre_map_maker=dataMapper,
|
175
202
|
pre_map_def={
|
176
203
|
"url": "URL",
|
177
204
|
"caption": "TEXT",
|
@@ -180,40 +207,53 @@ class OnlineStreamingDataLoader():
|
|
180
207
|
global_process_index=0,
|
181
208
|
prefetch=1000,
|
182
209
|
collate_fn=default_collate,
|
210
|
+
timeout=15,
|
211
|
+
retries=3,
|
212
|
+
image_processor=default_image_processor,
|
183
213
|
):
|
184
214
|
if isinstance(dataset, str):
|
185
215
|
dataset_path = dataset
|
186
216
|
print("Loading dataset from path")
|
187
|
-
|
217
|
+
if "gs://" in dataset:
|
218
|
+
dataset = load_from_disk(dataset_path)
|
219
|
+
else:
|
220
|
+
dataset = load_dataset(dataset_path, split=default_split)
|
188
221
|
elif isinstance(dataset, list):
|
189
222
|
if isinstance(dataset[0], str):
|
190
223
|
print("Loading multiple datasets from paths")
|
191
|
-
dataset = [
|
224
|
+
dataset = [load_from_disk(dataset_path) if "gs://" in dataset_path else load_dataset(
|
225
|
+
dataset_path, split=default_split) for dataset_path in dataset]
|
192
226
|
print("Concatenating multiple datasets")
|
193
227
|
dataset = concatenate_datasets(dataset)
|
194
|
-
|
195
|
-
|
228
|
+
dataset = dataset.shuffle(seed=0)
|
229
|
+
# dataset = dataset.map(pre_map_maker(pre_map_def), batched=True, batch_size=10000000)
|
230
|
+
self.dataset = dataset.shard(
|
231
|
+
num_shards=global_process_count, index=global_process_index)
|
196
232
|
print(f"Dataset length: {len(dataset)}")
|
197
|
-
self.iterator = ImageBatchIterator(self.dataset, image_shape=image_shape,
|
198
|
-
|
233
|
+
self.iterator = ImageBatchIterator(self.dataset, image_shape=image_shape,
|
234
|
+
num_workers=num_workers, batch_size=batch_size, num_threads=num_threads,
|
235
|
+
timeout=timeout, retries=retries, image_processor=image_processor)
|
199
236
|
self.batch_size = batch_size
|
200
|
-
|
237
|
+
|
201
238
|
# Launch a thread to load batches in the background
|
202
239
|
self.batch_queue = queue.Queue(prefetch)
|
203
|
-
|
240
|
+
|
204
241
|
def batch_loader():
|
205
242
|
for batch in self.iterator:
|
206
|
-
|
207
|
-
|
243
|
+
try:
|
244
|
+
self.batch_queue.put(collate_fn(batch))
|
245
|
+
except Exception as e:
|
246
|
+
print("Error processing batch", e)
|
247
|
+
|
208
248
|
self.loader_thread = threading.Thread(target=batch_loader)
|
209
249
|
self.loader_thread.start()
|
210
|
-
|
250
|
+
|
211
251
|
def __iter__(self):
|
212
252
|
return self
|
213
|
-
|
253
|
+
|
214
254
|
def __next__(self):
|
215
|
-
return self.
|
255
|
+
return self.batch_queue.get()
|
216
256
|
# return self.collate_fn(next(self.iterator))
|
217
|
-
|
257
|
+
|
218
258
|
def __len__(self):
|
219
|
-
return len(self.dataset)
|
259
|
+
return len(self.dataset)
|
@@ -1,7 +1,7 @@
|
|
1
1
|
flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
flaxdiff/utils.py,sha256=B0GcHlzlVYDNEIdh2v5qmP4u0neIT-FqexNohuyuCvg,2452
|
3
3
|
flaxdiff/data/__init__.py,sha256=PM3PkHihyohT5SHVYKc8vQ4IeVfGPpCktkSVwvqMjQ4,52
|
4
|
-
flaxdiff/data/online_loader.py,sha256=
|
4
|
+
flaxdiff/data/online_loader.py,sha256=qim6SRRGU1lRO0zQbDNjRYC7Qm6g7jtUfELEXotora0,8987
|
5
5
|
flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,26
|
6
6
|
flaxdiff/models/attention.py,sha256=ZbDGIb5Q6FRqJ6qRY660cqw4WvF9IwCnhEuYdTpLPdM,13023
|
7
7
|
flaxdiff/models/common.py,sha256=fd-Fl0VCNEBjijHNwGBqYL5VvXe9u0347h25czNTmRw,10780
|
@@ -34,7 +34,7 @@ flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk,
|
|
34
34
|
flaxdiff/trainer/autoencoder_trainer.py,sha256=al7AsZ7yeDMEiDD-gbcXf0ADq_xfk1VMxvg24GfA-XQ,7008
|
35
35
|
flaxdiff/trainer/diffusion_trainer.py,sha256=wKkg63DWZjx2MoM3VQNCDIr40rWN8fUGxH9jWWxfZao,9373
|
36
36
|
flaxdiff/trainer/simple_trainer.py,sha256=Z77zRS5viJpd2Mpl6sonJk5WcnEWi2Cd4gl4u5tIX2M,18206
|
37
|
-
flaxdiff-0.1.
|
38
|
-
flaxdiff-0.1.
|
39
|
-
flaxdiff-0.1.
|
40
|
-
flaxdiff-0.1.
|
37
|
+
flaxdiff-0.1.18.dist-info/METADATA,sha256=aUSr3lBb9P2mnrpmbcgQa41DT8YYM-DtVMU8NI3CZEE,22083
|
38
|
+
flaxdiff-0.1.18.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
|
39
|
+
flaxdiff-0.1.18.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
|
40
|
+
flaxdiff-0.1.18.dist-info/RECORD,,
|
File without changes
|
File without changes
|