flyte 0.2.0b1__py3-none-any.whl → 2.0.0b46__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flyte/__init__.py +83 -30
- flyte/_bin/connect.py +61 -0
- flyte/_bin/debug.py +38 -0
- flyte/_bin/runtime.py +87 -19
- flyte/_bin/serve.py +351 -0
- flyte/_build.py +3 -2
- flyte/_cache/cache.py +6 -5
- flyte/_cache/local_cache.py +216 -0
- flyte/_code_bundle/_ignore.py +31 -5
- flyte/_code_bundle/_packaging.py +42 -11
- flyte/_code_bundle/_utils.py +57 -34
- flyte/_code_bundle/bundle.py +130 -27
- flyte/_constants.py +1 -0
- flyte/_context.py +21 -5
- flyte/_custom_context.py +73 -0
- flyte/_debug/constants.py +37 -0
- flyte/_debug/utils.py +17 -0
- flyte/_debug/vscode.py +315 -0
- flyte/_deploy.py +396 -75
- flyte/_deployer.py +109 -0
- flyte/_environment.py +94 -11
- flyte/_excepthook.py +37 -0
- flyte/_group.py +2 -1
- flyte/_hash.py +1 -16
- flyte/_image.py +544 -231
- flyte/_initialize.py +456 -316
- flyte/_interface.py +40 -5
- flyte/_internal/controllers/__init__.py +22 -8
- flyte/_internal/controllers/_local_controller.py +159 -35
- flyte/_internal/controllers/_trace.py +18 -10
- flyte/_internal/controllers/remote/__init__.py +38 -9
- flyte/_internal/controllers/remote/_action.py +82 -12
- flyte/_internal/controllers/remote/_client.py +6 -2
- flyte/_internal/controllers/remote/_controller.py +290 -64
- flyte/_internal/controllers/remote/_core.py +155 -95
- flyte/_internal/controllers/remote/_informer.py +40 -20
- flyte/_internal/controllers/remote/_service_protocol.py +2 -2
- flyte/_internal/imagebuild/__init__.py +2 -10
- flyte/_internal/imagebuild/docker_builder.py +391 -84
- flyte/_internal/imagebuild/image_builder.py +111 -55
- flyte/_internal/imagebuild/remote_builder.py +409 -0
- flyte/_internal/imagebuild/utils.py +79 -0
- flyte/_internal/resolvers/_app_env_module.py +92 -0
- flyte/_internal/resolvers/_task_module.py +5 -38
- flyte/_internal/resolvers/app_env.py +26 -0
- flyte/_internal/resolvers/common.py +8 -1
- flyte/_internal/resolvers/default.py +2 -2
- flyte/_internal/runtime/convert.py +319 -36
- flyte/_internal/runtime/entrypoints.py +106 -18
- flyte/_internal/runtime/io.py +71 -23
- flyte/_internal/runtime/resources_serde.py +21 -7
- flyte/_internal/runtime/reuse.py +125 -0
- flyte/_internal/runtime/rusty.py +196 -0
- flyte/_internal/runtime/task_serde.py +239 -66
- flyte/_internal/runtime/taskrunner.py +48 -8
- flyte/_internal/runtime/trigger_serde.py +162 -0
- flyte/_internal/runtime/types_serde.py +7 -16
- flyte/_keyring/file.py +115 -0
- flyte/_link.py +30 -0
- flyte/_logging.py +241 -42
- flyte/_map.py +312 -0
- flyte/_metrics.py +59 -0
- flyte/_module.py +74 -0
- flyte/_pod.py +30 -0
- flyte/_resources.py +296 -33
- flyte/_retry.py +1 -7
- flyte/_reusable_environment.py +72 -7
- flyte/_run.py +462 -132
- flyte/_secret.py +47 -11
- flyte/_serve.py +333 -0
- flyte/_task.py +245 -56
- flyte/_task_environment.py +219 -97
- flyte/_task_plugins.py +47 -0
- flyte/_tools.py +8 -8
- flyte/_trace.py +15 -24
- flyte/_trigger.py +1027 -0
- flyte/_utils/__init__.py +12 -1
- flyte/_utils/asyn.py +3 -1
- flyte/_utils/async_cache.py +139 -0
- flyte/_utils/coro_management.py +5 -4
- flyte/_utils/description_parser.py +19 -0
- flyte/_utils/docker_credentials.py +173 -0
- flyte/_utils/helpers.py +45 -19
- flyte/_utils/module_loader.py +123 -0
- flyte/_utils/org_discovery.py +57 -0
- flyte/_utils/uv_script_parser.py +8 -1
- flyte/_version.py +16 -3
- flyte/app/__init__.py +27 -0
- flyte/app/_app_environment.py +362 -0
- flyte/app/_connector_environment.py +40 -0
- flyte/app/_deploy.py +130 -0
- flyte/app/_parameter.py +343 -0
- flyte/app/_runtime/__init__.py +3 -0
- flyte/app/_runtime/app_serde.py +383 -0
- flyte/app/_types.py +113 -0
- flyte/app/extras/__init__.py +9 -0
- flyte/app/extras/_auth_middleware.py +217 -0
- flyte/app/extras/_fastapi.py +93 -0
- flyte/app/extras/_model_loader/__init__.py +3 -0
- flyte/app/extras/_model_loader/config.py +7 -0
- flyte/app/extras/_model_loader/loader.py +288 -0
- flyte/cli/__init__.py +12 -0
- flyte/cli/_abort.py +28 -0
- flyte/cli/_build.py +114 -0
- flyte/cli/_common.py +493 -0
- flyte/cli/_create.py +371 -0
- flyte/cli/_delete.py +45 -0
- flyte/cli/_deploy.py +401 -0
- flyte/cli/_gen.py +316 -0
- flyte/cli/_get.py +446 -0
- flyte/cli/_option.py +33 -0
- flyte/{_cli → cli}/_params.py +57 -17
- flyte/cli/_plugins.py +209 -0
- flyte/cli/_prefetch.py +292 -0
- flyte/cli/_run.py +690 -0
- flyte/cli/_serve.py +338 -0
- flyte/cli/_update.py +86 -0
- flyte/cli/_user.py +20 -0
- flyte/cli/main.py +246 -0
- flyte/config/__init__.py +2 -167
- flyte/config/_config.py +215 -163
- flyte/config/_internal.py +10 -1
- flyte/config/_reader.py +225 -0
- flyte/connectors/__init__.py +11 -0
- flyte/connectors/_connector.py +330 -0
- flyte/connectors/_server.py +194 -0
- flyte/connectors/utils.py +159 -0
- flyte/errors.py +134 -2
- flyte/extend.py +24 -0
- flyte/extras/_container.py +69 -56
- flyte/git/__init__.py +3 -0
- flyte/git/_config.py +279 -0
- flyte/io/__init__.py +8 -1
- flyte/io/{structured_dataset → _dataframe}/__init__.py +32 -30
- flyte/io/{structured_dataset → _dataframe}/basic_dfs.py +75 -68
- flyte/io/{structured_dataset/structured_dataset.py → _dataframe/dataframe.py} +207 -242
- flyte/io/_dir.py +575 -113
- flyte/io/_file.py +587 -141
- flyte/io/_hashing_io.py +342 -0
- flyte/io/extend.py +7 -0
- flyte/models.py +635 -0
- flyte/prefetch/__init__.py +22 -0
- flyte/prefetch/_hf_model.py +563 -0
- flyte/remote/__init__.py +14 -3
- flyte/remote/_action.py +879 -0
- flyte/remote/_app.py +346 -0
- flyte/remote/_auth_metadata.py +42 -0
- flyte/remote/_client/_protocols.py +62 -4
- flyte/remote/_client/auth/_auth_utils.py +19 -0
- flyte/remote/_client/auth/_authenticators/base.py +8 -2
- flyte/remote/_client/auth/_authenticators/device_code.py +4 -5
- flyte/remote/_client/auth/_authenticators/factory.py +4 -0
- flyte/remote/_client/auth/_authenticators/passthrough.py +79 -0
- flyte/remote/_client/auth/_authenticators/pkce.py +17 -18
- flyte/remote/_client/auth/_channel.py +47 -18
- flyte/remote/_client/auth/_client_config.py +5 -3
- flyte/remote/_client/auth/_keyring.py +15 -2
- flyte/remote/_client/auth/_token_client.py +3 -3
- flyte/remote/_client/controlplane.py +206 -18
- flyte/remote/_common.py +66 -0
- flyte/remote/_data.py +107 -22
- flyte/remote/_logs.py +116 -33
- flyte/remote/_project.py +21 -19
- flyte/remote/_run.py +164 -631
- flyte/remote/_secret.py +72 -29
- flyte/remote/_task.py +387 -46
- flyte/remote/_trigger.py +368 -0
- flyte/remote/_user.py +43 -0
- flyte/report/_report.py +10 -6
- flyte/storage/__init__.py +13 -1
- flyte/storage/_config.py +237 -0
- flyte/storage/_parallel_reader.py +289 -0
- flyte/storage/_storage.py +268 -59
- flyte/syncify/__init__.py +56 -0
- flyte/syncify/_api.py +414 -0
- flyte/types/__init__.py +39 -0
- flyte/types/_interface.py +22 -7
- flyte/{io/pickle/transformer.py → types/_pickle.py} +37 -9
- flyte/types/_string_literals.py +8 -9
- flyte/types/_type_engine.py +226 -126
- flyte/types/_utils.py +1 -1
- flyte-2.0.0b46.data/scripts/debug.py +38 -0
- flyte-2.0.0b46.data/scripts/runtime.py +194 -0
- flyte-2.0.0b46.dist-info/METADATA +352 -0
- flyte-2.0.0b46.dist-info/RECORD +221 -0
- flyte-2.0.0b46.dist-info/entry_points.txt +8 -0
- flyte-2.0.0b46.dist-info/licenses/LICENSE +201 -0
- flyte/_api_commons.py +0 -3
- flyte/_cli/_common.py +0 -299
- flyte/_cli/_create.py +0 -42
- flyte/_cli/_delete.py +0 -23
- flyte/_cli/_deploy.py +0 -140
- flyte/_cli/_get.py +0 -235
- flyte/_cli/_run.py +0 -174
- flyte/_cli/main.py +0 -98
- flyte/_datastructures.py +0 -342
- flyte/_internal/controllers/pbhash.py +0 -39
- flyte/_protos/common/authorization_pb2.py +0 -66
- flyte/_protos/common/authorization_pb2.pyi +0 -108
- flyte/_protos/common/authorization_pb2_grpc.py +0 -4
- flyte/_protos/common/identifier_pb2.py +0 -71
- flyte/_protos/common/identifier_pb2.pyi +0 -82
- flyte/_protos/common/identifier_pb2_grpc.py +0 -4
- flyte/_protos/common/identity_pb2.py +0 -48
- flyte/_protos/common/identity_pb2.pyi +0 -72
- flyte/_protos/common/identity_pb2_grpc.py +0 -4
- flyte/_protos/common/list_pb2.py +0 -36
- flyte/_protos/common/list_pb2.pyi +0 -69
- flyte/_protos/common/list_pb2_grpc.py +0 -4
- flyte/_protos/common/policy_pb2.py +0 -37
- flyte/_protos/common/policy_pb2.pyi +0 -27
- flyte/_protos/common/policy_pb2_grpc.py +0 -4
- flyte/_protos/common/role_pb2.py +0 -37
- flyte/_protos/common/role_pb2.pyi +0 -53
- flyte/_protos/common/role_pb2_grpc.py +0 -4
- flyte/_protos/common/runtime_version_pb2.py +0 -28
- flyte/_protos/common/runtime_version_pb2.pyi +0 -24
- flyte/_protos/common/runtime_version_pb2_grpc.py +0 -4
- flyte/_protos/logs/dataplane/payload_pb2.py +0 -96
- flyte/_protos/logs/dataplane/payload_pb2.pyi +0 -168
- flyte/_protos/logs/dataplane/payload_pb2_grpc.py +0 -4
- flyte/_protos/secret/definition_pb2.py +0 -49
- flyte/_protos/secret/definition_pb2.pyi +0 -93
- flyte/_protos/secret/definition_pb2_grpc.py +0 -4
- flyte/_protos/secret/payload_pb2.py +0 -62
- flyte/_protos/secret/payload_pb2.pyi +0 -94
- flyte/_protos/secret/payload_pb2_grpc.py +0 -4
- flyte/_protos/secret/secret_pb2.py +0 -38
- flyte/_protos/secret/secret_pb2.pyi +0 -6
- flyte/_protos/secret/secret_pb2_grpc.py +0 -198
- flyte/_protos/secret/secret_pb2_grpc_grpc.py +0 -198
- flyte/_protos/validate/validate/validate_pb2.py +0 -76
- flyte/_protos/workflow/node_execution_service_pb2.py +0 -26
- flyte/_protos/workflow/node_execution_service_pb2.pyi +0 -4
- flyte/_protos/workflow/node_execution_service_pb2_grpc.py +0 -32
- flyte/_protos/workflow/queue_service_pb2.py +0 -106
- flyte/_protos/workflow/queue_service_pb2.pyi +0 -141
- flyte/_protos/workflow/queue_service_pb2_grpc.py +0 -172
- flyte/_protos/workflow/run_definition_pb2.py +0 -128
- flyte/_protos/workflow/run_definition_pb2.pyi +0 -310
- flyte/_protos/workflow/run_definition_pb2_grpc.py +0 -4
- flyte/_protos/workflow/run_logs_service_pb2.py +0 -41
- flyte/_protos/workflow/run_logs_service_pb2.pyi +0 -28
- flyte/_protos/workflow/run_logs_service_pb2_grpc.py +0 -69
- flyte/_protos/workflow/run_service_pb2.py +0 -133
- flyte/_protos/workflow/run_service_pb2.pyi +0 -175
- flyte/_protos/workflow/run_service_pb2_grpc.py +0 -412
- flyte/_protos/workflow/state_service_pb2.py +0 -58
- flyte/_protos/workflow/state_service_pb2.pyi +0 -71
- flyte/_protos/workflow/state_service_pb2_grpc.py +0 -138
- flyte/_protos/workflow/task_definition_pb2.py +0 -72
- flyte/_protos/workflow/task_definition_pb2.pyi +0 -65
- flyte/_protos/workflow/task_definition_pb2_grpc.py +0 -4
- flyte/_protos/workflow/task_service_pb2.py +0 -44
- flyte/_protos/workflow/task_service_pb2.pyi +0 -31
- flyte/_protos/workflow/task_service_pb2_grpc.py +0 -104
- flyte/io/_dataframe.py +0 -0
- flyte/io/pickle/__init__.py +0 -0
- flyte/remote/_console.py +0 -18
- flyte-0.2.0b1.dist-info/METADATA +0 -179
- flyte-0.2.0b1.dist-info/RECORD +0 -204
- flyte-0.2.0b1.dist-info/entry_points.txt +0 -3
- /flyte/{_cli → _debug}/__init__.py +0 -0
- /flyte/{_protos → _keyring}/__init__.py +0 -0
- {flyte-0.2.0b1.dist-info → flyte-2.0.0b46.dist-info}/WHEEL +0 -0
- {flyte-0.2.0b1.dist-info → flyte-2.0.0b46.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,217 @@
|
|
|
1
|
+
"""
|
|
2
|
+
FastAPI middleware for automatic Flyte authentication passthrough.
|
|
3
|
+
|
|
4
|
+
This module provides middleware that automatically extracts authentication headers
|
|
5
|
+
from incoming requests and sets them in the Flyte context, eliminating the need
|
|
6
|
+
for manual auth_metadata() wrapping in every endpoint.
|
|
7
|
+
|
|
8
|
+
Example:
|
|
9
|
+
Basic usage with default extractors (Authorization + Cookie headers)::
|
|
10
|
+
|
|
11
|
+
from fastapi import FastAPI
|
|
12
|
+
from flyte.app.extras import FastAPIAuthMiddleware
|
|
13
|
+
|
|
14
|
+
app = FastAPI()
|
|
15
|
+
app.add_middleware(FastAPIAuthMiddleware, excluded_paths={"/health"})
|
|
16
|
+
|
|
17
|
+
@app.get("/me")
|
|
18
|
+
async def get_current_user():
|
|
19
|
+
# Auth metadata automatically set from request headers
|
|
20
|
+
user = await remote.User.get.aio()
|
|
21
|
+
return {"subject": user.subject()}
|
|
22
|
+
|
|
23
|
+
Advanced usage with custom extractors and path exclusions::
|
|
24
|
+
|
|
25
|
+
from flyte.app.extras import FastAPIAuthMiddleware
|
|
26
|
+
|
|
27
|
+
app.add_middleware(
|
|
28
|
+
FastAPIAuthMiddleware,
|
|
29
|
+
header_extractors=[
|
|
30
|
+
FastAPIAuthMiddleware.extract_authorization_header,
|
|
31
|
+
FastAPIAuthMiddleware.extract_custom_header("x-api-key"),
|
|
32
|
+
],
|
|
33
|
+
excluded_paths={"/health", "/metrics"},
|
|
34
|
+
)
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
from __future__ import annotations
|
|
38
|
+
|
|
39
|
+
import logging
|
|
40
|
+
from typing import TYPE_CHECKING, Callable
|
|
41
|
+
|
|
42
|
+
if TYPE_CHECKING:
|
|
43
|
+
from fastapi import Request
|
|
44
|
+
from starlette.middleware.base import BaseHTTPMiddleware
|
|
45
|
+
from starlette.responses import Response
|
|
46
|
+
else:
|
|
47
|
+
try:
|
|
48
|
+
from starlette.middleware.base import BaseHTTPMiddleware
|
|
49
|
+
except ImportError:
|
|
50
|
+
|
|
51
|
+
class BaseHTTPMiddleware:
|
|
52
|
+
pass
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
logger = logging.getLogger(__name__)
|
|
56
|
+
|
|
57
|
+
# Header extractor type: takes a Request, returns (key, value) tuple or None
|
|
58
|
+
HeaderExtractor = Callable[["Request"], tuple[str, str] | None]
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class FastAPIPassthroughAuthMiddleware(BaseHTTPMiddleware):
|
|
62
|
+
"""
|
|
63
|
+
FastAPI middleware that automatically sets Flyte auth metadata from request headers.
|
|
64
|
+
|
|
65
|
+
This middleware extracts authentication headers from incoming HTTP requests and
|
|
66
|
+
sets them in the Flyte context using the auth_metadata() context manager. This
|
|
67
|
+
eliminates the need to manually wrap endpoint handlers with auth_metadata().
|
|
68
|
+
|
|
69
|
+
The middleware is highly configurable:
|
|
70
|
+
- Custom header extractors can be provided
|
|
71
|
+
- Specific paths can be excluded from auth requirements
|
|
72
|
+
- Auth can be optional or required
|
|
73
|
+
|
|
74
|
+
Attributes:
|
|
75
|
+
app: The FastAPI application (this is a mandatory framework parameter)
|
|
76
|
+
header_extractors: List of functions to extract headers from requests
|
|
77
|
+
excluded_paths: Set of URL paths that bypass auth extraction
|
|
78
|
+
|
|
79
|
+
Thread Safety:
|
|
80
|
+
This middleware is async-safe and properly isolates auth metadata per request
|
|
81
|
+
using Python's contextvars. Multiple concurrent requests with different
|
|
82
|
+
authentication will not interfere with each other.
|
|
83
|
+
"""
|
|
84
|
+
|
|
85
|
+
def __init__(
|
|
86
|
+
self,
|
|
87
|
+
app,
|
|
88
|
+
header_extractors: list[HeaderExtractor] | None = None,
|
|
89
|
+
excluded_paths: set[str] | None = None,
|
|
90
|
+
):
|
|
91
|
+
"""
|
|
92
|
+
Initialize the Flyte authentication middleware.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
app: The FastAPI/Starlette application
|
|
96
|
+
header_extractors: List of functions to extract headers. Each function
|
|
97
|
+
takes a Request and returns (key, value) tuple or None.
|
|
98
|
+
Defaults to [extract_authorization_header, extract_cookie_header].
|
|
99
|
+
excluded_paths: Set of URL paths to exclude from auth extraction.
|
|
100
|
+
Requests to these paths proceed without setting auth context.
|
|
101
|
+
"""
|
|
102
|
+
super().__init__(app)
|
|
103
|
+
|
|
104
|
+
if header_extractors is None:
|
|
105
|
+
self.header_extractors: list[HeaderExtractor] = [
|
|
106
|
+
self.extract_authorization_header,
|
|
107
|
+
self.extract_cookie_header,
|
|
108
|
+
]
|
|
109
|
+
else:
|
|
110
|
+
self.header_extractors = header_extractors
|
|
111
|
+
|
|
112
|
+
self.excluded_paths = excluded_paths or set()
|
|
113
|
+
|
|
114
|
+
async def dispatch(self, request: "Request", call_next) -> "Response":
|
|
115
|
+
"""
|
|
116
|
+
Process each request, extracting auth headers and setting Flyte context.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
request: The incoming HTTP request
|
|
120
|
+
call_next: The next middleware or route handler to call
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
The HTTP response from the handler
|
|
124
|
+
"""
|
|
125
|
+
from starlette.responses import JSONResponse
|
|
126
|
+
|
|
127
|
+
# Skip auth extraction for excluded paths
|
|
128
|
+
if request.url.path in self.excluded_paths:
|
|
129
|
+
return await call_next(request)
|
|
130
|
+
|
|
131
|
+
# Extract auth headers using all configured extractors
|
|
132
|
+
auth_tuples = []
|
|
133
|
+
for extractor in self.header_extractors:
|
|
134
|
+
try:
|
|
135
|
+
result = extractor(request)
|
|
136
|
+
if result is not None:
|
|
137
|
+
auth_tuples.append(result)
|
|
138
|
+
except Exception as e:
|
|
139
|
+
logger.warning(f"Header extractor {extractor.__name__} failed: {e}")
|
|
140
|
+
|
|
141
|
+
# Require auth headers
|
|
142
|
+
if not auth_tuples:
|
|
143
|
+
logger.info("No auth tuples found")
|
|
144
|
+
return JSONResponse(
|
|
145
|
+
status_code=401,
|
|
146
|
+
content={"detail": "Authentication credentials required"},
|
|
147
|
+
headers={"WWW-Authenticate": "Bearer"},
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
# Set auth metadata in Flyte context for the duration of this request
|
|
151
|
+
from flyte.remote import auth_metadata
|
|
152
|
+
|
|
153
|
+
with auth_metadata(*auth_tuples):
|
|
154
|
+
return await call_next(request)
|
|
155
|
+
|
|
156
|
+
@staticmethod
|
|
157
|
+
def extract_authorization_header(request: "Request") -> tuple[str, str] | None:
|
|
158
|
+
"""
|
|
159
|
+
Extract the Authorization header from the request.
|
|
160
|
+
|
|
161
|
+
Args:
|
|
162
|
+
request: The FastAPI/Starlette request object
|
|
163
|
+
|
|
164
|
+
Returns:
|
|
165
|
+
Tuple of ("authorization", header_value) if present, None otherwise
|
|
166
|
+
"""
|
|
167
|
+
auth_header = request.headers.get("authorization")
|
|
168
|
+
if auth_header:
|
|
169
|
+
return "authorization", auth_header
|
|
170
|
+
return None
|
|
171
|
+
|
|
172
|
+
@staticmethod
|
|
173
|
+
def extract_cookie_header(request: "Request") -> tuple[str, str] | None:
|
|
174
|
+
"""
|
|
175
|
+
Extract the Cookie header from the request.
|
|
176
|
+
|
|
177
|
+
Args:
|
|
178
|
+
request: The FastAPI/Starlette request object
|
|
179
|
+
|
|
180
|
+
Returns:
|
|
181
|
+
Tuple of ("cookie", header_value) if present, None otherwise
|
|
182
|
+
"""
|
|
183
|
+
cookie_header = request.headers.get("cookie")
|
|
184
|
+
if cookie_header:
|
|
185
|
+
return "cookie", cookie_header
|
|
186
|
+
return None
|
|
187
|
+
|
|
188
|
+
@staticmethod
|
|
189
|
+
def extract_custom_header(header_name: str) -> HeaderExtractor:
|
|
190
|
+
"""
|
|
191
|
+
Create a header extractor for a custom header name.
|
|
192
|
+
|
|
193
|
+
Args:
|
|
194
|
+
header_name: The name of the header to extract (case-insensitive)
|
|
195
|
+
|
|
196
|
+
Returns:
|
|
197
|
+
A header extractor function that extracts the specified header
|
|
198
|
+
|
|
199
|
+
Example::
|
|
200
|
+
|
|
201
|
+
# Create extractor for X-API-Key header
|
|
202
|
+
api_key_extractor = extract_custom_header("x-api-key")
|
|
203
|
+
|
|
204
|
+
app.add_middleware(
|
|
205
|
+
FastAPIAuthMiddleware,
|
|
206
|
+
header_extractors=[api_key_extractor],
|
|
207
|
+
)
|
|
208
|
+
"""
|
|
209
|
+
|
|
210
|
+
def extractor(request: "Request") -> tuple[str, str] | None:
|
|
211
|
+
header_value = request.headers.get(header_name.lower())
|
|
212
|
+
if header_value:
|
|
213
|
+
return header_name.lower(), header_value
|
|
214
|
+
return None
|
|
215
|
+
|
|
216
|
+
extractor.__name__ = f"extract_{header_name.replace('-', '_')}_header"
|
|
217
|
+
return extractor
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import inspect
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import TYPE_CHECKING
|
|
6
|
+
|
|
7
|
+
import rich.repr
|
|
8
|
+
|
|
9
|
+
import flyte.app
|
|
10
|
+
from flyte.models import SerializationContext
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
import fastapi
|
|
14
|
+
import uvicorn
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@rich.repr.auto
|
|
18
|
+
@dataclass(kw_only=True, repr=True)
|
|
19
|
+
class FastAPIAppEnvironment(flyte.app.AppEnvironment):
|
|
20
|
+
app: fastapi.FastAPI
|
|
21
|
+
type: str = "FastAPI"
|
|
22
|
+
uvicorn_config: uvicorn.Config | None = None
|
|
23
|
+
_caller_frame: inspect.FrameInfo | None = None
|
|
24
|
+
|
|
25
|
+
def __post_init__(self):
|
|
26
|
+
try:
|
|
27
|
+
import fastapi
|
|
28
|
+
except ModuleNotFoundError:
|
|
29
|
+
raise ModuleNotFoundError(
|
|
30
|
+
"fastapi is not installed. Please install 'fastapi' to use FastAPIAppEnvironment."
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
# starlette is a dependency of fastapi, so if fastapi is installed, starlette is also installed.
|
|
34
|
+
try:
|
|
35
|
+
from starlette.datastructures import State
|
|
36
|
+
except ModuleNotFoundError:
|
|
37
|
+
raise ModuleNotFoundError(
|
|
38
|
+
"starlette is not installed. Please install 'starlette' to use FastAPIAppEnvironment."
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
class PicklableState(State):
|
|
42
|
+
def __getstate__(self):
|
|
43
|
+
state = self.__dict__.copy()
|
|
44
|
+
# Replace the unpicklable State with an empty dict
|
|
45
|
+
state["_state"] = {}
|
|
46
|
+
return state
|
|
47
|
+
|
|
48
|
+
def __setstate__(self, state):
|
|
49
|
+
self.__dict__.update(state)
|
|
50
|
+
# Restore a fresh State object
|
|
51
|
+
self.state = State()
|
|
52
|
+
|
|
53
|
+
# NOTE: since FastAPI cannot be pickled (because starlette.datastructures.State cannot be pickled due to
|
|
54
|
+
# circular references), we need to patch the state object to make it picklable.
|
|
55
|
+
self.app.state = PicklableState()
|
|
56
|
+
|
|
57
|
+
super().__post_init__()
|
|
58
|
+
if self.app is None:
|
|
59
|
+
raise ValueError("app cannot be None for FastAPIAppEnvironment")
|
|
60
|
+
if not isinstance(self.app, fastapi.FastAPI):
|
|
61
|
+
raise TypeError(f"app must be of type fastapi.FastAPI, got {type(self.app)}")
|
|
62
|
+
|
|
63
|
+
self.links = [flyte.app.Link(path="/docs", title="FastAPI OpenAPI Docs", is_relative=True), *self.links]
|
|
64
|
+
self._server = self._fastapi_app_server
|
|
65
|
+
|
|
66
|
+
# Capture the frame where this environment was instantiated
|
|
67
|
+
# This helps us find the module where the app variable is defined
|
|
68
|
+
frame = inspect.currentframe()
|
|
69
|
+
if frame and frame.f_back:
|
|
70
|
+
# Go up the call stack to find the user's module
|
|
71
|
+
# Skip the dataclass __init__ frame
|
|
72
|
+
caller_frame = frame.f_back
|
|
73
|
+
if caller_frame and caller_frame.f_back:
|
|
74
|
+
self._caller_frame = inspect.getframeinfo(caller_frame.f_back)
|
|
75
|
+
|
|
76
|
+
async def _fastapi_app_server(self):
|
|
77
|
+
try:
|
|
78
|
+
import uvicorn
|
|
79
|
+
except ModuleNotFoundError:
|
|
80
|
+
raise ModuleNotFoundError(
|
|
81
|
+
"uvicorn is not installed. Please install 'uvicorn' to use FastAPIAppEnvironment."
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
if self.uvicorn_config is None:
|
|
85
|
+
self.uvicorn_config = uvicorn.Config(self.app, port=self.port.port)
|
|
86
|
+
elif self.uvicorn_config is not None:
|
|
87
|
+
if self.uvicorn_config.port is None:
|
|
88
|
+
self.uvicorn_config.port = self.port.port
|
|
89
|
+
|
|
90
|
+
await uvicorn.Server(self.uvicorn_config).serve()
|
|
91
|
+
|
|
92
|
+
def container_command(self, serialization_context: SerializationContext) -> list[str]:
|
|
93
|
+
return []
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
REMOTE_MODEL_PATH = os.getenv("FLYTE_MODEL_LOADER_REMOTE_MODEL_PATH")
|
|
4
|
+
LOCAL_MODEL_PATH = os.getenv("FLYTE_MODEL_LOADER_LOCAL_MODEL_PATH", "/srv/model")
|
|
5
|
+
CHUNK_SIZE = int(os.getenv("FLYTE_MODEL_LOADER_CHUNK_SIZE", str(16 * 1024 * 1024)))
|
|
6
|
+
MAX_CONCURRENCY = int(os.getenv("FLYTE_MODEL_LOADER_MAX_CONCURRENCY", str(32)))
|
|
7
|
+
STREAM_SAFETENSORS = os.getenv("FLYTE_MODEL_LOADER_STREAM_SAFETENSORS", "false").lower() == "true"
|
|
@@ -0,0 +1,288 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
import pathlib
|
|
5
|
+
import struct
|
|
6
|
+
import time
|
|
7
|
+
import typing
|
|
8
|
+
from collections import defaultdict
|
|
9
|
+
|
|
10
|
+
import obstore
|
|
11
|
+
import pydantic
|
|
12
|
+
from obstore.store import ObjectStore
|
|
13
|
+
from typing_extensions import Annotated
|
|
14
|
+
|
|
15
|
+
from flyte.storage._parallel_reader import (
|
|
16
|
+
BufferProtocol,
|
|
17
|
+
Chunk,
|
|
18
|
+
DownloadQueueEmpty,
|
|
19
|
+
DownloadTask,
|
|
20
|
+
ObstoreParallelReader,
|
|
21
|
+
Source,
|
|
22
|
+
)
|
|
23
|
+
from flyte.storage._storage import get_underlying_filesystem
|
|
24
|
+
|
|
25
|
+
try:
|
|
26
|
+
import torch
|
|
27
|
+
except ModuleNotFoundError:
|
|
28
|
+
raise ModuleNotFoundError("torch is not installed. Please install 'torch', to use the model loader.")
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
from flyte.app.extras._model_loader.config import (
|
|
32
|
+
CHUNK_SIZE,
|
|
33
|
+
MAX_CONCURRENCY,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
logger = logging.getLogger(__name__)
|
|
37
|
+
|
|
38
|
+
LITTLE_ENDIAN_LONG_LONG_STRUCT_FORMAT = "<Q"
|
|
39
|
+
|
|
40
|
+
SAFETENSORS_FORMAT_KEY = "format"
|
|
41
|
+
SAFETENSORS_FORMAT_VALUE = "pt"
|
|
42
|
+
SAFETENSORS_SUFFIX = ".safetensors"
|
|
43
|
+
SAFETENSORS_DEFAULT_PATTERN = f"*{SAFETENSORS_SUFFIX}"
|
|
44
|
+
SAFETENSORS_SHARDED_PATTERN = f"model-rank-{{rank}}-part-*{SAFETENSORS_SUFFIX}"
|
|
45
|
+
SAFETENSORS_INTERNAL_METADATA_KEY = "__metadata__"
|
|
46
|
+
SAFETENSORS_INDEX_PATH = "model.safetensors.index.json"
|
|
47
|
+
SAFETENSORS_HEADER_BUFFER_SIZE = 8
|
|
48
|
+
SAFETENSORS_TO_TORCH_DTYPE = {
|
|
49
|
+
"F64": torch.float64,
|
|
50
|
+
"F32": torch.float32,
|
|
51
|
+
"F16": torch.float16,
|
|
52
|
+
"BF16": torch.bfloat16,
|
|
53
|
+
"I64": torch.int64,
|
|
54
|
+
"I32": torch.int32,
|
|
55
|
+
"I16": torch.int16,
|
|
56
|
+
"I8": torch.int8,
|
|
57
|
+
"U8": torch.uint8,
|
|
58
|
+
"BOOL": torch.bool,
|
|
59
|
+
"F8_E5M2": torch.float8_e5m2,
|
|
60
|
+
"F8_E4M3": torch.float8_e4m3fn,
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
async def prefetch(remote_model_path, local_model_path, exclude_safetensors=True):
|
|
65
|
+
from flyte.storage._storage import _get_obstore_bypass
|
|
66
|
+
|
|
67
|
+
logger.info(f"Pre-fetching model artifacts from {remote_model_path} to {local_model_path}...")
|
|
68
|
+
if exclude_safetensors:
|
|
69
|
+
logger.info(f"Deferring download of safetensor files from {remote_model_path}")
|
|
70
|
+
start = time.perf_counter()
|
|
71
|
+
|
|
72
|
+
try:
|
|
73
|
+
# Exclude safetensors if model streaming is enabled, which will be handled by the flyte vllm model loader
|
|
74
|
+
await _get_obstore_bypass(
|
|
75
|
+
remote_model_path,
|
|
76
|
+
local_model_path,
|
|
77
|
+
recursive=True,
|
|
78
|
+
exclude=[SAFETENSORS_DEFAULT_PATTERN] if exclude_safetensors else None,
|
|
79
|
+
)
|
|
80
|
+
except DownloadQueueEmpty:
|
|
81
|
+
logger.warning("No model artifacts found to pre-fetch.")
|
|
82
|
+
else:
|
|
83
|
+
logger.info(f"Pre-fetched model artifacts in {time.perf_counter() - start:.2f}s")
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def _dtype_to_torch_dtype(dtype: str) -> torch.dtype:
|
|
87
|
+
try:
|
|
88
|
+
return SAFETENSORS_TO_TORCH_DTYPE[dtype]
|
|
89
|
+
except KeyError:
|
|
90
|
+
raise ValueError(f"Unsupported dtype: {dtype}")
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class TensorMetadata(pydantic.BaseModel):
|
|
94
|
+
model_config = pydantic.ConfigDict(arbitrary_types_allowed=True)
|
|
95
|
+
|
|
96
|
+
name: str
|
|
97
|
+
shape: list[int]
|
|
98
|
+
dtype: Annotated[torch.dtype, pydantic.BeforeValidator(_dtype_to_torch_dtype)]
|
|
99
|
+
data_offsets: tuple[int, int]
|
|
100
|
+
|
|
101
|
+
@pydantic.computed_field # type: ignore[prop-decorator]
|
|
102
|
+
@property
|
|
103
|
+
def size(self) -> int:
|
|
104
|
+
start, end = self.data_offsets
|
|
105
|
+
return end - start
|
|
106
|
+
|
|
107
|
+
@pydantic.computed_field # type: ignore[prop-decorator]
|
|
108
|
+
@property
|
|
109
|
+
def length(self) -> int:
|
|
110
|
+
count = 1
|
|
111
|
+
for dim in self.shape:
|
|
112
|
+
count *= dim
|
|
113
|
+
return count
|
|
114
|
+
|
|
115
|
+
def __len__(self):
|
|
116
|
+
return self.length
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
class SafeTensorsMetadata(pydantic.BaseModel):
|
|
120
|
+
path: str
|
|
121
|
+
data_start: int
|
|
122
|
+
tensors: list[TensorMetadata]
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class SafeTensorsStreamer:
|
|
126
|
+
def __init__(
|
|
127
|
+
self,
|
|
128
|
+
remote_path,
|
|
129
|
+
local_path,
|
|
130
|
+
chunk_size=CHUNK_SIZE,
|
|
131
|
+
max_concurrency=MAX_CONCURRENCY,
|
|
132
|
+
rank=0,
|
|
133
|
+
tensor_parallel_size=1,
|
|
134
|
+
store_kwargs=None,
|
|
135
|
+
):
|
|
136
|
+
fs = get_underlying_filesystem(path=remote_path)
|
|
137
|
+
bucket, prefix = fs._split_path(remote_path) # pylint: disable=W0212
|
|
138
|
+
|
|
139
|
+
self._store: ObjectStore = fs._construct_store(bucket)
|
|
140
|
+
self._bucket = bucket
|
|
141
|
+
self._prefix: pathlib.Path = pathlib.Path(prefix)
|
|
142
|
+
self._local_path = pathlib.Path(local_path)
|
|
143
|
+
self._reader = ObstoreParallelReader(self._store, chunk_size=chunk_size, max_concurrency=max_concurrency)
|
|
144
|
+
self._rank = rank
|
|
145
|
+
self._tensor_parallel_size = tensor_parallel_size
|
|
146
|
+
|
|
147
|
+
async def _parse_safetensors_metadata(self, path):
|
|
148
|
+
header_len = await obstore.get_range_async(self._store, str(path), start=0, end=SAFETENSORS_HEADER_BUFFER_SIZE)
|
|
149
|
+
header_size = struct.unpack(
|
|
150
|
+
LITTLE_ENDIAN_LONG_LONG_STRUCT_FORMAT,
|
|
151
|
+
header_len,
|
|
152
|
+
)[0]
|
|
153
|
+
header_data = json.loads(
|
|
154
|
+
(
|
|
155
|
+
await obstore.get_range_async(
|
|
156
|
+
self._store,
|
|
157
|
+
str(path),
|
|
158
|
+
start=SAFETENSORS_HEADER_BUFFER_SIZE,
|
|
159
|
+
end=SAFETENSORS_HEADER_BUFFER_SIZE + header_size,
|
|
160
|
+
)
|
|
161
|
+
).to_bytes()
|
|
162
|
+
)
|
|
163
|
+
if (
|
|
164
|
+
format := header_data.pop(SAFETENSORS_INTERNAL_METADATA_KEY, {}).get(SAFETENSORS_FORMAT_KEY)
|
|
165
|
+
) and format != SAFETENSORS_FORMAT_VALUE:
|
|
166
|
+
raise ValueError(f"Unsupported format: {format}")
|
|
167
|
+
return SafeTensorsMetadata(
|
|
168
|
+
path=str(path),
|
|
169
|
+
data_start=SAFETENSORS_HEADER_BUFFER_SIZE + header_size,
|
|
170
|
+
tensors=[TensorMetadata.model_validate({"name": k, **v}) for k, v in header_data.items()],
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
async def _list_safetensors_files_with_index(self):
|
|
174
|
+
# Get index of expected tensors if it exists
|
|
175
|
+
weight_map_resp = await obstore.get_async(self._store, str(self._prefix / SAFETENSORS_INDEX_PATH))
|
|
176
|
+
weight_map_bytes = bytes(await weight_map_resp.bytes_async())
|
|
177
|
+
tensor_to_path_map = json.loads(weight_map_bytes)["weight_map"]
|
|
178
|
+
|
|
179
|
+
# Create index for path -> tensors
|
|
180
|
+
index = defaultdict(set)
|
|
181
|
+
for tensor, path in tensor_to_path_map.items():
|
|
182
|
+
index[path].add(tensor)
|
|
183
|
+
|
|
184
|
+
return index.items()
|
|
185
|
+
|
|
186
|
+
async def _load_safetensors_metadata_from_index(self):
|
|
187
|
+
for path, expected in await self._list_safetensors_files_with_index():
|
|
188
|
+
stm = await self._parse_safetensors_metadata(self._prefix / path)
|
|
189
|
+
# Keep only the tensors we expect (should already be deduplicated)
|
|
190
|
+
keep = {tm.name: tm for tm in filter(lambda tm: tm.name in expected, stm.tensors)}
|
|
191
|
+
# We have missing tensors at the path. Bail out!
|
|
192
|
+
if missing := expected - keep.keys():
|
|
193
|
+
raise ValueError(f"Missing {len(missing)} tensors at {path!r}: {' '.join(missing)}")
|
|
194
|
+
stm.tensors = list(keep.values())
|
|
195
|
+
yield stm
|
|
196
|
+
|
|
197
|
+
async def _list_safetensors_files_with_pattern(self, pattern):
|
|
198
|
+
paths = set()
|
|
199
|
+
list_result = await obstore.list_with_delimiter_async(self._store, prefix=str(self._prefix))
|
|
200
|
+
for obj in list_result["objects"]:
|
|
201
|
+
path = pathlib.Path(obj["path"])
|
|
202
|
+
if path.match(pattern):
|
|
203
|
+
paths.add(path)
|
|
204
|
+
if not paths:
|
|
205
|
+
raise ValueError(f"No files found matching pattern: {pattern}")
|
|
206
|
+
return paths
|
|
207
|
+
|
|
208
|
+
async def _load_safetensors_metadata_with_pattern(self, pattern):
|
|
209
|
+
seen = set()
|
|
210
|
+
stms = await asyncio.gather(
|
|
211
|
+
*(
|
|
212
|
+
self._parse_safetensors_metadata(path)
|
|
213
|
+
for path in await self._list_safetensors_files_with_pattern(pattern)
|
|
214
|
+
)
|
|
215
|
+
)
|
|
216
|
+
for stm in stms:
|
|
217
|
+
stm.tensors = list[TensorMetadata](
|
|
218
|
+
filter(
|
|
219
|
+
lambda tm: tm.name not in seen and not seen.add(tm.name),
|
|
220
|
+
stm.tensors,
|
|
221
|
+
)
|
|
222
|
+
)
|
|
223
|
+
yield stm
|
|
224
|
+
|
|
225
|
+
async def _load_safetensors_metadata(self):
|
|
226
|
+
# When using tensor parallelism, we can't rely on the index. Fallback to using a pattern.
|
|
227
|
+
if self._tensor_parallel_size > 1:
|
|
228
|
+
async for stm in self._load_safetensors_metadata_with_pattern(
|
|
229
|
+
SAFETENSORS_SHARDED_PATTERN.format(rank=self._rank)
|
|
230
|
+
):
|
|
231
|
+
yield stm
|
|
232
|
+
return
|
|
233
|
+
|
|
234
|
+
# No tensor parallelism. Try to use the index first, then fallback to a pattern.
|
|
235
|
+
try:
|
|
236
|
+
async for stm in self._load_safetensors_metadata_from_index():
|
|
237
|
+
yield stm
|
|
238
|
+
except (
|
|
239
|
+
json.decoder.JSONDecodeError,
|
|
240
|
+
FileNotFoundError,
|
|
241
|
+
KeyError,
|
|
242
|
+
):
|
|
243
|
+
async for stm in self._load_safetensors_metadata_with_pattern(SAFETENSORS_DEFAULT_PATTERN):
|
|
244
|
+
yield stm
|
|
245
|
+
|
|
246
|
+
async def _get_tensors_async(self) -> typing.AsyncGenerator[tuple[str, torch.Tensor], None]:
|
|
247
|
+
async def _to_tensor(buf: BufferProtocol, source: Source) -> torch.Tensor:
|
|
248
|
+
assert isinstance(source.metadata, TensorMetadata)
|
|
249
|
+
return torch.frombuffer(
|
|
250
|
+
await buf.read(),
|
|
251
|
+
dtype=source.metadata.dtype,
|
|
252
|
+
count=len(source.metadata),
|
|
253
|
+
offset=0,
|
|
254
|
+
).view(source.metadata.shape)
|
|
255
|
+
|
|
256
|
+
async def _gen() -> typing.AsyncGenerator[DownloadTask, None]:
|
|
257
|
+
async for stm in self._load_safetensors_metadata():
|
|
258
|
+
for tm in stm.tensors:
|
|
259
|
+
source = Source(
|
|
260
|
+
id=tm.name,
|
|
261
|
+
path=stm.path,
|
|
262
|
+
length=tm.size,
|
|
263
|
+
offset=stm.data_start + tm.data_offsets[0],
|
|
264
|
+
metadata=tm,
|
|
265
|
+
)
|
|
266
|
+
for offset, length in self._reader._chunks(tm.size):
|
|
267
|
+
yield DownloadTask(
|
|
268
|
+
source=source,
|
|
269
|
+
chunk=Chunk(offset, length),
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
# Yield tensors as they are downloaded
|
|
273
|
+
async for result in self._reader._as_completed(_gen(), transformer=_to_tensor):
|
|
274
|
+
yield result
|
|
275
|
+
|
|
276
|
+
def get_tensors(self) -> typing.Generator[tuple[str, torch.Tensor], None, None]:
|
|
277
|
+
logger.info("Streaming tensors...")
|
|
278
|
+
start = time.perf_counter()
|
|
279
|
+
counter = 0
|
|
280
|
+
gen = self._get_tensors_async()
|
|
281
|
+
with asyncio.Runner() as runner:
|
|
282
|
+
try:
|
|
283
|
+
while True:
|
|
284
|
+
yield runner.run(gen.__anext__())
|
|
285
|
+
counter += 1
|
|
286
|
+
except StopAsyncIteration:
|
|
287
|
+
pass
|
|
288
|
+
logger.info(f"Streamed {counter} tensors in {time.perf_counter() - start:.2f}s")
|
flyte/cli/__init__.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
from flyte.cli.main import main
|
|
4
|
+
|
|
5
|
+
__all__ = ["main"]
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
# Set GRPC_VERBOSITY to NONE if not already set to silence unwanted output
|
|
9
|
+
# This addresses the issue with grpcio >=1.68.0 causing unwanted output
|
|
10
|
+
# https://github.com/flyteorg/flyte/issues/6082
|
|
11
|
+
if "GRPC_VERBOSITY" not in os.environ:
|
|
12
|
+
os.environ["GRPC_VERBOSITY"] = "NONE"
|
flyte/cli/_abort.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
import rich_click as click
|
|
2
|
+
|
|
3
|
+
from flyte.cli import _common as common
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@click.group(name="abort")
|
|
7
|
+
def abort():
|
|
8
|
+
"""
|
|
9
|
+
Abort an ongoing process.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@abort.command(cls=common.CommandBase)
|
|
14
|
+
@click.argument("run-name", type=str, required=True)
|
|
15
|
+
@click.pass_obj
|
|
16
|
+
def run(cfg: common.CLIConfig, run_name: str, project: str | None = None, domain: str | None = None):
|
|
17
|
+
"""
|
|
18
|
+
Abort a run.
|
|
19
|
+
"""
|
|
20
|
+
from flyte.remote import Run
|
|
21
|
+
|
|
22
|
+
cfg.init(project=project, domain=domain)
|
|
23
|
+
r = Run.get(name=run_name)
|
|
24
|
+
if r:
|
|
25
|
+
console = common.get_console()
|
|
26
|
+
with console.status(f"Aborting run '{run_name}'...", spinner="dots"):
|
|
27
|
+
r.abort()
|
|
28
|
+
console.print(f"Run '{run_name}' has been aborted.")
|