flaxdiff 0.1.14__py3-none-any.whl → 0.1.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -27,7 +27,6 @@ USER_AGENT = get_datasets_user_agent()
27
27
  data_queue = Queue(16*2000)
28
28
  error_queue = Queue(16*2000)
29
29
 
30
-
31
30
  def fetch_single_image(image_url, timeout=None, retries=0):
32
31
  for _ in range(retries + 1):
33
32
  try:
@@ -46,11 +45,13 @@ def fetch_single_image(image_url, timeout=None, retries=0):
46
45
  def map_sample(
47
46
  url, caption,
48
47
  image_shape=(256, 256),
48
+ timeout=15,
49
+ retries=3,
49
50
  upscale_interpolation=cv2.INTER_LANCZOS4,
50
51
  downscale_interpolation=cv2.INTER_AREA,
51
52
  ):
52
53
  try:
53
- image = fetch_single_image(url, timeout=15, retries=3) # Assuming fetch_single_image is defined elsewhere
54
+ image = fetch_single_image(url, timeout=timeout, retries=retries) # Assuming fetch_single_image is defined elsewhere
54
55
  if image is None:
55
56
  return
56
57
 
@@ -84,15 +85,24 @@ def map_sample(
84
85
  "original_width": original_width,
85
86
  })
86
87
  except Exception as e:
88
+ print(f"Error in map_sample: {str(e)}")
87
89
  error_queue.put({
88
90
  "url": url,
89
91
  "caption": caption,
90
92
  "error": str(e)
91
93
  })
92
-
93
- def map_batch(batch, num_threads=256, image_shape=(256, 256), timeout=None, retries=0):
94
- with ThreadPoolExecutor(max_workers=num_threads) as executor:
95
- executor.map(map_sample, batch["url"], batch['caption'], image_shape=image_shape, timeout=timeout, retries=retries)
94
+
95
+ def map_batch(batch, num_threads=256, image_shape=(256, 256), timeout=15, retries=3):
96
+ try:
97
+ map_sample_fn = partial(map_sample, image_shape=image_shape, timeout=timeout, retries=retries)
98
+ with ThreadPoolExecutor(max_workers=num_threads) as executor:
99
+ executor.map(map_sample_fn, batch["url"], batch['caption'])
100
+ except Exception as e:
101
+ print(f"Error in map_batch: {str(e)}")
102
+ error_queue.put({
103
+ "batch": batch,
104
+ "error": str(e)
105
+ })
96
106
 
97
107
  def parallel_image_loader(dataset: Dataset, num_workers: int = 8, image_shape=(256, 256), num_threads=256):
98
108
  map_batch_fn = partial(map_batch, num_threads=num_threads, image_shape=image_shape)
@@ -102,8 +112,10 @@ def parallel_image_loader(dataset: Dataset, num_workers: int = 8, image_shape=(2
102
112
  iteration = 0
103
113
  while True:
104
114
  # Repeat forever
105
- dataset = dataset.shuffle(seed=iteration)
115
+ print(f"Shuffling dataset with seed {iteration}")
116
+ # dataset = dataset.shuffle(seed=iteration)
106
117
  shards = [dataset[i*shard_len:(i+1)*shard_len] for i in range(num_workers)]
118
+ print(f"mapping {len(shards)} shards")
107
119
  pool.map(map_batch_fn, shards)
108
120
  iteration += 1
109
121
 
@@ -113,7 +125,7 @@ class ImageBatchIterator:
113
125
  self.num_workers = num_workers
114
126
  self.batch_size = batch_size
115
127
  loader = partial(parallel_image_loader, num_threads=num_threads, image_shape=image_shape, num_workers=num_workers)
116
- self.thread = threading.Thread(target=loader, args=(dataset))
128
+ self.thread = threading.Thread(target=loader, args=(dataset,))
117
129
  self.thread.start()
118
130
 
119
131
  def __iter__(self):
@@ -131,7 +143,7 @@ class ImageBatchIterator:
131
143
 
132
144
  def __len__(self):
133
145
  return len(self.dataset) // self.batch_size
134
-
146
+
135
147
  def default_collate(batch):
136
148
  urls = [sample["url"] for sample in batch]
137
149
  captions = [sample["caption"] for sample in batch]
@@ -177,14 +189,14 @@ class OnlineStreamingDataLoader():
177
189
  if isinstance(dataset[0], str):
178
190
  print("Loading multiple datasets from paths")
179
191
  dataset = [load_dataset(dataset_path, split=default_split) for dataset_path in dataset]
180
- else:
181
- print("Concatenating multiple datasets")
182
- dataset = concatenate_datasets(dataset)
183
- dataset = dataset.map(pre_map_maker(pre_map_def))
192
+ print("Concatenating multiple datasets")
193
+ dataset = concatenate_datasets(dataset)
194
+ dataset = dataset.map(pre_map_maker(pre_map_def), batched=True, batch_size=10000000)
184
195
  self.dataset = dataset.shard(num_shards=global_process_count, index=global_process_index)
185
196
  print(f"Dataset length: {len(dataset)}")
186
197
  self.iterator = ImageBatchIterator(self.dataset, image_shape=image_shape, num_workers=num_workers, batch_size=batch_size, num_threads=num_threads)
187
198
  self.collate_fn = collate_fn
199
+ self.batch_size = batch_size
188
200
 
189
201
  # Launch a thread to load batches in the background
190
202
  self.batch_queue = queue.Queue(prefetch)
@@ -204,5 +216,4 @@ class OnlineStreamingDataLoader():
204
216
  # return self.collate_fn(next(self.iterator))
205
217
 
206
218
  def __len__(self):
207
- return len(self.dataset) // self.batch_size
208
-
219
+ return len(self.dataset)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flaxdiff
3
- Version: 0.1.14
3
+ Version: 0.1.16
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author: Ashish Kumar Singh
6
6
  Author-email: ashishkmr472@gmail.com
@@ -1,7 +1,7 @@
1
1
  flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  flaxdiff/utils.py,sha256=B0GcHlzlVYDNEIdh2v5qmP4u0neIT-FqexNohuyuCvg,2452
3
3
  flaxdiff/data/__init__.py,sha256=PM3PkHihyohT5SHVYKc8vQ4IeVfGPpCktkSVwvqMjQ4,52
4
- flaxdiff/data/online_loader.py,sha256=IjCVdeq18lF71eLX8RplJLezjKASO1cSFVe2GSYkLQ8,7283
4
+ flaxdiff/data/online_loader.py,sha256=nrtZU4srZHsg3iN0sG91y_6nY7QtYXRPLk5rGn_BTIU,7728
5
5
  flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,26
6
6
  flaxdiff/models/attention.py,sha256=ZbDGIb5Q6FRqJ6qRY660cqw4WvF9IwCnhEuYdTpLPdM,13023
7
7
  flaxdiff/models/common.py,sha256=fd-Fl0VCNEBjijHNwGBqYL5VvXe9u0347h25czNTmRw,10780
@@ -34,7 +34,7 @@ flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk,
34
34
  flaxdiff/trainer/autoencoder_trainer.py,sha256=al7AsZ7yeDMEiDD-gbcXf0ADq_xfk1VMxvg24GfA-XQ,7008
35
35
  flaxdiff/trainer/diffusion_trainer.py,sha256=wKkg63DWZjx2MoM3VQNCDIr40rWN8fUGxH9jWWxfZao,9373
36
36
  flaxdiff/trainer/simple_trainer.py,sha256=Z77zRS5viJpd2Mpl6sonJk5WcnEWi2Cd4gl4u5tIX2M,18206
37
- flaxdiff-0.1.14.dist-info/METADATA,sha256=rcwF7cCFfgPLHn5gD7GZ_KvtG6EmiAIiel7xt8HylAo,22083
38
- flaxdiff-0.1.14.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
39
- flaxdiff-0.1.14.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
40
- flaxdiff-0.1.14.dist-info/RECORD,,
37
+ flaxdiff-0.1.16.dist-info/METADATA,sha256=BM2RLOiCDqRSWO_owxvWmL_PS3aFtNokPbf-qAuyK4o,22083
38
+ flaxdiff-0.1.16.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
39
+ flaxdiff-0.1.16.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
40
+ flaxdiff-0.1.16.dist-info/RECORD,,