xax 0.2.6__py3-none-any.whl → 0.2.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +52 -2
- xax/core/conf.py +1 -1
- xax/nn/equinox.py +6 -3
- xax/nn/functions.py +8 -5
- xax/nn/geom.py +49 -0
- xax/task/base.py +2 -2
- xax/task/logger.py +11 -6
- xax/task/loggers/callback.py +6 -0
- xax/task/loggers/json.py +14 -2
- xax/task/loggers/state.py +26 -1
- xax/task/loggers/stdout.py +4 -2
- xax/task/loggers/tensorboard.py +19 -1
- xax/task/mixins/artifacts.py +11 -8
- xax/task/mixins/checkpointing.py +108 -143
- xax/task/mixins/train.py +21 -17
- xax/utils/experiments.py +2 -1
- xax/utils/jaxpr.py +5 -5
- xax/utils/pytree.py +9 -2
- xax/utils/text.py +2 -2
- xax/utils/types/frozen_dict.py +2 -2
- {xax-0.2.6.dist-info → xax-0.2.8.dist-info}/METADATA +1 -1
- {xax-0.2.6.dist-info → xax-0.2.8.dist-info}/RECORD +25 -25
- {xax-0.2.6.dist-info → xax-0.2.8.dist-info}/WHEEL +1 -1
- {xax-0.2.6.dist-info → xax-0.2.8.dist-info}/licenses/LICENSE +0 -0
- {xax-0.2.6.dist-info → xax-0.2.8.dist-info}/top_level.txt +0 -0
xax/__init__.py
CHANGED
@@ -12,7 +12,7 @@ and running the update script:
|
|
12
12
|
python -m scripts.update_api --inplace
|
13
13
|
"""
|
14
14
|
|
15
|
-
__version__ = "0.2.
|
15
|
+
__version__ = "0.2.8"
|
16
16
|
|
17
17
|
# This list shouldn't be modified by hand; instead, run the update script.
|
18
18
|
__all__ = [
|
@@ -43,9 +43,12 @@ __all__ = [
|
|
43
43
|
"cubic_bezier_interpolation",
|
44
44
|
"euler_to_quat",
|
45
45
|
"get_projected_gravity_vector_from_quat",
|
46
|
+
"normalize",
|
46
47
|
"quat_to_euler",
|
47
48
|
"quat_to_rotmat",
|
48
49
|
"rotate_vector_by_quat",
|
50
|
+
"rotation6d_to_rotation_matrix",
|
51
|
+
"rotation_matrix_to_rotation6d",
|
49
52
|
"cross_entropy",
|
50
53
|
"cast_norm_type",
|
51
54
|
"get_norm",
|
@@ -57,8 +60,18 @@ __all__ = [
|
|
57
60
|
"BaseLauncher",
|
58
61
|
"CliLauncher",
|
59
62
|
"SingleProcessLauncher",
|
63
|
+
"LogDistribution",
|
64
|
+
"LogError",
|
65
|
+
"LogErrorSummary",
|
66
|
+
"LogGraph",
|
67
|
+
"LogHistogram",
|
60
68
|
"LogImage",
|
61
69
|
"LogLine",
|
70
|
+
"LogMesh",
|
71
|
+
"LogPing",
|
72
|
+
"LogScalar",
|
73
|
+
"LogStatus",
|
74
|
+
"LogVideo",
|
62
75
|
"Logger",
|
63
76
|
"LoggerImpl",
|
64
77
|
"CallbackLogger",
|
@@ -66,6 +79,7 @@ __all__ = [
|
|
66
79
|
"StateLogger",
|
67
80
|
"StdoutLogger",
|
68
81
|
"TensorboardLogger",
|
82
|
+
"load_ckpt",
|
69
83
|
"CPUStatsOptions",
|
70
84
|
"DataloaderConfig",
|
71
85
|
"GPUStatsOptions",
|
@@ -115,6 +129,7 @@ __all__ = [
|
|
115
129
|
"compute_nan_ratio",
|
116
130
|
"flatten_array",
|
117
131
|
"flatten_pytree",
|
132
|
+
"get_pytree_param_count",
|
118
133
|
"pytree_has_nans",
|
119
134
|
"reshuffle_pytree",
|
120
135
|
"reshuffle_pytree_along_dims",
|
@@ -207,9 +222,12 @@ NAME_MAP: dict[str, str] = {
|
|
207
222
|
"cubic_bezier_interpolation": "nn.geom",
|
208
223
|
"euler_to_quat": "nn.geom",
|
209
224
|
"get_projected_gravity_vector_from_quat": "nn.geom",
|
225
|
+
"normalize": "nn.geom",
|
210
226
|
"quat_to_euler": "nn.geom",
|
211
227
|
"quat_to_rotmat": "nn.geom",
|
212
228
|
"rotate_vector_by_quat": "nn.geom",
|
229
|
+
"rotation6d_to_rotation_matrix": "nn.geom",
|
230
|
+
"rotation_matrix_to_rotation6d": "nn.geom",
|
213
231
|
"cross_entropy": "nn.losses",
|
214
232
|
"cast_norm_type": "nn.norm",
|
215
233
|
"get_norm": "nn.norm",
|
@@ -221,8 +239,18 @@ NAME_MAP: dict[str, str] = {
|
|
221
239
|
"BaseLauncher": "task.launchers.base",
|
222
240
|
"CliLauncher": "task.launchers.cli",
|
223
241
|
"SingleProcessLauncher": "task.launchers.single_process",
|
242
|
+
"LogDistribution": "task.logger",
|
243
|
+
"LogError": "task.logger",
|
244
|
+
"LogErrorSummary": "task.logger",
|
245
|
+
"LogGraph": "task.logger",
|
246
|
+
"LogHistogram": "task.logger",
|
224
247
|
"LogImage": "task.logger",
|
225
248
|
"LogLine": "task.logger",
|
249
|
+
"LogMesh": "task.logger",
|
250
|
+
"LogPing": "task.logger",
|
251
|
+
"LogScalar": "task.logger",
|
252
|
+
"LogStatus": "task.logger",
|
253
|
+
"LogVideo": "task.logger",
|
226
254
|
"Logger": "task.logger",
|
227
255
|
"LoggerImpl": "task.logger",
|
228
256
|
"CallbackLogger": "task.loggers.callback",
|
@@ -230,6 +258,7 @@ NAME_MAP: dict[str, str] = {
|
|
230
258
|
"StateLogger": "task.loggers.state",
|
231
259
|
"StdoutLogger": "task.loggers.stdout",
|
232
260
|
"TensorboardLogger": "task.loggers.tensorboard",
|
261
|
+
"load_ckpt": "task.mixins.checkpointing",
|
233
262
|
"CPUStatsOptions": "task.mixins.cpu_stats",
|
234
263
|
"DataloaderConfig": "task.mixins.data_loader",
|
235
264
|
"GPUStatsOptions": "task.mixins.gpu_stats",
|
@@ -279,6 +308,7 @@ NAME_MAP: dict[str, str] = {
|
|
279
308
|
"compute_nan_ratio": "utils.pytree",
|
280
309
|
"flatten_array": "utils.pytree",
|
281
310
|
"flatten_pytree": "utils.pytree",
|
311
|
+
"get_param_count": "utils.pytree",
|
282
312
|
"pytree_has_nans": "utils.pytree",
|
283
313
|
"reshuffle_pytree": "utils.pytree",
|
284
314
|
"reshuffle_pytree_along_dims": "utils.pytree",
|
@@ -372,9 +402,12 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
372
402
|
cubic_bezier_interpolation,
|
373
403
|
euler_to_quat,
|
374
404
|
get_projected_gravity_vector_from_quat,
|
405
|
+
normalize,
|
375
406
|
quat_to_euler,
|
376
407
|
quat_to_rotmat,
|
377
408
|
rotate_vector_by_quat,
|
409
|
+
rotation6d_to_rotation_matrix,
|
410
|
+
rotation_matrix_to_rotation6d,
|
378
411
|
)
|
379
412
|
from xax.nn.losses import cross_entropy
|
380
413
|
from xax.nn.norm import NormType, cast_norm_type, get_norm
|
@@ -384,12 +417,28 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
384
417
|
from xax.task.launchers.base import BaseLauncher
|
385
418
|
from xax.task.launchers.cli import CliLauncher
|
386
419
|
from xax.task.launchers.single_process import SingleProcessLauncher
|
387
|
-
from xax.task.logger import
|
420
|
+
from xax.task.logger import (
|
421
|
+
LogDistribution,
|
422
|
+
LogError,
|
423
|
+
LogErrorSummary,
|
424
|
+
Logger,
|
425
|
+
LoggerImpl,
|
426
|
+
LogGraph,
|
427
|
+
LogHistogram,
|
428
|
+
LogImage,
|
429
|
+
LogLine,
|
430
|
+
LogMesh,
|
431
|
+
LogPing,
|
432
|
+
LogScalar,
|
433
|
+
LogStatus,
|
434
|
+
LogVideo,
|
435
|
+
)
|
388
436
|
from xax.task.loggers.callback import CallbackLogger
|
389
437
|
from xax.task.loggers.json import JsonLogger
|
390
438
|
from xax.task.loggers.state import StateLogger
|
391
439
|
from xax.task.loggers.stdout import StdoutLogger
|
392
440
|
from xax.task.loggers.tensorboard import TensorboardLogger
|
441
|
+
from xax.task.mixins.checkpointing import load_ckpt
|
393
442
|
from xax.task.mixins.cpu_stats import CPUStatsOptions
|
394
443
|
from xax.task.mixins.data_loader import DataloaderConfig
|
395
444
|
from xax.task.mixins.gpu_stats import GPUStatsOptions
|
@@ -439,6 +488,7 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
439
488
|
compute_nan_ratio,
|
440
489
|
flatten_array,
|
441
490
|
flatten_pytree,
|
491
|
+
get_pytree_param_count,
|
442
492
|
pytree_has_nans,
|
443
493
|
reshuffle_pytree,
|
444
494
|
reshuffle_pytree_along_dims,
|
xax/core/conf.py
CHANGED
@@ -26,7 +26,7 @@ def field(value: FieldType, **kwargs: str) -> FieldType:
|
|
26
26
|
metadata: dict[str, Any] = {}
|
27
27
|
metadata.update(kwargs)
|
28
28
|
|
29
|
-
if hasattr(value, "__call__"):
|
29
|
+
if hasattr(value, "__call__"): # noqa: B004
|
30
30
|
return field_base(default_factory=value, metadata=metadata)
|
31
31
|
if value.__class__.__hash__ is None:
|
32
32
|
return field_base(default_factory=lambda: value, metadata=metadata)
|
xax/nn/equinox.py
CHANGED
@@ -68,8 +68,8 @@ def _infer_activation(activation: ActivationFunction) -> Callable:
|
|
68
68
|
return lambda x: x
|
69
69
|
try:
|
70
70
|
return getattr(jax.nn, activation)
|
71
|
-
except AttributeError:
|
72
|
-
raise ValueError(f"Activation function `{activation}` not found in `jax.nn`")
|
71
|
+
except AttributeError as err:
|
72
|
+
raise ValueError(f"Activation function `{activation}` not found in `jax.nn`") from err
|
73
73
|
|
74
74
|
|
75
75
|
def make_eqx_mlp(hyperparams: MLPHyperParams, *, key: PRNGKeyArray) -> eqx.nn.MLP:
|
@@ -100,7 +100,7 @@ def make_eqx_mlp(hyperparams: MLPHyperParams, *, key: PRNGKeyArray) -> eqx.nn.ML
|
|
100
100
|
def export_eqx_mlp(
|
101
101
|
model: eqx.nn.MLP,
|
102
102
|
output_path: str | Path,
|
103
|
-
dtype: jax.numpy.dtype =
|
103
|
+
dtype: jax.numpy.dtype | None = None,
|
104
104
|
) -> None:
|
105
105
|
"""Serialize an Equinox MLP to a .eqx file.
|
106
106
|
|
@@ -109,6 +109,9 @@ def export_eqx_mlp(
|
|
109
109
|
output_path: The path to save the exported model.
|
110
110
|
dtype: The dtype of the model.
|
111
111
|
"""
|
112
|
+
if dtype is None:
|
113
|
+
dtype = eqx._misc.default_floating_dtype()
|
114
|
+
|
112
115
|
activation = model.activation.__name__
|
113
116
|
final_activation = model.final_activation.__name__
|
114
117
|
|
xax/nn/functions.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1
1
|
# mypy: disable-error-code="override"
|
2
|
-
"""Defines helper
|
2
|
+
"""Defines helper Jax functions."""
|
3
3
|
|
4
4
|
import random
|
5
5
|
from dataclasses import is_dataclass
|
@@ -58,13 +58,16 @@ def recursive_chunk(item: Any, num_chunks: int, dim: int = 0) -> Iterable[Any]:
|
|
58
58
|
yield from np.array_split(item, num_chunks, axis=dim)
|
59
59
|
elif is_dataclass(item):
|
60
60
|
yield from (
|
61
|
-
item.__class__(**{k: i for k, i in zip(item.__dict__, ii)})
|
62
|
-
for ii in zip(*(recursive_chunk(v, num_chunks, dim) for v in item.__dict__.values()))
|
61
|
+
item.__class__(**{k: i for k, i in zip(item.__dict__, ii, strict=True)})
|
62
|
+
for ii in zip(*(recursive_chunk(v, num_chunks, dim) for v in item.__dict__.values()), strict=False)
|
63
63
|
)
|
64
64
|
elif isinstance(item, Mapping):
|
65
|
-
yield from (
|
65
|
+
yield from (
|
66
|
+
dict(zip(item, ii, strict=False))
|
67
|
+
for ii in zip(*(recursive_chunk(i, num_chunks, dim) for i in item.values()), strict=False)
|
68
|
+
)
|
66
69
|
elif isinstance(item, Sequence):
|
67
|
-
yield from (list(ii) for ii in zip(*(recursive_chunk(i, num_chunks, dim) for i in item)))
|
70
|
+
yield from (list(ii) for ii in zip(*(recursive_chunk(i, num_chunks, dim) for i in item), strict=False))
|
68
71
|
else:
|
69
72
|
yield from (item for _ in range(num_chunks))
|
70
73
|
|
xax/nn/geom.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
"""Defines geometry functions."""
|
2
2
|
|
3
|
+
import chex
|
3
4
|
from jax import numpy as jnp
|
4
5
|
from jaxtyping import Array
|
5
6
|
|
@@ -211,3 +212,51 @@ def quat_to_rotmat(quat: Array, eps: float = 1e-6) -> Array:
|
|
211
212
|
],
|
212
213
|
axis=-2,
|
213
214
|
)
|
215
|
+
|
216
|
+
|
217
|
+
def normalize(v: jnp.ndarray, axis: int = -1, eps: float = 1e-8) -> jnp.ndarray:
|
218
|
+
norm = jnp.linalg.norm(v, axis=axis, keepdims=True)
|
219
|
+
return v / jnp.clip(norm, a_min=eps)
|
220
|
+
|
221
|
+
|
222
|
+
def rotation6d_to_rotation_matrix(r6d: jnp.ndarray) -> jnp.ndarray:
|
223
|
+
"""Convert 6D rotation representation to rotation matrix.
|
224
|
+
|
225
|
+
From https://arxiv.org/pdf/1812.07035, Appendix B
|
226
|
+
|
227
|
+
Args:
|
228
|
+
r6d: The 6D rotation representation, shape (*, 6).
|
229
|
+
|
230
|
+
Returns:
|
231
|
+
The rotation matrix, shape (*, 3, 3).
|
232
|
+
"""
|
233
|
+
chex.assert_shape(r6d, (..., 6))
|
234
|
+
shape = r6d.shape
|
235
|
+
flat = r6d.reshape(-1, 6)
|
236
|
+
a_1 = flat[:, 0:3]
|
237
|
+
a_2 = flat[:, 3:6]
|
238
|
+
|
239
|
+
b_1 = normalize(a_1, axis=-1)
|
240
|
+
|
241
|
+
# Reordered Gram-Schmidt orthonormalization.
|
242
|
+
b_3 = normalize(jnp.cross(b_1, a_2), axis=-1)
|
243
|
+
b_2 = jnp.cross(b_3, b_1)
|
244
|
+
|
245
|
+
rotation_matrix = jnp.stack([b_1, b_2, b_3], axis=-1)
|
246
|
+
return rotation_matrix.reshape(shape[:-1] + (3, 3))
|
247
|
+
|
248
|
+
|
249
|
+
def rotation_matrix_to_rotation6d(rotation_matrix: jnp.ndarray) -> jnp.ndarray:
|
250
|
+
"""Convert rotation matrix to 6D rotation representation.
|
251
|
+
|
252
|
+
Args:
|
253
|
+
rotation_matrix: The rotation matrix, shape (*, 3, 3).
|
254
|
+
|
255
|
+
Returns:
|
256
|
+
The 6D rotation representation, shape (*, 6).
|
257
|
+
"""
|
258
|
+
chex.assert_shape(rotation_matrix, (..., 3, 3))
|
259
|
+
shape = rotation_matrix.shape
|
260
|
+
# Simply concatenate a1 and a2 from SO(3)
|
261
|
+
r6d = jnp.concatenate([rotation_matrix[..., 0], rotation_matrix[..., 1]], axis=-1)
|
262
|
+
return r6d.reshape(shape[:-2] + (6,))
|
xax/task/base.py
CHANGED
@@ -184,8 +184,8 @@ class BaseTask(Generic[Config]):
|
|
184
184
|
|
185
185
|
# Attempts to load any paths as configs.
|
186
186
|
is_path = [Path(arg).is_file() or (task_path / arg).is_file() for arg in args]
|
187
|
-
paths = [arg for arg, is_path in zip(args, is_path) if is_path]
|
188
|
-
non_paths = [arg for arg, is_path in zip(args, is_path) if not is_path]
|
187
|
+
paths = [arg for arg, is_path in zip(args, is_path, strict=True) if is_path]
|
188
|
+
non_paths = [arg for arg, is_path in zip(args, is_path, strict=True) if not is_path]
|
189
189
|
if paths:
|
190
190
|
cfg = OmegaConf.merge(cfg, *(get_config(path, task_path) for path in paths))
|
191
191
|
cfg = OmegaConf.merge(cfg, OmegaConf.from_cli(non_paths))
|
xax/task/logger.py
CHANGED
@@ -462,11 +462,11 @@ class LoggerImpl(ABC):
|
|
462
462
|
|
463
463
|
self.tickers = {phase: IntervalTicker(log_interval_seconds) for phase in get_args(Phase)}
|
464
464
|
|
465
|
-
|
466
|
-
|
465
|
+
@abstractmethod
|
466
|
+
def start(self) -> None: ...
|
467
467
|
|
468
|
-
|
469
|
-
|
468
|
+
@abstractmethod
|
469
|
+
def stop(self) -> None: ...
|
470
470
|
|
471
471
|
@abstractmethod
|
472
472
|
def write(self, line: LogLine) -> None:
|
@@ -476,6 +476,7 @@ class LoggerImpl(ABC):
|
|
476
476
|
line: The line to write.
|
477
477
|
"""
|
478
478
|
|
479
|
+
@abstractmethod
|
479
480
|
def write_error_summary(self, error_summary: LogErrorSummary) -> None:
|
480
481
|
"""Handles writing an error summary.
|
481
482
|
|
@@ -483,6 +484,7 @@ class LoggerImpl(ABC):
|
|
483
484
|
error_summary: The error summary to write.
|
484
485
|
"""
|
485
486
|
|
487
|
+
@abstractmethod
|
486
488
|
def write_error(self, error: LogError) -> None:
|
487
489
|
"""Handles writing an error line.
|
488
490
|
|
@@ -490,6 +492,7 @@ class LoggerImpl(ABC):
|
|
490
492
|
error: The error information to write.
|
491
493
|
"""
|
492
494
|
|
495
|
+
@abstractmethod
|
493
496
|
def write_status(self, status: LogStatus) -> None:
|
494
497
|
"""Handles writing a status line.
|
495
498
|
|
@@ -497,6 +500,7 @@ class LoggerImpl(ABC):
|
|
497
500
|
status: The status to write.
|
498
501
|
"""
|
499
502
|
|
503
|
+
@abstractmethod
|
500
504
|
def write_ping(self, ping: LogPing) -> None:
|
501
505
|
"""Handles writing a ping line.
|
502
506
|
|
@@ -504,6 +508,7 @@ class LoggerImpl(ABC):
|
|
504
508
|
ping: The ping to write.
|
505
509
|
"""
|
506
510
|
|
511
|
+
@abstractmethod
|
507
512
|
def log_file(self, name: str, contents: str) -> None:
|
508
513
|
"""Logs a large text file.
|
509
514
|
|
@@ -621,7 +626,7 @@ class Logger:
|
|
621
626
|
return
|
622
627
|
line = self.pack(state)
|
623
628
|
self.clear()
|
624
|
-
for lg in (lg for lg, should_log in zip(self.loggers, should_log) if should_log):
|
629
|
+
for lg in (lg for lg, should_log in zip(self.loggers, should_log, strict=False) if should_log):
|
625
630
|
lg.write(line)
|
626
631
|
|
627
632
|
def write_error_summary(self, error_summary: str) -> None:
|
@@ -1045,7 +1050,7 @@ class Logger:
|
|
1045
1050
|
line_spacing=line_spacing,
|
1046
1051
|
centered=centered,
|
1047
1052
|
)
|
1048
|
-
for img, label in zip(images, labels)
|
1053
|
+
for img, label in zip(images, labels, strict=True)
|
1049
1054
|
]
|
1050
1055
|
tiled = tile_images([img.image for img in labeled], sep)
|
1051
1056
|
|
xax/task/loggers/callback.py
CHANGED
@@ -25,6 +25,12 @@ class CallbackLogger(LoggerImpl):
|
|
25
25
|
self.ping_callback = ping_callback
|
26
26
|
self.file_callback = file_callback
|
27
27
|
|
28
|
+
def start(self) -> None:
|
29
|
+
pass
|
30
|
+
|
31
|
+
def stop(self) -> None:
|
32
|
+
pass
|
33
|
+
|
28
34
|
def write(self, line: LogLine) -> None:
|
29
35
|
self.callback(line)
|
30
36
|
|
xax/task/loggers/json.py
CHANGED
@@ -2,13 +2,13 @@
|
|
2
2
|
|
3
3
|
import json
|
4
4
|
import sys
|
5
|
-
from dataclasses import asdict
|
6
5
|
from typing import Any, Literal, Mapping, TextIO
|
7
6
|
|
8
7
|
from jaxtyping import Array
|
9
8
|
|
10
9
|
from xax.task.logger import (
|
11
10
|
LogError,
|
11
|
+
LogErrorSummary,
|
12
12
|
LoggerImpl,
|
13
13
|
LogLine,
|
14
14
|
LogPing,
|
@@ -58,6 +58,12 @@ class JsonLogger(LoggerImpl):
|
|
58
58
|
self.line_sep = line_sep
|
59
59
|
self.remove_unicode_from_namespaces = remove_unicode_from_namespaces
|
60
60
|
|
61
|
+
def start(self) -> None:
|
62
|
+
pass
|
63
|
+
|
64
|
+
def stop(self) -> None:
|
65
|
+
pass
|
66
|
+
|
61
67
|
@property
|
62
68
|
def fp(self) -> TextIO:
|
63
69
|
return self.log_stream
|
@@ -67,7 +73,7 @@ class JsonLogger(LoggerImpl):
|
|
67
73
|
return self.err_log_stream
|
68
74
|
|
69
75
|
def get_json(self, line: LogLine) -> str:
|
70
|
-
data: dict = {"state":
|
76
|
+
data: dict = {"state": line.state.to_dict()}
|
71
77
|
|
72
78
|
def add_logs(log: Mapping[str, Mapping[str, LogScalar | LogString]], data: dict) -> None:
|
73
79
|
for namespace, values in log.items():
|
@@ -88,6 +94,12 @@ class JsonLogger(LoggerImpl):
|
|
88
94
|
if self.flush_immediately:
|
89
95
|
self.fp.flush()
|
90
96
|
|
97
|
+
def write_error_summary(self, error_summary: LogErrorSummary) -> None:
|
98
|
+
pass
|
99
|
+
|
100
|
+
def log_file(self, name: str, contents: str) -> None:
|
101
|
+
pass
|
102
|
+
|
91
103
|
def write_error(self, error: LogError) -> None:
|
92
104
|
self.err_fp.write(error.message)
|
93
105
|
if error.location is not None:
|
xax/task/loggers/state.py
CHANGED
@@ -3,7 +3,14 @@
|
|
3
3
|
from pathlib import Path
|
4
4
|
from typing import Literal
|
5
5
|
|
6
|
-
from xax.task.logger import
|
6
|
+
from xax.task.logger import (
|
7
|
+
LogError,
|
8
|
+
LogErrorSummary,
|
9
|
+
LoggerImpl,
|
10
|
+
LogLine,
|
11
|
+
LogPing,
|
12
|
+
LogStatus,
|
13
|
+
)
|
7
14
|
|
8
15
|
|
9
16
|
class StateLogger(LoggerImpl):
|
@@ -30,3 +37,21 @@ class StateLogger(LoggerImpl):
|
|
30
37
|
|
31
38
|
def write(self, line: LogLine) -> None:
|
32
39
|
pass
|
40
|
+
|
41
|
+
def start(self) -> None:
|
42
|
+
pass
|
43
|
+
|
44
|
+
def stop(self) -> None:
|
45
|
+
pass
|
46
|
+
|
47
|
+
def write_error_summary(self, error_summary: LogErrorSummary) -> None:
|
48
|
+
pass
|
49
|
+
|
50
|
+
def write_error(self, error: LogError) -> None:
|
51
|
+
pass
|
52
|
+
|
53
|
+
def write_status(self, status: LogStatus) -> None:
|
54
|
+
pass
|
55
|
+
|
56
|
+
def write_ping(self, ping: LogPing) -> None:
|
57
|
+
pass
|
xax/task/loggers/stdout.py
CHANGED
@@ -79,11 +79,13 @@ class StdoutLogger(LoggerImpl):
|
|
79
79
|
self.error_summary: tuple[str, datetime.datetime] | None = None
|
80
80
|
|
81
81
|
def start(self) -> None:
|
82
|
-
|
82
|
+
pass
|
83
83
|
|
84
84
|
def stop(self) -> None:
|
85
85
|
self.write_queues()
|
86
|
-
|
86
|
+
|
87
|
+
def log_file(self, name: str, contents: str) -> None:
|
88
|
+
pass
|
87
89
|
|
88
90
|
def write_separator(self) -> None:
|
89
91
|
self.write_fp.write("\033[2J\033[H")
|
xax/task/loggers/tensorboard.py
CHANGED
@@ -12,7 +12,7 @@ from typing import TypeVar
|
|
12
12
|
|
13
13
|
from xax.core.state import Phase
|
14
14
|
from xax.nn.parallel import is_master
|
15
|
-
from xax.task.logger import LoggerImpl, LogLine
|
15
|
+
from xax.task.logger import LogError, LogErrorSummary, LoggerImpl, LogLine, LogPing, LogStatus
|
16
16
|
from xax.utils.jax import as_float
|
17
17
|
from xax.utils.logging import LOG_STATUS, port_is_busy
|
18
18
|
from xax.utils.tensorboard import TensorboardWriter, TensorboardWriters
|
@@ -236,3 +236,21 @@ class TensorboardLogger(LoggerImpl):
|
|
236
236
|
for name, contents in self.files.items():
|
237
237
|
writer.add_text(name, contents)
|
238
238
|
self.files.clear()
|
239
|
+
|
240
|
+
def start(self) -> None:
|
241
|
+
pass
|
242
|
+
|
243
|
+
def stop(self) -> None:
|
244
|
+
pass
|
245
|
+
|
246
|
+
def write_error(self, error: LogError) -> None:
|
247
|
+
pass
|
248
|
+
|
249
|
+
def write_error_summary(self, error_summary: LogErrorSummary) -> None:
|
250
|
+
pass
|
251
|
+
|
252
|
+
def write_ping(self, ping: LogPing) -> None:
|
253
|
+
pass
|
254
|
+
|
255
|
+
def write_status(self, status: LogStatus) -> None:
|
256
|
+
pass
|
xax/task/mixins/artifacts.py
CHANGED
@@ -31,11 +31,13 @@ Config = TypeVar("Config", bound=ArtifactsConfig)
|
|
31
31
|
|
32
32
|
class ArtifactsMixin(BaseTask[Config]):
|
33
33
|
_exp_dir: Path | None
|
34
|
+
_stage_dir: Path | None
|
34
35
|
|
35
36
|
def __init__(self, config: Config) -> None:
|
36
37
|
super().__init__(config)
|
37
38
|
|
38
39
|
self._exp_dir = None
|
40
|
+
self._stage_dir = None
|
39
41
|
|
40
42
|
@functools.cached_property
|
41
43
|
def run_dir(self) -> Path:
|
@@ -75,15 +77,16 @@ class ArtifactsMixin(BaseTask[Config]):
|
|
75
77
|
logger.log(LOG_STATUS, self._exp_dir)
|
76
78
|
return self._exp_dir
|
77
79
|
|
78
|
-
@functools.lru_cache(maxsize=None)
|
79
80
|
def stage_environment(self) -> Path | None:
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
81
|
+
if self._stage_dir is None:
|
82
|
+
stage_dir = (self.exp_dir / "code").resolve()
|
83
|
+
try:
|
84
|
+
stage_environment(self, stage_dir)
|
85
|
+
except Exception:
|
86
|
+
logger.exception("Failed to stage environment!")
|
87
|
+
return None
|
88
|
+
self._stage_dir = stage_dir
|
89
|
+
return self._stage_dir
|
87
90
|
|
88
91
|
def on_training_end(self, state: State) -> State:
|
89
92
|
state = super().on_training_end(state)
|
xax/task/mixins/checkpointing.py
CHANGED
@@ -52,6 +52,114 @@ class CheckpointingConfig(ArtifactsConfig):
|
|
52
52
|
Config = TypeVar("Config", bound=CheckpointingConfig)
|
53
53
|
|
54
54
|
|
55
|
+
@overload
|
56
|
+
def load_ckpt(
|
57
|
+
path: Path,
|
58
|
+
*,
|
59
|
+
part: Literal["all"],
|
60
|
+
model_template: PyTree,
|
61
|
+
optimizer_template: PyTree,
|
62
|
+
opt_state_template: PyTree,
|
63
|
+
) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State, DictConfig]: ...
|
64
|
+
|
65
|
+
|
66
|
+
@overload
|
67
|
+
def load_ckpt(
|
68
|
+
path: Path,
|
69
|
+
*,
|
70
|
+
part: Literal["model_state_config"],
|
71
|
+
model_template: PyTree,
|
72
|
+
) -> tuple[PyTree, State, DictConfig]: ...
|
73
|
+
|
74
|
+
|
75
|
+
@overload
|
76
|
+
def load_ckpt(path: Path, *, part: Literal["model"], model_template: PyTree) -> PyTree: ...
|
77
|
+
|
78
|
+
|
79
|
+
@overload
|
80
|
+
def load_ckpt(path: Path, *, part: Literal["opt"], optimizer_template: PyTree) -> optax.GradientTransformation: ...
|
81
|
+
|
82
|
+
|
83
|
+
@overload
|
84
|
+
def load_ckpt(path: Path, *, part: Literal["opt_state"], opt_state_template: PyTree) -> optax.OptState: ...
|
85
|
+
|
86
|
+
|
87
|
+
@overload
|
88
|
+
def load_ckpt(path: Path, *, part: Literal["state"]) -> State: ...
|
89
|
+
|
90
|
+
|
91
|
+
@overload
|
92
|
+
def load_ckpt(path: Path, *, part: Literal["config"]) -> DictConfig: ...
|
93
|
+
|
94
|
+
|
95
|
+
def load_ckpt(
|
96
|
+
path: str | Path,
|
97
|
+
*,
|
98
|
+
part: CheckpointPart = "model",
|
99
|
+
model_template: PyTree | None = None,
|
100
|
+
optimizer_template: PyTree | None = None,
|
101
|
+
opt_state_template: PyTree | None = None,
|
102
|
+
) -> (
|
103
|
+
tuple[PyTree, optax.GradientTransformation, optax.OptState, State, DictConfig]
|
104
|
+
| tuple[PyTree, State, DictConfig]
|
105
|
+
| PyTree
|
106
|
+
| optax.GradientTransformation
|
107
|
+
| optax.OptState
|
108
|
+
| State
|
109
|
+
| DictConfig
|
110
|
+
):
|
111
|
+
with tarfile.open(path, "r:gz") as tar:
|
112
|
+
|
113
|
+
def get_model() -> PyTree:
|
114
|
+
if model_template is None:
|
115
|
+
raise ValueError("model_template must be provided to load model weights")
|
116
|
+
if (model := tar.extractfile("model")) is None:
|
117
|
+
raise ValueError(f"Checkpoint does not contain a model file: {path}")
|
118
|
+
return eqx.tree_deserialise_leaves(io.BytesIO(model.read()), model_template)
|
119
|
+
|
120
|
+
def get_opt() -> optax.GradientTransformation:
|
121
|
+
if optimizer_template is None:
|
122
|
+
raise ValueError("optimizer_template must be provided to load optimizer")
|
123
|
+
if (opt := tar.extractfile("optimizer")) is None:
|
124
|
+
raise ValueError(f"Checkpoint does not contain an optimizer file: {path}")
|
125
|
+
return eqx.tree_deserialise_leaves(io.BytesIO(opt.read()), optimizer_template)
|
126
|
+
|
127
|
+
def get_opt_state() -> optax.OptState:
|
128
|
+
if opt_state_template is None:
|
129
|
+
raise ValueError("opt_state_template must be provided to load optimizer state")
|
130
|
+
if (opt_state := tar.extractfile("opt_state")) is None:
|
131
|
+
raise ValueError(f"Checkpoint does not contain an optimizer state file: {path}")
|
132
|
+
return eqx.tree_deserialise_leaves(io.BytesIO(opt_state.read()), opt_state_template)
|
133
|
+
|
134
|
+
def get_state() -> State:
|
135
|
+
if (state := tar.extractfile("state")) is None:
|
136
|
+
raise ValueError(f"Checkpoint does not contain a state file: {path}")
|
137
|
+
return State.from_dict(**json.loads(state.read().decode()))
|
138
|
+
|
139
|
+
def get_config() -> DictConfig:
|
140
|
+
if (config := tar.extractfile("config")) is None:
|
141
|
+
raise ValueError(f"Checkpoint does not contain a config file: {path}")
|
142
|
+
return cast(DictConfig, OmegaConf.load(config))
|
143
|
+
|
144
|
+
match part:
|
145
|
+
case "model":
|
146
|
+
return get_model()
|
147
|
+
case "opt":
|
148
|
+
return get_opt()
|
149
|
+
case "opt_state":
|
150
|
+
return get_opt_state()
|
151
|
+
case "state":
|
152
|
+
return get_state()
|
153
|
+
case "config":
|
154
|
+
return get_config()
|
155
|
+
case "model_state_config":
|
156
|
+
return get_model(), get_state(), get_config()
|
157
|
+
case "all":
|
158
|
+
return get_model(), get_opt(), get_opt_state(), get_state(), get_config()
|
159
|
+
case _:
|
160
|
+
raise ValueError(f"Invalid checkpoint part: {part}")
|
161
|
+
|
162
|
+
|
55
163
|
class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
|
56
164
|
def __init__(self, config: Config) -> None:
|
57
165
|
super().__init__(config)
|
@@ -82,149 +190,6 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
|
|
82
190
|
return True
|
83
191
|
return False
|
84
192
|
|
85
|
-
@overload
|
86
|
-
def load_ckpt_with_template(
|
87
|
-
self,
|
88
|
-
path: Path,
|
89
|
-
*,
|
90
|
-
part: Literal["all"],
|
91
|
-
model_template: PyTree,
|
92
|
-
optimizer_template: PyTree,
|
93
|
-
opt_state_template: PyTree,
|
94
|
-
) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State, Config]: ...
|
95
|
-
|
96
|
-
@overload
|
97
|
-
def load_ckpt_with_template(
|
98
|
-
self,
|
99
|
-
path: Path,
|
100
|
-
*,
|
101
|
-
part: Literal["model_state_config"],
|
102
|
-
model_template: PyTree,
|
103
|
-
) -> tuple[PyTree, State, Config]: ...
|
104
|
-
|
105
|
-
@overload
|
106
|
-
def load_ckpt_with_template(
|
107
|
-
self,
|
108
|
-
path: Path,
|
109
|
-
*,
|
110
|
-
part: Literal["model"],
|
111
|
-
model_template: PyTree,
|
112
|
-
) -> PyTree: ...
|
113
|
-
|
114
|
-
@overload
|
115
|
-
def load_ckpt_with_template(
|
116
|
-
self,
|
117
|
-
path: Path,
|
118
|
-
*,
|
119
|
-
part: Literal["opt"],
|
120
|
-
optimizer_template: PyTree,
|
121
|
-
) -> optax.GradientTransformation: ...
|
122
|
-
|
123
|
-
@overload
|
124
|
-
def load_ckpt_with_template(
|
125
|
-
self,
|
126
|
-
path: Path,
|
127
|
-
*,
|
128
|
-
part: Literal["opt_state"],
|
129
|
-
opt_state_template: PyTree,
|
130
|
-
) -> optax.OptState: ...
|
131
|
-
|
132
|
-
@overload
|
133
|
-
def load_ckpt_with_template(
|
134
|
-
self,
|
135
|
-
path: Path,
|
136
|
-
*,
|
137
|
-
part: Literal["state"],
|
138
|
-
) -> State: ...
|
139
|
-
|
140
|
-
@overload
|
141
|
-
def load_ckpt_with_template(
|
142
|
-
self,
|
143
|
-
path: Path,
|
144
|
-
*,
|
145
|
-
part: Literal["config"],
|
146
|
-
) -> Config: ...
|
147
|
-
|
148
|
-
def load_ckpt_with_template(
|
149
|
-
self,
|
150
|
-
path: Path,
|
151
|
-
*,
|
152
|
-
part: CheckpointPart = "all",
|
153
|
-
model_template: PyTree | None = None,
|
154
|
-
optimizer_template: PyTree | None = None,
|
155
|
-
opt_state_template: PyTree | None = None,
|
156
|
-
) -> (
|
157
|
-
tuple[PyTree, optax.GradientTransformation, optax.OptState, State, Config]
|
158
|
-
| tuple[PyTree, State, Config]
|
159
|
-
| PyTree
|
160
|
-
| optax.GradientTransformation
|
161
|
-
| optax.OptState
|
162
|
-
| State
|
163
|
-
| Config
|
164
|
-
):
|
165
|
-
"""Load a checkpoint.
|
166
|
-
|
167
|
-
Args:
|
168
|
-
path: Path to the checkpoint directory
|
169
|
-
part: Which part of the checkpoint to load
|
170
|
-
model_template: Template model with correct structure but uninitialized weights
|
171
|
-
optimizer_template: Template optimizer with correct structure but uninitialized weights
|
172
|
-
opt_state_template: Template optimizer state with correct structure but uninitialized weights
|
173
|
-
|
174
|
-
Returns:
|
175
|
-
The requested checkpoint components
|
176
|
-
"""
|
177
|
-
with tarfile.open(path, "r:gz") as tar:
|
178
|
-
|
179
|
-
def get_model() -> PyTree:
|
180
|
-
if model_template is None:
|
181
|
-
raise ValueError("model_template must be provided to load model weights")
|
182
|
-
if (model := tar.extractfile("model")) is None:
|
183
|
-
raise ValueError(f"Checkpoint does not contain a model file: {path}")
|
184
|
-
return eqx.tree_deserialise_leaves(io.BytesIO(model.read()), model_template)
|
185
|
-
|
186
|
-
def get_opt() -> optax.GradientTransformation:
|
187
|
-
if optimizer_template is None:
|
188
|
-
raise ValueError("optimizer_template must be provided to load optimizer")
|
189
|
-
if (opt := tar.extractfile("optimizer")) is None:
|
190
|
-
raise ValueError(f"Checkpoint does not contain an optimizer file: {path}")
|
191
|
-
return eqx.tree_deserialise_leaves(io.BytesIO(opt.read()), optimizer_template)
|
192
|
-
|
193
|
-
def get_opt_state() -> optax.OptState:
|
194
|
-
if opt_state_template is None:
|
195
|
-
raise ValueError("opt_state_template must be provided to load optimizer state")
|
196
|
-
if (opt_state := tar.extractfile("opt_state")) is None:
|
197
|
-
raise ValueError(f"Checkpoint does not contain an optimizer state file: {path}")
|
198
|
-
return eqx.tree_deserialise_leaves(io.BytesIO(opt_state.read()), opt_state_template)
|
199
|
-
|
200
|
-
def get_state() -> State:
|
201
|
-
if (state := tar.extractfile("state")) is None:
|
202
|
-
raise ValueError(f"Checkpoint does not contain a state file: {path}")
|
203
|
-
return State.from_dict(**json.loads(state.read().decode()))
|
204
|
-
|
205
|
-
def get_config() -> Config:
|
206
|
-
if (config := tar.extractfile("config")) is None:
|
207
|
-
raise ValueError(f"Checkpoint does not contain a config file: {path}")
|
208
|
-
return self.get_config(cast(DictConfig, OmegaConf.load(config)), use_cli=False)
|
209
|
-
|
210
|
-
match part:
|
211
|
-
case "model":
|
212
|
-
return get_model()
|
213
|
-
case "opt":
|
214
|
-
return get_opt()
|
215
|
-
case "opt_state":
|
216
|
-
return get_opt_state()
|
217
|
-
case "state":
|
218
|
-
return get_state()
|
219
|
-
case "config":
|
220
|
-
return get_config()
|
221
|
-
case "model_state_config":
|
222
|
-
return get_model(), get_state(), get_config()
|
223
|
-
case "all":
|
224
|
-
return get_model(), get_opt(), get_opt_state(), get_state(), get_config()
|
225
|
-
case _:
|
226
|
-
raise ValueError(f"Invalid checkpoint part: {part}")
|
227
|
-
|
228
193
|
def save_checkpoint(
|
229
194
|
self,
|
230
195
|
model: PyTree | None = None,
|
xax/task/mixins/train.py
CHANGED
@@ -40,7 +40,7 @@ from xax.core.state import Phase, State
|
|
40
40
|
from xax.nn.functions import set_random_seed
|
41
41
|
from xax.nn.parallel import is_master
|
42
42
|
from xax.task.mixins.artifacts import ArtifactsConfig, ArtifactsMixin
|
43
|
-
from xax.task.mixins.checkpointing import CheckpointingConfig, CheckpointingMixin, CheckpointPart
|
43
|
+
from xax.task.mixins.checkpointing import CheckpointingConfig, CheckpointingMixin, CheckpointPart, load_ckpt
|
44
44
|
from xax.task.mixins.data_loader import DataloadersConfig, DataloadersMixin
|
45
45
|
from xax.task.mixins.logger import LoggerConfig, LoggerMixin
|
46
46
|
from xax.task.mixins.runnable import RunnableConfig, RunnableMixin
|
@@ -57,6 +57,7 @@ from xax.utils.experiments import (
|
|
57
57
|
)
|
58
58
|
from xax.utils.jax import jit as xax_jit
|
59
59
|
from xax.utils.logging import LOG_PING, LOG_STATUS
|
60
|
+
from xax.utils.pytree import get_pytree_param_count
|
60
61
|
from xax.utils.text import highlight_exception_message, show_info
|
61
62
|
from xax.utils.types.frozen_dict import FrozenDict
|
62
63
|
|
@@ -360,6 +361,7 @@ class TrainMixin(
|
|
360
361
|
model = self.get_model(key)
|
361
362
|
state = State.init_state()
|
362
363
|
|
364
|
+
self.log_model_size(model)
|
363
365
|
if not load_optimizer:
|
364
366
|
return model, state
|
365
367
|
|
@@ -450,44 +452,43 @@ class TrainMixin(
|
|
450
452
|
match part:
|
451
453
|
case "model_state_config":
|
452
454
|
model_spec = eqx.filter_eval_shape(self.get_model, key)
|
453
|
-
|
455
|
+
model, state, config = load_ckpt(path, part="model_state_config", model_template=model_spec)
|
456
|
+
config = self.get_config(config, use_cli=False)
|
457
|
+
return model, state, config
|
454
458
|
|
455
459
|
case "model":
|
456
460
|
model_spec = eqx.filter_eval_shape(self.get_model, key)
|
457
|
-
return
|
458
|
-
|
459
|
-
case "config":
|
460
|
-
return self.load_ckpt_with_template(path, part="config")
|
461
|
+
return load_ckpt(path, part="model", model_template=model_spec)
|
461
462
|
|
462
463
|
case "opt":
|
463
464
|
optimizer_spec = eqx.filter_eval_shape(self.get_optimizer)
|
464
|
-
return
|
465
|
+
return load_ckpt(path, part="opt", optimizer_template=optimizer_spec)
|
465
466
|
|
466
467
|
case "opt_state":
|
467
468
|
if model is None:
|
468
469
|
model_spec = eqx.filter_eval_shape(self.get_model, key)
|
469
|
-
model =
|
470
|
+
model = load_ckpt(path, part="model", model_template=model_spec)
|
470
471
|
if optimizer is None:
|
471
472
|
optimizer_spec = eqx.filter_eval_shape(self.get_optimizer)
|
472
|
-
optimizer =
|
473
|
+
optimizer = load_ckpt(path, part="opt", optimizer_template=optimizer_spec)
|
473
474
|
opt_state_spec = eqx.filter_eval_shape(self.get_initial_opt_state, model, optimizer)
|
474
|
-
return
|
475
|
+
return load_ckpt(path, part="opt_state", opt_state_template=opt_state_spec)
|
475
476
|
|
476
477
|
case "state":
|
477
|
-
return
|
478
|
+
return load_ckpt(path, part="state")
|
478
479
|
|
479
480
|
case "config":
|
480
|
-
return self.
|
481
|
+
return self.get_config(load_ckpt(path, part="config"), use_cli=False)
|
481
482
|
|
482
483
|
case "all":
|
483
484
|
model_spec = eqx.filter_eval_shape(self.get_model, key)
|
484
|
-
model =
|
485
|
+
model = load_ckpt(path, part="model", model_template=model_spec)
|
485
486
|
optimizer_spec = eqx.filter_eval_shape(self.get_optimizer)
|
486
|
-
optimizer =
|
487
|
+
optimizer = load_ckpt(path, part="opt", optimizer_template=optimizer_spec)
|
487
488
|
opt_state_spec = eqx.filter_eval_shape(self.get_initial_opt_state, model, optimizer)
|
488
|
-
opt_state =
|
489
|
-
state =
|
490
|
-
config = self.
|
489
|
+
opt_state = load_ckpt(path, part="opt_state", opt_state_template=opt_state_spec)
|
490
|
+
state = load_ckpt(path, part="state")
|
491
|
+
config = self.get_config(load_ckpt(path, part="config"), use_cli=False)
|
491
492
|
return model, optimizer, opt_state, state, config
|
492
493
|
|
493
494
|
case _:
|
@@ -683,6 +684,9 @@ class TrainMixin(
|
|
683
684
|
self.logger.log_file("config.yaml", self.config_str(self.config, use_cli=False))
|
684
685
|
self.logger.log_file("info.json", get_info_json())
|
685
686
|
|
687
|
+
def log_model_size(self, model: PyTree) -> None:
|
688
|
+
logger.info("Model size: %s", f"{get_pytree_param_count(model):,}")
|
689
|
+
|
686
690
|
def model_partition_fn(self, item: Any) -> bool: # noqa: ANN401
|
687
691
|
return eqx.is_inexact_array(item)
|
688
692
|
|
xax/utils/experiments.py
CHANGED
@@ -749,7 +749,8 @@ class BaseFileDownloader(ABC):
|
|
749
749
|
f"We detected some HTML elements in the downloaded file. "
|
750
750
|
f"This most likely means that the download triggered an unhandled API response by GDrive. "
|
751
751
|
f"Please report this to torchvision at https://github.com/pytorch/vision/issues including "
|
752
|
-
f"the response:\n\n{text}"
|
752
|
+
f"the response:\n\n{text}",
|
753
|
+
stacklevel=2,
|
753
754
|
)
|
754
755
|
|
755
756
|
@classmethod
|
xax/utils/jaxpr.py
CHANGED
@@ -3,10 +3,10 @@
|
|
3
3
|
from pathlib import Path
|
4
4
|
|
5
5
|
import jax
|
6
|
-
import jax.core
|
6
|
+
import jax.extend.core
|
7
7
|
|
8
8
|
|
9
|
-
def save_jaxpr_dot(closed_jaxpr: jax.core.ClosedJaxpr, filename: str | Path) -> None:
|
9
|
+
def save_jaxpr_dot(closed_jaxpr: jax.extend.core.ClosedJaxpr, filename: str | Path) -> None:
|
10
10
|
"""Save the JAXPR to a DOT file.
|
11
11
|
|
12
12
|
Example usage:
|
@@ -30,15 +30,15 @@ def save_jaxpr_dot(closed_jaxpr: jax.core.ClosedJaxpr, filename: str | Path) ->
|
|
30
30
|
with open(filename, "w") as f:
|
31
31
|
f.write("digraph Jaxpr {\n")
|
32
32
|
|
33
|
-
var_names: dict[jax.core.Var, str] = {}
|
33
|
+
var_names: dict[jax.extend.core.Var, str] = {}
|
34
34
|
var_count = 0
|
35
35
|
|
36
|
-
def get_var_name(var: jax.core.Var) -> str:
|
36
|
+
def get_var_name(var: jax.extend.core.Var) -> str:
|
37
37
|
"""Get a unique name for a variable."""
|
38
38
|
nonlocal var_names, var_count
|
39
39
|
|
40
40
|
# Handle Literal objects specially since they're not hashable
|
41
|
-
if isinstance(var, jax.core.Literal):
|
41
|
+
if isinstance(var, jax.extend.core.Literal):
|
42
42
|
# Create a name based on the literal value
|
43
43
|
name = f"lit_{var.val}"
|
44
44
|
return name
|
xax/utils/pytree.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
"""Utils for accessing, modifying, and otherwise manipulating pytrees."""
|
2
2
|
|
3
3
|
import chex
|
4
|
+
import equinox as eqx
|
4
5
|
import jax
|
5
6
|
import jax.numpy as jnp
|
6
7
|
from jax import Array
|
@@ -57,7 +58,7 @@ def pytree_has_nans(pytree: PyTree) -> Array:
|
|
57
58
|
|
58
59
|
def update_pytree(cond: Array, new: PyTree, original: PyTree) -> PyTree:
|
59
60
|
"""Update a pytree based on a condition."""
|
60
|
-
# Tricky, need use
|
61
|
+
# Tricky, need use tree.map because where expects array leafs.
|
61
62
|
return jax.tree.map(lambda x, y: jnp.where(cond, x, y), new, original)
|
62
63
|
|
63
64
|
|
@@ -124,7 +125,7 @@ def reshuffle_pytree(data: PyTree, batch_shape: tuple[int, ...], rng: PRNGKeyArr
|
|
124
125
|
def reshuffle_pytree_independently(data: PyTree, batch_shape: tuple[int, ...], rng: PRNGKeyArray) -> PyTree:
|
125
126
|
"""Reshuffle a rollout array across arbitrary batch dimensions independently of each other."""
|
126
127
|
rngs = jax.random.split(rng, len(batch_shape))
|
127
|
-
perms = [jax.random.permutation(rng_i, dim) for rng_i, dim in zip(rngs, batch_shape)]
|
128
|
+
perms = [jax.random.permutation(rng_i, dim) for rng_i, dim in zip(rngs, batch_shape, strict=True)]
|
128
129
|
# n-dimensional index grid from permutations
|
129
130
|
idx_grids = jnp.meshgrid(*perms, indexing="ij")
|
130
131
|
|
@@ -236,3 +237,9 @@ def reshuffle_pytree_along_dims(
|
|
236
237
|
return x
|
237
238
|
|
238
239
|
return jax.tree.map_with_path(restore_transpose, reshuffled_transposed)
|
240
|
+
|
241
|
+
|
242
|
+
def get_pytree_param_count(pytree: PyTree) -> int:
|
243
|
+
"""Calculates the total number of parameters in a PyTree."""
|
244
|
+
leaves, _ = jax.tree.flatten(pytree)
|
245
|
+
return sum(x.size for x in leaves if isinstance(x, jnp.ndarray) and eqx.is_inexact_array(x))
|
xax/utils/text.py
CHANGED
@@ -192,7 +192,7 @@ def render_text_blocks(
|
|
192
192
|
if any(len(row) != len(blocks[0]) for row in blocks):
|
193
193
|
raise ValueError("All rows must have the same number of blocks in order to align them")
|
194
194
|
widths = [[max(len(line) for line in i.lines) if i.width is None else i.width for i in r] for r in blocks]
|
195
|
-
row_widths = [max(i) for i in zip(*widths)]
|
195
|
+
row_widths = [max(i) for i in zip(*widths, strict=True)]
|
196
196
|
for row in blocks:
|
197
197
|
for i, block in enumerate(row):
|
198
198
|
block.width = row_widths[i]
|
@@ -263,7 +263,7 @@ def render_text_blocks(
|
|
263
263
|
if i >= len(block.lines)
|
264
264
|
else colored(pad(block.lines[i], width, block.center), block.color, bold=block.bold)
|
265
265
|
)
|
266
|
-
for block, width in zip(row, get_widths(row))
|
266
|
+
for block, width in zip(row, get_widths(row), strict=True)
|
267
267
|
]
|
268
268
|
)
|
269
269
|
+ " │"
|
xax/utils/types/frozen_dict.py
CHANGED
@@ -133,12 +133,12 @@ class FrozenDict(Mapping[K, V]):
|
|
133
133
|
|
134
134
|
@classmethod
|
135
135
|
def tree_unflatten(cls, keys: tuple[K, ...], values: tuple[Any, ...]) -> "FrozenDict[K, V]":
|
136
|
-
return cls({k: v for k, v in zip(keys, values)}, __unsafe_skip_copy__=True)
|
136
|
+
return cls({k: v for k, v in zip(keys, values, strict=True)}, __unsafe_skip_copy__=True)
|
137
137
|
|
138
138
|
|
139
139
|
def unfreeze(x: FrozenDict[K, V] | dict[str, Any]) -> dict[Any, Any]: # noqa: ANN401
|
140
140
|
if isinstance(x, FrozenDict):
|
141
|
-
return jax.
|
141
|
+
return jax.tree.map(lambda y: y, x._dict)
|
142
142
|
elif isinstance(x, dict):
|
143
143
|
ys = {}
|
144
144
|
for key, value in x.items():
|
@@ -1,23 +1,23 @@
|
|
1
|
-
xax/__init__.py,sha256=
|
1
|
+
xax/__init__.py,sha256=bNV5gnH1foOgZvLf-Cx_fJrhgwk0YFggqB3kNhbNuxg,15502
|
2
2
|
xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
3
|
xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
|
4
4
|
xax/requirements.txt,sha256=6qY-84e-sTmlfJNrSjwONQKqzAn5h8G_oGIhnhmfSr4,302
|
5
5
|
xax/core/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
6
|
-
xax/core/conf.py,sha256=
|
6
|
+
xax/core/conf.py,sha256=d7Dp_GwKnaxtkztlSrJSM_LR0UYJX_FWTtceIWCBkxc,5138
|
7
7
|
xax/core/state.py,sha256=yO25lMoLCUTJlHyLzQxlDbsHC_GZ3HkrKAq5huA7AkU,4552
|
8
8
|
xax/nn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
9
9
|
xax/nn/embeddings.py,sha256=bQGxBFxkLwi2MQLkRfGaHPH5P_KKB21HdI7VNWTKIOQ,11847
|
10
|
-
xax/nn/equinox.py,sha256=
|
10
|
+
xax/nn/equinox.py,sha256=JZuSApD4bL0UK5W1nrQtucWYvNWUha07J6LTLk_RX-Y,4910
|
11
11
|
xax/nn/export.py,sha256=pRfM2B4hB2EvljysC6AjtgB_7Cn7JtaP3dhYU2stZtY,5545
|
12
|
-
xax/nn/functions.py,sha256=
|
13
|
-
xax/nn/geom.py,sha256=
|
12
|
+
xax/nn/functions.py,sha256=bA5kJYzMtFM8eUqBC086i355zJMAO7k_vPFNSDBI9-s,2814
|
13
|
+
xax/nn/geom.py,sha256=B8QE-L-xJWhf9KygTByPUAWe7Clpek4GlTABpsJFMBs,7702
|
14
14
|
xax/nn/losses.py,sha256=Q_NVnm5n4UPBvp5nI_1aUptfXnqFYoUeFwySiyvopHg,272
|
15
15
|
xax/nn/norm.py,sha256=WgZ3QCrUnf-YecwhEtVPcr99fKK3ECl_UeiAs2uv7oo,564
|
16
16
|
xax/nn/parallel.py,sha256=fnTiT7MsG7eQrJvqwjIz2Ifo3P27TuxIJzmpGYSa_dQ,4608
|
17
17
|
xax/nn/ssm.py,sha256=8dLAcQ1hBaMT-kkHvwGu_ecxJeTY32WeMYmd4T4KtxA,10745
|
18
18
|
xax/task/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
19
|
-
xax/task/base.py,sha256=
|
20
|
-
xax/task/logger.py,sha256=
|
19
|
+
xax/task/base.py,sha256=ymGld72zTSw50PBpX4WwqQDyJgQSagDrLh56aqvLT40,7720
|
20
|
+
xax/task/logger.py,sha256=W_BpluYvQai1lh1dDCAj-2_mWUC1buhwJncHygDffjc,41125
|
21
21
|
xax/task/script.py,sha256=bMMIJoUtpSBvPp6-7bejTrajTXvSg0794sYLKdPIToE,972
|
22
22
|
xax/task/task.py,sha256=UHMpnv__gqMcfbC_L-Hhk-DCnUYlFVsgbNf-v8o8B7U,1424
|
23
23
|
xax/task/launchers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -25,14 +25,14 @@ xax/task/launchers/base.py,sha256=8LB_r6YISKu1vq1zk3aVYmiedRr9MxE5IMRocs6unFI,73
|
|
25
25
|
xax/task/launchers/cli.py,sha256=cK7Nm-3fO-W2gTxpn3FEThsT2NvneS2w0UjA1Nt-84A,1402
|
26
26
|
xax/task/launchers/single_process.py,sha256=IoML-30g5c526yxkpbWSOtG_KpNQMakT7xujzB1gIAo,846
|
27
27
|
xax/task/loggers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
28
|
-
xax/task/loggers/callback.py,sha256=
|
29
|
-
xax/task/loggers/json.py,sha256=
|
30
|
-
xax/task/loggers/state.py,sha256=
|
31
|
-
xax/task/loggers/stdout.py,sha256=
|
32
|
-
xax/task/loggers/tensorboard.py,sha256=
|
28
|
+
xax/task/loggers/callback.py,sha256=zQuV1xCvz47Q3UQqP1D5mBhbVzptvmPR_7hX25vqSk0,1667
|
29
|
+
xax/task/loggers/json.py,sha256=6A5wL7kspsXnpPhI_vu0scgd2Z2-WLhw4gbBFm7eZMM,4377
|
30
|
+
xax/task/loggers/state.py,sha256=0Jy0NYnY4c0qt0LvNlaTaCKOSqk5SCKln5VdyuQGnIc,1407
|
31
|
+
xax/task/loggers/stdout.py,sha256=giKSW2R83YkgRefm3BLkE7t8Pbj5Dux4AgsdJxYIbGo,6619
|
32
|
+
xax/task/loggers/tensorboard.py,sha256=gkAalLsYPGjZiiMlqvDWIhNpYCfKWNvnPz3brIv3JaQ,8725
|
33
33
|
xax/task/mixins/__init__.py,sha256=D3oU31rB9FeOr9MPLleLt5JFbftUr4sBTwgnwQdc2qA,809
|
34
|
-
xax/task/mixins/artifacts.py,sha256=
|
35
|
-
xax/task/mixins/checkpointing.py,sha256=
|
34
|
+
xax/task/mixins/artifacts.py,sha256=Ma7fwsp-SA1w6GcuBSskszj5TB83yxYJm4Ns_EnqkI4,3018
|
35
|
+
xax/task/mixins/checkpointing.py,sha256=zqospBFnTbGt_iriiduVfXazINPbzWpwmIs91KAniMY,10147
|
36
36
|
xax/task/mixins/compile.py,sha256=PG5aF3W9v_xGiImHgUJ7gmwuQQoSQWufdpl2N_mlLX0,3922
|
37
37
|
xax/task/mixins/cpu_stats.py,sha256=rO_9a82ZdsNec61ya4FpYE-rWqPhpijRSXsOfc6caFA,9595
|
38
38
|
xax/task/mixins/data_loader.py,sha256=Tp7zqPdfH2_JuE6J6EP-fEtCQpq9MjKlGHYK7Zh-goU,6599
|
@@ -41,25 +41,25 @@ xax/task/mixins/logger.py,sha256=6oXsJJyNUx6YT3q58FVXMZBUpMgjVkGre6BXFN20cVI,280
|
|
41
41
|
xax/task/mixins/process.py,sha256=hqDEsMp_SL6ee97iq26-G0g49OcWZZaX82JD4F22eJU,1781
|
42
42
|
xax/task/mixins/runnable.py,sha256=IYIsLd2k09g-_y6o44EhJqT7E6BpsyEMmsyLSuzqjtc,1979
|
43
43
|
xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
|
44
|
-
xax/task/mixins/train.py,sha256=
|
44
|
+
xax/task/mixins/train.py,sha256=VJbtoAjtLWJRRxRoMtJt13HkW_R8fOw7u6oBA5gcurk,31264
|
45
45
|
xax/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
46
46
|
xax/utils/debugging.py,sha256=OtUdu-3tQsQtik0Q9UM-SNV46IbPjwrAfZcywzoB5d4,1940
|
47
|
-
xax/utils/experiments.py,sha256=
|
47
|
+
xax/utils/experiments.py,sha256=bj8BftSHT3fFzfiJ0Co0WvqWo0rUS8kQnQYpVvH8FTM,29942
|
48
48
|
xax/utils/jax.py,sha256=KQYUHjN6t6JIWa11aRSO3edcsAgTscw_dExxI6kCd9g,6767
|
49
|
-
xax/utils/jaxpr.py,sha256=
|
49
|
+
xax/utils/jaxpr.py,sha256=H7pWl48ROXIB1-ZPWYfOn-ou3EBMxYWIwc_A0reJQoo,2333
|
50
50
|
xax/utils/logging.py,sha256=GAhTne2rdB4Fa1lzk06DMO15U8MTejn6XTClShC-ZtU,6622
|
51
51
|
xax/utils/numpy.py,sha256=_jOXVi-d2AtJnRftPkRK5MDMzsU8slgw-Jjv4GRm6ns,1197
|
52
52
|
xax/utils/profile.py,sha256=-aFdWpgYFvBsBZXSLL4zXrFe3zzsDqzmx4q5f2WOtpQ,1628
|
53
|
-
xax/utils/pytree.py,sha256=
|
53
|
+
xax/utils/pytree.py,sha256=Qp4u-jNOo9nd-2vdiuGHXFCTDERo4-zAGwJX_7kG7WM,9045
|
54
54
|
xax/utils/tensorboard.py,sha256=P0oIFvX2Qts1H4lkpizhRIpQdD0MNppVMeut0Z94yCs,19878
|
55
|
-
xax/utils/text.py,sha256=
|
55
|
+
xax/utils/text.py,sha256=xS02aSzdywl3KIaNSpKWcxdd37oYlUJtu9wIjkc1wVc,10654
|
56
56
|
xax/utils/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
57
57
|
xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,7066
|
58
58
|
xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
59
|
-
xax/utils/types/frozen_dict.py,sha256=
|
59
|
+
xax/utils/types/frozen_dict.py,sha256=ebtHENhyUzSjyJTlbMaLtcckQIJ7EtgJiok_40TJZpo,4689
|
60
60
|
xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
|
61
|
-
xax-0.2.
|
62
|
-
xax-0.2.
|
63
|
-
xax-0.2.
|
64
|
-
xax-0.2.
|
65
|
-
xax-0.2.
|
61
|
+
xax-0.2.8.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
|
62
|
+
xax-0.2.8.dist-info/METADATA,sha256=zv6QvF5HWciHILyNE-biwqpNIM6QPROSUaQ6aC-dGRU,1879
|
63
|
+
xax-0.2.8.dist-info/WHEEL,sha256=lTU6B6eIfYoiQJTZNc-fyaR6BpL6ehTzU3xGYxn2n8k,91
|
64
|
+
xax-0.2.8.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
|
65
|
+
xax-0.2.8.dist-info/RECORD,,
|
File without changes
|
File without changes
|