flyteplugins-pytorch 2.0.0b24__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
|
@@ -0,0 +1,191 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
|
+
from typing import Any, Dict, Literal, Optional, Union
|
|
4
|
+
|
|
5
|
+
import flyte
|
|
6
|
+
import flyte.report
|
|
7
|
+
from cloudpickle import cloudpickle
|
|
8
|
+
from flyte._context import internal_ctx
|
|
9
|
+
from flyte._logging import logger
|
|
10
|
+
from flyte._task import P, R
|
|
11
|
+
from flyte.extend import AsyncFunctionTaskTemplate, TaskPluginRegistry
|
|
12
|
+
from flyte.models import SerializationContext, TaskContext
|
|
13
|
+
from flyteidl.plugins.kubeflow import common_pb2
|
|
14
|
+
from flyteidl.plugins.kubeflow.pytorch_pb2 import (
|
|
15
|
+
DistributedPyTorchTrainingReplicaSpec,
|
|
16
|
+
DistributedPyTorchTrainingTask,
|
|
17
|
+
ElasticConfig,
|
|
18
|
+
)
|
|
19
|
+
from google.protobuf.json_format import MessageToDict
|
|
20
|
+
from torch.distributed import run
|
|
21
|
+
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclass
|
|
25
|
+
class RunPolicy:
|
|
26
|
+
"""
|
|
27
|
+
RunPolicy describes some policy to apply to the execution of a kubeflow job.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
clean_pod_policy (str, optional): Policy for cleaning up pods after the PyTorchJob completes.
|
|
31
|
+
Allowed values are "None", "all", or "Running". Defaults to None.
|
|
32
|
+
ttl_seconds_after_finished (int, optional): Defines the TTL (in seconds) for cleaning
|
|
33
|
+
up finished PyTorchJobs. Defaults to None.
|
|
34
|
+
active_deadline_seconds (int, optional): Specifies the duration (in seconds) since
|
|
35
|
+
startTime during which the job can remain active before it is terminated.
|
|
36
|
+
Must be a positive integer. Applies only to pods where restartPolicy is
|
|
37
|
+
OnFailure or Always. Defaults to None.
|
|
38
|
+
backoff_limit (int, optional): Number of retries before marking this job as failed.
|
|
39
|
+
Defaults to None.
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
clean_pod_policy: Optional[Literal["None", "all", "Running"]] = None
|
|
43
|
+
ttl_seconds_after_finished: Optional[int] = None
|
|
44
|
+
active_deadline_seconds: Optional[int] = None
|
|
45
|
+
backoff_limit: Optional[int] = None
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
@dataclass
|
|
49
|
+
class Elastic:
|
|
50
|
+
"""
|
|
51
|
+
Elastic defines the configuration for running a PyTorch elastic job using torch.distributed.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
nnodes (Union[int, str]): Number of nodes to use. Can be a fixed int or a range
|
|
55
|
+
string (e.g., "2:4" for elastic training).
|
|
56
|
+
nproc_per_node (int): Number of processes to launch per node.
|
|
57
|
+
rdzv_backend (literal): Rendezvous backend to use. Typically "c10d". Defaults to "c10d".
|
|
58
|
+
run_policy (RunPolicy, optional): Run policy applied to the job execution.
|
|
59
|
+
Defaults to None.
|
|
60
|
+
monitor_interval (int): Interval (in seconds) to monitor the job's state.
|
|
61
|
+
Defaults to 3.
|
|
62
|
+
max_restarts (int): Maximum number of worker group restarts before failing the job.
|
|
63
|
+
Defaults to 3.
|
|
64
|
+
rdzv_configs (Dict[str, Any]): Rendezvous configuration key-value pairs.
|
|
65
|
+
Defaults to {"timeout": 900, "join_timeout": 900}.
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
nnodes: Union[int, str]
|
|
69
|
+
nproc_per_node: int
|
|
70
|
+
rdzv_backend: Literal["c10d", "etcd", "etcd-v2"] = "c10d"
|
|
71
|
+
run_policy: Optional[RunPolicy] = None
|
|
72
|
+
monitor_interval: int = 3
|
|
73
|
+
max_restarts: int = 3
|
|
74
|
+
rdzv_configs: Dict[str, Any] = field(default_factory=lambda: {"timeout": 900, "join_timeout": 900})
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def launcher_entrypoint(tctx: TaskContext, fn: bytes, kwargs: dict):
|
|
78
|
+
func = cloudpickle.loads(fn)
|
|
79
|
+
flyte.init(
|
|
80
|
+
org=tctx.action.org,
|
|
81
|
+
project=tctx.action.project,
|
|
82
|
+
domain=tctx.action.domain,
|
|
83
|
+
root_dir=tctx.run_base_dir,
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
with internal_ctx().replace_task_context(tctx):
|
|
87
|
+
return func(**kwargs)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
@dataclass(kw_only=True)
|
|
91
|
+
class TorchFunctionTask(AsyncFunctionTaskTemplate):
|
|
92
|
+
"""
|
|
93
|
+
Plugin to transform local python code for execution as a PyTorch job.
|
|
94
|
+
"""
|
|
95
|
+
|
|
96
|
+
task_type: str = "pytorch"
|
|
97
|
+
task_type_version: int = 1
|
|
98
|
+
plugin_config: Elastic
|
|
99
|
+
|
|
100
|
+
def __post_init__(self):
|
|
101
|
+
super().__post_init__()
|
|
102
|
+
self.task_type = "python-task" if self.plugin_config.nnodes == 1 else "pytorch"
|
|
103
|
+
self.min_nodes, self.max_nodes = run.parse_min_max_nnodes(str(self.plugin_config.nnodes))
|
|
104
|
+
|
|
105
|
+
async def pre(self, *args: P.args, **kwargs: P.kwargs) -> Dict[str, Any]:
|
|
106
|
+
# If OMP_NUM_THREADS is not set, set it to 1 to avoid overloading the system.
|
|
107
|
+
# Doing so to copy the default behavior of torchrun.
|
|
108
|
+
# See https://github.com/pytorch/pytorch/blob/eea4ece256d74c6f25c1f4eab37b3f2f4aeefd4d/torch/distributed/run.py#L791
|
|
109
|
+
if "OMP_NUM_THREADS" not in os.environ and self.plugin_config.nproc_per_node > 1:
|
|
110
|
+
omp_num_threads = 1
|
|
111
|
+
logger.warning(
|
|
112
|
+
"\n*****************************************\n"
|
|
113
|
+
"Setting OMP_NUM_THREADS environment variable for each process to be "
|
|
114
|
+
"%s in default, to avoid your system being overloaded, "
|
|
115
|
+
"please further tune the variable for optimal performance in "
|
|
116
|
+
"your application as needed. \n"
|
|
117
|
+
"*****************************************",
|
|
118
|
+
omp_num_threads,
|
|
119
|
+
)
|
|
120
|
+
os.environ["OMP_NUM_THREADS"] = str(omp_num_threads)
|
|
121
|
+
return {}
|
|
122
|
+
|
|
123
|
+
async def execute(self, *args: P.args, **kwargs: P.kwargs) -> R:
|
|
124
|
+
tctx = internal_ctx().data.task_context
|
|
125
|
+
if tctx.mode == "local":
|
|
126
|
+
return self.func(**kwargs)
|
|
127
|
+
|
|
128
|
+
config = LaunchConfig(
|
|
129
|
+
run_id=flyte.ctx().action.run_name,
|
|
130
|
+
min_nodes=self.min_nodes,
|
|
131
|
+
max_nodes=self.max_nodes,
|
|
132
|
+
nproc_per_node=self.plugin_config.nproc_per_node,
|
|
133
|
+
rdzv_backend=self.plugin_config.rdzv_backend,
|
|
134
|
+
rdzv_configs=self.plugin_config.rdzv_configs,
|
|
135
|
+
rdzv_endpoint=os.environ.get("PET_RDZV_ENDPOINT", "localhost:0"),
|
|
136
|
+
max_restarts=self.plugin_config.max_restarts,
|
|
137
|
+
monitor_interval=self.plugin_config.monitor_interval,
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
out = elastic_launch(config=config, entrypoint=launcher_entrypoint)(
|
|
141
|
+
tctx,
|
|
142
|
+
cloudpickle.dumps(self.func),
|
|
143
|
+
kwargs,
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
# `out` is a dictionary of rank (not local rank) -> result
|
|
147
|
+
# Rank 0 returns the result of the task function
|
|
148
|
+
if 0 in out:
|
|
149
|
+
return out[0]
|
|
150
|
+
return None
|
|
151
|
+
|
|
152
|
+
def custom_config(self, sctx: SerializationContext) -> Optional[Dict[str, Any]]:
|
|
153
|
+
"""
|
|
154
|
+
Converts the ElasticConfig to a DistributedPyTorchTrainingTask
|
|
155
|
+
"""
|
|
156
|
+
elastic_config = ElasticConfig(
|
|
157
|
+
rdzv_backend=self.plugin_config.rdzv_backend,
|
|
158
|
+
min_replicas=self.min_nodes,
|
|
159
|
+
max_replicas=self.max_nodes,
|
|
160
|
+
nproc_per_node=self.plugin_config.nproc_per_node,
|
|
161
|
+
max_restarts=self.plugin_config.max_restarts,
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
policy = None
|
|
165
|
+
if self.plugin_config.run_policy:
|
|
166
|
+
policy = common_pb2.RunPolicy(
|
|
167
|
+
clean_pod_policy=(
|
|
168
|
+
# https://github.com/flyteorg/flyte/blob/4caa5639ee6453d86c823181083423549f08f702/flyteidl/protos/flyteidl/plugins/kubeflow/common.proto#L9-L13
|
|
169
|
+
common_pb2.CleanPodPolicy.Value(
|
|
170
|
+
"CLEANPOD_POLICY_" + self.plugin_config.run_policy.clean_pod_policy.upper()
|
|
171
|
+
)
|
|
172
|
+
if self.plugin_config.run_policy.clean_pod_policy
|
|
173
|
+
else None
|
|
174
|
+
),
|
|
175
|
+
ttl_seconds_after_finished=self.plugin_config.run_policy.ttl_seconds_after_finished,
|
|
176
|
+
active_deadline_seconds=self.plugin_config.run_policy.active_deadline_seconds,
|
|
177
|
+
backoff_limit=self.plugin_config.run_policy.backoff_limit,
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
torch_job = DistributedPyTorchTrainingTask(
|
|
181
|
+
worker_replicas=DistributedPyTorchTrainingReplicaSpec(
|
|
182
|
+
replicas=self.max_nodes,
|
|
183
|
+
),
|
|
184
|
+
run_policy=policy,
|
|
185
|
+
elastic_config=elastic_config,
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
return MessageToDict(torch_job)
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
TaskPluginRegistry.register(config_type=Elastic, plugin=TorchFunctionTask)
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: flyteplugins-pytorch
|
|
3
|
+
Version: 2.0.0b24
|
|
4
|
+
Summary: pytorch plugin for flyte
|
|
5
|
+
Author-email: Kevin Liao <kevinliao852@users.noreply.github.com>
|
|
6
|
+
Requires-Python: >=3.10
|
|
7
|
+
Description-Content-Type: text/markdown
|
|
8
|
+
Requires-Dist: torch
|
|
9
|
+
Requires-Dist: flyte
|
|
10
|
+
|
|
11
|
+
# Union PyTorch Plugin
|
|
12
|
+
|
|
13
|
+
Union can execute **PyTorch distributed training jobs** natively on a Kubernetes Cluster, which manages the lifecycle of worker pods, rendezvous coordination, spin-up, and tear down. It leverages the open-sourced **TorchElastic (torch.distributed.elastic)** launcher and the **Kubeflow PyTorch Operator**, enabling fault-tolerant and elastic training across multiple nodes.
|
|
14
|
+
|
|
15
|
+
This is like running a transient PyTorch cluster — worker groups are created for the specific job and torn down automatically after completion. Elastic training allows nodes to scale in and out, and failed workers can be restarted without bringing down the entire job.
|
|
16
|
+
|
|
17
|
+
To install the plugin, run the following command:
|
|
18
|
+
|
|
19
|
+
```bash
|
|
20
|
+
pip install --pre flyteplugins-pytorch
|
|
21
|
+
```
|
|
22
|
+
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
flyteplugins/pytorch/__init__.py,sha256=0Z7JvW6zJfyVX7375qTOBM51im_Q9fzzujP7FhotY5M,69
|
|
2
|
+
flyteplugins/pytorch/task.py,sha256=HiTlrzvlxzLhRhdZXywPYpzcav6rjYtPrMAeVHxNphg,7901
|
|
3
|
+
flyteplugins_pytorch-2.0.0b24.dist-info/METADATA,sha256=SgFA82956c77HfxEZ9JAApIyLzmjbVt8vUwBocv_ffo,1047
|
|
4
|
+
flyteplugins_pytorch-2.0.0b24.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
5
|
+
flyteplugins_pytorch-2.0.0b24.dist-info/top_level.txt,sha256=cgd779rPu9EsvdtuYgUxNHHgElaQvPn74KhB5XSeMBE,13
|
|
6
|
+
flyteplugins_pytorch-2.0.0b24.dist-info/RECORD,,
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
flyteplugins
|