flyte 2.0.0b22__py3-none-any.whl → 2.0.0b30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (197) hide show
  1. flyte/__init__.py +18 -2
  2. flyte/_bin/runtime.py +43 -5
  3. flyte/_cache/cache.py +4 -2
  4. flyte/_cache/local_cache.py +216 -0
  5. flyte/_code_bundle/_ignore.py +1 -1
  6. flyte/_code_bundle/_packaging.py +4 -4
  7. flyte/_code_bundle/_utils.py +14 -8
  8. flyte/_code_bundle/bundle.py +13 -5
  9. flyte/_constants.py +1 -0
  10. flyte/_context.py +4 -1
  11. flyte/_custom_context.py +73 -0
  12. flyte/_debug/constants.py +0 -1
  13. flyte/_debug/vscode.py +6 -1
  14. flyte/_deploy.py +223 -59
  15. flyte/_environment.py +5 -0
  16. flyte/_excepthook.py +1 -1
  17. flyte/_image.py +144 -82
  18. flyte/_initialize.py +95 -12
  19. flyte/_interface.py +2 -0
  20. flyte/_internal/controllers/_local_controller.py +65 -24
  21. flyte/_internal/controllers/_trace.py +1 -1
  22. flyte/_internal/controllers/remote/_action.py +13 -11
  23. flyte/_internal/controllers/remote/_client.py +1 -1
  24. flyte/_internal/controllers/remote/_controller.py +9 -4
  25. flyte/_internal/controllers/remote/_core.py +16 -16
  26. flyte/_internal/controllers/remote/_informer.py +4 -4
  27. flyte/_internal/controllers/remote/_service_protocol.py +7 -7
  28. flyte/_internal/imagebuild/docker_builder.py +139 -84
  29. flyte/_internal/imagebuild/image_builder.py +7 -13
  30. flyte/_internal/imagebuild/remote_builder.py +65 -13
  31. flyte/_internal/imagebuild/utils.py +51 -3
  32. flyte/_internal/resolvers/_task_module.py +5 -38
  33. flyte/_internal/resolvers/default.py +2 -2
  34. flyte/_internal/runtime/convert.py +42 -20
  35. flyte/_internal/runtime/entrypoints.py +24 -1
  36. flyte/_internal/runtime/io.py +21 -8
  37. flyte/_internal/runtime/resources_serde.py +20 -6
  38. flyte/_internal/runtime/reuse.py +1 -1
  39. flyte/_internal/runtime/rusty.py +20 -5
  40. flyte/_internal/runtime/task_serde.py +33 -27
  41. flyte/_internal/runtime/taskrunner.py +10 -1
  42. flyte/_internal/runtime/trigger_serde.py +160 -0
  43. flyte/_internal/runtime/types_serde.py +1 -1
  44. flyte/_keyring/file.py +39 -9
  45. flyte/_logging.py +79 -12
  46. flyte/_map.py +31 -12
  47. flyte/_module.py +70 -0
  48. flyte/_pod.py +2 -2
  49. flyte/_resources.py +213 -31
  50. flyte/_run.py +107 -41
  51. flyte/_task.py +66 -10
  52. flyte/_task_environment.py +96 -24
  53. flyte/_task_plugins.py +4 -2
  54. flyte/_trigger.py +1000 -0
  55. flyte/_utils/__init__.py +2 -1
  56. flyte/_utils/asyn.py +3 -1
  57. flyte/_utils/docker_credentials.py +173 -0
  58. flyte/_utils/module_loader.py +17 -2
  59. flyte/_version.py +3 -3
  60. flyte/cli/_abort.py +3 -3
  61. flyte/cli/_build.py +1 -3
  62. flyte/cli/_common.py +78 -7
  63. flyte/cli/_create.py +178 -3
  64. flyte/cli/_delete.py +23 -1
  65. flyte/cli/_deploy.py +49 -11
  66. flyte/cli/_get.py +79 -34
  67. flyte/cli/_params.py +8 -6
  68. flyte/cli/_plugins.py +209 -0
  69. flyte/cli/_run.py +127 -11
  70. flyte/cli/_serve.py +64 -0
  71. flyte/cli/_update.py +37 -0
  72. flyte/cli/_user.py +17 -0
  73. flyte/cli/main.py +30 -4
  74. flyte/config/_config.py +2 -0
  75. flyte/config/_internal.py +1 -0
  76. flyte/config/_reader.py +3 -3
  77. flyte/connectors/__init__.py +11 -0
  78. flyte/connectors/_connector.py +270 -0
  79. flyte/connectors/_server.py +197 -0
  80. flyte/connectors/utils.py +135 -0
  81. flyte/errors.py +10 -1
  82. flyte/extend.py +8 -1
  83. flyte/extras/_container.py +6 -1
  84. flyte/git/_config.py +11 -9
  85. flyte/io/__init__.py +2 -0
  86. flyte/io/_dataframe/__init__.py +2 -0
  87. flyte/io/_dataframe/basic_dfs.py +1 -1
  88. flyte/io/_dataframe/dataframe.py +12 -8
  89. flyte/io/_dir.py +551 -120
  90. flyte/io/_file.py +538 -141
  91. flyte/models.py +57 -12
  92. flyte/remote/__init__.py +6 -1
  93. flyte/remote/_action.py +18 -16
  94. flyte/remote/_client/_protocols.py +39 -4
  95. flyte/remote/_client/auth/_channel.py +10 -6
  96. flyte/remote/_client/controlplane.py +17 -5
  97. flyte/remote/_console.py +3 -2
  98. flyte/remote/_data.py +4 -3
  99. flyte/remote/_logs.py +3 -3
  100. flyte/remote/_run.py +47 -7
  101. flyte/remote/_secret.py +26 -17
  102. flyte/remote/_task.py +21 -9
  103. flyte/remote/_trigger.py +306 -0
  104. flyte/remote/_user.py +33 -0
  105. flyte/storage/__init__.py +6 -1
  106. flyte/storage/_parallel_reader.py +274 -0
  107. flyte/storage/_storage.py +185 -103
  108. flyte/types/__init__.py +16 -0
  109. flyte/types/_interface.py +2 -2
  110. flyte/types/_pickle.py +17 -4
  111. flyte/types/_string_literals.py +8 -9
  112. flyte/types/_type_engine.py +26 -19
  113. flyte/types/_utils.py +1 -1
  114. {flyte-2.0.0b22.data → flyte-2.0.0b30.data}/scripts/runtime.py +43 -5
  115. {flyte-2.0.0b22.dist-info → flyte-2.0.0b30.dist-info}/METADATA +8 -1
  116. flyte-2.0.0b30.dist-info/RECORD +192 -0
  117. flyte/_protos/__init__.py +0 -0
  118. flyte/_protos/common/authorization_pb2.py +0 -66
  119. flyte/_protos/common/authorization_pb2.pyi +0 -108
  120. flyte/_protos/common/authorization_pb2_grpc.py +0 -4
  121. flyte/_protos/common/identifier_pb2.py +0 -99
  122. flyte/_protos/common/identifier_pb2.pyi +0 -120
  123. flyte/_protos/common/identifier_pb2_grpc.py +0 -4
  124. flyte/_protos/common/identity_pb2.py +0 -48
  125. flyte/_protos/common/identity_pb2.pyi +0 -72
  126. flyte/_protos/common/identity_pb2_grpc.py +0 -4
  127. flyte/_protos/common/list_pb2.py +0 -36
  128. flyte/_protos/common/list_pb2.pyi +0 -71
  129. flyte/_protos/common/list_pb2_grpc.py +0 -4
  130. flyte/_protos/common/policy_pb2.py +0 -37
  131. flyte/_protos/common/policy_pb2.pyi +0 -27
  132. flyte/_protos/common/policy_pb2_grpc.py +0 -4
  133. flyte/_protos/common/role_pb2.py +0 -37
  134. flyte/_protos/common/role_pb2.pyi +0 -53
  135. flyte/_protos/common/role_pb2_grpc.py +0 -4
  136. flyte/_protos/common/runtime_version_pb2.py +0 -28
  137. flyte/_protos/common/runtime_version_pb2.pyi +0 -24
  138. flyte/_protos/common/runtime_version_pb2_grpc.py +0 -4
  139. flyte/_protos/imagebuilder/definition_pb2.py +0 -60
  140. flyte/_protos/imagebuilder/definition_pb2.pyi +0 -153
  141. flyte/_protos/imagebuilder/definition_pb2_grpc.py +0 -4
  142. flyte/_protos/imagebuilder/payload_pb2.py +0 -32
  143. flyte/_protos/imagebuilder/payload_pb2.pyi +0 -21
  144. flyte/_protos/imagebuilder/payload_pb2_grpc.py +0 -4
  145. flyte/_protos/imagebuilder/service_pb2.py +0 -29
  146. flyte/_protos/imagebuilder/service_pb2.pyi +0 -5
  147. flyte/_protos/imagebuilder/service_pb2_grpc.py +0 -66
  148. flyte/_protos/logs/dataplane/payload_pb2.py +0 -100
  149. flyte/_protos/logs/dataplane/payload_pb2.pyi +0 -177
  150. flyte/_protos/logs/dataplane/payload_pb2_grpc.py +0 -4
  151. flyte/_protos/secret/definition_pb2.py +0 -49
  152. flyte/_protos/secret/definition_pb2.pyi +0 -93
  153. flyte/_protos/secret/definition_pb2_grpc.py +0 -4
  154. flyte/_protos/secret/payload_pb2.py +0 -62
  155. flyte/_protos/secret/payload_pb2.pyi +0 -94
  156. flyte/_protos/secret/payload_pb2_grpc.py +0 -4
  157. flyte/_protos/secret/secret_pb2.py +0 -38
  158. flyte/_protos/secret/secret_pb2.pyi +0 -6
  159. flyte/_protos/secret/secret_pb2_grpc.py +0 -198
  160. flyte/_protos/secret/secret_pb2_grpc_grpc.py +0 -198
  161. flyte/_protos/validate/validate/validate_pb2.py +0 -76
  162. flyte/_protos/workflow/common_pb2.py +0 -27
  163. flyte/_protos/workflow/common_pb2.pyi +0 -14
  164. flyte/_protos/workflow/common_pb2_grpc.py +0 -4
  165. flyte/_protos/workflow/environment_pb2.py +0 -29
  166. flyte/_protos/workflow/environment_pb2.pyi +0 -12
  167. flyte/_protos/workflow/environment_pb2_grpc.py +0 -4
  168. flyte/_protos/workflow/node_execution_service_pb2.py +0 -26
  169. flyte/_protos/workflow/node_execution_service_pb2.pyi +0 -4
  170. flyte/_protos/workflow/node_execution_service_pb2_grpc.py +0 -32
  171. flyte/_protos/workflow/queue_service_pb2.py +0 -111
  172. flyte/_protos/workflow/queue_service_pb2.pyi +0 -168
  173. flyte/_protos/workflow/queue_service_pb2_grpc.py +0 -172
  174. flyte/_protos/workflow/run_definition_pb2.py +0 -123
  175. flyte/_protos/workflow/run_definition_pb2.pyi +0 -352
  176. flyte/_protos/workflow/run_definition_pb2_grpc.py +0 -4
  177. flyte/_protos/workflow/run_logs_service_pb2.py +0 -41
  178. flyte/_protos/workflow/run_logs_service_pb2.pyi +0 -28
  179. flyte/_protos/workflow/run_logs_service_pb2_grpc.py +0 -69
  180. flyte/_protos/workflow/run_service_pb2.py +0 -137
  181. flyte/_protos/workflow/run_service_pb2.pyi +0 -185
  182. flyte/_protos/workflow/run_service_pb2_grpc.py +0 -446
  183. flyte/_protos/workflow/state_service_pb2.py +0 -67
  184. flyte/_protos/workflow/state_service_pb2.pyi +0 -76
  185. flyte/_protos/workflow/state_service_pb2_grpc.py +0 -138
  186. flyte/_protos/workflow/task_definition_pb2.py +0 -82
  187. flyte/_protos/workflow/task_definition_pb2.pyi +0 -88
  188. flyte/_protos/workflow/task_definition_pb2_grpc.py +0 -4
  189. flyte/_protos/workflow/task_service_pb2.py +0 -60
  190. flyte/_protos/workflow/task_service_pb2.pyi +0 -59
  191. flyte/_protos/workflow/task_service_pb2_grpc.py +0 -138
  192. flyte-2.0.0b22.dist-info/RECORD +0 -250
  193. {flyte-2.0.0b22.data → flyte-2.0.0b30.data}/scripts/debug.py +0 -0
  194. {flyte-2.0.0b22.dist-info → flyte-2.0.0b30.dist-info}/WHEEL +0 -0
  195. {flyte-2.0.0b22.dist-info → flyte-2.0.0b30.dist-info}/entry_points.txt +0 -0
  196. {flyte-2.0.0b22.dist-info → flyte-2.0.0b30.dist-info}/licenses/LICENSE +0 -0
  197. {flyte-2.0.0b22.dist-info → flyte-2.0.0b30.dist-info}/top_level.txt +0 -0
flyte/_task.py CHANGED
@@ -15,10 +15,12 @@ from typing import (
15
15
  Literal,
16
16
  Optional,
17
17
  ParamSpec,
18
+ Tuple,
18
19
  TypeAlias,
19
20
  TypeVar,
20
21
  Union,
21
22
  cast,
23
+ overload,
22
24
  )
23
25
 
24
26
  from flyte._pod import PodTemplate
@@ -33,10 +35,11 @@ from ._retry import RetryStrategy
33
35
  from ._reusable_environment import ReusePolicy
34
36
  from ._secret import SecretRequest
35
37
  from ._timeout import TimeoutType
38
+ from ._trigger import Trigger
36
39
  from .models import MAX_INLINE_IO_BYTES, NativeInterface, SerializationContext
37
40
 
38
41
  if TYPE_CHECKING:
39
- from flyteidl.core.tasks_pb2 import DataLoadingConfig
42
+ from flyteidl2.core.tasks_pb2 import DataLoadingConfig
40
43
 
41
44
  from ._task_environment import TaskEnvironment
42
45
 
@@ -45,11 +48,12 @@ R = TypeVar("R") # return type
45
48
 
46
49
  AsyncFunctionType: TypeAlias = Callable[P, Coroutine[Any, Any, R]]
47
50
  SyncFunctionType: TypeAlias = Callable[P, R]
48
- FunctionTypes: TypeAlias = Union[AsyncFunctionType, SyncFunctionType]
51
+ FunctionTypes: TypeAlias = AsyncFunctionType | SyncFunctionType
52
+ F = TypeVar("F", bound=FunctionTypes)
49
53
 
50
54
 
51
55
  @dataclass(kw_only=True)
52
- class TaskTemplate(Generic[P, R]):
56
+ class TaskTemplate(Generic[P, R, F]):
53
57
  """
54
58
  Task template is a template for a task that can be executed. It defines various parameters for the task, which
55
59
  can be defined statically at the time of task definition or dynamically at the time of task invocation using
@@ -69,8 +73,8 @@ class TaskTemplate(Generic[P, R]):
69
73
  version with flyte installed
70
74
  :param resources: Optional The resources to use for the task
71
75
  :param cache: Optional The cache policy for the task, defaults to auto, which will cache the results of the task.
72
- :param interruptable: Optional The interruptable policy for the task, defaults to False, which means the task
73
- will not be scheduled on interruptable nodes. If set to True, the task will be scheduled on interruptable nodes,
76
+ :param interruptible: Optional The interruptible policy for the task, defaults to False, which means the task
77
+ will not be scheduled on interruptible nodes. If set to True, the task will be scheduled on interruptible nodes,
74
78
  and the code should handle interruptions and resumptions.
75
79
  :param retries: Optional The number of retries for the task, defaults to 0, which means no retries.
76
80
  :param reusable: Optional The reusability policy for the task, defaults to None, which means the task environment
@@ -81,6 +85,10 @@ class TaskTemplate(Generic[P, R]):
81
85
  :param timeout: Optional The timeout for the task.
82
86
  :param max_inline_io_bytes: Maximum allowed size (in bytes) for all inputs and outputs passed directly to the task
83
87
  (e.g., primitives, strings, dicts). Does not apply to files, directories, or dataframes.
88
+ :param pod_template: Optional The pod template to use for the task.
89
+ :param report: Optional Whether to report the task execution to the Flyte console, defaults to False.
90
+ :param queue: Optional The queue to use for the task. If not provided, the default queue will be used.
91
+ :param debuggable: Optional Whether the task supports debugging capabilities, defaults to False.
84
92
  """
85
93
 
86
94
  name: str
@@ -90,8 +98,8 @@ class TaskTemplate(Generic[P, R]):
90
98
  task_type_version: int = 0
91
99
  image: Union[str, Image, Literal["auto"]] = "auto"
92
100
  resources: Optional[Resources] = None
93
- cache: CacheRequest = "auto"
94
- interruptable: bool = False
101
+ cache: CacheRequest = "disable"
102
+ interruptible: bool = False
95
103
  retries: Union[int, RetryStrategy] = 0
96
104
  reusable: Union[ReusePolicy, None] = None
97
105
  docs: Optional[Documentation] = None
@@ -100,10 +108,14 @@ class TaskTemplate(Generic[P, R]):
100
108
  timeout: Optional[TimeoutType] = None
101
109
  pod_template: Optional[Union[str, PodTemplate]] = None
102
110
  report: bool = False
111
+ queue: Optional[str] = None
112
+ debuggable: bool = False
103
113
 
104
114
  parent_env: Optional[weakref.ReferenceType[TaskEnvironment]] = None
115
+ parent_env_name: Optional[str] = None
105
116
  ref: bool = field(default=False, init=False, repr=False, compare=False)
106
117
  max_inline_io_bytes: int = MAX_INLINE_IO_BYTES
118
+ triggers: Tuple[Trigger, ...] = field(default_factory=tuple)
107
119
 
108
120
  # Only used in python 3.10 and 3.11, where we cannot use markcoroutinefunction
109
121
  _call_as_synchronous: bool = False
@@ -217,6 +229,14 @@ class TaskTemplate(Generic[P, R]):
217
229
  def native_interface(self) -> NativeInterface:
218
230
  return self.interface
219
231
 
232
+ @overload
233
+ async def aio(self: TaskTemplate[P, R, SyncFunctionType], *args: P.args, **kwargs: P.kwargs) -> R: ...
234
+
235
+ @overload
236
+ async def aio(
237
+ self: TaskTemplate[P, R, AsyncFunctionType], *args: P.args, **kwargs: P.kwargs
238
+ ) -> Coroutine[Any, Any, R]: ...
239
+
220
240
  async def aio(self, *args: P.args, **kwargs: P.kwargs) -> Coroutine[Any, Any, R] | R:
221
241
  """
222
242
  The aio function allows executing "sync" tasks, in an async context. This helps with migrating v1 defined sync
@@ -240,7 +260,6 @@ class TaskTemplate(Generic[P, R]):
240
260
  :param kwargs:
241
261
  :return:
242
262
  """
243
-
244
263
  ctx = internal_ctx()
245
264
  if ctx.is_task_context():
246
265
  from ._internal.controllers import get_controller
@@ -265,6 +284,14 @@ class TaskTemplate(Generic[P, R]):
265
284
  # even for synchronous tasks. This is to support migration.
266
285
  return self.forward(*args, **kwargs)
267
286
 
287
+ @overload
288
+ def __call__(self: TaskTemplate[P, R, SyncFunctionType], *args: P.args, **kwargs: P.kwargs) -> R: ...
289
+
290
+ @overload
291
+ def __call__(
292
+ self: TaskTemplate[P, R, AsyncFunctionType], *args: P.args, **kwargs: P.kwargs
293
+ ) -> Coroutine[Any, Any, R]: ...
294
+
268
295
  def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Coroutine[Any, Any, R] | R:
269
296
  """
270
297
  This is the entrypoint for an async function task at runtime. It will be called during an execution.
@@ -327,11 +354,30 @@ class TaskTemplate(Generic[P, R]):
327
354
  secrets: Optional[SecretRequest] = None,
328
355
  max_inline_io_bytes: int | None = None,
329
356
  pod_template: Optional[Union[str, PodTemplate]] = None,
357
+ queue: Optional[str] = None,
358
+ interruptible: Optional[bool] = None,
330
359
  **kwargs: Any,
331
360
  ) -> TaskTemplate:
332
361
  """
333
362
  Override various parameters of the task template. This allows for dynamic configuration of the task
334
363
  when it is called, such as changing the image, resources, cache policy, etc.
364
+
365
+ :param short_name: Optional override for the short name of the task.
366
+ :param resources: Optional override for the resources to use for the task.
367
+ :param cache: Optional override for the cache policy for the task.
368
+ :param retries: Optional override for the number of retries for the task.
369
+ :param timeout: Optional override for the timeout for the task.
370
+ :param reusable: Optional override for the reusability policy for the task.
371
+ :param env_vars: Optional override for the environment variables to set for the task.
372
+ :param secrets: Optional override for the secrets that will be injected into the task at runtime.
373
+ :param max_inline_io_bytes: Optional override for the maximum allowed size (in bytes) for all inputs and outputs
374
+ passed directly to the task.
375
+ :param pod_template: Optional override for the pod template to use for the task.
376
+ :param queue: Optional override for the queue to use for the task.
377
+ :param kwargs: Additional keyword arguments for further overrides. Some fields like name, image, docs,
378
+ and interface cannot be overridden.
379
+
380
+ :return: A new TaskTemplate instance with the overridden parameters.
335
381
  """
336
382
  cache = cache or self.cache
337
383
  retries = retries or self.retries
@@ -366,6 +412,8 @@ class TaskTemplate(Generic[P, R]):
366
412
  env_vars = env_vars or self.env_vars
367
413
  secrets = secrets or self.secrets
368
414
 
415
+ interruptible = interruptible if interruptible is not None else self.interruptible
416
+
369
417
  for k, v in kwargs.items():
370
418
  if k == "name":
371
419
  raise ValueError("Name cannot be overridden")
@@ -388,19 +436,22 @@ class TaskTemplate(Generic[P, R]):
388
436
  secrets=secrets,
389
437
  max_inline_io_bytes=max_inline_io_bytes,
390
438
  pod_template=pod_template,
439
+ interruptible=interruptible,
440
+ queue=queue or self.queue,
391
441
  **kwargs,
392
442
  )
393
443
 
394
444
 
395
445
  @dataclass(kw_only=True)
396
- class AsyncFunctionTaskTemplate(TaskTemplate[P, R]):
446
+ class AsyncFunctionTaskTemplate(TaskTemplate[P, R, F]):
397
447
  """
398
448
  A task template that wraps an asynchronous functions. This is automatically created when an asynchronous function
399
449
  is decorated with the task decorator.
400
450
  """
401
451
 
402
- func: FunctionTypes
452
+ func: F
403
453
  plugin_config: Optional[Any] = None # This is used to pass plugin specific configuration
454
+ debuggable: bool = True
404
455
 
405
456
  def __post_init__(self):
406
457
  super().__post_init__()
@@ -479,6 +530,11 @@ class AsyncFunctionTaskTemplate(TaskTemplate[P, R]):
479
530
 
480
531
  from flyte._internal.resolvers.default import DefaultTaskResolver
481
532
 
533
+ if not serialize_context.root_dir:
534
+ raise RuntimeSystemError(
535
+ "SerializationError",
536
+ "Root dir is required for default task resolver when no code bundle is provided.",
537
+ )
482
538
  _task_resolver = DefaultTaskResolver()
483
539
  args = [
484
540
  *args,
@@ -12,8 +12,10 @@ from typing import (
12
12
  List,
13
13
  Literal,
14
14
  Optional,
15
+ Tuple,
15
16
  Union,
16
17
  cast,
18
+ overload,
17
19
  )
18
20
 
19
21
  import rich.repr
@@ -22,17 +24,17 @@ from ._cache import Cache, CacheRequest
22
24
  from ._doc import Documentation
23
25
  from ._environment import Environment
24
26
  from ._image import Image
27
+ from ._pod import PodTemplate
25
28
  from ._resources import Resources
26
29
  from ._retry import RetryStrategy
27
30
  from ._reusable_environment import ReusePolicy
28
31
  from ._secret import SecretRequest
29
32
  from ._task import AsyncFunctionTaskTemplate, TaskTemplate
33
+ from ._trigger import Trigger
30
34
  from .models import MAX_INLINE_IO_BYTES, NativeInterface
31
35
 
32
36
  if TYPE_CHECKING:
33
- from kubernetes.client import V1PodTemplate
34
-
35
- from ._task import FunctionTypes, P, R
37
+ from ._task import F, P, R
36
38
 
37
39
 
38
40
  @rich.repr.auto
@@ -60,13 +62,18 @@ class TaskEnvironment(Environment):
60
62
  that depend on each other.
61
63
  :param cache: Cache policy for the environment.
62
64
  :param reusable: Reuse policy for the environment, if set, a python process may be reused for multiple tasks.
65
+ :param plugin_config: Optional plugin configuration for custom task types.
66
+ If set, all tasks in this environment will use the specified plugin configuration.
67
+ :param queue: Optional queue name to use for tasks in this environment.
68
+ If not set, the default queue will be used.
69
+ :param pod_template: Optional pod template to use for tasks in this environment.
70
+ If not set, the default pod template will be used.
63
71
  """
64
72
 
65
73
  cache: CacheRequest = "disable"
66
74
  reusable: ReusePolicy | None = None
67
75
  plugin_config: Optional[Any] = None
68
- # TODO Shall we make this union of string or env? This way we can lookup the env by module/file:name
69
- # TODO also we could add list of files that are used by this environment
76
+ queue: Optional[str] = None
70
77
 
71
78
  _tasks: Dict[str, TaskTemplate] = field(default_factory=dict, init=False)
72
79
 
@@ -87,6 +94,8 @@ class TaskEnvironment(Environment):
87
94
  env_vars: Optional[Dict[str, str]] = None,
88
95
  secrets: Optional[SecretRequest] = None,
89
96
  depends_on: Optional[List[Environment]] = None,
97
+ description: Optional[str] = None,
98
+ interruptible: Optional[bool] = None,
90
99
  **kwargs: Any,
91
100
  ) -> TaskEnvironment:
92
101
  """
@@ -102,6 +111,11 @@ class TaskEnvironment(Environment):
102
111
  :param depends_on: The environment dependencies to hint, so when you deploy the environment,
103
112
  the dependencies are also deployed. This is useful when you have a set of environments
104
113
  that depend on each other.
114
+ :param queue: The queue name to use for tasks in this environment.
115
+ :param pod_template: The pod template to use for tasks in this environment.
116
+ :param description: The description of the environment.
117
+ :param interruptible: Whether the environment is interruptible and can be scheduled on spot/preemptible
118
+ instances.
105
119
  :param kwargs: Additional parameters to override the environment (e.g., cache, reusable, plugin_config).
106
120
  """
107
121
  cache = kwargs.pop("cache", None)
@@ -131,21 +145,52 @@ class TaskEnvironment(Environment):
131
145
  kwargs["secrets"] = secrets
132
146
  if depends_on is not None:
133
147
  kwargs["depends_on"] = depends_on
148
+ if description is not None:
149
+ kwargs["description"] = description
150
+ if interruptible is not None:
151
+ kwargs["interruptible"] = interruptible
134
152
  return replace(self, **kwargs)
135
153
 
154
+ @overload
136
155
  def task(
137
156
  self,
138
- _func: Callable[P, R] | None = None,
139
157
  *,
140
158
  short_name: Optional[str] = None,
141
159
  cache: CacheRequest | None = None,
142
160
  retries: Union[int, RetryStrategy] = 0,
143
161
  timeout: Union[timedelta, int] = 0,
144
162
  docs: Optional[Documentation] = None,
145
- pod_template: Optional[Union[str, "V1PodTemplate"]] = None,
163
+ pod_template: Optional[Union[str, PodTemplate]] = None,
146
164
  report: bool = False,
165
+ interruptible: bool | None = None,
147
166
  max_inline_io_bytes: int = MAX_INLINE_IO_BYTES,
148
- ) -> Union[AsyncFunctionTaskTemplate, Callable[P, R]]:
167
+ queue: Optional[str] = None,
168
+ triggers: Tuple[Trigger, ...] | Trigger = (),
169
+ ) -> Callable[[Callable[P, R]], AsyncFunctionTaskTemplate[P, R, Callable[P, R]]]: ...
170
+
171
+ @overload
172
+ def task(
173
+ self,
174
+ _func: Callable[P, R],
175
+ /,
176
+ ) -> AsyncFunctionTaskTemplate[P, R, Callable[P, R]]: ...
177
+
178
+ def task(
179
+ self,
180
+ _func: F | None = None,
181
+ *,
182
+ short_name: Optional[str] = None,
183
+ cache: CacheRequest | None = None,
184
+ retries: Union[int, RetryStrategy] = 0,
185
+ timeout: Union[timedelta, int] = 0,
186
+ docs: Optional[Documentation] = None,
187
+ pod_template: Optional[Union[str, PodTemplate]] = None,
188
+ report: bool = False,
189
+ interruptible: bool | None = None,
190
+ max_inline_io_bytes: int = MAX_INLINE_IO_BYTES,
191
+ queue: Optional[str] = None,
192
+ triggers: Tuple[Trigger, ...] | Trigger = (),
193
+ ) -> Callable[[F], AsyncFunctionTaskTemplate[P, R, F]] | AsyncFunctionTaskTemplate[P, R, F]:
149
194
  """
150
195
  Decorate a function to be a task.
151
196
 
@@ -162,14 +207,20 @@ class TaskEnvironment(Environment):
162
207
  :param report: Optional Whether to generate the html report for the task, defaults to False.
163
208
  :param max_inline_io_bytes: Maximum allowed size (in bytes) for all inputs and outputs passed directly to the
164
209
  task (e.g., primitives, strings, dicts). Does not apply to files, directories, or dataframes.
210
+ :param triggers: Optional A tuple of triggers to associate with the task. This allows the task to be run on a
211
+ schedule or in response to events. Triggers can be defined using the `flyte.trigger` module.
212
+ :param interruptible: Optional Whether the task is interruptible, defaults to environment setting.
213
+ :param queue: Optional queue name to use for this task. If not set, the environment's queue will be used.
214
+
215
+ :return: A TaskTemplate that can be used to deploy the task.
165
216
  """
166
- from ._task import P, R
217
+ from ._task import F, P, R
167
218
 
168
219
  if self.reusable is not None:
169
220
  if pod_template is not None:
170
221
  raise ValueError("Cannot set pod_template when environment is reusable.")
171
222
 
172
- def decorator(func: FunctionTypes) -> AsyncFunctionTaskTemplate[P, R]:
223
+ def decorator(func: F) -> AsyncFunctionTaskTemplate[P, R, F]:
173
224
  short = short_name or func.__name__
174
225
  task_name = self.name + "." + func.__name__
175
226
 
@@ -183,7 +234,7 @@ class TaskEnvironment(Environment):
183
234
  if self.plugin_config is not None:
184
235
  from flyte.extend import TaskPluginRegistry
185
236
 
186
- task_template_class: type[AsyncFunctionTaskTemplate[P, R]] | None = TaskPluginRegistry.find(
237
+ task_template_class: type[AsyncFunctionTaskTemplate[P, R, F]] | None = TaskPluginRegistry.find(
187
238
  config_type=type(self.plugin_config)
188
239
  )
189
240
  if task_template_class is None:
@@ -192,9 +243,9 @@ class TaskEnvironment(Environment):
192
243
  f"Please register a plugin using flyte.extend.TaskPluginRegistry.register() api."
193
244
  )
194
245
  else:
195
- task_template_class = AsyncFunctionTaskTemplate[P, R]
246
+ task_template_class = AsyncFunctionTaskTemplate[P, R, F]
196
247
 
197
- task_template_class = cast(type[AsyncFunctionTaskTemplate[P, R]], task_template_class)
248
+ task_template_class = cast(type[AsyncFunctionTaskTemplate[P, R, F]], task_template_class)
198
249
  tmpl = task_template_class(
199
250
  func=func,
200
251
  name=task_name,
@@ -209,18 +260,22 @@ class TaskEnvironment(Environment):
209
260
  secrets=self.secrets,
210
261
  pod_template=pod_template or self.pod_template,
211
262
  parent_env=weakref.ref(self),
263
+ parent_env_name=self.name,
212
264
  interface=NativeInterface.from_callable(func),
213
265
  report=report,
214
266
  short_name=short,
215
267
  plugin_config=self.plugin_config,
216
268
  max_inline_io_bytes=max_inline_io_bytes,
269
+ queue=queue or self.queue,
270
+ interruptible=interruptible if interruptible is not None else self.interruptible,
271
+ triggers=triggers if isinstance(triggers, tuple) else (triggers,),
217
272
  )
218
273
  self._tasks[task_name] = tmpl
219
274
  return tmpl
220
275
 
221
276
  if _func is None:
222
- return cast(AsyncFunctionTaskTemplate, decorator)
223
- return cast(AsyncFunctionTaskTemplate, decorator(_func))
277
+ return cast(Callable[[F], AsyncFunctionTaskTemplate[P, R, F]], decorator)
278
+ return cast(AsyncFunctionTaskTemplate[P, R, F], decorator(_func))
224
279
 
225
280
  @property
226
281
  def tasks(self) -> Dict[str, TaskTemplate]:
@@ -229,16 +284,33 @@ class TaskEnvironment(Environment):
229
284
  """
230
285
  return self._tasks
231
286
 
232
- def add_task(self, task: TaskTemplate) -> TaskTemplate:
287
+ @classmethod
288
+ def from_task(cls, name: str, *tasks: TaskTemplate) -> TaskEnvironment:
233
289
  """
234
- Add a task to the environment.
290
+ Create a TaskEnvironment from a list of tasks. All tasks should have the same image or no Image defined.
291
+ Similarity of Image is determined by the python reference, not by value.
292
+
293
+ If images are different, an error is raised. If no image is defined, the image is set to "auto".
235
294
 
236
- Useful when you want to add a task to an environment that is not defined using the `task` decorator.
295
+ For any other tasks that need to be use these tasks, the returned environment can be used in the `depends_on`
296
+ attribute of the other TaskEnvironment.
297
+
298
+ :param name: The name of the environment.
299
+ :param tasks: The list of tasks to create the environment from.
237
300
 
238
- :param task: The TaskTemplate to add to this environment.
301
+ :raises ValueError: If tasks are assigned to multiple environments or have different images.
302
+ :return: The created TaskEnvironment.
239
303
  """
240
- if task.name in self._tasks:
241
- raise ValueError(f"Task {task.name} already exists in the environment. Task names should be unique.")
242
- self._tasks[task.name] = task
243
- task.parent_env = weakref.ref(self)
244
- return task
304
+ envs = [t.parent_env() for t in tasks if t.parent_env and t.parent_env() is not None]
305
+ if envs:
306
+ raise ValueError("Tasks cannot assigned to multiple environments.")
307
+ images = {t.image for t in tasks}
308
+ if len(images) > 1:
309
+ raise ValueError("Tasks must have the same image to be in the same environment.")
310
+ image: Union[str, Image] = images.pop() if images else "auto"
311
+ env = cls(name, image=image)
312
+ for t in tasks:
313
+ env._tasks[t.name] = t
314
+ t.parent_env = weakref.ref(env)
315
+ t.parent_env_name = name
316
+ return env
flyte/_task_plugins.py CHANGED
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import typing
2
4
  from typing import Type
3
5
 
@@ -14,8 +16,8 @@ class _Registry:
14
16
  A registry for task plugins.
15
17
  """
16
18
 
17
- def __init__(self):
18
- self._plugins: typing.Dict[Type, Type[T]] = {}
19
+ def __init__(self: _Registry):
20
+ self._plugins: typing.Dict[Type, Type[typing.Any]] = {}
19
21
 
20
22
  def register(self, config_type: Type, plugin: Type[T]):
21
23
  """