xax 0.2.7__tar.gz → 0.2.9__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {xax-0.2.7/xax.egg-info → xax-0.2.9}/PKG-INFO +1 -1
- {xax-0.2.7 → xax-0.2.9}/pyproject.toml +1 -0
- {xax-0.2.7 → xax-0.2.9}/xax/__init__.py +50 -5
- {xax-0.2.7 → xax-0.2.9}/xax/core/conf.py +1 -1
- {xax-0.2.7 → xax-0.2.9}/xax/nn/equinox.py +6 -3
- {xax-0.2.7 → xax-0.2.9}/xax/nn/functions.py +7 -4
- {xax-0.2.7 → xax-0.2.9}/xax/nn/geom.py +49 -0
- {xax-0.2.7 → xax-0.2.9}/xax/task/base.py +2 -2
- {xax-0.2.7 → xax-0.2.9}/xax/task/logger.py +11 -6
- {xax-0.2.7 → xax-0.2.9}/xax/task/loggers/callback.py +6 -0
- {xax-0.2.7 → xax-0.2.9}/xax/task/loggers/json.py +13 -0
- {xax-0.2.7 → xax-0.2.9}/xax/task/loggers/state.py +26 -1
- {xax-0.2.7 → xax-0.2.9}/xax/task/loggers/stdout.py +4 -2
- {xax-0.2.7 → xax-0.2.9}/xax/task/loggers/tensorboard.py +19 -1
- {xax-0.2.7 → xax-0.2.9}/xax/task/mixins/artifacts.py +11 -8
- {xax-0.2.7 → xax-0.2.9}/xax/task/mixins/train.py +2 -7
- {xax-0.2.7 → xax-0.2.9}/xax/utils/experiments.py +2 -1
- {xax-0.2.7 → xax-0.2.9}/xax/utils/pytree.py +8 -1
- {xax-0.2.7 → xax-0.2.9}/xax/utils/text.py +2 -2
- {xax-0.2.7 → xax-0.2.9}/xax/utils/types/frozen_dict.py +1 -1
- {xax-0.2.7 → xax-0.2.9/xax.egg-info}/PKG-INFO +1 -1
- {xax-0.2.7 → xax-0.2.9}/LICENSE +0 -0
- {xax-0.2.7 → xax-0.2.9}/MANIFEST.in +0 -0
- {xax-0.2.7 → xax-0.2.9}/README.md +0 -0
- {xax-0.2.7 → xax-0.2.9}/setup.cfg +0 -0
- {xax-0.2.7 → xax-0.2.9}/setup.py +0 -0
- {xax-0.2.7 → xax-0.2.9}/xax/core/__init__.py +0 -0
- {xax-0.2.7 → xax-0.2.9}/xax/core/state.py +0 -0
- {xax-0.2.7 → xax-0.2.9}/xax/nn/__init__.py +0 -0
- {xax-0.2.7 → xax-0.2.9}/xax/nn/embeddings.py +0 -0
- {xax-0.2.7 → xax-0.2.9}/xax/nn/export.py +0 -0
- {xax-0.2.7 → xax-0.2.9}/xax/nn/losses.py +0 -0
- {xax-0.2.7 → xax-0.2.9}/xax/nn/norm.py +0 -0
- {xax-0.2.7 → xax-0.2.9}/xax/nn/parallel.py +0 -0
- {xax-0.2.7 → xax-0.2.9}/xax/nn/ssm.py +0 -0
- {xax-0.2.7 → xax-0.2.9}/xax/py.typed +0 -0
- {xax-0.2.7 → xax-0.2.9}/xax/requirements-dev.txt +0 -0
- {xax-0.2.7 → xax-0.2.9}/xax/requirements.txt +0 -0
- {xax-0.2.7 → xax-0.2.9}/xax/task/__init__.py +0 -0
- {xax-0.2.7 → xax-0.2.9}/xax/task/launchers/__init__.py +0 -0
- {xax-0.2.7 → xax-0.2.9}/xax/task/launchers/base.py +0 -0
- {xax-0.2.7 → xax-0.2.9}/xax/task/launchers/cli.py +0 -0
- {xax-0.2.7 → xax-0.2.9}/xax/task/launchers/single_process.py +0 -0
- {xax-0.2.7 → xax-0.2.9}/xax/task/loggers/__init__.py +0 -0
- {xax-0.2.7 → xax-0.2.9}/xax/task/mixins/__init__.py +0 -0
- {xax-0.2.7 → xax-0.2.9}/xax/task/mixins/checkpointing.py +0 -0
- {xax-0.2.7 → xax-0.2.9}/xax/task/mixins/compile.py +0 -0
- {xax-0.2.7 → xax-0.2.9}/xax/task/mixins/cpu_stats.py +0 -0
- {xax-0.2.7 → xax-0.2.9}/xax/task/mixins/data_loader.py +0 -0
- {xax-0.2.7 → xax-0.2.9}/xax/task/mixins/gpu_stats.py +0 -0
- {xax-0.2.7 → xax-0.2.9}/xax/task/mixins/logger.py +0 -0
- {xax-0.2.7 → xax-0.2.9}/xax/task/mixins/process.py +0 -0
- {xax-0.2.7 → xax-0.2.9}/xax/task/mixins/runnable.py +0 -0
- {xax-0.2.7 → xax-0.2.9}/xax/task/mixins/step_wrapper.py +0 -0
- {xax-0.2.7 → xax-0.2.9}/xax/task/script.py +0 -0
- {xax-0.2.7 → xax-0.2.9}/xax/task/task.py +0 -0
- {xax-0.2.7 → xax-0.2.9}/xax/utils/__init__.py +0 -0
- {xax-0.2.7 → xax-0.2.9}/xax/utils/data/__init__.py +0 -0
- {xax-0.2.7 → xax-0.2.9}/xax/utils/data/collate.py +0 -0
- {xax-0.2.7 → xax-0.2.9}/xax/utils/debugging.py +0 -0
- {xax-0.2.7 → xax-0.2.9}/xax/utils/jax.py +0 -0
- {xax-0.2.7 → xax-0.2.9}/xax/utils/jaxpr.py +0 -0
- {xax-0.2.7 → xax-0.2.9}/xax/utils/logging.py +0 -0
- {xax-0.2.7 → xax-0.2.9}/xax/utils/numpy.py +0 -0
- {xax-0.2.7 → xax-0.2.9}/xax/utils/profile.py +0 -0
- {xax-0.2.7 → xax-0.2.9}/xax/utils/tensorboard.py +0 -0
- {xax-0.2.7 → xax-0.2.9}/xax/utils/types/__init__.py +0 -0
- {xax-0.2.7 → xax-0.2.9}/xax/utils/types/hashable_array.py +0 -0
- {xax-0.2.7 → xax-0.2.9}/xax.egg-info/SOURCES.txt +0 -0
- {xax-0.2.7 → xax-0.2.9}/xax.egg-info/dependency_links.txt +0 -0
- {xax-0.2.7 → xax-0.2.9}/xax.egg-info/requires.txt +0 -0
- {xax-0.2.7 → xax-0.2.9}/xax.egg-info/top_level.txt +0 -0
@@ -12,7 +12,7 @@ and running the update script:
|
|
12
12
|
python -m scripts.update_api --inplace
|
13
13
|
"""
|
14
14
|
|
15
|
-
__version__ = "0.2.
|
15
|
+
__version__ = "0.2.9"
|
16
16
|
|
17
17
|
# This list shouldn't be modified by hand; instead, run the update script.
|
18
18
|
__all__ = [
|
@@ -43,9 +43,12 @@ __all__ = [
|
|
43
43
|
"cubic_bezier_interpolation",
|
44
44
|
"euler_to_quat",
|
45
45
|
"get_projected_gravity_vector_from_quat",
|
46
|
+
"normalize",
|
46
47
|
"quat_to_euler",
|
47
48
|
"quat_to_rotmat",
|
48
49
|
"rotate_vector_by_quat",
|
50
|
+
"rotation6d_to_rotation_matrix",
|
51
|
+
"rotation_matrix_to_rotation6d",
|
49
52
|
"cross_entropy",
|
50
53
|
"cast_norm_type",
|
51
54
|
"get_norm",
|
@@ -57,8 +60,18 @@ __all__ = [
|
|
57
60
|
"BaseLauncher",
|
58
61
|
"CliLauncher",
|
59
62
|
"SingleProcessLauncher",
|
63
|
+
"LogDistribution",
|
64
|
+
"LogError",
|
65
|
+
"LogErrorSummary",
|
66
|
+
"LogGraph",
|
67
|
+
"LogHistogram",
|
60
68
|
"LogImage",
|
61
69
|
"LogLine",
|
70
|
+
"LogMesh",
|
71
|
+
"LogPing",
|
72
|
+
"LogScalar",
|
73
|
+
"LogStatus",
|
74
|
+
"LogVideo",
|
62
75
|
"Logger",
|
63
76
|
"LoggerImpl",
|
64
77
|
"CallbackLogger",
|
@@ -72,7 +85,6 @@ __all__ = [
|
|
72
85
|
"GPUStatsOptions",
|
73
86
|
"StepContext",
|
74
87
|
"ValidStepTimer",
|
75
|
-
"get_param_count",
|
76
88
|
"Script",
|
77
89
|
"ScriptConfig",
|
78
90
|
"Config",
|
@@ -117,6 +129,7 @@ __all__ = [
|
|
117
129
|
"compute_nan_ratio",
|
118
130
|
"flatten_array",
|
119
131
|
"flatten_pytree",
|
132
|
+
"get_pytree_param_count",
|
120
133
|
"pytree_has_nans",
|
121
134
|
"reshuffle_pytree",
|
122
135
|
"reshuffle_pytree_along_dims",
|
@@ -209,9 +222,12 @@ NAME_MAP: dict[str, str] = {
|
|
209
222
|
"cubic_bezier_interpolation": "nn.geom",
|
210
223
|
"euler_to_quat": "nn.geom",
|
211
224
|
"get_projected_gravity_vector_from_quat": "nn.geom",
|
225
|
+
"normalize": "nn.geom",
|
212
226
|
"quat_to_euler": "nn.geom",
|
213
227
|
"quat_to_rotmat": "nn.geom",
|
214
228
|
"rotate_vector_by_quat": "nn.geom",
|
229
|
+
"rotation6d_to_rotation_matrix": "nn.geom",
|
230
|
+
"rotation_matrix_to_rotation6d": "nn.geom",
|
215
231
|
"cross_entropy": "nn.losses",
|
216
232
|
"cast_norm_type": "nn.norm",
|
217
233
|
"get_norm": "nn.norm",
|
@@ -223,8 +239,18 @@ NAME_MAP: dict[str, str] = {
|
|
223
239
|
"BaseLauncher": "task.launchers.base",
|
224
240
|
"CliLauncher": "task.launchers.cli",
|
225
241
|
"SingleProcessLauncher": "task.launchers.single_process",
|
242
|
+
"LogDistribution": "task.logger",
|
243
|
+
"LogError": "task.logger",
|
244
|
+
"LogErrorSummary": "task.logger",
|
245
|
+
"LogGraph": "task.logger",
|
246
|
+
"LogHistogram": "task.logger",
|
226
247
|
"LogImage": "task.logger",
|
227
248
|
"LogLine": "task.logger",
|
249
|
+
"LogMesh": "task.logger",
|
250
|
+
"LogPing": "task.logger",
|
251
|
+
"LogScalar": "task.logger",
|
252
|
+
"LogStatus": "task.logger",
|
253
|
+
"LogVideo": "task.logger",
|
228
254
|
"Logger": "task.logger",
|
229
255
|
"LoggerImpl": "task.logger",
|
230
256
|
"CallbackLogger": "task.loggers.callback",
|
@@ -238,7 +264,6 @@ NAME_MAP: dict[str, str] = {
|
|
238
264
|
"GPUStatsOptions": "task.mixins.gpu_stats",
|
239
265
|
"StepContext": "task.mixins.step_wrapper",
|
240
266
|
"ValidStepTimer": "task.mixins.train",
|
241
|
-
"get_param_count": "task.mixins.train",
|
242
267
|
"Script": "task.script",
|
243
268
|
"ScriptConfig": "task.script",
|
244
269
|
"Config": "task.task",
|
@@ -283,6 +308,7 @@ NAME_MAP: dict[str, str] = {
|
|
283
308
|
"compute_nan_ratio": "utils.pytree",
|
284
309
|
"flatten_array": "utils.pytree",
|
285
310
|
"flatten_pytree": "utils.pytree",
|
311
|
+
"get_pytree_param_count": "utils.pytree",
|
286
312
|
"pytree_has_nans": "utils.pytree",
|
287
313
|
"reshuffle_pytree": "utils.pytree",
|
288
314
|
"reshuffle_pytree_along_dims": "utils.pytree",
|
@@ -376,9 +402,12 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
376
402
|
cubic_bezier_interpolation,
|
377
403
|
euler_to_quat,
|
378
404
|
get_projected_gravity_vector_from_quat,
|
405
|
+
normalize,
|
379
406
|
quat_to_euler,
|
380
407
|
quat_to_rotmat,
|
381
408
|
rotate_vector_by_quat,
|
409
|
+
rotation6d_to_rotation_matrix,
|
410
|
+
rotation_matrix_to_rotation6d,
|
382
411
|
)
|
383
412
|
from xax.nn.losses import cross_entropy
|
384
413
|
from xax.nn.norm import NormType, cast_norm_type, get_norm
|
@@ -388,7 +417,22 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
388
417
|
from xax.task.launchers.base import BaseLauncher
|
389
418
|
from xax.task.launchers.cli import CliLauncher
|
390
419
|
from xax.task.launchers.single_process import SingleProcessLauncher
|
391
|
-
from xax.task.logger import
|
420
|
+
from xax.task.logger import (
|
421
|
+
LogDistribution,
|
422
|
+
LogError,
|
423
|
+
LogErrorSummary,
|
424
|
+
Logger,
|
425
|
+
LoggerImpl,
|
426
|
+
LogGraph,
|
427
|
+
LogHistogram,
|
428
|
+
LogImage,
|
429
|
+
LogLine,
|
430
|
+
LogMesh,
|
431
|
+
LogPing,
|
432
|
+
LogScalar,
|
433
|
+
LogStatus,
|
434
|
+
LogVideo,
|
435
|
+
)
|
392
436
|
from xax.task.loggers.callback import CallbackLogger
|
393
437
|
from xax.task.loggers.json import JsonLogger
|
394
438
|
from xax.task.loggers.state import StateLogger
|
@@ -399,7 +443,7 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
399
443
|
from xax.task.mixins.data_loader import DataloaderConfig
|
400
444
|
from xax.task.mixins.gpu_stats import GPUStatsOptions
|
401
445
|
from xax.task.mixins.step_wrapper import StepContext
|
402
|
-
from xax.task.mixins.train import Batch, Output, ValidStepTimer
|
446
|
+
from xax.task.mixins.train import Batch, Output, ValidStepTimer
|
403
447
|
from xax.task.script import Script, ScriptConfig
|
404
448
|
from xax.task.task import Config, Task
|
405
449
|
from xax.utils.data.collate import CollateMode, collate, collate_non_null
|
@@ -444,6 +488,7 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
444
488
|
compute_nan_ratio,
|
445
489
|
flatten_array,
|
446
490
|
flatten_pytree,
|
491
|
+
get_pytree_param_count,
|
447
492
|
pytree_has_nans,
|
448
493
|
reshuffle_pytree,
|
449
494
|
reshuffle_pytree_along_dims,
|
@@ -26,7 +26,7 @@ def field(value: FieldType, **kwargs: str) -> FieldType:
|
|
26
26
|
metadata: dict[str, Any] = {}
|
27
27
|
metadata.update(kwargs)
|
28
28
|
|
29
|
-
if hasattr(value, "__call__"):
|
29
|
+
if hasattr(value, "__call__"): # noqa: B004
|
30
30
|
return field_base(default_factory=value, metadata=metadata)
|
31
31
|
if value.__class__.__hash__ is None:
|
32
32
|
return field_base(default_factory=lambda: value, metadata=metadata)
|
@@ -68,8 +68,8 @@ def _infer_activation(activation: ActivationFunction) -> Callable:
|
|
68
68
|
return lambda x: x
|
69
69
|
try:
|
70
70
|
return getattr(jax.nn, activation)
|
71
|
-
except AttributeError:
|
72
|
-
raise ValueError(f"Activation function `{activation}` not found in `jax.nn`")
|
71
|
+
except AttributeError as err:
|
72
|
+
raise ValueError(f"Activation function `{activation}` not found in `jax.nn`") from err
|
73
73
|
|
74
74
|
|
75
75
|
def make_eqx_mlp(hyperparams: MLPHyperParams, *, key: PRNGKeyArray) -> eqx.nn.MLP:
|
@@ -100,7 +100,7 @@ def make_eqx_mlp(hyperparams: MLPHyperParams, *, key: PRNGKeyArray) -> eqx.nn.ML
|
|
100
100
|
def export_eqx_mlp(
|
101
101
|
model: eqx.nn.MLP,
|
102
102
|
output_path: str | Path,
|
103
|
-
dtype: jax.numpy.dtype =
|
103
|
+
dtype: jax.numpy.dtype | None = None,
|
104
104
|
) -> None:
|
105
105
|
"""Serialize an Equinox MLP to a .eqx file.
|
106
106
|
|
@@ -109,6 +109,9 @@ def export_eqx_mlp(
|
|
109
109
|
output_path: The path to save the exported model.
|
110
110
|
dtype: The dtype of the model.
|
111
111
|
"""
|
112
|
+
if dtype is None:
|
113
|
+
dtype = eqx._misc.default_floating_dtype()
|
114
|
+
|
112
115
|
activation = model.activation.__name__
|
113
116
|
final_activation = model.final_activation.__name__
|
114
117
|
|
@@ -58,13 +58,16 @@ def recursive_chunk(item: Any, num_chunks: int, dim: int = 0) -> Iterable[Any]:
|
|
58
58
|
yield from np.array_split(item, num_chunks, axis=dim)
|
59
59
|
elif is_dataclass(item):
|
60
60
|
yield from (
|
61
|
-
item.__class__(**{k: i for k, i in zip(item.__dict__, ii)})
|
62
|
-
for ii in zip(*(recursive_chunk(v, num_chunks, dim) for v in item.__dict__.values()))
|
61
|
+
item.__class__(**{k: i for k, i in zip(item.__dict__, ii, strict=True)})
|
62
|
+
for ii in zip(*(recursive_chunk(v, num_chunks, dim) for v in item.__dict__.values()), strict=False)
|
63
63
|
)
|
64
64
|
elif isinstance(item, Mapping):
|
65
|
-
yield from (
|
65
|
+
yield from (
|
66
|
+
dict(zip(item, ii, strict=False))
|
67
|
+
for ii in zip(*(recursive_chunk(i, num_chunks, dim) for i in item.values()), strict=False)
|
68
|
+
)
|
66
69
|
elif isinstance(item, Sequence):
|
67
|
-
yield from (list(ii) for ii in zip(*(recursive_chunk(i, num_chunks, dim) for i in item)))
|
70
|
+
yield from (list(ii) for ii in zip(*(recursive_chunk(i, num_chunks, dim) for i in item), strict=False))
|
68
71
|
else:
|
69
72
|
yield from (item for _ in range(num_chunks))
|
70
73
|
|
@@ -1,5 +1,6 @@
|
|
1
1
|
"""Defines geometry functions."""
|
2
2
|
|
3
|
+
import chex
|
3
4
|
from jax import numpy as jnp
|
4
5
|
from jaxtyping import Array
|
5
6
|
|
@@ -211,3 +212,51 @@ def quat_to_rotmat(quat: Array, eps: float = 1e-6) -> Array:
|
|
211
212
|
],
|
212
213
|
axis=-2,
|
213
214
|
)
|
215
|
+
|
216
|
+
|
217
|
+
def normalize(v: jnp.ndarray, axis: int = -1, eps: float = 1e-8) -> jnp.ndarray:
|
218
|
+
norm = jnp.linalg.norm(v, axis=axis, keepdims=True)
|
219
|
+
return v / jnp.clip(norm, a_min=eps)
|
220
|
+
|
221
|
+
|
222
|
+
def rotation6d_to_rotation_matrix(r6d: jnp.ndarray) -> jnp.ndarray:
|
223
|
+
"""Convert 6D rotation representation to rotation matrix.
|
224
|
+
|
225
|
+
From https://arxiv.org/pdf/1812.07035, Appendix B
|
226
|
+
|
227
|
+
Args:
|
228
|
+
r6d: The 6D rotation representation, shape (*, 6).
|
229
|
+
|
230
|
+
Returns:
|
231
|
+
The rotation matrix, shape (*, 3, 3).
|
232
|
+
"""
|
233
|
+
chex.assert_shape(r6d, (..., 6))
|
234
|
+
shape = r6d.shape
|
235
|
+
flat = r6d.reshape(-1, 6)
|
236
|
+
a_1 = flat[:, 0:3]
|
237
|
+
a_2 = flat[:, 3:6]
|
238
|
+
|
239
|
+
b_1 = normalize(a_1, axis=-1)
|
240
|
+
|
241
|
+
# Reordered Gram-Schmidt orthonormalization.
|
242
|
+
b_3 = normalize(jnp.cross(b_1, a_2), axis=-1)
|
243
|
+
b_2 = jnp.cross(b_3, b_1)
|
244
|
+
|
245
|
+
rotation_matrix = jnp.stack([b_1, b_2, b_3], axis=-1)
|
246
|
+
return rotation_matrix.reshape(shape[:-1] + (3, 3))
|
247
|
+
|
248
|
+
|
249
|
+
def rotation_matrix_to_rotation6d(rotation_matrix: jnp.ndarray) -> jnp.ndarray:
|
250
|
+
"""Convert rotation matrix to 6D rotation representation.
|
251
|
+
|
252
|
+
Args:
|
253
|
+
rotation_matrix: The rotation matrix, shape (*, 3, 3).
|
254
|
+
|
255
|
+
Returns:
|
256
|
+
The 6D rotation representation, shape (*, 6).
|
257
|
+
"""
|
258
|
+
chex.assert_shape(rotation_matrix, (..., 3, 3))
|
259
|
+
shape = rotation_matrix.shape
|
260
|
+
# Simply concatenate a1 and a2 from SO(3)
|
261
|
+
r6d = jnp.concatenate([rotation_matrix[..., 0], rotation_matrix[..., 1]], axis=-1)
|
262
|
+
return r6d.reshape(shape[:-2] + (6,))
|
@@ -184,8 +184,8 @@ class BaseTask(Generic[Config]):
|
|
184
184
|
|
185
185
|
# Attempts to load any paths as configs.
|
186
186
|
is_path = [Path(arg).is_file() or (task_path / arg).is_file() for arg in args]
|
187
|
-
paths = [arg for arg, is_path in zip(args, is_path) if is_path]
|
188
|
-
non_paths = [arg for arg, is_path in zip(args, is_path) if not is_path]
|
187
|
+
paths = [arg for arg, is_path in zip(args, is_path, strict=True) if is_path]
|
188
|
+
non_paths = [arg for arg, is_path in zip(args, is_path, strict=True) if not is_path]
|
189
189
|
if paths:
|
190
190
|
cfg = OmegaConf.merge(cfg, *(get_config(path, task_path) for path in paths))
|
191
191
|
cfg = OmegaConf.merge(cfg, OmegaConf.from_cli(non_paths))
|
@@ -462,11 +462,11 @@ class LoggerImpl(ABC):
|
|
462
462
|
|
463
463
|
self.tickers = {phase: IntervalTicker(log_interval_seconds) for phase in get_args(Phase)}
|
464
464
|
|
465
|
-
|
466
|
-
|
465
|
+
@abstractmethod
|
466
|
+
def start(self) -> None: ...
|
467
467
|
|
468
|
-
|
469
|
-
|
468
|
+
@abstractmethod
|
469
|
+
def stop(self) -> None: ...
|
470
470
|
|
471
471
|
@abstractmethod
|
472
472
|
def write(self, line: LogLine) -> None:
|
@@ -476,6 +476,7 @@ class LoggerImpl(ABC):
|
|
476
476
|
line: The line to write.
|
477
477
|
"""
|
478
478
|
|
479
|
+
@abstractmethod
|
479
480
|
def write_error_summary(self, error_summary: LogErrorSummary) -> None:
|
480
481
|
"""Handles writing an error summary.
|
481
482
|
|
@@ -483,6 +484,7 @@ class LoggerImpl(ABC):
|
|
483
484
|
error_summary: The error summary to write.
|
484
485
|
"""
|
485
486
|
|
487
|
+
@abstractmethod
|
486
488
|
def write_error(self, error: LogError) -> None:
|
487
489
|
"""Handles writing an error line.
|
488
490
|
|
@@ -490,6 +492,7 @@ class LoggerImpl(ABC):
|
|
490
492
|
error: The error information to write.
|
491
493
|
"""
|
492
494
|
|
495
|
+
@abstractmethod
|
493
496
|
def write_status(self, status: LogStatus) -> None:
|
494
497
|
"""Handles writing a status line.
|
495
498
|
|
@@ -497,6 +500,7 @@ class LoggerImpl(ABC):
|
|
497
500
|
status: The status to write.
|
498
501
|
"""
|
499
502
|
|
503
|
+
@abstractmethod
|
500
504
|
def write_ping(self, ping: LogPing) -> None:
|
501
505
|
"""Handles writing a ping line.
|
502
506
|
|
@@ -504,6 +508,7 @@ class LoggerImpl(ABC):
|
|
504
508
|
ping: The ping to write.
|
505
509
|
"""
|
506
510
|
|
511
|
+
@abstractmethod
|
507
512
|
def log_file(self, name: str, contents: str) -> None:
|
508
513
|
"""Logs a large text file.
|
509
514
|
|
@@ -621,7 +626,7 @@ class Logger:
|
|
621
626
|
return
|
622
627
|
line = self.pack(state)
|
623
628
|
self.clear()
|
624
|
-
for lg in (lg for lg, should_log in zip(self.loggers, should_log) if should_log):
|
629
|
+
for lg in (lg for lg, should_log in zip(self.loggers, should_log, strict=False) if should_log):
|
625
630
|
lg.write(line)
|
626
631
|
|
627
632
|
def write_error_summary(self, error_summary: str) -> None:
|
@@ -1045,7 +1050,7 @@ class Logger:
|
|
1045
1050
|
line_spacing=line_spacing,
|
1046
1051
|
centered=centered,
|
1047
1052
|
)
|
1048
|
-
for img, label in zip(images, labels)
|
1053
|
+
for img, label in zip(images, labels, strict=True)
|
1049
1054
|
]
|
1050
1055
|
tiled = tile_images([img.image for img in labeled], sep)
|
1051
1056
|
|
@@ -25,6 +25,12 @@ class CallbackLogger(LoggerImpl):
|
|
25
25
|
self.ping_callback = ping_callback
|
26
26
|
self.file_callback = file_callback
|
27
27
|
|
28
|
+
def start(self) -> None:
|
29
|
+
pass
|
30
|
+
|
31
|
+
def stop(self) -> None:
|
32
|
+
pass
|
33
|
+
|
28
34
|
def write(self, line: LogLine) -> None:
|
29
35
|
self.callback(line)
|
30
36
|
|
@@ -8,6 +8,7 @@ from jaxtyping import Array
|
|
8
8
|
|
9
9
|
from xax.task.logger import (
|
10
10
|
LogError,
|
11
|
+
LogErrorSummary,
|
11
12
|
LoggerImpl,
|
12
13
|
LogLine,
|
13
14
|
LogPing,
|
@@ -57,6 +58,12 @@ class JsonLogger(LoggerImpl):
|
|
57
58
|
self.line_sep = line_sep
|
58
59
|
self.remove_unicode_from_namespaces = remove_unicode_from_namespaces
|
59
60
|
|
61
|
+
def start(self) -> None:
|
62
|
+
pass
|
63
|
+
|
64
|
+
def stop(self) -> None:
|
65
|
+
pass
|
66
|
+
|
60
67
|
@property
|
61
68
|
def fp(self) -> TextIO:
|
62
69
|
return self.log_stream
|
@@ -87,6 +94,12 @@ class JsonLogger(LoggerImpl):
|
|
87
94
|
if self.flush_immediately:
|
88
95
|
self.fp.flush()
|
89
96
|
|
97
|
+
def write_error_summary(self, error_summary: LogErrorSummary) -> None:
|
98
|
+
pass
|
99
|
+
|
100
|
+
def log_file(self, name: str, contents: str) -> None:
|
101
|
+
pass
|
102
|
+
|
90
103
|
def write_error(self, error: LogError) -> None:
|
91
104
|
self.err_fp.write(error.message)
|
92
105
|
if error.location is not None:
|
@@ -3,7 +3,14 @@
|
|
3
3
|
from pathlib import Path
|
4
4
|
from typing import Literal
|
5
5
|
|
6
|
-
from xax.task.logger import
|
6
|
+
from xax.task.logger import (
|
7
|
+
LogError,
|
8
|
+
LogErrorSummary,
|
9
|
+
LoggerImpl,
|
10
|
+
LogLine,
|
11
|
+
LogPing,
|
12
|
+
LogStatus,
|
13
|
+
)
|
7
14
|
|
8
15
|
|
9
16
|
class StateLogger(LoggerImpl):
|
@@ -30,3 +37,21 @@ class StateLogger(LoggerImpl):
|
|
30
37
|
|
31
38
|
def write(self, line: LogLine) -> None:
|
32
39
|
pass
|
40
|
+
|
41
|
+
def start(self) -> None:
|
42
|
+
pass
|
43
|
+
|
44
|
+
def stop(self) -> None:
|
45
|
+
pass
|
46
|
+
|
47
|
+
def write_error_summary(self, error_summary: LogErrorSummary) -> None:
|
48
|
+
pass
|
49
|
+
|
50
|
+
def write_error(self, error: LogError) -> None:
|
51
|
+
pass
|
52
|
+
|
53
|
+
def write_status(self, status: LogStatus) -> None:
|
54
|
+
pass
|
55
|
+
|
56
|
+
def write_ping(self, ping: LogPing) -> None:
|
57
|
+
pass
|
@@ -79,11 +79,13 @@ class StdoutLogger(LoggerImpl):
|
|
79
79
|
self.error_summary: tuple[str, datetime.datetime] | None = None
|
80
80
|
|
81
81
|
def start(self) -> None:
|
82
|
-
|
82
|
+
pass
|
83
83
|
|
84
84
|
def stop(self) -> None:
|
85
85
|
self.write_queues()
|
86
|
-
|
86
|
+
|
87
|
+
def log_file(self, name: str, contents: str) -> None:
|
88
|
+
pass
|
87
89
|
|
88
90
|
def write_separator(self) -> None:
|
89
91
|
self.write_fp.write("\033[2J\033[H")
|
@@ -12,7 +12,7 @@ from typing import TypeVar
|
|
12
12
|
|
13
13
|
from xax.core.state import Phase
|
14
14
|
from xax.nn.parallel import is_master
|
15
|
-
from xax.task.logger import LoggerImpl, LogLine
|
15
|
+
from xax.task.logger import LogError, LogErrorSummary, LoggerImpl, LogLine, LogPing, LogStatus
|
16
16
|
from xax.utils.jax import as_float
|
17
17
|
from xax.utils.logging import LOG_STATUS, port_is_busy
|
18
18
|
from xax.utils.tensorboard import TensorboardWriter, TensorboardWriters
|
@@ -236,3 +236,21 @@ class TensorboardLogger(LoggerImpl):
|
|
236
236
|
for name, contents in self.files.items():
|
237
237
|
writer.add_text(name, contents)
|
238
238
|
self.files.clear()
|
239
|
+
|
240
|
+
def start(self) -> None:
|
241
|
+
pass
|
242
|
+
|
243
|
+
def stop(self) -> None:
|
244
|
+
pass
|
245
|
+
|
246
|
+
def write_error(self, error: LogError) -> None:
|
247
|
+
pass
|
248
|
+
|
249
|
+
def write_error_summary(self, error_summary: LogErrorSummary) -> None:
|
250
|
+
pass
|
251
|
+
|
252
|
+
def write_ping(self, ping: LogPing) -> None:
|
253
|
+
pass
|
254
|
+
|
255
|
+
def write_status(self, status: LogStatus) -> None:
|
256
|
+
pass
|
@@ -31,11 +31,13 @@ Config = TypeVar("Config", bound=ArtifactsConfig)
|
|
31
31
|
|
32
32
|
class ArtifactsMixin(BaseTask[Config]):
|
33
33
|
_exp_dir: Path | None
|
34
|
+
_stage_dir: Path | None
|
34
35
|
|
35
36
|
def __init__(self, config: Config) -> None:
|
36
37
|
super().__init__(config)
|
37
38
|
|
38
39
|
self._exp_dir = None
|
40
|
+
self._stage_dir = None
|
39
41
|
|
40
42
|
@functools.cached_property
|
41
43
|
def run_dir(self) -> Path:
|
@@ -75,15 +77,16 @@ class ArtifactsMixin(BaseTask[Config]):
|
|
75
77
|
logger.log(LOG_STATUS, self._exp_dir)
|
76
78
|
return self._exp_dir
|
77
79
|
|
78
|
-
@functools.lru_cache(maxsize=None)
|
79
80
|
def stage_environment(self) -> Path | None:
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
81
|
+
if self._stage_dir is None:
|
82
|
+
stage_dir = (self.exp_dir / "code").resolve()
|
83
|
+
try:
|
84
|
+
stage_environment(self, stage_dir)
|
85
|
+
except Exception:
|
86
|
+
logger.exception("Failed to stage environment!")
|
87
|
+
return None
|
88
|
+
self._stage_dir = stage_dir
|
89
|
+
return self._stage_dir
|
87
90
|
|
88
91
|
def on_training_end(self, state: State) -> State:
|
89
92
|
state = super().on_training_end(state)
|
@@ -57,6 +57,7 @@ from xax.utils.experiments import (
|
|
57
57
|
)
|
58
58
|
from xax.utils.jax import jit as xax_jit
|
59
59
|
from xax.utils.logging import LOG_PING, LOG_STATUS
|
60
|
+
from xax.utils.pytree import get_pytree_param_count
|
60
61
|
from xax.utils.text import highlight_exception_message, show_info
|
61
62
|
from xax.utils.types.frozen_dict import FrozenDict
|
62
63
|
|
@@ -96,12 +97,6 @@ def batches_per_step_schedule(schedule: list[int] | None) -> list[int] | None:
|
|
96
97
|
return list(itertools.accumulate([0] + schedule))
|
97
98
|
|
98
99
|
|
99
|
-
def get_param_count(pytree: PyTree) -> int:
|
100
|
-
"""Calculates the total number of parameters in a PyTree."""
|
101
|
-
leaves, _ = jax.tree.flatten(pytree)
|
102
|
-
return sum(x.size for x in leaves if isinstance(x, jnp.ndarray))
|
103
|
-
|
104
|
-
|
105
100
|
class ValidStepTimer:
|
106
101
|
def __init__(
|
107
102
|
self,
|
@@ -690,7 +685,7 @@ class TrainMixin(
|
|
690
685
|
self.logger.log_file("info.json", get_info_json())
|
691
686
|
|
692
687
|
def log_model_size(self, model: PyTree) -> None:
|
693
|
-
logger.info("Model size: %s", f"{
|
688
|
+
logger.info("Model size: %s", f"{get_pytree_param_count(model):,}")
|
694
689
|
|
695
690
|
def model_partition_fn(self, item: Any) -> bool: # noqa: ANN401
|
696
691
|
return eqx.is_inexact_array(item)
|
@@ -749,7 +749,8 @@ class BaseFileDownloader(ABC):
|
|
749
749
|
f"We detected some HTML elements in the downloaded file. "
|
750
750
|
f"This most likely means that the download triggered an unhandled API response by GDrive. "
|
751
751
|
f"Please report this to torchvision at https://github.com/pytorch/vision/issues including "
|
752
|
-
f"the response:\n\n{text}"
|
752
|
+
f"the response:\n\n{text}",
|
753
|
+
stacklevel=2,
|
753
754
|
)
|
754
755
|
|
755
756
|
@classmethod
|
@@ -1,6 +1,7 @@
|
|
1
1
|
"""Utils for accessing, modifying, and otherwise manipulating pytrees."""
|
2
2
|
|
3
3
|
import chex
|
4
|
+
import equinox as eqx
|
4
5
|
import jax
|
5
6
|
import jax.numpy as jnp
|
6
7
|
from jax import Array
|
@@ -124,7 +125,7 @@ def reshuffle_pytree(data: PyTree, batch_shape: tuple[int, ...], rng: PRNGKeyArr
|
|
124
125
|
def reshuffle_pytree_independently(data: PyTree, batch_shape: tuple[int, ...], rng: PRNGKeyArray) -> PyTree:
|
125
126
|
"""Reshuffle a rollout array across arbitrary batch dimensions independently of each other."""
|
126
127
|
rngs = jax.random.split(rng, len(batch_shape))
|
127
|
-
perms = [jax.random.permutation(rng_i, dim) for rng_i, dim in zip(rngs, batch_shape)]
|
128
|
+
perms = [jax.random.permutation(rng_i, dim) for rng_i, dim in zip(rngs, batch_shape, strict=True)]
|
128
129
|
# n-dimensional index grid from permutations
|
129
130
|
idx_grids = jnp.meshgrid(*perms, indexing="ij")
|
130
131
|
|
@@ -236,3 +237,9 @@ def reshuffle_pytree_along_dims(
|
|
236
237
|
return x
|
237
238
|
|
238
239
|
return jax.tree.map_with_path(restore_transpose, reshuffled_transposed)
|
240
|
+
|
241
|
+
|
242
|
+
def get_pytree_param_count(pytree: PyTree) -> int:
|
243
|
+
"""Calculates the total number of parameters in a PyTree."""
|
244
|
+
leaves, _ = jax.tree.flatten(pytree)
|
245
|
+
return sum(x.size for x in leaves if isinstance(x, jnp.ndarray) and eqx.is_inexact_array(x))
|
@@ -192,7 +192,7 @@ def render_text_blocks(
|
|
192
192
|
if any(len(row) != len(blocks[0]) for row in blocks):
|
193
193
|
raise ValueError("All rows must have the same number of blocks in order to align them")
|
194
194
|
widths = [[max(len(line) for line in i.lines) if i.width is None else i.width for i in r] for r in blocks]
|
195
|
-
row_widths = [max(i) for i in zip(*widths)]
|
195
|
+
row_widths = [max(i) for i in zip(*widths, strict=True)]
|
196
196
|
for row in blocks:
|
197
197
|
for i, block in enumerate(row):
|
198
198
|
block.width = row_widths[i]
|
@@ -263,7 +263,7 @@ def render_text_blocks(
|
|
263
263
|
if i >= len(block.lines)
|
264
264
|
else colored(pad(block.lines[i], width, block.center), block.color, bold=block.bold)
|
265
265
|
)
|
266
|
-
for block, width in zip(row, get_widths(row))
|
266
|
+
for block, width in zip(row, get_widths(row), strict=True)
|
267
267
|
]
|
268
268
|
)
|
269
269
|
+ " │"
|
@@ -133,7 +133,7 @@ class FrozenDict(Mapping[K, V]):
|
|
133
133
|
|
134
134
|
@classmethod
|
135
135
|
def tree_unflatten(cls, keys: tuple[K, ...], values: tuple[Any, ...]) -> "FrozenDict[K, V]":
|
136
|
-
return cls({k: v for k, v in zip(keys, values)}, __unsafe_skip_copy__=True)
|
136
|
+
return cls({k: v for k, v in zip(keys, values, strict=True)}, __unsafe_skip_copy__=True)
|
137
137
|
|
138
138
|
|
139
139
|
def unfreeze(x: FrozenDict[K, V] | dict[str, Any]) -> dict[Any, Any]: # noqa: ANN401
|
{xax-0.2.7 → xax-0.2.9}/LICENSE
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{xax-0.2.7 → xax-0.2.9}/setup.py
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|