flaxdiff 0.1.35.6__py3-none-any.whl → 0.1.36.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. flaxdiff/utils.py +105 -2
  2. {flaxdiff-0.1.35.6.dist-info → flaxdiff-0.1.36.1.dist-info}/METADATA +16 -7
  3. flaxdiff-0.1.36.1.dist-info/RECORD +6 -0
  4. {flaxdiff-0.1.35.6.dist-info → flaxdiff-0.1.36.1.dist-info}/WHEEL +1 -1
  5. flaxdiff/data/__init__.py +0 -1
  6. flaxdiff/data/online_loader.py +0 -336
  7. flaxdiff/models/__init__.py +0 -1
  8. flaxdiff/models/attention.py +0 -368
  9. flaxdiff/models/autoencoder/__init__.py +0 -2
  10. flaxdiff/models/autoencoder/autoencoder.py +0 -19
  11. flaxdiff/models/autoencoder/diffusers.py +0 -91
  12. flaxdiff/models/autoencoder/simple_autoenc.py +0 -26
  13. flaxdiff/models/common.py +0 -346
  14. flaxdiff/models/favor_fastattn.py +0 -723
  15. flaxdiff/models/simple_unet.py +0 -233
  16. flaxdiff/models/simple_vit.py +0 -180
  17. flaxdiff/predictors/__init__.py +0 -96
  18. flaxdiff/samplers/__init__.py +0 -7
  19. flaxdiff/samplers/common.py +0 -113
  20. flaxdiff/samplers/ddim.py +0 -10
  21. flaxdiff/samplers/ddpm.py +0 -43
  22. flaxdiff/samplers/euler.py +0 -59
  23. flaxdiff/samplers/heun_sampler.py +0 -28
  24. flaxdiff/samplers/multistep_dpm.py +0 -60
  25. flaxdiff/samplers/rk4_sampler.py +0 -34
  26. flaxdiff/schedulers/__init__.py +0 -6
  27. flaxdiff/schedulers/common.py +0 -98
  28. flaxdiff/schedulers/continuous.py +0 -12
  29. flaxdiff/schedulers/cosine.py +0 -40
  30. flaxdiff/schedulers/discrete.py +0 -74
  31. flaxdiff/schedulers/exp.py +0 -13
  32. flaxdiff/schedulers/karras.py +0 -69
  33. flaxdiff/schedulers/linear.py +0 -14
  34. flaxdiff/schedulers/sqrt.py +0 -10
  35. flaxdiff/trainer/__init__.py +0 -2
  36. flaxdiff/trainer/autoencoder_trainer.py +0 -182
  37. flaxdiff/trainer/diffusion_trainer.py +0 -234
  38. flaxdiff/trainer/simple_trainer.py +0 -442
  39. flaxdiff-0.1.35.6.dist-info/RECORD +0 -40
  40. {flaxdiff-0.1.35.6.dist-info → flaxdiff-0.1.36.1.dist-info}/top_level.txt +0 -0
@@ -1,182 +0,0 @@
1
- from flax import linen as nn
2
- import jax
3
- from typing import Callable
4
- from dataclasses import field
5
- import jax.numpy as jnp
6
- import optax
7
- from jax.sharding import Mesh, PartitionSpec as P
8
- from jax.experimental.shard_map import shard_map
9
- from typing import Dict, Callable, Sequence, Any, Union, Tuple
10
-
11
- from ..schedulers import NoiseScheduler
12
- from ..predictors import DiffusionPredictionTransform, EpsilonPredictionTransform
13
-
14
- from flaxdiff.utils import RandomMarkovState
15
-
16
- from .simple_trainer import SimpleTrainer, SimpleTrainState, Metrics
17
-
18
- from flaxdiff.models.autoencoder.autoencoder import AutoEncoder
19
-
20
- class AutoEncoderTrainer(SimpleTrainer):
21
- def __init__(self,
22
- model: nn.Module,
23
- input_shape: Union[int, int, int],
24
- latent_dim: int,
25
- spatial_scale: int,
26
- optimizer: optax.GradientTransformation,
27
- rngs: jax.random.PRNGKey,
28
- name: str = "Autoencoder",
29
- **kwargs
30
- ):
31
- super().__init__(
32
- model=model,
33
- input_shapes={"image": input_shape},
34
- optimizer=optimizer,
35
- rngs=rngs,
36
- name=name,
37
- **kwargs
38
- )
39
- self.latent_dim = latent_dim
40
- self.spatial_scale = spatial_scale
41
-
42
-
43
- def generate_states(
44
- self,
45
- optimizer: optax.GradientTransformation,
46
- rngs: jax.random.PRNGKey,
47
- existing_state: dict = None,
48
- existing_best_state: dict = None,
49
- model: nn.Module = None,
50
- param_transforms: Callable = None
51
- ) -> Tuple[TrainState, TrainState]:
52
- print("Generating states for DiffusionTrainer")
53
- rngs, subkey = jax.random.split(rngs)
54
-
55
- if existing_state == None:
56
- input_vars = self.get_input_ones()
57
- params = model.init(subkey, **input_vars)
58
- new_state = {"params": params, "ema_params": params}
59
- else:
60
- new_state = existing_state
61
-
62
- if param_transforms is not None:
63
- params = param_transforms(params)
64
-
65
- state = TrainState.create(
66
- apply_fn=model.apply,
67
- params=new_state['params'],
68
- ema_params=new_state['ema_params'],
69
- tx=optimizer,
70
- rngs=rngs,
71
- metrics=Metrics.empty()
72
- )
73
-
74
- if existing_best_state is not None:
75
- best_state = state.replace(
76
- params=existing_best_state['params'], ema_params=existing_best_state['ema_params'])
77
- else:
78
- best_state = state
79
-
80
- return state, best_state
81
-
82
- def _define_train_step(self, batch_size, null_labels_seq, text_embedder):
83
- noise_schedule: NoiseScheduler = self.noise_schedule
84
- model = self.model
85
- model_output_transform = self.model_output_transform
86
- loss_fn = self.loss_fn
87
- unconditional_prob = self.unconditional_prob
88
-
89
- # Determine the number of unconditional samples
90
- num_unconditional = int(batch_size * unconditional_prob)
91
-
92
- nS, nC = null_labels_seq.shape
93
- null_labels_seq = jnp.broadcast_to(
94
- null_labels_seq, (batch_size, nS, nC))
95
-
96
- distributed_training = self.distributed_training
97
-
98
- autoencoder = self.autoencoder
99
-
100
- # @jax.jit
101
- def train_step(train_state: TrainState, rng_state: RandomMarkovState, batch, local_device_index):
102
- """Train for a single step."""
103
- rng_state, subkey = rng_state.get_random_key()
104
- subkey = jax.random.fold_in(subkey, local_device_index.reshape())
105
- local_rng_state = RandomMarkovState(subkey)
106
-
107
- images = batch['image']
108
-
109
- if autoencoder is not None:
110
- # Convert the images to latent space
111
- local_rng_state, rngs = local_rng_state.get_random_key()
112
- images = autoencoder.encode(images, rngs)
113
- else:
114
- # normalize image
115
- images = (images - 127.5) / 127.5
116
-
117
- output = text_embedder(
118
- input_ids=batch['input_ids'], attention_mask=batch['attention_mask'])
119
- label_seq = output.last_hidden_state
120
-
121
- # Generate random probabilities to decide how much of this batch will be unconditional
122
-
123
- label_seq = jnp.concat(
124
- [null_labels_seq[:num_unconditional], label_seq[num_unconditional:]], axis=0)
125
-
126
- noise_level, local_rng_state = noise_schedule.generate_timesteps(images.shape[0], local_rng_state)
127
-
128
- local_rng_state, rngs = local_rng_state.get_random_key()
129
- noise: jax.Array = jax.random.normal(rngs, shape=images.shape)
130
-
131
- rates = noise_schedule.get_rates(noise_level)
132
- noisy_images, c_in, expected_output = model_output_transform.forward_diffusion(
133
- images, noise, rates)
134
-
135
- def model_loss(params):
136
- preds = model.apply(params, *noise_schedule.transform_inputs(noisy_images*c_in, noise_level), label_seq)
137
- preds = model_output_transform.pred_transform(
138
- noisy_images, preds, rates)
139
- nloss = loss_fn(preds, expected_output)
140
- # nloss = jnp.mean(nloss, axis=1)
141
- nloss *= noise_schedule.get_weights(noise_level)
142
- nloss = jnp.mean(nloss)
143
- loss = nloss
144
- return loss
145
-
146
- loss, grads = jax.value_and_grad(model_loss)(train_state.params)
147
- if distributed_training:
148
- grads = jax.lax.pmean(grads, "data")
149
- loss = jax.lax.pmean(loss, "data")
150
- train_state = train_state.apply_gradients(grads=grads)
151
- train_state = train_state.apply_ema(self.ema_decay)
152
- return train_state, loss, rng_state
153
-
154
- if distributed_training:
155
- train_step = shard_map(train_step, mesh=self.mesh, in_specs=(P(), P(), P('data'), P('data')),
156
- out_specs=(P(), P(), P()))
157
- train_step = jax.jit(train_step)
158
-
159
- return train_step
160
-
161
- def _define_compute_metrics(self):
162
- @jax.jit
163
- def compute_metrics(state: TrainState, expected, pred):
164
- loss = jnp.mean(jnp.square(pred - expected))
165
- metric_updates = state.metrics.single_from_model_output(loss=loss)
166
- metrics = state.metrics.merge(metric_updates)
167
- state = state.replace(metrics=metrics)
168
- return state
169
- return compute_metrics
170
-
171
- def fit(self, data, steps_per_epoch, epochs):
172
- null_labels_full = data['null_labels_full']
173
- local_batch_size = data['local_batch_size']
174
- text_embedder = data['model']
175
- super().fit(data, steps_per_epoch, epochs, {
176
- "batch_size": local_batch_size, "null_labels_seq": null_labels_full, "text_embedder": text_embedder})
177
-
178
- def boolean_string(s):
179
- if type(s) == bool:
180
- return s
181
- return s == 'True'
182
-
@@ -1,234 +0,0 @@
1
- from flax import linen as nn
2
- import jax
3
- from typing import Callable
4
- from dataclasses import field
5
- import jax.numpy as jnp
6
- import optax
7
- from jax.sharding import Mesh, PartitionSpec as P
8
- from jax.experimental.shard_map import shard_map
9
- from typing import Dict, Callable, Sequence, Any, Union, Tuple
10
-
11
- from ..schedulers import NoiseScheduler
12
- from ..predictors import DiffusionPredictionTransform, EpsilonPredictionTransform
13
-
14
- from flaxdiff.utils import RandomMarkovState
15
-
16
- from .simple_trainer import SimpleTrainer, SimpleTrainState, Metrics
17
-
18
- from flaxdiff.models.autoencoder.autoencoder import AutoEncoder
19
- from flax.training.dynamic_scale import DynamicScale
20
-
21
- class TrainState(SimpleTrainState):
22
- rngs: jax.random.PRNGKey
23
- ema_params: dict
24
-
25
- def apply_ema(self, decay: float = 0.999):
26
- new_ema_params = jax.tree_util.tree_map(
27
- lambda ema, param: decay * ema + (1 - decay) * param,
28
- self.ema_params,
29
- self.params,
30
- )
31
- return self.replace(ema_params=new_ema_params)
32
-
33
- from flaxdiff.models.autoencoder.autoencoder import AutoEncoder
34
-
35
- class DiffusionTrainer(SimpleTrainer):
36
- noise_schedule: NoiseScheduler
37
- model_output_transform: DiffusionPredictionTransform
38
- ema_decay: float = 0.999
39
-
40
- def __init__(self,
41
- model: nn.Module,
42
- input_shapes: Dict[str, Tuple[int]],
43
- optimizer: optax.GradientTransformation,
44
- noise_schedule: NoiseScheduler,
45
- rngs: jax.random.PRNGKey,
46
- unconditional_prob: float = 0.12,
47
- name: str = "Diffusion",
48
- model_output_transform: DiffusionPredictionTransform = EpsilonPredictionTransform(),
49
- autoencoder: AutoEncoder = None,
50
- **kwargs
51
- ):
52
- super().__init__(
53
- model=model,
54
- input_shapes=input_shapes,
55
- optimizer=optimizer,
56
- rngs=rngs,
57
- name=name,
58
- **kwargs
59
- )
60
- self.noise_schedule = noise_schedule
61
- self.model_output_transform = model_output_transform
62
- self.unconditional_prob = unconditional_prob
63
-
64
- self.autoencoder = autoencoder
65
-
66
- def generate_states(
67
- self,
68
- optimizer: optax.GradientTransformation,
69
- rngs: jax.random.PRNGKey,
70
- existing_state: dict = None,
71
- existing_best_state: dict = None,
72
- model: nn.Module = None,
73
- param_transforms: Callable = None,
74
- use_dynamic_scale: bool = False
75
- ) -> Tuple[TrainState, TrainState]:
76
- print("Generating states for DiffusionTrainer")
77
- rngs, subkey = jax.random.split(rngs)
78
-
79
- if existing_state == None:
80
- input_vars = self.get_input_ones()
81
- params = model.init(subkey, **input_vars)
82
- new_state = {"params": params, "ema_params": params}
83
- else:
84
- new_state = existing_state
85
-
86
- if param_transforms is not None:
87
- new_state['params'] = param_transforms(new_state['params'])
88
- new_state['ema_params'] = param_transforms(new_state['ema_params'])
89
-
90
- state = TrainState.create(
91
- apply_fn=model.apply,
92
- params=new_state['params'],
93
- ema_params=new_state['ema_params'],
94
- tx=optimizer,
95
- rngs=rngs,
96
- metrics=Metrics.empty(),
97
- dynamic_scale = DynamicScale() if use_dynamic_scale else None
98
- )
99
-
100
- if existing_best_state is not None:
101
- best_state = state.replace(
102
- params=existing_best_state['params'], ema_params=existing_best_state['ema_params'])
103
- else:
104
- best_state = state
105
-
106
- return state, best_state
107
-
108
- def _define_train_step(self, batch_size, null_labels_seq, text_embedder):
109
- noise_schedule: NoiseScheduler = self.noise_schedule
110
- model = self.model
111
- model_output_transform = self.model_output_transform
112
- loss_fn = self.loss_fn
113
- unconditional_prob = self.unconditional_prob
114
-
115
- # Determine the number of unconditional samples
116
- num_unconditional = int(batch_size * unconditional_prob)
117
-
118
- nS, nC = null_labels_seq.shape
119
- null_labels_seq = jnp.broadcast_to(
120
- null_labels_seq, (batch_size, nS, nC))
121
-
122
- distributed_training = self.distributed_training
123
-
124
- autoencoder = self.autoencoder
125
-
126
- # @jax.jit
127
- def train_step(train_state: TrainState, rng_state: RandomMarkovState, batch, local_device_index):
128
- """Train for a single step."""
129
- rng_state, subkey = rng_state.get_random_key()
130
- subkey = jax.random.fold_in(subkey, local_device_index.reshape())
131
- local_rng_state = RandomMarkovState(subkey)
132
-
133
- images = batch['image']
134
- images = jnp.array(images, dtype=jnp.float32)
135
- # normalize image
136
- images = (images - 127.5) / 127.5
137
-
138
- if autoencoder is not None:
139
- # Convert the images to latent space
140
- local_rng_state, rngs = local_rng_state.get_random_key()
141
- images = autoencoder.encode(images, rngs)
142
-
143
- output = text_embedder(
144
- input_ids=batch['input_ids'], attention_mask=batch['attention_mask'])
145
- label_seq = output.last_hidden_state
146
-
147
- # Generate random probabilities to decide how much of this batch will be unconditional
148
-
149
- label_seq = jnp.concat(
150
- [null_labels_seq[:num_unconditional], label_seq[num_unconditional:]], axis=0)
151
-
152
- noise_level, local_rng_state = noise_schedule.generate_timesteps(images.shape[0], local_rng_state)
153
-
154
- local_rng_state, rngs = local_rng_state.get_random_key()
155
- noise: jax.Array = jax.random.normal(rngs, shape=images.shape)
156
-
157
- rates = noise_schedule.get_rates(noise_level)
158
- noisy_images, c_in, expected_output = model_output_transform.forward_diffusion(
159
- images, noise, rates)
160
-
161
- def model_loss(params):
162
- preds = model.apply(params, *noise_schedule.transform_inputs(noisy_images*c_in, noise_level), label_seq)
163
- preds = model_output_transform.pred_transform(
164
- noisy_images, preds, rates)
165
- nloss = loss_fn(preds, expected_output)
166
- # nloss = jnp.mean(nloss, axis=1)
167
- nloss *= noise_schedule.get_weights(noise_level)
168
- nloss = jnp.mean(nloss)
169
- loss = nloss
170
- return loss
171
-
172
-
173
- if train_state.dynamic_scale is not None:
174
- # dynamic scale takes care of averaging gradients across replicas
175
- grad_fn = train_state.dynamic_scale.value_and_grad(
176
- model_loss, axis_name="data"
177
- )
178
- dynamic_scale, is_fin, loss, grads = grad_fn(train_state.params)
179
- train_state = train_state.replace(dynamic_scale=dynamic_scale)
180
- else:
181
- grad_fn = jax.value_and_grad(model_loss)
182
- loss, grads = grad_fn(train_state.params)
183
- if distributed_training:
184
- grads = jax.lax.pmean(grads, "data")
185
-
186
- new_state = train_state.apply_gradients(grads=grads)
187
-
188
- if train_state.dynamic_scale:
189
- # if is_fin == False the gradients contain Inf/NaNs and optimizer state and
190
- # params should be restored (= skip this step).
191
- select_fn = functools.partial(jnp.where, is_fin)
192
- new_state = train_state.replace(
193
- opt_state=jax.tree_util.tree_map(
194
- select_fn, new_state.opt_state, train_state.opt_state
195
- ),
196
- params=jax.tree_util.tree_map(
197
- select_fn, new_state.params, train_state.params
198
- ),
199
- )
200
-
201
- train_state = new_state.apply_ema(self.ema_decay)
202
-
203
- if distributed_training:
204
- loss = jax.lax.pmean(loss, "data")
205
- return train_state, loss, rng_state
206
-
207
- if distributed_training:
208
- train_step = shard_map(train_step, mesh=self.mesh, in_specs=(P(), P(), P('data'), P('data')),
209
- out_specs=(P(), P(), P()))
210
- train_step = jax.jit(train_step)
211
-
212
- return train_step
213
-
214
- def _define_compute_metrics(self):
215
- @jax.jit
216
- def compute_metrics(state: TrainState, expected, pred):
217
- loss = jnp.mean(jnp.square(pred - expected))
218
- metric_updates = state.metrics.single_from_model_output(loss=loss)
219
- metrics = state.metrics.merge(metric_updates)
220
- state = state.replace(metrics=metrics)
221
- return state
222
- return compute_metrics
223
-
224
- def fit(self, data, steps_per_epoch, epochs):
225
- null_labels_full = data['null_labels_full']
226
- local_batch_size = data['local_batch_size']
227
- text_embedder = data['model']
228
- super().fit(data, steps_per_epoch, epochs, {
229
- "batch_size": local_batch_size, "null_labels_seq": null_labels_full, "text_embedder": text_embedder})
230
-
231
- def boolean_string(s):
232
- if type(s) == bool:
233
- return s
234
- return s == 'True'