xax 0.2.6__tar.gz → 0.2.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. {xax-0.2.6/xax.egg-info → xax-0.2.7}/PKG-INFO +1 -1
  2. {xax-0.2.6 → xax-0.2.7}/xax/__init__.py +7 -2
  3. {xax-0.2.6 → xax-0.2.7}/xax/nn/functions.py +1 -1
  4. {xax-0.2.6 → xax-0.2.7}/xax/task/loggers/json.py +1 -2
  5. {xax-0.2.6 → xax-0.2.7}/xax/task/mixins/checkpointing.py +108 -143
  6. {xax-0.2.6 → xax-0.2.7}/xax/task/mixins/train.py +26 -17
  7. {xax-0.2.6 → xax-0.2.7}/xax/utils/jaxpr.py +5 -5
  8. {xax-0.2.6 → xax-0.2.7}/xax/utils/pytree.py +1 -1
  9. {xax-0.2.6 → xax-0.2.7}/xax/utils/types/frozen_dict.py +1 -1
  10. {xax-0.2.6 → xax-0.2.7/xax.egg-info}/PKG-INFO +1 -1
  11. {xax-0.2.6 → xax-0.2.7}/LICENSE +0 -0
  12. {xax-0.2.6 → xax-0.2.7}/MANIFEST.in +0 -0
  13. {xax-0.2.6 → xax-0.2.7}/README.md +0 -0
  14. {xax-0.2.6 → xax-0.2.7}/pyproject.toml +0 -0
  15. {xax-0.2.6 → xax-0.2.7}/setup.cfg +0 -0
  16. {xax-0.2.6 → xax-0.2.7}/setup.py +0 -0
  17. {xax-0.2.6 → xax-0.2.7}/xax/core/__init__.py +0 -0
  18. {xax-0.2.6 → xax-0.2.7}/xax/core/conf.py +0 -0
  19. {xax-0.2.6 → xax-0.2.7}/xax/core/state.py +0 -0
  20. {xax-0.2.6 → xax-0.2.7}/xax/nn/__init__.py +0 -0
  21. {xax-0.2.6 → xax-0.2.7}/xax/nn/embeddings.py +0 -0
  22. {xax-0.2.6 → xax-0.2.7}/xax/nn/equinox.py +0 -0
  23. {xax-0.2.6 → xax-0.2.7}/xax/nn/export.py +0 -0
  24. {xax-0.2.6 → xax-0.2.7}/xax/nn/geom.py +0 -0
  25. {xax-0.2.6 → xax-0.2.7}/xax/nn/losses.py +0 -0
  26. {xax-0.2.6 → xax-0.2.7}/xax/nn/norm.py +0 -0
  27. {xax-0.2.6 → xax-0.2.7}/xax/nn/parallel.py +0 -0
  28. {xax-0.2.6 → xax-0.2.7}/xax/nn/ssm.py +0 -0
  29. {xax-0.2.6 → xax-0.2.7}/xax/py.typed +0 -0
  30. {xax-0.2.6 → xax-0.2.7}/xax/requirements-dev.txt +0 -0
  31. {xax-0.2.6 → xax-0.2.7}/xax/requirements.txt +0 -0
  32. {xax-0.2.6 → xax-0.2.7}/xax/task/__init__.py +0 -0
  33. {xax-0.2.6 → xax-0.2.7}/xax/task/base.py +0 -0
  34. {xax-0.2.6 → xax-0.2.7}/xax/task/launchers/__init__.py +0 -0
  35. {xax-0.2.6 → xax-0.2.7}/xax/task/launchers/base.py +0 -0
  36. {xax-0.2.6 → xax-0.2.7}/xax/task/launchers/cli.py +0 -0
  37. {xax-0.2.6 → xax-0.2.7}/xax/task/launchers/single_process.py +0 -0
  38. {xax-0.2.6 → xax-0.2.7}/xax/task/logger.py +0 -0
  39. {xax-0.2.6 → xax-0.2.7}/xax/task/loggers/__init__.py +0 -0
  40. {xax-0.2.6 → xax-0.2.7}/xax/task/loggers/callback.py +0 -0
  41. {xax-0.2.6 → xax-0.2.7}/xax/task/loggers/state.py +0 -0
  42. {xax-0.2.6 → xax-0.2.7}/xax/task/loggers/stdout.py +0 -0
  43. {xax-0.2.6 → xax-0.2.7}/xax/task/loggers/tensorboard.py +0 -0
  44. {xax-0.2.6 → xax-0.2.7}/xax/task/mixins/__init__.py +0 -0
  45. {xax-0.2.6 → xax-0.2.7}/xax/task/mixins/artifacts.py +0 -0
  46. {xax-0.2.6 → xax-0.2.7}/xax/task/mixins/compile.py +0 -0
  47. {xax-0.2.6 → xax-0.2.7}/xax/task/mixins/cpu_stats.py +0 -0
  48. {xax-0.2.6 → xax-0.2.7}/xax/task/mixins/data_loader.py +0 -0
  49. {xax-0.2.6 → xax-0.2.7}/xax/task/mixins/gpu_stats.py +0 -0
  50. {xax-0.2.6 → xax-0.2.7}/xax/task/mixins/logger.py +0 -0
  51. {xax-0.2.6 → xax-0.2.7}/xax/task/mixins/process.py +0 -0
  52. {xax-0.2.6 → xax-0.2.7}/xax/task/mixins/runnable.py +0 -0
  53. {xax-0.2.6 → xax-0.2.7}/xax/task/mixins/step_wrapper.py +0 -0
  54. {xax-0.2.6 → xax-0.2.7}/xax/task/script.py +0 -0
  55. {xax-0.2.6 → xax-0.2.7}/xax/task/task.py +0 -0
  56. {xax-0.2.6 → xax-0.2.7}/xax/utils/__init__.py +0 -0
  57. {xax-0.2.6 → xax-0.2.7}/xax/utils/data/__init__.py +0 -0
  58. {xax-0.2.6 → xax-0.2.7}/xax/utils/data/collate.py +0 -0
  59. {xax-0.2.6 → xax-0.2.7}/xax/utils/debugging.py +0 -0
  60. {xax-0.2.6 → xax-0.2.7}/xax/utils/experiments.py +0 -0
  61. {xax-0.2.6 → xax-0.2.7}/xax/utils/jax.py +0 -0
  62. {xax-0.2.6 → xax-0.2.7}/xax/utils/logging.py +0 -0
  63. {xax-0.2.6 → xax-0.2.7}/xax/utils/numpy.py +0 -0
  64. {xax-0.2.6 → xax-0.2.7}/xax/utils/profile.py +0 -0
  65. {xax-0.2.6 → xax-0.2.7}/xax/utils/tensorboard.py +0 -0
  66. {xax-0.2.6 → xax-0.2.7}/xax/utils/text.py +0 -0
  67. {xax-0.2.6 → xax-0.2.7}/xax/utils/types/__init__.py +0 -0
  68. {xax-0.2.6 → xax-0.2.7}/xax/utils/types/hashable_array.py +0 -0
  69. {xax-0.2.6 → xax-0.2.7}/xax.egg-info/SOURCES.txt +0 -0
  70. {xax-0.2.6 → xax-0.2.7}/xax.egg-info/dependency_links.txt +0 -0
  71. {xax-0.2.6 → xax-0.2.7}/xax.egg-info/requires.txt +0 -0
  72. {xax-0.2.6 → xax-0.2.7}/xax.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.2.6
3
+ Version: 0.2.7
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.2.6"
15
+ __version__ = "0.2.7"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -66,11 +66,13 @@ __all__ = [
66
66
  "StateLogger",
67
67
  "StdoutLogger",
68
68
  "TensorboardLogger",
69
+ "load_ckpt",
69
70
  "CPUStatsOptions",
70
71
  "DataloaderConfig",
71
72
  "GPUStatsOptions",
72
73
  "StepContext",
73
74
  "ValidStepTimer",
75
+ "get_param_count",
74
76
  "Script",
75
77
  "ScriptConfig",
76
78
  "Config",
@@ -230,11 +232,13 @@ NAME_MAP: dict[str, str] = {
230
232
  "StateLogger": "task.loggers.state",
231
233
  "StdoutLogger": "task.loggers.stdout",
232
234
  "TensorboardLogger": "task.loggers.tensorboard",
235
+ "load_ckpt": "task.mixins.checkpointing",
233
236
  "CPUStatsOptions": "task.mixins.cpu_stats",
234
237
  "DataloaderConfig": "task.mixins.data_loader",
235
238
  "GPUStatsOptions": "task.mixins.gpu_stats",
236
239
  "StepContext": "task.mixins.step_wrapper",
237
240
  "ValidStepTimer": "task.mixins.train",
241
+ "get_param_count": "task.mixins.train",
238
242
  "Script": "task.script",
239
243
  "ScriptConfig": "task.script",
240
244
  "Config": "task.task",
@@ -390,11 +394,12 @@ if IMPORT_ALL or TYPE_CHECKING:
390
394
  from xax.task.loggers.state import StateLogger
391
395
  from xax.task.loggers.stdout import StdoutLogger
392
396
  from xax.task.loggers.tensorboard import TensorboardLogger
397
+ from xax.task.mixins.checkpointing import load_ckpt
393
398
  from xax.task.mixins.cpu_stats import CPUStatsOptions
394
399
  from xax.task.mixins.data_loader import DataloaderConfig
395
400
  from xax.task.mixins.gpu_stats import GPUStatsOptions
396
401
  from xax.task.mixins.step_wrapper import StepContext
397
- from xax.task.mixins.train import Batch, Output, ValidStepTimer
402
+ from xax.task.mixins.train import Batch, Output, ValidStepTimer, get_param_count
398
403
  from xax.task.script import Script, ScriptConfig
399
404
  from xax.task.task import Config, Task
400
405
  from xax.utils.data.collate import CollateMode, collate, collate_non_null
@@ -1,5 +1,5 @@
1
1
  # mypy: disable-error-code="override"
2
- """Defines helper Torch functions."""
2
+ """Defines helper Jax functions."""
3
3
 
4
4
  import random
5
5
  from dataclasses import is_dataclass
@@ -2,7 +2,6 @@
2
2
 
3
3
  import json
4
4
  import sys
5
- from dataclasses import asdict
6
5
  from typing import Any, Literal, Mapping, TextIO
7
6
 
8
7
  from jaxtyping import Array
@@ -67,7 +66,7 @@ class JsonLogger(LoggerImpl):
67
66
  return self.err_log_stream
68
67
 
69
68
  def get_json(self, line: LogLine) -> str:
70
- data: dict = {"state": asdict(line.state)}
69
+ data: dict = {"state": line.state.to_dict()}
71
70
 
72
71
  def add_logs(log: Mapping[str, Mapping[str, LogScalar | LogString]], data: dict) -> None:
73
72
  for namespace, values in log.items():
@@ -52,6 +52,114 @@ class CheckpointingConfig(ArtifactsConfig):
52
52
  Config = TypeVar("Config", bound=CheckpointingConfig)
53
53
 
54
54
 
55
+ @overload
56
+ def load_ckpt(
57
+ path: Path,
58
+ *,
59
+ part: Literal["all"],
60
+ model_template: PyTree,
61
+ optimizer_template: PyTree,
62
+ opt_state_template: PyTree,
63
+ ) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State, DictConfig]: ...
64
+
65
+
66
+ @overload
67
+ def load_ckpt(
68
+ path: Path,
69
+ *,
70
+ part: Literal["model_state_config"],
71
+ model_template: PyTree,
72
+ ) -> tuple[PyTree, State, DictConfig]: ...
73
+
74
+
75
+ @overload
76
+ def load_ckpt(path: Path, *, part: Literal["model"], model_template: PyTree) -> PyTree: ...
77
+
78
+
79
+ @overload
80
+ def load_ckpt(path: Path, *, part: Literal["opt"], optimizer_template: PyTree) -> optax.GradientTransformation: ...
81
+
82
+
83
+ @overload
84
+ def load_ckpt(path: Path, *, part: Literal["opt_state"], opt_state_template: PyTree) -> optax.OptState: ...
85
+
86
+
87
+ @overload
88
+ def load_ckpt(path: Path, *, part: Literal["state"]) -> State: ...
89
+
90
+
91
+ @overload
92
+ def load_ckpt(path: Path, *, part: Literal["config"]) -> DictConfig: ...
93
+
94
+
95
+ def load_ckpt(
96
+ path: str | Path,
97
+ *,
98
+ part: CheckpointPart = "model",
99
+ model_template: PyTree | None = None,
100
+ optimizer_template: PyTree | None = None,
101
+ opt_state_template: PyTree | None = None,
102
+ ) -> (
103
+ tuple[PyTree, optax.GradientTransformation, optax.OptState, State, DictConfig]
104
+ | tuple[PyTree, State, DictConfig]
105
+ | PyTree
106
+ | optax.GradientTransformation
107
+ | optax.OptState
108
+ | State
109
+ | DictConfig
110
+ ):
111
+ with tarfile.open(path, "r:gz") as tar:
112
+
113
+ def get_model() -> PyTree:
114
+ if model_template is None:
115
+ raise ValueError("model_template must be provided to load model weights")
116
+ if (model := tar.extractfile("model")) is None:
117
+ raise ValueError(f"Checkpoint does not contain a model file: {path}")
118
+ return eqx.tree_deserialise_leaves(io.BytesIO(model.read()), model_template)
119
+
120
+ def get_opt() -> optax.GradientTransformation:
121
+ if optimizer_template is None:
122
+ raise ValueError("optimizer_template must be provided to load optimizer")
123
+ if (opt := tar.extractfile("optimizer")) is None:
124
+ raise ValueError(f"Checkpoint does not contain an optimizer file: {path}")
125
+ return eqx.tree_deserialise_leaves(io.BytesIO(opt.read()), optimizer_template)
126
+
127
+ def get_opt_state() -> optax.OptState:
128
+ if opt_state_template is None:
129
+ raise ValueError("opt_state_template must be provided to load optimizer state")
130
+ if (opt_state := tar.extractfile("opt_state")) is None:
131
+ raise ValueError(f"Checkpoint does not contain an optimizer state file: {path}")
132
+ return eqx.tree_deserialise_leaves(io.BytesIO(opt_state.read()), opt_state_template)
133
+
134
+ def get_state() -> State:
135
+ if (state := tar.extractfile("state")) is None:
136
+ raise ValueError(f"Checkpoint does not contain a state file: {path}")
137
+ return State.from_dict(**json.loads(state.read().decode()))
138
+
139
+ def get_config() -> DictConfig:
140
+ if (config := tar.extractfile("config")) is None:
141
+ raise ValueError(f"Checkpoint does not contain a config file: {path}")
142
+ return cast(DictConfig, OmegaConf.load(config))
143
+
144
+ match part:
145
+ case "model":
146
+ return get_model()
147
+ case "opt":
148
+ return get_opt()
149
+ case "opt_state":
150
+ return get_opt_state()
151
+ case "state":
152
+ return get_state()
153
+ case "config":
154
+ return get_config()
155
+ case "model_state_config":
156
+ return get_model(), get_state(), get_config()
157
+ case "all":
158
+ return get_model(), get_opt(), get_opt_state(), get_state(), get_config()
159
+ case _:
160
+ raise ValueError(f"Invalid checkpoint part: {part}")
161
+
162
+
55
163
  class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
56
164
  def __init__(self, config: Config) -> None:
57
165
  super().__init__(config)
@@ -82,149 +190,6 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
82
190
  return True
83
191
  return False
84
192
 
85
- @overload
86
- def load_ckpt_with_template(
87
- self,
88
- path: Path,
89
- *,
90
- part: Literal["all"],
91
- model_template: PyTree,
92
- optimizer_template: PyTree,
93
- opt_state_template: PyTree,
94
- ) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State, Config]: ...
95
-
96
- @overload
97
- def load_ckpt_with_template(
98
- self,
99
- path: Path,
100
- *,
101
- part: Literal["model_state_config"],
102
- model_template: PyTree,
103
- ) -> tuple[PyTree, State, Config]: ...
104
-
105
- @overload
106
- def load_ckpt_with_template(
107
- self,
108
- path: Path,
109
- *,
110
- part: Literal["model"],
111
- model_template: PyTree,
112
- ) -> PyTree: ...
113
-
114
- @overload
115
- def load_ckpt_with_template(
116
- self,
117
- path: Path,
118
- *,
119
- part: Literal["opt"],
120
- optimizer_template: PyTree,
121
- ) -> optax.GradientTransformation: ...
122
-
123
- @overload
124
- def load_ckpt_with_template(
125
- self,
126
- path: Path,
127
- *,
128
- part: Literal["opt_state"],
129
- opt_state_template: PyTree,
130
- ) -> optax.OptState: ...
131
-
132
- @overload
133
- def load_ckpt_with_template(
134
- self,
135
- path: Path,
136
- *,
137
- part: Literal["state"],
138
- ) -> State: ...
139
-
140
- @overload
141
- def load_ckpt_with_template(
142
- self,
143
- path: Path,
144
- *,
145
- part: Literal["config"],
146
- ) -> Config: ...
147
-
148
- def load_ckpt_with_template(
149
- self,
150
- path: Path,
151
- *,
152
- part: CheckpointPart = "all",
153
- model_template: PyTree | None = None,
154
- optimizer_template: PyTree | None = None,
155
- opt_state_template: PyTree | None = None,
156
- ) -> (
157
- tuple[PyTree, optax.GradientTransformation, optax.OptState, State, Config]
158
- | tuple[PyTree, State, Config]
159
- | PyTree
160
- | optax.GradientTransformation
161
- | optax.OptState
162
- | State
163
- | Config
164
- ):
165
- """Load a checkpoint.
166
-
167
- Args:
168
- path: Path to the checkpoint directory
169
- part: Which part of the checkpoint to load
170
- model_template: Template model with correct structure but uninitialized weights
171
- optimizer_template: Template optimizer with correct structure but uninitialized weights
172
- opt_state_template: Template optimizer state with correct structure but uninitialized weights
173
-
174
- Returns:
175
- The requested checkpoint components
176
- """
177
- with tarfile.open(path, "r:gz") as tar:
178
-
179
- def get_model() -> PyTree:
180
- if model_template is None:
181
- raise ValueError("model_template must be provided to load model weights")
182
- if (model := tar.extractfile("model")) is None:
183
- raise ValueError(f"Checkpoint does not contain a model file: {path}")
184
- return eqx.tree_deserialise_leaves(io.BytesIO(model.read()), model_template)
185
-
186
- def get_opt() -> optax.GradientTransformation:
187
- if optimizer_template is None:
188
- raise ValueError("optimizer_template must be provided to load optimizer")
189
- if (opt := tar.extractfile("optimizer")) is None:
190
- raise ValueError(f"Checkpoint does not contain an optimizer file: {path}")
191
- return eqx.tree_deserialise_leaves(io.BytesIO(opt.read()), optimizer_template)
192
-
193
- def get_opt_state() -> optax.OptState:
194
- if opt_state_template is None:
195
- raise ValueError("opt_state_template must be provided to load optimizer state")
196
- if (opt_state := tar.extractfile("opt_state")) is None:
197
- raise ValueError(f"Checkpoint does not contain an optimizer state file: {path}")
198
- return eqx.tree_deserialise_leaves(io.BytesIO(opt_state.read()), opt_state_template)
199
-
200
- def get_state() -> State:
201
- if (state := tar.extractfile("state")) is None:
202
- raise ValueError(f"Checkpoint does not contain a state file: {path}")
203
- return State.from_dict(**json.loads(state.read().decode()))
204
-
205
- def get_config() -> Config:
206
- if (config := tar.extractfile("config")) is None:
207
- raise ValueError(f"Checkpoint does not contain a config file: {path}")
208
- return self.get_config(cast(DictConfig, OmegaConf.load(config)), use_cli=False)
209
-
210
- match part:
211
- case "model":
212
- return get_model()
213
- case "opt":
214
- return get_opt()
215
- case "opt_state":
216
- return get_opt_state()
217
- case "state":
218
- return get_state()
219
- case "config":
220
- return get_config()
221
- case "model_state_config":
222
- return get_model(), get_state(), get_config()
223
- case "all":
224
- return get_model(), get_opt(), get_opt_state(), get_state(), get_config()
225
- case _:
226
- raise ValueError(f"Invalid checkpoint part: {part}")
227
-
228
193
  def save_checkpoint(
229
194
  self,
230
195
  model: PyTree | None = None,
@@ -40,7 +40,7 @@ from xax.core.state import Phase, State
40
40
  from xax.nn.functions import set_random_seed
41
41
  from xax.nn.parallel import is_master
42
42
  from xax.task.mixins.artifacts import ArtifactsConfig, ArtifactsMixin
43
- from xax.task.mixins.checkpointing import CheckpointingConfig, CheckpointingMixin, CheckpointPart
43
+ from xax.task.mixins.checkpointing import CheckpointingConfig, CheckpointingMixin, CheckpointPart, load_ckpt
44
44
  from xax.task.mixins.data_loader import DataloadersConfig, DataloadersMixin
45
45
  from xax.task.mixins.logger import LoggerConfig, LoggerMixin
46
46
  from xax.task.mixins.runnable import RunnableConfig, RunnableMixin
@@ -96,6 +96,12 @@ def batches_per_step_schedule(schedule: list[int] | None) -> list[int] | None:
96
96
  return list(itertools.accumulate([0] + schedule))
97
97
 
98
98
 
99
+ def get_param_count(pytree: PyTree) -> int:
100
+ """Calculates the total number of parameters in a PyTree."""
101
+ leaves, _ = jax.tree.flatten(pytree)
102
+ return sum(x.size for x in leaves if isinstance(x, jnp.ndarray))
103
+
104
+
99
105
  class ValidStepTimer:
100
106
  def __init__(
101
107
  self,
@@ -360,6 +366,7 @@ class TrainMixin(
360
366
  model = self.get_model(key)
361
367
  state = State.init_state()
362
368
 
369
+ self.log_model_size(model)
363
370
  if not load_optimizer:
364
371
  return model, state
365
372
 
@@ -450,44 +457,43 @@ class TrainMixin(
450
457
  match part:
451
458
  case "model_state_config":
452
459
  model_spec = eqx.filter_eval_shape(self.get_model, key)
453
- return self.load_ckpt_with_template(path, part="model_state_config", model_template=model_spec)
460
+ model, state, config = load_ckpt(path, part="model_state_config", model_template=model_spec)
461
+ config = self.get_config(config, use_cli=False)
462
+ return model, state, config
454
463
 
455
464
  case "model":
456
465
  model_spec = eqx.filter_eval_shape(self.get_model, key)
457
- return self.load_ckpt_with_template(path, part="model", model_template=model_spec)
458
-
459
- case "config":
460
- return self.load_ckpt_with_template(path, part="config")
466
+ return load_ckpt(path, part="model", model_template=model_spec)
461
467
 
462
468
  case "opt":
463
469
  optimizer_spec = eqx.filter_eval_shape(self.get_optimizer)
464
- return self.load_ckpt_with_template(path, part="opt", optimizer_template=optimizer_spec)
470
+ return load_ckpt(path, part="opt", optimizer_template=optimizer_spec)
465
471
 
466
472
  case "opt_state":
467
473
  if model is None:
468
474
  model_spec = eqx.filter_eval_shape(self.get_model, key)
469
- model = self.load_ckpt_with_template(path, part="model", model_template=model_spec)
475
+ model = load_ckpt(path, part="model", model_template=model_spec)
470
476
  if optimizer is None:
471
477
  optimizer_spec = eqx.filter_eval_shape(self.get_optimizer)
472
- optimizer = self.load_ckpt_with_template(path, part="opt", optimizer_template=optimizer_spec)
478
+ optimizer = load_ckpt(path, part="opt", optimizer_template=optimizer_spec)
473
479
  opt_state_spec = eqx.filter_eval_shape(self.get_initial_opt_state, model, optimizer)
474
- return self.load_ckpt_with_template(path, part="opt_state", opt_state_template=opt_state_spec)
480
+ return load_ckpt(path, part="opt_state", opt_state_template=opt_state_spec)
475
481
 
476
482
  case "state":
477
- return self.load_ckpt_with_template(path, part="state")
483
+ return load_ckpt(path, part="state")
478
484
 
479
485
  case "config":
480
- return self.load_ckpt_with_template(path, part="config")
486
+ return self.get_config(load_ckpt(path, part="config"), use_cli=False)
481
487
 
482
488
  case "all":
483
489
  model_spec = eqx.filter_eval_shape(self.get_model, key)
484
- model = self.load_ckpt_with_template(path, part="model", model_template=model_spec)
490
+ model = load_ckpt(path, part="model", model_template=model_spec)
485
491
  optimizer_spec = eqx.filter_eval_shape(self.get_optimizer)
486
- optimizer = self.load_ckpt_with_template(path, part="opt", optimizer_template=optimizer_spec)
492
+ optimizer = load_ckpt(path, part="opt", optimizer_template=optimizer_spec)
487
493
  opt_state_spec = eqx.filter_eval_shape(self.get_initial_opt_state, model, optimizer)
488
- opt_state = self.load_ckpt_with_template(path, part="opt_state", opt_state_template=opt_state_spec)
489
- state = self.load_ckpt_with_template(path, part="state")
490
- config = self.load_ckpt_with_template(path, part="config")
494
+ opt_state = load_ckpt(path, part="opt_state", opt_state_template=opt_state_spec)
495
+ state = load_ckpt(path, part="state")
496
+ config = self.get_config(load_ckpt(path, part="config"), use_cli=False)
491
497
  return model, optimizer, opt_state, state, config
492
498
 
493
499
  case _:
@@ -683,6 +689,9 @@ class TrainMixin(
683
689
  self.logger.log_file("config.yaml", self.config_str(self.config, use_cli=False))
684
690
  self.logger.log_file("info.json", get_info_json())
685
691
 
692
+ def log_model_size(self, model: PyTree) -> None:
693
+ logger.info("Model size: %s", f"{get_param_count(model):,}")
694
+
686
695
  def model_partition_fn(self, item: Any) -> bool: # noqa: ANN401
687
696
  return eqx.is_inexact_array(item)
688
697
 
@@ -3,10 +3,10 @@
3
3
  from pathlib import Path
4
4
 
5
5
  import jax
6
- import jax.core
6
+ import jax.extend.core
7
7
 
8
8
 
9
- def save_jaxpr_dot(closed_jaxpr: jax.core.ClosedJaxpr, filename: str | Path) -> None:
9
+ def save_jaxpr_dot(closed_jaxpr: jax.extend.core.ClosedJaxpr, filename: str | Path) -> None:
10
10
  """Save the JAXPR to a DOT file.
11
11
 
12
12
  Example usage:
@@ -30,15 +30,15 @@ def save_jaxpr_dot(closed_jaxpr: jax.core.ClosedJaxpr, filename: str | Path) ->
30
30
  with open(filename, "w") as f:
31
31
  f.write("digraph Jaxpr {\n")
32
32
 
33
- var_names: dict[jax.core.Var, str] = {}
33
+ var_names: dict[jax.extend.core.Var, str] = {}
34
34
  var_count = 0
35
35
 
36
- def get_var_name(var: jax.core.Var) -> str:
36
+ def get_var_name(var: jax.extend.core.Var) -> str:
37
37
  """Get a unique name for a variable."""
38
38
  nonlocal var_names, var_count
39
39
 
40
40
  # Handle Literal objects specially since they're not hashable
41
- if isinstance(var, jax.core.Literal):
41
+ if isinstance(var, jax.extend.core.Literal):
42
42
  # Create a name based on the literal value
43
43
  name = f"lit_{var.val}"
44
44
  return name
@@ -57,7 +57,7 @@ def pytree_has_nans(pytree: PyTree) -> Array:
57
57
 
58
58
  def update_pytree(cond: Array, new: PyTree, original: PyTree) -> PyTree:
59
59
  """Update a pytree based on a condition."""
60
- # Tricky, need use tree_map because where expects array leafs.
60
+ # Tricky, need use tree.map because where expects array leafs.
61
61
  return jax.tree.map(lambda x, y: jnp.where(cond, x, y), new, original)
62
62
 
63
63
 
@@ -138,7 +138,7 @@ class FrozenDict(Mapping[K, V]):
138
138
 
139
139
  def unfreeze(x: FrozenDict[K, V] | dict[str, Any]) -> dict[Any, Any]: # noqa: ANN401
140
140
  if isinstance(x, FrozenDict):
141
- return jax.tree_util.tree_map(lambda y: y, x._dict)
141
+ return jax.tree.map(lambda y: y, x._dict)
142
142
  elif isinstance(x, dict):
143
143
  ys = {}
144
144
  for key, value in x.items():
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.2.6
3
+ Version: 0.2.7
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes