xax 0.2.5__tar.gz → 0.2.7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {xax-0.2.5/xax.egg-info → xax-0.2.7}/PKG-INFO +1 -1
- {xax-0.2.5 → xax-0.2.7}/xax/__init__.py +7 -2
- {xax-0.2.5 → xax-0.2.7}/xax/nn/functions.py +1 -1
- {xax-0.2.5 → xax-0.2.7}/xax/task/logger.py +2 -1
- {xax-0.2.5 → xax-0.2.7}/xax/task/loggers/json.py +1 -2
- {xax-0.2.5 → xax-0.2.7}/xax/task/mixins/checkpointing.py +108 -143
- {xax-0.2.5 → xax-0.2.7}/xax/task/mixins/train.py +34 -22
- {xax-0.2.5 → xax-0.2.7}/xax/utils/jaxpr.py +5 -5
- {xax-0.2.5 → xax-0.2.7}/xax/utils/pytree.py +1 -1
- {xax-0.2.5 → xax-0.2.7}/xax/utils/types/frozen_dict.py +1 -1
- {xax-0.2.5 → xax-0.2.7/xax.egg-info}/PKG-INFO +1 -1
- {xax-0.2.5 → xax-0.2.7}/LICENSE +0 -0
- {xax-0.2.5 → xax-0.2.7}/MANIFEST.in +0 -0
- {xax-0.2.5 → xax-0.2.7}/README.md +0 -0
- {xax-0.2.5 → xax-0.2.7}/pyproject.toml +0 -0
- {xax-0.2.5 → xax-0.2.7}/setup.cfg +0 -0
- {xax-0.2.5 → xax-0.2.7}/setup.py +0 -0
- {xax-0.2.5 → xax-0.2.7}/xax/core/__init__.py +0 -0
- {xax-0.2.5 → xax-0.2.7}/xax/core/conf.py +0 -0
- {xax-0.2.5 → xax-0.2.7}/xax/core/state.py +0 -0
- {xax-0.2.5 → xax-0.2.7}/xax/nn/__init__.py +0 -0
- {xax-0.2.5 → xax-0.2.7}/xax/nn/embeddings.py +0 -0
- {xax-0.2.5 → xax-0.2.7}/xax/nn/equinox.py +0 -0
- {xax-0.2.5 → xax-0.2.7}/xax/nn/export.py +0 -0
- {xax-0.2.5 → xax-0.2.7}/xax/nn/geom.py +0 -0
- {xax-0.2.5 → xax-0.2.7}/xax/nn/losses.py +0 -0
- {xax-0.2.5 → xax-0.2.7}/xax/nn/norm.py +0 -0
- {xax-0.2.5 → xax-0.2.7}/xax/nn/parallel.py +0 -0
- {xax-0.2.5 → xax-0.2.7}/xax/nn/ssm.py +0 -0
- {xax-0.2.5 → xax-0.2.7}/xax/py.typed +0 -0
- {xax-0.2.5 → xax-0.2.7}/xax/requirements-dev.txt +0 -0
- {xax-0.2.5 → xax-0.2.7}/xax/requirements.txt +0 -0
- {xax-0.2.5 → xax-0.2.7}/xax/task/__init__.py +0 -0
- {xax-0.2.5 → xax-0.2.7}/xax/task/base.py +0 -0
- {xax-0.2.5 → xax-0.2.7}/xax/task/launchers/__init__.py +0 -0
- {xax-0.2.5 → xax-0.2.7}/xax/task/launchers/base.py +0 -0
- {xax-0.2.5 → xax-0.2.7}/xax/task/launchers/cli.py +0 -0
- {xax-0.2.5 → xax-0.2.7}/xax/task/launchers/single_process.py +0 -0
- {xax-0.2.5 → xax-0.2.7}/xax/task/loggers/__init__.py +0 -0
- {xax-0.2.5 → xax-0.2.7}/xax/task/loggers/callback.py +0 -0
- {xax-0.2.5 → xax-0.2.7}/xax/task/loggers/state.py +0 -0
- {xax-0.2.5 → xax-0.2.7}/xax/task/loggers/stdout.py +0 -0
- {xax-0.2.5 → xax-0.2.7}/xax/task/loggers/tensorboard.py +0 -0
- {xax-0.2.5 → xax-0.2.7}/xax/task/mixins/__init__.py +0 -0
- {xax-0.2.5 → xax-0.2.7}/xax/task/mixins/artifacts.py +0 -0
- {xax-0.2.5 → xax-0.2.7}/xax/task/mixins/compile.py +0 -0
- {xax-0.2.5 → xax-0.2.7}/xax/task/mixins/cpu_stats.py +0 -0
- {xax-0.2.5 → xax-0.2.7}/xax/task/mixins/data_loader.py +0 -0
- {xax-0.2.5 → xax-0.2.7}/xax/task/mixins/gpu_stats.py +0 -0
- {xax-0.2.5 → xax-0.2.7}/xax/task/mixins/logger.py +0 -0
- {xax-0.2.5 → xax-0.2.7}/xax/task/mixins/process.py +0 -0
- {xax-0.2.5 → xax-0.2.7}/xax/task/mixins/runnable.py +0 -0
- {xax-0.2.5 → xax-0.2.7}/xax/task/mixins/step_wrapper.py +0 -0
- {xax-0.2.5 → xax-0.2.7}/xax/task/script.py +0 -0
- {xax-0.2.5 → xax-0.2.7}/xax/task/task.py +0 -0
- {xax-0.2.5 → xax-0.2.7}/xax/utils/__init__.py +0 -0
- {xax-0.2.5 → xax-0.2.7}/xax/utils/data/__init__.py +0 -0
- {xax-0.2.5 → xax-0.2.7}/xax/utils/data/collate.py +0 -0
- {xax-0.2.5 → xax-0.2.7}/xax/utils/debugging.py +0 -0
- {xax-0.2.5 → xax-0.2.7}/xax/utils/experiments.py +0 -0
- {xax-0.2.5 → xax-0.2.7}/xax/utils/jax.py +0 -0
- {xax-0.2.5 → xax-0.2.7}/xax/utils/logging.py +0 -0
- {xax-0.2.5 → xax-0.2.7}/xax/utils/numpy.py +0 -0
- {xax-0.2.5 → xax-0.2.7}/xax/utils/profile.py +0 -0
- {xax-0.2.5 → xax-0.2.7}/xax/utils/tensorboard.py +0 -0
- {xax-0.2.5 → xax-0.2.7}/xax/utils/text.py +0 -0
- {xax-0.2.5 → xax-0.2.7}/xax/utils/types/__init__.py +0 -0
- {xax-0.2.5 → xax-0.2.7}/xax/utils/types/hashable_array.py +0 -0
- {xax-0.2.5 → xax-0.2.7}/xax.egg-info/SOURCES.txt +0 -0
- {xax-0.2.5 → xax-0.2.7}/xax.egg-info/dependency_links.txt +0 -0
- {xax-0.2.5 → xax-0.2.7}/xax.egg-info/requires.txt +0 -0
- {xax-0.2.5 → xax-0.2.7}/xax.egg-info/top_level.txt +0 -0
@@ -12,7 +12,7 @@ and running the update script:
|
|
12
12
|
python -m scripts.update_api --inplace
|
13
13
|
"""
|
14
14
|
|
15
|
-
__version__ = "0.2.
|
15
|
+
__version__ = "0.2.7"
|
16
16
|
|
17
17
|
# This list shouldn't be modified by hand; instead, run the update script.
|
18
18
|
__all__ = [
|
@@ -66,11 +66,13 @@ __all__ = [
|
|
66
66
|
"StateLogger",
|
67
67
|
"StdoutLogger",
|
68
68
|
"TensorboardLogger",
|
69
|
+
"load_ckpt",
|
69
70
|
"CPUStatsOptions",
|
70
71
|
"DataloaderConfig",
|
71
72
|
"GPUStatsOptions",
|
72
73
|
"StepContext",
|
73
74
|
"ValidStepTimer",
|
75
|
+
"get_param_count",
|
74
76
|
"Script",
|
75
77
|
"ScriptConfig",
|
76
78
|
"Config",
|
@@ -230,11 +232,13 @@ NAME_MAP: dict[str, str] = {
|
|
230
232
|
"StateLogger": "task.loggers.state",
|
231
233
|
"StdoutLogger": "task.loggers.stdout",
|
232
234
|
"TensorboardLogger": "task.loggers.tensorboard",
|
235
|
+
"load_ckpt": "task.mixins.checkpointing",
|
233
236
|
"CPUStatsOptions": "task.mixins.cpu_stats",
|
234
237
|
"DataloaderConfig": "task.mixins.data_loader",
|
235
238
|
"GPUStatsOptions": "task.mixins.gpu_stats",
|
236
239
|
"StepContext": "task.mixins.step_wrapper",
|
237
240
|
"ValidStepTimer": "task.mixins.train",
|
241
|
+
"get_param_count": "task.mixins.train",
|
238
242
|
"Script": "task.script",
|
239
243
|
"ScriptConfig": "task.script",
|
240
244
|
"Config": "task.task",
|
@@ -390,11 +394,12 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
390
394
|
from xax.task.loggers.state import StateLogger
|
391
395
|
from xax.task.loggers.stdout import StdoutLogger
|
392
396
|
from xax.task.loggers.tensorboard import TensorboardLogger
|
397
|
+
from xax.task.mixins.checkpointing import load_ckpt
|
393
398
|
from xax.task.mixins.cpu_stats import CPUStatsOptions
|
394
399
|
from xax.task.mixins.data_loader import DataloaderConfig
|
395
400
|
from xax.task.mixins.gpu_stats import GPUStatsOptions
|
396
401
|
from xax.task.mixins.step_wrapper import StepContext
|
397
|
-
from xax.task.mixins.train import Batch, Output, ValidStepTimer
|
402
|
+
from xax.task.mixins.train import Batch, Output, ValidStepTimer, get_param_count
|
398
403
|
from xax.task.script import Script, ScriptConfig
|
399
404
|
from xax.task.task import Config, Task
|
400
405
|
from xax.utils.data.collate import CollateMode, collate, collate_non_null
|
@@ -521,7 +521,8 @@ class LoggerImpl(ABC):
|
|
521
521
|
Returns:
|
522
522
|
If the logger should log the current step.
|
523
523
|
"""
|
524
|
-
|
524
|
+
elapsed_time = state.elapsed_time_s.item() if state.phase == "train" else state.valid_elapsed_time_s.item()
|
525
|
+
return self.tickers[state.phase].tick(elapsed_time)
|
525
526
|
|
526
527
|
|
527
528
|
class ToastHandler(logging.Handler):
|
@@ -2,7 +2,6 @@
|
|
2
2
|
|
3
3
|
import json
|
4
4
|
import sys
|
5
|
-
from dataclasses import asdict
|
6
5
|
from typing import Any, Literal, Mapping, TextIO
|
7
6
|
|
8
7
|
from jaxtyping import Array
|
@@ -67,7 +66,7 @@ class JsonLogger(LoggerImpl):
|
|
67
66
|
return self.err_log_stream
|
68
67
|
|
69
68
|
def get_json(self, line: LogLine) -> str:
|
70
|
-
data: dict = {"state":
|
69
|
+
data: dict = {"state": line.state.to_dict()}
|
71
70
|
|
72
71
|
def add_logs(log: Mapping[str, Mapping[str, LogScalar | LogString]], data: dict) -> None:
|
73
72
|
for namespace, values in log.items():
|
@@ -52,6 +52,114 @@ class CheckpointingConfig(ArtifactsConfig):
|
|
52
52
|
Config = TypeVar("Config", bound=CheckpointingConfig)
|
53
53
|
|
54
54
|
|
55
|
+
@overload
|
56
|
+
def load_ckpt(
|
57
|
+
path: Path,
|
58
|
+
*,
|
59
|
+
part: Literal["all"],
|
60
|
+
model_template: PyTree,
|
61
|
+
optimizer_template: PyTree,
|
62
|
+
opt_state_template: PyTree,
|
63
|
+
) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State, DictConfig]: ...
|
64
|
+
|
65
|
+
|
66
|
+
@overload
|
67
|
+
def load_ckpt(
|
68
|
+
path: Path,
|
69
|
+
*,
|
70
|
+
part: Literal["model_state_config"],
|
71
|
+
model_template: PyTree,
|
72
|
+
) -> tuple[PyTree, State, DictConfig]: ...
|
73
|
+
|
74
|
+
|
75
|
+
@overload
|
76
|
+
def load_ckpt(path: Path, *, part: Literal["model"], model_template: PyTree) -> PyTree: ...
|
77
|
+
|
78
|
+
|
79
|
+
@overload
|
80
|
+
def load_ckpt(path: Path, *, part: Literal["opt"], optimizer_template: PyTree) -> optax.GradientTransformation: ...
|
81
|
+
|
82
|
+
|
83
|
+
@overload
|
84
|
+
def load_ckpt(path: Path, *, part: Literal["opt_state"], opt_state_template: PyTree) -> optax.OptState: ...
|
85
|
+
|
86
|
+
|
87
|
+
@overload
|
88
|
+
def load_ckpt(path: Path, *, part: Literal["state"]) -> State: ...
|
89
|
+
|
90
|
+
|
91
|
+
@overload
|
92
|
+
def load_ckpt(path: Path, *, part: Literal["config"]) -> DictConfig: ...
|
93
|
+
|
94
|
+
|
95
|
+
def load_ckpt(
|
96
|
+
path: str | Path,
|
97
|
+
*,
|
98
|
+
part: CheckpointPart = "model",
|
99
|
+
model_template: PyTree | None = None,
|
100
|
+
optimizer_template: PyTree | None = None,
|
101
|
+
opt_state_template: PyTree | None = None,
|
102
|
+
) -> (
|
103
|
+
tuple[PyTree, optax.GradientTransformation, optax.OptState, State, DictConfig]
|
104
|
+
| tuple[PyTree, State, DictConfig]
|
105
|
+
| PyTree
|
106
|
+
| optax.GradientTransformation
|
107
|
+
| optax.OptState
|
108
|
+
| State
|
109
|
+
| DictConfig
|
110
|
+
):
|
111
|
+
with tarfile.open(path, "r:gz") as tar:
|
112
|
+
|
113
|
+
def get_model() -> PyTree:
|
114
|
+
if model_template is None:
|
115
|
+
raise ValueError("model_template must be provided to load model weights")
|
116
|
+
if (model := tar.extractfile("model")) is None:
|
117
|
+
raise ValueError(f"Checkpoint does not contain a model file: {path}")
|
118
|
+
return eqx.tree_deserialise_leaves(io.BytesIO(model.read()), model_template)
|
119
|
+
|
120
|
+
def get_opt() -> optax.GradientTransformation:
|
121
|
+
if optimizer_template is None:
|
122
|
+
raise ValueError("optimizer_template must be provided to load optimizer")
|
123
|
+
if (opt := tar.extractfile("optimizer")) is None:
|
124
|
+
raise ValueError(f"Checkpoint does not contain an optimizer file: {path}")
|
125
|
+
return eqx.tree_deserialise_leaves(io.BytesIO(opt.read()), optimizer_template)
|
126
|
+
|
127
|
+
def get_opt_state() -> optax.OptState:
|
128
|
+
if opt_state_template is None:
|
129
|
+
raise ValueError("opt_state_template must be provided to load optimizer state")
|
130
|
+
if (opt_state := tar.extractfile("opt_state")) is None:
|
131
|
+
raise ValueError(f"Checkpoint does not contain an optimizer state file: {path}")
|
132
|
+
return eqx.tree_deserialise_leaves(io.BytesIO(opt_state.read()), opt_state_template)
|
133
|
+
|
134
|
+
def get_state() -> State:
|
135
|
+
if (state := tar.extractfile("state")) is None:
|
136
|
+
raise ValueError(f"Checkpoint does not contain a state file: {path}")
|
137
|
+
return State.from_dict(**json.loads(state.read().decode()))
|
138
|
+
|
139
|
+
def get_config() -> DictConfig:
|
140
|
+
if (config := tar.extractfile("config")) is None:
|
141
|
+
raise ValueError(f"Checkpoint does not contain a config file: {path}")
|
142
|
+
return cast(DictConfig, OmegaConf.load(config))
|
143
|
+
|
144
|
+
match part:
|
145
|
+
case "model":
|
146
|
+
return get_model()
|
147
|
+
case "opt":
|
148
|
+
return get_opt()
|
149
|
+
case "opt_state":
|
150
|
+
return get_opt_state()
|
151
|
+
case "state":
|
152
|
+
return get_state()
|
153
|
+
case "config":
|
154
|
+
return get_config()
|
155
|
+
case "model_state_config":
|
156
|
+
return get_model(), get_state(), get_config()
|
157
|
+
case "all":
|
158
|
+
return get_model(), get_opt(), get_opt_state(), get_state(), get_config()
|
159
|
+
case _:
|
160
|
+
raise ValueError(f"Invalid checkpoint part: {part}")
|
161
|
+
|
162
|
+
|
55
163
|
class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
|
56
164
|
def __init__(self, config: Config) -> None:
|
57
165
|
super().__init__(config)
|
@@ -82,149 +190,6 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
|
|
82
190
|
return True
|
83
191
|
return False
|
84
192
|
|
85
|
-
@overload
|
86
|
-
def load_ckpt_with_template(
|
87
|
-
self,
|
88
|
-
path: Path,
|
89
|
-
*,
|
90
|
-
part: Literal["all"],
|
91
|
-
model_template: PyTree,
|
92
|
-
optimizer_template: PyTree,
|
93
|
-
opt_state_template: PyTree,
|
94
|
-
) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State, Config]: ...
|
95
|
-
|
96
|
-
@overload
|
97
|
-
def load_ckpt_with_template(
|
98
|
-
self,
|
99
|
-
path: Path,
|
100
|
-
*,
|
101
|
-
part: Literal["model_state_config"],
|
102
|
-
model_template: PyTree,
|
103
|
-
) -> tuple[PyTree, State, Config]: ...
|
104
|
-
|
105
|
-
@overload
|
106
|
-
def load_ckpt_with_template(
|
107
|
-
self,
|
108
|
-
path: Path,
|
109
|
-
*,
|
110
|
-
part: Literal["model"],
|
111
|
-
model_template: PyTree,
|
112
|
-
) -> PyTree: ...
|
113
|
-
|
114
|
-
@overload
|
115
|
-
def load_ckpt_with_template(
|
116
|
-
self,
|
117
|
-
path: Path,
|
118
|
-
*,
|
119
|
-
part: Literal["opt"],
|
120
|
-
optimizer_template: PyTree,
|
121
|
-
) -> optax.GradientTransformation: ...
|
122
|
-
|
123
|
-
@overload
|
124
|
-
def load_ckpt_with_template(
|
125
|
-
self,
|
126
|
-
path: Path,
|
127
|
-
*,
|
128
|
-
part: Literal["opt_state"],
|
129
|
-
opt_state_template: PyTree,
|
130
|
-
) -> optax.OptState: ...
|
131
|
-
|
132
|
-
@overload
|
133
|
-
def load_ckpt_with_template(
|
134
|
-
self,
|
135
|
-
path: Path,
|
136
|
-
*,
|
137
|
-
part: Literal["state"],
|
138
|
-
) -> State: ...
|
139
|
-
|
140
|
-
@overload
|
141
|
-
def load_ckpt_with_template(
|
142
|
-
self,
|
143
|
-
path: Path,
|
144
|
-
*,
|
145
|
-
part: Literal["config"],
|
146
|
-
) -> Config: ...
|
147
|
-
|
148
|
-
def load_ckpt_with_template(
|
149
|
-
self,
|
150
|
-
path: Path,
|
151
|
-
*,
|
152
|
-
part: CheckpointPart = "all",
|
153
|
-
model_template: PyTree | None = None,
|
154
|
-
optimizer_template: PyTree | None = None,
|
155
|
-
opt_state_template: PyTree | None = None,
|
156
|
-
) -> (
|
157
|
-
tuple[PyTree, optax.GradientTransformation, optax.OptState, State, Config]
|
158
|
-
| tuple[PyTree, State, Config]
|
159
|
-
| PyTree
|
160
|
-
| optax.GradientTransformation
|
161
|
-
| optax.OptState
|
162
|
-
| State
|
163
|
-
| Config
|
164
|
-
):
|
165
|
-
"""Load a checkpoint.
|
166
|
-
|
167
|
-
Args:
|
168
|
-
path: Path to the checkpoint directory
|
169
|
-
part: Which part of the checkpoint to load
|
170
|
-
model_template: Template model with correct structure but uninitialized weights
|
171
|
-
optimizer_template: Template optimizer with correct structure but uninitialized weights
|
172
|
-
opt_state_template: Template optimizer state with correct structure but uninitialized weights
|
173
|
-
|
174
|
-
Returns:
|
175
|
-
The requested checkpoint components
|
176
|
-
"""
|
177
|
-
with tarfile.open(path, "r:gz") as tar:
|
178
|
-
|
179
|
-
def get_model() -> PyTree:
|
180
|
-
if model_template is None:
|
181
|
-
raise ValueError("model_template must be provided to load model weights")
|
182
|
-
if (model := tar.extractfile("model")) is None:
|
183
|
-
raise ValueError(f"Checkpoint does not contain a model file: {path}")
|
184
|
-
return eqx.tree_deserialise_leaves(io.BytesIO(model.read()), model_template)
|
185
|
-
|
186
|
-
def get_opt() -> optax.GradientTransformation:
|
187
|
-
if optimizer_template is None:
|
188
|
-
raise ValueError("optimizer_template must be provided to load optimizer")
|
189
|
-
if (opt := tar.extractfile("optimizer")) is None:
|
190
|
-
raise ValueError(f"Checkpoint does not contain an optimizer file: {path}")
|
191
|
-
return eqx.tree_deserialise_leaves(io.BytesIO(opt.read()), optimizer_template)
|
192
|
-
|
193
|
-
def get_opt_state() -> optax.OptState:
|
194
|
-
if opt_state_template is None:
|
195
|
-
raise ValueError("opt_state_template must be provided to load optimizer state")
|
196
|
-
if (opt_state := tar.extractfile("opt_state")) is None:
|
197
|
-
raise ValueError(f"Checkpoint does not contain an optimizer state file: {path}")
|
198
|
-
return eqx.tree_deserialise_leaves(io.BytesIO(opt_state.read()), opt_state_template)
|
199
|
-
|
200
|
-
def get_state() -> State:
|
201
|
-
if (state := tar.extractfile("state")) is None:
|
202
|
-
raise ValueError(f"Checkpoint does not contain a state file: {path}")
|
203
|
-
return State.from_dict(**json.loads(state.read().decode()))
|
204
|
-
|
205
|
-
def get_config() -> Config:
|
206
|
-
if (config := tar.extractfile("config")) is None:
|
207
|
-
raise ValueError(f"Checkpoint does not contain a config file: {path}")
|
208
|
-
return self.get_config(cast(DictConfig, OmegaConf.load(config)), use_cli=False)
|
209
|
-
|
210
|
-
match part:
|
211
|
-
case "model":
|
212
|
-
return get_model()
|
213
|
-
case "opt":
|
214
|
-
return get_opt()
|
215
|
-
case "opt_state":
|
216
|
-
return get_opt_state()
|
217
|
-
case "state":
|
218
|
-
return get_state()
|
219
|
-
case "config":
|
220
|
-
return get_config()
|
221
|
-
case "model_state_config":
|
222
|
-
return get_model(), get_state(), get_config()
|
223
|
-
case "all":
|
224
|
-
return get_model(), get_opt(), get_opt_state(), get_state(), get_config()
|
225
|
-
case _:
|
226
|
-
raise ValueError(f"Invalid checkpoint part: {part}")
|
227
|
-
|
228
193
|
def save_checkpoint(
|
229
194
|
self,
|
230
195
|
model: PyTree | None = None,
|
@@ -40,7 +40,7 @@ from xax.core.state import Phase, State
|
|
40
40
|
from xax.nn.functions import set_random_seed
|
41
41
|
from xax.nn.parallel import is_master
|
42
42
|
from xax.task.mixins.artifacts import ArtifactsConfig, ArtifactsMixin
|
43
|
-
from xax.task.mixins.checkpointing import CheckpointingConfig, CheckpointingMixin, CheckpointPart
|
43
|
+
from xax.task.mixins.checkpointing import CheckpointingConfig, CheckpointingMixin, CheckpointPart, load_ckpt
|
44
44
|
from xax.task.mixins.data_loader import DataloadersConfig, DataloadersMixin
|
45
45
|
from xax.task.mixins.logger import LoggerConfig, LoggerMixin
|
46
46
|
from xax.task.mixins.runnable import RunnableConfig, RunnableMixin
|
@@ -96,6 +96,12 @@ def batches_per_step_schedule(schedule: list[int] | None) -> list[int] | None:
|
|
96
96
|
return list(itertools.accumulate([0] + schedule))
|
97
97
|
|
98
98
|
|
99
|
+
def get_param_count(pytree: PyTree) -> int:
|
100
|
+
"""Calculates the total number of parameters in a PyTree."""
|
101
|
+
leaves, _ = jax.tree.flatten(pytree)
|
102
|
+
return sum(x.size for x in leaves if isinstance(x, jnp.ndarray))
|
103
|
+
|
104
|
+
|
99
105
|
class ValidStepTimer:
|
100
106
|
def __init__(
|
101
107
|
self,
|
@@ -115,19 +121,22 @@ class ValidStepTimer:
|
|
115
121
|
self.last_valid_time: float | None = None
|
116
122
|
self.last_valid_step: int | None = None
|
117
123
|
|
124
|
+
def _reset(self, state: State) -> None:
|
125
|
+
self.last_valid_time = state.elapsed_time_s.item()
|
126
|
+
self.last_valid_step = state.num_steps.item()
|
127
|
+
|
118
128
|
def is_valid_step(self, state: State) -> bool:
|
119
129
|
if state.num_steps < self.valid_first_n_steps:
|
120
130
|
return True
|
121
131
|
|
122
132
|
if self.last_valid_time is None or self.last_valid_step is None:
|
123
|
-
self.
|
124
|
-
self.last_valid_step = state.num_steps.item()
|
133
|
+
self._reset(state)
|
125
134
|
return False
|
126
135
|
|
127
136
|
# Step-based validation.
|
128
137
|
valid_every_n_steps = self.valid_every_n_steps
|
129
138
|
if valid_every_n_steps is not None and state.num_steps >= valid_every_n_steps + self.last_valid_step:
|
130
|
-
self.
|
139
|
+
self._reset(state)
|
131
140
|
return True
|
132
141
|
|
133
142
|
# Time-based validation.
|
@@ -136,14 +145,14 @@ class ValidStepTimer:
|
|
136
145
|
valid_every_n_seconds is not None
|
137
146
|
and state.elapsed_time_s.item() - self.last_valid_time >= valid_every_n_seconds
|
138
147
|
):
|
139
|
-
self.
|
148
|
+
self._reset(state)
|
140
149
|
return True
|
141
150
|
|
142
151
|
# Time-based validation for first validation step.
|
143
152
|
if self.first_valid_step_flag:
|
144
153
|
valid_first_n_seconds = self.valid_first_n_seconds
|
145
154
|
if valid_first_n_seconds is not None and state.elapsed_time_s.item() >= valid_first_n_seconds:
|
146
|
-
self.
|
155
|
+
self._reset(state)
|
147
156
|
self.first_valid_step_flag = False
|
148
157
|
return True
|
149
158
|
|
@@ -357,6 +366,7 @@ class TrainMixin(
|
|
357
366
|
model = self.get_model(key)
|
358
367
|
state = State.init_state()
|
359
368
|
|
369
|
+
self.log_model_size(model)
|
360
370
|
if not load_optimizer:
|
361
371
|
return model, state
|
362
372
|
|
@@ -447,44 +457,43 @@ class TrainMixin(
|
|
447
457
|
match part:
|
448
458
|
case "model_state_config":
|
449
459
|
model_spec = eqx.filter_eval_shape(self.get_model, key)
|
450
|
-
|
460
|
+
model, state, config = load_ckpt(path, part="model_state_config", model_template=model_spec)
|
461
|
+
config = self.get_config(config, use_cli=False)
|
462
|
+
return model, state, config
|
451
463
|
|
452
464
|
case "model":
|
453
465
|
model_spec = eqx.filter_eval_shape(self.get_model, key)
|
454
|
-
return
|
455
|
-
|
456
|
-
case "config":
|
457
|
-
return self.load_ckpt_with_template(path, part="config")
|
466
|
+
return load_ckpt(path, part="model", model_template=model_spec)
|
458
467
|
|
459
468
|
case "opt":
|
460
469
|
optimizer_spec = eqx.filter_eval_shape(self.get_optimizer)
|
461
|
-
return
|
470
|
+
return load_ckpt(path, part="opt", optimizer_template=optimizer_spec)
|
462
471
|
|
463
472
|
case "opt_state":
|
464
473
|
if model is None:
|
465
474
|
model_spec = eqx.filter_eval_shape(self.get_model, key)
|
466
|
-
model =
|
475
|
+
model = load_ckpt(path, part="model", model_template=model_spec)
|
467
476
|
if optimizer is None:
|
468
477
|
optimizer_spec = eqx.filter_eval_shape(self.get_optimizer)
|
469
|
-
optimizer =
|
478
|
+
optimizer = load_ckpt(path, part="opt", optimizer_template=optimizer_spec)
|
470
479
|
opt_state_spec = eqx.filter_eval_shape(self.get_initial_opt_state, model, optimizer)
|
471
|
-
return
|
480
|
+
return load_ckpt(path, part="opt_state", opt_state_template=opt_state_spec)
|
472
481
|
|
473
482
|
case "state":
|
474
|
-
return
|
483
|
+
return load_ckpt(path, part="state")
|
475
484
|
|
476
485
|
case "config":
|
477
|
-
return self.
|
486
|
+
return self.get_config(load_ckpt(path, part="config"), use_cli=False)
|
478
487
|
|
479
488
|
case "all":
|
480
489
|
model_spec = eqx.filter_eval_shape(self.get_model, key)
|
481
|
-
model =
|
490
|
+
model = load_ckpt(path, part="model", model_template=model_spec)
|
482
491
|
optimizer_spec = eqx.filter_eval_shape(self.get_optimizer)
|
483
|
-
optimizer =
|
492
|
+
optimizer = load_ckpt(path, part="opt", optimizer_template=optimizer_spec)
|
484
493
|
opt_state_spec = eqx.filter_eval_shape(self.get_initial_opt_state, model, optimizer)
|
485
|
-
opt_state =
|
486
|
-
state =
|
487
|
-
config = self.
|
494
|
+
opt_state = load_ckpt(path, part="opt_state", opt_state_template=opt_state_spec)
|
495
|
+
state = load_ckpt(path, part="state")
|
496
|
+
config = self.get_config(load_ckpt(path, part="config"), use_cli=False)
|
488
497
|
return model, optimizer, opt_state, state, config
|
489
498
|
|
490
499
|
case _:
|
@@ -680,6 +689,9 @@ class TrainMixin(
|
|
680
689
|
self.logger.log_file("config.yaml", self.config_str(self.config, use_cli=False))
|
681
690
|
self.logger.log_file("info.json", get_info_json())
|
682
691
|
|
692
|
+
def log_model_size(self, model: PyTree) -> None:
|
693
|
+
logger.info("Model size: %s", f"{get_param_count(model):,}")
|
694
|
+
|
683
695
|
def model_partition_fn(self, item: Any) -> bool: # noqa: ANN401
|
684
696
|
return eqx.is_inexact_array(item)
|
685
697
|
|
@@ -3,10 +3,10 @@
|
|
3
3
|
from pathlib import Path
|
4
4
|
|
5
5
|
import jax
|
6
|
-
import jax.core
|
6
|
+
import jax.extend.core
|
7
7
|
|
8
8
|
|
9
|
-
def save_jaxpr_dot(closed_jaxpr: jax.core.ClosedJaxpr, filename: str | Path) -> None:
|
9
|
+
def save_jaxpr_dot(closed_jaxpr: jax.extend.core.ClosedJaxpr, filename: str | Path) -> None:
|
10
10
|
"""Save the JAXPR to a DOT file.
|
11
11
|
|
12
12
|
Example usage:
|
@@ -30,15 +30,15 @@ def save_jaxpr_dot(closed_jaxpr: jax.core.ClosedJaxpr, filename: str | Path) ->
|
|
30
30
|
with open(filename, "w") as f:
|
31
31
|
f.write("digraph Jaxpr {\n")
|
32
32
|
|
33
|
-
var_names: dict[jax.core.Var, str] = {}
|
33
|
+
var_names: dict[jax.extend.core.Var, str] = {}
|
34
34
|
var_count = 0
|
35
35
|
|
36
|
-
def get_var_name(var: jax.core.Var) -> str:
|
36
|
+
def get_var_name(var: jax.extend.core.Var) -> str:
|
37
37
|
"""Get a unique name for a variable."""
|
38
38
|
nonlocal var_names, var_count
|
39
39
|
|
40
40
|
# Handle Literal objects specially since they're not hashable
|
41
|
-
if isinstance(var, jax.core.Literal):
|
41
|
+
if isinstance(var, jax.extend.core.Literal):
|
42
42
|
# Create a name based on the literal value
|
43
43
|
name = f"lit_{var.val}"
|
44
44
|
return name
|
@@ -57,7 +57,7 @@ def pytree_has_nans(pytree: PyTree) -> Array:
|
|
57
57
|
|
58
58
|
def update_pytree(cond: Array, new: PyTree, original: PyTree) -> PyTree:
|
59
59
|
"""Update a pytree based on a condition."""
|
60
|
-
# Tricky, need use
|
60
|
+
# Tricky, need use tree.map because where expects array leafs.
|
61
61
|
return jax.tree.map(lambda x, y: jnp.where(cond, x, y), new, original)
|
62
62
|
|
63
63
|
|
@@ -138,7 +138,7 @@ class FrozenDict(Mapping[K, V]):
|
|
138
138
|
|
139
139
|
def unfreeze(x: FrozenDict[K, V] | dict[str, Any]) -> dict[Any, Any]: # noqa: ANN401
|
140
140
|
if isinstance(x, FrozenDict):
|
141
|
-
return jax.
|
141
|
+
return jax.tree.map(lambda y: y, x._dict)
|
142
142
|
elif isinstance(x, dict):
|
143
143
|
ys = {}
|
144
144
|
for key, value in x.items():
|
{xax-0.2.5 → xax-0.2.7}/LICENSE
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{xax-0.2.5 → xax-0.2.7}/setup.py
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|