xax 0.2.0__tar.gz → 0.2.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. {xax-0.2.0/xax.egg-info → xax-0.2.2}/PKG-INFO +1 -1
  2. {xax-0.2.0 → xax-0.2.2}/xax/__init__.py +4 -1
  3. {xax-0.2.0 → xax-0.2.2}/xax/nn/geom.py +34 -0
  4. {xax-0.2.0 → xax-0.2.2}/xax/task/mixins/checkpointing.py +9 -12
  5. {xax-0.2.0 → xax-0.2.2}/xax/task/mixins/cpu_stats.py +12 -9
  6. {xax-0.2.0 → xax-0.2.2}/xax/task/mixins/gpu_stats.py +14 -11
  7. {xax-0.2.0 → xax-0.2.2}/xax/task/mixins/process.py +14 -8
  8. {xax-0.2.0 → xax-0.2.2}/xax/task/mixins/train.py +133 -19
  9. {xax-0.2.0 → xax-0.2.2/xax.egg-info}/PKG-INFO +1 -1
  10. {xax-0.2.0 → xax-0.2.2}/LICENSE +0 -0
  11. {xax-0.2.0 → xax-0.2.2}/MANIFEST.in +0 -0
  12. {xax-0.2.0 → xax-0.2.2}/README.md +0 -0
  13. {xax-0.2.0 → xax-0.2.2}/pyproject.toml +0 -0
  14. {xax-0.2.0 → xax-0.2.2}/setup.cfg +0 -0
  15. {xax-0.2.0 → xax-0.2.2}/setup.py +0 -0
  16. {xax-0.2.0 → xax-0.2.2}/xax/core/__init__.py +0 -0
  17. {xax-0.2.0 → xax-0.2.2}/xax/core/conf.py +0 -0
  18. {xax-0.2.0 → xax-0.2.2}/xax/core/state.py +0 -0
  19. {xax-0.2.0 → xax-0.2.2}/xax/nn/__init__.py +0 -0
  20. {xax-0.2.0 → xax-0.2.2}/xax/nn/embeddings.py +0 -0
  21. {xax-0.2.0 → xax-0.2.2}/xax/nn/equinox.py +0 -0
  22. {xax-0.2.0 → xax-0.2.2}/xax/nn/export.py +0 -0
  23. {xax-0.2.0 → xax-0.2.2}/xax/nn/functions.py +0 -0
  24. {xax-0.2.0 → xax-0.2.2}/xax/nn/losses.py +0 -0
  25. {xax-0.2.0 → xax-0.2.2}/xax/nn/norm.py +0 -0
  26. {xax-0.2.0 → xax-0.2.2}/xax/nn/parallel.py +0 -0
  27. {xax-0.2.0 → xax-0.2.2}/xax/nn/ssm.py +0 -0
  28. {xax-0.2.0 → xax-0.2.2}/xax/py.typed +0 -0
  29. {xax-0.2.0 → xax-0.2.2}/xax/requirements-dev.txt +0 -0
  30. {xax-0.2.0 → xax-0.2.2}/xax/requirements.txt +0 -0
  31. {xax-0.2.0 → xax-0.2.2}/xax/task/__init__.py +0 -0
  32. {xax-0.2.0 → xax-0.2.2}/xax/task/base.py +0 -0
  33. {xax-0.2.0 → xax-0.2.2}/xax/task/launchers/__init__.py +0 -0
  34. {xax-0.2.0 → xax-0.2.2}/xax/task/launchers/base.py +0 -0
  35. {xax-0.2.0 → xax-0.2.2}/xax/task/launchers/cli.py +0 -0
  36. {xax-0.2.0 → xax-0.2.2}/xax/task/launchers/single_process.py +0 -0
  37. {xax-0.2.0 → xax-0.2.2}/xax/task/logger.py +0 -0
  38. {xax-0.2.0 → xax-0.2.2}/xax/task/loggers/__init__.py +0 -0
  39. {xax-0.2.0 → xax-0.2.2}/xax/task/loggers/callback.py +0 -0
  40. {xax-0.2.0 → xax-0.2.2}/xax/task/loggers/json.py +0 -0
  41. {xax-0.2.0 → xax-0.2.2}/xax/task/loggers/state.py +0 -0
  42. {xax-0.2.0 → xax-0.2.2}/xax/task/loggers/stdout.py +0 -0
  43. {xax-0.2.0 → xax-0.2.2}/xax/task/loggers/tensorboard.py +0 -0
  44. {xax-0.2.0 → xax-0.2.2}/xax/task/mixins/__init__.py +0 -0
  45. {xax-0.2.0 → xax-0.2.2}/xax/task/mixins/artifacts.py +0 -0
  46. {xax-0.2.0 → xax-0.2.2}/xax/task/mixins/compile.py +0 -0
  47. {xax-0.2.0 → xax-0.2.2}/xax/task/mixins/data_loader.py +0 -0
  48. {xax-0.2.0 → xax-0.2.2}/xax/task/mixins/logger.py +0 -0
  49. {xax-0.2.0 → xax-0.2.2}/xax/task/mixins/runnable.py +0 -0
  50. {xax-0.2.0 → xax-0.2.2}/xax/task/mixins/step_wrapper.py +0 -0
  51. {xax-0.2.0 → xax-0.2.2}/xax/task/script.py +0 -0
  52. {xax-0.2.0 → xax-0.2.2}/xax/task/task.py +0 -0
  53. {xax-0.2.0 → xax-0.2.2}/xax/utils/__init__.py +0 -0
  54. {xax-0.2.0 → xax-0.2.2}/xax/utils/data/__init__.py +0 -0
  55. {xax-0.2.0 → xax-0.2.2}/xax/utils/data/collate.py +0 -0
  56. {xax-0.2.0 → xax-0.2.2}/xax/utils/debugging.py +0 -0
  57. {xax-0.2.0 → xax-0.2.2}/xax/utils/experiments.py +0 -0
  58. {xax-0.2.0 → xax-0.2.2}/xax/utils/jax.py +0 -0
  59. {xax-0.2.0 → xax-0.2.2}/xax/utils/jaxpr.py +0 -0
  60. {xax-0.2.0 → xax-0.2.2}/xax/utils/logging.py +0 -0
  61. {xax-0.2.0 → xax-0.2.2}/xax/utils/numpy.py +0 -0
  62. {xax-0.2.0 → xax-0.2.2}/xax/utils/profile.py +0 -0
  63. {xax-0.2.0 → xax-0.2.2}/xax/utils/pytree.py +0 -0
  64. {xax-0.2.0 → xax-0.2.2}/xax/utils/tensorboard.py +0 -0
  65. {xax-0.2.0 → xax-0.2.2}/xax/utils/text.py +0 -0
  66. {xax-0.2.0 → xax-0.2.2}/xax/utils/types/__init__.py +0 -0
  67. {xax-0.2.0 → xax-0.2.2}/xax/utils/types/frozen_dict.py +0 -0
  68. {xax-0.2.0 → xax-0.2.2}/xax/utils/types/hashable_array.py +0 -0
  69. {xax-0.2.0 → xax-0.2.2}/xax.egg-info/SOURCES.txt +0 -0
  70. {xax-0.2.0 → xax-0.2.2}/xax.egg-info/dependency_links.txt +0 -0
  71. {xax-0.2.0 → xax-0.2.2}/xax.egg-info/requires.txt +0 -0
  72. {xax-0.2.0 → xax-0.2.2}/xax.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.2.0
3
+ Version: 0.2.2
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.2.0"
15
+ __version__ = "0.2.2"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -44,6 +44,7 @@ __all__ = [
44
44
  "euler_to_quat",
45
45
  "get_projected_gravity_vector_from_quat",
46
46
  "quat_to_euler",
47
+ "quat_to_rotmat",
47
48
  "rotate_vector_by_quat",
48
49
  "cross_entropy",
49
50
  "cast_norm_type",
@@ -206,6 +207,7 @@ NAME_MAP: dict[str, str] = {
206
207
  "euler_to_quat": "nn.geom",
207
208
  "get_projected_gravity_vector_from_quat": "nn.geom",
208
209
  "quat_to_euler": "nn.geom",
210
+ "quat_to_rotmat": "nn.geom",
209
211
  "rotate_vector_by_quat": "nn.geom",
210
212
  "cross_entropy": "nn.losses",
211
213
  "cast_norm_type": "nn.norm",
@@ -369,6 +371,7 @@ if IMPORT_ALL or TYPE_CHECKING:
369
371
  euler_to_quat,
370
372
  get_projected_gravity_vector_from_quat,
371
373
  quat_to_euler,
374
+ quat_to_rotmat,
372
375
  rotate_vector_by_quat,
373
376
  )
374
377
  from xax.nn.losses import cross_entropy
@@ -177,3 +177,37 @@ def cubic_bezier_interpolation(y_start: Array, y_end: Array, x: Array) -> Array:
177
177
  y_diff = y_end - y_start
178
178
  bezier = x**3 + 3 * (x**2 * (1 - x))
179
179
  return y_start + y_diff * bezier
180
+
181
+
182
+ def quat_to_rotmat(quat: Array, eps: float = 1e-6) -> Array:
183
+ """Converts a quaternion to a rotation matrix.
184
+
185
+ Args:
186
+ quat: The quaternion to convert, shape (*, 4).
187
+ eps: A small epsilon value to avoid division by zero.
188
+
189
+ Returns:
190
+ The rotation matrix, shape (*, 3, 3).
191
+ """
192
+ quat = quat / (jnp.linalg.norm(quat, axis=-1, keepdims=True) + eps)
193
+ w, x, y, z = jnp.split(quat, 4, axis=-1)
194
+
195
+ xx = 1 - 2 * (y * y + z * z)
196
+ xy = 2 * (x * y - z * w)
197
+ xz = 2 * (x * z + y * w)
198
+ yx = 2 * (x * y + z * w)
199
+ yy = 1 - 2 * (x * x + z * z)
200
+ yz = 2 * (y * z - x * w)
201
+ zx = 2 * (x * z - y * w)
202
+ zy = 2 * (y * z + x * w)
203
+ zz = 1 - 2 * (x * x + y * y)
204
+
205
+ # Corrected stacking: row-major order
206
+ return jnp.concatenate(
207
+ [
208
+ jnp.concatenate([xx, xy, xz], axis=-1)[..., None, :],
209
+ jnp.concatenate([yx, yy, yz], axis=-1)[..., None, :],
210
+ jnp.concatenate([zx, zy, zz], axis=-1)[..., None, :],
211
+ ],
212
+ axis=-2,
213
+ )
@@ -63,10 +63,7 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
63
63
 
64
64
  def get_init_ckpt_path(self) -> Path | None:
65
65
  if self._exp_dir is not None:
66
- ckpt_path = self.get_ckpt_path()
67
- if not ckpt_path.exists():
68
- logger.warning("No checkpoint found in experiment directory: %s", ckpt_path)
69
- else:
66
+ if (ckpt_path := self.get_ckpt_path()).exists():
70
67
  return ckpt_path
71
68
  if self.config.load_from_ckpt_path is not None:
72
69
  ckpt_path = Path(self.config.load_from_ckpt_path)
@@ -86,7 +83,7 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
86
83
  return False
87
84
 
88
85
  @overload
89
- def load_checkpoint(
86
+ def load_ckpt_with_template(
90
87
  self,
91
88
  path: Path,
92
89
  *,
@@ -97,7 +94,7 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
97
94
  ) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State, Config]: ...
98
95
 
99
96
  @overload
100
- def load_checkpoint(
97
+ def load_ckpt_with_template(
101
98
  self,
102
99
  path: Path,
103
100
  *,
@@ -106,7 +103,7 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
106
103
  ) -> tuple[PyTree, State, Config]: ...
107
104
 
108
105
  @overload
109
- def load_checkpoint(
106
+ def load_ckpt_with_template(
110
107
  self,
111
108
  path: Path,
112
109
  *,
@@ -115,7 +112,7 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
115
112
  ) -> PyTree: ...
116
113
 
117
114
  @overload
118
- def load_checkpoint(
115
+ def load_ckpt_with_template(
119
116
  self,
120
117
  path: Path,
121
118
  *,
@@ -124,7 +121,7 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
124
121
  ) -> optax.GradientTransformation: ...
125
122
 
126
123
  @overload
127
- def load_checkpoint(
124
+ def load_ckpt_with_template(
128
125
  self,
129
126
  path: Path,
130
127
  *,
@@ -133,7 +130,7 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
133
130
  ) -> optax.OptState: ...
134
131
 
135
132
  @overload
136
- def load_checkpoint(
133
+ def load_ckpt_with_template(
137
134
  self,
138
135
  path: Path,
139
136
  *,
@@ -141,14 +138,14 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
141
138
  ) -> State: ...
142
139
 
143
140
  @overload
144
- def load_checkpoint(
141
+ def load_ckpt_with_template(
145
142
  self,
146
143
  path: Path,
147
144
  *,
148
145
  part: Literal["config"],
149
146
  ) -> Config: ...
150
147
 
151
- def load_checkpoint(
148
+ def load_ckpt_with_template(
152
149
  self,
153
150
  path: Path,
154
151
  *,
@@ -218,33 +218,36 @@ class CPUStatsMonitor:
218
218
  class CPUStatsMixin(ProcessMixin[Config], LoggerMixin[Config], Generic[Config]):
219
219
  """Defines a task mixin for getting CPU statistics."""
220
220
 
221
- _cpu_stats_monitor: CPUStatsMonitor
221
+ _cpu_stats_monitor: CPUStatsMonitor | None
222
222
 
223
223
  def __init__(self, config: Config) -> None:
224
224
  super().__init__(config)
225
225
 
226
- self._cpu_stats_monitor = CPUStatsMonitor(
227
- ping_interval=self.config.cpu_stats.ping_interval,
228
- context=self._mp_ctx,
229
- manager=self._mp_manager,
230
- )
226
+ if (ctx := self.multiprocessing_context) is not None and (mgr := self.multiprocessing_manager) is not None:
227
+ self._cpu_stats_monitor = CPUStatsMonitor(self.config.cpu_stats.ping_interval, ctx, mgr)
228
+ else:
229
+ self._cpu_stats_monitor = None
231
230
 
232
231
  def on_training_start(self, state: State) -> State:
233
232
  state = super().on_training_start(state)
234
233
 
235
- self._cpu_stats_monitor.start()
234
+ if (monitor := self._cpu_stats_monitor) is not None:
235
+ monitor.start()
236
236
  return state
237
237
 
238
238
  def on_training_end(self, state: State) -> State:
239
239
  state = super().on_training_end(state)
240
240
 
241
- self._cpu_stats_monitor.stop()
241
+ if (monitor := self._cpu_stats_monitor) is not None:
242
+ monitor.stop()
242
243
  return state
243
244
 
244
245
  def on_step_start(self, state: State) -> State:
245
246
  state = super().on_step_start(state)
246
247
 
247
- monitor = self._cpu_stats_monitor
248
+ if (monitor := self._cpu_stats_monitor) is None:
249
+ return state
250
+
248
251
  stats = monitor.get_if_set() if self.config.cpu_stats.only_log_once else monitor.get()
249
252
 
250
253
  if stats is not None:
@@ -234,24 +234,27 @@ class GPUStatsMixin(ProcessMixin[Config], LoggerMixin[Config], Generic[Config]):
234
234
  def __init__(self, config: Config) -> None:
235
235
  super().__init__(config)
236
236
 
237
- self._gpu_stats_monitor = None
238
- if shutil.which("nvidia-smi") is not None:
239
- self._gpu_stats_monitor = GPUStatsMonitor(
240
- config.gpu_stats.ping_interval,
241
- self._mp_ctx,
242
- self._mp_manager,
243
- )
237
+ if (
238
+ shutil.which("nvidia-smi") is not None
239
+ and (ctx := self.multiprocessing_context) is not None
240
+ and (mgr := self.multiprocessing_manager) is not None
241
+ ):
242
+ self._gpu_stats_monitor = GPUStatsMonitor(config.gpu_stats.ping_interval, ctx, mgr)
243
+ else:
244
+ self._gpu_stats_monitor = None
244
245
 
245
246
  def on_training_start(self, state: State) -> State:
246
247
  state = super().on_training_start(state)
247
- if self._gpu_stats_monitor is not None:
248
- self._gpu_stats_monitor.start()
248
+
249
+ if (monitor := self._gpu_stats_monitor) is not None:
250
+ monitor.start()
249
251
  return state
250
252
 
251
253
  def on_training_end(self, state: State) -> State:
252
254
  state = super().on_training_end(state)
253
- if self._gpu_stats_monitor is not None:
254
- self._gpu_stats_monitor.stop()
255
+
256
+ if (monitor := self._gpu_stats_monitor) is not None:
257
+ monitor.stop()
255
258
  return state
256
259
 
257
260
  def on_step_start(self, state: State) -> State:
@@ -20,6 +20,7 @@ logger: logging.Logger = logging.getLogger(__name__)
20
20
  @dataclass
21
21
  class ProcessConfig(BaseConfig):
22
22
  multiprocessing_context: str | None = field("spawn", help="The multiprocessing context to use")
23
+ disable_multiprocessing: bool = field(False, help="If set, disable multiprocessing")
23
24
 
24
25
 
25
26
  Config = TypeVar("Config", bound=ProcessConfig)
@@ -28,27 +29,32 @@ Config = TypeVar("Config", bound=ProcessConfig)
28
29
  class ProcessMixin(BaseTask[Config], Generic[Config]):
29
30
  """Defines a base trainer mixin for handling monitoring processes."""
30
31
 
31
- _mp_ctx: BaseContext
32
- _mp_manager: SyncManager
32
+ _mp_ctx: BaseContext | None
33
+ _mp_manager: SyncManager | None
33
34
 
34
35
  def __init__(self, config: Config) -> None:
35
36
  super().__init__(config)
36
37
 
37
- self._mp_ctx = mp.get_context(config.multiprocessing_context)
38
- self._mp_manager = self._mp_ctx.Manager()
38
+ if self.config.disable_multiprocessing:
39
+ self._mp_ctx = None
40
+ self._mp_manager = None
41
+ else:
42
+ self._mp_ctx = mp.get_context(config.multiprocessing_context)
43
+ self._mp_manager = self._mp_ctx.Manager()
39
44
 
40
45
  @property
41
- def multiprocessing_context(self) -> BaseContext:
46
+ def multiprocessing_context(self) -> BaseContext | None:
42
47
  return self._mp_ctx
43
48
 
44
49
  @property
45
- def multiprocessing_manager(self) -> SyncManager:
50
+ def multiprocessing_manager(self) -> SyncManager | None:
46
51
  return self._mp_manager
47
52
 
48
53
  def on_training_end(self, state: State) -> State:
49
54
  state = super().on_training_end(state)
50
55
 
51
- self._mp_manager.shutdown()
52
- self._mp_manager.join()
56
+ if self._mp_manager is not None:
57
+ self._mp_manager.shutdown()
58
+ self._mp_manager.join()
53
59
 
54
60
  return state
@@ -12,6 +12,7 @@ import time
12
12
  import traceback
13
13
  from abc import ABC, abstractmethod
14
14
  from dataclasses import asdict, dataclass, is_dataclass
15
+ from pathlib import Path
15
16
  from threading import Thread
16
17
  from typing import (
17
18
  Any,
@@ -39,7 +40,7 @@ from xax.core.state import Phase, State
39
40
  from xax.nn.functions import set_random_seed
40
41
  from xax.nn.parallel import is_master
41
42
  from xax.task.mixins.artifacts import ArtifactsConfig, ArtifactsMixin
42
- from xax.task.mixins.checkpointing import CheckpointingConfig, CheckpointingMixin
43
+ from xax.task.mixins.checkpointing import CheckpointingConfig, CheckpointingMixin, CheckpointPart
43
44
  from xax.task.mixins.data_loader import DataloadersConfig, DataloadersMixin
44
45
  from xax.task.mixins.logger import LoggerConfig, LoggerMixin
45
46
  from xax.task.mixins.runnable import RunnableConfig, RunnableMixin
@@ -54,7 +55,7 @@ from xax.utils.experiments import (
54
55
  get_training_code,
55
56
  )
56
57
  from xax.utils.jax import jit as xax_jit
57
- from xax.utils.logging import LOG_STATUS
58
+ from xax.utils.logging import LOG_PING, LOG_STATUS
58
59
  from xax.utils.text import highlight_exception_message, show_info
59
60
  from xax.utils.types.frozen_dict import FrozenDict
60
61
 
@@ -340,12 +341,7 @@ class TrainMixin(
340
341
 
341
342
  if init_ckpt_path is not None:
342
343
  logger.info("Loading checkpoint from %s", init_ckpt_path)
343
- model_spec = eqx.filter_eval_shape(self.get_model, key)
344
- model, state, config = self.load_checkpoint(
345
- init_ckpt_path,
346
- part="model_state_config",
347
- model_template=model_spec,
348
- )
344
+ model, state, config = self.load_ckpt(init_ckpt_path, part="model_state_config")
349
345
  config_diff = get_diff_string(diff_configs(asdict(config), asdict(self.config)))
350
346
  if config_diff:
351
347
  logger.warning("Loaded config differs from current config:\n%s", config_diff)
@@ -353,17 +349,11 @@ class TrainMixin(
353
349
  if not load_optimizer:
354
350
  return model, state
355
351
 
356
- # Loads the optimizer.
357
- optimizer_spec = eqx.filter_eval_shape(self.get_optimizer)
358
- optimizer = self.load_checkpoint(init_ckpt_path, part="opt", optimizer_template=optimizer_spec)
359
-
360
- # Loads the optimizer state.
361
- opt_state_spec = eqx.filter_eval_shape(self.get_initial_opt_state, model, optimizer)
362
- opt_state = self.load_checkpoint(init_ckpt_path, part="opt_state", opt_state_template=opt_state_spec)
363
-
352
+ optimizer = self.load_ckpt(init_ckpt_path, part="opt")
353
+ opt_state = self.load_ckpt(init_ckpt_path, part="opt_state", model=model, optimizer=optimizer)
364
354
  return model, optimizer, opt_state, state
365
355
 
366
- logger.info("No checkpoint found. Initializing a new model.")
356
+ logger.info("Starting a new training run")
367
357
  model = self.get_model(key)
368
358
  state = State.init_state()
369
359
 
@@ -375,6 +365,131 @@ class TrainMixin(
375
365
 
376
366
  return model, optimizer, opt_state, state
377
367
 
368
+ @overload
369
+ def load_ckpt(
370
+ self,
371
+ path: Path,
372
+ *,
373
+ part: Literal["all"],
374
+ ) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State, Config]: ...
375
+
376
+ @overload
377
+ def load_ckpt(
378
+ self,
379
+ path: Path,
380
+ *,
381
+ part: Literal["model_state_config"],
382
+ ) -> tuple[PyTree, State, Config]: ...
383
+
384
+ @overload
385
+ def load_ckpt(
386
+ self,
387
+ path: Path,
388
+ *,
389
+ part: Literal["model"],
390
+ ) -> PyTree: ...
391
+
392
+ @overload
393
+ def load_ckpt(
394
+ self,
395
+ path: Path,
396
+ *,
397
+ part: Literal["opt"],
398
+ ) -> optax.GradientTransformation: ...
399
+
400
+ @overload
401
+ def load_ckpt(
402
+ self,
403
+ path: Path,
404
+ *,
405
+ part: Literal["opt_state"],
406
+ model: PyTree | None = None,
407
+ optimizer: optax.GradientTransformation | None = None,
408
+ ) -> optax.OptState: ...
409
+
410
+ @overload
411
+ def load_ckpt(
412
+ self,
413
+ path: Path,
414
+ *,
415
+ part: Literal["state"],
416
+ ) -> State: ...
417
+
418
+ @overload
419
+ def load_ckpt(
420
+ self,
421
+ path: Path,
422
+ *,
423
+ part: Literal["config"],
424
+ ) -> Config: ...
425
+
426
+ def load_ckpt(
427
+ self,
428
+ path: str | Path,
429
+ *,
430
+ part: CheckpointPart = "all",
431
+ model: PyTree | None = None,
432
+ optimizer: optax.GradientTransformation | None = None,
433
+ ) -> (
434
+ tuple[PyTree, optax.GradientTransformation, optax.OptState, State, Config]
435
+ | tuple[PyTree, State, Config]
436
+ | PyTree
437
+ | optax.GradientTransformation
438
+ | optax.OptState
439
+ | State
440
+ | Config
441
+ ):
442
+ path = Path(path)
443
+
444
+ # This key isn't used for anything, it's just a required argument.
445
+ key = jax.random.PRNGKey(0)
446
+
447
+ match part:
448
+ case "model_state_config":
449
+ model_spec = eqx.filter_eval_shape(self.get_model, key)
450
+ return self.load_ckpt_with_template(path, part="model_state_config", model_template=model_spec)
451
+
452
+ case "model":
453
+ model_spec = eqx.filter_eval_shape(self.get_model, key)
454
+ return self.load_ckpt_with_template(path, part="model", model_template=model_spec)
455
+
456
+ case "config":
457
+ return self.load_ckpt_with_template(path, part="config")
458
+
459
+ case "opt":
460
+ optimizer_spec = eqx.filter_eval_shape(self.get_optimizer)
461
+ return self.load_ckpt_with_template(path, part="opt", optimizer_template=optimizer_spec)
462
+
463
+ case "opt_state":
464
+ if model is None:
465
+ model_spec = eqx.filter_eval_shape(self.get_model, key)
466
+ model = self.load_ckpt_with_template(path, part="model", model_template=model_spec)
467
+ if optimizer is None:
468
+ optimizer_spec = eqx.filter_eval_shape(self.get_optimizer)
469
+ optimizer = self.load_ckpt_with_template(path, part="opt", optimizer_template=optimizer_spec)
470
+ opt_state_spec = eqx.filter_eval_shape(self.get_initial_opt_state, model, optimizer)
471
+ return self.load_ckpt_with_template(path, part="opt_state", opt_state_template=opt_state_spec)
472
+
473
+ case "state":
474
+ return self.load_ckpt_with_template(path, part="state")
475
+
476
+ case "config":
477
+ return self.load_ckpt_with_template(path, part="config")
478
+
479
+ case "all":
480
+ model_spec = eqx.filter_eval_shape(self.get_model, key)
481
+ model = self.load_ckpt_with_template(path, part="model", model_template=model_spec)
482
+ optimizer_spec = eqx.filter_eval_shape(self.get_optimizer)
483
+ optimizer = self.load_ckpt_with_template(path, part="opt", optimizer_template=optimizer_spec)
484
+ opt_state_spec = eqx.filter_eval_shape(self.get_initial_opt_state, model, optimizer)
485
+ opt_state = self.load_ckpt_with_template(path, part="opt_state", opt_state_template=opt_state_spec)
486
+ state = self.load_ckpt_with_template(path, part="state")
487
+ config = self.load_ckpt_with_template(path, part="config")
488
+ return model, optimizer, opt_state, state, config
489
+
490
+ case _:
491
+ raise ValueError(f"Unknown checkpoint part: {part}")
492
+
378
493
  def get_output(self, model: PyTree, batch: Batch, state: State) -> Output:
379
494
  """Gets the output from the model.
380
495
 
@@ -529,8 +644,7 @@ class TrainMixin(
529
644
  self._last_printed_remaining_time = state.elapsed_time_s
530
645
  remaining_seconds = remaining_percent * state.elapsed_time_s / (1 - remaining_percent)
531
646
  termination_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time() + remaining_seconds))
532
- # logger.info("Estimated finish time: %s", termination_time)
533
- jax.debug.print("Estimated finish time: {}", termination_time)
647
+ logger.log(LOG_PING, "Estimated finish time: %s", termination_time)
534
648
 
535
649
  def get_remaining_percent(self, state: State) -> float | None:
536
650
  if self.config.max_steps is None:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.2.0
3
+ Version: 0.2.2
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes