flyte 2.0.0b22__py3-none-any.whl → 2.0.0b30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (197) hide show
  1. flyte/__init__.py +18 -2
  2. flyte/_bin/runtime.py +43 -5
  3. flyte/_cache/cache.py +4 -2
  4. flyte/_cache/local_cache.py +216 -0
  5. flyte/_code_bundle/_ignore.py +1 -1
  6. flyte/_code_bundle/_packaging.py +4 -4
  7. flyte/_code_bundle/_utils.py +14 -8
  8. flyte/_code_bundle/bundle.py +13 -5
  9. flyte/_constants.py +1 -0
  10. flyte/_context.py +4 -1
  11. flyte/_custom_context.py +73 -0
  12. flyte/_debug/constants.py +0 -1
  13. flyte/_debug/vscode.py +6 -1
  14. flyte/_deploy.py +223 -59
  15. flyte/_environment.py +5 -0
  16. flyte/_excepthook.py +1 -1
  17. flyte/_image.py +144 -82
  18. flyte/_initialize.py +95 -12
  19. flyte/_interface.py +2 -0
  20. flyte/_internal/controllers/_local_controller.py +65 -24
  21. flyte/_internal/controllers/_trace.py +1 -1
  22. flyte/_internal/controllers/remote/_action.py +13 -11
  23. flyte/_internal/controllers/remote/_client.py +1 -1
  24. flyte/_internal/controllers/remote/_controller.py +9 -4
  25. flyte/_internal/controllers/remote/_core.py +16 -16
  26. flyte/_internal/controllers/remote/_informer.py +4 -4
  27. flyte/_internal/controllers/remote/_service_protocol.py +7 -7
  28. flyte/_internal/imagebuild/docker_builder.py +139 -84
  29. flyte/_internal/imagebuild/image_builder.py +7 -13
  30. flyte/_internal/imagebuild/remote_builder.py +65 -13
  31. flyte/_internal/imagebuild/utils.py +51 -3
  32. flyte/_internal/resolvers/_task_module.py +5 -38
  33. flyte/_internal/resolvers/default.py +2 -2
  34. flyte/_internal/runtime/convert.py +42 -20
  35. flyte/_internal/runtime/entrypoints.py +24 -1
  36. flyte/_internal/runtime/io.py +21 -8
  37. flyte/_internal/runtime/resources_serde.py +20 -6
  38. flyte/_internal/runtime/reuse.py +1 -1
  39. flyte/_internal/runtime/rusty.py +20 -5
  40. flyte/_internal/runtime/task_serde.py +33 -27
  41. flyte/_internal/runtime/taskrunner.py +10 -1
  42. flyte/_internal/runtime/trigger_serde.py +160 -0
  43. flyte/_internal/runtime/types_serde.py +1 -1
  44. flyte/_keyring/file.py +39 -9
  45. flyte/_logging.py +79 -12
  46. flyte/_map.py +31 -12
  47. flyte/_module.py +70 -0
  48. flyte/_pod.py +2 -2
  49. flyte/_resources.py +213 -31
  50. flyte/_run.py +107 -41
  51. flyte/_task.py +66 -10
  52. flyte/_task_environment.py +96 -24
  53. flyte/_task_plugins.py +4 -2
  54. flyte/_trigger.py +1000 -0
  55. flyte/_utils/__init__.py +2 -1
  56. flyte/_utils/asyn.py +3 -1
  57. flyte/_utils/docker_credentials.py +173 -0
  58. flyte/_utils/module_loader.py +17 -2
  59. flyte/_version.py +3 -3
  60. flyte/cli/_abort.py +3 -3
  61. flyte/cli/_build.py +1 -3
  62. flyte/cli/_common.py +78 -7
  63. flyte/cli/_create.py +178 -3
  64. flyte/cli/_delete.py +23 -1
  65. flyte/cli/_deploy.py +49 -11
  66. flyte/cli/_get.py +79 -34
  67. flyte/cli/_params.py +8 -6
  68. flyte/cli/_plugins.py +209 -0
  69. flyte/cli/_run.py +127 -11
  70. flyte/cli/_serve.py +64 -0
  71. flyte/cli/_update.py +37 -0
  72. flyte/cli/_user.py +17 -0
  73. flyte/cli/main.py +30 -4
  74. flyte/config/_config.py +2 -0
  75. flyte/config/_internal.py +1 -0
  76. flyte/config/_reader.py +3 -3
  77. flyte/connectors/__init__.py +11 -0
  78. flyte/connectors/_connector.py +270 -0
  79. flyte/connectors/_server.py +197 -0
  80. flyte/connectors/utils.py +135 -0
  81. flyte/errors.py +10 -1
  82. flyte/extend.py +8 -1
  83. flyte/extras/_container.py +6 -1
  84. flyte/git/_config.py +11 -9
  85. flyte/io/__init__.py +2 -0
  86. flyte/io/_dataframe/__init__.py +2 -0
  87. flyte/io/_dataframe/basic_dfs.py +1 -1
  88. flyte/io/_dataframe/dataframe.py +12 -8
  89. flyte/io/_dir.py +551 -120
  90. flyte/io/_file.py +538 -141
  91. flyte/models.py +57 -12
  92. flyte/remote/__init__.py +6 -1
  93. flyte/remote/_action.py +18 -16
  94. flyte/remote/_client/_protocols.py +39 -4
  95. flyte/remote/_client/auth/_channel.py +10 -6
  96. flyte/remote/_client/controlplane.py +17 -5
  97. flyte/remote/_console.py +3 -2
  98. flyte/remote/_data.py +4 -3
  99. flyte/remote/_logs.py +3 -3
  100. flyte/remote/_run.py +47 -7
  101. flyte/remote/_secret.py +26 -17
  102. flyte/remote/_task.py +21 -9
  103. flyte/remote/_trigger.py +306 -0
  104. flyte/remote/_user.py +33 -0
  105. flyte/storage/__init__.py +6 -1
  106. flyte/storage/_parallel_reader.py +274 -0
  107. flyte/storage/_storage.py +185 -103
  108. flyte/types/__init__.py +16 -0
  109. flyte/types/_interface.py +2 -2
  110. flyte/types/_pickle.py +17 -4
  111. flyte/types/_string_literals.py +8 -9
  112. flyte/types/_type_engine.py +26 -19
  113. flyte/types/_utils.py +1 -1
  114. {flyte-2.0.0b22.data → flyte-2.0.0b30.data}/scripts/runtime.py +43 -5
  115. {flyte-2.0.0b22.dist-info → flyte-2.0.0b30.dist-info}/METADATA +8 -1
  116. flyte-2.0.0b30.dist-info/RECORD +192 -0
  117. flyte/_protos/__init__.py +0 -0
  118. flyte/_protos/common/authorization_pb2.py +0 -66
  119. flyte/_protos/common/authorization_pb2.pyi +0 -108
  120. flyte/_protos/common/authorization_pb2_grpc.py +0 -4
  121. flyte/_protos/common/identifier_pb2.py +0 -99
  122. flyte/_protos/common/identifier_pb2.pyi +0 -120
  123. flyte/_protos/common/identifier_pb2_grpc.py +0 -4
  124. flyte/_protos/common/identity_pb2.py +0 -48
  125. flyte/_protos/common/identity_pb2.pyi +0 -72
  126. flyte/_protos/common/identity_pb2_grpc.py +0 -4
  127. flyte/_protos/common/list_pb2.py +0 -36
  128. flyte/_protos/common/list_pb2.pyi +0 -71
  129. flyte/_protos/common/list_pb2_grpc.py +0 -4
  130. flyte/_protos/common/policy_pb2.py +0 -37
  131. flyte/_protos/common/policy_pb2.pyi +0 -27
  132. flyte/_protos/common/policy_pb2_grpc.py +0 -4
  133. flyte/_protos/common/role_pb2.py +0 -37
  134. flyte/_protos/common/role_pb2.pyi +0 -53
  135. flyte/_protos/common/role_pb2_grpc.py +0 -4
  136. flyte/_protos/common/runtime_version_pb2.py +0 -28
  137. flyte/_protos/common/runtime_version_pb2.pyi +0 -24
  138. flyte/_protos/common/runtime_version_pb2_grpc.py +0 -4
  139. flyte/_protos/imagebuilder/definition_pb2.py +0 -60
  140. flyte/_protos/imagebuilder/definition_pb2.pyi +0 -153
  141. flyte/_protos/imagebuilder/definition_pb2_grpc.py +0 -4
  142. flyte/_protos/imagebuilder/payload_pb2.py +0 -32
  143. flyte/_protos/imagebuilder/payload_pb2.pyi +0 -21
  144. flyte/_protos/imagebuilder/payload_pb2_grpc.py +0 -4
  145. flyte/_protos/imagebuilder/service_pb2.py +0 -29
  146. flyte/_protos/imagebuilder/service_pb2.pyi +0 -5
  147. flyte/_protos/imagebuilder/service_pb2_grpc.py +0 -66
  148. flyte/_protos/logs/dataplane/payload_pb2.py +0 -100
  149. flyte/_protos/logs/dataplane/payload_pb2.pyi +0 -177
  150. flyte/_protos/logs/dataplane/payload_pb2_grpc.py +0 -4
  151. flyte/_protos/secret/definition_pb2.py +0 -49
  152. flyte/_protos/secret/definition_pb2.pyi +0 -93
  153. flyte/_protos/secret/definition_pb2_grpc.py +0 -4
  154. flyte/_protos/secret/payload_pb2.py +0 -62
  155. flyte/_protos/secret/payload_pb2.pyi +0 -94
  156. flyte/_protos/secret/payload_pb2_grpc.py +0 -4
  157. flyte/_protos/secret/secret_pb2.py +0 -38
  158. flyte/_protos/secret/secret_pb2.pyi +0 -6
  159. flyte/_protos/secret/secret_pb2_grpc.py +0 -198
  160. flyte/_protos/secret/secret_pb2_grpc_grpc.py +0 -198
  161. flyte/_protos/validate/validate/validate_pb2.py +0 -76
  162. flyte/_protos/workflow/common_pb2.py +0 -27
  163. flyte/_protos/workflow/common_pb2.pyi +0 -14
  164. flyte/_protos/workflow/common_pb2_grpc.py +0 -4
  165. flyte/_protos/workflow/environment_pb2.py +0 -29
  166. flyte/_protos/workflow/environment_pb2.pyi +0 -12
  167. flyte/_protos/workflow/environment_pb2_grpc.py +0 -4
  168. flyte/_protos/workflow/node_execution_service_pb2.py +0 -26
  169. flyte/_protos/workflow/node_execution_service_pb2.pyi +0 -4
  170. flyte/_protos/workflow/node_execution_service_pb2_grpc.py +0 -32
  171. flyte/_protos/workflow/queue_service_pb2.py +0 -111
  172. flyte/_protos/workflow/queue_service_pb2.pyi +0 -168
  173. flyte/_protos/workflow/queue_service_pb2_grpc.py +0 -172
  174. flyte/_protos/workflow/run_definition_pb2.py +0 -123
  175. flyte/_protos/workflow/run_definition_pb2.pyi +0 -352
  176. flyte/_protos/workflow/run_definition_pb2_grpc.py +0 -4
  177. flyte/_protos/workflow/run_logs_service_pb2.py +0 -41
  178. flyte/_protos/workflow/run_logs_service_pb2.pyi +0 -28
  179. flyte/_protos/workflow/run_logs_service_pb2_grpc.py +0 -69
  180. flyte/_protos/workflow/run_service_pb2.py +0 -137
  181. flyte/_protos/workflow/run_service_pb2.pyi +0 -185
  182. flyte/_protos/workflow/run_service_pb2_grpc.py +0 -446
  183. flyte/_protos/workflow/state_service_pb2.py +0 -67
  184. flyte/_protos/workflow/state_service_pb2.pyi +0 -76
  185. flyte/_protos/workflow/state_service_pb2_grpc.py +0 -138
  186. flyte/_protos/workflow/task_definition_pb2.py +0 -82
  187. flyte/_protos/workflow/task_definition_pb2.pyi +0 -88
  188. flyte/_protos/workflow/task_definition_pb2_grpc.py +0 -4
  189. flyte/_protos/workflow/task_service_pb2.py +0 -60
  190. flyte/_protos/workflow/task_service_pb2.pyi +0 -59
  191. flyte/_protos/workflow/task_service_pb2_grpc.py +0 -138
  192. flyte-2.0.0b22.dist-info/RECORD +0 -250
  193. {flyte-2.0.0b22.data → flyte-2.0.0b30.data}/scripts/debug.py +0 -0
  194. {flyte-2.0.0b22.dist-info → flyte-2.0.0b30.dist-info}/WHEEL +0 -0
  195. {flyte-2.0.0b22.dist-info → flyte-2.0.0b30.dist-info}/entry_points.txt +0 -0
  196. {flyte-2.0.0b22.dist-info → flyte-2.0.0b30.dist-info}/licenses/LICENSE +0 -0
  197. {flyte-2.0.0b22.dist-info → flyte-2.0.0b30.dist-info}/top_level.txt +0 -0
flyte/__init__.py CHANGED
@@ -9,15 +9,17 @@ import sys
9
9
  from ._build import build
10
10
  from ._cache import Cache, CachePolicy, CacheRequest
11
11
  from ._context import ctx
12
+ from ._custom_context import custom_context, get_custom_context
12
13
  from ._deploy import build_images, deploy
13
14
  from ._environment import Environment
14
15
  from ._excepthook import custom_excepthook
15
16
  from ._group import group
16
17
  from ._image import Image
17
- from ._initialize import init, init_from_config
18
+ from ._initialize import current_domain, init, init_from_config
19
+ from ._logging import logger
18
20
  from ._map import map
19
21
  from ._pod import PodTemplate
20
- from ._resources import GPU, TPU, Device, Resources
22
+ from ._resources import AMD_GPU, GPU, HABANA_GAUDI, TPU, Device, DeviceClass, Neuron, Resources
21
23
  from ._retry import RetryStrategy
22
24
  from ._reusable_environment import ReusePolicy
23
25
  from ._run import run, with_runcontext
@@ -25,6 +27,7 @@ from ._secret import Secret, SecretRequest
25
27
  from ._task_environment import TaskEnvironment
26
28
  from ._timeout import Timeout, TimeoutType
27
29
  from ._trace import trace
30
+ from ._trigger import Cron, FixedRate, Trigger, TriggerTime
28
31
  from ._version import __version__
29
32
 
30
33
  sys.excepthook = custom_excepthook
@@ -59,14 +62,20 @@ def version() -> str:
59
62
 
60
63
 
61
64
  __all__ = [
65
+ "AMD_GPU",
62
66
  "GPU",
67
+ "HABANA_GAUDI",
63
68
  "TPU",
64
69
  "Cache",
65
70
  "CachePolicy",
66
71
  "CacheRequest",
72
+ "Cron",
67
73
  "Device",
74
+ "DeviceClass",
68
75
  "Environment",
76
+ "FixedRate",
69
77
  "Image",
78
+ "Neuron",
70
79
  "PodTemplate",
71
80
  "Resources",
72
81
  "RetryStrategy",
@@ -76,16 +85,23 @@ __all__ = [
76
85
  "TaskEnvironment",
77
86
  "Timeout",
78
87
  "TimeoutType",
88
+ "Trigger",
89
+ "TriggerTime",
79
90
  "__version__",
80
91
  "build",
81
92
  "build_images",
82
93
  "ctx",
94
+ "current_domain",
95
+ "custom_context",
83
96
  "deploy",
97
+ "get_custom_context",
84
98
  "group",
85
99
  "init",
86
100
  "init_from_config",
101
+ "logger",
87
102
  "map",
88
103
  "run",
89
104
  "trace",
105
+ "version",
90
106
  "with_runcontext",
91
107
  ]
flyte/_bin/runtime.py CHANGED
@@ -12,6 +12,9 @@ from typing import Any, List
12
12
 
13
13
  import click
14
14
 
15
+ from flyte._utils.helpers import str2bool
16
+ from flyte.models import PathRewrite
17
+
15
18
  # Todo: work with pvditt to make these the names
16
19
  # ACTION_NAME = "_U_ACTION_NAME"
17
20
  # RUN_NAME = "_U_RUN_NAME"
@@ -25,11 +28,12 @@ PROJECT_NAME = "FLYTE_INTERNAL_EXECUTION_PROJECT"
25
28
  DOMAIN_NAME = "FLYTE_INTERNAL_EXECUTION_DOMAIN"
26
29
  ORG_NAME = "_U_ORG_NAME"
27
30
  ENDPOINT_OVERRIDE = "_U_EP_OVERRIDE"
31
+ INSECURE_SKIP_VERIFY_OVERRIDE = "_U_INSECURE_SKIP_VERIFY"
28
32
  RUN_OUTPUT_BASE_DIR = "_U_RUN_BASE"
29
33
  FLYTE_ENABLE_VSCODE_KEY = "_F_E_VS"
30
34
 
31
- # TODO: Remove this after proper auth is implemented
32
35
  _UNION_EAGER_API_KEY_ENV_VAR = "_UNION_EAGER_API_KEY"
36
+ _F_PATH_REWRITE = "_F_PATH_REWRITE"
33
37
 
34
38
 
35
39
  @click.group()
@@ -94,6 +98,7 @@ def main(
94
98
  import flyte
95
99
  import flyte._utils as utils
96
100
  import flyte.errors
101
+ import flyte.storage as storage
97
102
  from flyte._initialize import init
98
103
  from flyte._internal.controllers import create_controller
99
104
  from flyte._internal.imagebuild.image_builder import ImageCache
@@ -136,19 +141,40 @@ def main(
136
141
  controller_kwargs["insecure"] = True
137
142
  logger.debug(f"Using controller endpoint: {ep} with kwargs: {controller_kwargs}")
138
143
 
139
- bundle = CodeBundle(tgz=tgz, pkl=pkl, destination=dest, computed_version=version)
144
+ # Check for insecure_skip_verify override (e.g. for self-signed certs)
145
+ insecure_skip_verify_str = os.getenv(INSECURE_SKIP_VERIFY_OVERRIDE, "")
146
+ if str2bool(insecure_skip_verify_str):
147
+ controller_kwargs["insecure_skip_verify"] = True
148
+ logger.info("SSL certificate verification disabled (insecure_skip_verify=True)")
149
+
150
+ bundle = None
151
+ if tgz or pkl:
152
+ bundle = CodeBundle(tgz=tgz, pkl=pkl, destination=dest, computed_version=version)
140
153
  init(org=org, project=project, domain=domain, image_builder="remote", **controller_kwargs)
141
154
  # Controller is created with the same kwargs as init, so that it can be used to run tasks
142
155
  controller = create_controller(ct="remote", **controller_kwargs)
143
156
 
144
157
  ic = ImageCache.from_transport(image_cache) if image_cache else None
145
158
 
159
+ path_rewrite_cfg = os.getenv(_F_PATH_REWRITE, None)
160
+ path_rewrite = None
161
+ if path_rewrite_cfg:
162
+ potential_path_rewrite = PathRewrite.from_str(path_rewrite_cfg)
163
+ if storage.exists_sync(potential_path_rewrite.new_prefix):
164
+ path_rewrite = potential_path_rewrite
165
+ logger.info(f"Path rewrite configured for {path_rewrite.new_prefix}")
166
+ else:
167
+ logger.error(
168
+ f"Path rewrite failed for path {potential_path_rewrite.new_prefix}, "
169
+ f"not found, reverting to original path {potential_path_rewrite.old_prefix}"
170
+ )
171
+
146
172
  # Create a coroutine to load the task and run it
147
173
  task_coroutine = load_and_run_task(
148
174
  resolver=resolver,
149
175
  resolver_args=resolver_args,
150
176
  action=ActionID(name=name, run_name=run_name, project=project, domain=domain, org=org),
151
- raw_data_path=RawDataPath(path=raw_data_path),
177
+ raw_data_path=RawDataPath(path=raw_data_path, path_rewrite=path_rewrite),
152
178
  checkpoints=Checkpoints(checkpoint_path, prev_checkpoint),
153
179
  code_bundle=bundle,
154
180
  input_path=inputs,
@@ -166,8 +192,20 @@ def main(
166
192
  async def _run_and_stop():
167
193
  loop = asyncio.get_event_loop()
168
194
  loop.set_exception_handler(flyte.errors.silence_grpc_polling_error)
169
- await utils.run_coros(controller_failure, task_coroutine)
170
- await controller.stop()
195
+ try:
196
+ await utils.run_coros(controller_failure, task_coroutine)
197
+ await controller.stop()
198
+ except flyte.errors.RuntimeSystemError as e:
199
+ logger.error(f"Runtime system error: {e}")
200
+ from flyte._internal.runtime.convert import convert_from_native_to_error
201
+ from flyte._internal.runtime.io import upload_error
202
+
203
+ logger.error(f"Flyte runtime failed for action {name} with run name {run_name}, error: {e}")
204
+ err = convert_from_native_to_error(e)
205
+ path = await upload_error(err.err, outputs_path)
206
+ logger.error(f"Run {run_name} Action {name} failed with error: {err}. Uploaded error to {path}")
207
+ await controller.stop()
208
+ raise
171
209
 
172
210
  asyncio.run(_run_and_stop())
173
211
  logger.warning(f"Flyte runtime completed for action {name} with run name {run_name}")
flyte/_cache/cache.py CHANGED
@@ -77,14 +77,16 @@ class Cache:
77
77
  def __post_init__(self):
78
78
  if self.behavior not in get_args(CacheBehavior):
79
79
  raise ValueError(f"Invalid cache behavior: {self.behavior}. Must be one of ['auto', 'override', 'disable']")
80
- if self.behavior == "disable":
81
- return
82
80
 
81
+ # Still setup _ignore_inputs when cache is disabled to prevent _ignored_inputs attribute not found error
83
82
  if isinstance(self.ignored_inputs, str):
84
83
  self._ignored_inputs = (self.ignored_inputs,)
85
84
  else:
86
85
  self._ignored_inputs = self.ignored_inputs
87
86
 
87
+ if self.behavior == "disable":
88
+ return
89
+
88
90
  # Normalize policies so that self._policies is always a list
89
91
  if self.policies is None:
90
92
  from flyte._cache.defaults import get_default_policies
@@ -0,0 +1,216 @@
1
+ import sqlite3
2
+ from pathlib import Path
3
+
4
+ try:
5
+ import aiosqlite
6
+
7
+ HAS_AIOSQLITE = True
8
+ except ImportError:
9
+ HAS_AIOSQLITE = False
10
+
11
+ from flyteidl2.task import common_pb2
12
+
13
+ from flyte._internal.runtime import convert
14
+ from flyte._logging import logger
15
+ from flyte.config import auto
16
+
17
+ DEFAULT_CACHE_DIR = "~/.flyte"
18
+ CACHE_LOCATION = "local-cache/cache.db"
19
+
20
+
21
+ class LocalTaskCache(object):
22
+ """
23
+ This class implements a persistent store able to cache the result of local task executions.
24
+ """
25
+
26
+ _conn: "aiosqlite.Connection | None" = None
27
+ _conn_sync: sqlite3.Connection | None = None
28
+ _initialized: bool = False
29
+
30
+ @staticmethod
31
+ def _get_cache_path() -> str:
32
+ """Get the cache database path, creating directory if needed."""
33
+ config = auto()
34
+ if config.source:
35
+ cache_dir = config.source.parent
36
+ else:
37
+ cache_dir = Path(DEFAULT_CACHE_DIR).expanduser()
38
+
39
+ cache_path = cache_dir / CACHE_LOCATION
40
+ # Ensure the directory exists
41
+ cache_path.parent.mkdir(parents=True, exist_ok=True)
42
+ logger.info(f"Use local cache path: {cache_path}")
43
+ return str(cache_path)
44
+
45
+ @staticmethod
46
+ async def initialize():
47
+ """Initialize the cache with database connection."""
48
+ if not LocalTaskCache._initialized:
49
+ if HAS_AIOSQLITE:
50
+ await LocalTaskCache._initialize_async()
51
+ else:
52
+ LocalTaskCache._initialize_sync()
53
+
54
+ @staticmethod
55
+ async def _initialize_async():
56
+ """Initialize async cache connection."""
57
+ db_path = LocalTaskCache._get_cache_path()
58
+ conn = await aiosqlite.connect(db_path)
59
+ await conn.execute("""
60
+ CREATE TABLE IF NOT EXISTS task_cache (
61
+ key TEXT PRIMARY KEY,
62
+ value BLOB
63
+ )
64
+ """)
65
+ await conn.commit()
66
+ LocalTaskCache._conn = conn
67
+ LocalTaskCache._initialized = True
68
+
69
+ @staticmethod
70
+ def _initialize_sync():
71
+ """Initialize sync cache connection."""
72
+ db_path = LocalTaskCache._get_cache_path()
73
+ conn = sqlite3.connect(db_path)
74
+ conn.execute("""
75
+ CREATE TABLE IF NOT EXISTS task_cache (
76
+ key TEXT PRIMARY KEY,
77
+ value BLOB
78
+ )
79
+ """)
80
+ conn.commit()
81
+ LocalTaskCache._conn_sync = conn
82
+ LocalTaskCache._initialized = True
83
+
84
+ @staticmethod
85
+ async def clear():
86
+ """Clear all cache entries."""
87
+ if not LocalTaskCache._initialized:
88
+ await LocalTaskCache.initialize()
89
+
90
+ if HAS_AIOSQLITE:
91
+ await LocalTaskCache._clear_async()
92
+ else:
93
+ LocalTaskCache._clear_sync()
94
+
95
+ @staticmethod
96
+ async def _clear_async():
97
+ """Clear all cache entries (async)."""
98
+ if LocalTaskCache._conn is None:
99
+ raise RuntimeError("Cache not properly initialized")
100
+ await LocalTaskCache._conn.execute("DELETE FROM task_cache")
101
+ await LocalTaskCache._conn.commit()
102
+
103
+ @staticmethod
104
+ def _clear_sync():
105
+ """Clear all cache entries (sync)."""
106
+ if LocalTaskCache._conn_sync is None:
107
+ raise RuntimeError("Cache not properly initialized")
108
+ LocalTaskCache._conn_sync.execute("DELETE FROM task_cache")
109
+ LocalTaskCache._conn_sync.commit()
110
+
111
+ @staticmethod
112
+ async def get(cache_key: str) -> convert.Outputs | None:
113
+ if not LocalTaskCache._initialized:
114
+ await LocalTaskCache.initialize()
115
+
116
+ if HAS_AIOSQLITE:
117
+ return await LocalTaskCache._get_async(cache_key)
118
+ else:
119
+ return LocalTaskCache._get_sync(cache_key)
120
+
121
+ @staticmethod
122
+ async def _get_async(cache_key: str) -> convert.Outputs | None:
123
+ """Get cache entry (async)."""
124
+ if LocalTaskCache._conn is None:
125
+ raise RuntimeError("Cache not properly initialized")
126
+
127
+ async with LocalTaskCache._conn.execute("SELECT value FROM task_cache WHERE key = ?", (cache_key,)) as cursor:
128
+ row = await cursor.fetchone()
129
+ if row:
130
+ outputs_bytes = row[0]
131
+ outputs = common_pb2.Outputs()
132
+ outputs.ParseFromString(outputs_bytes)
133
+ return convert.Outputs(proto_outputs=outputs)
134
+ return None
135
+
136
+ @staticmethod
137
+ def _get_sync(cache_key: str) -> convert.Outputs | None:
138
+ """Get cache entry (sync)."""
139
+ if LocalTaskCache._conn_sync is None:
140
+ raise RuntimeError("Cache not properly initialized")
141
+
142
+ cursor = LocalTaskCache._conn_sync.execute("SELECT value FROM task_cache WHERE key = ?", (cache_key,))
143
+ row = cursor.fetchone()
144
+ if row:
145
+ outputs_bytes = row[0]
146
+ outputs = common_pb2.Outputs()
147
+ outputs.ParseFromString(outputs_bytes)
148
+ return convert.Outputs(proto_outputs=outputs)
149
+ return None
150
+
151
+ @staticmethod
152
+ async def set(
153
+ cache_key: str,
154
+ value: convert.Outputs,
155
+ ) -> None:
156
+ if not LocalTaskCache._initialized:
157
+ await LocalTaskCache.initialize()
158
+
159
+ if HAS_AIOSQLITE:
160
+ await LocalTaskCache._set_async(cache_key, value)
161
+ else:
162
+ LocalTaskCache._set_sync(cache_key, value)
163
+
164
+ @staticmethod
165
+ async def _set_async(
166
+ cache_key: str,
167
+ value: convert.Outputs,
168
+ ) -> None:
169
+ """Set cache entry (async)."""
170
+ if LocalTaskCache._conn is None:
171
+ raise RuntimeError("Cache not properly initialized")
172
+
173
+ output_bytes = value.proto_outputs.SerializeToString()
174
+ await LocalTaskCache._conn.execute(
175
+ "INSERT OR REPLACE INTO task_cache (key, value) VALUES (?, ?)", (cache_key, output_bytes)
176
+ )
177
+ await LocalTaskCache._conn.commit()
178
+
179
+ @staticmethod
180
+ def _set_sync(
181
+ cache_key: str,
182
+ value: convert.Outputs,
183
+ ) -> None:
184
+ """Set cache entry (sync)."""
185
+ if LocalTaskCache._conn_sync is None:
186
+ raise RuntimeError("Cache not properly initialized")
187
+
188
+ output_bytes = value.proto_outputs.SerializeToString()
189
+ LocalTaskCache._conn_sync.execute(
190
+ "INSERT OR REPLACE INTO task_cache (key, value) VALUES (?, ?)", (cache_key, output_bytes)
191
+ )
192
+ LocalTaskCache._conn_sync.commit()
193
+
194
+ @staticmethod
195
+ async def close():
196
+ """Close the database connection."""
197
+ if HAS_AIOSQLITE:
198
+ await LocalTaskCache._close_async()
199
+ else:
200
+ LocalTaskCache._close_sync()
201
+
202
+ @staticmethod
203
+ async def _close_async():
204
+ """Close async database connection."""
205
+ if LocalTaskCache._conn:
206
+ await LocalTaskCache._conn.close()
207
+ LocalTaskCache._conn = None
208
+ LocalTaskCache._initialized = False
209
+
210
+ @staticmethod
211
+ def _close_sync():
212
+ """Close sync database connection."""
213
+ if LocalTaskCache._conn_sync:
214
+ LocalTaskCache._conn_sync.close()
215
+ LocalTaskCache._conn_sync = None
216
+ LocalTaskCache._initialized = False
@@ -79,7 +79,7 @@ class StandardIgnore(Ignore):
79
79
  by fed with custom ignore patterns from cli."""
80
80
 
81
81
  def __init__(self, root: Path, patterns: Optional[List[str]] = None):
82
- super().__init__(root)
82
+ super().__init__(root.resolve())
83
83
  self.patterns = patterns if patterns else STANDARD_IGNORE_PATTERNS
84
84
 
85
85
  def _is_ignored(self, path: pathlib.Path) -> bool:
@@ -32,15 +32,15 @@ def print_ls_tree(source: os.PathLike, ls: typing.List[str]):
32
32
  f"File structure:\n:open_file_folder: {source}",
33
33
  guide_style="bold bright_blue",
34
34
  )
35
- trees = {pathlib.Path(source): tree_root}
36
-
35
+ source_path = pathlib.Path(source).resolve()
36
+ trees = {source_path: tree_root}
37
37
  for f in ls:
38
38
  fpp = pathlib.Path(f)
39
39
  if fpp.parent not in trees:
40
40
  # add trees for all intermediate folders
41
41
  current = tree_root
42
- current_path = pathlib.Path(source)
43
- for subdir in fpp.parent.relative_to(source).parts:
42
+ current_path = source_path # pathlib.Path(source)
43
+ for subdir in fpp.parent.relative_to(source_path).parts:
44
44
  current_path = current_path / subdir
45
45
  if current_path not in trees:
46
46
  current = current.add(f"{subdir}", guide_style="bold bright_blue")
@@ -193,15 +193,15 @@ def list_all_files(source_path: pathlib.Path, deref_symlinks, ignore_group: Opti
193
193
  def _file_is_in_directory(file: str, directory: str) -> bool:
194
194
  """Return True if file is in directory and in its children."""
195
195
  try:
196
- return os.path.commonpath([file, directory]) == directory
197
- except ValueError as e:
198
- # ValueError is raised by windows if the paths are not from the same drive
199
- logger.debug(f"{file} and {directory} are not in the same drive: {e!s}")
196
+ return pathlib.Path(file).resolve().is_relative_to(pathlib.Path(directory).resolve())
197
+ except OSError as e:
198
+ # OSError can be raised if paths cannot be resolved (permissions, broken symlinks, etc.)
199
+ logger.debug(f"Failed to resolve paths for {file} and {directory}: {e!s}")
200
200
  return False
201
201
 
202
202
 
203
203
  def list_imported_modules_as_files(source_path: str, modules: List[ModuleType]) -> List[str]:
204
- """Copies modules into destination that are in modules. The module files are copied only if:
204
+ """Lists the files of modules that have been loaded. The files are only included if:
205
205
 
206
206
  1. Not a site-packages. These are installed packages and not user files.
207
207
  2. Not in the sys.base_prefix or sys.prefix. These are also installed and not user files.
@@ -211,7 +211,7 @@ def list_imported_modules_as_files(source_path: str, modules: List[ModuleType])
211
211
  import flyte
212
212
  from flyte._utils.lazy_module import is_imported
213
213
 
214
- files = []
214
+ files = set()
215
215
  flyte_root = os.path.dirname(flyte.__file__)
216
216
 
217
217
  # These directories contain installed packages or modules from the Python standard library.
@@ -244,9 +244,15 @@ def list_imported_modules_as_files(source_path: str, modules: List[ModuleType])
244
244
  logger.debug(f"{mod_file} is not in {source_path}")
245
245
  continue
246
246
 
247
- files.append(mod_file)
247
+ if not pathlib.Path(mod_file).is_file():
248
+ # Some modules have a __file__ attribute that are relative to the base package. Let's skip these,
249
+ # can add more rigorous logic to really pull out the correct file location if we need to.
250
+ logger.debug(f"Skipping {mod_file} from {mod.__name__} because it is not a file")
251
+ continue
252
+
253
+ files.add(mod_file)
248
254
 
249
- return files
255
+ return list(files)
250
256
 
251
257
 
252
258
  def add_imported_modules_from_source(source_path: str, destination: str, modules: List[ModuleType]):
@@ -8,7 +8,7 @@ from pathlib import Path
8
8
  from typing import ClassVar, Type
9
9
 
10
10
  from async_lru import alru_cache
11
- from flyteidl.core.tasks_pb2 import TaskTemplate
11
+ from flyteidl2.core.tasks_pb2 import TaskTemplate
12
12
 
13
13
  from flyte._logging import log, logger
14
14
  from flyte._utils import AsyncLRUCache
@@ -104,7 +104,7 @@ async def build_pkl_bundle(
104
104
  import shutil
105
105
 
106
106
  # Copy the bundle to the given path
107
- shutil.copy(dest, copy_bundle_to)
107
+ shutil.copy(dest, copy_bundle_to, follow_symlinks=True)
108
108
  local_path = copy_bundle_to / dest.name
109
109
  return CodeBundle(pkl=str(local_path), computed_version=str_digest)
110
110
  return CodeBundle(pkl=str(dest), computed_version=str_digest)
@@ -169,6 +169,8 @@ async def download_bundle(bundle: CodeBundle) -> pathlib.Path:
169
169
 
170
170
  :return: The path to the downloaded code bundle.
171
171
  """
172
+ import sys
173
+
172
174
  import flyte.storage as storage
173
175
 
174
176
  dest = pathlib.Path(bundle.destination)
@@ -185,16 +187,22 @@ async def download_bundle(bundle: CodeBundle) -> pathlib.Path:
185
187
  # NOTE the os.path.join(destination, ''). This is to ensure that the given path is in fact a directory and all
186
188
  # downloaded data should be copied into this directory. We do this to account for a difference in behavior in
187
189
  # fsspec, which requires a trailing slash in case of pre-existing directory.
188
- process = await asyncio.create_subprocess_exec(
189
- "tar",
190
+ args = [
190
191
  "-xvf",
191
192
  str(downloaded_bundle),
192
193
  "-C",
193
194
  str(dest),
195
+ ]
196
+ if sys.platform != "darwin":
197
+ args.insert(0, "--overwrite")
198
+
199
+ process = await asyncio.create_subprocess_exec(
200
+ "tar",
201
+ *args,
194
202
  stdout=asyncio.subprocess.PIPE,
195
203
  stderr=asyncio.subprocess.PIPE,
196
204
  )
197
- stdout, stderr = await process.communicate()
205
+ _stdout, stderr = await process.communicate()
198
206
 
199
207
  if process.returncode != 0:
200
208
  raise RuntimeError(stderr.decode())
flyte/_constants.py ADDED
@@ -0,0 +1 @@
1
+ FLYTE_SYS_PATH = "_F_SYS_PATH" # The paths that will be appended to sys.path at runtime
flyte/_context.py CHANGED
@@ -135,7 +135,10 @@ root_context_var = contextvars.ContextVar("root", default=Context(data=ContextDa
135
135
 
136
136
 
137
137
  def ctx() -> Optional[TaskContext]:
138
- """Retrieve the current task context from the context variable."""
138
+ """
139
+ Returns flyte.models.TaskContext if within a task context, else None
140
+ Note: Only use this in task code and not module level.
141
+ """
139
142
  return internal_ctx().data.task_context
140
143
 
141
144
 
@@ -0,0 +1,73 @@
1
+ from __future__ import annotations
2
+
3
+ from contextlib import contextmanager
4
+
5
+ from flyte._context import ctx
6
+
7
+ from ._context import internal_ctx
8
+
9
+
10
+ def get_custom_context() -> dict[str, str]:
11
+ """
12
+ Get the current input context. This can be used within a task to retrieve
13
+ context metadata that was passed to the action.
14
+
15
+ Context will automatically propagate to sub-actions.
16
+
17
+ Example:
18
+ ```python
19
+ import flyte
20
+
21
+ env = flyte.TaskEnvironment(name="...")
22
+
23
+ @env.task
24
+ def t1():
25
+ # context can be retrieved with `get_custom_context`
26
+ ctx = flyte.get_custom_context()
27
+ print(ctx) # {'project': '...', 'entity': '...'}
28
+ ```
29
+
30
+ :return: Dictionary of context key-value pairs
31
+ """
32
+ tctx = ctx()
33
+ if tctx is None or tctx.custom_context is None:
34
+ return {}
35
+ return tctx.custom_context
36
+
37
+
38
+ @contextmanager
39
+ def custom_context(**context: str):
40
+ """
41
+ Synchronous context manager to set input context for tasks spawned within this block.
42
+
43
+ Example:
44
+ ```python
45
+ import flyte
46
+
47
+ env = flyte.TaskEnvironment(name="...")
48
+
49
+ @env.task
50
+ def t1():
51
+ ctx = flyte.get_custom_context()
52
+ print(ctx)
53
+
54
+ @env.task
55
+ def main():
56
+ # context can be passed via a context manager
57
+ with flyte.custom_context(project="my-project"):
58
+ t1() # will have {'project': 'my-project'} as context
59
+ ```
60
+
61
+ :param context: Key-value pairs to set as input context
62
+ """
63
+ ctx = internal_ctx()
64
+ if ctx.data.task_context is None:
65
+ yield
66
+ return
67
+
68
+ tctx = ctx.data.task_context
69
+ new_tctx = tctx.replace(custom_context={**tctx.custom_context, **context})
70
+
71
+ with ctx.replace_task_context(new_tctx):
72
+ yield
73
+ # Exit the context and restore the previous context
flyte/_debug/constants.py CHANGED
@@ -12,7 +12,6 @@ DEFAULT_CODE_SERVER_REMOTE_PATHS = {
12
12
  }
13
13
  DEFAULT_CODE_SERVER_EXTENSIONS = [
14
14
  "https://raw.githubusercontent.com/flyteorg/flytetools/master/flytekitplugins/flyin/ms-python.python-2023.20.0.vsix",
15
- "https://raw.githubusercontent.com/flyteorg/flytetools/master/flytekitplugins/flyin/ms-toolsai.jupyter-2023.9.100.vsix",
16
15
  ]
17
16
 
18
17
  # Duration to pause the checking of the heartbeat file until the next one
flyte/_debug/vscode.py CHANGED
@@ -256,7 +256,12 @@ def prepare_launch_json(ctx: click.Context, pid: int):
256
256
 
257
257
 
258
258
  async def _start_vscode_server(ctx: click.Context):
259
- await asyncio.gather(download_tgz(ctx.params["dest"], ctx.params["version"], ctx.params["tgz"]), download_vscode())
259
+ if ctx.params["tgz"] is None:
260
+ await download_vscode()
261
+ else:
262
+ await asyncio.gather(
263
+ download_tgz(ctx.params["dest"], ctx.params["version"], ctx.params["tgz"]), download_vscode()
264
+ )
260
265
  child_process = multiprocessing.Process(
261
266
  target=lambda cmd: asyncio.run(asyncio.run(execute_command(cmd))),
262
267
  kwargs={"cmd": f"code-server --bind-addr 0.0.0.0:6060 --disable-workspace-trust --auth none {os.getcwd()}"},