xax 0.1.14__tar.gz → 0.1.16__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {xax-0.1.14/xax.egg-info → xax-0.1.16}/PKG-INFO +1 -1
- {xax-0.1.14 → xax-0.1.16}/xax/__init__.py +4 -1
- {xax-0.1.14 → xax-0.1.16}/xax/nn/geom.py +26 -5
- {xax-0.1.14 → xax-0.1.16}/xax/task/logger.py +42 -10
- {xax-0.1.14 → xax-0.1.16}/xax/task/loggers/json.py +12 -4
- {xax-0.1.14 → xax-0.1.16}/xax/task/loggers/stdout.py +21 -16
- {xax-0.1.14 → xax-0.1.16}/xax/task/loggers/tensorboard.py +2 -2
- {xax-0.1.14 → xax-0.1.16}/xax/task/mixins/cpu_stats.py +10 -10
- {xax-0.1.14 → xax-0.1.16}/xax/task/mixins/gpu_stats.py +3 -3
- {xax-0.1.14 → xax-0.1.16}/xax/task/mixins/train.py +27 -11
- {xax-0.1.14 → xax-0.1.16}/xax/utils/experiments.py +21 -20
- {xax-0.1.14 → xax-0.1.16/xax.egg-info}/PKG-INFO +1 -1
- {xax-0.1.14 → xax-0.1.16}/LICENSE +0 -0
- {xax-0.1.14 → xax-0.1.16}/MANIFEST.in +0 -0
- {xax-0.1.14 → xax-0.1.16}/README.md +0 -0
- {xax-0.1.14 → xax-0.1.16}/pyproject.toml +0 -0
- {xax-0.1.14 → xax-0.1.16}/setup.cfg +0 -0
- {xax-0.1.14 → xax-0.1.16}/setup.py +0 -0
- {xax-0.1.14 → xax-0.1.16}/xax/core/__init__.py +0 -0
- {xax-0.1.14 → xax-0.1.16}/xax/core/conf.py +0 -0
- {xax-0.1.14 → xax-0.1.16}/xax/core/state.py +0 -0
- {xax-0.1.14 → xax-0.1.16}/xax/nn/__init__.py +0 -0
- {xax-0.1.14 → xax-0.1.16}/xax/nn/embeddings.py +0 -0
- {xax-0.1.14 → xax-0.1.16}/xax/nn/equinox.py +0 -0
- {xax-0.1.14 → xax-0.1.16}/xax/nn/export.py +0 -0
- {xax-0.1.14 → xax-0.1.16}/xax/nn/functions.py +0 -0
- {xax-0.1.14 → xax-0.1.16}/xax/nn/losses.py +0 -0
- {xax-0.1.14 → xax-0.1.16}/xax/nn/norm.py +0 -0
- {xax-0.1.14 → xax-0.1.16}/xax/nn/parallel.py +0 -0
- {xax-0.1.14 → xax-0.1.16}/xax/nn/ssm.py +0 -0
- {xax-0.1.14 → xax-0.1.16}/xax/py.typed +0 -0
- {xax-0.1.14 → xax-0.1.16}/xax/requirements-dev.txt +0 -0
- {xax-0.1.14 → xax-0.1.16}/xax/requirements.txt +0 -0
- {xax-0.1.14 → xax-0.1.16}/xax/task/__init__.py +0 -0
- {xax-0.1.14 → xax-0.1.16}/xax/task/base.py +0 -0
- {xax-0.1.14 → xax-0.1.16}/xax/task/launchers/__init__.py +0 -0
- {xax-0.1.14 → xax-0.1.16}/xax/task/launchers/base.py +0 -0
- {xax-0.1.14 → xax-0.1.16}/xax/task/launchers/cli.py +0 -0
- {xax-0.1.14 → xax-0.1.16}/xax/task/launchers/single_process.py +0 -0
- {xax-0.1.14 → xax-0.1.16}/xax/task/loggers/__init__.py +0 -0
- {xax-0.1.14 → xax-0.1.16}/xax/task/loggers/callback.py +0 -0
- {xax-0.1.14 → xax-0.1.16}/xax/task/loggers/state.py +0 -0
- {xax-0.1.14 → xax-0.1.16}/xax/task/mixins/__init__.py +0 -0
- {xax-0.1.14 → xax-0.1.16}/xax/task/mixins/artifacts.py +0 -0
- {xax-0.1.14 → xax-0.1.16}/xax/task/mixins/checkpointing.py +0 -0
- {xax-0.1.14 → xax-0.1.16}/xax/task/mixins/compile.py +0 -0
- {xax-0.1.14 → xax-0.1.16}/xax/task/mixins/data_loader.py +0 -0
- {xax-0.1.14 → xax-0.1.16}/xax/task/mixins/logger.py +0 -0
- {xax-0.1.14 → xax-0.1.16}/xax/task/mixins/process.py +0 -0
- {xax-0.1.14 → xax-0.1.16}/xax/task/mixins/runnable.py +0 -0
- {xax-0.1.14 → xax-0.1.16}/xax/task/mixins/step_wrapper.py +0 -0
- {xax-0.1.14 → xax-0.1.16}/xax/task/script.py +0 -0
- {xax-0.1.14 → xax-0.1.16}/xax/task/task.py +0 -0
- {xax-0.1.14 → xax-0.1.16}/xax/utils/__init__.py +0 -0
- {xax-0.1.14 → xax-0.1.16}/xax/utils/data/__init__.py +0 -0
- {xax-0.1.14 → xax-0.1.16}/xax/utils/data/collate.py +0 -0
- {xax-0.1.14 → xax-0.1.16}/xax/utils/debugging.py +0 -0
- {xax-0.1.14 → xax-0.1.16}/xax/utils/jax.py +0 -0
- {xax-0.1.14 → xax-0.1.16}/xax/utils/jaxpr.py +0 -0
- {xax-0.1.14 → xax-0.1.16}/xax/utils/logging.py +0 -0
- {xax-0.1.14 → xax-0.1.16}/xax/utils/numpy.py +0 -0
- {xax-0.1.14 → xax-0.1.16}/xax/utils/profile.py +0 -0
- {xax-0.1.14 → xax-0.1.16}/xax/utils/pytree.py +0 -0
- {xax-0.1.14 → xax-0.1.16}/xax/utils/tensorboard.py +0 -0
- {xax-0.1.14 → xax-0.1.16}/xax/utils/text.py +0 -0
- {xax-0.1.14 → xax-0.1.16}/xax/utils/types/__init__.py +0 -0
- {xax-0.1.14 → xax-0.1.16}/xax/utils/types/frozen_dict.py +0 -0
- {xax-0.1.14 → xax-0.1.16}/xax/utils/types/hashable_array.py +0 -0
- {xax-0.1.14 → xax-0.1.16}/xax.egg-info/SOURCES.txt +0 -0
- {xax-0.1.14 → xax-0.1.16}/xax.egg-info/dependency_links.txt +0 -0
- {xax-0.1.14 → xax-0.1.16}/xax.egg-info/requires.txt +0 -0
- {xax-0.1.14 → xax-0.1.16}/xax.egg-info/top_level.txt +0 -0
@@ -12,7 +12,7 @@ and running the update script:
|
|
12
12
|
python -m scripts.update_api --inplace
|
13
13
|
"""
|
14
14
|
|
15
|
-
__version__ = "0.1.
|
15
|
+
__version__ = "0.1.16"
|
16
16
|
|
17
17
|
# This list shouldn't be modified by hand; instead, run the update script.
|
18
18
|
__all__ = [
|
@@ -40,6 +40,7 @@ __all__ = [
|
|
40
40
|
"load_eqx_mlp",
|
41
41
|
"make_eqx_mlp",
|
42
42
|
"save_eqx",
|
43
|
+
"cubic_bezier_interpolation",
|
43
44
|
"euler_to_quat",
|
44
45
|
"get_projected_gravity_vector_from_quat",
|
45
46
|
"quat_to_euler",
|
@@ -201,6 +202,7 @@ NAME_MAP: dict[str, str] = {
|
|
201
202
|
"load_eqx_mlp": "nn.equinox",
|
202
203
|
"make_eqx_mlp": "nn.equinox",
|
203
204
|
"save_eqx": "nn.equinox",
|
205
|
+
"cubic_bezier_interpolation": "nn.geom",
|
204
206
|
"euler_to_quat": "nn.geom",
|
205
207
|
"get_projected_gravity_vector_from_quat": "nn.geom",
|
206
208
|
"quat_to_euler": "nn.geom",
|
@@ -363,6 +365,7 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
363
365
|
save_eqx,
|
364
366
|
)
|
365
367
|
from xax.nn.geom import (
|
368
|
+
cubic_bezier_interpolation,
|
366
369
|
euler_to_quat,
|
367
370
|
get_projected_gravity_vector_from_quat,
|
368
371
|
quat_to_euler,
|
@@ -1,10 +1,10 @@
|
|
1
1
|
"""Defines geometry functions."""
|
2
2
|
|
3
|
-
import jax
|
4
3
|
from jax import numpy as jnp
|
4
|
+
from jaxtyping import Array
|
5
5
|
|
6
6
|
|
7
|
-
def quat_to_euler(quat_4:
|
7
|
+
def quat_to_euler(quat_4: Array, eps: float = 1e-6) -> Array:
|
8
8
|
"""Normalizes and converts a quaternion (w, x, y, z) to roll, pitch, yaw.
|
9
9
|
|
10
10
|
Args:
|
@@ -40,7 +40,7 @@ def quat_to_euler(quat_4: jax.Array, eps: float = 1e-6) -> jax.Array:
|
|
40
40
|
return jnp.concatenate([roll, pitch, yaw], axis=-1)
|
41
41
|
|
42
42
|
|
43
|
-
def euler_to_quat(euler_3:
|
43
|
+
def euler_to_quat(euler_3: Array) -> Array:
|
44
44
|
"""Converts roll, pitch, yaw angles to a quaternion (w, x, y, z).
|
45
45
|
|
46
46
|
Args:
|
@@ -75,7 +75,7 @@ def euler_to_quat(euler_3: jax.Array) -> jax.Array:
|
|
75
75
|
return quat
|
76
76
|
|
77
77
|
|
78
|
-
def get_projected_gravity_vector_from_quat(quat:
|
78
|
+
def get_projected_gravity_vector_from_quat(quat: Array, eps: float = 1e-6) -> Array:
|
79
79
|
"""Calculates the gravity vector projected onto the local frame given a quaternion orientation.
|
80
80
|
|
81
81
|
Args:
|
@@ -101,7 +101,7 @@ def get_projected_gravity_vector_from_quat(quat: jax.Array, eps: float = 1e-6) -
|
|
101
101
|
return jnp.concatenate([gx, gy, -gz], axis=-1)
|
102
102
|
|
103
103
|
|
104
|
-
def rotate_vector_by_quat(vector:
|
104
|
+
def rotate_vector_by_quat(vector: Array, quat: Array, eps: float = 1e-6) -> Array:
|
105
105
|
"""Rotates a vector by a quaternion.
|
106
106
|
|
107
107
|
Args:
|
@@ -156,3 +156,24 @@ def rotate_vector_by_quat(vector: jax.Array, quat: jax.Array, eps: float = 1e-6)
|
|
156
156
|
)
|
157
157
|
|
158
158
|
return jnp.concatenate([xx, yy, zz], axis=-1)
|
159
|
+
|
160
|
+
|
161
|
+
def cubic_bezier_interpolation(y_start: Array, y_end: Array, x: Array) -> Array:
|
162
|
+
"""Cubic bezier interpolation.
|
163
|
+
|
164
|
+
This is a cubic bezier curve that starts at y_start and ends at y_end,
|
165
|
+
and is controlled by the parameter x. The curve is defined by the following formula:
|
166
|
+
|
167
|
+
y(x) = y_start + (y_end - y_start) * (x**3 + 3 * (x**2 * (1 - x)))
|
168
|
+
|
169
|
+
Args:
|
170
|
+
y_start: The start value, shape (*).
|
171
|
+
y_end: The end value, shape (*).
|
172
|
+
x: The interpolation parameter, shape (*).
|
173
|
+
|
174
|
+
Returns:
|
175
|
+
The interpolated value, shape (*).
|
176
|
+
"""
|
177
|
+
y_diff = y_end - y_start
|
178
|
+
bezier = x**3 + 3 * (x**2 * (1 - x))
|
179
|
+
return y_start + y_diff * bezier
|
@@ -205,6 +205,12 @@ def as_numpy(array: Array) -> np.ndarray:
|
|
205
205
|
return np.array(array)
|
206
206
|
|
207
207
|
|
208
|
+
@dataclass(kw_only=True)
|
209
|
+
class LogString:
|
210
|
+
value: str
|
211
|
+
secondary: bool
|
212
|
+
|
213
|
+
|
208
214
|
@dataclass(kw_only=True)
|
209
215
|
class LogImage:
|
210
216
|
image: PILImage
|
@@ -223,6 +229,12 @@ class LogVideo:
|
|
223
229
|
fps: int
|
224
230
|
|
225
231
|
|
232
|
+
@dataclass(kw_only=True)
|
233
|
+
class LogScalar:
|
234
|
+
value: Number
|
235
|
+
secondary: bool
|
236
|
+
|
237
|
+
|
226
238
|
@dataclass(kw_only=True)
|
227
239
|
class LogDistribution:
|
228
240
|
mean: Number
|
@@ -243,10 +255,10 @@ class LogHistogram:
|
|
243
255
|
@dataclass(kw_only=True)
|
244
256
|
class LogLine:
|
245
257
|
state: State
|
246
|
-
scalars: dict[str, dict[str,
|
258
|
+
scalars: dict[str, dict[str, LogScalar]]
|
247
259
|
distributions: dict[str, dict[str, LogDistribution]]
|
248
260
|
histograms: dict[str, dict[str, LogHistogram]]
|
249
|
-
strings: dict[str, dict[str,
|
261
|
+
strings: dict[str, dict[str, LogString]]
|
250
262
|
images: dict[str, dict[str, LogImage]]
|
251
263
|
videos: dict[str, dict[str, LogVideo]]
|
252
264
|
|
@@ -515,10 +527,10 @@ class Logger:
|
|
515
527
|
"""Defines an intermediate container which holds values to log somewhere else."""
|
516
528
|
|
517
529
|
def __init__(self, default_namespace: str = DEFAULT_NAMESPACE) -> None:
|
518
|
-
self.scalars: dict[str, dict[str, Callable[[],
|
530
|
+
self.scalars: dict[str, dict[str, Callable[[], LogScalar]]] = defaultdict(dict)
|
519
531
|
self.distributions: dict[str, dict[str, Callable[[], LogDistribution]]] = defaultdict(dict)
|
520
532
|
self.histograms: dict[str, dict[str, Callable[[], LogHistogram]]] = defaultdict(dict)
|
521
|
-
self.strings: dict[str, dict[str, Callable[[],
|
533
|
+
self.strings: dict[str, dict[str, Callable[[], LogString]]] = defaultdict(dict)
|
522
534
|
self.images: dict[str, dict[str, Callable[[], LogImage]]] = defaultdict(dict)
|
523
535
|
self.videos: dict[str, dict[str, Callable[[], LogVideo]]] = defaultdict(dict)
|
524
536
|
self.default_namespace = default_namespace
|
@@ -616,13 +628,23 @@ class Logger:
|
|
616
628
|
def resolve_namespace(self, namespace: str | None = None) -> str:
|
617
629
|
return "_".join([self.default_namespace if namespace is None else namespace] + NAMESPACE_STACK)
|
618
630
|
|
619
|
-
def log_scalar(
|
631
|
+
def log_scalar(
|
632
|
+
self,
|
633
|
+
key: str,
|
634
|
+
value: Callable[[], Number] | Number,
|
635
|
+
*,
|
636
|
+
namespace: str | None = None,
|
637
|
+
secondary: bool = False,
|
638
|
+
) -> None:
|
620
639
|
"""Logs a scalar value.
|
621
640
|
|
622
641
|
Args:
|
623
642
|
key: The key being logged
|
624
643
|
value: The scalar value being logged
|
625
644
|
namespace: An optional logging namespace
|
645
|
+
secondary: If set, treat this as a secondary value (meaning, it is
|
646
|
+
less important than other values, and some downstream loggers
|
647
|
+
will not display it)
|
626
648
|
"""
|
627
649
|
if not self.active:
|
628
650
|
raise RuntimeError("The logger is not active")
|
@@ -632,11 +654,11 @@ class Logger:
|
|
632
654
|
assert value.ndim == 0, f"Scalar must be a 0D array, got shape {value.shape}"
|
633
655
|
|
634
656
|
@functools.lru_cache(maxsize=None)
|
635
|
-
def scalar_future() ->
|
657
|
+
def scalar_future() -> LogScalar:
|
636
658
|
with ContextTimer() as timer:
|
637
659
|
value_concrete = value() if callable(value) else value
|
638
660
|
logger.debug("Scalar Key: %s, Time: %s", key, timer.elapsed_time)
|
639
|
-
return value_concrete
|
661
|
+
return LogScalar(value=value_concrete, secondary=secondary)
|
640
662
|
|
641
663
|
self.scalars[namespace][key] = scalar_future
|
642
664
|
|
@@ -770,21 +792,31 @@ class Logger:
|
|
770
792
|
|
771
793
|
self.histograms[namespace][key] = histogram_future
|
772
794
|
|
773
|
-
def log_string(
|
795
|
+
def log_string(
|
796
|
+
self,
|
797
|
+
key: str,
|
798
|
+
value: Callable[[], str] | str,
|
799
|
+
*,
|
800
|
+
namespace: str | None = None,
|
801
|
+
secondary: bool = False,
|
802
|
+
) -> None:
|
774
803
|
"""Logs a string value.
|
775
804
|
|
776
805
|
Args:
|
777
806
|
key: The key being logged
|
778
807
|
value: The string value being logged
|
779
808
|
namespace: An optional logging namespace
|
809
|
+
secondary: If set, treat this as a secondary value (meaning, it is
|
810
|
+
less important than other values, and some downstream loggers
|
811
|
+
will not display it)
|
780
812
|
"""
|
781
813
|
if not self.active:
|
782
814
|
raise RuntimeError("The logger is not active")
|
783
815
|
namespace = self.resolve_namespace(namespace)
|
784
816
|
|
785
817
|
@functools.lru_cache(maxsize=None)
|
786
|
-
def value_future() ->
|
787
|
-
return value() if callable(value) else value
|
818
|
+
def value_future() -> LogString:
|
819
|
+
return LogString(value=value() if callable(value) else value, secondary=secondary)
|
788
820
|
|
789
821
|
self.strings[namespace][key] = value_future
|
790
822
|
|
@@ -3,11 +3,19 @@
|
|
3
3
|
import json
|
4
4
|
import sys
|
5
5
|
from dataclasses import asdict
|
6
|
-
from typing import Any, Literal, TextIO
|
6
|
+
from typing import Any, Literal, Mapping, TextIO
|
7
7
|
|
8
8
|
from jaxtyping import Array
|
9
9
|
|
10
|
-
from xax.task.logger import
|
10
|
+
from xax.task.logger import (
|
11
|
+
LogError,
|
12
|
+
LoggerImpl,
|
13
|
+
LogLine,
|
14
|
+
LogPing,
|
15
|
+
LogScalar,
|
16
|
+
LogStatus,
|
17
|
+
LogString,
|
18
|
+
)
|
11
19
|
|
12
20
|
|
13
21
|
def get_json_value(value: Any) -> Any: # noqa: ANN401
|
@@ -61,14 +69,14 @@ class JsonLogger(LoggerImpl):
|
|
61
69
|
def get_json(self, line: LogLine) -> str:
|
62
70
|
data: dict = {"state": asdict(line.state)}
|
63
71
|
|
64
|
-
def add_logs(log:
|
72
|
+
def add_logs(log: Mapping[str, Mapping[str, LogScalar | LogString]], data: dict) -> None:
|
65
73
|
for namespace, values in log.items():
|
66
74
|
if self.remove_unicode_from_namespaces:
|
67
75
|
namespace = namespace.encode("ascii", errors="ignore").decode("ascii").strip()
|
68
76
|
if namespace not in data:
|
69
77
|
data[namespace] = {}
|
70
78
|
for k, v in values.items():
|
71
|
-
data[namespace][k] = get_json_value(v)
|
79
|
+
data[namespace][k] = get_json_value(v.value)
|
72
80
|
|
73
81
|
add_logs(line.scalars, data)
|
74
82
|
add_logs(line.strings, data)
|
@@ -4,11 +4,20 @@ import datetime
|
|
4
4
|
import logging
|
5
5
|
import sys
|
6
6
|
from collections import deque
|
7
|
-
from typing import Any, Deque, TextIO
|
7
|
+
from typing import Any, Deque, Mapping, TextIO
|
8
8
|
|
9
9
|
from jaxtyping import Array
|
10
10
|
|
11
|
-
from xax.task.logger import
|
11
|
+
from xax.task.logger import (
|
12
|
+
LogError,
|
13
|
+
LogErrorSummary,
|
14
|
+
LoggerImpl,
|
15
|
+
LogLine,
|
16
|
+
LogPing,
|
17
|
+
LogScalar,
|
18
|
+
LogStatus,
|
19
|
+
LogString,
|
20
|
+
)
|
12
21
|
from xax.utils.text import Color, colored, format_timedelta
|
13
22
|
|
14
23
|
|
@@ -95,20 +104,17 @@ class StdoutLogger(LoggerImpl):
|
|
95
104
|
def write_log_window(self, line: LogLine) -> None:
|
96
105
|
namespace_to_lines: dict[str, dict[str, str]] = {}
|
97
106
|
|
98
|
-
def add_logs(
|
107
|
+
def add_logs(
|
108
|
+
log: Mapping[str, Mapping[str, LogScalar | LogString]],
|
109
|
+
namespace_to_lines: dict[str, dict[str, str]],
|
110
|
+
) -> None:
|
99
111
|
for namespace, values in log.items():
|
100
|
-
if not self.log_timers and namespace.startswith("⌛"):
|
101
|
-
continue
|
102
|
-
if not self.log_perf and namespace.startswith("🔧"):
|
103
|
-
continue
|
104
|
-
if not self.log_optim and namespace.startswith("📉"):
|
105
|
-
continue
|
106
|
-
if not self.log_fp and namespace.startswith("⚖️"):
|
107
|
-
continue
|
108
|
-
if namespace not in namespace_to_lines:
|
109
|
-
namespace_to_lines[namespace] = {}
|
110
112
|
for k, v in values.items():
|
111
|
-
|
113
|
+
if v.secondary:
|
114
|
+
continue
|
115
|
+
if namespace not in namespace_to_lines:
|
116
|
+
namespace_to_lines[namespace] = {}
|
117
|
+
v_str = as_str(v.value, self.precision)
|
112
118
|
namespace_to_lines[namespace][k] = v_str
|
113
119
|
|
114
120
|
add_logs(line.scalars, namespace_to_lines)
|
@@ -116,9 +122,8 @@ class StdoutLogger(LoggerImpl):
|
|
116
122
|
if not namespace_to_lines:
|
117
123
|
return
|
118
124
|
|
119
|
-
self.write_fp.write("\n")
|
120
125
|
for namespace, lines in sorted(namespace_to_lines.items()):
|
121
|
-
self.write_fp.write(f"{colored(namespace, 'cyan', bold=True)}\n")
|
126
|
+
self.write_fp.write(f"\n{colored(namespace, 'cyan', bold=True)}\n")
|
122
127
|
for k, v in lines.items():
|
123
128
|
self.write_fp.write(f" ↪ {k}: {v}\n")
|
124
129
|
|
@@ -158,7 +158,7 @@ class TensorboardLogger(LoggerImpl):
|
|
158
158
|
for scalar_key, scalar_value in scalars.items():
|
159
159
|
writer.add_scalar(
|
160
160
|
f"{namespace}/{scalar_key}",
|
161
|
-
as_float(scalar_value),
|
161
|
+
as_float(scalar_value.value),
|
162
162
|
global_step=line.state.num_steps,
|
163
163
|
walltime=walltime,
|
164
164
|
)
|
@@ -192,7 +192,7 @@ class TensorboardLogger(LoggerImpl):
|
|
192
192
|
for string_key, string_value in strings.items():
|
193
193
|
writer.add_text(
|
194
194
|
f"{namespace}/{string_key}",
|
195
|
-
string_value,
|
195
|
+
string_value.value,
|
196
196
|
global_step=line.state.num_steps,
|
197
197
|
walltime=walltime,
|
198
198
|
)
|
@@ -248,15 +248,15 @@ class CPUStatsMixin(ProcessMixin[Config], LoggerMixin[Config], Generic[Config]):
|
|
248
248
|
stats = monitor.get_if_set() if self.config.cpu_stats.only_log_once else monitor.get()
|
249
249
|
|
250
250
|
if stats is not None:
|
251
|
-
self.logger.log_scalar("child_procs", stats.num_child_procs, namespace="🔧 cpu")
|
252
|
-
self.logger.log_scalar("percent", stats.cpu_percent, namespace="🔧 cpu")
|
253
|
-
self.logger.log_scalar("child_percent", stats.child_cpu_percent, namespace="🔧 cpu")
|
254
|
-
self.logger.log_scalar("percent", stats.mem_percent, namespace="🔧 mem")
|
255
|
-
self.logger.log_scalar("shared", stats.mem_shared, namespace="🔧 mem")
|
256
|
-
self.logger.log_scalar("child_percent", stats.child_mem_percent, namespace="🔧 mem")
|
257
|
-
self.logger.log_scalar("rss/cur", stats.mem_rss, namespace="🔧 mem")
|
258
|
-
self.logger.log_scalar("rss/total", stats.mem_rss_total, namespace="🔧 mem")
|
259
|
-
self.logger.log_scalar("vms/cur", stats.mem_vms, namespace="🔧 mem")
|
260
|
-
self.logger.log_scalar("vms/total", stats.mem_vms_total, namespace="🔧 mem")
|
251
|
+
self.logger.log_scalar("child_procs", stats.num_child_procs, namespace="🔧 cpu", secondary=True)
|
252
|
+
self.logger.log_scalar("percent", stats.cpu_percent, namespace="🔧 cpu", secondary=True)
|
253
|
+
self.logger.log_scalar("child_percent", stats.child_cpu_percent, namespace="🔧 cpu", secondary=True)
|
254
|
+
self.logger.log_scalar("percent", stats.mem_percent, namespace="🔧 mem", secondary=True)
|
255
|
+
self.logger.log_scalar("shared", stats.mem_shared, namespace="🔧 mem", secondary=True)
|
256
|
+
self.logger.log_scalar("child_percent", stats.child_mem_percent, namespace="🔧 mem", secondary=True)
|
257
|
+
self.logger.log_scalar("rss/cur", stats.mem_rss, namespace="🔧 mem", secondary=True)
|
258
|
+
self.logger.log_scalar("rss/total", stats.mem_rss_total, namespace="🔧 mem", secondary=True)
|
259
|
+
self.logger.log_scalar("vms/cur", stats.mem_vms, namespace="🔧 mem", secondary=True)
|
260
|
+
self.logger.log_scalar("vms/total", stats.mem_vms_total, namespace="🔧 mem", secondary=True)
|
261
261
|
|
262
262
|
return state
|
@@ -264,8 +264,8 @@ class GPUStatsMixin(ProcessMixin[Config], LoggerMixin[Config], Generic[Config]):
|
|
264
264
|
for gpu_stat in stats.values():
|
265
265
|
if gpu_stat is None:
|
266
266
|
continue
|
267
|
-
self.logger.log_scalar(f"mem/{gpu_stat.index}", gpu_stat.memory_used, namespace="🔧 gpu")
|
268
|
-
self.logger.log_scalar(f"temp/{gpu_stat.index}", gpu_stat.temperature, namespace="🔧 gpu")
|
269
|
-
self.logger.log_scalar(f"util/{gpu_stat.index}", gpu_stat.utilization, namespace="🔧 gpu")
|
267
|
+
self.logger.log_scalar(f"mem/{gpu_stat.index}", gpu_stat.memory_used, namespace="🔧 gpu", secondary=True)
|
268
|
+
self.logger.log_scalar(f"temp/{gpu_stat.index}", gpu_stat.temperature, namespace="🔧 gpu", secondary=True)
|
269
|
+
self.logger.log_scalar(f"util/{gpu_stat.index}", gpu_stat.utilization, namespace="🔧 gpu", secondary=True)
|
270
270
|
|
271
271
|
return state
|
@@ -50,8 +50,7 @@ from xax.utils.experiments import (
|
|
50
50
|
TrainingFinishedError,
|
51
51
|
diff_configs,
|
52
52
|
get_diff_string,
|
53
|
-
|
54
|
-
get_packages_with_versions,
|
53
|
+
get_state_file_string,
|
55
54
|
get_training_code,
|
56
55
|
)
|
57
56
|
from xax.utils.jax import jit as xax_jit
|
@@ -219,7 +218,12 @@ class TrainMixin(
|
|
219
218
|
return state.replace(elapsed_time_s=time.time() - state.start_time_s)
|
220
219
|
|
221
220
|
def log_train_step(
|
222
|
-
self,
|
221
|
+
self,
|
222
|
+
model: PyTree,
|
223
|
+
batch: Batch,
|
224
|
+
output: Output,
|
225
|
+
metrics: FrozenDict[str, Array],
|
226
|
+
state: State,
|
223
227
|
) -> None:
|
224
228
|
"""Override this function to do logging during the training phase.
|
225
229
|
|
@@ -235,7 +239,12 @@ class TrainMixin(
|
|
235
239
|
"""
|
236
240
|
|
237
241
|
def log_valid_step(
|
238
|
-
self,
|
242
|
+
self,
|
243
|
+
model: PyTree,
|
244
|
+
batch: Batch,
|
245
|
+
output: Output,
|
246
|
+
metrics: FrozenDict[str, Array],
|
247
|
+
state: State,
|
239
248
|
) -> None:
|
240
249
|
"""Override this function to do logging during the validation phase.
|
241
250
|
|
@@ -253,12 +262,20 @@ class TrainMixin(
|
|
253
262
|
def log_state_timers(self, state: State) -> None:
|
254
263
|
timer = self.state_timers[state.phase]
|
255
264
|
timer.step(state)
|
256
|
-
for
|
257
|
-
|
258
|
-
|
265
|
+
for k, v in timer.log_dict().items():
|
266
|
+
if isinstance(v, tuple):
|
267
|
+
v, secondary = v
|
268
|
+
else:
|
269
|
+
secondary = False
|
270
|
+
self.logger.log_scalar(k, v, namespace="⌛ timers", secondary=secondary)
|
259
271
|
|
260
272
|
def log_step(
|
261
|
-
self,
|
273
|
+
self,
|
274
|
+
model: PyTree,
|
275
|
+
batch: Batch,
|
276
|
+
output: Output,
|
277
|
+
metrics: FrozenDict[str, Array],
|
278
|
+
state: State,
|
262
279
|
) -> None:
|
263
280
|
phase = state.phase
|
264
281
|
|
@@ -534,9 +551,8 @@ class TrainMixin(
|
|
534
551
|
logger.log(LOG_STATUS, self.task_path)
|
535
552
|
logger.log(LOG_STATUS, self.task_name)
|
536
553
|
logger.log(LOG_STATUS, "JAX devices: %s", jax.devices())
|
537
|
-
self.logger.log_file("
|
538
|
-
self.logger.log_file("
|
539
|
-
self.logger.log_file("training_code.txt", get_training_code(self))
|
554
|
+
self.logger.log_file("state.txt", get_state_file_string(self))
|
555
|
+
self.logger.log_file("training_code.py", get_training_code(self))
|
540
556
|
self.logger.log_file("config.yaml", self.config_str(self.config, use_cli=False))
|
541
557
|
|
542
558
|
def model_partition_fn(self, item: Any) -> bool: # noqa: ANN401
|
@@ -114,28 +114,15 @@ class StateTimer:
|
|
114
114
|
self.sample_timer.step(state.num_samples if state.phase == "train" else state.num_valid_samples, cur_time)
|
115
115
|
self.iter_timer.step(cur_time)
|
116
116
|
|
117
|
-
def log_dict(self) -> dict[str,
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
"
|
123
|
-
"
|
124
|
-
}
|
125
|
-
|
126
|
-
# Logs sample statistics.
|
127
|
-
logs["⌛ samples"] = {
|
128
|
-
"total": self.sample_timer.steps,
|
129
|
-
"per-second": self.sample_timer.steps_per_second,
|
130
|
-
}
|
131
|
-
|
132
|
-
# Logs full iteration statistics.
|
133
|
-
logs["⌛ dt"] = {
|
134
|
-
"iter": self.iter_timer.iter_seconds,
|
117
|
+
def log_dict(self) -> dict[str, int | float | tuple[int | float, bool]]:
|
118
|
+
return {
|
119
|
+
"steps": (self.step_timer.steps, True),
|
120
|
+
"steps/second": self.step_timer.steps_per_second,
|
121
|
+
"samples": (self.sample_timer.steps, True),
|
122
|
+
"samples/second": (self.sample_timer.steps_per_second, True),
|
123
|
+
"dt": self.iter_timer.iter_seconds,
|
135
124
|
}
|
136
125
|
|
137
|
-
return logs
|
138
|
-
|
139
126
|
|
140
127
|
class IntervalTicker:
|
141
128
|
def __init__(self, interval: float) -> None:
|
@@ -479,6 +466,20 @@ def get_packages_with_versions() -> str:
|
|
479
466
|
return "\n".join([f"{key}=={version}" for key, version in sorted(packages)])
|
480
467
|
|
481
468
|
|
469
|
+
def get_command_line_string() -> str:
|
470
|
+
return " ".join(sys.argv)
|
471
|
+
|
472
|
+
|
473
|
+
def get_state_file_string(obj: object) -> str:
|
474
|
+
return "\n\n".join(
|
475
|
+
[
|
476
|
+
f"=== Command Line ===\n\n{get_command_line_string()}",
|
477
|
+
f"=== Git State ===\n\n{get_git_state(obj)}",
|
478
|
+
f"=== Packages ===\n\n{get_packages_with_versions()}",
|
479
|
+
]
|
480
|
+
)
|
481
|
+
|
482
|
+
|
482
483
|
def get_training_code(obj: object) -> str:
|
483
484
|
"""Gets the text from the file containing the provided object.
|
484
485
|
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|