flyte 2.0.0b23__py3-none-any.whl → 2.0.0b25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flyte might be problematic. Click here for more details.

Files changed (162) hide show
  1. flyte/__init__.py +11 -2
  2. flyte/_cache/local_cache.py +4 -3
  3. flyte/_code_bundle/_utils.py +3 -3
  4. flyte/_code_bundle/bundle.py +12 -5
  5. flyte/_context.py +4 -1
  6. flyte/_custom_context.py +73 -0
  7. flyte/_deploy.py +31 -7
  8. flyte/_image.py +48 -16
  9. flyte/_initialize.py +69 -26
  10. flyte/_internal/controllers/_local_controller.py +1 -0
  11. flyte/_internal/controllers/_trace.py +1 -1
  12. flyte/_internal/controllers/remote/_action.py +9 -10
  13. flyte/_internal/controllers/remote/_client.py +1 -1
  14. flyte/_internal/controllers/remote/_controller.py +4 -2
  15. flyte/_internal/controllers/remote/_core.py +10 -13
  16. flyte/_internal/controllers/remote/_informer.py +3 -3
  17. flyte/_internal/controllers/remote/_service_protocol.py +7 -7
  18. flyte/_internal/imagebuild/docker_builder.py +45 -59
  19. flyte/_internal/imagebuild/remote_builder.py +51 -11
  20. flyte/_internal/imagebuild/utils.py +51 -3
  21. flyte/_internal/runtime/convert.py +39 -18
  22. flyte/_internal/runtime/io.py +8 -7
  23. flyte/_internal/runtime/resources_serde.py +20 -6
  24. flyte/_internal/runtime/reuse.py +1 -1
  25. flyte/_internal/runtime/task_serde.py +7 -10
  26. flyte/_internal/runtime/taskrunner.py +10 -1
  27. flyte/_internal/runtime/trigger_serde.py +13 -13
  28. flyte/_internal/runtime/types_serde.py +1 -1
  29. flyte/_keyring/file.py +2 -2
  30. flyte/_map.py +65 -13
  31. flyte/_pod.py +2 -2
  32. flyte/_resources.py +175 -31
  33. flyte/_run.py +37 -21
  34. flyte/_task.py +27 -6
  35. flyte/_task_environment.py +37 -10
  36. flyte/_utils/module_loader.py +2 -2
  37. flyte/_version.py +3 -3
  38. flyte/cli/_common.py +47 -5
  39. flyte/cli/_create.py +4 -0
  40. flyte/cli/_deploy.py +8 -0
  41. flyte/cli/_get.py +4 -0
  42. flyte/cli/_params.py +4 -4
  43. flyte/cli/_run.py +50 -7
  44. flyte/cli/_update.py +4 -3
  45. flyte/config/_config.py +2 -0
  46. flyte/config/_internal.py +1 -0
  47. flyte/config/_reader.py +3 -3
  48. flyte/errors.py +1 -1
  49. flyte/extend.py +4 -0
  50. flyte/extras/_container.py +6 -1
  51. flyte/git/_config.py +11 -9
  52. flyte/io/_dataframe/basic_dfs.py +1 -1
  53. flyte/io/_dataframe/dataframe.py +12 -8
  54. flyte/io/_dir.py +48 -15
  55. flyte/io/_file.py +48 -11
  56. flyte/models.py +12 -8
  57. flyte/remote/_action.py +18 -16
  58. flyte/remote/_client/_protocols.py +4 -3
  59. flyte/remote/_client/auth/_channel.py +1 -1
  60. flyte/remote/_client/controlplane.py +4 -8
  61. flyte/remote/_data.py +4 -3
  62. flyte/remote/_logs.py +3 -3
  63. flyte/remote/_run.py +5 -5
  64. flyte/remote/_secret.py +20 -13
  65. flyte/remote/_task.py +7 -8
  66. flyte/remote/_trigger.py +25 -27
  67. flyte/storage/_parallel_reader.py +274 -0
  68. flyte/storage/_storage.py +66 -2
  69. flyte/types/_interface.py +2 -2
  70. flyte/types/_pickle.py +1 -1
  71. flyte/types/_string_literals.py +8 -9
  72. flyte/types/_type_engine.py +25 -17
  73. flyte/types/_utils.py +1 -1
  74. {flyte-2.0.0b23.dist-info → flyte-2.0.0b25.dist-info}/METADATA +2 -1
  75. flyte-2.0.0b25.dist-info/RECORD +184 -0
  76. flyte/_protos/__init__.py +0 -0
  77. flyte/_protos/common/authorization_pb2.py +0 -66
  78. flyte/_protos/common/authorization_pb2.pyi +0 -108
  79. flyte/_protos/common/authorization_pb2_grpc.py +0 -4
  80. flyte/_protos/common/identifier_pb2.py +0 -117
  81. flyte/_protos/common/identifier_pb2.pyi +0 -142
  82. flyte/_protos/common/identifier_pb2_grpc.py +0 -4
  83. flyte/_protos/common/identity_pb2.py +0 -48
  84. flyte/_protos/common/identity_pb2.pyi +0 -72
  85. flyte/_protos/common/identity_pb2_grpc.py +0 -4
  86. flyte/_protos/common/list_pb2.py +0 -36
  87. flyte/_protos/common/list_pb2.pyi +0 -71
  88. flyte/_protos/common/list_pb2_grpc.py +0 -4
  89. flyte/_protos/common/policy_pb2.py +0 -37
  90. flyte/_protos/common/policy_pb2.pyi +0 -27
  91. flyte/_protos/common/policy_pb2_grpc.py +0 -4
  92. flyte/_protos/common/role_pb2.py +0 -37
  93. flyte/_protos/common/role_pb2.pyi +0 -53
  94. flyte/_protos/common/role_pb2_grpc.py +0 -4
  95. flyte/_protos/common/runtime_version_pb2.py +0 -28
  96. flyte/_protos/common/runtime_version_pb2.pyi +0 -24
  97. flyte/_protos/common/runtime_version_pb2_grpc.py +0 -4
  98. flyte/_protos/imagebuilder/definition_pb2.py +0 -60
  99. flyte/_protos/imagebuilder/definition_pb2.pyi +0 -153
  100. flyte/_protos/imagebuilder/definition_pb2_grpc.py +0 -4
  101. flyte/_protos/imagebuilder/payload_pb2.py +0 -32
  102. flyte/_protos/imagebuilder/payload_pb2.pyi +0 -21
  103. flyte/_protos/imagebuilder/payload_pb2_grpc.py +0 -4
  104. flyte/_protos/imagebuilder/service_pb2.py +0 -29
  105. flyte/_protos/imagebuilder/service_pb2.pyi +0 -5
  106. flyte/_protos/imagebuilder/service_pb2_grpc.py +0 -66
  107. flyte/_protos/logs/dataplane/payload_pb2.py +0 -100
  108. flyte/_protos/logs/dataplane/payload_pb2.pyi +0 -177
  109. flyte/_protos/logs/dataplane/payload_pb2_grpc.py +0 -4
  110. flyte/_protos/secret/definition_pb2.py +0 -49
  111. flyte/_protos/secret/definition_pb2.pyi +0 -93
  112. flyte/_protos/secret/definition_pb2_grpc.py +0 -4
  113. flyte/_protos/secret/payload_pb2.py +0 -62
  114. flyte/_protos/secret/payload_pb2.pyi +0 -94
  115. flyte/_protos/secret/payload_pb2_grpc.py +0 -4
  116. flyte/_protos/secret/secret_pb2.py +0 -38
  117. flyte/_protos/secret/secret_pb2.pyi +0 -6
  118. flyte/_protos/secret/secret_pb2_grpc.py +0 -198
  119. flyte/_protos/validate/validate/validate_pb2.py +0 -76
  120. flyte/_protos/workflow/common_pb2.py +0 -38
  121. flyte/_protos/workflow/common_pb2.pyi +0 -63
  122. flyte/_protos/workflow/common_pb2_grpc.py +0 -4
  123. flyte/_protos/workflow/environment_pb2.py +0 -29
  124. flyte/_protos/workflow/environment_pb2.pyi +0 -12
  125. flyte/_protos/workflow/environment_pb2_grpc.py +0 -4
  126. flyte/_protos/workflow/node_execution_service_pb2.py +0 -26
  127. flyte/_protos/workflow/node_execution_service_pb2.pyi +0 -4
  128. flyte/_protos/workflow/node_execution_service_pb2_grpc.py +0 -32
  129. flyte/_protos/workflow/queue_service_pb2.py +0 -117
  130. flyte/_protos/workflow/queue_service_pb2.pyi +0 -182
  131. flyte/_protos/workflow/queue_service_pb2_grpc.py +0 -206
  132. flyte/_protos/workflow/run_definition_pb2.py +0 -123
  133. flyte/_protos/workflow/run_definition_pb2.pyi +0 -354
  134. flyte/_protos/workflow/run_definition_pb2_grpc.py +0 -4
  135. flyte/_protos/workflow/run_logs_service_pb2.py +0 -41
  136. flyte/_protos/workflow/run_logs_service_pb2.pyi +0 -28
  137. flyte/_protos/workflow/run_logs_service_pb2_grpc.py +0 -69
  138. flyte/_protos/workflow/run_service_pb2.py +0 -147
  139. flyte/_protos/workflow/run_service_pb2.pyi +0 -203
  140. flyte/_protos/workflow/run_service_pb2_grpc.py +0 -480
  141. flyte/_protos/workflow/state_service_pb2.py +0 -67
  142. flyte/_protos/workflow/state_service_pb2.pyi +0 -76
  143. flyte/_protos/workflow/state_service_pb2_grpc.py +0 -138
  144. flyte/_protos/workflow/task_definition_pb2.py +0 -86
  145. flyte/_protos/workflow/task_definition_pb2.pyi +0 -105
  146. flyte/_protos/workflow/task_definition_pb2_grpc.py +0 -4
  147. flyte/_protos/workflow/task_service_pb2.py +0 -61
  148. flyte/_protos/workflow/task_service_pb2.pyi +0 -62
  149. flyte/_protos/workflow/task_service_pb2_grpc.py +0 -138
  150. flyte/_protos/workflow/trigger_definition_pb2.py +0 -66
  151. flyte/_protos/workflow/trigger_definition_pb2.pyi +0 -117
  152. flyte/_protos/workflow/trigger_definition_pb2_grpc.py +0 -4
  153. flyte/_protos/workflow/trigger_service_pb2.py +0 -96
  154. flyte/_protos/workflow/trigger_service_pb2.pyi +0 -110
  155. flyte/_protos/workflow/trigger_service_pb2_grpc.py +0 -281
  156. flyte-2.0.0b23.dist-info/RECORD +0 -262
  157. {flyte-2.0.0b23.data → flyte-2.0.0b25.data}/scripts/debug.py +0 -0
  158. {flyte-2.0.0b23.data → flyte-2.0.0b25.data}/scripts/runtime.py +0 -0
  159. {flyte-2.0.0b23.dist-info → flyte-2.0.0b25.dist-info}/WHEEL +0 -0
  160. {flyte-2.0.0b23.dist-info → flyte-2.0.0b25.dist-info}/entry_points.txt +0 -0
  161. {flyte-2.0.0b23.dist-info → flyte-2.0.0b25.dist-info}/licenses/LICENSE +0 -0
  162. {flyte-2.0.0b23.dist-info → flyte-2.0.0b25.dist-info}/top_level.txt +0 -0
flyte/remote/_secret.py CHANGED
@@ -4,9 +4,9 @@ from dataclasses import dataclass
4
4
  from typing import AsyncIterator, Literal, Union
5
5
 
6
6
  import rich.repr
7
+ from flyteidl2.secret import definition_pb2, payload_pb2
7
8
 
8
- from flyte._initialize import ensure_client, get_client, get_common_config
9
- from flyte._protos.secret import definition_pb2, payload_pb2
9
+ from flyte._initialize import ensure_client, get_client, get_init_config
10
10
  from flyte.remote._common import ToJSONMixin
11
11
  from flyte.syncify import syncify
12
12
 
@@ -21,12 +21,19 @@ class Secret(ToJSONMixin):
21
21
  @classmethod
22
22
  async def create(cls, name: str, value: Union[str, bytes], type: SecretTypes = "regular"):
23
23
  ensure_client()
24
- cfg = get_common_config()
25
- secret_type = (
26
- definition_pb2.SecretType.SECRET_TYPE_GENERIC
27
- if type == "regular"
28
- else definition_pb2.SecretType.SECRET_TYPE_IMAGE_PULL_SECRET
29
- )
24
+ cfg = get_init_config()
25
+ project = cfg.project
26
+ domain = cfg.domain
27
+
28
+ if type == "regular":
29
+ secret_type = definition_pb2.SecretType.SECRET_TYPE_GENERIC
30
+
31
+ else:
32
+ secret_type = definition_pb2.SecretType.SECRET_TYPE_IMAGE_PULL_SECRET
33
+ if project or domain:
34
+ raise ValueError(
35
+ f"Project `{project}` or domain `{domain}` should not be set when creating the image pull secret."
36
+ )
30
37
 
31
38
  if isinstance(value, str):
32
39
  secret = definition_pb2.SecretSpec(
@@ -42,8 +49,8 @@ class Secret(ToJSONMixin):
42
49
  request=payload_pb2.CreateSecretRequest(
43
50
  id=definition_pb2.SecretIdentifier(
44
51
  organization=cfg.org,
45
- project=cfg.project,
46
- domain=cfg.domain,
52
+ project=project,
53
+ domain=domain,
47
54
  name=name,
48
55
  ),
49
56
  secret_spec=secret,
@@ -54,7 +61,7 @@ class Secret(ToJSONMixin):
54
61
  @classmethod
55
62
  async def get(cls, name: str) -> Secret:
56
63
  ensure_client()
57
- cfg = get_common_config()
64
+ cfg = get_init_config()
58
65
  resp = await get_client().secrets_service.GetSecret(
59
66
  request=payload_pb2.GetSecretRequest(
60
67
  id=definition_pb2.SecretIdentifier(
@@ -71,7 +78,7 @@ class Secret(ToJSONMixin):
71
78
  @classmethod
72
79
  async def listall(cls, limit: int = 100) -> AsyncIterator[Secret]:
73
80
  ensure_client()
74
- cfg = get_common_config()
81
+ cfg = get_init_config()
75
82
  token = None
76
83
  while True:
77
84
  resp = await get_client().secrets_service.ListSecrets( # type: ignore
@@ -93,7 +100,7 @@ class Secret(ToJSONMixin):
93
100
  @classmethod
94
101
  async def delete(cls, name):
95
102
  ensure_client()
96
- cfg = get_common_config()
103
+ cfg = get_init_config()
97
104
  await get_client().secrets_service.DeleteSecret( # type: ignore
98
105
  request=payload_pb2.DeleteSecretRequest(
99
106
  id=definition_pb2.SecretIdentifier(
flyte/remote/_task.py CHANGED
@@ -6,19 +6,18 @@ from dataclasses import dataclass
6
6
  from typing import Any, AsyncIterator, Callable, Coroutine, Dict, Iterator, Literal, Optional, Tuple, Union, cast
7
7
 
8
8
  import rich.repr
9
- from flyteidl.core import literals_pb2
10
- from google.protobuf import timestamp
9
+ from flyteidl2.common import identifier_pb2, list_pb2
10
+ from flyteidl2.core import literals_pb2
11
+ from flyteidl2.task import task_definition_pb2, task_service_pb2
11
12
 
12
13
  import flyte
13
14
  import flyte.errors
14
15
  from flyte._cache.cache import CacheBehavior
15
16
  from flyte._context import internal_ctx
16
- from flyte._initialize import ensure_client, get_client, get_common_config
17
+ from flyte._initialize import ensure_client, get_client, get_init_config
17
18
  from flyte._internal.runtime.resources_serde import get_proto_resources
18
19
  from flyte._internal.runtime.task_serde import get_proto_retry_strategy, get_proto_timeout, get_security_context
19
20
  from flyte._logging import logger
20
- from flyte._protos.common import identifier_pb2, list_pb2
21
- from flyte._protos.workflow import task_definition_pb2, task_service_pb2
22
21
  from flyte.models import NativeInterface
23
22
  from flyte.syncify import syncify
24
23
 
@@ -35,7 +34,7 @@ def _repr_task_metadata(metadata: task_definition_pb2.TaskMetadata) -> rich.repr
35
34
  else:
36
35
  yield "deployed_by", f"App: {metadata.deployed_by.application.spec.name}"
37
36
  yield "short_name", metadata.short_name
38
- yield "deployed_at", timestamp.to_datetime(metadata.deployed_at)
37
+ yield "deployed_at", metadata.deployed_at.ToDatetime()
39
38
  yield "environment_name", metadata.environment_name
40
39
 
41
40
 
@@ -151,7 +150,7 @@ class TaskDetails(ToJSONMixin):
151
150
  if ctx is None:
152
151
  raise ValueError("auto_version=current can only be used within a task context.")
153
152
  _version = ctx.version
154
- cfg = get_common_config()
153
+ cfg = get_init_config()
155
154
  task_id = task_definition_pb2.TaskIdentifier(
156
155
  org=cfg.org,
157
156
  project=project or cfg.project,
@@ -451,7 +450,7 @@ class Task(ToJSONMixin):
451
450
  sort_pb2 = list_pb2.Sort(
452
451
  key=sort_by[0], direction=list_pb2.Sort.ASCENDING if sort_by[1] == "asc" else list_pb2.Sort.DESCENDING
453
452
  )
454
- cfg = get_common_config()
453
+ cfg = get_init_config()
455
454
  filters = []
456
455
  if by_task_name:
457
456
  filters.append(
flyte/remote/_trigger.py CHANGED
@@ -5,12 +5,13 @@ from functools import cached_property
5
5
  from typing import AsyncIterator
6
6
 
7
7
  import grpc.aio
8
+ from flyteidl2.common import identifier_pb2, list_pb2
9
+ from flyteidl2.task import common_pb2, task_definition_pb2
10
+ from flyteidl2.trigger import trigger_definition_pb2, trigger_service_pb2
8
11
 
9
12
  import flyte
10
- from flyte._initialize import ensure_client, get_client, get_common_config
13
+ from flyte._initialize import ensure_client, get_client, get_init_config
11
14
  from flyte._internal.runtime import trigger_serde
12
- from flyte._protos.common import identifier_pb2, list_pb2
13
- from flyte._protos.workflow import common_pb2, task_definition_pb2, trigger_definition_pb2, trigger_service_pb2
14
15
  from flyte.syncify import syncify
15
16
 
16
17
  from ._common import ToJSONMixin
@@ -28,7 +29,7 @@ class TriggerDetails(ToJSONMixin):
28
29
  Retrieve detailed information about a specific trigger by its name.
29
30
  """
30
31
  ensure_client()
31
- cfg = get_common_config()
32
+ cfg = get_init_config()
32
33
  resp = await get_client().trigger_service.GetTriggerDetails(
33
34
  request=trigger_service_pb2.GetTriggerDetailsRequest(
34
35
  name=identifier_pb2.TriggerName(
@@ -101,7 +102,7 @@ class Trigger(ToJSONMixin):
101
102
  :param task_name: Optional name of the task to associate with the trigger.
102
103
  """
103
104
  ensure_client()
104
- cfg = get_common_config()
105
+ cfg = get_init_config()
105
106
 
106
107
  # Fetch the task to ensure it exists and to get its input definitions
107
108
  try:
@@ -121,15 +122,12 @@ class Trigger(ToJSONMixin):
121
122
 
122
123
  resp = await get_client().trigger_service.DeployTrigger(
123
124
  request=trigger_service_pb2.DeployTriggerRequest(
124
- id=identifier_pb2.TriggerIdentifier(
125
- name=identifier_pb2.TriggerName(
126
- name=trigger.name,
127
- task_name=task_name,
128
- org=cfg.org,
129
- project=cfg.project,
130
- domain=cfg.domain,
131
- ),
132
- revision=1,
125
+ name=identifier_pb2.TriggerName(
126
+ name=trigger.name,
127
+ task_name=task_name,
128
+ org=cfg.org,
129
+ project=cfg.project,
130
+ domain=cfg.domain,
133
131
  ),
134
132
  spec=trigger_definition_pb2.TriggerSpec(
135
133
  active=task_trigger.spec.active,
@@ -155,7 +153,7 @@ class Trigger(ToJSONMixin):
155
153
  """
156
154
  Retrieve a trigger by its name and associated task name.
157
155
  """
158
- return await TriggerDetails.get(name=name, task_name=task_name)
156
+ return await TriggerDetails.get.aio(name=name, task_name=task_name)
159
157
 
160
158
  @syncify
161
159
  @classmethod
@@ -166,9 +164,9 @@ class Trigger(ToJSONMixin):
166
164
  List all triggers associated with a specific task or all tasks if no task name is provided.
167
165
  """
168
166
  ensure_client()
169
- cfg = get_common_config()
167
+ cfg = get_init_config()
170
168
  token = None
171
- # task_name_id = None TODO: implement listing by task name only
169
+ task_name_id = None
172
170
  project_id = None
173
171
  task_id = None
174
172
  if task_name and task_version:
@@ -179,13 +177,13 @@ class Trigger(ToJSONMixin):
179
177
  org=cfg.org,
180
178
  version=task_version,
181
179
  )
182
- # elif task_name: TODO: implement listing by task name only
183
- # task_name_id = task_definition_pb2.TaskName(
184
- # name=task_name,
185
- # project=cfg.project,
186
- # domain=cfg.domain,
187
- # org=cfg.org,
188
- # )
180
+ elif task_name:
181
+ task_name_id = task_definition_pb2.TaskName(
182
+ name=task_name,
183
+ project=cfg.project,
184
+ domain=cfg.domain,
185
+ org=cfg.org,
186
+ )
189
187
  else:
190
188
  project_id = identifier_pb2.ProjectIdentifier(
191
189
  organization=cfg.org,
@@ -198,7 +196,7 @@ class Trigger(ToJSONMixin):
198
196
  request=trigger_service_pb2.ListTriggersRequest(
199
197
  project_id=project_id,
200
198
  task_id=task_id,
201
- # task_name=task_name_id,
199
+ task_name=task_name_id,
202
200
  request=list_pb2.ListRequest(
203
201
  limit=limit,
204
202
  token=token,
@@ -218,7 +216,7 @@ class Trigger(ToJSONMixin):
218
216
  Pause a trigger by its name and associated task name.
219
217
  """
220
218
  ensure_client()
221
- cfg = get_common_config()
219
+ cfg = get_init_config()
222
220
  await get_client().trigger_service.UpdateTriggers(
223
221
  request=trigger_service_pb2.UpdateTriggersRequest(
224
222
  names=[
@@ -241,7 +239,7 @@ class Trigger(ToJSONMixin):
241
239
  Delete a trigger by its name.
242
240
  """
243
241
  ensure_client()
244
- cfg = get_common_config()
242
+ cfg = get_init_config()
245
243
  await get_client().trigger_service.DeleteTriggers(
246
244
  request=trigger_service_pb2.DeleteTriggersRequest(
247
245
  names=[
@@ -0,0 +1,274 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import dataclasses
5
+ import io
6
+ import os
7
+ import pathlib
8
+ import sys
9
+ import tempfile
10
+ import typing
11
+ from typing import Any, Hashable, Protocol
12
+
13
+ import aiofiles
14
+ import aiofiles.os
15
+ import obstore
16
+
17
+ if typing.TYPE_CHECKING:
18
+ from obstore import Bytes, ObjectMeta
19
+ from obstore.store import ObjectStore
20
+
21
+ CHUNK_SIZE = int(os.getenv("FLYTE_IO_CHUNK_SIZE", str(16 * 1024 * 1024)))
22
+ MAX_CONCURRENCY = int(os.getenv("FLYTE_IO_MAX_CONCURRENCY", str(32)))
23
+
24
+
25
+ class DownloadQueueEmpty(RuntimeError):
26
+ pass
27
+
28
+
29
+ class BufferProtocol(Protocol):
30
+ async def write(self, offset, length, value: Bytes) -> None: ...
31
+
32
+ async def read(self) -> memoryview: ...
33
+
34
+ @property
35
+ def complete(self) -> bool: ...
36
+
37
+
38
+ @dataclasses.dataclass
39
+ class _MemoryBuffer:
40
+ arr: bytearray
41
+ pending: int
42
+ _closed: bool = False
43
+
44
+ async def write(self, offset: int, length: int, value: Bytes) -> None:
45
+ self.arr[offset : offset + length] = memoryview(value)
46
+ self.pending -= length
47
+
48
+ async def read(self) -> memoryview:
49
+ return memoryview(self.arr)
50
+
51
+ @property
52
+ def complete(self) -> bool:
53
+ return self.pending == 0
54
+
55
+ @classmethod
56
+ def new(cls, size):
57
+ return cls(arr=bytearray(size), pending=size)
58
+
59
+
60
+ @dataclasses.dataclass
61
+ class _FileBuffer:
62
+ path: pathlib.Path
63
+ pending: int
64
+ _handle: io.FileIO | None = None
65
+ _closed: bool = False
66
+
67
+ async def write(self, offset: int, length: int, value: Bytes) -> None:
68
+ async with aiofiles.open(self.path, mode="r+b") as f:
69
+ await f.seek(offset)
70
+ await f.write(value)
71
+ self.pending -= length
72
+
73
+ async def read(self) -> memoryview:
74
+ async with aiofiles.open(self.path, mode="rb") as f:
75
+ return memoryview(await f.read())
76
+
77
+ @property
78
+ def complete(self) -> bool:
79
+ return self.pending == 0
80
+
81
+ @classmethod
82
+ def new(cls, path: pathlib.Path, size: int):
83
+ path.parent.mkdir(parents=True, exist_ok=True)
84
+ path.touch()
85
+ return cls(path=path, pending=size)
86
+
87
+
88
+ @dataclasses.dataclass
89
+ class Chunk:
90
+ offset: int
91
+ length: int
92
+
93
+
94
+ @dataclasses.dataclass
95
+ class Source:
96
+ id: Hashable
97
+ path: pathlib.Path # Should be str, represents the fully qualified prefix of a file (no bucket)
98
+ length: int
99
+ metadata: Any | None = None
100
+
101
+
102
+ @dataclasses.dataclass
103
+ class DownloadTask:
104
+ source: Source
105
+ chunk: Chunk
106
+ target: pathlib.Path | None = None
107
+
108
+
109
+ class ObstoreParallelReader:
110
+ def __init__(
111
+ self,
112
+ store: ObjectStore,
113
+ *,
114
+ chunk_size=CHUNK_SIZE,
115
+ max_concurrency=MAX_CONCURRENCY,
116
+ ):
117
+ self._store = store
118
+ self._chunk_size = chunk_size
119
+ self._max_concurrency = max_concurrency
120
+
121
+ def _chunks(self, size) -> typing.Iterator[tuple[int, int]]:
122
+ cs = self._chunk_size
123
+ for offset in range(0, size, cs):
124
+ length = min(cs, size - offset)
125
+ yield offset, length
126
+
127
+ async def _as_completed(self, gen: typing.AsyncGenerator[DownloadTask, None], transformer=None):
128
+ inq: asyncio.Queue = asyncio.Queue(self._max_concurrency * 2)
129
+ outq: asyncio.Queue = asyncio.Queue()
130
+ sentinel = object()
131
+ done = asyncio.Event()
132
+
133
+ active: dict[Hashable, _FileBuffer | _MemoryBuffer] = {}
134
+
135
+ async def _fill():
136
+ # Helper function to fill the input queue, this is because the generator is async because it does list/head
137
+ # calls on the object store which are async.
138
+ try:
139
+ counter = 0
140
+ async for task in gen:
141
+ if task.source.id not in active:
142
+ active[task.source.id] = (
143
+ _FileBuffer.new(task.target, task.source.length)
144
+ if task.target is not None
145
+ else _MemoryBuffer.new(task.source.length)
146
+ )
147
+ await inq.put(task)
148
+ counter += 1
149
+ await inq.put(sentinel)
150
+ if counter == 0:
151
+ raise DownloadQueueEmpty
152
+ except asyncio.CancelledError:
153
+ # document why we need to swallow this
154
+ pass
155
+
156
+ async def _worker():
157
+ try:
158
+ while not done.is_set():
159
+ task = await inq.get()
160
+ if task is sentinel:
161
+ inq.put_nowait(sentinel)
162
+ break
163
+ chunk_source_offset = task.chunk.offset
164
+ buf = active[task.source.id]
165
+ data_to_write = await obstore.get_range_async(
166
+ self._store,
167
+ str(task.source.path),
168
+ start=chunk_source_offset,
169
+ end=chunk_source_offset + task.chunk.length,
170
+ )
171
+ await buf.write(
172
+ task.chunk.offset,
173
+ task.chunk.length,
174
+ data_to_write,
175
+ )
176
+ if not buf.complete:
177
+ continue
178
+ if transformer is not None:
179
+ result = await transformer(buf)
180
+ elif task.target is not None:
181
+ result = task.target
182
+ else:
183
+ result = task.source
184
+ outq.put_nowait((task.source.id, result))
185
+ del active[task.source.id]
186
+ except asyncio.CancelledError:
187
+ pass
188
+ finally:
189
+ done.set()
190
+
191
+ # Yield results as they are completed
192
+ if sys.version_info >= (3, 11):
193
+ async with asyncio.TaskGroup() as tg:
194
+ tg.create_task(_fill())
195
+ for _ in range(self._max_concurrency):
196
+ tg.create_task(_worker())
197
+ while not done.is_set():
198
+ yield await outq.get()
199
+ else:
200
+ fill_task = asyncio.create_task(_fill())
201
+ worker_tasks = [asyncio.create_task(_worker()) for _ in range(self._max_concurrency)]
202
+ try:
203
+ while not done.is_set():
204
+ yield await outq.get()
205
+ except Exception as e:
206
+ if not fill_task.done():
207
+ fill_task.cancel()
208
+ for wt in worker_tasks:
209
+ if not wt.done():
210
+ wt.cancel()
211
+ raise e
212
+ finally:
213
+ await asyncio.gather(fill_task, *worker_tasks, return_exceptions=True)
214
+
215
+ # Drain the output queue
216
+ try:
217
+ while True:
218
+ yield outq.get_nowait()
219
+ except asyncio.QueueEmpty:
220
+ pass
221
+
222
+ async def download_files(
223
+ self, src_prefix: pathlib.Path, target_prefix: pathlib.Path, *paths, destination_file_name: str | None = None
224
+ ) -> None:
225
+ """
226
+ src_prefix: Prefix you want to download from in the object store, not including the bucket name, nor file name.
227
+ Should be replaced with string
228
+ target_prefix: Local directory to download to
229
+ paths: Specific paths (relative to src_prefix) to download. If empty, download everything
230
+ """
231
+
232
+ async def _list_downloadable() -> typing.AsyncGenerator[ObjectMeta, None]:
233
+ if paths:
234
+ # For specific file paths, use async head
235
+ for path_ in paths:
236
+ path = src_prefix / path_
237
+ x = await obstore.head_async(self._store, str(path))
238
+ yield x
239
+ return
240
+
241
+ # Use obstore.list() for recursive listing (all files in all subdirectories)
242
+ # obstore.list() returns an async iterator that yields batches (lists) of objects
243
+ async for batch in obstore.list(self._store, prefix=str(src_prefix)):
244
+ for obj in batch:
245
+ yield obj
246
+
247
+ async def _gen(tmp_dir: str) -> typing.AsyncGenerator[DownloadTask, None]:
248
+ async for obj in _list_downloadable():
249
+ path = pathlib.Path(obj["path"]) # e.g. Path(prefix/file.txt), needs to be changed to str.
250
+ size = obj["size"]
251
+ source = Source(id=path, path=path, length=size)
252
+ # Strip src_prefix from path for destination
253
+ rel_path = path.relative_to(src_prefix) # doesn't work on windows
254
+ for offset, length in self._chunks(size):
255
+ yield DownloadTask(
256
+ source=source,
257
+ target=tmp_dir / rel_path, # doesn't work on windows
258
+ chunk=Chunk(offset, length),
259
+ )
260
+
261
+ def _transform_decorator(tmp_dir: str):
262
+ async def _transformer(buf: _FileBuffer) -> None:
263
+ if len(paths) == 1 and destination_file_name is not None:
264
+ target = target_prefix / destination_file_name
265
+ else:
266
+ target = target_prefix / buf.path.relative_to(tmp_dir)
267
+ await aiofiles.os.makedirs(target.parent, exist_ok=True)
268
+ return await aiofiles.os.replace(buf.path, target) # mv buf.path target
269
+
270
+ return _transformer
271
+
272
+ with tempfile.TemporaryDirectory() as temporary_dir:
273
+ async for _ in self._as_completed(_gen(temporary_dir), transformer=_transform_decorator(temporary_dir)):
274
+ pass
flyte/storage/_storage.py CHANGED
@@ -147,12 +147,76 @@ def _get_anonymous_filesystem(from_path):
147
147
  return get_underlying_filesystem(get_protocol(from_path), anonymous=True, asynchronous=True)
148
148
 
149
149
 
150
+ async def _get_obstore_bypass(from_path: str, to_path: str | pathlib.Path, recursive: bool = False, **kwargs) -> str:
151
+ from obstore.store import ObjectStore
152
+
153
+ from flyte.storage._parallel_reader import ObstoreParallelReader
154
+
155
+ fs = get_underlying_filesystem(path=from_path)
156
+ bucket, prefix = fs._split_path(from_path) # pylint: disable=W0212
157
+ store: ObjectStore = fs._construct_store(bucket)
158
+
159
+ download_kwargs = {}
160
+ if "chunk_size" in kwargs:
161
+ download_kwargs["chunk_size"] = kwargs["chunk_size"]
162
+ if "max_concurrency" in kwargs:
163
+ download_kwargs["max_concurrency"] = kwargs["max_concurrency"]
164
+
165
+ reader = ObstoreParallelReader(store, **download_kwargs)
166
+ target_path = pathlib.Path(to_path) if isinstance(to_path, str) else to_path
167
+
168
+ # if recursive, just download the prefix to the target path
169
+ if recursive:
170
+ logger.debug(f"Downloading recursively {prefix=} to {target_path=}")
171
+ await reader.download_files(
172
+ prefix,
173
+ target_path,
174
+ )
175
+ return str(to_path)
176
+
177
+ # if not recursive, we need to split out the file name from the prefix
178
+ else:
179
+ path_for_reader = pathlib.Path(prefix).name
180
+ final_prefix = pathlib.Path(prefix).parent
181
+ logger.debug(f"Downloading single file {final_prefix=}, {path_for_reader=} to {target_path=}")
182
+ await reader.download_files(
183
+ final_prefix,
184
+ target_path.parent,
185
+ path_for_reader,
186
+ destination_file_name=target_path.name,
187
+ )
188
+ return str(target_path)
189
+
190
+
150
191
  async def get(from_path: str, to_path: Optional[str | pathlib.Path] = None, recursive: bool = False, **kwargs) -> str:
151
192
  if not to_path:
152
- name = pathlib.Path(from_path).name
193
+ name = pathlib.Path(from_path).name # may need to be adjusted for windows
153
194
  to_path = get_random_local_path(file_path_or_file_name=name)
154
195
  logger.debug(f"Storing file from {from_path} to {to_path}")
196
+ else:
197
+ # Only apply directory logic for single files (not recursive)
198
+ if not recursive:
199
+ to_path_str = str(to_path)
200
+ # Check for trailing separator BEFORE converting to Path (which normalizes and removes it)
201
+ ends_with_sep = to_path_str.endswith(os.sep)
202
+ to_path_obj = pathlib.Path(to_path)
203
+
204
+ # If path ends with os.sep or is an existing directory, append source filename
205
+ if ends_with_sep or (to_path_obj.exists() and to_path_obj.is_dir()):
206
+ source_filename = pathlib.Path(from_path).name # may need to be adjusted for windows
207
+ to_path = to_path_obj / source_filename
208
+ # For recursive=True, keep to_path as-is (it's the destination directory for contents)
209
+
155
210
  file_system = get_underlying_filesystem(path=from_path)
211
+
212
+ # Check if we should use obstore bypass
213
+ if (
214
+ _is_obstore_supported_protocol(file_system.protocol)
215
+ and hasattr(file_system, "_split_path")
216
+ and hasattr(file_system, "_construct_store")
217
+ ):
218
+ return await _get_obstore_bypass(from_path, to_path, recursive, **kwargs)
219
+
156
220
  try:
157
221
  return await _get_from_filesystem(file_system, from_path, to_path, recursive=recursive, **kwargs)
158
222
  except (OSError, GenericError) as oe:
@@ -195,7 +259,7 @@ async def put(from_path: str, to_path: Optional[str] = None, recursive: bool = F
195
259
  from flyte._context import internal_ctx
196
260
 
197
261
  ctx = internal_ctx()
198
- name = pathlib.Path(from_path).name if not recursive else None # don't pass a name for folders
262
+ name = pathlib.Path(from_path).name
199
263
  to_path = ctx.raw_data.get_random_remote_path(file_name=name)
200
264
 
201
265
  file_system = get_underlying_filesystem(path=to_path)
flyte/types/_interface.py CHANGED
@@ -1,9 +1,9 @@
1
1
  import inspect
2
2
  from typing import Any, Dict, Iterable, Tuple, Type, cast
3
3
 
4
- from flyteidl.core import interface_pb2, literals_pb2
4
+ from flyteidl2.core import interface_pb2, literals_pb2
5
+ from flyteidl2.task import common_pb2
5
6
 
6
- from flyte._protos.workflow import common_pb2
7
7
  from flyte.models import NativeInterface
8
8
 
9
9
 
flyte/types/_pickle.py CHANGED
@@ -6,7 +6,7 @@ from typing import Type
6
6
 
7
7
  import aiofiles
8
8
  import cloudpickle
9
- from flyteidl.core import literals_pb2, types_pb2
9
+ from flyteidl2.core import literals_pb2, types_pb2
10
10
 
11
11
  import flyte.storage as storage
12
12
 
@@ -3,11 +3,10 @@ import json
3
3
  from typing import Any, Dict, Union
4
4
 
5
5
  import msgpack
6
- from flyteidl.core import literals_pb2
6
+ from flyteidl2.core import literals_pb2
7
+ from flyteidl2.task import common_pb2
7
8
  from google.protobuf.json_format import MessageToDict
8
9
 
9
- from flyte._protos.workflow import run_definition_pb2
10
-
11
10
 
12
11
  def _primitive_to_string(primitive: literals_pb2.Primitive) -> Any:
13
12
  """
@@ -88,9 +87,9 @@ def _dict_literal_repr(lmd: Dict[str, literals_pb2.Literal]) -> Dict[str, Any]:
88
87
  def literal_string_repr(
89
88
  lm: Union[
90
89
  literals_pb2.Literal,
91
- run_definition_pb2.NamedLiteral,
92
- run_definition_pb2.Inputs,
93
- run_definition_pb2.Outputs,
90
+ common_pb2.NamedLiteral,
91
+ common_pb2.Inputs,
92
+ common_pb2.Outputs,
94
93
  literals_pb2.LiteralMap,
95
94
  Dict[str, literals_pb2.Literal],
96
95
  ],
@@ -105,13 +104,13 @@ def literal_string_repr(
105
104
  return _literal_string_repr(lm)
106
105
  case literals_pb2.LiteralMap():
107
106
  return _dict_literal_repr(lm.literals)
108
- case run_definition_pb2.NamedLiteral():
107
+ case common_pb2.NamedLiteral():
109
108
  lmd = {lm.name: lm.value}
110
109
  return _dict_literal_repr(lmd)
111
- case run_definition_pb2.Inputs():
110
+ case common_pb2.Inputs():
112
111
  lmd = {n.name: n.value for n in lm.literals}
113
112
  return _dict_literal_repr(lmd)
114
- case run_definition_pb2.Outputs():
113
+ case common_pb2.Outputs():
115
114
  lmd = {n.name: n.value for n in lm.literals}
116
115
  return _dict_literal_repr(lmd)
117
116
  case dict():