flaxdiff 0.1.36__tar.gz → 0.1.36.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {flaxdiff-0.1.36 → flaxdiff-0.1.36.1}/PKG-INFO +13 -10
- {flaxdiff-0.1.36 → flaxdiff-0.1.36.1}/flaxdiff.egg-info/PKG-INFO +13 -10
- flaxdiff-0.1.36.1/flaxdiff.egg-info/SOURCES.txt +9 -0
- flaxdiff-0.1.36.1/flaxdiff.egg-info/requires.txt +14 -0
- flaxdiff-0.1.36.1/pyproject.toml +32 -0
- flaxdiff-0.1.36/flaxdiff/data/__init__.py +0 -1
- flaxdiff-0.1.36/flaxdiff/data/dataset_map.py +0 -71
- flaxdiff-0.1.36/flaxdiff/data/datasets.py +0 -169
- flaxdiff-0.1.36/flaxdiff/data/online_loader.py +0 -363
- flaxdiff-0.1.36/flaxdiff/models/__init__.py +0 -1
- flaxdiff-0.1.36/flaxdiff/models/attention.py +0 -368
- flaxdiff-0.1.36/flaxdiff/models/autoencoder/__init__.py +0 -2
- flaxdiff-0.1.36/flaxdiff/models/autoencoder/autoencoder.py +0 -19
- flaxdiff-0.1.36/flaxdiff/models/autoencoder/diffusers.py +0 -91
- flaxdiff-0.1.36/flaxdiff/models/autoencoder/simple_autoenc.py +0 -26
- flaxdiff-0.1.36/flaxdiff/models/common.py +0 -346
- flaxdiff-0.1.36/flaxdiff/models/favor_fastattn.py +0 -723
- flaxdiff-0.1.36/flaxdiff/models/simple_unet.py +0 -233
- flaxdiff-0.1.36/flaxdiff/models/simple_vit.py +0 -180
- flaxdiff-0.1.36/flaxdiff/predictors/__init__.py +0 -96
- flaxdiff-0.1.36/flaxdiff/samplers/__init__.py +0 -7
- flaxdiff-0.1.36/flaxdiff/samplers/common.py +0 -165
- flaxdiff-0.1.36/flaxdiff/samplers/ddim.py +0 -10
- flaxdiff-0.1.36/flaxdiff/samplers/ddpm.py +0 -37
- flaxdiff-0.1.36/flaxdiff/samplers/euler.py +0 -56
- flaxdiff-0.1.36/flaxdiff/samplers/heun_sampler.py +0 -27
- flaxdiff-0.1.36/flaxdiff/samplers/multistep_dpm.py +0 -59
- flaxdiff-0.1.36/flaxdiff/samplers/rk4_sampler.py +0 -34
- flaxdiff-0.1.36/flaxdiff/schedulers/__init__.py +0 -6
- flaxdiff-0.1.36/flaxdiff/schedulers/common.py +0 -98
- flaxdiff-0.1.36/flaxdiff/schedulers/continuous.py +0 -12
- flaxdiff-0.1.36/flaxdiff/schedulers/cosine.py +0 -40
- flaxdiff-0.1.36/flaxdiff/schedulers/discrete.py +0 -74
- flaxdiff-0.1.36/flaxdiff/schedulers/exp.py +0 -13
- flaxdiff-0.1.36/flaxdiff/schedulers/karras.py +0 -69
- flaxdiff-0.1.36/flaxdiff/schedulers/linear.py +0 -14
- flaxdiff-0.1.36/flaxdiff/schedulers/sqrt.py +0 -10
- flaxdiff-0.1.36/flaxdiff/trainer/__init__.py +0 -2
- flaxdiff-0.1.36/flaxdiff/trainer/autoencoder_trainer.py +0 -182
- flaxdiff-0.1.36/flaxdiff/trainer/diffusion_trainer.py +0 -326
- flaxdiff-0.1.36/flaxdiff/trainer/simple_trainer.py +0 -538
- flaxdiff-0.1.36/flaxdiff/trainer/video_diffusion_trainer.py +0 -62
- flaxdiff-0.1.36/flaxdiff.egg-info/SOURCES.txt +0 -46
- flaxdiff-0.1.36/flaxdiff.egg-info/requires.txt +0 -5
- flaxdiff-0.1.36/setup.py +0 -21
- {flaxdiff-0.1.36 → flaxdiff-0.1.36.1}/README.md +0 -0
- {flaxdiff-0.1.36 → flaxdiff-0.1.36.1}/flaxdiff/__init__.py +0 -0
- {flaxdiff-0.1.36 → flaxdiff-0.1.36.1}/flaxdiff/utils.py +0 -0
- {flaxdiff-0.1.36 → flaxdiff-0.1.36.1}/flaxdiff.egg-info/dependency_links.txt +0 -0
- {flaxdiff-0.1.36 → flaxdiff-0.1.36.1}/flaxdiff.egg-info/top_level.txt +0 -0
- {flaxdiff-0.1.36 → flaxdiff-0.1.36.1}/setup.cfg +0 -0
@@ -1,21 +1,24 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: flaxdiff
|
3
|
-
Version: 0.1.36
|
3
|
+
Version: 0.1.36.1
|
4
4
|
Summary: A versatile and easy to understand Diffusion library
|
5
|
-
Author: Ashish Kumar Singh
|
6
|
-
|
5
|
+
Author-email: Ashish Kumar Singh <ashishkmr472@gmail.com>
|
6
|
+
License-Expression: MIT
|
7
7
|
Description-Content-Type: text/markdown
|
8
8
|
Requires-Dist: flax>=0.8.4
|
9
|
-
Requires-Dist: optax>=0.2.2
|
10
9
|
Requires-Dist: jax>=0.4.28
|
10
|
+
Requires-Dist: optax>=0.2.2
|
11
11
|
Requires-Dist: orbax
|
12
|
+
Requires-Dist: numpy
|
12
13
|
Requires-Dist: clu
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
14
|
+
Requires-Dist: einops
|
15
|
+
Requires-Dist: tqdm
|
16
|
+
Requires-Dist: grain
|
17
|
+
Requires-Dist: termcolor
|
18
|
+
Requires-Dist: augmax
|
19
|
+
Requires-Dist: albumentations
|
20
|
+
Requires-Dist: rich
|
21
|
+
Requires-Dist: python-dotenv
|
19
22
|
|
20
23
|
# 
|
21
24
|
|
@@ -1,21 +1,24 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: flaxdiff
|
3
|
-
Version: 0.1.36
|
3
|
+
Version: 0.1.36.1
|
4
4
|
Summary: A versatile and easy to understand Diffusion library
|
5
|
-
Author: Ashish Kumar Singh
|
6
|
-
|
5
|
+
Author-email: Ashish Kumar Singh <ashishkmr472@gmail.com>
|
6
|
+
License-Expression: MIT
|
7
7
|
Description-Content-Type: text/markdown
|
8
8
|
Requires-Dist: flax>=0.8.4
|
9
|
-
Requires-Dist: optax>=0.2.2
|
10
9
|
Requires-Dist: jax>=0.4.28
|
10
|
+
Requires-Dist: optax>=0.2.2
|
11
11
|
Requires-Dist: orbax
|
12
|
+
Requires-Dist: numpy
|
12
13
|
Requires-Dist: clu
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
14
|
+
Requires-Dist: einops
|
15
|
+
Requires-Dist: tqdm
|
16
|
+
Requires-Dist: grain
|
17
|
+
Requires-Dist: termcolor
|
18
|
+
Requires-Dist: augmax
|
19
|
+
Requires-Dist: albumentations
|
20
|
+
Requires-Dist: rich
|
21
|
+
Requires-Dist: python-dotenv
|
19
22
|
|
20
23
|
# 
|
21
24
|
|
@@ -0,0 +1,32 @@
|
|
1
|
+
[build-system]
|
2
|
+
requires = ["setuptools", "wheel"]
|
3
|
+
build-backend = "setuptools.build_meta"
|
4
|
+
|
5
|
+
[project]
|
6
|
+
name = "flaxdiff"
|
7
|
+
version = "0.1.36.1"
|
8
|
+
description = "A versatile and easy to understand Diffusion library"
|
9
|
+
readme = "README.md"
|
10
|
+
authors = [
|
11
|
+
{ name="Ashish Kumar Singh", email="ashishkmr472@gmail.com" }
|
12
|
+
]
|
13
|
+
dependencies = [
|
14
|
+
"flax>=0.8.4",
|
15
|
+
"jax>=0.4.28",
|
16
|
+
"optax>=0.2.2",
|
17
|
+
"orbax",
|
18
|
+
"numpy",
|
19
|
+
"clu",
|
20
|
+
"einops",
|
21
|
+
"tqdm",
|
22
|
+
"grain",
|
23
|
+
"termcolor",
|
24
|
+
"augmax",
|
25
|
+
"albumentations",
|
26
|
+
"rich",
|
27
|
+
"python-dotenv",
|
28
|
+
]
|
29
|
+
license = "MIT"
|
30
|
+
|
31
|
+
[tool.setuptools]
|
32
|
+
packages = ["flaxdiff"]
|
@@ -1 +0,0 @@
|
|
1
|
-
from .online_loader import OnlineStreamingDataLoader
|
@@ -1,71 +0,0 @@
|
|
1
|
-
from .sources.tfds import data_source_tfds, tfds_augmenters
|
2
|
-
from .sources.gcs import data_source_gcs, data_source_combined_gcs, gcs_augmenters
|
3
|
-
|
4
|
-
# Configure the following for your datasets
|
5
|
-
datasetMap = {
|
6
|
-
"oxford_flowers102": {
|
7
|
-
"source": data_source_tfds("oxford_flowers102", use_tf=False),
|
8
|
-
"augmenter": tfds_augmenters,
|
9
|
-
},
|
10
|
-
"cc12m": {
|
11
|
-
"source": data_source_gcs('arrayrecord2/cc12m'),
|
12
|
-
"augmenter": gcs_augmenters,
|
13
|
-
},
|
14
|
-
"laiona_coco": {
|
15
|
-
"source": data_source_gcs('arrayrecord2/laion-aesthetics-12m+mscoco-2017'),
|
16
|
-
"augmenter": gcs_augmenters,
|
17
|
-
},
|
18
|
-
"aesthetic_coyo": {
|
19
|
-
"source": data_source_gcs('arrayrecords/aestheticCoyo_0.25clip_6aesthetic'),
|
20
|
-
"augmenter": gcs_augmenters,
|
21
|
-
},
|
22
|
-
"combined_aesthetic": {
|
23
|
-
"source": data_source_combined_gcs([
|
24
|
-
'arrayrecord2/laion-aesthetics-12m+mscoco-2017',
|
25
|
-
'arrayrecords/aestheticCoyo_0.25clip_6aesthetic',
|
26
|
-
'arrayrecord2/cc12m',
|
27
|
-
'arrayrecords/aestheticCoyo_0.25clip_6aesthetic',
|
28
|
-
]),
|
29
|
-
"augmenter": gcs_augmenters,
|
30
|
-
},
|
31
|
-
"laiona_coco_coyo": {
|
32
|
-
"source": data_source_combined_gcs([
|
33
|
-
'arrayrecords/aestheticCoyo_0.25clip_6aesthetic',
|
34
|
-
'arrayrecord2/laion-aesthetics-12m+mscoco-2017',
|
35
|
-
'arrayrecords/aestheticCoyo_0.25clip_6aesthetic',
|
36
|
-
]),
|
37
|
-
"augmenter": gcs_augmenters,
|
38
|
-
},
|
39
|
-
"combined_30m": {
|
40
|
-
"source": data_source_combined_gcs([
|
41
|
-
'arrayrecord2/laion-aesthetics-12m+mscoco-2017',
|
42
|
-
'arrayrecord2/cc12m',
|
43
|
-
'arrayrecord2/aestheticCoyo_0.26_clip_5.5aesthetic_256plus',
|
44
|
-
"arrayrecord2/playground+leonardo_x4+cc3m.parquet",
|
45
|
-
]),
|
46
|
-
"augmenter": gcs_augmenters,
|
47
|
-
}
|
48
|
-
}
|
49
|
-
|
50
|
-
onlineDatasetMap = {
|
51
|
-
"combined_online": {
|
52
|
-
"source": [
|
53
|
-
# "gs://flaxdiff-datasets-regional/datasets/laion-aesthetics-12m+mscoco-2017.parquet"
|
54
|
-
# "ChristophSchuhmann/MS_COCO_2017_URL_TEXT",
|
55
|
-
# "dclure/laion-aesthetics-12m-umap",
|
56
|
-
"gs://flaxdiff-datasets-regional/datasets/laion-aesthetics-12m+mscoco-2017",
|
57
|
-
"gs://flaxdiff-datasets-regional/datasets/coyo700m-aesthetic-5.4_25M",
|
58
|
-
"gs://flaxdiff-datasets-regional/datasets/leonardo-liked-1.8m",
|
59
|
-
"gs://flaxdiff-datasets-regional/datasets/leonardo-liked-1.8m",
|
60
|
-
"gs://flaxdiff-datasets-regional/datasets/leonardo-liked-1.8m",
|
61
|
-
"gs://flaxdiff-datasets-regional/datasets/cc12m",
|
62
|
-
"gs://flaxdiff-datasets-regional/datasets/playground-liked",
|
63
|
-
"gs://flaxdiff-datasets-regional/datasets/leonardo-liked-1.8m",
|
64
|
-
"gs://flaxdiff-datasets-regional/datasets/leonardo-liked-1.8m",
|
65
|
-
"gs://flaxdiff-datasets-regional/datasets/cc3m",
|
66
|
-
"gs://flaxdiff-datasets-regional/datasets/cc3m",
|
67
|
-
"gs://flaxdiff-datasets-regional/datasets/laion2B-en-aesthetic-4.2_37M",
|
68
|
-
# "gs://flaxdiff-datasets-regional/datasets/laiion400m-185M"
|
69
|
-
]
|
70
|
-
}
|
71
|
-
}
|
@@ -1,169 +0,0 @@
|
|
1
|
-
import jax.numpy as jnp
|
2
|
-
import grain.python as pygrain
|
3
|
-
from typing import Dict
|
4
|
-
import numpy as np
|
5
|
-
import jax
|
6
|
-
from flaxdiff.utils import convert_to_global_tree, AutoTextTokenizer
|
7
|
-
from .dataset_map import datasetMap, onlineDatasetMap
|
8
|
-
import traceback
|
9
|
-
from .online_loader import OnlineStreamingDataLoader
|
10
|
-
import queue
|
11
|
-
from jax.sharding import Mesh
|
12
|
-
import threading
|
13
|
-
|
14
|
-
def batch_mesh_map(mesh):
|
15
|
-
class augmenters(pygrain.MapTransform):
|
16
|
-
def __init__(self, *args, **kwargs):
|
17
|
-
super().__init__(*args, **kwargs)
|
18
|
-
|
19
|
-
def map(self, batch) -> Dict[str, jnp.array]:
|
20
|
-
return convert_to_global_tree(mesh, batch)
|
21
|
-
return augmenters
|
22
|
-
|
23
|
-
def get_dataset_grain(
|
24
|
-
data_name="cc12m",
|
25
|
-
batch_size=64,
|
26
|
-
image_scale=256,
|
27
|
-
count=None,
|
28
|
-
num_epochs=None,
|
29
|
-
method=jax.image.ResizeMethod.LANCZOS3,
|
30
|
-
worker_count=32,
|
31
|
-
read_thread_count=64,
|
32
|
-
read_buffer_size=50,
|
33
|
-
worker_buffer_size=20,
|
34
|
-
seed=0,
|
35
|
-
dataset_source="/mnt/gcs_mount/arrayrecord2/cc12m/",
|
36
|
-
):
|
37
|
-
dataset = datasetMap[data_name]
|
38
|
-
data_source = dataset["source"](dataset_source)
|
39
|
-
augmenter = dataset["augmenter"](image_scale, method)
|
40
|
-
|
41
|
-
local_batch_size = batch_size // jax.process_count()
|
42
|
-
|
43
|
-
sampler = pygrain.IndexSampler(
|
44
|
-
num_records=len(data_source) if count is None else count,
|
45
|
-
shuffle=True,
|
46
|
-
seed=seed,
|
47
|
-
num_epochs=num_epochs,
|
48
|
-
shard_options=pygrain.ShardByJaxProcess(),
|
49
|
-
)
|
50
|
-
|
51
|
-
def get_trainset():
|
52
|
-
transformations = [
|
53
|
-
augmenter(),
|
54
|
-
pygrain.Batch(local_batch_size, drop_remainder=True),
|
55
|
-
]
|
56
|
-
|
57
|
-
# if mesh != None:
|
58
|
-
# transformations += [batch_mesh_map(mesh)]
|
59
|
-
|
60
|
-
loader = pygrain.DataLoader(
|
61
|
-
data_source=data_source,
|
62
|
-
sampler=sampler,
|
63
|
-
operations=transformations,
|
64
|
-
worker_count=worker_count,
|
65
|
-
read_options=pygrain.ReadOptions(
|
66
|
-
read_thread_count, read_buffer_size
|
67
|
-
),
|
68
|
-
worker_buffer_size=worker_buffer_size,
|
69
|
-
)
|
70
|
-
return loader
|
71
|
-
|
72
|
-
|
73
|
-
return {
|
74
|
-
"train": get_trainset,
|
75
|
-
"train_len": len(data_source),
|
76
|
-
"local_batch_size": local_batch_size,
|
77
|
-
"global_batch_size": batch_size,
|
78
|
-
# "null_labels": null_labels,
|
79
|
-
# "null_labels_full": null_labels_full,
|
80
|
-
# "model": model,
|
81
|
-
# "tokenizer": tokenizer,
|
82
|
-
}
|
83
|
-
|
84
|
-
def generate_collate_fn():
|
85
|
-
auto_tokenize = AutoTextTokenizer(tensor_type="np")
|
86
|
-
def default_collate(batch):
|
87
|
-
try:
|
88
|
-
# urls = [sample["url"] for sample in batch]
|
89
|
-
captions = [sample["caption"] for sample in batch]
|
90
|
-
results = auto_tokenize(captions)
|
91
|
-
images = np.stack([sample["image"] for sample in batch], axis=0)
|
92
|
-
return {
|
93
|
-
"image": images,
|
94
|
-
"input_ids": results['input_ids'],
|
95
|
-
"attention_mask": results['attention_mask'],
|
96
|
-
}
|
97
|
-
except Exception as e:
|
98
|
-
print("Error in collate function", e, [sample["image"].shape for sample in batch])
|
99
|
-
traceback.print_exc()
|
100
|
-
|
101
|
-
return default_collate
|
102
|
-
|
103
|
-
def get_dataset_online(
|
104
|
-
data_name="combined_online",
|
105
|
-
batch_size=64,
|
106
|
-
image_scale=256,
|
107
|
-
count=None,
|
108
|
-
num_epochs=None,
|
109
|
-
method=jax.image.ResizeMethod.LANCZOS3,
|
110
|
-
worker_count=32,
|
111
|
-
read_thread_count=64,
|
112
|
-
read_buffer_size=50,
|
113
|
-
worker_buffer_size=20,
|
114
|
-
seed=0,
|
115
|
-
dataset_source="/mnt/gcs_mount/arrayrecord2/cc12m/",
|
116
|
-
):
|
117
|
-
local_batch_size = batch_size // jax.process_count()
|
118
|
-
|
119
|
-
sources = onlineDatasetMap[data_name]["source"]
|
120
|
-
dataloader = OnlineStreamingDataLoader(
|
121
|
-
sources,
|
122
|
-
batch_size=local_batch_size,
|
123
|
-
num_workers=worker_count,
|
124
|
-
num_threads=read_thread_count,
|
125
|
-
image_shape=(image_scale, image_scale),
|
126
|
-
global_process_count=jax.process_count(),
|
127
|
-
global_process_index=jax.process_index(),
|
128
|
-
prefetch=worker_buffer_size,
|
129
|
-
collate_fn=generate_collate_fn(),
|
130
|
-
default_split="train",
|
131
|
-
)
|
132
|
-
|
133
|
-
def get_trainset(mesh: Mesh = None):
|
134
|
-
if mesh != None:
|
135
|
-
class dataLoaderWithMesh:
|
136
|
-
def __init__(self, dataloader, mesh):
|
137
|
-
self.dataloader = dataloader
|
138
|
-
self.mesh = mesh
|
139
|
-
self.tmp_queue = queue.Queue(worker_buffer_size)
|
140
|
-
def batch_loader():
|
141
|
-
for batch in self.dataloader:
|
142
|
-
try:
|
143
|
-
self.tmp_queue.put(convert_to_global_tree(mesh, batch))
|
144
|
-
except Exception as e:
|
145
|
-
print("Error processing batch", e)
|
146
|
-
self.loader_thread = threading.Thread(target=batch_loader)
|
147
|
-
self.loader_thread.start()
|
148
|
-
|
149
|
-
def __iter__(self):
|
150
|
-
return self
|
151
|
-
|
152
|
-
def __next__(self):
|
153
|
-
return self.tmp_queue.get()
|
154
|
-
|
155
|
-
dataloader_with_mesh = dataLoaderWithMesh(dataloader, mesh)
|
156
|
-
|
157
|
-
return dataloader_with_mesh
|
158
|
-
return dataloader
|
159
|
-
|
160
|
-
return {
|
161
|
-
"train": get_trainset,
|
162
|
-
"train_len": len(dataloader) * jax.process_count(),
|
163
|
-
"local_batch_size": local_batch_size,
|
164
|
-
"global_batch_size": batch_size,
|
165
|
-
# "null_labels": null_labels,
|
166
|
-
# "null_labels_full": null_labels_full,
|
167
|
-
# "model": model,
|
168
|
-
# "tokenizer": tokenizer,
|
169
|
-
}
|