flaxdiff 0.1.35.6__tar.gz → 0.1.36.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {flaxdiff-0.1.35.6 → flaxdiff-0.1.36.1}/PKG-INFO +14 -5
- flaxdiff-0.1.36.1/flaxdiff/utils.py +192 -0
- {flaxdiff-0.1.35.6 → flaxdiff-0.1.36.1}/flaxdiff.egg-info/PKG-INFO +14 -5
- flaxdiff-0.1.36.1/flaxdiff.egg-info/SOURCES.txt +9 -0
- flaxdiff-0.1.36.1/flaxdiff.egg-info/requires.txt +14 -0
- flaxdiff-0.1.36.1/pyproject.toml +32 -0
- flaxdiff-0.1.35.6/flaxdiff/data/__init__.py +0 -1
- flaxdiff-0.1.35.6/flaxdiff/data/online_loader.py +0 -336
- flaxdiff-0.1.35.6/flaxdiff/models/__init__.py +0 -1
- flaxdiff-0.1.35.6/flaxdiff/models/attention.py +0 -368
- flaxdiff-0.1.35.6/flaxdiff/models/autoencoder/__init__.py +0 -2
- flaxdiff-0.1.35.6/flaxdiff/models/autoencoder/autoencoder.py +0 -19
- flaxdiff-0.1.35.6/flaxdiff/models/autoencoder/diffusers.py +0 -91
- flaxdiff-0.1.35.6/flaxdiff/models/autoencoder/simple_autoenc.py +0 -26
- flaxdiff-0.1.35.6/flaxdiff/models/common.py +0 -346
- flaxdiff-0.1.35.6/flaxdiff/models/favor_fastattn.py +0 -723
- flaxdiff-0.1.35.6/flaxdiff/models/simple_unet.py +0 -233
- flaxdiff-0.1.35.6/flaxdiff/models/simple_vit.py +0 -180
- flaxdiff-0.1.35.6/flaxdiff/predictors/__init__.py +0 -96
- flaxdiff-0.1.35.6/flaxdiff/samplers/__init__.py +0 -7
- flaxdiff-0.1.35.6/flaxdiff/samplers/common.py +0 -113
- flaxdiff-0.1.35.6/flaxdiff/samplers/ddim.py +0 -10
- flaxdiff-0.1.35.6/flaxdiff/samplers/ddpm.py +0 -43
- flaxdiff-0.1.35.6/flaxdiff/samplers/euler.py +0 -59
- flaxdiff-0.1.35.6/flaxdiff/samplers/heun_sampler.py +0 -28
- flaxdiff-0.1.35.6/flaxdiff/samplers/multistep_dpm.py +0 -60
- flaxdiff-0.1.35.6/flaxdiff/samplers/rk4_sampler.py +0 -34
- flaxdiff-0.1.35.6/flaxdiff/schedulers/__init__.py +0 -6
- flaxdiff-0.1.35.6/flaxdiff/schedulers/common.py +0 -98
- flaxdiff-0.1.35.6/flaxdiff/schedulers/continuous.py +0 -12
- flaxdiff-0.1.35.6/flaxdiff/schedulers/cosine.py +0 -40
- flaxdiff-0.1.35.6/flaxdiff/schedulers/discrete.py +0 -74
- flaxdiff-0.1.35.6/flaxdiff/schedulers/exp.py +0 -13
- flaxdiff-0.1.35.6/flaxdiff/schedulers/karras.py +0 -69
- flaxdiff-0.1.35.6/flaxdiff/schedulers/linear.py +0 -14
- flaxdiff-0.1.35.6/flaxdiff/schedulers/sqrt.py +0 -10
- flaxdiff-0.1.35.6/flaxdiff/trainer/__init__.py +0 -2
- flaxdiff-0.1.35.6/flaxdiff/trainer/autoencoder_trainer.py +0 -182
- flaxdiff-0.1.35.6/flaxdiff/trainer/diffusion_trainer.py +0 -234
- flaxdiff-0.1.35.6/flaxdiff/trainer/simple_trainer.py +0 -442
- flaxdiff-0.1.35.6/flaxdiff/utils.py +0 -89
- flaxdiff-0.1.35.6/flaxdiff.egg-info/SOURCES.txt +0 -43
- flaxdiff-0.1.35.6/flaxdiff.egg-info/requires.txt +0 -5
- flaxdiff-0.1.35.6/setup.py +0 -21
- {flaxdiff-0.1.35.6 → flaxdiff-0.1.36.1}/README.md +0 -0
- {flaxdiff-0.1.35.6 → flaxdiff-0.1.36.1}/flaxdiff/__init__.py +0 -0
- {flaxdiff-0.1.35.6 → flaxdiff-0.1.36.1}/flaxdiff.egg-info/dependency_links.txt +0 -0
- {flaxdiff-0.1.35.6 → flaxdiff-0.1.36.1}/flaxdiff.egg-info/top_level.txt +0 -0
- {flaxdiff-0.1.35.6 → flaxdiff-0.1.36.1}/setup.cfg +0 -0
@@ -1,15 +1,24 @@
|
|
1
|
-
Metadata-Version: 2.
|
1
|
+
Metadata-Version: 2.4
|
2
2
|
Name: flaxdiff
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.36.1
|
4
4
|
Summary: A versatile and easy to understand Diffusion library
|
5
|
-
Author: Ashish Kumar Singh
|
6
|
-
|
5
|
+
Author-email: Ashish Kumar Singh <ashishkmr472@gmail.com>
|
6
|
+
License-Expression: MIT
|
7
7
|
Description-Content-Type: text/markdown
|
8
8
|
Requires-Dist: flax>=0.8.4
|
9
|
-
Requires-Dist: optax>=0.2.2
|
10
9
|
Requires-Dist: jax>=0.4.28
|
10
|
+
Requires-Dist: optax>=0.2.2
|
11
11
|
Requires-Dist: orbax
|
12
|
+
Requires-Dist: numpy
|
12
13
|
Requires-Dist: clu
|
14
|
+
Requires-Dist: einops
|
15
|
+
Requires-Dist: tqdm
|
16
|
+
Requires-Dist: grain
|
17
|
+
Requires-Dist: termcolor
|
18
|
+
Requires-Dist: augmax
|
19
|
+
Requires-Dist: albumentations
|
20
|
+
Requires-Dist: rich
|
21
|
+
Requires-Dist: python-dotenv
|
13
22
|
|
14
23
|
# 
|
15
24
|
|
@@ -0,0 +1,192 @@
|
|
1
|
+
import jax
|
2
|
+
import jax.numpy as jnp
|
3
|
+
import flax.struct as struct
|
4
|
+
import flax.linen as nn
|
5
|
+
from typing import Any, Callable
|
6
|
+
from dataclasses import dataclass
|
7
|
+
from functools import partial
|
8
|
+
import numpy as np
|
9
|
+
from jax.sharding import Mesh, PartitionSpec as P
|
10
|
+
from abc import ABC, abstractmethod
|
11
|
+
|
12
|
+
class MarkovState(struct.PyTreeNode):
|
13
|
+
pass
|
14
|
+
|
15
|
+
class RandomMarkovState(MarkovState):
|
16
|
+
rng: jax.random.PRNGKey
|
17
|
+
|
18
|
+
def get_random_key(self):
|
19
|
+
rng, subkey = jax.random.split(self.rng)
|
20
|
+
return RandomMarkovState(rng), subkey
|
21
|
+
|
22
|
+
def clip_images(images, clip_min=-1, clip_max=1):
|
23
|
+
return jnp.clip(images, clip_min, clip_max)
|
24
|
+
|
25
|
+
def _build_global_shape_and_sharding(
|
26
|
+
local_shape: tuple[int, ...], global_mesh: Mesh
|
27
|
+
) -> tuple[tuple[int, ...], jax.sharding.NamedSharding]:
|
28
|
+
sharding = jax.sharding.NamedSharding(global_mesh, P(global_mesh.axis_names))
|
29
|
+
global_shape = (jax.process_count() * local_shape[0],) + local_shape[1:]
|
30
|
+
return global_shape, sharding
|
31
|
+
|
32
|
+
|
33
|
+
def form_global_array(path, array: np.ndarray, global_mesh: Mesh) -> jax.Array:
|
34
|
+
"""Put local sharded array into local devices"""
|
35
|
+
global_shape, sharding = _build_global_shape_and_sharding(np.shape(array), global_mesh)
|
36
|
+
try:
|
37
|
+
local_device_arrays = np.split(array, len(global_mesh.local_devices), axis=0)
|
38
|
+
except ValueError as array_split_error:
|
39
|
+
raise ValueError(
|
40
|
+
f"Unable to put to devices shape {array.shape} with "
|
41
|
+
f"local device count {len(global_mesh.local_devices)} "
|
42
|
+
) from array_split_error
|
43
|
+
local_device_buffers = jax.device_put(local_device_arrays, global_mesh.local_devices)
|
44
|
+
return jax.make_array_from_single_device_arrays(global_shape, sharding, local_device_buffers)
|
45
|
+
|
46
|
+
def convert_to_global_tree(global_mesh, pytree):
|
47
|
+
return jax.tree_util.tree_map_with_path(partial(form_global_array, global_mesh=global_mesh), pytree)
|
48
|
+
|
49
|
+
class RMSNorm(nn.Module):
|
50
|
+
"""
|
51
|
+
From "Root Mean Square Layer Normalization" by https://arxiv.org/abs/1910.07467
|
52
|
+
|
53
|
+
Adapted from flax.linen.LayerNorm
|
54
|
+
"""
|
55
|
+
|
56
|
+
epsilon: float = 1e-6
|
57
|
+
dtype: Any = jnp.float32
|
58
|
+
param_dtype: Any = jnp.float32
|
59
|
+
use_scale: bool = True
|
60
|
+
scale_init: Any = jax.nn.initializers.ones
|
61
|
+
|
62
|
+
@nn.compact
|
63
|
+
def __call__(self, x):
|
64
|
+
reduction_axes = (-1,)
|
65
|
+
feature_axes = (-1,)
|
66
|
+
|
67
|
+
rms_sq = self._compute_rms_sq(x, reduction_axes)
|
68
|
+
|
69
|
+
return self._normalize(
|
70
|
+
self,
|
71
|
+
x,
|
72
|
+
rms_sq,
|
73
|
+
reduction_axes,
|
74
|
+
feature_axes,
|
75
|
+
self.dtype,
|
76
|
+
self.param_dtype,
|
77
|
+
self.epsilon,
|
78
|
+
self.use_scale,
|
79
|
+
self.scale_init,
|
80
|
+
)
|
81
|
+
|
82
|
+
def _compute_rms_sq(self, x, axes):
|
83
|
+
x = jnp.asarray(x, jnp.promote_types(jnp.float32, jnp.result_type(x)))
|
84
|
+
rms_sq = jnp.mean(jax.lax.square(x), axes)
|
85
|
+
return rms_sq
|
86
|
+
|
87
|
+
def _normalize(
|
88
|
+
self,
|
89
|
+
mdl,
|
90
|
+
x,
|
91
|
+
rms_sq,
|
92
|
+
reduction_axes,
|
93
|
+
feature_axes,
|
94
|
+
dtype,
|
95
|
+
param_dtype,
|
96
|
+
epsilon,
|
97
|
+
use_scale,
|
98
|
+
scale_init,
|
99
|
+
):
|
100
|
+
reduction_axes = nn.normalization._canonicalize_axes(x.ndim, reduction_axes)
|
101
|
+
feature_axes = nn.normalization._canonicalize_axes(x.ndim, feature_axes)
|
102
|
+
stats_shape = list(x.shape)
|
103
|
+
for axis in reduction_axes:
|
104
|
+
stats_shape[axis] = 1
|
105
|
+
rms_sq = rms_sq.reshape(stats_shape)
|
106
|
+
feature_shape = [1] * x.ndim
|
107
|
+
reduced_feature_shape = []
|
108
|
+
for ax in feature_axes:
|
109
|
+
feature_shape[ax] = x.shape[ax]
|
110
|
+
reduced_feature_shape.append(x.shape[ax])
|
111
|
+
mul = jax.lax.rsqrt(rms_sq + epsilon)
|
112
|
+
if use_scale:
|
113
|
+
scale = mdl.param(
|
114
|
+
"scale", scale_init, reduced_feature_shape, param_dtype
|
115
|
+
).reshape(feature_shape)
|
116
|
+
mul *= scale
|
117
|
+
y = mul * x
|
118
|
+
return jnp.asarray(y, dtype)
|
119
|
+
|
120
|
+
@dataclass
|
121
|
+
class ConditioningEncoder(ABC):
|
122
|
+
model: nn.Module
|
123
|
+
tokenizer: Callable
|
124
|
+
|
125
|
+
def __call__(self, data):
|
126
|
+
tokens = self.tokenize(data)
|
127
|
+
outputs = self.encode_from_tokens(tokens)
|
128
|
+
return outputs
|
129
|
+
|
130
|
+
def encode_from_tokens(self, tokens):
|
131
|
+
outputs = self.model(input_ids=tokens['input_ids'],
|
132
|
+
attention_mask=tokens['attention_mask'])
|
133
|
+
last_hidden_state = outputs.last_hidden_state
|
134
|
+
return last_hidden_state
|
135
|
+
|
136
|
+
def tokenize(self, data):
|
137
|
+
tokens = self.tokenizer(data, padding="max_length",
|
138
|
+
max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="np")
|
139
|
+
return tokens
|
140
|
+
|
141
|
+
@dataclass
|
142
|
+
class TextEncoder(ConditioningEncoder):
|
143
|
+
# def __call__(self, data):
|
144
|
+
# tokens = self.tokenize(data)
|
145
|
+
# outputs = self.encode_from_tokens(tokens)
|
146
|
+
# return outputs
|
147
|
+
|
148
|
+
# def encode_from_tokens(self, tokens):
|
149
|
+
# outputs = self.model(input_ids=tokens['input_ids'],
|
150
|
+
# attention_mask=tokens['attention_mask'])
|
151
|
+
# last_hidden_state = outputs.last_hidden_state
|
152
|
+
# # pooler_output = outputs.pooler_output # pooled (EOS token) states
|
153
|
+
# # embed_pooled = pooler_output # .astype(jnp.float16)
|
154
|
+
# embed_labels_full = last_hidden_state # .astype(jnp.float16)
|
155
|
+
|
156
|
+
# return embed_labels_full
|
157
|
+
pass
|
158
|
+
|
159
|
+
class AutoTextTokenizer:
|
160
|
+
def __init__(self, tensor_type="pt", modelname="openai/clip-vit-large-patch14"):
|
161
|
+
from transformers import AutoTokenizer
|
162
|
+
self.tokenizer = AutoTokenizer.from_pretrained(modelname)
|
163
|
+
self.tensor_type = tensor_type
|
164
|
+
|
165
|
+
def __call__(self, inputs):
|
166
|
+
# print(caption)
|
167
|
+
tokens = self.tokenizer(inputs, padding="max_length", max_length=self.tokenizer.model_max_length,
|
168
|
+
truncation=True, return_tensors=self.tensor_type)
|
169
|
+
# print(tokens.keys())
|
170
|
+
return {
|
171
|
+
"input_ids": tokens["input_ids"],
|
172
|
+
"attention_mask": tokens["attention_mask"],
|
173
|
+
"caption": inputs,
|
174
|
+
}
|
175
|
+
|
176
|
+
def __repr__(self):
|
177
|
+
return self.__class__.__name__ + '()'
|
178
|
+
|
179
|
+
def defaultTextEncodeModel(backend="jax"):
|
180
|
+
from transformers import (
|
181
|
+
CLIPTextModel,
|
182
|
+
FlaxCLIPTextModel,
|
183
|
+
AutoTokenizer,
|
184
|
+
)
|
185
|
+
modelname = "openai/clip-vit-large-patch14"
|
186
|
+
if backend == "jax":
|
187
|
+
model = FlaxCLIPTextModel.from_pretrained(
|
188
|
+
modelname, dtype=jnp.bfloat16)
|
189
|
+
else:
|
190
|
+
model = CLIPTextModel.from_pretrained(modelname)
|
191
|
+
tokenizer = AutoTokenizer.from_pretrained(modelname, dtype=jnp.float16)
|
192
|
+
return TextEncoder(model, tokenizer)
|
@@ -1,15 +1,24 @@
|
|
1
|
-
Metadata-Version: 2.
|
1
|
+
Metadata-Version: 2.4
|
2
2
|
Name: flaxdiff
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.36.1
|
4
4
|
Summary: A versatile and easy to understand Diffusion library
|
5
|
-
Author: Ashish Kumar Singh
|
6
|
-
|
5
|
+
Author-email: Ashish Kumar Singh <ashishkmr472@gmail.com>
|
6
|
+
License-Expression: MIT
|
7
7
|
Description-Content-Type: text/markdown
|
8
8
|
Requires-Dist: flax>=0.8.4
|
9
|
-
Requires-Dist: optax>=0.2.2
|
10
9
|
Requires-Dist: jax>=0.4.28
|
10
|
+
Requires-Dist: optax>=0.2.2
|
11
11
|
Requires-Dist: orbax
|
12
|
+
Requires-Dist: numpy
|
12
13
|
Requires-Dist: clu
|
14
|
+
Requires-Dist: einops
|
15
|
+
Requires-Dist: tqdm
|
16
|
+
Requires-Dist: grain
|
17
|
+
Requires-Dist: termcolor
|
18
|
+
Requires-Dist: augmax
|
19
|
+
Requires-Dist: albumentations
|
20
|
+
Requires-Dist: rich
|
21
|
+
Requires-Dist: python-dotenv
|
13
22
|
|
14
23
|
# 
|
15
24
|
|
@@ -0,0 +1,32 @@
|
|
1
|
+
[build-system]
|
2
|
+
requires = ["setuptools", "wheel"]
|
3
|
+
build-backend = "setuptools.build_meta"
|
4
|
+
|
5
|
+
[project]
|
6
|
+
name = "flaxdiff"
|
7
|
+
version = "0.1.36.1"
|
8
|
+
description = "A versatile and easy to understand Diffusion library"
|
9
|
+
readme = "README.md"
|
10
|
+
authors = [
|
11
|
+
{ name="Ashish Kumar Singh", email="ashishkmr472@gmail.com" }
|
12
|
+
]
|
13
|
+
dependencies = [
|
14
|
+
"flax>=0.8.4",
|
15
|
+
"jax>=0.4.28",
|
16
|
+
"optax>=0.2.2",
|
17
|
+
"orbax",
|
18
|
+
"numpy",
|
19
|
+
"clu",
|
20
|
+
"einops",
|
21
|
+
"tqdm",
|
22
|
+
"grain",
|
23
|
+
"termcolor",
|
24
|
+
"augmax",
|
25
|
+
"albumentations",
|
26
|
+
"rich",
|
27
|
+
"python-dotenv",
|
28
|
+
]
|
29
|
+
license = "MIT"
|
30
|
+
|
31
|
+
[tool.setuptools]
|
32
|
+
packages = ["flaxdiff"]
|
@@ -1 +0,0 @@
|
|
1
|
-
from .online_loader import OnlineStreamingDataLoader
|
@@ -1,336 +0,0 @@
|
|
1
|
-
import multiprocessing
|
2
|
-
import threading
|
3
|
-
from multiprocessing import Queue
|
4
|
-
# from arrayqueues.shared_arrays import ArrayQueue
|
5
|
-
# from faster_fifo import Queue
|
6
|
-
import time
|
7
|
-
import albumentations as A
|
8
|
-
import queue
|
9
|
-
import cv2
|
10
|
-
from functools import partial
|
11
|
-
from typing import Any, Dict, List, Tuple
|
12
|
-
|
13
|
-
import numpy as np
|
14
|
-
from functools import partial
|
15
|
-
|
16
|
-
from datasets import load_dataset, concatenate_datasets, Dataset, load_from_disk
|
17
|
-
from datasets.utils.file_utils import get_datasets_user_agent
|
18
|
-
from concurrent.futures import ThreadPoolExecutor
|
19
|
-
import io
|
20
|
-
import urllib
|
21
|
-
|
22
|
-
import PIL.Image
|
23
|
-
import cv2
|
24
|
-
import traceback
|
25
|
-
|
26
|
-
USER_AGENT = get_datasets_user_agent()
|
27
|
-
|
28
|
-
data_queue = Queue(16*2000)
|
29
|
-
|
30
|
-
|
31
|
-
def fetch_single_image(image_url, timeout=None, retries=0):
|
32
|
-
for _ in range(retries + 1):
|
33
|
-
try:
|
34
|
-
request = urllib.request.Request(
|
35
|
-
image_url,
|
36
|
-
data=None,
|
37
|
-
headers={"user-agent": USER_AGENT},
|
38
|
-
)
|
39
|
-
with urllib.request.urlopen(request, timeout=timeout) as req:
|
40
|
-
image = PIL.Image.open(io.BytesIO(req.read()))
|
41
|
-
break
|
42
|
-
except Exception:
|
43
|
-
image = None
|
44
|
-
return image
|
45
|
-
|
46
|
-
|
47
|
-
def default_image_processor(
|
48
|
-
image, image_shape,
|
49
|
-
min_image_shape=(128, 128),
|
50
|
-
upscale_interpolation=cv2.INTER_CUBIC,
|
51
|
-
downscale_interpolation=cv2.INTER_AREA,
|
52
|
-
):
|
53
|
-
image = np.array(image)
|
54
|
-
original_height, original_width = image.shape[:2]
|
55
|
-
# check if the image is too small
|
56
|
-
if min(original_height, original_width) < min(min_image_shape):
|
57
|
-
return None, original_height, original_width
|
58
|
-
# check if wrong aspect ratio
|
59
|
-
if max(original_height, original_width) / min(original_height, original_width) > 2.4:
|
60
|
-
return None, original_height, original_width
|
61
|
-
# check if the variance is too low
|
62
|
-
if np.std(image) < 1e-5:
|
63
|
-
return None, original_height, original_width
|
64
|
-
# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
65
|
-
downscale = max(original_width, original_height) > max(image_shape)
|
66
|
-
interpolation = downscale_interpolation if downscale else upscale_interpolation
|
67
|
-
|
68
|
-
image = A.longest_max_size(image, max(
|
69
|
-
image_shape), interpolation=interpolation)
|
70
|
-
image = A.pad(
|
71
|
-
image,
|
72
|
-
min_height=image_shape[0],
|
73
|
-
min_width=image_shape[1],
|
74
|
-
border_mode=cv2.BORDER_CONSTANT,
|
75
|
-
value=[255, 255, 255],
|
76
|
-
)
|
77
|
-
return image, original_height, original_width
|
78
|
-
|
79
|
-
|
80
|
-
def map_sample(
|
81
|
-
url,
|
82
|
-
caption,
|
83
|
-
image_shape=(256, 256),
|
84
|
-
min_image_shape=(128, 128),
|
85
|
-
timeout=15,
|
86
|
-
retries=3,
|
87
|
-
upscale_interpolation=cv2.INTER_CUBIC,
|
88
|
-
downscale_interpolation=cv2.INTER_AREA,
|
89
|
-
image_processor=default_image_processor,
|
90
|
-
):
|
91
|
-
try:
|
92
|
-
# Assuming fetch_single_image is defined elsewhere
|
93
|
-
image = fetch_single_image(url, timeout=timeout, retries=retries)
|
94
|
-
if image is None:
|
95
|
-
return
|
96
|
-
|
97
|
-
image, original_height, original_width = image_processor(
|
98
|
-
image, image_shape, min_image_shape=min_image_shape,
|
99
|
-
upscale_interpolation=upscale_interpolation,
|
100
|
-
downscale_interpolation=downscale_interpolation,
|
101
|
-
)
|
102
|
-
|
103
|
-
if image is None:
|
104
|
-
return
|
105
|
-
|
106
|
-
data_queue.put({
|
107
|
-
"url": url,
|
108
|
-
"caption": caption,
|
109
|
-
"image": image,
|
110
|
-
"original_height": original_height,
|
111
|
-
"original_width": original_width,
|
112
|
-
})
|
113
|
-
except Exception as e:
|
114
|
-
# print(f"Error maping sample {url}", e)
|
115
|
-
# traceback.print_exc()
|
116
|
-
# error_queue.put_nowait({
|
117
|
-
# "url": url,
|
118
|
-
# "caption": caption,
|
119
|
-
# "error": str(e)
|
120
|
-
# })
|
121
|
-
pass
|
122
|
-
|
123
|
-
|
124
|
-
def default_feature_extractor(sample):
|
125
|
-
return {
|
126
|
-
"url": sample["url"],
|
127
|
-
"caption": sample["caption"],
|
128
|
-
}
|
129
|
-
|
130
|
-
|
131
|
-
def map_batch(
|
132
|
-
batch, num_threads=256, image_shape=(256, 256),
|
133
|
-
min_image_shape=(128, 128),
|
134
|
-
timeout=15, retries=3, image_processor=default_image_processor,
|
135
|
-
upscale_interpolation=cv2.INTER_CUBIC,
|
136
|
-
downscale_interpolation=cv2.INTER_AREA,
|
137
|
-
feature_extractor=default_feature_extractor,
|
138
|
-
):
|
139
|
-
try:
|
140
|
-
map_sample_fn = partial(
|
141
|
-
map_sample, image_shape=image_shape, min_image_shape=min_image_shape,
|
142
|
-
timeout=timeout, retries=retries, image_processor=image_processor,
|
143
|
-
upscale_interpolation=upscale_interpolation,
|
144
|
-
downscale_interpolation=downscale_interpolation
|
145
|
-
)
|
146
|
-
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
147
|
-
features = feature_extractor(batch)
|
148
|
-
url, caption = features["url"], features["caption"]
|
149
|
-
executor.map(map_sample_fn, url, caption)
|
150
|
-
except Exception as e:
|
151
|
-
print(f"Error maping batch", e)
|
152
|
-
traceback.print_exc()
|
153
|
-
# error_queue.put_nowait({
|
154
|
-
# "batch": batch,
|
155
|
-
# "error": str(e)
|
156
|
-
# })
|
157
|
-
pass
|
158
|
-
|
159
|
-
|
160
|
-
def parallel_image_loader(
|
161
|
-
dataset: Dataset, num_workers: int = 8, image_shape=(256, 256),
|
162
|
-
min_image_shape=(128, 128),
|
163
|
-
num_threads=256, timeout=15, retries=3, image_processor=default_image_processor,
|
164
|
-
upscale_interpolation=cv2.INTER_CUBIC,
|
165
|
-
downscale_interpolation=cv2.INTER_AREA,
|
166
|
-
feature_extractor=default_feature_extractor,
|
167
|
-
):
|
168
|
-
map_batch_fn = partial(
|
169
|
-
map_batch, num_threads=num_threads, image_shape=image_shape,
|
170
|
-
min_image_shape=min_image_shape,
|
171
|
-
timeout=timeout, retries=retries, image_processor=image_processor,
|
172
|
-
upscale_interpolation=upscale_interpolation,
|
173
|
-
downscale_interpolation=downscale_interpolation,
|
174
|
-
feature_extractor=feature_extractor
|
175
|
-
)
|
176
|
-
shard_len = len(dataset) // num_workers
|
177
|
-
print(f"Local Shard lengths: {shard_len}")
|
178
|
-
with multiprocessing.Pool(num_workers) as pool:
|
179
|
-
iteration = 0
|
180
|
-
while True:
|
181
|
-
# Repeat forever
|
182
|
-
shards = [dataset[i*shard_len:(i+1)*shard_len]
|
183
|
-
for i in range(num_workers)]
|
184
|
-
print(f"mapping {len(shards)} shards")
|
185
|
-
pool.map(map_batch_fn, shards)
|
186
|
-
iteration += 1
|
187
|
-
print(f"Shuffling dataset with seed {iteration}")
|
188
|
-
dataset = dataset.shuffle(seed=iteration)
|
189
|
-
# Clear the error queue
|
190
|
-
# while not error_queue.empty():
|
191
|
-
# error_queue.get_nowait()
|
192
|
-
|
193
|
-
|
194
|
-
class ImageBatchIterator:
|
195
|
-
def __init__(
|
196
|
-
self, dataset: Dataset, batch_size: int = 64, image_shape=(256, 256),
|
197
|
-
min_image_shape=(128, 128),
|
198
|
-
num_workers: int = 8, num_threads=256, timeout=15, retries=3,
|
199
|
-
image_processor=default_image_processor,
|
200
|
-
upscale_interpolation=cv2.INTER_CUBIC,
|
201
|
-
downscale_interpolation=cv2.INTER_AREA,
|
202
|
-
feature_extractor=default_feature_extractor,
|
203
|
-
):
|
204
|
-
self.dataset = dataset
|
205
|
-
self.num_workers = num_workers
|
206
|
-
self.batch_size = batch_size
|
207
|
-
loader = partial(
|
208
|
-
parallel_image_loader,
|
209
|
-
num_threads=num_threads,
|
210
|
-
image_shape=image_shape,
|
211
|
-
min_image_shape=min_image_shape,
|
212
|
-
num_workers=num_workers,
|
213
|
-
timeout=timeout, retries=retries,
|
214
|
-
image_processor=image_processor,
|
215
|
-
upscale_interpolation=upscale_interpolation,
|
216
|
-
downscale_interpolation=downscale_interpolation,
|
217
|
-
feature_extractor=feature_extractor
|
218
|
-
)
|
219
|
-
self.thread = threading.Thread(target=loader, args=(dataset,))
|
220
|
-
self.thread.start()
|
221
|
-
|
222
|
-
def __iter__(self):
|
223
|
-
return self
|
224
|
-
|
225
|
-
def __next__(self):
|
226
|
-
def fetcher(_):
|
227
|
-
return data_queue.get()
|
228
|
-
with ThreadPoolExecutor(max_workers=self.batch_size) as executor:
|
229
|
-
batch = list(executor.map(fetcher, range(self.batch_size)))
|
230
|
-
return batch
|
231
|
-
|
232
|
-
def __del__(self):
|
233
|
-
self.thread.join()
|
234
|
-
|
235
|
-
def __len__(self):
|
236
|
-
return len(self.dataset) // self.batch_size
|
237
|
-
|
238
|
-
|
239
|
-
def default_collate(batch):
|
240
|
-
urls = [sample["url"] for sample in batch]
|
241
|
-
captions = [sample["caption"] for sample in batch]
|
242
|
-
images = np.stack([sample["image"] for sample in batch], axis=0)
|
243
|
-
return {
|
244
|
-
"url": urls,
|
245
|
-
"caption": captions,
|
246
|
-
"image": images,
|
247
|
-
}
|
248
|
-
|
249
|
-
|
250
|
-
def dataMapper(map: Dict[str, Any]):
|
251
|
-
def _map(sample) -> Dict[str, Any]:
|
252
|
-
return {
|
253
|
-
"url": sample[map["url"]],
|
254
|
-
"caption": sample[map["caption"]],
|
255
|
-
}
|
256
|
-
return _map
|
257
|
-
|
258
|
-
|
259
|
-
class OnlineStreamingDataLoader():
|
260
|
-
def __init__(
|
261
|
-
self,
|
262
|
-
dataset,
|
263
|
-
batch_size=64,
|
264
|
-
image_shape=(256, 256),
|
265
|
-
min_image_shape=(128, 128),
|
266
|
-
num_workers=16,
|
267
|
-
num_threads=512,
|
268
|
-
default_split="all",
|
269
|
-
pre_map_maker=dataMapper,
|
270
|
-
pre_map_def={
|
271
|
-
"url": "URL",
|
272
|
-
"caption": "TEXT",
|
273
|
-
},
|
274
|
-
global_process_count=1,
|
275
|
-
global_process_index=0,
|
276
|
-
prefetch=1000,
|
277
|
-
collate_fn=default_collate,
|
278
|
-
timeout=15,
|
279
|
-
retries=3,
|
280
|
-
image_processor=default_image_processor,
|
281
|
-
upscale_interpolation=cv2.INTER_CUBIC,
|
282
|
-
downscale_interpolation=cv2.INTER_AREA,
|
283
|
-
feature_extractor=default_feature_extractor,
|
284
|
-
):
|
285
|
-
if isinstance(dataset, str):
|
286
|
-
dataset_path = dataset
|
287
|
-
print("Loading dataset from path")
|
288
|
-
if "gs://" in dataset:
|
289
|
-
dataset = load_from_disk(dataset_path)
|
290
|
-
else:
|
291
|
-
dataset = load_dataset(dataset_path, split=default_split)
|
292
|
-
elif isinstance(dataset, list):
|
293
|
-
if isinstance(dataset[0], str):
|
294
|
-
print("Loading multiple datasets from paths")
|
295
|
-
dataset = [load_from_disk(dataset_path) if "gs://" in dataset_path else load_dataset(
|
296
|
-
dataset_path, split=default_split) for dataset_path in dataset]
|
297
|
-
print("Concatenating multiple datasets")
|
298
|
-
dataset = concatenate_datasets(dataset)
|
299
|
-
dataset = dataset.shuffle(seed=0)
|
300
|
-
# dataset = dataset.map(pre_map_maker(pre_map_def), batched=True, batch_size=10000000)
|
301
|
-
self.dataset = dataset.shard(
|
302
|
-
num_shards=global_process_count, index=global_process_index)
|
303
|
-
print(f"Dataset length: {len(dataset)}")
|
304
|
-
self.iterator = ImageBatchIterator(
|
305
|
-
self.dataset, image_shape=image_shape,
|
306
|
-
min_image_shape=min_image_shape,
|
307
|
-
num_workers=num_workers, batch_size=batch_size, num_threads=num_threads,
|
308
|
-
timeout=timeout, retries=retries, image_processor=image_processor,
|
309
|
-
upscale_interpolation=upscale_interpolation,
|
310
|
-
downscale_interpolation=downscale_interpolation,
|
311
|
-
feature_extractor=feature_extractor
|
312
|
-
)
|
313
|
-
self.batch_size = batch_size
|
314
|
-
|
315
|
-
# Launch a thread to load batches in the background
|
316
|
-
self.batch_queue = queue.Queue(prefetch)
|
317
|
-
|
318
|
-
def batch_loader():
|
319
|
-
for batch in self.iterator:
|
320
|
-
try:
|
321
|
-
self.batch_queue.put(collate_fn(batch))
|
322
|
-
except Exception as e:
|
323
|
-
print("Error processing batch", e)
|
324
|
-
|
325
|
-
self.loader_thread = threading.Thread(target=batch_loader)
|
326
|
-
self.loader_thread.start()
|
327
|
-
|
328
|
-
def __iter__(self):
|
329
|
-
return self
|
330
|
-
|
331
|
-
def __next__(self):
|
332
|
-
return self.batch_queue.get()
|
333
|
-
# return self.collate_fn(next(self.iterator))
|
334
|
-
|
335
|
-
def __len__(self):
|
336
|
-
return len(self.dataset)
|
@@ -1 +0,0 @@
|
|
1
|
-
from .simple_unet import *
|