flaxdiff 0.1.36.3__tar.gz → 0.1.36.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.5}/PKG-INFO +60 -22
- {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.5}/README.md +59 -21
- {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.5}/flaxdiff/data/sources/tfds.py +12 -0
- {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.5}/flaxdiff/schedulers/__init__.py +1 -1
- {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.5}/flaxdiff/schedulers/cosine.py +1 -1
- {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.5}/flaxdiff/trainer/diffusion_trainer.py +6 -7
- {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.5}/flaxdiff.egg-info/PKG-INFO +60 -22
- {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.5}/pyproject.toml +1 -1
- {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.5}/flaxdiff/__init__.py +0 -0
- {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.5}/flaxdiff/data/__init__.py +0 -0
- {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.5}/flaxdiff/data/dataset_map.py +0 -0
- {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.5}/flaxdiff/data/datasets.py +0 -0
- {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.5}/flaxdiff/data/online_loader.py +0 -0
- {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.5}/flaxdiff/data/sources/gcs.py +0 -0
- {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.5}/flaxdiff/metrics/inception.py +0 -0
- {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.5}/flaxdiff/metrics/utils.py +0 -0
- {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.5}/flaxdiff/models/__init__.py +0 -0
- {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.5}/flaxdiff/models/attention.py +0 -0
- {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.5}/flaxdiff/models/autoencoder/__init__.py +0 -0
- {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.5}/flaxdiff/models/autoencoder/autoencoder.py +0 -0
- {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.5}/flaxdiff/models/autoencoder/diffusers.py +0 -0
- {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.5}/flaxdiff/models/autoencoder/simple_autoenc.py +0 -0
- {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.5}/flaxdiff/models/common.py +0 -0
- {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.5}/flaxdiff/models/favor_fastattn.py +0 -0
- {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.5}/flaxdiff/models/simple_unet.py +0 -0
- {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.5}/flaxdiff/models/simple_vit.py +0 -0
- {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.5}/flaxdiff/predictors/__init__.py +0 -0
- {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.5}/flaxdiff/samplers/__init__.py +0 -0
- {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.5}/flaxdiff/samplers/common.py +0 -0
- {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.5}/flaxdiff/samplers/ddim.py +0 -0
- {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.5}/flaxdiff/samplers/ddpm.py +0 -0
- {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.5}/flaxdiff/samplers/euler.py +0 -0
- {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.5}/flaxdiff/samplers/heun_sampler.py +0 -0
- {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.5}/flaxdiff/samplers/multistep_dpm.py +0 -0
- {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.5}/flaxdiff/samplers/rk4_sampler.py +0 -0
- {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.5}/flaxdiff/schedulers/common.py +0 -0
- {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.5}/flaxdiff/schedulers/continuous.py +0 -0
- {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.5}/flaxdiff/schedulers/discrete.py +0 -0
- {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.5}/flaxdiff/schedulers/exp.py +0 -0
- {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.5}/flaxdiff/schedulers/karras.py +0 -0
- {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.5}/flaxdiff/schedulers/linear.py +0 -0
- {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.5}/flaxdiff/schedulers/sqrt.py +0 -0
- {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.5}/flaxdiff/trainer/__init__.py +0 -0
- {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.5}/flaxdiff/trainer/autoencoder_trainer.py +0 -0
- {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.5}/flaxdiff/trainer/simple_trainer.py +0 -0
- {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.5}/flaxdiff/trainer/video_diffusion_trainer.py +0 -0
- {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.5}/flaxdiff/utils.py +0 -0
- {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.5}/flaxdiff.egg-info/SOURCES.txt +0 -0
- {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.5}/flaxdiff.egg-info/dependency_links.txt +0 -0
- {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.5}/flaxdiff.egg-info/requires.txt +0 -0
- {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.5}/flaxdiff.egg-info/top_level.txt +0 -0
- {flaxdiff-0.1.36.3 → flaxdiff-0.1.36.5}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: flaxdiff
|
3
|
-
Version: 0.1.36.
|
3
|
+
Version: 0.1.36.5
|
4
4
|
Summary: A versatile and easy to understand Diffusion library
|
5
5
|
Author-email: Ashish Kumar Singh <ashishkmr472@gmail.com>
|
6
6
|
License-Expression: MIT
|
@@ -96,7 +96,7 @@ Also, few of the text may be generated with help of github copilot, so please ex
|
|
96
96
|
### Schedulers
|
97
97
|
Implemented in `flaxdiff.schedulers`:
|
98
98
|
- **LinearNoiseSchedule** (`flaxdiff.schedulers.LinearNoiseSchedule`): A beta-parameterized discrete scheduler.
|
99
|
-
- **
|
99
|
+
- **CosineNoiseScheduler** (`flaxdiff.schedulers.CosineNoiseScheduler`): A beta-parameterized discrete scheduler.
|
100
100
|
- **ExpNoiseSchedule** (`flaxdiff.schedulers.ExpNoiseSchedule`): A beta-parameterized discrete scheduler.
|
101
101
|
- **CosineContinuousNoiseScheduler** (`flaxdiff.schedulers.CosineContinuousNoiseScheduler`): A continuous scheduler.
|
102
102
|
- **CosineGeneralNoiseScheduler** (`flaxdiff.schedulers.CosineGeneralNoiseScheduler`): A continuous sigma parameterized cosine scheduler.
|
@@ -147,43 +147,81 @@ sticking to the versions mentioned in the requirements.txt
|
|
147
147
|
Here is a simplified example to get you started with training a diffusion model using FlaxDiff:
|
148
148
|
|
149
149
|
```python
|
150
|
-
from flaxdiff.schedulers import EDMNoiseScheduler
|
150
|
+
from flaxdiff.schedulers import EDMNoiseScheduler, KarrasVENoiseScheduler
|
151
151
|
from flaxdiff.predictors import KarrasPredictionTransform
|
152
|
-
from flaxdiff.models.simple_unet import
|
152
|
+
from flaxdiff.models.simple_unet import Unet
|
153
153
|
from flaxdiff.trainer import DiffusionTrainer
|
154
|
+
from flaxdiff.data.datasets import get_dataset_grain
|
155
|
+
from flaxdiff.utils import defaultTextEncodeModel
|
156
|
+
from flaxdiff.samplers.euler import EulerAncestralSampler
|
154
157
|
import jax
|
158
|
+
import jax.numpy as jnp
|
155
159
|
import optax
|
156
160
|
from datetime import datetime
|
157
161
|
|
158
162
|
BATCH_SIZE = 16
|
159
|
-
IMAGE_SIZE =
|
163
|
+
IMAGE_SIZE = 128
|
160
164
|
|
161
165
|
# Define noise scheduler
|
162
166
|
edm_schedule = EDMNoiseScheduler(1, sigma_max=80, rho=7, sigma_data=0.5)
|
163
|
-
|
167
|
+
karas_ve_schedule = KarrasVENoiseScheduler(1, sigma_max=80, rho=7, sigma_data=0.5)
|
164
168
|
# Define model
|
165
|
-
unet =
|
166
|
-
feature_depths=[64, 128, 256, 512],
|
167
|
-
attention_configs=[
|
169
|
+
unet = Unet(emb_features=256,
|
170
|
+
feature_depths=[64, 64, 128, 256, 512],
|
171
|
+
attention_configs=[
|
172
|
+
None,
|
173
|
+
{"heads":8, "dtype":jnp.float16, "flash_attention":False, "use_projection":True, "use_self_and_cross":True},
|
174
|
+
{"heads":8, "dtype":jnp.float16, "flash_attention":False, "use_projection":True, "use_self_and_cross":True},
|
175
|
+
{"heads":8, "dtype":jnp.float16, "flash_attention":False, "use_projection":True, "use_self_and_cross":True},
|
176
|
+
{"heads":8, "dtype":jnp.float16, "flash_attention":False, "use_projection":False, "use_self_and_cross":False}
|
177
|
+
],
|
168
178
|
num_res_blocks=2,
|
169
|
-
num_middle_res_blocks=1
|
170
|
-
|
179
|
+
num_middle_res_blocks=1
|
180
|
+
)
|
171
181
|
# Load dataset
|
172
|
-
data
|
182
|
+
data = get_dataset_grain("oxford_flowers102", batch_size=BATCH_SIZE, image_scale=IMAGE_SIZE)
|
183
|
+
datalen = data['train_len']
|
173
184
|
batches = datalen // BATCH_SIZE
|
174
185
|
|
186
|
+
input_shapes = {
|
187
|
+
"x": (IMAGE_SIZE, IMAGE_SIZE, 3),
|
188
|
+
"temb": (),
|
189
|
+
"textcontext": (77, 768)
|
190
|
+
}
|
191
|
+
text_encoder = defaultTextEncodeModel()
|
192
|
+
|
193
|
+
# Construct a validation set by the prompts
|
194
|
+
val_prompts = ['water tulip', ' a water lily', ' a water lily', ' a photo of a rose', ' a photo of a rose', ' a water lily', ' a water lily', ' a photo of a marigold', ' a photo of a marigold']
|
195
|
+
|
196
|
+
def get_val_dataset(batch_size=8):
|
197
|
+
for i in range(0, len(val_prompts), batch_size):
|
198
|
+
prompts = val_prompts[i:i + batch_size]
|
199
|
+
tokens = text_encoder.tokenize(prompts)
|
200
|
+
yield tokens
|
201
|
+
|
202
|
+
data['test'] = get_val_dataset
|
203
|
+
data['test_len'] = len(val_prompts)
|
204
|
+
|
175
205
|
# Define optimizer
|
176
206
|
solver = optax.adam(2e-4)
|
177
207
|
|
178
208
|
# Create trainer
|
179
|
-
trainer = DiffusionTrainer(
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
209
|
+
trainer = DiffusionTrainer(
|
210
|
+
unet, optimizer=solver,
|
211
|
+
input_shapes=input_shapes,
|
212
|
+
noise_schedule=edm_schedule,
|
213
|
+
rngs=jax.random.PRNGKey(4),
|
214
|
+
name="Diffusion_SDE_VE_" + datetime.now().strftime("%Y-%m-%d_%H:%M:%S"),
|
215
|
+
model_output_transform=KarrasPredictionTransform(sigma_data=edm_schedule.sigma_data),
|
216
|
+
encoder=text_encoder,
|
217
|
+
distributed_training=True,
|
218
|
+
wandb_config = {
|
219
|
+
"project": 'mlops-msml605-project',
|
220
|
+
"name": f"prototype-{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}",
|
221
|
+
})
|
184
222
|
|
185
223
|
# Train the model
|
186
|
-
final_state = trainer.fit(data, batches, epochs=2000)
|
224
|
+
final_state = trainer.fit(data, batches, epochs=2000, sampler_class=EulerAncestralSampler, sampling_noise_schedule=karas_ve_schedule)
|
187
225
|
```
|
188
226
|
|
189
227
|
### Inference Example
|
@@ -301,8 +339,8 @@ Images generated by the following prompts using classifier free guidance with gu
|
|
301
339
|
`Training Epochs: 1000`
|
302
340
|
`Steps per epoch: 511`
|
303
341
|
|
304
|
-
`Training Noise Schedule:
|
305
|
-
`Inference Noise Schedule:
|
342
|
+
`Training Noise Schedule: CosineNoiseScheduler`
|
343
|
+
`Inference Noise Schedule: CosineNoiseScheduler`
|
306
344
|
|
307
345
|
`Model: UNet(emb_features=256,
|
308
346
|
feature_depths=[64, 128, 256, 512],
|
@@ -321,8 +359,8 @@ Images generated by the following prompts using classifier free guidance with gu
|
|
321
359
|
`Training Epochs: 1000`
|
322
360
|
`Steps per epoch: 511`
|
323
361
|
|
324
|
-
`Training Noise Schedule:
|
325
|
-
`Inference Noise Schedule:
|
362
|
+
`Training Noise Schedule: CosineNoiseScheduler`
|
363
|
+
`Inference Noise Schedule: CosineNoiseScheduler`
|
326
364
|
|
327
365
|
`Model: UNet(emb_features=256,
|
328
366
|
feature_depths=[64, 128, 256, 512],
|
@@ -74,7 +74,7 @@ Also, few of the text may be generated with help of github copilot, so please ex
|
|
74
74
|
### Schedulers
|
75
75
|
Implemented in `flaxdiff.schedulers`:
|
76
76
|
- **LinearNoiseSchedule** (`flaxdiff.schedulers.LinearNoiseSchedule`): A beta-parameterized discrete scheduler.
|
77
|
-
- **
|
77
|
+
- **CosineNoiseScheduler** (`flaxdiff.schedulers.CosineNoiseScheduler`): A beta-parameterized discrete scheduler.
|
78
78
|
- **ExpNoiseSchedule** (`flaxdiff.schedulers.ExpNoiseSchedule`): A beta-parameterized discrete scheduler.
|
79
79
|
- **CosineContinuousNoiseScheduler** (`flaxdiff.schedulers.CosineContinuousNoiseScheduler`): A continuous scheduler.
|
80
80
|
- **CosineGeneralNoiseScheduler** (`flaxdiff.schedulers.CosineGeneralNoiseScheduler`): A continuous sigma parameterized cosine scheduler.
|
@@ -125,43 +125,81 @@ sticking to the versions mentioned in the requirements.txt
|
|
125
125
|
Here is a simplified example to get you started with training a diffusion model using FlaxDiff:
|
126
126
|
|
127
127
|
```python
|
128
|
-
from flaxdiff.schedulers import EDMNoiseScheduler
|
128
|
+
from flaxdiff.schedulers import EDMNoiseScheduler, KarrasVENoiseScheduler
|
129
129
|
from flaxdiff.predictors import KarrasPredictionTransform
|
130
|
-
from flaxdiff.models.simple_unet import
|
130
|
+
from flaxdiff.models.simple_unet import Unet
|
131
131
|
from flaxdiff.trainer import DiffusionTrainer
|
132
|
+
from flaxdiff.data.datasets import get_dataset_grain
|
133
|
+
from flaxdiff.utils import defaultTextEncodeModel
|
134
|
+
from flaxdiff.samplers.euler import EulerAncestralSampler
|
132
135
|
import jax
|
136
|
+
import jax.numpy as jnp
|
133
137
|
import optax
|
134
138
|
from datetime import datetime
|
135
139
|
|
136
140
|
BATCH_SIZE = 16
|
137
|
-
IMAGE_SIZE =
|
141
|
+
IMAGE_SIZE = 128
|
138
142
|
|
139
143
|
# Define noise scheduler
|
140
144
|
edm_schedule = EDMNoiseScheduler(1, sigma_max=80, rho=7, sigma_data=0.5)
|
141
|
-
|
145
|
+
karas_ve_schedule = KarrasVENoiseScheduler(1, sigma_max=80, rho=7, sigma_data=0.5)
|
142
146
|
# Define model
|
143
|
-
unet =
|
144
|
-
feature_depths=[64, 128, 256, 512],
|
145
|
-
attention_configs=[
|
147
|
+
unet = Unet(emb_features=256,
|
148
|
+
feature_depths=[64, 64, 128, 256, 512],
|
149
|
+
attention_configs=[
|
150
|
+
None,
|
151
|
+
{"heads":8, "dtype":jnp.float16, "flash_attention":False, "use_projection":True, "use_self_and_cross":True},
|
152
|
+
{"heads":8, "dtype":jnp.float16, "flash_attention":False, "use_projection":True, "use_self_and_cross":True},
|
153
|
+
{"heads":8, "dtype":jnp.float16, "flash_attention":False, "use_projection":True, "use_self_and_cross":True},
|
154
|
+
{"heads":8, "dtype":jnp.float16, "flash_attention":False, "use_projection":False, "use_self_and_cross":False}
|
155
|
+
],
|
146
156
|
num_res_blocks=2,
|
147
|
-
num_middle_res_blocks=1
|
148
|
-
|
157
|
+
num_middle_res_blocks=1
|
158
|
+
)
|
149
159
|
# Load dataset
|
150
|
-
data
|
160
|
+
data = get_dataset_grain("oxford_flowers102", batch_size=BATCH_SIZE, image_scale=IMAGE_SIZE)
|
161
|
+
datalen = data['train_len']
|
151
162
|
batches = datalen // BATCH_SIZE
|
152
163
|
|
164
|
+
input_shapes = {
|
165
|
+
"x": (IMAGE_SIZE, IMAGE_SIZE, 3),
|
166
|
+
"temb": (),
|
167
|
+
"textcontext": (77, 768)
|
168
|
+
}
|
169
|
+
text_encoder = defaultTextEncodeModel()
|
170
|
+
|
171
|
+
# Construct a validation set by the prompts
|
172
|
+
val_prompts = ['water tulip', ' a water lily', ' a water lily', ' a photo of a rose', ' a photo of a rose', ' a water lily', ' a water lily', ' a photo of a marigold', ' a photo of a marigold']
|
173
|
+
|
174
|
+
def get_val_dataset(batch_size=8):
|
175
|
+
for i in range(0, len(val_prompts), batch_size):
|
176
|
+
prompts = val_prompts[i:i + batch_size]
|
177
|
+
tokens = text_encoder.tokenize(prompts)
|
178
|
+
yield tokens
|
179
|
+
|
180
|
+
data['test'] = get_val_dataset
|
181
|
+
data['test_len'] = len(val_prompts)
|
182
|
+
|
153
183
|
# Define optimizer
|
154
184
|
solver = optax.adam(2e-4)
|
155
185
|
|
156
186
|
# Create trainer
|
157
|
-
trainer = DiffusionTrainer(
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
187
|
+
trainer = DiffusionTrainer(
|
188
|
+
unet, optimizer=solver,
|
189
|
+
input_shapes=input_shapes,
|
190
|
+
noise_schedule=edm_schedule,
|
191
|
+
rngs=jax.random.PRNGKey(4),
|
192
|
+
name="Diffusion_SDE_VE_" + datetime.now().strftime("%Y-%m-%d_%H:%M:%S"),
|
193
|
+
model_output_transform=KarrasPredictionTransform(sigma_data=edm_schedule.sigma_data),
|
194
|
+
encoder=text_encoder,
|
195
|
+
distributed_training=True,
|
196
|
+
wandb_config = {
|
197
|
+
"project": 'mlops-msml605-project',
|
198
|
+
"name": f"prototype-{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}",
|
199
|
+
})
|
162
200
|
|
163
201
|
# Train the model
|
164
|
-
final_state = trainer.fit(data, batches, epochs=2000)
|
202
|
+
final_state = trainer.fit(data, batches, epochs=2000, sampler_class=EulerAncestralSampler, sampling_noise_schedule=karas_ve_schedule)
|
165
203
|
```
|
166
204
|
|
167
205
|
### Inference Example
|
@@ -279,8 +317,8 @@ Images generated by the following prompts using classifier free guidance with gu
|
|
279
317
|
`Training Epochs: 1000`
|
280
318
|
`Steps per epoch: 511`
|
281
319
|
|
282
|
-
`Training Noise Schedule:
|
283
|
-
`Inference Noise Schedule:
|
320
|
+
`Training Noise Schedule: CosineNoiseScheduler`
|
321
|
+
`Inference Noise Schedule: CosineNoiseScheduler`
|
284
322
|
|
285
323
|
`Model: UNet(emb_features=256,
|
286
324
|
feature_depths=[64, 128, 256, 512],
|
@@ -299,8 +337,8 @@ Images generated by the following prompts using classifier free guidance with gu
|
|
299
337
|
`Training Epochs: 1000`
|
300
338
|
`Steps per epoch: 511`
|
301
339
|
|
302
|
-
`Training Noise Schedule:
|
303
|
-
`Inference Noise Schedule:
|
340
|
+
`Training Noise Schedule: CosineNoiseScheduler`
|
341
|
+
`Inference Noise Schedule: CosineNoiseScheduler`
|
304
342
|
|
305
343
|
`Model: UNet(emb_features=256,
|
306
344
|
feature_depths=[64, 128, 256, 512],
|
@@ -4,6 +4,8 @@ import grain.python as pygrain
|
|
4
4
|
from flaxdiff.utils import AutoTextTokenizer
|
5
5
|
from typing import Dict
|
6
6
|
import random
|
7
|
+
import augmax
|
8
|
+
import jax
|
7
9
|
|
8
10
|
# -----------------------------------------------------------------------------------------------#
|
9
11
|
# Oxford flowers and other TFDS datasources -----------------------------------------------------#
|
@@ -47,6 +49,15 @@ def tfds_augmenters(image_scale, method):
|
|
47
49
|
interpolation = cv2.INTER_CUBIC
|
48
50
|
else:
|
49
51
|
interpolation = cv2.INTER_AREA
|
52
|
+
|
53
|
+
augments = augmax.Chain(
|
54
|
+
augmax.HorizontalFlip(0.5),
|
55
|
+
augmax.RandomContrast((-0.05, 0.05), 1.),
|
56
|
+
augmax.RandomBrightness((-0.2, 0.2), 1.)
|
57
|
+
)
|
58
|
+
|
59
|
+
augments = jax.jit(augments, backend="cpu")
|
60
|
+
|
50
61
|
class augmenters(pygrain.MapTransform):
|
51
62
|
def __init__(self, *args, **kwargs):
|
52
63
|
super().__init__(*args, **kwargs)
|
@@ -56,6 +67,7 @@ def tfds_augmenters(image_scale, method):
|
|
56
67
|
image = element['image']
|
57
68
|
image = cv2.resize(image, (image_scale, image_scale),
|
58
69
|
interpolation=interpolation)
|
70
|
+
# image = augments(image)
|
59
71
|
# image = (image - 127.5) / 127.5
|
60
72
|
caption = labelizer(element)
|
61
73
|
results = self.tokenize(caption)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
from .discrete import DiscreteNoiseScheduler
|
2
2
|
from .common import NoiseScheduler, GeneralizedNoiseScheduler
|
3
|
-
from .cosine import
|
3
|
+
from .cosine import CosineNoiseScheduler, ContinuousNoiseScheduler, CosineGeneralNoiseScheduler
|
4
4
|
from .linear import LinearNoiseSchedule
|
5
5
|
from .sqrt import SqrtContinuousNoiseScheduler
|
6
6
|
from .karras import KarrasVENoiseScheduler, SimpleExpNoiseScheduler, EDMNoiseScheduler
|
@@ -12,7 +12,7 @@ def cosine_beta_schedule(timesteps, start_angle=0.008, end_angle=0.999):
|
|
12
12
|
betas = 1 - (alphas_bar[1:] / alphas_bar[:-1])
|
13
13
|
return np.clip(betas, 0, end_angle)
|
14
14
|
|
15
|
-
class
|
15
|
+
class CosineNoiseScheduler(DiscreteNoiseScheduler):
|
16
16
|
def __init__(self, timesteps, beta_start=0.008, beta_end=0.999, *args, **kwargs):
|
17
17
|
super().__init__(timesteps, beta_start, beta_end, schedule_fn=cosine_beta_schedule, *args, **kwargs)
|
18
18
|
|
@@ -14,6 +14,7 @@ from typing import Dict, Callable, Sequence, Any, Union, Tuple, Type
|
|
14
14
|
from ..schedulers import NoiseScheduler
|
15
15
|
from ..predictors import DiffusionPredictionTransform, EpsilonPredictionTransform
|
16
16
|
from ..samplers.common import DiffusionSampler
|
17
|
+
from ..samplers.ddim import DDIMSampler
|
17
18
|
|
18
19
|
from flaxdiff.utils import RandomMarkovState
|
19
20
|
|
@@ -179,9 +180,6 @@ class DiffusionTrainer(SimpleTrainer):
|
|
179
180
|
nloss = loss_fn(preds, expected_output)
|
180
181
|
# Ignore the loss contribution of images with zero standard deviation
|
181
182
|
nloss *= noise_schedule.get_weights(noise_level)
|
182
|
-
# nloss = jnp.mean(nloss, axis=(1,2,3))
|
183
|
-
# nloss = jnp.where(is_non_zero, nloss, 0)
|
184
|
-
# nloss = jnp.mean(nloss, where=nloss != 0)
|
185
183
|
nloss = jnp.mean(nloss)
|
186
184
|
loss = nloss
|
187
185
|
return loss
|
@@ -224,11 +222,11 @@ class DiffusionTrainer(SimpleTrainer):
|
|
224
222
|
if distributed_training:
|
225
223
|
train_step = shard_map(train_step, mesh=self.mesh, in_specs=(P(), P(), P('data'), P('data')),
|
226
224
|
out_specs=(P(), P(), P()))
|
227
|
-
|
225
|
+
train_step = jax.jit(train_step)
|
228
226
|
|
229
227
|
return train_step
|
230
228
|
|
231
|
-
def _define_vaidation_step(self, sampler_class: Type[DiffusionSampler]):
|
229
|
+
def _define_vaidation_step(self, sampler_class: Type[DiffusionSampler]=DDIMSampler, sampling_noise_schedule: NoiseScheduler=None):
|
232
230
|
model = self.model
|
233
231
|
encoder = self.encoder
|
234
232
|
autoencoder = self.autoencoder
|
@@ -241,7 +239,7 @@ class DiffusionTrainer(SimpleTrainer):
|
|
241
239
|
sampler = sampler_class(
|
242
240
|
model=model,
|
243
241
|
params=state.ema_params,
|
244
|
-
noise_schedule=self.noise_schedule,
|
242
|
+
noise_schedule=self.noise_schedule if sampling_noise_schedule is None else sampling_noise_schedule,
|
245
243
|
model_output_transform=self.model_output_transform,
|
246
244
|
image_size=self.input_shapes['x'][0],
|
247
245
|
null_labels_seq=null_labels_full,
|
@@ -311,10 +309,11 @@ class DiffusionTrainer(SimpleTrainer):
|
|
311
309
|
print("Error logging images to wandb", e)
|
312
310
|
traceback.print_exc()
|
313
311
|
|
314
|
-
def fit(self, data, training_steps_per_epoch, epochs, val_steps_per_epoch=8, sampler_class=None):
|
312
|
+
def fit(self, data, training_steps_per_epoch, epochs, val_steps_per_epoch=8, sampler_class: Type[DiffusionSampler]=DDIMSampler, sampling_noise_schedule: NoiseScheduler=None):
|
315
313
|
local_batch_size = data['local_batch_size']
|
316
314
|
validation_step_args = {
|
317
315
|
"sampler_class": sampler_class,
|
316
|
+
"sampling_noise_schedule": sampling_noise_schedule,
|
318
317
|
}
|
319
318
|
super().fit(
|
320
319
|
data,
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: flaxdiff
|
3
|
-
Version: 0.1.36.
|
3
|
+
Version: 0.1.36.5
|
4
4
|
Summary: A versatile and easy to understand Diffusion library
|
5
5
|
Author-email: Ashish Kumar Singh <ashishkmr472@gmail.com>
|
6
6
|
License-Expression: MIT
|
@@ -96,7 +96,7 @@ Also, few of the text may be generated with help of github copilot, so please ex
|
|
96
96
|
### Schedulers
|
97
97
|
Implemented in `flaxdiff.schedulers`:
|
98
98
|
- **LinearNoiseSchedule** (`flaxdiff.schedulers.LinearNoiseSchedule`): A beta-parameterized discrete scheduler.
|
99
|
-
- **
|
99
|
+
- **CosineNoiseScheduler** (`flaxdiff.schedulers.CosineNoiseScheduler`): A beta-parameterized discrete scheduler.
|
100
100
|
- **ExpNoiseSchedule** (`flaxdiff.schedulers.ExpNoiseSchedule`): A beta-parameterized discrete scheduler.
|
101
101
|
- **CosineContinuousNoiseScheduler** (`flaxdiff.schedulers.CosineContinuousNoiseScheduler`): A continuous scheduler.
|
102
102
|
- **CosineGeneralNoiseScheduler** (`flaxdiff.schedulers.CosineGeneralNoiseScheduler`): A continuous sigma parameterized cosine scheduler.
|
@@ -147,43 +147,81 @@ sticking to the versions mentioned in the requirements.txt
|
|
147
147
|
Here is a simplified example to get you started with training a diffusion model using FlaxDiff:
|
148
148
|
|
149
149
|
```python
|
150
|
-
from flaxdiff.schedulers import EDMNoiseScheduler
|
150
|
+
from flaxdiff.schedulers import EDMNoiseScheduler, KarrasVENoiseScheduler
|
151
151
|
from flaxdiff.predictors import KarrasPredictionTransform
|
152
|
-
from flaxdiff.models.simple_unet import
|
152
|
+
from flaxdiff.models.simple_unet import Unet
|
153
153
|
from flaxdiff.trainer import DiffusionTrainer
|
154
|
+
from flaxdiff.data.datasets import get_dataset_grain
|
155
|
+
from flaxdiff.utils import defaultTextEncodeModel
|
156
|
+
from flaxdiff.samplers.euler import EulerAncestralSampler
|
154
157
|
import jax
|
158
|
+
import jax.numpy as jnp
|
155
159
|
import optax
|
156
160
|
from datetime import datetime
|
157
161
|
|
158
162
|
BATCH_SIZE = 16
|
159
|
-
IMAGE_SIZE =
|
163
|
+
IMAGE_SIZE = 128
|
160
164
|
|
161
165
|
# Define noise scheduler
|
162
166
|
edm_schedule = EDMNoiseScheduler(1, sigma_max=80, rho=7, sigma_data=0.5)
|
163
|
-
|
167
|
+
karas_ve_schedule = KarrasVENoiseScheduler(1, sigma_max=80, rho=7, sigma_data=0.5)
|
164
168
|
# Define model
|
165
|
-
unet =
|
166
|
-
feature_depths=[64, 128, 256, 512],
|
167
|
-
attention_configs=[
|
169
|
+
unet = Unet(emb_features=256,
|
170
|
+
feature_depths=[64, 64, 128, 256, 512],
|
171
|
+
attention_configs=[
|
172
|
+
None,
|
173
|
+
{"heads":8, "dtype":jnp.float16, "flash_attention":False, "use_projection":True, "use_self_and_cross":True},
|
174
|
+
{"heads":8, "dtype":jnp.float16, "flash_attention":False, "use_projection":True, "use_self_and_cross":True},
|
175
|
+
{"heads":8, "dtype":jnp.float16, "flash_attention":False, "use_projection":True, "use_self_and_cross":True},
|
176
|
+
{"heads":8, "dtype":jnp.float16, "flash_attention":False, "use_projection":False, "use_self_and_cross":False}
|
177
|
+
],
|
168
178
|
num_res_blocks=2,
|
169
|
-
num_middle_res_blocks=1
|
170
|
-
|
179
|
+
num_middle_res_blocks=1
|
180
|
+
)
|
171
181
|
# Load dataset
|
172
|
-
data
|
182
|
+
data = get_dataset_grain("oxford_flowers102", batch_size=BATCH_SIZE, image_scale=IMAGE_SIZE)
|
183
|
+
datalen = data['train_len']
|
173
184
|
batches = datalen // BATCH_SIZE
|
174
185
|
|
186
|
+
input_shapes = {
|
187
|
+
"x": (IMAGE_SIZE, IMAGE_SIZE, 3),
|
188
|
+
"temb": (),
|
189
|
+
"textcontext": (77, 768)
|
190
|
+
}
|
191
|
+
text_encoder = defaultTextEncodeModel()
|
192
|
+
|
193
|
+
# Construct a validation set by the prompts
|
194
|
+
val_prompts = ['water tulip', ' a water lily', ' a water lily', ' a photo of a rose', ' a photo of a rose', ' a water lily', ' a water lily', ' a photo of a marigold', ' a photo of a marigold']
|
195
|
+
|
196
|
+
def get_val_dataset(batch_size=8):
|
197
|
+
for i in range(0, len(val_prompts), batch_size):
|
198
|
+
prompts = val_prompts[i:i + batch_size]
|
199
|
+
tokens = text_encoder.tokenize(prompts)
|
200
|
+
yield tokens
|
201
|
+
|
202
|
+
data['test'] = get_val_dataset
|
203
|
+
data['test_len'] = len(val_prompts)
|
204
|
+
|
175
205
|
# Define optimizer
|
176
206
|
solver = optax.adam(2e-4)
|
177
207
|
|
178
208
|
# Create trainer
|
179
|
-
trainer = DiffusionTrainer(
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
209
|
+
trainer = DiffusionTrainer(
|
210
|
+
unet, optimizer=solver,
|
211
|
+
input_shapes=input_shapes,
|
212
|
+
noise_schedule=edm_schedule,
|
213
|
+
rngs=jax.random.PRNGKey(4),
|
214
|
+
name="Diffusion_SDE_VE_" + datetime.now().strftime("%Y-%m-%d_%H:%M:%S"),
|
215
|
+
model_output_transform=KarrasPredictionTransform(sigma_data=edm_schedule.sigma_data),
|
216
|
+
encoder=text_encoder,
|
217
|
+
distributed_training=True,
|
218
|
+
wandb_config = {
|
219
|
+
"project": 'mlops-msml605-project',
|
220
|
+
"name": f"prototype-{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}",
|
221
|
+
})
|
184
222
|
|
185
223
|
# Train the model
|
186
|
-
final_state = trainer.fit(data, batches, epochs=2000)
|
224
|
+
final_state = trainer.fit(data, batches, epochs=2000, sampler_class=EulerAncestralSampler, sampling_noise_schedule=karas_ve_schedule)
|
187
225
|
```
|
188
226
|
|
189
227
|
### Inference Example
|
@@ -301,8 +339,8 @@ Images generated by the following prompts using classifier free guidance with gu
|
|
301
339
|
`Training Epochs: 1000`
|
302
340
|
`Steps per epoch: 511`
|
303
341
|
|
304
|
-
`Training Noise Schedule:
|
305
|
-
`Inference Noise Schedule:
|
342
|
+
`Training Noise Schedule: CosineNoiseScheduler`
|
343
|
+
`Inference Noise Schedule: CosineNoiseScheduler`
|
306
344
|
|
307
345
|
`Model: UNet(emb_features=256,
|
308
346
|
feature_depths=[64, 128, 256, 512],
|
@@ -321,8 +359,8 @@ Images generated by the following prompts using classifier free guidance with gu
|
|
321
359
|
`Training Epochs: 1000`
|
322
360
|
`Steps per epoch: 511`
|
323
361
|
|
324
|
-
`Training Noise Schedule:
|
325
|
-
`Inference Noise Schedule:
|
362
|
+
`Training Noise Schedule: CosineNoiseScheduler`
|
363
|
+
`Inference Noise Schedule: CosineNoiseScheduler`
|
326
364
|
|
327
365
|
`Model: UNet(emb_features=256,
|
328
366
|
feature_depths=[64, 128, 256, 512],
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|