flyte 2.0.0b32__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flyte might be problematic. Click here for more details.

Files changed (204) hide show
  1. flyte/__init__.py +108 -0
  2. flyte/_bin/__init__.py +0 -0
  3. flyte/_bin/debug.py +38 -0
  4. flyte/_bin/runtime.py +195 -0
  5. flyte/_bin/serve.py +178 -0
  6. flyte/_build.py +26 -0
  7. flyte/_cache/__init__.py +12 -0
  8. flyte/_cache/cache.py +147 -0
  9. flyte/_cache/defaults.py +9 -0
  10. flyte/_cache/local_cache.py +216 -0
  11. flyte/_cache/policy_function_body.py +42 -0
  12. flyte/_code_bundle/__init__.py +8 -0
  13. flyte/_code_bundle/_ignore.py +121 -0
  14. flyte/_code_bundle/_packaging.py +218 -0
  15. flyte/_code_bundle/_utils.py +347 -0
  16. flyte/_code_bundle/bundle.py +266 -0
  17. flyte/_constants.py +1 -0
  18. flyte/_context.py +155 -0
  19. flyte/_custom_context.py +73 -0
  20. flyte/_debug/__init__.py +0 -0
  21. flyte/_debug/constants.py +38 -0
  22. flyte/_debug/utils.py +17 -0
  23. flyte/_debug/vscode.py +307 -0
  24. flyte/_deploy.py +408 -0
  25. flyte/_deployer.py +109 -0
  26. flyte/_doc.py +29 -0
  27. flyte/_docstring.py +32 -0
  28. flyte/_environment.py +122 -0
  29. flyte/_excepthook.py +37 -0
  30. flyte/_group.py +32 -0
  31. flyte/_hash.py +8 -0
  32. flyte/_image.py +1055 -0
  33. flyte/_initialize.py +628 -0
  34. flyte/_interface.py +119 -0
  35. flyte/_internal/__init__.py +3 -0
  36. flyte/_internal/controllers/__init__.py +129 -0
  37. flyte/_internal/controllers/_local_controller.py +239 -0
  38. flyte/_internal/controllers/_trace.py +48 -0
  39. flyte/_internal/controllers/remote/__init__.py +58 -0
  40. flyte/_internal/controllers/remote/_action.py +211 -0
  41. flyte/_internal/controllers/remote/_client.py +47 -0
  42. flyte/_internal/controllers/remote/_controller.py +583 -0
  43. flyte/_internal/controllers/remote/_core.py +465 -0
  44. flyte/_internal/controllers/remote/_informer.py +381 -0
  45. flyte/_internal/controllers/remote/_service_protocol.py +50 -0
  46. flyte/_internal/imagebuild/__init__.py +3 -0
  47. flyte/_internal/imagebuild/docker_builder.py +706 -0
  48. flyte/_internal/imagebuild/image_builder.py +277 -0
  49. flyte/_internal/imagebuild/remote_builder.py +386 -0
  50. flyte/_internal/imagebuild/utils.py +78 -0
  51. flyte/_internal/resolvers/__init__.py +0 -0
  52. flyte/_internal/resolvers/_task_module.py +21 -0
  53. flyte/_internal/resolvers/common.py +31 -0
  54. flyte/_internal/resolvers/default.py +28 -0
  55. flyte/_internal/runtime/__init__.py +0 -0
  56. flyte/_internal/runtime/convert.py +486 -0
  57. flyte/_internal/runtime/entrypoints.py +204 -0
  58. flyte/_internal/runtime/io.py +188 -0
  59. flyte/_internal/runtime/resources_serde.py +152 -0
  60. flyte/_internal/runtime/reuse.py +125 -0
  61. flyte/_internal/runtime/rusty.py +193 -0
  62. flyte/_internal/runtime/task_serde.py +362 -0
  63. flyte/_internal/runtime/taskrunner.py +209 -0
  64. flyte/_internal/runtime/trigger_serde.py +160 -0
  65. flyte/_internal/runtime/types_serde.py +54 -0
  66. flyte/_keyring/__init__.py +0 -0
  67. flyte/_keyring/file.py +115 -0
  68. flyte/_logging.py +300 -0
  69. flyte/_map.py +312 -0
  70. flyte/_module.py +72 -0
  71. flyte/_pod.py +30 -0
  72. flyte/_resources.py +473 -0
  73. flyte/_retry.py +32 -0
  74. flyte/_reusable_environment.py +102 -0
  75. flyte/_run.py +724 -0
  76. flyte/_secret.py +96 -0
  77. flyte/_task.py +550 -0
  78. flyte/_task_environment.py +316 -0
  79. flyte/_task_plugins.py +47 -0
  80. flyte/_timeout.py +47 -0
  81. flyte/_tools.py +27 -0
  82. flyte/_trace.py +119 -0
  83. flyte/_trigger.py +1000 -0
  84. flyte/_utils/__init__.py +30 -0
  85. flyte/_utils/asyn.py +121 -0
  86. flyte/_utils/async_cache.py +139 -0
  87. flyte/_utils/coro_management.py +27 -0
  88. flyte/_utils/docker_credentials.py +173 -0
  89. flyte/_utils/file_handling.py +72 -0
  90. flyte/_utils/helpers.py +134 -0
  91. flyte/_utils/lazy_module.py +54 -0
  92. flyte/_utils/module_loader.py +104 -0
  93. flyte/_utils/org_discovery.py +57 -0
  94. flyte/_utils/uv_script_parser.py +49 -0
  95. flyte/_version.py +34 -0
  96. flyte/app/__init__.py +22 -0
  97. flyte/app/_app_environment.py +157 -0
  98. flyte/app/_deploy.py +125 -0
  99. flyte/app/_input.py +160 -0
  100. flyte/app/_runtime/__init__.py +3 -0
  101. flyte/app/_runtime/app_serde.py +347 -0
  102. flyte/app/_types.py +101 -0
  103. flyte/app/extras/__init__.py +3 -0
  104. flyte/app/extras/_fastapi.py +151 -0
  105. flyte/cli/__init__.py +12 -0
  106. flyte/cli/_abort.py +28 -0
  107. flyte/cli/_build.py +114 -0
  108. flyte/cli/_common.py +468 -0
  109. flyte/cli/_create.py +371 -0
  110. flyte/cli/_delete.py +45 -0
  111. flyte/cli/_deploy.py +293 -0
  112. flyte/cli/_gen.py +176 -0
  113. flyte/cli/_get.py +370 -0
  114. flyte/cli/_option.py +33 -0
  115. flyte/cli/_params.py +554 -0
  116. flyte/cli/_plugins.py +209 -0
  117. flyte/cli/_run.py +597 -0
  118. flyte/cli/_serve.py +64 -0
  119. flyte/cli/_update.py +37 -0
  120. flyte/cli/_user.py +17 -0
  121. flyte/cli/main.py +221 -0
  122. flyte/config/__init__.py +3 -0
  123. flyte/config/_config.py +248 -0
  124. flyte/config/_internal.py +73 -0
  125. flyte/config/_reader.py +225 -0
  126. flyte/connectors/__init__.py +11 -0
  127. flyte/connectors/_connector.py +270 -0
  128. flyte/connectors/_server.py +197 -0
  129. flyte/connectors/utils.py +135 -0
  130. flyte/errors.py +243 -0
  131. flyte/extend.py +19 -0
  132. flyte/extras/__init__.py +5 -0
  133. flyte/extras/_container.py +286 -0
  134. flyte/git/__init__.py +3 -0
  135. flyte/git/_config.py +21 -0
  136. flyte/io/__init__.py +29 -0
  137. flyte/io/_dataframe/__init__.py +131 -0
  138. flyte/io/_dataframe/basic_dfs.py +223 -0
  139. flyte/io/_dataframe/dataframe.py +1026 -0
  140. flyte/io/_dir.py +910 -0
  141. flyte/io/_file.py +914 -0
  142. flyte/io/_hashing_io.py +342 -0
  143. flyte/models.py +479 -0
  144. flyte/py.typed +0 -0
  145. flyte/remote/__init__.py +35 -0
  146. flyte/remote/_action.py +738 -0
  147. flyte/remote/_app.py +57 -0
  148. flyte/remote/_client/__init__.py +0 -0
  149. flyte/remote/_client/_protocols.py +189 -0
  150. flyte/remote/_client/auth/__init__.py +12 -0
  151. flyte/remote/_client/auth/_auth_utils.py +14 -0
  152. flyte/remote/_client/auth/_authenticators/__init__.py +0 -0
  153. flyte/remote/_client/auth/_authenticators/base.py +403 -0
  154. flyte/remote/_client/auth/_authenticators/client_credentials.py +73 -0
  155. flyte/remote/_client/auth/_authenticators/device_code.py +117 -0
  156. flyte/remote/_client/auth/_authenticators/external_command.py +79 -0
  157. flyte/remote/_client/auth/_authenticators/factory.py +200 -0
  158. flyte/remote/_client/auth/_authenticators/pkce.py +516 -0
  159. flyte/remote/_client/auth/_channel.py +213 -0
  160. flyte/remote/_client/auth/_client_config.py +85 -0
  161. flyte/remote/_client/auth/_default_html.py +32 -0
  162. flyte/remote/_client/auth/_grpc_utils/__init__.py +0 -0
  163. flyte/remote/_client/auth/_grpc_utils/auth_interceptor.py +288 -0
  164. flyte/remote/_client/auth/_grpc_utils/default_metadata_interceptor.py +151 -0
  165. flyte/remote/_client/auth/_keyring.py +152 -0
  166. flyte/remote/_client/auth/_token_client.py +260 -0
  167. flyte/remote/_client/auth/errors.py +16 -0
  168. flyte/remote/_client/controlplane.py +128 -0
  169. flyte/remote/_common.py +30 -0
  170. flyte/remote/_console.py +19 -0
  171. flyte/remote/_data.py +161 -0
  172. flyte/remote/_logs.py +185 -0
  173. flyte/remote/_project.py +88 -0
  174. flyte/remote/_run.py +386 -0
  175. flyte/remote/_secret.py +142 -0
  176. flyte/remote/_task.py +527 -0
  177. flyte/remote/_trigger.py +306 -0
  178. flyte/remote/_user.py +33 -0
  179. flyte/report/__init__.py +3 -0
  180. flyte/report/_report.py +182 -0
  181. flyte/report/_template.html +124 -0
  182. flyte/storage/__init__.py +36 -0
  183. flyte/storage/_config.py +237 -0
  184. flyte/storage/_parallel_reader.py +274 -0
  185. flyte/storage/_remote_fs.py +34 -0
  186. flyte/storage/_storage.py +456 -0
  187. flyte/storage/_utils.py +5 -0
  188. flyte/syncify/__init__.py +56 -0
  189. flyte/syncify/_api.py +375 -0
  190. flyte/types/__init__.py +52 -0
  191. flyte/types/_interface.py +40 -0
  192. flyte/types/_pickle.py +145 -0
  193. flyte/types/_renderer.py +162 -0
  194. flyte/types/_string_literals.py +119 -0
  195. flyte/types/_type_engine.py +2254 -0
  196. flyte/types/_utils.py +80 -0
  197. flyte-2.0.0b32.data/scripts/debug.py +38 -0
  198. flyte-2.0.0b32.data/scripts/runtime.py +195 -0
  199. flyte-2.0.0b32.dist-info/METADATA +351 -0
  200. flyte-2.0.0b32.dist-info/RECORD +204 -0
  201. flyte-2.0.0b32.dist-info/WHEEL +5 -0
  202. flyte-2.0.0b32.dist-info/entry_points.txt +7 -0
  203. flyte-2.0.0b32.dist-info/licenses/LICENSE +201 -0
  204. flyte-2.0.0b32.dist-info/top_level.txt +1 -0
flyte/_secret.py ADDED
@@ -0,0 +1,96 @@
1
+ import pathlib
2
+ import re
3
+ from dataclasses import dataclass
4
+ from typing import List, Optional, Union
5
+
6
+
7
+ @dataclass
8
+ class Secret:
9
+ """
10
+ Secrets are used to inject sensitive information into tasks or image build context.
11
+ Secrets can be mounted as environment variables or files.
12
+ The secret key is the name of the secret in the secret store. The group is optional and maybe used with some
13
+ secret stores to organize secrets. The secret_mount is used to specify how the secret should be mounted. If the
14
+ secret_mount is set to "env" the secret will be mounted as an environment variable. If the secret_mount is set to
15
+ "file" the secret will be mounted as a file. The as_env_var is an optional parameter that can be used to specify the
16
+ name of the environment variable that the secret should be mounted as.
17
+
18
+ Example:
19
+ ```python
20
+ @task(secrets="my-secret")
21
+ async def my_task():
22
+ # This will be set to the value of the secret. Note: The env var is always uppercase, and - is replaced with _.
23
+ os.environ["MY_SECRET"]
24
+
25
+ @task(secrets=Secret("my-openai-api-key", as_env_var="OPENAI_API_KEY"))
26
+ async def my_task2():
27
+ os.environ["OPENAI_API_KEY"]
28
+ ```
29
+
30
+ TODO: Add support for secret versioning (some stores) and secret groups (some stores) and mounting as files.
31
+
32
+ :param key: The name of the secret in the secret store.
33
+ :param group: The group of the secret in the secret store.
34
+ :param mount: Use this to specify the path where the secret should be mounted.
35
+ TODO: support arbitrary mount paths. Today only "/etc/flyte/secrets" is supported
36
+ :param as_env_var: The name of the environment variable that the secret should be mounted as.
37
+ """
38
+
39
+ key: str
40
+ group: Optional[str] = None
41
+ mount: pathlib.Path | None = None
42
+ as_env_var: Optional[str] = None
43
+
44
+ def __post_init__(self):
45
+ if not self.mount and not self.as_env_var:
46
+ self.as_env_var = f"{self.group}_{self.key}" if self.group else self.key
47
+ self.as_env_var = self.as_env_var.replace("-", "_").upper()
48
+ if self.as_env_var is not None:
49
+ pattern = r"^[A-Z_][A-Z0-9_]*$"
50
+ if not re.match(pattern, self.as_env_var):
51
+ raise ValueError(f"Invalid environment variable name: {self.as_env_var}, must match {pattern}")
52
+
53
+ def stable_hash(self) -> str:
54
+ """
55
+ Deterministic, process-independent hash (as hex string).
56
+ """
57
+ import hashlib
58
+
59
+ data = (
60
+ self.key,
61
+ self.group or "",
62
+ str(self.mount) if self.mount else "",
63
+ self.as_env_var or "",
64
+ )
65
+ joined = "|".join(data)
66
+ return hashlib.sha256(joined.encode("utf-8")).hexdigest()
67
+
68
+ def __hash__(self) -> int:
69
+ """
70
+ Deterministic hash function for the Secret class.
71
+ """
72
+ return int(self.stable_hash()[:16], 16)
73
+
74
+
75
+ SecretRequest = Union[str, Secret, List[str | Secret]]
76
+
77
+
78
+ def secrets_from_request(secrets: SecretRequest) -> List[Secret]:
79
+ """
80
+ Converts a secret request into a list of secrets.
81
+ """
82
+ if isinstance(secrets, str):
83
+ return [Secret(key=secrets)]
84
+ elif isinstance(secrets, Secret):
85
+ return [secrets]
86
+ else:
87
+ return [Secret(key=s) if isinstance(s, str) else s for s in secrets]
88
+
89
+
90
+ if __name__ == "__main__":
91
+ # Example usage
92
+ secret1 = Secret(key="MY_SECRET", mount=pathlib.Path("/path/to/secret"), as_env_var="MY_SECRET_ENV")
93
+ secret2 = Secret(
94
+ key="ANOTHER_SECRET",
95
+ )
96
+ print(hash(secret1), hash(secret2))
flyte/_task.py ADDED
@@ -0,0 +1,550 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import weakref
5
+ from dataclasses import dataclass, field, replace
6
+ from inspect import iscoroutinefunction
7
+ from typing import (
8
+ TYPE_CHECKING,
9
+ Any,
10
+ Callable,
11
+ Coroutine,
12
+ Dict,
13
+ Generic,
14
+ List,
15
+ Literal,
16
+ Optional,
17
+ ParamSpec,
18
+ Tuple,
19
+ TypeAlias,
20
+ TypeVar,
21
+ Union,
22
+ cast,
23
+ overload,
24
+ )
25
+
26
+ from flyte._pod import PodTemplate
27
+ from flyte.errors import RuntimeSystemError, RuntimeUserError
28
+
29
+ from ._cache import Cache, CacheRequest
30
+ from ._context import internal_ctx
31
+ from ._doc import Documentation
32
+ from ._image import Image
33
+ from ._resources import Resources
34
+ from ._retry import RetryStrategy
35
+ from ._reusable_environment import ReusePolicy
36
+ from ._secret import SecretRequest
37
+ from ._timeout import TimeoutType
38
+ from ._trigger import Trigger
39
+ from .models import MAX_INLINE_IO_BYTES, NativeInterface, SerializationContext
40
+
41
+ if TYPE_CHECKING:
42
+ from flyteidl2.core.tasks_pb2 import DataLoadingConfig
43
+
44
+ from ._task_environment import TaskEnvironment
45
+
46
+ P = ParamSpec("P") # capture the function's parameters
47
+ R = TypeVar("R") # return type
48
+
49
+ AsyncFunctionType: TypeAlias = Callable[P, Coroutine[Any, Any, R]]
50
+ SyncFunctionType: TypeAlias = Callable[P, R]
51
+ FunctionTypes: TypeAlias = AsyncFunctionType | SyncFunctionType
52
+ F = TypeVar("F", bound=FunctionTypes)
53
+
54
+
55
+ @dataclass(kw_only=True)
56
+ class TaskTemplate(Generic[P, R, F]):
57
+ """
58
+ Task template is a template for a task that can be executed. It defines various parameters for the task, which
59
+ can be defined statically at the time of task definition or dynamically at the time of task invocation using
60
+ the override method.
61
+
62
+ Example usage:
63
+ ```python
64
+ @task(name="my_task", image="my_image", resources=Resources(cpu="1", memory="1Gi"))
65
+ def my_task():
66
+ pass
67
+ ```
68
+
69
+ :param name: Optional The name of the task (defaults to the function name)
70
+ :param task_type: Router type for the task, this is used to determine how the task will be executed.
71
+ This is usually set to match with th execution plugin.
72
+ :param image: Optional The image to use for the task, if set to "auto" will use the default image for the python
73
+ version with flyte installed
74
+ :param resources: Optional The resources to use for the task
75
+ :param cache: Optional The cache policy for the task, defaults to auto, which will cache the results of the task.
76
+ :param interruptible: Optional The interruptible policy for the task, defaults to False, which means the task
77
+ will not be scheduled on interruptible nodes. If set to True, the task will be scheduled on interruptible nodes,
78
+ and the code should handle interruptions and resumptions.
79
+ :param retries: Optional The number of retries for the task, defaults to 0, which means no retries.
80
+ :param reusable: Optional The reusability policy for the task, defaults to None, which means the task environment
81
+ will not be reused across task invocations.
82
+ :param docs: Optional The documentation for the task, if not provided the function docstring will be used.
83
+ :param env_vars: Optional The environment variables to set for the task.
84
+ :param secrets: Optional The secrets that will be injected into the task at runtime.
85
+ :param timeout: Optional The timeout for the task.
86
+ :param max_inline_io_bytes: Maximum allowed size (in bytes) for all inputs and outputs passed directly to the task
87
+ (e.g., primitives, strings, dicts). Does not apply to files, directories, or dataframes.
88
+ :param pod_template: Optional The pod template to use for the task.
89
+ :param report: Optional Whether to report the task execution to the Flyte console, defaults to False.
90
+ :param queue: Optional The queue to use for the task. If not provided, the default queue will be used.
91
+ :param debuggable: Optional Whether the task supports debugging capabilities, defaults to False.
92
+ """
93
+
94
+ name: str
95
+ interface: NativeInterface
96
+ short_name: str = ""
97
+ task_type: str = "python"
98
+ task_type_version: int = 0
99
+ image: Union[str, Image, Literal["auto"]] = "auto"
100
+ resources: Optional[Resources] = None
101
+ cache: CacheRequest = "disable"
102
+ interruptible: bool = False
103
+ retries: Union[int, RetryStrategy] = 0
104
+ reusable: Union[ReusePolicy, None] = None
105
+ docs: Optional[Documentation] = None
106
+ env_vars: Optional[Dict[str, str]] = None
107
+ secrets: Optional[SecretRequest] = None
108
+ timeout: Optional[TimeoutType] = None
109
+ pod_template: Optional[Union[str, PodTemplate]] = None
110
+ report: bool = False
111
+ queue: Optional[str] = None
112
+ debuggable: bool = False
113
+
114
+ parent_env: Optional[weakref.ReferenceType[TaskEnvironment]] = None
115
+ parent_env_name: Optional[str] = None
116
+ ref: bool = field(default=False, init=False, repr=False, compare=False)
117
+ max_inline_io_bytes: int = MAX_INLINE_IO_BYTES
118
+ triggers: Tuple[Trigger, ...] = field(default_factory=tuple)
119
+
120
+ # Only used in python 3.10 and 3.11, where we cannot use markcoroutinefunction
121
+ _call_as_synchronous: bool = False
122
+
123
+ def __post_init__(self):
124
+ # Auto set the image based on the image request
125
+ if self.image == "auto":
126
+ self.image = Image.from_debian_base()
127
+ elif isinstance(self.image, str):
128
+ self.image = Image.from_base(str(self.image))
129
+
130
+ # Auto set cache based on the cache request
131
+ if isinstance(self.cache, str):
132
+ match self.cache:
133
+ case "auto":
134
+ self.cache = Cache(behavior="auto")
135
+ case "override":
136
+ self.cache = Cache(behavior="override")
137
+ case "disable":
138
+ self.cache = Cache(behavior="disable")
139
+
140
+ # if retries is set to int, convert to RetryStrategy
141
+ if isinstance(self.retries, int):
142
+ self.retries = RetryStrategy(count=self.retries)
143
+
144
+ if self.short_name == "":
145
+ # If short_name is not set, use the name of the task
146
+ self.short_name = self.name
147
+
148
+ def __getstate__(self):
149
+ """
150
+ This method is called when the object is pickled. We need to remove the parent_env reference
151
+ to avoid circular references.
152
+ """
153
+ state = self.__dict__.copy()
154
+ state.pop("parent_env", None)
155
+ return state
156
+
157
+ def __setstate__(self, state):
158
+ """
159
+ This method is called when the object is unpickled. We need to set the parent_env reference
160
+ to the environment that created the task.
161
+ """
162
+ self.__dict__.update(state)
163
+ self.parent_env = None
164
+
165
+ @property
166
+ def source_file(self) -> Optional[str]:
167
+ return None
168
+
169
+ async def pre(self, *args, **kwargs) -> Dict[str, Any]:
170
+ """
171
+ This is the preexecute function that will be
172
+ called before the task is executed
173
+ """
174
+ return {}
175
+
176
+ async def execute(self, *args, **kwargs) -> Any:
177
+ """
178
+ This is the pure python function that will be executed when the task is called.
179
+ """
180
+ raise NotImplementedError
181
+
182
+ async def post(self, return_vals: Any) -> Any:
183
+ """
184
+ This is the postexecute function that will be
185
+ called after the task is executed
186
+ """
187
+ return return_vals
188
+
189
+ # ---- Extension points ----
190
+ def config(self, sctx: SerializationContext) -> Dict[str, str]:
191
+ """
192
+ Returns additional configuration for the task. This is a set of key-value pairs that can be used to
193
+ configure the task execution environment at runtime. This is usually used by plugins.
194
+ """
195
+ return {}
196
+
197
+ def custom_config(self, sctx: SerializationContext) -> Dict[str, str]:
198
+ """
199
+ Returns additional configuration for the task. This is a set of key-value pairs that can be used to
200
+ configure the task execution environment at runtime. This is usually used by plugins.
201
+ """
202
+ return {}
203
+
204
+ def data_loading_config(self, sctx: SerializationContext) -> DataLoadingConfig:
205
+ """
206
+ This configuration allows executing raw containers in Flyte using the Flyte CoPilot system
207
+ Flyte CoPilot, eliminates the needs of sdk inside the container. Any inputs required by the users container
208
+ are side-loaded in the input_path
209
+ Any outputs generated by the user container - within output_path are automatically uploaded
210
+ """
211
+
212
+ def container_args(self, sctx: SerializationContext) -> List[str]:
213
+ """
214
+ Returns the container args for the task. This is a set of key-value pairs that can be used to
215
+ configure the task execution environment at runtime. This is usually used by plugins.
216
+ """
217
+ return []
218
+
219
+ def sql(self, sctx: SerializationContext) -> Optional[str]:
220
+ """
221
+ Returns the SQL for the task. This is a set of key-value pairs that can be used to
222
+ configure the task execution environment at runtime. This is usually used by plugins.
223
+ """
224
+ return None
225
+
226
+ # ---- Extension points ----
227
+
228
+ @property
229
+ def native_interface(self) -> NativeInterface:
230
+ return self.interface
231
+
232
+ @overload
233
+ async def aio(self: TaskTemplate[P, R, SyncFunctionType], *args: P.args, **kwargs: P.kwargs) -> R: ...
234
+
235
+ @overload
236
+ async def aio(
237
+ self: TaskTemplate[P, R, AsyncFunctionType], *args: P.args, **kwargs: P.kwargs
238
+ ) -> Coroutine[Any, Any, R]: ...
239
+
240
+ async def aio(self, *args: P.args, **kwargs: P.kwargs) -> Coroutine[Any, Any, R] | R:
241
+ """
242
+ The aio function allows executing "sync" tasks, in an async context. This helps with migrating v1 defined sync
243
+ tasks to be used within an asyncio parent task.
244
+ This function will also re-raise exceptions from the underlying task.
245
+
246
+ Example:
247
+ ```python
248
+ @env.task
249
+ def my_legacy_task(x: int) -> int:
250
+ return x
251
+
252
+ @env.task
253
+ async def my_new_parent_task(n: int) -> List[int]:
254
+ collect = []
255
+ for x in range(n):
256
+ collect.append(my_legacy_task.aio(x))
257
+ return asyncio.gather(*collect)
258
+ ```
259
+ :param args:
260
+ :param kwargs:
261
+ :return:
262
+ """
263
+ ctx = internal_ctx()
264
+ if ctx.is_task_context():
265
+ from ._internal.controllers import get_controller
266
+
267
+ # If we are in a task context, that implies we are executing a Run.
268
+ # In this scenario, we should submit the task to the controller.
269
+ controller = get_controller()
270
+ if controller:
271
+ if self._call_as_synchronous:
272
+ fut = controller.submit_sync(self, *args, **kwargs)
273
+ asyncio_future = asyncio.wrap_future(fut) # Wrap the future to make it awaitable
274
+ return await asyncio_future
275
+ else:
276
+ return await controller.submit(self, *args, **kwargs)
277
+ else:
278
+ raise RuntimeSystemError("BadContext", "Controller is not initialized.")
279
+ else:
280
+ from flyte._logging import logger
281
+
282
+ logger.warning(f"Task {self.name} running aio outside of a task context.")
283
+ # Local execute, just stay out of the way, but because .aio is used, we want to return an awaitable,
284
+ # even for synchronous tasks. This is to support migration.
285
+ return self.forward(*args, **kwargs)
286
+
287
+ @overload
288
+ def __call__(self: TaskTemplate[P, R, SyncFunctionType], *args: P.args, **kwargs: P.kwargs) -> R: ...
289
+
290
+ @overload
291
+ def __call__(
292
+ self: TaskTemplate[P, R, AsyncFunctionType], *args: P.args, **kwargs: P.kwargs
293
+ ) -> Coroutine[Any, Any, R]: ...
294
+
295
+ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Coroutine[Any, Any, R] | R:
296
+ """
297
+ This is the entrypoint for an async function task at runtime. It will be called during an execution.
298
+ Please do not override this method, if you simply want to modify the execution behavior, override the
299
+ execute method.
300
+
301
+ This needs to be overridable to maybe be async.
302
+ The returned thing from here needs to be an awaitable if the underlying task is async, and a regular object
303
+ if the task is not.
304
+ """
305
+ try:
306
+ ctx = internal_ctx()
307
+ if ctx.is_task_context():
308
+ # If we are in a task context, that implies we are executing a Run.
309
+ # In this scenario, we should submit the task to the controller.
310
+ # We will also check if we are not initialized, It is not expected to be not initialized
311
+ from ._internal.controllers import get_controller
312
+
313
+ controller = get_controller()
314
+ if not controller:
315
+ raise RuntimeSystemError("BadContext", "Controller is not initialized.")
316
+
317
+ if self._call_as_synchronous:
318
+ fut = controller.submit_sync(self, *args, **kwargs)
319
+ x = fut.result(None)
320
+ return x
321
+ else:
322
+ return controller.submit(self, *args, **kwargs)
323
+ else:
324
+ # If not in task context, purely function run, stay out of the way
325
+ return self.forward(*args, **kwargs)
326
+ except RuntimeSystemError:
327
+ raise
328
+ except RuntimeUserError:
329
+ raise
330
+ except Exception as e:
331
+ raise RuntimeUserError(type(e).__name__, str(e)) from e
332
+
333
+ def forward(self, *args: P.args, **kwargs: P.kwargs) -> Coroutine[Any, Any, R] | R:
334
+ """
335
+ Think of this as a local execute method for your task. This function will be invoked by the __call__ method
336
+ when not in a Flyte task execution context. See the implementation below for an example.
337
+
338
+ :param args:
339
+ :param kwargs:
340
+ :return:
341
+ """
342
+ raise NotImplementedError
343
+
344
+ def override(
345
+ self,
346
+ *,
347
+ short_name: Optional[str] = None,
348
+ resources: Optional[Resources] = None,
349
+ cache: Optional[CacheRequest] = None,
350
+ retries: Union[int, RetryStrategy] = 0,
351
+ timeout: Optional[TimeoutType] = None,
352
+ reusable: Union[ReusePolicy, Literal["off"], None] = None,
353
+ env_vars: Optional[Dict[str, str]] = None,
354
+ secrets: Optional[SecretRequest] = None,
355
+ max_inline_io_bytes: int | None = None,
356
+ pod_template: Optional[Union[str, PodTemplate]] = None,
357
+ queue: Optional[str] = None,
358
+ interruptible: Optional[bool] = None,
359
+ **kwargs: Any,
360
+ ) -> TaskTemplate:
361
+ """
362
+ Override various parameters of the task template. This allows for dynamic configuration of the task
363
+ when it is called, such as changing the image, resources, cache policy, etc.
364
+
365
+ :param short_name: Optional override for the short name of the task.
366
+ :param resources: Optional override for the resources to use for the task.
367
+ :param cache: Optional override for the cache policy for the task.
368
+ :param retries: Optional override for the number of retries for the task.
369
+ :param timeout: Optional override for the timeout for the task.
370
+ :param reusable: Optional override for the reusability policy for the task.
371
+ :param env_vars: Optional override for the environment variables to set for the task.
372
+ :param secrets: Optional override for the secrets that will be injected into the task at runtime.
373
+ :param max_inline_io_bytes: Optional override for the maximum allowed size (in bytes) for all inputs and outputs
374
+ passed directly to the task.
375
+ :param pod_template: Optional override for the pod template to use for the task.
376
+ :param queue: Optional override for the queue to use for the task.
377
+ :param kwargs: Additional keyword arguments for further overrides. Some fields like name, image, docs,
378
+ and interface cannot be overridden.
379
+
380
+ :return: A new TaskTemplate instance with the overridden parameters.
381
+ """
382
+ cache = cache or self.cache
383
+ retries = retries or self.retries
384
+ timeout = timeout or self.timeout
385
+ max_inline_io_bytes = max_inline_io_bytes or self.max_inline_io_bytes
386
+
387
+ reusable = reusable or self.reusable
388
+ if reusable == "off":
389
+ reusable = None
390
+
391
+ if reusable is not None:
392
+ if resources is not None:
393
+ raise ValueError(
394
+ "Cannot override resources when reusable is set."
395
+ " Reusable tasks will use the parent env's resources. You can disable reusability and"
396
+ " override resources if needed. (set reusable='off')"
397
+ )
398
+ if env_vars is not None:
399
+ raise ValueError(
400
+ "Cannot override env when reusable is set."
401
+ " Reusable tasks will use the parent env's env. You can disable reusability and "
402
+ "override env if needed. (set reusable='off')"
403
+ )
404
+ if secrets is not None:
405
+ raise ValueError(
406
+ "Cannot override secrets when reusable is set."
407
+ " Reusable tasks will use the parent env's secrets. You can disable reusability and "
408
+ "override secrets if needed. (set reusable='off')"
409
+ )
410
+
411
+ resources = resources or self.resources
412
+ env_vars = env_vars or self.env_vars
413
+ secrets = secrets or self.secrets
414
+
415
+ interruptible = interruptible if interruptible is not None else self.interruptible
416
+
417
+ for k, v in kwargs.items():
418
+ if k == "name":
419
+ raise ValueError("Name cannot be overridden")
420
+ if k == "image":
421
+ raise ValueError("Image cannot be overridden")
422
+ if k == "docs":
423
+ raise ValueError("Docs cannot be overridden")
424
+ if k == "interface":
425
+ raise ValueError("Interface cannot be overridden")
426
+
427
+ return replace(
428
+ self,
429
+ short_name=short_name or self.short_name,
430
+ resources=resources,
431
+ cache=cache,
432
+ retries=retries,
433
+ timeout=timeout,
434
+ reusable=cast(Optional[ReusePolicy], reusable),
435
+ env_vars=env_vars,
436
+ secrets=secrets,
437
+ max_inline_io_bytes=max_inline_io_bytes,
438
+ pod_template=pod_template,
439
+ interruptible=interruptible,
440
+ queue=queue or self.queue,
441
+ **kwargs,
442
+ )
443
+
444
+
445
+ @dataclass(kw_only=True)
446
+ class AsyncFunctionTaskTemplate(TaskTemplate[P, R, F]):
447
+ """
448
+ A task template that wraps an asynchronous functions. This is automatically created when an asynchronous function
449
+ is decorated with the task decorator.
450
+ """
451
+
452
+ func: F
453
+ plugin_config: Optional[Any] = None # This is used to pass plugin specific configuration
454
+ debuggable: bool = True
455
+
456
+ def __post_init__(self):
457
+ super().__post_init__()
458
+ if not iscoroutinefunction(self.func):
459
+ self._call_as_synchronous = True
460
+
461
+ @property
462
+ def source_file(self) -> Optional[str]:
463
+ """
464
+ Returns the source file of the function, if available. This is useful for debugging and tracing.
465
+ """
466
+ if hasattr(self.func, "__code__") and self.func.__code__:
467
+ return self.func.__code__.co_filename
468
+ return None
469
+
470
+ def forward(self, *args: P.args, **kwargs: P.kwargs) -> Coroutine[Any, Any, R] | R:
471
+ # In local execution, we want to just call the function. Note we're not awaiting anything here.
472
+ # If the function was a coroutine function, the coroutine is returned and the await that the caller has
473
+ # in front of the task invocation will handle the awaiting.
474
+ return self.func(*args, **kwargs)
475
+
476
+ async def execute(self, *args: P.args, **kwargs: P.kwargs) -> R:
477
+ """
478
+ This is the execute method that will be called when the task is invoked. It will call the actual function.
479
+ # TODO We may need to keep this as the bare func execute, and need a pre and post execute some other func.
480
+ """
481
+
482
+ ctx = internal_ctx()
483
+ assert ctx.data.task_context is not None, "Function should have already returned if not in a task context"
484
+ ctx_data = await self.pre(*args, **kwargs)
485
+ tctx = ctx.data.task_context.replace(data=ctx_data)
486
+ with ctx.replace_task_context(tctx):
487
+ if iscoroutinefunction(self.func):
488
+ v = await self.func(*args, **kwargs)
489
+ else:
490
+ v = self.func(*args, **kwargs)
491
+ await self.post(v)
492
+ return v
493
+
494
+ def container_args(self, serialize_context: SerializationContext) -> List[str]:
495
+ args = [
496
+ "a0",
497
+ "--inputs",
498
+ serialize_context.input_path,
499
+ "--outputs-path",
500
+ serialize_context.output_path,
501
+ "--version",
502
+ serialize_context.version, # pr: should this be serialize_context.version or code_bundle.version?
503
+ "--raw-data-path",
504
+ "{{.rawOutputDataPrefix}}",
505
+ "--checkpoint-path",
506
+ "{{.checkpointOutputPrefix}}",
507
+ "--prev-checkpoint",
508
+ "{{.prevCheckpointPrefix}}",
509
+ "--run-name",
510
+ "{{.runName}}",
511
+ "--name",
512
+ "{{.actionName}}",
513
+ ]
514
+ # Add on all the known images
515
+ if serialize_context.image_cache and serialize_context.image_cache.serialized_form:
516
+ args = [*args, "--image-cache", serialize_context.image_cache.serialized_form]
517
+ else:
518
+ if serialize_context.image_cache:
519
+ args = [*args, "--image-cache", serialize_context.image_cache.to_transport]
520
+
521
+ if serialize_context.code_bundle:
522
+ if serialize_context.code_bundle.tgz:
523
+ args = [*args, *["--tgz", f"{serialize_context.code_bundle.tgz}"]]
524
+ elif serialize_context.code_bundle.pkl:
525
+ args = [*args, *["--pkl", f"{serialize_context.code_bundle.pkl}"]]
526
+ args = [*args, *["--dest", f"{serialize_context.code_bundle.destination or '.'}"]]
527
+
528
+ if not serialize_context.code_bundle or not serialize_context.code_bundle.pkl:
529
+ # If we do not have a code bundle, or if we have one, but it is not a pkl, we need to add the resolver
530
+
531
+ from flyte._internal.resolvers.default import DefaultTaskResolver
532
+
533
+ if not serialize_context.root_dir:
534
+ raise RuntimeSystemError(
535
+ "SerializationError",
536
+ "Root dir is required for default task resolver when no code bundle is provided.",
537
+ )
538
+ _task_resolver = DefaultTaskResolver()
539
+ args = [
540
+ *args,
541
+ *[
542
+ "--resolver",
543
+ _task_resolver.import_path,
544
+ *_task_resolver.loader_args(task=self, root_dir=serialize_context.root_dir),
545
+ ],
546
+ ]
547
+
548
+ assert all(isinstance(item, str) for item in args), f"All args should be strings, non string item = {args}"
549
+
550
+ return args