flaxdiff 0.1.18__py3-none-any.whl → 0.1.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -25,7 +25,6 @@ import cv2
25
25
  USER_AGENT = get_datasets_user_agent()
26
26
 
27
27
  data_queue = Queue(16*2000)
28
- error_queue = Queue()
29
28
 
30
29
 
31
30
  def fetch_single_image(image_url, timeout=None, retries=0):
@@ -44,7 +43,7 @@ def fetch_single_image(image_url, timeout=None, retries=0):
44
43
  return image
45
44
 
46
45
 
47
- def default_image_processor(image, image_shape, interpolation=cv2.INTER_LANCZOS4):
46
+ def default_image_processor(image, image_shape, interpolation=cv2.INTER_CUBIC):
48
47
  image = A.longest_max_size(image, max(
49
48
  image_shape), interpolation=interpolation)
50
49
  image = A.pad(
@@ -60,9 +59,10 @@ def default_image_processor(image, image_shape, interpolation=cv2.INTER_LANCZOS4
60
59
  def map_sample(
61
60
  url, caption,
62
61
  image_shape=(256, 256),
62
+ min_image_shape=(128, 128),
63
63
  timeout=15,
64
64
  retries=3,
65
- upscale_interpolation=cv2.INTER_LANCZOS4,
65
+ upscale_interpolation=cv2.INTER_CUBIC,
66
66
  downscale_interpolation=cv2.INTER_AREA,
67
67
  image_processor=default_image_processor,
68
68
  ):
@@ -75,10 +75,10 @@ def map_sample(
75
75
  image = np.array(image)
76
76
  original_height, original_width = image.shape[:2]
77
77
  # check if the image is too small
78
- if min(original_height, original_width) < min(image_shape):
78
+ if min(original_height, original_width) < min(min_image_shape):
79
79
  return
80
80
  # check if wrong aspect ratio
81
- if max(original_height, original_width) / min(original_height, original_width) > 2:
81
+ if max(original_height, original_width) / min(original_height, original_width) > 2.4:
82
82
  return
83
83
  # check if the variance is too low
84
84
  if np.std(image) < 1e-4:
@@ -98,30 +98,48 @@ def map_sample(
98
98
  "original_width": original_width,
99
99
  })
100
100
  except Exception as e:
101
- error_queue.put_nowait({
102
- "url": url,
103
- "caption": caption,
104
- "error": str(e)
105
- })
106
-
107
-
108
- def map_batch(batch, num_threads=256, image_shape=(256, 256), timeout=15, retries=3, image_processor=default_image_processor):
101
+ # error_queue.put_nowait({
102
+ # "url": url,
103
+ # "caption": caption,
104
+ # "error": str(e)
105
+ # })
106
+ pass
107
+
108
+
109
+ def map_batch(
110
+ batch, num_threads=256, image_shape=(256, 256),
111
+ min_image_shape=(128, 128),
112
+ timeout=15, retries=3, image_processor=default_image_processor,
113
+ upscale_interpolation=cv2.INTER_CUBIC,
114
+ downscale_interpolation=cv2.INTER_AREA,
115
+ ):
109
116
  try:
110
- map_sample_fn = partial(map_sample, image_shape=image_shape,
111
- timeout=timeout, retries=retries, image_processor=image_processor)
117
+ map_sample_fn = partial(map_sample, image_shape=image_shape, min_image_shape=min_image_shape,
118
+ timeout=timeout, retries=retries, image_processor=image_processor,
119
+ upscale_interpolation=upscale_interpolation,
120
+ downscale_interpolation=downscale_interpolation)
112
121
  with ThreadPoolExecutor(max_workers=num_threads) as executor:
113
122
  executor.map(map_sample_fn, batch["url"], batch['caption'])
114
123
  except Exception as e:
115
- error_queue.put({
116
- "batch": batch,
117
- "error": str(e)
118
- })
119
-
120
-
121
- def parallel_image_loader(dataset: Dataset, num_workers: int = 8, image_shape=(256, 256),
122
- num_threads=256, timeout=15, retries=3, image_processor=default_image_processor):
123
- map_batch_fn = partial(map_batch, num_threads=num_threads, image_shape=image_shape,
124
- timeout=timeout, retries=retries, image_processor=image_processor)
124
+ # error_queue.put_nowait({
125
+ # "batch": batch,
126
+ # "error": str(e)
127
+ # })
128
+ pass
129
+
130
+
131
+ def parallel_image_loader(
132
+ dataset: Dataset, num_workers: int = 8, image_shape=(256, 256),
133
+ min_image_shape=(128, 128),
134
+ num_threads=256, timeout=15, retries=3, image_processor=default_image_processor,
135
+ upscale_interpolation=cv2.INTER_CUBIC,
136
+ downscale_interpolation=cv2.INTER_AREA,
137
+ ):
138
+ map_batch_fn = partial(map_batch, num_threads=num_threads, image_shape=image_shape,
139
+ min_image_shape=min_image_shape,
140
+ timeout=timeout, retries=retries, image_processor=image_processor,
141
+ upscale_interpolation=upscale_interpolation,
142
+ downscale_interpolation=downscale_interpolation)
125
143
  shard_len = len(dataset) // num_workers
126
144
  print(f"Local Shard lengths: {shard_len}")
127
145
  with multiprocessing.Pool(num_workers) as pool:
@@ -136,19 +154,29 @@ def parallel_image_loader(dataset: Dataset, num_workers: int = 8, image_shape=(2
136
154
  print(f"Shuffling dataset with seed {iteration}")
137
155
  dataset = dataset.shuffle(seed=iteration)
138
156
  # Clear the error queue
139
- while not error_queue.empty():
140
- error_queue.get_nowait()
157
+ # while not error_queue.empty():
158
+ # error_queue.get_nowait()
141
159
 
142
160
 
143
161
  class ImageBatchIterator:
144
- def __init__(self, dataset: Dataset, batch_size: int = 64, image_shape=(256, 256),
145
- num_workers: int = 8, num_threads=256, timeout=15, retries=3, image_processor=default_image_processor):
162
+ def __init__(
163
+ self, dataset: Dataset, batch_size: int = 64, image_shape=(256, 256),
164
+ min_image_shape=(128, 128),
165
+ num_workers: int = 8, num_threads=256, timeout=15, retries=3,
166
+ image_processor=default_image_processor,
167
+ upscale_interpolation=cv2.INTER_CUBIC,
168
+ downscale_interpolation=cv2.INTER_AREA,
169
+ ):
146
170
  self.dataset = dataset
147
171
  self.num_workers = num_workers
148
172
  self.batch_size = batch_size
149
173
  loader = partial(parallel_image_loader, num_threads=num_threads,
150
- image_shape=image_shape, num_workers=num_workers,
151
- timeout=timeout, retries=retries, image_processor=image_processor)
174
+ image_shape=image_shape,
175
+ min_image_shape=min_image_shape,
176
+ num_workers=num_workers,
177
+ timeout=timeout, retries=retries, image_processor=image_processor,
178
+ upscale_interpolation=upscale_interpolation,
179
+ downscale_interpolation=downscale_interpolation)
152
180
  self.thread = threading.Thread(target=loader, args=(dataset,))
153
181
  self.thread.start()
154
182
 
@@ -195,6 +223,7 @@ class OnlineStreamingDataLoader():
195
223
  dataset,
196
224
  batch_size=64,
197
225
  image_shape=(256, 256),
226
+ min_image_shape=(128, 128),
198
227
  num_workers=16,
199
228
  num_threads=512,
200
229
  default_split="all",
@@ -210,6 +239,8 @@ class OnlineStreamingDataLoader():
210
239
  timeout=15,
211
240
  retries=3,
212
241
  image_processor=default_image_processor,
242
+ upscale_interpolation=cv2.INTER_CUBIC,
243
+ downscale_interpolation=cv2.INTER_AREA,
213
244
  ):
214
245
  if isinstance(dataset, str):
215
246
  dataset_path = dataset
@@ -231,8 +262,11 @@ class OnlineStreamingDataLoader():
231
262
  num_shards=global_process_count, index=global_process_index)
232
263
  print(f"Dataset length: {len(dataset)}")
233
264
  self.iterator = ImageBatchIterator(self.dataset, image_shape=image_shape,
265
+ min_image_shape=min_image_shape,
234
266
  num_workers=num_workers, batch_size=batch_size, num_threads=num_threads,
235
- timeout=timeout, retries=retries, image_processor=image_processor)
267
+ timeout=timeout, retries=retries, image_processor=image_processor,
268
+ upscale_interpolation=upscale_interpolation,
269
+ downscale_interpolation=downscale_interpolation)
236
270
  self.batch_size = batch_size
237
271
 
238
272
  # Launch a thread to load batches in the background
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flaxdiff
3
- Version: 0.1.18
3
+ Version: 0.1.20
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author: Ashish Kumar Singh
6
6
  Author-email: ashishkmr472@gmail.com
@@ -1,7 +1,7 @@
1
1
  flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  flaxdiff/utils.py,sha256=B0GcHlzlVYDNEIdh2v5qmP4u0neIT-FqexNohuyuCvg,2452
3
3
  flaxdiff/data/__init__.py,sha256=PM3PkHihyohT5SHVYKc8vQ4IeVfGPpCktkSVwvqMjQ4,52
4
- flaxdiff/data/online_loader.py,sha256=qim6SRRGU1lRO0zQbDNjRYC7Qm6g7jtUfELEXotora0,8987
4
+ flaxdiff/data/online_loader.py,sha256=XVT_kT7v9CQVaQgunTL48KxgPgwQ-bhIi8RN-Q1qbYc,10451
5
5
  flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,26
6
6
  flaxdiff/models/attention.py,sha256=ZbDGIb5Q6FRqJ6qRY660cqw4WvF9IwCnhEuYdTpLPdM,13023
7
7
  flaxdiff/models/common.py,sha256=fd-Fl0VCNEBjijHNwGBqYL5VvXe9u0347h25czNTmRw,10780
@@ -34,7 +34,7 @@ flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk,
34
34
  flaxdiff/trainer/autoencoder_trainer.py,sha256=al7AsZ7yeDMEiDD-gbcXf0ADq_xfk1VMxvg24GfA-XQ,7008
35
35
  flaxdiff/trainer/diffusion_trainer.py,sha256=wKkg63DWZjx2MoM3VQNCDIr40rWN8fUGxH9jWWxfZao,9373
36
36
  flaxdiff/trainer/simple_trainer.py,sha256=Z77zRS5viJpd2Mpl6sonJk5WcnEWi2Cd4gl4u5tIX2M,18206
37
- flaxdiff-0.1.18.dist-info/METADATA,sha256=aUSr3lBb9P2mnrpmbcgQa41DT8YYM-DtVMU8NI3CZEE,22083
38
- flaxdiff-0.1.18.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
39
- flaxdiff-0.1.18.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
40
- flaxdiff-0.1.18.dist-info/RECORD,,
37
+ flaxdiff-0.1.20.dist-info/METADATA,sha256=ls0rUYnHBWdChfQ7meO2nlHSqGVEPn2JzZTOTagt2H8,22083
38
+ flaxdiff-0.1.20.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
39
+ flaxdiff-0.1.20.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
40
+ flaxdiff-0.1.20.dist-info/RECORD,,