flaxdiff 0.1.12__py3-none-any.whl → 0.1.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1 @@
1
+ from .online_loader import OnlineStreamingDataLoader
@@ -0,0 +1,208 @@
1
+ import multiprocessing
2
+ import threading
3
+ from multiprocessing import Queue
4
+ # from arrayqueues.shared_arrays import ArrayQueue
5
+ # from faster_fifo import Queue
6
+ import time
7
+ import albumentations as A
8
+ import queue
9
+ import cv2
10
+ from functools import partial
11
+ from typing import Any, Dict, List, Tuple
12
+
13
+ import numpy as np
14
+ from functools import partial
15
+
16
+ from datasets import load_dataset, concatenate_datasets, Dataset
17
+ from datasets.utils.file_utils import get_datasets_user_agent
18
+ from concurrent.futures import ThreadPoolExecutor
19
+ import io
20
+ import urllib
21
+
22
+ import PIL.Image
23
+ import cv2
24
+
25
+ USER_AGENT = get_datasets_user_agent()
26
+
27
+ data_queue = Queue(16*2000)
28
+ error_queue = Queue(16*2000)
29
+
30
+
31
+ def fetch_single_image(image_url, timeout=None, retries=0):
32
+ for _ in range(retries + 1):
33
+ try:
34
+ request = urllib.request.Request(
35
+ image_url,
36
+ data=None,
37
+ headers={"user-agent": USER_AGENT},
38
+ )
39
+ with urllib.request.urlopen(request, timeout=timeout) as req:
40
+ image = PIL.Image.open(io.BytesIO(req.read()))
41
+ break
42
+ except Exception:
43
+ image = None
44
+ return image
45
+
46
+ def map_sample(
47
+ url, caption,
48
+ image_shape=(256, 256),
49
+ upscale_interpolation=cv2.INTER_LANCZOS4,
50
+ downscale_interpolation=cv2.INTER_AREA,
51
+ ):
52
+ try:
53
+ image = fetch_single_image(url, timeout=15, retries=3) # Assuming fetch_single_image is defined elsewhere
54
+ if image is None:
55
+ return
56
+
57
+ image = np.array(image)
58
+ original_height, original_width = image.shape[:2]
59
+ # check if the image is too small
60
+ if min(original_height, original_width) < min(image_shape):
61
+ return
62
+ # check if wrong aspect ratio
63
+ if max(original_height, original_width) / min(original_height, original_width) > 2:
64
+ return
65
+ # check if the variance is too low
66
+ if np.std(image) < 1e-4:
67
+ return
68
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
69
+ downscale = max(original_width, original_height) > max(image_shape)
70
+ interpolation = downscale_interpolation if downscale else upscale_interpolation
71
+ image = A.longest_max_size(image, max(image_shape), interpolation=interpolation)
72
+ image = A.pad(
73
+ image,
74
+ min_height=image_shape[0],
75
+ min_width=image_shape[1],
76
+ border_mode=cv2.BORDER_CONSTANT,
77
+ value=[255, 255, 255],
78
+ )
79
+ data_queue.put({
80
+ "url": url,
81
+ "caption": caption,
82
+ "image": image,
83
+ "original_height": original_height,
84
+ "original_width": original_width,
85
+ })
86
+ except Exception as e:
87
+ error_queue.put({
88
+ "url": url,
89
+ "caption": caption,
90
+ "error": str(e)
91
+ })
92
+
93
+ def map_batch(batch, num_threads=256, image_shape=(256, 256), timeout=None, retries=0):
94
+ with ThreadPoolExecutor(max_workers=num_threads) as executor:
95
+ executor.map(map_sample, batch["url"], batch['caption'], image_shape=image_shape, timeout=timeout, retries=retries)
96
+
97
+ def parallel_image_loader(dataset: Dataset, num_workers: int = 8, image_shape=(256, 256), num_threads=256):
98
+ map_batch_fn = partial(map_batch, num_threads=num_threads, image_shape=image_shape)
99
+ shard_len = len(dataset) // num_workers
100
+ print(f"Local Shard lengths: {shard_len}")
101
+ with multiprocessing.Pool(num_workers) as pool:
102
+ iteration = 0
103
+ while True:
104
+ # Repeat forever
105
+ dataset = dataset.shuffle(seed=iteration)
106
+ shards = [dataset[i*shard_len:(i+1)*shard_len] for i in range(num_workers)]
107
+ pool.map(map_batch_fn, shards)
108
+ iteration += 1
109
+
110
+ class ImageBatchIterator:
111
+ def __init__(self, dataset: Dataset, batch_size: int = 64, image_shape=(256, 256), num_workers: int = 8, num_threads=256):
112
+ self.dataset = dataset
113
+ self.num_workers = num_workers
114
+ self.batch_size = batch_size
115
+ loader = partial(parallel_image_loader, num_threads=num_threads, image_shape=image_shape, num_workers=num_workers)
116
+ self.thread = threading.Thread(target=loader, args=(dataset))
117
+ self.thread.start()
118
+
119
+ def __iter__(self):
120
+ return self
121
+
122
+ def __next__(self):
123
+ def fetcher(_):
124
+ return data_queue.get()
125
+ with ThreadPoolExecutor(max_workers=self.batch_size) as executor:
126
+ batch = list(executor.map(fetcher, range(self.batch_size)))
127
+ return batch
128
+
129
+ def __del__(self):
130
+ self.thread.join()
131
+
132
+ def __len__(self):
133
+ return len(self.dataset) // self.batch_size
134
+
135
+ def default_collate(batch):
136
+ urls = [sample["url"] for sample in batch]
137
+ captions = [sample["caption"] for sample in batch]
138
+ images = np.stack([sample["image"] for sample in batch], axis=0)
139
+ return {
140
+ "url": urls,
141
+ "caption": captions,
142
+ "image": images,
143
+ }
144
+
145
+ def dataMapper(map: Dict[str, Any]):
146
+ def _map(sample) -> Dict[str, Any]:
147
+ return {
148
+ "url": sample[map["url"]],
149
+ "caption": sample[map["caption"]],
150
+ }
151
+ return _map
152
+
153
+ class OnlineStreamingDataLoader():
154
+ def __init__(
155
+ self,
156
+ dataset,
157
+ batch_size=64,
158
+ image_shape=(256, 256),
159
+ num_workers=16,
160
+ num_threads=512,
161
+ default_split="all",
162
+ pre_map_maker=dataMapper,
163
+ pre_map_def={
164
+ "url": "URL",
165
+ "caption": "TEXT",
166
+ },
167
+ global_process_count=1,
168
+ global_process_index=0,
169
+ prefetch=1000,
170
+ collate_fn=default_collate,
171
+ ):
172
+ if isinstance(dataset, str):
173
+ dataset_path = dataset
174
+ print("Loading dataset from path")
175
+ dataset = load_dataset(dataset_path, split=default_split)
176
+ elif isinstance(dataset, list):
177
+ if isinstance(dataset[0], str):
178
+ print("Loading multiple datasets from paths")
179
+ dataset = [load_dataset(dataset_path, split=default_split) for dataset_path in dataset]
180
+ else:
181
+ print("Concatenating multiple datasets")
182
+ dataset = concatenate_datasets(dataset)
183
+ dataset = dataset.map(pre_map_maker(pre_map_def))
184
+ self.dataset = dataset.shard(num_shards=global_process_count, index=global_process_index)
185
+ print(f"Dataset length: {len(dataset)}")
186
+ self.iterator = ImageBatchIterator(self.dataset, image_shape=image_shape, num_workers=num_workers, batch_size=batch_size, num_threads=num_threads)
187
+ self.collate_fn = collate_fn
188
+
189
+ # Launch a thread to load batches in the background
190
+ self.batch_queue = queue.Queue(prefetch)
191
+
192
+ def batch_loader():
193
+ for batch in self.iterator:
194
+ self.batch_queue.put(batch)
195
+
196
+ self.loader_thread = threading.Thread(target=batch_loader)
197
+ self.loader_thread.start()
198
+
199
+ def __iter__(self):
200
+ return self
201
+
202
+ def __next__(self):
203
+ return self.collate_fn(self.batch_queue.get())
204
+ # return self.collate_fn(next(self.iterator))
205
+
206
+ def __len__(self):
207
+ return len(self.dataset) // self.batch_size
208
+
@@ -11,15 +11,15 @@ All credits for the model go to the developers of Stable Diffusion VAE and all c
11
11
  """
12
12
 
13
13
  class StableDiffusionVAE(AutoEncoder):
14
- def __init__(self, modelname = "CompVis/stable-diffusion-v1-4"):
14
+ def __init__(self, modelname = "CompVis/stable-diffusion-v1-4", revision="bf16", dtype=jnp.bfloat16):
15
15
 
16
16
  from diffusers.models.vae_flax import FlaxEncoder, FlaxDecoder
17
17
  from diffusers import FlaxStableDiffusionPipeline
18
18
 
19
19
  pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
20
20
  modelname,
21
- revision="bf16",
22
- dtype=jnp.bfloat16,
21
+ revision=revision,
22
+ dtype=dtype,
23
23
  )
24
24
 
25
25
  vae = pipeline.vae
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flaxdiff
3
- Version: 0.1.12
3
+ Version: 0.1.14
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author: Ashish Kumar Singh
6
6
  Author-email: ashishkmr472@gmail.com
@@ -1,5 +1,7 @@
1
1
  flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  flaxdiff/utils.py,sha256=B0GcHlzlVYDNEIdh2v5qmP4u0neIT-FqexNohuyuCvg,2452
3
+ flaxdiff/data/__init__.py,sha256=PM3PkHihyohT5SHVYKc8vQ4IeVfGPpCktkSVwvqMjQ4,52
4
+ flaxdiff/data/online_loader.py,sha256=IjCVdeq18lF71eLX8RplJLezjKASO1cSFVe2GSYkLQ8,7283
3
5
  flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,26
4
6
  flaxdiff/models/attention.py,sha256=ZbDGIb5Q6FRqJ6qRY660cqw4WvF9IwCnhEuYdTpLPdM,13023
5
7
  flaxdiff/models/common.py,sha256=fd-Fl0VCNEBjijHNwGBqYL5VvXe9u0347h25czNTmRw,10780
@@ -8,7 +10,7 @@ flaxdiff/models/simple_unet.py,sha256=h1o9mQlLJy7Ec8Pz_O5miRbAyUaM5UNhSs-oXzpQvZ
8
10
  flaxdiff/models/simple_vit.py,sha256=xD23i1b7WEvoH4tUMsLyCe9ebDcv-PpaV0Nso38Jlb8,3887
9
11
  flaxdiff/models/autoencoder/__init__.py,sha256=qY-7MldZpsfkF-_T2LqlRK7VHbqfmosz0NmvzDlBkOk,78
10
12
  flaxdiff/models/autoencoder/autoencoder.py,sha256=27_hYl0yXAdH9Mx4Xu9J79mSNo-FEKr9SxhVaS3ffn4,591
11
- flaxdiff/models/autoencoder/diffusers.py,sha256=l4teVksXd9XCCQWcVn9eB820xJyLT8hpg1CXQ_aHZ6M,3611
13
+ flaxdiff/models/autoencoder/diffusers.py,sha256=JHeFLCxiHhu-QHwhKiCuKsQJn4AZumquiuxgZkiYGQ0,3643
12
14
  flaxdiff/models/autoencoder/simple_autoenc.py,sha256=UXHPgDmwGTnv3Uts6Zj3p9R9nJXnEiEXbllgarwDfXM,805
13
15
  flaxdiff/predictors/__init__.py,sha256=SKkYYRF9Wfgk2zhtZw4vCXOdOeRlrm2Mk6cvuaEvAzc,4403
14
16
  flaxdiff/samplers/__init__.py,sha256=_S-9TwDeshrI0VmapV-J2hqjTByOa0-oOeUs_IdovjU,285
@@ -32,7 +34,7 @@ flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk,
32
34
  flaxdiff/trainer/autoencoder_trainer.py,sha256=al7AsZ7yeDMEiDD-gbcXf0ADq_xfk1VMxvg24GfA-XQ,7008
33
35
  flaxdiff/trainer/diffusion_trainer.py,sha256=wKkg63DWZjx2MoM3VQNCDIr40rWN8fUGxH9jWWxfZao,9373
34
36
  flaxdiff/trainer/simple_trainer.py,sha256=Z77zRS5viJpd2Mpl6sonJk5WcnEWi2Cd4gl4u5tIX2M,18206
35
- flaxdiff-0.1.12.dist-info/METADATA,sha256=s3rIj9jqh1Xr1NABqOJZw9XwyHdaLd01c_jKpzEMErQ,22083
36
- flaxdiff-0.1.12.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
37
- flaxdiff-0.1.12.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
38
- flaxdiff-0.1.12.dist-info/RECORD,,
37
+ flaxdiff-0.1.14.dist-info/METADATA,sha256=rcwF7cCFfgPLHn5gD7GZ_KvtG6EmiAIiel7xt8HylAo,22083
38
+ flaxdiff-0.1.14.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
39
+ flaxdiff-0.1.14.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
40
+ flaxdiff-0.1.14.dist-info/RECORD,,