flyte 2.0.0b22__py3-none-any.whl → 2.0.0b30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (197) hide show
  1. flyte/__init__.py +18 -2
  2. flyte/_bin/runtime.py +43 -5
  3. flyte/_cache/cache.py +4 -2
  4. flyte/_cache/local_cache.py +216 -0
  5. flyte/_code_bundle/_ignore.py +1 -1
  6. flyte/_code_bundle/_packaging.py +4 -4
  7. flyte/_code_bundle/_utils.py +14 -8
  8. flyte/_code_bundle/bundle.py +13 -5
  9. flyte/_constants.py +1 -0
  10. flyte/_context.py +4 -1
  11. flyte/_custom_context.py +73 -0
  12. flyte/_debug/constants.py +0 -1
  13. flyte/_debug/vscode.py +6 -1
  14. flyte/_deploy.py +223 -59
  15. flyte/_environment.py +5 -0
  16. flyte/_excepthook.py +1 -1
  17. flyte/_image.py +144 -82
  18. flyte/_initialize.py +95 -12
  19. flyte/_interface.py +2 -0
  20. flyte/_internal/controllers/_local_controller.py +65 -24
  21. flyte/_internal/controllers/_trace.py +1 -1
  22. flyte/_internal/controllers/remote/_action.py +13 -11
  23. flyte/_internal/controllers/remote/_client.py +1 -1
  24. flyte/_internal/controllers/remote/_controller.py +9 -4
  25. flyte/_internal/controllers/remote/_core.py +16 -16
  26. flyte/_internal/controllers/remote/_informer.py +4 -4
  27. flyte/_internal/controllers/remote/_service_protocol.py +7 -7
  28. flyte/_internal/imagebuild/docker_builder.py +139 -84
  29. flyte/_internal/imagebuild/image_builder.py +7 -13
  30. flyte/_internal/imagebuild/remote_builder.py +65 -13
  31. flyte/_internal/imagebuild/utils.py +51 -3
  32. flyte/_internal/resolvers/_task_module.py +5 -38
  33. flyte/_internal/resolvers/default.py +2 -2
  34. flyte/_internal/runtime/convert.py +42 -20
  35. flyte/_internal/runtime/entrypoints.py +24 -1
  36. flyte/_internal/runtime/io.py +21 -8
  37. flyte/_internal/runtime/resources_serde.py +20 -6
  38. flyte/_internal/runtime/reuse.py +1 -1
  39. flyte/_internal/runtime/rusty.py +20 -5
  40. flyte/_internal/runtime/task_serde.py +33 -27
  41. flyte/_internal/runtime/taskrunner.py +10 -1
  42. flyte/_internal/runtime/trigger_serde.py +160 -0
  43. flyte/_internal/runtime/types_serde.py +1 -1
  44. flyte/_keyring/file.py +39 -9
  45. flyte/_logging.py +79 -12
  46. flyte/_map.py +31 -12
  47. flyte/_module.py +70 -0
  48. flyte/_pod.py +2 -2
  49. flyte/_resources.py +213 -31
  50. flyte/_run.py +107 -41
  51. flyte/_task.py +66 -10
  52. flyte/_task_environment.py +96 -24
  53. flyte/_task_plugins.py +4 -2
  54. flyte/_trigger.py +1000 -0
  55. flyte/_utils/__init__.py +2 -1
  56. flyte/_utils/asyn.py +3 -1
  57. flyte/_utils/docker_credentials.py +173 -0
  58. flyte/_utils/module_loader.py +17 -2
  59. flyte/_version.py +3 -3
  60. flyte/cli/_abort.py +3 -3
  61. flyte/cli/_build.py +1 -3
  62. flyte/cli/_common.py +78 -7
  63. flyte/cli/_create.py +178 -3
  64. flyte/cli/_delete.py +23 -1
  65. flyte/cli/_deploy.py +49 -11
  66. flyte/cli/_get.py +79 -34
  67. flyte/cli/_params.py +8 -6
  68. flyte/cli/_plugins.py +209 -0
  69. flyte/cli/_run.py +127 -11
  70. flyte/cli/_serve.py +64 -0
  71. flyte/cli/_update.py +37 -0
  72. flyte/cli/_user.py +17 -0
  73. flyte/cli/main.py +30 -4
  74. flyte/config/_config.py +2 -0
  75. flyte/config/_internal.py +1 -0
  76. flyte/config/_reader.py +3 -3
  77. flyte/connectors/__init__.py +11 -0
  78. flyte/connectors/_connector.py +270 -0
  79. flyte/connectors/_server.py +197 -0
  80. flyte/connectors/utils.py +135 -0
  81. flyte/errors.py +10 -1
  82. flyte/extend.py +8 -1
  83. flyte/extras/_container.py +6 -1
  84. flyte/git/_config.py +11 -9
  85. flyte/io/__init__.py +2 -0
  86. flyte/io/_dataframe/__init__.py +2 -0
  87. flyte/io/_dataframe/basic_dfs.py +1 -1
  88. flyte/io/_dataframe/dataframe.py +12 -8
  89. flyte/io/_dir.py +551 -120
  90. flyte/io/_file.py +538 -141
  91. flyte/models.py +57 -12
  92. flyte/remote/__init__.py +6 -1
  93. flyte/remote/_action.py +18 -16
  94. flyte/remote/_client/_protocols.py +39 -4
  95. flyte/remote/_client/auth/_channel.py +10 -6
  96. flyte/remote/_client/controlplane.py +17 -5
  97. flyte/remote/_console.py +3 -2
  98. flyte/remote/_data.py +4 -3
  99. flyte/remote/_logs.py +3 -3
  100. flyte/remote/_run.py +47 -7
  101. flyte/remote/_secret.py +26 -17
  102. flyte/remote/_task.py +21 -9
  103. flyte/remote/_trigger.py +306 -0
  104. flyte/remote/_user.py +33 -0
  105. flyte/storage/__init__.py +6 -1
  106. flyte/storage/_parallel_reader.py +274 -0
  107. flyte/storage/_storage.py +185 -103
  108. flyte/types/__init__.py +16 -0
  109. flyte/types/_interface.py +2 -2
  110. flyte/types/_pickle.py +17 -4
  111. flyte/types/_string_literals.py +8 -9
  112. flyte/types/_type_engine.py +26 -19
  113. flyte/types/_utils.py +1 -1
  114. {flyte-2.0.0b22.data → flyte-2.0.0b30.data}/scripts/runtime.py +43 -5
  115. {flyte-2.0.0b22.dist-info → flyte-2.0.0b30.dist-info}/METADATA +8 -1
  116. flyte-2.0.0b30.dist-info/RECORD +192 -0
  117. flyte/_protos/__init__.py +0 -0
  118. flyte/_protos/common/authorization_pb2.py +0 -66
  119. flyte/_protos/common/authorization_pb2.pyi +0 -108
  120. flyte/_protos/common/authorization_pb2_grpc.py +0 -4
  121. flyte/_protos/common/identifier_pb2.py +0 -99
  122. flyte/_protos/common/identifier_pb2.pyi +0 -120
  123. flyte/_protos/common/identifier_pb2_grpc.py +0 -4
  124. flyte/_protos/common/identity_pb2.py +0 -48
  125. flyte/_protos/common/identity_pb2.pyi +0 -72
  126. flyte/_protos/common/identity_pb2_grpc.py +0 -4
  127. flyte/_protos/common/list_pb2.py +0 -36
  128. flyte/_protos/common/list_pb2.pyi +0 -71
  129. flyte/_protos/common/list_pb2_grpc.py +0 -4
  130. flyte/_protos/common/policy_pb2.py +0 -37
  131. flyte/_protos/common/policy_pb2.pyi +0 -27
  132. flyte/_protos/common/policy_pb2_grpc.py +0 -4
  133. flyte/_protos/common/role_pb2.py +0 -37
  134. flyte/_protos/common/role_pb2.pyi +0 -53
  135. flyte/_protos/common/role_pb2_grpc.py +0 -4
  136. flyte/_protos/common/runtime_version_pb2.py +0 -28
  137. flyte/_protos/common/runtime_version_pb2.pyi +0 -24
  138. flyte/_protos/common/runtime_version_pb2_grpc.py +0 -4
  139. flyte/_protos/imagebuilder/definition_pb2.py +0 -60
  140. flyte/_protos/imagebuilder/definition_pb2.pyi +0 -153
  141. flyte/_protos/imagebuilder/definition_pb2_grpc.py +0 -4
  142. flyte/_protos/imagebuilder/payload_pb2.py +0 -32
  143. flyte/_protos/imagebuilder/payload_pb2.pyi +0 -21
  144. flyte/_protos/imagebuilder/payload_pb2_grpc.py +0 -4
  145. flyte/_protos/imagebuilder/service_pb2.py +0 -29
  146. flyte/_protos/imagebuilder/service_pb2.pyi +0 -5
  147. flyte/_protos/imagebuilder/service_pb2_grpc.py +0 -66
  148. flyte/_protos/logs/dataplane/payload_pb2.py +0 -100
  149. flyte/_protos/logs/dataplane/payload_pb2.pyi +0 -177
  150. flyte/_protos/logs/dataplane/payload_pb2_grpc.py +0 -4
  151. flyte/_protos/secret/definition_pb2.py +0 -49
  152. flyte/_protos/secret/definition_pb2.pyi +0 -93
  153. flyte/_protos/secret/definition_pb2_grpc.py +0 -4
  154. flyte/_protos/secret/payload_pb2.py +0 -62
  155. flyte/_protos/secret/payload_pb2.pyi +0 -94
  156. flyte/_protos/secret/payload_pb2_grpc.py +0 -4
  157. flyte/_protos/secret/secret_pb2.py +0 -38
  158. flyte/_protos/secret/secret_pb2.pyi +0 -6
  159. flyte/_protos/secret/secret_pb2_grpc.py +0 -198
  160. flyte/_protos/secret/secret_pb2_grpc_grpc.py +0 -198
  161. flyte/_protos/validate/validate/validate_pb2.py +0 -76
  162. flyte/_protos/workflow/common_pb2.py +0 -27
  163. flyte/_protos/workflow/common_pb2.pyi +0 -14
  164. flyte/_protos/workflow/common_pb2_grpc.py +0 -4
  165. flyte/_protos/workflow/environment_pb2.py +0 -29
  166. flyte/_protos/workflow/environment_pb2.pyi +0 -12
  167. flyte/_protos/workflow/environment_pb2_grpc.py +0 -4
  168. flyte/_protos/workflow/node_execution_service_pb2.py +0 -26
  169. flyte/_protos/workflow/node_execution_service_pb2.pyi +0 -4
  170. flyte/_protos/workflow/node_execution_service_pb2_grpc.py +0 -32
  171. flyte/_protos/workflow/queue_service_pb2.py +0 -111
  172. flyte/_protos/workflow/queue_service_pb2.pyi +0 -168
  173. flyte/_protos/workflow/queue_service_pb2_grpc.py +0 -172
  174. flyte/_protos/workflow/run_definition_pb2.py +0 -123
  175. flyte/_protos/workflow/run_definition_pb2.pyi +0 -352
  176. flyte/_protos/workflow/run_definition_pb2_grpc.py +0 -4
  177. flyte/_protos/workflow/run_logs_service_pb2.py +0 -41
  178. flyte/_protos/workflow/run_logs_service_pb2.pyi +0 -28
  179. flyte/_protos/workflow/run_logs_service_pb2_grpc.py +0 -69
  180. flyte/_protos/workflow/run_service_pb2.py +0 -137
  181. flyte/_protos/workflow/run_service_pb2.pyi +0 -185
  182. flyte/_protos/workflow/run_service_pb2_grpc.py +0 -446
  183. flyte/_protos/workflow/state_service_pb2.py +0 -67
  184. flyte/_protos/workflow/state_service_pb2.pyi +0 -76
  185. flyte/_protos/workflow/state_service_pb2_grpc.py +0 -138
  186. flyte/_protos/workflow/task_definition_pb2.py +0 -82
  187. flyte/_protos/workflow/task_definition_pb2.pyi +0 -88
  188. flyte/_protos/workflow/task_definition_pb2_grpc.py +0 -4
  189. flyte/_protos/workflow/task_service_pb2.py +0 -60
  190. flyte/_protos/workflow/task_service_pb2.pyi +0 -59
  191. flyte/_protos/workflow/task_service_pb2_grpc.py +0 -138
  192. flyte-2.0.0b22.dist-info/RECORD +0 -250
  193. {flyte-2.0.0b22.data → flyte-2.0.0b30.data}/scripts/debug.py +0 -0
  194. {flyte-2.0.0b22.dist-info → flyte-2.0.0b30.dist-info}/WHEEL +0 -0
  195. {flyte-2.0.0b22.dist-info → flyte-2.0.0b30.dist-info}/entry_points.txt +0 -0
  196. {flyte-2.0.0b22.dist-info → flyte-2.0.0b30.dist-info}/licenses/LICENSE +0 -0
  197. {flyte-2.0.0b22.dist-info → flyte-2.0.0b30.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,270 @@
1
+ import asyncio
2
+ import json
3
+ import typing
4
+ from abc import ABC, abstractmethod
5
+ from dataclasses import asdict, dataclass
6
+ from typing import Any, Dict, List, Optional
7
+
8
+ from flyteidl2.core import tasks_pb2
9
+ from flyteidl2.core.execution_pb2 import TaskExecution, TaskLog
10
+ from flyteidl2.plugins import connector_pb2
11
+ from flyteidl2.plugins.connector_pb2 import Connector as ConnectorProto
12
+ from flyteidl2.plugins.connector_pb2 import (
13
+ GetTaskLogsResponse,
14
+ GetTaskMetricsResponse,
15
+ TaskCategory,
16
+ TaskExecutionMetadata,
17
+ )
18
+ from google.protobuf import json_format
19
+ from google.protobuf.struct_pb2 import Struct
20
+
21
+ from flyte import Secret
22
+ from flyte._context import internal_ctx
23
+ from flyte._initialize import get_init_config
24
+ from flyte._internal.runtime.convert import convert_from_native_to_outputs
25
+ from flyte._internal.runtime.task_serde import get_proto_task
26
+ from flyte._logging import logger
27
+ from flyte._task import TaskTemplate
28
+ from flyte.connectors.utils import is_terminal_phase
29
+ from flyte.models import NativeInterface, SerializationContext
30
+ from flyte.types._type_engine import dataclass_from_dict
31
+
32
+
33
+ @dataclass(frozen=True)
34
+ class ConnectorRegistryKey:
35
+ task_type_name: str
36
+ task_type_version: int
37
+
38
+
39
+ @dataclass
40
+ class ResourceMeta:
41
+ """
42
+ This is the metadata for the job. For example, the id of the job.
43
+ """
44
+
45
+ def encode(self) -> bytes:
46
+ """
47
+ Encode the resource meta to bytes.
48
+ """
49
+ return json.dumps(asdict(self)).encode("utf-8")
50
+
51
+ @classmethod
52
+ def decode(cls, data: bytes) -> "ResourceMeta":
53
+ """
54
+ Decode the resource meta from bytes.
55
+ """
56
+ return dataclass_from_dict(cls, json.loads(data.decode("utf-8")))
57
+
58
+
59
+ @dataclass
60
+ class Resource:
61
+ """
62
+ This is the output resource of the job.
63
+
64
+ Attributes
65
+ ----------
66
+ phase : TaskExecution.Phase
67
+ The phase of the job.
68
+ message : Optional[str]
69
+ The return message from the job.
70
+ log_links : Optional[List[TaskLog]]
71
+ The log links of the job. For example, the link to the BigQuery Console.
72
+ outputs : Optional[Union[LiteralMap, typing.Dict[str, Any]]]
73
+ The outputs of the job. If return python native types, the agent will convert them to flyte literals.
74
+ custom_info : Optional[typing.Dict[str, Any]]
75
+ The custom info of the job. For example, the job config.
76
+ """
77
+
78
+ phase: TaskExecution.Phase
79
+ message: Optional[str] = None
80
+ log_links: Optional[List[TaskLog]] = None
81
+ outputs: Optional[Dict[str, Any]] = None
82
+ custom_info: Optional[typing.Dict[str, Any]] = None
83
+
84
+
85
+ class AsyncConnector(ABC):
86
+ """
87
+ This is the base class for all async connectors, and it defines the interface that all connectors must implement.
88
+ The connector service is responsible for invoking connectors.
89
+ The executor will communicate with the connector service to create tasks, get the status of tasks, and delete tasks.
90
+
91
+ All the connectors should be registered in the ConnectorRegistry.
92
+ Connector Service will look up the connector based on the task type and version.
93
+ """
94
+
95
+ name = "Async Connector"
96
+ task_type_name: str
97
+ task_type_version: int = 0
98
+ metadata_type: ResourceMeta
99
+
100
+ @abstractmethod
101
+ async def create(
102
+ self,
103
+ task_template: tasks_pb2.TaskTemplate,
104
+ output_prefix: str,
105
+ inputs: Optional[Dict[str, typing.Any]] = None,
106
+ task_execution_metadata: Optional[TaskExecutionMetadata] = None,
107
+ **kwargs,
108
+ ) -> ResourceMeta:
109
+ """
110
+ Return a resource meta that can be used to get the status of the task.
111
+ """
112
+ raise NotImplementedError
113
+
114
+ @abstractmethod
115
+ async def get(self, resource_meta: ResourceMeta, **kwargs) -> Resource:
116
+ """
117
+ Return the status of the task, and return the outputs in some cases. For example, bigquery job
118
+ can't write the structured dataset to the output location, so it returns the output literals to the propeller,
119
+ and the propeller will write the structured dataset to the blob store.
120
+ """
121
+ raise NotImplementedError
122
+
123
+ @abstractmethod
124
+ async def delete(self, resource_meta: ResourceMeta, **kwargs):
125
+ """
126
+ Delete the task. This call should be idempotent. It should raise an error if fails to delete the task.
127
+ """
128
+ raise NotImplementedError
129
+
130
+ async def get_metrics(self, resource_meta: ResourceMeta, **kwargs) -> GetTaskMetricsResponse:
131
+ """
132
+ Return the metrics for the task.
133
+ """
134
+ raise NotImplementedError
135
+
136
+ async def get_logs(self, resource_meta: ResourceMeta, **kwargs) -> GetTaskLogsResponse:
137
+ """
138
+ Return the metrics for the task.
139
+ """
140
+ raise NotImplementedError
141
+
142
+
143
+ class ConnectorRegistry(object):
144
+ """
145
+ This is the registry for all connectors.
146
+ The connector service will look up the connector registry based on the task type and version.
147
+ """
148
+
149
+ _REGISTRY: typing.ClassVar[Dict[ConnectorRegistryKey, AsyncConnector]] = {}
150
+ _METADATA: typing.ClassVar[Dict[str, ConnectorProto]] = {}
151
+
152
+ @staticmethod
153
+ def register(connector: AsyncConnector, override: bool = False):
154
+ key = ConnectorRegistryKey(
155
+ task_type_name=connector.task_type_name, task_type_version=connector.task_type_version
156
+ )
157
+ if key in ConnectorRegistry._REGISTRY and override is False:
158
+ raise ValueError(
159
+ f"Duplicate connector for task type: {connector.task_type_name}"
160
+ f" and version: {connector.task_type_version}"
161
+ )
162
+ ConnectorRegistry._REGISTRY[key] = connector
163
+
164
+ task_category = TaskCategory(name=connector.task_type_name, version=connector.task_type_version)
165
+
166
+ if connector.name in ConnectorRegistry._METADATA:
167
+ connector_metadata = ConnectorRegistry._get_connector_metadata(connector.name)
168
+ connector_metadata.supported_task_categories.append(task_category)
169
+ else:
170
+ connector_metadata = ConnectorProto(
171
+ name=connector.name,
172
+ supported_task_categories=[task_category],
173
+ )
174
+ ConnectorRegistry._METADATA[connector.name] = connector_metadata
175
+
176
+ @staticmethod
177
+ def get_connector(task_type_name: str, task_type_version: int = 0) -> AsyncConnector:
178
+ key = ConnectorRegistryKey(task_type_name=task_type_name, task_type_version=task_type_version)
179
+ if key not in ConnectorRegistry._REGISTRY:
180
+ raise FlyteConnectorNotFound(
181
+ f"Cannot find connector for task type: {task_type_name} and version: {task_type_version}"
182
+ )
183
+ return ConnectorRegistry._REGISTRY[key]
184
+
185
+ @staticmethod
186
+ def _list_connectors() -> List[ConnectorProto]:
187
+ return list(ConnectorRegistry._METADATA.values())
188
+
189
+ @staticmethod
190
+ def _get_connector_metadata(name: str) -> ConnectorProto:
191
+ if name not in ConnectorRegistry._METADATA:
192
+ raise FlyteConnectorNotFound(f"Cannot find connector for name: {name}.")
193
+ return ConnectorRegistry._METADATA[name]
194
+
195
+
196
+ class ConnectorSecretsMixin:
197
+ def __init__(self, secrets: Dict[str, str]):
198
+ # Key is the id of the secret, value is the secret name.
199
+ self._secrets = secrets
200
+
201
+ @property
202
+ def secrets(self) -> List[Secret]:
203
+ return [Secret(key=k, as_env_var=v) for k, v in self._secrets.items()]
204
+
205
+
206
+ class AsyncConnectorExecutorMixin:
207
+ """
208
+ This mixin class is used to run the connector task locally, and it's only used for local execution.
209
+ Task should inherit from this class if the task can be run in the connector.
210
+ """
211
+
212
+ async def execute(self, **kwargs) -> Any:
213
+ task = typing.cast(TaskTemplate, self)
214
+ connector = ConnectorRegistry.get_connector(task.task_type, task.task_type_version)
215
+
216
+ ctx = internal_ctx()
217
+ tctx = internal_ctx().data.task_context
218
+ cfg = get_init_config()
219
+
220
+ if tctx is None:
221
+ raise RuntimeError("Task context is not set.")
222
+
223
+ sc = SerializationContext(
224
+ project=tctx.action.project,
225
+ domain=tctx.action.domain,
226
+ org=tctx.action.org,
227
+ code_bundle=tctx.code_bundle,
228
+ version=tctx.version,
229
+ image_cache=tctx.compiled_image_cache,
230
+ root_dir=cfg.root_dir,
231
+ )
232
+ tt = get_proto_task(task, sc)
233
+ resource_meta = await connector.create(task_template=tt, output_prefix=ctx.raw_data.path, inputs=kwargs)
234
+ resource = Resource(phase=TaskExecution.RUNNING)
235
+
236
+ while not is_terminal_phase(resource.phase):
237
+ resource = await connector.get(resource_meta=resource_meta)
238
+
239
+ if resource.log_links:
240
+ for link in resource.log_links:
241
+ logger.info(f"{link.name}: {link.uri}")
242
+ await asyncio.sleep(1)
243
+
244
+ if resource.phase != TaskExecution.SUCCEEDED:
245
+ raise RuntimeError(f"Failed to run the task {task.name} with error: {resource.message}")
246
+
247
+ # TODO: Support abort
248
+
249
+ if resource.outputs is None:
250
+ return None
251
+ return tuple(resource.outputs.values())
252
+
253
+
254
+ async def get_resource_proto(resource: Resource) -> connector_pb2.Resource:
255
+ if resource.outputs:
256
+ interface = NativeInterface.from_types(inputs={}, outputs={k: type(v) for k, v in resource.outputs.items()})
257
+ outputs = await convert_from_native_to_outputs(tuple(resource.outputs.values()), interface)
258
+ else:
259
+ outputs = None
260
+
261
+ return connector_pb2.Resource(
262
+ phase=resource.phase,
263
+ message=resource.message,
264
+ log_links=resource.log_links,
265
+ outputs=outputs,
266
+ custom_info=(json_format.Parse(json.dumps(resource.custom_info), Struct()) if resource.custom_info else None),
267
+ )
268
+
269
+
270
+ class FlyteConnectorNotFound(ValueError): ...
@@ -0,0 +1,197 @@
1
+ import inspect
2
+ import os
3
+ import sys
4
+ from http import HTTPStatus
5
+ from typing import Callable, Dict, List, Tuple, Type, Union
6
+
7
+ import grpc
8
+ from flyteidl2.core.security_pb2 import Connection
9
+ from flyteidl2.plugins.connector_pb2 import (
10
+ CreateTaskRequest,
11
+ CreateTaskResponse,
12
+ DeleteTaskRequest,
13
+ DeleteTaskResponse,
14
+ GetConnectorRequest,
15
+ GetConnectorResponse,
16
+ GetTaskLogsRequest,
17
+ GetTaskLogsResponse,
18
+ GetTaskMetricsRequest,
19
+ GetTaskMetricsResponse,
20
+ GetTaskRequest,
21
+ GetTaskResponse,
22
+ ListConnectorsRequest,
23
+ ListConnectorsResponse,
24
+ )
25
+ from flyteidl2.service.connector_pb2_grpc import (
26
+ AsyncConnectorServiceServicer,
27
+ ConnectorMetadataServiceServicer,
28
+ )
29
+ from prometheus_client import Counter, Summary
30
+
31
+ from flyte._internal.runtime.convert import Inputs, convert_from_inputs_to_native
32
+ from flyte._logging import logger
33
+ from flyte.connectors._connector import ConnectorRegistry, FlyteConnectorNotFound, get_resource_proto
34
+ from flyte.connectors.utils import _start_grpc_server
35
+ from flyte.models import NativeInterface, _has_default
36
+ from flyte.syncify import syncify
37
+ from flyte.types import TypeEngine
38
+
39
+ metric_prefix = "flyte_connector_"
40
+ create_operation = "create"
41
+ get_operation = "get"
42
+ delete_operation = "delete"
43
+
44
+ # Follow the naming convention. https://prometheus.io/docs/practices/naming/
45
+ request_success_count = Counter(
46
+ f"{metric_prefix}requests_success_total",
47
+ "Total number of successful requests",
48
+ ["task_type", "operation"],
49
+ )
50
+ request_failure_count = Counter(
51
+ f"{metric_prefix}requests_failure_total",
52
+ "Total number of failed requests",
53
+ ["task_type", "operation", "error_code"],
54
+ )
55
+ request_latency = Summary(
56
+ f"{metric_prefix}request_latency_seconds",
57
+ "Time spent processing connector request",
58
+ ["task_type", "operation"],
59
+ )
60
+ input_literal_size = Summary(f"{metric_prefix}input_literal_bytes", "Size of input literal", ["task_type"])
61
+
62
+
63
+ def _handle_exception(e: Exception, context: grpc.ServicerContext, task_type: str, operation: str):
64
+ if isinstance(e, FlyteConnectorNotFound):
65
+ error_message = f"Cannot find connector for task type: {task_type}."
66
+ logger.error(error_message)
67
+ context.set_code(grpc.StatusCode.NOT_FOUND)
68
+ context.set_details(error_message)
69
+ request_failure_count.labels(task_type=task_type, operation=operation, error_code=HTTPStatus.NOT_FOUND).inc()
70
+ else:
71
+ error_message = f"failed to {operation} {task_type} task with error:\n {e}."
72
+ logger.error(error_message)
73
+ context.set_code(grpc.StatusCode.INTERNAL)
74
+ context.set_details(error_message)
75
+ request_failure_count.labels(
76
+ task_type=task_type, operation=operation, error_code=HTTPStatus.INTERNAL_SERVER_ERROR
77
+ ).inc()
78
+
79
+
80
+ class ConnectorService:
81
+ @syncify
82
+ @classmethod
83
+ async def run(cls, port: int, prometheus_port: int, worker: int, timeout: int | None, modules: List[str] | None):
84
+ working_dir = os.getcwd()
85
+ if all(os.path.realpath(path) != working_dir for path in sys.path):
86
+ sys.path.append(working_dir)
87
+ await _start_grpc_server(port, prometheus_port, worker, timeout, modules)
88
+
89
+
90
+ def record_connector_metrics(func: Callable):
91
+ async def wrapper(
92
+ self,
93
+ request: Union[CreateTaskRequest, GetTaskRequest, DeleteTaskRequest],
94
+ context: grpc.ServicerContext,
95
+ *args,
96
+ **kwargs,
97
+ ):
98
+ if isinstance(request, CreateTaskRequest):
99
+ task_type = request.template.type
100
+ operation = create_operation
101
+ if request.inputs:
102
+ input_literal_size.labels(task_type=task_type).observe(request.inputs.ByteSize())
103
+ elif isinstance(request, GetTaskRequest):
104
+ task_type = request.task_category.name
105
+ operation = get_operation
106
+ elif isinstance(request, DeleteTaskRequest):
107
+ task_type = request.task_category.name
108
+ operation = delete_operation
109
+ else:
110
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
111
+ context.set_details("Method not implemented!")
112
+ return None
113
+
114
+ try:
115
+ with request_latency.labels(task_type=task_type, operation=operation).time():
116
+ res = await func(self, request, context, *args, **kwargs)
117
+ request_success_count.labels(task_type=task_type, operation=operation).inc()
118
+ return res
119
+ except Exception as e:
120
+ _handle_exception(e, context, task_type, operation)
121
+
122
+ return wrapper
123
+
124
+
125
+ def _get_connection_kwargs(request: Connection) -> Dict[str, str]:
126
+ kwargs = {}
127
+
128
+ for k, v in request.secrets.items():
129
+ kwargs[k] = v
130
+ for k, v in request.configs.items():
131
+ kwargs[k] = v
132
+
133
+ return kwargs
134
+
135
+
136
+ class AsyncConnectorService(AsyncConnectorServiceServicer):
137
+ @record_connector_metrics
138
+ async def CreateTask(self, request: CreateTaskRequest, context: grpc.ServicerContext) -> CreateTaskResponse:
139
+ template = request.template
140
+ connector = ConnectorRegistry.get_connector(template.type, template.task_type_version)
141
+ logger.info(f"{connector.name} start creating the job")
142
+ python_interface_inputs: Dict[str, Tuple[Type, Type[_has_default] | Type[inspect._empty]]] = {
143
+ name: (TypeEngine.guess_python_type(lt.type), inspect.Parameter.empty)
144
+ for name, lt in template.interface.inputs.variables.items()
145
+ }
146
+ native_interface = NativeInterface.from_types(inputs=python_interface_inputs, outputs={})
147
+ native_inputs = await convert_from_inputs_to_native(native_interface, Inputs(proto_inputs=request.inputs))
148
+ resource_meta = await connector.create(
149
+ task_template=request.template,
150
+ inputs=native_inputs,
151
+ output_prefix=request.output_prefix,
152
+ task_execution_metadata=request.task_execution_metadata,
153
+ connection=_get_connection_kwargs(request.connection),
154
+ )
155
+ return CreateTaskResponse(resource_meta=resource_meta.encode())
156
+
157
+ @record_connector_metrics
158
+ async def GetTask(self, request: GetTaskRequest, context: grpc.ServicerContext) -> GetTaskResponse:
159
+ connector = ConnectorRegistry.get_connector(request.task_category.name, request.task_category.version)
160
+ logger.info(f"{connector.name} start checking the status of the job")
161
+ res = await connector.get(
162
+ resource_meta=connector.metadata_type.decode(request.resource_meta),
163
+ connection=_get_connection_kwargs(request.connection),
164
+ )
165
+ return GetTaskResponse(resource=await get_resource_proto(res))
166
+
167
+ @record_connector_metrics
168
+ async def DeleteTask(self, request: DeleteTaskRequest, context: grpc.ServicerContext) -> DeleteTaskResponse:
169
+ connector = ConnectorRegistry.get_connector(request.task_category.name, request.task_category.version)
170
+ logger.info(f"{connector.name} start deleting the job")
171
+ await connector.delete(
172
+ resource_meta=connector.metadata_type.decode(request.resource_meta),
173
+ connection=_get_connection_kwargs(request.connection),
174
+ )
175
+ return DeleteTaskResponse()
176
+
177
+ async def GetTaskMetrics(
178
+ self, request: GetTaskMetricsRequest, context: grpc.ServicerContext
179
+ ) -> GetTaskMetricsResponse:
180
+ connector = ConnectorRegistry.get_connector(request.task_category.name, request.task_category.version)
181
+ logger.info(f"{connector.name} start getting metrics of the job")
182
+ return await connector.get_metrics(resource_meta=connector.metadata_type.decode(request.resource_meta))
183
+
184
+ async def GetTaskLogs(self, request: GetTaskLogsRequest, context: grpc.ServicerContext) -> GetTaskLogsResponse:
185
+ connector = ConnectorRegistry.get_connector(request.task_category.name, request.task_category.version)
186
+ logger.info(f"{connector.name} start getting logs of the job")
187
+ return await connector.get_logs(resource_meta=connector.metadata_type.decode(request.resource_meta))
188
+
189
+
190
+ class ConnectorMetadataService(ConnectorMetadataServiceServicer):
191
+ async def GetConnector(self, request: GetConnectorRequest, context: grpc.ServicerContext) -> GetConnectorResponse:
192
+ return GetConnectorResponse(connector=ConnectorRegistry._get_connector_metadata(request.name))
193
+
194
+ async def ListConnectors(
195
+ self, request: ListConnectorsRequest, context: grpc.ServicerContext
196
+ ) -> ListConnectorsResponse:
197
+ return ListConnectorsResponse(connectors=ConnectorRegistry._list_connectors())
@@ -0,0 +1,135 @@
1
+ import importlib
2
+ from concurrent import futures
3
+ from importlib.metadata import entry_points
4
+ from typing import List
5
+
6
+ import click
7
+ import grpc
8
+ from flyteidl2.core.execution_pb2 import TaskExecution
9
+ from flyteidl2.service import connector_pb2
10
+ from flyteidl2.service.connector_pb2_grpc import (
11
+ add_AsyncConnectorServiceServicer_to_server,
12
+ add_ConnectorMetadataServiceServicer_to_server,
13
+ )
14
+ from rich.console import Console
15
+ from rich.table import Table
16
+
17
+ from flyte import logger
18
+
19
+
20
+ def is_terminal_phase(phase: TaskExecution.Phase) -> bool:
21
+ """
22
+ Return true if the phase is terminal.
23
+ """
24
+ return phase in [TaskExecution.SUCCEEDED, TaskExecution.ABORTED, TaskExecution.FAILED]
25
+
26
+
27
+ def convert_to_flyte_phase(state: str) -> TaskExecution.Phase:
28
+ """
29
+ Convert the state from the connector to the phase in flyte.
30
+ """
31
+ state = state.lower()
32
+ if state in ["failed", "timeout", "timedout", "canceled", "cancelled", "skipped"]:
33
+ return TaskExecution.FAILED
34
+ if state in ["internal_error"]:
35
+ return TaskExecution.RETRYABLE_FAILED
36
+ elif state in ["done", "succeeded", "success", "completed"]:
37
+ return TaskExecution.SUCCEEDED
38
+ elif state in ["running", "terminating"]:
39
+ return TaskExecution.RUNNING
40
+ elif state in ["pending"]:
41
+ return TaskExecution.INITIALIZING
42
+ raise ValueError(f"Unrecognized state: {state}")
43
+
44
+
45
+ async def _start_grpc_server(
46
+ port: int, prometheus_port: int, worker: int, timeout: int | None, modules: List[str] | None
47
+ ):
48
+ try:
49
+ from flyte.connectors._server import (
50
+ AsyncConnectorService,
51
+ ConnectorMetadataService,
52
+ )
53
+ except ImportError as e:
54
+ raise ImportError(
55
+ "Flyte connector dependencies are not installed."
56
+ " Please install it using `pip install flyteplugins-connector`"
57
+ ) from e
58
+
59
+ click.secho("🚀 Starting the connector service...")
60
+ _load_connectors(modules)
61
+ _start_http_server(prometheus_port)
62
+
63
+ print_metadata()
64
+
65
+ server = grpc.aio.server(futures.ThreadPoolExecutor(max_workers=worker))
66
+
67
+ add_AsyncConnectorServiceServicer_to_server(AsyncConnectorService(), server)
68
+ add_ConnectorMetadataServiceServicer_to_server(ConnectorMetadataService(), server)
69
+ _start_health_check_server(server, worker)
70
+
71
+ server.add_insecure_port(f"[::]:{port}")
72
+ await server.start()
73
+ await server.wait_for_termination(timeout)
74
+
75
+
76
+ def _start_http_server(prometheus_port: int):
77
+ try:
78
+ from prometheus_client import start_http_server
79
+
80
+ click.secho("Starting up the server to expose the prometheus metrics...")
81
+ start_http_server(prometheus_port)
82
+ except ImportError as e:
83
+ click.secho(f"Failed to start the prometheus server with error {e}", fg="red")
84
+
85
+
86
+ def _start_health_check_server(server: grpc.Server, worker: int):
87
+ try:
88
+ from grpc_health.v1 import health, health_pb2, health_pb2_grpc
89
+
90
+ health_servicer = health.HealthServicer(
91
+ experimental_non_blocking=True,
92
+ experimental_thread_pool=futures.ThreadPoolExecutor(max_workers=worker),
93
+ )
94
+
95
+ for service in connector_pb2.DESCRIPTOR.services_by_name.values():
96
+ health_servicer.set(service.full_name, health_pb2.HealthCheckResponse.SERVING)
97
+ health_servicer.set(health.SERVICE_NAME, health_pb2.HealthCheckResponse.SERVING)
98
+
99
+ health_pb2_grpc.add_HealthServicer_to_server(health_servicer, server)
100
+
101
+ except ImportError as e:
102
+ click.secho(f"Failed to start the health check servicer with error {e}", fg="red")
103
+
104
+
105
+ def print_metadata():
106
+ from flyte.connectors import ConnectorRegistry
107
+
108
+ connectors = ConnectorRegistry._list_connectors()
109
+
110
+ table = Table(title="Connector Metadata")
111
+ table.add_column("Connector Name", style="cyan", no_wrap=True)
112
+ table.add_column("Support Task Types", style="cyan")
113
+
114
+ for connector in connectors:
115
+ categories = ""
116
+ for category in connector.supported_task_categories:
117
+ categories += f"{category.name} ({category.version}) "
118
+ table.add_row(connector.name, categories)
119
+
120
+ console = Console()
121
+ console.print(table)
122
+
123
+
124
+ def _load_connectors(modules: List[str] | None):
125
+ plugins = entry_points(group="flyte.connectors")
126
+ for ep in plugins:
127
+ try:
128
+ logger.info(f"Loading connector: {ep.name}")
129
+ ep.load()
130
+ except Exception as e:
131
+ logger.warning(f"Failed to load type transformer {ep.name} with error: {e}")
132
+
133
+ if modules:
134
+ for m in modules:
135
+ importlib.import_module(m)
flyte/errors.py CHANGED
@@ -174,7 +174,7 @@ class RuntimeDataValidationError(RuntimeUserError):
174
174
 
175
175
  def __init__(self, var: str, e: Exception | str, task_name: str = ""):
176
176
  super().__init__(
177
- "DataValiationError", f"In task {task_name} variable {var}, failed to serialize/deserialize because of {e}"
177
+ "DataValidationError", f"In task {task_name} variable {var}, failed to serialize/deserialize because of {e}"
178
178
  )
179
179
 
180
180
 
@@ -232,3 +232,12 @@ class SlowDownError(RuntimeUserError):
232
232
 
233
233
  def __init__(self, message: str):
234
234
  super().__init__("SlowDownError", message, "user")
235
+
236
+
237
+ class OnlyAsyncIOSupportedError(RuntimeUserError):
238
+ """
239
+ This error is raised when the user tries to use sync IO in an async task.
240
+ """
241
+
242
+ def __init__(self, message: str):
243
+ super().__init__("OnlyAsyncIOSupportedError", message, "user")
flyte/extend.py CHANGED
@@ -1,12 +1,19 @@
1
1
  from ._initialize import is_initialized
2
+ from ._internal.imagebuild.image_builder import ImageBuildEngine
3
+ from ._internal.runtime.entrypoints import download_code_bundle
4
+ from ._internal.runtime.resources_serde import get_proto_resources
2
5
  from ._resources import PRIMARY_CONTAINER_DEFAULT_NAME, pod_spec_from_resources
3
- from ._task import AsyncFunctionTaskTemplate
6
+ from ._task import AsyncFunctionTaskTemplate, TaskTemplate
4
7
  from ._task_plugins import TaskPluginRegistry
5
8
 
6
9
  __all__ = [
7
10
  "PRIMARY_CONTAINER_DEFAULT_NAME",
8
11
  "AsyncFunctionTaskTemplate",
12
+ "ImageBuildEngine",
9
13
  "TaskPluginRegistry",
14
+ "TaskTemplate",
15
+ "download_code_bundle",
16
+ "get_proto_resources",
10
17
  "is_initialized",
11
18
  "pod_spec_from_resources",
12
19
  ]
@@ -2,7 +2,7 @@ import os
2
2
  import pathlib
3
3
  from typing import Any, Dict, List, Literal, Optional, Tuple, Type, Union
4
4
 
5
- from flyteidl.core import tasks_pb2
5
+ from flyteidl2.core import tasks_pb2
6
6
 
7
7
  from flyte import Image, storage
8
8
  from flyte._logging import logger
@@ -83,6 +83,11 @@ class ContainerTask(TaskTemplate):
83
83
  self._image = Image.from_debian_base()
84
84
  else:
85
85
  self._image = Image.from_base(image)
86
+
87
+ if command and any(not isinstance(c, str) for c in command):
88
+ raise ValueError("All elements in the command list must be strings.")
89
+ if arguments and any(not isinstance(a, str) for a in arguments):
90
+ raise ValueError("All elements in the arguments list must be strings.")
86
91
  self._cmd = command
87
92
  self._args = arguments
88
93
  self._input_data_dir = input_data_dir