flyte 0.2.0b12__py3-none-any.whl → 0.2.0b14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flyte might be problematic. Click here for more details.
- flyte/__init__.py +5 -0
- flyte/_excepthook.py +37 -0
- flyte/_internal/controllers/remote/_action.py +5 -0
- flyte/_internal/controllers/remote/_controller.py +43 -3
- flyte/_internal/controllers/remote/_core.py +7 -0
- flyte/_internal/runtime/convert.py +61 -7
- flyte/_internal/runtime/task_serde.py +1 -1
- flyte/_protos/workflow/queue_service_pb2.py +30 -29
- flyte/_protos/workflow/queue_service_pb2.pyi +5 -2
- flyte/_protos/workflow/state_service_pb2.py +36 -28
- flyte/_protos/workflow/state_service_pb2.pyi +19 -15
- flyte/_protos/workflow/state_service_pb2_grpc.py +28 -28
- flyte/_run.py +6 -0
- flyte/_version.py +2 -2
- flyte/cli/_common.py +2 -2
- flyte/cli/_delete.py +1 -1
- flyte/cli/_deploy.py +1 -1
- flyte/cli/_get.py +2 -2
- flyte/cli/_run.py +1 -4
- flyte/cli/main.py +19 -18
- flyte/remote/_client/auth/_channel.py +6 -0
- flyte/remote/_data.py +3 -1
- flyte/syncify/_api.py +103 -35
- flyte/types/_type_engine.py +83 -9
- flyte-0.2.0b14.dist-info/METADATA +249 -0
- {flyte-0.2.0b12.dist-info → flyte-0.2.0b14.dist-info}/RECORD +29 -28
- flyte-0.2.0b12.dist-info/METADATA +0 -181
- {flyte-0.2.0b12.dist-info → flyte-0.2.0b14.dist-info}/WHEEL +0 -0
- {flyte-0.2.0b12.dist-info → flyte-0.2.0b14.dist-info}/entry_points.txt +0 -0
- {flyte-0.2.0b12.dist-info → flyte-0.2.0b14.dist-info}/top_level.txt +0 -0
|
@@ -15,15 +15,15 @@ class StateServiceStub(object):
|
|
|
15
15
|
Args:
|
|
16
16
|
channel: A grpc.Channel.
|
|
17
17
|
"""
|
|
18
|
-
self.
|
|
19
|
-
'/cloudidl.workflow.StateService/
|
|
20
|
-
request_serializer=workflow_dot_state__service__pb2.
|
|
21
|
-
response_deserializer=workflow_dot_state__service__pb2.
|
|
18
|
+
self.Put = channel.stream_stream(
|
|
19
|
+
'/cloudidl.workflow.StateService/Put',
|
|
20
|
+
request_serializer=workflow_dot_state__service__pb2.PutRequest.SerializeToString,
|
|
21
|
+
response_deserializer=workflow_dot_state__service__pb2.PutResponse.FromString,
|
|
22
22
|
)
|
|
23
|
-
self.
|
|
24
|
-
'/cloudidl.workflow.StateService/
|
|
25
|
-
request_serializer=workflow_dot_state__service__pb2.
|
|
26
|
-
response_deserializer=workflow_dot_state__service__pb2.
|
|
23
|
+
self.Get = channel.stream_stream(
|
|
24
|
+
'/cloudidl.workflow.StateService/Get',
|
|
25
|
+
request_serializer=workflow_dot_state__service__pb2.GetRequest.SerializeToString,
|
|
26
|
+
response_deserializer=workflow_dot_state__service__pb2.GetResponse.FromString,
|
|
27
27
|
)
|
|
28
28
|
self.Watch = channel.unary_stream(
|
|
29
29
|
'/cloudidl.workflow.StateService/Watch',
|
|
@@ -36,15 +36,15 @@ class StateServiceServicer(object):
|
|
|
36
36
|
"""provides an interface for managing the state of actions.
|
|
37
37
|
"""
|
|
38
38
|
|
|
39
|
-
def
|
|
40
|
-
"""
|
|
39
|
+
def Put(self, request_iterator, context):
|
|
40
|
+
"""put the state of an action.
|
|
41
41
|
"""
|
|
42
42
|
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
43
43
|
context.set_details('Method not implemented!')
|
|
44
44
|
raise NotImplementedError('Method not implemented!')
|
|
45
45
|
|
|
46
|
-
def
|
|
47
|
-
"""
|
|
46
|
+
def Get(self, request_iterator, context):
|
|
47
|
+
"""get the state of an action.
|
|
48
48
|
"""
|
|
49
49
|
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
50
50
|
context.set_details('Method not implemented!')
|
|
@@ -60,15 +60,15 @@ class StateServiceServicer(object):
|
|
|
60
60
|
|
|
61
61
|
def add_StateServiceServicer_to_server(servicer, server):
|
|
62
62
|
rpc_method_handlers = {
|
|
63
|
-
'
|
|
64
|
-
servicer.
|
|
65
|
-
request_deserializer=workflow_dot_state__service__pb2.
|
|
66
|
-
response_serializer=workflow_dot_state__service__pb2.
|
|
63
|
+
'Put': grpc.stream_stream_rpc_method_handler(
|
|
64
|
+
servicer.Put,
|
|
65
|
+
request_deserializer=workflow_dot_state__service__pb2.PutRequest.FromString,
|
|
66
|
+
response_serializer=workflow_dot_state__service__pb2.PutResponse.SerializeToString,
|
|
67
67
|
),
|
|
68
|
-
'
|
|
69
|
-
servicer.
|
|
70
|
-
request_deserializer=workflow_dot_state__service__pb2.
|
|
71
|
-
response_serializer=workflow_dot_state__service__pb2.
|
|
68
|
+
'Get': grpc.stream_stream_rpc_method_handler(
|
|
69
|
+
servicer.Get,
|
|
70
|
+
request_deserializer=workflow_dot_state__service__pb2.GetRequest.FromString,
|
|
71
|
+
response_serializer=workflow_dot_state__service__pb2.GetResponse.SerializeToString,
|
|
72
72
|
),
|
|
73
73
|
'Watch': grpc.unary_stream_rpc_method_handler(
|
|
74
74
|
servicer.Watch,
|
|
@@ -87,7 +87,7 @@ class StateService(object):
|
|
|
87
87
|
"""
|
|
88
88
|
|
|
89
89
|
@staticmethod
|
|
90
|
-
def
|
|
90
|
+
def Put(request_iterator,
|
|
91
91
|
target,
|
|
92
92
|
options=(),
|
|
93
93
|
channel_credentials=None,
|
|
@@ -97,14 +97,14 @@ class StateService(object):
|
|
|
97
97
|
wait_for_ready=None,
|
|
98
98
|
timeout=None,
|
|
99
99
|
metadata=None):
|
|
100
|
-
return grpc.experimental.
|
|
101
|
-
workflow_dot_state__service__pb2.
|
|
102
|
-
workflow_dot_state__service__pb2.
|
|
100
|
+
return grpc.experimental.stream_stream(request_iterator, target, '/cloudidl.workflow.StateService/Put',
|
|
101
|
+
workflow_dot_state__service__pb2.PutRequest.SerializeToString,
|
|
102
|
+
workflow_dot_state__service__pb2.PutResponse.FromString,
|
|
103
103
|
options, channel_credentials,
|
|
104
104
|
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
|
105
105
|
|
|
106
106
|
@staticmethod
|
|
107
|
-
def
|
|
107
|
+
def Get(request_iterator,
|
|
108
108
|
target,
|
|
109
109
|
options=(),
|
|
110
110
|
channel_credentials=None,
|
|
@@ -114,9 +114,9 @@ class StateService(object):
|
|
|
114
114
|
wait_for_ready=None,
|
|
115
115
|
timeout=None,
|
|
116
116
|
metadata=None):
|
|
117
|
-
return grpc.experimental.
|
|
118
|
-
workflow_dot_state__service__pb2.
|
|
119
|
-
workflow_dot_state__service__pb2.
|
|
117
|
+
return grpc.experimental.stream_stream(request_iterator, target, '/cloudidl.workflow.StateService/Get',
|
|
118
|
+
workflow_dot_state__service__pb2.GetRequest.SerializeToString,
|
|
119
|
+
workflow_dot_state__service__pb2.GetResponse.FromString,
|
|
120
120
|
options, channel_credentials,
|
|
121
121
|
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
|
|
122
122
|
|
flyte/_run.py
CHANGED
|
@@ -63,6 +63,7 @@ class _Runner:
|
|
|
63
63
|
raw_data_path: str | None = None,
|
|
64
64
|
metadata_path: str | None = None,
|
|
65
65
|
run_base_dir: str | None = None,
|
|
66
|
+
overwrite_cache: bool = False,
|
|
66
67
|
):
|
|
67
68
|
init_config = _get_init_config()
|
|
68
69
|
client = init_config.client if init_config else None
|
|
@@ -81,6 +82,7 @@ class _Runner:
|
|
|
81
82
|
self._raw_data_path = raw_data_path
|
|
82
83
|
self._metadata_path = metadata_path or "/tmp"
|
|
83
84
|
self._run_base_dir = run_base_dir or "/tmp/base"
|
|
85
|
+
self._overwrite_cache = overwrite_cache
|
|
84
86
|
|
|
85
87
|
@requires_initialization
|
|
86
88
|
async def _run_remote(self, obj: TaskTemplate[P, R] | LazyEntity, *args: P.args, **kwargs: P.kwargs) -> Run:
|
|
@@ -182,6 +184,9 @@ class _Runner:
|
|
|
182
184
|
project_id=project_id,
|
|
183
185
|
task_spec=task_spec,
|
|
184
186
|
inputs=inputs.proto_inputs,
|
|
187
|
+
run_spec=run_definition_pb2.RunSpec(
|
|
188
|
+
overwrite_cache=self._overwrite_cache,
|
|
189
|
+
),
|
|
185
190
|
),
|
|
186
191
|
)
|
|
187
192
|
return Run(pb2=resp.run)
|
|
@@ -414,6 +419,7 @@ def with_runcontext(
|
|
|
414
419
|
interactive_mode: bool | None = None,
|
|
415
420
|
raw_data_path: str | None = None,
|
|
416
421
|
run_base_dir: str | None = None,
|
|
422
|
+
overwrite_cache: bool = False,
|
|
417
423
|
) -> _Runner:
|
|
418
424
|
"""
|
|
419
425
|
Launch a new run with the given parameters as the context.
|
flyte/_version.py
CHANGED
|
@@ -17,5 +17,5 @@ __version__: str
|
|
|
17
17
|
__version_tuple__: VERSION_TUPLE
|
|
18
18
|
version_tuple: VERSION_TUPLE
|
|
19
19
|
|
|
20
|
-
__version__ = version = '0.2.
|
|
21
|
-
__version_tuple__ = version_tuple = (0, 2, 0, '
|
|
20
|
+
__version__ = version = '0.2.0b14'
|
|
21
|
+
__version_tuple__ = version_tuple = (0, 2, 0, 'b14')
|
flyte/cli/_common.py
CHANGED
|
@@ -137,8 +137,8 @@ class InvokeBaseMixin:
|
|
|
137
137
|
if e.code() == grpc.StatusCode.INVALID_ARGUMENT:
|
|
138
138
|
raise click.ClickException(f"Invalid argument provided. Please check your input. Error: {e.details()}")
|
|
139
139
|
raise click.ClickException(f"RPC error invoking command: {e!s}") from e
|
|
140
|
-
except flyte.errors.InitializationError:
|
|
141
|
-
raise click.ClickException("
|
|
140
|
+
except flyte.errors.InitializationError as e:
|
|
141
|
+
raise click.ClickException(f"Initialization failed. Pass remote config for CLI. (Reason: {e})")
|
|
142
142
|
except flyte.errors.BaseRuntimeError as e:
|
|
143
143
|
raise click.ClickException(f"{e.kind} failure, {e.code}. {e}") from e
|
|
144
144
|
except click.exceptions.Exit as e:
|
flyte/cli/_delete.py
CHANGED
|
@@ -10,7 +10,7 @@ def delete():
|
|
|
10
10
|
"""
|
|
11
11
|
|
|
12
12
|
|
|
13
|
-
@
|
|
13
|
+
@delete.command(cls=common.CommandBase)
|
|
14
14
|
@click.argument("name", type=str, required=True)
|
|
15
15
|
@click.pass_obj
|
|
16
16
|
def secret(cfg: common.CLIConfig, name: str, project: str | None = None, domain: str | None = None):
|
flyte/cli/_deploy.py
CHANGED
|
@@ -147,6 +147,6 @@ deploy = EnvFiles(
|
|
|
147
147
|
name="deploy",
|
|
148
148
|
help="""
|
|
149
149
|
Deploy one or more environments from a python file.
|
|
150
|
-
|
|
150
|
+
This command will create or update environments in the Flyte system.
|
|
151
151
|
""",
|
|
152
152
|
)
|
flyte/cli/_get.py
CHANGED
|
@@ -20,7 +20,7 @@ def get():
|
|
|
20
20
|
Using a `get` subcommand without any arguments will retrieve a list of available resources to get.
|
|
21
21
|
For example:
|
|
22
22
|
|
|
23
|
-
* `get project` (without
|
|
23
|
+
* `get project` (without specifying a project), will list all projects.
|
|
24
24
|
* `get project my_project` will return the details of the project named `my_project`.
|
|
25
25
|
|
|
26
26
|
In some cases, a partially specified command will act as a filter and return available further parameters.
|
|
@@ -143,7 +143,7 @@ def action(
|
|
|
143
143
|
"--pretty",
|
|
144
144
|
is_flag=True,
|
|
145
145
|
default=False,
|
|
146
|
-
help="Show logs in
|
|
146
|
+
help="Show logs in an auto-scrolling box, where number of lines is limited to `--lines`",
|
|
147
147
|
)
|
|
148
148
|
@click.option(
|
|
149
149
|
"--attempt", "-a", type=int, default=None, help="Attempt number to show logs for, defaults to the latest attempt."
|
flyte/cli/_run.py
CHANGED
|
@@ -66,7 +66,7 @@ class RunArguments:
|
|
|
66
66
|
["--follow", "-f"],
|
|
67
67
|
is_flag=True,
|
|
68
68
|
default=False,
|
|
69
|
-
help="Wait and watch logs for the parent action. If not provided, the
|
|
69
|
+
help="Wait and watch logs for the parent action. If not provided, the CLI will exit after "
|
|
70
70
|
"successfully launching a remote execution with a link to the UI.",
|
|
71
71
|
)
|
|
72
72
|
},
|
|
@@ -99,8 +99,6 @@ class RunTaskCommand(click.Command):
|
|
|
99
99
|
|
|
100
100
|
obj = CLIConfig(flyte.config.auto(), ctx)
|
|
101
101
|
|
|
102
|
-
if not self.run_args.local:
|
|
103
|
-
assert obj.endpoint, "CLI Config should have an endpoint"
|
|
104
102
|
obj.init(self.run_args.project, self.run_args.domain)
|
|
105
103
|
|
|
106
104
|
async def _run():
|
|
@@ -108,7 +106,6 @@ class RunTaskCommand(click.Command):
|
|
|
108
106
|
|
|
109
107
|
r = flyte.with_runcontext(
|
|
110
108
|
copy_style=self.run_args.copy_style,
|
|
111
|
-
version=self.run_args.copy_style,
|
|
112
109
|
mode="local" if self.run_args.local else "remote",
|
|
113
110
|
name=self.run_args.name,
|
|
114
111
|
).run(self.obj, **ctx.params)
|
flyte/cli/main.py
CHANGED
|
@@ -5,31 +5,31 @@ from flyte._logging import initialize_logger, logger
|
|
|
5
5
|
from ._abort import abort
|
|
6
6
|
from ._common import CLIConfig
|
|
7
7
|
from ._create import create
|
|
8
|
+
from ._delete import delete
|
|
8
9
|
from ._deploy import deploy
|
|
9
10
|
from ._gen import gen
|
|
10
11
|
from ._get import get
|
|
11
12
|
from ._run import run
|
|
12
13
|
|
|
13
|
-
click.rich_click.COMMAND_GROUPS = {
|
|
14
|
-
"flyte": [
|
|
15
|
-
{
|
|
16
|
-
"name": "Running workflows",
|
|
17
|
-
"commands": ["run", "abort"],
|
|
18
|
-
},
|
|
19
|
-
{
|
|
20
|
-
"name": "Management",
|
|
21
|
-
"commands": ["create", "deploy", "get"],
|
|
22
|
-
},
|
|
23
|
-
{
|
|
24
|
-
"name": "Documentation generation",
|
|
25
|
-
"commands": ["gen"],
|
|
26
|
-
},
|
|
27
|
-
]
|
|
28
|
-
}
|
|
29
|
-
|
|
30
14
|
help_config = click.RichHelpConfiguration(
|
|
31
15
|
use_markdown=True,
|
|
32
16
|
use_markdown_emoji=True,
|
|
17
|
+
command_groups={
|
|
18
|
+
"flyte": [
|
|
19
|
+
{
|
|
20
|
+
"name": "Run and stop tasks",
|
|
21
|
+
"commands": ["run", "abort"],
|
|
22
|
+
},
|
|
23
|
+
{
|
|
24
|
+
"name": "Management",
|
|
25
|
+
"commands": ["create", "deploy", "get", "delete"],
|
|
26
|
+
},
|
|
27
|
+
{
|
|
28
|
+
"name": "Documentation generation",
|
|
29
|
+
"commands": ["gen"],
|
|
30
|
+
},
|
|
31
|
+
]
|
|
32
|
+
},
|
|
33
33
|
)
|
|
34
34
|
|
|
35
35
|
|
|
@@ -103,7 +103,7 @@ def main(
|
|
|
103
103
|
config_file: str | None,
|
|
104
104
|
):
|
|
105
105
|
"""
|
|
106
|
-
The Flyte CLI is the
|
|
106
|
+
The Flyte CLI is the command line interface for working with the Flyte SDK and backend.
|
|
107
107
|
|
|
108
108
|
It follows a simple verb/noun structure,
|
|
109
109
|
where the top-level commands are verbs that describe the action to be taken,
|
|
@@ -163,3 +163,4 @@ main.add_command(get) # type: ignore
|
|
|
163
163
|
main.add_command(create) # type: ignore
|
|
164
164
|
main.add_command(abort) # type: ignore
|
|
165
165
|
main.add_command(gen) # type: ignore
|
|
166
|
+
main.add_command(delete) # type: ignore
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import os
|
|
1
2
|
import ssl
|
|
2
3
|
import typing
|
|
3
4
|
|
|
@@ -15,6 +16,11 @@ from ._authenticators.factory import (
|
|
|
15
16
|
get_async_proxy_authenticator,
|
|
16
17
|
)
|
|
17
18
|
|
|
19
|
+
# Set environment variables for gRPC, this reduces log spew and avoids unnecessary warnings
|
|
20
|
+
if "GRPC_VERBOSITY" not in os.environ:
|
|
21
|
+
os.environ["GRPC_VERBOSITY"] = "ERROR"
|
|
22
|
+
os.environ["GRPC_CPP_MIN_LOG_LEVEL"] = "ERROR"
|
|
23
|
+
|
|
18
24
|
# Initialize gRPC AIO early enough so it can be used in the main thread
|
|
19
25
|
init_grpc_aio()
|
|
20
26
|
|
flyte/remote/_data.py
CHANGED
|
@@ -16,7 +16,7 @@ from flyteidl.service import dataproxy_pb2
|
|
|
16
16
|
from google.protobuf import duration_pb2
|
|
17
17
|
|
|
18
18
|
from flyte._initialize import CommonInit, ensure_client, get_client, get_common_config
|
|
19
|
-
from flyte.errors import RuntimeSystemError
|
|
19
|
+
from flyte.errors import InitializationError, RuntimeSystemError
|
|
20
20
|
|
|
21
21
|
_UPLOAD_EXPIRES_IN = timedelta(seconds=60)
|
|
22
22
|
|
|
@@ -83,6 +83,8 @@ async def _upload_single_file(
|
|
|
83
83
|
raise RuntimeSystemError(
|
|
84
84
|
"PermissionDenied", f"Failed to get signed url for {fp}, please check your permissions."
|
|
85
85
|
)
|
|
86
|
+
elif e.code() == grpc.StatusCode.UNAVAILABLE:
|
|
87
|
+
raise InitializationError("EndpointUnavailable", "user", "Service is unavailable.")
|
|
86
88
|
else:
|
|
87
89
|
raise RuntimeSystemError(e.code().value, f"Failed to get signed url for {fp}.")
|
|
88
90
|
except Exception as e:
|
flyte/syncify/_api.py
CHANGED
|
@@ -5,6 +5,7 @@ import atexit
|
|
|
5
5
|
import concurrent.futures
|
|
6
6
|
import functools
|
|
7
7
|
import inspect
|
|
8
|
+
import logging
|
|
8
9
|
import threading
|
|
9
10
|
from typing import (
|
|
10
11
|
Any,
|
|
@@ -21,6 +22,8 @@ from typing import (
|
|
|
21
22
|
overload,
|
|
22
23
|
)
|
|
23
24
|
|
|
25
|
+
from flyte._logging import logger
|
|
26
|
+
|
|
24
27
|
P = ParamSpec("P")
|
|
25
28
|
R_co = TypeVar("R_co", covariant=True)
|
|
26
29
|
T = TypeVar("T")
|
|
@@ -50,7 +53,7 @@ class SyncGenFunction(Protocol[P, R_co]):
|
|
|
50
53
|
|
|
51
54
|
class _BackgroundLoop:
|
|
52
55
|
"""
|
|
53
|
-
A background event loop that runs in a separate thread and used the
|
|
56
|
+
A background event loop that runs in a separate thread and used the `Syncify` decorator to run asynchronous
|
|
54
57
|
functions or methods synchronously.
|
|
55
58
|
"""
|
|
56
59
|
|
|
@@ -105,6 +108,19 @@ class _BackgroundLoop:
|
|
|
105
108
|
yield future.result()
|
|
106
109
|
except (StopAsyncIteration, StopIteration):
|
|
107
110
|
break
|
|
111
|
+
except Exception as e:
|
|
112
|
+
if logger.getEffectiveLevel() > logging.DEBUG:
|
|
113
|
+
# If the log level is not DEBUG, we will remove the extra stack frames to avoid confusion for the
|
|
114
|
+
# user
|
|
115
|
+
# This is because the stack trace will include the Syncify wrapper and the background loop thread
|
|
116
|
+
tb = e.__traceback__
|
|
117
|
+
while tb and tb.tb_next:
|
|
118
|
+
if tb.tb_frame.f_code.co_name == "":
|
|
119
|
+
break
|
|
120
|
+
tb = tb.tb_next
|
|
121
|
+
raise e.with_traceback(tb)
|
|
122
|
+
# If the log level is DEBUG, we will keep the extra stack frames to help with debugging
|
|
123
|
+
raise e
|
|
108
124
|
|
|
109
125
|
def call_in_loop_sync(self, coro: Coroutine[Any, Any, R_co]) -> R_co | Iterator[R_co]:
|
|
110
126
|
"""
|
|
@@ -144,6 +160,19 @@ class _BackgroundLoop:
|
|
|
144
160
|
yield v
|
|
145
161
|
except StopAsyncIteration:
|
|
146
162
|
break
|
|
163
|
+
except Exception as e:
|
|
164
|
+
if logger.getEffectiveLevel() > logging.DEBUG:
|
|
165
|
+
# If the log level is not DEBUG, we will remove the extra stack frames to avoid confusion for the
|
|
166
|
+
# user.
|
|
167
|
+
# This is because the stack trace will include the Syncify wrapper and the background loop thread
|
|
168
|
+
tb = e.__traceback__
|
|
169
|
+
while tb and tb.tb_next:
|
|
170
|
+
if tb.tb_frame.f_code.co_name == "":
|
|
171
|
+
break
|
|
172
|
+
tb = tb.tb_next
|
|
173
|
+
raise e.with_traceback(tb)
|
|
174
|
+
# If the log level is DEBUG, we will keep the extra stack frames to help with debugging
|
|
175
|
+
raise e
|
|
147
176
|
|
|
148
177
|
async def aio(self, coro: Coroutine[Any, Any, R_co]) -> R_co:
|
|
149
178
|
"""
|
|
@@ -152,12 +181,25 @@ class _BackgroundLoop:
|
|
|
152
181
|
if self.is_in_loop():
|
|
153
182
|
# If we are already in the background loop, just run the coroutine
|
|
154
183
|
return await coro
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
184
|
+
try:
|
|
185
|
+
# Otherwise, run it in the background loop and wait for the result
|
|
186
|
+
future: concurrent.futures.Future[R_co] = asyncio.run_coroutine_threadsafe(coro, self.loop)
|
|
187
|
+
# Wrap the future in an asyncio Future to await it in an async context
|
|
188
|
+
aio_future: asyncio.Future[R_co] = asyncio.wrap_future(future)
|
|
189
|
+
# await for the future to complete and return its result
|
|
190
|
+
return await aio_future
|
|
191
|
+
except Exception as e:
|
|
192
|
+
if logger.getEffectiveLevel() > logging.DEBUG:
|
|
193
|
+
# If the log level is not DEBUG, we will remove the extra stack frames to avoid confusion for the user
|
|
194
|
+
# This is because the stack trace will include the Syncify wrapper and the background loop thread
|
|
195
|
+
tb = e.__traceback__
|
|
196
|
+
while tb and tb.tb_next:
|
|
197
|
+
if tb.tb_frame.f_code.co_name == "":
|
|
198
|
+
break
|
|
199
|
+
tb = tb.tb_next
|
|
200
|
+
raise e.with_traceback(tb)
|
|
201
|
+
# If the log level is DEBUG, we will keep the extra stack frames to help with debugging
|
|
202
|
+
raise e
|
|
161
203
|
|
|
162
204
|
|
|
163
205
|
class _SyncWrapper:
|
|
@@ -175,23 +217,6 @@ class _SyncWrapper:
|
|
|
175
217
|
self._bg_loop = bg_loop
|
|
176
218
|
self._underlying_obj = underlying_obj
|
|
177
219
|
|
|
178
|
-
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
|
179
|
-
if threading.current_thread().name == self._bg_loop.thread.name:
|
|
180
|
-
# If we are already in the background loop thread, we can call the function directly
|
|
181
|
-
raise AssertionError(
|
|
182
|
-
f"Deadlock detected: blocking call used in syncify thread {self._bg_loop.thread.name} "
|
|
183
|
-
f"when calling function {self.fn}, use .aio() if in an async call."
|
|
184
|
-
)
|
|
185
|
-
# bind method if needed
|
|
186
|
-
coro_fn = self.fn
|
|
187
|
-
|
|
188
|
-
if inspect.isasyncgenfunction(coro_fn):
|
|
189
|
-
# Handle async iterator by converting to sync iterator
|
|
190
|
-
async_gen = coro_fn(*args, **kwargs)
|
|
191
|
-
return self._bg_loop.iterate_in_loop_sync(async_gen)
|
|
192
|
-
else:
|
|
193
|
-
return self._bg_loop.call_in_loop_sync(coro_fn(*args, **kwargs))
|
|
194
|
-
|
|
195
220
|
def __get__(self, instance: Any, owner: Any) -> Any:
|
|
196
221
|
"""
|
|
197
222
|
This method is called when the wrapper is accessed as a method of a class instance.
|
|
@@ -212,20 +237,63 @@ class _SyncWrapper:
|
|
|
212
237
|
functools.update_wrapper(wrapper, self.fn)
|
|
213
238
|
return wrapper
|
|
214
239
|
|
|
240
|
+
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
|
241
|
+
if threading.current_thread().name == self._bg_loop.thread.name:
|
|
242
|
+
# If we are already in the background loop thread, we can call the function directly
|
|
243
|
+
raise AssertionError(
|
|
244
|
+
f"Deadlock detected: blocking call used in syncify thread {self._bg_loop.thread.name} "
|
|
245
|
+
f"when calling function {self.fn}, use .aio() if in an async call."
|
|
246
|
+
)
|
|
247
|
+
try:
|
|
248
|
+
# bind method if needed
|
|
249
|
+
coro_fn = self.fn
|
|
250
|
+
|
|
251
|
+
if inspect.isasyncgenfunction(coro_fn):
|
|
252
|
+
# Handle async iterator by converting to sync iterator
|
|
253
|
+
async_gen = coro_fn(*args, **kwargs)
|
|
254
|
+
return self._bg_loop.iterate_in_loop_sync(async_gen)
|
|
255
|
+
else:
|
|
256
|
+
return self._bg_loop.call_in_loop_sync(coro_fn(*args, **kwargs))
|
|
257
|
+
except Exception as e:
|
|
258
|
+
if logger.getEffectiveLevel() > logging.DEBUG:
|
|
259
|
+
# If the log level is not DEBUG, we will remove the extra stack frames to avoid confusion for the user
|
|
260
|
+
# This is because the stack trace will include the Syncify wrapper and the background loop thread
|
|
261
|
+
tb = e.__traceback__
|
|
262
|
+
while tb and tb.tb_next:
|
|
263
|
+
if tb.tb_frame.f_code.co_name == self.fn.__name__:
|
|
264
|
+
break
|
|
265
|
+
tb = tb.tb_next
|
|
266
|
+
raise e.with_traceback(tb)
|
|
267
|
+
# If the log level is DEBUG, we will keep the extra stack frames to help with debugging
|
|
268
|
+
raise e
|
|
269
|
+
|
|
215
270
|
def aio(self, *args: Any, **kwargs: Any) -> Any:
|
|
216
271
|
fn = self.fn
|
|
217
272
|
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
273
|
+
try:
|
|
274
|
+
if inspect.isasyncgenfunction(fn):
|
|
275
|
+
# If the function is an async generator, we need to handle it differently
|
|
276
|
+
async_iter = fn(*args, **kwargs)
|
|
277
|
+
return self._bg_loop.iterate_in_loop(async_iter)
|
|
278
|
+
else:
|
|
279
|
+
# If we are already in the background loop, just return the coroutine
|
|
280
|
+
coro = fn(*args, **kwargs)
|
|
281
|
+
if hasattr(coro, "__aiter__"):
|
|
282
|
+
# If the coroutine is an async iterator, we need to handle it differently
|
|
283
|
+
return self._bg_loop.iterate_in_loop(coro)
|
|
284
|
+
return self._bg_loop.aio(coro)
|
|
285
|
+
except Exception as e:
|
|
286
|
+
if logger.getEffectiveLevel() > logging.DEBUG:
|
|
287
|
+
# If the log level is not DEBUG, we will remove the extra stack frames to avoid confusion for the user
|
|
288
|
+
# This is because the stack trace will include the Syncify wrapper and the background loop thread
|
|
289
|
+
tb = e.__traceback__
|
|
290
|
+
while tb and tb.tb_next:
|
|
291
|
+
if tb.tb_frame.f_code.co_name == self.fn.__name__:
|
|
292
|
+
break
|
|
293
|
+
tb = tb.tb_next
|
|
294
|
+
raise e.with_traceback(tb)
|
|
295
|
+
# If the log level is DEBUG, we will keep the extra stack frames to help with debugging
|
|
296
|
+
raise e
|
|
229
297
|
|
|
230
298
|
|
|
231
299
|
class Syncify:
|
flyte/types/_type_engine.py
CHANGED
|
@@ -35,6 +35,7 @@ from mashumaro.jsonschema.models import Context, JSONSchema
|
|
|
35
35
|
from mashumaro.jsonschema.plugins import BasePlugin
|
|
36
36
|
from mashumaro.jsonschema.schema import Instance
|
|
37
37
|
from mashumaro.mixins.json import DataClassJSONMixin
|
|
38
|
+
from pydantic import BaseModel
|
|
38
39
|
from typing_extensions import Annotated, get_args, get_origin
|
|
39
40
|
|
|
40
41
|
import flyte.storage as storage
|
|
@@ -352,6 +353,79 @@ class RestrictedTypeTransformer(TypeTransformer[T], ABC):
|
|
|
352
353
|
raise RestrictedTypeError(f"Transformer for type {self.python_type} is restricted currently")
|
|
353
354
|
|
|
354
355
|
|
|
356
|
+
class PydanticTransformer(TypeTransformer[BaseModel]):
|
|
357
|
+
def __init__(self):
|
|
358
|
+
super().__init__("Pydantic Transformer", BaseModel, enable_type_assertions=False)
|
|
359
|
+
|
|
360
|
+
def get_literal_type(self, t: Type[BaseModel]) -> LiteralType:
|
|
361
|
+
schema = t.model_json_schema()
|
|
362
|
+
fields = t.__annotations__.items()
|
|
363
|
+
|
|
364
|
+
literal_type = {}
|
|
365
|
+
for name, python_type in fields:
|
|
366
|
+
try:
|
|
367
|
+
literal_type[name] = TypeEngine.to_literal_type(python_type)
|
|
368
|
+
except Exception as e:
|
|
369
|
+
logger.warning(
|
|
370
|
+
"Field {} of type {} cannot be converted to a literal type. Error: {}".format(name, python_type, e)
|
|
371
|
+
)
|
|
372
|
+
|
|
373
|
+
# This is for attribute access in FlytePropeller.
|
|
374
|
+
ts = TypeStructure(tag="", dataclass_type=literal_type)
|
|
375
|
+
|
|
376
|
+
meta_struct = struct_pb2.Struct()
|
|
377
|
+
meta_struct.update(
|
|
378
|
+
{
|
|
379
|
+
CACHE_KEY_METADATA: {
|
|
380
|
+
SERIALIZATION_FORMAT: MESSAGEPACK,
|
|
381
|
+
}
|
|
382
|
+
}
|
|
383
|
+
)
|
|
384
|
+
|
|
385
|
+
return LiteralType(
|
|
386
|
+
simple=SimpleType.STRUCT,
|
|
387
|
+
metadata=schema,
|
|
388
|
+
structure=ts,
|
|
389
|
+
annotation=TypeAnnotation(annotations=meta_struct),
|
|
390
|
+
)
|
|
391
|
+
|
|
392
|
+
async def to_literal(
|
|
393
|
+
self,
|
|
394
|
+
python_val: BaseModel,
|
|
395
|
+
python_type: Type[BaseModel],
|
|
396
|
+
expected: LiteralType,
|
|
397
|
+
) -> Literal:
|
|
398
|
+
json_str = python_val.model_dump_json()
|
|
399
|
+
dict_obj = json.loads(json_str)
|
|
400
|
+
msgpack_bytes = msgpack.dumps(dict_obj)
|
|
401
|
+
return Literal(scalar=Scalar(binary=Binary(value=msgpack_bytes, tag=MESSAGEPACK)))
|
|
402
|
+
|
|
403
|
+
def from_binary_idl(self, binary_idl_object: Binary, expected_python_type: Type[BaseModel]) -> BaseModel:
|
|
404
|
+
if binary_idl_object.tag == MESSAGEPACK:
|
|
405
|
+
dict_obj = msgpack.loads(binary_idl_object.value, strict_map_key=False)
|
|
406
|
+
json_str = json.dumps(dict_obj)
|
|
407
|
+
python_val = expected_python_type.model_validate_json(
|
|
408
|
+
json_data=json_str, strict=False, context={"deserialize": True}
|
|
409
|
+
)
|
|
410
|
+
return python_val
|
|
411
|
+
else:
|
|
412
|
+
raise TypeTransformerFailedError(f"Unsupported binary format: `{binary_idl_object.tag}`")
|
|
413
|
+
|
|
414
|
+
async def to_python_value(self, lv: Literal, expected_python_type: Type[BaseModel]) -> BaseModel:
|
|
415
|
+
"""
|
|
416
|
+
There are two kinds of literal values to handle:
|
|
417
|
+
1. Protobuf Structs (from the UI)
|
|
418
|
+
2. Binary scalars (from other sources)
|
|
419
|
+
We need to account for both cases accordingly.
|
|
420
|
+
"""
|
|
421
|
+
if lv and lv.HasField("scalar") and lv.scalar.HasField("binary"):
|
|
422
|
+
return self.from_binary_idl(lv.scalar.binary, expected_python_type) # type: ignore
|
|
423
|
+
|
|
424
|
+
json_str = _json_format.MessageToJson(lv.scalar.generic)
|
|
425
|
+
python_val = expected_python_type.model_validate_json(json_str, strict=False, context={"deserialize": True})
|
|
426
|
+
return python_val
|
|
427
|
+
|
|
428
|
+
|
|
355
429
|
class PydanticSchemaPlugin(BasePlugin):
|
|
356
430
|
"""This allows us to generate proper schemas for Pydantic models."""
|
|
357
431
|
|
|
@@ -562,9 +636,8 @@ class DataclassTransformer(TypeTransformer[object]):
|
|
|
562
636
|
|
|
563
637
|
# This is for attribute access in FlytePropeller.
|
|
564
638
|
ts = TypeStructure(tag="", dataclass_type=literal_type)
|
|
565
|
-
from google.protobuf.struct_pb2 import Struct
|
|
566
639
|
|
|
567
|
-
meta_struct = Struct()
|
|
640
|
+
meta_struct = struct_pb2.Struct()
|
|
568
641
|
meta_struct.update(
|
|
569
642
|
{
|
|
570
643
|
CACHE_KEY_METADATA: {
|
|
@@ -627,7 +700,7 @@ class DataclassTransformer(TypeTransformer[object]):
|
|
|
627
700
|
field.type = self._get_origin_type_in_annotation(cast(type, field.type))
|
|
628
701
|
return python_type
|
|
629
702
|
|
|
630
|
-
|
|
703
|
+
def from_binary_idl(self, binary_idl_object: Binary, expected_python_type: Type[T]) -> T:
|
|
631
704
|
if binary_idl_object.tag == MESSAGEPACK:
|
|
632
705
|
if issubclass(expected_python_type, DataClassJSONMixin):
|
|
633
706
|
dict_obj = msgpack.loads(binary_idl_object.value, strict_map_key=False)
|
|
@@ -652,9 +725,10 @@ class DataclassTransformer(TypeTransformer[object]):
|
|
|
652
725
|
"user defined datatypes in Flytekit"
|
|
653
726
|
)
|
|
654
727
|
|
|
655
|
-
if lv.scalar and lv.scalar.binary:
|
|
656
|
-
return
|
|
728
|
+
if lv.HasField("scalar") and lv.scalar.HasField("binary"):
|
|
729
|
+
return self.from_binary_idl(lv.scalar.binary, expected_python_type) # type: ignore
|
|
657
730
|
|
|
731
|
+
# todo: revisit this, it should always be a binary in v2.
|
|
658
732
|
json_str = _json_format.MessageToJson(lv.scalar.generic)
|
|
659
733
|
|
|
660
734
|
# The `from_json` function is provided from mashumaro's `DataClassJSONMixin`.
|
|
@@ -970,11 +1044,10 @@ class TypeEngine(typing.Generic[T]):
|
|
|
970
1044
|
return cls._REGISTRY[python_type.__origin__]
|
|
971
1045
|
|
|
972
1046
|
# Handling UnionType specially - PEP 604
|
|
973
|
-
|
|
974
|
-
import types
|
|
1047
|
+
import types
|
|
975
1048
|
|
|
976
|
-
|
|
977
|
-
|
|
1049
|
+
if isinstance(python_type, types.UnionType):
|
|
1050
|
+
return cls._REGISTRY[types.UnionType]
|
|
978
1051
|
|
|
979
1052
|
if python_type in cls._REGISTRY:
|
|
980
1053
|
return cls._REGISTRY[python_type]
|
|
@@ -2041,6 +2114,7 @@ def _register_default_type_transformers():
|
|
|
2041
2114
|
TypeEngine.register(DictTransformer())
|
|
2042
2115
|
TypeEngine.register(EnumTransformer())
|
|
2043
2116
|
TypeEngine.register(ProtobufTransformer())
|
|
2117
|
+
TypeEngine.register(PydanticTransformer())
|
|
2044
2118
|
|
|
2045
2119
|
# inner type is. Also unsupported are typing's Tuples. Even though you can look inside them, Flyte's type system
|
|
2046
2120
|
# doesn't support these currently.
|