flyte 0.0.1b0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flyte might be problematic. Click here for more details.

Files changed (390) hide show
  1. flyte/__init__.py +62 -0
  2. flyte/_api_commons.py +3 -0
  3. flyte/_bin/__init__.py +0 -0
  4. flyte/_bin/runtime.py +126 -0
  5. flyte/_build.py +25 -0
  6. flyte/_cache/__init__.py +12 -0
  7. flyte/_cache/cache.py +146 -0
  8. flyte/_cache/defaults.py +9 -0
  9. flyte/_cache/policy_function_body.py +42 -0
  10. flyte/_cli/__init__.py +0 -0
  11. flyte/_cli/_common.py +287 -0
  12. flyte/_cli/_create.py +42 -0
  13. flyte/_cli/_delete.py +23 -0
  14. flyte/_cli/_deploy.py +140 -0
  15. flyte/_cli/_get.py +235 -0
  16. flyte/_cli/_run.py +152 -0
  17. flyte/_cli/main.py +72 -0
  18. flyte/_code_bundle/__init__.py +8 -0
  19. flyte/_code_bundle/_ignore.py +113 -0
  20. flyte/_code_bundle/_packaging.py +187 -0
  21. flyte/_code_bundle/_utils.py +339 -0
  22. flyte/_code_bundle/bundle.py +178 -0
  23. flyte/_context.py +146 -0
  24. flyte/_datastructures.py +342 -0
  25. flyte/_deploy.py +202 -0
  26. flyte/_doc.py +29 -0
  27. flyte/_docstring.py +32 -0
  28. flyte/_environment.py +43 -0
  29. flyte/_group.py +31 -0
  30. flyte/_hash.py +23 -0
  31. flyte/_image.py +760 -0
  32. flyte/_initialize.py +634 -0
  33. flyte/_interface.py +84 -0
  34. flyte/_internal/__init__.py +3 -0
  35. flyte/_internal/controllers/__init__.py +115 -0
  36. flyte/_internal/controllers/_local_controller.py +118 -0
  37. flyte/_internal/controllers/_trace.py +40 -0
  38. flyte/_internal/controllers/pbhash.py +39 -0
  39. flyte/_internal/controllers/remote/__init__.py +40 -0
  40. flyte/_internal/controllers/remote/_action.py +141 -0
  41. flyte/_internal/controllers/remote/_client.py +43 -0
  42. flyte/_internal/controllers/remote/_controller.py +361 -0
  43. flyte/_internal/controllers/remote/_core.py +402 -0
  44. flyte/_internal/controllers/remote/_informer.py +361 -0
  45. flyte/_internal/controllers/remote/_service_protocol.py +50 -0
  46. flyte/_internal/imagebuild/__init__.py +11 -0
  47. flyte/_internal/imagebuild/docker_builder.py +416 -0
  48. flyte/_internal/imagebuild/image_builder.py +241 -0
  49. flyte/_internal/imagebuild/remote_builder.py +0 -0
  50. flyte/_internal/resolvers/__init__.py +0 -0
  51. flyte/_internal/resolvers/_task_module.py +54 -0
  52. flyte/_internal/resolvers/common.py +31 -0
  53. flyte/_internal/resolvers/default.py +28 -0
  54. flyte/_internal/runtime/__init__.py +0 -0
  55. flyte/_internal/runtime/convert.py +199 -0
  56. flyte/_internal/runtime/entrypoints.py +135 -0
  57. flyte/_internal/runtime/io.py +136 -0
  58. flyte/_internal/runtime/resources_serde.py +138 -0
  59. flyte/_internal/runtime/task_serde.py +210 -0
  60. flyte/_internal/runtime/taskrunner.py +190 -0
  61. flyte/_internal/runtime/types_serde.py +54 -0
  62. flyte/_logging.py +124 -0
  63. flyte/_protos/__init__.py +0 -0
  64. flyte/_protos/common/authorization_pb2.py +66 -0
  65. flyte/_protos/common/authorization_pb2.pyi +108 -0
  66. flyte/_protos/common/authorization_pb2_grpc.py +4 -0
  67. flyte/_protos/common/identifier_pb2.py +71 -0
  68. flyte/_protos/common/identifier_pb2.pyi +82 -0
  69. flyte/_protos/common/identifier_pb2_grpc.py +4 -0
  70. flyte/_protos/common/identity_pb2.py +48 -0
  71. flyte/_protos/common/identity_pb2.pyi +72 -0
  72. flyte/_protos/common/identity_pb2_grpc.py +4 -0
  73. flyte/_protos/common/list_pb2.py +36 -0
  74. flyte/_protos/common/list_pb2.pyi +69 -0
  75. flyte/_protos/common/list_pb2_grpc.py +4 -0
  76. flyte/_protos/common/policy_pb2.py +37 -0
  77. flyte/_protos/common/policy_pb2.pyi +27 -0
  78. flyte/_protos/common/policy_pb2_grpc.py +4 -0
  79. flyte/_protos/common/role_pb2.py +37 -0
  80. flyte/_protos/common/role_pb2.pyi +53 -0
  81. flyte/_protos/common/role_pb2_grpc.py +4 -0
  82. flyte/_protos/common/runtime_version_pb2.py +28 -0
  83. flyte/_protos/common/runtime_version_pb2.pyi +24 -0
  84. flyte/_protos/common/runtime_version_pb2_grpc.py +4 -0
  85. flyte/_protos/logs/dataplane/payload_pb2.py +96 -0
  86. flyte/_protos/logs/dataplane/payload_pb2.pyi +168 -0
  87. flyte/_protos/logs/dataplane/payload_pb2_grpc.py +4 -0
  88. flyte/_protos/secret/definition_pb2.py +49 -0
  89. flyte/_protos/secret/definition_pb2.pyi +93 -0
  90. flyte/_protos/secret/definition_pb2_grpc.py +4 -0
  91. flyte/_protos/secret/payload_pb2.py +62 -0
  92. flyte/_protos/secret/payload_pb2.pyi +94 -0
  93. flyte/_protos/secret/payload_pb2_grpc.py +4 -0
  94. flyte/_protos/secret/secret_pb2.py +38 -0
  95. flyte/_protos/secret/secret_pb2.pyi +6 -0
  96. flyte/_protos/secret/secret_pb2_grpc.py +198 -0
  97. flyte/_protos/secret/secret_pb2_grpc_grpc.py +198 -0
  98. flyte/_protos/validate/validate/validate_pb2.py +76 -0
  99. flyte/_protos/workflow/node_execution_service_pb2.py +26 -0
  100. flyte/_protos/workflow/node_execution_service_pb2.pyi +4 -0
  101. flyte/_protos/workflow/node_execution_service_pb2_grpc.py +32 -0
  102. flyte/_protos/workflow/queue_service_pb2.py +106 -0
  103. flyte/_protos/workflow/queue_service_pb2.pyi +141 -0
  104. flyte/_protos/workflow/queue_service_pb2_grpc.py +172 -0
  105. flyte/_protos/workflow/run_definition_pb2.py +128 -0
  106. flyte/_protos/workflow/run_definition_pb2.pyi +310 -0
  107. flyte/_protos/workflow/run_definition_pb2_grpc.py +4 -0
  108. flyte/_protos/workflow/run_logs_service_pb2.py +41 -0
  109. flyte/_protos/workflow/run_logs_service_pb2.pyi +28 -0
  110. flyte/_protos/workflow/run_logs_service_pb2_grpc.py +69 -0
  111. flyte/_protos/workflow/run_service_pb2.py +133 -0
  112. flyte/_protos/workflow/run_service_pb2.pyi +175 -0
  113. flyte/_protos/workflow/run_service_pb2_grpc.py +412 -0
  114. flyte/_protos/workflow/state_service_pb2.py +58 -0
  115. flyte/_protos/workflow/state_service_pb2.pyi +71 -0
  116. flyte/_protos/workflow/state_service_pb2_grpc.py +138 -0
  117. flyte/_protos/workflow/task_definition_pb2.py +72 -0
  118. flyte/_protos/workflow/task_definition_pb2.pyi +65 -0
  119. flyte/_protos/workflow/task_definition_pb2_grpc.py +4 -0
  120. flyte/_protos/workflow/task_service_pb2.py +44 -0
  121. flyte/_protos/workflow/task_service_pb2.pyi +31 -0
  122. flyte/_protos/workflow/task_service_pb2_grpc.py +104 -0
  123. flyte/_resources.py +226 -0
  124. flyte/_retry.py +32 -0
  125. flyte/_reusable_environment.py +25 -0
  126. flyte/_run.py +411 -0
  127. flyte/_secret.py +61 -0
  128. flyte/_task.py +367 -0
  129. flyte/_task_environment.py +200 -0
  130. flyte/_timeout.py +47 -0
  131. flyte/_tools.py +27 -0
  132. flyte/_trace.py +128 -0
  133. flyte/_utils/__init__.py +20 -0
  134. flyte/_utils/asyn.py +119 -0
  135. flyte/_utils/coro_management.py +25 -0
  136. flyte/_utils/file_handling.py +72 -0
  137. flyte/_utils/helpers.py +108 -0
  138. flyte/_utils/lazy_module.py +54 -0
  139. flyte/_utils/uv_script_parser.py +49 -0
  140. flyte/_version.py +21 -0
  141. flyte/connectors/__init__.py +0 -0
  142. flyte/errors.py +143 -0
  143. flyte/extras/__init__.py +5 -0
  144. flyte/extras/_container.py +273 -0
  145. flyte/io/__init__.py +11 -0
  146. flyte/io/_dataframe.py +0 -0
  147. flyte/io/_dir.py +448 -0
  148. flyte/io/_file.py +468 -0
  149. flyte/io/pickle/__init__.py +0 -0
  150. flyte/io/pickle/transformer.py +117 -0
  151. flyte/io/structured_dataset/__init__.py +129 -0
  152. flyte/io/structured_dataset/basic_dfs.py +219 -0
  153. flyte/io/structured_dataset/structured_dataset.py +1061 -0
  154. flyte/py.typed +0 -0
  155. flyte/remote/__init__.py +25 -0
  156. flyte/remote/_client/__init__.py +0 -0
  157. flyte/remote/_client/_protocols.py +131 -0
  158. flyte/remote/_client/auth/__init__.py +12 -0
  159. flyte/remote/_client/auth/_authenticators/__init__.py +0 -0
  160. flyte/remote/_client/auth/_authenticators/base.py +397 -0
  161. flyte/remote/_client/auth/_authenticators/client_credentials.py +73 -0
  162. flyte/remote/_client/auth/_authenticators/device_code.py +118 -0
  163. flyte/remote/_client/auth/_authenticators/external_command.py +79 -0
  164. flyte/remote/_client/auth/_authenticators/factory.py +200 -0
  165. flyte/remote/_client/auth/_authenticators/pkce.py +516 -0
  166. flyte/remote/_client/auth/_channel.py +184 -0
  167. flyte/remote/_client/auth/_client_config.py +83 -0
  168. flyte/remote/_client/auth/_default_html.py +32 -0
  169. flyte/remote/_client/auth/_grpc_utils/__init__.py +0 -0
  170. flyte/remote/_client/auth/_grpc_utils/auth_interceptor.py +288 -0
  171. flyte/remote/_client/auth/_grpc_utils/default_metadata_interceptor.py +151 -0
  172. flyte/remote/_client/auth/_keyring.py +143 -0
  173. flyte/remote/_client/auth/_token_client.py +260 -0
  174. flyte/remote/_client/auth/errors.py +16 -0
  175. flyte/remote/_client/controlplane.py +95 -0
  176. flyte/remote/_console.py +18 -0
  177. flyte/remote/_data.py +155 -0
  178. flyte/remote/_logs.py +116 -0
  179. flyte/remote/_project.py +86 -0
  180. flyte/remote/_run.py +873 -0
  181. flyte/remote/_secret.py +132 -0
  182. flyte/remote/_task.py +227 -0
  183. flyte/report/__init__.py +3 -0
  184. flyte/report/_report.py +178 -0
  185. flyte/report/_template.html +124 -0
  186. flyte/storage/__init__.py +24 -0
  187. flyte/storage/_remote_fs.py +34 -0
  188. flyte/storage/_storage.py +251 -0
  189. flyte/storage/_utils.py +5 -0
  190. flyte/types/__init__.py +13 -0
  191. flyte/types/_interface.py +25 -0
  192. flyte/types/_renderer.py +162 -0
  193. flyte/types/_string_literals.py +120 -0
  194. flyte/types/_type_engine.py +2210 -0
  195. flyte/types/_utils.py +80 -0
  196. flyte-0.0.1b0.dist-info/METADATA +179 -0
  197. flyte-0.0.1b0.dist-info/RECORD +390 -0
  198. flyte-0.0.1b0.dist-info/WHEEL +5 -0
  199. flyte-0.0.1b0.dist-info/entry_points.txt +3 -0
  200. flyte-0.0.1b0.dist-info/top_level.txt +1 -0
  201. union/__init__.py +54 -0
  202. union/_api_commons.py +3 -0
  203. union/_bin/__init__.py +0 -0
  204. union/_bin/runtime.py +113 -0
  205. union/_build.py +25 -0
  206. union/_cache/__init__.py +12 -0
  207. union/_cache/cache.py +141 -0
  208. union/_cache/defaults.py +9 -0
  209. union/_cache/policy_function_body.py +42 -0
  210. union/_cli/__init__.py +0 -0
  211. union/_cli/_common.py +263 -0
  212. union/_cli/_create.py +40 -0
  213. union/_cli/_delete.py +23 -0
  214. union/_cli/_deploy.py +120 -0
  215. union/_cli/_get.py +162 -0
  216. union/_cli/_params.py +579 -0
  217. union/_cli/_run.py +150 -0
  218. union/_cli/main.py +72 -0
  219. union/_code_bundle/__init__.py +8 -0
  220. union/_code_bundle/_ignore.py +113 -0
  221. union/_code_bundle/_packaging.py +187 -0
  222. union/_code_bundle/_utils.py +342 -0
  223. union/_code_bundle/bundle.py +176 -0
  224. union/_context.py +146 -0
  225. union/_datastructures.py +295 -0
  226. union/_deploy.py +185 -0
  227. union/_doc.py +29 -0
  228. union/_docstring.py +26 -0
  229. union/_environment.py +43 -0
  230. union/_group.py +31 -0
  231. union/_hash.py +23 -0
  232. union/_image.py +760 -0
  233. union/_initialize.py +585 -0
  234. union/_interface.py +84 -0
  235. union/_internal/__init__.py +3 -0
  236. union/_internal/controllers/__init__.py +77 -0
  237. union/_internal/controllers/_local_controller.py +77 -0
  238. union/_internal/controllers/pbhash.py +39 -0
  239. union/_internal/controllers/remote/__init__.py +40 -0
  240. union/_internal/controllers/remote/_action.py +131 -0
  241. union/_internal/controllers/remote/_client.py +43 -0
  242. union/_internal/controllers/remote/_controller.py +169 -0
  243. union/_internal/controllers/remote/_core.py +341 -0
  244. union/_internal/controllers/remote/_informer.py +260 -0
  245. union/_internal/controllers/remote/_service_protocol.py +44 -0
  246. union/_internal/imagebuild/__init__.py +11 -0
  247. union/_internal/imagebuild/docker_builder.py +416 -0
  248. union/_internal/imagebuild/image_builder.py +243 -0
  249. union/_internal/imagebuild/remote_builder.py +0 -0
  250. union/_internal/resolvers/__init__.py +0 -0
  251. union/_internal/resolvers/_task_module.py +31 -0
  252. union/_internal/resolvers/common.py +24 -0
  253. union/_internal/resolvers/default.py +27 -0
  254. union/_internal/runtime/__init__.py +0 -0
  255. union/_internal/runtime/convert.py +163 -0
  256. union/_internal/runtime/entrypoints.py +121 -0
  257. union/_internal/runtime/io.py +136 -0
  258. union/_internal/runtime/resources_serde.py +134 -0
  259. union/_internal/runtime/task_serde.py +202 -0
  260. union/_internal/runtime/taskrunner.py +179 -0
  261. union/_internal/runtime/types_serde.py +53 -0
  262. union/_logging.py +124 -0
  263. union/_protos/__init__.py +0 -0
  264. union/_protos/common/authorization_pb2.py +66 -0
  265. union/_protos/common/authorization_pb2.pyi +106 -0
  266. union/_protos/common/authorization_pb2_grpc.py +4 -0
  267. union/_protos/common/identifier_pb2.py +71 -0
  268. union/_protos/common/identifier_pb2.pyi +82 -0
  269. union/_protos/common/identifier_pb2_grpc.py +4 -0
  270. union/_protos/common/identity_pb2.py +48 -0
  271. union/_protos/common/identity_pb2.pyi +72 -0
  272. union/_protos/common/identity_pb2_grpc.py +4 -0
  273. union/_protos/common/list_pb2.py +36 -0
  274. union/_protos/common/list_pb2.pyi +69 -0
  275. union/_protos/common/list_pb2_grpc.py +4 -0
  276. union/_protos/common/policy_pb2.py +37 -0
  277. union/_protos/common/policy_pb2.pyi +27 -0
  278. union/_protos/common/policy_pb2_grpc.py +4 -0
  279. union/_protos/common/role_pb2.py +37 -0
  280. union/_protos/common/role_pb2.pyi +51 -0
  281. union/_protos/common/role_pb2_grpc.py +4 -0
  282. union/_protos/common/runtime_version_pb2.py +28 -0
  283. union/_protos/common/runtime_version_pb2.pyi +24 -0
  284. union/_protos/common/runtime_version_pb2_grpc.py +4 -0
  285. union/_protos/logs/dataplane/payload_pb2.py +96 -0
  286. union/_protos/logs/dataplane/payload_pb2.pyi +168 -0
  287. union/_protos/logs/dataplane/payload_pb2_grpc.py +4 -0
  288. union/_protos/secret/definition_pb2.py +49 -0
  289. union/_protos/secret/definition_pb2.pyi +93 -0
  290. union/_protos/secret/definition_pb2_grpc.py +4 -0
  291. union/_protos/secret/payload_pb2.py +62 -0
  292. union/_protos/secret/payload_pb2.pyi +94 -0
  293. union/_protos/secret/payload_pb2_grpc.py +4 -0
  294. union/_protos/secret/secret_pb2.py +38 -0
  295. union/_protos/secret/secret_pb2.pyi +6 -0
  296. union/_protos/secret/secret_pb2_grpc.py +198 -0
  297. union/_protos/validate/validate/validate_pb2.py +76 -0
  298. union/_protos/workflow/node_execution_service_pb2.py +26 -0
  299. union/_protos/workflow/node_execution_service_pb2.pyi +4 -0
  300. union/_protos/workflow/node_execution_service_pb2_grpc.py +32 -0
  301. union/_protos/workflow/queue_service_pb2.py +75 -0
  302. union/_protos/workflow/queue_service_pb2.pyi +103 -0
  303. union/_protos/workflow/queue_service_pb2_grpc.py +172 -0
  304. union/_protos/workflow/run_definition_pb2.py +100 -0
  305. union/_protos/workflow/run_definition_pb2.pyi +256 -0
  306. union/_protos/workflow/run_definition_pb2_grpc.py +4 -0
  307. union/_protos/workflow/run_logs_service_pb2.py +41 -0
  308. union/_protos/workflow/run_logs_service_pb2.pyi +28 -0
  309. union/_protos/workflow/run_logs_service_pb2_grpc.py +69 -0
  310. union/_protos/workflow/run_service_pb2.py +133 -0
  311. union/_protos/workflow/run_service_pb2.pyi +173 -0
  312. union/_protos/workflow/run_service_pb2_grpc.py +412 -0
  313. union/_protos/workflow/state_service_pb2.py +58 -0
  314. union/_protos/workflow/state_service_pb2.pyi +69 -0
  315. union/_protos/workflow/state_service_pb2_grpc.py +138 -0
  316. union/_protos/workflow/task_definition_pb2.py +72 -0
  317. union/_protos/workflow/task_definition_pb2.pyi +65 -0
  318. union/_protos/workflow/task_definition_pb2_grpc.py +4 -0
  319. union/_protos/workflow/task_service_pb2.py +44 -0
  320. union/_protos/workflow/task_service_pb2.pyi +31 -0
  321. union/_protos/workflow/task_service_pb2_grpc.py +104 -0
  322. union/_resources.py +226 -0
  323. union/_retry.py +32 -0
  324. union/_reusable_environment.py +25 -0
  325. union/_run.py +374 -0
  326. union/_secret.py +61 -0
  327. union/_task.py +354 -0
  328. union/_task_environment.py +186 -0
  329. union/_timeout.py +47 -0
  330. union/_tools.py +27 -0
  331. union/_utils/__init__.py +11 -0
  332. union/_utils/asyn.py +119 -0
  333. union/_utils/file_handling.py +71 -0
  334. union/_utils/helpers.py +46 -0
  335. union/_utils/lazy_module.py +54 -0
  336. union/_utils/uv_script_parser.py +49 -0
  337. union/_version.py +21 -0
  338. union/connectors/__init__.py +0 -0
  339. union/errors.py +128 -0
  340. union/extras/__init__.py +5 -0
  341. union/extras/_container.py +263 -0
  342. union/io/__init__.py +11 -0
  343. union/io/_dataframe.py +0 -0
  344. union/io/_dir.py +425 -0
  345. union/io/_file.py +418 -0
  346. union/io/pickle/__init__.py +0 -0
  347. union/io/pickle/transformer.py +117 -0
  348. union/io/structured_dataset/__init__.py +122 -0
  349. union/io/structured_dataset/basic_dfs.py +219 -0
  350. union/io/structured_dataset/structured_dataset.py +1057 -0
  351. union/py.typed +0 -0
  352. union/remote/__init__.py +23 -0
  353. union/remote/_client/__init__.py +0 -0
  354. union/remote/_client/_protocols.py +129 -0
  355. union/remote/_client/auth/__init__.py +12 -0
  356. union/remote/_client/auth/_authenticators/__init__.py +0 -0
  357. union/remote/_client/auth/_authenticators/base.py +391 -0
  358. union/remote/_client/auth/_authenticators/client_credentials.py +73 -0
  359. union/remote/_client/auth/_authenticators/device_code.py +120 -0
  360. union/remote/_client/auth/_authenticators/external_command.py +77 -0
  361. union/remote/_client/auth/_authenticators/factory.py +200 -0
  362. union/remote/_client/auth/_authenticators/pkce.py +515 -0
  363. union/remote/_client/auth/_channel.py +184 -0
  364. union/remote/_client/auth/_client_config.py +83 -0
  365. union/remote/_client/auth/_default_html.py +32 -0
  366. union/remote/_client/auth/_grpc_utils/__init__.py +0 -0
  367. union/remote/_client/auth/_grpc_utils/auth_interceptor.py +204 -0
  368. union/remote/_client/auth/_grpc_utils/default_metadata_interceptor.py +144 -0
  369. union/remote/_client/auth/_keyring.py +154 -0
  370. union/remote/_client/auth/_token_client.py +258 -0
  371. union/remote/_client/auth/errors.py +16 -0
  372. union/remote/_client/controlplane.py +86 -0
  373. union/remote/_data.py +149 -0
  374. union/remote/_logs.py +74 -0
  375. union/remote/_project.py +86 -0
  376. union/remote/_run.py +820 -0
  377. union/remote/_secret.py +132 -0
  378. union/remote/_task.py +193 -0
  379. union/report/__init__.py +3 -0
  380. union/report/_report.py +178 -0
  381. union/report/_template.html +124 -0
  382. union/storage/__init__.py +24 -0
  383. union/storage/_remote_fs.py +34 -0
  384. union/storage/_storage.py +247 -0
  385. union/storage/_utils.py +5 -0
  386. union/types/__init__.py +11 -0
  387. union/types/_renderer.py +162 -0
  388. union/types/_string_literals.py +120 -0
  389. union/types/_type_engine.py +2131 -0
  390. union/types/_utils.py +80 -0
@@ -0,0 +1,154 @@
1
+ import hashlib # Added import for hashing
2
+ import typing
3
+ from urllib.parse import urlparse # Added import
4
+
5
+ import keyring
6
+ import pydantic
7
+ from keyring.errors import NoKeyringError, PasswordDeleteError
8
+
9
+ from union._logging import logger
10
+
11
+
12
+ def strip_scheme(url: str) -> str:
13
+ """
14
+ Strips the scheme from a URL.
15
+ Handles cases like:
16
+ - dns:///foo.com -> foo.com
17
+ - https://foo.com -> foo.com
18
+ - https://foo.com/blah -> foo.com/blah
19
+ """
20
+ parsed_url = urlparse(url)
21
+ if parsed_url.scheme == "dns":
22
+ return parsed_url.path.lstrip("/")
23
+ return f"{parsed_url.netloc}{parsed_url.path}" if parsed_url.netloc else url
24
+
25
+
26
+ class Credentials(pydantic.BaseModel):
27
+ """
28
+ Stores the credentials together
29
+ """
30
+
31
+ access_token: str
32
+ for_endpoint: str = "flyte-default"
33
+ id: str | None = ""
34
+ refresh_token: str | None = None
35
+ expires_in: int | None = None
36
+ id_token: str | None = None
37
+
38
+ @pydantic.field_validator("for_endpoint", mode="after")
39
+ @classmethod
40
+ def validate_endpoint(cls, v: str) -> str:
41
+ return strip_scheme(v)
42
+
43
+ @pydantic.model_validator(mode="after")
44
+ def compute_id(self) -> "Credentials":
45
+ """Computes the id field as a hash of the access_token or id_token."""
46
+ if self.access_token:
47
+ self.id = hashlib.md5(self.access_token.encode()).hexdigest()
48
+ elif self.id_token:
49
+ self.id = hashlib.md5(self.id_token.encode()).hexdigest()
50
+ return self
51
+
52
+
53
+ class KeyringStore:
54
+ """
55
+ Methods to access Keyring Store.
56
+ """
57
+
58
+ _access_token_key = "access_token"
59
+ _refresh_token_key = "refresh_token"
60
+ _id_token_key = "id_token"
61
+
62
+ @staticmethod
63
+ def store(credentials: Credentials) -> Credentials:
64
+ """
65
+ Stores the provided credentials in the system keyring.
66
+
67
+ This method stores the access token, refresh token (if available), and ID token (if available)
68
+ in the system keyring, using the endpoint as the service name and specific key names for each token type.
69
+
70
+ :param credentials: The credentials object containing tokens to store
71
+ :return: The same credentials object that was passed in
72
+ :raises: Logs but does not raise NoKeyringError if the system keyring is not available
73
+ """
74
+ try:
75
+ if credentials.refresh_token:
76
+ keyring.set_password(
77
+ credentials.for_endpoint,
78
+ KeyringStore._refresh_token_key,
79
+ credentials.refresh_token,
80
+ )
81
+ keyring.set_password(
82
+ credentials.for_endpoint,
83
+ KeyringStore._access_token_key,
84
+ credentials.access_token,
85
+ )
86
+ if credentials.id_token:
87
+ keyring.set_password(
88
+ credentials.for_endpoint,
89
+ KeyringStore._id_token_key,
90
+ credentials.id_token,
91
+ )
92
+ except NoKeyringError as e:
93
+ logger.debug(f"KeyRing not available, tokens will not be cached. Error: {e}")
94
+ return credentials
95
+
96
+ @staticmethod
97
+ def retrieve(for_endpoint: str) -> typing.Optional[Credentials]:
98
+ """
99
+ Retrieves stored credentials from the system keyring for the specified endpoint.
100
+
101
+ This method attempts to retrieve the access token, refresh token, and ID token from the system keyring
102
+ using the endpoint as the service name. The endpoint URL scheme is stripped before lookup.
103
+
104
+ :param for_endpoint: The endpoint URL to retrieve credentials for
105
+ :return: A Credentials object containing the retrieved tokens, or None if no tokens were found
106
+ or if the system keyring is not available
107
+ """
108
+ for_endpoint = strip_scheme(for_endpoint)
109
+ try:
110
+ refresh_token = keyring.get_password(for_endpoint, KeyringStore._refresh_token_key)
111
+ access_token = keyring.get_password(for_endpoint, KeyringStore._access_token_key)
112
+ id_token = keyring.get_password(for_endpoint, KeyringStore._id_token_key)
113
+ except NoKeyringError as e:
114
+ logger.debug(f"KeyRing not available, tokens will not be cached. Error: {e}")
115
+ return None
116
+
117
+ if not access_token and not id_token:
118
+ return None
119
+ return Credentials(
120
+ access_token=access_token,
121
+ refresh_token=refresh_token,
122
+ for_endpoint=for_endpoint,
123
+ id_token=id_token,
124
+ expires_in=None,
125
+ )
126
+
127
+ @staticmethod
128
+ def delete(for_endpoint: str):
129
+ """
130
+ Deletes all stored credentials for the specified endpoint from the system keyring.
131
+
132
+ This method attempts to delete the access token, refresh token, and ID token from the system keyring
133
+ using the endpoint as the service name. The endpoint URL scheme is stripped before lookup.
134
+
135
+ :param for_endpoint: The endpoint URL to delete credentials for
136
+ """
137
+ for_endpoint = strip_scheme(for_endpoint)
138
+
139
+ def _delete_key(key):
140
+ """
141
+ Helper function to delete a specific key from the keyring.
142
+
143
+ :param key: The key name to delete
144
+ """
145
+ try:
146
+ keyring.delete_password(for_endpoint, key)
147
+ except PasswordDeleteError as e:
148
+ logger.debug(f"Key {key} not found in key store, Ignoring. Error: {e}")
149
+ except NoKeyringError as e:
150
+ logger.debug(f"KeyRing not available, Key {key} deletion failed. Error: {e}")
151
+
152
+ _delete_key(KeyringStore._access_token_key)
153
+ _delete_key(KeyringStore._refresh_token_key)
154
+ _delete_key(KeyringStore._id_token_key)
@@ -0,0 +1,258 @@
1
+ import asyncio
2
+ import base64
3
+ import enum
4
+ import typing
5
+ import urllib.parse
6
+ from datetime import datetime, timedelta
7
+
8
+ import httpx
9
+ import pydantic
10
+
11
+ from union._logging import logger
12
+ from union.remote._client.auth.errors import AuthenticationError, AuthenticationPending
13
+
14
+ utf_8 = "utf-8"
15
+
16
+ # Errors that Token endpoint will return
17
+ error_slow_down = "slow_down"
18
+ error_auth_pending = "authorization_pending"
19
+
20
+
21
+ # Grant Types
22
+ class GrantType(str, enum.Enum):
23
+ CLIENT_CREDS = "client_credentials"
24
+ DEVICE_CODE = "urn:ietf:params:oauth:grant-type:device_code"
25
+ REFRESH_TOKEN = "refresh_token"
26
+
27
+
28
+ class DeviceCodeResponse(pydantic.BaseModel):
29
+ """
30
+ Response from device auth flow endpoint
31
+ {
32
+ 'device_code': 'code',
33
+ 'user_code': 'BNDJJFXL',
34
+ 'verification_uri': 'url',
35
+ 'expires_in': 600,
36
+ 'interval': 5
37
+ }
38
+
39
+ Attributes:
40
+ device_code (str): The device verification code.
41
+ user_code (str): The user-facing code that should be entered on the verification page.
42
+ verification_uri (str): The URL where the user should enter the user_code.
43
+ expires_in (int): The lifetime in seconds of the device code and user code.
44
+ interval (int): The minimum amount of time in seconds to wait between polling requests.
45
+ """
46
+
47
+ device_code: str
48
+ user_code: str
49
+ verification_uri: str
50
+ expires_in: int
51
+ interval: int
52
+
53
+ @classmethod
54
+ def from_json_response(cls, j: typing.Dict) -> "DeviceCodeResponse":
55
+ """
56
+ Create a DeviceCodeResponse instance from a JSON response dictionary.
57
+
58
+ :param j: The JSON response dictionary containing device code information
59
+ :return: A new instance with values from the JSON response
60
+ """
61
+ return cls(
62
+ device_code=j["device_code"],
63
+ user_code=j["user_code"],
64
+ verification_uri=j["verification_uri"],
65
+ expires_in=j["expires_in"],
66
+ interval=j["interval"],
67
+ )
68
+
69
+
70
+ def get_basic_authorization_header(client_id: str, client_secret: str) -> str:
71
+ """
72
+ This function transforms the client id and the client secret into a header that conforms with http basic auth.
73
+ It joins the id and the secret with a : then base64 encodes it, then adds the appropriate text. Secrets are
74
+ first URL encoded to escape illegal characters.
75
+
76
+ :param client_id: The client ID for authentication
77
+ :param client_secret: The client secret for authentication
78
+ :rtype: str
79
+ """
80
+ encoded = urllib.parse.quote_plus(client_secret)
81
+ concatenated = "{}:{}".format(client_id, encoded)
82
+ return "Basic {}".format(base64.b64encode(concatenated.encode(utf_8)).decode(utf_8))
83
+
84
+
85
+ async def get_token(
86
+ token_endpoint: str,
87
+ http_session: httpx.AsyncClient,
88
+ scopes: typing.Optional[typing.List[str]] = None,
89
+ authorization_header: typing.Optional[str] = None,
90
+ client_id: typing.Optional[str] = None,
91
+ device_code: typing.Optional[str] = None,
92
+ audience: typing.Optional[str] = None,
93
+ grant_type: GrantType = GrantType.CLIENT_CREDS,
94
+ http_proxy_url: typing.Optional[str] = None,
95
+ verify: typing.Optional[typing.Union[bool, str]] = None,
96
+ refresh_token: typing.Optional[str] = None,
97
+ ) -> typing.Tuple[str, str, int]:
98
+ """
99
+ Retrieves an access token from the specified token endpoint.
100
+
101
+ :param token_endpoint: The endpoint URL for token retrieval
102
+ :param http_session: HTTP session to use for requests
103
+ :param scopes: Optional list of scopes to request during authentication
104
+ :param authorization_header: Optional authorization header value
105
+ :param client_id: Optional client ID for authentication
106
+ :param device_code: Optional device code for device flow authentication
107
+ :param audience: Optional audience for the token
108
+ :param grant_type: The grant type to use (default: CLIENT_CREDS)
109
+ :param http_proxy_url: Optional HTTP proxy URL
110
+ :param verify: Whether to verify SSL certificates (bool or path to cert)
111
+ :param refresh_token: Optional refresh token for token refresh
112
+ :return: A tuple of (access_token, refresh_token, expires_in)
113
+
114
+ :param token_endpoint: The URL of the token endpoint
115
+ :param scopes: List of scopes to request
116
+ :param authorization_header: Authorization header value if using client credentials
117
+ :param client_id: The client ID to use for authentication
118
+ :param device_code: The device code when using device code flow
119
+ :param audience: The audience value to request
120
+ :param grant_type: The OAuth grant type to use
121
+ :param http_proxy_url: HTTP proxy URL if needed
122
+ :param verify: SSL verification mode
123
+ :param http_session: An existing HTTP client session
124
+ :param refresh_token: Refresh token for refresh token flow
125
+ :return: A tuple containing (access_token, refresh_token, expires_in)
126
+
127
+ Raises:
128
+ AuthenticationPending: When authentication is still pending (for device code flow).
129
+ AuthenticationError: When authentication fails for any reason.
130
+ """
131
+ headers = {
132
+ "Cache-Control": "no-cache",
133
+ "Accept": "application/json",
134
+ "Content-Type": "application/x-www-form-urlencoded",
135
+ }
136
+ if authorization_header:
137
+ headers["Authorization"] = authorization_header
138
+ body = {
139
+ "grant_type": grant_type.value,
140
+ }
141
+ if client_id:
142
+ body["client_id"] = client_id
143
+ if device_code:
144
+ body["device_code"] = device_code
145
+ if scopes is not None:
146
+ body["scope"] = " ".join(s.strip("' ") for s in scopes).strip("[]'")
147
+ if audience:
148
+ body["audience"] = audience
149
+ if refresh_token:
150
+ body["refresh_token"] = refresh_token
151
+
152
+ response = await http_session.post(token_endpoint, data=body, headers=headers)
153
+
154
+ if not response.is_success:
155
+ j = response.json()
156
+ if "error" in j:
157
+ err = j["error"]
158
+ if err == error_auth_pending or err == error_slow_down:
159
+ raise AuthenticationPending(f"Token not yet available, try again in some time {err}")
160
+ logger.error("Status Code ({}) received from IDP: {}".format(response.status_code, response.text))
161
+ raise AuthenticationError("Status Code ({}) received from IDP: {}".format(response.status_code, response.text))
162
+
163
+ j = response.json()
164
+ new_refresh_token = None
165
+ if "refresh_token" in j:
166
+ new_refresh_token = j["refresh_token"]
167
+
168
+ return j["access_token"], new_refresh_token, j["expires_in"]
169
+
170
+
171
+ async def get_device_code(
172
+ device_auth_endpoint: str,
173
+ client_id: str,
174
+ http_session: httpx.AsyncClient,
175
+ *,
176
+ audience: typing.Optional[str] = None,
177
+ scopes: typing.Optional[typing.List[str]] = None,
178
+ ) -> DeviceCodeResponse:
179
+ """
180
+ Retrieves the device authentication code that can be used to authenticate the request using a browser on a
181
+ separate device.
182
+
183
+ :param device_auth_endpoint: The URL of the device authorization endpoint
184
+ :param client_id: The client ID to use for authentication
185
+ :param audience: The audience value to request
186
+ :param scopes: List of scopes to request
187
+ :param http_proxy_url: HTTP proxy URL if needed
188
+ :param verify: SSL verification mode
189
+ :param http_session: An existing HTTP client session
190
+ :return: An object containing the device code and related information
191
+ :raises AuthenticationError: When device code retrieval fails
192
+ """
193
+ _scope = " ".join(s.strip("' ") for s in scopes).strip("[]'") if scopes is not None else ""
194
+ payload = {"client_id": client_id, "scope": _scope, "audience": audience}
195
+ resp = await http_session.post(device_auth_endpoint, data=payload)
196
+ if not resp.is_success:
197
+ raise AuthenticationError(
198
+ f"Unable to retrieve Device Authentication Code for {payload},"
199
+ f" Status Code {resp.status_code} Reason {resp.json()}"
200
+ )
201
+ return DeviceCodeResponse.from_json_response(resp.json())
202
+
203
+
204
+ async def poll_token_endpoint(
205
+ resp: DeviceCodeResponse,
206
+ *,
207
+ token_endpoint: str,
208
+ client_id: str,
209
+ http_session: httpx.AsyncClient,
210
+ audience: typing.Optional[str] = None,
211
+ scopes: typing.Optional[typing.List[str]] = None,
212
+ http_proxy_url: typing.Optional[str] = None,
213
+ verify: typing.Optional[typing.Union[bool, str]] = None,
214
+ ) -> typing.Tuple[str, str, int]:
215
+ """
216
+ Polls the token endpoint until authentication is complete or times out.
217
+
218
+ This function repeatedly calls the token endpoint at the specified interval until either:
219
+ 1. Authentication is successful and a token is returned
220
+ 2. The device code expires (as specified in the DeviceCodeResponse)
221
+
222
+ :param resp: The device code response from a previous call to get_device_code
223
+ :param token_endpoint: The URL of the token endpoint
224
+ :param client_id: The client ID to use for authentication
225
+ :param audience: The audience value to request
226
+ :param scopes: Space-separated list of scopes to request
227
+ :param http_proxy_url: HTTP proxy URL if needed
228
+ :param verify: SSL verification mode
229
+ :return: A tuple containing (access_token, refresh_token, expires_in)
230
+ :raises AuthenticationError: When authentication fails or times out
231
+ """
232
+ tick = datetime.now()
233
+ interval = timedelta(seconds=resp.interval)
234
+ end_time = tick + timedelta(seconds=resp.expires_in)
235
+ while tick < end_time:
236
+ try:
237
+ access_token, refresh_token, expires_in = await get_token(
238
+ token_endpoint,
239
+ grant_type=GrantType.DEVICE_CODE,
240
+ client_id=client_id,
241
+ audience=audience,
242
+ scopes=scopes,
243
+ device_code=resp.device_code,
244
+ http_proxy_url=http_proxy_url,
245
+ verify=verify,
246
+ http_session=http_session,
247
+ )
248
+ logger.debug(f"Authentication successful, access token received, expires in {expires_in} seconds")
249
+ return access_token, refresh_token, expires_in
250
+ except AuthenticationPending:
251
+ ...
252
+ except Exception as e:
253
+ logger.warning(f"Authentication failed, reason {e}")
254
+ raise e
255
+ logger.debug(f"Authentication pending, ..., waiting for {resp.interval} seconds")
256
+ await asyncio.sleep(interval.total_seconds())
257
+ tick = tick + interval
258
+ raise AuthenticationError("Authentication failed!")
@@ -0,0 +1,16 @@
1
+ class AccessTokenNotFoundError(RuntimeError):
2
+ """
3
+ This error is raised with Access token is not found or if Refreshing the token fails
4
+ """
5
+
6
+
7
+ class AuthenticationError(RuntimeError):
8
+ """
9
+ This is raised for any AuthenticationError
10
+ """
11
+
12
+
13
+ class AuthenticationPending(RuntimeError):
14
+ """
15
+ This is raised if the token endpoint returns authentication pending
16
+ """
@@ -0,0 +1,86 @@
1
+ from __future__ import annotations
2
+
3
+ import grpc
4
+ from flyteidl.service import admin_pb2_grpc, dataproxy_pb2_grpc
5
+
6
+ from union._protos.secret import secret_pb2_grpc
7
+ from union._protos.workflow import run_logs_service_pb2_grpc, run_service_pb2_grpc, task_service_pb2_grpc
8
+
9
+ from ._protocols import (
10
+ DataProxyService,
11
+ MetadataServiceProtocol,
12
+ ProjectDomainService,
13
+ RunLogsService,
14
+ RunService,
15
+ SecretService,
16
+ TaskService,
17
+ )
18
+ from .auth import create_channel
19
+
20
+
21
+ class ClientSet:
22
+ def __init__(self, channel: grpc.aio.Channel, data_proxy_channel: grpc.aio.Channel | None = None):
23
+ self._channel = channel
24
+ self._admin_client = admin_pb2_grpc.AdminServiceStub(channel=channel)
25
+ self._task_service = task_service_pb2_grpc.TaskServiceStub(channel=channel)
26
+ self._run_service = run_service_pb2_grpc.RunServiceStub(channel=channel)
27
+ self._dataproxy = dataproxy_pb2_grpc.DataProxyServiceStub(channel=channel)
28
+ self._log_service = run_logs_service_pb2_grpc.RunLogsServiceStub(channel=channel)
29
+ self._secrets_service = secret_pb2_grpc.SecretServiceStub(channel=channel)
30
+
31
+ @classmethod
32
+ async def for_endpoint(cls, endpoint: str, *, insecure: bool = False, **kwargs) -> ClientSet:
33
+ if insecure:
34
+ del kwargs["api_key"]
35
+ del kwargs["auth_type"]
36
+ del kwargs["headless"]
37
+ del kwargs["command"]
38
+ del kwargs["client_id"]
39
+ del kwargs["client_credentials_secret"]
40
+ del kwargs["client_config"]
41
+ del kwargs["rpc_retries"]
42
+ del kwargs["http_proxy_url"]
43
+ return cls(await create_channel(endpoint, insecure=insecure, **kwargs))
44
+
45
+ @classmethod
46
+ async def for_api_key(cls, api_key: str, **kwargs) -> ClientSet:
47
+ raise NotImplementedError
48
+
49
+ @classmethod
50
+ async def for_serverless(cls) -> ClientSet:
51
+ raise NotImplementedError
52
+
53
+ @classmethod
54
+ async def from_env(cls) -> ClientSet:
55
+ raise NotImplementedError
56
+
57
+ @property
58
+ def metadata_service(self) -> MetadataServiceProtocol:
59
+ return self._admin_client
60
+
61
+ @property
62
+ def project_domain_service(self) -> ProjectDomainService:
63
+ return self._admin_client
64
+
65
+ @property
66
+ def task_service(self) -> TaskService:
67
+ return self._task_service
68
+
69
+ @property
70
+ def run_service(self) -> RunService:
71
+ return self._run_service
72
+
73
+ @property
74
+ def dataproxy_service(self) -> DataProxyService:
75
+ return self._dataproxy
76
+
77
+ @property
78
+ def logs_service(self) -> RunLogsService:
79
+ return self._log_service
80
+
81
+ @property
82
+ def secrets_service(self) -> SecretService:
83
+ return self._secrets_service
84
+
85
+ async def close(self, grace: float | None = None):
86
+ return await self._channel.close(grace=grace)
union/remote/_data.py ADDED
@@ -0,0 +1,149 @@
1
+ import asyncio
2
+ import hashlib
3
+ import os
4
+ import typing
5
+ import uuid
6
+ from base64 import b64encode
7
+ from datetime import timedelta
8
+ from functools import lru_cache
9
+ from pathlib import Path
10
+ from typing import Tuple
11
+
12
+ import aiofiles
13
+ import grpc
14
+ import httpx
15
+ from flyteidl.service import dataproxy_pb2
16
+ from google.protobuf import duration_pb2
17
+
18
+ from union._initialize import CommonInit, get_client, get_common_config, requires_client
19
+ from union.errors import RuntimeSystemError
20
+
21
+ _UPLOAD_EXPIRES_IN = timedelta(seconds=60)
22
+
23
+
24
+ def get_extra_headers_for_protocol(native_url: str) -> typing.Dict[str, str]:
25
+ """
26
+ For Azure Blob Storage, we need to set certain headers for http request.
27
+ This is used when we work with signed urls.
28
+ :param native_url:
29
+ :return:
30
+ """
31
+ if native_url.startswith("abfs://"):
32
+ return {"x-ms-blob-type": "BlockBlob"}
33
+ return {}
34
+
35
+
36
+ @lru_cache
37
+ def hash_file(file_path: typing.Union[os.PathLike, str]) -> (bytes, str, int):
38
+ """
39
+ Hash a file and produce a digest to be used as a version
40
+ """
41
+ h = hashlib.md5()
42
+ size = 0
43
+
44
+ with open(file_path, "rb") as file:
45
+ while True:
46
+ # Reading is buffered, so we can read smaller chunks.
47
+ chunk = file.read(h.block_size)
48
+ if not chunk:
49
+ break
50
+ h.update(chunk)
51
+ size += len(chunk)
52
+
53
+ return h.digest(), h.hexdigest(), size
54
+
55
+
56
+ async def _upload_single_file(
57
+ cfg: CommonInit, fp: Path, verify: bool = True, basedir: str | None = None
58
+ ) -> Tuple[str, str]:
59
+ md5_bytes, str_digest, _ = hash_file(fp)
60
+
61
+ try:
62
+ expires_in_pb = duration_pb2.Duration()
63
+ expires_in_pb.FromTimedelta(_UPLOAD_EXPIRES_IN)
64
+ client = get_client()
65
+ resp = await client.dataproxy_service.CreateUploadLocation(
66
+ dataproxy_pb2.CreateUploadLocationRequest(
67
+ project=cfg.project,
68
+ domain=cfg.domain,
69
+ content_md5=md5_bytes,
70
+ filename=fp.name,
71
+ expires_in=expires_in_pb,
72
+ filename_root=basedir,
73
+ add_content_md5_metadata=True,
74
+ )
75
+ )
76
+ except grpc.aio.AioRpcError as e:
77
+ if e.code() == grpc.StatusCode.NOT_FOUND:
78
+ raise RuntimeSystemError(
79
+ "NotFound", f"Failed to get signed url for {fp}, please check your project and domain."
80
+ )
81
+ elif e.code() == grpc.StatusCode.PERMISSION_DENIED:
82
+ raise RuntimeSystemError(
83
+ "PermissionDenied", f"Failed to get signed url for {fp}, please check your permissions."
84
+ )
85
+ else:
86
+ raise RuntimeSystemError(e.code().value, f"Failed to get signed url for {fp}.")
87
+ except Exception as e:
88
+ raise RuntimeSystemError(type(e).__name__, f"Failed to get signed url for {fp}.") from e
89
+
90
+ extra_headers = get_extra_headers_for_protocol(resp.native_url)
91
+ extra_headers.update(resp.headers)
92
+ encoded_md5 = b64encode(md5_bytes)
93
+ content_length = fp.stat().st_size
94
+
95
+ async with aiofiles.open(str(fp), "rb") as file:
96
+ extra_headers.update({"Content-Length": str(content_length), "Content-MD5": encoded_md5})
97
+ async with httpx.AsyncClient(verify=verify) as client:
98
+ await client.put(resp.signed_url, headers=extra_headers, content=file)
99
+ # TODO in old code we did this
100
+ # if self._config.platform.insecure_skip_verify is True
101
+ # else self._config.platform.ca_cert_file_path,
102
+
103
+ return str_digest, resp.native_url
104
+
105
+
106
+ @requires_client
107
+ async def upload_file(fp: Path, verify: bool = True) -> Tuple[str, str]:
108
+ """
109
+ Uploads a file to a remote location and returns the remote URI.
110
+
111
+ :param fp: The file path to upload.
112
+ :param verify: Whether to verify the certificate for HTTPS requests.
113
+ :return: A tuple containing the MD5 digest and the remote URI.
114
+ """
115
+ # This is a placeholder implementation. Replace with actual upload logic.
116
+ cfg = get_common_config()
117
+ if not fp.is_file():
118
+ raise ValueError(f"{fp} is not a single file, upload arg must be a single file.")
119
+ return await _upload_single_file(cfg, fp, verify=verify)
120
+
121
+
122
+ @requires_client
123
+ async def upload_dir(dir_path: Path, verify: bool = True) -> str:
124
+ """
125
+ Uploads a directory to a remote location and returns the remote URI.
126
+
127
+ :param dir_path: The directory path to upload.
128
+ :param verify: Whether to verify the certificate for HTTPS requests.
129
+ :return: The remote URI of the uploaded directory.
130
+ """
131
+ # This is a placeholder implementation. Replace with actual upload logic.
132
+ cfg = get_common_config()
133
+ if not dir_path.is_dir():
134
+ raise ValueError(f"{dir_path} is not a directory, upload arg must be a directory.")
135
+
136
+ prefix = uuid.uuid4().hex
137
+
138
+ files = dir_path.rglob("*")
139
+ uploaded_files = []
140
+ for file in files:
141
+ if file.is_file():
142
+ uploaded_files.append(_upload_single_file(cfg, file, verify=verify, basedir=prefix))
143
+
144
+ urls = await asyncio.gather(*uploaded_files)
145
+ native_url = urls[0][1] # Assuming all files are uploaded to the same prefix
146
+ # native_url is of the form s3://my-s3-bucket/flytesnacks/development/{prefix}/source/empty.md
147
+ uri = native_url.split(prefix)[0] + "/" + prefix
148
+
149
+ return uri