flyte 0.0.1b0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flyte might be problematic. Click here for more details.
- flyte/__init__.py +62 -0
- flyte/_api_commons.py +3 -0
- flyte/_bin/__init__.py +0 -0
- flyte/_bin/runtime.py +126 -0
- flyte/_build.py +25 -0
- flyte/_cache/__init__.py +12 -0
- flyte/_cache/cache.py +146 -0
- flyte/_cache/defaults.py +9 -0
- flyte/_cache/policy_function_body.py +42 -0
- flyte/_cli/__init__.py +0 -0
- flyte/_cli/_common.py +287 -0
- flyte/_cli/_create.py +42 -0
- flyte/_cli/_delete.py +23 -0
- flyte/_cli/_deploy.py +140 -0
- flyte/_cli/_get.py +235 -0
- flyte/_cli/_run.py +152 -0
- flyte/_cli/main.py +72 -0
- flyte/_code_bundle/__init__.py +8 -0
- flyte/_code_bundle/_ignore.py +113 -0
- flyte/_code_bundle/_packaging.py +187 -0
- flyte/_code_bundle/_utils.py +339 -0
- flyte/_code_bundle/bundle.py +178 -0
- flyte/_context.py +146 -0
- flyte/_datastructures.py +342 -0
- flyte/_deploy.py +202 -0
- flyte/_doc.py +29 -0
- flyte/_docstring.py +32 -0
- flyte/_environment.py +43 -0
- flyte/_group.py +31 -0
- flyte/_hash.py +23 -0
- flyte/_image.py +760 -0
- flyte/_initialize.py +634 -0
- flyte/_interface.py +84 -0
- flyte/_internal/__init__.py +3 -0
- flyte/_internal/controllers/__init__.py +115 -0
- flyte/_internal/controllers/_local_controller.py +118 -0
- flyte/_internal/controllers/_trace.py +40 -0
- flyte/_internal/controllers/pbhash.py +39 -0
- flyte/_internal/controllers/remote/__init__.py +40 -0
- flyte/_internal/controllers/remote/_action.py +141 -0
- flyte/_internal/controllers/remote/_client.py +43 -0
- flyte/_internal/controllers/remote/_controller.py +361 -0
- flyte/_internal/controllers/remote/_core.py +402 -0
- flyte/_internal/controllers/remote/_informer.py +361 -0
- flyte/_internal/controllers/remote/_service_protocol.py +50 -0
- flyte/_internal/imagebuild/__init__.py +11 -0
- flyte/_internal/imagebuild/docker_builder.py +416 -0
- flyte/_internal/imagebuild/image_builder.py +241 -0
- flyte/_internal/imagebuild/remote_builder.py +0 -0
- flyte/_internal/resolvers/__init__.py +0 -0
- flyte/_internal/resolvers/_task_module.py +54 -0
- flyte/_internal/resolvers/common.py +31 -0
- flyte/_internal/resolvers/default.py +28 -0
- flyte/_internal/runtime/__init__.py +0 -0
- flyte/_internal/runtime/convert.py +199 -0
- flyte/_internal/runtime/entrypoints.py +135 -0
- flyte/_internal/runtime/io.py +136 -0
- flyte/_internal/runtime/resources_serde.py +138 -0
- flyte/_internal/runtime/task_serde.py +210 -0
- flyte/_internal/runtime/taskrunner.py +190 -0
- flyte/_internal/runtime/types_serde.py +54 -0
- flyte/_logging.py +124 -0
- flyte/_protos/__init__.py +0 -0
- flyte/_protos/common/authorization_pb2.py +66 -0
- flyte/_protos/common/authorization_pb2.pyi +108 -0
- flyte/_protos/common/authorization_pb2_grpc.py +4 -0
- flyte/_protos/common/identifier_pb2.py +71 -0
- flyte/_protos/common/identifier_pb2.pyi +82 -0
- flyte/_protos/common/identifier_pb2_grpc.py +4 -0
- flyte/_protos/common/identity_pb2.py +48 -0
- flyte/_protos/common/identity_pb2.pyi +72 -0
- flyte/_protos/common/identity_pb2_grpc.py +4 -0
- flyte/_protos/common/list_pb2.py +36 -0
- flyte/_protos/common/list_pb2.pyi +69 -0
- flyte/_protos/common/list_pb2_grpc.py +4 -0
- flyte/_protos/common/policy_pb2.py +37 -0
- flyte/_protos/common/policy_pb2.pyi +27 -0
- flyte/_protos/common/policy_pb2_grpc.py +4 -0
- flyte/_protos/common/role_pb2.py +37 -0
- flyte/_protos/common/role_pb2.pyi +53 -0
- flyte/_protos/common/role_pb2_grpc.py +4 -0
- flyte/_protos/common/runtime_version_pb2.py +28 -0
- flyte/_protos/common/runtime_version_pb2.pyi +24 -0
- flyte/_protos/common/runtime_version_pb2_grpc.py +4 -0
- flyte/_protos/logs/dataplane/payload_pb2.py +96 -0
- flyte/_protos/logs/dataplane/payload_pb2.pyi +168 -0
- flyte/_protos/logs/dataplane/payload_pb2_grpc.py +4 -0
- flyte/_protos/secret/definition_pb2.py +49 -0
- flyte/_protos/secret/definition_pb2.pyi +93 -0
- flyte/_protos/secret/definition_pb2_grpc.py +4 -0
- flyte/_protos/secret/payload_pb2.py +62 -0
- flyte/_protos/secret/payload_pb2.pyi +94 -0
- flyte/_protos/secret/payload_pb2_grpc.py +4 -0
- flyte/_protos/secret/secret_pb2.py +38 -0
- flyte/_protos/secret/secret_pb2.pyi +6 -0
- flyte/_protos/secret/secret_pb2_grpc.py +198 -0
- flyte/_protos/secret/secret_pb2_grpc_grpc.py +198 -0
- flyte/_protos/validate/validate/validate_pb2.py +76 -0
- flyte/_protos/workflow/node_execution_service_pb2.py +26 -0
- flyte/_protos/workflow/node_execution_service_pb2.pyi +4 -0
- flyte/_protos/workflow/node_execution_service_pb2_grpc.py +32 -0
- flyte/_protos/workflow/queue_service_pb2.py +106 -0
- flyte/_protos/workflow/queue_service_pb2.pyi +141 -0
- flyte/_protos/workflow/queue_service_pb2_grpc.py +172 -0
- flyte/_protos/workflow/run_definition_pb2.py +128 -0
- flyte/_protos/workflow/run_definition_pb2.pyi +310 -0
- flyte/_protos/workflow/run_definition_pb2_grpc.py +4 -0
- flyte/_protos/workflow/run_logs_service_pb2.py +41 -0
- flyte/_protos/workflow/run_logs_service_pb2.pyi +28 -0
- flyte/_protos/workflow/run_logs_service_pb2_grpc.py +69 -0
- flyte/_protos/workflow/run_service_pb2.py +133 -0
- flyte/_protos/workflow/run_service_pb2.pyi +175 -0
- flyte/_protos/workflow/run_service_pb2_grpc.py +412 -0
- flyte/_protos/workflow/state_service_pb2.py +58 -0
- flyte/_protos/workflow/state_service_pb2.pyi +71 -0
- flyte/_protos/workflow/state_service_pb2_grpc.py +138 -0
- flyte/_protos/workflow/task_definition_pb2.py +72 -0
- flyte/_protos/workflow/task_definition_pb2.pyi +65 -0
- flyte/_protos/workflow/task_definition_pb2_grpc.py +4 -0
- flyte/_protos/workflow/task_service_pb2.py +44 -0
- flyte/_protos/workflow/task_service_pb2.pyi +31 -0
- flyte/_protos/workflow/task_service_pb2_grpc.py +104 -0
- flyte/_resources.py +226 -0
- flyte/_retry.py +32 -0
- flyte/_reusable_environment.py +25 -0
- flyte/_run.py +411 -0
- flyte/_secret.py +61 -0
- flyte/_task.py +367 -0
- flyte/_task_environment.py +200 -0
- flyte/_timeout.py +47 -0
- flyte/_tools.py +27 -0
- flyte/_trace.py +128 -0
- flyte/_utils/__init__.py +20 -0
- flyte/_utils/asyn.py +119 -0
- flyte/_utils/coro_management.py +25 -0
- flyte/_utils/file_handling.py +72 -0
- flyte/_utils/helpers.py +108 -0
- flyte/_utils/lazy_module.py +54 -0
- flyte/_utils/uv_script_parser.py +49 -0
- flyte/_version.py +21 -0
- flyte/connectors/__init__.py +0 -0
- flyte/errors.py +143 -0
- flyte/extras/__init__.py +5 -0
- flyte/extras/_container.py +273 -0
- flyte/io/__init__.py +11 -0
- flyte/io/_dataframe.py +0 -0
- flyte/io/_dir.py +448 -0
- flyte/io/_file.py +468 -0
- flyte/io/pickle/__init__.py +0 -0
- flyte/io/pickle/transformer.py +117 -0
- flyte/io/structured_dataset/__init__.py +129 -0
- flyte/io/structured_dataset/basic_dfs.py +219 -0
- flyte/io/structured_dataset/structured_dataset.py +1061 -0
- flyte/py.typed +0 -0
- flyte/remote/__init__.py +25 -0
- flyte/remote/_client/__init__.py +0 -0
- flyte/remote/_client/_protocols.py +131 -0
- flyte/remote/_client/auth/__init__.py +12 -0
- flyte/remote/_client/auth/_authenticators/__init__.py +0 -0
- flyte/remote/_client/auth/_authenticators/base.py +397 -0
- flyte/remote/_client/auth/_authenticators/client_credentials.py +73 -0
- flyte/remote/_client/auth/_authenticators/device_code.py +118 -0
- flyte/remote/_client/auth/_authenticators/external_command.py +79 -0
- flyte/remote/_client/auth/_authenticators/factory.py +200 -0
- flyte/remote/_client/auth/_authenticators/pkce.py +516 -0
- flyte/remote/_client/auth/_channel.py +184 -0
- flyte/remote/_client/auth/_client_config.py +83 -0
- flyte/remote/_client/auth/_default_html.py +32 -0
- flyte/remote/_client/auth/_grpc_utils/__init__.py +0 -0
- flyte/remote/_client/auth/_grpc_utils/auth_interceptor.py +288 -0
- flyte/remote/_client/auth/_grpc_utils/default_metadata_interceptor.py +151 -0
- flyte/remote/_client/auth/_keyring.py +143 -0
- flyte/remote/_client/auth/_token_client.py +260 -0
- flyte/remote/_client/auth/errors.py +16 -0
- flyte/remote/_client/controlplane.py +95 -0
- flyte/remote/_console.py +18 -0
- flyte/remote/_data.py +155 -0
- flyte/remote/_logs.py +116 -0
- flyte/remote/_project.py +86 -0
- flyte/remote/_run.py +873 -0
- flyte/remote/_secret.py +132 -0
- flyte/remote/_task.py +227 -0
- flyte/report/__init__.py +3 -0
- flyte/report/_report.py +178 -0
- flyte/report/_template.html +124 -0
- flyte/storage/__init__.py +24 -0
- flyte/storage/_remote_fs.py +34 -0
- flyte/storage/_storage.py +251 -0
- flyte/storage/_utils.py +5 -0
- flyte/types/__init__.py +13 -0
- flyte/types/_interface.py +25 -0
- flyte/types/_renderer.py +162 -0
- flyte/types/_string_literals.py +120 -0
- flyte/types/_type_engine.py +2210 -0
- flyte/types/_utils.py +80 -0
- flyte-0.0.1b0.dist-info/METADATA +179 -0
- flyte-0.0.1b0.dist-info/RECORD +390 -0
- flyte-0.0.1b0.dist-info/WHEEL +5 -0
- flyte-0.0.1b0.dist-info/entry_points.txt +3 -0
- flyte-0.0.1b0.dist-info/top_level.txt +1 -0
- union/__init__.py +54 -0
- union/_api_commons.py +3 -0
- union/_bin/__init__.py +0 -0
- union/_bin/runtime.py +113 -0
- union/_build.py +25 -0
- union/_cache/__init__.py +12 -0
- union/_cache/cache.py +141 -0
- union/_cache/defaults.py +9 -0
- union/_cache/policy_function_body.py +42 -0
- union/_cli/__init__.py +0 -0
- union/_cli/_common.py +263 -0
- union/_cli/_create.py +40 -0
- union/_cli/_delete.py +23 -0
- union/_cli/_deploy.py +120 -0
- union/_cli/_get.py +162 -0
- union/_cli/_params.py +579 -0
- union/_cli/_run.py +150 -0
- union/_cli/main.py +72 -0
- union/_code_bundle/__init__.py +8 -0
- union/_code_bundle/_ignore.py +113 -0
- union/_code_bundle/_packaging.py +187 -0
- union/_code_bundle/_utils.py +342 -0
- union/_code_bundle/bundle.py +176 -0
- union/_context.py +146 -0
- union/_datastructures.py +295 -0
- union/_deploy.py +185 -0
- union/_doc.py +29 -0
- union/_docstring.py +26 -0
- union/_environment.py +43 -0
- union/_group.py +31 -0
- union/_hash.py +23 -0
- union/_image.py +760 -0
- union/_initialize.py +585 -0
- union/_interface.py +84 -0
- union/_internal/__init__.py +3 -0
- union/_internal/controllers/__init__.py +77 -0
- union/_internal/controllers/_local_controller.py +77 -0
- union/_internal/controllers/pbhash.py +39 -0
- union/_internal/controllers/remote/__init__.py +40 -0
- union/_internal/controllers/remote/_action.py +131 -0
- union/_internal/controllers/remote/_client.py +43 -0
- union/_internal/controllers/remote/_controller.py +169 -0
- union/_internal/controllers/remote/_core.py +341 -0
- union/_internal/controllers/remote/_informer.py +260 -0
- union/_internal/controllers/remote/_service_protocol.py +44 -0
- union/_internal/imagebuild/__init__.py +11 -0
- union/_internal/imagebuild/docker_builder.py +416 -0
- union/_internal/imagebuild/image_builder.py +243 -0
- union/_internal/imagebuild/remote_builder.py +0 -0
- union/_internal/resolvers/__init__.py +0 -0
- union/_internal/resolvers/_task_module.py +31 -0
- union/_internal/resolvers/common.py +24 -0
- union/_internal/resolvers/default.py +27 -0
- union/_internal/runtime/__init__.py +0 -0
- union/_internal/runtime/convert.py +163 -0
- union/_internal/runtime/entrypoints.py +121 -0
- union/_internal/runtime/io.py +136 -0
- union/_internal/runtime/resources_serde.py +134 -0
- union/_internal/runtime/task_serde.py +202 -0
- union/_internal/runtime/taskrunner.py +179 -0
- union/_internal/runtime/types_serde.py +53 -0
- union/_logging.py +124 -0
- union/_protos/__init__.py +0 -0
- union/_protos/common/authorization_pb2.py +66 -0
- union/_protos/common/authorization_pb2.pyi +106 -0
- union/_protos/common/authorization_pb2_grpc.py +4 -0
- union/_protos/common/identifier_pb2.py +71 -0
- union/_protos/common/identifier_pb2.pyi +82 -0
- union/_protos/common/identifier_pb2_grpc.py +4 -0
- union/_protos/common/identity_pb2.py +48 -0
- union/_protos/common/identity_pb2.pyi +72 -0
- union/_protos/common/identity_pb2_grpc.py +4 -0
- union/_protos/common/list_pb2.py +36 -0
- union/_protos/common/list_pb2.pyi +69 -0
- union/_protos/common/list_pb2_grpc.py +4 -0
- union/_protos/common/policy_pb2.py +37 -0
- union/_protos/common/policy_pb2.pyi +27 -0
- union/_protos/common/policy_pb2_grpc.py +4 -0
- union/_protos/common/role_pb2.py +37 -0
- union/_protos/common/role_pb2.pyi +51 -0
- union/_protos/common/role_pb2_grpc.py +4 -0
- union/_protos/common/runtime_version_pb2.py +28 -0
- union/_protos/common/runtime_version_pb2.pyi +24 -0
- union/_protos/common/runtime_version_pb2_grpc.py +4 -0
- union/_protos/logs/dataplane/payload_pb2.py +96 -0
- union/_protos/logs/dataplane/payload_pb2.pyi +168 -0
- union/_protos/logs/dataplane/payload_pb2_grpc.py +4 -0
- union/_protos/secret/definition_pb2.py +49 -0
- union/_protos/secret/definition_pb2.pyi +93 -0
- union/_protos/secret/definition_pb2_grpc.py +4 -0
- union/_protos/secret/payload_pb2.py +62 -0
- union/_protos/secret/payload_pb2.pyi +94 -0
- union/_protos/secret/payload_pb2_grpc.py +4 -0
- union/_protos/secret/secret_pb2.py +38 -0
- union/_protos/secret/secret_pb2.pyi +6 -0
- union/_protos/secret/secret_pb2_grpc.py +198 -0
- union/_protos/validate/validate/validate_pb2.py +76 -0
- union/_protos/workflow/node_execution_service_pb2.py +26 -0
- union/_protos/workflow/node_execution_service_pb2.pyi +4 -0
- union/_protos/workflow/node_execution_service_pb2_grpc.py +32 -0
- union/_protos/workflow/queue_service_pb2.py +75 -0
- union/_protos/workflow/queue_service_pb2.pyi +103 -0
- union/_protos/workflow/queue_service_pb2_grpc.py +172 -0
- union/_protos/workflow/run_definition_pb2.py +100 -0
- union/_protos/workflow/run_definition_pb2.pyi +256 -0
- union/_protos/workflow/run_definition_pb2_grpc.py +4 -0
- union/_protos/workflow/run_logs_service_pb2.py +41 -0
- union/_protos/workflow/run_logs_service_pb2.pyi +28 -0
- union/_protos/workflow/run_logs_service_pb2_grpc.py +69 -0
- union/_protos/workflow/run_service_pb2.py +133 -0
- union/_protos/workflow/run_service_pb2.pyi +173 -0
- union/_protos/workflow/run_service_pb2_grpc.py +412 -0
- union/_protos/workflow/state_service_pb2.py +58 -0
- union/_protos/workflow/state_service_pb2.pyi +69 -0
- union/_protos/workflow/state_service_pb2_grpc.py +138 -0
- union/_protos/workflow/task_definition_pb2.py +72 -0
- union/_protos/workflow/task_definition_pb2.pyi +65 -0
- union/_protos/workflow/task_definition_pb2_grpc.py +4 -0
- union/_protos/workflow/task_service_pb2.py +44 -0
- union/_protos/workflow/task_service_pb2.pyi +31 -0
- union/_protos/workflow/task_service_pb2_grpc.py +104 -0
- union/_resources.py +226 -0
- union/_retry.py +32 -0
- union/_reusable_environment.py +25 -0
- union/_run.py +374 -0
- union/_secret.py +61 -0
- union/_task.py +354 -0
- union/_task_environment.py +186 -0
- union/_timeout.py +47 -0
- union/_tools.py +27 -0
- union/_utils/__init__.py +11 -0
- union/_utils/asyn.py +119 -0
- union/_utils/file_handling.py +71 -0
- union/_utils/helpers.py +46 -0
- union/_utils/lazy_module.py +54 -0
- union/_utils/uv_script_parser.py +49 -0
- union/_version.py +21 -0
- union/connectors/__init__.py +0 -0
- union/errors.py +128 -0
- union/extras/__init__.py +5 -0
- union/extras/_container.py +263 -0
- union/io/__init__.py +11 -0
- union/io/_dataframe.py +0 -0
- union/io/_dir.py +425 -0
- union/io/_file.py +418 -0
- union/io/pickle/__init__.py +0 -0
- union/io/pickle/transformer.py +117 -0
- union/io/structured_dataset/__init__.py +122 -0
- union/io/structured_dataset/basic_dfs.py +219 -0
- union/io/structured_dataset/structured_dataset.py +1057 -0
- union/py.typed +0 -0
- union/remote/__init__.py +23 -0
- union/remote/_client/__init__.py +0 -0
- union/remote/_client/_protocols.py +129 -0
- union/remote/_client/auth/__init__.py +12 -0
- union/remote/_client/auth/_authenticators/__init__.py +0 -0
- union/remote/_client/auth/_authenticators/base.py +391 -0
- union/remote/_client/auth/_authenticators/client_credentials.py +73 -0
- union/remote/_client/auth/_authenticators/device_code.py +120 -0
- union/remote/_client/auth/_authenticators/external_command.py +77 -0
- union/remote/_client/auth/_authenticators/factory.py +200 -0
- union/remote/_client/auth/_authenticators/pkce.py +515 -0
- union/remote/_client/auth/_channel.py +184 -0
- union/remote/_client/auth/_client_config.py +83 -0
- union/remote/_client/auth/_default_html.py +32 -0
- union/remote/_client/auth/_grpc_utils/__init__.py +0 -0
- union/remote/_client/auth/_grpc_utils/auth_interceptor.py +204 -0
- union/remote/_client/auth/_grpc_utils/default_metadata_interceptor.py +144 -0
- union/remote/_client/auth/_keyring.py +154 -0
- union/remote/_client/auth/_token_client.py +258 -0
- union/remote/_client/auth/errors.py +16 -0
- union/remote/_client/controlplane.py +86 -0
- union/remote/_data.py +149 -0
- union/remote/_logs.py +74 -0
- union/remote/_project.py +86 -0
- union/remote/_run.py +820 -0
- union/remote/_secret.py +132 -0
- union/remote/_task.py +193 -0
- union/report/__init__.py +3 -0
- union/report/_report.py +178 -0
- union/report/_template.html +124 -0
- union/storage/__init__.py +24 -0
- union/storage/_remote_fs.py +34 -0
- union/storage/_storage.py +247 -0
- union/storage/_utils.py +5 -0
- union/types/__init__.py +11 -0
- union/types/_renderer.py +162 -0
- union/types/_string_literals.py +120 -0
- union/types/_type_engine.py +2131 -0
- union/types/_utils.py +80 -0
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
import ssl
|
|
2
|
+
import typing
|
|
3
|
+
|
|
4
|
+
import grpc
|
|
5
|
+
import grpc.aio
|
|
6
|
+
import httpx
|
|
7
|
+
from grpc.experimental.aio import init_grpc_aio
|
|
8
|
+
|
|
9
|
+
from union._logging import logger
|
|
10
|
+
|
|
11
|
+
from ._authenticators.base import get_async_session
|
|
12
|
+
from ._authenticators.factory import (
|
|
13
|
+
create_auth_interceptors,
|
|
14
|
+
create_proxy_auth_interceptors,
|
|
15
|
+
get_async_proxy_authenticator,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
# Initialize gRPC AIO early enough so it can be used in the main thread
|
|
19
|
+
init_grpc_aio()
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def bootstrap_ssl_from_server(endpoint: str) -> grpc.ChannelCredentials:
|
|
23
|
+
"""
|
|
24
|
+
Retrieves the SSL certificate from the remote server and creates gRPC channel credentials.
|
|
25
|
+
|
|
26
|
+
This function should be used only when insecure-skip-verify is enabled. It extracts the server address
|
|
27
|
+
and port from the endpoint URL, retrieves the SSL certificate from the server, and creates
|
|
28
|
+
gRPC channel credentials using the certificate.
|
|
29
|
+
|
|
30
|
+
:param endpoint: The endpoint URL to retrieve the SSL certificate from, may include port number
|
|
31
|
+
:return: gRPC channel credentials created from the retrieved certificate
|
|
32
|
+
"""
|
|
33
|
+
# Get port from endpoint or use 443
|
|
34
|
+
endpoint_parts = endpoint.rsplit(":", 1)
|
|
35
|
+
if len(endpoint_parts) == 2 and endpoint_parts[1].isdigit():
|
|
36
|
+
server_address = (endpoint_parts[0], int(endpoint_parts[1]))
|
|
37
|
+
else:
|
|
38
|
+
logger.warning(f"Unrecognized port in endpoint [{endpoint}], defaulting to 443.")
|
|
39
|
+
server_address = (endpoint, 443)
|
|
40
|
+
|
|
41
|
+
# Run the blocking SSL certificate retrieval in a thread pool
|
|
42
|
+
cert = ssl.get_server_certificate(server_address)
|
|
43
|
+
return grpc.ssl_channel_credentials(str.encode(cert))
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
async def create_channel(
|
|
47
|
+
endpoint: str,
|
|
48
|
+
*,
|
|
49
|
+
insecure: typing.Optional[bool],
|
|
50
|
+
insecure_skip_verify: typing.Optional[bool] = False,
|
|
51
|
+
ca_cert_file_path: typing.Optional[str] = None,
|
|
52
|
+
ssl_credentials: typing.Optional[grpc.ssl_channel_credentials] = None,
|
|
53
|
+
grpc_options: typing.Optional[typing.Sequence[typing.Tuple[str, typing.Any]]] = None,
|
|
54
|
+
compression: typing.Optional[grpc.Compression] = None,
|
|
55
|
+
http_session: httpx.AsyncClient | None = None,
|
|
56
|
+
proxy_command: typing.List[str] | None = None,
|
|
57
|
+
**kwargs,
|
|
58
|
+
) -> grpc.aio.Channel:
|
|
59
|
+
"""
|
|
60
|
+
Creates a new gRPC channel with appropriate authentication interceptors.
|
|
61
|
+
|
|
62
|
+
This function creates either a secure or insecure gRPC channel based on the provided parameters,
|
|
63
|
+
and adds authentication interceptors to the channel. If SSL credentials are not provided,
|
|
64
|
+
they are created based on the insecure_skip_verify and ca_cert_file_path parameters.
|
|
65
|
+
|
|
66
|
+
The function is async because it may need to read certificate files asynchronously
|
|
67
|
+
and create authentication interceptors that perform async operations.
|
|
68
|
+
|
|
69
|
+
:param endpoint: The endpoint URL for the gRPC channel
|
|
70
|
+
:param insecure: Whether to use an insecure channel (no SSL)
|
|
71
|
+
:param insecure_skip_verify: Whether to skip SSL certificate verification
|
|
72
|
+
:param ca_cert_file_path: Path to CA certificate file for SSL verification
|
|
73
|
+
:param ssl_credentials: Pre-configured SSL credentials for the channel
|
|
74
|
+
:param grpc_options: Additional gRPC channel options
|
|
75
|
+
:param compression: Compression method for the channel
|
|
76
|
+
:param http_session: Pre-configured HTTP session to use for requests
|
|
77
|
+
:param proxy_command: List of strings for proxy command configuration
|
|
78
|
+
:param kwargs: Additional arguments passed to various functions:
|
|
79
|
+
- For grpc.aio.insecure_channel/secure_channel:
|
|
80
|
+
- root_certificates: Root certificates for SSL credentials
|
|
81
|
+
- private_key: Private key for SSL credentials
|
|
82
|
+
- certificate_chain: Certificate chain for SSL credentials
|
|
83
|
+
- options: gRPC channel options
|
|
84
|
+
- compression: gRPC compression method
|
|
85
|
+
- For proxy configuration:
|
|
86
|
+
- proxy_env: Dict of environment variables for proxy
|
|
87
|
+
- proxy_timeout: Timeout for proxy connection
|
|
88
|
+
- For authentication interceptors (passed to create_auth_interceptors and create_proxy_auth_interceptors):
|
|
89
|
+
- auth_type: The authentication type to use ("Pkce", "ClientSecret", "ExternalCommand", "DeviceFlow")
|
|
90
|
+
- command: Command to execute for ExternalCommand authentication
|
|
91
|
+
- client_id: Client ID for ClientSecret authentication
|
|
92
|
+
- client_secret: Client secret for ClientSecret authentication
|
|
93
|
+
- client_credentials_secret: Client secret for ClientSecret authentication (alias)
|
|
94
|
+
- scopes: List of scopes to request during authentication
|
|
95
|
+
- audience: Audience for the token
|
|
96
|
+
- http_proxy_url: HTTP proxy URL
|
|
97
|
+
- verify: Whether to verify SSL certificates
|
|
98
|
+
- ca_cert_path: Optional path to CA certificate file
|
|
99
|
+
- header_key: Header key to use for authentication
|
|
100
|
+
- redirect_uri: OAuth2 redirect URI for PKCE authentication
|
|
101
|
+
- add_request_auth_code_params_to_request_access_token_params: Whether to add auth code params to token
|
|
102
|
+
request
|
|
103
|
+
- request_auth_code_params: Parameters to add to login URI opened in browser
|
|
104
|
+
- request_access_token_params: Parameters to add when exchanging auth code for access token
|
|
105
|
+
- refresh_access_token_params: Parameters to add when refreshing access token
|
|
106
|
+
:return: grpc.aio.Channel with authentication interceptors configured
|
|
107
|
+
"""
|
|
108
|
+
|
|
109
|
+
if not ssl_credentials:
|
|
110
|
+
if insecure_skip_verify:
|
|
111
|
+
ssl_credentials = bootstrap_ssl_from_server(endpoint)
|
|
112
|
+
elif ca_cert_file_path:
|
|
113
|
+
import aiofiles
|
|
114
|
+
|
|
115
|
+
async with aiofiles.open(ca_cert_file_path, "rb") as f:
|
|
116
|
+
st_cert = f.read()
|
|
117
|
+
ssl_credentials = grpc.ssl_channel_credentials(st_cert)
|
|
118
|
+
else:
|
|
119
|
+
ssl_credentials = grpc.ssl_channel_credentials()
|
|
120
|
+
|
|
121
|
+
# Create an unauthenticated channel first to use to get the server metadata
|
|
122
|
+
if insecure:
|
|
123
|
+
unauthenticated_channel = grpc.aio.insecure_channel(endpoint, **kwargs)
|
|
124
|
+
else:
|
|
125
|
+
unauthenticated_channel = grpc.aio.secure_channel(
|
|
126
|
+
target=endpoint,
|
|
127
|
+
credentials=ssl_credentials,
|
|
128
|
+
options=grpc_options,
|
|
129
|
+
compression=compression,
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
from ._grpc_utils.default_metadata_interceptor import (
|
|
133
|
+
DefaultMetadataStreamStreamInterceptor,
|
|
134
|
+
DefaultMetadataStreamUnaryInterceptor,
|
|
135
|
+
DefaultMetadataUnaryStreamInterceptor,
|
|
136
|
+
DefaultMetadataUnaryUnaryInterceptor,
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
# Add all types of default metadata interceptors
|
|
140
|
+
interceptors = [
|
|
141
|
+
DefaultMetadataUnaryUnaryInterceptor(),
|
|
142
|
+
DefaultMetadataUnaryStreamInterceptor(),
|
|
143
|
+
DefaultMetadataStreamUnaryInterceptor(),
|
|
144
|
+
DefaultMetadataStreamStreamInterceptor(),
|
|
145
|
+
]
|
|
146
|
+
|
|
147
|
+
# Create an HTTP session if not provided so we share the same http client across the stack
|
|
148
|
+
if not http_session:
|
|
149
|
+
proxy_authenticator = None
|
|
150
|
+
if proxy_command:
|
|
151
|
+
proxy_authenticator = get_async_proxy_authenticator(
|
|
152
|
+
endpoint=endpoint, proxy_command=proxy_command, **kwargs
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
http_session = get_async_session(
|
|
156
|
+
ca_cert_file_path=ca_cert_file_path, proxy_authenticator=proxy_authenticator, **kwargs
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
# Get proxy auth interceptors
|
|
160
|
+
proxy_auth_interceptors = create_proxy_auth_interceptors(endpoint, http_session=http_session, **kwargs)
|
|
161
|
+
interceptors.extend(proxy_auth_interceptors)
|
|
162
|
+
|
|
163
|
+
# Get auth interceptors
|
|
164
|
+
auth_interceptors = create_auth_interceptors(
|
|
165
|
+
endpoint=endpoint,
|
|
166
|
+
in_channel=unauthenticated_channel,
|
|
167
|
+
insecure=insecure,
|
|
168
|
+
insecure_skip_verify=insecure_skip_verify,
|
|
169
|
+
ca_cert_file_path=ca_cert_file_path,
|
|
170
|
+
http_session=http_session,
|
|
171
|
+
**kwargs,
|
|
172
|
+
)
|
|
173
|
+
interceptors.extend(auth_interceptors)
|
|
174
|
+
|
|
175
|
+
if insecure:
|
|
176
|
+
return grpc.aio.insecure_channel(endpoint, interceptors=interceptors, **kwargs)
|
|
177
|
+
|
|
178
|
+
return grpc.aio.secure_channel(
|
|
179
|
+
target=endpoint,
|
|
180
|
+
credentials=ssl_credentials,
|
|
181
|
+
options=grpc_options,
|
|
182
|
+
compression=compression,
|
|
183
|
+
interceptors=interceptors,
|
|
184
|
+
)
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
import typing
|
|
2
|
+
from abc import abstractmethod
|
|
3
|
+
|
|
4
|
+
import grpc.aio
|
|
5
|
+
import pydantic
|
|
6
|
+
from flyteidl.service.auth_pb2 import OAuth2MetadataRequest, PublicClientAuthConfigRequest
|
|
7
|
+
from flyteidl.service.auth_pb2_grpc import AuthMetadataServiceStub
|
|
8
|
+
|
|
9
|
+
AuthType = typing.Literal["ClientSecret", "Pkce", "ExternalCommand", "DeviceFlow"]
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ClientConfig(pydantic.BaseModel):
|
|
13
|
+
"""
|
|
14
|
+
Client Configuration that is needed by the authenticator
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
token_endpoint: str
|
|
18
|
+
authorization_endpoint: str
|
|
19
|
+
redirect_uri: str
|
|
20
|
+
client_id: str
|
|
21
|
+
device_authorization_endpoint: typing.Optional[str] = None
|
|
22
|
+
scopes: typing.List[str] = None
|
|
23
|
+
header_key: str = "authorization"
|
|
24
|
+
audience: typing.Optional[str] = None
|
|
25
|
+
|
|
26
|
+
def with_override(self, other: "ClientConfig") -> "ClientConfig":
|
|
27
|
+
"""
|
|
28
|
+
Returns a new ClientConfig instance with the values from the other instance overriding the current instance.
|
|
29
|
+
"""
|
|
30
|
+
return ClientConfig(
|
|
31
|
+
token_endpoint=other.token_endpoint or self.token_endpoint,
|
|
32
|
+
authorization_endpoint=other.authorization_endpoint or self.authorization_endpoint,
|
|
33
|
+
redirect_uri=other.redirect_uri or self.redirect_uri,
|
|
34
|
+
client_id=other.client_id or self.client_id,
|
|
35
|
+
device_authorization_endpoint=other.device_authorization_endpoint or self.device_authorization_endpoint,
|
|
36
|
+
scopes=other.scopes or self.scopes,
|
|
37
|
+
header_key=other.header_key or self.header_key,
|
|
38
|
+
audience=other.audience or self.audience,
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class ClientConfigStore(object):
|
|
43
|
+
"""
|
|
44
|
+
Client Config store retrieve client config. this can be done in multiple ways
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
@abstractmethod
|
|
48
|
+
async def get_client_config(self) -> ClientConfig: ...
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class StaticClientConfigStore(ClientConfigStore):
|
|
52
|
+
def __init__(self, cfg: ClientConfig):
|
|
53
|
+
self._cfg = cfg
|
|
54
|
+
|
|
55
|
+
async def get_client_config(self) -> ClientConfig:
|
|
56
|
+
return self._cfg
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class RemoteClientConfigStore(ClientConfigStore):
|
|
60
|
+
"""
|
|
61
|
+
This class implements the ClientConfigStore that is served by the Flyte Server, that implements AuthMetadataService
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
def __init__(self, unauthenticated_channel: grpc.aio.Channel):
|
|
65
|
+
self._unauthenticated_channel = unauthenticated_channel
|
|
66
|
+
|
|
67
|
+
async def get_client_config(self) -> ClientConfig:
|
|
68
|
+
"""
|
|
69
|
+
Retrieves the ClientConfig from the given grpc.Channel assuming AuthMetadataService is available
|
|
70
|
+
"""
|
|
71
|
+
metadata_service = AuthMetadataServiceStub(self._unauthenticated_channel)
|
|
72
|
+
public_client_config = await metadata_service.GetPublicClientConfig(PublicClientAuthConfigRequest())
|
|
73
|
+
oauth2_metadata = await metadata_service.GetOAuth2Metadata(OAuth2MetadataRequest())
|
|
74
|
+
return ClientConfig(
|
|
75
|
+
token_endpoint=oauth2_metadata.token_endpoint,
|
|
76
|
+
authorization_endpoint=oauth2_metadata.authorization_endpoint,
|
|
77
|
+
redirect_uri=public_client_config.redirect_uri,
|
|
78
|
+
client_id=public_client_config.client_id,
|
|
79
|
+
scopes=public_client_config.scopes,
|
|
80
|
+
header_key=public_client_config.authorization_metadata_key or None,
|
|
81
|
+
device_authorization_endpoint=oauth2_metadata.device_authorization_endpoint,
|
|
82
|
+
audience=public_client_config.audience,
|
|
83
|
+
)
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
import textwrap
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def get_default_success_html(endpoint: str) -> str:
|
|
5
|
+
"""Get default success html."""
|
|
6
|
+
return textwrap.dedent(
|
|
7
|
+
"""
|
|
8
|
+
<html>
|
|
9
|
+
<head>
|
|
10
|
+
<title>OAuth2 Authentication to Union Successful</title>
|
|
11
|
+
</head>
|
|
12
|
+
<body style="background:white;font-family:Arial">
|
|
13
|
+
<div style="position: absolute;top:40%;left:50%;transform: translate(-50%, -50%);text-align:center;">
|
|
14
|
+
<div style="margin:auto">
|
|
15
|
+
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 300 65" fill="currentColor"
|
|
16
|
+
style="color:#fdb51e;width:360px;">
|
|
17
|
+
<title>Union.ai</title>
|
|
18
|
+
<path d="M32,64.8C14.4,64.8,0,51.5,0,34V3.6h17.6v41.3c0,1.9,1.1,3,3,3h23c1.9,0,3-1.1,3-3V3.6H64V34
|
|
19
|
+
C64,51.5,49.6,64.8,32,64.8z M69.9,30.9v30.4h17.6V20c0-1.9,1.1-3,3-3h23c1.9,0,3,1.1,3,3v41.3H134V30.9c0-17.5-14.4-30.8-32.1-30.8
|
|
20
|
+
S69.9,13.5,69.9,30.9z M236,30.9v30.4h17.6V20c0-1.9,1.1-3,3-3h23c1.9,0,3,1.1,3,3v41.3H300V30.9c0-17.5-14.4-30.8-32-30.8
|
|
21
|
+
S236,13.5,236,30.9L236,30.9z M230.1,32.4c0,18.2-14.2,32.5-32.2,32.5s-32-14.3-32-32.5s14-32.1,32-32.1S230.1,14.3,230.1,32.4
|
|
22
|
+
L230.1,32.4z M213.5,20.2c0-1.9-1.1-3-3-3h-24.8c-1.9,0-3,1.1-3,3v24.5c0,1.9,1.1,3,3,3h24.8c1.9,0,3-1.1,3-3V20.2z M158.9,3.6
|
|
23
|
+
h-17.6v57.8h17.6V3.6z"></path>
|
|
24
|
+
</svg>
|
|
25
|
+
<h2>You've successfully authenticated to Union!</h2>
|
|
26
|
+
<p style="font-size:20px;">Return to your terminal for next steps</p>
|
|
27
|
+
</div>
|
|
28
|
+
</div>
|
|
29
|
+
</body>
|
|
30
|
+
</html>
|
|
31
|
+
""" # noqa: E501
|
|
32
|
+
)
|
|
File without changes
|
|
@@ -0,0 +1,204 @@
|
|
|
1
|
+
import typing
|
|
2
|
+
|
|
3
|
+
import grpc.aio
|
|
4
|
+
from grpc import RpcError
|
|
5
|
+
from grpc.aio import ClientCallDetails, Metadata
|
|
6
|
+
|
|
7
|
+
from union.remote._client.auth._authenticators.base import Authenticator
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class _BaseAuthInterceptor:
|
|
11
|
+
"""
|
|
12
|
+
Base class for all auth interceptors that provides common authentication functionality.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
def __init__(self, get_authenticator: typing.Callable[[], Authenticator]):
|
|
16
|
+
self._get_authenticator = get_authenticator
|
|
17
|
+
self._authenticator = None
|
|
18
|
+
|
|
19
|
+
@property
|
|
20
|
+
def authenticator(self) -> Authenticator:
|
|
21
|
+
if self._authenticator is None:
|
|
22
|
+
self._authenticator = self._get_authenticator()
|
|
23
|
+
return self._authenticator
|
|
24
|
+
|
|
25
|
+
async def _call_details_with_auth_metadata(
|
|
26
|
+
self, client_call_details: grpc.aio.ClientCallDetails
|
|
27
|
+
) -> (grpc.aio.ClientCallDetails, str):
|
|
28
|
+
"""
|
|
29
|
+
Returns new ClientCallDetails with authentication metadata added.
|
|
30
|
+
|
|
31
|
+
This method retrieves authentication metadata from the authenticator and adds it to the
|
|
32
|
+
client call details. If no authentication metadata is available, the original client call
|
|
33
|
+
details are returned unchanged.
|
|
34
|
+
|
|
35
|
+
:param client_call_details: The original client call details containing method, timeout, metadata,
|
|
36
|
+
credentials, and wait_for_ready settings
|
|
37
|
+
:return: Updated client call details with authentication metadata added to the existing metadata
|
|
38
|
+
"""
|
|
39
|
+
metadata = client_call_details.metadata
|
|
40
|
+
auth_metadata = await self.authenticator.get_grpc_call_auth_metadata()
|
|
41
|
+
if auth_metadata:
|
|
42
|
+
metadata = client_call_details.metadata or Metadata()
|
|
43
|
+
for k, v in auth_metadata.pairs.items():
|
|
44
|
+
metadata.add(k, v)
|
|
45
|
+
|
|
46
|
+
return client_call_details._replace(metadata=metadata), auth_metadata.creds_id if auth_metadata else None
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class AuthUnaryUnaryInterceptor(_BaseAuthInterceptor, grpc.aio.UnaryUnaryClientInterceptor):
|
|
50
|
+
"""
|
|
51
|
+
Interceptor for unary-unary RPC calls that adds authentication metadata.
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
async def intercept_unary_unary(
|
|
55
|
+
self,
|
|
56
|
+
continuation: typing.Callable,
|
|
57
|
+
client_call_details: ClientCallDetails,
|
|
58
|
+
request: typing.Any,
|
|
59
|
+
):
|
|
60
|
+
"""
|
|
61
|
+
Intercepts unary-unary calls and adds auth metadata if available. On Unauthenticated, resets the token and
|
|
62
|
+
refreshes and then retries with the new token.
|
|
63
|
+
|
|
64
|
+
This method first adds authentication metadata to the client call details, then attempts to make the RPC call.
|
|
65
|
+
If the call fails with an UNAUTHENTICATED or UNKNOWN status code, it refreshes the credentials and retries
|
|
66
|
+
the call with the new authentication metadata.
|
|
67
|
+
|
|
68
|
+
:param continuation: Function to continue the RPC call chain with the updated call details
|
|
69
|
+
:param client_call_details: Details about the RPC call including method, timeout, metadata, credentials,
|
|
70
|
+
and wait_for_ready
|
|
71
|
+
:param request: The request message to be sent to the server
|
|
72
|
+
:return: The response from the RPC call after successful authentication
|
|
73
|
+
:raises: grpc.aio.AioRpcError if the call fails for reasons other than authentication
|
|
74
|
+
"""
|
|
75
|
+
updated_call_details, creds_id = await self._call_details_with_auth_metadata(client_call_details)
|
|
76
|
+
try:
|
|
77
|
+
return await (await continuation(updated_call_details, request))
|
|
78
|
+
except grpc.aio.AioRpcError as e:
|
|
79
|
+
if e.code() == grpc.StatusCode.UNAUTHENTICATED or e.code() == grpc.StatusCode.UNKNOWN:
|
|
80
|
+
await self.authenticator.refresh_credentials(creds_id=creds_id)
|
|
81
|
+
updated_call_details, _ = await self._call_details_with_auth_metadata(client_call_details)
|
|
82
|
+
return await (await continuation(updated_call_details, request))
|
|
83
|
+
raise e
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class AuthUnaryStreamInterceptor(_BaseAuthInterceptor, grpc.aio.UnaryStreamClientInterceptor):
|
|
87
|
+
"""
|
|
88
|
+
Interceptor for unary-stream RPC calls that adds authentication metadata.
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
async def intercept_unary_stream(
|
|
92
|
+
self, continuation: typing.Callable, client_call_details: grpc.aio.ClientCallDetails, request: typing.Any
|
|
93
|
+
):
|
|
94
|
+
"""
|
|
95
|
+
Intercepts unary-stream calls and adds auth metadata if available.
|
|
96
|
+
|
|
97
|
+
This method first adds authentication metadata to the client call details, then attempts to make the RPC call.
|
|
98
|
+
If the call fails with an UNAUTHENTICATED or UNKNOWN status code, it refreshes the credentials and retries
|
|
99
|
+
the call with the new authentication metadata.
|
|
100
|
+
|
|
101
|
+
:param continuation: Function to continue the RPC call chain with the updated call details
|
|
102
|
+
:param client_call_details: Details about the RPC call including method, timeout, metadata, credentials,
|
|
103
|
+
and wait_for_ready
|
|
104
|
+
:param request: The request message to be sent to the server
|
|
105
|
+
:return: A stream of responses from the RPC call after successful authentication
|
|
106
|
+
:raises: grpc.aio.AioRpcError if the call fails for reasons other than authentication
|
|
107
|
+
"""
|
|
108
|
+
call_details, creds_id = await self._call_details_with_auth_metadata(client_call_details)
|
|
109
|
+
|
|
110
|
+
async def response_iterator() -> typing.AsyncIterator[typing.Any]:
|
|
111
|
+
call = await continuation(call_details, request)
|
|
112
|
+
try:
|
|
113
|
+
async for response in call:
|
|
114
|
+
yield response
|
|
115
|
+
except grpc.aio.AioRpcError as e:
|
|
116
|
+
if e.code() == grpc.StatusCode.UNAUTHENTICATED or e.code() == grpc.StatusCode.UNKNOWN:
|
|
117
|
+
await self.authenticator.refresh_credentials(creds_id=creds_id)
|
|
118
|
+
updated_call_details, _ = await self._call_details_with_auth_metadata(client_call_details)
|
|
119
|
+
async for response in await continuation(updated_call_details, request):
|
|
120
|
+
yield response
|
|
121
|
+
raise e
|
|
122
|
+
|
|
123
|
+
return response_iterator()
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
class AuthStreamUnaryInterceptor(_BaseAuthInterceptor, grpc.aio.StreamUnaryClientInterceptor):
|
|
127
|
+
"""
|
|
128
|
+
Interceptor for stream-unary RPC calls that adds authentication metadata.
|
|
129
|
+
"""
|
|
130
|
+
|
|
131
|
+
async def intercept_stream_unary(
|
|
132
|
+
self,
|
|
133
|
+
continuation: typing.Callable,
|
|
134
|
+
client_call_details: grpc.aio.ClientCallDetails,
|
|
135
|
+
request_iterator: typing.Any,
|
|
136
|
+
):
|
|
137
|
+
"""
|
|
138
|
+
Intercepts stream-unary calls and adds auth metadata if available.
|
|
139
|
+
|
|
140
|
+
This method first adds authentication metadata to the client call details, then attempts to make the RPC call.
|
|
141
|
+
If the call fails with an UNAUTHENTICATED or UNKNOWN status code, it refreshes the credentials and retries
|
|
142
|
+
the call with the new authentication metadata.
|
|
143
|
+
|
|
144
|
+
:param continuation: Function to continue the RPC call chain with the updated call details
|
|
145
|
+
:param client_call_details: Details about the RPC call including method, timeout, metadata, credentials,
|
|
146
|
+
and wait_for_ready
|
|
147
|
+
:param request_iterator: An iterator of request messages to be sent to the server
|
|
148
|
+
:return: The response from the RPC call after successful authentication
|
|
149
|
+
:raises: grpc.aio.AioRpcError if the call fails for reasons other than authentication
|
|
150
|
+
"""
|
|
151
|
+
updated_call_details, creds_id = await self._call_details_with_auth_metadata(client_call_details)
|
|
152
|
+
try:
|
|
153
|
+
call = await continuation(updated_call_details, request_iterator)
|
|
154
|
+
return await call
|
|
155
|
+
except grpc.aio.AioRpcError as e:
|
|
156
|
+
if e.code() == grpc.StatusCode.UNAUTHENTICATED or e.code() == grpc.StatusCode.UNKNOWN:
|
|
157
|
+
await self.authenticator.refresh_credentials(creds_id=creds_id)
|
|
158
|
+
updated_call_details, _ = await self._call_details_with_auth_metadata(client_call_details)
|
|
159
|
+
call = await continuation(updated_call_details, request_iterator)
|
|
160
|
+
return await call
|
|
161
|
+
raise e
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
class AuthStreamStreamInterceptor(_BaseAuthInterceptor, grpc.aio.StreamStreamClientInterceptor):
|
|
165
|
+
"""
|
|
166
|
+
Interceptor for stream-stream RPC calls that adds authentication metadata.
|
|
167
|
+
"""
|
|
168
|
+
|
|
169
|
+
async def intercept_stream_stream(
|
|
170
|
+
self,
|
|
171
|
+
continuation: typing.Callable,
|
|
172
|
+
client_call_details: grpc.aio.ClientCallDetails,
|
|
173
|
+
request_iterator: typing.Any,
|
|
174
|
+
):
|
|
175
|
+
"""
|
|
176
|
+
Intercepts stream-stream calls and adds auth metadata if available.
|
|
177
|
+
|
|
178
|
+
This method first adds authentication metadata to the client call details, then attempts to make the RPC call.
|
|
179
|
+
If the call fails with an UNAUTHENTICATED or UNKNOWN status code, it refreshes the credentials and retries
|
|
180
|
+
the call with the new authentication metadata.
|
|
181
|
+
|
|
182
|
+
:param continuation: Function to continue the RPC call chain with the updated call details
|
|
183
|
+
:param client_call_details: Details about the RPC call including method, timeout, metadata, credentials,
|
|
184
|
+
and wait_for_ready
|
|
185
|
+
:param request_iterator: An iterator of request messages to be sent to the server
|
|
186
|
+
:return: A stream of responses from the RPC call after successful authentication
|
|
187
|
+
:raises: grpc.aio.AioRpcError if the call fails for reasons other than authentication
|
|
188
|
+
"""
|
|
189
|
+
updated_call_details, creds_id = await self._call_details_with_auth_metadata(client_call_details)
|
|
190
|
+
try:
|
|
191
|
+
fut = await (await continuation(updated_call_details, request_iterator))
|
|
192
|
+
return fut
|
|
193
|
+
except grpc.aio.AioRpcError as e:
|
|
194
|
+
if e.code() == grpc.StatusCode.UNAUTHENTICATED or e.code() == grpc.StatusCode.UNKNOWN:
|
|
195
|
+
await self.authenticator.refresh_credentials(creds_id=creds_id)
|
|
196
|
+
updated_call_details, _ = await self._call_details_with_auth_metadata(client_call_details)
|
|
197
|
+
return await (await continuation(updated_call_details, request_iterator))
|
|
198
|
+
raise e
|
|
199
|
+
except RpcError as e:
|
|
200
|
+
raise e
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
# For backward compatibility, maintain the original class name but as a type alias
|
|
204
|
+
AuthUnaryInterceptor = AuthUnaryUnaryInterceptor
|
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
import typing
|
|
2
|
+
|
|
3
|
+
import grpc.aio
|
|
4
|
+
from grpc.aio import ClientCallDetails, Metadata
|
|
5
|
+
|
|
6
|
+
_default_metadata = {"accept": "application/grpc"}
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class _BaseDefaultMetadataInterceptor:
|
|
10
|
+
"""
|
|
11
|
+
Base class for all default metadata interceptors that provides common functionality.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
async def _inject_default_metadata(self, call_details: grpc.aio.ClientCallDetails):
|
|
15
|
+
"""
|
|
16
|
+
Injects default metadata into the client call details.
|
|
17
|
+
|
|
18
|
+
This method adds all key-value pairs from the default metadata dictionary to the
|
|
19
|
+
client call details metadata. If the client call details don't have metadata,
|
|
20
|
+
a new Metadata object is created.
|
|
21
|
+
|
|
22
|
+
:param call_details: The client call details to inject metadata into
|
|
23
|
+
:return: A new ClientCallDetails object with the injected metadata
|
|
24
|
+
"""
|
|
25
|
+
metadata = call_details.metadata or Metadata()
|
|
26
|
+
for k, v in _default_metadata.items():
|
|
27
|
+
metadata.add(k, v)
|
|
28
|
+
|
|
29
|
+
# return call_details._replace(metadata=metadata), None
|
|
30
|
+
return ClientCallDetails(
|
|
31
|
+
method=call_details.method,
|
|
32
|
+
timeout=call_details.timeout,
|
|
33
|
+
metadata=metadata,
|
|
34
|
+
credentials=call_details.credentials,
|
|
35
|
+
wait_for_ready=call_details.wait_for_ready,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class DefaultMetadataUnaryUnaryInterceptor(_BaseDefaultMetadataInterceptor, grpc.aio.UnaryUnaryClientInterceptor):
|
|
40
|
+
"""
|
|
41
|
+
Interceptor for unary-unary RPC calls that adds default metadata.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
async def intercept_unary_unary(
|
|
45
|
+
self,
|
|
46
|
+
continuation: typing.Callable,
|
|
47
|
+
client_call_details: grpc.aio.ClientCallDetails,
|
|
48
|
+
request: typing.Any,
|
|
49
|
+
):
|
|
50
|
+
"""
|
|
51
|
+
Intercepts unary-unary calls and injects default metadata.
|
|
52
|
+
|
|
53
|
+
This method adds default metadata to the client call details before continuing the RPC call chain.
|
|
54
|
+
|
|
55
|
+
:param continuation: Function to continue the RPC call chain with the updated call details
|
|
56
|
+
:param client_call_details: Details about the RPC call including method, timeout, metadata, credentials,
|
|
57
|
+
and wait_for_ready
|
|
58
|
+
:param request: The request message to be sent to the server
|
|
59
|
+
:return: The response from the RPC call
|
|
60
|
+
"""
|
|
61
|
+
updated_call_details = await self._inject_default_metadata(client_call_details)
|
|
62
|
+
return await (await continuation(updated_call_details, request))
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class DefaultMetadataUnaryStreamInterceptor(_BaseDefaultMetadataInterceptor, grpc.aio.UnaryStreamClientInterceptor):
|
|
66
|
+
"""
|
|
67
|
+
Interceptor for unary-stream RPC calls that adds default metadata.
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
async def intercept_unary_stream(
|
|
71
|
+
self,
|
|
72
|
+
continuation: typing.Callable,
|
|
73
|
+
client_call_details: grpc.aio.ClientCallDetails,
|
|
74
|
+
request: typing.Any,
|
|
75
|
+
):
|
|
76
|
+
"""
|
|
77
|
+
Intercepts unary-stream calls and injects default metadata.
|
|
78
|
+
|
|
79
|
+
This method adds default metadata to the client call details before continuing the RPC call chain.
|
|
80
|
+
|
|
81
|
+
:param continuation: Function to continue the RPC call chain with the updated call details
|
|
82
|
+
:param client_call_details: Details about the RPC call including method, timeout, metadata, credentials,
|
|
83
|
+
and wait_for_ready
|
|
84
|
+
:param request: The request message to be sent to the server
|
|
85
|
+
:return: A stream of responses from the RPC call
|
|
86
|
+
"""
|
|
87
|
+
updated_call_details = await self._inject_default_metadata(client_call_details)
|
|
88
|
+
return await continuation(updated_call_details, request)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class DefaultMetadataStreamUnaryInterceptor(_BaseDefaultMetadataInterceptor, grpc.aio.StreamUnaryClientInterceptor):
|
|
92
|
+
"""
|
|
93
|
+
Interceptor for stream-unary RPC calls that adds default metadata.
|
|
94
|
+
"""
|
|
95
|
+
|
|
96
|
+
async def intercept_stream_unary(
|
|
97
|
+
self,
|
|
98
|
+
continuation: typing.Callable,
|
|
99
|
+
client_call_details: grpc.aio.ClientCallDetails,
|
|
100
|
+
request_iterator: typing.Any,
|
|
101
|
+
):
|
|
102
|
+
"""
|
|
103
|
+
Intercepts stream-unary calls and injects default metadata.
|
|
104
|
+
|
|
105
|
+
This method adds default metadata to the client call details before continuing the RPC call chain.
|
|
106
|
+
|
|
107
|
+
:param continuation: Function to continue the RPC call chain with the updated call details
|
|
108
|
+
:param client_call_details: Details about the RPC call including method, timeout, metadata, credentials,
|
|
109
|
+
and wait_for_ready
|
|
110
|
+
:param request_iterator: An iterator of request messages to be sent to the server
|
|
111
|
+
:return: The response from the RPC call
|
|
112
|
+
"""
|
|
113
|
+
updated_call_details = await self._inject_default_metadata(client_call_details)
|
|
114
|
+
return await (await continuation(updated_call_details, request_iterator))
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
class DefaultMetadataStreamStreamInterceptor(_BaseDefaultMetadataInterceptor, grpc.aio.StreamStreamClientInterceptor):
|
|
118
|
+
"""
|
|
119
|
+
Interceptor for stream-stream RPC calls that adds default metadata.
|
|
120
|
+
"""
|
|
121
|
+
|
|
122
|
+
async def intercept_stream_stream(
|
|
123
|
+
self,
|
|
124
|
+
continuation: typing.Callable,
|
|
125
|
+
client_call_details: grpc.aio.ClientCallDetails,
|
|
126
|
+
request_iterator: typing.Any,
|
|
127
|
+
):
|
|
128
|
+
"""
|
|
129
|
+
Intercepts stream-stream calls and injects default metadata.
|
|
130
|
+
|
|
131
|
+
This method adds default metadata to the client call details before continuing the RPC call chain.
|
|
132
|
+
|
|
133
|
+
:param continuation: Function to continue the RPC call chain with the updated call details
|
|
134
|
+
:param client_call_details: Details about the RPC call including method, timeout, metadata, credentials, and
|
|
135
|
+
wait_for_ready
|
|
136
|
+
:param request_iterator: An iterator of request messages to be sent to the server
|
|
137
|
+
:return: A stream of responses from the RPC call
|
|
138
|
+
"""
|
|
139
|
+
updated_call_details = await self._inject_default_metadata(client_call_details)
|
|
140
|
+
return await (await continuation(updated_call_details, request_iterator))
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
# For backward compatibility, maintain the original class name but as a type alias
|
|
144
|
+
DefaultMetadataInterceptor = DefaultMetadataUnaryUnaryInterceptor
|