xax 0.3.14__tar.gz → 0.4.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {xax-0.3.14/xax.egg-info → xax-0.4.4}/PKG-INFO +3 -1
- {xax-0.3.14 → xax-0.4.4}/pyproject.toml +1 -0
- {xax-0.3.14 → xax-0.4.4}/setup.py +1 -0
- {xax-0.3.14 → xax-0.4.4}/xax/__init__.py +12 -4
- xax-0.4.4/xax/task/launchers/single_process.py +141 -0
- xax-0.4.4/xax/task/loggers/wandb.py +307 -0
- {xax-0.3.14 → xax-0.4.4}/xax/task/mixins/__init__.py +2 -1
- xax-0.4.4/xax/task/mixins/logger.py +169 -0
- xax-0.4.4/xax/task/mixins/supervised.py +368 -0
- {xax-0.3.14 → xax-0.4.4}/xax/task/mixins/train.py +36 -345
- {xax-0.3.14 → xax-0.4.4}/xax/task/task.py +26 -2
- {xax-0.3.14 → xax-0.4.4}/xax/utils/experiments.py +2 -2
- {xax-0.3.14 → xax-0.4.4}/xax/utils/types/frozen_dict.py +4 -0
- {xax-0.3.14 → xax-0.4.4/xax.egg-info}/PKG-INFO +3 -1
- {xax-0.3.14 → xax-0.4.4}/xax.egg-info/SOURCES.txt +2 -0
- {xax-0.3.14 → xax-0.4.4}/xax.egg-info/requires.txt +3 -0
- xax-0.3.14/xax/task/launchers/single_process.py +0 -31
- xax-0.3.14/xax/task/mixins/logger.py +0 -92
- {xax-0.3.14 → xax-0.4.4}/LICENSE +0 -0
- {xax-0.3.14 → xax-0.4.4}/MANIFEST.in +0 -0
- {xax-0.3.14 → xax-0.4.4}/README.md +0 -0
- {xax-0.3.14 → xax-0.4.4}/setup.cfg +0 -0
- {xax-0.3.14 → xax-0.4.4}/xax/cli/__init__.py +0 -0
- {xax-0.3.14 → xax-0.4.4}/xax/cli/edit_config.py +0 -0
- {xax-0.3.14 → xax-0.4.4}/xax/core/__init__.py +0 -0
- {xax-0.3.14 → xax-0.4.4}/xax/core/conf.py +0 -0
- {xax-0.3.14 → xax-0.4.4}/xax/core/state.py +0 -0
- {xax-0.3.14 → xax-0.4.4}/xax/nn/__init__.py +0 -0
- {xax-0.3.14 → xax-0.4.4}/xax/nn/attention.py +0 -0
- {xax-0.3.14 → xax-0.4.4}/xax/nn/distributions.py +0 -0
- {xax-0.3.14 → xax-0.4.4}/xax/nn/embeddings.py +0 -0
- {xax-0.3.14 → xax-0.4.4}/xax/nn/functions.py +0 -0
- {xax-0.3.14 → xax-0.4.4}/xax/nn/geom.py +0 -0
- {xax-0.3.14 → xax-0.4.4}/xax/nn/losses.py +0 -0
- {xax-0.3.14 → xax-0.4.4}/xax/nn/metrics.py +0 -0
- {xax-0.3.14 → xax-0.4.4}/xax/nn/parallel.py +0 -0
- {xax-0.3.14 → xax-0.4.4}/xax/nn/ssm.py +0 -0
- {xax-0.3.14 → xax-0.4.4}/xax/py.typed +0 -0
- {xax-0.3.14 → xax-0.4.4}/xax/requirements-dev.txt +0 -0
- {xax-0.3.14 → xax-0.4.4}/xax/requirements.txt +0 -0
- {xax-0.3.14 → xax-0.4.4}/xax/task/__init__.py +0 -0
- {xax-0.3.14 → xax-0.4.4}/xax/task/base.py +0 -0
- {xax-0.3.14 → xax-0.4.4}/xax/task/launchers/__init__.py +0 -0
- {xax-0.3.14 → xax-0.4.4}/xax/task/launchers/base.py +0 -0
- {xax-0.3.14 → xax-0.4.4}/xax/task/launchers/cli.py +0 -0
- {xax-0.3.14 → xax-0.4.4}/xax/task/logger.py +0 -0
- {xax-0.3.14 → xax-0.4.4}/xax/task/loggers/__init__.py +0 -0
- {xax-0.3.14 → xax-0.4.4}/xax/task/loggers/callback.py +0 -0
- {xax-0.3.14 → xax-0.4.4}/xax/task/loggers/json.py +0 -0
- {xax-0.3.14 → xax-0.4.4}/xax/task/loggers/state.py +0 -0
- {xax-0.3.14 → xax-0.4.4}/xax/task/loggers/stdout.py +0 -0
- {xax-0.3.14 → xax-0.4.4}/xax/task/loggers/tensorboard.py +0 -0
- {xax-0.3.14 → xax-0.4.4}/xax/task/mixins/artifacts.py +0 -0
- {xax-0.3.14 → xax-0.4.4}/xax/task/mixins/checkpointing.py +0 -0
- {xax-0.3.14 → xax-0.4.4}/xax/task/mixins/compile.py +0 -0
- {xax-0.3.14 → xax-0.4.4}/xax/task/mixins/cpu_stats.py +0 -0
- {xax-0.3.14 → xax-0.4.4}/xax/task/mixins/data_loader.py +0 -0
- {xax-0.3.14 → xax-0.4.4}/xax/task/mixins/gpu_stats.py +0 -0
- {xax-0.3.14 → xax-0.4.4}/xax/task/mixins/process.py +0 -0
- {xax-0.3.14 → xax-0.4.4}/xax/task/mixins/runnable.py +0 -0
- {xax-0.3.14 → xax-0.4.4}/xax/task/mixins/step_wrapper.py +0 -0
- {xax-0.3.14 → xax-0.4.4}/xax/task/script.py +0 -0
- {xax-0.3.14 → xax-0.4.4}/xax/utils/__init__.py +0 -0
- {xax-0.3.14 → xax-0.4.4}/xax/utils/data/__init__.py +0 -0
- {xax-0.3.14 → xax-0.4.4}/xax/utils/data/collate.py +0 -0
- {xax-0.3.14 → xax-0.4.4}/xax/utils/debugging.py +0 -0
- {xax-0.3.14 → xax-0.4.4}/xax/utils/jax.py +0 -0
- {xax-0.3.14 → xax-0.4.4}/xax/utils/jaxpr.py +0 -0
- {xax-0.3.14 → xax-0.4.4}/xax/utils/logging.py +0 -0
- {xax-0.3.14 → xax-0.4.4}/xax/utils/numpy.py +0 -0
- {xax-0.3.14 → xax-0.4.4}/xax/utils/profile.py +0 -0
- {xax-0.3.14 → xax-0.4.4}/xax/utils/pytree.py +0 -0
- {xax-0.3.14 → xax-0.4.4}/xax/utils/tensorboard.py +0 -0
- {xax-0.3.14 → xax-0.4.4}/xax/utils/text.py +0 -0
- {xax-0.3.14 → xax-0.4.4}/xax/utils/types/__init__.py +0 -0
- {xax-0.3.14 → xax-0.4.4}/xax/utils/types/hashable_array.py +0 -0
- {xax-0.3.14 → xax-0.4.4}/xax.egg-info/dependency_links.txt +0 -0
- {xax-0.3.14 → xax-0.4.4}/xax.egg-info/entry_points.txt +0 -0
- {xax-0.3.14 → xax-0.4.4}/xax.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: xax
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.4.4
|
4
4
|
Summary: A library for fast Jax experimentation
|
5
5
|
Home-page: https://github.com/kscalelabs/xax
|
6
6
|
Author: Benjamin Bolte
|
@@ -31,6 +31,8 @@ Requires-Dist: pytest; extra == "dev"
|
|
31
31
|
Requires-Dist: types-pillow; extra == "dev"
|
32
32
|
Requires-Dist: types-psutil; extra == "dev"
|
33
33
|
Requires-Dist: types-requests; extra == "dev"
|
34
|
+
Provides-Extra: wandb
|
35
|
+
Requires-Dist: wandb[media]; extra == "wandb"
|
34
36
|
Dynamic: author
|
35
37
|
Dynamic: description
|
36
38
|
Dynamic: description-content-type
|
@@ -12,7 +12,7 @@ and running the update script:
|
|
12
12
|
python -m scripts.update_api --inplace
|
13
13
|
"""
|
14
14
|
|
15
|
-
__version__ = "0.
|
15
|
+
__version__ = "0.4.4"
|
16
16
|
|
17
17
|
# This list shouldn't be modified by hand; instead, run the update script.
|
18
18
|
__all__ = [
|
@@ -94,10 +94,13 @@ __all__ = [
|
|
94
94
|
"DataloaderConfig",
|
95
95
|
"GPUStatsOptions",
|
96
96
|
"StepContext",
|
97
|
+
"InitParams",
|
97
98
|
"ValidStepTimer",
|
98
99
|
"Script",
|
99
100
|
"ScriptConfig",
|
100
101
|
"Config",
|
102
|
+
"SupervisedConfig",
|
103
|
+
"SupervisedTask",
|
101
104
|
"Task",
|
102
105
|
"collate",
|
103
106
|
"collate_non_null",
|
@@ -168,6 +171,7 @@ __all__ = [
|
|
168
171
|
"uncolored",
|
169
172
|
"wrapped",
|
170
173
|
"FrozenDict",
|
174
|
+
"freeze_dict",
|
171
175
|
"HashableArray",
|
172
176
|
"hashable_array",
|
173
177
|
]
|
@@ -291,10 +295,13 @@ NAME_MAP: dict[str, str] = {
|
|
291
295
|
"DataloaderConfig": "task.mixins.data_loader",
|
292
296
|
"GPUStatsOptions": "task.mixins.gpu_stats",
|
293
297
|
"StepContext": "task.mixins.step_wrapper",
|
298
|
+
"InitParams": "task.mixins.train",
|
294
299
|
"ValidStepTimer": "task.mixins.train",
|
295
300
|
"Script": "task.script",
|
296
301
|
"ScriptConfig": "task.script",
|
297
302
|
"Config": "task.task",
|
303
|
+
"SupervisedConfig": "task.task",
|
304
|
+
"SupervisedTask": "task.task",
|
298
305
|
"Task": "task.task",
|
299
306
|
"collate": "utils.data.collate",
|
300
307
|
"collate_non_null": "utils.data.collate",
|
@@ -365,6 +372,7 @@ NAME_MAP: dict[str, str] = {
|
|
365
372
|
"uncolored": "utils.text",
|
366
373
|
"wrapped": "utils.text",
|
367
374
|
"FrozenDict": "utils.types.frozen_dict",
|
375
|
+
"freeze_dict": "utils.types.frozen_dict",
|
368
376
|
"HashableArray": "utils.types.hashable_array",
|
369
377
|
"hashable_array": "utils.types.hashable_array",
|
370
378
|
}
|
@@ -488,9 +496,9 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
488
496
|
from xax.task.mixins.data_loader import DataloaderConfig
|
489
497
|
from xax.task.mixins.gpu_stats import GPUStatsOptions
|
490
498
|
from xax.task.mixins.step_wrapper import StepContext
|
491
|
-
from xax.task.mixins.train import Batch, Output, ValidStepTimer
|
499
|
+
from xax.task.mixins.train import Batch, InitParams, Output, ValidStepTimer
|
492
500
|
from xax.task.script import Script, ScriptConfig
|
493
|
-
from xax.task.task import Config, Task
|
501
|
+
from xax.task.task import Config, SupervisedConfig, SupervisedTask, Task
|
494
502
|
from xax.utils.data.collate import CollateMode, collate, collate_non_null
|
495
503
|
from xax.utils.debugging import (
|
496
504
|
breakpoint_if_nonfinite,
|
@@ -566,7 +574,7 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
566
574
|
uncolored,
|
567
575
|
wrapped,
|
568
576
|
)
|
569
|
-
from xax.utils.types.frozen_dict import FrozenDict
|
577
|
+
from xax.utils.types.frozen_dict import FrozenDict, freeze_dict
|
570
578
|
from xax.utils.types.hashable_array import HashableArray, hashable_array
|
571
579
|
|
572
580
|
del TYPE_CHECKING, IMPORT_ALL
|
@@ -0,0 +1,141 @@
|
|
1
|
+
"""Defines a launcher to train a model locally, in a single process."""
|
2
|
+
|
3
|
+
import logging
|
4
|
+
import os
|
5
|
+
import shutil
|
6
|
+
import subprocess
|
7
|
+
from typing import TYPE_CHECKING
|
8
|
+
|
9
|
+
import jax
|
10
|
+
|
11
|
+
from xax.task.base import RawConfigType
|
12
|
+
from xax.task.launchers.base import BaseLauncher
|
13
|
+
from xax.task.mixins.gpu_stats import get_num_gpus
|
14
|
+
from xax.utils.logging import configure_logging
|
15
|
+
|
16
|
+
if TYPE_CHECKING:
|
17
|
+
from xax.task.mixins.runnable import Config, RunnableMixin
|
18
|
+
|
19
|
+
|
20
|
+
def get_gpu_memory_info() -> dict[int, tuple[float, float]]:
|
21
|
+
"""Get memory information for all GPUs.
|
22
|
+
|
23
|
+
Returns:
|
24
|
+
Dictionary mapping GPU index to (total_memory_mb, used_memory_mb)
|
25
|
+
"""
|
26
|
+
command = "nvidia-smi --query-gpu=index,memory.total,memory.used --format=csv,noheader"
|
27
|
+
|
28
|
+
try:
|
29
|
+
with subprocess.Popen(command.split(), stdout=subprocess.PIPE, universal_newlines=True) as proc:
|
30
|
+
stdout = proc.stdout
|
31
|
+
assert stdout is not None
|
32
|
+
|
33
|
+
gpu_info = {}
|
34
|
+
for line in stdout:
|
35
|
+
line = line.strip()
|
36
|
+
if not line:
|
37
|
+
continue
|
38
|
+
|
39
|
+
parts = line.split(", ")
|
40
|
+
if len(parts) >= 3:
|
41
|
+
gpu_id = int(parts[0])
|
42
|
+
total_mem = float(parts[1].replace(" MiB", ""))
|
43
|
+
used_mem = float(parts[2].replace(" MiB", ""))
|
44
|
+
gpu_info[gpu_id] = (total_mem, used_mem)
|
45
|
+
|
46
|
+
return gpu_info
|
47
|
+
|
48
|
+
except Exception as e:
|
49
|
+
logger = configure_logging()
|
50
|
+
logger.warning("Failed to get GPU memory info: %s", e)
|
51
|
+
return {}
|
52
|
+
|
53
|
+
|
54
|
+
def select_best_gpu() -> int | None:
|
55
|
+
"""Select the GPU with the most available memory.
|
56
|
+
|
57
|
+
Returns:
|
58
|
+
GPU index with most available memory, or None if no GPUs found
|
59
|
+
"""
|
60
|
+
gpu_info = get_gpu_memory_info()
|
61
|
+
|
62
|
+
if not gpu_info:
|
63
|
+
return None
|
64
|
+
|
65
|
+
# Find GPU with most available memory
|
66
|
+
best_gpu = None
|
67
|
+
max_available: float = -1.0
|
68
|
+
|
69
|
+
for gpu_id, (total_mem, used_mem) in gpu_info.items():
|
70
|
+
available_mem = total_mem - used_mem
|
71
|
+
if available_mem > max_available:
|
72
|
+
max_available = available_mem
|
73
|
+
best_gpu = gpu_id
|
74
|
+
|
75
|
+
return best_gpu
|
76
|
+
|
77
|
+
|
78
|
+
def configure_gpu_devices(logger: logging.Logger | None = None) -> None:
|
79
|
+
if logger is None:
|
80
|
+
logger = configure_logging()
|
81
|
+
|
82
|
+
# If there are multiple devices, choose the one with the most
|
83
|
+
# available memory (i.e., the one which is likely not being used
|
84
|
+
# by other processes) and use only that device.
|
85
|
+
num_gpus = get_num_gpus()
|
86
|
+
|
87
|
+
if num_gpus > 1:
|
88
|
+
logger.info("Multiple GPUs detected (%d), selecting GPU with most available memory", num_gpus)
|
89
|
+
|
90
|
+
best_gpu = select_best_gpu()
|
91
|
+
if best_gpu is not None:
|
92
|
+
logger.info("Selected GPU %d for training", best_gpu)
|
93
|
+
|
94
|
+
# Set CUDA_VISIBLE_DEVICES to only show the selected GPU
|
95
|
+
os.environ["CUDA_VISIBLE_DEVICES"] = str(best_gpu)
|
96
|
+
|
97
|
+
# Configure JAX to use the selected device
|
98
|
+
try:
|
99
|
+
devices = jax.devices("gpu")
|
100
|
+
if devices:
|
101
|
+
jax.config.update("jax_default_device", devices[0])
|
102
|
+
logger.info("Configured JAX to use device: %s", devices[0])
|
103
|
+
except Exception as e:
|
104
|
+
logger.warning("Failed to configure JAX device: %s", e)
|
105
|
+
else:
|
106
|
+
logger.warning("Could not determine best GPU, using default device selection")
|
107
|
+
elif num_gpus == 1:
|
108
|
+
logger.info("Single GPU detected, using default device selection")
|
109
|
+
|
110
|
+
|
111
|
+
def configure_devices(logger: logging.Logger | None = None) -> None:
|
112
|
+
if logger is None:
|
113
|
+
logger = configure_logging()
|
114
|
+
|
115
|
+
if shutil.which("nvidia-smi") is not None:
|
116
|
+
configure_gpu_devices(logger)
|
117
|
+
|
118
|
+
|
119
|
+
def run_single_process_training(
|
120
|
+
task: "type[RunnableMixin[Config]]",
|
121
|
+
*cfgs: RawConfigType,
|
122
|
+
use_cli: bool | list[str] = True,
|
123
|
+
logger: logging.Logger | None = None,
|
124
|
+
) -> None:
|
125
|
+
if logger is None:
|
126
|
+
logger = configure_logging()
|
127
|
+
task_obj = task.get_task(*cfgs, use_cli=use_cli)
|
128
|
+
task_obj.add_logger_handlers(logger)
|
129
|
+
task_obj.run()
|
130
|
+
|
131
|
+
|
132
|
+
class SingleProcessLauncher(BaseLauncher):
|
133
|
+
def launch(
|
134
|
+
self,
|
135
|
+
task: "type[RunnableMixin[Config]]",
|
136
|
+
*cfgs: RawConfigType,
|
137
|
+
use_cli: bool | list[str] = True,
|
138
|
+
) -> None:
|
139
|
+
logger = configure_logging()
|
140
|
+
configure_devices(logger)
|
141
|
+
run_single_process_training(task, *cfgs, use_cli=use_cli, logger=logger)
|
@@ -0,0 +1,307 @@
|
|
1
|
+
# mypy: disable-error-code="import-not-found"
|
2
|
+
"""Defines a Weights & Biases logger backend."""
|
3
|
+
|
4
|
+
import logging
|
5
|
+
import os
|
6
|
+
from enum import Enum
|
7
|
+
from pathlib import Path
|
8
|
+
from typing import Any, TypeVar
|
9
|
+
|
10
|
+
import numpy as np
|
11
|
+
|
12
|
+
from xax.nn.parallel import is_master
|
13
|
+
from xax.task.logger import LogError, LogErrorSummary, LoggerImpl, LogLine, LogPing, LogStatus
|
14
|
+
from xax.utils.jax import as_float
|
15
|
+
|
16
|
+
logger: logging.Logger = logging.getLogger(__name__)
|
17
|
+
|
18
|
+
T = TypeVar("T")
|
19
|
+
|
20
|
+
|
21
|
+
def sanitize_metric_name(name: str) -> str:
|
22
|
+
"""Remove 4-byte unicode characters from metric names.
|
23
|
+
|
24
|
+
W&B has issues with 4-byte unicode characters in metric names,
|
25
|
+
so we need to filter them out.
|
26
|
+
|
27
|
+
Args:
|
28
|
+
name: The metric name to sanitize.
|
29
|
+
|
30
|
+
Returns:
|
31
|
+
The sanitized metric name.
|
32
|
+
"""
|
33
|
+
# Filter out characters that don't fit in UCS-2 (Basic Multilingual Plane)
|
34
|
+
# These are characters with code points > 0xFFFF (4-byte UTF-8)
|
35
|
+
return "".join(char for char in name if ord(char) <= 0xFFFF)
|
36
|
+
|
37
|
+
|
38
|
+
class WandbConfigResumeOption(str, Enum):
|
39
|
+
ALLOW = "allow"
|
40
|
+
NEVER = "never"
|
41
|
+
MUST = "must"
|
42
|
+
AUTO = "auto"
|
43
|
+
|
44
|
+
|
45
|
+
class WandbConfigModeOption(str, Enum):
|
46
|
+
ONLINE = "online"
|
47
|
+
OFFLINE = "offline"
|
48
|
+
DISABLED = "disabled"
|
49
|
+
SHARED = "shared"
|
50
|
+
|
51
|
+
|
52
|
+
class WandbConfigReinitOption(str, Enum):
|
53
|
+
RETURN_PREVIOUS = "return_previous"
|
54
|
+
FINISH_PREVIOUS = "finish_previous"
|
55
|
+
|
56
|
+
|
57
|
+
WandbConfigResume = WandbConfigResumeOption | bool
|
58
|
+
WandbConfigMode = WandbConfigModeOption | None
|
59
|
+
|
60
|
+
|
61
|
+
class WandbLogger(LoggerImpl):
|
62
|
+
def __init__(
|
63
|
+
self,
|
64
|
+
project: str | None = None,
|
65
|
+
entity: str | None = None,
|
66
|
+
name: str | None = None,
|
67
|
+
run_directory: str | Path | None = None,
|
68
|
+
config: dict[str, Any] | None = None,
|
69
|
+
tags: list[str] | None = None,
|
70
|
+
notes: str | None = None,
|
71
|
+
log_interval_seconds: float = 10.0,
|
72
|
+
reinit: WandbConfigReinitOption = WandbConfigReinitOption.RETURN_PREVIOUS,
|
73
|
+
resume: WandbConfigResume = False,
|
74
|
+
mode: WandbConfigMode = None,
|
75
|
+
) -> None:
|
76
|
+
"""Defines a logger which writes to Weights & Biases.
|
77
|
+
|
78
|
+
Args:
|
79
|
+
project: The name of the W&B project to log to.
|
80
|
+
entity: The W&B entity (team or user) to log to.
|
81
|
+
name: The name of this run.
|
82
|
+
run_directory: The root run directory. If provided, wandb will save
|
83
|
+
files to a subdirectory here.
|
84
|
+
config: Configuration dictionary to log.
|
85
|
+
tags: List of tags for this run.
|
86
|
+
notes: Notes about this run.
|
87
|
+
log_interval_seconds: The interval between successive log lines.
|
88
|
+
reinit: Whether to allow multiple wandb.init() calls in the same process.
|
89
|
+
resume: Whether to resume a previous run. Can be a run ID string.
|
90
|
+
mode: Mode for wandb ("online", "offline", or "disabled").
|
91
|
+
"""
|
92
|
+
try:
|
93
|
+
import wandb as _wandb # noqa: F401,PLC0415
|
94
|
+
except ImportError as e:
|
95
|
+
raise RuntimeError(
|
96
|
+
"WandbLogger requires the 'wandb' package. Install it with: pip install xax[wandb]"
|
97
|
+
) from e
|
98
|
+
|
99
|
+
self._wandb = _wandb
|
100
|
+
|
101
|
+
super().__init__(log_interval_seconds)
|
102
|
+
|
103
|
+
self.project = project
|
104
|
+
self.entity = entity
|
105
|
+
self.name = name
|
106
|
+
self.config = config
|
107
|
+
self.tags = tags
|
108
|
+
self.notes = notes
|
109
|
+
self.reinit = reinit
|
110
|
+
self.resume: WandbConfigResume = resume
|
111
|
+
self.mode: WandbConfigMode = mode
|
112
|
+
|
113
|
+
# Set wandb directory if run_directory is provided
|
114
|
+
if run_directory is not None:
|
115
|
+
self.wandb_dir = Path(run_directory).expanduser().resolve() / "wandb"
|
116
|
+
self.wandb_dir.mkdir(parents=True, exist_ok=True)
|
117
|
+
|
118
|
+
self._started = False
|
119
|
+
|
120
|
+
# Store pending files to log
|
121
|
+
self.files: dict[str, str] = {}
|
122
|
+
|
123
|
+
self.start()
|
124
|
+
|
125
|
+
def start(self) -> None:
|
126
|
+
"""Initialize the W&B run."""
|
127
|
+
if self._started or not is_master():
|
128
|
+
return
|
129
|
+
|
130
|
+
# Set wandb environment variables if needed
|
131
|
+
if self.wandb_dir is not None:
|
132
|
+
os.environ["WANDB_DIR"] = str(self.wandb_dir)
|
133
|
+
|
134
|
+
# Initialize wandb run
|
135
|
+
self.run = self._wandb.init( # pyright
|
136
|
+
project=self.project,
|
137
|
+
entity=self.entity,
|
138
|
+
name=self.name,
|
139
|
+
config=self.config,
|
140
|
+
tags=self.tags,
|
141
|
+
notes=self.notes,
|
142
|
+
reinit=self.reinit.value,
|
143
|
+
resume=self.resume.value if isinstance(self.resume, WandbConfigResumeOption) else self.resume,
|
144
|
+
mode=self.mode.value if isinstance(self.mode, WandbConfigModeOption) else self.mode,
|
145
|
+
)
|
146
|
+
|
147
|
+
self._started = True
|
148
|
+
logger.info("W&B run initialized: %s", self.run.url if self.run else "No URL available")
|
149
|
+
|
150
|
+
def stop(self) -> None:
|
151
|
+
"""Finish the W&B run."""
|
152
|
+
if not self._started or not is_master():
|
153
|
+
return
|
154
|
+
|
155
|
+
if self.run is not None:
|
156
|
+
self.run.finish()
|
157
|
+
self._started = False
|
158
|
+
|
159
|
+
def log_file(self, name: str, contents: str) -> None:
|
160
|
+
"""Store a file to be logged with the next write call.
|
161
|
+
|
162
|
+
Args:
|
163
|
+
name: The name of the file.
|
164
|
+
contents: The contents of the file.
|
165
|
+
"""
|
166
|
+
if not is_master():
|
167
|
+
return
|
168
|
+
self.files[name] = contents
|
169
|
+
|
170
|
+
def write(self, line: LogLine) -> None:
|
171
|
+
"""Writes the current log line to W&B.
|
172
|
+
|
173
|
+
Args:
|
174
|
+
line: The line to write.
|
175
|
+
"""
|
176
|
+
if not is_master() or not self._started:
|
177
|
+
return
|
178
|
+
|
179
|
+
# Get step information
|
180
|
+
global_step = line.state.num_steps.item()
|
181
|
+
|
182
|
+
# Dictionary to collect all metrics for this step
|
183
|
+
metrics: dict[str, Any] = {}
|
184
|
+
|
185
|
+
# Log scalars
|
186
|
+
for namespace, scalars in line.scalars.items():
|
187
|
+
for scalar_key, scalar_value in scalars.items():
|
188
|
+
key = sanitize_metric_name(f"{namespace}/{scalar_key}")
|
189
|
+
metrics[key] = as_float(scalar_value.value)
|
190
|
+
|
191
|
+
# Log distributions as custom metrics (mean and std)
|
192
|
+
for namespace, distributions in line.distributions.items():
|
193
|
+
for distribution_key, distribution_value in distributions.items():
|
194
|
+
base_key = sanitize_metric_name(f"{namespace}/{distribution_key}")
|
195
|
+
metrics[f"{base_key}/mean"] = float(distribution_value.mean)
|
196
|
+
metrics[f"{base_key}/std"] = float(distribution_value.std)
|
197
|
+
|
198
|
+
# Log histograms
|
199
|
+
for namespace, histograms in line.histograms.items():
|
200
|
+
for histogram_key, histogram_value in histograms.items():
|
201
|
+
key = sanitize_metric_name(f"{namespace}/{histogram_key}")
|
202
|
+
# Create histogram data for wandb
|
203
|
+
# W&B expects a list of values or a numpy array
|
204
|
+
# We need to reconstruct the data from the histogram bins
|
205
|
+
values = []
|
206
|
+
for i, count in enumerate(histogram_value.bucket_counts):
|
207
|
+
if count > 0:
|
208
|
+
# Use the midpoint of each bucket
|
209
|
+
if i == 0:
|
210
|
+
val = histogram_value.bucket_limits[0]
|
211
|
+
else:
|
212
|
+
val = (histogram_value.bucket_limits[i - 1] + histogram_value.bucket_limits[i]) / 2
|
213
|
+
values.extend([val] * count)
|
214
|
+
|
215
|
+
if values:
|
216
|
+
# wandb.Histogram accepts lists directly
|
217
|
+
metrics[key] = self._wandb.Histogram(values)
|
218
|
+
|
219
|
+
# Log strings as HTML
|
220
|
+
for namespace, strings in line.strings.items():
|
221
|
+
for string_key, string_value in strings.items():
|
222
|
+
key = sanitize_metric_name(f"{namespace}/{string_key}")
|
223
|
+
# For strings, we can log them as HTML or just as text in a table
|
224
|
+
metrics[key] = self._wandb.Html(f"<pre>{string_value.value}</pre>")
|
225
|
+
|
226
|
+
# Log images
|
227
|
+
for namespace, images in line.images.items():
|
228
|
+
for image_key, image_value in images.items():
|
229
|
+
key = sanitize_metric_name(f"{namespace}/{image_key}")
|
230
|
+
# Convert PIL image to wandb.Image
|
231
|
+
metrics[key] = self._wandb.Image(image_value.image)
|
232
|
+
|
233
|
+
# Log videos
|
234
|
+
for namespace, videos in line.videos.items():
|
235
|
+
for video_key, video_value in videos.items():
|
236
|
+
key = sanitize_metric_name(f"{namespace}/{video_key}")
|
237
|
+
# wandb.Video expects shape (time, channels, height, width)
|
238
|
+
# Our format is (T, H, W, C) so we need to transpose to (T, C, H, W)
|
239
|
+
frames = video_value.frames.transpose(0, 3, 1, 2) # (T, H, W, C) -> (T, C, H, W)
|
240
|
+
metrics[key] = self._wandb.Video(frames, fps=video_value.fps, format="mp4")
|
241
|
+
|
242
|
+
# Log meshes (3D objects)
|
243
|
+
for namespace, meshes in line.meshes.items():
|
244
|
+
for mesh_key, mesh_value in meshes.items():
|
245
|
+
key = sanitize_metric_name(f"{namespace}/{mesh_key}")
|
246
|
+
# W&B Object3D expects vertices and faces in specific format
|
247
|
+
# vertices: (batch_size, num_vertices, 3) or (num_vertices, 3)
|
248
|
+
# faces: (batch_size, num_faces, 3) or (num_faces, 3)
|
249
|
+
vertices = mesh_value.vertices
|
250
|
+
|
251
|
+
# Handle batch dimension - take first batch if present
|
252
|
+
if vertices.ndim == 3:
|
253
|
+
vertices = vertices[0]
|
254
|
+
|
255
|
+
obj3d_data = {
|
256
|
+
"type": "lidar/beta",
|
257
|
+
"vertices": vertices.tolist(),
|
258
|
+
}
|
259
|
+
|
260
|
+
if mesh_value.faces is not None:
|
261
|
+
faces = mesh_value.faces
|
262
|
+
if faces.ndim == 3:
|
263
|
+
faces = faces[0]
|
264
|
+
obj3d_data["faces"] = faces.tolist()
|
265
|
+
|
266
|
+
if mesh_value.colors is not None:
|
267
|
+
colors = mesh_value.colors
|
268
|
+
if colors.ndim == 3:
|
269
|
+
colors = colors[0]
|
270
|
+
# Convert colors to 0-1 range if they're in 0-255 range
|
271
|
+
# The colors are already numpy arrays from LogMesh, converted by as_numpy
|
272
|
+
if colors.dtype == np.uint8:
|
273
|
+
colors = colors.astype(np.float32) / 255.0
|
274
|
+
obj3d_data["colors"] = colors.tolist()
|
275
|
+
|
276
|
+
metrics[key] = self._wandb.Object3D(obj3d_data)
|
277
|
+
|
278
|
+
# Log any pending files as artifacts or text
|
279
|
+
for name, contents in self.files.items():
|
280
|
+
# Log as HTML text
|
281
|
+
key = sanitize_metric_name(name)
|
282
|
+
key = f"{self.run.name}_{key}"
|
283
|
+
is_training_code = "code" in name
|
284
|
+
artifact = self._wandb.Artifact(
|
285
|
+
name=key if not is_training_code else "training_code",
|
286
|
+
type="code" if is_training_code else "unspecified",
|
287
|
+
)
|
288
|
+
with artifact.new_file(name) as f:
|
289
|
+
f.write(contents)
|
290
|
+
artifact.save()
|
291
|
+
self.files.clear()
|
292
|
+
|
293
|
+
# Log all metrics at once
|
294
|
+
if metrics and self.run:
|
295
|
+
self.run.log(metrics, step=global_step)
|
296
|
+
|
297
|
+
def write_error_summary(self, error_summary: LogErrorSummary) -> None:
|
298
|
+
pass
|
299
|
+
|
300
|
+
def write_error(self, error: LogError) -> None:
|
301
|
+
pass
|
302
|
+
|
303
|
+
def write_status(self, status: LogStatus) -> None:
|
304
|
+
pass
|
305
|
+
|
306
|
+
def write_ping(self, ping: LogPing) -> None:
|
307
|
+
pass
|
@@ -10,4 +10,5 @@ from xax.task.mixins.logger import LoggerConfig, LoggerMixin
|
|
10
10
|
from xax.task.mixins.process import ProcessConfig, ProcessMixin
|
11
11
|
from xax.task.mixins.runnable import RunnableConfig, RunnableMixin
|
12
12
|
from xax.task.mixins.step_wrapper import StepContextConfig, StepContextMixin
|
13
|
-
from xax.task.mixins.
|
13
|
+
from xax.task.mixins.supervised import SupervisedConfig, SupervisedMixin
|
14
|
+
from xax.task.mixins.train import InitParams, TrainConfig, TrainMixin
|