flyte 2.0.0b32__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flyte might be problematic. Click here for more details.

Files changed (204) hide show
  1. flyte/__init__.py +108 -0
  2. flyte/_bin/__init__.py +0 -0
  3. flyte/_bin/debug.py +38 -0
  4. flyte/_bin/runtime.py +195 -0
  5. flyte/_bin/serve.py +178 -0
  6. flyte/_build.py +26 -0
  7. flyte/_cache/__init__.py +12 -0
  8. flyte/_cache/cache.py +147 -0
  9. flyte/_cache/defaults.py +9 -0
  10. flyte/_cache/local_cache.py +216 -0
  11. flyte/_cache/policy_function_body.py +42 -0
  12. flyte/_code_bundle/__init__.py +8 -0
  13. flyte/_code_bundle/_ignore.py +121 -0
  14. flyte/_code_bundle/_packaging.py +218 -0
  15. flyte/_code_bundle/_utils.py +347 -0
  16. flyte/_code_bundle/bundle.py +266 -0
  17. flyte/_constants.py +1 -0
  18. flyte/_context.py +155 -0
  19. flyte/_custom_context.py +73 -0
  20. flyte/_debug/__init__.py +0 -0
  21. flyte/_debug/constants.py +38 -0
  22. flyte/_debug/utils.py +17 -0
  23. flyte/_debug/vscode.py +307 -0
  24. flyte/_deploy.py +408 -0
  25. flyte/_deployer.py +109 -0
  26. flyte/_doc.py +29 -0
  27. flyte/_docstring.py +32 -0
  28. flyte/_environment.py +122 -0
  29. flyte/_excepthook.py +37 -0
  30. flyte/_group.py +32 -0
  31. flyte/_hash.py +8 -0
  32. flyte/_image.py +1055 -0
  33. flyte/_initialize.py +628 -0
  34. flyte/_interface.py +119 -0
  35. flyte/_internal/__init__.py +3 -0
  36. flyte/_internal/controllers/__init__.py +129 -0
  37. flyte/_internal/controllers/_local_controller.py +239 -0
  38. flyte/_internal/controllers/_trace.py +48 -0
  39. flyte/_internal/controllers/remote/__init__.py +58 -0
  40. flyte/_internal/controllers/remote/_action.py +211 -0
  41. flyte/_internal/controllers/remote/_client.py +47 -0
  42. flyte/_internal/controllers/remote/_controller.py +583 -0
  43. flyte/_internal/controllers/remote/_core.py +465 -0
  44. flyte/_internal/controllers/remote/_informer.py +381 -0
  45. flyte/_internal/controllers/remote/_service_protocol.py +50 -0
  46. flyte/_internal/imagebuild/__init__.py +3 -0
  47. flyte/_internal/imagebuild/docker_builder.py +706 -0
  48. flyte/_internal/imagebuild/image_builder.py +277 -0
  49. flyte/_internal/imagebuild/remote_builder.py +386 -0
  50. flyte/_internal/imagebuild/utils.py +78 -0
  51. flyte/_internal/resolvers/__init__.py +0 -0
  52. flyte/_internal/resolvers/_task_module.py +21 -0
  53. flyte/_internal/resolvers/common.py +31 -0
  54. flyte/_internal/resolvers/default.py +28 -0
  55. flyte/_internal/runtime/__init__.py +0 -0
  56. flyte/_internal/runtime/convert.py +486 -0
  57. flyte/_internal/runtime/entrypoints.py +204 -0
  58. flyte/_internal/runtime/io.py +188 -0
  59. flyte/_internal/runtime/resources_serde.py +152 -0
  60. flyte/_internal/runtime/reuse.py +125 -0
  61. flyte/_internal/runtime/rusty.py +193 -0
  62. flyte/_internal/runtime/task_serde.py +362 -0
  63. flyte/_internal/runtime/taskrunner.py +209 -0
  64. flyte/_internal/runtime/trigger_serde.py +160 -0
  65. flyte/_internal/runtime/types_serde.py +54 -0
  66. flyte/_keyring/__init__.py +0 -0
  67. flyte/_keyring/file.py +115 -0
  68. flyte/_logging.py +300 -0
  69. flyte/_map.py +312 -0
  70. flyte/_module.py +72 -0
  71. flyte/_pod.py +30 -0
  72. flyte/_resources.py +473 -0
  73. flyte/_retry.py +32 -0
  74. flyte/_reusable_environment.py +102 -0
  75. flyte/_run.py +724 -0
  76. flyte/_secret.py +96 -0
  77. flyte/_task.py +550 -0
  78. flyte/_task_environment.py +316 -0
  79. flyte/_task_plugins.py +47 -0
  80. flyte/_timeout.py +47 -0
  81. flyte/_tools.py +27 -0
  82. flyte/_trace.py +119 -0
  83. flyte/_trigger.py +1000 -0
  84. flyte/_utils/__init__.py +30 -0
  85. flyte/_utils/asyn.py +121 -0
  86. flyte/_utils/async_cache.py +139 -0
  87. flyte/_utils/coro_management.py +27 -0
  88. flyte/_utils/docker_credentials.py +173 -0
  89. flyte/_utils/file_handling.py +72 -0
  90. flyte/_utils/helpers.py +134 -0
  91. flyte/_utils/lazy_module.py +54 -0
  92. flyte/_utils/module_loader.py +104 -0
  93. flyte/_utils/org_discovery.py +57 -0
  94. flyte/_utils/uv_script_parser.py +49 -0
  95. flyte/_version.py +34 -0
  96. flyte/app/__init__.py +22 -0
  97. flyte/app/_app_environment.py +157 -0
  98. flyte/app/_deploy.py +125 -0
  99. flyte/app/_input.py +160 -0
  100. flyte/app/_runtime/__init__.py +3 -0
  101. flyte/app/_runtime/app_serde.py +347 -0
  102. flyte/app/_types.py +101 -0
  103. flyte/app/extras/__init__.py +3 -0
  104. flyte/app/extras/_fastapi.py +151 -0
  105. flyte/cli/__init__.py +12 -0
  106. flyte/cli/_abort.py +28 -0
  107. flyte/cli/_build.py +114 -0
  108. flyte/cli/_common.py +468 -0
  109. flyte/cli/_create.py +371 -0
  110. flyte/cli/_delete.py +45 -0
  111. flyte/cli/_deploy.py +293 -0
  112. flyte/cli/_gen.py +176 -0
  113. flyte/cli/_get.py +370 -0
  114. flyte/cli/_option.py +33 -0
  115. flyte/cli/_params.py +554 -0
  116. flyte/cli/_plugins.py +209 -0
  117. flyte/cli/_run.py +597 -0
  118. flyte/cli/_serve.py +64 -0
  119. flyte/cli/_update.py +37 -0
  120. flyte/cli/_user.py +17 -0
  121. flyte/cli/main.py +221 -0
  122. flyte/config/__init__.py +3 -0
  123. flyte/config/_config.py +248 -0
  124. flyte/config/_internal.py +73 -0
  125. flyte/config/_reader.py +225 -0
  126. flyte/connectors/__init__.py +11 -0
  127. flyte/connectors/_connector.py +270 -0
  128. flyte/connectors/_server.py +197 -0
  129. flyte/connectors/utils.py +135 -0
  130. flyte/errors.py +243 -0
  131. flyte/extend.py +19 -0
  132. flyte/extras/__init__.py +5 -0
  133. flyte/extras/_container.py +286 -0
  134. flyte/git/__init__.py +3 -0
  135. flyte/git/_config.py +21 -0
  136. flyte/io/__init__.py +29 -0
  137. flyte/io/_dataframe/__init__.py +131 -0
  138. flyte/io/_dataframe/basic_dfs.py +223 -0
  139. flyte/io/_dataframe/dataframe.py +1026 -0
  140. flyte/io/_dir.py +910 -0
  141. flyte/io/_file.py +914 -0
  142. flyte/io/_hashing_io.py +342 -0
  143. flyte/models.py +479 -0
  144. flyte/py.typed +0 -0
  145. flyte/remote/__init__.py +35 -0
  146. flyte/remote/_action.py +738 -0
  147. flyte/remote/_app.py +57 -0
  148. flyte/remote/_client/__init__.py +0 -0
  149. flyte/remote/_client/_protocols.py +189 -0
  150. flyte/remote/_client/auth/__init__.py +12 -0
  151. flyte/remote/_client/auth/_auth_utils.py +14 -0
  152. flyte/remote/_client/auth/_authenticators/__init__.py +0 -0
  153. flyte/remote/_client/auth/_authenticators/base.py +403 -0
  154. flyte/remote/_client/auth/_authenticators/client_credentials.py +73 -0
  155. flyte/remote/_client/auth/_authenticators/device_code.py +117 -0
  156. flyte/remote/_client/auth/_authenticators/external_command.py +79 -0
  157. flyte/remote/_client/auth/_authenticators/factory.py +200 -0
  158. flyte/remote/_client/auth/_authenticators/pkce.py +516 -0
  159. flyte/remote/_client/auth/_channel.py +213 -0
  160. flyte/remote/_client/auth/_client_config.py +85 -0
  161. flyte/remote/_client/auth/_default_html.py +32 -0
  162. flyte/remote/_client/auth/_grpc_utils/__init__.py +0 -0
  163. flyte/remote/_client/auth/_grpc_utils/auth_interceptor.py +288 -0
  164. flyte/remote/_client/auth/_grpc_utils/default_metadata_interceptor.py +151 -0
  165. flyte/remote/_client/auth/_keyring.py +152 -0
  166. flyte/remote/_client/auth/_token_client.py +260 -0
  167. flyte/remote/_client/auth/errors.py +16 -0
  168. flyte/remote/_client/controlplane.py +128 -0
  169. flyte/remote/_common.py +30 -0
  170. flyte/remote/_console.py +19 -0
  171. flyte/remote/_data.py +161 -0
  172. flyte/remote/_logs.py +185 -0
  173. flyte/remote/_project.py +88 -0
  174. flyte/remote/_run.py +386 -0
  175. flyte/remote/_secret.py +142 -0
  176. flyte/remote/_task.py +527 -0
  177. flyte/remote/_trigger.py +306 -0
  178. flyte/remote/_user.py +33 -0
  179. flyte/report/__init__.py +3 -0
  180. flyte/report/_report.py +182 -0
  181. flyte/report/_template.html +124 -0
  182. flyte/storage/__init__.py +36 -0
  183. flyte/storage/_config.py +237 -0
  184. flyte/storage/_parallel_reader.py +274 -0
  185. flyte/storage/_remote_fs.py +34 -0
  186. flyte/storage/_storage.py +456 -0
  187. flyte/storage/_utils.py +5 -0
  188. flyte/syncify/__init__.py +56 -0
  189. flyte/syncify/_api.py +375 -0
  190. flyte/types/__init__.py +52 -0
  191. flyte/types/_interface.py +40 -0
  192. flyte/types/_pickle.py +145 -0
  193. flyte/types/_renderer.py +162 -0
  194. flyte/types/_string_literals.py +119 -0
  195. flyte/types/_type_engine.py +2254 -0
  196. flyte/types/_utils.py +80 -0
  197. flyte-2.0.0b32.data/scripts/debug.py +38 -0
  198. flyte-2.0.0b32.data/scripts/runtime.py +195 -0
  199. flyte-2.0.0b32.dist-info/METADATA +351 -0
  200. flyte-2.0.0b32.dist-info/RECORD +204 -0
  201. flyte-2.0.0b32.dist-info/WHEEL +5 -0
  202. flyte-2.0.0b32.dist-info/entry_points.txt +7 -0
  203. flyte-2.0.0b32.dist-info/licenses/LICENSE +201 -0
  204. flyte-2.0.0b32.dist-info/top_level.txt +1 -0
flyte/_pod.py ADDED
@@ -0,0 +1,30 @@
1
+ from dataclasses import dataclass, field
2
+ from typing import TYPE_CHECKING, Dict, Optional
3
+
4
+ if TYPE_CHECKING:
5
+ from flyteidl2.core.tasks_pb2 import K8sPod
6
+ from kubernetes.client import V1PodSpec
7
+
8
+
9
+ _PRIMARY_CONTAINER_NAME_FIELD = "primary_container_name"
10
+ _PRIMARY_CONTAINER_DEFAULT_NAME = "primary"
11
+
12
+
13
+ @dataclass(init=True, repr=True, eq=True, frozen=False)
14
+ class PodTemplate(object):
15
+ """Custom PodTemplate specification for a Task."""
16
+
17
+ pod_spec: Optional["V1PodSpec"] = field(default_factory=lambda: V1PodSpec())
18
+ primary_container_name: str = _PRIMARY_CONTAINER_DEFAULT_NAME
19
+ labels: Optional[Dict[str, str]] = None
20
+ annotations: Optional[Dict[str, str]] = None
21
+
22
+ def to_k8s_pod(self) -> "K8sPod":
23
+ from flyteidl2.core.tasks_pb2 import K8sObjectMetadata, K8sPod
24
+ from kubernetes.client import ApiClient
25
+
26
+ return K8sPod(
27
+ metadata=K8sObjectMetadata(labels=self.labels, annotations=self.annotations),
28
+ pod_spec=ApiClient().sanitize_for_serialization(self.pod_spec),
29
+ primary_container_name=self.primary_container_name,
30
+ )
flyte/_resources.py ADDED
@@ -0,0 +1,473 @@
1
+ import typing
2
+ from dataclasses import dataclass, fields
3
+ from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union, get_args
4
+
5
+ import rich.repr
6
+
7
+ from flyte._pod import _PRIMARY_CONTAINER_DEFAULT_NAME
8
+
9
+ if TYPE_CHECKING:
10
+ from kubernetes.client import V1PodSpec
11
+
12
+ PRIMARY_CONTAINER_DEFAULT_NAME = "primary"
13
+
14
+ GPUType = Literal["A10", "A10G", "A100", "A100 80G", "B200", "H100", "L4", "L40s", "T4", "V100", "RTX PRO 6000"]
15
+ GPUQuantity = Literal[1, 2, 3, 4, 5, 6, 7, 8]
16
+ A100Parts = Literal["1g.5gb", "2g.10gb", "3g.20gb", "4g.20gb", "7g.40gb"]
17
+ """
18
+ Partitions for NVIDIA A100 GPU.
19
+ """
20
+
21
+ A100_80GBParts = Literal["1g.10gb", "2g.20gb", "3g.40gb", "4g.40gb", "7g.80gb"]
22
+ """
23
+ Partitions for NVIDIA A100 80GB GPU.
24
+ """
25
+
26
+ TPUType = Literal["V5P", "V6E"]
27
+ V5EParts = Literal["1x1", "2x2", "2x4", "4x4", "4x8", "8x8", "8x16", "16x16"]
28
+
29
+ V5PParts = Literal[
30
+ "2x2x1", "2x2x2", "2x4x4", "4x4x4", "4x4x8", "4x8x8", "8x8x8", "8x8x16", "8x16x16", "16x16x16", "16x16x24"
31
+ ]
32
+ """
33
+ Slices for Google Cloud TPU v5p.
34
+ """
35
+
36
+ V6EParts = Literal["1x1", "2x2", "2x4", "4x4", "4x8", "8x8", "8x16", "16x16"]
37
+ """
38
+ Slices for Google Cloud TPU v6e.
39
+ """
40
+
41
+ NeuronType = Literal["Inf1", "Inf2", "Trn1", "Trn1n", "Trn2", "Trn2u"]
42
+
43
+ AMD_GPUType = Literal["MI100", "MI210", "MI250", "MI250X", "MI300A", "MI300X", "MI325X", "MI350X", "MI355X"]
44
+
45
+ HABANA_GAUDIType = Literal["Gaudi1"]
46
+
47
+ Accelerators = Literal[
48
+ # A10
49
+ "A10:1",
50
+ "A10:2",
51
+ "A10:3",
52
+ "A10:4",
53
+ "A10:5",
54
+ "A10:6",
55
+ "A10:7",
56
+ "A10:8",
57
+ # A10G
58
+ "A10G:1",
59
+ "A10G:2",
60
+ "A10G:3",
61
+ "A10G:4",
62
+ "A10G:5",
63
+ "A10G:6",
64
+ "A10G:7",
65
+ "A10G:8",
66
+ # A100
67
+ "A100:1",
68
+ "A100:2",
69
+ "A100:3",
70
+ "A100:4",
71
+ "A100:5",
72
+ "A100:6",
73
+ "A100:7",
74
+ "A100:8",
75
+ # A100 80G
76
+ "A100 80G:1",
77
+ "A100 80G:2",
78
+ "A100 80G:3",
79
+ "A100 80G:4",
80
+ "A100 80G:5",
81
+ "A100 80G:6",
82
+ "A100 80G:7",
83
+ "A100 80G:8",
84
+ # B200
85
+ "B200:1",
86
+ "B200:2",
87
+ "B200:3",
88
+ "B200:4",
89
+ "B200:5",
90
+ "B200:6",
91
+ "B200:7",
92
+ "B200:8",
93
+ # H100
94
+ "H100:1",
95
+ "H100:2",
96
+ "H100:3",
97
+ "H100:4",
98
+ "H100:5",
99
+ "H100:6",
100
+ "H100:7",
101
+ "H100:8",
102
+ # H200
103
+ "H200:1",
104
+ "H200:2",
105
+ "H200:3",
106
+ "H200:4",
107
+ "H200:5",
108
+ "H200:6",
109
+ "H200:7",
110
+ "H200:8",
111
+ # L4
112
+ "L4:1",
113
+ "L4:2",
114
+ "L4:3",
115
+ "L4:4",
116
+ "L4:5",
117
+ "L4:6",
118
+ "L4:7",
119
+ "L4:8",
120
+ # L40s
121
+ "L40s:1",
122
+ "L40s:2",
123
+ "L40s:3",
124
+ "L40s:4",
125
+ "L40s:5",
126
+ "L40s:6",
127
+ "L40s:7",
128
+ "L40s:8",
129
+ # V100
130
+ "V100:1",
131
+ "V100:2",
132
+ "V100:3",
133
+ "V100:4",
134
+ "V100:5",
135
+ "V100:6",
136
+ "V100:7",
137
+ "V100:8",
138
+ # RTX 6000
139
+ "RTX PRO 6000:1",
140
+ # T4
141
+ "T4:1",
142
+ "T4:2",
143
+ "T4:3",
144
+ "T4:4",
145
+ "T4:5",
146
+ "T4:6",
147
+ "T4:7",
148
+ "T4:8",
149
+ # Trn1
150
+ "Trn1:1",
151
+ "Trn1:4",
152
+ "Trn1:8",
153
+ "Trn1:16",
154
+ # Trn1n
155
+ "Trn1n:1",
156
+ "Trn1n:4",
157
+ "Trn1n:8",
158
+ "Trn1n:16",
159
+ # Trn2
160
+ "Trn2:1",
161
+ "Trn2:4",
162
+ "Trn2:8",
163
+ "Trn2:16",
164
+ # Trn2u
165
+ "Trn2u:1",
166
+ "Trn2u:4",
167
+ "Trn2u:8",
168
+ "Trn2u:16",
169
+ # Inf1
170
+ "Inf1:1",
171
+ "Inf1:2",
172
+ "Inf1:3",
173
+ "Inf1:4",
174
+ "Inf1:5",
175
+ "Inf1:6",
176
+ "Inf1:7",
177
+ "Inf1:8",
178
+ "Inf1:9",
179
+ "Inf1:10",
180
+ "Inf1:11",
181
+ "Inf1:12",
182
+ "Inf1:13",
183
+ "Inf1:14",
184
+ "Inf1:15",
185
+ "Inf1:16",
186
+ # Inf2
187
+ "Inf2:1",
188
+ "Inf2:2",
189
+ "Inf2:3",
190
+ "Inf2:4",
191
+ "Inf2:5",
192
+ "Inf2:6",
193
+ "Inf2:7",
194
+ "Inf2:8",
195
+ "Inf2:9",
196
+ "Inf2:10",
197
+ "Inf2:11",
198
+ "Inf2:12",
199
+ # MI100
200
+ "MI100:1",
201
+ # MI210
202
+ "MI210:1",
203
+ # MI250
204
+ "MI250:1",
205
+ # MI250X
206
+ "MI250X:1",
207
+ # MI300A
208
+ "MI300A:1",
209
+ # MI300X
210
+ "MI300X:1",
211
+ # MI325X
212
+ "MI325X:1",
213
+ # MI350X
214
+ "MI350X:1",
215
+ # MI355X
216
+ "MI355X:1",
217
+ # Habana Gaudi
218
+ "Gaudi1:1",
219
+ ]
220
+
221
+ DeviceClass = Literal["GPU", "TPU", "NEURON", "AMD_GPU", "HABANA_GAUDI"]
222
+
223
+ _DeviceClassType: Dict[typing.Any, str] = {
224
+ GPUType: "GPU",
225
+ TPUType: "TPU",
226
+ NeuronType: "NEURON",
227
+ AMD_GPUType: "AMD_GPU",
228
+ HABANA_GAUDIType: "HABANA_GAUDI",
229
+ }
230
+
231
+
232
+ @rich.repr.auto
233
+ @dataclass(frozen=True, slots=True)
234
+ class Device:
235
+ """
236
+ Represents a device type, its quantity and partition if applicable.
237
+ :param device: The type of device (e.g., "T4", "A100").
238
+ :param quantity: The number of devices of this type.
239
+ :param partition: The partition of the device (e.g., "1g.5gb", "2g.10gb" for gpus) or ("1x1", ... for tpus).
240
+ """
241
+
242
+ quantity: int
243
+ device_class: DeviceClass
244
+ device: str | None = None
245
+ partition: str | None = None
246
+
247
+ def __post_init__(self):
248
+ if self.quantity < 1:
249
+ raise ValueError("GPU quantity must be at least 1")
250
+
251
+
252
+ def GPU(device: GPUType, quantity: GPUQuantity, partition: A100Parts | A100_80GBParts | None = None) -> Device:
253
+ """
254
+ Create a GPU device instance.
255
+ :param device: The type of GPU (e.g., "T4", "A100").
256
+ :param quantity: The number of GPUs of this type.
257
+ :param partition: The partition of the GPU (e.g., "1g.5gb", "2g.10gb" for gpus) or ("1x1", ... for tpus).
258
+ :return: Device instance.
259
+ """
260
+ if quantity < 1:
261
+ raise ValueError("GPU quantity must be at least 1")
262
+ if device not in get_args(GPUType):
263
+ raise ValueError(f"Invalid GPU type: {device}. Must be one of {get_args(GPUType)}")
264
+ if partition is not None and device == "A100":
265
+ if partition not in get_args(A100Parts):
266
+ raise ValueError(f"Invalid partition for A100: {partition}. Must be one of {get_args(A100Parts)}")
267
+ elif partition is not None and device == "A100 80G":
268
+ if partition not in get_args(A100_80GBParts):
269
+ raise ValueError(f"Invalid partition for A100 80G: {partition}. Must be one of {get_args(A100_80GBParts)}")
270
+ return Device(device=device, quantity=quantity, partition=partition, device_class="GPU")
271
+
272
+
273
+ def TPU(device: TPUType, partition: V5PParts | V6EParts | None = None):
274
+ """
275
+ Create a TPU device instance.
276
+ :param device: Device type (e.g., "V5P", "V6E").
277
+ :param partition: Partition of the TPU (e.g., "1x1", "2x2", ...).
278
+ :return: Device instance.
279
+ """
280
+ if device not in get_args(TPUType):
281
+ raise ValueError(f"Invalid TPU type: {device}. Must be one of {get_args(TPUType)}")
282
+ if partition is not None and device == "V5P":
283
+ if partition not in get_args(V5PParts):
284
+ raise ValueError(f"Invalid partition for V5P: {partition}. Must be one of {get_args(V5PParts)}")
285
+ elif partition is not None and device == "V6E":
286
+ if partition not in get_args(V6EParts):
287
+ raise ValueError(f"Invalid partition for V6E: {partition}. Must be one of {get_args(V6EParts)}")
288
+ elif partition is not None and device == "V5E":
289
+ if partition not in get_args(V5EParts):
290
+ raise ValueError(f"Invalid partition for V5E: {partition}. Must be one of {get_args(V5EParts)}")
291
+ return Device(1, "TPU", device, partition)
292
+
293
+
294
+ def Neuron(device: NeuronType) -> Device:
295
+ """
296
+ Create a Neuron device instance.
297
+ :param device: Device type (e.g., "Inf1", "Inf2", "Trn1", "Trn1n", "Trn2", "Trn2u").
298
+ :param quantity: The number of Neuron devices of this type.
299
+ :return: Device instance.
300
+ """
301
+ if device not in get_args(NeuronType):
302
+ raise ValueError(f"Invalid Neuron type: {device}. Must be one of {get_args(NeuronType)}")
303
+ return Device(device=device, quantity=1, device_class="NEURON")
304
+
305
+
306
+ def AMD_GPU(device: AMD_GPUType) -> Device:
307
+ """
308
+ Create an AMD GPU device instance.
309
+ :param device: Device type (e.g., "MI100", "MI210", "MI250", "MI250X", "MI300A", "MI300X", "MI325X", "MI350X",
310
+ "MI355X").
311
+ :return: Device instance.
312
+ """
313
+ if device not in get_args(AMD_GPUType):
314
+ raise ValueError(f"Invalid AMD GPU type: {device}. Must be one of {get_args(AMD_GPUType)}")
315
+ return Device(device=device, quantity=1, device_class="AMD_GPU")
316
+
317
+
318
+ def HABANA_GAUDI(device: HABANA_GAUDIType) -> Device:
319
+ """
320
+ Create a Habana Gaudi device instance.
321
+ :param device: Device type (e.g., "DL1").
322
+ :return: Device instance.
323
+ """
324
+ if device not in get_args(HABANA_GAUDIType):
325
+ raise ValueError(f"Invalid Habana Gaudi type: {device}. Must be one of {get_args(HABANA_GAUDIType)}")
326
+ return Device(device=device, quantity=1, device_class="HABANA_GAUDI")
327
+
328
+
329
+ CPUBaseType = int | float | str
330
+
331
+
332
+ @dataclass
333
+ class Resources:
334
+ """
335
+ Resources such as CPU, Memory, and GPU that can be allocated to a task.
336
+
337
+ Example:
338
+ - Single CPU, 1GiB of memory, and 1 T4 GPU:
339
+ ```python
340
+ @task(resources=Resources(cpu=1, memory="1GiB", gpu="T4:1"))
341
+ def my_task() -> int:
342
+ return 42
343
+ ```
344
+ - 1CPU with limit upto 2CPU, 2GiB of memory, and 8 A100 GPUs and 10GiB of disk:
345
+ ```python
346
+ @task(resources=Resources(cpu=(1, 2), memory="2GiB", gpu="A100:8", disk="10GiB"))
347
+ def my_task() -> int:
348
+ return 42
349
+ ```
350
+
351
+ :param cpu: The amount of CPU to allocate to the task. This can be a string, int, float, list of ints or strings,
352
+ or a tuple of two ints or strings.
353
+ :param memory: The amount of memory to allocate to the task. This can be a string, int, float, list of ints or
354
+ strings, or a tuple of two ints or strings.
355
+ :param gpu: The amount of GPU to allocate to the task. This can be an Accelerators enum, an int, or None.
356
+ :param disk: The amount of disk to allocate to the task. This is a string of the form "10GiB".
357
+ """
358
+
359
+ cpu: Union[CPUBaseType, Tuple[CPUBaseType, CPUBaseType], None] = None
360
+ memory: Union[str, Tuple[str, str], None] = None
361
+ gpu: Union[Accelerators, int, Device, None] = None
362
+ disk: Union[str, None] = None
363
+ shm: Union[str, Literal["auto"], None] = None
364
+
365
+ def __post_init__(self):
366
+ if isinstance(self.cpu, tuple):
367
+ if len(self.cpu) != 2:
368
+ raise ValueError("cpu tuple must have exactly two elements")
369
+ if isinstance(self.memory, tuple):
370
+ if len(self.memory) != 2:
371
+ raise ValueError("memory tuple must have exactly two elements")
372
+ if isinstance(self.cpu, (int, float)):
373
+ if self.cpu < 0:
374
+ raise ValueError("cpu must be greater than or equal to 0")
375
+ if self.gpu is not None:
376
+ if isinstance(self.gpu, int):
377
+ if self.gpu < 0:
378
+ raise ValueError("gpu must be greater than or equal to 0")
379
+ elif isinstance(self.gpu, str):
380
+ if self.gpu not in get_args(Accelerators):
381
+ raise ValueError(f"gpu must be one of {Accelerators}, got {self.gpu}")
382
+
383
+ def get_device(self) -> Optional[Device]:
384
+ """
385
+ Get the accelerator string for the task.
386
+
387
+ :return: If GPUs are requested, return a tuple of the device name, and potentially a partition string.
388
+ Default cloud provider labels typically use the following values: `1g.5gb`, `2g.10gb`, etc.
389
+ """
390
+ if self.gpu is None:
391
+ return None
392
+ if isinstance(self.gpu, int):
393
+ return Device(quantity=self.gpu, device_class="GPU")
394
+ if isinstance(self.gpu, str):
395
+ device, portion = self.gpu.split(":")
396
+ for cls, cls_name in _DeviceClassType.items():
397
+ if device in get_args(cls):
398
+ device_class = cls_name
399
+ break
400
+ else:
401
+ raise ValueError(f"Invalid device type: {device}. Must be one of {list(_DeviceClassType.keys())}")
402
+ return Device(device=device, device_class=device_class, quantity=int(portion)) # type: ignore
403
+ return self.gpu
404
+
405
+ def get_shared_memory(self) -> Optional[str]:
406
+ """
407
+ Get the shared memory string for the task.
408
+
409
+ :return: The shared memory string.
410
+ """
411
+ if self.shm is None:
412
+ return None
413
+ if self.shm == "auto":
414
+ return ""
415
+ return self.shm
416
+
417
+
418
+ def _check_resource_is_singular(resource: Resources):
419
+ """
420
+ Raise a value error if the resource has a tuple.
421
+ """
422
+ for field in fields(resource):
423
+ value = getattr(resource, field.name)
424
+ if isinstance(value, (tuple, list)):
425
+ raise ValueError(f"{value} can not be a list or tuple")
426
+ return resource
427
+
428
+
429
+ def pod_spec_from_resources(
430
+ primary_container_name: str = _PRIMARY_CONTAINER_DEFAULT_NAME,
431
+ requests: Optional[Resources] = None,
432
+ limits: Optional[Resources] = None,
433
+ k8s_gpu_resource_key: str = "nvidia.com/gpu",
434
+ ) -> "V1PodSpec":
435
+ from kubernetes.client import V1Container, V1PodSpec, V1ResourceRequirements
436
+
437
+ def _construct_k8s_pods_resources(resources: Optional[Resources], k8s_gpu_resource_key: str):
438
+ if resources is None:
439
+ return None
440
+
441
+ resources_map = {
442
+ "cpu": "cpu",
443
+ "memory": "memory",
444
+ "gpu": k8s_gpu_resource_key,
445
+ "ephemeral_storage": "ephemeral-storage",
446
+ }
447
+
448
+ k8s_pod_resources = {}
449
+
450
+ _check_resource_is_singular(resources)
451
+ for resource in fields(resources):
452
+ resource_value = getattr(resources, resource.name)
453
+ if resource_value is not None:
454
+ k8s_pod_resources[resources_map[resource.name]] = resource_value
455
+
456
+ return k8s_pod_resources
457
+
458
+ requests = _construct_k8s_pods_resources(resources=requests, k8s_gpu_resource_key=k8s_gpu_resource_key)
459
+ limits = _construct_k8s_pods_resources(resources=limits, k8s_gpu_resource_key=k8s_gpu_resource_key)
460
+ requests = requests or limits
461
+ limits = limits or requests
462
+
463
+ return V1PodSpec(
464
+ containers=[
465
+ V1Container(
466
+ name=primary_container_name,
467
+ resources=V1ResourceRequirements(
468
+ requests=requests,
469
+ limits=limits,
470
+ ),
471
+ )
472
+ ]
473
+ )
flyte/_retry.py ADDED
@@ -0,0 +1,32 @@
1
+ from dataclasses import dataclass
2
+ from datetime import timedelta
3
+ from typing import Union
4
+
5
+
6
+ @dataclass
7
+ class RetryStrategy:
8
+ """
9
+ Retry strategy for the task or task environment. Retry strategy is optional or can be a simple number of retries.
10
+
11
+ Example:
12
+ - This will retry the task 5 times.
13
+ ```
14
+ @task(retries=5)
15
+ def my_task():
16
+ pass
17
+ ```
18
+ - This will retry the task 5 times with a maximum backoff of 10 seconds and a backoff factor of 2.
19
+ ```
20
+ @task(retries=RetryStrategy(count=5, max_backoff=10, backoff=2))
21
+ def my_task():
22
+ pass
23
+ ```
24
+
25
+ :param count: The number of retries.
26
+ :param backoff: The maximum backoff time for retries. This can be a float or a timedelta.
27
+ :param backoff: The backoff exponential factor. This can be an integer or a float.
28
+ """
29
+
30
+ count: int
31
+ backoff: Union[float, timedelta, None] = None
32
+ backoff_factor: Union[int, float, None] = None
@@ -0,0 +1,102 @@
1
+ from dataclasses import dataclass
2
+ from datetime import timedelta
3
+ from typing import Tuple, Union
4
+
5
+ from flyte._logging import logger
6
+
7
+
8
+ @dataclass
9
+ class ReusePolicy:
10
+ """
11
+ ReusePolicy can be used to configure a task to reuse the environment. This is useful when the environment creation
12
+ is expensive and the runtime of the task is short. The environment will be reused for the next invocation of the
13
+ task, even the python process maybe be reused by subsequent task invocations. A good mental model is to think of
14
+ the environment as a container that is reused for multiple tasks, more like a long-running service.
15
+
16
+ Caution: It is important to note that the environment is shared, so managing memory and resources is important.
17
+
18
+ :param replicas: Either a single int representing number of replicas or a tuple of two ints representing
19
+ the min and max.
20
+ :param idle_ttl: The maximum idle duration for an environment, specified as either seconds (int) or a
21
+ timedelta, after which all replicas in the environment are shutdown.
22
+ If not set, the default is configured in the backend (can be as low as 90s).
23
+ When a replica remains idle — meaning no tasks are running — for this duration, it will be automatically
24
+ terminated, also referred to as environment idle timeout.
25
+ :param concurrency: The maximum number of tasks that can run concurrently in one instance of the environment.
26
+ Concurrency of greater than 1 is only supported for `async` tasks.
27
+ :param scaledown_ttl: The minimum time to wait before scaling down each replica, specified as either seconds (int)
28
+ or a timedelta. This is useful to prevent rapid scaling down of replicas when tasks are running
29
+ frequently. If not set, the default is configured in the backend.
30
+ """
31
+
32
+ replicas: Union[int, Tuple[int, int]] = 2
33
+ idle_ttl: Union[int, timedelta] = 30 # seconds
34
+ concurrency: int = 1
35
+ scaledown_ttl: Union[int, timedelta] = 30 # seconds
36
+
37
+ def __post_init__(self):
38
+ if self.replicas is None:
39
+ raise ValueError("replicas cannot be None")
40
+ if isinstance(self.replicas, int):
41
+ self.replicas = (self.replicas, self.replicas)
42
+ elif not isinstance(self.replicas, tuple):
43
+ raise ValueError("replicas must be an int or a tuple of two ints")
44
+ elif len(self.replicas) != 2:
45
+ raise ValueError("replicas must be an int or a tuple of two ints")
46
+
47
+ if isinstance(self.idle_ttl, int):
48
+ self.idle_ttl = timedelta(seconds=int(self.idle_ttl))
49
+ elif not isinstance(self.idle_ttl, timedelta):
50
+ raise ValueError("idle_ttl must be an int (seconds) or a timedelta")
51
+ if self.idle_ttl.total_seconds() < 30:
52
+ raise ValueError("idle_ttl must be at least 30 seconds")
53
+
54
+ if self.replicas[1] == 1 and self.concurrency == 1:
55
+ logger.warning(
56
+ "It is recommended to use a minimum of 2 replicas, to avoid starvation. "
57
+ "Starvation can occur if a task is running and no other replicas are available to handle new tasks."
58
+ "Options, increase concurrency, increase replicas or turn-off reuse for the parent task, "
59
+ "that runs child tasks."
60
+ )
61
+
62
+ if isinstance(self.scaledown_ttl, int):
63
+ self.scaledown_ttl = timedelta(seconds=int(self.scaledown_ttl))
64
+ elif not isinstance(self.scaledown_ttl, timedelta):
65
+ raise ValueError("scaledown_ttl must be an int (seconds) or a timedelta")
66
+ if self.scaledown_ttl.total_seconds() < 30:
67
+ raise ValueError("scaledown_ttl must be at least 30 seconds")
68
+
69
+ @property
70
+ def ttl(self) -> timedelta | None:
71
+ """
72
+ Returns the idle TTL as a timedelta. If idle_ttl is not set, returns the global default.
73
+ """
74
+ if self.idle_ttl is None:
75
+ return None
76
+ if isinstance(self.idle_ttl, timedelta):
77
+ return self.idle_ttl
78
+ return timedelta(seconds=self.idle_ttl)
79
+
80
+ @property
81
+ def min_replicas(self) -> int:
82
+ """
83
+ Returns the minimum number of replicas.
84
+ """
85
+ return self.replicas[0] if isinstance(self.replicas, tuple) else self.replicas
86
+
87
+ def get_scaledown_ttl(self) -> timedelta | None:
88
+ """
89
+ Returns the scaledown TTL as a timedelta. If scaledown_ttl is not set, returns None.
90
+ """
91
+ if self.scaledown_ttl is None:
92
+ return None
93
+ if isinstance(self.scaledown_ttl, timedelta):
94
+ return self.scaledown_ttl
95
+ return timedelta(seconds=int(self.scaledown_ttl))
96
+
97
+ @property
98
+ def max_replicas(self) -> int:
99
+ """
100
+ Returns the maximum number of replicas.
101
+ """
102
+ return self.replicas[1] if isinstance(self.replicas, tuple) else self.replicas