flaxdiff 0.1.16__py3-none-any.whl → 0.1.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/data/online_loader.py +81 -46
- {flaxdiff-0.1.16.dist-info → flaxdiff-0.1.17.dist-info}/METADATA +1 -1
- {flaxdiff-0.1.16.dist-info → flaxdiff-0.1.17.dist-info}/RECORD +5 -5
- {flaxdiff-0.1.16.dist-info → flaxdiff-0.1.17.dist-info}/WHEEL +0 -0
- {flaxdiff-0.1.16.dist-info → flaxdiff-0.1.17.dist-info}/top_level.txt +0 -0
flaxdiff/data/online_loader.py
CHANGED
@@ -13,7 +13,7 @@ from typing import Any, Dict, List, Tuple
|
|
13
13
|
import numpy as np
|
14
14
|
from functools import partial
|
15
15
|
|
16
|
-
from datasets import load_dataset, concatenate_datasets, Dataset
|
16
|
+
from datasets import load_dataset, concatenate_datasets, Dataset, load_from_disk
|
17
17
|
from datasets.utils.file_utils import get_datasets_user_agent
|
18
18
|
from concurrent.futures import ThreadPoolExecutor
|
19
19
|
import io
|
@@ -27,6 +27,7 @@ USER_AGENT = get_datasets_user_agent()
|
|
27
27
|
data_queue = Queue(16*2000)
|
28
28
|
error_queue = Queue(16*2000)
|
29
29
|
|
30
|
+
|
30
31
|
def fetch_single_image(image_url, timeout=None, retries=0):
|
31
32
|
for _ in range(retries + 1):
|
32
33
|
try:
|
@@ -42,19 +43,35 @@ def fetch_single_image(image_url, timeout=None, retries=0):
|
|
42
43
|
image = None
|
43
44
|
return image
|
44
45
|
|
46
|
+
|
47
|
+
def default_image_processor(image, image_shape, interpolation=cv2.INTER_LANCZOS4):
|
48
|
+
image = A.longest_max_size(image, max(
|
49
|
+
image_shape), interpolation=interpolation)
|
50
|
+
image = A.pad(
|
51
|
+
image,
|
52
|
+
min_height=image_shape[0],
|
53
|
+
min_width=image_shape[1],
|
54
|
+
border_mode=cv2.BORDER_CONSTANT,
|
55
|
+
value=[255, 255, 255],
|
56
|
+
)
|
57
|
+
return image
|
58
|
+
|
59
|
+
|
45
60
|
def map_sample(
|
46
|
-
url, caption,
|
61
|
+
url, caption,
|
47
62
|
image_shape=(256, 256),
|
48
63
|
timeout=15,
|
49
64
|
retries=3,
|
50
65
|
upscale_interpolation=cv2.INTER_LANCZOS4,
|
51
66
|
downscale_interpolation=cv2.INTER_AREA,
|
67
|
+
image_processor=default_image_processor,
|
52
68
|
):
|
53
69
|
try:
|
54
|
-
|
70
|
+
# Assuming fetch_single_image is defined elsewhere
|
71
|
+
image = fetch_single_image(url, timeout=timeout, retries=retries)
|
55
72
|
if image is None:
|
56
73
|
return
|
57
|
-
|
74
|
+
|
58
75
|
image = np.array(image)
|
59
76
|
original_height, original_width = image.shape[:2]
|
60
77
|
# check if the image is too small
|
@@ -69,14 +86,10 @@ def map_sample(
|
|
69
86
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
70
87
|
downscale = max(original_width, original_height) > max(image_shape)
|
71
88
|
interpolation = downscale_interpolation if downscale else upscale_interpolation
|
72
|
-
|
73
|
-
image =
|
74
|
-
image,
|
75
|
-
|
76
|
-
min_width=image_shape[1],
|
77
|
-
border_mode=cv2.BORDER_CONSTANT,
|
78
|
-
value=[255, 255, 255],
|
79
|
-
)
|
89
|
+
|
90
|
+
image = image_processor(
|
91
|
+
image, image_shape, interpolation=interpolation)
|
92
|
+
|
80
93
|
data_queue.put({
|
81
94
|
"url": url,
|
82
95
|
"caption": caption,
|
@@ -85,65 +98,74 @@ def map_sample(
|
|
85
98
|
"original_width": original_width,
|
86
99
|
})
|
87
100
|
except Exception as e:
|
88
|
-
print(f"Error in map_sample: {str(e)}")
|
89
101
|
error_queue.put({
|
90
102
|
"url": url,
|
91
103
|
"caption": caption,
|
92
104
|
"error": str(e)
|
93
105
|
})
|
94
106
|
|
95
|
-
|
107
|
+
|
108
|
+
def map_batch(batch, num_threads=256, image_shape=(256, 256), timeout=15, retries=3, image_processor=default_image_processor):
|
96
109
|
try:
|
97
|
-
map_sample_fn = partial(map_sample, image_shape=image_shape,
|
110
|
+
map_sample_fn = partial(map_sample, image_shape=image_shape,
|
111
|
+
timeout=timeout, retries=retries, image_processor=image_processor)
|
98
112
|
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
99
113
|
executor.map(map_sample_fn, batch["url"], batch['caption'])
|
100
114
|
except Exception as e:
|
101
|
-
print(f"Error in map_batch: {str(e)}")
|
102
115
|
error_queue.put({
|
103
116
|
"batch": batch,
|
104
117
|
"error": str(e)
|
105
118
|
})
|
106
|
-
|
107
|
-
|
108
|
-
|
119
|
+
|
120
|
+
|
121
|
+
def parallel_image_loader(dataset: Dataset, num_workers: int = 8, image_shape=(256, 256),
|
122
|
+
num_threads=256, timeout=15, retries=3, image_processor=default_image_processor):
|
123
|
+
map_batch_fn = partial(map_batch, num_threads=num_threads, image_shape=image_shape,
|
124
|
+
timeout=timeout, retries=retries, image_processor=image_processor)
|
109
125
|
shard_len = len(dataset) // num_workers
|
110
126
|
print(f"Local Shard lengths: {shard_len}")
|
111
127
|
with multiprocessing.Pool(num_workers) as pool:
|
112
128
|
iteration = 0
|
113
129
|
while True:
|
114
130
|
# Repeat forever
|
115
|
-
|
116
|
-
|
117
|
-
shards = [dataset[i*shard_len:(i+1)*shard_len] for i in range(num_workers)]
|
131
|
+
shards = [dataset[i*shard_len:(i+1)*shard_len]
|
132
|
+
for i in range(num_workers)]
|
118
133
|
print(f"mapping {len(shards)} shards")
|
119
134
|
pool.map(map_batch_fn, shards)
|
120
135
|
iteration += 1
|
121
|
-
|
136
|
+
print(f"Shuffling dataset with seed {iteration}")
|
137
|
+
dataset = dataset.shuffle(seed=iteration)
|
138
|
+
|
139
|
+
|
122
140
|
class ImageBatchIterator:
|
123
|
-
def __init__(self, dataset: Dataset, batch_size: int = 64, image_shape=(256, 256),
|
141
|
+
def __init__(self, dataset: Dataset, batch_size: int = 64, image_shape=(256, 256),
|
142
|
+
num_workers: int = 8, num_threads=256, timeout=15, retries=3, image_processor=default_image_processor):
|
124
143
|
self.dataset = dataset
|
125
144
|
self.num_workers = num_workers
|
126
145
|
self.batch_size = batch_size
|
127
|
-
loader = partial(parallel_image_loader, num_threads=num_threads,
|
146
|
+
loader = partial(parallel_image_loader, num_threads=num_threads,
|
147
|
+
image_shape=image_shape, num_workers=num_workers,
|
148
|
+
timeout=timeout, retries=retries, image_processor=image_processor)
|
128
149
|
self.thread = threading.Thread(target=loader, args=(dataset,))
|
129
150
|
self.thread.start()
|
130
|
-
|
151
|
+
|
131
152
|
def __iter__(self):
|
132
153
|
return self
|
133
|
-
|
154
|
+
|
134
155
|
def __next__(self):
|
135
156
|
def fetcher(_):
|
136
157
|
return data_queue.get()
|
137
158
|
with ThreadPoolExecutor(max_workers=self.batch_size) as executor:
|
138
159
|
batch = list(executor.map(fetcher, range(self.batch_size)))
|
139
160
|
return batch
|
140
|
-
|
161
|
+
|
141
162
|
def __del__(self):
|
142
163
|
self.thread.join()
|
143
|
-
|
164
|
+
|
144
165
|
def __len__(self):
|
145
166
|
return len(self.dataset) // self.batch_size
|
146
167
|
|
168
|
+
|
147
169
|
def default_collate(batch):
|
148
170
|
urls = [sample["url"] for sample in batch]
|
149
171
|
captions = [sample["caption"] for sample in batch]
|
@@ -153,7 +175,8 @@ def default_collate(batch):
|
|
153
175
|
"caption": captions,
|
154
176
|
"image": images,
|
155
177
|
}
|
156
|
-
|
178
|
+
|
179
|
+
|
157
180
|
def dataMapper(map: Dict[str, Any]):
|
158
181
|
def _map(sample) -> Dict[str, Any]:
|
159
182
|
return {
|
@@ -162,16 +185,17 @@ def dataMapper(map: Dict[str, Any]):
|
|
162
185
|
}
|
163
186
|
return _map
|
164
187
|
|
188
|
+
|
165
189
|
class OnlineStreamingDataLoader():
|
166
190
|
def __init__(
|
167
|
-
self,
|
168
|
-
dataset,
|
169
|
-
batch_size=64,
|
191
|
+
self,
|
192
|
+
dataset,
|
193
|
+
batch_size=64,
|
170
194
|
image_shape=(256, 256),
|
171
|
-
num_workers=16,
|
195
|
+
num_workers=16,
|
172
196
|
num_threads=512,
|
173
197
|
default_split="all",
|
174
|
-
pre_map_maker=dataMapper,
|
198
|
+
pre_map_maker=dataMapper,
|
175
199
|
pre_map_def={
|
176
200
|
"url": "URL",
|
177
201
|
"caption": "TEXT",
|
@@ -180,40 +204,51 @@ class OnlineStreamingDataLoader():
|
|
180
204
|
global_process_index=0,
|
181
205
|
prefetch=1000,
|
182
206
|
collate_fn=default_collate,
|
207
|
+
timeout=15,
|
208
|
+
retries=3,
|
209
|
+
image_processor=default_image_processor,
|
183
210
|
):
|
184
211
|
if isinstance(dataset, str):
|
185
212
|
dataset_path = dataset
|
186
213
|
print("Loading dataset from path")
|
187
|
-
|
214
|
+
if "gs://" in dataset:
|
215
|
+
dataset = load_from_disk(dataset_path)
|
216
|
+
else:
|
217
|
+
dataset = load_dataset(dataset_path, split=default_split)
|
188
218
|
elif isinstance(dataset, list):
|
189
219
|
if isinstance(dataset[0], str):
|
190
220
|
print("Loading multiple datasets from paths")
|
191
|
-
dataset = [
|
221
|
+
dataset = [load_from_disk(dataset_path) if "gs://" in dataset_path else load_dataset(
|
222
|
+
dataset_path, split=default_split) for dataset_path in dataset]
|
192
223
|
print("Concatenating multiple datasets")
|
193
224
|
dataset = concatenate_datasets(dataset)
|
194
|
-
|
195
|
-
|
225
|
+
dataset = dataset.shuffle(seed=0)
|
226
|
+
# dataset = dataset.map(pre_map_maker(pre_map_def), batched=True, batch_size=10000000)
|
227
|
+
self.dataset = dataset.shard(
|
228
|
+
num_shards=global_process_count, index=global_process_index)
|
196
229
|
print(f"Dataset length: {len(dataset)}")
|
197
|
-
self.iterator = ImageBatchIterator(self.dataset, image_shape=image_shape,
|
230
|
+
self.iterator = ImageBatchIterator(self.dataset, image_shape=image_shape,
|
231
|
+
num_workers=num_workers, batch_size=batch_size, num_threads=num_threads,
|
232
|
+
timeout=timeout, retries=retries, image_processor=image_processor)
|
198
233
|
self.collate_fn = collate_fn
|
199
234
|
self.batch_size = batch_size
|
200
|
-
|
235
|
+
|
201
236
|
# Launch a thread to load batches in the background
|
202
237
|
self.batch_queue = queue.Queue(prefetch)
|
203
|
-
|
238
|
+
|
204
239
|
def batch_loader():
|
205
240
|
for batch in self.iterator:
|
206
241
|
self.batch_queue.put(batch)
|
207
|
-
|
242
|
+
|
208
243
|
self.loader_thread = threading.Thread(target=batch_loader)
|
209
244
|
self.loader_thread.start()
|
210
|
-
|
245
|
+
|
211
246
|
def __iter__(self):
|
212
247
|
return self
|
213
|
-
|
248
|
+
|
214
249
|
def __next__(self):
|
215
250
|
return self.collate_fn(self.batch_queue.get())
|
216
251
|
# return self.collate_fn(next(self.iterator))
|
217
|
-
|
252
|
+
|
218
253
|
def __len__(self):
|
219
254
|
return len(self.dataset)
|
@@ -1,7 +1,7 @@
|
|
1
1
|
flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
flaxdiff/utils.py,sha256=B0GcHlzlVYDNEIdh2v5qmP4u0neIT-FqexNohuyuCvg,2452
|
3
3
|
flaxdiff/data/__init__.py,sha256=PM3PkHihyohT5SHVYKc8vQ4IeVfGPpCktkSVwvqMjQ4,52
|
4
|
-
flaxdiff/data/online_loader.py,sha256=
|
4
|
+
flaxdiff/data/online_loader.py,sha256=BM4Le-4BUo8MJpRzGIA2nMHKm4-WynQ2BOdiQz0JCDs,8791
|
5
5
|
flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,26
|
6
6
|
flaxdiff/models/attention.py,sha256=ZbDGIb5Q6FRqJ6qRY660cqw4WvF9IwCnhEuYdTpLPdM,13023
|
7
7
|
flaxdiff/models/common.py,sha256=fd-Fl0VCNEBjijHNwGBqYL5VvXe9u0347h25czNTmRw,10780
|
@@ -34,7 +34,7 @@ flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk,
|
|
34
34
|
flaxdiff/trainer/autoencoder_trainer.py,sha256=al7AsZ7yeDMEiDD-gbcXf0ADq_xfk1VMxvg24GfA-XQ,7008
|
35
35
|
flaxdiff/trainer/diffusion_trainer.py,sha256=wKkg63DWZjx2MoM3VQNCDIr40rWN8fUGxH9jWWxfZao,9373
|
36
36
|
flaxdiff/trainer/simple_trainer.py,sha256=Z77zRS5viJpd2Mpl6sonJk5WcnEWi2Cd4gl4u5tIX2M,18206
|
37
|
-
flaxdiff-0.1.
|
38
|
-
flaxdiff-0.1.
|
39
|
-
flaxdiff-0.1.
|
40
|
-
flaxdiff-0.1.
|
37
|
+
flaxdiff-0.1.17.dist-info/METADATA,sha256=2Nr_T2yg3XHFt2jBuUXo8FxLYM8si-DBLdW_PBKxzc4,22083
|
38
|
+
flaxdiff-0.1.17.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
|
39
|
+
flaxdiff-0.1.17.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
|
40
|
+
flaxdiff-0.1.17.dist-info/RECORD,,
|
File without changes
|
File without changes
|