flyte 2.0.0b32__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flyte might be problematic. Click here for more details.

Files changed (204) hide show
  1. flyte/__init__.py +108 -0
  2. flyte/_bin/__init__.py +0 -0
  3. flyte/_bin/debug.py +38 -0
  4. flyte/_bin/runtime.py +195 -0
  5. flyte/_bin/serve.py +178 -0
  6. flyte/_build.py +26 -0
  7. flyte/_cache/__init__.py +12 -0
  8. flyte/_cache/cache.py +147 -0
  9. flyte/_cache/defaults.py +9 -0
  10. flyte/_cache/local_cache.py +216 -0
  11. flyte/_cache/policy_function_body.py +42 -0
  12. flyte/_code_bundle/__init__.py +8 -0
  13. flyte/_code_bundle/_ignore.py +121 -0
  14. flyte/_code_bundle/_packaging.py +218 -0
  15. flyte/_code_bundle/_utils.py +347 -0
  16. flyte/_code_bundle/bundle.py +266 -0
  17. flyte/_constants.py +1 -0
  18. flyte/_context.py +155 -0
  19. flyte/_custom_context.py +73 -0
  20. flyte/_debug/__init__.py +0 -0
  21. flyte/_debug/constants.py +38 -0
  22. flyte/_debug/utils.py +17 -0
  23. flyte/_debug/vscode.py +307 -0
  24. flyte/_deploy.py +408 -0
  25. flyte/_deployer.py +109 -0
  26. flyte/_doc.py +29 -0
  27. flyte/_docstring.py +32 -0
  28. flyte/_environment.py +122 -0
  29. flyte/_excepthook.py +37 -0
  30. flyte/_group.py +32 -0
  31. flyte/_hash.py +8 -0
  32. flyte/_image.py +1055 -0
  33. flyte/_initialize.py +628 -0
  34. flyte/_interface.py +119 -0
  35. flyte/_internal/__init__.py +3 -0
  36. flyte/_internal/controllers/__init__.py +129 -0
  37. flyte/_internal/controllers/_local_controller.py +239 -0
  38. flyte/_internal/controllers/_trace.py +48 -0
  39. flyte/_internal/controllers/remote/__init__.py +58 -0
  40. flyte/_internal/controllers/remote/_action.py +211 -0
  41. flyte/_internal/controllers/remote/_client.py +47 -0
  42. flyte/_internal/controllers/remote/_controller.py +583 -0
  43. flyte/_internal/controllers/remote/_core.py +465 -0
  44. flyte/_internal/controllers/remote/_informer.py +381 -0
  45. flyte/_internal/controllers/remote/_service_protocol.py +50 -0
  46. flyte/_internal/imagebuild/__init__.py +3 -0
  47. flyte/_internal/imagebuild/docker_builder.py +706 -0
  48. flyte/_internal/imagebuild/image_builder.py +277 -0
  49. flyte/_internal/imagebuild/remote_builder.py +386 -0
  50. flyte/_internal/imagebuild/utils.py +78 -0
  51. flyte/_internal/resolvers/__init__.py +0 -0
  52. flyte/_internal/resolvers/_task_module.py +21 -0
  53. flyte/_internal/resolvers/common.py +31 -0
  54. flyte/_internal/resolvers/default.py +28 -0
  55. flyte/_internal/runtime/__init__.py +0 -0
  56. flyte/_internal/runtime/convert.py +486 -0
  57. flyte/_internal/runtime/entrypoints.py +204 -0
  58. flyte/_internal/runtime/io.py +188 -0
  59. flyte/_internal/runtime/resources_serde.py +152 -0
  60. flyte/_internal/runtime/reuse.py +125 -0
  61. flyte/_internal/runtime/rusty.py +193 -0
  62. flyte/_internal/runtime/task_serde.py +362 -0
  63. flyte/_internal/runtime/taskrunner.py +209 -0
  64. flyte/_internal/runtime/trigger_serde.py +160 -0
  65. flyte/_internal/runtime/types_serde.py +54 -0
  66. flyte/_keyring/__init__.py +0 -0
  67. flyte/_keyring/file.py +115 -0
  68. flyte/_logging.py +300 -0
  69. flyte/_map.py +312 -0
  70. flyte/_module.py +72 -0
  71. flyte/_pod.py +30 -0
  72. flyte/_resources.py +473 -0
  73. flyte/_retry.py +32 -0
  74. flyte/_reusable_environment.py +102 -0
  75. flyte/_run.py +724 -0
  76. flyte/_secret.py +96 -0
  77. flyte/_task.py +550 -0
  78. flyte/_task_environment.py +316 -0
  79. flyte/_task_plugins.py +47 -0
  80. flyte/_timeout.py +47 -0
  81. flyte/_tools.py +27 -0
  82. flyte/_trace.py +119 -0
  83. flyte/_trigger.py +1000 -0
  84. flyte/_utils/__init__.py +30 -0
  85. flyte/_utils/asyn.py +121 -0
  86. flyte/_utils/async_cache.py +139 -0
  87. flyte/_utils/coro_management.py +27 -0
  88. flyte/_utils/docker_credentials.py +173 -0
  89. flyte/_utils/file_handling.py +72 -0
  90. flyte/_utils/helpers.py +134 -0
  91. flyte/_utils/lazy_module.py +54 -0
  92. flyte/_utils/module_loader.py +104 -0
  93. flyte/_utils/org_discovery.py +57 -0
  94. flyte/_utils/uv_script_parser.py +49 -0
  95. flyte/_version.py +34 -0
  96. flyte/app/__init__.py +22 -0
  97. flyte/app/_app_environment.py +157 -0
  98. flyte/app/_deploy.py +125 -0
  99. flyte/app/_input.py +160 -0
  100. flyte/app/_runtime/__init__.py +3 -0
  101. flyte/app/_runtime/app_serde.py +347 -0
  102. flyte/app/_types.py +101 -0
  103. flyte/app/extras/__init__.py +3 -0
  104. flyte/app/extras/_fastapi.py +151 -0
  105. flyte/cli/__init__.py +12 -0
  106. flyte/cli/_abort.py +28 -0
  107. flyte/cli/_build.py +114 -0
  108. flyte/cli/_common.py +468 -0
  109. flyte/cli/_create.py +371 -0
  110. flyte/cli/_delete.py +45 -0
  111. flyte/cli/_deploy.py +293 -0
  112. flyte/cli/_gen.py +176 -0
  113. flyte/cli/_get.py +370 -0
  114. flyte/cli/_option.py +33 -0
  115. flyte/cli/_params.py +554 -0
  116. flyte/cli/_plugins.py +209 -0
  117. flyte/cli/_run.py +597 -0
  118. flyte/cli/_serve.py +64 -0
  119. flyte/cli/_update.py +37 -0
  120. flyte/cli/_user.py +17 -0
  121. flyte/cli/main.py +221 -0
  122. flyte/config/__init__.py +3 -0
  123. flyte/config/_config.py +248 -0
  124. flyte/config/_internal.py +73 -0
  125. flyte/config/_reader.py +225 -0
  126. flyte/connectors/__init__.py +11 -0
  127. flyte/connectors/_connector.py +270 -0
  128. flyte/connectors/_server.py +197 -0
  129. flyte/connectors/utils.py +135 -0
  130. flyte/errors.py +243 -0
  131. flyte/extend.py +19 -0
  132. flyte/extras/__init__.py +5 -0
  133. flyte/extras/_container.py +286 -0
  134. flyte/git/__init__.py +3 -0
  135. flyte/git/_config.py +21 -0
  136. flyte/io/__init__.py +29 -0
  137. flyte/io/_dataframe/__init__.py +131 -0
  138. flyte/io/_dataframe/basic_dfs.py +223 -0
  139. flyte/io/_dataframe/dataframe.py +1026 -0
  140. flyte/io/_dir.py +910 -0
  141. flyte/io/_file.py +914 -0
  142. flyte/io/_hashing_io.py +342 -0
  143. flyte/models.py +479 -0
  144. flyte/py.typed +0 -0
  145. flyte/remote/__init__.py +35 -0
  146. flyte/remote/_action.py +738 -0
  147. flyte/remote/_app.py +57 -0
  148. flyte/remote/_client/__init__.py +0 -0
  149. flyte/remote/_client/_protocols.py +189 -0
  150. flyte/remote/_client/auth/__init__.py +12 -0
  151. flyte/remote/_client/auth/_auth_utils.py +14 -0
  152. flyte/remote/_client/auth/_authenticators/__init__.py +0 -0
  153. flyte/remote/_client/auth/_authenticators/base.py +403 -0
  154. flyte/remote/_client/auth/_authenticators/client_credentials.py +73 -0
  155. flyte/remote/_client/auth/_authenticators/device_code.py +117 -0
  156. flyte/remote/_client/auth/_authenticators/external_command.py +79 -0
  157. flyte/remote/_client/auth/_authenticators/factory.py +200 -0
  158. flyte/remote/_client/auth/_authenticators/pkce.py +516 -0
  159. flyte/remote/_client/auth/_channel.py +213 -0
  160. flyte/remote/_client/auth/_client_config.py +85 -0
  161. flyte/remote/_client/auth/_default_html.py +32 -0
  162. flyte/remote/_client/auth/_grpc_utils/__init__.py +0 -0
  163. flyte/remote/_client/auth/_grpc_utils/auth_interceptor.py +288 -0
  164. flyte/remote/_client/auth/_grpc_utils/default_metadata_interceptor.py +151 -0
  165. flyte/remote/_client/auth/_keyring.py +152 -0
  166. flyte/remote/_client/auth/_token_client.py +260 -0
  167. flyte/remote/_client/auth/errors.py +16 -0
  168. flyte/remote/_client/controlplane.py +128 -0
  169. flyte/remote/_common.py +30 -0
  170. flyte/remote/_console.py +19 -0
  171. flyte/remote/_data.py +161 -0
  172. flyte/remote/_logs.py +185 -0
  173. flyte/remote/_project.py +88 -0
  174. flyte/remote/_run.py +386 -0
  175. flyte/remote/_secret.py +142 -0
  176. flyte/remote/_task.py +527 -0
  177. flyte/remote/_trigger.py +306 -0
  178. flyte/remote/_user.py +33 -0
  179. flyte/report/__init__.py +3 -0
  180. flyte/report/_report.py +182 -0
  181. flyte/report/_template.html +124 -0
  182. flyte/storage/__init__.py +36 -0
  183. flyte/storage/_config.py +237 -0
  184. flyte/storage/_parallel_reader.py +274 -0
  185. flyte/storage/_remote_fs.py +34 -0
  186. flyte/storage/_storage.py +456 -0
  187. flyte/storage/_utils.py +5 -0
  188. flyte/syncify/__init__.py +56 -0
  189. flyte/syncify/_api.py +375 -0
  190. flyte/types/__init__.py +52 -0
  191. flyte/types/_interface.py +40 -0
  192. flyte/types/_pickle.py +145 -0
  193. flyte/types/_renderer.py +162 -0
  194. flyte/types/_string_literals.py +119 -0
  195. flyte/types/_type_engine.py +2254 -0
  196. flyte/types/_utils.py +80 -0
  197. flyte-2.0.0b32.data/scripts/debug.py +38 -0
  198. flyte-2.0.0b32.data/scripts/runtime.py +195 -0
  199. flyte-2.0.0b32.dist-info/METADATA +351 -0
  200. flyte-2.0.0b32.dist-info/RECORD +204 -0
  201. flyte-2.0.0b32.dist-info/WHEEL +5 -0
  202. flyte-2.0.0b32.dist-info/entry_points.txt +7 -0
  203. flyte-2.0.0b32.dist-info/licenses/LICENSE +201 -0
  204. flyte-2.0.0b32.dist-info/top_level.txt +1 -0
flyte/remote/_task.py ADDED
@@ -0,0 +1,527 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import functools
5
+ from dataclasses import dataclass
6
+ from typing import Any, AsyncIterator, Callable, Coroutine, Dict, Iterator, Literal, Optional, Tuple, Union, cast
7
+
8
+ import rich.repr
9
+ from flyteidl2.common import identifier_pb2, list_pb2
10
+ from flyteidl2.core import literals_pb2
11
+ from flyteidl2.task import task_definition_pb2, task_service_pb2
12
+
13
+ import flyte
14
+ import flyte.errors
15
+ from flyte._cache.cache import CacheBehavior
16
+ from flyte._context import internal_ctx
17
+ from flyte._initialize import ensure_client, get_client, get_init_config
18
+ from flyte._internal.runtime.resources_serde import get_proto_resources
19
+ from flyte._internal.runtime.task_serde import get_proto_retry_strategy, get_proto_timeout, get_security_context
20
+ from flyte._logging import logger
21
+ from flyte.models import NativeInterface
22
+ from flyte.syncify import syncify
23
+
24
+ from ._common import ToJSONMixin
25
+
26
+
27
+ def _repr_task_metadata(metadata: task_definition_pb2.TaskMetadata) -> rich.repr.Result:
28
+ """
29
+ Rich representation of the task metadata.
30
+ """
31
+ if metadata.deployed_by:
32
+ if metadata.deployed_by.user:
33
+ yield "deployed_by", f"User: {metadata.deployed_by.user.spec.email}"
34
+ else:
35
+ yield "deployed_by", f"App: {metadata.deployed_by.application.spec.name}"
36
+ yield "short_name", metadata.short_name
37
+ yield "deployed_at", metadata.deployed_at.ToDatetime()
38
+ yield "environment_name", metadata.environment_name
39
+
40
+
41
+ class LazyEntity:
42
+ """
43
+ Fetches the entity when the entity is called or when the entity is retrieved.
44
+ The entity is derived from RemoteEntity so that it behaves exactly like the mimicked entity.
45
+ """
46
+
47
+ def __init__(self, name: str, getter: Callable[..., Coroutine[Any, Any, TaskDetails]], *args, **kwargs):
48
+ self._task: Optional[TaskDetails] = None
49
+ self._getter = getter
50
+ self._name = name
51
+ self._mutex = asyncio.Lock()
52
+
53
+ @property
54
+ def name(self) -> str:
55
+ return self._name
56
+
57
+ @syncify
58
+ async def fetch(self) -> TaskDetails:
59
+ """
60
+ Forwards all other attributes to task, causing the task to be fetched!
61
+ """
62
+ async with self._mutex:
63
+ if self._task is None:
64
+ self._task = await self._getter()
65
+ if self._task is None:
66
+ raise RuntimeError(f"Error downloading the task {self._name}, (check original exception...)")
67
+ return self._task
68
+
69
+ @syncify
70
+ async def override(
71
+ self,
72
+ **kwargs: Any,
73
+ ) -> LazyEntity:
74
+ task_details = cast(TaskDetails, await self.fetch.aio())
75
+ new_task_details = task_details.override(**kwargs)
76
+ new_entity = LazyEntity(self._name, self._getter)
77
+ new_entity._task = new_task_details
78
+ return new_entity
79
+
80
+ async def __call__(self, *args, **kwargs):
81
+ """
82
+ Forwards the call to the underlying task. The entity will be fetched if not already present
83
+ """
84
+ tk = await self.fetch.aio()
85
+ return await tk(*args, **kwargs)
86
+
87
+ def __repr__(self) -> str:
88
+ return str(self)
89
+
90
+ def __str__(self) -> str:
91
+ return f"Future for task with name {self._name}"
92
+
93
+
94
+ AutoVersioning = Literal["latest", "current"]
95
+
96
+
97
+ @dataclass(frozen=True)
98
+ class TaskDetails(ToJSONMixin):
99
+ pb2: task_definition_pb2.TaskDetails
100
+ max_inline_io_bytes: int = 10 * 1024 * 1024 # 10 MB
101
+ overriden_queue: Optional[str] = None
102
+
103
+ @classmethod
104
+ def get(
105
+ cls,
106
+ name: str,
107
+ project: str | None,
108
+ domain: str | None,
109
+ version: str | None = None,
110
+ auto_version: AutoVersioning | None = None,
111
+ ) -> LazyEntity:
112
+ """
113
+ Get a task by its ID or name. If both are provided, the ID will take precedence.
114
+
115
+ Either version or auto_version are required parameters.
116
+
117
+ :param name: The name of the task.
118
+ :param project: The project of the task.
119
+ :param domain: The domain of the task.
120
+ :param version: The version of the task.
121
+ :param auto_version: If set to "latest", the latest-by-time ordered from now, version of the task will be used.
122
+ If set to "current", the version will be derived from the callee tasks context. This is useful if you are
123
+ deploying all environments with the same version. If auto_version is current, you can only access the task from
124
+ within a task context.
125
+ """
126
+
127
+ if version is None and auto_version is None:
128
+ raise ValueError("Either version or auto_version must be provided.")
129
+
130
+ if version is None and auto_version not in ["latest", "current"]:
131
+ raise ValueError("auto_version must be either 'latest' or 'current'.")
132
+
133
+ async def deferred_get(_version: str | None, _auto_version: AutoVersioning | None) -> TaskDetails:
134
+ if _version is None:
135
+ if _auto_version == "latest":
136
+ tasks = []
137
+ async for x in Task.listall.aio(
138
+ by_task_name=name,
139
+ project=project,
140
+ domain=domain,
141
+ sort_by=("created_at", "desc"),
142
+ limit=1,
143
+ ):
144
+ tasks.append(x)
145
+ if not tasks:
146
+ raise flyte.errors.ReferenceTaskError(f"Task {name} not found.")
147
+ _version = tasks[0].version
148
+ elif _auto_version == "current":
149
+ ctx = flyte.ctx()
150
+ if ctx is None:
151
+ raise ValueError("auto_version=current can only be used within a task context.")
152
+ _version = ctx.version
153
+ cfg = get_init_config()
154
+ task_id = task_definition_pb2.TaskIdentifier(
155
+ org=cfg.org,
156
+ project=project or cfg.project,
157
+ domain=domain or cfg.domain,
158
+ name=name,
159
+ version=_version,
160
+ )
161
+ resp = await get_client().task_service.GetTaskDetails(
162
+ task_service_pb2.GetTaskDetailsRequest(
163
+ task_id=task_id,
164
+ )
165
+ )
166
+ return cls(resp.details)
167
+
168
+ return LazyEntity(
169
+ name=name, getter=functools.partial(deferred_get, _version=version, _auto_version=auto_version)
170
+ )
171
+
172
+ @classmethod
173
+ async def fetch(
174
+ cls,
175
+ name: str,
176
+ project: str | None = None,
177
+ domain: str | None = None,
178
+ version: str | None = None,
179
+ auto_version: AutoVersioning | None = None,
180
+ ) -> TaskDetails:
181
+ lazy = TaskDetails.get(name, project=project, domain=domain, version=version, auto_version=auto_version)
182
+ return await lazy.fetch.aio()
183
+
184
+ @property
185
+ def name(self) -> str:
186
+ """
187
+ The name of the task.
188
+ """
189
+ return self.pb2.task_id.name
190
+
191
+ @property
192
+ def version(self) -> str:
193
+ """
194
+ The version of the task.
195
+ """
196
+ return self.pb2.task_id.version
197
+
198
+ @property
199
+ def task_type(self) -> str:
200
+ """
201
+ The type of the task.
202
+ """
203
+ return self.pb2.spec.task_template.type
204
+
205
+ @property
206
+ def default_input_args(self) -> Tuple[str, ...]:
207
+ """
208
+ The default input arguments of the task.
209
+ """
210
+ return tuple(x.name for x in self.pb2.spec.default_inputs)
211
+
212
+ @property
213
+ def required_args(self) -> Tuple[str, ...]:
214
+ """
215
+ The required input arguments of the task.
216
+ """
217
+ return tuple(x for x, _ in self.interface.inputs.items() if x not in self.default_input_args)
218
+
219
+ @functools.cached_property
220
+ def interface(self) -> NativeInterface:
221
+ """
222
+ The interface of the task.
223
+ """
224
+ import flyte.types as types
225
+
226
+ return types.guess_interface(self.pb2.spec.task_template.interface, default_inputs=self.pb2.spec.default_inputs)
227
+
228
+ @property
229
+ def cache(self) -> flyte.Cache:
230
+ """
231
+ The cache policy of the task.
232
+ """
233
+ metadata = self.pb2.spec.task_template.metadata
234
+ behavior: CacheBehavior
235
+ if not metadata.discoverable:
236
+ behavior = "disable"
237
+ elif metadata.discovery_version:
238
+ behavior = "override"
239
+ else:
240
+ behavior = "auto"
241
+
242
+ return flyte.Cache(
243
+ behavior=behavior,
244
+ version_override=metadata.discovery_version if metadata.discovery_version else None,
245
+ serialize=metadata.cache_serializable,
246
+ ignored_inputs=tuple(metadata.cache_ignore_input_vars),
247
+ )
248
+
249
+ @property
250
+ def secrets(self):
251
+ """
252
+ The secrets of the task.
253
+ """
254
+ return [s.key for s in self.pb2.spec.task_template.security_context.secrets]
255
+
256
+ @property
257
+ def resources(self):
258
+ """
259
+ The resources of the task.
260
+ """
261
+ if self.pb2.spec.task_template.container is None:
262
+ return ()
263
+ return (
264
+ self.pb2.spec.task_template.container.resources.requests,
265
+ self.pb2.spec.task_template.container.resources.limits,
266
+ )
267
+
268
+ async def __call__(self, *args, **kwargs):
269
+ """
270
+ Forwards the call to the underlying task. The entity will be fetched if not already present
271
+ """
272
+ # TODO support kwargs, for this we need ordered inputs to be stored in the task spec.
273
+ if len(args) > 0:
274
+ raise flyte.errors.ReferenceTaskError(
275
+ f"Reference task {self.name} does not support positional arguments"
276
+ f"currently. Please use keyword arguments."
277
+ )
278
+
279
+ ctx = internal_ctx()
280
+ if ctx.is_task_context():
281
+ # If we are in a task context, that implies we are executing a Run.
282
+ # In this scenario, we should submit the task to the controller.
283
+ # We will also check if we are not initialized, It is not expected to be not initialized
284
+ from flyte._internal.controllers import get_controller
285
+
286
+ controller = get_controller()
287
+ if len(self.required_args) > 0:
288
+ if len(args) + len(kwargs) < len(self.required_args):
289
+ raise ValueError(
290
+ f"Task {self.name} requires at least {self.required_args} arguments, "
291
+ f"but only received args:{args} kwargs{kwargs}."
292
+ )
293
+ if controller:
294
+ return await controller.submit_task_ref(self, *args, **kwargs)
295
+ raise flyte.errors.ReferenceTaskError(
296
+ f"Reference tasks [{self.name}] cannot be executed locally, only remotely."
297
+ )
298
+
299
+ @property
300
+ def queue(self) -> Optional[str]:
301
+ """
302
+ The queue to use for the task.
303
+ """
304
+ return self.overriden_queue
305
+
306
+ def override(
307
+ self,
308
+ *,
309
+ short_name: Optional[str] = None,
310
+ resources: Optional[flyte.Resources] = None,
311
+ retries: Union[int, flyte.RetryStrategy] = 0,
312
+ timeout: Optional[flyte.TimeoutType] = None,
313
+ env_vars: Optional[Dict[str, str]] = None,
314
+ secrets: Optional[flyte.SecretRequest] = None,
315
+ max_inline_io_bytes: Optional[int] = None,
316
+ cache: Optional[flyte.Cache] = None,
317
+ queue: Optional[str] = None,
318
+ **kwargs: Any,
319
+ ) -> TaskDetails:
320
+ if len(kwargs) > 0:
321
+ raise ValueError(
322
+ f"ReferenceTasks [{self.name}] do not support overriding with kwargs: {kwargs}, "
323
+ f"Check the parameters for override method."
324
+ )
325
+ pb2 = task_definition_pb2.TaskDetails()
326
+ pb2.CopyFrom(self.pb2)
327
+
328
+ if short_name:
329
+ pb2.metadata.short_name = short_name
330
+
331
+ template = pb2.spec.task_template
332
+ if secrets:
333
+ template.security_context.CopyFrom(get_security_context(secrets))
334
+
335
+ if template.HasField("container"):
336
+ if env_vars:
337
+ template.container.env.clear()
338
+ template.container.env.extend([literals_pb2.KeyValuePair(key=k, value=v) for k, v in env_vars.items()])
339
+ if resources:
340
+ template.container.resources.CopyFrom(get_proto_resources(resources))
341
+
342
+ md = template.metadata
343
+ if retries:
344
+ md.retries.CopyFrom(get_proto_retry_strategy(retries))
345
+
346
+ if timeout:
347
+ md.timeout.CopyFrom(get_proto_timeout(timeout))
348
+
349
+ if cache:
350
+ if cache.behavior == "disable":
351
+ md.discoverable = False
352
+ md.discovery_version = ""
353
+ elif cache.behavior == "override":
354
+ md.discoverable = True
355
+ if not cache.version_override:
356
+ raise ValueError("cache.version_override must be set when cache.behavior is 'override'")
357
+ md.discovery_version = cache.version_override
358
+ else:
359
+ if cache.behavior == "auto":
360
+ raise ValueError("cache.behavior must be 'disable' or 'override' for reference tasks")
361
+ raise ValueError(f"Invalid cache behavior: {cache.behavior}.")
362
+ md.cache_serializable = cache.serialize
363
+ md.cache_ignore_input_vars[:] = list(cache.ignored_inputs or ())
364
+
365
+ return TaskDetails(
366
+ pb2,
367
+ max_inline_io_bytes=max_inline_io_bytes or self.max_inline_io_bytes,
368
+ overriden_queue=queue,
369
+ )
370
+
371
+ def __rich_repr__(self) -> rich.repr.Result:
372
+ """
373
+ Rich representation of the task.
374
+ """
375
+ yield "short_name", self.pb2.spec.short_name
376
+ yield "environment", self.pb2.spec.environment
377
+ yield "default_inputs_keys", self.default_input_args
378
+ yield "required_args", self.required_args
379
+ yield "raw_default_inputs", [str(x) for x in self.pb2.spec.default_inputs]
380
+ yield "project", self.pb2.task_id.project
381
+ yield "domain", self.pb2.task_id.domain
382
+ yield "name", self.name
383
+ yield "version", self.version
384
+ yield "task_type", self.task_type
385
+ yield "cache", self.cache
386
+ yield "interface", self.name + str(self.interface)
387
+ yield "secrets", self.secrets
388
+ yield "resources", self.resources
389
+
390
+
391
+ @dataclass
392
+ class Task(ToJSONMixin):
393
+ pb2: task_definition_pb2.Task
394
+
395
+ def __init__(self, pb2: task_definition_pb2.Task):
396
+ self.pb2 = pb2
397
+
398
+ @property
399
+ def name(self) -> str:
400
+ """
401
+ The name of the task.
402
+ """
403
+ return self.pb2.task_id.name
404
+
405
+ @property
406
+ def version(self) -> str:
407
+ """
408
+ The version of the task.
409
+ """
410
+ return self.pb2.task_id.version
411
+
412
+ @classmethod
413
+ def get(
414
+ cls,
415
+ name: str,
416
+ project: str | None = None,
417
+ domain: str | None = None,
418
+ version: str | None = None,
419
+ auto_version: AutoVersioning | None = None,
420
+ ) -> LazyEntity:
421
+ """
422
+ Get a task by its ID or name. If both are provided, the ID will take precedence.
423
+
424
+ Either version or auto_version are required parameters.
425
+
426
+ :param name: The name of the task.
427
+ :param project: The project of the task.
428
+ :param domain: The domain of the task.
429
+ :param version: The version of the task.
430
+ :param auto_version: If set to "latest", the latest-by-time ordered from now, version of the task will be used.
431
+ If set to "current", the version will be derived from the callee tasks context. This is useful if you are
432
+ deploying all environments with the same version. If auto_version is current, you can only access the task from
433
+ within a task context.
434
+ """
435
+ return TaskDetails.get(name, project=project, domain=domain, version=version, auto_version=auto_version)
436
+
437
+ @syncify
438
+ @classmethod
439
+ async def listall(
440
+ cls,
441
+ by_task_name: str | None = None,
442
+ by_task_env: str | None = None,
443
+ project: str | None = None,
444
+ domain: str | None = None,
445
+ sort_by: Tuple[str, Literal["asc", "desc"]] | None = None,
446
+ limit: int = 100,
447
+ ) -> Union[AsyncIterator[Task], Iterator[Task]]:
448
+ """
449
+ Get all runs for the current project and domain.
450
+
451
+ :param by_task_name: If provided, only tasks with this name will be returned.
452
+ :param by_task_env: If provided, only tasks with this environment prefix will be returned.
453
+ :param project: The project to filter tasks by. If None, the current project will be used.
454
+ :param domain: The domain to filter tasks by. If None, the current domain will be used.
455
+ :param sort_by: The sorting criteria for the project list, in the format (field, order).
456
+ :param limit: The maximum number of tasks to return.
457
+ :return: An iterator of runs.
458
+ """
459
+ ensure_client()
460
+ token = None
461
+ sort_by = sort_by or ("created_at", "asc")
462
+ sort_pb2 = list_pb2.Sort(
463
+ key=sort_by[0], direction=list_pb2.Sort.ASCENDING if sort_by[1] == "asc" else list_pb2.Sort.DESCENDING
464
+ )
465
+ cfg = get_init_config()
466
+ filters = []
467
+ if by_task_name:
468
+ filters.append(
469
+ list_pb2.Filter(
470
+ function=list_pb2.Filter.Function.EQUAL,
471
+ field="name",
472
+ values=[by_task_name],
473
+ )
474
+ )
475
+ if by_task_env:
476
+ # ideally we should have a STARTS_WITH filter, but it is not supported yet
477
+ filters.append(
478
+ list_pb2.Filter(
479
+ function=list_pb2.Filter.Function.CONTAINS,
480
+ field="name",
481
+ values=[f"{by_task_env}."],
482
+ )
483
+ )
484
+ original_limit = limit
485
+ if limit > cfg.batch_size:
486
+ limit = cfg.batch_size
487
+ retrieved = 0
488
+ while True:
489
+ resp = await get_client().task_service.ListTasks(
490
+ task_service_pb2.ListTasksRequest(
491
+ org=cfg.org,
492
+ project_id=identifier_pb2.ProjectIdentifier(
493
+ organization=cfg.org,
494
+ domain=domain or cfg.domain,
495
+ name=project or cfg.project,
496
+ ),
497
+ request=list_pb2.ListRequest(
498
+ sort_by=sort_pb2,
499
+ filters=filters,
500
+ limit=limit,
501
+ token=token,
502
+ ),
503
+ )
504
+ )
505
+ token = resp.token
506
+ for t in resp.tasks:
507
+ retrieved += 1
508
+ yield cls(t)
509
+ if not token or retrieved >= original_limit:
510
+ logger.debug(f"Retrieved {retrieved} tasks, stopping iteration.")
511
+ break
512
+
513
+ def __rich_repr__(self) -> rich.repr.Result:
514
+ """
515
+ Rich representation of the task.
516
+ """
517
+ yield "project", self.pb2.task_id.project
518
+ yield "domain", self.pb2.task_id.domain
519
+ yield "name", self.pb2.task_id.name
520
+ yield "version", self.pb2.task_id.version
521
+ yield "short_name", self.pb2.metadata.short_name
522
+ for t in _repr_task_metadata(self.pb2.metadata):
523
+ yield t
524
+
525
+
526
+ if __name__ == "__main__":
527
+ tk = Task.get(name="example_task")