flaxdiff 0.1.36.4__tar.gz → 0.1.37__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/PKG-INFO +60 -22
- {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/README.md +59 -21
- {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/samplers/common.py +20 -13
- {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/samplers/ddim.py +1 -1
- {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/samplers/ddpm.py +2 -2
- {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/samplers/euler.py +3 -3
- {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/samplers/heun_sampler.py +2 -2
- {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/samplers/multistep_dpm.py +1 -1
- {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/samplers/rk4_sampler.py +7 -7
- {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/schedulers/__init__.py +1 -1
- {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/schedulers/cosine.py +1 -1
- {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/trainer/diffusion_trainer.py +16 -14
- {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff.egg-info/PKG-INFO +60 -22
- {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/pyproject.toml +1 -1
- {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/__init__.py +0 -0
- {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/data/__init__.py +0 -0
- {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/data/dataset_map.py +0 -0
- {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/data/datasets.py +0 -0
- {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/data/online_loader.py +0 -0
- {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/data/sources/gcs.py +0 -0
- {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/data/sources/tfds.py +0 -0
- {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/metrics/inception.py +0 -0
- {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/metrics/utils.py +0 -0
- {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/models/__init__.py +0 -0
- {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/models/attention.py +0 -0
- {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/models/autoencoder/__init__.py +0 -0
- {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/models/autoencoder/autoencoder.py +0 -0
- {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/models/autoencoder/diffusers.py +0 -0
- {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/models/autoencoder/simple_autoenc.py +0 -0
- {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/models/common.py +0 -0
- {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/models/favor_fastattn.py +0 -0
- {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/models/simple_unet.py +0 -0
- {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/models/simple_vit.py +0 -0
- {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/predictors/__init__.py +0 -0
- {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/samplers/__init__.py +0 -0
- {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/schedulers/common.py +0 -0
- {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/schedulers/continuous.py +0 -0
- {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/schedulers/discrete.py +0 -0
- {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/schedulers/exp.py +0 -0
- {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/schedulers/karras.py +0 -0
- {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/schedulers/linear.py +0 -0
- {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/schedulers/sqrt.py +0 -0
- {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/trainer/__init__.py +0 -0
- {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/trainer/autoencoder_trainer.py +0 -0
- {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/trainer/simple_trainer.py +0 -0
- {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/trainer/video_diffusion_trainer.py +0 -0
- {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff/utils.py +0 -0
- {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff.egg-info/SOURCES.txt +0 -0
- {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff.egg-info/dependency_links.txt +0 -0
- {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff.egg-info/requires.txt +0 -0
- {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/flaxdiff.egg-info/top_level.txt +0 -0
- {flaxdiff-0.1.36.4 → flaxdiff-0.1.37}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: flaxdiff
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.37
|
4
4
|
Summary: A versatile and easy to understand Diffusion library
|
5
5
|
Author-email: Ashish Kumar Singh <ashishkmr472@gmail.com>
|
6
6
|
License-Expression: MIT
|
@@ -96,7 +96,7 @@ Also, few of the text may be generated with help of github copilot, so please ex
|
|
96
96
|
### Schedulers
|
97
97
|
Implemented in `flaxdiff.schedulers`:
|
98
98
|
- **LinearNoiseSchedule** (`flaxdiff.schedulers.LinearNoiseSchedule`): A beta-parameterized discrete scheduler.
|
99
|
-
- **
|
99
|
+
- **CosineNoiseScheduler** (`flaxdiff.schedulers.CosineNoiseScheduler`): A beta-parameterized discrete scheduler.
|
100
100
|
- **ExpNoiseSchedule** (`flaxdiff.schedulers.ExpNoiseSchedule`): A beta-parameterized discrete scheduler.
|
101
101
|
- **CosineContinuousNoiseScheduler** (`flaxdiff.schedulers.CosineContinuousNoiseScheduler`): A continuous scheduler.
|
102
102
|
- **CosineGeneralNoiseScheduler** (`flaxdiff.schedulers.CosineGeneralNoiseScheduler`): A continuous sigma parameterized cosine scheduler.
|
@@ -147,43 +147,81 @@ sticking to the versions mentioned in the requirements.txt
|
|
147
147
|
Here is a simplified example to get you started with training a diffusion model using FlaxDiff:
|
148
148
|
|
149
149
|
```python
|
150
|
-
from flaxdiff.schedulers import EDMNoiseScheduler
|
150
|
+
from flaxdiff.schedulers import EDMNoiseScheduler, KarrasVENoiseScheduler
|
151
151
|
from flaxdiff.predictors import KarrasPredictionTransform
|
152
|
-
from flaxdiff.models.simple_unet import
|
152
|
+
from flaxdiff.models.simple_unet import Unet
|
153
153
|
from flaxdiff.trainer import DiffusionTrainer
|
154
|
+
from flaxdiff.data.datasets import get_dataset_grain
|
155
|
+
from flaxdiff.utils import defaultTextEncodeModel
|
156
|
+
from flaxdiff.samplers.euler import EulerAncestralSampler
|
154
157
|
import jax
|
158
|
+
import jax.numpy as jnp
|
155
159
|
import optax
|
156
160
|
from datetime import datetime
|
157
161
|
|
158
162
|
BATCH_SIZE = 16
|
159
|
-
IMAGE_SIZE =
|
163
|
+
IMAGE_SIZE = 128
|
160
164
|
|
161
165
|
# Define noise scheduler
|
162
166
|
edm_schedule = EDMNoiseScheduler(1, sigma_max=80, rho=7, sigma_data=0.5)
|
163
|
-
|
167
|
+
karas_ve_schedule = KarrasVENoiseScheduler(1, sigma_max=80, rho=7, sigma_data=0.5)
|
164
168
|
# Define model
|
165
|
-
unet =
|
166
|
-
feature_depths=[64, 128, 256, 512],
|
167
|
-
attention_configs=[
|
169
|
+
unet = Unet(emb_features=256,
|
170
|
+
feature_depths=[64, 64, 128, 256, 512],
|
171
|
+
attention_configs=[
|
172
|
+
None,
|
173
|
+
{"heads":8, "dtype":jnp.float16, "flash_attention":False, "use_projection":True, "use_self_and_cross":True},
|
174
|
+
{"heads":8, "dtype":jnp.float16, "flash_attention":False, "use_projection":True, "use_self_and_cross":True},
|
175
|
+
{"heads":8, "dtype":jnp.float16, "flash_attention":False, "use_projection":True, "use_self_and_cross":True},
|
176
|
+
{"heads":8, "dtype":jnp.float16, "flash_attention":False, "use_projection":False, "use_self_and_cross":False}
|
177
|
+
],
|
168
178
|
num_res_blocks=2,
|
169
|
-
num_middle_res_blocks=1
|
170
|
-
|
179
|
+
num_middle_res_blocks=1
|
180
|
+
)
|
171
181
|
# Load dataset
|
172
|
-
data
|
182
|
+
data = get_dataset_grain("oxford_flowers102", batch_size=BATCH_SIZE, image_scale=IMAGE_SIZE)
|
183
|
+
datalen = data['train_len']
|
173
184
|
batches = datalen // BATCH_SIZE
|
174
185
|
|
186
|
+
input_shapes = {
|
187
|
+
"x": (IMAGE_SIZE, IMAGE_SIZE, 3),
|
188
|
+
"temb": (),
|
189
|
+
"textcontext": (77, 768)
|
190
|
+
}
|
191
|
+
text_encoder = defaultTextEncodeModel()
|
192
|
+
|
193
|
+
# Construct a validation set by the prompts
|
194
|
+
val_prompts = ['water tulip', ' a water lily', ' a water lily', ' a photo of a rose', ' a photo of a rose', ' a water lily', ' a water lily', ' a photo of a marigold', ' a photo of a marigold']
|
195
|
+
|
196
|
+
def get_val_dataset(batch_size=8):
|
197
|
+
for i in range(0, len(val_prompts), batch_size):
|
198
|
+
prompts = val_prompts[i:i + batch_size]
|
199
|
+
tokens = text_encoder.tokenize(prompts)
|
200
|
+
yield tokens
|
201
|
+
|
202
|
+
data['test'] = get_val_dataset
|
203
|
+
data['test_len'] = len(val_prompts)
|
204
|
+
|
175
205
|
# Define optimizer
|
176
206
|
solver = optax.adam(2e-4)
|
177
207
|
|
178
208
|
# Create trainer
|
179
|
-
trainer = DiffusionTrainer(
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
209
|
+
trainer = DiffusionTrainer(
|
210
|
+
unet, optimizer=solver,
|
211
|
+
input_shapes=input_shapes,
|
212
|
+
noise_schedule=edm_schedule,
|
213
|
+
rngs=jax.random.PRNGKey(4),
|
214
|
+
name="Diffusion_SDE_VE_" + datetime.now().strftime("%Y-%m-%d_%H:%M:%S"),
|
215
|
+
model_output_transform=KarrasPredictionTransform(sigma_data=edm_schedule.sigma_data),
|
216
|
+
encoder=text_encoder,
|
217
|
+
distributed_training=True,
|
218
|
+
wandb_config = {
|
219
|
+
"project": 'mlops-msml605-project',
|
220
|
+
"name": f"prototype-{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}",
|
221
|
+
})
|
184
222
|
|
185
223
|
# Train the model
|
186
|
-
final_state = trainer.fit(data, batches, epochs=2000)
|
224
|
+
final_state = trainer.fit(data, batches, epochs=2000, sampler_class=EulerAncestralSampler, sampling_noise_schedule=karas_ve_schedule)
|
187
225
|
```
|
188
226
|
|
189
227
|
### Inference Example
|
@@ -301,8 +339,8 @@ Images generated by the following prompts using classifier free guidance with gu
|
|
301
339
|
`Training Epochs: 1000`
|
302
340
|
`Steps per epoch: 511`
|
303
341
|
|
304
|
-
`Training Noise Schedule:
|
305
|
-
`Inference Noise Schedule:
|
342
|
+
`Training Noise Schedule: CosineNoiseScheduler`
|
343
|
+
`Inference Noise Schedule: CosineNoiseScheduler`
|
306
344
|
|
307
345
|
`Model: UNet(emb_features=256,
|
308
346
|
feature_depths=[64, 128, 256, 512],
|
@@ -321,8 +359,8 @@ Images generated by the following prompts using classifier free guidance with gu
|
|
321
359
|
`Training Epochs: 1000`
|
322
360
|
`Steps per epoch: 511`
|
323
361
|
|
324
|
-
`Training Noise Schedule:
|
325
|
-
`Inference Noise Schedule:
|
362
|
+
`Training Noise Schedule: CosineNoiseScheduler`
|
363
|
+
`Inference Noise Schedule: CosineNoiseScheduler`
|
326
364
|
|
327
365
|
`Model: UNet(emb_features=256,
|
328
366
|
feature_depths=[64, 128, 256, 512],
|
@@ -74,7 +74,7 @@ Also, few of the text may be generated with help of github copilot, so please ex
|
|
74
74
|
### Schedulers
|
75
75
|
Implemented in `flaxdiff.schedulers`:
|
76
76
|
- **LinearNoiseSchedule** (`flaxdiff.schedulers.LinearNoiseSchedule`): A beta-parameterized discrete scheduler.
|
77
|
-
- **
|
77
|
+
- **CosineNoiseScheduler** (`flaxdiff.schedulers.CosineNoiseScheduler`): A beta-parameterized discrete scheduler.
|
78
78
|
- **ExpNoiseSchedule** (`flaxdiff.schedulers.ExpNoiseSchedule`): A beta-parameterized discrete scheduler.
|
79
79
|
- **CosineContinuousNoiseScheduler** (`flaxdiff.schedulers.CosineContinuousNoiseScheduler`): A continuous scheduler.
|
80
80
|
- **CosineGeneralNoiseScheduler** (`flaxdiff.schedulers.CosineGeneralNoiseScheduler`): A continuous sigma parameterized cosine scheduler.
|
@@ -125,43 +125,81 @@ sticking to the versions mentioned in the requirements.txt
|
|
125
125
|
Here is a simplified example to get you started with training a diffusion model using FlaxDiff:
|
126
126
|
|
127
127
|
```python
|
128
|
-
from flaxdiff.schedulers import EDMNoiseScheduler
|
128
|
+
from flaxdiff.schedulers import EDMNoiseScheduler, KarrasVENoiseScheduler
|
129
129
|
from flaxdiff.predictors import KarrasPredictionTransform
|
130
|
-
from flaxdiff.models.simple_unet import
|
130
|
+
from flaxdiff.models.simple_unet import Unet
|
131
131
|
from flaxdiff.trainer import DiffusionTrainer
|
132
|
+
from flaxdiff.data.datasets import get_dataset_grain
|
133
|
+
from flaxdiff.utils import defaultTextEncodeModel
|
134
|
+
from flaxdiff.samplers.euler import EulerAncestralSampler
|
132
135
|
import jax
|
136
|
+
import jax.numpy as jnp
|
133
137
|
import optax
|
134
138
|
from datetime import datetime
|
135
139
|
|
136
140
|
BATCH_SIZE = 16
|
137
|
-
IMAGE_SIZE =
|
141
|
+
IMAGE_SIZE = 128
|
138
142
|
|
139
143
|
# Define noise scheduler
|
140
144
|
edm_schedule = EDMNoiseScheduler(1, sigma_max=80, rho=7, sigma_data=0.5)
|
141
|
-
|
145
|
+
karas_ve_schedule = KarrasVENoiseScheduler(1, sigma_max=80, rho=7, sigma_data=0.5)
|
142
146
|
# Define model
|
143
|
-
unet =
|
144
|
-
feature_depths=[64, 128, 256, 512],
|
145
|
-
attention_configs=[
|
147
|
+
unet = Unet(emb_features=256,
|
148
|
+
feature_depths=[64, 64, 128, 256, 512],
|
149
|
+
attention_configs=[
|
150
|
+
None,
|
151
|
+
{"heads":8, "dtype":jnp.float16, "flash_attention":False, "use_projection":True, "use_self_and_cross":True},
|
152
|
+
{"heads":8, "dtype":jnp.float16, "flash_attention":False, "use_projection":True, "use_self_and_cross":True},
|
153
|
+
{"heads":8, "dtype":jnp.float16, "flash_attention":False, "use_projection":True, "use_self_and_cross":True},
|
154
|
+
{"heads":8, "dtype":jnp.float16, "flash_attention":False, "use_projection":False, "use_self_and_cross":False}
|
155
|
+
],
|
146
156
|
num_res_blocks=2,
|
147
|
-
num_middle_res_blocks=1
|
148
|
-
|
157
|
+
num_middle_res_blocks=1
|
158
|
+
)
|
149
159
|
# Load dataset
|
150
|
-
data
|
160
|
+
data = get_dataset_grain("oxford_flowers102", batch_size=BATCH_SIZE, image_scale=IMAGE_SIZE)
|
161
|
+
datalen = data['train_len']
|
151
162
|
batches = datalen // BATCH_SIZE
|
152
163
|
|
164
|
+
input_shapes = {
|
165
|
+
"x": (IMAGE_SIZE, IMAGE_SIZE, 3),
|
166
|
+
"temb": (),
|
167
|
+
"textcontext": (77, 768)
|
168
|
+
}
|
169
|
+
text_encoder = defaultTextEncodeModel()
|
170
|
+
|
171
|
+
# Construct a validation set by the prompts
|
172
|
+
val_prompts = ['water tulip', ' a water lily', ' a water lily', ' a photo of a rose', ' a photo of a rose', ' a water lily', ' a water lily', ' a photo of a marigold', ' a photo of a marigold']
|
173
|
+
|
174
|
+
def get_val_dataset(batch_size=8):
|
175
|
+
for i in range(0, len(val_prompts), batch_size):
|
176
|
+
prompts = val_prompts[i:i + batch_size]
|
177
|
+
tokens = text_encoder.tokenize(prompts)
|
178
|
+
yield tokens
|
179
|
+
|
180
|
+
data['test'] = get_val_dataset
|
181
|
+
data['test_len'] = len(val_prompts)
|
182
|
+
|
153
183
|
# Define optimizer
|
154
184
|
solver = optax.adam(2e-4)
|
155
185
|
|
156
186
|
# Create trainer
|
157
|
-
trainer = DiffusionTrainer(
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
187
|
+
trainer = DiffusionTrainer(
|
188
|
+
unet, optimizer=solver,
|
189
|
+
input_shapes=input_shapes,
|
190
|
+
noise_schedule=edm_schedule,
|
191
|
+
rngs=jax.random.PRNGKey(4),
|
192
|
+
name="Diffusion_SDE_VE_" + datetime.now().strftime("%Y-%m-%d_%H:%M:%S"),
|
193
|
+
model_output_transform=KarrasPredictionTransform(sigma_data=edm_schedule.sigma_data),
|
194
|
+
encoder=text_encoder,
|
195
|
+
distributed_training=True,
|
196
|
+
wandb_config = {
|
197
|
+
"project": 'mlops-msml605-project',
|
198
|
+
"name": f"prototype-{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}",
|
199
|
+
})
|
162
200
|
|
163
201
|
# Train the model
|
164
|
-
final_state = trainer.fit(data, batches, epochs=2000)
|
202
|
+
final_state = trainer.fit(data, batches, epochs=2000, sampler_class=EulerAncestralSampler, sampling_noise_schedule=karas_ve_schedule)
|
165
203
|
```
|
166
204
|
|
167
205
|
### Inference Example
|
@@ -279,8 +317,8 @@ Images generated by the following prompts using classifier free guidance with gu
|
|
279
317
|
`Training Epochs: 1000`
|
280
318
|
`Steps per epoch: 511`
|
281
319
|
|
282
|
-
`Training Noise Schedule:
|
283
|
-
`Inference Noise Schedule:
|
320
|
+
`Training Noise Schedule: CosineNoiseScheduler`
|
321
|
+
`Inference Noise Schedule: CosineNoiseScheduler`
|
284
322
|
|
285
323
|
`Model: UNet(emb_features=256,
|
286
324
|
feature_depths=[64, 128, 256, 512],
|
@@ -299,8 +337,8 @@ Images generated by the following prompts using classifier free guidance with gu
|
|
299
337
|
`Training Epochs: 1000`
|
300
338
|
`Steps per epoch: 511`
|
301
339
|
|
302
|
-
`Training Noise Schedule:
|
303
|
-
`Inference Noise Schedule:
|
340
|
+
`Training Noise Schedule: CosineNoiseScheduler`
|
341
|
+
`Inference Noise Schedule: CosineNoiseScheduler`
|
304
342
|
|
305
343
|
`Model: UNet(emb_features=256,
|
306
344
|
feature_depths=[64, 128, 256, 512],
|
@@ -37,7 +37,7 @@ class DiffusionSampler():
|
|
37
37
|
# Classifier free guidance
|
38
38
|
assert null_labels_seq is not None, "Null labels sequence is required for classifier-free guidance"
|
39
39
|
print("Using classifier-free guidance")
|
40
|
-
def sample_model(x_t, t, *additional_inputs):
|
40
|
+
def sample_model(params, x_t, t, *additional_inputs):
|
41
41
|
# Concatenate unconditional and conditional inputs
|
42
42
|
x_t_cat = jnp.concatenate([x_t] * 2, axis=0)
|
43
43
|
t_cat = jnp.concatenate([t] * 2, axis=0)
|
@@ -46,7 +46,7 @@ class DiffusionSampler():
|
|
46
46
|
|
47
47
|
text_labels_seq, = additional_inputs
|
48
48
|
text_labels_seq = jnp.concatenate([text_labels_seq, jnp.broadcast_to(null_labels_seq, text_labels_seq.shape)], axis=0)
|
49
|
-
model_output = self.model.apply(
|
49
|
+
model_output = self.model.apply(params, *self.noise_schedule.transform_inputs(x_t_cat * c_in_cat, t_cat), text_labels_seq)
|
50
50
|
# Split model output into unconditional and conditional parts
|
51
51
|
model_output_cond, model_output_uncond = jnp.split(model_output, 2, axis=0)
|
52
52
|
model_output = model_output_uncond + guidance_scale * (model_output_cond - model_output_uncond)
|
@@ -55,10 +55,10 @@ class DiffusionSampler():
|
|
55
55
|
return x_0, eps, model_output
|
56
56
|
else:
|
57
57
|
# Unconditional sampling
|
58
|
-
def sample_model(x_t, t, *additional_inputs):
|
58
|
+
def sample_model(params, x_t, t, *additional_inputs):
|
59
59
|
rates = self.noise_schedule.get_rates(t)
|
60
60
|
c_in = self.model_output_transform.get_input_scale(rates)
|
61
|
-
model_output = self.model.apply(
|
61
|
+
model_output = self.model.apply(params, *self.noise_schedule.transform_inputs(x_t * c_in, t), *additional_inputs)
|
62
62
|
x_0, eps = self.model_output_transform(x_t, model_output, t, self.noise_schedule)
|
63
63
|
return x_0, eps, model_output
|
64
64
|
|
@@ -70,22 +70,23 @@ class DiffusionSampler():
|
|
70
70
|
self.sample_model = sample_model
|
71
71
|
|
72
72
|
# Used to sample from the diffusion model
|
73
|
-
def sample_step(self, current_samples:jnp.ndarray, current_step, model_conditioning_inputs, next_step=None, state:MarkovState=None) -> tuple[jnp.ndarray, MarkovState]:
|
73
|
+
def sample_step(self, sample_model_fn, current_samples:jnp.ndarray, current_step, model_conditioning_inputs, next_step=None, state:MarkovState=None) -> tuple[jnp.ndarray, MarkovState]:
|
74
74
|
# First clip the noisy images
|
75
75
|
step_ones = jnp.ones((current_samples.shape[0], ), dtype=jnp.int32)
|
76
76
|
current_step = step_ones * current_step
|
77
77
|
next_step = step_ones * next_step
|
78
|
-
pred_images, pred_noise, _ =
|
78
|
+
pred_images, pred_noise, _ = sample_model_fn(current_samples, current_step, *model_conditioning_inputs)
|
79
79
|
# plotImages(pred_images)
|
80
80
|
# pred_images = clip_images(pred_images)
|
81
81
|
new_samples, state = self.take_next_step(current_samples=current_samples, reconstructed_samples=pred_images,
|
82
|
-
|
83
|
-
|
82
|
+
pred_noise=pred_noise, current_step=current_step, next_step=next_step, state=state,
|
83
|
+
model_conditioning_inputs=model_conditioning_inputs,
|
84
|
+
sample_model_fn=sample_model_fn,
|
84
85
|
)
|
85
86
|
return new_samples, state
|
86
87
|
|
87
88
|
def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
|
88
|
-
pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
|
89
|
+
pred_noise, current_step, state:RandomMarkovState, sample_model_fn, next_step=1,) -> tuple[jnp.ndarray, RandomMarkovState]:
|
89
90
|
# estimate the q(x_{t-1} | x_t, x_0).
|
90
91
|
# pred_images is x_0, noisy_images is x_t, steps is t
|
91
92
|
return NotImplementedError
|
@@ -114,6 +115,7 @@ class DiffusionSampler():
|
|
114
115
|
return jax.random.normal(rngs, (num_images, image_size, image_size, image_channels)) * variance
|
115
116
|
|
116
117
|
def generate_images(self,
|
118
|
+
params:dict=None,
|
117
119
|
num_images=16,
|
118
120
|
diffusion_steps=1000,
|
119
121
|
start_step:int = None,
|
@@ -131,10 +133,15 @@ class DiffusionSampler():
|
|
131
133
|
if self.autoencoder is not None:
|
132
134
|
priors = self.autoencoder.encode(priors)
|
133
135
|
samples = priors
|
136
|
+
|
137
|
+
params = params if params is not None else self.params
|
138
|
+
|
139
|
+
def sample_model_fn(x_t, t, *additional_inputs):
|
140
|
+
return self.sample_model(params, x_t, t, *additional_inputs)
|
134
141
|
|
135
142
|
# @jax.jit
|
136
|
-
def sample_step(state:RandomMarkovState, samples, current_step, next_step):
|
137
|
-
samples, state = self.sample_step(current_samples=samples,
|
143
|
+
def sample_step(sample_model_fn, state:RandomMarkovState, samples, current_step, next_step):
|
144
|
+
samples, state = self.sample_step(sample_model_fn=sample_model_fn, current_samples=samples,
|
138
145
|
current_step=current_step,
|
139
146
|
model_conditioning_inputs=model_conditioning_inputs,
|
140
147
|
state=state, next_step=next_step)
|
@@ -154,11 +161,11 @@ class DiffusionSampler():
|
|
154
161
|
next_step = self.scale_steps(steps[i+1] if i+1 < len(steps) else 0)
|
155
162
|
if i != len(steps) - 1:
|
156
163
|
# print("normal step")
|
157
|
-
samples, rngstate = sample_step(rngstate, samples, current_step, next_step)
|
164
|
+
samples, rngstate = sample_step(sample_model_fn, rngstate, samples, current_step, next_step)
|
158
165
|
else:
|
159
166
|
# print("last step")
|
160
167
|
step_ones = jnp.ones((num_images, ), dtype=jnp.int32)
|
161
|
-
samples, _, _ =
|
168
|
+
samples, _, _ = sample_model_fn(samples, current_step * step_ones, *model_conditioning_inputs)
|
162
169
|
if self.autoencoder is not None:
|
163
170
|
samples = self.autoencoder.decode(samples)
|
164
171
|
samples = clip_images(samples)
|
@@ -4,7 +4,7 @@ from ..utils import MarkovState, RandomMarkovState
|
|
4
4
|
|
5
5
|
class DDIMSampler(DiffusionSampler):
|
6
6
|
def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
|
7
|
-
pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
|
7
|
+
pred_noise, current_step, state:RandomMarkovState, sample_model_fn, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
|
8
8
|
next_signal_rate, next_noise_rate = self.noise_schedule.get_rates(next_step)
|
9
9
|
return reconstructed_samples * next_signal_rate + pred_noise * next_noise_rate, state
|
10
10
|
|
@@ -4,7 +4,7 @@ from .common import DiffusionSampler
|
|
4
4
|
from ..utils import MarkovState, RandomMarkovState
|
5
5
|
class DDPMSampler(DiffusionSampler):
|
6
6
|
def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
|
7
|
-
pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
|
7
|
+
pred_noise, current_step, state:RandomMarkovState, sample_model_fn, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
|
8
8
|
mean = self.noise_schedule.get_posterior_mean(reconstructed_samples, current_samples, current_step)
|
9
9
|
variance = self.noise_schedule.get_posterior_variance(steps=current_step)
|
10
10
|
|
@@ -19,7 +19,7 @@ class DDPMSampler(DiffusionSampler):
|
|
19
19
|
|
20
20
|
class SimpleDDPMSampler(DiffusionSampler):
|
21
21
|
def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
|
22
|
-
pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
|
22
|
+
pred_noise, current_step, state:RandomMarkovState, sample_model_fn, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
|
23
23
|
state, rng = state.get_random_key()
|
24
24
|
noise = jax.random.normal(rng, reconstructed_samples.shape, dtype=jnp.float32)
|
25
25
|
|
@@ -6,7 +6,7 @@ from ..utils import RandomMarkovState
|
|
6
6
|
class EulerSampler(DiffusionSampler):
|
7
7
|
# Basically a DDIM Sampler but parameterized as an ODE
|
8
8
|
def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
|
9
|
-
pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
|
9
|
+
pred_noise, current_step, state:RandomMarkovState, sample_model_fn, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
|
10
10
|
current_alpha, current_sigma = self.noise_schedule.get_rates(current_step)
|
11
11
|
next_alpha, next_sigma = self.noise_schedule.get_rates(next_step)
|
12
12
|
|
@@ -22,7 +22,7 @@ class SimplifiedEulerSampler(DiffusionSampler):
|
|
22
22
|
This is for networks with forward diffusion of the form x_{t+1} = x_t + sigma_t * epsilon_t
|
23
23
|
"""
|
24
24
|
def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
|
25
|
-
pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
|
25
|
+
pred_noise, current_step, state:RandomMarkovState, sample_model_fn, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
|
26
26
|
_, current_sigma = self.noise_schedule.get_rates(current_step)
|
27
27
|
_, next_sigma = self.noise_schedule.get_rates(next_step)
|
28
28
|
|
@@ -37,7 +37,7 @@ class EulerAncestralSampler(DiffusionSampler):
|
|
37
37
|
Similar to EulerSampler but with ancestral sampling
|
38
38
|
"""
|
39
39
|
def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
|
40
|
-
pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
|
40
|
+
pred_noise, current_step, state:RandomMarkovState, sample_model_fn, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
|
41
41
|
current_alpha, current_sigma = self.noise_schedule.get_rates(current_step)
|
42
42
|
next_alpha, next_sigma = self.noise_schedule.get_rates(next_step)
|
43
43
|
|
@@ -5,7 +5,7 @@ from ..utils import RandomMarkovState
|
|
5
5
|
|
6
6
|
class HeunSampler(DiffusionSampler):
|
7
7
|
def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
|
8
|
-
pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
|
8
|
+
pred_noise, current_step, state:RandomMarkovState, sample_model_fn, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
|
9
9
|
# Get the noise and signal rates for the current and next steps
|
10
10
|
current_alpha, current_sigma = self.noise_schedule.get_rates(current_step)
|
11
11
|
next_alpha, next_sigma = self.noise_schedule.get_rates(next_step)
|
@@ -17,7 +17,7 @@ class HeunSampler(DiffusionSampler):
|
|
17
17
|
next_samples_0 = current_samples + dx_0 * dt
|
18
18
|
|
19
19
|
# Recompute x_0 and eps at the first estimate to refine the derivative
|
20
|
-
estimated_x_0, _, _ =
|
20
|
+
estimated_x_0, _, _ = sample_model_fn(next_samples_0, next_step, *model_conditioning_inputs)
|
21
21
|
|
22
22
|
# Estimate the refined derivative using the midpoint (Heun's method)
|
23
23
|
dx_1 = (next_samples_0 - x_0_coeff * estimated_x_0) / next_sigma
|
@@ -9,7 +9,7 @@ class MultiStepDPM(DiffusionSampler):
|
|
9
9
|
self.history = []
|
10
10
|
|
11
11
|
def take_next_step(self, current_samples, reconstructed_samples, model_conditioning_inputs,
|
12
|
-
pred_noise, current_step, state:RandomMarkovState, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
|
12
|
+
pred_noise, current_step, state:RandomMarkovState, sample_model_fn, next_step=1) -> tuple[jnp.ndarray, RandomMarkovState]:
|
13
13
|
# Get the noise and signal rates for the current and next steps
|
14
14
|
current_alpha, current_sigma = self.noise_schedule.get_rates(current_step)
|
15
15
|
next_alpha, next_sigma = self.noise_schedule.get_rates(next_step)
|
@@ -9,14 +9,14 @@ class RK4Sampler(DiffusionSampler):
|
|
9
9
|
super().__init__(*args, **kwargs)
|
10
10
|
assert issubclass(type(self.noise_schedule), GeneralizedNoiseScheduler), "Noise schedule must be a GeneralizedNoiseScheduler"
|
11
11
|
@jax.jit
|
12
|
-
def get_derivative(x_t, sigma, state:RandomMarkovState, model_conditioning_inputs) -> tuple[jnp.ndarray, RandomMarkovState]:
|
12
|
+
def get_derivative(sample_model_fn, x_t, sigma, state:RandomMarkovState, model_conditioning_inputs) -> tuple[jnp.ndarray, RandomMarkovState]:
|
13
13
|
t = self.noise_schedule.get_timesteps(sigma)
|
14
|
-
x_0, eps, _ =
|
14
|
+
x_0, eps, _ = sample_model_fn(x_t, t, *model_conditioning_inputs)
|
15
15
|
return eps, state
|
16
16
|
|
17
17
|
self.get_derivative = get_derivative
|
18
18
|
|
19
|
-
def sample_step(self, current_samples:jnp.ndarray, current_step, model_conditioning_inputs, next_step=None, state:MarkovState=None) -> tuple[jnp.ndarray, MarkovState]:
|
19
|
+
def sample_step(self, sample_model_fn, current_samples:jnp.ndarray, current_step, model_conditioning_inputs, next_step=None, state:MarkovState=None) -> tuple[jnp.ndarray, MarkovState]:
|
20
20
|
step_ones = jnp.ones((current_samples.shape[0], ), dtype=jnp.int32)
|
21
21
|
current_step = step_ones * current_step
|
22
22
|
next_step = step_ones * next_step
|
@@ -25,10 +25,10 @@ class RK4Sampler(DiffusionSampler):
|
|
25
25
|
|
26
26
|
dt = next_sigma - current_sigma
|
27
27
|
|
28
|
-
k1, state = self.get_derivative(current_samples, current_sigma, state, model_conditioning_inputs)
|
29
|
-
k2, state = self.get_derivative(current_samples + 0.5 * k1 * dt, current_sigma + 0.5 * dt, state, model_conditioning_inputs)
|
30
|
-
k3, state = self.get_derivative(current_samples + 0.5 * k2 * dt, current_sigma + 0.5 * dt, state, model_conditioning_inputs)
|
31
|
-
k4, state = self.get_derivative(current_samples + k3 * dt, current_sigma + dt, state, model_conditioning_inputs)
|
28
|
+
k1, state = self.get_derivative(sample_model_fn, current_samples, current_sigma, state, model_conditioning_inputs)
|
29
|
+
k2, state = self.get_derivative(sample_model_fn, current_samples + 0.5 * k1 * dt, current_sigma + 0.5 * dt, state, model_conditioning_inputs)
|
30
|
+
k3, state = self.get_derivative(sample_model_fn, current_samples + 0.5 * k2 * dt, current_sigma + 0.5 * dt, state, model_conditioning_inputs)
|
31
|
+
k4, state = self.get_derivative(sample_model_fn, current_samples + k3 * dt, current_sigma + dt, state, model_conditioning_inputs)
|
32
32
|
|
33
33
|
next_samples = current_samples + (((k1 + 2 * k2 + 2 * k3 + k4) * dt) / 6)
|
34
34
|
return next_samples, state
|
@@ -1,6 +1,6 @@
|
|
1
1
|
from .discrete import DiscreteNoiseScheduler
|
2
2
|
from .common import NoiseScheduler, GeneralizedNoiseScheduler
|
3
|
-
from .cosine import
|
3
|
+
from .cosine import CosineNoiseScheduler, ContinuousNoiseScheduler, CosineGeneralNoiseScheduler
|
4
4
|
from .linear import LinearNoiseSchedule
|
5
5
|
from .sqrt import SqrtContinuousNoiseScheduler
|
6
6
|
from .karras import KarrasVENoiseScheduler, SimpleExpNoiseScheduler, EDMNoiseScheduler
|
@@ -12,7 +12,7 @@ def cosine_beta_schedule(timesteps, start_angle=0.008, end_angle=0.999):
|
|
12
12
|
betas = 1 - (alphas_bar[1:] / alphas_bar[:-1])
|
13
13
|
return np.clip(betas, 0, end_angle)
|
14
14
|
|
15
|
-
class
|
15
|
+
class CosineNoiseScheduler(DiscreteNoiseScheduler):
|
16
16
|
def __init__(self, timesteps, beta_start=0.008, beta_end=0.999, *args, **kwargs):
|
17
17
|
super().__init__(timesteps, beta_start, beta_end, schedule_fn=cosine_beta_schedule, *args, **kwargs)
|
18
18
|
|
@@ -235,19 +235,19 @@ class DiffusionTrainer(SimpleTrainer):
|
|
235
235
|
null_labels_full = null_labels_full.astype(jnp.float16)
|
236
236
|
# null_labels_seq = jnp.array(null_labels_full[0], dtype=jnp.float16)
|
237
237
|
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
return sampler
|
238
|
+
sampler = sampler_class(
|
239
|
+
model=model,
|
240
|
+
params=None,
|
241
|
+
noise_schedule=self.noise_schedule if sampling_noise_schedule is None else sampling_noise_schedule,
|
242
|
+
model_output_transform=self.model_output_transform,
|
243
|
+
image_size=self.input_shapes['x'][0],
|
244
|
+
null_labels_seq=null_labels_full,
|
245
|
+
autoencoder=autoencoder,
|
246
|
+
guidance_scale=3.0,
|
247
|
+
)
|
249
248
|
|
250
249
|
def generate_samples(
|
250
|
+
val_state: TrainState,
|
251
251
|
batch,
|
252
252
|
sampler: DiffusionSampler,
|
253
253
|
diffusion_steps: int,
|
@@ -255,6 +255,7 @@ class DiffusionTrainer(SimpleTrainer):
|
|
255
255
|
labels_seq = encoder.encode_from_tokens(batch)
|
256
256
|
labels_seq = jnp.array(labels_seq, dtype=jnp.float16)
|
257
257
|
samples = sampler.generate_images(
|
258
|
+
params=val_state.ema_params,
|
258
259
|
num_images=len(labels_seq),
|
259
260
|
diffusion_steps=diffusion_steps,
|
260
261
|
start_step=1000,
|
@@ -264,7 +265,7 @@ class DiffusionTrainer(SimpleTrainer):
|
|
264
265
|
)
|
265
266
|
return samples
|
266
267
|
|
267
|
-
return
|
268
|
+
return sampler, generate_samples
|
268
269
|
|
269
270
|
def validation_loop(
|
270
271
|
self,
|
@@ -275,14 +276,15 @@ class DiffusionTrainer(SimpleTrainer):
|
|
275
276
|
current_step,
|
276
277
|
diffusion_steps=200,
|
277
278
|
):
|
278
|
-
|
279
|
+
sampler, generate_samples = val_step_fn
|
279
280
|
|
280
|
-
sampler = generate_sampler(val_state)
|
281
|
+
# sampler = generate_sampler(val_state)
|
281
282
|
|
282
283
|
val_ds = iter(val_ds()) if val_ds else None
|
283
284
|
# Evaluation step
|
284
285
|
try:
|
285
286
|
samples = generate_samples(
|
287
|
+
val_state,
|
286
288
|
next(val_ds),
|
287
289
|
sampler,
|
288
290
|
diffusion_steps,
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: flaxdiff
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.37
|
4
4
|
Summary: A versatile and easy to understand Diffusion library
|
5
5
|
Author-email: Ashish Kumar Singh <ashishkmr472@gmail.com>
|
6
6
|
License-Expression: MIT
|
@@ -96,7 +96,7 @@ Also, few of the text may be generated with help of github copilot, so please ex
|
|
96
96
|
### Schedulers
|
97
97
|
Implemented in `flaxdiff.schedulers`:
|
98
98
|
- **LinearNoiseSchedule** (`flaxdiff.schedulers.LinearNoiseSchedule`): A beta-parameterized discrete scheduler.
|
99
|
-
- **
|
99
|
+
- **CosineNoiseScheduler** (`flaxdiff.schedulers.CosineNoiseScheduler`): A beta-parameterized discrete scheduler.
|
100
100
|
- **ExpNoiseSchedule** (`flaxdiff.schedulers.ExpNoiseSchedule`): A beta-parameterized discrete scheduler.
|
101
101
|
- **CosineContinuousNoiseScheduler** (`flaxdiff.schedulers.CosineContinuousNoiseScheduler`): A continuous scheduler.
|
102
102
|
- **CosineGeneralNoiseScheduler** (`flaxdiff.schedulers.CosineGeneralNoiseScheduler`): A continuous sigma parameterized cosine scheduler.
|
@@ -147,43 +147,81 @@ sticking to the versions mentioned in the requirements.txt
|
|
147
147
|
Here is a simplified example to get you started with training a diffusion model using FlaxDiff:
|
148
148
|
|
149
149
|
```python
|
150
|
-
from flaxdiff.schedulers import EDMNoiseScheduler
|
150
|
+
from flaxdiff.schedulers import EDMNoiseScheduler, KarrasVENoiseScheduler
|
151
151
|
from flaxdiff.predictors import KarrasPredictionTransform
|
152
|
-
from flaxdiff.models.simple_unet import
|
152
|
+
from flaxdiff.models.simple_unet import Unet
|
153
153
|
from flaxdiff.trainer import DiffusionTrainer
|
154
|
+
from flaxdiff.data.datasets import get_dataset_grain
|
155
|
+
from flaxdiff.utils import defaultTextEncodeModel
|
156
|
+
from flaxdiff.samplers.euler import EulerAncestralSampler
|
154
157
|
import jax
|
158
|
+
import jax.numpy as jnp
|
155
159
|
import optax
|
156
160
|
from datetime import datetime
|
157
161
|
|
158
162
|
BATCH_SIZE = 16
|
159
|
-
IMAGE_SIZE =
|
163
|
+
IMAGE_SIZE = 128
|
160
164
|
|
161
165
|
# Define noise scheduler
|
162
166
|
edm_schedule = EDMNoiseScheduler(1, sigma_max=80, rho=7, sigma_data=0.5)
|
163
|
-
|
167
|
+
karas_ve_schedule = KarrasVENoiseScheduler(1, sigma_max=80, rho=7, sigma_data=0.5)
|
164
168
|
# Define model
|
165
|
-
unet =
|
166
|
-
feature_depths=[64, 128, 256, 512],
|
167
|
-
attention_configs=[
|
169
|
+
unet = Unet(emb_features=256,
|
170
|
+
feature_depths=[64, 64, 128, 256, 512],
|
171
|
+
attention_configs=[
|
172
|
+
None,
|
173
|
+
{"heads":8, "dtype":jnp.float16, "flash_attention":False, "use_projection":True, "use_self_and_cross":True},
|
174
|
+
{"heads":8, "dtype":jnp.float16, "flash_attention":False, "use_projection":True, "use_self_and_cross":True},
|
175
|
+
{"heads":8, "dtype":jnp.float16, "flash_attention":False, "use_projection":True, "use_self_and_cross":True},
|
176
|
+
{"heads":8, "dtype":jnp.float16, "flash_attention":False, "use_projection":False, "use_self_and_cross":False}
|
177
|
+
],
|
168
178
|
num_res_blocks=2,
|
169
|
-
num_middle_res_blocks=1
|
170
|
-
|
179
|
+
num_middle_res_blocks=1
|
180
|
+
)
|
171
181
|
# Load dataset
|
172
|
-
data
|
182
|
+
data = get_dataset_grain("oxford_flowers102", batch_size=BATCH_SIZE, image_scale=IMAGE_SIZE)
|
183
|
+
datalen = data['train_len']
|
173
184
|
batches = datalen // BATCH_SIZE
|
174
185
|
|
186
|
+
input_shapes = {
|
187
|
+
"x": (IMAGE_SIZE, IMAGE_SIZE, 3),
|
188
|
+
"temb": (),
|
189
|
+
"textcontext": (77, 768)
|
190
|
+
}
|
191
|
+
text_encoder = defaultTextEncodeModel()
|
192
|
+
|
193
|
+
# Construct a validation set by the prompts
|
194
|
+
val_prompts = ['water tulip', ' a water lily', ' a water lily', ' a photo of a rose', ' a photo of a rose', ' a water lily', ' a water lily', ' a photo of a marigold', ' a photo of a marigold']
|
195
|
+
|
196
|
+
def get_val_dataset(batch_size=8):
|
197
|
+
for i in range(0, len(val_prompts), batch_size):
|
198
|
+
prompts = val_prompts[i:i + batch_size]
|
199
|
+
tokens = text_encoder.tokenize(prompts)
|
200
|
+
yield tokens
|
201
|
+
|
202
|
+
data['test'] = get_val_dataset
|
203
|
+
data['test_len'] = len(val_prompts)
|
204
|
+
|
175
205
|
# Define optimizer
|
176
206
|
solver = optax.adam(2e-4)
|
177
207
|
|
178
208
|
# Create trainer
|
179
|
-
trainer = DiffusionTrainer(
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
209
|
+
trainer = DiffusionTrainer(
|
210
|
+
unet, optimizer=solver,
|
211
|
+
input_shapes=input_shapes,
|
212
|
+
noise_schedule=edm_schedule,
|
213
|
+
rngs=jax.random.PRNGKey(4),
|
214
|
+
name="Diffusion_SDE_VE_" + datetime.now().strftime("%Y-%m-%d_%H:%M:%S"),
|
215
|
+
model_output_transform=KarrasPredictionTransform(sigma_data=edm_schedule.sigma_data),
|
216
|
+
encoder=text_encoder,
|
217
|
+
distributed_training=True,
|
218
|
+
wandb_config = {
|
219
|
+
"project": 'mlops-msml605-project',
|
220
|
+
"name": f"prototype-{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}",
|
221
|
+
})
|
184
222
|
|
185
223
|
# Train the model
|
186
|
-
final_state = trainer.fit(data, batches, epochs=2000)
|
224
|
+
final_state = trainer.fit(data, batches, epochs=2000, sampler_class=EulerAncestralSampler, sampling_noise_schedule=karas_ve_schedule)
|
187
225
|
```
|
188
226
|
|
189
227
|
### Inference Example
|
@@ -301,8 +339,8 @@ Images generated by the following prompts using classifier free guidance with gu
|
|
301
339
|
`Training Epochs: 1000`
|
302
340
|
`Steps per epoch: 511`
|
303
341
|
|
304
|
-
`Training Noise Schedule:
|
305
|
-
`Inference Noise Schedule:
|
342
|
+
`Training Noise Schedule: CosineNoiseScheduler`
|
343
|
+
`Inference Noise Schedule: CosineNoiseScheduler`
|
306
344
|
|
307
345
|
`Model: UNet(emb_features=256,
|
308
346
|
feature_depths=[64, 128, 256, 512],
|
@@ -321,8 +359,8 @@ Images generated by the following prompts using classifier free guidance with gu
|
|
321
359
|
`Training Epochs: 1000`
|
322
360
|
`Steps per epoch: 511`
|
323
361
|
|
324
|
-
`Training Noise Schedule:
|
325
|
-
`Inference Noise Schedule:
|
362
|
+
`Training Noise Schedule: CosineNoiseScheduler`
|
363
|
+
`Inference Noise Schedule: CosineNoiseScheduler`
|
326
364
|
|
327
365
|
`Model: UNet(emb_features=256,
|
328
366
|
feature_depths=[64, 128, 256, 512],
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|