flaxdiff 0.1.36.1__tar.gz → 0.1.36.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {flaxdiff-0.1.36.1 → flaxdiff-0.1.36.2}/PKG-INFO +1 -1
- {flaxdiff-0.1.36.1 → flaxdiff-0.1.36.2}/pyproject.toml +2 -5
- flaxdiff-0.1.36.2/src/data/__init__.py +1 -0
- flaxdiff-0.1.36.2/src/data/dataset_map.py +71 -0
- flaxdiff-0.1.36.2/src/data/datasets.py +169 -0
- flaxdiff-0.1.36.2/src/data/online_loader.py +363 -0
- flaxdiff-0.1.36.2/src/data/sources/gcs.py +81 -0
- flaxdiff-0.1.36.2/src/data/sources/tfds.py +67 -0
- {flaxdiff-0.1.36.1 → flaxdiff-0.1.36.2/src}/flaxdiff.egg-info/PKG-INFO +1 -1
- flaxdiff-0.1.36.2/src/flaxdiff.egg-info/SOURCES.txt +50 -0
- flaxdiff-0.1.36.2/src/flaxdiff.egg-info/top_level.txt +9 -0
- flaxdiff-0.1.36.2/src/metrics/inception.py +658 -0
- flaxdiff-0.1.36.2/src/metrics/utils.py +49 -0
- flaxdiff-0.1.36.2/src/models/__init__.py +1 -0
- flaxdiff-0.1.36.2/src/models/attention.py +368 -0
- flaxdiff-0.1.36.2/src/models/autoencoder/__init__.py +2 -0
- flaxdiff-0.1.36.2/src/models/autoencoder/autoencoder.py +19 -0
- flaxdiff-0.1.36.2/src/models/autoencoder/diffusers.py +91 -0
- flaxdiff-0.1.36.2/src/models/autoencoder/simple_autoenc.py +26 -0
- flaxdiff-0.1.36.2/src/models/common.py +346 -0
- flaxdiff-0.1.36.2/src/models/favor_fastattn.py +723 -0
- flaxdiff-0.1.36.2/src/models/simple_unet.py +233 -0
- flaxdiff-0.1.36.2/src/models/simple_vit.py +180 -0
- flaxdiff-0.1.36.2/src/predictors/__init__.py +96 -0
- flaxdiff-0.1.36.2/src/samplers/__init__.py +7 -0
- flaxdiff-0.1.36.2/src/samplers/common.py +165 -0
- flaxdiff-0.1.36.2/src/samplers/ddim.py +10 -0
- flaxdiff-0.1.36.2/src/samplers/ddpm.py +37 -0
- flaxdiff-0.1.36.2/src/samplers/euler.py +56 -0
- flaxdiff-0.1.36.2/src/samplers/heun_sampler.py +27 -0
- flaxdiff-0.1.36.2/src/samplers/multistep_dpm.py +59 -0
- flaxdiff-0.1.36.2/src/samplers/rk4_sampler.py +34 -0
- flaxdiff-0.1.36.2/src/schedulers/__init__.py +6 -0
- flaxdiff-0.1.36.2/src/schedulers/common.py +98 -0
- flaxdiff-0.1.36.2/src/schedulers/continuous.py +12 -0
- flaxdiff-0.1.36.2/src/schedulers/cosine.py +40 -0
- flaxdiff-0.1.36.2/src/schedulers/discrete.py +74 -0
- flaxdiff-0.1.36.2/src/schedulers/exp.py +13 -0
- flaxdiff-0.1.36.2/src/schedulers/karras.py +69 -0
- flaxdiff-0.1.36.2/src/schedulers/linear.py +14 -0
- flaxdiff-0.1.36.2/src/schedulers/sqrt.py +10 -0
- flaxdiff-0.1.36.2/src/trainer/__init__.py +2 -0
- flaxdiff-0.1.36.2/src/trainer/autoencoder_trainer.py +182 -0
- flaxdiff-0.1.36.2/src/trainer/diffusion_trainer.py +326 -0
- flaxdiff-0.1.36.2/src/trainer/simple_trainer.py +540 -0
- flaxdiff-0.1.36.2/src/trainer/video_diffusion_trainer.py +62 -0
- flaxdiff-0.1.36.1/flaxdiff.egg-info/SOURCES.txt +0 -9
- flaxdiff-0.1.36.1/flaxdiff.egg-info/top_level.txt +0 -1
- {flaxdiff-0.1.36.1 → flaxdiff-0.1.36.2}/README.md +0 -0
- {flaxdiff-0.1.36.1 → flaxdiff-0.1.36.2}/setup.cfg +0 -0
- {flaxdiff-0.1.36.1/flaxdiff → flaxdiff-0.1.36.2/src}/__init__.py +0 -0
- {flaxdiff-0.1.36.1 → flaxdiff-0.1.36.2/src}/flaxdiff.egg-info/dependency_links.txt +0 -0
- {flaxdiff-0.1.36.1 → flaxdiff-0.1.36.2/src}/flaxdiff.egg-info/requires.txt +0 -0
- {flaxdiff-0.1.36.1/flaxdiff → flaxdiff-0.1.36.2/src}/utils.py +0 -0
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
4
4
|
|
5
5
|
[project]
|
6
6
|
name = "flaxdiff"
|
7
|
-
version = "0.1.36.
|
7
|
+
version = "0.1.36.2"
|
8
8
|
description = "A versatile and easy to understand Diffusion library"
|
9
9
|
readme = "README.md"
|
10
10
|
authors = [
|
@@ -26,7 +26,4 @@ dependencies = [
|
|
26
26
|
"rich",
|
27
27
|
"python-dotenv",
|
28
28
|
]
|
29
|
-
license = "MIT"
|
30
|
-
|
31
|
-
[tool.setuptools]
|
32
|
-
packages = ["flaxdiff"]
|
29
|
+
license = "MIT"
|
@@ -0,0 +1 @@
|
|
1
|
+
from .online_loader import OnlineStreamingDataLoader
|
@@ -0,0 +1,71 @@
|
|
1
|
+
from .sources.tfds import data_source_tfds, tfds_augmenters
|
2
|
+
from .sources.gcs import data_source_gcs, data_source_combined_gcs, gcs_augmenters
|
3
|
+
|
4
|
+
# Configure the following for your datasets
|
5
|
+
datasetMap = {
|
6
|
+
"oxford_flowers102": {
|
7
|
+
"source": data_source_tfds("oxford_flowers102", use_tf=False),
|
8
|
+
"augmenter": tfds_augmenters,
|
9
|
+
},
|
10
|
+
"cc12m": {
|
11
|
+
"source": data_source_gcs('arrayrecord2/cc12m'),
|
12
|
+
"augmenter": gcs_augmenters,
|
13
|
+
},
|
14
|
+
"laiona_coco": {
|
15
|
+
"source": data_source_gcs('arrayrecord2/laion-aesthetics-12m+mscoco-2017'),
|
16
|
+
"augmenter": gcs_augmenters,
|
17
|
+
},
|
18
|
+
"aesthetic_coyo": {
|
19
|
+
"source": data_source_gcs('arrayrecords/aestheticCoyo_0.25clip_6aesthetic'),
|
20
|
+
"augmenter": gcs_augmenters,
|
21
|
+
},
|
22
|
+
"combined_aesthetic": {
|
23
|
+
"source": data_source_combined_gcs([
|
24
|
+
'arrayrecord2/laion-aesthetics-12m+mscoco-2017',
|
25
|
+
'arrayrecords/aestheticCoyo_0.25clip_6aesthetic',
|
26
|
+
'arrayrecord2/cc12m',
|
27
|
+
'arrayrecords/aestheticCoyo_0.25clip_6aesthetic',
|
28
|
+
]),
|
29
|
+
"augmenter": gcs_augmenters,
|
30
|
+
},
|
31
|
+
"laiona_coco_coyo": {
|
32
|
+
"source": data_source_combined_gcs([
|
33
|
+
'arrayrecords/aestheticCoyo_0.25clip_6aesthetic',
|
34
|
+
'arrayrecord2/laion-aesthetics-12m+mscoco-2017',
|
35
|
+
'arrayrecords/aestheticCoyo_0.25clip_6aesthetic',
|
36
|
+
]),
|
37
|
+
"augmenter": gcs_augmenters,
|
38
|
+
},
|
39
|
+
"combined_30m": {
|
40
|
+
"source": data_source_combined_gcs([
|
41
|
+
'arrayrecord2/laion-aesthetics-12m+mscoco-2017',
|
42
|
+
'arrayrecord2/cc12m',
|
43
|
+
'arrayrecord2/aestheticCoyo_0.26_clip_5.5aesthetic_256plus',
|
44
|
+
"arrayrecord2/playground+leonardo_x4+cc3m.parquet",
|
45
|
+
]),
|
46
|
+
"augmenter": gcs_augmenters,
|
47
|
+
}
|
48
|
+
}
|
49
|
+
|
50
|
+
onlineDatasetMap = {
|
51
|
+
"combined_online": {
|
52
|
+
"source": [
|
53
|
+
# "gs://flaxdiff-datasets-regional/datasets/laion-aesthetics-12m+mscoco-2017.parquet"
|
54
|
+
# "ChristophSchuhmann/MS_COCO_2017_URL_TEXT",
|
55
|
+
# "dclure/laion-aesthetics-12m-umap",
|
56
|
+
"gs://flaxdiff-datasets-regional/datasets/laion-aesthetics-12m+mscoco-2017",
|
57
|
+
"gs://flaxdiff-datasets-regional/datasets/coyo700m-aesthetic-5.4_25M",
|
58
|
+
"gs://flaxdiff-datasets-regional/datasets/leonardo-liked-1.8m",
|
59
|
+
"gs://flaxdiff-datasets-regional/datasets/leonardo-liked-1.8m",
|
60
|
+
"gs://flaxdiff-datasets-regional/datasets/leonardo-liked-1.8m",
|
61
|
+
"gs://flaxdiff-datasets-regional/datasets/cc12m",
|
62
|
+
"gs://flaxdiff-datasets-regional/datasets/playground-liked",
|
63
|
+
"gs://flaxdiff-datasets-regional/datasets/leonardo-liked-1.8m",
|
64
|
+
"gs://flaxdiff-datasets-regional/datasets/leonardo-liked-1.8m",
|
65
|
+
"gs://flaxdiff-datasets-regional/datasets/cc3m",
|
66
|
+
"gs://flaxdiff-datasets-regional/datasets/cc3m",
|
67
|
+
"gs://flaxdiff-datasets-regional/datasets/laion2B-en-aesthetic-4.2_37M",
|
68
|
+
# "gs://flaxdiff-datasets-regional/datasets/laiion400m-185M"
|
69
|
+
]
|
70
|
+
}
|
71
|
+
}
|
@@ -0,0 +1,169 @@
|
|
1
|
+
import jax.numpy as jnp
|
2
|
+
import grain.python as pygrain
|
3
|
+
from typing import Dict
|
4
|
+
import numpy as np
|
5
|
+
import jax
|
6
|
+
from flaxdiff.utils import convert_to_global_tree, AutoTextTokenizer
|
7
|
+
from .dataset_map import datasetMap, onlineDatasetMap
|
8
|
+
import traceback
|
9
|
+
from .online_loader import OnlineStreamingDataLoader
|
10
|
+
import queue
|
11
|
+
from jax.sharding import Mesh
|
12
|
+
import threading
|
13
|
+
|
14
|
+
def batch_mesh_map(mesh):
|
15
|
+
class augmenters(pygrain.MapTransform):
|
16
|
+
def __init__(self, *args, **kwargs):
|
17
|
+
super().__init__(*args, **kwargs)
|
18
|
+
|
19
|
+
def map(self, batch) -> Dict[str, jnp.array]:
|
20
|
+
return convert_to_global_tree(mesh, batch)
|
21
|
+
return augmenters
|
22
|
+
|
23
|
+
def get_dataset_grain(
|
24
|
+
data_name="cc12m",
|
25
|
+
batch_size=64,
|
26
|
+
image_scale=256,
|
27
|
+
count=None,
|
28
|
+
num_epochs=None,
|
29
|
+
method=jax.image.ResizeMethod.LANCZOS3,
|
30
|
+
worker_count=32,
|
31
|
+
read_thread_count=64,
|
32
|
+
read_buffer_size=50,
|
33
|
+
worker_buffer_size=20,
|
34
|
+
seed=0,
|
35
|
+
dataset_source="/mnt/gcs_mount/arrayrecord2/cc12m/",
|
36
|
+
):
|
37
|
+
dataset = datasetMap[data_name]
|
38
|
+
data_source = dataset["source"](dataset_source)
|
39
|
+
augmenter = dataset["augmenter"](image_scale, method)
|
40
|
+
|
41
|
+
local_batch_size = batch_size // jax.process_count()
|
42
|
+
|
43
|
+
sampler = pygrain.IndexSampler(
|
44
|
+
num_records=len(data_source) if count is None else count,
|
45
|
+
shuffle=True,
|
46
|
+
seed=seed,
|
47
|
+
num_epochs=num_epochs,
|
48
|
+
shard_options=pygrain.ShardByJaxProcess(),
|
49
|
+
)
|
50
|
+
|
51
|
+
def get_trainset():
|
52
|
+
transformations = [
|
53
|
+
augmenter(),
|
54
|
+
pygrain.Batch(local_batch_size, drop_remainder=True),
|
55
|
+
]
|
56
|
+
|
57
|
+
# if mesh != None:
|
58
|
+
# transformations += [batch_mesh_map(mesh)]
|
59
|
+
|
60
|
+
loader = pygrain.DataLoader(
|
61
|
+
data_source=data_source,
|
62
|
+
sampler=sampler,
|
63
|
+
operations=transformations,
|
64
|
+
worker_count=worker_count,
|
65
|
+
read_options=pygrain.ReadOptions(
|
66
|
+
read_thread_count, read_buffer_size
|
67
|
+
),
|
68
|
+
worker_buffer_size=worker_buffer_size,
|
69
|
+
)
|
70
|
+
return loader
|
71
|
+
|
72
|
+
|
73
|
+
return {
|
74
|
+
"train": get_trainset,
|
75
|
+
"train_len": len(data_source),
|
76
|
+
"local_batch_size": local_batch_size,
|
77
|
+
"global_batch_size": batch_size,
|
78
|
+
# "null_labels": null_labels,
|
79
|
+
# "null_labels_full": null_labels_full,
|
80
|
+
# "model": model,
|
81
|
+
# "tokenizer": tokenizer,
|
82
|
+
}
|
83
|
+
|
84
|
+
def generate_collate_fn():
|
85
|
+
auto_tokenize = AutoTextTokenizer(tensor_type="np")
|
86
|
+
def default_collate(batch):
|
87
|
+
try:
|
88
|
+
# urls = [sample["url"] for sample in batch]
|
89
|
+
captions = [sample["caption"] for sample in batch]
|
90
|
+
results = auto_tokenize(captions)
|
91
|
+
images = np.stack([sample["image"] for sample in batch], axis=0)
|
92
|
+
return {
|
93
|
+
"image": images,
|
94
|
+
"input_ids": results['input_ids'],
|
95
|
+
"attention_mask": results['attention_mask'],
|
96
|
+
}
|
97
|
+
except Exception as e:
|
98
|
+
print("Error in collate function", e, [sample["image"].shape for sample in batch])
|
99
|
+
traceback.print_exc()
|
100
|
+
|
101
|
+
return default_collate
|
102
|
+
|
103
|
+
def get_dataset_online(
|
104
|
+
data_name="combined_online",
|
105
|
+
batch_size=64,
|
106
|
+
image_scale=256,
|
107
|
+
count=None,
|
108
|
+
num_epochs=None,
|
109
|
+
method=jax.image.ResizeMethod.LANCZOS3,
|
110
|
+
worker_count=32,
|
111
|
+
read_thread_count=64,
|
112
|
+
read_buffer_size=50,
|
113
|
+
worker_buffer_size=20,
|
114
|
+
seed=0,
|
115
|
+
dataset_source="/mnt/gcs_mount/arrayrecord2/cc12m/",
|
116
|
+
):
|
117
|
+
local_batch_size = batch_size // jax.process_count()
|
118
|
+
|
119
|
+
sources = onlineDatasetMap[data_name]["source"]
|
120
|
+
dataloader = OnlineStreamingDataLoader(
|
121
|
+
sources,
|
122
|
+
batch_size=local_batch_size,
|
123
|
+
num_workers=worker_count,
|
124
|
+
num_threads=read_thread_count,
|
125
|
+
image_shape=(image_scale, image_scale),
|
126
|
+
global_process_count=jax.process_count(),
|
127
|
+
global_process_index=jax.process_index(),
|
128
|
+
prefetch=worker_buffer_size,
|
129
|
+
collate_fn=generate_collate_fn(),
|
130
|
+
default_split="train",
|
131
|
+
)
|
132
|
+
|
133
|
+
def get_trainset(mesh: Mesh = None):
|
134
|
+
if mesh != None:
|
135
|
+
class dataLoaderWithMesh:
|
136
|
+
def __init__(self, dataloader, mesh):
|
137
|
+
self.dataloader = dataloader
|
138
|
+
self.mesh = mesh
|
139
|
+
self.tmp_queue = queue.Queue(worker_buffer_size)
|
140
|
+
def batch_loader():
|
141
|
+
for batch in self.dataloader:
|
142
|
+
try:
|
143
|
+
self.tmp_queue.put(convert_to_global_tree(mesh, batch))
|
144
|
+
except Exception as e:
|
145
|
+
print("Error processing batch", e)
|
146
|
+
self.loader_thread = threading.Thread(target=batch_loader)
|
147
|
+
self.loader_thread.start()
|
148
|
+
|
149
|
+
def __iter__(self):
|
150
|
+
return self
|
151
|
+
|
152
|
+
def __next__(self):
|
153
|
+
return self.tmp_queue.get()
|
154
|
+
|
155
|
+
dataloader_with_mesh = dataLoaderWithMesh(dataloader, mesh)
|
156
|
+
|
157
|
+
return dataloader_with_mesh
|
158
|
+
return dataloader
|
159
|
+
|
160
|
+
return {
|
161
|
+
"train": get_trainset,
|
162
|
+
"train_len": len(dataloader) * jax.process_count(),
|
163
|
+
"local_batch_size": local_batch_size,
|
164
|
+
"global_batch_size": batch_size,
|
165
|
+
# "null_labels": null_labels,
|
166
|
+
# "null_labels_full": null_labels_full,
|
167
|
+
# "model": model,
|
168
|
+
# "tokenizer": tokenizer,
|
169
|
+
}
|
@@ -0,0 +1,363 @@
|
|
1
|
+
import multiprocessing
|
2
|
+
import threading
|
3
|
+
from multiprocessing import Queue
|
4
|
+
# from arrayqueues.shared_arrays import ArrayQueue
|
5
|
+
# from faster_fifo import Queue
|
6
|
+
import time
|
7
|
+
import albumentations as A
|
8
|
+
import queue
|
9
|
+
import cv2
|
10
|
+
from functools import partial
|
11
|
+
from typing import Any, Dict, List, Tuple
|
12
|
+
|
13
|
+
import numpy as np
|
14
|
+
from functools import partial
|
15
|
+
|
16
|
+
from datasets import load_dataset, concatenate_datasets, Dataset, load_from_disk
|
17
|
+
from datasets.utils.file_utils import get_datasets_user_agent
|
18
|
+
from concurrent.futures import ThreadPoolExecutor
|
19
|
+
import io
|
20
|
+
import urllib
|
21
|
+
|
22
|
+
import PIL.Image
|
23
|
+
import cv2
|
24
|
+
import traceback
|
25
|
+
|
26
|
+
USER_AGENT = get_datasets_user_agent()
|
27
|
+
|
28
|
+
data_queue = Queue(16*2000)
|
29
|
+
|
30
|
+
|
31
|
+
def fetch_single_image(image_url, timeout=None, retries=0):
|
32
|
+
for _ in range(retries + 1):
|
33
|
+
try:
|
34
|
+
request = urllib.request.Request(
|
35
|
+
image_url,
|
36
|
+
data=None,
|
37
|
+
headers={"user-agent": USER_AGENT},
|
38
|
+
)
|
39
|
+
with urllib.request.urlopen(request, timeout=timeout) as req:
|
40
|
+
image = PIL.Image.open(io.BytesIO(req.read()))
|
41
|
+
break
|
42
|
+
except Exception:
|
43
|
+
image = None
|
44
|
+
return image
|
45
|
+
|
46
|
+
|
47
|
+
def default_image_processor(
|
48
|
+
image, image_shape,
|
49
|
+
min_image_shape=(128, 128),
|
50
|
+
upscale_interpolation=cv2.INTER_CUBIC,
|
51
|
+
downscale_interpolation=cv2.INTER_AREA,
|
52
|
+
):
|
53
|
+
try:
|
54
|
+
image = np.array(image)
|
55
|
+
if len(image.shape) != 3 or image.shape[2] != 3:
|
56
|
+
return None, 0, 0
|
57
|
+
original_height, original_width = image.shape[:2]
|
58
|
+
# check if the image is too small
|
59
|
+
if min(original_height, original_width) < min(min_image_shape):
|
60
|
+
return None, original_height, original_width
|
61
|
+
# check if wrong aspect ratio
|
62
|
+
if max(original_height, original_width) / min(original_height, original_width) > 2.4:
|
63
|
+
return None, original_height, original_width
|
64
|
+
# check if the variance is too low
|
65
|
+
if np.std(image) < 1e-5:
|
66
|
+
return None, original_height, original_width
|
67
|
+
# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
68
|
+
downscale = max(original_width, original_height) > max(image_shape)
|
69
|
+
interpolation = downscale_interpolation if downscale else upscale_interpolation
|
70
|
+
|
71
|
+
image = A.longest_max_size(image, max(
|
72
|
+
image_shape), interpolation=interpolation)
|
73
|
+
image = A.pad(
|
74
|
+
image,
|
75
|
+
min_height=image_shape[0],
|
76
|
+
min_width=image_shape[1],
|
77
|
+
border_mode=cv2.BORDER_CONSTANT,
|
78
|
+
value=[255, 255, 255],
|
79
|
+
)
|
80
|
+
return image, original_height, original_width
|
81
|
+
except Exception as e:
|
82
|
+
# print("Error processing image", e, image_shape, interpolation)
|
83
|
+
# traceback.print_exc()
|
84
|
+
return None, 0, 0
|
85
|
+
|
86
|
+
|
87
|
+
def map_sample(
|
88
|
+
url,
|
89
|
+
caption,
|
90
|
+
image_shape=(256, 256),
|
91
|
+
min_image_shape=(128, 128),
|
92
|
+
timeout=15,
|
93
|
+
retries=3,
|
94
|
+
upscale_interpolation=cv2.INTER_CUBIC,
|
95
|
+
downscale_interpolation=cv2.INTER_AREA,
|
96
|
+
image_processor=default_image_processor,
|
97
|
+
):
|
98
|
+
try:
|
99
|
+
# Assuming fetch_single_image is defined elsewhere
|
100
|
+
image = fetch_single_image(url, timeout=timeout, retries=retries)
|
101
|
+
if image is None:
|
102
|
+
return
|
103
|
+
|
104
|
+
image, original_height, original_width = image_processor(
|
105
|
+
image, image_shape, min_image_shape=min_image_shape,
|
106
|
+
upscale_interpolation=upscale_interpolation,
|
107
|
+
downscale_interpolation=downscale_interpolation,
|
108
|
+
)
|
109
|
+
|
110
|
+
if image is None:
|
111
|
+
return
|
112
|
+
|
113
|
+
data_queue.put({
|
114
|
+
"url": url,
|
115
|
+
"caption": caption,
|
116
|
+
"image": image,
|
117
|
+
"original_height": original_height,
|
118
|
+
"original_width": original_width,
|
119
|
+
})
|
120
|
+
except Exception as e:
|
121
|
+
# print(f"Error maping sample {url}", e)
|
122
|
+
# traceback.print_exc()
|
123
|
+
# error_queue.put_nowait({
|
124
|
+
# "url": url,
|
125
|
+
# "caption": caption,
|
126
|
+
# "error": str(e)
|
127
|
+
# })
|
128
|
+
pass
|
129
|
+
|
130
|
+
def default_feature_extractor(sample):
|
131
|
+
url = None
|
132
|
+
if "url" in sample:
|
133
|
+
url = sample["url"]
|
134
|
+
elif "URL" in sample:
|
135
|
+
url = sample["URL"]
|
136
|
+
elif "image_url" in sample:
|
137
|
+
url = sample["image_url"]
|
138
|
+
else:
|
139
|
+
print("No url found in sample, skipping", sample.keys())
|
140
|
+
|
141
|
+
caption = None
|
142
|
+
if "caption" in sample:
|
143
|
+
caption = sample["caption"]
|
144
|
+
elif "CAPTION" in sample:
|
145
|
+
caption = sample["CAPTION"]
|
146
|
+
elif "txt" in sample:
|
147
|
+
caption = sample["txt"]
|
148
|
+
elif "TEXT" in sample:
|
149
|
+
caption = sample["TEXT"]
|
150
|
+
elif "text" in sample:
|
151
|
+
caption = sample["text"]
|
152
|
+
else:
|
153
|
+
print("No caption found in sample, skipping", sample.keys())
|
154
|
+
|
155
|
+
return {
|
156
|
+
"url": url,
|
157
|
+
"caption": caption,
|
158
|
+
}
|
159
|
+
|
160
|
+
def map_batch(
|
161
|
+
batch, num_threads=256, image_shape=(256, 256),
|
162
|
+
min_image_shape=(128, 128),
|
163
|
+
timeout=15, retries=3, image_processor=default_image_processor,
|
164
|
+
upscale_interpolation=cv2.INTER_CUBIC,
|
165
|
+
downscale_interpolation=cv2.INTER_AREA,
|
166
|
+
feature_extractor=default_feature_extractor,
|
167
|
+
):
|
168
|
+
try:
|
169
|
+
map_sample_fn = partial(
|
170
|
+
map_sample, image_shape=image_shape, min_image_shape=min_image_shape,
|
171
|
+
timeout=timeout, retries=retries, image_processor=image_processor,
|
172
|
+
upscale_interpolation=upscale_interpolation,
|
173
|
+
downscale_interpolation=downscale_interpolation
|
174
|
+
)
|
175
|
+
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
176
|
+
features = feature_extractor(batch)
|
177
|
+
url, caption = features["url"], features["caption"]
|
178
|
+
executor.map(map_sample_fn, url, caption)
|
179
|
+
except Exception as e:
|
180
|
+
print(f"Error maping batch", e)
|
181
|
+
traceback.print_exc()
|
182
|
+
# error_queue.put_nowait({
|
183
|
+
# "batch": batch,
|
184
|
+
# "error": str(e)
|
185
|
+
# })
|
186
|
+
pass
|
187
|
+
|
188
|
+
|
189
|
+
def parallel_image_loader(
|
190
|
+
dataset: Dataset, num_workers: int = 8, image_shape=(256, 256),
|
191
|
+
min_image_shape=(128, 128),
|
192
|
+
num_threads=256, timeout=15, retries=3, image_processor=default_image_processor,
|
193
|
+
upscale_interpolation=cv2.INTER_CUBIC,
|
194
|
+
downscale_interpolation=cv2.INTER_AREA,
|
195
|
+
feature_extractor=default_feature_extractor,
|
196
|
+
):
|
197
|
+
map_batch_fn = partial(
|
198
|
+
map_batch, num_threads=num_threads, image_shape=image_shape,
|
199
|
+
min_image_shape=min_image_shape,
|
200
|
+
timeout=timeout, retries=retries, image_processor=image_processor,
|
201
|
+
upscale_interpolation=upscale_interpolation,
|
202
|
+
downscale_interpolation=downscale_interpolation,
|
203
|
+
feature_extractor=feature_extractor
|
204
|
+
)
|
205
|
+
shard_len = len(dataset) // num_workers
|
206
|
+
print(f"Local Shard lengths: {shard_len}")
|
207
|
+
with multiprocessing.Pool(num_workers) as pool:
|
208
|
+
iteration = 0
|
209
|
+
while True:
|
210
|
+
# Repeat forever
|
211
|
+
shards = [dataset[i*shard_len:(i+1)*shard_len]
|
212
|
+
for i in range(num_workers)]
|
213
|
+
print(f"mapping {len(shards)} shards")
|
214
|
+
pool.map(map_batch_fn, shards)
|
215
|
+
iteration += 1
|
216
|
+
print(f"Shuffling dataset with seed {iteration}")
|
217
|
+
dataset = dataset.shuffle(seed=iteration)
|
218
|
+
# Clear the error queue
|
219
|
+
# while not error_queue.empty():
|
220
|
+
# error_queue.get_nowait()
|
221
|
+
|
222
|
+
|
223
|
+
class ImageBatchIterator:
|
224
|
+
def __init__(
|
225
|
+
self, dataset: Dataset, batch_size: int = 64, image_shape=(256, 256),
|
226
|
+
min_image_shape=(128, 128),
|
227
|
+
num_workers: int = 8, num_threads=256, timeout=15, retries=3,
|
228
|
+
image_processor=default_image_processor,
|
229
|
+
upscale_interpolation=cv2.INTER_CUBIC,
|
230
|
+
downscale_interpolation=cv2.INTER_AREA,
|
231
|
+
feature_extractor=default_feature_extractor,
|
232
|
+
):
|
233
|
+
self.dataset = dataset
|
234
|
+
self.num_workers = num_workers
|
235
|
+
self.batch_size = batch_size
|
236
|
+
loader = partial(
|
237
|
+
parallel_image_loader,
|
238
|
+
num_threads=num_threads,
|
239
|
+
image_shape=image_shape,
|
240
|
+
min_image_shape=min_image_shape,
|
241
|
+
num_workers=num_workers,
|
242
|
+
timeout=timeout, retries=retries,
|
243
|
+
image_processor=image_processor,
|
244
|
+
upscale_interpolation=upscale_interpolation,
|
245
|
+
downscale_interpolation=downscale_interpolation,
|
246
|
+
feature_extractor=feature_extractor
|
247
|
+
)
|
248
|
+
self.thread = threading.Thread(target=loader, args=(dataset,))
|
249
|
+
self.thread.start()
|
250
|
+
|
251
|
+
def __iter__(self):
|
252
|
+
return self
|
253
|
+
|
254
|
+
def __next__(self):
|
255
|
+
def fetcher(_):
|
256
|
+
return data_queue.get()
|
257
|
+
with ThreadPoolExecutor(max_workers=self.batch_size) as executor:
|
258
|
+
batch = list(executor.map(fetcher, range(self.batch_size)))
|
259
|
+
return batch
|
260
|
+
|
261
|
+
def __del__(self):
|
262
|
+
self.thread.join()
|
263
|
+
|
264
|
+
def __len__(self):
|
265
|
+
return len(self.dataset) // self.batch_size
|
266
|
+
|
267
|
+
|
268
|
+
def default_collate(batch):
|
269
|
+
urls = [sample["url"] for sample in batch]
|
270
|
+
captions = [sample["caption"] for sample in batch]
|
271
|
+
images = np.stack([sample["image"] for sample in batch], axis=0)
|
272
|
+
return {
|
273
|
+
"url": urls,
|
274
|
+
"caption": captions,
|
275
|
+
"image": images,
|
276
|
+
}
|
277
|
+
|
278
|
+
|
279
|
+
def dataMapper(map: Dict[str, Any]):
|
280
|
+
def _map(sample) -> Dict[str, Any]:
|
281
|
+
return {
|
282
|
+
"url": sample[map["url"]],
|
283
|
+
"caption": sample[map["caption"]],
|
284
|
+
}
|
285
|
+
return _map
|
286
|
+
|
287
|
+
|
288
|
+
class OnlineStreamingDataLoader():
|
289
|
+
def __init__(
|
290
|
+
self,
|
291
|
+
dataset,
|
292
|
+
batch_size=64,
|
293
|
+
image_shape=(256, 256),
|
294
|
+
min_image_shape=(128, 128),
|
295
|
+
num_workers=16,
|
296
|
+
num_threads=512,
|
297
|
+
default_split="all",
|
298
|
+
pre_map_maker=dataMapper,
|
299
|
+
pre_map_def={
|
300
|
+
"url": "URL",
|
301
|
+
"caption": "TEXT",
|
302
|
+
},
|
303
|
+
global_process_count=1,
|
304
|
+
global_process_index=0,
|
305
|
+
prefetch=1000,
|
306
|
+
collate_fn=default_collate,
|
307
|
+
timeout=15,
|
308
|
+
retries=3,
|
309
|
+
image_processor=default_image_processor,
|
310
|
+
upscale_interpolation=cv2.INTER_CUBIC,
|
311
|
+
downscale_interpolation=cv2.INTER_AREA,
|
312
|
+
feature_extractor=default_feature_extractor,
|
313
|
+
):
|
314
|
+
if isinstance(dataset, str):
|
315
|
+
dataset_path = dataset
|
316
|
+
print("Loading dataset from path")
|
317
|
+
if "gs://" in dataset:
|
318
|
+
dataset = load_from_disk(dataset_path)
|
319
|
+
else:
|
320
|
+
dataset = load_dataset(dataset_path, split=default_split)
|
321
|
+
elif isinstance(dataset, list):
|
322
|
+
if isinstance(dataset[0], str):
|
323
|
+
print("Loading multiple datasets from paths")
|
324
|
+
dataset = [load_from_disk(dataset_path) if "gs://" in dataset_path else load_dataset(
|
325
|
+
dataset_path, split=default_split) for dataset_path in dataset]
|
326
|
+
print("Concatenating multiple datasets")
|
327
|
+
dataset = concatenate_datasets(dataset)
|
328
|
+
dataset = dataset.shuffle(seed=0)
|
329
|
+
# dataset = dataset.map(pre_map_maker(pre_map_def), batched=True, batch_size=10000000)
|
330
|
+
self.dataset = dataset.shard(
|
331
|
+
num_shards=global_process_count, index=global_process_index)
|
332
|
+
print(f"Dataset length: {len(dataset)}")
|
333
|
+
self.iterator = ImageBatchIterator(self.dataset, image_shape=image_shape,
|
334
|
+
min_image_shape=min_image_shape,
|
335
|
+
num_workers=num_workers, batch_size=batch_size, num_threads=num_threads,
|
336
|
+
timeout=timeout, retries=retries, image_processor=image_processor,
|
337
|
+
upscale_interpolation=upscale_interpolation,
|
338
|
+
downscale_interpolation=downscale_interpolation,
|
339
|
+
feature_extractor=feature_extractor)
|
340
|
+
self.batch_size = batch_size
|
341
|
+
|
342
|
+
# Launch a thread to load batches in the background
|
343
|
+
self.batch_queue = queue.Queue(prefetch)
|
344
|
+
|
345
|
+
def batch_loader():
|
346
|
+
for batch in self.iterator:
|
347
|
+
try:
|
348
|
+
self.batch_queue.put(collate_fn(batch))
|
349
|
+
except Exception as e:
|
350
|
+
print("Error collating batch", e)
|
351
|
+
|
352
|
+
self.loader_thread = threading.Thread(target=batch_loader)
|
353
|
+
self.loader_thread.start()
|
354
|
+
|
355
|
+
def __iter__(self):
|
356
|
+
return self
|
357
|
+
|
358
|
+
def __next__(self):
|
359
|
+
return self.batch_queue.get()
|
360
|
+
# return self.collate_fn(next(self.iterator))
|
361
|
+
|
362
|
+
def __len__(self):
|
363
|
+
return len(self.dataset)
|