flaxdiff 0.1.15__tar.gz → 0.1.16__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {flaxdiff-0.1.15 → flaxdiff-0.1.16}/PKG-INFO +1 -1
- {flaxdiff-0.1.15 → flaxdiff-0.1.16}/flaxdiff/data/online_loader.py +19 -8
- {flaxdiff-0.1.15 → flaxdiff-0.1.16}/flaxdiff.egg-info/PKG-INFO +1 -1
- {flaxdiff-0.1.15 → flaxdiff-0.1.16}/setup.py +1 -1
- {flaxdiff-0.1.15 → flaxdiff-0.1.16}/README.md +0 -0
- {flaxdiff-0.1.15 → flaxdiff-0.1.16}/flaxdiff/__init__.py +0 -0
- {flaxdiff-0.1.15 → flaxdiff-0.1.16}/flaxdiff/data/__init__.py +0 -0
- {flaxdiff-0.1.15 → flaxdiff-0.1.16}/flaxdiff/models/__init__.py +0 -0
- {flaxdiff-0.1.15 → flaxdiff-0.1.16}/flaxdiff/models/attention.py +0 -0
- {flaxdiff-0.1.15 → flaxdiff-0.1.16}/flaxdiff/models/autoencoder/__init__.py +0 -0
- {flaxdiff-0.1.15 → flaxdiff-0.1.16}/flaxdiff/models/autoencoder/autoencoder.py +0 -0
- {flaxdiff-0.1.15 → flaxdiff-0.1.16}/flaxdiff/models/autoencoder/diffusers.py +0 -0
- {flaxdiff-0.1.15 → flaxdiff-0.1.16}/flaxdiff/models/autoencoder/simple_autoenc.py +0 -0
- {flaxdiff-0.1.15 → flaxdiff-0.1.16}/flaxdiff/models/common.py +0 -0
- {flaxdiff-0.1.15 → flaxdiff-0.1.16}/flaxdiff/models/favor_fastattn.py +0 -0
- {flaxdiff-0.1.15 → flaxdiff-0.1.16}/flaxdiff/models/simple_unet.py +0 -0
- {flaxdiff-0.1.15 → flaxdiff-0.1.16}/flaxdiff/models/simple_vit.py +0 -0
- {flaxdiff-0.1.15 → flaxdiff-0.1.16}/flaxdiff/predictors/__init__.py +0 -0
- {flaxdiff-0.1.15 → flaxdiff-0.1.16}/flaxdiff/samplers/__init__.py +0 -0
- {flaxdiff-0.1.15 → flaxdiff-0.1.16}/flaxdiff/samplers/common.py +0 -0
- {flaxdiff-0.1.15 → flaxdiff-0.1.16}/flaxdiff/samplers/ddim.py +0 -0
- {flaxdiff-0.1.15 → flaxdiff-0.1.16}/flaxdiff/samplers/ddpm.py +0 -0
- {flaxdiff-0.1.15 → flaxdiff-0.1.16}/flaxdiff/samplers/euler.py +0 -0
- {flaxdiff-0.1.15 → flaxdiff-0.1.16}/flaxdiff/samplers/heun_sampler.py +0 -0
- {flaxdiff-0.1.15 → flaxdiff-0.1.16}/flaxdiff/samplers/multistep_dpm.py +0 -0
- {flaxdiff-0.1.15 → flaxdiff-0.1.16}/flaxdiff/samplers/rk4_sampler.py +0 -0
- {flaxdiff-0.1.15 → flaxdiff-0.1.16}/flaxdiff/schedulers/__init__.py +0 -0
- {flaxdiff-0.1.15 → flaxdiff-0.1.16}/flaxdiff/schedulers/common.py +0 -0
- {flaxdiff-0.1.15 → flaxdiff-0.1.16}/flaxdiff/schedulers/continuous.py +0 -0
- {flaxdiff-0.1.15 → flaxdiff-0.1.16}/flaxdiff/schedulers/cosine.py +0 -0
- {flaxdiff-0.1.15 → flaxdiff-0.1.16}/flaxdiff/schedulers/discrete.py +0 -0
- {flaxdiff-0.1.15 → flaxdiff-0.1.16}/flaxdiff/schedulers/exp.py +0 -0
- {flaxdiff-0.1.15 → flaxdiff-0.1.16}/flaxdiff/schedulers/karras.py +0 -0
- {flaxdiff-0.1.15 → flaxdiff-0.1.16}/flaxdiff/schedulers/linear.py +0 -0
- {flaxdiff-0.1.15 → flaxdiff-0.1.16}/flaxdiff/schedulers/sqrt.py +0 -0
- {flaxdiff-0.1.15 → flaxdiff-0.1.16}/flaxdiff/trainer/__init__.py +0 -0
- {flaxdiff-0.1.15 → flaxdiff-0.1.16}/flaxdiff/trainer/autoencoder_trainer.py +0 -0
- {flaxdiff-0.1.15 → flaxdiff-0.1.16}/flaxdiff/trainer/diffusion_trainer.py +0 -0
- {flaxdiff-0.1.15 → flaxdiff-0.1.16}/flaxdiff/trainer/simple_trainer.py +0 -0
- {flaxdiff-0.1.15 → flaxdiff-0.1.16}/flaxdiff/utils.py +0 -0
- {flaxdiff-0.1.15 → flaxdiff-0.1.16}/flaxdiff.egg-info/SOURCES.txt +0 -0
- {flaxdiff-0.1.15 → flaxdiff-0.1.16}/flaxdiff.egg-info/dependency_links.txt +0 -0
- {flaxdiff-0.1.15 → flaxdiff-0.1.16}/flaxdiff.egg-info/requires.txt +0 -0
- {flaxdiff-0.1.15 → flaxdiff-0.1.16}/flaxdiff.egg-info/top_level.txt +0 -0
- {flaxdiff-0.1.15 → flaxdiff-0.1.16}/setup.cfg +0 -0
@@ -27,7 +27,6 @@ USER_AGENT = get_datasets_user_agent()
|
|
27
27
|
data_queue = Queue(16*2000)
|
28
28
|
error_queue = Queue(16*2000)
|
29
29
|
|
30
|
-
|
31
30
|
def fetch_single_image(image_url, timeout=None, retries=0):
|
32
31
|
for _ in range(retries + 1):
|
33
32
|
try:
|
@@ -46,11 +45,13 @@ def fetch_single_image(image_url, timeout=None, retries=0):
|
|
46
45
|
def map_sample(
|
47
46
|
url, caption,
|
48
47
|
image_shape=(256, 256),
|
48
|
+
timeout=15,
|
49
|
+
retries=3,
|
49
50
|
upscale_interpolation=cv2.INTER_LANCZOS4,
|
50
51
|
downscale_interpolation=cv2.INTER_AREA,
|
51
52
|
):
|
52
53
|
try:
|
53
|
-
image = fetch_single_image(url, timeout=
|
54
|
+
image = fetch_single_image(url, timeout=timeout, retries=retries) # Assuming fetch_single_image is defined elsewhere
|
54
55
|
if image is None:
|
55
56
|
return
|
56
57
|
|
@@ -84,15 +85,24 @@ def map_sample(
|
|
84
85
|
"original_width": original_width,
|
85
86
|
})
|
86
87
|
except Exception as e:
|
88
|
+
print(f"Error in map_sample: {str(e)}")
|
87
89
|
error_queue.put({
|
88
90
|
"url": url,
|
89
91
|
"caption": caption,
|
90
92
|
"error": str(e)
|
91
93
|
})
|
92
|
-
|
93
|
-
def map_batch(batch, num_threads=256, image_shape=(256, 256), timeout=
|
94
|
-
|
95
|
-
|
94
|
+
|
95
|
+
def map_batch(batch, num_threads=256, image_shape=(256, 256), timeout=15, retries=3):
|
96
|
+
try:
|
97
|
+
map_sample_fn = partial(map_sample, image_shape=image_shape, timeout=timeout, retries=retries)
|
98
|
+
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
99
|
+
executor.map(map_sample_fn, batch["url"], batch['caption'])
|
100
|
+
except Exception as e:
|
101
|
+
print(f"Error in map_batch: {str(e)}")
|
102
|
+
error_queue.put({
|
103
|
+
"batch": batch,
|
104
|
+
"error": str(e)
|
105
|
+
})
|
96
106
|
|
97
107
|
def parallel_image_loader(dataset: Dataset, num_workers: int = 8, image_shape=(256, 256), num_threads=256):
|
98
108
|
map_batch_fn = partial(map_batch, num_threads=num_threads, image_shape=image_shape)
|
@@ -102,8 +112,10 @@ def parallel_image_loader(dataset: Dataset, num_workers: int = 8, image_shape=(2
|
|
102
112
|
iteration = 0
|
103
113
|
while True:
|
104
114
|
# Repeat forever
|
105
|
-
dataset
|
115
|
+
print(f"Shuffling dataset with seed {iteration}")
|
116
|
+
# dataset = dataset.shuffle(seed=iteration)
|
106
117
|
shards = [dataset[i*shard_len:(i+1)*shard_len] for i in range(num_workers)]
|
118
|
+
print(f"mapping {len(shards)} shards")
|
107
119
|
pool.map(map_batch_fn, shards)
|
108
120
|
iteration += 1
|
109
121
|
|
@@ -205,4 +217,3 @@ class OnlineStreamingDataLoader():
|
|
205
217
|
|
206
218
|
def __len__(self):
|
207
219
|
return len(self.dataset)
|
208
|
-
|
@@ -11,7 +11,7 @@ required_packages=[
|
|
11
11
|
setup(
|
12
12
|
name='flaxdiff',
|
13
13
|
packages=find_packages(),
|
14
|
-
version='0.1.
|
14
|
+
version='0.1.16',
|
15
15
|
description='A versatile and easy to understand Diffusion library',
|
16
16
|
long_description=open('README.md').read(),
|
17
17
|
long_description_content_type='text/markdown',
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|