flytekitplugins-slurm 1.15.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flytekitplugins_slurm-1.15.1/PKG-INFO +30 -0
- flytekitplugins_slurm-1.15.1/README.md +5 -0
- flytekitplugins_slurm-1.15.1/flytekitplugins/slurm/__init__.py +4 -0
- flytekitplugins_slurm-1.15.1/flytekitplugins/slurm/function/agent.py +190 -0
- flytekitplugins_slurm-1.15.1/flytekitplugins/slurm/function/task.py +79 -0
- flytekitplugins_slurm-1.15.1/flytekitplugins/slurm/script/agent.py +191 -0
- flytekitplugins_slurm-1.15.1/flytekitplugins/slurm/script/task.py +136 -0
- flytekitplugins_slurm-1.15.1/flytekitplugins/slurm/ssh_utils.py +190 -0
- flytekitplugins_slurm-1.15.1/flytekitplugins_slurm.egg-info/PKG-INFO +30 -0
- flytekitplugins_slurm-1.15.1/flytekitplugins_slurm.egg-info/SOURCES.txt +17 -0
- flytekitplugins_slurm-1.15.1/flytekitplugins_slurm.egg-info/dependency_links.txt +1 -0
- flytekitplugins_slurm-1.15.1/flytekitplugins_slurm.egg-info/entry_points.txt +2 -0
- flytekitplugins_slurm-1.15.1/flytekitplugins_slurm.egg-info/namespace_packages.txt +1 -0
- flytekitplugins_slurm-1.15.1/flytekitplugins_slurm.egg-info/requires.txt +3 -0
- flytekitplugins_slurm-1.15.1/flytekitplugins_slurm.egg-info/top_level.txt +1 -0
- flytekitplugins_slurm-1.15.1/setup.cfg +4 -0
- flytekitplugins_slurm-1.15.1/setup.py +41 -0
- flytekitplugins_slurm-1.15.1/tests/test_slurm_fn_task.py +55 -0
- flytekitplugins_slurm-1.15.1/tests/test_slurm_shell_task.py +77 -0
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
Metadata-Version: 2.2
|
|
2
|
+
Name: flytekitplugins-slurm
|
|
3
|
+
Version: 1.15.1
|
|
4
|
+
Summary: This package holds the Slurm plugins for flytekit
|
|
5
|
+
Author: flyteorg
|
|
6
|
+
Author-email: admin@flyte.org
|
|
7
|
+
License: apache2
|
|
8
|
+
Classifier: Intended Audience :: Science/Research
|
|
9
|
+
Classifier: Intended Audience :: Developers
|
|
10
|
+
Classifier: License :: OSI Approved :: Apache Software License
|
|
11
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
15
|
+
Classifier: Topic :: Scientific/Engineering
|
|
16
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
17
|
+
Classifier: Topic :: Software Development
|
|
18
|
+
Classifier: Topic :: Software Development :: Libraries
|
|
19
|
+
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
20
|
+
Requires-Python: >=3.9
|
|
21
|
+
Requires-Dist: flytekit>=1.15.0
|
|
22
|
+
Requires-Dist: flyteidl>=1.15.0
|
|
23
|
+
Requires-Dist: asyncssh
|
|
24
|
+
Dynamic: author
|
|
25
|
+
Dynamic: author-email
|
|
26
|
+
Dynamic: classifier
|
|
27
|
+
Dynamic: license
|
|
28
|
+
Dynamic: requires-dist
|
|
29
|
+
Dynamic: requires-python
|
|
30
|
+
Dynamic: summary
|
|
@@ -0,0 +1,5 @@
|
|
|
1
|
+
# Flytekit Slurm Plugin
|
|
2
|
+
|
|
3
|
+
The Slurm agent is designed to integrate Flyte workflows with Slurm-managed high-performance computing (HPC) clusters, enabling users to leverage Slurm's capability of compute resource allocation, scheduling, and monitoring.
|
|
4
|
+
|
|
5
|
+
This [guide](https://github.com/JiangJiaWei1103/flytekit/blob/slurm-agent-dev/plugins/flytekit-slurm/demo.md) provides a concise overview of the design philosophy behind the Slurm agent and explains how to set up a local environment for testing the agent.
|
|
@@ -0,0 +1,190 @@
|
|
|
1
|
+
import tempfile
|
|
2
|
+
import uuid
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Dict, List, Optional, Tuple, Union
|
|
5
|
+
|
|
6
|
+
from asyncssh import SSHClientConnection
|
|
7
|
+
|
|
8
|
+
from flytekit import logger
|
|
9
|
+
from flytekit.extend.backend.base_agent import AgentRegistry, AsyncAgentBase, Resource, ResourceMeta
|
|
10
|
+
from flytekit.extend.backend.utils import convert_to_flyte_phase
|
|
11
|
+
from flytekit.models.literals import LiteralMap
|
|
12
|
+
from flytekit.models.task import TaskTemplate
|
|
13
|
+
|
|
14
|
+
from ..ssh_utils import ssh_connect
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass
|
|
18
|
+
class SlurmJobMetadata(ResourceMeta):
|
|
19
|
+
"""Slurm job metadata.
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
job_id: Slurm job id.
|
|
23
|
+
ssh_config: Options of SSH client connection. For available options, please refer to
|
|
24
|
+
the ssh_utils module.
|
|
25
|
+
|
|
26
|
+
Attributes:
|
|
27
|
+
job_id (str): Slurm job id.
|
|
28
|
+
ssh_config (Dict[str, Union[str, List[str], Tuple[str, ...]]]): SSH configuration options
|
|
29
|
+
for establishing client connections.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
job_id: str
|
|
33
|
+
ssh_config: Dict[str, Union[str, List[str], Tuple[str, ...]]]
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dataclass
|
|
37
|
+
class SlurmCluster:
|
|
38
|
+
host: str
|
|
39
|
+
username: Optional[str] = None
|
|
40
|
+
|
|
41
|
+
def __hash__(self):
|
|
42
|
+
return hash((self.host, self.username))
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class SlurmFunctionAgent(AsyncAgentBase):
|
|
46
|
+
name = "Slurm Function Agent"
|
|
47
|
+
|
|
48
|
+
# SSH connection pool for multi-host environment
|
|
49
|
+
ssh_config_to_ssh_conn: Dict[SlurmCluster, SSHClientConnection] = {}
|
|
50
|
+
|
|
51
|
+
def __init__(self) -> None:
|
|
52
|
+
super(SlurmFunctionAgent, self).__init__(task_type_name="slurm_fn", metadata_type=SlurmJobMetadata)
|
|
53
|
+
|
|
54
|
+
async def create(
|
|
55
|
+
self,
|
|
56
|
+
task_template: TaskTemplate,
|
|
57
|
+
inputs: Optional[LiteralMap] = None,
|
|
58
|
+
**kwargs,
|
|
59
|
+
) -> SlurmJobMetadata:
|
|
60
|
+
unique_script_name = f"/tmp/task_{uuid.uuid4().hex}.slurm"
|
|
61
|
+
|
|
62
|
+
# Retrieve task config
|
|
63
|
+
ssh_config = task_template.custom["ssh_config"]
|
|
64
|
+
sbatch_conf = task_template.custom["sbatch_conf"]
|
|
65
|
+
script = task_template.custom["script"]
|
|
66
|
+
|
|
67
|
+
# Construct command for Slurm cluster
|
|
68
|
+
cmd, script = _get_sbatch_cmd_and_script(
|
|
69
|
+
sbatch_conf=sbatch_conf,
|
|
70
|
+
entrypoint=" ".join(task_template.container.args),
|
|
71
|
+
script=script,
|
|
72
|
+
batch_script_path=unique_script_name,
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
# Run Slurm job
|
|
76
|
+
conn = await self._get_or_create_ssh_connection(ssh_config)
|
|
77
|
+
with tempfile.NamedTemporaryFile("w") as f:
|
|
78
|
+
f.write(script)
|
|
79
|
+
f.flush()
|
|
80
|
+
async with conn.start_sftp_client() as sftp:
|
|
81
|
+
await sftp.put(f.name, unique_script_name)
|
|
82
|
+
res = await conn.run(cmd, check=True)
|
|
83
|
+
|
|
84
|
+
# Retrieve Slurm job id
|
|
85
|
+
job_id = res.stdout.split()[-1]
|
|
86
|
+
|
|
87
|
+
return SlurmJobMetadata(job_id=job_id, ssh_config=ssh_config)
|
|
88
|
+
|
|
89
|
+
async def get(self, resource_meta: SlurmJobMetadata, **kwargs) -> Resource:
|
|
90
|
+
ssh_config = resource_meta.ssh_config
|
|
91
|
+
conn = await self._get_or_create_ssh_connection(ssh_config)
|
|
92
|
+
job_res = await conn.run(f"scontrol show job {resource_meta.job_id}", check=True)
|
|
93
|
+
|
|
94
|
+
# Determine the current flyte phase from Slurm job state
|
|
95
|
+
job_state = "running"
|
|
96
|
+
msg = "No stdout available"
|
|
97
|
+
for o in job_res.stdout.split(" "):
|
|
98
|
+
if "JobState" in o:
|
|
99
|
+
job_state = o.split("=")[1].strip().lower()
|
|
100
|
+
elif "StdOut" in o:
|
|
101
|
+
stdout_path = o.split("=")[1].strip()
|
|
102
|
+
msg_res = await conn.run(f"cat {stdout_path}", check=True)
|
|
103
|
+
msg = msg_res.stdout
|
|
104
|
+
|
|
105
|
+
cur_phase = convert_to_flyte_phase(job_state)
|
|
106
|
+
|
|
107
|
+
return Resource(phase=cur_phase, message=msg)
|
|
108
|
+
|
|
109
|
+
async def delete(self, resource_meta: SlurmJobMetadata, **kwargs) -> None:
|
|
110
|
+
conn = await self._get_or_create_ssh_connection(resource_meta.ssh_config)
|
|
111
|
+
_ = await conn.run(f"scancel {resource_meta.job_id}", check=True)
|
|
112
|
+
|
|
113
|
+
async def _get_or_create_ssh_connection(
|
|
114
|
+
self, ssh_config: Dict[str, Union[str, List[str], Tuple[str, ...]]]
|
|
115
|
+
) -> SSHClientConnection:
|
|
116
|
+
"""Get an existing SSH connection or create a new one if needed.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
ssh_config (Dict[str, Union[str, List[str], Tuple[str, ...]]]): SSH configuration dictionary.
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
SSHClientConnection: An active SSH connection, either pre-existing or newly established.
|
|
123
|
+
"""
|
|
124
|
+
host = ssh_config.get("host")
|
|
125
|
+
username = ssh_config.get("username")
|
|
126
|
+
|
|
127
|
+
ssh_cluster_config = SlurmCluster(host=host, username=username)
|
|
128
|
+
if self.ssh_config_to_ssh_conn.get(ssh_cluster_config) is None:
|
|
129
|
+
logger.info("ssh connection key not found, creating new connection")
|
|
130
|
+
conn = await ssh_connect(ssh_config=ssh_config)
|
|
131
|
+
self.ssh_config_to_ssh_conn[ssh_cluster_config] = conn
|
|
132
|
+
else:
|
|
133
|
+
conn = self.ssh_config_to_ssh_conn[ssh_cluster_config]
|
|
134
|
+
try:
|
|
135
|
+
await conn.run("echo [TEST] SSH connection", check=True)
|
|
136
|
+
logger.info("re-using new connection")
|
|
137
|
+
except Exception as e:
|
|
138
|
+
logger.info(f"Re-establishing SSH connection due to error: {e}")
|
|
139
|
+
conn = await ssh_connect(ssh_config=ssh_config)
|
|
140
|
+
self.ssh_config_to_ssh_conn[ssh_cluster_config] = conn
|
|
141
|
+
|
|
142
|
+
return conn
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def _get_sbatch_cmd_and_script(
|
|
146
|
+
sbatch_conf: Dict[str, str],
|
|
147
|
+
entrypoint: str,
|
|
148
|
+
script: Optional[str] = None,
|
|
149
|
+
batch_script_path: str = "/tmp/task.slurm",
|
|
150
|
+
) -> Tuple[str, str]:
|
|
151
|
+
"""Construct the Slurm sbatch command and the batch script content.
|
|
152
|
+
|
|
153
|
+
Flyte entrypoint, pyflyte-execute, is run within a bash shell process.
|
|
154
|
+
|
|
155
|
+
Args:
|
|
156
|
+
sbatch_conf (Dict[str, str]): Options of sbatch command.
|
|
157
|
+
entrypoint (str): Flyte entrypoint.
|
|
158
|
+
script (Optional[str], optional): User-defined script where "{task.fn}" serves as a placeholder for the
|
|
159
|
+
task function execution. Users should insert "{task.fn}" at the desired
|
|
160
|
+
execution point within the script. If the script is not provided, the task
|
|
161
|
+
function will be executed directly. Defaults to None.
|
|
162
|
+
batch_script_path (str, optional): Absolute path of the batch script on Slurm cluster.
|
|
163
|
+
Defaults to "/tmp/task.slurm".
|
|
164
|
+
|
|
165
|
+
Returns:
|
|
166
|
+
Tuple[str, str]: A tuple containing:
|
|
167
|
+
- cmd: Slurm sbatch command
|
|
168
|
+
- script: The batch script content
|
|
169
|
+
"""
|
|
170
|
+
# Setup sbatch options
|
|
171
|
+
cmd = ["sbatch"]
|
|
172
|
+
for opt, val in sbatch_conf.items():
|
|
173
|
+
cmd.extend([f"--{opt}", str(val)])
|
|
174
|
+
|
|
175
|
+
# Assign the batch script to run
|
|
176
|
+
cmd.append(batch_script_path)
|
|
177
|
+
|
|
178
|
+
if script is None:
|
|
179
|
+
script = f"""#!/bin/bash -i
|
|
180
|
+
{entrypoint}
|
|
181
|
+
"""
|
|
182
|
+
else:
|
|
183
|
+
script = script.replace("{task.fn}", entrypoint)
|
|
184
|
+
|
|
185
|
+
cmd = " ".join(cmd)
|
|
186
|
+
|
|
187
|
+
return cmd, script
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
AgentRegistry.register(SlurmFunctionAgent())
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
3
|
+
|
|
4
|
+
from flytekit import FlyteContextManager, PythonFunctionTask
|
|
5
|
+
from flytekit.configuration import SerializationSettings
|
|
6
|
+
from flytekit.extend import TaskPlugins
|
|
7
|
+
from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin
|
|
8
|
+
from flytekit.image_spec import ImageSpec
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class SlurmFunction(object):
|
|
13
|
+
"""Configure Slurm settings. Note that we focus on sbatch command now.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
ssh_config: Options of SSH client connection. For available options, please refer to
|
|
17
|
+
<newly-added-ssh-utils-file>
|
|
18
|
+
sbatch_conf: Options of sbatch command. If not provided, defaults to an empty dict.
|
|
19
|
+
script: User-defined script where "{task.fn}" serves as a placeholder for the
|
|
20
|
+
task function execution. Users should insert "{task.fn}" at the desired
|
|
21
|
+
execution point within the script. If the script is not provided, the task
|
|
22
|
+
function will be executed directly.
|
|
23
|
+
|
|
24
|
+
Attributes:
|
|
25
|
+
ssh_config (Dict[str, Union[str, List[str], Tuple[str, ...]]]): SSH client configuration options.
|
|
26
|
+
sbatch_conf (Optional[Dict[str, str]]): Slurm sbatch command options.
|
|
27
|
+
script (Optional[str]): Custom script template for task execution.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
ssh_config: Dict[str, Union[str, List[str], Tuple[str, ...]]]
|
|
31
|
+
sbatch_conf: Optional[Dict[str, str]] = None
|
|
32
|
+
script: Optional[str] = None
|
|
33
|
+
|
|
34
|
+
def __post_init__(self):
|
|
35
|
+
assert self.ssh_config["host"] is not None, "'host' must be specified in ssh_config."
|
|
36
|
+
if self.sbatch_conf is None:
|
|
37
|
+
self.sbatch_conf = {}
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class SlurmFunctionTask(AsyncAgentExecutorMixin, PythonFunctionTask[SlurmFunction]):
|
|
41
|
+
"""
|
|
42
|
+
Actual Plugin that transforms the local python code for execution within a slurm context...
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
_TASK_TYPE = "slurm_fn"
|
|
46
|
+
|
|
47
|
+
def __init__(
|
|
48
|
+
self,
|
|
49
|
+
task_config: SlurmFunction,
|
|
50
|
+
task_function: Callable,
|
|
51
|
+
container_image: Optional[Union[str, ImageSpec]] = None,
|
|
52
|
+
**kwargs,
|
|
53
|
+
):
|
|
54
|
+
super(SlurmFunctionTask, self).__init__(
|
|
55
|
+
task_config=task_config,
|
|
56
|
+
task_type=self._TASK_TYPE,
|
|
57
|
+
task_function=task_function,
|
|
58
|
+
container_image=container_image,
|
|
59
|
+
**kwargs,
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
|
|
63
|
+
return {
|
|
64
|
+
"ssh_config": self.task_config.ssh_config,
|
|
65
|
+
"sbatch_conf": self.task_config.sbatch_conf,
|
|
66
|
+
"script": self.task_config.script,
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
def execute(self, **kwargs) -> Any:
|
|
70
|
+
ctx = FlyteContextManager.current_context()
|
|
71
|
+
if ctx.execution_state and ctx.execution_state.is_local_execution():
|
|
72
|
+
# Mimic the propeller's behavior in local agent test
|
|
73
|
+
return AsyncAgentExecutorMixin.execute(self, **kwargs)
|
|
74
|
+
else:
|
|
75
|
+
# Execute the task with a direct python function call
|
|
76
|
+
return PythonFunctionTask.execute(self, **kwargs)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
TaskPlugins.register_pythontask_plugin(SlurmFunction, SlurmFunctionTask)
|
|
@@ -0,0 +1,191 @@
|
|
|
1
|
+
import tempfile
|
|
2
|
+
import uuid
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Any, Dict, List, Optional, Tuple, Type
|
|
5
|
+
|
|
6
|
+
from asyncssh import SSHClientConnection
|
|
7
|
+
|
|
8
|
+
import flytekit
|
|
9
|
+
from flytekit.core.type_engine import TypeEngine
|
|
10
|
+
from flytekit.extend.backend.base_agent import AgentRegistry, AsyncAgentBase, Resource, ResourceMeta
|
|
11
|
+
from flytekit.extend.backend.utils import convert_to_flyte_phase
|
|
12
|
+
from flytekit.extras.tasks.shell import OutputLocation, _PythonFStringInterpolizer
|
|
13
|
+
from flytekit.models.literals import LiteralMap
|
|
14
|
+
from flytekit.models.task import TaskTemplate
|
|
15
|
+
|
|
16
|
+
from ..ssh_utils import SlurmCluster, get_ssh_conn
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass
|
|
20
|
+
class SlurmJobMetadata(ResourceMeta):
|
|
21
|
+
"""
|
|
22
|
+
Slurm job metadata.
|
|
23
|
+
|
|
24
|
+
Attributes:
|
|
25
|
+
job_id (str): Slurm job id.
|
|
26
|
+
ssh_config (Dict[str, Any]): SSH connection configuration options.
|
|
27
|
+
outputs (Dict[str, str]): Mapping from the output variable name to the output location.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
job_id: str
|
|
31
|
+
ssh_config: Dict[str, Any]
|
|
32
|
+
outputs: Dict[str, str]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class SlurmScriptAgent(AsyncAgentBase):
|
|
36
|
+
name = "Slurm Script Agent"
|
|
37
|
+
|
|
38
|
+
# SSH connection pool for multi-host environment
|
|
39
|
+
slurm_cluster_to_ssh_conn: Dict[SlurmCluster, SSHClientConnection] = {}
|
|
40
|
+
|
|
41
|
+
# Dummy script content
|
|
42
|
+
DUMMY_SCRIPT = "#!/bin/bash"
|
|
43
|
+
|
|
44
|
+
def __init__(self) -> None:
|
|
45
|
+
super(SlurmScriptAgent, self).__init__(task_type_name="slurm", metadata_type=SlurmJobMetadata)
|
|
46
|
+
|
|
47
|
+
async def create(
|
|
48
|
+
self,
|
|
49
|
+
task_template: TaskTemplate,
|
|
50
|
+
inputs: Optional[LiteralMap] = None,
|
|
51
|
+
**kwargs,
|
|
52
|
+
) -> SlurmJobMetadata:
|
|
53
|
+
uniq_script_path = f"/tmp/task_{uuid.uuid4().hex}.slurm"
|
|
54
|
+
outputs = {}
|
|
55
|
+
|
|
56
|
+
# Retrieve task config
|
|
57
|
+
ssh_config = task_template.custom["ssh_config"]
|
|
58
|
+
batch_script_args = task_template.custom["batch_script_args"]
|
|
59
|
+
sbatch_conf = task_template.custom["sbatch_conf"]
|
|
60
|
+
|
|
61
|
+
# Construct sbatch command for Slurm cluster
|
|
62
|
+
upload_script = False
|
|
63
|
+
if "script" in task_template.custom:
|
|
64
|
+
script = task_template.custom["script"]
|
|
65
|
+
assert script != self.DUMMY_SCRIPT, "Please write the user-defined batch script content."
|
|
66
|
+
script, outputs = self._interpolate_script(
|
|
67
|
+
script,
|
|
68
|
+
input_literal_map=inputs,
|
|
69
|
+
python_input_types=task_template.custom["python_input_types"],
|
|
70
|
+
output_locs=task_template.custom["output_locs"],
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
batch_script_path = uniq_script_path
|
|
74
|
+
upload_script = True
|
|
75
|
+
else:
|
|
76
|
+
# Assume the batch script is already on Slurm
|
|
77
|
+
batch_script_path = task_template.custom["batch_script_path"]
|
|
78
|
+
cmd = _get_sbatch_cmd(
|
|
79
|
+
sbatch_conf=sbatch_conf, batch_script_path=batch_script_path, batch_script_args=batch_script_args
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
# Run Slurm job
|
|
83
|
+
conn = await get_ssh_conn(ssh_config=ssh_config, slurm_cluster_to_ssh_conn=self.slurm_cluster_to_ssh_conn)
|
|
84
|
+
if upload_script:
|
|
85
|
+
with tempfile.NamedTemporaryFile("w") as f:
|
|
86
|
+
f.write(script)
|
|
87
|
+
f.flush()
|
|
88
|
+
async with conn.start_sftp_client() as sftp:
|
|
89
|
+
await sftp.put(f.name, batch_script_path)
|
|
90
|
+
res = await conn.run(cmd, check=True)
|
|
91
|
+
|
|
92
|
+
# Retrieve Slurm job id
|
|
93
|
+
job_id = res.stdout.split()[-1]
|
|
94
|
+
|
|
95
|
+
return SlurmJobMetadata(job_id=job_id, ssh_config=ssh_config, outputs=outputs)
|
|
96
|
+
|
|
97
|
+
async def get(self, resource_meta: SlurmJobMetadata, **kwargs) -> Resource:
|
|
98
|
+
conn = await get_ssh_conn(
|
|
99
|
+
ssh_config=resource_meta.ssh_config, slurm_cluster_to_ssh_conn=self.slurm_cluster_to_ssh_conn
|
|
100
|
+
)
|
|
101
|
+
job_res = await conn.run(f"scontrol show job {resource_meta.job_id}", check=True)
|
|
102
|
+
|
|
103
|
+
# Determine the current flyte phase from Slurm job state
|
|
104
|
+
msg = ""
|
|
105
|
+
job_state = "running"
|
|
106
|
+
for o in job_res.stdout.split(" "):
|
|
107
|
+
if "JobState" in o:
|
|
108
|
+
job_state = o.split("=")[1].strip().lower()
|
|
109
|
+
elif "StdOut" in o:
|
|
110
|
+
stdout_path = o.split("=")[1].strip()
|
|
111
|
+
msg_res = await conn.run(f"cat {stdout_path}", check=True)
|
|
112
|
+
msg = msg_res.stdout
|
|
113
|
+
cur_phase = convert_to_flyte_phase(job_state)
|
|
114
|
+
|
|
115
|
+
return Resource(phase=cur_phase, message=msg, outputs=resource_meta.outputs)
|
|
116
|
+
|
|
117
|
+
async def delete(self, resource_meta: SlurmJobMetadata, **kwargs) -> None:
|
|
118
|
+
conn = await get_ssh_conn(
|
|
119
|
+
ssh_config=resource_meta.ssh_config, slurm_cluster_to_ssh_conn=self.slurm_cluster_to_ssh_conn
|
|
120
|
+
)
|
|
121
|
+
_ = await conn.run(f"scancel {resource_meta.job_id}", check=True)
|
|
122
|
+
|
|
123
|
+
def _interpolate_script(
|
|
124
|
+
self,
|
|
125
|
+
script: str,
|
|
126
|
+
input_literal_map: Optional[LiteralMap] = None,
|
|
127
|
+
python_input_types: Optional[Dict[str, Type]] = None,
|
|
128
|
+
output_locs: Optional[List[OutputLocation]] = None,
|
|
129
|
+
) -> Tuple[str, Dict[str, str]]:
|
|
130
|
+
"""
|
|
131
|
+
Interpolate the user-defined script with the specified input and output arguments.
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
script (str): The user-defined script with placeholders for dynamic input/output.
|
|
135
|
+
input_literal_map (Optional[LiteralMap]): The Flyte LiteralMap of inputs.
|
|
136
|
+
python_input_types (Optional[Dict[str, Type]]): Mapping of input names to their Python/typing types.
|
|
137
|
+
output_locs (Optional[List[OutputLocation]]): List of output locations to be interpolated.
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
Tuple[str, Dict[str, str]]:
|
|
141
|
+
- A two-element tuple in which the first element is the interpolated script (str),
|
|
142
|
+
and the second is a dictionary mapping each output variable name to its final location (str).
|
|
143
|
+
"""
|
|
144
|
+
input_kwargs = TypeEngine.literal_map_to_kwargs(
|
|
145
|
+
flytekit.current_context(),
|
|
146
|
+
lm=input_literal_map,
|
|
147
|
+
python_types={} if python_input_types is None else python_input_types,
|
|
148
|
+
)
|
|
149
|
+
interpolizer = _PythonFStringInterpolizer()
|
|
150
|
+
|
|
151
|
+
# Interpolate output locations with input values
|
|
152
|
+
outputs = {}
|
|
153
|
+
if output_locs is not None:
|
|
154
|
+
for oloc in output_locs:
|
|
155
|
+
outputs[oloc.var] = interpolizer.interpolate(oloc.location, inputs=input_kwargs)
|
|
156
|
+
|
|
157
|
+
# Interpolate the script
|
|
158
|
+
script = interpolizer.interpolate(script, inputs=input_kwargs, outputs=outputs)
|
|
159
|
+
|
|
160
|
+
return script, outputs
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def _get_sbatch_cmd(sbatch_conf: Dict[str, str], batch_script_path: str, batch_script_args: List[str] = None) -> str:
|
|
164
|
+
"""
|
|
165
|
+
Construct the Slurm sbatch command.
|
|
166
|
+
|
|
167
|
+
Args:
|
|
168
|
+
sbatch_conf (Dict[str, str]): Slurm sbatch configuration options (e.g., partition, job-name, etc.).
|
|
169
|
+
batch_script_path (str): Absolute path on the Slurm cluster of the script to run.
|
|
170
|
+
batch_script_args (List[str], optional): Additional arguments to pass to the batch script.
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
str: The sbatch command string that can be executed on the Slurm cluster.
|
|
174
|
+
"""
|
|
175
|
+
cmd = ["sbatch"]
|
|
176
|
+
for opt, val in sbatch_conf.items():
|
|
177
|
+
cmd.extend([f"--{opt}", str(val)])
|
|
178
|
+
|
|
179
|
+
# Assign the batch script to run
|
|
180
|
+
cmd.append(batch_script_path)
|
|
181
|
+
|
|
182
|
+
# Add args if present
|
|
183
|
+
if batch_script_args:
|
|
184
|
+
for arg in batch_script_args:
|
|
185
|
+
cmd.append(arg)
|
|
186
|
+
|
|
187
|
+
cmd = " ".join(cmd)
|
|
188
|
+
return cmd
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
AgentRegistry.register(SlurmScriptAgent())
|
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
from dataclasses import dataclass, field
|
|
2
|
+
from typing import Any, Dict, List, Optional, Type
|
|
3
|
+
|
|
4
|
+
from flytekit.configuration import SerializationSettings
|
|
5
|
+
from flytekit.core.base_task import PythonTask
|
|
6
|
+
from flytekit.core.interface import Interface
|
|
7
|
+
from flytekit.extend import TaskPlugins
|
|
8
|
+
from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin
|
|
9
|
+
from flytekit.extras.tasks.shell import OutputLocation
|
|
10
|
+
from flytekit.types.directory import FlyteDirectory
|
|
11
|
+
from flytekit.types.file import FlyteFile
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class Slurm(object):
|
|
16
|
+
"""
|
|
17
|
+
Configure Slurm settings. Note that we focus on sbatch command now.
|
|
18
|
+
|
|
19
|
+
Compared with spark, please refer to:
|
|
20
|
+
https://api-docs.databricks.com/python/pyspark/latest/api/pyspark.SparkContext.html.
|
|
21
|
+
|
|
22
|
+
Attributes:
|
|
23
|
+
ssh_config (Dict[str, Any]): Options of SSH client connection. For available options, please refer to
|
|
24
|
+
<newly-added-ssh-utils-file>.
|
|
25
|
+
sbatch_conf (Optional[Dict[str, str]]): Options of sbatch command. For available options, please refer to
|
|
26
|
+
https://slurm.schedmd.com/sbatch.html.
|
|
27
|
+
batch_script_args (Optional[List[str]]): Additional args for the batch script on Slurm cluster.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
ssh_config: Dict[str, Any]
|
|
31
|
+
sbatch_conf: Optional[Dict[str, str]] = None
|
|
32
|
+
batch_script_args: Optional[List[str]] = None
|
|
33
|
+
|
|
34
|
+
def __post_init__(self):
|
|
35
|
+
if self.sbatch_conf is None:
|
|
36
|
+
self.sbatch_conf = {}
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@dataclass
|
|
40
|
+
class SlurmRemoteScript(Slurm):
|
|
41
|
+
"""Encounter collision if Slurm is shared btw SlurmTask and SlurmShellTask."""
|
|
42
|
+
|
|
43
|
+
batch_script_path: str = field(default=None)
|
|
44
|
+
|
|
45
|
+
def __post_init__(self):
|
|
46
|
+
super().__post_init__()
|
|
47
|
+
if self.batch_script_path is None:
|
|
48
|
+
raise ValueError("batch_script_path must be provided")
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class SlurmTask(AsyncAgentExecutorMixin, PythonTask[SlurmRemoteScript]):
|
|
52
|
+
_TASK_TYPE = "slurm"
|
|
53
|
+
|
|
54
|
+
def __init__(
|
|
55
|
+
self,
|
|
56
|
+
name: str,
|
|
57
|
+
task_config: SlurmRemoteScript,
|
|
58
|
+
**kwargs,
|
|
59
|
+
):
|
|
60
|
+
super(SlurmTask, self).__init__(
|
|
61
|
+
task_type=self._TASK_TYPE,
|
|
62
|
+
name=name,
|
|
63
|
+
task_config=task_config,
|
|
64
|
+
# Dummy interface, will support this after discussion
|
|
65
|
+
interface=Interface(None, None),
|
|
66
|
+
**kwargs,
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
|
|
70
|
+
return {
|
|
71
|
+
"ssh_config": self.task_config.ssh_config,
|
|
72
|
+
"batch_script_path": self.task_config.batch_script_path,
|
|
73
|
+
"batch_script_args": self.task_config.batch_script_args,
|
|
74
|
+
"sbatch_conf": self.task_config.sbatch_conf,
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class SlurmShellTask(AsyncAgentExecutorMixin, PythonTask[Slurm]):
|
|
79
|
+
_TASK_TYPE = "slurm"
|
|
80
|
+
|
|
81
|
+
def __init__(
|
|
82
|
+
self,
|
|
83
|
+
name: str,
|
|
84
|
+
task_config: Slurm,
|
|
85
|
+
script: str,
|
|
86
|
+
inputs: Optional[Dict[str, Type]] = None,
|
|
87
|
+
output_locs: Optional[List[OutputLocation]] = None,
|
|
88
|
+
**kwargs,
|
|
89
|
+
):
|
|
90
|
+
self._inputs = inputs
|
|
91
|
+
self._output_locs = output_locs if output_locs is not None else []
|
|
92
|
+
self._script = script
|
|
93
|
+
|
|
94
|
+
outputs = self._validate_output_locs()
|
|
95
|
+
|
|
96
|
+
super().__init__(
|
|
97
|
+
name=name,
|
|
98
|
+
task_type=self._TASK_TYPE,
|
|
99
|
+
task_config=task_config,
|
|
100
|
+
interface=Interface(inputs=inputs, outputs=outputs),
|
|
101
|
+
**kwargs,
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
def _validate_output_locs(self) -> Dict[str, Type]:
|
|
105
|
+
outputs = {}
|
|
106
|
+
for v in self._output_locs:
|
|
107
|
+
if v is None:
|
|
108
|
+
raise ValueError("OutputLocation cannot be none")
|
|
109
|
+
if not isinstance(v, OutputLocation):
|
|
110
|
+
raise ValueError("Every output type should be an output location on the file-system")
|
|
111
|
+
if v.location is None:
|
|
112
|
+
raise ValueError(f"Output Location not provided for output var {v.var}")
|
|
113
|
+
if not issubclass(v.var_type, FlyteFile) and not issubclass(v.var_type, FlyteDirectory):
|
|
114
|
+
raise ValueError(
|
|
115
|
+
"Currently only outputs of type FlyteFile/FlyteDirectory and their derived types are supported"
|
|
116
|
+
)
|
|
117
|
+
outputs[v.var] = v.var_type
|
|
118
|
+
return outputs
|
|
119
|
+
|
|
120
|
+
def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
|
|
121
|
+
return {
|
|
122
|
+
"ssh_config": self.task_config.ssh_config,
|
|
123
|
+
"batch_script_args": self.task_config.batch_script_args,
|
|
124
|
+
"sbatch_conf": self.task_config.sbatch_conf,
|
|
125
|
+
"script": self._script,
|
|
126
|
+
"python_input_types": self._inputs,
|
|
127
|
+
"output_locs": self._output_locs,
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
@property
|
|
131
|
+
def script(self) -> str:
|
|
132
|
+
return self._script
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
TaskPlugins.register_pythontask_plugin(SlurmRemoteScript, SlurmTask)
|
|
136
|
+
TaskPlugins.register_pythontask_plugin(Slurm, SlurmShellTask)
|
|
@@ -0,0 +1,190 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Utilities of asyncssh connections.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
import sys
|
|
7
|
+
from dataclasses import asdict, dataclass
|
|
8
|
+
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
|
|
9
|
+
|
|
10
|
+
import asyncssh
|
|
11
|
+
from asyncssh import SSHClientConnection
|
|
12
|
+
|
|
13
|
+
from flytekit import logger
|
|
14
|
+
from flytekit.extend.backend.utils import get_agent_secret
|
|
15
|
+
|
|
16
|
+
T = TypeVar("T", bound="SSHConfig")
|
|
17
|
+
SLURM_PRIVATE_KEY = "FLYTE_SLURM_PRIVATE_KEY"
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@dataclass
|
|
21
|
+
class SlurmCluster:
|
|
22
|
+
"""A Slurm cluster instance is defined by a pair of (Slurm host, username).
|
|
23
|
+
|
|
24
|
+
Attributes:
|
|
25
|
+
host (str): The hostname or address to connect to.
|
|
26
|
+
username (Optional[str]): The username to authenticate as on the server.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
host: str
|
|
30
|
+
username: Optional[str] = None
|
|
31
|
+
|
|
32
|
+
def __hash__(self):
|
|
33
|
+
return hash((self.host, self.username))
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dataclass(frozen=True)
|
|
37
|
+
class SSHConfig:
|
|
38
|
+
"""A customized version of SSHClientConnectionOptions, tailored to specific needs.
|
|
39
|
+
|
|
40
|
+
This config is based on the official SSHClientConnectionOptions but includes
|
|
41
|
+
only a subset of options, with some fields adjusted to be optional or required.
|
|
42
|
+
For the official options, please refer to:
|
|
43
|
+
https://asyncssh.readthedocs.io/en/latest/api.html#asyncssh.SSHClientConnectionOptions
|
|
44
|
+
|
|
45
|
+
Attributes:
|
|
46
|
+
host (str): The hostname or address to connect to.
|
|
47
|
+
username (Optional[str]): The username to authenticate as on the server.
|
|
48
|
+
client_keys (Union[str, List[str], Tuple[str, ...]]): File paths to private keys which will be used to authenticate the
|
|
49
|
+
client via public key authentication. The default value is an empty tuple since
|
|
50
|
+
client public key authentication is mandatory.
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
host: str
|
|
54
|
+
username: Optional[str] = None
|
|
55
|
+
client_keys: Union[str, List[str], Tuple[str, ...]] = ()
|
|
56
|
+
|
|
57
|
+
@classmethod
|
|
58
|
+
def from_dict(cls: Type[T], ssh_config: Dict[str, Any]) -> T:
|
|
59
|
+
return cls(**ssh_config)
|
|
60
|
+
|
|
61
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
62
|
+
return asdict(self)
|
|
63
|
+
|
|
64
|
+
def __eq__(self, other):
|
|
65
|
+
if not isinstance(other, SSHConfig):
|
|
66
|
+
return False
|
|
67
|
+
return self.host == other.host and self.username == other.username and self.client_keys == other.client_keys
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
async def ssh_connect(ssh_config: Dict[str, Any]) -> SSHClientConnection:
|
|
71
|
+
"""Make an SSH client connection.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
ssh_config (Dict[str, Any]): Options of SSH client connection defined in SSHConfig.
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
SSHClientConnection: An SSH client connection object.
|
|
78
|
+
|
|
79
|
+
Raises:
|
|
80
|
+
ValueError: If both FLYTE_SLURM_PRIVATE_KEY secret and ssh_config['private_key'] are missing.
|
|
81
|
+
"""
|
|
82
|
+
# Validate ssh_config
|
|
83
|
+
ssh_config = SSHConfig.from_dict(ssh_config).to_dict()
|
|
84
|
+
# This is required to avoid the error "asyncssh.misc.HostKeyNotVerifiable" when connecting to a new host.
|
|
85
|
+
ssh_config["known_hosts"] = None
|
|
86
|
+
|
|
87
|
+
# Make the first SSH connection using either OpenSSH client config files or
|
|
88
|
+
# a user-defined private key. If using OpenSSH config, it will attempt to
|
|
89
|
+
# load settings from ~/.ssh/config.
|
|
90
|
+
try:
|
|
91
|
+
conn = await asyncssh.connect(**ssh_config)
|
|
92
|
+
return conn
|
|
93
|
+
except Exception as e:
|
|
94
|
+
logger.info(
|
|
95
|
+
"Failed to make an SSH connection using the default OpenSSH client config (~/.ssh/config) or "
|
|
96
|
+
f"the provided private keys. Error details:\n{e}"
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
try:
|
|
100
|
+
default_client_key = get_agent_secret(secret_key=SLURM_PRIVATE_KEY)
|
|
101
|
+
except ValueError:
|
|
102
|
+
logger.info("The secret for key FLYTE_SLURM_PRIVATE_KEY is not set.")
|
|
103
|
+
default_client_key = None
|
|
104
|
+
|
|
105
|
+
if default_client_key is None and ssh_config.get("client_keys") == ():
|
|
106
|
+
raise ValueError(
|
|
107
|
+
"Both the secret for key FLYTE_SLURM_PRIVATE_KEY and ssh_config['private_key'] are missing. "
|
|
108
|
+
"At least one must be set."
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
client_keys = []
|
|
112
|
+
if default_client_key is not None:
|
|
113
|
+
# Write the private key to a local path
|
|
114
|
+
# This may not be a good practice...
|
|
115
|
+
private_key_path = os.path.abspath("./slurm_private_key")
|
|
116
|
+
with open(private_key_path, "w") as f:
|
|
117
|
+
f.write(default_client_key)
|
|
118
|
+
client_keys.append(private_key_path)
|
|
119
|
+
|
|
120
|
+
user_client_keys = ssh_config.get("client_keys")
|
|
121
|
+
if user_client_keys is not None:
|
|
122
|
+
client_keys.extend([user_client_keys] if isinstance(user_client_keys, str) else user_client_keys)
|
|
123
|
+
|
|
124
|
+
ssh_config["client_keys"] = client_keys
|
|
125
|
+
logger.info(f"Updated SSH config: {ssh_config}")
|
|
126
|
+
try:
|
|
127
|
+
conn = await asyncssh.connect(**ssh_config)
|
|
128
|
+
return conn
|
|
129
|
+
except Exception as e:
|
|
130
|
+
logger.info(
|
|
131
|
+
"Failed to make an SSH connection using the provided private keys. Please verify your setup."
|
|
132
|
+
f"Error details:\n{e}"
|
|
133
|
+
)
|
|
134
|
+
sys.exit(1)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
async def get_ssh_conn(
|
|
138
|
+
ssh_config: Dict[str, Union[str, List[str], Tuple[str, ...]]],
|
|
139
|
+
slurm_cluster_to_ssh_conn: Dict[SlurmCluster, SSHClientConnection],
|
|
140
|
+
) -> Tuple[SlurmCluster, SSHClientConnection]:
|
|
141
|
+
"""
|
|
142
|
+
Get an existing SSH connection or create a new one if needed.
|
|
143
|
+
|
|
144
|
+
Args:
|
|
145
|
+
ssh_config (Dict[str, Union[str, List[str], Tuple[str, ...]]]):
|
|
146
|
+
SSH configuration dictionary, including host and username.
|
|
147
|
+
slurm_cluster_to_ssh_conn (Dict[SlurmCluster, SSHClientConnection]):
|
|
148
|
+
A mapping of SlurmCluster to existing SSHClientConnection objects.
|
|
149
|
+
|
|
150
|
+
Returns:
|
|
151
|
+
Tuple[SlurmCluster, SSHClientConnection]:
|
|
152
|
+
A tuple containing (SlurmCluster, SSHClientConnection). If no connection
|
|
153
|
+
for the given SlurmCluster exists, a new one is created and cached.
|
|
154
|
+
"""
|
|
155
|
+
|
|
156
|
+
# (Optional) normal code comment instead of docstring line:
|
|
157
|
+
# Is it necessary to ensure immutability in this function?
|
|
158
|
+
|
|
159
|
+
host = ssh_config.get("host")
|
|
160
|
+
username = ssh_config.get("username")
|
|
161
|
+
slurm_cluster = SlurmCluster(host=host, username=username)
|
|
162
|
+
|
|
163
|
+
if slurm_cluster_to_ssh_conn.get(slurm_cluster) is None:
|
|
164
|
+
logger.info("SSH connection key not found, creating new connection")
|
|
165
|
+
conn = await ssh_connect(ssh_config=ssh_config)
|
|
166
|
+
slurm_cluster_to_ssh_conn[slurm_cluster] = conn
|
|
167
|
+
else:
|
|
168
|
+
conn = slurm_cluster_to_ssh_conn[slurm_cluster]
|
|
169
|
+
try:
|
|
170
|
+
await conn.run("echo [TEST] SSH connection", check=True)
|
|
171
|
+
logger.info("Re-using new connection")
|
|
172
|
+
except Exception as e:
|
|
173
|
+
logger.info(f"Re-establishing SSH connection due to error: {e}")
|
|
174
|
+
conn = await ssh_connect(ssh_config=ssh_config)
|
|
175
|
+
slurm_cluster_to_ssh_conn[slurm_cluster] = conn
|
|
176
|
+
|
|
177
|
+
return conn
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
if __name__ == "__main__":
|
|
181
|
+
import asyncio
|
|
182
|
+
|
|
183
|
+
async def test_connect():
|
|
184
|
+
conn = await ssh_connect({"host": "aws2", "username": "ubuntu"})
|
|
185
|
+
res = await conn.run("echo [TEST] SSH connection", check=True)
|
|
186
|
+
out = res.stdout
|
|
187
|
+
|
|
188
|
+
return out
|
|
189
|
+
|
|
190
|
+
logger.info(asyncio.run(test_connect()))
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
Metadata-Version: 2.2
|
|
2
|
+
Name: flytekitplugins-slurm
|
|
3
|
+
Version: 1.15.1
|
|
4
|
+
Summary: This package holds the Slurm plugins for flytekit
|
|
5
|
+
Author: flyteorg
|
|
6
|
+
Author-email: admin@flyte.org
|
|
7
|
+
License: apache2
|
|
8
|
+
Classifier: Intended Audience :: Science/Research
|
|
9
|
+
Classifier: Intended Audience :: Developers
|
|
10
|
+
Classifier: License :: OSI Approved :: Apache Software License
|
|
11
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
15
|
+
Classifier: Topic :: Scientific/Engineering
|
|
16
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
17
|
+
Classifier: Topic :: Software Development
|
|
18
|
+
Classifier: Topic :: Software Development :: Libraries
|
|
19
|
+
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
20
|
+
Requires-Python: >=3.9
|
|
21
|
+
Requires-Dist: flytekit>=1.15.0
|
|
22
|
+
Requires-Dist: flyteidl>=1.15.0
|
|
23
|
+
Requires-Dist: asyncssh
|
|
24
|
+
Dynamic: author
|
|
25
|
+
Dynamic: author-email
|
|
26
|
+
Dynamic: classifier
|
|
27
|
+
Dynamic: license
|
|
28
|
+
Dynamic: requires-dist
|
|
29
|
+
Dynamic: requires-python
|
|
30
|
+
Dynamic: summary
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
README.md
|
|
2
|
+
setup.py
|
|
3
|
+
flytekitplugins/slurm/__init__.py
|
|
4
|
+
flytekitplugins/slurm/ssh_utils.py
|
|
5
|
+
flytekitplugins/slurm/function/agent.py
|
|
6
|
+
flytekitplugins/slurm/function/task.py
|
|
7
|
+
flytekitplugins/slurm/script/agent.py
|
|
8
|
+
flytekitplugins/slurm/script/task.py
|
|
9
|
+
flytekitplugins_slurm.egg-info/PKG-INFO
|
|
10
|
+
flytekitplugins_slurm.egg-info/SOURCES.txt
|
|
11
|
+
flytekitplugins_slurm.egg-info/dependency_links.txt
|
|
12
|
+
flytekitplugins_slurm.egg-info/entry_points.txt
|
|
13
|
+
flytekitplugins_slurm.egg-info/namespace_packages.txt
|
|
14
|
+
flytekitplugins_slurm.egg-info/requires.txt
|
|
15
|
+
flytekitplugins_slurm.egg-info/top_level.txt
|
|
16
|
+
tests/test_slurm_fn_task.py
|
|
17
|
+
tests/test_slurm_shell_task.py
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
flytekitplugins
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
flytekitplugins
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
from setuptools import setup
|
|
2
|
+
|
|
3
|
+
PLUGIN_NAME = "slurm"
|
|
4
|
+
|
|
5
|
+
microlib_name = f"flytekitplugins-{PLUGIN_NAME}"
|
|
6
|
+
|
|
7
|
+
plugin_requires = ["flytekit>=1.15.0", "flyteidl>=1.15.0", "asyncssh"]
|
|
8
|
+
|
|
9
|
+
__version__ = "1.15.1"
|
|
10
|
+
|
|
11
|
+
setup(
|
|
12
|
+
name=microlib_name,
|
|
13
|
+
version=__version__,
|
|
14
|
+
author="flyteorg",
|
|
15
|
+
author_email="admin@flyte.org",
|
|
16
|
+
description="This package holds the Slurm plugins for flytekit",
|
|
17
|
+
namespace_packages=["flytekitplugins"],
|
|
18
|
+
packages=[
|
|
19
|
+
f"flytekitplugins.{PLUGIN_NAME}",
|
|
20
|
+
f"flytekitplugins.{PLUGIN_NAME}.script",
|
|
21
|
+
f"flytekitplugins.{PLUGIN_NAME}.function",
|
|
22
|
+
],
|
|
23
|
+
install_requires=plugin_requires,
|
|
24
|
+
license="apache2",
|
|
25
|
+
python_requires=">=3.9",
|
|
26
|
+
classifiers=[
|
|
27
|
+
"Intended Audience :: Science/Research",
|
|
28
|
+
"Intended Audience :: Developers",
|
|
29
|
+
"License :: OSI Approved :: Apache Software License",
|
|
30
|
+
"Programming Language :: Python :: 3.9",
|
|
31
|
+
"Programming Language :: Python :: 3.10",
|
|
32
|
+
"Programming Language :: Python :: 3.11",
|
|
33
|
+
"Programming Language :: Python :: 3.12",
|
|
34
|
+
"Topic :: Scientific/Engineering",
|
|
35
|
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
|
36
|
+
"Topic :: Software Development",
|
|
37
|
+
"Topic :: Software Development :: Libraries",
|
|
38
|
+
"Topic :: Software Development :: Libraries :: Python Modules",
|
|
39
|
+
],
|
|
40
|
+
entry_points={"flytekit.plugins": [f"{PLUGIN_NAME}=flytekitplugins.{PLUGIN_NAME}"]},
|
|
41
|
+
)
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
import os.path
|
|
2
|
+
from unittest import mock
|
|
3
|
+
from flytekit.core import context_manager
|
|
4
|
+
import flytekit
|
|
5
|
+
from flytekit import StructuredDataset, StructuredDatasetTransformerEngine, task, ImageSpec
|
|
6
|
+
from flytekit.configuration import Image, ImageConfig, SerializationSettings, FastSerializationSettings, DefaultImages
|
|
7
|
+
from flytekit.core.context_manager import ExecutionParameters, FlyteContextManager, ExecutionState
|
|
8
|
+
from flytekitplugins.slurm import SlurmFunction
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def test_slurm_task():
|
|
12
|
+
script_file = """#!/bin/bash -i
|
|
13
|
+
|
|
14
|
+
echo Run function with sbatch...
|
|
15
|
+
|
|
16
|
+
# Run the user-defined task function
|
|
17
|
+
{task.fn}
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
@task(
|
|
21
|
+
# container_image=image,
|
|
22
|
+
task_config=SlurmFunction(
|
|
23
|
+
ssh_config={
|
|
24
|
+
"host": "your-slurm-host",
|
|
25
|
+
"username": "ubuntu",
|
|
26
|
+
},
|
|
27
|
+
sbatch_conf={
|
|
28
|
+
"partition": "debug",
|
|
29
|
+
"job-name": "tiny-slurm",
|
|
30
|
+
"output": "/home/ubuntu/fn_task.log"
|
|
31
|
+
},
|
|
32
|
+
script=script_file
|
|
33
|
+
)
|
|
34
|
+
)
|
|
35
|
+
def plus_one(x: int) -> int:
|
|
36
|
+
return x + 1
|
|
37
|
+
|
|
38
|
+
assert plus_one.task_config is not None
|
|
39
|
+
assert plus_one.task_config.ssh_config == {"host": "your-slurm-host", "username": "ubuntu"}
|
|
40
|
+
assert plus_one.task_config.sbatch_conf == {"partition": "debug", "job-name": "tiny-slurm", "output": "/home/ubuntu/fn_task.log"}
|
|
41
|
+
assert plus_one.task_config.script == script_file
|
|
42
|
+
|
|
43
|
+
default_img = Image(name="default", fqn="test", tag="tag")
|
|
44
|
+
settings = SerializationSettings(
|
|
45
|
+
project="project",
|
|
46
|
+
domain="domain",
|
|
47
|
+
version="version",
|
|
48
|
+
env={"FOO": "baz"},
|
|
49
|
+
image_config=ImageConfig(default_image=default_img, images=[default_img]),
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
retrieved_settings = plus_one.get_custom(settings)
|
|
53
|
+
assert retrieved_settings["ssh_config"] == {"host": "your-slurm-host", "username": "ubuntu"}
|
|
54
|
+
assert retrieved_settings["sbatch_conf"] == {"partition": "debug", "job-name": "tiny-slurm", "output": "/home/ubuntu/fn_task.log"}
|
|
55
|
+
assert retrieved_settings["script"] == script_file
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
from flytekit.configuration import Image, ImageConfig, SerializationSettings
|
|
2
|
+
from flytekitplugins.slurm import Slurm, SlurmRemoteScript, SlurmTask, SlurmShellTask
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def test_slurm_shell_task():
|
|
6
|
+
"""Test SlurmTask and SlurmShellTask."""
|
|
7
|
+
batch_script_path = "/script/path/on/the/slurm/cluster"
|
|
8
|
+
script = """#!/bin/bash -i
|
|
9
|
+
|
|
10
|
+
echo "Run a Flyte SlurmShellTask...\n"
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
# SlurmTask
|
|
14
|
+
slurm_task = SlurmTask(
|
|
15
|
+
name="test-slurm-task",
|
|
16
|
+
task_config=SlurmRemoteScript(
|
|
17
|
+
ssh_config={
|
|
18
|
+
"host": "<your-slurm-host>",
|
|
19
|
+
"username": "ubuntu",
|
|
20
|
+
},
|
|
21
|
+
sbatch_conf={
|
|
22
|
+
"partition": "debug",
|
|
23
|
+
"job-name": "tiny-slurm",
|
|
24
|
+
"output": "/home/ubuntu/slurm_task.log"
|
|
25
|
+
},
|
|
26
|
+
batch_script_path=batch_script_path
|
|
27
|
+
)
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
assert slurm_task.task_config is not None
|
|
31
|
+
assert slurm_task.task_config.ssh_config == {"host": "<your-slurm-host>", "username": "ubuntu"}
|
|
32
|
+
assert slurm_task.task_config.sbatch_conf == {"partition": "debug", "job-name": "tiny-slurm", "output": "/home/ubuntu/slurm_task.log"}
|
|
33
|
+
assert slurm_task.task_config.batch_script_path == batch_script_path
|
|
34
|
+
|
|
35
|
+
# SlurmShellTask
|
|
36
|
+
slurm_shell_task = SlurmShellTask(
|
|
37
|
+
name="test-slurm-shell-task",
|
|
38
|
+
script=script,
|
|
39
|
+
task_config=Slurm(
|
|
40
|
+
ssh_config={
|
|
41
|
+
"host": "<your-slurm-host>",
|
|
42
|
+
"username": "ubuntu",
|
|
43
|
+
},
|
|
44
|
+
sbatch_conf={
|
|
45
|
+
"partition": "debug",
|
|
46
|
+
"job-name": "tiny-slurm",
|
|
47
|
+
"output": "/home/ubuntu/slurm_shell_task.log"
|
|
48
|
+
}
|
|
49
|
+
)
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
assert slurm_shell_task.task_config is not None
|
|
53
|
+
assert slurm_shell_task.task_config.ssh_config == {"host": "<your-slurm-host>", "username": "ubuntu"}
|
|
54
|
+
assert slurm_shell_task.task_config.sbatch_conf == {"partition": "debug", "job-name": "tiny-slurm", "output": "/home/ubuntu/slurm_shell_task.log"}
|
|
55
|
+
assert slurm_shell_task.script == script
|
|
56
|
+
|
|
57
|
+
# Define dummy SerializationSettings
|
|
58
|
+
default_img = Image(name="default", fqn="test", tag="tag")
|
|
59
|
+
settings = SerializationSettings(
|
|
60
|
+
project="project",
|
|
61
|
+
domain="domain",
|
|
62
|
+
version="version",
|
|
63
|
+
env={"FOO": "baz"},
|
|
64
|
+
image_config=ImageConfig(default_image=default_img, images=[default_img])
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
custom = slurm_task.get_custom(settings)
|
|
68
|
+
assert custom["ssh_config"] == {"host": "<your-slurm-host>", "username": "ubuntu"}
|
|
69
|
+
assert custom["sbatch_conf"] == {"partition": "debug", "job-name": "tiny-slurm", "output": "/home/ubuntu/slurm_task.log"}
|
|
70
|
+
assert custom["batch_script_path"] == batch_script_path
|
|
71
|
+
assert custom["batch_script_args"] is None
|
|
72
|
+
|
|
73
|
+
shell_custom = slurm_shell_task.get_custom(settings)
|
|
74
|
+
assert shell_custom["ssh_config"] == {"host": "<your-slurm-host>", "username": "ubuntu"}
|
|
75
|
+
assert shell_custom["sbatch_conf"] == {"partition": "debug", "job-name": "tiny-slurm", "output": "/home/ubuntu/slurm_shell_task.log"}
|
|
76
|
+
assert shell_custom["script"] == script
|
|
77
|
+
assert shell_custom["batch_script_args"] is None
|