xax 0.1.16__tar.gz → 0.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {xax-0.1.16/xax.egg-info → xax-0.2.0}/PKG-INFO +6 -6
- {xax-0.1.16 → xax-0.2.0}/pyproject.toml +0 -1
- {xax-0.1.16 → xax-0.2.0}/xax/__init__.py +1 -1
- {xax-0.1.16 → xax-0.2.0}/xax/core/state.py +26 -1
- {xax-0.1.16 → xax-0.2.0}/xax/requirements.txt +5 -5
- {xax-0.1.16 → xax-0.2.0}/xax/task/base.py +1 -1
- {xax-0.1.16 → xax-0.2.0}/xax/task/logger.py +107 -2
- {xax-0.1.16 → xax-0.2.0}/xax/task/loggers/tensorboard.py +16 -0
- {xax-0.1.16 → xax-0.2.0}/xax/task/mixins/checkpointing.py +118 -41
- {xax-0.1.16 → xax-0.2.0}/xax/task/mixins/data_loader.py +2 -1
- {xax-0.1.16 → xax-0.2.0}/xax/task/mixins/train.py +35 -23
- {xax-0.1.16 → xax-0.2.0}/xax/utils/experiments.py +29 -12
- {xax-0.1.16 → xax-0.2.0}/xax/utils/tensorboard.py +91 -3
- {xax-0.1.16 → xax-0.2.0/xax.egg-info}/PKG-INFO +6 -6
- {xax-0.1.16 → xax-0.2.0}/xax.egg-info/requires.txt +5 -5
- {xax-0.1.16 → xax-0.2.0}/LICENSE +0 -0
- {xax-0.1.16 → xax-0.2.0}/MANIFEST.in +0 -0
- {xax-0.1.16 → xax-0.2.0}/README.md +0 -0
- {xax-0.1.16 → xax-0.2.0}/setup.cfg +0 -0
- {xax-0.1.16 → xax-0.2.0}/setup.py +0 -0
- {xax-0.1.16 → xax-0.2.0}/xax/core/__init__.py +0 -0
- {xax-0.1.16 → xax-0.2.0}/xax/core/conf.py +0 -0
- {xax-0.1.16 → xax-0.2.0}/xax/nn/__init__.py +0 -0
- {xax-0.1.16 → xax-0.2.0}/xax/nn/embeddings.py +0 -0
- {xax-0.1.16 → xax-0.2.0}/xax/nn/equinox.py +0 -0
- {xax-0.1.16 → xax-0.2.0}/xax/nn/export.py +0 -0
- {xax-0.1.16 → xax-0.2.0}/xax/nn/functions.py +0 -0
- {xax-0.1.16 → xax-0.2.0}/xax/nn/geom.py +0 -0
- {xax-0.1.16 → xax-0.2.0}/xax/nn/losses.py +0 -0
- {xax-0.1.16 → xax-0.2.0}/xax/nn/norm.py +0 -0
- {xax-0.1.16 → xax-0.2.0}/xax/nn/parallel.py +0 -0
- {xax-0.1.16 → xax-0.2.0}/xax/nn/ssm.py +0 -0
- {xax-0.1.16 → xax-0.2.0}/xax/py.typed +0 -0
- {xax-0.1.16 → xax-0.2.0}/xax/requirements-dev.txt +0 -0
- {xax-0.1.16 → xax-0.2.0}/xax/task/__init__.py +0 -0
- {xax-0.1.16 → xax-0.2.0}/xax/task/launchers/__init__.py +0 -0
- {xax-0.1.16 → xax-0.2.0}/xax/task/launchers/base.py +0 -0
- {xax-0.1.16 → xax-0.2.0}/xax/task/launchers/cli.py +0 -0
- {xax-0.1.16 → xax-0.2.0}/xax/task/launchers/single_process.py +0 -0
- {xax-0.1.16 → xax-0.2.0}/xax/task/loggers/__init__.py +0 -0
- {xax-0.1.16 → xax-0.2.0}/xax/task/loggers/callback.py +0 -0
- {xax-0.1.16 → xax-0.2.0}/xax/task/loggers/json.py +0 -0
- {xax-0.1.16 → xax-0.2.0}/xax/task/loggers/state.py +0 -0
- {xax-0.1.16 → xax-0.2.0}/xax/task/loggers/stdout.py +0 -0
- {xax-0.1.16 → xax-0.2.0}/xax/task/mixins/__init__.py +0 -0
- {xax-0.1.16 → xax-0.2.0}/xax/task/mixins/artifacts.py +0 -0
- {xax-0.1.16 → xax-0.2.0}/xax/task/mixins/compile.py +0 -0
- {xax-0.1.16 → xax-0.2.0}/xax/task/mixins/cpu_stats.py +0 -0
- {xax-0.1.16 → xax-0.2.0}/xax/task/mixins/gpu_stats.py +0 -0
- {xax-0.1.16 → xax-0.2.0}/xax/task/mixins/logger.py +0 -0
- {xax-0.1.16 → xax-0.2.0}/xax/task/mixins/process.py +0 -0
- {xax-0.1.16 → xax-0.2.0}/xax/task/mixins/runnable.py +0 -0
- {xax-0.1.16 → xax-0.2.0}/xax/task/mixins/step_wrapper.py +0 -0
- {xax-0.1.16 → xax-0.2.0}/xax/task/script.py +0 -0
- {xax-0.1.16 → xax-0.2.0}/xax/task/task.py +0 -0
- {xax-0.1.16 → xax-0.2.0}/xax/utils/__init__.py +0 -0
- {xax-0.1.16 → xax-0.2.0}/xax/utils/data/__init__.py +0 -0
- {xax-0.1.16 → xax-0.2.0}/xax/utils/data/collate.py +0 -0
- {xax-0.1.16 → xax-0.2.0}/xax/utils/debugging.py +0 -0
- {xax-0.1.16 → xax-0.2.0}/xax/utils/jax.py +0 -0
- {xax-0.1.16 → xax-0.2.0}/xax/utils/jaxpr.py +0 -0
- {xax-0.1.16 → xax-0.2.0}/xax/utils/logging.py +0 -0
- {xax-0.1.16 → xax-0.2.0}/xax/utils/numpy.py +0 -0
- {xax-0.1.16 → xax-0.2.0}/xax/utils/profile.py +0 -0
- {xax-0.1.16 → xax-0.2.0}/xax/utils/pytree.py +0 -0
- {xax-0.1.16 → xax-0.2.0}/xax/utils/text.py +0 -0
- {xax-0.1.16 → xax-0.2.0}/xax/utils/types/__init__.py +0 -0
- {xax-0.1.16 → xax-0.2.0}/xax/utils/types/frozen_dict.py +0 -0
- {xax-0.1.16 → xax-0.2.0}/xax/utils/types/hashable_array.py +0 -0
- {xax-0.1.16 → xax-0.2.0}/xax.egg-info/SOURCES.txt +0 -0
- {xax-0.1.16 → xax-0.2.0}/xax.egg-info/dependency_links.txt +0 -0
- {xax-0.1.16 → xax-0.2.0}/xax.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: xax
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.2.0
|
4
4
|
Summary: A library for fast Jax experimentation
|
5
5
|
Home-page: https://github.com/kscalelabs/xax
|
6
6
|
Author: Benjamin Bolte
|
@@ -8,14 +8,14 @@ Requires-Python: >=3.11
|
|
8
8
|
Description-Content-Type: text/markdown
|
9
9
|
License-File: LICENSE
|
10
10
|
Requires-Dist: attrs
|
11
|
+
Requires-Dist: chex
|
12
|
+
Requires-Dist: dpshdl
|
13
|
+
Requires-Dist: equinox
|
14
|
+
Requires-Dist: importlib-resources
|
11
15
|
Requires-Dist: jax
|
12
16
|
Requires-Dist: jaxtyping
|
13
|
-
Requires-Dist: equinox
|
14
17
|
Requires-Dist: optax
|
15
|
-
Requires-Dist:
|
16
|
-
Requires-Dist: chex
|
17
|
-
Requires-Dist: importlib-resources
|
18
|
-
Requires-Dist: cloudpickle
|
18
|
+
Requires-Dist: orbax-checkpoint
|
19
19
|
Requires-Dist: pillow
|
20
20
|
Requires-Dist: omegaconf
|
21
21
|
Requires-Dist: gitpython
|
@@ -12,6 +12,14 @@ from xax.core.conf import field
|
|
12
12
|
Phase = Literal["train", "valid"]
|
13
13
|
|
14
14
|
|
15
|
+
def _phase_to_int(phase: Phase) -> int:
|
16
|
+
return {"train": 0, "valid": 1}[phase]
|
17
|
+
|
18
|
+
|
19
|
+
def _int_to_phase(i: int) -> Phase:
|
20
|
+
return cast(Phase, ["train", "valid"][i])
|
21
|
+
|
22
|
+
|
15
23
|
class StateDict(TypedDict, total=False):
|
16
24
|
num_steps: NotRequired[int]
|
17
25
|
num_samples: NotRequired[int]
|
@@ -35,7 +43,7 @@ class State:
|
|
35
43
|
|
36
44
|
@property
|
37
45
|
def phase(self) -> Phase:
|
38
|
-
return
|
46
|
+
return _int_to_phase(self._phase)
|
39
47
|
|
40
48
|
@classmethod
|
41
49
|
def init_state(cls) -> "State":
|
@@ -74,3 +82,20 @@ class State:
|
|
74
82
|
case _:
|
75
83
|
raise ValueError(f"Invalid phase: {phase}")
|
76
84
|
return State(**{**asdict(self), **kwargs, **extra_kwargs})
|
85
|
+
|
86
|
+
def to_dict(self) -> dict[str, int | float | str]:
|
87
|
+
return {
|
88
|
+
"num_steps": int(self.num_steps),
|
89
|
+
"num_samples": int(self.num_samples),
|
90
|
+
"num_valid_steps": int(self.num_valid_steps),
|
91
|
+
"num_valid_samples": int(self.num_valid_samples),
|
92
|
+
"start_time_s": float(self.start_time_s),
|
93
|
+
"elapsed_time_s": float(self.elapsed_time_s),
|
94
|
+
"phase": str(self.phase),
|
95
|
+
}
|
96
|
+
|
97
|
+
@classmethod
|
98
|
+
def from_dict(cls, d: dict[str, int | float | str]) -> "State":
|
99
|
+
if "phase" in d:
|
100
|
+
d["_phase"] = _phase_to_int(cast(Phase, d.pop("phase")))
|
101
|
+
return cls(**d) # type: ignore[arg-type]
|
@@ -2,16 +2,16 @@
|
|
2
2
|
|
3
3
|
# Core ML/JAX dependencies
|
4
4
|
attrs
|
5
|
+
chex
|
6
|
+
dpshdl
|
7
|
+
equinox
|
8
|
+
importlib-resources
|
5
9
|
jax
|
6
10
|
jaxtyping
|
7
|
-
equinox
|
8
11
|
optax
|
9
|
-
|
10
|
-
chex
|
11
|
-
importlib-resources
|
12
|
+
orbax-checkpoint
|
12
13
|
|
13
14
|
# Data processing and serialization
|
14
|
-
cloudpickle
|
15
15
|
pillow
|
16
16
|
|
17
17
|
# Configuration and project management
|
@@ -79,7 +79,7 @@ class BaseTask(Generic[Config]):
|
|
79
79
|
def on_training_end(self, state: State) -> State:
|
80
80
|
return state
|
81
81
|
|
82
|
-
def on_after_checkpoint_save(self, ckpt_path: Path, state: State) -> State:
|
82
|
+
def on_after_checkpoint_save(self, ckpt_path: Path, state: State | None) -> State | None:
|
83
83
|
return state
|
84
84
|
|
85
85
|
@functools.cached_property
|
@@ -18,11 +18,22 @@ from abc import ABC, abstractmethod
|
|
18
18
|
from collections import defaultdict
|
19
19
|
from dataclasses import dataclass
|
20
20
|
from types import TracebackType
|
21
|
-
from typing import
|
21
|
+
from typing import (
|
22
|
+
Any,
|
23
|
+
Callable,
|
24
|
+
Iterator,
|
25
|
+
Literal,
|
26
|
+
Self,
|
27
|
+
Sequence,
|
28
|
+
TypeVar,
|
29
|
+
cast,
|
30
|
+
get_args,
|
31
|
+
)
|
22
32
|
|
23
33
|
import jax
|
24
34
|
import jax.numpy as jnp
|
25
35
|
import numpy as np
|
36
|
+
from jax._src.core import ClosedJaxpr
|
26
37
|
from jaxtyping import Array
|
27
38
|
from PIL import Image, ImageDraw, ImageFont
|
28
39
|
from PIL.Image import Image as PILImage
|
@@ -194,7 +205,10 @@ def tile_images(images: list[PILImage], sep: int = 0) -> PILImage:
|
|
194
205
|
return tiled
|
195
206
|
|
196
207
|
|
197
|
-
def as_numpy(array: Array) -> np.ndarray:
|
208
|
+
def as_numpy(array: Array | np.ndarray) -> np.ndarray:
|
209
|
+
"""Convert a JAX array or numpy array to numpy array."""
|
210
|
+
if isinstance(array, np.ndarray):
|
211
|
+
return array
|
198
212
|
array = jax.device_get(array)
|
199
213
|
if jax.dtypes.issubdtype(array.dtype, jnp.floating):
|
200
214
|
array = array.astype(jnp.float32)
|
@@ -205,6 +219,13 @@ def as_numpy(array: Array) -> np.ndarray:
|
|
205
219
|
return np.array(array)
|
206
220
|
|
207
221
|
|
222
|
+
def as_numpy_opt(array: Array | np.ndarray | None) -> np.ndarray | None:
|
223
|
+
"""Convert an optional JAX array or numpy array to numpy array."""
|
224
|
+
if array is None:
|
225
|
+
return None
|
226
|
+
return as_numpy(array)
|
227
|
+
|
228
|
+
|
208
229
|
@dataclass(kw_only=True)
|
209
230
|
class LogString:
|
210
231
|
value: str
|
@@ -252,6 +273,19 @@ class LogHistogram:
|
|
252
273
|
bucket_counts: list[int]
|
253
274
|
|
254
275
|
|
276
|
+
@dataclass(kw_only=True)
|
277
|
+
class LogMesh:
|
278
|
+
vertices: np.ndarray
|
279
|
+
colors: np.ndarray | None
|
280
|
+
faces: np.ndarray | None
|
281
|
+
config_dict: dict[str, Any] | None # noqa: ANN401
|
282
|
+
|
283
|
+
|
284
|
+
@dataclass(kw_only=True)
|
285
|
+
class LogGraph:
|
286
|
+
computation: ClosedJaxpr
|
287
|
+
|
288
|
+
|
255
289
|
@dataclass(kw_only=True)
|
256
290
|
class LogLine:
|
257
291
|
state: State
|
@@ -261,6 +295,7 @@ class LogLine:
|
|
261
295
|
strings: dict[str, dict[str, LogString]]
|
262
296
|
images: dict[str, dict[str, LogImage]]
|
263
297
|
videos: dict[str, dict[str, LogVideo]]
|
298
|
+
meshes: dict[str, dict[str, LogMesh]]
|
264
299
|
|
265
300
|
|
266
301
|
@dataclass(kw_only=True)
|
@@ -533,6 +568,7 @@ class Logger:
|
|
533
568
|
self.strings: dict[str, dict[str, Callable[[], LogString]]] = defaultdict(dict)
|
534
569
|
self.images: dict[str, dict[str, Callable[[], LogImage]]] = defaultdict(dict)
|
535
570
|
self.videos: dict[str, dict[str, Callable[[], LogVideo]]] = defaultdict(dict)
|
571
|
+
self.meshes: dict[str, dict[str, Callable[[], LogMesh]]] = defaultdict(dict)
|
536
572
|
self.default_namespace = default_namespace
|
537
573
|
self.loggers: list[LoggerImpl] = []
|
538
574
|
|
@@ -560,6 +596,7 @@ class Logger:
|
|
560
596
|
strings={k: {kk: v() for kk, v in v.items()} for k, v in self.strings.items()},
|
561
597
|
images={k: {kk: v() for kk, v in v.items()} for k, v in self.images.items()},
|
562
598
|
videos={k: {kk: v() for kk, v in v.items()} for k, v in self.videos.items()},
|
599
|
+
meshes={k: {kk: v() for kk, v in v.items()} for k, v in self.meshes.items()},
|
563
600
|
)
|
564
601
|
|
565
602
|
def clear(self) -> None:
|
@@ -569,6 +606,7 @@ class Logger:
|
|
569
606
|
self.strings.clear()
|
570
607
|
self.images.clear()
|
571
608
|
self.videos.clear()
|
609
|
+
self.meshes.clear()
|
572
610
|
|
573
611
|
def write(self, state: State) -> None:
|
574
612
|
"""Writes the current step's logging information.
|
@@ -1051,6 +1089,73 @@ class Logger:
|
|
1051
1089
|
|
1052
1090
|
self.videos[namespace][key] = video_future
|
1053
1091
|
|
1092
|
+
def log_mesh(
|
1093
|
+
self,
|
1094
|
+
key: str,
|
1095
|
+
vertices: np.ndarray | Array | Callable[[], np.ndarray | Array],
|
1096
|
+
colors: np.ndarray | Array | None | Callable[[], np.ndarray | Array | None] = None,
|
1097
|
+
faces: np.ndarray | Array | None | Callable[[], np.ndarray | Array | None] = None,
|
1098
|
+
config_dict: dict[str, Any] | None = None,
|
1099
|
+
*,
|
1100
|
+
namespace: str | None = None,
|
1101
|
+
) -> None:
|
1102
|
+
if not self.active:
|
1103
|
+
raise RuntimeError("The logger is not active")
|
1104
|
+
namespace = self.resolve_namespace(namespace)
|
1105
|
+
|
1106
|
+
@functools.lru_cache(maxsize=None)
|
1107
|
+
def mesh_future() -> LogMesh:
|
1108
|
+
with ContextTimer() as timer:
|
1109
|
+
# Get the raw values
|
1110
|
+
vertices_val = vertices() if callable(vertices) else vertices
|
1111
|
+
colors_val = colors() if callable(colors) else colors
|
1112
|
+
faces_val = faces() if callable(faces) else faces
|
1113
|
+
|
1114
|
+
# Convert to numpy arrays with proper type handling
|
1115
|
+
vertices_np = as_numpy(vertices_val)
|
1116
|
+
colors_np = as_numpy_opt(colors_val)
|
1117
|
+
faces_np = as_numpy_opt(faces_val)
|
1118
|
+
|
1119
|
+
# Checks vertices shape.
|
1120
|
+
if vertices_np.ndim == 2:
|
1121
|
+
vertices_np = vertices_np[None]
|
1122
|
+
if vertices_np.shape[-1] != 3 or vertices_np.ndim != 3:
|
1123
|
+
raise ValueError("Vertices must have shape (N, 3) or (B, N, 3)")
|
1124
|
+
|
1125
|
+
# Checks colors shape.
|
1126
|
+
if colors_np is not None:
|
1127
|
+
if colors_np.ndim == 2:
|
1128
|
+
colors_np = colors_np[None]
|
1129
|
+
if colors_np.shape[-1] != 3 or colors_np.ndim != 3:
|
1130
|
+
raise ValueError("Colors must have shape (N, 3) or (B, N, 3)")
|
1131
|
+
|
1132
|
+
# Checks faces shape.
|
1133
|
+
if faces_np is not None:
|
1134
|
+
if faces_np.ndim == 2:
|
1135
|
+
faces_np = faces_np[None]
|
1136
|
+
if faces_np.shape[-1] != 3 or faces_np.ndim != 3:
|
1137
|
+
raise ValueError("Faces must have shape (N, 3) or (B, N, 3)")
|
1138
|
+
|
1139
|
+
# Ensures colors dtype is uint8.
|
1140
|
+
if colors_np is not None:
|
1141
|
+
if colors_np.dtype != np.uint8:
|
1142
|
+
colors_np = (colors_np * 255).astype(np.uint8)
|
1143
|
+
|
1144
|
+
# Ensures faces dtype is int32.
|
1145
|
+
if faces_np is not None:
|
1146
|
+
if faces_np.dtype != np.int32:
|
1147
|
+
faces_np = faces_np.astype(np.int32)
|
1148
|
+
|
1149
|
+
logger.debug("Mesh Key: %s, Time: %s", key, timer.elapsed_time)
|
1150
|
+
return LogMesh(
|
1151
|
+
vertices=vertices_np,
|
1152
|
+
colors=colors_np,
|
1153
|
+
faces=faces_np,
|
1154
|
+
config_dict=config_dict,
|
1155
|
+
)
|
1156
|
+
|
1157
|
+
self.meshes[namespace][key] = mesh_future
|
1158
|
+
|
1054
1159
|
def __enter__(self) -> Self:
|
1055
1160
|
self.active = True
|
1056
1161
|
for logger in self.loggers:
|
@@ -70,6 +70,9 @@ class TensorboardLogger(LoggerImpl):
|
|
70
70
|
self._started = True
|
71
71
|
|
72
72
|
def worker_thread(self) -> None:
|
73
|
+
if os.environ.get("DISABLE_TENSORBOARD", "0") == "1":
|
74
|
+
return
|
75
|
+
|
73
76
|
time.sleep(self.wait_seconds)
|
74
77
|
|
75
78
|
port = int(os.environ.get("TENSORBOARD_PORT", DEFAULT_TENSORBOARD_PORT))
|
@@ -213,6 +216,19 @@ class TensorboardLogger(LoggerImpl):
|
|
213
216
|
video_value.frames,
|
214
217
|
fps=video_value.fps,
|
215
218
|
global_step=line.state.num_steps,
|
219
|
+
walltime=walltime,
|
220
|
+
)
|
221
|
+
|
222
|
+
for namespace, meshes in line.meshes.items():
|
223
|
+
for mesh_key, mesh_value in meshes.items():
|
224
|
+
writer.add_mesh(
|
225
|
+
f"{namespace}/{mesh_key}",
|
226
|
+
vertices=mesh_value.vertices,
|
227
|
+
faces=mesh_value.faces,
|
228
|
+
colors=mesh_value.colors,
|
229
|
+
config_dict=mesh_value.config_dict,
|
230
|
+
global_step=line.state.num_steps,
|
231
|
+
walltime=walltime,
|
216
232
|
)
|
217
233
|
|
218
234
|
for name, contents in self.files.items():
|
@@ -6,9 +6,9 @@ import logging
|
|
6
6
|
import tarfile
|
7
7
|
from dataclasses import asdict, dataclass
|
8
8
|
from pathlib import Path
|
9
|
-
from typing import
|
9
|
+
from typing import Generic, Literal, TypeVar, cast, overload
|
10
10
|
|
11
|
-
import
|
11
|
+
import equinox as eqx
|
12
12
|
import jax
|
13
13
|
import optax
|
14
14
|
from jaxtyping import PyTree
|
@@ -64,7 +64,9 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
|
|
64
64
|
def get_init_ckpt_path(self) -> Path | None:
|
65
65
|
if self._exp_dir is not None:
|
66
66
|
ckpt_path = self.get_ckpt_path()
|
67
|
-
if ckpt_path.exists():
|
67
|
+
if not ckpt_path.exists():
|
68
|
+
logger.warning("No checkpoint found in experiment directory: %s", ckpt_path)
|
69
|
+
else:
|
68
70
|
return ckpt_path
|
69
71
|
if self.config.load_from_ckpt_path is not None:
|
70
72
|
ckpt_path = Path(self.config.load_from_ckpt_path)
|
@@ -87,41 +89,54 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
|
|
87
89
|
def load_checkpoint(
|
88
90
|
self,
|
89
91
|
path: Path,
|
90
|
-
|
91
|
-
|
92
|
+
*,
|
93
|
+
part: Literal["all"],
|
94
|
+
model_template: PyTree,
|
95
|
+
optimizer_template: PyTree,
|
96
|
+
opt_state_template: PyTree,
|
97
|
+
) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State, Config]: ...
|
92
98
|
|
93
99
|
@overload
|
94
100
|
def load_checkpoint(
|
95
101
|
self,
|
96
102
|
path: Path,
|
97
|
-
|
98
|
-
|
103
|
+
*,
|
104
|
+
part: Literal["model_state_config"],
|
105
|
+
model_template: PyTree,
|
106
|
+
) -> tuple[PyTree, State, Config]: ...
|
99
107
|
|
100
108
|
@overload
|
101
109
|
def load_checkpoint(
|
102
110
|
self,
|
103
111
|
path: Path,
|
112
|
+
*,
|
104
113
|
part: Literal["model"],
|
114
|
+
model_template: PyTree,
|
105
115
|
) -> PyTree: ...
|
106
116
|
|
107
117
|
@overload
|
108
118
|
def load_checkpoint(
|
109
119
|
self,
|
110
120
|
path: Path,
|
121
|
+
*,
|
111
122
|
part: Literal["opt"],
|
123
|
+
optimizer_template: PyTree,
|
112
124
|
) -> optax.GradientTransformation: ...
|
113
125
|
|
114
126
|
@overload
|
115
127
|
def load_checkpoint(
|
116
128
|
self,
|
117
129
|
path: Path,
|
130
|
+
*,
|
118
131
|
part: Literal["opt_state"],
|
132
|
+
opt_state_template: PyTree,
|
119
133
|
) -> optax.OptState: ...
|
120
134
|
|
121
135
|
@overload
|
122
136
|
def load_checkpoint(
|
123
137
|
self,
|
124
138
|
path: Path,
|
139
|
+
*,
|
125
140
|
part: Literal["state"],
|
126
141
|
) -> State: ...
|
127
142
|
|
@@ -129,48 +144,71 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
|
|
129
144
|
def load_checkpoint(
|
130
145
|
self,
|
131
146
|
path: Path,
|
147
|
+
*,
|
132
148
|
part: Literal["config"],
|
133
|
-
) ->
|
149
|
+
) -> Config: ...
|
134
150
|
|
135
151
|
def load_checkpoint(
|
136
152
|
self,
|
137
153
|
path: Path,
|
154
|
+
*,
|
138
155
|
part: CheckpointPart = "all",
|
156
|
+
model_template: PyTree | None = None,
|
157
|
+
optimizer_template: PyTree | None = None,
|
158
|
+
opt_state_template: PyTree | None = None,
|
139
159
|
) -> (
|
140
|
-
tuple[PyTree, optax.GradientTransformation, optax.OptState, State,
|
141
|
-
| tuple[PyTree, State,
|
160
|
+
tuple[PyTree, optax.GradientTransformation, optax.OptState, State, Config]
|
161
|
+
| tuple[PyTree, State, Config]
|
142
162
|
| PyTree
|
143
163
|
| optax.GradientTransformation
|
144
164
|
| optax.OptState
|
145
165
|
| State
|
146
|
-
|
|
166
|
+
| Config
|
147
167
|
):
|
168
|
+
"""Load a checkpoint.
|
169
|
+
|
170
|
+
Args:
|
171
|
+
path: Path to the checkpoint directory
|
172
|
+
part: Which part of the checkpoint to load
|
173
|
+
model_template: Template model with correct structure but uninitialized weights
|
174
|
+
optimizer_template: Template optimizer with correct structure but uninitialized weights
|
175
|
+
opt_state_template: Template optimizer state with correct structure but uninitialized weights
|
176
|
+
|
177
|
+
Returns:
|
178
|
+
The requested checkpoint components
|
179
|
+
"""
|
148
180
|
with tarfile.open(path, "r:gz") as tar:
|
149
181
|
|
150
182
|
def get_model() -> PyTree:
|
183
|
+
if model_template is None:
|
184
|
+
raise ValueError("model_template must be provided to load model weights")
|
151
185
|
if (model := tar.extractfile("model")) is None:
|
152
186
|
raise ValueError(f"Checkpoint does not contain a model file: {path}")
|
153
|
-
return
|
187
|
+
return eqx.tree_deserialise_leaves(io.BytesIO(model.read()), model_template)
|
154
188
|
|
155
189
|
def get_opt() -> optax.GradientTransformation:
|
156
|
-
if
|
157
|
-
raise ValueError(
|
158
|
-
|
190
|
+
if optimizer_template is None:
|
191
|
+
raise ValueError("optimizer_template must be provided to load optimizer")
|
192
|
+
if (opt := tar.extractfile("optimizer")) is None:
|
193
|
+
raise ValueError(f"Checkpoint does not contain an optimizer file: {path}")
|
194
|
+
return eqx.tree_deserialise_leaves(io.BytesIO(opt.read()), optimizer_template)
|
159
195
|
|
160
196
|
def get_opt_state() -> optax.OptState:
|
197
|
+
if opt_state_template is None:
|
198
|
+
raise ValueError("opt_state_template must be provided to load optimizer state")
|
161
199
|
if (opt_state := tar.extractfile("opt_state")) is None:
|
162
|
-
raise ValueError(f"Checkpoint does not contain an
|
163
|
-
return
|
200
|
+
raise ValueError(f"Checkpoint does not contain an optimizer state file: {path}")
|
201
|
+
return eqx.tree_deserialise_leaves(io.BytesIO(opt_state.read()), opt_state_template)
|
164
202
|
|
165
203
|
def get_state() -> State:
|
166
204
|
if (state := tar.extractfile("state")) is None:
|
167
205
|
raise ValueError(f"Checkpoint does not contain a state file: {path}")
|
168
206
|
return State(**json.loads(state.read().decode()))
|
169
207
|
|
170
|
-
def get_config() ->
|
208
|
+
def get_config() -> Config:
|
171
209
|
if (config := tar.extractfile("config")) is None:
|
172
210
|
raise ValueError(f"Checkpoint does not contain a config file: {path}")
|
173
|
-
return cast(DictConfig, OmegaConf.load(config))
|
211
|
+
return self.get_config(cast(DictConfig, OmegaConf.load(config)), use_cli=False)
|
174
212
|
|
175
213
|
match part:
|
176
214
|
case "model":
|
@@ -192,51 +230,90 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
|
|
192
230
|
|
193
231
|
def save_checkpoint(
|
194
232
|
self,
|
195
|
-
model: PyTree,
|
196
|
-
optimizer: optax.GradientTransformation,
|
197
|
-
opt_state: optax.OptState,
|
198
|
-
|
233
|
+
model: PyTree | None = None,
|
234
|
+
optimizer: optax.GradientTransformation | None = None,
|
235
|
+
opt_state: optax.OptState | None = None,
|
236
|
+
aux_data: PyTree | None = None,
|
237
|
+
state: State | None = None,
|
199
238
|
) -> Path:
|
239
|
+
"""Save a checkpoint.
|
240
|
+
|
241
|
+
Args:
|
242
|
+
model: The model to save
|
243
|
+
state: The current training state
|
244
|
+
optimizer: The optimizer to save
|
245
|
+
aux_data: Additional data to save
|
246
|
+
opt_state: The optimizer state to save
|
247
|
+
|
248
|
+
Returns:
|
249
|
+
Path to the saved checkpoint
|
250
|
+
"""
|
200
251
|
ckpt_path = self.get_ckpt_path(state)
|
201
252
|
|
202
253
|
if not is_master():
|
203
254
|
return ckpt_path
|
204
255
|
|
205
|
-
# Gets the path to the last checkpoint
|
256
|
+
# Gets the path to the last checkpoint
|
206
257
|
logger.info("Saving checkpoint to %s", ckpt_path)
|
207
258
|
last_ckpt_path = self.get_ckpt_path()
|
208
259
|
ckpt_path.parent.mkdir(exist_ok=True, parents=True)
|
209
260
|
|
210
|
-
# Potentially removes the last checkpoint
|
261
|
+
# Potentially removes the last checkpoint
|
211
262
|
if last_ckpt_path.exists() and self.config.only_save_most_recent:
|
212
263
|
if (base_ckpt := last_ckpt_path.resolve()).is_file():
|
213
264
|
base_ckpt.unlink()
|
214
265
|
|
215
|
-
#
|
266
|
+
# Save the checkpoint components
|
216
267
|
with tarfile.open(ckpt_path, "w:gz") as tar:
|
217
268
|
|
218
|
-
def add_file(name: str,
|
269
|
+
def add_file(name: str, buf: io.BytesIO) -> None:
|
270
|
+
tarinfo = tarfile.TarInfo(name)
|
271
|
+
tarinfo.size = buf.tell()
|
272
|
+
buf.seek(0)
|
273
|
+
tar.addfile(tarinfo, buf)
|
274
|
+
|
275
|
+
# Save model using Equinox
|
276
|
+
if model is not None:
|
277
|
+
with io.BytesIO() as buf:
|
278
|
+
eqx.tree_serialise_leaves(buf, model)
|
279
|
+
add_file("model", buf)
|
280
|
+
|
281
|
+
# Save optimizer using Equinox
|
282
|
+
if optimizer is not None:
|
283
|
+
with io.BytesIO() as buf:
|
284
|
+
eqx.tree_serialise_leaves(buf, optimizer)
|
285
|
+
add_file("optimizer", buf)
|
286
|
+
|
287
|
+
# Save optimizer state using Equinox
|
288
|
+
if opt_state is not None:
|
219
289
|
with io.BytesIO() as buf:
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
290
|
+
eqx.tree_serialise_leaves(buf, opt_state)
|
291
|
+
add_file("opt_state", buf)
|
292
|
+
|
293
|
+
# Save aux data using Equinox.
|
294
|
+
if aux_data is not None:
|
295
|
+
with io.BytesIO() as buf:
|
296
|
+
eqx.tree_serialise_leaves(buf, aux_data)
|
297
|
+
add_file("aux_data", buf)
|
298
|
+
|
299
|
+
# Save state and config as JSON
|
300
|
+
def add_file_bytes(name: str, data: bytes) -> None: # noqa: ANN401
|
301
|
+
info = tarfile.TarInfo(name=name)
|
302
|
+
info.size = len(data)
|
303
|
+
tar.addfile(info, io.BytesIO(data))
|
304
|
+
|
305
|
+
if state is not None:
|
306
|
+
add_file_bytes("state", json.dumps(asdict(state), indent=2).encode())
|
307
|
+
add_file_bytes("config", OmegaConf.to_yaml(self.config).encode())
|
308
|
+
|
309
|
+
# Updates the symlink to the new checkpoint
|
233
310
|
last_ckpt_path.unlink(missing_ok=True)
|
234
311
|
try:
|
235
312
|
last_ckpt_path.symlink_to(ckpt_path.relative_to(last_ckpt_path.parent))
|
236
313
|
except FileExistsError:
|
237
314
|
logger.exception("Exception while trying to update %s", ckpt_path)
|
238
315
|
|
239
|
-
# Calls the base callback
|
316
|
+
# Calls the base callback
|
240
317
|
self.on_after_checkpoint_save(ckpt_path, state)
|
241
318
|
|
242
319
|
return ckpt_path
|
@@ -9,6 +9,7 @@ import jax
|
|
9
9
|
from dpshdl.dataloader import CollatedDataloaderItem, Dataloader
|
10
10
|
from dpshdl.dataset import Dataset, ErrorHandlingDataset
|
11
11
|
from dpshdl.prefetcher import Prefetcher
|
12
|
+
from jaxtyping import PRNGKeyArray
|
12
13
|
from omegaconf import II, MISSING
|
13
14
|
|
14
15
|
from xax.core.conf import field, is_missing
|
@@ -103,7 +104,7 @@ class DataloadersMixin(ProcessMixin[Config], BaseTask[Config], Generic[Config],
|
|
103
104
|
"or `get_data_iterator` to return an iterator for the given dataset."
|
104
105
|
)
|
105
106
|
|
106
|
-
def get_data_iterator(self, phase: Phase) -> Iterator:
|
107
|
+
def get_data_iterator(self, phase: Phase, key: PRNGKeyArray) -> Iterator:
|
107
108
|
raise NotImplementedError(
|
108
109
|
"You must implement either the `get_dataset` method to return the dataset for the given phase, "
|
109
110
|
"or `get_data_iterator` to return an iterator for the given dataset."
|
@@ -11,7 +11,7 @@ import textwrap
|
|
11
11
|
import time
|
12
12
|
import traceback
|
13
13
|
from abc import ABC, abstractmethod
|
14
|
-
from dataclasses import dataclass, is_dataclass
|
14
|
+
from dataclasses import asdict, dataclass, is_dataclass
|
15
15
|
from threading import Thread
|
16
16
|
from typing import (
|
17
17
|
Any,
|
@@ -33,7 +33,6 @@ import jax.numpy as jnp
|
|
33
33
|
import numpy as np
|
34
34
|
import optax
|
35
35
|
from jaxtyping import Array, PRNGKeyArray, PyTree
|
36
|
-
from omegaconf import DictConfig
|
37
36
|
|
38
37
|
from xax.core.conf import field
|
39
38
|
from xax.core.state import Phase, State
|
@@ -50,6 +49,7 @@ from xax.utils.experiments import (
|
|
50
49
|
TrainingFinishedError,
|
51
50
|
diff_configs,
|
52
51
|
get_diff_string,
|
52
|
+
get_info_json,
|
53
53
|
get_state_file_string,
|
54
54
|
get_training_code,
|
55
55
|
)
|
@@ -340,20 +340,30 @@ class TrainMixin(
|
|
340
340
|
|
341
341
|
if init_ckpt_path is not None:
|
342
342
|
logger.info("Loading checkpoint from %s", init_ckpt_path)
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
343
|
+
model_spec = eqx.filter_eval_shape(self.get_model, key)
|
344
|
+
model, state, config = self.load_checkpoint(
|
345
|
+
init_ckpt_path,
|
346
|
+
part="model_state_config",
|
347
|
+
model_template=model_spec,
|
348
|
+
)
|
349
|
+
config_diff = get_diff_string(diff_configs(asdict(config), asdict(self.config)))
|
350
|
+
if config_diff:
|
351
|
+
logger.warning("Loaded config differs from current config:\n%s", config_diff)
|
349
352
|
|
350
|
-
|
351
|
-
model, state, config = self.load_checkpoint(init_ckpt_path, "model_state_config")
|
352
|
-
config_diff = get_diff_string(diff_configs(config, cast(DictConfig, self.config)))
|
353
|
-
if config_diff:
|
354
|
-
logger.warning("Loaded config differs from current config:\n%s", config_diff)
|
353
|
+
if not load_optimizer:
|
355
354
|
return model, state
|
356
355
|
|
356
|
+
# Loads the optimizer.
|
357
|
+
optimizer_spec = eqx.filter_eval_shape(self.get_optimizer)
|
358
|
+
optimizer = self.load_checkpoint(init_ckpt_path, part="opt", optimizer_template=optimizer_spec)
|
359
|
+
|
360
|
+
# Loads the optimizer state.
|
361
|
+
opt_state_spec = eqx.filter_eval_shape(self.get_initial_opt_state, model, optimizer)
|
362
|
+
opt_state = self.load_checkpoint(init_ckpt_path, part="opt_state", opt_state_template=opt_state_spec)
|
363
|
+
|
364
|
+
return model, optimizer, opt_state, state
|
365
|
+
|
366
|
+
logger.info("No checkpoint found. Initializing a new model.")
|
357
367
|
model = self.get_model(key)
|
358
368
|
state = State.init_state()
|
359
369
|
|
@@ -554,6 +564,7 @@ class TrainMixin(
|
|
554
564
|
self.logger.log_file("state.txt", get_state_file_string(self))
|
555
565
|
self.logger.log_file("training_code.py", get_training_code(self))
|
556
566
|
self.logger.log_file("config.yaml", self.config_str(self.config, use_cli=False))
|
567
|
+
self.logger.log_file("info.json", get_info_json())
|
557
568
|
|
558
569
|
def model_partition_fn(self, item: Any) -> bool: # noqa: ANN401
|
559
570
|
return eqx.is_inexact_array(item)
|
@@ -627,16 +638,16 @@ class TrainMixin(
|
|
627
638
|
|
628
639
|
if self.should_checkpoint(state):
|
629
640
|
model = eqx.combine(model_arr, model_static)
|
630
|
-
self.save_checkpoint(model, optimizer, opt_state, state)
|
641
|
+
self.save_checkpoint(model=model, optimizer=optimizer, opt_state=opt_state, state=state)
|
631
642
|
|
632
643
|
# After finishing training, save the final checkpoint.
|
633
644
|
model = eqx.combine(model_arr, model_static)
|
634
|
-
self.save_checkpoint(model, optimizer, opt_state, state)
|
645
|
+
self.save_checkpoint(model=model, optimizer=optimizer, opt_state=opt_state, state=state)
|
635
646
|
|
636
647
|
@contextlib.contextmanager
|
637
|
-
def get_train_iterator(self) -> Generator[Iterator[Batch], None, None]:
|
648
|
+
def get_train_iterator(self, key: PRNGKeyArray) -> Generator[Iterator[Batch], None, None]:
|
638
649
|
try:
|
639
|
-
train_iterator: Iterator[Batch] = self.get_data_iterator("train")
|
650
|
+
train_iterator: Iterator[Batch] = self.get_data_iterator("train", key=key)
|
640
651
|
yield train_iterator
|
641
652
|
return
|
642
653
|
except NotImplementedError:
|
@@ -653,9 +664,9 @@ class TrainMixin(
|
|
653
664
|
logger.info("Closing train prefetcher")
|
654
665
|
|
655
666
|
@contextlib.contextmanager
|
656
|
-
def get_valid_iterator(self) -> Generator[Iterator[Batch], None, None]:
|
667
|
+
def get_valid_iterator(self, key: PRNGKeyArray) -> Generator[Iterator[Batch], None, None]:
|
657
668
|
try:
|
658
|
-
valid_iterator: Iterator[Batch] = self.get_data_iterator("valid")
|
669
|
+
valid_iterator: Iterator[Batch] = self.get_data_iterator("valid", key=key)
|
659
670
|
yield valid_iterator
|
660
671
|
return
|
661
672
|
except NotImplementedError:
|
@@ -699,12 +710,13 @@ class TrainMixin(
|
|
699
710
|
state = self.on_training_start(state)
|
700
711
|
|
701
712
|
def on_exit() -> None:
|
702
|
-
self.save_checkpoint(model, optimizer, opt_state, state)
|
713
|
+
self.save_checkpoint(model=model, optimizer=optimizer, opt_state=opt_state, state=state)
|
703
714
|
|
704
715
|
# Handle user-defined interrupts during the training loop.
|
705
716
|
self.add_signal_handler(on_exit, signal.SIGUSR1, signal.SIGTERM)
|
706
717
|
|
707
|
-
|
718
|
+
key, tkey, vkey = jax.random.split(key, 3)
|
719
|
+
with self.get_train_iterator(tkey) as train_pf, self.get_valid_iterator(vkey) as valid_pf:
|
708
720
|
try:
|
709
721
|
self.train_loop(
|
710
722
|
model=model,
|
@@ -721,7 +733,7 @@ class TrainMixin(
|
|
721
733
|
f"Finished training after {state.num_steps} steps, {state.num_samples} samples",
|
722
734
|
important=True,
|
723
735
|
)
|
724
|
-
self.save_checkpoint(model, optimizer, opt_state, state)
|
736
|
+
self.save_checkpoint(model=model, optimizer=optimizer, opt_state=opt_state, state=state)
|
725
737
|
|
726
738
|
except (KeyboardInterrupt, bdb.BdbQuit):
|
727
739
|
if is_master():
|
@@ -731,7 +743,7 @@ class TrainMixin(
|
|
731
743
|
exception_tb = textwrap.indent(highlight_exception_message(traceback.format_exc()), " ")
|
732
744
|
sys.stdout.write(f"Caught exception during training loop:\n\n{exception_tb}\n")
|
733
745
|
sys.stdout.flush()
|
734
|
-
self.save_checkpoint(model, optimizer, opt_state, state)
|
746
|
+
self.save_checkpoint(model=model, optimizer=optimizer, opt_state=opt_state, state=state)
|
735
747
|
|
736
748
|
finally:
|
737
749
|
state = self.on_training_end(state)
|
@@ -7,6 +7,7 @@ import functools
|
|
7
7
|
import hashlib
|
8
8
|
import inspect
|
9
9
|
import itertools
|
10
|
+
import json
|
10
11
|
import logging
|
11
12
|
import math
|
12
13
|
import os
|
@@ -24,7 +25,7 @@ import warnings
|
|
24
25
|
from abc import ABC, abstractmethod
|
25
26
|
from pathlib import Path
|
26
27
|
from types import TracebackType
|
27
|
-
from typing import Any, Iterator, Self, TypeVar, cast
|
28
|
+
from typing import Any, Iterator, Mapping, Self, Sequence, TypeVar, cast
|
28
29
|
from urllib.parse import urlparse
|
29
30
|
|
30
31
|
import git
|
@@ -116,9 +117,7 @@ class StateTimer:
|
|
116
117
|
|
117
118
|
def log_dict(self) -> dict[str, int | float | tuple[int | float, bool]]:
|
118
119
|
return {
|
119
|
-
"steps": (self.step_timer.steps, True),
|
120
120
|
"steps/second": self.step_timer.steps_per_second,
|
121
|
-
"samples": (self.sample_timer.steps, True),
|
122
121
|
"samples/second": (self.sample_timer.steps_per_second, True),
|
123
122
|
"dt": self.iter_timer.iter_seconds,
|
124
123
|
}
|
@@ -204,8 +203,8 @@ class MinGradScaleError(TrainingFinishedError):
|
|
204
203
|
|
205
204
|
|
206
205
|
def diff_configs(
|
207
|
-
first:
|
208
|
-
second:
|
206
|
+
first: Mapping | Sequence,
|
207
|
+
second: Mapping | Sequence,
|
209
208
|
prefix: str | None = None,
|
210
209
|
) -> tuple[list[str], list[str]]:
|
211
210
|
"""Returns the difference between two configs.
|
@@ -232,7 +231,7 @@ def diff_configs(
|
|
232
231
|
|
233
232
|
any_config = (ListConfig, DictConfig)
|
234
233
|
|
235
|
-
if isinstance(first,
|
234
|
+
if isinstance(first, Mapping) and isinstance(second, Mapping):
|
236
235
|
first_keys, second_keys = cast(set[str], set(first.keys())), cast(set[str], set(second.keys()))
|
237
236
|
|
238
237
|
# Gets the new keys in each config.
|
@@ -242,11 +241,12 @@ def diff_configs(
|
|
242
241
|
# Gets the new sub-keys in each config.
|
243
242
|
for key in first_keys.intersection(second_keys):
|
244
243
|
sub_prefix = key if prefix is None else f"{prefix}.{key}"
|
245
|
-
if
|
246
|
-
if
|
247
|
-
|
248
|
-
|
249
|
-
|
244
|
+
if isinstance(first, DictConfig) and isinstance(second, DictConfig):
|
245
|
+
if OmegaConf.is_missing(first, key) or OmegaConf.is_missing(second, key):
|
246
|
+
if not OmegaConf.is_missing(first, key):
|
247
|
+
new_first += [get_diff_string(sub_prefix, first[key])]
|
248
|
+
if not OmegaConf.is_missing(second, key):
|
249
|
+
new_second += [get_diff_string(sub_prefix, second[key])]
|
250
250
|
elif isinstance(first[key], any_config) and isinstance(second[key], any_config):
|
251
251
|
sub_new_first, sub_new_second = diff_configs(first[key], second[key], prefix=sub_prefix)
|
252
252
|
new_first, new_second = new_first + sub_new_first, new_second + sub_new_second
|
@@ -255,7 +255,7 @@ def diff_configs(
|
|
255
255
|
new_first += [get_diff_string(sub_prefix, first_val)]
|
256
256
|
new_second += [get_diff_string(sub_prefix, second_val)]
|
257
257
|
|
258
|
-
elif isinstance(first,
|
258
|
+
elif isinstance(first, Sequence) and isinstance(second, Sequence):
|
259
259
|
if len(first) > len(second):
|
260
260
|
for i in range(len(second), len(first)):
|
261
261
|
new_first += [get_diff_string(prefix, first[i])]
|
@@ -470,16 +470,33 @@ def get_command_line_string() -> str:
|
|
470
470
|
return " ".join(sys.argv)
|
471
471
|
|
472
472
|
|
473
|
+
def get_environment_variables() -> str:
|
474
|
+
return "\n".join([f"{key}={value}" for key, value in sorted(os.environ.items())])
|
475
|
+
|
476
|
+
|
473
477
|
def get_state_file_string(obj: object) -> str:
|
474
478
|
return "\n\n".join(
|
475
479
|
[
|
476
480
|
f"=== Command Line ===\n\n{get_command_line_string()}",
|
477
481
|
f"=== Git State ===\n\n{get_git_state(obj)}",
|
478
482
|
f"=== Packages ===\n\n{get_packages_with_versions()}",
|
483
|
+
f"=== Environment Variables ===\n\n{get_environment_variables()}",
|
479
484
|
]
|
480
485
|
)
|
481
486
|
|
482
487
|
|
488
|
+
def get_info_json() -> str:
|
489
|
+
return json.dumps(
|
490
|
+
{
|
491
|
+
"process_id": os.getpid(),
|
492
|
+
"job": {
|
493
|
+
"start_time": datetime.datetime.now().isoformat(),
|
494
|
+
},
|
495
|
+
},
|
496
|
+
indent=2,
|
497
|
+
)
|
498
|
+
|
499
|
+
|
483
500
|
def get_training_code(obj: object) -> str:
|
484
501
|
"""Gets the text from the file containing the provided object.
|
485
502
|
|
@@ -2,11 +2,12 @@
|
|
2
2
|
|
3
3
|
import functools
|
4
4
|
import io
|
5
|
+
import json
|
5
6
|
import os
|
6
7
|
import tempfile
|
7
8
|
import time
|
8
9
|
from pathlib import Path
|
9
|
-
from typing import Literal, TypedDict
|
10
|
+
from typing import Any, Literal, TypedDict
|
10
11
|
|
11
12
|
import numpy as np
|
12
13
|
import PIL.Image
|
@@ -14,9 +15,15 @@ from PIL.Image import Image as PILImage
|
|
14
15
|
from tensorboard.compat.proto.config_pb2 import RunMetadata
|
15
16
|
from tensorboard.compat.proto.event_pb2 import Event, TaggedRunMetadata
|
16
17
|
from tensorboard.compat.proto.graph_pb2 import GraphDef
|
17
|
-
from tensorboard.compat.proto.summary_pb2 import
|
18
|
+
from tensorboard.compat.proto.summary_pb2 import (
|
19
|
+
HistogramProto,
|
20
|
+
Summary,
|
21
|
+
SummaryMetadata,
|
22
|
+
)
|
18
23
|
from tensorboard.compat.proto.tensor_pb2 import TensorProto
|
19
24
|
from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto
|
25
|
+
from tensorboard.plugins.mesh import metadata as mesh_metadata
|
26
|
+
from tensorboard.plugins.mesh.plugin_data_pb2 import MeshPluginData
|
20
27
|
from tensorboard.plugins.text.plugin_data_pb2 import TextPluginData
|
21
28
|
from tensorboard.summary.writer.event_file_writer import EventFileWriter
|
22
29
|
|
@@ -84,6 +91,68 @@ def make_histogram(values: np.ndarray, bins: str | np.ndarray, max_bins: int | N
|
|
84
91
|
)
|
85
92
|
|
86
93
|
|
94
|
+
def _get_json_config(config_dict: dict[str, Any] | None) -> str:
|
95
|
+
json_config = "{}"
|
96
|
+
if config_dict is not None:
|
97
|
+
json_config = json.dumps(config_dict, sort_keys=True)
|
98
|
+
return json_config
|
99
|
+
|
100
|
+
|
101
|
+
def make_mesh_summary(
|
102
|
+
tag: str,
|
103
|
+
vertices: np.ndarray,
|
104
|
+
colors: np.ndarray | None,
|
105
|
+
faces: np.ndarray | None,
|
106
|
+
config_dict: dict[str, Any] | None,
|
107
|
+
display_name: str | None = None,
|
108
|
+
description: str | None = None,
|
109
|
+
) -> Summary:
|
110
|
+
json_config = _get_json_config(config_dict)
|
111
|
+
|
112
|
+
summaries = []
|
113
|
+
tensors = [
|
114
|
+
(vertices, MeshPluginData.VERTEX),
|
115
|
+
(faces, MeshPluginData.FACE),
|
116
|
+
(colors, MeshPluginData.COLOR),
|
117
|
+
]
|
118
|
+
# Filter out None tensors and explicitly type the list
|
119
|
+
valid_tensors = [(t, content_type) for t, content_type in tensors if t is not None]
|
120
|
+
components = mesh_metadata.get_components_bitmask([content_type for (_, content_type) in valid_tensors])
|
121
|
+
|
122
|
+
for tensor, content_type in valid_tensors: # Now we know tensor is not None
|
123
|
+
tensor_metadata = mesh_metadata.create_summary_metadata(
|
124
|
+
tag,
|
125
|
+
display_name,
|
126
|
+
content_type,
|
127
|
+
components,
|
128
|
+
tensor.shape, # Safe now since tensor is not None
|
129
|
+
description,
|
130
|
+
json_config=json_config,
|
131
|
+
)
|
132
|
+
|
133
|
+
tensor_proto = TensorProto(
|
134
|
+
dtype="DT_FLOAT",
|
135
|
+
float_val=tensor.reshape(-1).tolist(), # Safe now since tensor is not None
|
136
|
+
tensor_shape=TensorShapeProto(
|
137
|
+
dim=[
|
138
|
+
TensorShapeProto.Dim(size=tensor.shape[0]), # Safe now since tensor is not None
|
139
|
+
TensorShapeProto.Dim(size=tensor.shape[1]),
|
140
|
+
TensorShapeProto.Dim(size=tensor.shape[2]),
|
141
|
+
]
|
142
|
+
),
|
143
|
+
)
|
144
|
+
|
145
|
+
tensor_summary = Summary.Value(
|
146
|
+
tag=mesh_metadata.get_instance_name(tag, content_type),
|
147
|
+
tensor=tensor_proto,
|
148
|
+
metadata=tensor_metadata,
|
149
|
+
)
|
150
|
+
|
151
|
+
summaries.append(tensor_summary)
|
152
|
+
|
153
|
+
return Summary(value=summaries)
|
154
|
+
|
155
|
+
|
87
156
|
class TensorboardProtobufWriter:
|
88
157
|
def __init__(
|
89
158
|
self,
|
@@ -454,6 +523,9 @@ class TensorboardWriter:
|
|
454
523
|
weighted_sum = float((bin_centers * bucket_counts).sum())
|
455
524
|
weighted_sum_squares = float((bin_centers**2 * bucket_counts).sum())
|
456
525
|
|
526
|
+
# Convert bin edges to list of floats explicitly
|
527
|
+
bucket_limits: list[float | np.ndarray] = [float(x) for x in bin_edges[1:]]
|
528
|
+
|
457
529
|
self.add_histogram_raw(
|
458
530
|
tag=tag,
|
459
531
|
min=float(bin_edges[0]),
|
@@ -461,12 +533,28 @@ class TensorboardWriter:
|
|
461
533
|
num=int(total_counts),
|
462
534
|
sum=weighted_sum,
|
463
535
|
sum_squares=weighted_sum_squares,
|
464
|
-
bucket_limits=
|
536
|
+
bucket_limits=bucket_limits, # Now properly typed
|
465
537
|
bucket_counts=bucket_counts.tolist(),
|
466
538
|
global_step=global_step,
|
467
539
|
walltime=walltime,
|
468
540
|
)
|
469
541
|
|
542
|
+
def add_mesh(
|
543
|
+
self,
|
544
|
+
tag: str,
|
545
|
+
vertices: np.ndarray,
|
546
|
+
colors: np.ndarray | None,
|
547
|
+
faces: np.ndarray | None,
|
548
|
+
config_dict: dict[str, Any] | None,
|
549
|
+
global_step: int | None = None,
|
550
|
+
walltime: float | None = None,
|
551
|
+
) -> None:
|
552
|
+
self.pb_writer.add_summary(
|
553
|
+
make_mesh_summary(tag, vertices, colors, faces, config_dict),
|
554
|
+
global_step=global_step,
|
555
|
+
walltime=walltime,
|
556
|
+
)
|
557
|
+
|
470
558
|
|
471
559
|
class TensorboardWriterKwargs(TypedDict):
|
472
560
|
max_queue_size: int
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: xax
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.2.0
|
4
4
|
Summary: A library for fast Jax experimentation
|
5
5
|
Home-page: https://github.com/kscalelabs/xax
|
6
6
|
Author: Benjamin Bolte
|
@@ -8,14 +8,14 @@ Requires-Python: >=3.11
|
|
8
8
|
Description-Content-Type: text/markdown
|
9
9
|
License-File: LICENSE
|
10
10
|
Requires-Dist: attrs
|
11
|
+
Requires-Dist: chex
|
12
|
+
Requires-Dist: dpshdl
|
13
|
+
Requires-Dist: equinox
|
14
|
+
Requires-Dist: importlib-resources
|
11
15
|
Requires-Dist: jax
|
12
16
|
Requires-Dist: jaxtyping
|
13
|
-
Requires-Dist: equinox
|
14
17
|
Requires-Dist: optax
|
15
|
-
Requires-Dist:
|
16
|
-
Requires-Dist: chex
|
17
|
-
Requires-Dist: importlib-resources
|
18
|
-
Requires-Dist: cloudpickle
|
18
|
+
Requires-Dist: orbax-checkpoint
|
19
19
|
Requires-Dist: pillow
|
20
20
|
Requires-Dist: omegaconf
|
21
21
|
Requires-Dist: gitpython
|
{xax-0.1.16 → xax-0.2.0}/LICENSE
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|