flaxdiff 0.1.18__py3-none-any.whl → 0.1.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -44,7 +44,7 @@ def fetch_single_image(image_url, timeout=None, retries=0):
44
44
  return image
45
45
 
46
46
 
47
- def default_image_processor(image, image_shape, interpolation=cv2.INTER_LANCZOS4):
47
+ def default_image_processor(image, image_shape, interpolation=cv2.INTER_CUBIC):
48
48
  image = A.longest_max_size(image, max(
49
49
  image_shape), interpolation=interpolation)
50
50
  image = A.pad(
@@ -62,7 +62,7 @@ def map_sample(
62
62
  image_shape=(256, 256),
63
63
  timeout=15,
64
64
  retries=3,
65
- upscale_interpolation=cv2.INTER_LANCZOS4,
65
+ upscale_interpolation=cv2.INTER_CUBIC,
66
66
  downscale_interpolation=cv2.INTER_AREA,
67
67
  image_processor=default_image_processor,
68
68
  ):
@@ -105,10 +105,17 @@ def map_sample(
105
105
  })
106
106
 
107
107
 
108
- def map_batch(batch, num_threads=256, image_shape=(256, 256), timeout=15, retries=3, image_processor=default_image_processor):
108
+ def map_batch(
109
+ batch, num_threads=256, image_shape=(256, 256),
110
+ timeout=15, retries=3, image_processor=default_image_processor,
111
+ upscale_interpolation=cv2.INTER_CUBIC,
112
+ downscale_interpolation=cv2.INTER_AREA,
113
+ ):
109
114
  try:
110
115
  map_sample_fn = partial(map_sample, image_shape=image_shape,
111
- timeout=timeout, retries=retries, image_processor=image_processor)
116
+ timeout=timeout, retries=retries, image_processor=image_processor,
117
+ upscale_interpolation=upscale_interpolation,
118
+ downscale_interpolation=downscale_interpolation)
112
119
  with ThreadPoolExecutor(max_workers=num_threads) as executor:
113
120
  executor.map(map_sample_fn, batch["url"], batch['caption'])
114
121
  except Exception as e:
@@ -118,10 +125,16 @@ def map_batch(batch, num_threads=256, image_shape=(256, 256), timeout=15, retrie
118
125
  })
119
126
 
120
127
 
121
- def parallel_image_loader(dataset: Dataset, num_workers: int = 8, image_shape=(256, 256),
122
- num_threads=256, timeout=15, retries=3, image_processor=default_image_processor):
128
+ def parallel_image_loader(
129
+ dataset: Dataset, num_workers: int = 8, image_shape=(256, 256),
130
+ num_threads=256, timeout=15, retries=3, image_processor=default_image_processor,
131
+ upscale_interpolation=cv2.INTER_CUBIC,
132
+ downscale_interpolation=cv2.INTER_AREA,
133
+ ):
123
134
  map_batch_fn = partial(map_batch, num_threads=num_threads, image_shape=image_shape,
124
- timeout=timeout, retries=retries, image_processor=image_processor)
135
+ timeout=timeout, retries=retries, image_processor=image_processor,
136
+ upscale_interpolation=upscale_interpolation,
137
+ downscale_interpolation=downscale_interpolation)
125
138
  shard_len = len(dataset) // num_workers
126
139
  print(f"Local Shard lengths: {shard_len}")
127
140
  with multiprocessing.Pool(num_workers) as pool:
@@ -141,14 +154,21 @@ def parallel_image_loader(dataset: Dataset, num_workers: int = 8, image_shape=(2
141
154
 
142
155
 
143
156
  class ImageBatchIterator:
144
- def __init__(self, dataset: Dataset, batch_size: int = 64, image_shape=(256, 256),
145
- num_workers: int = 8, num_threads=256, timeout=15, retries=3, image_processor=default_image_processor):
157
+ def __init__(
158
+ self, dataset: Dataset, batch_size: int = 64, image_shape=(256, 256),
159
+ num_workers: int = 8, num_threads=256, timeout=15, retries=3,
160
+ image_processor=default_image_processor,
161
+ upscale_interpolation=cv2.INTER_CUBIC,
162
+ downscale_interpolation=cv2.INTER_AREA,
163
+ ):
146
164
  self.dataset = dataset
147
165
  self.num_workers = num_workers
148
166
  self.batch_size = batch_size
149
167
  loader = partial(parallel_image_loader, num_threads=num_threads,
150
168
  image_shape=image_shape, num_workers=num_workers,
151
- timeout=timeout, retries=retries, image_processor=image_processor)
169
+ timeout=timeout, retries=retries, image_processor=image_processor,
170
+ upscale_interpolation=upscale_interpolation,
171
+ downscale_interpolation=downscale_interpolation)
152
172
  self.thread = threading.Thread(target=loader, args=(dataset,))
153
173
  self.thread.start()
154
174
 
@@ -210,6 +230,8 @@ class OnlineStreamingDataLoader():
210
230
  timeout=15,
211
231
  retries=3,
212
232
  image_processor=default_image_processor,
233
+ upscale_interpolation=cv2.INTER_CUBIC,
234
+ downscale_interpolation=cv2.INTER_AREA,
213
235
  ):
214
236
  if isinstance(dataset, str):
215
237
  dataset_path = dataset
@@ -232,7 +254,9 @@ class OnlineStreamingDataLoader():
232
254
  print(f"Dataset length: {len(dataset)}")
233
255
  self.iterator = ImageBatchIterator(self.dataset, image_shape=image_shape,
234
256
  num_workers=num_workers, batch_size=batch_size, num_threads=num_threads,
235
- timeout=timeout, retries=retries, image_processor=image_processor)
257
+ timeout=timeout, retries=retries, image_processor=image_processor,
258
+ upscale_interpolation=upscale_interpolation,
259
+ downscale_interpolation=downscale_interpolation)
236
260
  self.batch_size = batch_size
237
261
 
238
262
  # Launch a thread to load batches in the background
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flaxdiff
3
- Version: 0.1.18
3
+ Version: 0.1.19
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author: Ashish Kumar Singh
6
6
  Author-email: ashishkmr472@gmail.com
@@ -1,7 +1,7 @@
1
1
  flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  flaxdiff/utils.py,sha256=B0GcHlzlVYDNEIdh2v5qmP4u0neIT-FqexNohuyuCvg,2452
3
3
  flaxdiff/data/__init__.py,sha256=PM3PkHihyohT5SHVYKc8vQ4IeVfGPpCktkSVwvqMjQ4,52
4
- flaxdiff/data/online_loader.py,sha256=qim6SRRGU1lRO0zQbDNjRYC7Qm6g7jtUfELEXotora0,8987
4
+ flaxdiff/data/online_loader.py,sha256=WK4apO8Bx-RTU_z5imB53Lzq12vqGnXA9DhLq8nb0us,9991
5
5
  flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,26
6
6
  flaxdiff/models/attention.py,sha256=ZbDGIb5Q6FRqJ6qRY660cqw4WvF9IwCnhEuYdTpLPdM,13023
7
7
  flaxdiff/models/common.py,sha256=fd-Fl0VCNEBjijHNwGBqYL5VvXe9u0347h25czNTmRw,10780
@@ -34,7 +34,7 @@ flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk,
34
34
  flaxdiff/trainer/autoencoder_trainer.py,sha256=al7AsZ7yeDMEiDD-gbcXf0ADq_xfk1VMxvg24GfA-XQ,7008
35
35
  flaxdiff/trainer/diffusion_trainer.py,sha256=wKkg63DWZjx2MoM3VQNCDIr40rWN8fUGxH9jWWxfZao,9373
36
36
  flaxdiff/trainer/simple_trainer.py,sha256=Z77zRS5viJpd2Mpl6sonJk5WcnEWi2Cd4gl4u5tIX2M,18206
37
- flaxdiff-0.1.18.dist-info/METADATA,sha256=aUSr3lBb9P2mnrpmbcgQa41DT8YYM-DtVMU8NI3CZEE,22083
38
- flaxdiff-0.1.18.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
39
- flaxdiff-0.1.18.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
40
- flaxdiff-0.1.18.dist-info/RECORD,,
37
+ flaxdiff-0.1.19.dist-info/METADATA,sha256=NH-f1SK5obamoVRk8ZPQxvtQcz_R3mui3ToZe0Qx8Vg,22083
38
+ flaxdiff-0.1.19.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
39
+ flaxdiff-0.1.19.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
40
+ flaxdiff-0.1.19.dist-info/RECORD,,