xax 0.1.15__tar.gz → 0.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. {xax-0.1.15/xax.egg-info → xax-0.2.0}/PKG-INFO +6 -6
  2. {xax-0.1.15 → xax-0.2.0}/pyproject.toml +0 -1
  3. {xax-0.1.15 → xax-0.2.0}/xax/__init__.py +1 -1
  4. {xax-0.1.15 → xax-0.2.0}/xax/core/state.py +26 -1
  5. {xax-0.1.15 → xax-0.2.0}/xax/requirements.txt +5 -5
  6. {xax-0.1.15 → xax-0.2.0}/xax/task/base.py +1 -1
  7. {xax-0.1.15 → xax-0.2.0}/xax/task/logger.py +149 -12
  8. {xax-0.1.15 → xax-0.2.0}/xax/task/loggers/json.py +12 -4
  9. {xax-0.1.15 → xax-0.2.0}/xax/task/loggers/stdout.py +21 -16
  10. {xax-0.1.15 → xax-0.2.0}/xax/task/loggers/tensorboard.py +18 -2
  11. {xax-0.1.15 → xax-0.2.0}/xax/task/mixins/checkpointing.py +118 -41
  12. {xax-0.1.15 → xax-0.2.0}/xax/task/mixins/cpu_stats.py +10 -10
  13. {xax-0.1.15 → xax-0.2.0}/xax/task/mixins/data_loader.py +2 -1
  14. {xax-0.1.15 → xax-0.2.0}/xax/task/mixins/gpu_stats.py +3 -3
  15. {xax-0.1.15 → xax-0.2.0}/xax/task/mixins/train.py +59 -29
  16. {xax-0.1.15 → xax-0.2.0}/xax/utils/experiments.py +34 -30
  17. {xax-0.1.15 → xax-0.2.0}/xax/utils/tensorboard.py +91 -3
  18. {xax-0.1.15 → xax-0.2.0/xax.egg-info}/PKG-INFO +6 -6
  19. {xax-0.1.15 → xax-0.2.0}/xax.egg-info/requires.txt +5 -5
  20. {xax-0.1.15 → xax-0.2.0}/LICENSE +0 -0
  21. {xax-0.1.15 → xax-0.2.0}/MANIFEST.in +0 -0
  22. {xax-0.1.15 → xax-0.2.0}/README.md +0 -0
  23. {xax-0.1.15 → xax-0.2.0}/setup.cfg +0 -0
  24. {xax-0.1.15 → xax-0.2.0}/setup.py +0 -0
  25. {xax-0.1.15 → xax-0.2.0}/xax/core/__init__.py +0 -0
  26. {xax-0.1.15 → xax-0.2.0}/xax/core/conf.py +0 -0
  27. {xax-0.1.15 → xax-0.2.0}/xax/nn/__init__.py +0 -0
  28. {xax-0.1.15 → xax-0.2.0}/xax/nn/embeddings.py +0 -0
  29. {xax-0.1.15 → xax-0.2.0}/xax/nn/equinox.py +0 -0
  30. {xax-0.1.15 → xax-0.2.0}/xax/nn/export.py +0 -0
  31. {xax-0.1.15 → xax-0.2.0}/xax/nn/functions.py +0 -0
  32. {xax-0.1.15 → xax-0.2.0}/xax/nn/geom.py +0 -0
  33. {xax-0.1.15 → xax-0.2.0}/xax/nn/losses.py +0 -0
  34. {xax-0.1.15 → xax-0.2.0}/xax/nn/norm.py +0 -0
  35. {xax-0.1.15 → xax-0.2.0}/xax/nn/parallel.py +0 -0
  36. {xax-0.1.15 → xax-0.2.0}/xax/nn/ssm.py +0 -0
  37. {xax-0.1.15 → xax-0.2.0}/xax/py.typed +0 -0
  38. {xax-0.1.15 → xax-0.2.0}/xax/requirements-dev.txt +0 -0
  39. {xax-0.1.15 → xax-0.2.0}/xax/task/__init__.py +0 -0
  40. {xax-0.1.15 → xax-0.2.0}/xax/task/launchers/__init__.py +0 -0
  41. {xax-0.1.15 → xax-0.2.0}/xax/task/launchers/base.py +0 -0
  42. {xax-0.1.15 → xax-0.2.0}/xax/task/launchers/cli.py +0 -0
  43. {xax-0.1.15 → xax-0.2.0}/xax/task/launchers/single_process.py +0 -0
  44. {xax-0.1.15 → xax-0.2.0}/xax/task/loggers/__init__.py +0 -0
  45. {xax-0.1.15 → xax-0.2.0}/xax/task/loggers/callback.py +0 -0
  46. {xax-0.1.15 → xax-0.2.0}/xax/task/loggers/state.py +0 -0
  47. {xax-0.1.15 → xax-0.2.0}/xax/task/mixins/__init__.py +0 -0
  48. {xax-0.1.15 → xax-0.2.0}/xax/task/mixins/artifacts.py +0 -0
  49. {xax-0.1.15 → xax-0.2.0}/xax/task/mixins/compile.py +0 -0
  50. {xax-0.1.15 → xax-0.2.0}/xax/task/mixins/logger.py +0 -0
  51. {xax-0.1.15 → xax-0.2.0}/xax/task/mixins/process.py +0 -0
  52. {xax-0.1.15 → xax-0.2.0}/xax/task/mixins/runnable.py +0 -0
  53. {xax-0.1.15 → xax-0.2.0}/xax/task/mixins/step_wrapper.py +0 -0
  54. {xax-0.1.15 → xax-0.2.0}/xax/task/script.py +0 -0
  55. {xax-0.1.15 → xax-0.2.0}/xax/task/task.py +0 -0
  56. {xax-0.1.15 → xax-0.2.0}/xax/utils/__init__.py +0 -0
  57. {xax-0.1.15 → xax-0.2.0}/xax/utils/data/__init__.py +0 -0
  58. {xax-0.1.15 → xax-0.2.0}/xax/utils/data/collate.py +0 -0
  59. {xax-0.1.15 → xax-0.2.0}/xax/utils/debugging.py +0 -0
  60. {xax-0.1.15 → xax-0.2.0}/xax/utils/jax.py +0 -0
  61. {xax-0.1.15 → xax-0.2.0}/xax/utils/jaxpr.py +0 -0
  62. {xax-0.1.15 → xax-0.2.0}/xax/utils/logging.py +0 -0
  63. {xax-0.1.15 → xax-0.2.0}/xax/utils/numpy.py +0 -0
  64. {xax-0.1.15 → xax-0.2.0}/xax/utils/profile.py +0 -0
  65. {xax-0.1.15 → xax-0.2.0}/xax/utils/pytree.py +0 -0
  66. {xax-0.1.15 → xax-0.2.0}/xax/utils/text.py +0 -0
  67. {xax-0.1.15 → xax-0.2.0}/xax/utils/types/__init__.py +0 -0
  68. {xax-0.1.15 → xax-0.2.0}/xax/utils/types/frozen_dict.py +0 -0
  69. {xax-0.1.15 → xax-0.2.0}/xax/utils/types/hashable_array.py +0 -0
  70. {xax-0.1.15 → xax-0.2.0}/xax.egg-info/SOURCES.txt +0 -0
  71. {xax-0.1.15 → xax-0.2.0}/xax.egg-info/dependency_links.txt +0 -0
  72. {xax-0.1.15 → xax-0.2.0}/xax.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.1.15
3
+ Version: 0.2.0
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -8,14 +8,14 @@ Requires-Python: >=3.11
8
8
  Description-Content-Type: text/markdown
9
9
  License-File: LICENSE
10
10
  Requires-Dist: attrs
11
+ Requires-Dist: chex
12
+ Requires-Dist: dpshdl
13
+ Requires-Dist: equinox
14
+ Requires-Dist: importlib-resources
11
15
  Requires-Dist: jax
12
16
  Requires-Dist: jaxtyping
13
- Requires-Dist: equinox
14
17
  Requires-Dist: optax
15
- Requires-Dist: dpshdl
16
- Requires-Dist: chex
17
- Requires-Dist: importlib-resources
18
- Requires-Dist: cloudpickle
18
+ Requires-Dist: orbax-checkpoint
19
19
  Requires-Dist: pillow
20
20
  Requires-Dist: omegaconf
21
21
  Requires-Dist: gitpython
@@ -35,7 +35,6 @@ explicit_package_bases = true
35
35
  [[tool.mypy.overrides]]
36
36
 
37
37
  module = [
38
- "cloudpickle.*",
39
38
  "optax.*",
40
39
  "setuptools.*",
41
40
  "tensorboard.*",
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.1.15"
15
+ __version__ = "0.2.0"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -12,6 +12,14 @@ from xax.core.conf import field
12
12
  Phase = Literal["train", "valid"]
13
13
 
14
14
 
15
+ def _phase_to_int(phase: Phase) -> int:
16
+ return {"train": 0, "valid": 1}[phase]
17
+
18
+
19
+ def _int_to_phase(i: int) -> Phase:
20
+ return cast(Phase, ["train", "valid"][i])
21
+
22
+
15
23
  class StateDict(TypedDict, total=False):
16
24
  num_steps: NotRequired[int]
17
25
  num_samples: NotRequired[int]
@@ -35,7 +43,7 @@ class State:
35
43
 
36
44
  @property
37
45
  def phase(self) -> Phase:
38
- return cast(Phase, ["train", "valid"][self._phase])
46
+ return _int_to_phase(self._phase)
39
47
 
40
48
  @classmethod
41
49
  def init_state(cls) -> "State":
@@ -74,3 +82,20 @@ class State:
74
82
  case _:
75
83
  raise ValueError(f"Invalid phase: {phase}")
76
84
  return State(**{**asdict(self), **kwargs, **extra_kwargs})
85
+
86
+ def to_dict(self) -> dict[str, int | float | str]:
87
+ return {
88
+ "num_steps": int(self.num_steps),
89
+ "num_samples": int(self.num_samples),
90
+ "num_valid_steps": int(self.num_valid_steps),
91
+ "num_valid_samples": int(self.num_valid_samples),
92
+ "start_time_s": float(self.start_time_s),
93
+ "elapsed_time_s": float(self.elapsed_time_s),
94
+ "phase": str(self.phase),
95
+ }
96
+
97
+ @classmethod
98
+ def from_dict(cls, d: dict[str, int | float | str]) -> "State":
99
+ if "phase" in d:
100
+ d["_phase"] = _phase_to_int(cast(Phase, d.pop("phase")))
101
+ return cls(**d) # type: ignore[arg-type]
@@ -2,16 +2,16 @@
2
2
 
3
3
  # Core ML/JAX dependencies
4
4
  attrs
5
+ chex
6
+ dpshdl
7
+ equinox
8
+ importlib-resources
5
9
  jax
6
10
  jaxtyping
7
- equinox
8
11
  optax
9
- dpshdl
10
- chex
11
- importlib-resources
12
+ orbax-checkpoint
12
13
 
13
14
  # Data processing and serialization
14
- cloudpickle
15
15
  pillow
16
16
 
17
17
  # Configuration and project management
@@ -79,7 +79,7 @@ class BaseTask(Generic[Config]):
79
79
  def on_training_end(self, state: State) -> State:
80
80
  return state
81
81
 
82
- def on_after_checkpoint_save(self, ckpt_path: Path, state: State) -> State:
82
+ def on_after_checkpoint_save(self, ckpt_path: Path, state: State | None) -> State | None:
83
83
  return state
84
84
 
85
85
  @functools.cached_property
@@ -18,11 +18,22 @@ from abc import ABC, abstractmethod
18
18
  from collections import defaultdict
19
19
  from dataclasses import dataclass
20
20
  from types import TracebackType
21
- from typing import Callable, Iterator, Literal, Self, Sequence, TypeVar, cast, get_args
21
+ from typing import (
22
+ Any,
23
+ Callable,
24
+ Iterator,
25
+ Literal,
26
+ Self,
27
+ Sequence,
28
+ TypeVar,
29
+ cast,
30
+ get_args,
31
+ )
22
32
 
23
33
  import jax
24
34
  import jax.numpy as jnp
25
35
  import numpy as np
36
+ from jax._src.core import ClosedJaxpr
26
37
  from jaxtyping import Array
27
38
  from PIL import Image, ImageDraw, ImageFont
28
39
  from PIL.Image import Image as PILImage
@@ -194,7 +205,10 @@ def tile_images(images: list[PILImage], sep: int = 0) -> PILImage:
194
205
  return tiled
195
206
 
196
207
 
197
- def as_numpy(array: Array) -> np.ndarray:
208
+ def as_numpy(array: Array | np.ndarray) -> np.ndarray:
209
+ """Convert a JAX array or numpy array to numpy array."""
210
+ if isinstance(array, np.ndarray):
211
+ return array
198
212
  array = jax.device_get(array)
199
213
  if jax.dtypes.issubdtype(array.dtype, jnp.floating):
200
214
  array = array.astype(jnp.float32)
@@ -205,6 +219,19 @@ def as_numpy(array: Array) -> np.ndarray:
205
219
  return np.array(array)
206
220
 
207
221
 
222
+ def as_numpy_opt(array: Array | np.ndarray | None) -> np.ndarray | None:
223
+ """Convert an optional JAX array or numpy array to numpy array."""
224
+ if array is None:
225
+ return None
226
+ return as_numpy(array)
227
+
228
+
229
+ @dataclass(kw_only=True)
230
+ class LogString:
231
+ value: str
232
+ secondary: bool
233
+
234
+
208
235
  @dataclass(kw_only=True)
209
236
  class LogImage:
210
237
  image: PILImage
@@ -223,6 +250,12 @@ class LogVideo:
223
250
  fps: int
224
251
 
225
252
 
253
+ @dataclass(kw_only=True)
254
+ class LogScalar:
255
+ value: Number
256
+ secondary: bool
257
+
258
+
226
259
  @dataclass(kw_only=True)
227
260
  class LogDistribution:
228
261
  mean: Number
@@ -240,15 +273,29 @@ class LogHistogram:
240
273
  bucket_counts: list[int]
241
274
 
242
275
 
276
+ @dataclass(kw_only=True)
277
+ class LogMesh:
278
+ vertices: np.ndarray
279
+ colors: np.ndarray | None
280
+ faces: np.ndarray | None
281
+ config_dict: dict[str, Any] | None # noqa: ANN401
282
+
283
+
284
+ @dataclass(kw_only=True)
285
+ class LogGraph:
286
+ computation: ClosedJaxpr
287
+
288
+
243
289
  @dataclass(kw_only=True)
244
290
  class LogLine:
245
291
  state: State
246
- scalars: dict[str, dict[str, Number]]
292
+ scalars: dict[str, dict[str, LogScalar]]
247
293
  distributions: dict[str, dict[str, LogDistribution]]
248
294
  histograms: dict[str, dict[str, LogHistogram]]
249
- strings: dict[str, dict[str, str]]
295
+ strings: dict[str, dict[str, LogString]]
250
296
  images: dict[str, dict[str, LogImage]]
251
297
  videos: dict[str, dict[str, LogVideo]]
298
+ meshes: dict[str, dict[str, LogMesh]]
252
299
 
253
300
 
254
301
  @dataclass(kw_only=True)
@@ -515,12 +562,13 @@ class Logger:
515
562
  """Defines an intermediate container which holds values to log somewhere else."""
516
563
 
517
564
  def __init__(self, default_namespace: str = DEFAULT_NAMESPACE) -> None:
518
- self.scalars: dict[str, dict[str, Callable[[], Number]]] = defaultdict(dict)
565
+ self.scalars: dict[str, dict[str, Callable[[], LogScalar]]] = defaultdict(dict)
519
566
  self.distributions: dict[str, dict[str, Callable[[], LogDistribution]]] = defaultdict(dict)
520
567
  self.histograms: dict[str, dict[str, Callable[[], LogHistogram]]] = defaultdict(dict)
521
- self.strings: dict[str, dict[str, Callable[[], str]]] = defaultdict(dict)
568
+ self.strings: dict[str, dict[str, Callable[[], LogString]]] = defaultdict(dict)
522
569
  self.images: dict[str, dict[str, Callable[[], LogImage]]] = defaultdict(dict)
523
570
  self.videos: dict[str, dict[str, Callable[[], LogVideo]]] = defaultdict(dict)
571
+ self.meshes: dict[str, dict[str, Callable[[], LogMesh]]] = defaultdict(dict)
524
572
  self.default_namespace = default_namespace
525
573
  self.loggers: list[LoggerImpl] = []
526
574
 
@@ -548,6 +596,7 @@ class Logger:
548
596
  strings={k: {kk: v() for kk, v in v.items()} for k, v in self.strings.items()},
549
597
  images={k: {kk: v() for kk, v in v.items()} for k, v in self.images.items()},
550
598
  videos={k: {kk: v() for kk, v in v.items()} for k, v in self.videos.items()},
599
+ meshes={k: {kk: v() for kk, v in v.items()} for k, v in self.meshes.items()},
551
600
  )
552
601
 
553
602
  def clear(self) -> None:
@@ -557,6 +606,7 @@ class Logger:
557
606
  self.strings.clear()
558
607
  self.images.clear()
559
608
  self.videos.clear()
609
+ self.meshes.clear()
560
610
 
561
611
  def write(self, state: State) -> None:
562
612
  """Writes the current step's logging information.
@@ -616,13 +666,23 @@ class Logger:
616
666
  def resolve_namespace(self, namespace: str | None = None) -> str:
617
667
  return "_".join([self.default_namespace if namespace is None else namespace] + NAMESPACE_STACK)
618
668
 
619
- def log_scalar(self, key: str, value: Callable[[], Number] | Number, *, namespace: str | None = None) -> None:
669
+ def log_scalar(
670
+ self,
671
+ key: str,
672
+ value: Callable[[], Number] | Number,
673
+ *,
674
+ namespace: str | None = None,
675
+ secondary: bool = False,
676
+ ) -> None:
620
677
  """Logs a scalar value.
621
678
 
622
679
  Args:
623
680
  key: The key being logged
624
681
  value: The scalar value being logged
625
682
  namespace: An optional logging namespace
683
+ secondary: If set, treat this as a secondary value (meaning, it is
684
+ less important than other values, and some downstream loggers
685
+ will not display it)
626
686
  """
627
687
  if not self.active:
628
688
  raise RuntimeError("The logger is not active")
@@ -632,11 +692,11 @@ class Logger:
632
692
  assert value.ndim == 0, f"Scalar must be a 0D array, got shape {value.shape}"
633
693
 
634
694
  @functools.lru_cache(maxsize=None)
635
- def scalar_future() -> Number:
695
+ def scalar_future() -> LogScalar:
636
696
  with ContextTimer() as timer:
637
697
  value_concrete = value() if callable(value) else value
638
698
  logger.debug("Scalar Key: %s, Time: %s", key, timer.elapsed_time)
639
- return value_concrete
699
+ return LogScalar(value=value_concrete, secondary=secondary)
640
700
 
641
701
  self.scalars[namespace][key] = scalar_future
642
702
 
@@ -770,21 +830,31 @@ class Logger:
770
830
 
771
831
  self.histograms[namespace][key] = histogram_future
772
832
 
773
- def log_string(self, key: str, value: Callable[[], str] | str, *, namespace: str | None = None) -> None:
833
+ def log_string(
834
+ self,
835
+ key: str,
836
+ value: Callable[[], str] | str,
837
+ *,
838
+ namespace: str | None = None,
839
+ secondary: bool = False,
840
+ ) -> None:
774
841
  """Logs a string value.
775
842
 
776
843
  Args:
777
844
  key: The key being logged
778
845
  value: The string value being logged
779
846
  namespace: An optional logging namespace
847
+ secondary: If set, treat this as a secondary value (meaning, it is
848
+ less important than other values, and some downstream loggers
849
+ will not display it)
780
850
  """
781
851
  if not self.active:
782
852
  raise RuntimeError("The logger is not active")
783
853
  namespace = self.resolve_namespace(namespace)
784
854
 
785
855
  @functools.lru_cache(maxsize=None)
786
- def value_future() -> str:
787
- return value() if callable(value) else value
856
+ def value_future() -> LogString:
857
+ return LogString(value=value() if callable(value) else value, secondary=secondary)
788
858
 
789
859
  self.strings[namespace][key] = value_future
790
860
 
@@ -1019,6 +1089,73 @@ class Logger:
1019
1089
 
1020
1090
  self.videos[namespace][key] = video_future
1021
1091
 
1092
+ def log_mesh(
1093
+ self,
1094
+ key: str,
1095
+ vertices: np.ndarray | Array | Callable[[], np.ndarray | Array],
1096
+ colors: np.ndarray | Array | None | Callable[[], np.ndarray | Array | None] = None,
1097
+ faces: np.ndarray | Array | None | Callable[[], np.ndarray | Array | None] = None,
1098
+ config_dict: dict[str, Any] | None = None,
1099
+ *,
1100
+ namespace: str | None = None,
1101
+ ) -> None:
1102
+ if not self.active:
1103
+ raise RuntimeError("The logger is not active")
1104
+ namespace = self.resolve_namespace(namespace)
1105
+
1106
+ @functools.lru_cache(maxsize=None)
1107
+ def mesh_future() -> LogMesh:
1108
+ with ContextTimer() as timer:
1109
+ # Get the raw values
1110
+ vertices_val = vertices() if callable(vertices) else vertices
1111
+ colors_val = colors() if callable(colors) else colors
1112
+ faces_val = faces() if callable(faces) else faces
1113
+
1114
+ # Convert to numpy arrays with proper type handling
1115
+ vertices_np = as_numpy(vertices_val)
1116
+ colors_np = as_numpy_opt(colors_val)
1117
+ faces_np = as_numpy_opt(faces_val)
1118
+
1119
+ # Checks vertices shape.
1120
+ if vertices_np.ndim == 2:
1121
+ vertices_np = vertices_np[None]
1122
+ if vertices_np.shape[-1] != 3 or vertices_np.ndim != 3:
1123
+ raise ValueError("Vertices must have shape (N, 3) or (B, N, 3)")
1124
+
1125
+ # Checks colors shape.
1126
+ if colors_np is not None:
1127
+ if colors_np.ndim == 2:
1128
+ colors_np = colors_np[None]
1129
+ if colors_np.shape[-1] != 3 or colors_np.ndim != 3:
1130
+ raise ValueError("Colors must have shape (N, 3) or (B, N, 3)")
1131
+
1132
+ # Checks faces shape.
1133
+ if faces_np is not None:
1134
+ if faces_np.ndim == 2:
1135
+ faces_np = faces_np[None]
1136
+ if faces_np.shape[-1] != 3 or faces_np.ndim != 3:
1137
+ raise ValueError("Faces must have shape (N, 3) or (B, N, 3)")
1138
+
1139
+ # Ensures colors dtype is uint8.
1140
+ if colors_np is not None:
1141
+ if colors_np.dtype != np.uint8:
1142
+ colors_np = (colors_np * 255).astype(np.uint8)
1143
+
1144
+ # Ensures faces dtype is int32.
1145
+ if faces_np is not None:
1146
+ if faces_np.dtype != np.int32:
1147
+ faces_np = faces_np.astype(np.int32)
1148
+
1149
+ logger.debug("Mesh Key: %s, Time: %s", key, timer.elapsed_time)
1150
+ return LogMesh(
1151
+ vertices=vertices_np,
1152
+ colors=colors_np,
1153
+ faces=faces_np,
1154
+ config_dict=config_dict,
1155
+ )
1156
+
1157
+ self.meshes[namespace][key] = mesh_future
1158
+
1022
1159
  def __enter__(self) -> Self:
1023
1160
  self.active = True
1024
1161
  for logger in self.loggers:
@@ -3,11 +3,19 @@
3
3
  import json
4
4
  import sys
5
5
  from dataclasses import asdict
6
- from typing import Any, Literal, TextIO
6
+ from typing import Any, Literal, Mapping, TextIO
7
7
 
8
8
  from jaxtyping import Array
9
9
 
10
- from xax.task.logger import LogError, LoggerImpl, LogLine, LogPing, LogStatus
10
+ from xax.task.logger import (
11
+ LogError,
12
+ LoggerImpl,
13
+ LogLine,
14
+ LogPing,
15
+ LogScalar,
16
+ LogStatus,
17
+ LogString,
18
+ )
11
19
 
12
20
 
13
21
  def get_json_value(value: Any) -> Any: # noqa: ANN401
@@ -61,14 +69,14 @@ class JsonLogger(LoggerImpl):
61
69
  def get_json(self, line: LogLine) -> str:
62
70
  data: dict = {"state": asdict(line.state)}
63
71
 
64
- def add_logs(log: dict[str, dict[str, Any]], data: dict) -> None:
72
+ def add_logs(log: Mapping[str, Mapping[str, LogScalar | LogString]], data: dict) -> None:
65
73
  for namespace, values in log.items():
66
74
  if self.remove_unicode_from_namespaces:
67
75
  namespace = namespace.encode("ascii", errors="ignore").decode("ascii").strip()
68
76
  if namespace not in data:
69
77
  data[namespace] = {}
70
78
  for k, v in values.items():
71
- data[namespace][k] = get_json_value(v)
79
+ data[namespace][k] = get_json_value(v.value)
72
80
 
73
81
  add_logs(line.scalars, data)
74
82
  add_logs(line.strings, data)
@@ -4,11 +4,20 @@ import datetime
4
4
  import logging
5
5
  import sys
6
6
  from collections import deque
7
- from typing import Any, Deque, TextIO
7
+ from typing import Any, Deque, Mapping, TextIO
8
8
 
9
9
  from jaxtyping import Array
10
10
 
11
- from xax.task.logger import LogError, LogErrorSummary, LoggerImpl, LogLine, LogPing, LogStatus
11
+ from xax.task.logger import (
12
+ LogError,
13
+ LogErrorSummary,
14
+ LoggerImpl,
15
+ LogLine,
16
+ LogPing,
17
+ LogScalar,
18
+ LogStatus,
19
+ LogString,
20
+ )
12
21
  from xax.utils.text import Color, colored, format_timedelta
13
22
 
14
23
 
@@ -95,20 +104,17 @@ class StdoutLogger(LoggerImpl):
95
104
  def write_log_window(self, line: LogLine) -> None:
96
105
  namespace_to_lines: dict[str, dict[str, str]] = {}
97
106
 
98
- def add_logs(log: dict[str, dict[str, Any]], namespace_to_lines: dict[str, dict[str, str]]) -> None:
107
+ def add_logs(
108
+ log: Mapping[str, Mapping[str, LogScalar | LogString]],
109
+ namespace_to_lines: dict[str, dict[str, str]],
110
+ ) -> None:
99
111
  for namespace, values in log.items():
100
- if not self.log_timers and namespace.startswith("⌛"):
101
- continue
102
- if not self.log_perf and namespace.startswith("🔧"):
103
- continue
104
- if not self.log_optim and namespace.startswith("📉"):
105
- continue
106
- if not self.log_fp and namespace.startswith("⚖️"):
107
- continue
108
- if namespace not in namespace_to_lines:
109
- namespace_to_lines[namespace] = {}
110
112
  for k, v in values.items():
111
- v_str = as_str(v, self.precision)
113
+ if v.secondary:
114
+ continue
115
+ if namespace not in namespace_to_lines:
116
+ namespace_to_lines[namespace] = {}
117
+ v_str = as_str(v.value, self.precision)
112
118
  namespace_to_lines[namespace][k] = v_str
113
119
 
114
120
  add_logs(line.scalars, namespace_to_lines)
@@ -116,9 +122,8 @@ class StdoutLogger(LoggerImpl):
116
122
  if not namespace_to_lines:
117
123
  return
118
124
 
119
- self.write_fp.write("\n")
120
125
  for namespace, lines in sorted(namespace_to_lines.items()):
121
- self.write_fp.write(f"{colored(namespace, 'cyan', bold=True)}\n")
126
+ self.write_fp.write(f"\n{colored(namespace, 'cyan', bold=True)}\n")
122
127
  for k, v in lines.items():
123
128
  self.write_fp.write(f" ↪ {k}: {v}\n")
124
129
 
@@ -70,6 +70,9 @@ class TensorboardLogger(LoggerImpl):
70
70
  self._started = True
71
71
 
72
72
  def worker_thread(self) -> None:
73
+ if os.environ.get("DISABLE_TENSORBOARD", "0") == "1":
74
+ return
75
+
73
76
  time.sleep(self.wait_seconds)
74
77
 
75
78
  port = int(os.environ.get("TENSORBOARD_PORT", DEFAULT_TENSORBOARD_PORT))
@@ -158,7 +161,7 @@ class TensorboardLogger(LoggerImpl):
158
161
  for scalar_key, scalar_value in scalars.items():
159
162
  writer.add_scalar(
160
163
  f"{namespace}/{scalar_key}",
161
- as_float(scalar_value),
164
+ as_float(scalar_value.value),
162
165
  global_step=line.state.num_steps,
163
166
  walltime=walltime,
164
167
  )
@@ -192,7 +195,7 @@ class TensorboardLogger(LoggerImpl):
192
195
  for string_key, string_value in strings.items():
193
196
  writer.add_text(
194
197
  f"{namespace}/{string_key}",
195
- string_value,
198
+ string_value.value,
196
199
  global_step=line.state.num_steps,
197
200
  walltime=walltime,
198
201
  )
@@ -213,6 +216,19 @@ class TensorboardLogger(LoggerImpl):
213
216
  video_value.frames,
214
217
  fps=video_value.fps,
215
218
  global_step=line.state.num_steps,
219
+ walltime=walltime,
220
+ )
221
+
222
+ for namespace, meshes in line.meshes.items():
223
+ for mesh_key, mesh_value in meshes.items():
224
+ writer.add_mesh(
225
+ f"{namespace}/{mesh_key}",
226
+ vertices=mesh_value.vertices,
227
+ faces=mesh_value.faces,
228
+ colors=mesh_value.colors,
229
+ config_dict=mesh_value.config_dict,
230
+ global_step=line.state.num_steps,
231
+ walltime=walltime,
216
232
  )
217
233
 
218
234
  for name, contents in self.files.items():