xax 0.3.14__tar.gz → 0.4.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. {xax-0.3.14/xax.egg-info → xax-0.4.4}/PKG-INFO +3 -1
  2. {xax-0.3.14 → xax-0.4.4}/pyproject.toml +1 -0
  3. {xax-0.3.14 → xax-0.4.4}/setup.py +1 -0
  4. {xax-0.3.14 → xax-0.4.4}/xax/__init__.py +12 -4
  5. xax-0.4.4/xax/task/launchers/single_process.py +141 -0
  6. xax-0.4.4/xax/task/loggers/wandb.py +307 -0
  7. {xax-0.3.14 → xax-0.4.4}/xax/task/mixins/__init__.py +2 -1
  8. xax-0.4.4/xax/task/mixins/logger.py +169 -0
  9. xax-0.4.4/xax/task/mixins/supervised.py +368 -0
  10. {xax-0.3.14 → xax-0.4.4}/xax/task/mixins/train.py +36 -345
  11. {xax-0.3.14 → xax-0.4.4}/xax/task/task.py +26 -2
  12. {xax-0.3.14 → xax-0.4.4}/xax/utils/experiments.py +2 -2
  13. {xax-0.3.14 → xax-0.4.4}/xax/utils/types/frozen_dict.py +4 -0
  14. {xax-0.3.14 → xax-0.4.4/xax.egg-info}/PKG-INFO +3 -1
  15. {xax-0.3.14 → xax-0.4.4}/xax.egg-info/SOURCES.txt +2 -0
  16. {xax-0.3.14 → xax-0.4.4}/xax.egg-info/requires.txt +3 -0
  17. xax-0.3.14/xax/task/launchers/single_process.py +0 -31
  18. xax-0.3.14/xax/task/mixins/logger.py +0 -92
  19. {xax-0.3.14 → xax-0.4.4}/LICENSE +0 -0
  20. {xax-0.3.14 → xax-0.4.4}/MANIFEST.in +0 -0
  21. {xax-0.3.14 → xax-0.4.4}/README.md +0 -0
  22. {xax-0.3.14 → xax-0.4.4}/setup.cfg +0 -0
  23. {xax-0.3.14 → xax-0.4.4}/xax/cli/__init__.py +0 -0
  24. {xax-0.3.14 → xax-0.4.4}/xax/cli/edit_config.py +0 -0
  25. {xax-0.3.14 → xax-0.4.4}/xax/core/__init__.py +0 -0
  26. {xax-0.3.14 → xax-0.4.4}/xax/core/conf.py +0 -0
  27. {xax-0.3.14 → xax-0.4.4}/xax/core/state.py +0 -0
  28. {xax-0.3.14 → xax-0.4.4}/xax/nn/__init__.py +0 -0
  29. {xax-0.3.14 → xax-0.4.4}/xax/nn/attention.py +0 -0
  30. {xax-0.3.14 → xax-0.4.4}/xax/nn/distributions.py +0 -0
  31. {xax-0.3.14 → xax-0.4.4}/xax/nn/embeddings.py +0 -0
  32. {xax-0.3.14 → xax-0.4.4}/xax/nn/functions.py +0 -0
  33. {xax-0.3.14 → xax-0.4.4}/xax/nn/geom.py +0 -0
  34. {xax-0.3.14 → xax-0.4.4}/xax/nn/losses.py +0 -0
  35. {xax-0.3.14 → xax-0.4.4}/xax/nn/metrics.py +0 -0
  36. {xax-0.3.14 → xax-0.4.4}/xax/nn/parallel.py +0 -0
  37. {xax-0.3.14 → xax-0.4.4}/xax/nn/ssm.py +0 -0
  38. {xax-0.3.14 → xax-0.4.4}/xax/py.typed +0 -0
  39. {xax-0.3.14 → xax-0.4.4}/xax/requirements-dev.txt +0 -0
  40. {xax-0.3.14 → xax-0.4.4}/xax/requirements.txt +0 -0
  41. {xax-0.3.14 → xax-0.4.4}/xax/task/__init__.py +0 -0
  42. {xax-0.3.14 → xax-0.4.4}/xax/task/base.py +0 -0
  43. {xax-0.3.14 → xax-0.4.4}/xax/task/launchers/__init__.py +0 -0
  44. {xax-0.3.14 → xax-0.4.4}/xax/task/launchers/base.py +0 -0
  45. {xax-0.3.14 → xax-0.4.4}/xax/task/launchers/cli.py +0 -0
  46. {xax-0.3.14 → xax-0.4.4}/xax/task/logger.py +0 -0
  47. {xax-0.3.14 → xax-0.4.4}/xax/task/loggers/__init__.py +0 -0
  48. {xax-0.3.14 → xax-0.4.4}/xax/task/loggers/callback.py +0 -0
  49. {xax-0.3.14 → xax-0.4.4}/xax/task/loggers/json.py +0 -0
  50. {xax-0.3.14 → xax-0.4.4}/xax/task/loggers/state.py +0 -0
  51. {xax-0.3.14 → xax-0.4.4}/xax/task/loggers/stdout.py +0 -0
  52. {xax-0.3.14 → xax-0.4.4}/xax/task/loggers/tensorboard.py +0 -0
  53. {xax-0.3.14 → xax-0.4.4}/xax/task/mixins/artifacts.py +0 -0
  54. {xax-0.3.14 → xax-0.4.4}/xax/task/mixins/checkpointing.py +0 -0
  55. {xax-0.3.14 → xax-0.4.4}/xax/task/mixins/compile.py +0 -0
  56. {xax-0.3.14 → xax-0.4.4}/xax/task/mixins/cpu_stats.py +0 -0
  57. {xax-0.3.14 → xax-0.4.4}/xax/task/mixins/data_loader.py +0 -0
  58. {xax-0.3.14 → xax-0.4.4}/xax/task/mixins/gpu_stats.py +0 -0
  59. {xax-0.3.14 → xax-0.4.4}/xax/task/mixins/process.py +0 -0
  60. {xax-0.3.14 → xax-0.4.4}/xax/task/mixins/runnable.py +0 -0
  61. {xax-0.3.14 → xax-0.4.4}/xax/task/mixins/step_wrapper.py +0 -0
  62. {xax-0.3.14 → xax-0.4.4}/xax/task/script.py +0 -0
  63. {xax-0.3.14 → xax-0.4.4}/xax/utils/__init__.py +0 -0
  64. {xax-0.3.14 → xax-0.4.4}/xax/utils/data/__init__.py +0 -0
  65. {xax-0.3.14 → xax-0.4.4}/xax/utils/data/collate.py +0 -0
  66. {xax-0.3.14 → xax-0.4.4}/xax/utils/debugging.py +0 -0
  67. {xax-0.3.14 → xax-0.4.4}/xax/utils/jax.py +0 -0
  68. {xax-0.3.14 → xax-0.4.4}/xax/utils/jaxpr.py +0 -0
  69. {xax-0.3.14 → xax-0.4.4}/xax/utils/logging.py +0 -0
  70. {xax-0.3.14 → xax-0.4.4}/xax/utils/numpy.py +0 -0
  71. {xax-0.3.14 → xax-0.4.4}/xax/utils/profile.py +0 -0
  72. {xax-0.3.14 → xax-0.4.4}/xax/utils/pytree.py +0 -0
  73. {xax-0.3.14 → xax-0.4.4}/xax/utils/tensorboard.py +0 -0
  74. {xax-0.3.14 → xax-0.4.4}/xax/utils/text.py +0 -0
  75. {xax-0.3.14 → xax-0.4.4}/xax/utils/types/__init__.py +0 -0
  76. {xax-0.3.14 → xax-0.4.4}/xax/utils/types/hashable_array.py +0 -0
  77. {xax-0.3.14 → xax-0.4.4}/xax.egg-info/dependency_links.txt +0 -0
  78. {xax-0.3.14 → xax-0.4.4}/xax.egg-info/entry_points.txt +0 -0
  79. {xax-0.3.14 → xax-0.4.4}/xax.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.3.14
3
+ Version: 0.4.4
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -31,6 +31,8 @@ Requires-Dist: pytest; extra == "dev"
31
31
  Requires-Dist: types-pillow; extra == "dev"
32
32
  Requires-Dist: types-psutil; extra == "dev"
33
33
  Requires-Dist: types-requests; extra == "dev"
34
+ Provides-Extra: wandb
35
+ Requires-Dist: wandb[media]; extra == "wandb"
34
36
  Dynamic: author
35
37
  Dynamic: description
36
38
  Dynamic: description-content-type
@@ -35,6 +35,7 @@ explicit_package_bases = true
35
35
  [[tool.mypy.overrides]]
36
36
 
37
37
  module = [
38
+ "chex.*",
38
39
  "optax.*",
39
40
  "setuptools.*",
40
41
  "tensorboard.*",
@@ -33,6 +33,7 @@ setup(
33
33
  tests_require=requirements_dev,
34
34
  extras_require={
35
35
  "dev": requirements_dev,
36
+ "wandb": ["wandb[media]"],
36
37
  },
37
38
  package_data={
38
39
  "xax": [
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.3.14"
15
+ __version__ = "0.4.4"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -94,10 +94,13 @@ __all__ = [
94
94
  "DataloaderConfig",
95
95
  "GPUStatsOptions",
96
96
  "StepContext",
97
+ "InitParams",
97
98
  "ValidStepTimer",
98
99
  "Script",
99
100
  "ScriptConfig",
100
101
  "Config",
102
+ "SupervisedConfig",
103
+ "SupervisedTask",
101
104
  "Task",
102
105
  "collate",
103
106
  "collate_non_null",
@@ -168,6 +171,7 @@ __all__ = [
168
171
  "uncolored",
169
172
  "wrapped",
170
173
  "FrozenDict",
174
+ "freeze_dict",
171
175
  "HashableArray",
172
176
  "hashable_array",
173
177
  ]
@@ -291,10 +295,13 @@ NAME_MAP: dict[str, str] = {
291
295
  "DataloaderConfig": "task.mixins.data_loader",
292
296
  "GPUStatsOptions": "task.mixins.gpu_stats",
293
297
  "StepContext": "task.mixins.step_wrapper",
298
+ "InitParams": "task.mixins.train",
294
299
  "ValidStepTimer": "task.mixins.train",
295
300
  "Script": "task.script",
296
301
  "ScriptConfig": "task.script",
297
302
  "Config": "task.task",
303
+ "SupervisedConfig": "task.task",
304
+ "SupervisedTask": "task.task",
298
305
  "Task": "task.task",
299
306
  "collate": "utils.data.collate",
300
307
  "collate_non_null": "utils.data.collate",
@@ -365,6 +372,7 @@ NAME_MAP: dict[str, str] = {
365
372
  "uncolored": "utils.text",
366
373
  "wrapped": "utils.text",
367
374
  "FrozenDict": "utils.types.frozen_dict",
375
+ "freeze_dict": "utils.types.frozen_dict",
368
376
  "HashableArray": "utils.types.hashable_array",
369
377
  "hashable_array": "utils.types.hashable_array",
370
378
  }
@@ -488,9 +496,9 @@ if IMPORT_ALL or TYPE_CHECKING:
488
496
  from xax.task.mixins.data_loader import DataloaderConfig
489
497
  from xax.task.mixins.gpu_stats import GPUStatsOptions
490
498
  from xax.task.mixins.step_wrapper import StepContext
491
- from xax.task.mixins.train import Batch, Output, ValidStepTimer
499
+ from xax.task.mixins.train import Batch, InitParams, Output, ValidStepTimer
492
500
  from xax.task.script import Script, ScriptConfig
493
- from xax.task.task import Config, Task
501
+ from xax.task.task import Config, SupervisedConfig, SupervisedTask, Task
494
502
  from xax.utils.data.collate import CollateMode, collate, collate_non_null
495
503
  from xax.utils.debugging import (
496
504
  breakpoint_if_nonfinite,
@@ -566,7 +574,7 @@ if IMPORT_ALL or TYPE_CHECKING:
566
574
  uncolored,
567
575
  wrapped,
568
576
  )
569
- from xax.utils.types.frozen_dict import FrozenDict
577
+ from xax.utils.types.frozen_dict import FrozenDict, freeze_dict
570
578
  from xax.utils.types.hashable_array import HashableArray, hashable_array
571
579
 
572
580
  del TYPE_CHECKING, IMPORT_ALL
@@ -0,0 +1,141 @@
1
+ """Defines a launcher to train a model locally, in a single process."""
2
+
3
+ import logging
4
+ import os
5
+ import shutil
6
+ import subprocess
7
+ from typing import TYPE_CHECKING
8
+
9
+ import jax
10
+
11
+ from xax.task.base import RawConfigType
12
+ from xax.task.launchers.base import BaseLauncher
13
+ from xax.task.mixins.gpu_stats import get_num_gpus
14
+ from xax.utils.logging import configure_logging
15
+
16
+ if TYPE_CHECKING:
17
+ from xax.task.mixins.runnable import Config, RunnableMixin
18
+
19
+
20
+ def get_gpu_memory_info() -> dict[int, tuple[float, float]]:
21
+ """Get memory information for all GPUs.
22
+
23
+ Returns:
24
+ Dictionary mapping GPU index to (total_memory_mb, used_memory_mb)
25
+ """
26
+ command = "nvidia-smi --query-gpu=index,memory.total,memory.used --format=csv,noheader"
27
+
28
+ try:
29
+ with subprocess.Popen(command.split(), stdout=subprocess.PIPE, universal_newlines=True) as proc:
30
+ stdout = proc.stdout
31
+ assert stdout is not None
32
+
33
+ gpu_info = {}
34
+ for line in stdout:
35
+ line = line.strip()
36
+ if not line:
37
+ continue
38
+
39
+ parts = line.split(", ")
40
+ if len(parts) >= 3:
41
+ gpu_id = int(parts[0])
42
+ total_mem = float(parts[1].replace(" MiB", ""))
43
+ used_mem = float(parts[2].replace(" MiB", ""))
44
+ gpu_info[gpu_id] = (total_mem, used_mem)
45
+
46
+ return gpu_info
47
+
48
+ except Exception as e:
49
+ logger = configure_logging()
50
+ logger.warning("Failed to get GPU memory info: %s", e)
51
+ return {}
52
+
53
+
54
+ def select_best_gpu() -> int | None:
55
+ """Select the GPU with the most available memory.
56
+
57
+ Returns:
58
+ GPU index with most available memory, or None if no GPUs found
59
+ """
60
+ gpu_info = get_gpu_memory_info()
61
+
62
+ if not gpu_info:
63
+ return None
64
+
65
+ # Find GPU with most available memory
66
+ best_gpu = None
67
+ max_available: float = -1.0
68
+
69
+ for gpu_id, (total_mem, used_mem) in gpu_info.items():
70
+ available_mem = total_mem - used_mem
71
+ if available_mem > max_available:
72
+ max_available = available_mem
73
+ best_gpu = gpu_id
74
+
75
+ return best_gpu
76
+
77
+
78
+ def configure_gpu_devices(logger: logging.Logger | None = None) -> None:
79
+ if logger is None:
80
+ logger = configure_logging()
81
+
82
+ # If there are multiple devices, choose the one with the most
83
+ # available memory (i.e., the one which is likely not being used
84
+ # by other processes) and use only that device.
85
+ num_gpus = get_num_gpus()
86
+
87
+ if num_gpus > 1:
88
+ logger.info("Multiple GPUs detected (%d), selecting GPU with most available memory", num_gpus)
89
+
90
+ best_gpu = select_best_gpu()
91
+ if best_gpu is not None:
92
+ logger.info("Selected GPU %d for training", best_gpu)
93
+
94
+ # Set CUDA_VISIBLE_DEVICES to only show the selected GPU
95
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(best_gpu)
96
+
97
+ # Configure JAX to use the selected device
98
+ try:
99
+ devices = jax.devices("gpu")
100
+ if devices:
101
+ jax.config.update("jax_default_device", devices[0])
102
+ logger.info("Configured JAX to use device: %s", devices[0])
103
+ except Exception as e:
104
+ logger.warning("Failed to configure JAX device: %s", e)
105
+ else:
106
+ logger.warning("Could not determine best GPU, using default device selection")
107
+ elif num_gpus == 1:
108
+ logger.info("Single GPU detected, using default device selection")
109
+
110
+
111
+ def configure_devices(logger: logging.Logger | None = None) -> None:
112
+ if logger is None:
113
+ logger = configure_logging()
114
+
115
+ if shutil.which("nvidia-smi") is not None:
116
+ configure_gpu_devices(logger)
117
+
118
+
119
+ def run_single_process_training(
120
+ task: "type[RunnableMixin[Config]]",
121
+ *cfgs: RawConfigType,
122
+ use_cli: bool | list[str] = True,
123
+ logger: logging.Logger | None = None,
124
+ ) -> None:
125
+ if logger is None:
126
+ logger = configure_logging()
127
+ task_obj = task.get_task(*cfgs, use_cli=use_cli)
128
+ task_obj.add_logger_handlers(logger)
129
+ task_obj.run()
130
+
131
+
132
+ class SingleProcessLauncher(BaseLauncher):
133
+ def launch(
134
+ self,
135
+ task: "type[RunnableMixin[Config]]",
136
+ *cfgs: RawConfigType,
137
+ use_cli: bool | list[str] = True,
138
+ ) -> None:
139
+ logger = configure_logging()
140
+ configure_devices(logger)
141
+ run_single_process_training(task, *cfgs, use_cli=use_cli, logger=logger)
@@ -0,0 +1,307 @@
1
+ # mypy: disable-error-code="import-not-found"
2
+ """Defines a Weights & Biases logger backend."""
3
+
4
+ import logging
5
+ import os
6
+ from enum import Enum
7
+ from pathlib import Path
8
+ from typing import Any, TypeVar
9
+
10
+ import numpy as np
11
+
12
+ from xax.nn.parallel import is_master
13
+ from xax.task.logger import LogError, LogErrorSummary, LoggerImpl, LogLine, LogPing, LogStatus
14
+ from xax.utils.jax import as_float
15
+
16
+ logger: logging.Logger = logging.getLogger(__name__)
17
+
18
+ T = TypeVar("T")
19
+
20
+
21
+ def sanitize_metric_name(name: str) -> str:
22
+ """Remove 4-byte unicode characters from metric names.
23
+
24
+ W&B has issues with 4-byte unicode characters in metric names,
25
+ so we need to filter them out.
26
+
27
+ Args:
28
+ name: The metric name to sanitize.
29
+
30
+ Returns:
31
+ The sanitized metric name.
32
+ """
33
+ # Filter out characters that don't fit in UCS-2 (Basic Multilingual Plane)
34
+ # These are characters with code points > 0xFFFF (4-byte UTF-8)
35
+ return "".join(char for char in name if ord(char) <= 0xFFFF)
36
+
37
+
38
+ class WandbConfigResumeOption(str, Enum):
39
+ ALLOW = "allow"
40
+ NEVER = "never"
41
+ MUST = "must"
42
+ AUTO = "auto"
43
+
44
+
45
+ class WandbConfigModeOption(str, Enum):
46
+ ONLINE = "online"
47
+ OFFLINE = "offline"
48
+ DISABLED = "disabled"
49
+ SHARED = "shared"
50
+
51
+
52
+ class WandbConfigReinitOption(str, Enum):
53
+ RETURN_PREVIOUS = "return_previous"
54
+ FINISH_PREVIOUS = "finish_previous"
55
+
56
+
57
+ WandbConfigResume = WandbConfigResumeOption | bool
58
+ WandbConfigMode = WandbConfigModeOption | None
59
+
60
+
61
+ class WandbLogger(LoggerImpl):
62
+ def __init__(
63
+ self,
64
+ project: str | None = None,
65
+ entity: str | None = None,
66
+ name: str | None = None,
67
+ run_directory: str | Path | None = None,
68
+ config: dict[str, Any] | None = None,
69
+ tags: list[str] | None = None,
70
+ notes: str | None = None,
71
+ log_interval_seconds: float = 10.0,
72
+ reinit: WandbConfigReinitOption = WandbConfigReinitOption.RETURN_PREVIOUS,
73
+ resume: WandbConfigResume = False,
74
+ mode: WandbConfigMode = None,
75
+ ) -> None:
76
+ """Defines a logger which writes to Weights & Biases.
77
+
78
+ Args:
79
+ project: The name of the W&B project to log to.
80
+ entity: The W&B entity (team or user) to log to.
81
+ name: The name of this run.
82
+ run_directory: The root run directory. If provided, wandb will save
83
+ files to a subdirectory here.
84
+ config: Configuration dictionary to log.
85
+ tags: List of tags for this run.
86
+ notes: Notes about this run.
87
+ log_interval_seconds: The interval between successive log lines.
88
+ reinit: Whether to allow multiple wandb.init() calls in the same process.
89
+ resume: Whether to resume a previous run. Can be a run ID string.
90
+ mode: Mode for wandb ("online", "offline", or "disabled").
91
+ """
92
+ try:
93
+ import wandb as _wandb # noqa: F401,PLC0415
94
+ except ImportError as e:
95
+ raise RuntimeError(
96
+ "WandbLogger requires the 'wandb' package. Install it with: pip install xax[wandb]"
97
+ ) from e
98
+
99
+ self._wandb = _wandb
100
+
101
+ super().__init__(log_interval_seconds)
102
+
103
+ self.project = project
104
+ self.entity = entity
105
+ self.name = name
106
+ self.config = config
107
+ self.tags = tags
108
+ self.notes = notes
109
+ self.reinit = reinit
110
+ self.resume: WandbConfigResume = resume
111
+ self.mode: WandbConfigMode = mode
112
+
113
+ # Set wandb directory if run_directory is provided
114
+ if run_directory is not None:
115
+ self.wandb_dir = Path(run_directory).expanduser().resolve() / "wandb"
116
+ self.wandb_dir.mkdir(parents=True, exist_ok=True)
117
+
118
+ self._started = False
119
+
120
+ # Store pending files to log
121
+ self.files: dict[str, str] = {}
122
+
123
+ self.start()
124
+
125
+ def start(self) -> None:
126
+ """Initialize the W&B run."""
127
+ if self._started or not is_master():
128
+ return
129
+
130
+ # Set wandb environment variables if needed
131
+ if self.wandb_dir is not None:
132
+ os.environ["WANDB_DIR"] = str(self.wandb_dir)
133
+
134
+ # Initialize wandb run
135
+ self.run = self._wandb.init( # pyright
136
+ project=self.project,
137
+ entity=self.entity,
138
+ name=self.name,
139
+ config=self.config,
140
+ tags=self.tags,
141
+ notes=self.notes,
142
+ reinit=self.reinit.value,
143
+ resume=self.resume.value if isinstance(self.resume, WandbConfigResumeOption) else self.resume,
144
+ mode=self.mode.value if isinstance(self.mode, WandbConfigModeOption) else self.mode,
145
+ )
146
+
147
+ self._started = True
148
+ logger.info("W&B run initialized: %s", self.run.url if self.run else "No URL available")
149
+
150
+ def stop(self) -> None:
151
+ """Finish the W&B run."""
152
+ if not self._started or not is_master():
153
+ return
154
+
155
+ if self.run is not None:
156
+ self.run.finish()
157
+ self._started = False
158
+
159
+ def log_file(self, name: str, contents: str) -> None:
160
+ """Store a file to be logged with the next write call.
161
+
162
+ Args:
163
+ name: The name of the file.
164
+ contents: The contents of the file.
165
+ """
166
+ if not is_master():
167
+ return
168
+ self.files[name] = contents
169
+
170
+ def write(self, line: LogLine) -> None:
171
+ """Writes the current log line to W&B.
172
+
173
+ Args:
174
+ line: The line to write.
175
+ """
176
+ if not is_master() or not self._started:
177
+ return
178
+
179
+ # Get step information
180
+ global_step = line.state.num_steps.item()
181
+
182
+ # Dictionary to collect all metrics for this step
183
+ metrics: dict[str, Any] = {}
184
+
185
+ # Log scalars
186
+ for namespace, scalars in line.scalars.items():
187
+ for scalar_key, scalar_value in scalars.items():
188
+ key = sanitize_metric_name(f"{namespace}/{scalar_key}")
189
+ metrics[key] = as_float(scalar_value.value)
190
+
191
+ # Log distributions as custom metrics (mean and std)
192
+ for namespace, distributions in line.distributions.items():
193
+ for distribution_key, distribution_value in distributions.items():
194
+ base_key = sanitize_metric_name(f"{namespace}/{distribution_key}")
195
+ metrics[f"{base_key}/mean"] = float(distribution_value.mean)
196
+ metrics[f"{base_key}/std"] = float(distribution_value.std)
197
+
198
+ # Log histograms
199
+ for namespace, histograms in line.histograms.items():
200
+ for histogram_key, histogram_value in histograms.items():
201
+ key = sanitize_metric_name(f"{namespace}/{histogram_key}")
202
+ # Create histogram data for wandb
203
+ # W&B expects a list of values or a numpy array
204
+ # We need to reconstruct the data from the histogram bins
205
+ values = []
206
+ for i, count in enumerate(histogram_value.bucket_counts):
207
+ if count > 0:
208
+ # Use the midpoint of each bucket
209
+ if i == 0:
210
+ val = histogram_value.bucket_limits[0]
211
+ else:
212
+ val = (histogram_value.bucket_limits[i - 1] + histogram_value.bucket_limits[i]) / 2
213
+ values.extend([val] * count)
214
+
215
+ if values:
216
+ # wandb.Histogram accepts lists directly
217
+ metrics[key] = self._wandb.Histogram(values)
218
+
219
+ # Log strings as HTML
220
+ for namespace, strings in line.strings.items():
221
+ for string_key, string_value in strings.items():
222
+ key = sanitize_metric_name(f"{namespace}/{string_key}")
223
+ # For strings, we can log them as HTML or just as text in a table
224
+ metrics[key] = self._wandb.Html(f"<pre>{string_value.value}</pre>")
225
+
226
+ # Log images
227
+ for namespace, images in line.images.items():
228
+ for image_key, image_value in images.items():
229
+ key = sanitize_metric_name(f"{namespace}/{image_key}")
230
+ # Convert PIL image to wandb.Image
231
+ metrics[key] = self._wandb.Image(image_value.image)
232
+
233
+ # Log videos
234
+ for namespace, videos in line.videos.items():
235
+ for video_key, video_value in videos.items():
236
+ key = sanitize_metric_name(f"{namespace}/{video_key}")
237
+ # wandb.Video expects shape (time, channels, height, width)
238
+ # Our format is (T, H, W, C) so we need to transpose to (T, C, H, W)
239
+ frames = video_value.frames.transpose(0, 3, 1, 2) # (T, H, W, C) -> (T, C, H, W)
240
+ metrics[key] = self._wandb.Video(frames, fps=video_value.fps, format="mp4")
241
+
242
+ # Log meshes (3D objects)
243
+ for namespace, meshes in line.meshes.items():
244
+ for mesh_key, mesh_value in meshes.items():
245
+ key = sanitize_metric_name(f"{namespace}/{mesh_key}")
246
+ # W&B Object3D expects vertices and faces in specific format
247
+ # vertices: (batch_size, num_vertices, 3) or (num_vertices, 3)
248
+ # faces: (batch_size, num_faces, 3) or (num_faces, 3)
249
+ vertices = mesh_value.vertices
250
+
251
+ # Handle batch dimension - take first batch if present
252
+ if vertices.ndim == 3:
253
+ vertices = vertices[0]
254
+
255
+ obj3d_data = {
256
+ "type": "lidar/beta",
257
+ "vertices": vertices.tolist(),
258
+ }
259
+
260
+ if mesh_value.faces is not None:
261
+ faces = mesh_value.faces
262
+ if faces.ndim == 3:
263
+ faces = faces[0]
264
+ obj3d_data["faces"] = faces.tolist()
265
+
266
+ if mesh_value.colors is not None:
267
+ colors = mesh_value.colors
268
+ if colors.ndim == 3:
269
+ colors = colors[0]
270
+ # Convert colors to 0-1 range if they're in 0-255 range
271
+ # The colors are already numpy arrays from LogMesh, converted by as_numpy
272
+ if colors.dtype == np.uint8:
273
+ colors = colors.astype(np.float32) / 255.0
274
+ obj3d_data["colors"] = colors.tolist()
275
+
276
+ metrics[key] = self._wandb.Object3D(obj3d_data)
277
+
278
+ # Log any pending files as artifacts or text
279
+ for name, contents in self.files.items():
280
+ # Log as HTML text
281
+ key = sanitize_metric_name(name)
282
+ key = f"{self.run.name}_{key}"
283
+ is_training_code = "code" in name
284
+ artifact = self._wandb.Artifact(
285
+ name=key if not is_training_code else "training_code",
286
+ type="code" if is_training_code else "unspecified",
287
+ )
288
+ with artifact.new_file(name) as f:
289
+ f.write(contents)
290
+ artifact.save()
291
+ self.files.clear()
292
+
293
+ # Log all metrics at once
294
+ if metrics and self.run:
295
+ self.run.log(metrics, step=global_step)
296
+
297
+ def write_error_summary(self, error_summary: LogErrorSummary) -> None:
298
+ pass
299
+
300
+ def write_error(self, error: LogError) -> None:
301
+ pass
302
+
303
+ def write_status(self, status: LogStatus) -> None:
304
+ pass
305
+
306
+ def write_ping(self, ping: LogPing) -> None:
307
+ pass
@@ -10,4 +10,5 @@ from xax.task.mixins.logger import LoggerConfig, LoggerMixin
10
10
  from xax.task.mixins.process import ProcessConfig, ProcessMixin
11
11
  from xax.task.mixins.runnable import RunnableConfig, RunnableMixin
12
12
  from xax.task.mixins.step_wrapper import StepContextConfig, StepContextMixin
13
- from xax.task.mixins.train import TrainConfig, TrainMixin
13
+ from xax.task.mixins.supervised import SupervisedConfig, SupervisedMixin
14
+ from xax.task.mixins.train import InitParams, TrainConfig, TrainMixin