flaxdiff 0.1.1__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/models/attention.py +57 -115
- flaxdiff/models/common.py +2 -2
- flaxdiff/models/simple_unet.py +7 -15
- flaxdiff/models/simple_vit.py +123 -0
- flaxdiff/trainer/__init__.py +113 -128
- flaxdiff/trainer/simple_trainer.py +323 -0
- {flaxdiff-0.1.1.dist-info → flaxdiff-0.1.4.dist-info}/METADATA +1 -1
- {flaxdiff-0.1.1.dist-info → flaxdiff-0.1.4.dist-info}/RECORD +10 -8
- {flaxdiff-0.1.1.dist-info → flaxdiff-0.1.4.dist-info}/WHEEL +0 -0
- {flaxdiff-0.1.1.dist-info → flaxdiff-0.1.4.dist-info}/top_level.txt +0 -0
flaxdiff/models/attention.py
CHANGED
@@ -11,105 +11,6 @@ import functools
|
|
11
11
|
import math
|
12
12
|
from .common import kernel_init
|
13
13
|
|
14
|
-
def _query_chunk_attention(query, key, value, precision, key_chunk_size: int = 4096):
|
15
|
-
"""Multi-head dot product attention with a limited number of queries."""
|
16
|
-
num_kv, num_heads, k_features = key.shape[-3:]
|
17
|
-
v_features = value.shape[-1]
|
18
|
-
key_chunk_size = min(key_chunk_size, num_kv)
|
19
|
-
query = query / jnp.sqrt(k_features)
|
20
|
-
|
21
|
-
@functools.partial(jax.checkpoint, prevent_cse=False)
|
22
|
-
def summarize_chunk(query, key, value):
|
23
|
-
attn_weights = jnp.einsum("...qhd,...khd->...qhk", query, key, precision=precision)
|
24
|
-
|
25
|
-
max_score = jnp.max(attn_weights, axis=-1, keepdims=True)
|
26
|
-
max_score = jax.lax.stop_gradient(max_score)
|
27
|
-
exp_weights = jnp.exp(attn_weights - max_score)
|
28
|
-
|
29
|
-
exp_values = jnp.einsum("...vhf,...qhv->...qhf", value, exp_weights, precision=precision)
|
30
|
-
max_score = jnp.einsum("...qhk->...qh", max_score)
|
31
|
-
|
32
|
-
return (exp_values, exp_weights.sum(axis=-1), max_score)
|
33
|
-
|
34
|
-
def chunk_scanner(chunk_idx):
|
35
|
-
# julienne key array
|
36
|
-
key_chunk = jax.lax.dynamic_slice(
|
37
|
-
operand=key,
|
38
|
-
start_indices=[0] * (key.ndim - 3) + [chunk_idx, 0, 0], # [...,k,h,d]
|
39
|
-
slice_sizes=list(key.shape[:-3]) + [key_chunk_size, num_heads, k_features], # [...,k,h,d]
|
40
|
-
)
|
41
|
-
|
42
|
-
# julienne value array
|
43
|
-
value_chunk = jax.lax.dynamic_slice(
|
44
|
-
operand=value,
|
45
|
-
start_indices=[0] * (value.ndim - 3) + [chunk_idx, 0, 0], # [...,v,h,d]
|
46
|
-
slice_sizes=list(value.shape[:-3]) + [key_chunk_size, num_heads, v_features], # [...,v,h,d]
|
47
|
-
)
|
48
|
-
|
49
|
-
return summarize_chunk(query, key_chunk, value_chunk)
|
50
|
-
|
51
|
-
chunk_values, chunk_weights, chunk_max = jax.lax.map(f=chunk_scanner, xs=jnp.arange(0, num_kv, key_chunk_size))
|
52
|
-
|
53
|
-
global_max = jnp.max(chunk_max, axis=0, keepdims=True)
|
54
|
-
max_diffs = jnp.exp(chunk_max - global_max)
|
55
|
-
|
56
|
-
chunk_values *= jnp.expand_dims(max_diffs, axis=-1)
|
57
|
-
chunk_weights *= max_diffs
|
58
|
-
|
59
|
-
all_values = chunk_values.sum(axis=0)
|
60
|
-
all_weights = jnp.expand_dims(chunk_weights, -1).sum(axis=0)
|
61
|
-
|
62
|
-
return all_values / all_weights
|
63
|
-
|
64
|
-
|
65
|
-
def jax_memory_efficient_attention(
|
66
|
-
query, key, value, precision=jax.lax.Precision.HIGHEST, query_chunk_size: int = 1024, key_chunk_size: int = 4096
|
67
|
-
):
|
68
|
-
r"""
|
69
|
-
Flax Memory-efficient multi-head dot product attention. https://arxiv.org/abs/2112.05682v2
|
70
|
-
https://github.com/AminRezaei0x443/memory-efficient-attention
|
71
|
-
|
72
|
-
Args:
|
73
|
-
query (`jnp.ndarray`): (batch..., query_length, head, query_key_depth_per_head)
|
74
|
-
key (`jnp.ndarray`): (batch..., key_value_length, head, query_key_depth_per_head)
|
75
|
-
value (`jnp.ndarray`): (batch..., key_value_length, head, value_depth_per_head)
|
76
|
-
precision (`jax.lax.Precision`, *optional*, defaults to `jax.lax.Precision.HIGHEST`):
|
77
|
-
numerical precision for computation
|
78
|
-
query_chunk_size (`int`, *optional*, defaults to 1024):
|
79
|
-
chunk size to divide query array value must divide query_length equally without remainder
|
80
|
-
key_chunk_size (`int`, *optional*, defaults to 4096):
|
81
|
-
chunk size to divide key and value array value must divide key_value_length equally without remainder
|
82
|
-
|
83
|
-
Returns:
|
84
|
-
(`jnp.ndarray`) with shape of (batch..., query_length, head, value_depth_per_head)
|
85
|
-
"""
|
86
|
-
num_q, num_heads, q_features = query.shape[-3:]
|
87
|
-
|
88
|
-
def chunk_scanner(chunk_idx, _):
|
89
|
-
# julienne query array
|
90
|
-
query_chunk = jax.lax.dynamic_slice(
|
91
|
-
operand=query,
|
92
|
-
start_indices=([0] * (query.ndim - 3)) + [chunk_idx, 0, 0], # [...,q,h,d]
|
93
|
-
slice_sizes=list(query.shape[:-3]) + [min(query_chunk_size, num_q), num_heads, q_features], # [...,q,h,d]
|
94
|
-
)
|
95
|
-
|
96
|
-
return (
|
97
|
-
chunk_idx + query_chunk_size, # unused ignore it
|
98
|
-
_query_chunk_attention(
|
99
|
-
query=query_chunk, key=key, value=value, precision=precision, key_chunk_size=key_chunk_size
|
100
|
-
),
|
101
|
-
)
|
102
|
-
|
103
|
-
_, res = jax.lax.scan(
|
104
|
-
f=chunk_scanner,
|
105
|
-
init=0,
|
106
|
-
xs=None,
|
107
|
-
length=math.ceil(num_q / query_chunk_size), # start counter # stop counter
|
108
|
-
)
|
109
|
-
|
110
|
-
return jnp.concatenate(res, axis=-3) # fuse the chunked result back
|
111
|
-
|
112
|
-
|
113
14
|
class EfficientAttention(nn.Module):
|
114
15
|
"""
|
115
16
|
Based on the pallas attention implementation.
|
@@ -125,41 +26,77 @@ class EfficientAttention(nn.Module):
|
|
125
26
|
def setup(self):
|
126
27
|
inner_dim = self.dim_head * self.heads
|
127
28
|
# Weights were exported with old names {to_q, to_k, to_v, to_out}
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
29
|
+
dense = functools.partial(
|
30
|
+
nn.Dense,
|
31
|
+
self.heads * self.dim_head,
|
32
|
+
precision=self.precision,
|
33
|
+
use_bias=self.use_bias,
|
34
|
+
kernel_init=self.kernel_init(),
|
35
|
+
dtype=self.dtype
|
36
|
+
)
|
37
|
+
self.query = dense(name="to_q")
|
38
|
+
self.key = dense(name="to_k")
|
39
|
+
self.value = dense(name="to_v")
|
40
|
+
|
134
41
|
self.proj_attn = nn.DenseGeneral(self.query_dim, use_bias=False, precision=self.precision,
|
135
42
|
kernel_init=self.kernel_init(), dtype=self.dtype, name="to_out_0")
|
136
43
|
# self.attnfn = make_fast_generalized_attention(qkv_dim=inner_dim, lax_scan_unroll=16)
|
44
|
+
|
45
|
+
def _reshape_tensor_to_head_dim(self, tensor):
|
46
|
+
batch_size, _, seq_len, dim = tensor.shape
|
47
|
+
head_size = self.heads
|
48
|
+
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
|
49
|
+
tensor = jnp.transpose(tensor, (0, 2, 1, 3))
|
50
|
+
return tensor
|
51
|
+
|
52
|
+
def _reshape_tensor_from_head_dim(self, tensor):
|
53
|
+
batch_size, _, seq_len, dim = tensor.shape
|
54
|
+
head_size = self.heads
|
55
|
+
tensor = jnp.transpose(tensor, (0, 2, 1, 3))
|
56
|
+
tensor = tensor.reshape(batch_size, 1, seq_len, dim * head_size)
|
57
|
+
return tensor
|
137
58
|
|
138
59
|
@nn.compact
|
139
60
|
def __call__(self, x:jax.Array, context=None):
|
61
|
+
# print(x.shape)
|
140
62
|
# x has shape [B, H * W, C]
|
141
63
|
context = x if context is None else context
|
64
|
+
|
65
|
+
B, H, W, C = x.shape
|
66
|
+
x = x.reshape((B, 1, H * W, C))
|
67
|
+
|
68
|
+
if len(context.shape) == 4:
|
69
|
+
B, _H, _W, _C = context.shape
|
70
|
+
context = context.reshape((B, 1, _H * _W, _C))
|
71
|
+
else:
|
72
|
+
B, SEQ, _C = context.shape
|
73
|
+
context = context.reshape((B, 1, SEQ, _C))
|
74
|
+
|
142
75
|
query = self.query(x)
|
143
76
|
key = self.key(context)
|
144
77
|
value = self.value(context)
|
145
78
|
|
146
|
-
|
79
|
+
query = self._reshape_tensor_to_head_dim(query)
|
80
|
+
key = self._reshape_tensor_to_head_dim(key)
|
81
|
+
value = self._reshape_tensor_to_head_dim(value)
|
147
82
|
|
148
|
-
|
149
|
-
|
150
|
-
# )
|
151
|
-
|
152
|
-
hidden_states = nn.dot_product_attention(
|
153
|
-
query, key, value, dtype=self.dtype, broadcast_dropout=False, dropout_rng=None, precision=self.precision
|
83
|
+
hidden_states = jax.experimental.pallas.ops.tpu.flash_attention.flash_attention(
|
84
|
+
query, key, value, None
|
154
85
|
)
|
155
|
-
|
156
|
-
|
86
|
+
|
87
|
+
hidden_states = self._reshape_tensor_from_head_dim(hidden_states)
|
88
|
+
|
89
|
+
|
90
|
+
# hidden_states = nn.dot_product_attention(
|
91
|
+
# query, key, value, dtype=self.dtype, broadcast_dropout=False, dropout_rng=None, precision=self.precision
|
157
92
|
# )
|
158
93
|
|
159
94
|
proj = self.proj_attn(hidden_states)
|
95
|
+
|
96
|
+
proj = proj.reshape((B, H, W, C))
|
97
|
+
|
160
98
|
return proj
|
161
99
|
|
162
|
-
|
163
100
|
class NormalAttention(nn.Module):
|
164
101
|
"""
|
165
102
|
Simple implementation of the normal attention.
|
@@ -201,7 +138,11 @@ class NormalAttention(nn.Module):
|
|
201
138
|
@nn.compact
|
202
139
|
def __call__(self, x, context=None):
|
203
140
|
# x has shape [B, H, W, C]
|
141
|
+
B, H, W, C = x.shape
|
142
|
+
x = x.reshape((B, H*W, C))
|
204
143
|
context = x if context is None else context
|
144
|
+
if len(context.shape) == 4:
|
145
|
+
context = context.reshape((B, H*W, C))
|
205
146
|
query = self.query(x)
|
206
147
|
key = self.key(context)
|
207
148
|
value = self.value(context)
|
@@ -210,6 +151,7 @@ class NormalAttention(nn.Module):
|
|
210
151
|
query, key, value, dtype=self.dtype, broadcast_dropout=False, dropout_rng=None, precision=self.precision
|
211
152
|
)
|
212
153
|
proj = self.proj_attn(hidden_states)
|
154
|
+
proj = proj.reshape((B, H, W, C))
|
213
155
|
return proj
|
214
156
|
|
215
157
|
class AttentionBlock(nn.Module):
|
flaxdiff/models/common.py
CHANGED
@@ -2,6 +2,6 @@ import jax.numpy as jnp
|
|
2
2
|
from flax import linen as nn
|
3
3
|
|
4
4
|
# Kernel initializer to use
|
5
|
-
def kernel_init(scale):
|
5
|
+
def kernel_init(scale, dtype=jnp.float32):
|
6
6
|
scale = max(scale, 1e-10)
|
7
|
-
return nn.initializers.variance_scaling(scale=scale, mode="
|
7
|
+
return nn.initializers.variance_scaling(scale=scale, mode="fan_avg", distribution="truncated_normal", dtype=dtype)
|
flaxdiff/models/simple_unet.py
CHANGED
@@ -5,6 +5,7 @@ from typing import Dict, Callable, Sequence, Any, Union
|
|
5
5
|
import einops
|
6
6
|
from .common import kernel_init
|
7
7
|
from .attention import TransformerBlock
|
8
|
+
|
8
9
|
class WeightStandardizedConv(nn.Module):
|
9
10
|
"""
|
10
11
|
apply weight standardization https://arxiv.org/abs/1903.10520
|
@@ -243,6 +244,7 @@ def l2norm(t, axis=1, eps=1e-12):
|
|
243
244
|
denom = jnp.clip(jnp.linalg.norm(t, ord=2, axis=axis, keepdims=True), eps)
|
244
245
|
out = t/denom
|
245
246
|
return (out)
|
247
|
+
|
246
248
|
class ResidualBlock(nn.Module):
|
247
249
|
conv_type:str
|
248
250
|
features:int
|
@@ -327,7 +329,7 @@ class Unet(nn.Module):
|
|
327
329
|
precision: Any = jax.lax.Precision.HIGH
|
328
330
|
|
329
331
|
@nn.compact
|
330
|
-
def __call__(self, x, temb, textcontext
|
332
|
+
def __call__(self, x, temb, textcontext):
|
331
333
|
# print("embedding features", self.emb_features)
|
332
334
|
temb = FourierEmbedding(features=self.emb_features)(temb)
|
333
335
|
temb = TimeProjection(features=self.emb_features)(temb)
|
@@ -340,7 +342,7 @@ class Unet(nn.Module):
|
|
340
342
|
|
341
343
|
conv_type = up_conv_type = down_conv_type = middle_conv_type = "conv"
|
342
344
|
# middle_conv_type = "separable"
|
343
|
-
|
345
|
+
|
344
346
|
x = ConvLayer(
|
345
347
|
conv_type,
|
346
348
|
features=self.feature_depths[0],
|
@@ -370,18 +372,13 @@ class Unet(nn.Module):
|
|
370
372
|
precision=self.precision
|
371
373
|
)(x, temb)
|
372
374
|
if attention_config is not None and j == self.num_res_blocks - 1: # Apply attention only on the last block
|
373
|
-
B, H, W, _ = x.shape
|
374
|
-
if H > TS:
|
375
|
-
padded_context = jnp.pad(textcontext, ((0, 0), (0, H - TS), (0, 0)), mode='constant', constant_values=0).reshape((B, 1, H, TC))
|
376
|
-
else:
|
377
|
-
padded_context = None
|
378
375
|
x = TransformerBlock(heads=attention_config['heads'], dtype=attention_config.get('dtype', jnp.float32),
|
379
376
|
dim_head=dim_in // attention_config['heads'],
|
380
377
|
use_flash_attention=attention_config.get("flash_attention", True),
|
381
378
|
use_projection=attention_config.get("use_projection", False),
|
382
379
|
use_self_and_cross=attention_config.get("use_self_and_cross", True),
|
383
380
|
precision=attention_config.get("precision", self.precision),
|
384
|
-
name=f"down_{i}_attention_{j}")(x,
|
381
|
+
name=f"down_{i}_attention_{j}")(x, textcontext)
|
385
382
|
# print("down residual for feature level", i, "is of shape", x.shape, "features", dim_in)
|
386
383
|
downs.append(x)
|
387
384
|
if i != len(feature_depths) - 1:
|
@@ -419,7 +416,7 @@ class Unet(nn.Module):
|
|
419
416
|
use_projection=middle_attention.get("use_projection", False),
|
420
417
|
use_self_and_cross=False,
|
421
418
|
precision=attention_config.get("precision", self.precision),
|
422
|
-
name=f"middle_attention_{j}")(x)
|
419
|
+
name=f"middle_attention_{j}")(x, textcontext)
|
423
420
|
x = ResidualBlock(
|
424
421
|
middle_conv_type,
|
425
422
|
name=f"middle_res2_{j}",
|
@@ -454,18 +451,13 @@ class Unet(nn.Module):
|
|
454
451
|
precision=self.precision
|
455
452
|
)(x, temb)
|
456
453
|
if attention_config is not None and j == self.num_res_blocks - 1: # Apply attention only on the last block
|
457
|
-
B, H, W, _ = x.shape
|
458
|
-
if H > TS:
|
459
|
-
padded_context = jnp.pad(textcontext, ((0, 0), (0, H - TS), (0, 0)), mode='constant', constant_values=0).reshape((B, 1, H, TC))
|
460
|
-
else:
|
461
|
-
padded_context = None
|
462
454
|
x = TransformerBlock(heads=attention_config['heads'], dtype=attention_config.get('dtype', jnp.float32),
|
463
455
|
dim_head=dim_out // attention_config['heads'],
|
464
456
|
use_flash_attention=attention_config.get("flash_attention", True),
|
465
457
|
use_projection=attention_config.get("use_projection", False),
|
466
458
|
use_self_and_cross=attention_config.get("use_self_and_cross", True),
|
467
459
|
precision=attention_config.get("precision", self.precision),
|
468
|
-
name=f"up_{i}_attention_{j}")(x,
|
460
|
+
name=f"up_{i}_attention_{j}")(x, textcontext)
|
469
461
|
# print("Upscaling ", i, x.shape)
|
470
462
|
if i != len(feature_depths) - 1:
|
471
463
|
x = Upsample(
|
@@ -0,0 +1,123 @@
|
|
1
|
+
# simple_vit.py
|
2
|
+
|
3
|
+
import jax
|
4
|
+
import jax.numpy as jnp
|
5
|
+
from flax import linen as nn
|
6
|
+
from typing import Callable, Any
|
7
|
+
from .simply_unet import FourierEmbedding, TimeProjection, ConvLayer, kernel_init
|
8
|
+
from .attention import TransformerBlock
|
9
|
+
|
10
|
+
class PatchEmbedding(nn.Module):
|
11
|
+
patch_size: int
|
12
|
+
embedding_dim: int
|
13
|
+
dtype: Any = jnp.float32
|
14
|
+
precision: Any = jax.lax.Precision.HIGH
|
15
|
+
|
16
|
+
@nn.compact
|
17
|
+
def __call__(self, x):
|
18
|
+
batch, height, width, channels = x.shape
|
19
|
+
assert height % self.patch_size == 0 and width % self.patch_size == 0, "Image dimensions must be divisible by patch size"
|
20
|
+
|
21
|
+
x = nn.Conv(features=self.embedding_dim,
|
22
|
+
kernel_size=(self.patch_size, self.patch_size),
|
23
|
+
strides=(self.patch_size, self.patch_size),
|
24
|
+
dtype=self.dtype,
|
25
|
+
precision=self.precision)(x)
|
26
|
+
x = jnp.reshape(x, (batch, -1, self.embedding_dim))
|
27
|
+
return x
|
28
|
+
|
29
|
+
class PositionalEncoding(nn.Module):
|
30
|
+
max_len: int
|
31
|
+
embedding_dim: int
|
32
|
+
|
33
|
+
@nn.compact
|
34
|
+
def __call__(self, x):
|
35
|
+
pe = self.param('pos_encoding',
|
36
|
+
jax.nn.initializers.zeros,
|
37
|
+
(1, self.max_len, self.embedding_dim))
|
38
|
+
return x + pe[:, :x.shape[1], :]
|
39
|
+
|
40
|
+
class TransformerEncoder(nn.Module):
|
41
|
+
num_layers: int
|
42
|
+
num_heads: int
|
43
|
+
mlp_dim: int
|
44
|
+
dropout_rate: float = 0.1
|
45
|
+
dtype: Any = jnp.float32
|
46
|
+
precision: Any = jax.lax.Precision.HIGH
|
47
|
+
|
48
|
+
@nn.compact
|
49
|
+
def __call__(self, x, training=True):
|
50
|
+
for _ in range(self.num_layers):
|
51
|
+
x = TransformerBlock(
|
52
|
+
heads=self.num_heads,
|
53
|
+
dim_head=x.shape[-1] // self.num_heads,
|
54
|
+
mlp_dim=self.mlp_dim,
|
55
|
+
dropout_rate=self.dropout_rate,
|
56
|
+
dtype=self.dtype,
|
57
|
+
precision=self.precision
|
58
|
+
)(x)
|
59
|
+
return x
|
60
|
+
|
61
|
+
class VisionTransformer(nn.Module):
|
62
|
+
patch_size: int = 16
|
63
|
+
embedding_dim: int = 768
|
64
|
+
num_layers: int = 12
|
65
|
+
num_heads: int = 12
|
66
|
+
mlp_dim: int = 3072
|
67
|
+
emb_features: int = 256
|
68
|
+
dropout_rate: float = 0.1
|
69
|
+
dtype: Any = jnp.float32
|
70
|
+
precision: Any = jax.lax.Precision.HIGH
|
71
|
+
|
72
|
+
@nn.compact
|
73
|
+
def __call__(self, x, temb, textcontext=None):
|
74
|
+
# Time embedding
|
75
|
+
temb = FourierEmbedding(features=self.emb_features)(temb)
|
76
|
+
temb = TimeProjection(features=self.emb_features)(temb)
|
77
|
+
|
78
|
+
# Patch embedding
|
79
|
+
x = PatchEmbedding(patch_size=self.patch_size, embedding_dim=self.embedding_dim,
|
80
|
+
dtype=self.dtype, precision=self.precision)(x)
|
81
|
+
|
82
|
+
# Add positional encoding
|
83
|
+
x = PositionalEncoding(max_len=x.shape[1], embedding_dim=self.embedding_dim)(x)
|
84
|
+
|
85
|
+
# Add time embedding
|
86
|
+
temb = jnp.expand_dims(temb, axis=1)
|
87
|
+
x = jnp.concatenate([x, temb], axis=1)
|
88
|
+
|
89
|
+
# Add text context
|
90
|
+
if textcontext is not None:
|
91
|
+
x = jnp.concatenate([x, textcontext], axis=1)
|
92
|
+
|
93
|
+
# Transformer encoder
|
94
|
+
x = TransformerEncoder(
|
95
|
+
num_layers=self.num_layers,
|
96
|
+
num_heads=self.num_heads,
|
97
|
+
mlp_dim=self.mlp_dim,
|
98
|
+
dropout_rate=self.dropout_rate,
|
99
|
+
dtype=self.dtype,
|
100
|
+
precision=self.precision
|
101
|
+
)(x)
|
102
|
+
|
103
|
+
# Extract the image tokens (exclude time and text embeddings)
|
104
|
+
num_patches = (x.shape[1] - 1 - (0 if textcontext is None else textcontext.shape[1]))
|
105
|
+
x = x[:, :num_patches, :]
|
106
|
+
|
107
|
+
# Reshape to image dimensions
|
108
|
+
batch, _, _ = x.shape
|
109
|
+
height = width = int((num_patches) ** 0.5)
|
110
|
+
x = jnp.reshape(x, (batch, height, width, self.embedding_dim))
|
111
|
+
|
112
|
+
# Final convolution to get the desired output channels
|
113
|
+
x = ConvLayer(
|
114
|
+
conv_type="conv",
|
115
|
+
features=3,
|
116
|
+
kernel_size=(3, 3),
|
117
|
+
strides=(1, 1),
|
118
|
+
kernel_init=kernel_init(0.0),
|
119
|
+
dtype=self.dtype,
|
120
|
+
precision=self.precision
|
121
|
+
)(x)
|
122
|
+
|
123
|
+
return x
|
flaxdiff/trainer/__init__.py
CHANGED
@@ -17,18 +17,9 @@ from flax.training import orbax_utils
|
|
17
17
|
from ..schedulers import NoiseScheduler
|
18
18
|
from ..predictors import DiffusionPredictionTransform, EpsilonPredictionTransform
|
19
19
|
|
20
|
-
|
21
|
-
class Metrics(metrics.Collection):
|
22
|
-
loss: metrics.Average.from_output('loss') # type: ignore
|
20
|
+
from .simple_trainer import SimpleTrainer, SimpleTrainState
|
23
21
|
|
24
|
-
class
|
25
|
-
model: nn.Module
|
26
|
-
params: dict
|
27
|
-
noise_schedule: NoiseScheduler
|
28
|
-
model_output_transform: DiffusionPredictionTransform
|
29
|
-
|
30
|
-
# Define the TrainState with EMA parameters
|
31
|
-
class TrainState(train_state.TrainState):
|
22
|
+
class TrainState(SimpleTrainState):
|
32
23
|
rngs: jax.random.PRNGKey
|
33
24
|
ema_params: dict
|
34
25
|
|
@@ -36,7 +27,7 @@ class TrainState(train_state.TrainState):
|
|
36
27
|
rngs, subkey = jax.random.split(self.rngs)
|
37
28
|
return self.replace(rngs=rngs), subkey
|
38
29
|
|
39
|
-
def apply_ema(self, decay: float=0.999):
|
30
|
+
def apply_ema(self, decay: float = 0.999):
|
40
31
|
new_ema_params = jax.tree_util.tree_map(
|
41
32
|
lambda ema, param: decay * ema + (1 - decay) * param,
|
42
33
|
self.ema_params,
|
@@ -44,141 +35,142 @@ class TrainState(train_state.TrainState):
|
|
44
35
|
)
|
45
36
|
return self.replace(ema_params=new_ema_params)
|
46
37
|
|
47
|
-
class DiffusionTrainer:
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
def __init__(self,
|
57
|
-
model:nn.Module,
|
38
|
+
class DiffusionTrainer(SimpleTrainer):
|
39
|
+
noise_schedule: NoiseScheduler
|
40
|
+
model_output_transform: DiffusionPredictionTransform
|
41
|
+
ema_decay: float = 0.999
|
42
|
+
|
43
|
+
def __init__(self,
|
44
|
+
model: nn.Module,
|
45
|
+
input_shapes: Dict[str, Tuple[int]],
|
58
46
|
optimizer: optax.GradientTransformation,
|
59
|
-
noise_schedule:NoiseScheduler,
|
60
|
-
rngs:jax.random.PRNGKey,
|
61
|
-
|
62
|
-
name:str="Diffusion",
|
63
|
-
|
64
|
-
|
65
|
-
model_output_transform:DiffusionPredictionTransform=EpsilonPredictionTransform(),
|
66
|
-
loss_fn=optax.l2_loss,
|
47
|
+
noise_schedule: NoiseScheduler,
|
48
|
+
rngs: jax.random.PRNGKey,
|
49
|
+
unconditional_prob: float = 0.2,
|
50
|
+
name: str = "Diffusion",
|
51
|
+
model_output_transform: DiffusionPredictionTransform = EpsilonPredictionTransform(),
|
52
|
+
**kwargs
|
67
53
|
):
|
68
|
-
|
54
|
+
super().__init__(
|
55
|
+
model=model,
|
56
|
+
input_shapes=input_shapes,
|
57
|
+
optimizer=optimizer,
|
58
|
+
rngs=rngs,
|
59
|
+
name=name,
|
60
|
+
**kwargs
|
61
|
+
)
|
69
62
|
self.noise_schedule = noise_schedule
|
70
|
-
self.name = name
|
71
63
|
self.model_output_transform = model_output_transform
|
72
|
-
self.
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
64
|
+
self.unconditional_prob = unconditional_prob
|
65
|
+
|
66
|
+
def __init_fn(
|
67
|
+
self,
|
68
|
+
optimizer: optax.GradientTransformation,
|
69
|
+
rngs: jax.random.PRNGKey,
|
70
|
+
existing_state: dict = None,
|
71
|
+
existing_best_state: dict = None,
|
72
|
+
model: nn.Module = None,
|
73
|
+
param_transforms: Callable = None
|
74
|
+
) -> Tuple[TrainState, TrainState]:
|
75
|
+
rngs, subkey = jax.random.split(rngs)
|
77
76
|
|
78
|
-
if
|
79
|
-
|
77
|
+
if existing_state == None:
|
78
|
+
input_vars = self.get_input_ones()
|
79
|
+
params = model.init(subkey, **input_vars)
|
80
|
+
new_state = {"params": params, "ema_params": params}
|
80
81
|
else:
|
81
|
-
|
82
|
+
new_state = existing_state
|
82
83
|
|
83
|
-
if train_state == None:
|
84
|
-
self.init_state(optimizer, rngs, params=params, model=model, param_transforms=param_transforms)
|
85
|
-
else:
|
86
|
-
self.state = train_state
|
87
|
-
self.best_state = train_state
|
88
|
-
self.best_loss = 1e9
|
89
|
-
|
90
|
-
def init_state(self,
|
91
|
-
optimizer: optax.GradientTransformation,
|
92
|
-
rngs:jax.random.PRNGKey,
|
93
|
-
params:dict=None,
|
94
|
-
model:nn.Module=None,
|
95
|
-
param_transforms:Callable=None,
|
96
|
-
batch_size=16,
|
97
|
-
image_size=64
|
98
|
-
):
|
99
|
-
inp = jnp.ones((batch_size, image_size, image_size, 3))
|
100
|
-
temb = jnp.ones((batch_size,))
|
101
|
-
rngs, subkey = jax.random.split(rngs)
|
102
|
-
if params == None:
|
103
|
-
params = model.init(subkey, inp, temb)
|
104
84
|
if param_transforms is not None:
|
105
85
|
params = param_transforms(params)
|
106
|
-
|
107
|
-
|
86
|
+
|
87
|
+
state = TrainState.create(
|
108
88
|
apply_fn=model.apply,
|
109
|
-
params=params,
|
110
|
-
ema_params=
|
89
|
+
params=new_state['params'],
|
90
|
+
ema_params=new_state['ema_params'],
|
111
91
|
tx=optimizer,
|
112
92
|
rngs=rngs,
|
93
|
+
metrics=Metrics.empty()
|
113
94
|
)
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
return
|
122
|
-
|
123
|
-
def
|
124
|
-
step = self.checkpointer.latest_step()
|
125
|
-
print("Loading model from checkpoint", step)
|
126
|
-
ckpt = self.checkpointer.restore(step)
|
127
|
-
state = ckpt['state']
|
128
|
-
# Convert the state to a TrainState
|
129
|
-
self.best_loss = ckpt['best_loss']
|
130
|
-
print(f"Loaded model from checkpoint at step {step}", ckpt['best_loss'])
|
131
|
-
return state.get('params', None)#, ckpt.get('model', None)
|
132
|
-
|
133
|
-
def save(self, epoch=0, best=False):
|
134
|
-
print(f"Saving model at epoch {epoch}")
|
135
|
-
state = self.best_state if best else self.state
|
136
|
-
# filename = os.path.join(self.checkpoint_path(), f'model_{epoch}' if not best else 'best_model')
|
137
|
-
ckpt = {
|
138
|
-
'model': self.model,
|
139
|
-
'state': state,
|
140
|
-
'best_loss': self.best_loss
|
141
|
-
}
|
142
|
-
save_args = orbax_utils.save_args_from_target(ckpt)
|
143
|
-
self.checkpointer.save(epoch, ckpt, save_kwargs={'save_args': save_args})
|
144
|
-
|
145
|
-
def summary(self, image_size=64):
|
146
|
-
inp = jnp.ones((1, image_size, image_size, 3))
|
147
|
-
temb = jnp.ones((1,))
|
148
|
-
print(self.model.tabulate(jax.random.key(0), inp, temb, console_kwargs={"width": 200, "force_jupyter":True, }))
|
149
|
-
|
150
|
-
def _define_train_step(self):
|
95
|
+
|
96
|
+
if existing_best_state is not None:
|
97
|
+
best_state = state.replace(
|
98
|
+
params=existing_best_state['params'], ema_params=existing_best_state['ema_params'])
|
99
|
+
else:
|
100
|
+
best_state = state
|
101
|
+
|
102
|
+
return state, best_state
|
103
|
+
|
104
|
+
def _define_train_step(self, batch_size, null_labels_seq, text_embedder):
|
151
105
|
noise_schedule = self.noise_schedule
|
152
106
|
model = self.model
|
153
107
|
model_output_transform = self.model_output_transform
|
154
108
|
loss_fn = self.loss_fn
|
155
|
-
|
156
|
-
|
109
|
+
unconditional_prob = self.unconditional_prob
|
110
|
+
|
111
|
+
# Determine the number of unconditional samples
|
112
|
+
num_unconditional = int(batch_size * unconditional_prob)
|
113
|
+
|
114
|
+
nS, nC = null_labels_seq.shape
|
115
|
+
null_labels_seq = jnp.broadcast_to(
|
116
|
+
null_labels_seq, (batch_size, nS, nC))
|
117
|
+
|
118
|
+
distributed_training = self.distributed_training
|
119
|
+
|
120
|
+
def train_step(state: TrainState, batch):
|
157
121
|
"""Train for a single step."""
|
158
|
-
images = batch
|
159
|
-
|
122
|
+
images = batch['image']
|
123
|
+
# normalize image
|
124
|
+
images = (images - 127.5) / 127.5
|
125
|
+
|
126
|
+
output = text_embedder(
|
127
|
+
input_ids=batch['input_ids'], attention_mask=batch['attention_mask'])
|
128
|
+
# output = infer(input_ids=batch['input_ids'], attention_mask=batch['attention_mask'])
|
129
|
+
|
130
|
+
label_seq = output.last_hidden_state
|
131
|
+
|
132
|
+
# Generate random probabilities to decide how much of this batch will be unconditional
|
133
|
+
|
134
|
+
label_seq = jnp.concat(
|
135
|
+
[null_labels_seq[:num_unconditional], label_seq[num_unconditional:]], axis=0)
|
136
|
+
|
137
|
+
noise_level, state = noise_schedule.generate_timesteps(
|
138
|
+
images.shape[0], state)
|
160
139
|
state, rngs = state.get_random_key()
|
161
|
-
noise:jax.Array = jax.random.normal(rngs, shape=images.shape)
|
140
|
+
noise: jax.Array = jax.random.normal(rngs, shape=images.shape)
|
162
141
|
rates = noise_schedule.get_rates(noise_level)
|
163
|
-
noisy_images, c_in, expected_output = model_output_transform.forward_diffusion(
|
142
|
+
noisy_images, c_in, expected_output = model_output_transform.forward_diffusion(
|
143
|
+
images, noise, rates)
|
144
|
+
|
164
145
|
def model_loss(params):
|
165
|
-
preds = model.apply(
|
166
|
-
|
146
|
+
preds = model.apply(
|
147
|
+
params, *noise_schedule.transform_inputs(noisy_images*c_in, noise_level), label_seq)
|
148
|
+
preds = model_output_transform.pred_transform(
|
149
|
+
noisy_images, preds, rates)
|
167
150
|
nloss = loss_fn(preds, expected_output)
|
168
151
|
# nloss = jnp.mean(nloss, axis=1)
|
169
152
|
nloss *= noise_schedule.get_weights(noise_level)
|
170
153
|
nloss = jnp.mean(nloss)
|
171
154
|
loss = nloss
|
172
155
|
return loss
|
156
|
+
|
173
157
|
loss, grads = jax.value_and_grad(model_loss)(state.params)
|
174
|
-
|
158
|
+
if distributed_training:
|
159
|
+
grads = jax.lax.pmean(grads, "device")
|
160
|
+
state = state.apply_gradients(grads=grads)
|
175
161
|
state = state.apply_ema(self.ema_decay)
|
176
162
|
return state, loss
|
163
|
+
|
164
|
+
if distributed_training:
|
165
|
+
train_step = jax.pmap(axis_name="device")(train_step)
|
166
|
+
else:
|
167
|
+
train_step = jax.jit(train_step)
|
168
|
+
|
177
169
|
return train_step
|
178
|
-
|
170
|
+
|
179
171
|
def _define_compute_metrics(self):
|
180
172
|
@jax.jit
|
181
|
-
def compute_metrics(state:TrainState, expected, pred):
|
173
|
+
def compute_metrics(state: TrainState, expected, pred):
|
182
174
|
loss = jnp.mean(jnp.square(pred - expected))
|
183
175
|
metric_updates = state.metrics.single_from_model_output(loss=loss)
|
184
176
|
metrics = state.metrics.merge(metric_updates)
|
@@ -187,20 +179,13 @@ class DiffusionTrainer:
|
|
187
179
|
return compute_metrics
|
188
180
|
|
189
181
|
def fit(self, data, steps_per_epoch, epochs):
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
epoch_loss = 0
|
198
|
-
with tqdm.tqdm(total=steps_per_epoch, desc=f'\t\tEpoch {epoch+1}', ncols=100, unit='step') as pbar:
|
199
|
-
for i in range(steps_per_epoch):
|
200
|
-
batch = next(data)
|
201
|
-
state, loss = train_step(state, batch)
|
202
|
-
epoch_loss += loss
|
203
|
-
if i % 100 == 0:
|
182
|
+
null_labels_full = data['null_labels_full']
|
183
|
+
local_batch_size = data['local_batch_size']
|
184
|
+
text_embedder = data['model']
|
185
|
+
super().fit(data, steps_per_epoch, epochs, {
|
186
|
+
"batch_size": local_batch_size, "null_labels_seq": null_labels_full, "text_embedder": text_embedder})
|
187
|
+
|
188
|
+
|
204
189
|
pbar.set_postfix(loss=f'{loss:.4f}')
|
205
190
|
pbar.update(100)
|
206
191
|
end_time = time.time()
|
@@ -0,0 +1,323 @@
|
|
1
|
+
import orbax.checkpoint
|
2
|
+
import tqdm
|
3
|
+
from flax import linen as nn
|
4
|
+
import jax
|
5
|
+
from typing import Callable
|
6
|
+
from dataclasses import field
|
7
|
+
import jax.numpy as jnp
|
8
|
+
from clu import metrics
|
9
|
+
from flax.training import train_state # Useful dataclass to keep train state
|
10
|
+
import optax
|
11
|
+
from flax import struct # Flax dataclasses
|
12
|
+
import time
|
13
|
+
import os
|
14
|
+
import orbax
|
15
|
+
from flax.training import orbax_utils
|
16
|
+
|
17
|
+
@struct.dataclass
|
18
|
+
class Metrics(metrics.Collection):
|
19
|
+
accuracy: metrics.Accuracy
|
20
|
+
loss: metrics.Average.from_output('loss')
|
21
|
+
|
22
|
+
# Define the TrainState
|
23
|
+
class SimpleTrainState(train_state.TrainState):
|
24
|
+
rngs: jax.random.PRNGKey
|
25
|
+
metrics: Metrics
|
26
|
+
|
27
|
+
def get_random_key(self):
|
28
|
+
rngs, subkey = jax.random.split(self.rngs)
|
29
|
+
return self.replace(rngs=rngs), subkey
|
30
|
+
|
31
|
+
class SimpleTrainer:
|
32
|
+
state: SimpleTrainState
|
33
|
+
best_state: SimpleTrainState
|
34
|
+
best_loss: float
|
35
|
+
model: nn.Module
|
36
|
+
ema_decay: float = 0.999
|
37
|
+
|
38
|
+
def __init__(self,
|
39
|
+
model: nn.Module,
|
40
|
+
input_shapes: Dict[str, Tuple[int]],
|
41
|
+
optimizer: optax.GradientTransformation,
|
42
|
+
rngs: jax.random.PRNGKey,
|
43
|
+
train_state: SimpleTrainState = None,
|
44
|
+
name: str = "Simple",
|
45
|
+
load_from_checkpoint: bool = False,
|
46
|
+
checkpoint_suffix: str = "",
|
47
|
+
loss_fn=optax.l2_loss,
|
48
|
+
param_transforms: Callable = None,
|
49
|
+
wandb_config: Dict[str, Any] = None,
|
50
|
+
distributed_training: bool = None,
|
51
|
+
):
|
52
|
+
if distributed_training is None or distributed_training is True:
|
53
|
+
# Auto-detect if we are running on multiple devices
|
54
|
+
distributed_training = jax.device_count() > 1
|
55
|
+
|
56
|
+
self.distributed_training = distributed_training
|
57
|
+
self.model = model
|
58
|
+
self.name = name
|
59
|
+
self.loss_fn = loss_fn
|
60
|
+
self.input_shapes = input_shapes
|
61
|
+
|
62
|
+
if wandb_config is not None:
|
63
|
+
run = wandb.init(**wandb_config)
|
64
|
+
self.wandb = run
|
65
|
+
|
66
|
+
checkpointer = orbax.checkpoint.PyTreeCheckpointer()
|
67
|
+
options = orbax.checkpoint.CheckpointManagerOptions(
|
68
|
+
max_to_keep=4, create=True)
|
69
|
+
self.checkpointer = orbax.checkpoint.CheckpointManager(
|
70
|
+
self.checkpoint_path() + checkpoint_suffix, checkpointer, options)
|
71
|
+
|
72
|
+
if load_from_checkpoint:
|
73
|
+
latest_epoch, old_state, old_best_state = self.load()
|
74
|
+
else:
|
75
|
+
latest_epoch, old_state, old_best_state = 0, None, None
|
76
|
+
|
77
|
+
self.latest_epoch = latest_epoch
|
78
|
+
|
79
|
+
if train_state == None:
|
80
|
+
self.init_state(optimizer, rngs, existing_state=old_state,
|
81
|
+
existing_best_state=old_best_state, model=model, param_transforms=param_transforms)
|
82
|
+
else:
|
83
|
+
self.state = train_state
|
84
|
+
self.best_state = train_state
|
85
|
+
self.best_loss = 1e9
|
86
|
+
|
87
|
+
def get_input_ones(self):
|
88
|
+
return {k: jnp.ones((1, *v)) for k, v in self.input_shapes.items()}
|
89
|
+
|
90
|
+
def __init_fn(
|
91
|
+
self,
|
92
|
+
optimizer: optax.GradientTransformation,
|
93
|
+
rngs: jax.random.PRNGKey,
|
94
|
+
existing_state: dict = None,
|
95
|
+
existing_best_state: dict = None,
|
96
|
+
model: nn.Module = None,
|
97
|
+
param_transforms: Callable = None
|
98
|
+
) -> Tuple[SimpleTrainState, SimpleTrainState]:
|
99
|
+
rngs, subkey = jax.random.split(rngs)
|
100
|
+
|
101
|
+
if existing_state == None:
|
102
|
+
input_vars = self.get_input_ones()
|
103
|
+
params = model.init(subkey, **input_vars)
|
104
|
+
|
105
|
+
state = SimpleTrainState.create(
|
106
|
+
apply_fn=model.apply,
|
107
|
+
params=params,
|
108
|
+
tx=optimizer,
|
109
|
+
rngs=rngs,
|
110
|
+
metrics=Metrics.empty()
|
111
|
+
)
|
112
|
+
if existing_best_state is not None:
|
113
|
+
best_state = state.replace(
|
114
|
+
params=existing_best_state['params'])
|
115
|
+
else:
|
116
|
+
best_state = state
|
117
|
+
|
118
|
+
return state, best_state
|
119
|
+
|
120
|
+
def init_state(
|
121
|
+
self,
|
122
|
+
optimizer: optax.GradientTransformation,
|
123
|
+
rngs: jax.random.PRNGKey,
|
124
|
+
existing_state: dict = None,
|
125
|
+
existing_best_state: dict = None,
|
126
|
+
model: nn.Module = None,
|
127
|
+
param_transforms: Callable = None
|
128
|
+
):
|
129
|
+
|
130
|
+
state, best_state = self.__init_fn(
|
131
|
+
optimizer, rngs, existing_state, existing_best_state, model, param_transforms
|
132
|
+
)
|
133
|
+
self.best_loss = 1e9
|
134
|
+
|
135
|
+
if self.distributed_training:
|
136
|
+
devices = jax.local_devices()
|
137
|
+
if len(devices) > 1:
|
138
|
+
print("Replicating state across devices ", devices)
|
139
|
+
state = flax.jax_utils.replicate(state, devices)
|
140
|
+
best_state = flax.jax_utils.replicate(best_state, devices)
|
141
|
+
else:
|
142
|
+
print("Not replicating any state, Only single device connected to the process")
|
143
|
+
|
144
|
+
self.state = state
|
145
|
+
self.best_state = best_state
|
146
|
+
|
147
|
+
def get_state(self):
|
148
|
+
return flax.jax_utils.unreplicate(self.state)
|
149
|
+
|
150
|
+
def get_best_state(self):
|
151
|
+
return flax.jax_utils.unreplicate(self.best_state)
|
152
|
+
|
153
|
+
def checkpoint_path(self):
|
154
|
+
experiment_name = self.name
|
155
|
+
path = os.path.join(os.path.abspath('./checkpoints'), experiment_name)
|
156
|
+
if not os.path.exists(path):
|
157
|
+
os.makedirs(path)
|
158
|
+
return path
|
159
|
+
|
160
|
+
def tensorboard_path(self):
|
161
|
+
experiment_name = self.name
|
162
|
+
path = os.path.join(os.path.abspath('./tensorboard'), experiment_name)
|
163
|
+
if not os.path.exists(path):
|
164
|
+
os.makedirs(path)
|
165
|
+
return path
|
166
|
+
|
167
|
+
def load(self):
|
168
|
+
epoch = self.checkpointer.latest_step()
|
169
|
+
print("Loading model from checkpoint", epoch)
|
170
|
+
ckpt = self.checkpointer.restore(epoch)
|
171
|
+
state = ckpt['state']
|
172
|
+
best_state = ckpt['best_state']
|
173
|
+
# Convert the state to a TrainState
|
174
|
+
self.best_loss = ckpt['best_loss']
|
175
|
+
print(
|
176
|
+
f"Loaded model from checkpoint at epoch {epoch}", ckpt['best_loss'])
|
177
|
+
return epoch, state, best_state
|
178
|
+
|
179
|
+
def save(self, epoch=0):
|
180
|
+
print(f"Saving model at epoch {epoch}")
|
181
|
+
ckpt = {
|
182
|
+
# 'model': self.model,
|
183
|
+
'state': self.get_state(),
|
184
|
+
'best_state': self.get_best_state(),
|
185
|
+
'best_loss': self.best_loss
|
186
|
+
}
|
187
|
+
try:
|
188
|
+
save_args = orbax_utils.save_args_from_target(ckpt)
|
189
|
+
self.checkpointer.save(epoch, ckpt, save_kwargs={
|
190
|
+
'save_args': save_args}, force=True)
|
191
|
+
pass
|
192
|
+
except Exception as e:
|
193
|
+
print("Error saving checkpoint", e)
|
194
|
+
|
195
|
+
def _define_train_step(self, **kwargs):
|
196
|
+
model = self.model
|
197
|
+
loss_fn = self.loss_fn
|
198
|
+
distributed_training = self.distributed_training
|
199
|
+
|
200
|
+
def train_step(state: SimpleTrainState, batch):
|
201
|
+
"""Train for a single step."""
|
202
|
+
images = batch['image']
|
203
|
+
labels = batch['label']
|
204
|
+
|
205
|
+
def model_loss(params):
|
206
|
+
preds = model.apply(params, images)
|
207
|
+
expected_output = labels
|
208
|
+
nloss = loss_fn(preds, expected_output)
|
209
|
+
loss = jnp.mean(nloss)
|
210
|
+
return loss
|
211
|
+
loss, grads = jax.value_and_grad(model_loss)(state.params)
|
212
|
+
if distributed_training:
|
213
|
+
grads = jax.lax.pmean(grads, "device")
|
214
|
+
state = state.apply_gradients(grads=grads)
|
215
|
+
return state, loss
|
216
|
+
|
217
|
+
if distributed_training:
|
218
|
+
train_step = jax.pmap(axis_name="device")(train_step)
|
219
|
+
else:
|
220
|
+
train_step = jax.jit(train_step)
|
221
|
+
|
222
|
+
return train_step
|
223
|
+
|
224
|
+
def _define_compute_metrics(self):
|
225
|
+
model = self.model
|
226
|
+
loss_fn = self.loss_fn
|
227
|
+
|
228
|
+
@jax.jit
|
229
|
+
def compute_metrics(state: SimpleTrainState, batch):
|
230
|
+
preds = model.apply(state.params, batch['image'])
|
231
|
+
expected_output = batch['label']
|
232
|
+
loss = jnp.mean(loss_fn(preds, expected_output))
|
233
|
+
metric_updates = state.metrics.single_from_model_output(
|
234
|
+
loss=loss, logits=preds, labels=expected_output)
|
235
|
+
metrics = state.metrics.merge(metric_updates)
|
236
|
+
state = state.replace(metrics=metrics)
|
237
|
+
return state
|
238
|
+
return compute_metrics
|
239
|
+
|
240
|
+
def summary(self):
|
241
|
+
input_vars = self.get_input_ones()
|
242
|
+
print(self.model.tabulate(jax.random.key(0), **input_vars,
|
243
|
+
console_kwargs={"width": 200, "force_jupyter": True, }))
|
244
|
+
|
245
|
+
def config(self):
|
246
|
+
return {
|
247
|
+
"model": self.model,
|
248
|
+
"state": self.state,
|
249
|
+
"name": self.name,
|
250
|
+
"input_shapes": self.input_shapes
|
251
|
+
}
|
252
|
+
|
253
|
+
def init_tensorboard(self, batch_size, steps_per_epoch, epochs):
|
254
|
+
summary_writer = tensorboard.SummaryWriter(self.tensorboard_path())
|
255
|
+
summary_writer.hparams({
|
256
|
+
**self.config(),
|
257
|
+
"steps_per_epoch": steps_per_epoch,
|
258
|
+
"epochs": epochs,
|
259
|
+
"batch_size": batch_size
|
260
|
+
})
|
261
|
+
return summary_writer
|
262
|
+
|
263
|
+
def fit(self, data, steps_per_epoch, epochs, train_step_args={}):
|
264
|
+
train_ds = iter(data['train']())
|
265
|
+
if 'test' in data:
|
266
|
+
test_ds = data['test']
|
267
|
+
else:
|
268
|
+
test_ds = None
|
269
|
+
train_step = self._define_train_step(**train_step_args)
|
270
|
+
compute_metrics = self._define_compute_metrics()
|
271
|
+
state = self.state
|
272
|
+
device_count = jax.local_device_count()
|
273
|
+
# train_ds = flax.jax_utils.prefetch_to_device(train_ds, jax.devices())
|
274
|
+
|
275
|
+
summary_writer = self.init_tensorboard(
|
276
|
+
data['global_batch_size'], steps_per_epoch, epochs)
|
277
|
+
|
278
|
+
while self.latest_epoch <= epochs:
|
279
|
+
self.latest_epoch += 1
|
280
|
+
current_epoch = self.latest_epoch
|
281
|
+
print(f"\nEpoch {current_epoch}/{epochs}")
|
282
|
+
start_time = time.time()
|
283
|
+
epoch_loss = 0
|
284
|
+
|
285
|
+
with tqdm.tqdm(total=steps_per_epoch, desc=f'\t\tEpoch {current_epoch}', ncols=100, unit='step') as pbar:
|
286
|
+
for i in range(steps_per_epoch):
|
287
|
+
batch = next(train_ds)
|
288
|
+
if self.distributed_training and device_count > 1:
|
289
|
+
batch = jax.tree.map(lambda x: x.reshape(
|
290
|
+
(device_count, -1, *x.shape[1:])), batch)
|
291
|
+
|
292
|
+
state, loss = train_step(state, batch)
|
293
|
+
loss = jnp.mean(loss)
|
294
|
+
|
295
|
+
epoch_loss += loss
|
296
|
+
if i % 100 == 0:
|
297
|
+
pbar.set_postfix(loss=f'{loss:.4f}')
|
298
|
+
pbar.update(100)
|
299
|
+
current_step = current_epoch*steps_per_epoch + i
|
300
|
+
summary_writer.scalar(
|
301
|
+
'Train Loss', loss, step=current_step)
|
302
|
+
if self.wandb is not None:
|
303
|
+
self.wandb.log({"train/loss": loss})
|
304
|
+
|
305
|
+
print(f"\n\tEpoch done")
|
306
|
+
end_time = time.time()
|
307
|
+
self.state = state
|
308
|
+
total_time = end_time - start_time
|
309
|
+
avg_time_per_step = total_time / steps_per_epoch
|
310
|
+
avg_loss = epoch_loss / steps_per_epoch
|
311
|
+
if avg_loss < self.best_loss:
|
312
|
+
self.best_loss = avg_loss
|
313
|
+
self.best_state = state
|
314
|
+
self.save(current_epoch)
|
315
|
+
|
316
|
+
# Compute Metrics
|
317
|
+
metrics_str = ''
|
318
|
+
|
319
|
+
print(
|
320
|
+
f"\n\tEpoch {current_epoch} completed. Avg Loss: {avg_loss}, Time: {total_time:.2f}s, Best Loss: {self.best_loss} {metrics_str}")
|
321
|
+
|
322
|
+
self.save(epochs)
|
323
|
+
return self.state
|
@@ -1,10 +1,11 @@
|
|
1
1
|
flaxdiff/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
flaxdiff/utils.py,sha256=B0GcHlzlVYDNEIdh2v5qmP4u0neIT-FqexNohuyuCvg,2452
|
3
3
|
flaxdiff/models/__init__.py,sha256=FAivVYXxM2JrCFIXf-C3374RB2Hth25dBrzOeNFhH1U,26
|
4
|
-
flaxdiff/models/attention.py,sha256=
|
5
|
-
flaxdiff/models/common.py,sha256=
|
4
|
+
flaxdiff/models/attention.py,sha256=SL9cvINjmabW1LPvXLAFZNHv-FF1Ez_d3J7n5uHBTyQ,15301
|
5
|
+
flaxdiff/models/common.py,sha256=CjC4iRLjkF3oQ0f6rAqfiLaiHllZGtCOwN3rXDUndbE,274
|
6
6
|
flaxdiff/models/favor_fastattn.py,sha256=79Ew1nqarsNLPzZaBSd1ILORzJr74CupYeqGiCQK5E4,27689
|
7
|
-
flaxdiff/models/simple_unet.py,sha256=
|
7
|
+
flaxdiff/models/simple_unet.py,sha256=WlLry6v18syHBzcN8zAJ-zIVtq6ItMEIBWbeCcX0MLU,18693
|
8
|
+
flaxdiff/models/simple_vit.py,sha256=vTu2CQRoSOxetBHTrnCWddm-vxrZDkMe8EpdNxtpJMk,4015
|
8
9
|
flaxdiff/predictors/__init__.py,sha256=SKkYYRF9Wfgk2zhtZw4vCXOdOeRlrm2Mk6cvuaEvAzc,4403
|
9
10
|
flaxdiff/samplers/__init__.py,sha256=_S-9TwDeshrI0VmapV-J2hqjTByOa0-oOeUs_IdovjU,285
|
10
11
|
flaxdiff/samplers/common.py,sha256=_an5h5Niz9Joz_-ppridLrGHpu8X0VVvhNGknPu6AUY,5272
|
@@ -23,8 +24,9 @@ flaxdiff/schedulers/exp.py,sha256=cPTnUJpYdzJRRZqMLYQz0rRUCpEmaP2tXhRumLx94jA,60
|
|
23
24
|
flaxdiff/schedulers/karras.py,sha256=4GN120kGwdxxU-h2mVdhBVy9IORkUMm_vvz3XjthBcI,3355
|
24
25
|
flaxdiff/schedulers/linear.py,sha256=6003F5ISq1Wc0h6UAzY95MJgsDIKGMhBzbiVALpea0k,581
|
25
26
|
flaxdiff/schedulers/sqrt.py,sha256=1F84ZgQPuoNMhe6yxGTR2G0h7dPOZtm4UDQOakbSsEU,445
|
26
|
-
flaxdiff/trainer/__init__.py,sha256=
|
27
|
-
flaxdiff
|
28
|
-
flaxdiff-0.1.
|
29
|
-
flaxdiff-0.1.
|
30
|
-
flaxdiff-0.1.
|
27
|
+
flaxdiff/trainer/__init__.py,sha256=kwzkm-BD97hffFIXZUP1Hb3_D85fZ4SRNO7bviEwHU8,7591
|
28
|
+
flaxdiff/trainer/simple_trainer.py,sha256=jafxr-yZ6FXn0Qi-iTSnlf275QWnIO4GnSvNAeB3H-Q,11651
|
29
|
+
flaxdiff-0.1.4.dist-info/METADATA,sha256=G8OijdrrYWuKyAfCNtD_dKwdfBmdME56vpR-EYIZKXg,19229
|
30
|
+
flaxdiff-0.1.4.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
|
31
|
+
flaxdiff-0.1.4.dist-info/top_level.txt,sha256=-2-nXnfkJgSfkki1tjm5Faw6Dso7vhtdn2szwCdX5CQ,9
|
32
|
+
flaxdiff-0.1.4.dist-info/RECORD,,
|
File without changes
|
File without changes
|