xax 0.1.16__tar.gz → 0.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. {xax-0.1.16/xax.egg-info → xax-0.2.0}/PKG-INFO +6 -6
  2. {xax-0.1.16 → xax-0.2.0}/pyproject.toml +0 -1
  3. {xax-0.1.16 → xax-0.2.0}/xax/__init__.py +1 -1
  4. {xax-0.1.16 → xax-0.2.0}/xax/core/state.py +26 -1
  5. {xax-0.1.16 → xax-0.2.0}/xax/requirements.txt +5 -5
  6. {xax-0.1.16 → xax-0.2.0}/xax/task/base.py +1 -1
  7. {xax-0.1.16 → xax-0.2.0}/xax/task/logger.py +107 -2
  8. {xax-0.1.16 → xax-0.2.0}/xax/task/loggers/tensorboard.py +16 -0
  9. {xax-0.1.16 → xax-0.2.0}/xax/task/mixins/checkpointing.py +118 -41
  10. {xax-0.1.16 → xax-0.2.0}/xax/task/mixins/data_loader.py +2 -1
  11. {xax-0.1.16 → xax-0.2.0}/xax/task/mixins/train.py +35 -23
  12. {xax-0.1.16 → xax-0.2.0}/xax/utils/experiments.py +29 -12
  13. {xax-0.1.16 → xax-0.2.0}/xax/utils/tensorboard.py +91 -3
  14. {xax-0.1.16 → xax-0.2.0/xax.egg-info}/PKG-INFO +6 -6
  15. {xax-0.1.16 → xax-0.2.0}/xax.egg-info/requires.txt +5 -5
  16. {xax-0.1.16 → xax-0.2.0}/LICENSE +0 -0
  17. {xax-0.1.16 → xax-0.2.0}/MANIFEST.in +0 -0
  18. {xax-0.1.16 → xax-0.2.0}/README.md +0 -0
  19. {xax-0.1.16 → xax-0.2.0}/setup.cfg +0 -0
  20. {xax-0.1.16 → xax-0.2.0}/setup.py +0 -0
  21. {xax-0.1.16 → xax-0.2.0}/xax/core/__init__.py +0 -0
  22. {xax-0.1.16 → xax-0.2.0}/xax/core/conf.py +0 -0
  23. {xax-0.1.16 → xax-0.2.0}/xax/nn/__init__.py +0 -0
  24. {xax-0.1.16 → xax-0.2.0}/xax/nn/embeddings.py +0 -0
  25. {xax-0.1.16 → xax-0.2.0}/xax/nn/equinox.py +0 -0
  26. {xax-0.1.16 → xax-0.2.0}/xax/nn/export.py +0 -0
  27. {xax-0.1.16 → xax-0.2.0}/xax/nn/functions.py +0 -0
  28. {xax-0.1.16 → xax-0.2.0}/xax/nn/geom.py +0 -0
  29. {xax-0.1.16 → xax-0.2.0}/xax/nn/losses.py +0 -0
  30. {xax-0.1.16 → xax-0.2.0}/xax/nn/norm.py +0 -0
  31. {xax-0.1.16 → xax-0.2.0}/xax/nn/parallel.py +0 -0
  32. {xax-0.1.16 → xax-0.2.0}/xax/nn/ssm.py +0 -0
  33. {xax-0.1.16 → xax-0.2.0}/xax/py.typed +0 -0
  34. {xax-0.1.16 → xax-0.2.0}/xax/requirements-dev.txt +0 -0
  35. {xax-0.1.16 → xax-0.2.0}/xax/task/__init__.py +0 -0
  36. {xax-0.1.16 → xax-0.2.0}/xax/task/launchers/__init__.py +0 -0
  37. {xax-0.1.16 → xax-0.2.0}/xax/task/launchers/base.py +0 -0
  38. {xax-0.1.16 → xax-0.2.0}/xax/task/launchers/cli.py +0 -0
  39. {xax-0.1.16 → xax-0.2.0}/xax/task/launchers/single_process.py +0 -0
  40. {xax-0.1.16 → xax-0.2.0}/xax/task/loggers/__init__.py +0 -0
  41. {xax-0.1.16 → xax-0.2.0}/xax/task/loggers/callback.py +0 -0
  42. {xax-0.1.16 → xax-0.2.0}/xax/task/loggers/json.py +0 -0
  43. {xax-0.1.16 → xax-0.2.0}/xax/task/loggers/state.py +0 -0
  44. {xax-0.1.16 → xax-0.2.0}/xax/task/loggers/stdout.py +0 -0
  45. {xax-0.1.16 → xax-0.2.0}/xax/task/mixins/__init__.py +0 -0
  46. {xax-0.1.16 → xax-0.2.0}/xax/task/mixins/artifacts.py +0 -0
  47. {xax-0.1.16 → xax-0.2.0}/xax/task/mixins/compile.py +0 -0
  48. {xax-0.1.16 → xax-0.2.0}/xax/task/mixins/cpu_stats.py +0 -0
  49. {xax-0.1.16 → xax-0.2.0}/xax/task/mixins/gpu_stats.py +0 -0
  50. {xax-0.1.16 → xax-0.2.0}/xax/task/mixins/logger.py +0 -0
  51. {xax-0.1.16 → xax-0.2.0}/xax/task/mixins/process.py +0 -0
  52. {xax-0.1.16 → xax-0.2.0}/xax/task/mixins/runnable.py +0 -0
  53. {xax-0.1.16 → xax-0.2.0}/xax/task/mixins/step_wrapper.py +0 -0
  54. {xax-0.1.16 → xax-0.2.0}/xax/task/script.py +0 -0
  55. {xax-0.1.16 → xax-0.2.0}/xax/task/task.py +0 -0
  56. {xax-0.1.16 → xax-0.2.0}/xax/utils/__init__.py +0 -0
  57. {xax-0.1.16 → xax-0.2.0}/xax/utils/data/__init__.py +0 -0
  58. {xax-0.1.16 → xax-0.2.0}/xax/utils/data/collate.py +0 -0
  59. {xax-0.1.16 → xax-0.2.0}/xax/utils/debugging.py +0 -0
  60. {xax-0.1.16 → xax-0.2.0}/xax/utils/jax.py +0 -0
  61. {xax-0.1.16 → xax-0.2.0}/xax/utils/jaxpr.py +0 -0
  62. {xax-0.1.16 → xax-0.2.0}/xax/utils/logging.py +0 -0
  63. {xax-0.1.16 → xax-0.2.0}/xax/utils/numpy.py +0 -0
  64. {xax-0.1.16 → xax-0.2.0}/xax/utils/profile.py +0 -0
  65. {xax-0.1.16 → xax-0.2.0}/xax/utils/pytree.py +0 -0
  66. {xax-0.1.16 → xax-0.2.0}/xax/utils/text.py +0 -0
  67. {xax-0.1.16 → xax-0.2.0}/xax/utils/types/__init__.py +0 -0
  68. {xax-0.1.16 → xax-0.2.0}/xax/utils/types/frozen_dict.py +0 -0
  69. {xax-0.1.16 → xax-0.2.0}/xax/utils/types/hashable_array.py +0 -0
  70. {xax-0.1.16 → xax-0.2.0}/xax.egg-info/SOURCES.txt +0 -0
  71. {xax-0.1.16 → xax-0.2.0}/xax.egg-info/dependency_links.txt +0 -0
  72. {xax-0.1.16 → xax-0.2.0}/xax.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.1.16
3
+ Version: 0.2.0
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -8,14 +8,14 @@ Requires-Python: >=3.11
8
8
  Description-Content-Type: text/markdown
9
9
  License-File: LICENSE
10
10
  Requires-Dist: attrs
11
+ Requires-Dist: chex
12
+ Requires-Dist: dpshdl
13
+ Requires-Dist: equinox
14
+ Requires-Dist: importlib-resources
11
15
  Requires-Dist: jax
12
16
  Requires-Dist: jaxtyping
13
- Requires-Dist: equinox
14
17
  Requires-Dist: optax
15
- Requires-Dist: dpshdl
16
- Requires-Dist: chex
17
- Requires-Dist: importlib-resources
18
- Requires-Dist: cloudpickle
18
+ Requires-Dist: orbax-checkpoint
19
19
  Requires-Dist: pillow
20
20
  Requires-Dist: omegaconf
21
21
  Requires-Dist: gitpython
@@ -35,7 +35,6 @@ explicit_package_bases = true
35
35
  [[tool.mypy.overrides]]
36
36
 
37
37
  module = [
38
- "cloudpickle.*",
39
38
  "optax.*",
40
39
  "setuptools.*",
41
40
  "tensorboard.*",
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.1.16"
15
+ __version__ = "0.2.0"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -12,6 +12,14 @@ from xax.core.conf import field
12
12
  Phase = Literal["train", "valid"]
13
13
 
14
14
 
15
+ def _phase_to_int(phase: Phase) -> int:
16
+ return {"train": 0, "valid": 1}[phase]
17
+
18
+
19
+ def _int_to_phase(i: int) -> Phase:
20
+ return cast(Phase, ["train", "valid"][i])
21
+
22
+
15
23
  class StateDict(TypedDict, total=False):
16
24
  num_steps: NotRequired[int]
17
25
  num_samples: NotRequired[int]
@@ -35,7 +43,7 @@ class State:
35
43
 
36
44
  @property
37
45
  def phase(self) -> Phase:
38
- return cast(Phase, ["train", "valid"][self._phase])
46
+ return _int_to_phase(self._phase)
39
47
 
40
48
  @classmethod
41
49
  def init_state(cls) -> "State":
@@ -74,3 +82,20 @@ class State:
74
82
  case _:
75
83
  raise ValueError(f"Invalid phase: {phase}")
76
84
  return State(**{**asdict(self), **kwargs, **extra_kwargs})
85
+
86
+ def to_dict(self) -> dict[str, int | float | str]:
87
+ return {
88
+ "num_steps": int(self.num_steps),
89
+ "num_samples": int(self.num_samples),
90
+ "num_valid_steps": int(self.num_valid_steps),
91
+ "num_valid_samples": int(self.num_valid_samples),
92
+ "start_time_s": float(self.start_time_s),
93
+ "elapsed_time_s": float(self.elapsed_time_s),
94
+ "phase": str(self.phase),
95
+ }
96
+
97
+ @classmethod
98
+ def from_dict(cls, d: dict[str, int | float | str]) -> "State":
99
+ if "phase" in d:
100
+ d["_phase"] = _phase_to_int(cast(Phase, d.pop("phase")))
101
+ return cls(**d) # type: ignore[arg-type]
@@ -2,16 +2,16 @@
2
2
 
3
3
  # Core ML/JAX dependencies
4
4
  attrs
5
+ chex
6
+ dpshdl
7
+ equinox
8
+ importlib-resources
5
9
  jax
6
10
  jaxtyping
7
- equinox
8
11
  optax
9
- dpshdl
10
- chex
11
- importlib-resources
12
+ orbax-checkpoint
12
13
 
13
14
  # Data processing and serialization
14
- cloudpickle
15
15
  pillow
16
16
 
17
17
  # Configuration and project management
@@ -79,7 +79,7 @@ class BaseTask(Generic[Config]):
79
79
  def on_training_end(self, state: State) -> State:
80
80
  return state
81
81
 
82
- def on_after_checkpoint_save(self, ckpt_path: Path, state: State) -> State:
82
+ def on_after_checkpoint_save(self, ckpt_path: Path, state: State | None) -> State | None:
83
83
  return state
84
84
 
85
85
  @functools.cached_property
@@ -18,11 +18,22 @@ from abc import ABC, abstractmethod
18
18
  from collections import defaultdict
19
19
  from dataclasses import dataclass
20
20
  from types import TracebackType
21
- from typing import Callable, Iterator, Literal, Self, Sequence, TypeVar, cast, get_args
21
+ from typing import (
22
+ Any,
23
+ Callable,
24
+ Iterator,
25
+ Literal,
26
+ Self,
27
+ Sequence,
28
+ TypeVar,
29
+ cast,
30
+ get_args,
31
+ )
22
32
 
23
33
  import jax
24
34
  import jax.numpy as jnp
25
35
  import numpy as np
36
+ from jax._src.core import ClosedJaxpr
26
37
  from jaxtyping import Array
27
38
  from PIL import Image, ImageDraw, ImageFont
28
39
  from PIL.Image import Image as PILImage
@@ -194,7 +205,10 @@ def tile_images(images: list[PILImage], sep: int = 0) -> PILImage:
194
205
  return tiled
195
206
 
196
207
 
197
- def as_numpy(array: Array) -> np.ndarray:
208
+ def as_numpy(array: Array | np.ndarray) -> np.ndarray:
209
+ """Convert a JAX array or numpy array to numpy array."""
210
+ if isinstance(array, np.ndarray):
211
+ return array
198
212
  array = jax.device_get(array)
199
213
  if jax.dtypes.issubdtype(array.dtype, jnp.floating):
200
214
  array = array.astype(jnp.float32)
@@ -205,6 +219,13 @@ def as_numpy(array: Array) -> np.ndarray:
205
219
  return np.array(array)
206
220
 
207
221
 
222
+ def as_numpy_opt(array: Array | np.ndarray | None) -> np.ndarray | None:
223
+ """Convert an optional JAX array or numpy array to numpy array."""
224
+ if array is None:
225
+ return None
226
+ return as_numpy(array)
227
+
228
+
208
229
  @dataclass(kw_only=True)
209
230
  class LogString:
210
231
  value: str
@@ -252,6 +273,19 @@ class LogHistogram:
252
273
  bucket_counts: list[int]
253
274
 
254
275
 
276
+ @dataclass(kw_only=True)
277
+ class LogMesh:
278
+ vertices: np.ndarray
279
+ colors: np.ndarray | None
280
+ faces: np.ndarray | None
281
+ config_dict: dict[str, Any] | None # noqa: ANN401
282
+
283
+
284
+ @dataclass(kw_only=True)
285
+ class LogGraph:
286
+ computation: ClosedJaxpr
287
+
288
+
255
289
  @dataclass(kw_only=True)
256
290
  class LogLine:
257
291
  state: State
@@ -261,6 +295,7 @@ class LogLine:
261
295
  strings: dict[str, dict[str, LogString]]
262
296
  images: dict[str, dict[str, LogImage]]
263
297
  videos: dict[str, dict[str, LogVideo]]
298
+ meshes: dict[str, dict[str, LogMesh]]
264
299
 
265
300
 
266
301
  @dataclass(kw_only=True)
@@ -533,6 +568,7 @@ class Logger:
533
568
  self.strings: dict[str, dict[str, Callable[[], LogString]]] = defaultdict(dict)
534
569
  self.images: dict[str, dict[str, Callable[[], LogImage]]] = defaultdict(dict)
535
570
  self.videos: dict[str, dict[str, Callable[[], LogVideo]]] = defaultdict(dict)
571
+ self.meshes: dict[str, dict[str, Callable[[], LogMesh]]] = defaultdict(dict)
536
572
  self.default_namespace = default_namespace
537
573
  self.loggers: list[LoggerImpl] = []
538
574
 
@@ -560,6 +596,7 @@ class Logger:
560
596
  strings={k: {kk: v() for kk, v in v.items()} for k, v in self.strings.items()},
561
597
  images={k: {kk: v() for kk, v in v.items()} for k, v in self.images.items()},
562
598
  videos={k: {kk: v() for kk, v in v.items()} for k, v in self.videos.items()},
599
+ meshes={k: {kk: v() for kk, v in v.items()} for k, v in self.meshes.items()},
563
600
  )
564
601
 
565
602
  def clear(self) -> None:
@@ -569,6 +606,7 @@ class Logger:
569
606
  self.strings.clear()
570
607
  self.images.clear()
571
608
  self.videos.clear()
609
+ self.meshes.clear()
572
610
 
573
611
  def write(self, state: State) -> None:
574
612
  """Writes the current step's logging information.
@@ -1051,6 +1089,73 @@ class Logger:
1051
1089
 
1052
1090
  self.videos[namespace][key] = video_future
1053
1091
 
1092
+ def log_mesh(
1093
+ self,
1094
+ key: str,
1095
+ vertices: np.ndarray | Array | Callable[[], np.ndarray | Array],
1096
+ colors: np.ndarray | Array | None | Callable[[], np.ndarray | Array | None] = None,
1097
+ faces: np.ndarray | Array | None | Callable[[], np.ndarray | Array | None] = None,
1098
+ config_dict: dict[str, Any] | None = None,
1099
+ *,
1100
+ namespace: str | None = None,
1101
+ ) -> None:
1102
+ if not self.active:
1103
+ raise RuntimeError("The logger is not active")
1104
+ namespace = self.resolve_namespace(namespace)
1105
+
1106
+ @functools.lru_cache(maxsize=None)
1107
+ def mesh_future() -> LogMesh:
1108
+ with ContextTimer() as timer:
1109
+ # Get the raw values
1110
+ vertices_val = vertices() if callable(vertices) else vertices
1111
+ colors_val = colors() if callable(colors) else colors
1112
+ faces_val = faces() if callable(faces) else faces
1113
+
1114
+ # Convert to numpy arrays with proper type handling
1115
+ vertices_np = as_numpy(vertices_val)
1116
+ colors_np = as_numpy_opt(colors_val)
1117
+ faces_np = as_numpy_opt(faces_val)
1118
+
1119
+ # Checks vertices shape.
1120
+ if vertices_np.ndim == 2:
1121
+ vertices_np = vertices_np[None]
1122
+ if vertices_np.shape[-1] != 3 or vertices_np.ndim != 3:
1123
+ raise ValueError("Vertices must have shape (N, 3) or (B, N, 3)")
1124
+
1125
+ # Checks colors shape.
1126
+ if colors_np is not None:
1127
+ if colors_np.ndim == 2:
1128
+ colors_np = colors_np[None]
1129
+ if colors_np.shape[-1] != 3 or colors_np.ndim != 3:
1130
+ raise ValueError("Colors must have shape (N, 3) or (B, N, 3)")
1131
+
1132
+ # Checks faces shape.
1133
+ if faces_np is not None:
1134
+ if faces_np.ndim == 2:
1135
+ faces_np = faces_np[None]
1136
+ if faces_np.shape[-1] != 3 or faces_np.ndim != 3:
1137
+ raise ValueError("Faces must have shape (N, 3) or (B, N, 3)")
1138
+
1139
+ # Ensures colors dtype is uint8.
1140
+ if colors_np is not None:
1141
+ if colors_np.dtype != np.uint8:
1142
+ colors_np = (colors_np * 255).astype(np.uint8)
1143
+
1144
+ # Ensures faces dtype is int32.
1145
+ if faces_np is not None:
1146
+ if faces_np.dtype != np.int32:
1147
+ faces_np = faces_np.astype(np.int32)
1148
+
1149
+ logger.debug("Mesh Key: %s, Time: %s", key, timer.elapsed_time)
1150
+ return LogMesh(
1151
+ vertices=vertices_np,
1152
+ colors=colors_np,
1153
+ faces=faces_np,
1154
+ config_dict=config_dict,
1155
+ )
1156
+
1157
+ self.meshes[namespace][key] = mesh_future
1158
+
1054
1159
  def __enter__(self) -> Self:
1055
1160
  self.active = True
1056
1161
  for logger in self.loggers:
@@ -70,6 +70,9 @@ class TensorboardLogger(LoggerImpl):
70
70
  self._started = True
71
71
 
72
72
  def worker_thread(self) -> None:
73
+ if os.environ.get("DISABLE_TENSORBOARD", "0") == "1":
74
+ return
75
+
73
76
  time.sleep(self.wait_seconds)
74
77
 
75
78
  port = int(os.environ.get("TENSORBOARD_PORT", DEFAULT_TENSORBOARD_PORT))
@@ -213,6 +216,19 @@ class TensorboardLogger(LoggerImpl):
213
216
  video_value.frames,
214
217
  fps=video_value.fps,
215
218
  global_step=line.state.num_steps,
219
+ walltime=walltime,
220
+ )
221
+
222
+ for namespace, meshes in line.meshes.items():
223
+ for mesh_key, mesh_value in meshes.items():
224
+ writer.add_mesh(
225
+ f"{namespace}/{mesh_key}",
226
+ vertices=mesh_value.vertices,
227
+ faces=mesh_value.faces,
228
+ colors=mesh_value.colors,
229
+ config_dict=mesh_value.config_dict,
230
+ global_step=line.state.num_steps,
231
+ walltime=walltime,
216
232
  )
217
233
 
218
234
  for name, contents in self.files.items():
@@ -6,9 +6,9 @@ import logging
6
6
  import tarfile
7
7
  from dataclasses import asdict, dataclass
8
8
  from pathlib import Path
9
- from typing import Any, Callable, Generic, Literal, TypeVar, cast, overload
9
+ from typing import Generic, Literal, TypeVar, cast, overload
10
10
 
11
- import cloudpickle
11
+ import equinox as eqx
12
12
  import jax
13
13
  import optax
14
14
  from jaxtyping import PyTree
@@ -64,7 +64,9 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
64
64
  def get_init_ckpt_path(self) -> Path | None:
65
65
  if self._exp_dir is not None:
66
66
  ckpt_path = self.get_ckpt_path()
67
- if ckpt_path.exists():
67
+ if not ckpt_path.exists():
68
+ logger.warning("No checkpoint found in experiment directory: %s", ckpt_path)
69
+ else:
68
70
  return ckpt_path
69
71
  if self.config.load_from_ckpt_path is not None:
70
72
  ckpt_path = Path(self.config.load_from_ckpt_path)
@@ -87,41 +89,54 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
87
89
  def load_checkpoint(
88
90
  self,
89
91
  path: Path,
90
- part: Literal["all"] = "all",
91
- ) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State, DictConfig]: ...
92
+ *,
93
+ part: Literal["all"],
94
+ model_template: PyTree,
95
+ optimizer_template: PyTree,
96
+ opt_state_template: PyTree,
97
+ ) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State, Config]: ...
92
98
 
93
99
  @overload
94
100
  def load_checkpoint(
95
101
  self,
96
102
  path: Path,
97
- part: Literal["model_state_config"] = "model_state_config",
98
- ) -> tuple[PyTree, State, DictConfig]: ...
103
+ *,
104
+ part: Literal["model_state_config"],
105
+ model_template: PyTree,
106
+ ) -> tuple[PyTree, State, Config]: ...
99
107
 
100
108
  @overload
101
109
  def load_checkpoint(
102
110
  self,
103
111
  path: Path,
112
+ *,
104
113
  part: Literal["model"],
114
+ model_template: PyTree,
105
115
  ) -> PyTree: ...
106
116
 
107
117
  @overload
108
118
  def load_checkpoint(
109
119
  self,
110
120
  path: Path,
121
+ *,
111
122
  part: Literal["opt"],
123
+ optimizer_template: PyTree,
112
124
  ) -> optax.GradientTransformation: ...
113
125
 
114
126
  @overload
115
127
  def load_checkpoint(
116
128
  self,
117
129
  path: Path,
130
+ *,
118
131
  part: Literal["opt_state"],
132
+ opt_state_template: PyTree,
119
133
  ) -> optax.OptState: ...
120
134
 
121
135
  @overload
122
136
  def load_checkpoint(
123
137
  self,
124
138
  path: Path,
139
+ *,
125
140
  part: Literal["state"],
126
141
  ) -> State: ...
127
142
 
@@ -129,48 +144,71 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
129
144
  def load_checkpoint(
130
145
  self,
131
146
  path: Path,
147
+ *,
132
148
  part: Literal["config"],
133
- ) -> DictConfig: ...
149
+ ) -> Config: ...
134
150
 
135
151
  def load_checkpoint(
136
152
  self,
137
153
  path: Path,
154
+ *,
138
155
  part: CheckpointPart = "all",
156
+ model_template: PyTree | None = None,
157
+ optimizer_template: PyTree | None = None,
158
+ opt_state_template: PyTree | None = None,
139
159
  ) -> (
140
- tuple[PyTree, optax.GradientTransformation, optax.OptState, State, DictConfig]
141
- | tuple[PyTree, State, DictConfig]
160
+ tuple[PyTree, optax.GradientTransformation, optax.OptState, State, Config]
161
+ | tuple[PyTree, State, Config]
142
162
  | PyTree
143
163
  | optax.GradientTransformation
144
164
  | optax.OptState
145
165
  | State
146
- | DictConfig
166
+ | Config
147
167
  ):
168
+ """Load a checkpoint.
169
+
170
+ Args:
171
+ path: Path to the checkpoint directory
172
+ part: Which part of the checkpoint to load
173
+ model_template: Template model with correct structure but uninitialized weights
174
+ optimizer_template: Template optimizer with correct structure but uninitialized weights
175
+ opt_state_template: Template optimizer state with correct structure but uninitialized weights
176
+
177
+ Returns:
178
+ The requested checkpoint components
179
+ """
148
180
  with tarfile.open(path, "r:gz") as tar:
149
181
 
150
182
  def get_model() -> PyTree:
183
+ if model_template is None:
184
+ raise ValueError("model_template must be provided to load model weights")
151
185
  if (model := tar.extractfile("model")) is None:
152
186
  raise ValueError(f"Checkpoint does not contain a model file: {path}")
153
- return cloudpickle.load(model)
187
+ return eqx.tree_deserialise_leaves(io.BytesIO(model.read()), model_template)
154
188
 
155
189
  def get_opt() -> optax.GradientTransformation:
156
- if (opt := tar.extractfile("opt")) is None:
157
- raise ValueError(f"Checkpoint does not contain an opt file: {path}")
158
- return cloudpickle.load(opt)
190
+ if optimizer_template is None:
191
+ raise ValueError("optimizer_template must be provided to load optimizer")
192
+ if (opt := tar.extractfile("optimizer")) is None:
193
+ raise ValueError(f"Checkpoint does not contain an optimizer file: {path}")
194
+ return eqx.tree_deserialise_leaves(io.BytesIO(opt.read()), optimizer_template)
159
195
 
160
196
  def get_opt_state() -> optax.OptState:
197
+ if opt_state_template is None:
198
+ raise ValueError("opt_state_template must be provided to load optimizer state")
161
199
  if (opt_state := tar.extractfile("opt_state")) is None:
162
- raise ValueError(f"Checkpoint does not contain an opt_state file: {path}")
163
- return cloudpickle.load(opt_state)
200
+ raise ValueError(f"Checkpoint does not contain an optimizer state file: {path}")
201
+ return eqx.tree_deserialise_leaves(io.BytesIO(opt_state.read()), opt_state_template)
164
202
 
165
203
  def get_state() -> State:
166
204
  if (state := tar.extractfile("state")) is None:
167
205
  raise ValueError(f"Checkpoint does not contain a state file: {path}")
168
206
  return State(**json.loads(state.read().decode()))
169
207
 
170
- def get_config() -> DictConfig:
208
+ def get_config() -> Config:
171
209
  if (config := tar.extractfile("config")) is None:
172
210
  raise ValueError(f"Checkpoint does not contain a config file: {path}")
173
- return cast(DictConfig, OmegaConf.load(config))
211
+ return self.get_config(cast(DictConfig, OmegaConf.load(config)), use_cli=False)
174
212
 
175
213
  match part:
176
214
  case "model":
@@ -192,51 +230,90 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
192
230
 
193
231
  def save_checkpoint(
194
232
  self,
195
- model: PyTree,
196
- optimizer: optax.GradientTransformation,
197
- opt_state: optax.OptState,
198
- state: State,
233
+ model: PyTree | None = None,
234
+ optimizer: optax.GradientTransformation | None = None,
235
+ opt_state: optax.OptState | None = None,
236
+ aux_data: PyTree | None = None,
237
+ state: State | None = None,
199
238
  ) -> Path:
239
+ """Save a checkpoint.
240
+
241
+ Args:
242
+ model: The model to save
243
+ state: The current training state
244
+ optimizer: The optimizer to save
245
+ aux_data: Additional data to save
246
+ opt_state: The optimizer state to save
247
+
248
+ Returns:
249
+ Path to the saved checkpoint
250
+ """
200
251
  ckpt_path = self.get_ckpt_path(state)
201
252
 
202
253
  if not is_master():
203
254
  return ckpt_path
204
255
 
205
- # Gets the path to the last checkpoint.
256
+ # Gets the path to the last checkpoint
206
257
  logger.info("Saving checkpoint to %s", ckpt_path)
207
258
  last_ckpt_path = self.get_ckpt_path()
208
259
  ckpt_path.parent.mkdir(exist_ok=True, parents=True)
209
260
 
210
- # Potentially removes the last checkpoint.
261
+ # Potentially removes the last checkpoint
211
262
  if last_ckpt_path.exists() and self.config.only_save_most_recent:
212
263
  if (base_ckpt := last_ckpt_path.resolve()).is_file():
213
264
  base_ckpt.unlink()
214
265
 
215
- # Combines all temporary files into a single checkpoint TAR file.
266
+ # Save the checkpoint components
216
267
  with tarfile.open(ckpt_path, "w:gz") as tar:
217
268
 
218
- def add_file(name: str, write_fn: Callable[[io.BytesIO], Any]) -> None:
269
+ def add_file(name: str, buf: io.BytesIO) -> None:
270
+ tarinfo = tarfile.TarInfo(name)
271
+ tarinfo.size = buf.tell()
272
+ buf.seek(0)
273
+ tar.addfile(tarinfo, buf)
274
+
275
+ # Save model using Equinox
276
+ if model is not None:
277
+ with io.BytesIO() as buf:
278
+ eqx.tree_serialise_leaves(buf, model)
279
+ add_file("model", buf)
280
+
281
+ # Save optimizer using Equinox
282
+ if optimizer is not None:
283
+ with io.BytesIO() as buf:
284
+ eqx.tree_serialise_leaves(buf, optimizer)
285
+ add_file("optimizer", buf)
286
+
287
+ # Save optimizer state using Equinox
288
+ if opt_state is not None:
219
289
  with io.BytesIO() as buf:
220
- write_fn(buf)
221
- tarinfo = tarfile.TarInfo(name)
222
- tarinfo.size = buf.tell()
223
- buf.seek(0)
224
- tar.addfile(tarinfo, buf)
225
-
226
- add_file("model", lambda buf: cloudpickle.dump(model, buf))
227
- add_file("opt", lambda buf: cloudpickle.dump(optimizer, buf))
228
- add_file("opt_state", lambda buf: cloudpickle.dump(opt_state, buf))
229
- add_file("state", lambda buf: buf.write(json.dumps(asdict(state), indent=2).encode()))
230
- add_file("config", lambda buf: buf.write(OmegaConf.to_yaml(self.config).encode()))
231
-
232
- # Updates the symlink to the new checkpoint.
290
+ eqx.tree_serialise_leaves(buf, opt_state)
291
+ add_file("opt_state", buf)
292
+
293
+ # Save aux data using Equinox.
294
+ if aux_data is not None:
295
+ with io.BytesIO() as buf:
296
+ eqx.tree_serialise_leaves(buf, aux_data)
297
+ add_file("aux_data", buf)
298
+
299
+ # Save state and config as JSON
300
+ def add_file_bytes(name: str, data: bytes) -> None: # noqa: ANN401
301
+ info = tarfile.TarInfo(name=name)
302
+ info.size = len(data)
303
+ tar.addfile(info, io.BytesIO(data))
304
+
305
+ if state is not None:
306
+ add_file_bytes("state", json.dumps(asdict(state), indent=2).encode())
307
+ add_file_bytes("config", OmegaConf.to_yaml(self.config).encode())
308
+
309
+ # Updates the symlink to the new checkpoint
233
310
  last_ckpt_path.unlink(missing_ok=True)
234
311
  try:
235
312
  last_ckpt_path.symlink_to(ckpt_path.relative_to(last_ckpt_path.parent))
236
313
  except FileExistsError:
237
314
  logger.exception("Exception while trying to update %s", ckpt_path)
238
315
 
239
- # Calls the base callback.
316
+ # Calls the base callback
240
317
  self.on_after_checkpoint_save(ckpt_path, state)
241
318
 
242
319
  return ckpt_path
@@ -9,6 +9,7 @@ import jax
9
9
  from dpshdl.dataloader import CollatedDataloaderItem, Dataloader
10
10
  from dpshdl.dataset import Dataset, ErrorHandlingDataset
11
11
  from dpshdl.prefetcher import Prefetcher
12
+ from jaxtyping import PRNGKeyArray
12
13
  from omegaconf import II, MISSING
13
14
 
14
15
  from xax.core.conf import field, is_missing
@@ -103,7 +104,7 @@ class DataloadersMixin(ProcessMixin[Config], BaseTask[Config], Generic[Config],
103
104
  "or `get_data_iterator` to return an iterator for the given dataset."
104
105
  )
105
106
 
106
- def get_data_iterator(self, phase: Phase) -> Iterator:
107
+ def get_data_iterator(self, phase: Phase, key: PRNGKeyArray) -> Iterator:
107
108
  raise NotImplementedError(
108
109
  "You must implement either the `get_dataset` method to return the dataset for the given phase, "
109
110
  "or `get_data_iterator` to return an iterator for the given dataset."
@@ -11,7 +11,7 @@ import textwrap
11
11
  import time
12
12
  import traceback
13
13
  from abc import ABC, abstractmethod
14
- from dataclasses import dataclass, is_dataclass
14
+ from dataclasses import asdict, dataclass, is_dataclass
15
15
  from threading import Thread
16
16
  from typing import (
17
17
  Any,
@@ -33,7 +33,6 @@ import jax.numpy as jnp
33
33
  import numpy as np
34
34
  import optax
35
35
  from jaxtyping import Array, PRNGKeyArray, PyTree
36
- from omegaconf import DictConfig
37
36
 
38
37
  from xax.core.conf import field
39
38
  from xax.core.state import Phase, State
@@ -50,6 +49,7 @@ from xax.utils.experiments import (
50
49
  TrainingFinishedError,
51
50
  diff_configs,
52
51
  get_diff_string,
52
+ get_info_json,
53
53
  get_state_file_string,
54
54
  get_training_code,
55
55
  )
@@ -340,20 +340,30 @@ class TrainMixin(
340
340
 
341
341
  if init_ckpt_path is not None:
342
342
  logger.info("Loading checkpoint from %s", init_ckpt_path)
343
- if load_optimizer:
344
- model, optimizer, opt_state, state, config = self.load_checkpoint(init_ckpt_path)
345
- config_diff = get_diff_string(diff_configs(config, cast(DictConfig, self.config)))
346
- if config_diff:
347
- logger.warning("Loaded config differs from current config:\n%s", config_diff)
348
- return model, optimizer, opt_state, state
343
+ model_spec = eqx.filter_eval_shape(self.get_model, key)
344
+ model, state, config = self.load_checkpoint(
345
+ init_ckpt_path,
346
+ part="model_state_config",
347
+ model_template=model_spec,
348
+ )
349
+ config_diff = get_diff_string(diff_configs(asdict(config), asdict(self.config)))
350
+ if config_diff:
351
+ logger.warning("Loaded config differs from current config:\n%s", config_diff)
349
352
 
350
- else:
351
- model, state, config = self.load_checkpoint(init_ckpt_path, "model_state_config")
352
- config_diff = get_diff_string(diff_configs(config, cast(DictConfig, self.config)))
353
- if config_diff:
354
- logger.warning("Loaded config differs from current config:\n%s", config_diff)
353
+ if not load_optimizer:
355
354
  return model, state
356
355
 
356
+ # Loads the optimizer.
357
+ optimizer_spec = eqx.filter_eval_shape(self.get_optimizer)
358
+ optimizer = self.load_checkpoint(init_ckpt_path, part="opt", optimizer_template=optimizer_spec)
359
+
360
+ # Loads the optimizer state.
361
+ opt_state_spec = eqx.filter_eval_shape(self.get_initial_opt_state, model, optimizer)
362
+ opt_state = self.load_checkpoint(init_ckpt_path, part="opt_state", opt_state_template=opt_state_spec)
363
+
364
+ return model, optimizer, opt_state, state
365
+
366
+ logger.info("No checkpoint found. Initializing a new model.")
357
367
  model = self.get_model(key)
358
368
  state = State.init_state()
359
369
 
@@ -554,6 +564,7 @@ class TrainMixin(
554
564
  self.logger.log_file("state.txt", get_state_file_string(self))
555
565
  self.logger.log_file("training_code.py", get_training_code(self))
556
566
  self.logger.log_file("config.yaml", self.config_str(self.config, use_cli=False))
567
+ self.logger.log_file("info.json", get_info_json())
557
568
 
558
569
  def model_partition_fn(self, item: Any) -> bool: # noqa: ANN401
559
570
  return eqx.is_inexact_array(item)
@@ -627,16 +638,16 @@ class TrainMixin(
627
638
 
628
639
  if self.should_checkpoint(state):
629
640
  model = eqx.combine(model_arr, model_static)
630
- self.save_checkpoint(model, optimizer, opt_state, state)
641
+ self.save_checkpoint(model=model, optimizer=optimizer, opt_state=opt_state, state=state)
631
642
 
632
643
  # After finishing training, save the final checkpoint.
633
644
  model = eqx.combine(model_arr, model_static)
634
- self.save_checkpoint(model, optimizer, opt_state, state)
645
+ self.save_checkpoint(model=model, optimizer=optimizer, opt_state=opt_state, state=state)
635
646
 
636
647
  @contextlib.contextmanager
637
- def get_train_iterator(self) -> Generator[Iterator[Batch], None, None]:
648
+ def get_train_iterator(self, key: PRNGKeyArray) -> Generator[Iterator[Batch], None, None]:
638
649
  try:
639
- train_iterator: Iterator[Batch] = self.get_data_iterator("train")
650
+ train_iterator: Iterator[Batch] = self.get_data_iterator("train", key=key)
640
651
  yield train_iterator
641
652
  return
642
653
  except NotImplementedError:
@@ -653,9 +664,9 @@ class TrainMixin(
653
664
  logger.info("Closing train prefetcher")
654
665
 
655
666
  @contextlib.contextmanager
656
- def get_valid_iterator(self) -> Generator[Iterator[Batch], None, None]:
667
+ def get_valid_iterator(self, key: PRNGKeyArray) -> Generator[Iterator[Batch], None, None]:
657
668
  try:
658
- valid_iterator: Iterator[Batch] = self.get_data_iterator("valid")
669
+ valid_iterator: Iterator[Batch] = self.get_data_iterator("valid", key=key)
659
670
  yield valid_iterator
660
671
  return
661
672
  except NotImplementedError:
@@ -699,12 +710,13 @@ class TrainMixin(
699
710
  state = self.on_training_start(state)
700
711
 
701
712
  def on_exit() -> None:
702
- self.save_checkpoint(model, optimizer, opt_state, state)
713
+ self.save_checkpoint(model=model, optimizer=optimizer, opt_state=opt_state, state=state)
703
714
 
704
715
  # Handle user-defined interrupts during the training loop.
705
716
  self.add_signal_handler(on_exit, signal.SIGUSR1, signal.SIGTERM)
706
717
 
707
- with self.get_train_iterator() as train_pf, self.get_valid_iterator() as valid_pf:
718
+ key, tkey, vkey = jax.random.split(key, 3)
719
+ with self.get_train_iterator(tkey) as train_pf, self.get_valid_iterator(vkey) as valid_pf:
708
720
  try:
709
721
  self.train_loop(
710
722
  model=model,
@@ -721,7 +733,7 @@ class TrainMixin(
721
733
  f"Finished training after {state.num_steps} steps, {state.num_samples} samples",
722
734
  important=True,
723
735
  )
724
- self.save_checkpoint(model, optimizer, opt_state, state)
736
+ self.save_checkpoint(model=model, optimizer=optimizer, opt_state=opt_state, state=state)
725
737
 
726
738
  except (KeyboardInterrupt, bdb.BdbQuit):
727
739
  if is_master():
@@ -731,7 +743,7 @@ class TrainMixin(
731
743
  exception_tb = textwrap.indent(highlight_exception_message(traceback.format_exc()), " ")
732
744
  sys.stdout.write(f"Caught exception during training loop:\n\n{exception_tb}\n")
733
745
  sys.stdout.flush()
734
- self.save_checkpoint(model, optimizer, opt_state, state)
746
+ self.save_checkpoint(model=model, optimizer=optimizer, opt_state=opt_state, state=state)
735
747
 
736
748
  finally:
737
749
  state = self.on_training_end(state)
@@ -7,6 +7,7 @@ import functools
7
7
  import hashlib
8
8
  import inspect
9
9
  import itertools
10
+ import json
10
11
  import logging
11
12
  import math
12
13
  import os
@@ -24,7 +25,7 @@ import warnings
24
25
  from abc import ABC, abstractmethod
25
26
  from pathlib import Path
26
27
  from types import TracebackType
27
- from typing import Any, Iterator, Self, TypeVar, cast
28
+ from typing import Any, Iterator, Mapping, Self, Sequence, TypeVar, cast
28
29
  from urllib.parse import urlparse
29
30
 
30
31
  import git
@@ -116,9 +117,7 @@ class StateTimer:
116
117
 
117
118
  def log_dict(self) -> dict[str, int | float | tuple[int | float, bool]]:
118
119
  return {
119
- "steps": (self.step_timer.steps, True),
120
120
  "steps/second": self.step_timer.steps_per_second,
121
- "samples": (self.sample_timer.steps, True),
122
121
  "samples/second": (self.sample_timer.steps_per_second, True),
123
122
  "dt": self.iter_timer.iter_seconds,
124
123
  }
@@ -204,8 +203,8 @@ class MinGradScaleError(TrainingFinishedError):
204
203
 
205
204
 
206
205
  def diff_configs(
207
- first: ListConfig | DictConfig,
208
- second: ListConfig | DictConfig,
206
+ first: Mapping | Sequence,
207
+ second: Mapping | Sequence,
209
208
  prefix: str | None = None,
210
209
  ) -> tuple[list[str], list[str]]:
211
210
  """Returns the difference between two configs.
@@ -232,7 +231,7 @@ def diff_configs(
232
231
 
233
232
  any_config = (ListConfig, DictConfig)
234
233
 
235
- if isinstance(first, DictConfig) and isinstance(second, DictConfig):
234
+ if isinstance(first, Mapping) and isinstance(second, Mapping):
236
235
  first_keys, second_keys = cast(set[str], set(first.keys())), cast(set[str], set(second.keys()))
237
236
 
238
237
  # Gets the new keys in each config.
@@ -242,11 +241,12 @@ def diff_configs(
242
241
  # Gets the new sub-keys in each config.
243
242
  for key in first_keys.intersection(second_keys):
244
243
  sub_prefix = key if prefix is None else f"{prefix}.{key}"
245
- if OmegaConf.is_missing(first, key) or OmegaConf.is_missing(second, key):
246
- if not OmegaConf.is_missing(first, key):
247
- new_first += [get_diff_string(sub_prefix, first[key])]
248
- if not OmegaConf.is_missing(second, key):
249
- new_second += [get_diff_string(sub_prefix, second[key])]
244
+ if isinstance(first, DictConfig) and isinstance(second, DictConfig):
245
+ if OmegaConf.is_missing(first, key) or OmegaConf.is_missing(second, key):
246
+ if not OmegaConf.is_missing(first, key):
247
+ new_first += [get_diff_string(sub_prefix, first[key])]
248
+ if not OmegaConf.is_missing(second, key):
249
+ new_second += [get_diff_string(sub_prefix, second[key])]
250
250
  elif isinstance(first[key], any_config) and isinstance(second[key], any_config):
251
251
  sub_new_first, sub_new_second = diff_configs(first[key], second[key], prefix=sub_prefix)
252
252
  new_first, new_second = new_first + sub_new_first, new_second + sub_new_second
@@ -255,7 +255,7 @@ def diff_configs(
255
255
  new_first += [get_diff_string(sub_prefix, first_val)]
256
256
  new_second += [get_diff_string(sub_prefix, second_val)]
257
257
 
258
- elif isinstance(first, ListConfig) and isinstance(second, ListConfig):
258
+ elif isinstance(first, Sequence) and isinstance(second, Sequence):
259
259
  if len(first) > len(second):
260
260
  for i in range(len(second), len(first)):
261
261
  new_first += [get_diff_string(prefix, first[i])]
@@ -470,16 +470,33 @@ def get_command_line_string() -> str:
470
470
  return " ".join(sys.argv)
471
471
 
472
472
 
473
+ def get_environment_variables() -> str:
474
+ return "\n".join([f"{key}={value}" for key, value in sorted(os.environ.items())])
475
+
476
+
473
477
  def get_state_file_string(obj: object) -> str:
474
478
  return "\n\n".join(
475
479
  [
476
480
  f"=== Command Line ===\n\n{get_command_line_string()}",
477
481
  f"=== Git State ===\n\n{get_git_state(obj)}",
478
482
  f"=== Packages ===\n\n{get_packages_with_versions()}",
483
+ f"=== Environment Variables ===\n\n{get_environment_variables()}",
479
484
  ]
480
485
  )
481
486
 
482
487
 
488
+ def get_info_json() -> str:
489
+ return json.dumps(
490
+ {
491
+ "process_id": os.getpid(),
492
+ "job": {
493
+ "start_time": datetime.datetime.now().isoformat(),
494
+ },
495
+ },
496
+ indent=2,
497
+ )
498
+
499
+
483
500
  def get_training_code(obj: object) -> str:
484
501
  """Gets the text from the file containing the provided object.
485
502
 
@@ -2,11 +2,12 @@
2
2
 
3
3
  import functools
4
4
  import io
5
+ import json
5
6
  import os
6
7
  import tempfile
7
8
  import time
8
9
  from pathlib import Path
9
- from typing import Literal, TypedDict
10
+ from typing import Any, Literal, TypedDict
10
11
 
11
12
  import numpy as np
12
13
  import PIL.Image
@@ -14,9 +15,15 @@ from PIL.Image import Image as PILImage
14
15
  from tensorboard.compat.proto.config_pb2 import RunMetadata
15
16
  from tensorboard.compat.proto.event_pb2 import Event, TaggedRunMetadata
16
17
  from tensorboard.compat.proto.graph_pb2 import GraphDef
17
- from tensorboard.compat.proto.summary_pb2 import HistogramProto, Summary, SummaryMetadata
18
+ from tensorboard.compat.proto.summary_pb2 import (
19
+ HistogramProto,
20
+ Summary,
21
+ SummaryMetadata,
22
+ )
18
23
  from tensorboard.compat.proto.tensor_pb2 import TensorProto
19
24
  from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto
25
+ from tensorboard.plugins.mesh import metadata as mesh_metadata
26
+ from tensorboard.plugins.mesh.plugin_data_pb2 import MeshPluginData
20
27
  from tensorboard.plugins.text.plugin_data_pb2 import TextPluginData
21
28
  from tensorboard.summary.writer.event_file_writer import EventFileWriter
22
29
 
@@ -84,6 +91,68 @@ def make_histogram(values: np.ndarray, bins: str | np.ndarray, max_bins: int | N
84
91
  )
85
92
 
86
93
 
94
+ def _get_json_config(config_dict: dict[str, Any] | None) -> str:
95
+ json_config = "{}"
96
+ if config_dict is not None:
97
+ json_config = json.dumps(config_dict, sort_keys=True)
98
+ return json_config
99
+
100
+
101
+ def make_mesh_summary(
102
+ tag: str,
103
+ vertices: np.ndarray,
104
+ colors: np.ndarray | None,
105
+ faces: np.ndarray | None,
106
+ config_dict: dict[str, Any] | None,
107
+ display_name: str | None = None,
108
+ description: str | None = None,
109
+ ) -> Summary:
110
+ json_config = _get_json_config(config_dict)
111
+
112
+ summaries = []
113
+ tensors = [
114
+ (vertices, MeshPluginData.VERTEX),
115
+ (faces, MeshPluginData.FACE),
116
+ (colors, MeshPluginData.COLOR),
117
+ ]
118
+ # Filter out None tensors and explicitly type the list
119
+ valid_tensors = [(t, content_type) for t, content_type in tensors if t is not None]
120
+ components = mesh_metadata.get_components_bitmask([content_type for (_, content_type) in valid_tensors])
121
+
122
+ for tensor, content_type in valid_tensors: # Now we know tensor is not None
123
+ tensor_metadata = mesh_metadata.create_summary_metadata(
124
+ tag,
125
+ display_name,
126
+ content_type,
127
+ components,
128
+ tensor.shape, # Safe now since tensor is not None
129
+ description,
130
+ json_config=json_config,
131
+ )
132
+
133
+ tensor_proto = TensorProto(
134
+ dtype="DT_FLOAT",
135
+ float_val=tensor.reshape(-1).tolist(), # Safe now since tensor is not None
136
+ tensor_shape=TensorShapeProto(
137
+ dim=[
138
+ TensorShapeProto.Dim(size=tensor.shape[0]), # Safe now since tensor is not None
139
+ TensorShapeProto.Dim(size=tensor.shape[1]),
140
+ TensorShapeProto.Dim(size=tensor.shape[2]),
141
+ ]
142
+ ),
143
+ )
144
+
145
+ tensor_summary = Summary.Value(
146
+ tag=mesh_metadata.get_instance_name(tag, content_type),
147
+ tensor=tensor_proto,
148
+ metadata=tensor_metadata,
149
+ )
150
+
151
+ summaries.append(tensor_summary)
152
+
153
+ return Summary(value=summaries)
154
+
155
+
87
156
  class TensorboardProtobufWriter:
88
157
  def __init__(
89
158
  self,
@@ -454,6 +523,9 @@ class TensorboardWriter:
454
523
  weighted_sum = float((bin_centers * bucket_counts).sum())
455
524
  weighted_sum_squares = float((bin_centers**2 * bucket_counts).sum())
456
525
 
526
+ # Convert bin edges to list of floats explicitly
527
+ bucket_limits: list[float | np.ndarray] = [float(x) for x in bin_edges[1:]]
528
+
457
529
  self.add_histogram_raw(
458
530
  tag=tag,
459
531
  min=float(bin_edges[0]),
@@ -461,12 +533,28 @@ class TensorboardWriter:
461
533
  num=int(total_counts),
462
534
  sum=weighted_sum,
463
535
  sum_squares=weighted_sum_squares,
464
- bucket_limits=bin_edges[1:].tolist(), # TensorBoard expects right bin edges
536
+ bucket_limits=bucket_limits, # Now properly typed
465
537
  bucket_counts=bucket_counts.tolist(),
466
538
  global_step=global_step,
467
539
  walltime=walltime,
468
540
  )
469
541
 
542
+ def add_mesh(
543
+ self,
544
+ tag: str,
545
+ vertices: np.ndarray,
546
+ colors: np.ndarray | None,
547
+ faces: np.ndarray | None,
548
+ config_dict: dict[str, Any] | None,
549
+ global_step: int | None = None,
550
+ walltime: float | None = None,
551
+ ) -> None:
552
+ self.pb_writer.add_summary(
553
+ make_mesh_summary(tag, vertices, colors, faces, config_dict),
554
+ global_step=global_step,
555
+ walltime=walltime,
556
+ )
557
+
470
558
 
471
559
  class TensorboardWriterKwargs(TypedDict):
472
560
  max_queue_size: int
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.1.16
3
+ Version: 0.2.0
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -8,14 +8,14 @@ Requires-Python: >=3.11
8
8
  Description-Content-Type: text/markdown
9
9
  License-File: LICENSE
10
10
  Requires-Dist: attrs
11
+ Requires-Dist: chex
12
+ Requires-Dist: dpshdl
13
+ Requires-Dist: equinox
14
+ Requires-Dist: importlib-resources
11
15
  Requires-Dist: jax
12
16
  Requires-Dist: jaxtyping
13
- Requires-Dist: equinox
14
17
  Requires-Dist: optax
15
- Requires-Dist: dpshdl
16
- Requires-Dist: chex
17
- Requires-Dist: importlib-resources
18
- Requires-Dist: cloudpickle
18
+ Requires-Dist: orbax-checkpoint
19
19
  Requires-Dist: pillow
20
20
  Requires-Dist: omegaconf
21
21
  Requires-Dist: gitpython
@@ -1,12 +1,12 @@
1
1
  attrs
2
+ chex
3
+ dpshdl
4
+ equinox
5
+ importlib-resources
2
6
  jax
3
7
  jaxtyping
4
- equinox
5
8
  optax
6
- dpshdl
7
- chex
8
- importlib-resources
9
- cloudpickle
9
+ orbax-checkpoint
10
10
  pillow
11
11
  omegaconf
12
12
  gitpython
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes