flyte 2.0.0b25__py3-none-any.whl → 2.0.0b28__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flyte might be problematic. Click here for more details.
- flyte/__init__.py +2 -0
- flyte/_bin/runtime.py +8 -0
- flyte/_code_bundle/_utils.py +4 -4
- flyte/_code_bundle/bundle.py +1 -1
- flyte/_constants.py +1 -0
- flyte/_deploy.py +0 -1
- flyte/_excepthook.py +1 -1
- flyte/_initialize.py +10 -0
- flyte/_interface.py +2 -0
- flyte/_internal/imagebuild/docker_builder.py +3 -1
- flyte/_internal/imagebuild/remote_builder.py +3 -1
- flyte/_internal/resolvers/_task_module.py +4 -37
- flyte/_internal/runtime/convert.py +3 -2
- flyte/_internal/runtime/entrypoints.py +24 -1
- flyte/_internal/runtime/rusty.py +3 -3
- flyte/_internal/runtime/task_serde.py +19 -4
- flyte/_internal/runtime/trigger_serde.py +2 -2
- flyte/_map.py +2 -35
- flyte/_module.py +68 -0
- flyte/_resources.py +38 -0
- flyte/_run.py +23 -6
- flyte/_task.py +1 -2
- flyte/_task_plugins.py +4 -2
- flyte/_trigger.py +623 -5
- flyte/_utils/__init__.py +2 -1
- flyte/_utils/asyn.py +3 -1
- flyte/_utils/docker_credentials.py +173 -0
- flyte/_utils/module_loader.py +15 -0
- flyte/_version.py +3 -3
- flyte/cli/_common.py +15 -3
- flyte/cli/_create.py +100 -3
- flyte/cli/_deploy.py +38 -4
- flyte/cli/_plugins.py +208 -0
- flyte/cli/_run.py +69 -6
- flyte/cli/_serve.py +154 -0
- flyte/cli/main.py +6 -0
- flyte/connectors/__init__.py +3 -0
- flyte/connectors/_connector.py +270 -0
- flyte/connectors/_server.py +183 -0
- flyte/connectors/utils.py +26 -0
- flyte/models.py +13 -4
- flyte/remote/_client/auth/_channel.py +9 -5
- flyte/remote/_console.py +3 -2
- flyte/remote/_secret.py +6 -4
- flyte/remote/_trigger.py +2 -2
- flyte/types/_type_engine.py +1 -2
- {flyte-2.0.0b25.data → flyte-2.0.0b28.data}/scripts/runtime.py +8 -0
- {flyte-2.0.0b25.dist-info → flyte-2.0.0b28.dist-info}/METADATA +6 -2
- {flyte-2.0.0b25.dist-info → flyte-2.0.0b28.dist-info}/RECORD +54 -46
- {flyte-2.0.0b25.data → flyte-2.0.0b28.data}/scripts/debug.py +0 -0
- {flyte-2.0.0b25.dist-info → flyte-2.0.0b28.dist-info}/WHEEL +0 -0
- {flyte-2.0.0b25.dist-info → flyte-2.0.0b28.dist-info}/entry_points.txt +0 -0
- {flyte-2.0.0b25.dist-info → flyte-2.0.0b28.dist-info}/licenses/LICENSE +0 -0
- {flyte-2.0.0b25.dist-info → flyte-2.0.0b28.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,270 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import json
|
|
3
|
+
import typing
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
from dataclasses import asdict, dataclass
|
|
6
|
+
from typing import Any, Dict, List, Optional
|
|
7
|
+
|
|
8
|
+
from flyteidl2.core import tasks_pb2
|
|
9
|
+
from flyteidl2.core.execution_pb2 import TaskExecution, TaskLog
|
|
10
|
+
from flyteidl2.plugins import connector_pb2
|
|
11
|
+
from flyteidl2.plugins.connector_pb2 import (
|
|
12
|
+
Connector,
|
|
13
|
+
GetTaskLogsResponse,
|
|
14
|
+
GetTaskMetricsResponse,
|
|
15
|
+
TaskCategory,
|
|
16
|
+
TaskExecutionMetadata,
|
|
17
|
+
)
|
|
18
|
+
from google.protobuf import json_format
|
|
19
|
+
from google.protobuf.struct_pb2 import Struct
|
|
20
|
+
|
|
21
|
+
from flyte import Secret
|
|
22
|
+
from flyte._context import internal_ctx
|
|
23
|
+
from flyte._initialize import get_init_config
|
|
24
|
+
from flyte._internal.runtime.convert import convert_from_native_to_outputs
|
|
25
|
+
from flyte._internal.runtime.task_serde import get_proto_task
|
|
26
|
+
from flyte._logging import logger
|
|
27
|
+
from flyte._task import TaskTemplate
|
|
28
|
+
from flyte.connectors.utils import is_terminal_phase
|
|
29
|
+
from flyte.models import NativeInterface, SerializationContext
|
|
30
|
+
from flyte.types._type_engine import dataclass_from_dict
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@dataclass(frozen=True)
|
|
34
|
+
class ConnectorRegistryKey:
|
|
35
|
+
task_type_name: str
|
|
36
|
+
task_type_version: int
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@dataclass
|
|
40
|
+
class ResourceMeta:
|
|
41
|
+
"""
|
|
42
|
+
This is the metadata for the job. For example, the id of the job.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
def encode(self) -> bytes:
|
|
46
|
+
"""
|
|
47
|
+
Encode the resource meta to bytes.
|
|
48
|
+
"""
|
|
49
|
+
return json.dumps(asdict(self)).encode("utf-8")
|
|
50
|
+
|
|
51
|
+
@classmethod
|
|
52
|
+
def decode(cls, data: bytes) -> "ResourceMeta":
|
|
53
|
+
"""
|
|
54
|
+
Decode the resource meta from bytes.
|
|
55
|
+
"""
|
|
56
|
+
return dataclass_from_dict(cls, json.loads(data.decode("utf-8")))
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@dataclass
|
|
60
|
+
class Resource:
|
|
61
|
+
"""
|
|
62
|
+
This is the output resource of the job.
|
|
63
|
+
|
|
64
|
+
Attributes
|
|
65
|
+
----------
|
|
66
|
+
phase : TaskExecution.Phase
|
|
67
|
+
The phase of the job.
|
|
68
|
+
message : Optional[str]
|
|
69
|
+
The return message from the job.
|
|
70
|
+
log_links : Optional[List[TaskLog]]
|
|
71
|
+
The log links of the job. For example, the link to the BigQuery Console.
|
|
72
|
+
outputs : Optional[Union[LiteralMap, typing.Dict[str, Any]]]
|
|
73
|
+
The outputs of the job. If return python native types, the agent will convert them to flyte literals.
|
|
74
|
+
custom_info : Optional[typing.Dict[str, Any]]
|
|
75
|
+
The custom info of the job. For example, the job config.
|
|
76
|
+
"""
|
|
77
|
+
|
|
78
|
+
phase: TaskExecution.Phase
|
|
79
|
+
message: Optional[str] = None
|
|
80
|
+
log_links: Optional[List[TaskLog]] = None
|
|
81
|
+
outputs: Optional[Dict[str, Any]] = None
|
|
82
|
+
custom_info: Optional[typing.Dict[str, Any]] = None
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class AsyncConnector(ABC):
|
|
86
|
+
"""
|
|
87
|
+
This is the base class for all async connectors, and it defines the interface that all connectors must implement.
|
|
88
|
+
The connector service is responsible for invoking connectors.
|
|
89
|
+
The executor will communicate with the connector service to create tasks, get the status of tasks, and delete tasks.
|
|
90
|
+
|
|
91
|
+
All the connectors should be registered in the ConnectorRegistry.
|
|
92
|
+
Connector Service will look up the connector based on the task type and version.
|
|
93
|
+
"""
|
|
94
|
+
|
|
95
|
+
name = "Async Connector"
|
|
96
|
+
task_type_name: str
|
|
97
|
+
task_type_version: int = 0
|
|
98
|
+
metadata_type: ResourceMeta
|
|
99
|
+
|
|
100
|
+
@abstractmethod
|
|
101
|
+
async def create(
|
|
102
|
+
self,
|
|
103
|
+
task_template: tasks_pb2.TaskTemplate,
|
|
104
|
+
output_prefix: str,
|
|
105
|
+
inputs: Optional[Dict[str, typing.Any]] = None,
|
|
106
|
+
task_execution_metadata: Optional[TaskExecutionMetadata] = None,
|
|
107
|
+
**kwargs,
|
|
108
|
+
) -> ResourceMeta:
|
|
109
|
+
"""
|
|
110
|
+
Return a resource meta that can be used to get the status of the task.
|
|
111
|
+
"""
|
|
112
|
+
raise NotImplementedError
|
|
113
|
+
|
|
114
|
+
@abstractmethod
|
|
115
|
+
async def get(self, resource_meta: ResourceMeta, **kwargs) -> Resource:
|
|
116
|
+
"""
|
|
117
|
+
Return the status of the task, and return the outputs in some cases. For example, bigquery job
|
|
118
|
+
can't write the structured dataset to the output location, so it returns the output literals to the propeller,
|
|
119
|
+
and the propeller will write the structured dataset to the blob store.
|
|
120
|
+
"""
|
|
121
|
+
raise NotImplementedError
|
|
122
|
+
|
|
123
|
+
@abstractmethod
|
|
124
|
+
async def delete(self, resource_meta: ResourceMeta, **kwargs):
|
|
125
|
+
"""
|
|
126
|
+
Delete the task. This call should be idempotent. It should raise an error if fails to delete the task.
|
|
127
|
+
"""
|
|
128
|
+
raise NotImplementedError
|
|
129
|
+
|
|
130
|
+
async def get_metrics(self, resource_meta: ResourceMeta, **kwargs) -> GetTaskMetricsResponse:
|
|
131
|
+
"""
|
|
132
|
+
Return the metrics for the task.
|
|
133
|
+
"""
|
|
134
|
+
raise NotImplementedError
|
|
135
|
+
|
|
136
|
+
async def get_logs(self, resource_meta: ResourceMeta, **kwargs) -> GetTaskLogsResponse:
|
|
137
|
+
"""
|
|
138
|
+
Return the metrics for the task.
|
|
139
|
+
"""
|
|
140
|
+
raise NotImplementedError
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
class ConnectorRegistry(object):
|
|
144
|
+
"""
|
|
145
|
+
This is the registry for all connectors.
|
|
146
|
+
The connector service will look up the connector registry based on the task type and version.
|
|
147
|
+
"""
|
|
148
|
+
|
|
149
|
+
_REGISTRY: typing.ClassVar[Dict[ConnectorRegistryKey, AsyncConnector]] = {}
|
|
150
|
+
_METADATA: typing.ClassVar[Dict[str, Connector]] = {}
|
|
151
|
+
|
|
152
|
+
@staticmethod
|
|
153
|
+
def register(connector: AsyncConnector, override: bool = False):
|
|
154
|
+
key = ConnectorRegistryKey(
|
|
155
|
+
task_type_name=connector.task_type_name, task_type_version=connector.task_type_version
|
|
156
|
+
)
|
|
157
|
+
if key in ConnectorRegistry._REGISTRY and override is False:
|
|
158
|
+
raise ValueError(
|
|
159
|
+
f"Duplicate connector for task type: {connector.task_type_name}"
|
|
160
|
+
f" and version: {connector.task_type_version}"
|
|
161
|
+
)
|
|
162
|
+
ConnectorRegistry._REGISTRY[key] = connector
|
|
163
|
+
|
|
164
|
+
task_category = TaskCategory(name=connector.task_type_name, version=connector.task_type_version)
|
|
165
|
+
|
|
166
|
+
if connector.name in ConnectorRegistry._METADATA:
|
|
167
|
+
connector_metadata = ConnectorRegistry.get_connector_metadata(connector.name)
|
|
168
|
+
connector_metadata.supported_task_categories.append(task_category)
|
|
169
|
+
else:
|
|
170
|
+
connector_metadata = Connector(
|
|
171
|
+
name=connector.name,
|
|
172
|
+
supported_task_categories=[task_category],
|
|
173
|
+
)
|
|
174
|
+
ConnectorRegistry._METADATA[connector.name] = connector_metadata
|
|
175
|
+
|
|
176
|
+
@staticmethod
|
|
177
|
+
def get_connector(task_type_name: str, task_type_version: int = 0) -> AsyncConnector:
|
|
178
|
+
key = ConnectorRegistryKey(task_type_name=task_type_name, task_type_version=task_type_version)
|
|
179
|
+
if key not in ConnectorRegistry._REGISTRY:
|
|
180
|
+
raise FlyteConnectorNotFound(
|
|
181
|
+
f"Cannot find connector for task type: {task_type_name} and version: {task_type_version}"
|
|
182
|
+
)
|
|
183
|
+
return ConnectorRegistry._REGISTRY[key]
|
|
184
|
+
|
|
185
|
+
@staticmethod
|
|
186
|
+
def list_connectors() -> List[Connector]:
|
|
187
|
+
return list(ConnectorRegistry._METADATA.values())
|
|
188
|
+
|
|
189
|
+
@staticmethod
|
|
190
|
+
def get_connector_metadata(name: str) -> Connector:
|
|
191
|
+
if name not in ConnectorRegistry._METADATA:
|
|
192
|
+
raise FlyteConnectorNotFound(f"Cannot find connector for name: {name}.")
|
|
193
|
+
return ConnectorRegistry._METADATA[name]
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
class ConnectorSecretsMixin:
|
|
197
|
+
def __init__(self, secrets: Dict[str, str]):
|
|
198
|
+
# Key is the id of the secret, value is the secret name.
|
|
199
|
+
self._secrets = secrets
|
|
200
|
+
|
|
201
|
+
@property
|
|
202
|
+
def secrets(self) -> List[Secret]:
|
|
203
|
+
return [Secret(key=k, as_env_var=v) for k, v in self._secrets.items()]
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
class AsyncConnectorExecutorMixin:
|
|
207
|
+
"""
|
|
208
|
+
This mixin class is used to run the connector task locally, and it's only used for local execution.
|
|
209
|
+
Task should inherit from this class if the task can be run in the connector.
|
|
210
|
+
"""
|
|
211
|
+
|
|
212
|
+
async def execute(self, **kwargs) -> Any:
|
|
213
|
+
task = typing.cast(TaskTemplate, self)
|
|
214
|
+
connector = ConnectorRegistry.get_connector(task.task_type, task.task_type_version)
|
|
215
|
+
|
|
216
|
+
ctx = internal_ctx()
|
|
217
|
+
tctx = internal_ctx().data.task_context
|
|
218
|
+
cfg = get_init_config()
|
|
219
|
+
|
|
220
|
+
if tctx is None:
|
|
221
|
+
raise RuntimeError("Task context is not set.")
|
|
222
|
+
|
|
223
|
+
sc = SerializationContext(
|
|
224
|
+
project=tctx.action.project,
|
|
225
|
+
domain=tctx.action.domain,
|
|
226
|
+
org=tctx.action.org,
|
|
227
|
+
code_bundle=tctx.code_bundle,
|
|
228
|
+
version=tctx.version,
|
|
229
|
+
image_cache=tctx.compiled_image_cache,
|
|
230
|
+
root_dir=cfg.root_dir,
|
|
231
|
+
)
|
|
232
|
+
tt = get_proto_task(task, sc)
|
|
233
|
+
resource_meta = await connector.create(task_template=tt, output_prefix=ctx.raw_data.path, inputs=kwargs)
|
|
234
|
+
resource = Resource(phase=TaskExecution.RUNNING)
|
|
235
|
+
|
|
236
|
+
while not is_terminal_phase(resource.phase):
|
|
237
|
+
resource = await connector.get(resource_meta=resource_meta)
|
|
238
|
+
|
|
239
|
+
if resource.log_links:
|
|
240
|
+
for link in resource.log_links:
|
|
241
|
+
logger.info(f"{link.name}: {link.uri}")
|
|
242
|
+
await asyncio.sleep(1)
|
|
243
|
+
|
|
244
|
+
if resource.phase != TaskExecution.SUCCEEDED:
|
|
245
|
+
raise RuntimeError(f"Failed to run the task {task.name} with error: {resource.message}")
|
|
246
|
+
|
|
247
|
+
# TODO: Support abort
|
|
248
|
+
|
|
249
|
+
if resource.outputs is None:
|
|
250
|
+
return None
|
|
251
|
+
return tuple(resource.outputs.values())
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
async def get_resource_proto(resource: Resource) -> connector_pb2.Resource:
|
|
255
|
+
if resource.outputs:
|
|
256
|
+
interface = NativeInterface.from_types(inputs={}, outputs={k: type(v) for k, v in resource.outputs.items()})
|
|
257
|
+
outputs = await convert_from_native_to_outputs(tuple(resource.outputs.values()), interface)
|
|
258
|
+
else:
|
|
259
|
+
outputs = None
|
|
260
|
+
|
|
261
|
+
return connector_pb2.Resource(
|
|
262
|
+
phase=resource.phase,
|
|
263
|
+
message=resource.message,
|
|
264
|
+
log_links=resource.log_links,
|
|
265
|
+
outputs=outputs,
|
|
266
|
+
custom_info=(json_format.Parse(json.dumps(resource.custom_info), Struct()) if resource.custom_info else None),
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
class FlyteConnectorNotFound(ValueError): ...
|
|
@@ -0,0 +1,183 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
from http import HTTPStatus
|
|
3
|
+
from typing import Callable, Dict, Tuple, Type, Union
|
|
4
|
+
|
|
5
|
+
import grpc
|
|
6
|
+
from flyteidl2.core.security_pb2 import Connection
|
|
7
|
+
from flyteidl2.plugins.connector_pb2 import (
|
|
8
|
+
CreateTaskRequest,
|
|
9
|
+
CreateTaskResponse,
|
|
10
|
+
DeleteTaskRequest,
|
|
11
|
+
DeleteTaskResponse,
|
|
12
|
+
GetConnectorRequest,
|
|
13
|
+
GetConnectorResponse,
|
|
14
|
+
GetTaskLogsRequest,
|
|
15
|
+
GetTaskLogsResponse,
|
|
16
|
+
GetTaskMetricsRequest,
|
|
17
|
+
GetTaskMetricsResponse,
|
|
18
|
+
GetTaskRequest,
|
|
19
|
+
GetTaskResponse,
|
|
20
|
+
ListConnectorsRequest,
|
|
21
|
+
ListConnectorsResponse,
|
|
22
|
+
)
|
|
23
|
+
from flyteidl2.service.connector_pb2_grpc import (
|
|
24
|
+
AsyncConnectorServiceServicer,
|
|
25
|
+
ConnectorMetadataServiceServicer,
|
|
26
|
+
)
|
|
27
|
+
from prometheus_client import Counter, Summary
|
|
28
|
+
|
|
29
|
+
from flyte._internal.runtime.convert import Inputs, convert_from_inputs_to_native
|
|
30
|
+
from flyte._logging import logger
|
|
31
|
+
from flyte.connectors._connector import ConnectorRegistry, FlyteConnectorNotFound, get_resource_proto
|
|
32
|
+
from flyte.models import NativeInterface, _has_default
|
|
33
|
+
from flyte.types import TypeEngine
|
|
34
|
+
|
|
35
|
+
metric_prefix = "flyte_connector_"
|
|
36
|
+
create_operation = "create"
|
|
37
|
+
get_operation = "get"
|
|
38
|
+
delete_operation = "delete"
|
|
39
|
+
|
|
40
|
+
# Follow the naming convention. https://prometheus.io/docs/practices/naming/
|
|
41
|
+
request_success_count = Counter(
|
|
42
|
+
f"{metric_prefix}requests_success_total",
|
|
43
|
+
"Total number of successful requests",
|
|
44
|
+
["task_type", "operation"],
|
|
45
|
+
)
|
|
46
|
+
request_failure_count = Counter(
|
|
47
|
+
f"{metric_prefix}requests_failure_total",
|
|
48
|
+
"Total number of failed requests",
|
|
49
|
+
["task_type", "operation", "error_code"],
|
|
50
|
+
)
|
|
51
|
+
request_latency = Summary(
|
|
52
|
+
f"{metric_prefix}request_latency_seconds",
|
|
53
|
+
"Time spent processing connector request",
|
|
54
|
+
["task_type", "operation"],
|
|
55
|
+
)
|
|
56
|
+
input_literal_size = Summary(f"{metric_prefix}input_literal_bytes", "Size of input literal", ["task_type"])
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def _handle_exception(e: Exception, context: grpc.ServicerContext, task_type: str, operation: str):
|
|
60
|
+
if isinstance(e, FlyteConnectorNotFound):
|
|
61
|
+
error_message = f"Cannot find connector for task type: {task_type}."
|
|
62
|
+
logger.error(error_message)
|
|
63
|
+
context.set_code(grpc.StatusCode.NOT_FOUND)
|
|
64
|
+
context.set_details(error_message)
|
|
65
|
+
request_failure_count.labels(task_type=task_type, operation=operation, error_code=HTTPStatus.NOT_FOUND).inc()
|
|
66
|
+
else:
|
|
67
|
+
error_message = f"failed to {operation} {task_type} task with error:\n {e}."
|
|
68
|
+
logger.error(error_message)
|
|
69
|
+
context.set_code(grpc.StatusCode.INTERNAL)
|
|
70
|
+
context.set_details(error_message)
|
|
71
|
+
request_failure_count.labels(
|
|
72
|
+
task_type=task_type, operation=operation, error_code=HTTPStatus.INTERNAL_SERVER_ERROR
|
|
73
|
+
).inc()
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def record_connector_metrics(func: Callable):
|
|
77
|
+
async def wrapper(
|
|
78
|
+
self,
|
|
79
|
+
request: Union[CreateTaskRequest, GetTaskRequest, DeleteTaskRequest],
|
|
80
|
+
context: grpc.ServicerContext,
|
|
81
|
+
*args,
|
|
82
|
+
**kwargs,
|
|
83
|
+
):
|
|
84
|
+
if isinstance(request, CreateTaskRequest):
|
|
85
|
+
task_type = request.template.type
|
|
86
|
+
operation = create_operation
|
|
87
|
+
if request.inputs:
|
|
88
|
+
input_literal_size.labels(task_type=task_type).observe(request.inputs.ByteSize())
|
|
89
|
+
elif isinstance(request, GetTaskRequest):
|
|
90
|
+
task_type = request.task_category.name
|
|
91
|
+
operation = get_operation
|
|
92
|
+
elif isinstance(request, DeleteTaskRequest):
|
|
93
|
+
task_type = request.task_category.name
|
|
94
|
+
operation = delete_operation
|
|
95
|
+
else:
|
|
96
|
+
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
97
|
+
context.set_details("Method not implemented!")
|
|
98
|
+
return None
|
|
99
|
+
|
|
100
|
+
try:
|
|
101
|
+
with request_latency.labels(task_type=task_type, operation=operation).time():
|
|
102
|
+
res = await func(self, request, context, *args, **kwargs)
|
|
103
|
+
request_success_count.labels(task_type=task_type, operation=operation).inc()
|
|
104
|
+
return res
|
|
105
|
+
except Exception as e:
|
|
106
|
+
_handle_exception(e, context, task_type, operation)
|
|
107
|
+
|
|
108
|
+
return wrapper
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def _get_connection_kwargs(request: Connection) -> Dict[str, str]:
|
|
112
|
+
kwargs = {}
|
|
113
|
+
|
|
114
|
+
for k, v in request.secrets.items():
|
|
115
|
+
kwargs[k] = v
|
|
116
|
+
for k, v in request.configs.items():
|
|
117
|
+
kwargs[k] = v
|
|
118
|
+
|
|
119
|
+
return kwargs
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
class AsyncConnectorService(AsyncConnectorServiceServicer):
|
|
123
|
+
@record_connector_metrics
|
|
124
|
+
async def CreateTask(self, request: CreateTaskRequest, context: grpc.ServicerContext) -> CreateTaskResponse:
|
|
125
|
+
template = request.template
|
|
126
|
+
connector = ConnectorRegistry.get_connector(template.type, template.task_type_version)
|
|
127
|
+
logger.info(f"{connector.name} start creating the job")
|
|
128
|
+
python_interface_inputs: Dict[str, Tuple[Type, Type[_has_default] | Type[inspect._empty]]] = {
|
|
129
|
+
name: (TypeEngine.guess_python_type(lt.type), inspect.Parameter.empty)
|
|
130
|
+
for name, lt in template.interface.inputs.variables.items()
|
|
131
|
+
}
|
|
132
|
+
native_interface = NativeInterface.from_types(inputs=python_interface_inputs, outputs={})
|
|
133
|
+
native_inputs = await convert_from_inputs_to_native(native_interface, Inputs(proto_inputs=request.inputs))
|
|
134
|
+
resource_meta = await connector.create(
|
|
135
|
+
task_template=request.template,
|
|
136
|
+
inputs=native_inputs,
|
|
137
|
+
output_prefix=request.output_prefix,
|
|
138
|
+
task_execution_metadata=request.task_execution_metadata,
|
|
139
|
+
connection=_get_connection_kwargs(request.connection),
|
|
140
|
+
)
|
|
141
|
+
return CreateTaskResponse(resource_meta=resource_meta.encode())
|
|
142
|
+
|
|
143
|
+
@record_connector_metrics
|
|
144
|
+
async def GetTask(self, request: GetTaskRequest, context: grpc.ServicerContext) -> GetTaskResponse:
|
|
145
|
+
connector = ConnectorRegistry.get_connector(request.task_category.name, request.task_category.version)
|
|
146
|
+
logger.info(f"{connector.name} start checking the status of the job")
|
|
147
|
+
res = await connector.get(
|
|
148
|
+
resource_meta=connector.metadata_type.decode(request.resource_meta),
|
|
149
|
+
connection=_get_connection_kwargs(request.connection),
|
|
150
|
+
)
|
|
151
|
+
return GetTaskResponse(resource=await get_resource_proto(res))
|
|
152
|
+
|
|
153
|
+
@record_connector_metrics
|
|
154
|
+
async def DeleteTask(self, request: DeleteTaskRequest, context: grpc.ServicerContext) -> DeleteTaskResponse:
|
|
155
|
+
connector = ConnectorRegistry.get_connector(request.task_category.name, request.task_category.version)
|
|
156
|
+
logger.info(f"{connector.name} start deleting the job")
|
|
157
|
+
await connector.delete(
|
|
158
|
+
resource_meta=connector.metadata_type.decode(request.resource_meta),
|
|
159
|
+
connection=_get_connection_kwargs(request.connection),
|
|
160
|
+
)
|
|
161
|
+
return DeleteTaskResponse()
|
|
162
|
+
|
|
163
|
+
async def GetTaskMetrics(
|
|
164
|
+
self, request: GetTaskMetricsRequest, context: grpc.ServicerContext
|
|
165
|
+
) -> GetTaskMetricsResponse:
|
|
166
|
+
connector = ConnectorRegistry.get_connector(request.task_category.name, request.task_category.version)
|
|
167
|
+
logger.info(f"{connector.name} start getting metrics of the job")
|
|
168
|
+
return await connector.get_metrics(resource_meta=connector.metadata_type.decode(request.resource_meta))
|
|
169
|
+
|
|
170
|
+
async def GetTaskLogs(self, request: GetTaskLogsRequest, context: grpc.ServicerContext) -> GetTaskLogsResponse:
|
|
171
|
+
connector = ConnectorRegistry.get_connector(request.task_category.name, request.task_category.version)
|
|
172
|
+
logger.info(f"{connector.name} start getting logs of the job")
|
|
173
|
+
return await connector.get_logs(resource_meta=connector.metadata_type.decode(request.resource_meta))
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
class ConnectorMetadataService(ConnectorMetadataServiceServicer):
|
|
177
|
+
async def GetConnector(self, request: GetConnectorRequest, context: grpc.ServicerContext) -> GetConnectorResponse:
|
|
178
|
+
return GetConnectorResponse(connector=ConnectorRegistry.get_connector_metadata(request.name))
|
|
179
|
+
|
|
180
|
+
async def ListConnectors(
|
|
181
|
+
self, request: ListConnectorsRequest, context: grpc.ServicerContext
|
|
182
|
+
) -> ListConnectorsResponse:
|
|
183
|
+
return ListConnectorsResponse(connectors=ConnectorRegistry.list_connectors())
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
from flyteidl2.core.execution_pb2 import TaskExecution
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def is_terminal_phase(phase: TaskExecution.Phase) -> bool:
|
|
5
|
+
"""
|
|
6
|
+
Return true if the phase is terminal.
|
|
7
|
+
"""
|
|
8
|
+
return phase in [TaskExecution.SUCCEEDED, TaskExecution.ABORTED, TaskExecution.FAILED]
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def convert_to_flyte_phase(state: str) -> TaskExecution.Phase:
|
|
12
|
+
"""
|
|
13
|
+
Convert the state from the connector to the phase in flyte.
|
|
14
|
+
"""
|
|
15
|
+
state = state.lower()
|
|
16
|
+
if state in ["failed", "timeout", "timedout", "canceled", "cancelled", "skipped"]:
|
|
17
|
+
return TaskExecution.FAILED
|
|
18
|
+
if state in ["internal_error"]:
|
|
19
|
+
return TaskExecution.RETRYABLE_FAILED
|
|
20
|
+
elif state in ["done", "succeeded", "success", "completed"]:
|
|
21
|
+
return TaskExecution.SUCCEEDED
|
|
22
|
+
elif state in ["running", "terminating"]:
|
|
23
|
+
return TaskExecution.RUNNING
|
|
24
|
+
elif state in ["pending"]:
|
|
25
|
+
return TaskExecution.INITIALIZING
|
|
26
|
+
raise ValueError(f"Unrecognized state: {state}")
|
flyte/models.py
CHANGED
|
@@ -355,10 +355,18 @@ class NativeInterface:
|
|
|
355
355
|
"""
|
|
356
356
|
Extract the native interface from the given function. This is used to create a native interface for the task.
|
|
357
357
|
"""
|
|
358
|
+
# Get function parameters, defaults, varargs info (POSITIONAL_ONLY, VAR_POSITIONAL, KEYWORD_ONLY, etc.).
|
|
358
359
|
sig = inspect.signature(func)
|
|
359
360
|
|
|
360
361
|
# Extract parameter details (name, type, default value)
|
|
361
362
|
param_info = {}
|
|
363
|
+
try:
|
|
364
|
+
# Get fully evaluated, real Python types for type checking.
|
|
365
|
+
hints = typing.get_type_hints(func, include_extras=True)
|
|
366
|
+
except Exception as e:
|
|
367
|
+
logger.warning(f"Could not get type hints for function {func.__name__}: {e}")
|
|
368
|
+
raise
|
|
369
|
+
|
|
362
370
|
for name, param in sig.parameters.items():
|
|
363
371
|
if param.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD):
|
|
364
372
|
raise ValueError(f"Function {func.__name__} cannot have variable positional or keyword arguments.")
|
|
@@ -366,13 +374,14 @@ class NativeInterface:
|
|
|
366
374
|
logger.warning(
|
|
367
375
|
f"Function {func.__name__} has parameter {name} without type annotation. Data will be pickled."
|
|
368
376
|
)
|
|
369
|
-
|
|
370
|
-
|
|
377
|
+
arg_type = hints.get(name, param.annotation)
|
|
378
|
+
if typing.get_origin(arg_type) is Literal:
|
|
379
|
+
param_info[name] = (literal_to_enum(arg_type), param.default)
|
|
371
380
|
else:
|
|
372
|
-
param_info[name] = (
|
|
381
|
+
param_info[name] = (arg_type, param.default)
|
|
373
382
|
|
|
374
383
|
# Get return type
|
|
375
|
-
outputs = extract_return_annotation(sig.return_annotation)
|
|
384
|
+
outputs = extract_return_annotation(hints.get("return", sig.return_annotation))
|
|
376
385
|
return cls(inputs=param_info, outputs=outputs)
|
|
377
386
|
|
|
378
387
|
def convert_to_kwargs(self, *args, **kwargs) -> Dict[str, Any]:
|
|
@@ -7,6 +7,7 @@ import httpx
|
|
|
7
7
|
from grpc.experimental.aio import init_grpc_aio
|
|
8
8
|
|
|
9
9
|
from flyte._logging import logger
|
|
10
|
+
from flyte._utils.org_discovery import hostname_from_url
|
|
10
11
|
|
|
11
12
|
from ._authenticators.base import get_async_session
|
|
12
13
|
from ._authenticators.factory import (
|
|
@@ -30,16 +31,19 @@ def bootstrap_ssl_from_server(endpoint: str) -> grpc.ChannelCredentials:
|
|
|
30
31
|
:param endpoint: The endpoint URL to retrieve the SSL certificate from, may include port number
|
|
31
32
|
:return: gRPC channel credentials created from the retrieved certificate
|
|
32
33
|
"""
|
|
34
|
+
hostname = hostname_from_url(endpoint)
|
|
35
|
+
|
|
33
36
|
# Get port from endpoint or use 443
|
|
34
|
-
endpoint_parts =
|
|
37
|
+
endpoint_parts = hostname.rsplit(":", 1)
|
|
35
38
|
if len(endpoint_parts) == 2 and endpoint_parts[1].isdigit():
|
|
36
39
|
server_address = (endpoint_parts[0], int(endpoint_parts[1]))
|
|
37
40
|
else:
|
|
38
|
-
logger.warning(f"Unrecognized port in endpoint [{
|
|
39
|
-
server_address = (
|
|
41
|
+
logger.warning(f"Unrecognized port in endpoint [{hostname}], defaulting to 443.")
|
|
42
|
+
server_address = (hostname, 443)
|
|
40
43
|
|
|
41
|
-
# Run the blocking SSL certificate retrieval
|
|
42
|
-
|
|
44
|
+
# Run the blocking SSL certificate retrieval with a timeout
|
|
45
|
+
logger.debug(f"Retrieving SSL certificate from {server_address}")
|
|
46
|
+
cert = ssl.get_server_certificate(server_address, timeout=10)
|
|
43
47
|
return grpc.ssl_channel_credentials(str.encode(cert))
|
|
44
48
|
|
|
45
49
|
|
flyte/remote/_console.py
CHANGED
|
@@ -9,8 +9,9 @@ def _get_http_domain(endpoint: str, insecure: bool) -> str:
|
|
|
9
9
|
else:
|
|
10
10
|
domain = parsed.netloc or parsed.path
|
|
11
11
|
# TODO: make console url configurable
|
|
12
|
-
|
|
13
|
-
|
|
12
|
+
domain_split = domain.split(":")
|
|
13
|
+
if domain_split[0] == "localhost":
|
|
14
|
+
domain = domain if len(domain_split) > 1 else f"{domain}:8080"
|
|
14
15
|
return f"{scheme}://{domain}"
|
|
15
16
|
|
|
16
17
|
|
flyte/remote/_secret.py
CHANGED
|
@@ -79,21 +79,23 @@ class Secret(ToJSONMixin):
|
|
|
79
79
|
async def listall(cls, limit: int = 100) -> AsyncIterator[Secret]:
|
|
80
80
|
ensure_client()
|
|
81
81
|
cfg = get_init_config()
|
|
82
|
-
|
|
82
|
+
per_cluster_tokens = None
|
|
83
83
|
while True:
|
|
84
84
|
resp = await get_client().secrets_service.ListSecrets( # type: ignore
|
|
85
85
|
request=payload_pb2.ListSecretsRequest(
|
|
86
86
|
organization=cfg.org,
|
|
87
87
|
project=cfg.project,
|
|
88
88
|
domain=cfg.domain,
|
|
89
|
-
|
|
89
|
+
per_cluster_tokens=per_cluster_tokens,
|
|
90
90
|
limit=limit,
|
|
91
91
|
),
|
|
92
92
|
)
|
|
93
|
-
|
|
93
|
+
per_cluster_tokens = resp.per_cluster_tokens
|
|
94
|
+
round_items = [v for _, v in per_cluster_tokens.items() if v]
|
|
95
|
+
has_next = any(round_items)
|
|
94
96
|
for r in resp.secrets:
|
|
95
97
|
yield cls(r)
|
|
96
|
-
if not
|
|
98
|
+
if not has_next:
|
|
97
99
|
break
|
|
98
100
|
|
|
99
101
|
@syncify
|
flyte/remote/_trigger.py
CHANGED
|
@@ -284,9 +284,9 @@ class Trigger(ToJSONMixin):
|
|
|
284
284
|
return self.pb2.active
|
|
285
285
|
|
|
286
286
|
def _rich_automation(self, automation: common_pb2.TriggerAutomationSpec):
|
|
287
|
-
if automation.type == common_pb2.TriggerAutomationSpec.TYPE_NONE:
|
|
287
|
+
if automation.type == common_pb2.TriggerAutomationSpec.type.TYPE_NONE:
|
|
288
288
|
yield "none", None
|
|
289
|
-
elif automation.type == common_pb2.TriggerAutomationSpec.TYPE_SCHEDULE:
|
|
289
|
+
elif automation.type == common_pb2.TriggerAutomationSpec.type.TYPE_SCHEDULE:
|
|
290
290
|
if automation.schedule.cron_expression is not None:
|
|
291
291
|
yield "cron", automation.schedule.cron_expression
|
|
292
292
|
elif automation.schedule.rate is not None:
|
flyte/types/_type_engine.py
CHANGED
|
@@ -1911,7 +1911,6 @@ def _get_element_type(element_property: typing.Dict[str, str]) -> Type:
|
|
|
1911
1911
|
return str
|
|
1912
1912
|
|
|
1913
1913
|
|
|
1914
|
-
# pr: han-ru is this still needed?
|
|
1915
1914
|
def dataclass_from_dict(cls: type, src: typing.Dict[str, typing.Any]) -> typing.Any:
|
|
1916
1915
|
"""
|
|
1917
1916
|
Utility function to construct a dataclass object from dict
|
|
@@ -1991,7 +1990,7 @@ def _handle_flyte_console_float_input_to_int(lv: Literal) -> int:
|
|
|
1991
1990
|
|
|
1992
1991
|
def _check_and_convert_void(lv: Literal) -> None:
|
|
1993
1992
|
if not lv.scalar.HasField("none_type"):
|
|
1994
|
-
raise TypeTransformerFailedError(f"Cannot convert literal {lv} to None")
|
|
1993
|
+
raise TypeTransformerFailedError(f"Cannot convert literal '{lv}' to None")
|
|
1995
1994
|
return None
|
|
1996
1995
|
|
|
1997
1996
|
|
|
@@ -12,6 +12,7 @@ from typing import Any, List
|
|
|
12
12
|
|
|
13
13
|
import click
|
|
14
14
|
|
|
15
|
+
from flyte._utils.helpers import str2bool
|
|
15
16
|
from flyte.models import PathRewrite
|
|
16
17
|
|
|
17
18
|
# Todo: work with pvditt to make these the names
|
|
@@ -27,6 +28,7 @@ PROJECT_NAME = "FLYTE_INTERNAL_EXECUTION_PROJECT"
|
|
|
27
28
|
DOMAIN_NAME = "FLYTE_INTERNAL_EXECUTION_DOMAIN"
|
|
28
29
|
ORG_NAME = "_U_ORG_NAME"
|
|
29
30
|
ENDPOINT_OVERRIDE = "_U_EP_OVERRIDE"
|
|
31
|
+
INSECURE_SKIP_VERIFY_OVERRIDE = "_U_INSECURE_SKIP_VERIFY"
|
|
30
32
|
RUN_OUTPUT_BASE_DIR = "_U_RUN_BASE"
|
|
31
33
|
FLYTE_ENABLE_VSCODE_KEY = "_F_E_VS"
|
|
32
34
|
|
|
@@ -139,6 +141,12 @@ def main(
|
|
|
139
141
|
controller_kwargs["insecure"] = True
|
|
140
142
|
logger.debug(f"Using controller endpoint: {ep} with kwargs: {controller_kwargs}")
|
|
141
143
|
|
|
144
|
+
# Check for insecure_skip_verify override (e.g. for self-signed certs)
|
|
145
|
+
insecure_skip_verify_str = os.getenv(INSECURE_SKIP_VERIFY_OVERRIDE, "")
|
|
146
|
+
if str2bool(insecure_skip_verify_str):
|
|
147
|
+
controller_kwargs["insecure_skip_verify"] = True
|
|
148
|
+
logger.info("SSL certificate verification disabled (insecure_skip_verify=True)")
|
|
149
|
+
|
|
142
150
|
bundle = None
|
|
143
151
|
if tgz or pkl:
|
|
144
152
|
bundle = CodeBundle(tgz=tgz, pkl=pkl, destination=dest, computed_version=version)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: flyte
|
|
3
|
-
Version: 2.0.
|
|
3
|
+
Version: 2.0.0b28
|
|
4
4
|
Summary: Add your description here
|
|
5
5
|
Author-email: Ketan Umare <kumare3@users.noreply.github.com>
|
|
6
6
|
Requires-Python: >=3.10
|
|
@@ -25,9 +25,13 @@ Requires-Dist: async-lru>=2.0.5
|
|
|
25
25
|
Requires-Dist: mashumaro
|
|
26
26
|
Requires-Dist: dataclasses_json
|
|
27
27
|
Requires-Dist: aiolimiter>=1.2.1
|
|
28
|
-
Requires-Dist: flyteidl2==2.0.
|
|
28
|
+
Requires-Dist: flyteidl2==2.0.0a11
|
|
29
29
|
Provides-Extra: aiosqlite
|
|
30
30
|
Requires-Dist: aiosqlite>=0.21.0; extra == "aiosqlite"
|
|
31
|
+
Provides-Extra: connector
|
|
32
|
+
Requires-Dist: grpcio-health-checking; extra == "connector"
|
|
33
|
+
Requires-Dist: httpx; extra == "connector"
|
|
34
|
+
Requires-Dist: prometheus-client; extra == "connector"
|
|
31
35
|
Dynamic: license-file
|
|
32
36
|
|
|
33
37
|
# Flyte 2 SDK 🚀
|