flaxdiff 0.1.13__py3-none-any.whl → 0.1.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/data/__init__.py +1 -0
- flaxdiff/data/online_loader.py +12 -9
- {flaxdiff-0.1.13.dist-info → flaxdiff-0.1.14.dist-info}/METADATA +1 -1
- {flaxdiff-0.1.13.dist-info → flaxdiff-0.1.14.dist-info}/RECORD +6 -6
- {flaxdiff-0.1.13.dist-info → flaxdiff-0.1.14.dist-info}/WHEEL +0 -0
- {flaxdiff-0.1.13.dist-info → flaxdiff-0.1.14.dist-info}/top_level.txt +0 -0
flaxdiff/data/__init__.py
CHANGED
@@ -0,0 +1 @@
|
|
1
|
+
from .online_loader import OnlineStreamingDataLoader
|
flaxdiff/data/online_loader.py
CHANGED
@@ -79,7 +79,9 @@ def map_sample(
|
|
79
79
|
data_queue.put({
|
80
80
|
"url": url,
|
81
81
|
"caption": caption,
|
82
|
-
"image": image
|
82
|
+
"image": image,
|
83
|
+
"original_height": original_height,
|
84
|
+
"original_width": original_width,
|
83
85
|
})
|
84
86
|
except Exception as e:
|
85
87
|
error_queue.put({
|
@@ -88,12 +90,12 @@ def map_sample(
|
|
88
90
|
"error": str(e)
|
89
91
|
})
|
90
92
|
|
91
|
-
def map_batch(batch, num_threads=256, timeout=None, retries=0):
|
93
|
+
def map_batch(batch, num_threads=256, image_shape=(256, 256), timeout=None, retries=0):
|
92
94
|
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
93
|
-
executor.map(map_sample, batch["url"], batch['caption'])
|
95
|
+
executor.map(map_sample, batch["url"], batch['caption'], image_shape=image_shape, timeout=timeout, retries=retries)
|
94
96
|
|
95
|
-
def parallel_image_loader(dataset: Dataset, num_workers: int = 8, num_threads=256):
|
96
|
-
map_batch_fn = partial(map_batch, num_threads=num_threads)
|
97
|
+
def parallel_image_loader(dataset: Dataset, num_workers: int = 8, image_shape=(256, 256), num_threads=256):
|
98
|
+
map_batch_fn = partial(map_batch, num_threads=num_threads, image_shape=image_shape)
|
97
99
|
shard_len = len(dataset) // num_workers
|
98
100
|
print(f"Local Shard lengths: {shard_len}")
|
99
101
|
with multiprocessing.Pool(num_workers) as pool:
|
@@ -106,12 +108,12 @@ def parallel_image_loader(dataset: Dataset, num_workers: int = 8, num_threads=25
|
|
106
108
|
iteration += 1
|
107
109
|
|
108
110
|
class ImageBatchIterator:
|
109
|
-
def __init__(self, dataset: Dataset, batch_size: int = 64, num_workers: int = 8, num_threads=256):
|
111
|
+
def __init__(self, dataset: Dataset, batch_size: int = 64, image_shape=(256, 256), num_workers: int = 8, num_threads=256):
|
110
112
|
self.dataset = dataset
|
111
113
|
self.num_workers = num_workers
|
112
114
|
self.batch_size = batch_size
|
113
|
-
loader = partial(parallel_image_loader, num_threads=num_threads)
|
114
|
-
self.thread = threading.Thread(target=loader, args=(dataset
|
115
|
+
loader = partial(parallel_image_loader, num_threads=num_threads, image_shape=image_shape, num_workers=num_workers)
|
116
|
+
self.thread = threading.Thread(target=loader, args=(dataset))
|
115
117
|
self.thread.start()
|
116
118
|
|
117
119
|
def __iter__(self):
|
@@ -153,6 +155,7 @@ class OnlineStreamingDataLoader():
|
|
153
155
|
self,
|
154
156
|
dataset,
|
155
157
|
batch_size=64,
|
158
|
+
image_shape=(256, 256),
|
156
159
|
num_workers=16,
|
157
160
|
num_threads=512,
|
158
161
|
default_split="all",
|
@@ -180,7 +183,7 @@ class OnlineStreamingDataLoader():
|
|
180
183
|
dataset = dataset.map(pre_map_maker(pre_map_def))
|
181
184
|
self.dataset = dataset.shard(num_shards=global_process_count, index=global_process_index)
|
182
185
|
print(f"Dataset length: {len(dataset)}")
|
183
|
-
self.iterator = ImageBatchIterator(self.dataset, num_workers=num_workers, batch_size=batch_size, num_threads=num_threads)
|
186
|
+
self.iterator = ImageBatchIterator(self.dataset, image_shape=image_shape, num_workers=num_workers, batch_size=batch_size, num_threads=num_threads)
|
184
187
|
self.collate_fn = collate_fn
|
185
188
|
|
186
189
|
# Launch a thread to load batches in the background
|
@@ -1,7 +1,7 @@
|
|
1
1
|
flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
flaxdiff/utils.py,sha256=B0GcHlzlVYDNEIdh2v5qmP4u0neIT-FqexNohuyuCvg,2452
|
3
|
-
flaxdiff/data/__init__.py,sha256=
|
4
|
-
flaxdiff/data/online_loader.py,sha256=
|
3
|
+
flaxdiff/data/__init__.py,sha256=PM3PkHihyohT5SHVYKc8vQ4IeVfGPpCktkSVwvqMjQ4,52
|
4
|
+
flaxdiff/data/online_loader.py,sha256=IjCVdeq18lF71eLX8RplJLezjKASO1cSFVe2GSYkLQ8,7283
|
5
5
|
flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,26
|
6
6
|
flaxdiff/models/attention.py,sha256=ZbDGIb5Q6FRqJ6qRY660cqw4WvF9IwCnhEuYdTpLPdM,13023
|
7
7
|
flaxdiff/models/common.py,sha256=fd-Fl0VCNEBjijHNwGBqYL5VvXe9u0347h25czNTmRw,10780
|
@@ -34,7 +34,7 @@ flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk,
|
|
34
34
|
flaxdiff/trainer/autoencoder_trainer.py,sha256=al7AsZ7yeDMEiDD-gbcXf0ADq_xfk1VMxvg24GfA-XQ,7008
|
35
35
|
flaxdiff/trainer/diffusion_trainer.py,sha256=wKkg63DWZjx2MoM3VQNCDIr40rWN8fUGxH9jWWxfZao,9373
|
36
36
|
flaxdiff/trainer/simple_trainer.py,sha256=Z77zRS5viJpd2Mpl6sonJk5WcnEWi2Cd4gl4u5tIX2M,18206
|
37
|
-
flaxdiff-0.1.
|
38
|
-
flaxdiff-0.1.
|
39
|
-
flaxdiff-0.1.
|
40
|
-
flaxdiff-0.1.
|
37
|
+
flaxdiff-0.1.14.dist-info/METADATA,sha256=rcwF7cCFfgPLHn5gD7GZ_KvtG6EmiAIiel7xt8HylAo,22083
|
38
|
+
flaxdiff-0.1.14.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
|
39
|
+
flaxdiff-0.1.14.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
|
40
|
+
flaxdiff-0.1.14.dist-info/RECORD,,
|
File without changes
|
File without changes
|