xax 0.4.1__py3-none-any.whl → 0.4.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
xax/__init__.py CHANGED
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.4.1"
15
+ __version__ = "0.4.3"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -171,6 +171,7 @@ __all__ = [
171
171
  "uncolored",
172
172
  "wrapped",
173
173
  "FrozenDict",
174
+ "freeze_dict",
174
175
  "HashableArray",
175
176
  "hashable_array",
176
177
  ]
@@ -371,6 +372,7 @@ NAME_MAP: dict[str, str] = {
371
372
  "uncolored": "utils.text",
372
373
  "wrapped": "utils.text",
373
374
  "FrozenDict": "utils.types.frozen_dict",
375
+ "freeze_dict": "utils.types.frozen_dict",
374
376
  "HashableArray": "utils.types.hashable_array",
375
377
  "hashable_array": "utils.types.hashable_array",
376
378
  }
@@ -572,7 +574,7 @@ if IMPORT_ALL or TYPE_CHECKING:
572
574
  uncolored,
573
575
  wrapped,
574
576
  )
575
- from xax.utils.types.frozen_dict import FrozenDict
577
+ from xax.utils.types.frozen_dict import FrozenDict, freeze_dict
576
578
  from xax.utils.types.hashable_array import HashableArray, hashable_array
577
579
 
578
580
  del TYPE_CHECKING, IMPORT_ALL
@@ -1,21 +1,129 @@
1
1
  """Defines a launcher to train a model locally, in a single process."""
2
2
 
3
+ import logging
4
+ import os
5
+ import shutil
6
+ import subprocess
3
7
  from typing import TYPE_CHECKING
4
8
 
9
+ import jax
10
+
5
11
  from xax.task.base import RawConfigType
6
12
  from xax.task.launchers.base import BaseLauncher
13
+ from xax.task.mixins.gpu_stats import get_num_gpus
7
14
  from xax.utils.logging import configure_logging
8
15
 
9
16
  if TYPE_CHECKING:
10
17
  from xax.task.mixins.runnable import Config, RunnableMixin
11
18
 
12
19
 
20
+ def get_gpu_memory_info() -> dict[int, tuple[float, float]]:
21
+ """Get memory information for all GPUs.
22
+
23
+ Returns:
24
+ Dictionary mapping GPU index to (total_memory_mb, used_memory_mb)
25
+ """
26
+ command = "nvidia-smi --query-gpu=index,memory.total,memory.used --format=csv,noheader"
27
+
28
+ try:
29
+ with subprocess.Popen(command.split(), stdout=subprocess.PIPE, universal_newlines=True) as proc:
30
+ stdout = proc.stdout
31
+ assert stdout is not None
32
+
33
+ gpu_info = {}
34
+ for line in stdout:
35
+ line = line.strip()
36
+ if not line:
37
+ continue
38
+
39
+ parts = line.split(", ")
40
+ if len(parts) >= 3:
41
+ gpu_id = int(parts[0])
42
+ total_mem = float(parts[1].replace(" MiB", ""))
43
+ used_mem = float(parts[2].replace(" MiB", ""))
44
+ gpu_info[gpu_id] = (total_mem, used_mem)
45
+
46
+ return gpu_info
47
+
48
+ except Exception as e:
49
+ logger = configure_logging()
50
+ logger.warning("Failed to get GPU memory info: %s", e)
51
+ return {}
52
+
53
+
54
+ def select_best_gpu() -> int | None:
55
+ """Select the GPU with the most available memory.
56
+
57
+ Returns:
58
+ GPU index with most available memory, or None if no GPUs found
59
+ """
60
+ gpu_info = get_gpu_memory_info()
61
+
62
+ if not gpu_info:
63
+ return None
64
+
65
+ # Find GPU with most available memory
66
+ best_gpu = None
67
+ max_available: float = -1.0
68
+
69
+ for gpu_id, (total_mem, used_mem) in gpu_info.items():
70
+ available_mem = total_mem - used_mem
71
+ if available_mem > max_available:
72
+ max_available = available_mem
73
+ best_gpu = gpu_id
74
+
75
+ return best_gpu
76
+
77
+
78
+ def configure_gpu_devices(logger: logging.Logger | None = None) -> None:
79
+ if logger is None:
80
+ logger = configure_logging()
81
+
82
+ # If there are multiple devices, choose the one with the most
83
+ # available memory (i.e., the one which is likely not being used
84
+ # by other processes) and use only that device.
85
+ num_gpus = get_num_gpus()
86
+
87
+ if num_gpus > 1:
88
+ logger.info("Multiple GPUs detected (%d), selecting GPU with most available memory", num_gpus)
89
+
90
+ best_gpu = select_best_gpu()
91
+ if best_gpu is not None:
92
+ logger.info("Selected GPU %d for training", best_gpu)
93
+
94
+ # Set CUDA_VISIBLE_DEVICES to only show the selected GPU
95
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(best_gpu)
96
+
97
+ # Configure JAX to use the selected device
98
+ try:
99
+ devices = jax.devices("gpu")
100
+ if devices:
101
+ jax.config.update("jax_default_device", devices[0])
102
+ logger.info("Configured JAX to use device: %s", devices[0])
103
+ except Exception as e:
104
+ logger.warning("Failed to configure JAX device: %s", e)
105
+ else:
106
+ logger.warning("Could not determine best GPU, using default device selection")
107
+ elif num_gpus == 1:
108
+ logger.info("Single GPU detected, using default device selection")
109
+
110
+
111
+ def configure_devices(logger: logging.Logger | None = None) -> None:
112
+ if logger is None:
113
+ logger = configure_logging()
114
+
115
+ if shutil.which("nvidia-smi") is not None:
116
+ configure_gpu_devices(logger)
117
+
118
+
13
119
  def run_single_process_training(
14
120
  task: "type[RunnableMixin[Config]]",
15
121
  *cfgs: RawConfigType,
16
122
  use_cli: bool | list[str] = True,
123
+ logger: logging.Logger | None = None,
17
124
  ) -> None:
18
- logger = configure_logging()
125
+ if logger is None:
126
+ logger = configure_logging()
19
127
  task_obj = task.get_task(*cfgs, use_cli=use_cli)
20
128
  task_obj.add_logger_handlers(logger)
21
129
  task_obj.run()
@@ -28,4 +136,6 @@ class SingleProcessLauncher(BaseLauncher):
28
136
  *cfgs: RawConfigType,
29
137
  use_cli: bool | list[str] = True,
30
138
  ) -> None:
31
- run_single_process_training(task, *cfgs, use_cli=use_cli)
139
+ logger = configure_logging()
140
+ configure_devices(logger)
141
+ run_single_process_training(task, *cfgs, use_cli=use_cli, logger=logger)
@@ -0,0 +1,307 @@
1
+ # mypy: disable-error-code="import-not-found"
2
+ """Defines a Weights & Biases logger backend."""
3
+
4
+ import logging
5
+ import os
6
+ from enum import Enum
7
+ from pathlib import Path
8
+ from typing import Any, TypeVar
9
+
10
+ import numpy as np
11
+
12
+ from xax.nn.parallel import is_master
13
+ from xax.task.logger import LogError, LogErrorSummary, LoggerImpl, LogLine, LogPing, LogStatus
14
+ from xax.utils.jax import as_float
15
+
16
+ logger: logging.Logger = logging.getLogger(__name__)
17
+
18
+ T = TypeVar("T")
19
+
20
+
21
+ def sanitize_metric_name(name: str) -> str:
22
+ """Remove 4-byte unicode characters from metric names.
23
+
24
+ W&B has issues with 4-byte unicode characters in metric names,
25
+ so we need to filter them out.
26
+
27
+ Args:
28
+ name: The metric name to sanitize.
29
+
30
+ Returns:
31
+ The sanitized metric name.
32
+ """
33
+ # Filter out characters that don't fit in UCS-2 (Basic Multilingual Plane)
34
+ # These are characters with code points > 0xFFFF (4-byte UTF-8)
35
+ return "".join(char for char in name if ord(char) <= 0xFFFF)
36
+
37
+
38
+ class WandbConfigResumeOption(str, Enum):
39
+ ALLOW = "allow"
40
+ NEVER = "never"
41
+ MUST = "must"
42
+ AUTO = "auto"
43
+
44
+
45
+ class WandbConfigModeOption(str, Enum):
46
+ ONLINE = "online"
47
+ OFFLINE = "offline"
48
+ DISABLED = "disabled"
49
+ SHARED = "shared"
50
+
51
+
52
+ class WandbConfigReinitOption(str, Enum):
53
+ RETURN_PREVIOUS = "return_previous"
54
+ FINISH_PREVIOUS = "finish_previous"
55
+
56
+
57
+ WandbConfigResume = WandbConfigResumeOption | bool
58
+ WandbConfigMode = WandbConfigModeOption | None
59
+
60
+
61
+ class WandbLogger(LoggerImpl):
62
+ def __init__(
63
+ self,
64
+ project: str | None = None,
65
+ entity: str | None = None,
66
+ name: str | None = None,
67
+ run_directory: str | Path | None = None,
68
+ config: dict[str, Any] | None = None,
69
+ tags: list[str] | None = None,
70
+ notes: str | None = None,
71
+ log_interval_seconds: float = 10.0,
72
+ reinit: WandbConfigReinitOption = WandbConfigReinitOption.RETURN_PREVIOUS,
73
+ resume: WandbConfigResume = False,
74
+ mode: WandbConfigMode = None,
75
+ ) -> None:
76
+ """Defines a logger which writes to Weights & Biases.
77
+
78
+ Args:
79
+ project: The name of the W&B project to log to.
80
+ entity: The W&B entity (team or user) to log to.
81
+ name: The name of this run.
82
+ run_directory: The root run directory. If provided, wandb will save
83
+ files to a subdirectory here.
84
+ config: Configuration dictionary to log.
85
+ tags: List of tags for this run.
86
+ notes: Notes about this run.
87
+ log_interval_seconds: The interval between successive log lines.
88
+ reinit: Whether to allow multiple wandb.init() calls in the same process.
89
+ resume: Whether to resume a previous run. Can be a run ID string.
90
+ mode: Mode for wandb ("online", "offline", or "disabled").
91
+ """
92
+ try:
93
+ import wandb as _wandb # noqa: F401,PLC0415
94
+ except ImportError as e:
95
+ raise RuntimeError(
96
+ "WandbLogger requires the 'wandb' package. Install it with: pip install xax[wandb]"
97
+ ) from e
98
+
99
+ self._wandb = _wandb
100
+
101
+ super().__init__(log_interval_seconds)
102
+
103
+ self.project = project
104
+ self.entity = entity
105
+ self.name = name
106
+ self.config = config
107
+ self.tags = tags
108
+ self.notes = notes
109
+ self.reinit = reinit
110
+ self.resume: WandbConfigResume = resume
111
+ self.mode: WandbConfigMode = mode
112
+
113
+ # Set wandb directory if run_directory is provided
114
+ if run_directory is not None:
115
+ self.wandb_dir = Path(run_directory).expanduser().resolve() / "wandb"
116
+ self.wandb_dir.mkdir(parents=True, exist_ok=True)
117
+
118
+ self._started = False
119
+
120
+ # Store pending files to log
121
+ self.files: dict[str, str] = {}
122
+
123
+ self.start()
124
+
125
+ def start(self) -> None:
126
+ """Initialize the W&B run."""
127
+ if self._started or not is_master():
128
+ return
129
+
130
+ # Set wandb environment variables if needed
131
+ if self.wandb_dir is not None:
132
+ os.environ["WANDB_DIR"] = str(self.wandb_dir)
133
+
134
+ # Initialize wandb run
135
+ self.run = self._wandb.init( # pyright
136
+ project=self.project,
137
+ entity=self.entity,
138
+ name=self.name,
139
+ config=self.config,
140
+ tags=self.tags,
141
+ notes=self.notes,
142
+ reinit=self.reinit.value,
143
+ resume=self.resume.value if isinstance(self.resume, WandbConfigResumeOption) else self.resume,
144
+ mode=self.mode.value if isinstance(self.mode, WandbConfigModeOption) else self.mode,
145
+ )
146
+
147
+ self._started = True
148
+ logger.info("W&B run initialized: %s", self.run.url if self.run else "No URL available")
149
+
150
+ def stop(self) -> None:
151
+ """Finish the W&B run."""
152
+ if not self._started or not is_master():
153
+ return
154
+
155
+ if self.run is not None:
156
+ self.run.finish()
157
+ self._started = False
158
+
159
+ def log_file(self, name: str, contents: str) -> None:
160
+ """Store a file to be logged with the next write call.
161
+
162
+ Args:
163
+ name: The name of the file.
164
+ contents: The contents of the file.
165
+ """
166
+ if not is_master():
167
+ return
168
+ self.files[name] = contents
169
+
170
+ def write(self, line: LogLine) -> None:
171
+ """Writes the current log line to W&B.
172
+
173
+ Args:
174
+ line: The line to write.
175
+ """
176
+ if not is_master() or not self._started:
177
+ return
178
+
179
+ # Get step information
180
+ global_step = line.state.num_steps.item()
181
+
182
+ # Dictionary to collect all metrics for this step
183
+ metrics: dict[str, Any] = {}
184
+
185
+ # Log scalars
186
+ for namespace, scalars in line.scalars.items():
187
+ for scalar_key, scalar_value in scalars.items():
188
+ key = sanitize_metric_name(f"{namespace}/{scalar_key}")
189
+ metrics[key] = as_float(scalar_value.value)
190
+
191
+ # Log distributions as custom metrics (mean and std)
192
+ for namespace, distributions in line.distributions.items():
193
+ for distribution_key, distribution_value in distributions.items():
194
+ base_key = sanitize_metric_name(f"{namespace}/{distribution_key}")
195
+ metrics[f"{base_key}/mean"] = float(distribution_value.mean)
196
+ metrics[f"{base_key}/std"] = float(distribution_value.std)
197
+
198
+ # Log histograms
199
+ for namespace, histograms in line.histograms.items():
200
+ for histogram_key, histogram_value in histograms.items():
201
+ key = sanitize_metric_name(f"{namespace}/{histogram_key}")
202
+ # Create histogram data for wandb
203
+ # W&B expects a list of values or a numpy array
204
+ # We need to reconstruct the data from the histogram bins
205
+ values = []
206
+ for i, count in enumerate(histogram_value.bucket_counts):
207
+ if count > 0:
208
+ # Use the midpoint of each bucket
209
+ if i == 0:
210
+ val = histogram_value.bucket_limits[0]
211
+ else:
212
+ val = (histogram_value.bucket_limits[i - 1] + histogram_value.bucket_limits[i]) / 2
213
+ values.extend([val] * count)
214
+
215
+ if values:
216
+ # wandb.Histogram accepts lists directly
217
+ metrics[key] = self._wandb.Histogram(values)
218
+
219
+ # Log strings as HTML
220
+ for namespace, strings in line.strings.items():
221
+ for string_key, string_value in strings.items():
222
+ key = sanitize_metric_name(f"{namespace}/{string_key}")
223
+ # For strings, we can log them as HTML or just as text in a table
224
+ metrics[key] = self._wandb.Html(f"<pre>{string_value.value}</pre>")
225
+
226
+ # Log images
227
+ for namespace, images in line.images.items():
228
+ for image_key, image_value in images.items():
229
+ key = sanitize_metric_name(f"{namespace}/{image_key}")
230
+ # Convert PIL image to wandb.Image
231
+ metrics[key] = self._wandb.Image(image_value.image)
232
+
233
+ # Log videos
234
+ for namespace, videos in line.videos.items():
235
+ for video_key, video_value in videos.items():
236
+ key = sanitize_metric_name(f"{namespace}/{video_key}")
237
+ # wandb.Video expects shape (time, channels, height, width)
238
+ # Our format is (T, H, W, C) so we need to transpose to (T, C, H, W)
239
+ frames = video_value.frames.transpose(0, 3, 1, 2) # (T, H, W, C) -> (T, C, H, W)
240
+ metrics[key] = self._wandb.Video(frames, fps=video_value.fps, format="mp4")
241
+
242
+ # Log meshes (3D objects)
243
+ for namespace, meshes in line.meshes.items():
244
+ for mesh_key, mesh_value in meshes.items():
245
+ key = sanitize_metric_name(f"{namespace}/{mesh_key}")
246
+ # W&B Object3D expects vertices and faces in specific format
247
+ # vertices: (batch_size, num_vertices, 3) or (num_vertices, 3)
248
+ # faces: (batch_size, num_faces, 3) or (num_faces, 3)
249
+ vertices = mesh_value.vertices
250
+
251
+ # Handle batch dimension - take first batch if present
252
+ if vertices.ndim == 3:
253
+ vertices = vertices[0]
254
+
255
+ obj3d_data = {
256
+ "type": "lidar/beta",
257
+ "vertices": vertices.tolist(),
258
+ }
259
+
260
+ if mesh_value.faces is not None:
261
+ faces = mesh_value.faces
262
+ if faces.ndim == 3:
263
+ faces = faces[0]
264
+ obj3d_data["faces"] = faces.tolist()
265
+
266
+ if mesh_value.colors is not None:
267
+ colors = mesh_value.colors
268
+ if colors.ndim == 3:
269
+ colors = colors[0]
270
+ # Convert colors to 0-1 range if they're in 0-255 range
271
+ # The colors are already numpy arrays from LogMesh, converted by as_numpy
272
+ if colors.dtype == np.uint8:
273
+ colors = colors.astype(np.float32) / 255.0
274
+ obj3d_data["colors"] = colors.tolist()
275
+
276
+ metrics[key] = self._wandb.Object3D(obj3d_data)
277
+
278
+ # Log any pending files as artifacts or text
279
+ for name, contents in self.files.items():
280
+ # Log as HTML text
281
+ key = sanitize_metric_name(name)
282
+ key = f"{self.run.name}_{key}"
283
+ is_training_code = "code" in name
284
+ artifact = self._wandb.Artifact(
285
+ name=key if not is_training_code else "training_code",
286
+ type="code" if is_training_code else "unspecified",
287
+ )
288
+ with artifact.new_file(name) as f:
289
+ f.write(contents)
290
+ artifact.save()
291
+ self.files.clear()
292
+
293
+ # Log all metrics at once
294
+ if metrics and self.run:
295
+ self.run.log(metrics, step=global_step)
296
+
297
+ def write_error_summary(self, error_summary: LogErrorSummary) -> None:
298
+ pass
299
+
300
+ def write_error(self, error: LogError) -> None:
301
+ pass
302
+
303
+ def write_status(self, status: LogStatus) -> None:
304
+ pass
305
+
306
+ def write_ping(self, ping: LogPing) -> None:
307
+ pass
xax/task/mixins/logger.py CHANGED
@@ -2,9 +2,10 @@
2
2
 
3
3
  import os
4
4
  from dataclasses import dataclass
5
+ from enum import Enum
5
6
  from pathlib import Path
6
7
  from types import TracebackType
7
- from typing import Generic, Self, TypeVar
8
+ from typing import Any, Generic, Self, TypeVar
8
9
 
9
10
  import jax
10
11
 
@@ -16,10 +17,22 @@ from xax.task.loggers.json import JsonLogger
16
17
  from xax.task.loggers.state import StateLogger
17
18
  from xax.task.loggers.stdout import StdoutLogger
18
19
  from xax.task.loggers.tensorboard import TensorboardLogger
20
+ from xax.task.loggers.wandb import (
21
+ WandbConfigMode,
22
+ WandbConfigModeOption,
23
+ WandbConfigReinitOption,
24
+ WandbConfigResume,
25
+ WandbLogger,
26
+ )
19
27
  from xax.task.mixins.artifacts import ArtifactsMixin
20
28
  from xax.utils.text import is_interactive_session
21
29
 
22
30
 
31
+ class LoggerBackend(str, Enum):
32
+ TENSORBOARD = "tensorboard"
33
+ WANDB = "wandb"
34
+
35
+
23
36
  @jax.tree_util.register_dataclass
24
37
  @dataclass
25
38
  class LoggerConfig(BaseConfig):
@@ -27,10 +40,50 @@ class LoggerConfig(BaseConfig):
27
40
  value=1.0,
28
41
  help="The interval between successive log lines.",
29
42
  )
43
+ logger_backend: LoggerBackend = field(
44
+ value=LoggerBackend.TENSORBOARD,
45
+ help="The logger backend to use",
46
+ )
30
47
  tensorboard_log_interval_seconds: float = field(
31
48
  value=10.0,
32
49
  help="The interval between successive Tensorboard log lines.",
33
50
  )
51
+ wandb_project: str | None = field(
52
+ value=None,
53
+ help="The name of the W&B project to log to.",
54
+ )
55
+ wandb_entity: str | None = field(
56
+ value=None,
57
+ help="The W&B entity (team or user) to log to.",
58
+ )
59
+ wandb_name: str | None = field(
60
+ value=None,
61
+ help="The name of this run in W&B.",
62
+ )
63
+ wandb_tags: list[str] | None = field(
64
+ value=None,
65
+ help="List of tags for this W&B run.",
66
+ )
67
+ wandb_notes: str | None = field(
68
+ value=None,
69
+ help="Notes about this W&B run.",
70
+ )
71
+ wandb_log_interval_seconds: float = field(
72
+ value=10.0,
73
+ help="The interval between successive W&B log lines.",
74
+ )
75
+ wandb_mode: WandbConfigMode = field(
76
+ value=WandbConfigModeOption.ONLINE,
77
+ help="Mode for wandb (online, offline, or disabled).",
78
+ )
79
+ wandb_resume: WandbConfigResume = field(
80
+ value=False,
81
+ help="Whether to resume a previous run. Can be a run ID string.",
82
+ )
83
+ wandb_reinit: WandbConfigReinitOption = field(
84
+ value=WandbConfigReinitOption.RETURN_PREVIOUS,
85
+ help="Whether to allow multiple wandb.init() calls in the same process.",
86
+ )
34
87
 
35
88
 
36
89
  Config = TypeVar("Config", bound=LoggerConfig)
@@ -74,12 +127,70 @@ class LoggerMixin(BaseTask[Config], Generic[Config]):
74
127
  StateLogger(
75
128
  run_directory=self.exp_dir,
76
129
  ),
77
- TensorboardLogger(
78
- run_directory=self.exp_dir,
79
- log_interval_seconds=self.config.tensorboard_log_interval_seconds,
80
- ),
130
+ self._create_logger_backend(),
81
131
  )
82
132
 
133
+ def _create_logger_backend(self) -> LoggerImpl:
134
+ match self.config.logger_backend:
135
+ case LoggerBackend.TENSORBOARD:
136
+ return TensorboardLogger(
137
+ run_directory=self.exp_dir if isinstance(self, ArtifactsMixin) else "./",
138
+ log_interval_seconds=self.config.tensorboard_log_interval_seconds,
139
+ )
140
+ case LoggerBackend.WANDB:
141
+ run_config = {}
142
+ if hasattr(self.config, "__dict__"):
143
+ # Convert config to a serializable dictionary
144
+ run_config = self._config_to_dict(self.config)
145
+
146
+ return WandbLogger(
147
+ project=self.config.wandb_project,
148
+ entity=self.config.wandb_entity,
149
+ name=self.config.wandb_name,
150
+ run_directory=self.exp_dir if isinstance(self, ArtifactsMixin) else None,
151
+ config=run_config,
152
+ tags=self.config.wandb_tags,
153
+ notes=self.config.wandb_notes,
154
+ log_interval_seconds=self.config.wandb_log_interval_seconds,
155
+ reinit=self.config.wandb_reinit,
156
+ resume=self.config.wandb_resume,
157
+ mode=self.config.wandb_mode,
158
+ )
159
+ case _:
160
+ # This shouldn't happen, as validation should take care of this
161
+ raise Exception(f"Invalid logger_backend '{self.config.logger_backend}'")
162
+
163
+ def _config_to_dict(self, config: Config) -> dict[str, Any]:
164
+ """Convert a config object to a dictionary for W&B logging.
165
+
166
+ Args:
167
+ config: The configuration object to convert.
168
+
169
+ Returns:
170
+ A dictionary representation of the config.
171
+ """
172
+ if hasattr(config, "__dict__"):
173
+ result: dict[str, Any] = {}
174
+ for key, value in config.__dict__.items():
175
+ if not key.startswith("_"):
176
+ # Recursively convert nested configs
177
+ if hasattr(value, "__dict__"):
178
+ result[key] = self._config_to_dict(value)
179
+ elif isinstance(value, (list, tuple)):
180
+ # Handle lists/tuples that might contain configs
181
+ result[key] = [
182
+ self._config_to_dict(item) if hasattr(item, "__dict__") else item for item in value
183
+ ]
184
+ elif isinstance(value, dict):
185
+ # Handle dicts that might contain configs
186
+ result[key] = {
187
+ k: self._config_to_dict(v) if hasattr(v, "__dict__") else v for k, v in value.items()
188
+ }
189
+ else:
190
+ result[key] = value
191
+ return result
192
+ return {}
193
+
83
194
  def write_logs(self, state: State) -> None:
84
195
  self.logger.write(state)
85
196
 
@@ -146,3 +146,7 @@ def unfreeze(x: FrozenDict[K, V] | dict[str, Any]) -> dict[Any, Any]: # noqa: A
146
146
  return ys
147
147
  else:
148
148
  return x
149
+
150
+
151
+ def freeze_dict(x: Mapping[K, V]) -> FrozenDict[K, V]:
152
+ return FrozenDict(x)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.4.1
3
+ Version: 0.4.3
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -31,6 +31,8 @@ Requires-Dist: pytest; extra == "dev"
31
31
  Requires-Dist: types-pillow; extra == "dev"
32
32
  Requires-Dist: types-psutil; extra == "dev"
33
33
  Requires-Dist: types-requests; extra == "dev"
34
+ Provides-Extra: wandb
35
+ Requires-Dist: wandb[media]; extra == "wandb"
34
36
  Dynamic: author
35
37
  Dynamic: description
36
38
  Dynamic: description-content-type
@@ -1,4 +1,4 @@
1
- xax/__init__.py,sha256=CrRpqzOhz7HdzNOF55VHrtqFl-_9pudbSK-tDmVP5ZU,17164
1
+ xax/__init__.py,sha256=ckZ0hysU32yz3jOglRyau1iw2Tt3Ct0QP9qffIORZS4,17242
2
2
  xax/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  xax/requirements-dev.txt,sha256=qkscNkFzWd1S5fump-AKH53rR65v2x5FmboFdy_kKvs,128
4
4
  xax/requirements.txt,sha256=6qY-84e-sTmlfJNrSjwONQKqzAn5h8G_oGIhnhmfSr4,302
@@ -25,13 +25,14 @@ xax/task/task.py,sha256=Iy02wRUti5lDX1rfDHIgX87dGYeayJxJ9nzJzp_lMq0,1960
25
25
  xax/task/launchers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
26
26
  xax/task/launchers/base.py,sha256=8LB_r6YISKu1vq1zk3aVYmiedRr9MxE5IMRocs6unFI,731
27
27
  xax/task/launchers/cli.py,sha256=cK7Nm-3fO-W2gTxpn3FEThsT2NvneS2w0UjA1Nt-84A,1402
28
- xax/task/launchers/single_process.py,sha256=wdEUT-B-FE9aemmt1tB_rKKRNy60aiDhslsy2i-ojWo,896
28
+ xax/task/launchers/single_process.py,sha256=D67OtUGLifZa3wxWsB1RN97D-OjrA9gopAVdb-IlRzA,4471
29
29
  xax/task/loggers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
30
30
  xax/task/loggers/callback.py,sha256=zQuV1xCvz47Q3UQqP1D5mBhbVzptvmPR_7hX25vqSk0,1667
31
31
  xax/task/loggers/json.py,sha256=6A5wL7kspsXnpPhI_vu0scgd2Z2-WLhw4gbBFm7eZMM,4377
32
32
  xax/task/loggers/state.py,sha256=0Jy0NYnY4c0qt0LvNlaTaCKOSqk5SCKln5VdyuQGnIc,1407
33
33
  xax/task/loggers/stdout.py,sha256=giKSW2R83YkgRefm3BLkE7t8Pbj5Dux4AgsdJxYIbGo,6619
34
34
  xax/task/loggers/tensorboard.py,sha256=sRyBbeBeVXDTYhPZIKIapW0JEfL9hqqzhNTeIcSd374,8883
35
+ xax/task/loggers/wandb.py,sha256=iLs8VnFPJ_yryzLX5njrGOvZrqNehnVZM09xfUzRMuU,11392
35
36
  xax/task/mixins/__init__.py,sha256=wYc4zfutdMyEmzCVV421gSf25ZXW9htNTSY_TW6vL_8,894
36
37
  xax/task/mixins/artifacts.py,sha256=UN26TW22ARduO6Bjs0yRu4-V6-Md9MPbXLKDnS28m44,3861
37
38
  xax/task/mixins/checkpointing.py,sha256=v50IZ7j58DWmEu-_6Zh_02R5KUVGhrMkg5n-MYM_J4c,11484
@@ -39,7 +40,7 @@ xax/task/mixins/compile.py,sha256=PG5aF3W9v_xGiImHgUJ7gmwuQQoSQWufdpl2N_mlLX0,39
39
40
  xax/task/mixins/cpu_stats.py,sha256=rO_9a82ZdsNec61ya4FpYE-rWqPhpijRSXsOfc6caFA,9595
40
41
  xax/task/mixins/data_loader.py,sha256=BKfOVWXR70vbyHMFlnlUiQQHXHH5zTj5WtmsymNCFB4,6722
41
42
  xax/task/mixins/gpu_stats.py,sha256=USOyhXldxbsrl6eCtoFKTWUm_lfeG0cUCkQNUpXRdtA,8880
42
- xax/task/mixins/logger.py,sha256=6oXsJJyNUx6YT3q58FVXMZBUpMgjVkGre6BXFN20cVI,2808
43
+ xax/task/mixins/logger.py,sha256=mgJlw8ZYshm8F88jo1RuRNBIvYwYk7CF2o_7V6LPN_Q,7136
43
44
  xax/task/mixins/process.py,sha256=hqDEsMp_SL6ee97iq26-G0g49OcWZZaX82JD4F22eJU,1781
44
45
  xax/task/mixins/runnable.py,sha256=pcLrYc_TycZUY9zZim05Skc2FWk3IZKFnu6p3UDMonM,1966
45
46
  xax/task/mixins/step_wrapper.py,sha256=-Yu5Nft2CRw1JvZt6J_94SM1vqX8fk08IDK95Pmd2ew,1648
@@ -59,11 +60,11 @@ xax/utils/text.py,sha256=xS02aSzdywl3KIaNSpKWcxdd37oYlUJtu9wIjkc1wVc,10654
59
60
  xax/utils/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
60
61
  xax/utils/data/collate.py,sha256=Rd9vMomr_S_zCa_Hi4dO-8ntzAfVwndIUtuXFA3iNcc,7066
61
62
  xax/utils/types/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
62
- xax/utils/types/frozen_dict.py,sha256=ebtHENhyUzSjyJTlbMaLtcckQIJ7EtgJiok_40TJZpo,4689
63
+ xax/utils/types/frozen_dict.py,sha256=QBdBblKenWffJyADQaFF_0iAEVa9EgMylV63udA2vJ4,4771
63
64
  xax/utils/types/hashable_array.py,sha256=l5iIcFmkYzfGeaZmcSoeFkthFASqM8xJYK3AXhZQYwc,992
64
- xax-0.4.1.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
65
- xax-0.4.1.dist-info/METADATA,sha256=l2TJ6pust33Gak_9w7voIwyEQIfWXxu2VIB7q38LVxg,1246
66
- xax-0.4.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
67
- xax-0.4.1.dist-info/entry_points.txt,sha256=uRC6rx5ce0bf-FblJaZSBMxxKFfMyoWTf8OWbBmLSe8,61
68
- xax-0.4.1.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
69
- xax-0.4.1.dist-info/RECORD,,
65
+ xax-0.4.3.dist-info/licenses/LICENSE,sha256=HCN2bImAzUOXldAZZI7JZ9PYq6OwMlDAP_PpX1HnuN0,1071
66
+ xax-0.4.3.dist-info/METADATA,sha256=kfEAbt8WILGev-MQY5uFmrTrCFI3uVI5R0lnigd2ZXw,1314
67
+ xax-0.4.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
68
+ xax-0.4.3.dist-info/entry_points.txt,sha256=uRC6rx5ce0bf-FblJaZSBMxxKFfMyoWTf8OWbBmLSe8,61
69
+ xax-0.4.3.dist-info/top_level.txt,sha256=g4Au_r2XhvZ-lTybviH-Fh9g0zF4DAYHYxPue1-xbs8,4
70
+ xax-0.4.3.dist-info/RECORD,,
File without changes