flytekitplugins-slurm 1.15.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,30 @@
1
+ Metadata-Version: 2.2
2
+ Name: flytekitplugins-slurm
3
+ Version: 1.15.1
4
+ Summary: This package holds the Slurm plugins for flytekit
5
+ Author: flyteorg
6
+ Author-email: admin@flyte.org
7
+ License: apache2
8
+ Classifier: Intended Audience :: Science/Research
9
+ Classifier: Intended Audience :: Developers
10
+ Classifier: License :: OSI Approved :: Apache Software License
11
+ Classifier: Programming Language :: Python :: 3.9
12
+ Classifier: Programming Language :: Python :: 3.10
13
+ Classifier: Programming Language :: Python :: 3.11
14
+ Classifier: Programming Language :: Python :: 3.12
15
+ Classifier: Topic :: Scientific/Engineering
16
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
17
+ Classifier: Topic :: Software Development
18
+ Classifier: Topic :: Software Development :: Libraries
19
+ Classifier: Topic :: Software Development :: Libraries :: Python Modules
20
+ Requires-Python: >=3.9
21
+ Requires-Dist: flytekit>=1.15.0
22
+ Requires-Dist: flyteidl>=1.15.0
23
+ Requires-Dist: asyncssh
24
+ Dynamic: author
25
+ Dynamic: author-email
26
+ Dynamic: classifier
27
+ Dynamic: license
28
+ Dynamic: requires-dist
29
+ Dynamic: requires-python
30
+ Dynamic: summary
@@ -0,0 +1,5 @@
1
+ # Flytekit Slurm Plugin
2
+
3
+ The Slurm agent is designed to integrate Flyte workflows with Slurm-managed high-performance computing (HPC) clusters, enabling users to leverage Slurm's capability of compute resource allocation, scheduling, and monitoring.
4
+
5
+ This [guide](https://github.com/JiangJiaWei1103/flytekit/blob/slurm-agent-dev/plugins/flytekit-slurm/demo.md) provides a concise overview of the design philosophy behind the Slurm agent and explains how to set up a local environment for testing the agent.
@@ -0,0 +1,4 @@
1
+ from .function.agent import SlurmFunctionAgent
2
+ from .function.task import SlurmFunction, SlurmFunctionTask
3
+ from .script.agent import SlurmScriptAgent
4
+ from .script.task import Slurm, SlurmRemoteScript, SlurmShellTask, SlurmTask
@@ -0,0 +1,190 @@
1
+ import tempfile
2
+ import uuid
3
+ from dataclasses import dataclass
4
+ from typing import Dict, List, Optional, Tuple, Union
5
+
6
+ from asyncssh import SSHClientConnection
7
+
8
+ from flytekit import logger
9
+ from flytekit.extend.backend.base_agent import AgentRegistry, AsyncAgentBase, Resource, ResourceMeta
10
+ from flytekit.extend.backend.utils import convert_to_flyte_phase
11
+ from flytekit.models.literals import LiteralMap
12
+ from flytekit.models.task import TaskTemplate
13
+
14
+ from ..ssh_utils import ssh_connect
15
+
16
+
17
+ @dataclass
18
+ class SlurmJobMetadata(ResourceMeta):
19
+ """Slurm job metadata.
20
+
21
+ Args:
22
+ job_id: Slurm job id.
23
+ ssh_config: Options of SSH client connection. For available options, please refer to
24
+ the ssh_utils module.
25
+
26
+ Attributes:
27
+ job_id (str): Slurm job id.
28
+ ssh_config (Dict[str, Union[str, List[str], Tuple[str, ...]]]): SSH configuration options
29
+ for establishing client connections.
30
+ """
31
+
32
+ job_id: str
33
+ ssh_config: Dict[str, Union[str, List[str], Tuple[str, ...]]]
34
+
35
+
36
+ @dataclass
37
+ class SlurmCluster:
38
+ host: str
39
+ username: Optional[str] = None
40
+
41
+ def __hash__(self):
42
+ return hash((self.host, self.username))
43
+
44
+
45
+ class SlurmFunctionAgent(AsyncAgentBase):
46
+ name = "Slurm Function Agent"
47
+
48
+ # SSH connection pool for multi-host environment
49
+ ssh_config_to_ssh_conn: Dict[SlurmCluster, SSHClientConnection] = {}
50
+
51
+ def __init__(self) -> None:
52
+ super(SlurmFunctionAgent, self).__init__(task_type_name="slurm_fn", metadata_type=SlurmJobMetadata)
53
+
54
+ async def create(
55
+ self,
56
+ task_template: TaskTemplate,
57
+ inputs: Optional[LiteralMap] = None,
58
+ **kwargs,
59
+ ) -> SlurmJobMetadata:
60
+ unique_script_name = f"/tmp/task_{uuid.uuid4().hex}.slurm"
61
+
62
+ # Retrieve task config
63
+ ssh_config = task_template.custom["ssh_config"]
64
+ sbatch_conf = task_template.custom["sbatch_conf"]
65
+ script = task_template.custom["script"]
66
+
67
+ # Construct command for Slurm cluster
68
+ cmd, script = _get_sbatch_cmd_and_script(
69
+ sbatch_conf=sbatch_conf,
70
+ entrypoint=" ".join(task_template.container.args),
71
+ script=script,
72
+ batch_script_path=unique_script_name,
73
+ )
74
+
75
+ # Run Slurm job
76
+ conn = await self._get_or_create_ssh_connection(ssh_config)
77
+ with tempfile.NamedTemporaryFile("w") as f:
78
+ f.write(script)
79
+ f.flush()
80
+ async with conn.start_sftp_client() as sftp:
81
+ await sftp.put(f.name, unique_script_name)
82
+ res = await conn.run(cmd, check=True)
83
+
84
+ # Retrieve Slurm job id
85
+ job_id = res.stdout.split()[-1]
86
+
87
+ return SlurmJobMetadata(job_id=job_id, ssh_config=ssh_config)
88
+
89
+ async def get(self, resource_meta: SlurmJobMetadata, **kwargs) -> Resource:
90
+ ssh_config = resource_meta.ssh_config
91
+ conn = await self._get_or_create_ssh_connection(ssh_config)
92
+ job_res = await conn.run(f"scontrol show job {resource_meta.job_id}", check=True)
93
+
94
+ # Determine the current flyte phase from Slurm job state
95
+ job_state = "running"
96
+ msg = "No stdout available"
97
+ for o in job_res.stdout.split(" "):
98
+ if "JobState" in o:
99
+ job_state = o.split("=")[1].strip().lower()
100
+ elif "StdOut" in o:
101
+ stdout_path = o.split("=")[1].strip()
102
+ msg_res = await conn.run(f"cat {stdout_path}", check=True)
103
+ msg = msg_res.stdout
104
+
105
+ cur_phase = convert_to_flyte_phase(job_state)
106
+
107
+ return Resource(phase=cur_phase, message=msg)
108
+
109
+ async def delete(self, resource_meta: SlurmJobMetadata, **kwargs) -> None:
110
+ conn = await self._get_or_create_ssh_connection(resource_meta.ssh_config)
111
+ _ = await conn.run(f"scancel {resource_meta.job_id}", check=True)
112
+
113
+ async def _get_or_create_ssh_connection(
114
+ self, ssh_config: Dict[str, Union[str, List[str], Tuple[str, ...]]]
115
+ ) -> SSHClientConnection:
116
+ """Get an existing SSH connection or create a new one if needed.
117
+
118
+ Args:
119
+ ssh_config (Dict[str, Union[str, List[str], Tuple[str, ...]]]): SSH configuration dictionary.
120
+
121
+ Returns:
122
+ SSHClientConnection: An active SSH connection, either pre-existing or newly established.
123
+ """
124
+ host = ssh_config.get("host")
125
+ username = ssh_config.get("username")
126
+
127
+ ssh_cluster_config = SlurmCluster(host=host, username=username)
128
+ if self.ssh_config_to_ssh_conn.get(ssh_cluster_config) is None:
129
+ logger.info("ssh connection key not found, creating new connection")
130
+ conn = await ssh_connect(ssh_config=ssh_config)
131
+ self.ssh_config_to_ssh_conn[ssh_cluster_config] = conn
132
+ else:
133
+ conn = self.ssh_config_to_ssh_conn[ssh_cluster_config]
134
+ try:
135
+ await conn.run("echo [TEST] SSH connection", check=True)
136
+ logger.info("re-using new connection")
137
+ except Exception as e:
138
+ logger.info(f"Re-establishing SSH connection due to error: {e}")
139
+ conn = await ssh_connect(ssh_config=ssh_config)
140
+ self.ssh_config_to_ssh_conn[ssh_cluster_config] = conn
141
+
142
+ return conn
143
+
144
+
145
+ def _get_sbatch_cmd_and_script(
146
+ sbatch_conf: Dict[str, str],
147
+ entrypoint: str,
148
+ script: Optional[str] = None,
149
+ batch_script_path: str = "/tmp/task.slurm",
150
+ ) -> Tuple[str, str]:
151
+ """Construct the Slurm sbatch command and the batch script content.
152
+
153
+ Flyte entrypoint, pyflyte-execute, is run within a bash shell process.
154
+
155
+ Args:
156
+ sbatch_conf (Dict[str, str]): Options of sbatch command.
157
+ entrypoint (str): Flyte entrypoint.
158
+ script (Optional[str], optional): User-defined script where "{task.fn}" serves as a placeholder for the
159
+ task function execution. Users should insert "{task.fn}" at the desired
160
+ execution point within the script. If the script is not provided, the task
161
+ function will be executed directly. Defaults to None.
162
+ batch_script_path (str, optional): Absolute path of the batch script on Slurm cluster.
163
+ Defaults to "/tmp/task.slurm".
164
+
165
+ Returns:
166
+ Tuple[str, str]: A tuple containing:
167
+ - cmd: Slurm sbatch command
168
+ - script: The batch script content
169
+ """
170
+ # Setup sbatch options
171
+ cmd = ["sbatch"]
172
+ for opt, val in sbatch_conf.items():
173
+ cmd.extend([f"--{opt}", str(val)])
174
+
175
+ # Assign the batch script to run
176
+ cmd.append(batch_script_path)
177
+
178
+ if script is None:
179
+ script = f"""#!/bin/bash -i
180
+ {entrypoint}
181
+ """
182
+ else:
183
+ script = script.replace("{task.fn}", entrypoint)
184
+
185
+ cmd = " ".join(cmd)
186
+
187
+ return cmd, script
188
+
189
+
190
+ AgentRegistry.register(SlurmFunctionAgent())
@@ -0,0 +1,79 @@
1
+ from dataclasses import dataclass
2
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
3
+
4
+ from flytekit import FlyteContextManager, PythonFunctionTask
5
+ from flytekit.configuration import SerializationSettings
6
+ from flytekit.extend import TaskPlugins
7
+ from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin
8
+ from flytekit.image_spec import ImageSpec
9
+
10
+
11
+ @dataclass
12
+ class SlurmFunction(object):
13
+ """Configure Slurm settings. Note that we focus on sbatch command now.
14
+
15
+ Args:
16
+ ssh_config: Options of SSH client connection. For available options, please refer to
17
+ <newly-added-ssh-utils-file>
18
+ sbatch_conf: Options of sbatch command. If not provided, defaults to an empty dict.
19
+ script: User-defined script where "{task.fn}" serves as a placeholder for the
20
+ task function execution. Users should insert "{task.fn}" at the desired
21
+ execution point within the script. If the script is not provided, the task
22
+ function will be executed directly.
23
+
24
+ Attributes:
25
+ ssh_config (Dict[str, Union[str, List[str], Tuple[str, ...]]]): SSH client configuration options.
26
+ sbatch_conf (Optional[Dict[str, str]]): Slurm sbatch command options.
27
+ script (Optional[str]): Custom script template for task execution.
28
+ """
29
+
30
+ ssh_config: Dict[str, Union[str, List[str], Tuple[str, ...]]]
31
+ sbatch_conf: Optional[Dict[str, str]] = None
32
+ script: Optional[str] = None
33
+
34
+ def __post_init__(self):
35
+ assert self.ssh_config["host"] is not None, "'host' must be specified in ssh_config."
36
+ if self.sbatch_conf is None:
37
+ self.sbatch_conf = {}
38
+
39
+
40
+ class SlurmFunctionTask(AsyncAgentExecutorMixin, PythonFunctionTask[SlurmFunction]):
41
+ """
42
+ Actual Plugin that transforms the local python code for execution within a slurm context...
43
+ """
44
+
45
+ _TASK_TYPE = "slurm_fn"
46
+
47
+ def __init__(
48
+ self,
49
+ task_config: SlurmFunction,
50
+ task_function: Callable,
51
+ container_image: Optional[Union[str, ImageSpec]] = None,
52
+ **kwargs,
53
+ ):
54
+ super(SlurmFunctionTask, self).__init__(
55
+ task_config=task_config,
56
+ task_type=self._TASK_TYPE,
57
+ task_function=task_function,
58
+ container_image=container_image,
59
+ **kwargs,
60
+ )
61
+
62
+ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
63
+ return {
64
+ "ssh_config": self.task_config.ssh_config,
65
+ "sbatch_conf": self.task_config.sbatch_conf,
66
+ "script": self.task_config.script,
67
+ }
68
+
69
+ def execute(self, **kwargs) -> Any:
70
+ ctx = FlyteContextManager.current_context()
71
+ if ctx.execution_state and ctx.execution_state.is_local_execution():
72
+ # Mimic the propeller's behavior in local agent test
73
+ return AsyncAgentExecutorMixin.execute(self, **kwargs)
74
+ else:
75
+ # Execute the task with a direct python function call
76
+ return PythonFunctionTask.execute(self, **kwargs)
77
+
78
+
79
+ TaskPlugins.register_pythontask_plugin(SlurmFunction, SlurmFunctionTask)
@@ -0,0 +1,191 @@
1
+ import tempfile
2
+ import uuid
3
+ from dataclasses import dataclass
4
+ from typing import Any, Dict, List, Optional, Tuple, Type
5
+
6
+ from asyncssh import SSHClientConnection
7
+
8
+ import flytekit
9
+ from flytekit.core.type_engine import TypeEngine
10
+ from flytekit.extend.backend.base_agent import AgentRegistry, AsyncAgentBase, Resource, ResourceMeta
11
+ from flytekit.extend.backend.utils import convert_to_flyte_phase
12
+ from flytekit.extras.tasks.shell import OutputLocation, _PythonFStringInterpolizer
13
+ from flytekit.models.literals import LiteralMap
14
+ from flytekit.models.task import TaskTemplate
15
+
16
+ from ..ssh_utils import SlurmCluster, get_ssh_conn
17
+
18
+
19
+ @dataclass
20
+ class SlurmJobMetadata(ResourceMeta):
21
+ """
22
+ Slurm job metadata.
23
+
24
+ Attributes:
25
+ job_id (str): Slurm job id.
26
+ ssh_config (Dict[str, Any]): SSH connection configuration options.
27
+ outputs (Dict[str, str]): Mapping from the output variable name to the output location.
28
+ """
29
+
30
+ job_id: str
31
+ ssh_config: Dict[str, Any]
32
+ outputs: Dict[str, str]
33
+
34
+
35
+ class SlurmScriptAgent(AsyncAgentBase):
36
+ name = "Slurm Script Agent"
37
+
38
+ # SSH connection pool for multi-host environment
39
+ slurm_cluster_to_ssh_conn: Dict[SlurmCluster, SSHClientConnection] = {}
40
+
41
+ # Dummy script content
42
+ DUMMY_SCRIPT = "#!/bin/bash"
43
+
44
+ def __init__(self) -> None:
45
+ super(SlurmScriptAgent, self).__init__(task_type_name="slurm", metadata_type=SlurmJobMetadata)
46
+
47
+ async def create(
48
+ self,
49
+ task_template: TaskTemplate,
50
+ inputs: Optional[LiteralMap] = None,
51
+ **kwargs,
52
+ ) -> SlurmJobMetadata:
53
+ uniq_script_path = f"/tmp/task_{uuid.uuid4().hex}.slurm"
54
+ outputs = {}
55
+
56
+ # Retrieve task config
57
+ ssh_config = task_template.custom["ssh_config"]
58
+ batch_script_args = task_template.custom["batch_script_args"]
59
+ sbatch_conf = task_template.custom["sbatch_conf"]
60
+
61
+ # Construct sbatch command for Slurm cluster
62
+ upload_script = False
63
+ if "script" in task_template.custom:
64
+ script = task_template.custom["script"]
65
+ assert script != self.DUMMY_SCRIPT, "Please write the user-defined batch script content."
66
+ script, outputs = self._interpolate_script(
67
+ script,
68
+ input_literal_map=inputs,
69
+ python_input_types=task_template.custom["python_input_types"],
70
+ output_locs=task_template.custom["output_locs"],
71
+ )
72
+
73
+ batch_script_path = uniq_script_path
74
+ upload_script = True
75
+ else:
76
+ # Assume the batch script is already on Slurm
77
+ batch_script_path = task_template.custom["batch_script_path"]
78
+ cmd = _get_sbatch_cmd(
79
+ sbatch_conf=sbatch_conf, batch_script_path=batch_script_path, batch_script_args=batch_script_args
80
+ )
81
+
82
+ # Run Slurm job
83
+ conn = await get_ssh_conn(ssh_config=ssh_config, slurm_cluster_to_ssh_conn=self.slurm_cluster_to_ssh_conn)
84
+ if upload_script:
85
+ with tempfile.NamedTemporaryFile("w") as f:
86
+ f.write(script)
87
+ f.flush()
88
+ async with conn.start_sftp_client() as sftp:
89
+ await sftp.put(f.name, batch_script_path)
90
+ res = await conn.run(cmd, check=True)
91
+
92
+ # Retrieve Slurm job id
93
+ job_id = res.stdout.split()[-1]
94
+
95
+ return SlurmJobMetadata(job_id=job_id, ssh_config=ssh_config, outputs=outputs)
96
+
97
+ async def get(self, resource_meta: SlurmJobMetadata, **kwargs) -> Resource:
98
+ conn = await get_ssh_conn(
99
+ ssh_config=resource_meta.ssh_config, slurm_cluster_to_ssh_conn=self.slurm_cluster_to_ssh_conn
100
+ )
101
+ job_res = await conn.run(f"scontrol show job {resource_meta.job_id}", check=True)
102
+
103
+ # Determine the current flyte phase from Slurm job state
104
+ msg = ""
105
+ job_state = "running"
106
+ for o in job_res.stdout.split(" "):
107
+ if "JobState" in o:
108
+ job_state = o.split("=")[1].strip().lower()
109
+ elif "StdOut" in o:
110
+ stdout_path = o.split("=")[1].strip()
111
+ msg_res = await conn.run(f"cat {stdout_path}", check=True)
112
+ msg = msg_res.stdout
113
+ cur_phase = convert_to_flyte_phase(job_state)
114
+
115
+ return Resource(phase=cur_phase, message=msg, outputs=resource_meta.outputs)
116
+
117
+ async def delete(self, resource_meta: SlurmJobMetadata, **kwargs) -> None:
118
+ conn = await get_ssh_conn(
119
+ ssh_config=resource_meta.ssh_config, slurm_cluster_to_ssh_conn=self.slurm_cluster_to_ssh_conn
120
+ )
121
+ _ = await conn.run(f"scancel {resource_meta.job_id}", check=True)
122
+
123
+ def _interpolate_script(
124
+ self,
125
+ script: str,
126
+ input_literal_map: Optional[LiteralMap] = None,
127
+ python_input_types: Optional[Dict[str, Type]] = None,
128
+ output_locs: Optional[List[OutputLocation]] = None,
129
+ ) -> Tuple[str, Dict[str, str]]:
130
+ """
131
+ Interpolate the user-defined script with the specified input and output arguments.
132
+
133
+ Args:
134
+ script (str): The user-defined script with placeholders for dynamic input/output.
135
+ input_literal_map (Optional[LiteralMap]): The Flyte LiteralMap of inputs.
136
+ python_input_types (Optional[Dict[str, Type]]): Mapping of input names to their Python/typing types.
137
+ output_locs (Optional[List[OutputLocation]]): List of output locations to be interpolated.
138
+
139
+ Returns:
140
+ Tuple[str, Dict[str, str]]:
141
+ - A two-element tuple in which the first element is the interpolated script (str),
142
+ and the second is a dictionary mapping each output variable name to its final location (str).
143
+ """
144
+ input_kwargs = TypeEngine.literal_map_to_kwargs(
145
+ flytekit.current_context(),
146
+ lm=input_literal_map,
147
+ python_types={} if python_input_types is None else python_input_types,
148
+ )
149
+ interpolizer = _PythonFStringInterpolizer()
150
+
151
+ # Interpolate output locations with input values
152
+ outputs = {}
153
+ if output_locs is not None:
154
+ for oloc in output_locs:
155
+ outputs[oloc.var] = interpolizer.interpolate(oloc.location, inputs=input_kwargs)
156
+
157
+ # Interpolate the script
158
+ script = interpolizer.interpolate(script, inputs=input_kwargs, outputs=outputs)
159
+
160
+ return script, outputs
161
+
162
+
163
+ def _get_sbatch_cmd(sbatch_conf: Dict[str, str], batch_script_path: str, batch_script_args: List[str] = None) -> str:
164
+ """
165
+ Construct the Slurm sbatch command.
166
+
167
+ Args:
168
+ sbatch_conf (Dict[str, str]): Slurm sbatch configuration options (e.g., partition, job-name, etc.).
169
+ batch_script_path (str): Absolute path on the Slurm cluster of the script to run.
170
+ batch_script_args (List[str], optional): Additional arguments to pass to the batch script.
171
+
172
+ Returns:
173
+ str: The sbatch command string that can be executed on the Slurm cluster.
174
+ """
175
+ cmd = ["sbatch"]
176
+ for opt, val in sbatch_conf.items():
177
+ cmd.extend([f"--{opt}", str(val)])
178
+
179
+ # Assign the batch script to run
180
+ cmd.append(batch_script_path)
181
+
182
+ # Add args if present
183
+ if batch_script_args:
184
+ for arg in batch_script_args:
185
+ cmd.append(arg)
186
+
187
+ cmd = " ".join(cmd)
188
+ return cmd
189
+
190
+
191
+ AgentRegistry.register(SlurmScriptAgent())
@@ -0,0 +1,136 @@
1
+ from dataclasses import dataclass, field
2
+ from typing import Any, Dict, List, Optional, Type
3
+
4
+ from flytekit.configuration import SerializationSettings
5
+ from flytekit.core.base_task import PythonTask
6
+ from flytekit.core.interface import Interface
7
+ from flytekit.extend import TaskPlugins
8
+ from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin
9
+ from flytekit.extras.tasks.shell import OutputLocation
10
+ from flytekit.types.directory import FlyteDirectory
11
+ from flytekit.types.file import FlyteFile
12
+
13
+
14
+ @dataclass
15
+ class Slurm(object):
16
+ """
17
+ Configure Slurm settings. Note that we focus on sbatch command now.
18
+
19
+ Compared with spark, please refer to:
20
+ https://api-docs.databricks.com/python/pyspark/latest/api/pyspark.SparkContext.html.
21
+
22
+ Attributes:
23
+ ssh_config (Dict[str, Any]): Options of SSH client connection. For available options, please refer to
24
+ <newly-added-ssh-utils-file>.
25
+ sbatch_conf (Optional[Dict[str, str]]): Options of sbatch command. For available options, please refer to
26
+ https://slurm.schedmd.com/sbatch.html.
27
+ batch_script_args (Optional[List[str]]): Additional args for the batch script on Slurm cluster.
28
+ """
29
+
30
+ ssh_config: Dict[str, Any]
31
+ sbatch_conf: Optional[Dict[str, str]] = None
32
+ batch_script_args: Optional[List[str]] = None
33
+
34
+ def __post_init__(self):
35
+ if self.sbatch_conf is None:
36
+ self.sbatch_conf = {}
37
+
38
+
39
+ @dataclass
40
+ class SlurmRemoteScript(Slurm):
41
+ """Encounter collision if Slurm is shared btw SlurmTask and SlurmShellTask."""
42
+
43
+ batch_script_path: str = field(default=None)
44
+
45
+ def __post_init__(self):
46
+ super().__post_init__()
47
+ if self.batch_script_path is None:
48
+ raise ValueError("batch_script_path must be provided")
49
+
50
+
51
+ class SlurmTask(AsyncAgentExecutorMixin, PythonTask[SlurmRemoteScript]):
52
+ _TASK_TYPE = "slurm"
53
+
54
+ def __init__(
55
+ self,
56
+ name: str,
57
+ task_config: SlurmRemoteScript,
58
+ **kwargs,
59
+ ):
60
+ super(SlurmTask, self).__init__(
61
+ task_type=self._TASK_TYPE,
62
+ name=name,
63
+ task_config=task_config,
64
+ # Dummy interface, will support this after discussion
65
+ interface=Interface(None, None),
66
+ **kwargs,
67
+ )
68
+
69
+ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
70
+ return {
71
+ "ssh_config": self.task_config.ssh_config,
72
+ "batch_script_path": self.task_config.batch_script_path,
73
+ "batch_script_args": self.task_config.batch_script_args,
74
+ "sbatch_conf": self.task_config.sbatch_conf,
75
+ }
76
+
77
+
78
+ class SlurmShellTask(AsyncAgentExecutorMixin, PythonTask[Slurm]):
79
+ _TASK_TYPE = "slurm"
80
+
81
+ def __init__(
82
+ self,
83
+ name: str,
84
+ task_config: Slurm,
85
+ script: str,
86
+ inputs: Optional[Dict[str, Type]] = None,
87
+ output_locs: Optional[List[OutputLocation]] = None,
88
+ **kwargs,
89
+ ):
90
+ self._inputs = inputs
91
+ self._output_locs = output_locs if output_locs is not None else []
92
+ self._script = script
93
+
94
+ outputs = self._validate_output_locs()
95
+
96
+ super().__init__(
97
+ name=name,
98
+ task_type=self._TASK_TYPE,
99
+ task_config=task_config,
100
+ interface=Interface(inputs=inputs, outputs=outputs),
101
+ **kwargs,
102
+ )
103
+
104
+ def _validate_output_locs(self) -> Dict[str, Type]:
105
+ outputs = {}
106
+ for v in self._output_locs:
107
+ if v is None:
108
+ raise ValueError("OutputLocation cannot be none")
109
+ if not isinstance(v, OutputLocation):
110
+ raise ValueError("Every output type should be an output location on the file-system")
111
+ if v.location is None:
112
+ raise ValueError(f"Output Location not provided for output var {v.var}")
113
+ if not issubclass(v.var_type, FlyteFile) and not issubclass(v.var_type, FlyteDirectory):
114
+ raise ValueError(
115
+ "Currently only outputs of type FlyteFile/FlyteDirectory and their derived types are supported"
116
+ )
117
+ outputs[v.var] = v.var_type
118
+ return outputs
119
+
120
+ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
121
+ return {
122
+ "ssh_config": self.task_config.ssh_config,
123
+ "batch_script_args": self.task_config.batch_script_args,
124
+ "sbatch_conf": self.task_config.sbatch_conf,
125
+ "script": self._script,
126
+ "python_input_types": self._inputs,
127
+ "output_locs": self._output_locs,
128
+ }
129
+
130
+ @property
131
+ def script(self) -> str:
132
+ return self._script
133
+
134
+
135
+ TaskPlugins.register_pythontask_plugin(SlurmRemoteScript, SlurmTask)
136
+ TaskPlugins.register_pythontask_plugin(Slurm, SlurmShellTask)
@@ -0,0 +1,190 @@
1
+ """
2
+ Utilities of asyncssh connections.
3
+ """
4
+
5
+ import os
6
+ import sys
7
+ from dataclasses import asdict, dataclass
8
+ from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union
9
+
10
+ import asyncssh
11
+ from asyncssh import SSHClientConnection
12
+
13
+ from flytekit import logger
14
+ from flytekit.extend.backend.utils import get_agent_secret
15
+
16
+ T = TypeVar("T", bound="SSHConfig")
17
+ SLURM_PRIVATE_KEY = "FLYTE_SLURM_PRIVATE_KEY"
18
+
19
+
20
+ @dataclass
21
+ class SlurmCluster:
22
+ """A Slurm cluster instance is defined by a pair of (Slurm host, username).
23
+
24
+ Attributes:
25
+ host (str): The hostname or address to connect to.
26
+ username (Optional[str]): The username to authenticate as on the server.
27
+ """
28
+
29
+ host: str
30
+ username: Optional[str] = None
31
+
32
+ def __hash__(self):
33
+ return hash((self.host, self.username))
34
+
35
+
36
+ @dataclass(frozen=True)
37
+ class SSHConfig:
38
+ """A customized version of SSHClientConnectionOptions, tailored to specific needs.
39
+
40
+ This config is based on the official SSHClientConnectionOptions but includes
41
+ only a subset of options, with some fields adjusted to be optional or required.
42
+ For the official options, please refer to:
43
+ https://asyncssh.readthedocs.io/en/latest/api.html#asyncssh.SSHClientConnectionOptions
44
+
45
+ Attributes:
46
+ host (str): The hostname or address to connect to.
47
+ username (Optional[str]): The username to authenticate as on the server.
48
+ client_keys (Union[str, List[str], Tuple[str, ...]]): File paths to private keys which will be used to authenticate the
49
+ client via public key authentication. The default value is an empty tuple since
50
+ client public key authentication is mandatory.
51
+ """
52
+
53
+ host: str
54
+ username: Optional[str] = None
55
+ client_keys: Union[str, List[str], Tuple[str, ...]] = ()
56
+
57
+ @classmethod
58
+ def from_dict(cls: Type[T], ssh_config: Dict[str, Any]) -> T:
59
+ return cls(**ssh_config)
60
+
61
+ def to_dict(self) -> Dict[str, Any]:
62
+ return asdict(self)
63
+
64
+ def __eq__(self, other):
65
+ if not isinstance(other, SSHConfig):
66
+ return False
67
+ return self.host == other.host and self.username == other.username and self.client_keys == other.client_keys
68
+
69
+
70
+ async def ssh_connect(ssh_config: Dict[str, Any]) -> SSHClientConnection:
71
+ """Make an SSH client connection.
72
+
73
+ Args:
74
+ ssh_config (Dict[str, Any]): Options of SSH client connection defined in SSHConfig.
75
+
76
+ Returns:
77
+ SSHClientConnection: An SSH client connection object.
78
+
79
+ Raises:
80
+ ValueError: If both FLYTE_SLURM_PRIVATE_KEY secret and ssh_config['private_key'] are missing.
81
+ """
82
+ # Validate ssh_config
83
+ ssh_config = SSHConfig.from_dict(ssh_config).to_dict()
84
+ # This is required to avoid the error "asyncssh.misc.HostKeyNotVerifiable" when connecting to a new host.
85
+ ssh_config["known_hosts"] = None
86
+
87
+ # Make the first SSH connection using either OpenSSH client config files or
88
+ # a user-defined private key. If using OpenSSH config, it will attempt to
89
+ # load settings from ~/.ssh/config.
90
+ try:
91
+ conn = await asyncssh.connect(**ssh_config)
92
+ return conn
93
+ except Exception as e:
94
+ logger.info(
95
+ "Failed to make an SSH connection using the default OpenSSH client config (~/.ssh/config) or "
96
+ f"the provided private keys. Error details:\n{e}"
97
+ )
98
+
99
+ try:
100
+ default_client_key = get_agent_secret(secret_key=SLURM_PRIVATE_KEY)
101
+ except ValueError:
102
+ logger.info("The secret for key FLYTE_SLURM_PRIVATE_KEY is not set.")
103
+ default_client_key = None
104
+
105
+ if default_client_key is None and ssh_config.get("client_keys") == ():
106
+ raise ValueError(
107
+ "Both the secret for key FLYTE_SLURM_PRIVATE_KEY and ssh_config['private_key'] are missing. "
108
+ "At least one must be set."
109
+ )
110
+
111
+ client_keys = []
112
+ if default_client_key is not None:
113
+ # Write the private key to a local path
114
+ # This may not be a good practice...
115
+ private_key_path = os.path.abspath("./slurm_private_key")
116
+ with open(private_key_path, "w") as f:
117
+ f.write(default_client_key)
118
+ client_keys.append(private_key_path)
119
+
120
+ user_client_keys = ssh_config.get("client_keys")
121
+ if user_client_keys is not None:
122
+ client_keys.extend([user_client_keys] if isinstance(user_client_keys, str) else user_client_keys)
123
+
124
+ ssh_config["client_keys"] = client_keys
125
+ logger.info(f"Updated SSH config: {ssh_config}")
126
+ try:
127
+ conn = await asyncssh.connect(**ssh_config)
128
+ return conn
129
+ except Exception as e:
130
+ logger.info(
131
+ "Failed to make an SSH connection using the provided private keys. Please verify your setup."
132
+ f"Error details:\n{e}"
133
+ )
134
+ sys.exit(1)
135
+
136
+
137
+ async def get_ssh_conn(
138
+ ssh_config: Dict[str, Union[str, List[str], Tuple[str, ...]]],
139
+ slurm_cluster_to_ssh_conn: Dict[SlurmCluster, SSHClientConnection],
140
+ ) -> Tuple[SlurmCluster, SSHClientConnection]:
141
+ """
142
+ Get an existing SSH connection or create a new one if needed.
143
+
144
+ Args:
145
+ ssh_config (Dict[str, Union[str, List[str], Tuple[str, ...]]]):
146
+ SSH configuration dictionary, including host and username.
147
+ slurm_cluster_to_ssh_conn (Dict[SlurmCluster, SSHClientConnection]):
148
+ A mapping of SlurmCluster to existing SSHClientConnection objects.
149
+
150
+ Returns:
151
+ Tuple[SlurmCluster, SSHClientConnection]:
152
+ A tuple containing (SlurmCluster, SSHClientConnection). If no connection
153
+ for the given SlurmCluster exists, a new one is created and cached.
154
+ """
155
+
156
+ # (Optional) normal code comment instead of docstring line:
157
+ # Is it necessary to ensure immutability in this function?
158
+
159
+ host = ssh_config.get("host")
160
+ username = ssh_config.get("username")
161
+ slurm_cluster = SlurmCluster(host=host, username=username)
162
+
163
+ if slurm_cluster_to_ssh_conn.get(slurm_cluster) is None:
164
+ logger.info("SSH connection key not found, creating new connection")
165
+ conn = await ssh_connect(ssh_config=ssh_config)
166
+ slurm_cluster_to_ssh_conn[slurm_cluster] = conn
167
+ else:
168
+ conn = slurm_cluster_to_ssh_conn[slurm_cluster]
169
+ try:
170
+ await conn.run("echo [TEST] SSH connection", check=True)
171
+ logger.info("Re-using new connection")
172
+ except Exception as e:
173
+ logger.info(f"Re-establishing SSH connection due to error: {e}")
174
+ conn = await ssh_connect(ssh_config=ssh_config)
175
+ slurm_cluster_to_ssh_conn[slurm_cluster] = conn
176
+
177
+ return conn
178
+
179
+
180
+ if __name__ == "__main__":
181
+ import asyncio
182
+
183
+ async def test_connect():
184
+ conn = await ssh_connect({"host": "aws2", "username": "ubuntu"})
185
+ res = await conn.run("echo [TEST] SSH connection", check=True)
186
+ out = res.stdout
187
+
188
+ return out
189
+
190
+ logger.info(asyncio.run(test_connect()))
@@ -0,0 +1,30 @@
1
+ Metadata-Version: 2.2
2
+ Name: flytekitplugins-slurm
3
+ Version: 1.15.1
4
+ Summary: This package holds the Slurm plugins for flytekit
5
+ Author: flyteorg
6
+ Author-email: admin@flyte.org
7
+ License: apache2
8
+ Classifier: Intended Audience :: Science/Research
9
+ Classifier: Intended Audience :: Developers
10
+ Classifier: License :: OSI Approved :: Apache Software License
11
+ Classifier: Programming Language :: Python :: 3.9
12
+ Classifier: Programming Language :: Python :: 3.10
13
+ Classifier: Programming Language :: Python :: 3.11
14
+ Classifier: Programming Language :: Python :: 3.12
15
+ Classifier: Topic :: Scientific/Engineering
16
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
17
+ Classifier: Topic :: Software Development
18
+ Classifier: Topic :: Software Development :: Libraries
19
+ Classifier: Topic :: Software Development :: Libraries :: Python Modules
20
+ Requires-Python: >=3.9
21
+ Requires-Dist: flytekit>=1.15.0
22
+ Requires-Dist: flyteidl>=1.15.0
23
+ Requires-Dist: asyncssh
24
+ Dynamic: author
25
+ Dynamic: author-email
26
+ Dynamic: classifier
27
+ Dynamic: license
28
+ Dynamic: requires-dist
29
+ Dynamic: requires-python
30
+ Dynamic: summary
@@ -0,0 +1,17 @@
1
+ README.md
2
+ setup.py
3
+ flytekitplugins/slurm/__init__.py
4
+ flytekitplugins/slurm/ssh_utils.py
5
+ flytekitplugins/slurm/function/agent.py
6
+ flytekitplugins/slurm/function/task.py
7
+ flytekitplugins/slurm/script/agent.py
8
+ flytekitplugins/slurm/script/task.py
9
+ flytekitplugins_slurm.egg-info/PKG-INFO
10
+ flytekitplugins_slurm.egg-info/SOURCES.txt
11
+ flytekitplugins_slurm.egg-info/dependency_links.txt
12
+ flytekitplugins_slurm.egg-info/entry_points.txt
13
+ flytekitplugins_slurm.egg-info/namespace_packages.txt
14
+ flytekitplugins_slurm.egg-info/requires.txt
15
+ flytekitplugins_slurm.egg-info/top_level.txt
16
+ tests/test_slurm_fn_task.py
17
+ tests/test_slurm_shell_task.py
@@ -0,0 +1,2 @@
1
+ [flytekit.plugins]
2
+ slurm = flytekitplugins.slurm
@@ -0,0 +1,3 @@
1
+ flytekit>=1.15.0
2
+ flyteidl>=1.15.0
3
+ asyncssh
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,41 @@
1
+ from setuptools import setup
2
+
3
+ PLUGIN_NAME = "slurm"
4
+
5
+ microlib_name = f"flytekitplugins-{PLUGIN_NAME}"
6
+
7
+ plugin_requires = ["flytekit>=1.15.0", "flyteidl>=1.15.0", "asyncssh"]
8
+
9
+ __version__ = "1.15.1"
10
+
11
+ setup(
12
+ name=microlib_name,
13
+ version=__version__,
14
+ author="flyteorg",
15
+ author_email="admin@flyte.org",
16
+ description="This package holds the Slurm plugins for flytekit",
17
+ namespace_packages=["flytekitplugins"],
18
+ packages=[
19
+ f"flytekitplugins.{PLUGIN_NAME}",
20
+ f"flytekitplugins.{PLUGIN_NAME}.script",
21
+ f"flytekitplugins.{PLUGIN_NAME}.function",
22
+ ],
23
+ install_requires=plugin_requires,
24
+ license="apache2",
25
+ python_requires=">=3.9",
26
+ classifiers=[
27
+ "Intended Audience :: Science/Research",
28
+ "Intended Audience :: Developers",
29
+ "License :: OSI Approved :: Apache Software License",
30
+ "Programming Language :: Python :: 3.9",
31
+ "Programming Language :: Python :: 3.10",
32
+ "Programming Language :: Python :: 3.11",
33
+ "Programming Language :: Python :: 3.12",
34
+ "Topic :: Scientific/Engineering",
35
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
36
+ "Topic :: Software Development",
37
+ "Topic :: Software Development :: Libraries",
38
+ "Topic :: Software Development :: Libraries :: Python Modules",
39
+ ],
40
+ entry_points={"flytekit.plugins": [f"{PLUGIN_NAME}=flytekitplugins.{PLUGIN_NAME}"]},
41
+ )
@@ -0,0 +1,55 @@
1
+ import os.path
2
+ from unittest import mock
3
+ from flytekit.core import context_manager
4
+ import flytekit
5
+ from flytekit import StructuredDataset, StructuredDatasetTransformerEngine, task, ImageSpec
6
+ from flytekit.configuration import Image, ImageConfig, SerializationSettings, FastSerializationSettings, DefaultImages
7
+ from flytekit.core.context_manager import ExecutionParameters, FlyteContextManager, ExecutionState
8
+ from flytekitplugins.slurm import SlurmFunction
9
+
10
+
11
+ def test_slurm_task():
12
+ script_file = """#!/bin/bash -i
13
+
14
+ echo Run function with sbatch...
15
+
16
+ # Run the user-defined task function
17
+ {task.fn}
18
+ """
19
+
20
+ @task(
21
+ # container_image=image,
22
+ task_config=SlurmFunction(
23
+ ssh_config={
24
+ "host": "your-slurm-host",
25
+ "username": "ubuntu",
26
+ },
27
+ sbatch_conf={
28
+ "partition": "debug",
29
+ "job-name": "tiny-slurm",
30
+ "output": "/home/ubuntu/fn_task.log"
31
+ },
32
+ script=script_file
33
+ )
34
+ )
35
+ def plus_one(x: int) -> int:
36
+ return x + 1
37
+
38
+ assert plus_one.task_config is not None
39
+ assert plus_one.task_config.ssh_config == {"host": "your-slurm-host", "username": "ubuntu"}
40
+ assert plus_one.task_config.sbatch_conf == {"partition": "debug", "job-name": "tiny-slurm", "output": "/home/ubuntu/fn_task.log"}
41
+ assert plus_one.task_config.script == script_file
42
+
43
+ default_img = Image(name="default", fqn="test", tag="tag")
44
+ settings = SerializationSettings(
45
+ project="project",
46
+ domain="domain",
47
+ version="version",
48
+ env={"FOO": "baz"},
49
+ image_config=ImageConfig(default_image=default_img, images=[default_img]),
50
+ )
51
+
52
+ retrieved_settings = plus_one.get_custom(settings)
53
+ assert retrieved_settings["ssh_config"] == {"host": "your-slurm-host", "username": "ubuntu"}
54
+ assert retrieved_settings["sbatch_conf"] == {"partition": "debug", "job-name": "tiny-slurm", "output": "/home/ubuntu/fn_task.log"}
55
+ assert retrieved_settings["script"] == script_file
@@ -0,0 +1,77 @@
1
+ from flytekit.configuration import Image, ImageConfig, SerializationSettings
2
+ from flytekitplugins.slurm import Slurm, SlurmRemoteScript, SlurmTask, SlurmShellTask
3
+
4
+
5
+ def test_slurm_shell_task():
6
+ """Test SlurmTask and SlurmShellTask."""
7
+ batch_script_path = "/script/path/on/the/slurm/cluster"
8
+ script = """#!/bin/bash -i
9
+
10
+ echo "Run a Flyte SlurmShellTask...\n"
11
+ """
12
+
13
+ # SlurmTask
14
+ slurm_task = SlurmTask(
15
+ name="test-slurm-task",
16
+ task_config=SlurmRemoteScript(
17
+ ssh_config={
18
+ "host": "<your-slurm-host>",
19
+ "username": "ubuntu",
20
+ },
21
+ sbatch_conf={
22
+ "partition": "debug",
23
+ "job-name": "tiny-slurm",
24
+ "output": "/home/ubuntu/slurm_task.log"
25
+ },
26
+ batch_script_path=batch_script_path
27
+ )
28
+ )
29
+
30
+ assert slurm_task.task_config is not None
31
+ assert slurm_task.task_config.ssh_config == {"host": "<your-slurm-host>", "username": "ubuntu"}
32
+ assert slurm_task.task_config.sbatch_conf == {"partition": "debug", "job-name": "tiny-slurm", "output": "/home/ubuntu/slurm_task.log"}
33
+ assert slurm_task.task_config.batch_script_path == batch_script_path
34
+
35
+ # SlurmShellTask
36
+ slurm_shell_task = SlurmShellTask(
37
+ name="test-slurm-shell-task",
38
+ script=script,
39
+ task_config=Slurm(
40
+ ssh_config={
41
+ "host": "<your-slurm-host>",
42
+ "username": "ubuntu",
43
+ },
44
+ sbatch_conf={
45
+ "partition": "debug",
46
+ "job-name": "tiny-slurm",
47
+ "output": "/home/ubuntu/slurm_shell_task.log"
48
+ }
49
+ )
50
+ )
51
+
52
+ assert slurm_shell_task.task_config is not None
53
+ assert slurm_shell_task.task_config.ssh_config == {"host": "<your-slurm-host>", "username": "ubuntu"}
54
+ assert slurm_shell_task.task_config.sbatch_conf == {"partition": "debug", "job-name": "tiny-slurm", "output": "/home/ubuntu/slurm_shell_task.log"}
55
+ assert slurm_shell_task.script == script
56
+
57
+ # Define dummy SerializationSettings
58
+ default_img = Image(name="default", fqn="test", tag="tag")
59
+ settings = SerializationSettings(
60
+ project="project",
61
+ domain="domain",
62
+ version="version",
63
+ env={"FOO": "baz"},
64
+ image_config=ImageConfig(default_image=default_img, images=[default_img])
65
+ )
66
+
67
+ custom = slurm_task.get_custom(settings)
68
+ assert custom["ssh_config"] == {"host": "<your-slurm-host>", "username": "ubuntu"}
69
+ assert custom["sbatch_conf"] == {"partition": "debug", "job-name": "tiny-slurm", "output": "/home/ubuntu/slurm_task.log"}
70
+ assert custom["batch_script_path"] == batch_script_path
71
+ assert custom["batch_script_args"] is None
72
+
73
+ shell_custom = slurm_shell_task.get_custom(settings)
74
+ assert shell_custom["ssh_config"] == {"host": "<your-slurm-host>", "username": "ubuntu"}
75
+ assert shell_custom["sbatch_conf"] == {"partition": "debug", "job-name": "tiny-slurm", "output": "/home/ubuntu/slurm_shell_task.log"}
76
+ assert shell_custom["script"] == script
77
+ assert shell_custom["batch_script_args"] is None