xax 0.2.2__py3-none-any.whl → 0.2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +4 -2
- xax/core/state.py +95 -47
- xax/task/logger.py +1 -1
- xax/task/loggers/stdout.py +3 -3
- xax/task/loggers/tensorboard.py +10 -8
- xax/task/mixins/checkpointing.py +4 -4
- xax/task/mixins/train.py +42 -40
- xax/utils/experiments.py +4 -2
- xax/utils/jax.py +69 -0
- {xax-0.2.2.dist-info → xax-0.2.4.dist-info}/METADATA +1 -1
- {xax-0.2.2.dist-info → xax-0.2.4.dist-info}/RECORD +14 -14
- {xax-0.2.2.dist-info → xax-0.2.4.dist-info}/WHEEL +0 -0
- {xax-0.2.2.dist-info → xax-0.2.4.dist-info}/licenses/LICENSE +0 -0
- {xax-0.2.2.dist-info → xax-0.2.4.dist-info}/top_level.txt +0 -0
xax/__init__.py
CHANGED
@@ -12,7 +12,7 @@ and running the update script:
|
|
12
12
|
python -m scripts.update_api --inplace
|
13
13
|
"""
|
14
14
|
|
15
|
-
__version__ = "0.2.
|
15
|
+
__version__ = "0.2.4"
|
16
16
|
|
17
17
|
# This list shouldn't be modified by hand; instead, run the update script.
|
18
18
|
__all__ = [
|
@@ -104,6 +104,7 @@ __all__ = [
|
|
104
104
|
"stage_environment",
|
105
105
|
"to_markdown_table",
|
106
106
|
"jit",
|
107
|
+
"scan",
|
107
108
|
"save_jaxpr_dot",
|
108
109
|
"ColoredFormatter",
|
109
110
|
"configure_logging",
|
@@ -267,6 +268,7 @@ NAME_MAP: dict[str, str] = {
|
|
267
268
|
"stage_environment": "utils.experiments",
|
268
269
|
"to_markdown_table": "utils.experiments",
|
269
270
|
"jit": "utils.jax",
|
271
|
+
"scan": "utils.jax",
|
270
272
|
"save_jaxpr_dot": "utils.jaxpr",
|
271
273
|
"ColoredFormatter": "utils.logging",
|
272
274
|
"configure_logging": "utils.logging",
|
@@ -422,7 +424,7 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
422
424
|
stage_environment,
|
423
425
|
to_markdown_table,
|
424
426
|
)
|
425
|
-
from xax.utils.jax import jit
|
427
|
+
from xax.utils.jax import jit, scan
|
426
428
|
from xax.utils.jaxpr import save_jaxpr_dot
|
427
429
|
from xax.utils.logging import (
|
428
430
|
LOG_ERROR_SUMMARY,
|
xax/core/state.py
CHANGED
@@ -1,10 +1,11 @@
|
|
1
1
|
"""Defines a dataclass for keeping track of the current training state."""
|
2
2
|
|
3
|
-
import
|
4
|
-
from
|
5
|
-
from typing import Any, Literal, NotRequired, TypedDict, Unpack, cast
|
3
|
+
from dataclasses import dataclass
|
4
|
+
from typing import Literal, NotRequired, TypedDict, Unpack, cast
|
6
5
|
|
7
6
|
import jax
|
7
|
+
import jax.numpy as jnp
|
8
|
+
from jaxtyping import Array
|
8
9
|
from omegaconf import MISSING
|
9
10
|
|
10
11
|
from xax.core.conf import field
|
@@ -17,85 +18,132 @@ def _phase_to_int(phase: Phase) -> int:
|
|
17
18
|
|
18
19
|
|
19
20
|
def _int_to_phase(i: int) -> Phase:
|
21
|
+
if i < 0 or i > 1:
|
22
|
+
raise ValueError(f"Invalid phase: {i}")
|
20
23
|
return cast(Phase, ["train", "valid"][i])
|
21
24
|
|
22
25
|
|
23
26
|
class StateDict(TypedDict, total=False):
|
24
|
-
num_steps: NotRequired[int]
|
25
|
-
num_samples: NotRequired[int]
|
26
|
-
num_valid_steps: NotRequired[int]
|
27
|
-
num_valid_samples: NotRequired[int]
|
28
|
-
|
29
|
-
|
27
|
+
num_steps: NotRequired[int | Array]
|
28
|
+
num_samples: NotRequired[int | Array]
|
29
|
+
num_valid_steps: NotRequired[int | Array]
|
30
|
+
num_valid_samples: NotRequired[int | Array]
|
31
|
+
elapsed_time_s: NotRequired[float | Array]
|
32
|
+
valid_elapsed_time_s: NotRequired[float | Array]
|
30
33
|
phase: NotRequired[Phase]
|
34
|
+
_phase: NotRequired[int | Array]
|
31
35
|
|
32
36
|
|
33
37
|
@jax.tree_util.register_dataclass
|
34
38
|
@dataclass(frozen=True, kw_only=True)
|
35
39
|
class State:
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
40
|
+
_int32_arr: Array = field(MISSING, help="Internal array for storing int64 values")
|
41
|
+
_float32_arr: Array = field(MISSING, help="Internal array for storing floating-point values")
|
42
|
+
|
43
|
+
@property
|
44
|
+
def num_steps(self) -> Array:
|
45
|
+
return self._int32_arr[0]
|
46
|
+
|
47
|
+
@property
|
48
|
+
def num_valid_steps(self) -> Array:
|
49
|
+
return self._int32_arr[1]
|
50
|
+
|
51
|
+
@property
|
52
|
+
def num_samples(self) -> Array:
|
53
|
+
return self._float32_arr[0]
|
54
|
+
|
55
|
+
@property
|
56
|
+
def num_valid_samples(self) -> Array:
|
57
|
+
return self._float32_arr[1]
|
58
|
+
|
59
|
+
@property
|
60
|
+
def elapsed_time_s(self) -> Array:
|
61
|
+
return self._float32_arr[2]
|
62
|
+
|
63
|
+
@property
|
64
|
+
def valid_elapsed_time_s(self) -> Array:
|
65
|
+
return self._float32_arr[3]
|
43
66
|
|
44
67
|
@property
|
45
68
|
def phase(self) -> Phase:
|
46
|
-
return _int_to_phase(self.
|
69
|
+
return _int_to_phase(self._int32_arr[2].item())
|
47
70
|
|
48
71
|
@classmethod
|
49
72
|
def init_state(cls) -> "State":
|
50
73
|
return cls(
|
51
|
-
|
52
|
-
|
53
|
-
num_valid_steps=0,
|
54
|
-
num_valid_samples=0,
|
55
|
-
start_time_s=time.time(),
|
56
|
-
elapsed_time_s=0.0,
|
57
|
-
_phase=0,
|
74
|
+
_int32_arr=jnp.array([0, 0, 0], dtype=jnp.int32),
|
75
|
+
_float32_arr=jnp.array([0.0, 0.0, 0.0, 0.0], dtype=jnp.float32),
|
58
76
|
)
|
59
77
|
|
60
78
|
@property
|
61
79
|
def training(self) -> bool:
|
62
80
|
return self.phase == "train"
|
63
81
|
|
64
|
-
def num_phase_steps(self, phase: Phase) -> int:
|
65
|
-
match phase:
|
66
|
-
case "train":
|
67
|
-
return self.num_steps
|
68
|
-
case "valid":
|
69
|
-
return self.num_valid_steps
|
70
|
-
case _:
|
71
|
-
raise ValueError(f"Invalid phase: {phase}")
|
72
|
-
|
73
82
|
def replace(self, **kwargs: Unpack[StateDict]) -> "State":
|
74
|
-
|
83
|
+
int32_arr = self._int32_arr
|
84
|
+
float32_arr = self._float32_arr
|
85
|
+
|
86
|
+
if "num_steps" in kwargs:
|
87
|
+
int32_arr = int32_arr.at[0].set(kwargs["num_steps"])
|
88
|
+
if "num_valid_steps" in kwargs:
|
89
|
+
int32_arr = int32_arr.at[1].set(kwargs["num_valid_steps"])
|
90
|
+
|
75
91
|
if "phase" in kwargs:
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
92
|
+
int32_arr = int32_arr.at[3].set(_phase_to_int(kwargs["phase"]))
|
93
|
+
if "_phase" in kwargs:
|
94
|
+
int32_arr = int32_arr.at[3].set(kwargs["_phase"])
|
95
|
+
|
96
|
+
if "num_samples" in kwargs:
|
97
|
+
float32_arr = float32_arr.at[0].set(kwargs["num_samples"])
|
98
|
+
if "num_valid_samples" in kwargs:
|
99
|
+
float32_arr = float32_arr.at[1].set(kwargs["num_valid_samples"])
|
100
|
+
|
101
|
+
if "elapsed_time_s" in kwargs:
|
102
|
+
float32_arr = float32_arr.at[2].set(kwargs["elapsed_time_s"])
|
103
|
+
if "valid_elapsed_time_s" in kwargs:
|
104
|
+
float32_arr = float32_arr.at[3].set(kwargs["valid_elapsed_time_s"])
|
105
|
+
|
106
|
+
return State(
|
107
|
+
_int32_arr=int32_arr,
|
108
|
+
_float32_arr=float32_arr,
|
109
|
+
)
|
85
110
|
|
86
111
|
def to_dict(self) -> dict[str, int | float | str]:
|
87
112
|
return {
|
88
113
|
"num_steps": int(self.num_steps),
|
89
|
-
"num_samples": int(self.num_samples),
|
90
114
|
"num_valid_steps": int(self.num_valid_steps),
|
115
|
+
"num_samples": int(self.num_samples),
|
91
116
|
"num_valid_samples": int(self.num_valid_samples),
|
92
|
-
"start_time_s": float(self.start_time_s),
|
93
117
|
"elapsed_time_s": float(self.elapsed_time_s),
|
118
|
+
"valid_elapsed_time_s": float(self.valid_elapsed_time_s),
|
94
119
|
"phase": str(self.phase),
|
95
120
|
}
|
96
121
|
|
97
122
|
@classmethod
|
98
|
-
def from_dict(cls, d:
|
123
|
+
def from_dict(cls, **d: Unpack[StateDict]) -> "State":
|
99
124
|
if "phase" in d:
|
100
125
|
d["_phase"] = _phase_to_int(cast(Phase, d.pop("phase")))
|
101
|
-
|
126
|
+
|
127
|
+
int32_arr = jnp.array(
|
128
|
+
[
|
129
|
+
d.get("num_steps", 0),
|
130
|
+
d.get("num_valid_steps", 0),
|
131
|
+
d.get("_phase", 0),
|
132
|
+
],
|
133
|
+
dtype=jnp.int32,
|
134
|
+
)
|
135
|
+
|
136
|
+
float32_arr = jnp.array(
|
137
|
+
[
|
138
|
+
d.get("num_samples", 0),
|
139
|
+
d.get("num_valid_samples", 0),
|
140
|
+
d.get("elapsed_time_s", 0.0),
|
141
|
+
d.get("valid_elapsed_time_s", 0.0),
|
142
|
+
],
|
143
|
+
dtype=jnp.float32,
|
144
|
+
)
|
145
|
+
|
146
|
+
return cls(
|
147
|
+
_int32_arr=int32_arr,
|
148
|
+
_float32_arr=float32_arr,
|
149
|
+
)
|
xax/task/logger.py
CHANGED
@@ -521,7 +521,7 @@ class LoggerImpl(ABC):
|
|
521
521
|
Returns:
|
522
522
|
If the logger should log the current step.
|
523
523
|
"""
|
524
|
-
return self.tickers[state.phase].tick(state.elapsed_time_s)
|
524
|
+
return self.tickers[state.phase].tick(state.elapsed_time_s.item())
|
525
525
|
|
526
526
|
|
527
527
|
class ToastHandler(logging.Handler):
|
xax/task/loggers/stdout.py
CHANGED
@@ -90,9 +90,9 @@ class StdoutLogger(LoggerImpl):
|
|
90
90
|
|
91
91
|
def write_state_window(self, line: LogLine) -> None:
|
92
92
|
state_info: dict[str, str] = {
|
93
|
-
"Steps": format_number(line.state.num_steps, 0),
|
94
|
-
"Samples": format_number(line.state.num_samples, 0),
|
95
|
-
"Elapsed Time": format_timedelta(datetime.timedelta(seconds=line.state.elapsed_time_s), short=True),
|
93
|
+
"Steps": format_number(int(line.state.num_steps.item()), 0),
|
94
|
+
"Samples": format_number(int(line.state.num_samples.item()), 0),
|
95
|
+
"Elapsed Time": format_timedelta(datetime.timedelta(seconds=line.state.elapsed_time_s.item()), short=True),
|
96
96
|
}
|
97
97
|
|
98
98
|
colored_prefix = colored("Phase: ", "grey", bold=True)
|
xax/task/loggers/tensorboard.py
CHANGED
@@ -155,14 +155,16 @@ class TensorboardLogger(LoggerImpl):
|
|
155
155
|
return
|
156
156
|
|
157
157
|
writer = self.get_writer(line.state.phase)
|
158
|
-
|
158
|
+
|
159
|
+
global_step = line.state.num_steps.item()
|
160
|
+
walltime = line.state.elapsed_time_s.item()
|
159
161
|
|
160
162
|
for namespace, scalars in line.scalars.items():
|
161
163
|
for scalar_key, scalar_value in scalars.items():
|
162
164
|
writer.add_scalar(
|
163
165
|
f"{namespace}/{scalar_key}",
|
164
166
|
as_float(scalar_value.value),
|
165
|
-
global_step=
|
167
|
+
global_step=global_step,
|
166
168
|
walltime=walltime,
|
167
169
|
)
|
168
170
|
|
@@ -172,7 +174,7 @@ class TensorboardLogger(LoggerImpl):
|
|
172
174
|
f"{namespace}/{distribution_key}",
|
173
175
|
mean=float(distribution_value.mean),
|
174
176
|
std=float(distribution_value.std),
|
175
|
-
global_step=
|
177
|
+
global_step=global_step,
|
176
178
|
walltime=walltime,
|
177
179
|
)
|
178
180
|
|
@@ -187,7 +189,7 @@ class TensorboardLogger(LoggerImpl):
|
|
187
189
|
sum_squares=float(histogram_value.sum_squares),
|
188
190
|
bucket_limits=[float(x) for x in histogram_value.bucket_limits],
|
189
191
|
bucket_counts=[int(x) for x in histogram_value.bucket_counts],
|
190
|
-
global_step=
|
192
|
+
global_step=global_step,
|
191
193
|
walltime=walltime,
|
192
194
|
)
|
193
195
|
|
@@ -196,7 +198,7 @@ class TensorboardLogger(LoggerImpl):
|
|
196
198
|
writer.add_text(
|
197
199
|
f"{namespace}/{string_key}",
|
198
200
|
string_value.value,
|
199
|
-
global_step=
|
201
|
+
global_step=global_step,
|
200
202
|
walltime=walltime,
|
201
203
|
)
|
202
204
|
|
@@ -205,7 +207,7 @@ class TensorboardLogger(LoggerImpl):
|
|
205
207
|
writer.add_image(
|
206
208
|
f"{namespace}/{image_key}",
|
207
209
|
image_value.image,
|
208
|
-
global_step=
|
210
|
+
global_step=global_step,
|
209
211
|
walltime=walltime,
|
210
212
|
)
|
211
213
|
|
@@ -215,7 +217,7 @@ class TensorboardLogger(LoggerImpl):
|
|
215
217
|
f"{namespace}/{video_key}",
|
216
218
|
video_value.frames,
|
217
219
|
fps=video_value.fps,
|
218
|
-
global_step=
|
220
|
+
global_step=global_step,
|
219
221
|
walltime=walltime,
|
220
222
|
)
|
221
223
|
|
@@ -227,7 +229,7 @@ class TensorboardLogger(LoggerImpl):
|
|
227
229
|
faces=mesh_value.faces,
|
228
230
|
colors=mesh_value.colors,
|
229
231
|
config_dict=mesh_value.config_dict,
|
230
|
-
global_step=
|
232
|
+
global_step=global_step,
|
231
233
|
walltime=walltime,
|
232
234
|
)
|
233
235
|
|
xax/task/mixins/checkpointing.py
CHANGED
@@ -4,7 +4,7 @@ import io
|
|
4
4
|
import json
|
5
5
|
import logging
|
6
6
|
import tarfile
|
7
|
-
from dataclasses import
|
7
|
+
from dataclasses import dataclass
|
8
8
|
from pathlib import Path
|
9
9
|
from typing import Generic, Literal, TypeVar, cast, overload
|
10
10
|
|
@@ -76,7 +76,7 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
|
|
76
76
|
if state.num_steps % self.config.save_every_n_steps == 0:
|
77
77
|
return True
|
78
78
|
if self.config.save_every_n_seconds is not None:
|
79
|
-
last_time, cur_time = self.__last_ckpt_time, state.elapsed_time_s
|
79
|
+
last_time, cur_time = self.__last_ckpt_time, state.elapsed_time_s.item()
|
80
80
|
if cur_time - last_time >= self.config.save_every_n_seconds:
|
81
81
|
self.__last_ckpt_time = cur_time
|
82
82
|
return True
|
@@ -200,7 +200,7 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
|
|
200
200
|
def get_state() -> State:
|
201
201
|
if (state := tar.extractfile("state")) is None:
|
202
202
|
raise ValueError(f"Checkpoint does not contain a state file: {path}")
|
203
|
-
return State(**json.loads(state.read().decode()))
|
203
|
+
return State.from_dict(**json.loads(state.read().decode()))
|
204
204
|
|
205
205
|
def get_config() -> Config:
|
206
206
|
if (config := tar.extractfile("config")) is None:
|
@@ -300,7 +300,7 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
|
|
300
300
|
tar.addfile(info, io.BytesIO(data))
|
301
301
|
|
302
302
|
if state is not None:
|
303
|
-
add_file_bytes("state", json.dumps(
|
303
|
+
add_file_bytes("state", json.dumps(state.to_dict(), indent=2).encode())
|
304
304
|
add_file_bytes("config", OmegaConf.to_yaml(self.config).encode())
|
305
305
|
|
306
306
|
# Updates the symlink to the new checkpoint
|
xax/task/mixins/train.py
CHANGED
@@ -46,6 +46,7 @@ from xax.task.mixins.logger import LoggerConfig, LoggerMixin
|
|
46
46
|
from xax.task.mixins.runnable import RunnableConfig, RunnableMixin
|
47
47
|
from xax.task.mixins.step_wrapper import StepContextConfig, StepContextMixin
|
48
48
|
from xax.utils.experiments import (
|
49
|
+
ContextTimer,
|
49
50
|
StateTimer,
|
50
51
|
TrainingFinishedError,
|
51
52
|
diff_configs,
|
@@ -119,27 +120,30 @@ class ValidStepTimer:
|
|
119
120
|
return True
|
120
121
|
|
121
122
|
if self.last_valid_time is None or self.last_valid_step is None:
|
122
|
-
self.last_valid_time = state.elapsed_time_s
|
123
|
-
self.last_valid_step = state.num_steps
|
123
|
+
self.last_valid_time = state.elapsed_time_s.item()
|
124
|
+
self.last_valid_step = state.num_steps.item()
|
124
125
|
return False
|
125
126
|
|
126
127
|
# Step-based validation.
|
127
128
|
valid_every_n_steps = self.valid_every_n_steps
|
128
129
|
if valid_every_n_steps is not None and state.num_steps >= valid_every_n_steps + self.last_valid_step:
|
129
|
-
self.last_valid_step = state.num_steps
|
130
|
+
self.last_valid_step = state.num_steps.item()
|
130
131
|
return True
|
131
132
|
|
132
133
|
# Time-based validation.
|
133
134
|
valid_every_n_seconds = self.valid_every_n_seconds
|
134
|
-
if
|
135
|
-
|
135
|
+
if (
|
136
|
+
valid_every_n_seconds is not None
|
137
|
+
and state.elapsed_time_s.item() - self.last_valid_time >= valid_every_n_seconds
|
138
|
+
):
|
139
|
+
self.last_valid_time = state.elapsed_time_s.item()
|
136
140
|
return True
|
137
141
|
|
138
142
|
# Time-based validation for first validation step.
|
139
143
|
if self.first_valid_step_flag:
|
140
144
|
valid_first_n_seconds = self.valid_first_n_seconds
|
141
|
-
if valid_first_n_seconds is not None and state.elapsed_time_s >= valid_first_n_seconds:
|
142
|
-
self.last_valid_time = state.elapsed_time_s
|
145
|
+
if valid_first_n_seconds is not None and state.elapsed_time_s.item() >= valid_first_n_seconds:
|
146
|
+
self.last_valid_time = state.elapsed_time_s.item()
|
143
147
|
self.first_valid_step_flag = False
|
144
148
|
return True
|
145
149
|
|
@@ -214,10 +218,6 @@ class TrainMixin(
|
|
214
218
|
def prng_key(self) -> PRNGKeyArray:
|
215
219
|
return jax.random.PRNGKey(self.config.random_seed)
|
216
220
|
|
217
|
-
def on_step_end(self, state: State) -> State:
|
218
|
-
state = super().on_step_end(state)
|
219
|
-
return state.replace(elapsed_time_s=time.time() - state.start_time_s)
|
220
|
-
|
221
221
|
def log_train_step(
|
222
222
|
self,
|
223
223
|
model: PyTree,
|
@@ -548,7 +548,7 @@ class TrainMixin(
|
|
548
548
|
"loss": loss,
|
549
549
|
}
|
550
550
|
|
551
|
-
@xax_jit(static_argnames=["self", "model_static"])
|
551
|
+
@xax_jit(static_argnames=["self", "model_static"], jit_level=3)
|
552
552
|
def get_output_and_loss(
|
553
553
|
self,
|
554
554
|
model_arr: PyTree,
|
@@ -572,12 +572,12 @@ class TrainMixin(
|
|
572
572
|
state: State,
|
573
573
|
) -> tuple[PyTree, optax.OptState, Output, dict[str, Array]]:
|
574
574
|
grad_fn = jax.grad(self.get_output_and_loss, argnums=0, has_aux=True)
|
575
|
-
grad_fn = xax_jit(static_argnums=[1])(grad_fn)
|
575
|
+
grad_fn = xax_jit(static_argnums=[1], jit_level=3)(grad_fn)
|
576
576
|
grads, (output, metrics) = grad_fn(model_arr, model_static, batch, state)
|
577
577
|
model_arr, opt_state, grad_metrics = self.apply_gradients_with_clipping(model_arr, grads, optimizer, opt_state)
|
578
578
|
return model_arr, opt_state, output, metrics | grad_metrics
|
579
579
|
|
580
|
-
@xax_jit(static_argnames=["self", "optimizer"])
|
580
|
+
@xax_jit(static_argnames=["self", "optimizer"], jit_level=3)
|
581
581
|
def apply_gradients_with_clipping(
|
582
582
|
self,
|
583
583
|
model_arr: PyTree,
|
@@ -641,8 +641,8 @@ class TrainMixin(
|
|
641
641
|
def maybe_log_termination_time(self, remaining_percent: float, state: State) -> None:
|
642
642
|
if self._last_printed_remaining_time + PRINT_FINISH_TIME_EVERY_N_SECONDS > state.elapsed_time_s:
|
643
643
|
return
|
644
|
-
self._last_printed_remaining_time = state.elapsed_time_s
|
645
|
-
remaining_seconds = remaining_percent * state.elapsed_time_s / (1 - remaining_percent)
|
644
|
+
self._last_printed_remaining_time = state.elapsed_time_s.item()
|
645
|
+
remaining_seconds = remaining_percent * state.elapsed_time_s.item() / (1 - remaining_percent)
|
646
646
|
termination_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time() + remaining_seconds))
|
647
647
|
logger.log(LOG_PING, "Estimated finish time: %s", termination_time)
|
648
648
|
|
@@ -663,11 +663,11 @@ class TrainMixin(
|
|
663
663
|
def get_step(self, state: State) -> int:
|
664
664
|
match self._step_kind:
|
665
665
|
case "step":
|
666
|
-
return state.num_steps
|
666
|
+
return int(state.num_steps.item())
|
667
667
|
case "sample":
|
668
|
-
return state.num_samples
|
668
|
+
return int(state.num_samples.item())
|
669
669
|
case "second":
|
670
|
-
return int(state.elapsed_time_s)
|
670
|
+
return int(state.elapsed_time_s.item())
|
671
671
|
case _:
|
672
672
|
raise ValueError(f"Invalid step kind {self._step_kind}")
|
673
673
|
|
@@ -683,7 +683,7 @@ class TrainMixin(
|
|
683
683
|
def model_partition_fn(self, item: Any) -> bool: # noqa: ANN401
|
684
684
|
return eqx.is_inexact_array(item)
|
685
685
|
|
686
|
-
@xax_jit(static_argnames=["self", "model_static", "optimizer"])
|
686
|
+
@xax_jit(static_argnames=["self", "model_static", "optimizer"], jit_level=3)
|
687
687
|
def train_step(
|
688
688
|
self,
|
689
689
|
model_arr: PyTree,
|
@@ -696,7 +696,7 @@ class TrainMixin(
|
|
696
696
|
model_arr, opt_state, output, metrics = self.update(model_arr, model_static, optimizer, opt_state, batch, state)
|
697
697
|
return model_arr, opt_state, output, FrozenDict(metrics)
|
698
698
|
|
699
|
-
@xax_jit(static_argnames=["self", "model_static"])
|
699
|
+
@xax_jit(static_argnames=["self", "model_static"], jit_level=3)
|
700
700
|
def val_step(
|
701
701
|
self,
|
702
702
|
model_arr: PyTree,
|
@@ -720,34 +720,38 @@ class TrainMixin(
|
|
720
720
|
|
721
721
|
while not self.is_training_over(state):
|
722
722
|
if self.valid_step_timer.is_valid_step(state):
|
723
|
-
|
723
|
+
with ContextTimer() as timer:
|
724
|
+
valid_batch = next(valid_pf)
|
725
|
+
output, metrics = self.val_step(model_arr, model_static, valid_batch, state)
|
726
|
+
self.log_step(eqx.combine(model_arr, model_static), valid_batch, output, metrics, state)
|
727
|
+
|
724
728
|
state = state.replace(
|
725
729
|
phase="valid",
|
726
730
|
num_valid_steps=state.num_valid_steps + 1,
|
727
731
|
num_valid_samples=state.num_valid_samples + (self.get_size_of_batch(valid_batch) or 0),
|
732
|
+
valid_elapsed_time_s=state.valid_elapsed_time_s + timer.elapsed_time,
|
728
733
|
)
|
729
734
|
|
730
|
-
|
731
|
-
self.
|
735
|
+
with ContextTimer() as timer:
|
736
|
+
state = self.on_step_start(state)
|
737
|
+
train_batch = next(train_pf)
|
738
|
+
model_arr, opt_state, output, metrics = self.train_step(
|
739
|
+
model_arr=model_arr,
|
740
|
+
model_static=model_static,
|
741
|
+
optimizer=optimizer,
|
742
|
+
opt_state=opt_state,
|
743
|
+
batch=train_batch,
|
744
|
+
state=state,
|
745
|
+
)
|
746
|
+
self.log_step(eqx.combine(model_arr, model_static), train_batch, output, metrics, state)
|
732
747
|
|
733
|
-
state = self.on_step_start(state)
|
734
|
-
train_batch = next(train_pf)
|
735
748
|
state = state.replace(
|
736
749
|
phase="train",
|
737
750
|
num_steps=state.num_steps + 1,
|
738
751
|
num_samples=state.num_samples + (self.get_size_of_batch(train_batch) or 0),
|
752
|
+
elapsed_time_s=state.elapsed_time_s + timer.elapsed_time,
|
739
753
|
)
|
740
754
|
|
741
|
-
model_arr, opt_state, output, metrics = self.train_step(
|
742
|
-
model_arr=model_arr,
|
743
|
-
model_static=model_static,
|
744
|
-
optimizer=optimizer,
|
745
|
-
opt_state=opt_state,
|
746
|
-
batch=train_batch,
|
747
|
-
state=state,
|
748
|
-
)
|
749
|
-
self.log_step(eqx.combine(model_arr, model_static), train_batch, output, metrics, state)
|
750
|
-
|
751
755
|
state = self.on_step_end(state)
|
752
756
|
|
753
757
|
if self.should_checkpoint(state):
|
@@ -843,10 +847,8 @@ class TrainMixin(
|
|
843
847
|
|
844
848
|
except TrainingFinishedError:
|
845
849
|
if is_master():
|
846
|
-
|
847
|
-
|
848
|
-
important=True,
|
849
|
-
)
|
850
|
+
num_steps, num_samples = int(state.num_steps), int(state.num_samples)
|
851
|
+
show_info(f"Finished training after {num_steps} steps, {num_samples} samples", important=True)
|
850
852
|
self.save_checkpoint(model=model, optimizer=optimizer, opt_state=opt_state, state=state)
|
851
853
|
|
852
854
|
except (KeyboardInterrupt, bdb.BdbQuit):
|
xax/utils/experiments.py
CHANGED
@@ -111,8 +111,10 @@ class StateTimer:
|
|
111
111
|
|
112
112
|
def step(self, state: State) -> None:
|
113
113
|
cur_time = time.time()
|
114
|
-
|
115
|
-
|
114
|
+
num_steps = int((state.num_steps if state.phase == "train" else state.num_valid_steps).item())
|
115
|
+
num_samples = int((state.num_samples if state.phase == "train" else state.num_valid_samples).item())
|
116
|
+
self.step_timer.step(num_steps, cur_time)
|
117
|
+
self.sample_timer.step(num_samples, cur_time)
|
116
118
|
self.iter_timer.step(cur_time)
|
117
119
|
|
118
120
|
def log_dict(self) -> dict[str, int | float | tuple[int | float, bool]]:
|
xax/utils/jax.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
"""Defines some utility functions for interfacing with Jax."""
|
2
2
|
|
3
|
+
import functools
|
3
4
|
import inspect
|
4
5
|
import logging
|
5
6
|
import os
|
@@ -23,6 +24,28 @@ Number = int | float | np.ndarray | jnp.ndarray
|
|
23
24
|
P = ParamSpec("P") # For function parameters
|
24
25
|
R = TypeVar("R") # For function return type
|
25
26
|
|
27
|
+
# For control flow functions.
|
28
|
+
Carry = TypeVar("Carry")
|
29
|
+
X = TypeVar("X")
|
30
|
+
Y = TypeVar("Y")
|
31
|
+
|
32
|
+
|
33
|
+
@functools.lru_cache(maxsize=None)
|
34
|
+
def disable_jit_level() -> int:
|
35
|
+
"""Gets a debugging flag for disabling jitting.
|
36
|
+
|
37
|
+
For Xax's JIT'ed functions, we can set a JIT level which can be used to
|
38
|
+
disable jitting when we want to debug some NaN issues.
|
39
|
+
|
40
|
+
Returns:
|
41
|
+
The JIT level to disable.
|
42
|
+
"""
|
43
|
+
return int(os.environ.get("DISABLE_JIT_LEVEL", "0"))
|
44
|
+
|
45
|
+
|
46
|
+
def should_disable_jit(jit_level: int | None) -> bool:
|
47
|
+
return jit_level is not None and jit_level < disable_jit_level()
|
48
|
+
|
26
49
|
|
27
50
|
def as_float(value: int | float | np.ndarray | jnp.ndarray) -> float:
|
28
51
|
if isinstance(value, (int, float)):
|
@@ -55,6 +78,7 @@ def jit(
|
|
55
78
|
inline: bool = False,
|
56
79
|
abstracted_axes: Any | None = None, # noqa: ANN401
|
57
80
|
compiler_options: dict[str, Any] | None = None,
|
81
|
+
jit_level: int | None = None,
|
58
82
|
) -> Callable[[Callable[P, R]], Callable[P, R]]:
|
59
83
|
"""Wrapper function that provides utility improvements over Jax's JIT.
|
60
84
|
|
@@ -64,6 +88,8 @@ def jit(
|
|
64
88
|
This is meant to be used as a decorator factory, and the decorated function
|
65
89
|
calls `wrapped`.
|
66
90
|
"""
|
91
|
+
if should_disable_jit(jit_level):
|
92
|
+
return lambda fn: fn # Identity function.
|
67
93
|
|
68
94
|
def decorator(fn: Callable[P, R]) -> Callable[P, R]:
|
69
95
|
class JitState:
|
@@ -138,3 +164,46 @@ def jit(
|
|
138
164
|
return wrapped
|
139
165
|
|
140
166
|
return decorator
|
167
|
+
|
168
|
+
|
169
|
+
def scan(
|
170
|
+
f: Callable[[Carry, X], tuple[Carry, Y]],
|
171
|
+
init: Carry,
|
172
|
+
xs: X | None = None,
|
173
|
+
length: int | None = None,
|
174
|
+
reverse: bool = False,
|
175
|
+
unroll: int | bool = 1,
|
176
|
+
jit_level: int | None = None,
|
177
|
+
) -> tuple[Carry, Y]:
|
178
|
+
"""A wrapper around jax.lax.scan that allows for more flexible tracing.
|
179
|
+
|
180
|
+
If the provided JIT level is below the environment JIT level, we manually
|
181
|
+
unroll the scan function as a for loop.
|
182
|
+
|
183
|
+
Args:
|
184
|
+
f: The function to scan.
|
185
|
+
init: The initial value for the scan.
|
186
|
+
xs: The input to the scan.
|
187
|
+
length: The length of the scan.
|
188
|
+
reverse: Whether to reverse the scan.
|
189
|
+
unroll: The unroll factor for the scan.
|
190
|
+
jit_level: The JIT level to use for the scan.
|
191
|
+
|
192
|
+
Returns:
|
193
|
+
A tuple containing the final carry and the output of the scan.
|
194
|
+
"""
|
195
|
+
if not should_disable_jit(jit_level):
|
196
|
+
return jax.lax.scan(f, init, xs, length, reverse, unroll)
|
197
|
+
|
198
|
+
if xs is None:
|
199
|
+
if length is None:
|
200
|
+
raise ValueError("length must be provided if xs is None")
|
201
|
+
xs = cast(X, [None] * length)
|
202
|
+
|
203
|
+
carry = init
|
204
|
+
ys = []
|
205
|
+
for x in cast(Iterable, xs):
|
206
|
+
carry, y = f(carry, x)
|
207
|
+
ys.append(y)
|
208
|
+
|
209
|
+
return carry, jax.tree.map(lambda *ys: jnp.stack(ys), *ys)
|
@@ -1,10 +1,10 @@
|
|
1
|
-
xax/__init__.py,sha256=
|
1
|
+
xax/__init__.py,sha256=bMKUtRtmVnHshkD4Ylw7ymzIPpcasJAoXnBsIdoSEng,14225
|
2
2
|
xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
3
|
xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
|
4
4
|
xax/requirements.txt,sha256=6qY-84e-sTmlfJNrSjwONQKqzAn5h8G_oGIhnhmfSr4,302
|
5
5
|
xax/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
6
6
|
xax/core/conf.py,sha256=Wuo5WLRWuRTgb8eaihvnG_NZskTu0-P3JkIcl_hKINM,5124
|
7
|
-
xax/core/state.py,sha256=
|
7
|
+
xax/core/state.py,sha256=bJONQ0wXgbgo1jjSqV3JtqG5tdMlli93Nax_ftZ2D0w,4552
|
8
8
|
xax/nn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
9
9
|
xax/nn/embeddings.py,sha256=bQGxBFxkLwi2MQLkRfGaHPH5P_KKB21HdI7VNWTKIOQ,11847
|
10
10
|
xax/nn/equinox.py,sha256=5fdOKRXqAVZPsV-aEez3i1wamr_oBYnG74GP1jEthjM,4843
|
@@ -17,7 +17,7 @@ xax/nn/parallel.py,sha256=fnTiT7MsG7eQrJvqwjIz2Ifo3P27TuxIJzmpGYSa_dQ,4608
|
|
17
17
|
xax/nn/ssm.py,sha256=8dLAcQ1hBaMT-kkHvwGu_ecxJeTY32WeMYmd4T4KtxA,10745
|
18
18
|
xax/task/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
19
19
|
xax/task/base.py,sha256=OnXi2hiKPGwt6ng1dutnoQSiw7lEiWFlC_vx99_JsbQ,7694
|
20
|
-
xax/task/logger.py,sha256=
|
20
|
+
xax/task/logger.py,sha256=y4PGfMqKbfvPk8WCzr9MOsgG2X9E61KgeBVOYp-9kOY,40875
|
21
21
|
xax/task/script.py,sha256=bMMIJoUtpSBvPp6-7bejTrajTXvSg0794sYLKdPIToE,972
|
22
22
|
xax/task/task.py,sha256=UHMpnv__gqMcfbC_L-Hhk-DCnUYlFVsgbNf-v8o8B7U,1424
|
23
23
|
xax/task/launchers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -28,11 +28,11 @@ xax/task/loggers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,
|
|
28
28
|
xax/task/loggers/callback.py,sha256=lyuZX6Bir7xJM07ifdQIl1jlclgkiS82UO9V4y7wgPs,1582
|
29
29
|
xax/task/loggers/json.py,sha256=_tKum6jk_gqVzO-4MqSNXbE-Mmn-yJzkRAT-N1y2zes,4139
|
30
30
|
xax/task/loggers/state.py,sha256=6bG-NRsSUzAukYiglCT0oDj8zRMpffH4e1TKWGw1x4k,959
|
31
|
-
xax/task/loggers/stdout.py,sha256=
|
32
|
-
xax/task/loggers/tensorboard.py,sha256=
|
31
|
+
xax/task/loggers/stdout.py,sha256=ERLFrYe61hSSztzyxBRseobHQR72YFYjEd2i_hOeJ20,6595
|
32
|
+
xax/task/loggers/tensorboard.py,sha256=KFlsK0zD2ubDqAXYL4Ds7NQ9F-Ke-PHwfhLOYsGcbw4,8306
|
33
33
|
xax/task/mixins/__init__.py,sha256=D3oU31rB9FeOr9MPLleLt5JFbftUr4sBTwgnwQdc2qA,809
|
34
34
|
xax/task/mixins/artifacts.py,sha256=2ezmZGzPGe3nhsd9KRkeHWWXdbT9m7drzimIfw6v1XY,2892
|
35
|
-
xax/task/mixins/checkpointing.py,sha256=
|
35
|
+
xax/task/mixins/checkpointing.py,sha256=8Hi-2G0EA5OFRjgiOutlk7HgkD5b-0GHazOAYxnGytM,11409
|
36
36
|
xax/task/mixins/compile.py,sha256=PG5aF3W9v_xGiImHgUJ7gmwuQQoSQWufdpl2N_mlLX0,3922
|
37
37
|
xax/task/mixins/cpu_stats.py,sha256=rO_9a82ZdsNec61ya4FpYE-rWqPhpijRSXsOfc6caFA,9595
|
38
38
|
xax/task/mixins/data_loader.py,sha256=Tp7zqPdfH2_JuE6J6EP-fEtCQpq9MjKlGHYK7Zh-goU,6599
|
@@ -41,11 +41,11 @@ xax/task/mixins/logger.py,sha256=6oXsJJyNUx6YT3q58FVXMZBUpMgjVkGre6BXFN20cVI,280
|
|
41
41
|
xax/task/mixins/process.py,sha256=hqDEsMp_SL6ee97iq26-G0g49OcWZZaX82JD4F22eJU,1781
|
42
42
|
xax/task/mixins/runnable.py,sha256=IYIsLd2k09g-_y6o44EhJqT7E6BpsyEMmsyLSuzqjtc,1979
|
43
43
|
xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
|
44
|
-
xax/task/mixins/train.py,sha256=
|
44
|
+
xax/task/mixins/train.py,sha256=XcetJ0MppV_RDhgg1M9_d9heEXo-zeN_FS3MyczeBBU,31219
|
45
45
|
xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
46
46
|
xax/utils/debugging.py,sha256=OtUdu-3tQsQtik0Q9UM-SNV46IbPjwrAfZcywzoB5d4,1940
|
47
|
-
xax/utils/experiments.py,sha256=
|
48
|
-
xax/utils/jax.py,sha256=
|
47
|
+
xax/utils/experiments.py,sha256=d2H63ECtVOKySMUMrQRqq4kcuZpoXqo-L931usDVAhE,29903
|
48
|
+
xax/utils/jax.py,sha256=KQYUHjN6t6JIWa11aRSO3edcsAgTscw_dExxI6kCd9g,6767
|
49
49
|
xax/utils/jaxpr.py,sha256=S80nyEkv188RInzq3kCAdkQCU-bf6s0oPTrCE_LjkRs,2298
|
50
50
|
xax/utils/logging.py,sha256=GAhTne2rdB4Fa1lzk06DMO15U8MTejn6XTClShC-ZtU,6622
|
51
51
|
xax/utils/numpy.py,sha256=_jOXVi-d2AtJnRftPkRK5MDMzsU8slgw-Jjv4GRm6ns,1197
|
@@ -58,8 +58,8 @@ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,706
|
|
58
58
|
xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
59
59
|
xax/utils/types/frozen_dict.py,sha256=ZCMGfSfr2_b2qZbq9ywPD0zej5tpVSId2JftXpwfB5k,4686
|
60
60
|
xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
|
61
|
-
xax-0.2.
|
62
|
-
xax-0.2.
|
63
|
-
xax-0.2.
|
64
|
-
xax-0.2.
|
65
|
-
xax-0.2.
|
61
|
+
xax-0.2.4.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
|
62
|
+
xax-0.2.4.dist-info/METADATA,sha256=9hMsPCoszpjVN0rLDMlT20aYqmQwnHvl9T1V_0akl0U,1882
|
63
|
+
xax-0.2.4.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
64
|
+
xax-0.2.4.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
|
65
|
+
xax-0.2.4.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|