flaxdiff 0.1.17__py3-none-any.whl → 0.1.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -25,7 +25,7 @@ import cv2
25
25
  USER_AGENT = get_datasets_user_agent()
26
26
 
27
27
  data_queue = Queue(16*2000)
28
- error_queue = Queue(16*2000)
28
+ error_queue = Queue()
29
29
 
30
30
 
31
31
  def fetch_single_image(image_url, timeout=None, retries=0):
@@ -44,7 +44,7 @@ def fetch_single_image(image_url, timeout=None, retries=0):
44
44
  return image
45
45
 
46
46
 
47
- def default_image_processor(image, image_shape, interpolation=cv2.INTER_LANCZOS4):
47
+ def default_image_processor(image, image_shape, interpolation=cv2.INTER_CUBIC):
48
48
  image = A.longest_max_size(image, max(
49
49
  image_shape), interpolation=interpolation)
50
50
  image = A.pad(
@@ -62,7 +62,7 @@ def map_sample(
62
62
  image_shape=(256, 256),
63
63
  timeout=15,
64
64
  retries=3,
65
- upscale_interpolation=cv2.INTER_LANCZOS4,
65
+ upscale_interpolation=cv2.INTER_CUBIC,
66
66
  downscale_interpolation=cv2.INTER_AREA,
67
67
  image_processor=default_image_processor,
68
68
  ):
@@ -98,17 +98,24 @@ def map_sample(
98
98
  "original_width": original_width,
99
99
  })
100
100
  except Exception as e:
101
- error_queue.put({
101
+ error_queue.put_nowait({
102
102
  "url": url,
103
103
  "caption": caption,
104
104
  "error": str(e)
105
105
  })
106
106
 
107
107
 
108
- def map_batch(batch, num_threads=256, image_shape=(256, 256), timeout=15, retries=3, image_processor=default_image_processor):
108
+ def map_batch(
109
+ batch, num_threads=256, image_shape=(256, 256),
110
+ timeout=15, retries=3, image_processor=default_image_processor,
111
+ upscale_interpolation=cv2.INTER_CUBIC,
112
+ downscale_interpolation=cv2.INTER_AREA,
113
+ ):
109
114
  try:
110
115
  map_sample_fn = partial(map_sample, image_shape=image_shape,
111
- timeout=timeout, retries=retries, image_processor=image_processor)
116
+ timeout=timeout, retries=retries, image_processor=image_processor,
117
+ upscale_interpolation=upscale_interpolation,
118
+ downscale_interpolation=downscale_interpolation)
112
119
  with ThreadPoolExecutor(max_workers=num_threads) as executor:
113
120
  executor.map(map_sample_fn, batch["url"], batch['caption'])
114
121
  except Exception as e:
@@ -118,10 +125,16 @@ def map_batch(batch, num_threads=256, image_shape=(256, 256), timeout=15, retrie
118
125
  })
119
126
 
120
127
 
121
- def parallel_image_loader(dataset: Dataset, num_workers: int = 8, image_shape=(256, 256),
122
- num_threads=256, timeout=15, retries=3, image_processor=default_image_processor):
128
+ def parallel_image_loader(
129
+ dataset: Dataset, num_workers: int = 8, image_shape=(256, 256),
130
+ num_threads=256, timeout=15, retries=3, image_processor=default_image_processor,
131
+ upscale_interpolation=cv2.INTER_CUBIC,
132
+ downscale_interpolation=cv2.INTER_AREA,
133
+ ):
123
134
  map_batch_fn = partial(map_batch, num_threads=num_threads, image_shape=image_shape,
124
- timeout=timeout, retries=retries, image_processor=image_processor)
135
+ timeout=timeout, retries=retries, image_processor=image_processor,
136
+ upscale_interpolation=upscale_interpolation,
137
+ downscale_interpolation=downscale_interpolation)
125
138
  shard_len = len(dataset) // num_workers
126
139
  print(f"Local Shard lengths: {shard_len}")
127
140
  with multiprocessing.Pool(num_workers) as pool:
@@ -135,17 +148,27 @@ def parallel_image_loader(dataset: Dataset, num_workers: int = 8, image_shape=(2
135
148
  iteration += 1
136
149
  print(f"Shuffling dataset with seed {iteration}")
137
150
  dataset = dataset.shuffle(seed=iteration)
151
+ # Clear the error queue
152
+ while not error_queue.empty():
153
+ error_queue.get_nowait()
138
154
 
139
155
 
140
156
  class ImageBatchIterator:
141
- def __init__(self, dataset: Dataset, batch_size: int = 64, image_shape=(256, 256),
142
- num_workers: int = 8, num_threads=256, timeout=15, retries=3, image_processor=default_image_processor):
157
+ def __init__(
158
+ self, dataset: Dataset, batch_size: int = 64, image_shape=(256, 256),
159
+ num_workers: int = 8, num_threads=256, timeout=15, retries=3,
160
+ image_processor=default_image_processor,
161
+ upscale_interpolation=cv2.INTER_CUBIC,
162
+ downscale_interpolation=cv2.INTER_AREA,
163
+ ):
143
164
  self.dataset = dataset
144
165
  self.num_workers = num_workers
145
166
  self.batch_size = batch_size
146
167
  loader = partial(parallel_image_loader, num_threads=num_threads,
147
168
  image_shape=image_shape, num_workers=num_workers,
148
- timeout=timeout, retries=retries, image_processor=image_processor)
169
+ timeout=timeout, retries=retries, image_processor=image_processor,
170
+ upscale_interpolation=upscale_interpolation,
171
+ downscale_interpolation=downscale_interpolation)
149
172
  self.thread = threading.Thread(target=loader, args=(dataset,))
150
173
  self.thread.start()
151
174
 
@@ -207,6 +230,8 @@ class OnlineStreamingDataLoader():
207
230
  timeout=15,
208
231
  retries=3,
209
232
  image_processor=default_image_processor,
233
+ upscale_interpolation=cv2.INTER_CUBIC,
234
+ downscale_interpolation=cv2.INTER_AREA,
210
235
  ):
211
236
  if isinstance(dataset, str):
212
237
  dataset_path = dataset
@@ -229,8 +254,9 @@ class OnlineStreamingDataLoader():
229
254
  print(f"Dataset length: {len(dataset)}")
230
255
  self.iterator = ImageBatchIterator(self.dataset, image_shape=image_shape,
231
256
  num_workers=num_workers, batch_size=batch_size, num_threads=num_threads,
232
- timeout=timeout, retries=retries, image_processor=image_processor)
233
- self.collate_fn = collate_fn
257
+ timeout=timeout, retries=retries, image_processor=image_processor,
258
+ upscale_interpolation=upscale_interpolation,
259
+ downscale_interpolation=downscale_interpolation)
234
260
  self.batch_size = batch_size
235
261
 
236
262
  # Launch a thread to load batches in the background
@@ -238,7 +264,10 @@ class OnlineStreamingDataLoader():
238
264
 
239
265
  def batch_loader():
240
266
  for batch in self.iterator:
241
- self.batch_queue.put(batch)
267
+ try:
268
+ self.batch_queue.put(collate_fn(batch))
269
+ except Exception as e:
270
+ print("Error processing batch", e)
242
271
 
243
272
  self.loader_thread = threading.Thread(target=batch_loader)
244
273
  self.loader_thread.start()
@@ -247,8 +276,8 @@ class OnlineStreamingDataLoader():
247
276
  return self
248
277
 
249
278
  def __next__(self):
250
- return self.collate_fn(self.batch_queue.get())
279
+ return self.batch_queue.get()
251
280
  # return self.collate_fn(next(self.iterator))
252
281
 
253
282
  def __len__(self):
254
- return len(self.dataset)
283
+ return len(self.dataset)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flaxdiff
3
- Version: 0.1.17
3
+ Version: 0.1.19
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author: Ashish Kumar Singh
6
6
  Author-email: ashishkmr472@gmail.com
@@ -1,7 +1,7 @@
1
1
  flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  flaxdiff/utils.py,sha256=B0GcHlzlVYDNEIdh2v5qmP4u0neIT-FqexNohuyuCvg,2452
3
3
  flaxdiff/data/__init__.py,sha256=PM3PkHihyohT5SHVYKc8vQ4IeVfGPpCktkSVwvqMjQ4,52
4
- flaxdiff/data/online_loader.py,sha256=BM4Le-4BUo8MJpRzGIA2nMHKm4-WynQ2BOdiQz0JCDs,8791
4
+ flaxdiff/data/online_loader.py,sha256=WK4apO8Bx-RTU_z5imB53Lzq12vqGnXA9DhLq8nb0us,9991
5
5
  flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,26
6
6
  flaxdiff/models/attention.py,sha256=ZbDGIb5Q6FRqJ6qRY660cqw4WvF9IwCnhEuYdTpLPdM,13023
7
7
  flaxdiff/models/common.py,sha256=fd-Fl0VCNEBjijHNwGBqYL5VvXe9u0347h25czNTmRw,10780
@@ -34,7 +34,7 @@ flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk,
34
34
  flaxdiff/trainer/autoencoder_trainer.py,sha256=al7AsZ7yeDMEiDD-gbcXf0ADq_xfk1VMxvg24GfA-XQ,7008
35
35
  flaxdiff/trainer/diffusion_trainer.py,sha256=wKkg63DWZjx2MoM3VQNCDIr40rWN8fUGxH9jWWxfZao,9373
36
36
  flaxdiff/trainer/simple_trainer.py,sha256=Z77zRS5viJpd2Mpl6sonJk5WcnEWi2Cd4gl4u5tIX2M,18206
37
- flaxdiff-0.1.17.dist-info/METADATA,sha256=2Nr_T2yg3XHFt2jBuUXo8FxLYM8si-DBLdW_PBKxzc4,22083
38
- flaxdiff-0.1.17.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
39
- flaxdiff-0.1.17.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
40
- flaxdiff-0.1.17.dist-info/RECORD,,
37
+ flaxdiff-0.1.19.dist-info/METADATA,sha256=NH-f1SK5obamoVRk8ZPQxvtQcz_R3mui3ToZe0Qx8Vg,22083
38
+ flaxdiff-0.1.19.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
39
+ flaxdiff-0.1.19.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
40
+ flaxdiff-0.1.19.dist-info/RECORD,,