xax 0.1.16__tar.gz → 0.2.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. {xax-0.1.16/xax.egg-info → xax-0.2.1}/PKG-INFO +6 -6
  2. {xax-0.1.16 → xax-0.2.1}/pyproject.toml +0 -1
  3. {xax-0.1.16 → xax-0.2.1}/xax/__init__.py +4 -1
  4. {xax-0.1.16 → xax-0.2.1}/xax/core/state.py +26 -1
  5. {xax-0.1.16 → xax-0.2.1}/xax/nn/geom.py +34 -0
  6. {xax-0.1.16 → xax-0.2.1}/xax/requirements.txt +5 -5
  7. {xax-0.1.16 → xax-0.2.1}/xax/task/base.py +1 -1
  8. {xax-0.1.16 → xax-0.2.1}/xax/task/logger.py +107 -2
  9. {xax-0.1.16 → xax-0.2.1}/xax/task/loggers/tensorboard.py +16 -0
  10. {xax-0.1.16 → xax-0.2.1}/xax/task/mixins/checkpointing.py +124 -50
  11. {xax-0.1.16 → xax-0.2.1}/xax/task/mixins/data_loader.py +2 -1
  12. {xax-0.1.16 → xax-0.2.1}/xax/task/mixins/train.py +153 -27
  13. {xax-0.1.16 → xax-0.2.1}/xax/utils/experiments.py +29 -12
  14. {xax-0.1.16 → xax-0.2.1}/xax/utils/tensorboard.py +91 -3
  15. {xax-0.1.16 → xax-0.2.1/xax.egg-info}/PKG-INFO +6 -6
  16. {xax-0.1.16 → xax-0.2.1}/xax.egg-info/requires.txt +5 -5
  17. {xax-0.1.16 → xax-0.2.1}/LICENSE +0 -0
  18. {xax-0.1.16 → xax-0.2.1}/MANIFEST.in +0 -0
  19. {xax-0.1.16 → xax-0.2.1}/README.md +0 -0
  20. {xax-0.1.16 → xax-0.2.1}/setup.cfg +0 -0
  21. {xax-0.1.16 → xax-0.2.1}/setup.py +0 -0
  22. {xax-0.1.16 → xax-0.2.1}/xax/core/__init__.py +0 -0
  23. {xax-0.1.16 → xax-0.2.1}/xax/core/conf.py +0 -0
  24. {xax-0.1.16 → xax-0.2.1}/xax/nn/__init__.py +0 -0
  25. {xax-0.1.16 → xax-0.2.1}/xax/nn/embeddings.py +0 -0
  26. {xax-0.1.16 → xax-0.2.1}/xax/nn/equinox.py +0 -0
  27. {xax-0.1.16 → xax-0.2.1}/xax/nn/export.py +0 -0
  28. {xax-0.1.16 → xax-0.2.1}/xax/nn/functions.py +0 -0
  29. {xax-0.1.16 → xax-0.2.1}/xax/nn/losses.py +0 -0
  30. {xax-0.1.16 → xax-0.2.1}/xax/nn/norm.py +0 -0
  31. {xax-0.1.16 → xax-0.2.1}/xax/nn/parallel.py +0 -0
  32. {xax-0.1.16 → xax-0.2.1}/xax/nn/ssm.py +0 -0
  33. {xax-0.1.16 → xax-0.2.1}/xax/py.typed +0 -0
  34. {xax-0.1.16 → xax-0.2.1}/xax/requirements-dev.txt +0 -0
  35. {xax-0.1.16 → xax-0.2.1}/xax/task/__init__.py +0 -0
  36. {xax-0.1.16 → xax-0.2.1}/xax/task/launchers/__init__.py +0 -0
  37. {xax-0.1.16 → xax-0.2.1}/xax/task/launchers/base.py +0 -0
  38. {xax-0.1.16 → xax-0.2.1}/xax/task/launchers/cli.py +0 -0
  39. {xax-0.1.16 → xax-0.2.1}/xax/task/launchers/single_process.py +0 -0
  40. {xax-0.1.16 → xax-0.2.1}/xax/task/loggers/__init__.py +0 -0
  41. {xax-0.1.16 → xax-0.2.1}/xax/task/loggers/callback.py +0 -0
  42. {xax-0.1.16 → xax-0.2.1}/xax/task/loggers/json.py +0 -0
  43. {xax-0.1.16 → xax-0.2.1}/xax/task/loggers/state.py +0 -0
  44. {xax-0.1.16 → xax-0.2.1}/xax/task/loggers/stdout.py +0 -0
  45. {xax-0.1.16 → xax-0.2.1}/xax/task/mixins/__init__.py +0 -0
  46. {xax-0.1.16 → xax-0.2.1}/xax/task/mixins/artifacts.py +0 -0
  47. {xax-0.1.16 → xax-0.2.1}/xax/task/mixins/compile.py +0 -0
  48. {xax-0.1.16 → xax-0.2.1}/xax/task/mixins/cpu_stats.py +0 -0
  49. {xax-0.1.16 → xax-0.2.1}/xax/task/mixins/gpu_stats.py +0 -0
  50. {xax-0.1.16 → xax-0.2.1}/xax/task/mixins/logger.py +0 -0
  51. {xax-0.1.16 → xax-0.2.1}/xax/task/mixins/process.py +0 -0
  52. {xax-0.1.16 → xax-0.2.1}/xax/task/mixins/runnable.py +0 -0
  53. {xax-0.1.16 → xax-0.2.1}/xax/task/mixins/step_wrapper.py +0 -0
  54. {xax-0.1.16 → xax-0.2.1}/xax/task/script.py +0 -0
  55. {xax-0.1.16 → xax-0.2.1}/xax/task/task.py +0 -0
  56. {xax-0.1.16 → xax-0.2.1}/xax/utils/__init__.py +0 -0
  57. {xax-0.1.16 → xax-0.2.1}/xax/utils/data/__init__.py +0 -0
  58. {xax-0.1.16 → xax-0.2.1}/xax/utils/data/collate.py +0 -0
  59. {xax-0.1.16 → xax-0.2.1}/xax/utils/debugging.py +0 -0
  60. {xax-0.1.16 → xax-0.2.1}/xax/utils/jax.py +0 -0
  61. {xax-0.1.16 → xax-0.2.1}/xax/utils/jaxpr.py +0 -0
  62. {xax-0.1.16 → xax-0.2.1}/xax/utils/logging.py +0 -0
  63. {xax-0.1.16 → xax-0.2.1}/xax/utils/numpy.py +0 -0
  64. {xax-0.1.16 → xax-0.2.1}/xax/utils/profile.py +0 -0
  65. {xax-0.1.16 → xax-0.2.1}/xax/utils/pytree.py +0 -0
  66. {xax-0.1.16 → xax-0.2.1}/xax/utils/text.py +0 -0
  67. {xax-0.1.16 → xax-0.2.1}/xax/utils/types/__init__.py +0 -0
  68. {xax-0.1.16 → xax-0.2.1}/xax/utils/types/frozen_dict.py +0 -0
  69. {xax-0.1.16 → xax-0.2.1}/xax/utils/types/hashable_array.py +0 -0
  70. {xax-0.1.16 → xax-0.2.1}/xax.egg-info/SOURCES.txt +0 -0
  71. {xax-0.1.16 → xax-0.2.1}/xax.egg-info/dependency_links.txt +0 -0
  72. {xax-0.1.16 → xax-0.2.1}/xax.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.1.16
3
+ Version: 0.2.1
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -8,14 +8,14 @@ Requires-Python: >=3.11
8
8
  Description-Content-Type: text/markdown
9
9
  License-File: LICENSE
10
10
  Requires-Dist: attrs
11
+ Requires-Dist: chex
12
+ Requires-Dist: dpshdl
13
+ Requires-Dist: equinox
14
+ Requires-Dist: importlib-resources
11
15
  Requires-Dist: jax
12
16
  Requires-Dist: jaxtyping
13
- Requires-Dist: equinox
14
17
  Requires-Dist: optax
15
- Requires-Dist: dpshdl
16
- Requires-Dist: chex
17
- Requires-Dist: importlib-resources
18
- Requires-Dist: cloudpickle
18
+ Requires-Dist: orbax-checkpoint
19
19
  Requires-Dist: pillow
20
20
  Requires-Dist: omegaconf
21
21
  Requires-Dist: gitpython
@@ -35,7 +35,6 @@ explicit_package_bases = true
35
35
  [[tool.mypy.overrides]]
36
36
 
37
37
  module = [
38
- "cloudpickle.*",
39
38
  "optax.*",
40
39
  "setuptools.*",
41
40
  "tensorboard.*",
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.1.16"
15
+ __version__ = "0.2.1"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -44,6 +44,7 @@ __all__ = [
44
44
  "euler_to_quat",
45
45
  "get_projected_gravity_vector_from_quat",
46
46
  "quat_to_euler",
47
+ "quat_to_rotmat",
47
48
  "rotate_vector_by_quat",
48
49
  "cross_entropy",
49
50
  "cast_norm_type",
@@ -206,6 +207,7 @@ NAME_MAP: dict[str, str] = {
206
207
  "euler_to_quat": "nn.geom",
207
208
  "get_projected_gravity_vector_from_quat": "nn.geom",
208
209
  "quat_to_euler": "nn.geom",
210
+ "quat_to_rotmat": "nn.geom",
209
211
  "rotate_vector_by_quat": "nn.geom",
210
212
  "cross_entropy": "nn.losses",
211
213
  "cast_norm_type": "nn.norm",
@@ -369,6 +371,7 @@ if IMPORT_ALL or TYPE_CHECKING:
369
371
  euler_to_quat,
370
372
  get_projected_gravity_vector_from_quat,
371
373
  quat_to_euler,
374
+ quat_to_rotmat,
372
375
  rotate_vector_by_quat,
373
376
  )
374
377
  from xax.nn.losses import cross_entropy
@@ -12,6 +12,14 @@ from xax.core.conf import field
12
12
  Phase = Literal["train", "valid"]
13
13
 
14
14
 
15
+ def _phase_to_int(phase: Phase) -> int:
16
+ return {"train": 0, "valid": 1}[phase]
17
+
18
+
19
+ def _int_to_phase(i: int) -> Phase:
20
+ return cast(Phase, ["train", "valid"][i])
21
+
22
+
15
23
  class StateDict(TypedDict, total=False):
16
24
  num_steps: NotRequired[int]
17
25
  num_samples: NotRequired[int]
@@ -35,7 +43,7 @@ class State:
35
43
 
36
44
  @property
37
45
  def phase(self) -> Phase:
38
- return cast(Phase, ["train", "valid"][self._phase])
46
+ return _int_to_phase(self._phase)
39
47
 
40
48
  @classmethod
41
49
  def init_state(cls) -> "State":
@@ -74,3 +82,20 @@ class State:
74
82
  case _:
75
83
  raise ValueError(f"Invalid phase: {phase}")
76
84
  return State(**{**asdict(self), **kwargs, **extra_kwargs})
85
+
86
+ def to_dict(self) -> dict[str, int | float | str]:
87
+ return {
88
+ "num_steps": int(self.num_steps),
89
+ "num_samples": int(self.num_samples),
90
+ "num_valid_steps": int(self.num_valid_steps),
91
+ "num_valid_samples": int(self.num_valid_samples),
92
+ "start_time_s": float(self.start_time_s),
93
+ "elapsed_time_s": float(self.elapsed_time_s),
94
+ "phase": str(self.phase),
95
+ }
96
+
97
+ @classmethod
98
+ def from_dict(cls, d: dict[str, int | float | str]) -> "State":
99
+ if "phase" in d:
100
+ d["_phase"] = _phase_to_int(cast(Phase, d.pop("phase")))
101
+ return cls(**d) # type: ignore[arg-type]
@@ -177,3 +177,37 @@ def cubic_bezier_interpolation(y_start: Array, y_end: Array, x: Array) -> Array:
177
177
  y_diff = y_end - y_start
178
178
  bezier = x**3 + 3 * (x**2 * (1 - x))
179
179
  return y_start + y_diff * bezier
180
+
181
+
182
+ def quat_to_rotmat(quat: Array, eps: float = 1e-6) -> Array:
183
+ """Converts a quaternion to a rotation matrix.
184
+
185
+ Args:
186
+ quat: The quaternion to convert, shape (*, 4).
187
+ eps: A small epsilon value to avoid division by zero.
188
+
189
+ Returns:
190
+ The rotation matrix, shape (*, 3, 3).
191
+ """
192
+ quat = quat / (jnp.linalg.norm(quat, axis=-1, keepdims=True) + eps)
193
+ w, x, y, z = jnp.split(quat, 4, axis=-1)
194
+
195
+ xx = 1 - 2 * (y * y + z * z)
196
+ xy = 2 * (x * y - z * w)
197
+ xz = 2 * (x * z + y * w)
198
+ yx = 2 * (x * y + z * w)
199
+ yy = 1 - 2 * (x * x + z * z)
200
+ yz = 2 * (y * z - x * w)
201
+ zx = 2 * (x * z - y * w)
202
+ zy = 2 * (y * z + x * w)
203
+ zz = 1 - 2 * (x * x + y * y)
204
+
205
+ # Corrected stacking: row-major order
206
+ return jnp.concatenate(
207
+ [
208
+ jnp.concatenate([xx, xy, xz], axis=-1)[..., None, :],
209
+ jnp.concatenate([yx, yy, yz], axis=-1)[..., None, :],
210
+ jnp.concatenate([zx, zy, zz], axis=-1)[..., None, :],
211
+ ],
212
+ axis=-2,
213
+ )
@@ -2,16 +2,16 @@
2
2
 
3
3
  # Core ML/JAX dependencies
4
4
  attrs
5
+ chex
6
+ dpshdl
7
+ equinox
8
+ importlib-resources
5
9
  jax
6
10
  jaxtyping
7
- equinox
8
11
  optax
9
- dpshdl
10
- chex
11
- importlib-resources
12
+ orbax-checkpoint
12
13
 
13
14
  # Data processing and serialization
14
- cloudpickle
15
15
  pillow
16
16
 
17
17
  # Configuration and project management
@@ -79,7 +79,7 @@ class BaseTask(Generic[Config]):
79
79
  def on_training_end(self, state: State) -> State:
80
80
  return state
81
81
 
82
- def on_after_checkpoint_save(self, ckpt_path: Path, state: State) -> State:
82
+ def on_after_checkpoint_save(self, ckpt_path: Path, state: State | None) -> State | None:
83
83
  return state
84
84
 
85
85
  @functools.cached_property
@@ -18,11 +18,22 @@ from abc import ABC, abstractmethod
18
18
  from collections import defaultdict
19
19
  from dataclasses import dataclass
20
20
  from types import TracebackType
21
- from typing import Callable, Iterator, Literal, Self, Sequence, TypeVar, cast, get_args
21
+ from typing import (
22
+ Any,
23
+ Callable,
24
+ Iterator,
25
+ Literal,
26
+ Self,
27
+ Sequence,
28
+ TypeVar,
29
+ cast,
30
+ get_args,
31
+ )
22
32
 
23
33
  import jax
24
34
  import jax.numpy as jnp
25
35
  import numpy as np
36
+ from jax._src.core import ClosedJaxpr
26
37
  from jaxtyping import Array
27
38
  from PIL import Image, ImageDraw, ImageFont
28
39
  from PIL.Image import Image as PILImage
@@ -194,7 +205,10 @@ def tile_images(images: list[PILImage], sep: int = 0) -> PILImage:
194
205
  return tiled
195
206
 
196
207
 
197
- def as_numpy(array: Array) -> np.ndarray:
208
+ def as_numpy(array: Array | np.ndarray) -> np.ndarray:
209
+ """Convert a JAX array or numpy array to numpy array."""
210
+ if isinstance(array, np.ndarray):
211
+ return array
198
212
  array = jax.device_get(array)
199
213
  if jax.dtypes.issubdtype(array.dtype, jnp.floating):
200
214
  array = array.astype(jnp.float32)
@@ -205,6 +219,13 @@ def as_numpy(array: Array) -> np.ndarray:
205
219
  return np.array(array)
206
220
 
207
221
 
222
+ def as_numpy_opt(array: Array | np.ndarray | None) -> np.ndarray | None:
223
+ """Convert an optional JAX array or numpy array to numpy array."""
224
+ if array is None:
225
+ return None
226
+ return as_numpy(array)
227
+
228
+
208
229
  @dataclass(kw_only=True)
209
230
  class LogString:
210
231
  value: str
@@ -252,6 +273,19 @@ class LogHistogram:
252
273
  bucket_counts: list[int]
253
274
 
254
275
 
276
+ @dataclass(kw_only=True)
277
+ class LogMesh:
278
+ vertices: np.ndarray
279
+ colors: np.ndarray | None
280
+ faces: np.ndarray | None
281
+ config_dict: dict[str, Any] | None # noqa: ANN401
282
+
283
+
284
+ @dataclass(kw_only=True)
285
+ class LogGraph:
286
+ computation: ClosedJaxpr
287
+
288
+
255
289
  @dataclass(kw_only=True)
256
290
  class LogLine:
257
291
  state: State
@@ -261,6 +295,7 @@ class LogLine:
261
295
  strings: dict[str, dict[str, LogString]]
262
296
  images: dict[str, dict[str, LogImage]]
263
297
  videos: dict[str, dict[str, LogVideo]]
298
+ meshes: dict[str, dict[str, LogMesh]]
264
299
 
265
300
 
266
301
  @dataclass(kw_only=True)
@@ -533,6 +568,7 @@ class Logger:
533
568
  self.strings: dict[str, dict[str, Callable[[], LogString]]] = defaultdict(dict)
534
569
  self.images: dict[str, dict[str, Callable[[], LogImage]]] = defaultdict(dict)
535
570
  self.videos: dict[str, dict[str, Callable[[], LogVideo]]] = defaultdict(dict)
571
+ self.meshes: dict[str, dict[str, Callable[[], LogMesh]]] = defaultdict(dict)
536
572
  self.default_namespace = default_namespace
537
573
  self.loggers: list[LoggerImpl] = []
538
574
 
@@ -560,6 +596,7 @@ class Logger:
560
596
  strings={k: {kk: v() for kk, v in v.items()} for k, v in self.strings.items()},
561
597
  images={k: {kk: v() for kk, v in v.items()} for k, v in self.images.items()},
562
598
  videos={k: {kk: v() for kk, v in v.items()} for k, v in self.videos.items()},
599
+ meshes={k: {kk: v() for kk, v in v.items()} for k, v in self.meshes.items()},
563
600
  )
564
601
 
565
602
  def clear(self) -> None:
@@ -569,6 +606,7 @@ class Logger:
569
606
  self.strings.clear()
570
607
  self.images.clear()
571
608
  self.videos.clear()
609
+ self.meshes.clear()
572
610
 
573
611
  def write(self, state: State) -> None:
574
612
  """Writes the current step's logging information.
@@ -1051,6 +1089,73 @@ class Logger:
1051
1089
 
1052
1090
  self.videos[namespace][key] = video_future
1053
1091
 
1092
+ def log_mesh(
1093
+ self,
1094
+ key: str,
1095
+ vertices: np.ndarray | Array | Callable[[], np.ndarray | Array],
1096
+ colors: np.ndarray | Array | None | Callable[[], np.ndarray | Array | None] = None,
1097
+ faces: np.ndarray | Array | None | Callable[[], np.ndarray | Array | None] = None,
1098
+ config_dict: dict[str, Any] | None = None,
1099
+ *,
1100
+ namespace: str | None = None,
1101
+ ) -> None:
1102
+ if not self.active:
1103
+ raise RuntimeError("The logger is not active")
1104
+ namespace = self.resolve_namespace(namespace)
1105
+
1106
+ @functools.lru_cache(maxsize=None)
1107
+ def mesh_future() -> LogMesh:
1108
+ with ContextTimer() as timer:
1109
+ # Get the raw values
1110
+ vertices_val = vertices() if callable(vertices) else vertices
1111
+ colors_val = colors() if callable(colors) else colors
1112
+ faces_val = faces() if callable(faces) else faces
1113
+
1114
+ # Convert to numpy arrays with proper type handling
1115
+ vertices_np = as_numpy(vertices_val)
1116
+ colors_np = as_numpy_opt(colors_val)
1117
+ faces_np = as_numpy_opt(faces_val)
1118
+
1119
+ # Checks vertices shape.
1120
+ if vertices_np.ndim == 2:
1121
+ vertices_np = vertices_np[None]
1122
+ if vertices_np.shape[-1] != 3 or vertices_np.ndim != 3:
1123
+ raise ValueError("Vertices must have shape (N, 3) or (B, N, 3)")
1124
+
1125
+ # Checks colors shape.
1126
+ if colors_np is not None:
1127
+ if colors_np.ndim == 2:
1128
+ colors_np = colors_np[None]
1129
+ if colors_np.shape[-1] != 3 or colors_np.ndim != 3:
1130
+ raise ValueError("Colors must have shape (N, 3) or (B, N, 3)")
1131
+
1132
+ # Checks faces shape.
1133
+ if faces_np is not None:
1134
+ if faces_np.ndim == 2:
1135
+ faces_np = faces_np[None]
1136
+ if faces_np.shape[-1] != 3 or faces_np.ndim != 3:
1137
+ raise ValueError("Faces must have shape (N, 3) or (B, N, 3)")
1138
+
1139
+ # Ensures colors dtype is uint8.
1140
+ if colors_np is not None:
1141
+ if colors_np.dtype != np.uint8:
1142
+ colors_np = (colors_np * 255).astype(np.uint8)
1143
+
1144
+ # Ensures faces dtype is int32.
1145
+ if faces_np is not None:
1146
+ if faces_np.dtype != np.int32:
1147
+ faces_np = faces_np.astype(np.int32)
1148
+
1149
+ logger.debug("Mesh Key: %s, Time: %s", key, timer.elapsed_time)
1150
+ return LogMesh(
1151
+ vertices=vertices_np,
1152
+ colors=colors_np,
1153
+ faces=faces_np,
1154
+ config_dict=config_dict,
1155
+ )
1156
+
1157
+ self.meshes[namespace][key] = mesh_future
1158
+
1054
1159
  def __enter__(self) -> Self:
1055
1160
  self.active = True
1056
1161
  for logger in self.loggers:
@@ -70,6 +70,9 @@ class TensorboardLogger(LoggerImpl):
70
70
  self._started = True
71
71
 
72
72
  def worker_thread(self) -> None:
73
+ if os.environ.get("DISABLE_TENSORBOARD", "0") == "1":
74
+ return
75
+
73
76
  time.sleep(self.wait_seconds)
74
77
 
75
78
  port = int(os.environ.get("TENSORBOARD_PORT", DEFAULT_TENSORBOARD_PORT))
@@ -213,6 +216,19 @@ class TensorboardLogger(LoggerImpl):
213
216
  video_value.frames,
214
217
  fps=video_value.fps,
215
218
  global_step=line.state.num_steps,
219
+ walltime=walltime,
220
+ )
221
+
222
+ for namespace, meshes in line.meshes.items():
223
+ for mesh_key, mesh_value in meshes.items():
224
+ writer.add_mesh(
225
+ f"{namespace}/{mesh_key}",
226
+ vertices=mesh_value.vertices,
227
+ faces=mesh_value.faces,
228
+ colors=mesh_value.colors,
229
+ config_dict=mesh_value.config_dict,
230
+ global_step=line.state.num_steps,
231
+ walltime=walltime,
216
232
  )
217
233
 
218
234
  for name, contents in self.files.items():
@@ -6,9 +6,9 @@ import logging
6
6
  import tarfile
7
7
  from dataclasses import asdict, dataclass
8
8
  from pathlib import Path
9
- from typing import Any, Callable, Generic, Literal, TypeVar, cast, overload
9
+ from typing import Generic, Literal, TypeVar, cast, overload
10
10
 
11
- import cloudpickle
11
+ import equinox as eqx
12
12
  import jax
13
13
  import optax
14
14
  from jaxtyping import PyTree
@@ -63,8 +63,7 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
63
63
 
64
64
  def get_init_ckpt_path(self) -> Path | None:
65
65
  if self._exp_dir is not None:
66
- ckpt_path = self.get_ckpt_path()
67
- if ckpt_path.exists():
66
+ if (ckpt_path := self.get_ckpt_path()).exists():
68
67
  return ckpt_path
69
68
  if self.config.load_from_ckpt_path is not None:
70
69
  ckpt_path = Path(self.config.load_from_ckpt_path)
@@ -84,93 +83,129 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
84
83
  return False
85
84
 
86
85
  @overload
87
- def load_checkpoint(
86
+ def load_ckpt_with_template(
88
87
  self,
89
88
  path: Path,
90
- part: Literal["all"] = "all",
91
- ) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State, DictConfig]: ...
89
+ *,
90
+ part: Literal["all"],
91
+ model_template: PyTree,
92
+ optimizer_template: PyTree,
93
+ opt_state_template: PyTree,
94
+ ) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State, Config]: ...
92
95
 
93
96
  @overload
94
- def load_checkpoint(
97
+ def load_ckpt_with_template(
95
98
  self,
96
99
  path: Path,
97
- part: Literal["model_state_config"] = "model_state_config",
98
- ) -> tuple[PyTree, State, DictConfig]: ...
100
+ *,
101
+ part: Literal["model_state_config"],
102
+ model_template: PyTree,
103
+ ) -> tuple[PyTree, State, Config]: ...
99
104
 
100
105
  @overload
101
- def load_checkpoint(
106
+ def load_ckpt_with_template(
102
107
  self,
103
108
  path: Path,
109
+ *,
104
110
  part: Literal["model"],
111
+ model_template: PyTree,
105
112
  ) -> PyTree: ...
106
113
 
107
114
  @overload
108
- def load_checkpoint(
115
+ def load_ckpt_with_template(
109
116
  self,
110
117
  path: Path,
118
+ *,
111
119
  part: Literal["opt"],
120
+ optimizer_template: PyTree,
112
121
  ) -> optax.GradientTransformation: ...
113
122
 
114
123
  @overload
115
- def load_checkpoint(
124
+ def load_ckpt_with_template(
116
125
  self,
117
126
  path: Path,
127
+ *,
118
128
  part: Literal["opt_state"],
129
+ opt_state_template: PyTree,
119
130
  ) -> optax.OptState: ...
120
131
 
121
132
  @overload
122
- def load_checkpoint(
133
+ def load_ckpt_with_template(
123
134
  self,
124
135
  path: Path,
136
+ *,
125
137
  part: Literal["state"],
126
138
  ) -> State: ...
127
139
 
128
140
  @overload
129
- def load_checkpoint(
141
+ def load_ckpt_with_template(
130
142
  self,
131
143
  path: Path,
144
+ *,
132
145
  part: Literal["config"],
133
- ) -> DictConfig: ...
146
+ ) -> Config: ...
134
147
 
135
- def load_checkpoint(
148
+ def load_ckpt_with_template(
136
149
  self,
137
150
  path: Path,
151
+ *,
138
152
  part: CheckpointPart = "all",
153
+ model_template: PyTree | None = None,
154
+ optimizer_template: PyTree | None = None,
155
+ opt_state_template: PyTree | None = None,
139
156
  ) -> (
140
- tuple[PyTree, optax.GradientTransformation, optax.OptState, State, DictConfig]
141
- | tuple[PyTree, State, DictConfig]
157
+ tuple[PyTree, optax.GradientTransformation, optax.OptState, State, Config]
158
+ | tuple[PyTree, State, Config]
142
159
  | PyTree
143
160
  | optax.GradientTransformation
144
161
  | optax.OptState
145
162
  | State
146
- | DictConfig
163
+ | Config
147
164
  ):
165
+ """Load a checkpoint.
166
+
167
+ Args:
168
+ path: Path to the checkpoint directory
169
+ part: Which part of the checkpoint to load
170
+ model_template: Template model with correct structure but uninitialized weights
171
+ optimizer_template: Template optimizer with correct structure but uninitialized weights
172
+ opt_state_template: Template optimizer state with correct structure but uninitialized weights
173
+
174
+ Returns:
175
+ The requested checkpoint components
176
+ """
148
177
  with tarfile.open(path, "r:gz") as tar:
149
178
 
150
179
  def get_model() -> PyTree:
180
+ if model_template is None:
181
+ raise ValueError("model_template must be provided to load model weights")
151
182
  if (model := tar.extractfile("model")) is None:
152
183
  raise ValueError(f"Checkpoint does not contain a model file: {path}")
153
- return cloudpickle.load(model)
184
+ return eqx.tree_deserialise_leaves(io.BytesIO(model.read()), model_template)
154
185
 
155
186
  def get_opt() -> optax.GradientTransformation:
156
- if (opt := tar.extractfile("opt")) is None:
157
- raise ValueError(f"Checkpoint does not contain an opt file: {path}")
158
- return cloudpickle.load(opt)
187
+ if optimizer_template is None:
188
+ raise ValueError("optimizer_template must be provided to load optimizer")
189
+ if (opt := tar.extractfile("optimizer")) is None:
190
+ raise ValueError(f"Checkpoint does not contain an optimizer file: {path}")
191
+ return eqx.tree_deserialise_leaves(io.BytesIO(opt.read()), optimizer_template)
159
192
 
160
193
  def get_opt_state() -> optax.OptState:
194
+ if opt_state_template is None:
195
+ raise ValueError("opt_state_template must be provided to load optimizer state")
161
196
  if (opt_state := tar.extractfile("opt_state")) is None:
162
- raise ValueError(f"Checkpoint does not contain an opt_state file: {path}")
163
- return cloudpickle.load(opt_state)
197
+ raise ValueError(f"Checkpoint does not contain an optimizer state file: {path}")
198
+ return eqx.tree_deserialise_leaves(io.BytesIO(opt_state.read()), opt_state_template)
164
199
 
165
200
  def get_state() -> State:
166
201
  if (state := tar.extractfile("state")) is None:
167
202
  raise ValueError(f"Checkpoint does not contain a state file: {path}")
168
203
  return State(**json.loads(state.read().decode()))
169
204
 
170
- def get_config() -> DictConfig:
205
+ def get_config() -> Config:
171
206
  if (config := tar.extractfile("config")) is None:
172
207
  raise ValueError(f"Checkpoint does not contain a config file: {path}")
173
- return cast(DictConfig, OmegaConf.load(config))
208
+ return self.get_config(cast(DictConfig, OmegaConf.load(config)), use_cli=False)
174
209
 
175
210
  match part:
176
211
  case "model":
@@ -192,51 +227,90 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
192
227
 
193
228
  def save_checkpoint(
194
229
  self,
195
- model: PyTree,
196
- optimizer: optax.GradientTransformation,
197
- opt_state: optax.OptState,
198
- state: State,
230
+ model: PyTree | None = None,
231
+ optimizer: optax.GradientTransformation | None = None,
232
+ opt_state: optax.OptState | None = None,
233
+ aux_data: PyTree | None = None,
234
+ state: State | None = None,
199
235
  ) -> Path:
236
+ """Save a checkpoint.
237
+
238
+ Args:
239
+ model: The model to save
240
+ state: The current training state
241
+ optimizer: The optimizer to save
242
+ aux_data: Additional data to save
243
+ opt_state: The optimizer state to save
244
+
245
+ Returns:
246
+ Path to the saved checkpoint
247
+ """
200
248
  ckpt_path = self.get_ckpt_path(state)
201
249
 
202
250
  if not is_master():
203
251
  return ckpt_path
204
252
 
205
- # Gets the path to the last checkpoint.
253
+ # Gets the path to the last checkpoint
206
254
  logger.info("Saving checkpoint to %s", ckpt_path)
207
255
  last_ckpt_path = self.get_ckpt_path()
208
256
  ckpt_path.parent.mkdir(exist_ok=True, parents=True)
209
257
 
210
- # Potentially removes the last checkpoint.
258
+ # Potentially removes the last checkpoint
211
259
  if last_ckpt_path.exists() and self.config.only_save_most_recent:
212
260
  if (base_ckpt := last_ckpt_path.resolve()).is_file():
213
261
  base_ckpt.unlink()
214
262
 
215
- # Combines all temporary files into a single checkpoint TAR file.
263
+ # Save the checkpoint components
216
264
  with tarfile.open(ckpt_path, "w:gz") as tar:
217
265
 
218
- def add_file(name: str, write_fn: Callable[[io.BytesIO], Any]) -> None:
266
+ def add_file(name: str, buf: io.BytesIO) -> None:
267
+ tarinfo = tarfile.TarInfo(name)
268
+ tarinfo.size = buf.tell()
269
+ buf.seek(0)
270
+ tar.addfile(tarinfo, buf)
271
+
272
+ # Save model using Equinox
273
+ if model is not None:
274
+ with io.BytesIO() as buf:
275
+ eqx.tree_serialise_leaves(buf, model)
276
+ add_file("model", buf)
277
+
278
+ # Save optimizer using Equinox
279
+ if optimizer is not None:
280
+ with io.BytesIO() as buf:
281
+ eqx.tree_serialise_leaves(buf, optimizer)
282
+ add_file("optimizer", buf)
283
+
284
+ # Save optimizer state using Equinox
285
+ if opt_state is not None:
219
286
  with io.BytesIO() as buf:
220
- write_fn(buf)
221
- tarinfo = tarfile.TarInfo(name)
222
- tarinfo.size = buf.tell()
223
- buf.seek(0)
224
- tar.addfile(tarinfo, buf)
225
-
226
- add_file("model", lambda buf: cloudpickle.dump(model, buf))
227
- add_file("opt", lambda buf: cloudpickle.dump(optimizer, buf))
228
- add_file("opt_state", lambda buf: cloudpickle.dump(opt_state, buf))
229
- add_file("state", lambda buf: buf.write(json.dumps(asdict(state), indent=2).encode()))
230
- add_file("config", lambda buf: buf.write(OmegaConf.to_yaml(self.config).encode()))
231
-
232
- # Updates the symlink to the new checkpoint.
287
+ eqx.tree_serialise_leaves(buf, opt_state)
288
+ add_file("opt_state", buf)
289
+
290
+ # Save aux data using Equinox.
291
+ if aux_data is not None:
292
+ with io.BytesIO() as buf:
293
+ eqx.tree_serialise_leaves(buf, aux_data)
294
+ add_file("aux_data", buf)
295
+
296
+ # Save state and config as JSON
297
+ def add_file_bytes(name: str, data: bytes) -> None: # noqa: ANN401
298
+ info = tarfile.TarInfo(name=name)
299
+ info.size = len(data)
300
+ tar.addfile(info, io.BytesIO(data))
301
+
302
+ if state is not None:
303
+ add_file_bytes("state", json.dumps(asdict(state), indent=2).encode())
304
+ add_file_bytes("config", OmegaConf.to_yaml(self.config).encode())
305
+
306
+ # Updates the symlink to the new checkpoint
233
307
  last_ckpt_path.unlink(missing_ok=True)
234
308
  try:
235
309
  last_ckpt_path.symlink_to(ckpt_path.relative_to(last_ckpt_path.parent))
236
310
  except FileExistsError:
237
311
  logger.exception("Exception while trying to update %s", ckpt_path)
238
312
 
239
- # Calls the base callback.
313
+ # Calls the base callback
240
314
  self.on_after_checkpoint_save(ckpt_path, state)
241
315
 
242
316
  return ckpt_path