xmanager-slurm 0.3.1__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of xmanager-slurm might be problematic. Click here for more details.

@@ -0,0 +1,200 @@
1
+ import asyncio
2
+ import base64
3
+ import dataclasses
4
+ import enum
5
+ import functools
6
+ import logging
7
+ import zlib
8
+ from typing import Awaitable, Callable, Concatenate, Coroutine, Mapping, ParamSpec, TypeVar
9
+
10
+ import backoff
11
+ import cloudpickle
12
+ from xmanager import xm
13
+
14
+ import xm_slurm
15
+ from xm_slurm import job_blocks, status
16
+ from xm_slurm.experiment import SlurmAuxiliaryUnit, SlurmExperiment
17
+
18
+ P = ParamSpec("P")
19
+ T = TypeVar("T")
20
+
21
+
22
+ async def _monitor_parameter_controller(
23
+ aux_unit: SlurmAuxiliaryUnit,
24
+ local_parameter_controller_coro: Coroutine[None, None, T],
25
+ *,
26
+ poll_interval: float = 30.0,
27
+ ) -> None:
28
+ local_controller_finished = asyncio.Event()
29
+ local_parameter_controller = asyncio.create_task(local_parameter_controller_coro)
30
+
31
+ @local_parameter_controller.add_done_callback
32
+ def _(future: asyncio.Task[T]) -> None:
33
+ try:
34
+ _ = future.result()
35
+ except asyncio.CancelledError:
36
+ logging.info("Local parameter controller was cancelled, resuming on remote controller.")
37
+ pass
38
+ except Exception:
39
+ logging.error("Local parameter controller failed, stopping remote controller.")
40
+ aux_unit.stop(
41
+ mark_as_failed=True, mark_as_completed=False, message="Local controller failed."
42
+ )
43
+ raise
44
+ else:
45
+ logging.info(
46
+ "Local parameter controller finished before remote controller started, "
47
+ "stopping remote controller."
48
+ )
49
+ local_controller_finished.set()
50
+ aux_unit.stop(mark_as_completed=True, message="Local parameter controller finished.")
51
+
52
+ @backoff.on_predicate(
53
+ backoff.constant,
54
+ lambda aux_unit_status: aux_unit_status is status.SlurmWorkUnitStatusEnum.PENDING,
55
+ jitter=None,
56
+ interval=poll_interval,
57
+ )
58
+ async def wait_for_remote_controller() -> status.SlurmWorkUnitStatusEnum:
59
+ logging.info("Waiting for remote parameter controller to start.")
60
+ if local_controller_finished.is_set():
61
+ return status.SlurmWorkUnitStatusEnum.COMPLETED
62
+ return (await aux_unit.get_status()).status
63
+
64
+ logging.info("Monitoring remote parameter controller.")
65
+ # TODO(jfarebro): make get_status() more resiliant to errors when initially scheduling.
66
+ # We run into issues if we call get_status() too quickly when Slurm hasn't ingested the job.
67
+ await asyncio.sleep(15)
68
+ match await wait_for_remote_controller():
69
+ case status.SlurmWorkUnitStatusEnum.RUNNING:
70
+ logging.info("Remote parameter controller started.")
71
+ local_parameter_controller.cancel("Remote parameter controller started.")
72
+ case status.SlurmWorkUnitStatusEnum.COMPLETED:
73
+ if local_parameter_controller.done():
74
+ logging.info("Local parameter controller finished, stopping remote controller.")
75
+ aux_unit.stop(
76
+ mark_as_completed=True, message="Local parameter controller finished."
77
+ )
78
+ else:
79
+ logging.info("Remote parameter controller finished, stopping local controller.")
80
+ local_parameter_controller.cancel()
81
+ case status.SlurmWorkUnitStatusEnum.FAILED:
82
+ logging.error("Remote parameter controller failed, stopping local controller.")
83
+ local_parameter_controller.cancel()
84
+ case status.SlurmWorkUnitStatusEnum.CANCELLED:
85
+ logging.info("Remote parameter controller was cancelled, stopping local controller.")
86
+ local_parameter_controller.cancel()
87
+ case status.SlurmWorkUnitStatusEnum.PENDING:
88
+ raise RuntimeError("Remote parameter controller is still pending, invalid state.")
89
+
90
+
91
+ class ParameterControllerMode(enum.Enum):
92
+ AUTO = enum.auto()
93
+ REMOTE_ONLY = enum.auto()
94
+ # TODO(jfarebro): is it possible to get LOCAL_ONLY?
95
+ # We'd need to have a dummy job type as we need to return a JobType?
96
+
97
+
98
+ def parameter_controller(
99
+ *,
100
+ executable: xm.Executable,
101
+ executor: xm.Executor,
102
+ controller_mode: ParameterControllerMode = ParameterControllerMode.AUTO,
103
+ controller_name: str = "parameter_controller",
104
+ controller_args: xm.UserArgs | None = None,
105
+ controller_env_vars: Mapping[str, str] | None = None,
106
+ ) -> Callable[
107
+ [
108
+ Callable[Concatenate[SlurmExperiment, P], T]
109
+ | Callable[Concatenate[SlurmExperiment, P], Awaitable[T]],
110
+ ],
111
+ Callable[P, xm.AuxiliaryUnitJob],
112
+ ]:
113
+ """Converts a function to a controller which can be added to an experiment.
114
+
115
+ Calling the wrapped function would return an xm.JobGenerator which would run
116
+ it as auxiliary unit on the specified executor.
117
+
118
+ Args:
119
+ executable: An executable that has a Python entrypoint with all the necesarry dependencies.
120
+ executor: The executor to launch the controller job on.
121
+ controller_name: Name of the parameter controller job.
122
+ controller_args: Mapping of flag names and values to be used by the XM
123
+ client running inside the parameter controller job.
124
+ controller_env_vars: Mapping of env variable names and values to be passed
125
+ to the parameter controller job.
126
+
127
+ Returns:
128
+ A decorator to be applied to the function.
129
+ """
130
+
131
+ def decorator(
132
+ f: Callable[Concatenate[SlurmExperiment, P], T]
133
+ | Callable[Concatenate[SlurmExperiment, P], Awaitable[T]],
134
+ ) -> Callable[P, xm.AuxiliaryUnitJob]:
135
+ @functools.wraps(f)
136
+ def make_controller(*args: P.args, **kwargs: P.kwargs) -> xm.AuxiliaryUnitJob:
137
+ # Modify the function to read the experiment from the API so that it can be pickled.
138
+
139
+ async def job_generator(aux_unit: SlurmAuxiliaryUnit) -> None:
140
+ experiment_id = aux_unit.experiment.experiment_id
141
+
142
+ async def local_controller(*args: P.args, **kwargs: P.kwargs) -> T | Awaitable[T]:
143
+ if asyncio.iscoroutinefunction(f):
144
+ return await f(aux_unit.experiment, *args, **kwargs)
145
+ else:
146
+ return f(aux_unit.experiment, *args, **kwargs)
147
+
148
+ async def remote_controller(*args: P.args, **kwargs: P.kwargs) -> T | Awaitable[T]:
149
+ async with xm_slurm.get_experiment(experiment_id=experiment_id) as exp:
150
+ if asyncio.iscoroutinefunction(f):
151
+ return await f(exp, *args, **kwargs)
152
+ else:
153
+ return f(exp, *args, **kwargs)
154
+
155
+ remote_controller_serialized = base64.urlsafe_b64encode(
156
+ zlib.compress(
157
+ cloudpickle.dumps(
158
+ functools.partial(remote_controller, *args, **kwargs),
159
+ )
160
+ )
161
+ )
162
+
163
+ parameter_controller_executable = dataclasses.replace(
164
+ executable,
165
+ args=xm.merge_args(
166
+ job_blocks.get_args_for_python_entrypoint(
167
+ xm.ModuleName("xm_slurm.scripts._cloudpickle")
168
+ ),
169
+ xm.SequentialArgs.from_collection({
170
+ "cloudpickled_fn": remote_controller_serialized.decode("ascii"),
171
+ }),
172
+ xm.SequentialArgs.from_collection(controller_args),
173
+ ),
174
+ env_vars=controller_env_vars or {},
175
+ )
176
+
177
+ await aux_unit.add(
178
+ xm.Job(
179
+ executor=executor,
180
+ executable=parameter_controller_executable,
181
+ name=controller_name,
182
+ )
183
+ )
184
+
185
+ # Launch local parameter controller and monitor for when it starts running
186
+ # so we can kill the local controller.
187
+ if controller_mode is ParameterControllerMode.AUTO:
188
+ aux_unit._create_task(
189
+ _monitor_parameter_controller(aux_unit, local_controller(*args, **kwargs))
190
+ )
191
+
192
+ return xm.AuxiliaryUnitJob(
193
+ job_generator,
194
+ importance=xm.Importance.HIGH,
195
+ termination_delay_secs=0, # TODO: add support for termination delay.?
196
+ )
197
+
198
+ return make_controller
199
+
200
+ return decorator
xm_slurm/job_blocks.py CHANGED
@@ -1,6 +1,13 @@
1
+ from typing import Mapping, TypedDict
2
+
1
3
  from xmanager import xm
2
4
 
3
5
 
6
+ class JobArgs(TypedDict, total=False):
7
+ args: xm.UserArgs
8
+ env_vars: Mapping[str, str]
9
+
10
+
4
11
  def get_args_for_python_entrypoint(
5
12
  entrypoint: xm.ModuleName | xm.CommandList,
6
13
  ) -> xm.SequentialArgs:
xm_slurm/packageables.py CHANGED
@@ -1,7 +1,7 @@
1
1
  import importlib.resources as resources
2
2
  import pathlib
3
3
  import sys
4
- from typing import Mapping, Sequence
4
+ from typing import Literal, Mapping, Sequence
5
5
 
6
6
  import immutabledict
7
7
  from xmanager import xm
@@ -29,8 +29,8 @@ def docker_image(
29
29
  return xm.Packageable(
30
30
  executor_spec=SlurmSpec(),
31
31
  executable_spec=DockerImage(image=image),
32
- args=args,
33
- env_vars=env_vars,
32
+ args=xm.SequentialArgs.from_collection(args),
33
+ env_vars=dict(env_vars),
34
34
  )
35
35
 
36
36
 
@@ -40,6 +40,7 @@ def docker_container(
40
40
  dockerfile: pathlib.Path | None = None,
41
41
  context: pathlib.Path | None = None,
42
42
  target: str | None = None,
43
+ ssh: Sequence[str] | Literal[True] | None = None,
43
44
  build_args: Mapping[str, str] = immutabledict.immutabledict(),
44
45
  cache_from: str | Sequence[str] | None = None,
45
46
  labels: Mapping[str, str] = immutabledict.immutabledict(),
@@ -54,6 +55,7 @@ def docker_container(
54
55
  dockerfile: The path to the dockerfile.
55
56
  context: The path to the docker context.
56
57
  target: The docker build target.
58
+ ssh: A list of SSH sockets/keys for the docker build step or `True` to use the default SSH agent.
57
59
  build_args: Build arguments to docker.
58
60
  cache_from: Where to pull the BuildKit cache from. See `--cache-from` in `docker build`.
59
61
  labels: The container labels.
@@ -69,6 +71,12 @@ def docker_container(
69
71
  if dockerfile is None:
70
72
  dockerfile = context / "Dockerfile"
71
73
  dockerfile = dockerfile.resolve()
74
+
75
+ if ssh is None:
76
+ ssh = []
77
+ elif ssh is True:
78
+ ssh = ["default"]
79
+
72
80
  if cache_from is None and isinstance(executor_spec, SlurmSpec):
73
81
  cache_from = executor_spec.tag
74
82
  if cache_from is None:
@@ -82,13 +90,14 @@ def docker_container(
82
90
  dockerfile=dockerfile,
83
91
  context=context,
84
92
  target=target,
93
+ ssh=ssh,
85
94
  build_args=build_args,
86
95
  cache_from=cache_from,
87
96
  workdir=workdir,
88
97
  labels=labels,
89
98
  ),
90
- args=args,
91
- env_vars=env_vars,
99
+ args=xm.SequentialArgs.from_collection(args),
100
+ env_vars=dict(env_vars),
92
101
  )
93
102
 
94
103
 
@@ -99,14 +108,18 @@ def python_container(
99
108
  context: pathlib.Path | None = None,
100
109
  requirements: pathlib.Path | None = None,
101
110
  base_image: str = "docker.io/python:{major}.{minor}-slim",
111
+ extra_system_packages: Sequence[str] = (),
112
+ extra_python_packages: Sequence[str] = (),
102
113
  cache_from: str | Sequence[str] | None = None,
103
114
  labels: Mapping[str, str] = immutabledict.immutabledict(),
115
+ ssh: Sequence[str] | Literal[True] | None = None,
104
116
  args: xm.UserArgs | None = None,
105
117
  env_vars: Mapping[str, str] = immutabledict.immutabledict(),
106
118
  ) -> xm.Packageable:
107
119
  """Creates a Python container from a base image using pip from a `requirements.txt` file.
108
120
 
109
121
  NOTE: The base image will use the Python version of the current interpreter.
122
+ NOTE: uv is used to install packages from `requirements`.
110
123
 
111
124
  Args:
112
125
  executor_spec: The executor specification, where will the container be stored at.
@@ -114,8 +127,11 @@ def python_container(
114
127
  context: The path to the docker context.
115
128
  requirements: The path to the pip requirements file.
116
129
  base_image: The base image to use. NOTE: The base image must contain the Python runtime.
130
+ extra_system_packages: Additional system packages to install. NOTE: These are installed via `apt-get`.
131
+ extra_python_packages: Additional Python packages to install. NOTE: These are installed via `uv pip`.
117
132
  cache_from: Where to pull the BuildKit cache from. See `--cache-from` in `docker build`.
118
133
  labels: The container labels.
134
+ ssh: A list of SSH sockets/keys for the docker build step or `True` to use the default SSH agent.
119
135
  args: The user arguments to pass to the executable.
120
136
  env_vars: The environment variables to pass to the executable.
121
137
 
@@ -144,11 +160,11 @@ def python_container(
144
160
  executor_spec=executor_spec,
145
161
  dockerfile=dockerfile,
146
162
  context=context,
163
+ ssh=ssh,
147
164
  build_args={
148
165
  "PIP_REQUIREMENTS": requirements.relative_to(context).as_posix(),
149
- "PYTHON_MAJOR": str(sys.version_info.major),
150
- "PYTHON_MINOR": str(sys.version_info.minor),
151
- "PYTHON_MICRO": str(sys.version_info.micro),
166
+ "EXTRA_SYSTEM_PACKAGES": " ".join(extra_system_packages),
167
+ "EXTRA_PYTHON_PACKAGES": " ".join(extra_python_packages),
152
168
  "BASE_IMAGE": base_image.format_map({
153
169
  "major": sys.version_info.major,
154
170
  "minor": sys.version_info.minor,
@@ -173,6 +189,7 @@ def mamba_container(
173
189
  base_image: str = "gcr.io/distroless/base-debian10",
174
190
  cache_from: str | Sequence[str] | None = None,
175
191
  labels: Mapping[str, str] = immutabledict.immutabledict(),
192
+ ssh: Sequence[str] | Literal[True] | None = None,
176
193
  args: xm.UserArgs | None = None,
177
194
  env_vars: Mapping[str, str] = immutabledict.immutabledict(),
178
195
  ) -> xm.Packageable:
@@ -188,6 +205,7 @@ def mamba_container(
188
205
  base_image: The base image to use.
189
206
  cache_from: Where to pull the BuildKit cache from. See `--cache-from` in `docker build`.
190
207
  labels: The container labels.
208
+ ssh: A list of SSH sockets/keys for the docker build step or `True` to use the default SSH agent.
191
209
  args: The user arguments to pass to the executable.
192
210
  env_vars: The environment variables to pass to the executable.
193
211
 
@@ -216,6 +234,7 @@ def mamba_container(
216
234
  executor_spec=executor_spec,
217
235
  dockerfile=dockerfile,
218
236
  context=context,
237
+ ssh=ssh,
219
238
  build_args={
220
239
  "CONDA_ENVIRONMENT": environment.relative_to(context).as_posix(),
221
240
  "BASE_IMAGE": base_image,
@@ -232,26 +251,32 @@ def mamba_container(
232
251
  conda_container = mamba_container
233
252
 
234
253
 
235
- def pdm_container(
254
+ def uv_container(
236
255
  *,
237
256
  executor_spec: xm.ExecutorSpec,
238
257
  entrypoint: xm.ModuleName | xm.CommandList,
239
258
  context: pathlib.Path | None = None,
240
- base_image: str = "docker.io/python:{major}.{minor}-slim",
259
+ base_image: str = "docker.io/python:{major}.{minor}-slim-bookworm",
260
+ extra_system_packages: Sequence[str] = (),
261
+ extra_python_packages: Sequence[str] = (),
241
262
  cache_from: str | Sequence[str] | None = None,
242
263
  labels: Mapping[str, str] = immutabledict.immutabledict(),
264
+ ssh: Sequence[str] | Literal[True] | None = None,
243
265
  args: xm.UserArgs | None = None,
244
266
  env_vars: Mapping[str, str] = immutabledict.immutabledict(),
245
267
  ) -> xm.Packageable:
246
- """Creates a Python container from a base image using pdm from a `pdm.lock` file.
268
+ """Creates a Python container from a base image using uv from a `uv.lock` file.
247
269
 
248
270
  Args:
249
271
  executor_spec: The executor specification, where will the container be stored at.
250
272
  entrypoint: The entrypoint to run in the container.
251
273
  context: The path to the docker context.
252
274
  base_image: The base image to use. NOTE: The base image must contain the Python runtime.
275
+ extra_system_packages: Additional system packages to install. NOTE: These are installed via `apt-get`.
276
+ extra_python_packages: Additional Python packages to install. NOTE: These are installed via `uv pip`.
253
277
  cache_from: Where to pull the BuildKit cache from. See `--cache-from` in `docker build`.
254
278
  labels: The container labels.
279
+ ssh: A list of SSH sockets/keys for the docker build step or `True` to use the default SSH agent.
255
280
  args: The user arguments to pass to the executable.
256
281
  env_vars: The environment variables to pass to the executable.
257
282
 
@@ -263,20 +288,22 @@ def pdm_container(
263
288
  if context is None:
264
289
  context = utils.find_project_root()
265
290
  context = context.resolve()
266
- if not (context / "pdm.lock").exists():
267
- raise ValueError(f"PDM lockfile `{context / 'pdm.lock'}` doesn't exist.")
291
+ if not (context / "pyproject.toml").exists():
292
+ raise ValueError(f"Python project file `{context / 'pyproject.toml'}` doesn't exist.")
293
+ if not (context / "uv.lock").exists():
294
+ raise ValueError(f"UV lock file `{context / 'uv.lock'}` doesn't exist.")
268
295
 
269
296
  with resources.as_file(
270
- resources.files("xm_slurm.templates").joinpath("docker/pdm.Dockerfile")
297
+ resources.files("xm_slurm.templates").joinpath("docker/uv.Dockerfile")
271
298
  ) as dockerfile:
272
299
  return docker_container(
273
300
  executor_spec=executor_spec,
274
301
  dockerfile=dockerfile,
275
302
  context=context,
303
+ ssh=ssh,
276
304
  build_args={
277
- "PYTHON_MAJOR": str(sys.version_info.major),
278
- "PYTHON_MINOR": str(sys.version_info.minor),
279
- "PYTHON_MICRO": str(sys.version_info.micro),
305
+ "EXTRA_SYSTEM_PACKAGES": " ".join(extra_system_packages),
306
+ "EXTRA_PYTHON_PACKAGES": " ".join(extra_python_packages),
280
307
  "BASE_IMAGE": base_image.format_map({
281
308
  "major": sys.version_info.major,
282
309
  "minor": sys.version_info.minor,
@@ -286,7 +313,7 @@ def pdm_container(
286
313
  cache_from=cache_from,
287
314
  labels=labels,
288
315
  # We must specify the workdir manually for apptainer support
289
- workdir=pathlib.Path("/workspace/src"),
316
+ workdir=pathlib.Path("/workspace"),
290
317
  args=args,
291
318
  env_vars=env_vars,
292
319
  )
@@ -11,8 +11,8 @@ from xm_slurm.packaging import registry
11
11
  from xm_slurm.packaging.docker.abc import DockerClient
12
12
 
13
13
  FLAGS = flags.FLAGS
14
- REMOTE_BUILD = flags.DEFINE_enum(
15
- "xm_builder", "local", ["local", "gcp", "azure"], "Remote build provider."
14
+ DOCKER_CLIENT_PROVIDER = flags.DEFINE_enum(
15
+ "xm_docker_client", "docker", ["docker"], "Docker image build client."
16
16
  )
17
17
 
18
18
  IndexedContainer = registry.IndexedContainer
@@ -20,19 +20,13 @@ IndexedContainer = registry.IndexedContainer
20
20
 
21
21
  @functools.cache
22
22
  def docker_client() -> DockerClient:
23
- match REMOTE_BUILD.value:
24
- case "local":
23
+ match DOCKER_CLIENT_PROVIDER.value:
24
+ case "docker":
25
25
  from xm_slurm.packaging.docker.local import LocalDockerClient
26
26
 
27
27
  return LocalDockerClient()
28
- case "gcp":
29
- from xm_slurm.packaging.docker.cloud import GoogleCloudRemoteDockerClient
30
-
31
- return GoogleCloudRemoteDockerClient()
32
- case "azure":
33
- raise NotImplementedError("Azure remote build is not yet supported.")
34
28
  case _:
35
- raise ValueError(f"Unknown remote build provider: {REMOTE_BUILD.value}")
29
+ raise ValueError(f"Unknown build client: {DOCKER_CLIENT_PROVIDER.value}")
36
30
 
37
31
 
38
32
  @registry.register(Dockerfile)
@@ -9,7 +9,6 @@ import shlex
9
9
  import shutil
10
10
  import subprocess
11
11
  import tempfile
12
- import typing
13
12
  from typing import Sequence
14
13
 
15
14
  from xmanager import xm
@@ -37,7 +36,7 @@ class LocalDockerClient(DockerClient):
37
36
  BUILDKIT = enum.auto()
38
37
  BUILDAH = enum.auto()
39
38
 
40
- def __init__(self):
39
+ def __init__(self) -> None:
41
40
  if "XM_DOCKER_CLIENT" in os.environ:
42
41
  client_call = shlex.split(os.environ["XM_DOCKER_CLIENT"])
43
42
  elif shutil.which("docker"):
@@ -80,7 +79,7 @@ class LocalDockerClient(DockerClient):
80
79
  return None
81
80
 
82
81
  def _parse_credentials_from_config(
83
- config_path: pathlib.Path
82
+ config_path: pathlib.Path,
84
83
  ) -> RemoteRepositoryCredentials | None:
85
84
  """Parse credentials from the Docker configuration file."""
86
85
  if not config_path.exists():
@@ -138,9 +137,13 @@ class LocalDockerClient(DockerClient):
138
137
  targets: Sequence[IndexedContainer[xm.Packageable]],
139
138
  ) -> list[IndexedContainer[RemoteImage]]:
140
139
  executors_by_executables = packaging_utils.collect_executors_by_executable(targets)
141
- executors_by_executables = typing.cast(
142
- dict[Dockerfile, list[SlurmSpec]], executors_by_executables
143
- )
140
+ for executable, executors in executors_by_executables.items():
141
+ assert isinstance(
142
+ executable, Dockerfile
143
+ ), "All executables must be Dockerfiles when building Docker images."
144
+ assert all(
145
+ isinstance(executor, SlurmSpec) and executor.tag for executor in executors
146
+ ), "All executors must be SlurmSpecs with tags when building Docker images."
144
147
 
145
148
  with tempfile.TemporaryDirectory() as tempdir:
146
149
  hcl_file = pathlib.Path(tempdir) / "docker-bake.hcl"
@@ -157,12 +160,10 @@ class LocalDockerClient(DockerClient):
157
160
  try:
158
161
  command = DockerBakeCommand(
159
162
  targets=list(
160
- set(
161
- [
162
- packaging_utils.hash_digest(target.value.executable_spec)
163
- for target in targets
164
- ]
165
- )
163
+ set([
164
+ packaging_utils.hash_digest(target.value.executable_spec)
165
+ for target in targets
166
+ ])
166
167
  ),
167
168
  files=[hcl_file],
168
169
  metadata_file=metadata_file,
@@ -10,8 +10,7 @@ import re
10
10
  import select
11
11
  import shutil
12
12
  import subprocess
13
- import typing
14
- from typing import Callable, Concatenate, Hashable, Literal, ParamSpec, Sequence, TypeVar
13
+ from typing import Callable, Concatenate, Hashable, ParamSpec, Sequence, TypeVar
15
14
 
16
15
  from xmanager import xm
17
16
 
@@ -23,6 +22,7 @@ ReturnT = TypeVar("ReturnT")
23
22
 
24
23
 
25
24
  def hash_digest(obj: Hashable) -> str:
25
+ # TODO(jfarebro): Need a better way to hash these objects
26
26
  # obj_hash = hash(obj)
27
27
  # unsigned_obj_hash = obj_hash.from_bytes(
28
28
  # obj_hash.to_bytes((obj_hash.bit_length() + 7) // 8, "big", signed=True),
@@ -54,7 +54,7 @@ def parallel_map(
54
54
 
55
55
 
56
56
  # Cursor commands to filter out from the command data stream
57
- cursor_commands_regex = re.compile(
57
+ _CURSOR_ESCAPE_SEQUENCES_REGEX = re.compile(
58
58
  rb"\x1b\[\?25[hl]" # Matches cursor show/hide commands (CSI ?25h and CSI ?25l)
59
59
  rb"|\x1b\[[0-9;]*[Hf]" # Matches cursor position commands (CSI n;mH and CSI n;mf)
60
60
  rb"|\x1b\[s" # Matches cursor save position (CSI s)
@@ -64,54 +64,6 @@ cursor_commands_regex = re.compile(
64
64
  )
65
65
 
66
66
 
67
- @typing.overload
68
- def run_command(
69
- args: Sequence[str] | xm.SequentialArgs,
70
- env: dict[str, str] | None = ...,
71
- tty: bool = ...,
72
- cwd: str | os.PathLike[str] | None = ...,
73
- check: bool = ...,
74
- return_stdout: Literal[False] = False,
75
- return_stderr: Literal[False] = False,
76
- ) -> subprocess.CompletedProcess[None]: ...
77
-
78
-
79
- @typing.overload
80
- def run_command(
81
- args: Sequence[str] | xm.SequentialArgs,
82
- env: dict[str, str] | None = ...,
83
- tty: bool = ...,
84
- cwd: str | os.PathLike[str] | None = ...,
85
- check: bool = ...,
86
- return_stdout: Literal[True] = True,
87
- return_stderr: Literal[False] = False,
88
- ) -> subprocess.CompletedProcess[str]: ...
89
-
90
-
91
- @typing.overload
92
- def run_command(
93
- args: Sequence[str] | xm.SequentialArgs,
94
- env: dict[str, str] | None = ...,
95
- tty: bool = ...,
96
- cwd: str | os.PathLike[str] | None = ...,
97
- check: bool = ...,
98
- return_stdout: Literal[False] = False,
99
- return_stderr: Literal[True] = True,
100
- ) -> subprocess.CompletedProcess[str]: ...
101
-
102
-
103
- @typing.overload
104
- def run_command(
105
- args: Sequence[str] | xm.SequentialArgs,
106
- env: dict[str, str] | None = ...,
107
- tty: bool = ...,
108
- cwd: str | os.PathLike[str] | None = ...,
109
- check: bool = ...,
110
- return_stdout: Literal[True] = True,
111
- return_stderr: Literal[True] = True,
112
- ) -> subprocess.CompletedProcess[str]: ...
113
-
114
-
115
67
  def run_command(
116
68
  args: Sequence[str] | xm.SequentialArgs,
117
69
  env: dict[str, str] | None = None,
@@ -120,7 +72,7 @@ def run_command(
120
72
  check: bool = False,
121
73
  return_stdout: bool = False,
122
74
  return_stderr: bool = False,
123
- ) -> subprocess.CompletedProcess[str] | subprocess.CompletedProcess[None]:
75
+ ) -> subprocess.CompletedProcess[str]:
124
76
  if isinstance(args, xm.SequentialArgs):
125
77
  args = args.to_list()
126
78
  args = list(args)
@@ -171,7 +123,7 @@ def run_command(
171
123
  fds.remove(fd)
172
124
  continue
173
125
 
174
- data = re.sub(cursor_commands_regex, b"", data)
126
+ data = _CURSOR_ESCAPE_SEQUENCES_REGEX.sub(b"", data)
175
127
 
176
128
  if fd == stdout_master:
177
129
  if return_stdout:
@@ -186,8 +138,8 @@ def run_command(
186
138
  else:
187
139
  raise RuntimeError("Unexpected file descriptor")
188
140
 
189
- stdout = stdout_data.decode(errors="replace") if stdout_data else None
190
- stderr = stderr_data.decode(errors="replace") if stderr_data else None
141
+ stdout = stdout_data.decode(errors="replace") if stdout_data else ""
142
+ stderr = stderr_data.decode(errors="replace") if stderr_data else ""
191
143
 
192
144
  retcode = process.poll()
193
145
  assert retcode is not None