flaxdiff 0.1.31__tar.gz → 0.1.33__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. {flaxdiff-0.1.31 → flaxdiff-0.1.33}/PKG-INFO +1 -1
  2. {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff/data/online_loader.py +121 -42
  3. {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff.egg-info/PKG-INFO +1 -1
  4. {flaxdiff-0.1.31 → flaxdiff-0.1.33}/setup.py +1 -1
  5. {flaxdiff-0.1.31 → flaxdiff-0.1.33}/README.md +0 -0
  6. {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff/__init__.py +0 -0
  7. {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff/data/__init__.py +0 -0
  8. {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff/models/__init__.py +0 -0
  9. {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff/models/attention.py +0 -0
  10. {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff/models/autoencoder/__init__.py +0 -0
  11. {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff/models/autoencoder/autoencoder.py +0 -0
  12. {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff/models/autoencoder/diffusers.py +0 -0
  13. {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff/models/autoencoder/simple_autoenc.py +0 -0
  14. {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff/models/common.py +0 -0
  15. {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff/models/favor_fastattn.py +0 -0
  16. {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff/models/simple_unet.py +0 -0
  17. {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff/models/simple_vit.py +0 -0
  18. {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff/predictors/__init__.py +0 -0
  19. {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff/samplers/__init__.py +0 -0
  20. {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff/samplers/common.py +0 -0
  21. {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff/samplers/ddim.py +0 -0
  22. {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff/samplers/ddpm.py +0 -0
  23. {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff/samplers/euler.py +0 -0
  24. {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff/samplers/heun_sampler.py +0 -0
  25. {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff/samplers/multistep_dpm.py +0 -0
  26. {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff/samplers/rk4_sampler.py +0 -0
  27. {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff/schedulers/__init__.py +0 -0
  28. {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff/schedulers/common.py +0 -0
  29. {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff/schedulers/continuous.py +0 -0
  30. {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff/schedulers/cosine.py +0 -0
  31. {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff/schedulers/discrete.py +0 -0
  32. {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff/schedulers/exp.py +0 -0
  33. {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff/schedulers/karras.py +0 -0
  34. {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff/schedulers/linear.py +0 -0
  35. {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff/schedulers/sqrt.py +0 -0
  36. {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff/trainer/__init__.py +0 -0
  37. {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff/trainer/autoencoder_trainer.py +0 -0
  38. {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff/trainer/diffusion_trainer.py +0 -0
  39. {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff/trainer/simple_trainer.py +0 -0
  40. {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff/utils.py +0 -0
  41. {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff.egg-info/SOURCES.txt +0 -0
  42. {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff.egg-info/dependency_links.txt +0 -0
  43. {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff.egg-info/requires.txt +0 -0
  44. {flaxdiff-0.1.31 → flaxdiff-0.1.33}/flaxdiff.egg-info/top_level.txt +0 -0
  45. {flaxdiff-0.1.31 → flaxdiff-0.1.33}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flaxdiff
3
- Version: 0.1.31
3
+ Version: 0.1.33
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author: Ashish Kumar Singh
6
6
  Author-email: ashishkmr472@gmail.com
@@ -21,11 +21,12 @@ import urllib
21
21
 
22
22
  import PIL.Image
23
23
  import cv2
24
- import traceback
24
+ import traceback
25
25
 
26
26
  USER_AGENT = get_datasets_user_agent()
27
27
 
28
28
  data_queue = Queue(16*2000)
29
+ error_queue = Queue(16*2000)
29
30
 
30
31
 
31
32
  def fetch_single_image(image_url, timeout=None, retries=0):
@@ -45,7 +46,7 @@ def fetch_single_image(image_url, timeout=None, retries=0):
45
46
 
46
47
 
47
48
  def default_image_processor(
48
- image, image_shape,
49
+ image, image_shape,
49
50
  min_image_shape=(128, 128),
50
51
  upscale_interpolation=cv2.INTER_CUBIC,
51
52
  downscale_interpolation=cv2.INTER_AREA,
@@ -77,8 +78,15 @@ def default_image_processor(
77
78
  return image, original_height, original_width
78
79
 
79
80
 
81
+ def default_feature_extractor(sample):
82
+ return {
83
+ "url": sample["url"],
84
+ "caption": sample["caption"],
85
+ }
86
+
87
+
80
88
  def map_sample(
81
- url, caption,
89
+ sample,
82
90
  image_shape=(256, 256),
83
91
  min_image_shape=(128, 128),
84
92
  timeout=15,
@@ -86,8 +94,11 @@ def map_sample(
86
94
  upscale_interpolation=cv2.INTER_CUBIC,
87
95
  downscale_interpolation=cv2.INTER_AREA,
88
96
  image_processor=default_image_processor,
97
+ feature_extractor=default_feature_extractor,
89
98
  ):
90
99
  try:
100
+ features = feature_extractor(sample)
101
+ url, caption = features["url"], features["caption"]
91
102
  # Assuming fetch_single_image is defined elsewhere
92
103
  image = fetch_single_image(url, timeout=timeout, retries=retries)
93
104
  if image is None:
@@ -96,11 +107,12 @@ def map_sample(
96
107
  image, original_height, original_width = image_processor(
97
108
  image, image_shape, min_image_shape=min_image_shape,
98
109
  upscale_interpolation=upscale_interpolation,
99
- downscale_interpolation=downscale_interpolation,)
100
-
110
+ downscale_interpolation=downscale_interpolation,
111
+ )
112
+
101
113
  if image is None:
102
114
  return
103
-
115
+
104
116
  data_queue.put({
105
117
  "url": url,
106
118
  "caption": caption,
@@ -110,7 +122,7 @@ def map_sample(
110
122
  })
111
123
  except Exception as e:
112
124
  print(f"Error maping sample {url}", e)
113
- traceback.print_exc()
125
+ traceback.print_exc()
114
126
  # error_queue.put_nowait({
115
127
  # "url": url,
116
128
  # "caption": caption,
@@ -120,43 +132,84 @@ def map_sample(
120
132
 
121
133
 
122
134
  def map_batch(
123
- batch, num_threads=256, image_shape=(256, 256),
135
+ batch, num_threads=256, image_shape=(256, 256),
124
136
  min_image_shape=(128, 128),
125
137
  timeout=15, retries=3, image_processor=default_image_processor,
126
138
  upscale_interpolation=cv2.INTER_CUBIC,
127
139
  downscale_interpolation=cv2.INTER_AREA,
140
+ feature_extractor=default_feature_extractor,
128
141
  ):
129
142
  try:
130
- map_sample_fn = partial(map_sample, image_shape=image_shape, min_image_shape=min_image_shape,
131
- timeout=timeout, retries=retries, image_processor=image_processor,
132
- upscale_interpolation=upscale_interpolation,
133
- downscale_interpolation=downscale_interpolation)
143
+ map_sample_fn = partial(
144
+ map_sample, image_shape=image_shape, min_image_shape=min_image_shape,
145
+ timeout=timeout, retries=retries, image_processor=image_processor,
146
+ upscale_interpolation=upscale_interpolation,
147
+ downscale_interpolation=downscale_interpolation,
148
+ feature_extractor=feature_extractor
149
+ )
134
150
  with ThreadPoolExecutor(max_workers=num_threads) as executor:
135
- executor.map(map_sample_fn, batch["url"], batch['caption'])
151
+ executor.map(map_sample_fn, batch)
136
152
  except Exception as e:
137
153
  print(f"Error maping batch", e)
138
- traceback.print_exc()
139
- # error_queue.put_nowait({
140
- # "batch": batch,
141
- # "error": str(e)
142
- # })
154
+ traceback.print_exc()
155
+ error_queue.put_nowait({
156
+ "batch": batch,
157
+ "error": str(e)
158
+ })
143
159
  pass
144
160
 
145
161
 
162
+ def map_batch_repeat_forever(
163
+ batch, num_threads=256, image_shape=(256, 256),
164
+ min_image_shape=(128, 128),
165
+ timeout=15, retries=3, image_processor=default_image_processor,
166
+ upscale_interpolation=cv2.INTER_CUBIC,
167
+ downscale_interpolation=cv2.INTER_AREA,
168
+ feature_extractor=default_feature_extractor,
169
+ ):
170
+ while True: # Repeat forever
171
+ try:
172
+ map_sample_fn = partial(
173
+ map_sample, image_shape=image_shape, min_image_shape=min_image_shape,
174
+ timeout=timeout, retries=retries, image_processor=image_processor,
175
+ upscale_interpolation=upscale_interpolation,
176
+ downscale_interpolation=downscale_interpolation,
177
+ feature_extractor=feature_extractor
178
+ )
179
+ with ThreadPoolExecutor(max_workers=num_threads) as executor:
180
+ executor.map(map_sample_fn, batch)
181
+ # Shuffle the batch
182
+ batch = batch.shuffle(seed=np.random.randint(0, 1000000))
183
+ except Exception as e:
184
+ print(f"Error maping batch", e)
185
+ traceback.print_exc()
186
+ error_queue.put_nowait({
187
+ "batch": batch,
188
+ "error": str(e)
189
+ })
190
+ pass
191
+
192
+
146
193
  def parallel_image_loader(
147
- dataset: Dataset, num_workers: int = 8, image_shape=(256, 256),
194
+ dataset: Dataset, num_workers: int = 8, image_shape=(256, 256),
148
195
  min_image_shape=(128, 128),
149
196
  num_threads=256, timeout=15, retries=3, image_processor=default_image_processor,
150
197
  upscale_interpolation=cv2.INTER_CUBIC,
151
198
  downscale_interpolation=cv2.INTER_AREA,
199
+ feature_extractor=default_feature_extractor,
200
+ map_batch_fn=map_batch,
201
+
152
202
  ):
153
- map_batch_fn = partial(map_batch, num_threads=num_threads, image_shape=image_shape,
154
- min_image_shape=min_image_shape,
155
- timeout=timeout, retries=retries, image_processor=image_processor,
156
- upscale_interpolation=upscale_interpolation,
157
- downscale_interpolation=downscale_interpolation)
203
+ map_batch_fn = partial(
204
+ map_batch_fn, num_threads=num_threads, image_shape=image_shape,
205
+ min_image_shape=min_image_shape,
206
+ timeout=timeout, retries=retries, image_processor=image_processor,
207
+ upscale_interpolation=upscale_interpolation,
208
+ downscale_interpolation=downscale_interpolation,
209
+ feature_extractor=feature_extractor
210
+ )
158
211
  shard_len = len(dataset) // num_workers
159
- print(f"Local Shard lengths: {shard_len}")
212
+ print(f"Local Shard lengths: {shard_len}, workers: {num_workers}")
160
213
  with multiprocessing.Pool(num_workers) as pool:
161
214
  iteration = 0
162
215
  while True:
@@ -168,6 +221,7 @@ def parallel_image_loader(
168
221
  iteration += 1
169
222
  print(f"Shuffling dataset with seed {iteration}")
170
223
  dataset = dataset.shuffle(seed=iteration)
224
+ print(f"Dataset shuffled")
171
225
  # Clear the error queue
172
226
  # while not error_queue.empty():
173
227
  # error_queue.get_nowait()
@@ -175,25 +229,44 @@ def parallel_image_loader(
175
229
 
176
230
  class ImageBatchIterator:
177
231
  def __init__(
178
- self, dataset: Dataset, batch_size: int = 64, image_shape=(256, 256),
232
+ self, dataset: Dataset, batch_size: int = 64, image_shape=(256, 256),
179
233
  min_image_shape=(128, 128),
180
- num_workers: int = 8, num_threads=256, timeout=15, retries=3,
234
+ num_workers: int = 8, num_threads=256, timeout=15, retries=3,
181
235
  image_processor=default_image_processor,
182
236
  upscale_interpolation=cv2.INTER_CUBIC,
183
237
  downscale_interpolation=cv2.INTER_AREA,
238
+ feature_extractor=default_feature_extractor,
239
+ map_batch_fn=map_batch,
184
240
  ):
185
241
  self.dataset = dataset
186
242
  self.num_workers = num_workers
187
243
  self.batch_size = batch_size
188
- loader = partial(parallel_image_loader, num_threads=num_threads,
189
- image_shape=image_shape,
190
- min_image_shape=min_image_shape,
191
- num_workers=num_workers,
192
- timeout=timeout, retries=retries, image_processor=image_processor,
193
- upscale_interpolation=upscale_interpolation,
194
- downscale_interpolation=downscale_interpolation)
244
+ loader = partial(
245
+ parallel_image_loader,
246
+ num_threads=num_threads,
247
+ image_shape=image_shape,
248
+ min_image_shape=min_image_shape,
249
+ num_workers=num_workers,
250
+ timeout=timeout, retries=retries,
251
+ image_processor=image_processor,
252
+ upscale_interpolation=upscale_interpolation,
253
+ downscale_interpolation=downscale_interpolation,
254
+ feature_extractor=feature_extractor,
255
+ map_batch_fn=map_batch_fn,
256
+ )
195
257
  self.thread = threading.Thread(target=loader, args=(dataset,))
196
258
  self.thread.start()
259
+ self.error_queue = queue.Queue()
260
+
261
+ def error_fetcher():
262
+ while True:
263
+ error = error_queue.get()
264
+ self.error_queue.put(error)
265
+ self.error_thread = threading.Thread(target=error_fetcher)
266
+ self.error_thread.start()
267
+
268
+ def get_error(self):
269
+ yield self.error_queue.get()
197
270
 
198
271
  def __iter__(self):
199
272
  return self
@@ -256,6 +329,8 @@ class OnlineStreamingDataLoader():
256
329
  image_processor=default_image_processor,
257
330
  upscale_interpolation=cv2.INTER_CUBIC,
258
331
  downscale_interpolation=cv2.INTER_AREA,
332
+ feature_extractor=default_feature_extractor,
333
+ map_batch_fn=map_batch,
259
334
  ):
260
335
  if isinstance(dataset, str):
261
336
  dataset_path = dataset
@@ -276,12 +351,16 @@ class OnlineStreamingDataLoader():
276
351
  self.dataset = dataset.shard(
277
352
  num_shards=global_process_count, index=global_process_index)
278
353
  print(f"Dataset length: {len(dataset)}")
279
- self.iterator = ImageBatchIterator(self.dataset, image_shape=image_shape,
280
- min_image_shape=min_image_shape,
281
- num_workers=num_workers, batch_size=batch_size, num_threads=num_threads,
282
- timeout=timeout, retries=retries, image_processor=image_processor,
283
- upscale_interpolation=upscale_interpolation,
284
- downscale_interpolation=downscale_interpolation)
354
+ self.iterator = ImageBatchIterator(
355
+ self.dataset, image_shape=image_shape,
356
+ min_image_shape=min_image_shape,
357
+ num_workers=num_workers, batch_size=batch_size, num_threads=num_threads,
358
+ timeout=timeout, retries=retries, image_processor=image_processor,
359
+ upscale_interpolation=upscale_interpolation,
360
+ downscale_interpolation=downscale_interpolation,
361
+ feature_extractor=feature_extractor,
362
+ map_batch_fn=map_batch_fn,
363
+ )
285
364
  self.batch_size = batch_size
286
365
 
287
366
  # Launch a thread to load batches in the background
@@ -292,7 +371,7 @@ class OnlineStreamingDataLoader():
292
371
  try:
293
372
  self.batch_queue.put(collate_fn(batch))
294
373
  except Exception as e:
295
- print("Error processing batch", e)
374
+ print("Error collating batch", e)
296
375
 
297
376
  self.loader_thread = threading.Thread(target=batch_loader)
298
377
  self.loader_thread.start()
@@ -305,4 +384,4 @@ class OnlineStreamingDataLoader():
305
384
  # return self.collate_fn(next(self.iterator))
306
385
 
307
386
  def __len__(self):
308
- return len(self.dataset)
387
+ return len(self.dataset)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flaxdiff
3
- Version: 0.1.31
3
+ Version: 0.1.33
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author: Ashish Kumar Singh
6
6
  Author-email: ashishkmr472@gmail.com
@@ -11,7 +11,7 @@ required_packages=[
11
11
  setup(
12
12
  name='flaxdiff',
13
13
  packages=find_packages(),
14
- version='0.1.31',
14
+ version='0.1.33',
15
15
  description='A versatile and easy to understand Diffusion library',
16
16
  long_description=open('README.md').read(),
17
17
  long_description_content_type='text/markdown',
File without changes
File without changes
File without changes