flaxdiff 0.1.35.6__py3-none-any.whl → 0.1.36__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/data/dataset_map.py +71 -0
- flaxdiff/data/datasets.py +169 -0
- flaxdiff/data/online_loader.py +69 -42
- flaxdiff/samplers/common.py +72 -20
- flaxdiff/samplers/ddim.py +5 -5
- flaxdiff/samplers/ddpm.py +5 -11
- flaxdiff/samplers/euler.py +7 -10
- flaxdiff/samplers/heun_sampler.py +3 -4
- flaxdiff/samplers/multistep_dpm.py +2 -3
- flaxdiff/samplers/rk4_sampler.py +9 -9
- flaxdiff/trainer/autoencoder_trainer.py +1 -1
- flaxdiff/trainer/diffusion_trainer.py +124 -32
- flaxdiff/trainer/simple_trainer.py +187 -91
- flaxdiff/trainer/video_diffusion_trainer.py +62 -0
- flaxdiff/utils.py +105 -2
- {flaxdiff-0.1.35.6.dist-info → flaxdiff-0.1.36.dist-info}/METADATA +11 -5
- {flaxdiff-0.1.35.6.dist-info → flaxdiff-0.1.36.dist-info}/RECORD +19 -16
- {flaxdiff-0.1.35.6.dist-info → flaxdiff-0.1.36.dist-info}/WHEEL +1 -1
- {flaxdiff-0.1.35.6.dist-info → flaxdiff-0.1.36.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,71 @@
|
|
1
|
+
from .sources.tfds import data_source_tfds, tfds_augmenters
|
2
|
+
from .sources.gcs import data_source_gcs, data_source_combined_gcs, gcs_augmenters
|
3
|
+
|
4
|
+
# Configure the following for your datasets
|
5
|
+
datasetMap = {
|
6
|
+
"oxford_flowers102": {
|
7
|
+
"source": data_source_tfds("oxford_flowers102", use_tf=False),
|
8
|
+
"augmenter": tfds_augmenters,
|
9
|
+
},
|
10
|
+
"cc12m": {
|
11
|
+
"source": data_source_gcs('arrayrecord2/cc12m'),
|
12
|
+
"augmenter": gcs_augmenters,
|
13
|
+
},
|
14
|
+
"laiona_coco": {
|
15
|
+
"source": data_source_gcs('arrayrecord2/laion-aesthetics-12m+mscoco-2017'),
|
16
|
+
"augmenter": gcs_augmenters,
|
17
|
+
},
|
18
|
+
"aesthetic_coyo": {
|
19
|
+
"source": data_source_gcs('arrayrecords/aestheticCoyo_0.25clip_6aesthetic'),
|
20
|
+
"augmenter": gcs_augmenters,
|
21
|
+
},
|
22
|
+
"combined_aesthetic": {
|
23
|
+
"source": data_source_combined_gcs([
|
24
|
+
'arrayrecord2/laion-aesthetics-12m+mscoco-2017',
|
25
|
+
'arrayrecords/aestheticCoyo_0.25clip_6aesthetic',
|
26
|
+
'arrayrecord2/cc12m',
|
27
|
+
'arrayrecords/aestheticCoyo_0.25clip_6aesthetic',
|
28
|
+
]),
|
29
|
+
"augmenter": gcs_augmenters,
|
30
|
+
},
|
31
|
+
"laiona_coco_coyo": {
|
32
|
+
"source": data_source_combined_gcs([
|
33
|
+
'arrayrecords/aestheticCoyo_0.25clip_6aesthetic',
|
34
|
+
'arrayrecord2/laion-aesthetics-12m+mscoco-2017',
|
35
|
+
'arrayrecords/aestheticCoyo_0.25clip_6aesthetic',
|
36
|
+
]),
|
37
|
+
"augmenter": gcs_augmenters,
|
38
|
+
},
|
39
|
+
"combined_30m": {
|
40
|
+
"source": data_source_combined_gcs([
|
41
|
+
'arrayrecord2/laion-aesthetics-12m+mscoco-2017',
|
42
|
+
'arrayrecord2/cc12m',
|
43
|
+
'arrayrecord2/aestheticCoyo_0.26_clip_5.5aesthetic_256plus',
|
44
|
+
"arrayrecord2/playground+leonardo_x4+cc3m.parquet",
|
45
|
+
]),
|
46
|
+
"augmenter": gcs_augmenters,
|
47
|
+
}
|
48
|
+
}
|
49
|
+
|
50
|
+
onlineDatasetMap = {
|
51
|
+
"combined_online": {
|
52
|
+
"source": [
|
53
|
+
# "gs://flaxdiff-datasets-regional/datasets/laion-aesthetics-12m+mscoco-2017.parquet"
|
54
|
+
# "ChristophSchuhmann/MS_COCO_2017_URL_TEXT",
|
55
|
+
# "dclure/laion-aesthetics-12m-umap",
|
56
|
+
"gs://flaxdiff-datasets-regional/datasets/laion-aesthetics-12m+mscoco-2017",
|
57
|
+
"gs://flaxdiff-datasets-regional/datasets/coyo700m-aesthetic-5.4_25M",
|
58
|
+
"gs://flaxdiff-datasets-regional/datasets/leonardo-liked-1.8m",
|
59
|
+
"gs://flaxdiff-datasets-regional/datasets/leonardo-liked-1.8m",
|
60
|
+
"gs://flaxdiff-datasets-regional/datasets/leonardo-liked-1.8m",
|
61
|
+
"gs://flaxdiff-datasets-regional/datasets/cc12m",
|
62
|
+
"gs://flaxdiff-datasets-regional/datasets/playground-liked",
|
63
|
+
"gs://flaxdiff-datasets-regional/datasets/leonardo-liked-1.8m",
|
64
|
+
"gs://flaxdiff-datasets-regional/datasets/leonardo-liked-1.8m",
|
65
|
+
"gs://flaxdiff-datasets-regional/datasets/cc3m",
|
66
|
+
"gs://flaxdiff-datasets-regional/datasets/cc3m",
|
67
|
+
"gs://flaxdiff-datasets-regional/datasets/laion2B-en-aesthetic-4.2_37M",
|
68
|
+
# "gs://flaxdiff-datasets-regional/datasets/laiion400m-185M"
|
69
|
+
]
|
70
|
+
}
|
71
|
+
}
|
@@ -0,0 +1,169 @@
|
|
1
|
+
import jax.numpy as jnp
|
2
|
+
import grain.python as pygrain
|
3
|
+
from typing import Dict
|
4
|
+
import numpy as np
|
5
|
+
import jax
|
6
|
+
from flaxdiff.utils import convert_to_global_tree, AutoTextTokenizer
|
7
|
+
from .dataset_map import datasetMap, onlineDatasetMap
|
8
|
+
import traceback
|
9
|
+
from .online_loader import OnlineStreamingDataLoader
|
10
|
+
import queue
|
11
|
+
from jax.sharding import Mesh
|
12
|
+
import threading
|
13
|
+
|
14
|
+
def batch_mesh_map(mesh):
|
15
|
+
class augmenters(pygrain.MapTransform):
|
16
|
+
def __init__(self, *args, **kwargs):
|
17
|
+
super().__init__(*args, **kwargs)
|
18
|
+
|
19
|
+
def map(self, batch) -> Dict[str, jnp.array]:
|
20
|
+
return convert_to_global_tree(mesh, batch)
|
21
|
+
return augmenters
|
22
|
+
|
23
|
+
def get_dataset_grain(
|
24
|
+
data_name="cc12m",
|
25
|
+
batch_size=64,
|
26
|
+
image_scale=256,
|
27
|
+
count=None,
|
28
|
+
num_epochs=None,
|
29
|
+
method=jax.image.ResizeMethod.LANCZOS3,
|
30
|
+
worker_count=32,
|
31
|
+
read_thread_count=64,
|
32
|
+
read_buffer_size=50,
|
33
|
+
worker_buffer_size=20,
|
34
|
+
seed=0,
|
35
|
+
dataset_source="/mnt/gcs_mount/arrayrecord2/cc12m/",
|
36
|
+
):
|
37
|
+
dataset = datasetMap[data_name]
|
38
|
+
data_source = dataset["source"](dataset_source)
|
39
|
+
augmenter = dataset["augmenter"](image_scale, method)
|
40
|
+
|
41
|
+
local_batch_size = batch_size // jax.process_count()
|
42
|
+
|
43
|
+
sampler = pygrain.IndexSampler(
|
44
|
+
num_records=len(data_source) if count is None else count,
|
45
|
+
shuffle=True,
|
46
|
+
seed=seed,
|
47
|
+
num_epochs=num_epochs,
|
48
|
+
shard_options=pygrain.ShardByJaxProcess(),
|
49
|
+
)
|
50
|
+
|
51
|
+
def get_trainset():
|
52
|
+
transformations = [
|
53
|
+
augmenter(),
|
54
|
+
pygrain.Batch(local_batch_size, drop_remainder=True),
|
55
|
+
]
|
56
|
+
|
57
|
+
# if mesh != None:
|
58
|
+
# transformations += [batch_mesh_map(mesh)]
|
59
|
+
|
60
|
+
loader = pygrain.DataLoader(
|
61
|
+
data_source=data_source,
|
62
|
+
sampler=sampler,
|
63
|
+
operations=transformations,
|
64
|
+
worker_count=worker_count,
|
65
|
+
read_options=pygrain.ReadOptions(
|
66
|
+
read_thread_count, read_buffer_size
|
67
|
+
),
|
68
|
+
worker_buffer_size=worker_buffer_size,
|
69
|
+
)
|
70
|
+
return loader
|
71
|
+
|
72
|
+
|
73
|
+
return {
|
74
|
+
"train": get_trainset,
|
75
|
+
"train_len": len(data_source),
|
76
|
+
"local_batch_size": local_batch_size,
|
77
|
+
"global_batch_size": batch_size,
|
78
|
+
# "null_labels": null_labels,
|
79
|
+
# "null_labels_full": null_labels_full,
|
80
|
+
# "model": model,
|
81
|
+
# "tokenizer": tokenizer,
|
82
|
+
}
|
83
|
+
|
84
|
+
def generate_collate_fn():
|
85
|
+
auto_tokenize = AutoTextTokenizer(tensor_type="np")
|
86
|
+
def default_collate(batch):
|
87
|
+
try:
|
88
|
+
# urls = [sample["url"] for sample in batch]
|
89
|
+
captions = [sample["caption"] for sample in batch]
|
90
|
+
results = auto_tokenize(captions)
|
91
|
+
images = np.stack([sample["image"] for sample in batch], axis=0)
|
92
|
+
return {
|
93
|
+
"image": images,
|
94
|
+
"input_ids": results['input_ids'],
|
95
|
+
"attention_mask": results['attention_mask'],
|
96
|
+
}
|
97
|
+
except Exception as e:
|
98
|
+
print("Error in collate function", e, [sample["image"].shape for sample in batch])
|
99
|
+
traceback.print_exc()
|
100
|
+
|
101
|
+
return default_collate
|
102
|
+
|
103
|
+
def get_dataset_online(
|
104
|
+
data_name="combined_online",
|
105
|
+
batch_size=64,
|
106
|
+
image_scale=256,
|
107
|
+
count=None,
|
108
|
+
num_epochs=None,
|
109
|
+
method=jax.image.ResizeMethod.LANCZOS3,
|
110
|
+
worker_count=32,
|
111
|
+
read_thread_count=64,
|
112
|
+
read_buffer_size=50,
|
113
|
+
worker_buffer_size=20,
|
114
|
+
seed=0,
|
115
|
+
dataset_source="/mnt/gcs_mount/arrayrecord2/cc12m/",
|
116
|
+
):
|
117
|
+
local_batch_size = batch_size // jax.process_count()
|
118
|
+
|
119
|
+
sources = onlineDatasetMap[data_name]["source"]
|
120
|
+
dataloader = OnlineStreamingDataLoader(
|
121
|
+
sources,
|
122
|
+
batch_size=local_batch_size,
|
123
|
+
num_workers=worker_count,
|
124
|
+
num_threads=read_thread_count,
|
125
|
+
image_shape=(image_scale, image_scale),
|
126
|
+
global_process_count=jax.process_count(),
|
127
|
+
global_process_index=jax.process_index(),
|
128
|
+
prefetch=worker_buffer_size,
|
129
|
+
collate_fn=generate_collate_fn(),
|
130
|
+
default_split="train",
|
131
|
+
)
|
132
|
+
|
133
|
+
def get_trainset(mesh: Mesh = None):
|
134
|
+
if mesh != None:
|
135
|
+
class dataLoaderWithMesh:
|
136
|
+
def __init__(self, dataloader, mesh):
|
137
|
+
self.dataloader = dataloader
|
138
|
+
self.mesh = mesh
|
139
|
+
self.tmp_queue = queue.Queue(worker_buffer_size)
|
140
|
+
def batch_loader():
|
141
|
+
for batch in self.dataloader:
|
142
|
+
try:
|
143
|
+
self.tmp_queue.put(convert_to_global_tree(mesh, batch))
|
144
|
+
except Exception as e:
|
145
|
+
print("Error processing batch", e)
|
146
|
+
self.loader_thread = threading.Thread(target=batch_loader)
|
147
|
+
self.loader_thread.start()
|
148
|
+
|
149
|
+
def __iter__(self):
|
150
|
+
return self
|
151
|
+
|
152
|
+
def __next__(self):
|
153
|
+
return self.tmp_queue.get()
|
154
|
+
|
155
|
+
dataloader_with_mesh = dataLoaderWithMesh(dataloader, mesh)
|
156
|
+
|
157
|
+
return dataloader_with_mesh
|
158
|
+
return dataloader
|
159
|
+
|
160
|
+
return {
|
161
|
+
"train": get_trainset,
|
162
|
+
"train_len": len(dataloader) * jax.process_count(),
|
163
|
+
"local_batch_size": local_batch_size,
|
164
|
+
"global_batch_size": batch_size,
|
165
|
+
# "null_labels": null_labels,
|
166
|
+
# "null_labels_full": null_labels_full,
|
167
|
+
# "model": model,
|
168
|
+
# "tokenizer": tokenizer,
|
169
|
+
}
|
flaxdiff/data/online_loader.py
CHANGED
@@ -45,36 +45,43 @@ def fetch_single_image(image_url, timeout=None, retries=0):
|
|
45
45
|
|
46
46
|
|
47
47
|
def default_image_processor(
|
48
|
-
image, image_shape,
|
48
|
+
image, image_shape,
|
49
49
|
min_image_shape=(128, 128),
|
50
50
|
upscale_interpolation=cv2.INTER_CUBIC,
|
51
51
|
downscale_interpolation=cv2.INTER_AREA,
|
52
52
|
):
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
image,
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
53
|
+
try:
|
54
|
+
image = np.array(image)
|
55
|
+
if len(image.shape) != 3 or image.shape[2] != 3:
|
56
|
+
return None, 0, 0
|
57
|
+
original_height, original_width = image.shape[:2]
|
58
|
+
# check if the image is too small
|
59
|
+
if min(original_height, original_width) < min(min_image_shape):
|
60
|
+
return None, original_height, original_width
|
61
|
+
# check if wrong aspect ratio
|
62
|
+
if max(original_height, original_width) / min(original_height, original_width) > 2.4:
|
63
|
+
return None, original_height, original_width
|
64
|
+
# check if the variance is too low
|
65
|
+
if np.std(image) < 1e-5:
|
66
|
+
return None, original_height, original_width
|
67
|
+
# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
68
|
+
downscale = max(original_width, original_height) > max(image_shape)
|
69
|
+
interpolation = downscale_interpolation if downscale else upscale_interpolation
|
70
|
+
|
71
|
+
image = A.longest_max_size(image, max(
|
72
|
+
image_shape), interpolation=interpolation)
|
73
|
+
image = A.pad(
|
74
|
+
image,
|
75
|
+
min_height=image_shape[0],
|
76
|
+
min_width=image_shape[1],
|
77
|
+
border_mode=cv2.BORDER_CONSTANT,
|
78
|
+
value=[255, 255, 255],
|
79
|
+
)
|
80
|
+
return image, original_height, original_width
|
81
|
+
except Exception as e:
|
82
|
+
# print("Error processing image", e, image_shape, interpolation)
|
83
|
+
# traceback.print_exc()
|
84
|
+
return None, 0, 0
|
78
85
|
|
79
86
|
|
80
87
|
def map_sample(
|
@@ -120,14 +127,36 @@ def map_sample(
|
|
120
127
|
# })
|
121
128
|
pass
|
122
129
|
|
123
|
-
|
124
130
|
def default_feature_extractor(sample):
|
131
|
+
url = None
|
132
|
+
if "url" in sample:
|
133
|
+
url = sample["url"]
|
134
|
+
elif "URL" in sample:
|
135
|
+
url = sample["URL"]
|
136
|
+
elif "image_url" in sample:
|
137
|
+
url = sample["image_url"]
|
138
|
+
else:
|
139
|
+
print("No url found in sample, skipping", sample.keys())
|
140
|
+
|
141
|
+
caption = None
|
142
|
+
if "caption" in sample:
|
143
|
+
caption = sample["caption"]
|
144
|
+
elif "CAPTION" in sample:
|
145
|
+
caption = sample["CAPTION"]
|
146
|
+
elif "txt" in sample:
|
147
|
+
caption = sample["txt"]
|
148
|
+
elif "TEXT" in sample:
|
149
|
+
caption = sample["TEXT"]
|
150
|
+
elif "text" in sample:
|
151
|
+
caption = sample["text"]
|
152
|
+
else:
|
153
|
+
print("No caption found in sample, skipping", sample.keys())
|
154
|
+
|
125
155
|
return {
|
126
|
-
"url":
|
127
|
-
"caption":
|
156
|
+
"url": url,
|
157
|
+
"caption": caption,
|
128
158
|
}
|
129
|
-
|
130
|
-
|
159
|
+
|
131
160
|
def map_batch(
|
132
161
|
batch, num_threads=256, image_shape=(256, 256),
|
133
162
|
min_image_shape=(128, 128),
|
@@ -301,15 +330,13 @@ class OnlineStreamingDataLoader():
|
|
301
330
|
self.dataset = dataset.shard(
|
302
331
|
num_shards=global_process_count, index=global_process_index)
|
303
332
|
print(f"Dataset length: {len(dataset)}")
|
304
|
-
self.iterator = ImageBatchIterator(
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
feature_extractor=feature_extractor
|
312
|
-
)
|
333
|
+
self.iterator = ImageBatchIterator(self.dataset, image_shape=image_shape,
|
334
|
+
min_image_shape=min_image_shape,
|
335
|
+
num_workers=num_workers, batch_size=batch_size, num_threads=num_threads,
|
336
|
+
timeout=timeout, retries=retries, image_processor=image_processor,
|
337
|
+
upscale_interpolation=upscale_interpolation,
|
338
|
+
downscale_interpolation=downscale_interpolation,
|
339
|
+
feature_extractor=feature_extractor)
|
313
340
|
self.batch_size = batch_size
|
314
341
|
|
315
342
|
# Launch a thread to load batches in the background
|
@@ -320,7 +347,7 @@ class OnlineStreamingDataLoader():
|
|
320
347
|
try:
|
321
348
|
self.batch_queue.put(collate_fn(batch))
|
322
349
|
except Exception as e:
|
323
|
-
print("Error
|
350
|
+
print("Error collating batch", e)
|
324
351
|
|
325
352
|
self.loader_thread = threading.Thread(target=batch_loader)
|
326
353
|
self.loader_thread.start()
|
@@ -333,4 +360,4 @@ class OnlineStreamingDataLoader():
|
|
333
360
|
# return self.collate_fn(next(self.iterator))
|
334
361
|
|
335
362
|
def __len__(self):
|
336
|
-
return len(self.dataset)
|
363
|
+
return len(self.dataset)
|
flaxdiff/samplers/common.py
CHANGED
@@ -15,36 +15,76 @@ class DiffusionSampler():
|
|
15
15
|
|
16
16
|
def __init__(self, model:nn.Module, params:dict,
|
17
17
|
noise_schedule:NoiseScheduler,
|
18
|
-
model_output_transform:DiffusionPredictionTransform=EpsilonPredictionTransform()
|
18
|
+
model_output_transform:DiffusionPredictionTransform=EpsilonPredictionTransform(),
|
19
|
+
guidance_scale:float = 0.0,
|
20
|
+
null_labels_seq:jax.Array=None,
|
21
|
+
autoencoder=None,
|
22
|
+
image_size=256,
|
23
|
+
autoenc_scale_reduction=8,
|
24
|
+
autoenc_latent_channels=4,
|
25
|
+
):
|
19
26
|
self.model = model
|
20
27
|
self.noise_schedule = noise_schedule
|
21
28
|
self.params = params
|
22
29
|
self.model_output_transform = model_output_transform
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
model_output = self.model.apply(self.params, *self.noise_schedule.transform_inputs(x_t * c_in, t))
|
29
|
-
x_0, eps = self.model_output_transform(x_t, model_output, t, self.noise_schedule)
|
30
|
-
return x_0, eps, model_output
|
30
|
+
self.guidance_scale = guidance_scale
|
31
|
+
self.image_size = image_size
|
32
|
+
self.autoenc_scale_reduction = autoenc_scale_reduction
|
33
|
+
self.autoencoder = autoencoder
|
34
|
+
self.autoenc_latent_channels = autoenc_latent_channels
|
31
35
|
|
36
|
+
if self.guidance_scale > 0:
|
37
|
+
# Classifier free guidance
|
38
|
+
assert null_labels_seq is not None, "Null labels sequence is required for classifier-free guidance"
|
39
|
+
print("Using classifier-free guidance")
|
40
|
+
def sample_model(x_t, t, *additional_inputs):
|
41
|
+
# Concatenate unconditional and conditional inputs
|
42
|
+
x_t_cat = jnp.concatenate([x_t] * 2, axis=0)
|
43
|
+
t_cat = jnp.concatenate([t] * 2, axis=0)
|
44
|
+
rates_cat = self.noise_schedule.get_rates(t_cat)
|
45
|
+
c_in_cat = self.model_output_transform.get_input_scale(rates_cat)
|
46
|
+
|
47
|
+
text_labels_seq, = additional_inputs
|
48
|
+
text_labels_seq = jnp.concatenate([text_labels_seq, jnp.broadcast_to(null_labels_seq, text_labels_seq.shape)], axis=0)
|
49
|
+
model_output = self.model.apply(self.params, *self.noise_schedule.transform_inputs(x_t_cat * c_in_cat, t_cat), text_labels_seq)
|
50
|
+
# Split model output into unconditional and conditional parts
|
51
|
+
model_output_cond, model_output_uncond = jnp.split(model_output, 2, axis=0)
|
52
|
+
model_output = model_output_uncond + guidance_scale * (model_output_cond - model_output_uncond)
|
53
|
+
|
54
|
+
x_0, eps = self.model_output_transform(x_t, model_output, t, self.noise_schedule)
|
55
|
+
return x_0, eps, model_output
|
56
|
+
else:
|
57
|
+
# Unconditional sampling
|
58
|
+
def sample_model(x_t, t, *additional_inputs):
|
59
|
+
rates = self.noise_schedule.get_rates(t)
|
60
|
+
c_in = self.model_output_transform.get_input_scale(rates)
|
61
|
+
model_output = self.model.apply(self.params, *self.noise_schedule.transform_inputs(x_t * c_in, t), *additional_inputs)
|
62
|
+
x_0, eps = self.model_output_transform(x_t, model_output, t, self.noise_schedule)
|
63
|
+
return x_0, eps, model_output
|
64
|
+
|
65
|
+
# if jax.device_count() > 1:
|
66
|
+
# mesh = jax.sharding.Mesh(jax.devices(), 'data')
|
67
|
+
# sample_model = shard_map(sample_model, mesh=mesh, in_specs=(P('data'), P('data'), P('data')),
|
68
|
+
# out_specs=(P('data'), P('data'), P('data')))
|
69
|
+
sample_model = jax.jit(sample_model)
|
32
70
|
self.sample_model = sample_model
|
33
71
|
|
34
72
|
# Used to sample from the diffusion model
|
35
|
-
def sample_step(self, current_samples:jnp.ndarray, current_step, next_step=None, state:MarkovState=None) -> tuple[jnp.ndarray, MarkovState]:
|
73
|
+
def sample_step(self, current_samples:jnp.ndarray, current_step, model_conditioning_inputs, next_step=None, state:MarkovState=None) -> tuple[jnp.ndarray, MarkovState]:
|
36
74
|
# First clip the noisy images
|
37
|
-
# pred_images = clip_images(pred_images)
|
38
75
|
step_ones = jnp.ones((current_samples.shape[0], ), dtype=jnp.int32)
|
39
76
|
current_step = step_ones * current_step
|
40
77
|
next_step = step_ones * next_step
|
41
|
-
pred_images, pred_noise, _ = self.sample_model(current_samples, current_step)
|
78
|
+
pred_images, pred_noise, _ = self.sample_model(current_samples, current_step, *model_conditioning_inputs)
|
42
79
|
# plotImages(pred_images)
|
80
|
+
# pred_images = clip_images(pred_images)
|
43
81
|
new_samples, state = self.take_next_step(current_samples=current_samples, reconstructed_samples=pred_images,
|
44
|
-
pred_noise=pred_noise, current_step=current_step, next_step=next_step, state=state
|
82
|
+
pred_noise=pred_noise, current_step=current_step, next_step=next_step, state=state,
|
83
|
+
model_conditioning_inputs=model_conditioning_inputs
|
84
|
+
)
|
45
85
|
return new_samples, state
|
46
86
|
|
47
|
-
def take_next_step(self, current_samples, reconstructed_samples,
|
87
|
+
def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
|
48
88
|
pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
|
49
89
|
# estimate the q(x_{t-1} | x_t, x_0).
|
50
90
|
# pred_images is x_0, noisy_images is x_t, steps is t
|
@@ -62,11 +102,16 @@ class DiffusionSampler():
|
|
62
102
|
steps = jnp.linspace(end_step, start_step, diffusion_steps, dtype=jnp.int16)[::-1]
|
63
103
|
return steps
|
64
104
|
|
65
|
-
def get_initial_samples(self, num_images, rngs:jax.random.PRNGKey, start_step
|
105
|
+
def get_initial_samples(self, num_images, rngs:jax.random.PRNGKey, start_step):
|
66
106
|
start_step = self.scale_steps(start_step)
|
67
107
|
alpha_n, sigma_n = self.noise_schedule.get_rates(start_step)
|
68
108
|
variance = jnp.sqrt(alpha_n ** 2 + sigma_n ** 2)
|
69
|
-
|
109
|
+
image_size = self.image_size
|
110
|
+
image_channels = 3
|
111
|
+
if self.autoencoder is not None:
|
112
|
+
image_size = image_size // self.autoenc_scale_reduction
|
113
|
+
image_channels = self.autoenc_latent_channels
|
114
|
+
return jax.random.normal(rngs, (num_images, image_size, image_size, image_channels)) * variance
|
70
115
|
|
71
116
|
def generate_images(self,
|
72
117
|
num_images=16,
|
@@ -75,18 +120,23 @@ class DiffusionSampler():
|
|
75
120
|
end_step:int = 0,
|
76
121
|
steps_override=None,
|
77
122
|
priors=None,
|
78
|
-
rngstate:RandomMarkovState=RandomMarkovState(jax.random.PRNGKey(42))
|
123
|
+
rngstate:RandomMarkovState=RandomMarkovState(jax.random.PRNGKey(42)),
|
124
|
+
model_conditioning_inputs:tuple=()
|
125
|
+
) -> jnp.ndarray:
|
79
126
|
if priors is None:
|
80
127
|
rngstate, newrngs = rngstate.get_random_key()
|
81
128
|
samples = self.get_initial_samples(num_images, newrngs, start_step)
|
82
129
|
else:
|
83
130
|
print("Using priors")
|
131
|
+
if self.autoencoder is not None:
|
132
|
+
priors = self.autoencoder.encode(priors)
|
84
133
|
samples = priors
|
85
134
|
|
86
|
-
@jax.jit
|
135
|
+
# @jax.jit
|
87
136
|
def sample_step(state:RandomMarkovState, samples, current_step, next_step):
|
88
137
|
samples, state = self.sample_step(current_samples=samples,
|
89
138
|
current_step=current_step,
|
139
|
+
model_conditioning_inputs=model_conditioning_inputs,
|
90
140
|
state=state, next_step=next_step)
|
91
141
|
return samples, state
|
92
142
|
|
@@ -108,6 +158,8 @@ class DiffusionSampler():
|
|
108
158
|
else:
|
109
159
|
# print("last step")
|
110
160
|
step_ones = jnp.ones((num_images, ), dtype=jnp.int32)
|
111
|
-
samples, _, _ = self.sample_model(samples, current_step * step_ones)
|
161
|
+
samples, _, _ = self.sample_model(samples, current_step * step_ones, *model_conditioning_inputs)
|
162
|
+
if self.autoencoder is not None:
|
163
|
+
samples = self.autoencoder.decode(samples)
|
112
164
|
samples = clip_images(samples)
|
113
|
-
return samples
|
165
|
+
return samples
|
flaxdiff/samplers/ddim.py
CHANGED
@@ -1,10 +1,10 @@
|
|
1
1
|
import jax.numpy as jnp
|
2
2
|
from .common import DiffusionSampler
|
3
|
-
from ..utils import MarkovState
|
3
|
+
from ..utils import MarkovState, RandomMarkovState
|
4
4
|
|
5
5
|
class DDIMSampler(DiffusionSampler):
|
6
|
-
def take_next_step(self,
|
7
|
-
|
8
|
-
pred_noise, current_step, state:MarkovState, next_step=None) -> tuple[jnp.ndarray, MarkovState]:
|
6
|
+
def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
|
7
|
+
pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
|
9
8
|
next_signal_rate, next_noise_rate = self.noise_schedule.get_rates(next_step)
|
10
|
-
return reconstructed_samples * next_signal_rate + pred_noise * next_noise_rate, state
|
9
|
+
return reconstructed_samples * next_signal_rate + pred_noise * next_noise_rate, state
|
10
|
+
|
flaxdiff/samplers/ddpm.py
CHANGED
@@ -3,9 +3,8 @@ import jax.numpy as jnp
|
|
3
3
|
from .common import DiffusionSampler
|
4
4
|
from ..utils import MarkovState, RandomMarkovState
|
5
5
|
class DDPMSampler(DiffusionSampler):
|
6
|
-
def take_next_step(self,
|
7
|
-
|
8
|
-
pred_noise, current_step, state:RandomMarkovState, next_step=None) -> tuple[jnp.ndarray, RandomMarkovState]:
|
6
|
+
def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
|
7
|
+
pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
|
9
8
|
mean = self.noise_schedule.get_posterior_mean(reconstructed_samples, current_samples, current_step)
|
10
9
|
variance = self.noise_schedule.get_posterior_variance(steps=current_step)
|
11
10
|
|
@@ -19,9 +18,8 @@ class DDPMSampler(DiffusionSampler):
|
|
19
18
|
return super().generate_images(num_images=num_images, diffusion_steps=diffusion_steps, start_step=start_step, *args, **kwargs)
|
20
19
|
|
21
20
|
class SimpleDDPMSampler(DiffusionSampler):
|
22
|
-
def take_next_step(self,
|
23
|
-
|
24
|
-
pred_noise, current_step, state:RandomMarkovState, next_step=None) -> tuple[jnp.ndarray, RandomMarkovState]:
|
21
|
+
def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
|
22
|
+
pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
|
25
23
|
state, rng = state.get_random_key()
|
26
24
|
noise = jax.random.normal(rng, reconstructed_samples.shape, dtype=jnp.float32)
|
27
25
|
|
@@ -33,11 +31,7 @@ class SimpleDDPMSampler(DiffusionSampler):
|
|
33
31
|
|
34
32
|
noise_ratio_squared = (next_noise_rate ** 2) / (current_noise_rate ** 2)
|
35
33
|
signal_ratio_squared = (current_signal_rate ** 2) / (next_signal_rate ** 2)
|
36
|
-
|
37
|
-
gamma = jnp.sqrt(noise_ratio_squared * betas)
|
34
|
+
gamma = jnp.sqrt(noise_ratio_squared * (1 - signal_ratio_squared))
|
38
35
|
|
39
36
|
next_samples = next_signal_rate * reconstructed_samples + pred_noise_coeff * pred_noise + noise * gamma
|
40
|
-
# pred_noise_coeff = ((next_noise_rate ** 2) * current_signal_rate) / (current_noise_rate * next_signal_rate)
|
41
|
-
# next_samples = (2 - jnp.sqrt(1 - betas)) * current_samples - betas * (pred_noise / current_noise_rate) + noise * gamma#jnp.sqrt(betas)
|
42
|
-
# next_samples = (1 / (jnp.sqrt(1 - betas) + 1.e-24)) * (current_samples - betas * (pred_noise / current_noise_rate)) + noise * gamma
|
43
37
|
return next_samples, state
|
flaxdiff/samplers/euler.py
CHANGED
@@ -5,9 +5,8 @@ from ..utils import RandomMarkovState
|
|
5
5
|
|
6
6
|
class EulerSampler(DiffusionSampler):
|
7
7
|
# Basically a DDIM Sampler but parameterized as an ODE
|
8
|
-
def take_next_step(self,
|
9
|
-
|
10
|
-
pred_noise, current_step, state:RandomMarkovState, next_step=None) -> tuple[jnp.ndarray, RandomMarkovState]:
|
8
|
+
def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
|
9
|
+
pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
|
11
10
|
current_alpha, current_sigma = self.noise_schedule.get_rates(current_step)
|
12
11
|
next_alpha, next_sigma = self.noise_schedule.get_rates(next_step)
|
13
12
|
|
@@ -22,9 +21,8 @@ class SimplifiedEulerSampler(DiffusionSampler):
|
|
22
21
|
"""
|
23
22
|
This is for networks with forward diffusion of the form x_{t+1} = x_t + sigma_t * epsilon_t
|
24
23
|
"""
|
25
|
-
def take_next_step(self,
|
26
|
-
|
27
|
-
pred_noise, current_step, state:RandomMarkovState, next_step=None) -> tuple[jnp.ndarray, RandomMarkovState]:
|
24
|
+
def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
|
25
|
+
pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
|
28
26
|
_, current_sigma = self.noise_schedule.get_rates(current_step)
|
29
27
|
_, next_sigma = self.noise_schedule.get_rates(next_step)
|
30
28
|
|
@@ -38,9 +36,8 @@ class EulerAncestralSampler(DiffusionSampler):
|
|
38
36
|
"""
|
39
37
|
Similar to EulerSampler but with ancestral sampling
|
40
38
|
"""
|
41
|
-
def take_next_step(self,
|
42
|
-
|
43
|
-
pred_noise, current_step, state:RandomMarkovState, next_step=None) -> tuple[jnp.ndarray, RandomMarkovState]:
|
39
|
+
def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
|
40
|
+
pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
|
44
41
|
current_alpha, current_sigma = self.noise_schedule.get_rates(current_step)
|
45
42
|
next_alpha, next_sigma = self.noise_schedule.get_rates(next_step)
|
46
43
|
|
@@ -56,4 +53,4 @@ class EulerAncestralSampler(DiffusionSampler):
|
|
56
53
|
dW = jax.random.normal(subkey, current_samples.shape) * sigma_up
|
57
54
|
|
58
55
|
next_samples = current_samples + dx * dt + dW
|
59
|
-
return next_samples, state
|
56
|
+
return next_samples, state
|
@@ -4,9 +4,8 @@ from .common import DiffusionSampler
|
|
4
4
|
from ..utils import RandomMarkovState
|
5
5
|
|
6
6
|
class HeunSampler(DiffusionSampler):
|
7
|
-
def take_next_step(self,
|
8
|
-
|
9
|
-
pred_noise, current_step, state:RandomMarkovState, next_step=None) -> tuple[jnp.ndarray, RandomMarkovState]:
|
7
|
+
def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
|
8
|
+
pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
|
10
9
|
# Get the noise and signal rates for the current and next steps
|
11
10
|
current_alpha, current_sigma = self.noise_schedule.get_rates(current_step)
|
12
11
|
next_alpha, next_sigma = self.noise_schedule.get_rates(next_step)
|
@@ -18,7 +17,7 @@ class HeunSampler(DiffusionSampler):
|
|
18
17
|
next_samples_0 = current_samples + dx_0 * dt
|
19
18
|
|
20
19
|
# Recompute x_0 and eps at the first estimate to refine the derivative
|
21
|
-
estimated_x_0, _, _ = self.sample_model(next_samples_0, next_step)
|
20
|
+
estimated_x_0, _, _ = self.sample_model(next_samples_0, next_step, *model_conditioning_inputs)
|
22
21
|
|
23
22
|
# Estimate the refined derivative using the midpoint (Heun's method)
|
24
23
|
dx_1 = (next_samples_0 - x_0_coeff * estimated_x_0) / next_sigma
|