flaxdiff 0.1.35.1__tar.gz → 0.1.35.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. {flaxdiff-0.1.35.1 → flaxdiff-0.1.35.3}/PKG-INFO +1 -1
  2. {flaxdiff-0.1.35.1 → flaxdiff-0.1.35.3}/flaxdiff/data/online_loader.py +19 -71
  3. {flaxdiff-0.1.35.1 → flaxdiff-0.1.35.3}/flaxdiff.egg-info/PKG-INFO +1 -1
  4. {flaxdiff-0.1.35.1 → flaxdiff-0.1.35.3}/setup.py +1 -1
  5. {flaxdiff-0.1.35.1 → flaxdiff-0.1.35.3}/README.md +0 -0
  6. {flaxdiff-0.1.35.1 → flaxdiff-0.1.35.3}/flaxdiff/__init__.py +0 -0
  7. {flaxdiff-0.1.35.1 → flaxdiff-0.1.35.3}/flaxdiff/data/__init__.py +0 -0
  8. {flaxdiff-0.1.35.1 → flaxdiff-0.1.35.3}/flaxdiff/models/__init__.py +0 -0
  9. {flaxdiff-0.1.35.1 → flaxdiff-0.1.35.3}/flaxdiff/models/attention.py +0 -0
  10. {flaxdiff-0.1.35.1 → flaxdiff-0.1.35.3}/flaxdiff/models/autoencoder/__init__.py +0 -0
  11. {flaxdiff-0.1.35.1 → flaxdiff-0.1.35.3}/flaxdiff/models/autoencoder/autoencoder.py +0 -0
  12. {flaxdiff-0.1.35.1 → flaxdiff-0.1.35.3}/flaxdiff/models/autoencoder/diffusers.py +0 -0
  13. {flaxdiff-0.1.35.1 → flaxdiff-0.1.35.3}/flaxdiff/models/autoencoder/simple_autoenc.py +0 -0
  14. {flaxdiff-0.1.35.1 → flaxdiff-0.1.35.3}/flaxdiff/models/common.py +0 -0
  15. {flaxdiff-0.1.35.1 → flaxdiff-0.1.35.3}/flaxdiff/models/favor_fastattn.py +0 -0
  16. {flaxdiff-0.1.35.1 → flaxdiff-0.1.35.3}/flaxdiff/models/simple_unet.py +0 -0
  17. {flaxdiff-0.1.35.1 → flaxdiff-0.1.35.3}/flaxdiff/models/simple_vit.py +0 -0
  18. {flaxdiff-0.1.35.1 → flaxdiff-0.1.35.3}/flaxdiff/predictors/__init__.py +0 -0
  19. {flaxdiff-0.1.35.1 → flaxdiff-0.1.35.3}/flaxdiff/samplers/__init__.py +0 -0
  20. {flaxdiff-0.1.35.1 → flaxdiff-0.1.35.3}/flaxdiff/samplers/common.py +0 -0
  21. {flaxdiff-0.1.35.1 → flaxdiff-0.1.35.3}/flaxdiff/samplers/ddim.py +0 -0
  22. {flaxdiff-0.1.35.1 → flaxdiff-0.1.35.3}/flaxdiff/samplers/ddpm.py +0 -0
  23. {flaxdiff-0.1.35.1 → flaxdiff-0.1.35.3}/flaxdiff/samplers/euler.py +0 -0
  24. {flaxdiff-0.1.35.1 → flaxdiff-0.1.35.3}/flaxdiff/samplers/heun_sampler.py +0 -0
  25. {flaxdiff-0.1.35.1 → flaxdiff-0.1.35.3}/flaxdiff/samplers/multistep_dpm.py +0 -0
  26. {flaxdiff-0.1.35.1 → flaxdiff-0.1.35.3}/flaxdiff/samplers/rk4_sampler.py +0 -0
  27. {flaxdiff-0.1.35.1 → flaxdiff-0.1.35.3}/flaxdiff/schedulers/__init__.py +0 -0
  28. {flaxdiff-0.1.35.1 → flaxdiff-0.1.35.3}/flaxdiff/schedulers/common.py +0 -0
  29. {flaxdiff-0.1.35.1 → flaxdiff-0.1.35.3}/flaxdiff/schedulers/continuous.py +0 -0
  30. {flaxdiff-0.1.35.1 → flaxdiff-0.1.35.3}/flaxdiff/schedulers/cosine.py +0 -0
  31. {flaxdiff-0.1.35.1 → flaxdiff-0.1.35.3}/flaxdiff/schedulers/discrete.py +0 -0
  32. {flaxdiff-0.1.35.1 → flaxdiff-0.1.35.3}/flaxdiff/schedulers/exp.py +0 -0
  33. {flaxdiff-0.1.35.1 → flaxdiff-0.1.35.3}/flaxdiff/schedulers/karras.py +0 -0
  34. {flaxdiff-0.1.35.1 → flaxdiff-0.1.35.3}/flaxdiff/schedulers/linear.py +0 -0
  35. {flaxdiff-0.1.35.1 → flaxdiff-0.1.35.3}/flaxdiff/schedulers/sqrt.py +0 -0
  36. {flaxdiff-0.1.35.1 → flaxdiff-0.1.35.3}/flaxdiff/trainer/__init__.py +0 -0
  37. {flaxdiff-0.1.35.1 → flaxdiff-0.1.35.3}/flaxdiff/trainer/autoencoder_trainer.py +0 -0
  38. {flaxdiff-0.1.35.1 → flaxdiff-0.1.35.3}/flaxdiff/trainer/diffusion_trainer.py +0 -0
  39. {flaxdiff-0.1.35.1 → flaxdiff-0.1.35.3}/flaxdiff/trainer/simple_trainer.py +0 -0
  40. {flaxdiff-0.1.35.1 → flaxdiff-0.1.35.3}/flaxdiff/utils.py +0 -0
  41. {flaxdiff-0.1.35.1 → flaxdiff-0.1.35.3}/flaxdiff.egg-info/SOURCES.txt +0 -0
  42. {flaxdiff-0.1.35.1 → flaxdiff-0.1.35.3}/flaxdiff.egg-info/dependency_links.txt +0 -0
  43. {flaxdiff-0.1.35.1 → flaxdiff-0.1.35.3}/flaxdiff.egg-info/requires.txt +0 -0
  44. {flaxdiff-0.1.35.1 → flaxdiff-0.1.35.3}/flaxdiff.egg-info/top_level.txt +0 -0
  45. {flaxdiff-0.1.35.1 → flaxdiff-0.1.35.3}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flaxdiff
3
- Version: 0.1.35.1
3
+ Version: 0.1.35.3
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author: Ashish Kumar Singh
6
6
  Author-email: ashishkmr472@gmail.com
@@ -26,7 +26,6 @@ import traceback
26
26
  USER_AGENT = get_datasets_user_agent()
27
27
 
28
28
  data_queue = Queue(16*2000)
29
- error_queue = Queue(16*2000)
30
29
 
31
30
 
32
31
  def fetch_single_image(image_url, timeout=None, retries=0):
@@ -78,13 +77,6 @@ def default_image_processor(
78
77
  return image, original_height, original_width
79
78
 
80
79
 
81
- def default_feature_extractor(sample):
82
- return {
83
- "url": sample["url"],
84
- "caption": sample["caption"],
85
- }
86
-
87
-
88
80
  def map_sample(
89
81
  url,
90
82
  caption,
@@ -129,6 +121,13 @@ def map_sample(
129
121
  pass
130
122
 
131
123
 
124
+ def default_feature_extractor(sample):
125
+ return {
126
+ "url": sample["url"],
127
+ "caption": sample["caption"],
128
+ }
129
+
130
+
132
131
  def map_batch(
133
132
  batch, num_threads=256, image_shape=(256, 256),
134
133
  min_image_shape=(128, 128),
@@ -142,54 +141,22 @@ def map_batch(
142
141
  map_sample, image_shape=image_shape, min_image_shape=min_image_shape,
143
142
  timeout=timeout, retries=retries, image_processor=image_processor,
144
143
  upscale_interpolation=upscale_interpolation,
145
- downscale_interpolation=downscale_interpolation,
146
- feature_extractor=feature_extractor
144
+ downscale_interpolation=downscale_interpolation
147
145
  )
148
- features = feature_extractor(batch)
149
- url, caption = features["url"], features["caption"]
150
146
  with ThreadPoolExecutor(max_workers=num_threads) as executor:
147
+ features = feature_extractor(batch)
148
+ url, caption = features["url"], features["caption"]
151
149
  executor.map(map_sample_fn, url, caption)
152
150
  except Exception as e:
153
151
  print(f"Error maping batch", e)
154
152
  traceback.print_exc()
155
- error_queue.put_nowait({
156
- "batch": batch,
157
- "error": str(e)
158
- })
153
+ # error_queue.put_nowait({
154
+ # "batch": batch,
155
+ # "error": str(e)
156
+ # })
159
157
  pass
160
158
 
161
159
 
162
- def map_batch_repeat_forever(
163
- batch, num_threads=256, image_shape=(256, 256),
164
- min_image_shape=(128, 128),
165
- timeout=15, retries=3, image_processor=default_image_processor,
166
- upscale_interpolation=cv2.INTER_CUBIC,
167
- downscale_interpolation=cv2.INTER_AREA,
168
- feature_extractor=default_feature_extractor,
169
- ):
170
- while True: # Repeat forever
171
- try:
172
- map_sample_fn = partial(
173
- map_sample, image_shape=image_shape, min_image_shape=min_image_shape,
174
- timeout=timeout, retries=retries, image_processor=image_processor,
175
- upscale_interpolation=upscale_interpolation,
176
- downscale_interpolation=downscale_interpolation,
177
- feature_extractor=feature_extractor
178
- )
179
- with ThreadPoolExecutor(max_workers=num_threads) as executor:
180
- executor.map(map_sample_fn, batch)
181
- # Shuffle the batch
182
- batch = batch.shuffle(seed=np.random.randint(0, 1000000))
183
- except Exception as e:
184
- print(f"Error maping batch", e)
185
- traceback.print_exc()
186
- error_queue.put_nowait({
187
- "batch": batch,
188
- "error": str(e)
189
- })
190
- pass
191
-
192
-
193
160
  def parallel_image_loader(
194
161
  dataset: Dataset, num_workers: int = 8, image_shape=(256, 256),
195
162
  min_image_shape=(128, 128),
@@ -197,11 +164,9 @@ def parallel_image_loader(
197
164
  upscale_interpolation=cv2.INTER_CUBIC,
198
165
  downscale_interpolation=cv2.INTER_AREA,
199
166
  feature_extractor=default_feature_extractor,
200
- map_batch_fn=map_batch,
201
-
202
167
  ):
203
168
  map_batch_fn = partial(
204
- map_batch_fn, num_threads=num_threads, image_shape=image_shape,
169
+ map_batch, num_threads=num_threads, image_shape=image_shape,
205
170
  min_image_shape=min_image_shape,
206
171
  timeout=timeout, retries=retries, image_processor=image_processor,
207
172
  upscale_interpolation=upscale_interpolation,
@@ -209,20 +174,18 @@ def parallel_image_loader(
209
174
  feature_extractor=feature_extractor
210
175
  )
211
176
  shard_len = len(dataset) // num_workers
212
- print(f"Local Shard lengths: {shard_len}, workers: {num_workers}")
177
+ print(f"Local Shard lengths: {shard_len}")
213
178
  with multiprocessing.Pool(num_workers) as pool:
214
179
  iteration = 0
215
180
  while True:
216
181
  # Repeat forever
217
182
  shards = [dataset[i*shard_len:(i+1)*shard_len]
218
183
  for i in range(num_workers)]
219
- # shards = [dataset.shard(num_shards=num_workers, index=i) for i in range(num_workers)]
220
184
  print(f"mapping {len(shards)} shards")
221
185
  pool.map(map_batch_fn, shards)
222
186
  iteration += 1
223
187
  print(f"Shuffling dataset with seed {iteration}")
224
188
  dataset = dataset.shuffle(seed=iteration)
225
- print(f"Dataset shuffled")
226
189
  # Clear the error queue
227
190
  # while not error_queue.empty():
228
191
  # error_queue.get_nowait()
@@ -237,7 +200,6 @@ class ImageBatchIterator:
237
200
  upscale_interpolation=cv2.INTER_CUBIC,
238
201
  downscale_interpolation=cv2.INTER_AREA,
239
202
  feature_extractor=default_feature_extractor,
240
- map_batch_fn=map_batch,
241
203
  ):
242
204
  self.dataset = dataset
243
205
  self.num_workers = num_workers
@@ -252,22 +214,10 @@ class ImageBatchIterator:
252
214
  image_processor=image_processor,
253
215
  upscale_interpolation=upscale_interpolation,
254
216
  downscale_interpolation=downscale_interpolation,
255
- feature_extractor=feature_extractor,
256
- map_batch_fn=map_batch_fn,
217
+ feature_extractor=feature_extractor
257
218
  )
258
219
  self.thread = threading.Thread(target=loader, args=(dataset,))
259
220
  self.thread.start()
260
- self.error_queue = queue.Queue()
261
-
262
- def error_fetcher():
263
- while True:
264
- error = error_queue.get()
265
- self.error_queue.put(error)
266
- self.error_thread = threading.Thread(target=error_fetcher)
267
- self.error_thread.start()
268
-
269
- def get_error(self):
270
- yield self.error_queue.get()
271
221
 
272
222
  def __iter__(self):
273
223
  return self
@@ -331,7 +281,6 @@ class OnlineStreamingDataLoader():
331
281
  upscale_interpolation=cv2.INTER_CUBIC,
332
282
  downscale_interpolation=cv2.INTER_AREA,
333
283
  feature_extractor=default_feature_extractor,
334
- map_batch_fn=map_batch,
335
284
  ):
336
285
  if isinstance(dataset, str):
337
286
  dataset_path = dataset
@@ -359,8 +308,7 @@ class OnlineStreamingDataLoader():
359
308
  timeout=timeout, retries=retries, image_processor=image_processor,
360
309
  upscale_interpolation=upscale_interpolation,
361
310
  downscale_interpolation=downscale_interpolation,
362
- feature_extractor=feature_extractor,
363
- map_batch_fn=map_batch_fn,
311
+ feature_extractor=feature_extractor
364
312
  )
365
313
  self.batch_size = batch_size
366
314
 
@@ -372,7 +320,7 @@ class OnlineStreamingDataLoader():
372
320
  try:
373
321
  self.batch_queue.put(collate_fn(batch))
374
322
  except Exception as e:
375
- print("Error collating batch", e)
323
+ print("Error processing batch", e)
376
324
 
377
325
  self.loader_thread = threading.Thread(target=batch_loader)
378
326
  self.loader_thread.start()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flaxdiff
3
- Version: 0.1.35.1
3
+ Version: 0.1.35.3
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author: Ashish Kumar Singh
6
6
  Author-email: ashishkmr472@gmail.com
@@ -11,7 +11,7 @@ required_packages=[
11
11
  setup(
12
12
  name='flaxdiff',
13
13
  packages=find_packages(),
14
- version='0.1.35.1',
14
+ version='0.1.35.3',
15
15
  description='A versatile and easy to understand Diffusion library',
16
16
  long_description=open('README.md').read(),
17
17
  long_description_content_type='text/markdown',
File without changes
File without changes