flaxdiff 0.1.35__tar.gz → 0.1.35.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {flaxdiff-0.1.35 → flaxdiff-0.1.35.2}/PKG-INFO +1 -1
- {flaxdiff-0.1.35 → flaxdiff-0.1.35.2}/flaxdiff/data/online_loader.py +51 -59
- {flaxdiff-0.1.35 → flaxdiff-0.1.35.2}/flaxdiff.egg-info/PKG-INFO +1 -1
- {flaxdiff-0.1.35 → flaxdiff-0.1.35.2}/setup.py +1 -1
- {flaxdiff-0.1.35 → flaxdiff-0.1.35.2}/README.md +0 -0
- {flaxdiff-0.1.35 → flaxdiff-0.1.35.2}/flaxdiff/__init__.py +0 -0
- {flaxdiff-0.1.35 → flaxdiff-0.1.35.2}/flaxdiff/data/__init__.py +0 -0
- {flaxdiff-0.1.35 → flaxdiff-0.1.35.2}/flaxdiff/models/__init__.py +0 -0
- {flaxdiff-0.1.35 → flaxdiff-0.1.35.2}/flaxdiff/models/attention.py +0 -0
- {flaxdiff-0.1.35 → flaxdiff-0.1.35.2}/flaxdiff/models/autoencoder/__init__.py +0 -0
- {flaxdiff-0.1.35 → flaxdiff-0.1.35.2}/flaxdiff/models/autoencoder/autoencoder.py +0 -0
- {flaxdiff-0.1.35 → flaxdiff-0.1.35.2}/flaxdiff/models/autoencoder/diffusers.py +0 -0
- {flaxdiff-0.1.35 → flaxdiff-0.1.35.2}/flaxdiff/models/autoencoder/simple_autoenc.py +0 -0
- {flaxdiff-0.1.35 → flaxdiff-0.1.35.2}/flaxdiff/models/common.py +0 -0
- {flaxdiff-0.1.35 → flaxdiff-0.1.35.2}/flaxdiff/models/favor_fastattn.py +0 -0
- {flaxdiff-0.1.35 → flaxdiff-0.1.35.2}/flaxdiff/models/simple_unet.py +0 -0
- {flaxdiff-0.1.35 → flaxdiff-0.1.35.2}/flaxdiff/models/simple_vit.py +0 -0
- {flaxdiff-0.1.35 → flaxdiff-0.1.35.2}/flaxdiff/predictors/__init__.py +0 -0
- {flaxdiff-0.1.35 → flaxdiff-0.1.35.2}/flaxdiff/samplers/__init__.py +0 -0
- {flaxdiff-0.1.35 → flaxdiff-0.1.35.2}/flaxdiff/samplers/common.py +0 -0
- {flaxdiff-0.1.35 → flaxdiff-0.1.35.2}/flaxdiff/samplers/ddim.py +0 -0
- {flaxdiff-0.1.35 → flaxdiff-0.1.35.2}/flaxdiff/samplers/ddpm.py +0 -0
- {flaxdiff-0.1.35 → flaxdiff-0.1.35.2}/flaxdiff/samplers/euler.py +0 -0
- {flaxdiff-0.1.35 → flaxdiff-0.1.35.2}/flaxdiff/samplers/heun_sampler.py +0 -0
- {flaxdiff-0.1.35 → flaxdiff-0.1.35.2}/flaxdiff/samplers/multistep_dpm.py +0 -0
- {flaxdiff-0.1.35 → flaxdiff-0.1.35.2}/flaxdiff/samplers/rk4_sampler.py +0 -0
- {flaxdiff-0.1.35 → flaxdiff-0.1.35.2}/flaxdiff/schedulers/__init__.py +0 -0
- {flaxdiff-0.1.35 → flaxdiff-0.1.35.2}/flaxdiff/schedulers/common.py +0 -0
- {flaxdiff-0.1.35 → flaxdiff-0.1.35.2}/flaxdiff/schedulers/continuous.py +0 -0
- {flaxdiff-0.1.35 → flaxdiff-0.1.35.2}/flaxdiff/schedulers/cosine.py +0 -0
- {flaxdiff-0.1.35 → flaxdiff-0.1.35.2}/flaxdiff/schedulers/discrete.py +0 -0
- {flaxdiff-0.1.35 → flaxdiff-0.1.35.2}/flaxdiff/schedulers/exp.py +0 -0
- {flaxdiff-0.1.35 → flaxdiff-0.1.35.2}/flaxdiff/schedulers/karras.py +0 -0
- {flaxdiff-0.1.35 → flaxdiff-0.1.35.2}/flaxdiff/schedulers/linear.py +0 -0
- {flaxdiff-0.1.35 → flaxdiff-0.1.35.2}/flaxdiff/schedulers/sqrt.py +0 -0
- {flaxdiff-0.1.35 → flaxdiff-0.1.35.2}/flaxdiff/trainer/__init__.py +0 -0
- {flaxdiff-0.1.35 → flaxdiff-0.1.35.2}/flaxdiff/trainer/autoencoder_trainer.py +0 -0
- {flaxdiff-0.1.35 → flaxdiff-0.1.35.2}/flaxdiff/trainer/diffusion_trainer.py +0 -0
- {flaxdiff-0.1.35 → flaxdiff-0.1.35.2}/flaxdiff/trainer/simple_trainer.py +0 -0
- {flaxdiff-0.1.35 → flaxdiff-0.1.35.2}/flaxdiff/utils.py +0 -0
- {flaxdiff-0.1.35 → flaxdiff-0.1.35.2}/flaxdiff.egg-info/SOURCES.txt +0 -0
- {flaxdiff-0.1.35 → flaxdiff-0.1.35.2}/flaxdiff.egg-info/dependency_links.txt +0 -0
- {flaxdiff-0.1.35 → flaxdiff-0.1.35.2}/flaxdiff.egg-info/requires.txt +0 -0
- {flaxdiff-0.1.35 → flaxdiff-0.1.35.2}/flaxdiff.egg-info/top_level.txt +0 -0
- {flaxdiff-0.1.35 → flaxdiff-0.1.35.2}/setup.cfg +0 -0
@@ -84,9 +84,9 @@ def default_feature_extractor(sample):
|
|
84
84
|
"caption": sample["caption"],
|
85
85
|
}
|
86
86
|
|
87
|
-
|
88
87
|
def map_sample(
|
89
|
-
|
88
|
+
url,
|
89
|
+
caption,
|
90
90
|
image_shape=(256, 256),
|
91
91
|
min_image_shape=(128, 128),
|
92
92
|
timeout=15,
|
@@ -94,11 +94,8 @@ def map_sample(
|
|
94
94
|
upscale_interpolation=cv2.INTER_CUBIC,
|
95
95
|
downscale_interpolation=cv2.INTER_AREA,
|
96
96
|
image_processor=default_image_processor,
|
97
|
-
feature_extractor=default_feature_extractor,
|
98
97
|
):
|
99
98
|
try:
|
100
|
-
features = feature_extractor(sample)
|
101
|
-
url, caption = features["url"], features["caption"]
|
102
99
|
# Assuming fetch_single_image is defined elsewhere
|
103
100
|
image = fetch_single_image(url, timeout=timeout, retries=retries)
|
104
101
|
if image is None:
|
@@ -130,7 +127,6 @@ def map_sample(
|
|
130
127
|
# })
|
131
128
|
pass
|
132
129
|
|
133
|
-
|
134
130
|
def map_batch(
|
135
131
|
batch, num_threads=256, image_shape=(256, 256),
|
136
132
|
min_image_shape=(128, 128),
|
@@ -147,48 +143,52 @@ def map_batch(
|
|
147
143
|
downscale_interpolation=downscale_interpolation,
|
148
144
|
feature_extractor=feature_extractor
|
149
145
|
)
|
146
|
+
features = feature_extractor(batch)
|
147
|
+
url, caption = features["url"], features["caption"]
|
150
148
|
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
151
|
-
executor.map(map_sample_fn,
|
149
|
+
executor.map(map_sample_fn, url, caption)
|
150
|
+
return None
|
152
151
|
except Exception as e:
|
153
152
|
print(f"Error maping batch", e)
|
154
153
|
traceback.print_exc()
|
155
|
-
error_queue.put_nowait({
|
156
|
-
|
157
|
-
|
158
|
-
})
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
def map_batch_repeat_forever(
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
):
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
154
|
+
# error_queue.put_nowait({
|
155
|
+
# "batch": batch,
|
156
|
+
# "error": str(e)
|
157
|
+
# })
|
158
|
+
return e
|
159
|
+
|
160
|
+
|
161
|
+
# def map_batch_repeat_forever(
|
162
|
+
# batch, num_threads=256, image_shape=(256, 256),
|
163
|
+
# min_image_shape=(128, 128),
|
164
|
+
# timeout=15, retries=3, image_processor=default_image_processor,
|
165
|
+
# upscale_interpolation=cv2.INTER_CUBIC,
|
166
|
+
# downscale_interpolation=cv2.INTER_AREA,
|
167
|
+
# feature_extractor=default_feature_extractor,
|
168
|
+
# ):
|
169
|
+
# while True: # Repeat forever
|
170
|
+
# try:
|
171
|
+
# map_sample_fn = partial(
|
172
|
+
# map_sample, image_shape=image_shape, min_image_shape=min_image_shape,
|
173
|
+
# timeout=timeout, retries=retries, image_processor=image_processor,
|
174
|
+
# upscale_interpolation=upscale_interpolation,
|
175
|
+
# downscale_interpolation=downscale_interpolation,
|
176
|
+
# feature_extractor=feature_extractor
|
177
|
+
# )
|
178
|
+
# features = feature_extractor(batch)
|
179
|
+
# url, caption = features["url"], features["caption"]
|
180
|
+
# with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
181
|
+
# executor.map(map_sample_fn, url, caption)
|
182
|
+
# # Shuffle the batch
|
183
|
+
# batch = batch.shuffle(seed=np.random.randint(0, 1000000))
|
184
|
+
# except Exception as e:
|
185
|
+
# print(f"Error maping batch", e)
|
186
|
+
# traceback.print_exc()
|
187
|
+
# # error_queue.put_nowait({
|
188
|
+
# # "batch": batch,
|
189
|
+
# # "error": str(e)
|
190
|
+
# # })
|
191
|
+
# pass
|
192
192
|
|
193
193
|
def parallel_image_loader(
|
194
194
|
dataset: Dataset, num_workers: int = 8, image_shape=(256, 256),
|
@@ -214,11 +214,14 @@ def parallel_image_loader(
|
|
214
214
|
iteration = 0
|
215
215
|
while True:
|
216
216
|
# Repeat forever
|
217
|
-
|
218
|
-
|
219
|
-
shards = [dataset.shard(num_shards=num_workers, index=i) for i in range(num_workers)]
|
217
|
+
shards = [dataset[i*shard_len:(i+1)*shard_len]
|
218
|
+
for i in range(num_workers)]
|
219
|
+
# shards = [dataset.shard(num_shards=num_workers, index=i) for i in range(num_workers)]
|
220
220
|
print(f"mapping {len(shards)} shards")
|
221
|
-
pool.map(map_batch_fn, shards)
|
221
|
+
errors = pool.map(map_batch_fn, shards)
|
222
|
+
for error in errors:
|
223
|
+
if error is not None:
|
224
|
+
print(f"Error in mapping batch", error)
|
222
225
|
iteration += 1
|
223
226
|
print(f"Shuffling dataset with seed {iteration}")
|
224
227
|
dataset = dataset.shuffle(seed=iteration)
|
@@ -257,17 +260,6 @@ class ImageBatchIterator:
|
|
257
260
|
)
|
258
261
|
self.thread = threading.Thread(target=loader, args=(dataset,))
|
259
262
|
self.thread.start()
|
260
|
-
self.error_queue = queue.Queue()
|
261
|
-
|
262
|
-
def error_fetcher():
|
263
|
-
while True:
|
264
|
-
error = error_queue.get()
|
265
|
-
self.error_queue.put(error)
|
266
|
-
self.error_thread = threading.Thread(target=error_fetcher)
|
267
|
-
self.error_thread.start()
|
268
|
-
|
269
|
-
def get_error(self):
|
270
|
-
yield self.error_queue.get()
|
271
263
|
|
272
264
|
def __iter__(self):
|
273
265
|
return self
|
@@ -11,7 +11,7 @@ required_packages=[
|
|
11
11
|
setup(
|
12
12
|
name='flaxdiff',
|
13
13
|
packages=find_packages(),
|
14
|
-
version='0.1.35',
|
14
|
+
version='0.1.35.2',
|
15
15
|
description='A versatile and easy to understand Diffusion library',
|
16
16
|
long_description=open('README.md').read(),
|
17
17
|
long_description_content_type='text/markdown',
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|