xax 0.4.2__py3-none-any.whl → 0.4.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +1 -1
- xax/task/launchers/single_process.py +112 -2
- xax/task/loggers/wandb.py +307 -0
- xax/task/mixins/logger.py +116 -5
- {xax-0.4.2.dist-info → xax-0.4.3.dist-info}/METADATA +3 -1
- {xax-0.4.2.dist-info → xax-0.4.3.dist-info}/RECORD +10 -9
- {xax-0.4.2.dist-info → xax-0.4.3.dist-info}/WHEEL +0 -0
- {xax-0.4.2.dist-info → xax-0.4.3.dist-info}/entry_points.txt +0 -0
- {xax-0.4.2.dist-info → xax-0.4.3.dist-info}/licenses/LICENSE +0 -0
- {xax-0.4.2.dist-info → xax-0.4.3.dist-info}/top_level.txt +0 -0
xax/__init__.py
CHANGED
@@ -1,21 +1,129 @@
|
|
1
1
|
"""Defines a launcher to train a model locally, in a single process."""
|
2
2
|
|
3
|
+
import logging
|
4
|
+
import os
|
5
|
+
import shutil
|
6
|
+
import subprocess
|
3
7
|
from typing import TYPE_CHECKING
|
4
8
|
|
9
|
+
import jax
|
10
|
+
|
5
11
|
from xax.task.base import RawConfigType
|
6
12
|
from xax.task.launchers.base import BaseLauncher
|
13
|
+
from xax.task.mixins.gpu_stats import get_num_gpus
|
7
14
|
from xax.utils.logging import configure_logging
|
8
15
|
|
9
16
|
if TYPE_CHECKING:
|
10
17
|
from xax.task.mixins.runnable import Config, RunnableMixin
|
11
18
|
|
12
19
|
|
20
|
+
def get_gpu_memory_info() -> dict[int, tuple[float, float]]:
|
21
|
+
"""Get memory information for all GPUs.
|
22
|
+
|
23
|
+
Returns:
|
24
|
+
Dictionary mapping GPU index to (total_memory_mb, used_memory_mb)
|
25
|
+
"""
|
26
|
+
command = "nvidia-smi --query-gpu=index,memory.total,memory.used --format=csv,noheader"
|
27
|
+
|
28
|
+
try:
|
29
|
+
with subprocess.Popen(command.split(), stdout=subprocess.PIPE, universal_newlines=True) as proc:
|
30
|
+
stdout = proc.stdout
|
31
|
+
assert stdout is not None
|
32
|
+
|
33
|
+
gpu_info = {}
|
34
|
+
for line in stdout:
|
35
|
+
line = line.strip()
|
36
|
+
if not line:
|
37
|
+
continue
|
38
|
+
|
39
|
+
parts = line.split(", ")
|
40
|
+
if len(parts) >= 3:
|
41
|
+
gpu_id = int(parts[0])
|
42
|
+
total_mem = float(parts[1].replace(" MiB", ""))
|
43
|
+
used_mem = float(parts[2].replace(" MiB", ""))
|
44
|
+
gpu_info[gpu_id] = (total_mem, used_mem)
|
45
|
+
|
46
|
+
return gpu_info
|
47
|
+
|
48
|
+
except Exception as e:
|
49
|
+
logger = configure_logging()
|
50
|
+
logger.warning("Failed to get GPU memory info: %s", e)
|
51
|
+
return {}
|
52
|
+
|
53
|
+
|
54
|
+
def select_best_gpu() -> int | None:
|
55
|
+
"""Select the GPU with the most available memory.
|
56
|
+
|
57
|
+
Returns:
|
58
|
+
GPU index with most available memory, or None if no GPUs found
|
59
|
+
"""
|
60
|
+
gpu_info = get_gpu_memory_info()
|
61
|
+
|
62
|
+
if not gpu_info:
|
63
|
+
return None
|
64
|
+
|
65
|
+
# Find GPU with most available memory
|
66
|
+
best_gpu = None
|
67
|
+
max_available: float = -1.0
|
68
|
+
|
69
|
+
for gpu_id, (total_mem, used_mem) in gpu_info.items():
|
70
|
+
available_mem = total_mem - used_mem
|
71
|
+
if available_mem > max_available:
|
72
|
+
max_available = available_mem
|
73
|
+
best_gpu = gpu_id
|
74
|
+
|
75
|
+
return best_gpu
|
76
|
+
|
77
|
+
|
78
|
+
def configure_gpu_devices(logger: logging.Logger | None = None) -> None:
|
79
|
+
if logger is None:
|
80
|
+
logger = configure_logging()
|
81
|
+
|
82
|
+
# If there are multiple devices, choose the one with the most
|
83
|
+
# available memory (i.e., the one which is likely not being used
|
84
|
+
# by other processes) and use only that device.
|
85
|
+
num_gpus = get_num_gpus()
|
86
|
+
|
87
|
+
if num_gpus > 1:
|
88
|
+
logger.info("Multiple GPUs detected (%d), selecting GPU with most available memory", num_gpus)
|
89
|
+
|
90
|
+
best_gpu = select_best_gpu()
|
91
|
+
if best_gpu is not None:
|
92
|
+
logger.info("Selected GPU %d for training", best_gpu)
|
93
|
+
|
94
|
+
# Set CUDA_VISIBLE_DEVICES to only show the selected GPU
|
95
|
+
os.environ["CUDA_VISIBLE_DEVICES"] = str(best_gpu)
|
96
|
+
|
97
|
+
# Configure JAX to use the selected device
|
98
|
+
try:
|
99
|
+
devices = jax.devices("gpu")
|
100
|
+
if devices:
|
101
|
+
jax.config.update("jax_default_device", devices[0])
|
102
|
+
logger.info("Configured JAX to use device: %s", devices[0])
|
103
|
+
except Exception as e:
|
104
|
+
logger.warning("Failed to configure JAX device: %s", e)
|
105
|
+
else:
|
106
|
+
logger.warning("Could not determine best GPU, using default device selection")
|
107
|
+
elif num_gpus == 1:
|
108
|
+
logger.info("Single GPU detected, using default device selection")
|
109
|
+
|
110
|
+
|
111
|
+
def configure_devices(logger: logging.Logger | None = None) -> None:
|
112
|
+
if logger is None:
|
113
|
+
logger = configure_logging()
|
114
|
+
|
115
|
+
if shutil.which("nvidia-smi") is not None:
|
116
|
+
configure_gpu_devices(logger)
|
117
|
+
|
118
|
+
|
13
119
|
def run_single_process_training(
|
14
120
|
task: "type[RunnableMixin[Config]]",
|
15
121
|
*cfgs: RawConfigType,
|
16
122
|
use_cli: bool | list[str] = True,
|
123
|
+
logger: logging.Logger | None = None,
|
17
124
|
) -> None:
|
18
|
-
logger
|
125
|
+
if logger is None:
|
126
|
+
logger = configure_logging()
|
19
127
|
task_obj = task.get_task(*cfgs, use_cli=use_cli)
|
20
128
|
task_obj.add_logger_handlers(logger)
|
21
129
|
task_obj.run()
|
@@ -28,4 +136,6 @@ class SingleProcessLauncher(BaseLauncher):
|
|
28
136
|
*cfgs: RawConfigType,
|
29
137
|
use_cli: bool | list[str] = True,
|
30
138
|
) -> None:
|
31
|
-
|
139
|
+
logger = configure_logging()
|
140
|
+
configure_devices(logger)
|
141
|
+
run_single_process_training(task, *cfgs, use_cli=use_cli, logger=logger)
|
@@ -0,0 +1,307 @@
|
|
1
|
+
# mypy: disable-error-code="import-not-found"
|
2
|
+
"""Defines a Weights & Biases logger backend."""
|
3
|
+
|
4
|
+
import logging
|
5
|
+
import os
|
6
|
+
from enum import Enum
|
7
|
+
from pathlib import Path
|
8
|
+
from typing import Any, TypeVar
|
9
|
+
|
10
|
+
import numpy as np
|
11
|
+
|
12
|
+
from xax.nn.parallel import is_master
|
13
|
+
from xax.task.logger import LogError, LogErrorSummary, LoggerImpl, LogLine, LogPing, LogStatus
|
14
|
+
from xax.utils.jax import as_float
|
15
|
+
|
16
|
+
logger: logging.Logger = logging.getLogger(__name__)
|
17
|
+
|
18
|
+
T = TypeVar("T")
|
19
|
+
|
20
|
+
|
21
|
+
def sanitize_metric_name(name: str) -> str:
|
22
|
+
"""Remove 4-byte unicode characters from metric names.
|
23
|
+
|
24
|
+
W&B has issues with 4-byte unicode characters in metric names,
|
25
|
+
so we need to filter them out.
|
26
|
+
|
27
|
+
Args:
|
28
|
+
name: The metric name to sanitize.
|
29
|
+
|
30
|
+
Returns:
|
31
|
+
The sanitized metric name.
|
32
|
+
"""
|
33
|
+
# Filter out characters that don't fit in UCS-2 (Basic Multilingual Plane)
|
34
|
+
# These are characters with code points > 0xFFFF (4-byte UTF-8)
|
35
|
+
return "".join(char for char in name if ord(char) <= 0xFFFF)
|
36
|
+
|
37
|
+
|
38
|
+
class WandbConfigResumeOption(str, Enum):
|
39
|
+
ALLOW = "allow"
|
40
|
+
NEVER = "never"
|
41
|
+
MUST = "must"
|
42
|
+
AUTO = "auto"
|
43
|
+
|
44
|
+
|
45
|
+
class WandbConfigModeOption(str, Enum):
|
46
|
+
ONLINE = "online"
|
47
|
+
OFFLINE = "offline"
|
48
|
+
DISABLED = "disabled"
|
49
|
+
SHARED = "shared"
|
50
|
+
|
51
|
+
|
52
|
+
class WandbConfigReinitOption(str, Enum):
|
53
|
+
RETURN_PREVIOUS = "return_previous"
|
54
|
+
FINISH_PREVIOUS = "finish_previous"
|
55
|
+
|
56
|
+
|
57
|
+
WandbConfigResume = WandbConfigResumeOption | bool
|
58
|
+
WandbConfigMode = WandbConfigModeOption | None
|
59
|
+
|
60
|
+
|
61
|
+
class WandbLogger(LoggerImpl):
|
62
|
+
def __init__(
|
63
|
+
self,
|
64
|
+
project: str | None = None,
|
65
|
+
entity: str | None = None,
|
66
|
+
name: str | None = None,
|
67
|
+
run_directory: str | Path | None = None,
|
68
|
+
config: dict[str, Any] | None = None,
|
69
|
+
tags: list[str] | None = None,
|
70
|
+
notes: str | None = None,
|
71
|
+
log_interval_seconds: float = 10.0,
|
72
|
+
reinit: WandbConfigReinitOption = WandbConfigReinitOption.RETURN_PREVIOUS,
|
73
|
+
resume: WandbConfigResume = False,
|
74
|
+
mode: WandbConfigMode = None,
|
75
|
+
) -> None:
|
76
|
+
"""Defines a logger which writes to Weights & Biases.
|
77
|
+
|
78
|
+
Args:
|
79
|
+
project: The name of the W&B project to log to.
|
80
|
+
entity: The W&B entity (team or user) to log to.
|
81
|
+
name: The name of this run.
|
82
|
+
run_directory: The root run directory. If provided, wandb will save
|
83
|
+
files to a subdirectory here.
|
84
|
+
config: Configuration dictionary to log.
|
85
|
+
tags: List of tags for this run.
|
86
|
+
notes: Notes about this run.
|
87
|
+
log_interval_seconds: The interval between successive log lines.
|
88
|
+
reinit: Whether to allow multiple wandb.init() calls in the same process.
|
89
|
+
resume: Whether to resume a previous run. Can be a run ID string.
|
90
|
+
mode: Mode for wandb ("online", "offline", or "disabled").
|
91
|
+
"""
|
92
|
+
try:
|
93
|
+
import wandb as _wandb # noqa: F401,PLC0415
|
94
|
+
except ImportError as e:
|
95
|
+
raise RuntimeError(
|
96
|
+
"WandbLogger requires the 'wandb' package. Install it with: pip install xax[wandb]"
|
97
|
+
) from e
|
98
|
+
|
99
|
+
self._wandb = _wandb
|
100
|
+
|
101
|
+
super().__init__(log_interval_seconds)
|
102
|
+
|
103
|
+
self.project = project
|
104
|
+
self.entity = entity
|
105
|
+
self.name = name
|
106
|
+
self.config = config
|
107
|
+
self.tags = tags
|
108
|
+
self.notes = notes
|
109
|
+
self.reinit = reinit
|
110
|
+
self.resume: WandbConfigResume = resume
|
111
|
+
self.mode: WandbConfigMode = mode
|
112
|
+
|
113
|
+
# Set wandb directory if run_directory is provided
|
114
|
+
if run_directory is not None:
|
115
|
+
self.wandb_dir = Path(run_directory).expanduser().resolve() / "wandb"
|
116
|
+
self.wandb_dir.mkdir(parents=True, exist_ok=True)
|
117
|
+
|
118
|
+
self._started = False
|
119
|
+
|
120
|
+
# Store pending files to log
|
121
|
+
self.files: dict[str, str] = {}
|
122
|
+
|
123
|
+
self.start()
|
124
|
+
|
125
|
+
def start(self) -> None:
|
126
|
+
"""Initialize the W&B run."""
|
127
|
+
if self._started or not is_master():
|
128
|
+
return
|
129
|
+
|
130
|
+
# Set wandb environment variables if needed
|
131
|
+
if self.wandb_dir is not None:
|
132
|
+
os.environ["WANDB_DIR"] = str(self.wandb_dir)
|
133
|
+
|
134
|
+
# Initialize wandb run
|
135
|
+
self.run = self._wandb.init( # pyright
|
136
|
+
project=self.project,
|
137
|
+
entity=self.entity,
|
138
|
+
name=self.name,
|
139
|
+
config=self.config,
|
140
|
+
tags=self.tags,
|
141
|
+
notes=self.notes,
|
142
|
+
reinit=self.reinit.value,
|
143
|
+
resume=self.resume.value if isinstance(self.resume, WandbConfigResumeOption) else self.resume,
|
144
|
+
mode=self.mode.value if isinstance(self.mode, WandbConfigModeOption) else self.mode,
|
145
|
+
)
|
146
|
+
|
147
|
+
self._started = True
|
148
|
+
logger.info("W&B run initialized: %s", self.run.url if self.run else "No URL available")
|
149
|
+
|
150
|
+
def stop(self) -> None:
|
151
|
+
"""Finish the W&B run."""
|
152
|
+
if not self._started or not is_master():
|
153
|
+
return
|
154
|
+
|
155
|
+
if self.run is not None:
|
156
|
+
self.run.finish()
|
157
|
+
self._started = False
|
158
|
+
|
159
|
+
def log_file(self, name: str, contents: str) -> None:
|
160
|
+
"""Store a file to be logged with the next write call.
|
161
|
+
|
162
|
+
Args:
|
163
|
+
name: The name of the file.
|
164
|
+
contents: The contents of the file.
|
165
|
+
"""
|
166
|
+
if not is_master():
|
167
|
+
return
|
168
|
+
self.files[name] = contents
|
169
|
+
|
170
|
+
def write(self, line: LogLine) -> None:
|
171
|
+
"""Writes the current log line to W&B.
|
172
|
+
|
173
|
+
Args:
|
174
|
+
line: The line to write.
|
175
|
+
"""
|
176
|
+
if not is_master() or not self._started:
|
177
|
+
return
|
178
|
+
|
179
|
+
# Get step information
|
180
|
+
global_step = line.state.num_steps.item()
|
181
|
+
|
182
|
+
# Dictionary to collect all metrics for this step
|
183
|
+
metrics: dict[str, Any] = {}
|
184
|
+
|
185
|
+
# Log scalars
|
186
|
+
for namespace, scalars in line.scalars.items():
|
187
|
+
for scalar_key, scalar_value in scalars.items():
|
188
|
+
key = sanitize_metric_name(f"{namespace}/{scalar_key}")
|
189
|
+
metrics[key] = as_float(scalar_value.value)
|
190
|
+
|
191
|
+
# Log distributions as custom metrics (mean and std)
|
192
|
+
for namespace, distributions in line.distributions.items():
|
193
|
+
for distribution_key, distribution_value in distributions.items():
|
194
|
+
base_key = sanitize_metric_name(f"{namespace}/{distribution_key}")
|
195
|
+
metrics[f"{base_key}/mean"] = float(distribution_value.mean)
|
196
|
+
metrics[f"{base_key}/std"] = float(distribution_value.std)
|
197
|
+
|
198
|
+
# Log histograms
|
199
|
+
for namespace, histograms in line.histograms.items():
|
200
|
+
for histogram_key, histogram_value in histograms.items():
|
201
|
+
key = sanitize_metric_name(f"{namespace}/{histogram_key}")
|
202
|
+
# Create histogram data for wandb
|
203
|
+
# W&B expects a list of values or a numpy array
|
204
|
+
# We need to reconstruct the data from the histogram bins
|
205
|
+
values = []
|
206
|
+
for i, count in enumerate(histogram_value.bucket_counts):
|
207
|
+
if count > 0:
|
208
|
+
# Use the midpoint of each bucket
|
209
|
+
if i == 0:
|
210
|
+
val = histogram_value.bucket_limits[0]
|
211
|
+
else:
|
212
|
+
val = (histogram_value.bucket_limits[i - 1] + histogram_value.bucket_limits[i]) / 2
|
213
|
+
values.extend([val] * count)
|
214
|
+
|
215
|
+
if values:
|
216
|
+
# wandb.Histogram accepts lists directly
|
217
|
+
metrics[key] = self._wandb.Histogram(values)
|
218
|
+
|
219
|
+
# Log strings as HTML
|
220
|
+
for namespace, strings in line.strings.items():
|
221
|
+
for string_key, string_value in strings.items():
|
222
|
+
key = sanitize_metric_name(f"{namespace}/{string_key}")
|
223
|
+
# For strings, we can log them as HTML or just as text in a table
|
224
|
+
metrics[key] = self._wandb.Html(f"<pre>{string_value.value}</pre>")
|
225
|
+
|
226
|
+
# Log images
|
227
|
+
for namespace, images in line.images.items():
|
228
|
+
for image_key, image_value in images.items():
|
229
|
+
key = sanitize_metric_name(f"{namespace}/{image_key}")
|
230
|
+
# Convert PIL image to wandb.Image
|
231
|
+
metrics[key] = self._wandb.Image(image_value.image)
|
232
|
+
|
233
|
+
# Log videos
|
234
|
+
for namespace, videos in line.videos.items():
|
235
|
+
for video_key, video_value in videos.items():
|
236
|
+
key = sanitize_metric_name(f"{namespace}/{video_key}")
|
237
|
+
# wandb.Video expects shape (time, channels, height, width)
|
238
|
+
# Our format is (T, H, W, C) so we need to transpose to (T, C, H, W)
|
239
|
+
frames = video_value.frames.transpose(0, 3, 1, 2) # (T, H, W, C) -> (T, C, H, W)
|
240
|
+
metrics[key] = self._wandb.Video(frames, fps=video_value.fps, format="mp4")
|
241
|
+
|
242
|
+
# Log meshes (3D objects)
|
243
|
+
for namespace, meshes in line.meshes.items():
|
244
|
+
for mesh_key, mesh_value in meshes.items():
|
245
|
+
key = sanitize_metric_name(f"{namespace}/{mesh_key}")
|
246
|
+
# W&B Object3D expects vertices and faces in specific format
|
247
|
+
# vertices: (batch_size, num_vertices, 3) or (num_vertices, 3)
|
248
|
+
# faces: (batch_size, num_faces, 3) or (num_faces, 3)
|
249
|
+
vertices = mesh_value.vertices
|
250
|
+
|
251
|
+
# Handle batch dimension - take first batch if present
|
252
|
+
if vertices.ndim == 3:
|
253
|
+
vertices = vertices[0]
|
254
|
+
|
255
|
+
obj3d_data = {
|
256
|
+
"type": "lidar/beta",
|
257
|
+
"vertices": vertices.tolist(),
|
258
|
+
}
|
259
|
+
|
260
|
+
if mesh_value.faces is not None:
|
261
|
+
faces = mesh_value.faces
|
262
|
+
if faces.ndim == 3:
|
263
|
+
faces = faces[0]
|
264
|
+
obj3d_data["faces"] = faces.tolist()
|
265
|
+
|
266
|
+
if mesh_value.colors is not None:
|
267
|
+
colors = mesh_value.colors
|
268
|
+
if colors.ndim == 3:
|
269
|
+
colors = colors[0]
|
270
|
+
# Convert colors to 0-1 range if they're in 0-255 range
|
271
|
+
# The colors are already numpy arrays from LogMesh, converted by as_numpy
|
272
|
+
if colors.dtype == np.uint8:
|
273
|
+
colors = colors.astype(np.float32) / 255.0
|
274
|
+
obj3d_data["colors"] = colors.tolist()
|
275
|
+
|
276
|
+
metrics[key] = self._wandb.Object3D(obj3d_data)
|
277
|
+
|
278
|
+
# Log any pending files as artifacts or text
|
279
|
+
for name, contents in self.files.items():
|
280
|
+
# Log as HTML text
|
281
|
+
key = sanitize_metric_name(name)
|
282
|
+
key = f"{self.run.name}_{key}"
|
283
|
+
is_training_code = "code" in name
|
284
|
+
artifact = self._wandb.Artifact(
|
285
|
+
name=key if not is_training_code else "training_code",
|
286
|
+
type="code" if is_training_code else "unspecified",
|
287
|
+
)
|
288
|
+
with artifact.new_file(name) as f:
|
289
|
+
f.write(contents)
|
290
|
+
artifact.save()
|
291
|
+
self.files.clear()
|
292
|
+
|
293
|
+
# Log all metrics at once
|
294
|
+
if metrics and self.run:
|
295
|
+
self.run.log(metrics, step=global_step)
|
296
|
+
|
297
|
+
def write_error_summary(self, error_summary: LogErrorSummary) -> None:
|
298
|
+
pass
|
299
|
+
|
300
|
+
def write_error(self, error: LogError) -> None:
|
301
|
+
pass
|
302
|
+
|
303
|
+
def write_status(self, status: LogStatus) -> None:
|
304
|
+
pass
|
305
|
+
|
306
|
+
def write_ping(self, ping: LogPing) -> None:
|
307
|
+
pass
|
xax/task/mixins/logger.py
CHANGED
@@ -2,9 +2,10 @@
|
|
2
2
|
|
3
3
|
import os
|
4
4
|
from dataclasses import dataclass
|
5
|
+
from enum import Enum
|
5
6
|
from pathlib import Path
|
6
7
|
from types import TracebackType
|
7
|
-
from typing import Generic, Self, TypeVar
|
8
|
+
from typing import Any, Generic, Self, TypeVar
|
8
9
|
|
9
10
|
import jax
|
10
11
|
|
@@ -16,10 +17,22 @@ from xax.task.loggers.json import JsonLogger
|
|
16
17
|
from xax.task.loggers.state import StateLogger
|
17
18
|
from xax.task.loggers.stdout import StdoutLogger
|
18
19
|
from xax.task.loggers.tensorboard import TensorboardLogger
|
20
|
+
from xax.task.loggers.wandb import (
|
21
|
+
WandbConfigMode,
|
22
|
+
WandbConfigModeOption,
|
23
|
+
WandbConfigReinitOption,
|
24
|
+
WandbConfigResume,
|
25
|
+
WandbLogger,
|
26
|
+
)
|
19
27
|
from xax.task.mixins.artifacts import ArtifactsMixin
|
20
28
|
from xax.utils.text import is_interactive_session
|
21
29
|
|
22
30
|
|
31
|
+
class LoggerBackend(str, Enum):
|
32
|
+
TENSORBOARD = "tensorboard"
|
33
|
+
WANDB = "wandb"
|
34
|
+
|
35
|
+
|
23
36
|
@jax.tree_util.register_dataclass
|
24
37
|
@dataclass
|
25
38
|
class LoggerConfig(BaseConfig):
|
@@ -27,10 +40,50 @@ class LoggerConfig(BaseConfig):
|
|
27
40
|
value=1.0,
|
28
41
|
help="The interval between successive log lines.",
|
29
42
|
)
|
43
|
+
logger_backend: LoggerBackend = field(
|
44
|
+
value=LoggerBackend.TENSORBOARD,
|
45
|
+
help="The logger backend to use",
|
46
|
+
)
|
30
47
|
tensorboard_log_interval_seconds: float = field(
|
31
48
|
value=10.0,
|
32
49
|
help="The interval between successive Tensorboard log lines.",
|
33
50
|
)
|
51
|
+
wandb_project: str | None = field(
|
52
|
+
value=None,
|
53
|
+
help="The name of the W&B project to log to.",
|
54
|
+
)
|
55
|
+
wandb_entity: str | None = field(
|
56
|
+
value=None,
|
57
|
+
help="The W&B entity (team or user) to log to.",
|
58
|
+
)
|
59
|
+
wandb_name: str | None = field(
|
60
|
+
value=None,
|
61
|
+
help="The name of this run in W&B.",
|
62
|
+
)
|
63
|
+
wandb_tags: list[str] | None = field(
|
64
|
+
value=None,
|
65
|
+
help="List of tags for this W&B run.",
|
66
|
+
)
|
67
|
+
wandb_notes: str | None = field(
|
68
|
+
value=None,
|
69
|
+
help="Notes about this W&B run.",
|
70
|
+
)
|
71
|
+
wandb_log_interval_seconds: float = field(
|
72
|
+
value=10.0,
|
73
|
+
help="The interval between successive W&B log lines.",
|
74
|
+
)
|
75
|
+
wandb_mode: WandbConfigMode = field(
|
76
|
+
value=WandbConfigModeOption.ONLINE,
|
77
|
+
help="Mode for wandb (online, offline, or disabled).",
|
78
|
+
)
|
79
|
+
wandb_resume: WandbConfigResume = field(
|
80
|
+
value=False,
|
81
|
+
help="Whether to resume a previous run. Can be a run ID string.",
|
82
|
+
)
|
83
|
+
wandb_reinit: WandbConfigReinitOption = field(
|
84
|
+
value=WandbConfigReinitOption.RETURN_PREVIOUS,
|
85
|
+
help="Whether to allow multiple wandb.init() calls in the same process.",
|
86
|
+
)
|
34
87
|
|
35
88
|
|
36
89
|
Config = TypeVar("Config", bound=LoggerConfig)
|
@@ -74,12 +127,70 @@ class LoggerMixin(BaseTask[Config], Generic[Config]):
|
|
74
127
|
StateLogger(
|
75
128
|
run_directory=self.exp_dir,
|
76
129
|
),
|
77
|
-
|
78
|
-
run_directory=self.exp_dir,
|
79
|
-
log_interval_seconds=self.config.tensorboard_log_interval_seconds,
|
80
|
-
),
|
130
|
+
self._create_logger_backend(),
|
81
131
|
)
|
82
132
|
|
133
|
+
def _create_logger_backend(self) -> LoggerImpl:
|
134
|
+
match self.config.logger_backend:
|
135
|
+
case LoggerBackend.TENSORBOARD:
|
136
|
+
return TensorboardLogger(
|
137
|
+
run_directory=self.exp_dir if isinstance(self, ArtifactsMixin) else "./",
|
138
|
+
log_interval_seconds=self.config.tensorboard_log_interval_seconds,
|
139
|
+
)
|
140
|
+
case LoggerBackend.WANDB:
|
141
|
+
run_config = {}
|
142
|
+
if hasattr(self.config, "__dict__"):
|
143
|
+
# Convert config to a serializable dictionary
|
144
|
+
run_config = self._config_to_dict(self.config)
|
145
|
+
|
146
|
+
return WandbLogger(
|
147
|
+
project=self.config.wandb_project,
|
148
|
+
entity=self.config.wandb_entity,
|
149
|
+
name=self.config.wandb_name,
|
150
|
+
run_directory=self.exp_dir if isinstance(self, ArtifactsMixin) else None,
|
151
|
+
config=run_config,
|
152
|
+
tags=self.config.wandb_tags,
|
153
|
+
notes=self.config.wandb_notes,
|
154
|
+
log_interval_seconds=self.config.wandb_log_interval_seconds,
|
155
|
+
reinit=self.config.wandb_reinit,
|
156
|
+
resume=self.config.wandb_resume,
|
157
|
+
mode=self.config.wandb_mode,
|
158
|
+
)
|
159
|
+
case _:
|
160
|
+
# This shouldn't happen, as validation should take care of this
|
161
|
+
raise Exception(f"Invalid logger_backend '{self.config.logger_backend}'")
|
162
|
+
|
163
|
+
def _config_to_dict(self, config: Config) -> dict[str, Any]:
|
164
|
+
"""Convert a config object to a dictionary for W&B logging.
|
165
|
+
|
166
|
+
Args:
|
167
|
+
config: The configuration object to convert.
|
168
|
+
|
169
|
+
Returns:
|
170
|
+
A dictionary representation of the config.
|
171
|
+
"""
|
172
|
+
if hasattr(config, "__dict__"):
|
173
|
+
result: dict[str, Any] = {}
|
174
|
+
for key, value in config.__dict__.items():
|
175
|
+
if not key.startswith("_"):
|
176
|
+
# Recursively convert nested configs
|
177
|
+
if hasattr(value, "__dict__"):
|
178
|
+
result[key] = self._config_to_dict(value)
|
179
|
+
elif isinstance(value, (list, tuple)):
|
180
|
+
# Handle lists/tuples that might contain configs
|
181
|
+
result[key] = [
|
182
|
+
self._config_to_dict(item) if hasattr(item, "__dict__") else item for item in value
|
183
|
+
]
|
184
|
+
elif isinstance(value, dict):
|
185
|
+
# Handle dicts that might contain configs
|
186
|
+
result[key] = {
|
187
|
+
k: self._config_to_dict(v) if hasattr(v, "__dict__") else v for k, v in value.items()
|
188
|
+
}
|
189
|
+
else:
|
190
|
+
result[key] = value
|
191
|
+
return result
|
192
|
+
return {}
|
193
|
+
|
83
194
|
def write_logs(self, state: State) -> None:
|
84
195
|
self.logger.write(state)
|
85
196
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: xax
|
3
|
-
Version: 0.4.
|
3
|
+
Version: 0.4.3
|
4
4
|
Summary: A library for fast Jax experimentation
|
5
5
|
Home-page: https://github.com/kscalelabs/xax
|
6
6
|
Author: Benjamin Bolte
|
@@ -31,6 +31,8 @@ Requires-Dist: pytest; extra == "dev"
|
|
31
31
|
Requires-Dist: types-pillow; extra == "dev"
|
32
32
|
Requires-Dist: types-psutil; extra == "dev"
|
33
33
|
Requires-Dist: types-requests; extra == "dev"
|
34
|
+
Provides-Extra: wandb
|
35
|
+
Requires-Dist: wandb[media]; extra == "wandb"
|
34
36
|
Dynamic: author
|
35
37
|
Dynamic: description
|
36
38
|
Dynamic: description-content-type
|
@@ -1,4 +1,4 @@
|
|
1
|
-
xax/__init__.py,sha256=
|
1
|
+
xax/__init__.py,sha256=ckZ0hysU32yz3jOglRyau1iw2Tt3Ct0QP9qffIORZS4,17242
|
2
2
|
xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
3
|
xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
|
4
4
|
xax/requirements.txt,sha256=6qY-84e-sTmlfJNrSjwONQKqzAn5h8G_oGIhnhmfSr4,302
|
@@ -25,13 +25,14 @@ xax/task/task.py,sha256=Iy02wRUti5lDX1rfDHIgX87dGYeayJxJ9nzJzp_lMq0,1960
|
|
25
25
|
xax/task/launchers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
26
26
|
xax/task/launchers/base.py,sha256=8LB_r6YISKu1vq1zk3aVYmiedRr9MxE5IMRocs6unFI,731
|
27
27
|
xax/task/launchers/cli.py,sha256=cK7Nm-3fO-W2gTxpn3FEThsT2NvneS2w0UjA1Nt-84A,1402
|
28
|
-
xax/task/launchers/single_process.py,sha256=
|
28
|
+
xax/task/launchers/single_process.py,sha256=D67OtUGLifZa3wxWsB1RN97D-OjrA9gopAVdb-IlRzA,4471
|
29
29
|
xax/task/loggers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
30
30
|
xax/task/loggers/callback.py,sha256=zQuV1xCvz47Q3UQqP1D5mBhbVzptvmPR_7hX25vqSk0,1667
|
31
31
|
xax/task/loggers/json.py,sha256=6A5wL7kspsXnpPhI_vu0scgd2Z2-WLhw4gbBFm7eZMM,4377
|
32
32
|
xax/task/loggers/state.py,sha256=0Jy0NYnY4c0qt0LvNlaTaCKOSqk5SCKln5VdyuQGnIc,1407
|
33
33
|
xax/task/loggers/stdout.py,sha256=giKSW2R83YkgRefm3BLkE7t8Pbj5Dux4AgsdJxYIbGo,6619
|
34
34
|
xax/task/loggers/tensorboard.py,sha256=sRyBbeBeVXDTYhPZIKIapW0JEfL9hqqzhNTeIcSd374,8883
|
35
|
+
xax/task/loggers/wandb.py,sha256=iLs8VnFPJ_yryzLX5njrGOvZrqNehnVZM09xfUzRMuU,11392
|
35
36
|
xax/task/mixins/__init__.py,sha256=wYc4zfutdMyEmzCVV421gSf25ZXW9htNTSY_TW6vL_8,894
|
36
37
|
xax/task/mixins/artifacts.py,sha256=UN26TW22ARduO6Bjs0yRu4-V6-Md9MPbXLKDnS28m44,3861
|
37
38
|
xax/task/mixins/checkpointing.py,sha256=v50IZ7j58DWmEu-_6Zh_02R5KUVGhrMkg5n-MYM_J4c,11484
|
@@ -39,7 +40,7 @@ xax/task/mixins/compile.py,sha256=PG5aF3W9v_xGiImHgUJ7gmwuQQoSQWufdpl2N_mlLX0,39
|
|
39
40
|
xax/task/mixins/cpu_stats.py,sha256=rO_9a82ZdsNec61ya4FpYE-rWqPhpijRSXsOfc6caFA,9595
|
40
41
|
xax/task/mixins/data_loader.py,sha256=BKfOVWXR70vbyHMFlnlUiQQHXHH5zTj5WtmsymNCFB4,6722
|
41
42
|
xax/task/mixins/gpu_stats.py,sha256=USOyhXldxbsrl6eCtoFKTWUm_lfeG0cUCkQNUpXRdtA,8880
|
42
|
-
xax/task/mixins/logger.py,sha256=
|
43
|
+
xax/task/mixins/logger.py,sha256=mgJlw8ZYshm8F88jo1RuRNBIvYwYk7CF2o_7V6LPN_Q,7136
|
43
44
|
xax/task/mixins/process.py,sha256=hqDEsMp_SL6ee97iq26-G0g49OcWZZaX82JD4F22eJU,1781
|
44
45
|
xax/task/mixins/runnable.py,sha256=pcLrYc_TycZUY9zZim05Skc2FWk3IZKFnu6p3UDMonM,1966
|
45
46
|
xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
|
@@ -61,9 +62,9 @@ xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,706
|
|
61
62
|
xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
62
63
|
xax/utils/types/frozen_dict.py,sha256=QBdBblKenWffJyADQaFF_0iAEVa9EgMylV63udA2vJ4,4771
|
63
64
|
xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
|
64
|
-
xax-0.4.
|
65
|
-
xax-0.4.
|
66
|
-
xax-0.4.
|
67
|
-
xax-0.4.
|
68
|
-
xax-0.4.
|
69
|
-
xax-0.4.
|
65
|
+
xax-0.4.3.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
|
66
|
+
xax-0.4.3.dist-info/METADATA,sha256=kfEAbt8WILGev-MQY5uFmrTrCFI3uVI5R0lnigd2ZXw,1314
|
67
|
+
xax-0.4.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
68
|
+
xax-0.4.3.dist-info/entry_points.txt,sha256=uRC6rx5ce0bf-FblJaZSBMxxKFfMyoWTf8OWbBmLSe8,61
|
69
|
+
xax-0.4.3.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
|
70
|
+
xax-0.4.3.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|