flaxdiff 0.1.13__py3-none-any.whl → 0.1.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/data/__init__.py +1 -0
- flaxdiff/data/online_loader.py +18 -15
- {flaxdiff-0.1.13.dist-info → flaxdiff-0.1.15.dist-info}/METADATA +1 -1
- {flaxdiff-0.1.13.dist-info → flaxdiff-0.1.15.dist-info}/RECORD +6 -6
- {flaxdiff-0.1.13.dist-info → flaxdiff-0.1.15.dist-info}/WHEEL +0 -0
- {flaxdiff-0.1.13.dist-info → flaxdiff-0.1.15.dist-info}/top_level.txt +0 -0
flaxdiff/data/__init__.py
CHANGED
@@ -0,0 +1 @@
|
|
1
|
+
from .online_loader import OnlineStreamingDataLoader
|
flaxdiff/data/online_loader.py
CHANGED
@@ -79,7 +79,9 @@ def map_sample(
|
|
79
79
|
data_queue.put({
|
80
80
|
"url": url,
|
81
81
|
"caption": caption,
|
82
|
-
"image": image
|
82
|
+
"image": image,
|
83
|
+
"original_height": original_height,
|
84
|
+
"original_width": original_width,
|
83
85
|
})
|
84
86
|
except Exception as e:
|
85
87
|
error_queue.put({
|
@@ -88,12 +90,12 @@ def map_sample(
|
|
88
90
|
"error": str(e)
|
89
91
|
})
|
90
92
|
|
91
|
-
def map_batch(batch, num_threads=256, timeout=None, retries=0):
|
93
|
+
def map_batch(batch, num_threads=256, image_shape=(256, 256), timeout=None, retries=0):
|
92
94
|
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
93
|
-
executor.map(map_sample, batch["url"], batch['caption'])
|
95
|
+
executor.map(map_sample, batch["url"], batch['caption'], image_shape=image_shape, timeout=timeout, retries=retries)
|
94
96
|
|
95
|
-
def parallel_image_loader(dataset: Dataset, num_workers: int = 8, num_threads=256):
|
96
|
-
map_batch_fn = partial(map_batch, num_threads=num_threads)
|
97
|
+
def parallel_image_loader(dataset: Dataset, num_workers: int = 8, image_shape=(256, 256), num_threads=256):
|
98
|
+
map_batch_fn = partial(map_batch, num_threads=num_threads, image_shape=image_shape)
|
97
99
|
shard_len = len(dataset) // num_workers
|
98
100
|
print(f"Local Shard lengths: {shard_len}")
|
99
101
|
with multiprocessing.Pool(num_workers) as pool:
|
@@ -106,12 +108,12 @@ def parallel_image_loader(dataset: Dataset, num_workers: int = 8, num_threads=25
|
|
106
108
|
iteration += 1
|
107
109
|
|
108
110
|
class ImageBatchIterator:
|
109
|
-
def __init__(self, dataset: Dataset, batch_size: int = 64, num_workers: int = 8, num_threads=256):
|
111
|
+
def __init__(self, dataset: Dataset, batch_size: int = 64, image_shape=(256, 256), num_workers: int = 8, num_threads=256):
|
110
112
|
self.dataset = dataset
|
111
113
|
self.num_workers = num_workers
|
112
114
|
self.batch_size = batch_size
|
113
|
-
loader = partial(parallel_image_loader, num_threads=num_threads)
|
114
|
-
self.thread = threading.Thread(target=loader, args=(dataset,
|
115
|
+
loader = partial(parallel_image_loader, num_threads=num_threads, image_shape=image_shape, num_workers=num_workers)
|
116
|
+
self.thread = threading.Thread(target=loader, args=(dataset,))
|
115
117
|
self.thread.start()
|
116
118
|
|
117
119
|
def __iter__(self):
|
@@ -129,7 +131,7 @@ class ImageBatchIterator:
|
|
129
131
|
|
130
132
|
def __len__(self):
|
131
133
|
return len(self.dataset) // self.batch_size
|
132
|
-
|
134
|
+
|
133
135
|
def default_collate(batch):
|
134
136
|
urls = [sample["url"] for sample in batch]
|
135
137
|
captions = [sample["caption"] for sample in batch]
|
@@ -153,6 +155,7 @@ class OnlineStreamingDataLoader():
|
|
153
155
|
self,
|
154
156
|
dataset,
|
155
157
|
batch_size=64,
|
158
|
+
image_shape=(256, 256),
|
156
159
|
num_workers=16,
|
157
160
|
num_threads=512,
|
158
161
|
default_split="all",
|
@@ -174,14 +177,14 @@ class OnlineStreamingDataLoader():
|
|
174
177
|
if isinstance(dataset[0], str):
|
175
178
|
print("Loading multiple datasets from paths")
|
176
179
|
dataset = [load_dataset(dataset_path, split=default_split) for dataset_path in dataset]
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
dataset = dataset.map(pre_map_maker(pre_map_def))
|
180
|
+
print("Concatenating multiple datasets")
|
181
|
+
dataset = concatenate_datasets(dataset)
|
182
|
+
dataset = dataset.map(pre_map_maker(pre_map_def), batched=True, batch_size=10000000)
|
181
183
|
self.dataset = dataset.shard(num_shards=global_process_count, index=global_process_index)
|
182
184
|
print(f"Dataset length: {len(dataset)}")
|
183
|
-
self.iterator = ImageBatchIterator(self.dataset, num_workers=num_workers, batch_size=batch_size, num_threads=num_threads)
|
185
|
+
self.iterator = ImageBatchIterator(self.dataset, image_shape=image_shape, num_workers=num_workers, batch_size=batch_size, num_threads=num_threads)
|
184
186
|
self.collate_fn = collate_fn
|
187
|
+
self.batch_size = batch_size
|
185
188
|
|
186
189
|
# Launch a thread to load batches in the background
|
187
190
|
self.batch_queue = queue.Queue(prefetch)
|
@@ -201,5 +204,5 @@ class OnlineStreamingDataLoader():
|
|
201
204
|
# return self.collate_fn(next(self.iterator))
|
202
205
|
|
203
206
|
def __len__(self):
|
204
|
-
return len(self.dataset)
|
207
|
+
return len(self.dataset)
|
205
208
|
|
@@ -1,7 +1,7 @@
|
|
1
1
|
flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
flaxdiff/utils.py,sha256=B0GcHlzlVYDNEIdh2v5qmP4u0neIT-FqexNohuyuCvg,2452
|
3
|
-
flaxdiff/data/__init__.py,sha256=
|
4
|
-
flaxdiff/data/online_loader.py,sha256=
|
3
|
+
flaxdiff/data/__init__.py,sha256=PM3PkHihyohT5SHVYKc8vQ4IeVfGPpCktkSVwvqMjQ4,52
|
4
|
+
flaxdiff/data/online_loader.py,sha256=Xd4xAerzOU65tvb0cuM8S3Hnm8xv1dJ62xbxAxZkeBw,7307
|
5
5
|
flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,26
|
6
6
|
flaxdiff/models/attention.py,sha256=ZbDGIb5Q6FRqJ6qRY660cqw4WvF9IwCnhEuYdTpLPdM,13023
|
7
7
|
flaxdiff/models/common.py,sha256=fd-Fl0VCNEBjijHNwGBqYL5VvXe9u0347h25czNTmRw,10780
|
@@ -34,7 +34,7 @@ flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk,
|
|
34
34
|
flaxdiff/trainer/autoencoder_trainer.py,sha256=al7AsZ7yeDMEiDD-gbcXf0ADq_xfk1VMxvg24GfA-XQ,7008
|
35
35
|
flaxdiff/trainer/diffusion_trainer.py,sha256=wKkg63DWZjx2MoM3VQNCDIr40rWN8fUGxH9jWWxfZao,9373
|
36
36
|
flaxdiff/trainer/simple_trainer.py,sha256=Z77zRS5viJpd2Mpl6sonJk5WcnEWi2Cd4gl4u5tIX2M,18206
|
37
|
-
flaxdiff-0.1.
|
38
|
-
flaxdiff-0.1.
|
39
|
-
flaxdiff-0.1.
|
40
|
-
flaxdiff-0.1.
|
37
|
+
flaxdiff-0.1.15.dist-info/METADATA,sha256=45ifUa-j2rPqP3s734HRy2rIM1StRvSZ4lIJI0R3sTw,22083
|
38
|
+
flaxdiff-0.1.15.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
|
39
|
+
flaxdiff-0.1.15.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
|
40
|
+
flaxdiff-0.1.15.dist-info/RECORD,,
|
File without changes
|
File without changes
|