xax 0.2.21__tar.gz → 0.2.23__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {xax-0.2.21/xax.egg-info → xax-0.2.23}/PKG-INFO +1 -1
- {xax-0.2.21 → xax-0.2.23}/pyproject.toml +1 -0
- {xax-0.2.21 → xax-0.2.23}/xax/__init__.py +15 -2
- {xax-0.2.21 → xax-0.2.23}/xax/core/state.py +10 -37
- xax-0.2.23/xax/nn/attention.py +738 -0
- {xax-0.2.21 → xax-0.2.23}/xax/task/logger.py +1 -1
- {xax-0.2.21 → xax-0.2.23}/xax/task/mixins/train.py +17 -19
- {xax-0.2.21 → xax-0.2.23}/xax/utils/experiments.py +2 -2
- {xax-0.2.21 → xax-0.2.23}/xax/utils/jax.py +109 -7
- {xax-0.2.21 → xax-0.2.23/xax.egg-info}/PKG-INFO +1 -1
- {xax-0.2.21 → xax-0.2.23}/xax.egg-info/SOURCES.txt +1 -0
- {xax-0.2.21 → xax-0.2.23}/LICENSE +0 -0
- {xax-0.2.21 → xax-0.2.23}/MANIFEST.in +0 -0
- {xax-0.2.21 → xax-0.2.23}/README.md +0 -0
- {xax-0.2.21 → xax-0.2.23}/setup.cfg +0 -0
- {xax-0.2.21 → xax-0.2.23}/setup.py +0 -0
- {xax-0.2.21 → xax-0.2.23}/xax/cli/__init__.py +0 -0
- {xax-0.2.21 → xax-0.2.23}/xax/cli/edit_config.py +0 -0
- {xax-0.2.21 → xax-0.2.23}/xax/core/__init__.py +0 -0
- {xax-0.2.21 → xax-0.2.23}/xax/core/conf.py +0 -0
- {xax-0.2.21 → xax-0.2.23}/xax/nn/__init__.py +0 -0
- {xax-0.2.21 → xax-0.2.23}/xax/nn/embeddings.py +0 -0
- {xax-0.2.21 → xax-0.2.23}/xax/nn/functions.py +0 -0
- {xax-0.2.21 → xax-0.2.23}/xax/nn/geom.py +0 -0
- {xax-0.2.21 → xax-0.2.23}/xax/nn/losses.py +0 -0
- {xax-0.2.21 → xax-0.2.23}/xax/nn/metrics.py +0 -0
- {xax-0.2.21 → xax-0.2.23}/xax/nn/parallel.py +0 -0
- {xax-0.2.21 → xax-0.2.23}/xax/nn/ssm.py +0 -0
- {xax-0.2.21 → xax-0.2.23}/xax/py.typed +0 -0
- {xax-0.2.21 → xax-0.2.23}/xax/requirements-dev.txt +0 -0
- {xax-0.2.21 → xax-0.2.23}/xax/requirements.txt +0 -0
- {xax-0.2.21 → xax-0.2.23}/xax/task/__init__.py +0 -0
- {xax-0.2.21 → xax-0.2.23}/xax/task/base.py +0 -0
- {xax-0.2.21 → xax-0.2.23}/xax/task/launchers/__init__.py +0 -0
- {xax-0.2.21 → xax-0.2.23}/xax/task/launchers/base.py +0 -0
- {xax-0.2.21 → xax-0.2.23}/xax/task/launchers/cli.py +0 -0
- {xax-0.2.21 → xax-0.2.23}/xax/task/launchers/single_process.py +0 -0
- {xax-0.2.21 → xax-0.2.23}/xax/task/loggers/__init__.py +0 -0
- {xax-0.2.21 → xax-0.2.23}/xax/task/loggers/callback.py +0 -0
- {xax-0.2.21 → xax-0.2.23}/xax/task/loggers/json.py +0 -0
- {xax-0.2.21 → xax-0.2.23}/xax/task/loggers/state.py +0 -0
- {xax-0.2.21 → xax-0.2.23}/xax/task/loggers/stdout.py +0 -0
- {xax-0.2.21 → xax-0.2.23}/xax/task/loggers/tensorboard.py +0 -0
- {xax-0.2.21 → xax-0.2.23}/xax/task/mixins/__init__.py +0 -0
- {xax-0.2.21 → xax-0.2.23}/xax/task/mixins/artifacts.py +0 -0
- {xax-0.2.21 → xax-0.2.23}/xax/task/mixins/checkpointing.py +0 -0
- {xax-0.2.21 → xax-0.2.23}/xax/task/mixins/compile.py +0 -0
- {xax-0.2.21 → xax-0.2.23}/xax/task/mixins/cpu_stats.py +0 -0
- {xax-0.2.21 → xax-0.2.23}/xax/task/mixins/data_loader.py +0 -0
- {xax-0.2.21 → xax-0.2.23}/xax/task/mixins/gpu_stats.py +0 -0
- {xax-0.2.21 → xax-0.2.23}/xax/task/mixins/logger.py +0 -0
- {xax-0.2.21 → xax-0.2.23}/xax/task/mixins/process.py +0 -0
- {xax-0.2.21 → xax-0.2.23}/xax/task/mixins/runnable.py +0 -0
- {xax-0.2.21 → xax-0.2.23}/xax/task/mixins/step_wrapper.py +0 -0
- {xax-0.2.21 → xax-0.2.23}/xax/task/script.py +0 -0
- {xax-0.2.21 → xax-0.2.23}/xax/task/task.py +0 -0
- {xax-0.2.21 → xax-0.2.23}/xax/utils/__init__.py +0 -0
- {xax-0.2.21 → xax-0.2.23}/xax/utils/data/__init__.py +0 -0
- {xax-0.2.21 → xax-0.2.23}/xax/utils/data/collate.py +0 -0
- {xax-0.2.21 → xax-0.2.23}/xax/utils/debugging.py +0 -0
- {xax-0.2.21 → xax-0.2.23}/xax/utils/jaxpr.py +0 -0
- {xax-0.2.21 → xax-0.2.23}/xax/utils/logging.py +0 -0
- {xax-0.2.21 → xax-0.2.23}/xax/utils/numpy.py +0 -0
- {xax-0.2.21 → xax-0.2.23}/xax/utils/profile.py +0 -0
- {xax-0.2.21 → xax-0.2.23}/xax/utils/pytree.py +0 -0
- {xax-0.2.21 → xax-0.2.23}/xax/utils/tensorboard.py +0 -0
- {xax-0.2.21 → xax-0.2.23}/xax/utils/text.py +0 -0
- {xax-0.2.21 → xax-0.2.23}/xax/utils/types/__init__.py +0 -0
- {xax-0.2.21 → xax-0.2.23}/xax/utils/types/frozen_dict.py +0 -0
- {xax-0.2.21 → xax-0.2.23}/xax/utils/types/hashable_array.py +0 -0
- {xax-0.2.21 → xax-0.2.23}/xax.egg-info/dependency_links.txt +0 -0
- {xax-0.2.21 → xax-0.2.23}/xax.egg-info/entry_points.txt +0 -0
- {xax-0.2.21 → xax-0.2.23}/xax.egg-info/requires.txt +0 -0
- {xax-0.2.21 → xax-0.2.23}/xax.egg-info/top_level.txt +0 -0
@@ -12,7 +12,7 @@ and running the update script:
|
|
12
12
|
python -m scripts.update_api --inplace
|
13
13
|
"""
|
14
14
|
|
15
|
-
__version__ = "0.2.
|
15
|
+
__version__ = "0.2.23"
|
16
16
|
|
17
17
|
# This list shouldn't be modified by hand; instead, run the update script.
|
18
18
|
__all__ = [
|
@@ -23,6 +23,10 @@ __all__ = [
|
|
23
23
|
"get_run_dir",
|
24
24
|
"load_user_config",
|
25
25
|
"State",
|
26
|
+
"CrossAttentionBlock",
|
27
|
+
"SelfAttentionBlock",
|
28
|
+
"Transformer",
|
29
|
+
"TransformerBlock",
|
26
30
|
"FourierEmbeddings",
|
27
31
|
"IdentityPositionalEmbeddings",
|
28
32
|
"LearnedPositionalEmbeddings",
|
@@ -112,8 +116,10 @@ __all__ = [
|
|
112
116
|
"save_config",
|
113
117
|
"stage_environment",
|
114
118
|
"to_markdown_table",
|
119
|
+
"grad",
|
115
120
|
"jit",
|
116
121
|
"scan",
|
122
|
+
"vmap",
|
117
123
|
"save_jaxpr_dot",
|
118
124
|
"ColoredFormatter",
|
119
125
|
"configure_logging",
|
@@ -198,6 +204,10 @@ NAME_MAP: dict[str, str] = {
|
|
198
204
|
"get_run_dir": "core.conf",
|
199
205
|
"load_user_config": "core.conf",
|
200
206
|
"State": "core.state",
|
207
|
+
"CrossAttentionBlock": "nn.attention",
|
208
|
+
"SelfAttentionBlock": "nn.attention",
|
209
|
+
"Transformer": "nn.attention",
|
210
|
+
"TransformerBlock": "nn.attention",
|
201
211
|
"FourierEmbeddings": "nn.embeddings",
|
202
212
|
"IdentityPositionalEmbeddings": "nn.embeddings",
|
203
213
|
"LearnedPositionalEmbeddings": "nn.embeddings",
|
@@ -287,8 +297,10 @@ NAME_MAP: dict[str, str] = {
|
|
287
297
|
"save_config": "utils.experiments",
|
288
298
|
"stage_environment": "utils.experiments",
|
289
299
|
"to_markdown_table": "utils.experiments",
|
300
|
+
"grad": "utils.jax",
|
290
301
|
"jit": "utils.jax",
|
291
302
|
"scan": "utils.jax",
|
303
|
+
"vmap": "utils.jax",
|
292
304
|
"save_jaxpr_dot": "utils.jaxpr",
|
293
305
|
"ColoredFormatter": "utils.logging",
|
294
306
|
"configure_logging": "utils.logging",
|
@@ -366,6 +378,7 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
366
378
|
load_user_config,
|
367
379
|
)
|
368
380
|
from xax.core.state import Phase, State
|
381
|
+
from xax.nn.attention import CrossAttentionBlock, SelfAttentionBlock, Transformer, TransformerBlock
|
369
382
|
from xax.nn.embeddings import (
|
370
383
|
EmbeddingKind,
|
371
384
|
FourierEmbeddings,
|
@@ -460,7 +473,7 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
460
473
|
stage_environment,
|
461
474
|
to_markdown_table,
|
462
475
|
)
|
463
|
-
from xax.utils.jax import jit, scan
|
476
|
+
from xax.utils.jax import grad, jit, scan, vmap
|
464
477
|
from xax.utils.jaxpr import save_jaxpr_dot
|
465
478
|
from xax.utils.logging import (
|
466
479
|
LOG_ERROR_SUMMARY,
|
@@ -27,11 +27,8 @@ def _int_to_phase(i: int) -> Phase:
|
|
27
27
|
class StateDict(TypedDict, total=False):
|
28
28
|
num_steps: NotRequired[int | Array]
|
29
29
|
num_samples: NotRequired[int | Array]
|
30
|
-
num_valid_steps: NotRequired[int | Array]
|
31
|
-
num_valid_samples: NotRequired[int | Array]
|
32
30
|
start_time_s: NotRequired[float | Array]
|
33
31
|
elapsed_time_s: NotRequired[float | Array]
|
34
|
-
valid_elapsed_time_s: NotRequired[float | Array]
|
35
32
|
phase: NotRequired[Phase]
|
36
33
|
_phase: NotRequired[int | Array]
|
37
34
|
|
@@ -47,38 +44,26 @@ class State:
|
|
47
44
|
return self._int32_arr[0]
|
48
45
|
|
49
46
|
@property
|
50
|
-
def
|
51
|
-
return self._int32_arr[1]
|
47
|
+
def phase(self) -> Phase:
|
48
|
+
return _int_to_phase(self._int32_arr[1].item())
|
52
49
|
|
53
50
|
@property
|
54
51
|
def num_samples(self) -> Array:
|
55
52
|
return self._float32_arr[0]
|
56
53
|
|
57
|
-
@property
|
58
|
-
def num_valid_samples(self) -> Array:
|
59
|
-
return self._float32_arr[1]
|
60
|
-
|
61
54
|
@property
|
62
55
|
def start_time_s(self) -> Array:
|
63
|
-
return self._float32_arr[
|
56
|
+
return self._float32_arr[1]
|
64
57
|
|
65
58
|
@property
|
66
59
|
def elapsed_time_s(self) -> Array:
|
67
|
-
return self._float32_arr[
|
68
|
-
|
69
|
-
@property
|
70
|
-
def valid_elapsed_time_s(self) -> Array:
|
71
|
-
return self._float32_arr[4]
|
72
|
-
|
73
|
-
@property
|
74
|
-
def phase(self) -> Phase:
|
75
|
-
return _int_to_phase(self._int32_arr[2].item())
|
60
|
+
return self._float32_arr[2]
|
76
61
|
|
77
62
|
@classmethod
|
78
63
|
def init_state(cls) -> "State":
|
79
64
|
return cls(
|
80
|
-
_int32_arr=jnp.array([0, 0
|
81
|
-
_float32_arr=jnp.array([0.0,
|
65
|
+
_int32_arr=jnp.array([0, 0], dtype=jnp.int32),
|
66
|
+
_float32_arr=jnp.array([0.0, time.time(), 0.0], dtype=jnp.float32),
|
82
67
|
)
|
83
68
|
|
84
69
|
@property
|
@@ -91,25 +76,19 @@ class State:
|
|
91
76
|
|
92
77
|
if "num_steps" in kwargs:
|
93
78
|
int32_arr = int32_arr.at[0].set(kwargs["num_steps"])
|
94
|
-
if "num_valid_steps" in kwargs:
|
95
|
-
int32_arr = int32_arr.at[1].set(kwargs["num_valid_steps"])
|
96
79
|
|
97
80
|
if "phase" in kwargs:
|
98
|
-
int32_arr = int32_arr.at[
|
81
|
+
int32_arr = int32_arr.at[1].set(_phase_to_int(kwargs["phase"]))
|
99
82
|
if "_phase" in kwargs:
|
100
|
-
int32_arr = int32_arr.at[
|
83
|
+
int32_arr = int32_arr.at[1].set(kwargs["_phase"])
|
101
84
|
|
102
85
|
if "num_samples" in kwargs:
|
103
86
|
float32_arr = float32_arr.at[0].set(kwargs["num_samples"])
|
104
|
-
if "num_valid_samples" in kwargs:
|
105
|
-
float32_arr = float32_arr.at[1].set(kwargs["num_valid_samples"])
|
106
87
|
|
107
88
|
if "start_time_s" in kwargs:
|
108
|
-
float32_arr = float32_arr.at[
|
89
|
+
float32_arr = float32_arr.at[1].set(kwargs["start_time_s"])
|
109
90
|
if "elapsed_time_s" in kwargs:
|
110
|
-
float32_arr = float32_arr.at[
|
111
|
-
if "valid_elapsed_time_s" in kwargs:
|
112
|
-
float32_arr = float32_arr.at[4].set(kwargs["valid_elapsed_time_s"])
|
91
|
+
float32_arr = float32_arr.at[2].set(kwargs["elapsed_time_s"])
|
113
92
|
|
114
93
|
return State(
|
115
94
|
_int32_arr=int32_arr,
|
@@ -119,12 +98,9 @@ class State:
|
|
119
98
|
def to_dict(self) -> dict[str, int | float | str]:
|
120
99
|
return {
|
121
100
|
"num_steps": int(self.num_steps.item()),
|
122
|
-
"num_valid_steps": int(self.num_valid_steps.item()),
|
123
101
|
"num_samples": int(self.num_samples.item()),
|
124
|
-
"num_valid_samples": int(self.num_valid_samples.item()),
|
125
102
|
"start_time_s": float(self.start_time_s.item()),
|
126
103
|
"elapsed_time_s": float(self.elapsed_time_s.item()),
|
127
|
-
"valid_elapsed_time_s": float(self.valid_elapsed_time_s.item()),
|
128
104
|
"phase": str(self.phase),
|
129
105
|
}
|
130
106
|
|
@@ -136,7 +112,6 @@ class State:
|
|
136
112
|
int32_arr = jnp.array(
|
137
113
|
[
|
138
114
|
d.get("num_steps", 0),
|
139
|
-
d.get("num_valid_steps", 0),
|
140
115
|
d.get("_phase", 0),
|
141
116
|
],
|
142
117
|
dtype=jnp.int32,
|
@@ -145,10 +120,8 @@ class State:
|
|
145
120
|
float32_arr = jnp.array(
|
146
121
|
[
|
147
122
|
d.get("num_samples", 0),
|
148
|
-
d.get("num_valid_samples", 0),
|
149
123
|
d.get("start_time_s", time.time()),
|
150
124
|
d.get("elapsed_time_s", 0.0),
|
151
|
-
d.get("valid_elapsed_time_s", 0.0),
|
152
125
|
],
|
153
126
|
dtype=jnp.float32,
|
154
127
|
)
|