xax 0.1.16__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xax/__init__.py CHANGED
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.1.16"
15
+ __version__ = "0.2.1"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -44,6 +44,7 @@ __all__ = [
44
44
  "euler_to_quat",
45
45
  "get_projected_gravity_vector_from_quat",
46
46
  "quat_to_euler",
47
+ "quat_to_rotmat",
47
48
  "rotate_vector_by_quat",
48
49
  "cross_entropy",
49
50
  "cast_norm_type",
@@ -206,6 +207,7 @@ NAME_MAP: dict[str, str] = {
206
207
  "euler_to_quat": "nn.geom",
207
208
  "get_projected_gravity_vector_from_quat": "nn.geom",
208
209
  "quat_to_euler": "nn.geom",
210
+ "quat_to_rotmat": "nn.geom",
209
211
  "rotate_vector_by_quat": "nn.geom",
210
212
  "cross_entropy": "nn.losses",
211
213
  "cast_norm_type": "nn.norm",
@@ -369,6 +371,7 @@ if IMPORT_ALL or TYPE_CHECKING:
369
371
  euler_to_quat,
370
372
  get_projected_gravity_vector_from_quat,
371
373
  quat_to_euler,
374
+ quat_to_rotmat,
372
375
  rotate_vector_by_quat,
373
376
  )
374
377
  from xax.nn.losses import cross_entropy
xax/core/state.py CHANGED
@@ -12,6 +12,14 @@ from xax.core.conf import field
12
12
  Phase = Literal["train", "valid"]
13
13
 
14
14
 
15
+ def _phase_to_int(phase: Phase) -> int:
16
+ return {"train": 0, "valid": 1}[phase]
17
+
18
+
19
+ def _int_to_phase(i: int) -> Phase:
20
+ return cast(Phase, ["train", "valid"][i])
21
+
22
+
15
23
  class StateDict(TypedDict, total=False):
16
24
  num_steps: NotRequired[int]
17
25
  num_samples: NotRequired[int]
@@ -35,7 +43,7 @@ class State:
35
43
 
36
44
  @property
37
45
  def phase(self) -> Phase:
38
- return cast(Phase, ["train", "valid"][self._phase])
46
+ return _int_to_phase(self._phase)
39
47
 
40
48
  @classmethod
41
49
  def init_state(cls) -> "State":
@@ -74,3 +82,20 @@ class State:
74
82
  case _:
75
83
  raise ValueError(f"Invalid phase: {phase}")
76
84
  return State(**{**asdict(self), **kwargs, **extra_kwargs})
85
+
86
+ def to_dict(self) -> dict[str, int | float | str]:
87
+ return {
88
+ "num_steps": int(self.num_steps),
89
+ "num_samples": int(self.num_samples),
90
+ "num_valid_steps": int(self.num_valid_steps),
91
+ "num_valid_samples": int(self.num_valid_samples),
92
+ "start_time_s": float(self.start_time_s),
93
+ "elapsed_time_s": float(self.elapsed_time_s),
94
+ "phase": str(self.phase),
95
+ }
96
+
97
+ @classmethod
98
+ def from_dict(cls, d: dict[str, int | float | str]) -> "State":
99
+ if "phase" in d:
100
+ d["_phase"] = _phase_to_int(cast(Phase, d.pop("phase")))
101
+ return cls(**d) # type: ignore[arg-type]
xax/nn/geom.py CHANGED
@@ -177,3 +177,37 @@ def cubic_bezier_interpolation(y_start: Array, y_end: Array, x: Array) -> Array:
177
177
  y_diff = y_end - y_start
178
178
  bezier = x**3 + 3 * (x**2 * (1 - x))
179
179
  return y_start + y_diff * bezier
180
+
181
+
182
+ def quat_to_rotmat(quat: Array, eps: float = 1e-6) -> Array:
183
+ """Converts a quaternion to a rotation matrix.
184
+
185
+ Args:
186
+ quat: The quaternion to convert, shape (*, 4).
187
+ eps: A small epsilon value to avoid division by zero.
188
+
189
+ Returns:
190
+ The rotation matrix, shape (*, 3, 3).
191
+ """
192
+ quat = quat / (jnp.linalg.norm(quat, axis=-1, keepdims=True) + eps)
193
+ w, x, y, z = jnp.split(quat, 4, axis=-1)
194
+
195
+ xx = 1 - 2 * (y * y + z * z)
196
+ xy = 2 * (x * y - z * w)
197
+ xz = 2 * (x * z + y * w)
198
+ yx = 2 * (x * y + z * w)
199
+ yy = 1 - 2 * (x * x + z * z)
200
+ yz = 2 * (y * z - x * w)
201
+ zx = 2 * (x * z - y * w)
202
+ zy = 2 * (y * z + x * w)
203
+ zz = 1 - 2 * (x * x + y * y)
204
+
205
+ # Corrected stacking: row-major order
206
+ return jnp.concatenate(
207
+ [
208
+ jnp.concatenate([xx, xy, xz], axis=-1)[..., None, :],
209
+ jnp.concatenate([yx, yy, yz], axis=-1)[..., None, :],
210
+ jnp.concatenate([zx, zy, zz], axis=-1)[..., None, :],
211
+ ],
212
+ axis=-2,
213
+ )
xax/requirements.txt CHANGED
@@ -2,16 +2,16 @@
2
2
 
3
3
  # Core ML/JAX dependencies
4
4
  attrs
5
+ chex
6
+ dpshdl
7
+ equinox
8
+ importlib-resources
5
9
  jax
6
10
  jaxtyping
7
- equinox
8
11
  optax
9
- dpshdl
10
- chex
11
- importlib-resources
12
+ orbax-checkpoint
12
13
 
13
14
  # Data processing and serialization
14
- cloudpickle
15
15
  pillow
16
16
 
17
17
  # Configuration and project management
xax/task/base.py CHANGED
@@ -79,7 +79,7 @@ class BaseTask(Generic[Config]):
79
79
  def on_training_end(self, state: State) -> State:
80
80
  return state
81
81
 
82
- def on_after_checkpoint_save(self, ckpt_path: Path, state: State) -> State:
82
+ def on_after_checkpoint_save(self, ckpt_path: Path, state: State | None) -> State | None:
83
83
  return state
84
84
 
85
85
  @functools.cached_property
xax/task/logger.py CHANGED
@@ -18,11 +18,22 @@ from abc import ABC, abstractmethod
18
18
  from collections import defaultdict
19
19
  from dataclasses import dataclass
20
20
  from types import TracebackType
21
- from typing import Callable, Iterator, Literal, Self, Sequence, TypeVar, cast, get_args
21
+ from typing import (
22
+ Any,
23
+ Callable,
24
+ Iterator,
25
+ Literal,
26
+ Self,
27
+ Sequence,
28
+ TypeVar,
29
+ cast,
30
+ get_args,
31
+ )
22
32
 
23
33
  import jax
24
34
  import jax.numpy as jnp
25
35
  import numpy as np
36
+ from jax._src.core import ClosedJaxpr
26
37
  from jaxtyping import Array
27
38
  from PIL import Image, ImageDraw, ImageFont
28
39
  from PIL.Image import Image as PILImage
@@ -194,7 +205,10 @@ def tile_images(images: list[PILImage], sep: int = 0) -> PILImage:
194
205
  return tiled
195
206
 
196
207
 
197
- def as_numpy(array: Array) -> np.ndarray:
208
+ def as_numpy(array: Array | np.ndarray) -> np.ndarray:
209
+ """Convert a JAX array or numpy array to numpy array."""
210
+ if isinstance(array, np.ndarray):
211
+ return array
198
212
  array = jax.device_get(array)
199
213
  if jax.dtypes.issubdtype(array.dtype, jnp.floating):
200
214
  array = array.astype(jnp.float32)
@@ -205,6 +219,13 @@ def as_numpy(array: Array) -> np.ndarray:
205
219
  return np.array(array)
206
220
 
207
221
 
222
+ def as_numpy_opt(array: Array | np.ndarray | None) -> np.ndarray | None:
223
+ """Convert an optional JAX array or numpy array to numpy array."""
224
+ if array is None:
225
+ return None
226
+ return as_numpy(array)
227
+
228
+
208
229
  @dataclass(kw_only=True)
209
230
  class LogString:
210
231
  value: str
@@ -252,6 +273,19 @@ class LogHistogram:
252
273
  bucket_counts: list[int]
253
274
 
254
275
 
276
+ @dataclass(kw_only=True)
277
+ class LogMesh:
278
+ vertices: np.ndarray
279
+ colors: np.ndarray | None
280
+ faces: np.ndarray | None
281
+ config_dict: dict[str, Any] | None # noqa: ANN401
282
+
283
+
284
+ @dataclass(kw_only=True)
285
+ class LogGraph:
286
+ computation: ClosedJaxpr
287
+
288
+
255
289
  @dataclass(kw_only=True)
256
290
  class LogLine:
257
291
  state: State
@@ -261,6 +295,7 @@ class LogLine:
261
295
  strings: dict[str, dict[str, LogString]]
262
296
  images: dict[str, dict[str, LogImage]]
263
297
  videos: dict[str, dict[str, LogVideo]]
298
+ meshes: dict[str, dict[str, LogMesh]]
264
299
 
265
300
 
266
301
  @dataclass(kw_only=True)
@@ -533,6 +568,7 @@ class Logger:
533
568
  self.strings: dict[str, dict[str, Callable[[], LogString]]] = defaultdict(dict)
534
569
  self.images: dict[str, dict[str, Callable[[], LogImage]]] = defaultdict(dict)
535
570
  self.videos: dict[str, dict[str, Callable[[], LogVideo]]] = defaultdict(dict)
571
+ self.meshes: dict[str, dict[str, Callable[[], LogMesh]]] = defaultdict(dict)
536
572
  self.default_namespace = default_namespace
537
573
  self.loggers: list[LoggerImpl] = []
538
574
 
@@ -560,6 +596,7 @@ class Logger:
560
596
  strings={k: {kk: v() for kk, v in v.items()} for k, v in self.strings.items()},
561
597
  images={k: {kk: v() for kk, v in v.items()} for k, v in self.images.items()},
562
598
  videos={k: {kk: v() for kk, v in v.items()} for k, v in self.videos.items()},
599
+ meshes={k: {kk: v() for kk, v in v.items()} for k, v in self.meshes.items()},
563
600
  )
564
601
 
565
602
  def clear(self) -> None:
@@ -569,6 +606,7 @@ class Logger:
569
606
  self.strings.clear()
570
607
  self.images.clear()
571
608
  self.videos.clear()
609
+ self.meshes.clear()
572
610
 
573
611
  def write(self, state: State) -> None:
574
612
  """Writes the current step's logging information.
@@ -1051,6 +1089,73 @@ class Logger:
1051
1089
 
1052
1090
  self.videos[namespace][key] = video_future
1053
1091
 
1092
+ def log_mesh(
1093
+ self,
1094
+ key: str,
1095
+ vertices: np.ndarray | Array | Callable[[], np.ndarray | Array],
1096
+ colors: np.ndarray | Array | None | Callable[[], np.ndarray | Array | None] = None,
1097
+ faces: np.ndarray | Array | None | Callable[[], np.ndarray | Array | None] = None,
1098
+ config_dict: dict[str, Any] | None = None,
1099
+ *,
1100
+ namespace: str | None = None,
1101
+ ) -> None:
1102
+ if not self.active:
1103
+ raise RuntimeError("The logger is not active")
1104
+ namespace = self.resolve_namespace(namespace)
1105
+
1106
+ @functools.lru_cache(maxsize=None)
1107
+ def mesh_future() -> LogMesh:
1108
+ with ContextTimer() as timer:
1109
+ # Get the raw values
1110
+ vertices_val = vertices() if callable(vertices) else vertices
1111
+ colors_val = colors() if callable(colors) else colors
1112
+ faces_val = faces() if callable(faces) else faces
1113
+
1114
+ # Convert to numpy arrays with proper type handling
1115
+ vertices_np = as_numpy(vertices_val)
1116
+ colors_np = as_numpy_opt(colors_val)
1117
+ faces_np = as_numpy_opt(faces_val)
1118
+
1119
+ # Checks vertices shape.
1120
+ if vertices_np.ndim == 2:
1121
+ vertices_np = vertices_np[None]
1122
+ if vertices_np.shape[-1] != 3 or vertices_np.ndim != 3:
1123
+ raise ValueError("Vertices must have shape (N, 3) or (B, N, 3)")
1124
+
1125
+ # Checks colors shape.
1126
+ if colors_np is not None:
1127
+ if colors_np.ndim == 2:
1128
+ colors_np = colors_np[None]
1129
+ if colors_np.shape[-1] != 3 or colors_np.ndim != 3:
1130
+ raise ValueError("Colors must have shape (N, 3) or (B, N, 3)")
1131
+
1132
+ # Checks faces shape.
1133
+ if faces_np is not None:
1134
+ if faces_np.ndim == 2:
1135
+ faces_np = faces_np[None]
1136
+ if faces_np.shape[-1] != 3 or faces_np.ndim != 3:
1137
+ raise ValueError("Faces must have shape (N, 3) or (B, N, 3)")
1138
+
1139
+ # Ensures colors dtype is uint8.
1140
+ if colors_np is not None:
1141
+ if colors_np.dtype != np.uint8:
1142
+ colors_np = (colors_np * 255).astype(np.uint8)
1143
+
1144
+ # Ensures faces dtype is int32.
1145
+ if faces_np is not None:
1146
+ if faces_np.dtype != np.int32:
1147
+ faces_np = faces_np.astype(np.int32)
1148
+
1149
+ logger.debug("Mesh Key: %s, Time: %s", key, timer.elapsed_time)
1150
+ return LogMesh(
1151
+ vertices=vertices_np,
1152
+ colors=colors_np,
1153
+ faces=faces_np,
1154
+ config_dict=config_dict,
1155
+ )
1156
+
1157
+ self.meshes[namespace][key] = mesh_future
1158
+
1054
1159
  def __enter__(self) -> Self:
1055
1160
  self.active = True
1056
1161
  for logger in self.loggers:
@@ -70,6 +70,9 @@ class TensorboardLogger(LoggerImpl):
70
70
  self._started = True
71
71
 
72
72
  def worker_thread(self) -> None:
73
+ if os.environ.get("DISABLE_TENSORBOARD", "0") == "1":
74
+ return
75
+
73
76
  time.sleep(self.wait_seconds)
74
77
 
75
78
  port = int(os.environ.get("TENSORBOARD_PORT", DEFAULT_TENSORBOARD_PORT))
@@ -213,6 +216,19 @@ class TensorboardLogger(LoggerImpl):
213
216
  video_value.frames,
214
217
  fps=video_value.fps,
215
218
  global_step=line.state.num_steps,
219
+ walltime=walltime,
220
+ )
221
+
222
+ for namespace, meshes in line.meshes.items():
223
+ for mesh_key, mesh_value in meshes.items():
224
+ writer.add_mesh(
225
+ f"{namespace}/{mesh_key}",
226
+ vertices=mesh_value.vertices,
227
+ faces=mesh_value.faces,
228
+ colors=mesh_value.colors,
229
+ config_dict=mesh_value.config_dict,
230
+ global_step=line.state.num_steps,
231
+ walltime=walltime,
216
232
  )
217
233
 
218
234
  for name, contents in self.files.items():
@@ -6,9 +6,9 @@ import logging
6
6
  import tarfile
7
7
  from dataclasses import asdict, dataclass
8
8
  from pathlib import Path
9
- from typing import Any, Callable, Generic, Literal, TypeVar, cast, overload
9
+ from typing import Generic, Literal, TypeVar, cast, overload
10
10
 
11
- import cloudpickle
11
+ import equinox as eqx
12
12
  import jax
13
13
  import optax
14
14
  from jaxtyping import PyTree
@@ -63,8 +63,7 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
63
63
 
64
64
  def get_init_ckpt_path(self) -> Path | None:
65
65
  if self._exp_dir is not None:
66
- ckpt_path = self.get_ckpt_path()
67
- if ckpt_path.exists():
66
+ if (ckpt_path := self.get_ckpt_path()).exists():
68
67
  return ckpt_path
69
68
  if self.config.load_from_ckpt_path is not None:
70
69
  ckpt_path = Path(self.config.load_from_ckpt_path)
@@ -84,93 +83,129 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
84
83
  return False
85
84
 
86
85
  @overload
87
- def load_checkpoint(
86
+ def load_ckpt_with_template(
88
87
  self,
89
88
  path: Path,
90
- part: Literal["all"] = "all",
91
- ) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State, DictConfig]: ...
89
+ *,
90
+ part: Literal["all"],
91
+ model_template: PyTree,
92
+ optimizer_template: PyTree,
93
+ opt_state_template: PyTree,
94
+ ) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State, Config]: ...
92
95
 
93
96
  @overload
94
- def load_checkpoint(
97
+ def load_ckpt_with_template(
95
98
  self,
96
99
  path: Path,
97
- part: Literal["model_state_config"] = "model_state_config",
98
- ) -> tuple[PyTree, State, DictConfig]: ...
100
+ *,
101
+ part: Literal["model_state_config"],
102
+ model_template: PyTree,
103
+ ) -> tuple[PyTree, State, Config]: ...
99
104
 
100
105
  @overload
101
- def load_checkpoint(
106
+ def load_ckpt_with_template(
102
107
  self,
103
108
  path: Path,
109
+ *,
104
110
  part: Literal["model"],
111
+ model_template: PyTree,
105
112
  ) -> PyTree: ...
106
113
 
107
114
  @overload
108
- def load_checkpoint(
115
+ def load_ckpt_with_template(
109
116
  self,
110
117
  path: Path,
118
+ *,
111
119
  part: Literal["opt"],
120
+ optimizer_template: PyTree,
112
121
  ) -> optax.GradientTransformation: ...
113
122
 
114
123
  @overload
115
- def load_checkpoint(
124
+ def load_ckpt_with_template(
116
125
  self,
117
126
  path: Path,
127
+ *,
118
128
  part: Literal["opt_state"],
129
+ opt_state_template: PyTree,
119
130
  ) -> optax.OptState: ...
120
131
 
121
132
  @overload
122
- def load_checkpoint(
133
+ def load_ckpt_with_template(
123
134
  self,
124
135
  path: Path,
136
+ *,
125
137
  part: Literal["state"],
126
138
  ) -> State: ...
127
139
 
128
140
  @overload
129
- def load_checkpoint(
141
+ def load_ckpt_with_template(
130
142
  self,
131
143
  path: Path,
144
+ *,
132
145
  part: Literal["config"],
133
- ) -> DictConfig: ...
146
+ ) -> Config: ...
134
147
 
135
- def load_checkpoint(
148
+ def load_ckpt_with_template(
136
149
  self,
137
150
  path: Path,
151
+ *,
138
152
  part: CheckpointPart = "all",
153
+ model_template: PyTree | None = None,
154
+ optimizer_template: PyTree | None = None,
155
+ opt_state_template: PyTree | None = None,
139
156
  ) -> (
140
- tuple[PyTree, optax.GradientTransformation, optax.OptState, State, DictConfig]
141
- | tuple[PyTree, State, DictConfig]
157
+ tuple[PyTree, optax.GradientTransformation, optax.OptState, State, Config]
158
+ | tuple[PyTree, State, Config]
142
159
  | PyTree
143
160
  | optax.GradientTransformation
144
161
  | optax.OptState
145
162
  | State
146
- | DictConfig
163
+ | Config
147
164
  ):
165
+ """Load a checkpoint.
166
+
167
+ Args:
168
+ path: Path to the checkpoint directory
169
+ part: Which part of the checkpoint to load
170
+ model_template: Template model with correct structure but uninitialized weights
171
+ optimizer_template: Template optimizer with correct structure but uninitialized weights
172
+ opt_state_template: Template optimizer state with correct structure but uninitialized weights
173
+
174
+ Returns:
175
+ The requested checkpoint components
176
+ """
148
177
  with tarfile.open(path, "r:gz") as tar:
149
178
 
150
179
  def get_model() -> PyTree:
180
+ if model_template is None:
181
+ raise ValueError("model_template must be provided to load model weights")
151
182
  if (model := tar.extractfile("model")) is None:
152
183
  raise ValueError(f"Checkpoint does not contain a model file: {path}")
153
- return cloudpickle.load(model)
184
+ return eqx.tree_deserialise_leaves(io.BytesIO(model.read()), model_template)
154
185
 
155
186
  def get_opt() -> optax.GradientTransformation:
156
- if (opt := tar.extractfile("opt")) is None:
157
- raise ValueError(f"Checkpoint does not contain an opt file: {path}")
158
- return cloudpickle.load(opt)
187
+ if optimizer_template is None:
188
+ raise ValueError("optimizer_template must be provided to load optimizer")
189
+ if (opt := tar.extractfile("optimizer")) is None:
190
+ raise ValueError(f"Checkpoint does not contain an optimizer file: {path}")
191
+ return eqx.tree_deserialise_leaves(io.BytesIO(opt.read()), optimizer_template)
159
192
 
160
193
  def get_opt_state() -> optax.OptState:
194
+ if opt_state_template is None:
195
+ raise ValueError("opt_state_template must be provided to load optimizer state")
161
196
  if (opt_state := tar.extractfile("opt_state")) is None:
162
- raise ValueError(f"Checkpoint does not contain an opt_state file: {path}")
163
- return cloudpickle.load(opt_state)
197
+ raise ValueError(f"Checkpoint does not contain an optimizer state file: {path}")
198
+ return eqx.tree_deserialise_leaves(io.BytesIO(opt_state.read()), opt_state_template)
164
199
 
165
200
  def get_state() -> State:
166
201
  if (state := tar.extractfile("state")) is None:
167
202
  raise ValueError(f"Checkpoint does not contain a state file: {path}")
168
203
  return State(**json.loads(state.read().decode()))
169
204
 
170
- def get_config() -> DictConfig:
205
+ def get_config() -> Config:
171
206
  if (config := tar.extractfile("config")) is None:
172
207
  raise ValueError(f"Checkpoint does not contain a config file: {path}")
173
- return cast(DictConfig, OmegaConf.load(config))
208
+ return self.get_config(cast(DictConfig, OmegaConf.load(config)), use_cli=False)
174
209
 
175
210
  match part:
176
211
  case "model":
@@ -192,51 +227,90 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
192
227
 
193
228
  def save_checkpoint(
194
229
  self,
195
- model: PyTree,
196
- optimizer: optax.GradientTransformation,
197
- opt_state: optax.OptState,
198
- state: State,
230
+ model: PyTree | None = None,
231
+ optimizer: optax.GradientTransformation | None = None,
232
+ opt_state: optax.OptState | None = None,
233
+ aux_data: PyTree | None = None,
234
+ state: State | None = None,
199
235
  ) -> Path:
236
+ """Save a checkpoint.
237
+
238
+ Args:
239
+ model: The model to save
240
+ state: The current training state
241
+ optimizer: The optimizer to save
242
+ aux_data: Additional data to save
243
+ opt_state: The optimizer state to save
244
+
245
+ Returns:
246
+ Path to the saved checkpoint
247
+ """
200
248
  ckpt_path = self.get_ckpt_path(state)
201
249
 
202
250
  if not is_master():
203
251
  return ckpt_path
204
252
 
205
- # Gets the path to the last checkpoint.
253
+ # Gets the path to the last checkpoint
206
254
  logger.info("Saving checkpoint to %s", ckpt_path)
207
255
  last_ckpt_path = self.get_ckpt_path()
208
256
  ckpt_path.parent.mkdir(exist_ok=True, parents=True)
209
257
 
210
- # Potentially removes the last checkpoint.
258
+ # Potentially removes the last checkpoint
211
259
  if last_ckpt_path.exists() and self.config.only_save_most_recent:
212
260
  if (base_ckpt := last_ckpt_path.resolve()).is_file():
213
261
  base_ckpt.unlink()
214
262
 
215
- # Combines all temporary files into a single checkpoint TAR file.
263
+ # Save the checkpoint components
216
264
  with tarfile.open(ckpt_path, "w:gz") as tar:
217
265
 
218
- def add_file(name: str, write_fn: Callable[[io.BytesIO], Any]) -> None:
266
+ def add_file(name: str, buf: io.BytesIO) -> None:
267
+ tarinfo = tarfile.TarInfo(name)
268
+ tarinfo.size = buf.tell()
269
+ buf.seek(0)
270
+ tar.addfile(tarinfo, buf)
271
+
272
+ # Save model using Equinox
273
+ if model is not None:
274
+ with io.BytesIO() as buf:
275
+ eqx.tree_serialise_leaves(buf, model)
276
+ add_file("model", buf)
277
+
278
+ # Save optimizer using Equinox
279
+ if optimizer is not None:
280
+ with io.BytesIO() as buf:
281
+ eqx.tree_serialise_leaves(buf, optimizer)
282
+ add_file("optimizer", buf)
283
+
284
+ # Save optimizer state using Equinox
285
+ if opt_state is not None:
219
286
  with io.BytesIO() as buf:
220
- write_fn(buf)
221
- tarinfo = tarfile.TarInfo(name)
222
- tarinfo.size = buf.tell()
223
- buf.seek(0)
224
- tar.addfile(tarinfo, buf)
225
-
226
- add_file("model", lambda buf: cloudpickle.dump(model, buf))
227
- add_file("opt", lambda buf: cloudpickle.dump(optimizer, buf))
228
- add_file("opt_state", lambda buf: cloudpickle.dump(opt_state, buf))
229
- add_file("state", lambda buf: buf.write(json.dumps(asdict(state), indent=2).encode()))
230
- add_file("config", lambda buf: buf.write(OmegaConf.to_yaml(self.config).encode()))
231
-
232
- # Updates the symlink to the new checkpoint.
287
+ eqx.tree_serialise_leaves(buf, opt_state)
288
+ add_file("opt_state", buf)
289
+
290
+ # Save aux data using Equinox.
291
+ if aux_data is not None:
292
+ with io.BytesIO() as buf:
293
+ eqx.tree_serialise_leaves(buf, aux_data)
294
+ add_file("aux_data", buf)
295
+
296
+ # Save state and config as JSON
297
+ def add_file_bytes(name: str, data: bytes) -> None: # noqa: ANN401
298
+ info = tarfile.TarInfo(name=name)
299
+ info.size = len(data)
300
+ tar.addfile(info, io.BytesIO(data))
301
+
302
+ if state is not None:
303
+ add_file_bytes("state", json.dumps(asdict(state), indent=2).encode())
304
+ add_file_bytes("config", OmegaConf.to_yaml(self.config).encode())
305
+
306
+ # Updates the symlink to the new checkpoint
233
307
  last_ckpt_path.unlink(missing_ok=True)
234
308
  try:
235
309
  last_ckpt_path.symlink_to(ckpt_path.relative_to(last_ckpt_path.parent))
236
310
  except FileExistsError:
237
311
  logger.exception("Exception while trying to update %s", ckpt_path)
238
312
 
239
- # Calls the base callback.
313
+ # Calls the base callback
240
314
  self.on_after_checkpoint_save(ckpt_path, state)
241
315
 
242
316
  return ckpt_path
@@ -9,6 +9,7 @@ import jax
9
9
  from dpshdl.dataloader import CollatedDataloaderItem, Dataloader
10
10
  from dpshdl.dataset import Dataset, ErrorHandlingDataset
11
11
  from dpshdl.prefetcher import Prefetcher
12
+ from jaxtyping import PRNGKeyArray
12
13
  from omegaconf import II, MISSING
13
14
 
14
15
  from xax.core.conf import field, is_missing
@@ -103,7 +104,7 @@ class DataloadersMixin(ProcessMixin[Config], BaseTask[Config], Generic[Config],
103
104
  "or `get_data_iterator` to return an iterator for the given dataset."
104
105
  )
105
106
 
106
- def get_data_iterator(self, phase: Phase) -> Iterator:
107
+ def get_data_iterator(self, phase: Phase, key: PRNGKeyArray) -> Iterator:
107
108
  raise NotImplementedError(
108
109
  "You must implement either the `get_dataset` method to return the dataset for the given phase, "
109
110
  "or `get_data_iterator` to return an iterator for the given dataset."
xax/task/mixins/train.py CHANGED
@@ -11,7 +11,8 @@ import textwrap
11
11
  import time
12
12
  import traceback
13
13
  from abc import ABC, abstractmethod
14
- from dataclasses import dataclass, is_dataclass
14
+ from dataclasses import asdict, dataclass, is_dataclass
15
+ from pathlib import Path
15
16
  from threading import Thread
16
17
  from typing import (
17
18
  Any,
@@ -33,14 +34,13 @@ import jax.numpy as jnp
33
34
  import numpy as np
34
35
  import optax
35
36
  from jaxtyping import Array, PRNGKeyArray, PyTree
36
- from omegaconf import DictConfig
37
37
 
38
38
  from xax.core.conf import field
39
39
  from xax.core.state import Phase, State
40
40
  from xax.nn.functions import set_random_seed
41
41
  from xax.nn.parallel import is_master
42
42
  from xax.task.mixins.artifacts import ArtifactsConfig, ArtifactsMixin
43
- from xax.task.mixins.checkpointing import CheckpointingConfig, CheckpointingMixin
43
+ from xax.task.mixins.checkpointing import CheckpointingConfig, CheckpointingMixin, CheckpointPart
44
44
  from xax.task.mixins.data_loader import DataloadersConfig, DataloadersMixin
45
45
  from xax.task.mixins.logger import LoggerConfig, LoggerMixin
46
46
  from xax.task.mixins.runnable import RunnableConfig, RunnableMixin
@@ -50,11 +50,12 @@ from xax.utils.experiments import (
50
50
  TrainingFinishedError,
51
51
  diff_configs,
52
52
  get_diff_string,
53
+ get_info_json,
53
54
  get_state_file_string,
54
55
  get_training_code,
55
56
  )
56
57
  from xax.utils.jax import jit as xax_jit
57
- from xax.utils.logging import LOG_STATUS
58
+ from xax.utils.logging import LOG_PING, LOG_STATUS
58
59
  from xax.utils.text import highlight_exception_message, show_info
59
60
  from xax.utils.types.frozen_dict import FrozenDict
60
61
 
@@ -340,20 +341,19 @@ class TrainMixin(
340
341
 
341
342
  if init_ckpt_path is not None:
342
343
  logger.info("Loading checkpoint from %s", init_ckpt_path)
343
- if load_optimizer:
344
- model, optimizer, opt_state, state, config = self.load_checkpoint(init_ckpt_path)
345
- config_diff = get_diff_string(diff_configs(config, cast(DictConfig, self.config)))
346
- if config_diff:
347
- logger.warning("Loaded config differs from current config:\n%s", config_diff)
348
- return model, optimizer, opt_state, state
344
+ model, state, config = self.load_ckpt(init_ckpt_path, part="model_state_config")
345
+ config_diff = get_diff_string(diff_configs(asdict(config), asdict(self.config)))
346
+ if config_diff:
347
+ logger.warning("Loaded config differs from current config:\n%s", config_diff)
349
348
 
350
- else:
351
- model, state, config = self.load_checkpoint(init_ckpt_path, "model_state_config")
352
- config_diff = get_diff_string(diff_configs(config, cast(DictConfig, self.config)))
353
- if config_diff:
354
- logger.warning("Loaded config differs from current config:\n%s", config_diff)
349
+ if not load_optimizer:
355
350
  return model, state
356
351
 
352
+ optimizer = self.load_ckpt(init_ckpt_path, part="opt")
353
+ opt_state = self.load_ckpt(init_ckpt_path, part="opt_state", model=model, optimizer=optimizer)
354
+ return model, optimizer, opt_state, state
355
+
356
+ logger.info("Starting a new training run")
357
357
  model = self.get_model(key)
358
358
  state = State.init_state()
359
359
 
@@ -365,6 +365,131 @@ class TrainMixin(
365
365
 
366
366
  return model, optimizer, opt_state, state
367
367
 
368
+ @overload
369
+ def load_ckpt(
370
+ self,
371
+ path: Path,
372
+ *,
373
+ part: Literal["all"],
374
+ ) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State, Config]: ...
375
+
376
+ @overload
377
+ def load_ckpt(
378
+ self,
379
+ path: Path,
380
+ *,
381
+ part: Literal["model_state_config"],
382
+ ) -> tuple[PyTree, State, Config]: ...
383
+
384
+ @overload
385
+ def load_ckpt(
386
+ self,
387
+ path: Path,
388
+ *,
389
+ part: Literal["model"],
390
+ ) -> PyTree: ...
391
+
392
+ @overload
393
+ def load_ckpt(
394
+ self,
395
+ path: Path,
396
+ *,
397
+ part: Literal["opt"],
398
+ ) -> optax.GradientTransformation: ...
399
+
400
+ @overload
401
+ def load_ckpt(
402
+ self,
403
+ path: Path,
404
+ *,
405
+ part: Literal["opt_state"],
406
+ model: PyTree | None = None,
407
+ optimizer: optax.GradientTransformation | None = None,
408
+ ) -> optax.OptState: ...
409
+
410
+ @overload
411
+ def load_ckpt(
412
+ self,
413
+ path: Path,
414
+ *,
415
+ part: Literal["state"],
416
+ ) -> State: ...
417
+
418
+ @overload
419
+ def load_ckpt(
420
+ self,
421
+ path: Path,
422
+ *,
423
+ part: Literal["config"],
424
+ ) -> Config: ...
425
+
426
+ def load_ckpt(
427
+ self,
428
+ path: str | Path,
429
+ *,
430
+ part: CheckpointPart = "all",
431
+ model: PyTree | None = None,
432
+ optimizer: optax.GradientTransformation | None = None,
433
+ ) -> (
434
+ tuple[PyTree, optax.GradientTransformation, optax.OptState, State, Config]
435
+ | tuple[PyTree, State, Config]
436
+ | PyTree
437
+ | optax.GradientTransformation
438
+ | optax.OptState
439
+ | State
440
+ | Config
441
+ ):
442
+ path = Path(path)
443
+
444
+ # This key isn't used for anything, it's just a required argument.
445
+ key = jax.random.PRNGKey(0)
446
+
447
+ match part:
448
+ case "model_state_config":
449
+ model_spec = eqx.filter_eval_shape(self.get_model, key)
450
+ return self.load_ckpt_with_template(path, part="model_state_config", model_template=model_spec)
451
+
452
+ case "model":
453
+ model_spec = eqx.filter_eval_shape(self.get_model, key)
454
+ return self.load_ckpt_with_template(path, part="model", model_template=model_spec)
455
+
456
+ case "config":
457
+ return self.load_ckpt_with_template(path, part="config")
458
+
459
+ case "opt":
460
+ optimizer_spec = eqx.filter_eval_shape(self.get_optimizer)
461
+ return self.load_ckpt_with_template(path, part="opt", optimizer_template=optimizer_spec)
462
+
463
+ case "opt_state":
464
+ if model is None:
465
+ model_spec = eqx.filter_eval_shape(self.get_model, key)
466
+ model = self.load_ckpt_with_template(path, part="model", model_template=model_spec)
467
+ if optimizer is None:
468
+ optimizer_spec = eqx.filter_eval_shape(self.get_optimizer)
469
+ optimizer = self.load_ckpt_with_template(path, part="opt", optimizer_template=optimizer_spec)
470
+ opt_state_spec = eqx.filter_eval_shape(self.get_initial_opt_state, model, optimizer)
471
+ return self.load_ckpt_with_template(path, part="opt_state", opt_state_template=opt_state_spec)
472
+
473
+ case "state":
474
+ return self.load_ckpt_with_template(path, part="state")
475
+
476
+ case "config":
477
+ return self.load_ckpt_with_template(path, part="config")
478
+
479
+ case "all":
480
+ model_spec = eqx.filter_eval_shape(self.get_model, key)
481
+ model = self.load_ckpt_with_template(path, part="model", model_template=model_spec)
482
+ optimizer_spec = eqx.filter_eval_shape(self.get_optimizer)
483
+ optimizer = self.load_ckpt_with_template(path, part="opt", optimizer_template=optimizer_spec)
484
+ opt_state_spec = eqx.filter_eval_shape(self.get_initial_opt_state, model, optimizer)
485
+ opt_state = self.load_ckpt_with_template(path, part="opt_state", opt_state_template=opt_state_spec)
486
+ state = self.load_ckpt_with_template(path, part="state")
487
+ config = self.load_ckpt_with_template(path, part="config")
488
+ return model, optimizer, opt_state, state, config
489
+
490
+ case _:
491
+ raise ValueError(f"Unknown checkpoint part: {part}")
492
+
368
493
  def get_output(self, model: PyTree, batch: Batch, state: State) -> Output:
369
494
  """Gets the output from the model.
370
495
 
@@ -519,8 +644,7 @@ class TrainMixin(
519
644
  self._last_printed_remaining_time = state.elapsed_time_s
520
645
  remaining_seconds = remaining_percent * state.elapsed_time_s / (1 - remaining_percent)
521
646
  termination_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time() + remaining_seconds))
522
- # logger.info("Estimated finish time: %s", termination_time)
523
- jax.debug.print("Estimated finish time: {}", termination_time)
647
+ logger.log(LOG_PING, "Estimated finish time: %s", termination_time)
524
648
 
525
649
  def get_remaining_percent(self, state: State) -> float | None:
526
650
  if self.config.max_steps is None:
@@ -554,6 +678,7 @@ class TrainMixin(
554
678
  self.logger.log_file("state.txt", get_state_file_string(self))
555
679
  self.logger.log_file("training_code.py", get_training_code(self))
556
680
  self.logger.log_file("config.yaml", self.config_str(self.config, use_cli=False))
681
+ self.logger.log_file("info.json", get_info_json())
557
682
 
558
683
  def model_partition_fn(self, item: Any) -> bool: # noqa: ANN401
559
684
  return eqx.is_inexact_array(item)
@@ -627,16 +752,16 @@ class TrainMixin(
627
752
 
628
753
  if self.should_checkpoint(state):
629
754
  model = eqx.combine(model_arr, model_static)
630
- self.save_checkpoint(model, optimizer, opt_state, state)
755
+ self.save_checkpoint(model=model, optimizer=optimizer, opt_state=opt_state, state=state)
631
756
 
632
757
  # After finishing training, save the final checkpoint.
633
758
  model = eqx.combine(model_arr, model_static)
634
- self.save_checkpoint(model, optimizer, opt_state, state)
759
+ self.save_checkpoint(model=model, optimizer=optimizer, opt_state=opt_state, state=state)
635
760
 
636
761
  @contextlib.contextmanager
637
- def get_train_iterator(self) -> Generator[Iterator[Batch], None, None]:
762
+ def get_train_iterator(self, key: PRNGKeyArray) -> Generator[Iterator[Batch], None, None]:
638
763
  try:
639
- train_iterator: Iterator[Batch] = self.get_data_iterator("train")
764
+ train_iterator: Iterator[Batch] = self.get_data_iterator("train", key=key)
640
765
  yield train_iterator
641
766
  return
642
767
  except NotImplementedError:
@@ -653,9 +778,9 @@ class TrainMixin(
653
778
  logger.info("Closing train prefetcher")
654
779
 
655
780
  @contextlib.contextmanager
656
- def get_valid_iterator(self) -> Generator[Iterator[Batch], None, None]:
781
+ def get_valid_iterator(self, key: PRNGKeyArray) -> Generator[Iterator[Batch], None, None]:
657
782
  try:
658
- valid_iterator: Iterator[Batch] = self.get_data_iterator("valid")
783
+ valid_iterator: Iterator[Batch] = self.get_data_iterator("valid", key=key)
659
784
  yield valid_iterator
660
785
  return
661
786
  except NotImplementedError:
@@ -699,12 +824,13 @@ class TrainMixin(
699
824
  state = self.on_training_start(state)
700
825
 
701
826
  def on_exit() -> None:
702
- self.save_checkpoint(model, optimizer, opt_state, state)
827
+ self.save_checkpoint(model=model, optimizer=optimizer, opt_state=opt_state, state=state)
703
828
 
704
829
  # Handle user-defined interrupts during the training loop.
705
830
  self.add_signal_handler(on_exit, signal.SIGUSR1, signal.SIGTERM)
706
831
 
707
- with self.get_train_iterator() as train_pf, self.get_valid_iterator() as valid_pf:
832
+ key, tkey, vkey = jax.random.split(key, 3)
833
+ with self.get_train_iterator(tkey) as train_pf, self.get_valid_iterator(vkey) as valid_pf:
708
834
  try:
709
835
  self.train_loop(
710
836
  model=model,
@@ -721,7 +847,7 @@ class TrainMixin(
721
847
  f"Finished training after {state.num_steps} steps, {state.num_samples} samples",
722
848
  important=True,
723
849
  )
724
- self.save_checkpoint(model, optimizer, opt_state, state)
850
+ self.save_checkpoint(model=model, optimizer=optimizer, opt_state=opt_state, state=state)
725
851
 
726
852
  except (KeyboardInterrupt, bdb.BdbQuit):
727
853
  if is_master():
@@ -731,7 +857,7 @@ class TrainMixin(
731
857
  exception_tb = textwrap.indent(highlight_exception_message(traceback.format_exc()), " ")
732
858
  sys.stdout.write(f"Caught exception during training loop:\n\n{exception_tb}\n")
733
859
  sys.stdout.flush()
734
- self.save_checkpoint(model, optimizer, opt_state, state)
860
+ self.save_checkpoint(model=model, optimizer=optimizer, opt_state=opt_state, state=state)
735
861
 
736
862
  finally:
737
863
  state = self.on_training_end(state)
xax/utils/experiments.py CHANGED
@@ -7,6 +7,7 @@ import functools
7
7
  import hashlib
8
8
  import inspect
9
9
  import itertools
10
+ import json
10
11
  import logging
11
12
  import math
12
13
  import os
@@ -24,7 +25,7 @@ import warnings
24
25
  from abc import ABC, abstractmethod
25
26
  from pathlib import Path
26
27
  from types import TracebackType
27
- from typing import Any, Iterator, Self, TypeVar, cast
28
+ from typing import Any, Iterator, Mapping, Self, Sequence, TypeVar, cast
28
29
  from urllib.parse import urlparse
29
30
 
30
31
  import git
@@ -116,9 +117,7 @@ class StateTimer:
116
117
 
117
118
  def log_dict(self) -> dict[str, int | float | tuple[int | float, bool]]:
118
119
  return {
119
- "steps": (self.step_timer.steps, True),
120
120
  "steps/second": self.step_timer.steps_per_second,
121
- "samples": (self.sample_timer.steps, True),
122
121
  "samples/second": (self.sample_timer.steps_per_second, True),
123
122
  "dt": self.iter_timer.iter_seconds,
124
123
  }
@@ -204,8 +203,8 @@ class MinGradScaleError(TrainingFinishedError):
204
203
 
205
204
 
206
205
  def diff_configs(
207
- first: ListConfig | DictConfig,
208
- second: ListConfig | DictConfig,
206
+ first: Mapping | Sequence,
207
+ second: Mapping | Sequence,
209
208
  prefix: str | None = None,
210
209
  ) -> tuple[list[str], list[str]]:
211
210
  """Returns the difference between two configs.
@@ -232,7 +231,7 @@ def diff_configs(
232
231
 
233
232
  any_config = (ListConfig, DictConfig)
234
233
 
235
- if isinstance(first, DictConfig) and isinstance(second, DictConfig):
234
+ if isinstance(first, Mapping) and isinstance(second, Mapping):
236
235
  first_keys, second_keys = cast(set[str], set(first.keys())), cast(set[str], set(second.keys()))
237
236
 
238
237
  # Gets the new keys in each config.
@@ -242,11 +241,12 @@ def diff_configs(
242
241
  # Gets the new sub-keys in each config.
243
242
  for key in first_keys.intersection(second_keys):
244
243
  sub_prefix = key if prefix is None else f"{prefix}.{key}"
245
- if OmegaConf.is_missing(first, key) or OmegaConf.is_missing(second, key):
246
- if not OmegaConf.is_missing(first, key):
247
- new_first += [get_diff_string(sub_prefix, first[key])]
248
- if not OmegaConf.is_missing(second, key):
249
- new_second += [get_diff_string(sub_prefix, second[key])]
244
+ if isinstance(first, DictConfig) and isinstance(second, DictConfig):
245
+ if OmegaConf.is_missing(first, key) or OmegaConf.is_missing(second, key):
246
+ if not OmegaConf.is_missing(first, key):
247
+ new_first += [get_diff_string(sub_prefix, first[key])]
248
+ if not OmegaConf.is_missing(second, key):
249
+ new_second += [get_diff_string(sub_prefix, second[key])]
250
250
  elif isinstance(first[key], any_config) and isinstance(second[key], any_config):
251
251
  sub_new_first, sub_new_second = diff_configs(first[key], second[key], prefix=sub_prefix)
252
252
  new_first, new_second = new_first + sub_new_first, new_second + sub_new_second
@@ -255,7 +255,7 @@ def diff_configs(
255
255
  new_first += [get_diff_string(sub_prefix, first_val)]
256
256
  new_second += [get_diff_string(sub_prefix, second_val)]
257
257
 
258
- elif isinstance(first, ListConfig) and isinstance(second, ListConfig):
258
+ elif isinstance(first, Sequence) and isinstance(second, Sequence):
259
259
  if len(first) > len(second):
260
260
  for i in range(len(second), len(first)):
261
261
  new_first += [get_diff_string(prefix, first[i])]
@@ -470,16 +470,33 @@ def get_command_line_string() -> str:
470
470
  return " ".join(sys.argv)
471
471
 
472
472
 
473
+ def get_environment_variables() -> str:
474
+ return "\n".join([f"{key}={value}" for key, value in sorted(os.environ.items())])
475
+
476
+
473
477
  def get_state_file_string(obj: object) -> str:
474
478
  return "\n\n".join(
475
479
  [
476
480
  f"=== Command Line ===\n\n{get_command_line_string()}",
477
481
  f"=== Git State ===\n\n{get_git_state(obj)}",
478
482
  f"=== Packages ===\n\n{get_packages_with_versions()}",
483
+ f"=== Environment Variables ===\n\n{get_environment_variables()}",
479
484
  ]
480
485
  )
481
486
 
482
487
 
488
+ def get_info_json() -> str:
489
+ return json.dumps(
490
+ {
491
+ "process_id": os.getpid(),
492
+ "job": {
493
+ "start_time": datetime.datetime.now().isoformat(),
494
+ },
495
+ },
496
+ indent=2,
497
+ )
498
+
499
+
483
500
  def get_training_code(obj: object) -> str:
484
501
  """Gets the text from the file containing the provided object.
485
502
 
xax/utils/tensorboard.py CHANGED
@@ -2,11 +2,12 @@
2
2
 
3
3
  import functools
4
4
  import io
5
+ import json
5
6
  import os
6
7
  import tempfile
7
8
  import time
8
9
  from pathlib import Path
9
- from typing import Literal, TypedDict
10
+ from typing import Any, Literal, TypedDict
10
11
 
11
12
  import numpy as np
12
13
  import PIL.Image
@@ -14,9 +15,15 @@ from PIL.Image import Image as PILImage
14
15
  from tensorboard.compat.proto.config_pb2 import RunMetadata
15
16
  from tensorboard.compat.proto.event_pb2 import Event, TaggedRunMetadata
16
17
  from tensorboard.compat.proto.graph_pb2 import GraphDef
17
- from tensorboard.compat.proto.summary_pb2 import HistogramProto, Summary, SummaryMetadata
18
+ from tensorboard.compat.proto.summary_pb2 import (
19
+ HistogramProto,
20
+ Summary,
21
+ SummaryMetadata,
22
+ )
18
23
  from tensorboard.compat.proto.tensor_pb2 import TensorProto
19
24
  from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto
25
+ from tensorboard.plugins.mesh import metadata as mesh_metadata
26
+ from tensorboard.plugins.mesh.plugin_data_pb2 import MeshPluginData
20
27
  from tensorboard.plugins.text.plugin_data_pb2 import TextPluginData
21
28
  from tensorboard.summary.writer.event_file_writer import EventFileWriter
22
29
 
@@ -84,6 +91,68 @@ def make_histogram(values: np.ndarray, bins: str | np.ndarray, max_bins: int | N
84
91
  )
85
92
 
86
93
 
94
+ def _get_json_config(config_dict: dict[str, Any] | None) -> str:
95
+ json_config = "{}"
96
+ if config_dict is not None:
97
+ json_config = json.dumps(config_dict, sort_keys=True)
98
+ return json_config
99
+
100
+
101
+ def make_mesh_summary(
102
+ tag: str,
103
+ vertices: np.ndarray,
104
+ colors: np.ndarray | None,
105
+ faces: np.ndarray | None,
106
+ config_dict: dict[str, Any] | None,
107
+ display_name: str | None = None,
108
+ description: str | None = None,
109
+ ) -> Summary:
110
+ json_config = _get_json_config(config_dict)
111
+
112
+ summaries = []
113
+ tensors = [
114
+ (vertices, MeshPluginData.VERTEX),
115
+ (faces, MeshPluginData.FACE),
116
+ (colors, MeshPluginData.COLOR),
117
+ ]
118
+ # Filter out None tensors and explicitly type the list
119
+ valid_tensors = [(t, content_type) for t, content_type in tensors if t is not None]
120
+ components = mesh_metadata.get_components_bitmask([content_type for (_, content_type) in valid_tensors])
121
+
122
+ for tensor, content_type in valid_tensors: # Now we know tensor is not None
123
+ tensor_metadata = mesh_metadata.create_summary_metadata(
124
+ tag,
125
+ display_name,
126
+ content_type,
127
+ components,
128
+ tensor.shape, # Safe now since tensor is not None
129
+ description,
130
+ json_config=json_config,
131
+ )
132
+
133
+ tensor_proto = TensorProto(
134
+ dtype="DT_FLOAT",
135
+ float_val=tensor.reshape(-1).tolist(), # Safe now since tensor is not None
136
+ tensor_shape=TensorShapeProto(
137
+ dim=[
138
+ TensorShapeProto.Dim(size=tensor.shape[0]), # Safe now since tensor is not None
139
+ TensorShapeProto.Dim(size=tensor.shape[1]),
140
+ TensorShapeProto.Dim(size=tensor.shape[2]),
141
+ ]
142
+ ),
143
+ )
144
+
145
+ tensor_summary = Summary.Value(
146
+ tag=mesh_metadata.get_instance_name(tag, content_type),
147
+ tensor=tensor_proto,
148
+ metadata=tensor_metadata,
149
+ )
150
+
151
+ summaries.append(tensor_summary)
152
+
153
+ return Summary(value=summaries)
154
+
155
+
87
156
  class TensorboardProtobufWriter:
88
157
  def __init__(
89
158
  self,
@@ -454,6 +523,9 @@ class TensorboardWriter:
454
523
  weighted_sum = float((bin_centers * bucket_counts).sum())
455
524
  weighted_sum_squares = float((bin_centers**2 * bucket_counts).sum())
456
525
 
526
+ # Convert bin edges to list of floats explicitly
527
+ bucket_limits: list[float | np.ndarray] = [float(x) for x in bin_edges[1:]]
528
+
457
529
  self.add_histogram_raw(
458
530
  tag=tag,
459
531
  min=float(bin_edges[0]),
@@ -461,12 +533,28 @@ class TensorboardWriter:
461
533
  num=int(total_counts),
462
534
  sum=weighted_sum,
463
535
  sum_squares=weighted_sum_squares,
464
- bucket_limits=bin_edges[1:].tolist(), # TensorBoard expects right bin edges
536
+ bucket_limits=bucket_limits, # Now properly typed
465
537
  bucket_counts=bucket_counts.tolist(),
466
538
  global_step=global_step,
467
539
  walltime=walltime,
468
540
  )
469
541
 
542
+ def add_mesh(
543
+ self,
544
+ tag: str,
545
+ vertices: np.ndarray,
546
+ colors: np.ndarray | None,
547
+ faces: np.ndarray | None,
548
+ config_dict: dict[str, Any] | None,
549
+ global_step: int | None = None,
550
+ walltime: float | None = None,
551
+ ) -> None:
552
+ self.pb_writer.add_summary(
553
+ make_mesh_summary(tag, vertices, colors, faces, config_dict),
554
+ global_step=global_step,
555
+ walltime=walltime,
556
+ )
557
+
470
558
 
471
559
  class TensorboardWriterKwargs(TypedDict):
472
560
  max_queue_size: int
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.1.16
3
+ Version: 0.2.1
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -8,14 +8,14 @@ Requires-Python: >=3.11
8
8
  Description-Content-Type: text/markdown
9
9
  License-File: LICENSE
10
10
  Requires-Dist: attrs
11
+ Requires-Dist: chex
12
+ Requires-Dist: dpshdl
13
+ Requires-Dist: equinox
14
+ Requires-Dist: importlib-resources
11
15
  Requires-Dist: jax
12
16
  Requires-Dist: jaxtyping
13
- Requires-Dist: equinox
14
17
  Requires-Dist: optax
15
- Requires-Dist: dpshdl
16
- Requires-Dist: chex
17
- Requires-Dist: importlib-resources
18
- Requires-Dist: cloudpickle
18
+ Requires-Dist: orbax-checkpoint
19
19
  Requires-Dist: pillow
20
20
  Requires-Dist: omegaconf
21
21
  Requires-Dist: gitpython
@@ -1,23 +1,23 @@
1
- xax/__init__.py,sha256=pSWV5RtPBJynHr7dCqscbnMkETZPUyw8D6MHK4CuS90,14104
1
+ xax/__init__.py,sha256=kd-88OQGnuHb91PXwroAfLb0bMfbe37fXqpECRrjhoU,14182
2
2
  xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
4
- xax/requirements.txt,sha256=9LAEZ5c5gqRSARRVA6xJsVTa4MebPZuC4yOkkwkZJFw,297
4
+ xax/requirements.txt,sha256=6qY-84e-sTmlfJNrSjwONQKqzAn5h8G_oGIhnhmfSr4,302
5
5
  xax/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
6
  xax/core/conf.py,sha256=Wuo5WLRWuRTgb8eaihvnG_NZskTu0-P3JkIcl_hKINM,5124
7
- xax/core/state.py,sha256=WwW0qDm-be9MMOT-bGWEFvaWF4iq2FP9xRSn1zq_4A8,2507
7
+ xax/core/state.py,sha256=XejW1tGINYFFcNrscK8eZQsq02J7_RXa461QpmyWuLk,3337
8
8
  xax/nn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
9
  xax/nn/embeddings.py,sha256=bQGxBFxkLwi2MQLkRfGaHPH5P_KKB21HdI7VNWTKIOQ,11847
10
10
  xax/nn/equinox.py,sha256=5fdOKRXqAVZPsV-aEez3i1wamr_oBYnG74GP1jEthjM,4843
11
11
  xax/nn/export.py,sha256=7Yemw3T33QGEP8RkmTkpu6tRVOhut2RUJmttNFfCgFw,5537
12
12
  xax/nn/functions.py,sha256=CI_OmspaQwN9nl4hwefIU3_I7m6gBZwJ9aGK1JGUgr0,2713
13
- xax/nn/geom.py,sha256=PN0Ndn575aVtsSfxi67RghHB7luRkqtpS7bPbT1LpLE,5201
13
+ xax/nn/geom.py,sha256=rImNlkHWeoNcY7f84nknizJ6uzsrMhbAtKeb2xAWxNY,6215
14
14
  xax/nn/losses.py,sha256=Q_NVnm5n4UPBvp5nI_1aUptfXnqFYoUeFwySiyvopHg,272
15
15
  xax/nn/norm.py,sha256=WgZ3QCrUnf-YecwhEtVPcr99fKK3ECl_UeiAs2uv7oo,564
16
16
  xax/nn/parallel.py,sha256=fnTiT7MsG7eQrJvqwjIz2Ifo3P27TuxIJzmpGYSa_dQ,4608
17
17
  xax/nn/ssm.py,sha256=8dLAcQ1hBaMT-kkHvwGu_ecxJeTY32WeMYmd4T4KtxA,10745
18
18
  xax/task/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
19
- xax/task/base.py,sha256=DqgGIlo5kEWpYix3DdPCEkCgVLUOocjyFr8okaSUq-k,7680
20
- xax/task/logger.py,sha256=Upx7cCZvaVIs75CHTfIzYmsuaFRsGu0FvziTZuazT4k,37083
19
+ xax/task/base.py,sha256=OnXi2hiKPGwt6ng1dutnoQSiw7lEiWFlC_vx99_JsbQ,7694
20
+ xax/task/logger.py,sha256=peGtfnvnBKr9l6tx1V6XAsvPs0HP6ubV_aE7IJtOMNk,40868
21
21
  xax/task/script.py,sha256=bMMIJoUtpSBvPp6-7bejTrajTXvSg0794sYLKdPIToE,972
22
22
  xax/task/task.py,sha256=UHMpnv__gqMcfbC_L-Hhk-DCnUYlFVsgbNf-v8o8B7U,1424
23
23
  xax/task/launchers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -29,37 +29,37 @@ xax/task/loggers/callback.py,sha256=lyuZX6Bir7xJM07ifdQIl1jlclgkiS82UO9V4y7wgPs,
29
29
  xax/task/loggers/json.py,sha256=_tKum6jk_gqVzO-4MqSNXbE-Mmn-yJzkRAT-N1y2zes,4139
30
30
  xax/task/loggers/state.py,sha256=6bG-NRsSUzAukYiglCT0oDj8zRMpffH4e1TKWGw1x4k,959
31
31
  xax/task/loggers/stdout.py,sha256=oeIgPkj4RyJgBuWaJK9ncLa65iBNJCWXhSF8fx3_54c,6564
32
- xax/task/loggers/tensorboard.py,sha256=HjR-wiCWe0z3nivRzxEZIltzSzka1828bwxWVmMU5Sk,7718
32
+ xax/task/loggers/tensorboard.py,sha256=KOL9l60tLctX-VAdNwe49H48SAJeGxph3sflJpojA-4,8337
33
33
  xax/task/mixins/__init__.py,sha256=D3oU31rB9FeOr9MPLleLt5JFbftUr4sBTwgnwQdc2qA,809
34
34
  xax/task/mixins/artifacts.py,sha256=2ezmZGzPGe3nhsd9KRkeHWWXdbT9m7drzimIfw6v1XY,2892
35
- xax/task/mixins/checkpointing.py,sha256=nRddgtasagf0oTZE9LE5IN5JY7jy4BD_M0rlqYp4sCM,8554
35
+ xax/task/mixins/checkpointing.py,sha256=2nJgqFcV-D8W-4j8TR3PvVh1g5hQUOo-_quKO-XlE4U,11398
36
36
  xax/task/mixins/compile.py,sha256=PG5aF3W9v_xGiImHgUJ7gmwuQQoSQWufdpl2N_mlLX0,3922
37
37
  xax/task/mixins/cpu_stats.py,sha256=vAjEc3HpPnl56m7vshYX0dXAHJrB98DzVdsYSRqQllc,9371
38
- xax/task/mixins/data_loader.py,sha256=WjMWk9uACfBMMClLMcLPkE0WNIvlCZnmqyyqLqJpjX0,6545
38
+ xax/task/mixins/data_loader.py,sha256=Tp7zqPdfH2_JuE6J6EP-fEtCQpq9MjKlGHYK7Zh-goU,6599
39
39
  xax/task/mixins/gpu_stats.py,sha256=4HU6teEDlqMitLbSx7fbyL4qBJ0PgGy0Ly_Pzife8yo,8795
40
40
  xax/task/mixins/logger.py,sha256=6oXsJJyNUx6YT3q58FVXMZBUpMgjVkGre6BXFN20cVI,2808
41
41
  xax/task/mixins/process.py,sha256=d1opVgvc6bOFXb7R58b07F4P5lbSZIzYaajtE0eBbpw,1477
42
42
  xax/task/mixins/runnable.py,sha256=IYIsLd2k09g-_y6o44EhJqT7E6BpsyEMmsyLSuzqjtc,1979
43
43
  xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
44
- xax/task/mixins/train.py,sha256=4Xr8b5LFueFh-f3k8MIJMv3M46_Aaf65YwCbjtSBQ-U,26393
44
+ xax/task/mixins/train.py,sha256=v9oi9tNsNBYo-Ne_98nCG9qHX6sxvymHjsRDnL6GL-U,30871
45
45
  xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
46
46
  xax/utils/debugging.py,sha256=OtUdu-3tQsQtik0Q9UM-SNV46IbPjwrAfZcywzoB5d4,1940
47
- xax/utils/experiments.py,sha256=vm_hWfaty_wEHVdoU2ALiBiGJze7IoDJIfXi6pd_a9I,29360
47
+ xax/utils/experiments.py,sha256=Hzl46_9IH5_9cKzxit-FyVUWBH-_lBs00ZciuIdnWO8,29811
48
48
  xax/utils/jax.py,sha256=tC0NNelbrSTzwNGluiwLGKtoHhVpgdzrv-xherB3VtY,4752
49
49
  xax/utils/jaxpr.py,sha256=S80nyEkv188RInzq3kCAdkQCU-bf6s0oPTrCE_LjkRs,2298
50
50
  xax/utils/logging.py,sha256=GAhTne2rdB4Fa1lzk06DMO15U8MTejn6XTClShC-ZtU,6622
51
51
  xax/utils/numpy.py,sha256=_jOXVi-d2AtJnRftPkRK5MDMzsU8slgw-Jjv4GRm6ns,1197
52
52
  xax/utils/profile.py,sha256=-aFdWpgYFvBsBZXSLL4zXrFe3zzsDqzmx4q5f2WOtpQ,1628
53
53
  xax/utils/pytree.py,sha256=VFWhT0MQ99KjQyEYM6NFbqYq4_hOZwB23uhowMB4U34,8754
54
- xax/utils/tensorboard.py,sha256=21czW8WC2SAmwEhz6RLJc_q5HFvNKM4iR1ZycSO5qPE,17058
54
+ xax/utils/tensorboard.py,sha256=P0oIFvX2Qts1H4lkpizhRIpQdD0MNppVMeut0Z94yCs,19878
55
55
  xax/utils/text.py,sha256=zo1sAoZe59GkpcpaHBVOQ0OekSMGXvOAyNa3lOJozCY,10628
56
56
  xax/utils/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
57
57
  xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,7066
58
58
  xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
59
59
  xax/utils/types/frozen_dict.py,sha256=ZCMGfSfr2_b2qZbq9ywPD0zej5tpVSId2JftXpwfB5k,4686
60
60
  xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
61
- xax-0.1.16.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
62
- xax-0.1.16.dist-info/METADATA,sha256=gfh7iFi7Wz3fJDf2w1KKs8H0uanhn2HFsR67TvP6uZM,1878
63
- xax-0.1.16.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
64
- xax-0.1.16.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
65
- xax-0.1.16.dist-info/RECORD,,
61
+ xax-0.2.1.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
62
+ xax-0.2.1.dist-info/METADATA,sha256=2pOZLKMIcLoQTM-tRqRvVkF57PZyMoALM87UI5B4dtk,1882
63
+ xax-0.2.1.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
64
+ xax-0.2.1.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
65
+ xax-0.2.1.dist-info/RECORD,,
File without changes