flaxdiff 0.1.31__tar.gz → 0.1.33__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {flaxdiff-0.1.31 → flaxdiff-0.1.33}/PKG-INFO +1 -1
- {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff/data/online_loader.py +121 -42
- {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff.egg-info/PKG-INFO +1 -1
- {flaxdiff-0.1.31 → flaxdiff-0.1.33}/setup.py +1 -1
- {flaxdiff-0.1.31 → flaxdiff-0.1.33}/README.md +0 -0
- {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff/__init__.py +0 -0
- {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff/data/__init__.py +0 -0
- {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff/models/__init__.py +0 -0
- {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff/models/attention.py +0 -0
- {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff/models/autoencoder/__init__.py +0 -0
- {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff/models/autoencoder/autoencoder.py +0 -0
- {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff/models/autoencoder/diffusers.py +0 -0
- {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff/models/autoencoder/simple_autoenc.py +0 -0
- {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff/models/common.py +0 -0
- {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff/models/favor_fastattn.py +0 -0
- {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff/models/simple_unet.py +0 -0
- {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff/models/simple_vit.py +0 -0
- {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff/predictors/__init__.py +0 -0
- {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff/samplers/__init__.py +0 -0
- {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff/samplers/common.py +0 -0
- {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff/samplers/ddim.py +0 -0
- {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff/samplers/ddpm.py +0 -0
- {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff/samplers/euler.py +0 -0
- {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff/samplers/heun_sampler.py +0 -0
- {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff/samplers/multistep_dpm.py +0 -0
- {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff/samplers/rk4_sampler.py +0 -0
- {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff/schedulers/__init__.py +0 -0
- {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff/schedulers/common.py +0 -0
- {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff/schedulers/continuous.py +0 -0
- {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff/schedulers/cosine.py +0 -0
- {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff/schedulers/discrete.py +0 -0
- {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff/schedulers/exp.py +0 -0
- {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff/schedulers/karras.py +0 -0
- {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff/schedulers/linear.py +0 -0
- {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff/schedulers/sqrt.py +0 -0
- {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff/trainer/__init__.py +0 -0
- {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff/trainer/autoencoder_trainer.py +0 -0
- {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff/trainer/diffusion_trainer.py +0 -0
- {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff/trainer/simple_trainer.py +0 -0
- {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff/utils.py +0 -0
- {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff.egg-info/SOURCES.txt +0 -0
- {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff.egg-info/dependency_links.txt +0 -0
- {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff.egg-info/requires.txt +0 -0
- {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff.egg-info/top_level.txt +0 -0
- {flaxdiff-0.1.31 → flaxdiff-0.1.33}/setup.cfg +0 -0
@@ -21,11 +21,12 @@ import urllib
|
|
21
21
|
|
22
22
|
import PIL.Image
|
23
23
|
import cv2
|
24
|
-
import traceback
|
24
|
+
import traceback
|
25
25
|
|
26
26
|
USER_AGENT = get_datasets_user_agent()
|
27
27
|
|
28
28
|
data_queue = Queue(16*2000)
|
29
|
+
error_queue = Queue(16*2000)
|
29
30
|
|
30
31
|
|
31
32
|
def fetch_single_image(image_url, timeout=None, retries=0):
|
@@ -45,7 +46,7 @@ def fetch_single_image(image_url, timeout=None, retries=0):
|
|
45
46
|
|
46
47
|
|
47
48
|
def default_image_processor(
|
48
|
-
image, image_shape,
|
49
|
+
image, image_shape,
|
49
50
|
min_image_shape=(128, 128),
|
50
51
|
upscale_interpolation=cv2.INTER_CUBIC,
|
51
52
|
downscale_interpolation=cv2.INTER_AREA,
|
@@ -77,8 +78,15 @@ def default_image_processor(
|
|
77
78
|
return image, original_height, original_width
|
78
79
|
|
79
80
|
|
81
|
+
def default_feature_extractor(sample):
|
82
|
+
return {
|
83
|
+
"url": sample["url"],
|
84
|
+
"caption": sample["caption"],
|
85
|
+
}
|
86
|
+
|
87
|
+
|
80
88
|
def map_sample(
|
81
|
-
|
89
|
+
sample,
|
82
90
|
image_shape=(256, 256),
|
83
91
|
min_image_shape=(128, 128),
|
84
92
|
timeout=15,
|
@@ -86,8 +94,11 @@ def map_sample(
|
|
86
94
|
upscale_interpolation=cv2.INTER_CUBIC,
|
87
95
|
downscale_interpolation=cv2.INTER_AREA,
|
88
96
|
image_processor=default_image_processor,
|
97
|
+
feature_extractor=default_feature_extractor,
|
89
98
|
):
|
90
99
|
try:
|
100
|
+
features = feature_extractor(sample)
|
101
|
+
url, caption = features["url"], features["caption"]
|
91
102
|
# Assuming fetch_single_image is defined elsewhere
|
92
103
|
image = fetch_single_image(url, timeout=timeout, retries=retries)
|
93
104
|
if image is None:
|
@@ -96,11 +107,12 @@ def map_sample(
|
|
96
107
|
image, original_height, original_width = image_processor(
|
97
108
|
image, image_shape, min_image_shape=min_image_shape,
|
98
109
|
upscale_interpolation=upscale_interpolation,
|
99
|
-
downscale_interpolation=downscale_interpolation,
|
100
|
-
|
110
|
+
downscale_interpolation=downscale_interpolation,
|
111
|
+
)
|
112
|
+
|
101
113
|
if image is None:
|
102
114
|
return
|
103
|
-
|
115
|
+
|
104
116
|
data_queue.put({
|
105
117
|
"url": url,
|
106
118
|
"caption": caption,
|
@@ -110,7 +122,7 @@ def map_sample(
|
|
110
122
|
})
|
111
123
|
except Exception as e:
|
112
124
|
print(f"Error maping sample {url}", e)
|
113
|
-
traceback.print_exc()
|
125
|
+
traceback.print_exc()
|
114
126
|
# error_queue.put_nowait({
|
115
127
|
# "url": url,
|
116
128
|
# "caption": caption,
|
@@ -120,43 +132,84 @@ def map_sample(
|
|
120
132
|
|
121
133
|
|
122
134
|
def map_batch(
|
123
|
-
batch, num_threads=256, image_shape=(256, 256),
|
135
|
+
batch, num_threads=256, image_shape=(256, 256),
|
124
136
|
min_image_shape=(128, 128),
|
125
137
|
timeout=15, retries=3, image_processor=default_image_processor,
|
126
138
|
upscale_interpolation=cv2.INTER_CUBIC,
|
127
139
|
downscale_interpolation=cv2.INTER_AREA,
|
140
|
+
feature_extractor=default_feature_extractor,
|
128
141
|
):
|
129
142
|
try:
|
130
|
-
map_sample_fn = partial(
|
131
|
-
|
132
|
-
|
133
|
-
|
143
|
+
map_sample_fn = partial(
|
144
|
+
map_sample, image_shape=image_shape, min_image_shape=min_image_shape,
|
145
|
+
timeout=timeout, retries=retries, image_processor=image_processor,
|
146
|
+
upscale_interpolation=upscale_interpolation,
|
147
|
+
downscale_interpolation=downscale_interpolation,
|
148
|
+
feature_extractor=feature_extractor
|
149
|
+
)
|
134
150
|
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
135
|
-
executor.map(map_sample_fn, batch
|
151
|
+
executor.map(map_sample_fn, batch)
|
136
152
|
except Exception as e:
|
137
153
|
print(f"Error maping batch", e)
|
138
|
-
traceback.print_exc()
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
154
|
+
traceback.print_exc()
|
155
|
+
error_queue.put_nowait({
|
156
|
+
"batch": batch,
|
157
|
+
"error": str(e)
|
158
|
+
})
|
143
159
|
pass
|
144
160
|
|
145
161
|
|
162
|
+
def map_batch_repeat_forever(
|
163
|
+
batch, num_threads=256, image_shape=(256, 256),
|
164
|
+
min_image_shape=(128, 128),
|
165
|
+
timeout=15, retries=3, image_processor=default_image_processor,
|
166
|
+
upscale_interpolation=cv2.INTER_CUBIC,
|
167
|
+
downscale_interpolation=cv2.INTER_AREA,
|
168
|
+
feature_extractor=default_feature_extractor,
|
169
|
+
):
|
170
|
+
while True: # Repeat forever
|
171
|
+
try:
|
172
|
+
map_sample_fn = partial(
|
173
|
+
map_sample, image_shape=image_shape, min_image_shape=min_image_shape,
|
174
|
+
timeout=timeout, retries=retries, image_processor=image_processor,
|
175
|
+
upscale_interpolation=upscale_interpolation,
|
176
|
+
downscale_interpolation=downscale_interpolation,
|
177
|
+
feature_extractor=feature_extractor
|
178
|
+
)
|
179
|
+
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
180
|
+
executor.map(map_sample_fn, batch)
|
181
|
+
# Shuffle the batch
|
182
|
+
batch = batch.shuffle(seed=np.random.randint(0, 1000000))
|
183
|
+
except Exception as e:
|
184
|
+
print(f"Error maping batch", e)
|
185
|
+
traceback.print_exc()
|
186
|
+
error_queue.put_nowait({
|
187
|
+
"batch": batch,
|
188
|
+
"error": str(e)
|
189
|
+
})
|
190
|
+
pass
|
191
|
+
|
192
|
+
|
146
193
|
def parallel_image_loader(
|
147
|
-
dataset: Dataset, num_workers: int = 8, image_shape=(256, 256),
|
194
|
+
dataset: Dataset, num_workers: int = 8, image_shape=(256, 256),
|
148
195
|
min_image_shape=(128, 128),
|
149
196
|
num_threads=256, timeout=15, retries=3, image_processor=default_image_processor,
|
150
197
|
upscale_interpolation=cv2.INTER_CUBIC,
|
151
198
|
downscale_interpolation=cv2.INTER_AREA,
|
199
|
+
feature_extractor=default_feature_extractor,
|
200
|
+
map_batch_fn=map_batch,
|
201
|
+
|
152
202
|
):
|
153
|
-
map_batch_fn = partial(
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
203
|
+
map_batch_fn = partial(
|
204
|
+
map_batch_fn, num_threads=num_threads, image_shape=image_shape,
|
205
|
+
min_image_shape=min_image_shape,
|
206
|
+
timeout=timeout, retries=retries, image_processor=image_processor,
|
207
|
+
upscale_interpolation=upscale_interpolation,
|
208
|
+
downscale_interpolation=downscale_interpolation,
|
209
|
+
feature_extractor=feature_extractor
|
210
|
+
)
|
158
211
|
shard_len = len(dataset) // num_workers
|
159
|
-
print(f"Local Shard lengths: {shard_len}")
|
212
|
+
print(f"Local Shard lengths: {shard_len}, workers: {num_workers}")
|
160
213
|
with multiprocessing.Pool(num_workers) as pool:
|
161
214
|
iteration = 0
|
162
215
|
while True:
|
@@ -168,6 +221,7 @@ def parallel_image_loader(
|
|
168
221
|
iteration += 1
|
169
222
|
print(f"Shuffling dataset with seed {iteration}")
|
170
223
|
dataset = dataset.shuffle(seed=iteration)
|
224
|
+
print(f"Dataset shuffled")
|
171
225
|
# Clear the error queue
|
172
226
|
# while not error_queue.empty():
|
173
227
|
# error_queue.get_nowait()
|
@@ -175,25 +229,44 @@ def parallel_image_loader(
|
|
175
229
|
|
176
230
|
class ImageBatchIterator:
|
177
231
|
def __init__(
|
178
|
-
self, dataset: Dataset, batch_size: int = 64, image_shape=(256, 256),
|
232
|
+
self, dataset: Dataset, batch_size: int = 64, image_shape=(256, 256),
|
179
233
|
min_image_shape=(128, 128),
|
180
|
-
num_workers: int = 8, num_threads=256, timeout=15, retries=3,
|
234
|
+
num_workers: int = 8, num_threads=256, timeout=15, retries=3,
|
181
235
|
image_processor=default_image_processor,
|
182
236
|
upscale_interpolation=cv2.INTER_CUBIC,
|
183
237
|
downscale_interpolation=cv2.INTER_AREA,
|
238
|
+
feature_extractor=default_feature_extractor,
|
239
|
+
map_batch_fn=map_batch,
|
184
240
|
):
|
185
241
|
self.dataset = dataset
|
186
242
|
self.num_workers = num_workers
|
187
243
|
self.batch_size = batch_size
|
188
|
-
loader = partial(
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
244
|
+
loader = partial(
|
245
|
+
parallel_image_loader,
|
246
|
+
num_threads=num_threads,
|
247
|
+
image_shape=image_shape,
|
248
|
+
min_image_shape=min_image_shape,
|
249
|
+
num_workers=num_workers,
|
250
|
+
timeout=timeout, retries=retries,
|
251
|
+
image_processor=image_processor,
|
252
|
+
upscale_interpolation=upscale_interpolation,
|
253
|
+
downscale_interpolation=downscale_interpolation,
|
254
|
+
feature_extractor=feature_extractor,
|
255
|
+
map_batch_fn=map_batch_fn,
|
256
|
+
)
|
195
257
|
self.thread = threading.Thread(target=loader, args=(dataset,))
|
196
258
|
self.thread.start()
|
259
|
+
self.error_queue = queue.Queue()
|
260
|
+
|
261
|
+
def error_fetcher():
|
262
|
+
while True:
|
263
|
+
error = error_queue.get()
|
264
|
+
self.error_queue.put(error)
|
265
|
+
self.error_thread = threading.Thread(target=error_fetcher)
|
266
|
+
self.error_thread.start()
|
267
|
+
|
268
|
+
def get_error(self):
|
269
|
+
yield self.error_queue.get()
|
197
270
|
|
198
271
|
def __iter__(self):
|
199
272
|
return self
|
@@ -256,6 +329,8 @@ class OnlineStreamingDataLoader():
|
|
256
329
|
image_processor=default_image_processor,
|
257
330
|
upscale_interpolation=cv2.INTER_CUBIC,
|
258
331
|
downscale_interpolation=cv2.INTER_AREA,
|
332
|
+
feature_extractor=default_feature_extractor,
|
333
|
+
map_batch_fn=map_batch,
|
259
334
|
):
|
260
335
|
if isinstance(dataset, str):
|
261
336
|
dataset_path = dataset
|
@@ -276,12 +351,16 @@ class OnlineStreamingDataLoader():
|
|
276
351
|
self.dataset = dataset.shard(
|
277
352
|
num_shards=global_process_count, index=global_process_index)
|
278
353
|
print(f"Dataset length: {len(dataset)}")
|
279
|
-
self.iterator = ImageBatchIterator(
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
354
|
+
self.iterator = ImageBatchIterator(
|
355
|
+
self.dataset, image_shape=image_shape,
|
356
|
+
min_image_shape=min_image_shape,
|
357
|
+
num_workers=num_workers, batch_size=batch_size, num_threads=num_threads,
|
358
|
+
timeout=timeout, retries=retries, image_processor=image_processor,
|
359
|
+
upscale_interpolation=upscale_interpolation,
|
360
|
+
downscale_interpolation=downscale_interpolation,
|
361
|
+
feature_extractor=feature_extractor,
|
362
|
+
map_batch_fn=map_batch_fn,
|
363
|
+
)
|
285
364
|
self.batch_size = batch_size
|
286
365
|
|
287
366
|
# Launch a thread to load batches in the background
|
@@ -292,7 +371,7 @@ class OnlineStreamingDataLoader():
|
|
292
371
|
try:
|
293
372
|
self.batch_queue.put(collate_fn(batch))
|
294
373
|
except Exception as e:
|
295
|
-
print("Error
|
374
|
+
print("Error collating batch", e)
|
296
375
|
|
297
376
|
self.loader_thread = threading.Thread(target=batch_loader)
|
298
377
|
self.loader_thread.start()
|
@@ -305,4 +384,4 @@ class OnlineStreamingDataLoader():
|
|
305
384
|
# return self.collate_fn(next(self.iterator))
|
306
385
|
|
307
386
|
def __len__(self):
|
308
|
-
return len(self.dataset)
|
387
|
+
return len(self.dataset)
|
@@ -11,7 +11,7 @@ required_packages=[
|
|
11
11
|
setup(
|
12
12
|
name='flaxdiff',
|
13
13
|
packages=find_packages(),
|
14
|
-
version='0.1.
|
14
|
+
version='0.1.33',
|
15
15
|
description='A versatile and easy to understand Diffusion library',
|
16
16
|
long_description=open('README.md').read(),
|
17
17
|
long_description_content_type='text/markdown',
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|