xax 0.3.12__tar.gz → 0.4.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {xax-0.3.12/xax.egg-info → xax-0.4.4}/PKG-INFO +3 -1
- {xax-0.3.12 → xax-0.4.4}/pyproject.toml +1 -0
- {xax-0.3.12 → xax-0.4.4}/setup.py +1 -0
- {xax-0.3.12 → xax-0.4.4}/xax/__init__.py +28 -10
- {xax-0.3.12 → xax-0.4.4}/xax/nn/geom.py +42 -13
- xax-0.4.4/xax/task/launchers/single_process.py +141 -0
- xax-0.4.4/xax/task/loggers/wandb.py +307 -0
- {xax-0.3.12 → xax-0.4.4}/xax/task/mixins/__init__.py +2 -1
- {xax-0.3.12 → xax-0.4.4}/xax/task/mixins/artifacts.py +1 -1
- xax-0.4.4/xax/task/mixins/logger.py +169 -0
- xax-0.4.4/xax/task/mixins/supervised.py +368 -0
- {xax-0.3.12 → xax-0.4.4}/xax/task/mixins/train.py +37 -346
- {xax-0.3.12 → xax-0.4.4}/xax/task/task.py +26 -2
- {xax-0.3.12 → xax-0.4.4}/xax/utils/debugging.py +20 -4
- {xax-0.3.12 → xax-0.4.4}/xax/utils/experiments.py +2 -2
- {xax-0.3.12 → xax-0.4.4}/xax/utils/pytree.py +3 -5
- {xax-0.3.12 → xax-0.4.4}/xax/utils/types/frozen_dict.py +4 -0
- {xax-0.3.12 → xax-0.4.4/xax.egg-info}/PKG-INFO +3 -1
- {xax-0.3.12 → xax-0.4.4}/xax.egg-info/SOURCES.txt +2 -0
- {xax-0.3.12 → xax-0.4.4}/xax.egg-info/requires.txt +3 -0
- xax-0.3.12/xax/task/launchers/single_process.py +0 -31
- xax-0.3.12/xax/task/mixins/logger.py +0 -92
- {xax-0.3.12 → xax-0.4.4}/LICENSE +0 -0
- {xax-0.3.12 → xax-0.4.4}/MANIFEST.in +0 -0
- {xax-0.3.12 → xax-0.4.4}/README.md +0 -0
- {xax-0.3.12 → xax-0.4.4}/setup.cfg +0 -0
- {xax-0.3.12 → xax-0.4.4}/xax/cli/__init__.py +0 -0
- {xax-0.3.12 → xax-0.4.4}/xax/cli/edit_config.py +0 -0
- {xax-0.3.12 → xax-0.4.4}/xax/core/__init__.py +0 -0
- {xax-0.3.12 → xax-0.4.4}/xax/core/conf.py +0 -0
- {xax-0.3.12 → xax-0.4.4}/xax/core/state.py +0 -0
- {xax-0.3.12 → xax-0.4.4}/xax/nn/__init__.py +0 -0
- {xax-0.3.12 → xax-0.4.4}/xax/nn/attention.py +0 -0
- {xax-0.3.12 → xax-0.4.4}/xax/nn/distributions.py +0 -0
- {xax-0.3.12 → xax-0.4.4}/xax/nn/embeddings.py +0 -0
- {xax-0.3.12 → xax-0.4.4}/xax/nn/functions.py +0 -0
- {xax-0.3.12 → xax-0.4.4}/xax/nn/losses.py +0 -0
- {xax-0.3.12 → xax-0.4.4}/xax/nn/metrics.py +0 -0
- {xax-0.3.12 → xax-0.4.4}/xax/nn/parallel.py +0 -0
- {xax-0.3.12 → xax-0.4.4}/xax/nn/ssm.py +0 -0
- {xax-0.3.12 → xax-0.4.4}/xax/py.typed +0 -0
- {xax-0.3.12 → xax-0.4.4}/xax/requirements-dev.txt +0 -0
- {xax-0.3.12 → xax-0.4.4}/xax/requirements.txt +0 -0
- {xax-0.3.12 → xax-0.4.4}/xax/task/__init__.py +0 -0
- {xax-0.3.12 → xax-0.4.4}/xax/task/base.py +0 -0
- {xax-0.3.12 → xax-0.4.4}/xax/task/launchers/__init__.py +0 -0
- {xax-0.3.12 → xax-0.4.4}/xax/task/launchers/base.py +0 -0
- {xax-0.3.12 → xax-0.4.4}/xax/task/launchers/cli.py +0 -0
- {xax-0.3.12 → xax-0.4.4}/xax/task/logger.py +0 -0
- {xax-0.3.12 → xax-0.4.4}/xax/task/loggers/__init__.py +0 -0
- {xax-0.3.12 → xax-0.4.4}/xax/task/loggers/callback.py +0 -0
- {xax-0.3.12 → xax-0.4.4}/xax/task/loggers/json.py +0 -0
- {xax-0.3.12 → xax-0.4.4}/xax/task/loggers/state.py +0 -0
- {xax-0.3.12 → xax-0.4.4}/xax/task/loggers/stdout.py +0 -0
- {xax-0.3.12 → xax-0.4.4}/xax/task/loggers/tensorboard.py +0 -0
- {xax-0.3.12 → xax-0.4.4}/xax/task/mixins/checkpointing.py +0 -0
- {xax-0.3.12 → xax-0.4.4}/xax/task/mixins/compile.py +0 -0
- {xax-0.3.12 → xax-0.4.4}/xax/task/mixins/cpu_stats.py +0 -0
- {xax-0.3.12 → xax-0.4.4}/xax/task/mixins/data_loader.py +0 -0
- {xax-0.3.12 → xax-0.4.4}/xax/task/mixins/gpu_stats.py +0 -0
- {xax-0.3.12 → xax-0.4.4}/xax/task/mixins/process.py +0 -0
- {xax-0.3.12 → xax-0.4.4}/xax/task/mixins/runnable.py +0 -0
- {xax-0.3.12 → xax-0.4.4}/xax/task/mixins/step_wrapper.py +0 -0
- {xax-0.3.12 → xax-0.4.4}/xax/task/script.py +0 -0
- {xax-0.3.12 → xax-0.4.4}/xax/utils/__init__.py +0 -0
- {xax-0.3.12 → xax-0.4.4}/xax/utils/data/__init__.py +0 -0
- {xax-0.3.12 → xax-0.4.4}/xax/utils/data/collate.py +0 -0
- {xax-0.3.12 → xax-0.4.4}/xax/utils/jax.py +0 -0
- {xax-0.3.12 → xax-0.4.4}/xax/utils/jaxpr.py +0 -0
- {xax-0.3.12 → xax-0.4.4}/xax/utils/logging.py +0 -0
- {xax-0.3.12 → xax-0.4.4}/xax/utils/numpy.py +0 -0
- {xax-0.3.12 → xax-0.4.4}/xax/utils/profile.py +0 -0
- {xax-0.3.12 → xax-0.4.4}/xax/utils/tensorboard.py +0 -0
- {xax-0.3.12 → xax-0.4.4}/xax/utils/text.py +0 -0
- {xax-0.3.12 → xax-0.4.4}/xax/utils/types/__init__.py +0 -0
- {xax-0.3.12 → xax-0.4.4}/xax/utils/types/hashable_array.py +0 -0
- {xax-0.3.12 → xax-0.4.4}/xax.egg-info/dependency_links.txt +0 -0
- {xax-0.3.12 → xax-0.4.4}/xax.egg-info/entry_points.txt +0 -0
- {xax-0.3.12 → xax-0.4.4}/xax.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: xax
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.4.4
|
4
4
|
Summary: A library for fast Jax experimentation
|
5
5
|
Home-page: https://github.com/kscalelabs/xax
|
6
6
|
Author: Benjamin Bolte
|
@@ -31,6 +31,8 @@ Requires-Dist: pytest; extra == "dev"
|
|
31
31
|
Requires-Dist: types-pillow; extra == "dev"
|
32
32
|
Requires-Dist: types-psutil; extra == "dev"
|
33
33
|
Requires-Dist: types-requests; extra == "dev"
|
34
|
+
Provides-Extra: wandb
|
35
|
+
Requires-Dist: wandb[media]; extra == "wandb"
|
34
36
|
Dynamic: author
|
35
37
|
Dynamic: description
|
36
38
|
Dynamic: description-content-type
|
@@ -12,7 +12,7 @@ and running the update script:
|
|
12
12
|
python -m scripts.update_api --inplace
|
13
13
|
"""
|
14
14
|
|
15
|
-
__version__ = "0.
|
15
|
+
__version__ = "0.4.4"
|
16
16
|
|
17
17
|
# This list shouldn't be modified by hand; instead, run the update script.
|
18
18
|
__all__ = [
|
@@ -53,6 +53,7 @@ __all__ = [
|
|
53
53
|
"quat_mul",
|
54
54
|
"quat_to_euler",
|
55
55
|
"quat_to_rotmat",
|
56
|
+
"quat_to_yaw",
|
56
57
|
"rotate_vector_by_quat",
|
57
58
|
"rotation6d_to_rotation_matrix",
|
58
59
|
"rotation_matrix_to_quat",
|
@@ -93,16 +94,19 @@ __all__ = [
|
|
93
94
|
"DataloaderConfig",
|
94
95
|
"GPUStatsOptions",
|
95
96
|
"StepContext",
|
97
|
+
"InitParams",
|
96
98
|
"ValidStepTimer",
|
97
99
|
"Script",
|
98
100
|
"ScriptConfig",
|
99
101
|
"Config",
|
102
|
+
"SupervisedConfig",
|
103
|
+
"SupervisedTask",
|
100
104
|
"Task",
|
101
105
|
"collate",
|
102
106
|
"collate_non_null",
|
103
|
-
"
|
107
|
+
"breakpoint_if_nonfinite",
|
104
108
|
"get_named_leaves",
|
105
|
-
"
|
109
|
+
"log_if_nonfinite",
|
106
110
|
"BaseFileDownloader",
|
107
111
|
"ContextTimer",
|
108
112
|
"CumulativeTimer",
|
@@ -167,6 +171,7 @@ __all__ = [
|
|
167
171
|
"uncolored",
|
168
172
|
"wrapped",
|
169
173
|
"FrozenDict",
|
174
|
+
"freeze_dict",
|
170
175
|
"HashableArray",
|
171
176
|
"hashable_array",
|
172
177
|
]
|
@@ -198,7 +203,10 @@ if "XLA_FLAGS" in os.environ:
|
|
198
203
|
# If Nvidia GPU is detected (meaning, is `nvidia-smi` available?), disable
|
199
204
|
# Triton GEMM kernels. See https://github.com/NVIDIA/JAX-Toolbox
|
200
205
|
if shutil.which("nvidia-smi") is not None:
|
201
|
-
xla_flags += [
|
206
|
+
xla_flags += [
|
207
|
+
"--xla_gpu_enable_latency_hiding_scheduler=true",
|
208
|
+
"--xla_gpu_enable_triton_gemm=false",
|
209
|
+
]
|
202
210
|
os.environ["XLA_FLAGS"] = " ".join(xla_flags)
|
203
211
|
|
204
212
|
# If this flag is set, eagerly imports the entire package (not recommended).
|
@@ -246,6 +254,7 @@ NAME_MAP: dict[str, str] = {
|
|
246
254
|
"quat_mul": "nn.geom",
|
247
255
|
"quat_to_euler": "nn.geom",
|
248
256
|
"quat_to_rotmat": "nn.geom",
|
257
|
+
"quat_to_yaw": "nn.geom",
|
249
258
|
"rotate_vector_by_quat": "nn.geom",
|
250
259
|
"rotation6d_to_rotation_matrix": "nn.geom",
|
251
260
|
"rotation_matrix_to_quat": "nn.geom",
|
@@ -286,16 +295,19 @@ NAME_MAP: dict[str, str] = {
|
|
286
295
|
"DataloaderConfig": "task.mixins.data_loader",
|
287
296
|
"GPUStatsOptions": "task.mixins.gpu_stats",
|
288
297
|
"StepContext": "task.mixins.step_wrapper",
|
298
|
+
"InitParams": "task.mixins.train",
|
289
299
|
"ValidStepTimer": "task.mixins.train",
|
290
300
|
"Script": "task.script",
|
291
301
|
"ScriptConfig": "task.script",
|
292
302
|
"Config": "task.task",
|
303
|
+
"SupervisedConfig": "task.task",
|
304
|
+
"SupervisedTask": "task.task",
|
293
305
|
"Task": "task.task",
|
294
306
|
"collate": "utils.data.collate",
|
295
307
|
"collate_non_null": "utils.data.collate",
|
296
|
-
"
|
308
|
+
"breakpoint_if_nonfinite": "utils.debugging",
|
297
309
|
"get_named_leaves": "utils.debugging",
|
298
|
-
"
|
310
|
+
"log_if_nonfinite": "utils.debugging",
|
299
311
|
"BaseFileDownloader": "utils.experiments",
|
300
312
|
"ContextTimer": "utils.experiments",
|
301
313
|
"CumulativeTimer": "utils.experiments",
|
@@ -360,6 +372,7 @@ NAME_MAP: dict[str, str] = {
|
|
360
372
|
"uncolored": "utils.text",
|
361
373
|
"wrapped": "utils.text",
|
362
374
|
"FrozenDict": "utils.types.frozen_dict",
|
375
|
+
"freeze_dict": "utils.types.frozen_dict",
|
363
376
|
"HashableArray": "utils.types.hashable_array",
|
364
377
|
"hashable_array": "utils.types.hashable_array",
|
365
378
|
}
|
@@ -443,6 +456,7 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
443
456
|
quat_mul,
|
444
457
|
quat_to_euler,
|
445
458
|
quat_to_rotmat,
|
459
|
+
quat_to_yaw,
|
446
460
|
rotate_vector_by_quat,
|
447
461
|
rotation6d_to_rotation_matrix,
|
448
462
|
rotation_matrix_to_quat,
|
@@ -482,11 +496,15 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
482
496
|
from xax.task.mixins.data_loader import DataloaderConfig
|
483
497
|
from xax.task.mixins.gpu_stats import GPUStatsOptions
|
484
498
|
from xax.task.mixins.step_wrapper import StepContext
|
485
|
-
from xax.task.mixins.train import Batch, Output, ValidStepTimer
|
499
|
+
from xax.task.mixins.train import Batch, InitParams, Output, ValidStepTimer
|
486
500
|
from xax.task.script import Script, ScriptConfig
|
487
|
-
from xax.task.task import Config, Task
|
501
|
+
from xax.task.task import Config, SupervisedConfig, SupervisedTask, Task
|
488
502
|
from xax.utils.data.collate import CollateMode, collate, collate_non_null
|
489
|
-
from xax.utils.debugging import
|
503
|
+
from xax.utils.debugging import (
|
504
|
+
breakpoint_if_nonfinite,
|
505
|
+
get_named_leaves,
|
506
|
+
log_if_nonfinite,
|
507
|
+
)
|
490
508
|
from xax.utils.experiments import (
|
491
509
|
BaseFileDownloader,
|
492
510
|
ContextTimer,
|
@@ -556,7 +574,7 @@ if IMPORT_ALL or TYPE_CHECKING:
|
|
556
574
|
uncolored,
|
557
575
|
wrapped,
|
558
576
|
)
|
559
|
-
from xax.utils.types.frozen_dict import FrozenDict
|
577
|
+
from xax.utils.types.frozen_dict import FrozenDict, freeze_dict
|
560
578
|
from xax.utils.types.hashable_array import HashableArray, hashable_array
|
561
579
|
|
562
580
|
del TYPE_CHECKING, IMPORT_ALL
|
@@ -1,6 +1,7 @@
|
|
1
1
|
"""Defines geometry functions."""
|
2
2
|
|
3
3
|
import chex
|
4
|
+
import jax
|
4
5
|
from jax import numpy as jnp
|
5
6
|
from jaxtyping import Array
|
6
7
|
|
@@ -15,30 +16,53 @@ def quat_to_euler(quat_4: Array, eps: float = 1e-6) -> Array:
|
|
15
16
|
Returns:
|
16
17
|
The roll, pitch, yaw angles with shape (*, 3).
|
17
18
|
"""
|
18
|
-
|
19
|
-
|
19
|
+
# Normalize with clamping
|
20
|
+
norm_sq = jnp.sum(quat_4**2, axis=-1, keepdims=True)
|
21
|
+
inv_norm = jax.lax.rsqrt(jnp.maximum(norm_sq, eps))
|
22
|
+
quat_4 = quat_4 * inv_norm
|
23
|
+
|
24
|
+
w, x, y, z = jnp.unstack(quat_4, axis=-1)
|
20
25
|
|
21
26
|
# Roll (x-axis rotation)
|
22
27
|
sinr_cosp = 2.0 * (w * x + y * z)
|
23
28
|
cosr_cosp = 1.0 - 2.0 * (x * x + y * y)
|
24
|
-
roll =
|
29
|
+
roll = jax.lax.atan2(sinr_cosp, cosr_cosp)
|
25
30
|
|
26
31
|
# Pitch (y-axis rotation)
|
27
32
|
sinp = 2.0 * (w * y - z * x)
|
28
|
-
|
29
|
-
|
30
|
-
pitch = jnp.where(
|
31
|
-
jnp.abs(sinp) >= 1.0,
|
32
|
-
jnp.sign(sinp) * jnp.pi / 2.0, # Use 90 degrees if out of range
|
33
|
-
jnp.arcsin(sinp),
|
34
|
-
)
|
33
|
+
sinp = jnp.clip(sinp, -1.0, 1.0) # Clamp to valid domain
|
34
|
+
pitch = jax.lax.asin(sinp)
|
35
35
|
|
36
36
|
# Yaw (z-axis rotation)
|
37
37
|
siny_cosp = 2.0 * (w * z + x * y)
|
38
38
|
cosy_cosp = 1.0 - 2.0 * (y * y + z * z)
|
39
|
-
yaw =
|
39
|
+
yaw = jax.lax.atan2(siny_cosp, cosy_cosp)
|
40
|
+
|
41
|
+
return jnp.stack([roll, pitch, yaw], axis=-1)
|
42
|
+
|
43
|
+
|
44
|
+
def quat_to_yaw(quat_4: Array, eps: float = 1e-6) -> Array:
|
45
|
+
"""Converts a quaternion to a yaw angle.
|
46
|
+
|
47
|
+
Args:
|
48
|
+
quat_4: The quaternion to convert, shape (*, 4).
|
49
|
+
eps: A small epsilon value to avoid division by zero.
|
50
|
+
|
51
|
+
Returns:
|
52
|
+
The yaw angle, shape (*).
|
53
|
+
"""
|
54
|
+
# Normalize using a max + safe norm to handle extremely small values robustly
|
55
|
+
norm_sq = jnp.sum(quat_4**2, axis=-1, keepdims=True)
|
56
|
+
inv_norm = jax.lax.rsqrt(jnp.maximum(norm_sq, eps))
|
57
|
+
quat_4 = quat_4 * inv_norm
|
58
|
+
|
59
|
+
w, x, y, z = jnp.unstack(quat_4, axis=-1)
|
60
|
+
|
61
|
+
# Compute components with clamping to avoid rounding errors near limits
|
62
|
+
siny_cosp = 2.0 * (w * z + x * y)
|
63
|
+
cosy_cosp = 1.0 - 2.0 * (y * y + z * z)
|
40
64
|
|
41
|
-
return
|
65
|
+
return jax.lax.atan2(siny_cosp, cosy_cosp)
|
42
66
|
|
43
67
|
|
44
68
|
def euler_to_quat(euler_3: Array) -> Array:
|
@@ -89,7 +113,12 @@ def get_projected_gravity_vector_from_quat(quat: Array, eps: float = 1e-6) -> Ar
|
|
89
113
|
return rotate_vector_by_quat(jnp.array([0, 0, -9.81]), quat, inverse=True, eps=eps)
|
90
114
|
|
91
115
|
|
92
|
-
def rotate_vector_by_quat(
|
116
|
+
def rotate_vector_by_quat(
|
117
|
+
vector: Array,
|
118
|
+
quat: Array,
|
119
|
+
inverse: bool = False,
|
120
|
+
eps: float = 1e-6,
|
121
|
+
) -> Array:
|
93
122
|
"""Rotates a vector by a quaternion.
|
94
123
|
|
95
124
|
Args:
|
@@ -0,0 +1,141 @@
|
|
1
|
+
"""Defines a launcher to train a model locally, in a single process."""
|
2
|
+
|
3
|
+
import logging
|
4
|
+
import os
|
5
|
+
import shutil
|
6
|
+
import subprocess
|
7
|
+
from typing import TYPE_CHECKING
|
8
|
+
|
9
|
+
import jax
|
10
|
+
|
11
|
+
from xax.task.base import RawConfigType
|
12
|
+
from xax.task.launchers.base import BaseLauncher
|
13
|
+
from xax.task.mixins.gpu_stats import get_num_gpus
|
14
|
+
from xax.utils.logging import configure_logging
|
15
|
+
|
16
|
+
if TYPE_CHECKING:
|
17
|
+
from xax.task.mixins.runnable import Config, RunnableMixin
|
18
|
+
|
19
|
+
|
20
|
+
def get_gpu_memory_info() -> dict[int, tuple[float, float]]:
|
21
|
+
"""Get memory information for all GPUs.
|
22
|
+
|
23
|
+
Returns:
|
24
|
+
Dictionary mapping GPU index to (total_memory_mb, used_memory_mb)
|
25
|
+
"""
|
26
|
+
command = "nvidia-smi --query-gpu=index,memory.total,memory.used --format=csv,noheader"
|
27
|
+
|
28
|
+
try:
|
29
|
+
with subprocess.Popen(command.split(), stdout=subprocess.PIPE, universal_newlines=True) as proc:
|
30
|
+
stdout = proc.stdout
|
31
|
+
assert stdout is not None
|
32
|
+
|
33
|
+
gpu_info = {}
|
34
|
+
for line in stdout:
|
35
|
+
line = line.strip()
|
36
|
+
if not line:
|
37
|
+
continue
|
38
|
+
|
39
|
+
parts = line.split(", ")
|
40
|
+
if len(parts) >= 3:
|
41
|
+
gpu_id = int(parts[0])
|
42
|
+
total_mem = float(parts[1].replace(" MiB", ""))
|
43
|
+
used_mem = float(parts[2].replace(" MiB", ""))
|
44
|
+
gpu_info[gpu_id] = (total_mem, used_mem)
|
45
|
+
|
46
|
+
return gpu_info
|
47
|
+
|
48
|
+
except Exception as e:
|
49
|
+
logger = configure_logging()
|
50
|
+
logger.warning("Failed to get GPU memory info: %s", e)
|
51
|
+
return {}
|
52
|
+
|
53
|
+
|
54
|
+
def select_best_gpu() -> int | None:
|
55
|
+
"""Select the GPU with the most available memory.
|
56
|
+
|
57
|
+
Returns:
|
58
|
+
GPU index with most available memory, or None if no GPUs found
|
59
|
+
"""
|
60
|
+
gpu_info = get_gpu_memory_info()
|
61
|
+
|
62
|
+
if not gpu_info:
|
63
|
+
return None
|
64
|
+
|
65
|
+
# Find GPU with most available memory
|
66
|
+
best_gpu = None
|
67
|
+
max_available: float = -1.0
|
68
|
+
|
69
|
+
for gpu_id, (total_mem, used_mem) in gpu_info.items():
|
70
|
+
available_mem = total_mem - used_mem
|
71
|
+
if available_mem > max_available:
|
72
|
+
max_available = available_mem
|
73
|
+
best_gpu = gpu_id
|
74
|
+
|
75
|
+
return best_gpu
|
76
|
+
|
77
|
+
|
78
|
+
def configure_gpu_devices(logger: logging.Logger | None = None) -> None:
|
79
|
+
if logger is None:
|
80
|
+
logger = configure_logging()
|
81
|
+
|
82
|
+
# If there are multiple devices, choose the one with the most
|
83
|
+
# available memory (i.e., the one which is likely not being used
|
84
|
+
# by other processes) and use only that device.
|
85
|
+
num_gpus = get_num_gpus()
|
86
|
+
|
87
|
+
if num_gpus > 1:
|
88
|
+
logger.info("Multiple GPUs detected (%d), selecting GPU with most available memory", num_gpus)
|
89
|
+
|
90
|
+
best_gpu = select_best_gpu()
|
91
|
+
if best_gpu is not None:
|
92
|
+
logger.info("Selected GPU %d for training", best_gpu)
|
93
|
+
|
94
|
+
# Set CUDA_VISIBLE_DEVICES to only show the selected GPU
|
95
|
+
os.environ["CUDA_VISIBLE_DEVICES"] = str(best_gpu)
|
96
|
+
|
97
|
+
# Configure JAX to use the selected device
|
98
|
+
try:
|
99
|
+
devices = jax.devices("gpu")
|
100
|
+
if devices:
|
101
|
+
jax.config.update("jax_default_device", devices[0])
|
102
|
+
logger.info("Configured JAX to use device: %s", devices[0])
|
103
|
+
except Exception as e:
|
104
|
+
logger.warning("Failed to configure JAX device: %s", e)
|
105
|
+
else:
|
106
|
+
logger.warning("Could not determine best GPU, using default device selection")
|
107
|
+
elif num_gpus == 1:
|
108
|
+
logger.info("Single GPU detected, using default device selection")
|
109
|
+
|
110
|
+
|
111
|
+
def configure_devices(logger: logging.Logger | None = None) -> None:
|
112
|
+
if logger is None:
|
113
|
+
logger = configure_logging()
|
114
|
+
|
115
|
+
if shutil.which("nvidia-smi") is not None:
|
116
|
+
configure_gpu_devices(logger)
|
117
|
+
|
118
|
+
|
119
|
+
def run_single_process_training(
|
120
|
+
task: "type[RunnableMixin[Config]]",
|
121
|
+
*cfgs: RawConfigType,
|
122
|
+
use_cli: bool | list[str] = True,
|
123
|
+
logger: logging.Logger | None = None,
|
124
|
+
) -> None:
|
125
|
+
if logger is None:
|
126
|
+
logger = configure_logging()
|
127
|
+
task_obj = task.get_task(*cfgs, use_cli=use_cli)
|
128
|
+
task_obj.add_logger_handlers(logger)
|
129
|
+
task_obj.run()
|
130
|
+
|
131
|
+
|
132
|
+
class SingleProcessLauncher(BaseLauncher):
|
133
|
+
def launch(
|
134
|
+
self,
|
135
|
+
task: "type[RunnableMixin[Config]]",
|
136
|
+
*cfgs: RawConfigType,
|
137
|
+
use_cli: bool | list[str] = True,
|
138
|
+
) -> None:
|
139
|
+
logger = configure_logging()
|
140
|
+
configure_devices(logger)
|
141
|
+
run_single_process_training(task, *cfgs, use_cli=use_cli, logger=logger)
|