flyte 0.0.1b0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flyte might be problematic. Click here for more details.

Files changed (390) hide show
  1. flyte/__init__.py +62 -0
  2. flyte/_api_commons.py +3 -0
  3. flyte/_bin/__init__.py +0 -0
  4. flyte/_bin/runtime.py +126 -0
  5. flyte/_build.py +25 -0
  6. flyte/_cache/__init__.py +12 -0
  7. flyte/_cache/cache.py +146 -0
  8. flyte/_cache/defaults.py +9 -0
  9. flyte/_cache/policy_function_body.py +42 -0
  10. flyte/_cli/__init__.py +0 -0
  11. flyte/_cli/_common.py +287 -0
  12. flyte/_cli/_create.py +42 -0
  13. flyte/_cli/_delete.py +23 -0
  14. flyte/_cli/_deploy.py +140 -0
  15. flyte/_cli/_get.py +235 -0
  16. flyte/_cli/_run.py +152 -0
  17. flyte/_cli/main.py +72 -0
  18. flyte/_code_bundle/__init__.py +8 -0
  19. flyte/_code_bundle/_ignore.py +113 -0
  20. flyte/_code_bundle/_packaging.py +187 -0
  21. flyte/_code_bundle/_utils.py +339 -0
  22. flyte/_code_bundle/bundle.py +178 -0
  23. flyte/_context.py +146 -0
  24. flyte/_datastructures.py +342 -0
  25. flyte/_deploy.py +202 -0
  26. flyte/_doc.py +29 -0
  27. flyte/_docstring.py +32 -0
  28. flyte/_environment.py +43 -0
  29. flyte/_group.py +31 -0
  30. flyte/_hash.py +23 -0
  31. flyte/_image.py +760 -0
  32. flyte/_initialize.py +634 -0
  33. flyte/_interface.py +84 -0
  34. flyte/_internal/__init__.py +3 -0
  35. flyte/_internal/controllers/__init__.py +115 -0
  36. flyte/_internal/controllers/_local_controller.py +118 -0
  37. flyte/_internal/controllers/_trace.py +40 -0
  38. flyte/_internal/controllers/pbhash.py +39 -0
  39. flyte/_internal/controllers/remote/__init__.py +40 -0
  40. flyte/_internal/controllers/remote/_action.py +141 -0
  41. flyte/_internal/controllers/remote/_client.py +43 -0
  42. flyte/_internal/controllers/remote/_controller.py +361 -0
  43. flyte/_internal/controllers/remote/_core.py +402 -0
  44. flyte/_internal/controllers/remote/_informer.py +361 -0
  45. flyte/_internal/controllers/remote/_service_protocol.py +50 -0
  46. flyte/_internal/imagebuild/__init__.py +11 -0
  47. flyte/_internal/imagebuild/docker_builder.py +416 -0
  48. flyte/_internal/imagebuild/image_builder.py +241 -0
  49. flyte/_internal/imagebuild/remote_builder.py +0 -0
  50. flyte/_internal/resolvers/__init__.py +0 -0
  51. flyte/_internal/resolvers/_task_module.py +54 -0
  52. flyte/_internal/resolvers/common.py +31 -0
  53. flyte/_internal/resolvers/default.py +28 -0
  54. flyte/_internal/runtime/__init__.py +0 -0
  55. flyte/_internal/runtime/convert.py +199 -0
  56. flyte/_internal/runtime/entrypoints.py +135 -0
  57. flyte/_internal/runtime/io.py +136 -0
  58. flyte/_internal/runtime/resources_serde.py +138 -0
  59. flyte/_internal/runtime/task_serde.py +210 -0
  60. flyte/_internal/runtime/taskrunner.py +190 -0
  61. flyte/_internal/runtime/types_serde.py +54 -0
  62. flyte/_logging.py +124 -0
  63. flyte/_protos/__init__.py +0 -0
  64. flyte/_protos/common/authorization_pb2.py +66 -0
  65. flyte/_protos/common/authorization_pb2.pyi +108 -0
  66. flyte/_protos/common/authorization_pb2_grpc.py +4 -0
  67. flyte/_protos/common/identifier_pb2.py +71 -0
  68. flyte/_protos/common/identifier_pb2.pyi +82 -0
  69. flyte/_protos/common/identifier_pb2_grpc.py +4 -0
  70. flyte/_protos/common/identity_pb2.py +48 -0
  71. flyte/_protos/common/identity_pb2.pyi +72 -0
  72. flyte/_protos/common/identity_pb2_grpc.py +4 -0
  73. flyte/_protos/common/list_pb2.py +36 -0
  74. flyte/_protos/common/list_pb2.pyi +69 -0
  75. flyte/_protos/common/list_pb2_grpc.py +4 -0
  76. flyte/_protos/common/policy_pb2.py +37 -0
  77. flyte/_protos/common/policy_pb2.pyi +27 -0
  78. flyte/_protos/common/policy_pb2_grpc.py +4 -0
  79. flyte/_protos/common/role_pb2.py +37 -0
  80. flyte/_protos/common/role_pb2.pyi +53 -0
  81. flyte/_protos/common/role_pb2_grpc.py +4 -0
  82. flyte/_protos/common/runtime_version_pb2.py +28 -0
  83. flyte/_protos/common/runtime_version_pb2.pyi +24 -0
  84. flyte/_protos/common/runtime_version_pb2_grpc.py +4 -0
  85. flyte/_protos/logs/dataplane/payload_pb2.py +96 -0
  86. flyte/_protos/logs/dataplane/payload_pb2.pyi +168 -0
  87. flyte/_protos/logs/dataplane/payload_pb2_grpc.py +4 -0
  88. flyte/_protos/secret/definition_pb2.py +49 -0
  89. flyte/_protos/secret/definition_pb2.pyi +93 -0
  90. flyte/_protos/secret/definition_pb2_grpc.py +4 -0
  91. flyte/_protos/secret/payload_pb2.py +62 -0
  92. flyte/_protos/secret/payload_pb2.pyi +94 -0
  93. flyte/_protos/secret/payload_pb2_grpc.py +4 -0
  94. flyte/_protos/secret/secret_pb2.py +38 -0
  95. flyte/_protos/secret/secret_pb2.pyi +6 -0
  96. flyte/_protos/secret/secret_pb2_grpc.py +198 -0
  97. flyte/_protos/secret/secret_pb2_grpc_grpc.py +198 -0
  98. flyte/_protos/validate/validate/validate_pb2.py +76 -0
  99. flyte/_protos/workflow/node_execution_service_pb2.py +26 -0
  100. flyte/_protos/workflow/node_execution_service_pb2.pyi +4 -0
  101. flyte/_protos/workflow/node_execution_service_pb2_grpc.py +32 -0
  102. flyte/_protos/workflow/queue_service_pb2.py +106 -0
  103. flyte/_protos/workflow/queue_service_pb2.pyi +141 -0
  104. flyte/_protos/workflow/queue_service_pb2_grpc.py +172 -0
  105. flyte/_protos/workflow/run_definition_pb2.py +128 -0
  106. flyte/_protos/workflow/run_definition_pb2.pyi +310 -0
  107. flyte/_protos/workflow/run_definition_pb2_grpc.py +4 -0
  108. flyte/_protos/workflow/run_logs_service_pb2.py +41 -0
  109. flyte/_protos/workflow/run_logs_service_pb2.pyi +28 -0
  110. flyte/_protos/workflow/run_logs_service_pb2_grpc.py +69 -0
  111. flyte/_protos/workflow/run_service_pb2.py +133 -0
  112. flyte/_protos/workflow/run_service_pb2.pyi +175 -0
  113. flyte/_protos/workflow/run_service_pb2_grpc.py +412 -0
  114. flyte/_protos/workflow/state_service_pb2.py +58 -0
  115. flyte/_protos/workflow/state_service_pb2.pyi +71 -0
  116. flyte/_protos/workflow/state_service_pb2_grpc.py +138 -0
  117. flyte/_protos/workflow/task_definition_pb2.py +72 -0
  118. flyte/_protos/workflow/task_definition_pb2.pyi +65 -0
  119. flyte/_protos/workflow/task_definition_pb2_grpc.py +4 -0
  120. flyte/_protos/workflow/task_service_pb2.py +44 -0
  121. flyte/_protos/workflow/task_service_pb2.pyi +31 -0
  122. flyte/_protos/workflow/task_service_pb2_grpc.py +104 -0
  123. flyte/_resources.py +226 -0
  124. flyte/_retry.py +32 -0
  125. flyte/_reusable_environment.py +25 -0
  126. flyte/_run.py +411 -0
  127. flyte/_secret.py +61 -0
  128. flyte/_task.py +367 -0
  129. flyte/_task_environment.py +200 -0
  130. flyte/_timeout.py +47 -0
  131. flyte/_tools.py +27 -0
  132. flyte/_trace.py +128 -0
  133. flyte/_utils/__init__.py +20 -0
  134. flyte/_utils/asyn.py +119 -0
  135. flyte/_utils/coro_management.py +25 -0
  136. flyte/_utils/file_handling.py +72 -0
  137. flyte/_utils/helpers.py +108 -0
  138. flyte/_utils/lazy_module.py +54 -0
  139. flyte/_utils/uv_script_parser.py +49 -0
  140. flyte/_version.py +21 -0
  141. flyte/connectors/__init__.py +0 -0
  142. flyte/errors.py +143 -0
  143. flyte/extras/__init__.py +5 -0
  144. flyte/extras/_container.py +273 -0
  145. flyte/io/__init__.py +11 -0
  146. flyte/io/_dataframe.py +0 -0
  147. flyte/io/_dir.py +448 -0
  148. flyte/io/_file.py +468 -0
  149. flyte/io/pickle/__init__.py +0 -0
  150. flyte/io/pickle/transformer.py +117 -0
  151. flyte/io/structured_dataset/__init__.py +129 -0
  152. flyte/io/structured_dataset/basic_dfs.py +219 -0
  153. flyte/io/structured_dataset/structured_dataset.py +1061 -0
  154. flyte/py.typed +0 -0
  155. flyte/remote/__init__.py +25 -0
  156. flyte/remote/_client/__init__.py +0 -0
  157. flyte/remote/_client/_protocols.py +131 -0
  158. flyte/remote/_client/auth/__init__.py +12 -0
  159. flyte/remote/_client/auth/_authenticators/__init__.py +0 -0
  160. flyte/remote/_client/auth/_authenticators/base.py +397 -0
  161. flyte/remote/_client/auth/_authenticators/client_credentials.py +73 -0
  162. flyte/remote/_client/auth/_authenticators/device_code.py +118 -0
  163. flyte/remote/_client/auth/_authenticators/external_command.py +79 -0
  164. flyte/remote/_client/auth/_authenticators/factory.py +200 -0
  165. flyte/remote/_client/auth/_authenticators/pkce.py +516 -0
  166. flyte/remote/_client/auth/_channel.py +184 -0
  167. flyte/remote/_client/auth/_client_config.py +83 -0
  168. flyte/remote/_client/auth/_default_html.py +32 -0
  169. flyte/remote/_client/auth/_grpc_utils/__init__.py +0 -0
  170. flyte/remote/_client/auth/_grpc_utils/auth_interceptor.py +288 -0
  171. flyte/remote/_client/auth/_grpc_utils/default_metadata_interceptor.py +151 -0
  172. flyte/remote/_client/auth/_keyring.py +143 -0
  173. flyte/remote/_client/auth/_token_client.py +260 -0
  174. flyte/remote/_client/auth/errors.py +16 -0
  175. flyte/remote/_client/controlplane.py +95 -0
  176. flyte/remote/_console.py +18 -0
  177. flyte/remote/_data.py +155 -0
  178. flyte/remote/_logs.py +116 -0
  179. flyte/remote/_project.py +86 -0
  180. flyte/remote/_run.py +873 -0
  181. flyte/remote/_secret.py +132 -0
  182. flyte/remote/_task.py +227 -0
  183. flyte/report/__init__.py +3 -0
  184. flyte/report/_report.py +178 -0
  185. flyte/report/_template.html +124 -0
  186. flyte/storage/__init__.py +24 -0
  187. flyte/storage/_remote_fs.py +34 -0
  188. flyte/storage/_storage.py +251 -0
  189. flyte/storage/_utils.py +5 -0
  190. flyte/types/__init__.py +13 -0
  191. flyte/types/_interface.py +25 -0
  192. flyte/types/_renderer.py +162 -0
  193. flyte/types/_string_literals.py +120 -0
  194. flyte/types/_type_engine.py +2210 -0
  195. flyte/types/_utils.py +80 -0
  196. flyte-0.0.1b0.dist-info/METADATA +179 -0
  197. flyte-0.0.1b0.dist-info/RECORD +390 -0
  198. flyte-0.0.1b0.dist-info/WHEEL +5 -0
  199. flyte-0.0.1b0.dist-info/entry_points.txt +3 -0
  200. flyte-0.0.1b0.dist-info/top_level.txt +1 -0
  201. union/__init__.py +54 -0
  202. union/_api_commons.py +3 -0
  203. union/_bin/__init__.py +0 -0
  204. union/_bin/runtime.py +113 -0
  205. union/_build.py +25 -0
  206. union/_cache/__init__.py +12 -0
  207. union/_cache/cache.py +141 -0
  208. union/_cache/defaults.py +9 -0
  209. union/_cache/policy_function_body.py +42 -0
  210. union/_cli/__init__.py +0 -0
  211. union/_cli/_common.py +263 -0
  212. union/_cli/_create.py +40 -0
  213. union/_cli/_delete.py +23 -0
  214. union/_cli/_deploy.py +120 -0
  215. union/_cli/_get.py +162 -0
  216. union/_cli/_params.py +579 -0
  217. union/_cli/_run.py +150 -0
  218. union/_cli/main.py +72 -0
  219. union/_code_bundle/__init__.py +8 -0
  220. union/_code_bundle/_ignore.py +113 -0
  221. union/_code_bundle/_packaging.py +187 -0
  222. union/_code_bundle/_utils.py +342 -0
  223. union/_code_bundle/bundle.py +176 -0
  224. union/_context.py +146 -0
  225. union/_datastructures.py +295 -0
  226. union/_deploy.py +185 -0
  227. union/_doc.py +29 -0
  228. union/_docstring.py +26 -0
  229. union/_environment.py +43 -0
  230. union/_group.py +31 -0
  231. union/_hash.py +23 -0
  232. union/_image.py +760 -0
  233. union/_initialize.py +585 -0
  234. union/_interface.py +84 -0
  235. union/_internal/__init__.py +3 -0
  236. union/_internal/controllers/__init__.py +77 -0
  237. union/_internal/controllers/_local_controller.py +77 -0
  238. union/_internal/controllers/pbhash.py +39 -0
  239. union/_internal/controllers/remote/__init__.py +40 -0
  240. union/_internal/controllers/remote/_action.py +131 -0
  241. union/_internal/controllers/remote/_client.py +43 -0
  242. union/_internal/controllers/remote/_controller.py +169 -0
  243. union/_internal/controllers/remote/_core.py +341 -0
  244. union/_internal/controllers/remote/_informer.py +260 -0
  245. union/_internal/controllers/remote/_service_protocol.py +44 -0
  246. union/_internal/imagebuild/__init__.py +11 -0
  247. union/_internal/imagebuild/docker_builder.py +416 -0
  248. union/_internal/imagebuild/image_builder.py +243 -0
  249. union/_internal/imagebuild/remote_builder.py +0 -0
  250. union/_internal/resolvers/__init__.py +0 -0
  251. union/_internal/resolvers/_task_module.py +31 -0
  252. union/_internal/resolvers/common.py +24 -0
  253. union/_internal/resolvers/default.py +27 -0
  254. union/_internal/runtime/__init__.py +0 -0
  255. union/_internal/runtime/convert.py +163 -0
  256. union/_internal/runtime/entrypoints.py +121 -0
  257. union/_internal/runtime/io.py +136 -0
  258. union/_internal/runtime/resources_serde.py +134 -0
  259. union/_internal/runtime/task_serde.py +202 -0
  260. union/_internal/runtime/taskrunner.py +179 -0
  261. union/_internal/runtime/types_serde.py +53 -0
  262. union/_logging.py +124 -0
  263. union/_protos/__init__.py +0 -0
  264. union/_protos/common/authorization_pb2.py +66 -0
  265. union/_protos/common/authorization_pb2.pyi +106 -0
  266. union/_protos/common/authorization_pb2_grpc.py +4 -0
  267. union/_protos/common/identifier_pb2.py +71 -0
  268. union/_protos/common/identifier_pb2.pyi +82 -0
  269. union/_protos/common/identifier_pb2_grpc.py +4 -0
  270. union/_protos/common/identity_pb2.py +48 -0
  271. union/_protos/common/identity_pb2.pyi +72 -0
  272. union/_protos/common/identity_pb2_grpc.py +4 -0
  273. union/_protos/common/list_pb2.py +36 -0
  274. union/_protos/common/list_pb2.pyi +69 -0
  275. union/_protos/common/list_pb2_grpc.py +4 -0
  276. union/_protos/common/policy_pb2.py +37 -0
  277. union/_protos/common/policy_pb2.pyi +27 -0
  278. union/_protos/common/policy_pb2_grpc.py +4 -0
  279. union/_protos/common/role_pb2.py +37 -0
  280. union/_protos/common/role_pb2.pyi +51 -0
  281. union/_protos/common/role_pb2_grpc.py +4 -0
  282. union/_protos/common/runtime_version_pb2.py +28 -0
  283. union/_protos/common/runtime_version_pb2.pyi +24 -0
  284. union/_protos/common/runtime_version_pb2_grpc.py +4 -0
  285. union/_protos/logs/dataplane/payload_pb2.py +96 -0
  286. union/_protos/logs/dataplane/payload_pb2.pyi +168 -0
  287. union/_protos/logs/dataplane/payload_pb2_grpc.py +4 -0
  288. union/_protos/secret/definition_pb2.py +49 -0
  289. union/_protos/secret/definition_pb2.pyi +93 -0
  290. union/_protos/secret/definition_pb2_grpc.py +4 -0
  291. union/_protos/secret/payload_pb2.py +62 -0
  292. union/_protos/secret/payload_pb2.pyi +94 -0
  293. union/_protos/secret/payload_pb2_grpc.py +4 -0
  294. union/_protos/secret/secret_pb2.py +38 -0
  295. union/_protos/secret/secret_pb2.pyi +6 -0
  296. union/_protos/secret/secret_pb2_grpc.py +198 -0
  297. union/_protos/validate/validate/validate_pb2.py +76 -0
  298. union/_protos/workflow/node_execution_service_pb2.py +26 -0
  299. union/_protos/workflow/node_execution_service_pb2.pyi +4 -0
  300. union/_protos/workflow/node_execution_service_pb2_grpc.py +32 -0
  301. union/_protos/workflow/queue_service_pb2.py +75 -0
  302. union/_protos/workflow/queue_service_pb2.pyi +103 -0
  303. union/_protos/workflow/queue_service_pb2_grpc.py +172 -0
  304. union/_protos/workflow/run_definition_pb2.py +100 -0
  305. union/_protos/workflow/run_definition_pb2.pyi +256 -0
  306. union/_protos/workflow/run_definition_pb2_grpc.py +4 -0
  307. union/_protos/workflow/run_logs_service_pb2.py +41 -0
  308. union/_protos/workflow/run_logs_service_pb2.pyi +28 -0
  309. union/_protos/workflow/run_logs_service_pb2_grpc.py +69 -0
  310. union/_protos/workflow/run_service_pb2.py +133 -0
  311. union/_protos/workflow/run_service_pb2.pyi +173 -0
  312. union/_protos/workflow/run_service_pb2_grpc.py +412 -0
  313. union/_protos/workflow/state_service_pb2.py +58 -0
  314. union/_protos/workflow/state_service_pb2.pyi +69 -0
  315. union/_protos/workflow/state_service_pb2_grpc.py +138 -0
  316. union/_protos/workflow/task_definition_pb2.py +72 -0
  317. union/_protos/workflow/task_definition_pb2.pyi +65 -0
  318. union/_protos/workflow/task_definition_pb2_grpc.py +4 -0
  319. union/_protos/workflow/task_service_pb2.py +44 -0
  320. union/_protos/workflow/task_service_pb2.pyi +31 -0
  321. union/_protos/workflow/task_service_pb2_grpc.py +104 -0
  322. union/_resources.py +226 -0
  323. union/_retry.py +32 -0
  324. union/_reusable_environment.py +25 -0
  325. union/_run.py +374 -0
  326. union/_secret.py +61 -0
  327. union/_task.py +354 -0
  328. union/_task_environment.py +186 -0
  329. union/_timeout.py +47 -0
  330. union/_tools.py +27 -0
  331. union/_utils/__init__.py +11 -0
  332. union/_utils/asyn.py +119 -0
  333. union/_utils/file_handling.py +71 -0
  334. union/_utils/helpers.py +46 -0
  335. union/_utils/lazy_module.py +54 -0
  336. union/_utils/uv_script_parser.py +49 -0
  337. union/_version.py +21 -0
  338. union/connectors/__init__.py +0 -0
  339. union/errors.py +128 -0
  340. union/extras/__init__.py +5 -0
  341. union/extras/_container.py +263 -0
  342. union/io/__init__.py +11 -0
  343. union/io/_dataframe.py +0 -0
  344. union/io/_dir.py +425 -0
  345. union/io/_file.py +418 -0
  346. union/io/pickle/__init__.py +0 -0
  347. union/io/pickle/transformer.py +117 -0
  348. union/io/structured_dataset/__init__.py +122 -0
  349. union/io/structured_dataset/basic_dfs.py +219 -0
  350. union/io/structured_dataset/structured_dataset.py +1057 -0
  351. union/py.typed +0 -0
  352. union/remote/__init__.py +23 -0
  353. union/remote/_client/__init__.py +0 -0
  354. union/remote/_client/_protocols.py +129 -0
  355. union/remote/_client/auth/__init__.py +12 -0
  356. union/remote/_client/auth/_authenticators/__init__.py +0 -0
  357. union/remote/_client/auth/_authenticators/base.py +391 -0
  358. union/remote/_client/auth/_authenticators/client_credentials.py +73 -0
  359. union/remote/_client/auth/_authenticators/device_code.py +120 -0
  360. union/remote/_client/auth/_authenticators/external_command.py +77 -0
  361. union/remote/_client/auth/_authenticators/factory.py +200 -0
  362. union/remote/_client/auth/_authenticators/pkce.py +515 -0
  363. union/remote/_client/auth/_channel.py +184 -0
  364. union/remote/_client/auth/_client_config.py +83 -0
  365. union/remote/_client/auth/_default_html.py +32 -0
  366. union/remote/_client/auth/_grpc_utils/__init__.py +0 -0
  367. union/remote/_client/auth/_grpc_utils/auth_interceptor.py +204 -0
  368. union/remote/_client/auth/_grpc_utils/default_metadata_interceptor.py +144 -0
  369. union/remote/_client/auth/_keyring.py +154 -0
  370. union/remote/_client/auth/_token_client.py +258 -0
  371. union/remote/_client/auth/errors.py +16 -0
  372. union/remote/_client/controlplane.py +86 -0
  373. union/remote/_data.py +149 -0
  374. union/remote/_logs.py +74 -0
  375. union/remote/_project.py +86 -0
  376. union/remote/_run.py +820 -0
  377. union/remote/_secret.py +132 -0
  378. union/remote/_task.py +193 -0
  379. union/report/__init__.py +3 -0
  380. union/report/_report.py +178 -0
  381. union/report/_template.html +124 -0
  382. union/storage/__init__.py +24 -0
  383. union/storage/_remote_fs.py +34 -0
  384. union/storage/_storage.py +247 -0
  385. union/storage/_utils.py +5 -0
  386. union/types/__init__.py +11 -0
  387. union/types/_renderer.py +162 -0
  388. union/types/_string_literals.py +120 -0
  389. union/types/_type_engine.py +2131 -0
  390. union/types/_utils.py +80 -0
union/_initialize.py ADDED
@@ -0,0 +1,585 @@
1
+ from __future__ import annotations
2
+
3
+ import datetime
4
+ import functools
5
+ import os
6
+ import threading
7
+ import typing
8
+ from abc import abstractmethod
9
+ from dataclasses import dataclass, replace
10
+ from datetime import timedelta
11
+ from pathlib import Path
12
+ from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, List, Literal, Optional, TypeVar
13
+
14
+ from union.errors import InitializationError
15
+
16
+ from ._api_commons import syncer
17
+ from ._logging import initialize_logger
18
+ from ._tools import ipython_check
19
+
20
+ if TYPE_CHECKING:
21
+ from union.remote._client.auth import AuthType, ClientConfig
22
+ from union.remote._client.controlplane import ClientSet
23
+
24
+ Mode = Literal["local", "remote"]
25
+
26
+
27
+ def set_if_exists(d: dict, k: str, val: typing.Any) -> dict:
28
+ """
29
+ Given a dict ``d`` sets the key ``k`` with value of config ``v``, if the config value ``v`` is set
30
+ and return the updated dictionary.
31
+ """
32
+ exists = isinstance(val, bool) or bool(val is not None and val)
33
+ if exists:
34
+ d[k] = val
35
+ return d
36
+
37
+
38
+ @dataclass(init=True, repr=True, eq=True, frozen=True)
39
+ class Storage(object):
40
+ """
41
+ Data storage configuration that applies across any provider.
42
+ """
43
+
44
+ retries: int = 3
45
+ backoff: datetime.timedelta = datetime.timedelta(seconds=5)
46
+ enable_debug: bool = False
47
+ attach_execution_metadata: bool = True
48
+
49
+ _KEY_ENV_VAR_MAPPING: ClassVar[typing.Dict[str, str]] = {
50
+ "enable_debug": "UNION_STORAGE_DEBUG",
51
+ "retries": "UNION_STORAGE_RETRIES",
52
+ "backoff": "UNION_STORAGE_BACKOFF_SECONDS",
53
+ }
54
+
55
+ @abstractmethod
56
+ def get_fsspec_kwargs(self, anonymous: bool = False, /, **kwargs) -> Dict[str, Any]:
57
+ """
58
+ Returns the configuration as kwargs for constructing an fsspec filesystem.
59
+ """
60
+ # TODO: Get attach_execution_metadata from config
61
+
62
+ @classmethod
63
+ def _auto_as_kwargs(cls) -> Dict[str, Any]:
64
+ retries = os.getenv(cls._KEY_ENV_VAR_MAPPING["retries"])
65
+ backoff = os.getenv(cls._KEY_ENV_VAR_MAPPING["backoff"])
66
+ enable_debug = os.getenv(cls._KEY_ENV_VAR_MAPPING["enable_debug"])
67
+
68
+ kwargs = {}
69
+ kwargs = set_if_exists(kwargs, "enable_debug", enable_debug)
70
+ kwargs = set_if_exists(kwargs, "retries", retries)
71
+ kwargs = set_if_exists(kwargs, "backoff", backoff)
72
+ return kwargs
73
+
74
+ @classmethod
75
+ @abstractmethod
76
+ def auto(cls) -> Storage:
77
+ """
78
+ Construct the config object automatically from environment variables.
79
+ """
80
+ ...
81
+
82
+
83
+ @dataclass(init=True, repr=True, eq=True, frozen=True)
84
+ class S3(Storage):
85
+ """
86
+ S3 specific configuration
87
+ """
88
+
89
+ endpoint: typing.Optional[str] = None
90
+ access_key_id: typing.Optional[str] = None
91
+ secret_access_key: typing.Optional[str] = None
92
+
93
+ _KEY_ENV_VAR_MAPPING: ClassVar[typing.Dict[str, str]] = {
94
+ "endpoint": "FLYTE_AWS_ENDPOINT",
95
+ "access_key_id": "FLYTE_AWS_ACCESS_KEY_ID",
96
+ "secret_access_key": "FLYTE_AWS_SECRET_ACCESS_KEY",
97
+ } | Storage._KEY_ENV_VAR_MAPPING
98
+
99
+ # Refer to https://github.com/developmentseed/obstore/blob/33654fc37f19a657689eb93327b621e9f9e01494/obstore/python/obstore/store/_aws.pyi#L11
100
+ # for key and secret
101
+ _CONFIG_KEY_FSSPEC_S3_KEY_ID: ClassVar = "access_key_id"
102
+ _CONFIG_KEY_FSSPEC_S3_SECRET: ClassVar = "secret_access_key"
103
+ _CONFIG_KEY_ENDPOINT: ClassVar = "endpoint_url"
104
+ _KEY_SKIP_SIGNATURE: ClassVar = "skip_signature"
105
+
106
+ @classmethod
107
+ def auto(cls) -> S3:
108
+ """
109
+ :return: Config
110
+ """
111
+ endpoint = os.getenv(cls._KEY_ENV_VAR_MAPPING["endpoint"], None)
112
+ access_key_id = os.getenv(cls._KEY_ENV_VAR_MAPPING["access_key_id"], None)
113
+ secret_access_key = os.getenv(cls._KEY_ENV_VAR_MAPPING["secret_access_key"], None)
114
+
115
+ kwargs = super()._auto_as_kwargs()
116
+ kwargs = set_if_exists(kwargs, "endpoint", endpoint)
117
+ kwargs = set_if_exists(kwargs, "access_key_id", access_key_id)
118
+ kwargs = set_if_exists(kwargs, "secret_access_key", secret_access_key)
119
+
120
+ return S3(**kwargs)
121
+
122
+ @classmethod
123
+ def for_sandbox(cls) -> S3:
124
+ """
125
+ :return:
126
+ """
127
+ kwargs = {
128
+ "endpoint": "http://localhost:4566",
129
+ "access_key_id": "minio",
130
+ "secret_access_key": "miniostorage",
131
+ }
132
+ return S3(**kwargs)
133
+
134
+ def get_fsspec_kwargs(self, anonymous: bool = False, /, **kwargs) -> Dict[str, Any]:
135
+ # Construct the config object
136
+ config: Dict[str, Any] = {}
137
+ if self._CONFIG_KEY_FSSPEC_S3_KEY_ID in kwargs or self.access_key_id:
138
+ config[self._CONFIG_KEY_FSSPEC_S3_KEY_ID] = kwargs.pop(
139
+ self._CONFIG_KEY_FSSPEC_S3_KEY_ID, self.access_key_id
140
+ )
141
+ if self._CONFIG_KEY_FSSPEC_S3_SECRET in kwargs or self.secret_access_key:
142
+ config[self._CONFIG_KEY_FSSPEC_S3_SECRET] = kwargs.pop(
143
+ self._CONFIG_KEY_FSSPEC_S3_SECRET, self.secret_access_key
144
+ )
145
+ if self._CONFIG_KEY_ENDPOINT in kwargs or self.endpoint:
146
+ config["endpoint_url"] = kwargs.pop(self._CONFIG_KEY_ENDPOINT, self.endpoint)
147
+
148
+ retries = kwargs.pop("retries", self.retries)
149
+ backoff = kwargs.pop("backoff", self.backoff)
150
+
151
+ if anonymous:
152
+ config[self._KEY_SKIP_SIGNATURE] = True
153
+
154
+ retry_config = {
155
+ "max_retries": retries,
156
+ "backoff": {
157
+ "base": 2,
158
+ "init_backoff": backoff,
159
+ "max_backoff": timedelta(seconds=16),
160
+ },
161
+ "retry_timeout": timedelta(minutes=3),
162
+ }
163
+
164
+ client_options = {"timeout": "99999s", "allow_http": True}
165
+
166
+ if config:
167
+ kwargs["config"] = config
168
+ kwargs["client_options"] = client_options or None
169
+ kwargs["retry_config"] = retry_config or None
170
+
171
+ return kwargs
172
+
173
+
174
+ @dataclass(init=True, repr=True, eq=True, frozen=True)
175
+ class GCS(Storage):
176
+ """
177
+ Any GCS specific configuration.
178
+ """
179
+
180
+ gsutil_parallelism: bool = False
181
+
182
+ _KEY_ENV_VAR_MAPPING: ClassVar[dict[str, str]] = {
183
+ "gsutil_parallelism": "GCP_GSUTIL_PARALLELISM",
184
+ }
185
+
186
+ @classmethod
187
+ def auto(cls) -> GCS:
188
+ gsutil_parallelism = os.getenv(cls._KEY_ENV_VAR_MAPPING["gsutil_parallelism"], None)
189
+
190
+ kwargs = {}
191
+ kwargs = set_if_exists(kwargs, "gsutil_parallelism", gsutil_parallelism)
192
+ return GCS(**kwargs)
193
+
194
+ def get_fsspec_kwargs(self, anonymous: bool = False, /, **kwargs) -> Dict[str, Any]:
195
+ return kwargs
196
+
197
+
198
+ @dataclass(init=True, repr=True, eq=True, frozen=True)
199
+ class ABFS(Storage):
200
+ """
201
+ Any Azure Blob Storage specific configuration.
202
+ """
203
+
204
+ account_name: typing.Optional[str] = None
205
+ account_key: typing.Optional[str] = None
206
+ tenant_id: typing.Optional[str] = None
207
+ client_id: typing.Optional[str] = None
208
+ client_secret: typing.Optional[str] = None
209
+
210
+ _KEY_ENV_VAR_MAPPING: ClassVar[dict[str, str]] = {
211
+ "account_name": "AZURE_STORAGE_ACCOUNT_NAME",
212
+ "account_key": "AZURE_STORAGE_ACCOUNT_KEY",
213
+ "tenant_id": "AZURE_TENANT_ID",
214
+ "client_id": "AZURE_CLIENT_ID",
215
+ "client_secret": "AZURE_CLIENT_SECRET",
216
+ }
217
+ _KEY_SKIP_SIGNATURE: ClassVar = "skip_signature"
218
+
219
+ @classmethod
220
+ def auto(cls) -> ABFS:
221
+ account_name = os.getenv(cls._KEY_ENV_VAR_MAPPING["account_name"], None)
222
+ account_key = os.getenv(cls._KEY_ENV_VAR_MAPPING["account_key"], None)
223
+ tenant_id = os.getenv(cls._KEY_ENV_VAR_MAPPING["tenant_id"], None)
224
+ client_id = os.getenv(cls._KEY_ENV_VAR_MAPPING["client_id"], None)
225
+ client_secret = os.getenv(cls._KEY_ENV_VAR_MAPPING["client_secret"], None)
226
+
227
+ kwargs = {}
228
+ kwargs = set_if_exists(kwargs, "account_name", account_name)
229
+ kwargs = set_if_exists(kwargs, "account_key", account_key)
230
+ kwargs = set_if_exists(kwargs, "tenant_id", tenant_id)
231
+ kwargs = set_if_exists(kwargs, "client_id", client_id)
232
+ kwargs = set_if_exists(kwargs, "client_secret", client_secret)
233
+ return ABFS(**kwargs)
234
+
235
+ def get_fsspec_kwargs(self, anonymous: bool = False, /, **kwargs) -> Dict[str, Any]:
236
+ config: Dict[str, Any] = {}
237
+ if "account_name" in kwargs or self.account_name:
238
+ config["account_name"] = kwargs.get("account_name", self.account_name)
239
+ if "account_key" in kwargs or self.account_key:
240
+ config["account_key"] = kwargs.get("account_key", self.account_key)
241
+ if "client_id" in kwargs or self.client_id:
242
+ config["client_id"] = kwargs.get("client_id", self.client_id)
243
+ if "client_secret" in kwargs or self.client_secret:
244
+ config["client_secret"] = kwargs.get("client_secret", self.client_secret)
245
+ if "tenant_id" in kwargs or self.tenant_id:
246
+ config["tenant_id"] = kwargs.get("tenant_id", self.tenant_id)
247
+
248
+ if anonymous:
249
+ config[self._KEY_SKIP_SIGNATURE] = True
250
+
251
+ client_options = {"timeout": "99999s", "allow_http": "true"}
252
+
253
+ if config:
254
+ kwargs["config"] = config
255
+ kwargs["client_options"] = client_options
256
+
257
+ return kwargs
258
+
259
+
260
+ @dataclass(init=True, repr=True, eq=True, frozen=True, kw_only=True)
261
+ class CommonInit:
262
+ """
263
+ Common initialization configuration for Union.
264
+ """
265
+
266
+ root_dir: Path
267
+ org: str | None = None
268
+ project: str | None = None
269
+ domain: str | None = None
270
+
271
+
272
+ @dataclass(init=True, kw_only=True, repr=True, eq=True, frozen=True)
273
+ class _InitConfig(CommonInit):
274
+ client: Optional[ClientSet] = None
275
+ storage: Optional[Storage] = None
276
+
277
+ def replace(self, **kwargs) -> _InitConfig:
278
+ return replace(self, **kwargs)
279
+
280
+
281
+ # Global singleton to store initialization configuration
282
+ _init_config: _InitConfig | None = None
283
+ _init_lock = threading.RLock() # Reentrant lock for thread safety
284
+
285
+
286
+ async def _initialize_client(
287
+ api_key: str | None = None,
288
+ auth_type: AuthType = "Pkce",
289
+ endpoint: str | None = None,
290
+ client_config: ClientConfig | None = None,
291
+ headless: bool = False,
292
+ insecure: bool = False,
293
+ insecure_skip_verify: bool = False,
294
+ ca_cert_file_path: str | None = None,
295
+ command: List[str] | None = None,
296
+ proxy_command: List[str] | None = None,
297
+ client_id: str | None = None,
298
+ client_credentials_secret: str | None = None,
299
+ rpc_retries: int = 3,
300
+ http_proxy_url: str | None = None,
301
+ ) -> ClientSet:
302
+ """
303
+ Initialize the client based on the execution mode.
304
+ :return: The initialized client
305
+ """
306
+ from union.remote._client.controlplane import ClientSet
307
+
308
+ if endpoint is not None:
309
+ return await ClientSet.for_endpoint(
310
+ endpoint,
311
+ insecure=insecure,
312
+ api_key=api_key,
313
+ insecure_skip_verify=insecure_skip_verify,
314
+ auth_type=auth_type,
315
+ headless=headless,
316
+ ca_cert_file_path=ca_cert_file_path,
317
+ command=command,
318
+ proxy_command=proxy_command,
319
+ client_id=client_id,
320
+ client_credentials_secret=client_credentials_secret,
321
+ client_config=client_config,
322
+ rpc_retries=rpc_retries,
323
+ http_proxy_url=http_proxy_url,
324
+ )
325
+ raise NotImplementedError("Currently only endpoints are supported.")
326
+
327
+
328
+ @syncer.wrap
329
+ async def init(
330
+ org: str | None = None,
331
+ project: str | None = None,
332
+ domain: str | None = None,
333
+ root_dir: Path | None = None,
334
+ log_level: int | None = None,
335
+ endpoint: str | None = None,
336
+ headless: bool = False,
337
+ insecure: bool = False,
338
+ insecure_skip_verify: bool = False,
339
+ ca_cert_file_path: str | None = None,
340
+ auth_type: AuthType = "Pkce",
341
+ command: List[str] | None = None,
342
+ proxy_command: List[str] | None = None,
343
+ api_key: str | None = None,
344
+ client_id: str | None = None,
345
+ client_credentials_secret: str | None = None,
346
+ auth_client_config: ClientConfig | None = None,
347
+ rpc_retries: int = 3,
348
+ http_proxy_url: str | None = None,
349
+ storage: Storage | None = None,
350
+ ) -> None:
351
+ """
352
+ Initialize the Union system with the given configuration. This method should be called before any other Union
353
+ remote API methods are called. Thread-safe implementation.
354
+
355
+ :param project: Optional project name (not used in this implementation)
356
+ :param domain: Optional domain name (not used in this implementation)
357
+ :param root_dir: Optional root directory for the Union system
358
+ :param log_level: Optional logging level for the logger, default is set using the default initialization policies
359
+ :param api_key: Optional API key for authentication
360
+ :param endpoint: Optional API endpoint URL
361
+ :param headless: Optional Whether to run in headless mode
362
+ :param mode: Optional execution model (local, remote). Default is local. When local is used,
363
+ the execution will be done locally. When remote is used, the execution will be sent to a remote server,
364
+ In the remote case, the endpoint or api_key must be set.
365
+ :param insecure_skip_verify: Whether to skip SSL certificate verification
366
+ :param auth_client_config: Optional client configuration for authentication
367
+ :param auth_type: The authentication type to use (Pkce, ClientSecret, ExternalCommand, DeviceFlow)
368
+ :param command: This command is executed to return a token using an external process
369
+ :param proxy_command: This command is executed to return a token for proxy authorization using an external process
370
+ :param client_id: This is the public identifier for the app which handles authorization for a Flyte deployment.
371
+ More details here: https://www.oauth.com/oauth2-servers/client-registration/client-id-secret/.
372
+ :param client_credentials_secret: Used for service auth, which is automatically called during pyflyte. This will
373
+ allow the Flyte engine to read the password directly from the environment variable. Note that this is
374
+ less secure! Please only use this if mounting the secret as a file is impossible
375
+ :param ca_cert_file_path: [optional] str Root Cert to be loaded and used to verify admin
376
+ :param http_proxy_url: [optional] HTTP Proxy to be used for OAuth requests
377
+ :param rpc_retries: [optional] int Number of times to retry the platform calls
378
+ :param audience: oauth2 audience for the token request. This is used to validate the token
379
+ :param insecure: insecure flag for the client
380
+ :param storage: Optional blob store (S3, GCS, Azure) configuration if needed to access (i.e. using Minio)
381
+ :param org: Optional organization override for the client. Should be set by auth instead.
382
+
383
+ :return: None
384
+ """
385
+ interactive_mode = ipython_check()
386
+
387
+ initialize_logger(enable_rich=interactive_mode)
388
+ if log_level:
389
+ initialize_logger(log_level=log_level, enable_rich=interactive_mode)
390
+
391
+ global _init_config # noqa: PLW0603
392
+
393
+ with _init_lock:
394
+ client = None
395
+ if endpoint or api_key:
396
+ client = await _initialize_client(
397
+ api_key=api_key,
398
+ auth_type=auth_type,
399
+ endpoint=endpoint,
400
+ headless=headless,
401
+ insecure=insecure,
402
+ insecure_skip_verify=insecure_skip_verify,
403
+ ca_cert_file_path=ca_cert_file_path,
404
+ command=command,
405
+ proxy_command=proxy_command,
406
+ client_id=client_id,
407
+ client_credentials_secret=client_credentials_secret,
408
+ client_config=auth_client_config,
409
+ rpc_retries=rpc_retries,
410
+ http_proxy_url=http_proxy_url,
411
+ )
412
+
413
+ _init_config = _InitConfig(
414
+ root_dir=root_dir or Path.cwd(),
415
+ project=project,
416
+ domain=domain,
417
+ client=client,
418
+ storage=storage,
419
+ org=org,
420
+ )
421
+
422
+
423
+ def _get_init_config() -> _InitConfig | None:
424
+ """
425
+ Get the current initialization configuration. Thread-safe implementation.
426
+
427
+ :return: The current InitData if initialized, None otherwise
428
+ """
429
+ with _init_lock:
430
+ return _init_config
431
+
432
+
433
+ def get_common_config() -> Optional[CommonInit]:
434
+ """
435
+ Get the current initialization configuration. Thread-safe implementation.
436
+
437
+ :return: The current InitData if initialized, None otherwise
438
+ """
439
+ with _init_lock:
440
+ return _init_config
441
+
442
+
443
+ def get_storage() -> Optional[Storage]:
444
+ """
445
+ Get the current storage configuration. Thread-safe implementation.
446
+
447
+ :return: The current storage configuration
448
+ """
449
+ cfg = _get_init_config()
450
+ if cfg is None:
451
+ return None
452
+ return cfg.storage
453
+
454
+
455
+ def get_client() -> Optional[ClientSet]:
456
+ """
457
+ Get the current client. Thread-safe implementation.
458
+
459
+ :return: The current client
460
+ """
461
+ cfg = _get_init_config()
462
+ if cfg is None:
463
+ return None
464
+ return cfg.client
465
+
466
+
467
+ def is_initialized() -> bool:
468
+ """
469
+ Check if the system has been initialized.
470
+
471
+ :return: True if initialized, False otherwise
472
+ """
473
+ return _get_init_config() is not None
474
+
475
+
476
+ def initialize_in_cluster(storage: Storage | None = None) -> None:
477
+ """
478
+ Initialize the system for in-cluster execution. This is a placeholder function and does not perform any actions.
479
+
480
+ :return: None
481
+ """
482
+ init(storage=storage)
483
+
484
+
485
+ # Define a generic type variable for the decorated function
486
+ T = TypeVar("T", bound=Callable)
487
+
488
+
489
+ def requires_client(func: T) -> T:
490
+ """
491
+ Decorator that checks if the client has been initialized before executing the function.
492
+ Raises InitializationError if the client is not initialized.
493
+
494
+ :param func: Function to decorate
495
+ :return: Decorated function that checks for initialization
496
+ """
497
+
498
+ @functools.wraps(func)
499
+ def wrapper(*args, **kwargs):
500
+ if _get_init_config() is None or _get_init_config().client is None:
501
+ raise InitializationError(
502
+ "ClientNotInitializedError",
503
+ "user",
504
+ f"Function '{func.__name__}' requires client to be initialized. "
505
+ f"Call union.init() with a valid endpoint or api-key before using this function.",
506
+ )
507
+ return func(*args, **kwargs)
508
+
509
+ return wrapper
510
+
511
+
512
+ def requires_upload_location(func: T) -> T:
513
+ """
514
+ Decorator that checks if the storage has been initialized before executing the function.
515
+ Raises InitializationError if the storage is not initialized.
516
+
517
+ :param func: Function to decorate
518
+ :return: Decorated function that checks for initialization
519
+ """
520
+
521
+ @functools.wraps(func)
522
+ def wrapper(*args, **kwargs):
523
+ from ._context import internal_ctx
524
+
525
+ ctx = internal_ctx()
526
+ if not ctx.raw_data:
527
+ raise InitializationError(
528
+ "No upload path configured",
529
+ "user",
530
+ f"Function '{func.__name__}' requires client to be initialized. "
531
+ f"Call union.init() with storage configuration before using this function.",
532
+ )
533
+ return func(*args, **kwargs)
534
+
535
+ return wrapper
536
+
537
+
538
+ def requires_initialization(func: T) -> T:
539
+ """
540
+ Decorator that checks if the system has been initialized before executing the function.
541
+ Raises InitializationError if the system is not initialized.
542
+
543
+ :param func: Function to decorate
544
+ :return: Decorated function that checks for initialization
545
+ """
546
+
547
+ @functools.wraps(func)
548
+ def wrapper(*args, **kwargs):
549
+ if not is_initialized():
550
+ raise InitializationError(
551
+ "NotInitConfiguredError",
552
+ "user",
553
+ f"Function '{func.__name__}' requires initialization. Call union.init() before using this function.",
554
+ )
555
+ return func(*args, **kwargs)
556
+
557
+ return wrapper
558
+
559
+
560
+ async def _init_for_testing(
561
+ project: str | None = None,
562
+ domain: str | None = None,
563
+ root_dir: Path | None = None,
564
+ log_level: int | None = None,
565
+ client: ClientSet | None = None,
566
+ ):
567
+ global _init_config # noqa: PLW0603
568
+
569
+ if log_level:
570
+ initialize_logger(log_level=log_level)
571
+
572
+ with _init_lock:
573
+ _init_config = _InitConfig(
574
+ root_dir=root_dir or Path.cwd(),
575
+ project=project,
576
+ domain=domain,
577
+ client=client,
578
+ )
579
+
580
+
581
+ def replace_client(client):
582
+ global _init_config # noqa: PLW0603
583
+
584
+ with _init_lock:
585
+ _init_config = _init_config.replace(client=client)
union/_interface.py ADDED
@@ -0,0 +1,84 @@
1
+ from __future__ import annotations
2
+
3
+ import inspect
4
+ from typing import Dict, Generator, Tuple, Type, TypeVar, Union, cast, get_args, get_type_hints
5
+
6
+ from union._logging import logger
7
+
8
+
9
+ def default_output_name(index: int = 0) -> str:
10
+ return f"o{index}"
11
+
12
+
13
+ def output_name_generator(length: int) -> Generator[str, None, None]:
14
+ for x in range(length):
15
+ yield default_output_name(x)
16
+
17
+
18
+ def extract_return_annotation(return_annotation: Union[Type, Tuple, None]) -> Dict[str, Type]:
19
+ """
20
+ The input to this function should be sig.return_annotation where sig = inspect.signature(some_func)
21
+ The purpose of this function is to sort out whether a function is returning one thing, or multiple things, and to
22
+ name the outputs accordingly, either by using our default name function, or from a typing.NamedTuple.
23
+
24
+ # Option 1
25
+ nt1 = typing.NamedTuple("NT1", x_str=str, y_int=int)
26
+ def t(a: int, b: str) -> nt1: ...
27
+
28
+ # Option 2
29
+ def t(a: int, b: str) -> typing.NamedTuple("NT1", x_str=str, y_int=int): ...
30
+
31
+ # Option 3
32
+ def t(a: int, b: str) -> typing.Tuple[int, str]: ...
33
+
34
+ # Option 4
35
+ def t(a: int, b: str) -> (int, str): ...
36
+
37
+ # Option 5
38
+ def t(a: int, b: str) -> str: ...
39
+
40
+ # Option 6
41
+ def t(a: int, b: str) -> None: ...
42
+
43
+ # Options 7/8
44
+ def t(a: int, b: str) -> List[int]: ...
45
+ def t(a: int, b: str) -> Dict[str, int]: ...
46
+
47
+ Note that Options 1 and 2 are identical, just syntactic sugar. In the NamedTuple case, we'll use the names in the
48
+ definition. In all other cases, we'll automatically generate output names, indexed starting at 0.
49
+ """
50
+
51
+ # Handle Option 6
52
+ # We can think about whether we should add a default output name with type None in the future.
53
+ if return_annotation in (None, type(None), inspect.Signature.empty):
54
+ return {}
55
+
56
+ # This statement results in true for typing.Namedtuple, single and void return types, so this
57
+ # handles Options 1, 2. Even though NamedTuple for us is multi-valued, it's a single value for Python
58
+ if hasattr(return_annotation, "__bases__") and (
59
+ isinstance(return_annotation, type) or isinstance(return_annotation, TypeVar)
60
+ ):
61
+ # isinstance / issubclass does not work for Namedtuple.
62
+ # Options 1 and 2
63
+ bases = return_annotation.__bases__ # type: ignore
64
+ if len(bases) == 1 and bases[0] is tuple and hasattr(return_annotation, "_fields"):
65
+ logger.debug(f"Task returns named tuple {return_annotation}")
66
+ return dict(get_type_hints(cast(Type, return_annotation), include_extras=True))
67
+
68
+ if hasattr(return_annotation, "__origin__") and return_annotation.__origin__ is tuple: # type: ignore
69
+ # Handle option 3
70
+ logger.debug(f"Task returns unnamed typing.Tuple {return_annotation}")
71
+ if len(return_annotation.__args__) == 1: # type: ignore
72
+ raise TypeError("Tuples should be used to indicate multiple return values, found only one return variable.")
73
+ ra = get_args(return_annotation)
74
+ return dict(zip(list(output_name_generator(len(ra))), ra))
75
+
76
+ elif isinstance(return_annotation, tuple):
77
+ if len(return_annotation) == 1:
78
+ raise TypeError("Please don't use a tuple if you're just returning one thing.")
79
+ return dict(zip(list(output_name_generator(len(return_annotation))), return_annotation))
80
+
81
+ else:
82
+ # Handle all other single return types
83
+ logger.debug(f"Task returns unnamed native tuple {return_annotation}")
84
+ return {default_output_name(): cast(Type, return_annotation)}
@@ -0,0 +1,3 @@
1
+ from .controllers import Controller, ControllerType, create_controller
2
+
3
+ __all__ = ["Controller", "ControllerType", "create_controller"]