flaxdiff 0.1.16__py3-none-any.whl → 0.1.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -13,7 +13,7 @@ from typing import Any, Dict, List, Tuple
13
13
  import numpy as np
14
14
  from functools import partial
15
15
 
16
- from datasets import load_dataset, concatenate_datasets, Dataset
16
+ from datasets import load_dataset, concatenate_datasets, Dataset, load_from_disk
17
17
  from datasets.utils.file_utils import get_datasets_user_agent
18
18
  from concurrent.futures import ThreadPoolExecutor
19
19
  import io
@@ -27,6 +27,7 @@ USER_AGENT = get_datasets_user_agent()
27
27
  data_queue = Queue(16*2000)
28
28
  error_queue = Queue(16*2000)
29
29
 
30
+
30
31
  def fetch_single_image(image_url, timeout=None, retries=0):
31
32
  for _ in range(retries + 1):
32
33
  try:
@@ -42,19 +43,35 @@ def fetch_single_image(image_url, timeout=None, retries=0):
42
43
  image = None
43
44
  return image
44
45
 
46
+
47
+ def default_image_processor(image, image_shape, interpolation=cv2.INTER_LANCZOS4):
48
+ image = A.longest_max_size(image, max(
49
+ image_shape), interpolation=interpolation)
50
+ image = A.pad(
51
+ image,
52
+ min_height=image_shape[0],
53
+ min_width=image_shape[1],
54
+ border_mode=cv2.BORDER_CONSTANT,
55
+ value=[255, 255, 255],
56
+ )
57
+ return image
58
+
59
+
45
60
  def map_sample(
46
- url, caption,
61
+ url, caption,
47
62
  image_shape=(256, 256),
48
63
  timeout=15,
49
64
  retries=3,
50
65
  upscale_interpolation=cv2.INTER_LANCZOS4,
51
66
  downscale_interpolation=cv2.INTER_AREA,
67
+ image_processor=default_image_processor,
52
68
  ):
53
69
  try:
54
- image = fetch_single_image(url, timeout=timeout, retries=retries) # Assuming fetch_single_image is defined elsewhere
70
+ # Assuming fetch_single_image is defined elsewhere
71
+ image = fetch_single_image(url, timeout=timeout, retries=retries)
55
72
  if image is None:
56
73
  return
57
-
74
+
58
75
  image = np.array(image)
59
76
  original_height, original_width = image.shape[:2]
60
77
  # check if the image is too small
@@ -69,14 +86,10 @@ def map_sample(
69
86
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
70
87
  downscale = max(original_width, original_height) > max(image_shape)
71
88
  interpolation = downscale_interpolation if downscale else upscale_interpolation
72
- image = A.longest_max_size(image, max(image_shape), interpolation=interpolation)
73
- image = A.pad(
74
- image,
75
- min_height=image_shape[0],
76
- min_width=image_shape[1],
77
- border_mode=cv2.BORDER_CONSTANT,
78
- value=[255, 255, 255],
79
- )
89
+
90
+ image = image_processor(
91
+ image, image_shape, interpolation=interpolation)
92
+
80
93
  data_queue.put({
81
94
  "url": url,
82
95
  "caption": caption,
@@ -85,65 +98,74 @@ def map_sample(
85
98
  "original_width": original_width,
86
99
  })
87
100
  except Exception as e:
88
- print(f"Error in map_sample: {str(e)}")
89
101
  error_queue.put({
90
102
  "url": url,
91
103
  "caption": caption,
92
104
  "error": str(e)
93
105
  })
94
106
 
95
- def map_batch(batch, num_threads=256, image_shape=(256, 256), timeout=15, retries=3):
107
+
108
+ def map_batch(batch, num_threads=256, image_shape=(256, 256), timeout=15, retries=3, image_processor=default_image_processor):
96
109
  try:
97
- map_sample_fn = partial(map_sample, image_shape=image_shape, timeout=timeout, retries=retries)
110
+ map_sample_fn = partial(map_sample, image_shape=image_shape,
111
+ timeout=timeout, retries=retries, image_processor=image_processor)
98
112
  with ThreadPoolExecutor(max_workers=num_threads) as executor:
99
113
  executor.map(map_sample_fn, batch["url"], batch['caption'])
100
114
  except Exception as e:
101
- print(f"Error in map_batch: {str(e)}")
102
115
  error_queue.put({
103
116
  "batch": batch,
104
117
  "error": str(e)
105
118
  })
106
-
107
- def parallel_image_loader(dataset: Dataset, num_workers: int = 8, image_shape=(256, 256), num_threads=256):
108
- map_batch_fn = partial(map_batch, num_threads=num_threads, image_shape=image_shape)
119
+
120
+
121
+ def parallel_image_loader(dataset: Dataset, num_workers: int = 8, image_shape=(256, 256),
122
+ num_threads=256, timeout=15, retries=3, image_processor=default_image_processor):
123
+ map_batch_fn = partial(map_batch, num_threads=num_threads, image_shape=image_shape,
124
+ timeout=timeout, retries=retries, image_processor=image_processor)
109
125
  shard_len = len(dataset) // num_workers
110
126
  print(f"Local Shard lengths: {shard_len}")
111
127
  with multiprocessing.Pool(num_workers) as pool:
112
128
  iteration = 0
113
129
  while True:
114
130
  # Repeat forever
115
- print(f"Shuffling dataset with seed {iteration}")
116
- # dataset = dataset.shuffle(seed=iteration)
117
- shards = [dataset[i*shard_len:(i+1)*shard_len] for i in range(num_workers)]
131
+ shards = [dataset[i*shard_len:(i+1)*shard_len]
132
+ for i in range(num_workers)]
118
133
  print(f"mapping {len(shards)} shards")
119
134
  pool.map(map_batch_fn, shards)
120
135
  iteration += 1
121
-
136
+ print(f"Shuffling dataset with seed {iteration}")
137
+ dataset = dataset.shuffle(seed=iteration)
138
+
139
+
122
140
  class ImageBatchIterator:
123
- def __init__(self, dataset: Dataset, batch_size: int = 64, image_shape=(256, 256), num_workers: int = 8, num_threads=256):
141
+ def __init__(self, dataset: Dataset, batch_size: int = 64, image_shape=(256, 256),
142
+ num_workers: int = 8, num_threads=256, timeout=15, retries=3, image_processor=default_image_processor):
124
143
  self.dataset = dataset
125
144
  self.num_workers = num_workers
126
145
  self.batch_size = batch_size
127
- loader = partial(parallel_image_loader, num_threads=num_threads, image_shape=image_shape, num_workers=num_workers)
146
+ loader = partial(parallel_image_loader, num_threads=num_threads,
147
+ image_shape=image_shape, num_workers=num_workers,
148
+ timeout=timeout, retries=retries, image_processor=image_processor)
128
149
  self.thread = threading.Thread(target=loader, args=(dataset,))
129
150
  self.thread.start()
130
-
151
+
131
152
  def __iter__(self):
132
153
  return self
133
-
154
+
134
155
  def __next__(self):
135
156
  def fetcher(_):
136
157
  return data_queue.get()
137
158
  with ThreadPoolExecutor(max_workers=self.batch_size) as executor:
138
159
  batch = list(executor.map(fetcher, range(self.batch_size)))
139
160
  return batch
140
-
161
+
141
162
  def __del__(self):
142
163
  self.thread.join()
143
-
164
+
144
165
  def __len__(self):
145
166
  return len(self.dataset) // self.batch_size
146
167
 
168
+
147
169
  def default_collate(batch):
148
170
  urls = [sample["url"] for sample in batch]
149
171
  captions = [sample["caption"] for sample in batch]
@@ -153,7 +175,8 @@ def default_collate(batch):
153
175
  "caption": captions,
154
176
  "image": images,
155
177
  }
156
-
178
+
179
+
157
180
  def dataMapper(map: Dict[str, Any]):
158
181
  def _map(sample) -> Dict[str, Any]:
159
182
  return {
@@ -162,16 +185,17 @@ def dataMapper(map: Dict[str, Any]):
162
185
  }
163
186
  return _map
164
187
 
188
+
165
189
  class OnlineStreamingDataLoader():
166
190
  def __init__(
167
- self,
168
- dataset,
169
- batch_size=64,
191
+ self,
192
+ dataset,
193
+ batch_size=64,
170
194
  image_shape=(256, 256),
171
- num_workers=16,
195
+ num_workers=16,
172
196
  num_threads=512,
173
197
  default_split="all",
174
- pre_map_maker=dataMapper,
198
+ pre_map_maker=dataMapper,
175
199
  pre_map_def={
176
200
  "url": "URL",
177
201
  "caption": "TEXT",
@@ -180,40 +204,51 @@ class OnlineStreamingDataLoader():
180
204
  global_process_index=0,
181
205
  prefetch=1000,
182
206
  collate_fn=default_collate,
207
+ timeout=15,
208
+ retries=3,
209
+ image_processor=default_image_processor,
183
210
  ):
184
211
  if isinstance(dataset, str):
185
212
  dataset_path = dataset
186
213
  print("Loading dataset from path")
187
- dataset = load_dataset(dataset_path, split=default_split)
214
+ if "gs://" in dataset:
215
+ dataset = load_from_disk(dataset_path)
216
+ else:
217
+ dataset = load_dataset(dataset_path, split=default_split)
188
218
  elif isinstance(dataset, list):
189
219
  if isinstance(dataset[0], str):
190
220
  print("Loading multiple datasets from paths")
191
- dataset = [load_dataset(dataset_path, split=default_split) for dataset_path in dataset]
221
+ dataset = [load_from_disk(dataset_path) if "gs://" in dataset_path else load_dataset(
222
+ dataset_path, split=default_split) for dataset_path in dataset]
192
223
  print("Concatenating multiple datasets")
193
224
  dataset = concatenate_datasets(dataset)
194
- dataset = dataset.map(pre_map_maker(pre_map_def), batched=True, batch_size=10000000)
195
- self.dataset = dataset.shard(num_shards=global_process_count, index=global_process_index)
225
+ dataset = dataset.shuffle(seed=0)
226
+ # dataset = dataset.map(pre_map_maker(pre_map_def), batched=True, batch_size=10000000)
227
+ self.dataset = dataset.shard(
228
+ num_shards=global_process_count, index=global_process_index)
196
229
  print(f"Dataset length: {len(dataset)}")
197
- self.iterator = ImageBatchIterator(self.dataset, image_shape=image_shape, num_workers=num_workers, batch_size=batch_size, num_threads=num_threads)
230
+ self.iterator = ImageBatchIterator(self.dataset, image_shape=image_shape,
231
+ num_workers=num_workers, batch_size=batch_size, num_threads=num_threads,
232
+ timeout=timeout, retries=retries, image_processor=image_processor)
198
233
  self.collate_fn = collate_fn
199
234
  self.batch_size = batch_size
200
-
235
+
201
236
  # Launch a thread to load batches in the background
202
237
  self.batch_queue = queue.Queue(prefetch)
203
-
238
+
204
239
  def batch_loader():
205
240
  for batch in self.iterator:
206
241
  self.batch_queue.put(batch)
207
-
242
+
208
243
  self.loader_thread = threading.Thread(target=batch_loader)
209
244
  self.loader_thread.start()
210
-
245
+
211
246
  def __iter__(self):
212
247
  return self
213
-
248
+
214
249
  def __next__(self):
215
250
  return self.collate_fn(self.batch_queue.get())
216
251
  # return self.collate_fn(next(self.iterator))
217
-
252
+
218
253
  def __len__(self):
219
254
  return len(self.dataset)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flaxdiff
3
- Version: 0.1.16
3
+ Version: 0.1.17
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author: Ashish Kumar Singh
6
6
  Author-email: ashishkmr472@gmail.com
@@ -1,7 +1,7 @@
1
1
  flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  flaxdiff/utils.py,sha256=B0GcHlzlVYDNEIdh2v5qmP4u0neIT-FqexNohuyuCvg,2452
3
3
  flaxdiff/data/__init__.py,sha256=PM3PkHihyohT5SHVYKc8vQ4IeVfGPpCktkSVwvqMjQ4,52
4
- flaxdiff/data/online_loader.py,sha256=nrtZU4srZHsg3iN0sG91y_6nY7QtYXRPLk5rGn_BTIU,7728
4
+ flaxdiff/data/online_loader.py,sha256=BM4Le-4BUo8MJpRzGIA2nMHKm4-WynQ2BOdiQz0JCDs,8791
5
5
  flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,26
6
6
  flaxdiff/models/attention.py,sha256=ZbDGIb5Q6FRqJ6qRY660cqw4WvF9IwCnhEuYdTpLPdM,13023
7
7
  flaxdiff/models/common.py,sha256=fd-Fl0VCNEBjijHNwGBqYL5VvXe9u0347h25czNTmRw,10780
@@ -34,7 +34,7 @@ flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk,
34
34
  flaxdiff/trainer/autoencoder_trainer.py,sha256=al7AsZ7yeDMEiDD-gbcXf0ADq_xfk1VMxvg24GfA-XQ,7008
35
35
  flaxdiff/trainer/diffusion_trainer.py,sha256=wKkg63DWZjx2MoM3VQNCDIr40rWN8fUGxH9jWWxfZao,9373
36
36
  flaxdiff/trainer/simple_trainer.py,sha256=Z77zRS5viJpd2Mpl6sonJk5WcnEWi2Cd4gl4u5tIX2M,18206
37
- flaxdiff-0.1.16.dist-info/METADATA,sha256=BM2RLOiCDqRSWO_owxvWmL_PS3aFtNokPbf-qAuyK4o,22083
38
- flaxdiff-0.1.16.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
39
- flaxdiff-0.1.16.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
40
- flaxdiff-0.1.16.dist-info/RECORD,,
37
+ flaxdiff-0.1.17.dist-info/METADATA,sha256=2Nr_T2yg3XHFt2jBuUXo8FxLYM8si-DBLdW_PBKxzc4,22083
38
+ flaxdiff-0.1.17.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
39
+ flaxdiff-0.1.17.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
40
+ flaxdiff-0.1.17.dist-info/RECORD,,