flyte 0.2.0b1__py3-none-any.whl → 2.0.0b46__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (266) hide show
  1. flyte/__init__.py +83 -30
  2. flyte/_bin/connect.py +61 -0
  3. flyte/_bin/debug.py +38 -0
  4. flyte/_bin/runtime.py +87 -19
  5. flyte/_bin/serve.py +351 -0
  6. flyte/_build.py +3 -2
  7. flyte/_cache/cache.py +6 -5
  8. flyte/_cache/local_cache.py +216 -0
  9. flyte/_code_bundle/_ignore.py +31 -5
  10. flyte/_code_bundle/_packaging.py +42 -11
  11. flyte/_code_bundle/_utils.py +57 -34
  12. flyte/_code_bundle/bundle.py +130 -27
  13. flyte/_constants.py +1 -0
  14. flyte/_context.py +21 -5
  15. flyte/_custom_context.py +73 -0
  16. flyte/_debug/constants.py +37 -0
  17. flyte/_debug/utils.py +17 -0
  18. flyte/_debug/vscode.py +315 -0
  19. flyte/_deploy.py +396 -75
  20. flyte/_deployer.py +109 -0
  21. flyte/_environment.py +94 -11
  22. flyte/_excepthook.py +37 -0
  23. flyte/_group.py +2 -1
  24. flyte/_hash.py +1 -16
  25. flyte/_image.py +544 -231
  26. flyte/_initialize.py +456 -316
  27. flyte/_interface.py +40 -5
  28. flyte/_internal/controllers/__init__.py +22 -8
  29. flyte/_internal/controllers/_local_controller.py +159 -35
  30. flyte/_internal/controllers/_trace.py +18 -10
  31. flyte/_internal/controllers/remote/__init__.py +38 -9
  32. flyte/_internal/controllers/remote/_action.py +82 -12
  33. flyte/_internal/controllers/remote/_client.py +6 -2
  34. flyte/_internal/controllers/remote/_controller.py +290 -64
  35. flyte/_internal/controllers/remote/_core.py +155 -95
  36. flyte/_internal/controllers/remote/_informer.py +40 -20
  37. flyte/_internal/controllers/remote/_service_protocol.py +2 -2
  38. flyte/_internal/imagebuild/__init__.py +2 -10
  39. flyte/_internal/imagebuild/docker_builder.py +391 -84
  40. flyte/_internal/imagebuild/image_builder.py +111 -55
  41. flyte/_internal/imagebuild/remote_builder.py +409 -0
  42. flyte/_internal/imagebuild/utils.py +79 -0
  43. flyte/_internal/resolvers/_app_env_module.py +92 -0
  44. flyte/_internal/resolvers/_task_module.py +5 -38
  45. flyte/_internal/resolvers/app_env.py +26 -0
  46. flyte/_internal/resolvers/common.py +8 -1
  47. flyte/_internal/resolvers/default.py +2 -2
  48. flyte/_internal/runtime/convert.py +319 -36
  49. flyte/_internal/runtime/entrypoints.py +106 -18
  50. flyte/_internal/runtime/io.py +71 -23
  51. flyte/_internal/runtime/resources_serde.py +21 -7
  52. flyte/_internal/runtime/reuse.py +125 -0
  53. flyte/_internal/runtime/rusty.py +196 -0
  54. flyte/_internal/runtime/task_serde.py +239 -66
  55. flyte/_internal/runtime/taskrunner.py +48 -8
  56. flyte/_internal/runtime/trigger_serde.py +162 -0
  57. flyte/_internal/runtime/types_serde.py +7 -16
  58. flyte/_keyring/file.py +115 -0
  59. flyte/_link.py +30 -0
  60. flyte/_logging.py +241 -42
  61. flyte/_map.py +312 -0
  62. flyte/_metrics.py +59 -0
  63. flyte/_module.py +74 -0
  64. flyte/_pod.py +30 -0
  65. flyte/_resources.py +296 -33
  66. flyte/_retry.py +1 -7
  67. flyte/_reusable_environment.py +72 -7
  68. flyte/_run.py +462 -132
  69. flyte/_secret.py +47 -11
  70. flyte/_serve.py +333 -0
  71. flyte/_task.py +245 -56
  72. flyte/_task_environment.py +219 -97
  73. flyte/_task_plugins.py +47 -0
  74. flyte/_tools.py +8 -8
  75. flyte/_trace.py +15 -24
  76. flyte/_trigger.py +1027 -0
  77. flyte/_utils/__init__.py +12 -1
  78. flyte/_utils/asyn.py +3 -1
  79. flyte/_utils/async_cache.py +139 -0
  80. flyte/_utils/coro_management.py +5 -4
  81. flyte/_utils/description_parser.py +19 -0
  82. flyte/_utils/docker_credentials.py +173 -0
  83. flyte/_utils/helpers.py +45 -19
  84. flyte/_utils/module_loader.py +123 -0
  85. flyte/_utils/org_discovery.py +57 -0
  86. flyte/_utils/uv_script_parser.py +8 -1
  87. flyte/_version.py +16 -3
  88. flyte/app/__init__.py +27 -0
  89. flyte/app/_app_environment.py +362 -0
  90. flyte/app/_connector_environment.py +40 -0
  91. flyte/app/_deploy.py +130 -0
  92. flyte/app/_parameter.py +343 -0
  93. flyte/app/_runtime/__init__.py +3 -0
  94. flyte/app/_runtime/app_serde.py +383 -0
  95. flyte/app/_types.py +113 -0
  96. flyte/app/extras/__init__.py +9 -0
  97. flyte/app/extras/_auth_middleware.py +217 -0
  98. flyte/app/extras/_fastapi.py +93 -0
  99. flyte/app/extras/_model_loader/__init__.py +3 -0
  100. flyte/app/extras/_model_loader/config.py +7 -0
  101. flyte/app/extras/_model_loader/loader.py +288 -0
  102. flyte/cli/__init__.py +12 -0
  103. flyte/cli/_abort.py +28 -0
  104. flyte/cli/_build.py +114 -0
  105. flyte/cli/_common.py +493 -0
  106. flyte/cli/_create.py +371 -0
  107. flyte/cli/_delete.py +45 -0
  108. flyte/cli/_deploy.py +401 -0
  109. flyte/cli/_gen.py +316 -0
  110. flyte/cli/_get.py +446 -0
  111. flyte/cli/_option.py +33 -0
  112. flyte/{_cli → cli}/_params.py +57 -17
  113. flyte/cli/_plugins.py +209 -0
  114. flyte/cli/_prefetch.py +292 -0
  115. flyte/cli/_run.py +690 -0
  116. flyte/cli/_serve.py +338 -0
  117. flyte/cli/_update.py +86 -0
  118. flyte/cli/_user.py +20 -0
  119. flyte/cli/main.py +246 -0
  120. flyte/config/__init__.py +2 -167
  121. flyte/config/_config.py +215 -163
  122. flyte/config/_internal.py +10 -1
  123. flyte/config/_reader.py +225 -0
  124. flyte/connectors/__init__.py +11 -0
  125. flyte/connectors/_connector.py +330 -0
  126. flyte/connectors/_server.py +194 -0
  127. flyte/connectors/utils.py +159 -0
  128. flyte/errors.py +134 -2
  129. flyte/extend.py +24 -0
  130. flyte/extras/_container.py +69 -56
  131. flyte/git/__init__.py +3 -0
  132. flyte/git/_config.py +279 -0
  133. flyte/io/__init__.py +8 -1
  134. flyte/io/{structured_dataset → _dataframe}/__init__.py +32 -30
  135. flyte/io/{structured_dataset → _dataframe}/basic_dfs.py +75 -68
  136. flyte/io/{structured_dataset/structured_dataset.py → _dataframe/dataframe.py} +207 -242
  137. flyte/io/_dir.py +575 -113
  138. flyte/io/_file.py +587 -141
  139. flyte/io/_hashing_io.py +342 -0
  140. flyte/io/extend.py +7 -0
  141. flyte/models.py +635 -0
  142. flyte/prefetch/__init__.py +22 -0
  143. flyte/prefetch/_hf_model.py +563 -0
  144. flyte/remote/__init__.py +14 -3
  145. flyte/remote/_action.py +879 -0
  146. flyte/remote/_app.py +346 -0
  147. flyte/remote/_auth_metadata.py +42 -0
  148. flyte/remote/_client/_protocols.py +62 -4
  149. flyte/remote/_client/auth/_auth_utils.py +19 -0
  150. flyte/remote/_client/auth/_authenticators/base.py +8 -2
  151. flyte/remote/_client/auth/_authenticators/device_code.py +4 -5
  152. flyte/remote/_client/auth/_authenticators/factory.py +4 -0
  153. flyte/remote/_client/auth/_authenticators/passthrough.py +79 -0
  154. flyte/remote/_client/auth/_authenticators/pkce.py +17 -18
  155. flyte/remote/_client/auth/_channel.py +47 -18
  156. flyte/remote/_client/auth/_client_config.py +5 -3
  157. flyte/remote/_client/auth/_keyring.py +15 -2
  158. flyte/remote/_client/auth/_token_client.py +3 -3
  159. flyte/remote/_client/controlplane.py +206 -18
  160. flyte/remote/_common.py +66 -0
  161. flyte/remote/_data.py +107 -22
  162. flyte/remote/_logs.py +116 -33
  163. flyte/remote/_project.py +21 -19
  164. flyte/remote/_run.py +164 -631
  165. flyte/remote/_secret.py +72 -29
  166. flyte/remote/_task.py +387 -46
  167. flyte/remote/_trigger.py +368 -0
  168. flyte/remote/_user.py +43 -0
  169. flyte/report/_report.py +10 -6
  170. flyte/storage/__init__.py +13 -1
  171. flyte/storage/_config.py +237 -0
  172. flyte/storage/_parallel_reader.py +289 -0
  173. flyte/storage/_storage.py +268 -59
  174. flyte/syncify/__init__.py +56 -0
  175. flyte/syncify/_api.py +414 -0
  176. flyte/types/__init__.py +39 -0
  177. flyte/types/_interface.py +22 -7
  178. flyte/{io/pickle/transformer.py → types/_pickle.py} +37 -9
  179. flyte/types/_string_literals.py +8 -9
  180. flyte/types/_type_engine.py +226 -126
  181. flyte/types/_utils.py +1 -1
  182. flyte-2.0.0b46.data/scripts/debug.py +38 -0
  183. flyte-2.0.0b46.data/scripts/runtime.py +194 -0
  184. flyte-2.0.0b46.dist-info/METADATA +352 -0
  185. flyte-2.0.0b46.dist-info/RECORD +221 -0
  186. flyte-2.0.0b46.dist-info/entry_points.txt +8 -0
  187. flyte-2.0.0b46.dist-info/licenses/LICENSE +201 -0
  188. flyte/_api_commons.py +0 -3
  189. flyte/_cli/_common.py +0 -299
  190. flyte/_cli/_create.py +0 -42
  191. flyte/_cli/_delete.py +0 -23
  192. flyte/_cli/_deploy.py +0 -140
  193. flyte/_cli/_get.py +0 -235
  194. flyte/_cli/_run.py +0 -174
  195. flyte/_cli/main.py +0 -98
  196. flyte/_datastructures.py +0 -342
  197. flyte/_internal/controllers/pbhash.py +0 -39
  198. flyte/_protos/common/authorization_pb2.py +0 -66
  199. flyte/_protos/common/authorization_pb2.pyi +0 -108
  200. flyte/_protos/common/authorization_pb2_grpc.py +0 -4
  201. flyte/_protos/common/identifier_pb2.py +0 -71
  202. flyte/_protos/common/identifier_pb2.pyi +0 -82
  203. flyte/_protos/common/identifier_pb2_grpc.py +0 -4
  204. flyte/_protos/common/identity_pb2.py +0 -48
  205. flyte/_protos/common/identity_pb2.pyi +0 -72
  206. flyte/_protos/common/identity_pb2_grpc.py +0 -4
  207. flyte/_protos/common/list_pb2.py +0 -36
  208. flyte/_protos/common/list_pb2.pyi +0 -69
  209. flyte/_protos/common/list_pb2_grpc.py +0 -4
  210. flyte/_protos/common/policy_pb2.py +0 -37
  211. flyte/_protos/common/policy_pb2.pyi +0 -27
  212. flyte/_protos/common/policy_pb2_grpc.py +0 -4
  213. flyte/_protos/common/role_pb2.py +0 -37
  214. flyte/_protos/common/role_pb2.pyi +0 -53
  215. flyte/_protos/common/role_pb2_grpc.py +0 -4
  216. flyte/_protos/common/runtime_version_pb2.py +0 -28
  217. flyte/_protos/common/runtime_version_pb2.pyi +0 -24
  218. flyte/_protos/common/runtime_version_pb2_grpc.py +0 -4
  219. flyte/_protos/logs/dataplane/payload_pb2.py +0 -96
  220. flyte/_protos/logs/dataplane/payload_pb2.pyi +0 -168
  221. flyte/_protos/logs/dataplane/payload_pb2_grpc.py +0 -4
  222. flyte/_protos/secret/definition_pb2.py +0 -49
  223. flyte/_protos/secret/definition_pb2.pyi +0 -93
  224. flyte/_protos/secret/definition_pb2_grpc.py +0 -4
  225. flyte/_protos/secret/payload_pb2.py +0 -62
  226. flyte/_protos/secret/payload_pb2.pyi +0 -94
  227. flyte/_protos/secret/payload_pb2_grpc.py +0 -4
  228. flyte/_protos/secret/secret_pb2.py +0 -38
  229. flyte/_protos/secret/secret_pb2.pyi +0 -6
  230. flyte/_protos/secret/secret_pb2_grpc.py +0 -198
  231. flyte/_protos/secret/secret_pb2_grpc_grpc.py +0 -198
  232. flyte/_protos/validate/validate/validate_pb2.py +0 -76
  233. flyte/_protos/workflow/node_execution_service_pb2.py +0 -26
  234. flyte/_protos/workflow/node_execution_service_pb2.pyi +0 -4
  235. flyte/_protos/workflow/node_execution_service_pb2_grpc.py +0 -32
  236. flyte/_protos/workflow/queue_service_pb2.py +0 -106
  237. flyte/_protos/workflow/queue_service_pb2.pyi +0 -141
  238. flyte/_protos/workflow/queue_service_pb2_grpc.py +0 -172
  239. flyte/_protos/workflow/run_definition_pb2.py +0 -128
  240. flyte/_protos/workflow/run_definition_pb2.pyi +0 -310
  241. flyte/_protos/workflow/run_definition_pb2_grpc.py +0 -4
  242. flyte/_protos/workflow/run_logs_service_pb2.py +0 -41
  243. flyte/_protos/workflow/run_logs_service_pb2.pyi +0 -28
  244. flyte/_protos/workflow/run_logs_service_pb2_grpc.py +0 -69
  245. flyte/_protos/workflow/run_service_pb2.py +0 -133
  246. flyte/_protos/workflow/run_service_pb2.pyi +0 -175
  247. flyte/_protos/workflow/run_service_pb2_grpc.py +0 -412
  248. flyte/_protos/workflow/state_service_pb2.py +0 -58
  249. flyte/_protos/workflow/state_service_pb2.pyi +0 -71
  250. flyte/_protos/workflow/state_service_pb2_grpc.py +0 -138
  251. flyte/_protos/workflow/task_definition_pb2.py +0 -72
  252. flyte/_protos/workflow/task_definition_pb2.pyi +0 -65
  253. flyte/_protos/workflow/task_definition_pb2_grpc.py +0 -4
  254. flyte/_protos/workflow/task_service_pb2.py +0 -44
  255. flyte/_protos/workflow/task_service_pb2.pyi +0 -31
  256. flyte/_protos/workflow/task_service_pb2_grpc.py +0 -104
  257. flyte/io/_dataframe.py +0 -0
  258. flyte/io/pickle/__init__.py +0 -0
  259. flyte/remote/_console.py +0 -18
  260. flyte-0.2.0b1.dist-info/METADATA +0 -179
  261. flyte-0.2.0b1.dist-info/RECORD +0 -204
  262. flyte-0.2.0b1.dist-info/entry_points.txt +0 -3
  263. /flyte/{_cli → _debug}/__init__.py +0 -0
  264. /flyte/{_protos → _keyring}/__init__.py +0 -0
  265. {flyte-0.2.0b1.dist-info → flyte-2.0.0b46.dist-info}/WHEEL +0 -0
  266. {flyte-0.2.0b1.dist-info → flyte-2.0.0b46.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,217 @@
1
+ """
2
+ FastAPI middleware for automatic Flyte authentication passthrough.
3
+
4
+ This module provides middleware that automatically extracts authentication headers
5
+ from incoming requests and sets them in the Flyte context, eliminating the need
6
+ for manual auth_metadata() wrapping in every endpoint.
7
+
8
+ Example:
9
+ Basic usage with default extractors (Authorization + Cookie headers)::
10
+
11
+ from fastapi import FastAPI
12
+ from flyte.app.extras import FastAPIAuthMiddleware
13
+
14
+ app = FastAPI()
15
+ app.add_middleware(FastAPIAuthMiddleware, excluded_paths={"/health"})
16
+
17
+ @app.get("/me")
18
+ async def get_current_user():
19
+ # Auth metadata automatically set from request headers
20
+ user = await remote.User.get.aio()
21
+ return {"subject": user.subject()}
22
+
23
+ Advanced usage with custom extractors and path exclusions::
24
+
25
+ from flyte.app.extras import FastAPIAuthMiddleware
26
+
27
+ app.add_middleware(
28
+ FastAPIAuthMiddleware,
29
+ header_extractors=[
30
+ FastAPIAuthMiddleware.extract_authorization_header,
31
+ FastAPIAuthMiddleware.extract_custom_header("x-api-key"),
32
+ ],
33
+ excluded_paths={"/health", "/metrics"},
34
+ )
35
+ """
36
+
37
+ from __future__ import annotations
38
+
39
+ import logging
40
+ from typing import TYPE_CHECKING, Callable
41
+
42
+ if TYPE_CHECKING:
43
+ from fastapi import Request
44
+ from starlette.middleware.base import BaseHTTPMiddleware
45
+ from starlette.responses import Response
46
+ else:
47
+ try:
48
+ from starlette.middleware.base import BaseHTTPMiddleware
49
+ except ImportError:
50
+
51
+ class BaseHTTPMiddleware:
52
+ pass
53
+
54
+
55
+ logger = logging.getLogger(__name__)
56
+
57
+ # Header extractor type: takes a Request, returns (key, value) tuple or None
58
+ HeaderExtractor = Callable[["Request"], tuple[str, str] | None]
59
+
60
+
61
+ class FastAPIPassthroughAuthMiddleware(BaseHTTPMiddleware):
62
+ """
63
+ FastAPI middleware that automatically sets Flyte auth metadata from request headers.
64
+
65
+ This middleware extracts authentication headers from incoming HTTP requests and
66
+ sets them in the Flyte context using the auth_metadata() context manager. This
67
+ eliminates the need to manually wrap endpoint handlers with auth_metadata().
68
+
69
+ The middleware is highly configurable:
70
+ - Custom header extractors can be provided
71
+ - Specific paths can be excluded from auth requirements
72
+ - Auth can be optional or required
73
+
74
+ Attributes:
75
+ app: The FastAPI application (this is a mandatory framework parameter)
76
+ header_extractors: List of functions to extract headers from requests
77
+ excluded_paths: Set of URL paths that bypass auth extraction
78
+
79
+ Thread Safety:
80
+ This middleware is async-safe and properly isolates auth metadata per request
81
+ using Python's contextvars. Multiple concurrent requests with different
82
+ authentication will not interfere with each other.
83
+ """
84
+
85
+ def __init__(
86
+ self,
87
+ app,
88
+ header_extractors: list[HeaderExtractor] | None = None,
89
+ excluded_paths: set[str] | None = None,
90
+ ):
91
+ """
92
+ Initialize the Flyte authentication middleware.
93
+
94
+ Args:
95
+ app: The FastAPI/Starlette application
96
+ header_extractors: List of functions to extract headers. Each function
97
+ takes a Request and returns (key, value) tuple or None.
98
+ Defaults to [extract_authorization_header, extract_cookie_header].
99
+ excluded_paths: Set of URL paths to exclude from auth extraction.
100
+ Requests to these paths proceed without setting auth context.
101
+ """
102
+ super().__init__(app)
103
+
104
+ if header_extractors is None:
105
+ self.header_extractors: list[HeaderExtractor] = [
106
+ self.extract_authorization_header,
107
+ self.extract_cookie_header,
108
+ ]
109
+ else:
110
+ self.header_extractors = header_extractors
111
+
112
+ self.excluded_paths = excluded_paths or set()
113
+
114
+ async def dispatch(self, request: "Request", call_next) -> "Response":
115
+ """
116
+ Process each request, extracting auth headers and setting Flyte context.
117
+
118
+ Args:
119
+ request: The incoming HTTP request
120
+ call_next: The next middleware or route handler to call
121
+
122
+ Returns:
123
+ The HTTP response from the handler
124
+ """
125
+ from starlette.responses import JSONResponse
126
+
127
+ # Skip auth extraction for excluded paths
128
+ if request.url.path in self.excluded_paths:
129
+ return await call_next(request)
130
+
131
+ # Extract auth headers using all configured extractors
132
+ auth_tuples = []
133
+ for extractor in self.header_extractors:
134
+ try:
135
+ result = extractor(request)
136
+ if result is not None:
137
+ auth_tuples.append(result)
138
+ except Exception as e:
139
+ logger.warning(f"Header extractor {extractor.__name__} failed: {e}")
140
+
141
+ # Require auth headers
142
+ if not auth_tuples:
143
+ logger.info("No auth tuples found")
144
+ return JSONResponse(
145
+ status_code=401,
146
+ content={"detail": "Authentication credentials required"},
147
+ headers={"WWW-Authenticate": "Bearer"},
148
+ )
149
+
150
+ # Set auth metadata in Flyte context for the duration of this request
151
+ from flyte.remote import auth_metadata
152
+
153
+ with auth_metadata(*auth_tuples):
154
+ return await call_next(request)
155
+
156
+ @staticmethod
157
+ def extract_authorization_header(request: "Request") -> tuple[str, str] | None:
158
+ """
159
+ Extract the Authorization header from the request.
160
+
161
+ Args:
162
+ request: The FastAPI/Starlette request object
163
+
164
+ Returns:
165
+ Tuple of ("authorization", header_value) if present, None otherwise
166
+ """
167
+ auth_header = request.headers.get("authorization")
168
+ if auth_header:
169
+ return "authorization", auth_header
170
+ return None
171
+
172
+ @staticmethod
173
+ def extract_cookie_header(request: "Request") -> tuple[str, str] | None:
174
+ """
175
+ Extract the Cookie header from the request.
176
+
177
+ Args:
178
+ request: The FastAPI/Starlette request object
179
+
180
+ Returns:
181
+ Tuple of ("cookie", header_value) if present, None otherwise
182
+ """
183
+ cookie_header = request.headers.get("cookie")
184
+ if cookie_header:
185
+ return "cookie", cookie_header
186
+ return None
187
+
188
+ @staticmethod
189
+ def extract_custom_header(header_name: str) -> HeaderExtractor:
190
+ """
191
+ Create a header extractor for a custom header name.
192
+
193
+ Args:
194
+ header_name: The name of the header to extract (case-insensitive)
195
+
196
+ Returns:
197
+ A header extractor function that extracts the specified header
198
+
199
+ Example::
200
+
201
+ # Create extractor for X-API-Key header
202
+ api_key_extractor = extract_custom_header("x-api-key")
203
+
204
+ app.add_middleware(
205
+ FastAPIAuthMiddleware,
206
+ header_extractors=[api_key_extractor],
207
+ )
208
+ """
209
+
210
+ def extractor(request: "Request") -> tuple[str, str] | None:
211
+ header_value = request.headers.get(header_name.lower())
212
+ if header_value:
213
+ return header_name.lower(), header_value
214
+ return None
215
+
216
+ extractor.__name__ = f"extract_{header_name.replace('-', '_')}_header"
217
+ return extractor
@@ -0,0 +1,93 @@
1
+ from __future__ import annotations
2
+
3
+ import inspect
4
+ from dataclasses import dataclass
5
+ from typing import TYPE_CHECKING
6
+
7
+ import rich.repr
8
+
9
+ import flyte.app
10
+ from flyte.models import SerializationContext
11
+
12
+ if TYPE_CHECKING:
13
+ import fastapi
14
+ import uvicorn
15
+
16
+
17
+ @rich.repr.auto
18
+ @dataclass(kw_only=True, repr=True)
19
+ class FastAPIAppEnvironment(flyte.app.AppEnvironment):
20
+ app: fastapi.FastAPI
21
+ type: str = "FastAPI"
22
+ uvicorn_config: uvicorn.Config | None = None
23
+ _caller_frame: inspect.FrameInfo | None = None
24
+
25
+ def __post_init__(self):
26
+ try:
27
+ import fastapi
28
+ except ModuleNotFoundError:
29
+ raise ModuleNotFoundError(
30
+ "fastapi is not installed. Please install 'fastapi' to use FastAPIAppEnvironment."
31
+ )
32
+
33
+ # starlette is a dependency of fastapi, so if fastapi is installed, starlette is also installed.
34
+ try:
35
+ from starlette.datastructures import State
36
+ except ModuleNotFoundError:
37
+ raise ModuleNotFoundError(
38
+ "starlette is not installed. Please install 'starlette' to use FastAPIAppEnvironment."
39
+ )
40
+
41
+ class PicklableState(State):
42
+ def __getstate__(self):
43
+ state = self.__dict__.copy()
44
+ # Replace the unpicklable State with an empty dict
45
+ state["_state"] = {}
46
+ return state
47
+
48
+ def __setstate__(self, state):
49
+ self.__dict__.update(state)
50
+ # Restore a fresh State object
51
+ self.state = State()
52
+
53
+ # NOTE: since FastAPI cannot be pickled (because starlette.datastructures.State cannot be pickled due to
54
+ # circular references), we need to patch the state object to make it picklable.
55
+ self.app.state = PicklableState()
56
+
57
+ super().__post_init__()
58
+ if self.app is None:
59
+ raise ValueError("app cannot be None for FastAPIAppEnvironment")
60
+ if not isinstance(self.app, fastapi.FastAPI):
61
+ raise TypeError(f"app must be of type fastapi.FastAPI, got {type(self.app)}")
62
+
63
+ self.links = [flyte.app.Link(path="/docs", title="FastAPI OpenAPI Docs", is_relative=True), *self.links]
64
+ self._server = self._fastapi_app_server
65
+
66
+ # Capture the frame where this environment was instantiated
67
+ # This helps us find the module where the app variable is defined
68
+ frame = inspect.currentframe()
69
+ if frame and frame.f_back:
70
+ # Go up the call stack to find the user's module
71
+ # Skip the dataclass __init__ frame
72
+ caller_frame = frame.f_back
73
+ if caller_frame and caller_frame.f_back:
74
+ self._caller_frame = inspect.getframeinfo(caller_frame.f_back)
75
+
76
+ async def _fastapi_app_server(self):
77
+ try:
78
+ import uvicorn
79
+ except ModuleNotFoundError:
80
+ raise ModuleNotFoundError(
81
+ "uvicorn is not installed. Please install 'uvicorn' to use FastAPIAppEnvironment."
82
+ )
83
+
84
+ if self.uvicorn_config is None:
85
+ self.uvicorn_config = uvicorn.Config(self.app, port=self.port.port)
86
+ elif self.uvicorn_config is not None:
87
+ if self.uvicorn_config.port is None:
88
+ self.uvicorn_config.port = self.port.port
89
+
90
+ await uvicorn.Server(self.uvicorn_config).serve()
91
+
92
+ def container_command(self, serialization_context: SerializationContext) -> list[str]:
93
+ return []
@@ -0,0 +1,3 @@
1
+ from .loader import SafeTensorsStreamer, prefetch
2
+
3
+ __all__ = ["SafeTensorsStreamer", "prefetch"]
@@ -0,0 +1,7 @@
1
+ import os
2
+
3
+ REMOTE_MODEL_PATH = os.getenv("FLYTE_MODEL_LOADER_REMOTE_MODEL_PATH")
4
+ LOCAL_MODEL_PATH = os.getenv("FLYTE_MODEL_LOADER_LOCAL_MODEL_PATH", "/srv/model")
5
+ CHUNK_SIZE = int(os.getenv("FLYTE_MODEL_LOADER_CHUNK_SIZE", str(16 * 1024 * 1024)))
6
+ MAX_CONCURRENCY = int(os.getenv("FLYTE_MODEL_LOADER_MAX_CONCURRENCY", str(32)))
7
+ STREAM_SAFETENSORS = os.getenv("FLYTE_MODEL_LOADER_STREAM_SAFETENSORS", "false").lower() == "true"
@@ -0,0 +1,288 @@
1
+ import asyncio
2
+ import json
3
+ import logging
4
+ import pathlib
5
+ import struct
6
+ import time
7
+ import typing
8
+ from collections import defaultdict
9
+
10
+ import obstore
11
+ import pydantic
12
+ from obstore.store import ObjectStore
13
+ from typing_extensions import Annotated
14
+
15
+ from flyte.storage._parallel_reader import (
16
+ BufferProtocol,
17
+ Chunk,
18
+ DownloadQueueEmpty,
19
+ DownloadTask,
20
+ ObstoreParallelReader,
21
+ Source,
22
+ )
23
+ from flyte.storage._storage import get_underlying_filesystem
24
+
25
+ try:
26
+ import torch
27
+ except ModuleNotFoundError:
28
+ raise ModuleNotFoundError("torch is not installed. Please install 'torch', to use the model loader.")
29
+
30
+
31
+ from flyte.app.extras._model_loader.config import (
32
+ CHUNK_SIZE,
33
+ MAX_CONCURRENCY,
34
+ )
35
+
36
+ logger = logging.getLogger(__name__)
37
+
38
+ LITTLE_ENDIAN_LONG_LONG_STRUCT_FORMAT = "<Q"
39
+
40
+ SAFETENSORS_FORMAT_KEY = "format"
41
+ SAFETENSORS_FORMAT_VALUE = "pt"
42
+ SAFETENSORS_SUFFIX = ".safetensors"
43
+ SAFETENSORS_DEFAULT_PATTERN = f"*{SAFETENSORS_SUFFIX}"
44
+ SAFETENSORS_SHARDED_PATTERN = f"model-rank-{{rank}}-part-*{SAFETENSORS_SUFFIX}"
45
+ SAFETENSORS_INTERNAL_METADATA_KEY = "__metadata__"
46
+ SAFETENSORS_INDEX_PATH = "model.safetensors.index.json"
47
+ SAFETENSORS_HEADER_BUFFER_SIZE = 8
48
+ SAFETENSORS_TO_TORCH_DTYPE = {
49
+ "F64": torch.float64,
50
+ "F32": torch.float32,
51
+ "F16": torch.float16,
52
+ "BF16": torch.bfloat16,
53
+ "I64": torch.int64,
54
+ "I32": torch.int32,
55
+ "I16": torch.int16,
56
+ "I8": torch.int8,
57
+ "U8": torch.uint8,
58
+ "BOOL": torch.bool,
59
+ "F8_E5M2": torch.float8_e5m2,
60
+ "F8_E4M3": torch.float8_e4m3fn,
61
+ }
62
+
63
+
64
+ async def prefetch(remote_model_path, local_model_path, exclude_safetensors=True):
65
+ from flyte.storage._storage import _get_obstore_bypass
66
+
67
+ logger.info(f"Pre-fetching model artifacts from {remote_model_path} to {local_model_path}...")
68
+ if exclude_safetensors:
69
+ logger.info(f"Deferring download of safetensor files from {remote_model_path}")
70
+ start = time.perf_counter()
71
+
72
+ try:
73
+ # Exclude safetensors if model streaming is enabled, which will be handled by the flyte vllm model loader
74
+ await _get_obstore_bypass(
75
+ remote_model_path,
76
+ local_model_path,
77
+ recursive=True,
78
+ exclude=[SAFETENSORS_DEFAULT_PATTERN] if exclude_safetensors else None,
79
+ )
80
+ except DownloadQueueEmpty:
81
+ logger.warning("No model artifacts found to pre-fetch.")
82
+ else:
83
+ logger.info(f"Pre-fetched model artifacts in {time.perf_counter() - start:.2f}s")
84
+
85
+
86
+ def _dtype_to_torch_dtype(dtype: str) -> torch.dtype:
87
+ try:
88
+ return SAFETENSORS_TO_TORCH_DTYPE[dtype]
89
+ except KeyError:
90
+ raise ValueError(f"Unsupported dtype: {dtype}")
91
+
92
+
93
+ class TensorMetadata(pydantic.BaseModel):
94
+ model_config = pydantic.ConfigDict(arbitrary_types_allowed=True)
95
+
96
+ name: str
97
+ shape: list[int]
98
+ dtype: Annotated[torch.dtype, pydantic.BeforeValidator(_dtype_to_torch_dtype)]
99
+ data_offsets: tuple[int, int]
100
+
101
+ @pydantic.computed_field # type: ignore[prop-decorator]
102
+ @property
103
+ def size(self) -> int:
104
+ start, end = self.data_offsets
105
+ return end - start
106
+
107
+ @pydantic.computed_field # type: ignore[prop-decorator]
108
+ @property
109
+ def length(self) -> int:
110
+ count = 1
111
+ for dim in self.shape:
112
+ count *= dim
113
+ return count
114
+
115
+ def __len__(self):
116
+ return self.length
117
+
118
+
119
+ class SafeTensorsMetadata(pydantic.BaseModel):
120
+ path: str
121
+ data_start: int
122
+ tensors: list[TensorMetadata]
123
+
124
+
125
+ class SafeTensorsStreamer:
126
+ def __init__(
127
+ self,
128
+ remote_path,
129
+ local_path,
130
+ chunk_size=CHUNK_SIZE,
131
+ max_concurrency=MAX_CONCURRENCY,
132
+ rank=0,
133
+ tensor_parallel_size=1,
134
+ store_kwargs=None,
135
+ ):
136
+ fs = get_underlying_filesystem(path=remote_path)
137
+ bucket, prefix = fs._split_path(remote_path) # pylint: disable=W0212
138
+
139
+ self._store: ObjectStore = fs._construct_store(bucket)
140
+ self._bucket = bucket
141
+ self._prefix: pathlib.Path = pathlib.Path(prefix)
142
+ self._local_path = pathlib.Path(local_path)
143
+ self._reader = ObstoreParallelReader(self._store, chunk_size=chunk_size, max_concurrency=max_concurrency)
144
+ self._rank = rank
145
+ self._tensor_parallel_size = tensor_parallel_size
146
+
147
+ async def _parse_safetensors_metadata(self, path):
148
+ header_len = await obstore.get_range_async(self._store, str(path), start=0, end=SAFETENSORS_HEADER_BUFFER_SIZE)
149
+ header_size = struct.unpack(
150
+ LITTLE_ENDIAN_LONG_LONG_STRUCT_FORMAT,
151
+ header_len,
152
+ )[0]
153
+ header_data = json.loads(
154
+ (
155
+ await obstore.get_range_async(
156
+ self._store,
157
+ str(path),
158
+ start=SAFETENSORS_HEADER_BUFFER_SIZE,
159
+ end=SAFETENSORS_HEADER_BUFFER_SIZE + header_size,
160
+ )
161
+ ).to_bytes()
162
+ )
163
+ if (
164
+ format := header_data.pop(SAFETENSORS_INTERNAL_METADATA_KEY, {}).get(SAFETENSORS_FORMAT_KEY)
165
+ ) and format != SAFETENSORS_FORMAT_VALUE:
166
+ raise ValueError(f"Unsupported format: {format}")
167
+ return SafeTensorsMetadata(
168
+ path=str(path),
169
+ data_start=SAFETENSORS_HEADER_BUFFER_SIZE + header_size,
170
+ tensors=[TensorMetadata.model_validate({"name": k, **v}) for k, v in header_data.items()],
171
+ )
172
+
173
+ async def _list_safetensors_files_with_index(self):
174
+ # Get index of expected tensors if it exists
175
+ weight_map_resp = await obstore.get_async(self._store, str(self._prefix / SAFETENSORS_INDEX_PATH))
176
+ weight_map_bytes = bytes(await weight_map_resp.bytes_async())
177
+ tensor_to_path_map = json.loads(weight_map_bytes)["weight_map"]
178
+
179
+ # Create index for path -> tensors
180
+ index = defaultdict(set)
181
+ for tensor, path in tensor_to_path_map.items():
182
+ index[path].add(tensor)
183
+
184
+ return index.items()
185
+
186
+ async def _load_safetensors_metadata_from_index(self):
187
+ for path, expected in await self._list_safetensors_files_with_index():
188
+ stm = await self._parse_safetensors_metadata(self._prefix / path)
189
+ # Keep only the tensors we expect (should already be deduplicated)
190
+ keep = {tm.name: tm for tm in filter(lambda tm: tm.name in expected, stm.tensors)}
191
+ # We have missing tensors at the path. Bail out!
192
+ if missing := expected - keep.keys():
193
+ raise ValueError(f"Missing {len(missing)} tensors at {path!r}: {' '.join(missing)}")
194
+ stm.tensors = list(keep.values())
195
+ yield stm
196
+
197
+ async def _list_safetensors_files_with_pattern(self, pattern):
198
+ paths = set()
199
+ list_result = await obstore.list_with_delimiter_async(self._store, prefix=str(self._prefix))
200
+ for obj in list_result["objects"]:
201
+ path = pathlib.Path(obj["path"])
202
+ if path.match(pattern):
203
+ paths.add(path)
204
+ if not paths:
205
+ raise ValueError(f"No files found matching pattern: {pattern}")
206
+ return paths
207
+
208
+ async def _load_safetensors_metadata_with_pattern(self, pattern):
209
+ seen = set()
210
+ stms = await asyncio.gather(
211
+ *(
212
+ self._parse_safetensors_metadata(path)
213
+ for path in await self._list_safetensors_files_with_pattern(pattern)
214
+ )
215
+ )
216
+ for stm in stms:
217
+ stm.tensors = list[TensorMetadata](
218
+ filter(
219
+ lambda tm: tm.name not in seen and not seen.add(tm.name),
220
+ stm.tensors,
221
+ )
222
+ )
223
+ yield stm
224
+
225
+ async def _load_safetensors_metadata(self):
226
+ # When using tensor parallelism, we can't rely on the index. Fallback to using a pattern.
227
+ if self._tensor_parallel_size > 1:
228
+ async for stm in self._load_safetensors_metadata_with_pattern(
229
+ SAFETENSORS_SHARDED_PATTERN.format(rank=self._rank)
230
+ ):
231
+ yield stm
232
+ return
233
+
234
+ # No tensor parallelism. Try to use the index first, then fallback to a pattern.
235
+ try:
236
+ async for stm in self._load_safetensors_metadata_from_index():
237
+ yield stm
238
+ except (
239
+ json.decoder.JSONDecodeError,
240
+ FileNotFoundError,
241
+ KeyError,
242
+ ):
243
+ async for stm in self._load_safetensors_metadata_with_pattern(SAFETENSORS_DEFAULT_PATTERN):
244
+ yield stm
245
+
246
+ async def _get_tensors_async(self) -> typing.AsyncGenerator[tuple[str, torch.Tensor], None]:
247
+ async def _to_tensor(buf: BufferProtocol, source: Source) -> torch.Tensor:
248
+ assert isinstance(source.metadata, TensorMetadata)
249
+ return torch.frombuffer(
250
+ await buf.read(),
251
+ dtype=source.metadata.dtype,
252
+ count=len(source.metadata),
253
+ offset=0,
254
+ ).view(source.metadata.shape)
255
+
256
+ async def _gen() -> typing.AsyncGenerator[DownloadTask, None]:
257
+ async for stm in self._load_safetensors_metadata():
258
+ for tm in stm.tensors:
259
+ source = Source(
260
+ id=tm.name,
261
+ path=stm.path,
262
+ length=tm.size,
263
+ offset=stm.data_start + tm.data_offsets[0],
264
+ metadata=tm,
265
+ )
266
+ for offset, length in self._reader._chunks(tm.size):
267
+ yield DownloadTask(
268
+ source=source,
269
+ chunk=Chunk(offset, length),
270
+ )
271
+
272
+ # Yield tensors as they are downloaded
273
+ async for result in self._reader._as_completed(_gen(), transformer=_to_tensor):
274
+ yield result
275
+
276
+ def get_tensors(self) -> typing.Generator[tuple[str, torch.Tensor], None, None]:
277
+ logger.info("Streaming tensors...")
278
+ start = time.perf_counter()
279
+ counter = 0
280
+ gen = self._get_tensors_async()
281
+ with asyncio.Runner() as runner:
282
+ try:
283
+ while True:
284
+ yield runner.run(gen.__anext__())
285
+ counter += 1
286
+ except StopAsyncIteration:
287
+ pass
288
+ logger.info(f"Streamed {counter} tensors in {time.perf_counter() - start:.2f}s")
flyte/cli/__init__.py ADDED
@@ -0,0 +1,12 @@
1
+ import os
2
+
3
+ from flyte.cli.main import main
4
+
5
+ __all__ = ["main"]
6
+
7
+
8
+ # Set GRPC_VERBOSITY to NONE if not already set to silence unwanted output
9
+ # This addresses the issue with grpcio >=1.68.0 causing unwanted output
10
+ # https://github.com/flyteorg/flyte/issues/6082
11
+ if "GRPC_VERBOSITY" not in os.environ:
12
+ os.environ["GRPC_VERBOSITY"] = "NONE"
flyte/cli/_abort.py ADDED
@@ -0,0 +1,28 @@
1
+ import rich_click as click
2
+
3
+ from flyte.cli import _common as common
4
+
5
+
6
+ @click.group(name="abort")
7
+ def abort():
8
+ """
9
+ Abort an ongoing process.
10
+ """
11
+
12
+
13
+ @abort.command(cls=common.CommandBase)
14
+ @click.argument("run-name", type=str, required=True)
15
+ @click.pass_obj
16
+ def run(cfg: common.CLIConfig, run_name: str, project: str | None = None, domain: str | None = None):
17
+ """
18
+ Abort a run.
19
+ """
20
+ from flyte.remote import Run
21
+
22
+ cfg.init(project=project, domain=domain)
23
+ r = Run.get(name=run_name)
24
+ if r:
25
+ console = common.get_console()
26
+ with console.status(f"Aborting run '{run_name}'...", spinner="dots"):
27
+ r.abort()
28
+ console.print(f"Run '{run_name}' has been aborted.")