flytekitplugins-k8sdataservice 1.15.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (27) hide show
  1. flytekitplugins_k8sdataservice-1.15.0/PKG-INFO +30 -0
  2. flytekitplugins_k8sdataservice-1.15.0/README.md +104 -0
  3. flytekitplugins_k8sdataservice-1.15.0/flytekitplugins/k8sdataservice/__init__.py +15 -0
  4. flytekitplugins_k8sdataservice-1.15.0/flytekitplugins/k8sdataservice/agent.py +98 -0
  5. flytekitplugins_k8sdataservice-1.15.0/flytekitplugins/k8sdataservice/k8s/__init__.py +0 -0
  6. flytekitplugins_k8sdataservice-1.15.0/flytekitplugins/k8sdataservice/k8s/kube_config.py +20 -0
  7. flytekitplugins_k8sdataservice-1.15.0/flytekitplugins/k8sdataservice/k8s/manager.py +219 -0
  8. flytekitplugins_k8sdataservice-1.15.0/flytekitplugins/k8sdataservice/sensor.py +67 -0
  9. flytekitplugins_k8sdataservice-1.15.0/flytekitplugins/k8sdataservice/task.py +71 -0
  10. flytekitplugins_k8sdataservice-1.15.0/flytekitplugins_k8sdataservice.egg-info/PKG-INFO +30 -0
  11. flytekitplugins_k8sdataservice-1.15.0/flytekitplugins_k8sdataservice.egg-info/SOURCES.txt +25 -0
  12. flytekitplugins_k8sdataservice-1.15.0/flytekitplugins_k8sdataservice.egg-info/dependency_links.txt +1 -0
  13. flytekitplugins_k8sdataservice-1.15.0/flytekitplugins_k8sdataservice.egg-info/entry_points.txt +2 -0
  14. flytekitplugins_k8sdataservice-1.15.0/flytekitplugins_k8sdataservice.egg-info/namespace_packages.txt +1 -0
  15. flytekitplugins_k8sdataservice-1.15.0/flytekitplugins_k8sdataservice.egg-info/requires.txt +3 -0
  16. flytekitplugins_k8sdataservice-1.15.0/flytekitplugins_k8sdataservice.egg-info/top_level.txt +3 -0
  17. flytekitplugins_k8sdataservice-1.15.0/setup.cfg +4 -0
  18. flytekitplugins_k8sdataservice-1.15.0/setup.py +38 -0
  19. flytekitplugins_k8sdataservice-1.15.0/tests/k8sdataservice/k8s/test_kube_config.py +23 -0
  20. flytekitplugins_k8sdataservice-1.15.0/tests/k8sdataservice/k8s/test_manager.py +168 -0
  21. flytekitplugins_k8sdataservice-1.15.0/tests/k8sdataservice/test_agent.py +319 -0
  22. flytekitplugins_k8sdataservice-1.15.0/tests/k8sdataservice/test_sensor.py +134 -0
  23. flytekitplugins_k8sdataservice-1.15.0/tests/k8sdataservice/test_task.py +128 -0
  24. flytekitplugins_k8sdataservice-1.15.0/tests/k8sdataservice/utils/test_resources.py +90 -0
  25. flytekitplugins_k8sdataservice-1.15.0/utils/__init__.py +0 -0
  26. flytekitplugins_k8sdataservice-1.15.0/utils/infra.py +9 -0
  27. flytekitplugins_k8sdataservice-1.15.0/utils/resources.py +24 -0
@@ -0,0 +1,30 @@
1
+ Metadata-Version: 2.2
2
+ Name: flytekitplugins-k8sdataservice
3
+ Version: 1.15.0
4
+ Summary: Flytekit K8s Data Service Plugin
5
+ Author: LinkedIn
6
+ Author-email: shuliang@linkedin.com
7
+ License: apache2
8
+ Classifier: Intended Audience :: Science/Research
9
+ Classifier: Intended Audience :: Developers
10
+ Classifier: License :: OSI Approved :: Apache Software License
11
+ Classifier: Programming Language :: Python :: 3.9
12
+ Classifier: Programming Language :: Python :: 3.10
13
+ Classifier: Programming Language :: Python :: 3.11
14
+ Classifier: Programming Language :: Python :: 3.12
15
+ Classifier: Topic :: Scientific/Engineering
16
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
17
+ Classifier: Topic :: Software Development
18
+ Classifier: Topic :: Software Development :: Libraries
19
+ Classifier: Topic :: Software Development :: Libraries :: Python Modules
20
+ Requires-Python: >=3.9
21
+ Requires-Dist: flytekit<2.0.0,>=1.11.0
22
+ Requires-Dist: kubernetes<24.0.0,>=23.6.0
23
+ Requires-Dist: flyteidl<2.0.0,>=1.11.0
24
+ Dynamic: author
25
+ Dynamic: author-email
26
+ Dynamic: classifier
27
+ Dynamic: license
28
+ Dynamic: requires-dist
29
+ Dynamic: requires-python
30
+ Dynamic: summary
@@ -0,0 +1,104 @@
1
+ # K8s Stateful Service Plugin
2
+
3
+ This plugin provides support for Kubernetes StatefulSet and Service integration, enabling seamless provisioning and coordination with any Kubernetes services or Flyte tasks. It is especially suited for deep learning use cases at scale, where distributed and parallelized data loading and caching across nodes are required.
4
+
5
+ ## Features
6
+ - **Predictable and Reliable Endpoints**: The service creates consistent endpoints, facilitating communication between services or tasks within the same Kubernetes cluster.
7
+ - **Reusable Across Runs**: Service tasks can persist across task runs, ensuring consistency. Alternatively, a cleanup sensor can release cluster resources when they are no longer needed.
8
+ - **Conventional Pod Naming**: Pods in the StatefulSet follow a conventional naming pattern. For instance, if the StatefulSet name is `foo` and replicas are set to 2, the pod endpoints will be `foo-0.foo:1234` and `foo-1.foo:1234`. This simplifies endpoint construction for training or inference scripts. For example, gRPC endpoints can directly use `foo-0.foo:1234` and `foo-1.foo:1234`.
9
+
10
+ ## Installation
11
+
12
+ Install the plugin via pip:
13
+
14
+ ```bash
15
+ pip install flytekitplugins-k8sdataservice
16
+ ```
17
+
18
+ ## Usage
19
+
20
+ Below is an example demonstrating how to provision and run a service in Kubernetes, making it reachable within the cluster.
21
+
22
+ **Note**: Utility functions are available to generate unique service names that can be reused across training or inference scripts.
23
+
24
+ ### Example Usage
25
+
26
+ #### Provisioning a Data Service
27
+ ```python
28
+ from flytekitplugins.k8sdataservice import DataServiceConfig, DataServiceTask, CleanupSensor
29
+ from utils.infra import gen_infra_name
30
+ from flytekit import kwtypes, Resources, task, workflow
31
+
32
+ # Generate a unique infrastructure name
33
+ name = gen_infra_name()
34
+
35
+ def k8s_data_service():
36
+ gnn_config = DataServiceConfig(
37
+ Name=name,
38
+ Requests=Resources(cpu='1', mem='1Gi'),
39
+ Limits=Resources(cpu='2', mem='2Gi'),
40
+ Replicas=1,
41
+ Image="busybox:latest",
42
+ Command=[
43
+ "bash",
44
+ "-c",
45
+ "echo Hello Flyte K8s Stateful Service! && sleep 3600"
46
+ ],
47
+ )
48
+
49
+ gnn_task = DataServiceTask(
50
+ name="K8s Stateful Data Service",
51
+ inputs=kwtypes(ds=str),
52
+ task_config=gnn_config,
53
+ )
54
+ return gnn_task
55
+
56
+ # Define a cleanup sensor
57
+ gnn_sensor = CleanupSensor(name="Cleanup")
58
+
59
+ # Define a workflow to test the data service
60
+ @workflow
61
+ def test_dataservice_wf(name: str):
62
+ k8s_data_service()(ds="OSS Flyte K8s Data Service Demo") \
63
+ >> gnn_sensor(
64
+ release_name=name,
65
+ cleanup_data_service=True,
66
+ )
67
+
68
+ if __name__ == "__main__":
69
+ out = test_dataservice_wf(name="example")
70
+ print(f"Running test_dataservice_wf() {out}")
71
+ ```
72
+
73
+ #### Accessing the Data Service
74
+ Other tasks or services that need to access the data service can do so in multiple ways. For example, using environment variables:
75
+
76
+ ```python
77
+ from kubernetes.client import V1PodSpec, V1Container, V1EnvVar
78
+
79
+ PRIMARY_CONTAINER_NAME = "primary"
80
+ FLYTE_POD_SPEC = V1PodSpec(
81
+ containers=[
82
+ V1Container(
83
+ name=PRIMARY_CONTAINER_NAME,
84
+ env=[
85
+ V1EnvVar(name="MY_DATASERVICES", value=f"{name}-0.{name}:40000 {name}-1.{name}:40000"),
86
+ ],
87
+ )
88
+ ],
89
+ )
90
+
91
+ task_config = MPIJob(
92
+ launcher=Launcher(replicas=1, pod_template=FLYTE_POD_SPEC),
93
+ worker=Worker(replicas=1, pod_template=FLYTE_POD_SPEC),
94
+ )
95
+
96
+ @task(task_config=task_config)
97
+ def mpi_task() -> str:
98
+ return "your script uses the envs to communicate with the data service "
99
+ ```
100
+
101
+ ### Key Points
102
+ - The `DataServiceConfig` defines resource requests, limits, replicas, and the container image/command.
103
+ - The `CleanupSensor` ensures resources are cleaned up when required.
104
+ - The workflow connects the service provisioning and cleanup process for streamlined operations.
@@ -0,0 +1,15 @@
1
+ """
2
+ .. currentmodule:: flytekitplugins.k8sdataservice
3
+
4
+ This package contains things that are useful when extending Flytekit.
5
+
6
+ .. autosummary::
7
+ :template: custom.rst
8
+ :toctree: generated/
9
+
10
+ DataServiceTask
11
+ """
12
+
13
+ from .agent import DataServiceAgent # noqa: F401
14
+ from .sensor import CleanupSensor # noqa: F401
15
+ from .task import DataServiceConfig, DataServiceTask # noqa: F401
@@ -0,0 +1,98 @@
1
+ from dataclasses import dataclass
2
+ from typing import Optional
3
+
4
+ from flyteidl.core.execution_pb2 import TaskExecution
5
+ from flytekitplugins.k8sdataservice.k8s.manager import K8sManager
6
+ from flytekitplugins.k8sdataservice.task import DataServiceConfig
7
+
8
+ from flytekit import logger
9
+ from flytekit.extend.backend.base_agent import AgentRegistry, AsyncAgentBase, Resource, ResourceMeta
10
+ from flytekit.models.literals import LiteralMap
11
+ from flytekit.models.task import TaskTemplate
12
+
13
+
14
+ @dataclass
15
+ class DataServiceMetadata(ResourceMeta):
16
+ dataservice_config: DataServiceConfig
17
+ name: str
18
+
19
+
20
+ class DataServiceAgent(AsyncAgentBase):
21
+ name = "K8s DataService Async Agent"
22
+
23
+ def __init__(self):
24
+ self.k8s_manager = K8sManager()
25
+ super().__init__(task_type_name="dataservicetask", metadata_type=DataServiceMetadata)
26
+ self.config = None
27
+
28
+ def create(
29
+ self, task_template: TaskTemplate, output_prefix: str, inputs: Optional[LiteralMap] = None, **kwargs
30
+ ) -> DataServiceMetadata:
31
+ graph_engine_config = task_template.custom
32
+ self.k8s_manager.set_configs(graph_engine_config)
33
+ logger.info(f"Loaded agent config file {self.config}")
34
+ existing_release_name = graph_engine_config.get("ExistingReleaseName", None)
35
+ logger.info(f"The existing data service release name is {existing_release_name}")
36
+
37
+ name = ""
38
+ if existing_release_name is None or existing_release_name == "":
39
+ logger.info("Creating K8s data service resources...")
40
+ name = self.k8s_manager.create_data_service()
41
+ logger.info(f'Data service {name} with image {graph_engine_config["Image"]} completed')
42
+ else:
43
+ name = existing_release_name
44
+ logger.info(f"User configs to use the existing data service release name: {name}.")
45
+
46
+ dataservice_config = DataServiceConfig(
47
+ Name=graph_engine_config.get("Name", None),
48
+ Image=graph_engine_config["Image"],
49
+ Command=graph_engine_config["Command"],
50
+ Cluster=graph_engine_config["Cluster"],
51
+ ExistingReleaseName=graph_engine_config.get("ExistingReleaseName", None),
52
+ )
53
+ metadata = DataServiceMetadata(
54
+ dataservice_config=dataservice_config,
55
+ name=name,
56
+ )
57
+ logger.info(f"Created DataService metadata {metadata}")
58
+ return metadata
59
+
60
+ def get(self, resource_meta: DataServiceMetadata) -> Resource:
61
+ logger.info("K8s Data Service get is called")
62
+ data = resource_meta.dataservice_config
63
+ data_dict = data.__dict__ if isinstance(data, DataServiceConfig) else data
64
+ logger.info(f"The data_dict is {data_dict}")
65
+ self.k8s_manager.set_configs(data_dict)
66
+ name = data.Name
67
+ logger.info(f"Get the stateful set name {name}")
68
+
69
+ k8s_status = self.k8s_manager.check_stateful_set_status(name)
70
+ flyte_state = None
71
+ if k8s_status in ["failed", "timeout", "timedout", "canceled", "skipped", "internal_error"]:
72
+ flyte_state = TaskExecution.FAILED
73
+ elif k8s_status in ["done", "succeeded", "success"]:
74
+ flyte_state = TaskExecution.SUCCEEDED
75
+ elif k8s_status in ["running", "terminating", "pending"]:
76
+ flyte_state = TaskExecution.RUNNING
77
+ else:
78
+ logger.error(f"Unrecognized state: {k8s_status}")
79
+ outputs = {
80
+ "data_service_name": name,
81
+ }
82
+ # TODO: Add logs for StatefulSet.
83
+ return Resource(phase=flyte_state, outputs=outputs)
84
+
85
+ def delete(self, resource_meta: DataServiceMetadata):
86
+ logger.info("DataService delete is called")
87
+ data = resource_meta.dataservice_config
88
+
89
+ data_dict = data.__dict__ if isinstance(data, DataServiceConfig) else data
90
+ self.k8s_manager.set_configs(data_dict)
91
+
92
+ name = resource_meta.name
93
+ logger.info(f"To delete the DataService (e.g., StatefulSet and Service) with name {name}")
94
+ self.k8s_manager.delete_stateful_set(name)
95
+ self.k8s_manager.delete_service(name)
96
+
97
+
98
+ AgentRegistry.register(DataServiceAgent())
@@ -0,0 +1,20 @@
1
+ from kubernetes import config
2
+
3
+ from flytekit import logger
4
+
5
+
6
+ class KubeConfig:
7
+ def __init__(self):
8
+ pass
9
+
10
+ def load_kube_config(self) -> None:
11
+ """Load the kubernetes config based on fabric details prior to K8s client usage
12
+
13
+ :params target_fabric: fabric on which we are loading configs
14
+ """
15
+ try:
16
+ logger.info("Attempting to load in-cluster configuration.")
17
+ config.load_incluster_config() # This will use the service account credentials
18
+ logger.info("Successfully loaded in-cluster configuration using the agent service account.")
19
+ except config.ConfigException as e:
20
+ logger.warning(f"Failed to load in-cluster configuration. {e}")
@@ -0,0 +1,219 @@
1
+ import uuid
2
+
3
+ from flytekitplugins.k8sdataservice.k8s.kube_config import KubeConfig
4
+ from kubernetes import client
5
+ from kubernetes.client.rest import ApiException
6
+ from utils.resources import cleanup_resources, convert_flyte_to_k8s_fields
7
+
8
+ from flytekit import logger
9
+
10
+ APPNAME = "service-name"
11
+ DEFAULT_RESOURCES = client.V1ResourceRequirements(
12
+ requests={"cpu": "2", "memory": "10G"}, limits={"cpu": "6", "memory": "16G"}
13
+ )
14
+
15
+
16
+ class K8sManager:
17
+ def __init__(self):
18
+ self.config = KubeConfig()
19
+ self.config.load_kube_config()
20
+ self.apps_v1_api = client.AppsV1Api()
21
+ self.core_v1_api = client.CoreV1Api()
22
+
23
+ def set_configs(self, data_service_config):
24
+ self.data_service_config = data_service_config
25
+ self.labels = {}
26
+ self.namespace = "flyte"
27
+ self.name = None
28
+ self.name = data_service_config.get("Name", None)
29
+ if self.name is None:
30
+ self.name = f"k8s-dataservice-{uuid.uuid4().hex[:8]}"
31
+
32
+ def create_data_service(self) -> str:
33
+ svc_name = self.create_service()
34
+ logger.info(f"Created service: {svc_name}")
35
+ stateful_set_obj = self.create_stateful_set_object()
36
+ name = self.create_stateful_set(stateful_set_obj)
37
+ return name
38
+
39
+ def create_stateful_set(self, stateful_set_object) -> str:
40
+ api_response = None
41
+ try:
42
+ api_response = self.apps_v1_api.create_namespaced_stateful_set(
43
+ namespace=self.namespace, body=stateful_set_object
44
+ )
45
+ logger.info(f"Created statefulset in K8s API server: {api_response}")
46
+ except ApiException as e:
47
+ logger.error(f"Exception when calling AppsV1Api->create_namespaced_stateful_set: {e}\n")
48
+ raise
49
+ return api_response.metadata.name
50
+
51
+ def create_stateful_set_object(self):
52
+ container = self._create_container()
53
+ template = self._create_pod_template(container)
54
+ spec = self._create_stateful_set_spec(template)
55
+ return client.V1StatefulSet(
56
+ api_version="apps/v1",
57
+ kind="StatefulSet",
58
+ metadata=client.V1ObjectMeta(
59
+ labels=self.labels,
60
+ name=self.name,
61
+ annotations={},
62
+ ),
63
+ spec=spec,
64
+ )
65
+
66
+ def _create_container(self):
67
+ ss_replicas = self.data_service_config.get("Replicas", 1)
68
+ port = self.data_service_config.get("Port", 40000)
69
+ ss_env = [
70
+ client.V1EnvVar(name="GE_BASE_PORT", value=str(port)),
71
+ client.V1EnvVar(name="GE_COUNT", value=str(int(ss_replicas))),
72
+ client.V1EnvVar(name="SERVER_PORT", value=str(port)),
73
+ ]
74
+ return client.V1Container(
75
+ name=self.name,
76
+ image=self.data_service_config["Image"],
77
+ image_pull_policy="IfNotPresent",
78
+ ports=[client.V1ContainerPort(container_port=port, name="graph-engine")],
79
+ command=self.data_service_config["Command"],
80
+ env=ss_env,
81
+ resources=self.get_resources(),
82
+ )
83
+
84
+ def _create_pod_template(self, container):
85
+ self.labels.update({"app.kubernetes.io/instance": self.name})
86
+ return client.V1PodTemplateSpec(
87
+ metadata=client.V1ObjectMeta(
88
+ labels=self.labels,
89
+ annotations={},
90
+ ),
91
+ spec=client.V1PodSpec(
92
+ containers=[container],
93
+ security_context=client.V1PodSecurityContext(
94
+ fs_group=1001,
95
+ run_as_group=1001,
96
+ run_as_non_root=True,
97
+ run_as_user=1001,
98
+ ),
99
+ ),
100
+ )
101
+
102
+ def _create_stateful_set_spec(self, template):
103
+ ss_replicas = self.data_service_config.get("Replicas", 1)
104
+ return client.V1StatefulSetSpec(
105
+ replicas=int(ss_replicas),
106
+ selector=client.V1LabelSelector(
107
+ match_labels={"app.kubernetes.io/instance": self.name},
108
+ ),
109
+ service_name=self.name,
110
+ template=template,
111
+ )
112
+
113
+ def create_service(self) -> str:
114
+ namespace = self.namespace
115
+ logger.info(f"creating a service at namespace {namespace} with name {self.name}")
116
+ port = self.data_service_config.get("Port", 40000)
117
+ self.labels.update({"app.kubernetes.io/instance": self.name, "app": APPNAME})
118
+ body = client.V1Service(
119
+ api_version="v1",
120
+ kind="Service",
121
+ metadata=client.V1ObjectMeta(
122
+ name=self.name,
123
+ labels=self.labels,
124
+ namespace=namespace,
125
+ ),
126
+ spec=client.V1ServiceSpec(
127
+ selector={"app.kubernetes.io/instance": self.name},
128
+ type="ClusterIP",
129
+ ports=[
130
+ client.V1ServicePort(
131
+ port=port,
132
+ target_port=port,
133
+ name=self.name,
134
+ )
135
+ ],
136
+ ),
137
+ )
138
+ logger.info(
139
+ f"Service configuration in namespace {namespace} and name {self.name} is completed, posting request to K8s API server..."
140
+ )
141
+ api_response = None
142
+ try:
143
+ api_response = self.core_v1_api.create_namespaced_service(namespace=namespace, body=body)
144
+ except ApiException as e:
145
+ logger.error(f"Exception when calling CoreV1Api->create_namespaced_service: {e}")
146
+ raise e
147
+ # This will not happen in K8s API, but in case.
148
+ if api_response is None or not hasattr(api_response, "metadata") or not hasattr(api_response.metadata, "name"):
149
+ raise ValueError("Invalid response from Kubernetes API - missing metadata or name")
150
+ return api_response.metadata.name
151
+
152
+ def check_stateful_set_status(self, name) -> str:
153
+ try:
154
+ stateful_set = self.apps_v1_api.read_namespaced_stateful_set(name=name, namespace=self.namespace)
155
+ status = stateful_set.status
156
+ logger.info(f"StatefulSet status: {status}")
157
+ conditions = status.conditions if status and status.conditions else []
158
+ logger.info(f"StatefulSet conditions: {conditions}")
159
+
160
+ if status.replicas == 0:
161
+ logger.info(
162
+ f"StatefulSet {name} is pending. replicas: {status.replicas}, available: {status.available_replicas }"
163
+ )
164
+ return "pending"
165
+
166
+ if status.replicas > 0 and (
167
+ status.replicas == status.available_replicas or status.replicas == status.ready_replicas
168
+ ):
169
+ logger.info(
170
+ f"StatefulSet {name} has succeeded. replicas: {status.replicas}, available: {status.available_replicas }"
171
+ )
172
+ return "success"
173
+
174
+ if status.replicas > 0 and status.available_replicas is not None and status.available_replicas >= 0:
175
+ logger.info(
176
+ f"StatefulSet {name} is running. replicas: {status.replicas}, available: {status.available_replicas }"
177
+ )
178
+ return "running"
179
+
180
+ logger.info(
181
+ f"StatefulSet {name} status is unknown. Replicas: {status.replicas}, available: {status.available_replicas }"
182
+ )
183
+ return "failed"
184
+ except ApiException as e:
185
+ logger.error(f"Exception when calling AppsV1Api->read_namespaced_stateful_set: {e}")
186
+ return f"Error checking status of StatefulSet {name}: {e}"
187
+
188
+ def delete_stateful_set(self, name: str):
189
+ try:
190
+ self.apps_v1_api.delete_namespaced_stateful_set(
191
+ name=name, namespace=self.namespace, body=client.V1DeleteOptions()
192
+ )
193
+ logger.info(f"Deleted StatefulSet: {name}")
194
+ except ApiException as e:
195
+ logger.error(f"Exception when calling AppsV1Api->delete_namespaced_stateful_set: {e}")
196
+
197
+ def delete_service(self, name: str):
198
+ try:
199
+ self.core_v1_api.delete_namespaced_service(
200
+ name=name, namespace=self.namespace, body=client.V1DeleteOptions()
201
+ )
202
+ logger.info(f"Deleted Service: {name}")
203
+ except ApiException as e:
204
+ logger.error(f"Exception when calling CoreV1Api->delete_namespaced_service: {e}")
205
+
206
+ def get_resources(self) -> client.V1ResourceRequirements:
207
+ res = DEFAULT_RESOURCES
208
+ flyteidl_limits = self.data_service_config.get("Limits", None)
209
+ flyteidl_requests = self.data_service_config.get("Requests", None)
210
+ logger.info(f"Flyte Resources: limits: {flyteidl_limits} and requests {flyteidl_requests}")
211
+ if flyteidl_limits is not None:
212
+ res.limits = convert_flyte_to_k8s_fields(flyteidl_limits)
213
+ logger.info(f"Resources limits updated is: {res.limits}")
214
+ if flyteidl_requests is not None:
215
+ res.requests = convert_flyte_to_k8s_fields(flyteidl_requests)
216
+ logger.info(f"Resources requests updated is: {res.requests}")
217
+ cleanup_resources(res)
218
+ logger.info(f"Resources cleaned up is: {res}")
219
+ return res
@@ -0,0 +1,67 @@
1
+ from flytekitplugins.k8sdataservice.k8s.kube_config import KubeConfig
2
+ from kubernetes import client
3
+ from kubernetes.client.rest import ApiException
4
+
5
+ from flytekit import logger
6
+ from flytekit.sensor.base_sensor import BaseSensor
7
+
8
+ TRAININGJOB_API_GROUP = "kubeflow.org"
9
+ VERSION = "v1"
10
+
11
+
12
+ class CleanupSensor(BaseSensor):
13
+ def __init__(self, name: str, namespace: str = "flyte", **kwargs):
14
+ """
15
+ Initialize the CleanupSensor class with relevant configurations for monitoring and managing the k8s data service.
16
+ """
17
+ super().__init__(name=name, task_type="sensor", **kwargs)
18
+ self.k8s_config = KubeConfig()
19
+ try:
20
+ self.k8s_config.load_kube_config()
21
+ except kubernetes.config.ConfigException as e:
22
+ logger.error(f"Failed to load kubernetes config: {e}")
23
+ raise
24
+ self.apps_v1_api = client.AppsV1Api()
25
+ self.core_v1_api = client.CoreV1Api()
26
+ self.custom_api = client.CustomObjectsApi()
27
+ self.namespace = namespace
28
+
29
+ async def poke(self, release_name: str, cleanup_data_service: bool, cluster: str) -> bool:
30
+ """poke will delete the graph engine resources based on the user's configuration
31
+ 1. This has to be done in the control plane by design. We don't expect any users's running pod to be authn/z to manage resources
32
+ 2. This can not be done in the async agent because the delete callback is only invoked on abortion operation or failure phase.
33
+ while this makes sense but what we need is a separate task to delete graph engine without complicating the regular async agent flow.
34
+ 3. In the near future, we will add the poking logic on the training job's status. In the initial implementation, we skipped
35
+ it for simplicity. This is also why we use the sensor API to keep forward compatibility
36
+ """
37
+ self.release_name = release_name
38
+ self.cleanup_data_service = cleanup_data_service
39
+ self.cluster = cluster
40
+ return await self._handle_cleanup()
41
+
42
+ async def _handle_cleanup(self) -> bool:
43
+ if not self.cleanup_data_service:
44
+ logger.info(
45
+ f"User decides to not to clean up the graph engine: {self.release_name} in cluster {self.cluster}, namespace {self.namespace}"
46
+ )
47
+ logger.info("DataService sensor will stop polling")
48
+ return True
49
+ logger.info(f"The training job is in terminal stage, deleting graph engine {self.release_name}")
50
+ self.delete_data_service()
51
+ return True
52
+
53
+ def delete_data_service(self):
54
+ """
55
+ Delete the data service's associated Kubernetes resources (StatefulSet and Service).
56
+ """
57
+
58
+ def delete_resource(resource_type: str, delete_fn):
59
+ try:
60
+ delete_fn(name=self.release_name, namespace=self.namespace, body=client.V1DeleteOptions())
61
+ logger.info(f"Deleted {resource_type}: {self.release_name}")
62
+ except ApiException as e:
63
+ logger.error(f"Error deleting {resource_type}: {e}")
64
+
65
+ logger.info(f"Sensor got the release name: {self.release_name}")
66
+ delete_resource("Service", self.core_v1_api.delete_namespaced_service)
67
+ delete_resource("StatefulSet", self.apps_v1_api.delete_namespaced_stateful_set)
@@ -0,0 +1,71 @@
1
+ from dataclasses import asdict, dataclass
2
+ from typing import Any, Dict, List, Optional, Type
3
+
4
+ from google.protobuf import json_format
5
+ from google.protobuf.struct_pb2 import Struct
6
+
7
+ from flytekit import Resources, kwtypes, logger
8
+ from flytekit.configuration import SerializationSettings
9
+ from flytekit.core.base_task import PythonTask
10
+ from flytekit.core.interface import Interface
11
+ from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin
12
+
13
+
14
+ @dataclass
15
+ class DataServiceConfig(object):
16
+ """DataServiceConfig should be used to configure a DataServiceTask."""
17
+
18
+ Name: Optional[str] = None
19
+ Requests: Optional[Resources] = None
20
+ Limits: Optional[Resources] = None
21
+ Port: Optional[int] = None
22
+ Image: Optional[str] = None
23
+ Command: Optional[List[str]] = None
24
+ Replicas: Optional[int] = None
25
+ ExistingReleaseName: Optional[str] = None
26
+ Cluster: Optional[str] = None
27
+
28
+
29
+ class DataServiceTask(AsyncAgentExecutorMixin, PythonTask[DataServiceConfig]):
30
+ _TASK_TYPE = "dataservicetask"
31
+
32
+ def __init__(
33
+ self,
34
+ name: str,
35
+ task_config: Optional[DataServiceConfig],
36
+ inputs: Optional[Dict[str, Type]] = None,
37
+ **kwargs,
38
+ ):
39
+ super().__init__(
40
+ name=name,
41
+ task_config=task_config,
42
+ interface=Interface(inputs=inputs, outputs=kwtypes(name=str)),
43
+ task_type=self._TASK_TYPE,
44
+ **kwargs,
45
+ )
46
+
47
+ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
48
+ logger.info("get_custom is invoked")
49
+ config = {}
50
+ limits = None
51
+ requests = None
52
+ if self.task_config is not None:
53
+ limits = asdict(self.task_config.Limits) if self.task_config.Limits is not None else None
54
+ requests = asdict(self.task_config.Requests) if self.task_config.Requests is not None else None
55
+ ge = {
56
+ "Name": self.task_config.Name,
57
+ "Image": self.task_config.Image,
58
+ "Command": self.task_config.Command,
59
+ "Port": self.task_config.Port,
60
+ "Replicas": self.task_config.Replicas,
61
+ "ExistingReleaseName": self.task_config.ExistingReleaseName,
62
+ "Cluster": self.task_config.Cluster,
63
+ }
64
+ if limits is not None:
65
+ ge["Limits"] = limits
66
+ if requests is not None:
67
+ ge["Requests"] = requests
68
+ config = ge
69
+ s = Struct()
70
+ s.update(config)
71
+ return json_format.MessageToDict(s)
@@ -0,0 +1,30 @@
1
+ Metadata-Version: 2.2
2
+ Name: flytekitplugins-k8sdataservice
3
+ Version: 1.15.0
4
+ Summary: Flytekit K8s Data Service Plugin
5
+ Author: LinkedIn
6
+ Author-email: shuliang@linkedin.com
7
+ License: apache2
8
+ Classifier: Intended Audience :: Science/Research
9
+ Classifier: Intended Audience :: Developers
10
+ Classifier: License :: OSI Approved :: Apache Software License
11
+ Classifier: Programming Language :: Python :: 3.9
12
+ Classifier: Programming Language :: Python :: 3.10
13
+ Classifier: Programming Language :: Python :: 3.11
14
+ Classifier: Programming Language :: Python :: 3.12
15
+ Classifier: Topic :: Scientific/Engineering
16
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
17
+ Classifier: Topic :: Software Development
18
+ Classifier: Topic :: Software Development :: Libraries
19
+ Classifier: Topic :: Software Development :: Libraries :: Python Modules
20
+ Requires-Python: >=3.9
21
+ Requires-Dist: flytekit<2.0.0,>=1.11.0
22
+ Requires-Dist: kubernetes<24.0.0,>=23.6.0
23
+ Requires-Dist: flyteidl<2.0.0,>=1.11.0
24
+ Dynamic: author
25
+ Dynamic: author-email
26
+ Dynamic: classifier
27
+ Dynamic: license
28
+ Dynamic: requires-dist
29
+ Dynamic: requires-python
30
+ Dynamic: summary