flaxdiff 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,69 @@
1
+ import jax.numpy as jnp
2
+ from .common import GeneralizedNoiseScheduler
3
+ import math
4
+ import jax
5
+ from ..utils import RandomMarkovState
6
+
7
+ class KarrasVENoiseScheduler(GeneralizedNoiseScheduler):
8
+ def __init__(self, timesteps, sigma_min=0.002, sigma_max=80, rho=7., sigma_data=0.5, *args, **kwargs):
9
+ super().__init__(timesteps=timesteps, sigma_min=sigma_min, sigma_max=sigma_max, sigma_data=sigma_data, *args, **kwargs)
10
+ self.min_inv_rho = sigma_min ** (1 / rho)
11
+ self.max_inv_rho = sigma_max ** (1 / rho)
12
+ self.rho = rho
13
+
14
+ def get_sigmas(self, steps) -> jnp.ndarray:
15
+ # steps = jnp.int16(steps)
16
+ # return self.sigmas[steps]
17
+ ramp = 1 - steps / self.max_timesteps
18
+ sigmas = (self.max_inv_rho + ramp * (self.min_inv_rho - self.max_inv_rho)) ** self.rho
19
+ return sigmas
20
+
21
+ def get_weights(self, steps, shape=(-1, 1, 1, 1)) -> jnp.ndarray:
22
+ sigma = self.get_sigmas(steps)
23
+ weights = ((sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2)
24
+ return weights.reshape(shape)
25
+
26
+ def transform_inputs(self, x, steps, num_discrete_chunks=1000) -> tuple[jnp.ndarray, jnp.ndarray]:
27
+ sigmas = self.get_sigmas(steps)
28
+ # sigmas = (sigmas / self.sigma_max) * num_discrete_chunks
29
+ sigmas = jnp.log(sigmas) / 4
30
+ return x, sigmas
31
+
32
+ def get_timesteps(self, sigmas:jnp.ndarray) -> jnp.ndarray:
33
+ sigmas = sigmas.reshape(-1)
34
+ inv_rho = sigmas ** (1 / self.rho)
35
+ ramp = ((inv_rho - self.max_inv_rho) / (self.min_inv_rho - self.max_inv_rho))
36
+ steps = 1 - ramp * self.max_timesteps
37
+ return steps
38
+
39
+ def generate_timesteps(self, batch_size, state:RandomMarkovState) -> tuple[jnp.ndarray, RandomMarkovState]:
40
+ timesteps, state = super().generate_timesteps(batch_size, state)
41
+ timesteps = timesteps.astype(jnp.float32)
42
+ return timesteps, state
43
+
44
+ class SimpleExpNoiseScheduler(KarrasVENoiseScheduler):
45
+ def __init__(self, timesteps, sigma_min=0.002, sigma_max=80, rho=7., sigma_data=0.5, *args, **kwargs):
46
+ super().__init__(timesteps=timesteps, sigma_min=sigma_min, sigma_max=sigma_max, sigma_data=sigma_data, *args, **kwargs)
47
+ if type(timesteps) == int and timesteps > 1:
48
+ n = timesteps
49
+ else:
50
+ n = 1000
51
+ self.sigmas = jnp.exp(jnp.linspace(math.log(sigma_min), math.log(sigma_max), n))
52
+
53
+ def get_sigmas(self, steps) -> jnp.ndarray:
54
+ steps = jnp.int16(steps)
55
+ return self.sigmas[steps]
56
+
57
+ class EDMNoiseScheduler(KarrasVENoiseScheduler):
58
+ def __init__(self, timesteps, sigma_min=0.002, sigma_max=80, rho=7., sigma_data=0.5, *args, **kwargs):
59
+ super().__init__(timesteps=timesteps, sigma_min=sigma_min, sigma_max=sigma_max, sigma_data=sigma_data, *args, **kwargs)
60
+
61
+ def get_sigmas(self, steps, std=1.2, mean=-1.2) -> jnp.ndarray:
62
+ space = steps / self.max_timesteps
63
+ # space = jax.scipy.special.erfinv(self.erf_sigma_min + steps * (self.erf_sigma_max - self.erf_sigma_min))
64
+ return jnp.exp(space * std + mean)
65
+
66
+ def generate_timesteps(self, batch_size, state:RandomMarkovState) -> tuple[jnp.ndarray, RandomMarkovState]:
67
+ state, rng = state.get_random_key()
68
+ timesteps = jax.random.normal(rng, (batch_size,), dtype=jnp.float32)
69
+ return timesteps, state
@@ -0,0 +1,14 @@
1
+ import numpy as np
2
+ from .discrete import DiscreteNoiseScheduler
3
+
4
+ def linear_beta_schedule(timesteps, beta_start=0.0001, beta_end=0.02):
5
+ scale = 1000 / timesteps
6
+ beta_start = scale * beta_start
7
+ beta_end = scale * beta_end
8
+ betas = np.linspace(
9
+ beta_start, beta_end, timesteps, dtype=np.float64)
10
+ return betas
11
+
12
+ class LinearNoiseSchedule(DiscreteNoiseScheduler):
13
+ def __init__(self, timesteps, beta_start=0.0001, beta_end=0.02, *args, **kwargs):
14
+ super().__init__(timesteps, beta_start, beta_end, schedule_fn=linear_beta_schedule, *args, **kwargs)
@@ -0,0 +1,10 @@
1
+ import numpy as np
2
+ import jax.numpy as jnp
3
+ from .discrete import DiscreteNoiseScheduler
4
+ from .continuous import ContinuousNoiseScheduler
5
+
6
+ class SqrtContinuousNoiseScheduler(ContinuousNoiseScheduler):
7
+ def get_rates(self, steps, shape=(-1, 1, 1, 1)) -> tuple[jnp.ndarray, jnp.ndarray]:
8
+ signal_rates = jnp.sqrt(1 - steps)
9
+ noise_rates = jnp.sqrt(steps)
10
+ return self.reshape_rates((signal_rates, noise_rates), shape=shape)
@@ -0,0 +1,216 @@
1
+ import orbax.checkpoint
2
+ import tqdm
3
+ from flax import linen as nn
4
+ import jax
5
+ from typing import Callable
6
+ from dataclasses import field
7
+ import jax.numpy as jnp
8
+ from clu import metrics
9
+ from flax.training import train_state # Useful dataclass to keep train state
10
+ import optax
11
+ from flax import struct # Flax dataclasses
12
+ import time
13
+ import os
14
+ import orbax
15
+ from flax.training import orbax_utils
16
+
17
+ from ..schedulers import NoiseScheduler
18
+ from ..predictors import DiffusionPredictionTransform, EpsilonPredictionTransform
19
+
20
+ @struct.dataclass
21
+ class Metrics(metrics.Collection):
22
+ loss: metrics.Average.from_output('loss') # type: ignore
23
+
24
+ class ModelState():
25
+ model: nn.Module
26
+ params: dict
27
+ noise_schedule: NoiseScheduler
28
+ model_output_transform: DiffusionPredictionTransform
29
+
30
+ # Define the TrainState with EMA parameters
31
+ class TrainState(train_state.TrainState):
32
+ rngs: jax.random.PRNGKey
33
+ ema_params: dict
34
+
35
+ def get_random_key(self):
36
+ rngs, subkey = jax.random.split(self.rngs)
37
+ return self.replace(rngs=rngs), subkey
38
+
39
+ def apply_ema(self, decay: float=0.999):
40
+ new_ema_params = jax.tree_util.tree_map(
41
+ lambda ema, param: decay * ema + (1 - decay) * param,
42
+ self.ema_params,
43
+ self.params,
44
+ )
45
+ return self.replace(ema_params=new_ema_params)
46
+
47
+ class DiffusionTrainer:
48
+ state : TrainState
49
+ best_state : TrainState
50
+ best_loss : float
51
+ model : nn.Module
52
+ noise_schedule : NoiseScheduler
53
+ model_output_transform:DiffusionPredictionTransform
54
+ ema_decay:float = 0.999
55
+
56
+ def __init__(self,
57
+ model:nn.Module,
58
+ optimizer: optax.GradientTransformation,
59
+ noise_schedule:NoiseScheduler,
60
+ rngs:jax.random.PRNGKey,
61
+ train_state:TrainState=None,
62
+ name:str="Diffusion",
63
+ load_from_checkpoint:bool=False,
64
+ param_transforms:Callable=None,
65
+ model_output_transform:DiffusionPredictionTransform=EpsilonPredictionTransform(),
66
+ loss_fn=optax.l2_loss,
67
+ ):
68
+ self.model = model
69
+ self.noise_schedule = noise_schedule
70
+ self.name = name
71
+ self.model_output_transform = model_output_transform
72
+ self.loss_fn = loss_fn
73
+
74
+ checkpointer = orbax.checkpoint.PyTreeCheckpointer()
75
+ options = orbax.checkpoint.CheckpointManagerOptions(max_to_keep=4, create=True)
76
+ self.checkpointer = orbax.checkpoint.CheckpointManager(self.checkpoint_path(), checkpointer, options)
77
+
78
+ if load_from_checkpoint:
79
+ params = self.load()
80
+ else:
81
+ params = None
82
+
83
+ if train_state == None:
84
+ self.init_state(optimizer, rngs, params=params, model=model, param_transforms=param_transforms)
85
+ else:
86
+ self.state = train_state
87
+ self.best_state = train_state
88
+ self.best_loss = 1e9
89
+
90
+ def init_state(self,
91
+ optimizer: optax.GradientTransformation,
92
+ rngs:jax.random.PRNGKey,
93
+ params:dict=None,
94
+ model:nn.Module=None,
95
+ param_transforms:Callable=None,
96
+ batch_size=16,
97
+ image_size=64
98
+ ):
99
+ inp = jnp.ones((batch_size, image_size, image_size, 3))
100
+ temb = jnp.ones((batch_size,))
101
+ rngs, subkey = jax.random.split(rngs)
102
+ if params == None:
103
+ params = model.init(subkey, inp, temb)
104
+ if param_transforms is not None:
105
+ params = param_transforms(params)
106
+ self.best_loss = 1e9
107
+ self.state = TrainState.create(
108
+ apply_fn=model.apply,
109
+ params=params,
110
+ ema_params=params,
111
+ tx=optimizer,
112
+ rngs=rngs,
113
+ )
114
+ self.best_state = self.state
115
+
116
+ def checkpoint_path(self):
117
+ experiment_name = self.name
118
+ path = os.path.join(os.path.abspath('./models'), experiment_name)
119
+ if not os.path.exists(path):
120
+ os.makedirs(path)
121
+ return path
122
+
123
+ def load(self):
124
+ step = self.checkpointer.latest_step()
125
+ print("Loading model from checkpoint", step)
126
+ ckpt = self.checkpointer.restore(step)
127
+ state = ckpt['state']
128
+ # Convert the state to a TrainState
129
+ self.best_loss = ckpt['best_loss']
130
+ print(f"Loaded model from checkpoint at step {step}", ckpt['best_loss'])
131
+ return state.get('params', None)#, ckpt.get('model', None)
132
+
133
+ def save(self, epoch=0, best=False):
134
+ print(f"Saving model at epoch {epoch}")
135
+ state = self.best_state if best else self.state
136
+ # filename = os.path.join(self.checkpoint_path(), f'model_{epoch}' if not best else 'best_model')
137
+ ckpt = {
138
+ 'model': self.model,
139
+ 'state': state,
140
+ 'best_loss': self.best_loss
141
+ }
142
+ save_args = orbax_utils.save_args_from_target(ckpt)
143
+ self.checkpointer.save(epoch, ckpt, save_kwargs={'save_args': save_args})
144
+
145
+ def summary(self, image_size=64):
146
+ inp = jnp.ones((1, image_size, image_size, 3))
147
+ temb = jnp.ones((1,))
148
+ print(self.model.tabulate(jax.random.key(0), inp, temb, console_kwargs={"width": 200, "force_jupyter":True, }))
149
+
150
+ def _define_train_step(self):
151
+ noise_schedule = self.noise_schedule
152
+ model = self.model
153
+ model_output_transform = self.model_output_transform
154
+ loss_fn = self.loss_fn
155
+ @jax.jit
156
+ def train_step(state:TrainState, batch):
157
+ """Train for a single step."""
158
+ images = batch
159
+ noise_level, state = noise_schedule.generate_timesteps(images.shape[0], state)
160
+ state, rngs = state.get_random_key()
161
+ noise:jax.Array = jax.random.normal(rngs, shape=images.shape)
162
+ rates = noise_schedule.get_rates(noise_level)
163
+ noisy_images, c_in, expected_output = model_output_transform.forward_diffusion(images, noise, rates)
164
+ def model_loss(params):
165
+ preds = model.apply(params, *noise_schedule.transform_inputs(noisy_images*c_in, noise_level))
166
+ preds = model_output_transform.pred_transform(noisy_images, preds, rates)
167
+ nloss = loss_fn(preds, expected_output)
168
+ # nloss = jnp.mean(nloss, axis=1)
169
+ nloss *= noise_schedule.get_weights(noise_level)
170
+ nloss = jnp.mean(nloss)
171
+ loss = nloss
172
+ return loss
173
+ loss, grads = jax.value_and_grad(model_loss)(state.params)
174
+ state = state.apply_gradients(grads=grads)
175
+ state = state.apply_ema(self.ema_decay)
176
+ return state, loss
177
+ return train_step
178
+
179
+ def _define_compute_metrics(self):
180
+ @jax.jit
181
+ def compute_metrics(state:TrainState, expected, pred):
182
+ loss = jnp.mean(jnp.square(pred - expected))
183
+ metric_updates = state.metrics.single_from_model_output(loss=loss)
184
+ metrics = state.metrics.merge(metric_updates)
185
+ state = state.replace(metrics=metrics)
186
+ return state
187
+ return compute_metrics
188
+
189
+ def fit(self, data, steps_per_epoch, epochs):
190
+ data = iter(data)
191
+ train_step = self._define_train_step()
192
+ compute_metrics = self._define_compute_metrics()
193
+ state = self.state
194
+ for epoch in range(epochs):
195
+ print(f"\nEpoch {epoch+1}/{epochs}")
196
+ start_time = time.time()
197
+ epoch_loss = 0
198
+ with tqdm.tqdm(total=steps_per_epoch, desc=f'\t\tEpoch {epoch+1}', ncols=100, unit='step') as pbar:
199
+ for i in range(steps_per_epoch):
200
+ batch = next(data)
201
+ state, loss = train_step(state, batch)
202
+ epoch_loss += loss
203
+ if i % 100 == 0:
204
+ pbar.set_postfix(loss=f'{loss:.4f}')
205
+ pbar.update(100)
206
+ end_time = time.time()
207
+ self.state = state
208
+ total_time = end_time - start_time
209
+ avg_time_per_step = total_time / steps_per_epoch
210
+ avg_loss = epoch_loss / steps_per_epoch
211
+ if avg_loss < self.best_loss:
212
+ self.best_loss = avg_loss
213
+ self.best_state = state
214
+ self.save(epoch, best=True)
215
+ print(f"\n\tEpoch {epoch+1} completed. Avg Loss: {avg_loss}, Time: {total_time:.2f}s, Best Loss: {self.best_loss}")
216
+ return self.state
flaxdiff/utils.py ADDED
@@ -0,0 +1,89 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ import flax.struct as struct
4
+ import flax.linen as nn
5
+ from typing import Any
6
+
7
+ class MarkovState(struct.PyTreeNode):
8
+ pass
9
+
10
+ class RandomMarkovState(MarkovState):
11
+ rng: jax.random.PRNGKey
12
+
13
+ def get_random_key(self):
14
+ rng, subkey = jax.random.split(self.rng)
15
+ return RandomMarkovState(rng), subkey
16
+
17
+ def clip_images(images, clip_min=-1, clip_max=1):
18
+ return jnp.clip(images, clip_min, clip_max)
19
+
20
+ class RMSNorm(nn.Module):
21
+ """
22
+ From "Root Mean Square Layer Normalization" by https://arxiv.org/abs/1910.07467
23
+
24
+ Adapted from flax.linen.LayerNorm
25
+ """
26
+
27
+ epsilon: float = 1e-6
28
+ dtype: Any = jnp.float32
29
+ param_dtype: Any = jnp.float32
30
+ use_scale: bool = True
31
+ scale_init: Any = jax.nn.initializers.ones
32
+
33
+ @nn.compact
34
+ def __call__(self, x):
35
+ reduction_axes = (-1,)
36
+ feature_axes = (-1,)
37
+
38
+ rms_sq = self._compute_rms_sq(x, reduction_axes)
39
+
40
+ return self._normalize(
41
+ self,
42
+ x,
43
+ rms_sq,
44
+ reduction_axes,
45
+ feature_axes,
46
+ self.dtype,
47
+ self.param_dtype,
48
+ self.epsilon,
49
+ self.use_scale,
50
+ self.scale_init,
51
+ )
52
+
53
+ def _compute_rms_sq(self, x, axes):
54
+ x = jnp.asarray(x, jnp.promote_types(jnp.float32, jnp.result_type(x)))
55
+ rms_sq = jnp.mean(jax.lax.square(x), axes)
56
+ return rms_sq
57
+
58
+ def _normalize(
59
+ self,
60
+ mdl,
61
+ x,
62
+ rms_sq,
63
+ reduction_axes,
64
+ feature_axes,
65
+ dtype,
66
+ param_dtype,
67
+ epsilon,
68
+ use_scale,
69
+ scale_init,
70
+ ):
71
+ reduction_axes = nn.normalization._canonicalize_axes(x.ndim, reduction_axes)
72
+ feature_axes = nn.normalization._canonicalize_axes(x.ndim, feature_axes)
73
+ stats_shape = list(x.shape)
74
+ for axis in reduction_axes:
75
+ stats_shape[axis] = 1
76
+ rms_sq = rms_sq.reshape(stats_shape)
77
+ feature_shape = [1] * x.ndim
78
+ reduced_feature_shape = []
79
+ for ax in feature_axes:
80
+ feature_shape[ax] = x.shape[ax]
81
+ reduced_feature_shape.append(x.shape[ax])
82
+ mul = jax.lax.rsqrt(rms_sq + epsilon)
83
+ if use_scale:
84
+ scale = mdl.param(
85
+ "scale", scale_init, reduced_feature_shape, param_dtype
86
+ ).reshape(feature_shape)
87
+ mul *= scale
88
+ y = mul * x
89
+ return jnp.asarray(y, dtype)