flaxdiff 0.1.16__py3-none-any.whl → 0.1.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -13,7 +13,7 @@ from typing import Any, Dict, List, Tuple
13
13
  import numpy as np
14
14
  from functools import partial
15
15
 
16
- from datasets import load_dataset, concatenate_datasets, Dataset
16
+ from datasets import load_dataset, concatenate_datasets, Dataset, load_from_disk
17
17
  from datasets.utils.file_utils import get_datasets_user_agent
18
18
  from concurrent.futures import ThreadPoolExecutor
19
19
  import io
@@ -25,7 +25,8 @@ import cv2
25
25
  USER_AGENT = get_datasets_user_agent()
26
26
 
27
27
  data_queue = Queue(16*2000)
28
- error_queue = Queue(16*2000)
28
+ error_queue = Queue()
29
+
29
30
 
30
31
  def fetch_single_image(image_url, timeout=None, retries=0):
31
32
  for _ in range(retries + 1):
@@ -42,19 +43,35 @@ def fetch_single_image(image_url, timeout=None, retries=0):
42
43
  image = None
43
44
  return image
44
45
 
46
+
47
+ def default_image_processor(image, image_shape, interpolation=cv2.INTER_LANCZOS4):
48
+ image = A.longest_max_size(image, max(
49
+ image_shape), interpolation=interpolation)
50
+ image = A.pad(
51
+ image,
52
+ min_height=image_shape[0],
53
+ min_width=image_shape[1],
54
+ border_mode=cv2.BORDER_CONSTANT,
55
+ value=[255, 255, 255],
56
+ )
57
+ return image
58
+
59
+
45
60
  def map_sample(
46
- url, caption,
61
+ url, caption,
47
62
  image_shape=(256, 256),
48
63
  timeout=15,
49
64
  retries=3,
50
65
  upscale_interpolation=cv2.INTER_LANCZOS4,
51
66
  downscale_interpolation=cv2.INTER_AREA,
67
+ image_processor=default_image_processor,
52
68
  ):
53
69
  try:
54
- image = fetch_single_image(url, timeout=timeout, retries=retries) # Assuming fetch_single_image is defined elsewhere
70
+ # Assuming fetch_single_image is defined elsewhere
71
+ image = fetch_single_image(url, timeout=timeout, retries=retries)
55
72
  if image is None:
56
73
  return
57
-
74
+
58
75
  image = np.array(image)
59
76
  original_height, original_width = image.shape[:2]
60
77
  # check if the image is too small
@@ -69,14 +86,10 @@ def map_sample(
69
86
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
70
87
  downscale = max(original_width, original_height) > max(image_shape)
71
88
  interpolation = downscale_interpolation if downscale else upscale_interpolation
72
- image = A.longest_max_size(image, max(image_shape), interpolation=interpolation)
73
- image = A.pad(
74
- image,
75
- min_height=image_shape[0],
76
- min_width=image_shape[1],
77
- border_mode=cv2.BORDER_CONSTANT,
78
- value=[255, 255, 255],
79
- )
89
+
90
+ image = image_processor(
91
+ image, image_shape, interpolation=interpolation)
92
+
80
93
  data_queue.put({
81
94
  "url": url,
82
95
  "caption": caption,
@@ -85,65 +98,77 @@ def map_sample(
85
98
  "original_width": original_width,
86
99
  })
87
100
  except Exception as e:
88
- print(f"Error in map_sample: {str(e)}")
89
- error_queue.put({
101
+ error_queue.put_nowait({
90
102
  "url": url,
91
103
  "caption": caption,
92
104
  "error": str(e)
93
105
  })
94
106
 
95
- def map_batch(batch, num_threads=256, image_shape=(256, 256), timeout=15, retries=3):
107
+
108
+ def map_batch(batch, num_threads=256, image_shape=(256, 256), timeout=15, retries=3, image_processor=default_image_processor):
96
109
  try:
97
- map_sample_fn = partial(map_sample, image_shape=image_shape, timeout=timeout, retries=retries)
110
+ map_sample_fn = partial(map_sample, image_shape=image_shape,
111
+ timeout=timeout, retries=retries, image_processor=image_processor)
98
112
  with ThreadPoolExecutor(max_workers=num_threads) as executor:
99
113
  executor.map(map_sample_fn, batch["url"], batch['caption'])
100
114
  except Exception as e:
101
- print(f"Error in map_batch: {str(e)}")
102
115
  error_queue.put({
103
116
  "batch": batch,
104
117
  "error": str(e)
105
118
  })
106
-
107
- def parallel_image_loader(dataset: Dataset, num_workers: int = 8, image_shape=(256, 256), num_threads=256):
108
- map_batch_fn = partial(map_batch, num_threads=num_threads, image_shape=image_shape)
119
+
120
+
121
+ def parallel_image_loader(dataset: Dataset, num_workers: int = 8, image_shape=(256, 256),
122
+ num_threads=256, timeout=15, retries=3, image_processor=default_image_processor):
123
+ map_batch_fn = partial(map_batch, num_threads=num_threads, image_shape=image_shape,
124
+ timeout=timeout, retries=retries, image_processor=image_processor)
109
125
  shard_len = len(dataset) // num_workers
110
126
  print(f"Local Shard lengths: {shard_len}")
111
127
  with multiprocessing.Pool(num_workers) as pool:
112
128
  iteration = 0
113
129
  while True:
114
130
  # Repeat forever
115
- print(f"Shuffling dataset with seed {iteration}")
116
- # dataset = dataset.shuffle(seed=iteration)
117
- shards = [dataset[i*shard_len:(i+1)*shard_len] for i in range(num_workers)]
131
+ shards = [dataset[i*shard_len:(i+1)*shard_len]
132
+ for i in range(num_workers)]
118
133
  print(f"mapping {len(shards)} shards")
119
134
  pool.map(map_batch_fn, shards)
120
135
  iteration += 1
121
-
136
+ print(f"Shuffling dataset with seed {iteration}")
137
+ dataset = dataset.shuffle(seed=iteration)
138
+ # Clear the error queue
139
+ while not error_queue.empty():
140
+ error_queue.get_nowait()
141
+
142
+
122
143
  class ImageBatchIterator:
123
- def __init__(self, dataset: Dataset, batch_size: int = 64, image_shape=(256, 256), num_workers: int = 8, num_threads=256):
144
+ def __init__(self, dataset: Dataset, batch_size: int = 64, image_shape=(256, 256),
145
+ num_workers: int = 8, num_threads=256, timeout=15, retries=3, image_processor=default_image_processor):
124
146
  self.dataset = dataset
125
147
  self.num_workers = num_workers
126
148
  self.batch_size = batch_size
127
- loader = partial(parallel_image_loader, num_threads=num_threads, image_shape=image_shape, num_workers=num_workers)
149
+ loader = partial(parallel_image_loader, num_threads=num_threads,
150
+ image_shape=image_shape, num_workers=num_workers,
151
+ timeout=timeout, retries=retries, image_processor=image_processor)
128
152
  self.thread = threading.Thread(target=loader, args=(dataset,))
129
153
  self.thread.start()
130
-
154
+
131
155
  def __iter__(self):
132
156
  return self
133
-
157
+
134
158
  def __next__(self):
135
159
  def fetcher(_):
136
160
  return data_queue.get()
137
161
  with ThreadPoolExecutor(max_workers=self.batch_size) as executor:
138
162
  batch = list(executor.map(fetcher, range(self.batch_size)))
139
163
  return batch
140
-
164
+
141
165
  def __del__(self):
142
166
  self.thread.join()
143
-
167
+
144
168
  def __len__(self):
145
169
  return len(self.dataset) // self.batch_size
146
170
 
171
+
147
172
  def default_collate(batch):
148
173
  urls = [sample["url"] for sample in batch]
149
174
  captions = [sample["caption"] for sample in batch]
@@ -153,7 +178,8 @@ def default_collate(batch):
153
178
  "caption": captions,
154
179
  "image": images,
155
180
  }
156
-
181
+
182
+
157
183
  def dataMapper(map: Dict[str, Any]):
158
184
  def _map(sample) -> Dict[str, Any]:
159
185
  return {
@@ -162,16 +188,17 @@ def dataMapper(map: Dict[str, Any]):
162
188
  }
163
189
  return _map
164
190
 
191
+
165
192
  class OnlineStreamingDataLoader():
166
193
  def __init__(
167
- self,
168
- dataset,
169
- batch_size=64,
194
+ self,
195
+ dataset,
196
+ batch_size=64,
170
197
  image_shape=(256, 256),
171
- num_workers=16,
198
+ num_workers=16,
172
199
  num_threads=512,
173
200
  default_split="all",
174
- pre_map_maker=dataMapper,
201
+ pre_map_maker=dataMapper,
175
202
  pre_map_def={
176
203
  "url": "URL",
177
204
  "caption": "TEXT",
@@ -180,40 +207,53 @@ class OnlineStreamingDataLoader():
180
207
  global_process_index=0,
181
208
  prefetch=1000,
182
209
  collate_fn=default_collate,
210
+ timeout=15,
211
+ retries=3,
212
+ image_processor=default_image_processor,
183
213
  ):
184
214
  if isinstance(dataset, str):
185
215
  dataset_path = dataset
186
216
  print("Loading dataset from path")
187
- dataset = load_dataset(dataset_path, split=default_split)
217
+ if "gs://" in dataset:
218
+ dataset = load_from_disk(dataset_path)
219
+ else:
220
+ dataset = load_dataset(dataset_path, split=default_split)
188
221
  elif isinstance(dataset, list):
189
222
  if isinstance(dataset[0], str):
190
223
  print("Loading multiple datasets from paths")
191
- dataset = [load_dataset(dataset_path, split=default_split) for dataset_path in dataset]
224
+ dataset = [load_from_disk(dataset_path) if "gs://" in dataset_path else load_dataset(
225
+ dataset_path, split=default_split) for dataset_path in dataset]
192
226
  print("Concatenating multiple datasets")
193
227
  dataset = concatenate_datasets(dataset)
194
- dataset = dataset.map(pre_map_maker(pre_map_def), batched=True, batch_size=10000000)
195
- self.dataset = dataset.shard(num_shards=global_process_count, index=global_process_index)
228
+ dataset = dataset.shuffle(seed=0)
229
+ # dataset = dataset.map(pre_map_maker(pre_map_def), batched=True, batch_size=10000000)
230
+ self.dataset = dataset.shard(
231
+ num_shards=global_process_count, index=global_process_index)
196
232
  print(f"Dataset length: {len(dataset)}")
197
- self.iterator = ImageBatchIterator(self.dataset, image_shape=image_shape, num_workers=num_workers, batch_size=batch_size, num_threads=num_threads)
198
- self.collate_fn = collate_fn
233
+ self.iterator = ImageBatchIterator(self.dataset, image_shape=image_shape,
234
+ num_workers=num_workers, batch_size=batch_size, num_threads=num_threads,
235
+ timeout=timeout, retries=retries, image_processor=image_processor)
199
236
  self.batch_size = batch_size
200
-
237
+
201
238
  # Launch a thread to load batches in the background
202
239
  self.batch_queue = queue.Queue(prefetch)
203
-
240
+
204
241
  def batch_loader():
205
242
  for batch in self.iterator:
206
- self.batch_queue.put(batch)
207
-
243
+ try:
244
+ self.batch_queue.put(collate_fn(batch))
245
+ except Exception as e:
246
+ print("Error processing batch", e)
247
+
208
248
  self.loader_thread = threading.Thread(target=batch_loader)
209
249
  self.loader_thread.start()
210
-
250
+
211
251
  def __iter__(self):
212
252
  return self
213
-
253
+
214
254
  def __next__(self):
215
- return self.collate_fn(self.batch_queue.get())
255
+ return self.batch_queue.get()
216
256
  # return self.collate_fn(next(self.iterator))
217
-
257
+
218
258
  def __len__(self):
219
- return len(self.dataset)
259
+ return len(self.dataset)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flaxdiff
3
- Version: 0.1.16
3
+ Version: 0.1.18
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author: Ashish Kumar Singh
6
6
  Author-email: ashishkmr472@gmail.com
@@ -1,7 +1,7 @@
1
1
  flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  flaxdiff/utils.py,sha256=B0GcHlzlVYDNEIdh2v5qmP4u0neIT-FqexNohuyuCvg,2452
3
3
  flaxdiff/data/__init__.py,sha256=PM3PkHihyohT5SHVYKc8vQ4IeVfGPpCktkSVwvqMjQ4,52
4
- flaxdiff/data/online_loader.py,sha256=nrtZU4srZHsg3iN0sG91y_6nY7QtYXRPLk5rGn_BTIU,7728
4
+ flaxdiff/data/online_loader.py,sha256=qim6SRRGU1lRO0zQbDNjRYC7Qm6g7jtUfELEXotora0,8987
5
5
  flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,26
6
6
  flaxdiff/models/attention.py,sha256=ZbDGIb5Q6FRqJ6qRY660cqw4WvF9IwCnhEuYdTpLPdM,13023
7
7
  flaxdiff/models/common.py,sha256=fd-Fl0VCNEBjijHNwGBqYL5VvXe9u0347h25czNTmRw,10780
@@ -34,7 +34,7 @@ flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk,
34
34
  flaxdiff/trainer/autoencoder_trainer.py,sha256=al7AsZ7yeDMEiDD-gbcXf0ADq_xfk1VMxvg24GfA-XQ,7008
35
35
  flaxdiff/trainer/diffusion_trainer.py,sha256=wKkg63DWZjx2MoM3VQNCDIr40rWN8fUGxH9jWWxfZao,9373
36
36
  flaxdiff/trainer/simple_trainer.py,sha256=Z77zRS5viJpd2Mpl6sonJk5WcnEWi2Cd4gl4u5tIX2M,18206
37
- flaxdiff-0.1.16.dist-info/METADATA,sha256=BM2RLOiCDqRSWO_owxvWmL_PS3aFtNokPbf-qAuyK4o,22083
38
- flaxdiff-0.1.16.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
39
- flaxdiff-0.1.16.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
40
- flaxdiff-0.1.16.dist-info/RECORD,,
37
+ flaxdiff-0.1.18.dist-info/METADATA,sha256=aUSr3lBb9P2mnrpmbcgQa41DT8YYM-DtVMU8NI3CZEE,22083
38
+ flaxdiff-0.1.18.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
39
+ flaxdiff-0.1.18.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
40
+ flaxdiff-0.1.18.dist-info/RECORD,,