flaxdiff 0.1.14__py3-none-any.whl → 0.1.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/data/online_loader.py +26 -15
- {flaxdiff-0.1.14.dist-info → flaxdiff-0.1.16.dist-info}/METADATA +1 -1
- {flaxdiff-0.1.14.dist-info → flaxdiff-0.1.16.dist-info}/RECORD +5 -5
- {flaxdiff-0.1.14.dist-info → flaxdiff-0.1.16.dist-info}/WHEEL +0 -0
- {flaxdiff-0.1.14.dist-info → flaxdiff-0.1.16.dist-info}/top_level.txt +0 -0
flaxdiff/data/online_loader.py
CHANGED
@@ -27,7 +27,6 @@ USER_AGENT = get_datasets_user_agent()
|
|
27
27
|
data_queue = Queue(16*2000)
|
28
28
|
error_queue = Queue(16*2000)
|
29
29
|
|
30
|
-
|
31
30
|
def fetch_single_image(image_url, timeout=None, retries=0):
|
32
31
|
for _ in range(retries + 1):
|
33
32
|
try:
|
@@ -46,11 +45,13 @@ def fetch_single_image(image_url, timeout=None, retries=0):
|
|
46
45
|
def map_sample(
|
47
46
|
url, caption,
|
48
47
|
image_shape=(256, 256),
|
48
|
+
timeout=15,
|
49
|
+
retries=3,
|
49
50
|
upscale_interpolation=cv2.INTER_LANCZOS4,
|
50
51
|
downscale_interpolation=cv2.INTER_AREA,
|
51
52
|
):
|
52
53
|
try:
|
53
|
-
image = fetch_single_image(url, timeout=
|
54
|
+
image = fetch_single_image(url, timeout=timeout, retries=retries) # Assuming fetch_single_image is defined elsewhere
|
54
55
|
if image is None:
|
55
56
|
return
|
56
57
|
|
@@ -84,15 +85,24 @@ def map_sample(
|
|
84
85
|
"original_width": original_width,
|
85
86
|
})
|
86
87
|
except Exception as e:
|
88
|
+
print(f"Error in map_sample: {str(e)}")
|
87
89
|
error_queue.put({
|
88
90
|
"url": url,
|
89
91
|
"caption": caption,
|
90
92
|
"error": str(e)
|
91
93
|
})
|
92
|
-
|
93
|
-
def map_batch(batch, num_threads=256, image_shape=(256, 256), timeout=
|
94
|
-
|
95
|
-
|
94
|
+
|
95
|
+
def map_batch(batch, num_threads=256, image_shape=(256, 256), timeout=15, retries=3):
|
96
|
+
try:
|
97
|
+
map_sample_fn = partial(map_sample, image_shape=image_shape, timeout=timeout, retries=retries)
|
98
|
+
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
99
|
+
executor.map(map_sample_fn, batch["url"], batch['caption'])
|
100
|
+
except Exception as e:
|
101
|
+
print(f"Error in map_batch: {str(e)}")
|
102
|
+
error_queue.put({
|
103
|
+
"batch": batch,
|
104
|
+
"error": str(e)
|
105
|
+
})
|
96
106
|
|
97
107
|
def parallel_image_loader(dataset: Dataset, num_workers: int = 8, image_shape=(256, 256), num_threads=256):
|
98
108
|
map_batch_fn = partial(map_batch, num_threads=num_threads, image_shape=image_shape)
|
@@ -102,8 +112,10 @@ def parallel_image_loader(dataset: Dataset, num_workers: int = 8, image_shape=(2
|
|
102
112
|
iteration = 0
|
103
113
|
while True:
|
104
114
|
# Repeat forever
|
105
|
-
dataset
|
115
|
+
print(f"Shuffling dataset with seed {iteration}")
|
116
|
+
# dataset = dataset.shuffle(seed=iteration)
|
106
117
|
shards = [dataset[i*shard_len:(i+1)*shard_len] for i in range(num_workers)]
|
118
|
+
print(f"mapping {len(shards)} shards")
|
107
119
|
pool.map(map_batch_fn, shards)
|
108
120
|
iteration += 1
|
109
121
|
|
@@ -113,7 +125,7 @@ class ImageBatchIterator:
|
|
113
125
|
self.num_workers = num_workers
|
114
126
|
self.batch_size = batch_size
|
115
127
|
loader = partial(parallel_image_loader, num_threads=num_threads, image_shape=image_shape, num_workers=num_workers)
|
116
|
-
self.thread = threading.Thread(target=loader, args=(dataset))
|
128
|
+
self.thread = threading.Thread(target=loader, args=(dataset,))
|
117
129
|
self.thread.start()
|
118
130
|
|
119
131
|
def __iter__(self):
|
@@ -131,7 +143,7 @@ class ImageBatchIterator:
|
|
131
143
|
|
132
144
|
def __len__(self):
|
133
145
|
return len(self.dataset) // self.batch_size
|
134
|
-
|
146
|
+
|
135
147
|
def default_collate(batch):
|
136
148
|
urls = [sample["url"] for sample in batch]
|
137
149
|
captions = [sample["caption"] for sample in batch]
|
@@ -177,14 +189,14 @@ class OnlineStreamingDataLoader():
|
|
177
189
|
if isinstance(dataset[0], str):
|
178
190
|
print("Loading multiple datasets from paths")
|
179
191
|
dataset = [load_dataset(dataset_path, split=default_split) for dataset_path in dataset]
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
dataset = dataset.map(pre_map_maker(pre_map_def))
|
192
|
+
print("Concatenating multiple datasets")
|
193
|
+
dataset = concatenate_datasets(dataset)
|
194
|
+
dataset = dataset.map(pre_map_maker(pre_map_def), batched=True, batch_size=10000000)
|
184
195
|
self.dataset = dataset.shard(num_shards=global_process_count, index=global_process_index)
|
185
196
|
print(f"Dataset length: {len(dataset)}")
|
186
197
|
self.iterator = ImageBatchIterator(self.dataset, image_shape=image_shape, num_workers=num_workers, batch_size=batch_size, num_threads=num_threads)
|
187
198
|
self.collate_fn = collate_fn
|
199
|
+
self.batch_size = batch_size
|
188
200
|
|
189
201
|
# Launch a thread to load batches in the background
|
190
202
|
self.batch_queue = queue.Queue(prefetch)
|
@@ -204,5 +216,4 @@ class OnlineStreamingDataLoader():
|
|
204
216
|
# return self.collate_fn(next(self.iterator))
|
205
217
|
|
206
218
|
def __len__(self):
|
207
|
-
return len(self.dataset)
|
208
|
-
|
219
|
+
return len(self.dataset)
|
@@ -1,7 +1,7 @@
|
|
1
1
|
flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
flaxdiff/utils.py,sha256=B0GcHlzlVYDNEIdh2v5qmP4u0neIT-FqexNohuyuCvg,2452
|
3
3
|
flaxdiff/data/__init__.py,sha256=PM3PkHihyohT5SHVYKc8vQ4IeVfGPpCktkSVwvqMjQ4,52
|
4
|
-
flaxdiff/data/online_loader.py,sha256=
|
4
|
+
flaxdiff/data/online_loader.py,sha256=nrtZU4srZHsg3iN0sG91y_6nY7QtYXRPLk5rGn_BTIU,7728
|
5
5
|
flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,26
|
6
6
|
flaxdiff/models/attention.py,sha256=ZbDGIb5Q6FRqJ6qRY660cqw4WvF9IwCnhEuYdTpLPdM,13023
|
7
7
|
flaxdiff/models/common.py,sha256=fd-Fl0VCNEBjijHNwGBqYL5VvXe9u0347h25czNTmRw,10780
|
@@ -34,7 +34,7 @@ flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk,
|
|
34
34
|
flaxdiff/trainer/autoencoder_trainer.py,sha256=al7AsZ7yeDMEiDD-gbcXf0ADq_xfk1VMxvg24GfA-XQ,7008
|
35
35
|
flaxdiff/trainer/diffusion_trainer.py,sha256=wKkg63DWZjx2MoM3VQNCDIr40rWN8fUGxH9jWWxfZao,9373
|
36
36
|
flaxdiff/trainer/simple_trainer.py,sha256=Z77zRS5viJpd2Mpl6sonJk5WcnEWi2Cd4gl4u5tIX2M,18206
|
37
|
-
flaxdiff-0.1.
|
38
|
-
flaxdiff-0.1.
|
39
|
-
flaxdiff-0.1.
|
40
|
-
flaxdiff-0.1.
|
37
|
+
flaxdiff-0.1.16.dist-info/METADATA,sha256=BM2RLOiCDqRSWO_owxvWmL_PS3aFtNokPbf-qAuyK4o,22083
|
38
|
+
flaxdiff-0.1.16.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
|
39
|
+
flaxdiff-0.1.16.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
|
40
|
+
flaxdiff-0.1.16.dist-info/RECORD,,
|
File without changes
|
File without changes
|