flyte 2.0.0b13__py3-none-any.whl → 2.0.0b30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (211) hide show
  1. flyte/__init__.py +18 -2
  2. flyte/_bin/debug.py +38 -0
  3. flyte/_bin/runtime.py +62 -8
  4. flyte/_cache/cache.py +4 -2
  5. flyte/_cache/local_cache.py +216 -0
  6. flyte/_code_bundle/_ignore.py +12 -4
  7. flyte/_code_bundle/_packaging.py +13 -9
  8. flyte/_code_bundle/_utils.py +18 -10
  9. flyte/_code_bundle/bundle.py +17 -9
  10. flyte/_constants.py +1 -0
  11. flyte/_context.py +4 -1
  12. flyte/_custom_context.py +73 -0
  13. flyte/_debug/constants.py +38 -0
  14. flyte/_debug/utils.py +17 -0
  15. flyte/_debug/vscode.py +307 -0
  16. flyte/_deploy.py +235 -61
  17. flyte/_environment.py +20 -6
  18. flyte/_excepthook.py +1 -1
  19. flyte/_hash.py +1 -16
  20. flyte/_image.py +178 -81
  21. flyte/_initialize.py +132 -51
  22. flyte/_interface.py +39 -2
  23. flyte/_internal/controllers/__init__.py +4 -5
  24. flyte/_internal/controllers/_local_controller.py +70 -29
  25. flyte/_internal/controllers/_trace.py +1 -1
  26. flyte/_internal/controllers/remote/__init__.py +0 -2
  27. flyte/_internal/controllers/remote/_action.py +14 -16
  28. flyte/_internal/controllers/remote/_client.py +1 -1
  29. flyte/_internal/controllers/remote/_controller.py +68 -70
  30. flyte/_internal/controllers/remote/_core.py +127 -99
  31. flyte/_internal/controllers/remote/_informer.py +19 -10
  32. flyte/_internal/controllers/remote/_service_protocol.py +7 -7
  33. flyte/_internal/imagebuild/docker_builder.py +181 -69
  34. flyte/_internal/imagebuild/image_builder.py +0 -5
  35. flyte/_internal/imagebuild/remote_builder.py +155 -64
  36. flyte/_internal/imagebuild/utils.py +51 -2
  37. flyte/_internal/resolvers/_task_module.py +5 -38
  38. flyte/_internal/resolvers/default.py +2 -2
  39. flyte/_internal/runtime/convert.py +110 -21
  40. flyte/_internal/runtime/entrypoints.py +27 -1
  41. flyte/_internal/runtime/io.py +21 -8
  42. flyte/_internal/runtime/resources_serde.py +20 -6
  43. flyte/_internal/runtime/reuse.py +1 -1
  44. flyte/_internal/runtime/rusty.py +20 -5
  45. flyte/_internal/runtime/task_serde.py +34 -19
  46. flyte/_internal/runtime/taskrunner.py +22 -4
  47. flyte/_internal/runtime/trigger_serde.py +160 -0
  48. flyte/_internal/runtime/types_serde.py +1 -1
  49. flyte/_keyring/__init__.py +0 -0
  50. flyte/_keyring/file.py +115 -0
  51. flyte/_logging.py +201 -39
  52. flyte/_map.py +111 -14
  53. flyte/_module.py +70 -0
  54. flyte/_pod.py +4 -3
  55. flyte/_resources.py +213 -31
  56. flyte/_run.py +110 -39
  57. flyte/_task.py +75 -16
  58. flyte/_task_environment.py +105 -29
  59. flyte/_task_plugins.py +4 -2
  60. flyte/_trace.py +5 -0
  61. flyte/_trigger.py +1000 -0
  62. flyte/_utils/__init__.py +2 -1
  63. flyte/_utils/asyn.py +3 -1
  64. flyte/_utils/coro_management.py +2 -1
  65. flyte/_utils/docker_credentials.py +173 -0
  66. flyte/_utils/module_loader.py +17 -2
  67. flyte/_version.py +3 -3
  68. flyte/cli/_abort.py +3 -3
  69. flyte/cli/_build.py +3 -6
  70. flyte/cli/_common.py +78 -7
  71. flyte/cli/_create.py +182 -4
  72. flyte/cli/_delete.py +23 -1
  73. flyte/cli/_deploy.py +63 -16
  74. flyte/cli/_get.py +79 -34
  75. flyte/cli/_params.py +26 -10
  76. flyte/cli/_plugins.py +209 -0
  77. flyte/cli/_run.py +151 -26
  78. flyte/cli/_serve.py +64 -0
  79. flyte/cli/_update.py +37 -0
  80. flyte/cli/_user.py +17 -0
  81. flyte/cli/main.py +30 -4
  82. flyte/config/_config.py +10 -6
  83. flyte/config/_internal.py +1 -0
  84. flyte/config/_reader.py +29 -8
  85. flyte/connectors/__init__.py +11 -0
  86. flyte/connectors/_connector.py +270 -0
  87. flyte/connectors/_server.py +197 -0
  88. flyte/connectors/utils.py +135 -0
  89. flyte/errors.py +22 -2
  90. flyte/extend.py +8 -1
  91. flyte/extras/_container.py +6 -1
  92. flyte/git/__init__.py +3 -0
  93. flyte/git/_config.py +21 -0
  94. flyte/io/__init__.py +2 -0
  95. flyte/io/_dataframe/__init__.py +2 -0
  96. flyte/io/_dataframe/basic_dfs.py +17 -8
  97. flyte/io/_dataframe/dataframe.py +98 -132
  98. flyte/io/_dir.py +575 -113
  99. flyte/io/_file.py +582 -139
  100. flyte/io/_hashing_io.py +342 -0
  101. flyte/models.py +74 -15
  102. flyte/remote/__init__.py +6 -1
  103. flyte/remote/_action.py +34 -26
  104. flyte/remote/_client/_protocols.py +39 -4
  105. flyte/remote/_client/auth/_authenticators/device_code.py +4 -5
  106. flyte/remote/_client/auth/_authenticators/pkce.py +1 -1
  107. flyte/remote/_client/auth/_channel.py +10 -6
  108. flyte/remote/_client/controlplane.py +17 -5
  109. flyte/remote/_console.py +3 -2
  110. flyte/remote/_data.py +6 -6
  111. flyte/remote/_logs.py +3 -3
  112. flyte/remote/_run.py +64 -8
  113. flyte/remote/_secret.py +26 -17
  114. flyte/remote/_task.py +75 -33
  115. flyte/remote/_trigger.py +306 -0
  116. flyte/remote/_user.py +33 -0
  117. flyte/report/_report.py +1 -1
  118. flyte/storage/__init__.py +6 -1
  119. flyte/storage/_config.py +5 -1
  120. flyte/storage/_parallel_reader.py +274 -0
  121. flyte/storage/_storage.py +200 -103
  122. flyte/types/__init__.py +16 -0
  123. flyte/types/_interface.py +2 -2
  124. flyte/types/_pickle.py +35 -8
  125. flyte/types/_string_literals.py +8 -9
  126. flyte/types/_type_engine.py +40 -70
  127. flyte/types/_utils.py +1 -1
  128. flyte-2.0.0b30.data/scripts/debug.py +38 -0
  129. {flyte-2.0.0b13.data → flyte-2.0.0b30.data}/scripts/runtime.py +62 -8
  130. {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/METADATA +11 -3
  131. flyte-2.0.0b30.dist-info/RECORD +192 -0
  132. {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/entry_points.txt +3 -0
  133. flyte/_protos/common/authorization_pb2.py +0 -66
  134. flyte/_protos/common/authorization_pb2.pyi +0 -108
  135. flyte/_protos/common/authorization_pb2_grpc.py +0 -4
  136. flyte/_protos/common/identifier_pb2.py +0 -93
  137. flyte/_protos/common/identifier_pb2.pyi +0 -110
  138. flyte/_protos/common/identifier_pb2_grpc.py +0 -4
  139. flyte/_protos/common/identity_pb2.py +0 -48
  140. flyte/_protos/common/identity_pb2.pyi +0 -72
  141. flyte/_protos/common/identity_pb2_grpc.py +0 -4
  142. flyte/_protos/common/list_pb2.py +0 -36
  143. flyte/_protos/common/list_pb2.pyi +0 -71
  144. flyte/_protos/common/list_pb2_grpc.py +0 -4
  145. flyte/_protos/common/policy_pb2.py +0 -37
  146. flyte/_protos/common/policy_pb2.pyi +0 -27
  147. flyte/_protos/common/policy_pb2_grpc.py +0 -4
  148. flyte/_protos/common/role_pb2.py +0 -37
  149. flyte/_protos/common/role_pb2.pyi +0 -53
  150. flyte/_protos/common/role_pb2_grpc.py +0 -4
  151. flyte/_protos/common/runtime_version_pb2.py +0 -28
  152. flyte/_protos/common/runtime_version_pb2.pyi +0 -24
  153. flyte/_protos/common/runtime_version_pb2_grpc.py +0 -4
  154. flyte/_protos/imagebuilder/definition_pb2.py +0 -59
  155. flyte/_protos/imagebuilder/definition_pb2.pyi +0 -140
  156. flyte/_protos/imagebuilder/definition_pb2_grpc.py +0 -4
  157. flyte/_protos/imagebuilder/payload_pb2.py +0 -32
  158. flyte/_protos/imagebuilder/payload_pb2.pyi +0 -21
  159. flyte/_protos/imagebuilder/payload_pb2_grpc.py +0 -4
  160. flyte/_protos/imagebuilder/service_pb2.py +0 -29
  161. flyte/_protos/imagebuilder/service_pb2.pyi +0 -5
  162. flyte/_protos/imagebuilder/service_pb2_grpc.py +0 -66
  163. flyte/_protos/logs/dataplane/payload_pb2.py +0 -100
  164. flyte/_protos/logs/dataplane/payload_pb2.pyi +0 -177
  165. flyte/_protos/logs/dataplane/payload_pb2_grpc.py +0 -4
  166. flyte/_protos/secret/definition_pb2.py +0 -49
  167. flyte/_protos/secret/definition_pb2.pyi +0 -93
  168. flyte/_protos/secret/definition_pb2_grpc.py +0 -4
  169. flyte/_protos/secret/payload_pb2.py +0 -62
  170. flyte/_protos/secret/payload_pb2.pyi +0 -94
  171. flyte/_protos/secret/payload_pb2_grpc.py +0 -4
  172. flyte/_protos/secret/secret_pb2.py +0 -38
  173. flyte/_protos/secret/secret_pb2.pyi +0 -6
  174. flyte/_protos/secret/secret_pb2_grpc.py +0 -198
  175. flyte/_protos/secret/secret_pb2_grpc_grpc.py +0 -198
  176. flyte/_protos/validate/validate/validate_pb2.py +0 -76
  177. flyte/_protos/workflow/common_pb2.py +0 -27
  178. flyte/_protos/workflow/common_pb2.pyi +0 -14
  179. flyte/_protos/workflow/common_pb2_grpc.py +0 -4
  180. flyte/_protos/workflow/environment_pb2.py +0 -29
  181. flyte/_protos/workflow/environment_pb2.pyi +0 -12
  182. flyte/_protos/workflow/environment_pb2_grpc.py +0 -4
  183. flyte/_protos/workflow/node_execution_service_pb2.py +0 -26
  184. flyte/_protos/workflow/node_execution_service_pb2.pyi +0 -4
  185. flyte/_protos/workflow/node_execution_service_pb2_grpc.py +0 -32
  186. flyte/_protos/workflow/queue_service_pb2.py +0 -109
  187. flyte/_protos/workflow/queue_service_pb2.pyi +0 -166
  188. flyte/_protos/workflow/queue_service_pb2_grpc.py +0 -172
  189. flyte/_protos/workflow/run_definition_pb2.py +0 -121
  190. flyte/_protos/workflow/run_definition_pb2.pyi +0 -327
  191. flyte/_protos/workflow/run_definition_pb2_grpc.py +0 -4
  192. flyte/_protos/workflow/run_logs_service_pb2.py +0 -41
  193. flyte/_protos/workflow/run_logs_service_pb2.pyi +0 -28
  194. flyte/_protos/workflow/run_logs_service_pb2_grpc.py +0 -69
  195. flyte/_protos/workflow/run_service_pb2.py +0 -137
  196. flyte/_protos/workflow/run_service_pb2.pyi +0 -185
  197. flyte/_protos/workflow/run_service_pb2_grpc.py +0 -446
  198. flyte/_protos/workflow/state_service_pb2.py +0 -67
  199. flyte/_protos/workflow/state_service_pb2.pyi +0 -76
  200. flyte/_protos/workflow/state_service_pb2_grpc.py +0 -138
  201. flyte/_protos/workflow/task_definition_pb2.py +0 -79
  202. flyte/_protos/workflow/task_definition_pb2.pyi +0 -81
  203. flyte/_protos/workflow/task_definition_pb2_grpc.py +0 -4
  204. flyte/_protos/workflow/task_service_pb2.py +0 -60
  205. flyte/_protos/workflow/task_service_pb2.pyi +0 -59
  206. flyte/_protos/workflow/task_service_pb2_grpc.py +0 -138
  207. flyte-2.0.0b13.dist-info/RECORD +0 -239
  208. /flyte/{_protos → _debug}/__init__.py +0 -0
  209. {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/WHEEL +0 -0
  210. {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/licenses/LICENSE +0 -0
  211. {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/top_level.txt +0 -0
flyte/_initialize.py CHANGED
@@ -3,14 +3,14 @@ from __future__ import annotations
3
3
  import functools
4
4
  import threading
5
5
  import typing
6
- from dataclasses import dataclass, replace
6
+ from dataclasses import dataclass, field, replace
7
7
  from pathlib import Path
8
8
  from typing import TYPE_CHECKING, Callable, List, Literal, Optional, TypeVar
9
9
 
10
10
  from flyte.errors import InitializationError
11
11
  from flyte.syncify import syncify
12
12
 
13
- from ._logging import initialize_logger, logger
13
+ from ._logging import LogFormat, initialize_logger, logger
14
14
 
15
15
  if TYPE_CHECKING:
16
16
  from flyte._internal.imagebuild import ImageBuildEngine
@@ -33,6 +33,8 @@ class CommonInit:
33
33
  project: str | None = None
34
34
  domain: str | None = None
35
35
  batch_size: int = 1000
36
+ source_config_path: Optional[Path] = None # Only used for documentation
37
+ sync_local_sys_paths: bool = True
36
38
 
37
39
 
38
40
  @dataclass(init=True, kw_only=True, repr=True, eq=True, frozen=True)
@@ -40,6 +42,7 @@ class _InitConfig(CommonInit):
40
42
  client: Optional[ClientSet] = None
41
43
  storage: Optional[Storage] = None
42
44
  image_builder: "ImageBuildEngine.ImageBuilderType" = "local"
45
+ images: typing.Dict[str, str] = field(default_factory=dict)
43
46
 
44
47
  def replace(self, **kwargs) -> _InitConfig:
45
48
  return replace(self, **kwargs)
@@ -110,6 +113,10 @@ async def _initialize_client(
110
113
  )
111
114
 
112
115
 
116
+ def _initialize_logger(log_level: int | None = None, log_format: LogFormat | None = None) -> None:
117
+ initialize_logger(log_level=log_level, log_format=log_format, enable_rich=True)
118
+
119
+
113
120
  @syncify
114
121
  async def init(
115
122
  org: str | None = None,
@@ -117,6 +124,7 @@ async def init(
117
124
  domain: str | None = None,
118
125
  root_dir: Path | None = None,
119
126
  log_level: int | None = None,
127
+ log_format: LogFormat | None = None,
120
128
  endpoint: str | None = None,
121
129
  headless: bool = False,
122
130
  insecure: bool = False,
@@ -134,6 +142,10 @@ async def init(
134
142
  storage: Storage | None = None,
135
143
  batch_size: int = 1000,
136
144
  image_builder: ImageBuildEngine.ImageBuilderType = "local",
145
+ images: typing.Dict[str, str] | None = None,
146
+ source_config_path: Optional[Path] = None,
147
+ sync_local_sys_paths: bool = True,
148
+ load_plugin_type_transformers: bool = True,
137
149
  ) -> None:
138
150
  """
139
151
  Initialize the Flyte system with the given configuration. This method should be called before any other Flyte
@@ -146,6 +158,7 @@ async def init(
146
158
  also use to determine all the code that needs to be copied to the remote location.
147
159
  defaults to the editable install directory if the cwd is in a Python editable install, else just the cwd.
148
160
  :param log_level: Optional logging level for the logger, default is set using the default initialization policies
161
+ :param log_format: Optional logging format for the logger, default is "console"
149
162
  :param api_key: Optional API key for authentication
150
163
  :param endpoint: Optional API endpoint URL
151
164
  :param headless: Optional Whether to run in headless mode
@@ -162,24 +175,26 @@ async def init(
162
175
  :param ca_cert_file_path: [optional] str Root Cert to be loaded and used to verify admin
163
176
  :param http_proxy_url: [optional] HTTP Proxy to be used for OAuth requests
164
177
  :param rpc_retries: [optional] int Number of times to retry the platform calls
165
- :param audience: oauth2 audience for the token request. This is used to validate the token
166
178
  :param insecure: insecure flag for the client
167
179
  :param storage: Optional blob store (S3, GCS, Azure) configuration if needed to access (i.e. using Minio)
168
180
  :param org: Optional organization override for the client. Should be set by auth instead.
169
181
  :param batch_size: Optional batch size for operations that use listings, defaults to 1000, so limit larger than
170
182
  batch_size will be split into multiple requests.
171
183
  :param image_builder: Optional image builder configuration, if not provided, the default image builder will be used.
172
-
184
+ :param images: Optional dict of images that can be used by referencing the image name.
185
+ :param source_config_path: Optional path to the source configuration file (This is only used for documentation)
186
+ :param sync_local_sys_paths: Whether to include and synchronize local sys.path entries under the root directory
187
+ into the remote container (default: True).
188
+ :param load_plugin_type_transformers: If enabled (default True), load the type transformer plugins registered under
189
+ the "flyte.plugins.types" entry point group.
173
190
  :return: None
174
191
  """
175
- from flyte._tools import ipython_check
176
192
  from flyte._utils import get_cwd_editable_install, org_from_endpoint, sanitize_endpoint
193
+ from flyte.types import _load_custom_type_transformers
177
194
 
178
- interactive_mode = ipython_check()
179
-
180
- initialize_logger(enable_rich=interactive_mode)
181
- if log_level:
182
- initialize_logger(log_level=log_level, enable_rich=interactive_mode)
195
+ _initialize_logger(log_level=log_level, log_format=log_format)
196
+ if load_plugin_type_transformers:
197
+ _load_custom_type_transformers()
183
198
 
184
199
  global _init_config # noqa: PLW0603
185
200
 
@@ -205,7 +220,15 @@ async def init(
205
220
  http_proxy_url=http_proxy_url,
206
221
  )
207
222
 
208
- root_dir = root_dir or get_cwd_editable_install() or Path.cwd()
223
+ if not root_dir:
224
+ editable_root = get_cwd_editable_install()
225
+ if editable_root:
226
+ logger.info(f"Using editable install as root directory: {editable_root}")
227
+ root_dir = editable_root
228
+ else:
229
+ logger.info("No editable install found, using current working directory as root directory.")
230
+ root_dir = Path.cwd()
231
+
209
232
  _init_config = _InitConfig(
210
233
  root_dir=root_dir,
211
234
  project=project,
@@ -215,14 +238,21 @@ async def init(
215
238
  org=org or org_from_endpoint(endpoint),
216
239
  batch_size=batch_size,
217
240
  image_builder=image_builder,
241
+ images=images or {},
242
+ source_config_path=source_config_path,
243
+ sync_local_sys_paths=sync_local_sys_paths,
218
244
  )
219
245
 
220
246
 
221
247
  @syncify
222
248
  async def init_from_config(
223
- path_or_config: str | Config | None = None,
249
+ path_or_config: str | Path | Config | None = None,
224
250
  root_dir: Path | None = None,
225
251
  log_level: int | None = None,
252
+ log_format: LogFormat = "console",
253
+ storage: Storage | None = None,
254
+ images: tuple[str, ...] | None = None,
255
+ sync_local_sys_paths: bool = True,
226
256
  ) -> None:
227
257
  """
228
258
  Initialize the Flyte system using a configuration file or Config object. This method should be called before any
@@ -235,29 +265,43 @@ async def init_from_config(
235
265
  if not available, the current working directory.
236
266
  :param log_level: Optional logging level for the framework logger,
237
267
  default is set using the default initialization policies
268
+ :param log_format: Optional logging format for the logger, default is "console"
269
+ :param storage: Optional blob store (S3, GCS, Azure) configuration if needed to access (i.e. using Minio)
270
+ :param images: List of image strings in format "imagename=imageuri" or just "imageuri".
271
+ :param sync_local_sys_paths: Whether to include and synchronize local sys.path entries under the root directory
272
+ into the remote container (default: True).
238
273
  :return: None
239
274
  """
275
+ from rich.highlighter import ReprHighlighter
276
+
240
277
  import flyte.config as config
278
+ from flyte.cli._common import parse_images
241
279
 
242
280
  cfg: config.Config
243
- if path_or_config is None or isinstance(path_or_config, str):
244
- # If a string is passed, treat it as a path to the config file
245
- if path_or_config:
246
- if not Path(path_or_config).exists():
247
- raise InitializationError(
248
- "ConfigFileNotFoundError",
249
- "user",
250
- f"Configuration file '{path_or_config}' does not exist., current working directory is {Path.cwd()}",
251
- )
252
- if root_dir and path_or_config:
253
- cfg = config.auto(str(root_dir / path_or_config))
281
+ cfg_path: Optional[Path] = None
282
+ if path_or_config is None:
283
+ # If no path is provided, use the default config file
284
+ cfg = config.auto()
285
+ elif isinstance(path_or_config, (str, Path)):
286
+ if root_dir:
287
+ cfg_path = root_dir.expanduser() / path_or_config
254
288
  else:
255
- cfg = config.auto(path_or_config)
289
+ cfg_path = Path(path_or_config).expanduser()
290
+ if not Path(cfg_path).exists():
291
+ raise InitializationError(
292
+ "ConfigFileNotFoundError",
293
+ "user",
294
+ f"Configuration file '{cfg_path}' does not exist., current working directory is {Path.cwd()}",
295
+ )
296
+ cfg = config.auto(cfg_path)
256
297
  else:
257
- # If a Config object is passed, use it directly
258
298
  cfg = path_or_config
259
299
 
260
- logger.debug(f"Flyte config initialized as {cfg}")
300
+ logger.info(f"Flyte config initialized as {cfg}", extra={"highlighter": ReprHighlighter()})
301
+
302
+ # parse image, this will overwrite the image_refs set in the config file
303
+ parse_images(cfg, images)
304
+
261
305
  await init.aio(
262
306
  org=cfg.task.org,
263
307
  project=cfg.task.project,
@@ -273,7 +317,12 @@ async def init_from_config(
273
317
  client_credentials_secret=cfg.platform.client_credentials_secret,
274
318
  root_dir=root_dir,
275
319
  log_level=log_level,
320
+ log_format=log_format,
276
321
  image_builder=cfg.image.builder,
322
+ images=cfg.image.image_refs,
323
+ storage=storage,
324
+ source_config_path=cfg_path,
325
+ sync_local_sys_paths=sync_local_sys_paths,
277
326
  )
278
327
 
279
328
 
@@ -287,7 +336,7 @@ def _get_init_config() -> Optional[_InitConfig]:
287
336
  return _init_config
288
337
 
289
338
 
290
- def get_common_config() -> CommonInit:
339
+ def get_init_config() -> _InitConfig:
291
340
  """
292
341
  Get the current initialization configuration. Thread-safe implementation.
293
342
 
@@ -374,30 +423,6 @@ def ensure_client():
374
423
  )
375
424
 
376
425
 
377
- def requires_client(func: T) -> T:
378
- """
379
- Decorator that checks if the client has been initialized before executing the function.
380
- Raises InitializationError if the client is not initialized.
381
-
382
- :param func: Function to decorate
383
- :return: Decorated function that checks for initialization
384
- """
385
-
386
- @functools.wraps(func)
387
- async def wrapper(*args, **kwargs) -> T:
388
- init_config = _get_init_config()
389
- if init_config is None or init_config.client is None:
390
- raise InitializationError(
391
- "ClientNotInitializedError",
392
- "user",
393
- f"Function '{func.__name__}' requires client to be initialized. "
394
- f"Call flyte.init() with a valid endpoint or api-key before using this function.",
395
- )
396
- return func(*args, **kwargs)
397
-
398
- return typing.cast(T, wrapper)
399
-
400
-
401
426
  def requires_storage(func: T) -> T:
402
427
  """
403
428
  Decorator that checks if the storage has been initialized before executing the function.
@@ -469,6 +494,34 @@ def requires_initialization(func: T) -> T:
469
494
  return typing.cast(T, wrapper)
470
495
 
471
496
 
497
+ def require_project_and_domain(func):
498
+ """
499
+ Decorator that ensures the current Flyte configuration defines
500
+ both 'project' and 'domain'. Raises a clear error if not found.
501
+ """
502
+
503
+ @functools.wraps(func)
504
+ def wrapper(*args, **kwargs):
505
+ cfg = get_init_config()
506
+ if cfg.project is None:
507
+ raise ValueError(
508
+ "Project must be provided to initialize the client. "
509
+ "Please set 'project' in the 'task' section of your config file, "
510
+ "or pass it directly to flyte.init(project='your-project-name')."
511
+ )
512
+
513
+ if cfg.domain is None:
514
+ raise ValueError(
515
+ "Domain must be provided to initialize the client. "
516
+ "Please set 'domain' in the 'task' section of your config file, "
517
+ "or pass it directly to flyte.init(domain='your-domain-name')."
518
+ )
519
+
520
+ return func(*args, **kwargs)
521
+
522
+ return wrapper
523
+
524
+
472
525
  async def _init_for_testing(
473
526
  project: str | None = None,
474
527
  domain: str | None = None,
@@ -498,3 +551,31 @@ def replace_client(client):
498
551
 
499
552
  with _init_lock:
500
553
  _init_config = _init_config.replace(client=client)
554
+
555
+
556
+ def current_domain() -> str:
557
+ """
558
+ Returns the current domain from Runtime environment (on the cluster) or from the initialized configuration.
559
+ This is safe to be used during `deploy`, `run` and within `task` code.
560
+
561
+ NOTE: This will not work if you deploy a task to a domain and then run it in another domain.
562
+
563
+ Raises InitializationError if the configuration is not initialized or domain is not set.
564
+ :return: The current domain
565
+ """
566
+ from ._context import ctx
567
+
568
+ tctx = ctx()
569
+ if tctx is not None:
570
+ domain = tctx.action.domain
571
+ if domain is not None:
572
+ return domain
573
+
574
+ cfg = _get_init_config()
575
+ if cfg is None or cfg.domain is None:
576
+ raise InitializationError(
577
+ "DomainNotInitializedError",
578
+ "user",
579
+ "Domain has not been initialized. Call flyte.init() with a valid domain before using this function.",
580
+ )
581
+ return cfg.domain
flyte/_interface.py CHANGED
@@ -1,7 +1,13 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import inspect
4
- from typing import Dict, Generator, Tuple, Type, TypeVar, Union, cast, get_args, get_type_hints
4
+ import typing
5
+ from enum import Enum
6
+ from typing import Dict, Generator, Literal, Tuple, Type, TypeVar, Union, cast, get_args, get_origin, get_type_hints
7
+
8
+ from flyte._logging import logger
9
+
10
+ LITERAL_ENUM = "LiteralEnum"
5
11
 
6
12
 
7
13
  def default_output_name(index: int = 0) -> str:
@@ -45,6 +51,8 @@ def extract_return_annotation(return_annotation: Union[Type, Tuple, None]) -> Di
45
51
  Note that Options 1 and 2 are identical, just syntactic sugar. In the NamedTuple case, we'll use the names in the
46
52
  definition. In all other cases, we'll automatically generate output names, indexed starting at 0.
47
53
  """
54
+ if isinstance(return_annotation, str):
55
+ raise TypeError("String return annotations are not supported.")
48
56
 
49
57
  # Handle Option 6
50
58
  # We can think about whether we should add a default output name with type None in the future.
@@ -69,7 +77,15 @@ def extract_return_annotation(return_annotation: Union[Type, Tuple, None]) -> Di
69
77
  if len(return_annotation.__args__) == 1: # type: ignore
70
78
  raise TypeError("Tuples should be used to indicate multiple return values, found only one return variable.")
71
79
  ra = get_args(return_annotation)
72
- return dict(zip(list(output_name_generator(len(ra))), ra))
80
+ annotations = {}
81
+ for i, r in enumerate(ra):
82
+ if r is Ellipsis:
83
+ raise TypeError("Variable length tuples are not supported as return types.")
84
+ if get_origin(r) is Literal:
85
+ annotations[default_output_name(i)] = literal_to_enum(cast(Type, r))
86
+ else:
87
+ annotations[default_output_name(i)] = r
88
+ return annotations
73
89
 
74
90
  elif isinstance(return_annotation, tuple):
75
91
  if len(return_annotation) == 1:
@@ -79,4 +95,25 @@ def extract_return_annotation(return_annotation: Union[Type, Tuple, None]) -> Di
79
95
  else:
80
96
  # Handle all other single return types
81
97
  # Task returns unnamed native tuple
98
+ if get_origin(return_annotation) is Literal:
99
+ return {default_output_name(): literal_to_enum(cast(Type, return_annotation))}
82
100
  return {default_output_name(): cast(Type, return_annotation)}
101
+
102
+
103
+ def literal_to_enum(literal_type: Type) -> Type[Enum | typing.Any]:
104
+ """Convert a Literal[...] into Union[str, Enum]."""
105
+
106
+ if get_origin(literal_type) is not Literal:
107
+ raise TypeError(f"{literal_type} is not a Literal")
108
+
109
+ values = get_args(literal_type)
110
+ if not all(isinstance(v, str) for v in values):
111
+ logger.warning(f"Literal type {literal_type} contains non-string values, using Any instead of Enum")
112
+ return typing.Any
113
+ # Deduplicate & keep order
114
+ enum_dict = {str(v).upper(): v for v in values}
115
+
116
+ # Dynamically create an Enum
117
+ literal_enum = Enum(LITERAL_ENUM, enum_dict) # type: ignore
118
+
119
+ return literal_enum # type: ignore
@@ -5,12 +5,13 @@ from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Protocol, Tu
5
5
  from flyte._task import TaskTemplate
6
6
  from flyte.models import ActionID, NativeInterface
7
7
 
8
+ if TYPE_CHECKING:
9
+ from flyte.remote._task import TaskDetails
10
+
8
11
  from ._trace import TraceInfo
9
12
 
10
13
  __all__ = ["Controller", "ControllerType", "TraceInfo", "create_controller", "get_controller"]
11
14
 
12
- from ..._protos.workflow import task_definition_pb2
13
-
14
15
  if TYPE_CHECKING:
15
16
  import concurrent.futures
16
17
 
@@ -41,9 +42,7 @@ class Controller(Protocol):
41
42
  """
42
43
  ...
43
44
 
44
- async def submit_task_ref(
45
- self, _task: task_definition_pb2.TaskDetails, max_inline_io_bytes: int, *args, **kwargs
46
- ) -> Any:
45
+ async def submit_task_ref(self, _task: "TaskDetails", *args, **kwargs) -> Any:
47
46
  """
48
47
  Submit a task reference to the controller asynchronously and wait for the result. This is async and will block
49
48
  the current coroutine until the result is available.
@@ -2,19 +2,23 @@ import asyncio
2
2
  import atexit
3
3
  import concurrent.futures
4
4
  import os
5
+ import pathlib
5
6
  import threading
6
7
  from typing import Any, Callable, Tuple, TypeVar
7
8
 
8
9
  import flyte.errors
10
+ from flyte._cache.cache import VersionParameters, cache_from_request
11
+ from flyte._cache.local_cache import LocalTaskCache
9
12
  from flyte._context import internal_ctx
10
13
  from flyte._internal.controllers import TraceInfo
11
14
  from flyte._internal.runtime import convert
12
15
  from flyte._internal.runtime.entrypoints import direct_dispatch
16
+ from flyte._internal.runtime.types_serde import transform_native_to_typed_interface
13
17
  from flyte._logging import log, logger
14
- from flyte._protos.workflow import task_definition_pb2
15
- from flyte._task import TaskTemplate
18
+ from flyte._task import AsyncFunctionTaskTemplate, TaskTemplate
16
19
  from flyte._utils.helpers import _selector_policy
17
20
  from flyte.models import ActionID, NativeInterface
21
+ from flyte.remote._task import TaskDetails
18
22
 
19
23
  R = TypeVar("R")
20
24
 
@@ -81,31 +85,67 @@ class LocalController:
81
85
  raise flyte.errors.RuntimeSystemError("BadContext", "Task context not initialized")
82
86
 
83
87
  inputs = await convert.convert_from_native_to_inputs(_task.native_interface, *args, **kwargs)
84
- serialized_inputs = inputs.proto_inputs.SerializeToString(deterministic=True)
88
+ inputs_hash = convert.generate_inputs_hash_from_proto(inputs.proto_inputs)
89
+ task_interface = transform_native_to_typed_interface(_task.interface)
85
90
 
86
91
  sub_action_id, sub_action_output_path = convert.generate_sub_action_id_and_output_path(
87
- tctx, _task.name, serialized_inputs, 0
92
+ tctx, _task.name, inputs_hash, 0
88
93
  )
89
94
  sub_action_raw_data_path = tctx.raw_data_path
90
-
91
- out, err = await direct_dispatch(
92
- _task,
93
- controller=self,
94
- action=sub_action_id,
95
- raw_data_path=sub_action_raw_data_path,
96
- inputs=inputs,
97
- version=tctx.version,
98
- checkpoints=tctx.checkpoints,
99
- code_bundle=tctx.code_bundle,
100
- output_path=sub_action_output_path,
101
- run_base_dir=tctx.run_base_dir,
95
+ # Make sure the output path exists
96
+ pathlib.Path(sub_action_output_path).mkdir(parents=True, exist_ok=True)
97
+ pathlib.Path(sub_action_raw_data_path.path).mkdir(parents=True, exist_ok=True)
98
+
99
+ task_cache = cache_from_request(_task.cache)
100
+ cache_enabled = task_cache.is_enabled()
101
+ if isinstance(_task, AsyncFunctionTaskTemplate):
102
+ version_parameters = VersionParameters(func=_task.func, image=_task.image)
103
+ else:
104
+ version_parameters = VersionParameters(func=None, image=_task.image)
105
+ cache_version = task_cache.get_version(version_parameters)
106
+ cache_key = convert.generate_cache_key_hash(
107
+ _task.name,
108
+ inputs_hash,
109
+ task_interface,
110
+ cache_version,
111
+ list(task_cache.get_ignored_inputs()),
112
+ inputs.proto_inputs,
102
113
  )
103
- if err:
104
- exc = convert.convert_error_to_native(err)
105
- if exc:
106
- raise exc
107
- else:
108
- raise flyte.errors.RuntimeSystemError("BadError", "Unknown error")
114
+
115
+ out = None
116
+ # We only get output from cache if the cache behavior is set to auto
117
+ if task_cache.behavior == "auto":
118
+ out = await LocalTaskCache.get(cache_key)
119
+ if out is not None:
120
+ logger.info(
121
+ f"Cache hit for task '{_task.name}' (version: {cache_version}), getting result from cache..."
122
+ )
123
+
124
+ if out is None:
125
+ out, err = await direct_dispatch(
126
+ _task,
127
+ controller=self,
128
+ action=sub_action_id,
129
+ raw_data_path=sub_action_raw_data_path,
130
+ inputs=inputs,
131
+ version=cache_version,
132
+ checkpoints=tctx.checkpoints,
133
+ code_bundle=tctx.code_bundle,
134
+ output_path=sub_action_output_path,
135
+ run_base_dir=tctx.run_base_dir,
136
+ )
137
+
138
+ if err:
139
+ exc = convert.convert_error_to_native(err)
140
+ if exc:
141
+ raise exc
142
+ else:
143
+ raise flyte.errors.RuntimeSystemError("BadError", "Unknown error")
144
+
145
+ # store into cache
146
+ if cache_enabled and out is not None:
147
+ await LocalTaskCache.set(cache_key, out)
148
+
109
149
  if _task.native_interface.outputs:
110
150
  if out is None:
111
151
  raise flyte.errors.RuntimeSystemError("BadOutput", "Task output not captured.")
@@ -129,7 +169,7 @@ class LocalController:
129
169
  pass
130
170
 
131
171
  async def stop(self):
132
- pass
172
+ await LocalTaskCache.close()
133
173
 
134
174
  async def watch_for_errors(self):
135
175
  pass
@@ -146,16 +186,17 @@ class LocalController:
146
186
  tctx = ctx.data.task_context
147
187
  if not tctx:
148
188
  raise flyte.errors.NotInTaskContextError("BadContext", "Task context not initialized")
189
+
149
190
  converted_inputs = convert.Inputs.empty()
150
191
  if _interface.inputs:
151
192
  converted_inputs = await convert.convert_from_native_to_inputs(_interface, *args, **kwargs)
152
193
  assert converted_inputs
153
194
 
154
- serialized_inputs = converted_inputs.proto_inputs.SerializeToString(deterministic=True)
195
+ inputs_hash = convert.generate_inputs_hash_from_proto(converted_inputs.proto_inputs)
155
196
  action_id, action_output_path = convert.generate_sub_action_id_and_output_path(
156
197
  tctx,
157
198
  _func.__name__,
158
- serialized_inputs,
199
+ inputs_hash,
159
200
  0,
160
201
  )
161
202
  assert action_output_path
@@ -192,7 +233,7 @@ class LocalController:
192
233
  assert info.start_time
193
234
  assert info.end_time
194
235
 
195
- async def submit_task_ref(
196
- self, _task: task_definition_pb2.TaskDetails, max_inline_io_bytes: int, *args, **kwargs
197
- ) -> Any:
198
- raise flyte.errors.ReferenceTaskError("Reference tasks cannot be executed locally, only remotely.")
236
+ async def submit_task_ref(self, _task: TaskDetails, max_inline_io_bytes: int, *args, **kwargs) -> Any:
237
+ raise flyte.errors.ReferenceTaskError(
238
+ f"Reference tasks cannot be executed locally, only remotely. Found remote task {_task.name}"
239
+ )
@@ -1,7 +1,7 @@
1
1
  from dataclasses import dataclass, field
2
2
  from typing import Any, Optional
3
3
 
4
- from flyteidl.core import interface_pb2
4
+ from flyteidl2.core import interface_pb2
5
5
 
6
6
  from flyte.models import ActionID, NativeInterface
7
7
 
@@ -54,7 +54,5 @@ def create_remote_controller(
54
54
 
55
55
  controller = RemoteController(
56
56
  client_coro=client_coro,
57
- workers=10,
58
- max_system_retries=5,
59
57
  )
60
58
  return controller
@@ -1,18 +1,17 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from dataclasses import dataclass
4
- from typing import Literal
4
+ from typing import Literal, Optional
5
5
 
6
- from flyteidl.core import execution_pb2, interface_pb2, tasks_pb2
7
- from google.protobuf import timestamp_pb2
8
-
9
- from flyte._protos.common import identifier_pb2
10
- from flyte._protos.workflow import (
11
- queue_service_pb2,
6
+ from flyteidl2.common import identifier_pb2
7
+ from flyteidl2.core import execution_pb2, interface_pb2
8
+ from flyteidl2.task import common_pb2, task_definition_pb2
9
+ from flyteidl2.workflow import (
12
10
  run_definition_pb2,
13
11
  state_service_pb2,
14
- task_definition_pb2,
15
12
  )
13
+ from google.protobuf import timestamp_pb2
14
+
16
15
  from flyte.models import GroupData
17
16
 
18
17
  ActionType = Literal["task", "trace"]
@@ -31,7 +30,7 @@ class Action:
31
30
  friendly_name: str | None = None
32
31
  group: GroupData | None = None
33
32
  task: task_definition_pb2.TaskSpec | None = None
34
- trace: queue_service_pb2.TraceAction | None = None
33
+ trace: run_definition_pb2.TraceAction | None = None
35
34
  inputs_uri: str | None = None
36
35
  run_output_base: str | None = None
37
36
  realized_outputs_uri: str | None = None
@@ -39,6 +38,7 @@ class Action:
39
38
  phase: run_definition_pb2.Phase | None = None
40
39
  started: bool = False
41
40
  retries: int = 0
41
+ queue: Optional[str] = None # The queue to which this action was submitted.
42
42
  client_err: Exception | None = None # This error is set when something goes wrong in the controller.
43
43
  cache_key: str | None = None # None means no caching, otherwise it is the version of the cache.
44
44
 
@@ -122,6 +122,7 @@ class Action:
122
122
  inputs_uri: str,
123
123
  run_output_base: str,
124
124
  cache_key: str | None = None,
125
+ queue: Optional[str] = None,
125
126
  ) -> Action:
126
127
  return cls(
127
128
  action_id=sub_action_id,
@@ -132,6 +133,7 @@ class Action:
132
133
  inputs_uri=inputs_uri,
133
134
  run_output_base=run_output_base,
134
135
  cache_key=cache_key,
136
+ queue=queue,
135
137
  )
136
138
 
137
139
  @classmethod
@@ -183,11 +185,7 @@ class Action:
183
185
  et.FromSeconds(int(end_time))
184
186
  et.nanos = int((end_time % 1) * 1e9)
185
187
 
186
- spec = (
187
- task_definition_pb2.TaskSpec(task_template=tasks_pb2.TaskTemplate(interface=typed_interface))
188
- if typed_interface
189
- else None
190
- )
188
+ spec = task_definition_pb2.TraceSpec(interface=typed_interface) if typed_interface else None
191
189
 
192
190
  return cls(
193
191
  action_id=action_id,
@@ -199,12 +197,12 @@ class Action:
199
197
  realized_outputs_uri=outputs_uri,
200
198
  phase=run_definition_pb2.Phase.PHASE_SUCCEEDED,
201
199
  run_output_base=run_output_base,
202
- trace=queue_service_pb2.TraceAction(
200
+ trace=run_definition_pb2.TraceAction(
203
201
  name=friendly_name,
204
202
  phase=run_definition_pb2.Phase.PHASE_SUCCEEDED,
205
203
  start_time=st,
206
204
  end_time=et,
207
- outputs=run_definition_pb2.OutputReferences(
205
+ outputs=common_pb2.OutputReferences(
208
206
  output_uri=outputs_uri,
209
207
  report_uri=report_uri,
210
208
  ),
@@ -1,8 +1,8 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import grpc.aio
4
+ from flyteidl2.workflow import queue_service_pb2_grpc, state_service_pb2_grpc
4
5
 
5
- from flyte._protos.workflow import queue_service_pb2_grpc, state_service_pb2_grpc
6
6
  from flyte.remote import create_channel
7
7
 
8
8
  from ._service_protocol import QueueService, StateService