flaxdiff 0.1.19__py3-none-any.whl → 0.1.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -25,7 +25,6 @@ import cv2
25
25
  USER_AGENT = get_datasets_user_agent()
26
26
 
27
27
  data_queue = Queue(16*2000)
28
- error_queue = Queue()
29
28
 
30
29
 
31
30
  def fetch_single_image(image_url, timeout=None, retries=0):
@@ -60,6 +59,7 @@ def default_image_processor(image, image_shape, interpolation=cv2.INTER_CUBIC):
60
59
  def map_sample(
61
60
  url, caption,
62
61
  image_shape=(256, 256),
62
+ min_image_shape=(128, 128),
63
63
  timeout=15,
64
64
  retries=3,
65
65
  upscale_interpolation=cv2.INTER_CUBIC,
@@ -75,21 +75,21 @@ def map_sample(
75
75
  image = np.array(image)
76
76
  original_height, original_width = image.shape[:2]
77
77
  # check if the image is too small
78
- if min(original_height, original_width) < min(image_shape):
78
+ if min(original_height, original_width) < min(min_image_shape):
79
79
  return
80
80
  # check if wrong aspect ratio
81
- if max(original_height, original_width) / min(original_height, original_width) > 2:
81
+ if max(original_height, original_width) / min(original_height, original_width) > 2.4:
82
82
  return
83
83
  # check if the variance is too low
84
84
  if np.std(image) < 1e-4:
85
85
  return
86
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
86
+ # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
87
87
  downscale = max(original_width, original_height) > max(image_shape)
88
88
  interpolation = downscale_interpolation if downscale else upscale_interpolation
89
89
 
90
90
  image = image_processor(
91
91
  image, image_shape, interpolation=interpolation)
92
-
92
+
93
93
  data_queue.put({
94
94
  "url": url,
95
95
  "caption": caption,
@@ -98,40 +98,47 @@ def map_sample(
98
98
  "original_width": original_width,
99
99
  })
100
100
  except Exception as e:
101
- error_queue.put_nowait({
102
- "url": url,
103
- "caption": caption,
104
- "error": str(e)
105
- })
101
+ print(f"Error processing {url}", e)
102
+ # error_queue.put_nowait({
103
+ # "url": url,
104
+ # "caption": caption,
105
+ # "error": str(e)
106
+ # })
107
+ pass
106
108
 
107
109
 
108
110
  def map_batch(
109
111
  batch, num_threads=256, image_shape=(256, 256),
112
+ min_image_shape=(128, 128),
110
113
  timeout=15, retries=3, image_processor=default_image_processor,
111
114
  upscale_interpolation=cv2.INTER_CUBIC,
112
115
  downscale_interpolation=cv2.INTER_AREA,
113
116
  ):
114
117
  try:
115
- map_sample_fn = partial(map_sample, image_shape=image_shape,
118
+ map_sample_fn = partial(map_sample, image_shape=image_shape, min_image_shape=min_image_shape,
116
119
  timeout=timeout, retries=retries, image_processor=image_processor,
117
120
  upscale_interpolation=upscale_interpolation,
118
121
  downscale_interpolation=downscale_interpolation)
119
122
  with ThreadPoolExecutor(max_workers=num_threads) as executor:
120
123
  executor.map(map_sample_fn, batch["url"], batch['caption'])
121
124
  except Exception as e:
122
- error_queue.put({
123
- "batch": batch,
124
- "error": str(e)
125
- })
125
+ print(f"Error processing batch", e)
126
+ # error_queue.put_nowait({
127
+ # "batch": batch,
128
+ # "error": str(e)
129
+ # })
130
+ pass
126
131
 
127
132
 
128
133
  def parallel_image_loader(
129
134
  dataset: Dataset, num_workers: int = 8, image_shape=(256, 256),
135
+ min_image_shape=(128, 128),
130
136
  num_threads=256, timeout=15, retries=3, image_processor=default_image_processor,
131
137
  upscale_interpolation=cv2.INTER_CUBIC,
132
138
  downscale_interpolation=cv2.INTER_AREA,
133
139
  ):
134
- map_batch_fn = partial(map_batch, num_threads=num_threads, image_shape=image_shape,
140
+ map_batch_fn = partial(map_batch, num_threads=num_threads, image_shape=image_shape,
141
+ min_image_shape=min_image_shape,
135
142
  timeout=timeout, retries=retries, image_processor=image_processor,
136
143
  upscale_interpolation=upscale_interpolation,
137
144
  downscale_interpolation=downscale_interpolation)
@@ -149,13 +156,14 @@ def parallel_image_loader(
149
156
  print(f"Shuffling dataset with seed {iteration}")
150
157
  dataset = dataset.shuffle(seed=iteration)
151
158
  # Clear the error queue
152
- while not error_queue.empty():
153
- error_queue.get_nowait()
159
+ # while not error_queue.empty():
160
+ # error_queue.get_nowait()
154
161
 
155
162
 
156
163
  class ImageBatchIterator:
157
164
  def __init__(
158
165
  self, dataset: Dataset, batch_size: int = 64, image_shape=(256, 256),
166
+ min_image_shape=(128, 128),
159
167
  num_workers: int = 8, num_threads=256, timeout=15, retries=3,
160
168
  image_processor=default_image_processor,
161
169
  upscale_interpolation=cv2.INTER_CUBIC,
@@ -165,7 +173,9 @@ class ImageBatchIterator:
165
173
  self.num_workers = num_workers
166
174
  self.batch_size = batch_size
167
175
  loader = partial(parallel_image_loader, num_threads=num_threads,
168
- image_shape=image_shape, num_workers=num_workers,
176
+ image_shape=image_shape,
177
+ min_image_shape=min_image_shape,
178
+ num_workers=num_workers,
169
179
  timeout=timeout, retries=retries, image_processor=image_processor,
170
180
  upscale_interpolation=upscale_interpolation,
171
181
  downscale_interpolation=downscale_interpolation)
@@ -215,6 +225,7 @@ class OnlineStreamingDataLoader():
215
225
  dataset,
216
226
  batch_size=64,
217
227
  image_shape=(256, 256),
228
+ min_image_shape=(128, 128),
218
229
  num_workers=16,
219
230
  num_threads=512,
220
231
  default_split="all",
@@ -253,8 +264,9 @@ class OnlineStreamingDataLoader():
253
264
  num_shards=global_process_count, index=global_process_index)
254
265
  print(f"Dataset length: {len(dataset)}")
255
266
  self.iterator = ImageBatchIterator(self.dataset, image_shape=image_shape,
267
+ min_image_shape=min_image_shape,
256
268
  num_workers=num_workers, batch_size=batch_size, num_threads=num_threads,
257
- timeout=timeout, retries=retries, image_processor=image_processor,
269
+ timeout=timeout, retries=retries, image_processor=image_processor,
258
270
  upscale_interpolation=upscale_interpolation,
259
271
  downscale_interpolation=downscale_interpolation)
260
272
  self.batch_size = batch_size
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flaxdiff
3
- Version: 0.1.19
3
+ Version: 0.1.21
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author: Ashish Kumar Singh
6
6
  Author-email: ashishkmr472@gmail.com
@@ -1,7 +1,7 @@
1
1
  flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  flaxdiff/utils.py,sha256=B0GcHlzlVYDNEIdh2v5qmP4u0neIT-FqexNohuyuCvg,2452
3
3
  flaxdiff/data/__init__.py,sha256=PM3PkHihyohT5SHVYKc8vQ4IeVfGPpCktkSVwvqMjQ4,52
4
- flaxdiff/data/online_loader.py,sha256=WK4apO8Bx-RTU_z5imB53Lzq12vqGnXA9DhLq8nb0us,9991
4
+ flaxdiff/data/online_loader.py,sha256=w6gi1tAzWr4gtPt7onpStzOxp7Kdo_2q8Ro4Yi7OT4w,10549
5
5
  flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,26
6
6
  flaxdiff/models/attention.py,sha256=ZbDGIb5Q6FRqJ6qRY660cqw4WvF9IwCnhEuYdTpLPdM,13023
7
7
  flaxdiff/models/common.py,sha256=fd-Fl0VCNEBjijHNwGBqYL5VvXe9u0347h25czNTmRw,10780
@@ -34,7 +34,7 @@ flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk,
34
34
  flaxdiff/trainer/autoencoder_trainer.py,sha256=al7AsZ7yeDMEiDD-gbcXf0ADq_xfk1VMxvg24GfA-XQ,7008
35
35
  flaxdiff/trainer/diffusion_trainer.py,sha256=wKkg63DWZjx2MoM3VQNCDIr40rWN8fUGxH9jWWxfZao,9373
36
36
  flaxdiff/trainer/simple_trainer.py,sha256=Z77zRS5viJpd2Mpl6sonJk5WcnEWi2Cd4gl4u5tIX2M,18206
37
- flaxdiff-0.1.19.dist-info/METADATA,sha256=NH-f1SK5obamoVRk8ZPQxvtQcz_R3mui3ToZe0Qx8Vg,22083
38
- flaxdiff-0.1.19.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
39
- flaxdiff-0.1.19.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
40
- flaxdiff-0.1.19.dist-info/RECORD,,
37
+ flaxdiff-0.1.21.dist-info/METADATA,sha256=k1s_EIWBL0y4oCxXxr3QIi7LQbR47_jyDFfjVbURSMY,22083
38
+ flaxdiff-0.1.21.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
39
+ flaxdiff-0.1.21.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
40
+ flaxdiff-0.1.21.dist-info/RECORD,,