xax 0.2.22__py3-none-any.whl → 0.2.23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +10 -1
- xax/core/state.py +10 -37
- xax/nn/attention.py +738 -0
- xax/task/logger.py +1 -1
- xax/task/mixins/train.py +10 -16
- xax/utils/experiments.py +2 -2
- {xax-0.2.22.dist-info → xax-0.2.23.dist-info}/METADATA +1 -1
- {xax-0.2.22.dist-info → xax-0.2.23.dist-info}/RECORD +12 -11
- {xax-0.2.22.dist-info → xax-0.2.23.dist-info}/WHEEL +1 -1
- {xax-0.2.22.dist-info → xax-0.2.23.dist-info}/entry_points.txt +0 -0
- {xax-0.2.22.dist-info → xax-0.2.23.dist-info}/licenses/LICENSE +0 -0
- {xax-0.2.22.dist-info → xax-0.2.23.dist-info}/top_level.txt +0 -0
xax/__init__.py
CHANGED
@@ -12,7 +12,7 @@ and running the update script:
|
|
12
12
|
python -m scripts.update_api --inplace
|
13
13
|
"""
|
14
14
|
|
15
|
-
__version__ = "0.2.
|
15
|
+
__version__ = "0.2.23"
|
16
16
|
|
17
17
|
# This list shouldn't be modified by hand; instead, run the update script.
|
18
18
|
__all__ = [
|
@@ -23,6 +23,10 @@ __all__ = [
|
|
23
23
|
"get_run_dir",
|
24
24
|
"load_user_config",
|
25
25
|
"State",
|
26
|
+
"CrossAttentionBlock",
|
27
|
+
"SelfAttentionBlock",
|
28
|
+
"Transformer",
|
29
|
+
"TransformerBlock",
|
26
30
|
"FourierEmbeddings",
|
27
31
|
"IdentityPositionalEmbeddings",
|
28
32
|
"LearnedPositionalEmbeddings",
|
@@ -200,6 +204,10 @@ NAME_MAP: dict[str, str] = {
|
|
200
204
|
"get_run_dir": "core.conf",
|
201
205
|
"load_user_config": "core.conf",
|
202
206
|
"State": "core.state",
|
207
|
+
"CrossAttentionBlock": "nn.attention",
|
208
|
+
"SelfAttentionBlock": "nn.attention",
|
209
|
+
"Transformer": "nn.attention",
|
210
|
+
"TransformerBlock": "nn.attention",
|
203
211
|
"FourierEmbeddings": "nn.embeddings",
|
204
212
|
"IdentityPositionalEmbeddings": "nn.embeddings",
|
205
213
|
"LearnedPositionalEmbeddings": "nn.embeddings",
|
@@ -370,6 +378,7 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
370
378
|
load_user_config,
|
371
379
|
)
|
372
380
|
from xax.core.state import Phase, State
|
381
|
+
from xax.nn.attention import CrossAttentionBlock, SelfAttentionBlock, Transformer, TransformerBlock
|
373
382
|
from xax.nn.embeddings import (
|
374
383
|
EmbeddingKind,
|
375
384
|
FourierEmbeddings,
|
xax/core/state.py
CHANGED
@@ -27,11 +27,8 @@ def _int_to_phase(i: int) -> Phase:
|
|
27
27
|
class StateDict(TypedDict, total=False):
|
28
28
|
num_steps: NotRequired[int | Array]
|
29
29
|
num_samples: NotRequired[int | Array]
|
30
|
-
num_valid_steps: NotRequired[int | Array]
|
31
|
-
num_valid_samples: NotRequired[int | Array]
|
32
30
|
start_time_s: NotRequired[float | Array]
|
33
31
|
elapsed_time_s: NotRequired[float | Array]
|
34
|
-
valid_elapsed_time_s: NotRequired[float | Array]
|
35
32
|
phase: NotRequired[Phase]
|
36
33
|
_phase: NotRequired[int | Array]
|
37
34
|
|
@@ -47,38 +44,26 @@ class State:
|
|
47
44
|
return self._int32_arr[0]
|
48
45
|
|
49
46
|
@property
|
50
|
-
def
|
51
|
-
return self._int32_arr[1]
|
47
|
+
def phase(self) -> Phase:
|
48
|
+
return _int_to_phase(self._int32_arr[1].item())
|
52
49
|
|
53
50
|
@property
|
54
51
|
def num_samples(self) -> Array:
|
55
52
|
return self._float32_arr[0]
|
56
53
|
|
57
|
-
@property
|
58
|
-
def num_valid_samples(self) -> Array:
|
59
|
-
return self._float32_arr[1]
|
60
|
-
|
61
54
|
@property
|
62
55
|
def start_time_s(self) -> Array:
|
63
|
-
return self._float32_arr[
|
56
|
+
return self._float32_arr[1]
|
64
57
|
|
65
58
|
@property
|
66
59
|
def elapsed_time_s(self) -> Array:
|
67
|
-
return self._float32_arr[
|
68
|
-
|
69
|
-
@property
|
70
|
-
def valid_elapsed_time_s(self) -> Array:
|
71
|
-
return self._float32_arr[4]
|
72
|
-
|
73
|
-
@property
|
74
|
-
def phase(self) -> Phase:
|
75
|
-
return _int_to_phase(self._int32_arr[2].item())
|
60
|
+
return self._float32_arr[2]
|
76
61
|
|
77
62
|
@classmethod
|
78
63
|
def init_state(cls) -> "State":
|
79
64
|
return cls(
|
80
|
-
_int32_arr=jnp.array([0, 0
|
81
|
-
_float32_arr=jnp.array([0.0,
|
65
|
+
_int32_arr=jnp.array([0, 0], dtype=jnp.int32),
|
66
|
+
_float32_arr=jnp.array([0.0, time.time(), 0.0], dtype=jnp.float32),
|
82
67
|
)
|
83
68
|
|
84
69
|
@property
|
@@ -91,25 +76,19 @@ class State:
|
|
91
76
|
|
92
77
|
if "num_steps" in kwargs:
|
93
78
|
int32_arr = int32_arr.at[0].set(kwargs["num_steps"])
|
94
|
-
if "num_valid_steps" in kwargs:
|
95
|
-
int32_arr = int32_arr.at[1].set(kwargs["num_valid_steps"])
|
96
79
|
|
97
80
|
if "phase" in kwargs:
|
98
|
-
int32_arr = int32_arr.at[
|
81
|
+
int32_arr = int32_arr.at[1].set(_phase_to_int(kwargs["phase"]))
|
99
82
|
if "_phase" in kwargs:
|
100
|
-
int32_arr = int32_arr.at[
|
83
|
+
int32_arr = int32_arr.at[1].set(kwargs["_phase"])
|
101
84
|
|
102
85
|
if "num_samples" in kwargs:
|
103
86
|
float32_arr = float32_arr.at[0].set(kwargs["num_samples"])
|
104
|
-
if "num_valid_samples" in kwargs:
|
105
|
-
float32_arr = float32_arr.at[1].set(kwargs["num_valid_samples"])
|
106
87
|
|
107
88
|
if "start_time_s" in kwargs:
|
108
|
-
float32_arr = float32_arr.at[
|
89
|
+
float32_arr = float32_arr.at[1].set(kwargs["start_time_s"])
|
109
90
|
if "elapsed_time_s" in kwargs:
|
110
|
-
float32_arr = float32_arr.at[
|
111
|
-
if "valid_elapsed_time_s" in kwargs:
|
112
|
-
float32_arr = float32_arr.at[4].set(kwargs["valid_elapsed_time_s"])
|
91
|
+
float32_arr = float32_arr.at[2].set(kwargs["elapsed_time_s"])
|
113
92
|
|
114
93
|
return State(
|
115
94
|
_int32_arr=int32_arr,
|
@@ -119,12 +98,9 @@ class State:
|
|
119
98
|
def to_dict(self) -> dict[str, int | float | str]:
|
120
99
|
return {
|
121
100
|
"num_steps": int(self.num_steps.item()),
|
122
|
-
"num_valid_steps": int(self.num_valid_steps.item()),
|
123
101
|
"num_samples": int(self.num_samples.item()),
|
124
|
-
"num_valid_samples": int(self.num_valid_samples.item()),
|
125
102
|
"start_time_s": float(self.start_time_s.item()),
|
126
103
|
"elapsed_time_s": float(self.elapsed_time_s.item()),
|
127
|
-
"valid_elapsed_time_s": float(self.valid_elapsed_time_s.item()),
|
128
104
|
"phase": str(self.phase),
|
129
105
|
}
|
130
106
|
|
@@ -136,7 +112,6 @@ class State:
|
|
136
112
|
int32_arr = jnp.array(
|
137
113
|
[
|
138
114
|
d.get("num_steps", 0),
|
139
|
-
d.get("num_valid_steps", 0),
|
140
115
|
d.get("_phase", 0),
|
141
116
|
],
|
142
117
|
dtype=jnp.int32,
|
@@ -145,10 +120,8 @@ class State:
|
|
145
120
|
float32_arr = jnp.array(
|
146
121
|
[
|
147
122
|
d.get("num_samples", 0),
|
148
|
-
d.get("num_valid_samples", 0),
|
149
123
|
d.get("start_time_s", time.time()),
|
150
124
|
d.get("elapsed_time_s", 0.0),
|
151
|
-
d.get("valid_elapsed_time_s", 0.0),
|
152
125
|
],
|
153
126
|
dtype=jnp.float32,
|
154
127
|
)
|
xax/nn/attention.py
ADDED
@@ -0,0 +1,738 @@
|
|
1
|
+
"""Attention mechanisms for transformer models."""
|
2
|
+
|
3
|
+
from typing import Literal, cast, overload
|
4
|
+
|
5
|
+
import chex
|
6
|
+
import equinox as eqx
|
7
|
+
import jax
|
8
|
+
import jax.numpy as jnp
|
9
|
+
from jaxtyping import Array, PRNGKeyArray
|
10
|
+
|
11
|
+
|
12
|
+
class SelfAttentionBlock(eqx.Module):
|
13
|
+
"""Self-attention block using jax.nn.dot_product_attention."""
|
14
|
+
|
15
|
+
q_proj: eqx.nn.Linear
|
16
|
+
k_proj: eqx.nn.Linear
|
17
|
+
v_proj: eqx.nn.Linear
|
18
|
+
output_proj: eqx.nn.Linear
|
19
|
+
num_heads: int = eqx.static_field()
|
20
|
+
head_dim: int = eqx.static_field()
|
21
|
+
causal: bool = eqx.static_field()
|
22
|
+
|
23
|
+
def __init__(
|
24
|
+
self,
|
25
|
+
embed_dim: int,
|
26
|
+
num_heads: int,
|
27
|
+
*,
|
28
|
+
key: PRNGKeyArray,
|
29
|
+
causal: bool = False,
|
30
|
+
) -> None:
|
31
|
+
keys = jax.random.split(key, 4)
|
32
|
+
|
33
|
+
self.num_heads = num_heads
|
34
|
+
self.head_dim = embed_dim // num_heads
|
35
|
+
assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
|
36
|
+
|
37
|
+
self.q_proj = eqx.nn.Linear(embed_dim, embed_dim, key=keys[0])
|
38
|
+
self.k_proj = eqx.nn.Linear(embed_dim, embed_dim, key=keys[1])
|
39
|
+
self.v_proj = eqx.nn.Linear(embed_dim, embed_dim, key=keys[2])
|
40
|
+
self.output_proj = eqx.nn.Linear(embed_dim, embed_dim, key=keys[3])
|
41
|
+
|
42
|
+
self.causal = causal
|
43
|
+
|
44
|
+
def _reshape_for_multihead(self, x: Array) -> Array:
|
45
|
+
"""Reshape from (seq_len, embed_dim) to (seq_len, num_heads, head_dim)."""
|
46
|
+
seq_len, _ = x.shape
|
47
|
+
return x.reshape(seq_len, self.num_heads, self.head_dim)
|
48
|
+
|
49
|
+
def _combine_heads(self, x: Array) -> Array:
|
50
|
+
"""Reshape from (seq_len, num_heads, head_dim) to (seq_len, embed_dim)."""
|
51
|
+
seq_len, _, _ = x.shape
|
52
|
+
return x.reshape(seq_len, -1)
|
53
|
+
|
54
|
+
def __call__(
|
55
|
+
self,
|
56
|
+
x: Array,
|
57
|
+
*,
|
58
|
+
key: PRNGKeyArray | None = None,
|
59
|
+
mask: Array | None = None,
|
60
|
+
cache: dict[str, Array] | None = None,
|
61
|
+
update_cache: bool = False,
|
62
|
+
) -> Array | tuple[Array, dict[str, Array]]:
|
63
|
+
"""Apply self-attention to the input.
|
64
|
+
|
65
|
+
Args:
|
66
|
+
x: Input tensor of shape (seq_len, embed_dim)
|
67
|
+
key: PRNGKey for dropout randomness
|
68
|
+
mask: Optional mask tensor of shape (seq_len, seq_len) or broadcastable
|
69
|
+
cache: Optional dictionary containing cached key and value tensors
|
70
|
+
update_cache: Whether to update the cache and return it
|
71
|
+
|
72
|
+
Returns:
|
73
|
+
If update_cache is False: Output tensor of shape (seq_len, embed_dim)
|
74
|
+
If update_cache is True: Tuple of (output tensor, updated cache)
|
75
|
+
"""
|
76
|
+
chex.assert_rank(x, 2)
|
77
|
+
|
78
|
+
# Project inputs to queries, keys, and values
|
79
|
+
q = jax.vmap(self.q_proj)(x)
|
80
|
+
|
81
|
+
# Use cached key/value if provided and not updating cache
|
82
|
+
if cache is not None and not update_cache:
|
83
|
+
k = cache["k"]
|
84
|
+
v = cache["v"]
|
85
|
+
else:
|
86
|
+
k = jax.vmap(self.k_proj)(x)
|
87
|
+
v = jax.vmap(self.v_proj)(x)
|
88
|
+
|
89
|
+
# Update cache if needed
|
90
|
+
if update_cache:
|
91
|
+
if cache is None:
|
92
|
+
cache = {}
|
93
|
+
cache = {"k": k, "v": v}
|
94
|
+
|
95
|
+
# Reshape to multihead format
|
96
|
+
q = self._reshape_for_multihead(q)
|
97
|
+
k = self._reshape_for_multihead(k)
|
98
|
+
v = self._reshape_for_multihead(v)
|
99
|
+
|
100
|
+
# Apply dot product attention.
|
101
|
+
# Note that Apple Silicon struggles with this:
|
102
|
+
# https://github.com/jax-ml/jax/issues/20114
|
103
|
+
attn_output = jax.nn.dot_product_attention(
|
104
|
+
q,
|
105
|
+
k,
|
106
|
+
v,
|
107
|
+
mask=mask,
|
108
|
+
is_causal=self.causal and mask is None,
|
109
|
+
)
|
110
|
+
|
111
|
+
# Combine heads
|
112
|
+
attn_output = self._combine_heads(attn_output)
|
113
|
+
|
114
|
+
# Final projection
|
115
|
+
output = jax.vmap(self.output_proj)(attn_output)
|
116
|
+
|
117
|
+
if update_cache:
|
118
|
+
return output, cast(dict[str, Array], cache)
|
119
|
+
return output
|
120
|
+
|
121
|
+
|
122
|
+
class CrossAttentionBlock(eqx.Module):
|
123
|
+
"""Cross-attention block using jax.nn.dot_product_attention."""
|
124
|
+
|
125
|
+
q_proj: eqx.nn.Linear
|
126
|
+
k_proj: eqx.nn.Linear
|
127
|
+
v_proj: eqx.nn.Linear
|
128
|
+
output_proj: eqx.nn.Linear
|
129
|
+
num_heads: int = eqx.static_field()
|
130
|
+
head_dim: int = eqx.static_field()
|
131
|
+
|
132
|
+
def __init__(
|
133
|
+
self,
|
134
|
+
embed_dim: int,
|
135
|
+
num_heads: int,
|
136
|
+
*,
|
137
|
+
key: PRNGKeyArray,
|
138
|
+
) -> None:
|
139
|
+
keys = jax.random.split(key, 4)
|
140
|
+
|
141
|
+
self.num_heads = num_heads
|
142
|
+
self.head_dim = embed_dim // num_heads
|
143
|
+
assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
|
144
|
+
|
145
|
+
self.q_proj = eqx.nn.Linear(embed_dim, embed_dim, key=keys[0])
|
146
|
+
self.k_proj = eqx.nn.Linear(embed_dim, embed_dim, key=keys[1])
|
147
|
+
self.v_proj = eqx.nn.Linear(embed_dim, embed_dim, key=keys[2])
|
148
|
+
self.output_proj = eqx.nn.Linear(embed_dim, embed_dim, key=keys[3])
|
149
|
+
|
150
|
+
def _reshape_for_multihead(self, x: Array) -> Array:
|
151
|
+
"""Reshape from (seq_len, embed_dim) to (seq_len, num_heads, head_dim)."""
|
152
|
+
seq_len, _ = x.shape
|
153
|
+
return x.reshape(seq_len, self.num_heads, self.head_dim)
|
154
|
+
|
155
|
+
def _combine_heads(self, x: Array) -> Array:
|
156
|
+
"""Reshape from (seq_len, num_heads, head_dim) to (seq_len, embed_dim)."""
|
157
|
+
seq_len, _, _ = x.shape
|
158
|
+
return x.reshape(seq_len, -1)
|
159
|
+
|
160
|
+
def __call__(
|
161
|
+
self,
|
162
|
+
q_input: Array,
|
163
|
+
kv_input: Array,
|
164
|
+
*,
|
165
|
+
key: PRNGKeyArray | None = None,
|
166
|
+
mask: Array | None = None,
|
167
|
+
cache: dict[str, Array] | None = None,
|
168
|
+
update_cache: bool = False,
|
169
|
+
) -> Array | tuple[Array, dict[str, Array]]:
|
170
|
+
"""Apply cross-attention.
|
171
|
+
|
172
|
+
Args:
|
173
|
+
q_input: Query input tensor of shape (q_seq_len, embed_dim)
|
174
|
+
kv_input: Key/value input tensor of shape (kv_seq_len, embed_dim)
|
175
|
+
key: PRNGKey for dropout randomness
|
176
|
+
mask: Optional mask tensor
|
177
|
+
cache: Optional dictionary containing cached key and value tensors
|
178
|
+
update_cache: Whether to update the cache and return it
|
179
|
+
|
180
|
+
Returns:
|
181
|
+
If update_cache is False: Output tensor of shape (q_seq_len, embed_dim)
|
182
|
+
If update_cache is True: Tuple of (output tensor, updated cache)
|
183
|
+
"""
|
184
|
+
chex.assert_rank(q_input, 2)
|
185
|
+
chex.assert_rank(kv_input, 2)
|
186
|
+
|
187
|
+
# Project inputs to queries, keys, and values
|
188
|
+
q = jax.vmap(self.q_proj)(q_input)
|
189
|
+
|
190
|
+
# Use cached key/value if provided and not updating cache
|
191
|
+
if cache is not None and not update_cache:
|
192
|
+
k = cache["k"]
|
193
|
+
v = cache["v"]
|
194
|
+
else:
|
195
|
+
k = jax.vmap(self.k_proj)(kv_input)
|
196
|
+
v = jax.vmap(self.v_proj)(kv_input)
|
197
|
+
|
198
|
+
# Update cache if needed
|
199
|
+
if update_cache:
|
200
|
+
if cache is None:
|
201
|
+
cache = {}
|
202
|
+
cache = {"k": k, "v": v}
|
203
|
+
|
204
|
+
# Reshape to multihead format
|
205
|
+
q = self._reshape_for_multihead(q)
|
206
|
+
k = self._reshape_for_multihead(k)
|
207
|
+
v = self._reshape_for_multihead(v)
|
208
|
+
|
209
|
+
# Apply dot product attention
|
210
|
+
attn_output = jax.nn.dot_product_attention(
|
211
|
+
q,
|
212
|
+
k,
|
213
|
+
v,
|
214
|
+
mask=mask,
|
215
|
+
is_causal=False,
|
216
|
+
)
|
217
|
+
|
218
|
+
# Combine heads
|
219
|
+
attn_output = self._combine_heads(attn_output)
|
220
|
+
|
221
|
+
# Final projection
|
222
|
+
output = jax.vmap(self.output_proj)(attn_output)
|
223
|
+
|
224
|
+
if update_cache:
|
225
|
+
return output, cast(dict[str, Array], cache)
|
226
|
+
return output
|
227
|
+
|
228
|
+
|
229
|
+
class TransformerBlock(eqx.Module):
|
230
|
+
self_attn: SelfAttentionBlock
|
231
|
+
cross_attn: CrossAttentionBlock | None
|
232
|
+
feed_forward: eqx.nn.MLP
|
233
|
+
layer_norm1: eqx.nn.LayerNorm
|
234
|
+
layer_norm2: eqx.nn.LayerNorm
|
235
|
+
layer_norm3: eqx.nn.LayerNorm | None
|
236
|
+
|
237
|
+
def __init__(
|
238
|
+
self,
|
239
|
+
embed_dim: int,
|
240
|
+
num_heads: int,
|
241
|
+
ff_dim: int,
|
242
|
+
*,
|
243
|
+
key: PRNGKeyArray,
|
244
|
+
causal: bool = False,
|
245
|
+
cross_attention: bool = False,
|
246
|
+
) -> None:
|
247
|
+
keys = jax.random.split(key, 4)
|
248
|
+
|
249
|
+
self.self_attn = SelfAttentionBlock(
|
250
|
+
embed_dim=embed_dim,
|
251
|
+
num_heads=num_heads,
|
252
|
+
key=keys[0],
|
253
|
+
causal=causal,
|
254
|
+
)
|
255
|
+
|
256
|
+
if cross_attention:
|
257
|
+
self.cross_attn = CrossAttentionBlock(
|
258
|
+
embed_dim=embed_dim,
|
259
|
+
num_heads=num_heads,
|
260
|
+
key=keys[1],
|
261
|
+
)
|
262
|
+
self.layer_norm3 = eqx.nn.LayerNorm(embed_dim)
|
263
|
+
else:
|
264
|
+
self.cross_attn = None
|
265
|
+
self.layer_norm3 = None
|
266
|
+
|
267
|
+
self.layer_norm1 = eqx.nn.LayerNorm(embed_dim)
|
268
|
+
self.layer_norm2 = eqx.nn.LayerNorm(embed_dim)
|
269
|
+
|
270
|
+
self.feed_forward = eqx.nn.MLP(
|
271
|
+
in_size=embed_dim,
|
272
|
+
out_size=embed_dim,
|
273
|
+
width_size=ff_dim,
|
274
|
+
depth=1,
|
275
|
+
activation=jax.nn.gelu,
|
276
|
+
key=keys[2],
|
277
|
+
)
|
278
|
+
|
279
|
+
@overload
|
280
|
+
def __call__(
|
281
|
+
self,
|
282
|
+
x: Array,
|
283
|
+
*,
|
284
|
+
context: Array | None = None,
|
285
|
+
self_mask: Array | None = None,
|
286
|
+
cross_mask: Array | None = None,
|
287
|
+
key: PRNGKeyArray | None = None,
|
288
|
+
cache: dict[str, dict[str, Array]] | None = None,
|
289
|
+
update_cache: Literal[True],
|
290
|
+
) -> tuple[Array, dict[str, dict[str, Array]]]: ...
|
291
|
+
|
292
|
+
@overload
|
293
|
+
def __call__(
|
294
|
+
self,
|
295
|
+
x: Array,
|
296
|
+
*,
|
297
|
+
context: Array | None = None,
|
298
|
+
self_mask: Array | None = None,
|
299
|
+
cross_mask: Array | None = None,
|
300
|
+
key: PRNGKeyArray | None = None,
|
301
|
+
cache: dict[str, dict[str, Array]] | None = None,
|
302
|
+
update_cache: Literal[False] = False,
|
303
|
+
) -> Array: ...
|
304
|
+
|
305
|
+
def __call__(
|
306
|
+
self,
|
307
|
+
x: Array,
|
308
|
+
*,
|
309
|
+
context: Array | None = None,
|
310
|
+
self_mask: Array | None = None,
|
311
|
+
cross_mask: Array | None = None,
|
312
|
+
key: PRNGKeyArray | None = None,
|
313
|
+
cache: dict[str, dict[str, Array]] | None = None,
|
314
|
+
update_cache: bool = False,
|
315
|
+
) -> Array | tuple[Array, dict[str, dict[str, Array]]]:
|
316
|
+
"""Apply transformer block.
|
317
|
+
|
318
|
+
Args:
|
319
|
+
x: Input tensor
|
320
|
+
context: Optional context for cross-attention
|
321
|
+
self_mask: Mask for self-attention
|
322
|
+
cross_mask: Mask for cross-attention
|
323
|
+
key: Optional PRNG key for dropout
|
324
|
+
cache: Optional dictionary containing cached key and value tensors
|
325
|
+
update_cache: Whether to update the cache and return it
|
326
|
+
|
327
|
+
Returns:
|
328
|
+
If update_cache is False: Output tensor
|
329
|
+
If update_cache is True: Tuple of (output tensor, updated cache)
|
330
|
+
"""
|
331
|
+
chex.assert_rank(x, 2)
|
332
|
+
if key is not None:
|
333
|
+
key1, key2 = jax.random.split(key)
|
334
|
+
else:
|
335
|
+
key1 = key2 = None
|
336
|
+
|
337
|
+
# Initialize cache if needed
|
338
|
+
updated_cache = {}
|
339
|
+
if cache is None:
|
340
|
+
cache = {}
|
341
|
+
|
342
|
+
# Self-attention block with pre-norm
|
343
|
+
norm_x = jax.vmap(self.layer_norm1)(x)
|
344
|
+
|
345
|
+
self_attn_cache = cache.get("self_attn")
|
346
|
+
if update_cache:
|
347
|
+
attn_output, self_attn_cache = self.self_attn(
|
348
|
+
norm_x, key=key1, mask=self_mask, cache=self_attn_cache, update_cache=True
|
349
|
+
)
|
350
|
+
updated_cache["self_attn"] = self_attn_cache
|
351
|
+
else:
|
352
|
+
attn_output = self.self_attn(norm_x, key=key1, mask=self_mask, cache=self_attn_cache)
|
353
|
+
|
354
|
+
x = x + attn_output
|
355
|
+
|
356
|
+
# Cross-attention block (if enabled) with pre-norm
|
357
|
+
if self.cross_attn is not None and context is not None:
|
358
|
+
assert self.layer_norm3 is not None
|
359
|
+
|
360
|
+
norm_x = jax.vmap(self.layer_norm3)(x)
|
361
|
+
cross_attn_cache = cache.get("cross_attn")
|
362
|
+
|
363
|
+
if update_cache:
|
364
|
+
cross_attn_output, cross_attn_cache = self.cross_attn(
|
365
|
+
norm_x, context, key=key2, mask=cross_mask, cache=cross_attn_cache, update_cache=True
|
366
|
+
)
|
367
|
+
updated_cache["cross_attn"] = cross_attn_cache
|
368
|
+
else:
|
369
|
+
cross_attn_output = self.cross_attn(norm_x, context, key=key2, mask=cross_mask, cache=cross_attn_cache)
|
370
|
+
|
371
|
+
x = x + cross_attn_output
|
372
|
+
|
373
|
+
# Feed-forward block with pre-norm
|
374
|
+
norm_x = jax.vmap(self.layer_norm2)(x)
|
375
|
+
ff_output = jax.vmap(self.feed_forward)(norm_x)
|
376
|
+
x = x + ff_output
|
377
|
+
|
378
|
+
if update_cache:
|
379
|
+
return x, updated_cache
|
380
|
+
return x
|
381
|
+
|
382
|
+
|
383
|
+
class Transformer(eqx.Module):
|
384
|
+
token_embedding: eqx.nn.Embedding
|
385
|
+
position_embedding: eqx.nn.Embedding | None
|
386
|
+
layers: list[TransformerBlock]
|
387
|
+
output_layer: eqx.nn.Linear | None
|
388
|
+
layer_norm: eqx.nn.LayerNorm
|
389
|
+
max_seq_len: int = eqx.static_field()
|
390
|
+
embed_dim: int = eqx.static_field()
|
391
|
+
|
392
|
+
def __init__(
|
393
|
+
self,
|
394
|
+
vocab_size: int,
|
395
|
+
embed_dim: int,
|
396
|
+
num_heads: int,
|
397
|
+
ff_dim: int,
|
398
|
+
num_layers: int,
|
399
|
+
max_seq_len: int,
|
400
|
+
output_size: int | None = None,
|
401
|
+
*,
|
402
|
+
key: PRNGKeyArray,
|
403
|
+
causal: bool = False,
|
404
|
+
cross_attention: bool = False,
|
405
|
+
use_absolute_position: bool = True,
|
406
|
+
) -> None:
|
407
|
+
keys = jax.random.split(key, num_layers + 3)
|
408
|
+
|
409
|
+
self.token_embedding = eqx.nn.Embedding(vocab_size, embed_dim, key=keys[0])
|
410
|
+
|
411
|
+
# Position embeddings can be disabled
|
412
|
+
if use_absolute_position:
|
413
|
+
self.position_embedding = eqx.nn.Embedding(max_seq_len, embed_dim, key=keys[1])
|
414
|
+
else:
|
415
|
+
self.position_embedding = None
|
416
|
+
|
417
|
+
self.layers = [
|
418
|
+
TransformerBlock(
|
419
|
+
embed_dim=embed_dim,
|
420
|
+
num_heads=num_heads,
|
421
|
+
ff_dim=ff_dim,
|
422
|
+
key=keys[i + 2],
|
423
|
+
causal=causal,
|
424
|
+
cross_attention=cross_attention,
|
425
|
+
)
|
426
|
+
for i in range(num_layers)
|
427
|
+
]
|
428
|
+
|
429
|
+
self.layer_norm = eqx.nn.LayerNorm(embed_dim)
|
430
|
+
|
431
|
+
if output_size is not None:
|
432
|
+
self.output_layer = eqx.nn.Linear(embed_dim, output_size, key=keys[-1])
|
433
|
+
else:
|
434
|
+
self.output_layer = None
|
435
|
+
|
436
|
+
self.max_seq_len = max_seq_len
|
437
|
+
self.embed_dim = embed_dim
|
438
|
+
|
439
|
+
def _add_positional_embedding(self, x_embedded: Array, positions: Array | None = None) -> Array:
|
440
|
+
"""Add positional embeddings to the token embeddings."""
|
441
|
+
if self.position_embedding is None:
|
442
|
+
return x_embedded
|
443
|
+
|
444
|
+
seq_len, _ = x_embedded.shape
|
445
|
+
|
446
|
+
if positions is None:
|
447
|
+
positions = jnp.arange(seq_len)
|
448
|
+
pos_embedded = jax.vmap(self.position_embedding)(positions)
|
449
|
+
|
450
|
+
return x_embedded + pos_embedded
|
451
|
+
|
452
|
+
def encode(
|
453
|
+
self,
|
454
|
+
x: Array,
|
455
|
+
mask: Array | None = None,
|
456
|
+
positions: Array | None = None,
|
457
|
+
key: PRNGKeyArray | None = None,
|
458
|
+
cache: dict[str, dict[str, dict[str, Array]]] | None = None,
|
459
|
+
update_cache: bool = False,
|
460
|
+
) -> Array | tuple[Array, dict[str, dict[str, dict[str, Array]]]]:
|
461
|
+
"""Encode the input sequence.
|
462
|
+
|
463
|
+
Args:
|
464
|
+
x: Input token indices of shape (seq_len)
|
465
|
+
mask: Optional attention mask
|
466
|
+
positions: Optional positions
|
467
|
+
key: Optional PRNG key for dropout
|
468
|
+
cache: Optional dictionary containing cached key and value tensors
|
469
|
+
update_cache: Whether to update the cache and return it
|
470
|
+
|
471
|
+
Returns:
|
472
|
+
If update_cache is False: Encoded representation
|
473
|
+
If update_cache is True: Tuple of (encoded representation, updated cache)
|
474
|
+
"""
|
475
|
+
# Token embedding
|
476
|
+
x_embedded = jax.vmap(self.token_embedding)(x)
|
477
|
+
|
478
|
+
# Add positional embedding
|
479
|
+
x_embedded = self._add_positional_embedding(x_embedded, positions)
|
480
|
+
|
481
|
+
# Initialize layer caches
|
482
|
+
if cache is None and update_cache:
|
483
|
+
cache = {f"layer_{i}": {} for i in range(len(self.layers))}
|
484
|
+
|
485
|
+
# Updated cache will be built if needed
|
486
|
+
updated_cache = {}
|
487
|
+
|
488
|
+
# Apply transformer layers
|
489
|
+
keys: Array | list[None] = [None] * len(self.layers)
|
490
|
+
if key is not None:
|
491
|
+
keys = jax.random.split(key, len(self.layers))
|
492
|
+
|
493
|
+
for i, (layer, layer_key) in enumerate(zip(self.layers, keys, strict=False)):
|
494
|
+
layer_cache = None if cache is None else cache.get(f"layer_{i}")
|
495
|
+
|
496
|
+
if update_cache:
|
497
|
+
x_embedded, layer_updated_cache = layer.__call__(
|
498
|
+
x_embedded,
|
499
|
+
self_mask=mask,
|
500
|
+
key=layer_key,
|
501
|
+
cache=layer_cache,
|
502
|
+
update_cache=True,
|
503
|
+
)
|
504
|
+
updated_cache[f"layer_{i}"] = layer_updated_cache
|
505
|
+
else:
|
506
|
+
x_embedded = layer.__call__(
|
507
|
+
x_embedded,
|
508
|
+
self_mask=mask,
|
509
|
+
key=layer_key,
|
510
|
+
cache=layer_cache,
|
511
|
+
)
|
512
|
+
|
513
|
+
# Apply final layer norm
|
514
|
+
output = jax.vmap(self.layer_norm)(x_embedded)
|
515
|
+
|
516
|
+
if update_cache:
|
517
|
+
return output, updated_cache
|
518
|
+
return output
|
519
|
+
|
520
|
+
def decode(
|
521
|
+
self,
|
522
|
+
x: Array,
|
523
|
+
context: Array,
|
524
|
+
self_mask: Array | None = None,
|
525
|
+
cross_mask: Array | None = None,
|
526
|
+
positions: Array | None = None,
|
527
|
+
key: PRNGKeyArray | None = None,
|
528
|
+
cache: dict[str, dict[str, dict[str, Array]]] | None = None,
|
529
|
+
update_cache: bool = False,
|
530
|
+
) -> Array | tuple[Array, dict[str, dict[str, dict[str, Array]]]]:
|
531
|
+
"""Decode with self-attention and cross-attention.
|
532
|
+
|
533
|
+
Args:
|
534
|
+
x: Input token indices
|
535
|
+
context: Context from encoder
|
536
|
+
self_mask: Optional self-attention mask
|
537
|
+
cross_mask: Optional cross-attention mask
|
538
|
+
positions: Optional positions
|
539
|
+
key: Optional PRNG key for dropout
|
540
|
+
cache: Optional dictionary containing cached key and value tensors
|
541
|
+
update_cache: Whether to update the cache and return it
|
542
|
+
|
543
|
+
Returns:
|
544
|
+
If update_cache is False: Decoded representation
|
545
|
+
If update_cache is True: Tuple of (decoded representation, updated cache)
|
546
|
+
"""
|
547
|
+
# Token embedding
|
548
|
+
x_embedded = jax.vmap(lambda x_seq: jax.vmap(self.token_embedding)(x_seq))(x)
|
549
|
+
|
550
|
+
# Add positional embedding
|
551
|
+
x_embedded = self._add_positional_embedding(x_embedded, positions)
|
552
|
+
|
553
|
+
# Initialize layer caches
|
554
|
+
if cache is None and update_cache:
|
555
|
+
cache = {f"layer_{i}": {} for i in range(len(self.layers))}
|
556
|
+
|
557
|
+
# Updated cache will be built if needed
|
558
|
+
updated_cache = {}
|
559
|
+
|
560
|
+
# Apply transformer layers with cross-attention
|
561
|
+
keys: Array | list[None] = [None] * len(self.layers)
|
562
|
+
if key is not None:
|
563
|
+
keys = jax.random.split(key, len(self.layers))
|
564
|
+
|
565
|
+
for i, (layer, layer_key) in enumerate(zip(self.layers, keys, strict=False)):
|
566
|
+
layer_cache = None if cache is None else cache.get(f"layer_{i}")
|
567
|
+
|
568
|
+
if update_cache:
|
569
|
+
x_embedded, layer_updated_cache = layer.__call__(
|
570
|
+
x_embedded,
|
571
|
+
context=context,
|
572
|
+
self_mask=self_mask,
|
573
|
+
cross_mask=cross_mask,
|
574
|
+
key=layer_key,
|
575
|
+
cache=layer_cache,
|
576
|
+
update_cache=True,
|
577
|
+
)
|
578
|
+
updated_cache[f"layer_{i}"] = layer_updated_cache
|
579
|
+
else:
|
580
|
+
x_embedded = layer(
|
581
|
+
x_embedded,
|
582
|
+
context=context,
|
583
|
+
self_mask=self_mask,
|
584
|
+
cross_mask=cross_mask,
|
585
|
+
key=layer_key,
|
586
|
+
cache=layer_cache,
|
587
|
+
)
|
588
|
+
|
589
|
+
# Apply final layer norm
|
590
|
+
output = jax.vmap(self.layer_norm)(x_embedded)
|
591
|
+
|
592
|
+
if update_cache:
|
593
|
+
return output, updated_cache
|
594
|
+
return output
|
595
|
+
|
596
|
+
@overload
|
597
|
+
def __call__(
|
598
|
+
self,
|
599
|
+
x: Array,
|
600
|
+
*,
|
601
|
+
mask: Array | None = None,
|
602
|
+
positions: Array | None = None,
|
603
|
+
key: PRNGKeyArray | None = None,
|
604
|
+
cache: dict[str, dict[str, dict[str, Array]]] | None = None,
|
605
|
+
update_cache: Literal[True],
|
606
|
+
) -> tuple[Array, dict[str, dict[str, dict[str, Array]]]]: ...
|
607
|
+
|
608
|
+
@overload
|
609
|
+
def __call__(
|
610
|
+
self,
|
611
|
+
x: Array,
|
612
|
+
*,
|
613
|
+
mask: Array | None = None,
|
614
|
+
positions: Array | None = None,
|
615
|
+
key: PRNGKeyArray | None = None,
|
616
|
+
cache: dict[str, dict[str, dict[str, Array]]] | None = None,
|
617
|
+
update_cache: Literal[False] = False,
|
618
|
+
) -> Array: ...
|
619
|
+
|
620
|
+
def __call__(
|
621
|
+
self,
|
622
|
+
x: Array,
|
623
|
+
*,
|
624
|
+
mask: Array | None = None,
|
625
|
+
positions: Array | None = None,
|
626
|
+
key: PRNGKeyArray | None = None,
|
627
|
+
cache: dict[str, dict[str, dict[str, Array]]] | None = None,
|
628
|
+
update_cache: bool = False,
|
629
|
+
) -> Array | tuple[Array, dict[str, dict[str, dict[str, Array]]]]:
|
630
|
+
"""Forward pass for encoder-only or decoder-only transformers.
|
631
|
+
|
632
|
+
Args:
|
633
|
+
x: Input token indices of shape (seq_len)
|
634
|
+
mask: Optional attention mask
|
635
|
+
positions: Optional positions
|
636
|
+
key: Optional PRNG key for dropout
|
637
|
+
cache: Optional dictionary containing cached key and value tensors
|
638
|
+
update_cache: Whether to update the cache and return it
|
639
|
+
|
640
|
+
Returns:
|
641
|
+
If update_cache is False: Output representation
|
642
|
+
If update_cache is True: Tuple of (output representation, updated cache)
|
643
|
+
"""
|
644
|
+
chex.assert_rank(x, 1)
|
645
|
+
|
646
|
+
if update_cache:
|
647
|
+
output, updated_cache = self.encode(
|
648
|
+
x, mask=mask, positions=positions, key=key, cache=cache, update_cache=True
|
649
|
+
)
|
650
|
+
else:
|
651
|
+
output = self.encode(x, mask=mask, positions=positions, key=key, cache=cache)
|
652
|
+
|
653
|
+
# Apply output layer if it exists
|
654
|
+
if self.output_layer is not None:
|
655
|
+
output = jax.vmap(self.output_layer)(output)
|
656
|
+
|
657
|
+
if update_cache:
|
658
|
+
return output, updated_cache
|
659
|
+
return output
|
660
|
+
|
661
|
+
def predict_sequence(self, x_seq: Array) -> Array:
|
662
|
+
return self(x=x_seq)
|
663
|
+
|
664
|
+
def generate_sequence(
|
665
|
+
self,
|
666
|
+
prompt_seq: Array,
|
667
|
+
max_len: int,
|
668
|
+
temperature: float = 1.0,
|
669
|
+
top_k: int | None = None,
|
670
|
+
key: PRNGKeyArray | None = None,
|
671
|
+
) -> Array:
|
672
|
+
"""Generate a sequence autoregressively with KV caching.
|
673
|
+
|
674
|
+
Args:
|
675
|
+
prompt_seq: Input token indices of shape (prompt_len,)
|
676
|
+
max_len: Maximum length of generated sequence
|
677
|
+
temperature: Sampling temperature
|
678
|
+
top_k: Optional top-k sampling parameter
|
679
|
+
key: PRNG key for sampling
|
680
|
+
|
681
|
+
Returns:
|
682
|
+
Generated sequence of shape (prompt_len + max_len,)
|
683
|
+
"""
|
684
|
+
if key is None:
|
685
|
+
key = jax.random.PRNGKey(0)
|
686
|
+
|
687
|
+
prompt_len = prompt_seq.shape[0]
|
688
|
+
sequence = prompt_seq
|
689
|
+
|
690
|
+
# Create causal mask for generation
|
691
|
+
causal_mask = jnp.tril(jnp.ones((self.max_seq_len, self.max_seq_len), dtype=jnp.bool_))
|
692
|
+
|
693
|
+
# Initialize cache with the prompt
|
694
|
+
_, cache = self(x=prompt_seq, mask=causal_mask[:prompt_len, :prompt_len], update_cache=True)
|
695
|
+
|
696
|
+
# Define decode step function (for clarity)
|
697
|
+
def decode_step(seq: Array, pos: int, cur_cache: dict, rng: PRNGKeyArray) -> tuple[Array, dict, PRNGKeyArray]:
|
698
|
+
# Get the next position and last token
|
699
|
+
pos_tensor = jnp.array([pos])
|
700
|
+
last_token = seq[-1:]
|
701
|
+
|
702
|
+
# Get logits for next token
|
703
|
+
rng, subrng = jax.random.split(rng)
|
704
|
+
logits, new_cache = self(
|
705
|
+
x=last_token,
|
706
|
+
positions=pos_tensor,
|
707
|
+
key=subrng,
|
708
|
+
cache=cur_cache,
|
709
|
+
update_cache=True,
|
710
|
+
)
|
711
|
+
|
712
|
+
# Extract final logits and apply temperature
|
713
|
+
logits = logits[-1] / temperature
|
714
|
+
|
715
|
+
# Apply top-k sampling if specified
|
716
|
+
if top_k is not None:
|
717
|
+
top_logits, top_indices = jax.lax.top_k(logits, top_k)
|
718
|
+
logits = jnp.full_like(logits, float("-inf"))
|
719
|
+
logits = logits.at[top_indices].set(top_logits)
|
720
|
+
|
721
|
+
# Sample next token
|
722
|
+
rng, subrng = jax.random.split(rng)
|
723
|
+
next_token = jax.random.categorical(subrng, logits[None, ...])[0]
|
724
|
+
|
725
|
+
# Add token to sequence
|
726
|
+
new_seq = jnp.concatenate([seq, next_token[None]], axis=0)
|
727
|
+
return new_seq, new_cache, rng
|
728
|
+
|
729
|
+
# Generate tokens one by one
|
730
|
+
for _ in range(max_len):
|
731
|
+
# Break if max sequence length reached
|
732
|
+
if sequence.shape[0] >= self.max_seq_len:
|
733
|
+
break
|
734
|
+
|
735
|
+
# Decode next token
|
736
|
+
sequence, cache, key = decode_step(seq=sequence, pos=sequence.shape[0] - 1, cur_cache=cache, rng=key)
|
737
|
+
|
738
|
+
return sequence
|
xax/task/logger.py
CHANGED
@@ -526,7 +526,7 @@ class LoggerImpl(ABC):
|
|
526
526
|
Returns:
|
527
527
|
If the logger should log the current step.
|
528
528
|
"""
|
529
|
-
elapsed_time = state.elapsed_time_s.item()
|
529
|
+
elapsed_time = state.elapsed_time_s.item()
|
530
530
|
return self.tickers[state.phase].tick(elapsed_time)
|
531
531
|
|
532
532
|
|
xax/task/mixins/train.py
CHANGED
@@ -121,7 +121,7 @@ class ValidStepTimer:
|
|
121
121
|
self.last_valid_step = state.num_steps.item()
|
122
122
|
|
123
123
|
def __call__(self, state: State) -> bool:
|
124
|
-
if state.num_steps < self.valid_first_n_steps
|
124
|
+
if state.num_steps < self.valid_first_n_steps:
|
125
125
|
return True
|
126
126
|
|
127
127
|
if self.last_valid_time is None or self.last_valid_step is None:
|
@@ -130,18 +130,15 @@ class ValidStepTimer:
|
|
130
130
|
|
131
131
|
# Step-based validation.
|
132
132
|
valid_every_n_steps = self.valid_every_n_steps
|
133
|
-
if valid_every_n_steps is not None and
|
134
|
-
state.num_steps >= valid_every_n_steps + self.last_valid_step
|
135
|
-
or state.num_valid_steps >= valid_every_n_steps + self.last_valid_step
|
136
|
-
):
|
133
|
+
if valid_every_n_steps is not None and state.num_steps >= valid_every_n_steps + self.last_valid_step:
|
137
134
|
self._reset(state)
|
138
135
|
return True
|
139
136
|
|
140
137
|
# Time-based validation.
|
141
138
|
valid_every_n_seconds = self.valid_every_n_seconds
|
142
|
-
if
|
143
|
-
|
144
|
-
|
139
|
+
if (
|
140
|
+
valid_every_n_seconds is not None
|
141
|
+
and state.elapsed_time_s.item() - self.last_valid_time >= valid_every_n_seconds
|
145
142
|
):
|
146
143
|
self._reset(state)
|
147
144
|
return True
|
@@ -149,10 +146,7 @@ class ValidStepTimer:
|
|
149
146
|
# Time-based validation for first validation step.
|
150
147
|
if self.first_valid_step_flag:
|
151
148
|
valid_first_n_seconds = self.valid_first_n_seconds
|
152
|
-
if valid_first_n_seconds is not None and (
|
153
|
-
state.elapsed_time_s.item() >= valid_first_n_seconds
|
154
|
-
or state.valid_elapsed_time_s.item() >= valid_first_n_seconds
|
155
|
-
):
|
149
|
+
if valid_first_n_seconds is not None and state.elapsed_time_s.item() >= valid_first_n_seconds:
|
156
150
|
self._reset(state)
|
157
151
|
self.first_valid_step_flag = False
|
158
152
|
return True
|
@@ -777,12 +771,12 @@ class TrainMixin(
|
|
777
771
|
self.log_step(eqx.combine(model_arr, model_static), valid_batch, output, metrics, state)
|
778
772
|
|
779
773
|
state = state.replace(
|
780
|
-
|
781
|
-
|
774
|
+
num_steps=state.num_steps + 1,
|
775
|
+
num_samples=state.num_samples + (self.get_size_of_batch(valid_batch) or 0),
|
782
776
|
)
|
783
777
|
|
784
778
|
state = state.replace(
|
785
|
-
|
779
|
+
elapsed_time_s=state.elapsed_time_s + timer.elapsed_time,
|
786
780
|
)
|
787
781
|
|
788
782
|
with ContextTimer() as timer:
|
@@ -882,7 +876,7 @@ class TrainMixin(
|
|
882
876
|
key, model_key = jax.random.split(key)
|
883
877
|
models, optimizers, opt_states, state = self.load_initial_state(model_key, load_optimizer=True)
|
884
878
|
logger.info("Model size: %s", f"{get_pytree_param_count(models):,}")
|
885
|
-
logger.info("Optimizer size: %s", f"{get_pytree_param_count(
|
879
|
+
logger.info("Optimizer size: %s", f"{get_pytree_param_count(opt_states):,}")
|
886
880
|
|
887
881
|
state = self.on_training_start(state)
|
888
882
|
|
xax/utils/experiments.py
CHANGED
@@ -111,8 +111,8 @@ class StateTimer:
|
|
111
111
|
|
112
112
|
def step(self, state: State) -> None:
|
113
113
|
cur_time = time.time()
|
114
|
-
num_steps = int(
|
115
|
-
num_samples = int(
|
114
|
+
num_steps = int(state.num_steps.item())
|
115
|
+
num_samples = int(state.num_samples.item())
|
116
116
|
self.step_timer.step(num_steps, cur_time)
|
117
117
|
self.sample_timer.step(num_samples, cur_time)
|
118
118
|
self.iter_timer.step(cur_time)
|
@@ -1,4 +1,4 @@
|
|
1
|
-
xax/__init__.py,sha256=
|
1
|
+
xax/__init__.py,sha256=Q0boKxPtEUiiJ9j7Cdx51bLLFtYx3fPfCTG-o8o2Chk,15653
|
2
2
|
xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
3
|
xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
|
4
4
|
xax/requirements.txt,sha256=6qY-84e-sTmlfJNrSjwONQKqzAn5h8G_oGIhnhmfSr4,302
|
@@ -6,8 +6,9 @@ xax/cli/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
6
6
|
xax/cli/edit_config.py,sha256=LQUIlOS6hvPZyVEaMme3FP-62M0BKQPYavCwVDWuBLw,2600
|
7
7
|
xax/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
8
8
|
xax/core/conf.py,sha256=d7Dp_GwKnaxtkztlSrJSM_LR0UYJX_FWTtceIWCBkxc,5138
|
9
|
-
xax/core/state.py,sha256=
|
9
|
+
xax/core/state.py,sha256=F9Tj3FfCw8zFKaDEoEGiThZE2ntYEtzNjnBX3pQ1g60,3826
|
10
10
|
xax/nn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
11
|
+
xax/nn/attention.py,sha256=0essK90OO3x9FxnUqU0DhufwXKRMN41zMtRCki5iKzQ,24742
|
11
12
|
xax/nn/embeddings.py,sha256=bQGxBFxkLwi2MQLkRfGaHPH5P_KKB21HdI7VNWTKIOQ,11847
|
12
13
|
xax/nn/functions.py,sha256=bA5kJYzMtFM8eUqBC086i355zJMAO7k_vPFNSDBI9-s,2814
|
13
14
|
xax/nn/geom.py,sha256=A7WPefMvgwUNReZC7_HX1GmvHPASyghbaXaKsuhwDrE,7382
|
@@ -17,7 +18,7 @@ xax/nn/parallel.py,sha256=fnTiT7MsG7eQrJvqwjIz2Ifo3P27TuxIJzmpGYSa_dQ,4608
|
|
17
18
|
xax/nn/ssm.py,sha256=8dLAcQ1hBaMT-kkHvwGu_ecxJeTY32WeMYmd4T4KtxA,10745
|
18
19
|
xax/task/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
19
20
|
xax/task/base.py,sha256=TYANmjNcce4_V5ZSYLnE91PXRn7Nn0nT7hN8plW_Au0,8117
|
20
|
-
xax/task/logger.py,sha256=
|
21
|
+
xax/task/logger.py,sha256=Bmhl4mv08Aq49ZyX6BdjPIsPJK28e8s3mVFatM4IY2Q,41060
|
21
22
|
xax/task/script.py,sha256=bMMIJoUtpSBvPp6-7bejTrajTXvSg0794sYLKdPIToE,972
|
22
23
|
xax/task/task.py,sha256=UHMpnv__gqMcfbC_L-Hhk-DCnUYlFVsgbNf-v8o8B7U,1424
|
23
24
|
xax/task/launchers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -41,10 +42,10 @@ xax/task/mixins/logger.py,sha256=6oXsJJyNUx6YT3q58FVXMZBUpMgjVkGre6BXFN20cVI,280
|
|
41
42
|
xax/task/mixins/process.py,sha256=hqDEsMp_SL6ee97iq26-G0g49OcWZZaX82JD4F22eJU,1781
|
42
43
|
xax/task/mixins/runnable.py,sha256=IYIsLd2k09g-_y6o44EhJqT7E6BpsyEMmsyLSuzqjtc,1979
|
43
44
|
xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
|
44
|
-
xax/task/mixins/train.py,sha256=
|
45
|
+
xax/task/mixins/train.py,sha256=TZatz5QwTfrNhQTiO2IqrmQY9P4Lay6FAD2VsQpWa54,33245
|
45
46
|
xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
46
47
|
xax/utils/debugging.py,sha256=OtUdu-3tQsQtik0Q9UM-SNV46IbPjwrAfZcywzoB5d4,1940
|
47
|
-
xax/utils/experiments.py,sha256=
|
48
|
+
xax/utils/experiments.py,sha256=5k5hPYSaVjzoR_nm2Q3DAHMMYi3Bcp3N3PAQbwZq7Gg,29830
|
48
49
|
xax/utils/jax.py,sha256=6cP95-rcjkRt1fefkZWJQhJhH0uUYWJB3w4NP1-aDp0,10136
|
49
50
|
xax/utils/jaxpr.py,sha256=H7pWl48ROXIB1-ZPWYfOn-ou3EBMxYWIwc_A0reJQoo,2333
|
50
51
|
xax/utils/logging.py,sha256=GAhTne2rdB4Fa1lzk06DMO15U8MTejn6XTClShC-ZtU,6622
|
@@ -58,9 +59,9 @@ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,706
|
|
58
59
|
xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
59
60
|
xax/utils/types/frozen_dict.py,sha256=ebtHENhyUzSjyJTlbMaLtcckQIJ7EtgJiok_40TJZpo,4689
|
60
61
|
xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
|
61
|
-
xax-0.2.
|
62
|
-
xax-0.2.
|
63
|
-
xax-0.2.
|
64
|
-
xax-0.2.
|
65
|
-
xax-0.2.
|
66
|
-
xax-0.2.
|
62
|
+
xax-0.2.23.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
|
63
|
+
xax-0.2.23.dist-info/METADATA,sha256=mA98vsIjdfb8XM2mN1vUb2VRVEPU4xf10IWLxxFJjmY,1247
|
64
|
+
xax-0.2.23.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
|
65
|
+
xax-0.2.23.dist-info/entry_points.txt,sha256=uRC6rx5ce0bf-FblJaZSBMxxKFfMyoWTf8OWbBmLSe8,61
|
66
|
+
xax-0.2.23.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
|
67
|
+
xax-0.2.23.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|