flyte 0.0.1b0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flyte might be problematic. Click here for more details.
- flyte/__init__.py +62 -0
- flyte/_api_commons.py +3 -0
- flyte/_bin/__init__.py +0 -0
- flyte/_bin/runtime.py +126 -0
- flyte/_build.py +25 -0
- flyte/_cache/__init__.py +12 -0
- flyte/_cache/cache.py +146 -0
- flyte/_cache/defaults.py +9 -0
- flyte/_cache/policy_function_body.py +42 -0
- flyte/_cli/__init__.py +0 -0
- flyte/_cli/_common.py +287 -0
- flyte/_cli/_create.py +42 -0
- flyte/_cli/_delete.py +23 -0
- flyte/_cli/_deploy.py +140 -0
- flyte/_cli/_get.py +235 -0
- flyte/_cli/_run.py +152 -0
- flyte/_cli/main.py +72 -0
- flyte/_code_bundle/__init__.py +8 -0
- flyte/_code_bundle/_ignore.py +113 -0
- flyte/_code_bundle/_packaging.py +187 -0
- flyte/_code_bundle/_utils.py +339 -0
- flyte/_code_bundle/bundle.py +178 -0
- flyte/_context.py +146 -0
- flyte/_datastructures.py +342 -0
- flyte/_deploy.py +202 -0
- flyte/_doc.py +29 -0
- flyte/_docstring.py +32 -0
- flyte/_environment.py +43 -0
- flyte/_group.py +31 -0
- flyte/_hash.py +23 -0
- flyte/_image.py +760 -0
- flyte/_initialize.py +634 -0
- flyte/_interface.py +84 -0
- flyte/_internal/__init__.py +3 -0
- flyte/_internal/controllers/__init__.py +115 -0
- flyte/_internal/controllers/_local_controller.py +118 -0
- flyte/_internal/controllers/_trace.py +40 -0
- flyte/_internal/controllers/pbhash.py +39 -0
- flyte/_internal/controllers/remote/__init__.py +40 -0
- flyte/_internal/controllers/remote/_action.py +141 -0
- flyte/_internal/controllers/remote/_client.py +43 -0
- flyte/_internal/controllers/remote/_controller.py +361 -0
- flyte/_internal/controllers/remote/_core.py +402 -0
- flyte/_internal/controllers/remote/_informer.py +361 -0
- flyte/_internal/controllers/remote/_service_protocol.py +50 -0
- flyte/_internal/imagebuild/__init__.py +11 -0
- flyte/_internal/imagebuild/docker_builder.py +416 -0
- flyte/_internal/imagebuild/image_builder.py +241 -0
- flyte/_internal/imagebuild/remote_builder.py +0 -0
- flyte/_internal/resolvers/__init__.py +0 -0
- flyte/_internal/resolvers/_task_module.py +54 -0
- flyte/_internal/resolvers/common.py +31 -0
- flyte/_internal/resolvers/default.py +28 -0
- flyte/_internal/runtime/__init__.py +0 -0
- flyte/_internal/runtime/convert.py +199 -0
- flyte/_internal/runtime/entrypoints.py +135 -0
- flyte/_internal/runtime/io.py +136 -0
- flyte/_internal/runtime/resources_serde.py +138 -0
- flyte/_internal/runtime/task_serde.py +210 -0
- flyte/_internal/runtime/taskrunner.py +190 -0
- flyte/_internal/runtime/types_serde.py +54 -0
- flyte/_logging.py +124 -0
- flyte/_protos/__init__.py +0 -0
- flyte/_protos/common/authorization_pb2.py +66 -0
- flyte/_protos/common/authorization_pb2.pyi +108 -0
- flyte/_protos/common/authorization_pb2_grpc.py +4 -0
- flyte/_protos/common/identifier_pb2.py +71 -0
- flyte/_protos/common/identifier_pb2.pyi +82 -0
- flyte/_protos/common/identifier_pb2_grpc.py +4 -0
- flyte/_protos/common/identity_pb2.py +48 -0
- flyte/_protos/common/identity_pb2.pyi +72 -0
- flyte/_protos/common/identity_pb2_grpc.py +4 -0
- flyte/_protos/common/list_pb2.py +36 -0
- flyte/_protos/common/list_pb2.pyi +69 -0
- flyte/_protos/common/list_pb2_grpc.py +4 -0
- flyte/_protos/common/policy_pb2.py +37 -0
- flyte/_protos/common/policy_pb2.pyi +27 -0
- flyte/_protos/common/policy_pb2_grpc.py +4 -0
- flyte/_protos/common/role_pb2.py +37 -0
- flyte/_protos/common/role_pb2.pyi +53 -0
- flyte/_protos/common/role_pb2_grpc.py +4 -0
- flyte/_protos/common/runtime_version_pb2.py +28 -0
- flyte/_protos/common/runtime_version_pb2.pyi +24 -0
- flyte/_protos/common/runtime_version_pb2_grpc.py +4 -0
- flyte/_protos/logs/dataplane/payload_pb2.py +96 -0
- flyte/_protos/logs/dataplane/payload_pb2.pyi +168 -0
- flyte/_protos/logs/dataplane/payload_pb2_grpc.py +4 -0
- flyte/_protos/secret/definition_pb2.py +49 -0
- flyte/_protos/secret/definition_pb2.pyi +93 -0
- flyte/_protos/secret/definition_pb2_grpc.py +4 -0
- flyte/_protos/secret/payload_pb2.py +62 -0
- flyte/_protos/secret/payload_pb2.pyi +94 -0
- flyte/_protos/secret/payload_pb2_grpc.py +4 -0
- flyte/_protos/secret/secret_pb2.py +38 -0
- flyte/_protos/secret/secret_pb2.pyi +6 -0
- flyte/_protos/secret/secret_pb2_grpc.py +198 -0
- flyte/_protos/secret/secret_pb2_grpc_grpc.py +198 -0
- flyte/_protos/validate/validate/validate_pb2.py +76 -0
- flyte/_protos/workflow/node_execution_service_pb2.py +26 -0
- flyte/_protos/workflow/node_execution_service_pb2.pyi +4 -0
- flyte/_protos/workflow/node_execution_service_pb2_grpc.py +32 -0
- flyte/_protos/workflow/queue_service_pb2.py +106 -0
- flyte/_protos/workflow/queue_service_pb2.pyi +141 -0
- flyte/_protos/workflow/queue_service_pb2_grpc.py +172 -0
- flyte/_protos/workflow/run_definition_pb2.py +128 -0
- flyte/_protos/workflow/run_definition_pb2.pyi +310 -0
- flyte/_protos/workflow/run_definition_pb2_grpc.py +4 -0
- flyte/_protos/workflow/run_logs_service_pb2.py +41 -0
- flyte/_protos/workflow/run_logs_service_pb2.pyi +28 -0
- flyte/_protos/workflow/run_logs_service_pb2_grpc.py +69 -0
- flyte/_protos/workflow/run_service_pb2.py +133 -0
- flyte/_protos/workflow/run_service_pb2.pyi +175 -0
- flyte/_protos/workflow/run_service_pb2_grpc.py +412 -0
- flyte/_protos/workflow/state_service_pb2.py +58 -0
- flyte/_protos/workflow/state_service_pb2.pyi +71 -0
- flyte/_protos/workflow/state_service_pb2_grpc.py +138 -0
- flyte/_protos/workflow/task_definition_pb2.py +72 -0
- flyte/_protos/workflow/task_definition_pb2.pyi +65 -0
- flyte/_protos/workflow/task_definition_pb2_grpc.py +4 -0
- flyte/_protos/workflow/task_service_pb2.py +44 -0
- flyte/_protos/workflow/task_service_pb2.pyi +31 -0
- flyte/_protos/workflow/task_service_pb2_grpc.py +104 -0
- flyte/_resources.py +226 -0
- flyte/_retry.py +32 -0
- flyte/_reusable_environment.py +25 -0
- flyte/_run.py +411 -0
- flyte/_secret.py +61 -0
- flyte/_task.py +367 -0
- flyte/_task_environment.py +200 -0
- flyte/_timeout.py +47 -0
- flyte/_tools.py +27 -0
- flyte/_trace.py +128 -0
- flyte/_utils/__init__.py +20 -0
- flyte/_utils/asyn.py +119 -0
- flyte/_utils/coro_management.py +25 -0
- flyte/_utils/file_handling.py +72 -0
- flyte/_utils/helpers.py +108 -0
- flyte/_utils/lazy_module.py +54 -0
- flyte/_utils/uv_script_parser.py +49 -0
- flyte/_version.py +21 -0
- flyte/connectors/__init__.py +0 -0
- flyte/errors.py +143 -0
- flyte/extras/__init__.py +5 -0
- flyte/extras/_container.py +273 -0
- flyte/io/__init__.py +11 -0
- flyte/io/_dataframe.py +0 -0
- flyte/io/_dir.py +448 -0
- flyte/io/_file.py +468 -0
- flyte/io/pickle/__init__.py +0 -0
- flyte/io/pickle/transformer.py +117 -0
- flyte/io/structured_dataset/__init__.py +129 -0
- flyte/io/structured_dataset/basic_dfs.py +219 -0
- flyte/io/structured_dataset/structured_dataset.py +1061 -0
- flyte/py.typed +0 -0
- flyte/remote/__init__.py +25 -0
- flyte/remote/_client/__init__.py +0 -0
- flyte/remote/_client/_protocols.py +131 -0
- flyte/remote/_client/auth/__init__.py +12 -0
- flyte/remote/_client/auth/_authenticators/__init__.py +0 -0
- flyte/remote/_client/auth/_authenticators/base.py +397 -0
- flyte/remote/_client/auth/_authenticators/client_credentials.py +73 -0
- flyte/remote/_client/auth/_authenticators/device_code.py +118 -0
- flyte/remote/_client/auth/_authenticators/external_command.py +79 -0
- flyte/remote/_client/auth/_authenticators/factory.py +200 -0
- flyte/remote/_client/auth/_authenticators/pkce.py +516 -0
- flyte/remote/_client/auth/_channel.py +184 -0
- flyte/remote/_client/auth/_client_config.py +83 -0
- flyte/remote/_client/auth/_default_html.py +32 -0
- flyte/remote/_client/auth/_grpc_utils/__init__.py +0 -0
- flyte/remote/_client/auth/_grpc_utils/auth_interceptor.py +288 -0
- flyte/remote/_client/auth/_grpc_utils/default_metadata_interceptor.py +151 -0
- flyte/remote/_client/auth/_keyring.py +143 -0
- flyte/remote/_client/auth/_token_client.py +260 -0
- flyte/remote/_client/auth/errors.py +16 -0
- flyte/remote/_client/controlplane.py +95 -0
- flyte/remote/_console.py +18 -0
- flyte/remote/_data.py +155 -0
- flyte/remote/_logs.py +116 -0
- flyte/remote/_project.py +86 -0
- flyte/remote/_run.py +873 -0
- flyte/remote/_secret.py +132 -0
- flyte/remote/_task.py +227 -0
- flyte/report/__init__.py +3 -0
- flyte/report/_report.py +178 -0
- flyte/report/_template.html +124 -0
- flyte/storage/__init__.py +24 -0
- flyte/storage/_remote_fs.py +34 -0
- flyte/storage/_storage.py +251 -0
- flyte/storage/_utils.py +5 -0
- flyte/types/__init__.py +13 -0
- flyte/types/_interface.py +25 -0
- flyte/types/_renderer.py +162 -0
- flyte/types/_string_literals.py +120 -0
- flyte/types/_type_engine.py +2210 -0
- flyte/types/_utils.py +80 -0
- flyte-0.0.1b0.dist-info/METADATA +179 -0
- flyte-0.0.1b0.dist-info/RECORD +390 -0
- flyte-0.0.1b0.dist-info/WHEEL +5 -0
- flyte-0.0.1b0.dist-info/entry_points.txt +3 -0
- flyte-0.0.1b0.dist-info/top_level.txt +1 -0
- union/__init__.py +54 -0
- union/_api_commons.py +3 -0
- union/_bin/__init__.py +0 -0
- union/_bin/runtime.py +113 -0
- union/_build.py +25 -0
- union/_cache/__init__.py +12 -0
- union/_cache/cache.py +141 -0
- union/_cache/defaults.py +9 -0
- union/_cache/policy_function_body.py +42 -0
- union/_cli/__init__.py +0 -0
- union/_cli/_common.py +263 -0
- union/_cli/_create.py +40 -0
- union/_cli/_delete.py +23 -0
- union/_cli/_deploy.py +120 -0
- union/_cli/_get.py +162 -0
- union/_cli/_params.py +579 -0
- union/_cli/_run.py +150 -0
- union/_cli/main.py +72 -0
- union/_code_bundle/__init__.py +8 -0
- union/_code_bundle/_ignore.py +113 -0
- union/_code_bundle/_packaging.py +187 -0
- union/_code_bundle/_utils.py +342 -0
- union/_code_bundle/bundle.py +176 -0
- union/_context.py +146 -0
- union/_datastructures.py +295 -0
- union/_deploy.py +185 -0
- union/_doc.py +29 -0
- union/_docstring.py +26 -0
- union/_environment.py +43 -0
- union/_group.py +31 -0
- union/_hash.py +23 -0
- union/_image.py +760 -0
- union/_initialize.py +585 -0
- union/_interface.py +84 -0
- union/_internal/__init__.py +3 -0
- union/_internal/controllers/__init__.py +77 -0
- union/_internal/controllers/_local_controller.py +77 -0
- union/_internal/controllers/pbhash.py +39 -0
- union/_internal/controllers/remote/__init__.py +40 -0
- union/_internal/controllers/remote/_action.py +131 -0
- union/_internal/controllers/remote/_client.py +43 -0
- union/_internal/controllers/remote/_controller.py +169 -0
- union/_internal/controllers/remote/_core.py +341 -0
- union/_internal/controllers/remote/_informer.py +260 -0
- union/_internal/controllers/remote/_service_protocol.py +44 -0
- union/_internal/imagebuild/__init__.py +11 -0
- union/_internal/imagebuild/docker_builder.py +416 -0
- union/_internal/imagebuild/image_builder.py +243 -0
- union/_internal/imagebuild/remote_builder.py +0 -0
- union/_internal/resolvers/__init__.py +0 -0
- union/_internal/resolvers/_task_module.py +31 -0
- union/_internal/resolvers/common.py +24 -0
- union/_internal/resolvers/default.py +27 -0
- union/_internal/runtime/__init__.py +0 -0
- union/_internal/runtime/convert.py +163 -0
- union/_internal/runtime/entrypoints.py +121 -0
- union/_internal/runtime/io.py +136 -0
- union/_internal/runtime/resources_serde.py +134 -0
- union/_internal/runtime/task_serde.py +202 -0
- union/_internal/runtime/taskrunner.py +179 -0
- union/_internal/runtime/types_serde.py +53 -0
- union/_logging.py +124 -0
- union/_protos/__init__.py +0 -0
- union/_protos/common/authorization_pb2.py +66 -0
- union/_protos/common/authorization_pb2.pyi +106 -0
- union/_protos/common/authorization_pb2_grpc.py +4 -0
- union/_protos/common/identifier_pb2.py +71 -0
- union/_protos/common/identifier_pb2.pyi +82 -0
- union/_protos/common/identifier_pb2_grpc.py +4 -0
- union/_protos/common/identity_pb2.py +48 -0
- union/_protos/common/identity_pb2.pyi +72 -0
- union/_protos/common/identity_pb2_grpc.py +4 -0
- union/_protos/common/list_pb2.py +36 -0
- union/_protos/common/list_pb2.pyi +69 -0
- union/_protos/common/list_pb2_grpc.py +4 -0
- union/_protos/common/policy_pb2.py +37 -0
- union/_protos/common/policy_pb2.pyi +27 -0
- union/_protos/common/policy_pb2_grpc.py +4 -0
- union/_protos/common/role_pb2.py +37 -0
- union/_protos/common/role_pb2.pyi +51 -0
- union/_protos/common/role_pb2_grpc.py +4 -0
- union/_protos/common/runtime_version_pb2.py +28 -0
- union/_protos/common/runtime_version_pb2.pyi +24 -0
- union/_protos/common/runtime_version_pb2_grpc.py +4 -0
- union/_protos/logs/dataplane/payload_pb2.py +96 -0
- union/_protos/logs/dataplane/payload_pb2.pyi +168 -0
- union/_protos/logs/dataplane/payload_pb2_grpc.py +4 -0
- union/_protos/secret/definition_pb2.py +49 -0
- union/_protos/secret/definition_pb2.pyi +93 -0
- union/_protos/secret/definition_pb2_grpc.py +4 -0
- union/_protos/secret/payload_pb2.py +62 -0
- union/_protos/secret/payload_pb2.pyi +94 -0
- union/_protos/secret/payload_pb2_grpc.py +4 -0
- union/_protos/secret/secret_pb2.py +38 -0
- union/_protos/secret/secret_pb2.pyi +6 -0
- union/_protos/secret/secret_pb2_grpc.py +198 -0
- union/_protos/validate/validate/validate_pb2.py +76 -0
- union/_protos/workflow/node_execution_service_pb2.py +26 -0
- union/_protos/workflow/node_execution_service_pb2.pyi +4 -0
- union/_protos/workflow/node_execution_service_pb2_grpc.py +32 -0
- union/_protos/workflow/queue_service_pb2.py +75 -0
- union/_protos/workflow/queue_service_pb2.pyi +103 -0
- union/_protos/workflow/queue_service_pb2_grpc.py +172 -0
- union/_protos/workflow/run_definition_pb2.py +100 -0
- union/_protos/workflow/run_definition_pb2.pyi +256 -0
- union/_protos/workflow/run_definition_pb2_grpc.py +4 -0
- union/_protos/workflow/run_logs_service_pb2.py +41 -0
- union/_protos/workflow/run_logs_service_pb2.pyi +28 -0
- union/_protos/workflow/run_logs_service_pb2_grpc.py +69 -0
- union/_protos/workflow/run_service_pb2.py +133 -0
- union/_protos/workflow/run_service_pb2.pyi +173 -0
- union/_protos/workflow/run_service_pb2_grpc.py +412 -0
- union/_protos/workflow/state_service_pb2.py +58 -0
- union/_protos/workflow/state_service_pb2.pyi +69 -0
- union/_protos/workflow/state_service_pb2_grpc.py +138 -0
- union/_protos/workflow/task_definition_pb2.py +72 -0
- union/_protos/workflow/task_definition_pb2.pyi +65 -0
- union/_protos/workflow/task_definition_pb2_grpc.py +4 -0
- union/_protos/workflow/task_service_pb2.py +44 -0
- union/_protos/workflow/task_service_pb2.pyi +31 -0
- union/_protos/workflow/task_service_pb2_grpc.py +104 -0
- union/_resources.py +226 -0
- union/_retry.py +32 -0
- union/_reusable_environment.py +25 -0
- union/_run.py +374 -0
- union/_secret.py +61 -0
- union/_task.py +354 -0
- union/_task_environment.py +186 -0
- union/_timeout.py +47 -0
- union/_tools.py +27 -0
- union/_utils/__init__.py +11 -0
- union/_utils/asyn.py +119 -0
- union/_utils/file_handling.py +71 -0
- union/_utils/helpers.py +46 -0
- union/_utils/lazy_module.py +54 -0
- union/_utils/uv_script_parser.py +49 -0
- union/_version.py +21 -0
- union/connectors/__init__.py +0 -0
- union/errors.py +128 -0
- union/extras/__init__.py +5 -0
- union/extras/_container.py +263 -0
- union/io/__init__.py +11 -0
- union/io/_dataframe.py +0 -0
- union/io/_dir.py +425 -0
- union/io/_file.py +418 -0
- union/io/pickle/__init__.py +0 -0
- union/io/pickle/transformer.py +117 -0
- union/io/structured_dataset/__init__.py +122 -0
- union/io/structured_dataset/basic_dfs.py +219 -0
- union/io/structured_dataset/structured_dataset.py +1057 -0
- union/py.typed +0 -0
- union/remote/__init__.py +23 -0
- union/remote/_client/__init__.py +0 -0
- union/remote/_client/_protocols.py +129 -0
- union/remote/_client/auth/__init__.py +12 -0
- union/remote/_client/auth/_authenticators/__init__.py +0 -0
- union/remote/_client/auth/_authenticators/base.py +391 -0
- union/remote/_client/auth/_authenticators/client_credentials.py +73 -0
- union/remote/_client/auth/_authenticators/device_code.py +120 -0
- union/remote/_client/auth/_authenticators/external_command.py +77 -0
- union/remote/_client/auth/_authenticators/factory.py +200 -0
- union/remote/_client/auth/_authenticators/pkce.py +515 -0
- union/remote/_client/auth/_channel.py +184 -0
- union/remote/_client/auth/_client_config.py +83 -0
- union/remote/_client/auth/_default_html.py +32 -0
- union/remote/_client/auth/_grpc_utils/__init__.py +0 -0
- union/remote/_client/auth/_grpc_utils/auth_interceptor.py +204 -0
- union/remote/_client/auth/_grpc_utils/default_metadata_interceptor.py +144 -0
- union/remote/_client/auth/_keyring.py +154 -0
- union/remote/_client/auth/_token_client.py +258 -0
- union/remote/_client/auth/errors.py +16 -0
- union/remote/_client/controlplane.py +86 -0
- union/remote/_data.py +149 -0
- union/remote/_logs.py +74 -0
- union/remote/_project.py +86 -0
- union/remote/_run.py +820 -0
- union/remote/_secret.py +132 -0
- union/remote/_task.py +193 -0
- union/report/__init__.py +3 -0
- union/report/_report.py +178 -0
- union/report/_template.html +124 -0
- union/storage/__init__.py +24 -0
- union/storage/_remote_fs.py +34 -0
- union/storage/_storage.py +247 -0
- union/storage/_utils.py +5 -0
- union/types/__init__.py +11 -0
- union/types/_renderer.py +162 -0
- union/types/_string_literals.py +120 -0
- union/types/_type_engine.py +2131 -0
- union/types/_utils.py +80 -0
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
import ssl
|
|
2
|
+
import typing
|
|
3
|
+
|
|
4
|
+
import grpc
|
|
5
|
+
import grpc.aio
|
|
6
|
+
import httpx
|
|
7
|
+
from grpc.experimental.aio import init_grpc_aio
|
|
8
|
+
|
|
9
|
+
from flyte._logging import logger
|
|
10
|
+
|
|
11
|
+
from ._authenticators.base import get_async_session
|
|
12
|
+
from ._authenticators.factory import (
|
|
13
|
+
create_auth_interceptors,
|
|
14
|
+
create_proxy_auth_interceptors,
|
|
15
|
+
get_async_proxy_authenticator,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
# Initialize gRPC AIO early enough so it can be used in the main thread
|
|
19
|
+
init_grpc_aio()
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def bootstrap_ssl_from_server(endpoint: str) -> grpc.ChannelCredentials:
|
|
23
|
+
"""
|
|
24
|
+
Retrieves the SSL certificate from the remote server and creates gRPC channel credentials.
|
|
25
|
+
|
|
26
|
+
This function should be used only when insecure-skip-verify is enabled. It extracts the server address
|
|
27
|
+
and port from the endpoint URL, retrieves the SSL certificate from the server, and creates
|
|
28
|
+
gRPC channel credentials using the certificate.
|
|
29
|
+
|
|
30
|
+
:param endpoint: The endpoint URL to retrieve the SSL certificate from, may include port number
|
|
31
|
+
:return: gRPC channel credentials created from the retrieved certificate
|
|
32
|
+
"""
|
|
33
|
+
# Get port from endpoint or use 443
|
|
34
|
+
endpoint_parts = endpoint.rsplit(":", 1)
|
|
35
|
+
if len(endpoint_parts) == 2 and endpoint_parts[1].isdigit():
|
|
36
|
+
server_address = (endpoint_parts[0], int(endpoint_parts[1]))
|
|
37
|
+
else:
|
|
38
|
+
logger.warning(f"Unrecognized port in endpoint [{endpoint}], defaulting to 443.")
|
|
39
|
+
server_address = (endpoint, 443)
|
|
40
|
+
|
|
41
|
+
# Run the blocking SSL certificate retrieval in a thread pool
|
|
42
|
+
cert = ssl.get_server_certificate(server_address)
|
|
43
|
+
return grpc.ssl_channel_credentials(str.encode(cert))
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
async def create_channel(
|
|
47
|
+
endpoint: str,
|
|
48
|
+
insecure: typing.Optional[bool] = None,
|
|
49
|
+
insecure_skip_verify: typing.Optional[bool] = False,
|
|
50
|
+
ca_cert_file_path: typing.Optional[str] = None,
|
|
51
|
+
ssl_credentials: typing.Optional[grpc.ssl_channel_credentials] = None,
|
|
52
|
+
grpc_options: typing.Optional[typing.Sequence[typing.Tuple[str, typing.Any]]] = None,
|
|
53
|
+
compression: typing.Optional[grpc.Compression] = None,
|
|
54
|
+
http_session: httpx.AsyncClient | None = None,
|
|
55
|
+
proxy_command: typing.List[str] | None = None,
|
|
56
|
+
**kwargs,
|
|
57
|
+
) -> grpc.aio.Channel:
|
|
58
|
+
"""
|
|
59
|
+
Creates a new gRPC channel with appropriate authentication interceptors.
|
|
60
|
+
|
|
61
|
+
This function creates either a secure or insecure gRPC channel based on the provided parameters,
|
|
62
|
+
and adds authentication interceptors to the channel. If SSL credentials are not provided,
|
|
63
|
+
they are created based on the insecure_skip_verify and ca_cert_file_path parameters.
|
|
64
|
+
|
|
65
|
+
The function is async because it may need to read certificate files asynchronously
|
|
66
|
+
and create authentication interceptors that perform async operations.
|
|
67
|
+
|
|
68
|
+
:param endpoint: The endpoint URL for the gRPC channel
|
|
69
|
+
:param insecure: Whether to use an insecure channel (no SSL)
|
|
70
|
+
:param insecure_skip_verify: Whether to skip SSL certificate verification
|
|
71
|
+
:param ca_cert_file_path: Path to CA certificate file for SSL verification
|
|
72
|
+
:param ssl_credentials: Pre-configured SSL credentials for the channel
|
|
73
|
+
:param grpc_options: Additional gRPC channel options
|
|
74
|
+
:param compression: Compression method for the channel
|
|
75
|
+
:param http_session: Pre-configured HTTP session to use for requests
|
|
76
|
+
:param proxy_command: List of strings for proxy command configuration
|
|
77
|
+
:param kwargs: Additional arguments passed to various functions:
|
|
78
|
+
- For grpc.aio.insecure_channel/secure_channel:
|
|
79
|
+
- root_certificates: Root certificates for SSL credentials
|
|
80
|
+
- private_key: Private key for SSL credentials
|
|
81
|
+
- certificate_chain: Certificate chain for SSL credentials
|
|
82
|
+
- options: gRPC channel options
|
|
83
|
+
- compression: gRPC compression method
|
|
84
|
+
- For proxy configuration:
|
|
85
|
+
- proxy_env: Dict of environment variables for proxy
|
|
86
|
+
- proxy_timeout: Timeout for proxy connection
|
|
87
|
+
- For authentication interceptors (passed to create_auth_interceptors and create_proxy_auth_interceptors):
|
|
88
|
+
- auth_type: The authentication type to use ("Pkce", "ClientSecret", "ExternalCommand", "DeviceFlow")
|
|
89
|
+
- command: Command to execute for ExternalCommand authentication
|
|
90
|
+
- client_id: Client ID for ClientSecret authentication
|
|
91
|
+
- client_secret: Client secret for ClientSecret authentication
|
|
92
|
+
- client_credentials_secret: Client secret for ClientSecret authentication (alias)
|
|
93
|
+
- scopes: List of scopes to request during authentication
|
|
94
|
+
- audience: Audience for the token
|
|
95
|
+
- http_proxy_url: HTTP proxy URL
|
|
96
|
+
- verify: Whether to verify SSL certificates
|
|
97
|
+
- ca_cert_path: Optional path to CA certificate file
|
|
98
|
+
- header_key: Header key to use for authentication
|
|
99
|
+
- redirect_uri: OAuth2 redirect URI for PKCE authentication
|
|
100
|
+
- add_request_auth_code_params_to_request_access_token_params: Whether to add auth code params to token
|
|
101
|
+
request
|
|
102
|
+
- request_auth_code_params: Parameters to add to login URI opened in browser
|
|
103
|
+
- request_access_token_params: Parameters to add when exchanging auth code for access token
|
|
104
|
+
- refresh_access_token_params: Parameters to add when refreshing access token
|
|
105
|
+
:return: grpc.aio.Channel with authentication interceptors configured
|
|
106
|
+
"""
|
|
107
|
+
|
|
108
|
+
if not ssl_credentials:
|
|
109
|
+
if insecure_skip_verify:
|
|
110
|
+
ssl_credentials = bootstrap_ssl_from_server(endpoint)
|
|
111
|
+
elif ca_cert_file_path:
|
|
112
|
+
import aiofiles
|
|
113
|
+
|
|
114
|
+
async with aiofiles.open(ca_cert_file_path, "rb") as f:
|
|
115
|
+
st_cert = f.read()
|
|
116
|
+
ssl_credentials = grpc.ssl_channel_credentials(st_cert)
|
|
117
|
+
else:
|
|
118
|
+
ssl_credentials = grpc.ssl_channel_credentials()
|
|
119
|
+
|
|
120
|
+
# Create an unauthenticated channel first to use to get the server metadata
|
|
121
|
+
if insecure:
|
|
122
|
+
unauthenticated_channel = grpc.aio.insecure_channel(endpoint, **kwargs)
|
|
123
|
+
else:
|
|
124
|
+
unauthenticated_channel = grpc.aio.secure_channel(
|
|
125
|
+
target=endpoint,
|
|
126
|
+
credentials=ssl_credentials,
|
|
127
|
+
options=grpc_options,
|
|
128
|
+
compression=compression,
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
from ._grpc_utils.default_metadata_interceptor import (
|
|
132
|
+
DefaultMetadataStreamStreamInterceptor,
|
|
133
|
+
DefaultMetadataStreamUnaryInterceptor,
|
|
134
|
+
DefaultMetadataUnaryStreamInterceptor,
|
|
135
|
+
DefaultMetadataUnaryUnaryInterceptor,
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
# Add all types of default metadata interceptors
|
|
139
|
+
interceptors: typing.List[grpc.aio.ClientInterceptor] = [
|
|
140
|
+
DefaultMetadataUnaryUnaryInterceptor(),
|
|
141
|
+
DefaultMetadataUnaryStreamInterceptor(),
|
|
142
|
+
DefaultMetadataStreamUnaryInterceptor(),
|
|
143
|
+
DefaultMetadataStreamStreamInterceptor(),
|
|
144
|
+
]
|
|
145
|
+
|
|
146
|
+
# Create an HTTP session if not provided so we share the same http client across the stack
|
|
147
|
+
if not http_session:
|
|
148
|
+
proxy_authenticator = None
|
|
149
|
+
if proxy_command:
|
|
150
|
+
proxy_authenticator = get_async_proxy_authenticator(
|
|
151
|
+
endpoint=endpoint, proxy_command=proxy_command, **kwargs
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
http_session = get_async_session(
|
|
155
|
+
ca_cert_file_path=ca_cert_file_path, proxy_authenticator=proxy_authenticator, **kwargs
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
# Get proxy auth interceptors
|
|
159
|
+
proxy_auth_interceptors = create_proxy_auth_interceptors(endpoint, http_session=http_session, **kwargs)
|
|
160
|
+
interceptors.extend(proxy_auth_interceptors)
|
|
161
|
+
|
|
162
|
+
# Get auth interceptors
|
|
163
|
+
auth_interceptors = create_auth_interceptors(
|
|
164
|
+
endpoint=endpoint,
|
|
165
|
+
in_channel=unauthenticated_channel,
|
|
166
|
+
insecure=insecure,
|
|
167
|
+
insecure_skip_verify=insecure_skip_verify,
|
|
168
|
+
ca_cert_file_path=ca_cert_file_path,
|
|
169
|
+
http_session=http_session,
|
|
170
|
+
**kwargs,
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
interceptors.extend(auth_interceptors)
|
|
174
|
+
|
|
175
|
+
if insecure:
|
|
176
|
+
return grpc.aio.insecure_channel(endpoint, interceptors=interceptors, **kwargs)
|
|
177
|
+
|
|
178
|
+
return grpc.aio.secure_channel(
|
|
179
|
+
target=endpoint,
|
|
180
|
+
credentials=ssl_credentials,
|
|
181
|
+
options=grpc_options,
|
|
182
|
+
compression=compression,
|
|
183
|
+
interceptors=interceptors,
|
|
184
|
+
)
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
import typing
|
|
2
|
+
from abc import abstractmethod
|
|
3
|
+
|
|
4
|
+
import grpc.aio
|
|
5
|
+
import pydantic
|
|
6
|
+
from flyteidl.service.auth_pb2 import OAuth2MetadataRequest, PublicClientAuthConfigRequest
|
|
7
|
+
from flyteidl.service.auth_pb2_grpc import AuthMetadataServiceStub
|
|
8
|
+
|
|
9
|
+
AuthType = typing.Literal["ClientSecret", "Pkce", "ExternalCommand", "DeviceFlow"]
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ClientConfig(pydantic.BaseModel):
|
|
13
|
+
"""
|
|
14
|
+
Client Configuration that is needed by the authenticator
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
token_endpoint: str
|
|
18
|
+
authorization_endpoint: str
|
|
19
|
+
redirect_uri: str
|
|
20
|
+
client_id: str
|
|
21
|
+
device_authorization_endpoint: typing.Optional[str] = None
|
|
22
|
+
scopes: typing.Optional[typing.List[str]] = None
|
|
23
|
+
header_key: str = "authorization"
|
|
24
|
+
audience: typing.Optional[str] = None
|
|
25
|
+
|
|
26
|
+
def with_override(self, other: "ClientConfig") -> "ClientConfig":
|
|
27
|
+
"""
|
|
28
|
+
Returns a new ClientConfig instance with the values from the other instance overriding the current instance.
|
|
29
|
+
"""
|
|
30
|
+
return ClientConfig(
|
|
31
|
+
token_endpoint=other.token_endpoint or self.token_endpoint,
|
|
32
|
+
authorization_endpoint=other.authorization_endpoint or self.authorization_endpoint,
|
|
33
|
+
redirect_uri=other.redirect_uri or self.redirect_uri,
|
|
34
|
+
client_id=other.client_id or self.client_id,
|
|
35
|
+
device_authorization_endpoint=other.device_authorization_endpoint or self.device_authorization_endpoint,
|
|
36
|
+
scopes=other.scopes or self.scopes,
|
|
37
|
+
header_key=other.header_key or self.header_key,
|
|
38
|
+
audience=other.audience or self.audience,
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class ClientConfigStore(object):
|
|
43
|
+
"""
|
|
44
|
+
Client Config store retrieve client config. this can be done in multiple ways
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
@abstractmethod
|
|
48
|
+
async def get_client_config(self) -> ClientConfig: ...
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class StaticClientConfigStore(ClientConfigStore):
|
|
52
|
+
def __init__(self, cfg: ClientConfig):
|
|
53
|
+
self._cfg = cfg
|
|
54
|
+
|
|
55
|
+
async def get_client_config(self) -> ClientConfig:
|
|
56
|
+
return self._cfg
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class RemoteClientConfigStore(ClientConfigStore):
|
|
60
|
+
"""
|
|
61
|
+
This class implements the ClientConfigStore that is served by the Flyte Server, that implements AuthMetadataService
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
def __init__(self, unauthenticated_channel: grpc.aio.Channel):
|
|
65
|
+
self._unauthenticated_channel = unauthenticated_channel
|
|
66
|
+
|
|
67
|
+
async def get_client_config(self) -> ClientConfig:
|
|
68
|
+
"""
|
|
69
|
+
Retrieves the ClientConfig from the given grpc.Channel assuming AuthMetadataService is available
|
|
70
|
+
"""
|
|
71
|
+
metadata_service = AuthMetadataServiceStub(self._unauthenticated_channel)
|
|
72
|
+
public_client_config = await metadata_service.GetPublicClientConfig(PublicClientAuthConfigRequest())
|
|
73
|
+
oauth2_metadata = await metadata_service.GetOAuth2Metadata(OAuth2MetadataRequest())
|
|
74
|
+
return ClientConfig(
|
|
75
|
+
token_endpoint=oauth2_metadata.token_endpoint,
|
|
76
|
+
authorization_endpoint=oauth2_metadata.authorization_endpoint,
|
|
77
|
+
redirect_uri=public_client_config.redirect_uri,
|
|
78
|
+
client_id=public_client_config.client_id,
|
|
79
|
+
scopes=public_client_config.scopes,
|
|
80
|
+
header_key=public_client_config.authorization_metadata_key,
|
|
81
|
+
device_authorization_endpoint=oauth2_metadata.device_authorization_endpoint,
|
|
82
|
+
audience=public_client_config.audience,
|
|
83
|
+
)
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
import textwrap
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def get_default_success_html(endpoint: str) -> str:
|
|
5
|
+
"""Get default success html."""
|
|
6
|
+
return textwrap.dedent(
|
|
7
|
+
"""
|
|
8
|
+
<html>
|
|
9
|
+
<head>
|
|
10
|
+
<title>OAuth2 Authentication to Union Successful</title>
|
|
11
|
+
</head>
|
|
12
|
+
<body style="background:white;font-family:Arial">
|
|
13
|
+
<div style="position: absolute;top:40%;left:50%;transform: translate(-50%, -50%);text-align:center;">
|
|
14
|
+
<div style="margin:auto">
|
|
15
|
+
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 300 65" fill="currentColor"
|
|
16
|
+
style="color:#fdb51e;width:360px;">
|
|
17
|
+
<title>Union.ai</title>
|
|
18
|
+
<path d="M32,64.8C14.4,64.8,0,51.5,0,34V3.6h17.6v41.3c0,1.9,1.1,3,3,3h23c1.9,0,3-1.1,3-3V3.6H64V34
|
|
19
|
+
C64,51.5,49.6,64.8,32,64.8z M69.9,30.9v30.4h17.6V20c0-1.9,1.1-3,3-3h23c1.9,0,3,1.1,3,3v41.3H134V30.9c0-17.5-14.4-30.8-32.1-30.8
|
|
20
|
+
S69.9,13.5,69.9,30.9z M236,30.9v30.4h17.6V20c0-1.9,1.1-3,3-3h23c1.9,0,3,1.1,3,3v41.3H300V30.9c0-17.5-14.4-30.8-32-30.8
|
|
21
|
+
S236,13.5,236,30.9L236,30.9z M230.1,32.4c0,18.2-14.2,32.5-32.2,32.5s-32-14.3-32-32.5s14-32.1,32-32.1S230.1,14.3,230.1,32.4
|
|
22
|
+
L230.1,32.4z M213.5,20.2c0-1.9-1.1-3-3-3h-24.8c-1.9,0-3,1.1-3,3v24.5c0,1.9,1.1,3,3,3h24.8c1.9,0,3-1.1,3-3V20.2z M158.9,3.6
|
|
23
|
+
h-17.6v57.8h17.6V3.6z"></path>
|
|
24
|
+
</svg>
|
|
25
|
+
<h2>You've successfully authenticated to Union!</h2>
|
|
26
|
+
<p style="font-size:20px;">Return to your terminal for next steps</p>
|
|
27
|
+
</div>
|
|
28
|
+
</div>
|
|
29
|
+
</body>
|
|
30
|
+
</html>
|
|
31
|
+
""" # noqa: E501
|
|
32
|
+
)
|
|
File without changes
|
|
@@ -0,0 +1,288 @@
|
|
|
1
|
+
import typing
|
|
2
|
+
from typing import AsyncIterator, Optional, Union
|
|
3
|
+
|
|
4
|
+
import grpc.aio
|
|
5
|
+
from grpc.aio import ClientCallDetails, Metadata
|
|
6
|
+
from grpc.aio._typing import DoneCallbackType, EOFType, RequestType, ResponseType
|
|
7
|
+
|
|
8
|
+
from flyte.remote._client.auth._authenticators.base import Authenticator
|
|
9
|
+
from flyte.remote._client.auth._grpc_utils.default_metadata_interceptor import with_metadata
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class _BaseAuthInterceptor:
|
|
13
|
+
"""
|
|
14
|
+
Base class for all auth interceptors that provides common authentication functionality.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(self, get_authenticator: typing.Callable[[], Authenticator]):
|
|
18
|
+
self._get_authenticator = get_authenticator
|
|
19
|
+
self._authenticator: typing.Optional[Authenticator] = None
|
|
20
|
+
|
|
21
|
+
@property
|
|
22
|
+
def authenticator(self) -> Authenticator:
|
|
23
|
+
if self._authenticator is None:
|
|
24
|
+
self._authenticator = self._get_authenticator()
|
|
25
|
+
return self._authenticator
|
|
26
|
+
|
|
27
|
+
async def call_details_with_auth_metadata(
|
|
28
|
+
self, client_call_details: grpc.aio.ClientCallDetails
|
|
29
|
+
) -> typing.Tuple[grpc.aio.ClientCallDetails, str]:
|
|
30
|
+
"""
|
|
31
|
+
Returns new ClientCallDetails with authentication metadata added.
|
|
32
|
+
|
|
33
|
+
This method retrieves authentication metadata from the authenticator and adds it to the
|
|
34
|
+
client call details. If no authentication metadata is available, the original client call
|
|
35
|
+
details are returned unchanged.
|
|
36
|
+
|
|
37
|
+
:param client_call_details: The original client call details containing method, timeout, metadata,
|
|
38
|
+
credentials, and wait_for_ready settings
|
|
39
|
+
:return: Updated client call details with authentication metadata added to the existing metadata
|
|
40
|
+
"""
|
|
41
|
+
auth_metadata = await self.authenticator.get_grpc_call_auth_metadata()
|
|
42
|
+
if auth_metadata:
|
|
43
|
+
return with_metadata(client_call_details, auth_metadata.pairs), auth_metadata.creds_id
|
|
44
|
+
else:
|
|
45
|
+
return client_call_details, ""
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class AuthUnaryUnaryInterceptor(_BaseAuthInterceptor, grpc.aio.UnaryUnaryClientInterceptor):
|
|
49
|
+
"""
|
|
50
|
+
Interceptor for unary-unary RPC calls that adds authentication metadata.
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
async def intercept_unary_unary(
|
|
54
|
+
self,
|
|
55
|
+
continuation: typing.Callable,
|
|
56
|
+
client_call_details: ClientCallDetails,
|
|
57
|
+
request: typing.Any,
|
|
58
|
+
):
|
|
59
|
+
"""
|
|
60
|
+
Intercepts unary-unary calls and adds auth metadata if available. On Unauthenticated, resets the token and
|
|
61
|
+
refreshes and then retries with the new token.
|
|
62
|
+
|
|
63
|
+
This method first adds authentication metadata to the client call details, then attempts to make the RPC call.
|
|
64
|
+
If the call fails with an UNAUTHENTICATED or UNKNOWN status code, it refreshes the credentials and retries
|
|
65
|
+
the call with the new authentication metadata.
|
|
66
|
+
|
|
67
|
+
:param continuation: Function to continue the RPC call chain with the updated call details
|
|
68
|
+
:param client_call_details: Details about the RPC call including method, timeout, metadata, credentials,
|
|
69
|
+
and wait_for_ready
|
|
70
|
+
:param request: The request message to be sent to the server
|
|
71
|
+
:return: The response from the RPC call after successful authentication
|
|
72
|
+
:raises: grpc.aio.AioRpcError if the call fails for reasons other than authentication
|
|
73
|
+
"""
|
|
74
|
+
updated_call_details, creds_id = await self.call_details_with_auth_metadata(client_call_details)
|
|
75
|
+
try:
|
|
76
|
+
return await (await continuation(updated_call_details, request))
|
|
77
|
+
except grpc.aio.AioRpcError as e:
|
|
78
|
+
if e.code() == grpc.StatusCode.UNAUTHENTICATED or e.code() == grpc.StatusCode.UNKNOWN:
|
|
79
|
+
await self.authenticator.refresh_credentials(creds_id=creds_id)
|
|
80
|
+
updated_call_details, _ = await self.call_details_with_auth_metadata(client_call_details)
|
|
81
|
+
return await (await continuation(updated_call_details, request))
|
|
82
|
+
else:
|
|
83
|
+
raise e
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class UnaryStreamCall(grpc.aio.UnaryStreamCall):
|
|
87
|
+
def __init__(
|
|
88
|
+
self,
|
|
89
|
+
parent_interceptor: _BaseAuthInterceptor,
|
|
90
|
+
authenticator: Authenticator,
|
|
91
|
+
continuation: typing.Callable,
|
|
92
|
+
call_details: grpc.aio.ClientCallDetails,
|
|
93
|
+
request: RequestType,
|
|
94
|
+
):
|
|
95
|
+
super().__init__()
|
|
96
|
+
self._continuation = continuation
|
|
97
|
+
self._call_details = call_details
|
|
98
|
+
self._request = request
|
|
99
|
+
self._authenticator = authenticator
|
|
100
|
+
self._parent_interceptor = parent_interceptor
|
|
101
|
+
self._call: (
|
|
102
|
+
Union[
|
|
103
|
+
grpc.aio.UnaryStreamCall[RequestType, ResponseType],
|
|
104
|
+
grpc.aio.StreamStreamCall[RequestType, ResponseType],
|
|
105
|
+
]
|
|
106
|
+
| None
|
|
107
|
+
) = None
|
|
108
|
+
|
|
109
|
+
async def response_iterator(self) -> typing.AsyncIterator[ResponseType]:
|
|
110
|
+
call_details, creds_id = await self._parent_interceptor.call_details_with_auth_metadata(self._call_details)
|
|
111
|
+
self._call = await self._continuation(call_details, self._request)
|
|
112
|
+
try:
|
|
113
|
+
async for response in self._call:
|
|
114
|
+
yield response
|
|
115
|
+
except grpc.aio.AioRpcError as e:
|
|
116
|
+
if e.code() == grpc.StatusCode.UNAUTHENTICATED or e.code() == grpc.StatusCode.UNKNOWN:
|
|
117
|
+
await self._authenticator.refresh_credentials(creds_id=creds_id)
|
|
118
|
+
updated_call_details, _ = await self._parent_interceptor.call_details_with_auth_metadata(call_details)
|
|
119
|
+
self._call = await self._continuation(updated_call_details, self._request)
|
|
120
|
+
async for response in self._call:
|
|
121
|
+
yield response
|
|
122
|
+
else:
|
|
123
|
+
raise e
|
|
124
|
+
|
|
125
|
+
def __aiter__(self) -> AsyncIterator[ResponseType]:
|
|
126
|
+
return self.response_iterator()
|
|
127
|
+
|
|
128
|
+
async def read(self) -> Union[EOFType, ResponseType]:
|
|
129
|
+
if self._call is not None:
|
|
130
|
+
return await self._call.read()
|
|
131
|
+
return EOFType()
|
|
132
|
+
|
|
133
|
+
async def initial_metadata(self) -> Metadata:
|
|
134
|
+
if self._call is not None:
|
|
135
|
+
return await self._call.initial_metadata()
|
|
136
|
+
return Metadata()
|
|
137
|
+
|
|
138
|
+
async def trailing_metadata(self) -> Metadata:
|
|
139
|
+
if self._call is not None:
|
|
140
|
+
return await self._call.trailing_metadata()
|
|
141
|
+
return Metadata()
|
|
142
|
+
|
|
143
|
+
async def code(self) -> grpc.StatusCode:
|
|
144
|
+
if self._call is not None:
|
|
145
|
+
return await self._call.code()
|
|
146
|
+
return grpc.StatusCode.OK
|
|
147
|
+
|
|
148
|
+
async def details(self) -> str:
|
|
149
|
+
if self._call is not None:
|
|
150
|
+
return await self._call.details()
|
|
151
|
+
return ""
|
|
152
|
+
|
|
153
|
+
async def wait_for_connection(self) -> None:
|
|
154
|
+
if self._call is not None:
|
|
155
|
+
await self._call.wait_for_connection()
|
|
156
|
+
return None
|
|
157
|
+
|
|
158
|
+
def cancelled(self) -> bool:
|
|
159
|
+
if self._call is not None:
|
|
160
|
+
return self._call.cancelled()
|
|
161
|
+
return False
|
|
162
|
+
|
|
163
|
+
def done(self) -> bool:
|
|
164
|
+
if self._call is not None:
|
|
165
|
+
return self._call.done()
|
|
166
|
+
return False
|
|
167
|
+
|
|
168
|
+
def time_remaining(self) -> Optional[float]:
|
|
169
|
+
if self._call is not None:
|
|
170
|
+
return self._call.time_remaining()
|
|
171
|
+
return None
|
|
172
|
+
|
|
173
|
+
def cancel(self) -> bool:
|
|
174
|
+
if self._call is not None:
|
|
175
|
+
return self._call.cancel()
|
|
176
|
+
return False
|
|
177
|
+
|
|
178
|
+
def add_done_callback(self, callback: DoneCallbackType) -> None:
|
|
179
|
+
if self._call is not None:
|
|
180
|
+
self._call.add_done_callback(callback=callback)
|
|
181
|
+
return None
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
class AuthUnaryStreamInterceptor(_BaseAuthInterceptor, grpc.aio.UnaryStreamClientInterceptor):
|
|
185
|
+
"""
|
|
186
|
+
Interceptor for unary-stream RPC calls that adds authentication metadata.
|
|
187
|
+
"""
|
|
188
|
+
|
|
189
|
+
async def intercept_unary_stream(
|
|
190
|
+
self, continuation: typing.Callable, client_call_details: grpc.aio.ClientCallDetails, request: typing.Any
|
|
191
|
+
):
|
|
192
|
+
"""
|
|
193
|
+
Intercepts unary-stream calls and adds auth metadata if available.
|
|
194
|
+
|
|
195
|
+
This method first adds authentication metadata to the client call details, then attempts to make the RPC call.
|
|
196
|
+
If the call fails with an UNAUTHENTICATED or UNKNOWN status code, it refreshes the credentials and retries
|
|
197
|
+
the call with the new authentication metadata.
|
|
198
|
+
|
|
199
|
+
:param continuation: Function to continue the RPC call chain with the updated call details
|
|
200
|
+
:param client_call_details: Details about the RPC call including method, timeout, metadata, credentials,
|
|
201
|
+
and wait_for_ready
|
|
202
|
+
:param request: The request message to be sent to the server
|
|
203
|
+
:return: A stream of responses from the RPC call after successful authentication
|
|
204
|
+
:raises: grpc.aio.AioRpcError if the call fails for reasons other than authentication
|
|
205
|
+
"""
|
|
206
|
+
|
|
207
|
+
return UnaryStreamCall(
|
|
208
|
+
parent_interceptor=self,
|
|
209
|
+
authenticator=self.authenticator,
|
|
210
|
+
call_details=client_call_details,
|
|
211
|
+
continuation=continuation,
|
|
212
|
+
request=request,
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
class AuthStreamUnaryInterceptor(_BaseAuthInterceptor, grpc.aio.StreamUnaryClientInterceptor):
|
|
217
|
+
"""
|
|
218
|
+
Interceptor for stream-unary RPC calls that adds authentication metadata.
|
|
219
|
+
"""
|
|
220
|
+
|
|
221
|
+
async def intercept_stream_unary(
|
|
222
|
+
self,
|
|
223
|
+
continuation: typing.Callable,
|
|
224
|
+
client_call_details: grpc.aio.ClientCallDetails,
|
|
225
|
+
request_iterator: typing.Any,
|
|
226
|
+
):
|
|
227
|
+
"""
|
|
228
|
+
Intercepts stream-unary calls and adds auth metadata if available.
|
|
229
|
+
|
|
230
|
+
This method first adds authentication metadata to the client call details, then attempts to make the RPC call.
|
|
231
|
+
If the call fails with an UNAUTHENTICATED or UNKNOWN status code, it refreshes the credentials and retries
|
|
232
|
+
the call with the new authentication metadata.
|
|
233
|
+
|
|
234
|
+
:param continuation: Function to continue the RPC call chain with the updated call details
|
|
235
|
+
:param client_call_details: Details about the RPC call including method, timeout, metadata, credentials,
|
|
236
|
+
and wait_for_ready
|
|
237
|
+
:param request_iterator: An iterator of request messages to be sent to the server
|
|
238
|
+
:return: The response from the RPC call after successful authentication
|
|
239
|
+
:raises: grpc.aio.AioRpcError if the call fails for reasons other than authentication
|
|
240
|
+
"""
|
|
241
|
+
updated_call_details, creds_id = await self.call_details_with_auth_metadata(client_call_details)
|
|
242
|
+
try:
|
|
243
|
+
call = await continuation(updated_call_details, request_iterator)
|
|
244
|
+
return await call
|
|
245
|
+
except grpc.aio.AioRpcError as e:
|
|
246
|
+
if e.code() == grpc.StatusCode.UNAUTHENTICATED or e.code() == grpc.StatusCode.UNKNOWN:
|
|
247
|
+
await self.authenticator.refresh_credentials(creds_id=creds_id)
|
|
248
|
+
updated_call_details, _ = await self.call_details_with_auth_metadata(client_call_details)
|
|
249
|
+
call = await continuation(updated_call_details, request_iterator)
|
|
250
|
+
return await call
|
|
251
|
+
raise e
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
class AuthStreamStreamInterceptor(_BaseAuthInterceptor, grpc.aio.StreamStreamClientInterceptor):
|
|
255
|
+
"""
|
|
256
|
+
Interceptor for stream-stream RPC calls that adds authentication metadata.
|
|
257
|
+
"""
|
|
258
|
+
|
|
259
|
+
async def intercept_stream_stream(
|
|
260
|
+
self,
|
|
261
|
+
continuation: typing.Callable,
|
|
262
|
+
client_call_details: grpc.aio.ClientCallDetails,
|
|
263
|
+
request_iterator: typing.Any,
|
|
264
|
+
):
|
|
265
|
+
"""
|
|
266
|
+
Intercepts stream-stream calls and adds auth metadata if available.
|
|
267
|
+
|
|
268
|
+
This method first adds authentication metadata to the client call details, then attempts to make the RPC call.
|
|
269
|
+
If the call fails with an UNAUTHENTICATED or UNKNOWN status code, it refreshes the credentials and retries
|
|
270
|
+
the call with the new authentication metadata.
|
|
271
|
+
|
|
272
|
+
:param continuation: Function to continue the RPC call chain with the updated call details
|
|
273
|
+
:param client_call_details: Details about the RPC call including method, timeout, metadata, credentials,
|
|
274
|
+
and wait_for_ready
|
|
275
|
+
:param request_iterator: An iterator of request messages to be sent to the server
|
|
276
|
+
:return: A stream of responses from the RPC call after successful authentication
|
|
277
|
+
"""
|
|
278
|
+
return UnaryStreamCall(
|
|
279
|
+
parent_interceptor=self,
|
|
280
|
+
authenticator=self.authenticator,
|
|
281
|
+
call_details=client_call_details,
|
|
282
|
+
continuation=continuation,
|
|
283
|
+
request=request_iterator,
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
# For backward compatibility, maintain the original class name but as a type alias
|
|
288
|
+
AuthUnaryInterceptor = AuthUnaryUnaryInterceptor
|