flyte 0.0.1b3__py3-none-any.whl → 0.2.0a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flyte might be problematic. Click here for more details.

Files changed (319) hide show
  1. flyte/__init__.py +20 -4
  2. flyte/_bin/runtime.py +33 -7
  3. flyte/_build.py +3 -2
  4. flyte/_cache/cache.py +1 -2
  5. flyte/_code_bundle/_packaging.py +1 -1
  6. flyte/_code_bundle/_utils.py +0 -16
  7. flyte/_code_bundle/bundle.py +43 -12
  8. flyte/_context.py +8 -2
  9. flyte/_deploy.py +56 -15
  10. flyte/_environment.py +45 -4
  11. flyte/_excepthook.py +37 -0
  12. flyte/_group.py +2 -1
  13. flyte/_image.py +8 -4
  14. flyte/_initialize.py +112 -254
  15. flyte/_interface.py +3 -3
  16. flyte/_internal/controllers/__init__.py +19 -6
  17. flyte/_internal/controllers/_local_controller.py +83 -8
  18. flyte/_internal/controllers/_trace.py +2 -1
  19. flyte/_internal/controllers/remote/__init__.py +27 -7
  20. flyte/_internal/controllers/remote/_action.py +7 -2
  21. flyte/_internal/controllers/remote/_client.py +5 -1
  22. flyte/_internal/controllers/remote/_controller.py +159 -26
  23. flyte/_internal/controllers/remote/_core.py +13 -5
  24. flyte/_internal/controllers/remote/_informer.py +4 -4
  25. flyte/_internal/controllers/remote/_service_protocol.py +6 -6
  26. flyte/_internal/imagebuild/docker_builder.py +12 -1
  27. flyte/_internal/imagebuild/image_builder.py +16 -11
  28. flyte/_internal/runtime/convert.py +164 -21
  29. flyte/_internal/runtime/entrypoints.py +1 -1
  30. flyte/_internal/runtime/io.py +3 -3
  31. flyte/_internal/runtime/task_serde.py +140 -20
  32. flyte/_internal/runtime/taskrunner.py +4 -3
  33. flyte/_internal/runtime/types_serde.py +1 -1
  34. flyte/_logging.py +12 -1
  35. flyte/_map.py +215 -0
  36. flyte/_pod.py +19 -0
  37. flyte/_protos/common/list_pb2.py +3 -3
  38. flyte/_protos/common/list_pb2.pyi +2 -0
  39. flyte/_protos/logs/dataplane/payload_pb2.py +28 -24
  40. flyte/_protos/logs/dataplane/payload_pb2.pyi +11 -2
  41. flyte/_protos/workflow/common_pb2.py +27 -0
  42. flyte/_protos/workflow/common_pb2.pyi +14 -0
  43. flyte/_protos/workflow/environment_pb2.py +29 -0
  44. flyte/_protos/workflow/environment_pb2.pyi +12 -0
  45. flyte/_protos/workflow/queue_service_pb2.py +40 -41
  46. flyte/_protos/workflow/queue_service_pb2.pyi +35 -30
  47. flyte/_protos/workflow/queue_service_pb2_grpc.py +15 -15
  48. flyte/_protos/workflow/run_definition_pb2.py +61 -61
  49. flyte/_protos/workflow/run_definition_pb2.pyi +8 -4
  50. flyte/_protos/workflow/run_service_pb2.py +20 -24
  51. flyte/_protos/workflow/run_service_pb2.pyi +2 -6
  52. flyte/_protos/workflow/state_service_pb2.py +36 -28
  53. flyte/_protos/workflow/state_service_pb2.pyi +19 -15
  54. flyte/_protos/workflow/state_service_pb2_grpc.py +28 -28
  55. flyte/_protos/workflow/task_definition_pb2.py +29 -22
  56. flyte/_protos/workflow/task_definition_pb2.pyi +21 -5
  57. flyte/_protos/workflow/task_service_pb2.py +27 -11
  58. flyte/_protos/workflow/task_service_pb2.pyi +29 -1
  59. flyte/_protos/workflow/task_service_pb2_grpc.py +34 -0
  60. flyte/_run.py +166 -95
  61. flyte/_task.py +110 -28
  62. flyte/_task_environment.py +55 -72
  63. flyte/_trace.py +6 -14
  64. flyte/_utils/__init__.py +6 -0
  65. flyte/_utils/async_cache.py +139 -0
  66. flyte/_utils/coro_management.py +0 -2
  67. flyte/_utils/helpers.py +45 -19
  68. flyte/_utils/org_discovery.py +57 -0
  69. flyte/_version.py +2 -2
  70. flyte/cli/__init__.py +3 -0
  71. flyte/cli/_abort.py +28 -0
  72. flyte/{_cli → cli}/_common.py +73 -23
  73. flyte/cli/_create.py +145 -0
  74. flyte/{_cli → cli}/_delete.py +4 -4
  75. flyte/{_cli → cli}/_deploy.py +26 -14
  76. flyte/cli/_gen.py +163 -0
  77. flyte/{_cli → cli}/_get.py +98 -23
  78. {union/_cli → flyte/cli}/_params.py +106 -147
  79. flyte/{_cli → cli}/_run.py +99 -20
  80. flyte/cli/main.py +166 -0
  81. flyte/config/__init__.py +3 -0
  82. flyte/config/_config.py +216 -0
  83. flyte/config/_internal.py +64 -0
  84. flyte/config/_reader.py +207 -0
  85. flyte/errors.py +29 -0
  86. flyte/extras/_container.py +33 -43
  87. flyte/io/__init__.py +17 -1
  88. flyte/io/_dir.py +2 -2
  89. flyte/io/_file.py +3 -4
  90. flyte/io/{structured_dataset → _structured_dataset}/basic_dfs.py +1 -1
  91. flyte/io/{structured_dataset → _structured_dataset}/structured_dataset.py +1 -1
  92. flyte/{_datastructures.py → models.py} +56 -7
  93. flyte/remote/__init__.py +2 -1
  94. flyte/remote/_client/_protocols.py +2 -0
  95. flyte/remote/_client/auth/_auth_utils.py +14 -0
  96. flyte/remote/_client/auth/_channel.py +34 -3
  97. flyte/remote/_client/auth/_token_client.py +3 -3
  98. flyte/remote/_client/controlplane.py +13 -13
  99. flyte/remote/_console.py +1 -1
  100. flyte/remote/_data.py +10 -6
  101. flyte/remote/_logs.py +89 -29
  102. flyte/remote/_project.py +8 -9
  103. flyte/remote/_run.py +228 -131
  104. flyte/remote/_secret.py +12 -12
  105. flyte/remote/_task.py +179 -15
  106. flyte/report/_report.py +4 -4
  107. flyte/storage/__init__.py +5 -0
  108. flyte/storage/_config.py +233 -0
  109. flyte/storage/_storage.py +23 -3
  110. flyte/syncify/__init__.py +56 -0
  111. flyte/syncify/_api.py +371 -0
  112. flyte/types/__init__.py +23 -0
  113. flyte/types/_interface.py +22 -7
  114. flyte/{io/pickle/transformer.py → types/_pickle.py} +2 -1
  115. flyte/types/_type_engine.py +95 -18
  116. flyte-0.2.0a0.dist-info/METADATA +249 -0
  117. flyte-0.2.0a0.dist-info/RECORD +218 -0
  118. {flyte-0.0.1b3.dist-info → flyte-0.2.0a0.dist-info}/entry_points.txt +1 -1
  119. flyte/_api_commons.py +0 -3
  120. flyte/_cli/__init__.py +0 -0
  121. flyte/_cli/_create.py +0 -42
  122. flyte/_cli/main.py +0 -72
  123. flyte/_internal/controllers/pbhash.py +0 -39
  124. flyte/io/_dataframe.py +0 -0
  125. flyte/io/pickle/__init__.py +0 -0
  126. flyte-0.0.1b3.dist-info/METADATA +0 -179
  127. flyte-0.0.1b3.dist-info/RECORD +0 -390
  128. union/__init__.py +0 -54
  129. union/_api_commons.py +0 -3
  130. union/_bin/__init__.py +0 -0
  131. union/_bin/runtime.py +0 -113
  132. union/_build.py +0 -25
  133. union/_cache/__init__.py +0 -12
  134. union/_cache/cache.py +0 -141
  135. union/_cache/defaults.py +0 -9
  136. union/_cache/policy_function_body.py +0 -42
  137. union/_cli/__init__.py +0 -0
  138. union/_cli/_common.py +0 -263
  139. union/_cli/_create.py +0 -40
  140. union/_cli/_delete.py +0 -23
  141. union/_cli/_deploy.py +0 -120
  142. union/_cli/_get.py +0 -162
  143. union/_cli/_run.py +0 -150
  144. union/_cli/main.py +0 -72
  145. union/_code_bundle/__init__.py +0 -8
  146. union/_code_bundle/_ignore.py +0 -113
  147. union/_code_bundle/_packaging.py +0 -187
  148. union/_code_bundle/_utils.py +0 -342
  149. union/_code_bundle/bundle.py +0 -176
  150. union/_context.py +0 -146
  151. union/_datastructures.py +0 -295
  152. union/_deploy.py +0 -185
  153. union/_doc.py +0 -29
  154. union/_docstring.py +0 -26
  155. union/_environment.py +0 -43
  156. union/_group.py +0 -31
  157. union/_hash.py +0 -23
  158. union/_image.py +0 -760
  159. union/_initialize.py +0 -585
  160. union/_interface.py +0 -84
  161. union/_internal/__init__.py +0 -3
  162. union/_internal/controllers/__init__.py +0 -77
  163. union/_internal/controllers/_local_controller.py +0 -77
  164. union/_internal/controllers/pbhash.py +0 -39
  165. union/_internal/controllers/remote/__init__.py +0 -40
  166. union/_internal/controllers/remote/_action.py +0 -131
  167. union/_internal/controllers/remote/_client.py +0 -43
  168. union/_internal/controllers/remote/_controller.py +0 -169
  169. union/_internal/controllers/remote/_core.py +0 -341
  170. union/_internal/controllers/remote/_informer.py +0 -260
  171. union/_internal/controllers/remote/_service_protocol.py +0 -44
  172. union/_internal/imagebuild/__init__.py +0 -11
  173. union/_internal/imagebuild/docker_builder.py +0 -416
  174. union/_internal/imagebuild/image_builder.py +0 -243
  175. union/_internal/imagebuild/remote_builder.py +0 -0
  176. union/_internal/resolvers/__init__.py +0 -0
  177. union/_internal/resolvers/_task_module.py +0 -31
  178. union/_internal/resolvers/common.py +0 -24
  179. union/_internal/resolvers/default.py +0 -27
  180. union/_internal/runtime/__init__.py +0 -0
  181. union/_internal/runtime/convert.py +0 -163
  182. union/_internal/runtime/entrypoints.py +0 -121
  183. union/_internal/runtime/io.py +0 -136
  184. union/_internal/runtime/resources_serde.py +0 -134
  185. union/_internal/runtime/task_serde.py +0 -202
  186. union/_internal/runtime/taskrunner.py +0 -179
  187. union/_internal/runtime/types_serde.py +0 -53
  188. union/_logging.py +0 -124
  189. union/_protos/__init__.py +0 -0
  190. union/_protos/common/authorization_pb2.py +0 -66
  191. union/_protos/common/authorization_pb2.pyi +0 -106
  192. union/_protos/common/identifier_pb2.py +0 -71
  193. union/_protos/common/identifier_pb2.pyi +0 -82
  194. union/_protos/common/identity_pb2.py +0 -48
  195. union/_protos/common/identity_pb2.pyi +0 -72
  196. union/_protos/common/identity_pb2_grpc.py +0 -4
  197. union/_protos/common/list_pb2.py +0 -36
  198. union/_protos/common/list_pb2.pyi +0 -69
  199. union/_protos/common/list_pb2_grpc.py +0 -4
  200. union/_protos/common/policy_pb2.py +0 -37
  201. union/_protos/common/policy_pb2.pyi +0 -27
  202. union/_protos/common/policy_pb2_grpc.py +0 -4
  203. union/_protos/common/role_pb2.py +0 -37
  204. union/_protos/common/role_pb2.pyi +0 -51
  205. union/_protos/common/role_pb2_grpc.py +0 -4
  206. union/_protos/common/runtime_version_pb2.py +0 -28
  207. union/_protos/common/runtime_version_pb2.pyi +0 -24
  208. union/_protos/common/runtime_version_pb2_grpc.py +0 -4
  209. union/_protos/logs/dataplane/payload_pb2.py +0 -96
  210. union/_protos/logs/dataplane/payload_pb2.pyi +0 -168
  211. union/_protos/logs/dataplane/payload_pb2_grpc.py +0 -4
  212. union/_protos/secret/definition_pb2.py +0 -49
  213. union/_protos/secret/definition_pb2.pyi +0 -93
  214. union/_protos/secret/definition_pb2_grpc.py +0 -4
  215. union/_protos/secret/payload_pb2.py +0 -62
  216. union/_protos/secret/payload_pb2.pyi +0 -94
  217. union/_protos/secret/payload_pb2_grpc.py +0 -4
  218. union/_protos/secret/secret_pb2.py +0 -38
  219. union/_protos/secret/secret_pb2.pyi +0 -6
  220. union/_protos/secret/secret_pb2_grpc.py +0 -198
  221. union/_protos/validate/validate/validate_pb2.py +0 -76
  222. union/_protos/workflow/node_execution_service_pb2.py +0 -26
  223. union/_protos/workflow/node_execution_service_pb2.pyi +0 -4
  224. union/_protos/workflow/node_execution_service_pb2_grpc.py +0 -32
  225. union/_protos/workflow/queue_service_pb2.py +0 -75
  226. union/_protos/workflow/queue_service_pb2.pyi +0 -103
  227. union/_protos/workflow/queue_service_pb2_grpc.py +0 -172
  228. union/_protos/workflow/run_definition_pb2.py +0 -100
  229. union/_protos/workflow/run_definition_pb2.pyi +0 -256
  230. union/_protos/workflow/run_definition_pb2_grpc.py +0 -4
  231. union/_protos/workflow/run_logs_service_pb2.py +0 -41
  232. union/_protos/workflow/run_logs_service_pb2.pyi +0 -28
  233. union/_protos/workflow/run_logs_service_pb2_grpc.py +0 -69
  234. union/_protos/workflow/run_service_pb2.py +0 -133
  235. union/_protos/workflow/run_service_pb2.pyi +0 -173
  236. union/_protos/workflow/run_service_pb2_grpc.py +0 -412
  237. union/_protos/workflow/state_service_pb2.py +0 -58
  238. union/_protos/workflow/state_service_pb2.pyi +0 -69
  239. union/_protos/workflow/state_service_pb2_grpc.py +0 -138
  240. union/_protos/workflow/task_definition_pb2.py +0 -72
  241. union/_protos/workflow/task_definition_pb2.pyi +0 -65
  242. union/_protos/workflow/task_definition_pb2_grpc.py +0 -4
  243. union/_protos/workflow/task_service_pb2.py +0 -44
  244. union/_protos/workflow/task_service_pb2.pyi +0 -31
  245. union/_protos/workflow/task_service_pb2_grpc.py +0 -104
  246. union/_resources.py +0 -226
  247. union/_retry.py +0 -32
  248. union/_reusable_environment.py +0 -25
  249. union/_run.py +0 -374
  250. union/_secret.py +0 -61
  251. union/_task.py +0 -354
  252. union/_task_environment.py +0 -186
  253. union/_timeout.py +0 -47
  254. union/_tools.py +0 -27
  255. union/_utils/__init__.py +0 -11
  256. union/_utils/asyn.py +0 -119
  257. union/_utils/file_handling.py +0 -71
  258. union/_utils/helpers.py +0 -46
  259. union/_utils/lazy_module.py +0 -54
  260. union/_utils/uv_script_parser.py +0 -49
  261. union/_version.py +0 -21
  262. union/connectors/__init__.py +0 -0
  263. union/errors.py +0 -128
  264. union/extras/__init__.py +0 -5
  265. union/extras/_container.py +0 -263
  266. union/io/__init__.py +0 -11
  267. union/io/_dataframe.py +0 -0
  268. union/io/_dir.py +0 -425
  269. union/io/_file.py +0 -418
  270. union/io/pickle/__init__.py +0 -0
  271. union/io/pickle/transformer.py +0 -117
  272. union/io/structured_dataset/__init__.py +0 -122
  273. union/io/structured_dataset/basic_dfs.py +0 -219
  274. union/io/structured_dataset/structured_dataset.py +0 -1057
  275. union/py.typed +0 -0
  276. union/remote/__init__.py +0 -23
  277. union/remote/_client/__init__.py +0 -0
  278. union/remote/_client/_protocols.py +0 -129
  279. union/remote/_client/auth/__init__.py +0 -12
  280. union/remote/_client/auth/_authenticators/__init__.py +0 -0
  281. union/remote/_client/auth/_authenticators/base.py +0 -391
  282. union/remote/_client/auth/_authenticators/client_credentials.py +0 -73
  283. union/remote/_client/auth/_authenticators/device_code.py +0 -120
  284. union/remote/_client/auth/_authenticators/external_command.py +0 -77
  285. union/remote/_client/auth/_authenticators/factory.py +0 -200
  286. union/remote/_client/auth/_authenticators/pkce.py +0 -515
  287. union/remote/_client/auth/_channel.py +0 -184
  288. union/remote/_client/auth/_client_config.py +0 -83
  289. union/remote/_client/auth/_default_html.py +0 -32
  290. union/remote/_client/auth/_grpc_utils/__init__.py +0 -0
  291. union/remote/_client/auth/_grpc_utils/auth_interceptor.py +0 -204
  292. union/remote/_client/auth/_grpc_utils/default_metadata_interceptor.py +0 -144
  293. union/remote/_client/auth/_keyring.py +0 -154
  294. union/remote/_client/auth/_token_client.py +0 -258
  295. union/remote/_client/auth/errors.py +0 -16
  296. union/remote/_client/controlplane.py +0 -86
  297. union/remote/_data.py +0 -149
  298. union/remote/_logs.py +0 -74
  299. union/remote/_project.py +0 -86
  300. union/remote/_run.py +0 -820
  301. union/remote/_secret.py +0 -132
  302. union/remote/_task.py +0 -193
  303. union/report/__init__.py +0 -3
  304. union/report/_report.py +0 -178
  305. union/report/_template.html +0 -124
  306. union/storage/__init__.py +0 -24
  307. union/storage/_remote_fs.py +0 -34
  308. union/storage/_storage.py +0 -247
  309. union/storage/_utils.py +0 -5
  310. union/types/__init__.py +0 -11
  311. union/types/_renderer.py +0 -162
  312. union/types/_string_literals.py +0 -120
  313. union/types/_type_engine.py +0 -2131
  314. union/types/_utils.py +0 -80
  315. /union/_protos/common/authorization_pb2_grpc.py → /flyte/_protos/workflow/common_pb2_grpc.py +0 -0
  316. /union/_protos/common/identifier_pb2_grpc.py → /flyte/_protos/workflow/environment_pb2_grpc.py +0 -0
  317. /flyte/io/{structured_dataset → _structured_dataset}/__init__.py +0 -0
  318. {flyte-0.0.1b3.dist-info → flyte-0.2.0a0.dist-info}/WHEEL +0 -0
  319. {flyte-0.0.1b3.dist-info → flyte-0.2.0a0.dist-info}/top_level.txt +0 -0
@@ -38,11 +38,11 @@ class ActionCache:
38
38
  """
39
39
  Add an action to the cache if it doesn't exist. This is invoked by the watch.
40
40
  """
41
- logger.info(f"Observing phase {run_definition_pb2.Phase.Name(state.phase)} for {state.action_id.name}")
41
+ logger.debug(f"Observing phase {run_definition_pb2.Phase.Name(state.phase)} for {state.action_id.name}")
42
42
  if state.output_uri:
43
- logger.info(f"Output URI: {state.output_uri}")
43
+ logger.debug(f"Output URI: {state.output_uri}")
44
44
  else:
45
- logger.info(f"{state.action_id.name} has no output URI")
45
+ logger.warning(f"{state.action_id.name} has no output URI")
46
46
  if state.phase == run_definition_pb2.Phase.PHASE_FAILED:
47
47
  logger.error(
48
48
  f"Action {state.action_id.name} failed with error (msg):"
@@ -235,7 +235,7 @@ class Informer:
235
235
  await self._shared_queue.put(node)
236
236
  # hack to work in the absence of sentinel
237
237
  except asyncio.CancelledError:
238
- logger.warning(f"Watch cancelled: {self.name}")
238
+ logger.info(f"Watch cancelled: {self.name}")
239
239
  return
240
240
  except asyncio.TimeoutError as e:
241
241
  logger.error(f"Watch timeout: {self.name}", exc_info=e)
@@ -28,12 +28,12 @@ class QueueService(Protocol):
28
28
  ) -> queue_service_pb2.EnqueueActionResponse:
29
29
  """Enqueue a task"""
30
30
 
31
- async def AbortQueuedAction(
32
- self,
33
- req: queue_service_pb2.AbortQueuedActionRequest,
34
- **kwargs,
35
- ) -> queue_service_pb2.AbortQueuedActionResponse:
36
- """Dequeue a task"""
31
+ # async def AbortQueuedAction(
32
+ # self,
33
+ # req: queue_service_pb2.AbortQueuedActionRequest,
34
+ # **kwargs,
35
+ # ) -> queue_service_pb2.AbortQueuedActionResponse:
36
+ # """Dequeue a task"""
37
37
 
38
38
 
39
39
  class ClientSet(Protocol):
@@ -1,4 +1,5 @@
1
1
  import asyncio
2
+ import os
2
3
  import shutil
3
4
  import subprocess
4
5
  import tempfile
@@ -25,6 +26,8 @@ from flyte._image import (
25
26
  from flyte._logging import logger
26
27
 
27
28
  _F_IMG_ID = "_F_IMG_ID"
29
+ FLYTE_DOCKER_BUILDER_CACHE_FROM = "FLYTE_DOCKER_BUILDER_CACHE_FROM"
30
+ FLYTE_DOCKER_BUILDER_CACHE_TO = "FLYTE_DOCKER_BUILDER_CACHE_TO"
28
31
 
29
32
  UV_LOCK_INSTALL_TEMPLATE = Template("""\
30
33
  WORKDIR /root
@@ -193,7 +196,7 @@ class CopyConfigHandler:
193
196
  shutil.copy(abs_path, dest_path)
194
197
  elif layer.context_source.is_dir():
195
198
  # Copy the entire directory
196
- shutil.copytree(abs_path, dest_path)
199
+ shutil.copytree(abs_path, dest_path, dirs_exist_ok=True)
197
200
  else:
198
201
  raise ValueError(f"Source path is neither file nor directory: {layer.context_source}")
199
202
 
@@ -396,6 +399,14 @@ class DockerImageBuilder:
396
399
  "--push" if push else "--load",
397
400
  ]
398
401
 
402
+ cache_from = os.getenv(FLYTE_DOCKER_BUILDER_CACHE_FROM)
403
+ cache_to = os.getenv(FLYTE_DOCKER_BUILDER_CACHE_TO)
404
+ if cache_from and cache_to:
405
+ command[3:3] = [
406
+ f"--cache-from={cache_from}",
407
+ f"--cache-to={cache_to}",
408
+ ]
409
+
399
410
  if image.registry and push:
400
411
  command.append("--push")
401
412
  command.append(tmp_dir)
@@ -32,13 +32,15 @@ class DockerAPIImageChecker(ImageChecker):
32
32
  """
33
33
 
34
34
  @classmethod
35
- async def image_exists(cls, repository: str, tag: str, arch: Tuple[Architecture, ...] = ("linux/amd64",)) -> bool:
35
+ async def image_exists(
36
+ cls,
37
+ repository: str,
38
+ tag: str,
39
+ arch: Tuple[Architecture, ...] = ("linux/amd64",)
40
+ ) -> bool:
36
41
  import httpx
37
42
 
38
- if "/" in repository:
39
- if not repository.startswith("library/"):
40
- raise ValueError("This checker only works with Docker Hub")
41
- else:
43
+ if "/" not in repository:
42
44
  repository = f"library/{repository}"
43
45
 
44
46
  auth_url = "https://auth.docker.io/token"
@@ -50,6 +52,7 @@ class DockerAPIImageChecker(ImageChecker):
50
52
  auth_response = await client.get(auth_url, params={"service": service, "scope": scope})
51
53
  if auth_response.status_code != 200:
52
54
  raise Exception(f"Failed to get auth token: {auth_response.status_code}")
55
+
53
56
  token = auth_response.json()["token"]
54
57
 
55
58
  manifest_url = f"https://registry-1.docker.io/v2/{repository}/manifests/{tag}"
@@ -60,18 +63,20 @@ class DockerAPIImageChecker(ImageChecker):
60
63
  "application/vnd.docker.distribution.manifest.list.v2+json"
61
64
  ),
62
65
  }
63
- manifest_response = await client.get(manifest_url, headers=headers)
64
66
 
67
+ manifest_response = await client.get(manifest_url, headers=headers)
65
68
  if manifest_response.status_code != 200:
66
- raise Exception(f"Failed to get manifest: {manifest_response.status_code}")
69
+ logger.warning(f"Image not found: {repository}:{tag} (HTTP {manifest_response.status_code})")
70
+ return False
71
+
67
72
  manifest_list = manifest_response.json()["manifests"]
68
- architectures = [f"{x['platform']['os']}/{x['platform']['architecture']}" for x in manifest_list]
73
+ architectures = [f"{m['platform']['os']}/{m['platform']['architecture']}" for m in manifest_list]
69
74
 
70
- if set(architectures) >= set(arch):
71
- logger.debug(f"Image {repository}:{tag} found for architecture(s) {arch}, has {architectures}")
75
+ if set(arch).issubset(set(architectures)):
76
+ logger.debug(f"Image {repository}:{tag} found with arch {architectures}")
72
77
  return True
73
78
  else:
74
- logger.debug(f"Image {repository}:{tag} not found for architecture(s) {arch}, only has {architectures}")
79
+ logger.debug(f"Image {repository}:{tag} has {architectures}, but missing {arch}")
75
80
  return False
76
81
 
77
82
 
@@ -1,16 +1,20 @@
1
1
  from __future__ import annotations
2
2
 
3
+ import asyncio
4
+ import base64
5
+ import hashlib
6
+ import inspect
3
7
  from dataclasses import dataclass
4
- from typing import Any, Dict, Tuple, Union
8
+ from types import NoneType
9
+ from typing import Any, Dict, List, Tuple, Union, get_args
5
10
 
6
- from flyteidl.core import execution_pb2, literals_pb2
11
+ from flyteidl.core import execution_pb2, interface_pb2, literals_pb2
7
12
 
8
13
  import flyte.errors
9
14
  import flyte.storage as storage
10
- from flyte._datastructures import ActionID, NativeInterface, TaskContext
11
- from flyte._internal.controllers import pbhash
12
- from flyte._protos.workflow import run_definition_pb2
13
- from flyte.types import TypeEngine
15
+ from flyte._protos.workflow import common_pb2, run_definition_pb2, task_definition_pb2
16
+ from flyte.models import ActionID, NativeInterface, TaskContext
17
+ from flyte.types import TypeEngine, TypeTransformerFailedError
14
18
 
15
19
 
16
20
  @dataclass(frozen=True)
@@ -56,24 +60,97 @@ async def convert_inputs_to_native(inputs: Inputs, python_interface: NativeInter
56
60
  return native_vals
57
61
 
58
62
 
63
+ async def convert_upload_default_inputs(interface: NativeInterface) -> List[common_pb2.NamedParameter]:
64
+ """
65
+ Converts the default inputs of a NativeInterface to a list of NamedParameters for upload.
66
+ This is used to upload default inputs to the Flyte backend.
67
+ """
68
+ if not interface.inputs:
69
+ return []
70
+
71
+ vars = []
72
+ literal_coros = []
73
+ for input_name, (input_type, default_value) in interface.inputs.items():
74
+ if default_value is not inspect.Parameter.empty:
75
+ lt = TypeEngine.to_literal_type(input_type)
76
+ literal_coros.append(TypeEngine.to_literal(default_value, input_type, lt))
77
+ vars.append((input_name, lt))
78
+
79
+ literals: List[literals_pb2.Literal] = await asyncio.gather(*literal_coros)
80
+ named_params = []
81
+ for (name, lt), literal in zip(vars, literals):
82
+ param = interface_pb2.Parameter(
83
+ var=interface_pb2.Variable(
84
+ type=lt,
85
+ ),
86
+ default=literal,
87
+ )
88
+ named_params.append(
89
+ common_pb2.NamedParameter(
90
+ name=name,
91
+ parameter=param,
92
+ ),
93
+ )
94
+ return named_params
95
+
96
+
97
+ def is_optional_type(tp) -> bool:
98
+ """
99
+ True if the *annotation* `tp` is equivalent to Optional[…].
100
+ Works for Optional[T], Union[T, None], and T | None.
101
+ """
102
+ return NoneType in get_args(tp) # fastest check
103
+
104
+
59
105
  async def convert_from_native_to_inputs(interface: NativeInterface, *args, **kwargs) -> Inputs:
60
106
  kwargs = interface.convert_to_kwargs(*args, **kwargs)
61
- if len(kwargs) == 0:
62
- return Inputs.empty()
63
- if len(kwargs) < len(interface.inputs):
107
+
108
+ if len(kwargs) < interface.num_required_inputs():
64
109
  raise ValueError(
65
- f"Received {len(kwargs)} inputs but interface has {len(interface.inputs)}. "
66
- f"Please provide all required inputs."
110
+ f"Received {len(kwargs)} inputs but interface has {interface.num_required_inputs()} required inputs. "
111
+ f"Please provide all required inputs. Inputs received: {kwargs}, interface: {interface}"
67
112
  )
68
- literal_map = await TypeEngine.dict_to_literal_map(kwargs, interface.get_input_types())
113
+
114
+ if len(interface.inputs) == 0:
115
+ return Inputs.empty()
116
+
117
+ # fill in defaults if missing
118
+ type_hints: Dict[str, type] = {}
119
+ already_converted_kwargs: Dict[str, literals_pb2.Literal] = {}
120
+ for input_name, (input_type, default_value) in interface.inputs.items():
121
+ if input_name in kwargs:
122
+ type_hints[input_name] = input_type
123
+ elif (default_value is not None and default_value is not inspect.Signature.empty) or (
124
+ default_value is None and is_optional_type(input_type)
125
+ ):
126
+ if default_value == NativeInterface.has_default:
127
+ if interface._remote_defaults is None or input_name not in interface._remote_defaults:
128
+ raise ValueError(f"Input '{input_name}' has a default value but it is not set in the interface.")
129
+ already_converted_kwargs[input_name] = interface._remote_defaults[input_name]
130
+ else:
131
+ kwargs[input_name] = default_value
132
+ type_hints[input_name] = input_type
133
+
134
+ literal_map = await TypeEngine.dict_to_literal_map(kwargs, type_hints)
135
+ if len(already_converted_kwargs) > 0:
136
+ copied_literals: Dict[str, literals_pb2.Literal] = {}
137
+ for k, v in literal_map.literals.items():
138
+ copied_literals[k] = v
139
+ # Add the already converted kwargs to the literal map
140
+ for k, v in already_converted_kwargs.items():
141
+ copied_literals[k] = v
142
+ literal_map = literals_pb2.LiteralMap(literals=copied_literals)
143
+ # Make sure we the interface, not literal_map or kwargs, because those may have a different order
69
144
  return Inputs(
70
145
  proto_inputs=run_definition_pb2.Inputs(
71
- literals=[run_definition_pb2.NamedLiteral(name=k, value=v) for k, v in literal_map.literals.items()]
146
+ literals=[
147
+ run_definition_pb2.NamedLiteral(name=k, value=literal_map.literals[k]) for k in interface.inputs.keys()
148
+ ]
72
149
  )
73
150
  )
74
151
 
75
152
 
76
- async def convert_from_native_to_outputs(o: Any, interface: NativeInterface) -> Outputs:
153
+ async def convert_from_native_to_outputs(o: Any, interface: NativeInterface, task_name: str = "") -> Outputs:
77
154
  # Always make it a tuple even if it's just one item to simplify logic below
78
155
  if not isinstance(o, tuple):
79
156
  o = (o,)
@@ -83,8 +160,11 @@ async def convert_from_native_to_outputs(o: Any, interface: NativeInterface) ->
83
160
  )
84
161
  named = []
85
162
  for (output_name, python_type), v in zip(interface.outputs.items(), o):
86
- lit = await TypeEngine.to_literal(v, python_type, TypeEngine.to_literal_type(python_type))
87
- named.append(run_definition_pb2.NamedLiteral(name=output_name, value=lit))
163
+ try:
164
+ lit = await TypeEngine.to_literal(v, python_type, TypeEngine.to_literal_type(python_type))
165
+ named.append(run_definition_pb2.NamedLiteral(name=output_name, value=lit))
166
+ except TypeTransformerFailedError as e:
167
+ raise flyte.errors.RuntimeDataValidationError(output_name, e, task_name)
88
168
 
89
169
  return Outputs(proto_outputs=run_definition_pb2.Outputs(literals=named))
90
170
 
@@ -179,21 +259,84 @@ def convert_from_native_to_error(err: BaseException) -> Error:
179
259
  )
180
260
 
181
261
 
182
- def generate_sub_action_id_and_output_path(tctx: TaskContext, task_name: str, inputs: Inputs) -> Tuple[ActionID, str]:
262
+ def hash_data(data: Union[str, bytes]) -> str:
263
+ """
264
+ Generate a hash for the given data. If the data is a string, it will be encoded to bytes before hashing.
265
+ :param data: The data to hash, can be a string or bytes.
266
+ :return: A hexadecimal string representation of the hash.
267
+ """
268
+ if isinstance(data, str):
269
+ data = data.encode("utf-8")
270
+ digest = hashlib.sha256(data).digest()
271
+ return base64.b64encode(digest).decode("utf-8")
272
+
273
+
274
+ def generate_inputs_hash(serialized_inputs: str | bytes) -> str:
275
+ """
276
+ Generate a hash for the inputs. This is used to uniquely identify the inputs for a task.
277
+ :return: A hexadecimal string representation of the hash.
278
+ """
279
+ return hash_data(serialized_inputs)
280
+
281
+
282
+ def generate_cache_key_hash(
283
+ task_name: str,
284
+ inputs_hash: str,
285
+ task_interface: interface_pb2.TypedInterface,
286
+ cache_version: str,
287
+ ignored_input_vars: List[str],
288
+ proto_inputs: run_definition_pb2.Inputs,
289
+ ) -> str:
290
+ """
291
+ Generate a cache key hash based on the inputs hash, task name, task interface, and cache version.
292
+ This is used to uniquely identify the cache key for a task.
293
+
294
+ :param task_name: The name of the task.
295
+ :param inputs_hash: The hash of the inputs.
296
+ :param task_interface: The interface of the task.
297
+ :param cache_version: The version of the cache.
298
+ :param ignored_input_vars: A list of input variable names to ignore when generating the cache key.
299
+ :param proto_inputs: The proto inputs for the task, only used if there are ignored inputs.
300
+ :return: A hexadecimal string representation of the cache key hash.
301
+ """
302
+ if ignored_input_vars:
303
+ filtered = [named_lit for named_lit in proto_inputs.literals if named_lit.name not in ignored_input_vars]
304
+ final = run_definition_pb2.Inputs(literals=filtered)
305
+ final_inputs = final.SerializeToString(deterministic=True)
306
+ else:
307
+ final_inputs = inputs_hash
308
+ data = f"{final_inputs}{task_name}{task_interface.SerializeToString(deterministic=True)}{cache_version}"
309
+ return hash_data(data)
310
+
311
+
312
+ def generate_sub_action_id_and_output_path(
313
+ tctx: TaskContext,
314
+ task_spec_or_name: task_definition_pb2.TaskSpec | str,
315
+ inputs_hash: str,
316
+ invoke_seq: int,
317
+ ) -> Tuple[ActionID, str]:
183
318
  """
184
319
  Generate a sub-action ID and output path based on the current task context, task name, and inputs.
320
+
321
+ action name = current action name + task name + input hash + group name (if available)
185
322
  :param tctx:
186
- :param task_name:
187
- :param inputs:
323
+ :param task_spec_or_name: task specification or task name. Task name is only used in case of trace actions.
324
+ :param inputs_hash: Consistent hash string of the inputs
325
+ :param invoke_seq: The sequence number of the invocation, used to differentiate between multiple invocations.
188
326
  :return:
189
327
  """
190
328
  current_action_id = tctx.action
191
329
  current_output_path = tctx.run_base_dir
192
- inputs_hash = pbhash.compute_hash_string(inputs.proto_inputs)
330
+ if isinstance(task_spec_or_name, task_definition_pb2.TaskSpec):
331
+ task_spec_or_name.task_template.interface
332
+ task_hash = hash_data(task_spec_or_name.SerializeToString(deterministic=True))
333
+ else:
334
+ task_hash = task_spec_or_name
193
335
  sub_action_id = current_action_id.new_sub_action_from(
194
- task_name=task_name,
336
+ task_hash=task_hash,
195
337
  input_hash=inputs_hash,
196
338
  group=tctx.group_data.name if tctx.group_data else None,
339
+ task_call_seq=invoke_seq,
197
340
  )
198
341
  sub_run_output_path = storage.join(current_output_path, sub_action_id.name)
199
342
  return sub_action_id, sub_run_output_path
@@ -3,11 +3,11 @@ from typing import List, Optional, Tuple
3
3
  import flyte.errors
4
4
  from flyte._code_bundle import download_bundle
5
5
  from flyte._context import contextual_run
6
- from flyte._datastructures import ActionID, Checkpoints, CodeBundle, RawDataPath
7
6
  from flyte._internal import Controller
8
7
  from flyte._internal.imagebuild.image_builder import ImageCache
9
8
  from flyte._logging import log, logger
10
9
  from flyte._task import TaskTemplate
10
+ from flyte.models import ActionID, Checkpoints, CodeBundle, RawDataPath
11
11
 
12
12
  from .convert import Error, Inputs, Outputs
13
13
  from .task_serde import load_task
@@ -21,11 +21,11 @@ _OUTPUTS_FILE_NAME = "outputs.pb"
21
21
  _CHECKPOINT_FILE_NAME = "_flytecheckpoints"
22
22
  _ERROR_FILE_NAME = "error.pb"
23
23
  _REPORT_FILE_NAME = "report.html"
24
- _PKL_FILE_NAME = "code_bundle.pkl.gz"
24
+ _PKL_EXT = ".pkl.gz"
25
25
 
26
26
 
27
- def pkl_path(base_path: str) -> str:
28
- return storage.join(base_path, _PKL_FILE_NAME)
27
+ def pkl_path(base_path: str, pkl_name: str) -> str:
28
+ return storage.join(base_path, f"{pkl_name}{_PKL_EXT}")
29
29
 
30
30
 
31
31
  def inputs_path(base_path: str) -> str:
@@ -3,26 +3,32 @@ This module provides functionality to serialize and deserialize tasks to and fro
3
3
  It includes a Resolver interface for loading tasks, and functions to load classes and tasks.
4
4
  """
5
5
 
6
+ import copy
6
7
  import importlib
8
+ import typing
7
9
  from datetime import timedelta
8
- from typing import Optional, Type
10
+ from typing import Optional, Type, cast
9
11
 
10
12
  from flyteidl.core import identifier_pb2, literals_pb2, security_pb2, tasks_pb2
11
13
  from google.protobuf import duration_pb2, wrappers_pb2
12
14
 
13
15
  import flyte.errors
14
16
  from flyte._cache.cache import VersionParameters, cache_from_request
15
- from flyte._datastructures import SerializationContext
16
17
  from flyte._logging import logger
17
- from flyte._protos.workflow import task_definition_pb2
18
+ from flyte._pod import _PRIMARY_CONTAINER_NAME_FIELD, PodTemplate
19
+ from flyte._protos.workflow import common_pb2, environment_pb2, task_definition_pb2
18
20
  from flyte._secret import SecretRequest, secrets_from_request
19
21
  from flyte._task import AsyncFunctionTaskTemplate, TaskTemplate
22
+ from flyte.models import CodeBundle, SerializationContext
20
23
 
21
24
  from ..._retry import RetryStrategy
22
25
  from ..._timeout import TimeoutType, timeout_from_request
23
26
  from .resources_serde import get_proto_extended_resources, get_proto_resources
24
27
  from .types_serde import transform_native_to_typed_interface
25
28
 
29
+ _MAX_ENV_NAME_LENGTH = 63 # Maximum length for environment names
30
+ _MAX_TASK_SHORT_NAME_LENGTH = 63 # Maximum length for task short names
31
+
26
32
 
27
33
  def load_class(qualified_name) -> Type:
28
34
  """
@@ -49,17 +55,31 @@ def load_task(resolver: str, *resolver_args: str) -> TaskTemplate:
49
55
 
50
56
 
51
57
  def translate_task_to_wire(
52
- task: TaskTemplate, serialization_context: SerializationContext
58
+ task: TaskTemplate,
59
+ serialization_context: SerializationContext,
60
+ default_inputs: Optional[typing.List[common_pb2.NamedParameter]] = None,
53
61
  ) -> task_definition_pb2.TaskSpec:
54
62
  """
55
63
  Translate a task to a wire format. This is a placeholder function.
56
64
 
57
65
  :param task: The task to translate.
58
66
  :param serialization_context: The serialization context to use for the translation.
67
+ :param default_inputs: Optional list of default inputs for the task.
59
68
 
60
69
  :return: The translated task.
61
70
  """
62
- return get_proto_task(task, serialization_context)
71
+ tt = get_proto_task(task, serialization_context)
72
+ env: environment_pb2.Environment | None = None
73
+ if task.parent_env and task.parent_env():
74
+ _env = task.parent_env()
75
+ if _env:
76
+ env = environment_pb2.Environment(name=_env.name[:_MAX_ENV_NAME_LENGTH])
77
+ return task_definition_pb2.TaskSpec(
78
+ task_template=tt,
79
+ default_inputs=default_inputs,
80
+ short_name=task.friendly_name[:_MAX_TASK_SHORT_NAME_LENGTH],
81
+ environment=env,
82
+ )
63
83
 
64
84
 
65
85
  def get_security_context(secrets: Optional[SecretRequest]) -> Optional[security_pb2.SecurityContext]:
@@ -108,7 +128,7 @@ def get_proto_timeout(timeout: TimeoutType | None) -> Optional[duration_pb2.Dura
108
128
  return duration_pb2.Duration(seconds=max_runtime_timeout.seconds)
109
129
 
110
130
 
111
- def get_proto_task(task: TaskTemplate, serialize_context: SerializationContext) -> task_definition_pb2.TaskSpec:
131
+ def get_proto_task(task: TaskTemplate, serialize_context: SerializationContext) -> tasks_pb2.TaskTemplate:
112
132
  task_id = identifier_pb2.Identifier(
113
133
  resource_type=identifier_pb2.ResourceType.TASK,
114
134
  project=serialize_context.project,
@@ -121,17 +141,18 @@ def get_proto_task(task: TaskTemplate, serialize_context: SerializationContext)
121
141
  # if task.parent_env is None:
122
142
  # raise ValueError(f"Task {task.name} must have a parent environment")
123
143
 
124
- #
125
- # This pod will be incorrect when doing fast serialize
126
- #
127
- container = _get_urun_container(serialize_context, task)
144
+ # TODO Add support for SQL, extra_config, custom
145
+ extra_config: typing.Dict[str, str] = {}
146
+ custom = {} # type: ignore
128
147
 
129
- # TODO Add support for SQL, Pod, extra_config, custom
130
- pod = None
131
148
  sql = None
132
- # pod = task.get_k8s_pod(serialize_context)
133
- extra_config = {} # type: ignore
134
- custom = {} # type: ignore
149
+ if task.pod_template and not isinstance(task.pod_template, str):
150
+ container = None
151
+ pod = _get_k8s_pod(_get_urun_container(serialize_context, task), task.pod_template)
152
+ extra_config[_PRIMARY_CONTAINER_NAME_FIELD] = task.pod_template.primary_container_name
153
+ else:
154
+ container = _get_urun_container(serialize_context, task)
155
+ pod = None
135
156
 
136
157
  # -------------- CACHE HANDLING ----------------------
137
158
  task_cache = cache_from_request(task.cache)
@@ -154,20 +175,24 @@ def get_proto_task(task: TaskTemplate, serialize_context: SerializationContext)
154
175
  else:
155
176
  logger.debug(f"Cache disabled for task {task.name}")
156
177
 
157
- tt = tasks_pb2.TaskTemplate(
178
+ return tasks_pb2.TaskTemplate(
158
179
  id=task_id,
159
180
  type=task.task_type,
160
181
  metadata=tasks_pb2.TaskMetadata(
161
182
  discoverable=cache_enabled,
162
183
  discovery_version=cache_version,
163
184
  cache_serializable=task_cache.serialize,
164
- cache_ignore_input_vars=task_cache.ignored_inputs,
165
- runtime=tasks_pb2.RuntimeMetadata(),
185
+ cache_ignore_input_vars=task_cache.get_ignored_inputs() if cache_enabled else None,
186
+ runtime=tasks_pb2.RuntimeMetadata(
187
+ version=flyte.version(),
188
+ type=tasks_pb2.RuntimeMetadata.RuntimeType.FLYTE_SDK,
189
+ flavor="python",
190
+ ),
166
191
  retries=get_proto_retry_strategy(task.retries),
167
192
  timeout=get_proto_timeout(task.timeout),
168
193
  pod_template_name=task.pod_template if task.pod_template and isinstance(task.pod_template, str) else None,
169
194
  interruptible=task.interruptable,
170
- generates_deck=wrappers_pb2.BoolValue(value=False), # TODO add support for reports
195
+ generates_deck=wrappers_pb2.BoolValue(value=task.report),
171
196
  ),
172
197
  interface=transform_native_to_typed_interface(task.native_interface),
173
198
  custom=custom,
@@ -179,7 +204,6 @@ def get_proto_task(task: TaskTemplate, serialize_context: SerializationContext)
179
204
  sql=sql,
180
205
  extended_resources=get_proto_extended_resources(task.resources),
181
206
  )
182
- return task_definition_pb2.TaskSpec(task_template=tt)
183
207
 
184
208
 
185
209
  def _get_urun_container(
@@ -208,3 +232,99 @@ def _get_urun_container(
208
232
  data_config=task_template.data_loading_config(serialize_context),
209
233
  config=task_template.config(serialize_context),
210
234
  )
235
+
236
+
237
+ def _sanitize_resource_name(resource: tasks_pb2.Resources.ResourceEntry) -> str:
238
+ return tasks_pb2.Resources.ResourceName.Name(resource.name).lower().replace("_", "-")
239
+
240
+
241
+ def _get_k8s_pod(primary_container: tasks_pb2.Container, pod_template: PodTemplate) -> Optional[tasks_pb2.K8sPod]:
242
+ """
243
+ Get the K8sPod representation of the task template.
244
+ :param task: The task to convert.
245
+ :return: The K8sPod representation of the task template.
246
+ """
247
+ from kubernetes.client import ApiClient, V1PodSpec
248
+ from kubernetes.client.models import V1EnvVar, V1ResourceRequirements
249
+
250
+ pod_template = copy.deepcopy(pod_template)
251
+ containers = cast(V1PodSpec, pod_template.pod_spec).containers
252
+ primary_exists = False
253
+
254
+ for container in containers:
255
+ if container.name == pod_template.primary_container_name:
256
+ primary_exists = True
257
+ break
258
+
259
+ if not primary_exists:
260
+ raise ValueError(
261
+ "No primary container defined in the pod spec."
262
+ f" You must define a primary container with the name '{pod_template.primary_container_name}'."
263
+ )
264
+ final_containers = []
265
+
266
+ for container in containers:
267
+ # We overwrite the primary container attributes with the values given to ContainerTask.
268
+ # The attributes include: image, command, args, resource, and env (env is unioned)
269
+
270
+ if container.name == pod_template.primary_container_name:
271
+ if container.image is None:
272
+ # Copy the image from primary_container only if the image is not specified in the pod spec.
273
+ container.image = primary_container.image
274
+
275
+ container.command = list(primary_container.command)
276
+ container.args = list(primary_container.args)
277
+
278
+ limits, requests = {}, {}
279
+ for resource in primary_container.resources.limits:
280
+ limits[_sanitize_resource_name(resource)] = resource.value
281
+ for resource in primary_container.resources.requests:
282
+ requests[_sanitize_resource_name(resource)] = resource.value
283
+
284
+ resource_requirements = V1ResourceRequirements(limits=limits, requests=requests)
285
+ if len(limits) > 0 or len(requests) > 0:
286
+ # Important! Only copy over resource requirements if they are non-empty.
287
+ container.resources = resource_requirements
288
+
289
+ if primary_container.env is not None:
290
+ container.env = [V1EnvVar(name=e.key, value=e.value) for e in primary_container.env] + (
291
+ container.env or []
292
+ )
293
+
294
+ final_containers.append(container)
295
+
296
+ cast(V1PodSpec, pod_template.pod_spec).containers = final_containers
297
+ pod_spec = ApiClient().sanitize_for_serialization(pod_template.pod_spec)
298
+
299
+ metadata = tasks_pb2.K8sObjectMetadata(labels=pod_template.labels, annotations=pod_template.annotations)
300
+ return tasks_pb2.K8sPod(pod_spec=pod_spec, metadata=metadata)
301
+
302
+
303
+ def extract_code_bundle(task_spec: task_definition_pb2.TaskSpec) -> Optional[CodeBundle]:
304
+ """
305
+ Extract the code bundle from the task spec.
306
+ :param task_spec: The task spec to extract the code bundle from.
307
+ :return: The extracted code bundle or None if not present.
308
+ """
309
+ container = task_spec.task_template.container
310
+ if container and container.args:
311
+ pkl_path = None
312
+ tgz_path = None
313
+ dest_path: str = "."
314
+ version = ""
315
+ for i, v in enumerate(container.args):
316
+ if v == "--pkl":
317
+ # Extract the code bundle path from the argument
318
+ pkl_path = container.args[i + 1] if i + 1 < len(container.args) else None
319
+ elif v == "--tgz":
320
+ # Extract the code bundle path from the argument
321
+ tgz_path = container.args[i + 1] if i + 1 < len(container.args) else None
322
+ elif v == "--dest":
323
+ # Extract the destination path from the argument
324
+ dest_path = container.args[i + 1] if i + 1 < len(container.args) else "."
325
+ elif v == "--version":
326
+ # Extract the version from the argument
327
+ version = container.args[i + 1] if i + 1 < len(container.args) else ""
328
+ if pkl_path or tgz_path:
329
+ return CodeBundle(destination=dest_path, tgz=tgz_path, pkl=pkl_path, computed_version=version)
330
+ return None