flyte 0.0.1b0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flyte might be problematic. Click here for more details.

Files changed (390) hide show
  1. flyte/__init__.py +62 -0
  2. flyte/_api_commons.py +3 -0
  3. flyte/_bin/__init__.py +0 -0
  4. flyte/_bin/runtime.py +126 -0
  5. flyte/_build.py +25 -0
  6. flyte/_cache/__init__.py +12 -0
  7. flyte/_cache/cache.py +146 -0
  8. flyte/_cache/defaults.py +9 -0
  9. flyte/_cache/policy_function_body.py +42 -0
  10. flyte/_cli/__init__.py +0 -0
  11. flyte/_cli/_common.py +287 -0
  12. flyte/_cli/_create.py +42 -0
  13. flyte/_cli/_delete.py +23 -0
  14. flyte/_cli/_deploy.py +140 -0
  15. flyte/_cli/_get.py +235 -0
  16. flyte/_cli/_run.py +152 -0
  17. flyte/_cli/main.py +72 -0
  18. flyte/_code_bundle/__init__.py +8 -0
  19. flyte/_code_bundle/_ignore.py +113 -0
  20. flyte/_code_bundle/_packaging.py +187 -0
  21. flyte/_code_bundle/_utils.py +339 -0
  22. flyte/_code_bundle/bundle.py +178 -0
  23. flyte/_context.py +146 -0
  24. flyte/_datastructures.py +342 -0
  25. flyte/_deploy.py +202 -0
  26. flyte/_doc.py +29 -0
  27. flyte/_docstring.py +32 -0
  28. flyte/_environment.py +43 -0
  29. flyte/_group.py +31 -0
  30. flyte/_hash.py +23 -0
  31. flyte/_image.py +760 -0
  32. flyte/_initialize.py +634 -0
  33. flyte/_interface.py +84 -0
  34. flyte/_internal/__init__.py +3 -0
  35. flyte/_internal/controllers/__init__.py +115 -0
  36. flyte/_internal/controllers/_local_controller.py +118 -0
  37. flyte/_internal/controllers/_trace.py +40 -0
  38. flyte/_internal/controllers/pbhash.py +39 -0
  39. flyte/_internal/controllers/remote/__init__.py +40 -0
  40. flyte/_internal/controllers/remote/_action.py +141 -0
  41. flyte/_internal/controllers/remote/_client.py +43 -0
  42. flyte/_internal/controllers/remote/_controller.py +361 -0
  43. flyte/_internal/controllers/remote/_core.py +402 -0
  44. flyte/_internal/controllers/remote/_informer.py +361 -0
  45. flyte/_internal/controllers/remote/_service_protocol.py +50 -0
  46. flyte/_internal/imagebuild/__init__.py +11 -0
  47. flyte/_internal/imagebuild/docker_builder.py +416 -0
  48. flyte/_internal/imagebuild/image_builder.py +241 -0
  49. flyte/_internal/imagebuild/remote_builder.py +0 -0
  50. flyte/_internal/resolvers/__init__.py +0 -0
  51. flyte/_internal/resolvers/_task_module.py +54 -0
  52. flyte/_internal/resolvers/common.py +31 -0
  53. flyte/_internal/resolvers/default.py +28 -0
  54. flyte/_internal/runtime/__init__.py +0 -0
  55. flyte/_internal/runtime/convert.py +199 -0
  56. flyte/_internal/runtime/entrypoints.py +135 -0
  57. flyte/_internal/runtime/io.py +136 -0
  58. flyte/_internal/runtime/resources_serde.py +138 -0
  59. flyte/_internal/runtime/task_serde.py +210 -0
  60. flyte/_internal/runtime/taskrunner.py +190 -0
  61. flyte/_internal/runtime/types_serde.py +54 -0
  62. flyte/_logging.py +124 -0
  63. flyte/_protos/__init__.py +0 -0
  64. flyte/_protos/common/authorization_pb2.py +66 -0
  65. flyte/_protos/common/authorization_pb2.pyi +108 -0
  66. flyte/_protos/common/authorization_pb2_grpc.py +4 -0
  67. flyte/_protos/common/identifier_pb2.py +71 -0
  68. flyte/_protos/common/identifier_pb2.pyi +82 -0
  69. flyte/_protos/common/identifier_pb2_grpc.py +4 -0
  70. flyte/_protos/common/identity_pb2.py +48 -0
  71. flyte/_protos/common/identity_pb2.pyi +72 -0
  72. flyte/_protos/common/identity_pb2_grpc.py +4 -0
  73. flyte/_protos/common/list_pb2.py +36 -0
  74. flyte/_protos/common/list_pb2.pyi +69 -0
  75. flyte/_protos/common/list_pb2_grpc.py +4 -0
  76. flyte/_protos/common/policy_pb2.py +37 -0
  77. flyte/_protos/common/policy_pb2.pyi +27 -0
  78. flyte/_protos/common/policy_pb2_grpc.py +4 -0
  79. flyte/_protos/common/role_pb2.py +37 -0
  80. flyte/_protos/common/role_pb2.pyi +53 -0
  81. flyte/_protos/common/role_pb2_grpc.py +4 -0
  82. flyte/_protos/common/runtime_version_pb2.py +28 -0
  83. flyte/_protos/common/runtime_version_pb2.pyi +24 -0
  84. flyte/_protos/common/runtime_version_pb2_grpc.py +4 -0
  85. flyte/_protos/logs/dataplane/payload_pb2.py +96 -0
  86. flyte/_protos/logs/dataplane/payload_pb2.pyi +168 -0
  87. flyte/_protos/logs/dataplane/payload_pb2_grpc.py +4 -0
  88. flyte/_protos/secret/definition_pb2.py +49 -0
  89. flyte/_protos/secret/definition_pb2.pyi +93 -0
  90. flyte/_protos/secret/definition_pb2_grpc.py +4 -0
  91. flyte/_protos/secret/payload_pb2.py +62 -0
  92. flyte/_protos/secret/payload_pb2.pyi +94 -0
  93. flyte/_protos/secret/payload_pb2_grpc.py +4 -0
  94. flyte/_protos/secret/secret_pb2.py +38 -0
  95. flyte/_protos/secret/secret_pb2.pyi +6 -0
  96. flyte/_protos/secret/secret_pb2_grpc.py +198 -0
  97. flyte/_protos/secret/secret_pb2_grpc_grpc.py +198 -0
  98. flyte/_protos/validate/validate/validate_pb2.py +76 -0
  99. flyte/_protos/workflow/node_execution_service_pb2.py +26 -0
  100. flyte/_protos/workflow/node_execution_service_pb2.pyi +4 -0
  101. flyte/_protos/workflow/node_execution_service_pb2_grpc.py +32 -0
  102. flyte/_protos/workflow/queue_service_pb2.py +106 -0
  103. flyte/_protos/workflow/queue_service_pb2.pyi +141 -0
  104. flyte/_protos/workflow/queue_service_pb2_grpc.py +172 -0
  105. flyte/_protos/workflow/run_definition_pb2.py +128 -0
  106. flyte/_protos/workflow/run_definition_pb2.pyi +310 -0
  107. flyte/_protos/workflow/run_definition_pb2_grpc.py +4 -0
  108. flyte/_protos/workflow/run_logs_service_pb2.py +41 -0
  109. flyte/_protos/workflow/run_logs_service_pb2.pyi +28 -0
  110. flyte/_protos/workflow/run_logs_service_pb2_grpc.py +69 -0
  111. flyte/_protos/workflow/run_service_pb2.py +133 -0
  112. flyte/_protos/workflow/run_service_pb2.pyi +175 -0
  113. flyte/_protos/workflow/run_service_pb2_grpc.py +412 -0
  114. flyte/_protos/workflow/state_service_pb2.py +58 -0
  115. flyte/_protos/workflow/state_service_pb2.pyi +71 -0
  116. flyte/_protos/workflow/state_service_pb2_grpc.py +138 -0
  117. flyte/_protos/workflow/task_definition_pb2.py +72 -0
  118. flyte/_protos/workflow/task_definition_pb2.pyi +65 -0
  119. flyte/_protos/workflow/task_definition_pb2_grpc.py +4 -0
  120. flyte/_protos/workflow/task_service_pb2.py +44 -0
  121. flyte/_protos/workflow/task_service_pb2.pyi +31 -0
  122. flyte/_protos/workflow/task_service_pb2_grpc.py +104 -0
  123. flyte/_resources.py +226 -0
  124. flyte/_retry.py +32 -0
  125. flyte/_reusable_environment.py +25 -0
  126. flyte/_run.py +411 -0
  127. flyte/_secret.py +61 -0
  128. flyte/_task.py +367 -0
  129. flyte/_task_environment.py +200 -0
  130. flyte/_timeout.py +47 -0
  131. flyte/_tools.py +27 -0
  132. flyte/_trace.py +128 -0
  133. flyte/_utils/__init__.py +20 -0
  134. flyte/_utils/asyn.py +119 -0
  135. flyte/_utils/coro_management.py +25 -0
  136. flyte/_utils/file_handling.py +72 -0
  137. flyte/_utils/helpers.py +108 -0
  138. flyte/_utils/lazy_module.py +54 -0
  139. flyte/_utils/uv_script_parser.py +49 -0
  140. flyte/_version.py +21 -0
  141. flyte/connectors/__init__.py +0 -0
  142. flyte/errors.py +143 -0
  143. flyte/extras/__init__.py +5 -0
  144. flyte/extras/_container.py +273 -0
  145. flyte/io/__init__.py +11 -0
  146. flyte/io/_dataframe.py +0 -0
  147. flyte/io/_dir.py +448 -0
  148. flyte/io/_file.py +468 -0
  149. flyte/io/pickle/__init__.py +0 -0
  150. flyte/io/pickle/transformer.py +117 -0
  151. flyte/io/structured_dataset/__init__.py +129 -0
  152. flyte/io/structured_dataset/basic_dfs.py +219 -0
  153. flyte/io/structured_dataset/structured_dataset.py +1061 -0
  154. flyte/py.typed +0 -0
  155. flyte/remote/__init__.py +25 -0
  156. flyte/remote/_client/__init__.py +0 -0
  157. flyte/remote/_client/_protocols.py +131 -0
  158. flyte/remote/_client/auth/__init__.py +12 -0
  159. flyte/remote/_client/auth/_authenticators/__init__.py +0 -0
  160. flyte/remote/_client/auth/_authenticators/base.py +397 -0
  161. flyte/remote/_client/auth/_authenticators/client_credentials.py +73 -0
  162. flyte/remote/_client/auth/_authenticators/device_code.py +118 -0
  163. flyte/remote/_client/auth/_authenticators/external_command.py +79 -0
  164. flyte/remote/_client/auth/_authenticators/factory.py +200 -0
  165. flyte/remote/_client/auth/_authenticators/pkce.py +516 -0
  166. flyte/remote/_client/auth/_channel.py +184 -0
  167. flyte/remote/_client/auth/_client_config.py +83 -0
  168. flyte/remote/_client/auth/_default_html.py +32 -0
  169. flyte/remote/_client/auth/_grpc_utils/__init__.py +0 -0
  170. flyte/remote/_client/auth/_grpc_utils/auth_interceptor.py +288 -0
  171. flyte/remote/_client/auth/_grpc_utils/default_metadata_interceptor.py +151 -0
  172. flyte/remote/_client/auth/_keyring.py +143 -0
  173. flyte/remote/_client/auth/_token_client.py +260 -0
  174. flyte/remote/_client/auth/errors.py +16 -0
  175. flyte/remote/_client/controlplane.py +95 -0
  176. flyte/remote/_console.py +18 -0
  177. flyte/remote/_data.py +155 -0
  178. flyte/remote/_logs.py +116 -0
  179. flyte/remote/_project.py +86 -0
  180. flyte/remote/_run.py +873 -0
  181. flyte/remote/_secret.py +132 -0
  182. flyte/remote/_task.py +227 -0
  183. flyte/report/__init__.py +3 -0
  184. flyte/report/_report.py +178 -0
  185. flyte/report/_template.html +124 -0
  186. flyte/storage/__init__.py +24 -0
  187. flyte/storage/_remote_fs.py +34 -0
  188. flyte/storage/_storage.py +251 -0
  189. flyte/storage/_utils.py +5 -0
  190. flyte/types/__init__.py +13 -0
  191. flyte/types/_interface.py +25 -0
  192. flyte/types/_renderer.py +162 -0
  193. flyte/types/_string_literals.py +120 -0
  194. flyte/types/_type_engine.py +2210 -0
  195. flyte/types/_utils.py +80 -0
  196. flyte-0.0.1b0.dist-info/METADATA +179 -0
  197. flyte-0.0.1b0.dist-info/RECORD +390 -0
  198. flyte-0.0.1b0.dist-info/WHEEL +5 -0
  199. flyte-0.0.1b0.dist-info/entry_points.txt +3 -0
  200. flyte-0.0.1b0.dist-info/top_level.txt +1 -0
  201. union/__init__.py +54 -0
  202. union/_api_commons.py +3 -0
  203. union/_bin/__init__.py +0 -0
  204. union/_bin/runtime.py +113 -0
  205. union/_build.py +25 -0
  206. union/_cache/__init__.py +12 -0
  207. union/_cache/cache.py +141 -0
  208. union/_cache/defaults.py +9 -0
  209. union/_cache/policy_function_body.py +42 -0
  210. union/_cli/__init__.py +0 -0
  211. union/_cli/_common.py +263 -0
  212. union/_cli/_create.py +40 -0
  213. union/_cli/_delete.py +23 -0
  214. union/_cli/_deploy.py +120 -0
  215. union/_cli/_get.py +162 -0
  216. union/_cli/_params.py +579 -0
  217. union/_cli/_run.py +150 -0
  218. union/_cli/main.py +72 -0
  219. union/_code_bundle/__init__.py +8 -0
  220. union/_code_bundle/_ignore.py +113 -0
  221. union/_code_bundle/_packaging.py +187 -0
  222. union/_code_bundle/_utils.py +342 -0
  223. union/_code_bundle/bundle.py +176 -0
  224. union/_context.py +146 -0
  225. union/_datastructures.py +295 -0
  226. union/_deploy.py +185 -0
  227. union/_doc.py +29 -0
  228. union/_docstring.py +26 -0
  229. union/_environment.py +43 -0
  230. union/_group.py +31 -0
  231. union/_hash.py +23 -0
  232. union/_image.py +760 -0
  233. union/_initialize.py +585 -0
  234. union/_interface.py +84 -0
  235. union/_internal/__init__.py +3 -0
  236. union/_internal/controllers/__init__.py +77 -0
  237. union/_internal/controllers/_local_controller.py +77 -0
  238. union/_internal/controllers/pbhash.py +39 -0
  239. union/_internal/controllers/remote/__init__.py +40 -0
  240. union/_internal/controllers/remote/_action.py +131 -0
  241. union/_internal/controllers/remote/_client.py +43 -0
  242. union/_internal/controllers/remote/_controller.py +169 -0
  243. union/_internal/controllers/remote/_core.py +341 -0
  244. union/_internal/controllers/remote/_informer.py +260 -0
  245. union/_internal/controllers/remote/_service_protocol.py +44 -0
  246. union/_internal/imagebuild/__init__.py +11 -0
  247. union/_internal/imagebuild/docker_builder.py +416 -0
  248. union/_internal/imagebuild/image_builder.py +243 -0
  249. union/_internal/imagebuild/remote_builder.py +0 -0
  250. union/_internal/resolvers/__init__.py +0 -0
  251. union/_internal/resolvers/_task_module.py +31 -0
  252. union/_internal/resolvers/common.py +24 -0
  253. union/_internal/resolvers/default.py +27 -0
  254. union/_internal/runtime/__init__.py +0 -0
  255. union/_internal/runtime/convert.py +163 -0
  256. union/_internal/runtime/entrypoints.py +121 -0
  257. union/_internal/runtime/io.py +136 -0
  258. union/_internal/runtime/resources_serde.py +134 -0
  259. union/_internal/runtime/task_serde.py +202 -0
  260. union/_internal/runtime/taskrunner.py +179 -0
  261. union/_internal/runtime/types_serde.py +53 -0
  262. union/_logging.py +124 -0
  263. union/_protos/__init__.py +0 -0
  264. union/_protos/common/authorization_pb2.py +66 -0
  265. union/_protos/common/authorization_pb2.pyi +106 -0
  266. union/_protos/common/authorization_pb2_grpc.py +4 -0
  267. union/_protos/common/identifier_pb2.py +71 -0
  268. union/_protos/common/identifier_pb2.pyi +82 -0
  269. union/_protos/common/identifier_pb2_grpc.py +4 -0
  270. union/_protos/common/identity_pb2.py +48 -0
  271. union/_protos/common/identity_pb2.pyi +72 -0
  272. union/_protos/common/identity_pb2_grpc.py +4 -0
  273. union/_protos/common/list_pb2.py +36 -0
  274. union/_protos/common/list_pb2.pyi +69 -0
  275. union/_protos/common/list_pb2_grpc.py +4 -0
  276. union/_protos/common/policy_pb2.py +37 -0
  277. union/_protos/common/policy_pb2.pyi +27 -0
  278. union/_protos/common/policy_pb2_grpc.py +4 -0
  279. union/_protos/common/role_pb2.py +37 -0
  280. union/_protos/common/role_pb2.pyi +51 -0
  281. union/_protos/common/role_pb2_grpc.py +4 -0
  282. union/_protos/common/runtime_version_pb2.py +28 -0
  283. union/_protos/common/runtime_version_pb2.pyi +24 -0
  284. union/_protos/common/runtime_version_pb2_grpc.py +4 -0
  285. union/_protos/logs/dataplane/payload_pb2.py +96 -0
  286. union/_protos/logs/dataplane/payload_pb2.pyi +168 -0
  287. union/_protos/logs/dataplane/payload_pb2_grpc.py +4 -0
  288. union/_protos/secret/definition_pb2.py +49 -0
  289. union/_protos/secret/definition_pb2.pyi +93 -0
  290. union/_protos/secret/definition_pb2_grpc.py +4 -0
  291. union/_protos/secret/payload_pb2.py +62 -0
  292. union/_protos/secret/payload_pb2.pyi +94 -0
  293. union/_protos/secret/payload_pb2_grpc.py +4 -0
  294. union/_protos/secret/secret_pb2.py +38 -0
  295. union/_protos/secret/secret_pb2.pyi +6 -0
  296. union/_protos/secret/secret_pb2_grpc.py +198 -0
  297. union/_protos/validate/validate/validate_pb2.py +76 -0
  298. union/_protos/workflow/node_execution_service_pb2.py +26 -0
  299. union/_protos/workflow/node_execution_service_pb2.pyi +4 -0
  300. union/_protos/workflow/node_execution_service_pb2_grpc.py +32 -0
  301. union/_protos/workflow/queue_service_pb2.py +75 -0
  302. union/_protos/workflow/queue_service_pb2.pyi +103 -0
  303. union/_protos/workflow/queue_service_pb2_grpc.py +172 -0
  304. union/_protos/workflow/run_definition_pb2.py +100 -0
  305. union/_protos/workflow/run_definition_pb2.pyi +256 -0
  306. union/_protos/workflow/run_definition_pb2_grpc.py +4 -0
  307. union/_protos/workflow/run_logs_service_pb2.py +41 -0
  308. union/_protos/workflow/run_logs_service_pb2.pyi +28 -0
  309. union/_protos/workflow/run_logs_service_pb2_grpc.py +69 -0
  310. union/_protos/workflow/run_service_pb2.py +133 -0
  311. union/_protos/workflow/run_service_pb2.pyi +173 -0
  312. union/_protos/workflow/run_service_pb2_grpc.py +412 -0
  313. union/_protos/workflow/state_service_pb2.py +58 -0
  314. union/_protos/workflow/state_service_pb2.pyi +69 -0
  315. union/_protos/workflow/state_service_pb2_grpc.py +138 -0
  316. union/_protos/workflow/task_definition_pb2.py +72 -0
  317. union/_protos/workflow/task_definition_pb2.pyi +65 -0
  318. union/_protos/workflow/task_definition_pb2_grpc.py +4 -0
  319. union/_protos/workflow/task_service_pb2.py +44 -0
  320. union/_protos/workflow/task_service_pb2.pyi +31 -0
  321. union/_protos/workflow/task_service_pb2_grpc.py +104 -0
  322. union/_resources.py +226 -0
  323. union/_retry.py +32 -0
  324. union/_reusable_environment.py +25 -0
  325. union/_run.py +374 -0
  326. union/_secret.py +61 -0
  327. union/_task.py +354 -0
  328. union/_task_environment.py +186 -0
  329. union/_timeout.py +47 -0
  330. union/_tools.py +27 -0
  331. union/_utils/__init__.py +11 -0
  332. union/_utils/asyn.py +119 -0
  333. union/_utils/file_handling.py +71 -0
  334. union/_utils/helpers.py +46 -0
  335. union/_utils/lazy_module.py +54 -0
  336. union/_utils/uv_script_parser.py +49 -0
  337. union/_version.py +21 -0
  338. union/connectors/__init__.py +0 -0
  339. union/errors.py +128 -0
  340. union/extras/__init__.py +5 -0
  341. union/extras/_container.py +263 -0
  342. union/io/__init__.py +11 -0
  343. union/io/_dataframe.py +0 -0
  344. union/io/_dir.py +425 -0
  345. union/io/_file.py +418 -0
  346. union/io/pickle/__init__.py +0 -0
  347. union/io/pickle/transformer.py +117 -0
  348. union/io/structured_dataset/__init__.py +122 -0
  349. union/io/structured_dataset/basic_dfs.py +219 -0
  350. union/io/structured_dataset/structured_dataset.py +1057 -0
  351. union/py.typed +0 -0
  352. union/remote/__init__.py +23 -0
  353. union/remote/_client/__init__.py +0 -0
  354. union/remote/_client/_protocols.py +129 -0
  355. union/remote/_client/auth/__init__.py +12 -0
  356. union/remote/_client/auth/_authenticators/__init__.py +0 -0
  357. union/remote/_client/auth/_authenticators/base.py +391 -0
  358. union/remote/_client/auth/_authenticators/client_credentials.py +73 -0
  359. union/remote/_client/auth/_authenticators/device_code.py +120 -0
  360. union/remote/_client/auth/_authenticators/external_command.py +77 -0
  361. union/remote/_client/auth/_authenticators/factory.py +200 -0
  362. union/remote/_client/auth/_authenticators/pkce.py +515 -0
  363. union/remote/_client/auth/_channel.py +184 -0
  364. union/remote/_client/auth/_client_config.py +83 -0
  365. union/remote/_client/auth/_default_html.py +32 -0
  366. union/remote/_client/auth/_grpc_utils/__init__.py +0 -0
  367. union/remote/_client/auth/_grpc_utils/auth_interceptor.py +204 -0
  368. union/remote/_client/auth/_grpc_utils/default_metadata_interceptor.py +144 -0
  369. union/remote/_client/auth/_keyring.py +154 -0
  370. union/remote/_client/auth/_token_client.py +258 -0
  371. union/remote/_client/auth/errors.py +16 -0
  372. union/remote/_client/controlplane.py +86 -0
  373. union/remote/_data.py +149 -0
  374. union/remote/_logs.py +74 -0
  375. union/remote/_project.py +86 -0
  376. union/remote/_run.py +820 -0
  377. union/remote/_secret.py +132 -0
  378. union/remote/_task.py +193 -0
  379. union/report/__init__.py +3 -0
  380. union/report/_report.py +178 -0
  381. union/report/_template.html +124 -0
  382. union/storage/__init__.py +24 -0
  383. union/storage/_remote_fs.py +34 -0
  384. union/storage/_storage.py +247 -0
  385. union/storage/_utils.py +5 -0
  386. union/types/__init__.py +11 -0
  387. union/types/_renderer.py +162 -0
  388. union/types/_string_literals.py +120 -0
  389. union/types/_type_engine.py +2131 -0
  390. union/types/_utils.py +80 -0
@@ -0,0 +1,1057 @@
1
+ from __future__ import annotations
2
+
3
+ import _datetime
4
+ import collections
5
+ import types
6
+ import typing
7
+ from abc import ABC, abstractmethod
8
+ from dataclasses import dataclass, field, is_dataclass
9
+ from typing import ClassVar, Dict, Generic, List, Optional, Type, Union
10
+
11
+ import msgpack
12
+ from flyteidl.core import literals_pb2, types_pb2
13
+ from fsspec.utils import get_protocol
14
+ from mashumaro.mixins.json import DataClassJSONMixin
15
+ from mashumaro.types import SerializableType
16
+ from pydantic import model_serializer, model_validator
17
+ from typing_extensions import Annotated, TypeAlias, get_args, get_origin
18
+
19
+ import union.storage as storage
20
+ from union._logging import logger
21
+ from union._utils import lazy_module
22
+ from union._utils.asyn import loop_manager
23
+ from union.types import TypeEngine, TypeTransformer, TypeTransformerFailedError
24
+ from union.types._renderer import Renderable
25
+ from union.types._type_engine import modify_literal_uris
26
+
27
+ MESSAGEPACK = "msgpack"
28
+ CSV, PARQUET = "csv", "parquet"
29
+
30
+
31
+ if typing.TYPE_CHECKING:
32
+ import pandas as pd
33
+ import pyarrow as pa
34
+ else:
35
+ pd = lazy_module("pandas")
36
+ pa = lazy_module("pyarrow")
37
+
38
+ T = typing.TypeVar("T") # StructuredDataset type or a dataframe type
39
+ DF = typing.TypeVar("DF") # Dataframe type
40
+
41
+ # For specifying the storage formats of StructuredDatasets. It's just a string, nothing fancy.
42
+ StructuredDatasetFormat: TypeAlias = str
43
+
44
+ # Storage formats
45
+ PARQUET: StructuredDatasetFormat = "parquet"
46
+ CSV: StructuredDatasetFormat = "csv"
47
+ GENERIC_FORMAT: StructuredDatasetFormat = ""
48
+ GENERIC_PROTOCOL: str = "generic protocol"
49
+
50
+
51
+ @dataclass
52
+ class StructuredDataset(SerializableType, DataClassJSONMixin):
53
+ """
54
+ This is the user facing StructuredDataset class. Please don't confuse it with the literals.StructuredDataset
55
+ class (that is just a model, a Python class representation of the protobuf).
56
+ """
57
+
58
+ uri: typing.Optional[str] = field(default=None)
59
+ file_format: typing.Optional[str] = field(default=GENERIC_FORMAT)
60
+
61
+ # loop manager is working better than synchronicity for some reason, was getting an error but may be an easy fix
62
+ def _serialize(self) -> Dict[str, Optional[str]]:
63
+ # dataclass case
64
+ lt = TypeEngine.to_literal_type(type(self))
65
+ engine = StructuredDatasetTransformerEngine()
66
+ lv = loop_manager.run_sync(engine.to_literal, self, type(self), lt)
67
+ sd = StructuredDataset(uri=lv.scalar.structured_dataset.uri)
68
+ sd.file_format = lv.scalar.structured_dataset.metadata.structured_dataset_type.format
69
+ return {
70
+ "uri": sd.uri,
71
+ "file_format": sd.file_format,
72
+ }
73
+
74
+ @classmethod
75
+ def _deserialize(cls, value) -> "StructuredDataset":
76
+ uri = value.get("uri", None)
77
+ file_format = value.get("file_format", None)
78
+
79
+ if uri is None:
80
+ raise ValueError("StructuredDataset's uri and file format should not be None")
81
+
82
+ engine = StructuredDatasetTransformerEngine()
83
+ return loop_manager.run_sync(
84
+ engine.to_python_value,
85
+ literals_pb2.Literal(
86
+ scalar=literals_pb2.Scalar(
87
+ structured_dataset=literals_pb2.StructuredDataset(
88
+ metadata=literals_pb2.StructuredDatasetMetadata(
89
+ structured_dataset_type=types_pb2.StructuredDatasetType(format=file_format)
90
+ ),
91
+ uri=uri,
92
+ )
93
+ )
94
+ ),
95
+ cls,
96
+ )
97
+
98
+ @model_serializer
99
+ def serialize_structured_dataset(self) -> Dict[str, Optional[str]]:
100
+ lt = TypeEngine.to_literal_type(type(self))
101
+ sde = StructuredDatasetTransformerEngine()
102
+ lv = loop_manager.run_sync(sde.to_literal, self, type(self), lt)
103
+ return {
104
+ "uri": lv.scalar.structured_dataset.uri,
105
+ "file_format": lv.scalar.structured_dataset.metadata.structured_dataset_type.format,
106
+ }
107
+
108
+ @model_validator(mode="after")
109
+ def deserialize_structured_dataset(self, info) -> StructuredDataset:
110
+ if info.context is None or info.context.get("deserialize") is not True:
111
+ return self
112
+
113
+ engine = StructuredDatasetTransformerEngine()
114
+ return loop_manager.run_sync(
115
+ engine.to_python_value,
116
+ literals_pb2.Literal(
117
+ scalar=literals_pb2.Scalar(
118
+ structured_dataset=literals_pb2.StructuredDataset(
119
+ metadata=literals_pb2.StructuredDatasetMetadata(
120
+ structured_dataset_type=types_pb2.StructuredDatasetType(format=self.file_format)
121
+ ),
122
+ uri=self.uri,
123
+ )
124
+ )
125
+ ),
126
+ type(self),
127
+ )
128
+
129
+ @classmethod
130
+ def columns(cls) -> typing.Dict[str, typing.Type]:
131
+ return {}
132
+
133
+ @classmethod
134
+ def column_names(cls) -> typing.List[str]:
135
+ return [k for k, v in cls.columns().items()]
136
+
137
+ def __init__(
138
+ self,
139
+ dataframe: typing.Optional[typing.Any] = None,
140
+ uri: typing.Optional[str] = None,
141
+ metadata: typing.Optional[literals_pb2.StructuredDatasetMetadata] = None,
142
+ **kwargs,
143
+ ):
144
+ self._dataframe = dataframe
145
+ # Make these fields public, so that the dataclass transformer can set a value for it
146
+ # https://github.com/flyteorg/flytekit/blob/bcc8541bd6227b532f8462563fe8aac902242b21/flytekit/core/type_engine.py#L298
147
+ self.uri = uri
148
+ # When dataclass_json runs from_json, we need to set it here, otherwise the format will be empty string
149
+ self.file_format = kwargs["file_format"] if "file_format" in kwargs else GENERIC_FORMAT
150
+ # This is a special attribute that indicates if the data was either downloaded or uploaded
151
+ self._metadata = metadata
152
+ # This is not for users to set, the transformer will set this.
153
+ self._literal_sd: Optional[literals_pb2.StructuredDataset] = None
154
+ # Not meant for users to set, will be set by an open() call
155
+ self._dataframe_type: Optional[DF] = None # type: ignore
156
+ self._already_uploaded = False
157
+
158
+ @property
159
+ def dataframe(self) -> Optional[DF]:
160
+ return self._dataframe
161
+
162
+ @property
163
+ def metadata(self) -> Optional[literals_pb2.StructuredDatasetMetadata]:
164
+ return self._metadata
165
+
166
+ @property
167
+ def literal(self) -> Optional[literals_pb2.StructuredDataset]:
168
+ return self._literal_sd
169
+
170
+ def open(self, dataframe_type: Type[DF]):
171
+ from union.io.structured_dataset import lazy_import_structured_dataset_handler
172
+
173
+ """
174
+ Load the handler if needed. For the use case like:
175
+ @task
176
+ def t1(sd: StructuredDataset):
177
+ import pandas as pd
178
+ sd.open(pd.DataFrame).all()
179
+
180
+ pandas is imported inside the task, so pandas handler won't be loaded during deserialization in type engine.
181
+ """
182
+ lazy_import_structured_dataset_handler()
183
+ self._dataframe_type = dataframe_type
184
+ return self
185
+
186
+ async def all(self) -> DF: # type: ignore
187
+ if self._dataframe_type is None:
188
+ raise ValueError("No dataframe type set. Use open() to set the local dataframe type you want to use.")
189
+
190
+ if self.uri is not None and self.dataframe is None:
191
+ expected = TypeEngine.to_literal_type(StructuredDataset)
192
+ await self._set_literal(expected)
193
+
194
+ return await flyte_dataset_transformer.open_as(self.literal, self._dataframe_type, self.metadata)
195
+
196
+ async def _set_literal(self, expected: types_pb2.LiteralType) -> None:
197
+ """
198
+ Explicitly set the StructuredDataset Literal to handle the following cases:
199
+
200
+ 1. Read a dataframe from a StructuredDataset with an uri, for example:
201
+
202
+ @task
203
+ def return_sd() -> StructuredDataset:
204
+ sd = StructuredDataset(uri="s3://my-s3-bucket/s3_flyte_dir/df.parquet", file_format="parquet")
205
+ df = sd.open(pd.DataFrame).all()
206
+ return df
207
+
208
+ For details, please refer to this issue: https://github.com/flyteorg/flyte/issues/5954.
209
+
210
+ 2. Need access to self._literal_sd when converting task output LiteralMap back to flyteidl, please see:
211
+ https://github.com/flyteorg/flytekit/blob/f938661ff8413219d1bea77f6914a58c302d5c6c/flytekit/bin/entrypoint.py#L326
212
+
213
+ For details, please refer to this issue: https://github.com/flyteorg/flyte/issues/5956.
214
+ """
215
+ to_literal = await flyte_dataset_transformer.to_literal(self, StructuredDataset, expected)
216
+ self._literal_sd = to_literal.scalar.structured_dataset
217
+ if self.metadata is None:
218
+ self._metadata = self._literal_sd.metadata
219
+
220
+ async def set_literal(self, expected: types_pb2.LiteralType) -> None:
221
+ """
222
+ A public wrapper method to set the StructuredDataset Literal.
223
+
224
+ This method provides external access to the internal _set_literal method.
225
+ """
226
+ return await self._set_literal(expected)
227
+
228
+ async def iter(self) -> typing.AsyncIterator[DF]:
229
+ if self._dataframe_type is None:
230
+ raise ValueError("No dataframe type set. Use open() to set the local dataframe type you want to use.")
231
+ return await flyte_dataset_transformer.iter_as(
232
+ self.literal, self._dataframe_type, updated_metadata=self.metadata
233
+ )
234
+
235
+
236
+ # flat the nested column map recursively
237
+ def flatten_dict(sub_dict: dict, parent_key: str = "") -> typing.Dict:
238
+ result = {}
239
+ for key, value in sub_dict.items():
240
+ current_key = f"{parent_key}.{key}" if parent_key else key
241
+ if isinstance(value, dict):
242
+ result.update(flatten_dict(sub_dict=value, parent_key=current_key))
243
+ elif is_dataclass(value):
244
+ fields = getattr(value, "__dataclass_fields__")
245
+ d = {k: v.type for k, v in fields.items()}
246
+ result.update(flatten_dict(sub_dict=d, parent_key=current_key))
247
+ else:
248
+ result[current_key] = value
249
+ return result
250
+
251
+
252
+ def extract_cols_and_format(
253
+ t: typing.Any,
254
+ ) -> typing.Tuple[Type[T], Optional[typing.OrderedDict[str, Type]], Optional[str], Optional["pa.lib.Schema"]]:
255
+ """
256
+ Helper function, just used to iterate through Annotations and extract out the following information:
257
+ - base type, if not Annotated, it will just be the type that was passed in.
258
+ - column information, as a collections.OrderedDict,
259
+ - the storage format, as a ``StructuredDatasetFormat`` (str),
260
+ - pa.lib.Schema
261
+
262
+ If more than one of any type of thing is found, an error will be raised.
263
+ If no instances of a given type are found, then None will be returned.
264
+
265
+ If we add more things, we should put all the returned items in a dataclass instead of just a tuple.
266
+
267
+ :param t: The incoming type which may or may not be Annotated
268
+ :return: Tuple representing
269
+ the original type,
270
+ optional OrderedDict of columns,
271
+ optional str for the format,
272
+ optional pyarrow Schema
273
+ """
274
+ fmt = ""
275
+ ordered_dict_cols = None
276
+ pa_schema = None
277
+ if get_origin(t) is Annotated:
278
+ base_type, *annotate_args = get_args(t)
279
+ for aa in annotate_args:
280
+ if hasattr(aa, "__annotations__"):
281
+ # handle dataclass argument
282
+ d = collections.OrderedDict()
283
+ d.update(aa.__annotations__)
284
+ ordered_dict_cols = d
285
+ elif isinstance(aa, dict):
286
+ d = collections.OrderedDict()
287
+ d.update(aa)
288
+ ordered_dict_cols = d
289
+ elif isinstance(aa, StructuredDatasetFormat):
290
+ if fmt != "":
291
+ raise ValueError(f"A format was already specified {fmt}, cannot use {aa}")
292
+ fmt = aa
293
+ elif isinstance(aa, collections.OrderedDict):
294
+ if ordered_dict_cols is not None:
295
+ raise ValueError(f"Column information was already found {ordered_dict_cols}, cannot use {aa}")
296
+ ordered_dict_cols = aa
297
+ elif isinstance(aa, pa.lib.Schema):
298
+ if pa_schema is not None:
299
+ raise ValueError(f"Arrow schema was already found {pa_schema}, cannot use {aa}")
300
+ pa_schema = aa
301
+ return base_type, ordered_dict_cols, fmt, pa_schema
302
+
303
+ # We return None as the format instead of parquet or something because the transformer engine may find
304
+ # a better default for the given dataframe type.
305
+ return t, ordered_dict_cols, fmt, pa_schema
306
+
307
+
308
+ class StructuredDatasetEncoder(ABC, Generic[T]):
309
+ def __init__(
310
+ self,
311
+ python_type: Type[T],
312
+ protocol: Optional[str] = None,
313
+ supported_format: Optional[str] = None,
314
+ ):
315
+ """
316
+ Extend this abstract class, implement the encode function, and register your concrete class with the
317
+ StructuredDatasetTransformerEngine class in order for the core flytekit type engine to handle
318
+ dataframe libraries. This is the encoding interface, meaning it is used when there is a Python value that the
319
+ flytekit type engine is trying to convert into a Flyte Literal. For the other way, see
320
+ the StructuredDatasetEncoder
321
+
322
+ :param python_type: The dataframe class in question that you want to register this encoder with
323
+ :param protocol: A prefix representing the storage driver (e.g. 's3, 'gs', 'bq', etc.). You can use either
324
+ "s3" or "s3://". They are the same since the "://" will just be stripped by the constructor.
325
+ If None, this encoder will be registered with all protocols that flytekit's data persistence layer
326
+ is capable of handling.
327
+ :param supported_format: Arbitrary string representing the format. If not supplied then an empty string
328
+ will be used. An empty string implies that the encoder works with any format. If the format being asked
329
+ for does not exist, the transformer engine will look for the "" encoder instead and write a warning.
330
+ """
331
+ self._python_type = python_type
332
+ self._protocol = protocol.replace("://", "") if protocol else None
333
+ self._supported_format = supported_format or ""
334
+
335
+ @property
336
+ def python_type(self) -> Type[T]:
337
+ return self._python_type
338
+
339
+ @property
340
+ def protocol(self) -> Optional[str]:
341
+ return self._protocol
342
+
343
+ @property
344
+ def supported_format(self) -> str:
345
+ return self._supported_format
346
+
347
+ @abstractmethod
348
+ async def encode(
349
+ self,
350
+ structured_dataset: StructuredDataset,
351
+ structured_dataset_type: types_pb2.StructuredDatasetType,
352
+ ) -> literals_pb2.StructuredDataset:
353
+ """
354
+ Even if the user code returns a plain dataframe instance, the dataset transformer engine will wrap the
355
+ incoming dataframe with defaults set for that dataframe
356
+ type. This simplifies this function's interface as a lot of data that could be specified by the user using
357
+ the
358
+ # TODO: Do we need to add a flag to indicate if it was wrapped by the transformer or by the user?
359
+
360
+ :param structured_dataset: This is a StructuredDataset wrapper object. See more info above.
361
+ :param structured_dataset_type: This the StructuredDatasetType, as found in the LiteralType of the interface
362
+ of the task that invoked this encoding call. It is passed along to encoders so that authors of encoders
363
+ can include it in the returned literals.StructuredDataset. See the IDL for more information on why this
364
+ literal in particular carries the type information along with it. If the encoder doesn't supply it, it will
365
+ also be filled in after the encoder runs by the transformer engine.
366
+ :return: This function should return a StructuredDataset literal object. Do not confuse this with the
367
+ StructuredDataset wrapper class used as input to this function - that is the user facing Python class.
368
+ This function needs to return the IDL StructuredDataset.
369
+ """
370
+ raise NotImplementedError
371
+
372
+
373
+ class StructuredDatasetDecoder(ABC, Generic[DF]):
374
+ def __init__(
375
+ self,
376
+ python_type: Type[DF],
377
+ protocol: Optional[str] = None,
378
+ supported_format: Optional[str] = None,
379
+ additional_protocols: Optional[List[str]] = None,
380
+ ):
381
+ """
382
+ Extend this abstract class, implement the decode function, and register your concrete class with the
383
+ StructuredDatasetTransformerEngine class in order for the core flytekit type engine to handle
384
+ dataframe libraries. This is the decoder interface, meaning it is used when there is a Flyte Literal value,
385
+ and we have to get a Python value out of it. For the other way, see the StructuredDatasetEncoder
386
+
387
+ :param python_type: The dataframe class in question that you want to register this decoder with
388
+ :param protocol: A prefix representing the storage driver (e.g. 's3, 'gs', 'bq', etc.). You can use either
389
+ "s3" or "s3://". They are the same since the "://" will just be stripped by the constructor.
390
+ If None, this decoder will be registered with all protocols that flytekit's data persistence layer
391
+ is capable of handling.
392
+ :param supported_format: Arbitrary string representing the format. If not supplied then an empty string
393
+ will be used. An empty string implies that the decoder works with any format. If the format being asked
394
+ for does not exist, the transformer enginer will look for the "" decoder instead and write a warning.
395
+ """
396
+ self._python_type = python_type
397
+ self._protocol = protocol.replace("://", "") if protocol else None
398
+ self._supported_format = supported_format or ""
399
+
400
+ @property
401
+ def python_type(self) -> Type[DF]:
402
+ return self._python_type
403
+
404
+ @property
405
+ def protocol(self) -> Optional[str]:
406
+ return self._protocol
407
+
408
+ @property
409
+ def supported_format(self) -> str:
410
+ return self._supported_format
411
+
412
+ @abstractmethod
413
+ async def decode(
414
+ self,
415
+ flyte_value: literals_pb2.StructuredDataset,
416
+ current_task_metadata: literals_pb2.StructuredDatasetMetadata,
417
+ ) -> Union[DF, typing.AsyncIterator[DF]]:
418
+ """
419
+ This is code that will be called by the dataset transformer engine to ultimately translate from a Flyte Literal
420
+ value into a Python instance.
421
+
422
+ :param flyte_value: This will be a Flyte IDL StructuredDataset Literal - do not confuse this with the
423
+ StructuredDataset class defined also in this module.
424
+ :param current_task_metadata: Metadata object containing the type (and columns if any) for the currently
425
+ executing task. This type may have more or less information than the type information bundled
426
+ inside the incoming flyte_value.
427
+ :return: This function can either return an instance of the dataframe that this decoder handles, or an iterator
428
+ of those dataframes.
429
+ """
430
+ raise NotImplementedError
431
+
432
+
433
+ def get_supported_types():
434
+ import numpy as _np
435
+
436
+ _SUPPORTED_TYPES: typing.Dict[Type, types_pb2.LiteralType] = { # type: ignore
437
+ _np.int32: types_pb2.LiteralType(simple=types_pb2.SimpleType.INTEGER),
438
+ _np.int64: types_pb2.LiteralType(simple=types_pb2.SimpleType.INTEGER),
439
+ _np.uint32: types_pb2.LiteralType(simple=types_pb2.SimpleType.INTEGER),
440
+ _np.uint64: types_pb2.LiteralType(simple=types_pb2.SimpleType.INTEGER),
441
+ int: types_pb2.LiteralType(simple=types_pb2.SimpleType.INTEGER),
442
+ _np.float32: types_pb2.LiteralType(simple=types_pb2.SimpleType.FLOAT),
443
+ _np.float64: types_pb2.LiteralType(simple=types_pb2.SimpleType.FLOAT),
444
+ float: types_pb2.LiteralType(simple=types_pb2.SimpleType.FLOAT),
445
+ _np.bool_: types_pb2.LiteralType(simple=types_pb2.SimpleType.BOOLEAN), # type: ignore
446
+ bool: types_pb2.LiteralType(simple=types_pb2.SimpleType.BOOLEAN),
447
+ _np.datetime64: types_pb2.LiteralType(simple=types_pb2.SimpleType.DATETIME),
448
+ _datetime.datetime: types_pb2.LiteralType(simple=types_pb2.SimpleType.DATETIME),
449
+ _np.timedelta64: types_pb2.LiteralType(simple=types_pb2.SimpleType.DURATION),
450
+ _datetime.timedelta: types_pb2.LiteralType(simple=types_pb2.SimpleType.DURATION),
451
+ _np.bytes_: types_pb2.LiteralType(simple=types_pb2.SimpleType.STRING),
452
+ _np.str_: types_pb2.LiteralType(simple=types_pb2.SimpleType.STRING),
453
+ _np.object_: types_pb2.LiteralType(simple=types_pb2.SimpleType.STRING),
454
+ str: types_pb2.LiteralType(simple=types_pb2.SimpleType.STRING),
455
+ }
456
+ return _SUPPORTED_TYPES
457
+
458
+
459
+ class DuplicateHandlerError(ValueError): ...
460
+
461
+
462
+ class StructuredDatasetTransformerEngine(TypeTransformer[StructuredDataset]):
463
+ """
464
+ Think of this transformer as a higher-level meta transformer that is used for all the dataframe types.
465
+ If you are bringing a custom data frame type, or any data frame type, to flytekit, instead of
466
+ registering with the main type engine, you should register with this transformer instead.
467
+ """
468
+
469
+ ENCODERS: ClassVar[Dict[Type, Dict[str, Dict[str, StructuredDatasetEncoder]]]] = {}
470
+ DECODERS: ClassVar[Dict[Type, Dict[str, Dict[str, StructuredDatasetDecoder]]]] = {}
471
+ DEFAULT_PROTOCOLS: ClassVar[Dict[Type, str]] = {}
472
+ DEFAULT_FORMATS: ClassVar[Dict[Type, str]] = {}
473
+
474
+ Handlers = Union[StructuredDatasetEncoder, StructuredDatasetDecoder]
475
+ Renderers: ClassVar[Dict[Type, Renderable]] = {}
476
+
477
+ @classmethod
478
+ def _finder(cls, handler_map, df_type: Type, protocol: str, format: str):
479
+ # If there's an exact match, then we should use it.
480
+ try:
481
+ return handler_map[df_type][protocol][format]
482
+ except KeyError:
483
+ ...
484
+
485
+ fsspec_handler = None
486
+ protocol_specific_handler = None
487
+ single_handler = None
488
+ default_format = cls.DEFAULT_FORMATS.get(df_type, None)
489
+
490
+ try:
491
+ fss_handlers = handler_map[df_type]["fsspec"]
492
+ if format in fss_handlers:
493
+ fsspec_handler = fss_handlers[format]
494
+ elif GENERIC_FORMAT in fss_handlers:
495
+ fsspec_handler = fss_handlers[GENERIC_FORMAT]
496
+ else:
497
+ if default_format and default_format in fss_handlers and format == GENERIC_FORMAT:
498
+ fsspec_handler = fss_handlers[default_format]
499
+ else:
500
+ if len(fss_handlers) == 1 and format == GENERIC_FORMAT:
501
+ single_handler = next(iter(fss_handlers.values()))
502
+ else:
503
+ ...
504
+ except KeyError:
505
+ ...
506
+
507
+ try:
508
+ protocol_handlers = handler_map[df_type][protocol]
509
+ if GENERIC_FORMAT in protocol_handlers:
510
+ protocol_specific_handler = protocol_handlers[GENERIC_FORMAT]
511
+ else:
512
+ if default_format and default_format in protocol_handlers:
513
+ protocol_specific_handler = protocol_handlers[default_format]
514
+ else:
515
+ if len(protocol_handlers) == 1:
516
+ single_handler = next(iter(protocol_handlers.values()))
517
+ else:
518
+ ...
519
+
520
+ except KeyError:
521
+ ...
522
+
523
+ if protocol_specific_handler or fsspec_handler or single_handler:
524
+ return protocol_specific_handler or fsspec_handler or single_handler
525
+ else:
526
+ raise ValueError(f"Failed to find a handler for {df_type}, protocol [{protocol}], fmt ['{format}']")
527
+
528
+ @classmethod
529
+ def get_encoder(cls, df_type: Type, protocol: str, format: str):
530
+ return cls._finder(StructuredDatasetTransformerEngine.ENCODERS, df_type, protocol, format)
531
+
532
+ @classmethod
533
+ def get_decoder(cls, df_type: Type, protocol: str, format: str) -> StructuredDatasetDecoder:
534
+ return cls._finder(StructuredDatasetTransformerEngine.DECODERS, df_type, protocol, format)
535
+
536
+ @classmethod
537
+ def _handler_finder(cls, h: Handlers, protocol: str) -> Dict[str, Handlers]:
538
+ if isinstance(h, StructuredDatasetEncoder):
539
+ top_level = cls.ENCODERS
540
+ elif isinstance(h, StructuredDatasetDecoder):
541
+ top_level = cls.DECODERS # type: ignore
542
+ else:
543
+ raise TypeError(f"We don't support this type of handler {h}")
544
+ if h.python_type not in top_level:
545
+ top_level[h.python_type] = {}
546
+ if protocol not in top_level[h.python_type]:
547
+ top_level[h.python_type][protocol] = {}
548
+ return top_level[h.python_type][protocol] # type: ignore
549
+
550
+ def __init__(self):
551
+ super().__init__("StructuredDataset Transformer", StructuredDataset)
552
+ self._type_assertions_enabled = False
553
+
554
+ @classmethod
555
+ def register_renderer(cls, python_type: Type, renderer: Renderable):
556
+ cls.Renderers[python_type] = renderer
557
+
558
+ @classmethod
559
+ def register(
560
+ cls,
561
+ h: Handlers,
562
+ default_for_type: bool = False,
563
+ override: bool = False,
564
+ default_format_for_type: bool = False,
565
+ default_storage_for_type: bool = False,
566
+ ):
567
+ """
568
+ Call this with any Encoder or Decoder to register it with the flytekit type system. If your handler does not
569
+ specify a protocol (e.g. s3, gs, etc.) field, then
570
+
571
+ :param h: The StructuredDatasetEncoder or StructuredDatasetDecoder you wish to register with this transformer.
572
+ :param default_for_type: If set, when a user returns from a task an instance of the dataframe the handler
573
+ handles, e.g. ``return pd.DataFrame(...)``, not wrapped around the ``StructuredDataset`` object, we will
574
+ use this handler's protocol and format as the default, effectively saying that this handler will be called.
575
+ Note that this shouldn't be set if your handler's protocol is None, because that implies that your handler
576
+ is capable of handling all the different storage protocols that flytekit's data persistence layer is aware of.
577
+ In these cases, the protocol is determined by the raw output data prefix set in the active context.
578
+ :param override: Override any previous registrations. If default_for_type is also set, this will also override
579
+ the default.
580
+ :param default_format_for_type: Unlike the default_for_type arg that will set this handler's format and storage
581
+ as the default, this will only set the format. Error if already set, unless override is specified.
582
+ :param default_storage_for_type: Same as above but only for the storage format. Error if already set,
583
+ unless override is specified.
584
+ """
585
+ if not (isinstance(h, StructuredDatasetEncoder) or isinstance(h, StructuredDatasetDecoder)):
586
+ raise TypeError(f"We don't support this type of handler {h}")
587
+
588
+ if h.protocol is None:
589
+ if default_for_type:
590
+ raise ValueError(f"Registering SD handler {h} with all protocols should never have default specified.")
591
+ try:
592
+ cls.register_for_protocol(
593
+ h, "fsspec", False, override, default_format_for_type, default_storage_for_type
594
+ )
595
+ except DuplicateHandlerError:
596
+ ...
597
+
598
+ elif h.protocol == "":
599
+ raise ValueError(f"Use None instead of empty string for registering handler {h}")
600
+ else:
601
+ cls.register_for_protocol(
602
+ h, h.protocol, default_for_type, override, default_format_for_type, default_storage_for_type
603
+ )
604
+
605
+ @classmethod
606
+ def register_for_protocol(
607
+ cls,
608
+ h: Handlers,
609
+ protocol: str,
610
+ default_for_type: bool,
611
+ override: bool,
612
+ default_format_for_type: bool,
613
+ default_storage_for_type: bool,
614
+ ):
615
+ """
616
+ See the main register function instead.
617
+ """
618
+ if protocol == "/":
619
+ protocol = "file"
620
+ lowest_level = cls._handler_finder(h, protocol)
621
+ if h.supported_format in lowest_level and override is False:
622
+ raise DuplicateHandlerError(
623
+ f"Already registered a handler for {(h.python_type, protocol, h.supported_format)}"
624
+ )
625
+ lowest_level[h.supported_format] = h
626
+ logger.debug(f"Registered {h} as handler for {h.python_type}, protocol {protocol}, fmt {h.supported_format}")
627
+
628
+ if (default_format_for_type or default_for_type) and h.supported_format != GENERIC_FORMAT:
629
+ if h.python_type in cls.DEFAULT_FORMATS and not override:
630
+ if cls.DEFAULT_FORMATS[h.python_type] != h.supported_format:
631
+ logger.info(
632
+ f"Not using handler {h} with format {h.supported_format}"
633
+ f" as default for {h.python_type}, {cls.DEFAULT_FORMATS[h.python_type]} already specified."
634
+ )
635
+ else:
636
+ logger.debug(f"Use {type(h).__name__} as default handler for {h.python_type}.")
637
+ cls.DEFAULT_FORMATS[h.python_type] = h.supported_format
638
+ if default_storage_for_type or default_for_type:
639
+ if h.protocol in cls.DEFAULT_PROTOCOLS and not override:
640
+ logger.debug(
641
+ f"Not using handler {h} with storage protocol {h.protocol}"
642
+ f" as default for {h.python_type}, {cls.DEFAULT_PROTOCOLS[h.python_type]} already specified."
643
+ )
644
+ else:
645
+ logger.debug(f"Using storage {protocol} for dataframes of type {h.python_type} from handler {h}")
646
+ cls.DEFAULT_PROTOCOLS[h.python_type] = protocol
647
+
648
+ # Register with the type engine as well
649
+ # The semantics as of now are such that it doesn't matter which order these transformers are loaded in, as
650
+ # long as the older Pandas/FlyteSchema transformer do not also specify the override
651
+ engine = StructuredDatasetTransformerEngine()
652
+ TypeEngine.register_additional_type(engine, h.python_type, override=True)
653
+
654
+ def assert_type(self, t: Type[StructuredDataset], v: typing.Any):
655
+ return
656
+
657
+ async def to_literal(
658
+ self,
659
+ python_val: Union[StructuredDataset, typing.Any],
660
+ python_type: Union[Type[StructuredDataset], Type],
661
+ expected: types_pb2.LiteralType,
662
+ ) -> literals_pb2.Literal:
663
+ # Make a copy in case we need to hand off to encoders, since we can't be sure of mutations.
664
+ python_type, *attrs = extract_cols_and_format(python_type)
665
+ sdt = types_pb2.StructuredDatasetType(format=self.DEFAULT_FORMATS.get(python_type, GENERIC_FORMAT))
666
+
667
+ if issubclass(python_type, StructuredDataset) and not isinstance(python_val, StructuredDataset):
668
+ # Catch a common mistake
669
+ raise TypeTransformerFailedError(
670
+ f"Expected a StructuredDataset instance, but got {type(python_val)} instead."
671
+ f" Did you forget to wrap your dataframe in a StructuredDataset instance?"
672
+ )
673
+
674
+ if expected and expected.structured_dataset_type:
675
+ sdt = types_pb2.StructuredDatasetType(
676
+ columns=expected.structured_dataset_type.columns,
677
+ format=expected.structured_dataset_type.format,
678
+ external_schema_type=expected.structured_dataset_type.external_schema_type,
679
+ external_schema_bytes=expected.structured_dataset_type.external_schema_bytes,
680
+ )
681
+
682
+ # If the type signature has the StructuredDataset class, it will, or at least should, also be a
683
+ # StructuredDataset instance.
684
+ if isinstance(python_val, StructuredDataset):
685
+ # There are three cases that we need to take care of here.
686
+
687
+ # 1. A task returns a StructuredDataset that was just a passthrough input. If this happens
688
+ # then return the original literals.StructuredDataset without invoking any encoder
689
+ #
690
+ # Ex.
691
+ # def t1(dataset: Annotated[StructuredDataset, my_cols]) -> Annotated[StructuredDataset, my_cols]:
692
+ # return dataset
693
+ if python_val._literal_sd is not None:
694
+ if python_val._already_uploaded:
695
+ return literals_pb2.Literal(scalar=literals_pb2.Scalar(structured_dataset=python_val._literal_sd))
696
+ if python_val.dataframe is not None:
697
+ raise ValueError(
698
+ f"Shouldn't have specified both literal {python_val._literal_sd}"
699
+ f" and dataframe {python_val.dataframe}"
700
+ )
701
+ return literals_pb2.Literal(scalar=literals_pb2.Scalar(structured_dataset=python_val._literal_sd))
702
+
703
+ # 2. A task returns a python StructuredDataset with an uri.
704
+ # Note: this case is also what happens we start a local execution of a task with a python StructuredDataset.
705
+ # It gets converted into a literal first, then back into a python StructuredDataset.
706
+ #
707
+ # Ex.
708
+ # def t2(uri: str) -> Annotated[StructuredDataset, my_cols]
709
+ # return StructuredDataset(uri=uri)
710
+ if python_val.dataframe is None:
711
+ uri = python_val.uri
712
+ file_format = python_val.file_format
713
+
714
+ # Check the user-specified uri
715
+ if not uri:
716
+ raise ValueError(f"If dataframe is not specified, then the uri should be specified. {python_val}")
717
+ if not storage.is_remote(uri):
718
+ uri = await storage.put(uri)
719
+
720
+ # Check the user-specified file_format
721
+ # When users specify file_format for a StructuredDataset, the file_format should be retained
722
+ # conditionally. For details, please refer to https://github.com/flyteorg/flyte/issues/6096.
723
+ # Following illustrates why we can't always copy the user-specified file_format over:
724
+ #
725
+ # @task
726
+ # def modify_format(sd: Annotated[StructuredDataset, {}, "task-format"]) -> StructuredDataset:
727
+ # return sd
728
+ #
729
+ # sd = StructuredDataset(uri="s3://my-s3-bucket/df.parquet", file_format="user-format")
730
+ # sd2 = modify_format(sd=sd)
731
+ #
732
+ # In this case, we expect sd2.file_format to be task-format (as shown in Annotated), not user-format.
733
+ # If we directly copy the user-specified file_format over, the type hint information will be missing.
734
+ if sdt.format == GENERIC_FORMAT and file_format != GENERIC_FORMAT:
735
+ sdt.format = file_format
736
+
737
+ sd_model = literals_pb2.StructuredDataset(
738
+ uri=uri,
739
+ metadata=literals_pb2.StructuredDatasetMetadata(structured_dataset_type=sdt),
740
+ )
741
+ return literals_pb2.Literal(scalar=literals_pb2.Scalar(structured_dataset=sd_model))
742
+
743
+ # 3. This is the third and probably most common case. The python StructuredDataset object wraps a dataframe
744
+ # that we will need to invoke an encoder for. Figure out which encoder to call and invoke it.
745
+ df_type = type(python_val.dataframe)
746
+ protocol = self._protocol_from_type_or_prefix(df_type, python_val.uri)
747
+
748
+ return await self.encode(
749
+ python_val,
750
+ df_type,
751
+ protocol,
752
+ sdt.format,
753
+ sdt,
754
+ )
755
+
756
+ # Otherwise assume it's a dataframe instance. Wrap it with some defaults
757
+ fmt = self.DEFAULT_FORMATS.get(python_type, "")
758
+ protocol = self._protocol_from_type_or_prefix(python_type)
759
+ meta = literals_pb2.StructuredDatasetMetadata(
760
+ structured_dataset_type=expected.structured_dataset_type if expected else None
761
+ )
762
+
763
+ sd = StructuredDataset(dataframe=python_val, metadata=meta)
764
+ return await self.encode(sd, python_type, protocol, fmt, sdt)
765
+
766
+ def _protocol_from_type_or_prefix(self, df_type: Type, uri: Optional[str] = None) -> str:
767
+ """
768
+ Get the protocol from the default, if missing, then look it up from the uri if provided, if not then look
769
+ up from the provided context's file access.
770
+ """
771
+ if df_type in self.DEFAULT_PROTOCOLS:
772
+ return self.DEFAULT_PROTOCOLS[df_type]
773
+ else:
774
+ from union._context import internal_ctx
775
+
776
+ ctx = internal_ctx()
777
+ protocol = get_protocol(uri or ctx.raw_data.path)
778
+ logger.debug(
779
+ f"No default protocol for type {df_type} found, using {protocol} from output prefix {ctx.raw_data.path}"
780
+ )
781
+ return protocol
782
+
783
+ async def encode(
784
+ self,
785
+ sd: StructuredDataset,
786
+ df_type: Type,
787
+ protocol: str,
788
+ format: str,
789
+ structured_literal_type: types_pb2.StructuredDatasetType,
790
+ ) -> literals_pb2.Literal:
791
+ handler: StructuredDatasetEncoder
792
+ handler = self.get_encoder(df_type, protocol, format)
793
+
794
+ sd_model = await handler.encode(sd, structured_literal_type)
795
+ # This block is here in case the encoder did not set the type information in the metadata. Since this literal
796
+ # is special in that it carries around the type itself, we want to make sure the type info therein is at
797
+ # least as good as the type of the interface.
798
+
799
+ if sd_model.metadata is None:
800
+ sd_model.metadata = literals_pb2.StructuredDatasetMetadata(structured_dataset_type=structured_literal_type)
801
+ if sd_model.metadata and sd_model.metadata.structured_dataset_type is None:
802
+ sd_model.metadata.structured_dataset_type = structured_literal_type
803
+ # Always set the format here to the format of the handler.
804
+ # Note that this will always be the same as the incoming format except for when the fallback handler
805
+ # with a format of "" is used.
806
+ sd_model.metadata.structured_dataset_type.format = handler.supported_format
807
+ lit = literals_pb2.Literal(scalar=literals_pb2.Scalar(structured_dataset=sd_model))
808
+
809
+ # Because the handler.encode may have uploaded something, and because the sd may end up living inside a
810
+ # dataclass, we need to modify any uploaded flyte:// urls here.
811
+ modify_literal_uris(lit) # todo: verify that this can be removed.
812
+ sd._literal_sd = sd_model
813
+ sd._already_uploaded = True
814
+ return lit
815
+
816
+ # pr: han-ru: can this be removed if we make StructuredDataset a pydantic model?
817
+ def dict_to_structured_dataset(
818
+ self, dict_obj: typing.Dict[str, str], expected_python_type: Type[T] | StructuredDataset
819
+ ) -> T | StructuredDataset:
820
+ uri = dict_obj.get("uri", None)
821
+ file_format = dict_obj.get("file_format", None)
822
+
823
+ if uri is None:
824
+ raise ValueError("StructuredDataset's uri and file format should not be None")
825
+
826
+ # Instead of using python native StructuredDataset, we need to build a literals.StructuredDataset
827
+ # The reason is that _literal_sd of python sd is accessed when task output LiteralMap is
828
+ # converted back to flyteidl. Hence, _literal_sd must have to_flyte_idl method
829
+ # See https://github.com/flyteorg/flytekit/blob/f938661ff8413219d1bea77f6914a58c302d5c6c/flytekit/bin/entrypoint.py#L326
830
+ # For details, please refer to this issue: https://github.com/flyteorg/flyte/issues/5956.
831
+ sdt = types_pb2.StructuredDatasetType(format=file_format)
832
+ metad = literals_pb2.StructuredDatasetMetadata(structured_dataset_type=sdt)
833
+ sd_literal = literals_pb2.StructuredDataset(uri=uri, metadata=metad)
834
+
835
+ return StructuredDatasetTransformerEngine().to_python_value(
836
+ literals_pb2.Literal(scalar=literals_pb2.Scalar(structured_dataset=sd_literal)),
837
+ expected_python_type,
838
+ )
839
+
840
+ def from_binary_idl(
841
+ self, binary_idl_object: literals_pb2.Binary, expected_python_type: Type[T] | StructuredDataset
842
+ ) -> T | StructuredDataset:
843
+ """
844
+ If the input is from flytekit, the Life Cycle will be as follows:
845
+
846
+ Life Cycle:
847
+ binary IDL -> resolved binary -> bytes -> expected Python object
848
+ (flytekit customized (propeller processing) (flytekit binary IDL) (flytekit customized
849
+ serialization) deserialization)
850
+
851
+ Example Code:
852
+ @dataclass
853
+ class DC:
854
+ sd: StructuredDataset
855
+
856
+ @workflow
857
+ def wf(dc: DC):
858
+ t_sd(dc.sd)
859
+
860
+ Note:
861
+ - The deserialization is the same as put a structured dataset in a dataclass,
862
+ which will deserialize by the mashumaro's API.
863
+
864
+ Related PR:
865
+ - Title: Override Dataclass Serialization/Deserialization Behavior for FlyteTypes via Mashumaro
866
+ - Link: https://github.com/flyteorg/flytekit/pull/2554
867
+ """
868
+ if binary_idl_object.tag == MESSAGEPACK:
869
+ python_val = msgpack.loads(binary_idl_object.value)
870
+ return self.dict_to_structured_dataset(dict_obj=python_val, expected_python_type=expected_python_type)
871
+ else:
872
+ raise TypeTransformerFailedError(f"Unsupported binary format: `{binary_idl_object.tag}`")
873
+
874
+ async def to_python_value(
875
+ self, lv: literals_pb2.Literal, expected_python_type: Type[T] | StructuredDataset
876
+ ) -> T | StructuredDataset:
877
+ """
878
+ The only tricky thing with converting a Literal (say the output of an earlier task), to a Python value at
879
+ the start of a task execution, is the column subsetting behavior. For example, if you have,
880
+
881
+ def t1() -> Annotated[StructuredDataset, kwtypes(col_a=int, col_b=float)]: ...
882
+ def t2(in_a: Annotated[StructuredDataset, kwtypes(col_b=float)]): ...
883
+
884
+ where t2(in_a=t1()), when t2 does in_a.open(pd.DataFrame).all(), it should get a DataFrame
885
+ with only one column.
886
+
887
+ +-----------------------------+-----------------------------------------+--------------------------------------+
888
+ | | StructuredDatasetType of the incoming Literal |
889
+ +-----------------------------+-----------------------------------------+--------------------------------------+
890
+ | StructuredDatasetType | Has columns defined | [] columns or None |
891
+ | of currently running task | | |
892
+ +=============================+=========================================+======================================+
893
+ | Has columns | The StructuredDatasetType passed to the decoder will have the columns |
894
+ | defined | as defined by the type annotation of the currently running task. |
895
+ | | |
896
+ | | Decoders **should** then subset the incoming data to the columns requested. |
897
+ | | |
898
+ +-----------------------------+-----------------------------------------+--------------------------------------+
899
+ | [] columns or None | StructuredDatasetType passed to decoder | StructuredDatasetType passed to the |
900
+ | | will have the columns from the incoming | decoder will have an empty list of |
901
+ | | Literal. This is the scenario where | columns. |
902
+ | | the Literal returned by the running | |
903
+ | | task will have more information than | |
904
+ | | the running task's signature. | |
905
+ +-----------------------------+-----------------------------------------+--------------------------------------+
906
+ """
907
+ # Handle dataclass attribute access
908
+ if lv.HasField("scalar") and lv.scalar.HasField("binary"):
909
+ return self.from_binary_idl(lv.scalar.binary, expected_python_type)
910
+
911
+ # Detect annotations and extract out all the relevant information that the user might supply
912
+ expected_python_type, column_dict, storage_fmt, pa_schema = extract_cols_and_format(expected_python_type)
913
+
914
+ # Start handling for StructuredDataset scalars, first look at the columns
915
+ incoming_columns = lv.scalar.structured_dataset.metadata.structured_dataset_type.columns
916
+
917
+ # If the incoming literal, also doesn't have columns, then we just have an empty list, so initialize here
918
+ final_dataset_columns = []
919
+ # If the current running task's input does not have columns defined, or has an empty list of columns
920
+ if column_dict is None or len(column_dict) == 0:
921
+ # but if it does, then we just copy it over
922
+ if incoming_columns is not None and incoming_columns != []:
923
+ final_dataset_columns = incoming_columns[:]
924
+ # If the current running task's input does have columns defined
925
+ else:
926
+ final_dataset_columns = self._convert_ordered_dict_of_columns_to_list(column_dict)
927
+
928
+ new_sdt = types_pb2.StructuredDatasetType(
929
+ columns=final_dataset_columns,
930
+ format=lv.scalar.structured_dataset.metadata.structured_dataset_type.format,
931
+ external_schema_type=lv.scalar.structured_dataset.metadata.structured_dataset_type.external_schema_type,
932
+ external_schema_bytes=lv.scalar.structured_dataset.metadata.structured_dataset_type.external_schema_bytes,
933
+ )
934
+ metad = literals_pb2.StructuredDatasetMetadata(structured_dataset_type=new_sdt)
935
+
936
+ # A StructuredDataset type, for example
937
+ # t1(input_a: StructuredDataset) # or
938
+ # t1(input_a: Annotated[StructuredDataset, my_cols])
939
+ if issubclass(expected_python_type, StructuredDataset):
940
+ sd = expected_python_type(
941
+ dataframe=None,
942
+ # Note here that the type being passed in
943
+ metadata=metad,
944
+ )
945
+ sd._literal_sd = lv.scalar.structured_dataset
946
+ sd.file_format = metad.structured_dataset_type.format
947
+ return sd
948
+
949
+ # If the requested type was not a StructuredDataset, then it means it was a plain dataframe type, which means
950
+ # we should do the opening/downloading and whatever else it might entail right now. No iteration option here.
951
+ return await self.open_as(lv.scalar.structured_dataset, df_type=expected_python_type, updated_metadata=metad)
952
+
953
+ def to_html(self, python_val: typing.Any, expected_python_type: Type[T]) -> str:
954
+ if isinstance(python_val, StructuredDataset):
955
+ if python_val.dataframe is not None:
956
+ df = python_val.dataframe
957
+ else:
958
+ # Here we only render column information by default instead of opening the structured dataset.
959
+ col = typing.cast(StructuredDataset, python_val).columns()
960
+ df = pd.DataFrame(col, ["column type"])
961
+ return df.to_html() # type: ignore
962
+ else:
963
+ df = python_val
964
+
965
+ if type(df) in self.Renderers:
966
+ return self.Renderers[type(df)].to_html(df)
967
+ else:
968
+ raise NotImplementedError(f"Could not find a renderer for {type(df)} in {self.Renderers}")
969
+
970
+ async def open_as(
971
+ self,
972
+ sd: literals_pb2.StructuredDataset,
973
+ df_type: Type[DF],
974
+ updated_metadata: literals_pb2.StructuredDatasetMetadata,
975
+ ) -> DF:
976
+ """
977
+ :param sd:
978
+ :param df_type:
979
+ :param updated_metadata: New metadata type, since it might be different from the metadata in the literal.
980
+ :return: dataframe. It could be pandas dataframe or arrow table, etc.
981
+ """
982
+ protocol = get_protocol(sd.uri)
983
+ decoder = self.get_decoder(df_type, protocol, sd.metadata.structured_dataset_type.format)
984
+ result = await decoder.decode(sd, updated_metadata)
985
+ return result
986
+
987
+ async def iter_as(
988
+ self,
989
+ sd: literals_pb2.StructuredDataset,
990
+ df_type: Type[DF],
991
+ updated_metadata: literals_pb2.StructuredDatasetMetadata,
992
+ ) -> typing.AsyncIterator[DF]:
993
+ protocol = get_protocol(sd.uri)
994
+ decoder = self.DECODERS[df_type][protocol][sd.metadata.structured_dataset_type.format]
995
+ result: Union[DF, typing.AsyncIterator[DF]] = decoder.decode(sd, updated_metadata)
996
+ if not isinstance(result, types.AsyncGeneratorType):
997
+ raise ValueError(f"Decoder {decoder} didn't return an async iterator {result} but should have from {sd}")
998
+ return result
999
+
1000
+ def _get_dataset_column_literal_type(self, t: Type) -> types_pb2.LiteralType:
1001
+ if t in get_supported_types():
1002
+ return get_supported_types()[t]
1003
+ if hasattr(t, "__origin__") and t.__origin__ is list:
1004
+ return types_pb2.LiteralType(collection_type=self._get_dataset_column_literal_type(t.__args__[0]))
1005
+ if hasattr(t, "__origin__") and t.__origin__ is dict:
1006
+ return types_pb2.LiteralType(map_value_type=self._get_dataset_column_literal_type(t.__args__[1]))
1007
+ raise AssertionError(f"type {t} is currently not supported by StructuredDataset")
1008
+
1009
+ def _convert_ordered_dict_of_columns_to_list(
1010
+ self, column_map: typing.Optional[typing.OrderedDict[str, Type]]
1011
+ ) -> typing.List[types_pb2.StructuredDatasetType.DatasetColumn]:
1012
+ converted_cols: typing.List[types_pb2.StructuredDatasetType.DatasetColumn] = []
1013
+ if column_map is None or len(column_map) == 0:
1014
+ return converted_cols
1015
+ flat_column_map = flatten_dict(column_map)
1016
+ for k, v in flat_column_map.items():
1017
+ lt = self._get_dataset_column_literal_type(v)
1018
+ converted_cols.append(types_pb2.StructuredDatasetType.DatasetColumn(name=k, literal_type=lt))
1019
+ return converted_cols
1020
+
1021
+ def _get_dataset_type(
1022
+ self, t: typing.Union[Type[StructuredDataset], typing.Any]
1023
+ ) -> types_pb2.StructuredDatasetType:
1024
+ original_python_type, column_map, storage_format, pa_schema = extract_cols_and_format(t) # type: ignore
1025
+
1026
+ # Get the column information
1027
+ converted_cols: typing.List[types_pb2.StructuredDatasetType.DatasetColumn] = (
1028
+ self._convert_ordered_dict_of_columns_to_list(column_map)
1029
+ )
1030
+
1031
+ return types_pb2.StructuredDatasetType(
1032
+ columns=converted_cols,
1033
+ format=storage_format,
1034
+ external_schema_type="arrow" if pa_schema else None,
1035
+ external_schema_bytes=typing.cast(pa.lib.Schema, pa_schema).to_string().encode() if pa_schema else None,
1036
+ )
1037
+
1038
+ def get_literal_type(self, t: typing.Union[Type[StructuredDataset], typing.Any]) -> types_pb2.LiteralType:
1039
+ """
1040
+ Provide a concrete implementation so that writers of custom dataframe handlers since there's nothing that
1041
+ special about the literal type. Any dataframe type will always be associated with the structured dataset type.
1042
+ The other aspects of it - columns, external schema type, etc. can be read from associated metadata.
1043
+
1044
+ :param t: The python dataframe type, which is mostly ignored.
1045
+ """
1046
+ return types_pb2.LiteralType(structured_dataset_type=self._get_dataset_type(t))
1047
+
1048
+ def guess_python_type(self, literal_type: types_pb2.LiteralType) -> Type[StructuredDataset]:
1049
+ # todo: technically we should return the dataframe type specified in the constructor, but to do that,
1050
+ # we'd have to store that, which we don't do today. See possibly #1363
1051
+ if literal_type.HasField("structured_dataset_type"):
1052
+ return StructuredDataset
1053
+ raise ValueError(f"StructuredDatasetTransformerEngine cannot reverse {literal_type}")
1054
+
1055
+
1056
+ flyte_dataset_transformer = StructuredDatasetTransformerEngine()
1057
+ TypeEngine.register(flyte_dataset_transformer)