flyte 0.0.1b0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flyte might be problematic. Click here for more details.
- flyte/__init__.py +62 -0
- flyte/_api_commons.py +3 -0
- flyte/_bin/__init__.py +0 -0
- flyte/_bin/runtime.py +126 -0
- flyte/_build.py +25 -0
- flyte/_cache/__init__.py +12 -0
- flyte/_cache/cache.py +146 -0
- flyte/_cache/defaults.py +9 -0
- flyte/_cache/policy_function_body.py +42 -0
- flyte/_cli/__init__.py +0 -0
- flyte/_cli/_common.py +287 -0
- flyte/_cli/_create.py +42 -0
- flyte/_cli/_delete.py +23 -0
- flyte/_cli/_deploy.py +140 -0
- flyte/_cli/_get.py +235 -0
- flyte/_cli/_run.py +152 -0
- flyte/_cli/main.py +72 -0
- flyte/_code_bundle/__init__.py +8 -0
- flyte/_code_bundle/_ignore.py +113 -0
- flyte/_code_bundle/_packaging.py +187 -0
- flyte/_code_bundle/_utils.py +339 -0
- flyte/_code_bundle/bundle.py +178 -0
- flyte/_context.py +146 -0
- flyte/_datastructures.py +342 -0
- flyte/_deploy.py +202 -0
- flyte/_doc.py +29 -0
- flyte/_docstring.py +32 -0
- flyte/_environment.py +43 -0
- flyte/_group.py +31 -0
- flyte/_hash.py +23 -0
- flyte/_image.py +760 -0
- flyte/_initialize.py +634 -0
- flyte/_interface.py +84 -0
- flyte/_internal/__init__.py +3 -0
- flyte/_internal/controllers/__init__.py +115 -0
- flyte/_internal/controllers/_local_controller.py +118 -0
- flyte/_internal/controllers/_trace.py +40 -0
- flyte/_internal/controllers/pbhash.py +39 -0
- flyte/_internal/controllers/remote/__init__.py +40 -0
- flyte/_internal/controllers/remote/_action.py +141 -0
- flyte/_internal/controllers/remote/_client.py +43 -0
- flyte/_internal/controllers/remote/_controller.py +361 -0
- flyte/_internal/controllers/remote/_core.py +402 -0
- flyte/_internal/controllers/remote/_informer.py +361 -0
- flyte/_internal/controllers/remote/_service_protocol.py +50 -0
- flyte/_internal/imagebuild/__init__.py +11 -0
- flyte/_internal/imagebuild/docker_builder.py +416 -0
- flyte/_internal/imagebuild/image_builder.py +241 -0
- flyte/_internal/imagebuild/remote_builder.py +0 -0
- flyte/_internal/resolvers/__init__.py +0 -0
- flyte/_internal/resolvers/_task_module.py +54 -0
- flyte/_internal/resolvers/common.py +31 -0
- flyte/_internal/resolvers/default.py +28 -0
- flyte/_internal/runtime/__init__.py +0 -0
- flyte/_internal/runtime/convert.py +199 -0
- flyte/_internal/runtime/entrypoints.py +135 -0
- flyte/_internal/runtime/io.py +136 -0
- flyte/_internal/runtime/resources_serde.py +138 -0
- flyte/_internal/runtime/task_serde.py +210 -0
- flyte/_internal/runtime/taskrunner.py +190 -0
- flyte/_internal/runtime/types_serde.py +54 -0
- flyte/_logging.py +124 -0
- flyte/_protos/__init__.py +0 -0
- flyte/_protos/common/authorization_pb2.py +66 -0
- flyte/_protos/common/authorization_pb2.pyi +108 -0
- flyte/_protos/common/authorization_pb2_grpc.py +4 -0
- flyte/_protos/common/identifier_pb2.py +71 -0
- flyte/_protos/common/identifier_pb2.pyi +82 -0
- flyte/_protos/common/identifier_pb2_grpc.py +4 -0
- flyte/_protos/common/identity_pb2.py +48 -0
- flyte/_protos/common/identity_pb2.pyi +72 -0
- flyte/_protos/common/identity_pb2_grpc.py +4 -0
- flyte/_protos/common/list_pb2.py +36 -0
- flyte/_protos/common/list_pb2.pyi +69 -0
- flyte/_protos/common/list_pb2_grpc.py +4 -0
- flyte/_protos/common/policy_pb2.py +37 -0
- flyte/_protos/common/policy_pb2.pyi +27 -0
- flyte/_protos/common/policy_pb2_grpc.py +4 -0
- flyte/_protos/common/role_pb2.py +37 -0
- flyte/_protos/common/role_pb2.pyi +53 -0
- flyte/_protos/common/role_pb2_grpc.py +4 -0
- flyte/_protos/common/runtime_version_pb2.py +28 -0
- flyte/_protos/common/runtime_version_pb2.pyi +24 -0
- flyte/_protos/common/runtime_version_pb2_grpc.py +4 -0
- flyte/_protos/logs/dataplane/payload_pb2.py +96 -0
- flyte/_protos/logs/dataplane/payload_pb2.pyi +168 -0
- flyte/_protos/logs/dataplane/payload_pb2_grpc.py +4 -0
- flyte/_protos/secret/definition_pb2.py +49 -0
- flyte/_protos/secret/definition_pb2.pyi +93 -0
- flyte/_protos/secret/definition_pb2_grpc.py +4 -0
- flyte/_protos/secret/payload_pb2.py +62 -0
- flyte/_protos/secret/payload_pb2.pyi +94 -0
- flyte/_protos/secret/payload_pb2_grpc.py +4 -0
- flyte/_protos/secret/secret_pb2.py +38 -0
- flyte/_protos/secret/secret_pb2.pyi +6 -0
- flyte/_protos/secret/secret_pb2_grpc.py +198 -0
- flyte/_protos/secret/secret_pb2_grpc_grpc.py +198 -0
- flyte/_protos/validate/validate/validate_pb2.py +76 -0
- flyte/_protos/workflow/node_execution_service_pb2.py +26 -0
- flyte/_protos/workflow/node_execution_service_pb2.pyi +4 -0
- flyte/_protos/workflow/node_execution_service_pb2_grpc.py +32 -0
- flyte/_protos/workflow/queue_service_pb2.py +106 -0
- flyte/_protos/workflow/queue_service_pb2.pyi +141 -0
- flyte/_protos/workflow/queue_service_pb2_grpc.py +172 -0
- flyte/_protos/workflow/run_definition_pb2.py +128 -0
- flyte/_protos/workflow/run_definition_pb2.pyi +310 -0
- flyte/_protos/workflow/run_definition_pb2_grpc.py +4 -0
- flyte/_protos/workflow/run_logs_service_pb2.py +41 -0
- flyte/_protos/workflow/run_logs_service_pb2.pyi +28 -0
- flyte/_protos/workflow/run_logs_service_pb2_grpc.py +69 -0
- flyte/_protos/workflow/run_service_pb2.py +133 -0
- flyte/_protos/workflow/run_service_pb2.pyi +175 -0
- flyte/_protos/workflow/run_service_pb2_grpc.py +412 -0
- flyte/_protos/workflow/state_service_pb2.py +58 -0
- flyte/_protos/workflow/state_service_pb2.pyi +71 -0
- flyte/_protos/workflow/state_service_pb2_grpc.py +138 -0
- flyte/_protos/workflow/task_definition_pb2.py +72 -0
- flyte/_protos/workflow/task_definition_pb2.pyi +65 -0
- flyte/_protos/workflow/task_definition_pb2_grpc.py +4 -0
- flyte/_protos/workflow/task_service_pb2.py +44 -0
- flyte/_protos/workflow/task_service_pb2.pyi +31 -0
- flyte/_protos/workflow/task_service_pb2_grpc.py +104 -0
- flyte/_resources.py +226 -0
- flyte/_retry.py +32 -0
- flyte/_reusable_environment.py +25 -0
- flyte/_run.py +411 -0
- flyte/_secret.py +61 -0
- flyte/_task.py +367 -0
- flyte/_task_environment.py +200 -0
- flyte/_timeout.py +47 -0
- flyte/_tools.py +27 -0
- flyte/_trace.py +128 -0
- flyte/_utils/__init__.py +20 -0
- flyte/_utils/asyn.py +119 -0
- flyte/_utils/coro_management.py +25 -0
- flyte/_utils/file_handling.py +72 -0
- flyte/_utils/helpers.py +108 -0
- flyte/_utils/lazy_module.py +54 -0
- flyte/_utils/uv_script_parser.py +49 -0
- flyte/_version.py +21 -0
- flyte/connectors/__init__.py +0 -0
- flyte/errors.py +143 -0
- flyte/extras/__init__.py +5 -0
- flyte/extras/_container.py +273 -0
- flyte/io/__init__.py +11 -0
- flyte/io/_dataframe.py +0 -0
- flyte/io/_dir.py +448 -0
- flyte/io/_file.py +468 -0
- flyte/io/pickle/__init__.py +0 -0
- flyte/io/pickle/transformer.py +117 -0
- flyte/io/structured_dataset/__init__.py +129 -0
- flyte/io/structured_dataset/basic_dfs.py +219 -0
- flyte/io/structured_dataset/structured_dataset.py +1061 -0
- flyte/py.typed +0 -0
- flyte/remote/__init__.py +25 -0
- flyte/remote/_client/__init__.py +0 -0
- flyte/remote/_client/_protocols.py +131 -0
- flyte/remote/_client/auth/__init__.py +12 -0
- flyte/remote/_client/auth/_authenticators/__init__.py +0 -0
- flyte/remote/_client/auth/_authenticators/base.py +397 -0
- flyte/remote/_client/auth/_authenticators/client_credentials.py +73 -0
- flyte/remote/_client/auth/_authenticators/device_code.py +118 -0
- flyte/remote/_client/auth/_authenticators/external_command.py +79 -0
- flyte/remote/_client/auth/_authenticators/factory.py +200 -0
- flyte/remote/_client/auth/_authenticators/pkce.py +516 -0
- flyte/remote/_client/auth/_channel.py +184 -0
- flyte/remote/_client/auth/_client_config.py +83 -0
- flyte/remote/_client/auth/_default_html.py +32 -0
- flyte/remote/_client/auth/_grpc_utils/__init__.py +0 -0
- flyte/remote/_client/auth/_grpc_utils/auth_interceptor.py +288 -0
- flyte/remote/_client/auth/_grpc_utils/default_metadata_interceptor.py +151 -0
- flyte/remote/_client/auth/_keyring.py +143 -0
- flyte/remote/_client/auth/_token_client.py +260 -0
- flyte/remote/_client/auth/errors.py +16 -0
- flyte/remote/_client/controlplane.py +95 -0
- flyte/remote/_console.py +18 -0
- flyte/remote/_data.py +155 -0
- flyte/remote/_logs.py +116 -0
- flyte/remote/_project.py +86 -0
- flyte/remote/_run.py +873 -0
- flyte/remote/_secret.py +132 -0
- flyte/remote/_task.py +227 -0
- flyte/report/__init__.py +3 -0
- flyte/report/_report.py +178 -0
- flyte/report/_template.html +124 -0
- flyte/storage/__init__.py +24 -0
- flyte/storage/_remote_fs.py +34 -0
- flyte/storage/_storage.py +251 -0
- flyte/storage/_utils.py +5 -0
- flyte/types/__init__.py +13 -0
- flyte/types/_interface.py +25 -0
- flyte/types/_renderer.py +162 -0
- flyte/types/_string_literals.py +120 -0
- flyte/types/_type_engine.py +2210 -0
- flyte/types/_utils.py +80 -0
- flyte-0.0.1b0.dist-info/METADATA +179 -0
- flyte-0.0.1b0.dist-info/RECORD +390 -0
- flyte-0.0.1b0.dist-info/WHEEL +5 -0
- flyte-0.0.1b0.dist-info/entry_points.txt +3 -0
- flyte-0.0.1b0.dist-info/top_level.txt +1 -0
- union/__init__.py +54 -0
- union/_api_commons.py +3 -0
- union/_bin/__init__.py +0 -0
- union/_bin/runtime.py +113 -0
- union/_build.py +25 -0
- union/_cache/__init__.py +12 -0
- union/_cache/cache.py +141 -0
- union/_cache/defaults.py +9 -0
- union/_cache/policy_function_body.py +42 -0
- union/_cli/__init__.py +0 -0
- union/_cli/_common.py +263 -0
- union/_cli/_create.py +40 -0
- union/_cli/_delete.py +23 -0
- union/_cli/_deploy.py +120 -0
- union/_cli/_get.py +162 -0
- union/_cli/_params.py +579 -0
- union/_cli/_run.py +150 -0
- union/_cli/main.py +72 -0
- union/_code_bundle/__init__.py +8 -0
- union/_code_bundle/_ignore.py +113 -0
- union/_code_bundle/_packaging.py +187 -0
- union/_code_bundle/_utils.py +342 -0
- union/_code_bundle/bundle.py +176 -0
- union/_context.py +146 -0
- union/_datastructures.py +295 -0
- union/_deploy.py +185 -0
- union/_doc.py +29 -0
- union/_docstring.py +26 -0
- union/_environment.py +43 -0
- union/_group.py +31 -0
- union/_hash.py +23 -0
- union/_image.py +760 -0
- union/_initialize.py +585 -0
- union/_interface.py +84 -0
- union/_internal/__init__.py +3 -0
- union/_internal/controllers/__init__.py +77 -0
- union/_internal/controllers/_local_controller.py +77 -0
- union/_internal/controllers/pbhash.py +39 -0
- union/_internal/controllers/remote/__init__.py +40 -0
- union/_internal/controllers/remote/_action.py +131 -0
- union/_internal/controllers/remote/_client.py +43 -0
- union/_internal/controllers/remote/_controller.py +169 -0
- union/_internal/controllers/remote/_core.py +341 -0
- union/_internal/controllers/remote/_informer.py +260 -0
- union/_internal/controllers/remote/_service_protocol.py +44 -0
- union/_internal/imagebuild/__init__.py +11 -0
- union/_internal/imagebuild/docker_builder.py +416 -0
- union/_internal/imagebuild/image_builder.py +243 -0
- union/_internal/imagebuild/remote_builder.py +0 -0
- union/_internal/resolvers/__init__.py +0 -0
- union/_internal/resolvers/_task_module.py +31 -0
- union/_internal/resolvers/common.py +24 -0
- union/_internal/resolvers/default.py +27 -0
- union/_internal/runtime/__init__.py +0 -0
- union/_internal/runtime/convert.py +163 -0
- union/_internal/runtime/entrypoints.py +121 -0
- union/_internal/runtime/io.py +136 -0
- union/_internal/runtime/resources_serde.py +134 -0
- union/_internal/runtime/task_serde.py +202 -0
- union/_internal/runtime/taskrunner.py +179 -0
- union/_internal/runtime/types_serde.py +53 -0
- union/_logging.py +124 -0
- union/_protos/__init__.py +0 -0
- union/_protos/common/authorization_pb2.py +66 -0
- union/_protos/common/authorization_pb2.pyi +106 -0
- union/_protos/common/authorization_pb2_grpc.py +4 -0
- union/_protos/common/identifier_pb2.py +71 -0
- union/_protos/common/identifier_pb2.pyi +82 -0
- union/_protos/common/identifier_pb2_grpc.py +4 -0
- union/_protos/common/identity_pb2.py +48 -0
- union/_protos/common/identity_pb2.pyi +72 -0
- union/_protos/common/identity_pb2_grpc.py +4 -0
- union/_protos/common/list_pb2.py +36 -0
- union/_protos/common/list_pb2.pyi +69 -0
- union/_protos/common/list_pb2_grpc.py +4 -0
- union/_protos/common/policy_pb2.py +37 -0
- union/_protos/common/policy_pb2.pyi +27 -0
- union/_protos/common/policy_pb2_grpc.py +4 -0
- union/_protos/common/role_pb2.py +37 -0
- union/_protos/common/role_pb2.pyi +51 -0
- union/_protos/common/role_pb2_grpc.py +4 -0
- union/_protos/common/runtime_version_pb2.py +28 -0
- union/_protos/common/runtime_version_pb2.pyi +24 -0
- union/_protos/common/runtime_version_pb2_grpc.py +4 -0
- union/_protos/logs/dataplane/payload_pb2.py +96 -0
- union/_protos/logs/dataplane/payload_pb2.pyi +168 -0
- union/_protos/logs/dataplane/payload_pb2_grpc.py +4 -0
- union/_protos/secret/definition_pb2.py +49 -0
- union/_protos/secret/definition_pb2.pyi +93 -0
- union/_protos/secret/definition_pb2_grpc.py +4 -0
- union/_protos/secret/payload_pb2.py +62 -0
- union/_protos/secret/payload_pb2.pyi +94 -0
- union/_protos/secret/payload_pb2_grpc.py +4 -0
- union/_protos/secret/secret_pb2.py +38 -0
- union/_protos/secret/secret_pb2.pyi +6 -0
- union/_protos/secret/secret_pb2_grpc.py +198 -0
- union/_protos/validate/validate/validate_pb2.py +76 -0
- union/_protos/workflow/node_execution_service_pb2.py +26 -0
- union/_protos/workflow/node_execution_service_pb2.pyi +4 -0
- union/_protos/workflow/node_execution_service_pb2_grpc.py +32 -0
- union/_protos/workflow/queue_service_pb2.py +75 -0
- union/_protos/workflow/queue_service_pb2.pyi +103 -0
- union/_protos/workflow/queue_service_pb2_grpc.py +172 -0
- union/_protos/workflow/run_definition_pb2.py +100 -0
- union/_protos/workflow/run_definition_pb2.pyi +256 -0
- union/_protos/workflow/run_definition_pb2_grpc.py +4 -0
- union/_protos/workflow/run_logs_service_pb2.py +41 -0
- union/_protos/workflow/run_logs_service_pb2.pyi +28 -0
- union/_protos/workflow/run_logs_service_pb2_grpc.py +69 -0
- union/_protos/workflow/run_service_pb2.py +133 -0
- union/_protos/workflow/run_service_pb2.pyi +173 -0
- union/_protos/workflow/run_service_pb2_grpc.py +412 -0
- union/_protos/workflow/state_service_pb2.py +58 -0
- union/_protos/workflow/state_service_pb2.pyi +69 -0
- union/_protos/workflow/state_service_pb2_grpc.py +138 -0
- union/_protos/workflow/task_definition_pb2.py +72 -0
- union/_protos/workflow/task_definition_pb2.pyi +65 -0
- union/_protos/workflow/task_definition_pb2_grpc.py +4 -0
- union/_protos/workflow/task_service_pb2.py +44 -0
- union/_protos/workflow/task_service_pb2.pyi +31 -0
- union/_protos/workflow/task_service_pb2_grpc.py +104 -0
- union/_resources.py +226 -0
- union/_retry.py +32 -0
- union/_reusable_environment.py +25 -0
- union/_run.py +374 -0
- union/_secret.py +61 -0
- union/_task.py +354 -0
- union/_task_environment.py +186 -0
- union/_timeout.py +47 -0
- union/_tools.py +27 -0
- union/_utils/__init__.py +11 -0
- union/_utils/asyn.py +119 -0
- union/_utils/file_handling.py +71 -0
- union/_utils/helpers.py +46 -0
- union/_utils/lazy_module.py +54 -0
- union/_utils/uv_script_parser.py +49 -0
- union/_version.py +21 -0
- union/connectors/__init__.py +0 -0
- union/errors.py +128 -0
- union/extras/__init__.py +5 -0
- union/extras/_container.py +263 -0
- union/io/__init__.py +11 -0
- union/io/_dataframe.py +0 -0
- union/io/_dir.py +425 -0
- union/io/_file.py +418 -0
- union/io/pickle/__init__.py +0 -0
- union/io/pickle/transformer.py +117 -0
- union/io/structured_dataset/__init__.py +122 -0
- union/io/structured_dataset/basic_dfs.py +219 -0
- union/io/structured_dataset/structured_dataset.py +1057 -0
- union/py.typed +0 -0
- union/remote/__init__.py +23 -0
- union/remote/_client/__init__.py +0 -0
- union/remote/_client/_protocols.py +129 -0
- union/remote/_client/auth/__init__.py +12 -0
- union/remote/_client/auth/_authenticators/__init__.py +0 -0
- union/remote/_client/auth/_authenticators/base.py +391 -0
- union/remote/_client/auth/_authenticators/client_credentials.py +73 -0
- union/remote/_client/auth/_authenticators/device_code.py +120 -0
- union/remote/_client/auth/_authenticators/external_command.py +77 -0
- union/remote/_client/auth/_authenticators/factory.py +200 -0
- union/remote/_client/auth/_authenticators/pkce.py +515 -0
- union/remote/_client/auth/_channel.py +184 -0
- union/remote/_client/auth/_client_config.py +83 -0
- union/remote/_client/auth/_default_html.py +32 -0
- union/remote/_client/auth/_grpc_utils/__init__.py +0 -0
- union/remote/_client/auth/_grpc_utils/auth_interceptor.py +204 -0
- union/remote/_client/auth/_grpc_utils/default_metadata_interceptor.py +144 -0
- union/remote/_client/auth/_keyring.py +154 -0
- union/remote/_client/auth/_token_client.py +258 -0
- union/remote/_client/auth/errors.py +16 -0
- union/remote/_client/controlplane.py +86 -0
- union/remote/_data.py +149 -0
- union/remote/_logs.py +74 -0
- union/remote/_project.py +86 -0
- union/remote/_run.py +820 -0
- union/remote/_secret.py +132 -0
- union/remote/_task.py +193 -0
- union/report/__init__.py +3 -0
- union/report/_report.py +178 -0
- union/report/_template.html +124 -0
- union/storage/__init__.py +24 -0
- union/storage/_remote_fs.py +34 -0
- union/storage/_storage.py +247 -0
- union/storage/_utils.py +5 -0
- union/types/__init__.py +11 -0
- union/types/_renderer.py +162 -0
- union/types/_string_literals.py +120 -0
- union/types/_type_engine.py +2131 -0
- union/types/_utils.py +80 -0
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
import typing
|
|
2
|
+
|
|
3
|
+
import grpc.aio
|
|
4
|
+
from grpc.aio import ClientCallDetails, Metadata
|
|
5
|
+
|
|
6
|
+
_default_metadata = Metadata(("accept", "application/grpc"))
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def with_metadata(call_details: ClientCallDetails, new_metadata: Metadata) -> ClientCallDetails:
|
|
10
|
+
metadata = Metadata()
|
|
11
|
+
for k, v in call_details.metadata.keys():
|
|
12
|
+
# Add existing metadata to the new metadata object
|
|
13
|
+
metadata.add(key=k, value=v)
|
|
14
|
+
for k, v in new_metadata.keys():
|
|
15
|
+
metadata.add(key=k, value=v)
|
|
16
|
+
|
|
17
|
+
# return call_details._replace(metadata=metadata), None
|
|
18
|
+
return ClientCallDetails(
|
|
19
|
+
method=call_details.method,
|
|
20
|
+
timeout=call_details.timeout,
|
|
21
|
+
metadata=metadata,
|
|
22
|
+
credentials=call_details.credentials,
|
|
23
|
+
wait_for_ready=call_details.wait_for_ready,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class _BaseDefaultMetadataInterceptor:
|
|
28
|
+
"""
|
|
29
|
+
Base class for all default metadata interceptors that provides common functionality.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
async def _inject_default_metadata(self, call_details: grpc.aio.ClientCallDetails):
|
|
33
|
+
"""
|
|
34
|
+
Injects default metadata into the client call details.
|
|
35
|
+
|
|
36
|
+
This method adds all key-value pairs from the default metadata dictionary to the
|
|
37
|
+
client call details metadata. If the client call details don't have metadata,
|
|
38
|
+
a new Metadata object is created.
|
|
39
|
+
|
|
40
|
+
:param call_details: The client call details to inject metadata into
|
|
41
|
+
:return: A new ClientCallDetails object with the injected metadata
|
|
42
|
+
"""
|
|
43
|
+
return with_metadata(call_details, _default_metadata)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class DefaultMetadataUnaryUnaryInterceptor(_BaseDefaultMetadataInterceptor, grpc.aio.UnaryUnaryClientInterceptor):
|
|
47
|
+
"""
|
|
48
|
+
Interceptor for unary-unary RPC calls that adds default metadata.
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
async def intercept_unary_unary(
|
|
52
|
+
self,
|
|
53
|
+
continuation: typing.Callable,
|
|
54
|
+
client_call_details: grpc.aio.ClientCallDetails,
|
|
55
|
+
request: typing.Any,
|
|
56
|
+
):
|
|
57
|
+
"""
|
|
58
|
+
Intercepts unary-unary calls and injects default metadata.
|
|
59
|
+
|
|
60
|
+
This method adds default metadata to the client call details before continuing the RPC call chain.
|
|
61
|
+
|
|
62
|
+
:param continuation: Function to continue the RPC call chain with the updated call details
|
|
63
|
+
:param client_call_details: Details about the RPC call including method, timeout, metadata, credentials,
|
|
64
|
+
and wait_for_ready
|
|
65
|
+
:param request: The request message to be sent to the server
|
|
66
|
+
:return: The response from the RPC call
|
|
67
|
+
"""
|
|
68
|
+
updated_call_details = await self._inject_default_metadata(client_call_details)
|
|
69
|
+
return await (await continuation(updated_call_details, request))
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class DefaultMetadataUnaryStreamInterceptor(_BaseDefaultMetadataInterceptor, grpc.aio.UnaryStreamClientInterceptor):
|
|
73
|
+
"""
|
|
74
|
+
Interceptor for unary-stream RPC calls that adds default metadata.
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
async def intercept_unary_stream(
|
|
78
|
+
self,
|
|
79
|
+
continuation: typing.Callable,
|
|
80
|
+
client_call_details: grpc.aio.ClientCallDetails,
|
|
81
|
+
request: typing.Any,
|
|
82
|
+
):
|
|
83
|
+
"""
|
|
84
|
+
Intercepts unary-stream calls and injects default metadata.
|
|
85
|
+
|
|
86
|
+
This method adds default metadata to the client call details before continuing the RPC call chain.
|
|
87
|
+
|
|
88
|
+
:param continuation: Function to continue the RPC call chain with the updated call details
|
|
89
|
+
:param client_call_details: Details about the RPC call including method, timeout, metadata, credentials,
|
|
90
|
+
and wait_for_ready
|
|
91
|
+
:param request: The request message to be sent to the server
|
|
92
|
+
:return: A stream of responses from the RPC call
|
|
93
|
+
"""
|
|
94
|
+
updated_call_details = await self._inject_default_metadata(client_call_details)
|
|
95
|
+
return await continuation(updated_call_details, request)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class DefaultMetadataStreamUnaryInterceptor(_BaseDefaultMetadataInterceptor, grpc.aio.StreamUnaryClientInterceptor):
|
|
99
|
+
"""
|
|
100
|
+
Interceptor for stream-unary RPC calls that adds default metadata.
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
async def intercept_stream_unary(
|
|
104
|
+
self,
|
|
105
|
+
continuation: typing.Callable,
|
|
106
|
+
client_call_details: grpc.aio.ClientCallDetails,
|
|
107
|
+
request_iterator: typing.Any,
|
|
108
|
+
):
|
|
109
|
+
"""
|
|
110
|
+
Intercepts stream-unary calls and injects default metadata.
|
|
111
|
+
|
|
112
|
+
This method adds default metadata to the client call details before continuing the RPC call chain.
|
|
113
|
+
|
|
114
|
+
:param continuation: Function to continue the RPC call chain with the updated call details
|
|
115
|
+
:param client_call_details: Details about the RPC call including method, timeout, metadata, credentials,
|
|
116
|
+
and wait_for_ready
|
|
117
|
+
:param request_iterator: An iterator of request messages to be sent to the server
|
|
118
|
+
:return: The response from the RPC call
|
|
119
|
+
"""
|
|
120
|
+
updated_call_details = await self._inject_default_metadata(client_call_details)
|
|
121
|
+
return await continuation(updated_call_details, request_iterator)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
class DefaultMetadataStreamStreamInterceptor(_BaseDefaultMetadataInterceptor, grpc.aio.StreamStreamClientInterceptor):
|
|
125
|
+
"""
|
|
126
|
+
Interceptor for stream-stream RPC calls that adds default metadata.
|
|
127
|
+
"""
|
|
128
|
+
|
|
129
|
+
async def intercept_stream_stream(
|
|
130
|
+
self,
|
|
131
|
+
continuation: typing.Callable,
|
|
132
|
+
client_call_details: grpc.aio.ClientCallDetails,
|
|
133
|
+
request_iterator: typing.Any,
|
|
134
|
+
):
|
|
135
|
+
"""
|
|
136
|
+
Intercepts stream-stream calls and injects default metadata.
|
|
137
|
+
|
|
138
|
+
This method adds default metadata to the client call details before continuing the RPC call chain.
|
|
139
|
+
|
|
140
|
+
:param continuation: Function to continue the RPC call chain with the updated call details
|
|
141
|
+
:param client_call_details: Details about the RPC call including method, timeout, metadata, credentials, and
|
|
142
|
+
wait_for_ready
|
|
143
|
+
:param request_iterator: An iterator of request messages to be sent to the server
|
|
144
|
+
:return: A stream of responses from the RPC call
|
|
145
|
+
"""
|
|
146
|
+
updated_call_details = await self._inject_default_metadata(client_call_details)
|
|
147
|
+
return await continuation(updated_call_details, request_iterator)
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
# For backward compatibility, maintain the original class name but as a type alias
|
|
151
|
+
DefaultMetadataInterceptor = DefaultMetadataUnaryUnaryInterceptor
|
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
import hashlib # Added import for hashing
|
|
2
|
+
import typing
|
|
3
|
+
from urllib.parse import urlparse # Added import
|
|
4
|
+
|
|
5
|
+
import keyring
|
|
6
|
+
import pydantic
|
|
7
|
+
from keyring.errors import NoKeyringError, PasswordDeleteError
|
|
8
|
+
|
|
9
|
+
from flyte._logging import logger
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def strip_scheme(url: str) -> str:
|
|
13
|
+
"""
|
|
14
|
+
Strips the scheme from a URL.
|
|
15
|
+
Handles cases like:
|
|
16
|
+
- dns:///foo.com -> foo.com
|
|
17
|
+
- https://foo.com -> foo.com
|
|
18
|
+
- https://foo.com/blah -> foo.com/blah
|
|
19
|
+
"""
|
|
20
|
+
parsed_url = urlparse(url)
|
|
21
|
+
if parsed_url.scheme == "dns":
|
|
22
|
+
return parsed_url.path.lstrip("/")
|
|
23
|
+
return f"{parsed_url.netloc}{parsed_url.path}" if parsed_url.netloc else url
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class Credentials(pydantic.BaseModel):
|
|
27
|
+
"""
|
|
28
|
+
Stores the credentials together
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
access_token: str
|
|
32
|
+
for_endpoint: str = "flyte-default"
|
|
33
|
+
id: str = ""
|
|
34
|
+
refresh_token: str | None = None
|
|
35
|
+
expires_in: int | None = None
|
|
36
|
+
|
|
37
|
+
@pydantic.field_validator("for_endpoint", mode="after")
|
|
38
|
+
@classmethod
|
|
39
|
+
def validate_endpoint(cls, v: str) -> str:
|
|
40
|
+
return strip_scheme(v)
|
|
41
|
+
|
|
42
|
+
@pydantic.model_validator(mode="after")
|
|
43
|
+
def compute_id(self) -> "Credentials":
|
|
44
|
+
"""Computes the id field as a hash of the access_token."""
|
|
45
|
+
if self.access_token:
|
|
46
|
+
self.id = hashlib.md5(self.access_token.encode()).hexdigest()
|
|
47
|
+
return self
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class KeyringStore:
|
|
51
|
+
"""
|
|
52
|
+
Methods to access Keyring Store.
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
_access_token_key = "access_token"
|
|
56
|
+
_refresh_token_key = "refresh_token"
|
|
57
|
+
|
|
58
|
+
@staticmethod
|
|
59
|
+
def store(credentials: Credentials) -> Credentials:
|
|
60
|
+
"""
|
|
61
|
+
Stores the provided credentials in the system keyring.
|
|
62
|
+
|
|
63
|
+
This method stores the access token, refresh token (if available), and ID token (if available)
|
|
64
|
+
in the system keyring, using the endpoint as the service name and specific key names for each token type.
|
|
65
|
+
|
|
66
|
+
:param credentials: The credentials object containing tokens to store
|
|
67
|
+
:return: The same credentials object that was passed in
|
|
68
|
+
:raises: Logs but does not raise NoKeyringError if the system keyring is not available
|
|
69
|
+
"""
|
|
70
|
+
try:
|
|
71
|
+
if credentials.refresh_token:
|
|
72
|
+
keyring.set_password(
|
|
73
|
+
credentials.for_endpoint,
|
|
74
|
+
KeyringStore._refresh_token_key,
|
|
75
|
+
credentials.refresh_token,
|
|
76
|
+
)
|
|
77
|
+
keyring.set_password(
|
|
78
|
+
credentials.for_endpoint,
|
|
79
|
+
KeyringStore._access_token_key,
|
|
80
|
+
credentials.access_token,
|
|
81
|
+
)
|
|
82
|
+
except NoKeyringError as e:
|
|
83
|
+
logger.debug(f"KeyRing not available, tokens will not be cached. Error: {e}")
|
|
84
|
+
return credentials
|
|
85
|
+
|
|
86
|
+
@staticmethod
|
|
87
|
+
def retrieve(for_endpoint: str) -> typing.Optional[Credentials]:
|
|
88
|
+
"""
|
|
89
|
+
Retrieves stored credentials from the system keyring for the specified endpoint.
|
|
90
|
+
|
|
91
|
+
This method attempts to retrieve the access token, refresh token, and ID token from the system keyring
|
|
92
|
+
using the endpoint as the service name. The endpoint URL scheme is stripped before lookup.
|
|
93
|
+
|
|
94
|
+
:param for_endpoint: The endpoint URL to retrieve credentials for
|
|
95
|
+
:return: A Credentials object containing the retrieved tokens, or None if no tokens were found
|
|
96
|
+
or if the system keyring is not available
|
|
97
|
+
"""
|
|
98
|
+
for_endpoint = strip_scheme(for_endpoint)
|
|
99
|
+
try:
|
|
100
|
+
refresh_token = keyring.get_password(for_endpoint, KeyringStore._refresh_token_key)
|
|
101
|
+
access_token = keyring.get_password(for_endpoint, KeyringStore._access_token_key)
|
|
102
|
+
except NoKeyringError as e:
|
|
103
|
+
logger.debug(f"KeyRing not available, tokens will not be cached. Error: {e}")
|
|
104
|
+
return None
|
|
105
|
+
|
|
106
|
+
if not access_token:
|
|
107
|
+
logger.debug("No access token found in keyring.")
|
|
108
|
+
return None
|
|
109
|
+
|
|
110
|
+
return Credentials(
|
|
111
|
+
access_token=access_token,
|
|
112
|
+
refresh_token=refresh_token,
|
|
113
|
+
for_endpoint=for_endpoint,
|
|
114
|
+
expires_in=None,
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
@staticmethod
|
|
118
|
+
def delete(for_endpoint: str):
|
|
119
|
+
"""
|
|
120
|
+
Deletes all stored credentials for the specified endpoint from the system keyring.
|
|
121
|
+
|
|
122
|
+
This method attempts to delete the access token, refresh token, and ID token from the system keyring
|
|
123
|
+
using the endpoint as the service name. The endpoint URL scheme is stripped before lookup.
|
|
124
|
+
|
|
125
|
+
:param for_endpoint: The endpoint URL to delete credentials for
|
|
126
|
+
"""
|
|
127
|
+
for_endpoint = strip_scheme(for_endpoint)
|
|
128
|
+
|
|
129
|
+
def _delete_key(key):
|
|
130
|
+
"""
|
|
131
|
+
Helper function to delete a specific key from the keyring.
|
|
132
|
+
|
|
133
|
+
:param key: The key name to delete
|
|
134
|
+
"""
|
|
135
|
+
try:
|
|
136
|
+
keyring.delete_password(for_endpoint, key)
|
|
137
|
+
except PasswordDeleteError as e:
|
|
138
|
+
logger.debug(f"Key {key} not found in key store, Ignoring. Error: {e}")
|
|
139
|
+
except NoKeyringError as e:
|
|
140
|
+
logger.debug(f"KeyRing not available, Key {key} deletion failed. Error: {e}")
|
|
141
|
+
|
|
142
|
+
_delete_key(KeyringStore._access_token_key)
|
|
143
|
+
_delete_key(KeyringStore._refresh_token_key)
|
|
@@ -0,0 +1,260 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import base64
|
|
3
|
+
import enum
|
|
4
|
+
import typing
|
|
5
|
+
import urllib.parse
|
|
6
|
+
from datetime import datetime, timedelta
|
|
7
|
+
|
|
8
|
+
import httpx
|
|
9
|
+
import pydantic
|
|
10
|
+
|
|
11
|
+
from flyte._logging import logger
|
|
12
|
+
from flyte.remote._client.auth.errors import AuthenticationError, AuthenticationPending
|
|
13
|
+
|
|
14
|
+
utf_8 = "utf-8"
|
|
15
|
+
|
|
16
|
+
# Errors that Token endpoint will return
|
|
17
|
+
error_slow_down = "slow_down"
|
|
18
|
+
error_auth_pending = "authorization_pending"
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
# Grant Types
|
|
22
|
+
class GrantType(str, enum.Enum):
|
|
23
|
+
CLIENT_CREDS = "client_credentials"
|
|
24
|
+
DEVICE_CODE = "urn:ietf:params:oauth:grant-type:device_code"
|
|
25
|
+
REFRESH_TOKEN = "refresh_token"
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class DeviceCodeResponse(pydantic.BaseModel):
|
|
29
|
+
"""
|
|
30
|
+
Response from device auth flow endpoint
|
|
31
|
+
{
|
|
32
|
+
'device_code': 'code',
|
|
33
|
+
'user_code': 'BNDJJFXL',
|
|
34
|
+
'verification_uri': 'url',
|
|
35
|
+
'expires_in': 600,
|
|
36
|
+
'interval': 5
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
Attributes:
|
|
40
|
+
device_code (str): The device verification code.
|
|
41
|
+
user_code (str): The user-facing code that should be entered on the verification page.
|
|
42
|
+
verification_uri (str): The URL where the user should enter the user_code.
|
|
43
|
+
expires_in (int): The lifetime in seconds of the device code and user code.
|
|
44
|
+
interval (int): The minimum amount of time in seconds to wait between polling requests.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
device_code: str
|
|
48
|
+
user_code: str
|
|
49
|
+
verification_uri: str
|
|
50
|
+
expires_in: int
|
|
51
|
+
interval: int
|
|
52
|
+
|
|
53
|
+
@classmethod
|
|
54
|
+
def from_json_response(cls, j: typing.Dict) -> "DeviceCodeResponse":
|
|
55
|
+
"""
|
|
56
|
+
Create a DeviceCodeResponse instance from a JSON response dictionary.
|
|
57
|
+
|
|
58
|
+
:param j: The JSON response dictionary containing device code information
|
|
59
|
+
:return: A new instance with values from the JSON response
|
|
60
|
+
"""
|
|
61
|
+
return cls(
|
|
62
|
+
device_code=j["device_code"],
|
|
63
|
+
user_code=j["user_code"],
|
|
64
|
+
verification_uri=j["verification_uri"],
|
|
65
|
+
expires_in=j["expires_in"],
|
|
66
|
+
interval=j["interval"],
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def get_basic_authorization_header(client_id: str, client_secret: str) -> str:
|
|
71
|
+
"""
|
|
72
|
+
This function transforms the client id and the client secret into a header that conforms with http basic auth.
|
|
73
|
+
It joins the id and the secret with a : then base64 encodes it, then adds the appropriate text. Secrets are
|
|
74
|
+
first URL encoded to escape illegal characters.
|
|
75
|
+
|
|
76
|
+
:param client_id: The client ID for authentication
|
|
77
|
+
:param client_secret: The client secret for authentication
|
|
78
|
+
:rtype: str
|
|
79
|
+
"""
|
|
80
|
+
encoded = urllib.parse.quote_plus(client_secret)
|
|
81
|
+
concatenated = "{}:{}".format(client_id, encoded)
|
|
82
|
+
return "Basic {}".format(base64.b64encode(concatenated.encode(utf_8)).decode(utf_8))
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
async def get_token(
|
|
86
|
+
token_endpoint: str,
|
|
87
|
+
http_session: httpx.AsyncClient,
|
|
88
|
+
scopes: typing.Optional[typing.List[str]] = None,
|
|
89
|
+
authorization_header: typing.Optional[str] = None,
|
|
90
|
+
client_id: typing.Optional[str] = None,
|
|
91
|
+
device_code: typing.Optional[str] = None,
|
|
92
|
+
audience: typing.Optional[str] = None,
|
|
93
|
+
grant_type: GrantType = GrantType.CLIENT_CREDS,
|
|
94
|
+
http_proxy_url: typing.Optional[str] = None,
|
|
95
|
+
verify: typing.Optional[typing.Union[bool, str]] = None,
|
|
96
|
+
refresh_token: typing.Optional[str] = None,
|
|
97
|
+
) -> typing.Tuple[str, str, int]:
|
|
98
|
+
"""
|
|
99
|
+
Retrieves an access token from the specified token endpoint.
|
|
100
|
+
|
|
101
|
+
:param token_endpoint: The endpoint URL for token retrieval
|
|
102
|
+
:param http_session: HTTP session to use for requests
|
|
103
|
+
:param scopes: Optional list of scopes to request during authentication
|
|
104
|
+
:param authorization_header: Optional authorization header value
|
|
105
|
+
:param client_id: Optional client ID for authentication
|
|
106
|
+
:param device_code: Optional device code for device flow authentication
|
|
107
|
+
:param audience: Optional audience for the token
|
|
108
|
+
:param grant_type: The grant type to use (default: CLIENT_CREDS)
|
|
109
|
+
:param http_proxy_url: Optional HTTP proxy URL
|
|
110
|
+
:param verify: Whether to verify SSL certificates (bool or path to cert)
|
|
111
|
+
:param refresh_token: Optional refresh token for token refresh
|
|
112
|
+
:return: A tuple of (access_token, refresh_token, expires_in)
|
|
113
|
+
|
|
114
|
+
:param token_endpoint: The URL of the token endpoint
|
|
115
|
+
:param scopes: List of scopes to request
|
|
116
|
+
:param authorization_header: Authorization header value if using client credentials
|
|
117
|
+
:param client_id: The client ID to use for authentication
|
|
118
|
+
:param device_code: The device code when using device code flow
|
|
119
|
+
:param audience: The audience value to request
|
|
120
|
+
:param grant_type: The OAuth grant type to use
|
|
121
|
+
:param http_proxy_url: HTTP proxy URL if needed
|
|
122
|
+
:param verify: SSL verification mode
|
|
123
|
+
:param http_session: An existing HTTP client session
|
|
124
|
+
:param refresh_token: Refresh token for refresh token flow
|
|
125
|
+
:return: A tuple containing (access_token, refresh_token, expires_in)
|
|
126
|
+
|
|
127
|
+
Raises:
|
|
128
|
+
AuthenticationPending: When authentication is still pending (for device code flow).
|
|
129
|
+
AuthenticationError: When authentication fails for any reason.
|
|
130
|
+
"""
|
|
131
|
+
headers = {
|
|
132
|
+
"Cache-Control": "no-cache",
|
|
133
|
+
"Accept": "application/json",
|
|
134
|
+
"Content-Type": "application/x-www-form-urlencoded",
|
|
135
|
+
}
|
|
136
|
+
if authorization_header:
|
|
137
|
+
headers["Authorization"] = authorization_header
|
|
138
|
+
body = {
|
|
139
|
+
"grant_type": grant_type.value,
|
|
140
|
+
}
|
|
141
|
+
if client_id:
|
|
142
|
+
body["client_id"] = client_id
|
|
143
|
+
if device_code:
|
|
144
|
+
body["device_code"] = device_code
|
|
145
|
+
if scopes is not None:
|
|
146
|
+
body["scope"] = " ".join(s.strip("' ") for s in scopes).strip("[]'")
|
|
147
|
+
if audience:
|
|
148
|
+
body["audience"] = audience
|
|
149
|
+
if refresh_token:
|
|
150
|
+
body["refresh_token"] = refresh_token
|
|
151
|
+
|
|
152
|
+
response = await http_session.post(token_endpoint, data=body, headers=headers)
|
|
153
|
+
|
|
154
|
+
if not response.is_success:
|
|
155
|
+
j = response.json()
|
|
156
|
+
if "error" in j:
|
|
157
|
+
err = j["error"]
|
|
158
|
+
if err == error_auth_pending or err == error_slow_down:
|
|
159
|
+
raise AuthenticationPending(f"Token not yet available, try again in some time {err}")
|
|
160
|
+
logger.error("Status Code ({}) received from IDP: {}".format(response.status_code, response.text))
|
|
161
|
+
raise AuthenticationError("Status Code ({}) received from IDP: {}".format(response.status_code, response.text))
|
|
162
|
+
|
|
163
|
+
j = response.json()
|
|
164
|
+
new_refresh_token = None
|
|
165
|
+
if "refresh_token" in j:
|
|
166
|
+
new_refresh_token = j["refresh_token"]
|
|
167
|
+
else:
|
|
168
|
+
raise AuthenticationError("Token not yet available, try again in some time")
|
|
169
|
+
|
|
170
|
+
return j["access_token"], new_refresh_token, j["expires_in"]
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
async def get_device_code(
|
|
174
|
+
device_auth_endpoint: str,
|
|
175
|
+
client_id: str,
|
|
176
|
+
http_session: httpx.AsyncClient,
|
|
177
|
+
*,
|
|
178
|
+
audience: typing.Optional[str] = None,
|
|
179
|
+
scopes: typing.Optional[typing.List[str]] = None,
|
|
180
|
+
) -> DeviceCodeResponse:
|
|
181
|
+
"""
|
|
182
|
+
Retrieves the device authentication code that can be used to authenticate the request using a browser on a
|
|
183
|
+
separate device.
|
|
184
|
+
|
|
185
|
+
:param device_auth_endpoint: The URL of the device authorization endpoint
|
|
186
|
+
:param client_id: The client ID to use for authentication
|
|
187
|
+
:param audience: The audience value to request
|
|
188
|
+
:param scopes: List of scopes to request
|
|
189
|
+
:param http_proxy_url: HTTP proxy URL if needed
|
|
190
|
+
:param verify: SSL verification mode
|
|
191
|
+
:param http_session: An existing HTTP client session
|
|
192
|
+
:return: An object containing the device code and related information
|
|
193
|
+
:raises AuthenticationError: When device code retrieval fails
|
|
194
|
+
"""
|
|
195
|
+
_scope = " ".join(s.strip("' ") for s in scopes).strip("[]'") if scopes is not None else ""
|
|
196
|
+
payload = {"client_id": client_id, "scope": _scope, "audience": audience}
|
|
197
|
+
resp = await http_session.post(device_auth_endpoint, data=payload)
|
|
198
|
+
if not resp.is_success:
|
|
199
|
+
raise AuthenticationError(
|
|
200
|
+
f"Unable to retrieve Device Authentication Code for {payload},"
|
|
201
|
+
f" Status Code {resp.status_code} Reason {resp.json()}"
|
|
202
|
+
)
|
|
203
|
+
return DeviceCodeResponse.from_json_response(resp.json())
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
async def poll_token_endpoint(
|
|
207
|
+
resp: DeviceCodeResponse,
|
|
208
|
+
*,
|
|
209
|
+
token_endpoint: str,
|
|
210
|
+
client_id: str,
|
|
211
|
+
http_session: httpx.AsyncClient,
|
|
212
|
+
audience: typing.Optional[str] = None,
|
|
213
|
+
scopes: typing.Optional[typing.List[str]] = None,
|
|
214
|
+
http_proxy_url: typing.Optional[str] = None,
|
|
215
|
+
verify: typing.Optional[typing.Union[bool, str]] = None,
|
|
216
|
+
) -> typing.Tuple[str, str, int]:
|
|
217
|
+
"""
|
|
218
|
+
Polls the token endpoint until authentication is complete or times out.
|
|
219
|
+
|
|
220
|
+
This function repeatedly calls the token endpoint at the specified interval until either:
|
|
221
|
+
1. Authentication is successful and a token is returned
|
|
222
|
+
2. The device code expires (as specified in the DeviceCodeResponse)
|
|
223
|
+
|
|
224
|
+
:param resp: The device code response from a previous call to get_device_code
|
|
225
|
+
:param token_endpoint: The URL of the token endpoint
|
|
226
|
+
:param client_id: The client ID to use for authentication
|
|
227
|
+
:param audience: The audience value to request
|
|
228
|
+
:param scopes: Space-separated list of scopes to request
|
|
229
|
+
:param http_proxy_url: HTTP proxy URL if needed
|
|
230
|
+
:param verify: SSL verification mode
|
|
231
|
+
:return: A tuple containing (access_token, refresh_token, expires_in)
|
|
232
|
+
:raises AuthenticationError: When authentication fails or times out
|
|
233
|
+
"""
|
|
234
|
+
tick = datetime.now()
|
|
235
|
+
interval = timedelta(seconds=resp.interval)
|
|
236
|
+
end_time = tick + timedelta(seconds=resp.expires_in)
|
|
237
|
+
while tick < end_time:
|
|
238
|
+
try:
|
|
239
|
+
access_token, refresh_token, expires_in = await get_token(
|
|
240
|
+
token_endpoint,
|
|
241
|
+
grant_type=GrantType.DEVICE_CODE,
|
|
242
|
+
client_id=client_id,
|
|
243
|
+
audience=audience,
|
|
244
|
+
scopes=scopes,
|
|
245
|
+
device_code=resp.device_code,
|
|
246
|
+
http_proxy_url=http_proxy_url,
|
|
247
|
+
verify=verify,
|
|
248
|
+
http_session=http_session,
|
|
249
|
+
)
|
|
250
|
+
logger.debug(f"Authentication successful, access token received, expires in {expires_in} seconds")
|
|
251
|
+
return access_token, refresh_token, expires_in
|
|
252
|
+
except AuthenticationPending:
|
|
253
|
+
...
|
|
254
|
+
except Exception as e:
|
|
255
|
+
logger.warning(f"Authentication failed, reason {e}")
|
|
256
|
+
raise e
|
|
257
|
+
logger.debug(f"Authentication pending, ..., waiting for {resp.interval} seconds")
|
|
258
|
+
await asyncio.sleep(interval.total_seconds())
|
|
259
|
+
tick = tick + interval
|
|
260
|
+
raise AuthenticationError("Authentication failed!")
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
class AccessTokenNotFoundError(RuntimeError):
|
|
2
|
+
"""
|
|
3
|
+
This error is raised with Access token is not found or if Refreshing the token fails
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class AuthenticationError(RuntimeError):
|
|
8
|
+
"""
|
|
9
|
+
This is raised for any AuthenticationError
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class AuthenticationPending(RuntimeError):
|
|
14
|
+
"""
|
|
15
|
+
This is raised if the token endpoint returns authentication pending
|
|
16
|
+
"""
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import grpc
|
|
4
|
+
from flyteidl.service import admin_pb2_grpc, dataproxy_pb2_grpc
|
|
5
|
+
|
|
6
|
+
from flyte._protos.secret import secret_pb2_grpc
|
|
7
|
+
from flyte._protos.workflow import run_logs_service_pb2_grpc, run_service_pb2_grpc, task_service_pb2_grpc
|
|
8
|
+
|
|
9
|
+
from ._protocols import (
|
|
10
|
+
DataProxyService,
|
|
11
|
+
MetadataServiceProtocol,
|
|
12
|
+
ProjectDomainService,
|
|
13
|
+
RunLogsService,
|
|
14
|
+
RunService,
|
|
15
|
+
SecretService,
|
|
16
|
+
TaskService,
|
|
17
|
+
)
|
|
18
|
+
from .auth import create_channel
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class ClientSet:
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
channel: grpc.aio.Channel,
|
|
25
|
+
endpoint: str,
|
|
26
|
+
insecure: bool = False,
|
|
27
|
+
data_proxy_channel: grpc.aio.Channel | None = None,
|
|
28
|
+
**kwargs,
|
|
29
|
+
):
|
|
30
|
+
self.endpoint = endpoint
|
|
31
|
+
self.insecure = insecure
|
|
32
|
+
self._channel = channel
|
|
33
|
+
self._admin_client = admin_pb2_grpc.AdminServiceStub(channel=channel)
|
|
34
|
+
self._task_service = task_service_pb2_grpc.TaskServiceStub(channel=channel)
|
|
35
|
+
self._run_service = run_service_pb2_grpc.RunServiceStub(channel=channel)
|
|
36
|
+
self._dataproxy = dataproxy_pb2_grpc.DataProxyServiceStub(channel=channel)
|
|
37
|
+
self._log_service = run_logs_service_pb2_grpc.RunLogsServiceStub(channel=channel)
|
|
38
|
+
self._secrets_service = secret_pb2_grpc.SecretServiceStub(channel=channel)
|
|
39
|
+
|
|
40
|
+
@classmethod
|
|
41
|
+
async def for_endpoint(cls, endpoint: str, *, insecure: bool = False, **kwargs) -> ClientSet:
|
|
42
|
+
if insecure:
|
|
43
|
+
del kwargs["api_key"]
|
|
44
|
+
del kwargs["auth_type"]
|
|
45
|
+
del kwargs["headless"]
|
|
46
|
+
del kwargs["command"]
|
|
47
|
+
del kwargs["client_id"]
|
|
48
|
+
del kwargs["client_credentials_secret"]
|
|
49
|
+
del kwargs["client_config"]
|
|
50
|
+
del kwargs["rpc_retries"]
|
|
51
|
+
del kwargs["http_proxy_url"]
|
|
52
|
+
return cls(await create_channel(endpoint, insecure=insecure, **kwargs), endpoint, insecure=insecure, **kwargs)
|
|
53
|
+
|
|
54
|
+
@classmethod
|
|
55
|
+
async def for_api_key(cls, api_key: str, **kwargs) -> ClientSet:
|
|
56
|
+
raise NotImplementedError
|
|
57
|
+
|
|
58
|
+
@classmethod
|
|
59
|
+
async def for_serverless(cls) -> ClientSet:
|
|
60
|
+
raise NotImplementedError
|
|
61
|
+
|
|
62
|
+
@classmethod
|
|
63
|
+
async def from_env(cls) -> ClientSet:
|
|
64
|
+
raise NotImplementedError
|
|
65
|
+
|
|
66
|
+
@property
|
|
67
|
+
def metadata_service(self) -> MetadataServiceProtocol:
|
|
68
|
+
return self._admin_client
|
|
69
|
+
|
|
70
|
+
@property
|
|
71
|
+
def project_domain_service(self) -> ProjectDomainService:
|
|
72
|
+
return self._admin_client
|
|
73
|
+
|
|
74
|
+
@property
|
|
75
|
+
def task_service(self) -> TaskService:
|
|
76
|
+
return self._task_service
|
|
77
|
+
|
|
78
|
+
@property
|
|
79
|
+
def run_service(self) -> RunService:
|
|
80
|
+
return self._run_service
|
|
81
|
+
|
|
82
|
+
@property
|
|
83
|
+
def dataproxy_service(self) -> DataProxyService:
|
|
84
|
+
return self._dataproxy
|
|
85
|
+
|
|
86
|
+
@property
|
|
87
|
+
def logs_service(self) -> RunLogsService:
|
|
88
|
+
return self._log_service
|
|
89
|
+
|
|
90
|
+
@property
|
|
91
|
+
def secrets_service(self) -> SecretService:
|
|
92
|
+
return self._secrets_service
|
|
93
|
+
|
|
94
|
+
async def close(self, grace: float | None = None):
|
|
95
|
+
return await self._channel.close(grace=grace)
|
flyte/remote/_console.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
from urllib.parse import urlparse
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def _get_http_domain(endpoint: str, insecure: bool) -> str:
|
|
5
|
+
scheme = "http" if insecure else "https"
|
|
6
|
+
parsed = urlparse(endpoint)
|
|
7
|
+
if parsed.scheme == "dns":
|
|
8
|
+
domain = parsed.path.lstrip("/")
|
|
9
|
+
else:
|
|
10
|
+
domain = parsed.netloc
|
|
11
|
+
# TODO: make console url configurable
|
|
12
|
+
if domain.split(":")[0] == "localhost":
|
|
13
|
+
domain = "localhost:8080"
|
|
14
|
+
return f"{scheme}://{domain}"
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def get_run_url(endpoint: str, insecure: bool, project: str, domain: str, run_name: str) -> str:
|
|
18
|
+
return f"{_get_http_domain(endpoint, insecure)}/v2/runs/project/{project}/domain/{domain}/{run_name}"
|