flyte 2.0.0b32__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flyte might be problematic. Click here for more details.

Files changed (204) hide show
  1. flyte/__init__.py +108 -0
  2. flyte/_bin/__init__.py +0 -0
  3. flyte/_bin/debug.py +38 -0
  4. flyte/_bin/runtime.py +195 -0
  5. flyte/_bin/serve.py +178 -0
  6. flyte/_build.py +26 -0
  7. flyte/_cache/__init__.py +12 -0
  8. flyte/_cache/cache.py +147 -0
  9. flyte/_cache/defaults.py +9 -0
  10. flyte/_cache/local_cache.py +216 -0
  11. flyte/_cache/policy_function_body.py +42 -0
  12. flyte/_code_bundle/__init__.py +8 -0
  13. flyte/_code_bundle/_ignore.py +121 -0
  14. flyte/_code_bundle/_packaging.py +218 -0
  15. flyte/_code_bundle/_utils.py +347 -0
  16. flyte/_code_bundle/bundle.py +266 -0
  17. flyte/_constants.py +1 -0
  18. flyte/_context.py +155 -0
  19. flyte/_custom_context.py +73 -0
  20. flyte/_debug/__init__.py +0 -0
  21. flyte/_debug/constants.py +38 -0
  22. flyte/_debug/utils.py +17 -0
  23. flyte/_debug/vscode.py +307 -0
  24. flyte/_deploy.py +408 -0
  25. flyte/_deployer.py +109 -0
  26. flyte/_doc.py +29 -0
  27. flyte/_docstring.py +32 -0
  28. flyte/_environment.py +122 -0
  29. flyte/_excepthook.py +37 -0
  30. flyte/_group.py +32 -0
  31. flyte/_hash.py +8 -0
  32. flyte/_image.py +1055 -0
  33. flyte/_initialize.py +628 -0
  34. flyte/_interface.py +119 -0
  35. flyte/_internal/__init__.py +3 -0
  36. flyte/_internal/controllers/__init__.py +129 -0
  37. flyte/_internal/controllers/_local_controller.py +239 -0
  38. flyte/_internal/controllers/_trace.py +48 -0
  39. flyte/_internal/controllers/remote/__init__.py +58 -0
  40. flyte/_internal/controllers/remote/_action.py +211 -0
  41. flyte/_internal/controllers/remote/_client.py +47 -0
  42. flyte/_internal/controllers/remote/_controller.py +583 -0
  43. flyte/_internal/controllers/remote/_core.py +465 -0
  44. flyte/_internal/controllers/remote/_informer.py +381 -0
  45. flyte/_internal/controllers/remote/_service_protocol.py +50 -0
  46. flyte/_internal/imagebuild/__init__.py +3 -0
  47. flyte/_internal/imagebuild/docker_builder.py +706 -0
  48. flyte/_internal/imagebuild/image_builder.py +277 -0
  49. flyte/_internal/imagebuild/remote_builder.py +386 -0
  50. flyte/_internal/imagebuild/utils.py +78 -0
  51. flyte/_internal/resolvers/__init__.py +0 -0
  52. flyte/_internal/resolvers/_task_module.py +21 -0
  53. flyte/_internal/resolvers/common.py +31 -0
  54. flyte/_internal/resolvers/default.py +28 -0
  55. flyte/_internal/runtime/__init__.py +0 -0
  56. flyte/_internal/runtime/convert.py +486 -0
  57. flyte/_internal/runtime/entrypoints.py +204 -0
  58. flyte/_internal/runtime/io.py +188 -0
  59. flyte/_internal/runtime/resources_serde.py +152 -0
  60. flyte/_internal/runtime/reuse.py +125 -0
  61. flyte/_internal/runtime/rusty.py +193 -0
  62. flyte/_internal/runtime/task_serde.py +362 -0
  63. flyte/_internal/runtime/taskrunner.py +209 -0
  64. flyte/_internal/runtime/trigger_serde.py +160 -0
  65. flyte/_internal/runtime/types_serde.py +54 -0
  66. flyte/_keyring/__init__.py +0 -0
  67. flyte/_keyring/file.py +115 -0
  68. flyte/_logging.py +300 -0
  69. flyte/_map.py +312 -0
  70. flyte/_module.py +72 -0
  71. flyte/_pod.py +30 -0
  72. flyte/_resources.py +473 -0
  73. flyte/_retry.py +32 -0
  74. flyte/_reusable_environment.py +102 -0
  75. flyte/_run.py +724 -0
  76. flyte/_secret.py +96 -0
  77. flyte/_task.py +550 -0
  78. flyte/_task_environment.py +316 -0
  79. flyte/_task_plugins.py +47 -0
  80. flyte/_timeout.py +47 -0
  81. flyte/_tools.py +27 -0
  82. flyte/_trace.py +119 -0
  83. flyte/_trigger.py +1000 -0
  84. flyte/_utils/__init__.py +30 -0
  85. flyte/_utils/asyn.py +121 -0
  86. flyte/_utils/async_cache.py +139 -0
  87. flyte/_utils/coro_management.py +27 -0
  88. flyte/_utils/docker_credentials.py +173 -0
  89. flyte/_utils/file_handling.py +72 -0
  90. flyte/_utils/helpers.py +134 -0
  91. flyte/_utils/lazy_module.py +54 -0
  92. flyte/_utils/module_loader.py +104 -0
  93. flyte/_utils/org_discovery.py +57 -0
  94. flyte/_utils/uv_script_parser.py +49 -0
  95. flyte/_version.py +34 -0
  96. flyte/app/__init__.py +22 -0
  97. flyte/app/_app_environment.py +157 -0
  98. flyte/app/_deploy.py +125 -0
  99. flyte/app/_input.py +160 -0
  100. flyte/app/_runtime/__init__.py +3 -0
  101. flyte/app/_runtime/app_serde.py +347 -0
  102. flyte/app/_types.py +101 -0
  103. flyte/app/extras/__init__.py +3 -0
  104. flyte/app/extras/_fastapi.py +151 -0
  105. flyte/cli/__init__.py +12 -0
  106. flyte/cli/_abort.py +28 -0
  107. flyte/cli/_build.py +114 -0
  108. flyte/cli/_common.py +468 -0
  109. flyte/cli/_create.py +371 -0
  110. flyte/cli/_delete.py +45 -0
  111. flyte/cli/_deploy.py +293 -0
  112. flyte/cli/_gen.py +176 -0
  113. flyte/cli/_get.py +370 -0
  114. flyte/cli/_option.py +33 -0
  115. flyte/cli/_params.py +554 -0
  116. flyte/cli/_plugins.py +209 -0
  117. flyte/cli/_run.py +597 -0
  118. flyte/cli/_serve.py +64 -0
  119. flyte/cli/_update.py +37 -0
  120. flyte/cli/_user.py +17 -0
  121. flyte/cli/main.py +221 -0
  122. flyte/config/__init__.py +3 -0
  123. flyte/config/_config.py +248 -0
  124. flyte/config/_internal.py +73 -0
  125. flyte/config/_reader.py +225 -0
  126. flyte/connectors/__init__.py +11 -0
  127. flyte/connectors/_connector.py +270 -0
  128. flyte/connectors/_server.py +197 -0
  129. flyte/connectors/utils.py +135 -0
  130. flyte/errors.py +243 -0
  131. flyte/extend.py +19 -0
  132. flyte/extras/__init__.py +5 -0
  133. flyte/extras/_container.py +286 -0
  134. flyte/git/__init__.py +3 -0
  135. flyte/git/_config.py +21 -0
  136. flyte/io/__init__.py +29 -0
  137. flyte/io/_dataframe/__init__.py +131 -0
  138. flyte/io/_dataframe/basic_dfs.py +223 -0
  139. flyte/io/_dataframe/dataframe.py +1026 -0
  140. flyte/io/_dir.py +910 -0
  141. flyte/io/_file.py +914 -0
  142. flyte/io/_hashing_io.py +342 -0
  143. flyte/models.py +479 -0
  144. flyte/py.typed +0 -0
  145. flyte/remote/__init__.py +35 -0
  146. flyte/remote/_action.py +738 -0
  147. flyte/remote/_app.py +57 -0
  148. flyte/remote/_client/__init__.py +0 -0
  149. flyte/remote/_client/_protocols.py +189 -0
  150. flyte/remote/_client/auth/__init__.py +12 -0
  151. flyte/remote/_client/auth/_auth_utils.py +14 -0
  152. flyte/remote/_client/auth/_authenticators/__init__.py +0 -0
  153. flyte/remote/_client/auth/_authenticators/base.py +403 -0
  154. flyte/remote/_client/auth/_authenticators/client_credentials.py +73 -0
  155. flyte/remote/_client/auth/_authenticators/device_code.py +117 -0
  156. flyte/remote/_client/auth/_authenticators/external_command.py +79 -0
  157. flyte/remote/_client/auth/_authenticators/factory.py +200 -0
  158. flyte/remote/_client/auth/_authenticators/pkce.py +516 -0
  159. flyte/remote/_client/auth/_channel.py +213 -0
  160. flyte/remote/_client/auth/_client_config.py +85 -0
  161. flyte/remote/_client/auth/_default_html.py +32 -0
  162. flyte/remote/_client/auth/_grpc_utils/__init__.py +0 -0
  163. flyte/remote/_client/auth/_grpc_utils/auth_interceptor.py +288 -0
  164. flyte/remote/_client/auth/_grpc_utils/default_metadata_interceptor.py +151 -0
  165. flyte/remote/_client/auth/_keyring.py +152 -0
  166. flyte/remote/_client/auth/_token_client.py +260 -0
  167. flyte/remote/_client/auth/errors.py +16 -0
  168. flyte/remote/_client/controlplane.py +128 -0
  169. flyte/remote/_common.py +30 -0
  170. flyte/remote/_console.py +19 -0
  171. flyte/remote/_data.py +161 -0
  172. flyte/remote/_logs.py +185 -0
  173. flyte/remote/_project.py +88 -0
  174. flyte/remote/_run.py +386 -0
  175. flyte/remote/_secret.py +142 -0
  176. flyte/remote/_task.py +527 -0
  177. flyte/remote/_trigger.py +306 -0
  178. flyte/remote/_user.py +33 -0
  179. flyte/report/__init__.py +3 -0
  180. flyte/report/_report.py +182 -0
  181. flyte/report/_template.html +124 -0
  182. flyte/storage/__init__.py +36 -0
  183. flyte/storage/_config.py +237 -0
  184. flyte/storage/_parallel_reader.py +274 -0
  185. flyte/storage/_remote_fs.py +34 -0
  186. flyte/storage/_storage.py +456 -0
  187. flyte/storage/_utils.py +5 -0
  188. flyte/syncify/__init__.py +56 -0
  189. flyte/syncify/_api.py +375 -0
  190. flyte/types/__init__.py +52 -0
  191. flyte/types/_interface.py +40 -0
  192. flyte/types/_pickle.py +145 -0
  193. flyte/types/_renderer.py +162 -0
  194. flyte/types/_string_literals.py +119 -0
  195. flyte/types/_type_engine.py +2254 -0
  196. flyte/types/_utils.py +80 -0
  197. flyte-2.0.0b32.data/scripts/debug.py +38 -0
  198. flyte-2.0.0b32.data/scripts/runtime.py +195 -0
  199. flyte-2.0.0b32.dist-info/METADATA +351 -0
  200. flyte-2.0.0b32.dist-info/RECORD +204 -0
  201. flyte-2.0.0b32.dist-info/WHEEL +5 -0
  202. flyte-2.0.0b32.dist-info/entry_points.txt +7 -0
  203. flyte-2.0.0b32.dist-info/licenses/LICENSE +201 -0
  204. flyte-2.0.0b32.dist-info/top_level.txt +1 -0
flyte/remote/_app.py ADDED
@@ -0,0 +1,57 @@
1
+ from __future__ import annotations
2
+
3
+ from flyteidl2.app import app_definition_pb2, app_payload_pb2
4
+
5
+ from flyte._initialize import ensure_client, get_client, get_init_config
6
+ from flyte.syncify import syncify
7
+
8
+ from ._common import ToJSONMixin
9
+
10
+
11
+ class App(ToJSONMixin):
12
+ pb2: app_definition_pb2.App
13
+
14
+ def __init__(self, pb2: app_definition_pb2.App):
15
+ self.pb2 = pb2
16
+
17
+ @property
18
+ def name(self) -> str:
19
+ return self.pb2.metadata.id.name
20
+
21
+ @property
22
+ def revision(self) -> int:
23
+ return self.pb2.metadata.revision
24
+
25
+ @property
26
+ def endpoint(self) -> str:
27
+ return self.pb2.status.ingress.public_url
28
+
29
+ @classmethod
30
+ @syncify
31
+ async def get(
32
+ cls,
33
+ name: str,
34
+ project: str | None = None,
35
+ domain: str | None = None,
36
+ ) -> App:
37
+ """
38
+ Get an app by name.
39
+
40
+ :param name: The name of the app.
41
+ :param project: The project of the app.
42
+ :param domain: The domain of the app.
43
+ :return: The app remote object.
44
+ """
45
+ ensure_client()
46
+ cfg = get_init_config()
47
+ resp = await get_client().app_service.Get(
48
+ request=app_payload_pb2.GetRequest(
49
+ app_id=app_definition_pb2.Identifier(
50
+ org=cfg.org,
51
+ project=project or cfg.project,
52
+ domain=domain or cfg.domain,
53
+ name=name,
54
+ ),
55
+ )
56
+ )
57
+ return cls(pb2=resp.app)
File without changes
@@ -0,0 +1,189 @@
1
+ from typing import AsyncIterator, Protocol
2
+
3
+ from flyteidl.admin import project_attributes_pb2, project_pb2, version_pb2
4
+ from flyteidl.service import dataproxy_pb2, identity_pb2
5
+ from flyteidl2.app import app_payload_pb2
6
+ from flyteidl2.secret import payload_pb2
7
+ from flyteidl2.task import task_service_pb2
8
+ from flyteidl2.trigger import trigger_service_pb2
9
+ from flyteidl2.workflow import run_logs_service_pb2, run_service_pb2
10
+ from grpc.aio import UnaryStreamCall
11
+ from grpc.aio._typing import RequestType
12
+
13
+
14
+ class MetadataServiceProtocol(Protocol):
15
+ async def GetVersion(self, request: version_pb2.GetVersionRequest) -> version_pb2.GetVersionResponse: ...
16
+
17
+
18
+ class ProjectDomainService(Protocol):
19
+ async def RegisterProject(
20
+ self, request: project_pb2.ProjectRegisterRequest
21
+ ) -> project_pb2.ProjectRegisterResponse: ...
22
+
23
+ async def UpdateProject(self, request: project_pb2.Project) -> project_pb2.ProjectUpdateResponse: ...
24
+
25
+ async def GetProject(self, request: project_pb2.ProjectGetRequest) -> project_pb2.Project: ...
26
+
27
+ async def ListProjects(self, request: project_pb2.ProjectListRequest) -> project_pb2.Projects: ...
28
+
29
+ async def GetDomains(self, request: project_pb2.GetDomainRequest) -> project_pb2.GetDomainsResponse: ...
30
+
31
+ async def UpdateProjectDomainAttributes(
32
+ self, request: project_attributes_pb2.ProjectAttributesUpdateRequest
33
+ ) -> project_pb2.ProjectUpdateResponse: ...
34
+
35
+ async def GetProjectDomainAttributes(
36
+ self, request: project_attributes_pb2.ProjectAttributesGetRequest
37
+ ) -> project_attributes_pb2.ProjectAttributes: ...
38
+
39
+ async def DeleteProjectDomainAttributes(
40
+ self, request: project_attributes_pb2.ProjectAttributesDeleteRequest
41
+ ) -> project_attributes_pb2.ProjectAttributesDeleteResponse: ...
42
+
43
+ async def UpdateProjectAttributes(
44
+ self, request: project_attributes_pb2.ProjectAttributesUpdateRequest
45
+ ) -> project_attributes_pb2.ProjectAttributesUpdateResponse: ...
46
+
47
+ async def GetProjectAttributes(
48
+ self, request: project_attributes_pb2.ProjectAttributesGetRequest
49
+ ) -> project_attributes_pb2.ProjectAttributes: ...
50
+
51
+ async def DeleteProjectAttributes(
52
+ self, request: project_attributes_pb2.ProjectAttributesDeleteRequest
53
+ ) -> project_attributes_pb2.ProjectAttributesDeleteResponse: ...
54
+
55
+
56
+ class TaskService(Protocol):
57
+ async def DeployTask(self, request: task_service_pb2.DeployTaskRequest) -> task_service_pb2.DeployTaskResponse: ...
58
+
59
+ async def GetTaskDetails(
60
+ self, request: task_service_pb2.GetTaskDetailsRequest
61
+ ) -> task_service_pb2.GetTaskDetailsResponse: ...
62
+
63
+ async def ListTasks(self, request: task_service_pb2.ListTasksRequest) -> task_service_pb2.ListTasksResponse: ...
64
+
65
+
66
+ class AppService(Protocol):
67
+ async def Create(self, request: app_payload_pb2.CreateRequest) -> app_payload_pb2.CreateResponse: ...
68
+
69
+ async def Get(self, request: app_payload_pb2.GetRequest) -> app_payload_pb2.GetResponse: ...
70
+
71
+ async def Update(self, request: app_payload_pb2.UpdateRequest) -> app_payload_pb2.UpdateResponse: ...
72
+
73
+ async def UpdateStatus(
74
+ self, request: app_payload_pb2.UpdateStatusRequest
75
+ ) -> app_payload_pb2.UpdateStatusResponse: ...
76
+
77
+ async def Delete(self, request: app_payload_pb2.DeleteRequest) -> app_payload_pb2.DeleteResponse: ...
78
+
79
+ async def List(self, request: app_payload_pb2.ListRequest) -> app_payload_pb2.ListResponse: ...
80
+
81
+ async def Watch(self, request: app_payload_pb2.WatchRequest) -> app_payload_pb2.WatchResponse: ...
82
+
83
+ async def Lease(self, request: app_payload_pb2.LeaseRequest) -> app_payload_pb2.LeaseResponse: ...
84
+
85
+
86
+ class RunService(Protocol):
87
+ async def CreateRun(self, request: run_service_pb2.CreateRunRequest) -> run_service_pb2.CreateRunResponse: ...
88
+
89
+ async def AbortRun(self, request: run_service_pb2.AbortRunRequest) -> run_service_pb2.AbortRunResponse: ...
90
+
91
+ async def GetRunDetails(
92
+ self, request: run_service_pb2.GetRunDetailsRequest
93
+ ) -> run_service_pb2.GetRunDetailsResponse: ...
94
+
95
+ async def WatchRunDetails(
96
+ self, request: run_service_pb2.WatchRunDetailsRequest
97
+ ) -> AsyncIterator[run_service_pb2.WatchRunDetailsResponse]: ...
98
+
99
+ async def GetActionDetails(
100
+ self, request: run_service_pb2.GetActionDetailsRequest
101
+ ) -> run_service_pb2.GetActionDetailsResponse: ...
102
+
103
+ async def WatchActionDetails(
104
+ self, request: run_service_pb2.WatchActionDetailsRequest
105
+ ) -> AsyncIterator[run_service_pb2.WatchActionDetailsResponse]: ...
106
+
107
+ async def GetActionData(
108
+ self, request: run_service_pb2.GetActionDataRequest
109
+ ) -> run_service_pb2.GetActionDataResponse: ...
110
+
111
+ async def ListRuns(self, request: run_service_pb2.ListRunsRequest) -> run_service_pb2.ListRunsResponse: ...
112
+
113
+ async def WatchRuns(
114
+ self, request: run_service_pb2.WatchRunsRequest
115
+ ) -> AsyncIterator[run_service_pb2.WatchRunsResponse]: ...
116
+
117
+ async def ListActions(self, request: run_service_pb2.ListActionsRequest) -> run_service_pb2.ListActionsResponse: ...
118
+
119
+ async def WatchActions(
120
+ self, request: run_service_pb2.WatchActionsRequest
121
+ ) -> AsyncIterator[run_service_pb2.WatchActionsResponse]: ...
122
+
123
+
124
+ class DataProxyService(Protocol):
125
+ async def CreateUploadLocation(
126
+ self, request: dataproxy_pb2.CreateUploadLocationRequest
127
+ ) -> dataproxy_pb2.CreateUploadLocationResponse: ...
128
+
129
+ async def CreateDownloadLocation(
130
+ self, request: dataproxy_pb2.CreateDownloadLocationRequest
131
+ ) -> dataproxy_pb2.CreateDownloadLocationResponse: ...
132
+
133
+ async def CreateDownloadLink(
134
+ self, request: dataproxy_pb2.CreateDownloadLinkRequest
135
+ ) -> dataproxy_pb2.CreateDownloadLinkResponse: ...
136
+
137
+ async def GetData(self, request: dataproxy_pb2.GetDataRequest) -> dataproxy_pb2.GetDataResponse: ...
138
+
139
+
140
+ class RunLogsService(Protocol):
141
+ def TailLogs(
142
+ self, request: run_logs_service_pb2.TailLogsRequest
143
+ ) -> UnaryStreamCall[RequestType, run_logs_service_pb2.TailLogsResponse]: ...
144
+
145
+
146
+ class SecretService(Protocol):
147
+ async def CreateSecret(self, request: payload_pb2.CreateSecretRequest) -> payload_pb2.CreateSecretResponse: ...
148
+
149
+ async def UpdateSecret(self, request: payload_pb2.UpdateSecretRequest) -> payload_pb2.UpdateSecretResponse: ...
150
+
151
+ async def GetSecret(self, request: payload_pb2.GetSecretRequest) -> payload_pb2.GetSecretResponse: ...
152
+
153
+ async def ListSecrets(self, request: payload_pb2.ListSecretsRequest) -> payload_pb2.ListSecretsResponse: ...
154
+
155
+ async def DeleteSecret(self, request: payload_pb2.DeleteSecretRequest) -> payload_pb2.DeleteSecretResponse: ...
156
+
157
+
158
+ class IdentityService(Protocol):
159
+ async def UserInfo(self, request: identity_pb2.UserInfoRequest) -> identity_pb2.UserInfoResponse: ...
160
+
161
+
162
+ class TriggerService(Protocol):
163
+ async def DeployTrigger(
164
+ self, request: trigger_service_pb2.DeployTriggerRequest
165
+ ) -> trigger_service_pb2.DeployTriggerResponse: ...
166
+
167
+ async def GetTriggerDetails(
168
+ self, request: trigger_service_pb2.GetTriggerDetailsRequest
169
+ ) -> trigger_service_pb2.GetTriggerDetailsResponse: ...
170
+
171
+ async def GetTriggerRevisionDetails(
172
+ self, request: trigger_service_pb2.GetTriggerRevisionDetailsRequest
173
+ ) -> trigger_service_pb2.GetTriggerRevisionDetailsResponse: ...
174
+
175
+ async def ListTriggers(
176
+ self, request: trigger_service_pb2.ListTriggersRequest
177
+ ) -> trigger_service_pb2.ListTriggersResponse: ...
178
+
179
+ async def GetTriggerRevisionHistory(
180
+ self, request: trigger_service_pb2.GetTriggerRevisionHistoryRequest
181
+ ) -> trigger_service_pb2.GetTriggerRevisionHistoryResponse: ...
182
+
183
+ async def UpdateTriggers(
184
+ self, request: trigger_service_pb2.UpdateTriggersRequest
185
+ ) -> trigger_service_pb2.UpdateTriggersResponse: ...
186
+
187
+ async def DeleteTriggers(
188
+ self, request: trigger_service_pb2.DeleteTriggersRequest
189
+ ) -> trigger_service_pb2.DeleteTriggersResponse: ...
@@ -0,0 +1,12 @@
1
+ from flyte.remote._client.auth._channel import create_channel
2
+ from flyte.remote._client.auth._client_config import AuthType, ClientConfig
3
+ from flyte.remote._client.auth.errors import AccessTokenNotFoundError, AuthenticationError, AuthenticationPending
4
+
5
+ __all__ = [
6
+ "AccessTokenNotFoundError",
7
+ "AuthType",
8
+ "AuthenticationError",
9
+ "AuthenticationPending",
10
+ "ClientConfig",
11
+ "create_channel",
12
+ ]
@@ -0,0 +1,14 @@
1
+ from __future__ import annotations
2
+
3
+ import base64
4
+ from typing import Literal
5
+
6
+
7
+ def decode_api_key(encoded_str: str) -> tuple[str, str, str, str | Literal["None"]]:
8
+ """Decode encoded base64 string into app credentials. endpoint, client_id, client_secret, org"""
9
+ endpoint, client_id, client_secret, org = base64.b64decode(encoded_str.encode("utf-8")).decode("utf-8").split(":")
10
+ # For consistency, let's make sure org is always a non-empty string
11
+ if not org:
12
+ org = "None"
13
+
14
+ return endpoint, client_id, client_secret, org
File without changes
@@ -0,0 +1,403 @@
1
+ import asyncio
2
+ import dataclasses
3
+ import ssl
4
+ import typing
5
+ from abc import abstractmethod
6
+ from http import HTTPStatus
7
+
8
+ import httpx
9
+ from grpc.aio import Metadata
10
+
11
+ from flyte.remote._client.auth._client_config import ClientConfig, ClientConfigStore
12
+ from flyte.remote._client.auth._keyring import Credentials, KeyringStore
13
+
14
+
15
+ @dataclasses.dataclass
16
+ class GrpcAuthMetadata:
17
+ creds_id: str
18
+ pairs: Metadata
19
+
20
+
21
+ class Authenticator(object):
22
+ """
23
+ Base authenticator for all authentication flows
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ endpoint: str,
29
+ *,
30
+ cfg_store: typing.Optional[ClientConfigStore] = None,
31
+ client_config: typing.Optional[ClientConfig] = None,
32
+ credentials: typing.Optional[Credentials] = None,
33
+ http_session: typing.Optional[httpx.AsyncClient] = None,
34
+ http_proxy_url: typing.Optional[str] = None,
35
+ verify: bool = True,
36
+ ca_cert_path: typing.Optional[str] = None,
37
+ default_header_key: str = "authorization",
38
+ **kwargs,
39
+ ):
40
+ """
41
+ Initialize the base authenticator.
42
+
43
+ :param endpoint: The endpoint URL for authentication
44
+ :param cfg_store: Optional client configuration store for retrieving remote configuration
45
+ :param client_config: Optional client configuration containing authentication settings
46
+ :param credentials: Optional credentials to use for authentication
47
+ :param http_session: Optional HTTP session to use for requests
48
+ :param http_proxy_url: Optional HTTP proxy URL
49
+ :param verify: Whether to verify SSL certificates
50
+ :param ca_cert_path: Optional path to CA certificate file
51
+ :param kwargs: Additional keyword arguments passed to get_async_session, which may include:
52
+ - auth: Authentication implementation to use
53
+ - params: Query parameters to include in request URLs
54
+ - headers: HTTP headers to include in requests
55
+ - cookies: Cookies to include in requests
56
+ - cert: SSL client certificate (path or tuple)
57
+ - http1: Whether to enable HTTP/1.1 support
58
+ - http2: Whether to enable HTTP/2 support
59
+ - proxies: Proxy configuration mapping
60
+ - mounts: Mounted transports for specific URL patterns
61
+ - timeout: Request timeout configuration
62
+ - follow_redirects: Whether to follow redirects
63
+ - limits: Connection pool limits
64
+ - max_redirects: Maximum number of redirects to follow
65
+ - event_hooks: Event hooks for request/response lifecycle
66
+ - base_url: Base URL to join with relative URLs
67
+ - transport: Transport implementation to use
68
+ - app: ASGI application to handle requests
69
+ """
70
+ self._endpoint = endpoint
71
+ self._creds = credentials or KeyringStore.retrieve(endpoint)
72
+ self._http_proxy_url = http_proxy_url
73
+ self._verify = verify
74
+ self._ca_cert_path = ca_cert_path
75
+ self._client_config = client_config
76
+ self._cfg_store = cfg_store
77
+ # Will be populated by _ensure_remote_config
78
+ self._resolved_config: ClientConfig | None = None
79
+ # Lock for coroutine safety
80
+ self._async_lock = asyncio.Lock()
81
+ self._http_session = http_session or get_async_session(**kwargs)
82
+ # Id for tracking credential refresh state
83
+ self._creds_id = self._creds.id if self._creds else None
84
+ self._default_header_key = default_header_key
85
+
86
+ async def _resolve_config(self) -> ClientConfig:
87
+ """
88
+ Resolves and merges client configuration with remote configuration.
89
+
90
+ This method fetches the remote configuration from the cfg_store and merges it with
91
+ the local client_config, prioritizing local settings over remote ones.
92
+
93
+ This method is thread-safe and coroutine-safe, ensuring the remote config is fetched
94
+ only once regardless of concurrent access from multiple threads or coroutines.
95
+
96
+ :return: A merged ClientConfig object containing resolved configuration settings
97
+ """
98
+ # First check without locks for performance
99
+ if self._resolved_config is not None:
100
+ return self._resolved_config
101
+
102
+ if self._cfg_store is None:
103
+ raise ValueError("ClientConfigStore is not set. Cannot resolve configuration.")
104
+
105
+ remote_config = await self._cfg_store.get_client_config()
106
+ self._resolved_config = (
107
+ remote_config.with_override(self._client_config) if self._client_config else remote_config
108
+ )
109
+
110
+ return self._resolved_config
111
+
112
+ def get_credentials(self) -> typing.Optional[Credentials]:
113
+ """
114
+ Get the current credentials.
115
+
116
+ :return: The current credentials or None if not set
117
+ """
118
+ return self._creds
119
+
120
+ def _set_credentials(self, creds: Credentials):
121
+ """
122
+ Set the credentials.
123
+
124
+ :param creds: The credentials to set
125
+ """
126
+ self._creds = creds
127
+
128
+ async def get_grpc_call_auth_metadata(self) -> typing.Optional[GrpcAuthMetadata]:
129
+ """
130
+ Fetch the authentication metadata for gRPC calls.
131
+
132
+ :return: A tuple of (header_key, header_value) or None if no credentials are available
133
+ """
134
+ creds = self.get_credentials()
135
+ if creds:
136
+ header_key = self._default_header_key
137
+ if self._resolved_config is not None:
138
+ # We only resolve the config during authentication flow, to avoid unnecessary network calls
139
+ # and usually the header_key is consistent.
140
+ header_key = self._resolved_config.header_key
141
+ return GrpcAuthMetadata(
142
+ creds_id=creds.id,
143
+ pairs=Metadata((header_key, f"Bearer {creds.access_token}")),
144
+ )
145
+ return None
146
+
147
+ async def refresh_credentials(self, creds_id: str | None = None):
148
+ """
149
+ Refresh the credentials asynchronously with thread and asyncio safety.
150
+
151
+ This method implements a thread-safe and coroutine-safe credential refresh mechanism.
152
+ It uses a timestamp-based approach to prevent redundant credential refreshes when
153
+ multiple threads or coroutines attempt to refresh credentials simultaneously.
154
+
155
+ The caller should capture the current _creds_timestamp before attempting to use credentials.
156
+ If credential usage fails, the caller can pass that timestamp to this method.
157
+ If the timestamp matches the current value, a refresh is needed; otherwise,
158
+ another thread has already refreshed the credentials.
159
+
160
+ :param creds_id: The id of credentials when they were last accessed by the caller.
161
+ If None, force a refresh regardless of id.
162
+ :raises: May raise authentication-related exceptions if the refresh fails
163
+ """
164
+ # If creds_id is None, force refresh
165
+ # If creds_id matches current value, credentials need refresh
166
+ # If creds_id doesn't match, another thread already refreshed credentials
167
+ if creds_id and creds_id != self._creds_id:
168
+ # Credentials have been refreshed by another thread/coroutine since caller read them
169
+ return
170
+
171
+ # Use the async lock to ensure coroutine safety
172
+ async with self._async_lock:
173
+ # Double-check pattern to avoid unnecessary work
174
+ if creds_id and creds_id != self._creds_id:
175
+ # Another thread/coroutine refreshed credentials while we were waiting for the lock
176
+ return
177
+
178
+ # Perform the actual credential refresh
179
+ try:
180
+ self._creds = await self._do_refresh_credentials()
181
+ KeyringStore.store(self._creds)
182
+ except Exception:
183
+ KeyringStore.delete(self._endpoint)
184
+ raise
185
+
186
+ # Update the timestamp to indicate credentials have been refreshed
187
+ self._creds_id = self._creds.id
188
+
189
+ @abstractmethod
190
+ async def _do_refresh_credentials(self) -> Credentials:
191
+ """
192
+ Perform the actual credential refresh operation.
193
+
194
+ This method must be implemented by subclasses to handle the specific authentication flow.
195
+ It should update the internal credentials object (_creds) with a new access token.
196
+
197
+ Implementations typically use the resolved configuration from _resolve_config() to
198
+ determine authentication endpoints, scopes, audience, and other parameters needed for
199
+ the specific authentication flow.
200
+
201
+ :raises: May raise authentication-related exceptions if the refresh fails
202
+ """
203
+ ...
204
+
205
+
206
+ class AsyncAuthenticatedClient(httpx.AsyncClient):
207
+ """
208
+ An httpx.AsyncClient that automatically adds authentication headers to requests.
209
+ This class extends httpx.AsyncClient which is inherently async for network operations.
210
+ """
211
+
212
+ def __init__(self, authenticator: Authenticator, **kwargs):
213
+ """
214
+ Initialize the authenticated client.
215
+
216
+ :param authenticator: The authenticator to use for authentication
217
+ :param kwargs: Additional arguments passed to the httpx.AsyncClient constructor
218
+ """
219
+ super().__init__(**kwargs)
220
+ self.auth_adapter = AsyncAuthenticationHTTPAdapter(authenticator)
221
+ self.authenticator = authenticator
222
+
223
+ async def send(self, request: httpx.Request, **kwargs) -> httpx.Response:
224
+ """
225
+ Sends the request with added authentication headers.
226
+ Must be async because it performs network IO operations and may need to refresh credentials.
227
+ If the response returns a 401 status code, refreshes the credentials and retries the request.
228
+
229
+ :param request: The request object to send.
230
+ :param kwargs: Additional keyword arguments passed to the parent httpx.AsyncClient.send method, which may
231
+ include:
232
+ - auth: Authentication implementation to use for this request
233
+ - follow_redirects: Whether to follow redirects for this request
234
+ - timeout: Request timeout configuration for this request
235
+ :return: The response object.
236
+ """
237
+
238
+ creds_id = await self.auth_adapter.add_auth_header(request)
239
+ response = await super().send(request, **kwargs)
240
+
241
+ if response.status_code == HTTPStatus.UNAUTHORIZED:
242
+ await self.authenticator.refresh_credentials(creds_id=creds_id)
243
+ await self.auth_adapter.add_auth_header(request)
244
+ response = await super().send(request, **kwargs)
245
+
246
+ return response
247
+
248
+
249
+ class AsyncAuthenticationHTTPAdapter:
250
+ """
251
+ A custom async HTTP adapter that adds authentication headers to requests of an httpx.AsyncClient.
252
+ This is the async equivalent of AuthenticationHTTPAdapter for requests.
253
+ """
254
+
255
+ def __init__(self, authenticator: Authenticator):
256
+ """
257
+ Initialize the authentication HTTP adapter.
258
+
259
+ :param authenticator: The authenticator to use for authentication
260
+ """
261
+ self.authenticator = authenticator
262
+
263
+ async def add_auth_header(self, request: httpx.Request) -> typing.Optional[str]:
264
+ """
265
+ Adds authentication headers to the request.
266
+ Must be async because it may call refresh_credentials which performs IO operations.
267
+
268
+ :param request: The request object to add headers to.
269
+ :return: The credentials ID (creds_id) used for tracking credential refresh state
270
+ """
271
+ if self.authenticator.get_credentials() is None:
272
+ await self.authenticator.refresh_credentials()
273
+
274
+ metadata = await self.authenticator.get_grpc_call_auth_metadata()
275
+ if metadata is None:
276
+ return None
277
+ for key, value in metadata.pairs.keys():
278
+ request.headers[key] = value
279
+ return metadata.creds_id
280
+
281
+
282
+ def upgrade_async_session_to_proxy_authenticated(
283
+ http_session: httpx.AsyncClient, proxy_authenticator: typing.Optional[Authenticator] = None, **kwargs
284
+ ) -> httpx.AsyncClient:
285
+ """
286
+ Given an httpx.AsyncClient, it returns a new session that uses AsyncAuthenticationHTTPAdapter
287
+ to perform authentication with a proxy in front of Flyte
288
+
289
+ :param http_session: httpx.AsyncClient Precreated session
290
+ :param proxy_authenticator: Optional authenticator for proxy authentication
291
+ :param kwargs: Additional arguments passed to AsyncAuthenticatedClient, which may include:
292
+ - auth: Authentication implementation to use
293
+ - params: Query parameters to include in request URLs
294
+ - headers: HTTP headers to include in requests
295
+ - cookies: Cookies to include in requests
296
+ - verify: SSL verification mode (True/False/path to certificate)
297
+ - cert: SSL client certificate (path or tuple)
298
+ - http1: Whether to enable HTTP/1.1 support
299
+ - http2: Whether to enable HTTP/2 support
300
+ - proxies: Proxy configuration mapping
301
+ - mounts: Mounted transports for specific URL patterns
302
+ - timeout: Request timeout configuration
303
+ - follow_redirects: Whether to follow redirects
304
+ - limits: Connection pool limits
305
+ - max_redirects: Maximum number of redirects to follow
306
+ - event_hooks: Event hooks for request/response lifecycle
307
+ - base_url: Base URL to join with relative URLs
308
+ - transport: Transport implementation to use
309
+ - app: ASGI application to handle requests
310
+ :return: httpx.AsyncClient with authentication
311
+ """
312
+ if proxy_authenticator:
313
+ return AsyncAuthenticatedClient(proxy_authenticator, **kwargs)
314
+ else:
315
+ return http_session
316
+
317
+
318
+ def get_async_session(
319
+ proxy_authenticator: Authenticator | None = None,
320
+ ca_cert_path: str | None = None,
321
+ verify: bool | None = None,
322
+ **kwargs,
323
+ ) -> httpx.AsyncClient:
324
+ """
325
+ Returns a new httpx.AsyncClient with proxy authentication if proxy_authenticator is provided.
326
+
327
+ This function creates a new httpx.AsyncClient and optionally configures it with proxy authentication
328
+ if a proxy authenticator is provided.
329
+
330
+ :param proxy_authenticator: Optional authenticator for proxy authentication
331
+ :param ca_cert_path: Optional path to CA certificate file for SSL verification
332
+ :param verify: Optional SSL verification mode (True/False/path to certificate)
333
+ :param kwargs: Additional keyword arguments passed to httpx.AsyncClient constructor and AsyncAuthenticatedClient,
334
+ which may include:
335
+ - auth: Authentication implementation to use
336
+ - params: Query parameters to include in request URLs
337
+ - headers: HTTP headers to include in requests
338
+ - cookies: Cookies to include in requests
339
+ - cert: SSL client certificate (path or tuple)
340
+ - http1: Whether to enable HTTP/1.1 support
341
+ - http2: Whether to enable HTTP/2 support
342
+ - proxies: Proxy configuration mapping
343
+ - mounts: Mounted transports for specific URL patterns
344
+ - timeout: Request timeout configuration
345
+ - follow_redirects: Whether to follow redirects
346
+ - limits: Connection pool limits
347
+ - max_redirects: Maximum number of redirects to follow
348
+ - event_hooks: Event hooks for request/response lifecycle
349
+ - base_url: Base URL to join with relative URLs
350
+ - transport: Transport implementation to use
351
+ - app: ASGI application to handle requests
352
+ - proxy_env: Environment variables for proxy command
353
+ - proxy_timeout: Timeout for proxy command execution
354
+ - header_key: Header key to use for authentication
355
+ - endpoint: The endpoint URL for authentication
356
+ - client_id: Client ID for authentication
357
+ - client_secret: Client secret for authentication
358
+ - scopes: List of scopes to request during authentication
359
+ - audience: Audience for the token
360
+ - http_proxy_url: HTTP proxy URL
361
+ :return: An httpx.AsyncClient instance, optionally configured with proxy authentication
362
+ """
363
+
364
+ # Extract known httpx.AsyncClient parameters from kwargs
365
+ client_kwargs = {
366
+ k: v
367
+ for k, v in kwargs.items()
368
+ if k
369
+ in [
370
+ "auth",
371
+ "params",
372
+ "headers",
373
+ "cookies",
374
+ "verify",
375
+ "cert",
376
+ "http1",
377
+ "http2",
378
+ "proxies",
379
+ "mounts",
380
+ "timeout",
381
+ "follow_redirects",
382
+ "limits",
383
+ "max_redirects",
384
+ "event_hooks",
385
+ "base_url",
386
+ "transport",
387
+ "app",
388
+ ]
389
+ }
390
+
391
+ if ca_cert_path:
392
+ context = ssl.create_default_context(capath=ca_cert_path)
393
+ verify = True if context is not None else False
394
+
395
+ if verify is not None:
396
+ client_kwargs["verify"] = verify
397
+
398
+ http_session = httpx.AsyncClient(**client_kwargs)
399
+ if proxy_authenticator:
400
+ http_session = upgrade_async_session_to_proxy_authenticated(
401
+ http_session, proxy_authenticator=proxy_authenticator, **kwargs
402
+ )
403
+ return http_session