flyte 2.0.0b13__py3-none-any.whl → 2.0.0b30__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (211) hide show
  1. flyte/__init__.py +18 -2
  2. flyte/_bin/debug.py +38 -0
  3. flyte/_bin/runtime.py +62 -8
  4. flyte/_cache/cache.py +4 -2
  5. flyte/_cache/local_cache.py +216 -0
  6. flyte/_code_bundle/_ignore.py +12 -4
  7. flyte/_code_bundle/_packaging.py +13 -9
  8. flyte/_code_bundle/_utils.py +18 -10
  9. flyte/_code_bundle/bundle.py +17 -9
  10. flyte/_constants.py +1 -0
  11. flyte/_context.py +4 -1
  12. flyte/_custom_context.py +73 -0
  13. flyte/_debug/constants.py +38 -0
  14. flyte/_debug/utils.py +17 -0
  15. flyte/_debug/vscode.py +307 -0
  16. flyte/_deploy.py +235 -61
  17. flyte/_environment.py +20 -6
  18. flyte/_excepthook.py +1 -1
  19. flyte/_hash.py +1 -16
  20. flyte/_image.py +178 -81
  21. flyte/_initialize.py +132 -51
  22. flyte/_interface.py +39 -2
  23. flyte/_internal/controllers/__init__.py +4 -5
  24. flyte/_internal/controllers/_local_controller.py +70 -29
  25. flyte/_internal/controllers/_trace.py +1 -1
  26. flyte/_internal/controllers/remote/__init__.py +0 -2
  27. flyte/_internal/controllers/remote/_action.py +14 -16
  28. flyte/_internal/controllers/remote/_client.py +1 -1
  29. flyte/_internal/controllers/remote/_controller.py +68 -70
  30. flyte/_internal/controllers/remote/_core.py +127 -99
  31. flyte/_internal/controllers/remote/_informer.py +19 -10
  32. flyte/_internal/controllers/remote/_service_protocol.py +7 -7
  33. flyte/_internal/imagebuild/docker_builder.py +181 -69
  34. flyte/_internal/imagebuild/image_builder.py +0 -5
  35. flyte/_internal/imagebuild/remote_builder.py +155 -64
  36. flyte/_internal/imagebuild/utils.py +51 -2
  37. flyte/_internal/resolvers/_task_module.py +5 -38
  38. flyte/_internal/resolvers/default.py +2 -2
  39. flyte/_internal/runtime/convert.py +110 -21
  40. flyte/_internal/runtime/entrypoints.py +27 -1
  41. flyte/_internal/runtime/io.py +21 -8
  42. flyte/_internal/runtime/resources_serde.py +20 -6
  43. flyte/_internal/runtime/reuse.py +1 -1
  44. flyte/_internal/runtime/rusty.py +20 -5
  45. flyte/_internal/runtime/task_serde.py +34 -19
  46. flyte/_internal/runtime/taskrunner.py +22 -4
  47. flyte/_internal/runtime/trigger_serde.py +160 -0
  48. flyte/_internal/runtime/types_serde.py +1 -1
  49. flyte/_keyring/__init__.py +0 -0
  50. flyte/_keyring/file.py +115 -0
  51. flyte/_logging.py +201 -39
  52. flyte/_map.py +111 -14
  53. flyte/_module.py +70 -0
  54. flyte/_pod.py +4 -3
  55. flyte/_resources.py +213 -31
  56. flyte/_run.py +110 -39
  57. flyte/_task.py +75 -16
  58. flyte/_task_environment.py +105 -29
  59. flyte/_task_plugins.py +4 -2
  60. flyte/_trace.py +5 -0
  61. flyte/_trigger.py +1000 -0
  62. flyte/_utils/__init__.py +2 -1
  63. flyte/_utils/asyn.py +3 -1
  64. flyte/_utils/coro_management.py +2 -1
  65. flyte/_utils/docker_credentials.py +173 -0
  66. flyte/_utils/module_loader.py +17 -2
  67. flyte/_version.py +3 -3
  68. flyte/cli/_abort.py +3 -3
  69. flyte/cli/_build.py +3 -6
  70. flyte/cli/_common.py +78 -7
  71. flyte/cli/_create.py +182 -4
  72. flyte/cli/_delete.py +23 -1
  73. flyte/cli/_deploy.py +63 -16
  74. flyte/cli/_get.py +79 -34
  75. flyte/cli/_params.py +26 -10
  76. flyte/cli/_plugins.py +209 -0
  77. flyte/cli/_run.py +151 -26
  78. flyte/cli/_serve.py +64 -0
  79. flyte/cli/_update.py +37 -0
  80. flyte/cli/_user.py +17 -0
  81. flyte/cli/main.py +30 -4
  82. flyte/config/_config.py +10 -6
  83. flyte/config/_internal.py +1 -0
  84. flyte/config/_reader.py +29 -8
  85. flyte/connectors/__init__.py +11 -0
  86. flyte/connectors/_connector.py +270 -0
  87. flyte/connectors/_server.py +197 -0
  88. flyte/connectors/utils.py +135 -0
  89. flyte/errors.py +22 -2
  90. flyte/extend.py +8 -1
  91. flyte/extras/_container.py +6 -1
  92. flyte/git/__init__.py +3 -0
  93. flyte/git/_config.py +21 -0
  94. flyte/io/__init__.py +2 -0
  95. flyte/io/_dataframe/__init__.py +2 -0
  96. flyte/io/_dataframe/basic_dfs.py +17 -8
  97. flyte/io/_dataframe/dataframe.py +98 -132
  98. flyte/io/_dir.py +575 -113
  99. flyte/io/_file.py +582 -139
  100. flyte/io/_hashing_io.py +342 -0
  101. flyte/models.py +74 -15
  102. flyte/remote/__init__.py +6 -1
  103. flyte/remote/_action.py +34 -26
  104. flyte/remote/_client/_protocols.py +39 -4
  105. flyte/remote/_client/auth/_authenticators/device_code.py +4 -5
  106. flyte/remote/_client/auth/_authenticators/pkce.py +1 -1
  107. flyte/remote/_client/auth/_channel.py +10 -6
  108. flyte/remote/_client/controlplane.py +17 -5
  109. flyte/remote/_console.py +3 -2
  110. flyte/remote/_data.py +6 -6
  111. flyte/remote/_logs.py +3 -3
  112. flyte/remote/_run.py +64 -8
  113. flyte/remote/_secret.py +26 -17
  114. flyte/remote/_task.py +75 -33
  115. flyte/remote/_trigger.py +306 -0
  116. flyte/remote/_user.py +33 -0
  117. flyte/report/_report.py +1 -1
  118. flyte/storage/__init__.py +6 -1
  119. flyte/storage/_config.py +5 -1
  120. flyte/storage/_parallel_reader.py +274 -0
  121. flyte/storage/_storage.py +200 -103
  122. flyte/types/__init__.py +16 -0
  123. flyte/types/_interface.py +2 -2
  124. flyte/types/_pickle.py +35 -8
  125. flyte/types/_string_literals.py +8 -9
  126. flyte/types/_type_engine.py +40 -70
  127. flyte/types/_utils.py +1 -1
  128. flyte-2.0.0b30.data/scripts/debug.py +38 -0
  129. {flyte-2.0.0b13.data → flyte-2.0.0b30.data}/scripts/runtime.py +62 -8
  130. {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/METADATA +11 -3
  131. flyte-2.0.0b30.dist-info/RECORD +192 -0
  132. {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/entry_points.txt +3 -0
  133. flyte/_protos/common/authorization_pb2.py +0 -66
  134. flyte/_protos/common/authorization_pb2.pyi +0 -108
  135. flyte/_protos/common/authorization_pb2_grpc.py +0 -4
  136. flyte/_protos/common/identifier_pb2.py +0 -93
  137. flyte/_protos/common/identifier_pb2.pyi +0 -110
  138. flyte/_protos/common/identifier_pb2_grpc.py +0 -4
  139. flyte/_protos/common/identity_pb2.py +0 -48
  140. flyte/_protos/common/identity_pb2.pyi +0 -72
  141. flyte/_protos/common/identity_pb2_grpc.py +0 -4
  142. flyte/_protos/common/list_pb2.py +0 -36
  143. flyte/_protos/common/list_pb2.pyi +0 -71
  144. flyte/_protos/common/list_pb2_grpc.py +0 -4
  145. flyte/_protos/common/policy_pb2.py +0 -37
  146. flyte/_protos/common/policy_pb2.pyi +0 -27
  147. flyte/_protos/common/policy_pb2_grpc.py +0 -4
  148. flyte/_protos/common/role_pb2.py +0 -37
  149. flyte/_protos/common/role_pb2.pyi +0 -53
  150. flyte/_protos/common/role_pb2_grpc.py +0 -4
  151. flyte/_protos/common/runtime_version_pb2.py +0 -28
  152. flyte/_protos/common/runtime_version_pb2.pyi +0 -24
  153. flyte/_protos/common/runtime_version_pb2_grpc.py +0 -4
  154. flyte/_protos/imagebuilder/definition_pb2.py +0 -59
  155. flyte/_protos/imagebuilder/definition_pb2.pyi +0 -140
  156. flyte/_protos/imagebuilder/definition_pb2_grpc.py +0 -4
  157. flyte/_protos/imagebuilder/payload_pb2.py +0 -32
  158. flyte/_protos/imagebuilder/payload_pb2.pyi +0 -21
  159. flyte/_protos/imagebuilder/payload_pb2_grpc.py +0 -4
  160. flyte/_protos/imagebuilder/service_pb2.py +0 -29
  161. flyte/_protos/imagebuilder/service_pb2.pyi +0 -5
  162. flyte/_protos/imagebuilder/service_pb2_grpc.py +0 -66
  163. flyte/_protos/logs/dataplane/payload_pb2.py +0 -100
  164. flyte/_protos/logs/dataplane/payload_pb2.pyi +0 -177
  165. flyte/_protos/logs/dataplane/payload_pb2_grpc.py +0 -4
  166. flyte/_protos/secret/definition_pb2.py +0 -49
  167. flyte/_protos/secret/definition_pb2.pyi +0 -93
  168. flyte/_protos/secret/definition_pb2_grpc.py +0 -4
  169. flyte/_protos/secret/payload_pb2.py +0 -62
  170. flyte/_protos/secret/payload_pb2.pyi +0 -94
  171. flyte/_protos/secret/payload_pb2_grpc.py +0 -4
  172. flyte/_protos/secret/secret_pb2.py +0 -38
  173. flyte/_protos/secret/secret_pb2.pyi +0 -6
  174. flyte/_protos/secret/secret_pb2_grpc.py +0 -198
  175. flyte/_protos/secret/secret_pb2_grpc_grpc.py +0 -198
  176. flyte/_protos/validate/validate/validate_pb2.py +0 -76
  177. flyte/_protos/workflow/common_pb2.py +0 -27
  178. flyte/_protos/workflow/common_pb2.pyi +0 -14
  179. flyte/_protos/workflow/common_pb2_grpc.py +0 -4
  180. flyte/_protos/workflow/environment_pb2.py +0 -29
  181. flyte/_protos/workflow/environment_pb2.pyi +0 -12
  182. flyte/_protos/workflow/environment_pb2_grpc.py +0 -4
  183. flyte/_protos/workflow/node_execution_service_pb2.py +0 -26
  184. flyte/_protos/workflow/node_execution_service_pb2.pyi +0 -4
  185. flyte/_protos/workflow/node_execution_service_pb2_grpc.py +0 -32
  186. flyte/_protos/workflow/queue_service_pb2.py +0 -109
  187. flyte/_protos/workflow/queue_service_pb2.pyi +0 -166
  188. flyte/_protos/workflow/queue_service_pb2_grpc.py +0 -172
  189. flyte/_protos/workflow/run_definition_pb2.py +0 -121
  190. flyte/_protos/workflow/run_definition_pb2.pyi +0 -327
  191. flyte/_protos/workflow/run_definition_pb2_grpc.py +0 -4
  192. flyte/_protos/workflow/run_logs_service_pb2.py +0 -41
  193. flyte/_protos/workflow/run_logs_service_pb2.pyi +0 -28
  194. flyte/_protos/workflow/run_logs_service_pb2_grpc.py +0 -69
  195. flyte/_protos/workflow/run_service_pb2.py +0 -137
  196. flyte/_protos/workflow/run_service_pb2.pyi +0 -185
  197. flyte/_protos/workflow/run_service_pb2_grpc.py +0 -446
  198. flyte/_protos/workflow/state_service_pb2.py +0 -67
  199. flyte/_protos/workflow/state_service_pb2.pyi +0 -76
  200. flyte/_protos/workflow/state_service_pb2_grpc.py +0 -138
  201. flyte/_protos/workflow/task_definition_pb2.py +0 -79
  202. flyte/_protos/workflow/task_definition_pb2.pyi +0 -81
  203. flyte/_protos/workflow/task_definition_pb2_grpc.py +0 -4
  204. flyte/_protos/workflow/task_service_pb2.py +0 -60
  205. flyte/_protos/workflow/task_service_pb2.pyi +0 -59
  206. flyte/_protos/workflow/task_service_pb2_grpc.py +0 -138
  207. flyte-2.0.0b13.dist-info/RECORD +0 -239
  208. /flyte/{_protos → _debug}/__init__.py +0 -0
  209. {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/WHEEL +0 -0
  210. {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/licenses/LICENSE +0 -0
  211. {flyte-2.0.0b13.dist-info → flyte-2.0.0b30.dist-info}/top_level.txt +0 -0
flyte/_map.py CHANGED
@@ -1,17 +1,26 @@
1
1
  import asyncio
2
- from typing import Any, AsyncGenerator, AsyncIterator, Generic, Iterable, Iterator, List, Union, cast
2
+ import functools
3
+ import logging
4
+ from typing import Any, AsyncGenerator, AsyncIterator, Generic, Iterable, Iterator, List, Union, cast, overload
3
5
 
4
6
  from flyte.syncify import syncify
5
7
 
6
8
  from ._group import group
7
9
  from ._logging import logger
8
- from ._task import P, R, TaskTemplate
10
+ from ._task import AsyncFunctionTaskTemplate, F, P, R
9
11
 
10
12
 
11
13
  class MapAsyncIterator(Generic[P, R]):
12
14
  """AsyncIterator implementation for the map function results"""
13
15
 
14
- def __init__(self, func: TaskTemplate[P, R], args: tuple, name: str, concurrency: int, return_exceptions: bool):
16
+ def __init__(
17
+ self,
18
+ func: AsyncFunctionTaskTemplate[P, R, F] | functools.partial[R],
19
+ args: tuple,
20
+ name: str,
21
+ concurrency: int,
22
+ return_exceptions: bool,
23
+ ):
15
24
  self.func = func
16
25
  self.args = args
17
26
  self.name = name
@@ -49,13 +58,16 @@ class MapAsyncIterator(Generic[P, R]):
49
58
  return result
50
59
  except Exception as e:
51
60
  self._exception_count += 1
52
- logger.debug(f"Task {self._current_index - 1} failed with exception: {e}")
61
+ logger.debug(
62
+ f"Task {self._current_index - 1} failed with exception: {e}, return_exceptions={self.return_exceptions}"
63
+ )
53
64
  if self.return_exceptions:
54
65
  return e
55
66
  else:
56
67
  # Cancel remaining tasks
57
68
  for remaining_task in self._tasks[self._current_index + 1 :]:
58
69
  remaining_task.cancel()
70
+ logger.warning("Exception raising is `ON`, raising exception and cancelling remaining tasks")
59
71
  raise e
60
72
 
61
73
  async def _initialize(self):
@@ -64,10 +76,26 @@ class MapAsyncIterator(Generic[P, R]):
64
76
  tasks = []
65
77
  task_count = 0
66
78
 
67
- for arg_tuple in zip(*self.args):
68
- task = asyncio.create_task(self.func.aio(*arg_tuple))
69
- tasks.append(task)
70
- task_count += 1
79
+ if isinstance(self.func, functools.partial):
80
+ # Handle partial functions by merging bound args/kwargs with mapped args
81
+ base_func = cast(AsyncFunctionTaskTemplate, self.func.func)
82
+ bound_args = self.func.args
83
+ bound_kwargs = self.func.keywords or {}
84
+
85
+ for arg_tuple in zip(*self.args):
86
+ # Merge bound positional args with mapped args
87
+ merged_args = bound_args + arg_tuple
88
+ if logger.isEnabledFor(logging.DEBUG):
89
+ logger.debug(f"Running {base_func.name} with args: {merged_args} and kwargs: {bound_kwargs}")
90
+ task = asyncio.create_task(base_func.aio(*merged_args, **bound_kwargs))
91
+ tasks.append(task)
92
+ task_count += 1
93
+ else:
94
+ # Handle regular TaskTemplate functions
95
+ for arg_tuple in zip(*self.args):
96
+ task = asyncio.create_task(self.func.aio(*arg_tuple))
97
+ tasks.append(task)
98
+ task_count += 1
71
99
 
72
100
  if task_count == 0:
73
101
  logger.info(f"Group '{self.name}' has no tasks to process")
@@ -107,9 +135,65 @@ class _Mapper(Generic[P, R]):
107
135
  """Get the name of the group, defaulting to 'map' if not provided."""
108
136
  return f"{task_name}_{group_name or 'map'}"
109
137
 
138
+ @staticmethod
139
+ def validate_partial(func: functools.partial[R]):
140
+ """
141
+ This method validates that the provided partial function is valid for mapping, i.e. only the one argument
142
+ is left for mapping and the rest are provided as keywords or args.
143
+
144
+ :param func: partial function to validate
145
+ :raises TypeError: if the partial function is not valid for mapping
146
+ """
147
+ f = cast(AsyncFunctionTaskTemplate, func.func)
148
+ inputs = f.native_interface.inputs
149
+ params = list(inputs.keys())
150
+ total_params = len(params)
151
+ provided_args = len(func.args)
152
+ provided_kwargs = len(func.keywords or {})
153
+
154
+ # Calculate how many parameters are left unspecified
155
+ unspecified_count = total_params - provided_args - provided_kwargs
156
+
157
+ # Exactly one parameter should be left for mapping
158
+ if unspecified_count != 1:
159
+ raise TypeError(
160
+ f"Partial function must leave exactly one parameter unspecified for mapping. "
161
+ f"Found {unspecified_count} unspecified parameters in {f.name}, "
162
+ f"params: {inputs.keys()}"
163
+ )
164
+
165
+ # Validate that no parameter is both in args and keywords
166
+ if func.keywords:
167
+ param_names = list(inputs.keys())
168
+ for i, arg_name in enumerate(param_names[: provided_args + 1]):
169
+ if arg_name in func.keywords:
170
+ raise TypeError(
171
+ f"Parameter '{arg_name}' is provided both as positional argument and keyword argument "
172
+ f"in partial function {f.name}."
173
+ )
174
+
175
+ @overload
176
+ def __call__(
177
+ self,
178
+ func: AsyncFunctionTaskTemplate[P, R, F] | functools.partial[R],
179
+ *args: Iterable[Any],
180
+ group_name: str | None = None,
181
+ concurrency: int = 0,
182
+ ) -> Iterator[R]: ...
183
+
184
+ @overload
110
185
  def __call__(
111
186
  self,
112
- func: TaskTemplate[P, R],
187
+ func: AsyncFunctionTaskTemplate[P, R, F] | functools.partial[R],
188
+ *args: Iterable[Any],
189
+ group_name: str | None = None,
190
+ concurrency: int = 0,
191
+ return_exceptions: bool = True,
192
+ ) -> Iterator[Union[R, Exception]]: ...
193
+
194
+ def __call__(
195
+ self,
196
+ func: AsyncFunctionTaskTemplate[P, R, F] | functools.partial[R],
113
197
  *args: Iterable[Any],
114
198
  group_name: str | None = None,
115
199
  concurrency: int = 0,
@@ -128,7 +212,13 @@ class _Mapper(Generic[P, R]):
128
212
  if not args:
129
213
  return
130
214
 
131
- name = self._get_name(func.name, group_name)
215
+ if isinstance(func, functools.partial):
216
+ f = cast(AsyncFunctionTaskTemplate, func.func)
217
+ self.validate_partial(func)
218
+ else:
219
+ f = cast(AsyncFunctionTaskTemplate, func)
220
+
221
+ name = self._get_name(f.name, group_name)
132
222
  logger.debug(f"Blocking Map for {name}")
133
223
  with group(name):
134
224
  import flyte
@@ -154,7 +244,7 @@ class _Mapper(Generic[P, R]):
154
244
  *args,
155
245
  name=name,
156
246
  concurrency=concurrency,
157
- return_exceptions=True,
247
+ return_exceptions=return_exceptions,
158
248
  ),
159
249
  ):
160
250
  logger.debug(f"Mapped {x}, task {i}")
@@ -163,7 +253,7 @@ class _Mapper(Generic[P, R]):
163
253
 
164
254
  async def aio(
165
255
  self,
166
- func: TaskTemplate[P, R],
256
+ func: AsyncFunctionTaskTemplate[P, R, F] | functools.partial[R],
167
257
  *args: Iterable[Any],
168
258
  group_name: str | None = None,
169
259
  concurrency: int = 0,
@@ -171,7 +261,14 @@ class _Mapper(Generic[P, R]):
171
261
  ) -> AsyncGenerator[Union[R, Exception], None]:
172
262
  if not args:
173
263
  return
174
- name = self._get_name(func.name, group_name)
264
+
265
+ if isinstance(func, functools.partial):
266
+ f = cast(AsyncFunctionTaskTemplate, func.func)
267
+ self.validate_partial(func)
268
+ else:
269
+ f = cast(AsyncFunctionTaskTemplate, func)
270
+
271
+ name = self._get_name(f.name, group_name)
175
272
  with group(name):
176
273
  import flyte
177
274
 
@@ -199,7 +296,7 @@ class _Mapper(Generic[P, R]):
199
296
 
200
297
  @syncify
201
298
  async def _map(
202
- func: TaskTemplate[P, R],
299
+ func: AsyncFunctionTaskTemplate[P, R, F] | functools.partial[R],
203
300
  *args: Iterable[Any],
204
301
  name: str = "map",
205
302
  concurrency: int = 0,
flyte/_module.py ADDED
@@ -0,0 +1,70 @@
1
+ import inspect
2
+ import os
3
+ import pathlib
4
+ import sys
5
+
6
+
7
+ def extract_obj_module(obj: object, /, source_dir: pathlib.Path) -> str:
8
+ """
9
+ Extract the module from the given object. If source_dir is provided, the module will be relative to the source_dir.
10
+
11
+ Args:
12
+ obj: The object to extract the module from.
13
+ source_dir: The source directory to use for relative paths.
14
+
15
+ Returns:
16
+ The module name as a string.
17
+ """
18
+ if source_dir is None:
19
+ raise ValueError("extract_obj_module: source_dir cannot be None - specify root-dir")
20
+ # Get the module containing the object
21
+ entity_module = inspect.getmodule(obj)
22
+ if entity_module is None:
23
+ obj_name = getattr(obj, "__name__", str(obj))
24
+ raise ValueError(f"Object {obj_name} has no module.")
25
+
26
+ fp = entity_module.__file__
27
+ if fp is None:
28
+ obj_name = getattr(obj, "__name__", str(obj))
29
+ raise ValueError(f"Object {obj_name} has no module.")
30
+
31
+ file_path = pathlib.Path(fp)
32
+ try:
33
+ # Get the relative path to the current directory
34
+ # Will raise ValueError if the file is not in the source directory
35
+ relative_path = file_path.relative_to(str(pathlib.Path(source_dir).absolute()))
36
+
37
+ if relative_path == pathlib.Path("_internal/resolvers"):
38
+ entity_module_name = entity_module.__name__
39
+ else:
40
+ # Replace file separators with dots and remove the '.py' extension
41
+ dotted_path = os.path.splitext(str(relative_path))[0].replace(os.sep, ".")
42
+ entity_module_name = dotted_path
43
+ except ValueError:
44
+ # If source_dir is not provided or file is not in source_dir, fallback to module name
45
+ # File is not relative to source_dir - check if it's an installed package
46
+ file_path_str = str(file_path)
47
+ if "site-packages" in file_path_str or "dist-packages" in file_path_str:
48
+ # It's an installed package - use the module's __name__ directly
49
+ # This will be importable via importlib.import_module()
50
+ entity_module_name = entity_module.__name__
51
+ else:
52
+ # File is not in source_dir and not in site-packages - re-raise the error
53
+ obj_name = getattr(obj, "__name__", str(obj))
54
+ raise ValueError(
55
+ f"Object {obj_name} module file {file_path} is not relative to "
56
+ f"source directory {source_dir} and is not an installed package."
57
+ )
58
+
59
+ if entity_module_name == "__main__":
60
+ """
61
+ This case is for the case in which the object is run from the main module.
62
+ """
63
+ fp = sys.modules["__main__"].__file__
64
+ if fp is None:
65
+ obj_name = getattr(obj, "__name__", str(obj))
66
+ raise ValueError(f"Object {obj_name} has no module.")
67
+ main_path = pathlib.Path(fp)
68
+ entity_module_name = main_path.stem
69
+
70
+ return entity_module_name
flyte/_pod.py CHANGED
@@ -2,8 +2,8 @@ from dataclasses import dataclass, field
2
2
  from typing import TYPE_CHECKING, Dict, Optional
3
3
 
4
4
  if TYPE_CHECKING:
5
- from flyteidl.core.tasks_pb2 import K8sPod
6
- from kubernetes.client import ApiClient, V1PodSpec
5
+ from flyteidl2.core.tasks_pb2 import K8sPod
6
+ from kubernetes.client import V1PodSpec
7
7
 
8
8
 
9
9
  _PRIMARY_CONTAINER_NAME_FIELD = "primary_container_name"
@@ -20,7 +20,8 @@ class PodTemplate(object):
20
20
  annotations: Optional[Dict[str, str]] = None
21
21
 
22
22
  def to_k8s_pod(self) -> "K8sPod":
23
- from flyteidl.core.tasks_pb2 import K8sObjectMetadata, K8sPod
23
+ from flyteidl2.core.tasks_pb2 import K8sObjectMetadata, K8sPod
24
+ from kubernetes.client import ApiClient
24
25
 
25
26
  return K8sPod(
26
27
  metadata=K8sObjectMetadata(labels=self.labels, annotations=self.annotations),
flyte/_resources.py CHANGED
@@ -1,5 +1,6 @@
1
+ import typing
1
2
  from dataclasses import dataclass, fields
2
- from typing import TYPE_CHECKING, Literal, Optional, Tuple, Union, get_args
3
+ from typing import TYPE_CHECKING, Dict, Literal, Optional, Tuple, Union, get_args
3
4
 
4
5
  import rich.repr
5
6
 
@@ -10,7 +11,7 @@ if TYPE_CHECKING:
10
11
 
11
12
  PRIMARY_CONTAINER_DEFAULT_NAME = "primary"
12
13
 
13
- GPUType = Literal["T4", "A100", "A100 80G", "H100", "L4", "L40s"]
14
+ GPUType = Literal["A10", "A10G", "A100", "A100 80G", "B200", "H100", "L4", "L40s", "T4", "V100", "RTX PRO 6000"]
14
15
  GPUQuantity = Literal[1, 2, 3, 4, 5, 6, 7, 8]
15
16
  A100Parts = Literal["1g.5gb", "2g.10gb", "3g.20gb", "4g.20gb", "7g.40gb"]
16
17
  """
@@ -37,31 +38,32 @@ V6EParts = Literal["1x1", "2x2", "2x4", "4x4", "4x8", "8x8", "8x16", "16x16"]
37
38
  Slices for Google Cloud TPU v6e.
38
39
  """
39
40
 
41
+ NeuronType = Literal["Inf1", "Inf2", "Trn1", "Trn1n", "Trn2", "Trn2u"]
42
+
43
+ AMD_GPUType = Literal["MI100", "MI210", "MI250", "MI250X", "MI300A", "MI300X", "MI325X", "MI350X", "MI355X"]
44
+
45
+ HABANA_GAUDIType = Literal["Gaudi1"]
46
+
40
47
  Accelerators = Literal[
41
- "T4:1",
42
- "T4:2",
43
- "T4:3",
44
- "T4:4",
45
- "T4:5",
46
- "T4:6",
47
- "T4:7",
48
- "T4:8",
49
- "L4:1",
50
- "L4:2",
51
- "L4:3",
52
- "L4:4",
53
- "L4:5",
54
- "L4:6",
55
- "L4:7",
56
- "L4:8",
57
- "L40s:1",
58
- "L40s:2",
59
- "L40s:3",
60
- "L40s:4",
61
- "L40s:5",
62
- "L40s:6",
63
- "L40s:7",
64
- "L40s:8",
48
+ # A10
49
+ "A10:1",
50
+ "A10:2",
51
+ "A10:3",
52
+ "A10:4",
53
+ "A10:5",
54
+ "A10:6",
55
+ "A10:7",
56
+ "A10:8",
57
+ # A10G
58
+ "A10G:1",
59
+ "A10G:2",
60
+ "A10G:3",
61
+ "A10G:4",
62
+ "A10G:5",
63
+ "A10G:6",
64
+ "A10G:7",
65
+ "A10G:8",
66
+ # A100
65
67
  "A100:1",
66
68
  "A100:2",
67
69
  "A100:3",
@@ -70,6 +72,7 @@ Accelerators = Literal[
70
72
  "A100:6",
71
73
  "A100:7",
72
74
  "A100:8",
75
+ # A100 80G
73
76
  "A100 80G:1",
74
77
  "A100 80G:2",
75
78
  "A100 80G:3",
@@ -78,6 +81,16 @@ Accelerators = Literal[
78
81
  "A100 80G:6",
79
82
  "A100 80G:7",
80
83
  "A100 80G:8",
84
+ # B200
85
+ "B200:1",
86
+ "B200:2",
87
+ "B200:3",
88
+ "B200:4",
89
+ "B200:5",
90
+ "B200:6",
91
+ "B200:7",
92
+ "B200:8",
93
+ # H100
81
94
  "H100:1",
82
95
  "H100:2",
83
96
  "H100:3",
@@ -86,8 +99,135 @@ Accelerators = Literal[
86
99
  "H100:6",
87
100
  "H100:7",
88
101
  "H100:8",
102
+ # H200
103
+ "H200:1",
104
+ "H200:2",
105
+ "H200:3",
106
+ "H200:4",
107
+ "H200:5",
108
+ "H200:6",
109
+ "H200:7",
110
+ "H200:8",
111
+ # L4
112
+ "L4:1",
113
+ "L4:2",
114
+ "L4:3",
115
+ "L4:4",
116
+ "L4:5",
117
+ "L4:6",
118
+ "L4:7",
119
+ "L4:8",
120
+ # L40s
121
+ "L40s:1",
122
+ "L40s:2",
123
+ "L40s:3",
124
+ "L40s:4",
125
+ "L40s:5",
126
+ "L40s:6",
127
+ "L40s:7",
128
+ "L40s:8",
129
+ # V100
130
+ "V100:1",
131
+ "V100:2",
132
+ "V100:3",
133
+ "V100:4",
134
+ "V100:5",
135
+ "V100:6",
136
+ "V100:7",
137
+ "V100:8",
138
+ # RTX 6000
139
+ "RTX PRO 6000:1",
140
+ # T4
141
+ "T4:1",
142
+ "T4:2",
143
+ "T4:3",
144
+ "T4:4",
145
+ "T4:5",
146
+ "T4:6",
147
+ "T4:7",
148
+ "T4:8",
149
+ # Trn1
150
+ "Trn1:1",
151
+ "Trn1:4",
152
+ "Trn1:8",
153
+ "Trn1:16",
154
+ # Trn1n
155
+ "Trn1n:1",
156
+ "Trn1n:4",
157
+ "Trn1n:8",
158
+ "Trn1n:16",
159
+ # Trn2
160
+ "Trn2:1",
161
+ "Trn2:4",
162
+ "Trn2:8",
163
+ "Trn2:16",
164
+ # Trn2u
165
+ "Trn2u:1",
166
+ "Trn2u:4",
167
+ "Trn2u:8",
168
+ "Trn2u:16",
169
+ # Inf1
170
+ "Inf1:1",
171
+ "Inf1:2",
172
+ "Inf1:3",
173
+ "Inf1:4",
174
+ "Inf1:5",
175
+ "Inf1:6",
176
+ "Inf1:7",
177
+ "Inf1:8",
178
+ "Inf1:9",
179
+ "Inf1:10",
180
+ "Inf1:11",
181
+ "Inf1:12",
182
+ "Inf1:13",
183
+ "Inf1:14",
184
+ "Inf1:15",
185
+ "Inf1:16",
186
+ # Inf2
187
+ "Inf2:1",
188
+ "Inf2:2",
189
+ "Inf2:3",
190
+ "Inf2:4",
191
+ "Inf2:5",
192
+ "Inf2:6",
193
+ "Inf2:7",
194
+ "Inf2:8",
195
+ "Inf2:9",
196
+ "Inf2:10",
197
+ "Inf2:11",
198
+ "Inf2:12",
199
+ # MI100
200
+ "MI100:1",
201
+ # MI210
202
+ "MI210:1",
203
+ # MI250
204
+ "MI250:1",
205
+ # MI250X
206
+ "MI250X:1",
207
+ # MI300A
208
+ "MI300A:1",
209
+ # MI300X
210
+ "MI300X:1",
211
+ # MI325X
212
+ "MI325X:1",
213
+ # MI350X
214
+ "MI350X:1",
215
+ # MI355X
216
+ "MI355X:1",
217
+ # Habana Gaudi
218
+ "Gaudi1:1",
89
219
  ]
90
220
 
221
+ DeviceClass = Literal["GPU", "TPU", "NEURON", "AMD_GPU", "HABANA_GAUDI"]
222
+
223
+ _DeviceClassType: Dict[typing.Any, str] = {
224
+ GPUType: "GPU",
225
+ TPUType: "TPU",
226
+ NeuronType: "NEURON",
227
+ AMD_GPUType: "AMD_GPU",
228
+ HABANA_GAUDIType: "HABANA_GAUDI",
229
+ }
230
+
91
231
 
92
232
  @rich.repr.auto
93
233
  @dataclass(frozen=True, slots=True)
@@ -100,6 +240,7 @@ class Device:
100
240
  """
101
241
 
102
242
  quantity: int
243
+ device_class: DeviceClass
103
244
  device: str | None = None
104
245
  partition: str | None = None
105
246
 
@@ -126,7 +267,7 @@ def GPU(device: GPUType, quantity: GPUQuantity, partition: A100Parts | A100_80GB
126
267
  elif partition is not None and device == "A100 80G":
127
268
  if partition not in get_args(A100_80GBParts):
128
269
  raise ValueError(f"Invalid partition for A100 80G: {partition}. Must be one of {get_args(A100_80GBParts)}")
129
- return Device(device=device, quantity=quantity, partition=partition)
270
+ return Device(device=device, quantity=quantity, partition=partition, device_class="GPU")
130
271
 
131
272
 
132
273
  def TPU(device: TPUType, partition: V5PParts | V6EParts | None = None):
@@ -147,7 +288,42 @@ def TPU(device: TPUType, partition: V5PParts | V6EParts | None = None):
147
288
  elif partition is not None and device == "V5E":
148
289
  if partition not in get_args(V5EParts):
149
290
  raise ValueError(f"Invalid partition for V5E: {partition}. Must be one of {get_args(V5EParts)}")
150
- return Device(1, device, partition)
291
+ return Device(1, "TPU", device, partition)
292
+
293
+
294
+ def Neuron(device: NeuronType) -> Device:
295
+ """
296
+ Create a Neuron device instance.
297
+ :param device: Device type (e.g., "Inf1", "Inf2", "Trn1", "Trn1n", "Trn2", "Trn2u").
298
+ :param quantity: The number of Neuron devices of this type.
299
+ :return: Device instance.
300
+ """
301
+ if device not in get_args(NeuronType):
302
+ raise ValueError(f"Invalid Neuron type: {device}. Must be one of {get_args(NeuronType)}")
303
+ return Device(device=device, quantity=1, device_class="NEURON")
304
+
305
+
306
+ def AMD_GPU(device: AMD_GPUType) -> Device:
307
+ """
308
+ Create an AMD GPU device instance.
309
+ :param device: Device type (e.g., "MI100", "MI210", "MI250", "MI250X", "MI300A", "MI300X", "MI325X", "MI350X",
310
+ "MI355X").
311
+ :return: Device instance.
312
+ """
313
+ if device not in get_args(AMD_GPUType):
314
+ raise ValueError(f"Invalid AMD GPU type: {device}. Must be one of {get_args(AMD_GPUType)}")
315
+ return Device(device=device, quantity=1, device_class="AMD_GPU")
316
+
317
+
318
+ def HABANA_GAUDI(device: HABANA_GAUDIType) -> Device:
319
+ """
320
+ Create a Habana Gaudi device instance.
321
+ :param device: Device type (e.g., "DL1").
322
+ :return: Device instance.
323
+ """
324
+ if device not in get_args(HABANA_GAUDIType):
325
+ raise ValueError(f"Invalid Habana Gaudi type: {device}. Must be one of {get_args(HABANA_GAUDIType)}")
326
+ return Device(device=device, quantity=1, device_class="HABANA_GAUDI")
151
327
 
152
328
 
153
329
  CPUBaseType = int | float | str
@@ -202,7 +378,7 @@ class Resources:
202
378
  raise ValueError("gpu must be greater than or equal to 0")
203
379
  elif isinstance(self.gpu, str):
204
380
  if self.gpu not in get_args(Accelerators):
205
- raise ValueError(f"gpu must be one of {Accelerators}")
381
+ raise ValueError(f"gpu must be one of {Accelerators}, got {self.gpu}")
206
382
 
207
383
  def get_device(self) -> Optional[Device]:
208
384
  """
@@ -214,10 +390,16 @@ class Resources:
214
390
  if self.gpu is None:
215
391
  return None
216
392
  if isinstance(self.gpu, int):
217
- return Device(quantity=self.gpu)
393
+ return Device(quantity=self.gpu, device_class="GPU")
218
394
  if isinstance(self.gpu, str):
219
395
  device, portion = self.gpu.split(":")
220
- return Device(device=device, quantity=int(portion))
396
+ for cls, cls_name in _DeviceClassType.items():
397
+ if device in get_args(cls):
398
+ device_class = cls_name
399
+ break
400
+ else:
401
+ raise ValueError(f"Invalid device type: {device}. Must be one of {list(_DeviceClassType.keys())}")
402
+ return Device(device=device, device_class=device_class, quantity=int(portion)) # type: ignore
221
403
  return self.gpu
222
404
 
223
405
  def get_shared_memory(self) -> Optional[str]: