flaxdiff 0.1.4__py3-none-any.whl → 0.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/models/attention.py +140 -162
- flaxdiff/models/autoencoder/__init__.py +2 -0
- flaxdiff/models/autoencoder/autoencoder.py +19 -0
- flaxdiff/models/autoencoder/diffusers.py +91 -0
- flaxdiff/models/autoencoder/simple_autoenc.py +26 -0
- flaxdiff/models/common.py +322 -0
- flaxdiff/models/simple_unet.py +21 -327
- flaxdiff/trainer/__init__.py +2 -201
- flaxdiff/trainer/autoencoder_trainer.py +182 -0
- flaxdiff/trainer/diffusion_trainer.py +202 -0
- flaxdiff/trainer/simple_trainer.py +175 -80
- {flaxdiff-0.1.4.dist-info → flaxdiff-0.1.6.dist-info}/METADATA +12 -2
- {flaxdiff-0.1.4.dist-info → flaxdiff-0.1.6.dist-info}/RECORD +15 -9
- {flaxdiff-0.1.4.dist-info → flaxdiff-0.1.6.dist-info}/WHEEL +0 -0
- {flaxdiff-0.1.4.dist-info → flaxdiff-0.1.6.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,182 @@
|
|
1
|
+
from flax import linen as nn
|
2
|
+
import jax
|
3
|
+
from typing import Callable
|
4
|
+
from dataclasses import field
|
5
|
+
import jax.numpy as jnp
|
6
|
+
import optax
|
7
|
+
from jax.sharding import Mesh, PartitionSpec as P
|
8
|
+
from jax.experimental.shard_map import shard_map
|
9
|
+
from typing import Dict, Callable, Sequence, Any, Union, Tuple
|
10
|
+
|
11
|
+
from ..schedulers import NoiseScheduler
|
12
|
+
from ..predictors import DiffusionPredictionTransform, EpsilonPredictionTransform
|
13
|
+
|
14
|
+
from flaxdiff.utils import RandomMarkovState
|
15
|
+
|
16
|
+
from .simple_trainer import SimpleTrainer, SimpleTrainState, Metrics
|
17
|
+
|
18
|
+
from flaxdiff.models.autoencoder.autoencoder import AutoEncoder
|
19
|
+
|
20
|
+
class AutoEncoderTrainer(SimpleTrainer):
|
21
|
+
def __init__(self,
|
22
|
+
model: nn.Module,
|
23
|
+
input_shape: Union[int, int, int],
|
24
|
+
latent_dim: int,
|
25
|
+
spatial_scale: int,
|
26
|
+
optimizer: optax.GradientTransformation,
|
27
|
+
rngs: jax.random.PRNGKey,
|
28
|
+
name: str = "Autoencoder",
|
29
|
+
**kwargs
|
30
|
+
):
|
31
|
+
super().__init__(
|
32
|
+
model=model,
|
33
|
+
input_shapes={"image": input_shape},
|
34
|
+
optimizer=optimizer,
|
35
|
+
rngs=rngs,
|
36
|
+
name=name,
|
37
|
+
**kwargs
|
38
|
+
)
|
39
|
+
self.latent_dim = latent_dim
|
40
|
+
self.spatial_scale = spatial_scale
|
41
|
+
|
42
|
+
|
43
|
+
def generate_states(
|
44
|
+
self,
|
45
|
+
optimizer: optax.GradientTransformation,
|
46
|
+
rngs: jax.random.PRNGKey,
|
47
|
+
existing_state: dict = None,
|
48
|
+
existing_best_state: dict = None,
|
49
|
+
model: nn.Module = None,
|
50
|
+
param_transforms: Callable = None
|
51
|
+
) -> Tuple[TrainState, TrainState]:
|
52
|
+
print("Generating states for DiffusionTrainer")
|
53
|
+
rngs, subkey = jax.random.split(rngs)
|
54
|
+
|
55
|
+
if existing_state == None:
|
56
|
+
input_vars = self.get_input_ones()
|
57
|
+
params = model.init(subkey, **input_vars)
|
58
|
+
new_state = {"params": params, "ema_params": params}
|
59
|
+
else:
|
60
|
+
new_state = existing_state
|
61
|
+
|
62
|
+
if param_transforms is not None:
|
63
|
+
params = param_transforms(params)
|
64
|
+
|
65
|
+
state = TrainState.create(
|
66
|
+
apply_fn=model.apply,
|
67
|
+
params=new_state['params'],
|
68
|
+
ema_params=new_state['ema_params'],
|
69
|
+
tx=optimizer,
|
70
|
+
rngs=rngs,
|
71
|
+
metrics=Metrics.empty()
|
72
|
+
)
|
73
|
+
|
74
|
+
if existing_best_state is not None:
|
75
|
+
best_state = state.replace(
|
76
|
+
params=existing_best_state['params'], ema_params=existing_best_state['ema_params'])
|
77
|
+
else:
|
78
|
+
best_state = state
|
79
|
+
|
80
|
+
return state, best_state
|
81
|
+
|
82
|
+
def _define_train_step(self, batch_size, null_labels_seq, text_embedder):
|
83
|
+
noise_schedule: NoiseScheduler = self.noise_schedule
|
84
|
+
model = self.model
|
85
|
+
model_output_transform = self.model_output_transform
|
86
|
+
loss_fn = self.loss_fn
|
87
|
+
unconditional_prob = self.unconditional_prob
|
88
|
+
|
89
|
+
# Determine the number of unconditional samples
|
90
|
+
num_unconditional = int(batch_size * unconditional_prob)
|
91
|
+
|
92
|
+
nS, nC = null_labels_seq.shape
|
93
|
+
null_labels_seq = jnp.broadcast_to(
|
94
|
+
null_labels_seq, (batch_size, nS, nC))
|
95
|
+
|
96
|
+
distributed_training = self.distributed_training
|
97
|
+
|
98
|
+
autoencoder = self.autoencoder
|
99
|
+
|
100
|
+
# @jax.jit
|
101
|
+
def train_step(train_state: TrainState, rng_state: RandomMarkovState, batch, local_device_index):
|
102
|
+
"""Train for a single step."""
|
103
|
+
rng_state, subkey = rng_state.get_random_key()
|
104
|
+
subkey = jax.random.fold_in(subkey, local_device_index.reshape())
|
105
|
+
local_rng_state = RandomMarkovState(subkey)
|
106
|
+
|
107
|
+
images = batch['image']
|
108
|
+
|
109
|
+
if autoencoder is not None:
|
110
|
+
# Convert the images to latent space
|
111
|
+
local_rng_state, rngs = local_rng_state.get_random_key()
|
112
|
+
images = autoencoder.encode(images, rngs)
|
113
|
+
else:
|
114
|
+
# normalize image
|
115
|
+
images = (images - 127.5) / 127.5
|
116
|
+
|
117
|
+
output = text_embedder(
|
118
|
+
input_ids=batch['input_ids'], attention_mask=batch['attention_mask'])
|
119
|
+
label_seq = output.last_hidden_state
|
120
|
+
|
121
|
+
# Generate random probabilities to decide how much of this batch will be unconditional
|
122
|
+
|
123
|
+
label_seq = jnp.concat(
|
124
|
+
[null_labels_seq[:num_unconditional], label_seq[num_unconditional:]], axis=0)
|
125
|
+
|
126
|
+
noise_level, local_rng_state = noise_schedule.generate_timesteps(images.shape[0], local_rng_state)
|
127
|
+
|
128
|
+
local_rng_state, rngs = local_rng_state.get_random_key()
|
129
|
+
noise: jax.Array = jax.random.normal(rngs, shape=images.shape)
|
130
|
+
|
131
|
+
rates = noise_schedule.get_rates(noise_level)
|
132
|
+
noisy_images, c_in, expected_output = model_output_transform.forward_diffusion(
|
133
|
+
images, noise, rates)
|
134
|
+
|
135
|
+
def model_loss(params):
|
136
|
+
preds = model.apply(params, *noise_schedule.transform_inputs(noisy_images*c_in, noise_level), label_seq)
|
137
|
+
preds = model_output_transform.pred_transform(
|
138
|
+
noisy_images, preds, rates)
|
139
|
+
nloss = loss_fn(preds, expected_output)
|
140
|
+
# nloss = jnp.mean(nloss, axis=1)
|
141
|
+
nloss *= noise_schedule.get_weights(noise_level)
|
142
|
+
nloss = jnp.mean(nloss)
|
143
|
+
loss = nloss
|
144
|
+
return loss
|
145
|
+
|
146
|
+
loss, grads = jax.value_and_grad(model_loss)(train_state.params)
|
147
|
+
if distributed_training:
|
148
|
+
grads = jax.lax.pmean(grads, "data")
|
149
|
+
loss = jax.lax.pmean(loss, "data")
|
150
|
+
train_state = train_state.apply_gradients(grads=grads)
|
151
|
+
train_state = train_state.apply_ema(self.ema_decay)
|
152
|
+
return train_state, loss, rng_state
|
153
|
+
|
154
|
+
if distributed_training:
|
155
|
+
train_step = shard_map(train_step, mesh=self.mesh, in_specs=(P(), P(), P('data'), P('data')),
|
156
|
+
out_specs=(P(), P(), P()))
|
157
|
+
train_step = jax.jit(train_step)
|
158
|
+
|
159
|
+
return train_step
|
160
|
+
|
161
|
+
def _define_compute_metrics(self):
|
162
|
+
@jax.jit
|
163
|
+
def compute_metrics(state: TrainState, expected, pred):
|
164
|
+
loss = jnp.mean(jnp.square(pred - expected))
|
165
|
+
metric_updates = state.metrics.single_from_model_output(loss=loss)
|
166
|
+
metrics = state.metrics.merge(metric_updates)
|
167
|
+
state = state.replace(metrics=metrics)
|
168
|
+
return state
|
169
|
+
return compute_metrics
|
170
|
+
|
171
|
+
def fit(self, data, steps_per_epoch, epochs):
|
172
|
+
null_labels_full = data['null_labels_full']
|
173
|
+
local_batch_size = data['local_batch_size']
|
174
|
+
text_embedder = data['model']
|
175
|
+
super().fit(data, steps_per_epoch, epochs, {
|
176
|
+
"batch_size": local_batch_size, "null_labels_seq": null_labels_full, "text_embedder": text_embedder})
|
177
|
+
|
178
|
+
def boolean_string(s):
|
179
|
+
if type(s) == bool:
|
180
|
+
return s
|
181
|
+
return s == 'True'
|
182
|
+
|
@@ -0,0 +1,202 @@
|
|
1
|
+
from flax import linen as nn
|
2
|
+
import jax
|
3
|
+
from typing import Callable
|
4
|
+
from dataclasses import field
|
5
|
+
import jax.numpy as jnp
|
6
|
+
import optax
|
7
|
+
from jax.sharding import Mesh, PartitionSpec as P
|
8
|
+
from jax.experimental.shard_map import shard_map
|
9
|
+
from typing import Dict, Callable, Sequence, Any, Union, Tuple
|
10
|
+
|
11
|
+
from ..schedulers import NoiseScheduler
|
12
|
+
from ..predictors import DiffusionPredictionTransform, EpsilonPredictionTransform
|
13
|
+
|
14
|
+
from flaxdiff.utils import RandomMarkovState
|
15
|
+
|
16
|
+
from .simple_trainer import SimpleTrainer, SimpleTrainState, Metrics
|
17
|
+
|
18
|
+
from flaxdiff.models.autoencoder.autoencoder import AutoEncoder
|
19
|
+
|
20
|
+
class TrainState(SimpleTrainState):
|
21
|
+
rngs: jax.random.PRNGKey
|
22
|
+
ema_params: dict
|
23
|
+
|
24
|
+
def apply_ema(self, decay: float = 0.999):
|
25
|
+
new_ema_params = jax.tree_util.tree_map(
|
26
|
+
lambda ema, param: decay * ema + (1 - decay) * param,
|
27
|
+
self.ema_params,
|
28
|
+
self.params,
|
29
|
+
)
|
30
|
+
return self.replace(ema_params=new_ema_params)
|
31
|
+
|
32
|
+
class DiffusionTrainer(SimpleTrainer):
|
33
|
+
noise_schedule: NoiseScheduler
|
34
|
+
model_output_transform: DiffusionPredictionTransform
|
35
|
+
ema_decay: float = 0.999
|
36
|
+
|
37
|
+
def __init__(self,
|
38
|
+
model: nn.Module,
|
39
|
+
input_shapes: Dict[str, Tuple[int]],
|
40
|
+
optimizer: optax.GradientTransformation,
|
41
|
+
noise_schedule: NoiseScheduler,
|
42
|
+
rngs: jax.random.PRNGKey,
|
43
|
+
unconditional_prob: float = 0.2,
|
44
|
+
name: str = "Diffusion",
|
45
|
+
model_output_transform: DiffusionPredictionTransform = EpsilonPredictionTransform(),
|
46
|
+
autoencoder: AutoEncoder = None,
|
47
|
+
**kwargs
|
48
|
+
):
|
49
|
+
super().__init__(
|
50
|
+
model=model,
|
51
|
+
input_shapes=input_shapes,
|
52
|
+
optimizer=optimizer,
|
53
|
+
rngs=rngs,
|
54
|
+
name=name,
|
55
|
+
**kwargs
|
56
|
+
)
|
57
|
+
self.noise_schedule = noise_schedule
|
58
|
+
self.model_output_transform = model_output_transform
|
59
|
+
self.unconditional_prob = unconditional_prob
|
60
|
+
|
61
|
+
self.autoencoder = autoencoder
|
62
|
+
|
63
|
+
def generate_states(
|
64
|
+
self,
|
65
|
+
optimizer: optax.GradientTransformation,
|
66
|
+
rngs: jax.random.PRNGKey,
|
67
|
+
existing_state: dict = None,
|
68
|
+
existing_best_state: dict = None,
|
69
|
+
model: nn.Module = None,
|
70
|
+
param_transforms: Callable = None
|
71
|
+
) -> Tuple[TrainState, TrainState]:
|
72
|
+
print("Generating states for DiffusionTrainer")
|
73
|
+
rngs, subkey = jax.random.split(rngs)
|
74
|
+
|
75
|
+
if existing_state == None:
|
76
|
+
input_vars = self.get_input_ones()
|
77
|
+
params = model.init(subkey, **input_vars)
|
78
|
+
new_state = {"params": params, "ema_params": params}
|
79
|
+
else:
|
80
|
+
new_state = existing_state
|
81
|
+
|
82
|
+
if param_transforms is not None:
|
83
|
+
params = param_transforms(params)
|
84
|
+
|
85
|
+
state = TrainState.create(
|
86
|
+
apply_fn=model.apply,
|
87
|
+
params=new_state['params'],
|
88
|
+
ema_params=new_state['ema_params'],
|
89
|
+
tx=optimizer,
|
90
|
+
rngs=rngs,
|
91
|
+
metrics=Metrics.empty()
|
92
|
+
)
|
93
|
+
|
94
|
+
if existing_best_state is not None:
|
95
|
+
best_state = state.replace(
|
96
|
+
params=existing_best_state['params'], ema_params=existing_best_state['ema_params'])
|
97
|
+
else:
|
98
|
+
best_state = state
|
99
|
+
|
100
|
+
return state, best_state
|
101
|
+
|
102
|
+
def _define_train_step(self, batch_size, null_labels_seq, text_embedder):
|
103
|
+
noise_schedule: NoiseScheduler = self.noise_schedule
|
104
|
+
model = self.model
|
105
|
+
model_output_transform = self.model_output_transform
|
106
|
+
loss_fn = self.loss_fn
|
107
|
+
unconditional_prob = self.unconditional_prob
|
108
|
+
|
109
|
+
# Determine the number of unconditional samples
|
110
|
+
num_unconditional = int(batch_size * unconditional_prob)
|
111
|
+
|
112
|
+
nS, nC = null_labels_seq.shape
|
113
|
+
null_labels_seq = jnp.broadcast_to(
|
114
|
+
null_labels_seq, (batch_size, nS, nC))
|
115
|
+
|
116
|
+
distributed_training = self.distributed_training
|
117
|
+
|
118
|
+
autoencoder = self.autoencoder
|
119
|
+
|
120
|
+
# @jax.jit
|
121
|
+
def train_step(train_state: TrainState, rng_state: RandomMarkovState, batch, local_device_index):
|
122
|
+
"""Train for a single step."""
|
123
|
+
rng_state, subkey = rng_state.get_random_key()
|
124
|
+
subkey = jax.random.fold_in(subkey, local_device_index.reshape())
|
125
|
+
local_rng_state = RandomMarkovState(subkey)
|
126
|
+
|
127
|
+
images = batch['image']
|
128
|
+
|
129
|
+
if autoencoder is not None:
|
130
|
+
# Convert the images to latent space
|
131
|
+
local_rng_state, rngs = local_rng_state.get_random_key()
|
132
|
+
images = autoencoder.encode(images, rngs)
|
133
|
+
else:
|
134
|
+
# normalize image
|
135
|
+
images = (images - 127.5) / 127.5
|
136
|
+
|
137
|
+
output = text_embedder(
|
138
|
+
input_ids=batch['input_ids'], attention_mask=batch['attention_mask'])
|
139
|
+
label_seq = output.last_hidden_state
|
140
|
+
|
141
|
+
# Generate random probabilities to decide how much of this batch will be unconditional
|
142
|
+
|
143
|
+
label_seq = jnp.concat(
|
144
|
+
[null_labels_seq[:num_unconditional], label_seq[num_unconditional:]], axis=0)
|
145
|
+
|
146
|
+
noise_level, local_rng_state = noise_schedule.generate_timesteps(images.shape[0], local_rng_state)
|
147
|
+
|
148
|
+
local_rng_state, rngs = local_rng_state.get_random_key()
|
149
|
+
noise: jax.Array = jax.random.normal(rngs, shape=images.shape)
|
150
|
+
|
151
|
+
rates = noise_schedule.get_rates(noise_level)
|
152
|
+
noisy_images, c_in, expected_output = model_output_transform.forward_diffusion(
|
153
|
+
images, noise, rates)
|
154
|
+
|
155
|
+
def model_loss(params):
|
156
|
+
preds = model.apply(params, *noise_schedule.transform_inputs(noisy_images*c_in, noise_level), label_seq)
|
157
|
+
preds = model_output_transform.pred_transform(
|
158
|
+
noisy_images, preds, rates)
|
159
|
+
nloss = loss_fn(preds, expected_output)
|
160
|
+
# nloss = jnp.mean(nloss, axis=1)
|
161
|
+
nloss *= noise_schedule.get_weights(noise_level)
|
162
|
+
nloss = jnp.mean(nloss)
|
163
|
+
loss = nloss
|
164
|
+
return loss
|
165
|
+
|
166
|
+
loss, grads = jax.value_and_grad(model_loss)(train_state.params)
|
167
|
+
if distributed_training:
|
168
|
+
grads = jax.lax.pmean(grads, "data")
|
169
|
+
loss = jax.lax.pmean(loss, "data")
|
170
|
+
train_state = train_state.apply_gradients(grads=grads)
|
171
|
+
train_state = train_state.apply_ema(self.ema_decay)
|
172
|
+
return train_state, loss, rng_state
|
173
|
+
|
174
|
+
if distributed_training:
|
175
|
+
train_step = shard_map(train_step, mesh=self.mesh, in_specs=(P(), P(), P('data'), P('data')),
|
176
|
+
out_specs=(P(), P(), P()))
|
177
|
+
train_step = jax.jit(train_step)
|
178
|
+
|
179
|
+
return train_step
|
180
|
+
|
181
|
+
def _define_compute_metrics(self):
|
182
|
+
@jax.jit
|
183
|
+
def compute_metrics(state: TrainState, expected, pred):
|
184
|
+
loss = jnp.mean(jnp.square(pred - expected))
|
185
|
+
metric_updates = state.metrics.single_from_model_output(loss=loss)
|
186
|
+
metrics = state.metrics.merge(metric_updates)
|
187
|
+
state = state.replace(metrics=metrics)
|
188
|
+
return state
|
189
|
+
return compute_metrics
|
190
|
+
|
191
|
+
def fit(self, data, steps_per_epoch, epochs):
|
192
|
+
null_labels_full = data['null_labels_full']
|
193
|
+
local_batch_size = data['local_batch_size']
|
194
|
+
text_embedder = data['model']
|
195
|
+
super().fit(data, steps_per_epoch, epochs, {
|
196
|
+
"batch_size": local_batch_size, "null_labels_seq": null_labels_full, "text_embedder": text_embedder})
|
197
|
+
|
198
|
+
def boolean_string(s):
|
199
|
+
if type(s) == bool:
|
200
|
+
return s
|
201
|
+
return s == 'True'
|
202
|
+
|