flyte 2.0.0b23__py3-none-any.whl → 2.0.0b25__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flyte might be problematic. Click here for more details.

Files changed (162) hide show
  1. flyte/__init__.py +11 -2
  2. flyte/_cache/local_cache.py +4 -3
  3. flyte/_code_bundle/_utils.py +3 -3
  4. flyte/_code_bundle/bundle.py +12 -5
  5. flyte/_context.py +4 -1
  6. flyte/_custom_context.py +73 -0
  7. flyte/_deploy.py +31 -7
  8. flyte/_image.py +48 -16
  9. flyte/_initialize.py +69 -26
  10. flyte/_internal/controllers/_local_controller.py +1 -0
  11. flyte/_internal/controllers/_trace.py +1 -1
  12. flyte/_internal/controllers/remote/_action.py +9 -10
  13. flyte/_internal/controllers/remote/_client.py +1 -1
  14. flyte/_internal/controllers/remote/_controller.py +4 -2
  15. flyte/_internal/controllers/remote/_core.py +10 -13
  16. flyte/_internal/controllers/remote/_informer.py +3 -3
  17. flyte/_internal/controllers/remote/_service_protocol.py +7 -7
  18. flyte/_internal/imagebuild/docker_builder.py +45 -59
  19. flyte/_internal/imagebuild/remote_builder.py +51 -11
  20. flyte/_internal/imagebuild/utils.py +51 -3
  21. flyte/_internal/runtime/convert.py +39 -18
  22. flyte/_internal/runtime/io.py +8 -7
  23. flyte/_internal/runtime/resources_serde.py +20 -6
  24. flyte/_internal/runtime/reuse.py +1 -1
  25. flyte/_internal/runtime/task_serde.py +7 -10
  26. flyte/_internal/runtime/taskrunner.py +10 -1
  27. flyte/_internal/runtime/trigger_serde.py +13 -13
  28. flyte/_internal/runtime/types_serde.py +1 -1
  29. flyte/_keyring/file.py +2 -2
  30. flyte/_map.py +65 -13
  31. flyte/_pod.py +2 -2
  32. flyte/_resources.py +175 -31
  33. flyte/_run.py +37 -21
  34. flyte/_task.py +27 -6
  35. flyte/_task_environment.py +37 -10
  36. flyte/_utils/module_loader.py +2 -2
  37. flyte/_version.py +3 -3
  38. flyte/cli/_common.py +47 -5
  39. flyte/cli/_create.py +4 -0
  40. flyte/cli/_deploy.py +8 -0
  41. flyte/cli/_get.py +4 -0
  42. flyte/cli/_params.py +4 -4
  43. flyte/cli/_run.py +50 -7
  44. flyte/cli/_update.py +4 -3
  45. flyte/config/_config.py +2 -0
  46. flyte/config/_internal.py +1 -0
  47. flyte/config/_reader.py +3 -3
  48. flyte/errors.py +1 -1
  49. flyte/extend.py +4 -0
  50. flyte/extras/_container.py +6 -1
  51. flyte/git/_config.py +11 -9
  52. flyte/io/_dataframe/basic_dfs.py +1 -1
  53. flyte/io/_dataframe/dataframe.py +12 -8
  54. flyte/io/_dir.py +48 -15
  55. flyte/io/_file.py +48 -11
  56. flyte/models.py +12 -8
  57. flyte/remote/_action.py +18 -16
  58. flyte/remote/_client/_protocols.py +4 -3
  59. flyte/remote/_client/auth/_channel.py +1 -1
  60. flyte/remote/_client/controlplane.py +4 -8
  61. flyte/remote/_data.py +4 -3
  62. flyte/remote/_logs.py +3 -3
  63. flyte/remote/_run.py +5 -5
  64. flyte/remote/_secret.py +20 -13
  65. flyte/remote/_task.py +7 -8
  66. flyte/remote/_trigger.py +25 -27
  67. flyte/storage/_parallel_reader.py +274 -0
  68. flyte/storage/_storage.py +66 -2
  69. flyte/types/_interface.py +2 -2
  70. flyte/types/_pickle.py +1 -1
  71. flyte/types/_string_literals.py +8 -9
  72. flyte/types/_type_engine.py +25 -17
  73. flyte/types/_utils.py +1 -1
  74. {flyte-2.0.0b23.dist-info → flyte-2.0.0b25.dist-info}/METADATA +2 -1
  75. flyte-2.0.0b25.dist-info/RECORD +184 -0
  76. flyte/_protos/__init__.py +0 -0
  77. flyte/_protos/common/authorization_pb2.py +0 -66
  78. flyte/_protos/common/authorization_pb2.pyi +0 -108
  79. flyte/_protos/common/authorization_pb2_grpc.py +0 -4
  80. flyte/_protos/common/identifier_pb2.py +0 -117
  81. flyte/_protos/common/identifier_pb2.pyi +0 -142
  82. flyte/_protos/common/identifier_pb2_grpc.py +0 -4
  83. flyte/_protos/common/identity_pb2.py +0 -48
  84. flyte/_protos/common/identity_pb2.pyi +0 -72
  85. flyte/_protos/common/identity_pb2_grpc.py +0 -4
  86. flyte/_protos/common/list_pb2.py +0 -36
  87. flyte/_protos/common/list_pb2.pyi +0 -71
  88. flyte/_protos/common/list_pb2_grpc.py +0 -4
  89. flyte/_protos/common/policy_pb2.py +0 -37
  90. flyte/_protos/common/policy_pb2.pyi +0 -27
  91. flyte/_protos/common/policy_pb2_grpc.py +0 -4
  92. flyte/_protos/common/role_pb2.py +0 -37
  93. flyte/_protos/common/role_pb2.pyi +0 -53
  94. flyte/_protos/common/role_pb2_grpc.py +0 -4
  95. flyte/_protos/common/runtime_version_pb2.py +0 -28
  96. flyte/_protos/common/runtime_version_pb2.pyi +0 -24
  97. flyte/_protos/common/runtime_version_pb2_grpc.py +0 -4
  98. flyte/_protos/imagebuilder/definition_pb2.py +0 -60
  99. flyte/_protos/imagebuilder/definition_pb2.pyi +0 -153
  100. flyte/_protos/imagebuilder/definition_pb2_grpc.py +0 -4
  101. flyte/_protos/imagebuilder/payload_pb2.py +0 -32
  102. flyte/_protos/imagebuilder/payload_pb2.pyi +0 -21
  103. flyte/_protos/imagebuilder/payload_pb2_grpc.py +0 -4
  104. flyte/_protos/imagebuilder/service_pb2.py +0 -29
  105. flyte/_protos/imagebuilder/service_pb2.pyi +0 -5
  106. flyte/_protos/imagebuilder/service_pb2_grpc.py +0 -66
  107. flyte/_protos/logs/dataplane/payload_pb2.py +0 -100
  108. flyte/_protos/logs/dataplane/payload_pb2.pyi +0 -177
  109. flyte/_protos/logs/dataplane/payload_pb2_grpc.py +0 -4
  110. flyte/_protos/secret/definition_pb2.py +0 -49
  111. flyte/_protos/secret/definition_pb2.pyi +0 -93
  112. flyte/_protos/secret/definition_pb2_grpc.py +0 -4
  113. flyte/_protos/secret/payload_pb2.py +0 -62
  114. flyte/_protos/secret/payload_pb2.pyi +0 -94
  115. flyte/_protos/secret/payload_pb2_grpc.py +0 -4
  116. flyte/_protos/secret/secret_pb2.py +0 -38
  117. flyte/_protos/secret/secret_pb2.pyi +0 -6
  118. flyte/_protos/secret/secret_pb2_grpc.py +0 -198
  119. flyte/_protos/validate/validate/validate_pb2.py +0 -76
  120. flyte/_protos/workflow/common_pb2.py +0 -38
  121. flyte/_protos/workflow/common_pb2.pyi +0 -63
  122. flyte/_protos/workflow/common_pb2_grpc.py +0 -4
  123. flyte/_protos/workflow/environment_pb2.py +0 -29
  124. flyte/_protos/workflow/environment_pb2.pyi +0 -12
  125. flyte/_protos/workflow/environment_pb2_grpc.py +0 -4
  126. flyte/_protos/workflow/node_execution_service_pb2.py +0 -26
  127. flyte/_protos/workflow/node_execution_service_pb2.pyi +0 -4
  128. flyte/_protos/workflow/node_execution_service_pb2_grpc.py +0 -32
  129. flyte/_protos/workflow/queue_service_pb2.py +0 -117
  130. flyte/_protos/workflow/queue_service_pb2.pyi +0 -182
  131. flyte/_protos/workflow/queue_service_pb2_grpc.py +0 -206
  132. flyte/_protos/workflow/run_definition_pb2.py +0 -123
  133. flyte/_protos/workflow/run_definition_pb2.pyi +0 -354
  134. flyte/_protos/workflow/run_definition_pb2_grpc.py +0 -4
  135. flyte/_protos/workflow/run_logs_service_pb2.py +0 -41
  136. flyte/_protos/workflow/run_logs_service_pb2.pyi +0 -28
  137. flyte/_protos/workflow/run_logs_service_pb2_grpc.py +0 -69
  138. flyte/_protos/workflow/run_service_pb2.py +0 -147
  139. flyte/_protos/workflow/run_service_pb2.pyi +0 -203
  140. flyte/_protos/workflow/run_service_pb2_grpc.py +0 -480
  141. flyte/_protos/workflow/state_service_pb2.py +0 -67
  142. flyte/_protos/workflow/state_service_pb2.pyi +0 -76
  143. flyte/_protos/workflow/state_service_pb2_grpc.py +0 -138
  144. flyte/_protos/workflow/task_definition_pb2.py +0 -86
  145. flyte/_protos/workflow/task_definition_pb2.pyi +0 -105
  146. flyte/_protos/workflow/task_definition_pb2_grpc.py +0 -4
  147. flyte/_protos/workflow/task_service_pb2.py +0 -61
  148. flyte/_protos/workflow/task_service_pb2.pyi +0 -62
  149. flyte/_protos/workflow/task_service_pb2_grpc.py +0 -138
  150. flyte/_protos/workflow/trigger_definition_pb2.py +0 -66
  151. flyte/_protos/workflow/trigger_definition_pb2.pyi +0 -117
  152. flyte/_protos/workflow/trigger_definition_pb2_grpc.py +0 -4
  153. flyte/_protos/workflow/trigger_service_pb2.py +0 -96
  154. flyte/_protos/workflow/trigger_service_pb2.pyi +0 -110
  155. flyte/_protos/workflow/trigger_service_pb2_grpc.py +0 -281
  156. flyte-2.0.0b23.dist-info/RECORD +0 -262
  157. {flyte-2.0.0b23.data → flyte-2.0.0b25.data}/scripts/debug.py +0 -0
  158. {flyte-2.0.0b23.data → flyte-2.0.0b25.data}/scripts/runtime.py +0 -0
  159. {flyte-2.0.0b23.dist-info → flyte-2.0.0b25.dist-info}/WHEEL +0 -0
  160. {flyte-2.0.0b23.dist-info → flyte-2.0.0b25.dist-info}/entry_points.txt +0 -0
  161. {flyte-2.0.0b23.dist-info → flyte-2.0.0b25.dist-info}/licenses/LICENSE +0 -0
  162. {flyte-2.0.0b23.dist-info → flyte-2.0.0b25.dist-info}/top_level.txt +0 -0
flyte/_map.py CHANGED
@@ -1,13 +1,13 @@
1
1
  import asyncio
2
2
  import functools
3
3
  import logging
4
- from typing import Any, AsyncGenerator, AsyncIterator, Generic, Iterable, Iterator, List, Union, cast
4
+ from typing import Any, AsyncGenerator, AsyncIterator, Generic, Iterable, Iterator, List, Union, cast, overload
5
5
 
6
6
  from flyte.syncify import syncify
7
7
 
8
8
  from ._group import group
9
9
  from ._logging import logger
10
- from ._task import P, R, TaskTemplate
10
+ from ._task import AsyncFunctionTaskTemplate, F, P, R
11
11
 
12
12
 
13
13
  class MapAsyncIterator(Generic[P, R]):
@@ -15,7 +15,7 @@ class MapAsyncIterator(Generic[P, R]):
15
15
 
16
16
  def __init__(
17
17
  self,
18
- func: TaskTemplate[P, R] | functools.partial[R],
18
+ func: AsyncFunctionTaskTemplate[P, R, F] | functools.partial[R],
19
19
  args: tuple,
20
20
  name: str,
21
21
  concurrency: int,
@@ -78,7 +78,7 @@ class MapAsyncIterator(Generic[P, R]):
78
78
 
79
79
  if isinstance(self.func, functools.partial):
80
80
  # Handle partial functions by merging bound args/kwargs with mapped args
81
- base_func = cast(TaskTemplate, self.func.func)
81
+ base_func = cast(AsyncFunctionTaskTemplate, self.func.func)
82
82
  bound_args = self.func.args
83
83
  bound_kwargs = self.func.keywords or {}
84
84
 
@@ -144,7 +144,7 @@ class _Mapper(Generic[P, R]):
144
144
  :param func: partial function to validate
145
145
  :raises TypeError: if the partial function is not valid for mapping
146
146
  """
147
- f = cast(TaskTemplate, func.func)
147
+ f = cast(AsyncFunctionTaskTemplate, func.func)
148
148
  inputs = f.native_interface.inputs
149
149
  params = list(inputs.keys())
150
150
  total_params = len(params)
@@ -172,9 +172,28 @@ class _Mapper(Generic[P, R]):
172
172
  f"in partial function {f.name}."
173
173
  )
174
174
 
175
+ @overload
175
176
  def __call__(
176
177
  self,
177
- func: TaskTemplate[P, R] | functools.partial[R],
178
+ func: AsyncFunctionTaskTemplate[P, R, F] | functools.partial[R],
179
+ *args: Iterable[Any],
180
+ group_name: str | None = None,
181
+ concurrency: int = 0,
182
+ ) -> Iterator[R]: ...
183
+
184
+ @overload
185
+ def __call__(
186
+ self,
187
+ func: AsyncFunctionTaskTemplate[P, R, F] | functools.partial[R],
188
+ *args: Iterable[Any],
189
+ group_name: str | None = None,
190
+ concurrency: int = 0,
191
+ return_exceptions: bool = True,
192
+ ) -> Iterator[R]: ...
193
+
194
+ def __call__(
195
+ self,
196
+ func: AsyncFunctionTaskTemplate[P, R, F] | functools.partial[R],
178
197
  *args: Iterable[Any],
179
198
  group_name: str | None = None,
180
199
  concurrency: int = 0,
@@ -194,10 +213,10 @@ class _Mapper(Generic[P, R]):
194
213
  return
195
214
 
196
215
  if isinstance(func, functools.partial):
197
- f = cast(TaskTemplate, func.func)
216
+ f = cast(AsyncFunctionTaskTemplate, func.func)
198
217
  self.validate_partial(func)
199
218
  else:
200
- f = cast(TaskTemplate, func)
219
+ f = cast(AsyncFunctionTaskTemplate, func)
201
220
 
202
221
  name = self._get_name(f.name, group_name)
203
222
  logger.debug(f"Blocking Map for {name}")
@@ -234,7 +253,7 @@ class _Mapper(Generic[P, R]):
234
253
 
235
254
  async def aio(
236
255
  self,
237
- func: TaskTemplate[P, R] | functools.partial[R],
256
+ func: AsyncFunctionTaskTemplate[P, R, F] | functools.partial[R],
238
257
  *args: Iterable[Any],
239
258
  group_name: str | None = None,
240
259
  concurrency: int = 0,
@@ -244,10 +263,10 @@ class _Mapper(Generic[P, R]):
244
263
  return
245
264
 
246
265
  if isinstance(func, functools.partial):
247
- f = cast(TaskTemplate, func.func)
266
+ f = cast(AsyncFunctionTaskTemplate, func.func)
248
267
  self.validate_partial(func)
249
268
  else:
250
- f = cast(TaskTemplate, func)
269
+ f = cast(AsyncFunctionTaskTemplate, func)
251
270
 
252
271
  name = self._get_name(f.name, group_name)
253
272
  with group(name):
@@ -277,7 +296,7 @@ class _Mapper(Generic[P, R]):
277
296
 
278
297
  @syncify
279
298
  async def _map(
280
- func: TaskTemplate[P, R] | functools.partial[R],
299
+ func: AsyncFunctionTaskTemplate[P, R, F] | functools.partial[R],
281
300
  *args: Iterable[Any],
282
301
  name: str = "map",
283
302
  concurrency: int = 0,
@@ -290,4 +309,37 @@ async def _map(
290
309
  yield result
291
310
 
292
311
 
293
- map: _Mapper = _Mapper()
312
+ @overload
313
+ def map(
314
+ func: AsyncFunctionTaskTemplate[P, R, F] | functools.partial[R],
315
+ *args: Iterable[Any],
316
+ group_name: str | None = None,
317
+ concurrency: int = 0,
318
+ ) -> Iterator[R]: ...
319
+
320
+
321
+ @overload
322
+ def map(
323
+ func: AsyncFunctionTaskTemplate[P, R, F] | functools.partial[R],
324
+ *args: Iterable[Any],
325
+ group_name: str | None = None,
326
+ concurrency: int = 0,
327
+ return_exceptions: bool = True,
328
+ ) -> Iterator[Union[R, Exception]]: ...
329
+
330
+
331
+ def map(
332
+ func: AsyncFunctionTaskTemplate[P, R, F] | functools.partial[R],
333
+ *args: Iterable[Any],
334
+ group_name: str | None = None,
335
+ concurrency: int = 0,
336
+ return_exceptions: bool = True,
337
+ ) -> Iterator[Union[R, Exception]]:
338
+ map: _Mapper = _Mapper()
339
+ return map(
340
+ func,
341
+ *args,
342
+ group_name=group_name,
343
+ concurrency=concurrency,
344
+ return_exceptions=return_exceptions,
345
+ )
flyte/_pod.py CHANGED
@@ -2,7 +2,7 @@ from dataclasses import dataclass, field
2
2
  from typing import TYPE_CHECKING, Dict, Optional
3
3
 
4
4
  if TYPE_CHECKING:
5
- from flyteidl.core.tasks_pb2 import K8sPod
5
+ from flyteidl2.core.tasks_pb2 import K8sPod
6
6
  from kubernetes.client import V1PodSpec
7
7
 
8
8
 
@@ -20,7 +20,7 @@ class PodTemplate(object):
20
20
  annotations: Optional[Dict[str, str]] = None
21
21
 
22
22
  def to_k8s_pod(self) -> "K8sPod":
23
- from flyteidl.core.tasks_pb2 import K8sObjectMetadata, K8sPod
23
+ from flyteidl2.core.tasks_pb2 import K8sObjectMetadata, K8sPod
24
24
  from kubernetes.client import ApiClient
25
25
 
26
26
  return K8sPod(
flyte/_resources.py CHANGED
@@ -1,5 +1,6 @@
1
+ import typing
1
2
  from dataclasses import dataclass, fields
2
- from typing import TYPE_CHECKING, Literal, Optional, Tuple, Union, get_args
3
+ from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union, get_args
3
4
 
4
5
  import rich.repr
5
6
 
@@ -10,7 +11,7 @@ if TYPE_CHECKING:
10
11
 
11
12
  PRIMARY_CONTAINER_DEFAULT_NAME = "primary"
12
13
 
13
- GPUType = Literal["T4", "A100", "A100 80G", "H100", "L4", "L40s"]
14
+ GPUType = Literal["A10", "A10G", "A100", "A100 80G", "B200", "H100", "L4", "L40s", "T4", "V100", "RTX PRO 6000"]
14
15
  GPUQuantity = Literal[1, 2, 3, 4, 5, 6, 7, 8]
15
16
  A100Parts = Literal["1g.5gb", "2g.10gb", "3g.20gb", "4g.20gb", "7g.40gb"]
16
17
  """
@@ -37,31 +38,32 @@ V6EParts = Literal["1x1", "2x2", "2x4", "4x4", "4x8", "8x8", "8x16", "16x16"]
37
38
  Slices for Google Cloud TPU v6e.
38
39
  """
39
40
 
41
+ NeuronType = Literal["Inf1", "Inf2", "Trn1", "Trn1n", "Trn2", "Trn2u"]
42
+
43
+ AMD_GPUType = Literal["MI100", "MI210", "MI250", "MI250X", "MI300A", "MI300X", "MI325X", "MI350X", "MI355X"]
44
+
45
+ HABANA_GAUDIType = Literal["Gaudi1"]
46
+
40
47
  Accelerators = Literal[
41
- "T4:1",
42
- "T4:2",
43
- "T4:3",
44
- "T4:4",
45
- "T4:5",
46
- "T4:6",
47
- "T4:7",
48
- "T4:8",
49
- "L4:1",
50
- "L4:2",
51
- "L4:3",
52
- "L4:4",
53
- "L4:5",
54
- "L4:6",
55
- "L4:7",
56
- "L4:8",
57
- "L40s:1",
58
- "L40s:2",
59
- "L40s:3",
60
- "L40s:4",
61
- "L40s:5",
62
- "L40s:6",
63
- "L40s:7",
64
- "L40s:8",
48
+ # A10
49
+ "A10:1",
50
+ "A10:2",
51
+ "A10:3",
52
+ "A10:4",
53
+ "A10:5",
54
+ "A10:6",
55
+ "A10:7",
56
+ "A10:8",
57
+ # A10G
58
+ "A10G:1",
59
+ "A10G:2",
60
+ "A10G:3",
61
+ "A10G:4",
62
+ "A10G:5",
63
+ "A10G:6",
64
+ "A10G:7",
65
+ "A10G:8",
66
+ # A100
65
67
  "A100:1",
66
68
  "A100:2",
67
69
  "A100:3",
@@ -70,6 +72,7 @@ Accelerators = Literal[
70
72
  "A100:6",
71
73
  "A100:7",
72
74
  "A100:8",
75
+ # A100 80G
73
76
  "A100 80G:1",
74
77
  "A100 80G:2",
75
78
  "A100 80G:3",
@@ -78,6 +81,16 @@ Accelerators = Literal[
78
81
  "A100 80G:6",
79
82
  "A100 80G:7",
80
83
  "A100 80G:8",
84
+ # B200
85
+ "B200:1",
86
+ "B200:2",
87
+ "B200:3",
88
+ "B200:4",
89
+ "B200:5",
90
+ "B200:6",
91
+ "B200:7",
92
+ "B200:8",
93
+ # H100
81
94
  "H100:1",
82
95
  "H100:2",
83
96
  "H100:3",
@@ -86,8 +99,97 @@ Accelerators = Literal[
86
99
  "H100:6",
87
100
  "H100:7",
88
101
  "H100:8",
102
+ # H200
103
+ "H200:1",
104
+ "H200:2",
105
+ "H200:3",
106
+ "H200:4",
107
+ "H200:5",
108
+ "H200:6",
109
+ "H200:7",
110
+ "H200:8",
111
+ # L4
112
+ "L4:1",
113
+ "L4:2",
114
+ "L4:3",
115
+ "L4:4",
116
+ "L4:5",
117
+ "L4:6",
118
+ "L4:7",
119
+ "L4:8",
120
+ # L40s
121
+ "L40s:1",
122
+ "L40s:2",
123
+ "L40s:3",
124
+ "L40s:4",
125
+ "L40s:5",
126
+ "L40s:6",
127
+ "L40s:7",
128
+ "L40s:8",
129
+ # V100
130
+ "V100:1",
131
+ "V100:2",
132
+ "V100:3",
133
+ "V100:4",
134
+ "V100:5",
135
+ "V100:6",
136
+ "V100:7",
137
+ "V100:8",
138
+ # RTX 6000
139
+ "RTX PRO 6000:1",
140
+ # T4
141
+ "T4:1",
142
+ "T4:2",
143
+ "T4:3",
144
+ "T4:4",
145
+ "T4:5",
146
+ "T4:6",
147
+ "T4:7",
148
+ "T4:8",
149
+ # Trn1
150
+ "Trn1:1",
151
+ # Trn1n
152
+ "Trn1n:1",
153
+ # Trn2
154
+ "Trn2:1",
155
+ # Trn2u
156
+ "Trn2u:1",
157
+ # Inf1
158
+ "Inf1:1",
159
+ # Inf2
160
+ "Inf2:1",
161
+ # MI100
162
+ "MI100:1",
163
+ # MI210
164
+ "MI210:1",
165
+ # MI250
166
+ "MI250:1",
167
+ # MI250X
168
+ "MI250X:1",
169
+ # MI300A
170
+ "MI300A:1",
171
+ # MI300X
172
+ "MI300X:1",
173
+ # MI325X
174
+ "MI325X:1",
175
+ # MI350X
176
+ "MI350X:1",
177
+ # MI355X
178
+ "MI355X:1",
179
+ # Habana Gaudi
180
+ "Gaudi1:1",
89
181
  ]
90
182
 
183
+ DeviceClass = Literal["GPU", "TPU", "NEURON", "AMD_GPU", "HABANA_GAUDI"]
184
+
185
+ _DeviceClassType: Dict[typing.Any, str] = {
186
+ GPUType: "GPU",
187
+ TPUType: "TPU",
188
+ NeuronType: "NEURON",
189
+ AMD_GPUType: "AMD_GPU",
190
+ HABANA_GAUDIType: "HABANA_GAUDI",
191
+ }
192
+
91
193
 
92
194
  @rich.repr.auto
93
195
  @dataclass(frozen=True, slots=True)
@@ -100,6 +202,7 @@ class Device:
100
202
  """
101
203
 
102
204
  quantity: int
205
+ device_class: DeviceClass
103
206
  device: str | None = None
104
207
  partition: str | None = None
105
208
 
@@ -126,7 +229,7 @@ def GPU(device: GPUType, quantity: GPUQuantity, partition: A100Parts | A100_80GB
126
229
  elif partition is not None and device == "A100 80G":
127
230
  if partition not in get_args(A100_80GBParts):
128
231
  raise ValueError(f"Invalid partition for A100 80G: {partition}. Must be one of {get_args(A100_80GBParts)}")
129
- return Device(device=device, quantity=quantity, partition=partition)
232
+ return Device(device=device, quantity=quantity, partition=partition, device_class="GPU")
130
233
 
131
234
 
132
235
  def TPU(device: TPUType, partition: V5PParts | V6EParts | None = None):
@@ -147,7 +250,42 @@ def TPU(device: TPUType, partition: V5PParts | V6EParts | None = None):
147
250
  elif partition is not None and device == "V5E":
148
251
  if partition not in get_args(V5EParts):
149
252
  raise ValueError(f"Invalid partition for V5E: {partition}. Must be one of {get_args(V5EParts)}")
150
- return Device(1, device, partition)
253
+ return Device(1, "TPU", device, partition)
254
+
255
+
256
+ def Neuron(device: NeuronType) -> Device:
257
+ """
258
+ Create a Neuron device instance.
259
+ :param device: Device type (e.g., "Inf1", "Inf2", "Trn1", "Trn1n", "Trn2", "Trn2u").
260
+ :param quantity: The number of Neuron devices of this type.
261
+ :return: Device instance.
262
+ """
263
+ if device not in get_args(NeuronType):
264
+ raise ValueError(f"Invalid Neuron type: {device}. Must be one of {get_args(NeuronType)}")
265
+ return Device(device=device, quantity=1, device_class="NEURON")
266
+
267
+
268
+ def AMD_GPU(device: AMD_GPUType) -> Device:
269
+ """
270
+ Create an AMD GPU device instance.
271
+ :param device: Device type (e.g., "MI100", "MI210", "MI250", "MI250X", "MI300A", "MI300X", "MI325X", "MI350X",
272
+ "MI355X").
273
+ :return: Device instance.
274
+ """
275
+ if device not in get_args(AMD_GPUType):
276
+ raise ValueError(f"Invalid AMD GPU type: {device}. Must be one of {get_args(AMD_GPUType)}")
277
+ return Device(device=device, quantity=1, device_class="AMD_GPU")
278
+
279
+
280
+ def HABANA_GAUDI(device: HABANA_GAUDIType) -> Device:
281
+ """
282
+ Create a Habana Gaudi device instance.
283
+ :param device: Device type (e.g., "DL1").
284
+ :return: Device instance.
285
+ """
286
+ if device not in get_args(HABANA_GAUDIType):
287
+ raise ValueError(f"Invalid Habana Gaudi type: {device}. Must be one of {get_args(HABANA_GAUDIType)}")
288
+ return Device(device=device, quantity=1, device_class="HABANA_GAUDI")
151
289
 
152
290
 
153
291
  CPUBaseType = int | float | str
@@ -202,7 +340,7 @@ class Resources:
202
340
  raise ValueError("gpu must be greater than or equal to 0")
203
341
  elif isinstance(self.gpu, str):
204
342
  if self.gpu not in get_args(Accelerators):
205
- raise ValueError(f"gpu must be one of {Accelerators}")
343
+ raise ValueError(f"gpu must be one of {Accelerators}, got {self.gpu}")
206
344
 
207
345
  def get_device(self) -> Optional[Device]:
208
346
  """
@@ -214,10 +352,16 @@ class Resources:
214
352
  if self.gpu is None:
215
353
  return None
216
354
  if isinstance(self.gpu, int):
217
- return Device(quantity=self.gpu)
355
+ return Device(quantity=self.gpu, device_class="GPU")
218
356
  if isinstance(self.gpu, str):
219
357
  device, portion = self.gpu.split(":")
220
- return Device(device=device, quantity=int(portion))
358
+ for cls, cls_name in _DeviceClassType.items():
359
+ if device in get_args(cls):
360
+ device_class = cls_name
361
+ break
362
+ else:
363
+ raise ValueError(f"Invalid device type: {device}. Must be one of {list(_DeviceClassType.keys())}")
364
+ return Device(device=device, device_class=device_class, quantity=int(portion)) # type: ignore
221
365
  return self.gpu
222
366
 
223
367
  def get_shared_memory(self) -> Optional[str]:
flyte/_run.py CHANGED
@@ -12,13 +12,13 @@ from flyte._environment import Environment
12
12
  from flyte._initialize import (
13
13
  _get_init_config,
14
14
  get_client,
15
- get_common_config,
15
+ get_init_config,
16
16
  get_storage,
17
17
  requires_initialization,
18
18
  requires_storage,
19
19
  )
20
20
  from flyte._logging import logger
21
- from flyte._task import P, R, TaskTemplate
21
+ from flyte._task import F, P, R, TaskTemplate
22
22
  from flyte.models import (
23
23
  ActionID,
24
24
  Checkpoints,
@@ -94,6 +94,7 @@ class _Runner:
94
94
  log_level: int | None = None,
95
95
  disable_run_cache: bool = False,
96
96
  queue: Optional[str] = None,
97
+ custom_context: Dict[str, str] | None = None,
97
98
  ):
98
99
  from flyte._tools import ipython_check
99
100
 
@@ -124,11 +125,15 @@ class _Runner:
124
125
  self._log_level = log_level
125
126
  self._disable_run_cache = disable_run_cache
126
127
  self._queue = queue
128
+ self._custom_context = custom_context or {}
127
129
 
128
130
  @requires_initialization
129
- async def _run_remote(self, obj: TaskTemplate[P, R] | LazyEntity, *args: P.args, **kwargs: P.kwargs) -> Run:
131
+ async def _run_remote(self, obj: TaskTemplate[P, R, F] | LazyEntity, *args: P.args, **kwargs: P.kwargs) -> Run:
130
132
  import grpc
131
- from flyteidl.core import literals_pb2
133
+ from flyteidl2.common import identifier_pb2
134
+ from flyteidl2.core import literals_pb2
135
+ from flyteidl2.task import run_pb2
136
+ from flyteidl2.workflow import run_definition_pb2, run_service_pb2
132
137
  from google.protobuf import wrappers_pb2
133
138
 
134
139
  from flyte.remote import Run
@@ -138,21 +143,21 @@ class _Runner:
138
143
  from ._deploy import build_images
139
144
  from ._internal.runtime.convert import convert_from_native_to_inputs
140
145
  from ._internal.runtime.task_serde import translate_task_to_wire
141
- from ._protos.common import identifier_pb2
142
- from ._protos.workflow import run_definition_pb2, run_service_pb2
143
146
 
144
- cfg = get_common_config()
147
+ cfg = get_init_config()
145
148
  project = self._project or cfg.project
146
149
  domain = self._domain or cfg.domain
147
150
 
148
151
  if isinstance(obj, LazyEntity):
149
152
  task = await obj.fetch.aio()
150
153
  task_spec = task.pb2.spec
151
- inputs = await convert_from_native_to_inputs(task.interface, *args, **kwargs)
154
+ inputs = await convert_from_native_to_inputs(
155
+ task.interface, *args, custom_context=self._custom_context, **kwargs
156
+ )
152
157
  version = task.pb2.task_id.version
153
158
  code_bundle = None
154
159
  else:
155
- task = cast(TaskTemplate[P, R], obj)
160
+ task = cast(TaskTemplate[P, R, F], obj)
156
161
  if obj.parent_env is None:
157
162
  raise ValueError("Task is not attached to an environment. Please attach the task to an environment")
158
163
 
@@ -204,7 +209,9 @@ class _Runner:
204
209
  root_dir=cfg.root_dir,
205
210
  )
206
211
  task_spec = translate_task_to_wire(obj, s_ctx)
207
- inputs = await convert_from_native_to_inputs(obj.native_interface, *args, **kwargs)
212
+ inputs = await convert_from_native_to_inputs(
213
+ obj.native_interface, *args, custom_context=self._custom_context, **kwargs
214
+ )
208
215
 
209
216
  env = self._env_vars or {}
210
217
  if env.get("LOG_LEVEL") is None:
@@ -254,9 +261,9 @@ class _Runner:
254
261
  raise ValueError(f"Environment variable {k} must be a string, got {type(v)}")
255
262
  kv_pairs.append(literals_pb2.KeyValuePair(key=k, value=v))
256
263
 
257
- env_kv = run_definition_pb2.Envs(values=kv_pairs)
258
- annotations = run_definition_pb2.Annotations(values=self._annotations)
259
- labels = run_definition_pb2.Labels(values=self._labels)
264
+ env_kv = run_pb2.Envs(values=kv_pairs)
265
+ annotations = run_pb2.Annotations(values=self._annotations)
266
+ labels = run_pb2.Labels(values=self._labels)
260
267
 
261
268
  try:
262
269
  resp = await get_client().run_service.CreateRun(
@@ -265,7 +272,7 @@ class _Runner:
265
272
  project_id=project_id,
266
273
  task_spec=task_spec,
267
274
  inputs=inputs.proto_inputs,
268
- run_spec=run_definition_pb2.RunSpec(
275
+ run_spec=run_pb2.RunSpec(
269
276
  overwrite_cache=self._overwrite_cache,
270
277
  interruptible=wrappers_pb2.BoolValue(value=self._interruptible)
271
278
  if self._interruptible is not None
@@ -318,7 +325,7 @@ class _Runner:
318
325
 
319
326
  @requires_storage
320
327
  @requires_initialization
321
- async def _run_hybrid(self, obj: TaskTemplate[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
328
+ async def _run_hybrid(self, obj: TaskTemplate[P, R, F], *args: P.args, **kwargs: P.kwargs) -> R:
322
329
  """
323
330
  Run a task in hybrid mode. This means that the parent action will be run locally, but the child actions will be
324
331
  run in the cluster remotely. This is currently only used for testing,
@@ -333,7 +340,7 @@ class _Runner:
333
340
  from ._internal import create_controller
334
341
  from ._internal.runtime.taskrunner import run_task
335
342
 
336
- cfg = get_common_config()
343
+ cfg = get_init_config()
337
344
 
338
345
  if obj.parent_env is None:
339
346
  raise ValueError("Task is not attached to an environment. Please attach the task to an environment.")
@@ -411,6 +418,7 @@ class _Runner:
411
418
  compiled_image_cache=image_cache,
412
419
  run_base_dir=run_base_dir,
413
420
  report=flyte.report.Report(name=action.name),
421
+ custom_context=self._custom_context,
414
422
  )
415
423
  async with ctx.replace_task_context(tctx):
416
424
  return await run_task(tctx=tctx, controller=controller, task=obj, inputs=inputs)
@@ -420,10 +428,11 @@ class _Runner:
420
428
  raise err
421
429
  return outputs
422
430
 
423
- async def _run_local(self, obj: TaskTemplate[P, R], *args: P.args, **kwargs: P.kwargs) -> Run:
431
+ async def _run_local(self, obj: TaskTemplate[P, R, F], *args: P.args, **kwargs: P.kwargs) -> Run:
432
+ from flyteidl2.common import identifier_pb2
433
+
424
434
  from flyte._internal.controllers import create_controller
425
435
  from flyte._internal.controllers._local_controller import LocalController
426
- from flyte._protos.common import identifier_pb2
427
436
  from flyte.remote import Run
428
437
  from flyte.report import Report
429
438
 
@@ -461,7 +470,9 @@ class _Runner:
461
470
  compiled_image_cache=None,
462
471
  report=Report(name=action.name),
463
472
  mode="local",
473
+ custom_context=self._custom_context,
464
474
  )
475
+
465
476
  with ctx.replace_task_context(tctx):
466
477
  # make the local version always runs on a different thread, returns a wrapped future.
467
478
  if obj._call_as_synchronous:
@@ -473,7 +484,7 @@ class _Runner:
473
484
 
474
485
  class _LocalRun(Run):
475
486
  def __init__(self, outputs: Tuple[Any, ...] | Any):
476
- from flyte._protos.workflow import run_definition_pb2
487
+ from flyteidl2.workflow import run_definition_pb2
477
488
 
478
489
  self._outputs = outputs
479
490
  super().__init__(
@@ -506,7 +517,7 @@ class _Runner:
506
517
  @syncify
507
518
  async def run(
508
519
  self,
509
- task: TaskTemplate[P, Union[R, Run]] | LazyEntity,
520
+ task: TaskTemplate[P, Union[R, Run], F] | LazyEntity,
510
521
  *args: P.args,
511
522
  **kwargs: P.kwargs,
512
523
  ) -> Union[R, Run]:
@@ -574,6 +585,7 @@ def with_runcontext(
574
585
  log_level: int | None = None,
575
586
  disable_run_cache: bool = False,
576
587
  queue: Optional[str] = None,
588
+ custom_context: Dict[str, str] | None = None,
577
589
  ) -> _Runner:
578
590
  """
579
591
  Launch a new run with the given parameters as the context.
@@ -618,6 +630,9 @@ def with_runcontext(
618
630
  set using `flyte.init()`
619
631
  :param disable_run_cache: Optional If true, the run cache will be disabled. This is useful for testing purposes.
620
632
  :param queue: Optional The queue to use for the run. This is used to specify the cluster to use for the run.
633
+ :param custom_context: Optional global input context to pass to the task. This will be available via
634
+ get_custom_context() within the task and will automatically propagate to sub-tasks.
635
+ Acts as base/default values that can be overridden by context managers in the code.
621
636
 
622
637
  :return: runner
623
638
  """
@@ -646,11 +661,12 @@ def with_runcontext(
646
661
  log_level=log_level,
647
662
  disable_run_cache=disable_run_cache,
648
663
  queue=queue,
664
+ custom_context=custom_context,
649
665
  )
650
666
 
651
667
 
652
668
  @syncify
653
- async def run(task: TaskTemplate[P, R], *args: P.args, **kwargs: P.kwargs) -> Union[R, Run]:
669
+ async def run(task: TaskTemplate[P, R, F], *args: P.args, **kwargs: P.kwargs) -> Union[R, Run]:
654
670
  """
655
671
  Run a task with the given parameters
656
672
  :param task: task to run