flyte 0.0.1b0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of flyte might be problematic. Click here for more details.
- flyte/__init__.py +62 -0
- flyte/_api_commons.py +3 -0
- flyte/_bin/__init__.py +0 -0
- flyte/_bin/runtime.py +126 -0
- flyte/_build.py +25 -0
- flyte/_cache/__init__.py +12 -0
- flyte/_cache/cache.py +146 -0
- flyte/_cache/defaults.py +9 -0
- flyte/_cache/policy_function_body.py +42 -0
- flyte/_cli/__init__.py +0 -0
- flyte/_cli/_common.py +287 -0
- flyte/_cli/_create.py +42 -0
- flyte/_cli/_delete.py +23 -0
- flyte/_cli/_deploy.py +140 -0
- flyte/_cli/_get.py +235 -0
- flyte/_cli/_run.py +152 -0
- flyte/_cli/main.py +72 -0
- flyte/_code_bundle/__init__.py +8 -0
- flyte/_code_bundle/_ignore.py +113 -0
- flyte/_code_bundle/_packaging.py +187 -0
- flyte/_code_bundle/_utils.py +339 -0
- flyte/_code_bundle/bundle.py +178 -0
- flyte/_context.py +146 -0
- flyte/_datastructures.py +342 -0
- flyte/_deploy.py +202 -0
- flyte/_doc.py +29 -0
- flyte/_docstring.py +32 -0
- flyte/_environment.py +43 -0
- flyte/_group.py +31 -0
- flyte/_hash.py +23 -0
- flyte/_image.py +760 -0
- flyte/_initialize.py +634 -0
- flyte/_interface.py +84 -0
- flyte/_internal/__init__.py +3 -0
- flyte/_internal/controllers/__init__.py +115 -0
- flyte/_internal/controllers/_local_controller.py +118 -0
- flyte/_internal/controllers/_trace.py +40 -0
- flyte/_internal/controllers/pbhash.py +39 -0
- flyte/_internal/controllers/remote/__init__.py +40 -0
- flyte/_internal/controllers/remote/_action.py +141 -0
- flyte/_internal/controllers/remote/_client.py +43 -0
- flyte/_internal/controllers/remote/_controller.py +361 -0
- flyte/_internal/controllers/remote/_core.py +402 -0
- flyte/_internal/controllers/remote/_informer.py +361 -0
- flyte/_internal/controllers/remote/_service_protocol.py +50 -0
- flyte/_internal/imagebuild/__init__.py +11 -0
- flyte/_internal/imagebuild/docker_builder.py +416 -0
- flyte/_internal/imagebuild/image_builder.py +241 -0
- flyte/_internal/imagebuild/remote_builder.py +0 -0
- flyte/_internal/resolvers/__init__.py +0 -0
- flyte/_internal/resolvers/_task_module.py +54 -0
- flyte/_internal/resolvers/common.py +31 -0
- flyte/_internal/resolvers/default.py +28 -0
- flyte/_internal/runtime/__init__.py +0 -0
- flyte/_internal/runtime/convert.py +199 -0
- flyte/_internal/runtime/entrypoints.py +135 -0
- flyte/_internal/runtime/io.py +136 -0
- flyte/_internal/runtime/resources_serde.py +138 -0
- flyte/_internal/runtime/task_serde.py +210 -0
- flyte/_internal/runtime/taskrunner.py +190 -0
- flyte/_internal/runtime/types_serde.py +54 -0
- flyte/_logging.py +124 -0
- flyte/_protos/__init__.py +0 -0
- flyte/_protos/common/authorization_pb2.py +66 -0
- flyte/_protos/common/authorization_pb2.pyi +108 -0
- flyte/_protos/common/authorization_pb2_grpc.py +4 -0
- flyte/_protos/common/identifier_pb2.py +71 -0
- flyte/_protos/common/identifier_pb2.pyi +82 -0
- flyte/_protos/common/identifier_pb2_grpc.py +4 -0
- flyte/_protos/common/identity_pb2.py +48 -0
- flyte/_protos/common/identity_pb2.pyi +72 -0
- flyte/_protos/common/identity_pb2_grpc.py +4 -0
- flyte/_protos/common/list_pb2.py +36 -0
- flyte/_protos/common/list_pb2.pyi +69 -0
- flyte/_protos/common/list_pb2_grpc.py +4 -0
- flyte/_protos/common/policy_pb2.py +37 -0
- flyte/_protos/common/policy_pb2.pyi +27 -0
- flyte/_protos/common/policy_pb2_grpc.py +4 -0
- flyte/_protos/common/role_pb2.py +37 -0
- flyte/_protos/common/role_pb2.pyi +53 -0
- flyte/_protos/common/role_pb2_grpc.py +4 -0
- flyte/_protos/common/runtime_version_pb2.py +28 -0
- flyte/_protos/common/runtime_version_pb2.pyi +24 -0
- flyte/_protos/common/runtime_version_pb2_grpc.py +4 -0
- flyte/_protos/logs/dataplane/payload_pb2.py +96 -0
- flyte/_protos/logs/dataplane/payload_pb2.pyi +168 -0
- flyte/_protos/logs/dataplane/payload_pb2_grpc.py +4 -0
- flyte/_protos/secret/definition_pb2.py +49 -0
- flyte/_protos/secret/definition_pb2.pyi +93 -0
- flyte/_protos/secret/definition_pb2_grpc.py +4 -0
- flyte/_protos/secret/payload_pb2.py +62 -0
- flyte/_protos/secret/payload_pb2.pyi +94 -0
- flyte/_protos/secret/payload_pb2_grpc.py +4 -0
- flyte/_protos/secret/secret_pb2.py +38 -0
- flyte/_protos/secret/secret_pb2.pyi +6 -0
- flyte/_protos/secret/secret_pb2_grpc.py +198 -0
- flyte/_protos/secret/secret_pb2_grpc_grpc.py +198 -0
- flyte/_protos/validate/validate/validate_pb2.py +76 -0
- flyte/_protos/workflow/node_execution_service_pb2.py +26 -0
- flyte/_protos/workflow/node_execution_service_pb2.pyi +4 -0
- flyte/_protos/workflow/node_execution_service_pb2_grpc.py +32 -0
- flyte/_protos/workflow/queue_service_pb2.py +106 -0
- flyte/_protos/workflow/queue_service_pb2.pyi +141 -0
- flyte/_protos/workflow/queue_service_pb2_grpc.py +172 -0
- flyte/_protos/workflow/run_definition_pb2.py +128 -0
- flyte/_protos/workflow/run_definition_pb2.pyi +310 -0
- flyte/_protos/workflow/run_definition_pb2_grpc.py +4 -0
- flyte/_protos/workflow/run_logs_service_pb2.py +41 -0
- flyte/_protos/workflow/run_logs_service_pb2.pyi +28 -0
- flyte/_protos/workflow/run_logs_service_pb2_grpc.py +69 -0
- flyte/_protos/workflow/run_service_pb2.py +133 -0
- flyte/_protos/workflow/run_service_pb2.pyi +175 -0
- flyte/_protos/workflow/run_service_pb2_grpc.py +412 -0
- flyte/_protos/workflow/state_service_pb2.py +58 -0
- flyte/_protos/workflow/state_service_pb2.pyi +71 -0
- flyte/_protos/workflow/state_service_pb2_grpc.py +138 -0
- flyte/_protos/workflow/task_definition_pb2.py +72 -0
- flyte/_protos/workflow/task_definition_pb2.pyi +65 -0
- flyte/_protos/workflow/task_definition_pb2_grpc.py +4 -0
- flyte/_protos/workflow/task_service_pb2.py +44 -0
- flyte/_protos/workflow/task_service_pb2.pyi +31 -0
- flyte/_protos/workflow/task_service_pb2_grpc.py +104 -0
- flyte/_resources.py +226 -0
- flyte/_retry.py +32 -0
- flyte/_reusable_environment.py +25 -0
- flyte/_run.py +411 -0
- flyte/_secret.py +61 -0
- flyte/_task.py +367 -0
- flyte/_task_environment.py +200 -0
- flyte/_timeout.py +47 -0
- flyte/_tools.py +27 -0
- flyte/_trace.py +128 -0
- flyte/_utils/__init__.py +20 -0
- flyte/_utils/asyn.py +119 -0
- flyte/_utils/coro_management.py +25 -0
- flyte/_utils/file_handling.py +72 -0
- flyte/_utils/helpers.py +108 -0
- flyte/_utils/lazy_module.py +54 -0
- flyte/_utils/uv_script_parser.py +49 -0
- flyte/_version.py +21 -0
- flyte/connectors/__init__.py +0 -0
- flyte/errors.py +143 -0
- flyte/extras/__init__.py +5 -0
- flyte/extras/_container.py +273 -0
- flyte/io/__init__.py +11 -0
- flyte/io/_dataframe.py +0 -0
- flyte/io/_dir.py +448 -0
- flyte/io/_file.py +468 -0
- flyte/io/pickle/__init__.py +0 -0
- flyte/io/pickle/transformer.py +117 -0
- flyte/io/structured_dataset/__init__.py +129 -0
- flyte/io/structured_dataset/basic_dfs.py +219 -0
- flyte/io/structured_dataset/structured_dataset.py +1061 -0
- flyte/py.typed +0 -0
- flyte/remote/__init__.py +25 -0
- flyte/remote/_client/__init__.py +0 -0
- flyte/remote/_client/_protocols.py +131 -0
- flyte/remote/_client/auth/__init__.py +12 -0
- flyte/remote/_client/auth/_authenticators/__init__.py +0 -0
- flyte/remote/_client/auth/_authenticators/base.py +397 -0
- flyte/remote/_client/auth/_authenticators/client_credentials.py +73 -0
- flyte/remote/_client/auth/_authenticators/device_code.py +118 -0
- flyte/remote/_client/auth/_authenticators/external_command.py +79 -0
- flyte/remote/_client/auth/_authenticators/factory.py +200 -0
- flyte/remote/_client/auth/_authenticators/pkce.py +516 -0
- flyte/remote/_client/auth/_channel.py +184 -0
- flyte/remote/_client/auth/_client_config.py +83 -0
- flyte/remote/_client/auth/_default_html.py +32 -0
- flyte/remote/_client/auth/_grpc_utils/__init__.py +0 -0
- flyte/remote/_client/auth/_grpc_utils/auth_interceptor.py +288 -0
- flyte/remote/_client/auth/_grpc_utils/default_metadata_interceptor.py +151 -0
- flyte/remote/_client/auth/_keyring.py +143 -0
- flyte/remote/_client/auth/_token_client.py +260 -0
- flyte/remote/_client/auth/errors.py +16 -0
- flyte/remote/_client/controlplane.py +95 -0
- flyte/remote/_console.py +18 -0
- flyte/remote/_data.py +155 -0
- flyte/remote/_logs.py +116 -0
- flyte/remote/_project.py +86 -0
- flyte/remote/_run.py +873 -0
- flyte/remote/_secret.py +132 -0
- flyte/remote/_task.py +227 -0
- flyte/report/__init__.py +3 -0
- flyte/report/_report.py +178 -0
- flyte/report/_template.html +124 -0
- flyte/storage/__init__.py +24 -0
- flyte/storage/_remote_fs.py +34 -0
- flyte/storage/_storage.py +251 -0
- flyte/storage/_utils.py +5 -0
- flyte/types/__init__.py +13 -0
- flyte/types/_interface.py +25 -0
- flyte/types/_renderer.py +162 -0
- flyte/types/_string_literals.py +120 -0
- flyte/types/_type_engine.py +2210 -0
- flyte/types/_utils.py +80 -0
- flyte-0.0.1b0.dist-info/METADATA +179 -0
- flyte-0.0.1b0.dist-info/RECORD +390 -0
- flyte-0.0.1b0.dist-info/WHEEL +5 -0
- flyte-0.0.1b0.dist-info/entry_points.txt +3 -0
- flyte-0.0.1b0.dist-info/top_level.txt +1 -0
- union/__init__.py +54 -0
- union/_api_commons.py +3 -0
- union/_bin/__init__.py +0 -0
- union/_bin/runtime.py +113 -0
- union/_build.py +25 -0
- union/_cache/__init__.py +12 -0
- union/_cache/cache.py +141 -0
- union/_cache/defaults.py +9 -0
- union/_cache/policy_function_body.py +42 -0
- union/_cli/__init__.py +0 -0
- union/_cli/_common.py +263 -0
- union/_cli/_create.py +40 -0
- union/_cli/_delete.py +23 -0
- union/_cli/_deploy.py +120 -0
- union/_cli/_get.py +162 -0
- union/_cli/_params.py +579 -0
- union/_cli/_run.py +150 -0
- union/_cli/main.py +72 -0
- union/_code_bundle/__init__.py +8 -0
- union/_code_bundle/_ignore.py +113 -0
- union/_code_bundle/_packaging.py +187 -0
- union/_code_bundle/_utils.py +342 -0
- union/_code_bundle/bundle.py +176 -0
- union/_context.py +146 -0
- union/_datastructures.py +295 -0
- union/_deploy.py +185 -0
- union/_doc.py +29 -0
- union/_docstring.py +26 -0
- union/_environment.py +43 -0
- union/_group.py +31 -0
- union/_hash.py +23 -0
- union/_image.py +760 -0
- union/_initialize.py +585 -0
- union/_interface.py +84 -0
- union/_internal/__init__.py +3 -0
- union/_internal/controllers/__init__.py +77 -0
- union/_internal/controllers/_local_controller.py +77 -0
- union/_internal/controllers/pbhash.py +39 -0
- union/_internal/controllers/remote/__init__.py +40 -0
- union/_internal/controllers/remote/_action.py +131 -0
- union/_internal/controllers/remote/_client.py +43 -0
- union/_internal/controllers/remote/_controller.py +169 -0
- union/_internal/controllers/remote/_core.py +341 -0
- union/_internal/controllers/remote/_informer.py +260 -0
- union/_internal/controllers/remote/_service_protocol.py +44 -0
- union/_internal/imagebuild/__init__.py +11 -0
- union/_internal/imagebuild/docker_builder.py +416 -0
- union/_internal/imagebuild/image_builder.py +243 -0
- union/_internal/imagebuild/remote_builder.py +0 -0
- union/_internal/resolvers/__init__.py +0 -0
- union/_internal/resolvers/_task_module.py +31 -0
- union/_internal/resolvers/common.py +24 -0
- union/_internal/resolvers/default.py +27 -0
- union/_internal/runtime/__init__.py +0 -0
- union/_internal/runtime/convert.py +163 -0
- union/_internal/runtime/entrypoints.py +121 -0
- union/_internal/runtime/io.py +136 -0
- union/_internal/runtime/resources_serde.py +134 -0
- union/_internal/runtime/task_serde.py +202 -0
- union/_internal/runtime/taskrunner.py +179 -0
- union/_internal/runtime/types_serde.py +53 -0
- union/_logging.py +124 -0
- union/_protos/__init__.py +0 -0
- union/_protos/common/authorization_pb2.py +66 -0
- union/_protos/common/authorization_pb2.pyi +106 -0
- union/_protos/common/authorization_pb2_grpc.py +4 -0
- union/_protos/common/identifier_pb2.py +71 -0
- union/_protos/common/identifier_pb2.pyi +82 -0
- union/_protos/common/identifier_pb2_grpc.py +4 -0
- union/_protos/common/identity_pb2.py +48 -0
- union/_protos/common/identity_pb2.pyi +72 -0
- union/_protos/common/identity_pb2_grpc.py +4 -0
- union/_protos/common/list_pb2.py +36 -0
- union/_protos/common/list_pb2.pyi +69 -0
- union/_protos/common/list_pb2_grpc.py +4 -0
- union/_protos/common/policy_pb2.py +37 -0
- union/_protos/common/policy_pb2.pyi +27 -0
- union/_protos/common/policy_pb2_grpc.py +4 -0
- union/_protos/common/role_pb2.py +37 -0
- union/_protos/common/role_pb2.pyi +51 -0
- union/_protos/common/role_pb2_grpc.py +4 -0
- union/_protos/common/runtime_version_pb2.py +28 -0
- union/_protos/common/runtime_version_pb2.pyi +24 -0
- union/_protos/common/runtime_version_pb2_grpc.py +4 -0
- union/_protos/logs/dataplane/payload_pb2.py +96 -0
- union/_protos/logs/dataplane/payload_pb2.pyi +168 -0
- union/_protos/logs/dataplane/payload_pb2_grpc.py +4 -0
- union/_protos/secret/definition_pb2.py +49 -0
- union/_protos/secret/definition_pb2.pyi +93 -0
- union/_protos/secret/definition_pb2_grpc.py +4 -0
- union/_protos/secret/payload_pb2.py +62 -0
- union/_protos/secret/payload_pb2.pyi +94 -0
- union/_protos/secret/payload_pb2_grpc.py +4 -0
- union/_protos/secret/secret_pb2.py +38 -0
- union/_protos/secret/secret_pb2.pyi +6 -0
- union/_protos/secret/secret_pb2_grpc.py +198 -0
- union/_protos/validate/validate/validate_pb2.py +76 -0
- union/_protos/workflow/node_execution_service_pb2.py +26 -0
- union/_protos/workflow/node_execution_service_pb2.pyi +4 -0
- union/_protos/workflow/node_execution_service_pb2_grpc.py +32 -0
- union/_protos/workflow/queue_service_pb2.py +75 -0
- union/_protos/workflow/queue_service_pb2.pyi +103 -0
- union/_protos/workflow/queue_service_pb2_grpc.py +172 -0
- union/_protos/workflow/run_definition_pb2.py +100 -0
- union/_protos/workflow/run_definition_pb2.pyi +256 -0
- union/_protos/workflow/run_definition_pb2_grpc.py +4 -0
- union/_protos/workflow/run_logs_service_pb2.py +41 -0
- union/_protos/workflow/run_logs_service_pb2.pyi +28 -0
- union/_protos/workflow/run_logs_service_pb2_grpc.py +69 -0
- union/_protos/workflow/run_service_pb2.py +133 -0
- union/_protos/workflow/run_service_pb2.pyi +173 -0
- union/_protos/workflow/run_service_pb2_grpc.py +412 -0
- union/_protos/workflow/state_service_pb2.py +58 -0
- union/_protos/workflow/state_service_pb2.pyi +69 -0
- union/_protos/workflow/state_service_pb2_grpc.py +138 -0
- union/_protos/workflow/task_definition_pb2.py +72 -0
- union/_protos/workflow/task_definition_pb2.pyi +65 -0
- union/_protos/workflow/task_definition_pb2_grpc.py +4 -0
- union/_protos/workflow/task_service_pb2.py +44 -0
- union/_protos/workflow/task_service_pb2.pyi +31 -0
- union/_protos/workflow/task_service_pb2_grpc.py +104 -0
- union/_resources.py +226 -0
- union/_retry.py +32 -0
- union/_reusable_environment.py +25 -0
- union/_run.py +374 -0
- union/_secret.py +61 -0
- union/_task.py +354 -0
- union/_task_environment.py +186 -0
- union/_timeout.py +47 -0
- union/_tools.py +27 -0
- union/_utils/__init__.py +11 -0
- union/_utils/asyn.py +119 -0
- union/_utils/file_handling.py +71 -0
- union/_utils/helpers.py +46 -0
- union/_utils/lazy_module.py +54 -0
- union/_utils/uv_script_parser.py +49 -0
- union/_version.py +21 -0
- union/connectors/__init__.py +0 -0
- union/errors.py +128 -0
- union/extras/__init__.py +5 -0
- union/extras/_container.py +263 -0
- union/io/__init__.py +11 -0
- union/io/_dataframe.py +0 -0
- union/io/_dir.py +425 -0
- union/io/_file.py +418 -0
- union/io/pickle/__init__.py +0 -0
- union/io/pickle/transformer.py +117 -0
- union/io/structured_dataset/__init__.py +122 -0
- union/io/structured_dataset/basic_dfs.py +219 -0
- union/io/structured_dataset/structured_dataset.py +1057 -0
- union/py.typed +0 -0
- union/remote/__init__.py +23 -0
- union/remote/_client/__init__.py +0 -0
- union/remote/_client/_protocols.py +129 -0
- union/remote/_client/auth/__init__.py +12 -0
- union/remote/_client/auth/_authenticators/__init__.py +0 -0
- union/remote/_client/auth/_authenticators/base.py +391 -0
- union/remote/_client/auth/_authenticators/client_credentials.py +73 -0
- union/remote/_client/auth/_authenticators/device_code.py +120 -0
- union/remote/_client/auth/_authenticators/external_command.py +77 -0
- union/remote/_client/auth/_authenticators/factory.py +200 -0
- union/remote/_client/auth/_authenticators/pkce.py +515 -0
- union/remote/_client/auth/_channel.py +184 -0
- union/remote/_client/auth/_client_config.py +83 -0
- union/remote/_client/auth/_default_html.py +32 -0
- union/remote/_client/auth/_grpc_utils/__init__.py +0 -0
- union/remote/_client/auth/_grpc_utils/auth_interceptor.py +204 -0
- union/remote/_client/auth/_grpc_utils/default_metadata_interceptor.py +144 -0
- union/remote/_client/auth/_keyring.py +154 -0
- union/remote/_client/auth/_token_client.py +258 -0
- union/remote/_client/auth/errors.py +16 -0
- union/remote/_client/controlplane.py +86 -0
- union/remote/_data.py +149 -0
- union/remote/_logs.py +74 -0
- union/remote/_project.py +86 -0
- union/remote/_run.py +820 -0
- union/remote/_secret.py +132 -0
- union/remote/_task.py +193 -0
- union/report/__init__.py +3 -0
- union/report/_report.py +178 -0
- union/report/_template.html +124 -0
- union/storage/__init__.py +24 -0
- union/storage/_remote_fs.py +34 -0
- union/storage/_storage.py +247 -0
- union/storage/_utils.py +5 -0
- union/types/__init__.py +11 -0
- union/types/_renderer.py +162 -0
- union/types/_string_literals.py +120 -0
- union/types/_type_engine.py +2131 -0
- union/types/_utils.py +80 -0
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
import hashlib # Added import for hashing
|
|
2
|
+
import typing
|
|
3
|
+
from urllib.parse import urlparse # Added import
|
|
4
|
+
|
|
5
|
+
import keyring
|
|
6
|
+
import pydantic
|
|
7
|
+
from keyring.errors import NoKeyringError, PasswordDeleteError
|
|
8
|
+
|
|
9
|
+
from union._logging import logger
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def strip_scheme(url: str) -> str:
|
|
13
|
+
"""
|
|
14
|
+
Strips the scheme from a URL.
|
|
15
|
+
Handles cases like:
|
|
16
|
+
- dns:///foo.com -> foo.com
|
|
17
|
+
- https://foo.com -> foo.com
|
|
18
|
+
- https://foo.com/blah -> foo.com/blah
|
|
19
|
+
"""
|
|
20
|
+
parsed_url = urlparse(url)
|
|
21
|
+
if parsed_url.scheme == "dns":
|
|
22
|
+
return parsed_url.path.lstrip("/")
|
|
23
|
+
return f"{parsed_url.netloc}{parsed_url.path}" if parsed_url.netloc else url
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class Credentials(pydantic.BaseModel):
|
|
27
|
+
"""
|
|
28
|
+
Stores the credentials together
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
access_token: str
|
|
32
|
+
for_endpoint: str = "flyte-default"
|
|
33
|
+
id: str | None = ""
|
|
34
|
+
refresh_token: str | None = None
|
|
35
|
+
expires_in: int | None = None
|
|
36
|
+
id_token: str | None = None
|
|
37
|
+
|
|
38
|
+
@pydantic.field_validator("for_endpoint", mode="after")
|
|
39
|
+
@classmethod
|
|
40
|
+
def validate_endpoint(cls, v: str) -> str:
|
|
41
|
+
return strip_scheme(v)
|
|
42
|
+
|
|
43
|
+
@pydantic.model_validator(mode="after")
|
|
44
|
+
def compute_id(self) -> "Credentials":
|
|
45
|
+
"""Computes the id field as a hash of the access_token or id_token."""
|
|
46
|
+
if self.access_token:
|
|
47
|
+
self.id = hashlib.md5(self.access_token.encode()).hexdigest()
|
|
48
|
+
elif self.id_token:
|
|
49
|
+
self.id = hashlib.md5(self.id_token.encode()).hexdigest()
|
|
50
|
+
return self
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class KeyringStore:
|
|
54
|
+
"""
|
|
55
|
+
Methods to access Keyring Store.
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
_access_token_key = "access_token"
|
|
59
|
+
_refresh_token_key = "refresh_token"
|
|
60
|
+
_id_token_key = "id_token"
|
|
61
|
+
|
|
62
|
+
@staticmethod
|
|
63
|
+
def store(credentials: Credentials) -> Credentials:
|
|
64
|
+
"""
|
|
65
|
+
Stores the provided credentials in the system keyring.
|
|
66
|
+
|
|
67
|
+
This method stores the access token, refresh token (if available), and ID token (if available)
|
|
68
|
+
in the system keyring, using the endpoint as the service name and specific key names for each token type.
|
|
69
|
+
|
|
70
|
+
:param credentials: The credentials object containing tokens to store
|
|
71
|
+
:return: The same credentials object that was passed in
|
|
72
|
+
:raises: Logs but does not raise NoKeyringError if the system keyring is not available
|
|
73
|
+
"""
|
|
74
|
+
try:
|
|
75
|
+
if credentials.refresh_token:
|
|
76
|
+
keyring.set_password(
|
|
77
|
+
credentials.for_endpoint,
|
|
78
|
+
KeyringStore._refresh_token_key,
|
|
79
|
+
credentials.refresh_token,
|
|
80
|
+
)
|
|
81
|
+
keyring.set_password(
|
|
82
|
+
credentials.for_endpoint,
|
|
83
|
+
KeyringStore._access_token_key,
|
|
84
|
+
credentials.access_token,
|
|
85
|
+
)
|
|
86
|
+
if credentials.id_token:
|
|
87
|
+
keyring.set_password(
|
|
88
|
+
credentials.for_endpoint,
|
|
89
|
+
KeyringStore._id_token_key,
|
|
90
|
+
credentials.id_token,
|
|
91
|
+
)
|
|
92
|
+
except NoKeyringError as e:
|
|
93
|
+
logger.debug(f"KeyRing not available, tokens will not be cached. Error: {e}")
|
|
94
|
+
return credentials
|
|
95
|
+
|
|
96
|
+
@staticmethod
|
|
97
|
+
def retrieve(for_endpoint: str) -> typing.Optional[Credentials]:
|
|
98
|
+
"""
|
|
99
|
+
Retrieves stored credentials from the system keyring for the specified endpoint.
|
|
100
|
+
|
|
101
|
+
This method attempts to retrieve the access token, refresh token, and ID token from the system keyring
|
|
102
|
+
using the endpoint as the service name. The endpoint URL scheme is stripped before lookup.
|
|
103
|
+
|
|
104
|
+
:param for_endpoint: The endpoint URL to retrieve credentials for
|
|
105
|
+
:return: A Credentials object containing the retrieved tokens, or None if no tokens were found
|
|
106
|
+
or if the system keyring is not available
|
|
107
|
+
"""
|
|
108
|
+
for_endpoint = strip_scheme(for_endpoint)
|
|
109
|
+
try:
|
|
110
|
+
refresh_token = keyring.get_password(for_endpoint, KeyringStore._refresh_token_key)
|
|
111
|
+
access_token = keyring.get_password(for_endpoint, KeyringStore._access_token_key)
|
|
112
|
+
id_token = keyring.get_password(for_endpoint, KeyringStore._id_token_key)
|
|
113
|
+
except NoKeyringError as e:
|
|
114
|
+
logger.debug(f"KeyRing not available, tokens will not be cached. Error: {e}")
|
|
115
|
+
return None
|
|
116
|
+
|
|
117
|
+
if not access_token and not id_token:
|
|
118
|
+
return None
|
|
119
|
+
return Credentials(
|
|
120
|
+
access_token=access_token,
|
|
121
|
+
refresh_token=refresh_token,
|
|
122
|
+
for_endpoint=for_endpoint,
|
|
123
|
+
id_token=id_token,
|
|
124
|
+
expires_in=None,
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
@staticmethod
|
|
128
|
+
def delete(for_endpoint: str):
|
|
129
|
+
"""
|
|
130
|
+
Deletes all stored credentials for the specified endpoint from the system keyring.
|
|
131
|
+
|
|
132
|
+
This method attempts to delete the access token, refresh token, and ID token from the system keyring
|
|
133
|
+
using the endpoint as the service name. The endpoint URL scheme is stripped before lookup.
|
|
134
|
+
|
|
135
|
+
:param for_endpoint: The endpoint URL to delete credentials for
|
|
136
|
+
"""
|
|
137
|
+
for_endpoint = strip_scheme(for_endpoint)
|
|
138
|
+
|
|
139
|
+
def _delete_key(key):
|
|
140
|
+
"""
|
|
141
|
+
Helper function to delete a specific key from the keyring.
|
|
142
|
+
|
|
143
|
+
:param key: The key name to delete
|
|
144
|
+
"""
|
|
145
|
+
try:
|
|
146
|
+
keyring.delete_password(for_endpoint, key)
|
|
147
|
+
except PasswordDeleteError as e:
|
|
148
|
+
logger.debug(f"Key {key} not found in key store, Ignoring. Error: {e}")
|
|
149
|
+
except NoKeyringError as e:
|
|
150
|
+
logger.debug(f"KeyRing not available, Key {key} deletion failed. Error: {e}")
|
|
151
|
+
|
|
152
|
+
_delete_key(KeyringStore._access_token_key)
|
|
153
|
+
_delete_key(KeyringStore._refresh_token_key)
|
|
154
|
+
_delete_key(KeyringStore._id_token_key)
|
|
@@ -0,0 +1,258 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import base64
|
|
3
|
+
import enum
|
|
4
|
+
import typing
|
|
5
|
+
import urllib.parse
|
|
6
|
+
from datetime import datetime, timedelta
|
|
7
|
+
|
|
8
|
+
import httpx
|
|
9
|
+
import pydantic
|
|
10
|
+
|
|
11
|
+
from union._logging import logger
|
|
12
|
+
from union.remote._client.auth.errors import AuthenticationError, AuthenticationPending
|
|
13
|
+
|
|
14
|
+
utf_8 = "utf-8"
|
|
15
|
+
|
|
16
|
+
# Errors that Token endpoint will return
|
|
17
|
+
error_slow_down = "slow_down"
|
|
18
|
+
error_auth_pending = "authorization_pending"
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
# Grant Types
|
|
22
|
+
class GrantType(str, enum.Enum):
|
|
23
|
+
CLIENT_CREDS = "client_credentials"
|
|
24
|
+
DEVICE_CODE = "urn:ietf:params:oauth:grant-type:device_code"
|
|
25
|
+
REFRESH_TOKEN = "refresh_token"
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class DeviceCodeResponse(pydantic.BaseModel):
|
|
29
|
+
"""
|
|
30
|
+
Response from device auth flow endpoint
|
|
31
|
+
{
|
|
32
|
+
'device_code': 'code',
|
|
33
|
+
'user_code': 'BNDJJFXL',
|
|
34
|
+
'verification_uri': 'url',
|
|
35
|
+
'expires_in': 600,
|
|
36
|
+
'interval': 5
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
Attributes:
|
|
40
|
+
device_code (str): The device verification code.
|
|
41
|
+
user_code (str): The user-facing code that should be entered on the verification page.
|
|
42
|
+
verification_uri (str): The URL where the user should enter the user_code.
|
|
43
|
+
expires_in (int): The lifetime in seconds of the device code and user code.
|
|
44
|
+
interval (int): The minimum amount of time in seconds to wait between polling requests.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
device_code: str
|
|
48
|
+
user_code: str
|
|
49
|
+
verification_uri: str
|
|
50
|
+
expires_in: int
|
|
51
|
+
interval: int
|
|
52
|
+
|
|
53
|
+
@classmethod
|
|
54
|
+
def from_json_response(cls, j: typing.Dict) -> "DeviceCodeResponse":
|
|
55
|
+
"""
|
|
56
|
+
Create a DeviceCodeResponse instance from a JSON response dictionary.
|
|
57
|
+
|
|
58
|
+
:param j: The JSON response dictionary containing device code information
|
|
59
|
+
:return: A new instance with values from the JSON response
|
|
60
|
+
"""
|
|
61
|
+
return cls(
|
|
62
|
+
device_code=j["device_code"],
|
|
63
|
+
user_code=j["user_code"],
|
|
64
|
+
verification_uri=j["verification_uri"],
|
|
65
|
+
expires_in=j["expires_in"],
|
|
66
|
+
interval=j["interval"],
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def get_basic_authorization_header(client_id: str, client_secret: str) -> str:
|
|
71
|
+
"""
|
|
72
|
+
This function transforms the client id and the client secret into a header that conforms with http basic auth.
|
|
73
|
+
It joins the id and the secret with a : then base64 encodes it, then adds the appropriate text. Secrets are
|
|
74
|
+
first URL encoded to escape illegal characters.
|
|
75
|
+
|
|
76
|
+
:param client_id: The client ID for authentication
|
|
77
|
+
:param client_secret: The client secret for authentication
|
|
78
|
+
:rtype: str
|
|
79
|
+
"""
|
|
80
|
+
encoded = urllib.parse.quote_plus(client_secret)
|
|
81
|
+
concatenated = "{}:{}".format(client_id, encoded)
|
|
82
|
+
return "Basic {}".format(base64.b64encode(concatenated.encode(utf_8)).decode(utf_8))
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
async def get_token(
|
|
86
|
+
token_endpoint: str,
|
|
87
|
+
http_session: httpx.AsyncClient,
|
|
88
|
+
scopes: typing.Optional[typing.List[str]] = None,
|
|
89
|
+
authorization_header: typing.Optional[str] = None,
|
|
90
|
+
client_id: typing.Optional[str] = None,
|
|
91
|
+
device_code: typing.Optional[str] = None,
|
|
92
|
+
audience: typing.Optional[str] = None,
|
|
93
|
+
grant_type: GrantType = GrantType.CLIENT_CREDS,
|
|
94
|
+
http_proxy_url: typing.Optional[str] = None,
|
|
95
|
+
verify: typing.Optional[typing.Union[bool, str]] = None,
|
|
96
|
+
refresh_token: typing.Optional[str] = None,
|
|
97
|
+
) -> typing.Tuple[str, str, int]:
|
|
98
|
+
"""
|
|
99
|
+
Retrieves an access token from the specified token endpoint.
|
|
100
|
+
|
|
101
|
+
:param token_endpoint: The endpoint URL for token retrieval
|
|
102
|
+
:param http_session: HTTP session to use for requests
|
|
103
|
+
:param scopes: Optional list of scopes to request during authentication
|
|
104
|
+
:param authorization_header: Optional authorization header value
|
|
105
|
+
:param client_id: Optional client ID for authentication
|
|
106
|
+
:param device_code: Optional device code for device flow authentication
|
|
107
|
+
:param audience: Optional audience for the token
|
|
108
|
+
:param grant_type: The grant type to use (default: CLIENT_CREDS)
|
|
109
|
+
:param http_proxy_url: Optional HTTP proxy URL
|
|
110
|
+
:param verify: Whether to verify SSL certificates (bool or path to cert)
|
|
111
|
+
:param refresh_token: Optional refresh token for token refresh
|
|
112
|
+
:return: A tuple of (access_token, refresh_token, expires_in)
|
|
113
|
+
|
|
114
|
+
:param token_endpoint: The URL of the token endpoint
|
|
115
|
+
:param scopes: List of scopes to request
|
|
116
|
+
:param authorization_header: Authorization header value if using client credentials
|
|
117
|
+
:param client_id: The client ID to use for authentication
|
|
118
|
+
:param device_code: The device code when using device code flow
|
|
119
|
+
:param audience: The audience value to request
|
|
120
|
+
:param grant_type: The OAuth grant type to use
|
|
121
|
+
:param http_proxy_url: HTTP proxy URL if needed
|
|
122
|
+
:param verify: SSL verification mode
|
|
123
|
+
:param http_session: An existing HTTP client session
|
|
124
|
+
:param refresh_token: Refresh token for refresh token flow
|
|
125
|
+
:return: A tuple containing (access_token, refresh_token, expires_in)
|
|
126
|
+
|
|
127
|
+
Raises:
|
|
128
|
+
AuthenticationPending: When authentication is still pending (for device code flow).
|
|
129
|
+
AuthenticationError: When authentication fails for any reason.
|
|
130
|
+
"""
|
|
131
|
+
headers = {
|
|
132
|
+
"Cache-Control": "no-cache",
|
|
133
|
+
"Accept": "application/json",
|
|
134
|
+
"Content-Type": "application/x-www-form-urlencoded",
|
|
135
|
+
}
|
|
136
|
+
if authorization_header:
|
|
137
|
+
headers["Authorization"] = authorization_header
|
|
138
|
+
body = {
|
|
139
|
+
"grant_type": grant_type.value,
|
|
140
|
+
}
|
|
141
|
+
if client_id:
|
|
142
|
+
body["client_id"] = client_id
|
|
143
|
+
if device_code:
|
|
144
|
+
body["device_code"] = device_code
|
|
145
|
+
if scopes is not None:
|
|
146
|
+
body["scope"] = " ".join(s.strip("' ") for s in scopes).strip("[]'")
|
|
147
|
+
if audience:
|
|
148
|
+
body["audience"] = audience
|
|
149
|
+
if refresh_token:
|
|
150
|
+
body["refresh_token"] = refresh_token
|
|
151
|
+
|
|
152
|
+
response = await http_session.post(token_endpoint, data=body, headers=headers)
|
|
153
|
+
|
|
154
|
+
if not response.is_success:
|
|
155
|
+
j = response.json()
|
|
156
|
+
if "error" in j:
|
|
157
|
+
err = j["error"]
|
|
158
|
+
if err == error_auth_pending or err == error_slow_down:
|
|
159
|
+
raise AuthenticationPending(f"Token not yet available, try again in some time {err}")
|
|
160
|
+
logger.error("Status Code ({}) received from IDP: {}".format(response.status_code, response.text))
|
|
161
|
+
raise AuthenticationError("Status Code ({}) received from IDP: {}".format(response.status_code, response.text))
|
|
162
|
+
|
|
163
|
+
j = response.json()
|
|
164
|
+
new_refresh_token = None
|
|
165
|
+
if "refresh_token" in j:
|
|
166
|
+
new_refresh_token = j["refresh_token"]
|
|
167
|
+
|
|
168
|
+
return j["access_token"], new_refresh_token, j["expires_in"]
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
async def get_device_code(
|
|
172
|
+
device_auth_endpoint: str,
|
|
173
|
+
client_id: str,
|
|
174
|
+
http_session: httpx.AsyncClient,
|
|
175
|
+
*,
|
|
176
|
+
audience: typing.Optional[str] = None,
|
|
177
|
+
scopes: typing.Optional[typing.List[str]] = None,
|
|
178
|
+
) -> DeviceCodeResponse:
|
|
179
|
+
"""
|
|
180
|
+
Retrieves the device authentication code that can be used to authenticate the request using a browser on a
|
|
181
|
+
separate device.
|
|
182
|
+
|
|
183
|
+
:param device_auth_endpoint: The URL of the device authorization endpoint
|
|
184
|
+
:param client_id: The client ID to use for authentication
|
|
185
|
+
:param audience: The audience value to request
|
|
186
|
+
:param scopes: List of scopes to request
|
|
187
|
+
:param http_proxy_url: HTTP proxy URL if needed
|
|
188
|
+
:param verify: SSL verification mode
|
|
189
|
+
:param http_session: An existing HTTP client session
|
|
190
|
+
:return: An object containing the device code and related information
|
|
191
|
+
:raises AuthenticationError: When device code retrieval fails
|
|
192
|
+
"""
|
|
193
|
+
_scope = " ".join(s.strip("' ") for s in scopes).strip("[]'") if scopes is not None else ""
|
|
194
|
+
payload = {"client_id": client_id, "scope": _scope, "audience": audience}
|
|
195
|
+
resp = await http_session.post(device_auth_endpoint, data=payload)
|
|
196
|
+
if not resp.is_success:
|
|
197
|
+
raise AuthenticationError(
|
|
198
|
+
f"Unable to retrieve Device Authentication Code for {payload},"
|
|
199
|
+
f" Status Code {resp.status_code} Reason {resp.json()}"
|
|
200
|
+
)
|
|
201
|
+
return DeviceCodeResponse.from_json_response(resp.json())
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
async def poll_token_endpoint(
|
|
205
|
+
resp: DeviceCodeResponse,
|
|
206
|
+
*,
|
|
207
|
+
token_endpoint: str,
|
|
208
|
+
client_id: str,
|
|
209
|
+
http_session: httpx.AsyncClient,
|
|
210
|
+
audience: typing.Optional[str] = None,
|
|
211
|
+
scopes: typing.Optional[typing.List[str]] = None,
|
|
212
|
+
http_proxy_url: typing.Optional[str] = None,
|
|
213
|
+
verify: typing.Optional[typing.Union[bool, str]] = None,
|
|
214
|
+
) -> typing.Tuple[str, str, int]:
|
|
215
|
+
"""
|
|
216
|
+
Polls the token endpoint until authentication is complete or times out.
|
|
217
|
+
|
|
218
|
+
This function repeatedly calls the token endpoint at the specified interval until either:
|
|
219
|
+
1. Authentication is successful and a token is returned
|
|
220
|
+
2. The device code expires (as specified in the DeviceCodeResponse)
|
|
221
|
+
|
|
222
|
+
:param resp: The device code response from a previous call to get_device_code
|
|
223
|
+
:param token_endpoint: The URL of the token endpoint
|
|
224
|
+
:param client_id: The client ID to use for authentication
|
|
225
|
+
:param audience: The audience value to request
|
|
226
|
+
:param scopes: Space-separated list of scopes to request
|
|
227
|
+
:param http_proxy_url: HTTP proxy URL if needed
|
|
228
|
+
:param verify: SSL verification mode
|
|
229
|
+
:return: A tuple containing (access_token, refresh_token, expires_in)
|
|
230
|
+
:raises AuthenticationError: When authentication fails or times out
|
|
231
|
+
"""
|
|
232
|
+
tick = datetime.now()
|
|
233
|
+
interval = timedelta(seconds=resp.interval)
|
|
234
|
+
end_time = tick + timedelta(seconds=resp.expires_in)
|
|
235
|
+
while tick < end_time:
|
|
236
|
+
try:
|
|
237
|
+
access_token, refresh_token, expires_in = await get_token(
|
|
238
|
+
token_endpoint,
|
|
239
|
+
grant_type=GrantType.DEVICE_CODE,
|
|
240
|
+
client_id=client_id,
|
|
241
|
+
audience=audience,
|
|
242
|
+
scopes=scopes,
|
|
243
|
+
device_code=resp.device_code,
|
|
244
|
+
http_proxy_url=http_proxy_url,
|
|
245
|
+
verify=verify,
|
|
246
|
+
http_session=http_session,
|
|
247
|
+
)
|
|
248
|
+
logger.debug(f"Authentication successful, access token received, expires in {expires_in} seconds")
|
|
249
|
+
return access_token, refresh_token, expires_in
|
|
250
|
+
except AuthenticationPending:
|
|
251
|
+
...
|
|
252
|
+
except Exception as e:
|
|
253
|
+
logger.warning(f"Authentication failed, reason {e}")
|
|
254
|
+
raise e
|
|
255
|
+
logger.debug(f"Authentication pending, ..., waiting for {resp.interval} seconds")
|
|
256
|
+
await asyncio.sleep(interval.total_seconds())
|
|
257
|
+
tick = tick + interval
|
|
258
|
+
raise AuthenticationError("Authentication failed!")
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
class AccessTokenNotFoundError(RuntimeError):
|
|
2
|
+
"""
|
|
3
|
+
This error is raised with Access token is not found or if Refreshing the token fails
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class AuthenticationError(RuntimeError):
|
|
8
|
+
"""
|
|
9
|
+
This is raised for any AuthenticationError
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class AuthenticationPending(RuntimeError):
|
|
14
|
+
"""
|
|
15
|
+
This is raised if the token endpoint returns authentication pending
|
|
16
|
+
"""
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import grpc
|
|
4
|
+
from flyteidl.service import admin_pb2_grpc, dataproxy_pb2_grpc
|
|
5
|
+
|
|
6
|
+
from union._protos.secret import secret_pb2_grpc
|
|
7
|
+
from union._protos.workflow import run_logs_service_pb2_grpc, run_service_pb2_grpc, task_service_pb2_grpc
|
|
8
|
+
|
|
9
|
+
from ._protocols import (
|
|
10
|
+
DataProxyService,
|
|
11
|
+
MetadataServiceProtocol,
|
|
12
|
+
ProjectDomainService,
|
|
13
|
+
RunLogsService,
|
|
14
|
+
RunService,
|
|
15
|
+
SecretService,
|
|
16
|
+
TaskService,
|
|
17
|
+
)
|
|
18
|
+
from .auth import create_channel
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class ClientSet:
|
|
22
|
+
def __init__(self, channel: grpc.aio.Channel, data_proxy_channel: grpc.aio.Channel | None = None):
|
|
23
|
+
self._channel = channel
|
|
24
|
+
self._admin_client = admin_pb2_grpc.AdminServiceStub(channel=channel)
|
|
25
|
+
self._task_service = task_service_pb2_grpc.TaskServiceStub(channel=channel)
|
|
26
|
+
self._run_service = run_service_pb2_grpc.RunServiceStub(channel=channel)
|
|
27
|
+
self._dataproxy = dataproxy_pb2_grpc.DataProxyServiceStub(channel=channel)
|
|
28
|
+
self._log_service = run_logs_service_pb2_grpc.RunLogsServiceStub(channel=channel)
|
|
29
|
+
self._secrets_service = secret_pb2_grpc.SecretServiceStub(channel=channel)
|
|
30
|
+
|
|
31
|
+
@classmethod
|
|
32
|
+
async def for_endpoint(cls, endpoint: str, *, insecure: bool = False, **kwargs) -> ClientSet:
|
|
33
|
+
if insecure:
|
|
34
|
+
del kwargs["api_key"]
|
|
35
|
+
del kwargs["auth_type"]
|
|
36
|
+
del kwargs["headless"]
|
|
37
|
+
del kwargs["command"]
|
|
38
|
+
del kwargs["client_id"]
|
|
39
|
+
del kwargs["client_credentials_secret"]
|
|
40
|
+
del kwargs["client_config"]
|
|
41
|
+
del kwargs["rpc_retries"]
|
|
42
|
+
del kwargs["http_proxy_url"]
|
|
43
|
+
return cls(await create_channel(endpoint, insecure=insecure, **kwargs))
|
|
44
|
+
|
|
45
|
+
@classmethod
|
|
46
|
+
async def for_api_key(cls, api_key: str, **kwargs) -> ClientSet:
|
|
47
|
+
raise NotImplementedError
|
|
48
|
+
|
|
49
|
+
@classmethod
|
|
50
|
+
async def for_serverless(cls) -> ClientSet:
|
|
51
|
+
raise NotImplementedError
|
|
52
|
+
|
|
53
|
+
@classmethod
|
|
54
|
+
async def from_env(cls) -> ClientSet:
|
|
55
|
+
raise NotImplementedError
|
|
56
|
+
|
|
57
|
+
@property
|
|
58
|
+
def metadata_service(self) -> MetadataServiceProtocol:
|
|
59
|
+
return self._admin_client
|
|
60
|
+
|
|
61
|
+
@property
|
|
62
|
+
def project_domain_service(self) -> ProjectDomainService:
|
|
63
|
+
return self._admin_client
|
|
64
|
+
|
|
65
|
+
@property
|
|
66
|
+
def task_service(self) -> TaskService:
|
|
67
|
+
return self._task_service
|
|
68
|
+
|
|
69
|
+
@property
|
|
70
|
+
def run_service(self) -> RunService:
|
|
71
|
+
return self._run_service
|
|
72
|
+
|
|
73
|
+
@property
|
|
74
|
+
def dataproxy_service(self) -> DataProxyService:
|
|
75
|
+
return self._dataproxy
|
|
76
|
+
|
|
77
|
+
@property
|
|
78
|
+
def logs_service(self) -> RunLogsService:
|
|
79
|
+
return self._log_service
|
|
80
|
+
|
|
81
|
+
@property
|
|
82
|
+
def secrets_service(self) -> SecretService:
|
|
83
|
+
return self._secrets_service
|
|
84
|
+
|
|
85
|
+
async def close(self, grace: float | None = None):
|
|
86
|
+
return await self._channel.close(grace=grace)
|
union/remote/_data.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import hashlib
|
|
3
|
+
import os
|
|
4
|
+
import typing
|
|
5
|
+
import uuid
|
|
6
|
+
from base64 import b64encode
|
|
7
|
+
from datetime import timedelta
|
|
8
|
+
from functools import lru_cache
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Tuple
|
|
11
|
+
|
|
12
|
+
import aiofiles
|
|
13
|
+
import grpc
|
|
14
|
+
import httpx
|
|
15
|
+
from flyteidl.service import dataproxy_pb2
|
|
16
|
+
from google.protobuf import duration_pb2
|
|
17
|
+
|
|
18
|
+
from union._initialize import CommonInit, get_client, get_common_config, requires_client
|
|
19
|
+
from union.errors import RuntimeSystemError
|
|
20
|
+
|
|
21
|
+
_UPLOAD_EXPIRES_IN = timedelta(seconds=60)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def get_extra_headers_for_protocol(native_url: str) -> typing.Dict[str, str]:
|
|
25
|
+
"""
|
|
26
|
+
For Azure Blob Storage, we need to set certain headers for http request.
|
|
27
|
+
This is used when we work with signed urls.
|
|
28
|
+
:param native_url:
|
|
29
|
+
:return:
|
|
30
|
+
"""
|
|
31
|
+
if native_url.startswith("abfs://"):
|
|
32
|
+
return {"x-ms-blob-type": "BlockBlob"}
|
|
33
|
+
return {}
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@lru_cache
|
|
37
|
+
def hash_file(file_path: typing.Union[os.PathLike, str]) -> (bytes, str, int):
|
|
38
|
+
"""
|
|
39
|
+
Hash a file and produce a digest to be used as a version
|
|
40
|
+
"""
|
|
41
|
+
h = hashlib.md5()
|
|
42
|
+
size = 0
|
|
43
|
+
|
|
44
|
+
with open(file_path, "rb") as file:
|
|
45
|
+
while True:
|
|
46
|
+
# Reading is buffered, so we can read smaller chunks.
|
|
47
|
+
chunk = file.read(h.block_size)
|
|
48
|
+
if not chunk:
|
|
49
|
+
break
|
|
50
|
+
h.update(chunk)
|
|
51
|
+
size += len(chunk)
|
|
52
|
+
|
|
53
|
+
return h.digest(), h.hexdigest(), size
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
async def _upload_single_file(
|
|
57
|
+
cfg: CommonInit, fp: Path, verify: bool = True, basedir: str | None = None
|
|
58
|
+
) -> Tuple[str, str]:
|
|
59
|
+
md5_bytes, str_digest, _ = hash_file(fp)
|
|
60
|
+
|
|
61
|
+
try:
|
|
62
|
+
expires_in_pb = duration_pb2.Duration()
|
|
63
|
+
expires_in_pb.FromTimedelta(_UPLOAD_EXPIRES_IN)
|
|
64
|
+
client = get_client()
|
|
65
|
+
resp = await client.dataproxy_service.CreateUploadLocation(
|
|
66
|
+
dataproxy_pb2.CreateUploadLocationRequest(
|
|
67
|
+
project=cfg.project,
|
|
68
|
+
domain=cfg.domain,
|
|
69
|
+
content_md5=md5_bytes,
|
|
70
|
+
filename=fp.name,
|
|
71
|
+
expires_in=expires_in_pb,
|
|
72
|
+
filename_root=basedir,
|
|
73
|
+
add_content_md5_metadata=True,
|
|
74
|
+
)
|
|
75
|
+
)
|
|
76
|
+
except grpc.aio.AioRpcError as e:
|
|
77
|
+
if e.code() == grpc.StatusCode.NOT_FOUND:
|
|
78
|
+
raise RuntimeSystemError(
|
|
79
|
+
"NotFound", f"Failed to get signed url for {fp}, please check your project and domain."
|
|
80
|
+
)
|
|
81
|
+
elif e.code() == grpc.StatusCode.PERMISSION_DENIED:
|
|
82
|
+
raise RuntimeSystemError(
|
|
83
|
+
"PermissionDenied", f"Failed to get signed url for {fp}, please check your permissions."
|
|
84
|
+
)
|
|
85
|
+
else:
|
|
86
|
+
raise RuntimeSystemError(e.code().value, f"Failed to get signed url for {fp}.")
|
|
87
|
+
except Exception as e:
|
|
88
|
+
raise RuntimeSystemError(type(e).__name__, f"Failed to get signed url for {fp}.") from e
|
|
89
|
+
|
|
90
|
+
extra_headers = get_extra_headers_for_protocol(resp.native_url)
|
|
91
|
+
extra_headers.update(resp.headers)
|
|
92
|
+
encoded_md5 = b64encode(md5_bytes)
|
|
93
|
+
content_length = fp.stat().st_size
|
|
94
|
+
|
|
95
|
+
async with aiofiles.open(str(fp), "rb") as file:
|
|
96
|
+
extra_headers.update({"Content-Length": str(content_length), "Content-MD5": encoded_md5})
|
|
97
|
+
async with httpx.AsyncClient(verify=verify) as client:
|
|
98
|
+
await client.put(resp.signed_url, headers=extra_headers, content=file)
|
|
99
|
+
# TODO in old code we did this
|
|
100
|
+
# if self._config.platform.insecure_skip_verify is True
|
|
101
|
+
# else self._config.platform.ca_cert_file_path,
|
|
102
|
+
|
|
103
|
+
return str_digest, resp.native_url
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
@requires_client
|
|
107
|
+
async def upload_file(fp: Path, verify: bool = True) -> Tuple[str, str]:
|
|
108
|
+
"""
|
|
109
|
+
Uploads a file to a remote location and returns the remote URI.
|
|
110
|
+
|
|
111
|
+
:param fp: The file path to upload.
|
|
112
|
+
:param verify: Whether to verify the certificate for HTTPS requests.
|
|
113
|
+
:return: A tuple containing the MD5 digest and the remote URI.
|
|
114
|
+
"""
|
|
115
|
+
# This is a placeholder implementation. Replace with actual upload logic.
|
|
116
|
+
cfg = get_common_config()
|
|
117
|
+
if not fp.is_file():
|
|
118
|
+
raise ValueError(f"{fp} is not a single file, upload arg must be a single file.")
|
|
119
|
+
return await _upload_single_file(cfg, fp, verify=verify)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
@requires_client
|
|
123
|
+
async def upload_dir(dir_path: Path, verify: bool = True) -> str:
|
|
124
|
+
"""
|
|
125
|
+
Uploads a directory to a remote location and returns the remote URI.
|
|
126
|
+
|
|
127
|
+
:param dir_path: The directory path to upload.
|
|
128
|
+
:param verify: Whether to verify the certificate for HTTPS requests.
|
|
129
|
+
:return: The remote URI of the uploaded directory.
|
|
130
|
+
"""
|
|
131
|
+
# This is a placeholder implementation. Replace with actual upload logic.
|
|
132
|
+
cfg = get_common_config()
|
|
133
|
+
if not dir_path.is_dir():
|
|
134
|
+
raise ValueError(f"{dir_path} is not a directory, upload arg must be a directory.")
|
|
135
|
+
|
|
136
|
+
prefix = uuid.uuid4().hex
|
|
137
|
+
|
|
138
|
+
files = dir_path.rglob("*")
|
|
139
|
+
uploaded_files = []
|
|
140
|
+
for file in files:
|
|
141
|
+
if file.is_file():
|
|
142
|
+
uploaded_files.append(_upload_single_file(cfg, file, verify=verify, basedir=prefix))
|
|
143
|
+
|
|
144
|
+
urls = await asyncio.gather(*uploaded_files)
|
|
145
|
+
native_url = urls[0][1] # Assuming all files are uploaded to the same prefix
|
|
146
|
+
# native_url is of the form s3://my-s3-bucket/flytesnacks/development/{prefix}/source/empty.md
|
|
147
|
+
uri = native_url.split(prefix)[0] + "/" + prefix
|
|
148
|
+
|
|
149
|
+
return uri
|