xax 0.2.21__py3-none-any.whl → 0.2.23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +15 -2
- xax/core/state.py +10 -37
- xax/nn/attention.py +738 -0
- xax/task/logger.py +1 -1
- xax/task/mixins/train.py +17 -19
- xax/utils/experiments.py +2 -2
- xax/utils/jax.py +109 -7
- {xax-0.2.21.dist-info → xax-0.2.23.dist-info}/METADATA +1 -1
- {xax-0.2.21.dist-info → xax-0.2.23.dist-info}/RECORD +13 -12
- {xax-0.2.21.dist-info → xax-0.2.23.dist-info}/WHEEL +1 -1
- {xax-0.2.21.dist-info → xax-0.2.23.dist-info}/entry_points.txt +0 -0
- {xax-0.2.21.dist-info → xax-0.2.23.dist-info}/licenses/LICENSE +0 -0
- {xax-0.2.21.dist-info → xax-0.2.23.dist-info}/top_level.txt +0 -0
xax/__init__.py
CHANGED
@@ -12,7 +12,7 @@ and running the update script:
|
|
12
12
|
python -m scripts.update_api --inplace
|
13
13
|
"""
|
14
14
|
|
15
|
-
__version__ = "0.2.
|
15
|
+
__version__ = "0.2.23"
|
16
16
|
|
17
17
|
# This list shouldn't be modified by hand; instead, run the update script.
|
18
18
|
__all__ = [
|
@@ -23,6 +23,10 @@ __all__ = [
|
|
23
23
|
"get_run_dir",
|
24
24
|
"load_user_config",
|
25
25
|
"State",
|
26
|
+
"CrossAttentionBlock",
|
27
|
+
"SelfAttentionBlock",
|
28
|
+
"Transformer",
|
29
|
+
"TransformerBlock",
|
26
30
|
"FourierEmbeddings",
|
27
31
|
"IdentityPositionalEmbeddings",
|
28
32
|
"LearnedPositionalEmbeddings",
|
@@ -112,8 +116,10 @@ __all__ = [
|
|
112
116
|
"save_config",
|
113
117
|
"stage_environment",
|
114
118
|
"to_markdown_table",
|
119
|
+
"grad",
|
115
120
|
"jit",
|
116
121
|
"scan",
|
122
|
+
"vmap",
|
117
123
|
"save_jaxpr_dot",
|
118
124
|
"ColoredFormatter",
|
119
125
|
"configure_logging",
|
@@ -198,6 +204,10 @@ NAME_MAP: dict[str, str] = {
|
|
198
204
|
"get_run_dir": "core.conf",
|
199
205
|
"load_user_config": "core.conf",
|
200
206
|
"State": "core.state",
|
207
|
+
"CrossAttentionBlock": "nn.attention",
|
208
|
+
"SelfAttentionBlock": "nn.attention",
|
209
|
+
"Transformer": "nn.attention",
|
210
|
+
"TransformerBlock": "nn.attention",
|
201
211
|
"FourierEmbeddings": "nn.embeddings",
|
202
212
|
"IdentityPositionalEmbeddings": "nn.embeddings",
|
203
213
|
"LearnedPositionalEmbeddings": "nn.embeddings",
|
@@ -287,8 +297,10 @@ NAME_MAP: dict[str, str] = {
|
|
287
297
|
"save_config": "utils.experiments",
|
288
298
|
"stage_environment": "utils.experiments",
|
289
299
|
"to_markdown_table": "utils.experiments",
|
300
|
+
"grad": "utils.jax",
|
290
301
|
"jit": "utils.jax",
|
291
302
|
"scan": "utils.jax",
|
303
|
+
"vmap": "utils.jax",
|
292
304
|
"save_jaxpr_dot": "utils.jaxpr",
|
293
305
|
"ColoredFormatter": "utils.logging",
|
294
306
|
"configure_logging": "utils.logging",
|
@@ -366,6 +378,7 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
366
378
|
load_user_config,
|
367
379
|
)
|
368
380
|
from xax.core.state import Phase, State
|
381
|
+
from xax.nn.attention import CrossAttentionBlock, SelfAttentionBlock, Transformer, TransformerBlock
|
369
382
|
from xax.nn.embeddings import (
|
370
383
|
EmbeddingKind,
|
371
384
|
FourierEmbeddings,
|
@@ -460,7 +473,7 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
460
473
|
stage_environment,
|
461
474
|
to_markdown_table,
|
462
475
|
)
|
463
|
-
from xax.utils.jax import jit, scan
|
476
|
+
from xax.utils.jax import grad, jit, scan, vmap
|
464
477
|
from xax.utils.jaxpr import save_jaxpr_dot
|
465
478
|
from xax.utils.logging import (
|
466
479
|
LOG_ERROR_SUMMARY,
|
xax/core/state.py
CHANGED
@@ -27,11 +27,8 @@ def _int_to_phase(i: int) -> Phase:
|
|
27
27
|
class StateDict(TypedDict, total=False):
|
28
28
|
num_steps: NotRequired[int | Array]
|
29
29
|
num_samples: NotRequired[int | Array]
|
30
|
-
num_valid_steps: NotRequired[int | Array]
|
31
|
-
num_valid_samples: NotRequired[int | Array]
|
32
30
|
start_time_s: NotRequired[float | Array]
|
33
31
|
elapsed_time_s: NotRequired[float | Array]
|
34
|
-
valid_elapsed_time_s: NotRequired[float | Array]
|
35
32
|
phase: NotRequired[Phase]
|
36
33
|
_phase: NotRequired[int | Array]
|
37
34
|
|
@@ -47,38 +44,26 @@ class State:
|
|
47
44
|
return self._int32_arr[0]
|
48
45
|
|
49
46
|
@property
|
50
|
-
def
|
51
|
-
return self._int32_arr[1]
|
47
|
+
def phase(self) -> Phase:
|
48
|
+
return _int_to_phase(self._int32_arr[1].item())
|
52
49
|
|
53
50
|
@property
|
54
51
|
def num_samples(self) -> Array:
|
55
52
|
return self._float32_arr[0]
|
56
53
|
|
57
|
-
@property
|
58
|
-
def num_valid_samples(self) -> Array:
|
59
|
-
return self._float32_arr[1]
|
60
|
-
|
61
54
|
@property
|
62
55
|
def start_time_s(self) -> Array:
|
63
|
-
return self._float32_arr[
|
56
|
+
return self._float32_arr[1]
|
64
57
|
|
65
58
|
@property
|
66
59
|
def elapsed_time_s(self) -> Array:
|
67
|
-
return self._float32_arr[
|
68
|
-
|
69
|
-
@property
|
70
|
-
def valid_elapsed_time_s(self) -> Array:
|
71
|
-
return self._float32_arr[4]
|
72
|
-
|
73
|
-
@property
|
74
|
-
def phase(self) -> Phase:
|
75
|
-
return _int_to_phase(self._int32_arr[2].item())
|
60
|
+
return self._float32_arr[2]
|
76
61
|
|
77
62
|
@classmethod
|
78
63
|
def init_state(cls) -> "State":
|
79
64
|
return cls(
|
80
|
-
_int32_arr=jnp.array([0, 0
|
81
|
-
_float32_arr=jnp.array([0.0,
|
65
|
+
_int32_arr=jnp.array([0, 0], dtype=jnp.int32),
|
66
|
+
_float32_arr=jnp.array([0.0, time.time(), 0.0], dtype=jnp.float32),
|
82
67
|
)
|
83
68
|
|
84
69
|
@property
|
@@ -91,25 +76,19 @@ class State:
|
|
91
76
|
|
92
77
|
if "num_steps" in kwargs:
|
93
78
|
int32_arr = int32_arr.at[0].set(kwargs["num_steps"])
|
94
|
-
if "num_valid_steps" in kwargs:
|
95
|
-
int32_arr = int32_arr.at[1].set(kwargs["num_valid_steps"])
|
96
79
|
|
97
80
|
if "phase" in kwargs:
|
98
|
-
int32_arr = int32_arr.at[
|
81
|
+
int32_arr = int32_arr.at[1].set(_phase_to_int(kwargs["phase"]))
|
99
82
|
if "_phase" in kwargs:
|
100
|
-
int32_arr = int32_arr.at[
|
83
|
+
int32_arr = int32_arr.at[1].set(kwargs["_phase"])
|
101
84
|
|
102
85
|
if "num_samples" in kwargs:
|
103
86
|
float32_arr = float32_arr.at[0].set(kwargs["num_samples"])
|
104
|
-
if "num_valid_samples" in kwargs:
|
105
|
-
float32_arr = float32_arr.at[1].set(kwargs["num_valid_samples"])
|
106
87
|
|
107
88
|
if "start_time_s" in kwargs:
|
108
|
-
float32_arr = float32_arr.at[
|
89
|
+
float32_arr = float32_arr.at[1].set(kwargs["start_time_s"])
|
109
90
|
if "elapsed_time_s" in kwargs:
|
110
|
-
float32_arr = float32_arr.at[
|
111
|
-
if "valid_elapsed_time_s" in kwargs:
|
112
|
-
float32_arr = float32_arr.at[4].set(kwargs["valid_elapsed_time_s"])
|
91
|
+
float32_arr = float32_arr.at[2].set(kwargs["elapsed_time_s"])
|
113
92
|
|
114
93
|
return State(
|
115
94
|
_int32_arr=int32_arr,
|
@@ -119,12 +98,9 @@ class State:
|
|
119
98
|
def to_dict(self) -> dict[str, int | float | str]:
|
120
99
|
return {
|
121
100
|
"num_steps": int(self.num_steps.item()),
|
122
|
-
"num_valid_steps": int(self.num_valid_steps.item()),
|
123
101
|
"num_samples": int(self.num_samples.item()),
|
124
|
-
"num_valid_samples": int(self.num_valid_samples.item()),
|
125
102
|
"start_time_s": float(self.start_time_s.item()),
|
126
103
|
"elapsed_time_s": float(self.elapsed_time_s.item()),
|
127
|
-
"valid_elapsed_time_s": float(self.valid_elapsed_time_s.item()),
|
128
104
|
"phase": str(self.phase),
|
129
105
|
}
|
130
106
|
|
@@ -136,7 +112,6 @@ class State:
|
|
136
112
|
int32_arr = jnp.array(
|
137
113
|
[
|
138
114
|
d.get("num_steps", 0),
|
139
|
-
d.get("num_valid_steps", 0),
|
140
115
|
d.get("_phase", 0),
|
141
116
|
],
|
142
117
|
dtype=jnp.int32,
|
@@ -145,10 +120,8 @@ class State:
|
|
145
120
|
float32_arr = jnp.array(
|
146
121
|
[
|
147
122
|
d.get("num_samples", 0),
|
148
|
-
d.get("num_valid_samples", 0),
|
149
123
|
d.get("start_time_s", time.time()),
|
150
124
|
d.get("elapsed_time_s", 0.0),
|
151
|
-
d.get("valid_elapsed_time_s", 0.0),
|
152
125
|
],
|
153
126
|
dtype=jnp.float32,
|
154
127
|
)
|
xax/nn/attention.py
ADDED
@@ -0,0 +1,738 @@
|
|
1
|
+
"""Attention mechanisms for transformer models."""
|
2
|
+
|
3
|
+
from typing import Literal, cast, overload
|
4
|
+
|
5
|
+
import chex
|
6
|
+
import equinox as eqx
|
7
|
+
import jax
|
8
|
+
import jax.numpy as jnp
|
9
|
+
from jaxtyping import Array, PRNGKeyArray
|
10
|
+
|
11
|
+
|
12
|
+
class SelfAttentionBlock(eqx.Module):
|
13
|
+
"""Self-attention block using jax.nn.dot_product_attention."""
|
14
|
+
|
15
|
+
q_proj: eqx.nn.Linear
|
16
|
+
k_proj: eqx.nn.Linear
|
17
|
+
v_proj: eqx.nn.Linear
|
18
|
+
output_proj: eqx.nn.Linear
|
19
|
+
num_heads: int = eqx.static_field()
|
20
|
+
head_dim: int = eqx.static_field()
|
21
|
+
causal: bool = eqx.static_field()
|
22
|
+
|
23
|
+
def __init__(
|
24
|
+
self,
|
25
|
+
embed_dim: int,
|
26
|
+
num_heads: int,
|
27
|
+
*,
|
28
|
+
key: PRNGKeyArray,
|
29
|
+
causal: bool = False,
|
30
|
+
) -> None:
|
31
|
+
keys = jax.random.split(key, 4)
|
32
|
+
|
33
|
+
self.num_heads = num_heads
|
34
|
+
self.head_dim = embed_dim // num_heads
|
35
|
+
assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
|
36
|
+
|
37
|
+
self.q_proj = eqx.nn.Linear(embed_dim, embed_dim, key=keys[0])
|
38
|
+
self.k_proj = eqx.nn.Linear(embed_dim, embed_dim, key=keys[1])
|
39
|
+
self.v_proj = eqx.nn.Linear(embed_dim, embed_dim, key=keys[2])
|
40
|
+
self.output_proj = eqx.nn.Linear(embed_dim, embed_dim, key=keys[3])
|
41
|
+
|
42
|
+
self.causal = causal
|
43
|
+
|
44
|
+
def _reshape_for_multihead(self, x: Array) -> Array:
|
45
|
+
"""Reshape from (seq_len, embed_dim) to (seq_len, num_heads, head_dim)."""
|
46
|
+
seq_len, _ = x.shape
|
47
|
+
return x.reshape(seq_len, self.num_heads, self.head_dim)
|
48
|
+
|
49
|
+
def _combine_heads(self, x: Array) -> Array:
|
50
|
+
"""Reshape from (seq_len, num_heads, head_dim) to (seq_len, embed_dim)."""
|
51
|
+
seq_len, _, _ = x.shape
|
52
|
+
return x.reshape(seq_len, -1)
|
53
|
+
|
54
|
+
def __call__(
|
55
|
+
self,
|
56
|
+
x: Array,
|
57
|
+
*,
|
58
|
+
key: PRNGKeyArray | None = None,
|
59
|
+
mask: Array | None = None,
|
60
|
+
cache: dict[str, Array] | None = None,
|
61
|
+
update_cache: bool = False,
|
62
|
+
) -> Array | tuple[Array, dict[str, Array]]:
|
63
|
+
"""Apply self-attention to the input.
|
64
|
+
|
65
|
+
Args:
|
66
|
+
x: Input tensor of shape (seq_len, embed_dim)
|
67
|
+
key: PRNGKey for dropout randomness
|
68
|
+
mask: Optional mask tensor of shape (seq_len, seq_len) or broadcastable
|
69
|
+
cache: Optional dictionary containing cached key and value tensors
|
70
|
+
update_cache: Whether to update the cache and return it
|
71
|
+
|
72
|
+
Returns:
|
73
|
+
If update_cache is False: Output tensor of shape (seq_len, embed_dim)
|
74
|
+
If update_cache is True: Tuple of (output tensor, updated cache)
|
75
|
+
"""
|
76
|
+
chex.assert_rank(x, 2)
|
77
|
+
|
78
|
+
# Project inputs to queries, keys, and values
|
79
|
+
q = jax.vmap(self.q_proj)(x)
|
80
|
+
|
81
|
+
# Use cached key/value if provided and not updating cache
|
82
|
+
if cache is not None and not update_cache:
|
83
|
+
k = cache["k"]
|
84
|
+
v = cache["v"]
|
85
|
+
else:
|
86
|
+
k = jax.vmap(self.k_proj)(x)
|
87
|
+
v = jax.vmap(self.v_proj)(x)
|
88
|
+
|
89
|
+
# Update cache if needed
|
90
|
+
if update_cache:
|
91
|
+
if cache is None:
|
92
|
+
cache = {}
|
93
|
+
cache = {"k": k, "v": v}
|
94
|
+
|
95
|
+
# Reshape to multihead format
|
96
|
+
q = self._reshape_for_multihead(q)
|
97
|
+
k = self._reshape_for_multihead(k)
|
98
|
+
v = self._reshape_for_multihead(v)
|
99
|
+
|
100
|
+
# Apply dot product attention.
|
101
|
+
# Note that Apple Silicon struggles with this:
|
102
|
+
# https://github.com/jax-ml/jax/issues/20114
|
103
|
+
attn_output = jax.nn.dot_product_attention(
|
104
|
+
q,
|
105
|
+
k,
|
106
|
+
v,
|
107
|
+
mask=mask,
|
108
|
+
is_causal=self.causal and mask is None,
|
109
|
+
)
|
110
|
+
|
111
|
+
# Combine heads
|
112
|
+
attn_output = self._combine_heads(attn_output)
|
113
|
+
|
114
|
+
# Final projection
|
115
|
+
output = jax.vmap(self.output_proj)(attn_output)
|
116
|
+
|
117
|
+
if update_cache:
|
118
|
+
return output, cast(dict[str, Array], cache)
|
119
|
+
return output
|
120
|
+
|
121
|
+
|
122
|
+
class CrossAttentionBlock(eqx.Module):
|
123
|
+
"""Cross-attention block using jax.nn.dot_product_attention."""
|
124
|
+
|
125
|
+
q_proj: eqx.nn.Linear
|
126
|
+
k_proj: eqx.nn.Linear
|
127
|
+
v_proj: eqx.nn.Linear
|
128
|
+
output_proj: eqx.nn.Linear
|
129
|
+
num_heads: int = eqx.static_field()
|
130
|
+
head_dim: int = eqx.static_field()
|
131
|
+
|
132
|
+
def __init__(
|
133
|
+
self,
|
134
|
+
embed_dim: int,
|
135
|
+
num_heads: int,
|
136
|
+
*,
|
137
|
+
key: PRNGKeyArray,
|
138
|
+
) -> None:
|
139
|
+
keys = jax.random.split(key, 4)
|
140
|
+
|
141
|
+
self.num_heads = num_heads
|
142
|
+
self.head_dim = embed_dim // num_heads
|
143
|
+
assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
|
144
|
+
|
145
|
+
self.q_proj = eqx.nn.Linear(embed_dim, embed_dim, key=keys[0])
|
146
|
+
self.k_proj = eqx.nn.Linear(embed_dim, embed_dim, key=keys[1])
|
147
|
+
self.v_proj = eqx.nn.Linear(embed_dim, embed_dim, key=keys[2])
|
148
|
+
self.output_proj = eqx.nn.Linear(embed_dim, embed_dim, key=keys[3])
|
149
|
+
|
150
|
+
def _reshape_for_multihead(self, x: Array) -> Array:
|
151
|
+
"""Reshape from (seq_len, embed_dim) to (seq_len, num_heads, head_dim)."""
|
152
|
+
seq_len, _ = x.shape
|
153
|
+
return x.reshape(seq_len, self.num_heads, self.head_dim)
|
154
|
+
|
155
|
+
def _combine_heads(self, x: Array) -> Array:
|
156
|
+
"""Reshape from (seq_len, num_heads, head_dim) to (seq_len, embed_dim)."""
|
157
|
+
seq_len, _, _ = x.shape
|
158
|
+
return x.reshape(seq_len, -1)
|
159
|
+
|
160
|
+
def __call__(
|
161
|
+
self,
|
162
|
+
q_input: Array,
|
163
|
+
kv_input: Array,
|
164
|
+
*,
|
165
|
+
key: PRNGKeyArray | None = None,
|
166
|
+
mask: Array | None = None,
|
167
|
+
cache: dict[str, Array] | None = None,
|
168
|
+
update_cache: bool = False,
|
169
|
+
) -> Array | tuple[Array, dict[str, Array]]:
|
170
|
+
"""Apply cross-attention.
|
171
|
+
|
172
|
+
Args:
|
173
|
+
q_input: Query input tensor of shape (q_seq_len, embed_dim)
|
174
|
+
kv_input: Key/value input tensor of shape (kv_seq_len, embed_dim)
|
175
|
+
key: PRNGKey for dropout randomness
|
176
|
+
mask: Optional mask tensor
|
177
|
+
cache: Optional dictionary containing cached key and value tensors
|
178
|
+
update_cache: Whether to update the cache and return it
|
179
|
+
|
180
|
+
Returns:
|
181
|
+
If update_cache is False: Output tensor of shape (q_seq_len, embed_dim)
|
182
|
+
If update_cache is True: Tuple of (output tensor, updated cache)
|
183
|
+
"""
|
184
|
+
chex.assert_rank(q_input, 2)
|
185
|
+
chex.assert_rank(kv_input, 2)
|
186
|
+
|
187
|
+
# Project inputs to queries, keys, and values
|
188
|
+
q = jax.vmap(self.q_proj)(q_input)
|
189
|
+
|
190
|
+
# Use cached key/value if provided and not updating cache
|
191
|
+
if cache is not None and not update_cache:
|
192
|
+
k = cache["k"]
|
193
|
+
v = cache["v"]
|
194
|
+
else:
|
195
|
+
k = jax.vmap(self.k_proj)(kv_input)
|
196
|
+
v = jax.vmap(self.v_proj)(kv_input)
|
197
|
+
|
198
|
+
# Update cache if needed
|
199
|
+
if update_cache:
|
200
|
+
if cache is None:
|
201
|
+
cache = {}
|
202
|
+
cache = {"k": k, "v": v}
|
203
|
+
|
204
|
+
# Reshape to multihead format
|
205
|
+
q = self._reshape_for_multihead(q)
|
206
|
+
k = self._reshape_for_multihead(k)
|
207
|
+
v = self._reshape_for_multihead(v)
|
208
|
+
|
209
|
+
# Apply dot product attention
|
210
|
+
attn_output = jax.nn.dot_product_attention(
|
211
|
+
q,
|
212
|
+
k,
|
213
|
+
v,
|
214
|
+
mask=mask,
|
215
|
+
is_causal=False,
|
216
|
+
)
|
217
|
+
|
218
|
+
# Combine heads
|
219
|
+
attn_output = self._combine_heads(attn_output)
|
220
|
+
|
221
|
+
# Final projection
|
222
|
+
output = jax.vmap(self.output_proj)(attn_output)
|
223
|
+
|
224
|
+
if update_cache:
|
225
|
+
return output, cast(dict[str, Array], cache)
|
226
|
+
return output
|
227
|
+
|
228
|
+
|
229
|
+
class TransformerBlock(eqx.Module):
|
230
|
+
self_attn: SelfAttentionBlock
|
231
|
+
cross_attn: CrossAttentionBlock | None
|
232
|
+
feed_forward: eqx.nn.MLP
|
233
|
+
layer_norm1: eqx.nn.LayerNorm
|
234
|
+
layer_norm2: eqx.nn.LayerNorm
|
235
|
+
layer_norm3: eqx.nn.LayerNorm | None
|
236
|
+
|
237
|
+
def __init__(
|
238
|
+
self,
|
239
|
+
embed_dim: int,
|
240
|
+
num_heads: int,
|
241
|
+
ff_dim: int,
|
242
|
+
*,
|
243
|
+
key: PRNGKeyArray,
|
244
|
+
causal: bool = False,
|
245
|
+
cross_attention: bool = False,
|
246
|
+
) -> None:
|
247
|
+
keys = jax.random.split(key, 4)
|
248
|
+
|
249
|
+
self.self_attn = SelfAttentionBlock(
|
250
|
+
embed_dim=embed_dim,
|
251
|
+
num_heads=num_heads,
|
252
|
+
key=keys[0],
|
253
|
+
causal=causal,
|
254
|
+
)
|
255
|
+
|
256
|
+
if cross_attention:
|
257
|
+
self.cross_attn = CrossAttentionBlock(
|
258
|
+
embed_dim=embed_dim,
|
259
|
+
num_heads=num_heads,
|
260
|
+
key=keys[1],
|
261
|
+
)
|
262
|
+
self.layer_norm3 = eqx.nn.LayerNorm(embed_dim)
|
263
|
+
else:
|
264
|
+
self.cross_attn = None
|
265
|
+
self.layer_norm3 = None
|
266
|
+
|
267
|
+
self.layer_norm1 = eqx.nn.LayerNorm(embed_dim)
|
268
|
+
self.layer_norm2 = eqx.nn.LayerNorm(embed_dim)
|
269
|
+
|
270
|
+
self.feed_forward = eqx.nn.MLP(
|
271
|
+
in_size=embed_dim,
|
272
|
+
out_size=embed_dim,
|
273
|
+
width_size=ff_dim,
|
274
|
+
depth=1,
|
275
|
+
activation=jax.nn.gelu,
|
276
|
+
key=keys[2],
|
277
|
+
)
|
278
|
+
|
279
|
+
@overload
|
280
|
+
def __call__(
|
281
|
+
self,
|
282
|
+
x: Array,
|
283
|
+
*,
|
284
|
+
context: Array | None = None,
|
285
|
+
self_mask: Array | None = None,
|
286
|
+
cross_mask: Array | None = None,
|
287
|
+
key: PRNGKeyArray | None = None,
|
288
|
+
cache: dict[str, dict[str, Array]] | None = None,
|
289
|
+
update_cache: Literal[True],
|
290
|
+
) -> tuple[Array, dict[str, dict[str, Array]]]: ...
|
291
|
+
|
292
|
+
@overload
|
293
|
+
def __call__(
|
294
|
+
self,
|
295
|
+
x: Array,
|
296
|
+
*,
|
297
|
+
context: Array | None = None,
|
298
|
+
self_mask: Array | None = None,
|
299
|
+
cross_mask: Array | None = None,
|
300
|
+
key: PRNGKeyArray | None = None,
|
301
|
+
cache: dict[str, dict[str, Array]] | None = None,
|
302
|
+
update_cache: Literal[False] = False,
|
303
|
+
) -> Array: ...
|
304
|
+
|
305
|
+
def __call__(
|
306
|
+
self,
|
307
|
+
x: Array,
|
308
|
+
*,
|
309
|
+
context: Array | None = None,
|
310
|
+
self_mask: Array | None = None,
|
311
|
+
cross_mask: Array | None = None,
|
312
|
+
key: PRNGKeyArray | None = None,
|
313
|
+
cache: dict[str, dict[str, Array]] | None = None,
|
314
|
+
update_cache: bool = False,
|
315
|
+
) -> Array | tuple[Array, dict[str, dict[str, Array]]]:
|
316
|
+
"""Apply transformer block.
|
317
|
+
|
318
|
+
Args:
|
319
|
+
x: Input tensor
|
320
|
+
context: Optional context for cross-attention
|
321
|
+
self_mask: Mask for self-attention
|
322
|
+
cross_mask: Mask for cross-attention
|
323
|
+
key: Optional PRNG key for dropout
|
324
|
+
cache: Optional dictionary containing cached key and value tensors
|
325
|
+
update_cache: Whether to update the cache and return it
|
326
|
+
|
327
|
+
Returns:
|
328
|
+
If update_cache is False: Output tensor
|
329
|
+
If update_cache is True: Tuple of (output tensor, updated cache)
|
330
|
+
"""
|
331
|
+
chex.assert_rank(x, 2)
|
332
|
+
if key is not None:
|
333
|
+
key1, key2 = jax.random.split(key)
|
334
|
+
else:
|
335
|
+
key1 = key2 = None
|
336
|
+
|
337
|
+
# Initialize cache if needed
|
338
|
+
updated_cache = {}
|
339
|
+
if cache is None:
|
340
|
+
cache = {}
|
341
|
+
|
342
|
+
# Self-attention block with pre-norm
|
343
|
+
norm_x = jax.vmap(self.layer_norm1)(x)
|
344
|
+
|
345
|
+
self_attn_cache = cache.get("self_attn")
|
346
|
+
if update_cache:
|
347
|
+
attn_output, self_attn_cache = self.self_attn(
|
348
|
+
norm_x, key=key1, mask=self_mask, cache=self_attn_cache, update_cache=True
|
349
|
+
)
|
350
|
+
updated_cache["self_attn"] = self_attn_cache
|
351
|
+
else:
|
352
|
+
attn_output = self.self_attn(norm_x, key=key1, mask=self_mask, cache=self_attn_cache)
|
353
|
+
|
354
|
+
x = x + attn_output
|
355
|
+
|
356
|
+
# Cross-attention block (if enabled) with pre-norm
|
357
|
+
if self.cross_attn is not None and context is not None:
|
358
|
+
assert self.layer_norm3 is not None
|
359
|
+
|
360
|
+
norm_x = jax.vmap(self.layer_norm3)(x)
|
361
|
+
cross_attn_cache = cache.get("cross_attn")
|
362
|
+
|
363
|
+
if update_cache:
|
364
|
+
cross_attn_output, cross_attn_cache = self.cross_attn(
|
365
|
+
norm_x, context, key=key2, mask=cross_mask, cache=cross_attn_cache, update_cache=True
|
366
|
+
)
|
367
|
+
updated_cache["cross_attn"] = cross_attn_cache
|
368
|
+
else:
|
369
|
+
cross_attn_output = self.cross_attn(norm_x, context, key=key2, mask=cross_mask, cache=cross_attn_cache)
|
370
|
+
|
371
|
+
x = x + cross_attn_output
|
372
|
+
|
373
|
+
# Feed-forward block with pre-norm
|
374
|
+
norm_x = jax.vmap(self.layer_norm2)(x)
|
375
|
+
ff_output = jax.vmap(self.feed_forward)(norm_x)
|
376
|
+
x = x + ff_output
|
377
|
+
|
378
|
+
if update_cache:
|
379
|
+
return x, updated_cache
|
380
|
+
return x
|
381
|
+
|
382
|
+
|
383
|
+
class Transformer(eqx.Module):
|
384
|
+
token_embedding: eqx.nn.Embedding
|
385
|
+
position_embedding: eqx.nn.Embedding | None
|
386
|
+
layers: list[TransformerBlock]
|
387
|
+
output_layer: eqx.nn.Linear | None
|
388
|
+
layer_norm: eqx.nn.LayerNorm
|
389
|
+
max_seq_len: int = eqx.static_field()
|
390
|
+
embed_dim: int = eqx.static_field()
|
391
|
+
|
392
|
+
def __init__(
|
393
|
+
self,
|
394
|
+
vocab_size: int,
|
395
|
+
embed_dim: int,
|
396
|
+
num_heads: int,
|
397
|
+
ff_dim: int,
|
398
|
+
num_layers: int,
|
399
|
+
max_seq_len: int,
|
400
|
+
output_size: int | None = None,
|
401
|
+
*,
|
402
|
+
key: PRNGKeyArray,
|
403
|
+
causal: bool = False,
|
404
|
+
cross_attention: bool = False,
|
405
|
+
use_absolute_position: bool = True,
|
406
|
+
) -> None:
|
407
|
+
keys = jax.random.split(key, num_layers + 3)
|
408
|
+
|
409
|
+
self.token_embedding = eqx.nn.Embedding(vocab_size, embed_dim, key=keys[0])
|
410
|
+
|
411
|
+
# Position embeddings can be disabled
|
412
|
+
if use_absolute_position:
|
413
|
+
self.position_embedding = eqx.nn.Embedding(max_seq_len, embed_dim, key=keys[1])
|
414
|
+
else:
|
415
|
+
self.position_embedding = None
|
416
|
+
|
417
|
+
self.layers = [
|
418
|
+
TransformerBlock(
|
419
|
+
embed_dim=embed_dim,
|
420
|
+
num_heads=num_heads,
|
421
|
+
ff_dim=ff_dim,
|
422
|
+
key=keys[i + 2],
|
423
|
+
causal=causal,
|
424
|
+
cross_attention=cross_attention,
|
425
|
+
)
|
426
|
+
for i in range(num_layers)
|
427
|
+
]
|
428
|
+
|
429
|
+
self.layer_norm = eqx.nn.LayerNorm(embed_dim)
|
430
|
+
|
431
|
+
if output_size is not None:
|
432
|
+
self.output_layer = eqx.nn.Linear(embed_dim, output_size, key=keys[-1])
|
433
|
+
else:
|
434
|
+
self.output_layer = None
|
435
|
+
|
436
|
+
self.max_seq_len = max_seq_len
|
437
|
+
self.embed_dim = embed_dim
|
438
|
+
|
439
|
+
def _add_positional_embedding(self, x_embedded: Array, positions: Array | None = None) -> Array:
|
440
|
+
"""Add positional embeddings to the token embeddings."""
|
441
|
+
if self.position_embedding is None:
|
442
|
+
return x_embedded
|
443
|
+
|
444
|
+
seq_len, _ = x_embedded.shape
|
445
|
+
|
446
|
+
if positions is None:
|
447
|
+
positions = jnp.arange(seq_len)
|
448
|
+
pos_embedded = jax.vmap(self.position_embedding)(positions)
|
449
|
+
|
450
|
+
return x_embedded + pos_embedded
|
451
|
+
|
452
|
+
def encode(
|
453
|
+
self,
|
454
|
+
x: Array,
|
455
|
+
mask: Array | None = None,
|
456
|
+
positions: Array | None = None,
|
457
|
+
key: PRNGKeyArray | None = None,
|
458
|
+
cache: dict[str, dict[str, dict[str, Array]]] | None = None,
|
459
|
+
update_cache: bool = False,
|
460
|
+
) -> Array | tuple[Array, dict[str, dict[str, dict[str, Array]]]]:
|
461
|
+
"""Encode the input sequence.
|
462
|
+
|
463
|
+
Args:
|
464
|
+
x: Input token indices of shape (seq_len)
|
465
|
+
mask: Optional attention mask
|
466
|
+
positions: Optional positions
|
467
|
+
key: Optional PRNG key for dropout
|
468
|
+
cache: Optional dictionary containing cached key and value tensors
|
469
|
+
update_cache: Whether to update the cache and return it
|
470
|
+
|
471
|
+
Returns:
|
472
|
+
If update_cache is False: Encoded representation
|
473
|
+
If update_cache is True: Tuple of (encoded representation, updated cache)
|
474
|
+
"""
|
475
|
+
# Token embedding
|
476
|
+
x_embedded = jax.vmap(self.token_embedding)(x)
|
477
|
+
|
478
|
+
# Add positional embedding
|
479
|
+
x_embedded = self._add_positional_embedding(x_embedded, positions)
|
480
|
+
|
481
|
+
# Initialize layer caches
|
482
|
+
if cache is None and update_cache:
|
483
|
+
cache = {f"layer_{i}": {} for i in range(len(self.layers))}
|
484
|
+
|
485
|
+
# Updated cache will be built if needed
|
486
|
+
updated_cache = {}
|
487
|
+
|
488
|
+
# Apply transformer layers
|
489
|
+
keys: Array | list[None] = [None] * len(self.layers)
|
490
|
+
if key is not None:
|
491
|
+
keys = jax.random.split(key, len(self.layers))
|
492
|
+
|
493
|
+
for i, (layer, layer_key) in enumerate(zip(self.layers, keys, strict=False)):
|
494
|
+
layer_cache = None if cache is None else cache.get(f"layer_{i}")
|
495
|
+
|
496
|
+
if update_cache:
|
497
|
+
x_embedded, layer_updated_cache = layer.__call__(
|
498
|
+
x_embedded,
|
499
|
+
self_mask=mask,
|
500
|
+
key=layer_key,
|
501
|
+
cache=layer_cache,
|
502
|
+
update_cache=True,
|
503
|
+
)
|
504
|
+
updated_cache[f"layer_{i}"] = layer_updated_cache
|
505
|
+
else:
|
506
|
+
x_embedded = layer.__call__(
|
507
|
+
x_embedded,
|
508
|
+
self_mask=mask,
|
509
|
+
key=layer_key,
|
510
|
+
cache=layer_cache,
|
511
|
+
)
|
512
|
+
|
513
|
+
# Apply final layer norm
|
514
|
+
output = jax.vmap(self.layer_norm)(x_embedded)
|
515
|
+
|
516
|
+
if update_cache:
|
517
|
+
return output, updated_cache
|
518
|
+
return output
|
519
|
+
|
520
|
+
def decode(
|
521
|
+
self,
|
522
|
+
x: Array,
|
523
|
+
context: Array,
|
524
|
+
self_mask: Array | None = None,
|
525
|
+
cross_mask: Array | None = None,
|
526
|
+
positions: Array | None = None,
|
527
|
+
key: PRNGKeyArray | None = None,
|
528
|
+
cache: dict[str, dict[str, dict[str, Array]]] | None = None,
|
529
|
+
update_cache: bool = False,
|
530
|
+
) -> Array | tuple[Array, dict[str, dict[str, dict[str, Array]]]]:
|
531
|
+
"""Decode with self-attention and cross-attention.
|
532
|
+
|
533
|
+
Args:
|
534
|
+
x: Input token indices
|
535
|
+
context: Context from encoder
|
536
|
+
self_mask: Optional self-attention mask
|
537
|
+
cross_mask: Optional cross-attention mask
|
538
|
+
positions: Optional positions
|
539
|
+
key: Optional PRNG key for dropout
|
540
|
+
cache: Optional dictionary containing cached key and value tensors
|
541
|
+
update_cache: Whether to update the cache and return it
|
542
|
+
|
543
|
+
Returns:
|
544
|
+
If update_cache is False: Decoded representation
|
545
|
+
If update_cache is True: Tuple of (decoded representation, updated cache)
|
546
|
+
"""
|
547
|
+
# Token embedding
|
548
|
+
x_embedded = jax.vmap(lambda x_seq: jax.vmap(self.token_embedding)(x_seq))(x)
|
549
|
+
|
550
|
+
# Add positional embedding
|
551
|
+
x_embedded = self._add_positional_embedding(x_embedded, positions)
|
552
|
+
|
553
|
+
# Initialize layer caches
|
554
|
+
if cache is None and update_cache:
|
555
|
+
cache = {f"layer_{i}": {} for i in range(len(self.layers))}
|
556
|
+
|
557
|
+
# Updated cache will be built if needed
|
558
|
+
updated_cache = {}
|
559
|
+
|
560
|
+
# Apply transformer layers with cross-attention
|
561
|
+
keys: Array | list[None] = [None] * len(self.layers)
|
562
|
+
if key is not None:
|
563
|
+
keys = jax.random.split(key, len(self.layers))
|
564
|
+
|
565
|
+
for i, (layer, layer_key) in enumerate(zip(self.layers, keys, strict=False)):
|
566
|
+
layer_cache = None if cache is None else cache.get(f"layer_{i}")
|
567
|
+
|
568
|
+
if update_cache:
|
569
|
+
x_embedded, layer_updated_cache = layer.__call__(
|
570
|
+
x_embedded,
|
571
|
+
context=context,
|
572
|
+
self_mask=self_mask,
|
573
|
+
cross_mask=cross_mask,
|
574
|
+
key=layer_key,
|
575
|
+
cache=layer_cache,
|
576
|
+
update_cache=True,
|
577
|
+
)
|
578
|
+
updated_cache[f"layer_{i}"] = layer_updated_cache
|
579
|
+
else:
|
580
|
+
x_embedded = layer(
|
581
|
+
x_embedded,
|
582
|
+
context=context,
|
583
|
+
self_mask=self_mask,
|
584
|
+
cross_mask=cross_mask,
|
585
|
+
key=layer_key,
|
586
|
+
cache=layer_cache,
|
587
|
+
)
|
588
|
+
|
589
|
+
# Apply final layer norm
|
590
|
+
output = jax.vmap(self.layer_norm)(x_embedded)
|
591
|
+
|
592
|
+
if update_cache:
|
593
|
+
return output, updated_cache
|
594
|
+
return output
|
595
|
+
|
596
|
+
@overload
|
597
|
+
def __call__(
|
598
|
+
self,
|
599
|
+
x: Array,
|
600
|
+
*,
|
601
|
+
mask: Array | None = None,
|
602
|
+
positions: Array | None = None,
|
603
|
+
key: PRNGKeyArray | None = None,
|
604
|
+
cache: dict[str, dict[str, dict[str, Array]]] | None = None,
|
605
|
+
update_cache: Literal[True],
|
606
|
+
) -> tuple[Array, dict[str, dict[str, dict[str, Array]]]]: ...
|
607
|
+
|
608
|
+
@overload
|
609
|
+
def __call__(
|
610
|
+
self,
|
611
|
+
x: Array,
|
612
|
+
*,
|
613
|
+
mask: Array | None = None,
|
614
|
+
positions: Array | None = None,
|
615
|
+
key: PRNGKeyArray | None = None,
|
616
|
+
cache: dict[str, dict[str, dict[str, Array]]] | None = None,
|
617
|
+
update_cache: Literal[False] = False,
|
618
|
+
) -> Array: ...
|
619
|
+
|
620
|
+
def __call__(
|
621
|
+
self,
|
622
|
+
x: Array,
|
623
|
+
*,
|
624
|
+
mask: Array | None = None,
|
625
|
+
positions: Array | None = None,
|
626
|
+
key: PRNGKeyArray | None = None,
|
627
|
+
cache: dict[str, dict[str, dict[str, Array]]] | None = None,
|
628
|
+
update_cache: bool = False,
|
629
|
+
) -> Array | tuple[Array, dict[str, dict[str, dict[str, Array]]]]:
|
630
|
+
"""Forward pass for encoder-only or decoder-only transformers.
|
631
|
+
|
632
|
+
Args:
|
633
|
+
x: Input token indices of shape (seq_len)
|
634
|
+
mask: Optional attention mask
|
635
|
+
positions: Optional positions
|
636
|
+
key: Optional PRNG key for dropout
|
637
|
+
cache: Optional dictionary containing cached key and value tensors
|
638
|
+
update_cache: Whether to update the cache and return it
|
639
|
+
|
640
|
+
Returns:
|
641
|
+
If update_cache is False: Output representation
|
642
|
+
If update_cache is True: Tuple of (output representation, updated cache)
|
643
|
+
"""
|
644
|
+
chex.assert_rank(x, 1)
|
645
|
+
|
646
|
+
if update_cache:
|
647
|
+
output, updated_cache = self.encode(
|
648
|
+
x, mask=mask, positions=positions, key=key, cache=cache, update_cache=True
|
649
|
+
)
|
650
|
+
else:
|
651
|
+
output = self.encode(x, mask=mask, positions=positions, key=key, cache=cache)
|
652
|
+
|
653
|
+
# Apply output layer if it exists
|
654
|
+
if self.output_layer is not None:
|
655
|
+
output = jax.vmap(self.output_layer)(output)
|
656
|
+
|
657
|
+
if update_cache:
|
658
|
+
return output, updated_cache
|
659
|
+
return output
|
660
|
+
|
661
|
+
def predict_sequence(self, x_seq: Array) -> Array:
|
662
|
+
return self(x=x_seq)
|
663
|
+
|
664
|
+
def generate_sequence(
|
665
|
+
self,
|
666
|
+
prompt_seq: Array,
|
667
|
+
max_len: int,
|
668
|
+
temperature: float = 1.0,
|
669
|
+
top_k: int | None = None,
|
670
|
+
key: PRNGKeyArray | None = None,
|
671
|
+
) -> Array:
|
672
|
+
"""Generate a sequence autoregressively with KV caching.
|
673
|
+
|
674
|
+
Args:
|
675
|
+
prompt_seq: Input token indices of shape (prompt_len,)
|
676
|
+
max_len: Maximum length of generated sequence
|
677
|
+
temperature: Sampling temperature
|
678
|
+
top_k: Optional top-k sampling parameter
|
679
|
+
key: PRNG key for sampling
|
680
|
+
|
681
|
+
Returns:
|
682
|
+
Generated sequence of shape (prompt_len + max_len,)
|
683
|
+
"""
|
684
|
+
if key is None:
|
685
|
+
key = jax.random.PRNGKey(0)
|
686
|
+
|
687
|
+
prompt_len = prompt_seq.shape[0]
|
688
|
+
sequence = prompt_seq
|
689
|
+
|
690
|
+
# Create causal mask for generation
|
691
|
+
causal_mask = jnp.tril(jnp.ones((self.max_seq_len, self.max_seq_len), dtype=jnp.bool_))
|
692
|
+
|
693
|
+
# Initialize cache with the prompt
|
694
|
+
_, cache = self(x=prompt_seq, mask=causal_mask[:prompt_len, :prompt_len], update_cache=True)
|
695
|
+
|
696
|
+
# Define decode step function (for clarity)
|
697
|
+
def decode_step(seq: Array, pos: int, cur_cache: dict, rng: PRNGKeyArray) -> tuple[Array, dict, PRNGKeyArray]:
|
698
|
+
# Get the next position and last token
|
699
|
+
pos_tensor = jnp.array([pos])
|
700
|
+
last_token = seq[-1:]
|
701
|
+
|
702
|
+
# Get logits for next token
|
703
|
+
rng, subrng = jax.random.split(rng)
|
704
|
+
logits, new_cache = self(
|
705
|
+
x=last_token,
|
706
|
+
positions=pos_tensor,
|
707
|
+
key=subrng,
|
708
|
+
cache=cur_cache,
|
709
|
+
update_cache=True,
|
710
|
+
)
|
711
|
+
|
712
|
+
# Extract final logits and apply temperature
|
713
|
+
logits = logits[-1] / temperature
|
714
|
+
|
715
|
+
# Apply top-k sampling if specified
|
716
|
+
if top_k is not None:
|
717
|
+
top_logits, top_indices = jax.lax.top_k(logits, top_k)
|
718
|
+
logits = jnp.full_like(logits, float("-inf"))
|
719
|
+
logits = logits.at[top_indices].set(top_logits)
|
720
|
+
|
721
|
+
# Sample next token
|
722
|
+
rng, subrng = jax.random.split(rng)
|
723
|
+
next_token = jax.random.categorical(subrng, logits[None, ...])[0]
|
724
|
+
|
725
|
+
# Add token to sequence
|
726
|
+
new_seq = jnp.concatenate([seq, next_token[None]], axis=0)
|
727
|
+
return new_seq, new_cache, rng
|
728
|
+
|
729
|
+
# Generate tokens one by one
|
730
|
+
for _ in range(max_len):
|
731
|
+
# Break if max sequence length reached
|
732
|
+
if sequence.shape[0] >= self.max_seq_len:
|
733
|
+
break
|
734
|
+
|
735
|
+
# Decode next token
|
736
|
+
sequence, cache, key = decode_step(seq=sequence, pos=sequence.shape[0] - 1, cur_cache=cache, rng=key)
|
737
|
+
|
738
|
+
return sequence
|
xax/task/logger.py
CHANGED
@@ -526,7 +526,7 @@ class LoggerImpl(ABC):
|
|
526
526
|
Returns:
|
527
527
|
If the logger should log the current step.
|
528
528
|
"""
|
529
|
-
elapsed_time = state.elapsed_time_s.item()
|
529
|
+
elapsed_time = state.elapsed_time_s.item()
|
530
530
|
return self.tickers[state.phase].tick(elapsed_time)
|
531
531
|
|
532
532
|
|
xax/task/mixins/train.py
CHANGED
@@ -121,7 +121,7 @@ class ValidStepTimer:
|
|
121
121
|
self.last_valid_step = state.num_steps.item()
|
122
122
|
|
123
123
|
def __call__(self, state: State) -> bool:
|
124
|
-
if state.num_steps < self.valid_first_n_steps
|
124
|
+
if state.num_steps < self.valid_first_n_steps:
|
125
125
|
return True
|
126
126
|
|
127
127
|
if self.last_valid_time is None or self.last_valid_step is None:
|
@@ -130,18 +130,15 @@ class ValidStepTimer:
|
|
130
130
|
|
131
131
|
# Step-based validation.
|
132
132
|
valid_every_n_steps = self.valid_every_n_steps
|
133
|
-
if valid_every_n_steps is not None and
|
134
|
-
state.num_steps >= valid_every_n_steps + self.last_valid_step
|
135
|
-
or state.num_valid_steps >= valid_every_n_steps + self.last_valid_step
|
136
|
-
):
|
133
|
+
if valid_every_n_steps is not None and state.num_steps >= valid_every_n_steps + self.last_valid_step:
|
137
134
|
self._reset(state)
|
138
135
|
return True
|
139
136
|
|
140
137
|
# Time-based validation.
|
141
138
|
valid_every_n_seconds = self.valid_every_n_seconds
|
142
|
-
if
|
143
|
-
|
144
|
-
|
139
|
+
if (
|
140
|
+
valid_every_n_seconds is not None
|
141
|
+
and state.elapsed_time_s.item() - self.last_valid_time >= valid_every_n_seconds
|
145
142
|
):
|
146
143
|
self._reset(state)
|
147
144
|
return True
|
@@ -149,10 +146,7 @@ class ValidStepTimer:
|
|
149
146
|
# Time-based validation for first validation step.
|
150
147
|
if self.first_valid_step_flag:
|
151
148
|
valid_first_n_seconds = self.valid_first_n_seconds
|
152
|
-
if valid_first_n_seconds is not None and (
|
153
|
-
state.elapsed_time_s.item() >= valid_first_n_seconds
|
154
|
-
or state.valid_elapsed_time_s.item() >= valid_first_n_seconds
|
155
|
-
):
|
149
|
+
if valid_first_n_seconds is not None and state.elapsed_time_s.item() >= valid_first_n_seconds:
|
156
150
|
self._reset(state)
|
157
151
|
self.first_valid_step_flag = False
|
158
152
|
return True
|
@@ -625,9 +619,13 @@ class TrainMixin(
|
|
625
619
|
grad_metrics = {"grad_norm": grad_norm}
|
626
620
|
|
627
621
|
def apply(grads: PyTree, grad_norm: Array) -> tuple[PyTree, optax.OptState]:
|
628
|
-
# Clip
|
629
|
-
|
630
|
-
|
622
|
+
# Clip gradients based on global norm, similar to optax.clip_by_global_norm
|
623
|
+
trigger = jnp.squeeze(grad_norm < self.config.global_grad_clip)
|
624
|
+
|
625
|
+
def clip_fn(t: Array) -> Array:
|
626
|
+
return jax.lax.select(trigger, t, (t / grad_norm.astype(t.dtype)) * self.config.global_grad_clip)
|
627
|
+
|
628
|
+
grads = jax.tree.map(clip_fn, grads)
|
631
629
|
|
632
630
|
# Apply the gradient updates.
|
633
631
|
updates, new_opt_state = optimizer.update(grads, opt_state, model_arr)
|
@@ -773,12 +771,12 @@ class TrainMixin(
|
|
773
771
|
self.log_step(eqx.combine(model_arr, model_static), valid_batch, output, metrics, state)
|
774
772
|
|
775
773
|
state = state.replace(
|
776
|
-
|
777
|
-
|
774
|
+
num_steps=state.num_steps + 1,
|
775
|
+
num_samples=state.num_samples + (self.get_size_of_batch(valid_batch) or 0),
|
778
776
|
)
|
779
777
|
|
780
778
|
state = state.replace(
|
781
|
-
|
779
|
+
elapsed_time_s=state.elapsed_time_s + timer.elapsed_time,
|
782
780
|
)
|
783
781
|
|
784
782
|
with ContextTimer() as timer:
|
@@ -878,7 +876,7 @@ class TrainMixin(
|
|
878
876
|
key, model_key = jax.random.split(key)
|
879
877
|
models, optimizers, opt_states, state = self.load_initial_state(model_key, load_optimizer=True)
|
880
878
|
logger.info("Model size: %s", f"{get_pytree_param_count(models):,}")
|
881
|
-
logger.info("Optimizer size: %s", f"{get_pytree_param_count(
|
879
|
+
logger.info("Optimizer size: %s", f"{get_pytree_param_count(opt_states):,}")
|
882
880
|
|
883
881
|
state = self.on_training_start(state)
|
884
882
|
|
xax/utils/experiments.py
CHANGED
@@ -111,8 +111,8 @@ class StateTimer:
|
|
111
111
|
|
112
112
|
def step(self, state: State) -> None:
|
113
113
|
cur_time = time.time()
|
114
|
-
num_steps = int(
|
115
|
-
num_samples = int(
|
114
|
+
num_steps = int(state.num_steps.item())
|
115
|
+
num_samples = int(state.num_samples.item())
|
116
116
|
self.step_timer.step(num_steps, cur_time)
|
117
117
|
self.sample_timer.step(num_samples, cur_time)
|
118
118
|
self.iter_timer.step(cur_time)
|
xax/utils/jax.py
CHANGED
@@ -6,13 +6,14 @@ import logging
|
|
6
6
|
import os
|
7
7
|
import time
|
8
8
|
from functools import wraps
|
9
|
-
from typing import Any, Callable, Iterable, ParamSpec, Sequence, TypeVar, cast
|
9
|
+
from typing import Any, Callable, Hashable, Iterable, ParamSpec, Sequence, TypeVar, cast
|
10
10
|
|
11
11
|
import jax
|
12
12
|
import jax.numpy as jnp
|
13
13
|
import numpy as np
|
14
14
|
from jax._src import sharding_impls
|
15
15
|
from jax._src.lib import xla_client as xc
|
16
|
+
from jaxtyping import PyTree
|
16
17
|
|
17
18
|
logger = logging.getLogger(__name__)
|
18
19
|
|
@@ -20,6 +21,7 @@ DEFAULT_COMPILE_TIMEOUT = 1.0
|
|
20
21
|
|
21
22
|
Number = int | float | np.ndarray | jnp.ndarray
|
22
23
|
|
24
|
+
T = TypeVar("T", bound=PyTree)
|
23
25
|
|
24
26
|
P = ParamSpec("P") # For function parameters
|
25
27
|
R = TypeVar("R") # For function return type
|
@@ -29,6 +31,9 @@ Carry = TypeVar("Carry")
|
|
29
31
|
X = TypeVar("X")
|
30
32
|
Y = TypeVar("Y")
|
31
33
|
|
34
|
+
F = TypeVar("F", bound=Callable)
|
35
|
+
AxisName = Hashable
|
36
|
+
|
32
37
|
|
33
38
|
@functools.lru_cache(maxsize=None)
|
34
39
|
def disable_jit_level() -> int:
|
@@ -166,6 +171,22 @@ def jit(
|
|
166
171
|
return decorator
|
167
172
|
|
168
173
|
|
174
|
+
def _split_module(tree: T, axis: int = 0) -> list[T]:
|
175
|
+
"""Splits a module in the same way that jax.lax.scan and jax.vmap do.
|
176
|
+
|
177
|
+
Args:
|
178
|
+
tree: The tree to split.
|
179
|
+
axis: The axis to split on.
|
180
|
+
|
181
|
+
Returns:
|
182
|
+
A list of the split trees.
|
183
|
+
"""
|
184
|
+
first_leaf = jax.tree.leaves(tree)[0]
|
185
|
+
num_slices = first_leaf.shape[axis]
|
186
|
+
result = [jax.tree.map(lambda x, idx=i: jnp.take(x, idx, axis=axis), tree) for i in range(num_slices)]
|
187
|
+
return result
|
188
|
+
|
189
|
+
|
169
190
|
def scan(
|
170
191
|
f: Callable[[Carry, X], tuple[Carry, Y]],
|
171
192
|
init: Carry,
|
@@ -195,15 +216,96 @@ def scan(
|
|
195
216
|
if not should_disable_jit(jit_level):
|
196
217
|
return jax.lax.scan(f, init, xs, length, reverse, unroll)
|
197
218
|
|
219
|
+
carry = init
|
220
|
+
ys = []
|
221
|
+
|
198
222
|
if xs is None:
|
199
223
|
if length is None:
|
200
224
|
raise ValueError("length must be provided if xs is None")
|
201
|
-
|
225
|
+
for _ in range(length) if not reverse else range(length - 1, -1, -1):
|
226
|
+
carry, y = f(carry, None) # type: ignore[arg-type]
|
227
|
+
ys.append(y)
|
202
228
|
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
229
|
+
else:
|
230
|
+
xlist = _split_module(xs, axis=0)
|
231
|
+
if reverse:
|
232
|
+
xlist = xlist[::-1]
|
233
|
+
for x in xlist:
|
234
|
+
carry, y = f(carry, x)
|
235
|
+
ys.append(y)
|
236
|
+
|
237
|
+
if reverse:
|
238
|
+
ys = ys[::-1]
|
239
|
+
|
240
|
+
if not ys:
|
241
|
+
return carry, jnp.array([]) # type: ignore[return-value]
|
208
242
|
|
209
243
|
return carry, jax.tree.map(lambda *ys: jnp.stack(ys), *ys)
|
244
|
+
|
245
|
+
|
246
|
+
def vmap(
|
247
|
+
fun: Callable[P, R],
|
248
|
+
in_axes: int | Sequence[int | None] = 0,
|
249
|
+
jit_level: int | None = None,
|
250
|
+
) -> Callable[P, R]:
|
251
|
+
"""A wrapper around jax.lax.vmap that allows for more flexible tracing.
|
252
|
+
|
253
|
+
If the provided JIT level is below the environment JIT level, we manually
|
254
|
+
unroll the scan function as a for loop.
|
255
|
+
"""
|
256
|
+
if not should_disable_jit(jit_level):
|
257
|
+
return jax.vmap(fun, in_axes=in_axes)
|
258
|
+
|
259
|
+
@functools.wraps(fun)
|
260
|
+
def wrapped(*args: P.args, **kwargs: P.kwargs) -> R:
|
261
|
+
if kwargs:
|
262
|
+
raise ValueError("vmap does not support keyword arguments")
|
263
|
+
|
264
|
+
ia = in_axes
|
265
|
+
if isinstance(ia, int):
|
266
|
+
ia = [ia] * len(args)
|
267
|
+
elif len(ia) != len(args):
|
268
|
+
raise ValueError("in_axes must be the same length as args")
|
269
|
+
|
270
|
+
if not all(isinstance(a, int) or a is None for a in ia):
|
271
|
+
raise ValueError("in_axes must be a list of integers or None")
|
272
|
+
|
273
|
+
ns = next((len(_split_module(a, axis=i)) for i, a in zip(ia, args, strict=True) if i is not None), None)
|
274
|
+
if ns is None:
|
275
|
+
return fun(*args, **kwargs)
|
276
|
+
split_args = [[a] * ns if i is None else _split_module(a, axis=i) for i, a in zip(ia, args, strict=True)]
|
277
|
+
split_outputs = [fun(*sargs, **kwargs) for sargs in zip(*split_args, strict=True)]
|
278
|
+
|
279
|
+
if not split_outputs:
|
280
|
+
return jnp.array([]) # type: ignore[return-value]
|
281
|
+
|
282
|
+
return jax.tree.map(lambda *ys: jnp.stack(ys), *split_outputs)
|
283
|
+
|
284
|
+
return wrapped
|
285
|
+
|
286
|
+
|
287
|
+
def grad(
|
288
|
+
fun: Callable[P, R],
|
289
|
+
argnums: int | Sequence[int] = 0,
|
290
|
+
has_aux: bool = False,
|
291
|
+
holomorphic: bool = False,
|
292
|
+
allow_int: bool = False,
|
293
|
+
reduce_axes: Sequence[AxisName] = (),
|
294
|
+
jit_level: int | None = None,
|
295
|
+
) -> Callable:
|
296
|
+
"""A wrapper around jax.grad that allows for more flexible tracing.
|
297
|
+
|
298
|
+
We don't do anything special here, we just manually evaluate the function
|
299
|
+
if the JIT level is below the environment JIT level.
|
300
|
+
"""
|
301
|
+
if not should_disable_jit(jit_level):
|
302
|
+
return jax.grad(fun, argnums, has_aux, holomorphic, allow_int, reduce_axes)
|
303
|
+
|
304
|
+
@functools.wraps(fun)
|
305
|
+
def wrapped(*args: P.args, **kwargs: P.kwargs) -> Callable:
|
306
|
+
# Evaluate the function once, then just return the gradient.
|
307
|
+
fun(*args, **kwargs)
|
308
|
+
|
309
|
+
return jax.grad(fun, argnums, has_aux, holomorphic, allow_int, reduce_axes)(*args, **kwargs)
|
310
|
+
|
311
|
+
return wrapped
|
@@ -1,4 +1,4 @@
|
|
1
|
-
xax/__init__.py,sha256=
|
1
|
+
xax/__init__.py,sha256=Q0boKxPtEUiiJ9j7Cdx51bLLFtYx3fPfCTG-o8o2Chk,15653
|
2
2
|
xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
3
|
xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
|
4
4
|
xax/requirements.txt,sha256=6qY-84e-sTmlfJNrSjwONQKqzAn5h8G_oGIhnhmfSr4,302
|
@@ -6,8 +6,9 @@ xax/cli/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
6
6
|
xax/cli/edit_config.py,sha256=LQUIlOS6hvPZyVEaMme3FP-62M0BKQPYavCwVDWuBLw,2600
|
7
7
|
xax/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
8
8
|
xax/core/conf.py,sha256=d7Dp_GwKnaxtkztlSrJSM_LR0UYJX_FWTtceIWCBkxc,5138
|
9
|
-
xax/core/state.py,sha256=
|
9
|
+
xax/core/state.py,sha256=F9Tj3FfCw8zFKaDEoEGiThZE2ntYEtzNjnBX3pQ1g60,3826
|
10
10
|
xax/nn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
11
|
+
xax/nn/attention.py,sha256=0essK90OO3x9FxnUqU0DhufwXKRMN41zMtRCki5iKzQ,24742
|
11
12
|
xax/nn/embeddings.py,sha256=bQGxBFxkLwi2MQLkRfGaHPH5P_KKB21HdI7VNWTKIOQ,11847
|
12
13
|
xax/nn/functions.py,sha256=bA5kJYzMtFM8eUqBC086i355zJMAO7k_vPFNSDBI9-s,2814
|
13
14
|
xax/nn/geom.py,sha256=A7WPefMvgwUNReZC7_HX1GmvHPASyghbaXaKsuhwDrE,7382
|
@@ -17,7 +18,7 @@ xax/nn/parallel.py,sha256=fnTiT7MsG7eQrJvqwjIz2Ifo3P27TuxIJzmpGYSa_dQ,4608
|
|
17
18
|
xax/nn/ssm.py,sha256=8dLAcQ1hBaMT-kkHvwGu_ecxJeTY32WeMYmd4T4KtxA,10745
|
18
19
|
xax/task/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
19
20
|
xax/task/base.py,sha256=TYANmjNcce4_V5ZSYLnE91PXRn7Nn0nT7hN8plW_Au0,8117
|
20
|
-
xax/task/logger.py,sha256=
|
21
|
+
xax/task/logger.py,sha256=Bmhl4mv08Aq49ZyX6BdjPIsPJK28e8s3mVFatM4IY2Q,41060
|
21
22
|
xax/task/script.py,sha256=bMMIJoUtpSBvPp6-7bejTrajTXvSg0794sYLKdPIToE,972
|
22
23
|
xax/task/task.py,sha256=UHMpnv__gqMcfbC_L-Hhk-DCnUYlFVsgbNf-v8o8B7U,1424
|
23
24
|
xax/task/launchers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -41,11 +42,11 @@ xax/task/mixins/logger.py,sha256=6oXsJJyNUx6YT3q58FVXMZBUpMgjVkGre6BXFN20cVI,280
|
|
41
42
|
xax/task/mixins/process.py,sha256=hqDEsMp_SL6ee97iq26-G0g49OcWZZaX82JD4F22eJU,1781
|
42
43
|
xax/task/mixins/runnable.py,sha256=IYIsLd2k09g-_y6o44EhJqT7E6BpsyEMmsyLSuzqjtc,1979
|
43
44
|
xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
|
44
|
-
xax/task/mixins/train.py,sha256=
|
45
|
+
xax/task/mixins/train.py,sha256=TZatz5QwTfrNhQTiO2IqrmQY9P4Lay6FAD2VsQpWa54,33245
|
45
46
|
xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
46
47
|
xax/utils/debugging.py,sha256=OtUdu-3tQsQtik0Q9UM-SNV46IbPjwrAfZcywzoB5d4,1940
|
47
|
-
xax/utils/experiments.py,sha256=
|
48
|
-
xax/utils/jax.py,sha256=
|
48
|
+
xax/utils/experiments.py,sha256=5k5hPYSaVjzoR_nm2Q3DAHMMYi3Bcp3N3PAQbwZq7Gg,29830
|
49
|
+
xax/utils/jax.py,sha256=6cP95-rcjkRt1fefkZWJQhJhH0uUYWJB3w4NP1-aDp0,10136
|
49
50
|
xax/utils/jaxpr.py,sha256=H7pWl48ROXIB1-ZPWYfOn-ou3EBMxYWIwc_A0reJQoo,2333
|
50
51
|
xax/utils/logging.py,sha256=GAhTne2rdB4Fa1lzk06DMO15U8MTejn6XTClShC-ZtU,6622
|
51
52
|
xax/utils/numpy.py,sha256=_jOXVi-d2AtJnRftPkRK5MDMzsU8slgw-Jjv4GRm6ns,1197
|
@@ -58,9 +59,9 @@ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,706
|
|
58
59
|
xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
59
60
|
xax/utils/types/frozen_dict.py,sha256=ebtHENhyUzSjyJTlbMaLtcckQIJ7EtgJiok_40TJZpo,4689
|
60
61
|
xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
|
61
|
-
xax-0.2.
|
62
|
-
xax-0.2.
|
63
|
-
xax-0.2.
|
64
|
-
xax-0.2.
|
65
|
-
xax-0.2.
|
66
|
-
xax-0.2.
|
62
|
+
xax-0.2.23.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
|
63
|
+
xax-0.2.23.dist-info/METADATA,sha256=mA98vsIjdfb8XM2mN1vUb2VRVEPU4xf10IWLxxFJjmY,1247
|
64
|
+
xax-0.2.23.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
|
65
|
+
xax-0.2.23.dist-info/entry_points.txt,sha256=uRC6rx5ce0bf-FblJaZSBMxxKFfMyoWTf8OWbBmLSe8,61
|
66
|
+
xax-0.2.23.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
|
67
|
+
xax-0.2.23.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|