flaxdiff 0.1.11__tar.gz → 0.1.13__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. {flaxdiff-0.1.11 → flaxdiff-0.1.13}/PKG-INFO +1 -1
  2. flaxdiff-0.1.13/flaxdiff/data/__init__.py +0 -0
  3. flaxdiff-0.1.13/flaxdiff/data/online_loader.py +205 -0
  4. {flaxdiff-0.1.11 → flaxdiff-0.1.13}/flaxdiff/models/attention.py +12 -12
  5. {flaxdiff-0.1.11 → flaxdiff-0.1.13}/flaxdiff/models/autoencoder/diffusers.py +3 -3
  6. {flaxdiff-0.1.11 → flaxdiff-0.1.13}/flaxdiff.egg-info/PKG-INFO +1 -1
  7. {flaxdiff-0.1.11 → flaxdiff-0.1.13}/flaxdiff.egg-info/SOURCES.txt +2 -0
  8. {flaxdiff-0.1.11 → flaxdiff-0.1.13}/setup.py +1 -1
  9. {flaxdiff-0.1.11 → flaxdiff-0.1.13}/README.md +0 -0
  10. {flaxdiff-0.1.11 → flaxdiff-0.1.13}/flaxdiff/__init__.py +0 -0
  11. {flaxdiff-0.1.11 → flaxdiff-0.1.13}/flaxdiff/models/__init__.py +0 -0
  12. {flaxdiff-0.1.11 → flaxdiff-0.1.13}/flaxdiff/models/autoencoder/__init__.py +0 -0
  13. {flaxdiff-0.1.11 → flaxdiff-0.1.13}/flaxdiff/models/autoencoder/autoencoder.py +0 -0
  14. {flaxdiff-0.1.11 → flaxdiff-0.1.13}/flaxdiff/models/autoencoder/simple_autoenc.py +0 -0
  15. {flaxdiff-0.1.11 → flaxdiff-0.1.13}/flaxdiff/models/common.py +0 -0
  16. {flaxdiff-0.1.11 → flaxdiff-0.1.13}/flaxdiff/models/favor_fastattn.py +0 -0
  17. {flaxdiff-0.1.11 → flaxdiff-0.1.13}/flaxdiff/models/simple_unet.py +0 -0
  18. {flaxdiff-0.1.11 → flaxdiff-0.1.13}/flaxdiff/models/simple_vit.py +0 -0
  19. {flaxdiff-0.1.11 → flaxdiff-0.1.13}/flaxdiff/predictors/__init__.py +0 -0
  20. {flaxdiff-0.1.11 → flaxdiff-0.1.13}/flaxdiff/samplers/__init__.py +0 -0
  21. {flaxdiff-0.1.11 → flaxdiff-0.1.13}/flaxdiff/samplers/common.py +0 -0
  22. {flaxdiff-0.1.11 → flaxdiff-0.1.13}/flaxdiff/samplers/ddim.py +0 -0
  23. {flaxdiff-0.1.11 → flaxdiff-0.1.13}/flaxdiff/samplers/ddpm.py +0 -0
  24. {flaxdiff-0.1.11 → flaxdiff-0.1.13}/flaxdiff/samplers/euler.py +0 -0
  25. {flaxdiff-0.1.11 → flaxdiff-0.1.13}/flaxdiff/samplers/heun_sampler.py +0 -0
  26. {flaxdiff-0.1.11 → flaxdiff-0.1.13}/flaxdiff/samplers/multistep_dpm.py +0 -0
  27. {flaxdiff-0.1.11 → flaxdiff-0.1.13}/flaxdiff/samplers/rk4_sampler.py +0 -0
  28. {flaxdiff-0.1.11 → flaxdiff-0.1.13}/flaxdiff/schedulers/__init__.py +0 -0
  29. {flaxdiff-0.1.11 → flaxdiff-0.1.13}/flaxdiff/schedulers/common.py +0 -0
  30. {flaxdiff-0.1.11 → flaxdiff-0.1.13}/flaxdiff/schedulers/continuous.py +0 -0
  31. {flaxdiff-0.1.11 → flaxdiff-0.1.13}/flaxdiff/schedulers/cosine.py +0 -0
  32. {flaxdiff-0.1.11 → flaxdiff-0.1.13}/flaxdiff/schedulers/discrete.py +0 -0
  33. {flaxdiff-0.1.11 → flaxdiff-0.1.13}/flaxdiff/schedulers/exp.py +0 -0
  34. {flaxdiff-0.1.11 → flaxdiff-0.1.13}/flaxdiff/schedulers/karras.py +0 -0
  35. {flaxdiff-0.1.11 → flaxdiff-0.1.13}/flaxdiff/schedulers/linear.py +0 -0
  36. {flaxdiff-0.1.11 → flaxdiff-0.1.13}/flaxdiff/schedulers/sqrt.py +0 -0
  37. {flaxdiff-0.1.11 → flaxdiff-0.1.13}/flaxdiff/trainer/__init__.py +0 -0
  38. {flaxdiff-0.1.11 → flaxdiff-0.1.13}/flaxdiff/trainer/autoencoder_trainer.py +0 -0
  39. {flaxdiff-0.1.11 → flaxdiff-0.1.13}/flaxdiff/trainer/diffusion_trainer.py +0 -0
  40. {flaxdiff-0.1.11 → flaxdiff-0.1.13}/flaxdiff/trainer/simple_trainer.py +0 -0
  41. {flaxdiff-0.1.11 → flaxdiff-0.1.13}/flaxdiff/utils.py +0 -0
  42. {flaxdiff-0.1.11 → flaxdiff-0.1.13}/flaxdiff.egg-info/dependency_links.txt +0 -0
  43. {flaxdiff-0.1.11 → flaxdiff-0.1.13}/flaxdiff.egg-info/requires.txt +0 -0
  44. {flaxdiff-0.1.11 → flaxdiff-0.1.13}/flaxdiff.egg-info/top_level.txt +0 -0
  45. {flaxdiff-0.1.11 → flaxdiff-0.1.13}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flaxdiff
3
- Version: 0.1.11
3
+ Version: 0.1.13
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author: Ashish Kumar Singh
6
6
  Author-email: ashishkmr472@gmail.com
File without changes
@@ -0,0 +1,205 @@
1
+ import multiprocessing
2
+ import threading
3
+ from multiprocessing import Queue
4
+ # from arrayqueues.shared_arrays import ArrayQueue
5
+ # from faster_fifo import Queue
6
+ import time
7
+ import albumentations as A
8
+ import queue
9
+ import cv2
10
+ from functools import partial
11
+ from typing import Any, Dict, List, Tuple
12
+
13
+ import numpy as np
14
+ from functools import partial
15
+
16
+ from datasets import load_dataset, concatenate_datasets, Dataset
17
+ from datasets.utils.file_utils import get_datasets_user_agent
18
+ from concurrent.futures import ThreadPoolExecutor
19
+ import io
20
+ import urllib
21
+
22
+ import PIL.Image
23
+ import cv2
24
+
25
+ USER_AGENT = get_datasets_user_agent()
26
+
27
+ data_queue = Queue(16*2000)
28
+ error_queue = Queue(16*2000)
29
+
30
+
31
+ def fetch_single_image(image_url, timeout=None, retries=0):
32
+ for _ in range(retries + 1):
33
+ try:
34
+ request = urllib.request.Request(
35
+ image_url,
36
+ data=None,
37
+ headers={"user-agent": USER_AGENT},
38
+ )
39
+ with urllib.request.urlopen(request, timeout=timeout) as req:
40
+ image = PIL.Image.open(io.BytesIO(req.read()))
41
+ break
42
+ except Exception:
43
+ image = None
44
+ return image
45
+
46
+ def map_sample(
47
+ url, caption,
48
+ image_shape=(256, 256),
49
+ upscale_interpolation=cv2.INTER_LANCZOS4,
50
+ downscale_interpolation=cv2.INTER_AREA,
51
+ ):
52
+ try:
53
+ image = fetch_single_image(url, timeout=15, retries=3) # Assuming fetch_single_image is defined elsewhere
54
+ if image is None:
55
+ return
56
+
57
+ image = np.array(image)
58
+ original_height, original_width = image.shape[:2]
59
+ # check if the image is too small
60
+ if min(original_height, original_width) < min(image_shape):
61
+ return
62
+ # check if wrong aspect ratio
63
+ if max(original_height, original_width) / min(original_height, original_width) > 2:
64
+ return
65
+ # check if the variance is too low
66
+ if np.std(image) < 1e-4:
67
+ return
68
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
69
+ downscale = max(original_width, original_height) > max(image_shape)
70
+ interpolation = downscale_interpolation if downscale else upscale_interpolation
71
+ image = A.longest_max_size(image, max(image_shape), interpolation=interpolation)
72
+ image = A.pad(
73
+ image,
74
+ min_height=image_shape[0],
75
+ min_width=image_shape[1],
76
+ border_mode=cv2.BORDER_CONSTANT,
77
+ value=[255, 255, 255],
78
+ )
79
+ data_queue.put({
80
+ "url": url,
81
+ "caption": caption,
82
+ "image": image
83
+ })
84
+ except Exception as e:
85
+ error_queue.put({
86
+ "url": url,
87
+ "caption": caption,
88
+ "error": str(e)
89
+ })
90
+
91
+ def map_batch(batch, num_threads=256, timeout=None, retries=0):
92
+ with ThreadPoolExecutor(max_workers=num_threads) as executor:
93
+ executor.map(map_sample, batch["url"], batch['caption'])
94
+
95
+ def parallel_image_loader(dataset: Dataset, num_workers: int = 8, num_threads=256):
96
+ map_batch_fn = partial(map_batch, num_threads=num_threads)
97
+ shard_len = len(dataset) // num_workers
98
+ print(f"Local Shard lengths: {shard_len}")
99
+ with multiprocessing.Pool(num_workers) as pool:
100
+ iteration = 0
101
+ while True:
102
+ # Repeat forever
103
+ dataset = dataset.shuffle(seed=iteration)
104
+ shards = [dataset[i*shard_len:(i+1)*shard_len] for i in range(num_workers)]
105
+ pool.map(map_batch_fn, shards)
106
+ iteration += 1
107
+
108
+ class ImageBatchIterator:
109
+ def __init__(self, dataset: Dataset, batch_size: int = 64, num_workers: int = 8, num_threads=256):
110
+ self.dataset = dataset
111
+ self.num_workers = num_workers
112
+ self.batch_size = batch_size
113
+ loader = partial(parallel_image_loader, num_threads=num_threads)
114
+ self.thread = threading.Thread(target=loader, args=(dataset, num_workers))
115
+ self.thread.start()
116
+
117
+ def __iter__(self):
118
+ return self
119
+
120
+ def __next__(self):
121
+ def fetcher(_):
122
+ return data_queue.get()
123
+ with ThreadPoolExecutor(max_workers=self.batch_size) as executor:
124
+ batch = list(executor.map(fetcher, range(self.batch_size)))
125
+ return batch
126
+
127
+ def __del__(self):
128
+ self.thread.join()
129
+
130
+ def __len__(self):
131
+ return len(self.dataset) // self.batch_size
132
+
133
+ def default_collate(batch):
134
+ urls = [sample["url"] for sample in batch]
135
+ captions = [sample["caption"] for sample in batch]
136
+ images = np.stack([sample["image"] for sample in batch], axis=0)
137
+ return {
138
+ "url": urls,
139
+ "caption": captions,
140
+ "image": images,
141
+ }
142
+
143
+ def dataMapper(map: Dict[str, Any]):
144
+ def _map(sample) -> Dict[str, Any]:
145
+ return {
146
+ "url": sample[map["url"]],
147
+ "caption": sample[map["caption"]],
148
+ }
149
+ return _map
150
+
151
+ class OnlineStreamingDataLoader():
152
+ def __init__(
153
+ self,
154
+ dataset,
155
+ batch_size=64,
156
+ num_workers=16,
157
+ num_threads=512,
158
+ default_split="all",
159
+ pre_map_maker=dataMapper,
160
+ pre_map_def={
161
+ "url": "URL",
162
+ "caption": "TEXT",
163
+ },
164
+ global_process_count=1,
165
+ global_process_index=0,
166
+ prefetch=1000,
167
+ collate_fn=default_collate,
168
+ ):
169
+ if isinstance(dataset, str):
170
+ dataset_path = dataset
171
+ print("Loading dataset from path")
172
+ dataset = load_dataset(dataset_path, split=default_split)
173
+ elif isinstance(dataset, list):
174
+ if isinstance(dataset[0], str):
175
+ print("Loading multiple datasets from paths")
176
+ dataset = [load_dataset(dataset_path, split=default_split) for dataset_path in dataset]
177
+ else:
178
+ print("Concatenating multiple datasets")
179
+ dataset = concatenate_datasets(dataset)
180
+ dataset = dataset.map(pre_map_maker(pre_map_def))
181
+ self.dataset = dataset.shard(num_shards=global_process_count, index=global_process_index)
182
+ print(f"Dataset length: {len(dataset)}")
183
+ self.iterator = ImageBatchIterator(self.dataset, num_workers=num_workers, batch_size=batch_size, num_threads=num_threads)
184
+ self.collate_fn = collate_fn
185
+
186
+ # Launch a thread to load batches in the background
187
+ self.batch_queue = queue.Queue(prefetch)
188
+
189
+ def batch_loader():
190
+ for batch in self.iterator:
191
+ self.batch_queue.put(batch)
192
+
193
+ self.loader_thread = threading.Thread(target=batch_loader)
194
+ self.loader_thread.start()
195
+
196
+ def __iter__(self):
197
+ return self
198
+
199
+ def __next__(self):
200
+ return self.collate_fn(self.batch_queue.get())
201
+ # return self.collate_fn(next(self.iterator))
202
+
203
+ def __len__(self):
204
+ return len(self.dataset) // self.batch_size
205
+
@@ -22,7 +22,7 @@ class EfficientAttention(nn.Module):
22
22
  dtype: Optional[Dtype] = None
23
23
  precision: PrecisionLike = None
24
24
  use_bias: bool = True
25
- kernel_init: Callable = lambda : kernel_init(1.0)
25
+ kernel_init: Callable = kernel_init(1.0)
26
26
  force_fp32_for_softmax: bool = True
27
27
 
28
28
  def setup(self):
@@ -33,7 +33,7 @@ class EfficientAttention(nn.Module):
33
33
  self.heads * self.dim_head,
34
34
  precision=self.precision,
35
35
  use_bias=self.use_bias,
36
- kernel_init=self.kernel_init(),
36
+ kernel_init=self.kernel_init,
37
37
  dtype=self.dtype
38
38
  )
39
39
  self.query = dense(name="to_q")
@@ -41,7 +41,7 @@ class EfficientAttention(nn.Module):
41
41
  self.value = dense(name="to_v")
42
42
 
43
43
  self.proj_attn = nn.DenseGeneral(self.query_dim, use_bias=False, precision=self.precision,
44
- kernel_init=self.kernel_init(), dtype=self.dtype, name="to_out_0")
44
+ kernel_init=self.kernel_init, dtype=self.dtype, name="to_out_0")
45
45
  # self.attnfn = make_fast_generalized_attention(qkv_dim=inner_dim, lax_scan_unroll=16)
46
46
 
47
47
  def _reshape_tensor_to_head_dim(self, tensor):
@@ -114,7 +114,7 @@ class NormalAttention(nn.Module):
114
114
  dtype: Optional[Dtype] = None
115
115
  precision: PrecisionLike = None
116
116
  use_bias: bool = True
117
- kernel_init: Callable = lambda : kernel_init(1.0)
117
+ kernel_init: Callable = kernel_init(1.0)
118
118
  force_fp32_for_softmax: bool = True
119
119
 
120
120
  def setup(self):
@@ -125,7 +125,7 @@ class NormalAttention(nn.Module):
125
125
  axis=-1,
126
126
  precision=self.precision,
127
127
  use_bias=self.use_bias,
128
- kernel_init=self.kernel_init(),
128
+ kernel_init=self.kernel_init,
129
129
  dtype=self.dtype
130
130
  )
131
131
  self.query = dense(name="to_q")
@@ -139,7 +139,7 @@ class NormalAttention(nn.Module):
139
139
  use_bias=self.use_bias,
140
140
  dtype=self.dtype,
141
141
  name="to_out_0",
142
- kernel_init=self.kernel_init()
142
+ kernel_init=self.kernel_init
143
143
  # kernel_init=jax.nn.initializers.xavier_uniform()
144
144
  )
145
145
 
@@ -235,7 +235,7 @@ class BasicTransformerBlock(nn.Module):
235
235
  dtype: Optional[Dtype] = None
236
236
  precision: PrecisionLike = None
237
237
  use_bias: bool = True
238
- kernel_init: Callable = lambda : kernel_init(1.0)
238
+ kernel_init: Callable = kernel_init(1.0)
239
239
  use_flash_attention:bool = False
240
240
  use_cross_only:bool = False
241
241
  only_pure_attention:bool = False
@@ -302,7 +302,7 @@ class TransformerBlock(nn.Module):
302
302
  use_self_and_cross:bool = True
303
303
  only_pure_attention:bool = False
304
304
  force_fp32_for_softmax: bool = True
305
- kernel_init: Callable = lambda : kernel_init(1.0)
305
+ kernel_init: Callable = kernel_init(1.0)
306
306
 
307
307
  @nn.compact
308
308
  def __call__(self, x, context=None):
@@ -313,12 +313,12 @@ class TransformerBlock(nn.Module):
313
313
  if self.use_linear_attention:
314
314
  projected_x = nn.Dense(features=inner_dim,
315
315
  use_bias=False, precision=self.precision,
316
- kernel_init=self.kernel_init(),
316
+ kernel_init=self.kernel_init,
317
317
  dtype=self.dtype, name=f'project_in')(normed_x)
318
318
  else:
319
319
  projected_x = nn.Conv(
320
320
  features=inner_dim, kernel_size=(1, 1),
321
- kernel_init=self.kernel_init(),
321
+ kernel_init=self.kernel_init,
322
322
  strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
323
323
  precision=self.precision, name=f'project_in_conv',
324
324
  )(normed_x)
@@ -347,12 +347,12 @@ class TransformerBlock(nn.Module):
347
347
  if self.use_linear_attention:
348
348
  projected_x = nn.Dense(features=C, precision=self.precision,
349
349
  dtype=self.dtype, use_bias=False,
350
- kernel_init=self.kernel_init(),
350
+ kernel_init=self.kernel_init,
351
351
  name=f'project_out')(projected_x)
352
352
  else:
353
353
  projected_x = nn.Conv(
354
354
  features=C, kernel_size=(1, 1),
355
- kernel_init=self.kernel_init(),
355
+ kernel_init=self.kernel_init,
356
356
  strides=(1, 1), padding='VALID', use_bias=False, dtype=self.dtype,
357
357
  precision=self.precision, name=f'project_out_conv',
358
358
  )(projected_x)
@@ -11,15 +11,15 @@ All credits for the model go to the developers of Stable Diffusion VAE and all c
11
11
  """
12
12
 
13
13
  class StableDiffusionVAE(AutoEncoder):
14
- def __init__(self, modelname = "CompVis/stable-diffusion-v1-4"):
14
+ def __init__(self, modelname = "CompVis/stable-diffusion-v1-4", revision="bf16", dtype=jnp.bfloat16):
15
15
 
16
16
  from diffusers.models.vae_flax import FlaxEncoder, FlaxDecoder
17
17
  from diffusers import FlaxStableDiffusionPipeline
18
18
 
19
19
  pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
20
20
  modelname,
21
- revision="bf16",
22
- dtype=jnp.bfloat16,
21
+ revision=revision,
22
+ dtype=dtype,
23
23
  )
24
24
 
25
25
  vae = pipeline.vae
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flaxdiff
3
- Version: 0.1.11
3
+ Version: 0.1.13
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author: Ashish Kumar Singh
6
6
  Author-email: ashishkmr472@gmail.com
@@ -7,6 +7,8 @@ flaxdiff.egg-info/SOURCES.txt
7
7
  flaxdiff.egg-info/dependency_links.txt
8
8
  flaxdiff.egg-info/requires.txt
9
9
  flaxdiff.egg-info/top_level.txt
10
+ flaxdiff/data/__init__.py
11
+ flaxdiff/data/online_loader.py
10
12
  flaxdiff/models/__init__.py
11
13
  flaxdiff/models/attention.py
12
14
  flaxdiff/models/common.py
@@ -11,7 +11,7 @@ required_packages=[
11
11
  setup(
12
12
  name='flaxdiff',
13
13
  packages=find_packages(),
14
- version='0.1.11',
14
+ version='0.1.13',
15
15
  description='A versatile and easy to understand Diffusion library',
16
16
  long_description=open('README.md').read(),
17
17
  long_description_content_type='text/markdown',
File without changes
File without changes
File without changes