flyteplugins-pytorch 2.0.0b24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,3 @@
1
+ __all__ = ["Elastic"]
2
+
3
+ from flyteplugins.pytorch.task import Elastic
@@ -0,0 +1,191 @@
1
+ import os
2
+ from dataclasses import dataclass, field
3
+ from typing import Any, Dict, Literal, Optional, Union
4
+
5
+ import flyte
6
+ import flyte.report
7
+ from cloudpickle import cloudpickle
8
+ from flyte._context import internal_ctx
9
+ from flyte._logging import logger
10
+ from flyte._task import P, R
11
+ from flyte.extend import AsyncFunctionTaskTemplate, TaskPluginRegistry
12
+ from flyte.models import SerializationContext, TaskContext
13
+ from flyteidl.plugins.kubeflow import common_pb2
14
+ from flyteidl.plugins.kubeflow.pytorch_pb2 import (
15
+ DistributedPyTorchTrainingReplicaSpec,
16
+ DistributedPyTorchTrainingTask,
17
+ ElasticConfig,
18
+ )
19
+ from google.protobuf.json_format import MessageToDict
20
+ from torch.distributed import run
21
+ from torch.distributed.launcher.api import LaunchConfig, elastic_launch
22
+
23
+
24
+ @dataclass
25
+ class RunPolicy:
26
+ """
27
+ RunPolicy describes some policy to apply to the execution of a kubeflow job.
28
+
29
+ Args:
30
+ clean_pod_policy (str, optional): Policy for cleaning up pods after the PyTorchJob completes.
31
+ Allowed values are "None", "all", or "Running". Defaults to None.
32
+ ttl_seconds_after_finished (int, optional): Defines the TTL (in seconds) for cleaning
33
+ up finished PyTorchJobs. Defaults to None.
34
+ active_deadline_seconds (int, optional): Specifies the duration (in seconds) since
35
+ startTime during which the job can remain active before it is terminated.
36
+ Must be a positive integer. Applies only to pods where restartPolicy is
37
+ OnFailure or Always. Defaults to None.
38
+ backoff_limit (int, optional): Number of retries before marking this job as failed.
39
+ Defaults to None.
40
+ """
41
+
42
+ clean_pod_policy: Optional[Literal["None", "all", "Running"]] = None
43
+ ttl_seconds_after_finished: Optional[int] = None
44
+ active_deadline_seconds: Optional[int] = None
45
+ backoff_limit: Optional[int] = None
46
+
47
+
48
+ @dataclass
49
+ class Elastic:
50
+ """
51
+ Elastic defines the configuration for running a PyTorch elastic job using torch.distributed.
52
+
53
+ Args:
54
+ nnodes (Union[int, str]): Number of nodes to use. Can be a fixed int or a range
55
+ string (e.g., "2:4" for elastic training).
56
+ nproc_per_node (int): Number of processes to launch per node.
57
+ rdzv_backend (literal): Rendezvous backend to use. Typically "c10d". Defaults to "c10d".
58
+ run_policy (RunPolicy, optional): Run policy applied to the job execution.
59
+ Defaults to None.
60
+ monitor_interval (int): Interval (in seconds) to monitor the job's state.
61
+ Defaults to 3.
62
+ max_restarts (int): Maximum number of worker group restarts before failing the job.
63
+ Defaults to 3.
64
+ rdzv_configs (Dict[str, Any]): Rendezvous configuration key-value pairs.
65
+ Defaults to {"timeout": 900, "join_timeout": 900}.
66
+ """
67
+
68
+ nnodes: Union[int, str]
69
+ nproc_per_node: int
70
+ rdzv_backend: Literal["c10d", "etcd", "etcd-v2"] = "c10d"
71
+ run_policy: Optional[RunPolicy] = None
72
+ monitor_interval: int = 3
73
+ max_restarts: int = 3
74
+ rdzv_configs: Dict[str, Any] = field(default_factory=lambda: {"timeout": 900, "join_timeout": 900})
75
+
76
+
77
+ def launcher_entrypoint(tctx: TaskContext, fn: bytes, kwargs: dict):
78
+ func = cloudpickle.loads(fn)
79
+ flyte.init(
80
+ org=tctx.action.org,
81
+ project=tctx.action.project,
82
+ domain=tctx.action.domain,
83
+ root_dir=tctx.run_base_dir,
84
+ )
85
+
86
+ with internal_ctx().replace_task_context(tctx):
87
+ return func(**kwargs)
88
+
89
+
90
+ @dataclass(kw_only=True)
91
+ class TorchFunctionTask(AsyncFunctionTaskTemplate):
92
+ """
93
+ Plugin to transform local python code for execution as a PyTorch job.
94
+ """
95
+
96
+ task_type: str = "pytorch"
97
+ task_type_version: int = 1
98
+ plugin_config: Elastic
99
+
100
+ def __post_init__(self):
101
+ super().__post_init__()
102
+ self.task_type = "python-task" if self.plugin_config.nnodes == 1 else "pytorch"
103
+ self.min_nodes, self.max_nodes = run.parse_min_max_nnodes(str(self.plugin_config.nnodes))
104
+
105
+ async def pre(self, *args: P.args, **kwargs: P.kwargs) -> Dict[str, Any]:
106
+ # If OMP_NUM_THREADS is not set, set it to 1 to avoid overloading the system.
107
+ # Doing so to copy the default behavior of torchrun.
108
+ # See https://github.com/pytorch/pytorch/blob/eea4ece256d74c6f25c1f4eab37b3f2f4aeefd4d/torch/distributed/run.py#L791
109
+ if "OMP_NUM_THREADS" not in os.environ and self.plugin_config.nproc_per_node > 1:
110
+ omp_num_threads = 1
111
+ logger.warning(
112
+ "\n*****************************************\n"
113
+ "Setting OMP_NUM_THREADS environment variable for each process to be "
114
+ "%s in default, to avoid your system being overloaded, "
115
+ "please further tune the variable for optimal performance in "
116
+ "your application as needed. \n"
117
+ "*****************************************",
118
+ omp_num_threads,
119
+ )
120
+ os.environ["OMP_NUM_THREADS"] = str(omp_num_threads)
121
+ return {}
122
+
123
+ async def execute(self, *args: P.args, **kwargs: P.kwargs) -> R:
124
+ tctx = internal_ctx().data.task_context
125
+ if tctx.mode == "local":
126
+ return self.func(**kwargs)
127
+
128
+ config = LaunchConfig(
129
+ run_id=flyte.ctx().action.run_name,
130
+ min_nodes=self.min_nodes,
131
+ max_nodes=self.max_nodes,
132
+ nproc_per_node=self.plugin_config.nproc_per_node,
133
+ rdzv_backend=self.plugin_config.rdzv_backend,
134
+ rdzv_configs=self.plugin_config.rdzv_configs,
135
+ rdzv_endpoint=os.environ.get("PET_RDZV_ENDPOINT", "localhost:0"),
136
+ max_restarts=self.plugin_config.max_restarts,
137
+ monitor_interval=self.plugin_config.monitor_interval,
138
+ )
139
+
140
+ out = elastic_launch(config=config, entrypoint=launcher_entrypoint)(
141
+ tctx,
142
+ cloudpickle.dumps(self.func),
143
+ kwargs,
144
+ )
145
+
146
+ # `out` is a dictionary of rank (not local rank) -> result
147
+ # Rank 0 returns the result of the task function
148
+ if 0 in out:
149
+ return out[0]
150
+ return None
151
+
152
+ def custom_config(self, sctx: SerializationContext) -> Optional[Dict[str, Any]]:
153
+ """
154
+ Converts the ElasticConfig to a DistributedPyTorchTrainingTask
155
+ """
156
+ elastic_config = ElasticConfig(
157
+ rdzv_backend=self.plugin_config.rdzv_backend,
158
+ min_replicas=self.min_nodes,
159
+ max_replicas=self.max_nodes,
160
+ nproc_per_node=self.plugin_config.nproc_per_node,
161
+ max_restarts=self.plugin_config.max_restarts,
162
+ )
163
+
164
+ policy = None
165
+ if self.plugin_config.run_policy:
166
+ policy = common_pb2.RunPolicy(
167
+ clean_pod_policy=(
168
+ # https://github.com/flyteorg/flyte/blob/4caa5639ee6453d86c823181083423549f08f702/flyteidl/protos/flyteidl/plugins/kubeflow/common.proto#L9-L13
169
+ common_pb2.CleanPodPolicy.Value(
170
+ "CLEANPOD_POLICY_" + self.plugin_config.run_policy.clean_pod_policy.upper()
171
+ )
172
+ if self.plugin_config.run_policy.clean_pod_policy
173
+ else None
174
+ ),
175
+ ttl_seconds_after_finished=self.plugin_config.run_policy.ttl_seconds_after_finished,
176
+ active_deadline_seconds=self.plugin_config.run_policy.active_deadline_seconds,
177
+ backoff_limit=self.plugin_config.run_policy.backoff_limit,
178
+ )
179
+
180
+ torch_job = DistributedPyTorchTrainingTask(
181
+ worker_replicas=DistributedPyTorchTrainingReplicaSpec(
182
+ replicas=self.max_nodes,
183
+ ),
184
+ run_policy=policy,
185
+ elastic_config=elastic_config,
186
+ )
187
+
188
+ return MessageToDict(torch_job)
189
+
190
+
191
+ TaskPluginRegistry.register(config_type=Elastic, plugin=TorchFunctionTask)
@@ -0,0 +1,22 @@
1
+ Metadata-Version: 2.4
2
+ Name: flyteplugins-pytorch
3
+ Version: 2.0.0b24
4
+ Summary: pytorch plugin for flyte
5
+ Author-email: Kevin Liao <kevinliao852@users.noreply.github.com>
6
+ Requires-Python: >=3.10
7
+ Description-Content-Type: text/markdown
8
+ Requires-Dist: torch
9
+ Requires-Dist: flyte
10
+
11
+ # Union PyTorch Plugin
12
+
13
+ Union can execute **PyTorch distributed training jobs** natively on a Kubernetes Cluster, which manages the lifecycle of worker pods, rendezvous coordination, spin-up, and tear down. It leverages the open-sourced **TorchElastic (torch.distributed.elastic)** launcher and the **Kubeflow PyTorch Operator**, enabling fault-tolerant and elastic training across multiple nodes.
14
+
15
+ This is like running a transient PyTorch cluster — worker groups are created for the specific job and torn down automatically after completion. Elastic training allows nodes to scale in and out, and failed workers can be restarted without bringing down the entire job.
16
+
17
+ To install the plugin, run the following command:
18
+
19
+ ```bash
20
+ pip install --pre flyteplugins-pytorch
21
+ ```
22
+
@@ -0,0 +1,6 @@
1
+ flyteplugins/pytorch/__init__.py,sha256=0Z7JvW6zJfyVX7375qTOBM51im_Q9fzzujP7FhotY5M,69
2
+ flyteplugins/pytorch/task.py,sha256=HiTlrzvlxzLhRhdZXywPYpzcav6rjYtPrMAeVHxNphg,7901
3
+ flyteplugins_pytorch-2.0.0b24.dist-info/METADATA,sha256=SgFA82956c77HfxEZ9JAApIyLzmjbVt8vUwBocv_ffo,1047
4
+ flyteplugins_pytorch-2.0.0b24.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
5
+ flyteplugins_pytorch-2.0.0b24.dist-info/top_level.txt,sha256=cgd779rPu9EsvdtuYgUxNHHgElaQvPn74KhB5XSeMBE,13
6
+ flyteplugins_pytorch-2.0.0b24.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (80.9.0)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1 @@
1
+ flyteplugins