flyte 0.0.1b3__py3-none-any.whl → 0.2.0a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flyte might be problematic. Click here for more details.
- flyte/__init__.py +20 -4
- flyte/_bin/runtime.py +33 -7
- flyte/_build.py +3 -2
- flyte/_cache/cache.py +1 -2
- flyte/_code_bundle/_packaging.py +1 -1
- flyte/_code_bundle/_utils.py +0 -16
- flyte/_code_bundle/bundle.py +43 -12
- flyte/_context.py +8 -2
- flyte/_deploy.py +56 -15
- flyte/_environment.py +45 -4
- flyte/_excepthook.py +37 -0
- flyte/_group.py +2 -1
- flyte/_image.py +8 -4
- flyte/_initialize.py +112 -254
- flyte/_interface.py +3 -3
- flyte/_internal/controllers/__init__.py +19 -6
- flyte/_internal/controllers/_local_controller.py +83 -8
- flyte/_internal/controllers/_trace.py +2 -1
- flyte/_internal/controllers/remote/__init__.py +27 -7
- flyte/_internal/controllers/remote/_action.py +7 -2
- flyte/_internal/controllers/remote/_client.py +5 -1
- flyte/_internal/controllers/remote/_controller.py +159 -26
- flyte/_internal/controllers/remote/_core.py +13 -5
- flyte/_internal/controllers/remote/_informer.py +4 -4
- flyte/_internal/controllers/remote/_service_protocol.py +6 -6
- flyte/_internal/imagebuild/docker_builder.py +12 -1
- flyte/_internal/imagebuild/image_builder.py +16 -11
- flyte/_internal/runtime/convert.py +164 -21
- flyte/_internal/runtime/entrypoints.py +1 -1
- flyte/_internal/runtime/io.py +3 -3
- flyte/_internal/runtime/task_serde.py +140 -20
- flyte/_internal/runtime/taskrunner.py +4 -3
- flyte/_internal/runtime/types_serde.py +1 -1
- flyte/_logging.py +12 -1
- flyte/_map.py +215 -0
- flyte/_pod.py +19 -0
- flyte/_protos/common/list_pb2.py +3 -3
- flyte/_protos/common/list_pb2.pyi +2 -0
- flyte/_protos/logs/dataplane/payload_pb2.py +28 -24
- flyte/_protos/logs/dataplane/payload_pb2.pyi +11 -2
- flyte/_protos/workflow/common_pb2.py +27 -0
- flyte/_protos/workflow/common_pb2.pyi +14 -0
- flyte/_protos/workflow/environment_pb2.py +29 -0
- flyte/_protos/workflow/environment_pb2.pyi +12 -0
- flyte/_protos/workflow/queue_service_pb2.py +40 -41
- flyte/_protos/workflow/queue_service_pb2.pyi +35 -30
- flyte/_protos/workflow/queue_service_pb2_grpc.py +15 -15
- flyte/_protos/workflow/run_definition_pb2.py +61 -61
- flyte/_protos/workflow/run_definition_pb2.pyi +8 -4
- flyte/_protos/workflow/run_service_pb2.py +20 -24
- flyte/_protos/workflow/run_service_pb2.pyi +2 -6
- flyte/_protos/workflow/state_service_pb2.py +36 -28
- flyte/_protos/workflow/state_service_pb2.pyi +19 -15
- flyte/_protos/workflow/state_service_pb2_grpc.py +28 -28
- flyte/_protos/workflow/task_definition_pb2.py +29 -22
- flyte/_protos/workflow/task_definition_pb2.pyi +21 -5
- flyte/_protos/workflow/task_service_pb2.py +27 -11
- flyte/_protos/workflow/task_service_pb2.pyi +29 -1
- flyte/_protos/workflow/task_service_pb2_grpc.py +34 -0
- flyte/_run.py +166 -95
- flyte/_task.py +110 -28
- flyte/_task_environment.py +55 -72
- flyte/_trace.py +6 -14
- flyte/_utils/__init__.py +6 -0
- flyte/_utils/async_cache.py +139 -0
- flyte/_utils/coro_management.py +0 -2
- flyte/_utils/helpers.py +45 -19
- flyte/_utils/org_discovery.py +57 -0
- flyte/_version.py +2 -2
- flyte/cli/__init__.py +3 -0
- flyte/cli/_abort.py +28 -0
- flyte/{_cli → cli}/_common.py +73 -23
- flyte/cli/_create.py +145 -0
- flyte/{_cli → cli}/_delete.py +4 -4
- flyte/{_cli → cli}/_deploy.py +26 -14
- flyte/cli/_gen.py +163 -0
- flyte/{_cli → cli}/_get.py +98 -23
- {union/_cli → flyte/cli}/_params.py +106 -147
- flyte/{_cli → cli}/_run.py +99 -20
- flyte/cli/main.py +166 -0
- flyte/config/__init__.py +3 -0
- flyte/config/_config.py +216 -0
- flyte/config/_internal.py +64 -0
- flyte/config/_reader.py +207 -0
- flyte/errors.py +29 -0
- flyte/extras/_container.py +33 -43
- flyte/io/__init__.py +17 -1
- flyte/io/_dir.py +2 -2
- flyte/io/_file.py +3 -4
- flyte/io/{structured_dataset → _structured_dataset}/basic_dfs.py +1 -1
- flyte/io/{structured_dataset → _structured_dataset}/structured_dataset.py +1 -1
- flyte/{_datastructures.py → models.py} +56 -7
- flyte/remote/__init__.py +2 -1
- flyte/remote/_client/_protocols.py +2 -0
- flyte/remote/_client/auth/_auth_utils.py +14 -0
- flyte/remote/_client/auth/_channel.py +34 -3
- flyte/remote/_client/auth/_token_client.py +3 -3
- flyte/remote/_client/controlplane.py +13 -13
- flyte/remote/_console.py +1 -1
- flyte/remote/_data.py +10 -6
- flyte/remote/_logs.py +89 -29
- flyte/remote/_project.py +8 -9
- flyte/remote/_run.py +228 -131
- flyte/remote/_secret.py +12 -12
- flyte/remote/_task.py +179 -15
- flyte/report/_report.py +4 -4
- flyte/storage/__init__.py +5 -0
- flyte/storage/_config.py +233 -0
- flyte/storage/_storage.py +23 -3
- flyte/syncify/__init__.py +56 -0
- flyte/syncify/_api.py +371 -0
- flyte/types/__init__.py +23 -0
- flyte/types/_interface.py +22 -7
- flyte/{io/pickle/transformer.py → types/_pickle.py} +2 -1
- flyte/types/_type_engine.py +95 -18
- flyte-0.2.0a0.dist-info/METADATA +249 -0
- flyte-0.2.0a0.dist-info/RECORD +218 -0
- {flyte-0.0.1b3.dist-info → flyte-0.2.0a0.dist-info}/entry_points.txt +1 -1
- flyte/_api_commons.py +0 -3
- flyte/_cli/__init__.py +0 -0
- flyte/_cli/_create.py +0 -42
- flyte/_cli/main.py +0 -72
- flyte/_internal/controllers/pbhash.py +0 -39
- flyte/io/_dataframe.py +0 -0
- flyte/io/pickle/__init__.py +0 -0
- flyte-0.0.1b3.dist-info/METADATA +0 -179
- flyte-0.0.1b3.dist-info/RECORD +0 -390
- union/__init__.py +0 -54
- union/_api_commons.py +0 -3
- union/_bin/__init__.py +0 -0
- union/_bin/runtime.py +0 -113
- union/_build.py +0 -25
- union/_cache/__init__.py +0 -12
- union/_cache/cache.py +0 -141
- union/_cache/defaults.py +0 -9
- union/_cache/policy_function_body.py +0 -42
- union/_cli/__init__.py +0 -0
- union/_cli/_common.py +0 -263
- union/_cli/_create.py +0 -40
- union/_cli/_delete.py +0 -23
- union/_cli/_deploy.py +0 -120
- union/_cli/_get.py +0 -162
- union/_cli/_run.py +0 -150
- union/_cli/main.py +0 -72
- union/_code_bundle/__init__.py +0 -8
- union/_code_bundle/_ignore.py +0 -113
- union/_code_bundle/_packaging.py +0 -187
- union/_code_bundle/_utils.py +0 -342
- union/_code_bundle/bundle.py +0 -176
- union/_context.py +0 -146
- union/_datastructures.py +0 -295
- union/_deploy.py +0 -185
- union/_doc.py +0 -29
- union/_docstring.py +0 -26
- union/_environment.py +0 -43
- union/_group.py +0 -31
- union/_hash.py +0 -23
- union/_image.py +0 -760
- union/_initialize.py +0 -585
- union/_interface.py +0 -84
- union/_internal/__init__.py +0 -3
- union/_internal/controllers/__init__.py +0 -77
- union/_internal/controllers/_local_controller.py +0 -77
- union/_internal/controllers/pbhash.py +0 -39
- union/_internal/controllers/remote/__init__.py +0 -40
- union/_internal/controllers/remote/_action.py +0 -131
- union/_internal/controllers/remote/_client.py +0 -43
- union/_internal/controllers/remote/_controller.py +0 -169
- union/_internal/controllers/remote/_core.py +0 -341
- union/_internal/controllers/remote/_informer.py +0 -260
- union/_internal/controllers/remote/_service_protocol.py +0 -44
- union/_internal/imagebuild/__init__.py +0 -11
- union/_internal/imagebuild/docker_builder.py +0 -416
- union/_internal/imagebuild/image_builder.py +0 -243
- union/_internal/imagebuild/remote_builder.py +0 -0
- union/_internal/resolvers/__init__.py +0 -0
- union/_internal/resolvers/_task_module.py +0 -31
- union/_internal/resolvers/common.py +0 -24
- union/_internal/resolvers/default.py +0 -27
- union/_internal/runtime/__init__.py +0 -0
- union/_internal/runtime/convert.py +0 -163
- union/_internal/runtime/entrypoints.py +0 -121
- union/_internal/runtime/io.py +0 -136
- union/_internal/runtime/resources_serde.py +0 -134
- union/_internal/runtime/task_serde.py +0 -202
- union/_internal/runtime/taskrunner.py +0 -179
- union/_internal/runtime/types_serde.py +0 -53
- union/_logging.py +0 -124
- union/_protos/__init__.py +0 -0
- union/_protos/common/authorization_pb2.py +0 -66
- union/_protos/common/authorization_pb2.pyi +0 -106
- union/_protos/common/identifier_pb2.py +0 -71
- union/_protos/common/identifier_pb2.pyi +0 -82
- union/_protos/common/identity_pb2.py +0 -48
- union/_protos/common/identity_pb2.pyi +0 -72
- union/_protos/common/identity_pb2_grpc.py +0 -4
- union/_protos/common/list_pb2.py +0 -36
- union/_protos/common/list_pb2.pyi +0 -69
- union/_protos/common/list_pb2_grpc.py +0 -4
- union/_protos/common/policy_pb2.py +0 -37
- union/_protos/common/policy_pb2.pyi +0 -27
- union/_protos/common/policy_pb2_grpc.py +0 -4
- union/_protos/common/role_pb2.py +0 -37
- union/_protos/common/role_pb2.pyi +0 -51
- union/_protos/common/role_pb2_grpc.py +0 -4
- union/_protos/common/runtime_version_pb2.py +0 -28
- union/_protos/common/runtime_version_pb2.pyi +0 -24
- union/_protos/common/runtime_version_pb2_grpc.py +0 -4
- union/_protos/logs/dataplane/payload_pb2.py +0 -96
- union/_protos/logs/dataplane/payload_pb2.pyi +0 -168
- union/_protos/logs/dataplane/payload_pb2_grpc.py +0 -4
- union/_protos/secret/definition_pb2.py +0 -49
- union/_protos/secret/definition_pb2.pyi +0 -93
- union/_protos/secret/definition_pb2_grpc.py +0 -4
- union/_protos/secret/payload_pb2.py +0 -62
- union/_protos/secret/payload_pb2.pyi +0 -94
- union/_protos/secret/payload_pb2_grpc.py +0 -4
- union/_protos/secret/secret_pb2.py +0 -38
- union/_protos/secret/secret_pb2.pyi +0 -6
- union/_protos/secret/secret_pb2_grpc.py +0 -198
- union/_protos/validate/validate/validate_pb2.py +0 -76
- union/_protos/workflow/node_execution_service_pb2.py +0 -26
- union/_protos/workflow/node_execution_service_pb2.pyi +0 -4
- union/_protos/workflow/node_execution_service_pb2_grpc.py +0 -32
- union/_protos/workflow/queue_service_pb2.py +0 -75
- union/_protos/workflow/queue_service_pb2.pyi +0 -103
- union/_protos/workflow/queue_service_pb2_grpc.py +0 -172
- union/_protos/workflow/run_definition_pb2.py +0 -100
- union/_protos/workflow/run_definition_pb2.pyi +0 -256
- union/_protos/workflow/run_definition_pb2_grpc.py +0 -4
- union/_protos/workflow/run_logs_service_pb2.py +0 -41
- union/_protos/workflow/run_logs_service_pb2.pyi +0 -28
- union/_protos/workflow/run_logs_service_pb2_grpc.py +0 -69
- union/_protos/workflow/run_service_pb2.py +0 -133
- union/_protos/workflow/run_service_pb2.pyi +0 -173
- union/_protos/workflow/run_service_pb2_grpc.py +0 -412
- union/_protos/workflow/state_service_pb2.py +0 -58
- union/_protos/workflow/state_service_pb2.pyi +0 -69
- union/_protos/workflow/state_service_pb2_grpc.py +0 -138
- union/_protos/workflow/task_definition_pb2.py +0 -72
- union/_protos/workflow/task_definition_pb2.pyi +0 -65
- union/_protos/workflow/task_definition_pb2_grpc.py +0 -4
- union/_protos/workflow/task_service_pb2.py +0 -44
- union/_protos/workflow/task_service_pb2.pyi +0 -31
- union/_protos/workflow/task_service_pb2_grpc.py +0 -104
- union/_resources.py +0 -226
- union/_retry.py +0 -32
- union/_reusable_environment.py +0 -25
- union/_run.py +0 -374
- union/_secret.py +0 -61
- union/_task.py +0 -354
- union/_task_environment.py +0 -186
- union/_timeout.py +0 -47
- union/_tools.py +0 -27
- union/_utils/__init__.py +0 -11
- union/_utils/asyn.py +0 -119
- union/_utils/file_handling.py +0 -71
- union/_utils/helpers.py +0 -46
- union/_utils/lazy_module.py +0 -54
- union/_utils/uv_script_parser.py +0 -49
- union/_version.py +0 -21
- union/connectors/__init__.py +0 -0
- union/errors.py +0 -128
- union/extras/__init__.py +0 -5
- union/extras/_container.py +0 -263
- union/io/__init__.py +0 -11
- union/io/_dataframe.py +0 -0
- union/io/_dir.py +0 -425
- union/io/_file.py +0 -418
- union/io/pickle/__init__.py +0 -0
- union/io/pickle/transformer.py +0 -117
- union/io/structured_dataset/__init__.py +0 -122
- union/io/structured_dataset/basic_dfs.py +0 -219
- union/io/structured_dataset/structured_dataset.py +0 -1057
- union/py.typed +0 -0
- union/remote/__init__.py +0 -23
- union/remote/_client/__init__.py +0 -0
- union/remote/_client/_protocols.py +0 -129
- union/remote/_client/auth/__init__.py +0 -12
- union/remote/_client/auth/_authenticators/__init__.py +0 -0
- union/remote/_client/auth/_authenticators/base.py +0 -391
- union/remote/_client/auth/_authenticators/client_credentials.py +0 -73
- union/remote/_client/auth/_authenticators/device_code.py +0 -120
- union/remote/_client/auth/_authenticators/external_command.py +0 -77
- union/remote/_client/auth/_authenticators/factory.py +0 -200
- union/remote/_client/auth/_authenticators/pkce.py +0 -515
- union/remote/_client/auth/_channel.py +0 -184
- union/remote/_client/auth/_client_config.py +0 -83
- union/remote/_client/auth/_default_html.py +0 -32
- union/remote/_client/auth/_grpc_utils/__init__.py +0 -0
- union/remote/_client/auth/_grpc_utils/auth_interceptor.py +0 -204
- union/remote/_client/auth/_grpc_utils/default_metadata_interceptor.py +0 -144
- union/remote/_client/auth/_keyring.py +0 -154
- union/remote/_client/auth/_token_client.py +0 -258
- union/remote/_client/auth/errors.py +0 -16
- union/remote/_client/controlplane.py +0 -86
- union/remote/_data.py +0 -149
- union/remote/_logs.py +0 -74
- union/remote/_project.py +0 -86
- union/remote/_run.py +0 -820
- union/remote/_secret.py +0 -132
- union/remote/_task.py +0 -193
- union/report/__init__.py +0 -3
- union/report/_report.py +0 -178
- union/report/_template.html +0 -124
- union/storage/__init__.py +0 -24
- union/storage/_remote_fs.py +0 -34
- union/storage/_storage.py +0 -247
- union/storage/_utils.py +0 -5
- union/types/__init__.py +0 -11
- union/types/_renderer.py +0 -162
- union/types/_string_literals.py +0 -120
- union/types/_type_engine.py +0 -2131
- union/types/_utils.py +0 -80
- /union/_protos/common/authorization_pb2_grpc.py → /flyte/_protos/workflow/common_pb2_grpc.py +0 -0
- /union/_protos/common/identifier_pb2_grpc.py → /flyte/_protos/workflow/environment_pb2_grpc.py +0 -0
- /flyte/io/{structured_dataset → _structured_dataset}/__init__.py +0 -0
- {flyte-0.0.1b3.dist-info → flyte-0.2.0a0.dist-info}/WHEEL +0 -0
- {flyte-0.0.1b3.dist-info → flyte-0.2.0a0.dist-info}/top_level.txt +0 -0
flyte/remote/_secret.py
CHANGED
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
from dataclasses import dataclass
|
|
4
|
-
from typing import
|
|
4
|
+
from typing import AsyncIterator, Literal, Union
|
|
5
5
|
|
|
6
6
|
import rich.repr
|
|
7
7
|
|
|
8
|
-
from flyte.
|
|
9
|
-
from flyte._initialize import get_client, get_common_config, requires_client
|
|
8
|
+
from flyte._initialize import ensure_client, get_client, get_common_config
|
|
10
9
|
from flyte._protos.secret import definition_pb2, payload_pb2
|
|
10
|
+
from flyte.syncify import syncify
|
|
11
11
|
|
|
12
12
|
SecretTypes = Literal["regular", "image_pull"]
|
|
13
13
|
|
|
@@ -16,10 +16,10 @@ SecretTypes = Literal["regular", "image_pull"]
|
|
|
16
16
|
class Secret:
|
|
17
17
|
pb2: definition_pb2.Secret
|
|
18
18
|
|
|
19
|
+
@syncify
|
|
19
20
|
@classmethod
|
|
20
|
-
@requires_client
|
|
21
|
-
@syncer.wrap
|
|
22
21
|
async def create(cls, name: str, value: Union[str, bytes], type: SecretTypes = "regular"):
|
|
22
|
+
ensure_client()
|
|
23
23
|
cfg = get_common_config()
|
|
24
24
|
secret_type = (
|
|
25
25
|
definition_pb2.SecretType.SECRET_TYPE_GENERIC
|
|
@@ -49,10 +49,10 @@ class Secret:
|
|
|
49
49
|
),
|
|
50
50
|
)
|
|
51
51
|
|
|
52
|
+
@syncify
|
|
52
53
|
@classmethod
|
|
53
|
-
@requires_client
|
|
54
|
-
@syncer.wrap
|
|
55
54
|
async def get(cls, name: str) -> Secret:
|
|
55
|
+
ensure_client()
|
|
56
56
|
cfg = get_common_config()
|
|
57
57
|
resp = await get_client().secrets_service.GetSecret(
|
|
58
58
|
request=payload_pb2.GetSecretRequest(
|
|
@@ -66,10 +66,10 @@ class Secret:
|
|
|
66
66
|
)
|
|
67
67
|
return Secret(pb2=resp.secret)
|
|
68
68
|
|
|
69
|
+
@syncify
|
|
69
70
|
@classmethod
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
async def listall(cls, limit: int = 100) -> Union[Iterator[Secret], AsyncGenerator[Secret, None]]:
|
|
71
|
+
async def listall(cls, limit: int = 100) -> AsyncIterator[Secret]:
|
|
72
|
+
ensure_client()
|
|
73
73
|
cfg = get_common_config()
|
|
74
74
|
token = None
|
|
75
75
|
while True:
|
|
@@ -88,10 +88,10 @@ class Secret:
|
|
|
88
88
|
if not token:
|
|
89
89
|
break
|
|
90
90
|
|
|
91
|
+
@syncify
|
|
91
92
|
@classmethod
|
|
92
|
-
@requires_client
|
|
93
|
-
@syncer.wrap
|
|
94
93
|
async def delete(cls, name):
|
|
94
|
+
ensure_client()
|
|
95
95
|
cfg = get_common_config()
|
|
96
96
|
await get_client().secrets_service.DeleteSecret( # type: ignore
|
|
97
97
|
request=payload_pb2.DeleteSecretRequest(
|
flyte/remote/_task.py
CHANGED
|
@@ -3,17 +3,34 @@ from __future__ import annotations
|
|
|
3
3
|
import functools
|
|
4
4
|
from dataclasses import dataclass
|
|
5
5
|
from threading import Lock
|
|
6
|
-
from typing import Any, Callable, Coroutine, Dict, Literal, Optional, Union
|
|
6
|
+
from typing import Any, AsyncIterator, Callable, Coroutine, Dict, Iterator, Literal, Optional, Tuple, Union
|
|
7
7
|
|
|
8
8
|
import rich.repr
|
|
9
|
+
from google.protobuf import timestamp
|
|
9
10
|
|
|
10
11
|
import flyte
|
|
11
12
|
import flyte.errors
|
|
12
|
-
from flyte._api_commons import syncer
|
|
13
13
|
from flyte._context import internal_ctx
|
|
14
|
-
from flyte.
|
|
15
|
-
from flyte.
|
|
14
|
+
from flyte._initialize import ensure_client, get_client, get_common_config
|
|
15
|
+
from flyte._logging import logger
|
|
16
|
+
from flyte._protos.common import list_pb2
|
|
16
17
|
from flyte._protos.workflow import task_definition_pb2, task_service_pb2
|
|
18
|
+
from flyte.models import NativeInterface
|
|
19
|
+
from flyte.syncify import syncify
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _repr_task_metadata(metadata: task_definition_pb2.TaskMetadata) -> rich.repr.Result:
|
|
23
|
+
"""
|
|
24
|
+
Rich representation of the task metadata.
|
|
25
|
+
"""
|
|
26
|
+
if metadata.deployed_by:
|
|
27
|
+
if metadata.deployed_by.user:
|
|
28
|
+
yield "deployed_by", f"User: {metadata.deployed_by.user.spec.email}"
|
|
29
|
+
else:
|
|
30
|
+
yield "deployed_by", f"App: {metadata.deployed_by.application.spec.name}"
|
|
31
|
+
yield "short_name", metadata.short_name
|
|
32
|
+
yield "deployed_at", timestamp.to_datetime(metadata.deployed_at)
|
|
33
|
+
yield "environment_name", metadata.environment_name
|
|
17
34
|
|
|
18
35
|
|
|
19
36
|
class LazyEntity:
|
|
@@ -22,8 +39,8 @@ class LazyEntity:
|
|
|
22
39
|
The entity is derived from RemoteEntity so that it behaves exactly like the mimicked entity.
|
|
23
40
|
"""
|
|
24
41
|
|
|
25
|
-
def __init__(self, name: str, getter: Callable[..., Coroutine[Any, Any,
|
|
26
|
-
self._task: Optional[
|
|
42
|
+
def __init__(self, name: str, getter: Callable[..., Coroutine[Any, Any, TaskDetails]], *args, **kwargs):
|
|
43
|
+
self._task: Optional[TaskDetails] = None
|
|
27
44
|
self._getter = getter
|
|
28
45
|
self._name = name
|
|
29
46
|
self._mutex = Lock()
|
|
@@ -32,8 +49,8 @@ class LazyEntity:
|
|
|
32
49
|
def name(self) -> str:
|
|
33
50
|
return self._name
|
|
34
51
|
|
|
35
|
-
@
|
|
36
|
-
async def fetch(self) ->
|
|
52
|
+
@syncify
|
|
53
|
+
async def fetch(self) -> TaskDetails:
|
|
37
54
|
"""
|
|
38
55
|
Forwards all other attributes to task, causing the task to be fetched!
|
|
39
56
|
"""
|
|
@@ -48,7 +65,7 @@ class LazyEntity:
|
|
|
48
65
|
"""
|
|
49
66
|
Forwards the call to the underlying task. The entity will be fetched if not already present
|
|
50
67
|
"""
|
|
51
|
-
tk = await self.fetch.aio(
|
|
68
|
+
tk = await self.fetch.aio()
|
|
52
69
|
return await tk(*args, **kwargs)
|
|
53
70
|
|
|
54
71
|
def __repr__(self) -> str:
|
|
@@ -62,7 +79,7 @@ AutoVersioning = Literal["latest", "current"]
|
|
|
62
79
|
|
|
63
80
|
|
|
64
81
|
@dataclass
|
|
65
|
-
class
|
|
82
|
+
class TaskDetails:
|
|
66
83
|
pb2: task_definition_pb2.TaskDetails
|
|
67
84
|
|
|
68
85
|
@classmethod
|
|
@@ -87,10 +104,19 @@ class Task:
|
|
|
87
104
|
if version is None and auto_version not in ["latest", "current"]:
|
|
88
105
|
raise ValueError("auto_version must be either 'latest' or 'current'.")
|
|
89
106
|
|
|
90
|
-
async def deferred_get(_version: str | None, _auto_version: AutoVersioning | None) ->
|
|
107
|
+
async def deferred_get(_version: str | None, _auto_version: AutoVersioning | None) -> TaskDetails:
|
|
91
108
|
if _version is None:
|
|
92
109
|
if _auto_version == "latest":
|
|
93
|
-
|
|
110
|
+
tasks = []
|
|
111
|
+
async for x in Task.listall.aio(
|
|
112
|
+
by_task_name=name,
|
|
113
|
+
sort_by=("created_at", "desc"),
|
|
114
|
+
limit=1,
|
|
115
|
+
):
|
|
116
|
+
tasks.append(x)
|
|
117
|
+
if not tasks:
|
|
118
|
+
raise flyte.errors.ReferenceTaskError(f"Task {name} not found.")
|
|
119
|
+
_version = tasks[0].version
|
|
94
120
|
elif _auto_version == "current":
|
|
95
121
|
ctx = flyte.ctx()
|
|
96
122
|
if ctx is None:
|
|
@@ -136,6 +162,20 @@ class Task:
|
|
|
136
162
|
"""
|
|
137
163
|
return self.pb2.spec.task_template.type
|
|
138
164
|
|
|
165
|
+
@property
|
|
166
|
+
def default_input_args(self) -> Tuple[str, ...]:
|
|
167
|
+
"""
|
|
168
|
+
The default input arguments of the task.
|
|
169
|
+
"""
|
|
170
|
+
return tuple(x.name for x in self.pb2.spec.default_inputs)
|
|
171
|
+
|
|
172
|
+
@property
|
|
173
|
+
def required_args(self) -> Tuple[str, ...]:
|
|
174
|
+
"""
|
|
175
|
+
The required input arguments of the task.
|
|
176
|
+
"""
|
|
177
|
+
return tuple(x for x, _ in self.interface.inputs.items() if x not in self.default_input_args)
|
|
178
|
+
|
|
139
179
|
@functools.cached_property
|
|
140
180
|
def interface(self) -> NativeInterface:
|
|
141
181
|
"""
|
|
@@ -143,7 +183,7 @@ class Task:
|
|
|
143
183
|
"""
|
|
144
184
|
import flyte.types as types
|
|
145
185
|
|
|
146
|
-
return types.guess_interface(self.pb2.spec.task_template.interface)
|
|
186
|
+
return types.guess_interface(self.pb2.spec.task_template.interface, default_inputs=self.pb2.spec.default_inputs)
|
|
147
187
|
|
|
148
188
|
@property
|
|
149
189
|
def cache(self) -> flyte.Cache:
|
|
@@ -180,6 +220,19 @@ class Task:
|
|
|
180
220
|
"""
|
|
181
221
|
Forwards the call to the underlying task. The entity will be fetched if not already present
|
|
182
222
|
"""
|
|
223
|
+
# TODO support kwargs, for this we need ordered inputs to be stored in the task spec.
|
|
224
|
+
if len(args) > 0:
|
|
225
|
+
raise flyte.errors.ReferenceTaskError(
|
|
226
|
+
f"Reference task {self.name} does not support positional arguments"
|
|
227
|
+
f"currently. Please use keyword arguments."
|
|
228
|
+
)
|
|
229
|
+
if len(self.required_args) > 0:
|
|
230
|
+
if len(args) + len(kwargs) < len(self.required_args):
|
|
231
|
+
raise ValueError(
|
|
232
|
+
f"Task {self.name} requires at least {self.required_args} arguments, "
|
|
233
|
+
f"but only received args:{args} kwargs{kwargs}."
|
|
234
|
+
)
|
|
235
|
+
|
|
183
236
|
ctx = internal_ctx()
|
|
184
237
|
if ctx.is_task_context():
|
|
185
238
|
# If we are in a task context, that implies we are executing a Run.
|
|
@@ -187,7 +240,7 @@ class Task:
|
|
|
187
240
|
# We will also check if we are not initialized, It is not expected to be not initialized
|
|
188
241
|
from flyte._internal.controllers import get_controller
|
|
189
242
|
|
|
190
|
-
controller =
|
|
243
|
+
controller = get_controller()
|
|
191
244
|
if controller:
|
|
192
245
|
return await controller.submit_task_ref(self.pb2, *args, **kwargs)
|
|
193
246
|
raise flyte.errors
|
|
@@ -205,13 +258,18 @@ class Task:
|
|
|
205
258
|
env: Optional[Dict[str, str]] = None,
|
|
206
259
|
secrets: Optional[flyte.SecretRequest] = None,
|
|
207
260
|
**kwargs: Any,
|
|
208
|
-
) ->
|
|
261
|
+
) -> TaskDetails:
|
|
209
262
|
raise NotImplementedError
|
|
210
263
|
|
|
211
264
|
def __rich_repr__(self) -> rich.repr.Result:
|
|
212
265
|
"""
|
|
213
266
|
Rich representation of the task.
|
|
214
267
|
"""
|
|
268
|
+
yield "friendly_name", self.pb2.spec.short_name
|
|
269
|
+
yield "environment", self.pb2.spec.environment
|
|
270
|
+
yield "default_inputs_keys", self.default_input_args
|
|
271
|
+
yield "required_args", self.required_args
|
|
272
|
+
yield "raw_default_inputs", [str(x) for x in self.pb2.spec.default_inputs]
|
|
215
273
|
yield "project", self.pb2.task_id.project
|
|
216
274
|
yield "domain", self.pb2.task_id.domain
|
|
217
275
|
yield "name", self.name
|
|
@@ -223,5 +281,111 @@ class Task:
|
|
|
223
281
|
yield "resources", self.resources
|
|
224
282
|
|
|
225
283
|
|
|
284
|
+
@dataclass
|
|
285
|
+
class Task:
|
|
286
|
+
pb2: task_definition_pb2.Task
|
|
287
|
+
|
|
288
|
+
def __init__(self, pb2: task_definition_pb2.Task):
|
|
289
|
+
self.pb2 = pb2
|
|
290
|
+
|
|
291
|
+
@property
|
|
292
|
+
def name(self) -> str:
|
|
293
|
+
"""
|
|
294
|
+
The name of the task.
|
|
295
|
+
"""
|
|
296
|
+
return self.pb2.task_id.name
|
|
297
|
+
|
|
298
|
+
@property
|
|
299
|
+
def version(self) -> str:
|
|
300
|
+
"""
|
|
301
|
+
The version of the task.
|
|
302
|
+
"""
|
|
303
|
+
return self.pb2.task_id.version
|
|
304
|
+
|
|
305
|
+
@classmethod
|
|
306
|
+
def get(cls, name: str, version: str | None = None, auto_version: AutoVersioning | None = None) -> LazyEntity:
|
|
307
|
+
"""
|
|
308
|
+
Get a task by its ID or name. If both are provided, the ID will take precedence.
|
|
309
|
+
|
|
310
|
+
Either version or auto_version are required parameters.
|
|
311
|
+
|
|
312
|
+
:param name: The name of the task.
|
|
313
|
+
:param version: The version of the task.
|
|
314
|
+
:param auto_version: If set to "latest", the latest-by-time ordered from now, version of the task will be used.
|
|
315
|
+
If set to "current", the version will be derived from the callee tasks context. This is useful if you are
|
|
316
|
+
deploying all environments with the same version. If auto_version is current, you can only access the task from
|
|
317
|
+
within a task context.
|
|
318
|
+
"""
|
|
319
|
+
return TaskDetails.get(name, version=version, auto_version=auto_version)
|
|
320
|
+
|
|
321
|
+
@syncify
|
|
322
|
+
@classmethod
|
|
323
|
+
async def listall(
|
|
324
|
+
cls,
|
|
325
|
+
by_task_name: str | None = None,
|
|
326
|
+
sort_by: Tuple[str, Literal["asc", "desc"]] | None = None,
|
|
327
|
+
limit: int = 100,
|
|
328
|
+
) -> Union[AsyncIterator[Task], Iterator[Task]]:
|
|
329
|
+
"""
|
|
330
|
+
Get all runs for the current project and domain.
|
|
331
|
+
|
|
332
|
+
:param by_task_name: If provided, only tasks with this name will be returned.
|
|
333
|
+
:param sort_by: The sorting criteria for the project list, in the format (field, order).
|
|
334
|
+
:param limit: The maximum number of tasks to return.
|
|
335
|
+
:return: An iterator of runs.
|
|
336
|
+
"""
|
|
337
|
+
ensure_client()
|
|
338
|
+
token = None
|
|
339
|
+
sort_by = sort_by or ("created_at", "asc")
|
|
340
|
+
sort_pb2 = list_pb2.Sort(
|
|
341
|
+
key=sort_by[0], direction=list_pb2.Sort.ASCENDING if sort_by[1] == "asc" else list_pb2.Sort.DESCENDING
|
|
342
|
+
)
|
|
343
|
+
cfg = get_common_config()
|
|
344
|
+
filters = []
|
|
345
|
+
if by_task_name:
|
|
346
|
+
filters.append(
|
|
347
|
+
list_pb2.Filter(
|
|
348
|
+
function=list_pb2.Filter.Function.EQUAL,
|
|
349
|
+
field="name",
|
|
350
|
+
values=[by_task_name],
|
|
351
|
+
)
|
|
352
|
+
)
|
|
353
|
+
original_limit = limit
|
|
354
|
+
if limit > cfg.batch_size:
|
|
355
|
+
limit = cfg.batch_size
|
|
356
|
+
retrieved = 0
|
|
357
|
+
while True:
|
|
358
|
+
resp = await get_client().task_service.ListTasks(
|
|
359
|
+
task_service_pb2.ListTasksRequest(
|
|
360
|
+
org=cfg.org,
|
|
361
|
+
request=list_pb2.ListRequest(
|
|
362
|
+
sort_by=sort_pb2,
|
|
363
|
+
filters=filters,
|
|
364
|
+
limit=limit,
|
|
365
|
+
token=token,
|
|
366
|
+
),
|
|
367
|
+
)
|
|
368
|
+
)
|
|
369
|
+
token = resp.token
|
|
370
|
+
for t in resp.tasks:
|
|
371
|
+
retrieved += 1
|
|
372
|
+
yield cls(t)
|
|
373
|
+
if not token or retrieved >= original_limit:
|
|
374
|
+
logger.debug(f"Retrieved {retrieved} tasks, stopping iteration.")
|
|
375
|
+
break
|
|
376
|
+
|
|
377
|
+
def __rich_repr__(self) -> rich.repr.Result:
|
|
378
|
+
"""
|
|
379
|
+
Rich representation of the task.
|
|
380
|
+
"""
|
|
381
|
+
yield "project", self.pb2.task_id.project
|
|
382
|
+
yield "domain", self.pb2.task_id.domain
|
|
383
|
+
yield "name", self.pb2.task_id.name
|
|
384
|
+
yield "version", self.pb2.task_id.version
|
|
385
|
+
yield "short_name", self.pb2.metadata.short_name
|
|
386
|
+
for t in _repr_task_metadata(self.pb2.metadata):
|
|
387
|
+
yield t
|
|
388
|
+
|
|
389
|
+
|
|
226
390
|
if __name__ == "__main__":
|
|
227
391
|
tk = Task.get(name="example_task")
|
flyte/report/_report.py
CHANGED
|
@@ -4,10 +4,10 @@ import string
|
|
|
4
4
|
from dataclasses import dataclass, field
|
|
5
5
|
from typing import TYPE_CHECKING, Dict, List, Union
|
|
6
6
|
|
|
7
|
-
from flyte._api_commons import syncer
|
|
8
7
|
from flyte._internal.runtime import io
|
|
9
8
|
from flyte._logging import logger
|
|
10
9
|
from flyte._tools import ipython_check
|
|
10
|
+
from flyte.syncify import syncify
|
|
11
11
|
|
|
12
12
|
if TYPE_CHECKING:
|
|
13
13
|
from IPython.core.display import HTML
|
|
@@ -112,7 +112,7 @@ def get_tab(name: str, /, create_if_missing: bool = True) -> Tab:
|
|
|
112
112
|
return report.get_tab(name, create_if_missing=create_if_missing)
|
|
113
113
|
|
|
114
114
|
|
|
115
|
-
@
|
|
115
|
+
@syncify
|
|
116
116
|
async def log(content: str, do_flush: bool = False):
|
|
117
117
|
"""
|
|
118
118
|
Log content to the main tab. The content should be a valid HTML string, but not a complete HTML document,
|
|
@@ -126,7 +126,7 @@ async def log(content: str, do_flush: bool = False):
|
|
|
126
126
|
await flush.aio()
|
|
127
127
|
|
|
128
128
|
|
|
129
|
-
@
|
|
129
|
+
@syncify
|
|
130
130
|
async def flush():
|
|
131
131
|
"""
|
|
132
132
|
Flush the report.
|
|
@@ -149,7 +149,7 @@ async def flush():
|
|
|
149
149
|
logger.debug(f"Report flushed to {final_path}")
|
|
150
150
|
|
|
151
151
|
|
|
152
|
-
@
|
|
152
|
+
@syncify
|
|
153
153
|
async def replace(content: str, do_flush: bool = False):
|
|
154
154
|
"""
|
|
155
155
|
Get the report. Replaces the content of the main tab.
|
flyte/storage/__init__.py
CHANGED
|
@@ -1,4 +1,8 @@
|
|
|
1
1
|
__all__ = [
|
|
2
|
+
"ABFS",
|
|
3
|
+
"GCS",
|
|
4
|
+
"S3",
|
|
5
|
+
"Storage",
|
|
2
6
|
"get",
|
|
3
7
|
"get_random_local_directory",
|
|
4
8
|
"get_random_local_path",
|
|
@@ -11,6 +15,7 @@ __all__ = [
|
|
|
11
15
|
"put_stream",
|
|
12
16
|
]
|
|
13
17
|
|
|
18
|
+
from ._config import ABFS, GCS, S3, Storage
|
|
14
19
|
from ._storage import (
|
|
15
20
|
get,
|
|
16
21
|
get_random_local_directory,
|
flyte/storage/_config.py
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import datetime
|
|
4
|
+
import os
|
|
5
|
+
import typing
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from typing import ClassVar
|
|
8
|
+
|
|
9
|
+
from flyte.config import set_if_exists
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass(init=True, repr=True, eq=True, frozen=True)
|
|
13
|
+
class Storage(object):
|
|
14
|
+
"""
|
|
15
|
+
Data storage configuration that applies across any provider.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
retries: int = 3
|
|
19
|
+
backoff: datetime.timedelta = datetime.timedelta(seconds=5)
|
|
20
|
+
enable_debug: bool = False
|
|
21
|
+
attach_execution_metadata: bool = True
|
|
22
|
+
|
|
23
|
+
_KEY_ENV_VAR_MAPPING: ClassVar[typing.Dict[str, str]] = {
|
|
24
|
+
"enable_debug": "UNION_STORAGE_DEBUG",
|
|
25
|
+
"retries": "UNION_STORAGE_RETRIES",
|
|
26
|
+
"backoff": "UNION_STORAGE_BACKOFF_SECONDS",
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
def get_fsspec_kwargs(self, anonymous: bool = False, **kwargs) -> typing.Dict[str, typing.Any]:
|
|
30
|
+
"""
|
|
31
|
+
Returns the configuration as kwargs for constructing an fsspec filesystem.
|
|
32
|
+
"""
|
|
33
|
+
return {}
|
|
34
|
+
|
|
35
|
+
@classmethod
|
|
36
|
+
def _auto_as_kwargs(cls) -> typing.Dict[str, typing.Any]:
|
|
37
|
+
retries = os.getenv(cls._KEY_ENV_VAR_MAPPING["retries"])
|
|
38
|
+
backoff = os.getenv(cls._KEY_ENV_VAR_MAPPING["backoff"])
|
|
39
|
+
enable_debug = os.getenv(cls._KEY_ENV_VAR_MAPPING["enable_debug"])
|
|
40
|
+
|
|
41
|
+
kwargs: typing.Dict[str, typing.Any] = {}
|
|
42
|
+
kwargs = set_if_exists(kwargs, "enable_debug", enable_debug)
|
|
43
|
+
kwargs = set_if_exists(kwargs, "retries", retries)
|
|
44
|
+
kwargs = set_if_exists(kwargs, "backoff", backoff)
|
|
45
|
+
return kwargs
|
|
46
|
+
|
|
47
|
+
@classmethod
|
|
48
|
+
def auto(cls) -> Storage:
|
|
49
|
+
"""
|
|
50
|
+
Construct the config object automatically from environment variables.
|
|
51
|
+
"""
|
|
52
|
+
return cls(**cls._auto_as_kwargs())
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
@dataclass(init=True, repr=True, eq=True, frozen=True)
|
|
56
|
+
class S3(Storage):
|
|
57
|
+
"""
|
|
58
|
+
S3 specific configuration
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
endpoint: typing.Optional[str] = None
|
|
62
|
+
access_key_id: typing.Optional[str] = None
|
|
63
|
+
secret_access_key: typing.Optional[str] = None
|
|
64
|
+
|
|
65
|
+
_KEY_ENV_VAR_MAPPING: ClassVar[typing.Dict[str, str]] = {
|
|
66
|
+
"endpoint": "FLYTE_AWS_ENDPOINT",
|
|
67
|
+
"access_key_id": "FLYTE_AWS_ACCESS_KEY_ID",
|
|
68
|
+
"secret_access_key": "FLYTE_AWS_SECRET_ACCESS_KEY",
|
|
69
|
+
} | Storage._KEY_ENV_VAR_MAPPING
|
|
70
|
+
|
|
71
|
+
# Refer to https://github.com/developmentseed/obstore/blob/33654fc37f19a657689eb93327b621e9f9e01494/obstore/python/obstore/store/_aws.pyi#L11
|
|
72
|
+
# for key and secret
|
|
73
|
+
_CONFIG_KEY_FSSPEC_S3_KEY_ID: ClassVar = "access_key_id"
|
|
74
|
+
_CONFIG_KEY_FSSPEC_S3_SECRET: ClassVar = "secret_access_key"
|
|
75
|
+
_CONFIG_KEY_ENDPOINT: ClassVar = "endpoint_url"
|
|
76
|
+
_KEY_SKIP_SIGNATURE: ClassVar = "skip_signature"
|
|
77
|
+
|
|
78
|
+
@classmethod
|
|
79
|
+
def auto(cls) -> S3:
|
|
80
|
+
"""
|
|
81
|
+
:return: Config
|
|
82
|
+
"""
|
|
83
|
+
endpoint = os.getenv(cls._KEY_ENV_VAR_MAPPING["endpoint"], None)
|
|
84
|
+
access_key_id = os.getenv(cls._KEY_ENV_VAR_MAPPING["access_key_id"], None)
|
|
85
|
+
secret_access_key = os.getenv(cls._KEY_ENV_VAR_MAPPING["secret_access_key"], None)
|
|
86
|
+
|
|
87
|
+
kwargs = super()._auto_as_kwargs()
|
|
88
|
+
kwargs = set_if_exists(kwargs, "endpoint", endpoint)
|
|
89
|
+
kwargs = set_if_exists(kwargs, "access_key_id", access_key_id)
|
|
90
|
+
kwargs = set_if_exists(kwargs, "secret_access_key", secret_access_key)
|
|
91
|
+
|
|
92
|
+
return S3(**kwargs)
|
|
93
|
+
|
|
94
|
+
@classmethod
|
|
95
|
+
def for_sandbox(cls) -> S3:
|
|
96
|
+
"""
|
|
97
|
+
:return:
|
|
98
|
+
"""
|
|
99
|
+
kwargs = super()._auto_as_kwargs()
|
|
100
|
+
final_kwargs = kwargs | {
|
|
101
|
+
"endpoint": "http://localhost:4566",
|
|
102
|
+
"access_key_id": "minio",
|
|
103
|
+
"secret_access_key": "miniostorage",
|
|
104
|
+
}
|
|
105
|
+
return S3(**final_kwargs)
|
|
106
|
+
|
|
107
|
+
def get_fsspec_kwargs(self, anonymous: bool = False, **kwargs) -> typing.Dict[str, typing.Any]:
|
|
108
|
+
# Construct the config object
|
|
109
|
+
kwargs.pop("anonymous", None) # Remove anonymous if it exists, as we handle it separately
|
|
110
|
+
config: typing.Dict[str, typing.Any] = {}
|
|
111
|
+
if self._CONFIG_KEY_FSSPEC_S3_KEY_ID in kwargs or self.access_key_id:
|
|
112
|
+
config[self._CONFIG_KEY_FSSPEC_S3_KEY_ID] = kwargs.pop(
|
|
113
|
+
self._CONFIG_KEY_FSSPEC_S3_KEY_ID, self.access_key_id
|
|
114
|
+
)
|
|
115
|
+
if self._CONFIG_KEY_FSSPEC_S3_SECRET in kwargs or self.secret_access_key:
|
|
116
|
+
config[self._CONFIG_KEY_FSSPEC_S3_SECRET] = kwargs.pop(
|
|
117
|
+
self._CONFIG_KEY_FSSPEC_S3_SECRET, self.secret_access_key
|
|
118
|
+
)
|
|
119
|
+
if self._CONFIG_KEY_ENDPOINT in kwargs or self.endpoint:
|
|
120
|
+
config["endpoint_url"] = kwargs.pop(self._CONFIG_KEY_ENDPOINT, self.endpoint)
|
|
121
|
+
|
|
122
|
+
retries = kwargs.pop("retries", self.retries)
|
|
123
|
+
backoff = kwargs.pop("backoff", self.backoff)
|
|
124
|
+
|
|
125
|
+
if anonymous:
|
|
126
|
+
config[self._KEY_SKIP_SIGNATURE] = True
|
|
127
|
+
|
|
128
|
+
retry_config = {
|
|
129
|
+
"max_retries": retries,
|
|
130
|
+
"backoff": {
|
|
131
|
+
"base": 2,
|
|
132
|
+
"init_backoff": backoff,
|
|
133
|
+
"max_backoff": datetime.timedelta(seconds=16),
|
|
134
|
+
},
|
|
135
|
+
"retry_timeout": datetime.timedelta(minutes=3),
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
client_options = {"timeout": "99999s", "allow_http": True}
|
|
139
|
+
|
|
140
|
+
if config:
|
|
141
|
+
kwargs["config"] = config
|
|
142
|
+
kwargs["client_options"] = client_options or None
|
|
143
|
+
kwargs["retry_config"] = retry_config or None
|
|
144
|
+
|
|
145
|
+
return kwargs
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
@dataclass(init=True, repr=True, eq=True, frozen=True)
|
|
149
|
+
class GCS(Storage):
|
|
150
|
+
"""
|
|
151
|
+
Any GCS specific configuration.
|
|
152
|
+
"""
|
|
153
|
+
|
|
154
|
+
gsutil_parallelism: bool = False
|
|
155
|
+
|
|
156
|
+
_KEY_ENV_VAR_MAPPING: ClassVar[dict[str, str]] = {
|
|
157
|
+
"gsutil_parallelism": "GCP_GSUTIL_PARALLELISM",
|
|
158
|
+
}
|
|
159
|
+
|
|
160
|
+
@classmethod
|
|
161
|
+
def auto(cls) -> GCS:
|
|
162
|
+
gsutil_parallelism = os.getenv(cls._KEY_ENV_VAR_MAPPING["gsutil_parallelism"], None)
|
|
163
|
+
|
|
164
|
+
kwargs: typing.Dict[str, typing.Any] = {}
|
|
165
|
+
kwargs = set_if_exists(kwargs, "gsutil_parallelism", gsutil_parallelism)
|
|
166
|
+
return GCS(**kwargs)
|
|
167
|
+
|
|
168
|
+
def get_fsspec_kwargs(self, anonymous: bool = False, **kwargs) -> typing.Dict[str, typing.Any]:
|
|
169
|
+
kwargs.pop("anonymous", None)
|
|
170
|
+
return kwargs
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
@dataclass(init=True, repr=True, eq=True, frozen=True)
|
|
174
|
+
class ABFS(Storage):
|
|
175
|
+
"""
|
|
176
|
+
Any Azure Blob Storage specific configuration.
|
|
177
|
+
"""
|
|
178
|
+
|
|
179
|
+
account_name: typing.Optional[str] = None
|
|
180
|
+
account_key: typing.Optional[str] = None
|
|
181
|
+
tenant_id: typing.Optional[str] = None
|
|
182
|
+
client_id: typing.Optional[str] = None
|
|
183
|
+
client_secret: typing.Optional[str] = None
|
|
184
|
+
|
|
185
|
+
_KEY_ENV_VAR_MAPPING: ClassVar[dict[str, str]] = {
|
|
186
|
+
"account_name": "AZURE_STORAGE_ACCOUNT_NAME",
|
|
187
|
+
"account_key": "AZURE_STORAGE_ACCOUNT_KEY",
|
|
188
|
+
"tenant_id": "AZURE_TENANT_ID",
|
|
189
|
+
"client_id": "AZURE_CLIENT_ID",
|
|
190
|
+
"client_secret": "AZURE_CLIENT_SECRET",
|
|
191
|
+
}
|
|
192
|
+
_KEY_SKIP_SIGNATURE: ClassVar = "skip_signature"
|
|
193
|
+
|
|
194
|
+
@classmethod
|
|
195
|
+
def auto(cls) -> ABFS:
|
|
196
|
+
account_name = os.getenv(cls._KEY_ENV_VAR_MAPPING["account_name"], None)
|
|
197
|
+
account_key = os.getenv(cls._KEY_ENV_VAR_MAPPING["account_key"], None)
|
|
198
|
+
tenant_id = os.getenv(cls._KEY_ENV_VAR_MAPPING["tenant_id"], None)
|
|
199
|
+
client_id = os.getenv(cls._KEY_ENV_VAR_MAPPING["client_id"], None)
|
|
200
|
+
client_secret = os.getenv(cls._KEY_ENV_VAR_MAPPING["client_secret"], None)
|
|
201
|
+
|
|
202
|
+
kwargs: typing.Dict[str, typing.Any] = {}
|
|
203
|
+
kwargs = set_if_exists(kwargs, "account_name", account_name)
|
|
204
|
+
kwargs = set_if_exists(kwargs, "account_key", account_key)
|
|
205
|
+
kwargs = set_if_exists(kwargs, "tenant_id", tenant_id)
|
|
206
|
+
kwargs = set_if_exists(kwargs, "client_id", client_id)
|
|
207
|
+
kwargs = set_if_exists(kwargs, "client_secret", client_secret)
|
|
208
|
+
return ABFS(**kwargs)
|
|
209
|
+
|
|
210
|
+
def get_fsspec_kwargs(self, anonymous: bool = False, **kwargs) -> typing.Dict[str, typing.Any]:
|
|
211
|
+
kwargs.pop("anonymous", None)
|
|
212
|
+
config: typing.Dict[str, typing.Any] = {}
|
|
213
|
+
if "account_name" in kwargs or self.account_name:
|
|
214
|
+
config["account_name"] = kwargs.get("account_name", self.account_name)
|
|
215
|
+
if "account_key" in kwargs or self.account_key:
|
|
216
|
+
config["account_key"] = kwargs.get("account_key", self.account_key)
|
|
217
|
+
if "client_id" in kwargs or self.client_id:
|
|
218
|
+
config["client_id"] = kwargs.get("client_id", self.client_id)
|
|
219
|
+
if "client_secret" in kwargs or self.client_secret:
|
|
220
|
+
config["client_secret"] = kwargs.get("client_secret", self.client_secret)
|
|
221
|
+
if "tenant_id" in kwargs or self.tenant_id:
|
|
222
|
+
config["tenant_id"] = kwargs.get("tenant_id", self.tenant_id)
|
|
223
|
+
|
|
224
|
+
if anonymous:
|
|
225
|
+
config[self._KEY_SKIP_SIGNATURE] = True
|
|
226
|
+
|
|
227
|
+
client_options = {"timeout": "99999s", "allow_http": "true"}
|
|
228
|
+
|
|
229
|
+
if config:
|
|
230
|
+
kwargs["config"] = config
|
|
231
|
+
kwargs["client_options"] = client_options
|
|
232
|
+
|
|
233
|
+
return kwargs
|
flyte/storage/_storage.py
CHANGED
|
@@ -74,7 +74,27 @@ def get_underlying_filesystem(
|
|
|
74
74
|
|
|
75
75
|
storage_config = get_storage()
|
|
76
76
|
if storage_config:
|
|
77
|
-
kwargs
|
|
77
|
+
kwargs = storage_config.get_fsspec_kwargs(anonymous, **kwargs)
|
|
78
|
+
elif protocol:
|
|
79
|
+
match protocol:
|
|
80
|
+
case "s3":
|
|
81
|
+
# If the protocol is s3, we can use the s3 filesystem
|
|
82
|
+
from flyte.storage import S3
|
|
83
|
+
|
|
84
|
+
kwargs = S3.auto().get_fsspec_kwargs(anonymous=anonymous, **kwargs)
|
|
85
|
+
case "gs":
|
|
86
|
+
# If the protocol is gs, we can use the gs filesystem
|
|
87
|
+
from flyte.storage import GCS
|
|
88
|
+
|
|
89
|
+
kwargs = GCS.auto().get_fsspec_kwargs(anonymous=anonymous, **kwargs)
|
|
90
|
+
case "abfs" | "abfss":
|
|
91
|
+
# If the protocol is abfs or abfss, we can use the abfs filesystem
|
|
92
|
+
from flyte.storage import ABFS
|
|
93
|
+
|
|
94
|
+
kwargs = ABFS.auto().get_fsspec_kwargs(anonymous=anonymous, **kwargs)
|
|
95
|
+
case _:
|
|
96
|
+
pass
|
|
97
|
+
|
|
78
98
|
return fsspec.filesystem(protocol, **kwargs)
|
|
79
99
|
|
|
80
100
|
|
|
@@ -127,7 +147,7 @@ async def _get_from_filesystem(
|
|
|
127
147
|
return to_path
|
|
128
148
|
|
|
129
149
|
|
|
130
|
-
async def put(from_path: str, to_path: Optional[str] = None, recursive: bool = False, **kwargs):
|
|
150
|
+
async def put(from_path: str, to_path: Optional[str] = None, recursive: bool = False, **kwargs) -> str:
|
|
131
151
|
if not to_path:
|
|
132
152
|
from flyte._context import internal_ctx
|
|
133
153
|
|
|
@@ -142,7 +162,7 @@ async def put(from_path: str, to_path: Optional[str] = None, recursive: bool = F
|
|
|
142
162
|
else:
|
|
143
163
|
dst = file_system.put(from_path, to_path, recursive=recursive, **kwargs)
|
|
144
164
|
if isinstance(dst, (str, pathlib.Path)):
|
|
145
|
-
return dst
|
|
165
|
+
return str(dst)
|
|
146
166
|
else:
|
|
147
167
|
return to_path
|
|
148
168
|
|