flyte 2.0.0b23__py3-none-any.whl → 2.0.0b25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flyte might be problematic. Click here for more details.
- flyte/__init__.py +11 -2
- flyte/_cache/local_cache.py +4 -3
- flyte/_code_bundle/_utils.py +3 -3
- flyte/_code_bundle/bundle.py +12 -5
- flyte/_context.py +4 -1
- flyte/_custom_context.py +73 -0
- flyte/_deploy.py +31 -7
- flyte/_image.py +48 -16
- flyte/_initialize.py +69 -26
- flyte/_internal/controllers/_local_controller.py +1 -0
- flyte/_internal/controllers/_trace.py +1 -1
- flyte/_internal/controllers/remote/_action.py +9 -10
- flyte/_internal/controllers/remote/_client.py +1 -1
- flyte/_internal/controllers/remote/_controller.py +4 -2
- flyte/_internal/controllers/remote/_core.py +10 -13
- flyte/_internal/controllers/remote/_informer.py +3 -3
- flyte/_internal/controllers/remote/_service_protocol.py +7 -7
- flyte/_internal/imagebuild/docker_builder.py +45 -59
- flyte/_internal/imagebuild/remote_builder.py +51 -11
- flyte/_internal/imagebuild/utils.py +51 -3
- flyte/_internal/runtime/convert.py +39 -18
- flyte/_internal/runtime/io.py +8 -7
- flyte/_internal/runtime/resources_serde.py +20 -6
- flyte/_internal/runtime/reuse.py +1 -1
- flyte/_internal/runtime/task_serde.py +7 -10
- flyte/_internal/runtime/taskrunner.py +10 -1
- flyte/_internal/runtime/trigger_serde.py +13 -13
- flyte/_internal/runtime/types_serde.py +1 -1
- flyte/_keyring/file.py +2 -2
- flyte/_map.py +65 -13
- flyte/_pod.py +2 -2
- flyte/_resources.py +175 -31
- flyte/_run.py +37 -21
- flyte/_task.py +27 -6
- flyte/_task_environment.py +37 -10
- flyte/_utils/module_loader.py +2 -2
- flyte/_version.py +3 -3
- flyte/cli/_common.py +47 -5
- flyte/cli/_create.py +4 -0
- flyte/cli/_deploy.py +8 -0
- flyte/cli/_get.py +4 -0
- flyte/cli/_params.py +4 -4
- flyte/cli/_run.py +50 -7
- flyte/cli/_update.py +4 -3
- flyte/config/_config.py +2 -0
- flyte/config/_internal.py +1 -0
- flyte/config/_reader.py +3 -3
- flyte/errors.py +1 -1
- flyte/extend.py +4 -0
- flyte/extras/_container.py +6 -1
- flyte/git/_config.py +11 -9
- flyte/io/_dataframe/basic_dfs.py +1 -1
- flyte/io/_dataframe/dataframe.py +12 -8
- flyte/io/_dir.py +48 -15
- flyte/io/_file.py +48 -11
- flyte/models.py +12 -8
- flyte/remote/_action.py +18 -16
- flyte/remote/_client/_protocols.py +4 -3
- flyte/remote/_client/auth/_channel.py +1 -1
- flyte/remote/_client/controlplane.py +4 -8
- flyte/remote/_data.py +4 -3
- flyte/remote/_logs.py +3 -3
- flyte/remote/_run.py +5 -5
- flyte/remote/_secret.py +20 -13
- flyte/remote/_task.py +7 -8
- flyte/remote/_trigger.py +25 -27
- flyte/storage/_parallel_reader.py +274 -0
- flyte/storage/_storage.py +66 -2
- flyte/types/_interface.py +2 -2
- flyte/types/_pickle.py +1 -1
- flyte/types/_string_literals.py +8 -9
- flyte/types/_type_engine.py +25 -17
- flyte/types/_utils.py +1 -1
- {flyte-2.0.0b23.dist-info → flyte-2.0.0b25.dist-info}/METADATA +2 -1
- flyte-2.0.0b25.dist-info/RECORD +184 -0
- flyte/_protos/__init__.py +0 -0
- flyte/_protos/common/authorization_pb2.py +0 -66
- flyte/_protos/common/authorization_pb2.pyi +0 -108
- flyte/_protos/common/authorization_pb2_grpc.py +0 -4
- flyte/_protos/common/identifier_pb2.py +0 -117
- flyte/_protos/common/identifier_pb2.pyi +0 -142
- flyte/_protos/common/identifier_pb2_grpc.py +0 -4
- flyte/_protos/common/identity_pb2.py +0 -48
- flyte/_protos/common/identity_pb2.pyi +0 -72
- flyte/_protos/common/identity_pb2_grpc.py +0 -4
- flyte/_protos/common/list_pb2.py +0 -36
- flyte/_protos/common/list_pb2.pyi +0 -71
- flyte/_protos/common/list_pb2_grpc.py +0 -4
- flyte/_protos/common/policy_pb2.py +0 -37
- flyte/_protos/common/policy_pb2.pyi +0 -27
- flyte/_protos/common/policy_pb2_grpc.py +0 -4
- flyte/_protos/common/role_pb2.py +0 -37
- flyte/_protos/common/role_pb2.pyi +0 -53
- flyte/_protos/common/role_pb2_grpc.py +0 -4
- flyte/_protos/common/runtime_version_pb2.py +0 -28
- flyte/_protos/common/runtime_version_pb2.pyi +0 -24
- flyte/_protos/common/runtime_version_pb2_grpc.py +0 -4
- flyte/_protos/imagebuilder/definition_pb2.py +0 -60
- flyte/_protos/imagebuilder/definition_pb2.pyi +0 -153
- flyte/_protos/imagebuilder/definition_pb2_grpc.py +0 -4
- flyte/_protos/imagebuilder/payload_pb2.py +0 -32
- flyte/_protos/imagebuilder/payload_pb2.pyi +0 -21
- flyte/_protos/imagebuilder/payload_pb2_grpc.py +0 -4
- flyte/_protos/imagebuilder/service_pb2.py +0 -29
- flyte/_protos/imagebuilder/service_pb2.pyi +0 -5
- flyte/_protos/imagebuilder/service_pb2_grpc.py +0 -66
- flyte/_protos/logs/dataplane/payload_pb2.py +0 -100
- flyte/_protos/logs/dataplane/payload_pb2.pyi +0 -177
- flyte/_protos/logs/dataplane/payload_pb2_grpc.py +0 -4
- flyte/_protos/secret/definition_pb2.py +0 -49
- flyte/_protos/secret/definition_pb2.pyi +0 -93
- flyte/_protos/secret/definition_pb2_grpc.py +0 -4
- flyte/_protos/secret/payload_pb2.py +0 -62
- flyte/_protos/secret/payload_pb2.pyi +0 -94
- flyte/_protos/secret/payload_pb2_grpc.py +0 -4
- flyte/_protos/secret/secret_pb2.py +0 -38
- flyte/_protos/secret/secret_pb2.pyi +0 -6
- flyte/_protos/secret/secret_pb2_grpc.py +0 -198
- flyte/_protos/validate/validate/validate_pb2.py +0 -76
- flyte/_protos/workflow/common_pb2.py +0 -38
- flyte/_protos/workflow/common_pb2.pyi +0 -63
- flyte/_protos/workflow/common_pb2_grpc.py +0 -4
- flyte/_protos/workflow/environment_pb2.py +0 -29
- flyte/_protos/workflow/environment_pb2.pyi +0 -12
- flyte/_protos/workflow/environment_pb2_grpc.py +0 -4
- flyte/_protos/workflow/node_execution_service_pb2.py +0 -26
- flyte/_protos/workflow/node_execution_service_pb2.pyi +0 -4
- flyte/_protos/workflow/node_execution_service_pb2_grpc.py +0 -32
- flyte/_protos/workflow/queue_service_pb2.py +0 -117
- flyte/_protos/workflow/queue_service_pb2.pyi +0 -182
- flyte/_protos/workflow/queue_service_pb2_grpc.py +0 -206
- flyte/_protos/workflow/run_definition_pb2.py +0 -123
- flyte/_protos/workflow/run_definition_pb2.pyi +0 -354
- flyte/_protos/workflow/run_definition_pb2_grpc.py +0 -4
- flyte/_protos/workflow/run_logs_service_pb2.py +0 -41
- flyte/_protos/workflow/run_logs_service_pb2.pyi +0 -28
- flyte/_protos/workflow/run_logs_service_pb2_grpc.py +0 -69
- flyte/_protos/workflow/run_service_pb2.py +0 -147
- flyte/_protos/workflow/run_service_pb2.pyi +0 -203
- flyte/_protos/workflow/run_service_pb2_grpc.py +0 -480
- flyte/_protos/workflow/state_service_pb2.py +0 -67
- flyte/_protos/workflow/state_service_pb2.pyi +0 -76
- flyte/_protos/workflow/state_service_pb2_grpc.py +0 -138
- flyte/_protos/workflow/task_definition_pb2.py +0 -86
- flyte/_protos/workflow/task_definition_pb2.pyi +0 -105
- flyte/_protos/workflow/task_definition_pb2_grpc.py +0 -4
- flyte/_protos/workflow/task_service_pb2.py +0 -61
- flyte/_protos/workflow/task_service_pb2.pyi +0 -62
- flyte/_protos/workflow/task_service_pb2_grpc.py +0 -138
- flyte/_protos/workflow/trigger_definition_pb2.py +0 -66
- flyte/_protos/workflow/trigger_definition_pb2.pyi +0 -117
- flyte/_protos/workflow/trigger_definition_pb2_grpc.py +0 -4
- flyte/_protos/workflow/trigger_service_pb2.py +0 -96
- flyte/_protos/workflow/trigger_service_pb2.pyi +0 -110
- flyte/_protos/workflow/trigger_service_pb2_grpc.py +0 -281
- flyte-2.0.0b23.dist-info/RECORD +0 -262
- {flyte-2.0.0b23.data → flyte-2.0.0b25.data}/scripts/debug.py +0 -0
- {flyte-2.0.0b23.data → flyte-2.0.0b25.data}/scripts/runtime.py +0 -0
- {flyte-2.0.0b23.dist-info → flyte-2.0.0b25.dist-info}/WHEEL +0 -0
- {flyte-2.0.0b23.dist-info → flyte-2.0.0b25.dist-info}/entry_points.txt +0 -0
- {flyte-2.0.0b23.dist-info → flyte-2.0.0b25.dist-info}/licenses/LICENSE +0 -0
- {flyte-2.0.0b23.dist-info → flyte-2.0.0b25.dist-info}/top_level.txt +0 -0
flyte/remote/_secret.py
CHANGED
|
@@ -4,9 +4,9 @@ from dataclasses import dataclass
|
|
|
4
4
|
from typing import AsyncIterator, Literal, Union
|
|
5
5
|
|
|
6
6
|
import rich.repr
|
|
7
|
+
from flyteidl2.secret import definition_pb2, payload_pb2
|
|
7
8
|
|
|
8
|
-
from flyte._initialize import ensure_client, get_client,
|
|
9
|
-
from flyte._protos.secret import definition_pb2, payload_pb2
|
|
9
|
+
from flyte._initialize import ensure_client, get_client, get_init_config
|
|
10
10
|
from flyte.remote._common import ToJSONMixin
|
|
11
11
|
from flyte.syncify import syncify
|
|
12
12
|
|
|
@@ -21,12 +21,19 @@ class Secret(ToJSONMixin):
|
|
|
21
21
|
@classmethod
|
|
22
22
|
async def create(cls, name: str, value: Union[str, bytes], type: SecretTypes = "regular"):
|
|
23
23
|
ensure_client()
|
|
24
|
-
cfg =
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
24
|
+
cfg = get_init_config()
|
|
25
|
+
project = cfg.project
|
|
26
|
+
domain = cfg.domain
|
|
27
|
+
|
|
28
|
+
if type == "regular":
|
|
29
|
+
secret_type = definition_pb2.SecretType.SECRET_TYPE_GENERIC
|
|
30
|
+
|
|
31
|
+
else:
|
|
32
|
+
secret_type = definition_pb2.SecretType.SECRET_TYPE_IMAGE_PULL_SECRET
|
|
33
|
+
if project or domain:
|
|
34
|
+
raise ValueError(
|
|
35
|
+
f"Project `{project}` or domain `{domain}` should not be set when creating the image pull secret."
|
|
36
|
+
)
|
|
30
37
|
|
|
31
38
|
if isinstance(value, str):
|
|
32
39
|
secret = definition_pb2.SecretSpec(
|
|
@@ -42,8 +49,8 @@ class Secret(ToJSONMixin):
|
|
|
42
49
|
request=payload_pb2.CreateSecretRequest(
|
|
43
50
|
id=definition_pb2.SecretIdentifier(
|
|
44
51
|
organization=cfg.org,
|
|
45
|
-
project=
|
|
46
|
-
domain=
|
|
52
|
+
project=project,
|
|
53
|
+
domain=domain,
|
|
47
54
|
name=name,
|
|
48
55
|
),
|
|
49
56
|
secret_spec=secret,
|
|
@@ -54,7 +61,7 @@ class Secret(ToJSONMixin):
|
|
|
54
61
|
@classmethod
|
|
55
62
|
async def get(cls, name: str) -> Secret:
|
|
56
63
|
ensure_client()
|
|
57
|
-
cfg =
|
|
64
|
+
cfg = get_init_config()
|
|
58
65
|
resp = await get_client().secrets_service.GetSecret(
|
|
59
66
|
request=payload_pb2.GetSecretRequest(
|
|
60
67
|
id=definition_pb2.SecretIdentifier(
|
|
@@ -71,7 +78,7 @@ class Secret(ToJSONMixin):
|
|
|
71
78
|
@classmethod
|
|
72
79
|
async def listall(cls, limit: int = 100) -> AsyncIterator[Secret]:
|
|
73
80
|
ensure_client()
|
|
74
|
-
cfg =
|
|
81
|
+
cfg = get_init_config()
|
|
75
82
|
token = None
|
|
76
83
|
while True:
|
|
77
84
|
resp = await get_client().secrets_service.ListSecrets( # type: ignore
|
|
@@ -93,7 +100,7 @@ class Secret(ToJSONMixin):
|
|
|
93
100
|
@classmethod
|
|
94
101
|
async def delete(cls, name):
|
|
95
102
|
ensure_client()
|
|
96
|
-
cfg =
|
|
103
|
+
cfg = get_init_config()
|
|
97
104
|
await get_client().secrets_service.DeleteSecret( # type: ignore
|
|
98
105
|
request=payload_pb2.DeleteSecretRequest(
|
|
99
106
|
id=definition_pb2.SecretIdentifier(
|
flyte/remote/_task.py
CHANGED
|
@@ -6,19 +6,18 @@ from dataclasses import dataclass
|
|
|
6
6
|
from typing import Any, AsyncIterator, Callable, Coroutine, Dict, Iterator, Literal, Optional, Tuple, Union, cast
|
|
7
7
|
|
|
8
8
|
import rich.repr
|
|
9
|
-
from
|
|
10
|
-
from
|
|
9
|
+
from flyteidl2.common import identifier_pb2, list_pb2
|
|
10
|
+
from flyteidl2.core import literals_pb2
|
|
11
|
+
from flyteidl2.task import task_definition_pb2, task_service_pb2
|
|
11
12
|
|
|
12
13
|
import flyte
|
|
13
14
|
import flyte.errors
|
|
14
15
|
from flyte._cache.cache import CacheBehavior
|
|
15
16
|
from flyte._context import internal_ctx
|
|
16
|
-
from flyte._initialize import ensure_client, get_client,
|
|
17
|
+
from flyte._initialize import ensure_client, get_client, get_init_config
|
|
17
18
|
from flyte._internal.runtime.resources_serde import get_proto_resources
|
|
18
19
|
from flyte._internal.runtime.task_serde import get_proto_retry_strategy, get_proto_timeout, get_security_context
|
|
19
20
|
from flyte._logging import logger
|
|
20
|
-
from flyte._protos.common import identifier_pb2, list_pb2
|
|
21
|
-
from flyte._protos.workflow import task_definition_pb2, task_service_pb2
|
|
22
21
|
from flyte.models import NativeInterface
|
|
23
22
|
from flyte.syncify import syncify
|
|
24
23
|
|
|
@@ -35,7 +34,7 @@ def _repr_task_metadata(metadata: task_definition_pb2.TaskMetadata) -> rich.repr
|
|
|
35
34
|
else:
|
|
36
35
|
yield "deployed_by", f"App: {metadata.deployed_by.application.spec.name}"
|
|
37
36
|
yield "short_name", metadata.short_name
|
|
38
|
-
yield "deployed_at",
|
|
37
|
+
yield "deployed_at", metadata.deployed_at.ToDatetime()
|
|
39
38
|
yield "environment_name", metadata.environment_name
|
|
40
39
|
|
|
41
40
|
|
|
@@ -151,7 +150,7 @@ class TaskDetails(ToJSONMixin):
|
|
|
151
150
|
if ctx is None:
|
|
152
151
|
raise ValueError("auto_version=current can only be used within a task context.")
|
|
153
152
|
_version = ctx.version
|
|
154
|
-
cfg =
|
|
153
|
+
cfg = get_init_config()
|
|
155
154
|
task_id = task_definition_pb2.TaskIdentifier(
|
|
156
155
|
org=cfg.org,
|
|
157
156
|
project=project or cfg.project,
|
|
@@ -451,7 +450,7 @@ class Task(ToJSONMixin):
|
|
|
451
450
|
sort_pb2 = list_pb2.Sort(
|
|
452
451
|
key=sort_by[0], direction=list_pb2.Sort.ASCENDING if sort_by[1] == "asc" else list_pb2.Sort.DESCENDING
|
|
453
452
|
)
|
|
454
|
-
cfg =
|
|
453
|
+
cfg = get_init_config()
|
|
455
454
|
filters = []
|
|
456
455
|
if by_task_name:
|
|
457
456
|
filters.append(
|
flyte/remote/_trigger.py
CHANGED
|
@@ -5,12 +5,13 @@ from functools import cached_property
|
|
|
5
5
|
from typing import AsyncIterator
|
|
6
6
|
|
|
7
7
|
import grpc.aio
|
|
8
|
+
from flyteidl2.common import identifier_pb2, list_pb2
|
|
9
|
+
from flyteidl2.task import common_pb2, task_definition_pb2
|
|
10
|
+
from flyteidl2.trigger import trigger_definition_pb2, trigger_service_pb2
|
|
8
11
|
|
|
9
12
|
import flyte
|
|
10
|
-
from flyte._initialize import ensure_client, get_client,
|
|
13
|
+
from flyte._initialize import ensure_client, get_client, get_init_config
|
|
11
14
|
from flyte._internal.runtime import trigger_serde
|
|
12
|
-
from flyte._protos.common import identifier_pb2, list_pb2
|
|
13
|
-
from flyte._protos.workflow import common_pb2, task_definition_pb2, trigger_definition_pb2, trigger_service_pb2
|
|
14
15
|
from flyte.syncify import syncify
|
|
15
16
|
|
|
16
17
|
from ._common import ToJSONMixin
|
|
@@ -28,7 +29,7 @@ class TriggerDetails(ToJSONMixin):
|
|
|
28
29
|
Retrieve detailed information about a specific trigger by its name.
|
|
29
30
|
"""
|
|
30
31
|
ensure_client()
|
|
31
|
-
cfg =
|
|
32
|
+
cfg = get_init_config()
|
|
32
33
|
resp = await get_client().trigger_service.GetTriggerDetails(
|
|
33
34
|
request=trigger_service_pb2.GetTriggerDetailsRequest(
|
|
34
35
|
name=identifier_pb2.TriggerName(
|
|
@@ -101,7 +102,7 @@ class Trigger(ToJSONMixin):
|
|
|
101
102
|
:param task_name: Optional name of the task to associate with the trigger.
|
|
102
103
|
"""
|
|
103
104
|
ensure_client()
|
|
104
|
-
cfg =
|
|
105
|
+
cfg = get_init_config()
|
|
105
106
|
|
|
106
107
|
# Fetch the task to ensure it exists and to get its input definitions
|
|
107
108
|
try:
|
|
@@ -121,15 +122,12 @@ class Trigger(ToJSONMixin):
|
|
|
121
122
|
|
|
122
123
|
resp = await get_client().trigger_service.DeployTrigger(
|
|
123
124
|
request=trigger_service_pb2.DeployTriggerRequest(
|
|
124
|
-
|
|
125
|
-
name=
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
domain=cfg.domain,
|
|
131
|
-
),
|
|
132
|
-
revision=1,
|
|
125
|
+
name=identifier_pb2.TriggerName(
|
|
126
|
+
name=trigger.name,
|
|
127
|
+
task_name=task_name,
|
|
128
|
+
org=cfg.org,
|
|
129
|
+
project=cfg.project,
|
|
130
|
+
domain=cfg.domain,
|
|
133
131
|
),
|
|
134
132
|
spec=trigger_definition_pb2.TriggerSpec(
|
|
135
133
|
active=task_trigger.spec.active,
|
|
@@ -155,7 +153,7 @@ class Trigger(ToJSONMixin):
|
|
|
155
153
|
"""
|
|
156
154
|
Retrieve a trigger by its name and associated task name.
|
|
157
155
|
"""
|
|
158
|
-
return await TriggerDetails.get(name=name, task_name=task_name)
|
|
156
|
+
return await TriggerDetails.get.aio(name=name, task_name=task_name)
|
|
159
157
|
|
|
160
158
|
@syncify
|
|
161
159
|
@classmethod
|
|
@@ -166,9 +164,9 @@ class Trigger(ToJSONMixin):
|
|
|
166
164
|
List all triggers associated with a specific task or all tasks if no task name is provided.
|
|
167
165
|
"""
|
|
168
166
|
ensure_client()
|
|
169
|
-
cfg =
|
|
167
|
+
cfg = get_init_config()
|
|
170
168
|
token = None
|
|
171
|
-
|
|
169
|
+
task_name_id = None
|
|
172
170
|
project_id = None
|
|
173
171
|
task_id = None
|
|
174
172
|
if task_name and task_version:
|
|
@@ -179,13 +177,13 @@ class Trigger(ToJSONMixin):
|
|
|
179
177
|
org=cfg.org,
|
|
180
178
|
version=task_version,
|
|
181
179
|
)
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
180
|
+
elif task_name:
|
|
181
|
+
task_name_id = task_definition_pb2.TaskName(
|
|
182
|
+
name=task_name,
|
|
183
|
+
project=cfg.project,
|
|
184
|
+
domain=cfg.domain,
|
|
185
|
+
org=cfg.org,
|
|
186
|
+
)
|
|
189
187
|
else:
|
|
190
188
|
project_id = identifier_pb2.ProjectIdentifier(
|
|
191
189
|
organization=cfg.org,
|
|
@@ -198,7 +196,7 @@ class Trigger(ToJSONMixin):
|
|
|
198
196
|
request=trigger_service_pb2.ListTriggersRequest(
|
|
199
197
|
project_id=project_id,
|
|
200
198
|
task_id=task_id,
|
|
201
|
-
|
|
199
|
+
task_name=task_name_id,
|
|
202
200
|
request=list_pb2.ListRequest(
|
|
203
201
|
limit=limit,
|
|
204
202
|
token=token,
|
|
@@ -218,7 +216,7 @@ class Trigger(ToJSONMixin):
|
|
|
218
216
|
Pause a trigger by its name and associated task name.
|
|
219
217
|
"""
|
|
220
218
|
ensure_client()
|
|
221
|
-
cfg =
|
|
219
|
+
cfg = get_init_config()
|
|
222
220
|
await get_client().trigger_service.UpdateTriggers(
|
|
223
221
|
request=trigger_service_pb2.UpdateTriggersRequest(
|
|
224
222
|
names=[
|
|
@@ -241,7 +239,7 @@ class Trigger(ToJSONMixin):
|
|
|
241
239
|
Delete a trigger by its name.
|
|
242
240
|
"""
|
|
243
241
|
ensure_client()
|
|
244
|
-
cfg =
|
|
242
|
+
cfg = get_init_config()
|
|
245
243
|
await get_client().trigger_service.DeleteTriggers(
|
|
246
244
|
request=trigger_service_pb2.DeleteTriggersRequest(
|
|
247
245
|
names=[
|
|
@@ -0,0 +1,274 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import dataclasses
|
|
5
|
+
import io
|
|
6
|
+
import os
|
|
7
|
+
import pathlib
|
|
8
|
+
import sys
|
|
9
|
+
import tempfile
|
|
10
|
+
import typing
|
|
11
|
+
from typing import Any, Hashable, Protocol
|
|
12
|
+
|
|
13
|
+
import aiofiles
|
|
14
|
+
import aiofiles.os
|
|
15
|
+
import obstore
|
|
16
|
+
|
|
17
|
+
if typing.TYPE_CHECKING:
|
|
18
|
+
from obstore import Bytes, ObjectMeta
|
|
19
|
+
from obstore.store import ObjectStore
|
|
20
|
+
|
|
21
|
+
CHUNK_SIZE = int(os.getenv("FLYTE_IO_CHUNK_SIZE", str(16 * 1024 * 1024)))
|
|
22
|
+
MAX_CONCURRENCY = int(os.getenv("FLYTE_IO_MAX_CONCURRENCY", str(32)))
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class DownloadQueueEmpty(RuntimeError):
|
|
26
|
+
pass
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class BufferProtocol(Protocol):
|
|
30
|
+
async def write(self, offset, length, value: Bytes) -> None: ...
|
|
31
|
+
|
|
32
|
+
async def read(self) -> memoryview: ...
|
|
33
|
+
|
|
34
|
+
@property
|
|
35
|
+
def complete(self) -> bool: ...
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@dataclasses.dataclass
|
|
39
|
+
class _MemoryBuffer:
|
|
40
|
+
arr: bytearray
|
|
41
|
+
pending: int
|
|
42
|
+
_closed: bool = False
|
|
43
|
+
|
|
44
|
+
async def write(self, offset: int, length: int, value: Bytes) -> None:
|
|
45
|
+
self.arr[offset : offset + length] = memoryview(value)
|
|
46
|
+
self.pending -= length
|
|
47
|
+
|
|
48
|
+
async def read(self) -> memoryview:
|
|
49
|
+
return memoryview(self.arr)
|
|
50
|
+
|
|
51
|
+
@property
|
|
52
|
+
def complete(self) -> bool:
|
|
53
|
+
return self.pending == 0
|
|
54
|
+
|
|
55
|
+
@classmethod
|
|
56
|
+
def new(cls, size):
|
|
57
|
+
return cls(arr=bytearray(size), pending=size)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@dataclasses.dataclass
|
|
61
|
+
class _FileBuffer:
|
|
62
|
+
path: pathlib.Path
|
|
63
|
+
pending: int
|
|
64
|
+
_handle: io.FileIO | None = None
|
|
65
|
+
_closed: bool = False
|
|
66
|
+
|
|
67
|
+
async def write(self, offset: int, length: int, value: Bytes) -> None:
|
|
68
|
+
async with aiofiles.open(self.path, mode="r+b") as f:
|
|
69
|
+
await f.seek(offset)
|
|
70
|
+
await f.write(value)
|
|
71
|
+
self.pending -= length
|
|
72
|
+
|
|
73
|
+
async def read(self) -> memoryview:
|
|
74
|
+
async with aiofiles.open(self.path, mode="rb") as f:
|
|
75
|
+
return memoryview(await f.read())
|
|
76
|
+
|
|
77
|
+
@property
|
|
78
|
+
def complete(self) -> bool:
|
|
79
|
+
return self.pending == 0
|
|
80
|
+
|
|
81
|
+
@classmethod
|
|
82
|
+
def new(cls, path: pathlib.Path, size: int):
|
|
83
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
84
|
+
path.touch()
|
|
85
|
+
return cls(path=path, pending=size)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
@dataclasses.dataclass
|
|
89
|
+
class Chunk:
|
|
90
|
+
offset: int
|
|
91
|
+
length: int
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
@dataclasses.dataclass
|
|
95
|
+
class Source:
|
|
96
|
+
id: Hashable
|
|
97
|
+
path: pathlib.Path # Should be str, represents the fully qualified prefix of a file (no bucket)
|
|
98
|
+
length: int
|
|
99
|
+
metadata: Any | None = None
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
@dataclasses.dataclass
|
|
103
|
+
class DownloadTask:
|
|
104
|
+
source: Source
|
|
105
|
+
chunk: Chunk
|
|
106
|
+
target: pathlib.Path | None = None
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class ObstoreParallelReader:
|
|
110
|
+
def __init__(
|
|
111
|
+
self,
|
|
112
|
+
store: ObjectStore,
|
|
113
|
+
*,
|
|
114
|
+
chunk_size=CHUNK_SIZE,
|
|
115
|
+
max_concurrency=MAX_CONCURRENCY,
|
|
116
|
+
):
|
|
117
|
+
self._store = store
|
|
118
|
+
self._chunk_size = chunk_size
|
|
119
|
+
self._max_concurrency = max_concurrency
|
|
120
|
+
|
|
121
|
+
def _chunks(self, size) -> typing.Iterator[tuple[int, int]]:
|
|
122
|
+
cs = self._chunk_size
|
|
123
|
+
for offset in range(0, size, cs):
|
|
124
|
+
length = min(cs, size - offset)
|
|
125
|
+
yield offset, length
|
|
126
|
+
|
|
127
|
+
async def _as_completed(self, gen: typing.AsyncGenerator[DownloadTask, None], transformer=None):
|
|
128
|
+
inq: asyncio.Queue = asyncio.Queue(self._max_concurrency * 2)
|
|
129
|
+
outq: asyncio.Queue = asyncio.Queue()
|
|
130
|
+
sentinel = object()
|
|
131
|
+
done = asyncio.Event()
|
|
132
|
+
|
|
133
|
+
active: dict[Hashable, _FileBuffer | _MemoryBuffer] = {}
|
|
134
|
+
|
|
135
|
+
async def _fill():
|
|
136
|
+
# Helper function to fill the input queue, this is because the generator is async because it does list/head
|
|
137
|
+
# calls on the object store which are async.
|
|
138
|
+
try:
|
|
139
|
+
counter = 0
|
|
140
|
+
async for task in gen:
|
|
141
|
+
if task.source.id not in active:
|
|
142
|
+
active[task.source.id] = (
|
|
143
|
+
_FileBuffer.new(task.target, task.source.length)
|
|
144
|
+
if task.target is not None
|
|
145
|
+
else _MemoryBuffer.new(task.source.length)
|
|
146
|
+
)
|
|
147
|
+
await inq.put(task)
|
|
148
|
+
counter += 1
|
|
149
|
+
await inq.put(sentinel)
|
|
150
|
+
if counter == 0:
|
|
151
|
+
raise DownloadQueueEmpty
|
|
152
|
+
except asyncio.CancelledError:
|
|
153
|
+
# document why we need to swallow this
|
|
154
|
+
pass
|
|
155
|
+
|
|
156
|
+
async def _worker():
|
|
157
|
+
try:
|
|
158
|
+
while not done.is_set():
|
|
159
|
+
task = await inq.get()
|
|
160
|
+
if task is sentinel:
|
|
161
|
+
inq.put_nowait(sentinel)
|
|
162
|
+
break
|
|
163
|
+
chunk_source_offset = task.chunk.offset
|
|
164
|
+
buf = active[task.source.id]
|
|
165
|
+
data_to_write = await obstore.get_range_async(
|
|
166
|
+
self._store,
|
|
167
|
+
str(task.source.path),
|
|
168
|
+
start=chunk_source_offset,
|
|
169
|
+
end=chunk_source_offset + task.chunk.length,
|
|
170
|
+
)
|
|
171
|
+
await buf.write(
|
|
172
|
+
task.chunk.offset,
|
|
173
|
+
task.chunk.length,
|
|
174
|
+
data_to_write,
|
|
175
|
+
)
|
|
176
|
+
if not buf.complete:
|
|
177
|
+
continue
|
|
178
|
+
if transformer is not None:
|
|
179
|
+
result = await transformer(buf)
|
|
180
|
+
elif task.target is not None:
|
|
181
|
+
result = task.target
|
|
182
|
+
else:
|
|
183
|
+
result = task.source
|
|
184
|
+
outq.put_nowait((task.source.id, result))
|
|
185
|
+
del active[task.source.id]
|
|
186
|
+
except asyncio.CancelledError:
|
|
187
|
+
pass
|
|
188
|
+
finally:
|
|
189
|
+
done.set()
|
|
190
|
+
|
|
191
|
+
# Yield results as they are completed
|
|
192
|
+
if sys.version_info >= (3, 11):
|
|
193
|
+
async with asyncio.TaskGroup() as tg:
|
|
194
|
+
tg.create_task(_fill())
|
|
195
|
+
for _ in range(self._max_concurrency):
|
|
196
|
+
tg.create_task(_worker())
|
|
197
|
+
while not done.is_set():
|
|
198
|
+
yield await outq.get()
|
|
199
|
+
else:
|
|
200
|
+
fill_task = asyncio.create_task(_fill())
|
|
201
|
+
worker_tasks = [asyncio.create_task(_worker()) for _ in range(self._max_concurrency)]
|
|
202
|
+
try:
|
|
203
|
+
while not done.is_set():
|
|
204
|
+
yield await outq.get()
|
|
205
|
+
except Exception as e:
|
|
206
|
+
if not fill_task.done():
|
|
207
|
+
fill_task.cancel()
|
|
208
|
+
for wt in worker_tasks:
|
|
209
|
+
if not wt.done():
|
|
210
|
+
wt.cancel()
|
|
211
|
+
raise e
|
|
212
|
+
finally:
|
|
213
|
+
await asyncio.gather(fill_task, *worker_tasks, return_exceptions=True)
|
|
214
|
+
|
|
215
|
+
# Drain the output queue
|
|
216
|
+
try:
|
|
217
|
+
while True:
|
|
218
|
+
yield outq.get_nowait()
|
|
219
|
+
except asyncio.QueueEmpty:
|
|
220
|
+
pass
|
|
221
|
+
|
|
222
|
+
async def download_files(
|
|
223
|
+
self, src_prefix: pathlib.Path, target_prefix: pathlib.Path, *paths, destination_file_name: str | None = None
|
|
224
|
+
) -> None:
|
|
225
|
+
"""
|
|
226
|
+
src_prefix: Prefix you want to download from in the object store, not including the bucket name, nor file name.
|
|
227
|
+
Should be replaced with string
|
|
228
|
+
target_prefix: Local directory to download to
|
|
229
|
+
paths: Specific paths (relative to src_prefix) to download. If empty, download everything
|
|
230
|
+
"""
|
|
231
|
+
|
|
232
|
+
async def _list_downloadable() -> typing.AsyncGenerator[ObjectMeta, None]:
|
|
233
|
+
if paths:
|
|
234
|
+
# For specific file paths, use async head
|
|
235
|
+
for path_ in paths:
|
|
236
|
+
path = src_prefix / path_
|
|
237
|
+
x = await obstore.head_async(self._store, str(path))
|
|
238
|
+
yield x
|
|
239
|
+
return
|
|
240
|
+
|
|
241
|
+
# Use obstore.list() for recursive listing (all files in all subdirectories)
|
|
242
|
+
# obstore.list() returns an async iterator that yields batches (lists) of objects
|
|
243
|
+
async for batch in obstore.list(self._store, prefix=str(src_prefix)):
|
|
244
|
+
for obj in batch:
|
|
245
|
+
yield obj
|
|
246
|
+
|
|
247
|
+
async def _gen(tmp_dir: str) -> typing.AsyncGenerator[DownloadTask, None]:
|
|
248
|
+
async for obj in _list_downloadable():
|
|
249
|
+
path = pathlib.Path(obj["path"]) # e.g. Path(prefix/file.txt), needs to be changed to str.
|
|
250
|
+
size = obj["size"]
|
|
251
|
+
source = Source(id=path, path=path, length=size)
|
|
252
|
+
# Strip src_prefix from path for destination
|
|
253
|
+
rel_path = path.relative_to(src_prefix) # doesn't work on windows
|
|
254
|
+
for offset, length in self._chunks(size):
|
|
255
|
+
yield DownloadTask(
|
|
256
|
+
source=source,
|
|
257
|
+
target=tmp_dir / rel_path, # doesn't work on windows
|
|
258
|
+
chunk=Chunk(offset, length),
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
def _transform_decorator(tmp_dir: str):
|
|
262
|
+
async def _transformer(buf: _FileBuffer) -> None:
|
|
263
|
+
if len(paths) == 1 and destination_file_name is not None:
|
|
264
|
+
target = target_prefix / destination_file_name
|
|
265
|
+
else:
|
|
266
|
+
target = target_prefix / buf.path.relative_to(tmp_dir)
|
|
267
|
+
await aiofiles.os.makedirs(target.parent, exist_ok=True)
|
|
268
|
+
return await aiofiles.os.replace(buf.path, target) # mv buf.path target
|
|
269
|
+
|
|
270
|
+
return _transformer
|
|
271
|
+
|
|
272
|
+
with tempfile.TemporaryDirectory() as temporary_dir:
|
|
273
|
+
async for _ in self._as_completed(_gen(temporary_dir), transformer=_transform_decorator(temporary_dir)):
|
|
274
|
+
pass
|
flyte/storage/_storage.py
CHANGED
|
@@ -147,12 +147,76 @@ def _get_anonymous_filesystem(from_path):
|
|
|
147
147
|
return get_underlying_filesystem(get_protocol(from_path), anonymous=True, asynchronous=True)
|
|
148
148
|
|
|
149
149
|
|
|
150
|
+
async def _get_obstore_bypass(from_path: str, to_path: str | pathlib.Path, recursive: bool = False, **kwargs) -> str:
|
|
151
|
+
from obstore.store import ObjectStore
|
|
152
|
+
|
|
153
|
+
from flyte.storage._parallel_reader import ObstoreParallelReader
|
|
154
|
+
|
|
155
|
+
fs = get_underlying_filesystem(path=from_path)
|
|
156
|
+
bucket, prefix = fs._split_path(from_path) # pylint: disable=W0212
|
|
157
|
+
store: ObjectStore = fs._construct_store(bucket)
|
|
158
|
+
|
|
159
|
+
download_kwargs = {}
|
|
160
|
+
if "chunk_size" in kwargs:
|
|
161
|
+
download_kwargs["chunk_size"] = kwargs["chunk_size"]
|
|
162
|
+
if "max_concurrency" in kwargs:
|
|
163
|
+
download_kwargs["max_concurrency"] = kwargs["max_concurrency"]
|
|
164
|
+
|
|
165
|
+
reader = ObstoreParallelReader(store, **download_kwargs)
|
|
166
|
+
target_path = pathlib.Path(to_path) if isinstance(to_path, str) else to_path
|
|
167
|
+
|
|
168
|
+
# if recursive, just download the prefix to the target path
|
|
169
|
+
if recursive:
|
|
170
|
+
logger.debug(f"Downloading recursively {prefix=} to {target_path=}")
|
|
171
|
+
await reader.download_files(
|
|
172
|
+
prefix,
|
|
173
|
+
target_path,
|
|
174
|
+
)
|
|
175
|
+
return str(to_path)
|
|
176
|
+
|
|
177
|
+
# if not recursive, we need to split out the file name from the prefix
|
|
178
|
+
else:
|
|
179
|
+
path_for_reader = pathlib.Path(prefix).name
|
|
180
|
+
final_prefix = pathlib.Path(prefix).parent
|
|
181
|
+
logger.debug(f"Downloading single file {final_prefix=}, {path_for_reader=} to {target_path=}")
|
|
182
|
+
await reader.download_files(
|
|
183
|
+
final_prefix,
|
|
184
|
+
target_path.parent,
|
|
185
|
+
path_for_reader,
|
|
186
|
+
destination_file_name=target_path.name,
|
|
187
|
+
)
|
|
188
|
+
return str(target_path)
|
|
189
|
+
|
|
190
|
+
|
|
150
191
|
async def get(from_path: str, to_path: Optional[str | pathlib.Path] = None, recursive: bool = False, **kwargs) -> str:
|
|
151
192
|
if not to_path:
|
|
152
|
-
name = pathlib.Path(from_path).name
|
|
193
|
+
name = pathlib.Path(from_path).name # may need to be adjusted for windows
|
|
153
194
|
to_path = get_random_local_path(file_path_or_file_name=name)
|
|
154
195
|
logger.debug(f"Storing file from {from_path} to {to_path}")
|
|
196
|
+
else:
|
|
197
|
+
# Only apply directory logic for single files (not recursive)
|
|
198
|
+
if not recursive:
|
|
199
|
+
to_path_str = str(to_path)
|
|
200
|
+
# Check for trailing separator BEFORE converting to Path (which normalizes and removes it)
|
|
201
|
+
ends_with_sep = to_path_str.endswith(os.sep)
|
|
202
|
+
to_path_obj = pathlib.Path(to_path)
|
|
203
|
+
|
|
204
|
+
# If path ends with os.sep or is an existing directory, append source filename
|
|
205
|
+
if ends_with_sep or (to_path_obj.exists() and to_path_obj.is_dir()):
|
|
206
|
+
source_filename = pathlib.Path(from_path).name # may need to be adjusted for windows
|
|
207
|
+
to_path = to_path_obj / source_filename
|
|
208
|
+
# For recursive=True, keep to_path as-is (it's the destination directory for contents)
|
|
209
|
+
|
|
155
210
|
file_system = get_underlying_filesystem(path=from_path)
|
|
211
|
+
|
|
212
|
+
# Check if we should use obstore bypass
|
|
213
|
+
if (
|
|
214
|
+
_is_obstore_supported_protocol(file_system.protocol)
|
|
215
|
+
and hasattr(file_system, "_split_path")
|
|
216
|
+
and hasattr(file_system, "_construct_store")
|
|
217
|
+
):
|
|
218
|
+
return await _get_obstore_bypass(from_path, to_path, recursive, **kwargs)
|
|
219
|
+
|
|
156
220
|
try:
|
|
157
221
|
return await _get_from_filesystem(file_system, from_path, to_path, recursive=recursive, **kwargs)
|
|
158
222
|
except (OSError, GenericError) as oe:
|
|
@@ -195,7 +259,7 @@ async def put(from_path: str, to_path: Optional[str] = None, recursive: bool = F
|
|
|
195
259
|
from flyte._context import internal_ctx
|
|
196
260
|
|
|
197
261
|
ctx = internal_ctx()
|
|
198
|
-
name = pathlib.Path(from_path).name
|
|
262
|
+
name = pathlib.Path(from_path).name
|
|
199
263
|
to_path = ctx.raw_data.get_random_remote_path(file_name=name)
|
|
200
264
|
|
|
201
265
|
file_system = get_underlying_filesystem(path=to_path)
|
flyte/types/_interface.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
import inspect
|
|
2
2
|
from typing import Any, Dict, Iterable, Tuple, Type, cast
|
|
3
3
|
|
|
4
|
-
from
|
|
4
|
+
from flyteidl2.core import interface_pb2, literals_pb2
|
|
5
|
+
from flyteidl2.task import common_pb2
|
|
5
6
|
|
|
6
|
-
from flyte._protos.workflow import common_pb2
|
|
7
7
|
from flyte.models import NativeInterface
|
|
8
8
|
|
|
9
9
|
|
flyte/types/_pickle.py
CHANGED
flyte/types/_string_literals.py
CHANGED
|
@@ -3,11 +3,10 @@ import json
|
|
|
3
3
|
from typing import Any, Dict, Union
|
|
4
4
|
|
|
5
5
|
import msgpack
|
|
6
|
-
from
|
|
6
|
+
from flyteidl2.core import literals_pb2
|
|
7
|
+
from flyteidl2.task import common_pb2
|
|
7
8
|
from google.protobuf.json_format import MessageToDict
|
|
8
9
|
|
|
9
|
-
from flyte._protos.workflow import run_definition_pb2
|
|
10
|
-
|
|
11
10
|
|
|
12
11
|
def _primitive_to_string(primitive: literals_pb2.Primitive) -> Any:
|
|
13
12
|
"""
|
|
@@ -88,9 +87,9 @@ def _dict_literal_repr(lmd: Dict[str, literals_pb2.Literal]) -> Dict[str, Any]:
|
|
|
88
87
|
def literal_string_repr(
|
|
89
88
|
lm: Union[
|
|
90
89
|
literals_pb2.Literal,
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
90
|
+
common_pb2.NamedLiteral,
|
|
91
|
+
common_pb2.Inputs,
|
|
92
|
+
common_pb2.Outputs,
|
|
94
93
|
literals_pb2.LiteralMap,
|
|
95
94
|
Dict[str, literals_pb2.Literal],
|
|
96
95
|
],
|
|
@@ -105,13 +104,13 @@ def literal_string_repr(
|
|
|
105
104
|
return _literal_string_repr(lm)
|
|
106
105
|
case literals_pb2.LiteralMap():
|
|
107
106
|
return _dict_literal_repr(lm.literals)
|
|
108
|
-
case
|
|
107
|
+
case common_pb2.NamedLiteral():
|
|
109
108
|
lmd = {lm.name: lm.value}
|
|
110
109
|
return _dict_literal_repr(lmd)
|
|
111
|
-
case
|
|
110
|
+
case common_pb2.Inputs():
|
|
112
111
|
lmd = {n.name: n.value for n in lm.literals}
|
|
113
112
|
return _dict_literal_repr(lmd)
|
|
114
|
-
case
|
|
113
|
+
case common_pb2.Outputs():
|
|
115
114
|
lmd = {n.name: n.value for n in lm.literals}
|
|
116
115
|
return _dict_literal_repr(lmd)
|
|
117
116
|
case dict():
|