flyte 2.0.0b32__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flyte might be problematic. Click here for more details.

Files changed (204) hide show
  1. flyte/__init__.py +108 -0
  2. flyte/_bin/__init__.py +0 -0
  3. flyte/_bin/debug.py +38 -0
  4. flyte/_bin/runtime.py +195 -0
  5. flyte/_bin/serve.py +178 -0
  6. flyte/_build.py +26 -0
  7. flyte/_cache/__init__.py +12 -0
  8. flyte/_cache/cache.py +147 -0
  9. flyte/_cache/defaults.py +9 -0
  10. flyte/_cache/local_cache.py +216 -0
  11. flyte/_cache/policy_function_body.py +42 -0
  12. flyte/_code_bundle/__init__.py +8 -0
  13. flyte/_code_bundle/_ignore.py +121 -0
  14. flyte/_code_bundle/_packaging.py +218 -0
  15. flyte/_code_bundle/_utils.py +347 -0
  16. flyte/_code_bundle/bundle.py +266 -0
  17. flyte/_constants.py +1 -0
  18. flyte/_context.py +155 -0
  19. flyte/_custom_context.py +73 -0
  20. flyte/_debug/__init__.py +0 -0
  21. flyte/_debug/constants.py +38 -0
  22. flyte/_debug/utils.py +17 -0
  23. flyte/_debug/vscode.py +307 -0
  24. flyte/_deploy.py +408 -0
  25. flyte/_deployer.py +109 -0
  26. flyte/_doc.py +29 -0
  27. flyte/_docstring.py +32 -0
  28. flyte/_environment.py +122 -0
  29. flyte/_excepthook.py +37 -0
  30. flyte/_group.py +32 -0
  31. flyte/_hash.py +8 -0
  32. flyte/_image.py +1055 -0
  33. flyte/_initialize.py +628 -0
  34. flyte/_interface.py +119 -0
  35. flyte/_internal/__init__.py +3 -0
  36. flyte/_internal/controllers/__init__.py +129 -0
  37. flyte/_internal/controllers/_local_controller.py +239 -0
  38. flyte/_internal/controllers/_trace.py +48 -0
  39. flyte/_internal/controllers/remote/__init__.py +58 -0
  40. flyte/_internal/controllers/remote/_action.py +211 -0
  41. flyte/_internal/controllers/remote/_client.py +47 -0
  42. flyte/_internal/controllers/remote/_controller.py +583 -0
  43. flyte/_internal/controllers/remote/_core.py +465 -0
  44. flyte/_internal/controllers/remote/_informer.py +381 -0
  45. flyte/_internal/controllers/remote/_service_protocol.py +50 -0
  46. flyte/_internal/imagebuild/__init__.py +3 -0
  47. flyte/_internal/imagebuild/docker_builder.py +706 -0
  48. flyte/_internal/imagebuild/image_builder.py +277 -0
  49. flyte/_internal/imagebuild/remote_builder.py +386 -0
  50. flyte/_internal/imagebuild/utils.py +78 -0
  51. flyte/_internal/resolvers/__init__.py +0 -0
  52. flyte/_internal/resolvers/_task_module.py +21 -0
  53. flyte/_internal/resolvers/common.py +31 -0
  54. flyte/_internal/resolvers/default.py +28 -0
  55. flyte/_internal/runtime/__init__.py +0 -0
  56. flyte/_internal/runtime/convert.py +486 -0
  57. flyte/_internal/runtime/entrypoints.py +204 -0
  58. flyte/_internal/runtime/io.py +188 -0
  59. flyte/_internal/runtime/resources_serde.py +152 -0
  60. flyte/_internal/runtime/reuse.py +125 -0
  61. flyte/_internal/runtime/rusty.py +193 -0
  62. flyte/_internal/runtime/task_serde.py +362 -0
  63. flyte/_internal/runtime/taskrunner.py +209 -0
  64. flyte/_internal/runtime/trigger_serde.py +160 -0
  65. flyte/_internal/runtime/types_serde.py +54 -0
  66. flyte/_keyring/__init__.py +0 -0
  67. flyte/_keyring/file.py +115 -0
  68. flyte/_logging.py +300 -0
  69. flyte/_map.py +312 -0
  70. flyte/_module.py +72 -0
  71. flyte/_pod.py +30 -0
  72. flyte/_resources.py +473 -0
  73. flyte/_retry.py +32 -0
  74. flyte/_reusable_environment.py +102 -0
  75. flyte/_run.py +724 -0
  76. flyte/_secret.py +96 -0
  77. flyte/_task.py +550 -0
  78. flyte/_task_environment.py +316 -0
  79. flyte/_task_plugins.py +47 -0
  80. flyte/_timeout.py +47 -0
  81. flyte/_tools.py +27 -0
  82. flyte/_trace.py +119 -0
  83. flyte/_trigger.py +1000 -0
  84. flyte/_utils/__init__.py +30 -0
  85. flyte/_utils/asyn.py +121 -0
  86. flyte/_utils/async_cache.py +139 -0
  87. flyte/_utils/coro_management.py +27 -0
  88. flyte/_utils/docker_credentials.py +173 -0
  89. flyte/_utils/file_handling.py +72 -0
  90. flyte/_utils/helpers.py +134 -0
  91. flyte/_utils/lazy_module.py +54 -0
  92. flyte/_utils/module_loader.py +104 -0
  93. flyte/_utils/org_discovery.py +57 -0
  94. flyte/_utils/uv_script_parser.py +49 -0
  95. flyte/_version.py +34 -0
  96. flyte/app/__init__.py +22 -0
  97. flyte/app/_app_environment.py +157 -0
  98. flyte/app/_deploy.py +125 -0
  99. flyte/app/_input.py +160 -0
  100. flyte/app/_runtime/__init__.py +3 -0
  101. flyte/app/_runtime/app_serde.py +347 -0
  102. flyte/app/_types.py +101 -0
  103. flyte/app/extras/__init__.py +3 -0
  104. flyte/app/extras/_fastapi.py +151 -0
  105. flyte/cli/__init__.py +12 -0
  106. flyte/cli/_abort.py +28 -0
  107. flyte/cli/_build.py +114 -0
  108. flyte/cli/_common.py +468 -0
  109. flyte/cli/_create.py +371 -0
  110. flyte/cli/_delete.py +45 -0
  111. flyte/cli/_deploy.py +293 -0
  112. flyte/cli/_gen.py +176 -0
  113. flyte/cli/_get.py +370 -0
  114. flyte/cli/_option.py +33 -0
  115. flyte/cli/_params.py +554 -0
  116. flyte/cli/_plugins.py +209 -0
  117. flyte/cli/_run.py +597 -0
  118. flyte/cli/_serve.py +64 -0
  119. flyte/cli/_update.py +37 -0
  120. flyte/cli/_user.py +17 -0
  121. flyte/cli/main.py +221 -0
  122. flyte/config/__init__.py +3 -0
  123. flyte/config/_config.py +248 -0
  124. flyte/config/_internal.py +73 -0
  125. flyte/config/_reader.py +225 -0
  126. flyte/connectors/__init__.py +11 -0
  127. flyte/connectors/_connector.py +270 -0
  128. flyte/connectors/_server.py +197 -0
  129. flyte/connectors/utils.py +135 -0
  130. flyte/errors.py +243 -0
  131. flyte/extend.py +19 -0
  132. flyte/extras/__init__.py +5 -0
  133. flyte/extras/_container.py +286 -0
  134. flyte/git/__init__.py +3 -0
  135. flyte/git/_config.py +21 -0
  136. flyte/io/__init__.py +29 -0
  137. flyte/io/_dataframe/__init__.py +131 -0
  138. flyte/io/_dataframe/basic_dfs.py +223 -0
  139. flyte/io/_dataframe/dataframe.py +1026 -0
  140. flyte/io/_dir.py +910 -0
  141. flyte/io/_file.py +914 -0
  142. flyte/io/_hashing_io.py +342 -0
  143. flyte/models.py +479 -0
  144. flyte/py.typed +0 -0
  145. flyte/remote/__init__.py +35 -0
  146. flyte/remote/_action.py +738 -0
  147. flyte/remote/_app.py +57 -0
  148. flyte/remote/_client/__init__.py +0 -0
  149. flyte/remote/_client/_protocols.py +189 -0
  150. flyte/remote/_client/auth/__init__.py +12 -0
  151. flyte/remote/_client/auth/_auth_utils.py +14 -0
  152. flyte/remote/_client/auth/_authenticators/__init__.py +0 -0
  153. flyte/remote/_client/auth/_authenticators/base.py +403 -0
  154. flyte/remote/_client/auth/_authenticators/client_credentials.py +73 -0
  155. flyte/remote/_client/auth/_authenticators/device_code.py +117 -0
  156. flyte/remote/_client/auth/_authenticators/external_command.py +79 -0
  157. flyte/remote/_client/auth/_authenticators/factory.py +200 -0
  158. flyte/remote/_client/auth/_authenticators/pkce.py +516 -0
  159. flyte/remote/_client/auth/_channel.py +213 -0
  160. flyte/remote/_client/auth/_client_config.py +85 -0
  161. flyte/remote/_client/auth/_default_html.py +32 -0
  162. flyte/remote/_client/auth/_grpc_utils/__init__.py +0 -0
  163. flyte/remote/_client/auth/_grpc_utils/auth_interceptor.py +288 -0
  164. flyte/remote/_client/auth/_grpc_utils/default_metadata_interceptor.py +151 -0
  165. flyte/remote/_client/auth/_keyring.py +152 -0
  166. flyte/remote/_client/auth/_token_client.py +260 -0
  167. flyte/remote/_client/auth/errors.py +16 -0
  168. flyte/remote/_client/controlplane.py +128 -0
  169. flyte/remote/_common.py +30 -0
  170. flyte/remote/_console.py +19 -0
  171. flyte/remote/_data.py +161 -0
  172. flyte/remote/_logs.py +185 -0
  173. flyte/remote/_project.py +88 -0
  174. flyte/remote/_run.py +386 -0
  175. flyte/remote/_secret.py +142 -0
  176. flyte/remote/_task.py +527 -0
  177. flyte/remote/_trigger.py +306 -0
  178. flyte/remote/_user.py +33 -0
  179. flyte/report/__init__.py +3 -0
  180. flyte/report/_report.py +182 -0
  181. flyte/report/_template.html +124 -0
  182. flyte/storage/__init__.py +36 -0
  183. flyte/storage/_config.py +237 -0
  184. flyte/storage/_parallel_reader.py +274 -0
  185. flyte/storage/_remote_fs.py +34 -0
  186. flyte/storage/_storage.py +456 -0
  187. flyte/storage/_utils.py +5 -0
  188. flyte/syncify/__init__.py +56 -0
  189. flyte/syncify/_api.py +375 -0
  190. flyte/types/__init__.py +52 -0
  191. flyte/types/_interface.py +40 -0
  192. flyte/types/_pickle.py +145 -0
  193. flyte/types/_renderer.py +162 -0
  194. flyte/types/_string_literals.py +119 -0
  195. flyte/types/_type_engine.py +2254 -0
  196. flyte/types/_utils.py +80 -0
  197. flyte-2.0.0b32.data/scripts/debug.py +38 -0
  198. flyte-2.0.0b32.data/scripts/runtime.py +195 -0
  199. flyte-2.0.0b32.dist-info/METADATA +351 -0
  200. flyte-2.0.0b32.dist-info/RECORD +204 -0
  201. flyte-2.0.0b32.dist-info/WHEEL +5 -0
  202. flyte-2.0.0b32.dist-info/entry_points.txt +7 -0
  203. flyte-2.0.0b32.dist-info/licenses/LICENSE +201 -0
  204. flyte-2.0.0b32.dist-info/top_level.txt +1 -0
@@ -0,0 +1,193 @@
1
+ import asyncio
2
+ import time
3
+ from typing import List, Tuple
4
+
5
+ from flyte._context import contextual_run
6
+ from flyte._internal.controllers import Controller
7
+ from flyte._internal.controllers import create_controller as _create_controller
8
+ from flyte._internal.imagebuild.image_builder import ImageCache
9
+ from flyte._internal.runtime.entrypoints import download_code_bundle, load_pkl_task, load_task
10
+ from flyte._internal.runtime.taskrunner import extract_download_run_upload
11
+ from flyte._logging import logger
12
+ from flyte._task import TaskTemplate
13
+ from flyte._utils import adjust_sys_path
14
+ from flyte.models import ActionID, Checkpoints, CodeBundle, PathRewrite, RawDataPath
15
+
16
+
17
+ async def download_tgz(destination: str, version: str, tgz: str) -> CodeBundle:
18
+ """
19
+ Downloads and loads the task from the code bundle or resolver.
20
+ :param tgz: The path to the task template in a tar.gz format.
21
+ :param destination: The path to save the downloaded task template.
22
+ :param version: The version of the task to load.
23
+ :return: The CodeBundle object.
24
+ """
25
+ logger.info(f"[rusty] Downloading tgz code bundle from {tgz} to {destination} with version {version}")
26
+ adjust_sys_path()
27
+
28
+ code_bundle = CodeBundle(
29
+ tgz=tgz,
30
+ destination=destination,
31
+ computed_version=version,
32
+ )
33
+ return await download_code_bundle(code_bundle)
34
+
35
+
36
+ async def download_load_pkl(destination: str, version: str, pkl: str) -> Tuple[CodeBundle, TaskTemplate]:
37
+ """
38
+ Downloads and loads the task from the code bundle or resolver.
39
+ :param pkl: The path to the task template in a pickle format.
40
+ :param destination: The path to save the downloaded task template.
41
+ :param version: The version of the task to load.
42
+ :return: The CodeBundle object.
43
+ """
44
+ logger.info(f"[rusty] Downloading pkl code bundle from {pkl} to {destination} with version {version}")
45
+ adjust_sys_path()
46
+
47
+ code_bundle = CodeBundle(
48
+ pkl=pkl,
49
+ destination=destination,
50
+ computed_version=version,
51
+ )
52
+ code_bundle = await download_code_bundle(code_bundle)
53
+ return code_bundle, load_pkl_task(code_bundle)
54
+
55
+
56
+ def load_task_from_code_bundle(resolver: str, resolver_args: List[str]) -> TaskTemplate:
57
+ """
58
+ Loads the task from the code bundle or resolver.
59
+ :param resolver: The resolver to use to load the task.
60
+ :param resolver_args: The arguments to pass to the resolver.
61
+ :return: The loaded task template.
62
+ """
63
+ logger.debug(f"[rusty] Loading task from code bundle {resolver} with args: {resolver_args}")
64
+ return load_task(resolver, *resolver_args)
65
+
66
+
67
+ async def create_controller(
68
+ endpoint: str = "host.docker.internal:8090",
69
+ insecure: bool = False,
70
+ api_key: str | None = None,
71
+ ) -> Controller:
72
+ """
73
+ Creates a controller instance for remote operations.
74
+ :param endpoint:
75
+ :param insecure:
76
+ :param api_key:
77
+ :return:
78
+ """
79
+ logger.info(f"[rusty] Creating controller with endpoint {endpoint}")
80
+ import flyte.errors
81
+ from flyte._initialize import init_in_cluster
82
+
83
+ loop = asyncio.get_event_loop()
84
+ loop.set_exception_handler(flyte.errors.silence_grpc_polling_error)
85
+
86
+ # TODO Currently reference tasks are not supported in Rusty.
87
+ controller_kwargs = await init_in_cluster.aio(api_key=api_key, endpoint=endpoint, insecure=insecure)
88
+ return _create_controller(ct="remote", **controller_kwargs)
89
+
90
+
91
+ async def run_task(
92
+ task: TaskTemplate,
93
+ controller: Controller,
94
+ org: str,
95
+ project: str,
96
+ domain: str,
97
+ run_name: str,
98
+ name: str,
99
+ raw_data_path: str,
100
+ output_path: str,
101
+ run_base_dir: str,
102
+ version: str,
103
+ image_cache: str | None = None,
104
+ checkpoint_path: str | None = None,
105
+ prev_checkpoint: str | None = None,
106
+ code_bundle: CodeBundle | None = None,
107
+ input_path: str | None = None,
108
+ path_rewrite_cfg: str | None = None,
109
+ ):
110
+ """
111
+ Runs the task with the provided parameters.
112
+ :param prev_checkpoint: Previous checkpoint path to resume from.
113
+ :param checkpoint_path: Checkpoint path to save the current state.
114
+ :param image_cache: Image cache to use for the task.
115
+ :param name: Action name to run.
116
+ :param run_name: Parent run name to use for the task.
117
+ :param domain: domain to run the task in.
118
+ :param project: project to run the task in.
119
+ :param org: organization to run the task in.
120
+ :param task: The task template to run.
121
+ :param raw_data_path: The path to the raw data.
122
+ :param output_path: The path to save the output.
123
+ :param run_base_dir: The base directory for the run.
124
+ :param version: The version of the task to run.
125
+ :param controller: The controller to use for the task.
126
+ :param code_bundle: Optional code bundle for the task.
127
+ :param input_path: Optional input path for the task.
128
+ :param path_rewrite_cfg: Optional path rewrite configuration.
129
+ :return: The loaded task template.
130
+ """
131
+ start_time = time.time()
132
+ action_id = f"{org}/{project}/{domain}/{run_name}/{name}"
133
+
134
+ logger.info(
135
+ f"[rusty] Running task '{task.name}' (action: {action_id})"
136
+ f" at {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time))}"
137
+ )
138
+
139
+ path_rewrite = PathRewrite.from_str(path_rewrite_cfg) if path_rewrite_cfg else None
140
+ if path_rewrite:
141
+ import flyte.storage as storage
142
+
143
+ if not await storage.exists(path_rewrite.new_prefix):
144
+ logger.error(
145
+ f"[rusty] Path rewrite failed for path {path_rewrite.new_prefix}, "
146
+ f"not found, reverting to original path {path_rewrite.old_prefix}"
147
+ )
148
+ path_rewrite = None
149
+ else:
150
+ logger.info(f"[rusty] Using path rewrite: {path_rewrite}")
151
+
152
+ try:
153
+ await contextual_run(
154
+ extract_download_run_upload,
155
+ task,
156
+ action=ActionID(name=name, org=org, project=project, domain=domain, run_name=run_name),
157
+ version=version,
158
+ controller=controller,
159
+ raw_data_path=RawDataPath(path=raw_data_path, path_rewrite=path_rewrite),
160
+ output_path=output_path,
161
+ run_base_dir=run_base_dir,
162
+ checkpoints=Checkpoints(prev_checkpoint_path=prev_checkpoint, checkpoint_path=checkpoint_path),
163
+ code_bundle=code_bundle,
164
+ input_path=input_path,
165
+ image_cache=ImageCache.from_transport(image_cache) if image_cache else None,
166
+ )
167
+ except Exception as e:
168
+ logger.error(f"[rusty] Task failed: {e!s}")
169
+ raise
170
+ finally:
171
+ end_time = time.time()
172
+ duration = end_time - start_time
173
+ logger.info(
174
+ f"[rusty] TASK_EXECUTION_END: Task '{task.name}' (action: {action_id})"
175
+ f" done after {duration:.2f}s at {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(end_time))}"
176
+ )
177
+
178
+
179
+ async def ping(name: str) -> str:
180
+ """
181
+ A simple hello world function to test the Rusty entrypoint.
182
+ """
183
+ print(f"Received ping request from {name} in Rusty!")
184
+ return f"pong from Rusty to {name}!"
185
+
186
+
187
+ async def hello(name: str):
188
+ """
189
+ A simple hello world function to test the Rusty entrypoint.
190
+ :param name: The name to greet.
191
+ :return: A greeting message.
192
+ """
193
+ print(f"Received hello request in Rusty with name: {name}!")
@@ -0,0 +1,362 @@
1
+ """
2
+ This module provides functionality to serialize and deserialize tasks to and from the wire format.
3
+ It includes a Resolver interface for loading tasks, and functions to load classes and tasks.
4
+ """
5
+
6
+ import copy
7
+ import typing
8
+ from datetime import timedelta
9
+ from typing import Optional, cast
10
+
11
+ from flyteidl2.core import identifier_pb2, literals_pb2, security_pb2, tasks_pb2
12
+ from flyteidl2.task import common_pb2, environment_pb2, task_definition_pb2
13
+ from google.protobuf import duration_pb2, wrappers_pb2
14
+
15
+ import flyte.errors
16
+ from flyte._cache.cache import VersionParameters, cache_from_request
17
+ from flyte._logging import logger
18
+ from flyte._pod import _PRIMARY_CONTAINER_NAME_FIELD, PodTemplate
19
+ from flyte._secret import SecretRequest, secrets_from_request
20
+ from flyte._task import AsyncFunctionTaskTemplate, TaskTemplate
21
+ from flyte.models import CodeBundle, SerializationContext
22
+
23
+ from ... import ReusePolicy
24
+ from ..._retry import RetryStrategy
25
+ from ..._timeout import TimeoutType, timeout_from_request
26
+ from .resources_serde import get_proto_extended_resources, get_proto_resources
27
+ from .reuse import add_reusable
28
+ from .types_serde import transform_native_to_typed_interface
29
+
30
+ _MAX_ENV_NAME_LENGTH = 63 # Maximum length for environment names
31
+ _MAX_TASK_SHORT_NAME_LENGTH = 63 # Maximum length for task short names
32
+
33
+
34
+ def translate_task_to_wire(
35
+ task: TaskTemplate,
36
+ serialization_context: SerializationContext,
37
+ default_inputs: Optional[typing.List[common_pb2.NamedParameter]] = None,
38
+ ) -> task_definition_pb2.TaskSpec:
39
+ """
40
+ Translate a task to a wire format. This is a placeholder function.
41
+
42
+ :param task: The task to translate.
43
+ :param serialization_context: The serialization context to use for the translation.
44
+ :param default_inputs: Optional list of default inputs for the task.
45
+
46
+ :return: The translated task.
47
+ """
48
+ tt = get_proto_task(task, serialization_context)
49
+ env: environment_pb2.Environment | None = None
50
+ if task.parent_env and task.parent_env():
51
+ _env = task.parent_env()
52
+ if _env:
53
+ env = environment_pb2.Environment(name=_env.name[:_MAX_ENV_NAME_LENGTH])
54
+ return task_definition_pb2.TaskSpec(
55
+ task_template=tt,
56
+ default_inputs=default_inputs,
57
+ short_name=task.short_name[:_MAX_TASK_SHORT_NAME_LENGTH],
58
+ environment=env,
59
+ )
60
+
61
+
62
+ def get_security_context(
63
+ secrets: Optional[SecretRequest],
64
+ ) -> Optional[security_pb2.SecurityContext]:
65
+ """
66
+ Get the security context from a list of secrets. This is a placeholder function.
67
+
68
+ :param secrets: The list of secrets to use for the security context.
69
+
70
+ :return: The security context.
71
+ """
72
+ if secrets is None:
73
+ return None
74
+
75
+ secret_list = secrets_from_request(secrets)
76
+ return security_pb2.SecurityContext(
77
+ secrets=[
78
+ security_pb2.Secret(
79
+ group=secret.group,
80
+ key=secret.key,
81
+ mount_requirement=(
82
+ security_pb2.Secret.MountType.ENV_VAR if secret.as_env_var else security_pb2.Secret.MountType.FILE
83
+ ),
84
+ env_var=secret.as_env_var,
85
+ )
86
+ for secret in secret_list
87
+ ]
88
+ )
89
+
90
+
91
+ def get_proto_retry_strategy(
92
+ retries: RetryStrategy | int | None,
93
+ ) -> Optional[literals_pb2.RetryStrategy]:
94
+ if retries is None:
95
+ return None
96
+
97
+ if isinstance(retries, int):
98
+ raise AssertionError("Retries should be an instance of RetryStrategy, not int")
99
+
100
+ return literals_pb2.RetryStrategy(retries=retries.count)
101
+
102
+
103
+ def get_proto_timeout(timeout: TimeoutType | None) -> Optional[duration_pb2.Duration]:
104
+ if timeout is None:
105
+ return None
106
+ max_runtime_timeout = timeout_from_request(timeout).max_runtime
107
+ if isinstance(max_runtime_timeout, int):
108
+ max_runtime_timeout = timedelta(seconds=max_runtime_timeout)
109
+ return duration_pb2.Duration(seconds=max_runtime_timeout.seconds)
110
+
111
+
112
+ def get_proto_task(task: TaskTemplate, serialize_context: SerializationContext) -> tasks_pb2.TaskTemplate:
113
+ task_id = identifier_pb2.Identifier(
114
+ resource_type=identifier_pb2.ResourceType.TASK,
115
+ project=serialize_context.project,
116
+ domain=serialize_context.domain,
117
+ org=serialize_context.org,
118
+ name=task.name,
119
+ version=serialize_context.version,
120
+ )
121
+
122
+ # TODO Add support for extra_config, custom
123
+ extra_config: typing.Dict[str, str] = {}
124
+
125
+ if task.pod_template and not isinstance(task.pod_template, str):
126
+ pod = _get_k8s_pod(_get_urun_container(serialize_context, task), task.pod_template)
127
+ extra_config[_PRIMARY_CONTAINER_NAME_FIELD] = task.pod_template.primary_container_name
128
+ container = None
129
+ else:
130
+ container = _get_urun_container(serialize_context, task)
131
+ pod = None
132
+
133
+ custom = task.custom_config(serialize_context)
134
+
135
+ sql = task.sql(serialize_context)
136
+
137
+ # -------------- CACHE HANDLING ----------------------
138
+ task_cache = cache_from_request(task.cache)
139
+ cache_enabled = task_cache.is_enabled()
140
+ cache_version = None
141
+
142
+ if task_cache.is_enabled():
143
+ logger.debug(f"Cache enabled for task {task.name}")
144
+ if serialize_context.code_bundle and serialize_context.code_bundle.pkl:
145
+ logger.debug(f"Detected pkl bundle for task {task.name}, using computed version as cache version")
146
+ cache_version = serialize_context.code_bundle.computed_version
147
+ else:
148
+ if isinstance(task, AsyncFunctionTaskTemplate):
149
+ version_parameters = VersionParameters(func=task.func, image=task.image)
150
+ else:
151
+ version_parameters = VersionParameters(func=None, image=task.image)
152
+ cache_version = task_cache.get_version(version_parameters)
153
+ logger.debug(f"Cache version for task {task.name} is {cache_version}")
154
+ else:
155
+ logger.debug(f"Cache disabled for task {task.name}")
156
+
157
+ task_template = tasks_pb2.TaskTemplate(
158
+ id=task_id,
159
+ type=task.task_type,
160
+ metadata=tasks_pb2.TaskMetadata(
161
+ discoverable=cache_enabled,
162
+ discovery_version=cache_version,
163
+ cache_serializable=task_cache.serialize,
164
+ cache_ignore_input_vars=(task_cache.get_ignored_inputs() if cache_enabled else None),
165
+ runtime=tasks_pb2.RuntimeMetadata(
166
+ version=flyte.version(),
167
+ type=tasks_pb2.RuntimeMetadata.RuntimeType.FLYTE_SDK,
168
+ flavor="python",
169
+ ),
170
+ retries=get_proto_retry_strategy(task.retries),
171
+ timeout=get_proto_timeout(task.timeout),
172
+ pod_template_name=(task.pod_template if task.pod_template and isinstance(task.pod_template, str) else None),
173
+ interruptible=task.interruptible,
174
+ generates_deck=wrappers_pb2.BoolValue(value=task.report),
175
+ debuggable=task.debuggable,
176
+ ),
177
+ interface=transform_native_to_typed_interface(task.native_interface),
178
+ custom=custom if len(custom) > 0 else None,
179
+ container=container,
180
+ task_type_version=task.task_type_version,
181
+ security_context=get_security_context(task.secrets),
182
+ config=extra_config,
183
+ k8s_pod=pod,
184
+ sql=sql,
185
+ extended_resources=get_proto_extended_resources(task.resources),
186
+ )
187
+
188
+ if task.reusable is not None:
189
+ if not isinstance(task.reusable, ReusePolicy):
190
+ raise flyte.errors.RuntimeUserError(
191
+ "BadConfig", f"Expected ReusePolicy, got {type(task.reusable)} for task {task.name}"
192
+ )
193
+ env_name = None
194
+ if task.parent_env is not None:
195
+ env = task.parent_env()
196
+ if env is not None:
197
+ env_name = env.name
198
+ return add_reusable(task_template, task.reusable, serialize_context.code_bundle, env_name)
199
+
200
+ return task_template
201
+
202
+
203
+ def lookup_image_in_cache(serialize_context: SerializationContext, env_name: str, image: flyte.Image) -> str:
204
+ if not serialize_context.image_cache:
205
+ # This computes the image uri, computing hashes as necessary so can fail if done remotely.
206
+ return image.uri
207
+ elif serialize_context.image_cache and env_name not in serialize_context.image_cache.image_lookup:
208
+ raise flyte.errors.RuntimeUserError(
209
+ "MissingEnvironment",
210
+ f"Environment '{env_name}' not found in image cache.\n\n"
211
+ "💡 To fix this:\n"
212
+ " 1. If your parent environment calls a task in another environment,"
213
+ " declare that dependency using 'depends_on=[...]'.\n"
214
+ " Example:\n"
215
+ " env1 = flyte.TaskEnvironment(\n"
216
+ " name='outer',\n"
217
+ " image=flyte.Image.from_debian_base().with_pip_packages('requests'),\n"
218
+ " depends_on=[env2, env3],\n"
219
+ " )\n"
220
+ " 2. If you're using os.getenv() to set the environment name,"
221
+ " make sure the runtime environment has the same environment variable defined.\n"
222
+ " Example:\n"
223
+ " env = flyte.TaskEnvironment(\n"
224
+ ' name=os.getenv("my-name"),\n'
225
+ ' env_vars={"my-name": os.getenv("my-name")},\n'
226
+ " )\n",
227
+ )
228
+ return serialize_context.image_cache.image_lookup[env_name]
229
+
230
+
231
+ def _get_urun_container(
232
+ serialize_context: SerializationContext, task_template: TaskTemplate
233
+ ) -> Optional[tasks_pb2.Container]:
234
+ env = (
235
+ [literals_pb2.KeyValuePair(key=k, value=v) for k, v in task_template.env_vars.items()]
236
+ if task_template.env_vars
237
+ else None
238
+ )
239
+ resources = get_proto_resources(task_template.resources)
240
+
241
+ img = task_template.image
242
+ if isinstance(img, str):
243
+ raise flyte.errors.RuntimeSystemError("BadConfig", "Image is not a valid image")
244
+
245
+ env_name = task_template.parent_env_name
246
+ if env_name is None:
247
+ raise flyte.errors.RuntimeSystemError("BadConfig", f"Task {task_template.name} has no parent environment name")
248
+
249
+ img_uri = lookup_image_in_cache(serialize_context, env_name, img)
250
+
251
+ return tasks_pb2.Container(
252
+ image=img_uri,
253
+ command=[],
254
+ args=task_template.container_args(serialize_context),
255
+ resources=resources,
256
+ env=env,
257
+ data_config=task_template.data_loading_config(serialize_context),
258
+ config=task_template.config(serialize_context),
259
+ )
260
+
261
+
262
+ def _sanitize_resource_name(resource: tasks_pb2.Resources.ResourceEntry) -> str:
263
+ return tasks_pb2.Resources.ResourceName.Name(resource.name).lower().replace("_", "-")
264
+
265
+
266
+ def _get_k8s_pod(primary_container: tasks_pb2.Container, pod_template: PodTemplate) -> Optional[tasks_pb2.K8sPod]:
267
+ """
268
+ Get the K8sPod representation of the task template.
269
+ :param task: The task to convert.
270
+ :return: The K8sPod representation of the task template.
271
+ """
272
+ from kubernetes.client import ApiClient, V1PodSpec
273
+ from kubernetes.client.models import V1EnvVar, V1ResourceRequirements
274
+
275
+ pod_template = copy.deepcopy(pod_template)
276
+ containers = cast(V1PodSpec, pod_template.pod_spec).containers
277
+ primary_exists = False
278
+
279
+ for container in containers:
280
+ if container.name == pod_template.primary_container_name:
281
+ primary_exists = True
282
+ break
283
+
284
+ if not primary_exists:
285
+ raise ValueError(
286
+ "No primary container defined in the pod spec."
287
+ f" You must define a primary container with the name '{pod_template.primary_container_name}'."
288
+ )
289
+ final_containers = []
290
+
291
+ for container in containers:
292
+ # We overwrite the primary container attributes with the values given to ContainerTask.
293
+ # The attributes include: image, command, args, resource, and env (env is unioned)
294
+
295
+ if container.name == pod_template.primary_container_name:
296
+ if container.image is None:
297
+ # Copy the image from primary_container only if the image is not specified in the pod spec.
298
+ container.image = primary_container.image
299
+
300
+ container.command = list(primary_container.command)
301
+ container.args = list(primary_container.args)
302
+
303
+ limits, requests = {}, {}
304
+ for resource in primary_container.resources.limits:
305
+ limits[_sanitize_resource_name(resource)] = resource.value
306
+ for resource in primary_container.resources.requests:
307
+ requests[_sanitize_resource_name(resource)] = resource.value
308
+
309
+ resource_requirements = V1ResourceRequirements(limits=limits, requests=requests)
310
+ if len(limits) > 0 or len(requests) > 0:
311
+ # Important! Only copy over resource requirements if they are non-empty.
312
+ container.resources = resource_requirements
313
+
314
+ if primary_container.env is not None:
315
+ container.env = [V1EnvVar(name=e.key, value=e.value) for e in primary_container.env] + (
316
+ container.env or []
317
+ )
318
+
319
+ final_containers.append(container)
320
+
321
+ cast(V1PodSpec, pod_template.pod_spec).containers = final_containers
322
+ pod_spec = ApiClient().sanitize_for_serialization(pod_template.pod_spec)
323
+
324
+ metadata = tasks_pb2.K8sObjectMetadata(labels=pod_template.labels, annotations=pod_template.annotations)
325
+ return tasks_pb2.K8sPod(pod_spec=pod_spec, metadata=metadata)
326
+
327
+
328
+ def extract_code_bundle(
329
+ task_spec: task_definition_pb2.TaskSpec,
330
+ ) -> Optional[CodeBundle]:
331
+ """
332
+ Extract the code bundle from the task spec.
333
+ :param task_spec: The task spec to extract the code bundle from.
334
+ :return: The extracted code bundle or None if not present.
335
+ """
336
+ container = task_spec.task_template.container
337
+ if container and container.args:
338
+ pkl_path = None
339
+ tgz_path = None
340
+ dest_path: str = "."
341
+ version = ""
342
+ for i, v in enumerate(container.args):
343
+ if v == "--pkl":
344
+ # Extract the code bundle path from the argument
345
+ pkl_path = container.args[i + 1] if i + 1 < len(container.args) else None
346
+ elif v == "--tgz":
347
+ # Extract the code bundle path from the argument
348
+ tgz_path = container.args[i + 1] if i + 1 < len(container.args) else None
349
+ elif v == "--dest":
350
+ # Extract the destination path from the argument
351
+ dest_path = container.args[i + 1] if i + 1 < len(container.args) else "."
352
+ elif v == "--version":
353
+ # Extract the version from the argument
354
+ version = container.args[i + 1] if i + 1 < len(container.args) else ""
355
+ if pkl_path or tgz_path:
356
+ return CodeBundle(
357
+ destination=dest_path,
358
+ tgz=tgz_path,
359
+ pkl=pkl_path,
360
+ computed_version=version,
361
+ )
362
+ return None