flyte 0.0.1b0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flyte might be problematic. Click here for more details.

Files changed (390) hide show
  1. flyte/__init__.py +62 -0
  2. flyte/_api_commons.py +3 -0
  3. flyte/_bin/__init__.py +0 -0
  4. flyte/_bin/runtime.py +126 -0
  5. flyte/_build.py +25 -0
  6. flyte/_cache/__init__.py +12 -0
  7. flyte/_cache/cache.py +146 -0
  8. flyte/_cache/defaults.py +9 -0
  9. flyte/_cache/policy_function_body.py +42 -0
  10. flyte/_cli/__init__.py +0 -0
  11. flyte/_cli/_common.py +287 -0
  12. flyte/_cli/_create.py +42 -0
  13. flyte/_cli/_delete.py +23 -0
  14. flyte/_cli/_deploy.py +140 -0
  15. flyte/_cli/_get.py +235 -0
  16. flyte/_cli/_run.py +152 -0
  17. flyte/_cli/main.py +72 -0
  18. flyte/_code_bundle/__init__.py +8 -0
  19. flyte/_code_bundle/_ignore.py +113 -0
  20. flyte/_code_bundle/_packaging.py +187 -0
  21. flyte/_code_bundle/_utils.py +339 -0
  22. flyte/_code_bundle/bundle.py +178 -0
  23. flyte/_context.py +146 -0
  24. flyte/_datastructures.py +342 -0
  25. flyte/_deploy.py +202 -0
  26. flyte/_doc.py +29 -0
  27. flyte/_docstring.py +32 -0
  28. flyte/_environment.py +43 -0
  29. flyte/_group.py +31 -0
  30. flyte/_hash.py +23 -0
  31. flyte/_image.py +760 -0
  32. flyte/_initialize.py +634 -0
  33. flyte/_interface.py +84 -0
  34. flyte/_internal/__init__.py +3 -0
  35. flyte/_internal/controllers/__init__.py +115 -0
  36. flyte/_internal/controllers/_local_controller.py +118 -0
  37. flyte/_internal/controllers/_trace.py +40 -0
  38. flyte/_internal/controllers/pbhash.py +39 -0
  39. flyte/_internal/controllers/remote/__init__.py +40 -0
  40. flyte/_internal/controllers/remote/_action.py +141 -0
  41. flyte/_internal/controllers/remote/_client.py +43 -0
  42. flyte/_internal/controllers/remote/_controller.py +361 -0
  43. flyte/_internal/controllers/remote/_core.py +402 -0
  44. flyte/_internal/controllers/remote/_informer.py +361 -0
  45. flyte/_internal/controllers/remote/_service_protocol.py +50 -0
  46. flyte/_internal/imagebuild/__init__.py +11 -0
  47. flyte/_internal/imagebuild/docker_builder.py +416 -0
  48. flyte/_internal/imagebuild/image_builder.py +241 -0
  49. flyte/_internal/imagebuild/remote_builder.py +0 -0
  50. flyte/_internal/resolvers/__init__.py +0 -0
  51. flyte/_internal/resolvers/_task_module.py +54 -0
  52. flyte/_internal/resolvers/common.py +31 -0
  53. flyte/_internal/resolvers/default.py +28 -0
  54. flyte/_internal/runtime/__init__.py +0 -0
  55. flyte/_internal/runtime/convert.py +199 -0
  56. flyte/_internal/runtime/entrypoints.py +135 -0
  57. flyte/_internal/runtime/io.py +136 -0
  58. flyte/_internal/runtime/resources_serde.py +138 -0
  59. flyte/_internal/runtime/task_serde.py +210 -0
  60. flyte/_internal/runtime/taskrunner.py +190 -0
  61. flyte/_internal/runtime/types_serde.py +54 -0
  62. flyte/_logging.py +124 -0
  63. flyte/_protos/__init__.py +0 -0
  64. flyte/_protos/common/authorization_pb2.py +66 -0
  65. flyte/_protos/common/authorization_pb2.pyi +108 -0
  66. flyte/_protos/common/authorization_pb2_grpc.py +4 -0
  67. flyte/_protos/common/identifier_pb2.py +71 -0
  68. flyte/_protos/common/identifier_pb2.pyi +82 -0
  69. flyte/_protos/common/identifier_pb2_grpc.py +4 -0
  70. flyte/_protos/common/identity_pb2.py +48 -0
  71. flyte/_protos/common/identity_pb2.pyi +72 -0
  72. flyte/_protos/common/identity_pb2_grpc.py +4 -0
  73. flyte/_protos/common/list_pb2.py +36 -0
  74. flyte/_protos/common/list_pb2.pyi +69 -0
  75. flyte/_protos/common/list_pb2_grpc.py +4 -0
  76. flyte/_protos/common/policy_pb2.py +37 -0
  77. flyte/_protos/common/policy_pb2.pyi +27 -0
  78. flyte/_protos/common/policy_pb2_grpc.py +4 -0
  79. flyte/_protos/common/role_pb2.py +37 -0
  80. flyte/_protos/common/role_pb2.pyi +53 -0
  81. flyte/_protos/common/role_pb2_grpc.py +4 -0
  82. flyte/_protos/common/runtime_version_pb2.py +28 -0
  83. flyte/_protos/common/runtime_version_pb2.pyi +24 -0
  84. flyte/_protos/common/runtime_version_pb2_grpc.py +4 -0
  85. flyte/_protos/logs/dataplane/payload_pb2.py +96 -0
  86. flyte/_protos/logs/dataplane/payload_pb2.pyi +168 -0
  87. flyte/_protos/logs/dataplane/payload_pb2_grpc.py +4 -0
  88. flyte/_protos/secret/definition_pb2.py +49 -0
  89. flyte/_protos/secret/definition_pb2.pyi +93 -0
  90. flyte/_protos/secret/definition_pb2_grpc.py +4 -0
  91. flyte/_protos/secret/payload_pb2.py +62 -0
  92. flyte/_protos/secret/payload_pb2.pyi +94 -0
  93. flyte/_protos/secret/payload_pb2_grpc.py +4 -0
  94. flyte/_protos/secret/secret_pb2.py +38 -0
  95. flyte/_protos/secret/secret_pb2.pyi +6 -0
  96. flyte/_protos/secret/secret_pb2_grpc.py +198 -0
  97. flyte/_protos/secret/secret_pb2_grpc_grpc.py +198 -0
  98. flyte/_protos/validate/validate/validate_pb2.py +76 -0
  99. flyte/_protos/workflow/node_execution_service_pb2.py +26 -0
  100. flyte/_protos/workflow/node_execution_service_pb2.pyi +4 -0
  101. flyte/_protos/workflow/node_execution_service_pb2_grpc.py +32 -0
  102. flyte/_protos/workflow/queue_service_pb2.py +106 -0
  103. flyte/_protos/workflow/queue_service_pb2.pyi +141 -0
  104. flyte/_protos/workflow/queue_service_pb2_grpc.py +172 -0
  105. flyte/_protos/workflow/run_definition_pb2.py +128 -0
  106. flyte/_protos/workflow/run_definition_pb2.pyi +310 -0
  107. flyte/_protos/workflow/run_definition_pb2_grpc.py +4 -0
  108. flyte/_protos/workflow/run_logs_service_pb2.py +41 -0
  109. flyte/_protos/workflow/run_logs_service_pb2.pyi +28 -0
  110. flyte/_protos/workflow/run_logs_service_pb2_grpc.py +69 -0
  111. flyte/_protos/workflow/run_service_pb2.py +133 -0
  112. flyte/_protos/workflow/run_service_pb2.pyi +175 -0
  113. flyte/_protos/workflow/run_service_pb2_grpc.py +412 -0
  114. flyte/_protos/workflow/state_service_pb2.py +58 -0
  115. flyte/_protos/workflow/state_service_pb2.pyi +71 -0
  116. flyte/_protos/workflow/state_service_pb2_grpc.py +138 -0
  117. flyte/_protos/workflow/task_definition_pb2.py +72 -0
  118. flyte/_protos/workflow/task_definition_pb2.pyi +65 -0
  119. flyte/_protos/workflow/task_definition_pb2_grpc.py +4 -0
  120. flyte/_protos/workflow/task_service_pb2.py +44 -0
  121. flyte/_protos/workflow/task_service_pb2.pyi +31 -0
  122. flyte/_protos/workflow/task_service_pb2_grpc.py +104 -0
  123. flyte/_resources.py +226 -0
  124. flyte/_retry.py +32 -0
  125. flyte/_reusable_environment.py +25 -0
  126. flyte/_run.py +411 -0
  127. flyte/_secret.py +61 -0
  128. flyte/_task.py +367 -0
  129. flyte/_task_environment.py +200 -0
  130. flyte/_timeout.py +47 -0
  131. flyte/_tools.py +27 -0
  132. flyte/_trace.py +128 -0
  133. flyte/_utils/__init__.py +20 -0
  134. flyte/_utils/asyn.py +119 -0
  135. flyte/_utils/coro_management.py +25 -0
  136. flyte/_utils/file_handling.py +72 -0
  137. flyte/_utils/helpers.py +108 -0
  138. flyte/_utils/lazy_module.py +54 -0
  139. flyte/_utils/uv_script_parser.py +49 -0
  140. flyte/_version.py +21 -0
  141. flyte/connectors/__init__.py +0 -0
  142. flyte/errors.py +143 -0
  143. flyte/extras/__init__.py +5 -0
  144. flyte/extras/_container.py +273 -0
  145. flyte/io/__init__.py +11 -0
  146. flyte/io/_dataframe.py +0 -0
  147. flyte/io/_dir.py +448 -0
  148. flyte/io/_file.py +468 -0
  149. flyte/io/pickle/__init__.py +0 -0
  150. flyte/io/pickle/transformer.py +117 -0
  151. flyte/io/structured_dataset/__init__.py +129 -0
  152. flyte/io/structured_dataset/basic_dfs.py +219 -0
  153. flyte/io/structured_dataset/structured_dataset.py +1061 -0
  154. flyte/py.typed +0 -0
  155. flyte/remote/__init__.py +25 -0
  156. flyte/remote/_client/__init__.py +0 -0
  157. flyte/remote/_client/_protocols.py +131 -0
  158. flyte/remote/_client/auth/__init__.py +12 -0
  159. flyte/remote/_client/auth/_authenticators/__init__.py +0 -0
  160. flyte/remote/_client/auth/_authenticators/base.py +397 -0
  161. flyte/remote/_client/auth/_authenticators/client_credentials.py +73 -0
  162. flyte/remote/_client/auth/_authenticators/device_code.py +118 -0
  163. flyte/remote/_client/auth/_authenticators/external_command.py +79 -0
  164. flyte/remote/_client/auth/_authenticators/factory.py +200 -0
  165. flyte/remote/_client/auth/_authenticators/pkce.py +516 -0
  166. flyte/remote/_client/auth/_channel.py +184 -0
  167. flyte/remote/_client/auth/_client_config.py +83 -0
  168. flyte/remote/_client/auth/_default_html.py +32 -0
  169. flyte/remote/_client/auth/_grpc_utils/__init__.py +0 -0
  170. flyte/remote/_client/auth/_grpc_utils/auth_interceptor.py +288 -0
  171. flyte/remote/_client/auth/_grpc_utils/default_metadata_interceptor.py +151 -0
  172. flyte/remote/_client/auth/_keyring.py +143 -0
  173. flyte/remote/_client/auth/_token_client.py +260 -0
  174. flyte/remote/_client/auth/errors.py +16 -0
  175. flyte/remote/_client/controlplane.py +95 -0
  176. flyte/remote/_console.py +18 -0
  177. flyte/remote/_data.py +155 -0
  178. flyte/remote/_logs.py +116 -0
  179. flyte/remote/_project.py +86 -0
  180. flyte/remote/_run.py +873 -0
  181. flyte/remote/_secret.py +132 -0
  182. flyte/remote/_task.py +227 -0
  183. flyte/report/__init__.py +3 -0
  184. flyte/report/_report.py +178 -0
  185. flyte/report/_template.html +124 -0
  186. flyte/storage/__init__.py +24 -0
  187. flyte/storage/_remote_fs.py +34 -0
  188. flyte/storage/_storage.py +251 -0
  189. flyte/storage/_utils.py +5 -0
  190. flyte/types/__init__.py +13 -0
  191. flyte/types/_interface.py +25 -0
  192. flyte/types/_renderer.py +162 -0
  193. flyte/types/_string_literals.py +120 -0
  194. flyte/types/_type_engine.py +2210 -0
  195. flyte/types/_utils.py +80 -0
  196. flyte-0.0.1b0.dist-info/METADATA +179 -0
  197. flyte-0.0.1b0.dist-info/RECORD +390 -0
  198. flyte-0.0.1b0.dist-info/WHEEL +5 -0
  199. flyte-0.0.1b0.dist-info/entry_points.txt +3 -0
  200. flyte-0.0.1b0.dist-info/top_level.txt +1 -0
  201. union/__init__.py +54 -0
  202. union/_api_commons.py +3 -0
  203. union/_bin/__init__.py +0 -0
  204. union/_bin/runtime.py +113 -0
  205. union/_build.py +25 -0
  206. union/_cache/__init__.py +12 -0
  207. union/_cache/cache.py +141 -0
  208. union/_cache/defaults.py +9 -0
  209. union/_cache/policy_function_body.py +42 -0
  210. union/_cli/__init__.py +0 -0
  211. union/_cli/_common.py +263 -0
  212. union/_cli/_create.py +40 -0
  213. union/_cli/_delete.py +23 -0
  214. union/_cli/_deploy.py +120 -0
  215. union/_cli/_get.py +162 -0
  216. union/_cli/_params.py +579 -0
  217. union/_cli/_run.py +150 -0
  218. union/_cli/main.py +72 -0
  219. union/_code_bundle/__init__.py +8 -0
  220. union/_code_bundle/_ignore.py +113 -0
  221. union/_code_bundle/_packaging.py +187 -0
  222. union/_code_bundle/_utils.py +342 -0
  223. union/_code_bundle/bundle.py +176 -0
  224. union/_context.py +146 -0
  225. union/_datastructures.py +295 -0
  226. union/_deploy.py +185 -0
  227. union/_doc.py +29 -0
  228. union/_docstring.py +26 -0
  229. union/_environment.py +43 -0
  230. union/_group.py +31 -0
  231. union/_hash.py +23 -0
  232. union/_image.py +760 -0
  233. union/_initialize.py +585 -0
  234. union/_interface.py +84 -0
  235. union/_internal/__init__.py +3 -0
  236. union/_internal/controllers/__init__.py +77 -0
  237. union/_internal/controllers/_local_controller.py +77 -0
  238. union/_internal/controllers/pbhash.py +39 -0
  239. union/_internal/controllers/remote/__init__.py +40 -0
  240. union/_internal/controllers/remote/_action.py +131 -0
  241. union/_internal/controllers/remote/_client.py +43 -0
  242. union/_internal/controllers/remote/_controller.py +169 -0
  243. union/_internal/controllers/remote/_core.py +341 -0
  244. union/_internal/controllers/remote/_informer.py +260 -0
  245. union/_internal/controllers/remote/_service_protocol.py +44 -0
  246. union/_internal/imagebuild/__init__.py +11 -0
  247. union/_internal/imagebuild/docker_builder.py +416 -0
  248. union/_internal/imagebuild/image_builder.py +243 -0
  249. union/_internal/imagebuild/remote_builder.py +0 -0
  250. union/_internal/resolvers/__init__.py +0 -0
  251. union/_internal/resolvers/_task_module.py +31 -0
  252. union/_internal/resolvers/common.py +24 -0
  253. union/_internal/resolvers/default.py +27 -0
  254. union/_internal/runtime/__init__.py +0 -0
  255. union/_internal/runtime/convert.py +163 -0
  256. union/_internal/runtime/entrypoints.py +121 -0
  257. union/_internal/runtime/io.py +136 -0
  258. union/_internal/runtime/resources_serde.py +134 -0
  259. union/_internal/runtime/task_serde.py +202 -0
  260. union/_internal/runtime/taskrunner.py +179 -0
  261. union/_internal/runtime/types_serde.py +53 -0
  262. union/_logging.py +124 -0
  263. union/_protos/__init__.py +0 -0
  264. union/_protos/common/authorization_pb2.py +66 -0
  265. union/_protos/common/authorization_pb2.pyi +106 -0
  266. union/_protos/common/authorization_pb2_grpc.py +4 -0
  267. union/_protos/common/identifier_pb2.py +71 -0
  268. union/_protos/common/identifier_pb2.pyi +82 -0
  269. union/_protos/common/identifier_pb2_grpc.py +4 -0
  270. union/_protos/common/identity_pb2.py +48 -0
  271. union/_protos/common/identity_pb2.pyi +72 -0
  272. union/_protos/common/identity_pb2_grpc.py +4 -0
  273. union/_protos/common/list_pb2.py +36 -0
  274. union/_protos/common/list_pb2.pyi +69 -0
  275. union/_protos/common/list_pb2_grpc.py +4 -0
  276. union/_protos/common/policy_pb2.py +37 -0
  277. union/_protos/common/policy_pb2.pyi +27 -0
  278. union/_protos/common/policy_pb2_grpc.py +4 -0
  279. union/_protos/common/role_pb2.py +37 -0
  280. union/_protos/common/role_pb2.pyi +51 -0
  281. union/_protos/common/role_pb2_grpc.py +4 -0
  282. union/_protos/common/runtime_version_pb2.py +28 -0
  283. union/_protos/common/runtime_version_pb2.pyi +24 -0
  284. union/_protos/common/runtime_version_pb2_grpc.py +4 -0
  285. union/_protos/logs/dataplane/payload_pb2.py +96 -0
  286. union/_protos/logs/dataplane/payload_pb2.pyi +168 -0
  287. union/_protos/logs/dataplane/payload_pb2_grpc.py +4 -0
  288. union/_protos/secret/definition_pb2.py +49 -0
  289. union/_protos/secret/definition_pb2.pyi +93 -0
  290. union/_protos/secret/definition_pb2_grpc.py +4 -0
  291. union/_protos/secret/payload_pb2.py +62 -0
  292. union/_protos/secret/payload_pb2.pyi +94 -0
  293. union/_protos/secret/payload_pb2_grpc.py +4 -0
  294. union/_protos/secret/secret_pb2.py +38 -0
  295. union/_protos/secret/secret_pb2.pyi +6 -0
  296. union/_protos/secret/secret_pb2_grpc.py +198 -0
  297. union/_protos/validate/validate/validate_pb2.py +76 -0
  298. union/_protos/workflow/node_execution_service_pb2.py +26 -0
  299. union/_protos/workflow/node_execution_service_pb2.pyi +4 -0
  300. union/_protos/workflow/node_execution_service_pb2_grpc.py +32 -0
  301. union/_protos/workflow/queue_service_pb2.py +75 -0
  302. union/_protos/workflow/queue_service_pb2.pyi +103 -0
  303. union/_protos/workflow/queue_service_pb2_grpc.py +172 -0
  304. union/_protos/workflow/run_definition_pb2.py +100 -0
  305. union/_protos/workflow/run_definition_pb2.pyi +256 -0
  306. union/_protos/workflow/run_definition_pb2_grpc.py +4 -0
  307. union/_protos/workflow/run_logs_service_pb2.py +41 -0
  308. union/_protos/workflow/run_logs_service_pb2.pyi +28 -0
  309. union/_protos/workflow/run_logs_service_pb2_grpc.py +69 -0
  310. union/_protos/workflow/run_service_pb2.py +133 -0
  311. union/_protos/workflow/run_service_pb2.pyi +173 -0
  312. union/_protos/workflow/run_service_pb2_grpc.py +412 -0
  313. union/_protos/workflow/state_service_pb2.py +58 -0
  314. union/_protos/workflow/state_service_pb2.pyi +69 -0
  315. union/_protos/workflow/state_service_pb2_grpc.py +138 -0
  316. union/_protos/workflow/task_definition_pb2.py +72 -0
  317. union/_protos/workflow/task_definition_pb2.pyi +65 -0
  318. union/_protos/workflow/task_definition_pb2_grpc.py +4 -0
  319. union/_protos/workflow/task_service_pb2.py +44 -0
  320. union/_protos/workflow/task_service_pb2.pyi +31 -0
  321. union/_protos/workflow/task_service_pb2_grpc.py +104 -0
  322. union/_resources.py +226 -0
  323. union/_retry.py +32 -0
  324. union/_reusable_environment.py +25 -0
  325. union/_run.py +374 -0
  326. union/_secret.py +61 -0
  327. union/_task.py +354 -0
  328. union/_task_environment.py +186 -0
  329. union/_timeout.py +47 -0
  330. union/_tools.py +27 -0
  331. union/_utils/__init__.py +11 -0
  332. union/_utils/asyn.py +119 -0
  333. union/_utils/file_handling.py +71 -0
  334. union/_utils/helpers.py +46 -0
  335. union/_utils/lazy_module.py +54 -0
  336. union/_utils/uv_script_parser.py +49 -0
  337. union/_version.py +21 -0
  338. union/connectors/__init__.py +0 -0
  339. union/errors.py +128 -0
  340. union/extras/__init__.py +5 -0
  341. union/extras/_container.py +263 -0
  342. union/io/__init__.py +11 -0
  343. union/io/_dataframe.py +0 -0
  344. union/io/_dir.py +425 -0
  345. union/io/_file.py +418 -0
  346. union/io/pickle/__init__.py +0 -0
  347. union/io/pickle/transformer.py +117 -0
  348. union/io/structured_dataset/__init__.py +122 -0
  349. union/io/structured_dataset/basic_dfs.py +219 -0
  350. union/io/structured_dataset/structured_dataset.py +1057 -0
  351. union/py.typed +0 -0
  352. union/remote/__init__.py +23 -0
  353. union/remote/_client/__init__.py +0 -0
  354. union/remote/_client/_protocols.py +129 -0
  355. union/remote/_client/auth/__init__.py +12 -0
  356. union/remote/_client/auth/_authenticators/__init__.py +0 -0
  357. union/remote/_client/auth/_authenticators/base.py +391 -0
  358. union/remote/_client/auth/_authenticators/client_credentials.py +73 -0
  359. union/remote/_client/auth/_authenticators/device_code.py +120 -0
  360. union/remote/_client/auth/_authenticators/external_command.py +77 -0
  361. union/remote/_client/auth/_authenticators/factory.py +200 -0
  362. union/remote/_client/auth/_authenticators/pkce.py +515 -0
  363. union/remote/_client/auth/_channel.py +184 -0
  364. union/remote/_client/auth/_client_config.py +83 -0
  365. union/remote/_client/auth/_default_html.py +32 -0
  366. union/remote/_client/auth/_grpc_utils/__init__.py +0 -0
  367. union/remote/_client/auth/_grpc_utils/auth_interceptor.py +204 -0
  368. union/remote/_client/auth/_grpc_utils/default_metadata_interceptor.py +144 -0
  369. union/remote/_client/auth/_keyring.py +154 -0
  370. union/remote/_client/auth/_token_client.py +258 -0
  371. union/remote/_client/auth/errors.py +16 -0
  372. union/remote/_client/controlplane.py +86 -0
  373. union/remote/_data.py +149 -0
  374. union/remote/_logs.py +74 -0
  375. union/remote/_project.py +86 -0
  376. union/remote/_run.py +820 -0
  377. union/remote/_secret.py +132 -0
  378. union/remote/_task.py +193 -0
  379. union/report/__init__.py +3 -0
  380. union/report/_report.py +178 -0
  381. union/report/_template.html +124 -0
  382. union/storage/__init__.py +24 -0
  383. union/storage/_remote_fs.py +34 -0
  384. union/storage/_storage.py +247 -0
  385. union/storage/_utils.py +5 -0
  386. union/types/__init__.py +11 -0
  387. union/types/_renderer.py +162 -0
  388. union/types/_string_literals.py +120 -0
  389. union/types/_type_engine.py +2131 -0
  390. union/types/_utils.py +80 -0
@@ -0,0 +1,184 @@
1
+ import ssl
2
+ import typing
3
+
4
+ import grpc
5
+ import grpc.aio
6
+ import httpx
7
+ from grpc.experimental.aio import init_grpc_aio
8
+
9
+ from union._logging import logger
10
+
11
+ from ._authenticators.base import get_async_session
12
+ from ._authenticators.factory import (
13
+ create_auth_interceptors,
14
+ create_proxy_auth_interceptors,
15
+ get_async_proxy_authenticator,
16
+ )
17
+
18
+ # Initialize gRPC AIO early enough so it can be used in the main thread
19
+ init_grpc_aio()
20
+
21
+
22
+ def bootstrap_ssl_from_server(endpoint: str) -> grpc.ChannelCredentials:
23
+ """
24
+ Retrieves the SSL certificate from the remote server and creates gRPC channel credentials.
25
+
26
+ This function should be used only when insecure-skip-verify is enabled. It extracts the server address
27
+ and port from the endpoint URL, retrieves the SSL certificate from the server, and creates
28
+ gRPC channel credentials using the certificate.
29
+
30
+ :param endpoint: The endpoint URL to retrieve the SSL certificate from, may include port number
31
+ :return: gRPC channel credentials created from the retrieved certificate
32
+ """
33
+ # Get port from endpoint or use 443
34
+ endpoint_parts = endpoint.rsplit(":", 1)
35
+ if len(endpoint_parts) == 2 and endpoint_parts[1].isdigit():
36
+ server_address = (endpoint_parts[0], int(endpoint_parts[1]))
37
+ else:
38
+ logger.warning(f"Unrecognized port in endpoint [{endpoint}], defaulting to 443.")
39
+ server_address = (endpoint, 443)
40
+
41
+ # Run the blocking SSL certificate retrieval in a thread pool
42
+ cert = ssl.get_server_certificate(server_address)
43
+ return grpc.ssl_channel_credentials(str.encode(cert))
44
+
45
+
46
+ async def create_channel(
47
+ endpoint: str,
48
+ *,
49
+ insecure: typing.Optional[bool],
50
+ insecure_skip_verify: typing.Optional[bool] = False,
51
+ ca_cert_file_path: typing.Optional[str] = None,
52
+ ssl_credentials: typing.Optional[grpc.ssl_channel_credentials] = None,
53
+ grpc_options: typing.Optional[typing.Sequence[typing.Tuple[str, typing.Any]]] = None,
54
+ compression: typing.Optional[grpc.Compression] = None,
55
+ http_session: httpx.AsyncClient | None = None,
56
+ proxy_command: typing.List[str] | None = None,
57
+ **kwargs,
58
+ ) -> grpc.aio.Channel:
59
+ """
60
+ Creates a new gRPC channel with appropriate authentication interceptors.
61
+
62
+ This function creates either a secure or insecure gRPC channel based on the provided parameters,
63
+ and adds authentication interceptors to the channel. If SSL credentials are not provided,
64
+ they are created based on the insecure_skip_verify and ca_cert_file_path parameters.
65
+
66
+ The function is async because it may need to read certificate files asynchronously
67
+ and create authentication interceptors that perform async operations.
68
+
69
+ :param endpoint: The endpoint URL for the gRPC channel
70
+ :param insecure: Whether to use an insecure channel (no SSL)
71
+ :param insecure_skip_verify: Whether to skip SSL certificate verification
72
+ :param ca_cert_file_path: Path to CA certificate file for SSL verification
73
+ :param ssl_credentials: Pre-configured SSL credentials for the channel
74
+ :param grpc_options: Additional gRPC channel options
75
+ :param compression: Compression method for the channel
76
+ :param http_session: Pre-configured HTTP session to use for requests
77
+ :param proxy_command: List of strings for proxy command configuration
78
+ :param kwargs: Additional arguments passed to various functions:
79
+ - For grpc.aio.insecure_channel/secure_channel:
80
+ - root_certificates: Root certificates for SSL credentials
81
+ - private_key: Private key for SSL credentials
82
+ - certificate_chain: Certificate chain for SSL credentials
83
+ - options: gRPC channel options
84
+ - compression: gRPC compression method
85
+ - For proxy configuration:
86
+ - proxy_env: Dict of environment variables for proxy
87
+ - proxy_timeout: Timeout for proxy connection
88
+ - For authentication interceptors (passed to create_auth_interceptors and create_proxy_auth_interceptors):
89
+ - auth_type: The authentication type to use ("Pkce", "ClientSecret", "ExternalCommand", "DeviceFlow")
90
+ - command: Command to execute for ExternalCommand authentication
91
+ - client_id: Client ID for ClientSecret authentication
92
+ - client_secret: Client secret for ClientSecret authentication
93
+ - client_credentials_secret: Client secret for ClientSecret authentication (alias)
94
+ - scopes: List of scopes to request during authentication
95
+ - audience: Audience for the token
96
+ - http_proxy_url: HTTP proxy URL
97
+ - verify: Whether to verify SSL certificates
98
+ - ca_cert_path: Optional path to CA certificate file
99
+ - header_key: Header key to use for authentication
100
+ - redirect_uri: OAuth2 redirect URI for PKCE authentication
101
+ - add_request_auth_code_params_to_request_access_token_params: Whether to add auth code params to token
102
+ request
103
+ - request_auth_code_params: Parameters to add to login URI opened in browser
104
+ - request_access_token_params: Parameters to add when exchanging auth code for access token
105
+ - refresh_access_token_params: Parameters to add when refreshing access token
106
+ :return: grpc.aio.Channel with authentication interceptors configured
107
+ """
108
+
109
+ if not ssl_credentials:
110
+ if insecure_skip_verify:
111
+ ssl_credentials = bootstrap_ssl_from_server(endpoint)
112
+ elif ca_cert_file_path:
113
+ import aiofiles
114
+
115
+ async with aiofiles.open(ca_cert_file_path, "rb") as f:
116
+ st_cert = f.read()
117
+ ssl_credentials = grpc.ssl_channel_credentials(st_cert)
118
+ else:
119
+ ssl_credentials = grpc.ssl_channel_credentials()
120
+
121
+ # Create an unauthenticated channel first to use to get the server metadata
122
+ if insecure:
123
+ unauthenticated_channel = grpc.aio.insecure_channel(endpoint, **kwargs)
124
+ else:
125
+ unauthenticated_channel = grpc.aio.secure_channel(
126
+ target=endpoint,
127
+ credentials=ssl_credentials,
128
+ options=grpc_options,
129
+ compression=compression,
130
+ )
131
+
132
+ from ._grpc_utils.default_metadata_interceptor import (
133
+ DefaultMetadataStreamStreamInterceptor,
134
+ DefaultMetadataStreamUnaryInterceptor,
135
+ DefaultMetadataUnaryStreamInterceptor,
136
+ DefaultMetadataUnaryUnaryInterceptor,
137
+ )
138
+
139
+ # Add all types of default metadata interceptors
140
+ interceptors = [
141
+ DefaultMetadataUnaryUnaryInterceptor(),
142
+ DefaultMetadataUnaryStreamInterceptor(),
143
+ DefaultMetadataStreamUnaryInterceptor(),
144
+ DefaultMetadataStreamStreamInterceptor(),
145
+ ]
146
+
147
+ # Create an HTTP session if not provided so we share the same http client across the stack
148
+ if not http_session:
149
+ proxy_authenticator = None
150
+ if proxy_command:
151
+ proxy_authenticator = get_async_proxy_authenticator(
152
+ endpoint=endpoint, proxy_command=proxy_command, **kwargs
153
+ )
154
+
155
+ http_session = get_async_session(
156
+ ca_cert_file_path=ca_cert_file_path, proxy_authenticator=proxy_authenticator, **kwargs
157
+ )
158
+
159
+ # Get proxy auth interceptors
160
+ proxy_auth_interceptors = create_proxy_auth_interceptors(endpoint, http_session=http_session, **kwargs)
161
+ interceptors.extend(proxy_auth_interceptors)
162
+
163
+ # Get auth interceptors
164
+ auth_interceptors = create_auth_interceptors(
165
+ endpoint=endpoint,
166
+ in_channel=unauthenticated_channel,
167
+ insecure=insecure,
168
+ insecure_skip_verify=insecure_skip_verify,
169
+ ca_cert_file_path=ca_cert_file_path,
170
+ http_session=http_session,
171
+ **kwargs,
172
+ )
173
+ interceptors.extend(auth_interceptors)
174
+
175
+ if insecure:
176
+ return grpc.aio.insecure_channel(endpoint, interceptors=interceptors, **kwargs)
177
+
178
+ return grpc.aio.secure_channel(
179
+ target=endpoint,
180
+ credentials=ssl_credentials,
181
+ options=grpc_options,
182
+ compression=compression,
183
+ interceptors=interceptors,
184
+ )
@@ -0,0 +1,83 @@
1
+ import typing
2
+ from abc import abstractmethod
3
+
4
+ import grpc.aio
5
+ import pydantic
6
+ from flyteidl.service.auth_pb2 import OAuth2MetadataRequest, PublicClientAuthConfigRequest
7
+ from flyteidl.service.auth_pb2_grpc import AuthMetadataServiceStub
8
+
9
+ AuthType = typing.Literal["ClientSecret", "Pkce", "ExternalCommand", "DeviceFlow"]
10
+
11
+
12
+ class ClientConfig(pydantic.BaseModel):
13
+ """
14
+ Client Configuration that is needed by the authenticator
15
+ """
16
+
17
+ token_endpoint: str
18
+ authorization_endpoint: str
19
+ redirect_uri: str
20
+ client_id: str
21
+ device_authorization_endpoint: typing.Optional[str] = None
22
+ scopes: typing.List[str] = None
23
+ header_key: str = "authorization"
24
+ audience: typing.Optional[str] = None
25
+
26
+ def with_override(self, other: "ClientConfig") -> "ClientConfig":
27
+ """
28
+ Returns a new ClientConfig instance with the values from the other instance overriding the current instance.
29
+ """
30
+ return ClientConfig(
31
+ token_endpoint=other.token_endpoint or self.token_endpoint,
32
+ authorization_endpoint=other.authorization_endpoint or self.authorization_endpoint,
33
+ redirect_uri=other.redirect_uri or self.redirect_uri,
34
+ client_id=other.client_id or self.client_id,
35
+ device_authorization_endpoint=other.device_authorization_endpoint or self.device_authorization_endpoint,
36
+ scopes=other.scopes or self.scopes,
37
+ header_key=other.header_key or self.header_key,
38
+ audience=other.audience or self.audience,
39
+ )
40
+
41
+
42
+ class ClientConfigStore(object):
43
+ """
44
+ Client Config store retrieve client config. this can be done in multiple ways
45
+ """
46
+
47
+ @abstractmethod
48
+ async def get_client_config(self) -> ClientConfig: ...
49
+
50
+
51
+ class StaticClientConfigStore(ClientConfigStore):
52
+ def __init__(self, cfg: ClientConfig):
53
+ self._cfg = cfg
54
+
55
+ async def get_client_config(self) -> ClientConfig:
56
+ return self._cfg
57
+
58
+
59
+ class RemoteClientConfigStore(ClientConfigStore):
60
+ """
61
+ This class implements the ClientConfigStore that is served by the Flyte Server, that implements AuthMetadataService
62
+ """
63
+
64
+ def __init__(self, unauthenticated_channel: grpc.aio.Channel):
65
+ self._unauthenticated_channel = unauthenticated_channel
66
+
67
+ async def get_client_config(self) -> ClientConfig:
68
+ """
69
+ Retrieves the ClientConfig from the given grpc.Channel assuming AuthMetadataService is available
70
+ """
71
+ metadata_service = AuthMetadataServiceStub(self._unauthenticated_channel)
72
+ public_client_config = await metadata_service.GetPublicClientConfig(PublicClientAuthConfigRequest())
73
+ oauth2_metadata = await metadata_service.GetOAuth2Metadata(OAuth2MetadataRequest())
74
+ return ClientConfig(
75
+ token_endpoint=oauth2_metadata.token_endpoint,
76
+ authorization_endpoint=oauth2_metadata.authorization_endpoint,
77
+ redirect_uri=public_client_config.redirect_uri,
78
+ client_id=public_client_config.client_id,
79
+ scopes=public_client_config.scopes,
80
+ header_key=public_client_config.authorization_metadata_key or None,
81
+ device_authorization_endpoint=oauth2_metadata.device_authorization_endpoint,
82
+ audience=public_client_config.audience,
83
+ )
@@ -0,0 +1,32 @@
1
+ import textwrap
2
+
3
+
4
+ def get_default_success_html(endpoint: str) -> str:
5
+ """Get default success html."""
6
+ return textwrap.dedent(
7
+ """
8
+ <html>
9
+ <head>
10
+ <title>OAuth2 Authentication to Union Successful</title>
11
+ </head>
12
+ <body style="background:white;font-family:Arial">
13
+ <div style="position: absolute;top:40%;left:50%;transform: translate(-50%, -50%);text-align:center;">
14
+ <div style="margin:auto">
15
+ <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 300 65" fill="currentColor"
16
+ style="color:#fdb51e;width:360px;">
17
+ <title>Union.ai</title>
18
+ <path d="M32,64.8C14.4,64.8,0,51.5,0,34V3.6h17.6v41.3c0,1.9,1.1,3,3,3h23c1.9,0,3-1.1,3-3V3.6H64V34
19
+ C64,51.5,49.6,64.8,32,64.8z M69.9,30.9v30.4h17.6V20c0-1.9,1.1-3,3-3h23c1.9,0,3,1.1,3,3v41.3H134V30.9c0-17.5-14.4-30.8-32.1-30.8
20
+ S69.9,13.5,69.9,30.9z M236,30.9v30.4h17.6V20c0-1.9,1.1-3,3-3h23c1.9,0,3,1.1,3,3v41.3H300V30.9c0-17.5-14.4-30.8-32-30.8
21
+ S236,13.5,236,30.9L236,30.9z M230.1,32.4c0,18.2-14.2,32.5-32.2,32.5s-32-14.3-32-32.5s14-32.1,32-32.1S230.1,14.3,230.1,32.4
22
+ L230.1,32.4z M213.5,20.2c0-1.9-1.1-3-3-3h-24.8c-1.9,0-3,1.1-3,3v24.5c0,1.9,1.1,3,3,3h24.8c1.9,0,3-1.1,3-3V20.2z M158.9,3.6
23
+ h-17.6v57.8h17.6V3.6z"></path>
24
+ </svg>
25
+ <h2>You've successfully authenticated to Union!</h2>
26
+ <p style="font-size:20px;">Return to your terminal for next steps</p>
27
+ </div>
28
+ </div>
29
+ </body>
30
+ </html>
31
+ """ # noqa: E501
32
+ )
File without changes
@@ -0,0 +1,204 @@
1
+ import typing
2
+
3
+ import grpc.aio
4
+ from grpc import RpcError
5
+ from grpc.aio import ClientCallDetails, Metadata
6
+
7
+ from union.remote._client.auth._authenticators.base import Authenticator
8
+
9
+
10
+ class _BaseAuthInterceptor:
11
+ """
12
+ Base class for all auth interceptors that provides common authentication functionality.
13
+ """
14
+
15
+ def __init__(self, get_authenticator: typing.Callable[[], Authenticator]):
16
+ self._get_authenticator = get_authenticator
17
+ self._authenticator = None
18
+
19
+ @property
20
+ def authenticator(self) -> Authenticator:
21
+ if self._authenticator is None:
22
+ self._authenticator = self._get_authenticator()
23
+ return self._authenticator
24
+
25
+ async def _call_details_with_auth_metadata(
26
+ self, client_call_details: grpc.aio.ClientCallDetails
27
+ ) -> (grpc.aio.ClientCallDetails, str):
28
+ """
29
+ Returns new ClientCallDetails with authentication metadata added.
30
+
31
+ This method retrieves authentication metadata from the authenticator and adds it to the
32
+ client call details. If no authentication metadata is available, the original client call
33
+ details are returned unchanged.
34
+
35
+ :param client_call_details: The original client call details containing method, timeout, metadata,
36
+ credentials, and wait_for_ready settings
37
+ :return: Updated client call details with authentication metadata added to the existing metadata
38
+ """
39
+ metadata = client_call_details.metadata
40
+ auth_metadata = await self.authenticator.get_grpc_call_auth_metadata()
41
+ if auth_metadata:
42
+ metadata = client_call_details.metadata or Metadata()
43
+ for k, v in auth_metadata.pairs.items():
44
+ metadata.add(k, v)
45
+
46
+ return client_call_details._replace(metadata=metadata), auth_metadata.creds_id if auth_metadata else None
47
+
48
+
49
+ class AuthUnaryUnaryInterceptor(_BaseAuthInterceptor, grpc.aio.UnaryUnaryClientInterceptor):
50
+ """
51
+ Interceptor for unary-unary RPC calls that adds authentication metadata.
52
+ """
53
+
54
+ async def intercept_unary_unary(
55
+ self,
56
+ continuation: typing.Callable,
57
+ client_call_details: ClientCallDetails,
58
+ request: typing.Any,
59
+ ):
60
+ """
61
+ Intercepts unary-unary calls and adds auth metadata if available. On Unauthenticated, resets the token and
62
+ refreshes and then retries with the new token.
63
+
64
+ This method first adds authentication metadata to the client call details, then attempts to make the RPC call.
65
+ If the call fails with an UNAUTHENTICATED or UNKNOWN status code, it refreshes the credentials and retries
66
+ the call with the new authentication metadata.
67
+
68
+ :param continuation: Function to continue the RPC call chain with the updated call details
69
+ :param client_call_details: Details about the RPC call including method, timeout, metadata, credentials,
70
+ and wait_for_ready
71
+ :param request: The request message to be sent to the server
72
+ :return: The response from the RPC call after successful authentication
73
+ :raises: grpc.aio.AioRpcError if the call fails for reasons other than authentication
74
+ """
75
+ updated_call_details, creds_id = await self._call_details_with_auth_metadata(client_call_details)
76
+ try:
77
+ return await (await continuation(updated_call_details, request))
78
+ except grpc.aio.AioRpcError as e:
79
+ if e.code() == grpc.StatusCode.UNAUTHENTICATED or e.code() == grpc.StatusCode.UNKNOWN:
80
+ await self.authenticator.refresh_credentials(creds_id=creds_id)
81
+ updated_call_details, _ = await self._call_details_with_auth_metadata(client_call_details)
82
+ return await (await continuation(updated_call_details, request))
83
+ raise e
84
+
85
+
86
+ class AuthUnaryStreamInterceptor(_BaseAuthInterceptor, grpc.aio.UnaryStreamClientInterceptor):
87
+ """
88
+ Interceptor for unary-stream RPC calls that adds authentication metadata.
89
+ """
90
+
91
+ async def intercept_unary_stream(
92
+ self, continuation: typing.Callable, client_call_details: grpc.aio.ClientCallDetails, request: typing.Any
93
+ ):
94
+ """
95
+ Intercepts unary-stream calls and adds auth metadata if available.
96
+
97
+ This method first adds authentication metadata to the client call details, then attempts to make the RPC call.
98
+ If the call fails with an UNAUTHENTICATED or UNKNOWN status code, it refreshes the credentials and retries
99
+ the call with the new authentication metadata.
100
+
101
+ :param continuation: Function to continue the RPC call chain with the updated call details
102
+ :param client_call_details: Details about the RPC call including method, timeout, metadata, credentials,
103
+ and wait_for_ready
104
+ :param request: The request message to be sent to the server
105
+ :return: A stream of responses from the RPC call after successful authentication
106
+ :raises: grpc.aio.AioRpcError if the call fails for reasons other than authentication
107
+ """
108
+ call_details, creds_id = await self._call_details_with_auth_metadata(client_call_details)
109
+
110
+ async def response_iterator() -> typing.AsyncIterator[typing.Any]:
111
+ call = await continuation(call_details, request)
112
+ try:
113
+ async for response in call:
114
+ yield response
115
+ except grpc.aio.AioRpcError as e:
116
+ if e.code() == grpc.StatusCode.UNAUTHENTICATED or e.code() == grpc.StatusCode.UNKNOWN:
117
+ await self.authenticator.refresh_credentials(creds_id=creds_id)
118
+ updated_call_details, _ = await self._call_details_with_auth_metadata(client_call_details)
119
+ async for response in await continuation(updated_call_details, request):
120
+ yield response
121
+ raise e
122
+
123
+ return response_iterator()
124
+
125
+
126
+ class AuthStreamUnaryInterceptor(_BaseAuthInterceptor, grpc.aio.StreamUnaryClientInterceptor):
127
+ """
128
+ Interceptor for stream-unary RPC calls that adds authentication metadata.
129
+ """
130
+
131
+ async def intercept_stream_unary(
132
+ self,
133
+ continuation: typing.Callable,
134
+ client_call_details: grpc.aio.ClientCallDetails,
135
+ request_iterator: typing.Any,
136
+ ):
137
+ """
138
+ Intercepts stream-unary calls and adds auth metadata if available.
139
+
140
+ This method first adds authentication metadata to the client call details, then attempts to make the RPC call.
141
+ If the call fails with an UNAUTHENTICATED or UNKNOWN status code, it refreshes the credentials and retries
142
+ the call with the new authentication metadata.
143
+
144
+ :param continuation: Function to continue the RPC call chain with the updated call details
145
+ :param client_call_details: Details about the RPC call including method, timeout, metadata, credentials,
146
+ and wait_for_ready
147
+ :param request_iterator: An iterator of request messages to be sent to the server
148
+ :return: The response from the RPC call after successful authentication
149
+ :raises: grpc.aio.AioRpcError if the call fails for reasons other than authentication
150
+ """
151
+ updated_call_details, creds_id = await self._call_details_with_auth_metadata(client_call_details)
152
+ try:
153
+ call = await continuation(updated_call_details, request_iterator)
154
+ return await call
155
+ except grpc.aio.AioRpcError as e:
156
+ if e.code() == grpc.StatusCode.UNAUTHENTICATED or e.code() == grpc.StatusCode.UNKNOWN:
157
+ await self.authenticator.refresh_credentials(creds_id=creds_id)
158
+ updated_call_details, _ = await self._call_details_with_auth_metadata(client_call_details)
159
+ call = await continuation(updated_call_details, request_iterator)
160
+ return await call
161
+ raise e
162
+
163
+
164
+ class AuthStreamStreamInterceptor(_BaseAuthInterceptor, grpc.aio.StreamStreamClientInterceptor):
165
+ """
166
+ Interceptor for stream-stream RPC calls that adds authentication metadata.
167
+ """
168
+
169
+ async def intercept_stream_stream(
170
+ self,
171
+ continuation: typing.Callable,
172
+ client_call_details: grpc.aio.ClientCallDetails,
173
+ request_iterator: typing.Any,
174
+ ):
175
+ """
176
+ Intercepts stream-stream calls and adds auth metadata if available.
177
+
178
+ This method first adds authentication metadata to the client call details, then attempts to make the RPC call.
179
+ If the call fails with an UNAUTHENTICATED or UNKNOWN status code, it refreshes the credentials and retries
180
+ the call with the new authentication metadata.
181
+
182
+ :param continuation: Function to continue the RPC call chain with the updated call details
183
+ :param client_call_details: Details about the RPC call including method, timeout, metadata, credentials,
184
+ and wait_for_ready
185
+ :param request_iterator: An iterator of request messages to be sent to the server
186
+ :return: A stream of responses from the RPC call after successful authentication
187
+ :raises: grpc.aio.AioRpcError if the call fails for reasons other than authentication
188
+ """
189
+ updated_call_details, creds_id = await self._call_details_with_auth_metadata(client_call_details)
190
+ try:
191
+ fut = await (await continuation(updated_call_details, request_iterator))
192
+ return fut
193
+ except grpc.aio.AioRpcError as e:
194
+ if e.code() == grpc.StatusCode.UNAUTHENTICATED or e.code() == grpc.StatusCode.UNKNOWN:
195
+ await self.authenticator.refresh_credentials(creds_id=creds_id)
196
+ updated_call_details, _ = await self._call_details_with_auth_metadata(client_call_details)
197
+ return await (await continuation(updated_call_details, request_iterator))
198
+ raise e
199
+ except RpcError as e:
200
+ raise e
201
+
202
+
203
+ # For backward compatibility, maintain the original class name but as a type alias
204
+ AuthUnaryInterceptor = AuthUnaryUnaryInterceptor
@@ -0,0 +1,144 @@
1
+ import typing
2
+
3
+ import grpc.aio
4
+ from grpc.aio import ClientCallDetails, Metadata
5
+
6
+ _default_metadata = {"accept": "application/grpc"}
7
+
8
+
9
+ class _BaseDefaultMetadataInterceptor:
10
+ """
11
+ Base class for all default metadata interceptors that provides common functionality.
12
+ """
13
+
14
+ async def _inject_default_metadata(self, call_details: grpc.aio.ClientCallDetails):
15
+ """
16
+ Injects default metadata into the client call details.
17
+
18
+ This method adds all key-value pairs from the default metadata dictionary to the
19
+ client call details metadata. If the client call details don't have metadata,
20
+ a new Metadata object is created.
21
+
22
+ :param call_details: The client call details to inject metadata into
23
+ :return: A new ClientCallDetails object with the injected metadata
24
+ """
25
+ metadata = call_details.metadata or Metadata()
26
+ for k, v in _default_metadata.items():
27
+ metadata.add(k, v)
28
+
29
+ # return call_details._replace(metadata=metadata), None
30
+ return ClientCallDetails(
31
+ method=call_details.method,
32
+ timeout=call_details.timeout,
33
+ metadata=metadata,
34
+ credentials=call_details.credentials,
35
+ wait_for_ready=call_details.wait_for_ready,
36
+ )
37
+
38
+
39
+ class DefaultMetadataUnaryUnaryInterceptor(_BaseDefaultMetadataInterceptor, grpc.aio.UnaryUnaryClientInterceptor):
40
+ """
41
+ Interceptor for unary-unary RPC calls that adds default metadata.
42
+ """
43
+
44
+ async def intercept_unary_unary(
45
+ self,
46
+ continuation: typing.Callable,
47
+ client_call_details: grpc.aio.ClientCallDetails,
48
+ request: typing.Any,
49
+ ):
50
+ """
51
+ Intercepts unary-unary calls and injects default metadata.
52
+
53
+ This method adds default metadata to the client call details before continuing the RPC call chain.
54
+
55
+ :param continuation: Function to continue the RPC call chain with the updated call details
56
+ :param client_call_details: Details about the RPC call including method, timeout, metadata, credentials,
57
+ and wait_for_ready
58
+ :param request: The request message to be sent to the server
59
+ :return: The response from the RPC call
60
+ """
61
+ updated_call_details = await self._inject_default_metadata(client_call_details)
62
+ return await (await continuation(updated_call_details, request))
63
+
64
+
65
+ class DefaultMetadataUnaryStreamInterceptor(_BaseDefaultMetadataInterceptor, grpc.aio.UnaryStreamClientInterceptor):
66
+ """
67
+ Interceptor for unary-stream RPC calls that adds default metadata.
68
+ """
69
+
70
+ async def intercept_unary_stream(
71
+ self,
72
+ continuation: typing.Callable,
73
+ client_call_details: grpc.aio.ClientCallDetails,
74
+ request: typing.Any,
75
+ ):
76
+ """
77
+ Intercepts unary-stream calls and injects default metadata.
78
+
79
+ This method adds default metadata to the client call details before continuing the RPC call chain.
80
+
81
+ :param continuation: Function to continue the RPC call chain with the updated call details
82
+ :param client_call_details: Details about the RPC call including method, timeout, metadata, credentials,
83
+ and wait_for_ready
84
+ :param request: The request message to be sent to the server
85
+ :return: A stream of responses from the RPC call
86
+ """
87
+ updated_call_details = await self._inject_default_metadata(client_call_details)
88
+ return await continuation(updated_call_details, request)
89
+
90
+
91
+ class DefaultMetadataStreamUnaryInterceptor(_BaseDefaultMetadataInterceptor, grpc.aio.StreamUnaryClientInterceptor):
92
+ """
93
+ Interceptor for stream-unary RPC calls that adds default metadata.
94
+ """
95
+
96
+ async def intercept_stream_unary(
97
+ self,
98
+ continuation: typing.Callable,
99
+ client_call_details: grpc.aio.ClientCallDetails,
100
+ request_iterator: typing.Any,
101
+ ):
102
+ """
103
+ Intercepts stream-unary calls and injects default metadata.
104
+
105
+ This method adds default metadata to the client call details before continuing the RPC call chain.
106
+
107
+ :param continuation: Function to continue the RPC call chain with the updated call details
108
+ :param client_call_details: Details about the RPC call including method, timeout, metadata, credentials,
109
+ and wait_for_ready
110
+ :param request_iterator: An iterator of request messages to be sent to the server
111
+ :return: The response from the RPC call
112
+ """
113
+ updated_call_details = await self._inject_default_metadata(client_call_details)
114
+ return await (await continuation(updated_call_details, request_iterator))
115
+
116
+
117
+ class DefaultMetadataStreamStreamInterceptor(_BaseDefaultMetadataInterceptor, grpc.aio.StreamStreamClientInterceptor):
118
+ """
119
+ Interceptor for stream-stream RPC calls that adds default metadata.
120
+ """
121
+
122
+ async def intercept_stream_stream(
123
+ self,
124
+ continuation: typing.Callable,
125
+ client_call_details: grpc.aio.ClientCallDetails,
126
+ request_iterator: typing.Any,
127
+ ):
128
+ """
129
+ Intercepts stream-stream calls and injects default metadata.
130
+
131
+ This method adds default metadata to the client call details before continuing the RPC call chain.
132
+
133
+ :param continuation: Function to continue the RPC call chain with the updated call details
134
+ :param client_call_details: Details about the RPC call including method, timeout, metadata, credentials, and
135
+ wait_for_ready
136
+ :param request_iterator: An iterator of request messages to be sent to the server
137
+ :return: A stream of responses from the RPC call
138
+ """
139
+ updated_call_details = await self._inject_default_metadata(client_call_details)
140
+ return await (await continuation(updated_call_details, request_iterator))
141
+
142
+
143
+ # For backward compatibility, maintain the original class name but as a type alias
144
+ DefaultMetadataInterceptor = DefaultMetadataUnaryUnaryInterceptor