flyte 0.0.1b3__py3-none-any.whl → 0.2.0a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flyte might be problematic. Click here for more details.
- flyte/__init__.py +20 -4
- flyte/_bin/runtime.py +33 -7
- flyte/_build.py +3 -2
- flyte/_cache/cache.py +1 -2
- flyte/_code_bundle/_packaging.py +1 -1
- flyte/_code_bundle/_utils.py +0 -16
- flyte/_code_bundle/bundle.py +43 -12
- flyte/_context.py +8 -2
- flyte/_deploy.py +56 -15
- flyte/_environment.py +45 -4
- flyte/_excepthook.py +37 -0
- flyte/_group.py +2 -1
- flyte/_image.py +8 -4
- flyte/_initialize.py +112 -254
- flyte/_interface.py +3 -3
- flyte/_internal/controllers/__init__.py +19 -6
- flyte/_internal/controllers/_local_controller.py +83 -8
- flyte/_internal/controllers/_trace.py +2 -1
- flyte/_internal/controllers/remote/__init__.py +27 -7
- flyte/_internal/controllers/remote/_action.py +7 -2
- flyte/_internal/controllers/remote/_client.py +5 -1
- flyte/_internal/controllers/remote/_controller.py +159 -26
- flyte/_internal/controllers/remote/_core.py +13 -5
- flyte/_internal/controllers/remote/_informer.py +4 -4
- flyte/_internal/controllers/remote/_service_protocol.py +6 -6
- flyte/_internal/imagebuild/docker_builder.py +12 -1
- flyte/_internal/imagebuild/image_builder.py +16 -11
- flyte/_internal/runtime/convert.py +164 -21
- flyte/_internal/runtime/entrypoints.py +1 -1
- flyte/_internal/runtime/io.py +3 -3
- flyte/_internal/runtime/task_serde.py +140 -20
- flyte/_internal/runtime/taskrunner.py +4 -3
- flyte/_internal/runtime/types_serde.py +1 -1
- flyte/_logging.py +12 -1
- flyte/_map.py +215 -0
- flyte/_pod.py +19 -0
- flyte/_protos/common/list_pb2.py +3 -3
- flyte/_protos/common/list_pb2.pyi +2 -0
- flyte/_protos/logs/dataplane/payload_pb2.py +28 -24
- flyte/_protos/logs/dataplane/payload_pb2.pyi +11 -2
- flyte/_protos/workflow/common_pb2.py +27 -0
- flyte/_protos/workflow/common_pb2.pyi +14 -0
- flyte/_protos/workflow/environment_pb2.py +29 -0
- flyte/_protos/workflow/environment_pb2.pyi +12 -0
- flyte/_protos/workflow/queue_service_pb2.py +40 -41
- flyte/_protos/workflow/queue_service_pb2.pyi +35 -30
- flyte/_protos/workflow/queue_service_pb2_grpc.py +15 -15
- flyte/_protos/workflow/run_definition_pb2.py +61 -61
- flyte/_protos/workflow/run_definition_pb2.pyi +8 -4
- flyte/_protos/workflow/run_service_pb2.py +20 -24
- flyte/_protos/workflow/run_service_pb2.pyi +2 -6
- flyte/_protos/workflow/state_service_pb2.py +36 -28
- flyte/_protos/workflow/state_service_pb2.pyi +19 -15
- flyte/_protos/workflow/state_service_pb2_grpc.py +28 -28
- flyte/_protos/workflow/task_definition_pb2.py +29 -22
- flyte/_protos/workflow/task_definition_pb2.pyi +21 -5
- flyte/_protos/workflow/task_service_pb2.py +27 -11
- flyte/_protos/workflow/task_service_pb2.pyi +29 -1
- flyte/_protos/workflow/task_service_pb2_grpc.py +34 -0
- flyte/_run.py +166 -95
- flyte/_task.py +110 -28
- flyte/_task_environment.py +55 -72
- flyte/_trace.py +6 -14
- flyte/_utils/__init__.py +6 -0
- flyte/_utils/async_cache.py +139 -0
- flyte/_utils/coro_management.py +0 -2
- flyte/_utils/helpers.py +45 -19
- flyte/_utils/org_discovery.py +57 -0
- flyte/_version.py +2 -2
- flyte/cli/__init__.py +3 -0
- flyte/cli/_abort.py +28 -0
- flyte/{_cli → cli}/_common.py +73 -23
- flyte/cli/_create.py +145 -0
- flyte/{_cli → cli}/_delete.py +4 -4
- flyte/{_cli → cli}/_deploy.py +26 -14
- flyte/cli/_gen.py +163 -0
- flyte/{_cli → cli}/_get.py +98 -23
- {union/_cli → flyte/cli}/_params.py +106 -147
- flyte/{_cli → cli}/_run.py +99 -20
- flyte/cli/main.py +166 -0
- flyte/config/__init__.py +3 -0
- flyte/config/_config.py +216 -0
- flyte/config/_internal.py +64 -0
- flyte/config/_reader.py +207 -0
- flyte/errors.py +29 -0
- flyte/extras/_container.py +33 -43
- flyte/io/__init__.py +17 -1
- flyte/io/_dir.py +2 -2
- flyte/io/_file.py +3 -4
- flyte/io/{structured_dataset → _structured_dataset}/basic_dfs.py +1 -1
- flyte/io/{structured_dataset → _structured_dataset}/structured_dataset.py +1 -1
- flyte/{_datastructures.py → models.py} +56 -7
- flyte/remote/__init__.py +2 -1
- flyte/remote/_client/_protocols.py +2 -0
- flyte/remote/_client/auth/_auth_utils.py +14 -0
- flyte/remote/_client/auth/_channel.py +34 -3
- flyte/remote/_client/auth/_token_client.py +3 -3
- flyte/remote/_client/controlplane.py +13 -13
- flyte/remote/_console.py +1 -1
- flyte/remote/_data.py +10 -6
- flyte/remote/_logs.py +89 -29
- flyte/remote/_project.py +8 -9
- flyte/remote/_run.py +228 -131
- flyte/remote/_secret.py +12 -12
- flyte/remote/_task.py +179 -15
- flyte/report/_report.py +4 -4
- flyte/storage/__init__.py +5 -0
- flyte/storage/_config.py +233 -0
- flyte/storage/_storage.py +23 -3
- flyte/syncify/__init__.py +56 -0
- flyte/syncify/_api.py +371 -0
- flyte/types/__init__.py +23 -0
- flyte/types/_interface.py +22 -7
- flyte/{io/pickle/transformer.py → types/_pickle.py} +2 -1
- flyte/types/_type_engine.py +95 -18
- flyte-0.2.0a0.dist-info/METADATA +249 -0
- flyte-0.2.0a0.dist-info/RECORD +218 -0
- {flyte-0.0.1b3.dist-info → flyte-0.2.0a0.dist-info}/entry_points.txt +1 -1
- flyte/_api_commons.py +0 -3
- flyte/_cli/__init__.py +0 -0
- flyte/_cli/_create.py +0 -42
- flyte/_cli/main.py +0 -72
- flyte/_internal/controllers/pbhash.py +0 -39
- flyte/io/_dataframe.py +0 -0
- flyte/io/pickle/__init__.py +0 -0
- flyte-0.0.1b3.dist-info/METADATA +0 -179
- flyte-0.0.1b3.dist-info/RECORD +0 -390
- union/__init__.py +0 -54
- union/_api_commons.py +0 -3
- union/_bin/__init__.py +0 -0
- union/_bin/runtime.py +0 -113
- union/_build.py +0 -25
- union/_cache/__init__.py +0 -12
- union/_cache/cache.py +0 -141
- union/_cache/defaults.py +0 -9
- union/_cache/policy_function_body.py +0 -42
- union/_cli/__init__.py +0 -0
- union/_cli/_common.py +0 -263
- union/_cli/_create.py +0 -40
- union/_cli/_delete.py +0 -23
- union/_cli/_deploy.py +0 -120
- union/_cli/_get.py +0 -162
- union/_cli/_run.py +0 -150
- union/_cli/main.py +0 -72
- union/_code_bundle/__init__.py +0 -8
- union/_code_bundle/_ignore.py +0 -113
- union/_code_bundle/_packaging.py +0 -187
- union/_code_bundle/_utils.py +0 -342
- union/_code_bundle/bundle.py +0 -176
- union/_context.py +0 -146
- union/_datastructures.py +0 -295
- union/_deploy.py +0 -185
- union/_doc.py +0 -29
- union/_docstring.py +0 -26
- union/_environment.py +0 -43
- union/_group.py +0 -31
- union/_hash.py +0 -23
- union/_image.py +0 -760
- union/_initialize.py +0 -585
- union/_interface.py +0 -84
- union/_internal/__init__.py +0 -3
- union/_internal/controllers/__init__.py +0 -77
- union/_internal/controllers/_local_controller.py +0 -77
- union/_internal/controllers/pbhash.py +0 -39
- union/_internal/controllers/remote/__init__.py +0 -40
- union/_internal/controllers/remote/_action.py +0 -131
- union/_internal/controllers/remote/_client.py +0 -43
- union/_internal/controllers/remote/_controller.py +0 -169
- union/_internal/controllers/remote/_core.py +0 -341
- union/_internal/controllers/remote/_informer.py +0 -260
- union/_internal/controllers/remote/_service_protocol.py +0 -44
- union/_internal/imagebuild/__init__.py +0 -11
- union/_internal/imagebuild/docker_builder.py +0 -416
- union/_internal/imagebuild/image_builder.py +0 -243
- union/_internal/imagebuild/remote_builder.py +0 -0
- union/_internal/resolvers/__init__.py +0 -0
- union/_internal/resolvers/_task_module.py +0 -31
- union/_internal/resolvers/common.py +0 -24
- union/_internal/resolvers/default.py +0 -27
- union/_internal/runtime/__init__.py +0 -0
- union/_internal/runtime/convert.py +0 -163
- union/_internal/runtime/entrypoints.py +0 -121
- union/_internal/runtime/io.py +0 -136
- union/_internal/runtime/resources_serde.py +0 -134
- union/_internal/runtime/task_serde.py +0 -202
- union/_internal/runtime/taskrunner.py +0 -179
- union/_internal/runtime/types_serde.py +0 -53
- union/_logging.py +0 -124
- union/_protos/__init__.py +0 -0
- union/_protos/common/authorization_pb2.py +0 -66
- union/_protos/common/authorization_pb2.pyi +0 -106
- union/_protos/common/identifier_pb2.py +0 -71
- union/_protos/common/identifier_pb2.pyi +0 -82
- union/_protos/common/identity_pb2.py +0 -48
- union/_protos/common/identity_pb2.pyi +0 -72
- union/_protos/common/identity_pb2_grpc.py +0 -4
- union/_protos/common/list_pb2.py +0 -36
- union/_protos/common/list_pb2.pyi +0 -69
- union/_protos/common/list_pb2_grpc.py +0 -4
- union/_protos/common/policy_pb2.py +0 -37
- union/_protos/common/policy_pb2.pyi +0 -27
- union/_protos/common/policy_pb2_grpc.py +0 -4
- union/_protos/common/role_pb2.py +0 -37
- union/_protos/common/role_pb2.pyi +0 -51
- union/_protos/common/role_pb2_grpc.py +0 -4
- union/_protos/common/runtime_version_pb2.py +0 -28
- union/_protos/common/runtime_version_pb2.pyi +0 -24
- union/_protos/common/runtime_version_pb2_grpc.py +0 -4
- union/_protos/logs/dataplane/payload_pb2.py +0 -96
- union/_protos/logs/dataplane/payload_pb2.pyi +0 -168
- union/_protos/logs/dataplane/payload_pb2_grpc.py +0 -4
- union/_protos/secret/definition_pb2.py +0 -49
- union/_protos/secret/definition_pb2.pyi +0 -93
- union/_protos/secret/definition_pb2_grpc.py +0 -4
- union/_protos/secret/payload_pb2.py +0 -62
- union/_protos/secret/payload_pb2.pyi +0 -94
- union/_protos/secret/payload_pb2_grpc.py +0 -4
- union/_protos/secret/secret_pb2.py +0 -38
- union/_protos/secret/secret_pb2.pyi +0 -6
- union/_protos/secret/secret_pb2_grpc.py +0 -198
- union/_protos/validate/validate/validate_pb2.py +0 -76
- union/_protos/workflow/node_execution_service_pb2.py +0 -26
- union/_protos/workflow/node_execution_service_pb2.pyi +0 -4
- union/_protos/workflow/node_execution_service_pb2_grpc.py +0 -32
- union/_protos/workflow/queue_service_pb2.py +0 -75
- union/_protos/workflow/queue_service_pb2.pyi +0 -103
- union/_protos/workflow/queue_service_pb2_grpc.py +0 -172
- union/_protos/workflow/run_definition_pb2.py +0 -100
- union/_protos/workflow/run_definition_pb2.pyi +0 -256
- union/_protos/workflow/run_definition_pb2_grpc.py +0 -4
- union/_protos/workflow/run_logs_service_pb2.py +0 -41
- union/_protos/workflow/run_logs_service_pb2.pyi +0 -28
- union/_protos/workflow/run_logs_service_pb2_grpc.py +0 -69
- union/_protos/workflow/run_service_pb2.py +0 -133
- union/_protos/workflow/run_service_pb2.pyi +0 -173
- union/_protos/workflow/run_service_pb2_grpc.py +0 -412
- union/_protos/workflow/state_service_pb2.py +0 -58
- union/_protos/workflow/state_service_pb2.pyi +0 -69
- union/_protos/workflow/state_service_pb2_grpc.py +0 -138
- union/_protos/workflow/task_definition_pb2.py +0 -72
- union/_protos/workflow/task_definition_pb2.pyi +0 -65
- union/_protos/workflow/task_definition_pb2_grpc.py +0 -4
- union/_protos/workflow/task_service_pb2.py +0 -44
- union/_protos/workflow/task_service_pb2.pyi +0 -31
- union/_protos/workflow/task_service_pb2_grpc.py +0 -104
- union/_resources.py +0 -226
- union/_retry.py +0 -32
- union/_reusable_environment.py +0 -25
- union/_run.py +0 -374
- union/_secret.py +0 -61
- union/_task.py +0 -354
- union/_task_environment.py +0 -186
- union/_timeout.py +0 -47
- union/_tools.py +0 -27
- union/_utils/__init__.py +0 -11
- union/_utils/asyn.py +0 -119
- union/_utils/file_handling.py +0 -71
- union/_utils/helpers.py +0 -46
- union/_utils/lazy_module.py +0 -54
- union/_utils/uv_script_parser.py +0 -49
- union/_version.py +0 -21
- union/connectors/__init__.py +0 -0
- union/errors.py +0 -128
- union/extras/__init__.py +0 -5
- union/extras/_container.py +0 -263
- union/io/__init__.py +0 -11
- union/io/_dataframe.py +0 -0
- union/io/_dir.py +0 -425
- union/io/_file.py +0 -418
- union/io/pickle/__init__.py +0 -0
- union/io/pickle/transformer.py +0 -117
- union/io/structured_dataset/__init__.py +0 -122
- union/io/structured_dataset/basic_dfs.py +0 -219
- union/io/structured_dataset/structured_dataset.py +0 -1057
- union/py.typed +0 -0
- union/remote/__init__.py +0 -23
- union/remote/_client/__init__.py +0 -0
- union/remote/_client/_protocols.py +0 -129
- union/remote/_client/auth/__init__.py +0 -12
- union/remote/_client/auth/_authenticators/__init__.py +0 -0
- union/remote/_client/auth/_authenticators/base.py +0 -391
- union/remote/_client/auth/_authenticators/client_credentials.py +0 -73
- union/remote/_client/auth/_authenticators/device_code.py +0 -120
- union/remote/_client/auth/_authenticators/external_command.py +0 -77
- union/remote/_client/auth/_authenticators/factory.py +0 -200
- union/remote/_client/auth/_authenticators/pkce.py +0 -515
- union/remote/_client/auth/_channel.py +0 -184
- union/remote/_client/auth/_client_config.py +0 -83
- union/remote/_client/auth/_default_html.py +0 -32
- union/remote/_client/auth/_grpc_utils/__init__.py +0 -0
- union/remote/_client/auth/_grpc_utils/auth_interceptor.py +0 -204
- union/remote/_client/auth/_grpc_utils/default_metadata_interceptor.py +0 -144
- union/remote/_client/auth/_keyring.py +0 -154
- union/remote/_client/auth/_token_client.py +0 -258
- union/remote/_client/auth/errors.py +0 -16
- union/remote/_client/controlplane.py +0 -86
- union/remote/_data.py +0 -149
- union/remote/_logs.py +0 -74
- union/remote/_project.py +0 -86
- union/remote/_run.py +0 -820
- union/remote/_secret.py +0 -132
- union/remote/_task.py +0 -193
- union/report/__init__.py +0 -3
- union/report/_report.py +0 -178
- union/report/_template.html +0 -124
- union/storage/__init__.py +0 -24
- union/storage/_remote_fs.py +0 -34
- union/storage/_storage.py +0 -247
- union/storage/_utils.py +0 -5
- union/types/__init__.py +0 -11
- union/types/_renderer.py +0 -162
- union/types/_string_literals.py +0 -120
- union/types/_type_engine.py +0 -2131
- union/types/_utils.py +0 -80
- /union/_protos/common/authorization_pb2_grpc.py → /flyte/_protos/workflow/common_pb2_grpc.py +0 -0
- /union/_protos/common/identifier_pb2_grpc.py → /flyte/_protos/workflow/environment_pb2_grpc.py +0 -0
- /flyte/io/{structured_dataset → _structured_dataset}/__init__.py +0 -0
- {flyte-0.0.1b3.dist-info → flyte-0.2.0a0.dist-info}/WHEEL +0 -0
- {flyte-0.0.1b3.dist-info → flyte-0.2.0a0.dist-info}/top_level.txt +0 -0
flyte/_initialize.py
CHANGED
|
@@ -1,260 +1,27 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
import datetime
|
|
4
3
|
import functools
|
|
5
|
-
import os
|
|
6
4
|
import threading
|
|
7
5
|
import typing
|
|
8
6
|
from dataclasses import dataclass, replace
|
|
9
|
-
from datetime import timedelta
|
|
10
7
|
from pathlib import Path
|
|
11
|
-
from typing import TYPE_CHECKING,
|
|
8
|
+
from typing import TYPE_CHECKING, Callable, List, Literal, Optional, TypeVar
|
|
12
9
|
|
|
13
10
|
from flyte.errors import InitializationError
|
|
11
|
+
from flyte.syncify import syncify
|
|
14
12
|
|
|
15
|
-
from .
|
|
16
|
-
from ._logging import initialize_logger
|
|
13
|
+
from ._logging import initialize_logger, logger
|
|
17
14
|
from ._tools import ipython_check
|
|
18
15
|
|
|
19
16
|
if TYPE_CHECKING:
|
|
17
|
+
from flyte.config import Config
|
|
20
18
|
from flyte.remote._client.auth import AuthType, ClientConfig
|
|
21
19
|
from flyte.remote._client.controlplane import ClientSet
|
|
20
|
+
from flyte.storage import Storage
|
|
22
21
|
|
|
23
22
|
Mode = Literal["local", "remote"]
|
|
24
23
|
|
|
25
24
|
|
|
26
|
-
def set_if_exists(d: dict, k: str, val: typing.Any) -> dict:
|
|
27
|
-
"""
|
|
28
|
-
Given a dict ``d`` sets the key ``k`` with value of config ``v``, if the config value ``v`` is set
|
|
29
|
-
and return the updated dictionary.
|
|
30
|
-
"""
|
|
31
|
-
exists = isinstance(val, bool) or bool(val is not None and val)
|
|
32
|
-
if exists:
|
|
33
|
-
d[k] = val
|
|
34
|
-
return d
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
@dataclass(init=True, repr=True, eq=True, frozen=True)
|
|
38
|
-
class Storage(object):
|
|
39
|
-
"""
|
|
40
|
-
Data storage configuration that applies across any provider.
|
|
41
|
-
"""
|
|
42
|
-
|
|
43
|
-
retries: int = 3
|
|
44
|
-
backoff: datetime.timedelta = datetime.timedelta(seconds=5)
|
|
45
|
-
enable_debug: bool = False
|
|
46
|
-
attach_execution_metadata: bool = True
|
|
47
|
-
|
|
48
|
-
_KEY_ENV_VAR_MAPPING: ClassVar[typing.Dict[str, str]] = {
|
|
49
|
-
"enable_debug": "UNION_STORAGE_DEBUG",
|
|
50
|
-
"retries": "UNION_STORAGE_RETRIES",
|
|
51
|
-
"backoff": "UNION_STORAGE_BACKOFF_SECONDS",
|
|
52
|
-
}
|
|
53
|
-
|
|
54
|
-
def get_fsspec_kwargs(self, anonymous: bool = False, /, **kwargs) -> Dict[str, Any]:
|
|
55
|
-
"""
|
|
56
|
-
Returns the configuration as kwargs for constructing an fsspec filesystem.
|
|
57
|
-
"""
|
|
58
|
-
return {}
|
|
59
|
-
|
|
60
|
-
@classmethod
|
|
61
|
-
def _auto_as_kwargs(cls) -> Dict[str, Any]:
|
|
62
|
-
retries = os.getenv(cls._KEY_ENV_VAR_MAPPING["retries"])
|
|
63
|
-
backoff = os.getenv(cls._KEY_ENV_VAR_MAPPING["backoff"])
|
|
64
|
-
enable_debug = os.getenv(cls._KEY_ENV_VAR_MAPPING["enable_debug"])
|
|
65
|
-
|
|
66
|
-
kwargs: Dict[str, Any] = {}
|
|
67
|
-
kwargs = set_if_exists(kwargs, "enable_debug", enable_debug)
|
|
68
|
-
kwargs = set_if_exists(kwargs, "retries", retries)
|
|
69
|
-
kwargs = set_if_exists(kwargs, "backoff", backoff)
|
|
70
|
-
return kwargs
|
|
71
|
-
|
|
72
|
-
@classmethod
|
|
73
|
-
def auto(cls) -> Storage:
|
|
74
|
-
"""
|
|
75
|
-
Construct the config object automatically from environment variables.
|
|
76
|
-
"""
|
|
77
|
-
return cls(**cls._auto_as_kwargs())
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
@dataclass(init=True, repr=True, eq=True, frozen=True)
|
|
81
|
-
class S3(Storage):
|
|
82
|
-
"""
|
|
83
|
-
S3 specific configuration
|
|
84
|
-
"""
|
|
85
|
-
|
|
86
|
-
endpoint: typing.Optional[str] = None
|
|
87
|
-
access_key_id: typing.Optional[str] = None
|
|
88
|
-
secret_access_key: typing.Optional[str] = None
|
|
89
|
-
|
|
90
|
-
_KEY_ENV_VAR_MAPPING: ClassVar[typing.Dict[str, str]] = {
|
|
91
|
-
"endpoint": "FLYTE_AWS_ENDPOINT",
|
|
92
|
-
"access_key_id": "FLYTE_AWS_ACCESS_KEY_ID",
|
|
93
|
-
"secret_access_key": "FLYTE_AWS_SECRET_ACCESS_KEY",
|
|
94
|
-
} | Storage._KEY_ENV_VAR_MAPPING
|
|
95
|
-
|
|
96
|
-
# Refer to https://github.com/developmentseed/obstore/blob/33654fc37f19a657689eb93327b621e9f9e01494/obstore/python/obstore/store/_aws.pyi#L11
|
|
97
|
-
# for key and secret
|
|
98
|
-
_CONFIG_KEY_FSSPEC_S3_KEY_ID: ClassVar = "access_key_id"
|
|
99
|
-
_CONFIG_KEY_FSSPEC_S3_SECRET: ClassVar = "secret_access_key"
|
|
100
|
-
_CONFIG_KEY_ENDPOINT: ClassVar = "endpoint_url"
|
|
101
|
-
_KEY_SKIP_SIGNATURE: ClassVar = "skip_signature"
|
|
102
|
-
|
|
103
|
-
@classmethod
|
|
104
|
-
def auto(cls) -> S3:
|
|
105
|
-
"""
|
|
106
|
-
:return: Config
|
|
107
|
-
"""
|
|
108
|
-
endpoint = os.getenv(cls._KEY_ENV_VAR_MAPPING["endpoint"], None)
|
|
109
|
-
access_key_id = os.getenv(cls._KEY_ENV_VAR_MAPPING["access_key_id"], None)
|
|
110
|
-
secret_access_key = os.getenv(cls._KEY_ENV_VAR_MAPPING["secret_access_key"], None)
|
|
111
|
-
|
|
112
|
-
kwargs = super()._auto_as_kwargs()
|
|
113
|
-
kwargs = set_if_exists(kwargs, "endpoint", endpoint)
|
|
114
|
-
kwargs = set_if_exists(kwargs, "access_key_id", access_key_id)
|
|
115
|
-
kwargs = set_if_exists(kwargs, "secret_access_key", secret_access_key)
|
|
116
|
-
|
|
117
|
-
return S3(**kwargs)
|
|
118
|
-
|
|
119
|
-
@classmethod
|
|
120
|
-
def for_sandbox(cls) -> S3:
|
|
121
|
-
"""
|
|
122
|
-
:return:
|
|
123
|
-
"""
|
|
124
|
-
kwargs = super()._auto_as_kwargs()
|
|
125
|
-
final_kwargs = kwargs | {
|
|
126
|
-
"endpoint": "http://localhost:4566",
|
|
127
|
-
"access_key_id": "minio",
|
|
128
|
-
"secret_access_key": "miniostorage",
|
|
129
|
-
}
|
|
130
|
-
return S3(**final_kwargs)
|
|
131
|
-
|
|
132
|
-
def get_fsspec_kwargs(self, anonymous: bool = False, /, **kwargs) -> Dict[str, Any]:
|
|
133
|
-
# Construct the config object
|
|
134
|
-
config: Dict[str, Any] = {}
|
|
135
|
-
if self._CONFIG_KEY_FSSPEC_S3_KEY_ID in kwargs or self.access_key_id:
|
|
136
|
-
config[self._CONFIG_KEY_FSSPEC_S3_KEY_ID] = kwargs.pop(
|
|
137
|
-
self._CONFIG_KEY_FSSPEC_S3_KEY_ID, self.access_key_id
|
|
138
|
-
)
|
|
139
|
-
if self._CONFIG_KEY_FSSPEC_S3_SECRET in kwargs or self.secret_access_key:
|
|
140
|
-
config[self._CONFIG_KEY_FSSPEC_S3_SECRET] = kwargs.pop(
|
|
141
|
-
self._CONFIG_KEY_FSSPEC_S3_SECRET, self.secret_access_key
|
|
142
|
-
)
|
|
143
|
-
if self._CONFIG_KEY_ENDPOINT in kwargs or self.endpoint:
|
|
144
|
-
config["endpoint_url"] = kwargs.pop(self._CONFIG_KEY_ENDPOINT, self.endpoint)
|
|
145
|
-
|
|
146
|
-
retries = kwargs.pop("retries", self.retries)
|
|
147
|
-
backoff = kwargs.pop("backoff", self.backoff)
|
|
148
|
-
|
|
149
|
-
if anonymous:
|
|
150
|
-
config[self._KEY_SKIP_SIGNATURE] = True
|
|
151
|
-
|
|
152
|
-
retry_config = {
|
|
153
|
-
"max_retries": retries,
|
|
154
|
-
"backoff": {
|
|
155
|
-
"base": 2,
|
|
156
|
-
"init_backoff": backoff,
|
|
157
|
-
"max_backoff": timedelta(seconds=16),
|
|
158
|
-
},
|
|
159
|
-
"retry_timeout": timedelta(minutes=3),
|
|
160
|
-
}
|
|
161
|
-
|
|
162
|
-
client_options = {"timeout": "99999s", "allow_http": True}
|
|
163
|
-
|
|
164
|
-
if config:
|
|
165
|
-
kwargs["config"] = config
|
|
166
|
-
kwargs["client_options"] = client_options or None
|
|
167
|
-
kwargs["retry_config"] = retry_config or None
|
|
168
|
-
|
|
169
|
-
return kwargs
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
@dataclass(init=True, repr=True, eq=True, frozen=True)
|
|
173
|
-
class GCS(Storage):
|
|
174
|
-
"""
|
|
175
|
-
Any GCS specific configuration.
|
|
176
|
-
"""
|
|
177
|
-
|
|
178
|
-
gsutil_parallelism: bool = False
|
|
179
|
-
|
|
180
|
-
_KEY_ENV_VAR_MAPPING: ClassVar[dict[str, str]] = {
|
|
181
|
-
"gsutil_parallelism": "GCP_GSUTIL_PARALLELISM",
|
|
182
|
-
}
|
|
183
|
-
|
|
184
|
-
@classmethod
|
|
185
|
-
def auto(cls) -> GCS:
|
|
186
|
-
gsutil_parallelism = os.getenv(cls._KEY_ENV_VAR_MAPPING["gsutil_parallelism"], None)
|
|
187
|
-
|
|
188
|
-
kwargs: Dict[str, Any] = {}
|
|
189
|
-
kwargs = set_if_exists(kwargs, "gsutil_parallelism", gsutil_parallelism)
|
|
190
|
-
return GCS(**kwargs)
|
|
191
|
-
|
|
192
|
-
def get_fsspec_kwargs(self, anonymous: bool = False, /, **kwargs) -> Dict[str, Any]:
|
|
193
|
-
return kwargs
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
@dataclass(init=True, repr=True, eq=True, frozen=True)
|
|
197
|
-
class ABFS(Storage):
|
|
198
|
-
"""
|
|
199
|
-
Any Azure Blob Storage specific configuration.
|
|
200
|
-
"""
|
|
201
|
-
|
|
202
|
-
account_name: typing.Optional[str] = None
|
|
203
|
-
account_key: typing.Optional[str] = None
|
|
204
|
-
tenant_id: typing.Optional[str] = None
|
|
205
|
-
client_id: typing.Optional[str] = None
|
|
206
|
-
client_secret: typing.Optional[str] = None
|
|
207
|
-
|
|
208
|
-
_KEY_ENV_VAR_MAPPING: ClassVar[dict[str, str]] = {
|
|
209
|
-
"account_name": "AZURE_STORAGE_ACCOUNT_NAME",
|
|
210
|
-
"account_key": "AZURE_STORAGE_ACCOUNT_KEY",
|
|
211
|
-
"tenant_id": "AZURE_TENANT_ID",
|
|
212
|
-
"client_id": "AZURE_CLIENT_ID",
|
|
213
|
-
"client_secret": "AZURE_CLIENT_SECRET",
|
|
214
|
-
}
|
|
215
|
-
_KEY_SKIP_SIGNATURE: ClassVar = "skip_signature"
|
|
216
|
-
|
|
217
|
-
@classmethod
|
|
218
|
-
def auto(cls) -> ABFS:
|
|
219
|
-
account_name = os.getenv(cls._KEY_ENV_VAR_MAPPING["account_name"], None)
|
|
220
|
-
account_key = os.getenv(cls._KEY_ENV_VAR_MAPPING["account_key"], None)
|
|
221
|
-
tenant_id = os.getenv(cls._KEY_ENV_VAR_MAPPING["tenant_id"], None)
|
|
222
|
-
client_id = os.getenv(cls._KEY_ENV_VAR_MAPPING["client_id"], None)
|
|
223
|
-
client_secret = os.getenv(cls._KEY_ENV_VAR_MAPPING["client_secret"], None)
|
|
224
|
-
|
|
225
|
-
kwargs: Dict[str, Any] = {}
|
|
226
|
-
kwargs = set_if_exists(kwargs, "account_name", account_name)
|
|
227
|
-
kwargs = set_if_exists(kwargs, "account_key", account_key)
|
|
228
|
-
kwargs = set_if_exists(kwargs, "tenant_id", tenant_id)
|
|
229
|
-
kwargs = set_if_exists(kwargs, "client_id", client_id)
|
|
230
|
-
kwargs = set_if_exists(kwargs, "client_secret", client_secret)
|
|
231
|
-
return ABFS(**kwargs)
|
|
232
|
-
|
|
233
|
-
def get_fsspec_kwargs(self, anonymous: bool = False, /, **kwargs) -> Dict[str, Any]:
|
|
234
|
-
config: Dict[str, Any] = {}
|
|
235
|
-
if "account_name" in kwargs or self.account_name:
|
|
236
|
-
config["account_name"] = kwargs.get("account_name", self.account_name)
|
|
237
|
-
if "account_key" in kwargs or self.account_key:
|
|
238
|
-
config["account_key"] = kwargs.get("account_key", self.account_key)
|
|
239
|
-
if "client_id" in kwargs or self.client_id:
|
|
240
|
-
config["client_id"] = kwargs.get("client_id", self.client_id)
|
|
241
|
-
if "client_secret" in kwargs or self.client_secret:
|
|
242
|
-
config["client_secret"] = kwargs.get("client_secret", self.client_secret)
|
|
243
|
-
if "tenant_id" in kwargs or self.tenant_id:
|
|
244
|
-
config["tenant_id"] = kwargs.get("tenant_id", self.tenant_id)
|
|
245
|
-
|
|
246
|
-
if anonymous:
|
|
247
|
-
config[self._KEY_SKIP_SIGNATURE] = True
|
|
248
|
-
|
|
249
|
-
client_options = {"timeout": "99999s", "allow_http": "true"}
|
|
250
|
-
|
|
251
|
-
if config:
|
|
252
|
-
kwargs["config"] = config
|
|
253
|
-
kwargs["client_options"] = client_options
|
|
254
|
-
|
|
255
|
-
return kwargs
|
|
256
|
-
|
|
257
|
-
|
|
258
25
|
@dataclass(init=True, repr=True, eq=True, frozen=True, kw_only=True)
|
|
259
26
|
class CommonInit:
|
|
260
27
|
"""
|
|
@@ -265,6 +32,7 @@ class CommonInit:
|
|
|
265
32
|
org: str | None = None
|
|
266
33
|
project: str | None = None
|
|
267
34
|
domain: str | None = None
|
|
35
|
+
batch_size: int = 1000
|
|
268
36
|
|
|
269
37
|
|
|
270
38
|
@dataclass(init=True, kw_only=True, repr=True, eq=True, frozen=True)
|
|
@@ -303,11 +71,10 @@ async def _initialize_client(
|
|
|
303
71
|
"""
|
|
304
72
|
from flyte.remote._client.controlplane import ClientSet
|
|
305
73
|
|
|
306
|
-
if endpoint
|
|
74
|
+
if endpoint:
|
|
307
75
|
return await ClientSet.for_endpoint(
|
|
308
76
|
endpoint,
|
|
309
77
|
insecure=insecure,
|
|
310
|
-
api_key=api_key,
|
|
311
78
|
insecure_skip_verify=insecure_skip_verify,
|
|
312
79
|
auth_type=auth_type,
|
|
313
80
|
headless=headless,
|
|
@@ -320,10 +87,29 @@ async def _initialize_client(
|
|
|
320
87
|
rpc_retries=rpc_retries,
|
|
321
88
|
http_proxy_url=http_proxy_url,
|
|
322
89
|
)
|
|
323
|
-
|
|
90
|
+
elif api_key:
|
|
91
|
+
return await ClientSet.for_api_key(
|
|
92
|
+
api_key,
|
|
93
|
+
insecure=insecure,
|
|
94
|
+
insecure_skip_verify=insecure_skip_verify,
|
|
95
|
+
auth_type=auth_type,
|
|
96
|
+
headless=headless,
|
|
97
|
+
ca_cert_file_path=ca_cert_file_path,
|
|
98
|
+
command=command,
|
|
99
|
+
proxy_command=proxy_command,
|
|
100
|
+
client_id=client_id,
|
|
101
|
+
client_credentials_secret=client_credentials_secret,
|
|
102
|
+
client_config=client_config,
|
|
103
|
+
rpc_retries=rpc_retries,
|
|
104
|
+
http_proxy_url=http_proxy_url,
|
|
105
|
+
)
|
|
324
106
|
|
|
107
|
+
raise InitializationError(
|
|
108
|
+
"MissingEndpointOrApiKeyError", "user", "Either endpoint or api_key must be provided to initialize the client."
|
|
109
|
+
)
|
|
325
110
|
|
|
326
|
-
|
|
111
|
+
|
|
112
|
+
@syncify
|
|
327
113
|
async def init(
|
|
328
114
|
org: str | None = None,
|
|
329
115
|
project: str | None = None,
|
|
@@ -345,6 +131,7 @@ async def init(
|
|
|
345
131
|
rpc_retries: int = 3,
|
|
346
132
|
http_proxy_url: str | None = None,
|
|
347
133
|
storage: Storage | None = None,
|
|
134
|
+
batch_size: int = 1000,
|
|
348
135
|
) -> None:
|
|
349
136
|
"""
|
|
350
137
|
Initialize the Flyte system with the given configuration. This method should be called before any other Flyte
|
|
@@ -353,14 +140,13 @@ async def init(
|
|
|
353
140
|
:param project: Optional project name (not used in this implementation)
|
|
354
141
|
:param domain: Optional domain name (not used in this implementation)
|
|
355
142
|
:param root_dir: Optional root directory from which to determine how to load files, and find paths to files.
|
|
143
|
+
This is useful for determining the root directory for the current project, and for locating files like config etc.
|
|
144
|
+
also use to determine all the code that needs to be copied to the remote location.
|
|
356
145
|
defaults to the editable install directory if the cwd is in a Python editable install, else just the cwd.
|
|
357
146
|
:param log_level: Optional logging level for the logger, default is set using the default initialization policies
|
|
358
147
|
:param api_key: Optional API key for authentication
|
|
359
148
|
:param endpoint: Optional API endpoint URL
|
|
360
149
|
:param headless: Optional Whether to run in headless mode
|
|
361
|
-
:param mode: Optional execution model (local, remote). Default is local. When local is used,
|
|
362
|
-
the execution will be done locally. When remote is used, the execution will be sent to a remote server,
|
|
363
|
-
In the remote case, the endpoint or api_key must be set.
|
|
364
150
|
:param insecure_skip_verify: Whether to skip SSL certificate verification
|
|
365
151
|
:param auth_client_config: Optional client configuration for authentication
|
|
366
152
|
:param auth_type: The authentication type to use (Pkce, ClientSecret, ExternalCommand, DeviceFlow)
|
|
@@ -378,10 +164,12 @@ async def init(
|
|
|
378
164
|
:param insecure: insecure flag for the client
|
|
379
165
|
:param storage: Optional blob store (S3, GCS, Azure) configuration if needed to access (i.e. using Minio)
|
|
380
166
|
:param org: Optional organization override for the client. Should be set by auth instead.
|
|
167
|
+
:param batch_size: Optional batch size for operations that use listings, defaults to 1000, so limit larger than
|
|
168
|
+
batch_size will be split into multiple requests.
|
|
381
169
|
|
|
382
170
|
:return: None
|
|
383
171
|
"""
|
|
384
|
-
from flyte._utils import get_cwd_editable_install
|
|
172
|
+
from flyte._utils import get_cwd_editable_install, org_from_endpoint, sanitize_endpoint
|
|
385
173
|
|
|
386
174
|
interactive_mode = ipython_check()
|
|
387
175
|
|
|
@@ -391,6 +179,8 @@ async def init(
|
|
|
391
179
|
|
|
392
180
|
global _init_config # noqa: PLW0603
|
|
393
181
|
|
|
182
|
+
endpoint = sanitize_endpoint(endpoint)
|
|
183
|
+
|
|
394
184
|
with _init_lock:
|
|
395
185
|
client = None
|
|
396
186
|
if endpoint or api_key:
|
|
@@ -418,10 +208,67 @@ async def init(
|
|
|
418
208
|
domain=domain,
|
|
419
209
|
client=client,
|
|
420
210
|
storage=storage,
|
|
421
|
-
org=org,
|
|
211
|
+
org=org or org_from_endpoint(endpoint),
|
|
212
|
+
batch_size=batch_size,
|
|
422
213
|
)
|
|
423
214
|
|
|
424
215
|
|
|
216
|
+
@syncify
|
|
217
|
+
async def init_from_config(
|
|
218
|
+
path_or_config: str | Config | None = None, root_dir: Path | None = None, log_level: int | None = None
|
|
219
|
+
) -> None:
|
|
220
|
+
"""
|
|
221
|
+
Initialize the Flyte system using a configuration file or Config object. This method should be called before any
|
|
222
|
+
other Flyte remote API methods are called. Thread-safe implementation.
|
|
223
|
+
|
|
224
|
+
:param path_or_config: Path to the configuration file or Config object
|
|
225
|
+
:param root_dir: Optional root directory from which to determine how to load files, and find paths to
|
|
226
|
+
files like config etc. For example if one uses the copy-style=="all", it is essential to determine the
|
|
227
|
+
root directory for the current project. If not provided, it defaults to the editable install directory or
|
|
228
|
+
if not available, the current working directory.
|
|
229
|
+
:param log_level: Optional logging level for the framework logger,
|
|
230
|
+
default is set using the default initialization policies
|
|
231
|
+
:return: None
|
|
232
|
+
"""
|
|
233
|
+
import flyte.config as config
|
|
234
|
+
|
|
235
|
+
cfg: config.Config
|
|
236
|
+
if path_or_config is None or isinstance(path_or_config, str):
|
|
237
|
+
# If a string is passed, treat it as a path to the config file
|
|
238
|
+
if path_or_config:
|
|
239
|
+
if not Path(path_or_config).exists():
|
|
240
|
+
raise InitializationError(
|
|
241
|
+
"ConfigFileNotFoundError",
|
|
242
|
+
"user",
|
|
243
|
+
f"Configuration file '{path_or_config}' does not exist., current working directory is {Path.cwd()}",
|
|
244
|
+
)
|
|
245
|
+
if root_dir and path_or_config:
|
|
246
|
+
cfg = config.auto(str(root_dir / path_or_config))
|
|
247
|
+
else:
|
|
248
|
+
cfg = config.auto(path_or_config)
|
|
249
|
+
else:
|
|
250
|
+
# If a Config object is passed, use it directly
|
|
251
|
+
cfg = path_or_config
|
|
252
|
+
|
|
253
|
+
logger.debug(f"Flyte config initialized as {cfg}")
|
|
254
|
+
await init.aio(
|
|
255
|
+
org=cfg.task.org,
|
|
256
|
+
project=cfg.task.project,
|
|
257
|
+
domain=cfg.task.domain,
|
|
258
|
+
endpoint=cfg.platform.endpoint,
|
|
259
|
+
insecure=cfg.platform.insecure,
|
|
260
|
+
insecure_skip_verify=cfg.platform.insecure_skip_verify,
|
|
261
|
+
ca_cert_file_path=cfg.platform.ca_cert_file_path,
|
|
262
|
+
auth_type=cfg.platform.auth_mode,
|
|
263
|
+
command=cfg.platform.command,
|
|
264
|
+
proxy_command=cfg.platform.proxy_command,
|
|
265
|
+
client_id=cfg.platform.client_id,
|
|
266
|
+
client_credentials_secret=cfg.platform.client_credentials_secret,
|
|
267
|
+
root_dir=root_dir,
|
|
268
|
+
log_level=log_level,
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
|
|
425
272
|
def _get_init_config() -> Optional[_InitConfig]:
|
|
426
273
|
"""
|
|
427
274
|
Get the current initialization configuration. Thread-safe implementation.
|
|
@@ -449,7 +296,7 @@ def get_common_config() -> CommonInit:
|
|
|
449
296
|
return cfg
|
|
450
297
|
|
|
451
298
|
|
|
452
|
-
def get_storage() -> Storage:
|
|
299
|
+
def get_storage() -> Storage | None:
|
|
453
300
|
"""
|
|
454
301
|
Get the current storage configuration. Thread-safe implementation.
|
|
455
302
|
|
|
@@ -463,9 +310,6 @@ def get_storage() -> Storage:
|
|
|
463
310
|
"Configuration has not been initialized. Call flyte.init() with a valid endpoint or",
|
|
464
311
|
" api-key before using this function.",
|
|
465
312
|
)
|
|
466
|
-
if cfg.storage is None:
|
|
467
|
-
# return default local storage
|
|
468
|
-
return typing.cast(Storage, cfg.replace(storage=Storage()).storage)
|
|
469
313
|
return cfg.storage
|
|
470
314
|
|
|
471
315
|
|
|
@@ -495,19 +339,33 @@ def is_initialized() -> bool:
|
|
|
495
339
|
return _get_init_config() is not None
|
|
496
340
|
|
|
497
341
|
|
|
498
|
-
def initialize_in_cluster(
|
|
342
|
+
def initialize_in_cluster() -> None:
|
|
499
343
|
"""
|
|
500
344
|
Initialize the system for in-cluster execution. This is a placeholder function and does not perform any actions.
|
|
501
345
|
|
|
502
346
|
:return: None
|
|
503
347
|
"""
|
|
504
|
-
init(
|
|
348
|
+
init()
|
|
505
349
|
|
|
506
350
|
|
|
507
351
|
# Define a generic type variable for the decorated function
|
|
508
352
|
T = TypeVar("T", bound=Callable)
|
|
509
353
|
|
|
510
354
|
|
|
355
|
+
def ensure_client():
|
|
356
|
+
"""
|
|
357
|
+
Ensure that the client is initialized. If not, raise an InitializationError.
|
|
358
|
+
This function is used to check if the client is initialized before executing any Flyte remote API methods.
|
|
359
|
+
"""
|
|
360
|
+
if _get_init_config() is None or _get_init_config().client is None:
|
|
361
|
+
raise InitializationError(
|
|
362
|
+
"ClientNotInitializedError",
|
|
363
|
+
"user",
|
|
364
|
+
"Client has not been initialized. Call flyte.init() with a valid endpoint"
|
|
365
|
+
" or api-key before using this function.",
|
|
366
|
+
)
|
|
367
|
+
|
|
368
|
+
|
|
511
369
|
def requires_client(func: T) -> T:
|
|
512
370
|
"""
|
|
513
371
|
Decorator that checks if the client has been initialized before executing the function.
|
|
@@ -518,7 +376,7 @@ def requires_client(func: T) -> T:
|
|
|
518
376
|
"""
|
|
519
377
|
|
|
520
378
|
@functools.wraps(func)
|
|
521
|
-
def wrapper(*args, **kwargs) -> T:
|
|
379
|
+
async def wrapper(*args, **kwargs) -> T:
|
|
522
380
|
init_config = _get_init_config()
|
|
523
381
|
if init_config is None or init_config.client is None:
|
|
524
382
|
raise InitializationError(
|
flyte/_interface.py
CHANGED
|
@@ -62,12 +62,12 @@ def extract_return_annotation(return_annotation: Union[Type, Tuple, None]) -> Di
|
|
|
62
62
|
# Options 1 and 2
|
|
63
63
|
bases = return_annotation.__bases__ # type: ignore
|
|
64
64
|
if len(bases) == 1 and bases[0] is tuple and hasattr(return_annotation, "_fields"):
|
|
65
|
-
|
|
65
|
+
# Task returns named tuple
|
|
66
66
|
return dict(get_type_hints(cast(Type, return_annotation), include_extras=True))
|
|
67
67
|
|
|
68
68
|
if hasattr(return_annotation, "__origin__") and return_annotation.__origin__ is tuple: # type: ignore
|
|
69
69
|
# Handle option 3
|
|
70
|
-
|
|
70
|
+
# Task returns unnamed typing.Tuple
|
|
71
71
|
if len(return_annotation.__args__) == 1: # type: ignore
|
|
72
72
|
raise TypeError("Tuples should be used to indicate multiple return values, found only one return variable.")
|
|
73
73
|
ra = get_args(return_annotation)
|
|
@@ -80,5 +80,5 @@ def extract_return_annotation(return_annotation: Union[Type, Tuple, None]) -> Di
|
|
|
80
80
|
|
|
81
81
|
else:
|
|
82
82
|
# Handle all other single return types
|
|
83
|
-
|
|
83
|
+
# Task returns unnamed native tuple
|
|
84
84
|
return {default_output_name(): cast(Type, return_annotation)}
|
|
@@ -1,8 +1,9 @@
|
|
|
1
|
+
import concurrent.futures
|
|
1
2
|
import threading
|
|
2
|
-
from typing import Any, Literal, Optional, Protocol, Tuple, TypeVar
|
|
3
|
+
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Protocol, Tuple, TypeVar
|
|
3
4
|
|
|
4
|
-
from flyte._datastructures import ActionID, NativeInterface
|
|
5
5
|
from flyte._task import TaskTemplate
|
|
6
|
+
from flyte.models import ActionID, NativeInterface
|
|
6
7
|
|
|
7
8
|
from ._trace import TraceInfo
|
|
8
9
|
|
|
@@ -10,6 +11,9 @@ __all__ = ["Controller", "ControllerType", "TraceInfo", "create_controller", "ge
|
|
|
10
11
|
|
|
11
12
|
from ..._protos.workflow import task_definition_pb2
|
|
12
13
|
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
import concurrent.futures
|
|
16
|
+
|
|
13
17
|
ControllerType = Literal["local", "remote"]
|
|
14
18
|
|
|
15
19
|
R = TypeVar("R")
|
|
@@ -28,6 +32,15 @@ class Controller(Protocol):
|
|
|
28
32
|
"""
|
|
29
33
|
...
|
|
30
34
|
|
|
35
|
+
def submit_sync(self, _task: TaskTemplate, *args, **kwargs) -> concurrent.futures.Future:
|
|
36
|
+
"""
|
|
37
|
+
This should call the async submit method above, but return a concurrent Future object that can be
|
|
38
|
+
used in a blocking wait or wrapped in an async future. This is called when
|
|
39
|
+
a) a synchronous task is kicked off locally,
|
|
40
|
+
b) a running task (of either kind) kicks off a downstream synchronous task.
|
|
41
|
+
"""
|
|
42
|
+
...
|
|
43
|
+
|
|
31
44
|
async def submit_task_ref(self, _task: task_definition_pb2.TaskDetails, *args, **kwargs) -> Any:
|
|
32
45
|
"""
|
|
33
46
|
Submit a task reference to the controller asynchronously and wait for the result. This is async and will block
|
|
@@ -47,12 +60,12 @@ class Controller(Protocol):
|
|
|
47
60
|
async def watch_for_errors(self): ...
|
|
48
61
|
|
|
49
62
|
async def get_action_outputs(
|
|
50
|
-
self, _interface: NativeInterface,
|
|
63
|
+
self, _interface: NativeInterface, _func: Callable, *args, **kwargs
|
|
51
64
|
) -> Tuple[TraceInfo, bool]:
|
|
52
65
|
"""
|
|
53
66
|
This method returns the outputs of the action, if it is available.
|
|
54
67
|
:param _interface: NativeInterface
|
|
55
|
-
:param
|
|
68
|
+
:param _func: Function name
|
|
56
69
|
:param args: Arguments
|
|
57
70
|
:param kwargs: Keyword arguments
|
|
58
71
|
:return: TraceInfo object and a boolean indicating if the action was found.
|
|
@@ -81,13 +94,13 @@ class _ControllerState:
|
|
|
81
94
|
lock = threading.Lock()
|
|
82
95
|
|
|
83
96
|
|
|
84
|
-
|
|
97
|
+
def get_controller() -> Controller:
|
|
85
98
|
"""
|
|
86
99
|
Get the controller instance. Raise an error if it has not been created.
|
|
87
100
|
"""
|
|
88
101
|
if _ControllerState.controller is not None:
|
|
89
102
|
return _ControllerState.controller
|
|
90
|
-
raise RuntimeError("Controller is not initialized. Please call
|
|
103
|
+
raise RuntimeError("Controller is not initialized. Please call create_controller() first.")
|
|
91
104
|
|
|
92
105
|
|
|
93
106
|
def create_controller(
|