flaxdiff 0.1.36.1__py3-none-any.whl → 0.1.36.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. flaxdiff/data/__init__.py +1 -0
  2. flaxdiff/data/dataset_map.py +71 -0
  3. flaxdiff/data/datasets.py +169 -0
  4. flaxdiff/data/online_loader.py +363 -0
  5. flaxdiff/data/sources/gcs.py +81 -0
  6. flaxdiff/data/sources/tfds.py +67 -0
  7. flaxdiff/metrics/inception.py +658 -0
  8. flaxdiff/metrics/utils.py +49 -0
  9. flaxdiff/models/__init__.py +1 -0
  10. flaxdiff/models/attention.py +368 -0
  11. flaxdiff/models/autoencoder/__init__.py +2 -0
  12. flaxdiff/models/autoencoder/autoencoder.py +19 -0
  13. flaxdiff/models/autoencoder/diffusers.py +91 -0
  14. flaxdiff/models/autoencoder/simple_autoenc.py +26 -0
  15. flaxdiff/models/common.py +346 -0
  16. flaxdiff/models/favor_fastattn.py +723 -0
  17. flaxdiff/models/simple_unet.py +233 -0
  18. flaxdiff/models/simple_vit.py +180 -0
  19. flaxdiff/predictors/__init__.py +96 -0
  20. flaxdiff/samplers/__init__.py +7 -0
  21. flaxdiff/samplers/common.py +165 -0
  22. flaxdiff/samplers/ddim.py +10 -0
  23. flaxdiff/samplers/ddpm.py +37 -0
  24. flaxdiff/samplers/euler.py +56 -0
  25. flaxdiff/samplers/heun_sampler.py +27 -0
  26. flaxdiff/samplers/multistep_dpm.py +59 -0
  27. flaxdiff/samplers/rk4_sampler.py +34 -0
  28. flaxdiff/schedulers/__init__.py +6 -0
  29. flaxdiff/schedulers/common.py +98 -0
  30. flaxdiff/schedulers/continuous.py +12 -0
  31. flaxdiff/schedulers/cosine.py +40 -0
  32. flaxdiff/schedulers/discrete.py +74 -0
  33. flaxdiff/schedulers/exp.py +13 -0
  34. flaxdiff/schedulers/karras.py +69 -0
  35. flaxdiff/schedulers/linear.py +14 -0
  36. flaxdiff/schedulers/sqrt.py +10 -0
  37. flaxdiff/trainer/__init__.py +2 -0
  38. flaxdiff/trainer/autoencoder_trainer.py +182 -0
  39. flaxdiff/trainer/diffusion_trainer.py +326 -0
  40. flaxdiff/trainer/simple_trainer.py +540 -0
  41. flaxdiff/trainer/video_diffusion_trainer.py +62 -0
  42. {flaxdiff-0.1.36.1.dist-info → flaxdiff-0.1.36.3.dist-info}/METADATA +1 -1
  43. flaxdiff-0.1.36.3.dist-info/RECORD +47 -0
  44. flaxdiff-0.1.36.1.dist-info/RECORD +0 -6
  45. {flaxdiff-0.1.36.1.dist-info → flaxdiff-0.1.36.3.dist-info}/WHEEL +0 -0
  46. {flaxdiff-0.1.36.1.dist-info → flaxdiff-0.1.36.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1 @@
1
+ from .online_loader import OnlineStreamingDataLoader
@@ -0,0 +1,71 @@
1
+ from .sources.tfds import data_source_tfds, tfds_augmenters
2
+ from .sources.gcs import data_source_gcs, data_source_combined_gcs, gcs_augmenters
3
+
4
+ # Configure the following for your datasets
5
+ datasetMap = {
6
+ "oxford_flowers102": {
7
+ "source": data_source_tfds("oxford_flowers102", use_tf=False),
8
+ "augmenter": tfds_augmenters,
9
+ },
10
+ "cc12m": {
11
+ "source": data_source_gcs('arrayrecord2/cc12m'),
12
+ "augmenter": gcs_augmenters,
13
+ },
14
+ "laiona_coco": {
15
+ "source": data_source_gcs('arrayrecord2/laion-aesthetics-12m+mscoco-2017'),
16
+ "augmenter": gcs_augmenters,
17
+ },
18
+ "aesthetic_coyo": {
19
+ "source": data_source_gcs('arrayrecords/aestheticCoyo_0.25clip_6aesthetic'),
20
+ "augmenter": gcs_augmenters,
21
+ },
22
+ "combined_aesthetic": {
23
+ "source": data_source_combined_gcs([
24
+ 'arrayrecord2/laion-aesthetics-12m+mscoco-2017',
25
+ 'arrayrecords/aestheticCoyo_0.25clip_6aesthetic',
26
+ 'arrayrecord2/cc12m',
27
+ 'arrayrecords/aestheticCoyo_0.25clip_6aesthetic',
28
+ ]),
29
+ "augmenter": gcs_augmenters,
30
+ },
31
+ "laiona_coco_coyo": {
32
+ "source": data_source_combined_gcs([
33
+ 'arrayrecords/aestheticCoyo_0.25clip_6aesthetic',
34
+ 'arrayrecord2/laion-aesthetics-12m+mscoco-2017',
35
+ 'arrayrecords/aestheticCoyo_0.25clip_6aesthetic',
36
+ ]),
37
+ "augmenter": gcs_augmenters,
38
+ },
39
+ "combined_30m": {
40
+ "source": data_source_combined_gcs([
41
+ 'arrayrecord2/laion-aesthetics-12m+mscoco-2017',
42
+ 'arrayrecord2/cc12m',
43
+ 'arrayrecord2/aestheticCoyo_0.26_clip_5.5aesthetic_256plus',
44
+ "arrayrecord2/playground+leonardo_x4+cc3m.parquet",
45
+ ]),
46
+ "augmenter": gcs_augmenters,
47
+ }
48
+ }
49
+
50
+ onlineDatasetMap = {
51
+ "combined_online": {
52
+ "source": [
53
+ # "gs://flaxdiff-datasets-regional/datasets/laion-aesthetics-12m+mscoco-2017.parquet"
54
+ # "ChristophSchuhmann/MS_COCO_2017_URL_TEXT",
55
+ # "dclure/laion-aesthetics-12m-umap",
56
+ "gs://flaxdiff-datasets-regional/datasets/laion-aesthetics-12m+mscoco-2017",
57
+ "gs://flaxdiff-datasets-regional/datasets/coyo700m-aesthetic-5.4_25M",
58
+ "gs://flaxdiff-datasets-regional/datasets/leonardo-liked-1.8m",
59
+ "gs://flaxdiff-datasets-regional/datasets/leonardo-liked-1.8m",
60
+ "gs://flaxdiff-datasets-regional/datasets/leonardo-liked-1.8m",
61
+ "gs://flaxdiff-datasets-regional/datasets/cc12m",
62
+ "gs://flaxdiff-datasets-regional/datasets/playground-liked",
63
+ "gs://flaxdiff-datasets-regional/datasets/leonardo-liked-1.8m",
64
+ "gs://flaxdiff-datasets-regional/datasets/leonardo-liked-1.8m",
65
+ "gs://flaxdiff-datasets-regional/datasets/cc3m",
66
+ "gs://flaxdiff-datasets-regional/datasets/cc3m",
67
+ "gs://flaxdiff-datasets-regional/datasets/laion2B-en-aesthetic-4.2_37M",
68
+ # "gs://flaxdiff-datasets-regional/datasets/laiion400m-185M"
69
+ ]
70
+ }
71
+ }
@@ -0,0 +1,169 @@
1
+ import jax.numpy as jnp
2
+ import grain.python as pygrain
3
+ from typing import Dict
4
+ import numpy as np
5
+ import jax
6
+ from flaxdiff.utils import convert_to_global_tree, AutoTextTokenizer
7
+ from .dataset_map import datasetMap, onlineDatasetMap
8
+ import traceback
9
+ from .online_loader import OnlineStreamingDataLoader
10
+ import queue
11
+ from jax.sharding import Mesh
12
+ import threading
13
+
14
+ def batch_mesh_map(mesh):
15
+ class augmenters(pygrain.MapTransform):
16
+ def __init__(self, *args, **kwargs):
17
+ super().__init__(*args, **kwargs)
18
+
19
+ def map(self, batch) -> Dict[str, jnp.array]:
20
+ return convert_to_global_tree(mesh, batch)
21
+ return augmenters
22
+
23
+ def get_dataset_grain(
24
+ data_name="cc12m",
25
+ batch_size=64,
26
+ image_scale=256,
27
+ count=None,
28
+ num_epochs=None,
29
+ method=jax.image.ResizeMethod.LANCZOS3,
30
+ worker_count=32,
31
+ read_thread_count=64,
32
+ read_buffer_size=50,
33
+ worker_buffer_size=20,
34
+ seed=0,
35
+ dataset_source="/mnt/gcs_mount/arrayrecord2/cc12m/",
36
+ ):
37
+ dataset = datasetMap[data_name]
38
+ data_source = dataset["source"](dataset_source)
39
+ augmenter = dataset["augmenter"](image_scale, method)
40
+
41
+ local_batch_size = batch_size // jax.process_count()
42
+
43
+ sampler = pygrain.IndexSampler(
44
+ num_records=len(data_source) if count is None else count,
45
+ shuffle=True,
46
+ seed=seed,
47
+ num_epochs=num_epochs,
48
+ shard_options=pygrain.ShardByJaxProcess(),
49
+ )
50
+
51
+ def get_trainset():
52
+ transformations = [
53
+ augmenter(),
54
+ pygrain.Batch(local_batch_size, drop_remainder=True),
55
+ ]
56
+
57
+ # if mesh != None:
58
+ # transformations += [batch_mesh_map(mesh)]
59
+
60
+ loader = pygrain.DataLoader(
61
+ data_source=data_source,
62
+ sampler=sampler,
63
+ operations=transformations,
64
+ worker_count=worker_count,
65
+ read_options=pygrain.ReadOptions(
66
+ read_thread_count, read_buffer_size
67
+ ),
68
+ worker_buffer_size=worker_buffer_size,
69
+ )
70
+ return loader
71
+
72
+
73
+ return {
74
+ "train": get_trainset,
75
+ "train_len": len(data_source),
76
+ "local_batch_size": local_batch_size,
77
+ "global_batch_size": batch_size,
78
+ # "null_labels": null_labels,
79
+ # "null_labels_full": null_labels_full,
80
+ # "model": model,
81
+ # "tokenizer": tokenizer,
82
+ }
83
+
84
+ def generate_collate_fn():
85
+ auto_tokenize = AutoTextTokenizer(tensor_type="np")
86
+ def default_collate(batch):
87
+ try:
88
+ # urls = [sample["url"] for sample in batch]
89
+ captions = [sample["caption"] for sample in batch]
90
+ results = auto_tokenize(captions)
91
+ images = np.stack([sample["image"] for sample in batch], axis=0)
92
+ return {
93
+ "image": images,
94
+ "input_ids": results['input_ids'],
95
+ "attention_mask": results['attention_mask'],
96
+ }
97
+ except Exception as e:
98
+ print("Error in collate function", e, [sample["image"].shape for sample in batch])
99
+ traceback.print_exc()
100
+
101
+ return default_collate
102
+
103
+ def get_dataset_online(
104
+ data_name="combined_online",
105
+ batch_size=64,
106
+ image_scale=256,
107
+ count=None,
108
+ num_epochs=None,
109
+ method=jax.image.ResizeMethod.LANCZOS3,
110
+ worker_count=32,
111
+ read_thread_count=64,
112
+ read_buffer_size=50,
113
+ worker_buffer_size=20,
114
+ seed=0,
115
+ dataset_source="/mnt/gcs_mount/arrayrecord2/cc12m/",
116
+ ):
117
+ local_batch_size = batch_size // jax.process_count()
118
+
119
+ sources = onlineDatasetMap[data_name]["source"]
120
+ dataloader = OnlineStreamingDataLoader(
121
+ sources,
122
+ batch_size=local_batch_size,
123
+ num_workers=worker_count,
124
+ num_threads=read_thread_count,
125
+ image_shape=(image_scale, image_scale),
126
+ global_process_count=jax.process_count(),
127
+ global_process_index=jax.process_index(),
128
+ prefetch=worker_buffer_size,
129
+ collate_fn=generate_collate_fn(),
130
+ default_split="train",
131
+ )
132
+
133
+ def get_trainset(mesh: Mesh = None):
134
+ if mesh != None:
135
+ class dataLoaderWithMesh:
136
+ def __init__(self, dataloader, mesh):
137
+ self.dataloader = dataloader
138
+ self.mesh = mesh
139
+ self.tmp_queue = queue.Queue(worker_buffer_size)
140
+ def batch_loader():
141
+ for batch in self.dataloader:
142
+ try:
143
+ self.tmp_queue.put(convert_to_global_tree(mesh, batch))
144
+ except Exception as e:
145
+ print("Error processing batch", e)
146
+ self.loader_thread = threading.Thread(target=batch_loader)
147
+ self.loader_thread.start()
148
+
149
+ def __iter__(self):
150
+ return self
151
+
152
+ def __next__(self):
153
+ return self.tmp_queue.get()
154
+
155
+ dataloader_with_mesh = dataLoaderWithMesh(dataloader, mesh)
156
+
157
+ return dataloader_with_mesh
158
+ return dataloader
159
+
160
+ return {
161
+ "train": get_trainset,
162
+ "train_len": len(dataloader) * jax.process_count(),
163
+ "local_batch_size": local_batch_size,
164
+ "global_batch_size": batch_size,
165
+ # "null_labels": null_labels,
166
+ # "null_labels_full": null_labels_full,
167
+ # "model": model,
168
+ # "tokenizer": tokenizer,
169
+ }
@@ -0,0 +1,363 @@
1
+ import multiprocessing
2
+ import threading
3
+ from multiprocessing import Queue
4
+ # from arrayqueues.shared_arrays import ArrayQueue
5
+ # from faster_fifo import Queue
6
+ import time
7
+ import albumentations as A
8
+ import queue
9
+ import cv2
10
+ from functools import partial
11
+ from typing import Any, Dict, List, Tuple
12
+
13
+ import numpy as np
14
+ from functools import partial
15
+
16
+ from datasets import load_dataset, concatenate_datasets, Dataset, load_from_disk
17
+ from datasets.utils.file_utils import get_datasets_user_agent
18
+ from concurrent.futures import ThreadPoolExecutor
19
+ import io
20
+ import urllib
21
+
22
+ import PIL.Image
23
+ import cv2
24
+ import traceback
25
+
26
+ USER_AGENT = get_datasets_user_agent()
27
+
28
+ data_queue = Queue(16*2000)
29
+
30
+
31
+ def fetch_single_image(image_url, timeout=None, retries=0):
32
+ for _ in range(retries + 1):
33
+ try:
34
+ request = urllib.request.Request(
35
+ image_url,
36
+ data=None,
37
+ headers={"user-agent": USER_AGENT},
38
+ )
39
+ with urllib.request.urlopen(request, timeout=timeout) as req:
40
+ image = PIL.Image.open(io.BytesIO(req.read()))
41
+ break
42
+ except Exception:
43
+ image = None
44
+ return image
45
+
46
+
47
+ def default_image_processor(
48
+ image, image_shape,
49
+ min_image_shape=(128, 128),
50
+ upscale_interpolation=cv2.INTER_CUBIC,
51
+ downscale_interpolation=cv2.INTER_AREA,
52
+ ):
53
+ try:
54
+ image = np.array(image)
55
+ if len(image.shape) != 3 or image.shape[2] != 3:
56
+ return None, 0, 0
57
+ original_height, original_width = image.shape[:2]
58
+ # check if the image is too small
59
+ if min(original_height, original_width) < min(min_image_shape):
60
+ return None, original_height, original_width
61
+ # check if wrong aspect ratio
62
+ if max(original_height, original_width) / min(original_height, original_width) > 2.4:
63
+ return None, original_height, original_width
64
+ # check if the variance is too low
65
+ if np.std(image) < 1e-5:
66
+ return None, original_height, original_width
67
+ # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
68
+ downscale = max(original_width, original_height) > max(image_shape)
69
+ interpolation = downscale_interpolation if downscale else upscale_interpolation
70
+
71
+ image = A.longest_max_size(image, max(
72
+ image_shape), interpolation=interpolation)
73
+ image = A.pad(
74
+ image,
75
+ min_height=image_shape[0],
76
+ min_width=image_shape[1],
77
+ border_mode=cv2.BORDER_CONSTANT,
78
+ value=[255, 255, 255],
79
+ )
80
+ return image, original_height, original_width
81
+ except Exception as e:
82
+ # print("Error processing image", e, image_shape, interpolation)
83
+ # traceback.print_exc()
84
+ return None, 0, 0
85
+
86
+
87
+ def map_sample(
88
+ url,
89
+ caption,
90
+ image_shape=(256, 256),
91
+ min_image_shape=(128, 128),
92
+ timeout=15,
93
+ retries=3,
94
+ upscale_interpolation=cv2.INTER_CUBIC,
95
+ downscale_interpolation=cv2.INTER_AREA,
96
+ image_processor=default_image_processor,
97
+ ):
98
+ try:
99
+ # Assuming fetch_single_image is defined elsewhere
100
+ image = fetch_single_image(url, timeout=timeout, retries=retries)
101
+ if image is None:
102
+ return
103
+
104
+ image, original_height, original_width = image_processor(
105
+ image, image_shape, min_image_shape=min_image_shape,
106
+ upscale_interpolation=upscale_interpolation,
107
+ downscale_interpolation=downscale_interpolation,
108
+ )
109
+
110
+ if image is None:
111
+ return
112
+
113
+ data_queue.put({
114
+ "url": url,
115
+ "caption": caption,
116
+ "image": image,
117
+ "original_height": original_height,
118
+ "original_width": original_width,
119
+ })
120
+ except Exception as e:
121
+ # print(f"Error maping sample {url}", e)
122
+ # traceback.print_exc()
123
+ # error_queue.put_nowait({
124
+ # "url": url,
125
+ # "caption": caption,
126
+ # "error": str(e)
127
+ # })
128
+ pass
129
+
130
+ def default_feature_extractor(sample):
131
+ url = None
132
+ if "url" in sample:
133
+ url = sample["url"]
134
+ elif "URL" in sample:
135
+ url = sample["URL"]
136
+ elif "image_url" in sample:
137
+ url = sample["image_url"]
138
+ else:
139
+ print("No url found in sample, skipping", sample.keys())
140
+
141
+ caption = None
142
+ if "caption" in sample:
143
+ caption = sample["caption"]
144
+ elif "CAPTION" in sample:
145
+ caption = sample["CAPTION"]
146
+ elif "txt" in sample:
147
+ caption = sample["txt"]
148
+ elif "TEXT" in sample:
149
+ caption = sample["TEXT"]
150
+ elif "text" in sample:
151
+ caption = sample["text"]
152
+ else:
153
+ print("No caption found in sample, skipping", sample.keys())
154
+
155
+ return {
156
+ "url": url,
157
+ "caption": caption,
158
+ }
159
+
160
+ def map_batch(
161
+ batch, num_threads=256, image_shape=(256, 256),
162
+ min_image_shape=(128, 128),
163
+ timeout=15, retries=3, image_processor=default_image_processor,
164
+ upscale_interpolation=cv2.INTER_CUBIC,
165
+ downscale_interpolation=cv2.INTER_AREA,
166
+ feature_extractor=default_feature_extractor,
167
+ ):
168
+ try:
169
+ map_sample_fn = partial(
170
+ map_sample, image_shape=image_shape, min_image_shape=min_image_shape,
171
+ timeout=timeout, retries=retries, image_processor=image_processor,
172
+ upscale_interpolation=upscale_interpolation,
173
+ downscale_interpolation=downscale_interpolation
174
+ )
175
+ with ThreadPoolExecutor(max_workers=num_threads) as executor:
176
+ features = feature_extractor(batch)
177
+ url, caption = features["url"], features["caption"]
178
+ executor.map(map_sample_fn, url, caption)
179
+ except Exception as e:
180
+ print(f"Error maping batch", e)
181
+ traceback.print_exc()
182
+ # error_queue.put_nowait({
183
+ # "batch": batch,
184
+ # "error": str(e)
185
+ # })
186
+ pass
187
+
188
+
189
+ def parallel_image_loader(
190
+ dataset: Dataset, num_workers: int = 8, image_shape=(256, 256),
191
+ min_image_shape=(128, 128),
192
+ num_threads=256, timeout=15, retries=3, image_processor=default_image_processor,
193
+ upscale_interpolation=cv2.INTER_CUBIC,
194
+ downscale_interpolation=cv2.INTER_AREA,
195
+ feature_extractor=default_feature_extractor,
196
+ ):
197
+ map_batch_fn = partial(
198
+ map_batch, num_threads=num_threads, image_shape=image_shape,
199
+ min_image_shape=min_image_shape,
200
+ timeout=timeout, retries=retries, image_processor=image_processor,
201
+ upscale_interpolation=upscale_interpolation,
202
+ downscale_interpolation=downscale_interpolation,
203
+ feature_extractor=feature_extractor
204
+ )
205
+ shard_len = len(dataset) // num_workers
206
+ print(f"Local Shard lengths: {shard_len}")
207
+ with multiprocessing.Pool(num_workers) as pool:
208
+ iteration = 0
209
+ while True:
210
+ # Repeat forever
211
+ shards = [dataset[i*shard_len:(i+1)*shard_len]
212
+ for i in range(num_workers)]
213
+ print(f"mapping {len(shards)} shards")
214
+ pool.map(map_batch_fn, shards)
215
+ iteration += 1
216
+ print(f"Shuffling dataset with seed {iteration}")
217
+ dataset = dataset.shuffle(seed=iteration)
218
+ # Clear the error queue
219
+ # while not error_queue.empty():
220
+ # error_queue.get_nowait()
221
+
222
+
223
+ class ImageBatchIterator:
224
+ def __init__(
225
+ self, dataset: Dataset, batch_size: int = 64, image_shape=(256, 256),
226
+ min_image_shape=(128, 128),
227
+ num_workers: int = 8, num_threads=256, timeout=15, retries=3,
228
+ image_processor=default_image_processor,
229
+ upscale_interpolation=cv2.INTER_CUBIC,
230
+ downscale_interpolation=cv2.INTER_AREA,
231
+ feature_extractor=default_feature_extractor,
232
+ ):
233
+ self.dataset = dataset
234
+ self.num_workers = num_workers
235
+ self.batch_size = batch_size
236
+ loader = partial(
237
+ parallel_image_loader,
238
+ num_threads=num_threads,
239
+ image_shape=image_shape,
240
+ min_image_shape=min_image_shape,
241
+ num_workers=num_workers,
242
+ timeout=timeout, retries=retries,
243
+ image_processor=image_processor,
244
+ upscale_interpolation=upscale_interpolation,
245
+ downscale_interpolation=downscale_interpolation,
246
+ feature_extractor=feature_extractor
247
+ )
248
+ self.thread = threading.Thread(target=loader, args=(dataset,))
249
+ self.thread.start()
250
+
251
+ def __iter__(self):
252
+ return self
253
+
254
+ def __next__(self):
255
+ def fetcher(_):
256
+ return data_queue.get()
257
+ with ThreadPoolExecutor(max_workers=self.batch_size) as executor:
258
+ batch = list(executor.map(fetcher, range(self.batch_size)))
259
+ return batch
260
+
261
+ def __del__(self):
262
+ self.thread.join()
263
+
264
+ def __len__(self):
265
+ return len(self.dataset) // self.batch_size
266
+
267
+
268
+ def default_collate(batch):
269
+ urls = [sample["url"] for sample in batch]
270
+ captions = [sample["caption"] for sample in batch]
271
+ images = np.stack([sample["image"] for sample in batch], axis=0)
272
+ return {
273
+ "url": urls,
274
+ "caption": captions,
275
+ "image": images,
276
+ }
277
+
278
+
279
+ def dataMapper(map: Dict[str, Any]):
280
+ def _map(sample) -> Dict[str, Any]:
281
+ return {
282
+ "url": sample[map["url"]],
283
+ "caption": sample[map["caption"]],
284
+ }
285
+ return _map
286
+
287
+
288
+ class OnlineStreamingDataLoader():
289
+ def __init__(
290
+ self,
291
+ dataset,
292
+ batch_size=64,
293
+ image_shape=(256, 256),
294
+ min_image_shape=(128, 128),
295
+ num_workers=16,
296
+ num_threads=512,
297
+ default_split="all",
298
+ pre_map_maker=dataMapper,
299
+ pre_map_def={
300
+ "url": "URL",
301
+ "caption": "TEXT",
302
+ },
303
+ global_process_count=1,
304
+ global_process_index=0,
305
+ prefetch=1000,
306
+ collate_fn=default_collate,
307
+ timeout=15,
308
+ retries=3,
309
+ image_processor=default_image_processor,
310
+ upscale_interpolation=cv2.INTER_CUBIC,
311
+ downscale_interpolation=cv2.INTER_AREA,
312
+ feature_extractor=default_feature_extractor,
313
+ ):
314
+ if isinstance(dataset, str):
315
+ dataset_path = dataset
316
+ print("Loading dataset from path")
317
+ if "gs://" in dataset:
318
+ dataset = load_from_disk(dataset_path)
319
+ else:
320
+ dataset = load_dataset(dataset_path, split=default_split)
321
+ elif isinstance(dataset, list):
322
+ if isinstance(dataset[0], str):
323
+ print("Loading multiple datasets from paths")
324
+ dataset = [load_from_disk(dataset_path) if "gs://" in dataset_path else load_dataset(
325
+ dataset_path, split=default_split) for dataset_path in dataset]
326
+ print("Concatenating multiple datasets")
327
+ dataset = concatenate_datasets(dataset)
328
+ dataset = dataset.shuffle(seed=0)
329
+ # dataset = dataset.map(pre_map_maker(pre_map_def), batched=True, batch_size=10000000)
330
+ self.dataset = dataset.shard(
331
+ num_shards=global_process_count, index=global_process_index)
332
+ print(f"Dataset length: {len(dataset)}")
333
+ self.iterator = ImageBatchIterator(self.dataset, image_shape=image_shape,
334
+ min_image_shape=min_image_shape,
335
+ num_workers=num_workers, batch_size=batch_size, num_threads=num_threads,
336
+ timeout=timeout, retries=retries, image_processor=image_processor,
337
+ upscale_interpolation=upscale_interpolation,
338
+ downscale_interpolation=downscale_interpolation,
339
+ feature_extractor=feature_extractor)
340
+ self.batch_size = batch_size
341
+
342
+ # Launch a thread to load batches in the background
343
+ self.batch_queue = queue.Queue(prefetch)
344
+
345
+ def batch_loader():
346
+ for batch in self.iterator:
347
+ try:
348
+ self.batch_queue.put(collate_fn(batch))
349
+ except Exception as e:
350
+ print("Error collating batch", e)
351
+
352
+ self.loader_thread = threading.Thread(target=batch_loader)
353
+ self.loader_thread.start()
354
+
355
+ def __iter__(self):
356
+ return self
357
+
358
+ def __next__(self):
359
+ return self.batch_queue.get()
360
+ # return self.collate_fn(next(self.iterator))
361
+
362
+ def __len__(self):
363
+ return len(self.dataset)
@@ -0,0 +1,81 @@
1
+ import cv2
2
+ import jax.numpy as jnp
3
+ import grain.python as pygrain
4
+ from flaxdiff.utils import AutoTextTokenizer
5
+ from typing import Dict
6
+ import os
7
+ import struct as st
8
+ from functools import partial
9
+ import numpy as np
10
+
11
+ # -----------------------------------------------------------------------------------------------#
12
+ # CC12m and other GCS data sources --------------------------------------------------------------#
13
+ # -----------------------------------------------------------------------------------------------#
14
+
15
+ def data_source_gcs(source='arrayrecord/laion-aesthetics-12m+mscoco-2017'):
16
+ def data_source(base="/home/mrwhite0racle/gcs_mount"):
17
+ records_path = os.path.join(base, source)
18
+ records = [os.path.join(records_path, i) for i in os.listdir(
19
+ records_path) if 'array_record' in i]
20
+ ds = pygrain.ArrayRecordDataSource(records)
21
+ return ds
22
+ return data_source
23
+
24
+ def data_source_combined_gcs(
25
+ sources=[]):
26
+ def data_source(base="/home/mrwhite0racle/gcs_mount"):
27
+ records_paths = [os.path.join(base, source) for source in sources]
28
+ records = []
29
+ for records_path in records_paths:
30
+ records += [os.path.join(records_path, i) for i in os.listdir(
31
+ records_path) if 'array_record' in i]
32
+ ds = pygrain.ArrayRecordDataSource(records)
33
+ return ds
34
+ return data_source
35
+
36
+ def unpack_dict_of_byte_arrays(packed_data):
37
+ unpacked_dict = {}
38
+ offset = 0
39
+ while offset < len(packed_data):
40
+ # Unpack the key length
41
+ key_length = st.unpack_from('I', packed_data, offset)[0]
42
+ offset += st.calcsize('I')
43
+ # Unpack the key bytes and convert to string
44
+ key = packed_data[offset:offset+key_length].decode('utf-8')
45
+ offset += key_length
46
+ # Unpack the byte array length
47
+ byte_array_length = st.unpack_from('I', packed_data, offset)[0]
48
+ offset += st.calcsize('I')
49
+ # Unpack the byte array
50
+ byte_array = packed_data[offset:offset+byte_array_length]
51
+ offset += byte_array_length
52
+ unpacked_dict[key] = byte_array
53
+ return unpacked_dict
54
+
55
+ def image_augmenter(image, image_scale, method=cv2.INTER_AREA):
56
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
57
+ image = cv2.resize(image, (image_scale, image_scale),
58
+ interpolation=cv2.INTER_AREA)
59
+ return image
60
+
61
+ def gcs_augmenters(image_scale, method):
62
+ labelizer = lambda sample : sample['txt']
63
+ class augmenters(pygrain.MapTransform):
64
+ def __init__(self, *args, **kwargs):
65
+ super().__init__(*args, **kwargs)
66
+ self.auto_tokenize = AutoTextTokenizer(tensor_type="np")
67
+ self.image_augmenter = partial(image_augmenter, image_scale=image_scale, method=method)
68
+
69
+ def map(self, element) -> Dict[str, jnp.array]:
70
+ element = unpack_dict_of_byte_arrays(element)
71
+ image = np.asarray(bytearray(element['jpg']), dtype="uint8")
72
+ image = cv2.imdecode(image, cv2.IMREAD_UNCHANGED)
73
+ image = self.image_augmenter(image)
74
+ caption = labelizer(element).decode('utf-8')
75
+ results = self.auto_tokenize(caption)
76
+ return {
77
+ "image": image,
78
+ "input_ids": results['input_ids'][0],
79
+ "attention_mask": results['attention_mask'][0],
80
+ }
81
+ return augmenters