xax 0.3.12__tar.gz → 0.4.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. {xax-0.3.12/xax.egg-info → xax-0.4.4}/PKG-INFO +3 -1
  2. {xax-0.3.12 → xax-0.4.4}/pyproject.toml +1 -0
  3. {xax-0.3.12 → xax-0.4.4}/setup.py +1 -0
  4. {xax-0.3.12 → xax-0.4.4}/xax/__init__.py +28 -10
  5. {xax-0.3.12 → xax-0.4.4}/xax/nn/geom.py +42 -13
  6. xax-0.4.4/xax/task/launchers/single_process.py +141 -0
  7. xax-0.4.4/xax/task/loggers/wandb.py +307 -0
  8. {xax-0.3.12 → xax-0.4.4}/xax/task/mixins/__init__.py +2 -1
  9. {xax-0.3.12 → xax-0.4.4}/xax/task/mixins/artifacts.py +1 -1
  10. xax-0.4.4/xax/task/mixins/logger.py +169 -0
  11. xax-0.4.4/xax/task/mixins/supervised.py +368 -0
  12. {xax-0.3.12 → xax-0.4.4}/xax/task/mixins/train.py +37 -346
  13. {xax-0.3.12 → xax-0.4.4}/xax/task/task.py +26 -2
  14. {xax-0.3.12 → xax-0.4.4}/xax/utils/debugging.py +20 -4
  15. {xax-0.3.12 → xax-0.4.4}/xax/utils/experiments.py +2 -2
  16. {xax-0.3.12 → xax-0.4.4}/xax/utils/pytree.py +3 -5
  17. {xax-0.3.12 → xax-0.4.4}/xax/utils/types/frozen_dict.py +4 -0
  18. {xax-0.3.12 → xax-0.4.4/xax.egg-info}/PKG-INFO +3 -1
  19. {xax-0.3.12 → xax-0.4.4}/xax.egg-info/SOURCES.txt +2 -0
  20. {xax-0.3.12 → xax-0.4.4}/xax.egg-info/requires.txt +3 -0
  21. xax-0.3.12/xax/task/launchers/single_process.py +0 -31
  22. xax-0.3.12/xax/task/mixins/logger.py +0 -92
  23. {xax-0.3.12 → xax-0.4.4}/LICENSE +0 -0
  24. {xax-0.3.12 → xax-0.4.4}/MANIFEST.in +0 -0
  25. {xax-0.3.12 → xax-0.4.4}/README.md +0 -0
  26. {xax-0.3.12 → xax-0.4.4}/setup.cfg +0 -0
  27. {xax-0.3.12 → xax-0.4.4}/xax/cli/__init__.py +0 -0
  28. {xax-0.3.12 → xax-0.4.4}/xax/cli/edit_config.py +0 -0
  29. {xax-0.3.12 → xax-0.4.4}/xax/core/__init__.py +0 -0
  30. {xax-0.3.12 → xax-0.4.4}/xax/core/conf.py +0 -0
  31. {xax-0.3.12 → xax-0.4.4}/xax/core/state.py +0 -0
  32. {xax-0.3.12 → xax-0.4.4}/xax/nn/__init__.py +0 -0
  33. {xax-0.3.12 → xax-0.4.4}/xax/nn/attention.py +0 -0
  34. {xax-0.3.12 → xax-0.4.4}/xax/nn/distributions.py +0 -0
  35. {xax-0.3.12 → xax-0.4.4}/xax/nn/embeddings.py +0 -0
  36. {xax-0.3.12 → xax-0.4.4}/xax/nn/functions.py +0 -0
  37. {xax-0.3.12 → xax-0.4.4}/xax/nn/losses.py +0 -0
  38. {xax-0.3.12 → xax-0.4.4}/xax/nn/metrics.py +0 -0
  39. {xax-0.3.12 → xax-0.4.4}/xax/nn/parallel.py +0 -0
  40. {xax-0.3.12 → xax-0.4.4}/xax/nn/ssm.py +0 -0
  41. {xax-0.3.12 → xax-0.4.4}/xax/py.typed +0 -0
  42. {xax-0.3.12 → xax-0.4.4}/xax/requirements-dev.txt +0 -0
  43. {xax-0.3.12 → xax-0.4.4}/xax/requirements.txt +0 -0
  44. {xax-0.3.12 → xax-0.4.4}/xax/task/__init__.py +0 -0
  45. {xax-0.3.12 → xax-0.4.4}/xax/task/base.py +0 -0
  46. {xax-0.3.12 → xax-0.4.4}/xax/task/launchers/__init__.py +0 -0
  47. {xax-0.3.12 → xax-0.4.4}/xax/task/launchers/base.py +0 -0
  48. {xax-0.3.12 → xax-0.4.4}/xax/task/launchers/cli.py +0 -0
  49. {xax-0.3.12 → xax-0.4.4}/xax/task/logger.py +0 -0
  50. {xax-0.3.12 → xax-0.4.4}/xax/task/loggers/__init__.py +0 -0
  51. {xax-0.3.12 → xax-0.4.4}/xax/task/loggers/callback.py +0 -0
  52. {xax-0.3.12 → xax-0.4.4}/xax/task/loggers/json.py +0 -0
  53. {xax-0.3.12 → xax-0.4.4}/xax/task/loggers/state.py +0 -0
  54. {xax-0.3.12 → xax-0.4.4}/xax/task/loggers/stdout.py +0 -0
  55. {xax-0.3.12 → xax-0.4.4}/xax/task/loggers/tensorboard.py +0 -0
  56. {xax-0.3.12 → xax-0.4.4}/xax/task/mixins/checkpointing.py +0 -0
  57. {xax-0.3.12 → xax-0.4.4}/xax/task/mixins/compile.py +0 -0
  58. {xax-0.3.12 → xax-0.4.4}/xax/task/mixins/cpu_stats.py +0 -0
  59. {xax-0.3.12 → xax-0.4.4}/xax/task/mixins/data_loader.py +0 -0
  60. {xax-0.3.12 → xax-0.4.4}/xax/task/mixins/gpu_stats.py +0 -0
  61. {xax-0.3.12 → xax-0.4.4}/xax/task/mixins/process.py +0 -0
  62. {xax-0.3.12 → xax-0.4.4}/xax/task/mixins/runnable.py +0 -0
  63. {xax-0.3.12 → xax-0.4.4}/xax/task/mixins/step_wrapper.py +0 -0
  64. {xax-0.3.12 → xax-0.4.4}/xax/task/script.py +0 -0
  65. {xax-0.3.12 → xax-0.4.4}/xax/utils/__init__.py +0 -0
  66. {xax-0.3.12 → xax-0.4.4}/xax/utils/data/__init__.py +0 -0
  67. {xax-0.3.12 → xax-0.4.4}/xax/utils/data/collate.py +0 -0
  68. {xax-0.3.12 → xax-0.4.4}/xax/utils/jax.py +0 -0
  69. {xax-0.3.12 → xax-0.4.4}/xax/utils/jaxpr.py +0 -0
  70. {xax-0.3.12 → xax-0.4.4}/xax/utils/logging.py +0 -0
  71. {xax-0.3.12 → xax-0.4.4}/xax/utils/numpy.py +0 -0
  72. {xax-0.3.12 → xax-0.4.4}/xax/utils/profile.py +0 -0
  73. {xax-0.3.12 → xax-0.4.4}/xax/utils/tensorboard.py +0 -0
  74. {xax-0.3.12 → xax-0.4.4}/xax/utils/text.py +0 -0
  75. {xax-0.3.12 → xax-0.4.4}/xax/utils/types/__init__.py +0 -0
  76. {xax-0.3.12 → xax-0.4.4}/xax/utils/types/hashable_array.py +0 -0
  77. {xax-0.3.12 → xax-0.4.4}/xax.egg-info/dependency_links.txt +0 -0
  78. {xax-0.3.12 → xax-0.4.4}/xax.egg-info/entry_points.txt +0 -0
  79. {xax-0.3.12 → xax-0.4.4}/xax.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: xax
3
- Version: 0.3.12
3
+ Version: 0.4.4
4
4
  Summary: A library for fast Jax experimentation
5
5
  Home-page: https://github.com/kscalelabs/xax
6
6
  Author: Benjamin Bolte
@@ -31,6 +31,8 @@ Requires-Dist: pytest; extra == "dev"
31
31
  Requires-Dist: types-pillow; extra == "dev"
32
32
  Requires-Dist: types-psutil; extra == "dev"
33
33
  Requires-Dist: types-requests; extra == "dev"
34
+ Provides-Extra: wandb
35
+ Requires-Dist: wandb[media]; extra == "wandb"
34
36
  Dynamic: author
35
37
  Dynamic: description
36
38
  Dynamic: description-content-type
@@ -35,6 +35,7 @@ explicit_package_bases = true
35
35
  [[tool.mypy.overrides]]
36
36
 
37
37
  module = [
38
+ "chex.*",
38
39
  "optax.*",
39
40
  "setuptools.*",
40
41
  "tensorboard.*",
@@ -33,6 +33,7 @@ setup(
33
33
  tests_require=requirements_dev,
34
34
  extras_require={
35
35
  "dev": requirements_dev,
36
+ "wandb": ["wandb[media]"],
36
37
  },
37
38
  package_data={
38
39
  "xax": [
@@ -12,7 +12,7 @@ and running the update script:
12
12
  python -m scripts.update_api --inplace
13
13
  """
14
14
 
15
- __version__ = "0.3.12"
15
+ __version__ = "0.4.4"
16
16
 
17
17
  # This list shouldn't be modified by hand; instead, run the update script.
18
18
  __all__ = [
@@ -53,6 +53,7 @@ __all__ = [
53
53
  "quat_mul",
54
54
  "quat_to_euler",
55
55
  "quat_to_rotmat",
56
+ "quat_to_yaw",
56
57
  "rotate_vector_by_quat",
57
58
  "rotation6d_to_rotation_matrix",
58
59
  "rotation_matrix_to_quat",
@@ -93,16 +94,19 @@ __all__ = [
93
94
  "DataloaderConfig",
94
95
  "GPUStatsOptions",
95
96
  "StepContext",
97
+ "InitParams",
96
98
  "ValidStepTimer",
97
99
  "Script",
98
100
  "ScriptConfig",
99
101
  "Config",
102
+ "SupervisedConfig",
103
+ "SupervisedTask",
100
104
  "Task",
101
105
  "collate",
102
106
  "collate_non_null",
103
- "breakpoint_if_nan",
107
+ "breakpoint_if_nonfinite",
104
108
  "get_named_leaves",
105
- "log_if_nan",
109
+ "log_if_nonfinite",
106
110
  "BaseFileDownloader",
107
111
  "ContextTimer",
108
112
  "CumulativeTimer",
@@ -167,6 +171,7 @@ __all__ = [
167
171
  "uncolored",
168
172
  "wrapped",
169
173
  "FrozenDict",
174
+ "freeze_dict",
170
175
  "HashableArray",
171
176
  "hashable_array",
172
177
  ]
@@ -198,7 +203,10 @@ if "XLA_FLAGS" in os.environ:
198
203
  # If Nvidia GPU is detected (meaning, is `nvidia-smi` available?), disable
199
204
  # Triton GEMM kernels. See https://github.com/NVIDIA/JAX-Toolbox
200
205
  if shutil.which("nvidia-smi") is not None:
201
- xla_flags += ["--xla_gpu_enable_latency_hiding_scheduler=true", "--xla_gpu_enable_triton_gemm=false"]
206
+ xla_flags += [
207
+ "--xla_gpu_enable_latency_hiding_scheduler=true",
208
+ "--xla_gpu_enable_triton_gemm=false",
209
+ ]
202
210
  os.environ["XLA_FLAGS"] = " ".join(xla_flags)
203
211
 
204
212
  # If this flag is set, eagerly imports the entire package (not recommended).
@@ -246,6 +254,7 @@ NAME_MAP: dict[str, str] = {
246
254
  "quat_mul": "nn.geom",
247
255
  "quat_to_euler": "nn.geom",
248
256
  "quat_to_rotmat": "nn.geom",
257
+ "quat_to_yaw": "nn.geom",
249
258
  "rotate_vector_by_quat": "nn.geom",
250
259
  "rotation6d_to_rotation_matrix": "nn.geom",
251
260
  "rotation_matrix_to_quat": "nn.geom",
@@ -286,16 +295,19 @@ NAME_MAP: dict[str, str] = {
286
295
  "DataloaderConfig": "task.mixins.data_loader",
287
296
  "GPUStatsOptions": "task.mixins.gpu_stats",
288
297
  "StepContext": "task.mixins.step_wrapper",
298
+ "InitParams": "task.mixins.train",
289
299
  "ValidStepTimer": "task.mixins.train",
290
300
  "Script": "task.script",
291
301
  "ScriptConfig": "task.script",
292
302
  "Config": "task.task",
303
+ "SupervisedConfig": "task.task",
304
+ "SupervisedTask": "task.task",
293
305
  "Task": "task.task",
294
306
  "collate": "utils.data.collate",
295
307
  "collate_non_null": "utils.data.collate",
296
- "breakpoint_if_nan": "utils.debugging",
308
+ "breakpoint_if_nonfinite": "utils.debugging",
297
309
  "get_named_leaves": "utils.debugging",
298
- "log_if_nan": "utils.debugging",
310
+ "log_if_nonfinite": "utils.debugging",
299
311
  "BaseFileDownloader": "utils.experiments",
300
312
  "ContextTimer": "utils.experiments",
301
313
  "CumulativeTimer": "utils.experiments",
@@ -360,6 +372,7 @@ NAME_MAP: dict[str, str] = {
360
372
  "uncolored": "utils.text",
361
373
  "wrapped": "utils.text",
362
374
  "FrozenDict": "utils.types.frozen_dict",
375
+ "freeze_dict": "utils.types.frozen_dict",
363
376
  "HashableArray": "utils.types.hashable_array",
364
377
  "hashable_array": "utils.types.hashable_array",
365
378
  }
@@ -443,6 +456,7 @@ if IMPORT_ALL or TYPE_CHECKING:
443
456
  quat_mul,
444
457
  quat_to_euler,
445
458
  quat_to_rotmat,
459
+ quat_to_yaw,
446
460
  rotate_vector_by_quat,
447
461
  rotation6d_to_rotation_matrix,
448
462
  rotation_matrix_to_quat,
@@ -482,11 +496,15 @@ if IMPORT_ALL or TYPE_CHECKING:
482
496
  from xax.task.mixins.data_loader import DataloaderConfig
483
497
  from xax.task.mixins.gpu_stats import GPUStatsOptions
484
498
  from xax.task.mixins.step_wrapper import StepContext
485
- from xax.task.mixins.train import Batch, Output, ValidStepTimer
499
+ from xax.task.mixins.train import Batch, InitParams, Output, ValidStepTimer
486
500
  from xax.task.script import Script, ScriptConfig
487
- from xax.task.task import Config, Task
501
+ from xax.task.task import Config, SupervisedConfig, SupervisedTask, Task
488
502
  from xax.utils.data.collate import CollateMode, collate, collate_non_null
489
- from xax.utils.debugging import breakpoint_if_nan, get_named_leaves, log_if_nan
503
+ from xax.utils.debugging import (
504
+ breakpoint_if_nonfinite,
505
+ get_named_leaves,
506
+ log_if_nonfinite,
507
+ )
490
508
  from xax.utils.experiments import (
491
509
  BaseFileDownloader,
492
510
  ContextTimer,
@@ -556,7 +574,7 @@ if IMPORT_ALL or TYPE_CHECKING:
556
574
  uncolored,
557
575
  wrapped,
558
576
  )
559
- from xax.utils.types.frozen_dict import FrozenDict
577
+ from xax.utils.types.frozen_dict import FrozenDict, freeze_dict
560
578
  from xax.utils.types.hashable_array import HashableArray, hashable_array
561
579
 
562
580
  del TYPE_CHECKING, IMPORT_ALL
@@ -1,6 +1,7 @@
1
1
  """Defines geometry functions."""
2
2
 
3
3
  import chex
4
+ import jax
4
5
  from jax import numpy as jnp
5
6
  from jaxtyping import Array
6
7
 
@@ -15,30 +16,53 @@ def quat_to_euler(quat_4: Array, eps: float = 1e-6) -> Array:
15
16
  Returns:
16
17
  The roll, pitch, yaw angles with shape (*, 3).
17
18
  """
18
- quat_4 = quat_4 / (jnp.linalg.norm(quat_4, axis=-1, keepdims=True) + eps)
19
- w, x, y, z = jnp.split(quat_4, 4, axis=-1)
19
+ # Normalize with clamping
20
+ norm_sq = jnp.sum(quat_4**2, axis=-1, keepdims=True)
21
+ inv_norm = jax.lax.rsqrt(jnp.maximum(norm_sq, eps))
22
+ quat_4 = quat_4 * inv_norm
23
+
24
+ w, x, y, z = jnp.unstack(quat_4, axis=-1)
20
25
 
21
26
  # Roll (x-axis rotation)
22
27
  sinr_cosp = 2.0 * (w * x + y * z)
23
28
  cosr_cosp = 1.0 - 2.0 * (x * x + y * y)
24
- roll = jnp.arctan2(sinr_cosp, cosr_cosp)
29
+ roll = jax.lax.atan2(sinr_cosp, cosr_cosp)
25
30
 
26
31
  # Pitch (y-axis rotation)
27
32
  sinp = 2.0 * (w * y - z * x)
28
-
29
- # Handle edge cases where |sinp| >= 1
30
- pitch = jnp.where(
31
- jnp.abs(sinp) >= 1.0,
32
- jnp.sign(sinp) * jnp.pi / 2.0, # Use 90 degrees if out of range
33
- jnp.arcsin(sinp),
34
- )
33
+ sinp = jnp.clip(sinp, -1.0, 1.0) # Clamp to valid domain
34
+ pitch = jax.lax.asin(sinp)
35
35
 
36
36
  # Yaw (z-axis rotation)
37
37
  siny_cosp = 2.0 * (w * z + x * y)
38
38
  cosy_cosp = 1.0 - 2.0 * (y * y + z * z)
39
- yaw = jnp.arctan2(siny_cosp, cosy_cosp)
39
+ yaw = jax.lax.atan2(siny_cosp, cosy_cosp)
40
+
41
+ return jnp.stack([roll, pitch, yaw], axis=-1)
42
+
43
+
44
+ def quat_to_yaw(quat_4: Array, eps: float = 1e-6) -> Array:
45
+ """Converts a quaternion to a yaw angle.
46
+
47
+ Args:
48
+ quat_4: The quaternion to convert, shape (*, 4).
49
+ eps: A small epsilon value to avoid division by zero.
50
+
51
+ Returns:
52
+ The yaw angle, shape (*).
53
+ """
54
+ # Normalize using a max + safe norm to handle extremely small values robustly
55
+ norm_sq = jnp.sum(quat_4**2, axis=-1, keepdims=True)
56
+ inv_norm = jax.lax.rsqrt(jnp.maximum(norm_sq, eps))
57
+ quat_4 = quat_4 * inv_norm
58
+
59
+ w, x, y, z = jnp.unstack(quat_4, axis=-1)
60
+
61
+ # Compute components with clamping to avoid rounding errors near limits
62
+ siny_cosp = 2.0 * (w * z + x * y)
63
+ cosy_cosp = 1.0 - 2.0 * (y * y + z * z)
40
64
 
41
- return jnp.concatenate([roll, pitch, yaw], axis=-1)
65
+ return jax.lax.atan2(siny_cosp, cosy_cosp)
42
66
 
43
67
 
44
68
  def euler_to_quat(euler_3: Array) -> Array:
@@ -89,7 +113,12 @@ def get_projected_gravity_vector_from_quat(quat: Array, eps: float = 1e-6) -> Ar
89
113
  return rotate_vector_by_quat(jnp.array([0, 0, -9.81]), quat, inverse=True, eps=eps)
90
114
 
91
115
 
92
- def rotate_vector_by_quat(vector: Array, quat: Array, inverse: bool = False, eps: float = 1e-6) -> Array:
116
+ def rotate_vector_by_quat(
117
+ vector: Array,
118
+ quat: Array,
119
+ inverse: bool = False,
120
+ eps: float = 1e-6,
121
+ ) -> Array:
93
122
  """Rotates a vector by a quaternion.
94
123
 
95
124
  Args:
@@ -0,0 +1,141 @@
1
+ """Defines a launcher to train a model locally, in a single process."""
2
+
3
+ import logging
4
+ import os
5
+ import shutil
6
+ import subprocess
7
+ from typing import TYPE_CHECKING
8
+
9
+ import jax
10
+
11
+ from xax.task.base import RawConfigType
12
+ from xax.task.launchers.base import BaseLauncher
13
+ from xax.task.mixins.gpu_stats import get_num_gpus
14
+ from xax.utils.logging import configure_logging
15
+
16
+ if TYPE_CHECKING:
17
+ from xax.task.mixins.runnable import Config, RunnableMixin
18
+
19
+
20
+ def get_gpu_memory_info() -> dict[int, tuple[float, float]]:
21
+ """Get memory information for all GPUs.
22
+
23
+ Returns:
24
+ Dictionary mapping GPU index to (total_memory_mb, used_memory_mb)
25
+ """
26
+ command = "nvidia-smi --query-gpu=index,memory.total,memory.used --format=csv,noheader"
27
+
28
+ try:
29
+ with subprocess.Popen(command.split(), stdout=subprocess.PIPE, universal_newlines=True) as proc:
30
+ stdout = proc.stdout
31
+ assert stdout is not None
32
+
33
+ gpu_info = {}
34
+ for line in stdout:
35
+ line = line.strip()
36
+ if not line:
37
+ continue
38
+
39
+ parts = line.split(", ")
40
+ if len(parts) >= 3:
41
+ gpu_id = int(parts[0])
42
+ total_mem = float(parts[1].replace(" MiB", ""))
43
+ used_mem = float(parts[2].replace(" MiB", ""))
44
+ gpu_info[gpu_id] = (total_mem, used_mem)
45
+
46
+ return gpu_info
47
+
48
+ except Exception as e:
49
+ logger = configure_logging()
50
+ logger.warning("Failed to get GPU memory info: %s", e)
51
+ return {}
52
+
53
+
54
+ def select_best_gpu() -> int | None:
55
+ """Select the GPU with the most available memory.
56
+
57
+ Returns:
58
+ GPU index with most available memory, or None if no GPUs found
59
+ """
60
+ gpu_info = get_gpu_memory_info()
61
+
62
+ if not gpu_info:
63
+ return None
64
+
65
+ # Find GPU with most available memory
66
+ best_gpu = None
67
+ max_available: float = -1.0
68
+
69
+ for gpu_id, (total_mem, used_mem) in gpu_info.items():
70
+ available_mem = total_mem - used_mem
71
+ if available_mem > max_available:
72
+ max_available = available_mem
73
+ best_gpu = gpu_id
74
+
75
+ return best_gpu
76
+
77
+
78
+ def configure_gpu_devices(logger: logging.Logger | None = None) -> None:
79
+ if logger is None:
80
+ logger = configure_logging()
81
+
82
+ # If there are multiple devices, choose the one with the most
83
+ # available memory (i.e., the one which is likely not being used
84
+ # by other processes) and use only that device.
85
+ num_gpus = get_num_gpus()
86
+
87
+ if num_gpus > 1:
88
+ logger.info("Multiple GPUs detected (%d), selecting GPU with most available memory", num_gpus)
89
+
90
+ best_gpu = select_best_gpu()
91
+ if best_gpu is not None:
92
+ logger.info("Selected GPU %d for training", best_gpu)
93
+
94
+ # Set CUDA_VISIBLE_DEVICES to only show the selected GPU
95
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(best_gpu)
96
+
97
+ # Configure JAX to use the selected device
98
+ try:
99
+ devices = jax.devices("gpu")
100
+ if devices:
101
+ jax.config.update("jax_default_device", devices[0])
102
+ logger.info("Configured JAX to use device: %s", devices[0])
103
+ except Exception as e:
104
+ logger.warning("Failed to configure JAX device: %s", e)
105
+ else:
106
+ logger.warning("Could not determine best GPU, using default device selection")
107
+ elif num_gpus == 1:
108
+ logger.info("Single GPU detected, using default device selection")
109
+
110
+
111
+ def configure_devices(logger: logging.Logger | None = None) -> None:
112
+ if logger is None:
113
+ logger = configure_logging()
114
+
115
+ if shutil.which("nvidia-smi") is not None:
116
+ configure_gpu_devices(logger)
117
+
118
+
119
+ def run_single_process_training(
120
+ task: "type[RunnableMixin[Config]]",
121
+ *cfgs: RawConfigType,
122
+ use_cli: bool | list[str] = True,
123
+ logger: logging.Logger | None = None,
124
+ ) -> None:
125
+ if logger is None:
126
+ logger = configure_logging()
127
+ task_obj = task.get_task(*cfgs, use_cli=use_cli)
128
+ task_obj.add_logger_handlers(logger)
129
+ task_obj.run()
130
+
131
+
132
+ class SingleProcessLauncher(BaseLauncher):
133
+ def launch(
134
+ self,
135
+ task: "type[RunnableMixin[Config]]",
136
+ *cfgs: RawConfigType,
137
+ use_cli: bool | list[str] = True,
138
+ ) -> None:
139
+ logger = configure_logging()
140
+ configure_devices(logger)
141
+ run_single_process_training(task, *cfgs, use_cli=use_cli, logger=logger)