flyte 2.0.0b13__py3-none-any.whl → 2.0.0b30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (211) hide show
  1. flyte/__init__.py +18 -2
  2. flyte/_bin/debug.py +38 -0
  3. flyte/_bin/runtime.py +62 -8
  4. flyte/_cache/cache.py +4 -2
  5. flyte/_cache/local_cache.py +216 -0
  6. flyte/_code_bundle/_ignore.py +12 -4
  7. flyte/_code_bundle/_packaging.py +13 -9
  8. flyte/_code_bundle/_utils.py +18 -10
  9. flyte/_code_bundle/bundle.py +17 -9
  10. flyte/_constants.py +1 -0
  11. flyte/_context.py +4 -1
  12. flyte/_custom_context.py +73 -0
  13. flyte/_debug/constants.py +38 -0
  14. flyte/_debug/utils.py +17 -0
  15. flyte/_debug/vscode.py +307 -0
  16. flyte/_deploy.py +235 -61
  17. flyte/_environment.py +20 -6
  18. flyte/_excepthook.py +1 -1
  19. flyte/_hash.py +1 -16
  20. flyte/_image.py +178 -81
  21. flyte/_initialize.py +132 -51
  22. flyte/_interface.py +39 -2
  23. flyte/_internal/controllers/__init__.py +4 -5
  24. flyte/_internal/controllers/_local_controller.py +70 -29
  25. flyte/_internal/controllers/_trace.py +1 -1
  26. flyte/_internal/controllers/remote/__init__.py +0 -2
  27. flyte/_internal/controllers/remote/_action.py +14 -16
  28. flyte/_internal/controllers/remote/_client.py +1 -1
  29. flyte/_internal/controllers/remote/_controller.py +68 -70
  30. flyte/_internal/controllers/remote/_core.py +127 -99
  31. flyte/_internal/controllers/remote/_informer.py +19 -10
  32. flyte/_internal/controllers/remote/_service_protocol.py +7 -7
  33. flyte/_internal/imagebuild/docker_builder.py +181 -69
  34. flyte/_internal/imagebuild/image_builder.py +0 -5
  35. flyte/_internal/imagebuild/remote_builder.py +155 -64
  36. flyte/_internal/imagebuild/utils.py +51 -2
  37. flyte/_internal/resolvers/_task_module.py +5 -38
  38. flyte/_internal/resolvers/default.py +2 -2
  39. flyte/_internal/runtime/convert.py +110 -21
  40. flyte/_internal/runtime/entrypoints.py +27 -1
  41. flyte/_internal/runtime/io.py +21 -8
  42. flyte/_internal/runtime/resources_serde.py +20 -6
  43. flyte/_internal/runtime/reuse.py +1 -1
  44. flyte/_internal/runtime/rusty.py +20 -5
  45. flyte/_internal/runtime/task_serde.py +34 -19
  46. flyte/_internal/runtime/taskrunner.py +22 -4
  47. flyte/_internal/runtime/trigger_serde.py +160 -0
  48. flyte/_internal/runtime/types_serde.py +1 -1
  49. flyte/_keyring/__init__.py +0 -0
  50. flyte/_keyring/file.py +115 -0
  51. flyte/_logging.py +201 -39
  52. flyte/_map.py +111 -14
  53. flyte/_module.py +70 -0
  54. flyte/_pod.py +4 -3
  55. flyte/_resources.py +213 -31
  56. flyte/_run.py +110 -39
  57. flyte/_task.py +75 -16
  58. flyte/_task_environment.py +105 -29
  59. flyte/_task_plugins.py +4 -2
  60. flyte/_trace.py +5 -0
  61. flyte/_trigger.py +1000 -0
  62. flyte/_utils/__init__.py +2 -1
  63. flyte/_utils/asyn.py +3 -1
  64. flyte/_utils/coro_management.py +2 -1
  65. flyte/_utils/docker_credentials.py +173 -0
  66. flyte/_utils/module_loader.py +17 -2
  67. flyte/_version.py +3 -3
  68. flyte/cli/_abort.py +3 -3
  69. flyte/cli/_build.py +3 -6
  70. flyte/cli/_common.py +78 -7
  71. flyte/cli/_create.py +182 -4
  72. flyte/cli/_delete.py +23 -1
  73. flyte/cli/_deploy.py +63 -16
  74. flyte/cli/_get.py +79 -34
  75. flyte/cli/_params.py +26 -10
  76. flyte/cli/_plugins.py +209 -0
  77. flyte/cli/_run.py +151 -26
  78. flyte/cli/_serve.py +64 -0
  79. flyte/cli/_update.py +37 -0
  80. flyte/cli/_user.py +17 -0
  81. flyte/cli/main.py +30 -4
  82. flyte/config/_config.py +10 -6
  83. flyte/config/_internal.py +1 -0
  84. flyte/config/_reader.py +29 -8
  85. flyte/connectors/__init__.py +11 -0
  86. flyte/connectors/_connector.py +270 -0
  87. flyte/connectors/_server.py +197 -0
  88. flyte/connectors/utils.py +135 -0
  89. flyte/errors.py +22 -2
  90. flyte/extend.py +8 -1
  91. flyte/extras/_container.py +6 -1
  92. flyte/git/__init__.py +3 -0
  93. flyte/git/_config.py +21 -0
  94. flyte/io/__init__.py +2 -0
  95. flyte/io/_dataframe/__init__.py +2 -0
  96. flyte/io/_dataframe/basic_dfs.py +17 -8
  97. flyte/io/_dataframe/dataframe.py +98 -132
  98. flyte/io/_dir.py +575 -113
  99. flyte/io/_file.py +582 -139
  100. flyte/io/_hashing_io.py +342 -0
  101. flyte/models.py +74 -15
  102. flyte/remote/__init__.py +6 -1
  103. flyte/remote/_action.py +34 -26
  104. flyte/remote/_client/_protocols.py +39 -4
  105. flyte/remote/_client/auth/_authenticators/device_code.py +4 -5
  106. flyte/remote/_client/auth/_authenticators/pkce.py +1 -1
  107. flyte/remote/_client/auth/_channel.py +10 -6
  108. flyte/remote/_client/controlplane.py +17 -5
  109. flyte/remote/_console.py +3 -2
  110. flyte/remote/_data.py +6 -6
  111. flyte/remote/_logs.py +3 -3
  112. flyte/remote/_run.py +64 -8
  113. flyte/remote/_secret.py +26 -17
  114. flyte/remote/_task.py +75 -33
  115. flyte/remote/_trigger.py +306 -0
  116. flyte/remote/_user.py +33 -0
  117. flyte/report/_report.py +1 -1
  118. flyte/storage/__init__.py +6 -1
  119. flyte/storage/_config.py +5 -1
  120. flyte/storage/_parallel_reader.py +274 -0
  121. flyte/storage/_storage.py +200 -103
  122. flyte/types/__init__.py +16 -0
  123. flyte/types/_interface.py +2 -2
  124. flyte/types/_pickle.py +35 -8
  125. flyte/types/_string_literals.py +8 -9
  126. flyte/types/_type_engine.py +40 -70
  127. flyte/types/_utils.py +1 -1
  128. flyte-2.0.0b30.data/scripts/debug.py +38 -0
  129. {flyte-2.0.0b13.data → flyte-2.0.0b30.data}/scripts/runtime.py +62 -8
  130. {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/METADATA +11 -3
  131. flyte-2.0.0b30.dist-info/RECORD +192 -0
  132. {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/entry_points.txt +3 -0
  133. flyte/_protos/common/authorization_pb2.py +0 -66
  134. flyte/_protos/common/authorization_pb2.pyi +0 -108
  135. flyte/_protos/common/authorization_pb2_grpc.py +0 -4
  136. flyte/_protos/common/identifier_pb2.py +0 -93
  137. flyte/_protos/common/identifier_pb2.pyi +0 -110
  138. flyte/_protos/common/identifier_pb2_grpc.py +0 -4
  139. flyte/_protos/common/identity_pb2.py +0 -48
  140. flyte/_protos/common/identity_pb2.pyi +0 -72
  141. flyte/_protos/common/identity_pb2_grpc.py +0 -4
  142. flyte/_protos/common/list_pb2.py +0 -36
  143. flyte/_protos/common/list_pb2.pyi +0 -71
  144. flyte/_protos/common/list_pb2_grpc.py +0 -4
  145. flyte/_protos/common/policy_pb2.py +0 -37
  146. flyte/_protos/common/policy_pb2.pyi +0 -27
  147. flyte/_protos/common/policy_pb2_grpc.py +0 -4
  148. flyte/_protos/common/role_pb2.py +0 -37
  149. flyte/_protos/common/role_pb2.pyi +0 -53
  150. flyte/_protos/common/role_pb2_grpc.py +0 -4
  151. flyte/_protos/common/runtime_version_pb2.py +0 -28
  152. flyte/_protos/common/runtime_version_pb2.pyi +0 -24
  153. flyte/_protos/common/runtime_version_pb2_grpc.py +0 -4
  154. flyte/_protos/imagebuilder/definition_pb2.py +0 -59
  155. flyte/_protos/imagebuilder/definition_pb2.pyi +0 -140
  156. flyte/_protos/imagebuilder/definition_pb2_grpc.py +0 -4
  157. flyte/_protos/imagebuilder/payload_pb2.py +0 -32
  158. flyte/_protos/imagebuilder/payload_pb2.pyi +0 -21
  159. flyte/_protos/imagebuilder/payload_pb2_grpc.py +0 -4
  160. flyte/_protos/imagebuilder/service_pb2.py +0 -29
  161. flyte/_protos/imagebuilder/service_pb2.pyi +0 -5
  162. flyte/_protos/imagebuilder/service_pb2_grpc.py +0 -66
  163. flyte/_protos/logs/dataplane/payload_pb2.py +0 -100
  164. flyte/_protos/logs/dataplane/payload_pb2.pyi +0 -177
  165. flyte/_protos/logs/dataplane/payload_pb2_grpc.py +0 -4
  166. flyte/_protos/secret/definition_pb2.py +0 -49
  167. flyte/_protos/secret/definition_pb2.pyi +0 -93
  168. flyte/_protos/secret/definition_pb2_grpc.py +0 -4
  169. flyte/_protos/secret/payload_pb2.py +0 -62
  170. flyte/_protos/secret/payload_pb2.pyi +0 -94
  171. flyte/_protos/secret/payload_pb2_grpc.py +0 -4
  172. flyte/_protos/secret/secret_pb2.py +0 -38
  173. flyte/_protos/secret/secret_pb2.pyi +0 -6
  174. flyte/_protos/secret/secret_pb2_grpc.py +0 -198
  175. flyte/_protos/secret/secret_pb2_grpc_grpc.py +0 -198
  176. flyte/_protos/validate/validate/validate_pb2.py +0 -76
  177. flyte/_protos/workflow/common_pb2.py +0 -27
  178. flyte/_protos/workflow/common_pb2.pyi +0 -14
  179. flyte/_protos/workflow/common_pb2_grpc.py +0 -4
  180. flyte/_protos/workflow/environment_pb2.py +0 -29
  181. flyte/_protos/workflow/environment_pb2.pyi +0 -12
  182. flyte/_protos/workflow/environment_pb2_grpc.py +0 -4
  183. flyte/_protos/workflow/node_execution_service_pb2.py +0 -26
  184. flyte/_protos/workflow/node_execution_service_pb2.pyi +0 -4
  185. flyte/_protos/workflow/node_execution_service_pb2_grpc.py +0 -32
  186. flyte/_protos/workflow/queue_service_pb2.py +0 -109
  187. flyte/_protos/workflow/queue_service_pb2.pyi +0 -166
  188. flyte/_protos/workflow/queue_service_pb2_grpc.py +0 -172
  189. flyte/_protos/workflow/run_definition_pb2.py +0 -121
  190. flyte/_protos/workflow/run_definition_pb2.pyi +0 -327
  191. flyte/_protos/workflow/run_definition_pb2_grpc.py +0 -4
  192. flyte/_protos/workflow/run_logs_service_pb2.py +0 -41
  193. flyte/_protos/workflow/run_logs_service_pb2.pyi +0 -28
  194. flyte/_protos/workflow/run_logs_service_pb2_grpc.py +0 -69
  195. flyte/_protos/workflow/run_service_pb2.py +0 -137
  196. flyte/_protos/workflow/run_service_pb2.pyi +0 -185
  197. flyte/_protos/workflow/run_service_pb2_grpc.py +0 -446
  198. flyte/_protos/workflow/state_service_pb2.py +0 -67
  199. flyte/_protos/workflow/state_service_pb2.pyi +0 -76
  200. flyte/_protos/workflow/state_service_pb2_grpc.py +0 -138
  201. flyte/_protos/workflow/task_definition_pb2.py +0 -79
  202. flyte/_protos/workflow/task_definition_pb2.pyi +0 -81
  203. flyte/_protos/workflow/task_definition_pb2_grpc.py +0 -4
  204. flyte/_protos/workflow/task_service_pb2.py +0 -60
  205. flyte/_protos/workflow/task_service_pb2.pyi +0 -59
  206. flyte/_protos/workflow/task_service_pb2_grpc.py +0 -138
  207. flyte-2.0.0b13.dist-info/RECORD +0 -239
  208. /flyte/{_protos → _debug}/__init__.py +0 -0
  209. {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/WHEEL +0 -0
  210. {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/licenses/LICENSE +0 -0
  211. {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/top_level.txt +0 -0
flyte/_run.py CHANGED
@@ -2,6 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  import asyncio
4
4
  import pathlib
5
+ import sys
5
6
  import uuid
6
7
  from dataclasses import dataclass
7
8
  from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union, cast
@@ -12,13 +13,13 @@ from flyte._environment import Environment
12
13
  from flyte._initialize import (
13
14
  _get_init_config,
14
15
  get_client,
15
- get_common_config,
16
+ get_init_config,
16
17
  get_storage,
17
18
  requires_initialization,
18
19
  requires_storage,
19
20
  )
20
- from flyte._logging import logger
21
- from flyte._task import P, R, TaskTemplate
21
+ from flyte._logging import LogFormat, logger
22
+ from flyte._task import F, P, R, TaskTemplate
22
23
  from flyte.models import (
23
24
  ActionID,
24
25
  Checkpoints,
@@ -29,6 +30,8 @@ from flyte.models import (
29
30
  )
30
31
  from flyte.syncify import syncify
31
32
 
33
+ from ._constants import FLYTE_SYS_PATH
34
+
32
35
  if TYPE_CHECKING:
33
36
  from flyte.remote import Run
34
37
  from flyte.remote._task import LazyEntity
@@ -90,9 +93,12 @@ class _Runner:
90
93
  env_vars: Dict[str, str] | None = None,
91
94
  labels: Dict[str, str] | None = None,
92
95
  annotations: Dict[str, str] | None = None,
93
- interruptible: bool = False,
96
+ interruptible: bool | None = None,
94
97
  log_level: int | None = None,
98
+ log_format: LogFormat = "console",
95
99
  disable_run_cache: bool = False,
100
+ queue: Optional[str] = None,
101
+ custom_context: Dict[str, str] | None = None,
96
102
  ):
97
103
  from flyte._tools import ipython_check
98
104
 
@@ -111,8 +117,8 @@ class _Runner:
111
117
  self._copy_bundle_to = copy_bundle_to
112
118
  self._interactive_mode = interactive_mode if interactive_mode else ipython_check()
113
119
  self._raw_data_path = raw_data_path
114
- self._metadata_path = metadata_path or "/tmp"
115
- self._run_base_dir = run_base_dir or "/tmp/base"
120
+ self._metadata_path = metadata_path
121
+ self._run_base_dir = run_base_dir
116
122
  self._overwrite_cache = overwrite_cache
117
123
  self._project = project
118
124
  self._domain = domain
@@ -121,12 +127,18 @@ class _Runner:
121
127
  self._annotations = annotations
122
128
  self._interruptible = interruptible
123
129
  self._log_level = log_level
130
+ self._log_format = log_format
124
131
  self._disable_run_cache = disable_run_cache
132
+ self._queue = queue
133
+ self._custom_context = custom_context or {}
125
134
 
126
135
  @requires_initialization
127
- async def _run_remote(self, obj: TaskTemplate[P, R] | LazyEntity, *args: P.args, **kwargs: P.kwargs) -> Run:
136
+ async def _run_remote(self, obj: TaskTemplate[P, R, F] | LazyEntity, *args: P.args, **kwargs: P.kwargs) -> Run:
128
137
  import grpc
129
- from flyteidl.core import literals_pb2
138
+ from flyteidl2.common import identifier_pb2
139
+ from flyteidl2.core import literals_pb2, security_pb2
140
+ from flyteidl2.task import run_pb2
141
+ from flyteidl2.workflow import run_definition_pb2, run_service_pb2
130
142
  from google.protobuf import wrappers_pb2
131
143
 
132
144
  from flyte.remote import Run
@@ -136,20 +148,21 @@ class _Runner:
136
148
  from ._deploy import build_images
137
149
  from ._internal.runtime.convert import convert_from_native_to_inputs
138
150
  from ._internal.runtime.task_serde import translate_task_to_wire
139
- from ._protos.common import identifier_pb2
140
- from ._protos.workflow import run_definition_pb2, run_service_pb2
141
151
 
142
- cfg = get_common_config()
152
+ cfg = get_init_config()
143
153
  project = self._project or cfg.project
144
154
  domain = self._domain or cfg.domain
145
155
 
146
156
  if isinstance(obj, LazyEntity):
147
157
  task = await obj.fetch.aio()
148
158
  task_spec = task.pb2.spec
149
- inputs = await convert_from_native_to_inputs(task.interface, *args, **kwargs)
159
+ inputs = await convert_from_native_to_inputs(
160
+ task.interface, *args, custom_context=self._custom_context, **kwargs
161
+ )
150
162
  version = task.pb2.task_id.version
151
163
  code_bundle = None
152
164
  else:
165
+ task = cast(TaskTemplate[P, R, F], obj)
153
166
  if obj.parent_env is None:
154
167
  raise ValueError("Task is not attached to an environment. Please attach the task to an environment")
155
168
 
@@ -161,7 +174,10 @@ class _Runner:
161
174
  code_bundle = cached_value.code_bundle
162
175
  image_cache = cached_value.image_cache
163
176
  else:
164
- image_cache = await build_images.aio(cast(Environment, obj.parent_env()))
177
+ if not self._dry_run:
178
+ image_cache = await build_images.aio(cast(Environment, obj.parent_env()))
179
+ else:
180
+ image_cache = None
165
181
 
166
182
  if self._interactive_mode:
167
183
  code_bundle = await build_pkl_bundle(
@@ -196,13 +212,23 @@ class _Runner:
196
212
  root_dir=cfg.root_dir,
197
213
  )
198
214
  task_spec = translate_task_to_wire(obj, s_ctx)
199
- inputs = await convert_from_native_to_inputs(obj.native_interface, *args, **kwargs)
215
+ inputs = await convert_from_native_to_inputs(
216
+ obj.native_interface, *args, custom_context=self._custom_context, **kwargs
217
+ )
200
218
 
201
219
  env = self._env_vars or {}
202
- if self._log_level:
203
- env["LOG_LEVEL"] = str(self._log_level)
204
- else:
205
- env["LOG_LEVEL"] = str(logger.getEffectiveLevel())
220
+ if env.get("LOG_LEVEL") is None:
221
+ if self._log_level:
222
+ env["LOG_LEVEL"] = str(self._log_level)
223
+ else:
224
+ env["LOG_LEVEL"] = str(logger.getEffectiveLevel())
225
+ env["LOG_FORMAT"] = self._log_format
226
+
227
+ # These paths will be appended to sys.path at runtime.
228
+ if cfg.sync_local_sys_paths:
229
+ env[FLYTE_SYS_PATH] = ":".join(
230
+ f"./{pathlib.Path(p).relative_to(cfg.root_dir)}" for p in sys.path if p.startswith(str(cfg.root_dir))
231
+ )
206
232
 
207
233
  if not self._dry_run:
208
234
  if get_client() is None:
@@ -245,9 +271,17 @@ class _Runner:
245
271
  raise ValueError(f"Environment variable {k} must be a string, got {type(v)}")
246
272
  kv_pairs.append(literals_pb2.KeyValuePair(key=k, value=v))
247
273
 
248
- env_kv = run_definition_pb2.Envs(values=kv_pairs)
249
- annotations = run_definition_pb2.Annotations(values=self._annotations)
250
- labels = run_definition_pb2.Labels(values=self._labels)
274
+ env_kv = run_pb2.Envs(values=kv_pairs)
275
+ annotations = run_pb2.Annotations(values=self._annotations)
276
+ labels = run_pb2.Labels(values=self._labels)
277
+ raw_data_storage = (
278
+ run_pb2.RawDataStorage(raw_data_prefix=self._raw_data_path) if self._raw_data_path else None
279
+ )
280
+ security_context = (
281
+ security_pb2.SecurityContext(run_as=security_pb2.Identity(k8s_service_account=self._service_account))
282
+ if self._service_account
283
+ else None
284
+ )
251
285
 
252
286
  try:
253
287
  resp = await get_client().run_service.CreateRun(
@@ -256,12 +290,17 @@ class _Runner:
256
290
  project_id=project_id,
257
291
  task_spec=task_spec,
258
292
  inputs=inputs.proto_inputs,
259
- run_spec=run_definition_pb2.RunSpec(
293
+ run_spec=run_pb2.RunSpec(
260
294
  overwrite_cache=self._overwrite_cache,
261
- interruptible=wrappers_pb2.BoolValue(value=self._interruptible),
295
+ interruptible=wrappers_pb2.BoolValue(value=self._interruptible)
296
+ if self._interruptible is not None
297
+ else None,
262
298
  annotations=annotations,
263
299
  labels=labels,
264
300
  envs=env_kv,
301
+ cluster=self._queue or task.queue,
302
+ raw_data_storage=raw_data_storage,
303
+ security_context=security_context,
265
304
  ),
266
305
  ),
267
306
  )
@@ -306,7 +345,7 @@ class _Runner:
306
345
 
307
346
  @requires_storage
308
347
  @requires_initialization
309
- async def _run_hybrid(self, obj: TaskTemplate[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
348
+ async def _run_hybrid(self, obj: TaskTemplate[P, R, F], *args: P.args, **kwargs: P.kwargs) -> R:
310
349
  """
311
350
  Run a task in hybrid mode. This means that the parent action will be run locally, but the child actions will be
312
351
  run in the cluster remotely. This is currently only used for testing,
@@ -321,7 +360,7 @@ class _Runner:
321
360
  from ._internal import create_controller
322
361
  from ._internal.runtime.taskrunner import run_task
323
362
 
324
- cfg = get_common_config()
363
+ cfg = get_init_config()
325
364
 
326
365
  if obj.parent_env is None:
327
366
  raise ValueError("Task is not attached to an environment. Please attach the task to an environment.")
@@ -380,6 +419,7 @@ class _Runner:
380
419
  " flyte.with_runcontext(run_base_dir='s3://bucket/metadata/outputs')",
381
420
  )
382
421
  output_path = self._run_base_dir
422
+ run_base_dir = self._run_base_dir
383
423
  raw_data_path = f"{output_path}/rd/{random_id}"
384
424
  raw_data_path_obj = RawDataPath(path=raw_data_path)
385
425
  checkpoint_path = f"{raw_data_path}/checkpoint"
@@ -396,8 +436,9 @@ class _Runner:
396
436
  version=version if version else "na",
397
437
  raw_data_path=raw_data_path_obj,
398
438
  compiled_image_cache=image_cache,
399
- run_base_dir=self._run_base_dir,
439
+ run_base_dir=run_base_dir,
400
440
  report=flyte.report.Report(name=action.name),
441
+ custom_context=self._custom_context,
401
442
  )
402
443
  async with ctx.replace_task_context(tctx):
403
444
  return await run_task(tctx=tctx, controller=controller, task=obj, inputs=inputs)
@@ -407,10 +448,11 @@ class _Runner:
407
448
  raise err
408
449
  return outputs
409
450
 
410
- async def _run_local(self, obj: TaskTemplate[P, R], *args: P.args, **kwargs: P.kwargs) -> Run:
451
+ async def _run_local(self, obj: TaskTemplate[P, R, F], *args: P.args, **kwargs: P.kwargs) -> Run:
452
+ from flyteidl2.common import identifier_pb2
453
+
411
454
  from flyte._internal.controllers import create_controller
412
455
  from flyte._internal.controllers._local_controller import LocalController
413
- from flyte._protos.common import identifier_pb2
414
456
  from flyte.remote import Run
415
457
  from flyte.report import Report
416
458
 
@@ -421,6 +463,18 @@ class _Runner:
421
463
  else:
422
464
  action = ActionID(name=self._name)
423
465
 
466
+ metadata_path = self._metadata_path
467
+ if metadata_path is None:
468
+ metadata_path = pathlib.Path("/") / "tmp" / "flyte" / "metadata" / action.name
469
+ else:
470
+ metadata_path = pathlib.Path(metadata_path) / action.name
471
+ output_path = metadata_path / "a0"
472
+ if self._raw_data_path is None:
473
+ path = pathlib.Path("/") / "tmp" / "flyte" / "raw_data" / action.name
474
+ raw_data_path = RawDataPath(path=str(path))
475
+ else:
476
+ raw_data_path = RawDataPath(path=self._raw_data_path)
477
+
424
478
  ctx = internal_ctx()
425
479
  tctx = TaskContext(
426
480
  action=action,
@@ -429,14 +483,16 @@ class _Runner:
429
483
  checkpoint_path=internal_ctx().raw_data.path,
430
484
  ),
431
485
  code_bundle=None,
432
- output_path=self._metadata_path,
433
- run_base_dir=self._metadata_path,
486
+ output_path=str(output_path),
487
+ run_base_dir=str(metadata_path),
434
488
  version="na",
435
- raw_data_path=internal_ctx().raw_data,
489
+ raw_data_path=raw_data_path,
436
490
  compiled_image_cache=None,
437
491
  report=Report(name=action.name),
438
492
  mode="local",
493
+ custom_context=self._custom_context,
439
494
  )
495
+
440
496
  with ctx.replace_task_context(tctx):
441
497
  # make the local version always runs on a different thread, returns a wrapped future.
442
498
  if obj._call_as_synchronous:
@@ -448,7 +504,7 @@ class _Runner:
448
504
 
449
505
  class _LocalRun(Run):
450
506
  def __init__(self, outputs: Tuple[Any, ...] | Any):
451
- from flyte._protos.workflow import run_definition_pb2
507
+ from flyteidl2.workflow import run_definition_pb2
452
508
 
453
509
  self._outputs = outputs
454
510
  super().__init__(
@@ -464,7 +520,7 @@ class _Runner:
464
520
 
465
521
  @property
466
522
  def url(self) -> str:
467
- return "local-run"
523
+ return str(metadata_path)
468
524
 
469
525
  def wait(
470
526
  self,
@@ -481,7 +537,7 @@ class _Runner:
481
537
  @syncify
482
538
  async def run(
483
539
  self,
484
- task: TaskTemplate[P, Union[R, Run]] | LazyEntity,
540
+ task: TaskTemplate[P, Union[R, Run], F] | LazyEntity,
485
541
  *args: P.args,
486
542
  **kwargs: P.kwargs,
487
543
  ) -> Union[R, Run]:
@@ -545,9 +601,12 @@ def with_runcontext(
545
601
  env_vars: Dict[str, str] | None = None,
546
602
  labels: Dict[str, str] | None = None,
547
603
  annotations: Dict[str, str] | None = None,
548
- interruptible: bool = False,
604
+ interruptible: bool | None = None,
549
605
  log_level: int | None = None,
606
+ log_format: LogFormat = "console",
550
607
  disable_run_cache: bool = False,
608
+ queue: Optional[str] = None,
609
+ custom_context: Dict[str, str] | None = None,
551
610
  ) -> _Runner:
552
611
  """
553
612
  Launch a new run with the given parameters as the context.
@@ -575,8 +634,8 @@ def with_runcontext(
575
634
  :param interactive_mode: Optional, can be forced to True or False.
576
635
  If not provided, it will be set based on the current environment. For example Jupyter notebooks are considered
577
636
  interactive mode, while scripts are not. This is used to determine how the code bundle is created.
578
- :param raw_data_path: Use this path to store the raw data for the run. Currently only supported for local runs,
579
- and can be used to store raw data in specific locations. TODO coming soon for remote runs as well.
637
+ :param raw_data_path: Use this path to store the raw data for the run for local and remote, and can be used to
638
+ store raw data in specific locations.
580
639
  :param run_base_dir: Optional The base directory to use for the run. This is used to store the metadata for the run,
581
640
  that is passed between tasks.
582
641
  :param overwrite_cache: Optional If true, the cache will be overwritten for the run
@@ -585,15 +644,24 @@ def with_runcontext(
585
644
  :param env_vars: Optional Environment variables to set for the run
586
645
  :param labels: Optional Labels to set for the run
587
646
  :param annotations: Optional Annotations to set for the run
588
- :param interruptible: Optional If true, the run can be interrupted by the user.
647
+ :param interruptible: Optional If true, the run can be scheduled on interruptible instances and false implies
648
+ that all tasks in the run should only be scheduled on non-interruptible instances. If not specified the
649
+ original setting on all tasks is retained.
589
650
  :param log_level: Optional Log level to set for the run. If not provided, it will be set to the default log level
590
651
  set using `flyte.init()`
652
+ :param log_format: Optional Log format to set for the run. If not provided, it will be set to the default log format
591
653
  :param disable_run_cache: Optional If true, the run cache will be disabled. This is useful for testing purposes.
654
+ :param queue: Optional The queue to use for the run. This is used to specify the cluster to use for the run.
655
+ :param custom_context: Optional global input context to pass to the task. This will be available via
656
+ get_custom_context() within the task and will automatically propagate to sub-tasks.
657
+ Acts as base/default values that can be overridden by context managers in the code.
592
658
 
593
659
  :return: runner
594
660
  """
595
661
  if mode == "hybrid" and not name and not run_base_dir:
596
662
  raise ValueError("Run name and run base dir are required for hybrid mode")
663
+ if copy_style == "none" and not version:
664
+ raise ValueError("Version is required when copy_style is 'none'")
597
665
  return _Runner(
598
666
  force_mode=mode,
599
667
  name=name,
@@ -613,12 +681,15 @@ def with_runcontext(
613
681
  project=project,
614
682
  domain=domain,
615
683
  log_level=log_level,
684
+ log_format=log_format,
616
685
  disable_run_cache=disable_run_cache,
686
+ queue=queue,
687
+ custom_context=custom_context,
617
688
  )
618
689
 
619
690
 
620
691
  @syncify
621
- async def run(task: TaskTemplate[P, R], *args: P.args, **kwargs: P.kwargs) -> Union[R, Run]:
692
+ async def run(task: TaskTemplate[P, R, F], *args: P.args, **kwargs: P.kwargs) -> Union[R, Run]:
622
693
  """
623
694
  Run a task with the given parameters
624
695
  :param task: task to run
flyte/_task.py CHANGED
@@ -15,10 +15,12 @@ from typing import (
15
15
  Literal,
16
16
  Optional,
17
17
  ParamSpec,
18
+ Tuple,
18
19
  TypeAlias,
19
20
  TypeVar,
20
21
  Union,
21
22
  cast,
23
+ overload,
22
24
  )
23
25
 
24
26
  from flyte._pod import PodTemplate
@@ -33,10 +35,11 @@ from ._retry import RetryStrategy
33
35
  from ._reusable_environment import ReusePolicy
34
36
  from ._secret import SecretRequest
35
37
  from ._timeout import TimeoutType
38
+ from ._trigger import Trigger
36
39
  from .models import MAX_INLINE_IO_BYTES, NativeInterface, SerializationContext
37
40
 
38
41
  if TYPE_CHECKING:
39
- from flyteidl.core.tasks_pb2 import DataLoadingConfig
42
+ from flyteidl2.core.tasks_pb2 import DataLoadingConfig
40
43
 
41
44
  from ._task_environment import TaskEnvironment
42
45
 
@@ -45,11 +48,12 @@ R = TypeVar("R") # return type
45
48
 
46
49
  AsyncFunctionType: TypeAlias = Callable[P, Coroutine[Any, Any, R]]
47
50
  SyncFunctionType: TypeAlias = Callable[P, R]
48
- FunctionTypes: TypeAlias = Union[AsyncFunctionType, SyncFunctionType]
51
+ FunctionTypes: TypeAlias = AsyncFunctionType | SyncFunctionType
52
+ F = TypeVar("F", bound=FunctionTypes)
49
53
 
50
54
 
51
55
  @dataclass(kw_only=True)
52
- class TaskTemplate(Generic[P, R]):
56
+ class TaskTemplate(Generic[P, R, F]):
53
57
  """
54
58
  Task template is a template for a task that can be executed. It defines various parameters for the task, which
55
59
  can be defined statically at the time of task definition or dynamically at the time of task invocation using
@@ -69,8 +73,8 @@ class TaskTemplate(Generic[P, R]):
69
73
  version with flyte installed
70
74
  :param resources: Optional The resources to use for the task
71
75
  :param cache: Optional The cache policy for the task, defaults to auto, which will cache the results of the task.
72
- :param interruptable: Optional The interruptable policy for the task, defaults to False, which means the task
73
- will not be scheduled on interruptable nodes. If set to True, the task will be scheduled on interruptable nodes,
76
+ :param interruptible: Optional The interruptible policy for the task, defaults to False, which means the task
77
+ will not be scheduled on interruptible nodes. If set to True, the task will be scheduled on interruptible nodes,
74
78
  and the code should handle interruptions and resumptions.
75
79
  :param retries: Optional The number of retries for the task, defaults to 0, which means no retries.
76
80
  :param reusable: Optional The reusability policy for the task, defaults to None, which means the task environment
@@ -81,17 +85,21 @@ class TaskTemplate(Generic[P, R]):
81
85
  :param timeout: Optional The timeout for the task.
82
86
  :param max_inline_io_bytes: Maximum allowed size (in bytes) for all inputs and outputs passed directly to the task
83
87
  (e.g., primitives, strings, dicts). Does not apply to files, directories, or dataframes.
88
+ :param pod_template: Optional The pod template to use for the task.
89
+ :param report: Optional Whether to report the task execution to the Flyte console, defaults to False.
90
+ :param queue: Optional The queue to use for the task. If not provided, the default queue will be used.
91
+ :param debuggable: Optional Whether the task supports debugging capabilities, defaults to False.
84
92
  """
85
93
 
86
94
  name: str
87
95
  interface: NativeInterface
88
- friendly_name: str = ""
96
+ short_name: str = ""
89
97
  task_type: str = "python"
90
98
  task_type_version: int = 0
91
99
  image: Union[str, Image, Literal["auto"]] = "auto"
92
100
  resources: Optional[Resources] = None
93
- cache: CacheRequest = "auto"
94
- interruptable: bool = False
101
+ cache: CacheRequest = "disable"
102
+ interruptible: bool = False
95
103
  retries: Union[int, RetryStrategy] = 0
96
104
  reusable: Union[ReusePolicy, None] = None
97
105
  docs: Optional[Documentation] = None
@@ -100,10 +108,14 @@ class TaskTemplate(Generic[P, R]):
100
108
  timeout: Optional[TimeoutType] = None
101
109
  pod_template: Optional[Union[str, PodTemplate]] = None
102
110
  report: bool = False
111
+ queue: Optional[str] = None
112
+ debuggable: bool = False
103
113
 
104
114
  parent_env: Optional[weakref.ReferenceType[TaskEnvironment]] = None
115
+ parent_env_name: Optional[str] = None
105
116
  ref: bool = field(default=False, init=False, repr=False, compare=False)
106
117
  max_inline_io_bytes: int = MAX_INLINE_IO_BYTES
118
+ triggers: Tuple[Trigger, ...] = field(default_factory=tuple)
107
119
 
108
120
  # Only used in python 3.10 and 3.11, where we cannot use markcoroutinefunction
109
121
  _call_as_synchronous: bool = False
@@ -129,9 +141,9 @@ class TaskTemplate(Generic[P, R]):
129
141
  if isinstance(self.retries, int):
130
142
  self.retries = RetryStrategy(count=self.retries)
131
143
 
132
- if self.friendly_name == "":
133
- # If friendly_name is not set, use the name of the task
134
- self.friendly_name = self.name
144
+ if self.short_name == "":
145
+ # If short_name is not set, use the name of the task
146
+ self.short_name = self.name
135
147
 
136
148
  def __getstate__(self):
137
149
  """
@@ -217,6 +229,14 @@ class TaskTemplate(Generic[P, R]):
217
229
  def native_interface(self) -> NativeInterface:
218
230
  return self.interface
219
231
 
232
+ @overload
233
+ async def aio(self: TaskTemplate[P, R, SyncFunctionType], *args: P.args, **kwargs: P.kwargs) -> R: ...
234
+
235
+ @overload
236
+ async def aio(
237
+ self: TaskTemplate[P, R, AsyncFunctionType], *args: P.args, **kwargs: P.kwargs
238
+ ) -> Coroutine[Any, Any, R]: ...
239
+
220
240
  async def aio(self, *args: P.args, **kwargs: P.kwargs) -> Coroutine[Any, Any, R] | R:
221
241
  """
222
242
  The aio function allows executing "sync" tasks, in an async context. This helps with migrating v1 defined sync
@@ -240,7 +260,6 @@ class TaskTemplate(Generic[P, R]):
240
260
  :param kwargs:
241
261
  :return:
242
262
  """
243
-
244
263
  ctx = internal_ctx()
245
264
  if ctx.is_task_context():
246
265
  from ._internal.controllers import get_controller
@@ -258,10 +277,21 @@ class TaskTemplate(Generic[P, R]):
258
277
  else:
259
278
  raise RuntimeSystemError("BadContext", "Controller is not initialized.")
260
279
  else:
280
+ from flyte._logging import logger
281
+
282
+ logger.warning(f"Task {self.name} running aio outside of a task context.")
261
283
  # Local execute, just stay out of the way, but because .aio is used, we want to return an awaitable,
262
284
  # even for synchronous tasks. This is to support migration.
263
285
  return self.forward(*args, **kwargs)
264
286
 
287
+ @overload
288
+ def __call__(self: TaskTemplate[P, R, SyncFunctionType], *args: P.args, **kwargs: P.kwargs) -> R: ...
289
+
290
+ @overload
291
+ def __call__(
292
+ self: TaskTemplate[P, R, AsyncFunctionType], *args: P.args, **kwargs: P.kwargs
293
+ ) -> Coroutine[Any, Any, R]: ...
294
+
265
295
  def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Coroutine[Any, Any, R] | R:
266
296
  """
267
297
  This is the entrypoint for an async function task at runtime. It will be called during an execution.
@@ -314,7 +344,7 @@ class TaskTemplate(Generic[P, R]):
314
344
  def override(
315
345
  self,
316
346
  *,
317
- friendly_name: Optional[str] = None,
347
+ short_name: Optional[str] = None,
318
348
  resources: Optional[Resources] = None,
319
349
  cache: Optional[CacheRequest] = None,
320
350
  retries: Union[int, RetryStrategy] = 0,
@@ -324,11 +354,30 @@ class TaskTemplate(Generic[P, R]):
324
354
  secrets: Optional[SecretRequest] = None,
325
355
  max_inline_io_bytes: int | None = None,
326
356
  pod_template: Optional[Union[str, PodTemplate]] = None,
357
+ queue: Optional[str] = None,
358
+ interruptible: Optional[bool] = None,
327
359
  **kwargs: Any,
328
360
  ) -> TaskTemplate:
329
361
  """
330
362
  Override various parameters of the task template. This allows for dynamic configuration of the task
331
363
  when it is called, such as changing the image, resources, cache policy, etc.
364
+
365
+ :param short_name: Optional override for the short name of the task.
366
+ :param resources: Optional override for the resources to use for the task.
367
+ :param cache: Optional override for the cache policy for the task.
368
+ :param retries: Optional override for the number of retries for the task.
369
+ :param timeout: Optional override for the timeout for the task.
370
+ :param reusable: Optional override for the reusability policy for the task.
371
+ :param env_vars: Optional override for the environment variables to set for the task.
372
+ :param secrets: Optional override for the secrets that will be injected into the task at runtime.
373
+ :param max_inline_io_bytes: Optional override for the maximum allowed size (in bytes) for all inputs and outputs
374
+ passed directly to the task.
375
+ :param pod_template: Optional override for the pod template to use for the task.
376
+ :param queue: Optional override for the queue to use for the task.
377
+ :param kwargs: Additional keyword arguments for further overrides. Some fields like name, image, docs,
378
+ and interface cannot be overridden.
379
+
380
+ :return: A new TaskTemplate instance with the overridden parameters.
332
381
  """
333
382
  cache = cache or self.cache
334
383
  retries = retries or self.retries
@@ -363,6 +412,8 @@ class TaskTemplate(Generic[P, R]):
363
412
  env_vars = env_vars or self.env_vars
364
413
  secrets = secrets or self.secrets
365
414
 
415
+ interruptible = interruptible if interruptible is not None else self.interruptible
416
+
366
417
  for k, v in kwargs.items():
367
418
  if k == "name":
368
419
  raise ValueError("Name cannot be overridden")
@@ -375,7 +426,7 @@ class TaskTemplate(Generic[P, R]):
375
426
 
376
427
  return replace(
377
428
  self,
378
- friendly_name=friendly_name or self.friendly_name,
429
+ short_name=short_name or self.short_name,
379
430
  resources=resources,
380
431
  cache=cache,
381
432
  retries=retries,
@@ -385,19 +436,22 @@ class TaskTemplate(Generic[P, R]):
385
436
  secrets=secrets,
386
437
  max_inline_io_bytes=max_inline_io_bytes,
387
438
  pod_template=pod_template,
439
+ interruptible=interruptible,
440
+ queue=queue or self.queue,
388
441
  **kwargs,
389
442
  )
390
443
 
391
444
 
392
445
  @dataclass(kw_only=True)
393
- class AsyncFunctionTaskTemplate(TaskTemplate[P, R]):
446
+ class AsyncFunctionTaskTemplate(TaskTemplate[P, R, F]):
394
447
  """
395
448
  A task template that wraps an asynchronous functions. This is automatically created when an asynchronous function
396
449
  is decorated with the task decorator.
397
450
  """
398
451
 
399
- func: FunctionTypes
452
+ func: F
400
453
  plugin_config: Optional[Any] = None # This is used to pass plugin specific configuration
454
+ debuggable: bool = True
401
455
 
402
456
  def __post_init__(self):
403
457
  super().__post_init__()
@@ -476,6 +530,11 @@ class AsyncFunctionTaskTemplate(TaskTemplate[P, R]):
476
530
 
477
531
  from flyte._internal.resolvers.default import DefaultTaskResolver
478
532
 
533
+ if not serialize_context.root_dir:
534
+ raise RuntimeSystemError(
535
+ "SerializationError",
536
+ "Root dir is required for default task resolver when no code bundle is provided.",
537
+ )
479
538
  _task_resolver = DefaultTaskResolver()
480
539
  args = [
481
540
  *args,