flyte 0.0.1b0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flyte might be problematic. Click here for more details.
- flyte/__init__.py +62 -0
- flyte/_api_commons.py +3 -0
- flyte/_bin/__init__.py +0 -0
- flyte/_bin/runtime.py +126 -0
- flyte/_build.py +25 -0
- flyte/_cache/__init__.py +12 -0
- flyte/_cache/cache.py +146 -0
- flyte/_cache/defaults.py +9 -0
- flyte/_cache/policy_function_body.py +42 -0
- flyte/_cli/__init__.py +0 -0
- flyte/_cli/_common.py +287 -0
- flyte/_cli/_create.py +42 -0
- flyte/_cli/_delete.py +23 -0
- flyte/_cli/_deploy.py +140 -0
- flyte/_cli/_get.py +235 -0
- flyte/_cli/_run.py +152 -0
- flyte/_cli/main.py +72 -0
- flyte/_code_bundle/__init__.py +8 -0
- flyte/_code_bundle/_ignore.py +113 -0
- flyte/_code_bundle/_packaging.py +187 -0
- flyte/_code_bundle/_utils.py +339 -0
- flyte/_code_bundle/bundle.py +178 -0
- flyte/_context.py +146 -0
- flyte/_datastructures.py +342 -0
- flyte/_deploy.py +202 -0
- flyte/_doc.py +29 -0
- flyte/_docstring.py +32 -0
- flyte/_environment.py +43 -0
- flyte/_group.py +31 -0
- flyte/_hash.py +23 -0
- flyte/_image.py +760 -0
- flyte/_initialize.py +634 -0
- flyte/_interface.py +84 -0
- flyte/_internal/__init__.py +3 -0
- flyte/_internal/controllers/__init__.py +115 -0
- flyte/_internal/controllers/_local_controller.py +118 -0
- flyte/_internal/controllers/_trace.py +40 -0
- flyte/_internal/controllers/pbhash.py +39 -0
- flyte/_internal/controllers/remote/__init__.py +40 -0
- flyte/_internal/controllers/remote/_action.py +141 -0
- flyte/_internal/controllers/remote/_client.py +43 -0
- flyte/_internal/controllers/remote/_controller.py +361 -0
- flyte/_internal/controllers/remote/_core.py +402 -0
- flyte/_internal/controllers/remote/_informer.py +361 -0
- flyte/_internal/controllers/remote/_service_protocol.py +50 -0
- flyte/_internal/imagebuild/__init__.py +11 -0
- flyte/_internal/imagebuild/docker_builder.py +416 -0
- flyte/_internal/imagebuild/image_builder.py +241 -0
- flyte/_internal/imagebuild/remote_builder.py +0 -0
- flyte/_internal/resolvers/__init__.py +0 -0
- flyte/_internal/resolvers/_task_module.py +54 -0
- flyte/_internal/resolvers/common.py +31 -0
- flyte/_internal/resolvers/default.py +28 -0
- flyte/_internal/runtime/__init__.py +0 -0
- flyte/_internal/runtime/convert.py +199 -0
- flyte/_internal/runtime/entrypoints.py +135 -0
- flyte/_internal/runtime/io.py +136 -0
- flyte/_internal/runtime/resources_serde.py +138 -0
- flyte/_internal/runtime/task_serde.py +210 -0
- flyte/_internal/runtime/taskrunner.py +190 -0
- flyte/_internal/runtime/types_serde.py +54 -0
- flyte/_logging.py +124 -0
- flyte/_protos/__init__.py +0 -0
- flyte/_protos/common/authorization_pb2.py +66 -0
- flyte/_protos/common/authorization_pb2.pyi +108 -0
- flyte/_protos/common/authorization_pb2_grpc.py +4 -0
- flyte/_protos/common/identifier_pb2.py +71 -0
- flyte/_protos/common/identifier_pb2.pyi +82 -0
- flyte/_protos/common/identifier_pb2_grpc.py +4 -0
- flyte/_protos/common/identity_pb2.py +48 -0
- flyte/_protos/common/identity_pb2.pyi +72 -0
- flyte/_protos/common/identity_pb2_grpc.py +4 -0
- flyte/_protos/common/list_pb2.py +36 -0
- flyte/_protos/common/list_pb2.pyi +69 -0
- flyte/_protos/common/list_pb2_grpc.py +4 -0
- flyte/_protos/common/policy_pb2.py +37 -0
- flyte/_protos/common/policy_pb2.pyi +27 -0
- flyte/_protos/common/policy_pb2_grpc.py +4 -0
- flyte/_protos/common/role_pb2.py +37 -0
- flyte/_protos/common/role_pb2.pyi +53 -0
- flyte/_protos/common/role_pb2_grpc.py +4 -0
- flyte/_protos/common/runtime_version_pb2.py +28 -0
- flyte/_protos/common/runtime_version_pb2.pyi +24 -0
- flyte/_protos/common/runtime_version_pb2_grpc.py +4 -0
- flyte/_protos/logs/dataplane/payload_pb2.py +96 -0
- flyte/_protos/logs/dataplane/payload_pb2.pyi +168 -0
- flyte/_protos/logs/dataplane/payload_pb2_grpc.py +4 -0
- flyte/_protos/secret/definition_pb2.py +49 -0
- flyte/_protos/secret/definition_pb2.pyi +93 -0
- flyte/_protos/secret/definition_pb2_grpc.py +4 -0
- flyte/_protos/secret/payload_pb2.py +62 -0
- flyte/_protos/secret/payload_pb2.pyi +94 -0
- flyte/_protos/secret/payload_pb2_grpc.py +4 -0
- flyte/_protos/secret/secret_pb2.py +38 -0
- flyte/_protos/secret/secret_pb2.pyi +6 -0
- flyte/_protos/secret/secret_pb2_grpc.py +198 -0
- flyte/_protos/secret/secret_pb2_grpc_grpc.py +198 -0
- flyte/_protos/validate/validate/validate_pb2.py +76 -0
- flyte/_protos/workflow/node_execution_service_pb2.py +26 -0
- flyte/_protos/workflow/node_execution_service_pb2.pyi +4 -0
- flyte/_protos/workflow/node_execution_service_pb2_grpc.py +32 -0
- flyte/_protos/workflow/queue_service_pb2.py +106 -0
- flyte/_protos/workflow/queue_service_pb2.pyi +141 -0
- flyte/_protos/workflow/queue_service_pb2_grpc.py +172 -0
- flyte/_protos/workflow/run_definition_pb2.py +128 -0
- flyte/_protos/workflow/run_definition_pb2.pyi +310 -0
- flyte/_protos/workflow/run_definition_pb2_grpc.py +4 -0
- flyte/_protos/workflow/run_logs_service_pb2.py +41 -0
- flyte/_protos/workflow/run_logs_service_pb2.pyi +28 -0
- flyte/_protos/workflow/run_logs_service_pb2_grpc.py +69 -0
- flyte/_protos/workflow/run_service_pb2.py +133 -0
- flyte/_protos/workflow/run_service_pb2.pyi +175 -0
- flyte/_protos/workflow/run_service_pb2_grpc.py +412 -0
- flyte/_protos/workflow/state_service_pb2.py +58 -0
- flyte/_protos/workflow/state_service_pb2.pyi +71 -0
- flyte/_protos/workflow/state_service_pb2_grpc.py +138 -0
- flyte/_protos/workflow/task_definition_pb2.py +72 -0
- flyte/_protos/workflow/task_definition_pb2.pyi +65 -0
- flyte/_protos/workflow/task_definition_pb2_grpc.py +4 -0
- flyte/_protos/workflow/task_service_pb2.py +44 -0
- flyte/_protos/workflow/task_service_pb2.pyi +31 -0
- flyte/_protos/workflow/task_service_pb2_grpc.py +104 -0
- flyte/_resources.py +226 -0
- flyte/_retry.py +32 -0
- flyte/_reusable_environment.py +25 -0
- flyte/_run.py +411 -0
- flyte/_secret.py +61 -0
- flyte/_task.py +367 -0
- flyte/_task_environment.py +200 -0
- flyte/_timeout.py +47 -0
- flyte/_tools.py +27 -0
- flyte/_trace.py +128 -0
- flyte/_utils/__init__.py +20 -0
- flyte/_utils/asyn.py +119 -0
- flyte/_utils/coro_management.py +25 -0
- flyte/_utils/file_handling.py +72 -0
- flyte/_utils/helpers.py +108 -0
- flyte/_utils/lazy_module.py +54 -0
- flyte/_utils/uv_script_parser.py +49 -0
- flyte/_version.py +21 -0
- flyte/connectors/__init__.py +0 -0
- flyte/errors.py +143 -0
- flyte/extras/__init__.py +5 -0
- flyte/extras/_container.py +273 -0
- flyte/io/__init__.py +11 -0
- flyte/io/_dataframe.py +0 -0
- flyte/io/_dir.py +448 -0
- flyte/io/_file.py +468 -0
- flyte/io/pickle/__init__.py +0 -0
- flyte/io/pickle/transformer.py +117 -0
- flyte/io/structured_dataset/__init__.py +129 -0
- flyte/io/structured_dataset/basic_dfs.py +219 -0
- flyte/io/structured_dataset/structured_dataset.py +1061 -0
- flyte/py.typed +0 -0
- flyte/remote/__init__.py +25 -0
- flyte/remote/_client/__init__.py +0 -0
- flyte/remote/_client/_protocols.py +131 -0
- flyte/remote/_client/auth/__init__.py +12 -0
- flyte/remote/_client/auth/_authenticators/__init__.py +0 -0
- flyte/remote/_client/auth/_authenticators/base.py +397 -0
- flyte/remote/_client/auth/_authenticators/client_credentials.py +73 -0
- flyte/remote/_client/auth/_authenticators/device_code.py +118 -0
- flyte/remote/_client/auth/_authenticators/external_command.py +79 -0
- flyte/remote/_client/auth/_authenticators/factory.py +200 -0
- flyte/remote/_client/auth/_authenticators/pkce.py +516 -0
- flyte/remote/_client/auth/_channel.py +184 -0
- flyte/remote/_client/auth/_client_config.py +83 -0
- flyte/remote/_client/auth/_default_html.py +32 -0
- flyte/remote/_client/auth/_grpc_utils/__init__.py +0 -0
- flyte/remote/_client/auth/_grpc_utils/auth_interceptor.py +288 -0
- flyte/remote/_client/auth/_grpc_utils/default_metadata_interceptor.py +151 -0
- flyte/remote/_client/auth/_keyring.py +143 -0
- flyte/remote/_client/auth/_token_client.py +260 -0
- flyte/remote/_client/auth/errors.py +16 -0
- flyte/remote/_client/controlplane.py +95 -0
- flyte/remote/_console.py +18 -0
- flyte/remote/_data.py +155 -0
- flyte/remote/_logs.py +116 -0
- flyte/remote/_project.py +86 -0
- flyte/remote/_run.py +873 -0
- flyte/remote/_secret.py +132 -0
- flyte/remote/_task.py +227 -0
- flyte/report/__init__.py +3 -0
- flyte/report/_report.py +178 -0
- flyte/report/_template.html +124 -0
- flyte/storage/__init__.py +24 -0
- flyte/storage/_remote_fs.py +34 -0
- flyte/storage/_storage.py +251 -0
- flyte/storage/_utils.py +5 -0
- flyte/types/__init__.py +13 -0
- flyte/types/_interface.py +25 -0
- flyte/types/_renderer.py +162 -0
- flyte/types/_string_literals.py +120 -0
- flyte/types/_type_engine.py +2210 -0
- flyte/types/_utils.py +80 -0
- flyte-0.0.1b0.dist-info/METADATA +179 -0
- flyte-0.0.1b0.dist-info/RECORD +390 -0
- flyte-0.0.1b0.dist-info/WHEEL +5 -0
- flyte-0.0.1b0.dist-info/entry_points.txt +3 -0
- flyte-0.0.1b0.dist-info/top_level.txt +1 -0
- union/__init__.py +54 -0
- union/_api_commons.py +3 -0
- union/_bin/__init__.py +0 -0
- union/_bin/runtime.py +113 -0
- union/_build.py +25 -0
- union/_cache/__init__.py +12 -0
- union/_cache/cache.py +141 -0
- union/_cache/defaults.py +9 -0
- union/_cache/policy_function_body.py +42 -0
- union/_cli/__init__.py +0 -0
- union/_cli/_common.py +263 -0
- union/_cli/_create.py +40 -0
- union/_cli/_delete.py +23 -0
- union/_cli/_deploy.py +120 -0
- union/_cli/_get.py +162 -0
- union/_cli/_params.py +579 -0
- union/_cli/_run.py +150 -0
- union/_cli/main.py +72 -0
- union/_code_bundle/__init__.py +8 -0
- union/_code_bundle/_ignore.py +113 -0
- union/_code_bundle/_packaging.py +187 -0
- union/_code_bundle/_utils.py +342 -0
- union/_code_bundle/bundle.py +176 -0
- union/_context.py +146 -0
- union/_datastructures.py +295 -0
- union/_deploy.py +185 -0
- union/_doc.py +29 -0
- union/_docstring.py +26 -0
- union/_environment.py +43 -0
- union/_group.py +31 -0
- union/_hash.py +23 -0
- union/_image.py +760 -0
- union/_initialize.py +585 -0
- union/_interface.py +84 -0
- union/_internal/__init__.py +3 -0
- union/_internal/controllers/__init__.py +77 -0
- union/_internal/controllers/_local_controller.py +77 -0
- union/_internal/controllers/pbhash.py +39 -0
- union/_internal/controllers/remote/__init__.py +40 -0
- union/_internal/controllers/remote/_action.py +131 -0
- union/_internal/controllers/remote/_client.py +43 -0
- union/_internal/controllers/remote/_controller.py +169 -0
- union/_internal/controllers/remote/_core.py +341 -0
- union/_internal/controllers/remote/_informer.py +260 -0
- union/_internal/controllers/remote/_service_protocol.py +44 -0
- union/_internal/imagebuild/__init__.py +11 -0
- union/_internal/imagebuild/docker_builder.py +416 -0
- union/_internal/imagebuild/image_builder.py +243 -0
- union/_internal/imagebuild/remote_builder.py +0 -0
- union/_internal/resolvers/__init__.py +0 -0
- union/_internal/resolvers/_task_module.py +31 -0
- union/_internal/resolvers/common.py +24 -0
- union/_internal/resolvers/default.py +27 -0
- union/_internal/runtime/__init__.py +0 -0
- union/_internal/runtime/convert.py +163 -0
- union/_internal/runtime/entrypoints.py +121 -0
- union/_internal/runtime/io.py +136 -0
- union/_internal/runtime/resources_serde.py +134 -0
- union/_internal/runtime/task_serde.py +202 -0
- union/_internal/runtime/taskrunner.py +179 -0
- union/_internal/runtime/types_serde.py +53 -0
- union/_logging.py +124 -0
- union/_protos/__init__.py +0 -0
- union/_protos/common/authorization_pb2.py +66 -0
- union/_protos/common/authorization_pb2.pyi +106 -0
- union/_protos/common/authorization_pb2_grpc.py +4 -0
- union/_protos/common/identifier_pb2.py +71 -0
- union/_protos/common/identifier_pb2.pyi +82 -0
- union/_protos/common/identifier_pb2_grpc.py +4 -0
- union/_protos/common/identity_pb2.py +48 -0
- union/_protos/common/identity_pb2.pyi +72 -0
- union/_protos/common/identity_pb2_grpc.py +4 -0
- union/_protos/common/list_pb2.py +36 -0
- union/_protos/common/list_pb2.pyi +69 -0
- union/_protos/common/list_pb2_grpc.py +4 -0
- union/_protos/common/policy_pb2.py +37 -0
- union/_protos/common/policy_pb2.pyi +27 -0
- union/_protos/common/policy_pb2_grpc.py +4 -0
- union/_protos/common/role_pb2.py +37 -0
- union/_protos/common/role_pb2.pyi +51 -0
- union/_protos/common/role_pb2_grpc.py +4 -0
- union/_protos/common/runtime_version_pb2.py +28 -0
- union/_protos/common/runtime_version_pb2.pyi +24 -0
- union/_protos/common/runtime_version_pb2_grpc.py +4 -0
- union/_protos/logs/dataplane/payload_pb2.py +96 -0
- union/_protos/logs/dataplane/payload_pb2.pyi +168 -0
- union/_protos/logs/dataplane/payload_pb2_grpc.py +4 -0
- union/_protos/secret/definition_pb2.py +49 -0
- union/_protos/secret/definition_pb2.pyi +93 -0
- union/_protos/secret/definition_pb2_grpc.py +4 -0
- union/_protos/secret/payload_pb2.py +62 -0
- union/_protos/secret/payload_pb2.pyi +94 -0
- union/_protos/secret/payload_pb2_grpc.py +4 -0
- union/_protos/secret/secret_pb2.py +38 -0
- union/_protos/secret/secret_pb2.pyi +6 -0
- union/_protos/secret/secret_pb2_grpc.py +198 -0
- union/_protos/validate/validate/validate_pb2.py +76 -0
- union/_protos/workflow/node_execution_service_pb2.py +26 -0
- union/_protos/workflow/node_execution_service_pb2.pyi +4 -0
- union/_protos/workflow/node_execution_service_pb2_grpc.py +32 -0
- union/_protos/workflow/queue_service_pb2.py +75 -0
- union/_protos/workflow/queue_service_pb2.pyi +103 -0
- union/_protos/workflow/queue_service_pb2_grpc.py +172 -0
- union/_protos/workflow/run_definition_pb2.py +100 -0
- union/_protos/workflow/run_definition_pb2.pyi +256 -0
- union/_protos/workflow/run_definition_pb2_grpc.py +4 -0
- union/_protos/workflow/run_logs_service_pb2.py +41 -0
- union/_protos/workflow/run_logs_service_pb2.pyi +28 -0
- union/_protos/workflow/run_logs_service_pb2_grpc.py +69 -0
- union/_protos/workflow/run_service_pb2.py +133 -0
- union/_protos/workflow/run_service_pb2.pyi +173 -0
- union/_protos/workflow/run_service_pb2_grpc.py +412 -0
- union/_protos/workflow/state_service_pb2.py +58 -0
- union/_protos/workflow/state_service_pb2.pyi +69 -0
- union/_protos/workflow/state_service_pb2_grpc.py +138 -0
- union/_protos/workflow/task_definition_pb2.py +72 -0
- union/_protos/workflow/task_definition_pb2.pyi +65 -0
- union/_protos/workflow/task_definition_pb2_grpc.py +4 -0
- union/_protos/workflow/task_service_pb2.py +44 -0
- union/_protos/workflow/task_service_pb2.pyi +31 -0
- union/_protos/workflow/task_service_pb2_grpc.py +104 -0
- union/_resources.py +226 -0
- union/_retry.py +32 -0
- union/_reusable_environment.py +25 -0
- union/_run.py +374 -0
- union/_secret.py +61 -0
- union/_task.py +354 -0
- union/_task_environment.py +186 -0
- union/_timeout.py +47 -0
- union/_tools.py +27 -0
- union/_utils/__init__.py +11 -0
- union/_utils/asyn.py +119 -0
- union/_utils/file_handling.py +71 -0
- union/_utils/helpers.py +46 -0
- union/_utils/lazy_module.py +54 -0
- union/_utils/uv_script_parser.py +49 -0
- union/_version.py +21 -0
- union/connectors/__init__.py +0 -0
- union/errors.py +128 -0
- union/extras/__init__.py +5 -0
- union/extras/_container.py +263 -0
- union/io/__init__.py +11 -0
- union/io/_dataframe.py +0 -0
- union/io/_dir.py +425 -0
- union/io/_file.py +418 -0
- union/io/pickle/__init__.py +0 -0
- union/io/pickle/transformer.py +117 -0
- union/io/structured_dataset/__init__.py +122 -0
- union/io/structured_dataset/basic_dfs.py +219 -0
- union/io/structured_dataset/structured_dataset.py +1057 -0
- union/py.typed +0 -0
- union/remote/__init__.py +23 -0
- union/remote/_client/__init__.py +0 -0
- union/remote/_client/_protocols.py +129 -0
- union/remote/_client/auth/__init__.py +12 -0
- union/remote/_client/auth/_authenticators/__init__.py +0 -0
- union/remote/_client/auth/_authenticators/base.py +391 -0
- union/remote/_client/auth/_authenticators/client_credentials.py +73 -0
- union/remote/_client/auth/_authenticators/device_code.py +120 -0
- union/remote/_client/auth/_authenticators/external_command.py +77 -0
- union/remote/_client/auth/_authenticators/factory.py +200 -0
- union/remote/_client/auth/_authenticators/pkce.py +515 -0
- union/remote/_client/auth/_channel.py +184 -0
- union/remote/_client/auth/_client_config.py +83 -0
- union/remote/_client/auth/_default_html.py +32 -0
- union/remote/_client/auth/_grpc_utils/__init__.py +0 -0
- union/remote/_client/auth/_grpc_utils/auth_interceptor.py +204 -0
- union/remote/_client/auth/_grpc_utils/default_metadata_interceptor.py +144 -0
- union/remote/_client/auth/_keyring.py +154 -0
- union/remote/_client/auth/_token_client.py +258 -0
- union/remote/_client/auth/errors.py +16 -0
- union/remote/_client/controlplane.py +86 -0
- union/remote/_data.py +149 -0
- union/remote/_logs.py +74 -0
- union/remote/_project.py +86 -0
- union/remote/_run.py +820 -0
- union/remote/_secret.py +132 -0
- union/remote/_task.py +193 -0
- union/report/__init__.py +3 -0
- union/report/_report.py +178 -0
- union/report/_template.html +124 -0
- union/storage/__init__.py +24 -0
- union/storage/_remote_fs.py +34 -0
- union/storage/_storage.py +247 -0
- union/storage/_utils.py +5 -0
- union/types/__init__.py +11 -0
- union/types/_renderer.py +162 -0
- union/types/_string_literals.py +120 -0
- union/types/_type_engine.py +2131 -0
- union/types/_utils.py +80 -0
|
@@ -0,0 +1,515 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import base64
|
|
5
|
+
import hashlib
|
|
6
|
+
import os
|
|
7
|
+
import re
|
|
8
|
+
import threading
|
|
9
|
+
import time
|
|
10
|
+
import typing
|
|
11
|
+
import webbrowser
|
|
12
|
+
from http import HTTPStatus as _StatusCodes
|
|
13
|
+
from queue import Queue
|
|
14
|
+
from urllib import parse as _urlparse
|
|
15
|
+
from urllib.parse import urlencode as _urlencode
|
|
16
|
+
|
|
17
|
+
import click
|
|
18
|
+
import httpx
|
|
19
|
+
import pydantic
|
|
20
|
+
|
|
21
|
+
from union._logging import logger
|
|
22
|
+
from union.remote._client.auth._authenticators.base import Authenticator
|
|
23
|
+
from union.remote._client.auth._default_html import get_default_success_html
|
|
24
|
+
from union.remote._client.auth._keyring import Credentials
|
|
25
|
+
from union.remote._client.auth.errors import AccessTokenNotFoundError
|
|
26
|
+
|
|
27
|
+
_utf_8 = "utf-8"
|
|
28
|
+
_code_verifier_length = 64
|
|
29
|
+
_random_seed_length = 40
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class PKCEAuthenticator(Authenticator):
|
|
33
|
+
"""
|
|
34
|
+
This Authenticator encapsulates the entire PKCE flow and automatically opens a browser window for login
|
|
35
|
+
|
|
36
|
+
For Auth0 - you will need to manually configure your config.yaml to include a scopes list of the syntax:
|
|
37
|
+
admin.scopes: ["offline_access", "offline", "all", "openid"] and/or similar scopes in order to get the refresh
|
|
38
|
+
token + caching. Otherwise, it will just receive the access token alone. Your FlyteCTL Helm config however should
|
|
39
|
+
only contain ["offline", "all"] - as OIDC scopes are not-grantable in Auth0 customer APIs. They are simply requested
|
|
40
|
+
for in the POST request during the token caching process.
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
def __init__(
|
|
44
|
+
self,
|
|
45
|
+
**kwargs,
|
|
46
|
+
):
|
|
47
|
+
"""
|
|
48
|
+
Initialize with default creds from KeyStore using the endpoint name
|
|
49
|
+
|
|
50
|
+
:param kwargs: Keyword arguments passed to the base Authenticator
|
|
51
|
+
|
|
52
|
+
**Keyword Arguments passed to base Authenticator**:
|
|
53
|
+
:param endpoint: The endpoint URL for authentication (required)
|
|
54
|
+
:param cfg_store: Optional client configuration store for retrieving remote configuration
|
|
55
|
+
:param client_config: Optional client configuration containing authentication settings
|
|
56
|
+
:param credentials: Optional credentials to use for authentication
|
|
57
|
+
:param http_session: Optional HTTP session to use for requests
|
|
58
|
+
:param http_proxy_url: Optional HTTP proxy URL
|
|
59
|
+
:param verify: Whether to verify SSL certificates (default: True)
|
|
60
|
+
:param ca_cert_path: Optional path to CA certificate file
|
|
61
|
+
:param client_id: Client ID for authentication
|
|
62
|
+
:param scopes: List of scopes to request during authentication
|
|
63
|
+
:param audience: Audience for the token
|
|
64
|
+
:param redirect_uri: OAuth2 redirect URI for authentication
|
|
65
|
+
:param authorization_endpoint: Authorization endpoint for OAuth2 flow
|
|
66
|
+
:param token_endpoint: Token endpoint for OAuth2 flow
|
|
67
|
+
:param add_request_auth_code_params_to_request_access_token_params: Whether to add auth code params to token
|
|
68
|
+
request
|
|
69
|
+
:param request_auth_code_params: Parameters to add to login URI opened in browser
|
|
70
|
+
:param request_access_token_params: Parameters to add when exchanging auth code for access token
|
|
71
|
+
:param refresh_access_token_params: Parameters to add when refreshing access token
|
|
72
|
+
"""
|
|
73
|
+
super().__init__(**kwargs)
|
|
74
|
+
self._auth_client = None
|
|
75
|
+
|
|
76
|
+
async def _initialize_auth_client(self):
|
|
77
|
+
if not self._auth_client:
|
|
78
|
+
code_verifier = await _generate_code_verifier()
|
|
79
|
+
code_challenge = await _create_code_challenge(code_verifier)
|
|
80
|
+
|
|
81
|
+
cfg = await self._resolve_config()
|
|
82
|
+
self._auth_client = AuthorizationClient(
|
|
83
|
+
endpoint=self._endpoint,
|
|
84
|
+
redirect_uri=cfg.redirect_uri,
|
|
85
|
+
client_id=cfg.client_id,
|
|
86
|
+
# Audience only needed for Auth0 - Taken from client config
|
|
87
|
+
audience=cfg.audience,
|
|
88
|
+
scopes=cfg.scopes,
|
|
89
|
+
# self._scopes refers to flytekit.configuration.PlatformConfig (config.yaml)
|
|
90
|
+
# cfg.scopes refers to PublicClientConfig scopes (can be defined in Helm deployments)
|
|
91
|
+
auth_endpoint=cfg.authorization_endpoint,
|
|
92
|
+
token_endpoint=cfg.token_endpoint,
|
|
93
|
+
verify=self._verify,
|
|
94
|
+
http_session=self._http_session,
|
|
95
|
+
request_auth_code_params={
|
|
96
|
+
"code_challenge": code_challenge,
|
|
97
|
+
"code_challenge_method": "S256",
|
|
98
|
+
},
|
|
99
|
+
request_access_token_params={
|
|
100
|
+
"code_verifier": code_verifier,
|
|
101
|
+
},
|
|
102
|
+
refresh_access_token_params={},
|
|
103
|
+
add_request_auth_code_params_to_request_access_token_params=True,
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
async def _do_refresh_credentials(self) -> Credentials:
|
|
107
|
+
"""
|
|
108
|
+
Refreshes the authentication credentials using PKCE flow.
|
|
109
|
+
|
|
110
|
+
First attempts to refresh using a refresh token if available.
|
|
111
|
+
If that fails or if no credentials exist, initiates the full PKCE authorization flow,
|
|
112
|
+
which typically involves opening a browser for user authentication.
|
|
113
|
+
|
|
114
|
+
This method initializes the auth client if needed, then attempts to refresh or acquire
|
|
115
|
+
new credentials, and updates the internal credentials object.
|
|
116
|
+
|
|
117
|
+
:raises: May raise authentication-related exceptions if the refresh fails
|
|
118
|
+
"""
|
|
119
|
+
await self._initialize_auth_client()
|
|
120
|
+
if self._creds:
|
|
121
|
+
"""We have an access token so lets try to refresh it"""
|
|
122
|
+
try:
|
|
123
|
+
return await self._auth_client.refresh_access_token(self._creds)
|
|
124
|
+
except AccessTokenNotFoundError:
|
|
125
|
+
logger.warning("Failed to refresh token. Kicking off a full authorization flow.")
|
|
126
|
+
|
|
127
|
+
return await self._auth_client.get_creds_from_remote()
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
class AuthorizationClient(object):
|
|
131
|
+
"""
|
|
132
|
+
Authorization client that stores the credentials in keyring and uses oauth2 standard flow to retrieve the
|
|
133
|
+
credentials. NOTE: This will open an web browser to retrieve the credentials.
|
|
134
|
+
"""
|
|
135
|
+
|
|
136
|
+
def __init__(
|
|
137
|
+
self,
|
|
138
|
+
endpoint: str,
|
|
139
|
+
auth_endpoint: str,
|
|
140
|
+
token_endpoint: str,
|
|
141
|
+
http_session: httpx.AsyncClient,
|
|
142
|
+
audience: typing.Optional[str] = None,
|
|
143
|
+
scopes: typing.Optional[typing.List[str]] = None,
|
|
144
|
+
client_id: typing.Optional[str] = None,
|
|
145
|
+
redirect_uri: typing.Optional[str] = None,
|
|
146
|
+
endpoint_metadata: typing.Optional[EndpointMetadata] = None,
|
|
147
|
+
verify: bool = True,
|
|
148
|
+
ca_cert_path: typing.Optional[str] = None,
|
|
149
|
+
request_auth_code_params: typing.Optional[typing.Dict[str, str]] = None,
|
|
150
|
+
request_access_token_params: typing.Optional[typing.Dict[str, str]] = None,
|
|
151
|
+
refresh_access_token_params: typing.Optional[typing.Dict[str, str]] = None,
|
|
152
|
+
add_request_auth_code_params_to_request_access_token_params: typing.Optional[bool] = False,
|
|
153
|
+
):
|
|
154
|
+
"""
|
|
155
|
+
Create new AuthorizationClient
|
|
156
|
+
|
|
157
|
+
:param endpoint: The endpoint URL to connect to
|
|
158
|
+
:param auth_endpoint: The endpoint URL where auth metadata can be found
|
|
159
|
+
:param token_endpoint: The endpoint URL to retrieve token from
|
|
160
|
+
:param http_session: A custom httpx.AsyncClient object to use for making HTTP requests
|
|
161
|
+
:param audience: Audience parameter for Auth0 (optional)
|
|
162
|
+
:param scopes: List of OAuth2 scopes to request during authentication
|
|
163
|
+
:param client_id: OAuth2 client ID for authentication
|
|
164
|
+
:param redirect_uri: OAuth2 redirect URI for authentication callback
|
|
165
|
+
:param endpoint_metadata: EndpointMetadata object to control the rendering of the page on login successful or
|
|
166
|
+
failure
|
|
167
|
+
:param verify: A boolean that controls whether to verify the server's TLS certificate.
|
|
168
|
+
Defaults to ``True``. When set to ``False``, requests will accept any TLS certificate
|
|
169
|
+
presented by the server, and will ignore hostname mismatches and/or expired certificates,
|
|
170
|
+
which will make your application vulnerable to man-in-the-middle (MitM) attacks.
|
|
171
|
+
Setting verify to ``False`` may be useful during local development or testing.
|
|
172
|
+
:param ca_cert_path: Path to a certificate chain file for SSL verification (optional)
|
|
173
|
+
:param request_auth_code_params: Dictionary of parameters to add to login URI opened in the browser (optional)
|
|
174
|
+
:param request_access_token_params: Dictionary of parameters to add when exchanging the auth code for the
|
|
175
|
+
access token (optional)
|
|
176
|
+
:param refresh_access_token_params: Dictionary of parameters to add when refreshing the access token (optional)
|
|
177
|
+
:param add_request_auth_code_params_to_request_access_token_params: Whether to add the
|
|
178
|
+
`request_auth_code_params` to the parameters sent when exchanging the auth code for the access token.
|
|
179
|
+
Defaults to False. Required for the PKCE flow with the backend. Not required for the standard OAuth2 flow
|
|
180
|
+
on GCP.
|
|
181
|
+
"""
|
|
182
|
+
self._endpoint = endpoint
|
|
183
|
+
self._auth_endpoint = auth_endpoint
|
|
184
|
+
if endpoint_metadata is None:
|
|
185
|
+
remote_url = _urlparse.urlparse(self._auth_endpoint)
|
|
186
|
+
self._remote = EndpointMetadata(endpoint=remote_url.hostname)
|
|
187
|
+
else:
|
|
188
|
+
self._remote = endpoint_metadata
|
|
189
|
+
self._token_endpoint = token_endpoint
|
|
190
|
+
self._client_id = client_id
|
|
191
|
+
self._audience = audience
|
|
192
|
+
self._scopes = scopes or []
|
|
193
|
+
self._redirect_uri = redirect_uri
|
|
194
|
+
state = _generate_state_parameter()
|
|
195
|
+
self._state = state
|
|
196
|
+
self._verify = verify
|
|
197
|
+
self._ca_cert_path = ca_cert_path
|
|
198
|
+
self._headers = {"content-type": "application/x-www-form-urlencoded"}
|
|
199
|
+
self._lock = threading.Lock()
|
|
200
|
+
self._cached_credentials = None
|
|
201
|
+
self._cached_credentials_ts = None
|
|
202
|
+
self._http_session = http_session
|
|
203
|
+
|
|
204
|
+
self._request_auth_code_params = {
|
|
205
|
+
"client_id": client_id, # This must match the Client ID of the OAuth application.
|
|
206
|
+
"response_type": "code", # Indicates the authorization code grant
|
|
207
|
+
"scope": " ".join(s.strip("' ") for s in self._scopes).strip(
|
|
208
|
+
"[]'"
|
|
209
|
+
), # ensures that the /token endpoint returns an ID and refresh token
|
|
210
|
+
# callback location where the user-agent will be directed to.
|
|
211
|
+
"redirect_uri": self._redirect_uri,
|
|
212
|
+
"state": state,
|
|
213
|
+
}
|
|
214
|
+
|
|
215
|
+
# Conditionally add audience param if provided - value is not None
|
|
216
|
+
if self._audience:
|
|
217
|
+
self._request_auth_code_params["audience"] = self._audience
|
|
218
|
+
|
|
219
|
+
if request_auth_code_params:
|
|
220
|
+
# Allow adding additional parameters to the request_auth_code_params
|
|
221
|
+
self._request_auth_code_params.update(request_auth_code_params)
|
|
222
|
+
|
|
223
|
+
self._request_access_token_params = request_access_token_params or {}
|
|
224
|
+
self._refresh_access_token_params = refresh_access_token_params or {}
|
|
225
|
+
|
|
226
|
+
if add_request_auth_code_params_to_request_access_token_params:
|
|
227
|
+
self._request_access_token_params.update(self._request_auth_code_params)
|
|
228
|
+
|
|
229
|
+
def __repr__(self):
|
|
230
|
+
return (
|
|
231
|
+
f"AuthorizationClient({self._auth_endpoint}, {self._token_endpoint}, {self._client_id}, {self._scopes},"
|
|
232
|
+
f" {self._redirect_uri})"
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
async def _create_callback_server(self):
|
|
236
|
+
server_url = _urlparse.urlparse(self._redirect_uri)
|
|
237
|
+
server_address = (server_url.hostname, server_url.port)
|
|
238
|
+
queue = Queue()
|
|
239
|
+
handler = OAuthCallbackHandler(queue, self._remote, server_url.path)
|
|
240
|
+
server = await asyncio.start_server(handler.handle, server_address[0], server_address[1])
|
|
241
|
+
return server, queue, handler
|
|
242
|
+
|
|
243
|
+
async def _request_authorization_code(self):
|
|
244
|
+
scheme, netloc, path, _, _, _ = _urlparse.urlparse(self._auth_endpoint)
|
|
245
|
+
query = _urlencode(self._request_auth_code_params)
|
|
246
|
+
endpoint = _urlparse.urlunparse((scheme, netloc, path, None, query, None))
|
|
247
|
+
logger.debug(f"Requesting authorization code through {endpoint}")
|
|
248
|
+
|
|
249
|
+
success = webbrowser.open_new_tab(endpoint) # type: ignore
|
|
250
|
+
if not success:
|
|
251
|
+
click.secho(f"Please open the following link in your browser to authenticate: {endpoint}")
|
|
252
|
+
|
|
253
|
+
async def _credentials_from_response(self, auth_token_resp) -> Credentials:
|
|
254
|
+
"""
|
|
255
|
+
Extracts credentials from the authentication token response.
|
|
256
|
+
|
|
257
|
+
The auth_token_resp body is of the form:
|
|
258
|
+
{
|
|
259
|
+
"access_token": "foo",
|
|
260
|
+
"refresh_token": "bar",
|
|
261
|
+
"token_type": "Bearer"
|
|
262
|
+
}
|
|
263
|
+
|
|
264
|
+
Can additionally contain "expires_in" and "id_token" fields.
|
|
265
|
+
|
|
266
|
+
:param auth_token_resp: The HTTP response containing the token information
|
|
267
|
+
:return: Credentials object created from the response
|
|
268
|
+
:raises ValueError: If the response does not contain an access token
|
|
269
|
+
"""
|
|
270
|
+
response_body = auth_token_resp.json()
|
|
271
|
+
refresh_token = None
|
|
272
|
+
expires_in = None
|
|
273
|
+
id_token = None
|
|
274
|
+
if "access_token" not in response_body:
|
|
275
|
+
raise ValueError('Expected "access_token" in response from oauth server')
|
|
276
|
+
if "refresh_token" in response_body:
|
|
277
|
+
refresh_token = response_body["refresh_token"]
|
|
278
|
+
if "expires_in" in response_body:
|
|
279
|
+
expires_in = response_body["expires_in"]
|
|
280
|
+
access_token = response_body["access_token"]
|
|
281
|
+
if "id_token" in response_body:
|
|
282
|
+
id_token = response_body["id_token"]
|
|
283
|
+
|
|
284
|
+
return Credentials(
|
|
285
|
+
access_token=access_token,
|
|
286
|
+
refresh_token=refresh_token,
|
|
287
|
+
for_endpoint=self._endpoint,
|
|
288
|
+
expires_in=expires_in,
|
|
289
|
+
id_token=id_token,
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
async def _request_access_token(self, auth_code) -> Credentials:
|
|
293
|
+
if self._state != auth_code.state:
|
|
294
|
+
raise ValueError(f"Unexpected state parameter [{auth_code.state}] passed")
|
|
295
|
+
|
|
296
|
+
params = {
|
|
297
|
+
"code": auth_code.code,
|
|
298
|
+
"grant_type": "authorization_code",
|
|
299
|
+
}
|
|
300
|
+
|
|
301
|
+
params.update(self._request_access_token_params)
|
|
302
|
+
|
|
303
|
+
resp = await self._http_session.post(
|
|
304
|
+
url=self._token_endpoint,
|
|
305
|
+
data=params,
|
|
306
|
+
headers=self._headers,
|
|
307
|
+
follow_redirects=False,
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
if resp.status_code != _StatusCodes.OK:
|
|
311
|
+
raise RuntimeError(
|
|
312
|
+
"Failed to request access token with response: [{}] {}".format(resp.status_code, resp.content)
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
return await self._credentials_from_response(resp)
|
|
316
|
+
|
|
317
|
+
async def get_creds_from_remote(self) -> Credentials:
|
|
318
|
+
"""
|
|
319
|
+
This is the entrypoint method. It will kickoff the full authentication
|
|
320
|
+
flow and trigger a web-browser to retrieve credentials. Because this
|
|
321
|
+
needs to open a port on localhost and may be called from a
|
|
322
|
+
multithreaded context (e.g. pyflyte register), this call may block
|
|
323
|
+
multiple threads and return a cached result for up to 60 seconds.
|
|
324
|
+
|
|
325
|
+
:return: Credentials obtained from the authentication flow
|
|
326
|
+
:raises: May raise authentication-related exceptions if the flow fails
|
|
327
|
+
"""
|
|
328
|
+
# In the absence of globally-set token values, initiate the token request flow
|
|
329
|
+
with self._lock:
|
|
330
|
+
# Clear cache if it's been more than 60 seconds since the last check
|
|
331
|
+
cache_ttl_s = 60
|
|
332
|
+
if self._cached_credentials_ts is not None and self._cached_credentials_ts + cache_ttl_s < time.monotonic():
|
|
333
|
+
self._cached_credentials = None
|
|
334
|
+
|
|
335
|
+
if self._cached_credentials is not None:
|
|
336
|
+
return self._cached_credentials
|
|
337
|
+
|
|
338
|
+
server, queue, handler = await self._create_callback_server()
|
|
339
|
+
async with server:
|
|
340
|
+
await self._request_authorization_code()
|
|
341
|
+
# Wait for the callback handler to receive a response instead of serving forever
|
|
342
|
+
await handler.response_received.wait()
|
|
343
|
+
|
|
344
|
+
auth_code = queue.get()
|
|
345
|
+
self._cached_credentials = await self._request_access_token(auth_code)
|
|
346
|
+
self._cached_credentials_ts = time.monotonic()
|
|
347
|
+
return self._cached_credentials
|
|
348
|
+
|
|
349
|
+
async def refresh_access_token(self, credentials: Credentials) -> Credentials:
|
|
350
|
+
"""
|
|
351
|
+
Refreshes the access token using the refresh token from the provided credentials.
|
|
352
|
+
|
|
353
|
+
:param credentials: The credentials containing the refresh token to use
|
|
354
|
+
:return: Updated credentials with a new access token
|
|
355
|
+
:raises AccessTokenNotFoundError: If no refresh token is available in the credentials
|
|
356
|
+
"""
|
|
357
|
+
if credentials.refresh_token is None:
|
|
358
|
+
raise AccessTokenNotFoundError("no refresh token available with which to refresh authorization credentials")
|
|
359
|
+
|
|
360
|
+
data = {
|
|
361
|
+
"refresh_token": credentials.refresh_token,
|
|
362
|
+
"grant_type": "refresh_token",
|
|
363
|
+
"client_id": self._client_id,
|
|
364
|
+
}
|
|
365
|
+
|
|
366
|
+
data.update(self._refresh_access_token_params)
|
|
367
|
+
|
|
368
|
+
async with self._http_session.post(
|
|
369
|
+
url=self._token_endpoint,
|
|
370
|
+
data=data,
|
|
371
|
+
headers=self._headers,
|
|
372
|
+
follow_redirects=False,
|
|
373
|
+
) as resp:
|
|
374
|
+
if resp.status_code != _StatusCodes.OK:
|
|
375
|
+
raise AccessTokenNotFoundError(f"Non-200 returned from refresh token endpoint {resp.status_code}")
|
|
376
|
+
|
|
377
|
+
return await self._credentials_from_response(resp)
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
class OAuthCallbackHandler:
|
|
381
|
+
"""
|
|
382
|
+
Handles OAuth2 callback requests during the authentication flow.
|
|
383
|
+
|
|
384
|
+
This class implements an HTTP request handler that processes the callback from the OAuth2 provider,
|
|
385
|
+
extracts the authorization code, and passes it to the authentication flow.
|
|
386
|
+
"""
|
|
387
|
+
|
|
388
|
+
def __init__(self, queue: Queue, remote_metadata: EndpointMetadata, redirect_path: str):
|
|
389
|
+
"""
|
|
390
|
+
Initialize the OAuth callback handler.
|
|
391
|
+
|
|
392
|
+
:param queue: Queue to put the authorization code into when received
|
|
393
|
+
:param remote_metadata: Metadata about the remote endpoint for rendering success/failure pages
|
|
394
|
+
:param redirect_path: The path component of the redirect URI to match incoming requests against
|
|
395
|
+
"""
|
|
396
|
+
self.queue = queue
|
|
397
|
+
self.remote_metadata = remote_metadata
|
|
398
|
+
self.redirect_path = redirect_path
|
|
399
|
+
self.response_received = asyncio.Event()
|
|
400
|
+
|
|
401
|
+
async def handle(self, reader, writer):
|
|
402
|
+
"""
|
|
403
|
+
Handles an incoming HTTP request during the OAuth2 callback.
|
|
404
|
+
|
|
405
|
+
This method reads the incoming HTTP request, parses it, and if it matches the expected redirect path,
|
|
406
|
+
extracts the authorization code and state from the query parameters and puts them in the queue.
|
|
407
|
+
It then responds with an appropriate HTTP response.
|
|
408
|
+
|
|
409
|
+
:param reader: The StreamReader for reading the incoming request
|
|
410
|
+
:param writer: The StreamWriter for writing the response
|
|
411
|
+
"""
|
|
412
|
+
data = await reader.read(1024)
|
|
413
|
+
message = data.decode()
|
|
414
|
+
headers = message.split("\r\n")
|
|
415
|
+
path = headers[0].split(" ")[1]
|
|
416
|
+
url = _urlparse.urlparse(path)
|
|
417
|
+
if url.path.strip("/") == self.redirect_path.strip("/"):
|
|
418
|
+
response = f"HTTP/1.1 {_StatusCodes.OK.value} {_StatusCodes.OK.phrase}\r\n"
|
|
419
|
+
response += "Content-Type: text/html\r\n\r\n"
|
|
420
|
+
self.handle_login(dict(_urlparse.parse_qsl(url.query)))
|
|
421
|
+
if self.remote_metadata.success_html is None:
|
|
422
|
+
response += get_default_success_html(self.remote_metadata.endpoint)
|
|
423
|
+
writer.write(response.encode(_utf_8))
|
|
424
|
+
await writer.drain()
|
|
425
|
+
else:
|
|
426
|
+
response = f"HTTP/1.1 {_StatusCodes.NOT_FOUND.value} {_StatusCodes.NOT_FOUND.phrase}\r\n\r\n"
|
|
427
|
+
writer.write(response.encode(_utf_8))
|
|
428
|
+
await writer.drain()
|
|
429
|
+
writer.close()
|
|
430
|
+
# Signal that we've received a response
|
|
431
|
+
self.response_received.set()
|
|
432
|
+
|
|
433
|
+
def handle_login(self, data: dict):
|
|
434
|
+
"""
|
|
435
|
+
Processes the login data from the OAuth2 callback.
|
|
436
|
+
|
|
437
|
+
Extracts the authorization code and state from the query parameters and puts them in the queue
|
|
438
|
+
for the authentication flow to process.
|
|
439
|
+
|
|
440
|
+
:param data: Dictionary containing the query parameters from the callback URL
|
|
441
|
+
"""
|
|
442
|
+
self.queue.put(AuthorizationCode(code=data["code"], state=data["state"]))
|
|
443
|
+
|
|
444
|
+
|
|
445
|
+
class EndpointMetadata(pydantic.BaseModel):
|
|
446
|
+
"""
|
|
447
|
+
This class can be used to control the rendering of the page on login successful or failure.
|
|
448
|
+
|
|
449
|
+
:param endpoint: The endpoint URL or hostname for the remote service
|
|
450
|
+
:param success_html: Optional HTML content to display on successful authentication
|
|
451
|
+
:param failure_html: Optional HTML content to display on authentication failure
|
|
452
|
+
"""
|
|
453
|
+
|
|
454
|
+
endpoint: str
|
|
455
|
+
success_html: typing.Optional[bytes] = None
|
|
456
|
+
failure_html: typing.Optional[bytes] = None
|
|
457
|
+
|
|
458
|
+
|
|
459
|
+
class AuthorizationCode(pydantic.BaseModel):
|
|
460
|
+
"""
|
|
461
|
+
Represents an authorization code received from the OAuth2 provider.
|
|
462
|
+
|
|
463
|
+
:param code: The authorization code received from the OAuth2 provider
|
|
464
|
+
:param state: The state parameter that was sent in the original request
|
|
465
|
+
"""
|
|
466
|
+
|
|
467
|
+
code: str
|
|
468
|
+
state: str
|
|
469
|
+
|
|
470
|
+
|
|
471
|
+
async def _create_code_challenge(code_verifier):
|
|
472
|
+
"""
|
|
473
|
+
Creates a code challenge for PKCE flow from the provided code verifier.
|
|
474
|
+
Adapted from https://github.com/openstack/deb-python-oauth2client/blob/master/oauth2client/_pkce.py.
|
|
475
|
+
|
|
476
|
+
:param str code_verifier: A code verifier string generated by _generate_code_verifier()
|
|
477
|
+
:return str: Urlsafe base64-encoded sha256 hash digest of the code verifier
|
|
478
|
+
"""
|
|
479
|
+
code_challenge = hashlib.sha256(code_verifier.encode(_utf_8)).digest()
|
|
480
|
+
code_challenge = base64.urlsafe_b64encode(code_challenge).decode(_utf_8)
|
|
481
|
+
# Eliminate invalid characters
|
|
482
|
+
code_challenge = code_challenge.replace("=", "")
|
|
483
|
+
return code_challenge
|
|
484
|
+
|
|
485
|
+
|
|
486
|
+
def _generate_state_parameter():
|
|
487
|
+
"""
|
|
488
|
+
Generates a random state parameter for OAuth2 authorization requests.
|
|
489
|
+
|
|
490
|
+
The state parameter is used to maintain state between the request and callback
|
|
491
|
+
and to prevent cross-site request forgery attacks.
|
|
492
|
+
|
|
493
|
+
:return: A random string to use as the state parameter
|
|
494
|
+
"""
|
|
495
|
+
state = base64.urlsafe_b64encode(os.urandom(_random_seed_length)).decode(_utf_8)
|
|
496
|
+
# Eliminate invalid characters.
|
|
497
|
+
code_verifier = re.sub("[^a-zA-Z0-9-_.,]+", "", state)
|
|
498
|
+
return code_verifier
|
|
499
|
+
|
|
500
|
+
|
|
501
|
+
async def _generate_code_verifier():
|
|
502
|
+
"""
|
|
503
|
+
Generates a 'code_verifier' for PKCE OAuth2 flow as described in RFC 7636 section 4.1.
|
|
504
|
+
Adapted from https://github.com/openstack/deb-python-oauth2client/blob/master/oauth2client/_pkce.py.
|
|
505
|
+
|
|
506
|
+
:return str: A random string to use as the code verifier
|
|
507
|
+
"""
|
|
508
|
+
code_verifier = base64.urlsafe_b64encode(os.urandom(_code_verifier_length)).decode(_utf_8)
|
|
509
|
+
# Eliminate invalid characters.
|
|
510
|
+
code_verifier = re.sub(r"[^a-zA-Z0-9_\-.~]+", "", code_verifier)
|
|
511
|
+
if len(code_verifier) < 43:
|
|
512
|
+
raise ValueError("Verifier too short. number of bytes must be > 30.")
|
|
513
|
+
elif len(code_verifier) > 128:
|
|
514
|
+
raise ValueError("Verifier too long. number of bytes must be < 97.")
|
|
515
|
+
return code_verifier
|