xax 0.1.16__tar.gz → 0.2.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {xax-0.1.16/xax.egg-info → xax-0.2.1}/PKG-INFO +6 -6
- {xax-0.1.16 → xax-0.2.1}/pyproject.toml +0 -1
- {xax-0.1.16 → xax-0.2.1}/xax/__init__.py +4 -1
- {xax-0.1.16 → xax-0.2.1}/xax/core/state.py +26 -1
- {xax-0.1.16 → xax-0.2.1}/xax/nn/geom.py +34 -0
- {xax-0.1.16 → xax-0.2.1}/xax/requirements.txt +5 -5
- {xax-0.1.16 → xax-0.2.1}/xax/task/base.py +1 -1
- {xax-0.1.16 → xax-0.2.1}/xax/task/logger.py +107 -2
- {xax-0.1.16 → xax-0.2.1}/xax/task/loggers/tensorboard.py +16 -0
- {xax-0.1.16 → xax-0.2.1}/xax/task/mixins/checkpointing.py +124 -50
- {xax-0.1.16 → xax-0.2.1}/xax/task/mixins/data_loader.py +2 -1
- {xax-0.1.16 → xax-0.2.1}/xax/task/mixins/train.py +153 -27
- {xax-0.1.16 → xax-0.2.1}/xax/utils/experiments.py +29 -12
- {xax-0.1.16 → xax-0.2.1}/xax/utils/tensorboard.py +91 -3
- {xax-0.1.16 → xax-0.2.1/xax.egg-info}/PKG-INFO +6 -6
- {xax-0.1.16 → xax-0.2.1}/xax.egg-info/requires.txt +5 -5
- {xax-0.1.16 → xax-0.2.1}/LICENSE +0 -0
- {xax-0.1.16 → xax-0.2.1}/MANIFEST.in +0 -0
- {xax-0.1.16 → xax-0.2.1}/README.md +0 -0
- {xax-0.1.16 → xax-0.2.1}/setup.cfg +0 -0
- {xax-0.1.16 → xax-0.2.1}/setup.py +0 -0
- {xax-0.1.16 → xax-0.2.1}/xax/core/__init__.py +0 -0
- {xax-0.1.16 → xax-0.2.1}/xax/core/conf.py +0 -0
- {xax-0.1.16 → xax-0.2.1}/xax/nn/__init__.py +0 -0
- {xax-0.1.16 → xax-0.2.1}/xax/nn/embeddings.py +0 -0
- {xax-0.1.16 → xax-0.2.1}/xax/nn/equinox.py +0 -0
- {xax-0.1.16 → xax-0.2.1}/xax/nn/export.py +0 -0
- {xax-0.1.16 → xax-0.2.1}/xax/nn/functions.py +0 -0
- {xax-0.1.16 → xax-0.2.1}/xax/nn/losses.py +0 -0
- {xax-0.1.16 → xax-0.2.1}/xax/nn/norm.py +0 -0
- {xax-0.1.16 → xax-0.2.1}/xax/nn/parallel.py +0 -0
- {xax-0.1.16 → xax-0.2.1}/xax/nn/ssm.py +0 -0
- {xax-0.1.16 → xax-0.2.1}/xax/py.typed +0 -0
- {xax-0.1.16 → xax-0.2.1}/xax/requirements-dev.txt +0 -0
- {xax-0.1.16 → xax-0.2.1}/xax/task/__init__.py +0 -0
- {xax-0.1.16 → xax-0.2.1}/xax/task/launchers/__init__.py +0 -0
- {xax-0.1.16 → xax-0.2.1}/xax/task/launchers/base.py +0 -0
- {xax-0.1.16 → xax-0.2.1}/xax/task/launchers/cli.py +0 -0
- {xax-0.1.16 → xax-0.2.1}/xax/task/launchers/single_process.py +0 -0
- {xax-0.1.16 → xax-0.2.1}/xax/task/loggers/__init__.py +0 -0
- {xax-0.1.16 → xax-0.2.1}/xax/task/loggers/callback.py +0 -0
- {xax-0.1.16 → xax-0.2.1}/xax/task/loggers/json.py +0 -0
- {xax-0.1.16 → xax-0.2.1}/xax/task/loggers/state.py +0 -0
- {xax-0.1.16 → xax-0.2.1}/xax/task/loggers/stdout.py +0 -0
- {xax-0.1.16 → xax-0.2.1}/xax/task/mixins/__init__.py +0 -0
- {xax-0.1.16 → xax-0.2.1}/xax/task/mixins/artifacts.py +0 -0
- {xax-0.1.16 → xax-0.2.1}/xax/task/mixins/compile.py +0 -0
- {xax-0.1.16 → xax-0.2.1}/xax/task/mixins/cpu_stats.py +0 -0
- {xax-0.1.16 → xax-0.2.1}/xax/task/mixins/gpu_stats.py +0 -0
- {xax-0.1.16 → xax-0.2.1}/xax/task/mixins/logger.py +0 -0
- {xax-0.1.16 → xax-0.2.1}/xax/task/mixins/process.py +0 -0
- {xax-0.1.16 → xax-0.2.1}/xax/task/mixins/runnable.py +0 -0
- {xax-0.1.16 → xax-0.2.1}/xax/task/mixins/step_wrapper.py +0 -0
- {xax-0.1.16 → xax-0.2.1}/xax/task/script.py +0 -0
- {xax-0.1.16 → xax-0.2.1}/xax/task/task.py +0 -0
- {xax-0.1.16 → xax-0.2.1}/xax/utils/__init__.py +0 -0
- {xax-0.1.16 → xax-0.2.1}/xax/utils/data/__init__.py +0 -0
- {xax-0.1.16 → xax-0.2.1}/xax/utils/data/collate.py +0 -0
- {xax-0.1.16 → xax-0.2.1}/xax/utils/debugging.py +0 -0
- {xax-0.1.16 → xax-0.2.1}/xax/utils/jax.py +0 -0
- {xax-0.1.16 → xax-0.2.1}/xax/utils/jaxpr.py +0 -0
- {xax-0.1.16 → xax-0.2.1}/xax/utils/logging.py +0 -0
- {xax-0.1.16 → xax-0.2.1}/xax/utils/numpy.py +0 -0
- {xax-0.1.16 → xax-0.2.1}/xax/utils/profile.py +0 -0
- {xax-0.1.16 → xax-0.2.1}/xax/utils/pytree.py +0 -0
- {xax-0.1.16 → xax-0.2.1}/xax/utils/text.py +0 -0
- {xax-0.1.16 → xax-0.2.1}/xax/utils/types/__init__.py +0 -0
- {xax-0.1.16 → xax-0.2.1}/xax/utils/types/frozen_dict.py +0 -0
- {xax-0.1.16 → xax-0.2.1}/xax/utils/types/hashable_array.py +0 -0
- {xax-0.1.16 → xax-0.2.1}/xax.egg-info/SOURCES.txt +0 -0
- {xax-0.1.16 → xax-0.2.1}/xax.egg-info/dependency_links.txt +0 -0
- {xax-0.1.16 → xax-0.2.1}/xax.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: xax
|
3
|
-
Version: 0.1
|
3
|
+
Version: 0.2.1
|
4
4
|
Summary: A library for fast Jax experimentation
|
5
5
|
Home-page: https://github.com/kscalelabs/xax
|
6
6
|
Author: Benjamin Bolte
|
@@ -8,14 +8,14 @@ Requires-Python: >=3.11
|
|
8
8
|
Description-Content-Type: text/markdown
|
9
9
|
License-File: LICENSE
|
10
10
|
Requires-Dist: attrs
|
11
|
+
Requires-Dist: chex
|
12
|
+
Requires-Dist: dpshdl
|
13
|
+
Requires-Dist: equinox
|
14
|
+
Requires-Dist: importlib-resources
|
11
15
|
Requires-Dist: jax
|
12
16
|
Requires-Dist: jaxtyping
|
13
|
-
Requires-Dist: equinox
|
14
17
|
Requires-Dist: optax
|
15
|
-
Requires-Dist:
|
16
|
-
Requires-Dist: chex
|
17
|
-
Requires-Dist: importlib-resources
|
18
|
-
Requires-Dist: cloudpickle
|
18
|
+
Requires-Dist: orbax-checkpoint
|
19
19
|
Requires-Dist: pillow
|
20
20
|
Requires-Dist: omegaconf
|
21
21
|
Requires-Dist: gitpython
|
@@ -12,7 +12,7 @@ and running the update script:
|
|
12
12
|
python -m scripts.update_api --inplace
|
13
13
|
"""
|
14
14
|
|
15
|
-
__version__ = "0.1
|
15
|
+
__version__ = "0.2.1"
|
16
16
|
|
17
17
|
# This list shouldn't be modified by hand; instead, run the update script.
|
18
18
|
__all__ = [
|
@@ -44,6 +44,7 @@ __all__ = [
|
|
44
44
|
"euler_to_quat",
|
45
45
|
"get_projected_gravity_vector_from_quat",
|
46
46
|
"quat_to_euler",
|
47
|
+
"quat_to_rotmat",
|
47
48
|
"rotate_vector_by_quat",
|
48
49
|
"cross_entropy",
|
49
50
|
"cast_norm_type",
|
@@ -206,6 +207,7 @@ NAME_MAP: dict[str, str] = {
|
|
206
207
|
"euler_to_quat": "nn.geom",
|
207
208
|
"get_projected_gravity_vector_from_quat": "nn.geom",
|
208
209
|
"quat_to_euler": "nn.geom",
|
210
|
+
"quat_to_rotmat": "nn.geom",
|
209
211
|
"rotate_vector_by_quat": "nn.geom",
|
210
212
|
"cross_entropy": "nn.losses",
|
211
213
|
"cast_norm_type": "nn.norm",
|
@@ -369,6 +371,7 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
369
371
|
euler_to_quat,
|
370
372
|
get_projected_gravity_vector_from_quat,
|
371
373
|
quat_to_euler,
|
374
|
+
quat_to_rotmat,
|
372
375
|
rotate_vector_by_quat,
|
373
376
|
)
|
374
377
|
from xax.nn.losses import cross_entropy
|
@@ -12,6 +12,14 @@ from xax.core.conf import field
|
|
12
12
|
Phase = Literal["train", "valid"]
|
13
13
|
|
14
14
|
|
15
|
+
def _phase_to_int(phase: Phase) -> int:
|
16
|
+
return {"train": 0, "valid": 1}[phase]
|
17
|
+
|
18
|
+
|
19
|
+
def _int_to_phase(i: int) -> Phase:
|
20
|
+
return cast(Phase, ["train", "valid"][i])
|
21
|
+
|
22
|
+
|
15
23
|
class StateDict(TypedDict, total=False):
|
16
24
|
num_steps: NotRequired[int]
|
17
25
|
num_samples: NotRequired[int]
|
@@ -35,7 +43,7 @@ class State:
|
|
35
43
|
|
36
44
|
@property
|
37
45
|
def phase(self) -> Phase:
|
38
|
-
return
|
46
|
+
return _int_to_phase(self._phase)
|
39
47
|
|
40
48
|
@classmethod
|
41
49
|
def init_state(cls) -> "State":
|
@@ -74,3 +82,20 @@ class State:
|
|
74
82
|
case _:
|
75
83
|
raise ValueError(f"Invalid phase: {phase}")
|
76
84
|
return State(**{**asdict(self), **kwargs, **extra_kwargs})
|
85
|
+
|
86
|
+
def to_dict(self) -> dict[str, int | float | str]:
|
87
|
+
return {
|
88
|
+
"num_steps": int(self.num_steps),
|
89
|
+
"num_samples": int(self.num_samples),
|
90
|
+
"num_valid_steps": int(self.num_valid_steps),
|
91
|
+
"num_valid_samples": int(self.num_valid_samples),
|
92
|
+
"start_time_s": float(self.start_time_s),
|
93
|
+
"elapsed_time_s": float(self.elapsed_time_s),
|
94
|
+
"phase": str(self.phase),
|
95
|
+
}
|
96
|
+
|
97
|
+
@classmethod
|
98
|
+
def from_dict(cls, d: dict[str, int | float | str]) -> "State":
|
99
|
+
if "phase" in d:
|
100
|
+
d["_phase"] = _phase_to_int(cast(Phase, d.pop("phase")))
|
101
|
+
return cls(**d) # type: ignore[arg-type]
|
@@ -177,3 +177,37 @@ def cubic_bezier_interpolation(y_start: Array, y_end: Array, x: Array) -> Array:
|
|
177
177
|
y_diff = y_end - y_start
|
178
178
|
bezier = x**3 + 3 * (x**2 * (1 - x))
|
179
179
|
return y_start + y_diff * bezier
|
180
|
+
|
181
|
+
|
182
|
+
def quat_to_rotmat(quat: Array, eps: float = 1e-6) -> Array:
|
183
|
+
"""Converts a quaternion to a rotation matrix.
|
184
|
+
|
185
|
+
Args:
|
186
|
+
quat: The quaternion to convert, shape (*, 4).
|
187
|
+
eps: A small epsilon value to avoid division by zero.
|
188
|
+
|
189
|
+
Returns:
|
190
|
+
The rotation matrix, shape (*, 3, 3).
|
191
|
+
"""
|
192
|
+
quat = quat / (jnp.linalg.norm(quat, axis=-1, keepdims=True) + eps)
|
193
|
+
w, x, y, z = jnp.split(quat, 4, axis=-1)
|
194
|
+
|
195
|
+
xx = 1 - 2 * (y * y + z * z)
|
196
|
+
xy = 2 * (x * y - z * w)
|
197
|
+
xz = 2 * (x * z + y * w)
|
198
|
+
yx = 2 * (x * y + z * w)
|
199
|
+
yy = 1 - 2 * (x * x + z * z)
|
200
|
+
yz = 2 * (y * z - x * w)
|
201
|
+
zx = 2 * (x * z - y * w)
|
202
|
+
zy = 2 * (y * z + x * w)
|
203
|
+
zz = 1 - 2 * (x * x + y * y)
|
204
|
+
|
205
|
+
# Corrected stacking: row-major order
|
206
|
+
return jnp.concatenate(
|
207
|
+
[
|
208
|
+
jnp.concatenate([xx, xy, xz], axis=-1)[..., None, :],
|
209
|
+
jnp.concatenate([yx, yy, yz], axis=-1)[..., None, :],
|
210
|
+
jnp.concatenate([zx, zy, zz], axis=-1)[..., None, :],
|
211
|
+
],
|
212
|
+
axis=-2,
|
213
|
+
)
|
@@ -2,16 +2,16 @@
|
|
2
2
|
|
3
3
|
# Core ML/JAX dependencies
|
4
4
|
attrs
|
5
|
+
chex
|
6
|
+
dpshdl
|
7
|
+
equinox
|
8
|
+
importlib-resources
|
5
9
|
jax
|
6
10
|
jaxtyping
|
7
|
-
equinox
|
8
11
|
optax
|
9
|
-
|
10
|
-
chex
|
11
|
-
importlib-resources
|
12
|
+
orbax-checkpoint
|
12
13
|
|
13
14
|
# Data processing and serialization
|
14
|
-
cloudpickle
|
15
15
|
pillow
|
16
16
|
|
17
17
|
# Configuration and project management
|
@@ -79,7 +79,7 @@ class BaseTask(Generic[Config]):
|
|
79
79
|
def on_training_end(self, state: State) -> State:
|
80
80
|
return state
|
81
81
|
|
82
|
-
def on_after_checkpoint_save(self, ckpt_path: Path, state: State) -> State:
|
82
|
+
def on_after_checkpoint_save(self, ckpt_path: Path, state: State | None) -> State | None:
|
83
83
|
return state
|
84
84
|
|
85
85
|
@functools.cached_property
|
@@ -18,11 +18,22 @@ from abc import ABC, abstractmethod
|
|
18
18
|
from collections import defaultdict
|
19
19
|
from dataclasses import dataclass
|
20
20
|
from types import TracebackType
|
21
|
-
from typing import
|
21
|
+
from typing import (
|
22
|
+
Any,
|
23
|
+
Callable,
|
24
|
+
Iterator,
|
25
|
+
Literal,
|
26
|
+
Self,
|
27
|
+
Sequence,
|
28
|
+
TypeVar,
|
29
|
+
cast,
|
30
|
+
get_args,
|
31
|
+
)
|
22
32
|
|
23
33
|
import jax
|
24
34
|
import jax.numpy as jnp
|
25
35
|
import numpy as np
|
36
|
+
from jax._src.core import ClosedJaxpr
|
26
37
|
from jaxtyping import Array
|
27
38
|
from PIL import Image, ImageDraw, ImageFont
|
28
39
|
from PIL.Image import Image as PILImage
|
@@ -194,7 +205,10 @@ def tile_images(images: list[PILImage], sep: int = 0) -> PILImage:
|
|
194
205
|
return tiled
|
195
206
|
|
196
207
|
|
197
|
-
def as_numpy(array: Array) -> np.ndarray:
|
208
|
+
def as_numpy(array: Array | np.ndarray) -> np.ndarray:
|
209
|
+
"""Convert a JAX array or numpy array to numpy array."""
|
210
|
+
if isinstance(array, np.ndarray):
|
211
|
+
return array
|
198
212
|
array = jax.device_get(array)
|
199
213
|
if jax.dtypes.issubdtype(array.dtype, jnp.floating):
|
200
214
|
array = array.astype(jnp.float32)
|
@@ -205,6 +219,13 @@ def as_numpy(array: Array) -> np.ndarray:
|
|
205
219
|
return np.array(array)
|
206
220
|
|
207
221
|
|
222
|
+
def as_numpy_opt(array: Array | np.ndarray | None) -> np.ndarray | None:
|
223
|
+
"""Convert an optional JAX array or numpy array to numpy array."""
|
224
|
+
if array is None:
|
225
|
+
return None
|
226
|
+
return as_numpy(array)
|
227
|
+
|
228
|
+
|
208
229
|
@dataclass(kw_only=True)
|
209
230
|
class LogString:
|
210
231
|
value: str
|
@@ -252,6 +273,19 @@ class LogHistogram:
|
|
252
273
|
bucket_counts: list[int]
|
253
274
|
|
254
275
|
|
276
|
+
@dataclass(kw_only=True)
|
277
|
+
class LogMesh:
|
278
|
+
vertices: np.ndarray
|
279
|
+
colors: np.ndarray | None
|
280
|
+
faces: np.ndarray | None
|
281
|
+
config_dict: dict[str, Any] | None # noqa: ANN401
|
282
|
+
|
283
|
+
|
284
|
+
@dataclass(kw_only=True)
|
285
|
+
class LogGraph:
|
286
|
+
computation: ClosedJaxpr
|
287
|
+
|
288
|
+
|
255
289
|
@dataclass(kw_only=True)
|
256
290
|
class LogLine:
|
257
291
|
state: State
|
@@ -261,6 +295,7 @@ class LogLine:
|
|
261
295
|
strings: dict[str, dict[str, LogString]]
|
262
296
|
images: dict[str, dict[str, LogImage]]
|
263
297
|
videos: dict[str, dict[str, LogVideo]]
|
298
|
+
meshes: dict[str, dict[str, LogMesh]]
|
264
299
|
|
265
300
|
|
266
301
|
@dataclass(kw_only=True)
|
@@ -533,6 +568,7 @@ class Logger:
|
|
533
568
|
self.strings: dict[str, dict[str, Callable[[], LogString]]] = defaultdict(dict)
|
534
569
|
self.images: dict[str, dict[str, Callable[[], LogImage]]] = defaultdict(dict)
|
535
570
|
self.videos: dict[str, dict[str, Callable[[], LogVideo]]] = defaultdict(dict)
|
571
|
+
self.meshes: dict[str, dict[str, Callable[[], LogMesh]]] = defaultdict(dict)
|
536
572
|
self.default_namespace = default_namespace
|
537
573
|
self.loggers: list[LoggerImpl] = []
|
538
574
|
|
@@ -560,6 +596,7 @@ class Logger:
|
|
560
596
|
strings={k: {kk: v() for kk, v in v.items()} for k, v in self.strings.items()},
|
561
597
|
images={k: {kk: v() for kk, v in v.items()} for k, v in self.images.items()},
|
562
598
|
videos={k: {kk: v() for kk, v in v.items()} for k, v in self.videos.items()},
|
599
|
+
meshes={k: {kk: v() for kk, v in v.items()} for k, v in self.meshes.items()},
|
563
600
|
)
|
564
601
|
|
565
602
|
def clear(self) -> None:
|
@@ -569,6 +606,7 @@ class Logger:
|
|
569
606
|
self.strings.clear()
|
570
607
|
self.images.clear()
|
571
608
|
self.videos.clear()
|
609
|
+
self.meshes.clear()
|
572
610
|
|
573
611
|
def write(self, state: State) -> None:
|
574
612
|
"""Writes the current step's logging information.
|
@@ -1051,6 +1089,73 @@ class Logger:
|
|
1051
1089
|
|
1052
1090
|
self.videos[namespace][key] = video_future
|
1053
1091
|
|
1092
|
+
def log_mesh(
|
1093
|
+
self,
|
1094
|
+
key: str,
|
1095
|
+
vertices: np.ndarray | Array | Callable[[], np.ndarray | Array],
|
1096
|
+
colors: np.ndarray | Array | None | Callable[[], np.ndarray | Array | None] = None,
|
1097
|
+
faces: np.ndarray | Array | None | Callable[[], np.ndarray | Array | None] = None,
|
1098
|
+
config_dict: dict[str, Any] | None = None,
|
1099
|
+
*,
|
1100
|
+
namespace: str | None = None,
|
1101
|
+
) -> None:
|
1102
|
+
if not self.active:
|
1103
|
+
raise RuntimeError("The logger is not active")
|
1104
|
+
namespace = self.resolve_namespace(namespace)
|
1105
|
+
|
1106
|
+
@functools.lru_cache(maxsize=None)
|
1107
|
+
def mesh_future() -> LogMesh:
|
1108
|
+
with ContextTimer() as timer:
|
1109
|
+
# Get the raw values
|
1110
|
+
vertices_val = vertices() if callable(vertices) else vertices
|
1111
|
+
colors_val = colors() if callable(colors) else colors
|
1112
|
+
faces_val = faces() if callable(faces) else faces
|
1113
|
+
|
1114
|
+
# Convert to numpy arrays with proper type handling
|
1115
|
+
vertices_np = as_numpy(vertices_val)
|
1116
|
+
colors_np = as_numpy_opt(colors_val)
|
1117
|
+
faces_np = as_numpy_opt(faces_val)
|
1118
|
+
|
1119
|
+
# Checks vertices shape.
|
1120
|
+
if vertices_np.ndim == 2:
|
1121
|
+
vertices_np = vertices_np[None]
|
1122
|
+
if vertices_np.shape[-1] != 3 or vertices_np.ndim != 3:
|
1123
|
+
raise ValueError("Vertices must have shape (N, 3) or (B, N, 3)")
|
1124
|
+
|
1125
|
+
# Checks colors shape.
|
1126
|
+
if colors_np is not None:
|
1127
|
+
if colors_np.ndim == 2:
|
1128
|
+
colors_np = colors_np[None]
|
1129
|
+
if colors_np.shape[-1] != 3 or colors_np.ndim != 3:
|
1130
|
+
raise ValueError("Colors must have shape (N, 3) or (B, N, 3)")
|
1131
|
+
|
1132
|
+
# Checks faces shape.
|
1133
|
+
if faces_np is not None:
|
1134
|
+
if faces_np.ndim == 2:
|
1135
|
+
faces_np = faces_np[None]
|
1136
|
+
if faces_np.shape[-1] != 3 or faces_np.ndim != 3:
|
1137
|
+
raise ValueError("Faces must have shape (N, 3) or (B, N, 3)")
|
1138
|
+
|
1139
|
+
# Ensures colors dtype is uint8.
|
1140
|
+
if colors_np is not None:
|
1141
|
+
if colors_np.dtype != np.uint8:
|
1142
|
+
colors_np = (colors_np * 255).astype(np.uint8)
|
1143
|
+
|
1144
|
+
# Ensures faces dtype is int32.
|
1145
|
+
if faces_np is not None:
|
1146
|
+
if faces_np.dtype != np.int32:
|
1147
|
+
faces_np = faces_np.astype(np.int32)
|
1148
|
+
|
1149
|
+
logger.debug("Mesh Key: %s, Time: %s", key, timer.elapsed_time)
|
1150
|
+
return LogMesh(
|
1151
|
+
vertices=vertices_np,
|
1152
|
+
colors=colors_np,
|
1153
|
+
faces=faces_np,
|
1154
|
+
config_dict=config_dict,
|
1155
|
+
)
|
1156
|
+
|
1157
|
+
self.meshes[namespace][key] = mesh_future
|
1158
|
+
|
1054
1159
|
def __enter__(self) -> Self:
|
1055
1160
|
self.active = True
|
1056
1161
|
for logger in self.loggers:
|
@@ -70,6 +70,9 @@ class TensorboardLogger(LoggerImpl):
|
|
70
70
|
self._started = True
|
71
71
|
|
72
72
|
def worker_thread(self) -> None:
|
73
|
+
if os.environ.get("DISABLE_TENSORBOARD", "0") == "1":
|
74
|
+
return
|
75
|
+
|
73
76
|
time.sleep(self.wait_seconds)
|
74
77
|
|
75
78
|
port = int(os.environ.get("TENSORBOARD_PORT", DEFAULT_TENSORBOARD_PORT))
|
@@ -213,6 +216,19 @@ class TensorboardLogger(LoggerImpl):
|
|
213
216
|
video_value.frames,
|
214
217
|
fps=video_value.fps,
|
215
218
|
global_step=line.state.num_steps,
|
219
|
+
walltime=walltime,
|
220
|
+
)
|
221
|
+
|
222
|
+
for namespace, meshes in line.meshes.items():
|
223
|
+
for mesh_key, mesh_value in meshes.items():
|
224
|
+
writer.add_mesh(
|
225
|
+
f"{namespace}/{mesh_key}",
|
226
|
+
vertices=mesh_value.vertices,
|
227
|
+
faces=mesh_value.faces,
|
228
|
+
colors=mesh_value.colors,
|
229
|
+
config_dict=mesh_value.config_dict,
|
230
|
+
global_step=line.state.num_steps,
|
231
|
+
walltime=walltime,
|
216
232
|
)
|
217
233
|
|
218
234
|
for name, contents in self.files.items():
|
@@ -6,9 +6,9 @@ import logging
|
|
6
6
|
import tarfile
|
7
7
|
from dataclasses import asdict, dataclass
|
8
8
|
from pathlib import Path
|
9
|
-
from typing import
|
9
|
+
from typing import Generic, Literal, TypeVar, cast, overload
|
10
10
|
|
11
|
-
import
|
11
|
+
import equinox as eqx
|
12
12
|
import jax
|
13
13
|
import optax
|
14
14
|
from jaxtyping import PyTree
|
@@ -63,8 +63,7 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
|
|
63
63
|
|
64
64
|
def get_init_ckpt_path(self) -> Path | None:
|
65
65
|
if self._exp_dir is not None:
|
66
|
-
ckpt_path
|
67
|
-
if ckpt_path.exists():
|
66
|
+
if (ckpt_path := self.get_ckpt_path()).exists():
|
68
67
|
return ckpt_path
|
69
68
|
if self.config.load_from_ckpt_path is not None:
|
70
69
|
ckpt_path = Path(self.config.load_from_ckpt_path)
|
@@ -84,93 +83,129 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
|
|
84
83
|
return False
|
85
84
|
|
86
85
|
@overload
|
87
|
-
def
|
86
|
+
def load_ckpt_with_template(
|
88
87
|
self,
|
89
88
|
path: Path,
|
90
|
-
|
91
|
-
|
89
|
+
*,
|
90
|
+
part: Literal["all"],
|
91
|
+
model_template: PyTree,
|
92
|
+
optimizer_template: PyTree,
|
93
|
+
opt_state_template: PyTree,
|
94
|
+
) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State, Config]: ...
|
92
95
|
|
93
96
|
@overload
|
94
|
-
def
|
97
|
+
def load_ckpt_with_template(
|
95
98
|
self,
|
96
99
|
path: Path,
|
97
|
-
|
98
|
-
|
100
|
+
*,
|
101
|
+
part: Literal["model_state_config"],
|
102
|
+
model_template: PyTree,
|
103
|
+
) -> tuple[PyTree, State, Config]: ...
|
99
104
|
|
100
105
|
@overload
|
101
|
-
def
|
106
|
+
def load_ckpt_with_template(
|
102
107
|
self,
|
103
108
|
path: Path,
|
109
|
+
*,
|
104
110
|
part: Literal["model"],
|
111
|
+
model_template: PyTree,
|
105
112
|
) -> PyTree: ...
|
106
113
|
|
107
114
|
@overload
|
108
|
-
def
|
115
|
+
def load_ckpt_with_template(
|
109
116
|
self,
|
110
117
|
path: Path,
|
118
|
+
*,
|
111
119
|
part: Literal["opt"],
|
120
|
+
optimizer_template: PyTree,
|
112
121
|
) -> optax.GradientTransformation: ...
|
113
122
|
|
114
123
|
@overload
|
115
|
-
def
|
124
|
+
def load_ckpt_with_template(
|
116
125
|
self,
|
117
126
|
path: Path,
|
127
|
+
*,
|
118
128
|
part: Literal["opt_state"],
|
129
|
+
opt_state_template: PyTree,
|
119
130
|
) -> optax.OptState: ...
|
120
131
|
|
121
132
|
@overload
|
122
|
-
def
|
133
|
+
def load_ckpt_with_template(
|
123
134
|
self,
|
124
135
|
path: Path,
|
136
|
+
*,
|
125
137
|
part: Literal["state"],
|
126
138
|
) -> State: ...
|
127
139
|
|
128
140
|
@overload
|
129
|
-
def
|
141
|
+
def load_ckpt_with_template(
|
130
142
|
self,
|
131
143
|
path: Path,
|
144
|
+
*,
|
132
145
|
part: Literal["config"],
|
133
|
-
) ->
|
146
|
+
) -> Config: ...
|
134
147
|
|
135
|
-
def
|
148
|
+
def load_ckpt_with_template(
|
136
149
|
self,
|
137
150
|
path: Path,
|
151
|
+
*,
|
138
152
|
part: CheckpointPart = "all",
|
153
|
+
model_template: PyTree | None = None,
|
154
|
+
optimizer_template: PyTree | None = None,
|
155
|
+
opt_state_template: PyTree | None = None,
|
139
156
|
) -> (
|
140
|
-
tuple[PyTree, optax.GradientTransformation, optax.OptState, State,
|
141
|
-
| tuple[PyTree, State,
|
157
|
+
tuple[PyTree, optax.GradientTransformation, optax.OptState, State, Config]
|
158
|
+
| tuple[PyTree, State, Config]
|
142
159
|
| PyTree
|
143
160
|
| optax.GradientTransformation
|
144
161
|
| optax.OptState
|
145
162
|
| State
|
146
|
-
|
|
163
|
+
| Config
|
147
164
|
):
|
165
|
+
"""Load a checkpoint.
|
166
|
+
|
167
|
+
Args:
|
168
|
+
path: Path to the checkpoint directory
|
169
|
+
part: Which part of the checkpoint to load
|
170
|
+
model_template: Template model with correct structure but uninitialized weights
|
171
|
+
optimizer_template: Template optimizer with correct structure but uninitialized weights
|
172
|
+
opt_state_template: Template optimizer state with correct structure but uninitialized weights
|
173
|
+
|
174
|
+
Returns:
|
175
|
+
The requested checkpoint components
|
176
|
+
"""
|
148
177
|
with tarfile.open(path, "r:gz") as tar:
|
149
178
|
|
150
179
|
def get_model() -> PyTree:
|
180
|
+
if model_template is None:
|
181
|
+
raise ValueError("model_template must be provided to load model weights")
|
151
182
|
if (model := tar.extractfile("model")) is None:
|
152
183
|
raise ValueError(f"Checkpoint does not contain a model file: {path}")
|
153
|
-
return
|
184
|
+
return eqx.tree_deserialise_leaves(io.BytesIO(model.read()), model_template)
|
154
185
|
|
155
186
|
def get_opt() -> optax.GradientTransformation:
|
156
|
-
if
|
157
|
-
raise ValueError(
|
158
|
-
|
187
|
+
if optimizer_template is None:
|
188
|
+
raise ValueError("optimizer_template must be provided to load optimizer")
|
189
|
+
if (opt := tar.extractfile("optimizer")) is None:
|
190
|
+
raise ValueError(f"Checkpoint does not contain an optimizer file: {path}")
|
191
|
+
return eqx.tree_deserialise_leaves(io.BytesIO(opt.read()), optimizer_template)
|
159
192
|
|
160
193
|
def get_opt_state() -> optax.OptState:
|
194
|
+
if opt_state_template is None:
|
195
|
+
raise ValueError("opt_state_template must be provided to load optimizer state")
|
161
196
|
if (opt_state := tar.extractfile("opt_state")) is None:
|
162
|
-
raise ValueError(f"Checkpoint does not contain an
|
163
|
-
return
|
197
|
+
raise ValueError(f"Checkpoint does not contain an optimizer state file: {path}")
|
198
|
+
return eqx.tree_deserialise_leaves(io.BytesIO(opt_state.read()), opt_state_template)
|
164
199
|
|
165
200
|
def get_state() -> State:
|
166
201
|
if (state := tar.extractfile("state")) is None:
|
167
202
|
raise ValueError(f"Checkpoint does not contain a state file: {path}")
|
168
203
|
return State(**json.loads(state.read().decode()))
|
169
204
|
|
170
|
-
def get_config() ->
|
205
|
+
def get_config() -> Config:
|
171
206
|
if (config := tar.extractfile("config")) is None:
|
172
207
|
raise ValueError(f"Checkpoint does not contain a config file: {path}")
|
173
|
-
return cast(DictConfig, OmegaConf.load(config))
|
208
|
+
return self.get_config(cast(DictConfig, OmegaConf.load(config)), use_cli=False)
|
174
209
|
|
175
210
|
match part:
|
176
211
|
case "model":
|
@@ -192,51 +227,90 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
|
|
192
227
|
|
193
228
|
def save_checkpoint(
|
194
229
|
self,
|
195
|
-
model: PyTree,
|
196
|
-
optimizer: optax.GradientTransformation,
|
197
|
-
opt_state: optax.OptState,
|
198
|
-
|
230
|
+
model: PyTree | None = None,
|
231
|
+
optimizer: optax.GradientTransformation | None = None,
|
232
|
+
opt_state: optax.OptState | None = None,
|
233
|
+
aux_data: PyTree | None = None,
|
234
|
+
state: State | None = None,
|
199
235
|
) -> Path:
|
236
|
+
"""Save a checkpoint.
|
237
|
+
|
238
|
+
Args:
|
239
|
+
model: The model to save
|
240
|
+
state: The current training state
|
241
|
+
optimizer: The optimizer to save
|
242
|
+
aux_data: Additional data to save
|
243
|
+
opt_state: The optimizer state to save
|
244
|
+
|
245
|
+
Returns:
|
246
|
+
Path to the saved checkpoint
|
247
|
+
"""
|
200
248
|
ckpt_path = self.get_ckpt_path(state)
|
201
249
|
|
202
250
|
if not is_master():
|
203
251
|
return ckpt_path
|
204
252
|
|
205
|
-
# Gets the path to the last checkpoint
|
253
|
+
# Gets the path to the last checkpoint
|
206
254
|
logger.info("Saving checkpoint to %s", ckpt_path)
|
207
255
|
last_ckpt_path = self.get_ckpt_path()
|
208
256
|
ckpt_path.parent.mkdir(exist_ok=True, parents=True)
|
209
257
|
|
210
|
-
# Potentially removes the last checkpoint
|
258
|
+
# Potentially removes the last checkpoint
|
211
259
|
if last_ckpt_path.exists() and self.config.only_save_most_recent:
|
212
260
|
if (base_ckpt := last_ckpt_path.resolve()).is_file():
|
213
261
|
base_ckpt.unlink()
|
214
262
|
|
215
|
-
#
|
263
|
+
# Save the checkpoint components
|
216
264
|
with tarfile.open(ckpt_path, "w:gz") as tar:
|
217
265
|
|
218
|
-
def add_file(name: str,
|
266
|
+
def add_file(name: str, buf: io.BytesIO) -> None:
|
267
|
+
tarinfo = tarfile.TarInfo(name)
|
268
|
+
tarinfo.size = buf.tell()
|
269
|
+
buf.seek(0)
|
270
|
+
tar.addfile(tarinfo, buf)
|
271
|
+
|
272
|
+
# Save model using Equinox
|
273
|
+
if model is not None:
|
274
|
+
with io.BytesIO() as buf:
|
275
|
+
eqx.tree_serialise_leaves(buf, model)
|
276
|
+
add_file("model", buf)
|
277
|
+
|
278
|
+
# Save optimizer using Equinox
|
279
|
+
if optimizer is not None:
|
280
|
+
with io.BytesIO() as buf:
|
281
|
+
eqx.tree_serialise_leaves(buf, optimizer)
|
282
|
+
add_file("optimizer", buf)
|
283
|
+
|
284
|
+
# Save optimizer state using Equinox
|
285
|
+
if opt_state is not None:
|
219
286
|
with io.BytesIO() as buf:
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
287
|
+
eqx.tree_serialise_leaves(buf, opt_state)
|
288
|
+
add_file("opt_state", buf)
|
289
|
+
|
290
|
+
# Save aux data using Equinox.
|
291
|
+
if aux_data is not None:
|
292
|
+
with io.BytesIO() as buf:
|
293
|
+
eqx.tree_serialise_leaves(buf, aux_data)
|
294
|
+
add_file("aux_data", buf)
|
295
|
+
|
296
|
+
# Save state and config as JSON
|
297
|
+
def add_file_bytes(name: str, data: bytes) -> None: # noqa: ANN401
|
298
|
+
info = tarfile.TarInfo(name=name)
|
299
|
+
info.size = len(data)
|
300
|
+
tar.addfile(info, io.BytesIO(data))
|
301
|
+
|
302
|
+
if state is not None:
|
303
|
+
add_file_bytes("state", json.dumps(asdict(state), indent=2).encode())
|
304
|
+
add_file_bytes("config", OmegaConf.to_yaml(self.config).encode())
|
305
|
+
|
306
|
+
# Updates the symlink to the new checkpoint
|
233
307
|
last_ckpt_path.unlink(missing_ok=True)
|
234
308
|
try:
|
235
309
|
last_ckpt_path.symlink_to(ckpt_path.relative_to(last_ckpt_path.parent))
|
236
310
|
except FileExistsError:
|
237
311
|
logger.exception("Exception while trying to update %s", ckpt_path)
|
238
312
|
|
239
|
-
# Calls the base callback
|
313
|
+
# Calls the base callback
|
240
314
|
self.on_after_checkpoint_save(ckpt_path, state)
|
241
315
|
|
242
316
|
return ckpt_path
|