flaxdiff 0.1.5__py3-none-any.whl → 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -5,7 +5,8 @@ Some Code ported from https://github.com/huggingface/diffusers/blob/main/src/dif
5
5
  import jax
6
6
  import jax.numpy as jnp
7
7
  from flax import linen as nn
8
- from typing import Dict, Callable, Sequence, Any, Union
8
+ from typing import Dict, Callable, Sequence, Any, Union, Tuple, Optional
9
+ from flax.typing import Dtype, PrecisionLike
9
10
  import einops
10
11
  import functools
11
12
  import math
@@ -18,8 +19,8 @@ class EfficientAttention(nn.Module):
18
19
  query_dim: int
19
20
  heads: int = 4
20
21
  dim_head: int = 64
21
- dtype: Any = jnp.float32
22
- precision: Any = jax.lax.Precision.HIGHEST
22
+ dtype: Optional[Dtype] = None
23
+ precision: PrecisionLike = None
23
24
  use_bias: bool = True
24
25
  kernel_init: Callable = lambda : kernel_init(1.0)
25
26
 
@@ -109,8 +110,8 @@ class NormalAttention(nn.Module):
109
110
  query_dim: int
110
111
  heads: int = 4
111
112
  dim_head: int = 64
112
- dtype: Any = jnp.float32
113
- precision: Any = jax.lax.Precision.HIGHEST
113
+ dtype: Optional[Dtype] = None
114
+ precision: PrecisionLike = None
114
115
  use_bias: bool = True
115
116
  kernel_init: Callable = lambda : kernel_init(1.0)
116
117
 
@@ -166,8 +167,8 @@ class BasicTransformerBlock(nn.Module):
166
167
  query_dim: int
167
168
  heads: int = 4
168
169
  dim_head: int = 64
169
- dtype: Any = jnp.float32
170
- precision: Any = jax.lax.Precision.HIGHEST
170
+ dtype: Optional[Dtype] = None
171
+ precision: PrecisionLike = None
171
172
  use_bias: bool = True
172
173
  kernel_init: Callable = lambda : kernel_init(1.0)
173
174
  use_flash_attention:bool = False
@@ -286,8 +287,8 @@ class BasicTransformerBlock(nn.Module):
286
287
  query_dim: int
287
288
  heads: int = 4
288
289
  dim_head: int = 64
289
- dtype: Any = jnp.float32
290
- precision: Any = jax.lax.Precision.HIGHEST
290
+ dtype: Optional[Dtype] = None
291
+ precision: PrecisionLike = None
291
292
  use_bias: bool = True
292
293
  kernel_init: Callable = lambda : kernel_init(1.0)
293
294
  use_flash_attention:bool = False
@@ -346,8 +347,8 @@ class TransformerBlock(nn.Module):
346
347
  heads: int = 4
347
348
  dim_head: int = 32
348
349
  use_linear_attention: bool = True
349
- dtype: Any = jnp.float32
350
- precision: Any = jax.lax.Precision.HIGH
350
+ dtype: Optional[Dtype] = None
351
+ precision: PrecisionLike = None
351
352
  use_projection: bool = False
352
353
  use_flash_attention:bool = True
353
354
  use_self_and_cross:bool = False
@@ -0,0 +1,2 @@
1
+ from .autoencoder import AutoEncoder
2
+ from .diffusers import StableDiffusionVAE
@@ -6,9 +6,14 @@ import einops
6
6
  from ..common import kernel_init, ConvLayer, Upsample, Downsample, PixelShuffle
7
7
 
8
8
 
9
- class AutoEncoder:
9
+ class AutoEncoder():
10
10
  def encode(self, x: jnp.ndarray, **kwargs) -> jnp.ndarray:
11
11
  raise NotImplementedError
12
12
 
13
13
  def decode(self, z: jnp.ndarray, **kwargs) -> jnp.ndarray:
14
- raise NotImplementedError
14
+ raise NotImplementedError
15
+
16
+ def __call__(self, x: jnp.ndarray):
17
+ latents = self.encode(x)
18
+ reconstructions = self.decode(latents)
19
+ return reconstructions
@@ -5,6 +5,9 @@ from .autoencoder import AutoEncoder
5
5
 
6
6
  """
7
7
  This module contains an Autoencoder implementation which uses the Stable Diffusion VAE model from the HuggingFace Diffusers library.
8
+ The actual model was not trained by me, but was taken from the HuggingFace model hub.
9
+ I have only implemented the wrapper around the diffusers pipeline to make it compatible with our library
10
+ All credits for the model go to the developers of Stable Diffusion VAE and all credits for the pipeline go to the developers of the Diffusers library.
8
11
  """
9
12
 
10
13
  class StableDiffusionVAE(AutoEncoder):
@@ -0,0 +1,26 @@
1
+ from typing import Any, List, Optional, Callable
2
+ import jax
3
+ import flax.linen as nn
4
+ from jax import numpy as jnp
5
+ from flax.typing import Dtype, PrecisionLike
6
+ from .autoencoder import AutoEncoder
7
+
8
+ class SimpleAutoEncoder(AutoEncoder):
9
+ latent_channels: int
10
+ feature_depths: List[int]=[64, 128, 256, 512]
11
+ attention_configs:list=[{"heads":8}, {"heads":8}, {"heads":8}, {"heads":8}],
12
+ num_res_blocks: int=2
13
+ num_middle_res_blocks:int=1,
14
+ activation:Callable = jax.nn.swish
15
+ norm_groups:int=8
16
+ dtype: Optional[Dtype] = None
17
+ precision: PrecisionLike = None
18
+
19
+ # def encode(self, x: jnp.ndarray):
20
+
21
+
22
+ @nn.compact
23
+ def __call__(self, x: jnp.ndarray):
24
+ latents = self.encode(x)
25
+ reconstructions = self.decode(latents)
26
+ return reconstructions
flaxdiff/models/common.py CHANGED
@@ -1,6 +1,8 @@
1
1
  import jax.numpy as jnp
2
2
  import jax
3
3
  from flax import linen as nn
4
+ from typing import Optional, Any, Callable, Sequence, Union
5
+ from flax.typing import Dtype, PrecisionLike
4
6
  from typing import Dict, Callable, Sequence, Any, Union
5
7
  import einops
6
8
 
@@ -18,8 +20,9 @@ class WeightStandardizedConv(nn.Module):
18
20
  kernel_size: Sequence[int] = 3
19
21
  strides: Union[None, int, Sequence[int]] = 1
20
22
  padding: Any = 1
21
- dtype: Any = jnp.float32
22
- param_dtype: Any = jnp.float32
23
+ dtype: Optional[Dtype] = None
24
+ precision: PrecisionLike = None
25
+ param_dtype: Optional[Dtype] = None
23
26
 
24
27
  @nn.compact
25
28
  def __call__(self, x):
@@ -120,8 +123,8 @@ class SeparableConv(nn.Module):
120
123
  use_bias:bool=False
121
124
  kernel_init:Callable=kernel_init(1.0)
122
125
  padding:str="SAME"
123
- dtype: Any = jnp.bfloat16
124
- precision: Any = jax.lax.Precision.HIGH
126
+ dtype: Optional[Dtype] = None
127
+ precision: PrecisionLike = None
125
128
 
126
129
  @nn.compact
127
130
  def __call__(self, x):
@@ -149,8 +152,8 @@ class ConvLayer(nn.Module):
149
152
  kernel_size:tuple=(3, 3)
150
153
  strides:tuple=(1, 1)
151
154
  kernel_init:Callable=kernel_init(1.0)
152
- dtype: Any = jnp.bfloat16
153
- precision: Any = jax.lax.Precision.HIGH
155
+ dtype: Optional[Dtype] = None
156
+ precision: PrecisionLike = None
154
157
 
155
158
  def setup(self):
156
159
  # conv_type can be "conv", "separable", "conv_transpose"
@@ -199,8 +202,8 @@ class Upsample(nn.Module):
199
202
  features:int
200
203
  scale:int
201
204
  activation:Callable=jax.nn.swish
202
- dtype: Any = jnp.bfloat16
203
- precision: Any = jax.lax.Precision.HIGH
205
+ dtype: Optional[Dtype] = None
206
+ precision: PrecisionLike = None
204
207
 
205
208
  @nn.compact
206
209
  def __call__(self, x, residual=None):
@@ -224,8 +227,8 @@ class Downsample(nn.Module):
224
227
  features:int
225
228
  scale:int
226
229
  activation:Callable=jax.nn.swish
227
- dtype: Any = jnp.bfloat16
228
- precision: Any = jax.lax.Precision.HIGH
230
+ dtype: Optional[Dtype] = None
231
+ precision: PrecisionLike = None
229
232
 
230
233
  @nn.compact
231
234
  def __call__(self, x, residual=None):
@@ -248,3 +251,79 @@ def l2norm(t, axis=1, eps=1e-12):
248
251
  denom = jnp.clip(jnp.linalg.norm(t, ord=2, axis=axis, keepdims=True), eps)
249
252
  out = t/denom
250
253
  return (out)
254
+
255
+
256
+ class ResidualBlock(nn.Module):
257
+ conv_type:str
258
+ features:int
259
+ kernel_size:tuple=(3, 3)
260
+ strides:tuple=(1, 1)
261
+ padding:str="SAME"
262
+ activation:Callable=jax.nn.swish
263
+ direction:str=None
264
+ res:int=2
265
+ norm_groups:int=8
266
+ kernel_init:Callable=kernel_init(1.0)
267
+ dtype: Optional[Dtype] = None
268
+ precision: PrecisionLike = None
269
+
270
+ @nn.compact
271
+ def __call__(self, x:jax.Array, temb:jax.Array, textemb:jax.Array=None, extra_features:jax.Array=None):
272
+ residual = x
273
+ # out = nn.GroupNorm(self.norm_groups)(x)
274
+ out = nn.RMSNorm()(x)
275
+ out = self.activation(out)
276
+
277
+ out = ConvLayer(
278
+ self.conv_type,
279
+ features=self.features,
280
+ kernel_size=self.kernel_size,
281
+ strides=self.strides,
282
+ kernel_init=self.kernel_init,
283
+ name="conv1",
284
+ dtype=self.dtype,
285
+ precision=self.precision
286
+ )(out)
287
+
288
+ temb = nn.DenseGeneral(
289
+ features=self.features,
290
+ name="temb_projection",
291
+ dtype=self.dtype,
292
+ precision=self.precision)(temb)
293
+ temb = jnp.expand_dims(jnp.expand_dims(temb, 1), 1)
294
+ # scale, shift = jnp.split(temb, 2, axis=-1)
295
+ # out = out * (1 + scale) + shift
296
+ out = out + temb
297
+
298
+ # out = nn.GroupNorm(self.norm_groups)(out)
299
+ out = nn.RMSNorm()(out)
300
+ out = self.activation(out)
301
+
302
+ out = ConvLayer(
303
+ self.conv_type,
304
+ features=self.features,
305
+ kernel_size=self.kernel_size,
306
+ strides=self.strides,
307
+ kernel_init=self.kernel_init,
308
+ name="conv2",
309
+ dtype=self.dtype,
310
+ precision=self.precision
311
+ )(out)
312
+
313
+ if residual.shape != out.shape:
314
+ residual = ConvLayer(
315
+ self.conv_type,
316
+ features=self.features,
317
+ kernel_size=(1, 1),
318
+ strides=1,
319
+ kernel_init=self.kernel_init,
320
+ name="residual_conv",
321
+ dtype=self.dtype,
322
+ precision=self.precision
323
+ )(residual)
324
+ out = out + residual
325
+
326
+ out = jnp.concatenate([out, extra_features], axis=-1) if extra_features is not None else out
327
+
328
+ return out
329
+
@@ -1,83 +1,12 @@
1
1
  import jax
2
2
  import jax.numpy as jnp
3
3
  from flax import linen as nn
4
- from typing import Dict, Callable, Sequence, Any, Union
4
+ from flax.typing import Dtype, PrecisionLike
5
+ from typing import Dict, Callable, Sequence, Any, Union, Optional
5
6
  import einops
6
7
  from .common import kernel_init, ConvLayer, Downsample, Upsample, FourierEmbedding, TimeProjection
7
8
  from .attention import TransformerBlock
8
9
 
9
- class ResidualBlock(nn.Module):
10
- conv_type:str
11
- features:int
12
- kernel_size:tuple=(3, 3)
13
- strides:tuple=(1, 1)
14
- padding:str="SAME"
15
- activation:Callable=jax.nn.swish
16
- direction:str=None
17
- res:int=2
18
- norm_groups:int=8
19
- kernel_init:Callable=kernel_init(1.0)
20
- dtype: Any = jnp.float32
21
- precision: Any = jax.lax.Precision.HIGHEST
22
-
23
- @nn.compact
24
- def __call__(self, x:jax.Array, temb:jax.Array, textemb:jax.Array=None, extra_features:jax.Array=None):
25
- residual = x
26
- out = nn.GroupNorm(self.norm_groups)(x)
27
- out = self.activation(out)
28
-
29
- out = ConvLayer(
30
- self.conv_type,
31
- features=self.features,
32
- kernel_size=self.kernel_size,
33
- strides=self.strides,
34
- kernel_init=self.kernel_init,
35
- name="conv1",
36
- dtype=self.dtype,
37
- precision=self.precision
38
- )(out)
39
-
40
- temb = nn.DenseGeneral(
41
- features=self.features,
42
- name="temb_projection",
43
- dtype=self.dtype,
44
- precision=self.precision)(temb)
45
- temb = jnp.expand_dims(jnp.expand_dims(temb, 1), 1)
46
- # scale, shift = jnp.split(temb, 2, axis=-1)
47
- # out = out * (1 + scale) + shift
48
- out = out + temb
49
-
50
- out = nn.GroupNorm(self.norm_groups)(out)
51
- out = self.activation(out)
52
-
53
- out = ConvLayer(
54
- self.conv_type,
55
- features=self.features,
56
- kernel_size=self.kernel_size,
57
- strides=self.strides,
58
- kernel_init=self.kernel_init,
59
- name="conv2",
60
- dtype=self.dtype,
61
- precision=self.precision
62
- )(out)
63
-
64
- if residual.shape != out.shape:
65
- residual = ConvLayer(
66
- self.conv_type,
67
- features=self.features,
68
- kernel_size=(1, 1),
69
- strides=1,
70
- kernel_init=self.kernel_init,
71
- name="residual_conv",
72
- dtype=self.dtype,
73
- precision=self.precision
74
- )(residual)
75
- out = out + residual
76
-
77
- out = jnp.concatenate([out, extra_features], axis=-1) if extra_features is not None else out
78
-
79
- return out
80
-
81
10
  class Unet(nn.Module):
82
11
  output_channels:int=3
83
12
  emb_features:int=64*4,
@@ -87,8 +16,8 @@ class Unet(nn.Module):
87
16
  num_middle_res_blocks:int=1,
88
17
  activation:Callable = jax.nn.swish
89
18
  norm_groups:int=8
90
- dtype: Any = jnp.bfloat16
91
- precision: Any = jax.lax.Precision.HIGH
19
+ dtype: Optional[Dtype] = None
20
+ precision: PrecisionLike = None
92
21
 
93
22
  @nn.compact
94
23
  def __call__(self, x, temb, textcontext):
@@ -1,184 +1,2 @@
1
- from flax import linen as nn
2
- import jax
3
- from typing import Callable
4
- from dataclasses import field
5
- import jax.numpy as jnp
6
- import optax
7
- from jax.sharding import Mesh, PartitionSpec as P
8
- from jax.experimental.shard_map import shard_map
9
- from typing import Dict, Callable, Sequence, Any, Union, Tuple
10
-
11
- from ..schedulers import NoiseScheduler
12
- from ..predictors import DiffusionPredictionTransform, EpsilonPredictionTransform
13
-
14
- from flaxdiff.utils import RandomMarkovState
15
-
16
1
  from .simple_trainer import SimpleTrainer, SimpleTrainState, Metrics
17
-
18
- class TrainState(SimpleTrainState):
19
- rngs: jax.random.PRNGKey
20
- ema_params: dict
21
-
22
- def apply_ema(self, decay: float = 0.999):
23
- new_ema_params = jax.tree_util.tree_map(
24
- lambda ema, param: decay * ema + (1 - decay) * param,
25
- self.ema_params,
26
- self.params,
27
- )
28
- return self.replace(ema_params=new_ema_params)
29
-
30
- class DiffusionTrainer(SimpleTrainer):
31
- noise_schedule: NoiseScheduler
32
- model_output_transform: DiffusionPredictionTransform
33
- ema_decay: float = 0.999
34
-
35
- def __init__(self,
36
- model: nn.Module,
37
- input_shapes: Dict[str, Tuple[int]],
38
- optimizer: optax.GradientTransformation,
39
- noise_schedule: NoiseScheduler,
40
- rngs: jax.random.PRNGKey,
41
- unconditional_prob: float = 0.2,
42
- name: str = "Diffusion",
43
- model_output_transform: DiffusionPredictionTransform = EpsilonPredictionTransform(),
44
- **kwargs
45
- ):
46
- super().__init__(
47
- model=model,
48
- input_shapes=input_shapes,
49
- optimizer=optimizer,
50
- rngs=rngs,
51
- name=name,
52
- **kwargs
53
- )
54
- self.noise_schedule = noise_schedule
55
- self.model_output_transform = model_output_transform
56
- self.unconditional_prob = unconditional_prob
57
-
58
- def generate_states(
59
- self,
60
- optimizer: optax.GradientTransformation,
61
- rngs: jax.random.PRNGKey,
62
- existing_state: dict = None,
63
- existing_best_state: dict = None,
64
- model: nn.Module = None,
65
- param_transforms: Callable = None
66
- ) -> Tuple[TrainState, TrainState]:
67
- print("Generating states for DiffusionTrainer")
68
- rngs, subkey = jax.random.split(rngs)
69
-
70
- if existing_state == None:
71
- input_vars = self.get_input_ones()
72
- params = model.init(subkey, **input_vars)
73
- new_state = {"params": params, "ema_params": params}
74
- else:
75
- new_state = existing_state
76
-
77
- if param_transforms is not None:
78
- params = param_transforms(params)
79
-
80
- state = TrainState.create(
81
- apply_fn=model.apply,
82
- params=new_state['params'],
83
- ema_params=new_state['ema_params'],
84
- tx=optimizer,
85
- rngs=rngs,
86
- metrics=Metrics.empty()
87
- )
88
-
89
- if existing_best_state is not None:
90
- best_state = state.replace(
91
- params=existing_best_state['params'], ema_params=existing_best_state['ema_params'])
92
- else:
93
- best_state = state
94
-
95
- return state, best_state
96
-
97
- def _define_train_step(self, batch_size, null_labels_seq, text_embedder):
98
- noise_schedule: NoiseScheduler = self.noise_schedule
99
- model = self.model
100
- model_output_transform = self.model_output_transform
101
- loss_fn = self.loss_fn
102
- unconditional_prob = self.unconditional_prob
103
-
104
- # Determine the number of unconditional samples
105
- num_unconditional = int(batch_size * unconditional_prob)
106
-
107
- nS, nC = null_labels_seq.shape
108
- null_labels_seq = jnp.broadcast_to(
109
- null_labels_seq, (batch_size, nS, nC))
110
-
111
- distributed_training = self.distributed_training
112
-
113
- # @jax.jit
114
- def train_step(train_state: TrainState, rng_state: RandomMarkovState, batch, local_device_index):
115
- """Train for a single step."""
116
- rng_state, subkey = rng_state.get_random_key()
117
- subkey = jax.random.fold_in(subkey, local_device_index.reshape())
118
- local_rng_state = RandomMarkovState(subkey)
119
-
120
- images = batch['image']
121
- # normalize image
122
- images = (images - 127.5) / 127.5
123
-
124
- output = text_embedder(
125
- input_ids=batch['input_ids'], attention_mask=batch['attention_mask'])
126
- label_seq = output.last_hidden_state
127
-
128
- # Generate random probabilities to decide how much of this batch will be unconditional
129
-
130
- label_seq = jnp.concat(
131
- [null_labels_seq[:num_unconditional], label_seq[num_unconditional:]], axis=0)
132
-
133
- noise_level, local_rng_state = noise_schedule.generate_timesteps(images.shape[0], local_rng_state)
134
-
135
- local_rng_state, rngs = local_rng_state.get_random_key()
136
- noise: jax.Array = jax.random.normal(rngs, shape=images.shape)
137
-
138
- rates = noise_schedule.get_rates(noise_level)
139
- noisy_images, c_in, expected_output = model_output_transform.forward_diffusion(
140
- images, noise, rates)
141
-
142
- def model_loss(params):
143
- preds = model.apply(
144
- params, *noise_schedule.transform_inputs(noisy_images*c_in, noise_level), label_seq)
145
- preds = model_output_transform.pred_transform(
146
- noisy_images, preds, rates)
147
- nloss = loss_fn(preds, expected_output)
148
- # nloss = jnp.mean(nloss, axis=1)
149
- nloss *= noise_schedule.get_weights(noise_level)
150
- nloss = jnp.mean(nloss)
151
- loss = nloss
152
- return loss
153
-
154
- loss, grads = jax.value_and_grad(model_loss)(train_state.params)
155
- if distributed_training:
156
- grads = jax.lax.pmean(grads, "data")
157
- loss = jax.lax.pmean(loss, "data")
158
- train_state = train_state.apply_gradients(grads=grads)
159
- train_state = train_state.apply_ema(self.ema_decay)
160
- return train_state, loss, rng_state
161
-
162
- if distributed_training:
163
- train_step = shard_map(train_step, mesh=self.mesh, in_specs=(P(), P(), P('data'), P('data')),
164
- out_specs=(P(), P(), P()))
165
- train_step = jax.jit(train_step)
166
-
167
- return train_step
168
-
169
- def _define_compute_metrics(self):
170
- @jax.jit
171
- def compute_metrics(state: TrainState, expected, pred):
172
- loss = jnp.mean(jnp.square(pred - expected))
173
- metric_updates = state.metrics.single_from_model_output(loss=loss)
174
- metrics = state.metrics.merge(metric_updates)
175
- state = state.replace(metrics=metrics)
176
- return state
177
- return compute_metrics
178
-
179
- def fit(self, data, steps_per_epoch, epochs):
180
- null_labels_full = data['null_labels_full']
181
- local_batch_size = data['local_batch_size']
182
- text_embedder = data['model']
183
- super().fit(data, steps_per_epoch, epochs, {
184
- "batch_size": local_batch_size, "null_labels_seq": null_labels_full, "text_embedder": text_embedder})
2
+ from .diffusion_trainer import DiffusionTrainer, TrainState
@@ -0,0 +1,182 @@
1
+ from flax import linen as nn
2
+ import jax
3
+ from typing import Callable
4
+ from dataclasses import field
5
+ import jax.numpy as jnp
6
+ import optax
7
+ from jax.sharding import Mesh, PartitionSpec as P
8
+ from jax.experimental.shard_map import shard_map
9
+ from typing import Dict, Callable, Sequence, Any, Union, Tuple
10
+
11
+ from ..schedulers import NoiseScheduler
12
+ from ..predictors import DiffusionPredictionTransform, EpsilonPredictionTransform
13
+
14
+ from flaxdiff.utils import RandomMarkovState
15
+
16
+ from .simple_trainer import SimpleTrainer, SimpleTrainState, Metrics
17
+
18
+ from flaxdiff.models.autoencoder.autoencoder import AutoEncoder
19
+
20
+ class AutoEncoderTrainer(SimpleTrainer):
21
+ def __init__(self,
22
+ model: nn.Module,
23
+ input_shape: Union[int, int, int],
24
+ latent_dim: int,
25
+ spatial_scale: int,
26
+ optimizer: optax.GradientTransformation,
27
+ rngs: jax.random.PRNGKey,
28
+ name: str = "Autoencoder",
29
+ **kwargs
30
+ ):
31
+ super().__init__(
32
+ model=model,
33
+ input_shapes={"image": input_shape},
34
+ optimizer=optimizer,
35
+ rngs=rngs,
36
+ name=name,
37
+ **kwargs
38
+ )
39
+ self.latent_dim = latent_dim
40
+ self.spatial_scale = spatial_scale
41
+
42
+
43
+ def generate_states(
44
+ self,
45
+ optimizer: optax.GradientTransformation,
46
+ rngs: jax.random.PRNGKey,
47
+ existing_state: dict = None,
48
+ existing_best_state: dict = None,
49
+ model: nn.Module = None,
50
+ param_transforms: Callable = None
51
+ ) -> Tuple[TrainState, TrainState]:
52
+ print("Generating states for DiffusionTrainer")
53
+ rngs, subkey = jax.random.split(rngs)
54
+
55
+ if existing_state == None:
56
+ input_vars = self.get_input_ones()
57
+ params = model.init(subkey, **input_vars)
58
+ new_state = {"params": params, "ema_params": params}
59
+ else:
60
+ new_state = existing_state
61
+
62
+ if param_transforms is not None:
63
+ params = param_transforms(params)
64
+
65
+ state = TrainState.create(
66
+ apply_fn=model.apply,
67
+ params=new_state['params'],
68
+ ema_params=new_state['ema_params'],
69
+ tx=optimizer,
70
+ rngs=rngs,
71
+ metrics=Metrics.empty()
72
+ )
73
+
74
+ if existing_best_state is not None:
75
+ best_state = state.replace(
76
+ params=existing_best_state['params'], ema_params=existing_best_state['ema_params'])
77
+ else:
78
+ best_state = state
79
+
80
+ return state, best_state
81
+
82
+ def _define_train_step(self, batch_size, null_labels_seq, text_embedder):
83
+ noise_schedule: NoiseScheduler = self.noise_schedule
84
+ model = self.model
85
+ model_output_transform = self.model_output_transform
86
+ loss_fn = self.loss_fn
87
+ unconditional_prob = self.unconditional_prob
88
+
89
+ # Determine the number of unconditional samples
90
+ num_unconditional = int(batch_size * unconditional_prob)
91
+
92
+ nS, nC = null_labels_seq.shape
93
+ null_labels_seq = jnp.broadcast_to(
94
+ null_labels_seq, (batch_size, nS, nC))
95
+
96
+ distributed_training = self.distributed_training
97
+
98
+ autoencoder = self.autoencoder
99
+
100
+ # @jax.jit
101
+ def train_step(train_state: TrainState, rng_state: RandomMarkovState, batch, local_device_index):
102
+ """Train for a single step."""
103
+ rng_state, subkey = rng_state.get_random_key()
104
+ subkey = jax.random.fold_in(subkey, local_device_index.reshape())
105
+ local_rng_state = RandomMarkovState(subkey)
106
+
107
+ images = batch['image']
108
+
109
+ if autoencoder is not None:
110
+ # Convert the images to latent space
111
+ local_rng_state, rngs = local_rng_state.get_random_key()
112
+ images = autoencoder.encode(images, rngs)
113
+ else:
114
+ # normalize image
115
+ images = (images - 127.5) / 127.5
116
+
117
+ output = text_embedder(
118
+ input_ids=batch['input_ids'], attention_mask=batch['attention_mask'])
119
+ label_seq = output.last_hidden_state
120
+
121
+ # Generate random probabilities to decide how much of this batch will be unconditional
122
+
123
+ label_seq = jnp.concat(
124
+ [null_labels_seq[:num_unconditional], label_seq[num_unconditional:]], axis=0)
125
+
126
+ noise_level, local_rng_state = noise_schedule.generate_timesteps(images.shape[0], local_rng_state)
127
+
128
+ local_rng_state, rngs = local_rng_state.get_random_key()
129
+ noise: jax.Array = jax.random.normal(rngs, shape=images.shape)
130
+
131
+ rates = noise_schedule.get_rates(noise_level)
132
+ noisy_images, c_in, expected_output = model_output_transform.forward_diffusion(
133
+ images, noise, rates)
134
+
135
+ def model_loss(params):
136
+ preds = model.apply(params, *noise_schedule.transform_inputs(noisy_images*c_in, noise_level), label_seq)
137
+ preds = model_output_transform.pred_transform(
138
+ noisy_images, preds, rates)
139
+ nloss = loss_fn(preds, expected_output)
140
+ # nloss = jnp.mean(nloss, axis=1)
141
+ nloss *= noise_schedule.get_weights(noise_level)
142
+ nloss = jnp.mean(nloss)
143
+ loss = nloss
144
+ return loss
145
+
146
+ loss, grads = jax.value_and_grad(model_loss)(train_state.params)
147
+ if distributed_training:
148
+ grads = jax.lax.pmean(grads, "data")
149
+ loss = jax.lax.pmean(loss, "data")
150
+ train_state = train_state.apply_gradients(grads=grads)
151
+ train_state = train_state.apply_ema(self.ema_decay)
152
+ return train_state, loss, rng_state
153
+
154
+ if distributed_training:
155
+ train_step = shard_map(train_step, mesh=self.mesh, in_specs=(P(), P(), P('data'), P('data')),
156
+ out_specs=(P(), P(), P()))
157
+ train_step = jax.jit(train_step)
158
+
159
+ return train_step
160
+
161
+ def _define_compute_metrics(self):
162
+ @jax.jit
163
+ def compute_metrics(state: TrainState, expected, pred):
164
+ loss = jnp.mean(jnp.square(pred - expected))
165
+ metric_updates = state.metrics.single_from_model_output(loss=loss)
166
+ metrics = state.metrics.merge(metric_updates)
167
+ state = state.replace(metrics=metrics)
168
+ return state
169
+ return compute_metrics
170
+
171
+ def fit(self, data, steps_per_epoch, epochs):
172
+ null_labels_full = data['null_labels_full']
173
+ local_batch_size = data['local_batch_size']
174
+ text_embedder = data['model']
175
+ super().fit(data, steps_per_epoch, epochs, {
176
+ "batch_size": local_batch_size, "null_labels_seq": null_labels_full, "text_embedder": text_embedder})
177
+
178
+ def boolean_string(s):
179
+ if type(s) == bool:
180
+ return s
181
+ return s == 'True'
182
+
@@ -0,0 +1,202 @@
1
+ from flax import linen as nn
2
+ import jax
3
+ from typing import Callable
4
+ from dataclasses import field
5
+ import jax.numpy as jnp
6
+ import optax
7
+ from jax.sharding import Mesh, PartitionSpec as P
8
+ from jax.experimental.shard_map import shard_map
9
+ from typing import Dict, Callable, Sequence, Any, Union, Tuple
10
+
11
+ from ..schedulers import NoiseScheduler
12
+ from ..predictors import DiffusionPredictionTransform, EpsilonPredictionTransform
13
+
14
+ from flaxdiff.utils import RandomMarkovState
15
+
16
+ from .simple_trainer import SimpleTrainer, SimpleTrainState, Metrics
17
+
18
+ from flaxdiff.models.autoencoder.autoencoder import AutoEncoder
19
+
20
+ class TrainState(SimpleTrainState):
21
+ rngs: jax.random.PRNGKey
22
+ ema_params: dict
23
+
24
+ def apply_ema(self, decay: float = 0.999):
25
+ new_ema_params = jax.tree_util.tree_map(
26
+ lambda ema, param: decay * ema + (1 - decay) * param,
27
+ self.ema_params,
28
+ self.params,
29
+ )
30
+ return self.replace(ema_params=new_ema_params)
31
+
32
+ class DiffusionTrainer(SimpleTrainer):
33
+ noise_schedule: NoiseScheduler
34
+ model_output_transform: DiffusionPredictionTransform
35
+ ema_decay: float = 0.999
36
+
37
+ def __init__(self,
38
+ model: nn.Module,
39
+ input_shapes: Dict[str, Tuple[int]],
40
+ optimizer: optax.GradientTransformation,
41
+ noise_schedule: NoiseScheduler,
42
+ rngs: jax.random.PRNGKey,
43
+ unconditional_prob: float = 0.2,
44
+ name: str = "Diffusion",
45
+ model_output_transform: DiffusionPredictionTransform = EpsilonPredictionTransform(),
46
+ autoencoder: AutoEncoder = None,
47
+ **kwargs
48
+ ):
49
+ super().__init__(
50
+ model=model,
51
+ input_shapes=input_shapes,
52
+ optimizer=optimizer,
53
+ rngs=rngs,
54
+ name=name,
55
+ **kwargs
56
+ )
57
+ self.noise_schedule = noise_schedule
58
+ self.model_output_transform = model_output_transform
59
+ self.unconditional_prob = unconditional_prob
60
+
61
+ self.autoencoder = autoencoder
62
+
63
+ def generate_states(
64
+ self,
65
+ optimizer: optax.GradientTransformation,
66
+ rngs: jax.random.PRNGKey,
67
+ existing_state: dict = None,
68
+ existing_best_state: dict = None,
69
+ model: nn.Module = None,
70
+ param_transforms: Callable = None
71
+ ) -> Tuple[TrainState, TrainState]:
72
+ print("Generating states for DiffusionTrainer")
73
+ rngs, subkey = jax.random.split(rngs)
74
+
75
+ if existing_state == None:
76
+ input_vars = self.get_input_ones()
77
+ params = model.init(subkey, **input_vars)
78
+ new_state = {"params": params, "ema_params": params}
79
+ else:
80
+ new_state = existing_state
81
+
82
+ if param_transforms is not None:
83
+ params = param_transforms(params)
84
+
85
+ state = TrainState.create(
86
+ apply_fn=model.apply,
87
+ params=new_state['params'],
88
+ ema_params=new_state['ema_params'],
89
+ tx=optimizer,
90
+ rngs=rngs,
91
+ metrics=Metrics.empty()
92
+ )
93
+
94
+ if existing_best_state is not None:
95
+ best_state = state.replace(
96
+ params=existing_best_state['params'], ema_params=existing_best_state['ema_params'])
97
+ else:
98
+ best_state = state
99
+
100
+ return state, best_state
101
+
102
+ def _define_train_step(self, batch_size, null_labels_seq, text_embedder):
103
+ noise_schedule: NoiseScheduler = self.noise_schedule
104
+ model = self.model
105
+ model_output_transform = self.model_output_transform
106
+ loss_fn = self.loss_fn
107
+ unconditional_prob = self.unconditional_prob
108
+
109
+ # Determine the number of unconditional samples
110
+ num_unconditional = int(batch_size * unconditional_prob)
111
+
112
+ nS, nC = null_labels_seq.shape
113
+ null_labels_seq = jnp.broadcast_to(
114
+ null_labels_seq, (batch_size, nS, nC))
115
+
116
+ distributed_training = self.distributed_training
117
+
118
+ autoencoder = self.autoencoder
119
+
120
+ # @jax.jit
121
+ def train_step(train_state: TrainState, rng_state: RandomMarkovState, batch, local_device_index):
122
+ """Train for a single step."""
123
+ rng_state, subkey = rng_state.get_random_key()
124
+ subkey = jax.random.fold_in(subkey, local_device_index.reshape())
125
+ local_rng_state = RandomMarkovState(subkey)
126
+
127
+ images = batch['image']
128
+
129
+ if autoencoder is not None:
130
+ # Convert the images to latent space
131
+ local_rng_state, rngs = local_rng_state.get_random_key()
132
+ images = autoencoder.encode(images, rngs)
133
+ else:
134
+ # normalize image
135
+ images = (images - 127.5) / 127.5
136
+
137
+ output = text_embedder(
138
+ input_ids=batch['input_ids'], attention_mask=batch['attention_mask'])
139
+ label_seq = output.last_hidden_state
140
+
141
+ # Generate random probabilities to decide how much of this batch will be unconditional
142
+
143
+ label_seq = jnp.concat(
144
+ [null_labels_seq[:num_unconditional], label_seq[num_unconditional:]], axis=0)
145
+
146
+ noise_level, local_rng_state = noise_schedule.generate_timesteps(images.shape[0], local_rng_state)
147
+
148
+ local_rng_state, rngs = local_rng_state.get_random_key()
149
+ noise: jax.Array = jax.random.normal(rngs, shape=images.shape)
150
+
151
+ rates = noise_schedule.get_rates(noise_level)
152
+ noisy_images, c_in, expected_output = model_output_transform.forward_diffusion(
153
+ images, noise, rates)
154
+
155
+ def model_loss(params):
156
+ preds = model.apply(params, *noise_schedule.transform_inputs(noisy_images*c_in, noise_level), label_seq)
157
+ preds = model_output_transform.pred_transform(
158
+ noisy_images, preds, rates)
159
+ nloss = loss_fn(preds, expected_output)
160
+ # nloss = jnp.mean(nloss, axis=1)
161
+ nloss *= noise_schedule.get_weights(noise_level)
162
+ nloss = jnp.mean(nloss)
163
+ loss = nloss
164
+ return loss
165
+
166
+ loss, grads = jax.value_and_grad(model_loss)(train_state.params)
167
+ if distributed_training:
168
+ grads = jax.lax.pmean(grads, "data")
169
+ loss = jax.lax.pmean(loss, "data")
170
+ train_state = train_state.apply_gradients(grads=grads)
171
+ train_state = train_state.apply_ema(self.ema_decay)
172
+ return train_state, loss, rng_state
173
+
174
+ if distributed_training:
175
+ train_step = shard_map(train_step, mesh=self.mesh, in_specs=(P(), P(), P('data'), P('data')),
176
+ out_specs=(P(), P(), P()))
177
+ train_step = jax.jit(train_step)
178
+
179
+ return train_step
180
+
181
+ def _define_compute_metrics(self):
182
+ @jax.jit
183
+ def compute_metrics(state: TrainState, expected, pred):
184
+ loss = jnp.mean(jnp.square(pred - expected))
185
+ metric_updates = state.metrics.single_from_model_output(loss=loss)
186
+ metrics = state.metrics.merge(metric_updates)
187
+ state = state.replace(metrics=metrics)
188
+ return state
189
+ return compute_metrics
190
+
191
+ def fit(self, data, steps_per_epoch, epochs):
192
+ null_labels_full = data['null_labels_full']
193
+ local_batch_size = data['local_batch_size']
194
+ text_embedder = data['model']
195
+ super().fit(data, steps_per_epoch, epochs, {
196
+ "batch_size": local_batch_size, "null_labels_seq": null_labels_full, "text_embedder": text_embedder})
197
+
198
+ def boolean_string(s):
199
+ if type(s) == bool:
200
+ return s
201
+ return s == 'True'
202
+
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flaxdiff
3
- Version: 0.1.5
3
+ Version: 0.1.6
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author: Ashish Kumar Singh
6
6
  Author-email: ashishkmr472@gmail.com
@@ -13,6 +13,8 @@ Requires-Dist: clu
13
13
 
14
14
  # ![](images/logo.jpeg "FlaxDiff")
15
15
 
16
+ **This project is partially supported by [Google TPU Research Cloud](https://sites.research.google/trc/about/). I would like to thank the Google Cloud TPU team for providing me with the resources to train the bigger text-conditional models in multi-host distributed settings.**
17
+
16
18
  ## A Versatile and simple Diffusion Library
17
19
 
18
20
  In recent years, diffusion and score-based multi-step models have revolutionized the generative AI domain. However, the latest research in this field has become highly math-intensive, making it challenging to understand how state-of-the-art diffusion models work and generate such impressive images. Replicating this research in code can be daunting.
@@ -1,14 +1,15 @@
1
1
  flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
2
  flaxdiff/utils.py,sha256=B0GcHlzlVYDNEIdh2v5qmP4u0neIT-FqexNohuyuCvg,2452
3
3
  flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,26
4
- flaxdiff/models/attention.py,sha256=KiAUyfujGpUZR13aJR6RVnL6pBXk5UcyM62VIXhojMg,14468
5
- flaxdiff/models/common.py,sha256=jlyRB4uF7BmeuExor1YHaqEbBjSuyaDZ4mDsSW3rWKE,7948
4
+ flaxdiff/models/attention.py,sha256=OhpKQXdxWbf8K2_yotLfS0DYdHb-zNpL2p8--ql_FAg,14503
5
+ flaxdiff/models/common.py,sha256=RYNxX9K19hvwSWaB9Wtv7MIZLhcacdugDgD9uZDh8XM,10358
6
6
  flaxdiff/models/favor_fastattn.py,sha256=79Ew1nqarsNLPzZaBSd1ILORzJr74CupYeqGiCQK5E4,27689
7
- flaxdiff/models/simple_unet.py,sha256=o1DCa9yvqarEGTiUKsTqE70q-h6bRU6HcU0lZpb65jc,11418
7
+ flaxdiff/models/simple_unet.py,sha256=hAcz074E9NVdUtECPMi1c1Kw-52Dc6l_ME-5FqIg-n8,9255
8
8
  flaxdiff/models/simple_vit.py,sha256=vTu2CQRoSOxetBHTrnCWddm-vxrZDkMe8EpdNxtpJMk,4015
9
- flaxdiff/models/autoencoder/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
- flaxdiff/models/autoencoder/autoencoder.py,sha256=At-DhcmrZ0Gao4PUa4l9D25FTdTPwbE4gu6LKcFKzUQ,433
11
- flaxdiff/models/autoencoder/diffusers.py,sha256=gwyD98277vQGKVPFbyd6w6CupoxMsNgKlN67AtzLCtg,3267
9
+ flaxdiff/models/autoencoder/__init__.py,sha256=qY-7MldZpsfkF-_T2LqlRK7VHbqfmosz0NmvzDlBkOk,78
10
+ flaxdiff/models/autoencoder/autoencoder.py,sha256=27_hYl0yXAdH9Mx4Xu9J79mSNo-FEKr9SxhVaS3ffn4,591
11
+ flaxdiff/models/autoencoder/diffusers.py,sha256=kwlKwHBSAegtTiEkGju_1Trltegj-e47hXFN9jCKmgY,3609
12
+ flaxdiff/models/autoencoder/simple_autoenc.py,sha256=UXHPgDmwGTnv3Uts6Zj3p9R9nJXnEiEXbllgarwDfXM,805
12
13
  flaxdiff/predictors/__init__.py,sha256=SKkYYRF9Wfgk2zhtZw4vCXOdOeRlrm2Mk6cvuaEvAzc,4403
13
14
  flaxdiff/samplers/__init__.py,sha256=_S-9TwDeshrI0VmapV-J2hqjTByOa0-oOeUs_IdovjU,285
14
15
  flaxdiff/samplers/common.py,sha256=_an5h5Niz9Joz_-ppridLrGHpu8X0VVvhNGknPu6AUY,5272
@@ -27,9 +28,11 @@ flaxdiff/schedulers/exp.py,sha256=cPTnUJpYdzJRRZqMLYQz0rRUCpEmaP2tXhRumLx94jA,60
27
28
  flaxdiff/schedulers/karras.py,sha256=4GN120kGwdxxU-h2mVdhBVy9IORkUMm_vvz3XjthBcI,3355
28
29
  flaxdiff/schedulers/linear.py,sha256=6003F5ISq1Wc0h6UAzY95MJgsDIKGMhBzbiVALpea0k,581
29
30
  flaxdiff/schedulers/sqrt.py,sha256=1F84ZgQPuoNMhe6yxGTR2G0h7dPOZtm4UDQOakbSsEU,445
30
- flaxdiff/trainer/__init__.py,sha256=17qKQFITCfaXQFKYElMzkE-c-EPrv5iUL66gY1gKOsQ,7243
31
+ flaxdiff/trainer/__init__.py,sha256=T-vUVq4zHcMK6kpCsG4Gu8vn71q6lZD-lg-Ul7yKfEk,128
32
+ flaxdiff/trainer/autoencoder_trainer.py,sha256=al7AsZ7yeDMEiDD-gbcXf0ADq_xfk1VMxvg24GfA-XQ,7008
33
+ flaxdiff/trainer/diffusion_trainer.py,sha256=h5YxIMjBI553xDNeapzLDGF0_4y0MfGRMuHume5sPtM,7785
31
34
  flaxdiff/trainer/simple_trainer.py,sha256=f4g2KGuGM__d9v_4Ip3ng8wQubmenWZUW60VEu2ANOg,16774
32
- flaxdiff-0.1.5.dist-info/METADATA,sha256=tGKayFhkYSJJnLY_sHiaCJ60kJZqnO-kcLM3uH3JSN4,19811
33
- flaxdiff-0.1.5.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
34
- flaxdiff-0.1.5.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
35
- flaxdiff-0.1.5.dist-info/RECORD,,
35
+ flaxdiff-0.1.6.dist-info/METADATA,sha256=sWY_oQgQhhuyW89KyRwIBrpVHBPJjRMmsk5twfgIBlo,20090
36
+ flaxdiff-0.1.6.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
37
+ flaxdiff-0.1.6.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
38
+ flaxdiff-0.1.6.dist-info/RECORD,,