flyte 2.0.0b13__py3-none-any.whl → 2.0.0b30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (211) hide show
  1. flyte/__init__.py +18 -2
  2. flyte/_bin/debug.py +38 -0
  3. flyte/_bin/runtime.py +62 -8
  4. flyte/_cache/cache.py +4 -2
  5. flyte/_cache/local_cache.py +216 -0
  6. flyte/_code_bundle/_ignore.py +12 -4
  7. flyte/_code_bundle/_packaging.py +13 -9
  8. flyte/_code_bundle/_utils.py +18 -10
  9. flyte/_code_bundle/bundle.py +17 -9
  10. flyte/_constants.py +1 -0
  11. flyte/_context.py +4 -1
  12. flyte/_custom_context.py +73 -0
  13. flyte/_debug/constants.py +38 -0
  14. flyte/_debug/utils.py +17 -0
  15. flyte/_debug/vscode.py +307 -0
  16. flyte/_deploy.py +235 -61
  17. flyte/_environment.py +20 -6
  18. flyte/_excepthook.py +1 -1
  19. flyte/_hash.py +1 -16
  20. flyte/_image.py +178 -81
  21. flyte/_initialize.py +132 -51
  22. flyte/_interface.py +39 -2
  23. flyte/_internal/controllers/__init__.py +4 -5
  24. flyte/_internal/controllers/_local_controller.py +70 -29
  25. flyte/_internal/controllers/_trace.py +1 -1
  26. flyte/_internal/controllers/remote/__init__.py +0 -2
  27. flyte/_internal/controllers/remote/_action.py +14 -16
  28. flyte/_internal/controllers/remote/_client.py +1 -1
  29. flyte/_internal/controllers/remote/_controller.py +68 -70
  30. flyte/_internal/controllers/remote/_core.py +127 -99
  31. flyte/_internal/controllers/remote/_informer.py +19 -10
  32. flyte/_internal/controllers/remote/_service_protocol.py +7 -7
  33. flyte/_internal/imagebuild/docker_builder.py +181 -69
  34. flyte/_internal/imagebuild/image_builder.py +0 -5
  35. flyte/_internal/imagebuild/remote_builder.py +155 -64
  36. flyte/_internal/imagebuild/utils.py +51 -2
  37. flyte/_internal/resolvers/_task_module.py +5 -38
  38. flyte/_internal/resolvers/default.py +2 -2
  39. flyte/_internal/runtime/convert.py +110 -21
  40. flyte/_internal/runtime/entrypoints.py +27 -1
  41. flyte/_internal/runtime/io.py +21 -8
  42. flyte/_internal/runtime/resources_serde.py +20 -6
  43. flyte/_internal/runtime/reuse.py +1 -1
  44. flyte/_internal/runtime/rusty.py +20 -5
  45. flyte/_internal/runtime/task_serde.py +34 -19
  46. flyte/_internal/runtime/taskrunner.py +22 -4
  47. flyte/_internal/runtime/trigger_serde.py +160 -0
  48. flyte/_internal/runtime/types_serde.py +1 -1
  49. flyte/_keyring/__init__.py +0 -0
  50. flyte/_keyring/file.py +115 -0
  51. flyte/_logging.py +201 -39
  52. flyte/_map.py +111 -14
  53. flyte/_module.py +70 -0
  54. flyte/_pod.py +4 -3
  55. flyte/_resources.py +213 -31
  56. flyte/_run.py +110 -39
  57. flyte/_task.py +75 -16
  58. flyte/_task_environment.py +105 -29
  59. flyte/_task_plugins.py +4 -2
  60. flyte/_trace.py +5 -0
  61. flyte/_trigger.py +1000 -0
  62. flyte/_utils/__init__.py +2 -1
  63. flyte/_utils/asyn.py +3 -1
  64. flyte/_utils/coro_management.py +2 -1
  65. flyte/_utils/docker_credentials.py +173 -0
  66. flyte/_utils/module_loader.py +17 -2
  67. flyte/_version.py +3 -3
  68. flyte/cli/_abort.py +3 -3
  69. flyte/cli/_build.py +3 -6
  70. flyte/cli/_common.py +78 -7
  71. flyte/cli/_create.py +182 -4
  72. flyte/cli/_delete.py +23 -1
  73. flyte/cli/_deploy.py +63 -16
  74. flyte/cli/_get.py +79 -34
  75. flyte/cli/_params.py +26 -10
  76. flyte/cli/_plugins.py +209 -0
  77. flyte/cli/_run.py +151 -26
  78. flyte/cli/_serve.py +64 -0
  79. flyte/cli/_update.py +37 -0
  80. flyte/cli/_user.py +17 -0
  81. flyte/cli/main.py +30 -4
  82. flyte/config/_config.py +10 -6
  83. flyte/config/_internal.py +1 -0
  84. flyte/config/_reader.py +29 -8
  85. flyte/connectors/__init__.py +11 -0
  86. flyte/connectors/_connector.py +270 -0
  87. flyte/connectors/_server.py +197 -0
  88. flyte/connectors/utils.py +135 -0
  89. flyte/errors.py +22 -2
  90. flyte/extend.py +8 -1
  91. flyte/extras/_container.py +6 -1
  92. flyte/git/__init__.py +3 -0
  93. flyte/git/_config.py +21 -0
  94. flyte/io/__init__.py +2 -0
  95. flyte/io/_dataframe/__init__.py +2 -0
  96. flyte/io/_dataframe/basic_dfs.py +17 -8
  97. flyte/io/_dataframe/dataframe.py +98 -132
  98. flyte/io/_dir.py +575 -113
  99. flyte/io/_file.py +582 -139
  100. flyte/io/_hashing_io.py +342 -0
  101. flyte/models.py +74 -15
  102. flyte/remote/__init__.py +6 -1
  103. flyte/remote/_action.py +34 -26
  104. flyte/remote/_client/_protocols.py +39 -4
  105. flyte/remote/_client/auth/_authenticators/device_code.py +4 -5
  106. flyte/remote/_client/auth/_authenticators/pkce.py +1 -1
  107. flyte/remote/_client/auth/_channel.py +10 -6
  108. flyte/remote/_client/controlplane.py +17 -5
  109. flyte/remote/_console.py +3 -2
  110. flyte/remote/_data.py +6 -6
  111. flyte/remote/_logs.py +3 -3
  112. flyte/remote/_run.py +64 -8
  113. flyte/remote/_secret.py +26 -17
  114. flyte/remote/_task.py +75 -33
  115. flyte/remote/_trigger.py +306 -0
  116. flyte/remote/_user.py +33 -0
  117. flyte/report/_report.py +1 -1
  118. flyte/storage/__init__.py +6 -1
  119. flyte/storage/_config.py +5 -1
  120. flyte/storage/_parallel_reader.py +274 -0
  121. flyte/storage/_storage.py +200 -103
  122. flyte/types/__init__.py +16 -0
  123. flyte/types/_interface.py +2 -2
  124. flyte/types/_pickle.py +35 -8
  125. flyte/types/_string_literals.py +8 -9
  126. flyte/types/_type_engine.py +40 -70
  127. flyte/types/_utils.py +1 -1
  128. flyte-2.0.0b30.data/scripts/debug.py +38 -0
  129. {flyte-2.0.0b13.data → flyte-2.0.0b30.data}/scripts/runtime.py +62 -8
  130. {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/METADATA +11 -3
  131. flyte-2.0.0b30.dist-info/RECORD +192 -0
  132. {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/entry_points.txt +3 -0
  133. flyte/_protos/common/authorization_pb2.py +0 -66
  134. flyte/_protos/common/authorization_pb2.pyi +0 -108
  135. flyte/_protos/common/authorization_pb2_grpc.py +0 -4
  136. flyte/_protos/common/identifier_pb2.py +0 -93
  137. flyte/_protos/common/identifier_pb2.pyi +0 -110
  138. flyte/_protos/common/identifier_pb2_grpc.py +0 -4
  139. flyte/_protos/common/identity_pb2.py +0 -48
  140. flyte/_protos/common/identity_pb2.pyi +0 -72
  141. flyte/_protos/common/identity_pb2_grpc.py +0 -4
  142. flyte/_protos/common/list_pb2.py +0 -36
  143. flyte/_protos/common/list_pb2.pyi +0 -71
  144. flyte/_protos/common/list_pb2_grpc.py +0 -4
  145. flyte/_protos/common/policy_pb2.py +0 -37
  146. flyte/_protos/common/policy_pb2.pyi +0 -27
  147. flyte/_protos/common/policy_pb2_grpc.py +0 -4
  148. flyte/_protos/common/role_pb2.py +0 -37
  149. flyte/_protos/common/role_pb2.pyi +0 -53
  150. flyte/_protos/common/role_pb2_grpc.py +0 -4
  151. flyte/_protos/common/runtime_version_pb2.py +0 -28
  152. flyte/_protos/common/runtime_version_pb2.pyi +0 -24
  153. flyte/_protos/common/runtime_version_pb2_grpc.py +0 -4
  154. flyte/_protos/imagebuilder/definition_pb2.py +0 -59
  155. flyte/_protos/imagebuilder/definition_pb2.pyi +0 -140
  156. flyte/_protos/imagebuilder/definition_pb2_grpc.py +0 -4
  157. flyte/_protos/imagebuilder/payload_pb2.py +0 -32
  158. flyte/_protos/imagebuilder/payload_pb2.pyi +0 -21
  159. flyte/_protos/imagebuilder/payload_pb2_grpc.py +0 -4
  160. flyte/_protos/imagebuilder/service_pb2.py +0 -29
  161. flyte/_protos/imagebuilder/service_pb2.pyi +0 -5
  162. flyte/_protos/imagebuilder/service_pb2_grpc.py +0 -66
  163. flyte/_protos/logs/dataplane/payload_pb2.py +0 -100
  164. flyte/_protos/logs/dataplane/payload_pb2.pyi +0 -177
  165. flyte/_protos/logs/dataplane/payload_pb2_grpc.py +0 -4
  166. flyte/_protos/secret/definition_pb2.py +0 -49
  167. flyte/_protos/secret/definition_pb2.pyi +0 -93
  168. flyte/_protos/secret/definition_pb2_grpc.py +0 -4
  169. flyte/_protos/secret/payload_pb2.py +0 -62
  170. flyte/_protos/secret/payload_pb2.pyi +0 -94
  171. flyte/_protos/secret/payload_pb2_grpc.py +0 -4
  172. flyte/_protos/secret/secret_pb2.py +0 -38
  173. flyte/_protos/secret/secret_pb2.pyi +0 -6
  174. flyte/_protos/secret/secret_pb2_grpc.py +0 -198
  175. flyte/_protos/secret/secret_pb2_grpc_grpc.py +0 -198
  176. flyte/_protos/validate/validate/validate_pb2.py +0 -76
  177. flyte/_protos/workflow/common_pb2.py +0 -27
  178. flyte/_protos/workflow/common_pb2.pyi +0 -14
  179. flyte/_protos/workflow/common_pb2_grpc.py +0 -4
  180. flyte/_protos/workflow/environment_pb2.py +0 -29
  181. flyte/_protos/workflow/environment_pb2.pyi +0 -12
  182. flyte/_protos/workflow/environment_pb2_grpc.py +0 -4
  183. flyte/_protos/workflow/node_execution_service_pb2.py +0 -26
  184. flyte/_protos/workflow/node_execution_service_pb2.pyi +0 -4
  185. flyte/_protos/workflow/node_execution_service_pb2_grpc.py +0 -32
  186. flyte/_protos/workflow/queue_service_pb2.py +0 -109
  187. flyte/_protos/workflow/queue_service_pb2.pyi +0 -166
  188. flyte/_protos/workflow/queue_service_pb2_grpc.py +0 -172
  189. flyte/_protos/workflow/run_definition_pb2.py +0 -121
  190. flyte/_protos/workflow/run_definition_pb2.pyi +0 -327
  191. flyte/_protos/workflow/run_definition_pb2_grpc.py +0 -4
  192. flyte/_protos/workflow/run_logs_service_pb2.py +0 -41
  193. flyte/_protos/workflow/run_logs_service_pb2.pyi +0 -28
  194. flyte/_protos/workflow/run_logs_service_pb2_grpc.py +0 -69
  195. flyte/_protos/workflow/run_service_pb2.py +0 -137
  196. flyte/_protos/workflow/run_service_pb2.pyi +0 -185
  197. flyte/_protos/workflow/run_service_pb2_grpc.py +0 -446
  198. flyte/_protos/workflow/state_service_pb2.py +0 -67
  199. flyte/_protos/workflow/state_service_pb2.pyi +0 -76
  200. flyte/_protos/workflow/state_service_pb2_grpc.py +0 -138
  201. flyte/_protos/workflow/task_definition_pb2.py +0 -79
  202. flyte/_protos/workflow/task_definition_pb2.pyi +0 -81
  203. flyte/_protos/workflow/task_definition_pb2_grpc.py +0 -4
  204. flyte/_protos/workflow/task_service_pb2.py +0 -60
  205. flyte/_protos/workflow/task_service_pb2.pyi +0 -59
  206. flyte/_protos/workflow/task_service_pb2_grpc.py +0 -138
  207. flyte-2.0.0b13.dist-info/RECORD +0 -239
  208. /flyte/{_protos → _debug}/__init__.py +0 -0
  209. {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/WHEEL +0 -0
  210. {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/licenses/LICENSE +0 -0
  211. {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/top_level.txt +0 -0
@@ -12,27 +12,29 @@ from typing import (
12
12
  List,
13
13
  Literal,
14
14
  Optional,
15
+ Tuple,
15
16
  Union,
16
17
  cast,
18
+ overload,
17
19
  )
18
20
 
19
21
  import rich.repr
20
22
 
21
- from ._cache import CacheRequest
23
+ from ._cache import Cache, CacheRequest
22
24
  from ._doc import Documentation
23
25
  from ._environment import Environment
24
26
  from ._image import Image
27
+ from ._pod import PodTemplate
25
28
  from ._resources import Resources
26
29
  from ._retry import RetryStrategy
27
30
  from ._reusable_environment import ReusePolicy
28
31
  from ._secret import SecretRequest
29
32
  from ._task import AsyncFunctionTaskTemplate, TaskTemplate
33
+ from ._trigger import Trigger
30
34
  from .models import MAX_INLINE_IO_BYTES, NativeInterface
31
35
 
32
36
  if TYPE_CHECKING:
33
- from kubernetes.client import V1PodTemplate
34
-
35
- from ._task import FunctionTypes, P, R
37
+ from ._task import F, P, R
36
38
 
37
39
 
38
40
  @rich.repr.auto
@@ -60,13 +62,18 @@ class TaskEnvironment(Environment):
60
62
  that depend on each other.
61
63
  :param cache: Cache policy for the environment.
62
64
  :param reusable: Reuse policy for the environment, if set, a python process may be reused for multiple tasks.
65
+ :param plugin_config: Optional plugin configuration for custom task types.
66
+ If set, all tasks in this environment will use the specified plugin configuration.
67
+ :param queue: Optional queue name to use for tasks in this environment.
68
+ If not set, the default queue will be used.
69
+ :param pod_template: Optional pod template to use for tasks in this environment.
70
+ If not set, the default pod template will be used.
63
71
  """
64
72
 
65
73
  cache: CacheRequest = "disable"
66
74
  reusable: ReusePolicy | None = None
67
75
  plugin_config: Optional[Any] = None
68
- # TODO Shall we make this union of string or env? This way we can lookup the env by module/file:name
69
- # TODO also we could add list of files that are used by this environment
76
+ queue: Optional[str] = None
70
77
 
71
78
  _tasks: Dict[str, TaskTemplate] = field(default_factory=dict, init=False)
72
79
 
@@ -74,6 +81,10 @@ class TaskEnvironment(Environment):
74
81
  super().__post_init__()
75
82
  if self.reusable is not None and self.plugin_config is not None:
76
83
  raise ValueError("Cannot set plugin_config when environment is reusable.")
84
+ if self.reusable and not isinstance(self.reusable, ReusePolicy):
85
+ raise TypeError(f"Expected reusable to be of type ReusePolicy, got {type(self.reusable)}")
86
+ if self.cache and not isinstance(self.cache, (str, Cache)):
87
+ raise TypeError(f"Expected cache to be of type str or Cache, got {type(self.cache)}")
77
88
 
78
89
  def clone_with(
79
90
  self,
@@ -83,6 +94,8 @@ class TaskEnvironment(Environment):
83
94
  env_vars: Optional[Dict[str, str]] = None,
84
95
  secrets: Optional[SecretRequest] = None,
85
96
  depends_on: Optional[List[Environment]] = None,
97
+ description: Optional[str] = None,
98
+ interruptible: Optional[bool] = None,
86
99
  **kwargs: Any,
87
100
  ) -> TaskEnvironment:
88
101
  """
@@ -98,6 +111,11 @@ class TaskEnvironment(Environment):
98
111
  :param depends_on: The environment dependencies to hint, so when you deploy the environment,
99
112
  the dependencies are also deployed. This is useful when you have a set of environments
100
113
  that depend on each other.
114
+ :param queue: The queue name to use for tasks in this environment.
115
+ :param pod_template: The pod template to use for tasks in this environment.
116
+ :param description: The description of the environment.
117
+ :param interruptible: Whether the environment is interruptible and can be scheduled on spot/preemptible
118
+ instances.
101
119
  :param kwargs: Additional parameters to override the environment (e.g., cache, reusable, plugin_config).
102
120
  """
103
121
  cache = kwargs.pop("cache", None)
@@ -127,27 +145,58 @@ class TaskEnvironment(Environment):
127
145
  kwargs["secrets"] = secrets
128
146
  if depends_on is not None:
129
147
  kwargs["depends_on"] = depends_on
148
+ if description is not None:
149
+ kwargs["description"] = description
150
+ if interruptible is not None:
151
+ kwargs["interruptible"] = interruptible
130
152
  return replace(self, **kwargs)
131
153
 
154
+ @overload
155
+ def task(
156
+ self,
157
+ *,
158
+ short_name: Optional[str] = None,
159
+ cache: CacheRequest | None = None,
160
+ retries: Union[int, RetryStrategy] = 0,
161
+ timeout: Union[timedelta, int] = 0,
162
+ docs: Optional[Documentation] = None,
163
+ pod_template: Optional[Union[str, PodTemplate]] = None,
164
+ report: bool = False,
165
+ interruptible: bool | None = None,
166
+ max_inline_io_bytes: int = MAX_INLINE_IO_BYTES,
167
+ queue: Optional[str] = None,
168
+ triggers: Tuple[Trigger, ...] | Trigger = (),
169
+ ) -> Callable[[Callable[P, R]], AsyncFunctionTaskTemplate[P, R, Callable[P, R]]]: ...
170
+
171
+ @overload
172
+ def task(
173
+ self,
174
+ _func: Callable[P, R],
175
+ /,
176
+ ) -> AsyncFunctionTaskTemplate[P, R, Callable[P, R]]: ...
177
+
132
178
  def task(
133
179
  self,
134
- _func=None,
180
+ _func: F | None = None,
135
181
  *,
136
- name: Optional[str] = None,
182
+ short_name: Optional[str] = None,
137
183
  cache: CacheRequest | None = None,
138
184
  retries: Union[int, RetryStrategy] = 0,
139
185
  timeout: Union[timedelta, int] = 0,
140
186
  docs: Optional[Documentation] = None,
141
- pod_template: Optional[Union[str, "V1PodTemplate"]] = None,
187
+ pod_template: Optional[Union[str, PodTemplate]] = None,
142
188
  report: bool = False,
189
+ interruptible: bool | None = None,
143
190
  max_inline_io_bytes: int = MAX_INLINE_IO_BYTES,
144
- ) -> Union[AsyncFunctionTaskTemplate, Callable[P, R]]:
191
+ queue: Optional[str] = None,
192
+ triggers: Tuple[Trigger, ...] | Trigger = (),
193
+ ) -> Callable[[F], AsyncFunctionTaskTemplate[P, R, F]] | AsyncFunctionTaskTemplate[P, R, F]:
145
194
  """
146
195
  Decorate a function to be a task.
147
196
 
148
197
  :param _func: Optional The function to decorate. If not provided, the decorator will return a callable that
149
198
  accepts a function to be decorated.
150
- :param name: Optional A friendly name for the task (defaults to the function name)
199
+ :param short_name: Optional A friendly name for the task (defaults to the function name)
151
200
  :param cache: Optional The cache policy for the task, defaults to auto, which will cache the results of the
152
201
  task.
153
202
  :param retries: Optional The number of retries for the task, defaults to 0, which means no retries.
@@ -158,15 +207,21 @@ class TaskEnvironment(Environment):
158
207
  :param report: Optional Whether to generate the html report for the task, defaults to False.
159
208
  :param max_inline_io_bytes: Maximum allowed size (in bytes) for all inputs and outputs passed directly to the
160
209
  task (e.g., primitives, strings, dicts). Does not apply to files, directories, or dataframes.
210
+ :param triggers: Optional A tuple of triggers to associate with the task. This allows the task to be run on a
211
+ schedule or in response to events. Triggers can be defined using the `flyte.trigger` module.
212
+ :param interruptible: Optional Whether the task is interruptible, defaults to environment setting.
213
+ :param queue: Optional queue name to use for this task. If not set, the environment's queue will be used.
214
+
215
+ :return: A TaskTemplate that can be used to deploy the task.
161
216
  """
162
- from ._task import P, R
217
+ from ._task import F, P, R
163
218
 
164
219
  if self.reusable is not None:
165
220
  if pod_template is not None:
166
221
  raise ValueError("Cannot set pod_template when environment is reusable.")
167
222
 
168
- def decorator(func: FunctionTypes) -> AsyncFunctionTaskTemplate[P, R]:
169
- friendly_name = name or func.__name__
223
+ def decorator(func: F) -> AsyncFunctionTaskTemplate[P, R, F]:
224
+ short = short_name or func.__name__
170
225
  task_name = self.name + "." + func.__name__
171
226
 
172
227
  if not inspect.iscoroutinefunction(func) and self.reusable is not None:
@@ -179,7 +234,7 @@ class TaskEnvironment(Environment):
179
234
  if self.plugin_config is not None:
180
235
  from flyte.extend import TaskPluginRegistry
181
236
 
182
- task_template_class: type[AsyncFunctionTaskTemplate[P, R]] | None = TaskPluginRegistry.find(
237
+ task_template_class: type[AsyncFunctionTaskTemplate[P, R, F]] | None = TaskPluginRegistry.find(
183
238
  config_type=type(self.plugin_config)
184
239
  )
185
240
  if task_template_class is None:
@@ -188,9 +243,9 @@ class TaskEnvironment(Environment):
188
243
  f"Please register a plugin using flyte.extend.TaskPluginRegistry.register() api."
189
244
  )
190
245
  else:
191
- task_template_class = AsyncFunctionTaskTemplate[P, R]
246
+ task_template_class = AsyncFunctionTaskTemplate[P, R, F]
192
247
 
193
- task_template_class = cast(type[AsyncFunctionTaskTemplate[P, R]], task_template_class)
248
+ task_template_class = cast(type[AsyncFunctionTaskTemplate[P, R, F]], task_template_class)
194
249
  tmpl = task_template_class(
195
250
  func=func,
196
251
  name=task_name,
@@ -205,18 +260,22 @@ class TaskEnvironment(Environment):
205
260
  secrets=self.secrets,
206
261
  pod_template=pod_template or self.pod_template,
207
262
  parent_env=weakref.ref(self),
263
+ parent_env_name=self.name,
208
264
  interface=NativeInterface.from_callable(func),
209
265
  report=report,
210
- friendly_name=friendly_name,
266
+ short_name=short,
211
267
  plugin_config=self.plugin_config,
212
268
  max_inline_io_bytes=max_inline_io_bytes,
269
+ queue=queue or self.queue,
270
+ interruptible=interruptible if interruptible is not None else self.interruptible,
271
+ triggers=triggers if isinstance(triggers, tuple) else (triggers,),
213
272
  )
214
273
  self._tasks[task_name] = tmpl
215
274
  return tmpl
216
275
 
217
276
  if _func is None:
218
- return cast(AsyncFunctionTaskTemplate, decorator)
219
- return cast(AsyncFunctionTaskTemplate, decorator(_func))
277
+ return cast(Callable[[F], AsyncFunctionTaskTemplate[P, R, F]], decorator)
278
+ return cast(AsyncFunctionTaskTemplate[P, R, F], decorator(_func))
220
279
 
221
280
  @property
222
281
  def tasks(self) -> Dict[str, TaskTemplate]:
@@ -225,16 +284,33 @@ class TaskEnvironment(Environment):
225
284
  """
226
285
  return self._tasks
227
286
 
228
- def add_task(self, task: TaskTemplate) -> TaskTemplate:
287
+ @classmethod
288
+ def from_task(cls, name: str, *tasks: TaskTemplate) -> TaskEnvironment:
229
289
  """
230
- Add a task to the environment.
290
+ Create a TaskEnvironment from a list of tasks. All tasks should have the same image or no Image defined.
291
+ Similarity of Image is determined by the python reference, not by value.
292
+
293
+ If images are different, an error is raised. If no image is defined, the image is set to "auto".
231
294
 
232
- Useful when you want to add a task to an environment that is not defined using the `task` decorator.
295
+ For any other tasks that need to be use these tasks, the returned environment can be used in the `depends_on`
296
+ attribute of the other TaskEnvironment.
297
+
298
+ :param name: The name of the environment.
299
+ :param tasks: The list of tasks to create the environment from.
233
300
 
234
- :param task: The TaskTemplate to add to this environment.
301
+ :raises ValueError: If tasks are assigned to multiple environments or have different images.
302
+ :return: The created TaskEnvironment.
235
303
  """
236
- if task.name in self._tasks:
237
- raise ValueError(f"Task {task.name} already exists in the environment. Task names should be unique.")
238
- self._tasks[task.name] = task
239
- task.parent_env = weakref.ref(self)
240
- return task
304
+ envs = [t.parent_env() for t in tasks if t.parent_env and t.parent_env() is not None]
305
+ if envs:
306
+ raise ValueError("Tasks cannot assigned to multiple environments.")
307
+ images = {t.image for t in tasks}
308
+ if len(images) > 1:
309
+ raise ValueError("Tasks must have the same image to be in the same environment.")
310
+ image: Union[str, Image] = images.pop() if images else "auto"
311
+ env = cls(name, image=image)
312
+ for t in tasks:
313
+ env._tasks[t.name] = t
314
+ t.parent_env = weakref.ref(env)
315
+ t.parent_env_name = name
316
+ return env
flyte/_task_plugins.py CHANGED
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import typing
2
4
  from typing import Type
3
5
 
@@ -14,8 +16,8 @@ class _Registry:
14
16
  A registry for task plugins.
15
17
  """
16
18
 
17
- def __init__(self):
18
- self._plugins: typing.Dict[Type, Type[T]] = {}
19
+ def __init__(self: _Registry):
20
+ self._plugins: typing.Dict[Type, Type[typing.Any]] = {}
19
21
 
20
22
  def register(self, config_type: Type, plugin: Type[T]):
21
23
  """
flyte/_trace.py CHANGED
@@ -3,6 +3,7 @@ import inspect
3
3
  import time
4
4
  from typing import Any, AsyncGenerator, AsyncIterator, Awaitable, Callable, TypeGuard, TypeVar, Union, cast
5
5
 
6
+ from flyte._logging import logger
6
7
  from flyte.models import NativeInterface
7
8
 
8
9
  T = TypeVar("T")
@@ -33,10 +34,13 @@ def trace(func: Callable[..., T]) -> Callable[..., T]:
33
34
  iface = NativeInterface.from_callable(func)
34
35
  info, ok = await controller.get_action_outputs(iface, func, *args, **kwargs)
35
36
  if ok:
37
+ logger.info(f"Found existing trace info for {func}, {info}")
36
38
  if info.output:
37
39
  return info.output
38
40
  elif info.error:
39
41
  raise info.error
42
+ else:
43
+ logger.debug(f"No existing trace info found for {func}, proceeding to execute.")
40
44
  start_time = time.time()
41
45
  try:
42
46
  # Cast to Awaitable to satisfy mypy
@@ -44,6 +48,7 @@ def trace(func: Callable[..., T]) -> Callable[..., T]:
44
48
  results = await coroutine_result
45
49
  info.add_outputs(results, start_time=start_time, end_time=time.time())
46
50
  await controller.record_trace(info)
51
+ logger.debug(f"Finished trace for {func}, {info}")
47
52
  return results
48
53
  except Exception as e:
49
54
  # If there is an error, we need to record it