flaxdiff 0.1.4__py3-none-any.whl → 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,182 @@
1
+ from flax import linen as nn
2
+ import jax
3
+ from typing import Callable
4
+ from dataclasses import field
5
+ import jax.numpy as jnp
6
+ import optax
7
+ from jax.sharding import Mesh, PartitionSpec as P
8
+ from jax.experimental.shard_map import shard_map
9
+ from typing import Dict, Callable, Sequence, Any, Union, Tuple
10
+
11
+ from ..schedulers import NoiseScheduler
12
+ from ..predictors import DiffusionPredictionTransform, EpsilonPredictionTransform
13
+
14
+ from flaxdiff.utils import RandomMarkovState
15
+
16
+ from .simple_trainer import SimpleTrainer, SimpleTrainState, Metrics
17
+
18
+ from flaxdiff.models.autoencoder.autoencoder import AutoEncoder
19
+
20
+ class AutoEncoderTrainer(SimpleTrainer):
21
+ def __init__(self,
22
+ model: nn.Module,
23
+ input_shape: Union[int, int, int],
24
+ latent_dim: int,
25
+ spatial_scale: int,
26
+ optimizer: optax.GradientTransformation,
27
+ rngs: jax.random.PRNGKey,
28
+ name: str = "Autoencoder",
29
+ **kwargs
30
+ ):
31
+ super().__init__(
32
+ model=model,
33
+ input_shapes={"image": input_shape},
34
+ optimizer=optimizer,
35
+ rngs=rngs,
36
+ name=name,
37
+ **kwargs
38
+ )
39
+ self.latent_dim = latent_dim
40
+ self.spatial_scale = spatial_scale
41
+
42
+
43
+ def generate_states(
44
+ self,
45
+ optimizer: optax.GradientTransformation,
46
+ rngs: jax.random.PRNGKey,
47
+ existing_state: dict = None,
48
+ existing_best_state: dict = None,
49
+ model: nn.Module = None,
50
+ param_transforms: Callable = None
51
+ ) -> Tuple[TrainState, TrainState]:
52
+ print("Generating states for DiffusionTrainer")
53
+ rngs, subkey = jax.random.split(rngs)
54
+
55
+ if existing_state == None:
56
+ input_vars = self.get_input_ones()
57
+ params = model.init(subkey, **input_vars)
58
+ new_state = {"params": params, "ema_params": params}
59
+ else:
60
+ new_state = existing_state
61
+
62
+ if param_transforms is not None:
63
+ params = param_transforms(params)
64
+
65
+ state = TrainState.create(
66
+ apply_fn=model.apply,
67
+ params=new_state['params'],
68
+ ema_params=new_state['ema_params'],
69
+ tx=optimizer,
70
+ rngs=rngs,
71
+ metrics=Metrics.empty()
72
+ )
73
+
74
+ if existing_best_state is not None:
75
+ best_state = state.replace(
76
+ params=existing_best_state['params'], ema_params=existing_best_state['ema_params'])
77
+ else:
78
+ best_state = state
79
+
80
+ return state, best_state
81
+
82
+ def _define_train_step(self, batch_size, null_labels_seq, text_embedder):
83
+ noise_schedule: NoiseScheduler = self.noise_schedule
84
+ model = self.model
85
+ model_output_transform = self.model_output_transform
86
+ loss_fn = self.loss_fn
87
+ unconditional_prob = self.unconditional_prob
88
+
89
+ # Determine the number of unconditional samples
90
+ num_unconditional = int(batch_size * unconditional_prob)
91
+
92
+ nS, nC = null_labels_seq.shape
93
+ null_labels_seq = jnp.broadcast_to(
94
+ null_labels_seq, (batch_size, nS, nC))
95
+
96
+ distributed_training = self.distributed_training
97
+
98
+ autoencoder = self.autoencoder
99
+
100
+ # @jax.jit
101
+ def train_step(train_state: TrainState, rng_state: RandomMarkovState, batch, local_device_index):
102
+ """Train for a single step."""
103
+ rng_state, subkey = rng_state.get_random_key()
104
+ subkey = jax.random.fold_in(subkey, local_device_index.reshape())
105
+ local_rng_state = RandomMarkovState(subkey)
106
+
107
+ images = batch['image']
108
+
109
+ if autoencoder is not None:
110
+ # Convert the images to latent space
111
+ local_rng_state, rngs = local_rng_state.get_random_key()
112
+ images = autoencoder.encode(images, rngs)
113
+ else:
114
+ # normalize image
115
+ images = (images - 127.5) / 127.5
116
+
117
+ output = text_embedder(
118
+ input_ids=batch['input_ids'], attention_mask=batch['attention_mask'])
119
+ label_seq = output.last_hidden_state
120
+
121
+ # Generate random probabilities to decide how much of this batch will be unconditional
122
+
123
+ label_seq = jnp.concat(
124
+ [null_labels_seq[:num_unconditional], label_seq[num_unconditional:]], axis=0)
125
+
126
+ noise_level, local_rng_state = noise_schedule.generate_timesteps(images.shape[0], local_rng_state)
127
+
128
+ local_rng_state, rngs = local_rng_state.get_random_key()
129
+ noise: jax.Array = jax.random.normal(rngs, shape=images.shape)
130
+
131
+ rates = noise_schedule.get_rates(noise_level)
132
+ noisy_images, c_in, expected_output = model_output_transform.forward_diffusion(
133
+ images, noise, rates)
134
+
135
+ def model_loss(params):
136
+ preds = model.apply(params, *noise_schedule.transform_inputs(noisy_images*c_in, noise_level), label_seq)
137
+ preds = model_output_transform.pred_transform(
138
+ noisy_images, preds, rates)
139
+ nloss = loss_fn(preds, expected_output)
140
+ # nloss = jnp.mean(nloss, axis=1)
141
+ nloss *= noise_schedule.get_weights(noise_level)
142
+ nloss = jnp.mean(nloss)
143
+ loss = nloss
144
+ return loss
145
+
146
+ loss, grads = jax.value_and_grad(model_loss)(train_state.params)
147
+ if distributed_training:
148
+ grads = jax.lax.pmean(grads, "data")
149
+ loss = jax.lax.pmean(loss, "data")
150
+ train_state = train_state.apply_gradients(grads=grads)
151
+ train_state = train_state.apply_ema(self.ema_decay)
152
+ return train_state, loss, rng_state
153
+
154
+ if distributed_training:
155
+ train_step = shard_map(train_step, mesh=self.mesh, in_specs=(P(), P(), P('data'), P('data')),
156
+ out_specs=(P(), P(), P()))
157
+ train_step = jax.jit(train_step)
158
+
159
+ return train_step
160
+
161
+ def _define_compute_metrics(self):
162
+ @jax.jit
163
+ def compute_metrics(state: TrainState, expected, pred):
164
+ loss = jnp.mean(jnp.square(pred - expected))
165
+ metric_updates = state.metrics.single_from_model_output(loss=loss)
166
+ metrics = state.metrics.merge(metric_updates)
167
+ state = state.replace(metrics=metrics)
168
+ return state
169
+ return compute_metrics
170
+
171
+ def fit(self, data, steps_per_epoch, epochs):
172
+ null_labels_full = data['null_labels_full']
173
+ local_batch_size = data['local_batch_size']
174
+ text_embedder = data['model']
175
+ super().fit(data, steps_per_epoch, epochs, {
176
+ "batch_size": local_batch_size, "null_labels_seq": null_labels_full, "text_embedder": text_embedder})
177
+
178
+ def boolean_string(s):
179
+ if type(s) == bool:
180
+ return s
181
+ return s == 'True'
182
+
@@ -0,0 +1,202 @@
1
+ from flax import linen as nn
2
+ import jax
3
+ from typing import Callable
4
+ from dataclasses import field
5
+ import jax.numpy as jnp
6
+ import optax
7
+ from jax.sharding import Mesh, PartitionSpec as P
8
+ from jax.experimental.shard_map import shard_map
9
+ from typing import Dict, Callable, Sequence, Any, Union, Tuple
10
+
11
+ from ..schedulers import NoiseScheduler
12
+ from ..predictors import DiffusionPredictionTransform, EpsilonPredictionTransform
13
+
14
+ from flaxdiff.utils import RandomMarkovState
15
+
16
+ from .simple_trainer import SimpleTrainer, SimpleTrainState, Metrics
17
+
18
+ from flaxdiff.models.autoencoder.autoencoder import AutoEncoder
19
+
20
+ class TrainState(SimpleTrainState):
21
+ rngs: jax.random.PRNGKey
22
+ ema_params: dict
23
+
24
+ def apply_ema(self, decay: float = 0.999):
25
+ new_ema_params = jax.tree_util.tree_map(
26
+ lambda ema, param: decay * ema + (1 - decay) * param,
27
+ self.ema_params,
28
+ self.params,
29
+ )
30
+ return self.replace(ema_params=new_ema_params)
31
+
32
+ class DiffusionTrainer(SimpleTrainer):
33
+ noise_schedule: NoiseScheduler
34
+ model_output_transform: DiffusionPredictionTransform
35
+ ema_decay: float = 0.999
36
+
37
+ def __init__(self,
38
+ model: nn.Module,
39
+ input_shapes: Dict[str, Tuple[int]],
40
+ optimizer: optax.GradientTransformation,
41
+ noise_schedule: NoiseScheduler,
42
+ rngs: jax.random.PRNGKey,
43
+ unconditional_prob: float = 0.2,
44
+ name: str = "Diffusion",
45
+ model_output_transform: DiffusionPredictionTransform = EpsilonPredictionTransform(),
46
+ autoencoder: AutoEncoder = None,
47
+ **kwargs
48
+ ):
49
+ super().__init__(
50
+ model=model,
51
+ input_shapes=input_shapes,
52
+ optimizer=optimizer,
53
+ rngs=rngs,
54
+ name=name,
55
+ **kwargs
56
+ )
57
+ self.noise_schedule = noise_schedule
58
+ self.model_output_transform = model_output_transform
59
+ self.unconditional_prob = unconditional_prob
60
+
61
+ self.autoencoder = autoencoder
62
+
63
+ def generate_states(
64
+ self,
65
+ optimizer: optax.GradientTransformation,
66
+ rngs: jax.random.PRNGKey,
67
+ existing_state: dict = None,
68
+ existing_best_state: dict = None,
69
+ model: nn.Module = None,
70
+ param_transforms: Callable = None
71
+ ) -> Tuple[TrainState, TrainState]:
72
+ print("Generating states for DiffusionTrainer")
73
+ rngs, subkey = jax.random.split(rngs)
74
+
75
+ if existing_state == None:
76
+ input_vars = self.get_input_ones()
77
+ params = model.init(subkey, **input_vars)
78
+ new_state = {"params": params, "ema_params": params}
79
+ else:
80
+ new_state = existing_state
81
+
82
+ if param_transforms is not None:
83
+ params = param_transforms(params)
84
+
85
+ state = TrainState.create(
86
+ apply_fn=model.apply,
87
+ params=new_state['params'],
88
+ ema_params=new_state['ema_params'],
89
+ tx=optimizer,
90
+ rngs=rngs,
91
+ metrics=Metrics.empty()
92
+ )
93
+
94
+ if existing_best_state is not None:
95
+ best_state = state.replace(
96
+ params=existing_best_state['params'], ema_params=existing_best_state['ema_params'])
97
+ else:
98
+ best_state = state
99
+
100
+ return state, best_state
101
+
102
+ def _define_train_step(self, batch_size, null_labels_seq, text_embedder):
103
+ noise_schedule: NoiseScheduler = self.noise_schedule
104
+ model = self.model
105
+ model_output_transform = self.model_output_transform
106
+ loss_fn = self.loss_fn
107
+ unconditional_prob = self.unconditional_prob
108
+
109
+ # Determine the number of unconditional samples
110
+ num_unconditional = int(batch_size * unconditional_prob)
111
+
112
+ nS, nC = null_labels_seq.shape
113
+ null_labels_seq = jnp.broadcast_to(
114
+ null_labels_seq, (batch_size, nS, nC))
115
+
116
+ distributed_training = self.distributed_training
117
+
118
+ autoencoder = self.autoencoder
119
+
120
+ # @jax.jit
121
+ def train_step(train_state: TrainState, rng_state: RandomMarkovState, batch, local_device_index):
122
+ """Train for a single step."""
123
+ rng_state, subkey = rng_state.get_random_key()
124
+ subkey = jax.random.fold_in(subkey, local_device_index.reshape())
125
+ local_rng_state = RandomMarkovState(subkey)
126
+
127
+ images = batch['image']
128
+
129
+ if autoencoder is not None:
130
+ # Convert the images to latent space
131
+ local_rng_state, rngs = local_rng_state.get_random_key()
132
+ images = autoencoder.encode(images, rngs)
133
+ else:
134
+ # normalize image
135
+ images = (images - 127.5) / 127.5
136
+
137
+ output = text_embedder(
138
+ input_ids=batch['input_ids'], attention_mask=batch['attention_mask'])
139
+ label_seq = output.last_hidden_state
140
+
141
+ # Generate random probabilities to decide how much of this batch will be unconditional
142
+
143
+ label_seq = jnp.concat(
144
+ [null_labels_seq[:num_unconditional], label_seq[num_unconditional:]], axis=0)
145
+
146
+ noise_level, local_rng_state = noise_schedule.generate_timesteps(images.shape[0], local_rng_state)
147
+
148
+ local_rng_state, rngs = local_rng_state.get_random_key()
149
+ noise: jax.Array = jax.random.normal(rngs, shape=images.shape)
150
+
151
+ rates = noise_schedule.get_rates(noise_level)
152
+ noisy_images, c_in, expected_output = model_output_transform.forward_diffusion(
153
+ images, noise, rates)
154
+
155
+ def model_loss(params):
156
+ preds = model.apply(params, *noise_schedule.transform_inputs(noisy_images*c_in, noise_level), label_seq)
157
+ preds = model_output_transform.pred_transform(
158
+ noisy_images, preds, rates)
159
+ nloss = loss_fn(preds, expected_output)
160
+ # nloss = jnp.mean(nloss, axis=1)
161
+ nloss *= noise_schedule.get_weights(noise_level)
162
+ nloss = jnp.mean(nloss)
163
+ loss = nloss
164
+ return loss
165
+
166
+ loss, grads = jax.value_and_grad(model_loss)(train_state.params)
167
+ if distributed_training:
168
+ grads = jax.lax.pmean(grads, "data")
169
+ loss = jax.lax.pmean(loss, "data")
170
+ train_state = train_state.apply_gradients(grads=grads)
171
+ train_state = train_state.apply_ema(self.ema_decay)
172
+ return train_state, loss, rng_state
173
+
174
+ if distributed_training:
175
+ train_step = shard_map(train_step, mesh=self.mesh, in_specs=(P(), P(), P('data'), P('data')),
176
+ out_specs=(P(), P(), P()))
177
+ train_step = jax.jit(train_step)
178
+
179
+ return train_step
180
+
181
+ def _define_compute_metrics(self):
182
+ @jax.jit
183
+ def compute_metrics(state: TrainState, expected, pred):
184
+ loss = jnp.mean(jnp.square(pred - expected))
185
+ metric_updates = state.metrics.single_from_model_output(loss=loss)
186
+ metrics = state.metrics.merge(metric_updates)
187
+ state = state.replace(metrics=metrics)
188
+ return state
189
+ return compute_metrics
190
+
191
+ def fit(self, data, steps_per_epoch, epochs):
192
+ null_labels_full = data['null_labels_full']
193
+ local_batch_size = data['local_batch_size']
194
+ text_embedder = data['model']
195
+ super().fit(data, steps_per_epoch, epochs, {
196
+ "batch_size": local_batch_size, "null_labels_seq": null_labels_full, "text_embedder": text_embedder})
197
+
198
+ def boolean_string(s):
199
+ if type(s) == bool:
200
+ return s
201
+ return s == 'True'
202
+