xmanager-slurm 0.3.2__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xmanager-slurm might be problematic. Click here for more details.

@@ -0,0 +1,200 @@
1
+ import asyncio
2
+ import base64
3
+ import dataclasses
4
+ import enum
5
+ import functools
6
+ import logging
7
+ import zlib
8
+ from typing import Awaitable, Callable, Concatenate, Coroutine, Mapping, ParamSpec, TypeVar
9
+
10
+ import backoff
11
+ import cloudpickle
12
+ from xmanager import xm
13
+
14
+ import xm_slurm
15
+ from xm_slurm import job_blocks, status
16
+ from xm_slurm.experiment import SlurmAuxiliaryUnit, SlurmExperiment
17
+
18
+ P = ParamSpec("P")
19
+ T = TypeVar("T")
20
+
21
+
22
+ async def _monitor_parameter_controller(
23
+ aux_unit: SlurmAuxiliaryUnit,
24
+ local_parameter_controller_coro: Coroutine[None, None, T],
25
+ *,
26
+ poll_interval: float = 30.0,
27
+ ) -> None:
28
+ local_controller_finished = asyncio.Event()
29
+ local_parameter_controller = asyncio.create_task(local_parameter_controller_coro)
30
+
31
+ @local_parameter_controller.add_done_callback
32
+ def _(future: asyncio.Task[T]) -> None:
33
+ try:
34
+ _ = future.result()
35
+ except asyncio.CancelledError:
36
+ logging.info("Local parameter controller was cancelled, resuming on remote controller.")
37
+ pass
38
+ except Exception:
39
+ logging.error("Local parameter controller failed, stopping remote controller.")
40
+ aux_unit.stop(
41
+ mark_as_failed=True, mark_as_completed=False, message="Local controller failed."
42
+ )
43
+ raise
44
+ else:
45
+ logging.info(
46
+ "Local parameter controller finished before remote controller started, "
47
+ "stopping remote controller."
48
+ )
49
+ local_controller_finished.set()
50
+ aux_unit.stop(mark_as_completed=True, message="Local parameter controller finished.")
51
+
52
+ @backoff.on_predicate(
53
+ backoff.constant,
54
+ lambda aux_unit_status: aux_unit_status is status.SlurmWorkUnitStatusEnum.PENDING,
55
+ jitter=None,
56
+ interval=poll_interval,
57
+ )
58
+ async def wait_for_remote_controller() -> status.SlurmWorkUnitStatusEnum:
59
+ logging.info("Waiting for remote parameter controller to start.")
60
+ if local_controller_finished.is_set():
61
+ return status.SlurmWorkUnitStatusEnum.COMPLETED
62
+ return (await aux_unit.get_status()).status
63
+
64
+ logging.info("Monitoring remote parameter controller.")
65
+ # TODO(jfarebro): make get_status() more resiliant to errors when initially scheduling.
66
+ # We run into issues if we call get_status() too quickly when Slurm hasn't ingested the job.
67
+ await asyncio.sleep(15)
68
+ match await wait_for_remote_controller():
69
+ case status.SlurmWorkUnitStatusEnum.RUNNING:
70
+ logging.info("Remote parameter controller started.")
71
+ local_parameter_controller.cancel("Remote parameter controller started.")
72
+ case status.SlurmWorkUnitStatusEnum.COMPLETED:
73
+ if local_parameter_controller.done():
74
+ logging.info("Local parameter controller finished, stopping remote controller.")
75
+ aux_unit.stop(
76
+ mark_as_completed=True, message="Local parameter controller finished."
77
+ )
78
+ else:
79
+ logging.info("Remote parameter controller finished, stopping local controller.")
80
+ local_parameter_controller.cancel()
81
+ case status.SlurmWorkUnitStatusEnum.FAILED:
82
+ logging.error("Remote parameter controller failed, stopping local controller.")
83
+ local_parameter_controller.cancel()
84
+ case status.SlurmWorkUnitStatusEnum.CANCELLED:
85
+ logging.info("Remote parameter controller was cancelled, stopping local controller.")
86
+ local_parameter_controller.cancel()
87
+ case status.SlurmWorkUnitStatusEnum.PENDING:
88
+ raise RuntimeError("Remote parameter controller is still pending, invalid state.")
89
+
90
+
91
+ class ParameterControllerMode(enum.Enum):
92
+ AUTO = enum.auto()
93
+ REMOTE_ONLY = enum.auto()
94
+ # TODO(jfarebro): is it possible to get LOCAL_ONLY?
95
+ # We'd need to have a dummy job type as we need to return a JobType?
96
+
97
+
98
+ def parameter_controller(
99
+ *,
100
+ executable: xm.Executable,
101
+ executor: xm.Executor,
102
+ controller_mode: ParameterControllerMode = ParameterControllerMode.AUTO,
103
+ controller_name: str = "parameter_controller",
104
+ controller_args: xm.UserArgs | None = None,
105
+ controller_env_vars: Mapping[str, str] | None = None,
106
+ ) -> Callable[
107
+ [
108
+ Callable[Concatenate[SlurmExperiment, P], T]
109
+ | Callable[Concatenate[SlurmExperiment, P], Awaitable[T]],
110
+ ],
111
+ Callable[P, xm.AuxiliaryUnitJob],
112
+ ]:
113
+ """Converts a function to a controller which can be added to an experiment.
114
+
115
+ Calling the wrapped function would return an xm.JobGenerator which would run
116
+ it as auxiliary unit on the specified executor.
117
+
118
+ Args:
119
+ executable: An executable that has a Python entrypoint with all the necesarry dependencies.
120
+ executor: The executor to launch the controller job on.
121
+ controller_name: Name of the parameter controller job.
122
+ controller_args: Mapping of flag names and values to be used by the XM
123
+ client running inside the parameter controller job.
124
+ controller_env_vars: Mapping of env variable names and values to be passed
125
+ to the parameter controller job.
126
+
127
+ Returns:
128
+ A decorator to be applied to the function.
129
+ """
130
+
131
+ def decorator(
132
+ f: Callable[Concatenate[SlurmExperiment, P], T]
133
+ | Callable[Concatenate[SlurmExperiment, P], Awaitable[T]],
134
+ ) -> Callable[P, xm.AuxiliaryUnitJob]:
135
+ @functools.wraps(f)
136
+ def make_controller(*args: P.args, **kwargs: P.kwargs) -> xm.AuxiliaryUnitJob:
137
+ # Modify the function to read the experiment from the API so that it can be pickled.
138
+
139
+ async def job_generator(aux_unit: SlurmAuxiliaryUnit) -> None:
140
+ experiment_id = aux_unit.experiment.experiment_id
141
+
142
+ async def local_controller(*args: P.args, **kwargs: P.kwargs) -> T | Awaitable[T]:
143
+ if asyncio.iscoroutinefunction(f):
144
+ return await f(aux_unit.experiment, *args, **kwargs)
145
+ else:
146
+ return f(aux_unit.experiment, *args, **kwargs)
147
+
148
+ async def remote_controller(*args: P.args, **kwargs: P.kwargs) -> T | Awaitable[T]:
149
+ async with xm_slurm.get_experiment(experiment_id=experiment_id) as exp:
150
+ if asyncio.iscoroutinefunction(f):
151
+ return await f(exp, *args, **kwargs)
152
+ else:
153
+ return f(exp, *args, **kwargs)
154
+
155
+ remote_controller_serialized = base64.urlsafe_b64encode(
156
+ zlib.compress(
157
+ cloudpickle.dumps(
158
+ functools.partial(remote_controller, *args, **kwargs),
159
+ )
160
+ )
161
+ )
162
+
163
+ parameter_controller_executable = dataclasses.replace(
164
+ executable,
165
+ args=xm.merge_args(
166
+ job_blocks.get_args_for_python_entrypoint(
167
+ xm.ModuleName("xm_slurm.scripts._cloudpickle")
168
+ ),
169
+ xm.SequentialArgs.from_collection({
170
+ "cloudpickled_fn": remote_controller_serialized.decode("ascii"),
171
+ }),
172
+ xm.SequentialArgs.from_collection(controller_args),
173
+ ),
174
+ env_vars=controller_env_vars or {},
175
+ )
176
+
177
+ await aux_unit.add(
178
+ xm.Job(
179
+ executor=executor,
180
+ executable=parameter_controller_executable,
181
+ name=controller_name,
182
+ )
183
+ )
184
+
185
+ # Launch local parameter controller and monitor for when it starts running
186
+ # so we can kill the local controller.
187
+ if controller_mode is ParameterControllerMode.AUTO:
188
+ aux_unit._create_task(
189
+ _monitor_parameter_controller(aux_unit, local_controller(*args, **kwargs))
190
+ )
191
+
192
+ return xm.AuxiliaryUnitJob(
193
+ job_generator,
194
+ importance=xm.Importance.HIGH,
195
+ termination_delay_secs=0, # TODO: add support for termination delay.?
196
+ )
197
+
198
+ return make_controller
199
+
200
+ return decorator
xm_slurm/job_blocks.py CHANGED
@@ -1,6 +1,13 @@
1
+ from typing import Mapping, TypedDict
2
+
1
3
  from xmanager import xm
2
4
 
3
5
 
6
+ class JobArgs(TypedDict, total=False):
7
+ args: xm.UserArgs
8
+ env_vars: Mapping[str, str]
9
+
10
+
4
11
  def get_args_for_python_entrypoint(
5
12
  entrypoint: xm.ModuleName | xm.CommandList,
6
13
  ) -> xm.SequentialArgs:
xm_slurm/packageables.py CHANGED
@@ -1,7 +1,7 @@
1
1
  import importlib.resources as resources
2
2
  import pathlib
3
3
  import sys
4
- from typing import Mapping, Sequence
4
+ from typing import Literal, Mapping, Sequence
5
5
 
6
6
  import immutabledict
7
7
  from xmanager import xm
@@ -29,8 +29,8 @@ def docker_image(
29
29
  return xm.Packageable(
30
30
  executor_spec=SlurmSpec(),
31
31
  executable_spec=DockerImage(image=image),
32
- args=args,
33
- env_vars=env_vars,
32
+ args=xm.SequentialArgs.from_collection(args),
33
+ env_vars=dict(env_vars),
34
34
  )
35
35
 
36
36
 
@@ -40,7 +40,7 @@ def docker_container(
40
40
  dockerfile: pathlib.Path | None = None,
41
41
  context: pathlib.Path | None = None,
42
42
  target: str | None = None,
43
- ssh: list[str] | None = None,
43
+ ssh: Sequence[str] | Literal[True] | None = None,
44
44
  build_args: Mapping[str, str] = immutabledict.immutabledict(),
45
45
  cache_from: str | Sequence[str] | None = None,
46
46
  labels: Mapping[str, str] = immutabledict.immutabledict(),
@@ -55,7 +55,7 @@ def docker_container(
55
55
  dockerfile: The path to the dockerfile.
56
56
  context: The path to the docker context.
57
57
  target: The docker build target.
58
- ssh: A list of SSH sockets/keys for the docker build step.
58
+ ssh: A list of SSH sockets/keys for the docker build step or `True` to use the default SSH agent.
59
59
  build_args: Build arguments to docker.
60
60
  cache_from: Where to pull the BuildKit cache from. See `--cache-from` in `docker build`.
61
61
  labels: The container labels.
@@ -71,8 +71,12 @@ def docker_container(
71
71
  if dockerfile is None:
72
72
  dockerfile = context / "Dockerfile"
73
73
  dockerfile = dockerfile.resolve()
74
+
74
75
  if ssh is None:
75
76
  ssh = []
77
+ elif ssh is True:
78
+ ssh = ["default"]
79
+
76
80
  if cache_from is None and isinstance(executor_spec, SlurmSpec):
77
81
  cache_from = executor_spec.tag
78
82
  if cache_from is None:
@@ -92,8 +96,8 @@ def docker_container(
92
96
  workdir=workdir,
93
97
  labels=labels,
94
98
  ),
95
- args=args,
96
- env_vars=env_vars,
99
+ args=xm.SequentialArgs.from_collection(args),
100
+ env_vars=dict(env_vars),
97
101
  )
98
102
 
99
103
 
@@ -104,14 +108,18 @@ def python_container(
104
108
  context: pathlib.Path | None = None,
105
109
  requirements: pathlib.Path | None = None,
106
110
  base_image: str = "docker.io/python:{major}.{minor}-slim",
111
+ extra_system_packages: Sequence[str] = (),
112
+ extra_python_packages: Sequence[str] = (),
107
113
  cache_from: str | Sequence[str] | None = None,
108
114
  labels: Mapping[str, str] = immutabledict.immutabledict(),
115
+ ssh: Sequence[str] | Literal[True] | None = None,
109
116
  args: xm.UserArgs | None = None,
110
117
  env_vars: Mapping[str, str] = immutabledict.immutabledict(),
111
118
  ) -> xm.Packageable:
112
119
  """Creates a Python container from a base image using pip from a `requirements.txt` file.
113
120
 
114
121
  NOTE: The base image will use the Python version of the current interpreter.
122
+ NOTE: uv is used to install packages from `requirements`.
115
123
 
116
124
  Args:
117
125
  executor_spec: The executor specification, where will the container be stored at.
@@ -119,8 +127,11 @@ def python_container(
119
127
  context: The path to the docker context.
120
128
  requirements: The path to the pip requirements file.
121
129
  base_image: The base image to use. NOTE: The base image must contain the Python runtime.
130
+ extra_system_packages: Additional system packages to install. NOTE: These are installed via `apt-get`.
131
+ extra_python_packages: Additional Python packages to install. NOTE: These are installed via `uv pip`.
122
132
  cache_from: Where to pull the BuildKit cache from. See `--cache-from` in `docker build`.
123
133
  labels: The container labels.
134
+ ssh: A list of SSH sockets/keys for the docker build step or `True` to use the default SSH agent.
124
135
  args: The user arguments to pass to the executable.
125
136
  env_vars: The environment variables to pass to the executable.
126
137
 
@@ -149,11 +160,11 @@ def python_container(
149
160
  executor_spec=executor_spec,
150
161
  dockerfile=dockerfile,
151
162
  context=context,
163
+ ssh=ssh,
152
164
  build_args={
153
165
  "PIP_REQUIREMENTS": requirements.relative_to(context).as_posix(),
154
- "PYTHON_MAJOR": str(sys.version_info.major),
155
- "PYTHON_MINOR": str(sys.version_info.minor),
156
- "PYTHON_MICRO": str(sys.version_info.micro),
166
+ "EXTRA_SYSTEM_PACKAGES": " ".join(extra_system_packages),
167
+ "EXTRA_PYTHON_PACKAGES": " ".join(extra_python_packages),
157
168
  "BASE_IMAGE": base_image.format_map({
158
169
  "major": sys.version_info.major,
159
170
  "minor": sys.version_info.minor,
@@ -178,6 +189,7 @@ def mamba_container(
178
189
  base_image: str = "gcr.io/distroless/base-debian10",
179
190
  cache_from: str | Sequence[str] | None = None,
180
191
  labels: Mapping[str, str] = immutabledict.immutabledict(),
192
+ ssh: Sequence[str] | Literal[True] | None = None,
181
193
  args: xm.UserArgs | None = None,
182
194
  env_vars: Mapping[str, str] = immutabledict.immutabledict(),
183
195
  ) -> xm.Packageable:
@@ -193,6 +205,7 @@ def mamba_container(
193
205
  base_image: The base image to use.
194
206
  cache_from: Where to pull the BuildKit cache from. See `--cache-from` in `docker build`.
195
207
  labels: The container labels.
208
+ ssh: A list of SSH sockets/keys for the docker build step or `True` to use the default SSH agent.
196
209
  args: The user arguments to pass to the executable.
197
210
  env_vars: The environment variables to pass to the executable.
198
211
 
@@ -221,6 +234,7 @@ def mamba_container(
221
234
  executor_spec=executor_spec,
222
235
  dockerfile=dockerfile,
223
236
  context=context,
237
+ ssh=ssh,
224
238
  build_args={
225
239
  "CONDA_ENVIRONMENT": environment.relative_to(context).as_posix(),
226
240
  "BASE_IMAGE": base_image,
@@ -237,26 +251,32 @@ def mamba_container(
237
251
  conda_container = mamba_container
238
252
 
239
253
 
240
- def pdm_container(
254
+ def uv_container(
241
255
  *,
242
256
  executor_spec: xm.ExecutorSpec,
243
257
  entrypoint: xm.ModuleName | xm.CommandList,
244
258
  context: pathlib.Path | None = None,
245
- base_image: str = "docker.io/python:{major}.{minor}-slim",
259
+ base_image: str = "docker.io/python:{major}.{minor}-slim-bookworm",
260
+ extra_system_packages: Sequence[str] = (),
261
+ extra_python_packages: Sequence[str] = (),
246
262
  cache_from: str | Sequence[str] | None = None,
247
263
  labels: Mapping[str, str] = immutabledict.immutabledict(),
264
+ ssh: Sequence[str] | Literal[True] | None = None,
248
265
  args: xm.UserArgs | None = None,
249
266
  env_vars: Mapping[str, str] = immutabledict.immutabledict(),
250
267
  ) -> xm.Packageable:
251
- """Creates a Python container from a base image using pdm from a `pdm.lock` file.
268
+ """Creates a Python container from a base image using uv from a `uv.lock` file.
252
269
 
253
270
  Args:
254
271
  executor_spec: The executor specification, where will the container be stored at.
255
272
  entrypoint: The entrypoint to run in the container.
256
273
  context: The path to the docker context.
257
274
  base_image: The base image to use. NOTE: The base image must contain the Python runtime.
275
+ extra_system_packages: Additional system packages to install. NOTE: These are installed via `apt-get`.
276
+ extra_python_packages: Additional Python packages to install. NOTE: These are installed via `uv pip`.
258
277
  cache_from: Where to pull the BuildKit cache from. See `--cache-from` in `docker build`.
259
278
  labels: The container labels.
279
+ ssh: A list of SSH sockets/keys for the docker build step or `True` to use the default SSH agent.
260
280
  args: The user arguments to pass to the executable.
261
281
  env_vars: The environment variables to pass to the executable.
262
282
 
@@ -268,20 +288,22 @@ def pdm_container(
268
288
  if context is None:
269
289
  context = utils.find_project_root()
270
290
  context = context.resolve()
271
- if not (context / "pdm.lock").exists():
272
- raise ValueError(f"PDM lockfile `{context / 'pdm.lock'}` doesn't exist.")
291
+ if not (context / "pyproject.toml").exists():
292
+ raise ValueError(f"Python project file `{context / 'pyproject.toml'}` doesn't exist.")
293
+ if not (context / "uv.lock").exists():
294
+ raise ValueError(f"UV lock file `{context / 'uv.lock'}` doesn't exist.")
273
295
 
274
296
  with resources.as_file(
275
- resources.files("xm_slurm.templates").joinpath("docker/pdm.Dockerfile")
297
+ resources.files("xm_slurm.templates").joinpath("docker/uv.Dockerfile")
276
298
  ) as dockerfile:
277
299
  return docker_container(
278
300
  executor_spec=executor_spec,
279
301
  dockerfile=dockerfile,
280
302
  context=context,
303
+ ssh=ssh,
281
304
  build_args={
282
- "PYTHON_MAJOR": str(sys.version_info.major),
283
- "PYTHON_MINOR": str(sys.version_info.minor),
284
- "PYTHON_MICRO": str(sys.version_info.micro),
305
+ "EXTRA_SYSTEM_PACKAGES": " ".join(extra_system_packages),
306
+ "EXTRA_PYTHON_PACKAGES": " ".join(extra_python_packages),
285
307
  "BASE_IMAGE": base_image.format_map({
286
308
  "major": sys.version_info.major,
287
309
  "minor": sys.version_info.minor,
@@ -291,7 +313,7 @@ def pdm_container(
291
313
  cache_from=cache_from,
292
314
  labels=labels,
293
315
  # We must specify the workdir manually for apptainer support
294
- workdir=pathlib.Path("/workspace/src"),
316
+ workdir=pathlib.Path("/workspace"),
295
317
  args=args,
296
318
  env_vars=env_vars,
297
319
  )
@@ -11,8 +11,8 @@ from xm_slurm.packaging import registry
11
11
  from xm_slurm.packaging.docker.abc import DockerClient
12
12
 
13
13
  FLAGS = flags.FLAGS
14
- REMOTE_BUILD = flags.DEFINE_enum(
15
- "xm_builder", "local", ["local", "gcp", "azure"], "Remote build provider."
14
+ DOCKER_CLIENT_PROVIDER = flags.DEFINE_enum(
15
+ "xm_docker_client", "docker", ["docker"], "Docker image build client."
16
16
  )
17
17
 
18
18
  IndexedContainer = registry.IndexedContainer
@@ -20,19 +20,13 @@ IndexedContainer = registry.IndexedContainer
20
20
 
21
21
  @functools.cache
22
22
  def docker_client() -> DockerClient:
23
- match REMOTE_BUILD.value:
24
- case "local":
23
+ match DOCKER_CLIENT_PROVIDER.value:
24
+ case "docker":
25
25
  from xm_slurm.packaging.docker.local import LocalDockerClient
26
26
 
27
27
  return LocalDockerClient()
28
- case "gcp":
29
- from xm_slurm.packaging.docker.cloud import GoogleCloudRemoteDockerClient
30
-
31
- return GoogleCloudRemoteDockerClient()
32
- case "azure":
33
- raise NotImplementedError("Azure remote build is not yet supported.")
34
28
  case _:
35
- raise ValueError(f"Unknown remote build provider: {REMOTE_BUILD.value}")
29
+ raise ValueError(f"Unknown build client: {DOCKER_CLIENT_PROVIDER.value}")
36
30
 
37
31
 
38
32
  @registry.register(Dockerfile)
@@ -9,7 +9,6 @@ import shlex
9
9
  import shutil
10
10
  import subprocess
11
11
  import tempfile
12
- import typing
13
12
  from typing import Sequence
14
13
 
15
14
  from xmanager import xm
@@ -37,7 +36,7 @@ class LocalDockerClient(DockerClient):
37
36
  BUILDKIT = enum.auto()
38
37
  BUILDAH = enum.auto()
39
38
 
40
- def __init__(self):
39
+ def __init__(self) -> None:
41
40
  if "XM_DOCKER_CLIENT" in os.environ:
42
41
  client_call = shlex.split(os.environ["XM_DOCKER_CLIENT"])
43
42
  elif shutil.which("docker"):
@@ -80,7 +79,7 @@ class LocalDockerClient(DockerClient):
80
79
  return None
81
80
 
82
81
  def _parse_credentials_from_config(
83
- config_path: pathlib.Path
82
+ config_path: pathlib.Path,
84
83
  ) -> RemoteRepositoryCredentials | None:
85
84
  """Parse credentials from the Docker configuration file."""
86
85
  if not config_path.exists():
@@ -138,9 +137,13 @@ class LocalDockerClient(DockerClient):
138
137
  targets: Sequence[IndexedContainer[xm.Packageable]],
139
138
  ) -> list[IndexedContainer[RemoteImage]]:
140
139
  executors_by_executables = packaging_utils.collect_executors_by_executable(targets)
141
- executors_by_executables = typing.cast(
142
- dict[Dockerfile, list[SlurmSpec]], executors_by_executables
143
- )
140
+ for executable, executors in executors_by_executables.items():
141
+ assert isinstance(
142
+ executable, Dockerfile
143
+ ), "All executables must be Dockerfiles when building Docker images."
144
+ assert all(
145
+ isinstance(executor, SlurmSpec) and executor.tag for executor in executors
146
+ ), "All executors must be SlurmSpecs with tags when building Docker images."
144
147
 
145
148
  with tempfile.TemporaryDirectory() as tempdir:
146
149
  hcl_file = pathlib.Path(tempdir) / "docker-bake.hcl"
@@ -157,12 +160,10 @@ class LocalDockerClient(DockerClient):
157
160
  try:
158
161
  command = DockerBakeCommand(
159
162
  targets=list(
160
- set(
161
- [
162
- packaging_utils.hash_digest(target.value.executable_spec)
163
- for target in targets
164
- ]
165
- )
163
+ set([
164
+ packaging_utils.hash_digest(target.value.executable_spec)
165
+ for target in targets
166
+ ])
166
167
  ),
167
168
  files=[hcl_file],
168
169
  metadata_file=metadata_file,
@@ -10,8 +10,7 @@ import re
10
10
  import select
11
11
  import shutil
12
12
  import subprocess
13
- import typing
14
- from typing import Callable, Concatenate, Hashable, Literal, ParamSpec, Sequence, TypeVar
13
+ from typing import Callable, Concatenate, Hashable, ParamSpec, Sequence, TypeVar
15
14
 
16
15
  from xmanager import xm
17
16
 
@@ -23,6 +22,7 @@ ReturnT = TypeVar("ReturnT")
23
22
 
24
23
 
25
24
  def hash_digest(obj: Hashable) -> str:
25
+ # TODO(jfarebro): Need a better way to hash these objects
26
26
  # obj_hash = hash(obj)
27
27
  # unsigned_obj_hash = obj_hash.from_bytes(
28
28
  # obj_hash.to_bytes((obj_hash.bit_length() + 7) // 8, "big", signed=True),
@@ -54,7 +54,7 @@ def parallel_map(
54
54
 
55
55
 
56
56
  # Cursor commands to filter out from the command data stream
57
- cursor_commands_regex = re.compile(
57
+ _CURSOR_ESCAPE_SEQUENCES_REGEX = re.compile(
58
58
  rb"\x1b\[\?25[hl]" # Matches cursor show/hide commands (CSI ?25h and CSI ?25l)
59
59
  rb"|\x1b\[[0-9;]*[Hf]" # Matches cursor position commands (CSI n;mH and CSI n;mf)
60
60
  rb"|\x1b\[s" # Matches cursor save position (CSI s)
@@ -64,54 +64,6 @@ cursor_commands_regex = re.compile(
64
64
  )
65
65
 
66
66
 
67
- @typing.overload
68
- def run_command(
69
- args: Sequence[str] | xm.SequentialArgs,
70
- env: dict[str, str] | None = ...,
71
- tty: bool = ...,
72
- cwd: str | os.PathLike[str] | None = ...,
73
- check: bool = ...,
74
- return_stdout: Literal[False] = False,
75
- return_stderr: Literal[False] = False,
76
- ) -> subprocess.CompletedProcess[None]: ...
77
-
78
-
79
- @typing.overload
80
- def run_command(
81
- args: Sequence[str] | xm.SequentialArgs,
82
- env: dict[str, str] | None = ...,
83
- tty: bool = ...,
84
- cwd: str | os.PathLike[str] | None = ...,
85
- check: bool = ...,
86
- return_stdout: Literal[True] = True,
87
- return_stderr: Literal[False] = False,
88
- ) -> subprocess.CompletedProcess[str]: ...
89
-
90
-
91
- @typing.overload
92
- def run_command(
93
- args: Sequence[str] | xm.SequentialArgs,
94
- env: dict[str, str] | None = ...,
95
- tty: bool = ...,
96
- cwd: str | os.PathLike[str] | None = ...,
97
- check: bool = ...,
98
- return_stdout: Literal[False] = False,
99
- return_stderr: Literal[True] = True,
100
- ) -> subprocess.CompletedProcess[str]: ...
101
-
102
-
103
- @typing.overload
104
- def run_command(
105
- args: Sequence[str] | xm.SequentialArgs,
106
- env: dict[str, str] | None = ...,
107
- tty: bool = ...,
108
- cwd: str | os.PathLike[str] | None = ...,
109
- check: bool = ...,
110
- return_stdout: Literal[True] = True,
111
- return_stderr: Literal[True] = True,
112
- ) -> subprocess.CompletedProcess[str]: ...
113
-
114
-
115
67
  def run_command(
116
68
  args: Sequence[str] | xm.SequentialArgs,
117
69
  env: dict[str, str] | None = None,
@@ -120,7 +72,7 @@ def run_command(
120
72
  check: bool = False,
121
73
  return_stdout: bool = False,
122
74
  return_stderr: bool = False,
123
- ) -> subprocess.CompletedProcess[str] | subprocess.CompletedProcess[None]:
75
+ ) -> subprocess.CompletedProcess[str]:
124
76
  if isinstance(args, xm.SequentialArgs):
125
77
  args = args.to_list()
126
78
  args = list(args)
@@ -171,7 +123,7 @@ def run_command(
171
123
  fds.remove(fd)
172
124
  continue
173
125
 
174
- data = re.sub(cursor_commands_regex, b"", data)
126
+ data = _CURSOR_ESCAPE_SEQUENCES_REGEX.sub(b"", data)
175
127
 
176
128
  if fd == stdout_master:
177
129
  if return_stdout:
@@ -186,8 +138,8 @@ def run_command(
186
138
  else:
187
139
  raise RuntimeError("Unexpected file descriptor")
188
140
 
189
- stdout = stdout_data.decode(errors="replace") if stdout_data else None
190
- stderr = stderr_data.decode(errors="replace") if stderr_data else None
141
+ stdout = stdout_data.decode(errors="replace") if stdout_data else ""
142
+ stderr = stderr_data.decode(errors="replace") if stderr_data else ""
191
143
 
192
144
  retcode = process.poll()
193
145
  assert retcode is not None