flaxdiff 0.1.13__py3-none-any.whl → 0.1.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
flaxdiff/data/__init__.py CHANGED
@@ -0,0 +1 @@
1
+ from .online_loader import OnlineStreamingDataLoader
@@ -79,7 +79,9 @@ def map_sample(
79
79
  data_queue.put({
80
80
  "url": url,
81
81
  "caption": caption,
82
- "image": image
82
+ "image": image,
83
+ "original_height": original_height,
84
+ "original_width": original_width,
83
85
  })
84
86
  except Exception as e:
85
87
  error_queue.put({
@@ -88,12 +90,12 @@ def map_sample(
88
90
  "error": str(e)
89
91
  })
90
92
 
91
- def map_batch(batch, num_threads=256, timeout=None, retries=0):
93
+ def map_batch(batch, num_threads=256, image_shape=(256, 256), timeout=None, retries=0):
92
94
  with ThreadPoolExecutor(max_workers=num_threads) as executor:
93
- executor.map(map_sample, batch["url"], batch['caption'])
95
+ executor.map(map_sample, batch["url"], batch['caption'], image_shape=image_shape, timeout=timeout, retries=retries)
94
96
 
95
- def parallel_image_loader(dataset: Dataset, num_workers: int = 8, num_threads=256):
96
- map_batch_fn = partial(map_batch, num_threads=num_threads)
97
+ def parallel_image_loader(dataset: Dataset, num_workers: int = 8, image_shape=(256, 256), num_threads=256):
98
+ map_batch_fn = partial(map_batch, num_threads=num_threads, image_shape=image_shape)
97
99
  shard_len = len(dataset) // num_workers
98
100
  print(f"Local Shard lengths: {shard_len}")
99
101
  with multiprocessing.Pool(num_workers) as pool:
@@ -106,12 +108,12 @@ def parallel_image_loader(dataset: Dataset, num_workers: int = 8, num_threads=25
106
108
  iteration += 1
107
109
 
108
110
  class ImageBatchIterator:
109
- def __init__(self, dataset: Dataset, batch_size: int = 64, num_workers: int = 8, num_threads=256):
111
+ def __init__(self, dataset: Dataset, batch_size: int = 64, image_shape=(256, 256), num_workers: int = 8, num_threads=256):
110
112
  self.dataset = dataset
111
113
  self.num_workers = num_workers
112
114
  self.batch_size = batch_size
113
- loader = partial(parallel_image_loader, num_threads=num_threads)
114
- self.thread = threading.Thread(target=loader, args=(dataset, num_workers))
115
+ loader = partial(parallel_image_loader, num_threads=num_threads, image_shape=image_shape, num_workers=num_workers)
116
+ self.thread = threading.Thread(target=loader, args=(dataset))
115
117
  self.thread.start()
116
118
 
117
119
  def __iter__(self):
@@ -153,6 +155,7 @@ class OnlineStreamingDataLoader():
153
155
  self,
154
156
  dataset,
155
157
  batch_size=64,
158
+ image_shape=(256, 256),
156
159
  num_workers=16,
157
160
  num_threads=512,
158
161
  default_split="all",
@@ -180,7 +183,7 @@ class OnlineStreamingDataLoader():
180
183
  dataset = dataset.map(pre_map_maker(pre_map_def))
181
184
  self.dataset = dataset.shard(num_shards=global_process_count, index=global_process_index)
182
185
  print(f"Dataset length: {len(dataset)}")
183
- self.iterator = ImageBatchIterator(self.dataset, num_workers=num_workers, batch_size=batch_size, num_threads=num_threads)
186
+ self.iterator = ImageBatchIterator(self.dataset, image_shape=image_shape, num_workers=num_workers, batch_size=batch_size, num_threads=num_threads)
184
187
  self.collate_fn = collate_fn
185
188
 
186
189
  # Launch a thread to load batches in the background
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flaxdiff
3
- Version: 0.1.13
3
+ Version: 0.1.14
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author: Ashish Kumar Singh
6
6
  Author-email: ashishkmr472@gmail.com
@@ -1,7 +1,7 @@
1
1
  flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  flaxdiff/utils.py,sha256=B0GcHlzlVYDNEIdh2v5qmP4u0neIT-FqexNohuyuCvg,2452
3
- flaxdiff/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
- flaxdiff/data/online_loader.py,sha256=_q0YgXGif1zLCNOLYJB0w3QFqE_p1zOjE9-hAApZFA4,6938
3
+ flaxdiff/data/__init__.py,sha256=PM3PkHihyohT5SHVYKc8vQ4IeVfGPpCktkSVwvqMjQ4,52
4
+ flaxdiff/data/online_loader.py,sha256=IjCVdeq18lF71eLX8RplJLezjKASO1cSFVe2GSYkLQ8,7283
5
5
  flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,26
6
6
  flaxdiff/models/attention.py,sha256=ZbDGIb5Q6FRqJ6qRY660cqw4WvF9IwCnhEuYdTpLPdM,13023
7
7
  flaxdiff/models/common.py,sha256=fd-Fl0VCNEBjijHNwGBqYL5VvXe9u0347h25czNTmRw,10780
@@ -34,7 +34,7 @@ flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk,
34
34
  flaxdiff/trainer/autoencoder_trainer.py,sha256=al7AsZ7yeDMEiDD-gbcXf0ADq_xfk1VMxvg24GfA-XQ,7008
35
35
  flaxdiff/trainer/diffusion_trainer.py,sha256=wKkg63DWZjx2MoM3VQNCDIr40rWN8fUGxH9jWWxfZao,9373
36
36
  flaxdiff/trainer/simple_trainer.py,sha256=Z77zRS5viJpd2Mpl6sonJk5WcnEWi2Cd4gl4u5tIX2M,18206
37
- flaxdiff-0.1.13.dist-info/METADATA,sha256=sx4MOQ9jPKy1yl_sUjkJbeu6vUJJG9HX6x6W8ILmW5I,22083
38
- flaxdiff-0.1.13.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
39
- flaxdiff-0.1.13.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
40
- flaxdiff-0.1.13.dist-info/RECORD,,
37
+ flaxdiff-0.1.14.dist-info/METADATA,sha256=rcwF7cCFfgPLHn5gD7GZ_KvtG6EmiAIiel7xt8HylAo,22083
38
+ flaxdiff-0.1.14.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
39
+ flaxdiff-0.1.14.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
40
+ flaxdiff-0.1.14.dist-info/RECORD,,