flaxdiff 0.1.32__tar.gz → 0.1.34__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. {flaxdiff-0.1.32 → flaxdiff-0.1.34}/PKG-INFO +1 -1
  2. {flaxdiff-0.1.32 → flaxdiff-0.1.34}/flaxdiff/data/online_loader.py +118 -53
  3. {flaxdiff-0.1.32 → flaxdiff-0.1.34}/flaxdiff.egg-info/PKG-INFO +1 -1
  4. {flaxdiff-0.1.32 → flaxdiff-0.1.34}/setup.py +1 -1
  5. {flaxdiff-0.1.32 → flaxdiff-0.1.34}/README.md +0 -0
  6. {flaxdiff-0.1.32 → flaxdiff-0.1.34}/flaxdiff/__init__.py +0 -0
  7. {flaxdiff-0.1.32 → flaxdiff-0.1.34}/flaxdiff/data/__init__.py +0 -0
  8. {flaxdiff-0.1.32 → flaxdiff-0.1.34}/flaxdiff/models/__init__.py +0 -0
  9. {flaxdiff-0.1.32 → flaxdiff-0.1.34}/flaxdiff/models/attention.py +0 -0
  10. {flaxdiff-0.1.32 → flaxdiff-0.1.34}/flaxdiff/models/autoencoder/__init__.py +0 -0
  11. {flaxdiff-0.1.32 → flaxdiff-0.1.34}/flaxdiff/models/autoencoder/autoencoder.py +0 -0
  12. {flaxdiff-0.1.32 → flaxdiff-0.1.34}/flaxdiff/models/autoencoder/diffusers.py +0 -0
  13. {flaxdiff-0.1.32 → flaxdiff-0.1.34}/flaxdiff/models/autoencoder/simple_autoenc.py +0 -0
  14. {flaxdiff-0.1.32 → flaxdiff-0.1.34}/flaxdiff/models/common.py +0 -0
  15. {flaxdiff-0.1.32 → flaxdiff-0.1.34}/flaxdiff/models/favor_fastattn.py +0 -0
  16. {flaxdiff-0.1.32 → flaxdiff-0.1.34}/flaxdiff/models/simple_unet.py +0 -0
  17. {flaxdiff-0.1.32 → flaxdiff-0.1.34}/flaxdiff/models/simple_vit.py +0 -0
  18. {flaxdiff-0.1.32 → flaxdiff-0.1.34}/flaxdiff/predictors/__init__.py +0 -0
  19. {flaxdiff-0.1.32 → flaxdiff-0.1.34}/flaxdiff/samplers/__init__.py +0 -0
  20. {flaxdiff-0.1.32 → flaxdiff-0.1.34}/flaxdiff/samplers/common.py +0 -0
  21. {flaxdiff-0.1.32 → flaxdiff-0.1.34}/flaxdiff/samplers/ddim.py +0 -0
  22. {flaxdiff-0.1.32 → flaxdiff-0.1.34}/flaxdiff/samplers/ddpm.py +0 -0
  23. {flaxdiff-0.1.32 → flaxdiff-0.1.34}/flaxdiff/samplers/euler.py +0 -0
  24. {flaxdiff-0.1.32 → flaxdiff-0.1.34}/flaxdiff/samplers/heun_sampler.py +0 -0
  25. {flaxdiff-0.1.32 → flaxdiff-0.1.34}/flaxdiff/samplers/multistep_dpm.py +0 -0
  26. {flaxdiff-0.1.32 → flaxdiff-0.1.34}/flaxdiff/samplers/rk4_sampler.py +0 -0
  27. {flaxdiff-0.1.32 → flaxdiff-0.1.34}/flaxdiff/schedulers/__init__.py +0 -0
  28. {flaxdiff-0.1.32 → flaxdiff-0.1.34}/flaxdiff/schedulers/common.py +0 -0
  29. {flaxdiff-0.1.32 → flaxdiff-0.1.34}/flaxdiff/schedulers/continuous.py +0 -0
  30. {flaxdiff-0.1.32 → flaxdiff-0.1.34}/flaxdiff/schedulers/cosine.py +0 -0
  31. {flaxdiff-0.1.32 → flaxdiff-0.1.34}/flaxdiff/schedulers/discrete.py +0 -0
  32. {flaxdiff-0.1.32 → flaxdiff-0.1.34}/flaxdiff/schedulers/exp.py +0 -0
  33. {flaxdiff-0.1.32 → flaxdiff-0.1.34}/flaxdiff/schedulers/karras.py +0 -0
  34. {flaxdiff-0.1.32 → flaxdiff-0.1.34}/flaxdiff/schedulers/linear.py +0 -0
  35. {flaxdiff-0.1.32 → flaxdiff-0.1.34}/flaxdiff/schedulers/sqrt.py +0 -0
  36. {flaxdiff-0.1.32 → flaxdiff-0.1.34}/flaxdiff/trainer/__init__.py +0 -0
  37. {flaxdiff-0.1.32 → flaxdiff-0.1.34}/flaxdiff/trainer/autoencoder_trainer.py +0 -0
  38. {flaxdiff-0.1.32 → flaxdiff-0.1.34}/flaxdiff/trainer/diffusion_trainer.py +0 -0
  39. {flaxdiff-0.1.32 → flaxdiff-0.1.34}/flaxdiff/trainer/simple_trainer.py +0 -0
  40. {flaxdiff-0.1.32 → flaxdiff-0.1.34}/flaxdiff/utils.py +0 -0
  41. {flaxdiff-0.1.32 → flaxdiff-0.1.34}/flaxdiff.egg-info/SOURCES.txt +0 -0
  42. {flaxdiff-0.1.32 → flaxdiff-0.1.34}/flaxdiff.egg-info/dependency_links.txt +0 -0
  43. {flaxdiff-0.1.32 → flaxdiff-0.1.34}/flaxdiff.egg-info/requires.txt +0 -0
  44. {flaxdiff-0.1.32 → flaxdiff-0.1.34}/flaxdiff.egg-info/top_level.txt +0 -0
  45. {flaxdiff-0.1.32 → flaxdiff-0.1.34}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flaxdiff
3
- Version: 0.1.32
3
+ Version: 0.1.34
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author: Ashish Kumar Singh
6
6
  Author-email: ashishkmr472@gmail.com
@@ -21,11 +21,12 @@ import urllib
21
21
 
22
22
  import PIL.Image
23
23
  import cv2
24
- import traceback
24
+ import traceback
25
25
 
26
26
  USER_AGENT = get_datasets_user_agent()
27
27
 
28
28
  data_queue = Queue(16*2000)
29
+ error_queue = Queue(16*2000)
29
30
 
30
31
 
31
32
  def fetch_single_image(image_url, timeout=None, retries=0):
@@ -45,7 +46,7 @@ def fetch_single_image(image_url, timeout=None, retries=0):
45
46
 
46
47
 
47
48
  def default_image_processor(
48
- image, image_shape,
49
+ image, image_shape,
49
50
  min_image_shape=(128, 128),
50
51
  upscale_interpolation=cv2.INTER_CUBIC,
51
52
  downscale_interpolation=cv2.INTER_AREA,
@@ -77,8 +78,15 @@ def default_image_processor(
77
78
  return image, original_height, original_width
78
79
 
79
80
 
81
+ def default_feature_extractor(sample):
82
+ return {
83
+ "url": sample["url"],
84
+ "caption": sample["caption"],
85
+ }
86
+
87
+
80
88
  def map_sample(
81
- url, caption,
89
+ sample,
82
90
  image_shape=(256, 256),
83
91
  min_image_shape=(128, 128),
84
92
  timeout=15,
@@ -86,8 +94,11 @@ def map_sample(
86
94
  upscale_interpolation=cv2.INTER_CUBIC,
87
95
  downscale_interpolation=cv2.INTER_AREA,
88
96
  image_processor=default_image_processor,
97
+ feature_extractor=default_feature_extractor,
89
98
  ):
90
99
  try:
100
+ features = feature_extractor(sample)
101
+ url, caption = features["url"], features["caption"]
91
102
  # Assuming fetch_single_image is defined elsewhere
92
103
  image = fetch_single_image(url, timeout=timeout, retries=retries)
93
104
  if image is None:
@@ -96,11 +107,12 @@ def map_sample(
96
107
  image, original_height, original_width = image_processor(
97
108
  image, image_shape, min_image_shape=min_image_shape,
98
109
  upscale_interpolation=upscale_interpolation,
99
- downscale_interpolation=downscale_interpolation,)
100
-
110
+ downscale_interpolation=downscale_interpolation,
111
+ )
112
+
101
113
  if image is None:
102
114
  return
103
-
115
+
104
116
  data_queue.put({
105
117
  "url": url,
106
118
  "caption": caption,
@@ -110,22 +122,17 @@ def map_sample(
110
122
  })
111
123
  except Exception as e:
112
124
  print(f"Error maping sample {url}", e)
113
- traceback.print_exc()
125
+ traceback.print_exc()
114
126
  # error_queue.put_nowait({
115
127
  # "url": url,
116
128
  # "caption": caption,
117
129
  # "error": str(e)
118
130
  # })
119
131
  pass
120
-
121
- def default_feature_extractor(sample):
122
- return {
123
- "url": sample["url"],
124
- "caption": sample["caption"],
125
- }
132
+
126
133
 
127
134
  def map_batch(
128
- batch, num_threads=256, image_shape=(256, 256),
135
+ batch, num_threads=256, image_shape=(256, 256),
129
136
  min_image_shape=(128, 128),
130
137
  timeout=15, retries=3, image_processor=default_image_processor,
131
138
  upscale_interpolation=cv2.INTER_CUBIC,
@@ -133,40 +140,76 @@ def map_batch(
133
140
  feature_extractor=default_feature_extractor,
134
141
  ):
135
142
  try:
136
- map_sample_fn = partial(map_sample, image_shape=image_shape, min_image_shape=min_image_shape,
137
- timeout=timeout, retries=retries, image_processor=image_processor,
138
- upscale_interpolation=upscale_interpolation,
139
- downscale_interpolation=downscale_interpolation)
143
+ map_sample_fn = partial(
144
+ map_sample, image_shape=image_shape, min_image_shape=min_image_shape,
145
+ timeout=timeout, retries=retries, image_processor=image_processor,
146
+ upscale_interpolation=upscale_interpolation,
147
+ downscale_interpolation=downscale_interpolation,
148
+ feature_extractor=feature_extractor
149
+ )
140
150
  with ThreadPoolExecutor(max_workers=num_threads) as executor:
141
- features = feature_extractor(batch)
142
- url, caption = features["url"], features["caption"]
143
- executor.map(map_sample_fn, url, caption)
151
+ executor.map(map_sample_fn, batch)
144
152
  except Exception as e:
145
153
  print(f"Error maping batch", e)
146
- traceback.print_exc()
147
- # error_queue.put_nowait({
148
- # "batch": batch,
149
- # "error": str(e)
150
- # })
154
+ traceback.print_exc()
155
+ error_queue.put_nowait({
156
+ "batch": batch,
157
+ "error": str(e)
158
+ })
151
159
  pass
152
160
 
153
161
 
162
+ def map_batch_repeat_forever(
163
+ batch, num_threads=256, image_shape=(256, 256),
164
+ min_image_shape=(128, 128),
165
+ timeout=15, retries=3, image_processor=default_image_processor,
166
+ upscale_interpolation=cv2.INTER_CUBIC,
167
+ downscale_interpolation=cv2.INTER_AREA,
168
+ feature_extractor=default_feature_extractor,
169
+ ):
170
+ while True: # Repeat forever
171
+ try:
172
+ map_sample_fn = partial(
173
+ map_sample, image_shape=image_shape, min_image_shape=min_image_shape,
174
+ timeout=timeout, retries=retries, image_processor=image_processor,
175
+ upscale_interpolation=upscale_interpolation,
176
+ downscale_interpolation=downscale_interpolation,
177
+ feature_extractor=feature_extractor
178
+ )
179
+ with ThreadPoolExecutor(max_workers=num_threads) as executor:
180
+ executor.map(map_sample_fn, batch)
181
+ # Shuffle the batch
182
+ batch = batch.shuffle(seed=np.random.randint(0, 1000000))
183
+ except Exception as e:
184
+ print(f"Error maping batch", e)
185
+ traceback.print_exc()
186
+ error_queue.put_nowait({
187
+ "batch": batch,
188
+ "error": str(e)
189
+ })
190
+ pass
191
+
192
+
154
193
  def parallel_image_loader(
155
- dataset: Dataset, num_workers: int = 8, image_shape=(256, 256),
194
+ dataset: Dataset, num_workers: int = 8, image_shape=(256, 256),
156
195
  min_image_shape=(128, 128),
157
196
  num_threads=256, timeout=15, retries=3, image_processor=default_image_processor,
158
197
  upscale_interpolation=cv2.INTER_CUBIC,
159
198
  downscale_interpolation=cv2.INTER_AREA,
160
199
  feature_extractor=default_feature_extractor,
200
+ map_batch_fn=map_batch,
201
+
161
202
  ):
162
- map_batch_fn = partial(map_batch, num_threads=num_threads, image_shape=image_shape,
163
- min_image_shape=min_image_shape,
164
- timeout=timeout, retries=retries, image_processor=image_processor,
165
- upscale_interpolation=upscale_interpolation,
166
- downscale_interpolation=downscale_interpolation,
167
- feature_extractor=feature_extractor)
203
+ map_batch_fn = partial(
204
+ map_batch_fn, num_threads=num_threads, image_shape=image_shape,
205
+ min_image_shape=min_image_shape,
206
+ timeout=timeout, retries=retries, image_processor=image_processor,
207
+ upscale_interpolation=upscale_interpolation,
208
+ downscale_interpolation=downscale_interpolation,
209
+ feature_extractor=feature_extractor
210
+ )
168
211
  shard_len = len(dataset) // num_workers
169
- print(f"Local Shard lengths: {shard_len}")
212
+ print(f"Local Shard lengths: {shard_len}, workers: {num_workers}")
170
213
  with multiprocessing.Pool(num_workers) as pool:
171
214
  iteration = 0
172
215
  while True:
@@ -178,6 +221,7 @@ def parallel_image_loader(
178
221
  iteration += 1
179
222
  print(f"Shuffling dataset with seed {iteration}")
180
223
  dataset = dataset.shuffle(seed=iteration)
224
+ print(f"Dataset shuffled")
181
225
  # Clear the error queue
182
226
  # while not error_queue.empty():
183
227
  # error_queue.get_nowait()
@@ -185,27 +229,44 @@ def parallel_image_loader(
185
229
 
186
230
  class ImageBatchIterator:
187
231
  def __init__(
188
- self, dataset: Dataset, batch_size: int = 64, image_shape=(256, 256),
232
+ self, dataset: Dataset, batch_size: int = 64, image_shape=(256, 256),
189
233
  min_image_shape=(128, 128),
190
- num_workers: int = 8, num_threads=256, timeout=15, retries=3,
234
+ num_workers: int = 8, num_threads=256, timeout=15, retries=3,
191
235
  image_processor=default_image_processor,
192
236
  upscale_interpolation=cv2.INTER_CUBIC,
193
237
  downscale_interpolation=cv2.INTER_AREA,
194
238
  feature_extractor=default_feature_extractor,
239
+ map_batch_fn=map_batch,
195
240
  ):
196
241
  self.dataset = dataset
197
242
  self.num_workers = num_workers
198
243
  self.batch_size = batch_size
199
- loader = partial(parallel_image_loader, num_threads=num_threads,
200
- image_shape=image_shape,
201
- min_image_shape=min_image_shape,
202
- num_workers=num_workers,
203
- timeout=timeout, retries=retries, image_processor=image_processor,
204
- upscale_interpolation=upscale_interpolation,
205
- downscale_interpolation=downscale_interpolation,
206
- feature_extractor=feature_extractor)
244
+ loader = partial(
245
+ parallel_image_loader,
246
+ num_threads=num_threads,
247
+ image_shape=image_shape,
248
+ min_image_shape=min_image_shape,
249
+ num_workers=num_workers,
250
+ timeout=timeout, retries=retries,
251
+ image_processor=image_processor,
252
+ upscale_interpolation=upscale_interpolation,
253
+ downscale_interpolation=downscale_interpolation,
254
+ feature_extractor=feature_extractor,
255
+ map_batch_fn=map_batch_fn,
256
+ )
207
257
  self.thread = threading.Thread(target=loader, args=(dataset,))
208
258
  self.thread.start()
259
+ self.error_queue = queue.Queue()
260
+
261
+ def error_fetcher():
262
+ while True:
263
+ error = error_queue.get()
264
+ self.error_queue.put(error)
265
+ self.error_thread = threading.Thread(target=error_fetcher)
266
+ self.error_thread.start()
267
+
268
+ def get_error(self):
269
+ yield self.error_queue.get()
209
270
 
210
271
  def __iter__(self):
211
272
  return self
@@ -269,6 +330,7 @@ class OnlineStreamingDataLoader():
269
330
  upscale_interpolation=cv2.INTER_CUBIC,
270
331
  downscale_interpolation=cv2.INTER_AREA,
271
332
  feature_extractor=default_feature_extractor,
333
+ map_batch_fn=map_batch,
272
334
  ):
273
335
  if isinstance(dataset, str):
274
336
  dataset_path = dataset
@@ -289,13 +351,16 @@ class OnlineStreamingDataLoader():
289
351
  self.dataset = dataset.shard(
290
352
  num_shards=global_process_count, index=global_process_index)
291
353
  print(f"Dataset length: {len(dataset)}")
292
- self.iterator = ImageBatchIterator(self.dataset, image_shape=image_shape,
293
- min_image_shape=min_image_shape,
294
- num_workers=num_workers, batch_size=batch_size, num_threads=num_threads,
295
- timeout=timeout, retries=retries, image_processor=image_processor,
296
- upscale_interpolation=upscale_interpolation,
297
- downscale_interpolation=downscale_interpolation,
298
- feature_extractor=feature_extractor)
354
+ self.iterator = ImageBatchIterator(
355
+ self.dataset, image_shape=image_shape,
356
+ min_image_shape=min_image_shape,
357
+ num_workers=num_workers, batch_size=batch_size, num_threads=num_threads,
358
+ timeout=timeout, retries=retries, image_processor=image_processor,
359
+ upscale_interpolation=upscale_interpolation,
360
+ downscale_interpolation=downscale_interpolation,
361
+ feature_extractor=feature_extractor,
362
+ map_batch_fn=map_batch_fn,
363
+ )
299
364
  self.batch_size = batch_size
300
365
 
301
366
  # Launch a thread to load batches in the background
@@ -306,7 +371,7 @@ class OnlineStreamingDataLoader():
306
371
  try:
307
372
  self.batch_queue.put(collate_fn(batch))
308
373
  except Exception as e:
309
- print("Error processing batch", e)
374
+ print("Error collating batch", e)
310
375
 
311
376
  self.loader_thread = threading.Thread(target=batch_loader)
312
377
  self.loader_thread.start()
@@ -319,4 +384,4 @@ class OnlineStreamingDataLoader():
319
384
  # return self.collate_fn(next(self.iterator))
320
385
 
321
386
  def __len__(self):
322
- return len(self.dataset)
387
+ return len(self.dataset)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flaxdiff
3
- Version: 0.1.32
3
+ Version: 0.1.34
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author: Ashish Kumar Singh
6
6
  Author-email: ashishkmr472@gmail.com
@@ -11,7 +11,7 @@ required_packages=[
11
11
  setup(
12
12
  name='flaxdiff',
13
13
  packages=find_packages(),
14
- version='0.1.32',
14
+ version='0.1.34',
15
15
  description='A versatile and easy to understand Diffusion library',
16
16
  long_description=open('README.md').read(),
17
17
  long_description_content_type='text/markdown',
File without changes
File without changes
File without changes