flyte 2.0.0b25__py3-none-any.whl → 2.0.0b28__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flyte might be problematic. Click here for more details.

Files changed (54) hide show
  1. flyte/__init__.py +2 -0
  2. flyte/_bin/runtime.py +8 -0
  3. flyte/_code_bundle/_utils.py +4 -4
  4. flyte/_code_bundle/bundle.py +1 -1
  5. flyte/_constants.py +1 -0
  6. flyte/_deploy.py +0 -1
  7. flyte/_excepthook.py +1 -1
  8. flyte/_initialize.py +10 -0
  9. flyte/_interface.py +2 -0
  10. flyte/_internal/imagebuild/docker_builder.py +3 -1
  11. flyte/_internal/imagebuild/remote_builder.py +3 -1
  12. flyte/_internal/resolvers/_task_module.py +4 -37
  13. flyte/_internal/runtime/convert.py +3 -2
  14. flyte/_internal/runtime/entrypoints.py +24 -1
  15. flyte/_internal/runtime/rusty.py +3 -3
  16. flyte/_internal/runtime/task_serde.py +19 -4
  17. flyte/_internal/runtime/trigger_serde.py +2 -2
  18. flyte/_map.py +2 -35
  19. flyte/_module.py +68 -0
  20. flyte/_resources.py +38 -0
  21. flyte/_run.py +23 -6
  22. flyte/_task.py +1 -2
  23. flyte/_task_plugins.py +4 -2
  24. flyte/_trigger.py +623 -5
  25. flyte/_utils/__init__.py +2 -1
  26. flyte/_utils/asyn.py +3 -1
  27. flyte/_utils/docker_credentials.py +173 -0
  28. flyte/_utils/module_loader.py +15 -0
  29. flyte/_version.py +3 -3
  30. flyte/cli/_common.py +15 -3
  31. flyte/cli/_create.py +100 -3
  32. flyte/cli/_deploy.py +38 -4
  33. flyte/cli/_plugins.py +208 -0
  34. flyte/cli/_run.py +69 -6
  35. flyte/cli/_serve.py +154 -0
  36. flyte/cli/main.py +6 -0
  37. flyte/connectors/__init__.py +3 -0
  38. flyte/connectors/_connector.py +270 -0
  39. flyte/connectors/_server.py +183 -0
  40. flyte/connectors/utils.py +26 -0
  41. flyte/models.py +13 -4
  42. flyte/remote/_client/auth/_channel.py +9 -5
  43. flyte/remote/_console.py +3 -2
  44. flyte/remote/_secret.py +6 -4
  45. flyte/remote/_trigger.py +2 -2
  46. flyte/types/_type_engine.py +1 -2
  47. {flyte-2.0.0b25.data → flyte-2.0.0b28.data}/scripts/runtime.py +8 -0
  48. {flyte-2.0.0b25.dist-info → flyte-2.0.0b28.dist-info}/METADATA +6 -2
  49. {flyte-2.0.0b25.dist-info → flyte-2.0.0b28.dist-info}/RECORD +54 -46
  50. {flyte-2.0.0b25.data → flyte-2.0.0b28.data}/scripts/debug.py +0 -0
  51. {flyte-2.0.0b25.dist-info → flyte-2.0.0b28.dist-info}/WHEEL +0 -0
  52. {flyte-2.0.0b25.dist-info → flyte-2.0.0b28.dist-info}/entry_points.txt +0 -0
  53. {flyte-2.0.0b25.dist-info → flyte-2.0.0b28.dist-info}/licenses/LICENSE +0 -0
  54. {flyte-2.0.0b25.dist-info → flyte-2.0.0b28.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,270 @@
1
+ import asyncio
2
+ import json
3
+ import typing
4
+ from abc import ABC, abstractmethod
5
+ from dataclasses import asdict, dataclass
6
+ from typing import Any, Dict, List, Optional
7
+
8
+ from flyteidl2.core import tasks_pb2
9
+ from flyteidl2.core.execution_pb2 import TaskExecution, TaskLog
10
+ from flyteidl2.plugins import connector_pb2
11
+ from flyteidl2.plugins.connector_pb2 import (
12
+ Connector,
13
+ GetTaskLogsResponse,
14
+ GetTaskMetricsResponse,
15
+ TaskCategory,
16
+ TaskExecutionMetadata,
17
+ )
18
+ from google.protobuf import json_format
19
+ from google.protobuf.struct_pb2 import Struct
20
+
21
+ from flyte import Secret
22
+ from flyte._context import internal_ctx
23
+ from flyte._initialize import get_init_config
24
+ from flyte._internal.runtime.convert import convert_from_native_to_outputs
25
+ from flyte._internal.runtime.task_serde import get_proto_task
26
+ from flyte._logging import logger
27
+ from flyte._task import TaskTemplate
28
+ from flyte.connectors.utils import is_terminal_phase
29
+ from flyte.models import NativeInterface, SerializationContext
30
+ from flyte.types._type_engine import dataclass_from_dict
31
+
32
+
33
+ @dataclass(frozen=True)
34
+ class ConnectorRegistryKey:
35
+ task_type_name: str
36
+ task_type_version: int
37
+
38
+
39
+ @dataclass
40
+ class ResourceMeta:
41
+ """
42
+ This is the metadata for the job. For example, the id of the job.
43
+ """
44
+
45
+ def encode(self) -> bytes:
46
+ """
47
+ Encode the resource meta to bytes.
48
+ """
49
+ return json.dumps(asdict(self)).encode("utf-8")
50
+
51
+ @classmethod
52
+ def decode(cls, data: bytes) -> "ResourceMeta":
53
+ """
54
+ Decode the resource meta from bytes.
55
+ """
56
+ return dataclass_from_dict(cls, json.loads(data.decode("utf-8")))
57
+
58
+
59
+ @dataclass
60
+ class Resource:
61
+ """
62
+ This is the output resource of the job.
63
+
64
+ Attributes
65
+ ----------
66
+ phase : TaskExecution.Phase
67
+ The phase of the job.
68
+ message : Optional[str]
69
+ The return message from the job.
70
+ log_links : Optional[List[TaskLog]]
71
+ The log links of the job. For example, the link to the BigQuery Console.
72
+ outputs : Optional[Union[LiteralMap, typing.Dict[str, Any]]]
73
+ The outputs of the job. If return python native types, the agent will convert them to flyte literals.
74
+ custom_info : Optional[typing.Dict[str, Any]]
75
+ The custom info of the job. For example, the job config.
76
+ """
77
+
78
+ phase: TaskExecution.Phase
79
+ message: Optional[str] = None
80
+ log_links: Optional[List[TaskLog]] = None
81
+ outputs: Optional[Dict[str, Any]] = None
82
+ custom_info: Optional[typing.Dict[str, Any]] = None
83
+
84
+
85
+ class AsyncConnector(ABC):
86
+ """
87
+ This is the base class for all async connectors, and it defines the interface that all connectors must implement.
88
+ The connector service is responsible for invoking connectors.
89
+ The executor will communicate with the connector service to create tasks, get the status of tasks, and delete tasks.
90
+
91
+ All the connectors should be registered in the ConnectorRegistry.
92
+ Connector Service will look up the connector based on the task type and version.
93
+ """
94
+
95
+ name = "Async Connector"
96
+ task_type_name: str
97
+ task_type_version: int = 0
98
+ metadata_type: ResourceMeta
99
+
100
+ @abstractmethod
101
+ async def create(
102
+ self,
103
+ task_template: tasks_pb2.TaskTemplate,
104
+ output_prefix: str,
105
+ inputs: Optional[Dict[str, typing.Any]] = None,
106
+ task_execution_metadata: Optional[TaskExecutionMetadata] = None,
107
+ **kwargs,
108
+ ) -> ResourceMeta:
109
+ """
110
+ Return a resource meta that can be used to get the status of the task.
111
+ """
112
+ raise NotImplementedError
113
+
114
+ @abstractmethod
115
+ async def get(self, resource_meta: ResourceMeta, **kwargs) -> Resource:
116
+ """
117
+ Return the status of the task, and return the outputs in some cases. For example, bigquery job
118
+ can't write the structured dataset to the output location, so it returns the output literals to the propeller,
119
+ and the propeller will write the structured dataset to the blob store.
120
+ """
121
+ raise NotImplementedError
122
+
123
+ @abstractmethod
124
+ async def delete(self, resource_meta: ResourceMeta, **kwargs):
125
+ """
126
+ Delete the task. This call should be idempotent. It should raise an error if fails to delete the task.
127
+ """
128
+ raise NotImplementedError
129
+
130
+ async def get_metrics(self, resource_meta: ResourceMeta, **kwargs) -> GetTaskMetricsResponse:
131
+ """
132
+ Return the metrics for the task.
133
+ """
134
+ raise NotImplementedError
135
+
136
+ async def get_logs(self, resource_meta: ResourceMeta, **kwargs) -> GetTaskLogsResponse:
137
+ """
138
+ Return the metrics for the task.
139
+ """
140
+ raise NotImplementedError
141
+
142
+
143
+ class ConnectorRegistry(object):
144
+ """
145
+ This is the registry for all connectors.
146
+ The connector service will look up the connector registry based on the task type and version.
147
+ """
148
+
149
+ _REGISTRY: typing.ClassVar[Dict[ConnectorRegistryKey, AsyncConnector]] = {}
150
+ _METADATA: typing.ClassVar[Dict[str, Connector]] = {}
151
+
152
+ @staticmethod
153
+ def register(connector: AsyncConnector, override: bool = False):
154
+ key = ConnectorRegistryKey(
155
+ task_type_name=connector.task_type_name, task_type_version=connector.task_type_version
156
+ )
157
+ if key in ConnectorRegistry._REGISTRY and override is False:
158
+ raise ValueError(
159
+ f"Duplicate connector for task type: {connector.task_type_name}"
160
+ f" and version: {connector.task_type_version}"
161
+ )
162
+ ConnectorRegistry._REGISTRY[key] = connector
163
+
164
+ task_category = TaskCategory(name=connector.task_type_name, version=connector.task_type_version)
165
+
166
+ if connector.name in ConnectorRegistry._METADATA:
167
+ connector_metadata = ConnectorRegistry.get_connector_metadata(connector.name)
168
+ connector_metadata.supported_task_categories.append(task_category)
169
+ else:
170
+ connector_metadata = Connector(
171
+ name=connector.name,
172
+ supported_task_categories=[task_category],
173
+ )
174
+ ConnectorRegistry._METADATA[connector.name] = connector_metadata
175
+
176
+ @staticmethod
177
+ def get_connector(task_type_name: str, task_type_version: int = 0) -> AsyncConnector:
178
+ key = ConnectorRegistryKey(task_type_name=task_type_name, task_type_version=task_type_version)
179
+ if key not in ConnectorRegistry._REGISTRY:
180
+ raise FlyteConnectorNotFound(
181
+ f"Cannot find connector for task type: {task_type_name} and version: {task_type_version}"
182
+ )
183
+ return ConnectorRegistry._REGISTRY[key]
184
+
185
+ @staticmethod
186
+ def list_connectors() -> List[Connector]:
187
+ return list(ConnectorRegistry._METADATA.values())
188
+
189
+ @staticmethod
190
+ def get_connector_metadata(name: str) -> Connector:
191
+ if name not in ConnectorRegistry._METADATA:
192
+ raise FlyteConnectorNotFound(f"Cannot find connector for name: {name}.")
193
+ return ConnectorRegistry._METADATA[name]
194
+
195
+
196
+ class ConnectorSecretsMixin:
197
+ def __init__(self, secrets: Dict[str, str]):
198
+ # Key is the id of the secret, value is the secret name.
199
+ self._secrets = secrets
200
+
201
+ @property
202
+ def secrets(self) -> List[Secret]:
203
+ return [Secret(key=k, as_env_var=v) for k, v in self._secrets.items()]
204
+
205
+
206
+ class AsyncConnectorExecutorMixin:
207
+ """
208
+ This mixin class is used to run the connector task locally, and it's only used for local execution.
209
+ Task should inherit from this class if the task can be run in the connector.
210
+ """
211
+
212
+ async def execute(self, **kwargs) -> Any:
213
+ task = typing.cast(TaskTemplate, self)
214
+ connector = ConnectorRegistry.get_connector(task.task_type, task.task_type_version)
215
+
216
+ ctx = internal_ctx()
217
+ tctx = internal_ctx().data.task_context
218
+ cfg = get_init_config()
219
+
220
+ if tctx is None:
221
+ raise RuntimeError("Task context is not set.")
222
+
223
+ sc = SerializationContext(
224
+ project=tctx.action.project,
225
+ domain=tctx.action.domain,
226
+ org=tctx.action.org,
227
+ code_bundle=tctx.code_bundle,
228
+ version=tctx.version,
229
+ image_cache=tctx.compiled_image_cache,
230
+ root_dir=cfg.root_dir,
231
+ )
232
+ tt = get_proto_task(task, sc)
233
+ resource_meta = await connector.create(task_template=tt, output_prefix=ctx.raw_data.path, inputs=kwargs)
234
+ resource = Resource(phase=TaskExecution.RUNNING)
235
+
236
+ while not is_terminal_phase(resource.phase):
237
+ resource = await connector.get(resource_meta=resource_meta)
238
+
239
+ if resource.log_links:
240
+ for link in resource.log_links:
241
+ logger.info(f"{link.name}: {link.uri}")
242
+ await asyncio.sleep(1)
243
+
244
+ if resource.phase != TaskExecution.SUCCEEDED:
245
+ raise RuntimeError(f"Failed to run the task {task.name} with error: {resource.message}")
246
+
247
+ # TODO: Support abort
248
+
249
+ if resource.outputs is None:
250
+ return None
251
+ return tuple(resource.outputs.values())
252
+
253
+
254
+ async def get_resource_proto(resource: Resource) -> connector_pb2.Resource:
255
+ if resource.outputs:
256
+ interface = NativeInterface.from_types(inputs={}, outputs={k: type(v) for k, v in resource.outputs.items()})
257
+ outputs = await convert_from_native_to_outputs(tuple(resource.outputs.values()), interface)
258
+ else:
259
+ outputs = None
260
+
261
+ return connector_pb2.Resource(
262
+ phase=resource.phase,
263
+ message=resource.message,
264
+ log_links=resource.log_links,
265
+ outputs=outputs,
266
+ custom_info=(json_format.Parse(json.dumps(resource.custom_info), Struct()) if resource.custom_info else None),
267
+ )
268
+
269
+
270
+ class FlyteConnectorNotFound(ValueError): ...
@@ -0,0 +1,183 @@
1
+ import inspect
2
+ from http import HTTPStatus
3
+ from typing import Callable, Dict, Tuple, Type, Union
4
+
5
+ import grpc
6
+ from flyteidl2.core.security_pb2 import Connection
7
+ from flyteidl2.plugins.connector_pb2 import (
8
+ CreateTaskRequest,
9
+ CreateTaskResponse,
10
+ DeleteTaskRequest,
11
+ DeleteTaskResponse,
12
+ GetConnectorRequest,
13
+ GetConnectorResponse,
14
+ GetTaskLogsRequest,
15
+ GetTaskLogsResponse,
16
+ GetTaskMetricsRequest,
17
+ GetTaskMetricsResponse,
18
+ GetTaskRequest,
19
+ GetTaskResponse,
20
+ ListConnectorsRequest,
21
+ ListConnectorsResponse,
22
+ )
23
+ from flyteidl2.service.connector_pb2_grpc import (
24
+ AsyncConnectorServiceServicer,
25
+ ConnectorMetadataServiceServicer,
26
+ )
27
+ from prometheus_client import Counter, Summary
28
+
29
+ from flyte._internal.runtime.convert import Inputs, convert_from_inputs_to_native
30
+ from flyte._logging import logger
31
+ from flyte.connectors._connector import ConnectorRegistry, FlyteConnectorNotFound, get_resource_proto
32
+ from flyte.models import NativeInterface, _has_default
33
+ from flyte.types import TypeEngine
34
+
35
+ metric_prefix = "flyte_connector_"
36
+ create_operation = "create"
37
+ get_operation = "get"
38
+ delete_operation = "delete"
39
+
40
+ # Follow the naming convention. https://prometheus.io/docs/practices/naming/
41
+ request_success_count = Counter(
42
+ f"{metric_prefix}requests_success_total",
43
+ "Total number of successful requests",
44
+ ["task_type", "operation"],
45
+ )
46
+ request_failure_count = Counter(
47
+ f"{metric_prefix}requests_failure_total",
48
+ "Total number of failed requests",
49
+ ["task_type", "operation", "error_code"],
50
+ )
51
+ request_latency = Summary(
52
+ f"{metric_prefix}request_latency_seconds",
53
+ "Time spent processing connector request",
54
+ ["task_type", "operation"],
55
+ )
56
+ input_literal_size = Summary(f"{metric_prefix}input_literal_bytes", "Size of input literal", ["task_type"])
57
+
58
+
59
+ def _handle_exception(e: Exception, context: grpc.ServicerContext, task_type: str, operation: str):
60
+ if isinstance(e, FlyteConnectorNotFound):
61
+ error_message = f"Cannot find connector for task type: {task_type}."
62
+ logger.error(error_message)
63
+ context.set_code(grpc.StatusCode.NOT_FOUND)
64
+ context.set_details(error_message)
65
+ request_failure_count.labels(task_type=task_type, operation=operation, error_code=HTTPStatus.NOT_FOUND).inc()
66
+ else:
67
+ error_message = f"failed to {operation} {task_type} task with error:\n {e}."
68
+ logger.error(error_message)
69
+ context.set_code(grpc.StatusCode.INTERNAL)
70
+ context.set_details(error_message)
71
+ request_failure_count.labels(
72
+ task_type=task_type, operation=operation, error_code=HTTPStatus.INTERNAL_SERVER_ERROR
73
+ ).inc()
74
+
75
+
76
+ def record_connector_metrics(func: Callable):
77
+ async def wrapper(
78
+ self,
79
+ request: Union[CreateTaskRequest, GetTaskRequest, DeleteTaskRequest],
80
+ context: grpc.ServicerContext,
81
+ *args,
82
+ **kwargs,
83
+ ):
84
+ if isinstance(request, CreateTaskRequest):
85
+ task_type = request.template.type
86
+ operation = create_operation
87
+ if request.inputs:
88
+ input_literal_size.labels(task_type=task_type).observe(request.inputs.ByteSize())
89
+ elif isinstance(request, GetTaskRequest):
90
+ task_type = request.task_category.name
91
+ operation = get_operation
92
+ elif isinstance(request, DeleteTaskRequest):
93
+ task_type = request.task_category.name
94
+ operation = delete_operation
95
+ else:
96
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
97
+ context.set_details("Method not implemented!")
98
+ return None
99
+
100
+ try:
101
+ with request_latency.labels(task_type=task_type, operation=operation).time():
102
+ res = await func(self, request, context, *args, **kwargs)
103
+ request_success_count.labels(task_type=task_type, operation=operation).inc()
104
+ return res
105
+ except Exception as e:
106
+ _handle_exception(e, context, task_type, operation)
107
+
108
+ return wrapper
109
+
110
+
111
+ def _get_connection_kwargs(request: Connection) -> Dict[str, str]:
112
+ kwargs = {}
113
+
114
+ for k, v in request.secrets.items():
115
+ kwargs[k] = v
116
+ for k, v in request.configs.items():
117
+ kwargs[k] = v
118
+
119
+ return kwargs
120
+
121
+
122
+ class AsyncConnectorService(AsyncConnectorServiceServicer):
123
+ @record_connector_metrics
124
+ async def CreateTask(self, request: CreateTaskRequest, context: grpc.ServicerContext) -> CreateTaskResponse:
125
+ template = request.template
126
+ connector = ConnectorRegistry.get_connector(template.type, template.task_type_version)
127
+ logger.info(f"{connector.name} start creating the job")
128
+ python_interface_inputs: Dict[str, Tuple[Type, Type[_has_default] | Type[inspect._empty]]] = {
129
+ name: (TypeEngine.guess_python_type(lt.type), inspect.Parameter.empty)
130
+ for name, lt in template.interface.inputs.variables.items()
131
+ }
132
+ native_interface = NativeInterface.from_types(inputs=python_interface_inputs, outputs={})
133
+ native_inputs = await convert_from_inputs_to_native(native_interface, Inputs(proto_inputs=request.inputs))
134
+ resource_meta = await connector.create(
135
+ task_template=request.template,
136
+ inputs=native_inputs,
137
+ output_prefix=request.output_prefix,
138
+ task_execution_metadata=request.task_execution_metadata,
139
+ connection=_get_connection_kwargs(request.connection),
140
+ )
141
+ return CreateTaskResponse(resource_meta=resource_meta.encode())
142
+
143
+ @record_connector_metrics
144
+ async def GetTask(self, request: GetTaskRequest, context: grpc.ServicerContext) -> GetTaskResponse:
145
+ connector = ConnectorRegistry.get_connector(request.task_category.name, request.task_category.version)
146
+ logger.info(f"{connector.name} start checking the status of the job")
147
+ res = await connector.get(
148
+ resource_meta=connector.metadata_type.decode(request.resource_meta),
149
+ connection=_get_connection_kwargs(request.connection),
150
+ )
151
+ return GetTaskResponse(resource=await get_resource_proto(res))
152
+
153
+ @record_connector_metrics
154
+ async def DeleteTask(self, request: DeleteTaskRequest, context: grpc.ServicerContext) -> DeleteTaskResponse:
155
+ connector = ConnectorRegistry.get_connector(request.task_category.name, request.task_category.version)
156
+ logger.info(f"{connector.name} start deleting the job")
157
+ await connector.delete(
158
+ resource_meta=connector.metadata_type.decode(request.resource_meta),
159
+ connection=_get_connection_kwargs(request.connection),
160
+ )
161
+ return DeleteTaskResponse()
162
+
163
+ async def GetTaskMetrics(
164
+ self, request: GetTaskMetricsRequest, context: grpc.ServicerContext
165
+ ) -> GetTaskMetricsResponse:
166
+ connector = ConnectorRegistry.get_connector(request.task_category.name, request.task_category.version)
167
+ logger.info(f"{connector.name} start getting metrics of the job")
168
+ return await connector.get_metrics(resource_meta=connector.metadata_type.decode(request.resource_meta))
169
+
170
+ async def GetTaskLogs(self, request: GetTaskLogsRequest, context: grpc.ServicerContext) -> GetTaskLogsResponse:
171
+ connector = ConnectorRegistry.get_connector(request.task_category.name, request.task_category.version)
172
+ logger.info(f"{connector.name} start getting logs of the job")
173
+ return await connector.get_logs(resource_meta=connector.metadata_type.decode(request.resource_meta))
174
+
175
+
176
+ class ConnectorMetadataService(ConnectorMetadataServiceServicer):
177
+ async def GetConnector(self, request: GetConnectorRequest, context: grpc.ServicerContext) -> GetConnectorResponse:
178
+ return GetConnectorResponse(connector=ConnectorRegistry.get_connector_metadata(request.name))
179
+
180
+ async def ListConnectors(
181
+ self, request: ListConnectorsRequest, context: grpc.ServicerContext
182
+ ) -> ListConnectorsResponse:
183
+ return ListConnectorsResponse(connectors=ConnectorRegistry.list_connectors())
@@ -0,0 +1,26 @@
1
+ from flyteidl2.core.execution_pb2 import TaskExecution
2
+
3
+
4
+ def is_terminal_phase(phase: TaskExecution.Phase) -> bool:
5
+ """
6
+ Return true if the phase is terminal.
7
+ """
8
+ return phase in [TaskExecution.SUCCEEDED, TaskExecution.ABORTED, TaskExecution.FAILED]
9
+
10
+
11
+ def convert_to_flyte_phase(state: str) -> TaskExecution.Phase:
12
+ """
13
+ Convert the state from the connector to the phase in flyte.
14
+ """
15
+ state = state.lower()
16
+ if state in ["failed", "timeout", "timedout", "canceled", "cancelled", "skipped"]:
17
+ return TaskExecution.FAILED
18
+ if state in ["internal_error"]:
19
+ return TaskExecution.RETRYABLE_FAILED
20
+ elif state in ["done", "succeeded", "success", "completed"]:
21
+ return TaskExecution.SUCCEEDED
22
+ elif state in ["running", "terminating"]:
23
+ return TaskExecution.RUNNING
24
+ elif state in ["pending"]:
25
+ return TaskExecution.INITIALIZING
26
+ raise ValueError(f"Unrecognized state: {state}")
flyte/models.py CHANGED
@@ -355,10 +355,18 @@ class NativeInterface:
355
355
  """
356
356
  Extract the native interface from the given function. This is used to create a native interface for the task.
357
357
  """
358
+ # Get function parameters, defaults, varargs info (POSITIONAL_ONLY, VAR_POSITIONAL, KEYWORD_ONLY, etc.).
358
359
  sig = inspect.signature(func)
359
360
 
360
361
  # Extract parameter details (name, type, default value)
361
362
  param_info = {}
363
+ try:
364
+ # Get fully evaluated, real Python types for type checking.
365
+ hints = typing.get_type_hints(func, include_extras=True)
366
+ except Exception as e:
367
+ logger.warning(f"Could not get type hints for function {func.__name__}: {e}")
368
+ raise
369
+
362
370
  for name, param in sig.parameters.items():
363
371
  if param.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD):
364
372
  raise ValueError(f"Function {func.__name__} cannot have variable positional or keyword arguments.")
@@ -366,13 +374,14 @@ class NativeInterface:
366
374
  logger.warning(
367
375
  f"Function {func.__name__} has parameter {name} without type annotation. Data will be pickled."
368
376
  )
369
- if typing.get_origin(param.annotation) is Literal:
370
- param_info[name] = (literal_to_enum(param.annotation), param.default)
377
+ arg_type = hints.get(name, param.annotation)
378
+ if typing.get_origin(arg_type) is Literal:
379
+ param_info[name] = (literal_to_enum(arg_type), param.default)
371
380
  else:
372
- param_info[name] = (param.annotation, param.default)
381
+ param_info[name] = (arg_type, param.default)
373
382
 
374
383
  # Get return type
375
- outputs = extract_return_annotation(sig.return_annotation)
384
+ outputs = extract_return_annotation(hints.get("return", sig.return_annotation))
376
385
  return cls(inputs=param_info, outputs=outputs)
377
386
 
378
387
  def convert_to_kwargs(self, *args, **kwargs) -> Dict[str, Any]:
@@ -7,6 +7,7 @@ import httpx
7
7
  from grpc.experimental.aio import init_grpc_aio
8
8
 
9
9
  from flyte._logging import logger
10
+ from flyte._utils.org_discovery import hostname_from_url
10
11
 
11
12
  from ._authenticators.base import get_async_session
12
13
  from ._authenticators.factory import (
@@ -30,16 +31,19 @@ def bootstrap_ssl_from_server(endpoint: str) -> grpc.ChannelCredentials:
30
31
  :param endpoint: The endpoint URL to retrieve the SSL certificate from, may include port number
31
32
  :return: gRPC channel credentials created from the retrieved certificate
32
33
  """
34
+ hostname = hostname_from_url(endpoint)
35
+
33
36
  # Get port from endpoint or use 443
34
- endpoint_parts = endpoint.rsplit(":", 1)
37
+ endpoint_parts = hostname.rsplit(":", 1)
35
38
  if len(endpoint_parts) == 2 and endpoint_parts[1].isdigit():
36
39
  server_address = (endpoint_parts[0], int(endpoint_parts[1]))
37
40
  else:
38
- logger.warning(f"Unrecognized port in endpoint [{endpoint}], defaulting to 443.")
39
- server_address = (endpoint, 443)
41
+ logger.warning(f"Unrecognized port in endpoint [{hostname}], defaulting to 443.")
42
+ server_address = (hostname, 443)
40
43
 
41
- # Run the blocking SSL certificate retrieval in a thread pool
42
- cert = ssl.get_server_certificate(server_address)
44
+ # Run the blocking SSL certificate retrieval with a timeout
45
+ logger.debug(f"Retrieving SSL certificate from {server_address}")
46
+ cert = ssl.get_server_certificate(server_address, timeout=10)
43
47
  return grpc.ssl_channel_credentials(str.encode(cert))
44
48
 
45
49
 
flyte/remote/_console.py CHANGED
@@ -9,8 +9,9 @@ def _get_http_domain(endpoint: str, insecure: bool) -> str:
9
9
  else:
10
10
  domain = parsed.netloc or parsed.path
11
11
  # TODO: make console url configurable
12
- if domain.split(":")[0] == "localhost":
13
- domain = "localhost:8080"
12
+ domain_split = domain.split(":")
13
+ if domain_split[0] == "localhost":
14
+ domain = domain if len(domain_split) > 1 else f"{domain}:8080"
14
15
  return f"{scheme}://{domain}"
15
16
 
16
17
 
flyte/remote/_secret.py CHANGED
@@ -79,21 +79,23 @@ class Secret(ToJSONMixin):
79
79
  async def listall(cls, limit: int = 100) -> AsyncIterator[Secret]:
80
80
  ensure_client()
81
81
  cfg = get_init_config()
82
- token = None
82
+ per_cluster_tokens = None
83
83
  while True:
84
84
  resp = await get_client().secrets_service.ListSecrets( # type: ignore
85
85
  request=payload_pb2.ListSecretsRequest(
86
86
  organization=cfg.org,
87
87
  project=cfg.project,
88
88
  domain=cfg.domain,
89
- token=token,
89
+ per_cluster_tokens=per_cluster_tokens,
90
90
  limit=limit,
91
91
  ),
92
92
  )
93
- token = resp.token
93
+ per_cluster_tokens = resp.per_cluster_tokens
94
+ round_items = [v for _, v in per_cluster_tokens.items() if v]
95
+ has_next = any(round_items)
94
96
  for r in resp.secrets:
95
97
  yield cls(r)
96
- if not token:
98
+ if not has_next:
97
99
  break
98
100
 
99
101
  @syncify
flyte/remote/_trigger.py CHANGED
@@ -284,9 +284,9 @@ class Trigger(ToJSONMixin):
284
284
  return self.pb2.active
285
285
 
286
286
  def _rich_automation(self, automation: common_pb2.TriggerAutomationSpec):
287
- if automation.type == common_pb2.TriggerAutomationSpec.TYPE_NONE:
287
+ if automation.type == common_pb2.TriggerAutomationSpec.type.TYPE_NONE:
288
288
  yield "none", None
289
- elif automation.type == common_pb2.TriggerAutomationSpec.TYPE_SCHEDULE:
289
+ elif automation.type == common_pb2.TriggerAutomationSpec.type.TYPE_SCHEDULE:
290
290
  if automation.schedule.cron_expression is not None:
291
291
  yield "cron", automation.schedule.cron_expression
292
292
  elif automation.schedule.rate is not None:
@@ -1911,7 +1911,6 @@ def _get_element_type(element_property: typing.Dict[str, str]) -> Type:
1911
1911
  return str
1912
1912
 
1913
1913
 
1914
- # pr: han-ru is this still needed?
1915
1914
  def dataclass_from_dict(cls: type, src: typing.Dict[str, typing.Any]) -> typing.Any:
1916
1915
  """
1917
1916
  Utility function to construct a dataclass object from dict
@@ -1991,7 +1990,7 @@ def _handle_flyte_console_float_input_to_int(lv: Literal) -> int:
1991
1990
 
1992
1991
  def _check_and_convert_void(lv: Literal) -> None:
1993
1992
  if not lv.scalar.HasField("none_type"):
1994
- raise TypeTransformerFailedError(f"Cannot convert literal {lv} to None")
1993
+ raise TypeTransformerFailedError(f"Cannot convert literal '{lv}' to None")
1995
1994
  return None
1996
1995
 
1997
1996
 
@@ -12,6 +12,7 @@ from typing import Any, List
12
12
 
13
13
  import click
14
14
 
15
+ from flyte._utils.helpers import str2bool
15
16
  from flyte.models import PathRewrite
16
17
 
17
18
  # Todo: work with pvditt to make these the names
@@ -27,6 +28,7 @@ PROJECT_NAME = "FLYTE_INTERNAL_EXECUTION_PROJECT"
27
28
  DOMAIN_NAME = "FLYTE_INTERNAL_EXECUTION_DOMAIN"
28
29
  ORG_NAME = "_U_ORG_NAME"
29
30
  ENDPOINT_OVERRIDE = "_U_EP_OVERRIDE"
31
+ INSECURE_SKIP_VERIFY_OVERRIDE = "_U_INSECURE_SKIP_VERIFY"
30
32
  RUN_OUTPUT_BASE_DIR = "_U_RUN_BASE"
31
33
  FLYTE_ENABLE_VSCODE_KEY = "_F_E_VS"
32
34
 
@@ -139,6 +141,12 @@ def main(
139
141
  controller_kwargs["insecure"] = True
140
142
  logger.debug(f"Using controller endpoint: {ep} with kwargs: {controller_kwargs}")
141
143
 
144
+ # Check for insecure_skip_verify override (e.g. for self-signed certs)
145
+ insecure_skip_verify_str = os.getenv(INSECURE_SKIP_VERIFY_OVERRIDE, "")
146
+ if str2bool(insecure_skip_verify_str):
147
+ controller_kwargs["insecure_skip_verify"] = True
148
+ logger.info("SSL certificate verification disabled (insecure_skip_verify=True)")
149
+
142
150
  bundle = None
143
151
  if tgz or pkl:
144
152
  bundle = CodeBundle(tgz=tgz, pkl=pkl, destination=dest, computed_version=version)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: flyte
3
- Version: 2.0.0b25
3
+ Version: 2.0.0b28
4
4
  Summary: Add your description here
5
5
  Author-email: Ketan Umare <kumare3@users.noreply.github.com>
6
6
  Requires-Python: >=3.10
@@ -25,9 +25,13 @@ Requires-Dist: async-lru>=2.0.5
25
25
  Requires-Dist: mashumaro
26
26
  Requires-Dist: dataclasses_json
27
27
  Requires-Dist: aiolimiter>=1.2.1
28
- Requires-Dist: flyteidl2==2.0.0a10
28
+ Requires-Dist: flyteidl2==2.0.0a11
29
29
  Provides-Extra: aiosqlite
30
30
  Requires-Dist: aiosqlite>=0.21.0; extra == "aiosqlite"
31
+ Provides-Extra: connector
32
+ Requires-Dist: grpcio-health-checking; extra == "connector"
33
+ Requires-Dist: httpx; extra == "connector"
34
+ Requires-Dist: prometheus-client; extra == "connector"
31
35
  Dynamic: license-file
32
36
 
33
37
  # Flyte 2 SDK 🚀