flaxdiff 0.1.4__py3-none-any.whl → 0.1.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,323 +1,14 @@
1
1
  import jax
2
2
  import jax.numpy as jnp
3
3
  from flax import linen as nn
4
- from typing import Dict, Callable, Sequence, Any, Union
4
+ from flax.typing import Dtype, PrecisionLike
5
+ from typing import Dict, Callable, Sequence, Any, Union, Optional
5
6
  import einops
6
- from .common import kernel_init
7
+ from .common import kernel_init, ConvLayer, Downsample, Upsample, FourierEmbedding, TimeProjection
7
8
  from .attention import TransformerBlock
8
9
 
9
- class WeightStandardizedConv(nn.Module):
10
- """
11
- apply weight standardization https://arxiv.org/abs/1903.10520
12
- """
13
- features: int
14
- kernel_size: Sequence[int] = 3
15
- strides: Union[None, int, Sequence[int]] = 1
16
- padding: Any = 1
17
- dtype: Any = jnp.float32
18
- param_dtype: Any = jnp.float32
19
-
20
- @nn.compact
21
- def __call__(self, x):
22
- """
23
- Applies a weight standardized convolution to the inputs.
24
-
25
- Args:
26
- inputs: input data with dimensions (batch, spatial_dims..., features).
27
-
28
- Returns:
29
- The convolved data.
30
- """
31
- x = x.astype(self.dtype)
32
-
33
- conv = nn.Conv(
34
- features=self.features,
35
- kernel_size=self.kernel_size,
36
- strides = self.strides,
37
- padding=self.padding,
38
- dtype=self.dtype,
39
- param_dtype = self.param_dtype,
40
- parent=None)
41
-
42
- kernel_init = lambda rng, x: conv.init(rng,x)['params']['kernel']
43
- bias_init = lambda rng, x: conv.init(rng,x)['params']['bias']
44
-
45
- # standardize kernel
46
- kernel = self.param('kernel', kernel_init, x)
47
- eps = 1e-5 if self.dtype == jnp.float32 else 1e-3
48
- # reduce over dim_out
49
- redux = tuple(range(kernel.ndim - 1))
50
- mean = jnp.mean(kernel, axis=redux, dtype=self.dtype, keepdims=True)
51
- var = jnp.var(kernel, axis=redux, dtype=self.dtype, keepdims=True)
52
- standardized_kernel = (kernel - mean)/jnp.sqrt(var + eps)
53
-
54
- bias = self.param('bias',bias_init, x)
55
-
56
- return(conv.apply({'params': {'kernel': standardized_kernel, 'bias': bias}},x))
57
-
58
- class PixelShuffle(nn.Module):
59
- scale: int
60
-
61
- @nn.compact
62
- def __call__(self, x):
63
- up = einops.rearrange(
64
- x,
65
- pattern="b h w (h2 w2 c) -> b (h h2) (w w2) c",
66
- h2=self.scale,
67
- w2=self.scale,
68
- )
69
- return up
70
-
71
- class TimeEmbedding(nn.Module):
72
- features:int
73
- nax_positions:int=10000
74
-
75
- def setup(self):
76
- half_dim = self.features // 2
77
- emb = jnp.log(self.nax_positions) / (half_dim - 1)
78
- emb = jnp.exp(-emb * jnp.arange(half_dim, dtype=jnp.float32))
79
- self.embeddings = emb
80
-
81
- def __call__(self, x):
82
- x = jax.lax.convert_element_type(x, jnp.float32)
83
- emb = x[:, None] * self.embeddings[None, :]
84
- emb = jnp.concatenate([jnp.sin(emb), jnp.cos(emb)], axis=-1)
85
- return emb
86
-
87
- class FourierEmbedding(nn.Module):
88
- features:int
89
- scale:int = 16
90
-
91
- def setup(self):
92
- self.freqs = jax.random.normal(jax.random.PRNGKey(42), (self.features // 2, ), dtype=jnp.float32) * self.scale
93
-
94
- def __call__(self, x):
95
- x = jax.lax.convert_element_type(x, jnp.float32)
96
- emb = x[:, None] * (2 * jnp.pi * self.freqs)[None, :]
97
- emb = jnp.concatenate([jnp.sin(emb), jnp.cos(emb)], axis=-1)
98
- return emb
99
-
100
- class TimeProjection(nn.Module):
101
- features:int
102
- activation:Callable=jax.nn.gelu
103
-
104
- @nn.compact
105
- def __call__(self, x):
106
- x = nn.DenseGeneral(self.features, kernel_init=kernel_init(1.0))(x)
107
- x = self.activation(x)
108
- x = nn.DenseGeneral(self.features, kernel_init=kernel_init(1.0))(x)
109
- x = self.activation(x)
110
- return x
111
-
112
- class SeparableConv(nn.Module):
113
- features:int
114
- kernel_size:tuple=(3, 3)
115
- strides:tuple=(1, 1)
116
- use_bias:bool=False
117
- kernel_init:Callable=kernel_init(1.0)
118
- padding:str="SAME"
119
- dtype: Any = jnp.bfloat16
120
- precision: Any = jax.lax.Precision.HIGH
121
-
122
- @nn.compact
123
- def __call__(self, x):
124
- in_features = x.shape[-1]
125
- depthwise = nn.Conv(
126
- features=in_features, kernel_size=self.kernel_size,
127
- strides=self.strides, kernel_init=self.kernel_init,
128
- feature_group_count=in_features, use_bias=self.use_bias,
129
- padding=self.padding,
130
- dtype=self.dtype,
131
- precision=self.precision
132
- )(x)
133
- pointwise = nn.Conv(
134
- features=self.features, kernel_size=(1, 1),
135
- strides=(1, 1), kernel_init=self.kernel_init,
136
- use_bias=self.use_bias,
137
- dtype=self.dtype,
138
- precision=self.precision
139
- )(depthwise)
140
- return pointwise
141
-
142
- class ConvLayer(nn.Module):
143
- conv_type:str
144
- features:int
145
- kernel_size:tuple=(3, 3)
146
- strides:tuple=(1, 1)
147
- kernel_init:Callable=kernel_init(1.0)
148
- dtype: Any = jnp.bfloat16
149
- precision: Any = jax.lax.Precision.HIGH
150
-
151
- def setup(self):
152
- # conv_type can be "conv", "separable", "conv_transpose"
153
- if self.conv_type == "conv":
154
- self.conv = nn.Conv(
155
- features=self.features,
156
- kernel_size=self.kernel_size,
157
- strides=self.strides,
158
- kernel_init=self.kernel_init,
159
- dtype=self.dtype,
160
- precision=self.precision
161
- )
162
- elif self.conv_type == "w_conv":
163
- self.conv = WeightStandardizedConv(
164
- features=self.features,
165
- kernel_size=self.kernel_size,
166
- strides=self.strides,
167
- padding="SAME",
168
- param_dtype=self.dtype,
169
- dtype=self.dtype,
170
- precision=self.precision
171
- )
172
- elif self.conv_type == "separable":
173
- self.conv = SeparableConv(
174
- features=self.features,
175
- kernel_size=self.kernel_size,
176
- strides=self.strides,
177
- kernel_init=self.kernel_init,
178
- dtype=self.dtype,
179
- precision=self.precision
180
- )
181
- elif self.conv_type == "conv_transpose":
182
- self.conv = nn.ConvTranspose(
183
- features=self.features,
184
- kernel_size=self.kernel_size,
185
- strides=self.strides,
186
- kernel_init=self.kernel_init,
187
- dtype=self.dtype,
188
- precision=self.precision
189
- )
190
-
191
- def __call__(self, x):
192
- return self.conv(x)
193
-
194
- class Upsample(nn.Module):
195
- features:int
196
- scale:int
197
- activation:Callable=jax.nn.swish
198
- dtype: Any = jnp.bfloat16
199
- precision: Any = jax.lax.Precision.HIGH
200
-
201
- @nn.compact
202
- def __call__(self, x, residual=None):
203
- out = x
204
- # out = PixelShuffle(scale=self.scale)(out)
205
- B, H, W, C = x.shape
206
- out = jax.image.resize(x, (B, H * self.scale, W * self.scale, C), method="nearest")
207
- out = ConvLayer(
208
- "conv",
209
- features=self.features,
210
- kernel_size=(3, 3),
211
- strides=(1, 1),
212
- dtype=self.dtype,
213
- precision=self.precision
214
- )(out)
215
- if residual is not None:
216
- out = jnp.concatenate([out, residual], axis=-1)
217
- return out
218
-
219
- class Downsample(nn.Module):
220
- features:int
221
- scale:int
222
- activation:Callable=jax.nn.swish
223
- dtype: Any = jnp.bfloat16
224
- precision: Any = jax.lax.Precision.HIGH
225
-
226
- @nn.compact
227
- def __call__(self, x, residual=None):
228
- out = ConvLayer(
229
- "conv",
230
- features=self.features,
231
- kernel_size=(3, 3),
232
- strides=(2, 2),
233
- dtype=self.dtype,
234
- precision=self.precision
235
- )(x)
236
- if residual is not None:
237
- if residual.shape[1] > out.shape[1]:
238
- residual = nn.avg_pool(residual, window_shape=(2, 2), strides=(2, 2), padding="SAME")
239
- out = jnp.concatenate([out, residual], axis=-1)
240
- return out
241
-
242
-
243
- def l2norm(t, axis=1, eps=1e-12):
244
- denom = jnp.clip(jnp.linalg.norm(t, ord=2, axis=axis, keepdims=True), eps)
245
- out = t/denom
246
- return (out)
247
-
248
- class ResidualBlock(nn.Module):
249
- conv_type:str
250
- features:int
251
- kernel_size:tuple=(3, 3)
252
- strides:tuple=(1, 1)
253
- padding:str="SAME"
254
- activation:Callable=jax.nn.swish
255
- direction:str=None
256
- res:int=2
257
- norm_groups:int=8
258
- kernel_init:Callable=kernel_init(1.0)
259
- dtype: Any = jnp.float32
260
- precision: Any = jax.lax.Precision.HIGHEST
261
-
262
- @nn.compact
263
- def __call__(self, x:jax.Array, temb:jax.Array, textemb:jax.Array=None, extra_features:jax.Array=None):
264
- residual = x
265
- out = nn.GroupNorm(self.norm_groups)(x)
266
- out = self.activation(out)
267
-
268
- out = ConvLayer(
269
- self.conv_type,
270
- features=self.features,
271
- kernel_size=self.kernel_size,
272
- strides=self.strides,
273
- kernel_init=self.kernel_init,
274
- name="conv1",
275
- dtype=self.dtype,
276
- precision=self.precision
277
- )(out)
278
-
279
- temb = nn.DenseGeneral(
280
- features=self.features,
281
- name="temb_projection",
282
- dtype=self.dtype,
283
- precision=self.precision)(temb)
284
- temb = jnp.expand_dims(jnp.expand_dims(temb, 1), 1)
285
- # scale, shift = jnp.split(temb, 2, axis=-1)
286
- # out = out * (1 + scale) + shift
287
- out = out + temb
288
-
289
- out = nn.GroupNorm(self.norm_groups)(out)
290
- out = self.activation(out)
291
-
292
- out = ConvLayer(
293
- self.conv_type,
294
- features=self.features,
295
- kernel_size=self.kernel_size,
296
- strides=self.strides,
297
- kernel_init=self.kernel_init,
298
- name="conv2",
299
- dtype=self.dtype,
300
- precision=self.precision
301
- )(out)
302
-
303
- if residual.shape != out.shape:
304
- residual = ConvLayer(
305
- self.conv_type,
306
- features=self.features,
307
- kernel_size=(1, 1),
308
- strides=1,
309
- kernel_init=self.kernel_init,
310
- name="residual_conv",
311
- dtype=self.dtype,
312
- precision=self.precision
313
- )(residual)
314
- out = out + residual
315
-
316
- out = jnp.concatenate([out, extra_features], axis=-1) if extra_features is not None else out
317
-
318
- return out
319
-
320
10
  class Unet(nn.Module):
11
+ output_channels:int=3
321
12
  emb_features:int=64*4,
322
13
  feature_depths:list=[64, 128, 256, 512],
323
14
  attention_configs:list=[{"heads":8}, {"heads":8}, {"heads":8}, {"heads":8}],
@@ -325,8 +16,8 @@ class Unet(nn.Module):
325
16
  num_middle_res_blocks:int=1,
326
17
  activation:Callable = jax.nn.swish
327
18
  norm_groups:int=8
328
- dtype: Any = jnp.bfloat16
329
- precision: Any = jax.lax.Precision.HIGH
19
+ dtype: Optional[Dtype] = None
20
+ precision: PrecisionLike = None
330
21
 
331
22
  @nn.compact
332
23
  def __call__(self, x, temb, textcontext):
@@ -373,12 +64,13 @@ class Unet(nn.Module):
373
64
  )(x, temb)
374
65
  if attention_config is not None and j == self.num_res_blocks - 1: # Apply attention only on the last block
375
66
  x = TransformerBlock(heads=attention_config['heads'], dtype=attention_config.get('dtype', jnp.float32),
376
- dim_head=dim_in // attention_config['heads'],
377
- use_flash_attention=attention_config.get("flash_attention", True),
378
- use_projection=attention_config.get("use_projection", False),
379
- use_self_and_cross=attention_config.get("use_self_and_cross", True),
380
- precision=attention_config.get("precision", self.precision),
381
- name=f"down_{i}_attention_{j}")(x, textcontext)
67
+ dim_head=dim_in // attention_config['heads'],
68
+ use_flash_attention=attention_config.get("flash_attention", True),
69
+ use_projection=attention_config.get("use_projection", False),
70
+ use_self_and_cross=attention_config.get("use_self_and_cross", True),
71
+ precision=attention_config.get("precision", self.precision),
72
+ only_pure_attention=True,
73
+ name=f"down_{i}_attention_{j}")(x, textcontext)
382
74
  # print("down residual for feature level", i, "is of shape", x.shape, "features", dim_in)
383
75
  downs.append(x)
384
76
  if i != len(feature_depths) - 1:
@@ -416,6 +108,7 @@ class Unet(nn.Module):
416
108
  use_projection=middle_attention.get("use_projection", False),
417
109
  use_self_and_cross=False,
418
110
  precision=attention_config.get("precision", self.precision),
111
+ only_pure_attention=True,
419
112
  name=f"middle_attention_{j}")(x, textcontext)
420
113
  x = ResidualBlock(
421
114
  middle_conv_type,
@@ -452,12 +145,13 @@ class Unet(nn.Module):
452
145
  )(x, temb)
453
146
  if attention_config is not None and j == self.num_res_blocks - 1: # Apply attention only on the last block
454
147
  x = TransformerBlock(heads=attention_config['heads'], dtype=attention_config.get('dtype', jnp.float32),
455
- dim_head=dim_out // attention_config['heads'],
456
- use_flash_attention=attention_config.get("flash_attention", True),
457
- use_projection=attention_config.get("use_projection", False),
458
- use_self_and_cross=attention_config.get("use_self_and_cross", True),
148
+ dim_head=dim_out // attention_config['heads'],
149
+ use_flash_attention=attention_config.get("flash_attention", True),
150
+ use_projection=attention_config.get("use_projection", False),
151
+ use_self_and_cross=attention_config.get("use_self_and_cross", True),
459
152
  precision=attention_config.get("precision", self.precision),
460
- name=f"up_{i}_attention_{j}")(x, textcontext)
153
+ only_pure_attention=True,
154
+ name=f"up_{i}_attention_{j}")(x, textcontext)
461
155
  # print("Upscaling ", i, x.shape)
462
156
  if i != len(feature_depths) - 1:
463
157
  x = Upsample(
@@ -500,7 +194,7 @@ class Unet(nn.Module):
500
194
 
501
195
  noise_out = ConvLayer(
502
196
  conv_type,
503
- features=3,
197
+ features=self.output_channels,
504
198
  kernel_size=(3, 3),
505
199
  strides=(1, 1),
506
200
  # activation=jax.nn.mish
@@ -1,201 +1,2 @@
1
- import orbax.checkpoint
2
- import tqdm
3
- from flax import linen as nn
4
- import jax
5
- from typing import Callable
6
- from dataclasses import field
7
- import jax.numpy as jnp
8
- from clu import metrics
9
- from flax.training import train_state # Useful dataclass to keep train state
10
- import optax
11
- from flax import struct # Flax dataclasses
12
- import time
13
- import os
14
- import orbax
15
- from flax.training import orbax_utils
16
-
17
- from ..schedulers import NoiseScheduler
18
- from ..predictors import DiffusionPredictionTransform, EpsilonPredictionTransform
19
-
20
- from .simple_trainer import SimpleTrainer, SimpleTrainState
21
-
22
- class TrainState(SimpleTrainState):
23
- rngs: jax.random.PRNGKey
24
- ema_params: dict
25
-
26
- def get_random_key(self):
27
- rngs, subkey = jax.random.split(self.rngs)
28
- return self.replace(rngs=rngs), subkey
29
-
30
- def apply_ema(self, decay: float = 0.999):
31
- new_ema_params = jax.tree_util.tree_map(
32
- lambda ema, param: decay * ema + (1 - decay) * param,
33
- self.ema_params,
34
- self.params,
35
- )
36
- return self.replace(ema_params=new_ema_params)
37
-
38
- class DiffusionTrainer(SimpleTrainer):
39
- noise_schedule: NoiseScheduler
40
- model_output_transform: DiffusionPredictionTransform
41
- ema_decay: float = 0.999
42
-
43
- def __init__(self,
44
- model: nn.Module,
45
- input_shapes: Dict[str, Tuple[int]],
46
- optimizer: optax.GradientTransformation,
47
- noise_schedule: NoiseScheduler,
48
- rngs: jax.random.PRNGKey,
49
- unconditional_prob: float = 0.2,
50
- name: str = "Diffusion",
51
- model_output_transform: DiffusionPredictionTransform = EpsilonPredictionTransform(),
52
- **kwargs
53
- ):
54
- super().__init__(
55
- model=model,
56
- input_shapes=input_shapes,
57
- optimizer=optimizer,
58
- rngs=rngs,
59
- name=name,
60
- **kwargs
61
- )
62
- self.noise_schedule = noise_schedule
63
- self.model_output_transform = model_output_transform
64
- self.unconditional_prob = unconditional_prob
65
-
66
- def __init_fn(
67
- self,
68
- optimizer: optax.GradientTransformation,
69
- rngs: jax.random.PRNGKey,
70
- existing_state: dict = None,
71
- existing_best_state: dict = None,
72
- model: nn.Module = None,
73
- param_transforms: Callable = None
74
- ) -> Tuple[TrainState, TrainState]:
75
- rngs, subkey = jax.random.split(rngs)
76
-
77
- if existing_state == None:
78
- input_vars = self.get_input_ones()
79
- params = model.init(subkey, **input_vars)
80
- new_state = {"params": params, "ema_params": params}
81
- else:
82
- new_state = existing_state
83
-
84
- if param_transforms is not None:
85
- params = param_transforms(params)
86
-
87
- state = TrainState.create(
88
- apply_fn=model.apply,
89
- params=new_state['params'],
90
- ema_params=new_state['ema_params'],
91
- tx=optimizer,
92
- rngs=rngs,
93
- metrics=Metrics.empty()
94
- )
95
-
96
- if existing_best_state is not None:
97
- best_state = state.replace(
98
- params=existing_best_state['params'], ema_params=existing_best_state['ema_params'])
99
- else:
100
- best_state = state
101
-
102
- return state, best_state
103
-
104
- def _define_train_step(self, batch_size, null_labels_seq, text_embedder):
105
- noise_schedule = self.noise_schedule
106
- model = self.model
107
- model_output_transform = self.model_output_transform
108
- loss_fn = self.loss_fn
109
- unconditional_prob = self.unconditional_prob
110
-
111
- # Determine the number of unconditional samples
112
- num_unconditional = int(batch_size * unconditional_prob)
113
-
114
- nS, nC = null_labels_seq.shape
115
- null_labels_seq = jnp.broadcast_to(
116
- null_labels_seq, (batch_size, nS, nC))
117
-
118
- distributed_training = self.distributed_training
119
-
120
- def train_step(state: TrainState, batch):
121
- """Train for a single step."""
122
- images = batch['image']
123
- # normalize image
124
- images = (images - 127.5) / 127.5
125
-
126
- output = text_embedder(
127
- input_ids=batch['input_ids'], attention_mask=batch['attention_mask'])
128
- # output = infer(input_ids=batch['input_ids'], attention_mask=batch['attention_mask'])
129
-
130
- label_seq = output.last_hidden_state
131
-
132
- # Generate random probabilities to decide how much of this batch will be unconditional
133
-
134
- label_seq = jnp.concat(
135
- [null_labels_seq[:num_unconditional], label_seq[num_unconditional:]], axis=0)
136
-
137
- noise_level, state = noise_schedule.generate_timesteps(
138
- images.shape[0], state)
139
- state, rngs = state.get_random_key()
140
- noise: jax.Array = jax.random.normal(rngs, shape=images.shape)
141
- rates = noise_schedule.get_rates(noise_level)
142
- noisy_images, c_in, expected_output = model_output_transform.forward_diffusion(
143
- images, noise, rates)
144
-
145
- def model_loss(params):
146
- preds = model.apply(
147
- params, *noise_schedule.transform_inputs(noisy_images*c_in, noise_level), label_seq)
148
- preds = model_output_transform.pred_transform(
149
- noisy_images, preds, rates)
150
- nloss = loss_fn(preds, expected_output)
151
- # nloss = jnp.mean(nloss, axis=1)
152
- nloss *= noise_schedule.get_weights(noise_level)
153
- nloss = jnp.mean(nloss)
154
- loss = nloss
155
- return loss
156
-
157
- loss, grads = jax.value_and_grad(model_loss)(state.params)
158
- if distributed_training:
159
- grads = jax.lax.pmean(grads, "device")
160
- state = state.apply_gradients(grads=grads)
161
- state = state.apply_ema(self.ema_decay)
162
- return state, loss
163
-
164
- if distributed_training:
165
- train_step = jax.pmap(axis_name="device")(train_step)
166
- else:
167
- train_step = jax.jit(train_step)
168
-
169
- return train_step
170
-
171
- def _define_compute_metrics(self):
172
- @jax.jit
173
- def compute_metrics(state: TrainState, expected, pred):
174
- loss = jnp.mean(jnp.square(pred - expected))
175
- metric_updates = state.metrics.single_from_model_output(loss=loss)
176
- metrics = state.metrics.merge(metric_updates)
177
- state = state.replace(metrics=metrics)
178
- return state
179
- return compute_metrics
180
-
181
- def fit(self, data, steps_per_epoch, epochs):
182
- null_labels_full = data['null_labels_full']
183
- local_batch_size = data['local_batch_size']
184
- text_embedder = data['model']
185
- super().fit(data, steps_per_epoch, epochs, {
186
- "batch_size": local_batch_size, "null_labels_seq": null_labels_full, "text_embedder": text_embedder})
187
-
188
-
189
- pbar.set_postfix(loss=f'{loss:.4f}')
190
- pbar.update(100)
191
- end_time = time.time()
192
- self.state = state
193
- total_time = end_time - start_time
194
- avg_time_per_step = total_time / steps_per_epoch
195
- avg_loss = epoch_loss / steps_per_epoch
196
- if avg_loss < self.best_loss:
197
- self.best_loss = avg_loss
198
- self.best_state = state
199
- self.save(epoch, best=True)
200
- print(f"\n\tEpoch {epoch+1} completed. Avg Loss: {avg_loss}, Time: {total_time:.2f}s, Best Loss: {self.best_loss}")
201
- return self.state
1
+ from .simple_trainer import SimpleTrainer, SimpleTrainState, Metrics
2
+ from .diffusion_trainer import DiffusionTrainer, TrainState