flaxdiff 0.1.34__py3-none-any.whl → 0.1.35.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/data/online_loader.py +6 -5
- {flaxdiff-0.1.34.dist-info → flaxdiff-0.1.35.1.dist-info}/METADATA +1 -1
- {flaxdiff-0.1.34.dist-info → flaxdiff-0.1.35.1.dist-info}/RECORD +5 -5
- {flaxdiff-0.1.34.dist-info → flaxdiff-0.1.35.1.dist-info}/WHEEL +0 -0
- {flaxdiff-0.1.34.dist-info → flaxdiff-0.1.35.1.dist-info}/top_level.txt +0 -0
flaxdiff/data/online_loader.py
CHANGED
@@ -86,7 +86,8 @@ def default_feature_extractor(sample):
|
|
86
86
|
|
87
87
|
|
88
88
|
def map_sample(
|
89
|
-
|
89
|
+
url,
|
90
|
+
caption,
|
90
91
|
image_shape=(256, 256),
|
91
92
|
min_image_shape=(128, 128),
|
92
93
|
timeout=15,
|
@@ -94,11 +95,8 @@ def map_sample(
|
|
94
95
|
upscale_interpolation=cv2.INTER_CUBIC,
|
95
96
|
downscale_interpolation=cv2.INTER_AREA,
|
96
97
|
image_processor=default_image_processor,
|
97
|
-
feature_extractor=default_feature_extractor,
|
98
98
|
):
|
99
99
|
try:
|
100
|
-
features = feature_extractor(sample)
|
101
|
-
url, caption = features["url"], features["caption"]
|
102
100
|
# Assuming fetch_single_image is defined elsewhere
|
103
101
|
image = fetch_single_image(url, timeout=timeout, retries=retries)
|
104
102
|
if image is None:
|
@@ -147,8 +145,10 @@ def map_batch(
|
|
147
145
|
downscale_interpolation=downscale_interpolation,
|
148
146
|
feature_extractor=feature_extractor
|
149
147
|
)
|
148
|
+
features = feature_extractor(batch)
|
149
|
+
url, caption = features["url"], features["caption"]
|
150
150
|
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
151
|
-
executor.map(map_sample_fn,
|
151
|
+
executor.map(map_sample_fn, url, caption)
|
152
152
|
except Exception as e:
|
153
153
|
print(f"Error maping batch", e)
|
154
154
|
traceback.print_exc()
|
@@ -216,6 +216,7 @@ def parallel_image_loader(
|
|
216
216
|
# Repeat forever
|
217
217
|
shards = [dataset[i*shard_len:(i+1)*shard_len]
|
218
218
|
for i in range(num_workers)]
|
219
|
+
# shards = [dataset.shard(num_shards=num_workers, index=i) for i in range(num_workers)]
|
219
220
|
print(f"mapping {len(shards)} shards")
|
220
221
|
pool.map(map_batch_fn, shards)
|
221
222
|
iteration += 1
|
@@ -1,7 +1,7 @@
|
|
1
1
|
flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
flaxdiff/utils.py,sha256=B0GcHlzlVYDNEIdh2v5qmP4u0neIT-FqexNohuyuCvg,2452
|
3
3
|
flaxdiff/data/__init__.py,sha256=PM3PkHihyohT5SHVYKc8vQ4IeVfGPpCktkSVwvqMjQ4,52
|
4
|
-
flaxdiff/data/online_loader.py,sha256=
|
4
|
+
flaxdiff/data/online_loader.py,sha256=reMG_EYbzcNi5rTVbGnAZE7gq1iY6YxYY7K6D5JIsbw,13293
|
5
5
|
flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,26
|
6
6
|
flaxdiff/models/attention.py,sha256=ZbDGIb5Q6FRqJ6qRY660cqw4WvF9IwCnhEuYdTpLPdM,13023
|
7
7
|
flaxdiff/models/common.py,sha256=hWsSs2BP2J-JN1s4qLRr-h-KYkcVyl2hOp1Wsm_L-h8,10994
|
@@ -34,7 +34,7 @@ flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk,
|
|
34
34
|
flaxdiff/trainer/autoencoder_trainer.py,sha256=al7AsZ7yeDMEiDD-gbcXf0ADq_xfk1VMxvg24GfA-XQ,7008
|
35
35
|
flaxdiff/trainer/diffusion_trainer.py,sha256=wKkg63DWZjx2MoM3VQNCDIr40rWN8fUGxH9jWWxfZao,9373
|
36
36
|
flaxdiff/trainer/simple_trainer.py,sha256=Z77zRS5viJpd2Mpl6sonJk5WcnEWi2Cd4gl4u5tIX2M,18206
|
37
|
-
flaxdiff-0.1.
|
38
|
-
flaxdiff-0.1.
|
39
|
-
flaxdiff-0.1.
|
40
|
-
flaxdiff-0.1.
|
37
|
+
flaxdiff-0.1.35.1.dist-info/METADATA,sha256=ktAM1HUTLykCSmVZzH3VJHj4PzeCoOEOlpslyavlOWs,22085
|
38
|
+
flaxdiff-0.1.35.1.dist-info/WHEEL,sha256=cVxcB9AmuTcXqmwrtPhNK88dr7IR_b6qagTj0UvIEbY,91
|
39
|
+
flaxdiff-0.1.35.1.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
|
40
|
+
flaxdiff-0.1.35.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|