flaxdiff 0.1.36__tar.gz → 0.1.36.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. {flaxdiff-0.1.36 → flaxdiff-0.1.36.1}/PKG-INFO +13 -10
  2. {flaxdiff-0.1.36 → flaxdiff-0.1.36.1}/flaxdiff.egg-info/PKG-INFO +13 -10
  3. flaxdiff-0.1.36.1/flaxdiff.egg-info/SOURCES.txt +9 -0
  4. flaxdiff-0.1.36.1/flaxdiff.egg-info/requires.txt +14 -0
  5. flaxdiff-0.1.36.1/pyproject.toml +32 -0
  6. flaxdiff-0.1.36/flaxdiff/data/__init__.py +0 -1
  7. flaxdiff-0.1.36/flaxdiff/data/dataset_map.py +0 -71
  8. flaxdiff-0.1.36/flaxdiff/data/datasets.py +0 -169
  9. flaxdiff-0.1.36/flaxdiff/data/online_loader.py +0 -363
  10. flaxdiff-0.1.36/flaxdiff/models/__init__.py +0 -1
  11. flaxdiff-0.1.36/flaxdiff/models/attention.py +0 -368
  12. flaxdiff-0.1.36/flaxdiff/models/autoencoder/__init__.py +0 -2
  13. flaxdiff-0.1.36/flaxdiff/models/autoencoder/autoencoder.py +0 -19
  14. flaxdiff-0.1.36/flaxdiff/models/autoencoder/diffusers.py +0 -91
  15. flaxdiff-0.1.36/flaxdiff/models/autoencoder/simple_autoenc.py +0 -26
  16. flaxdiff-0.1.36/flaxdiff/models/common.py +0 -346
  17. flaxdiff-0.1.36/flaxdiff/models/favor_fastattn.py +0 -723
  18. flaxdiff-0.1.36/flaxdiff/models/simple_unet.py +0 -233
  19. flaxdiff-0.1.36/flaxdiff/models/simple_vit.py +0 -180
  20. flaxdiff-0.1.36/flaxdiff/predictors/__init__.py +0 -96
  21. flaxdiff-0.1.36/flaxdiff/samplers/__init__.py +0 -7
  22. flaxdiff-0.1.36/flaxdiff/samplers/common.py +0 -165
  23. flaxdiff-0.1.36/flaxdiff/samplers/ddim.py +0 -10
  24. flaxdiff-0.1.36/flaxdiff/samplers/ddpm.py +0 -37
  25. flaxdiff-0.1.36/flaxdiff/samplers/euler.py +0 -56
  26. flaxdiff-0.1.36/flaxdiff/samplers/heun_sampler.py +0 -27
  27. flaxdiff-0.1.36/flaxdiff/samplers/multistep_dpm.py +0 -59
  28. flaxdiff-0.1.36/flaxdiff/samplers/rk4_sampler.py +0 -34
  29. flaxdiff-0.1.36/flaxdiff/schedulers/__init__.py +0 -6
  30. flaxdiff-0.1.36/flaxdiff/schedulers/common.py +0 -98
  31. flaxdiff-0.1.36/flaxdiff/schedulers/continuous.py +0 -12
  32. flaxdiff-0.1.36/flaxdiff/schedulers/cosine.py +0 -40
  33. flaxdiff-0.1.36/flaxdiff/schedulers/discrete.py +0 -74
  34. flaxdiff-0.1.36/flaxdiff/schedulers/exp.py +0 -13
  35. flaxdiff-0.1.36/flaxdiff/schedulers/karras.py +0 -69
  36. flaxdiff-0.1.36/flaxdiff/schedulers/linear.py +0 -14
  37. flaxdiff-0.1.36/flaxdiff/schedulers/sqrt.py +0 -10
  38. flaxdiff-0.1.36/flaxdiff/trainer/__init__.py +0 -2
  39. flaxdiff-0.1.36/flaxdiff/trainer/autoencoder_trainer.py +0 -182
  40. flaxdiff-0.1.36/flaxdiff/trainer/diffusion_trainer.py +0 -326
  41. flaxdiff-0.1.36/flaxdiff/trainer/simple_trainer.py +0 -538
  42. flaxdiff-0.1.36/flaxdiff/trainer/video_diffusion_trainer.py +0 -62
  43. flaxdiff-0.1.36/flaxdiff.egg-info/SOURCES.txt +0 -46
  44. flaxdiff-0.1.36/flaxdiff.egg-info/requires.txt +0 -5
  45. flaxdiff-0.1.36/setup.py +0 -21
  46. {flaxdiff-0.1.36 → flaxdiff-0.1.36.1}/README.md +0 -0
  47. {flaxdiff-0.1.36 → flaxdiff-0.1.36.1}/flaxdiff/__init__.py +0 -0
  48. {flaxdiff-0.1.36 → flaxdiff-0.1.36.1}/flaxdiff/utils.py +0 -0
  49. {flaxdiff-0.1.36 → flaxdiff-0.1.36.1}/flaxdiff.egg-info/dependency_links.txt +0 -0
  50. {flaxdiff-0.1.36 → flaxdiff-0.1.36.1}/flaxdiff.egg-info/top_level.txt +0 -0
  51. {flaxdiff-0.1.36 → flaxdiff-0.1.36.1}/setup.cfg +0 -0
@@ -1,21 +1,24 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: flaxdiff
3
- Version: 0.1.36
3
+ Version: 0.1.36.1
4
4
  Summary: A versatile and easy to understand Diffusion library
5
- Author: Ashish Kumar Singh
6
- Author-email: ashishkmr472@gmail.com
5
+ Author-email: Ashish Kumar Singh <ashishkmr472@gmail.com>
6
+ License-Expression: MIT
7
7
  Description-Content-Type: text/markdown
8
8
  Requires-Dist: flax>=0.8.4
9
- Requires-Dist: optax>=0.2.2
10
9
  Requires-Dist: jax>=0.4.28
10
+ Requires-Dist: optax>=0.2.2
11
11
  Requires-Dist: orbax
12
+ Requires-Dist: numpy
12
13
  Requires-Dist: clu
13
- Dynamic: author
14
- Dynamic: author-email
15
- Dynamic: description
16
- Dynamic: description-content-type
17
- Dynamic: requires-dist
18
- Dynamic: summary
14
+ Requires-Dist: einops
15
+ Requires-Dist: tqdm
16
+ Requires-Dist: grain
17
+ Requires-Dist: termcolor
18
+ Requires-Dist: augmax
19
+ Requires-Dist: albumentations
20
+ Requires-Dist: rich
21
+ Requires-Dist: python-dotenv
19
22
 
20
23
  # ![](images/logo.jpeg "FlaxDiff")
21
24
 
@@ -1,21 +1,24 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: flaxdiff
3
- Version: 0.1.36
3
+ Version: 0.1.36.1
4
4
  Summary: A versatile and easy to understand Diffusion library
5
- Author: Ashish Kumar Singh
6
- Author-email: ashishkmr472@gmail.com
5
+ Author-email: Ashish Kumar Singh <ashishkmr472@gmail.com>
6
+ License-Expression: MIT
7
7
  Description-Content-Type: text/markdown
8
8
  Requires-Dist: flax>=0.8.4
9
- Requires-Dist: optax>=0.2.2
10
9
  Requires-Dist: jax>=0.4.28
10
+ Requires-Dist: optax>=0.2.2
11
11
  Requires-Dist: orbax
12
+ Requires-Dist: numpy
12
13
  Requires-Dist: clu
13
- Dynamic: author
14
- Dynamic: author-email
15
- Dynamic: description
16
- Dynamic: description-content-type
17
- Dynamic: requires-dist
18
- Dynamic: summary
14
+ Requires-Dist: einops
15
+ Requires-Dist: tqdm
16
+ Requires-Dist: grain
17
+ Requires-Dist: termcolor
18
+ Requires-Dist: augmax
19
+ Requires-Dist: albumentations
20
+ Requires-Dist: rich
21
+ Requires-Dist: python-dotenv
19
22
 
20
23
  # ![](images/logo.jpeg "FlaxDiff")
21
24
 
@@ -0,0 +1,9 @@
1
+ README.md
2
+ pyproject.toml
3
+ flaxdiff/__init__.py
4
+ flaxdiff/utils.py
5
+ flaxdiff.egg-info/PKG-INFO
6
+ flaxdiff.egg-info/SOURCES.txt
7
+ flaxdiff.egg-info/dependency_links.txt
8
+ flaxdiff.egg-info/requires.txt
9
+ flaxdiff.egg-info/top_level.txt
@@ -0,0 +1,14 @@
1
+ flax>=0.8.4
2
+ jax>=0.4.28
3
+ optax>=0.2.2
4
+ orbax
5
+ numpy
6
+ clu
7
+ einops
8
+ tqdm
9
+ grain
10
+ termcolor
11
+ augmax
12
+ albumentations
13
+ rich
14
+ python-dotenv
@@ -0,0 +1,32 @@
1
+ [build-system]
2
+ requires = ["setuptools", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "flaxdiff"
7
+ version = "0.1.36.1"
8
+ description = "A versatile and easy to understand Diffusion library"
9
+ readme = "README.md"
10
+ authors = [
11
+ { name="Ashish Kumar Singh", email="ashishkmr472@gmail.com" }
12
+ ]
13
+ dependencies = [
14
+ "flax>=0.8.4",
15
+ "jax>=0.4.28",
16
+ "optax>=0.2.2",
17
+ "orbax",
18
+ "numpy",
19
+ "clu",
20
+ "einops",
21
+ "tqdm",
22
+ "grain",
23
+ "termcolor",
24
+ "augmax",
25
+ "albumentations",
26
+ "rich",
27
+ "python-dotenv",
28
+ ]
29
+ license = "MIT"
30
+
31
+ [tool.setuptools]
32
+ packages = ["flaxdiff"]
@@ -1 +0,0 @@
1
- from .online_loader import OnlineStreamingDataLoader
@@ -1,71 +0,0 @@
1
- from .sources.tfds import data_source_tfds, tfds_augmenters
2
- from .sources.gcs import data_source_gcs, data_source_combined_gcs, gcs_augmenters
3
-
4
- # Configure the following for your datasets
5
- datasetMap = {
6
- "oxford_flowers102": {
7
- "source": data_source_tfds("oxford_flowers102", use_tf=False),
8
- "augmenter": tfds_augmenters,
9
- },
10
- "cc12m": {
11
- "source": data_source_gcs('arrayrecord2/cc12m'),
12
- "augmenter": gcs_augmenters,
13
- },
14
- "laiona_coco": {
15
- "source": data_source_gcs('arrayrecord2/laion-aesthetics-12m+mscoco-2017'),
16
- "augmenter": gcs_augmenters,
17
- },
18
- "aesthetic_coyo": {
19
- "source": data_source_gcs('arrayrecords/aestheticCoyo_0.25clip_6aesthetic'),
20
- "augmenter": gcs_augmenters,
21
- },
22
- "combined_aesthetic": {
23
- "source": data_source_combined_gcs([
24
- 'arrayrecord2/laion-aesthetics-12m+mscoco-2017',
25
- 'arrayrecords/aestheticCoyo_0.25clip_6aesthetic',
26
- 'arrayrecord2/cc12m',
27
- 'arrayrecords/aestheticCoyo_0.25clip_6aesthetic',
28
- ]),
29
- "augmenter": gcs_augmenters,
30
- },
31
- "laiona_coco_coyo": {
32
- "source": data_source_combined_gcs([
33
- 'arrayrecords/aestheticCoyo_0.25clip_6aesthetic',
34
- 'arrayrecord2/laion-aesthetics-12m+mscoco-2017',
35
- 'arrayrecords/aestheticCoyo_0.25clip_6aesthetic',
36
- ]),
37
- "augmenter": gcs_augmenters,
38
- },
39
- "combined_30m": {
40
- "source": data_source_combined_gcs([
41
- 'arrayrecord2/laion-aesthetics-12m+mscoco-2017',
42
- 'arrayrecord2/cc12m',
43
- 'arrayrecord2/aestheticCoyo_0.26_clip_5.5aesthetic_256plus',
44
- "arrayrecord2/playground+leonardo_x4+cc3m.parquet",
45
- ]),
46
- "augmenter": gcs_augmenters,
47
- }
48
- }
49
-
50
- onlineDatasetMap = {
51
- "combined_online": {
52
- "source": [
53
- # "gs://flaxdiff-datasets-regional/datasets/laion-aesthetics-12m+mscoco-2017.parquet"
54
- # "ChristophSchuhmann/MS_COCO_2017_URL_TEXT",
55
- # "dclure/laion-aesthetics-12m-umap",
56
- "gs://flaxdiff-datasets-regional/datasets/laion-aesthetics-12m+mscoco-2017",
57
- "gs://flaxdiff-datasets-regional/datasets/coyo700m-aesthetic-5.4_25M",
58
- "gs://flaxdiff-datasets-regional/datasets/leonardo-liked-1.8m",
59
- "gs://flaxdiff-datasets-regional/datasets/leonardo-liked-1.8m",
60
- "gs://flaxdiff-datasets-regional/datasets/leonardo-liked-1.8m",
61
- "gs://flaxdiff-datasets-regional/datasets/cc12m",
62
- "gs://flaxdiff-datasets-regional/datasets/playground-liked",
63
- "gs://flaxdiff-datasets-regional/datasets/leonardo-liked-1.8m",
64
- "gs://flaxdiff-datasets-regional/datasets/leonardo-liked-1.8m",
65
- "gs://flaxdiff-datasets-regional/datasets/cc3m",
66
- "gs://flaxdiff-datasets-regional/datasets/cc3m",
67
- "gs://flaxdiff-datasets-regional/datasets/laion2B-en-aesthetic-4.2_37M",
68
- # "gs://flaxdiff-datasets-regional/datasets/laiion400m-185M"
69
- ]
70
- }
71
- }
@@ -1,169 +0,0 @@
1
- import jax.numpy as jnp
2
- import grain.python as pygrain
3
- from typing import Dict
4
- import numpy as np
5
- import jax
6
- from flaxdiff.utils import convert_to_global_tree, AutoTextTokenizer
7
- from .dataset_map import datasetMap, onlineDatasetMap
8
- import traceback
9
- from .online_loader import OnlineStreamingDataLoader
10
- import queue
11
- from jax.sharding import Mesh
12
- import threading
13
-
14
- def batch_mesh_map(mesh):
15
- class augmenters(pygrain.MapTransform):
16
- def __init__(self, *args, **kwargs):
17
- super().__init__(*args, **kwargs)
18
-
19
- def map(self, batch) -> Dict[str, jnp.array]:
20
- return convert_to_global_tree(mesh, batch)
21
- return augmenters
22
-
23
- def get_dataset_grain(
24
- data_name="cc12m",
25
- batch_size=64,
26
- image_scale=256,
27
- count=None,
28
- num_epochs=None,
29
- method=jax.image.ResizeMethod.LANCZOS3,
30
- worker_count=32,
31
- read_thread_count=64,
32
- read_buffer_size=50,
33
- worker_buffer_size=20,
34
- seed=0,
35
- dataset_source="/mnt/gcs_mount/arrayrecord2/cc12m/",
36
- ):
37
- dataset = datasetMap[data_name]
38
- data_source = dataset["source"](dataset_source)
39
- augmenter = dataset["augmenter"](image_scale, method)
40
-
41
- local_batch_size = batch_size // jax.process_count()
42
-
43
- sampler = pygrain.IndexSampler(
44
- num_records=len(data_source) if count is None else count,
45
- shuffle=True,
46
- seed=seed,
47
- num_epochs=num_epochs,
48
- shard_options=pygrain.ShardByJaxProcess(),
49
- )
50
-
51
- def get_trainset():
52
- transformations = [
53
- augmenter(),
54
- pygrain.Batch(local_batch_size, drop_remainder=True),
55
- ]
56
-
57
- # if mesh != None:
58
- # transformations += [batch_mesh_map(mesh)]
59
-
60
- loader = pygrain.DataLoader(
61
- data_source=data_source,
62
- sampler=sampler,
63
- operations=transformations,
64
- worker_count=worker_count,
65
- read_options=pygrain.ReadOptions(
66
- read_thread_count, read_buffer_size
67
- ),
68
- worker_buffer_size=worker_buffer_size,
69
- )
70
- return loader
71
-
72
-
73
- return {
74
- "train": get_trainset,
75
- "train_len": len(data_source),
76
- "local_batch_size": local_batch_size,
77
- "global_batch_size": batch_size,
78
- # "null_labels": null_labels,
79
- # "null_labels_full": null_labels_full,
80
- # "model": model,
81
- # "tokenizer": tokenizer,
82
- }
83
-
84
- def generate_collate_fn():
85
- auto_tokenize = AutoTextTokenizer(tensor_type="np")
86
- def default_collate(batch):
87
- try:
88
- # urls = [sample["url"] for sample in batch]
89
- captions = [sample["caption"] for sample in batch]
90
- results = auto_tokenize(captions)
91
- images = np.stack([sample["image"] for sample in batch], axis=0)
92
- return {
93
- "image": images,
94
- "input_ids": results['input_ids'],
95
- "attention_mask": results['attention_mask'],
96
- }
97
- except Exception as e:
98
- print("Error in collate function", e, [sample["image"].shape for sample in batch])
99
- traceback.print_exc()
100
-
101
- return default_collate
102
-
103
- def get_dataset_online(
104
- data_name="combined_online",
105
- batch_size=64,
106
- image_scale=256,
107
- count=None,
108
- num_epochs=None,
109
- method=jax.image.ResizeMethod.LANCZOS3,
110
- worker_count=32,
111
- read_thread_count=64,
112
- read_buffer_size=50,
113
- worker_buffer_size=20,
114
- seed=0,
115
- dataset_source="/mnt/gcs_mount/arrayrecord2/cc12m/",
116
- ):
117
- local_batch_size = batch_size // jax.process_count()
118
-
119
- sources = onlineDatasetMap[data_name]["source"]
120
- dataloader = OnlineStreamingDataLoader(
121
- sources,
122
- batch_size=local_batch_size,
123
- num_workers=worker_count,
124
- num_threads=read_thread_count,
125
- image_shape=(image_scale, image_scale),
126
- global_process_count=jax.process_count(),
127
- global_process_index=jax.process_index(),
128
- prefetch=worker_buffer_size,
129
- collate_fn=generate_collate_fn(),
130
- default_split="train",
131
- )
132
-
133
- def get_trainset(mesh: Mesh = None):
134
- if mesh != None:
135
- class dataLoaderWithMesh:
136
- def __init__(self, dataloader, mesh):
137
- self.dataloader = dataloader
138
- self.mesh = mesh
139
- self.tmp_queue = queue.Queue(worker_buffer_size)
140
- def batch_loader():
141
- for batch in self.dataloader:
142
- try:
143
- self.tmp_queue.put(convert_to_global_tree(mesh, batch))
144
- except Exception as e:
145
- print("Error processing batch", e)
146
- self.loader_thread = threading.Thread(target=batch_loader)
147
- self.loader_thread.start()
148
-
149
- def __iter__(self):
150
- return self
151
-
152
- def __next__(self):
153
- return self.tmp_queue.get()
154
-
155
- dataloader_with_mesh = dataLoaderWithMesh(dataloader, mesh)
156
-
157
- return dataloader_with_mesh
158
- return dataloader
159
-
160
- return {
161
- "train": get_trainset,
162
- "train_len": len(dataloader) * jax.process_count(),
163
- "local_batch_size": local_batch_size,
164
- "global_batch_size": batch_size,
165
- # "null_labels": null_labels,
166
- # "null_labels_full": null_labels_full,
167
- # "model": model,
168
- # "tokenizer": tokenizer,
169
- }