xax 0.1.15__tar.gz → 0.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {xax-0.1.15/xax.egg-info → xax-0.2.0}/PKG-INFO +6 -6
- {xax-0.1.15 → xax-0.2.0}/pyproject.toml +0 -1
- {xax-0.1.15 → xax-0.2.0}/xax/__init__.py +1 -1
- {xax-0.1.15 → xax-0.2.0}/xax/core/state.py +26 -1
- {xax-0.1.15 → xax-0.2.0}/xax/requirements.txt +5 -5
- {xax-0.1.15 → xax-0.2.0}/xax/task/base.py +1 -1
- {xax-0.1.15 → xax-0.2.0}/xax/task/logger.py +149 -12
- {xax-0.1.15 → xax-0.2.0}/xax/task/loggers/json.py +12 -4
- {xax-0.1.15 → xax-0.2.0}/xax/task/loggers/stdout.py +21 -16
- {xax-0.1.15 → xax-0.2.0}/xax/task/loggers/tensorboard.py +18 -2
- {xax-0.1.15 → xax-0.2.0}/xax/task/mixins/checkpointing.py +118 -41
- {xax-0.1.15 → xax-0.2.0}/xax/task/mixins/cpu_stats.py +10 -10
- {xax-0.1.15 → xax-0.2.0}/xax/task/mixins/data_loader.py +2 -1
- {xax-0.1.15 → xax-0.2.0}/xax/task/mixins/gpu_stats.py +3 -3
- {xax-0.1.15 → xax-0.2.0}/xax/task/mixins/train.py +59 -29
- {xax-0.1.15 → xax-0.2.0}/xax/utils/experiments.py +34 -30
- {xax-0.1.15 → xax-0.2.0}/xax/utils/tensorboard.py +91 -3
- {xax-0.1.15 → xax-0.2.0/xax.egg-info}/PKG-INFO +6 -6
- {xax-0.1.15 → xax-0.2.0}/xax.egg-info/requires.txt +5 -5
- {xax-0.1.15 → xax-0.2.0}/LICENSE +0 -0
- {xax-0.1.15 → xax-0.2.0}/MANIFEST.in +0 -0
- {xax-0.1.15 → xax-0.2.0}/README.md +0 -0
- {xax-0.1.15 → xax-0.2.0}/setup.cfg +0 -0
- {xax-0.1.15 → xax-0.2.0}/setup.py +0 -0
- {xax-0.1.15 → xax-0.2.0}/xax/core/__init__.py +0 -0
- {xax-0.1.15 → xax-0.2.0}/xax/core/conf.py +0 -0
- {xax-0.1.15 → xax-0.2.0}/xax/nn/__init__.py +0 -0
- {xax-0.1.15 → xax-0.2.0}/xax/nn/embeddings.py +0 -0
- {xax-0.1.15 → xax-0.2.0}/xax/nn/equinox.py +0 -0
- {xax-0.1.15 → xax-0.2.0}/xax/nn/export.py +0 -0
- {xax-0.1.15 → xax-0.2.0}/xax/nn/functions.py +0 -0
- {xax-0.1.15 → xax-0.2.0}/xax/nn/geom.py +0 -0
- {xax-0.1.15 → xax-0.2.0}/xax/nn/losses.py +0 -0
- {xax-0.1.15 → xax-0.2.0}/xax/nn/norm.py +0 -0
- {xax-0.1.15 → xax-0.2.0}/xax/nn/parallel.py +0 -0
- {xax-0.1.15 → xax-0.2.0}/xax/nn/ssm.py +0 -0
- {xax-0.1.15 → xax-0.2.0}/xax/py.typed +0 -0
- {xax-0.1.15 → xax-0.2.0}/xax/requirements-dev.txt +0 -0
- {xax-0.1.15 → xax-0.2.0}/xax/task/__init__.py +0 -0
- {xax-0.1.15 → xax-0.2.0}/xax/task/launchers/__init__.py +0 -0
- {xax-0.1.15 → xax-0.2.0}/xax/task/launchers/base.py +0 -0
- {xax-0.1.15 → xax-0.2.0}/xax/task/launchers/cli.py +0 -0
- {xax-0.1.15 → xax-0.2.0}/xax/task/launchers/single_process.py +0 -0
- {xax-0.1.15 → xax-0.2.0}/xax/task/loggers/__init__.py +0 -0
- {xax-0.1.15 → xax-0.2.0}/xax/task/loggers/callback.py +0 -0
- {xax-0.1.15 → xax-0.2.0}/xax/task/loggers/state.py +0 -0
- {xax-0.1.15 → xax-0.2.0}/xax/task/mixins/__init__.py +0 -0
- {xax-0.1.15 → xax-0.2.0}/xax/task/mixins/artifacts.py +0 -0
- {xax-0.1.15 → xax-0.2.0}/xax/task/mixins/compile.py +0 -0
- {xax-0.1.15 → xax-0.2.0}/xax/task/mixins/logger.py +0 -0
- {xax-0.1.15 → xax-0.2.0}/xax/task/mixins/process.py +0 -0
- {xax-0.1.15 → xax-0.2.0}/xax/task/mixins/runnable.py +0 -0
- {xax-0.1.15 → xax-0.2.0}/xax/task/mixins/step_wrapper.py +0 -0
- {xax-0.1.15 → xax-0.2.0}/xax/task/script.py +0 -0
- {xax-0.1.15 → xax-0.2.0}/xax/task/task.py +0 -0
- {xax-0.1.15 → xax-0.2.0}/xax/utils/__init__.py +0 -0
- {xax-0.1.15 → xax-0.2.0}/xax/utils/data/__init__.py +0 -0
- {xax-0.1.15 → xax-0.2.0}/xax/utils/data/collate.py +0 -0
- {xax-0.1.15 → xax-0.2.0}/xax/utils/debugging.py +0 -0
- {xax-0.1.15 → xax-0.2.0}/xax/utils/jax.py +0 -0
- {xax-0.1.15 → xax-0.2.0}/xax/utils/jaxpr.py +0 -0
- {xax-0.1.15 → xax-0.2.0}/xax/utils/logging.py +0 -0
- {xax-0.1.15 → xax-0.2.0}/xax/utils/numpy.py +0 -0
- {xax-0.1.15 → xax-0.2.0}/xax/utils/profile.py +0 -0
- {xax-0.1.15 → xax-0.2.0}/xax/utils/pytree.py +0 -0
- {xax-0.1.15 → xax-0.2.0}/xax/utils/text.py +0 -0
- {xax-0.1.15 → xax-0.2.0}/xax/utils/types/__init__.py +0 -0
- {xax-0.1.15 → xax-0.2.0}/xax/utils/types/frozen_dict.py +0 -0
- {xax-0.1.15 → xax-0.2.0}/xax/utils/types/hashable_array.py +0 -0
- {xax-0.1.15 → xax-0.2.0}/xax.egg-info/SOURCES.txt +0 -0
- {xax-0.1.15 → xax-0.2.0}/xax.egg-info/dependency_links.txt +0 -0
- {xax-0.1.15 → xax-0.2.0}/xax.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: xax
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.2.0
|
4
4
|
Summary: A library for fast Jax experimentation
|
5
5
|
Home-page: https://github.com/kscalelabs/xax
|
6
6
|
Author: Benjamin Bolte
|
@@ -8,14 +8,14 @@ Requires-Python: >=3.11
|
|
8
8
|
Description-Content-Type: text/markdown
|
9
9
|
License-File: LICENSE
|
10
10
|
Requires-Dist: attrs
|
11
|
+
Requires-Dist: chex
|
12
|
+
Requires-Dist: dpshdl
|
13
|
+
Requires-Dist: equinox
|
14
|
+
Requires-Dist: importlib-resources
|
11
15
|
Requires-Dist: jax
|
12
16
|
Requires-Dist: jaxtyping
|
13
|
-
Requires-Dist: equinox
|
14
17
|
Requires-Dist: optax
|
15
|
-
Requires-Dist:
|
16
|
-
Requires-Dist: chex
|
17
|
-
Requires-Dist: importlib-resources
|
18
|
-
Requires-Dist: cloudpickle
|
18
|
+
Requires-Dist: orbax-checkpoint
|
19
19
|
Requires-Dist: pillow
|
20
20
|
Requires-Dist: omegaconf
|
21
21
|
Requires-Dist: gitpython
|
@@ -12,6 +12,14 @@ from xax.core.conf import field
|
|
12
12
|
Phase = Literal["train", "valid"]
|
13
13
|
|
14
14
|
|
15
|
+
def _phase_to_int(phase: Phase) -> int:
|
16
|
+
return {"train": 0, "valid": 1}[phase]
|
17
|
+
|
18
|
+
|
19
|
+
def _int_to_phase(i: int) -> Phase:
|
20
|
+
return cast(Phase, ["train", "valid"][i])
|
21
|
+
|
22
|
+
|
15
23
|
class StateDict(TypedDict, total=False):
|
16
24
|
num_steps: NotRequired[int]
|
17
25
|
num_samples: NotRequired[int]
|
@@ -35,7 +43,7 @@ class State:
|
|
35
43
|
|
36
44
|
@property
|
37
45
|
def phase(self) -> Phase:
|
38
|
-
return
|
46
|
+
return _int_to_phase(self._phase)
|
39
47
|
|
40
48
|
@classmethod
|
41
49
|
def init_state(cls) -> "State":
|
@@ -74,3 +82,20 @@ class State:
|
|
74
82
|
case _:
|
75
83
|
raise ValueError(f"Invalid phase: {phase}")
|
76
84
|
return State(**{**asdict(self), **kwargs, **extra_kwargs})
|
85
|
+
|
86
|
+
def to_dict(self) -> dict[str, int | float | str]:
|
87
|
+
return {
|
88
|
+
"num_steps": int(self.num_steps),
|
89
|
+
"num_samples": int(self.num_samples),
|
90
|
+
"num_valid_steps": int(self.num_valid_steps),
|
91
|
+
"num_valid_samples": int(self.num_valid_samples),
|
92
|
+
"start_time_s": float(self.start_time_s),
|
93
|
+
"elapsed_time_s": float(self.elapsed_time_s),
|
94
|
+
"phase": str(self.phase),
|
95
|
+
}
|
96
|
+
|
97
|
+
@classmethod
|
98
|
+
def from_dict(cls, d: dict[str, int | float | str]) -> "State":
|
99
|
+
if "phase" in d:
|
100
|
+
d["_phase"] = _phase_to_int(cast(Phase, d.pop("phase")))
|
101
|
+
return cls(**d) # type: ignore[arg-type]
|
@@ -2,16 +2,16 @@
|
|
2
2
|
|
3
3
|
# Core ML/JAX dependencies
|
4
4
|
attrs
|
5
|
+
chex
|
6
|
+
dpshdl
|
7
|
+
equinox
|
8
|
+
importlib-resources
|
5
9
|
jax
|
6
10
|
jaxtyping
|
7
|
-
equinox
|
8
11
|
optax
|
9
|
-
|
10
|
-
chex
|
11
|
-
importlib-resources
|
12
|
+
orbax-checkpoint
|
12
13
|
|
13
14
|
# Data processing and serialization
|
14
|
-
cloudpickle
|
15
15
|
pillow
|
16
16
|
|
17
17
|
# Configuration and project management
|
@@ -79,7 +79,7 @@ class BaseTask(Generic[Config]):
|
|
79
79
|
def on_training_end(self, state: State) -> State:
|
80
80
|
return state
|
81
81
|
|
82
|
-
def on_after_checkpoint_save(self, ckpt_path: Path, state: State) -> State:
|
82
|
+
def on_after_checkpoint_save(self, ckpt_path: Path, state: State | None) -> State | None:
|
83
83
|
return state
|
84
84
|
|
85
85
|
@functools.cached_property
|
@@ -18,11 +18,22 @@ from abc import ABC, abstractmethod
|
|
18
18
|
from collections import defaultdict
|
19
19
|
from dataclasses import dataclass
|
20
20
|
from types import TracebackType
|
21
|
-
from typing import
|
21
|
+
from typing import (
|
22
|
+
Any,
|
23
|
+
Callable,
|
24
|
+
Iterator,
|
25
|
+
Literal,
|
26
|
+
Self,
|
27
|
+
Sequence,
|
28
|
+
TypeVar,
|
29
|
+
cast,
|
30
|
+
get_args,
|
31
|
+
)
|
22
32
|
|
23
33
|
import jax
|
24
34
|
import jax.numpy as jnp
|
25
35
|
import numpy as np
|
36
|
+
from jax._src.core import ClosedJaxpr
|
26
37
|
from jaxtyping import Array
|
27
38
|
from PIL import Image, ImageDraw, ImageFont
|
28
39
|
from PIL.Image import Image as PILImage
|
@@ -194,7 +205,10 @@ def tile_images(images: list[PILImage], sep: int = 0) -> PILImage:
|
|
194
205
|
return tiled
|
195
206
|
|
196
207
|
|
197
|
-
def as_numpy(array: Array) -> np.ndarray:
|
208
|
+
def as_numpy(array: Array | np.ndarray) -> np.ndarray:
|
209
|
+
"""Convert a JAX array or numpy array to numpy array."""
|
210
|
+
if isinstance(array, np.ndarray):
|
211
|
+
return array
|
198
212
|
array = jax.device_get(array)
|
199
213
|
if jax.dtypes.issubdtype(array.dtype, jnp.floating):
|
200
214
|
array = array.astype(jnp.float32)
|
@@ -205,6 +219,19 @@ def as_numpy(array: Array) -> np.ndarray:
|
|
205
219
|
return np.array(array)
|
206
220
|
|
207
221
|
|
222
|
+
def as_numpy_opt(array: Array | np.ndarray | None) -> np.ndarray | None:
|
223
|
+
"""Convert an optional JAX array or numpy array to numpy array."""
|
224
|
+
if array is None:
|
225
|
+
return None
|
226
|
+
return as_numpy(array)
|
227
|
+
|
228
|
+
|
229
|
+
@dataclass(kw_only=True)
|
230
|
+
class LogString:
|
231
|
+
value: str
|
232
|
+
secondary: bool
|
233
|
+
|
234
|
+
|
208
235
|
@dataclass(kw_only=True)
|
209
236
|
class LogImage:
|
210
237
|
image: PILImage
|
@@ -223,6 +250,12 @@ class LogVideo:
|
|
223
250
|
fps: int
|
224
251
|
|
225
252
|
|
253
|
+
@dataclass(kw_only=True)
|
254
|
+
class LogScalar:
|
255
|
+
value: Number
|
256
|
+
secondary: bool
|
257
|
+
|
258
|
+
|
226
259
|
@dataclass(kw_only=True)
|
227
260
|
class LogDistribution:
|
228
261
|
mean: Number
|
@@ -240,15 +273,29 @@ class LogHistogram:
|
|
240
273
|
bucket_counts: list[int]
|
241
274
|
|
242
275
|
|
276
|
+
@dataclass(kw_only=True)
|
277
|
+
class LogMesh:
|
278
|
+
vertices: np.ndarray
|
279
|
+
colors: np.ndarray | None
|
280
|
+
faces: np.ndarray | None
|
281
|
+
config_dict: dict[str, Any] | None # noqa: ANN401
|
282
|
+
|
283
|
+
|
284
|
+
@dataclass(kw_only=True)
|
285
|
+
class LogGraph:
|
286
|
+
computation: ClosedJaxpr
|
287
|
+
|
288
|
+
|
243
289
|
@dataclass(kw_only=True)
|
244
290
|
class LogLine:
|
245
291
|
state: State
|
246
|
-
scalars: dict[str, dict[str,
|
292
|
+
scalars: dict[str, dict[str, LogScalar]]
|
247
293
|
distributions: dict[str, dict[str, LogDistribution]]
|
248
294
|
histograms: dict[str, dict[str, LogHistogram]]
|
249
|
-
strings: dict[str, dict[str,
|
295
|
+
strings: dict[str, dict[str, LogString]]
|
250
296
|
images: dict[str, dict[str, LogImage]]
|
251
297
|
videos: dict[str, dict[str, LogVideo]]
|
298
|
+
meshes: dict[str, dict[str, LogMesh]]
|
252
299
|
|
253
300
|
|
254
301
|
@dataclass(kw_only=True)
|
@@ -515,12 +562,13 @@ class Logger:
|
|
515
562
|
"""Defines an intermediate container which holds values to log somewhere else."""
|
516
563
|
|
517
564
|
def __init__(self, default_namespace: str = DEFAULT_NAMESPACE) -> None:
|
518
|
-
self.scalars: dict[str, dict[str, Callable[[],
|
565
|
+
self.scalars: dict[str, dict[str, Callable[[], LogScalar]]] = defaultdict(dict)
|
519
566
|
self.distributions: dict[str, dict[str, Callable[[], LogDistribution]]] = defaultdict(dict)
|
520
567
|
self.histograms: dict[str, dict[str, Callable[[], LogHistogram]]] = defaultdict(dict)
|
521
|
-
self.strings: dict[str, dict[str, Callable[[],
|
568
|
+
self.strings: dict[str, dict[str, Callable[[], LogString]]] = defaultdict(dict)
|
522
569
|
self.images: dict[str, dict[str, Callable[[], LogImage]]] = defaultdict(dict)
|
523
570
|
self.videos: dict[str, dict[str, Callable[[], LogVideo]]] = defaultdict(dict)
|
571
|
+
self.meshes: dict[str, dict[str, Callable[[], LogMesh]]] = defaultdict(dict)
|
524
572
|
self.default_namespace = default_namespace
|
525
573
|
self.loggers: list[LoggerImpl] = []
|
526
574
|
|
@@ -548,6 +596,7 @@ class Logger:
|
|
548
596
|
strings={k: {kk: v() for kk, v in v.items()} for k, v in self.strings.items()},
|
549
597
|
images={k: {kk: v() for kk, v in v.items()} for k, v in self.images.items()},
|
550
598
|
videos={k: {kk: v() for kk, v in v.items()} for k, v in self.videos.items()},
|
599
|
+
meshes={k: {kk: v() for kk, v in v.items()} for k, v in self.meshes.items()},
|
551
600
|
)
|
552
601
|
|
553
602
|
def clear(self) -> None:
|
@@ -557,6 +606,7 @@ class Logger:
|
|
557
606
|
self.strings.clear()
|
558
607
|
self.images.clear()
|
559
608
|
self.videos.clear()
|
609
|
+
self.meshes.clear()
|
560
610
|
|
561
611
|
def write(self, state: State) -> None:
|
562
612
|
"""Writes the current step's logging information.
|
@@ -616,13 +666,23 @@ class Logger:
|
|
616
666
|
def resolve_namespace(self, namespace: str | None = None) -> str:
|
617
667
|
return "_".join([self.default_namespace if namespace is None else namespace] + NAMESPACE_STACK)
|
618
668
|
|
619
|
-
def log_scalar(
|
669
|
+
def log_scalar(
|
670
|
+
self,
|
671
|
+
key: str,
|
672
|
+
value: Callable[[], Number] | Number,
|
673
|
+
*,
|
674
|
+
namespace: str | None = None,
|
675
|
+
secondary: bool = False,
|
676
|
+
) -> None:
|
620
677
|
"""Logs a scalar value.
|
621
678
|
|
622
679
|
Args:
|
623
680
|
key: The key being logged
|
624
681
|
value: The scalar value being logged
|
625
682
|
namespace: An optional logging namespace
|
683
|
+
secondary: If set, treat this as a secondary value (meaning, it is
|
684
|
+
less important than other values, and some downstream loggers
|
685
|
+
will not display it)
|
626
686
|
"""
|
627
687
|
if not self.active:
|
628
688
|
raise RuntimeError("The logger is not active")
|
@@ -632,11 +692,11 @@ class Logger:
|
|
632
692
|
assert value.ndim == 0, f"Scalar must be a 0D array, got shape {value.shape}"
|
633
693
|
|
634
694
|
@functools.lru_cache(maxsize=None)
|
635
|
-
def scalar_future() ->
|
695
|
+
def scalar_future() -> LogScalar:
|
636
696
|
with ContextTimer() as timer:
|
637
697
|
value_concrete = value() if callable(value) else value
|
638
698
|
logger.debug("Scalar Key: %s, Time: %s", key, timer.elapsed_time)
|
639
|
-
return value_concrete
|
699
|
+
return LogScalar(value=value_concrete, secondary=secondary)
|
640
700
|
|
641
701
|
self.scalars[namespace][key] = scalar_future
|
642
702
|
|
@@ -770,21 +830,31 @@ class Logger:
|
|
770
830
|
|
771
831
|
self.histograms[namespace][key] = histogram_future
|
772
832
|
|
773
|
-
def log_string(
|
833
|
+
def log_string(
|
834
|
+
self,
|
835
|
+
key: str,
|
836
|
+
value: Callable[[], str] | str,
|
837
|
+
*,
|
838
|
+
namespace: str | None = None,
|
839
|
+
secondary: bool = False,
|
840
|
+
) -> None:
|
774
841
|
"""Logs a string value.
|
775
842
|
|
776
843
|
Args:
|
777
844
|
key: The key being logged
|
778
845
|
value: The string value being logged
|
779
846
|
namespace: An optional logging namespace
|
847
|
+
secondary: If set, treat this as a secondary value (meaning, it is
|
848
|
+
less important than other values, and some downstream loggers
|
849
|
+
will not display it)
|
780
850
|
"""
|
781
851
|
if not self.active:
|
782
852
|
raise RuntimeError("The logger is not active")
|
783
853
|
namespace = self.resolve_namespace(namespace)
|
784
854
|
|
785
855
|
@functools.lru_cache(maxsize=None)
|
786
|
-
def value_future() ->
|
787
|
-
return value() if callable(value) else value
|
856
|
+
def value_future() -> LogString:
|
857
|
+
return LogString(value=value() if callable(value) else value, secondary=secondary)
|
788
858
|
|
789
859
|
self.strings[namespace][key] = value_future
|
790
860
|
|
@@ -1019,6 +1089,73 @@ class Logger:
|
|
1019
1089
|
|
1020
1090
|
self.videos[namespace][key] = video_future
|
1021
1091
|
|
1092
|
+
def log_mesh(
|
1093
|
+
self,
|
1094
|
+
key: str,
|
1095
|
+
vertices: np.ndarray | Array | Callable[[], np.ndarray | Array],
|
1096
|
+
colors: np.ndarray | Array | None | Callable[[], np.ndarray | Array | None] = None,
|
1097
|
+
faces: np.ndarray | Array | None | Callable[[], np.ndarray | Array | None] = None,
|
1098
|
+
config_dict: dict[str, Any] | None = None,
|
1099
|
+
*,
|
1100
|
+
namespace: str | None = None,
|
1101
|
+
) -> None:
|
1102
|
+
if not self.active:
|
1103
|
+
raise RuntimeError("The logger is not active")
|
1104
|
+
namespace = self.resolve_namespace(namespace)
|
1105
|
+
|
1106
|
+
@functools.lru_cache(maxsize=None)
|
1107
|
+
def mesh_future() -> LogMesh:
|
1108
|
+
with ContextTimer() as timer:
|
1109
|
+
# Get the raw values
|
1110
|
+
vertices_val = vertices() if callable(vertices) else vertices
|
1111
|
+
colors_val = colors() if callable(colors) else colors
|
1112
|
+
faces_val = faces() if callable(faces) else faces
|
1113
|
+
|
1114
|
+
# Convert to numpy arrays with proper type handling
|
1115
|
+
vertices_np = as_numpy(vertices_val)
|
1116
|
+
colors_np = as_numpy_opt(colors_val)
|
1117
|
+
faces_np = as_numpy_opt(faces_val)
|
1118
|
+
|
1119
|
+
# Checks vertices shape.
|
1120
|
+
if vertices_np.ndim == 2:
|
1121
|
+
vertices_np = vertices_np[None]
|
1122
|
+
if vertices_np.shape[-1] != 3 or vertices_np.ndim != 3:
|
1123
|
+
raise ValueError("Vertices must have shape (N, 3) or (B, N, 3)")
|
1124
|
+
|
1125
|
+
# Checks colors shape.
|
1126
|
+
if colors_np is not None:
|
1127
|
+
if colors_np.ndim == 2:
|
1128
|
+
colors_np = colors_np[None]
|
1129
|
+
if colors_np.shape[-1] != 3 or colors_np.ndim != 3:
|
1130
|
+
raise ValueError("Colors must have shape (N, 3) or (B, N, 3)")
|
1131
|
+
|
1132
|
+
# Checks faces shape.
|
1133
|
+
if faces_np is not None:
|
1134
|
+
if faces_np.ndim == 2:
|
1135
|
+
faces_np = faces_np[None]
|
1136
|
+
if faces_np.shape[-1] != 3 or faces_np.ndim != 3:
|
1137
|
+
raise ValueError("Faces must have shape (N, 3) or (B, N, 3)")
|
1138
|
+
|
1139
|
+
# Ensures colors dtype is uint8.
|
1140
|
+
if colors_np is not None:
|
1141
|
+
if colors_np.dtype != np.uint8:
|
1142
|
+
colors_np = (colors_np * 255).astype(np.uint8)
|
1143
|
+
|
1144
|
+
# Ensures faces dtype is int32.
|
1145
|
+
if faces_np is not None:
|
1146
|
+
if faces_np.dtype != np.int32:
|
1147
|
+
faces_np = faces_np.astype(np.int32)
|
1148
|
+
|
1149
|
+
logger.debug("Mesh Key: %s, Time: %s", key, timer.elapsed_time)
|
1150
|
+
return LogMesh(
|
1151
|
+
vertices=vertices_np,
|
1152
|
+
colors=colors_np,
|
1153
|
+
faces=faces_np,
|
1154
|
+
config_dict=config_dict,
|
1155
|
+
)
|
1156
|
+
|
1157
|
+
self.meshes[namespace][key] = mesh_future
|
1158
|
+
|
1022
1159
|
def __enter__(self) -> Self:
|
1023
1160
|
self.active = True
|
1024
1161
|
for logger in self.loggers:
|
@@ -3,11 +3,19 @@
|
|
3
3
|
import json
|
4
4
|
import sys
|
5
5
|
from dataclasses import asdict
|
6
|
-
from typing import Any, Literal, TextIO
|
6
|
+
from typing import Any, Literal, Mapping, TextIO
|
7
7
|
|
8
8
|
from jaxtyping import Array
|
9
9
|
|
10
|
-
from xax.task.logger import
|
10
|
+
from xax.task.logger import (
|
11
|
+
LogError,
|
12
|
+
LoggerImpl,
|
13
|
+
LogLine,
|
14
|
+
LogPing,
|
15
|
+
LogScalar,
|
16
|
+
LogStatus,
|
17
|
+
LogString,
|
18
|
+
)
|
11
19
|
|
12
20
|
|
13
21
|
def get_json_value(value: Any) -> Any: # noqa: ANN401
|
@@ -61,14 +69,14 @@ class JsonLogger(LoggerImpl):
|
|
61
69
|
def get_json(self, line: LogLine) -> str:
|
62
70
|
data: dict = {"state": asdict(line.state)}
|
63
71
|
|
64
|
-
def add_logs(log:
|
72
|
+
def add_logs(log: Mapping[str, Mapping[str, LogScalar | LogString]], data: dict) -> None:
|
65
73
|
for namespace, values in log.items():
|
66
74
|
if self.remove_unicode_from_namespaces:
|
67
75
|
namespace = namespace.encode("ascii", errors="ignore").decode("ascii").strip()
|
68
76
|
if namespace not in data:
|
69
77
|
data[namespace] = {}
|
70
78
|
for k, v in values.items():
|
71
|
-
data[namespace][k] = get_json_value(v)
|
79
|
+
data[namespace][k] = get_json_value(v.value)
|
72
80
|
|
73
81
|
add_logs(line.scalars, data)
|
74
82
|
add_logs(line.strings, data)
|
@@ -4,11 +4,20 @@ import datetime
|
|
4
4
|
import logging
|
5
5
|
import sys
|
6
6
|
from collections import deque
|
7
|
-
from typing import Any, Deque, TextIO
|
7
|
+
from typing import Any, Deque, Mapping, TextIO
|
8
8
|
|
9
9
|
from jaxtyping import Array
|
10
10
|
|
11
|
-
from xax.task.logger import
|
11
|
+
from xax.task.logger import (
|
12
|
+
LogError,
|
13
|
+
LogErrorSummary,
|
14
|
+
LoggerImpl,
|
15
|
+
LogLine,
|
16
|
+
LogPing,
|
17
|
+
LogScalar,
|
18
|
+
LogStatus,
|
19
|
+
LogString,
|
20
|
+
)
|
12
21
|
from xax.utils.text import Color, colored, format_timedelta
|
13
22
|
|
14
23
|
|
@@ -95,20 +104,17 @@ class StdoutLogger(LoggerImpl):
|
|
95
104
|
def write_log_window(self, line: LogLine) -> None:
|
96
105
|
namespace_to_lines: dict[str, dict[str, str]] = {}
|
97
106
|
|
98
|
-
def add_logs(
|
107
|
+
def add_logs(
|
108
|
+
log: Mapping[str, Mapping[str, LogScalar | LogString]],
|
109
|
+
namespace_to_lines: dict[str, dict[str, str]],
|
110
|
+
) -> None:
|
99
111
|
for namespace, values in log.items():
|
100
|
-
if not self.log_timers and namespace.startswith("⌛"):
|
101
|
-
continue
|
102
|
-
if not self.log_perf and namespace.startswith("🔧"):
|
103
|
-
continue
|
104
|
-
if not self.log_optim and namespace.startswith("📉"):
|
105
|
-
continue
|
106
|
-
if not self.log_fp and namespace.startswith("⚖️"):
|
107
|
-
continue
|
108
|
-
if namespace not in namespace_to_lines:
|
109
|
-
namespace_to_lines[namespace] = {}
|
110
112
|
for k, v in values.items():
|
111
|
-
|
113
|
+
if v.secondary:
|
114
|
+
continue
|
115
|
+
if namespace not in namespace_to_lines:
|
116
|
+
namespace_to_lines[namespace] = {}
|
117
|
+
v_str = as_str(v.value, self.precision)
|
112
118
|
namespace_to_lines[namespace][k] = v_str
|
113
119
|
|
114
120
|
add_logs(line.scalars, namespace_to_lines)
|
@@ -116,9 +122,8 @@ class StdoutLogger(LoggerImpl):
|
|
116
122
|
if not namespace_to_lines:
|
117
123
|
return
|
118
124
|
|
119
|
-
self.write_fp.write("\n")
|
120
125
|
for namespace, lines in sorted(namespace_to_lines.items()):
|
121
|
-
self.write_fp.write(f"{colored(namespace, 'cyan', bold=True)}\n")
|
126
|
+
self.write_fp.write(f"\n{colored(namespace, 'cyan', bold=True)}\n")
|
122
127
|
for k, v in lines.items():
|
123
128
|
self.write_fp.write(f" ↪ {k}: {v}\n")
|
124
129
|
|
@@ -70,6 +70,9 @@ class TensorboardLogger(LoggerImpl):
|
|
70
70
|
self._started = True
|
71
71
|
|
72
72
|
def worker_thread(self) -> None:
|
73
|
+
if os.environ.get("DISABLE_TENSORBOARD", "0") == "1":
|
74
|
+
return
|
75
|
+
|
73
76
|
time.sleep(self.wait_seconds)
|
74
77
|
|
75
78
|
port = int(os.environ.get("TENSORBOARD_PORT", DEFAULT_TENSORBOARD_PORT))
|
@@ -158,7 +161,7 @@ class TensorboardLogger(LoggerImpl):
|
|
158
161
|
for scalar_key, scalar_value in scalars.items():
|
159
162
|
writer.add_scalar(
|
160
163
|
f"{namespace}/{scalar_key}",
|
161
|
-
as_float(scalar_value),
|
164
|
+
as_float(scalar_value.value),
|
162
165
|
global_step=line.state.num_steps,
|
163
166
|
walltime=walltime,
|
164
167
|
)
|
@@ -192,7 +195,7 @@ class TensorboardLogger(LoggerImpl):
|
|
192
195
|
for string_key, string_value in strings.items():
|
193
196
|
writer.add_text(
|
194
197
|
f"{namespace}/{string_key}",
|
195
|
-
string_value,
|
198
|
+
string_value.value,
|
196
199
|
global_step=line.state.num_steps,
|
197
200
|
walltime=walltime,
|
198
201
|
)
|
@@ -213,6 +216,19 @@ class TensorboardLogger(LoggerImpl):
|
|
213
216
|
video_value.frames,
|
214
217
|
fps=video_value.fps,
|
215
218
|
global_step=line.state.num_steps,
|
219
|
+
walltime=walltime,
|
220
|
+
)
|
221
|
+
|
222
|
+
for namespace, meshes in line.meshes.items():
|
223
|
+
for mesh_key, mesh_value in meshes.items():
|
224
|
+
writer.add_mesh(
|
225
|
+
f"{namespace}/{mesh_key}",
|
226
|
+
vertices=mesh_value.vertices,
|
227
|
+
faces=mesh_value.faces,
|
228
|
+
colors=mesh_value.colors,
|
229
|
+
config_dict=mesh_value.config_dict,
|
230
|
+
global_step=line.state.num_steps,
|
231
|
+
walltime=walltime,
|
216
232
|
)
|
217
233
|
|
218
234
|
for name, contents in self.files.items():
|