xax 0.2.6__py3-none-any.whl → 0.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +7 -2
- xax/nn/functions.py +1 -1
- xax/task/loggers/json.py +1 -2
- xax/task/mixins/checkpointing.py +108 -143
- xax/task/mixins/train.py +26 -17
- xax/utils/jaxpr.py +5 -5
- xax/utils/pytree.py +1 -1
- xax/utils/types/frozen_dict.py +1 -1
- {xax-0.2.6.dist-info → xax-0.2.7.dist-info}/METADATA +1 -1
- {xax-0.2.6.dist-info → xax-0.2.7.dist-info}/RECORD +13 -13
- {xax-0.2.6.dist-info → xax-0.2.7.dist-info}/WHEEL +0 -0
- {xax-0.2.6.dist-info → xax-0.2.7.dist-info}/licenses/LICENSE +0 -0
- {xax-0.2.6.dist-info → xax-0.2.7.dist-info}/top_level.txt +0 -0
xax/__init__.py
CHANGED
@@ -12,7 +12,7 @@ and running the update script:
|
|
12
12
|
python -m scripts.update_api --inplace
|
13
13
|
"""
|
14
14
|
|
15
|
-
__version__ = "0.2.
|
15
|
+
__version__ = "0.2.7"
|
16
16
|
|
17
17
|
# This list shouldn't be modified by hand; instead, run the update script.
|
18
18
|
__all__ = [
|
@@ -66,11 +66,13 @@ __all__ = [
|
|
66
66
|
"StateLogger",
|
67
67
|
"StdoutLogger",
|
68
68
|
"TensorboardLogger",
|
69
|
+
"load_ckpt",
|
69
70
|
"CPUStatsOptions",
|
70
71
|
"DataloaderConfig",
|
71
72
|
"GPUStatsOptions",
|
72
73
|
"StepContext",
|
73
74
|
"ValidStepTimer",
|
75
|
+
"get_param_count",
|
74
76
|
"Script",
|
75
77
|
"ScriptConfig",
|
76
78
|
"Config",
|
@@ -230,11 +232,13 @@ NAME_MAP: dict[str, str] = {
|
|
230
232
|
"StateLogger": "task.loggers.state",
|
231
233
|
"StdoutLogger": "task.loggers.stdout",
|
232
234
|
"TensorboardLogger": "task.loggers.tensorboard",
|
235
|
+
"load_ckpt": "task.mixins.checkpointing",
|
233
236
|
"CPUStatsOptions": "task.mixins.cpu_stats",
|
234
237
|
"DataloaderConfig": "task.mixins.data_loader",
|
235
238
|
"GPUStatsOptions": "task.mixins.gpu_stats",
|
236
239
|
"StepContext": "task.mixins.step_wrapper",
|
237
240
|
"ValidStepTimer": "task.mixins.train",
|
241
|
+
"get_param_count": "task.mixins.train",
|
238
242
|
"Script": "task.script",
|
239
243
|
"ScriptConfig": "task.script",
|
240
244
|
"Config": "task.task",
|
@@ -390,11 +394,12 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
390
394
|
from xax.task.loggers.state import StateLogger
|
391
395
|
from xax.task.loggers.stdout import StdoutLogger
|
392
396
|
from xax.task.loggers.tensorboard import TensorboardLogger
|
397
|
+
from xax.task.mixins.checkpointing import load_ckpt
|
393
398
|
from xax.task.mixins.cpu_stats import CPUStatsOptions
|
394
399
|
from xax.task.mixins.data_loader import DataloaderConfig
|
395
400
|
from xax.task.mixins.gpu_stats import GPUStatsOptions
|
396
401
|
from xax.task.mixins.step_wrapper import StepContext
|
397
|
-
from xax.task.mixins.train import Batch, Output, ValidStepTimer
|
402
|
+
from xax.task.mixins.train import Batch, Output, ValidStepTimer, get_param_count
|
398
403
|
from xax.task.script import Script, ScriptConfig
|
399
404
|
from xax.task.task import Config, Task
|
400
405
|
from xax.utils.data.collate import CollateMode, collate, collate_non_null
|
xax/nn/functions.py
CHANGED
xax/task/loggers/json.py
CHANGED
@@ -2,7 +2,6 @@
|
|
2
2
|
|
3
3
|
import json
|
4
4
|
import sys
|
5
|
-
from dataclasses import asdict
|
6
5
|
from typing import Any, Literal, Mapping, TextIO
|
7
6
|
|
8
7
|
from jaxtyping import Array
|
@@ -67,7 +66,7 @@ class JsonLogger(LoggerImpl):
|
|
67
66
|
return self.err_log_stream
|
68
67
|
|
69
68
|
def get_json(self, line: LogLine) -> str:
|
70
|
-
data: dict = {"state":
|
69
|
+
data: dict = {"state": line.state.to_dict()}
|
71
70
|
|
72
71
|
def add_logs(log: Mapping[str, Mapping[str, LogScalar | LogString]], data: dict) -> None:
|
73
72
|
for namespace, values in log.items():
|
xax/task/mixins/checkpointing.py
CHANGED
@@ -52,6 +52,114 @@ class CheckpointingConfig(ArtifactsConfig):
|
|
52
52
|
Config = TypeVar("Config", bound=CheckpointingConfig)
|
53
53
|
|
54
54
|
|
55
|
+
@overload
|
56
|
+
def load_ckpt(
|
57
|
+
path: Path,
|
58
|
+
*,
|
59
|
+
part: Literal["all"],
|
60
|
+
model_template: PyTree,
|
61
|
+
optimizer_template: PyTree,
|
62
|
+
opt_state_template: PyTree,
|
63
|
+
) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State, DictConfig]: ...
|
64
|
+
|
65
|
+
|
66
|
+
@overload
|
67
|
+
def load_ckpt(
|
68
|
+
path: Path,
|
69
|
+
*,
|
70
|
+
part: Literal["model_state_config"],
|
71
|
+
model_template: PyTree,
|
72
|
+
) -> tuple[PyTree, State, DictConfig]: ...
|
73
|
+
|
74
|
+
|
75
|
+
@overload
|
76
|
+
def load_ckpt(path: Path, *, part: Literal["model"], model_template: PyTree) -> PyTree: ...
|
77
|
+
|
78
|
+
|
79
|
+
@overload
|
80
|
+
def load_ckpt(path: Path, *, part: Literal["opt"], optimizer_template: PyTree) -> optax.GradientTransformation: ...
|
81
|
+
|
82
|
+
|
83
|
+
@overload
|
84
|
+
def load_ckpt(path: Path, *, part: Literal["opt_state"], opt_state_template: PyTree) -> optax.OptState: ...
|
85
|
+
|
86
|
+
|
87
|
+
@overload
|
88
|
+
def load_ckpt(path: Path, *, part: Literal["state"]) -> State: ...
|
89
|
+
|
90
|
+
|
91
|
+
@overload
|
92
|
+
def load_ckpt(path: Path, *, part: Literal["config"]) -> DictConfig: ...
|
93
|
+
|
94
|
+
|
95
|
+
def load_ckpt(
|
96
|
+
path: str | Path,
|
97
|
+
*,
|
98
|
+
part: CheckpointPart = "model",
|
99
|
+
model_template: PyTree | None = None,
|
100
|
+
optimizer_template: PyTree | None = None,
|
101
|
+
opt_state_template: PyTree | None = None,
|
102
|
+
) -> (
|
103
|
+
tuple[PyTree, optax.GradientTransformation, optax.OptState, State, DictConfig]
|
104
|
+
| tuple[PyTree, State, DictConfig]
|
105
|
+
| PyTree
|
106
|
+
| optax.GradientTransformation
|
107
|
+
| optax.OptState
|
108
|
+
| State
|
109
|
+
| DictConfig
|
110
|
+
):
|
111
|
+
with tarfile.open(path, "r:gz") as tar:
|
112
|
+
|
113
|
+
def get_model() -> PyTree:
|
114
|
+
if model_template is None:
|
115
|
+
raise ValueError("model_template must be provided to load model weights")
|
116
|
+
if (model := tar.extractfile("model")) is None:
|
117
|
+
raise ValueError(f"Checkpoint does not contain a model file: {path}")
|
118
|
+
return eqx.tree_deserialise_leaves(io.BytesIO(model.read()), model_template)
|
119
|
+
|
120
|
+
def get_opt() -> optax.GradientTransformation:
|
121
|
+
if optimizer_template is None:
|
122
|
+
raise ValueError("optimizer_template must be provided to load optimizer")
|
123
|
+
if (opt := tar.extractfile("optimizer")) is None:
|
124
|
+
raise ValueError(f"Checkpoint does not contain an optimizer file: {path}")
|
125
|
+
return eqx.tree_deserialise_leaves(io.BytesIO(opt.read()), optimizer_template)
|
126
|
+
|
127
|
+
def get_opt_state() -> optax.OptState:
|
128
|
+
if opt_state_template is None:
|
129
|
+
raise ValueError("opt_state_template must be provided to load optimizer state")
|
130
|
+
if (opt_state := tar.extractfile("opt_state")) is None:
|
131
|
+
raise ValueError(f"Checkpoint does not contain an optimizer state file: {path}")
|
132
|
+
return eqx.tree_deserialise_leaves(io.BytesIO(opt_state.read()), opt_state_template)
|
133
|
+
|
134
|
+
def get_state() -> State:
|
135
|
+
if (state := tar.extractfile("state")) is None:
|
136
|
+
raise ValueError(f"Checkpoint does not contain a state file: {path}")
|
137
|
+
return State.from_dict(**json.loads(state.read().decode()))
|
138
|
+
|
139
|
+
def get_config() -> DictConfig:
|
140
|
+
if (config := tar.extractfile("config")) is None:
|
141
|
+
raise ValueError(f"Checkpoint does not contain a config file: {path}")
|
142
|
+
return cast(DictConfig, OmegaConf.load(config))
|
143
|
+
|
144
|
+
match part:
|
145
|
+
case "model":
|
146
|
+
return get_model()
|
147
|
+
case "opt":
|
148
|
+
return get_opt()
|
149
|
+
case "opt_state":
|
150
|
+
return get_opt_state()
|
151
|
+
case "state":
|
152
|
+
return get_state()
|
153
|
+
case "config":
|
154
|
+
return get_config()
|
155
|
+
case "model_state_config":
|
156
|
+
return get_model(), get_state(), get_config()
|
157
|
+
case "all":
|
158
|
+
return get_model(), get_opt(), get_opt_state(), get_state(), get_config()
|
159
|
+
case _:
|
160
|
+
raise ValueError(f"Invalid checkpoint part: {part}")
|
161
|
+
|
162
|
+
|
55
163
|
class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
|
56
164
|
def __init__(self, config: Config) -> None:
|
57
165
|
super().__init__(config)
|
@@ -82,149 +190,6 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
|
|
82
190
|
return True
|
83
191
|
return False
|
84
192
|
|
85
|
-
@overload
|
86
|
-
def load_ckpt_with_template(
|
87
|
-
self,
|
88
|
-
path: Path,
|
89
|
-
*,
|
90
|
-
part: Literal["all"],
|
91
|
-
model_template: PyTree,
|
92
|
-
optimizer_template: PyTree,
|
93
|
-
opt_state_template: PyTree,
|
94
|
-
) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State, Config]: ...
|
95
|
-
|
96
|
-
@overload
|
97
|
-
def load_ckpt_with_template(
|
98
|
-
self,
|
99
|
-
path: Path,
|
100
|
-
*,
|
101
|
-
part: Literal["model_state_config"],
|
102
|
-
model_template: PyTree,
|
103
|
-
) -> tuple[PyTree, State, Config]: ...
|
104
|
-
|
105
|
-
@overload
|
106
|
-
def load_ckpt_with_template(
|
107
|
-
self,
|
108
|
-
path: Path,
|
109
|
-
*,
|
110
|
-
part: Literal["model"],
|
111
|
-
model_template: PyTree,
|
112
|
-
) -> PyTree: ...
|
113
|
-
|
114
|
-
@overload
|
115
|
-
def load_ckpt_with_template(
|
116
|
-
self,
|
117
|
-
path: Path,
|
118
|
-
*,
|
119
|
-
part: Literal["opt"],
|
120
|
-
optimizer_template: PyTree,
|
121
|
-
) -> optax.GradientTransformation: ...
|
122
|
-
|
123
|
-
@overload
|
124
|
-
def load_ckpt_with_template(
|
125
|
-
self,
|
126
|
-
path: Path,
|
127
|
-
*,
|
128
|
-
part: Literal["opt_state"],
|
129
|
-
opt_state_template: PyTree,
|
130
|
-
) -> optax.OptState: ...
|
131
|
-
|
132
|
-
@overload
|
133
|
-
def load_ckpt_with_template(
|
134
|
-
self,
|
135
|
-
path: Path,
|
136
|
-
*,
|
137
|
-
part: Literal["state"],
|
138
|
-
) -> State: ...
|
139
|
-
|
140
|
-
@overload
|
141
|
-
def load_ckpt_with_template(
|
142
|
-
self,
|
143
|
-
path: Path,
|
144
|
-
*,
|
145
|
-
part: Literal["config"],
|
146
|
-
) -> Config: ...
|
147
|
-
|
148
|
-
def load_ckpt_with_template(
|
149
|
-
self,
|
150
|
-
path: Path,
|
151
|
-
*,
|
152
|
-
part: CheckpointPart = "all",
|
153
|
-
model_template: PyTree | None = None,
|
154
|
-
optimizer_template: PyTree | None = None,
|
155
|
-
opt_state_template: PyTree | None = None,
|
156
|
-
) -> (
|
157
|
-
tuple[PyTree, optax.GradientTransformation, optax.OptState, State, Config]
|
158
|
-
| tuple[PyTree, State, Config]
|
159
|
-
| PyTree
|
160
|
-
| optax.GradientTransformation
|
161
|
-
| optax.OptState
|
162
|
-
| State
|
163
|
-
| Config
|
164
|
-
):
|
165
|
-
"""Load a checkpoint.
|
166
|
-
|
167
|
-
Args:
|
168
|
-
path: Path to the checkpoint directory
|
169
|
-
part: Which part of the checkpoint to load
|
170
|
-
model_template: Template model with correct structure but uninitialized weights
|
171
|
-
optimizer_template: Template optimizer with correct structure but uninitialized weights
|
172
|
-
opt_state_template: Template optimizer state with correct structure but uninitialized weights
|
173
|
-
|
174
|
-
Returns:
|
175
|
-
The requested checkpoint components
|
176
|
-
"""
|
177
|
-
with tarfile.open(path, "r:gz") as tar:
|
178
|
-
|
179
|
-
def get_model() -> PyTree:
|
180
|
-
if model_template is None:
|
181
|
-
raise ValueError("model_template must be provided to load model weights")
|
182
|
-
if (model := tar.extractfile("model")) is None:
|
183
|
-
raise ValueError(f"Checkpoint does not contain a model file: {path}")
|
184
|
-
return eqx.tree_deserialise_leaves(io.BytesIO(model.read()), model_template)
|
185
|
-
|
186
|
-
def get_opt() -> optax.GradientTransformation:
|
187
|
-
if optimizer_template is None:
|
188
|
-
raise ValueError("optimizer_template must be provided to load optimizer")
|
189
|
-
if (opt := tar.extractfile("optimizer")) is None:
|
190
|
-
raise ValueError(f"Checkpoint does not contain an optimizer file: {path}")
|
191
|
-
return eqx.tree_deserialise_leaves(io.BytesIO(opt.read()), optimizer_template)
|
192
|
-
|
193
|
-
def get_opt_state() -> optax.OptState:
|
194
|
-
if opt_state_template is None:
|
195
|
-
raise ValueError("opt_state_template must be provided to load optimizer state")
|
196
|
-
if (opt_state := tar.extractfile("opt_state")) is None:
|
197
|
-
raise ValueError(f"Checkpoint does not contain an optimizer state file: {path}")
|
198
|
-
return eqx.tree_deserialise_leaves(io.BytesIO(opt_state.read()), opt_state_template)
|
199
|
-
|
200
|
-
def get_state() -> State:
|
201
|
-
if (state := tar.extractfile("state")) is None:
|
202
|
-
raise ValueError(f"Checkpoint does not contain a state file: {path}")
|
203
|
-
return State.from_dict(**json.loads(state.read().decode()))
|
204
|
-
|
205
|
-
def get_config() -> Config:
|
206
|
-
if (config := tar.extractfile("config")) is None:
|
207
|
-
raise ValueError(f"Checkpoint does not contain a config file: {path}")
|
208
|
-
return self.get_config(cast(DictConfig, OmegaConf.load(config)), use_cli=False)
|
209
|
-
|
210
|
-
match part:
|
211
|
-
case "model":
|
212
|
-
return get_model()
|
213
|
-
case "opt":
|
214
|
-
return get_opt()
|
215
|
-
case "opt_state":
|
216
|
-
return get_opt_state()
|
217
|
-
case "state":
|
218
|
-
return get_state()
|
219
|
-
case "config":
|
220
|
-
return get_config()
|
221
|
-
case "model_state_config":
|
222
|
-
return get_model(), get_state(), get_config()
|
223
|
-
case "all":
|
224
|
-
return get_model(), get_opt(), get_opt_state(), get_state(), get_config()
|
225
|
-
case _:
|
226
|
-
raise ValueError(f"Invalid checkpoint part: {part}")
|
227
|
-
|
228
193
|
def save_checkpoint(
|
229
194
|
self,
|
230
195
|
model: PyTree | None = None,
|
xax/task/mixins/train.py
CHANGED
@@ -40,7 +40,7 @@ from xax.core.state import Phase, State
|
|
40
40
|
from xax.nn.functions import set_random_seed
|
41
41
|
from xax.nn.parallel import is_master
|
42
42
|
from xax.task.mixins.artifacts import ArtifactsConfig, ArtifactsMixin
|
43
|
-
from xax.task.mixins.checkpointing import CheckpointingConfig, CheckpointingMixin, CheckpointPart
|
43
|
+
from xax.task.mixins.checkpointing import CheckpointingConfig, CheckpointingMixin, CheckpointPart, load_ckpt
|
44
44
|
from xax.task.mixins.data_loader import DataloadersConfig, DataloadersMixin
|
45
45
|
from xax.task.mixins.logger import LoggerConfig, LoggerMixin
|
46
46
|
from xax.task.mixins.runnable import RunnableConfig, RunnableMixin
|
@@ -96,6 +96,12 @@ def batches_per_step_schedule(schedule: list[int] | None) -> list[int] | None:
|
|
96
96
|
return list(itertools.accumulate([0] + schedule))
|
97
97
|
|
98
98
|
|
99
|
+
def get_param_count(pytree: PyTree) -> int:
|
100
|
+
"""Calculates the total number of parameters in a PyTree."""
|
101
|
+
leaves, _ = jax.tree.flatten(pytree)
|
102
|
+
return sum(x.size for x in leaves if isinstance(x, jnp.ndarray))
|
103
|
+
|
104
|
+
|
99
105
|
class ValidStepTimer:
|
100
106
|
def __init__(
|
101
107
|
self,
|
@@ -360,6 +366,7 @@ class TrainMixin(
|
|
360
366
|
model = self.get_model(key)
|
361
367
|
state = State.init_state()
|
362
368
|
|
369
|
+
self.log_model_size(model)
|
363
370
|
if not load_optimizer:
|
364
371
|
return model, state
|
365
372
|
|
@@ -450,44 +457,43 @@ class TrainMixin(
|
|
450
457
|
match part:
|
451
458
|
case "model_state_config":
|
452
459
|
model_spec = eqx.filter_eval_shape(self.get_model, key)
|
453
|
-
|
460
|
+
model, state, config = load_ckpt(path, part="model_state_config", model_template=model_spec)
|
461
|
+
config = self.get_config(config, use_cli=False)
|
462
|
+
return model, state, config
|
454
463
|
|
455
464
|
case "model":
|
456
465
|
model_spec = eqx.filter_eval_shape(self.get_model, key)
|
457
|
-
return
|
458
|
-
|
459
|
-
case "config":
|
460
|
-
return self.load_ckpt_with_template(path, part="config")
|
466
|
+
return load_ckpt(path, part="model", model_template=model_spec)
|
461
467
|
|
462
468
|
case "opt":
|
463
469
|
optimizer_spec = eqx.filter_eval_shape(self.get_optimizer)
|
464
|
-
return
|
470
|
+
return load_ckpt(path, part="opt", optimizer_template=optimizer_spec)
|
465
471
|
|
466
472
|
case "opt_state":
|
467
473
|
if model is None:
|
468
474
|
model_spec = eqx.filter_eval_shape(self.get_model, key)
|
469
|
-
model =
|
475
|
+
model = load_ckpt(path, part="model", model_template=model_spec)
|
470
476
|
if optimizer is None:
|
471
477
|
optimizer_spec = eqx.filter_eval_shape(self.get_optimizer)
|
472
|
-
optimizer =
|
478
|
+
optimizer = load_ckpt(path, part="opt", optimizer_template=optimizer_spec)
|
473
479
|
opt_state_spec = eqx.filter_eval_shape(self.get_initial_opt_state, model, optimizer)
|
474
|
-
return
|
480
|
+
return load_ckpt(path, part="opt_state", opt_state_template=opt_state_spec)
|
475
481
|
|
476
482
|
case "state":
|
477
|
-
return
|
483
|
+
return load_ckpt(path, part="state")
|
478
484
|
|
479
485
|
case "config":
|
480
|
-
return self.
|
486
|
+
return self.get_config(load_ckpt(path, part="config"), use_cli=False)
|
481
487
|
|
482
488
|
case "all":
|
483
489
|
model_spec = eqx.filter_eval_shape(self.get_model, key)
|
484
|
-
model =
|
490
|
+
model = load_ckpt(path, part="model", model_template=model_spec)
|
485
491
|
optimizer_spec = eqx.filter_eval_shape(self.get_optimizer)
|
486
|
-
optimizer =
|
492
|
+
optimizer = load_ckpt(path, part="opt", optimizer_template=optimizer_spec)
|
487
493
|
opt_state_spec = eqx.filter_eval_shape(self.get_initial_opt_state, model, optimizer)
|
488
|
-
opt_state =
|
489
|
-
state =
|
490
|
-
config = self.
|
494
|
+
opt_state = load_ckpt(path, part="opt_state", opt_state_template=opt_state_spec)
|
495
|
+
state = load_ckpt(path, part="state")
|
496
|
+
config = self.get_config(load_ckpt(path, part="config"), use_cli=False)
|
491
497
|
return model, optimizer, opt_state, state, config
|
492
498
|
|
493
499
|
case _:
|
@@ -683,6 +689,9 @@ class TrainMixin(
|
|
683
689
|
self.logger.log_file("config.yaml", self.config_str(self.config, use_cli=False))
|
684
690
|
self.logger.log_file("info.json", get_info_json())
|
685
691
|
|
692
|
+
def log_model_size(self, model: PyTree) -> None:
|
693
|
+
logger.info("Model size: %s", f"{get_param_count(model):,}")
|
694
|
+
|
686
695
|
def model_partition_fn(self, item: Any) -> bool: # noqa: ANN401
|
687
696
|
return eqx.is_inexact_array(item)
|
688
697
|
|
xax/utils/jaxpr.py
CHANGED
@@ -3,10 +3,10 @@
|
|
3
3
|
from pathlib import Path
|
4
4
|
|
5
5
|
import jax
|
6
|
-
import jax.core
|
6
|
+
import jax.extend.core
|
7
7
|
|
8
8
|
|
9
|
-
def save_jaxpr_dot(closed_jaxpr: jax.core.ClosedJaxpr, filename: str | Path) -> None:
|
9
|
+
def save_jaxpr_dot(closed_jaxpr: jax.extend.core.ClosedJaxpr, filename: str | Path) -> None:
|
10
10
|
"""Save the JAXPR to a DOT file.
|
11
11
|
|
12
12
|
Example usage:
|
@@ -30,15 +30,15 @@ def save_jaxpr_dot(closed_jaxpr: jax.core.ClosedJaxpr, filename: str | Path) ->
|
|
30
30
|
with open(filename, "w") as f:
|
31
31
|
f.write("digraph Jaxpr {\n")
|
32
32
|
|
33
|
-
var_names: dict[jax.core.Var, str] = {}
|
33
|
+
var_names: dict[jax.extend.core.Var, str] = {}
|
34
34
|
var_count = 0
|
35
35
|
|
36
|
-
def get_var_name(var: jax.core.Var) -> str:
|
36
|
+
def get_var_name(var: jax.extend.core.Var) -> str:
|
37
37
|
"""Get a unique name for a variable."""
|
38
38
|
nonlocal var_names, var_count
|
39
39
|
|
40
40
|
# Handle Literal objects specially since they're not hashable
|
41
|
-
if isinstance(var, jax.core.Literal):
|
41
|
+
if isinstance(var, jax.extend.core.Literal):
|
42
42
|
# Create a name based on the literal value
|
43
43
|
name = f"lit_{var.val}"
|
44
44
|
return name
|
xax/utils/pytree.py
CHANGED
@@ -57,7 +57,7 @@ def pytree_has_nans(pytree: PyTree) -> Array:
|
|
57
57
|
|
58
58
|
def update_pytree(cond: Array, new: PyTree, original: PyTree) -> PyTree:
|
59
59
|
"""Update a pytree based on a condition."""
|
60
|
-
# Tricky, need use
|
60
|
+
# Tricky, need use tree.map because where expects array leafs.
|
61
61
|
return jax.tree.map(lambda x, y: jnp.where(cond, x, y), new, original)
|
62
62
|
|
63
63
|
|
xax/utils/types/frozen_dict.py
CHANGED
@@ -138,7 +138,7 @@ class FrozenDict(Mapping[K, V]):
|
|
138
138
|
|
139
139
|
def unfreeze(x: FrozenDict[K, V] | dict[str, Any]) -> dict[Any, Any]: # noqa: ANN401
|
140
140
|
if isinstance(x, FrozenDict):
|
141
|
-
return jax.
|
141
|
+
return jax.tree.map(lambda y: y, x._dict)
|
142
142
|
elif isinstance(x, dict):
|
143
143
|
ys = {}
|
144
144
|
for key, value in x.items():
|
@@ -1,4 +1,4 @@
|
|
1
|
-
xax/__init__.py,sha256=
|
1
|
+
xax/__init__.py,sha256=94V0RHzCNC-aVXtpv5jtYaXJWytSJYOJCjUR69dIM1g,14428
|
2
2
|
xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
3
|
xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
|
4
4
|
xax/requirements.txt,sha256=6qY-84e-sTmlfJNrSjwONQKqzAn5h8G_oGIhnhmfSr4,302
|
@@ -9,7 +9,7 @@ xax/nn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
9
9
|
xax/nn/embeddings.py,sha256=bQGxBFxkLwi2MQLkRfGaHPH5P_KKB21HdI7VNWTKIOQ,11847
|
10
10
|
xax/nn/equinox.py,sha256=5fdOKRXqAVZPsV-aEez3i1wamr_oBYnG74GP1jEthjM,4843
|
11
11
|
xax/nn/export.py,sha256=pRfM2B4hB2EvljysC6AjtgB_7Cn7JtaP3dhYU2stZtY,5545
|
12
|
-
xax/nn/functions.py,sha256=
|
12
|
+
xax/nn/functions.py,sha256=Gig46ZjFxzXl3ImOFvpAO6oG6xrnAZywp1Ez8Gwy0BU,2711
|
13
13
|
xax/nn/geom.py,sha256=rImNlkHWeoNcY7f84nknizJ6uzsrMhbAtKeb2xAWxNY,6215
|
14
14
|
xax/nn/losses.py,sha256=Q_NVnm5n4UPBvp5nI_1aUptfXnqFYoUeFwySiyvopHg,272
|
15
15
|
xax/nn/norm.py,sha256=WgZ3QCrUnf-YecwhEtVPcr99fKK3ECl_UeiAs2uv7oo,564
|
@@ -26,13 +26,13 @@ xax/task/launchers/cli.py,sha256=cK7Nm-3fO-W2gTxpn3FEThsT2NvneS2w0UjA1Nt-84A,140
|
|
26
26
|
xax/task/launchers/single_process.py,sha256=IoML-30g5c526yxkpbWSOtG_KpNQMakT7xujzB1gIAo,846
|
27
27
|
xax/task/loggers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
28
28
|
xax/task/loggers/callback.py,sha256=lyuZX6Bir7xJM07ifdQIl1jlclgkiS82UO9V4y7wgPs,1582
|
29
|
-
xax/task/loggers/json.py,sha256=
|
29
|
+
xax/task/loggers/json.py,sha256=Ukbo6eAq9mSDmo7AqCVu9OFFjjSjcIsdeKP_1WQbyuw,4110
|
30
30
|
xax/task/loggers/state.py,sha256=6bG-NRsSUzAukYiglCT0oDj8zRMpffH4e1TKWGw1x4k,959
|
31
31
|
xax/task/loggers/stdout.py,sha256=ERLFrYe61hSSztzyxBRseobHQR72YFYjEd2i_hOeJ20,6595
|
32
32
|
xax/task/loggers/tensorboard.py,sha256=KFlsK0zD2ubDqAXYL4Ds7NQ9F-Ke-PHwfhLOYsGcbw4,8306
|
33
33
|
xax/task/mixins/__init__.py,sha256=D3oU31rB9FeOr9MPLleLt5JFbftUr4sBTwgnwQdc2qA,809
|
34
34
|
xax/task/mixins/artifacts.py,sha256=2ezmZGzPGe3nhsd9KRkeHWWXdbT9m7drzimIfw6v1XY,2892
|
35
|
-
xax/task/mixins/checkpointing.py,sha256=
|
35
|
+
xax/task/mixins/checkpointing.py,sha256=zqospBFnTbGt_iriiduVfXazINPbzWpwmIs91KAniMY,10147
|
36
36
|
xax/task/mixins/compile.py,sha256=PG5aF3W9v_xGiImHgUJ7gmwuQQoSQWufdpl2N_mlLX0,3922
|
37
37
|
xax/task/mixins/cpu_stats.py,sha256=rO_9a82ZdsNec61ya4FpYE-rWqPhpijRSXsOfc6caFA,9595
|
38
38
|
xax/task/mixins/data_loader.py,sha256=Tp7zqPdfH2_JuE6J6EP-fEtCQpq9MjKlGHYK7Zh-goU,6599
|
@@ -41,25 +41,25 @@ xax/task/mixins/logger.py,sha256=6oXsJJyNUx6YT3q58FVXMZBUpMgjVkGre6BXFN20cVI,280
|
|
41
41
|
xax/task/mixins/process.py,sha256=hqDEsMp_SL6ee97iq26-G0g49OcWZZaX82JD4F22eJU,1781
|
42
42
|
xax/task/mixins/runnable.py,sha256=IYIsLd2k09g-_y6o44EhJqT7E6BpsyEMmsyLSuzqjtc,1979
|
43
43
|
xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
|
44
|
-
xax/task/mixins/train.py,sha256=
|
44
|
+
xax/task/mixins/train.py,sha256=t2KW18S9vpUuvmr3VAeKzSEcEdAD6fbi_fjk-KZ6ssA,31426
|
45
45
|
xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
46
46
|
xax/utils/debugging.py,sha256=OtUdu-3tQsQtik0Q9UM-SNV46IbPjwrAfZcywzoB5d4,1940
|
47
47
|
xax/utils/experiments.py,sha256=d2H63ECtVOKySMUMrQRqq4kcuZpoXqo-L931usDVAhE,29903
|
48
48
|
xax/utils/jax.py,sha256=KQYUHjN6t6JIWa11aRSO3edcsAgTscw_dExxI6kCd9g,6767
|
49
|
-
xax/utils/jaxpr.py,sha256=
|
49
|
+
xax/utils/jaxpr.py,sha256=H7pWl48ROXIB1-ZPWYfOn-ou3EBMxYWIwc_A0reJQoo,2333
|
50
50
|
xax/utils/logging.py,sha256=GAhTne2rdB4Fa1lzk06DMO15U8MTejn6XTClShC-ZtU,6622
|
51
51
|
xax/utils/numpy.py,sha256=_jOXVi-d2AtJnRftPkRK5MDMzsU8slgw-Jjv4GRm6ns,1197
|
52
52
|
xax/utils/profile.py,sha256=-aFdWpgYFvBsBZXSLL4zXrFe3zzsDqzmx4q5f2WOtpQ,1628
|
53
|
-
xax/utils/pytree.py,sha256=
|
53
|
+
xax/utils/pytree.py,sha256=Q_5NCr90Wlqw1x4yK-FXfDVZEkNK2figmsUSZ4fnxYk,8754
|
54
54
|
xax/utils/tensorboard.py,sha256=P0oIFvX2Qts1H4lkpizhRIpQdD0MNppVMeut0Z94yCs,19878
|
55
55
|
xax/utils/text.py,sha256=zo1sAoZe59GkpcpaHBVOQ0OekSMGXvOAyNa3lOJozCY,10628
|
56
56
|
xax/utils/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
57
57
|
xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,7066
|
58
58
|
xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
59
|
-
xax/utils/types/frozen_dict.py,sha256=
|
59
|
+
xax/utils/types/frozen_dict.py,sha256=s57XaTo2jgeT4_SzNQEjsYb4XrNZJwm1ca4ZdXSE5TY,4676
|
60
60
|
xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
|
61
|
-
xax-0.2.
|
62
|
-
xax-0.2.
|
63
|
-
xax-0.2.
|
64
|
-
xax-0.2.
|
65
|
-
xax-0.2.
|
61
|
+
xax-0.2.7.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
|
62
|
+
xax-0.2.7.dist-info/METADATA,sha256=7otxcR5N4nVGyao5-NsrLfyuY1LqX8v2PWMEWB3sVy8,1879
|
63
|
+
xax-0.2.7.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
64
|
+
xax-0.2.7.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
|
65
|
+
xax-0.2.7.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|