xax 0.4.1__tar.gz → 0.4.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. {xax-0.4.1/xax.egg-info → xax-0.4.3}/PKG-INFO +3 -1
  2. {xax-0.4.1 → xax-0.4.3}/pyproject.toml +1 -0
  3. {xax-0.4.1 → xax-0.4.3}/setup.py +1 -0
  4. {xax-0.4.1 → xax-0.4.3}/xax/__init__.py +4 -2
  5. xax-0.4.3/xax/task/launchers/single_process.py +141 -0
  6. xax-0.4.3/xax/task/loggers/wandb.py +307 -0
  7. xax-0.4.3/xax/task/mixins/logger.py +203 -0
  8. {xax-0.4.1 → xax-0.4.3}/xax/utils/types/frozen_dict.py +4 -0
  9. {xax-0.4.1 → xax-0.4.3/xax.egg-info}/PKG-INFO +3 -1
  10. {xax-0.4.1 → xax-0.4.3}/xax.egg-info/SOURCES.txt +1 -0
  11. {xax-0.4.1 → xax-0.4.3}/xax.egg-info/requires.txt +3 -0
  12. xax-0.4.1/xax/task/launchers/single_process.py +0 -31
  13. xax-0.4.1/xax/task/mixins/logger.py +0 -92
  14. {xax-0.4.1 → xax-0.4.3}/LICENSE +0 -0
  15. {xax-0.4.1 → xax-0.4.3}/MANIFEST.in +0 -0
  16. {xax-0.4.1 → xax-0.4.3}/README.md +0 -0
  17. {xax-0.4.1 → xax-0.4.3}/setup.cfg +0 -0
  18. {xax-0.4.1 → xax-0.4.3}/xax/cli/__init__.py +0 -0
  19. {xax-0.4.1 → xax-0.4.3}/xax/cli/edit_config.py +0 -0
  20. {xax-0.4.1 → xax-0.4.3}/xax/core/__init__.py +0 -0
  21. {xax-0.4.1 → xax-0.4.3}/xax/core/conf.py +0 -0
  22. {xax-0.4.1 → xax-0.4.3}/xax/core/state.py +0 -0
  23. {xax-0.4.1 → xax-0.4.3}/xax/nn/__init__.py +0 -0
  24. {xax-0.4.1 → xax-0.4.3}/xax/nn/attention.py +0 -0
  25. {xax-0.4.1 → xax-0.4.3}/xax/nn/distributions.py +0 -0
  26. {xax-0.4.1 → xax-0.4.3}/xax/nn/embeddings.py +0 -0
  27. {xax-0.4.1 → xax-0.4.3}/xax/nn/functions.py +0 -0
  28. {xax-0.4.1 → xax-0.4.3}/xax/nn/geom.py +0 -0
  29. {xax-0.4.1 → xax-0.4.3}/xax/nn/losses.py +0 -0
  30. {xax-0.4.1 → xax-0.4.3}/xax/nn/metrics.py +0 -0
  31. {xax-0.4.1 → xax-0.4.3}/xax/nn/parallel.py +0 -0
  32. {xax-0.4.1 → xax-0.4.3}/xax/nn/ssm.py +0 -0
  33. {xax-0.4.1 → xax-0.4.3}/xax/py.typed +0 -0
  34. {xax-0.4.1 → xax-0.4.3}/xax/requirements-dev.txt +0 -0
  35. {xax-0.4.1 → xax-0.4.3}/xax/requirements.txt +0 -0
  36. {xax-0.4.1 → xax-0.4.3}/xax/task/__init__.py +0 -0
  37. {xax-0.4.1 → xax-0.4.3}/xax/task/base.py +0 -0
  38. {xax-0.4.1 → xax-0.4.3}/xax/task/launchers/__init__.py +0 -0
  39. {xax-0.4.1 → xax-0.4.3}/xax/task/launchers/base.py +0 -0
  40. {xax-0.4.1 → xax-0.4.3}/xax/task/launchers/cli.py +0 -0
  41. {xax-0.4.1 → xax-0.4.3}/xax/task/logger.py +0 -0
  42. {xax-0.4.1 → xax-0.4.3}/xax/task/loggers/__init__.py +0 -0
  43. {xax-0.4.1 → xax-0.4.3}/xax/task/loggers/callback.py +0 -0
  44. {xax-0.4.1 → xax-0.4.3}/xax/task/loggers/json.py +0 -0
  45. {xax-0.4.1 → xax-0.4.3}/xax/task/loggers/state.py +0 -0
  46. {xax-0.4.1 → xax-0.4.3}/xax/task/loggers/stdout.py +0 -0
  47. {xax-0.4.1 → xax-0.4.3}/xax/task/loggers/tensorboard.py +0 -0
  48. {xax-0.4.1 → xax-0.4.3}/xax/task/mixins/__init__.py +0 -0
  49. {xax-0.4.1 → xax-0.4.3}/xax/task/mixins/artifacts.py +0 -0
  50. {xax-0.4.1 → xax-0.4.3}/xax/task/mixins/checkpointing.py +0 -0
  51. {xax-0.4.1 → xax-0.4.3}/xax/task/mixins/compile.py +0 -0
  52. {xax-0.4.1 → xax-0.4.3}/xax/task/mixins/cpu_stats.py +0 -0
  53. {xax-0.4.1 → xax-0.4.3}/xax/task/mixins/data_loader.py +0 -0
  54. {xax-0.4.1 → xax-0.4.3}/xax/task/mixins/gpu_stats.py +0 -0
  55. {xax-0.4.1 → xax-0.4.3}/xax/task/mixins/process.py +0 -0
  56. {xax-0.4.1 → xax-0.4.3}/xax/task/mixins/runnable.py +0 -0
  57. {xax-0.4.1 → xax-0.4.3}/xax/task/mixins/step_wrapper.py +0 -0
  58. {xax-0.4.1 → xax-0.4.3}/xax/task/mixins/supervised.py +0 -0
  59. {xax-0.4.1 → xax-0.4.3}/xax/task/mixins/train.py +0 -0
  60. {xax-0.4.1 → xax-0.4.3}/xax/task/script.py +0 -0
  61. {xax-0.4.1 → xax-0.4.3}/xax/task/task.py +0 -0
  62. {xax-0.4.1 → xax-0.4.3}/xax/utils/__init__.py +0 -0
  63. {xax-0.4.1 → xax-0.4.3}/xax/utils/data/__init__.py +0 -0
  64. {xax-0.4.1 → xax-0.4.3}/xax/utils/data/collate.py +0 -0
  65. {xax-0.4.1 → xax-0.4.3}/xax/utils/debugging.py +0 -0
  66. {xax-0.4.1 → xax-0.4.3}/xax/utils/experiments.py +0 -0
  67. {xax-0.4.1 → xax-0.4.3}/xax/utils/jax.py +0 -0
  68. {xax-0.4.1 → xax-0.4.3}/xax/utils/jaxpr.py +0 -0
  69. {xax-0.4.1 → xax-0.4.3}/xax/utils/logging.py +0 -0
  70. {xax-0.4.1 → xax-0.4.3}/xax/utils/numpy.py +0 -0
  71. {xax-0.4.1 → xax-0.4.3}/xax/utils/profile.py +0 -0
  72. {xax-0.4.1 → xax-0.4.3}/xax/utils/pytree.py +0 -0
  73. {xax-0.4.1 → xax-0.4.3}/xax/utils/tensorboard.py +0 -0
  74. {xax-0.4.1 → xax-0.4.3}/xax/utils/text.py +0 -0
  75. {xax-0.4.1 → xax-0.4.3}/xax/utils/types/__init__.py +0 -0
  76. {xax-0.4.1 → xax-0.4.3}/xax/utils/types/hashable_array.py +0 -0
  77. {xax-0.4.1 → xax-0.4.3}/xax.egg-info/dependency_links.txt +0 -0
  78. {xax-0.4.1 → xax-0.4.3}/xax.egg-info/entry_points.txt +0 -0
  79. {xax-0.4.1 → xax-0.4.3}/xax.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.4.1
3
+ Version: 0.4.3
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -31,6 +31,8 @@ Requires-Dist: pytest; extra == "dev"
31
31
  Requires-Dist: types-pillow; extra == "dev"
32
32
  Requires-Dist: types-psutil; extra == "dev"
33
33
  Requires-Dist: types-requests; extra == "dev"
34
+ Provides-Extra: wandb
35
+ Requires-Dist: wandb[media]; extra == "wandb"
34
36
  Dynamic: author
35
37
  Dynamic: description
36
38
  Dynamic: description-content-type
@@ -35,6 +35,7 @@ explicit_package_bases = true
35
35
  [[tool.mypy.overrides]]
36
36
 
37
37
  module = [
38
+ "chex.*",
38
39
  "optax.*",
39
40
  "setuptools.*",
40
41
  "tensorboard.*",
@@ -33,6 +33,7 @@ setup(
33
33
  tests_require=requirements_dev,
34
34
  extras_require={
35
35
  "dev": requirements_dev,
36
+ "wandb": ["wandb[media]"],
36
37
  },
37
38
  package_data={
38
39
  "xax": [
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.4.1"
15
+ __version__ = "0.4.3"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -171,6 +171,7 @@ __all__ = [
171
171
  "uncolored",
172
172
  "wrapped",
173
173
  "FrozenDict",
174
+ "freeze_dict",
174
175
  "HashableArray",
175
176
  "hashable_array",
176
177
  ]
@@ -371,6 +372,7 @@ NAME_MAP: dict[str, str] = {
371
372
  "uncolored": "utils.text",
372
373
  "wrapped": "utils.text",
373
374
  "FrozenDict": "utils.types.frozen_dict",
375
+ "freeze_dict": "utils.types.frozen_dict",
374
376
  "HashableArray": "utils.types.hashable_array",
375
377
  "hashable_array": "utils.types.hashable_array",
376
378
  }
@@ -572,7 +574,7 @@ if IMPORT_ALL or TYPE_CHECKING:
572
574
  uncolored,
573
575
  wrapped,
574
576
  )
575
- from xax.utils.types.frozen_dict import FrozenDict
577
+ from xax.utils.types.frozen_dict import FrozenDict, freeze_dict
576
578
  from xax.utils.types.hashable_array import HashableArray, hashable_array
577
579
 
578
580
  del TYPE_CHECKING, IMPORT_ALL
@@ -0,0 +1,141 @@
1
+ """Defines a launcher to train a model locally, in a single process."""
2
+
3
+ import logging
4
+ import os
5
+ import shutil
6
+ import subprocess
7
+ from typing import TYPE_CHECKING
8
+
9
+ import jax
10
+
11
+ from xax.task.base import RawConfigType
12
+ from xax.task.launchers.base import BaseLauncher
13
+ from xax.task.mixins.gpu_stats import get_num_gpus
14
+ from xax.utils.logging import configure_logging
15
+
16
+ if TYPE_CHECKING:
17
+ from xax.task.mixins.runnable import Config, RunnableMixin
18
+
19
+
20
+ def get_gpu_memory_info() -> dict[int, tuple[float, float]]:
21
+ """Get memory information for all GPUs.
22
+
23
+ Returns:
24
+ Dictionary mapping GPU index to (total_memory_mb, used_memory_mb)
25
+ """
26
+ command = "nvidia-smi --query-gpu=index,memory.total,memory.used --format=csv,noheader"
27
+
28
+ try:
29
+ with subprocess.Popen(command.split(), stdout=subprocess.PIPE, universal_newlines=True) as proc:
30
+ stdout = proc.stdout
31
+ assert stdout is not None
32
+
33
+ gpu_info = {}
34
+ for line in stdout:
35
+ line = line.strip()
36
+ if not line:
37
+ continue
38
+
39
+ parts = line.split(", ")
40
+ if len(parts) >= 3:
41
+ gpu_id = int(parts[0])
42
+ total_mem = float(parts[1].replace(" MiB", ""))
43
+ used_mem = float(parts[2].replace(" MiB", ""))
44
+ gpu_info[gpu_id] = (total_mem, used_mem)
45
+
46
+ return gpu_info
47
+
48
+ except Exception as e:
49
+ logger = configure_logging()
50
+ logger.warning("Failed to get GPU memory info: %s", e)
51
+ return {}
52
+
53
+
54
+ def select_best_gpu() -> int | None:
55
+ """Select the GPU with the most available memory.
56
+
57
+ Returns:
58
+ GPU index with most available memory, or None if no GPUs found
59
+ """
60
+ gpu_info = get_gpu_memory_info()
61
+
62
+ if not gpu_info:
63
+ return None
64
+
65
+ # Find GPU with most available memory
66
+ best_gpu = None
67
+ max_available: float = -1.0
68
+
69
+ for gpu_id, (total_mem, used_mem) in gpu_info.items():
70
+ available_mem = total_mem - used_mem
71
+ if available_mem > max_available:
72
+ max_available = available_mem
73
+ best_gpu = gpu_id
74
+
75
+ return best_gpu
76
+
77
+
78
+ def configure_gpu_devices(logger: logging.Logger | None = None) -> None:
79
+ if logger is None:
80
+ logger = configure_logging()
81
+
82
+ # If there are multiple devices, choose the one with the most
83
+ # available memory (i.e., the one which is likely not being used
84
+ # by other processes) and use only that device.
85
+ num_gpus = get_num_gpus()
86
+
87
+ if num_gpus > 1:
88
+ logger.info("Multiple GPUs detected (%d), selecting GPU with most available memory", num_gpus)
89
+
90
+ best_gpu = select_best_gpu()
91
+ if best_gpu is not None:
92
+ logger.info("Selected GPU %d for training", best_gpu)
93
+
94
+ # Set CUDA_VISIBLE_DEVICES to only show the selected GPU
95
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(best_gpu)
96
+
97
+ # Configure JAX to use the selected device
98
+ try:
99
+ devices = jax.devices("gpu")
100
+ if devices:
101
+ jax.config.update("jax_default_device", devices[0])
102
+ logger.info("Configured JAX to use device: %s", devices[0])
103
+ except Exception as e:
104
+ logger.warning("Failed to configure JAX device: %s", e)
105
+ else:
106
+ logger.warning("Could not determine best GPU, using default device selection")
107
+ elif num_gpus == 1:
108
+ logger.info("Single GPU detected, using default device selection")
109
+
110
+
111
+ def configure_devices(logger: logging.Logger | None = None) -> None:
112
+ if logger is None:
113
+ logger = configure_logging()
114
+
115
+ if shutil.which("nvidia-smi") is not None:
116
+ configure_gpu_devices(logger)
117
+
118
+
119
+ def run_single_process_training(
120
+ task: "type[RunnableMixin[Config]]",
121
+ *cfgs: RawConfigType,
122
+ use_cli: bool | list[str] = True,
123
+ logger: logging.Logger | None = None,
124
+ ) -> None:
125
+ if logger is None:
126
+ logger = configure_logging()
127
+ task_obj = task.get_task(*cfgs, use_cli=use_cli)
128
+ task_obj.add_logger_handlers(logger)
129
+ task_obj.run()
130
+
131
+
132
+ class SingleProcessLauncher(BaseLauncher):
133
+ def launch(
134
+ self,
135
+ task: "type[RunnableMixin[Config]]",
136
+ *cfgs: RawConfigType,
137
+ use_cli: bool | list[str] = True,
138
+ ) -> None:
139
+ logger = configure_logging()
140
+ configure_devices(logger)
141
+ run_single_process_training(task, *cfgs, use_cli=use_cli, logger=logger)
@@ -0,0 +1,307 @@
1
+ # mypy: disable-error-code="import-not-found"
2
+ """Defines a Weights & Biases logger backend."""
3
+
4
+ import logging
5
+ import os
6
+ from enum import Enum
7
+ from pathlib import Path
8
+ from typing import Any, TypeVar
9
+
10
+ import numpy as np
11
+
12
+ from xax.nn.parallel import is_master
13
+ from xax.task.logger import LogError, LogErrorSummary, LoggerImpl, LogLine, LogPing, LogStatus
14
+ from xax.utils.jax import as_float
15
+
16
+ logger: logging.Logger = logging.getLogger(__name__)
17
+
18
+ T = TypeVar("T")
19
+
20
+
21
+ def sanitize_metric_name(name: str) -> str:
22
+ """Remove 4-byte unicode characters from metric names.
23
+
24
+ W&B has issues with 4-byte unicode characters in metric names,
25
+ so we need to filter them out.
26
+
27
+ Args:
28
+ name: The metric name to sanitize.
29
+
30
+ Returns:
31
+ The sanitized metric name.
32
+ """
33
+ # Filter out characters that don't fit in UCS-2 (Basic Multilingual Plane)
34
+ # These are characters with code points > 0xFFFF (4-byte UTF-8)
35
+ return "".join(char for char in name if ord(char) <= 0xFFFF)
36
+
37
+
38
+ class WandbConfigResumeOption(str, Enum):
39
+ ALLOW = "allow"
40
+ NEVER = "never"
41
+ MUST = "must"
42
+ AUTO = "auto"
43
+
44
+
45
+ class WandbConfigModeOption(str, Enum):
46
+ ONLINE = "online"
47
+ OFFLINE = "offline"
48
+ DISABLED = "disabled"
49
+ SHARED = "shared"
50
+
51
+
52
+ class WandbConfigReinitOption(str, Enum):
53
+ RETURN_PREVIOUS = "return_previous"
54
+ FINISH_PREVIOUS = "finish_previous"
55
+
56
+
57
+ WandbConfigResume = WandbConfigResumeOption | bool
58
+ WandbConfigMode = WandbConfigModeOption | None
59
+
60
+
61
+ class WandbLogger(LoggerImpl):
62
+ def __init__(
63
+ self,
64
+ project: str | None = None,
65
+ entity: str | None = None,
66
+ name: str | None = None,
67
+ run_directory: str | Path | None = None,
68
+ config: dict[str, Any] | None = None,
69
+ tags: list[str] | None = None,
70
+ notes: str | None = None,
71
+ log_interval_seconds: float = 10.0,
72
+ reinit: WandbConfigReinitOption = WandbConfigReinitOption.RETURN_PREVIOUS,
73
+ resume: WandbConfigResume = False,
74
+ mode: WandbConfigMode = None,
75
+ ) -> None:
76
+ """Defines a logger which writes to Weights & Biases.
77
+
78
+ Args:
79
+ project: The name of the W&B project to log to.
80
+ entity: The W&B entity (team or user) to log to.
81
+ name: The name of this run.
82
+ run_directory: The root run directory. If provided, wandb will save
83
+ files to a subdirectory here.
84
+ config: Configuration dictionary to log.
85
+ tags: List of tags for this run.
86
+ notes: Notes about this run.
87
+ log_interval_seconds: The interval between successive log lines.
88
+ reinit: Whether to allow multiple wandb.init() calls in the same process.
89
+ resume: Whether to resume a previous run. Can be a run ID string.
90
+ mode: Mode for wandb ("online", "offline", or "disabled").
91
+ """
92
+ try:
93
+ import wandb as _wandb # noqa: F401,PLC0415
94
+ except ImportError as e:
95
+ raise RuntimeError(
96
+ "WandbLogger requires the 'wandb' package. Install it with: pip install xax[wandb]"
97
+ ) from e
98
+
99
+ self._wandb = _wandb
100
+
101
+ super().__init__(log_interval_seconds)
102
+
103
+ self.project = project
104
+ self.entity = entity
105
+ self.name = name
106
+ self.config = config
107
+ self.tags = tags
108
+ self.notes = notes
109
+ self.reinit = reinit
110
+ self.resume: WandbConfigResume = resume
111
+ self.mode: WandbConfigMode = mode
112
+
113
+ # Set wandb directory if run_directory is provided
114
+ if run_directory is not None:
115
+ self.wandb_dir = Path(run_directory).expanduser().resolve() / "wandb"
116
+ self.wandb_dir.mkdir(parents=True, exist_ok=True)
117
+
118
+ self._started = False
119
+
120
+ # Store pending files to log
121
+ self.files: dict[str, str] = {}
122
+
123
+ self.start()
124
+
125
+ def start(self) -> None:
126
+ """Initialize the W&B run."""
127
+ if self._started or not is_master():
128
+ return
129
+
130
+ # Set wandb environment variables if needed
131
+ if self.wandb_dir is not None:
132
+ os.environ["WANDB_DIR"] = str(self.wandb_dir)
133
+
134
+ # Initialize wandb run
135
+ self.run = self._wandb.init( # pyright
136
+ project=self.project,
137
+ entity=self.entity,
138
+ name=self.name,
139
+ config=self.config,
140
+ tags=self.tags,
141
+ notes=self.notes,
142
+ reinit=self.reinit.value,
143
+ resume=self.resume.value if isinstance(self.resume, WandbConfigResumeOption) else self.resume,
144
+ mode=self.mode.value if isinstance(self.mode, WandbConfigModeOption) else self.mode,
145
+ )
146
+
147
+ self._started = True
148
+ logger.info("W&B run initialized: %s", self.run.url if self.run else "No URL available")
149
+
150
+ def stop(self) -> None:
151
+ """Finish the W&B run."""
152
+ if not self._started or not is_master():
153
+ return
154
+
155
+ if self.run is not None:
156
+ self.run.finish()
157
+ self._started = False
158
+
159
+ def log_file(self, name: str, contents: str) -> None:
160
+ """Store a file to be logged with the next write call.
161
+
162
+ Args:
163
+ name: The name of the file.
164
+ contents: The contents of the file.
165
+ """
166
+ if not is_master():
167
+ return
168
+ self.files[name] = contents
169
+
170
+ def write(self, line: LogLine) -> None:
171
+ """Writes the current log line to W&B.
172
+
173
+ Args:
174
+ line: The line to write.
175
+ """
176
+ if not is_master() or not self._started:
177
+ return
178
+
179
+ # Get step information
180
+ global_step = line.state.num_steps.item()
181
+
182
+ # Dictionary to collect all metrics for this step
183
+ metrics: dict[str, Any] = {}
184
+
185
+ # Log scalars
186
+ for namespace, scalars in line.scalars.items():
187
+ for scalar_key, scalar_value in scalars.items():
188
+ key = sanitize_metric_name(f"{namespace}/{scalar_key}")
189
+ metrics[key] = as_float(scalar_value.value)
190
+
191
+ # Log distributions as custom metrics (mean and std)
192
+ for namespace, distributions in line.distributions.items():
193
+ for distribution_key, distribution_value in distributions.items():
194
+ base_key = sanitize_metric_name(f"{namespace}/{distribution_key}")
195
+ metrics[f"{base_key}/mean"] = float(distribution_value.mean)
196
+ metrics[f"{base_key}/std"] = float(distribution_value.std)
197
+
198
+ # Log histograms
199
+ for namespace, histograms in line.histograms.items():
200
+ for histogram_key, histogram_value in histograms.items():
201
+ key = sanitize_metric_name(f"{namespace}/{histogram_key}")
202
+ # Create histogram data for wandb
203
+ # W&B expects a list of values or a numpy array
204
+ # We need to reconstruct the data from the histogram bins
205
+ values = []
206
+ for i, count in enumerate(histogram_value.bucket_counts):
207
+ if count > 0:
208
+ # Use the midpoint of each bucket
209
+ if i == 0:
210
+ val = histogram_value.bucket_limits[0]
211
+ else:
212
+ val = (histogram_value.bucket_limits[i - 1] + histogram_value.bucket_limits[i]) / 2
213
+ values.extend([val] * count)
214
+
215
+ if values:
216
+ # wandb.Histogram accepts lists directly
217
+ metrics[key] = self._wandb.Histogram(values)
218
+
219
+ # Log strings as HTML
220
+ for namespace, strings in line.strings.items():
221
+ for string_key, string_value in strings.items():
222
+ key = sanitize_metric_name(f"{namespace}/{string_key}")
223
+ # For strings, we can log them as HTML or just as text in a table
224
+ metrics[key] = self._wandb.Html(f"<pre>{string_value.value}</pre>")
225
+
226
+ # Log images
227
+ for namespace, images in line.images.items():
228
+ for image_key, image_value in images.items():
229
+ key = sanitize_metric_name(f"{namespace}/{image_key}")
230
+ # Convert PIL image to wandb.Image
231
+ metrics[key] = self._wandb.Image(image_value.image)
232
+
233
+ # Log videos
234
+ for namespace, videos in line.videos.items():
235
+ for video_key, video_value in videos.items():
236
+ key = sanitize_metric_name(f"{namespace}/{video_key}")
237
+ # wandb.Video expects shape (time, channels, height, width)
238
+ # Our format is (T, H, W, C) so we need to transpose to (T, C, H, W)
239
+ frames = video_value.frames.transpose(0, 3, 1, 2) # (T, H, W, C) -> (T, C, H, W)
240
+ metrics[key] = self._wandb.Video(frames, fps=video_value.fps, format="mp4")
241
+
242
+ # Log meshes (3D objects)
243
+ for namespace, meshes in line.meshes.items():
244
+ for mesh_key, mesh_value in meshes.items():
245
+ key = sanitize_metric_name(f"{namespace}/{mesh_key}")
246
+ # W&B Object3D expects vertices and faces in specific format
247
+ # vertices: (batch_size, num_vertices, 3) or (num_vertices, 3)
248
+ # faces: (batch_size, num_faces, 3) or (num_faces, 3)
249
+ vertices = mesh_value.vertices
250
+
251
+ # Handle batch dimension - take first batch if present
252
+ if vertices.ndim == 3:
253
+ vertices = vertices[0]
254
+
255
+ obj3d_data = {
256
+ "type": "lidar/beta",
257
+ "vertices": vertices.tolist(),
258
+ }
259
+
260
+ if mesh_value.faces is not None:
261
+ faces = mesh_value.faces
262
+ if faces.ndim == 3:
263
+ faces = faces[0]
264
+ obj3d_data["faces"] = faces.tolist()
265
+
266
+ if mesh_value.colors is not None:
267
+ colors = mesh_value.colors
268
+ if colors.ndim == 3:
269
+ colors = colors[0]
270
+ # Convert colors to 0-1 range if they're in 0-255 range
271
+ # The colors are already numpy arrays from LogMesh, converted by as_numpy
272
+ if colors.dtype == np.uint8:
273
+ colors = colors.astype(np.float32) / 255.0
274
+ obj3d_data["colors"] = colors.tolist()
275
+
276
+ metrics[key] = self._wandb.Object3D(obj3d_data)
277
+
278
+ # Log any pending files as artifacts or text
279
+ for name, contents in self.files.items():
280
+ # Log as HTML text
281
+ key = sanitize_metric_name(name)
282
+ key = f"{self.run.name}_{key}"
283
+ is_training_code = "code" in name
284
+ artifact = self._wandb.Artifact(
285
+ name=key if not is_training_code else "training_code",
286
+ type="code" if is_training_code else "unspecified",
287
+ )
288
+ with artifact.new_file(name) as f:
289
+ f.write(contents)
290
+ artifact.save()
291
+ self.files.clear()
292
+
293
+ # Log all metrics at once
294
+ if metrics and self.run:
295
+ self.run.log(metrics, step=global_step)
296
+
297
+ def write_error_summary(self, error_summary: LogErrorSummary) -> None:
298
+ pass
299
+
300
+ def write_error(self, error: LogError) -> None:
301
+ pass
302
+
303
+ def write_status(self, status: LogStatus) -> None:
304
+ pass
305
+
306
+ def write_ping(self, ping: LogPing) -> None:
307
+ pass
@@ -0,0 +1,203 @@
1
+ """Defines a mixin for incorporating some logging functionality."""
2
+
3
+ import os
4
+ from dataclasses import dataclass
5
+ from enum import Enum
6
+ from pathlib import Path
7
+ from types import TracebackType
8
+ from typing import Any, Generic, Self, TypeVar
9
+
10
+ import jax
11
+
12
+ from xax.core.conf import field
13
+ from xax.core.state import State
14
+ from xax.task.base import BaseConfig, BaseTask
15
+ from xax.task.logger import Logger, LoggerImpl
16
+ from xax.task.loggers.json import JsonLogger
17
+ from xax.task.loggers.state import StateLogger
18
+ from xax.task.loggers.stdout import StdoutLogger
19
+ from xax.task.loggers.tensorboard import TensorboardLogger
20
+ from xax.task.loggers.wandb import (
21
+ WandbConfigMode,
22
+ WandbConfigModeOption,
23
+ WandbConfigReinitOption,
24
+ WandbConfigResume,
25
+ WandbLogger,
26
+ )
27
+ from xax.task.mixins.artifacts import ArtifactsMixin
28
+ from xax.utils.text import is_interactive_session
29
+
30
+
31
+ class LoggerBackend(str, Enum):
32
+ TENSORBOARD = "tensorboard"
33
+ WANDB = "wandb"
34
+
35
+
36
+ @jax.tree_util.register_dataclass
37
+ @dataclass
38
+ class LoggerConfig(BaseConfig):
39
+ log_interval_seconds: float = field(
40
+ value=1.0,
41
+ help="The interval between successive log lines.",
42
+ )
43
+ logger_backend: LoggerBackend = field(
44
+ value=LoggerBackend.TENSORBOARD,
45
+ help="The logger backend to use",
46
+ )
47
+ tensorboard_log_interval_seconds: float = field(
48
+ value=10.0,
49
+ help="The interval between successive Tensorboard log lines.",
50
+ )
51
+ wandb_project: str | None = field(
52
+ value=None,
53
+ help="The name of the W&B project to log to.",
54
+ )
55
+ wandb_entity: str | None = field(
56
+ value=None,
57
+ help="The W&B entity (team or user) to log to.",
58
+ )
59
+ wandb_name: str | None = field(
60
+ value=None,
61
+ help="The name of this run in W&B.",
62
+ )
63
+ wandb_tags: list[str] | None = field(
64
+ value=None,
65
+ help="List of tags for this W&B run.",
66
+ )
67
+ wandb_notes: str | None = field(
68
+ value=None,
69
+ help="Notes about this W&B run.",
70
+ )
71
+ wandb_log_interval_seconds: float = field(
72
+ value=10.0,
73
+ help="The interval between successive W&B log lines.",
74
+ )
75
+ wandb_mode: WandbConfigMode = field(
76
+ value=WandbConfigModeOption.ONLINE,
77
+ help="Mode for wandb (online, offline, or disabled).",
78
+ )
79
+ wandb_resume: WandbConfigResume = field(
80
+ value=False,
81
+ help="Whether to resume a previous run. Can be a run ID string.",
82
+ )
83
+ wandb_reinit: WandbConfigReinitOption = field(
84
+ value=WandbConfigReinitOption.RETURN_PREVIOUS,
85
+ help="Whether to allow multiple wandb.init() calls in the same process.",
86
+ )
87
+
88
+
89
+ Config = TypeVar("Config", bound=LoggerConfig)
90
+
91
+
92
+ def get_env_var(name: str, default: bool) -> bool:
93
+ if name not in os.environ:
94
+ return default
95
+ return os.environ[name].strip() == "1"
96
+
97
+
98
+ class LoggerMixin(BaseTask[Config], Generic[Config]):
99
+ logger: Logger
100
+
101
+ def __init__(self, config: Config) -> None:
102
+ super().__init__(config)
103
+
104
+ self.logger = Logger()
105
+
106
+ def log_directory(self) -> Path | None:
107
+ return None
108
+
109
+ def add_logger(self, *logger: LoggerImpl) -> None:
110
+ self.logger.add_logger(*logger)
111
+
112
+ def set_loggers(self) -> None:
113
+ self.add_logger(
114
+ StdoutLogger(
115
+ log_interval_seconds=self.config.log_interval_seconds,
116
+ )
117
+ if is_interactive_session()
118
+ else JsonLogger(
119
+ log_interval_seconds=self.config.log_interval_seconds,
120
+ )
121
+ )
122
+
123
+ # If this is also an ArtifactsMixin, we should default add some
124
+ # additional loggers which log data to the artifacts directory.
125
+ if isinstance(self, ArtifactsMixin):
126
+ self.add_logger(
127
+ StateLogger(
128
+ run_directory=self.exp_dir,
129
+ ),
130
+ self._create_logger_backend(),
131
+ )
132
+
133
+ def _create_logger_backend(self) -> LoggerImpl:
134
+ match self.config.logger_backend:
135
+ case LoggerBackend.TENSORBOARD:
136
+ return TensorboardLogger(
137
+ run_directory=self.exp_dir if isinstance(self, ArtifactsMixin) else "./",
138
+ log_interval_seconds=self.config.tensorboard_log_interval_seconds,
139
+ )
140
+ case LoggerBackend.WANDB:
141
+ run_config = {}
142
+ if hasattr(self.config, "__dict__"):
143
+ # Convert config to a serializable dictionary
144
+ run_config = self._config_to_dict(self.config)
145
+
146
+ return WandbLogger(
147
+ project=self.config.wandb_project,
148
+ entity=self.config.wandb_entity,
149
+ name=self.config.wandb_name,
150
+ run_directory=self.exp_dir if isinstance(self, ArtifactsMixin) else None,
151
+ config=run_config,
152
+ tags=self.config.wandb_tags,
153
+ notes=self.config.wandb_notes,
154
+ log_interval_seconds=self.config.wandb_log_interval_seconds,
155
+ reinit=self.config.wandb_reinit,
156
+ resume=self.config.wandb_resume,
157
+ mode=self.config.wandb_mode,
158
+ )
159
+ case _:
160
+ # This shouldn't happen, as validation should take care of this
161
+ raise Exception(f"Invalid logger_backend '{self.config.logger_backend}'")
162
+
163
+ def _config_to_dict(self, config: Config) -> dict[str, Any]:
164
+ """Convert a config object to a dictionary for W&B logging.
165
+
166
+ Args:
167
+ config: The configuration object to convert.
168
+
169
+ Returns:
170
+ A dictionary representation of the config.
171
+ """
172
+ if hasattr(config, "__dict__"):
173
+ result: dict[str, Any] = {}
174
+ for key, value in config.__dict__.items():
175
+ if not key.startswith("_"):
176
+ # Recursively convert nested configs
177
+ if hasattr(value, "__dict__"):
178
+ result[key] = self._config_to_dict(value)
179
+ elif isinstance(value, (list, tuple)):
180
+ # Handle lists/tuples that might contain configs
181
+ result[key] = [
182
+ self._config_to_dict(item) if hasattr(item, "__dict__") else item for item in value
183
+ ]
184
+ elif isinstance(value, dict):
185
+ # Handle dicts that might contain configs
186
+ result[key] = {
187
+ k: self._config_to_dict(v) if hasattr(v, "__dict__") else v for k, v in value.items()
188
+ }
189
+ else:
190
+ result[key] = value
191
+ return result
192
+ return {}
193
+
194
+ def write_logs(self, state: State) -> None:
195
+ self.logger.write(state)
196
+
197
+ def __enter__(self) -> Self:
198
+ self.logger.__enter__()
199
+ return self
200
+
201
+ def __exit__(self, t: type[BaseException] | None, e: BaseException | None, tr: TracebackType | None) -> None:
202
+ self.logger.__exit__(t, e, tr)
203
+ return super().__exit__(t, e, tr)
@@ -146,3 +146,7 @@ def unfreeze(x: FrozenDict[K, V] | dict[str, Any]) -> dict[Any, Any]: # noqa: A
146
146
  return ys
147
147
  else:
148
148
  return x
149
+
150
+
151
+ def freeze_dict(x: Mapping[K, V]) -> FrozenDict[K, V]:
152
+ return FrozenDict(x)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.4.1
3
+ Version: 0.4.3
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -31,6 +31,8 @@ Requires-Dist: pytest; extra == "dev"
31
31
  Requires-Dist: types-pillow; extra == "dev"
32
32
  Requires-Dist: types-psutil; extra == "dev"
33
33
  Requires-Dist: types-requests; extra == "dev"
34
+ Provides-Extra: wandb
35
+ Requires-Dist: wandb[media]; extra == "wandb"
34
36
  Dynamic: author
35
37
  Dynamic: description
36
38
  Dynamic: description-content-type
@@ -44,6 +44,7 @@ xax/task/loggers/json.py
44
44
  xax/task/loggers/state.py
45
45
  xax/task/loggers/stdout.py
46
46
  xax/task/loggers/tensorboard.py
47
+ xax/task/loggers/wandb.py
47
48
  xax/task/mixins/__init__.py
48
49
  xax/task/mixins/artifacts.py
49
50
  xax/task/mixins/checkpointing.py
@@ -23,3 +23,6 @@ pytest
23
23
  types-pillow
24
24
  types-psutil
25
25
  types-requests
26
+
27
+ [wandb]
28
+ wandb[media]
@@ -1,31 +0,0 @@
1
- """Defines a launcher to train a model locally, in a single process."""
2
-
3
- from typing import TYPE_CHECKING
4
-
5
- from xax.task.base import RawConfigType
6
- from xax.task.launchers.base import BaseLauncher
7
- from xax.utils.logging import configure_logging
8
-
9
- if TYPE_CHECKING:
10
- from xax.task.mixins.runnable import Config, RunnableMixin
11
-
12
-
13
- def run_single_process_training(
14
- task: "type[RunnableMixin[Config]]",
15
- *cfgs: RawConfigType,
16
- use_cli: bool | list[str] = True,
17
- ) -> None:
18
- logger = configure_logging()
19
- task_obj = task.get_task(*cfgs, use_cli=use_cli)
20
- task_obj.add_logger_handlers(logger)
21
- task_obj.run()
22
-
23
-
24
- class SingleProcessLauncher(BaseLauncher):
25
- def launch(
26
- self,
27
- task: "type[RunnableMixin[Config]]",
28
- *cfgs: RawConfigType,
29
- use_cli: bool | list[str] = True,
30
- ) -> None:
31
- run_single_process_training(task, *cfgs, use_cli=use_cli)
@@ -1,92 +0,0 @@
1
- """Defines a mixin for incorporating some logging functionality."""
2
-
3
- import os
4
- from dataclasses import dataclass
5
- from pathlib import Path
6
- from types import TracebackType
7
- from typing import Generic, Self, TypeVar
8
-
9
- import jax
10
-
11
- from xax.core.conf import field
12
- from xax.core.state import State
13
- from xax.task.base import BaseConfig, BaseTask
14
- from xax.task.logger import Logger, LoggerImpl
15
- from xax.task.loggers.json import JsonLogger
16
- from xax.task.loggers.state import StateLogger
17
- from xax.task.loggers.stdout import StdoutLogger
18
- from xax.task.loggers.tensorboard import TensorboardLogger
19
- from xax.task.mixins.artifacts import ArtifactsMixin
20
- from xax.utils.text import is_interactive_session
21
-
22
-
23
- @jax.tree_util.register_dataclass
24
- @dataclass
25
- class LoggerConfig(BaseConfig):
26
- log_interval_seconds: float = field(
27
- value=1.0,
28
- help="The interval between successive log lines.",
29
- )
30
- tensorboard_log_interval_seconds: float = field(
31
- value=10.0,
32
- help="The interval between successive Tensorboard log lines.",
33
- )
34
-
35
-
36
- Config = TypeVar("Config", bound=LoggerConfig)
37
-
38
-
39
- def get_env_var(name: str, default: bool) -> bool:
40
- if name not in os.environ:
41
- return default
42
- return os.environ[name].strip() == "1"
43
-
44
-
45
- class LoggerMixin(BaseTask[Config], Generic[Config]):
46
- logger: Logger
47
-
48
- def __init__(self, config: Config) -> None:
49
- super().__init__(config)
50
-
51
- self.logger = Logger()
52
-
53
- def log_directory(self) -> Path | None:
54
- return None
55
-
56
- def add_logger(self, *logger: LoggerImpl) -> None:
57
- self.logger.add_logger(*logger)
58
-
59
- def set_loggers(self) -> None:
60
- self.add_logger(
61
- StdoutLogger(
62
- log_interval_seconds=self.config.log_interval_seconds,
63
- )
64
- if is_interactive_session()
65
- else JsonLogger(
66
- log_interval_seconds=self.config.log_interval_seconds,
67
- )
68
- )
69
-
70
- # If this is also an ArtifactsMixin, we should default add some
71
- # additional loggers which log data to the artifacts directory.
72
- if isinstance(self, ArtifactsMixin):
73
- self.add_logger(
74
- StateLogger(
75
- run_directory=self.exp_dir,
76
- ),
77
- TensorboardLogger(
78
- run_directory=self.exp_dir,
79
- log_interval_seconds=self.config.tensorboard_log_interval_seconds,
80
- ),
81
- )
82
-
83
- def write_logs(self, state: State) -> None:
84
- self.logger.write(state)
85
-
86
- def __enter__(self) -> Self:
87
- self.logger.__enter__()
88
- return self
89
-
90
- def __exit__(self, t: type[BaseException] | None, e: BaseException | None, tr: TracebackType | None) -> None:
91
- self.logger.__exit__(t, e, tr)
92
- return super().__exit__(t, e, tr)
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes