flaxdiff 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/__init__.py +0 -0
- flaxdiff/models/__init__.py +1 -0
- flaxdiff/models/attention.py +489 -0
- flaxdiff/models/common.py +7 -0
- flaxdiff/models/favor_fastattn.py +723 -0
- flaxdiff/models/simple_unet.py +519 -0
- flaxdiff/predictors/__init__.py +96 -0
- flaxdiff/samplers/__init__.py +7 -0
- flaxdiff/samplers/common.py +113 -0
- flaxdiff/samplers/ddim.py +10 -0
- flaxdiff/samplers/ddpm.py +43 -0
- flaxdiff/samplers/euler.py +59 -0
- flaxdiff/samplers/heun_sampler.py +28 -0
- flaxdiff/samplers/multistep_dpm.py +60 -0
- flaxdiff/samplers/rk4_sampler.py +34 -0
- flaxdiff/schedulers/__init__.py +6 -0
- flaxdiff/schedulers/common.py +98 -0
- flaxdiff/schedulers/continuous.py +12 -0
- flaxdiff/schedulers/cosine.py +40 -0
- flaxdiff/schedulers/discrete.py +74 -0
- flaxdiff/schedulers/exp.py +13 -0
- flaxdiff/schedulers/karras.py +69 -0
- flaxdiff/schedulers/linear.py +14 -0
- flaxdiff/schedulers/sqrt.py +10 -0
- flaxdiff/trainer/__init__.py +216 -0
- flaxdiff/utils.py +89 -0
- flaxdiff-0.1.1.dist-info/METADATA +326 -0
- flaxdiff-0.1.1.dist-info/RECORD +30 -0
- flaxdiff-0.1.1.dist-info/WHEEL +5 -0
- flaxdiff-0.1.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,69 @@
|
|
1
|
+
import jax.numpy as jnp
|
2
|
+
from .common import GeneralizedNoiseScheduler
|
3
|
+
import math
|
4
|
+
import jax
|
5
|
+
from ..utils import RandomMarkovState
|
6
|
+
|
7
|
+
class KarrasVENoiseScheduler(GeneralizedNoiseScheduler):
|
8
|
+
def __init__(self, timesteps, sigma_min=0.002, sigma_max=80, rho=7., sigma_data=0.5, *args, **kwargs):
|
9
|
+
super().__init__(timesteps=timesteps, sigma_min=sigma_min, sigma_max=sigma_max, sigma_data=sigma_data, *args, **kwargs)
|
10
|
+
self.min_inv_rho = sigma_min ** (1 / rho)
|
11
|
+
self.max_inv_rho = sigma_max ** (1 / rho)
|
12
|
+
self.rho = rho
|
13
|
+
|
14
|
+
def get_sigmas(self, steps) -> jnp.ndarray:
|
15
|
+
# steps = jnp.int16(steps)
|
16
|
+
# return self.sigmas[steps]
|
17
|
+
ramp = 1 - steps / self.max_timesteps
|
18
|
+
sigmas = (self.max_inv_rho + ramp * (self.min_inv_rho - self.max_inv_rho)) ** self.rho
|
19
|
+
return sigmas
|
20
|
+
|
21
|
+
def get_weights(self, steps, shape=(-1, 1, 1, 1)) -> jnp.ndarray:
|
22
|
+
sigma = self.get_sigmas(steps)
|
23
|
+
weights = ((sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2)
|
24
|
+
return weights.reshape(shape)
|
25
|
+
|
26
|
+
def transform_inputs(self, x, steps, num_discrete_chunks=1000) -> tuple[jnp.ndarray, jnp.ndarray]:
|
27
|
+
sigmas = self.get_sigmas(steps)
|
28
|
+
# sigmas = (sigmas / self.sigma_max) * num_discrete_chunks
|
29
|
+
sigmas = jnp.log(sigmas) / 4
|
30
|
+
return x, sigmas
|
31
|
+
|
32
|
+
def get_timesteps(self, sigmas:jnp.ndarray) -> jnp.ndarray:
|
33
|
+
sigmas = sigmas.reshape(-1)
|
34
|
+
inv_rho = sigmas ** (1 / self.rho)
|
35
|
+
ramp = ((inv_rho - self.max_inv_rho) / (self.min_inv_rho - self.max_inv_rho))
|
36
|
+
steps = 1 - ramp * self.max_timesteps
|
37
|
+
return steps
|
38
|
+
|
39
|
+
def generate_timesteps(self, batch_size, state:RandomMarkovState) -> tuple[jnp.ndarray, RandomMarkovState]:
|
40
|
+
timesteps, state = super().generate_timesteps(batch_size, state)
|
41
|
+
timesteps = timesteps.astype(jnp.float32)
|
42
|
+
return timesteps, state
|
43
|
+
|
44
|
+
class SimpleExpNoiseScheduler(KarrasVENoiseScheduler):
|
45
|
+
def __init__(self, timesteps, sigma_min=0.002, sigma_max=80, rho=7., sigma_data=0.5, *args, **kwargs):
|
46
|
+
super().__init__(timesteps=timesteps, sigma_min=sigma_min, sigma_max=sigma_max, sigma_data=sigma_data, *args, **kwargs)
|
47
|
+
if type(timesteps) == int and timesteps > 1:
|
48
|
+
n = timesteps
|
49
|
+
else:
|
50
|
+
n = 1000
|
51
|
+
self.sigmas = jnp.exp(jnp.linspace(math.log(sigma_min), math.log(sigma_max), n))
|
52
|
+
|
53
|
+
def get_sigmas(self, steps) -> jnp.ndarray:
|
54
|
+
steps = jnp.int16(steps)
|
55
|
+
return self.sigmas[steps]
|
56
|
+
|
57
|
+
class EDMNoiseScheduler(KarrasVENoiseScheduler):
|
58
|
+
def __init__(self, timesteps, sigma_min=0.002, sigma_max=80, rho=7., sigma_data=0.5, *args, **kwargs):
|
59
|
+
super().__init__(timesteps=timesteps, sigma_min=sigma_min, sigma_max=sigma_max, sigma_data=sigma_data, *args, **kwargs)
|
60
|
+
|
61
|
+
def get_sigmas(self, steps, std=1.2, mean=-1.2) -> jnp.ndarray:
|
62
|
+
space = steps / self.max_timesteps
|
63
|
+
# space = jax.scipy.special.erfinv(self.erf_sigma_min + steps * (self.erf_sigma_max - self.erf_sigma_min))
|
64
|
+
return jnp.exp(space * std + mean)
|
65
|
+
|
66
|
+
def generate_timesteps(self, batch_size, state:RandomMarkovState) -> tuple[jnp.ndarray, RandomMarkovState]:
|
67
|
+
state, rng = state.get_random_key()
|
68
|
+
timesteps = jax.random.normal(rng, (batch_size,), dtype=jnp.float32)
|
69
|
+
return timesteps, state
|
@@ -0,0 +1,14 @@
|
|
1
|
+
import numpy as np
|
2
|
+
from .discrete import DiscreteNoiseScheduler
|
3
|
+
|
4
|
+
def linear_beta_schedule(timesteps, beta_start=0.0001, beta_end=0.02):
|
5
|
+
scale = 1000 / timesteps
|
6
|
+
beta_start = scale * beta_start
|
7
|
+
beta_end = scale * beta_end
|
8
|
+
betas = np.linspace(
|
9
|
+
beta_start, beta_end, timesteps, dtype=np.float64)
|
10
|
+
return betas
|
11
|
+
|
12
|
+
class LinearNoiseSchedule(DiscreteNoiseScheduler):
|
13
|
+
def __init__(self, timesteps, beta_start=0.0001, beta_end=0.02, *args, **kwargs):
|
14
|
+
super().__init__(timesteps, beta_start, beta_end, schedule_fn=linear_beta_schedule, *args, **kwargs)
|
@@ -0,0 +1,10 @@
|
|
1
|
+
import numpy as np
|
2
|
+
import jax.numpy as jnp
|
3
|
+
from .discrete import DiscreteNoiseScheduler
|
4
|
+
from .continuous import ContinuousNoiseScheduler
|
5
|
+
|
6
|
+
class SqrtContinuousNoiseScheduler(ContinuousNoiseScheduler):
|
7
|
+
def get_rates(self, steps, shape=(-1, 1, 1, 1)) -> tuple[jnp.ndarray, jnp.ndarray]:
|
8
|
+
signal_rates = jnp.sqrt(1 - steps)
|
9
|
+
noise_rates = jnp.sqrt(steps)
|
10
|
+
return self.reshape_rates((signal_rates, noise_rates), shape=shape)
|
@@ -0,0 +1,216 @@
|
|
1
|
+
import orbax.checkpoint
|
2
|
+
import tqdm
|
3
|
+
from flax import linen as nn
|
4
|
+
import jax
|
5
|
+
from typing import Callable
|
6
|
+
from dataclasses import field
|
7
|
+
import jax.numpy as jnp
|
8
|
+
from clu import metrics
|
9
|
+
from flax.training import train_state # Useful dataclass to keep train state
|
10
|
+
import optax
|
11
|
+
from flax import struct # Flax dataclasses
|
12
|
+
import time
|
13
|
+
import os
|
14
|
+
import orbax
|
15
|
+
from flax.training import orbax_utils
|
16
|
+
|
17
|
+
from ..schedulers import NoiseScheduler
|
18
|
+
from ..predictors import DiffusionPredictionTransform, EpsilonPredictionTransform
|
19
|
+
|
20
|
+
@struct.dataclass
|
21
|
+
class Metrics(metrics.Collection):
|
22
|
+
loss: metrics.Average.from_output('loss') # type: ignore
|
23
|
+
|
24
|
+
class ModelState():
|
25
|
+
model: nn.Module
|
26
|
+
params: dict
|
27
|
+
noise_schedule: NoiseScheduler
|
28
|
+
model_output_transform: DiffusionPredictionTransform
|
29
|
+
|
30
|
+
# Define the TrainState with EMA parameters
|
31
|
+
class TrainState(train_state.TrainState):
|
32
|
+
rngs: jax.random.PRNGKey
|
33
|
+
ema_params: dict
|
34
|
+
|
35
|
+
def get_random_key(self):
|
36
|
+
rngs, subkey = jax.random.split(self.rngs)
|
37
|
+
return self.replace(rngs=rngs), subkey
|
38
|
+
|
39
|
+
def apply_ema(self, decay: float=0.999):
|
40
|
+
new_ema_params = jax.tree_util.tree_map(
|
41
|
+
lambda ema, param: decay * ema + (1 - decay) * param,
|
42
|
+
self.ema_params,
|
43
|
+
self.params,
|
44
|
+
)
|
45
|
+
return self.replace(ema_params=new_ema_params)
|
46
|
+
|
47
|
+
class DiffusionTrainer:
|
48
|
+
state : TrainState
|
49
|
+
best_state : TrainState
|
50
|
+
best_loss : float
|
51
|
+
model : nn.Module
|
52
|
+
noise_schedule : NoiseScheduler
|
53
|
+
model_output_transform:DiffusionPredictionTransform
|
54
|
+
ema_decay:float = 0.999
|
55
|
+
|
56
|
+
def __init__(self,
|
57
|
+
model:nn.Module,
|
58
|
+
optimizer: optax.GradientTransformation,
|
59
|
+
noise_schedule:NoiseScheduler,
|
60
|
+
rngs:jax.random.PRNGKey,
|
61
|
+
train_state:TrainState=None,
|
62
|
+
name:str="Diffusion",
|
63
|
+
load_from_checkpoint:bool=False,
|
64
|
+
param_transforms:Callable=None,
|
65
|
+
model_output_transform:DiffusionPredictionTransform=EpsilonPredictionTransform(),
|
66
|
+
loss_fn=optax.l2_loss,
|
67
|
+
):
|
68
|
+
self.model = model
|
69
|
+
self.noise_schedule = noise_schedule
|
70
|
+
self.name = name
|
71
|
+
self.model_output_transform = model_output_transform
|
72
|
+
self.loss_fn = loss_fn
|
73
|
+
|
74
|
+
checkpointer = orbax.checkpoint.PyTreeCheckpointer()
|
75
|
+
options = orbax.checkpoint.CheckpointManagerOptions(max_to_keep=4, create=True)
|
76
|
+
self.checkpointer = orbax.checkpoint.CheckpointManager(self.checkpoint_path(), checkpointer, options)
|
77
|
+
|
78
|
+
if load_from_checkpoint:
|
79
|
+
params = self.load()
|
80
|
+
else:
|
81
|
+
params = None
|
82
|
+
|
83
|
+
if train_state == None:
|
84
|
+
self.init_state(optimizer, rngs, params=params, model=model, param_transforms=param_transforms)
|
85
|
+
else:
|
86
|
+
self.state = train_state
|
87
|
+
self.best_state = train_state
|
88
|
+
self.best_loss = 1e9
|
89
|
+
|
90
|
+
def init_state(self,
|
91
|
+
optimizer: optax.GradientTransformation,
|
92
|
+
rngs:jax.random.PRNGKey,
|
93
|
+
params:dict=None,
|
94
|
+
model:nn.Module=None,
|
95
|
+
param_transforms:Callable=None,
|
96
|
+
batch_size=16,
|
97
|
+
image_size=64
|
98
|
+
):
|
99
|
+
inp = jnp.ones((batch_size, image_size, image_size, 3))
|
100
|
+
temb = jnp.ones((batch_size,))
|
101
|
+
rngs, subkey = jax.random.split(rngs)
|
102
|
+
if params == None:
|
103
|
+
params = model.init(subkey, inp, temb)
|
104
|
+
if param_transforms is not None:
|
105
|
+
params = param_transforms(params)
|
106
|
+
self.best_loss = 1e9
|
107
|
+
self.state = TrainState.create(
|
108
|
+
apply_fn=model.apply,
|
109
|
+
params=params,
|
110
|
+
ema_params=params,
|
111
|
+
tx=optimizer,
|
112
|
+
rngs=rngs,
|
113
|
+
)
|
114
|
+
self.best_state = self.state
|
115
|
+
|
116
|
+
def checkpoint_path(self):
|
117
|
+
experiment_name = self.name
|
118
|
+
path = os.path.join(os.path.abspath('./models'), experiment_name)
|
119
|
+
if not os.path.exists(path):
|
120
|
+
os.makedirs(path)
|
121
|
+
return path
|
122
|
+
|
123
|
+
def load(self):
|
124
|
+
step = self.checkpointer.latest_step()
|
125
|
+
print("Loading model from checkpoint", step)
|
126
|
+
ckpt = self.checkpointer.restore(step)
|
127
|
+
state = ckpt['state']
|
128
|
+
# Convert the state to a TrainState
|
129
|
+
self.best_loss = ckpt['best_loss']
|
130
|
+
print(f"Loaded model from checkpoint at step {step}", ckpt['best_loss'])
|
131
|
+
return state.get('params', None)#, ckpt.get('model', None)
|
132
|
+
|
133
|
+
def save(self, epoch=0, best=False):
|
134
|
+
print(f"Saving model at epoch {epoch}")
|
135
|
+
state = self.best_state if best else self.state
|
136
|
+
# filename = os.path.join(self.checkpoint_path(), f'model_{epoch}' if not best else 'best_model')
|
137
|
+
ckpt = {
|
138
|
+
'model': self.model,
|
139
|
+
'state': state,
|
140
|
+
'best_loss': self.best_loss
|
141
|
+
}
|
142
|
+
save_args = orbax_utils.save_args_from_target(ckpt)
|
143
|
+
self.checkpointer.save(epoch, ckpt, save_kwargs={'save_args': save_args})
|
144
|
+
|
145
|
+
def summary(self, image_size=64):
|
146
|
+
inp = jnp.ones((1, image_size, image_size, 3))
|
147
|
+
temb = jnp.ones((1,))
|
148
|
+
print(self.model.tabulate(jax.random.key(0), inp, temb, console_kwargs={"width": 200, "force_jupyter":True, }))
|
149
|
+
|
150
|
+
def _define_train_step(self):
|
151
|
+
noise_schedule = self.noise_schedule
|
152
|
+
model = self.model
|
153
|
+
model_output_transform = self.model_output_transform
|
154
|
+
loss_fn = self.loss_fn
|
155
|
+
@jax.jit
|
156
|
+
def train_step(state:TrainState, batch):
|
157
|
+
"""Train for a single step."""
|
158
|
+
images = batch
|
159
|
+
noise_level, state = noise_schedule.generate_timesteps(images.shape[0], state)
|
160
|
+
state, rngs = state.get_random_key()
|
161
|
+
noise:jax.Array = jax.random.normal(rngs, shape=images.shape)
|
162
|
+
rates = noise_schedule.get_rates(noise_level)
|
163
|
+
noisy_images, c_in, expected_output = model_output_transform.forward_diffusion(images, noise, rates)
|
164
|
+
def model_loss(params):
|
165
|
+
preds = model.apply(params, *noise_schedule.transform_inputs(noisy_images*c_in, noise_level))
|
166
|
+
preds = model_output_transform.pred_transform(noisy_images, preds, rates)
|
167
|
+
nloss = loss_fn(preds, expected_output)
|
168
|
+
# nloss = jnp.mean(nloss, axis=1)
|
169
|
+
nloss *= noise_schedule.get_weights(noise_level)
|
170
|
+
nloss = jnp.mean(nloss)
|
171
|
+
loss = nloss
|
172
|
+
return loss
|
173
|
+
loss, grads = jax.value_and_grad(model_loss)(state.params)
|
174
|
+
state = state.apply_gradients(grads=grads)
|
175
|
+
state = state.apply_ema(self.ema_decay)
|
176
|
+
return state, loss
|
177
|
+
return train_step
|
178
|
+
|
179
|
+
def _define_compute_metrics(self):
|
180
|
+
@jax.jit
|
181
|
+
def compute_metrics(state:TrainState, expected, pred):
|
182
|
+
loss = jnp.mean(jnp.square(pred - expected))
|
183
|
+
metric_updates = state.metrics.single_from_model_output(loss=loss)
|
184
|
+
metrics = state.metrics.merge(metric_updates)
|
185
|
+
state = state.replace(metrics=metrics)
|
186
|
+
return state
|
187
|
+
return compute_metrics
|
188
|
+
|
189
|
+
def fit(self, data, steps_per_epoch, epochs):
|
190
|
+
data = iter(data)
|
191
|
+
train_step = self._define_train_step()
|
192
|
+
compute_metrics = self._define_compute_metrics()
|
193
|
+
state = self.state
|
194
|
+
for epoch in range(epochs):
|
195
|
+
print(f"\nEpoch {epoch+1}/{epochs}")
|
196
|
+
start_time = time.time()
|
197
|
+
epoch_loss = 0
|
198
|
+
with tqdm.tqdm(total=steps_per_epoch, desc=f'\t\tEpoch {epoch+1}', ncols=100, unit='step') as pbar:
|
199
|
+
for i in range(steps_per_epoch):
|
200
|
+
batch = next(data)
|
201
|
+
state, loss = train_step(state, batch)
|
202
|
+
epoch_loss += loss
|
203
|
+
if i % 100 == 0:
|
204
|
+
pbar.set_postfix(loss=f'{loss:.4f}')
|
205
|
+
pbar.update(100)
|
206
|
+
end_time = time.time()
|
207
|
+
self.state = state
|
208
|
+
total_time = end_time - start_time
|
209
|
+
avg_time_per_step = total_time / steps_per_epoch
|
210
|
+
avg_loss = epoch_loss / steps_per_epoch
|
211
|
+
if avg_loss < self.best_loss:
|
212
|
+
self.best_loss = avg_loss
|
213
|
+
self.best_state = state
|
214
|
+
self.save(epoch, best=True)
|
215
|
+
print(f"\n\tEpoch {epoch+1} completed. Avg Loss: {avg_loss}, Time: {total_time:.2f}s, Best Loss: {self.best_loss}")
|
216
|
+
return self.state
|
flaxdiff/utils.py
ADDED
@@ -0,0 +1,89 @@
|
|
1
|
+
import jax
|
2
|
+
import jax.numpy as jnp
|
3
|
+
import flax.struct as struct
|
4
|
+
import flax.linen as nn
|
5
|
+
from typing import Any
|
6
|
+
|
7
|
+
class MarkovState(struct.PyTreeNode):
|
8
|
+
pass
|
9
|
+
|
10
|
+
class RandomMarkovState(MarkovState):
|
11
|
+
rng: jax.random.PRNGKey
|
12
|
+
|
13
|
+
def get_random_key(self):
|
14
|
+
rng, subkey = jax.random.split(self.rng)
|
15
|
+
return RandomMarkovState(rng), subkey
|
16
|
+
|
17
|
+
def clip_images(images, clip_min=-1, clip_max=1):
|
18
|
+
return jnp.clip(images, clip_min, clip_max)
|
19
|
+
|
20
|
+
class RMSNorm(nn.Module):
|
21
|
+
"""
|
22
|
+
From "Root Mean Square Layer Normalization" by https://arxiv.org/abs/1910.07467
|
23
|
+
|
24
|
+
Adapted from flax.linen.LayerNorm
|
25
|
+
"""
|
26
|
+
|
27
|
+
epsilon: float = 1e-6
|
28
|
+
dtype: Any = jnp.float32
|
29
|
+
param_dtype: Any = jnp.float32
|
30
|
+
use_scale: bool = True
|
31
|
+
scale_init: Any = jax.nn.initializers.ones
|
32
|
+
|
33
|
+
@nn.compact
|
34
|
+
def __call__(self, x):
|
35
|
+
reduction_axes = (-1,)
|
36
|
+
feature_axes = (-1,)
|
37
|
+
|
38
|
+
rms_sq = self._compute_rms_sq(x, reduction_axes)
|
39
|
+
|
40
|
+
return self._normalize(
|
41
|
+
self,
|
42
|
+
x,
|
43
|
+
rms_sq,
|
44
|
+
reduction_axes,
|
45
|
+
feature_axes,
|
46
|
+
self.dtype,
|
47
|
+
self.param_dtype,
|
48
|
+
self.epsilon,
|
49
|
+
self.use_scale,
|
50
|
+
self.scale_init,
|
51
|
+
)
|
52
|
+
|
53
|
+
def _compute_rms_sq(self, x, axes):
|
54
|
+
x = jnp.asarray(x, jnp.promote_types(jnp.float32, jnp.result_type(x)))
|
55
|
+
rms_sq = jnp.mean(jax.lax.square(x), axes)
|
56
|
+
return rms_sq
|
57
|
+
|
58
|
+
def _normalize(
|
59
|
+
self,
|
60
|
+
mdl,
|
61
|
+
x,
|
62
|
+
rms_sq,
|
63
|
+
reduction_axes,
|
64
|
+
feature_axes,
|
65
|
+
dtype,
|
66
|
+
param_dtype,
|
67
|
+
epsilon,
|
68
|
+
use_scale,
|
69
|
+
scale_init,
|
70
|
+
):
|
71
|
+
reduction_axes = nn.normalization._canonicalize_axes(x.ndim, reduction_axes)
|
72
|
+
feature_axes = nn.normalization._canonicalize_axes(x.ndim, feature_axes)
|
73
|
+
stats_shape = list(x.shape)
|
74
|
+
for axis in reduction_axes:
|
75
|
+
stats_shape[axis] = 1
|
76
|
+
rms_sq = rms_sq.reshape(stats_shape)
|
77
|
+
feature_shape = [1] * x.ndim
|
78
|
+
reduced_feature_shape = []
|
79
|
+
for ax in feature_axes:
|
80
|
+
feature_shape[ax] = x.shape[ax]
|
81
|
+
reduced_feature_shape.append(x.shape[ax])
|
82
|
+
mul = jax.lax.rsqrt(rms_sq + epsilon)
|
83
|
+
if use_scale:
|
84
|
+
scale = mdl.param(
|
85
|
+
"scale", scale_init, reduced_feature_shape, param_dtype
|
86
|
+
).reshape(feature_shape)
|
87
|
+
mul *= scale
|
88
|
+
y = mul * x
|
89
|
+
return jnp.asarray(y, dtype)
|