flaxdiff 0.1.35.6__py3-none-any.whl → 0.1.36.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. flaxdiff/utils.py +105 -2
  2. {flaxdiff-0.1.35.6.dist-info → flaxdiff-0.1.36.1.dist-info}/METADATA +16 -7
  3. flaxdiff-0.1.36.1.dist-info/RECORD +6 -0
  4. {flaxdiff-0.1.35.6.dist-info → flaxdiff-0.1.36.1.dist-info}/WHEEL +1 -1
  5. flaxdiff/data/__init__.py +0 -1
  6. flaxdiff/data/online_loader.py +0 -336
  7. flaxdiff/models/__init__.py +0 -1
  8. flaxdiff/models/attention.py +0 -368
  9. flaxdiff/models/autoencoder/__init__.py +0 -2
  10. flaxdiff/models/autoencoder/autoencoder.py +0 -19
  11. flaxdiff/models/autoencoder/diffusers.py +0 -91
  12. flaxdiff/models/autoencoder/simple_autoenc.py +0 -26
  13. flaxdiff/models/common.py +0 -346
  14. flaxdiff/models/favor_fastattn.py +0 -723
  15. flaxdiff/models/simple_unet.py +0 -233
  16. flaxdiff/models/simple_vit.py +0 -180
  17. flaxdiff/predictors/__init__.py +0 -96
  18. flaxdiff/samplers/__init__.py +0 -7
  19. flaxdiff/samplers/common.py +0 -113
  20. flaxdiff/samplers/ddim.py +0 -10
  21. flaxdiff/samplers/ddpm.py +0 -43
  22. flaxdiff/samplers/euler.py +0 -59
  23. flaxdiff/samplers/heun_sampler.py +0 -28
  24. flaxdiff/samplers/multistep_dpm.py +0 -60
  25. flaxdiff/samplers/rk4_sampler.py +0 -34
  26. flaxdiff/schedulers/__init__.py +0 -6
  27. flaxdiff/schedulers/common.py +0 -98
  28. flaxdiff/schedulers/continuous.py +0 -12
  29. flaxdiff/schedulers/cosine.py +0 -40
  30. flaxdiff/schedulers/discrete.py +0 -74
  31. flaxdiff/schedulers/exp.py +0 -13
  32. flaxdiff/schedulers/karras.py +0 -69
  33. flaxdiff/schedulers/linear.py +0 -14
  34. flaxdiff/schedulers/sqrt.py +0 -10
  35. flaxdiff/trainer/__init__.py +0 -2
  36. flaxdiff/trainer/autoencoder_trainer.py +0 -182
  37. flaxdiff/trainer/diffusion_trainer.py +0 -234
  38. flaxdiff/trainer/simple_trainer.py +0 -442
  39. flaxdiff-0.1.35.6.dist-info/RECORD +0 -40
  40. {flaxdiff-0.1.35.6.dist-info → flaxdiff-0.1.36.1.dist-info}/top_level.txt +0 -0
flaxdiff/utils.py CHANGED
@@ -2,7 +2,12 @@ import jax
2
2
  import jax.numpy as jnp
3
3
  import flax.struct as struct
4
4
  import flax.linen as nn
5
- from typing import Any
5
+ from typing import Any, Callable
6
+ from dataclasses import dataclass
7
+ from functools import partial
8
+ import numpy as np
9
+ from jax.sharding import Mesh, PartitionSpec as P
10
+ from abc import ABC, abstractmethod
6
11
 
7
12
  class MarkovState(struct.PyTreeNode):
8
13
  pass
@@ -17,6 +22,30 @@ class RandomMarkovState(MarkovState):
17
22
  def clip_images(images, clip_min=-1, clip_max=1):
18
23
  return jnp.clip(images, clip_min, clip_max)
19
24
 
25
+ def _build_global_shape_and_sharding(
26
+ local_shape: tuple[int, ...], global_mesh: Mesh
27
+ ) -> tuple[tuple[int, ...], jax.sharding.NamedSharding]:
28
+ sharding = jax.sharding.NamedSharding(global_mesh, P(global_mesh.axis_names))
29
+ global_shape = (jax.process_count() * local_shape[0],) + local_shape[1:]
30
+ return global_shape, sharding
31
+
32
+
33
+ def form_global_array(path, array: np.ndarray, global_mesh: Mesh) -> jax.Array:
34
+ """Put local sharded array into local devices"""
35
+ global_shape, sharding = _build_global_shape_and_sharding(np.shape(array), global_mesh)
36
+ try:
37
+ local_device_arrays = np.split(array, len(global_mesh.local_devices), axis=0)
38
+ except ValueError as array_split_error:
39
+ raise ValueError(
40
+ f"Unable to put to devices shape {array.shape} with "
41
+ f"local device count {len(global_mesh.local_devices)} "
42
+ ) from array_split_error
43
+ local_device_buffers = jax.device_put(local_device_arrays, global_mesh.local_devices)
44
+ return jax.make_array_from_single_device_arrays(global_shape, sharding, local_device_buffers)
45
+
46
+ def convert_to_global_tree(global_mesh, pytree):
47
+ return jax.tree_util.tree_map_with_path(partial(form_global_array, global_mesh=global_mesh), pytree)
48
+
20
49
  class RMSNorm(nn.Module):
21
50
  """
22
51
  From "Root Mean Square Layer Normalization" by https://arxiv.org/abs/1910.07467
@@ -86,4 +115,78 @@ class RMSNorm(nn.Module):
86
115
  ).reshape(feature_shape)
87
116
  mul *= scale
88
117
  y = mul * x
89
- return jnp.asarray(y, dtype)
118
+ return jnp.asarray(y, dtype)
119
+
120
+ @dataclass
121
+ class ConditioningEncoder(ABC):
122
+ model: nn.Module
123
+ tokenizer: Callable
124
+
125
+ def __call__(self, data):
126
+ tokens = self.tokenize(data)
127
+ outputs = self.encode_from_tokens(tokens)
128
+ return outputs
129
+
130
+ def encode_from_tokens(self, tokens):
131
+ outputs = self.model(input_ids=tokens['input_ids'],
132
+ attention_mask=tokens['attention_mask'])
133
+ last_hidden_state = outputs.last_hidden_state
134
+ return last_hidden_state
135
+
136
+ def tokenize(self, data):
137
+ tokens = self.tokenizer(data, padding="max_length",
138
+ max_length=self.tokenizer.model_max_length, truncation=True, return_tensors="np")
139
+ return tokens
140
+
141
+ @dataclass
142
+ class TextEncoder(ConditioningEncoder):
143
+ # def __call__(self, data):
144
+ # tokens = self.tokenize(data)
145
+ # outputs = self.encode_from_tokens(tokens)
146
+ # return outputs
147
+
148
+ # def encode_from_tokens(self, tokens):
149
+ # outputs = self.model(input_ids=tokens['input_ids'],
150
+ # attention_mask=tokens['attention_mask'])
151
+ # last_hidden_state = outputs.last_hidden_state
152
+ # # pooler_output = outputs.pooler_output # pooled (EOS token) states
153
+ # # embed_pooled = pooler_output # .astype(jnp.float16)
154
+ # embed_labels_full = last_hidden_state # .astype(jnp.float16)
155
+
156
+ # return embed_labels_full
157
+ pass
158
+
159
+ class AutoTextTokenizer:
160
+ def __init__(self, tensor_type="pt", modelname="openai/clip-vit-large-patch14"):
161
+ from transformers import AutoTokenizer
162
+ self.tokenizer = AutoTokenizer.from_pretrained(modelname)
163
+ self.tensor_type = tensor_type
164
+
165
+ def __call__(self, inputs):
166
+ # print(caption)
167
+ tokens = self.tokenizer(inputs, padding="max_length", max_length=self.tokenizer.model_max_length,
168
+ truncation=True, return_tensors=self.tensor_type)
169
+ # print(tokens.keys())
170
+ return {
171
+ "input_ids": tokens["input_ids"],
172
+ "attention_mask": tokens["attention_mask"],
173
+ "caption": inputs,
174
+ }
175
+
176
+ def __repr__(self):
177
+ return self.__class__.__name__ + '()'
178
+
179
+ def defaultTextEncodeModel(backend="jax"):
180
+ from transformers import (
181
+ CLIPTextModel,
182
+ FlaxCLIPTextModel,
183
+ AutoTokenizer,
184
+ )
185
+ modelname = "openai/clip-vit-large-patch14"
186
+ if backend == "jax":
187
+ model = FlaxCLIPTextModel.from_pretrained(
188
+ modelname, dtype=jnp.bfloat16)
189
+ else:
190
+ model = CLIPTextModel.from_pretrained(modelname)
191
+ tokenizer = AutoTokenizer.from_pretrained(modelname, dtype=jnp.float16)
192
+ return TextEncoder(model, tokenizer)
@@ -1,15 +1,24 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.4
2
2
  Name: flaxdiff
3
- Version: 0.1.35.6
3
+ Version: 0.1.36.1
4
4
  Summary: A versatile and easy to understand Diffusion library
5
- Author: Ashish Kumar Singh
6
- Author-email: ashishkmr472@gmail.com
5
+ Author-email: Ashish Kumar Singh <ashishkmr472@gmail.com>
6
+ License-Expression: MIT
7
7
  Description-Content-Type: text/markdown
8
- Requires-Dist: flax >=0.8.4
9
- Requires-Dist: optax >=0.2.2
10
- Requires-Dist: jax >=0.4.28
8
+ Requires-Dist: flax>=0.8.4
9
+ Requires-Dist: jax>=0.4.28
10
+ Requires-Dist: optax>=0.2.2
11
11
  Requires-Dist: orbax
12
+ Requires-Dist: numpy
12
13
  Requires-Dist: clu
14
+ Requires-Dist: einops
15
+ Requires-Dist: tqdm
16
+ Requires-Dist: grain
17
+ Requires-Dist: termcolor
18
+ Requires-Dist: augmax
19
+ Requires-Dist: albumentations
20
+ Requires-Dist: rich
21
+ Requires-Dist: python-dotenv
13
22
 
14
23
  # ![](images/logo.jpeg "FlaxDiff")
15
24
 
@@ -0,0 +1,6 @@
1
+ flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ flaxdiff/utils.py,sha256=b_hFXsam2NICQYCFk0EOcqtBjM-RUqnN0NKTn0lQ070,6532
3
+ flaxdiff-0.1.36.1.dist-info/METADATA,sha256=Fl9tlGh_BgRnT-f8k4cEYnFj7G03VecUNOX_1zbJrmE,22310
4
+ flaxdiff-0.1.36.1.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
5
+ flaxdiff-0.1.36.1.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
6
+ flaxdiff-0.1.36.1.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.1.0)
2
+ Generator: setuptools (78.1.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
flaxdiff/data/__init__.py DELETED
@@ -1 +0,0 @@
1
- from .online_loader import OnlineStreamingDataLoader
@@ -1,336 +0,0 @@
1
- import multiprocessing
2
- import threading
3
- from multiprocessing import Queue
4
- # from arrayqueues.shared_arrays import ArrayQueue
5
- # from faster_fifo import Queue
6
- import time
7
- import albumentations as A
8
- import queue
9
- import cv2
10
- from functools import partial
11
- from typing import Any, Dict, List, Tuple
12
-
13
- import numpy as np
14
- from functools import partial
15
-
16
- from datasets import load_dataset, concatenate_datasets, Dataset, load_from_disk
17
- from datasets.utils.file_utils import get_datasets_user_agent
18
- from concurrent.futures import ThreadPoolExecutor
19
- import io
20
- import urllib
21
-
22
- import PIL.Image
23
- import cv2
24
- import traceback
25
-
26
- USER_AGENT = get_datasets_user_agent()
27
-
28
- data_queue = Queue(16*2000)
29
-
30
-
31
- def fetch_single_image(image_url, timeout=None, retries=0):
32
- for _ in range(retries + 1):
33
- try:
34
- request = urllib.request.Request(
35
- image_url,
36
- data=None,
37
- headers={"user-agent": USER_AGENT},
38
- )
39
- with urllib.request.urlopen(request, timeout=timeout) as req:
40
- image = PIL.Image.open(io.BytesIO(req.read()))
41
- break
42
- except Exception:
43
- image = None
44
- return image
45
-
46
-
47
- def default_image_processor(
48
- image, image_shape,
49
- min_image_shape=(128, 128),
50
- upscale_interpolation=cv2.INTER_CUBIC,
51
- downscale_interpolation=cv2.INTER_AREA,
52
- ):
53
- image = np.array(image)
54
- original_height, original_width = image.shape[:2]
55
- # check if the image is too small
56
- if min(original_height, original_width) < min(min_image_shape):
57
- return None, original_height, original_width
58
- # check if wrong aspect ratio
59
- if max(original_height, original_width) / min(original_height, original_width) > 2.4:
60
- return None, original_height, original_width
61
- # check if the variance is too low
62
- if np.std(image) < 1e-5:
63
- return None, original_height, original_width
64
- # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
65
- downscale = max(original_width, original_height) > max(image_shape)
66
- interpolation = downscale_interpolation if downscale else upscale_interpolation
67
-
68
- image = A.longest_max_size(image, max(
69
- image_shape), interpolation=interpolation)
70
- image = A.pad(
71
- image,
72
- min_height=image_shape[0],
73
- min_width=image_shape[1],
74
- border_mode=cv2.BORDER_CONSTANT,
75
- value=[255, 255, 255],
76
- )
77
- return image, original_height, original_width
78
-
79
-
80
- def map_sample(
81
- url,
82
- caption,
83
- image_shape=(256, 256),
84
- min_image_shape=(128, 128),
85
- timeout=15,
86
- retries=3,
87
- upscale_interpolation=cv2.INTER_CUBIC,
88
- downscale_interpolation=cv2.INTER_AREA,
89
- image_processor=default_image_processor,
90
- ):
91
- try:
92
- # Assuming fetch_single_image is defined elsewhere
93
- image = fetch_single_image(url, timeout=timeout, retries=retries)
94
- if image is None:
95
- return
96
-
97
- image, original_height, original_width = image_processor(
98
- image, image_shape, min_image_shape=min_image_shape,
99
- upscale_interpolation=upscale_interpolation,
100
- downscale_interpolation=downscale_interpolation,
101
- )
102
-
103
- if image is None:
104
- return
105
-
106
- data_queue.put({
107
- "url": url,
108
- "caption": caption,
109
- "image": image,
110
- "original_height": original_height,
111
- "original_width": original_width,
112
- })
113
- except Exception as e:
114
- # print(f"Error maping sample {url}", e)
115
- # traceback.print_exc()
116
- # error_queue.put_nowait({
117
- # "url": url,
118
- # "caption": caption,
119
- # "error": str(e)
120
- # })
121
- pass
122
-
123
-
124
- def default_feature_extractor(sample):
125
- return {
126
- "url": sample["url"],
127
- "caption": sample["caption"],
128
- }
129
-
130
-
131
- def map_batch(
132
- batch, num_threads=256, image_shape=(256, 256),
133
- min_image_shape=(128, 128),
134
- timeout=15, retries=3, image_processor=default_image_processor,
135
- upscale_interpolation=cv2.INTER_CUBIC,
136
- downscale_interpolation=cv2.INTER_AREA,
137
- feature_extractor=default_feature_extractor,
138
- ):
139
- try:
140
- map_sample_fn = partial(
141
- map_sample, image_shape=image_shape, min_image_shape=min_image_shape,
142
- timeout=timeout, retries=retries, image_processor=image_processor,
143
- upscale_interpolation=upscale_interpolation,
144
- downscale_interpolation=downscale_interpolation
145
- )
146
- with ThreadPoolExecutor(max_workers=num_threads) as executor:
147
- features = feature_extractor(batch)
148
- url, caption = features["url"], features["caption"]
149
- executor.map(map_sample_fn, url, caption)
150
- except Exception as e:
151
- print(f"Error maping batch", e)
152
- traceback.print_exc()
153
- # error_queue.put_nowait({
154
- # "batch": batch,
155
- # "error": str(e)
156
- # })
157
- pass
158
-
159
-
160
- def parallel_image_loader(
161
- dataset: Dataset, num_workers: int = 8, image_shape=(256, 256),
162
- min_image_shape=(128, 128),
163
- num_threads=256, timeout=15, retries=3, image_processor=default_image_processor,
164
- upscale_interpolation=cv2.INTER_CUBIC,
165
- downscale_interpolation=cv2.INTER_AREA,
166
- feature_extractor=default_feature_extractor,
167
- ):
168
- map_batch_fn = partial(
169
- map_batch, num_threads=num_threads, image_shape=image_shape,
170
- min_image_shape=min_image_shape,
171
- timeout=timeout, retries=retries, image_processor=image_processor,
172
- upscale_interpolation=upscale_interpolation,
173
- downscale_interpolation=downscale_interpolation,
174
- feature_extractor=feature_extractor
175
- )
176
- shard_len = len(dataset) // num_workers
177
- print(f"Local Shard lengths: {shard_len}")
178
- with multiprocessing.Pool(num_workers) as pool:
179
- iteration = 0
180
- while True:
181
- # Repeat forever
182
- shards = [dataset[i*shard_len:(i+1)*shard_len]
183
- for i in range(num_workers)]
184
- print(f"mapping {len(shards)} shards")
185
- pool.map(map_batch_fn, shards)
186
- iteration += 1
187
- print(f"Shuffling dataset with seed {iteration}")
188
- dataset = dataset.shuffle(seed=iteration)
189
- # Clear the error queue
190
- # while not error_queue.empty():
191
- # error_queue.get_nowait()
192
-
193
-
194
- class ImageBatchIterator:
195
- def __init__(
196
- self, dataset: Dataset, batch_size: int = 64, image_shape=(256, 256),
197
- min_image_shape=(128, 128),
198
- num_workers: int = 8, num_threads=256, timeout=15, retries=3,
199
- image_processor=default_image_processor,
200
- upscale_interpolation=cv2.INTER_CUBIC,
201
- downscale_interpolation=cv2.INTER_AREA,
202
- feature_extractor=default_feature_extractor,
203
- ):
204
- self.dataset = dataset
205
- self.num_workers = num_workers
206
- self.batch_size = batch_size
207
- loader = partial(
208
- parallel_image_loader,
209
- num_threads=num_threads,
210
- image_shape=image_shape,
211
- min_image_shape=min_image_shape,
212
- num_workers=num_workers,
213
- timeout=timeout, retries=retries,
214
- image_processor=image_processor,
215
- upscale_interpolation=upscale_interpolation,
216
- downscale_interpolation=downscale_interpolation,
217
- feature_extractor=feature_extractor
218
- )
219
- self.thread = threading.Thread(target=loader, args=(dataset,))
220
- self.thread.start()
221
-
222
- def __iter__(self):
223
- return self
224
-
225
- def __next__(self):
226
- def fetcher(_):
227
- return data_queue.get()
228
- with ThreadPoolExecutor(max_workers=self.batch_size) as executor:
229
- batch = list(executor.map(fetcher, range(self.batch_size)))
230
- return batch
231
-
232
- def __del__(self):
233
- self.thread.join()
234
-
235
- def __len__(self):
236
- return len(self.dataset) // self.batch_size
237
-
238
-
239
- def default_collate(batch):
240
- urls = [sample["url"] for sample in batch]
241
- captions = [sample["caption"] for sample in batch]
242
- images = np.stack([sample["image"] for sample in batch], axis=0)
243
- return {
244
- "url": urls,
245
- "caption": captions,
246
- "image": images,
247
- }
248
-
249
-
250
- def dataMapper(map: Dict[str, Any]):
251
- def _map(sample) -> Dict[str, Any]:
252
- return {
253
- "url": sample[map["url"]],
254
- "caption": sample[map["caption"]],
255
- }
256
- return _map
257
-
258
-
259
- class OnlineStreamingDataLoader():
260
- def __init__(
261
- self,
262
- dataset,
263
- batch_size=64,
264
- image_shape=(256, 256),
265
- min_image_shape=(128, 128),
266
- num_workers=16,
267
- num_threads=512,
268
- default_split="all",
269
- pre_map_maker=dataMapper,
270
- pre_map_def={
271
- "url": "URL",
272
- "caption": "TEXT",
273
- },
274
- global_process_count=1,
275
- global_process_index=0,
276
- prefetch=1000,
277
- collate_fn=default_collate,
278
- timeout=15,
279
- retries=3,
280
- image_processor=default_image_processor,
281
- upscale_interpolation=cv2.INTER_CUBIC,
282
- downscale_interpolation=cv2.INTER_AREA,
283
- feature_extractor=default_feature_extractor,
284
- ):
285
- if isinstance(dataset, str):
286
- dataset_path = dataset
287
- print("Loading dataset from path")
288
- if "gs://" in dataset:
289
- dataset = load_from_disk(dataset_path)
290
- else:
291
- dataset = load_dataset(dataset_path, split=default_split)
292
- elif isinstance(dataset, list):
293
- if isinstance(dataset[0], str):
294
- print("Loading multiple datasets from paths")
295
- dataset = [load_from_disk(dataset_path) if "gs://" in dataset_path else load_dataset(
296
- dataset_path, split=default_split) for dataset_path in dataset]
297
- print("Concatenating multiple datasets")
298
- dataset = concatenate_datasets(dataset)
299
- dataset = dataset.shuffle(seed=0)
300
- # dataset = dataset.map(pre_map_maker(pre_map_def), batched=True, batch_size=10000000)
301
- self.dataset = dataset.shard(
302
- num_shards=global_process_count, index=global_process_index)
303
- print(f"Dataset length: {len(dataset)}")
304
- self.iterator = ImageBatchIterator(
305
- self.dataset, image_shape=image_shape,
306
- min_image_shape=min_image_shape,
307
- num_workers=num_workers, batch_size=batch_size, num_threads=num_threads,
308
- timeout=timeout, retries=retries, image_processor=image_processor,
309
- upscale_interpolation=upscale_interpolation,
310
- downscale_interpolation=downscale_interpolation,
311
- feature_extractor=feature_extractor
312
- )
313
- self.batch_size = batch_size
314
-
315
- # Launch a thread to load batches in the background
316
- self.batch_queue = queue.Queue(prefetch)
317
-
318
- def batch_loader():
319
- for batch in self.iterator:
320
- try:
321
- self.batch_queue.put(collate_fn(batch))
322
- except Exception as e:
323
- print("Error processing batch", e)
324
-
325
- self.loader_thread = threading.Thread(target=batch_loader)
326
- self.loader_thread.start()
327
-
328
- def __iter__(self):
329
- return self
330
-
331
- def __next__(self):
332
- return self.batch_queue.get()
333
- # return self.collate_fn(next(self.iterator))
334
-
335
- def __len__(self):
336
- return len(self.dataset)
@@ -1 +0,0 @@
1
- from .simple_unet import *