flaxdiff 0.1.1__tar.gz → 0.1.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. {flaxdiff-0.1.1 → flaxdiff-0.1.3}/PKG-INFO +1 -1
  2. {flaxdiff-0.1.1 → flaxdiff-0.1.3}/flaxdiff/models/attention.py +57 -115
  3. {flaxdiff-0.1.1 → flaxdiff-0.1.3}/flaxdiff/models/common.py +2 -2
  4. {flaxdiff-0.1.1 → flaxdiff-0.1.3}/flaxdiff/models/simple_unet.py +10 -14
  5. flaxdiff-0.1.3/flaxdiff/models/simple_vit.py +123 -0
  6. flaxdiff-0.1.3/flaxdiff/trainer/__init__.py +201 -0
  7. flaxdiff-0.1.3/flaxdiff/trainer/simple_trainer.py +323 -0
  8. {flaxdiff-0.1.1 → flaxdiff-0.1.3}/flaxdiff.egg-info/PKG-INFO +1 -1
  9. {flaxdiff-0.1.1 → flaxdiff-0.1.3}/flaxdiff.egg-info/SOURCES.txt +3 -1
  10. {flaxdiff-0.1.1 → flaxdiff-0.1.3}/setup.py +1 -1
  11. flaxdiff-0.1.1/flaxdiff/trainer/__init__.py +0 -216
  12. {flaxdiff-0.1.1 → flaxdiff-0.1.3}/README.md +0 -0
  13. {flaxdiff-0.1.1 → flaxdiff-0.1.3}/flaxdiff/__init__.py +0 -0
  14. {flaxdiff-0.1.1 → flaxdiff-0.1.3}/flaxdiff/models/__init__.py +0 -0
  15. {flaxdiff-0.1.1 → flaxdiff-0.1.3}/flaxdiff/models/favor_fastattn.py +0 -0
  16. {flaxdiff-0.1.1 → flaxdiff-0.1.3}/flaxdiff/predictors/__init__.py +0 -0
  17. {flaxdiff-0.1.1 → flaxdiff-0.1.3}/flaxdiff/samplers/__init__.py +0 -0
  18. {flaxdiff-0.1.1 → flaxdiff-0.1.3}/flaxdiff/samplers/common.py +0 -0
  19. {flaxdiff-0.1.1 → flaxdiff-0.1.3}/flaxdiff/samplers/ddim.py +0 -0
  20. {flaxdiff-0.1.1 → flaxdiff-0.1.3}/flaxdiff/samplers/ddpm.py +0 -0
  21. {flaxdiff-0.1.1 → flaxdiff-0.1.3}/flaxdiff/samplers/euler.py +0 -0
  22. {flaxdiff-0.1.1 → flaxdiff-0.1.3}/flaxdiff/samplers/heun_sampler.py +0 -0
  23. {flaxdiff-0.1.1 → flaxdiff-0.1.3}/flaxdiff/samplers/multistep_dpm.py +0 -0
  24. {flaxdiff-0.1.1 → flaxdiff-0.1.3}/flaxdiff/samplers/rk4_sampler.py +0 -0
  25. {flaxdiff-0.1.1 → flaxdiff-0.1.3}/flaxdiff/schedulers/__init__.py +0 -0
  26. {flaxdiff-0.1.1 → flaxdiff-0.1.3}/flaxdiff/schedulers/common.py +0 -0
  27. {flaxdiff-0.1.1 → flaxdiff-0.1.3}/flaxdiff/schedulers/continuous.py +0 -0
  28. {flaxdiff-0.1.1 → flaxdiff-0.1.3}/flaxdiff/schedulers/cosine.py +0 -0
  29. {flaxdiff-0.1.1 → flaxdiff-0.1.3}/flaxdiff/schedulers/discrete.py +0 -0
  30. {flaxdiff-0.1.1 → flaxdiff-0.1.3}/flaxdiff/schedulers/exp.py +0 -0
  31. {flaxdiff-0.1.1 → flaxdiff-0.1.3}/flaxdiff/schedulers/karras.py +0 -0
  32. {flaxdiff-0.1.1 → flaxdiff-0.1.3}/flaxdiff/schedulers/linear.py +0 -0
  33. {flaxdiff-0.1.1 → flaxdiff-0.1.3}/flaxdiff/schedulers/sqrt.py +0 -0
  34. {flaxdiff-0.1.1 → flaxdiff-0.1.3}/flaxdiff/utils.py +0 -0
  35. {flaxdiff-0.1.1 → flaxdiff-0.1.3}/flaxdiff.egg-info/dependency_links.txt +0 -0
  36. {flaxdiff-0.1.1 → flaxdiff-0.1.3}/flaxdiff.egg-info/requires.txt +0 -0
  37. {flaxdiff-0.1.1 → flaxdiff-0.1.3}/flaxdiff.egg-info/top_level.txt +0 -0
  38. {flaxdiff-0.1.1 → flaxdiff-0.1.3}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flaxdiff
3
- Version: 0.1.1
3
+ Version: 0.1.3
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author: Ashish Kumar Singh
6
6
  Author-email: ashishkmr472@gmail.com
@@ -11,105 +11,6 @@ import functools
11
11
  import math
12
12
  from .common import kernel_init
13
13
 
14
- def _query_chunk_attention(query, key, value, precision, key_chunk_size: int = 4096):
15
- """Multi-head dot product attention with a limited number of queries."""
16
- num_kv, num_heads, k_features = key.shape[-3:]
17
- v_features = value.shape[-1]
18
- key_chunk_size = min(key_chunk_size, num_kv)
19
- query = query / jnp.sqrt(k_features)
20
-
21
- @functools.partial(jax.checkpoint, prevent_cse=False)
22
- def summarize_chunk(query, key, value):
23
- attn_weights = jnp.einsum("...qhd,...khd->...qhk", query, key, precision=precision)
24
-
25
- max_score = jnp.max(attn_weights, axis=-1, keepdims=True)
26
- max_score = jax.lax.stop_gradient(max_score)
27
- exp_weights = jnp.exp(attn_weights - max_score)
28
-
29
- exp_values = jnp.einsum("...vhf,...qhv->...qhf", value, exp_weights, precision=precision)
30
- max_score = jnp.einsum("...qhk->...qh", max_score)
31
-
32
- return (exp_values, exp_weights.sum(axis=-1), max_score)
33
-
34
- def chunk_scanner(chunk_idx):
35
- # julienne key array
36
- key_chunk = jax.lax.dynamic_slice(
37
- operand=key,
38
- start_indices=[0] * (key.ndim - 3) + [chunk_idx, 0, 0], # [...,k,h,d]
39
- slice_sizes=list(key.shape[:-3]) + [key_chunk_size, num_heads, k_features], # [...,k,h,d]
40
- )
41
-
42
- # julienne value array
43
- value_chunk = jax.lax.dynamic_slice(
44
- operand=value,
45
- start_indices=[0] * (value.ndim - 3) + [chunk_idx, 0, 0], # [...,v,h,d]
46
- slice_sizes=list(value.shape[:-3]) + [key_chunk_size, num_heads, v_features], # [...,v,h,d]
47
- )
48
-
49
- return summarize_chunk(query, key_chunk, value_chunk)
50
-
51
- chunk_values, chunk_weights, chunk_max = jax.lax.map(f=chunk_scanner, xs=jnp.arange(0, num_kv, key_chunk_size))
52
-
53
- global_max = jnp.max(chunk_max, axis=0, keepdims=True)
54
- max_diffs = jnp.exp(chunk_max - global_max)
55
-
56
- chunk_values *= jnp.expand_dims(max_diffs, axis=-1)
57
- chunk_weights *= max_diffs
58
-
59
- all_values = chunk_values.sum(axis=0)
60
- all_weights = jnp.expand_dims(chunk_weights, -1).sum(axis=0)
61
-
62
- return all_values / all_weights
63
-
64
-
65
- def jax_memory_efficient_attention(
66
- query, key, value, precision=jax.lax.Precision.HIGHEST, query_chunk_size: int = 1024, key_chunk_size: int = 4096
67
- ):
68
- r"""
69
- Flax Memory-efficient multi-head dot product attention. https://arxiv.org/abs/2112.05682v2
70
- https://github.com/AminRezaei0x443/memory-efficient-attention
71
-
72
- Args:
73
- query (`jnp.ndarray`): (batch..., query_length, head, query_key_depth_per_head)
74
- key (`jnp.ndarray`): (batch..., key_value_length, head, query_key_depth_per_head)
75
- value (`jnp.ndarray`): (batch..., key_value_length, head, value_depth_per_head)
76
- precision (`jax.lax.Precision`, *optional*, defaults to `jax.lax.Precision.HIGHEST`):
77
- numerical precision for computation
78
- query_chunk_size (`int`, *optional*, defaults to 1024):
79
- chunk size to divide query array value must divide query_length equally without remainder
80
- key_chunk_size (`int`, *optional*, defaults to 4096):
81
- chunk size to divide key and value array value must divide key_value_length equally without remainder
82
-
83
- Returns:
84
- (`jnp.ndarray`) with shape of (batch..., query_length, head, value_depth_per_head)
85
- """
86
- num_q, num_heads, q_features = query.shape[-3:]
87
-
88
- def chunk_scanner(chunk_idx, _):
89
- # julienne query array
90
- query_chunk = jax.lax.dynamic_slice(
91
- operand=query,
92
- start_indices=([0] * (query.ndim - 3)) + [chunk_idx, 0, 0], # [...,q,h,d]
93
- slice_sizes=list(query.shape[:-3]) + [min(query_chunk_size, num_q), num_heads, q_features], # [...,q,h,d]
94
- )
95
-
96
- return (
97
- chunk_idx + query_chunk_size, # unused ignore it
98
- _query_chunk_attention(
99
- query=query_chunk, key=key, value=value, precision=precision, key_chunk_size=key_chunk_size
100
- ),
101
- )
102
-
103
- _, res = jax.lax.scan(
104
- f=chunk_scanner,
105
- init=0,
106
- xs=None,
107
- length=math.ceil(num_q / query_chunk_size), # start counter # stop counter
108
- )
109
-
110
- return jnp.concatenate(res, axis=-3) # fuse the chunked result back
111
-
112
-
113
14
  class EfficientAttention(nn.Module):
114
15
  """
115
16
  Based on the pallas attention implementation.
@@ -125,41 +26,77 @@ class EfficientAttention(nn.Module):
125
26
  def setup(self):
126
27
  inner_dim = self.dim_head * self.heads
127
28
  # Weights were exported with old names {to_q, to_k, to_v, to_out}
128
- self.query = nn.DenseGeneral(inner_dim, use_bias=False, precision=self.precision,
129
- kernel_init=self.kernel_init(), dtype=self.dtype, name="to_q")
130
- self.key = nn.DenseGeneral(inner_dim, use_bias=False, precision=self.precision,
131
- kernel_init=self.kernel_init(), dtype=self.dtype, name="to_k")
132
- self.value = nn.DenseGeneral(inner_dim, use_bias=False, precision=self.precision,
133
- kernel_init=self.kernel_init(), dtype=self.dtype, name="to_v")
29
+ dense = functools.partial(
30
+ nn.Dense,
31
+ self.heads * self.dim_head,
32
+ precision=self.precision,
33
+ use_bias=self.use_bias,
34
+ kernel_init=self.kernel_init(),
35
+ dtype=self.dtype
36
+ )
37
+ self.query = dense(name="to_q")
38
+ self.key = dense(name="to_k")
39
+ self.value = dense(name="to_v")
40
+
134
41
  self.proj_attn = nn.DenseGeneral(self.query_dim, use_bias=False, precision=self.precision,
135
42
  kernel_init=self.kernel_init(), dtype=self.dtype, name="to_out_0")
136
43
  # self.attnfn = make_fast_generalized_attention(qkv_dim=inner_dim, lax_scan_unroll=16)
44
+
45
+ def _reshape_tensor_to_head_dim(self, tensor):
46
+ batch_size, _, seq_len, dim = tensor.shape
47
+ head_size = self.heads
48
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
49
+ tensor = jnp.transpose(tensor, (0, 2, 1, 3))
50
+ return tensor
51
+
52
+ def _reshape_tensor_from_head_dim(self, tensor):
53
+ batch_size, _, seq_len, dim = tensor.shape
54
+ head_size = self.heads
55
+ tensor = jnp.transpose(tensor, (0, 2, 1, 3))
56
+ tensor = tensor.reshape(batch_size, 1, seq_len, dim * head_size)
57
+ return tensor
137
58
 
138
59
  @nn.compact
139
60
  def __call__(self, x:jax.Array, context=None):
61
+ # print(x.shape)
140
62
  # x has shape [B, H * W, C]
141
63
  context = x if context is None else context
64
+
65
+ B, H, W, C = x.shape
66
+ x = x.reshape((B, 1, H * W, C))
67
+
68
+ if len(context.shape) == 4:
69
+ B, _H, _W, _C = context.shape
70
+ context = context.reshape((B, 1, _H * _W, _C))
71
+ else:
72
+ B, SEQ, _C = context.shape
73
+ context = context.reshape((B, 1, SEQ, _C))
74
+
142
75
  query = self.query(x)
143
76
  key = self.key(context)
144
77
  value = self.value(context)
145
78
 
146
- # print(query.shape, key.shape, value.shape)
79
+ query = self._reshape_tensor_to_head_dim(query)
80
+ key = self._reshape_tensor_to_head_dim(key)
81
+ value = self._reshape_tensor_to_head_dim(value)
147
82
 
148
- # hidden_states = jax.experimental.pallas.ops.tpu.flash_attention.mha_reference(
149
- # query, key, value, None
150
- # )
151
-
152
- hidden_states = nn.dot_product_attention(
153
- query, key, value, dtype=self.dtype, broadcast_dropout=False, dropout_rng=None, precision=self.precision
83
+ hidden_states = jax.experimental.pallas.ops.tpu.flash_attention.flash_attention(
84
+ query, key, value, None
154
85
  )
155
- # hidden_states = self.attnfn(
156
- # query, key, value, None
86
+
87
+ hidden_states = self._reshape_tensor_from_head_dim(hidden_states)
88
+
89
+
90
+ # hidden_states = nn.dot_product_attention(
91
+ # query, key, value, dtype=self.dtype, broadcast_dropout=False, dropout_rng=None, precision=self.precision
157
92
  # )
158
93
 
159
94
  proj = self.proj_attn(hidden_states)
95
+
96
+ proj = proj.reshape((B, H, W, C))
97
+
160
98
  return proj
161
99
 
162
-
163
100
  class NormalAttention(nn.Module):
164
101
  """
165
102
  Simple implementation of the normal attention.
@@ -201,7 +138,11 @@ class NormalAttention(nn.Module):
201
138
  @nn.compact
202
139
  def __call__(self, x, context=None):
203
140
  # x has shape [B, H, W, C]
141
+ B, H, W, C = x.shape
142
+ x = x.reshape((B, H*W, C))
204
143
  context = x if context is None else context
144
+ if len(context.shape) == 4:
145
+ context = context.reshape((B, H*W, C))
205
146
  query = self.query(x)
206
147
  key = self.key(context)
207
148
  value = self.value(context)
@@ -210,6 +151,7 @@ class NormalAttention(nn.Module):
210
151
  query, key, value, dtype=self.dtype, broadcast_dropout=False, dropout_rng=None, precision=self.precision
211
152
  )
212
153
  proj = self.proj_attn(hidden_states)
154
+ proj = proj.reshape((B, H, W, C))
213
155
  return proj
214
156
 
215
157
  class AttentionBlock(nn.Module):
@@ -2,6 +2,6 @@ import jax.numpy as jnp
2
2
  from flax import linen as nn
3
3
 
4
4
  # Kernel initializer to use
5
- def kernel_init(scale):
5
+ def kernel_init(scale, dtype=jnp.float32):
6
6
  scale = max(scale, 1e-10)
7
- return nn.initializers.variance_scaling(scale=scale, mode="fan_in", distribution="truncated_normal")
7
+ return nn.initializers.variance_scaling(scale=scale, mode="fan_avg", distribution="truncated_normal", dtype=dtype)
@@ -5,6 +5,7 @@ from typing import Dict, Callable, Sequence, Any, Union
5
5
  import einops
6
6
  from .common import kernel_init
7
7
  from .attention import TransformerBlock
8
+
8
9
  class WeightStandardizedConv(nn.Module):
9
10
  """
10
11
  apply weight standardization https://arxiv.org/abs/1903.10520
@@ -243,6 +244,7 @@ def l2norm(t, axis=1, eps=1e-12):
243
244
  denom = jnp.clip(jnp.linalg.norm(t, ord=2, axis=axis, keepdims=True), eps)
244
245
  out = t/denom
245
246
  return (out)
247
+
246
248
  class ResidualBlock(nn.Module):
247
249
  conv_type:str
248
250
  features:int
@@ -327,7 +329,7 @@ class Unet(nn.Module):
327
329
  precision: Any = jax.lax.Precision.HIGH
328
330
 
329
331
  @nn.compact
330
- def __call__(self, x, temb, textcontext=None):
332
+ def __call__(self, x, temb, textcontext):
331
333
  # print("embedding features", self.emb_features)
332
334
  temb = FourierEmbedding(features=self.emb_features)(temb)
333
335
  temb = TimeProjection(features=self.emb_features)(temb)
@@ -341,6 +343,8 @@ class Unet(nn.Module):
341
343
  conv_type = up_conv_type = down_conv_type = middle_conv_type = "conv"
342
344
  # middle_conv_type = "separable"
343
345
 
346
+ print(f"input shape: {x.shape}")
347
+
344
348
  x = ConvLayer(
345
349
  conv_type,
346
350
  features=self.feature_depths[0],
@@ -351,6 +355,8 @@ class Unet(nn.Module):
351
355
  precision=self.precision
352
356
  )(x)
353
357
  downs = [x]
358
+
359
+ print(f"x shape: {x.shape}")
354
360
 
355
361
  # Downscaling blocks
356
362
  for i, (dim_out, attention_config) in enumerate(zip(feature_depths, attention_configs)):
@@ -370,18 +376,13 @@ class Unet(nn.Module):
370
376
  precision=self.precision
371
377
  )(x, temb)
372
378
  if attention_config is not None and j == self.num_res_blocks - 1: # Apply attention only on the last block
373
- B, H, W, _ = x.shape
374
- if H > TS:
375
- padded_context = jnp.pad(textcontext, ((0, 0), (0, H - TS), (0, 0)), mode='constant', constant_values=0).reshape((B, 1, H, TC))
376
- else:
377
- padded_context = None
378
379
  x = TransformerBlock(heads=attention_config['heads'], dtype=attention_config.get('dtype', jnp.float32),
379
380
  dim_head=dim_in // attention_config['heads'],
380
381
  use_flash_attention=attention_config.get("flash_attention", True),
381
382
  use_projection=attention_config.get("use_projection", False),
382
383
  use_self_and_cross=attention_config.get("use_self_and_cross", True),
383
384
  precision=attention_config.get("precision", self.precision),
384
- name=f"down_{i}_attention_{j}")(x, padded_context)
385
+ name=f"down_{i}_attention_{j}")(x, textcontext)
385
386
  # print("down residual for feature level", i, "is of shape", x.shape, "features", dim_in)
386
387
  downs.append(x)
387
388
  if i != len(feature_depths) - 1:
@@ -419,7 +420,7 @@ class Unet(nn.Module):
419
420
  use_projection=middle_attention.get("use_projection", False),
420
421
  use_self_and_cross=False,
421
422
  precision=attention_config.get("precision", self.precision),
422
- name=f"middle_attention_{j}")(x)
423
+ name=f"middle_attention_{j}")(x, textcontext)
423
424
  x = ResidualBlock(
424
425
  middle_conv_type,
425
426
  name=f"middle_res2_{j}",
@@ -454,18 +455,13 @@ class Unet(nn.Module):
454
455
  precision=self.precision
455
456
  )(x, temb)
456
457
  if attention_config is not None and j == self.num_res_blocks - 1: # Apply attention only on the last block
457
- B, H, W, _ = x.shape
458
- if H > TS:
459
- padded_context = jnp.pad(textcontext, ((0, 0), (0, H - TS), (0, 0)), mode='constant', constant_values=0).reshape((B, 1, H, TC))
460
- else:
461
- padded_context = None
462
458
  x = TransformerBlock(heads=attention_config['heads'], dtype=attention_config.get('dtype', jnp.float32),
463
459
  dim_head=dim_out // attention_config['heads'],
464
460
  use_flash_attention=attention_config.get("flash_attention", True),
465
461
  use_projection=attention_config.get("use_projection", False),
466
462
  use_self_and_cross=attention_config.get("use_self_and_cross", True),
467
463
  precision=attention_config.get("precision", self.precision),
468
- name=f"up_{i}_attention_{j}")(x, padded_context)
464
+ name=f"up_{i}_attention_{j}")(x, textcontext)
469
465
  # print("Upscaling ", i, x.shape)
470
466
  if i != len(feature_depths) - 1:
471
467
  x = Upsample(
@@ -0,0 +1,123 @@
1
+ # simple_vit.py
2
+
3
+ import jax
4
+ import jax.numpy as jnp
5
+ from flax import linen as nn
6
+ from typing import Callable, Any
7
+ from .simply_unet import FourierEmbedding, TimeProjection, ConvLayer, kernel_init
8
+ from .attention import TransformerBlock
9
+
10
+ class PatchEmbedding(nn.Module):
11
+ patch_size: int
12
+ embedding_dim: int
13
+ dtype: Any = jnp.float32
14
+ precision: Any = jax.lax.Precision.HIGH
15
+
16
+ @nn.compact
17
+ def __call__(self, x):
18
+ batch, height, width, channels = x.shape
19
+ assert height % self.patch_size == 0 and width % self.patch_size == 0, "Image dimensions must be divisible by patch size"
20
+
21
+ x = nn.Conv(features=self.embedding_dim,
22
+ kernel_size=(self.patch_size, self.patch_size),
23
+ strides=(self.patch_size, self.patch_size),
24
+ dtype=self.dtype,
25
+ precision=self.precision)(x)
26
+ x = jnp.reshape(x, (batch, -1, self.embedding_dim))
27
+ return x
28
+
29
+ class PositionalEncoding(nn.Module):
30
+ max_len: int
31
+ embedding_dim: int
32
+
33
+ @nn.compact
34
+ def __call__(self, x):
35
+ pe = self.param('pos_encoding',
36
+ jax.nn.initializers.zeros,
37
+ (1, self.max_len, self.embedding_dim))
38
+ return x + pe[:, :x.shape[1], :]
39
+
40
+ class TransformerEncoder(nn.Module):
41
+ num_layers: int
42
+ num_heads: int
43
+ mlp_dim: int
44
+ dropout_rate: float = 0.1
45
+ dtype: Any = jnp.float32
46
+ precision: Any = jax.lax.Precision.HIGH
47
+
48
+ @nn.compact
49
+ def __call__(self, x, training=True):
50
+ for _ in range(self.num_layers):
51
+ x = TransformerBlock(
52
+ heads=self.num_heads,
53
+ dim_head=x.shape[-1] // self.num_heads,
54
+ mlp_dim=self.mlp_dim,
55
+ dropout_rate=self.dropout_rate,
56
+ dtype=self.dtype,
57
+ precision=self.precision
58
+ )(x)
59
+ return x
60
+
61
+ class VisionTransformer(nn.Module):
62
+ patch_size: int = 16
63
+ embedding_dim: int = 768
64
+ num_layers: int = 12
65
+ num_heads: int = 12
66
+ mlp_dim: int = 3072
67
+ emb_features: int = 256
68
+ dropout_rate: float = 0.1
69
+ dtype: Any = jnp.float32
70
+ precision: Any = jax.lax.Precision.HIGH
71
+
72
+ @nn.compact
73
+ def __call__(self, x, temb, textcontext=None):
74
+ # Time embedding
75
+ temb = FourierEmbedding(features=self.emb_features)(temb)
76
+ temb = TimeProjection(features=self.emb_features)(temb)
77
+
78
+ # Patch embedding
79
+ x = PatchEmbedding(patch_size=self.patch_size, embedding_dim=self.embedding_dim,
80
+ dtype=self.dtype, precision=self.precision)(x)
81
+
82
+ # Add positional encoding
83
+ x = PositionalEncoding(max_len=x.shape[1], embedding_dim=self.embedding_dim)(x)
84
+
85
+ # Add time embedding
86
+ temb = jnp.expand_dims(temb, axis=1)
87
+ x = jnp.concatenate([x, temb], axis=1)
88
+
89
+ # Add text context
90
+ if textcontext is not None:
91
+ x = jnp.concatenate([x, textcontext], axis=1)
92
+
93
+ # Transformer encoder
94
+ x = TransformerEncoder(
95
+ num_layers=self.num_layers,
96
+ num_heads=self.num_heads,
97
+ mlp_dim=self.mlp_dim,
98
+ dropout_rate=self.dropout_rate,
99
+ dtype=self.dtype,
100
+ precision=self.precision
101
+ )(x)
102
+
103
+ # Extract the image tokens (exclude time and text embeddings)
104
+ num_patches = (x.shape[1] - 1 - (0 if textcontext is None else textcontext.shape[1]))
105
+ x = x[:, :num_patches, :]
106
+
107
+ # Reshape to image dimensions
108
+ batch, _, _ = x.shape
109
+ height = width = int((num_patches) ** 0.5)
110
+ x = jnp.reshape(x, (batch, height, width, self.embedding_dim))
111
+
112
+ # Final convolution to get the desired output channels
113
+ x = ConvLayer(
114
+ conv_type="conv",
115
+ features=3,
116
+ kernel_size=(3, 3),
117
+ strides=(1, 1),
118
+ kernel_init=kernel_init(0.0),
119
+ dtype=self.dtype,
120
+ precision=self.precision
121
+ )(x)
122
+
123
+ return x
@@ -0,0 +1,201 @@
1
+ import orbax.checkpoint
2
+ import tqdm
3
+ from flax import linen as nn
4
+ import jax
5
+ from typing import Callable
6
+ from dataclasses import field
7
+ import jax.numpy as jnp
8
+ from clu import metrics
9
+ from flax.training import train_state # Useful dataclass to keep train state
10
+ import optax
11
+ from flax import struct # Flax dataclasses
12
+ import time
13
+ import os
14
+ import orbax
15
+ from flax.training import orbax_utils
16
+
17
+ from ..schedulers import NoiseScheduler
18
+ from ..predictors import DiffusionPredictionTransform, EpsilonPredictionTransform
19
+
20
+ from .simple_trainer import SimpleTrainer, SimpleTrainState
21
+
22
+ class TrainState(SimpleTrainState):
23
+ rngs: jax.random.PRNGKey
24
+ ema_params: dict
25
+
26
+ def get_random_key(self):
27
+ rngs, subkey = jax.random.split(self.rngs)
28
+ return self.replace(rngs=rngs), subkey
29
+
30
+ def apply_ema(self, decay: float = 0.999):
31
+ new_ema_params = jax.tree_util.tree_map(
32
+ lambda ema, param: decay * ema + (1 - decay) * param,
33
+ self.ema_params,
34
+ self.params,
35
+ )
36
+ return self.replace(ema_params=new_ema_params)
37
+
38
+ class DiffusionTrainer(SimpleTrainer):
39
+ noise_schedule: NoiseScheduler
40
+ model_output_transform: DiffusionPredictionTransform
41
+ ema_decay: float = 0.999
42
+
43
+ def __init__(self,
44
+ model: nn.Module,
45
+ input_shapes: Dict[str, Tuple[int]],
46
+ optimizer: optax.GradientTransformation,
47
+ noise_schedule: NoiseScheduler,
48
+ rngs: jax.random.PRNGKey,
49
+ unconditional_prob: float = 0.2,
50
+ name: str = "Diffusion",
51
+ model_output_transform: DiffusionPredictionTransform = EpsilonPredictionTransform(),
52
+ **kwargs
53
+ ):
54
+ super().__init__(
55
+ model=model,
56
+ input_shapes=input_shapes,
57
+ optimizer=optimizer,
58
+ rngs=rngs,
59
+ name=name,
60
+ **kwargs
61
+ )
62
+ self.noise_schedule = noise_schedule
63
+ self.model_output_transform = model_output_transform
64
+ self.unconditional_prob = unconditional_prob
65
+
66
+ def __init_fn(
67
+ self,
68
+ optimizer: optax.GradientTransformation,
69
+ rngs: jax.random.PRNGKey,
70
+ existing_state: dict = None,
71
+ existing_best_state: dict = None,
72
+ model: nn.Module = None,
73
+ param_transforms: Callable = None
74
+ ) -> Tuple[TrainState, TrainState]:
75
+ rngs, subkey = jax.random.split(rngs)
76
+
77
+ if existing_state == None:
78
+ input_vars = self.get_input_ones()
79
+ params = model.init(subkey, **input_vars)
80
+ new_state = {"params": params, "ema_params": params}
81
+ else:
82
+ new_state = existing_state
83
+
84
+ if param_transforms is not None:
85
+ params = param_transforms(params)
86
+
87
+ state = TrainState.create(
88
+ apply_fn=model.apply,
89
+ params=new_state['params'],
90
+ ema_params=new_state['ema_params'],
91
+ tx=optimizer,
92
+ rngs=rngs,
93
+ metrics=Metrics.empty()
94
+ )
95
+
96
+ if existing_best_state is not None:
97
+ best_state = state.replace(
98
+ params=existing_best_state['params'], ema_params=existing_best_state['ema_params'])
99
+ else:
100
+ best_state = state
101
+
102
+ return state, best_state
103
+
104
+ def _define_train_step(self, batch_size, null_labels_seq, text_embedder):
105
+ noise_schedule = self.noise_schedule
106
+ model = self.model
107
+ model_output_transform = self.model_output_transform
108
+ loss_fn = self.loss_fn
109
+ unconditional_prob = self.unconditional_prob
110
+
111
+ # Determine the number of unconditional samples
112
+ num_unconditional = int(batch_size * unconditional_prob)
113
+
114
+ nS, nC = null_labels_seq.shape
115
+ null_labels_seq = jnp.broadcast_to(
116
+ null_labels_seq, (batch_size, nS, nC))
117
+
118
+ distributed_training = self.distributed_training
119
+
120
+ def train_step(state: TrainState, batch):
121
+ """Train for a single step."""
122
+ images = batch['image']
123
+ # normalize image
124
+ images = (images - 127.5) / 127.5
125
+
126
+ output = text_embedder(
127
+ input_ids=batch['input_ids'], attention_mask=batch['attention_mask'])
128
+ # output = infer(input_ids=batch['input_ids'], attention_mask=batch['attention_mask'])
129
+
130
+ label_seq = output.last_hidden_state
131
+
132
+ # Generate random probabilities to decide how much of this batch will be unconditional
133
+
134
+ label_seq = jnp.concat(
135
+ [null_labels_seq[:num_unconditional], label_seq[num_unconditional:]], axis=0)
136
+
137
+ noise_level, state = noise_schedule.generate_timesteps(
138
+ images.shape[0], state)
139
+ state, rngs = state.get_random_key()
140
+ noise: jax.Array = jax.random.normal(rngs, shape=images.shape)
141
+ rates = noise_schedule.get_rates(noise_level)
142
+ noisy_images, c_in, expected_output = model_output_transform.forward_diffusion(
143
+ images, noise, rates)
144
+
145
+ def model_loss(params):
146
+ preds = model.apply(
147
+ params, *noise_schedule.transform_inputs(noisy_images*c_in, noise_level), label_seq)
148
+ preds = model_output_transform.pred_transform(
149
+ noisy_images, preds, rates)
150
+ nloss = loss_fn(preds, expected_output)
151
+ # nloss = jnp.mean(nloss, axis=1)
152
+ nloss *= noise_schedule.get_weights(noise_level)
153
+ nloss = jnp.mean(nloss)
154
+ loss = nloss
155
+ return loss
156
+
157
+ loss, grads = jax.value_and_grad(model_loss)(state.params)
158
+ if distributed_training:
159
+ grads = jax.lax.pmean(grads, "device")
160
+ state = state.apply_gradients(grads=grads)
161
+ state = state.apply_ema(self.ema_decay)
162
+ return state, loss
163
+
164
+ if distributed_training:
165
+ train_step = jax.pmap(axis_name="device")(train_step)
166
+ else:
167
+ train_step = jax.jit(train_step)
168
+
169
+ return train_step
170
+
171
+ def _define_compute_metrics(self):
172
+ @jax.jit
173
+ def compute_metrics(state: TrainState, expected, pred):
174
+ loss = jnp.mean(jnp.square(pred - expected))
175
+ metric_updates = state.metrics.single_from_model_output(loss=loss)
176
+ metrics = state.metrics.merge(metric_updates)
177
+ state = state.replace(metrics=metrics)
178
+ return state
179
+ return compute_metrics
180
+
181
+ def fit(self, data, steps_per_epoch, epochs):
182
+ null_labels_full = data['null_labels_full']
183
+ local_batch_size = data['local_batch_size']
184
+ text_embedder = data['model']
185
+ super().fit(data, steps_per_epoch, epochs, {
186
+ "batch_size": local_batch_size, "null_labels_seq": null_labels_full, "text_embedder": text_embedder})
187
+
188
+
189
+ pbar.set_postfix(loss=f'{loss:.4f}')
190
+ pbar.update(100)
191
+ end_time = time.time()
192
+ self.state = state
193
+ total_time = end_time - start_time
194
+ avg_time_per_step = total_time / steps_per_epoch
195
+ avg_loss = epoch_loss / steps_per_epoch
196
+ if avg_loss < self.best_loss:
197
+ self.best_loss = avg_loss
198
+ self.best_state = state
199
+ self.save(epoch, best=True)
200
+ print(f"\n\tEpoch {epoch+1} completed. Avg Loss: {avg_loss}, Time: {total_time:.2f}s, Best Loss: {self.best_loss}")
201
+ return self.state
@@ -0,0 +1,323 @@
1
+ import orbax.checkpoint
2
+ import tqdm
3
+ from flax import linen as nn
4
+ import jax
5
+ from typing import Callable
6
+ from dataclasses import field
7
+ import jax.numpy as jnp
8
+ from clu import metrics
9
+ from flax.training import train_state # Useful dataclass to keep train state
10
+ import optax
11
+ from flax import struct # Flax dataclasses
12
+ import time
13
+ import os
14
+ import orbax
15
+ from flax.training import orbax_utils
16
+
17
+ @struct.dataclass
18
+ class Metrics(metrics.Collection):
19
+ accuracy: metrics.Accuracy
20
+ loss: metrics.Average.from_output('loss')
21
+
22
+ # Define the TrainState
23
+ class SimpleTrainState(train_state.TrainState):
24
+ rngs: jax.random.PRNGKey
25
+ metrics: Metrics
26
+
27
+ def get_random_key(self):
28
+ rngs, subkey = jax.random.split(self.rngs)
29
+ return self.replace(rngs=rngs), subkey
30
+
31
+ class SimpleTrainer:
32
+ state: SimpleTrainState
33
+ best_state: SimpleTrainState
34
+ best_loss: float
35
+ model: nn.Module
36
+ ema_decay: float = 0.999
37
+
38
+ def __init__(self,
39
+ model: nn.Module,
40
+ input_shapes: Dict[str, Tuple[int]],
41
+ optimizer: optax.GradientTransformation,
42
+ rngs: jax.random.PRNGKey,
43
+ train_state: SimpleTrainState = None,
44
+ name: str = "Simple",
45
+ load_from_checkpoint: bool = False,
46
+ checkpoint_suffix: str = "",
47
+ loss_fn=optax.l2_loss,
48
+ param_transforms: Callable = None,
49
+ wandb_config: Dict[str, Any] = None,
50
+ distributed_training: bool = None,
51
+ ):
52
+ if distributed_training is None or distributed_training is True:
53
+ # Auto-detect if we are running on multiple devices
54
+ distributed_training = jax.device_count() > 1
55
+
56
+ self.distributed_training = distributed_training
57
+ self.model = model
58
+ self.name = name
59
+ self.loss_fn = loss_fn
60
+ self.input_shapes = input_shapes
61
+
62
+ if wandb_config is not None:
63
+ run = wandb.init(**wandb_config)
64
+ self.wandb = run
65
+
66
+ checkpointer = orbax.checkpoint.PyTreeCheckpointer()
67
+ options = orbax.checkpoint.CheckpointManagerOptions(
68
+ max_to_keep=4, create=True)
69
+ self.checkpointer = orbax.checkpoint.CheckpointManager(
70
+ self.checkpoint_path() + checkpoint_suffix, checkpointer, options)
71
+
72
+ if load_from_checkpoint:
73
+ latest_epoch, old_state, old_best_state = self.load()
74
+ else:
75
+ latest_epoch, old_state, old_best_state = 0, None, None
76
+
77
+ self.latest_epoch = latest_epoch
78
+
79
+ if train_state == None:
80
+ self.init_state(optimizer, rngs, existing_state=old_state,
81
+ existing_best_state=old_best_state, model=model, param_transforms=param_transforms)
82
+ else:
83
+ self.state = train_state
84
+ self.best_state = train_state
85
+ self.best_loss = 1e9
86
+
87
+ def get_input_ones(self):
88
+ return {k: jnp.ones((1, *v)) for k, v in self.input_shapes.items()}
89
+
90
+ def __init_fn(
91
+ self,
92
+ optimizer: optax.GradientTransformation,
93
+ rngs: jax.random.PRNGKey,
94
+ existing_state: dict = None,
95
+ existing_best_state: dict = None,
96
+ model: nn.Module = None,
97
+ param_transforms: Callable = None
98
+ ) -> Tuple[SimpleTrainState, SimpleTrainState]:
99
+ rngs, subkey = jax.random.split(rngs)
100
+
101
+ if existing_state == None:
102
+ input_vars = self.get_input_ones()
103
+ params = model.init(subkey, **input_vars)
104
+
105
+ state = SimpleTrainState.create(
106
+ apply_fn=model.apply,
107
+ params=params,
108
+ tx=optimizer,
109
+ rngs=rngs,
110
+ metrics=Metrics.empty()
111
+ )
112
+ if existing_best_state is not None:
113
+ best_state = state.replace(
114
+ params=existing_best_state['params'])
115
+ else:
116
+ best_state = state
117
+
118
+ return state, best_state
119
+
120
+ def init_state(
121
+ self,
122
+ optimizer: optax.GradientTransformation,
123
+ rngs: jax.random.PRNGKey,
124
+ existing_state: dict = None,
125
+ existing_best_state: dict = None,
126
+ model: nn.Module = None,
127
+ param_transforms: Callable = None
128
+ ):
129
+
130
+ state, best_state = self.__init_fn(
131
+ optimizer, rngs, existing_state, existing_best_state, model, param_transforms
132
+ )
133
+ self.best_loss = 1e9
134
+
135
+ if self.distributed_training:
136
+ devices = jax.local_devices()
137
+ if len(devices) > 1:
138
+ print("Replicating state across devices ", devices)
139
+ state = flax.jax_utils.replicate(state, devices)
140
+ best_state = flax.jax_utils.replicate(best_state, devices)
141
+ else:
142
+ print("Not replicating any state, Only single device connected to the process")
143
+
144
+ self.state = state
145
+ self.best_state = best_state
146
+
147
+ def get_state(self):
148
+ return flax.jax_utils.unreplicate(self.state)
149
+
150
+ def get_best_state(self):
151
+ return flax.jax_utils.unreplicate(self.best_state)
152
+
153
+ def checkpoint_path(self):
154
+ experiment_name = self.name
155
+ path = os.path.join(os.path.abspath('./checkpoints'), experiment_name)
156
+ if not os.path.exists(path):
157
+ os.makedirs(path)
158
+ return path
159
+
160
+ def tensorboard_path(self):
161
+ experiment_name = self.name
162
+ path = os.path.join(os.path.abspath('./tensorboard'), experiment_name)
163
+ if not os.path.exists(path):
164
+ os.makedirs(path)
165
+ return path
166
+
167
+ def load(self):
168
+ epoch = self.checkpointer.latest_step()
169
+ print("Loading model from checkpoint", epoch)
170
+ ckpt = self.checkpointer.restore(epoch)
171
+ state = ckpt['state']
172
+ best_state = ckpt['best_state']
173
+ # Convert the state to a TrainState
174
+ self.best_loss = ckpt['best_loss']
175
+ print(
176
+ f"Loaded model from checkpoint at epoch {epoch}", ckpt['best_loss'])
177
+ return epoch, state, best_state
178
+
179
+ def save(self, epoch=0):
180
+ print(f"Saving model at epoch {epoch}")
181
+ ckpt = {
182
+ # 'model': self.model,
183
+ 'state': self.get_state(),
184
+ 'best_state': self.get_best_state(),
185
+ 'best_loss': self.best_loss
186
+ }
187
+ try:
188
+ save_args = orbax_utils.save_args_from_target(ckpt)
189
+ self.checkpointer.save(epoch, ckpt, save_kwargs={
190
+ 'save_args': save_args}, force=True)
191
+ pass
192
+ except Exception as e:
193
+ print("Error saving checkpoint", e)
194
+
195
+ def _define_train_step(self, **kwargs):
196
+ model = self.model
197
+ loss_fn = self.loss_fn
198
+ distributed_training = self.distributed_training
199
+
200
+ def train_step(state: SimpleTrainState, batch):
201
+ """Train for a single step."""
202
+ images = batch['image']
203
+ labels = batch['label']
204
+
205
+ def model_loss(params):
206
+ preds = model.apply(params, images)
207
+ expected_output = labels
208
+ nloss = loss_fn(preds, expected_output)
209
+ loss = jnp.mean(nloss)
210
+ return loss
211
+ loss, grads = jax.value_and_grad(model_loss)(state.params)
212
+ if distributed_training:
213
+ grads = jax.lax.pmean(grads, "device")
214
+ state = state.apply_gradients(grads=grads)
215
+ return state, loss
216
+
217
+ if distributed_training:
218
+ train_step = jax.pmap(axis_name="device")(train_step)
219
+ else:
220
+ train_step = jax.jit(train_step)
221
+
222
+ return train_step
223
+
224
+ def _define_compute_metrics(self):
225
+ model = self.model
226
+ loss_fn = self.loss_fn
227
+
228
+ @jax.jit
229
+ def compute_metrics(state: SimpleTrainState, batch):
230
+ preds = model.apply(state.params, batch['image'])
231
+ expected_output = batch['label']
232
+ loss = jnp.mean(loss_fn(preds, expected_output))
233
+ metric_updates = state.metrics.single_from_model_output(
234
+ loss=loss, logits=preds, labels=expected_output)
235
+ metrics = state.metrics.merge(metric_updates)
236
+ state = state.replace(metrics=metrics)
237
+ return state
238
+ return compute_metrics
239
+
240
+ def summary(self):
241
+ input_vars = self.get_input_ones()
242
+ print(self.model.tabulate(jax.random.key(0), **input_vars,
243
+ console_kwargs={"width": 200, "force_jupyter": True, }))
244
+
245
+ def config(self):
246
+ return {
247
+ "model": self.model,
248
+ "state": self.state,
249
+ "name": self.name,
250
+ "input_shapes": self.input_shapes
251
+ }
252
+
253
+ def init_tensorboard(self, batch_size, steps_per_epoch, epochs):
254
+ summary_writer = tensorboard.SummaryWriter(self.tensorboard_path())
255
+ summary_writer.hparams({
256
+ **self.config(),
257
+ "steps_per_epoch": steps_per_epoch,
258
+ "epochs": epochs,
259
+ "batch_size": batch_size
260
+ })
261
+ return summary_writer
262
+
263
+ def fit(self, data, steps_per_epoch, epochs, train_step_args={}):
264
+ train_ds = iter(data['train']())
265
+ if 'test' in data:
266
+ test_ds = data['test']
267
+ else:
268
+ test_ds = None
269
+ train_step = self._define_train_step(**train_step_args)
270
+ compute_metrics = self._define_compute_metrics()
271
+ state = self.state
272
+ device_count = jax.local_device_count()
273
+ # train_ds = flax.jax_utils.prefetch_to_device(train_ds, jax.devices())
274
+
275
+ summary_writer = self.init_tensorboard(
276
+ data['global_batch_size'], steps_per_epoch, epochs)
277
+
278
+ while self.latest_epoch <= epochs:
279
+ self.latest_epoch += 1
280
+ current_epoch = self.latest_epoch
281
+ print(f"\nEpoch {current_epoch}/{epochs}")
282
+ start_time = time.time()
283
+ epoch_loss = 0
284
+
285
+ with tqdm.tqdm(total=steps_per_epoch, desc=f'\t\tEpoch {current_epoch}', ncols=100, unit='step') as pbar:
286
+ for i in range(steps_per_epoch):
287
+ batch = next(train_ds)
288
+ if self.distributed_training and device_count > 1:
289
+ batch = jax.tree.map(lambda x: x.reshape(
290
+ (device_count, -1, *x.shape[1:])), batch)
291
+
292
+ state, loss = train_step(state, batch)
293
+ loss = jnp.mean(loss)
294
+
295
+ epoch_loss += loss
296
+ if i % 100 == 0:
297
+ pbar.set_postfix(loss=f'{loss:.4f}')
298
+ pbar.update(100)
299
+ current_step = current_epoch*steps_per_epoch + i
300
+ summary_writer.scalar(
301
+ 'Train Loss', loss, step=current_step)
302
+ if self.wandb is not None:
303
+ self.wandb.log({"train/loss": loss})
304
+
305
+ print(f"\n\tEpoch done")
306
+ end_time = time.time()
307
+ self.state = state
308
+ total_time = end_time - start_time
309
+ avg_time_per_step = total_time / steps_per_epoch
310
+ avg_loss = epoch_loss / steps_per_epoch
311
+ if avg_loss < self.best_loss:
312
+ self.best_loss = avg_loss
313
+ self.best_state = state
314
+ self.save(current_epoch)
315
+
316
+ # Compute Metrics
317
+ metrics_str = ''
318
+
319
+ print(
320
+ f"\n\tEpoch {current_epoch} completed. Avg Loss: {avg_loss}, Time: {total_time:.2f}s, Best Loss: {self.best_loss} {metrics_str}")
321
+
322
+ self.save(epochs)
323
+ return self.state
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flaxdiff
3
- Version: 0.1.1
3
+ Version: 0.1.3
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author: Ashish Kumar Singh
6
6
  Author-email: ashishkmr472@gmail.com
@@ -12,6 +12,7 @@ flaxdiff/models/attention.py
12
12
  flaxdiff/models/common.py
13
13
  flaxdiff/models/favor_fastattn.py
14
14
  flaxdiff/models/simple_unet.py
15
+ flaxdiff/models/simple_vit.py
15
16
  flaxdiff/predictors/__init__.py
16
17
  flaxdiff/samplers/__init__.py
17
18
  flaxdiff/samplers/common.py
@@ -30,4 +31,5 @@ flaxdiff/schedulers/exp.py
30
31
  flaxdiff/schedulers/karras.py
31
32
  flaxdiff/schedulers/linear.py
32
33
  flaxdiff/schedulers/sqrt.py
33
- flaxdiff/trainer/__init__.py
34
+ flaxdiff/trainer/__init__.py
35
+ flaxdiff/trainer/simple_trainer.py
@@ -11,7 +11,7 @@ required_packages=[
11
11
  setup(
12
12
  name='flaxdiff',
13
13
  packages=find_packages(),
14
- version='0.1.1',
14
+ version='0.1.3',
15
15
  description='A versatile and easy to understand Diffusion library',
16
16
  long_description=open('README.md').read(),
17
17
  long_description_content_type='text/markdown',
@@ -1,216 +0,0 @@
1
- import orbax.checkpoint
2
- import tqdm
3
- from flax import linen as nn
4
- import jax
5
- from typing import Callable
6
- from dataclasses import field
7
- import jax.numpy as jnp
8
- from clu import metrics
9
- from flax.training import train_state # Useful dataclass to keep train state
10
- import optax
11
- from flax import struct # Flax dataclasses
12
- import time
13
- import os
14
- import orbax
15
- from flax.training import orbax_utils
16
-
17
- from ..schedulers import NoiseScheduler
18
- from ..predictors import DiffusionPredictionTransform, EpsilonPredictionTransform
19
-
20
- @struct.dataclass
21
- class Metrics(metrics.Collection):
22
- loss: metrics.Average.from_output('loss') # type: ignore
23
-
24
- class ModelState():
25
- model: nn.Module
26
- params: dict
27
- noise_schedule: NoiseScheduler
28
- model_output_transform: DiffusionPredictionTransform
29
-
30
- # Define the TrainState with EMA parameters
31
- class TrainState(train_state.TrainState):
32
- rngs: jax.random.PRNGKey
33
- ema_params: dict
34
-
35
- def get_random_key(self):
36
- rngs, subkey = jax.random.split(self.rngs)
37
- return self.replace(rngs=rngs), subkey
38
-
39
- def apply_ema(self, decay: float=0.999):
40
- new_ema_params = jax.tree_util.tree_map(
41
- lambda ema, param: decay * ema + (1 - decay) * param,
42
- self.ema_params,
43
- self.params,
44
- )
45
- return self.replace(ema_params=new_ema_params)
46
-
47
- class DiffusionTrainer:
48
- state : TrainState
49
- best_state : TrainState
50
- best_loss : float
51
- model : nn.Module
52
- noise_schedule : NoiseScheduler
53
- model_output_transform:DiffusionPredictionTransform
54
- ema_decay:float = 0.999
55
-
56
- def __init__(self,
57
- model:nn.Module,
58
- optimizer: optax.GradientTransformation,
59
- noise_schedule:NoiseScheduler,
60
- rngs:jax.random.PRNGKey,
61
- train_state:TrainState=None,
62
- name:str="Diffusion",
63
- load_from_checkpoint:bool=False,
64
- param_transforms:Callable=None,
65
- model_output_transform:DiffusionPredictionTransform=EpsilonPredictionTransform(),
66
- loss_fn=optax.l2_loss,
67
- ):
68
- self.model = model
69
- self.noise_schedule = noise_schedule
70
- self.name = name
71
- self.model_output_transform = model_output_transform
72
- self.loss_fn = loss_fn
73
-
74
- checkpointer = orbax.checkpoint.PyTreeCheckpointer()
75
- options = orbax.checkpoint.CheckpointManagerOptions(max_to_keep=4, create=True)
76
- self.checkpointer = orbax.checkpoint.CheckpointManager(self.checkpoint_path(), checkpointer, options)
77
-
78
- if load_from_checkpoint:
79
- params = self.load()
80
- else:
81
- params = None
82
-
83
- if train_state == None:
84
- self.init_state(optimizer, rngs, params=params, model=model, param_transforms=param_transforms)
85
- else:
86
- self.state = train_state
87
- self.best_state = train_state
88
- self.best_loss = 1e9
89
-
90
- def init_state(self,
91
- optimizer: optax.GradientTransformation,
92
- rngs:jax.random.PRNGKey,
93
- params:dict=None,
94
- model:nn.Module=None,
95
- param_transforms:Callable=None,
96
- batch_size=16,
97
- image_size=64
98
- ):
99
- inp = jnp.ones((batch_size, image_size, image_size, 3))
100
- temb = jnp.ones((batch_size,))
101
- rngs, subkey = jax.random.split(rngs)
102
- if params == None:
103
- params = model.init(subkey, inp, temb)
104
- if param_transforms is not None:
105
- params = param_transforms(params)
106
- self.best_loss = 1e9
107
- self.state = TrainState.create(
108
- apply_fn=model.apply,
109
- params=params,
110
- ema_params=params,
111
- tx=optimizer,
112
- rngs=rngs,
113
- )
114
- self.best_state = self.state
115
-
116
- def checkpoint_path(self):
117
- experiment_name = self.name
118
- path = os.path.join(os.path.abspath('./models'), experiment_name)
119
- if not os.path.exists(path):
120
- os.makedirs(path)
121
- return path
122
-
123
- def load(self):
124
- step = self.checkpointer.latest_step()
125
- print("Loading model from checkpoint", step)
126
- ckpt = self.checkpointer.restore(step)
127
- state = ckpt['state']
128
- # Convert the state to a TrainState
129
- self.best_loss = ckpt['best_loss']
130
- print(f"Loaded model from checkpoint at step {step}", ckpt['best_loss'])
131
- return state.get('params', None)#, ckpt.get('model', None)
132
-
133
- def save(self, epoch=0, best=False):
134
- print(f"Saving model at epoch {epoch}")
135
- state = self.best_state if best else self.state
136
- # filename = os.path.join(self.checkpoint_path(), f'model_{epoch}' if not best else 'best_model')
137
- ckpt = {
138
- 'model': self.model,
139
- 'state': state,
140
- 'best_loss': self.best_loss
141
- }
142
- save_args = orbax_utils.save_args_from_target(ckpt)
143
- self.checkpointer.save(epoch, ckpt, save_kwargs={'save_args': save_args})
144
-
145
- def summary(self, image_size=64):
146
- inp = jnp.ones((1, image_size, image_size, 3))
147
- temb = jnp.ones((1,))
148
- print(self.model.tabulate(jax.random.key(0), inp, temb, console_kwargs={"width": 200, "force_jupyter":True, }))
149
-
150
- def _define_train_step(self):
151
- noise_schedule = self.noise_schedule
152
- model = self.model
153
- model_output_transform = self.model_output_transform
154
- loss_fn = self.loss_fn
155
- @jax.jit
156
- def train_step(state:TrainState, batch):
157
- """Train for a single step."""
158
- images = batch
159
- noise_level, state = noise_schedule.generate_timesteps(images.shape[0], state)
160
- state, rngs = state.get_random_key()
161
- noise:jax.Array = jax.random.normal(rngs, shape=images.shape)
162
- rates = noise_schedule.get_rates(noise_level)
163
- noisy_images, c_in, expected_output = model_output_transform.forward_diffusion(images, noise, rates)
164
- def model_loss(params):
165
- preds = model.apply(params, *noise_schedule.transform_inputs(noisy_images*c_in, noise_level))
166
- preds = model_output_transform.pred_transform(noisy_images, preds, rates)
167
- nloss = loss_fn(preds, expected_output)
168
- # nloss = jnp.mean(nloss, axis=1)
169
- nloss *= noise_schedule.get_weights(noise_level)
170
- nloss = jnp.mean(nloss)
171
- loss = nloss
172
- return loss
173
- loss, grads = jax.value_and_grad(model_loss)(state.params)
174
- state = state.apply_gradients(grads=grads)
175
- state = state.apply_ema(self.ema_decay)
176
- return state, loss
177
- return train_step
178
-
179
- def _define_compute_metrics(self):
180
- @jax.jit
181
- def compute_metrics(state:TrainState, expected, pred):
182
- loss = jnp.mean(jnp.square(pred - expected))
183
- metric_updates = state.metrics.single_from_model_output(loss=loss)
184
- metrics = state.metrics.merge(metric_updates)
185
- state = state.replace(metrics=metrics)
186
- return state
187
- return compute_metrics
188
-
189
- def fit(self, data, steps_per_epoch, epochs):
190
- data = iter(data)
191
- train_step = self._define_train_step()
192
- compute_metrics = self._define_compute_metrics()
193
- state = self.state
194
- for epoch in range(epochs):
195
- print(f"\nEpoch {epoch+1}/{epochs}")
196
- start_time = time.time()
197
- epoch_loss = 0
198
- with tqdm.tqdm(total=steps_per_epoch, desc=f'\t\tEpoch {epoch+1}', ncols=100, unit='step') as pbar:
199
- for i in range(steps_per_epoch):
200
- batch = next(data)
201
- state, loss = train_step(state, batch)
202
- epoch_loss += loss
203
- if i % 100 == 0:
204
- pbar.set_postfix(loss=f'{loss:.4f}')
205
- pbar.update(100)
206
- end_time = time.time()
207
- self.state = state
208
- total_time = end_time - start_time
209
- avg_time_per_step = total_time / steps_per_epoch
210
- avg_loss = epoch_loss / steps_per_epoch
211
- if avg_loss < self.best_loss:
212
- self.best_loss = avg_loss
213
- self.best_state = state
214
- self.save(epoch, best=True)
215
- print(f"\n\tEpoch {epoch+1} completed. Avg Loss: {avg_loss}, Time: {total_time:.2f}s, Best Loss: {self.best_loss}")
216
- return self.state
File without changes
File without changes
File without changes
File without changes