flyte 0.0.1b0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flyte might be problematic. Click here for more details.

Files changed (390) hide show
  1. flyte/__init__.py +62 -0
  2. flyte/_api_commons.py +3 -0
  3. flyte/_bin/__init__.py +0 -0
  4. flyte/_bin/runtime.py +126 -0
  5. flyte/_build.py +25 -0
  6. flyte/_cache/__init__.py +12 -0
  7. flyte/_cache/cache.py +146 -0
  8. flyte/_cache/defaults.py +9 -0
  9. flyte/_cache/policy_function_body.py +42 -0
  10. flyte/_cli/__init__.py +0 -0
  11. flyte/_cli/_common.py +287 -0
  12. flyte/_cli/_create.py +42 -0
  13. flyte/_cli/_delete.py +23 -0
  14. flyte/_cli/_deploy.py +140 -0
  15. flyte/_cli/_get.py +235 -0
  16. flyte/_cli/_run.py +152 -0
  17. flyte/_cli/main.py +72 -0
  18. flyte/_code_bundle/__init__.py +8 -0
  19. flyte/_code_bundle/_ignore.py +113 -0
  20. flyte/_code_bundle/_packaging.py +187 -0
  21. flyte/_code_bundle/_utils.py +339 -0
  22. flyte/_code_bundle/bundle.py +178 -0
  23. flyte/_context.py +146 -0
  24. flyte/_datastructures.py +342 -0
  25. flyte/_deploy.py +202 -0
  26. flyte/_doc.py +29 -0
  27. flyte/_docstring.py +32 -0
  28. flyte/_environment.py +43 -0
  29. flyte/_group.py +31 -0
  30. flyte/_hash.py +23 -0
  31. flyte/_image.py +760 -0
  32. flyte/_initialize.py +634 -0
  33. flyte/_interface.py +84 -0
  34. flyte/_internal/__init__.py +3 -0
  35. flyte/_internal/controllers/__init__.py +115 -0
  36. flyte/_internal/controllers/_local_controller.py +118 -0
  37. flyte/_internal/controllers/_trace.py +40 -0
  38. flyte/_internal/controllers/pbhash.py +39 -0
  39. flyte/_internal/controllers/remote/__init__.py +40 -0
  40. flyte/_internal/controllers/remote/_action.py +141 -0
  41. flyte/_internal/controllers/remote/_client.py +43 -0
  42. flyte/_internal/controllers/remote/_controller.py +361 -0
  43. flyte/_internal/controllers/remote/_core.py +402 -0
  44. flyte/_internal/controllers/remote/_informer.py +361 -0
  45. flyte/_internal/controllers/remote/_service_protocol.py +50 -0
  46. flyte/_internal/imagebuild/__init__.py +11 -0
  47. flyte/_internal/imagebuild/docker_builder.py +416 -0
  48. flyte/_internal/imagebuild/image_builder.py +241 -0
  49. flyte/_internal/imagebuild/remote_builder.py +0 -0
  50. flyte/_internal/resolvers/__init__.py +0 -0
  51. flyte/_internal/resolvers/_task_module.py +54 -0
  52. flyte/_internal/resolvers/common.py +31 -0
  53. flyte/_internal/resolvers/default.py +28 -0
  54. flyte/_internal/runtime/__init__.py +0 -0
  55. flyte/_internal/runtime/convert.py +199 -0
  56. flyte/_internal/runtime/entrypoints.py +135 -0
  57. flyte/_internal/runtime/io.py +136 -0
  58. flyte/_internal/runtime/resources_serde.py +138 -0
  59. flyte/_internal/runtime/task_serde.py +210 -0
  60. flyte/_internal/runtime/taskrunner.py +190 -0
  61. flyte/_internal/runtime/types_serde.py +54 -0
  62. flyte/_logging.py +124 -0
  63. flyte/_protos/__init__.py +0 -0
  64. flyte/_protos/common/authorization_pb2.py +66 -0
  65. flyte/_protos/common/authorization_pb2.pyi +108 -0
  66. flyte/_protos/common/authorization_pb2_grpc.py +4 -0
  67. flyte/_protos/common/identifier_pb2.py +71 -0
  68. flyte/_protos/common/identifier_pb2.pyi +82 -0
  69. flyte/_protos/common/identifier_pb2_grpc.py +4 -0
  70. flyte/_protos/common/identity_pb2.py +48 -0
  71. flyte/_protos/common/identity_pb2.pyi +72 -0
  72. flyte/_protos/common/identity_pb2_grpc.py +4 -0
  73. flyte/_protos/common/list_pb2.py +36 -0
  74. flyte/_protos/common/list_pb2.pyi +69 -0
  75. flyte/_protos/common/list_pb2_grpc.py +4 -0
  76. flyte/_protos/common/policy_pb2.py +37 -0
  77. flyte/_protos/common/policy_pb2.pyi +27 -0
  78. flyte/_protos/common/policy_pb2_grpc.py +4 -0
  79. flyte/_protos/common/role_pb2.py +37 -0
  80. flyte/_protos/common/role_pb2.pyi +53 -0
  81. flyte/_protos/common/role_pb2_grpc.py +4 -0
  82. flyte/_protos/common/runtime_version_pb2.py +28 -0
  83. flyte/_protos/common/runtime_version_pb2.pyi +24 -0
  84. flyte/_protos/common/runtime_version_pb2_grpc.py +4 -0
  85. flyte/_protos/logs/dataplane/payload_pb2.py +96 -0
  86. flyte/_protos/logs/dataplane/payload_pb2.pyi +168 -0
  87. flyte/_protos/logs/dataplane/payload_pb2_grpc.py +4 -0
  88. flyte/_protos/secret/definition_pb2.py +49 -0
  89. flyte/_protos/secret/definition_pb2.pyi +93 -0
  90. flyte/_protos/secret/definition_pb2_grpc.py +4 -0
  91. flyte/_protos/secret/payload_pb2.py +62 -0
  92. flyte/_protos/secret/payload_pb2.pyi +94 -0
  93. flyte/_protos/secret/payload_pb2_grpc.py +4 -0
  94. flyte/_protos/secret/secret_pb2.py +38 -0
  95. flyte/_protos/secret/secret_pb2.pyi +6 -0
  96. flyte/_protos/secret/secret_pb2_grpc.py +198 -0
  97. flyte/_protos/secret/secret_pb2_grpc_grpc.py +198 -0
  98. flyte/_protos/validate/validate/validate_pb2.py +76 -0
  99. flyte/_protos/workflow/node_execution_service_pb2.py +26 -0
  100. flyte/_protos/workflow/node_execution_service_pb2.pyi +4 -0
  101. flyte/_protos/workflow/node_execution_service_pb2_grpc.py +32 -0
  102. flyte/_protos/workflow/queue_service_pb2.py +106 -0
  103. flyte/_protos/workflow/queue_service_pb2.pyi +141 -0
  104. flyte/_protos/workflow/queue_service_pb2_grpc.py +172 -0
  105. flyte/_protos/workflow/run_definition_pb2.py +128 -0
  106. flyte/_protos/workflow/run_definition_pb2.pyi +310 -0
  107. flyte/_protos/workflow/run_definition_pb2_grpc.py +4 -0
  108. flyte/_protos/workflow/run_logs_service_pb2.py +41 -0
  109. flyte/_protos/workflow/run_logs_service_pb2.pyi +28 -0
  110. flyte/_protos/workflow/run_logs_service_pb2_grpc.py +69 -0
  111. flyte/_protos/workflow/run_service_pb2.py +133 -0
  112. flyte/_protos/workflow/run_service_pb2.pyi +175 -0
  113. flyte/_protos/workflow/run_service_pb2_grpc.py +412 -0
  114. flyte/_protos/workflow/state_service_pb2.py +58 -0
  115. flyte/_protos/workflow/state_service_pb2.pyi +71 -0
  116. flyte/_protos/workflow/state_service_pb2_grpc.py +138 -0
  117. flyte/_protos/workflow/task_definition_pb2.py +72 -0
  118. flyte/_protos/workflow/task_definition_pb2.pyi +65 -0
  119. flyte/_protos/workflow/task_definition_pb2_grpc.py +4 -0
  120. flyte/_protos/workflow/task_service_pb2.py +44 -0
  121. flyte/_protos/workflow/task_service_pb2.pyi +31 -0
  122. flyte/_protos/workflow/task_service_pb2_grpc.py +104 -0
  123. flyte/_resources.py +226 -0
  124. flyte/_retry.py +32 -0
  125. flyte/_reusable_environment.py +25 -0
  126. flyte/_run.py +411 -0
  127. flyte/_secret.py +61 -0
  128. flyte/_task.py +367 -0
  129. flyte/_task_environment.py +200 -0
  130. flyte/_timeout.py +47 -0
  131. flyte/_tools.py +27 -0
  132. flyte/_trace.py +128 -0
  133. flyte/_utils/__init__.py +20 -0
  134. flyte/_utils/asyn.py +119 -0
  135. flyte/_utils/coro_management.py +25 -0
  136. flyte/_utils/file_handling.py +72 -0
  137. flyte/_utils/helpers.py +108 -0
  138. flyte/_utils/lazy_module.py +54 -0
  139. flyte/_utils/uv_script_parser.py +49 -0
  140. flyte/_version.py +21 -0
  141. flyte/connectors/__init__.py +0 -0
  142. flyte/errors.py +143 -0
  143. flyte/extras/__init__.py +5 -0
  144. flyte/extras/_container.py +273 -0
  145. flyte/io/__init__.py +11 -0
  146. flyte/io/_dataframe.py +0 -0
  147. flyte/io/_dir.py +448 -0
  148. flyte/io/_file.py +468 -0
  149. flyte/io/pickle/__init__.py +0 -0
  150. flyte/io/pickle/transformer.py +117 -0
  151. flyte/io/structured_dataset/__init__.py +129 -0
  152. flyte/io/structured_dataset/basic_dfs.py +219 -0
  153. flyte/io/structured_dataset/structured_dataset.py +1061 -0
  154. flyte/py.typed +0 -0
  155. flyte/remote/__init__.py +25 -0
  156. flyte/remote/_client/__init__.py +0 -0
  157. flyte/remote/_client/_protocols.py +131 -0
  158. flyte/remote/_client/auth/__init__.py +12 -0
  159. flyte/remote/_client/auth/_authenticators/__init__.py +0 -0
  160. flyte/remote/_client/auth/_authenticators/base.py +397 -0
  161. flyte/remote/_client/auth/_authenticators/client_credentials.py +73 -0
  162. flyte/remote/_client/auth/_authenticators/device_code.py +118 -0
  163. flyte/remote/_client/auth/_authenticators/external_command.py +79 -0
  164. flyte/remote/_client/auth/_authenticators/factory.py +200 -0
  165. flyte/remote/_client/auth/_authenticators/pkce.py +516 -0
  166. flyte/remote/_client/auth/_channel.py +184 -0
  167. flyte/remote/_client/auth/_client_config.py +83 -0
  168. flyte/remote/_client/auth/_default_html.py +32 -0
  169. flyte/remote/_client/auth/_grpc_utils/__init__.py +0 -0
  170. flyte/remote/_client/auth/_grpc_utils/auth_interceptor.py +288 -0
  171. flyte/remote/_client/auth/_grpc_utils/default_metadata_interceptor.py +151 -0
  172. flyte/remote/_client/auth/_keyring.py +143 -0
  173. flyte/remote/_client/auth/_token_client.py +260 -0
  174. flyte/remote/_client/auth/errors.py +16 -0
  175. flyte/remote/_client/controlplane.py +95 -0
  176. flyte/remote/_console.py +18 -0
  177. flyte/remote/_data.py +155 -0
  178. flyte/remote/_logs.py +116 -0
  179. flyte/remote/_project.py +86 -0
  180. flyte/remote/_run.py +873 -0
  181. flyte/remote/_secret.py +132 -0
  182. flyte/remote/_task.py +227 -0
  183. flyte/report/__init__.py +3 -0
  184. flyte/report/_report.py +178 -0
  185. flyte/report/_template.html +124 -0
  186. flyte/storage/__init__.py +24 -0
  187. flyte/storage/_remote_fs.py +34 -0
  188. flyte/storage/_storage.py +251 -0
  189. flyte/storage/_utils.py +5 -0
  190. flyte/types/__init__.py +13 -0
  191. flyte/types/_interface.py +25 -0
  192. flyte/types/_renderer.py +162 -0
  193. flyte/types/_string_literals.py +120 -0
  194. flyte/types/_type_engine.py +2210 -0
  195. flyte/types/_utils.py +80 -0
  196. flyte-0.0.1b0.dist-info/METADATA +179 -0
  197. flyte-0.0.1b0.dist-info/RECORD +390 -0
  198. flyte-0.0.1b0.dist-info/WHEEL +5 -0
  199. flyte-0.0.1b0.dist-info/entry_points.txt +3 -0
  200. flyte-0.0.1b0.dist-info/top_level.txt +1 -0
  201. union/__init__.py +54 -0
  202. union/_api_commons.py +3 -0
  203. union/_bin/__init__.py +0 -0
  204. union/_bin/runtime.py +113 -0
  205. union/_build.py +25 -0
  206. union/_cache/__init__.py +12 -0
  207. union/_cache/cache.py +141 -0
  208. union/_cache/defaults.py +9 -0
  209. union/_cache/policy_function_body.py +42 -0
  210. union/_cli/__init__.py +0 -0
  211. union/_cli/_common.py +263 -0
  212. union/_cli/_create.py +40 -0
  213. union/_cli/_delete.py +23 -0
  214. union/_cli/_deploy.py +120 -0
  215. union/_cli/_get.py +162 -0
  216. union/_cli/_params.py +579 -0
  217. union/_cli/_run.py +150 -0
  218. union/_cli/main.py +72 -0
  219. union/_code_bundle/__init__.py +8 -0
  220. union/_code_bundle/_ignore.py +113 -0
  221. union/_code_bundle/_packaging.py +187 -0
  222. union/_code_bundle/_utils.py +342 -0
  223. union/_code_bundle/bundle.py +176 -0
  224. union/_context.py +146 -0
  225. union/_datastructures.py +295 -0
  226. union/_deploy.py +185 -0
  227. union/_doc.py +29 -0
  228. union/_docstring.py +26 -0
  229. union/_environment.py +43 -0
  230. union/_group.py +31 -0
  231. union/_hash.py +23 -0
  232. union/_image.py +760 -0
  233. union/_initialize.py +585 -0
  234. union/_interface.py +84 -0
  235. union/_internal/__init__.py +3 -0
  236. union/_internal/controllers/__init__.py +77 -0
  237. union/_internal/controllers/_local_controller.py +77 -0
  238. union/_internal/controllers/pbhash.py +39 -0
  239. union/_internal/controllers/remote/__init__.py +40 -0
  240. union/_internal/controllers/remote/_action.py +131 -0
  241. union/_internal/controllers/remote/_client.py +43 -0
  242. union/_internal/controllers/remote/_controller.py +169 -0
  243. union/_internal/controllers/remote/_core.py +341 -0
  244. union/_internal/controllers/remote/_informer.py +260 -0
  245. union/_internal/controllers/remote/_service_protocol.py +44 -0
  246. union/_internal/imagebuild/__init__.py +11 -0
  247. union/_internal/imagebuild/docker_builder.py +416 -0
  248. union/_internal/imagebuild/image_builder.py +243 -0
  249. union/_internal/imagebuild/remote_builder.py +0 -0
  250. union/_internal/resolvers/__init__.py +0 -0
  251. union/_internal/resolvers/_task_module.py +31 -0
  252. union/_internal/resolvers/common.py +24 -0
  253. union/_internal/resolvers/default.py +27 -0
  254. union/_internal/runtime/__init__.py +0 -0
  255. union/_internal/runtime/convert.py +163 -0
  256. union/_internal/runtime/entrypoints.py +121 -0
  257. union/_internal/runtime/io.py +136 -0
  258. union/_internal/runtime/resources_serde.py +134 -0
  259. union/_internal/runtime/task_serde.py +202 -0
  260. union/_internal/runtime/taskrunner.py +179 -0
  261. union/_internal/runtime/types_serde.py +53 -0
  262. union/_logging.py +124 -0
  263. union/_protos/__init__.py +0 -0
  264. union/_protos/common/authorization_pb2.py +66 -0
  265. union/_protos/common/authorization_pb2.pyi +106 -0
  266. union/_protos/common/authorization_pb2_grpc.py +4 -0
  267. union/_protos/common/identifier_pb2.py +71 -0
  268. union/_protos/common/identifier_pb2.pyi +82 -0
  269. union/_protos/common/identifier_pb2_grpc.py +4 -0
  270. union/_protos/common/identity_pb2.py +48 -0
  271. union/_protos/common/identity_pb2.pyi +72 -0
  272. union/_protos/common/identity_pb2_grpc.py +4 -0
  273. union/_protos/common/list_pb2.py +36 -0
  274. union/_protos/common/list_pb2.pyi +69 -0
  275. union/_protos/common/list_pb2_grpc.py +4 -0
  276. union/_protos/common/policy_pb2.py +37 -0
  277. union/_protos/common/policy_pb2.pyi +27 -0
  278. union/_protos/common/policy_pb2_grpc.py +4 -0
  279. union/_protos/common/role_pb2.py +37 -0
  280. union/_protos/common/role_pb2.pyi +51 -0
  281. union/_protos/common/role_pb2_grpc.py +4 -0
  282. union/_protos/common/runtime_version_pb2.py +28 -0
  283. union/_protos/common/runtime_version_pb2.pyi +24 -0
  284. union/_protos/common/runtime_version_pb2_grpc.py +4 -0
  285. union/_protos/logs/dataplane/payload_pb2.py +96 -0
  286. union/_protos/logs/dataplane/payload_pb2.pyi +168 -0
  287. union/_protos/logs/dataplane/payload_pb2_grpc.py +4 -0
  288. union/_protos/secret/definition_pb2.py +49 -0
  289. union/_protos/secret/definition_pb2.pyi +93 -0
  290. union/_protos/secret/definition_pb2_grpc.py +4 -0
  291. union/_protos/secret/payload_pb2.py +62 -0
  292. union/_protos/secret/payload_pb2.pyi +94 -0
  293. union/_protos/secret/payload_pb2_grpc.py +4 -0
  294. union/_protos/secret/secret_pb2.py +38 -0
  295. union/_protos/secret/secret_pb2.pyi +6 -0
  296. union/_protos/secret/secret_pb2_grpc.py +198 -0
  297. union/_protos/validate/validate/validate_pb2.py +76 -0
  298. union/_protos/workflow/node_execution_service_pb2.py +26 -0
  299. union/_protos/workflow/node_execution_service_pb2.pyi +4 -0
  300. union/_protos/workflow/node_execution_service_pb2_grpc.py +32 -0
  301. union/_protos/workflow/queue_service_pb2.py +75 -0
  302. union/_protos/workflow/queue_service_pb2.pyi +103 -0
  303. union/_protos/workflow/queue_service_pb2_grpc.py +172 -0
  304. union/_protos/workflow/run_definition_pb2.py +100 -0
  305. union/_protos/workflow/run_definition_pb2.pyi +256 -0
  306. union/_protos/workflow/run_definition_pb2_grpc.py +4 -0
  307. union/_protos/workflow/run_logs_service_pb2.py +41 -0
  308. union/_protos/workflow/run_logs_service_pb2.pyi +28 -0
  309. union/_protos/workflow/run_logs_service_pb2_grpc.py +69 -0
  310. union/_protos/workflow/run_service_pb2.py +133 -0
  311. union/_protos/workflow/run_service_pb2.pyi +173 -0
  312. union/_protos/workflow/run_service_pb2_grpc.py +412 -0
  313. union/_protos/workflow/state_service_pb2.py +58 -0
  314. union/_protos/workflow/state_service_pb2.pyi +69 -0
  315. union/_protos/workflow/state_service_pb2_grpc.py +138 -0
  316. union/_protos/workflow/task_definition_pb2.py +72 -0
  317. union/_protos/workflow/task_definition_pb2.pyi +65 -0
  318. union/_protos/workflow/task_definition_pb2_grpc.py +4 -0
  319. union/_protos/workflow/task_service_pb2.py +44 -0
  320. union/_protos/workflow/task_service_pb2.pyi +31 -0
  321. union/_protos/workflow/task_service_pb2_grpc.py +104 -0
  322. union/_resources.py +226 -0
  323. union/_retry.py +32 -0
  324. union/_reusable_environment.py +25 -0
  325. union/_run.py +374 -0
  326. union/_secret.py +61 -0
  327. union/_task.py +354 -0
  328. union/_task_environment.py +186 -0
  329. union/_timeout.py +47 -0
  330. union/_tools.py +27 -0
  331. union/_utils/__init__.py +11 -0
  332. union/_utils/asyn.py +119 -0
  333. union/_utils/file_handling.py +71 -0
  334. union/_utils/helpers.py +46 -0
  335. union/_utils/lazy_module.py +54 -0
  336. union/_utils/uv_script_parser.py +49 -0
  337. union/_version.py +21 -0
  338. union/connectors/__init__.py +0 -0
  339. union/errors.py +128 -0
  340. union/extras/__init__.py +5 -0
  341. union/extras/_container.py +263 -0
  342. union/io/__init__.py +11 -0
  343. union/io/_dataframe.py +0 -0
  344. union/io/_dir.py +425 -0
  345. union/io/_file.py +418 -0
  346. union/io/pickle/__init__.py +0 -0
  347. union/io/pickle/transformer.py +117 -0
  348. union/io/structured_dataset/__init__.py +122 -0
  349. union/io/structured_dataset/basic_dfs.py +219 -0
  350. union/io/structured_dataset/structured_dataset.py +1057 -0
  351. union/py.typed +0 -0
  352. union/remote/__init__.py +23 -0
  353. union/remote/_client/__init__.py +0 -0
  354. union/remote/_client/_protocols.py +129 -0
  355. union/remote/_client/auth/__init__.py +12 -0
  356. union/remote/_client/auth/_authenticators/__init__.py +0 -0
  357. union/remote/_client/auth/_authenticators/base.py +391 -0
  358. union/remote/_client/auth/_authenticators/client_credentials.py +73 -0
  359. union/remote/_client/auth/_authenticators/device_code.py +120 -0
  360. union/remote/_client/auth/_authenticators/external_command.py +77 -0
  361. union/remote/_client/auth/_authenticators/factory.py +200 -0
  362. union/remote/_client/auth/_authenticators/pkce.py +515 -0
  363. union/remote/_client/auth/_channel.py +184 -0
  364. union/remote/_client/auth/_client_config.py +83 -0
  365. union/remote/_client/auth/_default_html.py +32 -0
  366. union/remote/_client/auth/_grpc_utils/__init__.py +0 -0
  367. union/remote/_client/auth/_grpc_utils/auth_interceptor.py +204 -0
  368. union/remote/_client/auth/_grpc_utils/default_metadata_interceptor.py +144 -0
  369. union/remote/_client/auth/_keyring.py +154 -0
  370. union/remote/_client/auth/_token_client.py +258 -0
  371. union/remote/_client/auth/errors.py +16 -0
  372. union/remote/_client/controlplane.py +86 -0
  373. union/remote/_data.py +149 -0
  374. union/remote/_logs.py +74 -0
  375. union/remote/_project.py +86 -0
  376. union/remote/_run.py +820 -0
  377. union/remote/_secret.py +132 -0
  378. union/remote/_task.py +193 -0
  379. union/report/__init__.py +3 -0
  380. union/report/_report.py +178 -0
  381. union/report/_template.html +124 -0
  382. union/storage/__init__.py +24 -0
  383. union/storage/_remote_fs.py +34 -0
  384. union/storage/_storage.py +247 -0
  385. union/storage/_utils.py +5 -0
  386. union/types/__init__.py +11 -0
  387. union/types/_renderer.py +162 -0
  388. union/types/_string_literals.py +120 -0
  389. union/types/_type_engine.py +2131 -0
  390. union/types/_utils.py +80 -0
@@ -0,0 +1,184 @@
1
+ import ssl
2
+ import typing
3
+
4
+ import grpc
5
+ import grpc.aio
6
+ import httpx
7
+ from grpc.experimental.aio import init_grpc_aio
8
+
9
+ from flyte._logging import logger
10
+
11
+ from ._authenticators.base import get_async_session
12
+ from ._authenticators.factory import (
13
+ create_auth_interceptors,
14
+ create_proxy_auth_interceptors,
15
+ get_async_proxy_authenticator,
16
+ )
17
+
18
+ # Initialize gRPC AIO early enough so it can be used in the main thread
19
+ init_grpc_aio()
20
+
21
+
22
+ def bootstrap_ssl_from_server(endpoint: str) -> grpc.ChannelCredentials:
23
+ """
24
+ Retrieves the SSL certificate from the remote server and creates gRPC channel credentials.
25
+
26
+ This function should be used only when insecure-skip-verify is enabled. It extracts the server address
27
+ and port from the endpoint URL, retrieves the SSL certificate from the server, and creates
28
+ gRPC channel credentials using the certificate.
29
+
30
+ :param endpoint: The endpoint URL to retrieve the SSL certificate from, may include port number
31
+ :return: gRPC channel credentials created from the retrieved certificate
32
+ """
33
+ # Get port from endpoint or use 443
34
+ endpoint_parts = endpoint.rsplit(":", 1)
35
+ if len(endpoint_parts) == 2 and endpoint_parts[1].isdigit():
36
+ server_address = (endpoint_parts[0], int(endpoint_parts[1]))
37
+ else:
38
+ logger.warning(f"Unrecognized port in endpoint [{endpoint}], defaulting to 443.")
39
+ server_address = (endpoint, 443)
40
+
41
+ # Run the blocking SSL certificate retrieval in a thread pool
42
+ cert = ssl.get_server_certificate(server_address)
43
+ return grpc.ssl_channel_credentials(str.encode(cert))
44
+
45
+
46
+ async def create_channel(
47
+ endpoint: str,
48
+ insecure: typing.Optional[bool] = None,
49
+ insecure_skip_verify: typing.Optional[bool] = False,
50
+ ca_cert_file_path: typing.Optional[str] = None,
51
+ ssl_credentials: typing.Optional[grpc.ssl_channel_credentials] = None,
52
+ grpc_options: typing.Optional[typing.Sequence[typing.Tuple[str, typing.Any]]] = None,
53
+ compression: typing.Optional[grpc.Compression] = None,
54
+ http_session: httpx.AsyncClient | None = None,
55
+ proxy_command: typing.List[str] | None = None,
56
+ **kwargs,
57
+ ) -> grpc.aio.Channel:
58
+ """
59
+ Creates a new gRPC channel with appropriate authentication interceptors.
60
+
61
+ This function creates either a secure or insecure gRPC channel based on the provided parameters,
62
+ and adds authentication interceptors to the channel. If SSL credentials are not provided,
63
+ they are created based on the insecure_skip_verify and ca_cert_file_path parameters.
64
+
65
+ The function is async because it may need to read certificate files asynchronously
66
+ and create authentication interceptors that perform async operations.
67
+
68
+ :param endpoint: The endpoint URL for the gRPC channel
69
+ :param insecure: Whether to use an insecure channel (no SSL)
70
+ :param insecure_skip_verify: Whether to skip SSL certificate verification
71
+ :param ca_cert_file_path: Path to CA certificate file for SSL verification
72
+ :param ssl_credentials: Pre-configured SSL credentials for the channel
73
+ :param grpc_options: Additional gRPC channel options
74
+ :param compression: Compression method for the channel
75
+ :param http_session: Pre-configured HTTP session to use for requests
76
+ :param proxy_command: List of strings for proxy command configuration
77
+ :param kwargs: Additional arguments passed to various functions:
78
+ - For grpc.aio.insecure_channel/secure_channel:
79
+ - root_certificates: Root certificates for SSL credentials
80
+ - private_key: Private key for SSL credentials
81
+ - certificate_chain: Certificate chain for SSL credentials
82
+ - options: gRPC channel options
83
+ - compression: gRPC compression method
84
+ - For proxy configuration:
85
+ - proxy_env: Dict of environment variables for proxy
86
+ - proxy_timeout: Timeout for proxy connection
87
+ - For authentication interceptors (passed to create_auth_interceptors and create_proxy_auth_interceptors):
88
+ - auth_type: The authentication type to use ("Pkce", "ClientSecret", "ExternalCommand", "DeviceFlow")
89
+ - command: Command to execute for ExternalCommand authentication
90
+ - client_id: Client ID for ClientSecret authentication
91
+ - client_secret: Client secret for ClientSecret authentication
92
+ - client_credentials_secret: Client secret for ClientSecret authentication (alias)
93
+ - scopes: List of scopes to request during authentication
94
+ - audience: Audience for the token
95
+ - http_proxy_url: HTTP proxy URL
96
+ - verify: Whether to verify SSL certificates
97
+ - ca_cert_path: Optional path to CA certificate file
98
+ - header_key: Header key to use for authentication
99
+ - redirect_uri: OAuth2 redirect URI for PKCE authentication
100
+ - add_request_auth_code_params_to_request_access_token_params: Whether to add auth code params to token
101
+ request
102
+ - request_auth_code_params: Parameters to add to login URI opened in browser
103
+ - request_access_token_params: Parameters to add when exchanging auth code for access token
104
+ - refresh_access_token_params: Parameters to add when refreshing access token
105
+ :return: grpc.aio.Channel with authentication interceptors configured
106
+ """
107
+
108
+ if not ssl_credentials:
109
+ if insecure_skip_verify:
110
+ ssl_credentials = bootstrap_ssl_from_server(endpoint)
111
+ elif ca_cert_file_path:
112
+ import aiofiles
113
+
114
+ async with aiofiles.open(ca_cert_file_path, "rb") as f:
115
+ st_cert = f.read()
116
+ ssl_credentials = grpc.ssl_channel_credentials(st_cert)
117
+ else:
118
+ ssl_credentials = grpc.ssl_channel_credentials()
119
+
120
+ # Create an unauthenticated channel first to use to get the server metadata
121
+ if insecure:
122
+ unauthenticated_channel = grpc.aio.insecure_channel(endpoint, **kwargs)
123
+ else:
124
+ unauthenticated_channel = grpc.aio.secure_channel(
125
+ target=endpoint,
126
+ credentials=ssl_credentials,
127
+ options=grpc_options,
128
+ compression=compression,
129
+ )
130
+
131
+ from ._grpc_utils.default_metadata_interceptor import (
132
+ DefaultMetadataStreamStreamInterceptor,
133
+ DefaultMetadataStreamUnaryInterceptor,
134
+ DefaultMetadataUnaryStreamInterceptor,
135
+ DefaultMetadataUnaryUnaryInterceptor,
136
+ )
137
+
138
+ # Add all types of default metadata interceptors
139
+ interceptors: typing.List[grpc.aio.ClientInterceptor] = [
140
+ DefaultMetadataUnaryUnaryInterceptor(),
141
+ DefaultMetadataUnaryStreamInterceptor(),
142
+ DefaultMetadataStreamUnaryInterceptor(),
143
+ DefaultMetadataStreamStreamInterceptor(),
144
+ ]
145
+
146
+ # Create an HTTP session if not provided so we share the same http client across the stack
147
+ if not http_session:
148
+ proxy_authenticator = None
149
+ if proxy_command:
150
+ proxy_authenticator = get_async_proxy_authenticator(
151
+ endpoint=endpoint, proxy_command=proxy_command, **kwargs
152
+ )
153
+
154
+ http_session = get_async_session(
155
+ ca_cert_file_path=ca_cert_file_path, proxy_authenticator=proxy_authenticator, **kwargs
156
+ )
157
+
158
+ # Get proxy auth interceptors
159
+ proxy_auth_interceptors = create_proxy_auth_interceptors(endpoint, http_session=http_session, **kwargs)
160
+ interceptors.extend(proxy_auth_interceptors)
161
+
162
+ # Get auth interceptors
163
+ auth_interceptors = create_auth_interceptors(
164
+ endpoint=endpoint,
165
+ in_channel=unauthenticated_channel,
166
+ insecure=insecure,
167
+ insecure_skip_verify=insecure_skip_verify,
168
+ ca_cert_file_path=ca_cert_file_path,
169
+ http_session=http_session,
170
+ **kwargs,
171
+ )
172
+
173
+ interceptors.extend(auth_interceptors)
174
+
175
+ if insecure:
176
+ return grpc.aio.insecure_channel(endpoint, interceptors=interceptors, **kwargs)
177
+
178
+ return grpc.aio.secure_channel(
179
+ target=endpoint,
180
+ credentials=ssl_credentials,
181
+ options=grpc_options,
182
+ compression=compression,
183
+ interceptors=interceptors,
184
+ )
@@ -0,0 +1,83 @@
1
+ import typing
2
+ from abc import abstractmethod
3
+
4
+ import grpc.aio
5
+ import pydantic
6
+ from flyteidl.service.auth_pb2 import OAuth2MetadataRequest, PublicClientAuthConfigRequest
7
+ from flyteidl.service.auth_pb2_grpc import AuthMetadataServiceStub
8
+
9
+ AuthType = typing.Literal["ClientSecret", "Pkce", "ExternalCommand", "DeviceFlow"]
10
+
11
+
12
+ class ClientConfig(pydantic.BaseModel):
13
+ """
14
+ Client Configuration that is needed by the authenticator
15
+ """
16
+
17
+ token_endpoint: str
18
+ authorization_endpoint: str
19
+ redirect_uri: str
20
+ client_id: str
21
+ device_authorization_endpoint: typing.Optional[str] = None
22
+ scopes: typing.Optional[typing.List[str]] = None
23
+ header_key: str = "authorization"
24
+ audience: typing.Optional[str] = None
25
+
26
+ def with_override(self, other: "ClientConfig") -> "ClientConfig":
27
+ """
28
+ Returns a new ClientConfig instance with the values from the other instance overriding the current instance.
29
+ """
30
+ return ClientConfig(
31
+ token_endpoint=other.token_endpoint or self.token_endpoint,
32
+ authorization_endpoint=other.authorization_endpoint or self.authorization_endpoint,
33
+ redirect_uri=other.redirect_uri or self.redirect_uri,
34
+ client_id=other.client_id or self.client_id,
35
+ device_authorization_endpoint=other.device_authorization_endpoint or self.device_authorization_endpoint,
36
+ scopes=other.scopes or self.scopes,
37
+ header_key=other.header_key or self.header_key,
38
+ audience=other.audience or self.audience,
39
+ )
40
+
41
+
42
+ class ClientConfigStore(object):
43
+ """
44
+ Client Config store retrieve client config. this can be done in multiple ways
45
+ """
46
+
47
+ @abstractmethod
48
+ async def get_client_config(self) -> ClientConfig: ...
49
+
50
+
51
+ class StaticClientConfigStore(ClientConfigStore):
52
+ def __init__(self, cfg: ClientConfig):
53
+ self._cfg = cfg
54
+
55
+ async def get_client_config(self) -> ClientConfig:
56
+ return self._cfg
57
+
58
+
59
+ class RemoteClientConfigStore(ClientConfigStore):
60
+ """
61
+ This class implements the ClientConfigStore that is served by the Flyte Server, that implements AuthMetadataService
62
+ """
63
+
64
+ def __init__(self, unauthenticated_channel: grpc.aio.Channel):
65
+ self._unauthenticated_channel = unauthenticated_channel
66
+
67
+ async def get_client_config(self) -> ClientConfig:
68
+ """
69
+ Retrieves the ClientConfig from the given grpc.Channel assuming AuthMetadataService is available
70
+ """
71
+ metadata_service = AuthMetadataServiceStub(self._unauthenticated_channel)
72
+ public_client_config = await metadata_service.GetPublicClientConfig(PublicClientAuthConfigRequest())
73
+ oauth2_metadata = await metadata_service.GetOAuth2Metadata(OAuth2MetadataRequest())
74
+ return ClientConfig(
75
+ token_endpoint=oauth2_metadata.token_endpoint,
76
+ authorization_endpoint=oauth2_metadata.authorization_endpoint,
77
+ redirect_uri=public_client_config.redirect_uri,
78
+ client_id=public_client_config.client_id,
79
+ scopes=public_client_config.scopes,
80
+ header_key=public_client_config.authorization_metadata_key,
81
+ device_authorization_endpoint=oauth2_metadata.device_authorization_endpoint,
82
+ audience=public_client_config.audience,
83
+ )
@@ -0,0 +1,32 @@
1
+ import textwrap
2
+
3
+
4
+ def get_default_success_html(endpoint: str) -> str:
5
+ """Get default success html."""
6
+ return textwrap.dedent(
7
+ """
8
+ <html>
9
+ <head>
10
+ <title>OAuth2 Authentication to Union Successful</title>
11
+ </head>
12
+ <body style="background:white;font-family:Arial">
13
+ <div style="position: absolute;top:40%;left:50%;transform: translate(-50%, -50%);text-align:center;">
14
+ <div style="margin:auto">
15
+ <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 300 65" fill="currentColor"
16
+ style="color:#fdb51e;width:360px;">
17
+ <title>Union.ai</title>
18
+ <path d="M32,64.8C14.4,64.8,0,51.5,0,34V3.6h17.6v41.3c0,1.9,1.1,3,3,3h23c1.9,0,3-1.1,3-3V3.6H64V34
19
+ C64,51.5,49.6,64.8,32,64.8z M69.9,30.9v30.4h17.6V20c0-1.9,1.1-3,3-3h23c1.9,0,3,1.1,3,3v41.3H134V30.9c0-17.5-14.4-30.8-32.1-30.8
20
+ S69.9,13.5,69.9,30.9z M236,30.9v30.4h17.6V20c0-1.9,1.1-3,3-3h23c1.9,0,3,1.1,3,3v41.3H300V30.9c0-17.5-14.4-30.8-32-30.8
21
+ S236,13.5,236,30.9L236,30.9z M230.1,32.4c0,18.2-14.2,32.5-32.2,32.5s-32-14.3-32-32.5s14-32.1,32-32.1S230.1,14.3,230.1,32.4
22
+ L230.1,32.4z M213.5,20.2c0-1.9-1.1-3-3-3h-24.8c-1.9,0-3,1.1-3,3v24.5c0,1.9,1.1,3,3,3h24.8c1.9,0,3-1.1,3-3V20.2z M158.9,3.6
23
+ h-17.6v57.8h17.6V3.6z"></path>
24
+ </svg>
25
+ <h2>You've successfully authenticated to Union!</h2>
26
+ <p style="font-size:20px;">Return to your terminal for next steps</p>
27
+ </div>
28
+ </div>
29
+ </body>
30
+ </html>
31
+ """ # noqa: E501
32
+ )
File without changes
@@ -0,0 +1,288 @@
1
+ import typing
2
+ from typing import AsyncIterator, Optional, Union
3
+
4
+ import grpc.aio
5
+ from grpc.aio import ClientCallDetails, Metadata
6
+ from grpc.aio._typing import DoneCallbackType, EOFType, RequestType, ResponseType
7
+
8
+ from flyte.remote._client.auth._authenticators.base import Authenticator
9
+ from flyte.remote._client.auth._grpc_utils.default_metadata_interceptor import with_metadata
10
+
11
+
12
+ class _BaseAuthInterceptor:
13
+ """
14
+ Base class for all auth interceptors that provides common authentication functionality.
15
+ """
16
+
17
+ def __init__(self, get_authenticator: typing.Callable[[], Authenticator]):
18
+ self._get_authenticator = get_authenticator
19
+ self._authenticator: typing.Optional[Authenticator] = None
20
+
21
+ @property
22
+ def authenticator(self) -> Authenticator:
23
+ if self._authenticator is None:
24
+ self._authenticator = self._get_authenticator()
25
+ return self._authenticator
26
+
27
+ async def call_details_with_auth_metadata(
28
+ self, client_call_details: grpc.aio.ClientCallDetails
29
+ ) -> typing.Tuple[grpc.aio.ClientCallDetails, str]:
30
+ """
31
+ Returns new ClientCallDetails with authentication metadata added.
32
+
33
+ This method retrieves authentication metadata from the authenticator and adds it to the
34
+ client call details. If no authentication metadata is available, the original client call
35
+ details are returned unchanged.
36
+
37
+ :param client_call_details: The original client call details containing method, timeout, metadata,
38
+ credentials, and wait_for_ready settings
39
+ :return: Updated client call details with authentication metadata added to the existing metadata
40
+ """
41
+ auth_metadata = await self.authenticator.get_grpc_call_auth_metadata()
42
+ if auth_metadata:
43
+ return with_metadata(client_call_details, auth_metadata.pairs), auth_metadata.creds_id
44
+ else:
45
+ return client_call_details, ""
46
+
47
+
48
+ class AuthUnaryUnaryInterceptor(_BaseAuthInterceptor, grpc.aio.UnaryUnaryClientInterceptor):
49
+ """
50
+ Interceptor for unary-unary RPC calls that adds authentication metadata.
51
+ """
52
+
53
+ async def intercept_unary_unary(
54
+ self,
55
+ continuation: typing.Callable,
56
+ client_call_details: ClientCallDetails,
57
+ request: typing.Any,
58
+ ):
59
+ """
60
+ Intercepts unary-unary calls and adds auth metadata if available. On Unauthenticated, resets the token and
61
+ refreshes and then retries with the new token.
62
+
63
+ This method first adds authentication metadata to the client call details, then attempts to make the RPC call.
64
+ If the call fails with an UNAUTHENTICATED or UNKNOWN status code, it refreshes the credentials and retries
65
+ the call with the new authentication metadata.
66
+
67
+ :param continuation: Function to continue the RPC call chain with the updated call details
68
+ :param client_call_details: Details about the RPC call including method, timeout, metadata, credentials,
69
+ and wait_for_ready
70
+ :param request: The request message to be sent to the server
71
+ :return: The response from the RPC call after successful authentication
72
+ :raises: grpc.aio.AioRpcError if the call fails for reasons other than authentication
73
+ """
74
+ updated_call_details, creds_id = await self.call_details_with_auth_metadata(client_call_details)
75
+ try:
76
+ return await (await continuation(updated_call_details, request))
77
+ except grpc.aio.AioRpcError as e:
78
+ if e.code() == grpc.StatusCode.UNAUTHENTICATED or e.code() == grpc.StatusCode.UNKNOWN:
79
+ await self.authenticator.refresh_credentials(creds_id=creds_id)
80
+ updated_call_details, _ = await self.call_details_with_auth_metadata(client_call_details)
81
+ return await (await continuation(updated_call_details, request))
82
+ else:
83
+ raise e
84
+
85
+
86
+ class UnaryStreamCall(grpc.aio.UnaryStreamCall):
87
+ def __init__(
88
+ self,
89
+ parent_interceptor: _BaseAuthInterceptor,
90
+ authenticator: Authenticator,
91
+ continuation: typing.Callable,
92
+ call_details: grpc.aio.ClientCallDetails,
93
+ request: RequestType,
94
+ ):
95
+ super().__init__()
96
+ self._continuation = continuation
97
+ self._call_details = call_details
98
+ self._request = request
99
+ self._authenticator = authenticator
100
+ self._parent_interceptor = parent_interceptor
101
+ self._call: (
102
+ Union[
103
+ grpc.aio.UnaryStreamCall[RequestType, ResponseType],
104
+ grpc.aio.StreamStreamCall[RequestType, ResponseType],
105
+ ]
106
+ | None
107
+ ) = None
108
+
109
+ async def response_iterator(self) -> typing.AsyncIterator[ResponseType]:
110
+ call_details, creds_id = await self._parent_interceptor.call_details_with_auth_metadata(self._call_details)
111
+ self._call = await self._continuation(call_details, self._request)
112
+ try:
113
+ async for response in self._call:
114
+ yield response
115
+ except grpc.aio.AioRpcError as e:
116
+ if e.code() == grpc.StatusCode.UNAUTHENTICATED or e.code() == grpc.StatusCode.UNKNOWN:
117
+ await self._authenticator.refresh_credentials(creds_id=creds_id)
118
+ updated_call_details, _ = await self._parent_interceptor.call_details_with_auth_metadata(call_details)
119
+ self._call = await self._continuation(updated_call_details, self._request)
120
+ async for response in self._call:
121
+ yield response
122
+ else:
123
+ raise e
124
+
125
+ def __aiter__(self) -> AsyncIterator[ResponseType]:
126
+ return self.response_iterator()
127
+
128
+ async def read(self) -> Union[EOFType, ResponseType]:
129
+ if self._call is not None:
130
+ return await self._call.read()
131
+ return EOFType()
132
+
133
+ async def initial_metadata(self) -> Metadata:
134
+ if self._call is not None:
135
+ return await self._call.initial_metadata()
136
+ return Metadata()
137
+
138
+ async def trailing_metadata(self) -> Metadata:
139
+ if self._call is not None:
140
+ return await self._call.trailing_metadata()
141
+ return Metadata()
142
+
143
+ async def code(self) -> grpc.StatusCode:
144
+ if self._call is not None:
145
+ return await self._call.code()
146
+ return grpc.StatusCode.OK
147
+
148
+ async def details(self) -> str:
149
+ if self._call is not None:
150
+ return await self._call.details()
151
+ return ""
152
+
153
+ async def wait_for_connection(self) -> None:
154
+ if self._call is not None:
155
+ await self._call.wait_for_connection()
156
+ return None
157
+
158
+ def cancelled(self) -> bool:
159
+ if self._call is not None:
160
+ return self._call.cancelled()
161
+ return False
162
+
163
+ def done(self) -> bool:
164
+ if self._call is not None:
165
+ return self._call.done()
166
+ return False
167
+
168
+ def time_remaining(self) -> Optional[float]:
169
+ if self._call is not None:
170
+ return self._call.time_remaining()
171
+ return None
172
+
173
+ def cancel(self) -> bool:
174
+ if self._call is not None:
175
+ return self._call.cancel()
176
+ return False
177
+
178
+ def add_done_callback(self, callback: DoneCallbackType) -> None:
179
+ if self._call is not None:
180
+ self._call.add_done_callback(callback=callback)
181
+ return None
182
+
183
+
184
+ class AuthUnaryStreamInterceptor(_BaseAuthInterceptor, grpc.aio.UnaryStreamClientInterceptor):
185
+ """
186
+ Interceptor for unary-stream RPC calls that adds authentication metadata.
187
+ """
188
+
189
+ async def intercept_unary_stream(
190
+ self, continuation: typing.Callable, client_call_details: grpc.aio.ClientCallDetails, request: typing.Any
191
+ ):
192
+ """
193
+ Intercepts unary-stream calls and adds auth metadata if available.
194
+
195
+ This method first adds authentication metadata to the client call details, then attempts to make the RPC call.
196
+ If the call fails with an UNAUTHENTICATED or UNKNOWN status code, it refreshes the credentials and retries
197
+ the call with the new authentication metadata.
198
+
199
+ :param continuation: Function to continue the RPC call chain with the updated call details
200
+ :param client_call_details: Details about the RPC call including method, timeout, metadata, credentials,
201
+ and wait_for_ready
202
+ :param request: The request message to be sent to the server
203
+ :return: A stream of responses from the RPC call after successful authentication
204
+ :raises: grpc.aio.AioRpcError if the call fails for reasons other than authentication
205
+ """
206
+
207
+ return UnaryStreamCall(
208
+ parent_interceptor=self,
209
+ authenticator=self.authenticator,
210
+ call_details=client_call_details,
211
+ continuation=continuation,
212
+ request=request,
213
+ )
214
+
215
+
216
+ class AuthStreamUnaryInterceptor(_BaseAuthInterceptor, grpc.aio.StreamUnaryClientInterceptor):
217
+ """
218
+ Interceptor for stream-unary RPC calls that adds authentication metadata.
219
+ """
220
+
221
+ async def intercept_stream_unary(
222
+ self,
223
+ continuation: typing.Callable,
224
+ client_call_details: grpc.aio.ClientCallDetails,
225
+ request_iterator: typing.Any,
226
+ ):
227
+ """
228
+ Intercepts stream-unary calls and adds auth metadata if available.
229
+
230
+ This method first adds authentication metadata to the client call details, then attempts to make the RPC call.
231
+ If the call fails with an UNAUTHENTICATED or UNKNOWN status code, it refreshes the credentials and retries
232
+ the call with the new authentication metadata.
233
+
234
+ :param continuation: Function to continue the RPC call chain with the updated call details
235
+ :param client_call_details: Details about the RPC call including method, timeout, metadata, credentials,
236
+ and wait_for_ready
237
+ :param request_iterator: An iterator of request messages to be sent to the server
238
+ :return: The response from the RPC call after successful authentication
239
+ :raises: grpc.aio.AioRpcError if the call fails for reasons other than authentication
240
+ """
241
+ updated_call_details, creds_id = await self.call_details_with_auth_metadata(client_call_details)
242
+ try:
243
+ call = await continuation(updated_call_details, request_iterator)
244
+ return await call
245
+ except grpc.aio.AioRpcError as e:
246
+ if e.code() == grpc.StatusCode.UNAUTHENTICATED or e.code() == grpc.StatusCode.UNKNOWN:
247
+ await self.authenticator.refresh_credentials(creds_id=creds_id)
248
+ updated_call_details, _ = await self.call_details_with_auth_metadata(client_call_details)
249
+ call = await continuation(updated_call_details, request_iterator)
250
+ return await call
251
+ raise e
252
+
253
+
254
+ class AuthStreamStreamInterceptor(_BaseAuthInterceptor, grpc.aio.StreamStreamClientInterceptor):
255
+ """
256
+ Interceptor for stream-stream RPC calls that adds authentication metadata.
257
+ """
258
+
259
+ async def intercept_stream_stream(
260
+ self,
261
+ continuation: typing.Callable,
262
+ client_call_details: grpc.aio.ClientCallDetails,
263
+ request_iterator: typing.Any,
264
+ ):
265
+ """
266
+ Intercepts stream-stream calls and adds auth metadata if available.
267
+
268
+ This method first adds authentication metadata to the client call details, then attempts to make the RPC call.
269
+ If the call fails with an UNAUTHENTICATED or UNKNOWN status code, it refreshes the credentials and retries
270
+ the call with the new authentication metadata.
271
+
272
+ :param continuation: Function to continue the RPC call chain with the updated call details
273
+ :param client_call_details: Details about the RPC call including method, timeout, metadata, credentials,
274
+ and wait_for_ready
275
+ :param request_iterator: An iterator of request messages to be sent to the server
276
+ :return: A stream of responses from the RPC call after successful authentication
277
+ """
278
+ return UnaryStreamCall(
279
+ parent_interceptor=self,
280
+ authenticator=self.authenticator,
281
+ call_details=client_call_details,
282
+ continuation=continuation,
283
+ request=request_iterator,
284
+ )
285
+
286
+
287
+ # For backward compatibility, maintain the original class name but as a type alias
288
+ AuthUnaryInterceptor = AuthUnaryUnaryInterceptor