flaxdiff 0.1.15__py3-none-any.whl → 0.1.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -13,7 +13,7 @@ from typing import Any, Dict, List, Tuple
13
13
  import numpy as np
14
14
  from functools import partial
15
15
 
16
- from datasets import load_dataset, concatenate_datasets, Dataset
16
+ from datasets import load_dataset, concatenate_datasets, Dataset, load_from_disk
17
17
  from datasets.utils.file_utils import get_datasets_user_agent
18
18
  from concurrent.futures import ThreadPoolExecutor
19
19
  import io
@@ -43,17 +43,35 @@ def fetch_single_image(image_url, timeout=None, retries=0):
43
43
  image = None
44
44
  return image
45
45
 
46
+
47
+ def default_image_processor(image, image_shape, interpolation=cv2.INTER_LANCZOS4):
48
+ image = A.longest_max_size(image, max(
49
+ image_shape), interpolation=interpolation)
50
+ image = A.pad(
51
+ image,
52
+ min_height=image_shape[0],
53
+ min_width=image_shape[1],
54
+ border_mode=cv2.BORDER_CONSTANT,
55
+ value=[255, 255, 255],
56
+ )
57
+ return image
58
+
59
+
46
60
  def map_sample(
47
- url, caption,
61
+ url, caption,
48
62
  image_shape=(256, 256),
63
+ timeout=15,
64
+ retries=3,
49
65
  upscale_interpolation=cv2.INTER_LANCZOS4,
50
66
  downscale_interpolation=cv2.INTER_AREA,
67
+ image_processor=default_image_processor,
51
68
  ):
52
69
  try:
53
- image = fetch_single_image(url, timeout=15, retries=3) # Assuming fetch_single_image is defined elsewhere
70
+ # Assuming fetch_single_image is defined elsewhere
71
+ image = fetch_single_image(url, timeout=timeout, retries=retries)
54
72
  if image is None:
55
73
  return
56
-
74
+
57
75
  image = np.array(image)
58
76
  original_height, original_width = image.shape[:2]
59
77
  # check if the image is too small
@@ -68,14 +86,10 @@ def map_sample(
68
86
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
69
87
  downscale = max(original_width, original_height) > max(image_shape)
70
88
  interpolation = downscale_interpolation if downscale else upscale_interpolation
71
- image = A.longest_max_size(image, max(image_shape), interpolation=interpolation)
72
- image = A.pad(
73
- image,
74
- min_height=image_shape[0],
75
- min_width=image_shape[1],
76
- border_mode=cv2.BORDER_CONSTANT,
77
- value=[255, 255, 255],
78
- )
89
+
90
+ image = image_processor(
91
+ image, image_shape, interpolation=interpolation)
92
+
79
93
  data_queue.put({
80
94
  "url": url,
81
95
  "caption": caption,
@@ -89,49 +103,69 @@ def map_sample(
89
103
  "caption": caption,
90
104
  "error": str(e)
91
105
  })
92
-
93
- def map_batch(batch, num_threads=256, image_shape=(256, 256), timeout=None, retries=0):
94
- with ThreadPoolExecutor(max_workers=num_threads) as executor:
95
- executor.map(map_sample, batch["url"], batch['caption'], image_shape=image_shape, timeout=timeout, retries=retries)
96
-
97
- def parallel_image_loader(dataset: Dataset, num_workers: int = 8, image_shape=(256, 256), num_threads=256):
98
- map_batch_fn = partial(map_batch, num_threads=num_threads, image_shape=image_shape)
106
+
107
+
108
+ def map_batch(batch, num_threads=256, image_shape=(256, 256), timeout=15, retries=3, image_processor=default_image_processor):
109
+ try:
110
+ map_sample_fn = partial(map_sample, image_shape=image_shape,
111
+ timeout=timeout, retries=retries, image_processor=image_processor)
112
+ with ThreadPoolExecutor(max_workers=num_threads) as executor:
113
+ executor.map(map_sample_fn, batch["url"], batch['caption'])
114
+ except Exception as e:
115
+ error_queue.put({
116
+ "batch": batch,
117
+ "error": str(e)
118
+ })
119
+
120
+
121
+ def parallel_image_loader(dataset: Dataset, num_workers: int = 8, image_shape=(256, 256),
122
+ num_threads=256, timeout=15, retries=3, image_processor=default_image_processor):
123
+ map_batch_fn = partial(map_batch, num_threads=num_threads, image_shape=image_shape,
124
+ timeout=timeout, retries=retries, image_processor=image_processor)
99
125
  shard_len = len(dataset) // num_workers
100
126
  print(f"Local Shard lengths: {shard_len}")
101
127
  with multiprocessing.Pool(num_workers) as pool:
102
128
  iteration = 0
103
129
  while True:
104
130
  # Repeat forever
105
- dataset = dataset.shuffle(seed=iteration)
106
- shards = [dataset[i*shard_len:(i+1)*shard_len] for i in range(num_workers)]
131
+ shards = [dataset[i*shard_len:(i+1)*shard_len]
132
+ for i in range(num_workers)]
133
+ print(f"mapping {len(shards)} shards")
107
134
  pool.map(map_batch_fn, shards)
108
135
  iteration += 1
109
-
136
+ print(f"Shuffling dataset with seed {iteration}")
137
+ dataset = dataset.shuffle(seed=iteration)
138
+
139
+
110
140
  class ImageBatchIterator:
111
- def __init__(self, dataset: Dataset, batch_size: int = 64, image_shape=(256, 256), num_workers: int = 8, num_threads=256):
141
+ def __init__(self, dataset: Dataset, batch_size: int = 64, image_shape=(256, 256),
142
+ num_workers: int = 8, num_threads=256, timeout=15, retries=3, image_processor=default_image_processor):
112
143
  self.dataset = dataset
113
144
  self.num_workers = num_workers
114
145
  self.batch_size = batch_size
115
- loader = partial(parallel_image_loader, num_threads=num_threads, image_shape=image_shape, num_workers=num_workers)
146
+ loader = partial(parallel_image_loader, num_threads=num_threads,
147
+ image_shape=image_shape, num_workers=num_workers,
148
+ timeout=timeout, retries=retries, image_processor=image_processor)
116
149
  self.thread = threading.Thread(target=loader, args=(dataset,))
117
150
  self.thread.start()
118
-
151
+
119
152
  def __iter__(self):
120
153
  return self
121
-
154
+
122
155
  def __next__(self):
123
156
  def fetcher(_):
124
157
  return data_queue.get()
125
158
  with ThreadPoolExecutor(max_workers=self.batch_size) as executor:
126
159
  batch = list(executor.map(fetcher, range(self.batch_size)))
127
160
  return batch
128
-
161
+
129
162
  def __del__(self):
130
163
  self.thread.join()
131
-
164
+
132
165
  def __len__(self):
133
166
  return len(self.dataset) // self.batch_size
134
167
 
168
+
135
169
  def default_collate(batch):
136
170
  urls = [sample["url"] for sample in batch]
137
171
  captions = [sample["caption"] for sample in batch]
@@ -141,7 +175,8 @@ def default_collate(batch):
141
175
  "caption": captions,
142
176
  "image": images,
143
177
  }
144
-
178
+
179
+
145
180
  def dataMapper(map: Dict[str, Any]):
146
181
  def _map(sample) -> Dict[str, Any]:
147
182
  return {
@@ -150,16 +185,17 @@ def dataMapper(map: Dict[str, Any]):
150
185
  }
151
186
  return _map
152
187
 
188
+
153
189
  class OnlineStreamingDataLoader():
154
190
  def __init__(
155
- self,
156
- dataset,
157
- batch_size=64,
191
+ self,
192
+ dataset,
193
+ batch_size=64,
158
194
  image_shape=(256, 256),
159
- num_workers=16,
195
+ num_workers=16,
160
196
  num_threads=512,
161
197
  default_split="all",
162
- pre_map_maker=dataMapper,
198
+ pre_map_maker=dataMapper,
163
199
  pre_map_def={
164
200
  "url": "URL",
165
201
  "caption": "TEXT",
@@ -168,41 +204,51 @@ class OnlineStreamingDataLoader():
168
204
  global_process_index=0,
169
205
  prefetch=1000,
170
206
  collate_fn=default_collate,
207
+ timeout=15,
208
+ retries=3,
209
+ image_processor=default_image_processor,
171
210
  ):
172
211
  if isinstance(dataset, str):
173
212
  dataset_path = dataset
174
213
  print("Loading dataset from path")
175
- dataset = load_dataset(dataset_path, split=default_split)
214
+ if "gs://" in dataset:
215
+ dataset = load_from_disk(dataset_path)
216
+ else:
217
+ dataset = load_dataset(dataset_path, split=default_split)
176
218
  elif isinstance(dataset, list):
177
219
  if isinstance(dataset[0], str):
178
220
  print("Loading multiple datasets from paths")
179
- dataset = [load_dataset(dataset_path, split=default_split) for dataset_path in dataset]
221
+ dataset = [load_from_disk(dataset_path) if "gs://" in dataset_path else load_dataset(
222
+ dataset_path, split=default_split) for dataset_path in dataset]
180
223
  print("Concatenating multiple datasets")
181
224
  dataset = concatenate_datasets(dataset)
182
- dataset = dataset.map(pre_map_maker(pre_map_def), batched=True, batch_size=10000000)
183
- self.dataset = dataset.shard(num_shards=global_process_count, index=global_process_index)
225
+ dataset = dataset.shuffle(seed=0)
226
+ # dataset = dataset.map(pre_map_maker(pre_map_def), batched=True, batch_size=10000000)
227
+ self.dataset = dataset.shard(
228
+ num_shards=global_process_count, index=global_process_index)
184
229
  print(f"Dataset length: {len(dataset)}")
185
- self.iterator = ImageBatchIterator(self.dataset, image_shape=image_shape, num_workers=num_workers, batch_size=batch_size, num_threads=num_threads)
230
+ self.iterator = ImageBatchIterator(self.dataset, image_shape=image_shape,
231
+ num_workers=num_workers, batch_size=batch_size, num_threads=num_threads,
232
+ timeout=timeout, retries=retries, image_processor=image_processor)
186
233
  self.collate_fn = collate_fn
187
234
  self.batch_size = batch_size
188
-
235
+
189
236
  # Launch a thread to load batches in the background
190
237
  self.batch_queue = queue.Queue(prefetch)
191
-
238
+
192
239
  def batch_loader():
193
240
  for batch in self.iterator:
194
241
  self.batch_queue.put(batch)
195
-
242
+
196
243
  self.loader_thread = threading.Thread(target=batch_loader)
197
244
  self.loader_thread.start()
198
-
245
+
199
246
  def __iter__(self):
200
247
  return self
201
-
248
+
202
249
  def __next__(self):
203
250
  return self.collate_fn(self.batch_queue.get())
204
251
  # return self.collate_fn(next(self.iterator))
205
-
252
+
206
253
  def __len__(self):
207
254
  return len(self.dataset)
208
-
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flaxdiff
3
- Version: 0.1.15
3
+ Version: 0.1.17
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author: Ashish Kumar Singh
6
6
  Author-email: ashishkmr472@gmail.com
@@ -1,7 +1,7 @@
1
1
  flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  flaxdiff/utils.py,sha256=B0GcHlzlVYDNEIdh2v5qmP4u0neIT-FqexNohuyuCvg,2452
3
3
  flaxdiff/data/__init__.py,sha256=PM3PkHihyohT5SHVYKc8vQ4IeVfGPpCktkSVwvqMjQ4,52
4
- flaxdiff/data/online_loader.py,sha256=Xd4xAerzOU65tvb0cuM8S3Hnm8xv1dJ62xbxAxZkeBw,7307
4
+ flaxdiff/data/online_loader.py,sha256=BM4Le-4BUo8MJpRzGIA2nMHKm4-WynQ2BOdiQz0JCDs,8791
5
5
  flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,26
6
6
  flaxdiff/models/attention.py,sha256=ZbDGIb5Q6FRqJ6qRY660cqw4WvF9IwCnhEuYdTpLPdM,13023
7
7
  flaxdiff/models/common.py,sha256=fd-Fl0VCNEBjijHNwGBqYL5VvXe9u0347h25czNTmRw,10780
@@ -34,7 +34,7 @@ flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk,
34
34
  flaxdiff/trainer/autoencoder_trainer.py,sha256=al7AsZ7yeDMEiDD-gbcXf0ADq_xfk1VMxvg24GfA-XQ,7008
35
35
  flaxdiff/trainer/diffusion_trainer.py,sha256=wKkg63DWZjx2MoM3VQNCDIr40rWN8fUGxH9jWWxfZao,9373
36
36
  flaxdiff/trainer/simple_trainer.py,sha256=Z77zRS5viJpd2Mpl6sonJk5WcnEWi2Cd4gl4u5tIX2M,18206
37
- flaxdiff-0.1.15.dist-info/METADATA,sha256=45ifUa-j2rPqP3s734HRy2rIM1StRvSZ4lIJI0R3sTw,22083
38
- flaxdiff-0.1.15.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
39
- flaxdiff-0.1.15.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
40
- flaxdiff-0.1.15.dist-info/RECORD,,
37
+ flaxdiff-0.1.17.dist-info/METADATA,sha256=2Nr_T2yg3XHFt2jBuUXo8FxLYM8si-DBLdW_PBKxzc4,22083
38
+ flaxdiff-0.1.17.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
39
+ flaxdiff-0.1.17.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
40
+ flaxdiff-0.1.17.dist-info/RECORD,,