flaxdiff 0.1.35.6__tar.gz → 0.1.36.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. {flaxdiff-0.1.35.6 → flaxdiff-0.1.36.1}/PKG-INFO +14 -5
  2. flaxdiff-0.1.36.1/flaxdiff/utils.py +192 -0
  3. {flaxdiff-0.1.35.6 → flaxdiff-0.1.36.1}/flaxdiff.egg-info/PKG-INFO +14 -5
  4. flaxdiff-0.1.36.1/flaxdiff.egg-info/SOURCES.txt +9 -0
  5. flaxdiff-0.1.36.1/flaxdiff.egg-info/requires.txt +14 -0
  6. flaxdiff-0.1.36.1/pyproject.toml +32 -0
  7. flaxdiff-0.1.35.6/flaxdiff/data/__init__.py +0 -1
  8. flaxdiff-0.1.35.6/flaxdiff/data/online_loader.py +0 -336
  9. flaxdiff-0.1.35.6/flaxdiff/models/__init__.py +0 -1
  10. flaxdiff-0.1.35.6/flaxdiff/models/attention.py +0 -368
  11. flaxdiff-0.1.35.6/flaxdiff/models/autoencoder/__init__.py +0 -2
  12. flaxdiff-0.1.35.6/flaxdiff/models/autoencoder/autoencoder.py +0 -19
  13. flaxdiff-0.1.35.6/flaxdiff/models/autoencoder/diffusers.py +0 -91
  14. flaxdiff-0.1.35.6/flaxdiff/models/autoencoder/simple_autoenc.py +0 -26
  15. flaxdiff-0.1.35.6/flaxdiff/models/common.py +0 -346
  16. flaxdiff-0.1.35.6/flaxdiff/models/favor_fastattn.py +0 -723
  17. flaxdiff-0.1.35.6/flaxdiff/models/simple_unet.py +0 -233
  18. flaxdiff-0.1.35.6/flaxdiff/models/simple_vit.py +0 -180
  19. flaxdiff-0.1.35.6/flaxdiff/predictors/__init__.py +0 -96
  20. flaxdiff-0.1.35.6/flaxdiff/samplers/__init__.py +0 -7
  21. flaxdiff-0.1.35.6/flaxdiff/samplers/common.py +0 -113
  22. flaxdiff-0.1.35.6/flaxdiff/samplers/ddim.py +0 -10
  23. flaxdiff-0.1.35.6/flaxdiff/samplers/ddpm.py +0 -43
  24. flaxdiff-0.1.35.6/flaxdiff/samplers/euler.py +0 -59
  25. flaxdiff-0.1.35.6/flaxdiff/samplers/heun_sampler.py +0 -28
  26. flaxdiff-0.1.35.6/flaxdiff/samplers/multistep_dpm.py +0 -60
  27. flaxdiff-0.1.35.6/flaxdiff/samplers/rk4_sampler.py +0 -34
  28. flaxdiff-0.1.35.6/flaxdiff/schedulers/__init__.py +0 -6
  29. flaxdiff-0.1.35.6/flaxdiff/schedulers/common.py +0 -98
  30. flaxdiff-0.1.35.6/flaxdiff/schedulers/continuous.py +0 -12
  31. flaxdiff-0.1.35.6/flaxdiff/schedulers/cosine.py +0 -40
  32. flaxdiff-0.1.35.6/flaxdiff/schedulers/discrete.py +0 -74
  33. flaxdiff-0.1.35.6/flaxdiff/schedulers/exp.py +0 -13
  34. flaxdiff-0.1.35.6/flaxdiff/schedulers/karras.py +0 -69
  35. flaxdiff-0.1.35.6/flaxdiff/schedulers/linear.py +0 -14
  36. flaxdiff-0.1.35.6/flaxdiff/schedulers/sqrt.py +0 -10
  37. flaxdiff-0.1.35.6/flaxdiff/trainer/__init__.py +0 -2
  38. flaxdiff-0.1.35.6/flaxdiff/trainer/autoencoder_trainer.py +0 -182
  39. flaxdiff-0.1.35.6/flaxdiff/trainer/diffusion_trainer.py +0 -234
  40. flaxdiff-0.1.35.6/flaxdiff/trainer/simple_trainer.py +0 -442
  41. flaxdiff-0.1.35.6/flaxdiff/utils.py +0 -89
  42. flaxdiff-0.1.35.6/flaxdiff.egg-info/SOURCES.txt +0 -43
  43. flaxdiff-0.1.35.6/flaxdiff.egg-info/requires.txt +0 -5
  44. flaxdiff-0.1.35.6/setup.py +0 -21
  45. {flaxdiff-0.1.35.6 → flaxdiff-0.1.36.1}/README.md +0 -0
  46. {flaxdiff-0.1.35.6 → flaxdiff-0.1.36.1}/flaxdiff/__init__.py +0 -0
  47. {flaxdiff-0.1.35.6 → flaxdiff-0.1.36.1}/flaxdiff.egg-info/dependency_links.txt +0 -0
  48. {flaxdiff-0.1.35.6 → flaxdiff-0.1.36.1}/flaxdiff.egg-info/top_level.txt +0 -0
  49. {flaxdiff-0.1.35.6 → flaxdiff-0.1.36.1}/setup.cfg +0 -0
@@ -1,15 +1,24 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.4
2
2
  Name: flaxdiff
3
- Version: 0.1.35.6
3
+ Version: 0.1.36.1
4
4
  Summary: A versatile and easy to understand Diffusion library
5
- Author: Ashish Kumar Singh
6
- Author-email: ashishkmr472@gmail.com
5
+ Author-email: Ashish Kumar Singh <ashishkmr472@gmail.com>
6
+ License-Expression: MIT
7
7
  Description-Content-Type: text/markdown
8
8
  Requires-Dist: flax>=0.8.4
9
- Requires-Dist: optax>=0.2.2
10
9
  Requires-Dist: jax>=0.4.28
10
+ Requires-Dist: optax>=0.2.2
11
11
  Requires-Dist: orbax
12
+ Requires-Dist: numpy
12
13
  Requires-Dist: clu
14
+ Requires-Dist: einops
15
+ Requires-Dist: tqdm
16
+ Requires-Dist: grain
17
+ Requires-Dist: termcolor
18
+ Requires-Dist: augmax
19
+ Requires-Dist: albumentations
20
+ Requires-Dist: rich
21
+ Requires-Dist: python-dotenv
13
22
 
14
23
  # ![](images/logo.jpeg "FlaxDiff")
15
24
 
@@ -0,0 +1,192 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ import flax.struct as struct
4
+ import flax.linen as nn
5
+ from typing import Any, Callable
6
+ from dataclasses import dataclass
7
+ from functools import partial
8
+ import numpy as np
9
+ from jax.sharding import Mesh, PartitionSpec as P
10
+ from abc import ABC, abstractmethod
11
+
12
+ class MarkovState(struct.PyTreeNode):
13
+ pass
14
+
15
+ class RandomMarkovState(MarkovState):
16
+ rng: jax.random.PRNGKey
17
+
18
+ def get_random_key(self):
19
+ rng, subkey = jax.random.split(self.rng)
20
+ return RandomMarkovState(rng), subkey
21
+
22
+ def clip_images(images, clip_min=-1, clip_max=1):
23
+ return jnp.clip(images, clip_min, clip_max)
24
+
25
+ def _build_global_shape_and_sharding(
26
+ local_shape: tuple[int, ...], global_mesh: Mesh
27
+ ) -> tuple[tuple[int, ...], jax.sharding.NamedSharding]:
28
+ sharding = jax.sharding.NamedSharding(global_mesh, P(global_mesh.axis_names))
29
+ global_shape = (jax.process_count() * local_shape[0],) + local_shape[1:]
30
+ return global_shape, sharding
31
+
32
+
33
+ def form_global_array(path, array: np.ndarray, global_mesh: Mesh) -> jax.Array:
34
+ """Put local sharded array into local devices"""
35
+ global_shape, sharding = _build_global_shape_and_sharding(np.shape(array), global_mesh)
36
+ try:
37
+ local_device_arrays = np.split(array, len(global_mesh.local_devices), axis=0)
38
+ except ValueError as array_split_error:
39
+ raise ValueError(
40
+ f"Unable to put to devices shape {array.shape} with "
41
+ f"local device count {len(global_mesh.local_devices)} "
42
+ ) from array_split_error
43
+ local_device_buffers = jax.device_put(local_device_arrays, global_mesh.local_devices)
44
+ return jax.make_array_from_single_device_arrays(global_shape, sharding, local_device_buffers)
45
+
46
+ def convert_to_global_tree(global_mesh, pytree):
47
+ return jax.tree_util.tree_map_with_path(partial(form_global_array, global_mesh=global_mesh), pytree)
48
+
49
+ class RMSNorm(nn.Module):
50
+ """
51
+ From "Root Mean Square Layer Normalization" by https://arxiv.org/abs/1910.07467
52
+
53
+ Adapted from flax.linen.LayerNorm
54
+ """
55
+
56
+ epsilon: float = 1e-6
57
+ dtype: Any = jnp.float32
58
+ param_dtype: Any = jnp.float32
59
+ use_scale: bool = True
60
+ scale_init: Any = jax.nn.initializers.ones
61
+
62
+ @nn.compact
63
+ def __call__(self, x):
64
+ reduction_axes = (-1,)
65
+ feature_axes = (-1,)
66
+
67
+ rms_sq = self._compute_rms_sq(x, reduction_axes)
68
+
69
+ return self._normalize(
70
+ self,
71
+ x,
72
+ rms_sq,
73
+ reduction_axes,
74
+ feature_axes,
75
+ self.dtype,
76
+ self.param_dtype,
77
+ self.epsilon,
78
+ self.use_scale,
79
+ self.scale_init,
80
+ )
81
+
82
+ def _compute_rms_sq(self, x, axes):
83
+ x = jnp.asarray(x, jnp.promote_types(jnp.float32, jnp.result_type(x)))
84
+ rms_sq = jnp.mean(jax.lax.square(x), axes)
85
+ return rms_sq
86
+
87
+ def _normalize(
88
+ self,
89
+ mdl,
90
+ x,
91
+ rms_sq,
92
+ reduction_axes,
93
+ feature_axes,
94
+ dtype,
95
+ param_dtype,
96
+ epsilon,
97
+ use_scale,
98
+ scale_init,
99
+ ):
100
+ reduction_axes = nn.normalization._canonicalize_axes(x.ndim, reduction_axes)
101
+ feature_axes = nn.normalization._canonicalize_axes(x.ndim, feature_axes)
102
+ stats_shape = list(x.shape)
103
+ for axis in reduction_axes:
104
+ stats_shape[axis] = 1
105
+ rms_sq = rms_sq.reshape(stats_shape)
106
+ feature_shape = [1] * x.ndim
107
+ reduced_feature_shape = []
108
+ for ax in feature_axes:
109
+ feature_shape[ax] = x.shape[ax]
110
+ reduced_feature_shape.append(x.shape[ax])
111
+ mul = jax.lax.rsqrt(rms_sq + epsilon)
112
+ if use_scale:
113
+ scale = mdl.param(
114
+ "scale", scale_init, reduced_feature_shape, param_dtype
115
+ ).reshape(feature_shape)
116
+ mul *= scale
117
+ y = mul * x
118
+ return jnp.asarray(y, dtype)
119
+
120
+ @dataclass
121
+ class ConditioningEncoder(ABC):
122
+ model: nn.Module
123
+ tokenizer: Callable
124
+
125
+ def __call__(self, data):
126
+ tokens = self.tokenize(data)
127
+ outputs = self.encode_from_tokens(tokens)
128
+ return outputs
129
+
130
+ def encode_from_tokens(self, tokens):
131
+ outputs = self.model(input_ids=tokens['input_ids'],
132
+ attention_mask=tokens['attention_mask'])
133
+ last_hidden_state = outputs.last_hidden_state
134
+ return last_hidden_state
135
+
136
+ def tokenize(self, data):
137
+ tokens = self.tokenizer(data, padding="max_length",
138
+ max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="np")
139
+ return tokens
140
+
141
+ @dataclass
142
+ class TextEncoder(ConditioningEncoder):
143
+ # def __call__(self, data):
144
+ # tokens = self.tokenize(data)
145
+ # outputs = self.encode_from_tokens(tokens)
146
+ # return outputs
147
+
148
+ # def encode_from_tokens(self, tokens):
149
+ # outputs = self.model(input_ids=tokens['input_ids'],
150
+ # attention_mask=tokens['attention_mask'])
151
+ # last_hidden_state = outputs.last_hidden_state
152
+ # # pooler_output = outputs.pooler_output # pooled (EOS token) states
153
+ # # embed_pooled = pooler_output # .astype(jnp.float16)
154
+ # embed_labels_full = last_hidden_state # .astype(jnp.float16)
155
+
156
+ # return embed_labels_full
157
+ pass
158
+
159
+ class AutoTextTokenizer:
160
+ def __init__(self, tensor_type="pt", modelname="openai/clip-vit-large-patch14"):
161
+ from transformers import AutoTokenizer
162
+ self.tokenizer = AutoTokenizer.from_pretrained(modelname)
163
+ self.tensor_type = tensor_type
164
+
165
+ def __call__(self, inputs):
166
+ # print(caption)
167
+ tokens = self.tokenizer(inputs, padding="max_length", max_length=self.tokenizer.model_max_length,
168
+ truncation=True, return_tensors=self.tensor_type)
169
+ # print(tokens.keys())
170
+ return {
171
+ "input_ids": tokens["input_ids"],
172
+ "attention_mask": tokens["attention_mask"],
173
+ "caption": inputs,
174
+ }
175
+
176
+ def __repr__(self):
177
+ return self.__class__.__name__ + '()'
178
+
179
+ def defaultTextEncodeModel(backend="jax"):
180
+ from transformers import (
181
+ CLIPTextModel,
182
+ FlaxCLIPTextModel,
183
+ AutoTokenizer,
184
+ )
185
+ modelname = "openai/clip-vit-large-patch14"
186
+ if backend == "jax":
187
+ model = FlaxCLIPTextModel.from_pretrained(
188
+ modelname, dtype=jnp.bfloat16)
189
+ else:
190
+ model = CLIPTextModel.from_pretrained(modelname)
191
+ tokenizer = AutoTokenizer.from_pretrained(modelname, dtype=jnp.float16)
192
+ return TextEncoder(model, tokenizer)
@@ -1,15 +1,24 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.4
2
2
  Name: flaxdiff
3
- Version: 0.1.35.6
3
+ Version: 0.1.36.1
4
4
  Summary: A versatile and easy to understand Diffusion library
5
- Author: Ashish Kumar Singh
6
- Author-email: ashishkmr472@gmail.com
5
+ Author-email: Ashish Kumar Singh <ashishkmr472@gmail.com>
6
+ License-Expression: MIT
7
7
  Description-Content-Type: text/markdown
8
8
  Requires-Dist: flax>=0.8.4
9
- Requires-Dist: optax>=0.2.2
10
9
  Requires-Dist: jax>=0.4.28
10
+ Requires-Dist: optax>=0.2.2
11
11
  Requires-Dist: orbax
12
+ Requires-Dist: numpy
12
13
  Requires-Dist: clu
14
+ Requires-Dist: einops
15
+ Requires-Dist: tqdm
16
+ Requires-Dist: grain
17
+ Requires-Dist: termcolor
18
+ Requires-Dist: augmax
19
+ Requires-Dist: albumentations
20
+ Requires-Dist: rich
21
+ Requires-Dist: python-dotenv
13
22
 
14
23
  # ![](images/logo.jpeg "FlaxDiff")
15
24
 
@@ -0,0 +1,9 @@
1
+ README.md
2
+ pyproject.toml
3
+ flaxdiff/__init__.py
4
+ flaxdiff/utils.py
5
+ flaxdiff.egg-info/PKG-INFO
6
+ flaxdiff.egg-info/SOURCES.txt
7
+ flaxdiff.egg-info/dependency_links.txt
8
+ flaxdiff.egg-info/requires.txt
9
+ flaxdiff.egg-info/top_level.txt
@@ -0,0 +1,14 @@
1
+ flax>=0.8.4
2
+ jax>=0.4.28
3
+ optax>=0.2.2
4
+ orbax
5
+ numpy
6
+ clu
7
+ einops
8
+ tqdm
9
+ grain
10
+ termcolor
11
+ augmax
12
+ albumentations
13
+ rich
14
+ python-dotenv
@@ -0,0 +1,32 @@
1
+ [build-system]
2
+ requires = ["setuptools", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "flaxdiff"
7
+ version = "0.1.36.1"
8
+ description = "A versatile and easy to understand Diffusion library"
9
+ readme = "README.md"
10
+ authors = [
11
+ { name="Ashish Kumar Singh", email="ashishkmr472@gmail.com" }
12
+ ]
13
+ dependencies = [
14
+ "flax>=0.8.4",
15
+ "jax>=0.4.28",
16
+ "optax>=0.2.2",
17
+ "orbax",
18
+ "numpy",
19
+ "clu",
20
+ "einops",
21
+ "tqdm",
22
+ "grain",
23
+ "termcolor",
24
+ "augmax",
25
+ "albumentations",
26
+ "rich",
27
+ "python-dotenv",
28
+ ]
29
+ license = "MIT"
30
+
31
+ [tool.setuptools]
32
+ packages = ["flaxdiff"]
@@ -1 +0,0 @@
1
- from .online_loader import OnlineStreamingDataLoader
@@ -1,336 +0,0 @@
1
- import multiprocessing
2
- import threading
3
- from multiprocessing import Queue
4
- # from arrayqueues.shared_arrays import ArrayQueue
5
- # from faster_fifo import Queue
6
- import time
7
- import albumentations as A
8
- import queue
9
- import cv2
10
- from functools import partial
11
- from typing import Any, Dict, List, Tuple
12
-
13
- import numpy as np
14
- from functools import partial
15
-
16
- from datasets import load_dataset, concatenate_datasets, Dataset, load_from_disk
17
- from datasets.utils.file_utils import get_datasets_user_agent
18
- from concurrent.futures import ThreadPoolExecutor
19
- import io
20
- import urllib
21
-
22
- import PIL.Image
23
- import cv2
24
- import traceback
25
-
26
- USER_AGENT = get_datasets_user_agent()
27
-
28
- data_queue = Queue(16*2000)
29
-
30
-
31
- def fetch_single_image(image_url, timeout=None, retries=0):
32
- for _ in range(retries + 1):
33
- try:
34
- request = urllib.request.Request(
35
- image_url,
36
- data=None,
37
- headers={"user-agent": USER_AGENT},
38
- )
39
- with urllib.request.urlopen(request, timeout=timeout) as req:
40
- image = PIL.Image.open(io.BytesIO(req.read()))
41
- break
42
- except Exception:
43
- image = None
44
- return image
45
-
46
-
47
- def default_image_processor(
48
- image, image_shape,
49
- min_image_shape=(128, 128),
50
- upscale_interpolation=cv2.INTER_CUBIC,
51
- downscale_interpolation=cv2.INTER_AREA,
52
- ):
53
- image = np.array(image)
54
- original_height, original_width = image.shape[:2]
55
- # check if the image is too small
56
- if min(original_height, original_width) < min(min_image_shape):
57
- return None, original_height, original_width
58
- # check if wrong aspect ratio
59
- if max(original_height, original_width) / min(original_height, original_width) > 2.4:
60
- return None, original_height, original_width
61
- # check if the variance is too low
62
- if np.std(image) < 1e-5:
63
- return None, original_height, original_width
64
- # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
65
- downscale = max(original_width, original_height) > max(image_shape)
66
- interpolation = downscale_interpolation if downscale else upscale_interpolation
67
-
68
- image = A.longest_max_size(image, max(
69
- image_shape), interpolation=interpolation)
70
- image = A.pad(
71
- image,
72
- min_height=image_shape[0],
73
- min_width=image_shape[1],
74
- border_mode=cv2.BORDER_CONSTANT,
75
- value=[255, 255, 255],
76
- )
77
- return image, original_height, original_width
78
-
79
-
80
- def map_sample(
81
- url,
82
- caption,
83
- image_shape=(256, 256),
84
- min_image_shape=(128, 128),
85
- timeout=15,
86
- retries=3,
87
- upscale_interpolation=cv2.INTER_CUBIC,
88
- downscale_interpolation=cv2.INTER_AREA,
89
- image_processor=default_image_processor,
90
- ):
91
- try:
92
- # Assuming fetch_single_image is defined elsewhere
93
- image = fetch_single_image(url, timeout=timeout, retries=retries)
94
- if image is None:
95
- return
96
-
97
- image, original_height, original_width = image_processor(
98
- image, image_shape, min_image_shape=min_image_shape,
99
- upscale_interpolation=upscale_interpolation,
100
- downscale_interpolation=downscale_interpolation,
101
- )
102
-
103
- if image is None:
104
- return
105
-
106
- data_queue.put({
107
- "url": url,
108
- "caption": caption,
109
- "image": image,
110
- "original_height": original_height,
111
- "original_width": original_width,
112
- })
113
- except Exception as e:
114
- # print(f"Error maping sample {url}", e)
115
- # traceback.print_exc()
116
- # error_queue.put_nowait({
117
- # "url": url,
118
- # "caption": caption,
119
- # "error": str(e)
120
- # })
121
- pass
122
-
123
-
124
- def default_feature_extractor(sample):
125
- return {
126
- "url": sample["url"],
127
- "caption": sample["caption"],
128
- }
129
-
130
-
131
- def map_batch(
132
- batch, num_threads=256, image_shape=(256, 256),
133
- min_image_shape=(128, 128),
134
- timeout=15, retries=3, image_processor=default_image_processor,
135
- upscale_interpolation=cv2.INTER_CUBIC,
136
- downscale_interpolation=cv2.INTER_AREA,
137
- feature_extractor=default_feature_extractor,
138
- ):
139
- try:
140
- map_sample_fn = partial(
141
- map_sample, image_shape=image_shape, min_image_shape=min_image_shape,
142
- timeout=timeout, retries=retries, image_processor=image_processor,
143
- upscale_interpolation=upscale_interpolation,
144
- downscale_interpolation=downscale_interpolation
145
- )
146
- with ThreadPoolExecutor(max_workers=num_threads) as executor:
147
- features = feature_extractor(batch)
148
- url, caption = features["url"], features["caption"]
149
- executor.map(map_sample_fn, url, caption)
150
- except Exception as e:
151
- print(f"Error maping batch", e)
152
- traceback.print_exc()
153
- # error_queue.put_nowait({
154
- # "batch": batch,
155
- # "error": str(e)
156
- # })
157
- pass
158
-
159
-
160
- def parallel_image_loader(
161
- dataset: Dataset, num_workers: int = 8, image_shape=(256, 256),
162
- min_image_shape=(128, 128),
163
- num_threads=256, timeout=15, retries=3, image_processor=default_image_processor,
164
- upscale_interpolation=cv2.INTER_CUBIC,
165
- downscale_interpolation=cv2.INTER_AREA,
166
- feature_extractor=default_feature_extractor,
167
- ):
168
- map_batch_fn = partial(
169
- map_batch, num_threads=num_threads, image_shape=image_shape,
170
- min_image_shape=min_image_shape,
171
- timeout=timeout, retries=retries, image_processor=image_processor,
172
- upscale_interpolation=upscale_interpolation,
173
- downscale_interpolation=downscale_interpolation,
174
- feature_extractor=feature_extractor
175
- )
176
- shard_len = len(dataset) // num_workers
177
- print(f"Local Shard lengths: {shard_len}")
178
- with multiprocessing.Pool(num_workers) as pool:
179
- iteration = 0
180
- while True:
181
- # Repeat forever
182
- shards = [dataset[i*shard_len:(i+1)*shard_len]
183
- for i in range(num_workers)]
184
- print(f"mapping {len(shards)} shards")
185
- pool.map(map_batch_fn, shards)
186
- iteration += 1
187
- print(f"Shuffling dataset with seed {iteration}")
188
- dataset = dataset.shuffle(seed=iteration)
189
- # Clear the error queue
190
- # while not error_queue.empty():
191
- # error_queue.get_nowait()
192
-
193
-
194
- class ImageBatchIterator:
195
- def __init__(
196
- self, dataset: Dataset, batch_size: int = 64, image_shape=(256, 256),
197
- min_image_shape=(128, 128),
198
- num_workers: int = 8, num_threads=256, timeout=15, retries=3,
199
- image_processor=default_image_processor,
200
- upscale_interpolation=cv2.INTER_CUBIC,
201
- downscale_interpolation=cv2.INTER_AREA,
202
- feature_extractor=default_feature_extractor,
203
- ):
204
- self.dataset = dataset
205
- self.num_workers = num_workers
206
- self.batch_size = batch_size
207
- loader = partial(
208
- parallel_image_loader,
209
- num_threads=num_threads,
210
- image_shape=image_shape,
211
- min_image_shape=min_image_shape,
212
- num_workers=num_workers,
213
- timeout=timeout, retries=retries,
214
- image_processor=image_processor,
215
- upscale_interpolation=upscale_interpolation,
216
- downscale_interpolation=downscale_interpolation,
217
- feature_extractor=feature_extractor
218
- )
219
- self.thread = threading.Thread(target=loader, args=(dataset,))
220
- self.thread.start()
221
-
222
- def __iter__(self):
223
- return self
224
-
225
- def __next__(self):
226
- def fetcher(_):
227
- return data_queue.get()
228
- with ThreadPoolExecutor(max_workers=self.batch_size) as executor:
229
- batch = list(executor.map(fetcher, range(self.batch_size)))
230
- return batch
231
-
232
- def __del__(self):
233
- self.thread.join()
234
-
235
- def __len__(self):
236
- return len(self.dataset) // self.batch_size
237
-
238
-
239
- def default_collate(batch):
240
- urls = [sample["url"] for sample in batch]
241
- captions = [sample["caption"] for sample in batch]
242
- images = np.stack([sample["image"] for sample in batch], axis=0)
243
- return {
244
- "url": urls,
245
- "caption": captions,
246
- "image": images,
247
- }
248
-
249
-
250
- def dataMapper(map: Dict[str, Any]):
251
- def _map(sample) -> Dict[str, Any]:
252
- return {
253
- "url": sample[map["url"]],
254
- "caption": sample[map["caption"]],
255
- }
256
- return _map
257
-
258
-
259
- class OnlineStreamingDataLoader():
260
- def __init__(
261
- self,
262
- dataset,
263
- batch_size=64,
264
- image_shape=(256, 256),
265
- min_image_shape=(128, 128),
266
- num_workers=16,
267
- num_threads=512,
268
- default_split="all",
269
- pre_map_maker=dataMapper,
270
- pre_map_def={
271
- "url": "URL",
272
- "caption": "TEXT",
273
- },
274
- global_process_count=1,
275
- global_process_index=0,
276
- prefetch=1000,
277
- collate_fn=default_collate,
278
- timeout=15,
279
- retries=3,
280
- image_processor=default_image_processor,
281
- upscale_interpolation=cv2.INTER_CUBIC,
282
- downscale_interpolation=cv2.INTER_AREA,
283
- feature_extractor=default_feature_extractor,
284
- ):
285
- if isinstance(dataset, str):
286
- dataset_path = dataset
287
- print("Loading dataset from path")
288
- if "gs://" in dataset:
289
- dataset = load_from_disk(dataset_path)
290
- else:
291
- dataset = load_dataset(dataset_path, split=default_split)
292
- elif isinstance(dataset, list):
293
- if isinstance(dataset[0], str):
294
- print("Loading multiple datasets from paths")
295
- dataset = [load_from_disk(dataset_path) if "gs://" in dataset_path else load_dataset(
296
- dataset_path, split=default_split) for dataset_path in dataset]
297
- print("Concatenating multiple datasets")
298
- dataset = concatenate_datasets(dataset)
299
- dataset = dataset.shuffle(seed=0)
300
- # dataset = dataset.map(pre_map_maker(pre_map_def), batched=True, batch_size=10000000)
301
- self.dataset = dataset.shard(
302
- num_shards=global_process_count, index=global_process_index)
303
- print(f"Dataset length: {len(dataset)}")
304
- self.iterator = ImageBatchIterator(
305
- self.dataset, image_shape=image_shape,
306
- min_image_shape=min_image_shape,
307
- num_workers=num_workers, batch_size=batch_size, num_threads=num_threads,
308
- timeout=timeout, retries=retries, image_processor=image_processor,
309
- upscale_interpolation=upscale_interpolation,
310
- downscale_interpolation=downscale_interpolation,
311
- feature_extractor=feature_extractor
312
- )
313
- self.batch_size = batch_size
314
-
315
- # Launch a thread to load batches in the background
316
- self.batch_queue = queue.Queue(prefetch)
317
-
318
- def batch_loader():
319
- for batch in self.iterator:
320
- try:
321
- self.batch_queue.put(collate_fn(batch))
322
- except Exception as e:
323
- print("Error processing batch", e)
324
-
325
- self.loader_thread = threading.Thread(target=batch_loader)
326
- self.loader_thread.start()
327
-
328
- def __iter__(self):
329
- return self
330
-
331
- def __next__(self):
332
- return self.batch_queue.get()
333
- # return self.collate_fn(next(self.iterator))
334
-
335
- def __len__(self):
336
- return len(self.dataset)
@@ -1 +0,0 @@
1
- from .simple_unet import *