flaxdiff 0.1.36__py3-none-any.whl → 0.1.36.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. {flaxdiff-0.1.36.dist-info → flaxdiff-0.1.36.1.dist-info}/METADATA +13 -10
  2. flaxdiff-0.1.36.1.dist-info/RECORD +6 -0
  3. flaxdiff/data/__init__.py +0 -1
  4. flaxdiff/data/dataset_map.py +0 -71
  5. flaxdiff/data/datasets.py +0 -169
  6. flaxdiff/data/online_loader.py +0 -363
  7. flaxdiff/models/__init__.py +0 -1
  8. flaxdiff/models/attention.py +0 -368
  9. flaxdiff/models/autoencoder/__init__.py +0 -2
  10. flaxdiff/models/autoencoder/autoencoder.py +0 -19
  11. flaxdiff/models/autoencoder/diffusers.py +0 -91
  12. flaxdiff/models/autoencoder/simple_autoenc.py +0 -26
  13. flaxdiff/models/common.py +0 -346
  14. flaxdiff/models/favor_fastattn.py +0 -723
  15. flaxdiff/models/simple_unet.py +0 -233
  16. flaxdiff/models/simple_vit.py +0 -180
  17. flaxdiff/predictors/__init__.py +0 -96
  18. flaxdiff/samplers/__init__.py +0 -7
  19. flaxdiff/samplers/common.py +0 -165
  20. flaxdiff/samplers/ddim.py +0 -10
  21. flaxdiff/samplers/ddpm.py +0 -37
  22. flaxdiff/samplers/euler.py +0 -56
  23. flaxdiff/samplers/heun_sampler.py +0 -27
  24. flaxdiff/samplers/multistep_dpm.py +0 -59
  25. flaxdiff/samplers/rk4_sampler.py +0 -34
  26. flaxdiff/schedulers/__init__.py +0 -6
  27. flaxdiff/schedulers/common.py +0 -98
  28. flaxdiff/schedulers/continuous.py +0 -12
  29. flaxdiff/schedulers/cosine.py +0 -40
  30. flaxdiff/schedulers/discrete.py +0 -74
  31. flaxdiff/schedulers/exp.py +0 -13
  32. flaxdiff/schedulers/karras.py +0 -69
  33. flaxdiff/schedulers/linear.py +0 -14
  34. flaxdiff/schedulers/sqrt.py +0 -10
  35. flaxdiff/trainer/__init__.py +0 -2
  36. flaxdiff/trainer/autoencoder_trainer.py +0 -182
  37. flaxdiff/trainer/diffusion_trainer.py +0 -326
  38. flaxdiff/trainer/simple_trainer.py +0 -538
  39. flaxdiff/trainer/video_diffusion_trainer.py +0 -62
  40. flaxdiff-0.1.36.dist-info/RECORD +0 -43
  41. {flaxdiff-0.1.36.dist-info → flaxdiff-0.1.36.1.dist-info}/WHEEL +0 -0
  42. {flaxdiff-0.1.36.dist-info → flaxdiff-0.1.36.1.dist-info}/top_level.txt +0 -0
@@ -1,21 +1,24 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: flaxdiff
3
- Version: 0.1.36
3
+ Version: 0.1.36.1
4
4
  Summary: A versatile and easy to understand Diffusion library
5
- Author: Ashish Kumar Singh
6
- Author-email: ashishkmr472@gmail.com
5
+ Author-email: Ashish Kumar Singh <ashishkmr472@gmail.com>
6
+ License-Expression: MIT
7
7
  Description-Content-Type: text/markdown
8
8
  Requires-Dist: flax>=0.8.4
9
- Requires-Dist: optax>=0.2.2
10
9
  Requires-Dist: jax>=0.4.28
10
+ Requires-Dist: optax>=0.2.2
11
11
  Requires-Dist: orbax
12
+ Requires-Dist: numpy
12
13
  Requires-Dist: clu
13
- Dynamic: author
14
- Dynamic: author-email
15
- Dynamic: description
16
- Dynamic: description-content-type
17
- Dynamic: requires-dist
18
- Dynamic: summary
14
+ Requires-Dist: einops
15
+ Requires-Dist: tqdm
16
+ Requires-Dist: grain
17
+ Requires-Dist: termcolor
18
+ Requires-Dist: augmax
19
+ Requires-Dist: albumentations
20
+ Requires-Dist: rich
21
+ Requires-Dist: python-dotenv
19
22
 
20
23
  # ![](images/logo.jpeg "FlaxDiff")
21
24
 
@@ -0,0 +1,6 @@
1
+ flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ flaxdiff/utils.py,sha256=b_hFXsam2NICQYCFk0EOcqtBjM-RUqnN0NKTn0lQ070,6532
3
+ flaxdiff-0.1.36.1.dist-info/METADATA,sha256=Fl9tlGh_BgRnT-f8k4cEYnFj7G03VecUNOX_1zbJrmE,22310
4
+ flaxdiff-0.1.36.1.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
5
+ flaxdiff-0.1.36.1.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
6
+ flaxdiff-0.1.36.1.dist-info/RECORD,,
flaxdiff/data/__init__.py DELETED
@@ -1 +0,0 @@
1
- from .online_loader import OnlineStreamingDataLoader
@@ -1,71 +0,0 @@
1
- from .sources.tfds import data_source_tfds, tfds_augmenters
2
- from .sources.gcs import data_source_gcs, data_source_combined_gcs, gcs_augmenters
3
-
4
- # Configure the following for your datasets
5
- datasetMap = {
6
- "oxford_flowers102": {
7
- "source": data_source_tfds("oxford_flowers102", use_tf=False),
8
- "augmenter": tfds_augmenters,
9
- },
10
- "cc12m": {
11
- "source": data_source_gcs('arrayrecord2/cc12m'),
12
- "augmenter": gcs_augmenters,
13
- },
14
- "laiona_coco": {
15
- "source": data_source_gcs('arrayrecord2/laion-aesthetics-12m+mscoco-2017'),
16
- "augmenter": gcs_augmenters,
17
- },
18
- "aesthetic_coyo": {
19
- "source": data_source_gcs('arrayrecords/aestheticCoyo_0.25clip_6aesthetic'),
20
- "augmenter": gcs_augmenters,
21
- },
22
- "combined_aesthetic": {
23
- "source": data_source_combined_gcs([
24
- 'arrayrecord2/laion-aesthetics-12m+mscoco-2017',
25
- 'arrayrecords/aestheticCoyo_0.25clip_6aesthetic',
26
- 'arrayrecord2/cc12m',
27
- 'arrayrecords/aestheticCoyo_0.25clip_6aesthetic',
28
- ]),
29
- "augmenter": gcs_augmenters,
30
- },
31
- "laiona_coco_coyo": {
32
- "source": data_source_combined_gcs([
33
- 'arrayrecords/aestheticCoyo_0.25clip_6aesthetic',
34
- 'arrayrecord2/laion-aesthetics-12m+mscoco-2017',
35
- 'arrayrecords/aestheticCoyo_0.25clip_6aesthetic',
36
- ]),
37
- "augmenter": gcs_augmenters,
38
- },
39
- "combined_30m": {
40
- "source": data_source_combined_gcs([
41
- 'arrayrecord2/laion-aesthetics-12m+mscoco-2017',
42
- 'arrayrecord2/cc12m',
43
- 'arrayrecord2/aestheticCoyo_0.26_clip_5.5aesthetic_256plus',
44
- "arrayrecord2/playground+leonardo_x4+cc3m.parquet",
45
- ]),
46
- "augmenter": gcs_augmenters,
47
- }
48
- }
49
-
50
- onlineDatasetMap = {
51
- "combined_online": {
52
- "source": [
53
- # "gs://flaxdiff-datasets-regional/datasets/laion-aesthetics-12m+mscoco-2017.parquet"
54
- # "ChristophSchuhmann/MS_COCO_2017_URL_TEXT",
55
- # "dclure/laion-aesthetics-12m-umap",
56
- "gs://flaxdiff-datasets-regional/datasets/laion-aesthetics-12m+mscoco-2017",
57
- "gs://flaxdiff-datasets-regional/datasets/coyo700m-aesthetic-5.4_25M",
58
- "gs://flaxdiff-datasets-regional/datasets/leonardo-liked-1.8m",
59
- "gs://flaxdiff-datasets-regional/datasets/leonardo-liked-1.8m",
60
- "gs://flaxdiff-datasets-regional/datasets/leonardo-liked-1.8m",
61
- "gs://flaxdiff-datasets-regional/datasets/cc12m",
62
- "gs://flaxdiff-datasets-regional/datasets/playground-liked",
63
- "gs://flaxdiff-datasets-regional/datasets/leonardo-liked-1.8m",
64
- "gs://flaxdiff-datasets-regional/datasets/leonardo-liked-1.8m",
65
- "gs://flaxdiff-datasets-regional/datasets/cc3m",
66
- "gs://flaxdiff-datasets-regional/datasets/cc3m",
67
- "gs://flaxdiff-datasets-regional/datasets/laion2B-en-aesthetic-4.2_37M",
68
- # "gs://flaxdiff-datasets-regional/datasets/laiion400m-185M"
69
- ]
70
- }
71
- }
flaxdiff/data/datasets.py DELETED
@@ -1,169 +0,0 @@
1
- import jax.numpy as jnp
2
- import grain.python as pygrain
3
- from typing import Dict
4
- import numpy as np
5
- import jax
6
- from flaxdiff.utils import convert_to_global_tree, AutoTextTokenizer
7
- from .dataset_map import datasetMap, onlineDatasetMap
8
- import traceback
9
- from .online_loader import OnlineStreamingDataLoader
10
- import queue
11
- from jax.sharding import Mesh
12
- import threading
13
-
14
- def batch_mesh_map(mesh):
15
- class augmenters(pygrain.MapTransform):
16
- def __init__(self, *args, **kwargs):
17
- super().__init__(*args, **kwargs)
18
-
19
- def map(self, batch) -> Dict[str, jnp.array]:
20
- return convert_to_global_tree(mesh, batch)
21
- return augmenters
22
-
23
- def get_dataset_grain(
24
- data_name="cc12m",
25
- batch_size=64,
26
- image_scale=256,
27
- count=None,
28
- num_epochs=None,
29
- method=jax.image.ResizeMethod.LANCZOS3,
30
- worker_count=32,
31
- read_thread_count=64,
32
- read_buffer_size=50,
33
- worker_buffer_size=20,
34
- seed=0,
35
- dataset_source="/mnt/gcs_mount/arrayrecord2/cc12m/",
36
- ):
37
- dataset = datasetMap[data_name]
38
- data_source = dataset["source"](dataset_source)
39
- augmenter = dataset["augmenter"](image_scale, method)
40
-
41
- local_batch_size = batch_size // jax.process_count()
42
-
43
- sampler = pygrain.IndexSampler(
44
- num_records=len(data_source) if count is None else count,
45
- shuffle=True,
46
- seed=seed,
47
- num_epochs=num_epochs,
48
- shard_options=pygrain.ShardByJaxProcess(),
49
- )
50
-
51
- def get_trainset():
52
- transformations = [
53
- augmenter(),
54
- pygrain.Batch(local_batch_size, drop_remainder=True),
55
- ]
56
-
57
- # if mesh != None:
58
- # transformations += [batch_mesh_map(mesh)]
59
-
60
- loader = pygrain.DataLoader(
61
- data_source=data_source,
62
- sampler=sampler,
63
- operations=transformations,
64
- worker_count=worker_count,
65
- read_options=pygrain.ReadOptions(
66
- read_thread_count, read_buffer_size
67
- ),
68
- worker_buffer_size=worker_buffer_size,
69
- )
70
- return loader
71
-
72
-
73
- return {
74
- "train": get_trainset,
75
- "train_len": len(data_source),
76
- "local_batch_size": local_batch_size,
77
- "global_batch_size": batch_size,
78
- # "null_labels": null_labels,
79
- # "null_labels_full": null_labels_full,
80
- # "model": model,
81
- # "tokenizer": tokenizer,
82
- }
83
-
84
- def generate_collate_fn():
85
- auto_tokenize = AutoTextTokenizer(tensor_type="np")
86
- def default_collate(batch):
87
- try:
88
- # urls = [sample["url"] for sample in batch]
89
- captions = [sample["caption"] for sample in batch]
90
- results = auto_tokenize(captions)
91
- images = np.stack([sample["image"] for sample in batch], axis=0)
92
- return {
93
- "image": images,
94
- "input_ids": results['input_ids'],
95
- "attention_mask": results['attention_mask'],
96
- }
97
- except Exception as e:
98
- print("Error in collate function", e, [sample["image"].shape for sample in batch])
99
- traceback.print_exc()
100
-
101
- return default_collate
102
-
103
- def get_dataset_online(
104
- data_name="combined_online",
105
- batch_size=64,
106
- image_scale=256,
107
- count=None,
108
- num_epochs=None,
109
- method=jax.image.ResizeMethod.LANCZOS3,
110
- worker_count=32,
111
- read_thread_count=64,
112
- read_buffer_size=50,
113
- worker_buffer_size=20,
114
- seed=0,
115
- dataset_source="/mnt/gcs_mount/arrayrecord2/cc12m/",
116
- ):
117
- local_batch_size = batch_size // jax.process_count()
118
-
119
- sources = onlineDatasetMap[data_name]["source"]
120
- dataloader = OnlineStreamingDataLoader(
121
- sources,
122
- batch_size=local_batch_size,
123
- num_workers=worker_count,
124
- num_threads=read_thread_count,
125
- image_shape=(image_scale, image_scale),
126
- global_process_count=jax.process_count(),
127
- global_process_index=jax.process_index(),
128
- prefetch=worker_buffer_size,
129
- collate_fn=generate_collate_fn(),
130
- default_split="train",
131
- )
132
-
133
- def get_trainset(mesh: Mesh = None):
134
- if mesh != None:
135
- class dataLoaderWithMesh:
136
- def __init__(self, dataloader, mesh):
137
- self.dataloader = dataloader
138
- self.mesh = mesh
139
- self.tmp_queue = queue.Queue(worker_buffer_size)
140
- def batch_loader():
141
- for batch in self.dataloader:
142
- try:
143
- self.tmp_queue.put(convert_to_global_tree(mesh, batch))
144
- except Exception as e:
145
- print("Error processing batch", e)
146
- self.loader_thread = threading.Thread(target=batch_loader)
147
- self.loader_thread.start()
148
-
149
- def __iter__(self):
150
- return self
151
-
152
- def __next__(self):
153
- return self.tmp_queue.get()
154
-
155
- dataloader_with_mesh = dataLoaderWithMesh(dataloader, mesh)
156
-
157
- return dataloader_with_mesh
158
- return dataloader
159
-
160
- return {
161
- "train": get_trainset,
162
- "train_len": len(dataloader) * jax.process_count(),
163
- "local_batch_size": local_batch_size,
164
- "global_batch_size": batch_size,
165
- # "null_labels": null_labels,
166
- # "null_labels_full": null_labels_full,
167
- # "model": model,
168
- # "tokenizer": tokenizer,
169
- }
@@ -1,363 +0,0 @@
1
- import multiprocessing
2
- import threading
3
- from multiprocessing import Queue
4
- # from arrayqueues.shared_arrays import ArrayQueue
5
- # from faster_fifo import Queue
6
- import time
7
- import albumentations as A
8
- import queue
9
- import cv2
10
- from functools import partial
11
- from typing import Any, Dict, List, Tuple
12
-
13
- import numpy as np
14
- from functools import partial
15
-
16
- from datasets import load_dataset, concatenate_datasets, Dataset, load_from_disk
17
- from datasets.utils.file_utils import get_datasets_user_agent
18
- from concurrent.futures import ThreadPoolExecutor
19
- import io
20
- import urllib
21
-
22
- import PIL.Image
23
- import cv2
24
- import traceback
25
-
26
- USER_AGENT = get_datasets_user_agent()
27
-
28
- data_queue = Queue(16*2000)
29
-
30
-
31
- def fetch_single_image(image_url, timeout=None, retries=0):
32
- for _ in range(retries + 1):
33
- try:
34
- request = urllib.request.Request(
35
- image_url,
36
- data=None,
37
- headers={"user-agent": USER_AGENT},
38
- )
39
- with urllib.request.urlopen(request, timeout=timeout) as req:
40
- image = PIL.Image.open(io.BytesIO(req.read()))
41
- break
42
- except Exception:
43
- image = None
44
- return image
45
-
46
-
47
- def default_image_processor(
48
- image, image_shape,
49
- min_image_shape=(128, 128),
50
- upscale_interpolation=cv2.INTER_CUBIC,
51
- downscale_interpolation=cv2.INTER_AREA,
52
- ):
53
- try:
54
- image = np.array(image)
55
- if len(image.shape) != 3 or image.shape[2] != 3:
56
- return None, 0, 0
57
- original_height, original_width = image.shape[:2]
58
- # check if the image is too small
59
- if min(original_height, original_width) < min(min_image_shape):
60
- return None, original_height, original_width
61
- # check if wrong aspect ratio
62
- if max(original_height, original_width) / min(original_height, original_width) > 2.4:
63
- return None, original_height, original_width
64
- # check if the variance is too low
65
- if np.std(image) < 1e-5:
66
- return None, original_height, original_width
67
- # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
68
- downscale = max(original_width, original_height) > max(image_shape)
69
- interpolation = downscale_interpolation if downscale else upscale_interpolation
70
-
71
- image = A.longest_max_size(image, max(
72
- image_shape), interpolation=interpolation)
73
- image = A.pad(
74
- image,
75
- min_height=image_shape[0],
76
- min_width=image_shape[1],
77
- border_mode=cv2.BORDER_CONSTANT,
78
- value=[255, 255, 255],
79
- )
80
- return image, original_height, original_width
81
- except Exception as e:
82
- # print("Error processing image", e, image_shape, interpolation)
83
- # traceback.print_exc()
84
- return None, 0, 0
85
-
86
-
87
- def map_sample(
88
- url,
89
- caption,
90
- image_shape=(256, 256),
91
- min_image_shape=(128, 128),
92
- timeout=15,
93
- retries=3,
94
- upscale_interpolation=cv2.INTER_CUBIC,
95
- downscale_interpolation=cv2.INTER_AREA,
96
- image_processor=default_image_processor,
97
- ):
98
- try:
99
- # Assuming fetch_single_image is defined elsewhere
100
- image = fetch_single_image(url, timeout=timeout, retries=retries)
101
- if image is None:
102
- return
103
-
104
- image, original_height, original_width = image_processor(
105
- image, image_shape, min_image_shape=min_image_shape,
106
- upscale_interpolation=upscale_interpolation,
107
- downscale_interpolation=downscale_interpolation,
108
- )
109
-
110
- if image is None:
111
- return
112
-
113
- data_queue.put({
114
- "url": url,
115
- "caption": caption,
116
- "image": image,
117
- "original_height": original_height,
118
- "original_width": original_width,
119
- })
120
- except Exception as e:
121
- # print(f"Error maping sample {url}", e)
122
- # traceback.print_exc()
123
- # error_queue.put_nowait({
124
- # "url": url,
125
- # "caption": caption,
126
- # "error": str(e)
127
- # })
128
- pass
129
-
130
- def default_feature_extractor(sample):
131
- url = None
132
- if "url" in sample:
133
- url = sample["url"]
134
- elif "URL" in sample:
135
- url = sample["URL"]
136
- elif "image_url" in sample:
137
- url = sample["image_url"]
138
- else:
139
- print("No url found in sample, skipping", sample.keys())
140
-
141
- caption = None
142
- if "caption" in sample:
143
- caption = sample["caption"]
144
- elif "CAPTION" in sample:
145
- caption = sample["CAPTION"]
146
- elif "txt" in sample:
147
- caption = sample["txt"]
148
- elif "TEXT" in sample:
149
- caption = sample["TEXT"]
150
- elif "text" in sample:
151
- caption = sample["text"]
152
- else:
153
- print("No caption found in sample, skipping", sample.keys())
154
-
155
- return {
156
- "url": url,
157
- "caption": caption,
158
- }
159
-
160
- def map_batch(
161
- batch, num_threads=256, image_shape=(256, 256),
162
- min_image_shape=(128, 128),
163
- timeout=15, retries=3, image_processor=default_image_processor,
164
- upscale_interpolation=cv2.INTER_CUBIC,
165
- downscale_interpolation=cv2.INTER_AREA,
166
- feature_extractor=default_feature_extractor,
167
- ):
168
- try:
169
- map_sample_fn = partial(
170
- map_sample, image_shape=image_shape, min_image_shape=min_image_shape,
171
- timeout=timeout, retries=retries, image_processor=image_processor,
172
- upscale_interpolation=upscale_interpolation,
173
- downscale_interpolation=downscale_interpolation
174
- )
175
- with ThreadPoolExecutor(max_workers=num_threads) as executor:
176
- features = feature_extractor(batch)
177
- url, caption = features["url"], features["caption"]
178
- executor.map(map_sample_fn, url, caption)
179
- except Exception as e:
180
- print(f"Error maping batch", e)
181
- traceback.print_exc()
182
- # error_queue.put_nowait({
183
- # "batch": batch,
184
- # "error": str(e)
185
- # })
186
- pass
187
-
188
-
189
- def parallel_image_loader(
190
- dataset: Dataset, num_workers: int = 8, image_shape=(256, 256),
191
- min_image_shape=(128, 128),
192
- num_threads=256, timeout=15, retries=3, image_processor=default_image_processor,
193
- upscale_interpolation=cv2.INTER_CUBIC,
194
- downscale_interpolation=cv2.INTER_AREA,
195
- feature_extractor=default_feature_extractor,
196
- ):
197
- map_batch_fn = partial(
198
- map_batch, num_threads=num_threads, image_shape=image_shape,
199
- min_image_shape=min_image_shape,
200
- timeout=timeout, retries=retries, image_processor=image_processor,
201
- upscale_interpolation=upscale_interpolation,
202
- downscale_interpolation=downscale_interpolation,
203
- feature_extractor=feature_extractor
204
- )
205
- shard_len = len(dataset) // num_workers
206
- print(f"Local Shard lengths: {shard_len}")
207
- with multiprocessing.Pool(num_workers) as pool:
208
- iteration = 0
209
- while True:
210
- # Repeat forever
211
- shards = [dataset[i*shard_len:(i+1)*shard_len]
212
- for i in range(num_workers)]
213
- print(f"mapping {len(shards)} shards")
214
- pool.map(map_batch_fn, shards)
215
- iteration += 1
216
- print(f"Shuffling dataset with seed {iteration}")
217
- dataset = dataset.shuffle(seed=iteration)
218
- # Clear the error queue
219
- # while not error_queue.empty():
220
- # error_queue.get_nowait()
221
-
222
-
223
- class ImageBatchIterator:
224
- def __init__(
225
- self, dataset: Dataset, batch_size: int = 64, image_shape=(256, 256),
226
- min_image_shape=(128, 128),
227
- num_workers: int = 8, num_threads=256, timeout=15, retries=3,
228
- image_processor=default_image_processor,
229
- upscale_interpolation=cv2.INTER_CUBIC,
230
- downscale_interpolation=cv2.INTER_AREA,
231
- feature_extractor=default_feature_extractor,
232
- ):
233
- self.dataset = dataset
234
- self.num_workers = num_workers
235
- self.batch_size = batch_size
236
- loader = partial(
237
- parallel_image_loader,
238
- num_threads=num_threads,
239
- image_shape=image_shape,
240
- min_image_shape=min_image_shape,
241
- num_workers=num_workers,
242
- timeout=timeout, retries=retries,
243
- image_processor=image_processor,
244
- upscale_interpolation=upscale_interpolation,
245
- downscale_interpolation=downscale_interpolation,
246
- feature_extractor=feature_extractor
247
- )
248
- self.thread = threading.Thread(target=loader, args=(dataset,))
249
- self.thread.start()
250
-
251
- def __iter__(self):
252
- return self
253
-
254
- def __next__(self):
255
- def fetcher(_):
256
- return data_queue.get()
257
- with ThreadPoolExecutor(max_workers=self.batch_size) as executor:
258
- batch = list(executor.map(fetcher, range(self.batch_size)))
259
- return batch
260
-
261
- def __del__(self):
262
- self.thread.join()
263
-
264
- def __len__(self):
265
- return len(self.dataset) // self.batch_size
266
-
267
-
268
- def default_collate(batch):
269
- urls = [sample["url"] for sample in batch]
270
- captions = [sample["caption"] for sample in batch]
271
- images = np.stack([sample["image"] for sample in batch], axis=0)
272
- return {
273
- "url": urls,
274
- "caption": captions,
275
- "image": images,
276
- }
277
-
278
-
279
- def dataMapper(map: Dict[str, Any]):
280
- def _map(sample) -> Dict[str, Any]:
281
- return {
282
- "url": sample[map["url"]],
283
- "caption": sample[map["caption"]],
284
- }
285
- return _map
286
-
287
-
288
- class OnlineStreamingDataLoader():
289
- def __init__(
290
- self,
291
- dataset,
292
- batch_size=64,
293
- image_shape=(256, 256),
294
- min_image_shape=(128, 128),
295
- num_workers=16,
296
- num_threads=512,
297
- default_split="all",
298
- pre_map_maker=dataMapper,
299
- pre_map_def={
300
- "url": "URL",
301
- "caption": "TEXT",
302
- },
303
- global_process_count=1,
304
- global_process_index=0,
305
- prefetch=1000,
306
- collate_fn=default_collate,
307
- timeout=15,
308
- retries=3,
309
- image_processor=default_image_processor,
310
- upscale_interpolation=cv2.INTER_CUBIC,
311
- downscale_interpolation=cv2.INTER_AREA,
312
- feature_extractor=default_feature_extractor,
313
- ):
314
- if isinstance(dataset, str):
315
- dataset_path = dataset
316
- print("Loading dataset from path")
317
- if "gs://" in dataset:
318
- dataset = load_from_disk(dataset_path)
319
- else:
320
- dataset = load_dataset(dataset_path, split=default_split)
321
- elif isinstance(dataset, list):
322
- if isinstance(dataset[0], str):
323
- print("Loading multiple datasets from paths")
324
- dataset = [load_from_disk(dataset_path) if "gs://" in dataset_path else load_dataset(
325
- dataset_path, split=default_split) for dataset_path in dataset]
326
- print("Concatenating multiple datasets")
327
- dataset = concatenate_datasets(dataset)
328
- dataset = dataset.shuffle(seed=0)
329
- # dataset = dataset.map(pre_map_maker(pre_map_def), batched=True, batch_size=10000000)
330
- self.dataset = dataset.shard(
331
- num_shards=global_process_count, index=global_process_index)
332
- print(f"Dataset length: {len(dataset)}")
333
- self.iterator = ImageBatchIterator(self.dataset, image_shape=image_shape,
334
- min_image_shape=min_image_shape,
335
- num_workers=num_workers, batch_size=batch_size, num_threads=num_threads,
336
- timeout=timeout, retries=retries, image_processor=image_processor,
337
- upscale_interpolation=upscale_interpolation,
338
- downscale_interpolation=downscale_interpolation,
339
- feature_extractor=feature_extractor)
340
- self.batch_size = batch_size
341
-
342
- # Launch a thread to load batches in the background
343
- self.batch_queue = queue.Queue(prefetch)
344
-
345
- def batch_loader():
346
- for batch in self.iterator:
347
- try:
348
- self.batch_queue.put(collate_fn(batch))
349
- except Exception as e:
350
- print("Error collating batch", e)
351
-
352
- self.loader_thread = threading.Thread(target=batch_loader)
353
- self.loader_thread.start()
354
-
355
- def __iter__(self):
356
- return self
357
-
358
- def __next__(self):
359
- return self.batch_queue.get()
360
- # return self.collate_fn(next(self.iterator))
361
-
362
- def __len__(self):
363
- return len(self.dataset)
@@ -1 +0,0 @@
1
- from .simple_unet import *