flyte 2.0.0b32__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flyte might be problematic. Click here for more details.

Files changed (204) hide show
  1. flyte/__init__.py +108 -0
  2. flyte/_bin/__init__.py +0 -0
  3. flyte/_bin/debug.py +38 -0
  4. flyte/_bin/runtime.py +195 -0
  5. flyte/_bin/serve.py +178 -0
  6. flyte/_build.py +26 -0
  7. flyte/_cache/__init__.py +12 -0
  8. flyte/_cache/cache.py +147 -0
  9. flyte/_cache/defaults.py +9 -0
  10. flyte/_cache/local_cache.py +216 -0
  11. flyte/_cache/policy_function_body.py +42 -0
  12. flyte/_code_bundle/__init__.py +8 -0
  13. flyte/_code_bundle/_ignore.py +121 -0
  14. flyte/_code_bundle/_packaging.py +218 -0
  15. flyte/_code_bundle/_utils.py +347 -0
  16. flyte/_code_bundle/bundle.py +266 -0
  17. flyte/_constants.py +1 -0
  18. flyte/_context.py +155 -0
  19. flyte/_custom_context.py +73 -0
  20. flyte/_debug/__init__.py +0 -0
  21. flyte/_debug/constants.py +38 -0
  22. flyte/_debug/utils.py +17 -0
  23. flyte/_debug/vscode.py +307 -0
  24. flyte/_deploy.py +408 -0
  25. flyte/_deployer.py +109 -0
  26. flyte/_doc.py +29 -0
  27. flyte/_docstring.py +32 -0
  28. flyte/_environment.py +122 -0
  29. flyte/_excepthook.py +37 -0
  30. flyte/_group.py +32 -0
  31. flyte/_hash.py +8 -0
  32. flyte/_image.py +1055 -0
  33. flyte/_initialize.py +628 -0
  34. flyte/_interface.py +119 -0
  35. flyte/_internal/__init__.py +3 -0
  36. flyte/_internal/controllers/__init__.py +129 -0
  37. flyte/_internal/controllers/_local_controller.py +239 -0
  38. flyte/_internal/controllers/_trace.py +48 -0
  39. flyte/_internal/controllers/remote/__init__.py +58 -0
  40. flyte/_internal/controllers/remote/_action.py +211 -0
  41. flyte/_internal/controllers/remote/_client.py +47 -0
  42. flyte/_internal/controllers/remote/_controller.py +583 -0
  43. flyte/_internal/controllers/remote/_core.py +465 -0
  44. flyte/_internal/controllers/remote/_informer.py +381 -0
  45. flyte/_internal/controllers/remote/_service_protocol.py +50 -0
  46. flyte/_internal/imagebuild/__init__.py +3 -0
  47. flyte/_internal/imagebuild/docker_builder.py +706 -0
  48. flyte/_internal/imagebuild/image_builder.py +277 -0
  49. flyte/_internal/imagebuild/remote_builder.py +386 -0
  50. flyte/_internal/imagebuild/utils.py +78 -0
  51. flyte/_internal/resolvers/__init__.py +0 -0
  52. flyte/_internal/resolvers/_task_module.py +21 -0
  53. flyte/_internal/resolvers/common.py +31 -0
  54. flyte/_internal/resolvers/default.py +28 -0
  55. flyte/_internal/runtime/__init__.py +0 -0
  56. flyte/_internal/runtime/convert.py +486 -0
  57. flyte/_internal/runtime/entrypoints.py +204 -0
  58. flyte/_internal/runtime/io.py +188 -0
  59. flyte/_internal/runtime/resources_serde.py +152 -0
  60. flyte/_internal/runtime/reuse.py +125 -0
  61. flyte/_internal/runtime/rusty.py +193 -0
  62. flyte/_internal/runtime/task_serde.py +362 -0
  63. flyte/_internal/runtime/taskrunner.py +209 -0
  64. flyte/_internal/runtime/trigger_serde.py +160 -0
  65. flyte/_internal/runtime/types_serde.py +54 -0
  66. flyte/_keyring/__init__.py +0 -0
  67. flyte/_keyring/file.py +115 -0
  68. flyte/_logging.py +300 -0
  69. flyte/_map.py +312 -0
  70. flyte/_module.py +72 -0
  71. flyte/_pod.py +30 -0
  72. flyte/_resources.py +473 -0
  73. flyte/_retry.py +32 -0
  74. flyte/_reusable_environment.py +102 -0
  75. flyte/_run.py +724 -0
  76. flyte/_secret.py +96 -0
  77. flyte/_task.py +550 -0
  78. flyte/_task_environment.py +316 -0
  79. flyte/_task_plugins.py +47 -0
  80. flyte/_timeout.py +47 -0
  81. flyte/_tools.py +27 -0
  82. flyte/_trace.py +119 -0
  83. flyte/_trigger.py +1000 -0
  84. flyte/_utils/__init__.py +30 -0
  85. flyte/_utils/asyn.py +121 -0
  86. flyte/_utils/async_cache.py +139 -0
  87. flyte/_utils/coro_management.py +27 -0
  88. flyte/_utils/docker_credentials.py +173 -0
  89. flyte/_utils/file_handling.py +72 -0
  90. flyte/_utils/helpers.py +134 -0
  91. flyte/_utils/lazy_module.py +54 -0
  92. flyte/_utils/module_loader.py +104 -0
  93. flyte/_utils/org_discovery.py +57 -0
  94. flyte/_utils/uv_script_parser.py +49 -0
  95. flyte/_version.py +34 -0
  96. flyte/app/__init__.py +22 -0
  97. flyte/app/_app_environment.py +157 -0
  98. flyte/app/_deploy.py +125 -0
  99. flyte/app/_input.py +160 -0
  100. flyte/app/_runtime/__init__.py +3 -0
  101. flyte/app/_runtime/app_serde.py +347 -0
  102. flyte/app/_types.py +101 -0
  103. flyte/app/extras/__init__.py +3 -0
  104. flyte/app/extras/_fastapi.py +151 -0
  105. flyte/cli/__init__.py +12 -0
  106. flyte/cli/_abort.py +28 -0
  107. flyte/cli/_build.py +114 -0
  108. flyte/cli/_common.py +468 -0
  109. flyte/cli/_create.py +371 -0
  110. flyte/cli/_delete.py +45 -0
  111. flyte/cli/_deploy.py +293 -0
  112. flyte/cli/_gen.py +176 -0
  113. flyte/cli/_get.py +370 -0
  114. flyte/cli/_option.py +33 -0
  115. flyte/cli/_params.py +554 -0
  116. flyte/cli/_plugins.py +209 -0
  117. flyte/cli/_run.py +597 -0
  118. flyte/cli/_serve.py +64 -0
  119. flyte/cli/_update.py +37 -0
  120. flyte/cli/_user.py +17 -0
  121. flyte/cli/main.py +221 -0
  122. flyte/config/__init__.py +3 -0
  123. flyte/config/_config.py +248 -0
  124. flyte/config/_internal.py +73 -0
  125. flyte/config/_reader.py +225 -0
  126. flyte/connectors/__init__.py +11 -0
  127. flyte/connectors/_connector.py +270 -0
  128. flyte/connectors/_server.py +197 -0
  129. flyte/connectors/utils.py +135 -0
  130. flyte/errors.py +243 -0
  131. flyte/extend.py +19 -0
  132. flyte/extras/__init__.py +5 -0
  133. flyte/extras/_container.py +286 -0
  134. flyte/git/__init__.py +3 -0
  135. flyte/git/_config.py +21 -0
  136. flyte/io/__init__.py +29 -0
  137. flyte/io/_dataframe/__init__.py +131 -0
  138. flyte/io/_dataframe/basic_dfs.py +223 -0
  139. flyte/io/_dataframe/dataframe.py +1026 -0
  140. flyte/io/_dir.py +910 -0
  141. flyte/io/_file.py +914 -0
  142. flyte/io/_hashing_io.py +342 -0
  143. flyte/models.py +479 -0
  144. flyte/py.typed +0 -0
  145. flyte/remote/__init__.py +35 -0
  146. flyte/remote/_action.py +738 -0
  147. flyte/remote/_app.py +57 -0
  148. flyte/remote/_client/__init__.py +0 -0
  149. flyte/remote/_client/_protocols.py +189 -0
  150. flyte/remote/_client/auth/__init__.py +12 -0
  151. flyte/remote/_client/auth/_auth_utils.py +14 -0
  152. flyte/remote/_client/auth/_authenticators/__init__.py +0 -0
  153. flyte/remote/_client/auth/_authenticators/base.py +403 -0
  154. flyte/remote/_client/auth/_authenticators/client_credentials.py +73 -0
  155. flyte/remote/_client/auth/_authenticators/device_code.py +117 -0
  156. flyte/remote/_client/auth/_authenticators/external_command.py +79 -0
  157. flyte/remote/_client/auth/_authenticators/factory.py +200 -0
  158. flyte/remote/_client/auth/_authenticators/pkce.py +516 -0
  159. flyte/remote/_client/auth/_channel.py +213 -0
  160. flyte/remote/_client/auth/_client_config.py +85 -0
  161. flyte/remote/_client/auth/_default_html.py +32 -0
  162. flyte/remote/_client/auth/_grpc_utils/__init__.py +0 -0
  163. flyte/remote/_client/auth/_grpc_utils/auth_interceptor.py +288 -0
  164. flyte/remote/_client/auth/_grpc_utils/default_metadata_interceptor.py +151 -0
  165. flyte/remote/_client/auth/_keyring.py +152 -0
  166. flyte/remote/_client/auth/_token_client.py +260 -0
  167. flyte/remote/_client/auth/errors.py +16 -0
  168. flyte/remote/_client/controlplane.py +128 -0
  169. flyte/remote/_common.py +30 -0
  170. flyte/remote/_console.py +19 -0
  171. flyte/remote/_data.py +161 -0
  172. flyte/remote/_logs.py +185 -0
  173. flyte/remote/_project.py +88 -0
  174. flyte/remote/_run.py +386 -0
  175. flyte/remote/_secret.py +142 -0
  176. flyte/remote/_task.py +527 -0
  177. flyte/remote/_trigger.py +306 -0
  178. flyte/remote/_user.py +33 -0
  179. flyte/report/__init__.py +3 -0
  180. flyte/report/_report.py +182 -0
  181. flyte/report/_template.html +124 -0
  182. flyte/storage/__init__.py +36 -0
  183. flyte/storage/_config.py +237 -0
  184. flyte/storage/_parallel_reader.py +274 -0
  185. flyte/storage/_remote_fs.py +34 -0
  186. flyte/storage/_storage.py +456 -0
  187. flyte/storage/_utils.py +5 -0
  188. flyte/syncify/__init__.py +56 -0
  189. flyte/syncify/_api.py +375 -0
  190. flyte/types/__init__.py +52 -0
  191. flyte/types/_interface.py +40 -0
  192. flyte/types/_pickle.py +145 -0
  193. flyte/types/_renderer.py +162 -0
  194. flyte/types/_string_literals.py +119 -0
  195. flyte/types/_type_engine.py +2254 -0
  196. flyte/types/_utils.py +80 -0
  197. flyte-2.0.0b32.data/scripts/debug.py +38 -0
  198. flyte-2.0.0b32.data/scripts/runtime.py +195 -0
  199. flyte-2.0.0b32.dist-info/METADATA +351 -0
  200. flyte-2.0.0b32.dist-info/RECORD +204 -0
  201. flyte-2.0.0b32.dist-info/WHEEL +5 -0
  202. flyte-2.0.0b32.dist-info/entry_points.txt +7 -0
  203. flyte-2.0.0b32.dist-info/licenses/LICENSE +201 -0
  204. flyte-2.0.0b32.dist-info/top_level.txt +1 -0
flyte/_run.py ADDED
@@ -0,0 +1,724 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import pathlib
5
+ import sys
6
+ import uuid
7
+ from dataclasses import dataclass
8
+ from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union, cast
9
+
10
+ import flyte.errors
11
+ from flyte._context import contextual_run, internal_ctx
12
+ from flyte._environment import Environment
13
+ from flyte._initialize import (
14
+ _get_init_config,
15
+ get_client,
16
+ get_init_config,
17
+ get_storage,
18
+ requires_initialization,
19
+ requires_storage,
20
+ )
21
+ from flyte._logging import LogFormat, logger
22
+ from flyte._task import F, P, R, TaskTemplate
23
+ from flyte.models import (
24
+ ActionID,
25
+ Checkpoints,
26
+ CodeBundle,
27
+ RawDataPath,
28
+ SerializationContext,
29
+ TaskContext,
30
+ )
31
+ from flyte.syncify import syncify
32
+
33
+ from ._constants import FLYTE_SYS_PATH
34
+
35
+ if TYPE_CHECKING:
36
+ from flyte.remote import Run
37
+ from flyte.remote._task import LazyEntity
38
+
39
+ from ._code_bundle import CopyFiles
40
+ from ._internal.imagebuild.image_builder import ImageCache
41
+
42
+ Mode = Literal["local", "remote", "hybrid"]
43
+ CacheLookupScope = Literal["global", "project-domain"]
44
+
45
+
46
+ @dataclass(frozen=True)
47
+ class _CacheKey:
48
+ obj_id: int
49
+ dry_run: bool
50
+
51
+
52
+ @dataclass(frozen=True)
53
+ class _CacheValue:
54
+ code_bundle: CodeBundle | None
55
+ image_cache: Optional[ImageCache]
56
+
57
+
58
+ _RUN_CACHE: Dict[_CacheKey, _CacheValue] = {}
59
+
60
+
61
+ async def _get_code_bundle_for_run(name: str) -> CodeBundle | None:
62
+ """
63
+ Get the code bundle for the run with the given name.
64
+ This is used to get the code bundle for the run when running in hybrid mode.
65
+ """
66
+ from flyte._internal.runtime.task_serde import extract_code_bundle
67
+ from flyte.remote import Run
68
+
69
+ run = await Run.get.aio(name=name)
70
+ if run:
71
+ run_details = await run.details.aio()
72
+ spec = run_details.action_details.pb2.resolved_task_spec
73
+ return extract_code_bundle(spec)
74
+ return None
75
+
76
+
77
+ class _Runner:
78
+ def __init__(
79
+ self,
80
+ force_mode: Mode | None = None,
81
+ name: Optional[str] = None,
82
+ service_account: Optional[str] = None,
83
+ version: Optional[str] = None,
84
+ copy_style: CopyFiles = "loaded_modules",
85
+ dry_run: bool = False,
86
+ copy_bundle_to: pathlib.Path | None = None,
87
+ interactive_mode: bool | None = None,
88
+ raw_data_path: str | None = None,
89
+ metadata_path: str | None = None,
90
+ run_base_dir: str | None = None,
91
+ overwrite_cache: bool = False,
92
+ project: str | None = None,
93
+ domain: str | None = None,
94
+ env_vars: Dict[str, str] | None = None,
95
+ labels: Dict[str, str] | None = None,
96
+ annotations: Dict[str, str] | None = None,
97
+ interruptible: bool | None = None,
98
+ log_level: int | None = None,
99
+ log_format: LogFormat = "console",
100
+ disable_run_cache: bool = False,
101
+ queue: Optional[str] = None,
102
+ custom_context: Dict[str, str] | None = None,
103
+ cache_lookup_scope: CacheLookupScope = "global",
104
+ ):
105
+ from flyte._tools import ipython_check
106
+
107
+ init_config = _get_init_config()
108
+ client = init_config.client if init_config else None
109
+ if not force_mode and client is not None:
110
+ force_mode = "remote"
111
+ force_mode = force_mode or "local"
112
+ logger.debug(f"Effective run mode: `{force_mode}`, client configured: `{client is not None}`")
113
+ self._mode = force_mode
114
+ self._name = name
115
+ self._service_account = service_account
116
+ self._version = version
117
+ self._copy_files = copy_style
118
+ self._dry_run = dry_run
119
+ self._copy_bundle_to = copy_bundle_to
120
+ self._interactive_mode = interactive_mode if interactive_mode else ipython_check()
121
+ self._raw_data_path = raw_data_path
122
+ self._metadata_path = metadata_path
123
+ self._run_base_dir = run_base_dir
124
+ self._overwrite_cache = overwrite_cache
125
+ self._project = project
126
+ self._domain = domain
127
+ self._env_vars = env_vars
128
+ self._labels = labels
129
+ self._annotations = annotations
130
+ self._interruptible = interruptible
131
+ self._log_level = log_level
132
+ self._log_format = log_format
133
+ self._disable_run_cache = disable_run_cache
134
+ self._queue = queue
135
+ self._custom_context = custom_context or {}
136
+ self._cache_lookup_scope = cache_lookup_scope
137
+
138
+ @requires_initialization
139
+ async def _run_remote(self, obj: TaskTemplate[P, R, F] | LazyEntity, *args: P.args, **kwargs: P.kwargs) -> Run:
140
+ import grpc
141
+ from flyteidl2.common import identifier_pb2
142
+ from flyteidl2.core import literals_pb2, security_pb2
143
+ from flyteidl2.task import run_pb2
144
+ from flyteidl2.workflow import run_definition_pb2, run_service_pb2
145
+ from google.protobuf import wrappers_pb2
146
+
147
+ from flyte.remote import Run
148
+ from flyte.remote._task import LazyEntity
149
+
150
+ from ._code_bundle import build_code_bundle, build_pkl_bundle
151
+ from ._deploy import build_images
152
+ from ._internal.runtime.convert import convert_from_native_to_inputs
153
+ from ._internal.runtime.task_serde import translate_task_to_wire
154
+
155
+ cfg = get_init_config()
156
+ project = self._project or cfg.project
157
+ domain = self._domain or cfg.domain
158
+
159
+ if isinstance(obj, LazyEntity):
160
+ task = await obj.fetch.aio()
161
+ task_spec = task.pb2.spec
162
+ inputs = await convert_from_native_to_inputs(
163
+ task.interface, *args, custom_context=self._custom_context, **kwargs
164
+ )
165
+ version = task.pb2.task_id.version
166
+ code_bundle = None
167
+ else:
168
+ task = cast(TaskTemplate[P, R, F], obj)
169
+ if obj.parent_env is None:
170
+ raise ValueError("Task is not attached to an environment. Please attach the task to an environment")
171
+
172
+ if (
173
+ not self._disable_run_cache
174
+ and _RUN_CACHE.get(_CacheKey(obj_id=id(obj), dry_run=self._dry_run)) is not None
175
+ ):
176
+ cached_value = _RUN_CACHE[_CacheKey(obj_id=id(obj), dry_run=self._dry_run)]
177
+ code_bundle = cached_value.code_bundle
178
+ image_cache = cached_value.image_cache
179
+ else:
180
+ if not self._dry_run:
181
+ image_cache = await build_images.aio(cast(Environment, obj.parent_env()))
182
+ else:
183
+ image_cache = None
184
+
185
+ if self._interactive_mode:
186
+ code_bundle = await build_pkl_bundle(
187
+ obj,
188
+ upload_to_controlplane=not self._dry_run,
189
+ copy_bundle_to=self._copy_bundle_to,
190
+ )
191
+ else:
192
+ if self._copy_files != "none":
193
+ code_bundle = await build_code_bundle(
194
+ from_dir=cfg.root_dir,
195
+ dryrun=self._dry_run,
196
+ copy_bundle_to=self._copy_bundle_to,
197
+ copy_style=self._copy_files,
198
+ )
199
+ else:
200
+ code_bundle = None
201
+ if not self._disable_run_cache:
202
+ _RUN_CACHE[_CacheKey(obj_id=id(obj), dry_run=self._dry_run)] = _CacheValue(
203
+ code_bundle=code_bundle, image_cache=image_cache
204
+ )
205
+
206
+ version = self._version or (
207
+ code_bundle.computed_version if code_bundle and code_bundle.computed_version else None
208
+ )
209
+ if not version:
210
+ raise ValueError("Version is required when running a task")
211
+ s_ctx = SerializationContext(
212
+ code_bundle=code_bundle,
213
+ version=version,
214
+ image_cache=image_cache,
215
+ root_dir=cfg.root_dir,
216
+ )
217
+ task_spec = translate_task_to_wire(obj, s_ctx)
218
+ inputs = await convert_from_native_to_inputs(
219
+ obj.native_interface, *args, custom_context=self._custom_context, **kwargs
220
+ )
221
+
222
+ env = self._env_vars or {}
223
+ if env.get("LOG_LEVEL") is None:
224
+ if self._log_level:
225
+ env["LOG_LEVEL"] = str(self._log_level)
226
+ else:
227
+ env["LOG_LEVEL"] = str(logger.getEffectiveLevel())
228
+ env["LOG_FORMAT"] = self._log_format
229
+
230
+ # These paths will be appended to sys.path at runtime.
231
+ if cfg.sync_local_sys_paths:
232
+ env[FLYTE_SYS_PATH] = ":".join(
233
+ f"./{pathlib.Path(p).relative_to(cfg.root_dir)}" for p in sys.path if p.startswith(str(cfg.root_dir))
234
+ )
235
+
236
+ if not self._dry_run:
237
+ if get_client() is None:
238
+ # This can only happen, if the user forces flyte.run(mode="remote") without initializing the client
239
+ raise flyte.errors.InitializationError(
240
+ "ClientNotInitializedError",
241
+ "user",
242
+ "flyte.run requires client to be initialized. "
243
+ "Call flyte.init() with a valid endpoint or api-key before using this function.",
244
+ )
245
+ run_id = None
246
+ project_id = None
247
+ if self._name:
248
+ run_id = identifier_pb2.RunIdentifier(
249
+ project=project,
250
+ domain=domain,
251
+ org=cfg.org,
252
+ name=self._name if self._name else None,
253
+ )
254
+ else:
255
+ project_id = identifier_pb2.ProjectIdentifier(
256
+ name=project,
257
+ domain=domain,
258
+ organization=cfg.org,
259
+ )
260
+ # Fill in task id inside the task template if it's not provided.
261
+ # Maybe this should be done here, or the backend.
262
+ if task_spec.task_template.id.project == "":
263
+ task_spec.task_template.id.project = project if project else ""
264
+ if task_spec.task_template.id.domain == "":
265
+ task_spec.task_template.id.domain = domain if domain else ""
266
+ if task_spec.task_template.id.org == "":
267
+ task_spec.task_template.id.org = cfg.org if cfg.org else ""
268
+ if task_spec.task_template.id.version == "":
269
+ task_spec.task_template.id.version = version
270
+
271
+ kv_pairs: List[literals_pb2.KeyValuePair] = []
272
+ for k, v in env.items():
273
+ if not isinstance(v, str):
274
+ raise ValueError(f"Environment variable {k} must be a string, got {type(v)}")
275
+ kv_pairs.append(literals_pb2.KeyValuePair(key=k, value=v))
276
+
277
+ env_kv = run_pb2.Envs(values=kv_pairs)
278
+ annotations = run_pb2.Annotations(values=self._annotations)
279
+ labels = run_pb2.Labels(values=self._labels)
280
+ raw_data_storage = (
281
+ run_pb2.RawDataStorage(raw_data_prefix=self._raw_data_path) if self._raw_data_path else None
282
+ )
283
+ security_context = (
284
+ security_pb2.SecurityContext(run_as=security_pb2.Identity(k8s_service_account=self._service_account))
285
+ if self._service_account
286
+ else None
287
+ )
288
+
289
+ def _to_cache_lookup_scope(scope: CacheLookupScope | None = None) -> run_pb2.CacheLookupScope:
290
+ if scope == "global":
291
+ return run_pb2.CacheLookupScope.CACHE_LOOKUP_SCOPE_GLOBAL
292
+ elif scope == "project-domain":
293
+ return run_pb2.CacheLookupScope.CACHE_LOOKUP_SCOPE_PROJECT_DOMAIN
294
+ elif scope is None:
295
+ return run_pb2.CacheLookupScope.CACHE_LOOKUP_SCOPE_UNSPECIFIED
296
+ else:
297
+ raise ValueError(f"Unknown cache lookup scope: {scope}")
298
+
299
+ try:
300
+ resp = await get_client().run_service.CreateRun(
301
+ run_service_pb2.CreateRunRequest(
302
+ run_id=run_id,
303
+ project_id=project_id,
304
+ task_spec=task_spec,
305
+ inputs=inputs.proto_inputs,
306
+ run_spec=run_pb2.RunSpec(
307
+ overwrite_cache=self._overwrite_cache,
308
+ interruptible=wrappers_pb2.BoolValue(value=self._interruptible)
309
+ if self._interruptible is not None
310
+ else None,
311
+ annotations=annotations,
312
+ labels=labels,
313
+ envs=env_kv,
314
+ cluster=self._queue or task.queue,
315
+ raw_data_storage=raw_data_storage,
316
+ security_context=security_context,
317
+ cache_config=run_pb2.CacheConfig(
318
+ overwrite_cache=self._overwrite_cache,
319
+ cache_lookup_scope=_to_cache_lookup_scope(self._cache_lookup_scope)
320
+ if self._cache_lookup_scope
321
+ else None,
322
+ ),
323
+ ),
324
+ ),
325
+ )
326
+ return Run(pb2=resp.run)
327
+ except grpc.aio.AioRpcError as e:
328
+ if e.code() == grpc.StatusCode.UNAVAILABLE:
329
+ raise flyte.errors.RuntimeSystemError(
330
+ "SystemUnavailableError",
331
+ "Flyte system is currently unavailable. check your configuration, or the service status.",
332
+ ) from e
333
+ elif e.code() == grpc.StatusCode.INVALID_ARGUMENT:
334
+ raise flyte.errors.RuntimeUserError("InvalidArgumentError", e.details())
335
+ elif e.code() == grpc.StatusCode.ALREADY_EXISTS:
336
+ # TODO maybe this should be a pass and return existing run?
337
+ raise flyte.errors.RuntimeUserError(
338
+ "RunAlreadyExistsError",
339
+ f"A run with the name '{self._name}' already exists. Please choose a different name.",
340
+ )
341
+ else:
342
+ raise flyte.errors.RuntimeSystemError(
343
+ "RunCreationError",
344
+ f"Failed to create run: {e.details()}",
345
+ ) from e
346
+
347
+ class DryRun(Run):
348
+ def __init__(self, _task_spec, _inputs, _code_bundle):
349
+ super().__init__(
350
+ pb2=run_definition_pb2.Run(
351
+ action=run_definition_pb2.Action(
352
+ id=identifier_pb2.ActionIdentifier(
353
+ name="a0",
354
+ run=identifier_pb2.RunIdentifier(name="dry-run"),
355
+ )
356
+ )
357
+ )
358
+ )
359
+ self.task_spec = _task_spec
360
+ self.inputs = _inputs
361
+ self.code_bundle = _code_bundle
362
+
363
+ return DryRun(_task_spec=task_spec, _inputs=inputs, _code_bundle=code_bundle)
364
+
365
+ @requires_storage
366
+ @requires_initialization
367
+ async def _run_hybrid(self, obj: TaskTemplate[P, R, F], *args: P.args, **kwargs: P.kwargs) -> R:
368
+ """
369
+ Run a task in hybrid mode. This means that the parent action will be run locally, but the child actions will be
370
+ run in the cluster remotely. This is currently only used for testing,
371
+ over the longer term we will productize this.
372
+ """
373
+ import flyte.report
374
+ from flyte._code_bundle import build_code_bundle, build_pkl_bundle
375
+ from flyte._deploy import build_images
376
+ from flyte.models import RawDataPath
377
+ from flyte.storage import ABFS, GCS, S3
378
+
379
+ from ._internal import create_controller
380
+ from ._internal.runtime.taskrunner import run_task
381
+
382
+ cfg = get_init_config()
383
+
384
+ if obj.parent_env is None:
385
+ raise ValueError("Task is not attached to an environment. Please attach the task to an environment.")
386
+
387
+ image_cache = await build_images.aio(cast(Environment, obj.parent_env()))
388
+
389
+ code_bundle = None
390
+ if self._name is not None:
391
+ # Check if remote run service has this run name already and if exists, then extract the code bundle from it.
392
+ code_bundle = await _get_code_bundle_for_run(name=self._name)
393
+
394
+ if not code_bundle:
395
+ if self._interactive_mode:
396
+ code_bundle = await build_pkl_bundle(
397
+ obj,
398
+ upload_to_controlplane=not self._dry_run,
399
+ copy_bundle_to=self._copy_bundle_to,
400
+ )
401
+ else:
402
+ if self._copy_files != "none":
403
+ code_bundle = await build_code_bundle(
404
+ from_dir=cfg.root_dir,
405
+ dryrun=self._dry_run,
406
+ copy_bundle_to=self._copy_bundle_to,
407
+ copy_style=self._copy_files,
408
+ )
409
+ else:
410
+ code_bundle = None
411
+
412
+ version = self._version or (
413
+ code_bundle.computed_version if code_bundle and code_bundle.computed_version else None
414
+ )
415
+ if not version:
416
+ raise ValueError("Version is required when running a task")
417
+
418
+ project = cfg.project
419
+ domain = cfg.domain
420
+ org = cfg.org
421
+ action_name = "a0"
422
+ run_name = self._name
423
+ random_id = str(uuid.uuid4())[:6]
424
+
425
+ controller = create_controller("remote", endpoint="localhost:8090", insecure=True)
426
+ action = ActionID(name=action_name, run_name=run_name, project=project, domain=domain, org=org)
427
+
428
+ inputs = obj.native_interface.convert_to_kwargs(*args, **kwargs)
429
+ # TODO: Ideally we should get this from runService
430
+ # The API should be:
431
+ # create new run, from run, in mode hybrid -> new run id, output_base, raw_data_path, inputs_path
432
+ storage = get_storage()
433
+ if type(storage) not in (S3, GCS, ABFS):
434
+ raise ValueError(f"Unsupported storage type: {type(storage)}")
435
+ if self._run_base_dir is None:
436
+ raise ValueError(
437
+ "Raw data path is required when running task, please set it in the run context:",
438
+ " flyte.with_runcontext(run_base_dir='s3://bucket/metadata/outputs')",
439
+ )
440
+ output_path = self._run_base_dir
441
+ run_base_dir = self._run_base_dir
442
+ raw_data_path = f"{output_path}/rd/{random_id}"
443
+ raw_data_path_obj = RawDataPath(path=raw_data_path)
444
+ checkpoint_path = f"{raw_data_path}/checkpoint"
445
+ prev_checkpoint = f"{raw_data_path}/prev_checkpoint"
446
+ checkpoints = Checkpoints(checkpoint_path, prev_checkpoint)
447
+
448
+ async def _run_task() -> Tuple[Any, Optional[Exception]]:
449
+ ctx = internal_ctx()
450
+ tctx = TaskContext(
451
+ action=action,
452
+ checkpoints=checkpoints,
453
+ code_bundle=code_bundle,
454
+ output_path=output_path,
455
+ version=version if version else "na",
456
+ raw_data_path=raw_data_path_obj,
457
+ compiled_image_cache=image_cache,
458
+ run_base_dir=run_base_dir,
459
+ report=flyte.report.Report(name=action.name),
460
+ custom_context=self._custom_context,
461
+ )
462
+ async with ctx.replace_task_context(tctx):
463
+ return await run_task(tctx=tctx, controller=controller, task=obj, inputs=inputs)
464
+
465
+ outputs, err = await contextual_run(_run_task)
466
+ if err:
467
+ raise err
468
+ return outputs
469
+
470
+ async def _run_local(self, obj: TaskTemplate[P, R, F], *args: P.args, **kwargs: P.kwargs) -> Run:
471
+ from flyteidl2.common import identifier_pb2
472
+
473
+ from flyte._internal.controllers import create_controller
474
+ from flyte._internal.controllers._local_controller import LocalController
475
+ from flyte.remote import Run
476
+ from flyte.report import Report
477
+
478
+ controller = cast(LocalController, create_controller("local"))
479
+
480
+ if self._name is None:
481
+ action = ActionID.create_random()
482
+ else:
483
+ action = ActionID(name=self._name)
484
+
485
+ metadata_path = self._metadata_path
486
+ if metadata_path is None:
487
+ metadata_path = pathlib.Path("/") / "tmp" / "flyte" / "metadata" / action.name
488
+ else:
489
+ metadata_path = pathlib.Path(metadata_path) / action.name
490
+ output_path = metadata_path / "a0"
491
+ if self._raw_data_path is None:
492
+ path = pathlib.Path("/") / "tmp" / "flyte" / "raw_data" / action.name
493
+ raw_data_path = RawDataPath(path=str(path))
494
+ else:
495
+ raw_data_path = RawDataPath(path=self._raw_data_path)
496
+
497
+ ctx = internal_ctx()
498
+ tctx = TaskContext(
499
+ action=action,
500
+ checkpoints=Checkpoints(
501
+ prev_checkpoint_path=internal_ctx().raw_data.path,
502
+ checkpoint_path=internal_ctx().raw_data.path,
503
+ ),
504
+ code_bundle=None,
505
+ output_path=str(output_path),
506
+ run_base_dir=str(metadata_path),
507
+ version="na",
508
+ raw_data_path=raw_data_path,
509
+ compiled_image_cache=None,
510
+ report=Report(name=action.name),
511
+ mode="local",
512
+ custom_context=self._custom_context,
513
+ )
514
+
515
+ with ctx.replace_task_context(tctx):
516
+ # make the local version always runs on a different thread, returns a wrapped future.
517
+ if obj._call_as_synchronous:
518
+ fut = controller.submit_sync(obj, *args, **kwargs)
519
+ awaitable = asyncio.wrap_future(fut)
520
+ outputs = await awaitable
521
+ else:
522
+ outputs = await controller.submit(obj, *args, **kwargs)
523
+
524
+ class _LocalRun(Run):
525
+ def __init__(self, outputs: Tuple[Any, ...] | Any):
526
+ from flyteidl2.workflow import run_definition_pb2
527
+
528
+ self._outputs = outputs
529
+ super().__init__(
530
+ pb2=run_definition_pb2.Run(
531
+ action=run_definition_pb2.Action(
532
+ id=identifier_pb2.ActionIdentifier(
533
+ name="a0",
534
+ run=identifier_pb2.RunIdentifier(name="dry-run"),
535
+ )
536
+ )
537
+ )
538
+ )
539
+
540
+ @property
541
+ def url(self) -> str:
542
+ return str(metadata_path)
543
+
544
+ def wait(
545
+ self,
546
+ quiet: bool = False,
547
+ wait_for: Literal["terminal", "running"] = "terminal",
548
+ ):
549
+ pass
550
+
551
+ def outputs(self) -> R:
552
+ return cast(R, self._outputs)
553
+
554
+ return _LocalRun(outputs)
555
+
556
+ @syncify
557
+ async def run(
558
+ self,
559
+ task: TaskTemplate[P, Union[R, Run], F] | LazyEntity,
560
+ *args: P.args,
561
+ **kwargs: P.kwargs,
562
+ ) -> Union[R, Run]:
563
+ """
564
+ Run an async `@env.task` or `TaskTemplate` instance. The existing async context will be used.
565
+
566
+ Example:
567
+ ```python
568
+ import flyte
569
+ env = flyte.TaskEnvironment("example")
570
+
571
+ @env.task
572
+ async def example_task(x: int, y: str) -> str:
573
+ return f"{x} {y}"
574
+
575
+ if __name__ == "__main__":
576
+ flyte.run(example_task, 1, y="hello")
577
+ ```
578
+
579
+ :param task: TaskTemplate instance `@env.task` or `TaskTemplate`
580
+ :param args: Arguments to pass to the Task
581
+ :param kwargs: Keyword arguments to pass to the Task
582
+ :return: Run instance or the result of the task
583
+ """
584
+ from flyte.remote._task import LazyEntity
585
+
586
+ if isinstance(task, LazyEntity) and self._mode != "remote":
587
+ raise ValueError("Remote task can only be run in remote mode.")
588
+
589
+ if not isinstance(task, TaskTemplate) and not isinstance(task, LazyEntity):
590
+ raise TypeError(f"On Flyte tasks can be run, not generic functions or methods '{type(task)}'.")
591
+
592
+ if self._mode == "remote":
593
+ return await self._run_remote(task, *args, **kwargs)
594
+ task = cast(TaskTemplate, task)
595
+ if self._mode == "hybrid":
596
+ return await self._run_hybrid(task, *args, **kwargs)
597
+
598
+ # TODO We could use this for remote as well and users could simply pass flyte:// or s3:// or file://
599
+ with internal_ctx().new_raw_data_path(
600
+ raw_data_path=RawDataPath.from_local_folder(local_folder=self._raw_data_path)
601
+ ):
602
+ return await self._run_local(task, *args, **kwargs)
603
+
604
+
605
+ def with_runcontext(
606
+ mode: Mode | None = None,
607
+ *,
608
+ name: Optional[str] = None,
609
+ service_account: Optional[str] = None,
610
+ version: Optional[str] = None,
611
+ copy_style: CopyFiles = "loaded_modules",
612
+ dry_run: bool = False,
613
+ copy_bundle_to: pathlib.Path | None = None,
614
+ interactive_mode: bool | None = None,
615
+ raw_data_path: str | None = None,
616
+ run_base_dir: str | None = None,
617
+ overwrite_cache: bool = False,
618
+ project: str | None = None,
619
+ domain: str | None = None,
620
+ env_vars: Dict[str, str] | None = None,
621
+ labels: Dict[str, str] | None = None,
622
+ annotations: Dict[str, str] | None = None,
623
+ interruptible: bool | None = None,
624
+ log_level: int | None = None,
625
+ log_format: LogFormat = "console",
626
+ disable_run_cache: bool = False,
627
+ queue: Optional[str] = None,
628
+ custom_context: Dict[str, str] | None = None,
629
+ cache_lookup_scope: CacheLookupScope = "global",
630
+ ) -> _Runner:
631
+ """
632
+ Launch a new run with the given parameters as the context.
633
+
634
+ Example:
635
+ ```python
636
+ import flyte
637
+ env = flyte.TaskEnvironment("example")
638
+
639
+ @env.task
640
+ async def example_task(x: int, y: str) -> str:
641
+ return f"{x} {y}"
642
+
643
+ if __name__ == "__main__":
644
+ flyte.with_runcontext(name="example_run_id").run(example_task, 1, y="hello")
645
+ ```
646
+
647
+ :param mode: Optional The mode to use for the run, if not provided, it will be computed from flyte.init
648
+ :param version: Optional The version to use for the run, if not provided, it will be computed from the code bundle
649
+ :param name: Optional The name to use for the run
650
+ :param service_account: Optional The service account to use for the run context
651
+ :param copy_style: Optional The copy style to use for the run context
652
+ :param dry_run: Optional If true, the run will not be executed, but the bundle will be created
653
+ :param copy_bundle_to: When dry_run is True, the bundle will be copied to this location if specified
654
+ :param interactive_mode: Optional, can be forced to True or False.
655
+ If not provided, it will be set based on the current environment. For example Jupyter notebooks are considered
656
+ interactive mode, while scripts are not. This is used to determine how the code bundle is created.
657
+ :param raw_data_path: Use this path to store the raw data for the run for local and remote, and can be used to
658
+ store raw data in specific locations.
659
+ :param run_base_dir: Optional The base directory to use for the run. This is used to store the metadata for the run,
660
+ that is passed between tasks.
661
+ :param overwrite_cache: Optional If true, the cache will be overwritten for the run
662
+ :param project: Optional The project to use for the run
663
+ :param domain: Optional The domain to use for the run
664
+ :param env_vars: Optional Environment variables to set for the run
665
+ :param labels: Optional Labels to set for the run
666
+ :param annotations: Optional Annotations to set for the run
667
+ :param interruptible: Optional If true, the run can be scheduled on interruptible instances and false implies
668
+ that all tasks in the run should only be scheduled on non-interruptible instances. If not specified the
669
+ original setting on all tasks is retained.
670
+ :param log_level: Optional Log level to set for the run. If not provided, it will be set to the default log level
671
+ set using `flyte.init()`
672
+ :param log_format: Optional Log format to set for the run. If not provided, it will be set to the default log format
673
+ :param disable_run_cache: Optional If true, the run cache will be disabled. This is useful for testing purposes.
674
+ :param queue: Optional The queue to use for the run. This is used to specify the cluster to use for the run.
675
+ :param custom_context: Optional global input context to pass to the task. This will be available via
676
+ get_custom_context() within the task and will automatically propagate to sub-tasks.
677
+ Acts as base/default values that can be overridden by context managers in the code.
678
+ :param cache_lookup_scope: Optional Scope to use for the run. This is used to specify the scope to use for cache
679
+ lookups. If not specified, it will be set to the default scope (global unless overridden at the system level).
680
+
681
+ :return: runner
682
+ """
683
+ if mode == "hybrid" and not name and not run_base_dir:
684
+ raise ValueError("Run name and run base dir are required for hybrid mode")
685
+ if copy_style == "none" and not version:
686
+ raise ValueError("Version is required when copy_style is 'none'")
687
+ return _Runner(
688
+ force_mode=mode,
689
+ name=name,
690
+ service_account=service_account,
691
+ version=version,
692
+ copy_style=copy_style,
693
+ dry_run=dry_run,
694
+ copy_bundle_to=copy_bundle_to,
695
+ interactive_mode=interactive_mode,
696
+ raw_data_path=raw_data_path,
697
+ run_base_dir=run_base_dir,
698
+ overwrite_cache=overwrite_cache,
699
+ env_vars=env_vars,
700
+ labels=labels,
701
+ annotations=annotations,
702
+ interruptible=interruptible,
703
+ project=project,
704
+ domain=domain,
705
+ log_level=log_level,
706
+ log_format=log_format,
707
+ disable_run_cache=disable_run_cache,
708
+ queue=queue,
709
+ custom_context=custom_context,
710
+ cache_lookup_scope=cache_lookup_scope,
711
+ )
712
+
713
+
714
+ @syncify
715
+ async def run(task: TaskTemplate[P, R, F], *args: P.args, **kwargs: P.kwargs) -> Union[R, Run]:
716
+ """
717
+ Run a task with the given parameters
718
+ :param task: task to run
719
+ :param args: args to pass to the task
720
+ :param kwargs: kwargs to pass to the task
721
+ :return: Run | Result of the task
722
+ """
723
+ # using syncer causes problems
724
+ return await _Runner().run.aio(task, *args, **kwargs) # type: ignore