xax 0.1.16__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +4 -1
- xax/core/state.py +26 -1
- xax/nn/geom.py +34 -0
- xax/requirements.txt +5 -5
- xax/task/base.py +1 -1
- xax/task/logger.py +107 -2
- xax/task/loggers/tensorboard.py +16 -0
- xax/task/mixins/checkpointing.py +124 -50
- xax/task/mixins/data_loader.py +2 -1
- xax/task/mixins/train.py +153 -27
- xax/utils/experiments.py +29 -12
- xax/utils/tensorboard.py +91 -3
- {xax-0.1.16.dist-info → xax-0.2.1.dist-info}/METADATA +6 -6
- {xax-0.1.16.dist-info → xax-0.2.1.dist-info}/RECORD +17 -17
- {xax-0.1.16.dist-info → xax-0.2.1.dist-info}/WHEEL +0 -0
- {xax-0.1.16.dist-info → xax-0.2.1.dist-info}/licenses/LICENSE +0 -0
- {xax-0.1.16.dist-info → xax-0.2.1.dist-info}/top_level.txt +0 -0
xax/__init__.py
CHANGED
@@ -12,7 +12,7 @@ and running the update script:
|
|
12
12
|
python -m scripts.update_api --inplace
|
13
13
|
"""
|
14
14
|
|
15
|
-
__version__ = "0.1
|
15
|
+
__version__ = "0.2.1"
|
16
16
|
|
17
17
|
# This list shouldn't be modified by hand; instead, run the update script.
|
18
18
|
__all__ = [
|
@@ -44,6 +44,7 @@ __all__ = [
|
|
44
44
|
"euler_to_quat",
|
45
45
|
"get_projected_gravity_vector_from_quat",
|
46
46
|
"quat_to_euler",
|
47
|
+
"quat_to_rotmat",
|
47
48
|
"rotate_vector_by_quat",
|
48
49
|
"cross_entropy",
|
49
50
|
"cast_norm_type",
|
@@ -206,6 +207,7 @@ NAME_MAP: dict[str, str] = {
|
|
206
207
|
"euler_to_quat": "nn.geom",
|
207
208
|
"get_projected_gravity_vector_from_quat": "nn.geom",
|
208
209
|
"quat_to_euler": "nn.geom",
|
210
|
+
"quat_to_rotmat": "nn.geom",
|
209
211
|
"rotate_vector_by_quat": "nn.geom",
|
210
212
|
"cross_entropy": "nn.losses",
|
211
213
|
"cast_norm_type": "nn.norm",
|
@@ -369,6 +371,7 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
369
371
|
euler_to_quat,
|
370
372
|
get_projected_gravity_vector_from_quat,
|
371
373
|
quat_to_euler,
|
374
|
+
quat_to_rotmat,
|
372
375
|
rotate_vector_by_quat,
|
373
376
|
)
|
374
377
|
from xax.nn.losses import cross_entropy
|
xax/core/state.py
CHANGED
@@ -12,6 +12,14 @@ from xax.core.conf import field
|
|
12
12
|
Phase = Literal["train", "valid"]
|
13
13
|
|
14
14
|
|
15
|
+
def _phase_to_int(phase: Phase) -> int:
|
16
|
+
return {"train": 0, "valid": 1}[phase]
|
17
|
+
|
18
|
+
|
19
|
+
def _int_to_phase(i: int) -> Phase:
|
20
|
+
return cast(Phase, ["train", "valid"][i])
|
21
|
+
|
22
|
+
|
15
23
|
class StateDict(TypedDict, total=False):
|
16
24
|
num_steps: NotRequired[int]
|
17
25
|
num_samples: NotRequired[int]
|
@@ -35,7 +43,7 @@ class State:
|
|
35
43
|
|
36
44
|
@property
|
37
45
|
def phase(self) -> Phase:
|
38
|
-
return
|
46
|
+
return _int_to_phase(self._phase)
|
39
47
|
|
40
48
|
@classmethod
|
41
49
|
def init_state(cls) -> "State":
|
@@ -74,3 +82,20 @@ class State:
|
|
74
82
|
case _:
|
75
83
|
raise ValueError(f"Invalid phase: {phase}")
|
76
84
|
return State(**{**asdict(self), **kwargs, **extra_kwargs})
|
85
|
+
|
86
|
+
def to_dict(self) -> dict[str, int | float | str]:
|
87
|
+
return {
|
88
|
+
"num_steps": int(self.num_steps),
|
89
|
+
"num_samples": int(self.num_samples),
|
90
|
+
"num_valid_steps": int(self.num_valid_steps),
|
91
|
+
"num_valid_samples": int(self.num_valid_samples),
|
92
|
+
"start_time_s": float(self.start_time_s),
|
93
|
+
"elapsed_time_s": float(self.elapsed_time_s),
|
94
|
+
"phase": str(self.phase),
|
95
|
+
}
|
96
|
+
|
97
|
+
@classmethod
|
98
|
+
def from_dict(cls, d: dict[str, int | float | str]) -> "State":
|
99
|
+
if "phase" in d:
|
100
|
+
d["_phase"] = _phase_to_int(cast(Phase, d.pop("phase")))
|
101
|
+
return cls(**d) # type: ignore[arg-type]
|
xax/nn/geom.py
CHANGED
@@ -177,3 +177,37 @@ def cubic_bezier_interpolation(y_start: Array, y_end: Array, x: Array) -> Array:
|
|
177
177
|
y_diff = y_end - y_start
|
178
178
|
bezier = x**3 + 3 * (x**2 * (1 - x))
|
179
179
|
return y_start + y_diff * bezier
|
180
|
+
|
181
|
+
|
182
|
+
def quat_to_rotmat(quat: Array, eps: float = 1e-6) -> Array:
|
183
|
+
"""Converts a quaternion to a rotation matrix.
|
184
|
+
|
185
|
+
Args:
|
186
|
+
quat: The quaternion to convert, shape (*, 4).
|
187
|
+
eps: A small epsilon value to avoid division by zero.
|
188
|
+
|
189
|
+
Returns:
|
190
|
+
The rotation matrix, shape (*, 3, 3).
|
191
|
+
"""
|
192
|
+
quat = quat / (jnp.linalg.norm(quat, axis=-1, keepdims=True) + eps)
|
193
|
+
w, x, y, z = jnp.split(quat, 4, axis=-1)
|
194
|
+
|
195
|
+
xx = 1 - 2 * (y * y + z * z)
|
196
|
+
xy = 2 * (x * y - z * w)
|
197
|
+
xz = 2 * (x * z + y * w)
|
198
|
+
yx = 2 * (x * y + z * w)
|
199
|
+
yy = 1 - 2 * (x * x + z * z)
|
200
|
+
yz = 2 * (y * z - x * w)
|
201
|
+
zx = 2 * (x * z - y * w)
|
202
|
+
zy = 2 * (y * z + x * w)
|
203
|
+
zz = 1 - 2 * (x * x + y * y)
|
204
|
+
|
205
|
+
# Corrected stacking: row-major order
|
206
|
+
return jnp.concatenate(
|
207
|
+
[
|
208
|
+
jnp.concatenate([xx, xy, xz], axis=-1)[..., None, :],
|
209
|
+
jnp.concatenate([yx, yy, yz], axis=-1)[..., None, :],
|
210
|
+
jnp.concatenate([zx, zy, zz], axis=-1)[..., None, :],
|
211
|
+
],
|
212
|
+
axis=-2,
|
213
|
+
)
|
xax/requirements.txt
CHANGED
@@ -2,16 +2,16 @@
|
|
2
2
|
|
3
3
|
# Core ML/JAX dependencies
|
4
4
|
attrs
|
5
|
+
chex
|
6
|
+
dpshdl
|
7
|
+
equinox
|
8
|
+
importlib-resources
|
5
9
|
jax
|
6
10
|
jaxtyping
|
7
|
-
equinox
|
8
11
|
optax
|
9
|
-
|
10
|
-
chex
|
11
|
-
importlib-resources
|
12
|
+
orbax-checkpoint
|
12
13
|
|
13
14
|
# Data processing and serialization
|
14
|
-
cloudpickle
|
15
15
|
pillow
|
16
16
|
|
17
17
|
# Configuration and project management
|
xax/task/base.py
CHANGED
@@ -79,7 +79,7 @@ class BaseTask(Generic[Config]):
|
|
79
79
|
def on_training_end(self, state: State) -> State:
|
80
80
|
return state
|
81
81
|
|
82
|
-
def on_after_checkpoint_save(self, ckpt_path: Path, state: State) -> State:
|
82
|
+
def on_after_checkpoint_save(self, ckpt_path: Path, state: State | None) -> State | None:
|
83
83
|
return state
|
84
84
|
|
85
85
|
@functools.cached_property
|
xax/task/logger.py
CHANGED
@@ -18,11 +18,22 @@ from abc import ABC, abstractmethod
|
|
18
18
|
from collections import defaultdict
|
19
19
|
from dataclasses import dataclass
|
20
20
|
from types import TracebackType
|
21
|
-
from typing import
|
21
|
+
from typing import (
|
22
|
+
Any,
|
23
|
+
Callable,
|
24
|
+
Iterator,
|
25
|
+
Literal,
|
26
|
+
Self,
|
27
|
+
Sequence,
|
28
|
+
TypeVar,
|
29
|
+
cast,
|
30
|
+
get_args,
|
31
|
+
)
|
22
32
|
|
23
33
|
import jax
|
24
34
|
import jax.numpy as jnp
|
25
35
|
import numpy as np
|
36
|
+
from jax._src.core import ClosedJaxpr
|
26
37
|
from jaxtyping import Array
|
27
38
|
from PIL import Image, ImageDraw, ImageFont
|
28
39
|
from PIL.Image import Image as PILImage
|
@@ -194,7 +205,10 @@ def tile_images(images: list[PILImage], sep: int = 0) -> PILImage:
|
|
194
205
|
return tiled
|
195
206
|
|
196
207
|
|
197
|
-
def as_numpy(array: Array) -> np.ndarray:
|
208
|
+
def as_numpy(array: Array | np.ndarray) -> np.ndarray:
|
209
|
+
"""Convert a JAX array or numpy array to numpy array."""
|
210
|
+
if isinstance(array, np.ndarray):
|
211
|
+
return array
|
198
212
|
array = jax.device_get(array)
|
199
213
|
if jax.dtypes.issubdtype(array.dtype, jnp.floating):
|
200
214
|
array = array.astype(jnp.float32)
|
@@ -205,6 +219,13 @@ def as_numpy(array: Array) -> np.ndarray:
|
|
205
219
|
return np.array(array)
|
206
220
|
|
207
221
|
|
222
|
+
def as_numpy_opt(array: Array | np.ndarray | None) -> np.ndarray | None:
|
223
|
+
"""Convert an optional JAX array or numpy array to numpy array."""
|
224
|
+
if array is None:
|
225
|
+
return None
|
226
|
+
return as_numpy(array)
|
227
|
+
|
228
|
+
|
208
229
|
@dataclass(kw_only=True)
|
209
230
|
class LogString:
|
210
231
|
value: str
|
@@ -252,6 +273,19 @@ class LogHistogram:
|
|
252
273
|
bucket_counts: list[int]
|
253
274
|
|
254
275
|
|
276
|
+
@dataclass(kw_only=True)
|
277
|
+
class LogMesh:
|
278
|
+
vertices: np.ndarray
|
279
|
+
colors: np.ndarray | None
|
280
|
+
faces: np.ndarray | None
|
281
|
+
config_dict: dict[str, Any] | None # noqa: ANN401
|
282
|
+
|
283
|
+
|
284
|
+
@dataclass(kw_only=True)
|
285
|
+
class LogGraph:
|
286
|
+
computation: ClosedJaxpr
|
287
|
+
|
288
|
+
|
255
289
|
@dataclass(kw_only=True)
|
256
290
|
class LogLine:
|
257
291
|
state: State
|
@@ -261,6 +295,7 @@ class LogLine:
|
|
261
295
|
strings: dict[str, dict[str, LogString]]
|
262
296
|
images: dict[str, dict[str, LogImage]]
|
263
297
|
videos: dict[str, dict[str, LogVideo]]
|
298
|
+
meshes: dict[str, dict[str, LogMesh]]
|
264
299
|
|
265
300
|
|
266
301
|
@dataclass(kw_only=True)
|
@@ -533,6 +568,7 @@ class Logger:
|
|
533
568
|
self.strings: dict[str, dict[str, Callable[[], LogString]]] = defaultdict(dict)
|
534
569
|
self.images: dict[str, dict[str, Callable[[], LogImage]]] = defaultdict(dict)
|
535
570
|
self.videos: dict[str, dict[str, Callable[[], LogVideo]]] = defaultdict(dict)
|
571
|
+
self.meshes: dict[str, dict[str, Callable[[], LogMesh]]] = defaultdict(dict)
|
536
572
|
self.default_namespace = default_namespace
|
537
573
|
self.loggers: list[LoggerImpl] = []
|
538
574
|
|
@@ -560,6 +596,7 @@ class Logger:
|
|
560
596
|
strings={k: {kk: v() for kk, v in v.items()} for k, v in self.strings.items()},
|
561
597
|
images={k: {kk: v() for kk, v in v.items()} for k, v in self.images.items()},
|
562
598
|
videos={k: {kk: v() for kk, v in v.items()} for k, v in self.videos.items()},
|
599
|
+
meshes={k: {kk: v() for kk, v in v.items()} for k, v in self.meshes.items()},
|
563
600
|
)
|
564
601
|
|
565
602
|
def clear(self) -> None:
|
@@ -569,6 +606,7 @@ class Logger:
|
|
569
606
|
self.strings.clear()
|
570
607
|
self.images.clear()
|
571
608
|
self.videos.clear()
|
609
|
+
self.meshes.clear()
|
572
610
|
|
573
611
|
def write(self, state: State) -> None:
|
574
612
|
"""Writes the current step's logging information.
|
@@ -1051,6 +1089,73 @@ class Logger:
|
|
1051
1089
|
|
1052
1090
|
self.videos[namespace][key] = video_future
|
1053
1091
|
|
1092
|
+
def log_mesh(
|
1093
|
+
self,
|
1094
|
+
key: str,
|
1095
|
+
vertices: np.ndarray | Array | Callable[[], np.ndarray | Array],
|
1096
|
+
colors: np.ndarray | Array | None | Callable[[], np.ndarray | Array | None] = None,
|
1097
|
+
faces: np.ndarray | Array | None | Callable[[], np.ndarray | Array | None] = None,
|
1098
|
+
config_dict: dict[str, Any] | None = None,
|
1099
|
+
*,
|
1100
|
+
namespace: str | None = None,
|
1101
|
+
) -> None:
|
1102
|
+
if not self.active:
|
1103
|
+
raise RuntimeError("The logger is not active")
|
1104
|
+
namespace = self.resolve_namespace(namespace)
|
1105
|
+
|
1106
|
+
@functools.lru_cache(maxsize=None)
|
1107
|
+
def mesh_future() -> LogMesh:
|
1108
|
+
with ContextTimer() as timer:
|
1109
|
+
# Get the raw values
|
1110
|
+
vertices_val = vertices() if callable(vertices) else vertices
|
1111
|
+
colors_val = colors() if callable(colors) else colors
|
1112
|
+
faces_val = faces() if callable(faces) else faces
|
1113
|
+
|
1114
|
+
# Convert to numpy arrays with proper type handling
|
1115
|
+
vertices_np = as_numpy(vertices_val)
|
1116
|
+
colors_np = as_numpy_opt(colors_val)
|
1117
|
+
faces_np = as_numpy_opt(faces_val)
|
1118
|
+
|
1119
|
+
# Checks vertices shape.
|
1120
|
+
if vertices_np.ndim == 2:
|
1121
|
+
vertices_np = vertices_np[None]
|
1122
|
+
if vertices_np.shape[-1] != 3 or vertices_np.ndim != 3:
|
1123
|
+
raise ValueError("Vertices must have shape (N, 3) or (B, N, 3)")
|
1124
|
+
|
1125
|
+
# Checks colors shape.
|
1126
|
+
if colors_np is not None:
|
1127
|
+
if colors_np.ndim == 2:
|
1128
|
+
colors_np = colors_np[None]
|
1129
|
+
if colors_np.shape[-1] != 3 or colors_np.ndim != 3:
|
1130
|
+
raise ValueError("Colors must have shape (N, 3) or (B, N, 3)")
|
1131
|
+
|
1132
|
+
# Checks faces shape.
|
1133
|
+
if faces_np is not None:
|
1134
|
+
if faces_np.ndim == 2:
|
1135
|
+
faces_np = faces_np[None]
|
1136
|
+
if faces_np.shape[-1] != 3 or faces_np.ndim != 3:
|
1137
|
+
raise ValueError("Faces must have shape (N, 3) or (B, N, 3)")
|
1138
|
+
|
1139
|
+
# Ensures colors dtype is uint8.
|
1140
|
+
if colors_np is not None:
|
1141
|
+
if colors_np.dtype != np.uint8:
|
1142
|
+
colors_np = (colors_np * 255).astype(np.uint8)
|
1143
|
+
|
1144
|
+
# Ensures faces dtype is int32.
|
1145
|
+
if faces_np is not None:
|
1146
|
+
if faces_np.dtype != np.int32:
|
1147
|
+
faces_np = faces_np.astype(np.int32)
|
1148
|
+
|
1149
|
+
logger.debug("Mesh Key: %s, Time: %s", key, timer.elapsed_time)
|
1150
|
+
return LogMesh(
|
1151
|
+
vertices=vertices_np,
|
1152
|
+
colors=colors_np,
|
1153
|
+
faces=faces_np,
|
1154
|
+
config_dict=config_dict,
|
1155
|
+
)
|
1156
|
+
|
1157
|
+
self.meshes[namespace][key] = mesh_future
|
1158
|
+
|
1054
1159
|
def __enter__(self) -> Self:
|
1055
1160
|
self.active = True
|
1056
1161
|
for logger in self.loggers:
|
xax/task/loggers/tensorboard.py
CHANGED
@@ -70,6 +70,9 @@ class TensorboardLogger(LoggerImpl):
|
|
70
70
|
self._started = True
|
71
71
|
|
72
72
|
def worker_thread(self) -> None:
|
73
|
+
if os.environ.get("DISABLE_TENSORBOARD", "0") == "1":
|
74
|
+
return
|
75
|
+
|
73
76
|
time.sleep(self.wait_seconds)
|
74
77
|
|
75
78
|
port = int(os.environ.get("TENSORBOARD_PORT", DEFAULT_TENSORBOARD_PORT))
|
@@ -213,6 +216,19 @@ class TensorboardLogger(LoggerImpl):
|
|
213
216
|
video_value.frames,
|
214
217
|
fps=video_value.fps,
|
215
218
|
global_step=line.state.num_steps,
|
219
|
+
walltime=walltime,
|
220
|
+
)
|
221
|
+
|
222
|
+
for namespace, meshes in line.meshes.items():
|
223
|
+
for mesh_key, mesh_value in meshes.items():
|
224
|
+
writer.add_mesh(
|
225
|
+
f"{namespace}/{mesh_key}",
|
226
|
+
vertices=mesh_value.vertices,
|
227
|
+
faces=mesh_value.faces,
|
228
|
+
colors=mesh_value.colors,
|
229
|
+
config_dict=mesh_value.config_dict,
|
230
|
+
global_step=line.state.num_steps,
|
231
|
+
walltime=walltime,
|
216
232
|
)
|
217
233
|
|
218
234
|
for name, contents in self.files.items():
|
xax/task/mixins/checkpointing.py
CHANGED
@@ -6,9 +6,9 @@ import logging
|
|
6
6
|
import tarfile
|
7
7
|
from dataclasses import asdict, dataclass
|
8
8
|
from pathlib import Path
|
9
|
-
from typing import
|
9
|
+
from typing import Generic, Literal, TypeVar, cast, overload
|
10
10
|
|
11
|
-
import
|
11
|
+
import equinox as eqx
|
12
12
|
import jax
|
13
13
|
import optax
|
14
14
|
from jaxtyping import PyTree
|
@@ -63,8 +63,7 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
|
|
63
63
|
|
64
64
|
def get_init_ckpt_path(self) -> Path | None:
|
65
65
|
if self._exp_dir is not None:
|
66
|
-
ckpt_path
|
67
|
-
if ckpt_path.exists():
|
66
|
+
if (ckpt_path := self.get_ckpt_path()).exists():
|
68
67
|
return ckpt_path
|
69
68
|
if self.config.load_from_ckpt_path is not None:
|
70
69
|
ckpt_path = Path(self.config.load_from_ckpt_path)
|
@@ -84,93 +83,129 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
|
|
84
83
|
return False
|
85
84
|
|
86
85
|
@overload
|
87
|
-
def
|
86
|
+
def load_ckpt_with_template(
|
88
87
|
self,
|
89
88
|
path: Path,
|
90
|
-
|
91
|
-
|
89
|
+
*,
|
90
|
+
part: Literal["all"],
|
91
|
+
model_template: PyTree,
|
92
|
+
optimizer_template: PyTree,
|
93
|
+
opt_state_template: PyTree,
|
94
|
+
) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State, Config]: ...
|
92
95
|
|
93
96
|
@overload
|
94
|
-
def
|
97
|
+
def load_ckpt_with_template(
|
95
98
|
self,
|
96
99
|
path: Path,
|
97
|
-
|
98
|
-
|
100
|
+
*,
|
101
|
+
part: Literal["model_state_config"],
|
102
|
+
model_template: PyTree,
|
103
|
+
) -> tuple[PyTree, State, Config]: ...
|
99
104
|
|
100
105
|
@overload
|
101
|
-
def
|
106
|
+
def load_ckpt_with_template(
|
102
107
|
self,
|
103
108
|
path: Path,
|
109
|
+
*,
|
104
110
|
part: Literal["model"],
|
111
|
+
model_template: PyTree,
|
105
112
|
) -> PyTree: ...
|
106
113
|
|
107
114
|
@overload
|
108
|
-
def
|
115
|
+
def load_ckpt_with_template(
|
109
116
|
self,
|
110
117
|
path: Path,
|
118
|
+
*,
|
111
119
|
part: Literal["opt"],
|
120
|
+
optimizer_template: PyTree,
|
112
121
|
) -> optax.GradientTransformation: ...
|
113
122
|
|
114
123
|
@overload
|
115
|
-
def
|
124
|
+
def load_ckpt_with_template(
|
116
125
|
self,
|
117
126
|
path: Path,
|
127
|
+
*,
|
118
128
|
part: Literal["opt_state"],
|
129
|
+
opt_state_template: PyTree,
|
119
130
|
) -> optax.OptState: ...
|
120
131
|
|
121
132
|
@overload
|
122
|
-
def
|
133
|
+
def load_ckpt_with_template(
|
123
134
|
self,
|
124
135
|
path: Path,
|
136
|
+
*,
|
125
137
|
part: Literal["state"],
|
126
138
|
) -> State: ...
|
127
139
|
|
128
140
|
@overload
|
129
|
-
def
|
141
|
+
def load_ckpt_with_template(
|
130
142
|
self,
|
131
143
|
path: Path,
|
144
|
+
*,
|
132
145
|
part: Literal["config"],
|
133
|
-
) ->
|
146
|
+
) -> Config: ...
|
134
147
|
|
135
|
-
def
|
148
|
+
def load_ckpt_with_template(
|
136
149
|
self,
|
137
150
|
path: Path,
|
151
|
+
*,
|
138
152
|
part: CheckpointPart = "all",
|
153
|
+
model_template: PyTree | None = None,
|
154
|
+
optimizer_template: PyTree | None = None,
|
155
|
+
opt_state_template: PyTree | None = None,
|
139
156
|
) -> (
|
140
|
-
tuple[PyTree, optax.GradientTransformation, optax.OptState, State,
|
141
|
-
| tuple[PyTree, State,
|
157
|
+
tuple[PyTree, optax.GradientTransformation, optax.OptState, State, Config]
|
158
|
+
| tuple[PyTree, State, Config]
|
142
159
|
| PyTree
|
143
160
|
| optax.GradientTransformation
|
144
161
|
| optax.OptState
|
145
162
|
| State
|
146
|
-
|
|
163
|
+
| Config
|
147
164
|
):
|
165
|
+
"""Load a checkpoint.
|
166
|
+
|
167
|
+
Args:
|
168
|
+
path: Path to the checkpoint directory
|
169
|
+
part: Which part of the checkpoint to load
|
170
|
+
model_template: Template model with correct structure but uninitialized weights
|
171
|
+
optimizer_template: Template optimizer with correct structure but uninitialized weights
|
172
|
+
opt_state_template: Template optimizer state with correct structure but uninitialized weights
|
173
|
+
|
174
|
+
Returns:
|
175
|
+
The requested checkpoint components
|
176
|
+
"""
|
148
177
|
with tarfile.open(path, "r:gz") as tar:
|
149
178
|
|
150
179
|
def get_model() -> PyTree:
|
180
|
+
if model_template is None:
|
181
|
+
raise ValueError("model_template must be provided to load model weights")
|
151
182
|
if (model := tar.extractfile("model")) is None:
|
152
183
|
raise ValueError(f"Checkpoint does not contain a model file: {path}")
|
153
|
-
return
|
184
|
+
return eqx.tree_deserialise_leaves(io.BytesIO(model.read()), model_template)
|
154
185
|
|
155
186
|
def get_opt() -> optax.GradientTransformation:
|
156
|
-
if
|
157
|
-
raise ValueError(
|
158
|
-
|
187
|
+
if optimizer_template is None:
|
188
|
+
raise ValueError("optimizer_template must be provided to load optimizer")
|
189
|
+
if (opt := tar.extractfile("optimizer")) is None:
|
190
|
+
raise ValueError(f"Checkpoint does not contain an optimizer file: {path}")
|
191
|
+
return eqx.tree_deserialise_leaves(io.BytesIO(opt.read()), optimizer_template)
|
159
192
|
|
160
193
|
def get_opt_state() -> optax.OptState:
|
194
|
+
if opt_state_template is None:
|
195
|
+
raise ValueError("opt_state_template must be provided to load optimizer state")
|
161
196
|
if (opt_state := tar.extractfile("opt_state")) is None:
|
162
|
-
raise ValueError(f"Checkpoint does not contain an
|
163
|
-
return
|
197
|
+
raise ValueError(f"Checkpoint does not contain an optimizer state file: {path}")
|
198
|
+
return eqx.tree_deserialise_leaves(io.BytesIO(opt_state.read()), opt_state_template)
|
164
199
|
|
165
200
|
def get_state() -> State:
|
166
201
|
if (state := tar.extractfile("state")) is None:
|
167
202
|
raise ValueError(f"Checkpoint does not contain a state file: {path}")
|
168
203
|
return State(**json.loads(state.read().decode()))
|
169
204
|
|
170
|
-
def get_config() ->
|
205
|
+
def get_config() -> Config:
|
171
206
|
if (config := tar.extractfile("config")) is None:
|
172
207
|
raise ValueError(f"Checkpoint does not contain a config file: {path}")
|
173
|
-
return cast(DictConfig, OmegaConf.load(config))
|
208
|
+
return self.get_config(cast(DictConfig, OmegaConf.load(config)), use_cli=False)
|
174
209
|
|
175
210
|
match part:
|
176
211
|
case "model":
|
@@ -192,51 +227,90 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
|
|
192
227
|
|
193
228
|
def save_checkpoint(
|
194
229
|
self,
|
195
|
-
model: PyTree,
|
196
|
-
optimizer: optax.GradientTransformation,
|
197
|
-
opt_state: optax.OptState,
|
198
|
-
|
230
|
+
model: PyTree | None = None,
|
231
|
+
optimizer: optax.GradientTransformation | None = None,
|
232
|
+
opt_state: optax.OptState | None = None,
|
233
|
+
aux_data: PyTree | None = None,
|
234
|
+
state: State | None = None,
|
199
235
|
) -> Path:
|
236
|
+
"""Save a checkpoint.
|
237
|
+
|
238
|
+
Args:
|
239
|
+
model: The model to save
|
240
|
+
state: The current training state
|
241
|
+
optimizer: The optimizer to save
|
242
|
+
aux_data: Additional data to save
|
243
|
+
opt_state: The optimizer state to save
|
244
|
+
|
245
|
+
Returns:
|
246
|
+
Path to the saved checkpoint
|
247
|
+
"""
|
200
248
|
ckpt_path = self.get_ckpt_path(state)
|
201
249
|
|
202
250
|
if not is_master():
|
203
251
|
return ckpt_path
|
204
252
|
|
205
|
-
# Gets the path to the last checkpoint
|
253
|
+
# Gets the path to the last checkpoint
|
206
254
|
logger.info("Saving checkpoint to %s", ckpt_path)
|
207
255
|
last_ckpt_path = self.get_ckpt_path()
|
208
256
|
ckpt_path.parent.mkdir(exist_ok=True, parents=True)
|
209
257
|
|
210
|
-
# Potentially removes the last checkpoint
|
258
|
+
# Potentially removes the last checkpoint
|
211
259
|
if last_ckpt_path.exists() and self.config.only_save_most_recent:
|
212
260
|
if (base_ckpt := last_ckpt_path.resolve()).is_file():
|
213
261
|
base_ckpt.unlink()
|
214
262
|
|
215
|
-
#
|
263
|
+
# Save the checkpoint components
|
216
264
|
with tarfile.open(ckpt_path, "w:gz") as tar:
|
217
265
|
|
218
|
-
def add_file(name: str,
|
266
|
+
def add_file(name: str, buf: io.BytesIO) -> None:
|
267
|
+
tarinfo = tarfile.TarInfo(name)
|
268
|
+
tarinfo.size = buf.tell()
|
269
|
+
buf.seek(0)
|
270
|
+
tar.addfile(tarinfo, buf)
|
271
|
+
|
272
|
+
# Save model using Equinox
|
273
|
+
if model is not None:
|
274
|
+
with io.BytesIO() as buf:
|
275
|
+
eqx.tree_serialise_leaves(buf, model)
|
276
|
+
add_file("model", buf)
|
277
|
+
|
278
|
+
# Save optimizer using Equinox
|
279
|
+
if optimizer is not None:
|
280
|
+
with io.BytesIO() as buf:
|
281
|
+
eqx.tree_serialise_leaves(buf, optimizer)
|
282
|
+
add_file("optimizer", buf)
|
283
|
+
|
284
|
+
# Save optimizer state using Equinox
|
285
|
+
if opt_state is not None:
|
219
286
|
with io.BytesIO() as buf:
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
287
|
+
eqx.tree_serialise_leaves(buf, opt_state)
|
288
|
+
add_file("opt_state", buf)
|
289
|
+
|
290
|
+
# Save aux data using Equinox.
|
291
|
+
if aux_data is not None:
|
292
|
+
with io.BytesIO() as buf:
|
293
|
+
eqx.tree_serialise_leaves(buf, aux_data)
|
294
|
+
add_file("aux_data", buf)
|
295
|
+
|
296
|
+
# Save state and config as JSON
|
297
|
+
def add_file_bytes(name: str, data: bytes) -> None: # noqa: ANN401
|
298
|
+
info = tarfile.TarInfo(name=name)
|
299
|
+
info.size = len(data)
|
300
|
+
tar.addfile(info, io.BytesIO(data))
|
301
|
+
|
302
|
+
if state is not None:
|
303
|
+
add_file_bytes("state", json.dumps(asdict(state), indent=2).encode())
|
304
|
+
add_file_bytes("config", OmegaConf.to_yaml(self.config).encode())
|
305
|
+
|
306
|
+
# Updates the symlink to the new checkpoint
|
233
307
|
last_ckpt_path.unlink(missing_ok=True)
|
234
308
|
try:
|
235
309
|
last_ckpt_path.symlink_to(ckpt_path.relative_to(last_ckpt_path.parent))
|
236
310
|
except FileExistsError:
|
237
311
|
logger.exception("Exception while trying to update %s", ckpt_path)
|
238
312
|
|
239
|
-
# Calls the base callback
|
313
|
+
# Calls the base callback
|
240
314
|
self.on_after_checkpoint_save(ckpt_path, state)
|
241
315
|
|
242
316
|
return ckpt_path
|
xax/task/mixins/data_loader.py
CHANGED
@@ -9,6 +9,7 @@ import jax
|
|
9
9
|
from dpshdl.dataloader import CollatedDataloaderItem, Dataloader
|
10
10
|
from dpshdl.dataset import Dataset, ErrorHandlingDataset
|
11
11
|
from dpshdl.prefetcher import Prefetcher
|
12
|
+
from jaxtyping import PRNGKeyArray
|
12
13
|
from omegaconf import II, MISSING
|
13
14
|
|
14
15
|
from xax.core.conf import field, is_missing
|
@@ -103,7 +104,7 @@ class DataloadersMixin(ProcessMixin[Config], BaseTask[Config], Generic[Config],
|
|
103
104
|
"or `get_data_iterator` to return an iterator for the given dataset."
|
104
105
|
)
|
105
106
|
|
106
|
-
def get_data_iterator(self, phase: Phase) -> Iterator:
|
107
|
+
def get_data_iterator(self, phase: Phase, key: PRNGKeyArray) -> Iterator:
|
107
108
|
raise NotImplementedError(
|
108
109
|
"You must implement either the `get_dataset` method to return the dataset for the given phase, "
|
109
110
|
"or `get_data_iterator` to return an iterator for the given dataset."
|
xax/task/mixins/train.py
CHANGED
@@ -11,7 +11,8 @@ import textwrap
|
|
11
11
|
import time
|
12
12
|
import traceback
|
13
13
|
from abc import ABC, abstractmethod
|
14
|
-
from dataclasses import dataclass, is_dataclass
|
14
|
+
from dataclasses import asdict, dataclass, is_dataclass
|
15
|
+
from pathlib import Path
|
15
16
|
from threading import Thread
|
16
17
|
from typing import (
|
17
18
|
Any,
|
@@ -33,14 +34,13 @@ import jax.numpy as jnp
|
|
33
34
|
import numpy as np
|
34
35
|
import optax
|
35
36
|
from jaxtyping import Array, PRNGKeyArray, PyTree
|
36
|
-
from omegaconf import DictConfig
|
37
37
|
|
38
38
|
from xax.core.conf import field
|
39
39
|
from xax.core.state import Phase, State
|
40
40
|
from xax.nn.functions import set_random_seed
|
41
41
|
from xax.nn.parallel import is_master
|
42
42
|
from xax.task.mixins.artifacts import ArtifactsConfig, ArtifactsMixin
|
43
|
-
from xax.task.mixins.checkpointing import CheckpointingConfig, CheckpointingMixin
|
43
|
+
from xax.task.mixins.checkpointing import CheckpointingConfig, CheckpointingMixin, CheckpointPart
|
44
44
|
from xax.task.mixins.data_loader import DataloadersConfig, DataloadersMixin
|
45
45
|
from xax.task.mixins.logger import LoggerConfig, LoggerMixin
|
46
46
|
from xax.task.mixins.runnable import RunnableConfig, RunnableMixin
|
@@ -50,11 +50,12 @@ from xax.utils.experiments import (
|
|
50
50
|
TrainingFinishedError,
|
51
51
|
diff_configs,
|
52
52
|
get_diff_string,
|
53
|
+
get_info_json,
|
53
54
|
get_state_file_string,
|
54
55
|
get_training_code,
|
55
56
|
)
|
56
57
|
from xax.utils.jax import jit as xax_jit
|
57
|
-
from xax.utils.logging import LOG_STATUS
|
58
|
+
from xax.utils.logging import LOG_PING, LOG_STATUS
|
58
59
|
from xax.utils.text import highlight_exception_message, show_info
|
59
60
|
from xax.utils.types.frozen_dict import FrozenDict
|
60
61
|
|
@@ -340,20 +341,19 @@ class TrainMixin(
|
|
340
341
|
|
341
342
|
if init_ckpt_path is not None:
|
342
343
|
logger.info("Loading checkpoint from %s", init_ckpt_path)
|
343
|
-
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
logger.warning("Loaded config differs from current config:\n%s", config_diff)
|
348
|
-
return model, optimizer, opt_state, state
|
344
|
+
model, state, config = self.load_ckpt(init_ckpt_path, part="model_state_config")
|
345
|
+
config_diff = get_diff_string(diff_configs(asdict(config), asdict(self.config)))
|
346
|
+
if config_diff:
|
347
|
+
logger.warning("Loaded config differs from current config:\n%s", config_diff)
|
349
348
|
|
350
|
-
|
351
|
-
model, state, config = self.load_checkpoint(init_ckpt_path, "model_state_config")
|
352
|
-
config_diff = get_diff_string(diff_configs(config, cast(DictConfig, self.config)))
|
353
|
-
if config_diff:
|
354
|
-
logger.warning("Loaded config differs from current config:\n%s", config_diff)
|
349
|
+
if not load_optimizer:
|
355
350
|
return model, state
|
356
351
|
|
352
|
+
optimizer = self.load_ckpt(init_ckpt_path, part="opt")
|
353
|
+
opt_state = self.load_ckpt(init_ckpt_path, part="opt_state", model=model, optimizer=optimizer)
|
354
|
+
return model, optimizer, opt_state, state
|
355
|
+
|
356
|
+
logger.info("Starting a new training run")
|
357
357
|
model = self.get_model(key)
|
358
358
|
state = State.init_state()
|
359
359
|
|
@@ -365,6 +365,131 @@ class TrainMixin(
|
|
365
365
|
|
366
366
|
return model, optimizer, opt_state, state
|
367
367
|
|
368
|
+
@overload
|
369
|
+
def load_ckpt(
|
370
|
+
self,
|
371
|
+
path: Path,
|
372
|
+
*,
|
373
|
+
part: Literal["all"],
|
374
|
+
) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State, Config]: ...
|
375
|
+
|
376
|
+
@overload
|
377
|
+
def load_ckpt(
|
378
|
+
self,
|
379
|
+
path: Path,
|
380
|
+
*,
|
381
|
+
part: Literal["model_state_config"],
|
382
|
+
) -> tuple[PyTree, State, Config]: ...
|
383
|
+
|
384
|
+
@overload
|
385
|
+
def load_ckpt(
|
386
|
+
self,
|
387
|
+
path: Path,
|
388
|
+
*,
|
389
|
+
part: Literal["model"],
|
390
|
+
) -> PyTree: ...
|
391
|
+
|
392
|
+
@overload
|
393
|
+
def load_ckpt(
|
394
|
+
self,
|
395
|
+
path: Path,
|
396
|
+
*,
|
397
|
+
part: Literal["opt"],
|
398
|
+
) -> optax.GradientTransformation: ...
|
399
|
+
|
400
|
+
@overload
|
401
|
+
def load_ckpt(
|
402
|
+
self,
|
403
|
+
path: Path,
|
404
|
+
*,
|
405
|
+
part: Literal["opt_state"],
|
406
|
+
model: PyTree | None = None,
|
407
|
+
optimizer: optax.GradientTransformation | None = None,
|
408
|
+
) -> optax.OptState: ...
|
409
|
+
|
410
|
+
@overload
|
411
|
+
def load_ckpt(
|
412
|
+
self,
|
413
|
+
path: Path,
|
414
|
+
*,
|
415
|
+
part: Literal["state"],
|
416
|
+
) -> State: ...
|
417
|
+
|
418
|
+
@overload
|
419
|
+
def load_ckpt(
|
420
|
+
self,
|
421
|
+
path: Path,
|
422
|
+
*,
|
423
|
+
part: Literal["config"],
|
424
|
+
) -> Config: ...
|
425
|
+
|
426
|
+
def load_ckpt(
|
427
|
+
self,
|
428
|
+
path: str | Path,
|
429
|
+
*,
|
430
|
+
part: CheckpointPart = "all",
|
431
|
+
model: PyTree | None = None,
|
432
|
+
optimizer: optax.GradientTransformation | None = None,
|
433
|
+
) -> (
|
434
|
+
tuple[PyTree, optax.GradientTransformation, optax.OptState, State, Config]
|
435
|
+
| tuple[PyTree, State, Config]
|
436
|
+
| PyTree
|
437
|
+
| optax.GradientTransformation
|
438
|
+
| optax.OptState
|
439
|
+
| State
|
440
|
+
| Config
|
441
|
+
):
|
442
|
+
path = Path(path)
|
443
|
+
|
444
|
+
# This key isn't used for anything, it's just a required argument.
|
445
|
+
key = jax.random.PRNGKey(0)
|
446
|
+
|
447
|
+
match part:
|
448
|
+
case "model_state_config":
|
449
|
+
model_spec = eqx.filter_eval_shape(self.get_model, key)
|
450
|
+
return self.load_ckpt_with_template(path, part="model_state_config", model_template=model_spec)
|
451
|
+
|
452
|
+
case "model":
|
453
|
+
model_spec = eqx.filter_eval_shape(self.get_model, key)
|
454
|
+
return self.load_ckpt_with_template(path, part="model", model_template=model_spec)
|
455
|
+
|
456
|
+
case "config":
|
457
|
+
return self.load_ckpt_with_template(path, part="config")
|
458
|
+
|
459
|
+
case "opt":
|
460
|
+
optimizer_spec = eqx.filter_eval_shape(self.get_optimizer)
|
461
|
+
return self.load_ckpt_with_template(path, part="opt", optimizer_template=optimizer_spec)
|
462
|
+
|
463
|
+
case "opt_state":
|
464
|
+
if model is None:
|
465
|
+
model_spec = eqx.filter_eval_shape(self.get_model, key)
|
466
|
+
model = self.load_ckpt_with_template(path, part="model", model_template=model_spec)
|
467
|
+
if optimizer is None:
|
468
|
+
optimizer_spec = eqx.filter_eval_shape(self.get_optimizer)
|
469
|
+
optimizer = self.load_ckpt_with_template(path, part="opt", optimizer_template=optimizer_spec)
|
470
|
+
opt_state_spec = eqx.filter_eval_shape(self.get_initial_opt_state, model, optimizer)
|
471
|
+
return self.load_ckpt_with_template(path, part="opt_state", opt_state_template=opt_state_spec)
|
472
|
+
|
473
|
+
case "state":
|
474
|
+
return self.load_ckpt_with_template(path, part="state")
|
475
|
+
|
476
|
+
case "config":
|
477
|
+
return self.load_ckpt_with_template(path, part="config")
|
478
|
+
|
479
|
+
case "all":
|
480
|
+
model_spec = eqx.filter_eval_shape(self.get_model, key)
|
481
|
+
model = self.load_ckpt_with_template(path, part="model", model_template=model_spec)
|
482
|
+
optimizer_spec = eqx.filter_eval_shape(self.get_optimizer)
|
483
|
+
optimizer = self.load_ckpt_with_template(path, part="opt", optimizer_template=optimizer_spec)
|
484
|
+
opt_state_spec = eqx.filter_eval_shape(self.get_initial_opt_state, model, optimizer)
|
485
|
+
opt_state = self.load_ckpt_with_template(path, part="opt_state", opt_state_template=opt_state_spec)
|
486
|
+
state = self.load_ckpt_with_template(path, part="state")
|
487
|
+
config = self.load_ckpt_with_template(path, part="config")
|
488
|
+
return model, optimizer, opt_state, state, config
|
489
|
+
|
490
|
+
case _:
|
491
|
+
raise ValueError(f"Unknown checkpoint part: {part}")
|
492
|
+
|
368
493
|
def get_output(self, model: PyTree, batch: Batch, state: State) -> Output:
|
369
494
|
"""Gets the output from the model.
|
370
495
|
|
@@ -519,8 +644,7 @@ class TrainMixin(
|
|
519
644
|
self._last_printed_remaining_time = state.elapsed_time_s
|
520
645
|
remaining_seconds = remaining_percent * state.elapsed_time_s / (1 - remaining_percent)
|
521
646
|
termination_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time() + remaining_seconds))
|
522
|
-
|
523
|
-
jax.debug.print("Estimated finish time: {}", termination_time)
|
647
|
+
logger.log(LOG_PING, "Estimated finish time: %s", termination_time)
|
524
648
|
|
525
649
|
def get_remaining_percent(self, state: State) -> float | None:
|
526
650
|
if self.config.max_steps is None:
|
@@ -554,6 +678,7 @@ class TrainMixin(
|
|
554
678
|
self.logger.log_file("state.txt", get_state_file_string(self))
|
555
679
|
self.logger.log_file("training_code.py", get_training_code(self))
|
556
680
|
self.logger.log_file("config.yaml", self.config_str(self.config, use_cli=False))
|
681
|
+
self.logger.log_file("info.json", get_info_json())
|
557
682
|
|
558
683
|
def model_partition_fn(self, item: Any) -> bool: # noqa: ANN401
|
559
684
|
return eqx.is_inexact_array(item)
|
@@ -627,16 +752,16 @@ class TrainMixin(
|
|
627
752
|
|
628
753
|
if self.should_checkpoint(state):
|
629
754
|
model = eqx.combine(model_arr, model_static)
|
630
|
-
self.save_checkpoint(model, optimizer, opt_state, state)
|
755
|
+
self.save_checkpoint(model=model, optimizer=optimizer, opt_state=opt_state, state=state)
|
631
756
|
|
632
757
|
# After finishing training, save the final checkpoint.
|
633
758
|
model = eqx.combine(model_arr, model_static)
|
634
|
-
self.save_checkpoint(model, optimizer, opt_state, state)
|
759
|
+
self.save_checkpoint(model=model, optimizer=optimizer, opt_state=opt_state, state=state)
|
635
760
|
|
636
761
|
@contextlib.contextmanager
|
637
|
-
def get_train_iterator(self) -> Generator[Iterator[Batch], None, None]:
|
762
|
+
def get_train_iterator(self, key: PRNGKeyArray) -> Generator[Iterator[Batch], None, None]:
|
638
763
|
try:
|
639
|
-
train_iterator: Iterator[Batch] = self.get_data_iterator("train")
|
764
|
+
train_iterator: Iterator[Batch] = self.get_data_iterator("train", key=key)
|
640
765
|
yield train_iterator
|
641
766
|
return
|
642
767
|
except NotImplementedError:
|
@@ -653,9 +778,9 @@ class TrainMixin(
|
|
653
778
|
logger.info("Closing train prefetcher")
|
654
779
|
|
655
780
|
@contextlib.contextmanager
|
656
|
-
def get_valid_iterator(self) -> Generator[Iterator[Batch], None, None]:
|
781
|
+
def get_valid_iterator(self, key: PRNGKeyArray) -> Generator[Iterator[Batch], None, None]:
|
657
782
|
try:
|
658
|
-
valid_iterator: Iterator[Batch] = self.get_data_iterator("valid")
|
783
|
+
valid_iterator: Iterator[Batch] = self.get_data_iterator("valid", key=key)
|
659
784
|
yield valid_iterator
|
660
785
|
return
|
661
786
|
except NotImplementedError:
|
@@ -699,12 +824,13 @@ class TrainMixin(
|
|
699
824
|
state = self.on_training_start(state)
|
700
825
|
|
701
826
|
def on_exit() -> None:
|
702
|
-
self.save_checkpoint(model, optimizer, opt_state, state)
|
827
|
+
self.save_checkpoint(model=model, optimizer=optimizer, opt_state=opt_state, state=state)
|
703
828
|
|
704
829
|
# Handle user-defined interrupts during the training loop.
|
705
830
|
self.add_signal_handler(on_exit, signal.SIGUSR1, signal.SIGTERM)
|
706
831
|
|
707
|
-
|
832
|
+
key, tkey, vkey = jax.random.split(key, 3)
|
833
|
+
with self.get_train_iterator(tkey) as train_pf, self.get_valid_iterator(vkey) as valid_pf:
|
708
834
|
try:
|
709
835
|
self.train_loop(
|
710
836
|
model=model,
|
@@ -721,7 +847,7 @@ class TrainMixin(
|
|
721
847
|
f"Finished training after {state.num_steps} steps, {state.num_samples} samples",
|
722
848
|
important=True,
|
723
849
|
)
|
724
|
-
self.save_checkpoint(model, optimizer, opt_state, state)
|
850
|
+
self.save_checkpoint(model=model, optimizer=optimizer, opt_state=opt_state, state=state)
|
725
851
|
|
726
852
|
except (KeyboardInterrupt, bdb.BdbQuit):
|
727
853
|
if is_master():
|
@@ -731,7 +857,7 @@ class TrainMixin(
|
|
731
857
|
exception_tb = textwrap.indent(highlight_exception_message(traceback.format_exc()), " ")
|
732
858
|
sys.stdout.write(f"Caught exception during training loop:\n\n{exception_tb}\n")
|
733
859
|
sys.stdout.flush()
|
734
|
-
self.save_checkpoint(model, optimizer, opt_state, state)
|
860
|
+
self.save_checkpoint(model=model, optimizer=optimizer, opt_state=opt_state, state=state)
|
735
861
|
|
736
862
|
finally:
|
737
863
|
state = self.on_training_end(state)
|
xax/utils/experiments.py
CHANGED
@@ -7,6 +7,7 @@ import functools
|
|
7
7
|
import hashlib
|
8
8
|
import inspect
|
9
9
|
import itertools
|
10
|
+
import json
|
10
11
|
import logging
|
11
12
|
import math
|
12
13
|
import os
|
@@ -24,7 +25,7 @@ import warnings
|
|
24
25
|
from abc import ABC, abstractmethod
|
25
26
|
from pathlib import Path
|
26
27
|
from types import TracebackType
|
27
|
-
from typing import Any, Iterator, Self, TypeVar, cast
|
28
|
+
from typing import Any, Iterator, Mapping, Self, Sequence, TypeVar, cast
|
28
29
|
from urllib.parse import urlparse
|
29
30
|
|
30
31
|
import git
|
@@ -116,9 +117,7 @@ class StateTimer:
|
|
116
117
|
|
117
118
|
def log_dict(self) -> dict[str, int | float | tuple[int | float, bool]]:
|
118
119
|
return {
|
119
|
-
"steps": (self.step_timer.steps, True),
|
120
120
|
"steps/second": self.step_timer.steps_per_second,
|
121
|
-
"samples": (self.sample_timer.steps, True),
|
122
121
|
"samples/second": (self.sample_timer.steps_per_second, True),
|
123
122
|
"dt": self.iter_timer.iter_seconds,
|
124
123
|
}
|
@@ -204,8 +203,8 @@ class MinGradScaleError(TrainingFinishedError):
|
|
204
203
|
|
205
204
|
|
206
205
|
def diff_configs(
|
207
|
-
first:
|
208
|
-
second:
|
206
|
+
first: Mapping | Sequence,
|
207
|
+
second: Mapping | Sequence,
|
209
208
|
prefix: str | None = None,
|
210
209
|
) -> tuple[list[str], list[str]]:
|
211
210
|
"""Returns the difference between two configs.
|
@@ -232,7 +231,7 @@ def diff_configs(
|
|
232
231
|
|
233
232
|
any_config = (ListConfig, DictConfig)
|
234
233
|
|
235
|
-
if isinstance(first,
|
234
|
+
if isinstance(first, Mapping) and isinstance(second, Mapping):
|
236
235
|
first_keys, second_keys = cast(set[str], set(first.keys())), cast(set[str], set(second.keys()))
|
237
236
|
|
238
237
|
# Gets the new keys in each config.
|
@@ -242,11 +241,12 @@ def diff_configs(
|
|
242
241
|
# Gets the new sub-keys in each config.
|
243
242
|
for key in first_keys.intersection(second_keys):
|
244
243
|
sub_prefix = key if prefix is None else f"{prefix}.{key}"
|
245
|
-
if
|
246
|
-
if
|
247
|
-
|
248
|
-
|
249
|
-
|
244
|
+
if isinstance(first, DictConfig) and isinstance(second, DictConfig):
|
245
|
+
if OmegaConf.is_missing(first, key) or OmegaConf.is_missing(second, key):
|
246
|
+
if not OmegaConf.is_missing(first, key):
|
247
|
+
new_first += [get_diff_string(sub_prefix, first[key])]
|
248
|
+
if not OmegaConf.is_missing(second, key):
|
249
|
+
new_second += [get_diff_string(sub_prefix, second[key])]
|
250
250
|
elif isinstance(first[key], any_config) and isinstance(second[key], any_config):
|
251
251
|
sub_new_first, sub_new_second = diff_configs(first[key], second[key], prefix=sub_prefix)
|
252
252
|
new_first, new_second = new_first + sub_new_first, new_second + sub_new_second
|
@@ -255,7 +255,7 @@ def diff_configs(
|
|
255
255
|
new_first += [get_diff_string(sub_prefix, first_val)]
|
256
256
|
new_second += [get_diff_string(sub_prefix, second_val)]
|
257
257
|
|
258
|
-
elif isinstance(first,
|
258
|
+
elif isinstance(first, Sequence) and isinstance(second, Sequence):
|
259
259
|
if len(first) > len(second):
|
260
260
|
for i in range(len(second), len(first)):
|
261
261
|
new_first += [get_diff_string(prefix, first[i])]
|
@@ -470,16 +470,33 @@ def get_command_line_string() -> str:
|
|
470
470
|
return " ".join(sys.argv)
|
471
471
|
|
472
472
|
|
473
|
+
def get_environment_variables() -> str:
|
474
|
+
return "\n".join([f"{key}={value}" for key, value in sorted(os.environ.items())])
|
475
|
+
|
476
|
+
|
473
477
|
def get_state_file_string(obj: object) -> str:
|
474
478
|
return "\n\n".join(
|
475
479
|
[
|
476
480
|
f"=== Command Line ===\n\n{get_command_line_string()}",
|
477
481
|
f"=== Git State ===\n\n{get_git_state(obj)}",
|
478
482
|
f"=== Packages ===\n\n{get_packages_with_versions()}",
|
483
|
+
f"=== Environment Variables ===\n\n{get_environment_variables()}",
|
479
484
|
]
|
480
485
|
)
|
481
486
|
|
482
487
|
|
488
|
+
def get_info_json() -> str:
|
489
|
+
return json.dumps(
|
490
|
+
{
|
491
|
+
"process_id": os.getpid(),
|
492
|
+
"job": {
|
493
|
+
"start_time": datetime.datetime.now().isoformat(),
|
494
|
+
},
|
495
|
+
},
|
496
|
+
indent=2,
|
497
|
+
)
|
498
|
+
|
499
|
+
|
483
500
|
def get_training_code(obj: object) -> str:
|
484
501
|
"""Gets the text from the file containing the provided object.
|
485
502
|
|
xax/utils/tensorboard.py
CHANGED
@@ -2,11 +2,12 @@
|
|
2
2
|
|
3
3
|
import functools
|
4
4
|
import io
|
5
|
+
import json
|
5
6
|
import os
|
6
7
|
import tempfile
|
7
8
|
import time
|
8
9
|
from pathlib import Path
|
9
|
-
from typing import Literal, TypedDict
|
10
|
+
from typing import Any, Literal, TypedDict
|
10
11
|
|
11
12
|
import numpy as np
|
12
13
|
import PIL.Image
|
@@ -14,9 +15,15 @@ from PIL.Image import Image as PILImage
|
|
14
15
|
from tensorboard.compat.proto.config_pb2 import RunMetadata
|
15
16
|
from tensorboard.compat.proto.event_pb2 import Event, TaggedRunMetadata
|
16
17
|
from tensorboard.compat.proto.graph_pb2 import GraphDef
|
17
|
-
from tensorboard.compat.proto.summary_pb2 import
|
18
|
+
from tensorboard.compat.proto.summary_pb2 import (
|
19
|
+
HistogramProto,
|
20
|
+
Summary,
|
21
|
+
SummaryMetadata,
|
22
|
+
)
|
18
23
|
from tensorboard.compat.proto.tensor_pb2 import TensorProto
|
19
24
|
from tensorboard.compat.proto.tensor_shape_pb2 import TensorShapeProto
|
25
|
+
from tensorboard.plugins.mesh import metadata as mesh_metadata
|
26
|
+
from tensorboard.plugins.mesh.plugin_data_pb2 import MeshPluginData
|
20
27
|
from tensorboard.plugins.text.plugin_data_pb2 import TextPluginData
|
21
28
|
from tensorboard.summary.writer.event_file_writer import EventFileWriter
|
22
29
|
|
@@ -84,6 +91,68 @@ def make_histogram(values: np.ndarray, bins: str | np.ndarray, max_bins: int | N
|
|
84
91
|
)
|
85
92
|
|
86
93
|
|
94
|
+
def _get_json_config(config_dict: dict[str, Any] | None) -> str:
|
95
|
+
json_config = "{}"
|
96
|
+
if config_dict is not None:
|
97
|
+
json_config = json.dumps(config_dict, sort_keys=True)
|
98
|
+
return json_config
|
99
|
+
|
100
|
+
|
101
|
+
def make_mesh_summary(
|
102
|
+
tag: str,
|
103
|
+
vertices: np.ndarray,
|
104
|
+
colors: np.ndarray | None,
|
105
|
+
faces: np.ndarray | None,
|
106
|
+
config_dict: dict[str, Any] | None,
|
107
|
+
display_name: str | None = None,
|
108
|
+
description: str | None = None,
|
109
|
+
) -> Summary:
|
110
|
+
json_config = _get_json_config(config_dict)
|
111
|
+
|
112
|
+
summaries = []
|
113
|
+
tensors = [
|
114
|
+
(vertices, MeshPluginData.VERTEX),
|
115
|
+
(faces, MeshPluginData.FACE),
|
116
|
+
(colors, MeshPluginData.COLOR),
|
117
|
+
]
|
118
|
+
# Filter out None tensors and explicitly type the list
|
119
|
+
valid_tensors = [(t, content_type) for t, content_type in tensors if t is not None]
|
120
|
+
components = mesh_metadata.get_components_bitmask([content_type for (_, content_type) in valid_tensors])
|
121
|
+
|
122
|
+
for tensor, content_type in valid_tensors: # Now we know tensor is not None
|
123
|
+
tensor_metadata = mesh_metadata.create_summary_metadata(
|
124
|
+
tag,
|
125
|
+
display_name,
|
126
|
+
content_type,
|
127
|
+
components,
|
128
|
+
tensor.shape, # Safe now since tensor is not None
|
129
|
+
description,
|
130
|
+
json_config=json_config,
|
131
|
+
)
|
132
|
+
|
133
|
+
tensor_proto = TensorProto(
|
134
|
+
dtype="DT_FLOAT",
|
135
|
+
float_val=tensor.reshape(-1).tolist(), # Safe now since tensor is not None
|
136
|
+
tensor_shape=TensorShapeProto(
|
137
|
+
dim=[
|
138
|
+
TensorShapeProto.Dim(size=tensor.shape[0]), # Safe now since tensor is not None
|
139
|
+
TensorShapeProto.Dim(size=tensor.shape[1]),
|
140
|
+
TensorShapeProto.Dim(size=tensor.shape[2]),
|
141
|
+
]
|
142
|
+
),
|
143
|
+
)
|
144
|
+
|
145
|
+
tensor_summary = Summary.Value(
|
146
|
+
tag=mesh_metadata.get_instance_name(tag, content_type),
|
147
|
+
tensor=tensor_proto,
|
148
|
+
metadata=tensor_metadata,
|
149
|
+
)
|
150
|
+
|
151
|
+
summaries.append(tensor_summary)
|
152
|
+
|
153
|
+
return Summary(value=summaries)
|
154
|
+
|
155
|
+
|
87
156
|
class TensorboardProtobufWriter:
|
88
157
|
def __init__(
|
89
158
|
self,
|
@@ -454,6 +523,9 @@ class TensorboardWriter:
|
|
454
523
|
weighted_sum = float((bin_centers * bucket_counts).sum())
|
455
524
|
weighted_sum_squares = float((bin_centers**2 * bucket_counts).sum())
|
456
525
|
|
526
|
+
# Convert bin edges to list of floats explicitly
|
527
|
+
bucket_limits: list[float | np.ndarray] = [float(x) for x in bin_edges[1:]]
|
528
|
+
|
457
529
|
self.add_histogram_raw(
|
458
530
|
tag=tag,
|
459
531
|
min=float(bin_edges[0]),
|
@@ -461,12 +533,28 @@ class TensorboardWriter:
|
|
461
533
|
num=int(total_counts),
|
462
534
|
sum=weighted_sum,
|
463
535
|
sum_squares=weighted_sum_squares,
|
464
|
-
bucket_limits=
|
536
|
+
bucket_limits=bucket_limits, # Now properly typed
|
465
537
|
bucket_counts=bucket_counts.tolist(),
|
466
538
|
global_step=global_step,
|
467
539
|
walltime=walltime,
|
468
540
|
)
|
469
541
|
|
542
|
+
def add_mesh(
|
543
|
+
self,
|
544
|
+
tag: str,
|
545
|
+
vertices: np.ndarray,
|
546
|
+
colors: np.ndarray | None,
|
547
|
+
faces: np.ndarray | None,
|
548
|
+
config_dict: dict[str, Any] | None,
|
549
|
+
global_step: int | None = None,
|
550
|
+
walltime: float | None = None,
|
551
|
+
) -> None:
|
552
|
+
self.pb_writer.add_summary(
|
553
|
+
make_mesh_summary(tag, vertices, colors, faces, config_dict),
|
554
|
+
global_step=global_step,
|
555
|
+
walltime=walltime,
|
556
|
+
)
|
557
|
+
|
470
558
|
|
471
559
|
class TensorboardWriterKwargs(TypedDict):
|
472
560
|
max_queue_size: int
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: xax
|
3
|
-
Version: 0.1
|
3
|
+
Version: 0.2.1
|
4
4
|
Summary: A library for fast Jax experimentation
|
5
5
|
Home-page: https://github.com/kscalelabs/xax
|
6
6
|
Author: Benjamin Bolte
|
@@ -8,14 +8,14 @@ Requires-Python: >=3.11
|
|
8
8
|
Description-Content-Type: text/markdown
|
9
9
|
License-File: LICENSE
|
10
10
|
Requires-Dist: attrs
|
11
|
+
Requires-Dist: chex
|
12
|
+
Requires-Dist: dpshdl
|
13
|
+
Requires-Dist: equinox
|
14
|
+
Requires-Dist: importlib-resources
|
11
15
|
Requires-Dist: jax
|
12
16
|
Requires-Dist: jaxtyping
|
13
|
-
Requires-Dist: equinox
|
14
17
|
Requires-Dist: optax
|
15
|
-
Requires-Dist:
|
16
|
-
Requires-Dist: chex
|
17
|
-
Requires-Dist: importlib-resources
|
18
|
-
Requires-Dist: cloudpickle
|
18
|
+
Requires-Dist: orbax-checkpoint
|
19
19
|
Requires-Dist: pillow
|
20
20
|
Requires-Dist: omegaconf
|
21
21
|
Requires-Dist: gitpython
|
@@ -1,23 +1,23 @@
|
|
1
|
-
xax/__init__.py,sha256=
|
1
|
+
xax/__init__.py,sha256=kd-88OQGnuHb91PXwroAfLb0bMfbe37fXqpECRrjhoU,14182
|
2
2
|
xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
3
|
xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
|
4
|
-
xax/requirements.txt,sha256=
|
4
|
+
xax/requirements.txt,sha256=6qY-84e-sTmlfJNrSjwONQKqzAn5h8G_oGIhnhmfSr4,302
|
5
5
|
xax/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
6
6
|
xax/core/conf.py,sha256=Wuo5WLRWuRTgb8eaihvnG_NZskTu0-P3JkIcl_hKINM,5124
|
7
|
-
xax/core/state.py,sha256=
|
7
|
+
xax/core/state.py,sha256=XejW1tGINYFFcNrscK8eZQsq02J7_RXa461QpmyWuLk,3337
|
8
8
|
xax/nn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
9
9
|
xax/nn/embeddings.py,sha256=bQGxBFxkLwi2MQLkRfGaHPH5P_KKB21HdI7VNWTKIOQ,11847
|
10
10
|
xax/nn/equinox.py,sha256=5fdOKRXqAVZPsV-aEez3i1wamr_oBYnG74GP1jEthjM,4843
|
11
11
|
xax/nn/export.py,sha256=7Yemw3T33QGEP8RkmTkpu6tRVOhut2RUJmttNFfCgFw,5537
|
12
12
|
xax/nn/functions.py,sha256=CI_OmspaQwN9nl4hwefIU3_I7m6gBZwJ9aGK1JGUgr0,2713
|
13
|
-
xax/nn/geom.py,sha256=
|
13
|
+
xax/nn/geom.py,sha256=rImNlkHWeoNcY7f84nknizJ6uzsrMhbAtKeb2xAWxNY,6215
|
14
14
|
xax/nn/losses.py,sha256=Q_NVnm5n4UPBvp5nI_1aUptfXnqFYoUeFwySiyvopHg,272
|
15
15
|
xax/nn/norm.py,sha256=WgZ3QCrUnf-YecwhEtVPcr99fKK3ECl_UeiAs2uv7oo,564
|
16
16
|
xax/nn/parallel.py,sha256=fnTiT7MsG7eQrJvqwjIz2Ifo3P27TuxIJzmpGYSa_dQ,4608
|
17
17
|
xax/nn/ssm.py,sha256=8dLAcQ1hBaMT-kkHvwGu_ecxJeTY32WeMYmd4T4KtxA,10745
|
18
18
|
xax/task/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
19
|
-
xax/task/base.py,sha256=
|
20
|
-
xax/task/logger.py,sha256=
|
19
|
+
xax/task/base.py,sha256=OnXi2hiKPGwt6ng1dutnoQSiw7lEiWFlC_vx99_JsbQ,7694
|
20
|
+
xax/task/logger.py,sha256=peGtfnvnBKr9l6tx1V6XAsvPs0HP6ubV_aE7IJtOMNk,40868
|
21
21
|
xax/task/script.py,sha256=bMMIJoUtpSBvPp6-7bejTrajTXvSg0794sYLKdPIToE,972
|
22
22
|
xax/task/task.py,sha256=UHMpnv__gqMcfbC_L-Hhk-DCnUYlFVsgbNf-v8o8B7U,1424
|
23
23
|
xax/task/launchers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -29,37 +29,37 @@ xax/task/loggers/callback.py,sha256=lyuZX6Bir7xJM07ifdQIl1jlclgkiS82UO9V4y7wgPs,
|
|
29
29
|
xax/task/loggers/json.py,sha256=_tKum6jk_gqVzO-4MqSNXbE-Mmn-yJzkRAT-N1y2zes,4139
|
30
30
|
xax/task/loggers/state.py,sha256=6bG-NRsSUzAukYiglCT0oDj8zRMpffH4e1TKWGw1x4k,959
|
31
31
|
xax/task/loggers/stdout.py,sha256=oeIgPkj4RyJgBuWaJK9ncLa65iBNJCWXhSF8fx3_54c,6564
|
32
|
-
xax/task/loggers/tensorboard.py,sha256=
|
32
|
+
xax/task/loggers/tensorboard.py,sha256=KOL9l60tLctX-VAdNwe49H48SAJeGxph3sflJpojA-4,8337
|
33
33
|
xax/task/mixins/__init__.py,sha256=D3oU31rB9FeOr9MPLleLt5JFbftUr4sBTwgnwQdc2qA,809
|
34
34
|
xax/task/mixins/artifacts.py,sha256=2ezmZGzPGe3nhsd9KRkeHWWXdbT9m7drzimIfw6v1XY,2892
|
35
|
-
xax/task/mixins/checkpointing.py,sha256=
|
35
|
+
xax/task/mixins/checkpointing.py,sha256=2nJgqFcV-D8W-4j8TR3PvVh1g5hQUOo-_quKO-XlE4U,11398
|
36
36
|
xax/task/mixins/compile.py,sha256=PG5aF3W9v_xGiImHgUJ7gmwuQQoSQWufdpl2N_mlLX0,3922
|
37
37
|
xax/task/mixins/cpu_stats.py,sha256=vAjEc3HpPnl56m7vshYX0dXAHJrB98DzVdsYSRqQllc,9371
|
38
|
-
xax/task/mixins/data_loader.py,sha256=
|
38
|
+
xax/task/mixins/data_loader.py,sha256=Tp7zqPdfH2_JuE6J6EP-fEtCQpq9MjKlGHYK7Zh-goU,6599
|
39
39
|
xax/task/mixins/gpu_stats.py,sha256=4HU6teEDlqMitLbSx7fbyL4qBJ0PgGy0Ly_Pzife8yo,8795
|
40
40
|
xax/task/mixins/logger.py,sha256=6oXsJJyNUx6YT3q58FVXMZBUpMgjVkGre6BXFN20cVI,2808
|
41
41
|
xax/task/mixins/process.py,sha256=d1opVgvc6bOFXb7R58b07F4P5lbSZIzYaajtE0eBbpw,1477
|
42
42
|
xax/task/mixins/runnable.py,sha256=IYIsLd2k09g-_y6o44EhJqT7E6BpsyEMmsyLSuzqjtc,1979
|
43
43
|
xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
|
44
|
-
xax/task/mixins/train.py,sha256=
|
44
|
+
xax/task/mixins/train.py,sha256=v9oi9tNsNBYo-Ne_98nCG9qHX6sxvymHjsRDnL6GL-U,30871
|
45
45
|
xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
46
46
|
xax/utils/debugging.py,sha256=OtUdu-3tQsQtik0Q9UM-SNV46IbPjwrAfZcywzoB5d4,1940
|
47
|
-
xax/utils/experiments.py,sha256=
|
47
|
+
xax/utils/experiments.py,sha256=Hzl46_9IH5_9cKzxit-FyVUWBH-_lBs00ZciuIdnWO8,29811
|
48
48
|
xax/utils/jax.py,sha256=tC0NNelbrSTzwNGluiwLGKtoHhVpgdzrv-xherB3VtY,4752
|
49
49
|
xax/utils/jaxpr.py,sha256=S80nyEkv188RInzq3kCAdkQCU-bf6s0oPTrCE_LjkRs,2298
|
50
50
|
xax/utils/logging.py,sha256=GAhTne2rdB4Fa1lzk06DMO15U8MTejn6XTClShC-ZtU,6622
|
51
51
|
xax/utils/numpy.py,sha256=_jOXVi-d2AtJnRftPkRK5MDMzsU8slgw-Jjv4GRm6ns,1197
|
52
52
|
xax/utils/profile.py,sha256=-aFdWpgYFvBsBZXSLL4zXrFe3zzsDqzmx4q5f2WOtpQ,1628
|
53
53
|
xax/utils/pytree.py,sha256=VFWhT0MQ99KjQyEYM6NFbqYq4_hOZwB23uhowMB4U34,8754
|
54
|
-
xax/utils/tensorboard.py,sha256=
|
54
|
+
xax/utils/tensorboard.py,sha256=P0oIFvX2Qts1H4lkpizhRIpQdD0MNppVMeut0Z94yCs,19878
|
55
55
|
xax/utils/text.py,sha256=zo1sAoZe59GkpcpaHBVOQ0OekSMGXvOAyNa3lOJozCY,10628
|
56
56
|
xax/utils/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
57
57
|
xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,7066
|
58
58
|
xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
59
59
|
xax/utils/types/frozen_dict.py,sha256=ZCMGfSfr2_b2qZbq9ywPD0zej5tpVSId2JftXpwfB5k,4686
|
60
60
|
xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
|
61
|
-
xax-0.1.
|
62
|
-
xax-0.1.
|
63
|
-
xax-0.1.
|
64
|
-
xax-0.1.
|
65
|
-
xax-0.1.
|
61
|
+
xax-0.2.1.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
|
62
|
+
xax-0.2.1.dist-info/METADATA,sha256=2pOZLKMIcLoQTM-tRqRvVkF57PZyMoALM87UI5B4dtk,1882
|
63
|
+
xax-0.2.1.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
64
|
+
xax-0.2.1.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
|
65
|
+
xax-0.2.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|