flyte 0.0.1b0__py3-none-any.whl → 2.0.0b46__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (455) hide show
  1. flyte/__init__.py +83 -30
  2. flyte/_bin/connect.py +61 -0
  3. flyte/_bin/debug.py +38 -0
  4. flyte/_bin/runtime.py +87 -19
  5. flyte/_bin/serve.py +351 -0
  6. flyte/_build.py +3 -2
  7. flyte/_cache/cache.py +6 -5
  8. flyte/_cache/local_cache.py +216 -0
  9. flyte/_code_bundle/_ignore.py +31 -5
  10. flyte/_code_bundle/_packaging.py +42 -11
  11. flyte/_code_bundle/_utils.py +57 -34
  12. flyte/_code_bundle/bundle.py +130 -27
  13. flyte/_constants.py +1 -0
  14. flyte/_context.py +21 -5
  15. flyte/_custom_context.py +73 -0
  16. flyte/_debug/constants.py +37 -0
  17. flyte/_debug/utils.py +17 -0
  18. flyte/_debug/vscode.py +315 -0
  19. flyte/_deploy.py +396 -75
  20. flyte/_deployer.py +109 -0
  21. flyte/_environment.py +94 -11
  22. flyte/_excepthook.py +37 -0
  23. flyte/_group.py +2 -1
  24. flyte/_hash.py +1 -16
  25. flyte/_image.py +544 -234
  26. flyte/_initialize.py +443 -294
  27. flyte/_interface.py +40 -5
  28. flyte/_internal/controllers/__init__.py +22 -8
  29. flyte/_internal/controllers/_local_controller.py +159 -35
  30. flyte/_internal/controllers/_trace.py +18 -10
  31. flyte/_internal/controllers/remote/__init__.py +38 -9
  32. flyte/_internal/controllers/remote/_action.py +82 -12
  33. flyte/_internal/controllers/remote/_client.py +6 -2
  34. flyte/_internal/controllers/remote/_controller.py +290 -64
  35. flyte/_internal/controllers/remote/_core.py +155 -95
  36. flyte/_internal/controllers/remote/_informer.py +40 -20
  37. flyte/_internal/controllers/remote/_service_protocol.py +2 -2
  38. flyte/_internal/imagebuild/__init__.py +2 -10
  39. flyte/_internal/imagebuild/docker_builder.py +391 -84
  40. flyte/_internal/imagebuild/image_builder.py +111 -55
  41. flyte/_internal/imagebuild/remote_builder.py +409 -0
  42. flyte/_internal/imagebuild/utils.py +79 -0
  43. flyte/_internal/resolvers/_app_env_module.py +92 -0
  44. flyte/_internal/resolvers/_task_module.py +5 -38
  45. flyte/_internal/resolvers/app_env.py +26 -0
  46. flyte/_internal/resolvers/common.py +8 -1
  47. flyte/_internal/resolvers/default.py +2 -2
  48. flyte/_internal/runtime/convert.py +322 -33
  49. flyte/_internal/runtime/entrypoints.py +106 -18
  50. flyte/_internal/runtime/io.py +71 -23
  51. flyte/_internal/runtime/resources_serde.py +21 -7
  52. flyte/_internal/runtime/reuse.py +125 -0
  53. flyte/_internal/runtime/rusty.py +196 -0
  54. flyte/_internal/runtime/task_serde.py +239 -66
  55. flyte/_internal/runtime/taskrunner.py +48 -8
  56. flyte/_internal/runtime/trigger_serde.py +162 -0
  57. flyte/_internal/runtime/types_serde.py +7 -16
  58. flyte/_keyring/file.py +115 -0
  59. flyte/_link.py +30 -0
  60. flyte/_logging.py +241 -42
  61. flyte/_map.py +312 -0
  62. flyte/_metrics.py +59 -0
  63. flyte/_module.py +74 -0
  64. flyte/_pod.py +30 -0
  65. flyte/_resources.py +296 -33
  66. flyte/_retry.py +1 -7
  67. flyte/_reusable_environment.py +72 -7
  68. flyte/_run.py +461 -132
  69. flyte/_secret.py +47 -11
  70. flyte/_serve.py +333 -0
  71. flyte/_task.py +245 -56
  72. flyte/_task_environment.py +219 -97
  73. flyte/_task_plugins.py +47 -0
  74. flyte/_tools.py +8 -8
  75. flyte/_trace.py +15 -24
  76. flyte/_trigger.py +1027 -0
  77. flyte/_utils/__init__.py +12 -1
  78. flyte/_utils/asyn.py +3 -1
  79. flyte/_utils/async_cache.py +139 -0
  80. flyte/_utils/coro_management.py +5 -4
  81. flyte/_utils/description_parser.py +19 -0
  82. flyte/_utils/docker_credentials.py +173 -0
  83. flyte/_utils/helpers.py +45 -19
  84. flyte/_utils/module_loader.py +123 -0
  85. flyte/_utils/org_discovery.py +57 -0
  86. flyte/_utils/uv_script_parser.py +8 -1
  87. flyte/_version.py +16 -3
  88. flyte/app/__init__.py +27 -0
  89. flyte/app/_app_environment.py +362 -0
  90. flyte/app/_connector_environment.py +40 -0
  91. flyte/app/_deploy.py +130 -0
  92. flyte/app/_parameter.py +343 -0
  93. flyte/app/_runtime/__init__.py +3 -0
  94. flyte/app/_runtime/app_serde.py +383 -0
  95. flyte/app/_types.py +113 -0
  96. flyte/app/extras/__init__.py +9 -0
  97. flyte/app/extras/_auth_middleware.py +217 -0
  98. flyte/app/extras/_fastapi.py +93 -0
  99. flyte/app/extras/_model_loader/__init__.py +3 -0
  100. flyte/app/extras/_model_loader/config.py +7 -0
  101. flyte/app/extras/_model_loader/loader.py +288 -0
  102. flyte/cli/__init__.py +12 -0
  103. flyte/cli/_abort.py +28 -0
  104. flyte/cli/_build.py +114 -0
  105. flyte/cli/_common.py +493 -0
  106. flyte/cli/_create.py +371 -0
  107. flyte/cli/_delete.py +45 -0
  108. flyte/cli/_deploy.py +401 -0
  109. flyte/cli/_gen.py +316 -0
  110. flyte/cli/_get.py +446 -0
  111. flyte/cli/_option.py +33 -0
  112. {union/_cli → flyte/cli}/_params.py +152 -153
  113. flyte/cli/_plugins.py +209 -0
  114. flyte/cli/_prefetch.py +292 -0
  115. flyte/cli/_run.py +690 -0
  116. flyte/cli/_serve.py +338 -0
  117. flyte/cli/_update.py +86 -0
  118. flyte/cli/_user.py +20 -0
  119. flyte/cli/main.py +246 -0
  120. flyte/config/__init__.py +3 -0
  121. flyte/config/_config.py +248 -0
  122. flyte/config/_internal.py +73 -0
  123. flyte/config/_reader.py +225 -0
  124. flyte/connectors/__init__.py +11 -0
  125. flyte/connectors/_connector.py +330 -0
  126. flyte/connectors/_server.py +194 -0
  127. flyte/connectors/utils.py +159 -0
  128. flyte/errors.py +134 -2
  129. flyte/extend.py +24 -0
  130. flyte/extras/_container.py +69 -56
  131. flyte/git/__init__.py +3 -0
  132. flyte/git/_config.py +279 -0
  133. flyte/io/__init__.py +8 -1
  134. flyte/io/{structured_dataset → _dataframe}/__init__.py +32 -30
  135. flyte/io/{structured_dataset → _dataframe}/basic_dfs.py +75 -68
  136. flyte/io/{structured_dataset/structured_dataset.py → _dataframe/dataframe.py} +207 -242
  137. flyte/io/_dir.py +575 -113
  138. flyte/io/_file.py +587 -141
  139. flyte/io/_hashing_io.py +342 -0
  140. flyte/io/extend.py +7 -0
  141. flyte/models.py +635 -0
  142. flyte/prefetch/__init__.py +22 -0
  143. flyte/prefetch/_hf_model.py +563 -0
  144. flyte/remote/__init__.py +14 -3
  145. flyte/remote/_action.py +879 -0
  146. flyte/remote/_app.py +346 -0
  147. flyte/remote/_auth_metadata.py +42 -0
  148. flyte/remote/_client/_protocols.py +62 -4
  149. flyte/remote/_client/auth/_auth_utils.py +19 -0
  150. flyte/remote/_client/auth/_authenticators/base.py +8 -2
  151. flyte/remote/_client/auth/_authenticators/device_code.py +4 -5
  152. flyte/remote/_client/auth/_authenticators/factory.py +4 -0
  153. flyte/remote/_client/auth/_authenticators/passthrough.py +79 -0
  154. flyte/remote/_client/auth/_authenticators/pkce.py +17 -18
  155. flyte/remote/_client/auth/_channel.py +47 -18
  156. flyte/remote/_client/auth/_client_config.py +5 -3
  157. flyte/remote/_client/auth/_keyring.py +15 -2
  158. flyte/remote/_client/auth/_token_client.py +3 -3
  159. flyte/remote/_client/controlplane.py +206 -18
  160. flyte/remote/_common.py +66 -0
  161. flyte/remote/_data.py +107 -22
  162. flyte/remote/_logs.py +116 -33
  163. flyte/remote/_project.py +21 -19
  164. flyte/remote/_run.py +164 -631
  165. flyte/remote/_secret.py +72 -29
  166. flyte/remote/_task.py +387 -46
  167. flyte/remote/_trigger.py +368 -0
  168. flyte/remote/_user.py +43 -0
  169. flyte/report/_report.py +10 -6
  170. flyte/storage/__init__.py +13 -1
  171. flyte/storage/_config.py +237 -0
  172. flyte/storage/_parallel_reader.py +289 -0
  173. flyte/storage/_storage.py +268 -59
  174. flyte/syncify/__init__.py +56 -0
  175. flyte/syncify/_api.py +414 -0
  176. flyte/types/__init__.py +39 -0
  177. flyte/types/_interface.py +22 -7
  178. flyte/{io/pickle/transformer.py → types/_pickle.py} +37 -9
  179. flyte/types/_string_literals.py +8 -9
  180. flyte/types/_type_engine.py +230 -129
  181. flyte/types/_utils.py +1 -1
  182. flyte-2.0.0b46.data/scripts/debug.py +38 -0
  183. flyte-2.0.0b46.data/scripts/runtime.py +194 -0
  184. flyte-2.0.0b46.dist-info/METADATA +352 -0
  185. flyte-2.0.0b46.dist-info/RECORD +221 -0
  186. flyte-2.0.0b46.dist-info/entry_points.txt +8 -0
  187. flyte-2.0.0b46.dist-info/licenses/LICENSE +201 -0
  188. flyte/_api_commons.py +0 -3
  189. flyte/_cli/_common.py +0 -287
  190. flyte/_cli/_create.py +0 -42
  191. flyte/_cli/_delete.py +0 -23
  192. flyte/_cli/_deploy.py +0 -140
  193. flyte/_cli/_get.py +0 -235
  194. flyte/_cli/_run.py +0 -152
  195. flyte/_cli/main.py +0 -72
  196. flyte/_datastructures.py +0 -342
  197. flyte/_internal/controllers/pbhash.py +0 -39
  198. flyte/_protos/common/authorization_pb2.py +0 -66
  199. flyte/_protos/common/authorization_pb2.pyi +0 -108
  200. flyte/_protos/common/authorization_pb2_grpc.py +0 -4
  201. flyte/_protos/common/identifier_pb2.py +0 -71
  202. flyte/_protos/common/identifier_pb2.pyi +0 -82
  203. flyte/_protos/common/identifier_pb2_grpc.py +0 -4
  204. flyte/_protos/common/identity_pb2.py +0 -48
  205. flyte/_protos/common/identity_pb2.pyi +0 -72
  206. flyte/_protos/common/identity_pb2_grpc.py +0 -4
  207. flyte/_protos/common/list_pb2.py +0 -36
  208. flyte/_protos/common/list_pb2.pyi +0 -69
  209. flyte/_protos/common/list_pb2_grpc.py +0 -4
  210. flyte/_protos/common/policy_pb2.py +0 -37
  211. flyte/_protos/common/policy_pb2.pyi +0 -27
  212. flyte/_protos/common/policy_pb2_grpc.py +0 -4
  213. flyte/_protos/common/role_pb2.py +0 -37
  214. flyte/_protos/common/role_pb2.pyi +0 -53
  215. flyte/_protos/common/role_pb2_grpc.py +0 -4
  216. flyte/_protos/common/runtime_version_pb2.py +0 -28
  217. flyte/_protos/common/runtime_version_pb2.pyi +0 -24
  218. flyte/_protos/common/runtime_version_pb2_grpc.py +0 -4
  219. flyte/_protos/logs/dataplane/payload_pb2.py +0 -96
  220. flyte/_protos/logs/dataplane/payload_pb2.pyi +0 -168
  221. flyte/_protos/logs/dataplane/payload_pb2_grpc.py +0 -4
  222. flyte/_protos/secret/definition_pb2.py +0 -49
  223. flyte/_protos/secret/definition_pb2.pyi +0 -93
  224. flyte/_protos/secret/definition_pb2_grpc.py +0 -4
  225. flyte/_protos/secret/payload_pb2.py +0 -62
  226. flyte/_protos/secret/payload_pb2.pyi +0 -94
  227. flyte/_protos/secret/payload_pb2_grpc.py +0 -4
  228. flyte/_protos/secret/secret_pb2.py +0 -38
  229. flyte/_protos/secret/secret_pb2.pyi +0 -6
  230. flyte/_protos/secret/secret_pb2_grpc.py +0 -198
  231. flyte/_protos/secret/secret_pb2_grpc_grpc.py +0 -198
  232. flyte/_protos/validate/validate/validate_pb2.py +0 -76
  233. flyte/_protos/workflow/node_execution_service_pb2.py +0 -26
  234. flyte/_protos/workflow/node_execution_service_pb2.pyi +0 -4
  235. flyte/_protos/workflow/node_execution_service_pb2_grpc.py +0 -32
  236. flyte/_protos/workflow/queue_service_pb2.py +0 -106
  237. flyte/_protos/workflow/queue_service_pb2.pyi +0 -141
  238. flyte/_protos/workflow/queue_service_pb2_grpc.py +0 -172
  239. flyte/_protos/workflow/run_definition_pb2.py +0 -128
  240. flyte/_protos/workflow/run_definition_pb2.pyi +0 -310
  241. flyte/_protos/workflow/run_definition_pb2_grpc.py +0 -4
  242. flyte/_protos/workflow/run_logs_service_pb2.py +0 -41
  243. flyte/_protos/workflow/run_logs_service_pb2.pyi +0 -28
  244. flyte/_protos/workflow/run_logs_service_pb2_grpc.py +0 -69
  245. flyte/_protos/workflow/run_service_pb2.py +0 -133
  246. flyte/_protos/workflow/run_service_pb2.pyi +0 -175
  247. flyte/_protos/workflow/run_service_pb2_grpc.py +0 -412
  248. flyte/_protos/workflow/state_service_pb2.py +0 -58
  249. flyte/_protos/workflow/state_service_pb2.pyi +0 -71
  250. flyte/_protos/workflow/state_service_pb2_grpc.py +0 -138
  251. flyte/_protos/workflow/task_definition_pb2.py +0 -72
  252. flyte/_protos/workflow/task_definition_pb2.pyi +0 -65
  253. flyte/_protos/workflow/task_definition_pb2_grpc.py +0 -4
  254. flyte/_protos/workflow/task_service_pb2.py +0 -44
  255. flyte/_protos/workflow/task_service_pb2.pyi +0 -31
  256. flyte/_protos/workflow/task_service_pb2_grpc.py +0 -104
  257. flyte/io/_dataframe.py +0 -0
  258. flyte/io/pickle/__init__.py +0 -0
  259. flyte/remote/_console.py +0 -18
  260. flyte-0.0.1b0.dist-info/METADATA +0 -179
  261. flyte-0.0.1b0.dist-info/RECORD +0 -390
  262. flyte-0.0.1b0.dist-info/entry_points.txt +0 -3
  263. union/__init__.py +0 -54
  264. union/_api_commons.py +0 -3
  265. union/_bin/__init__.py +0 -0
  266. union/_bin/runtime.py +0 -113
  267. union/_build.py +0 -25
  268. union/_cache/__init__.py +0 -12
  269. union/_cache/cache.py +0 -141
  270. union/_cache/defaults.py +0 -9
  271. union/_cache/policy_function_body.py +0 -42
  272. union/_cli/__init__.py +0 -0
  273. union/_cli/_common.py +0 -263
  274. union/_cli/_create.py +0 -40
  275. union/_cli/_delete.py +0 -23
  276. union/_cli/_deploy.py +0 -120
  277. union/_cli/_get.py +0 -162
  278. union/_cli/_run.py +0 -150
  279. union/_cli/main.py +0 -72
  280. union/_code_bundle/__init__.py +0 -8
  281. union/_code_bundle/_ignore.py +0 -113
  282. union/_code_bundle/_packaging.py +0 -187
  283. union/_code_bundle/_utils.py +0 -342
  284. union/_code_bundle/bundle.py +0 -176
  285. union/_context.py +0 -146
  286. union/_datastructures.py +0 -295
  287. union/_deploy.py +0 -185
  288. union/_doc.py +0 -29
  289. union/_docstring.py +0 -26
  290. union/_environment.py +0 -43
  291. union/_group.py +0 -31
  292. union/_hash.py +0 -23
  293. union/_image.py +0 -760
  294. union/_initialize.py +0 -585
  295. union/_interface.py +0 -84
  296. union/_internal/__init__.py +0 -3
  297. union/_internal/controllers/__init__.py +0 -77
  298. union/_internal/controllers/_local_controller.py +0 -77
  299. union/_internal/controllers/pbhash.py +0 -39
  300. union/_internal/controllers/remote/__init__.py +0 -40
  301. union/_internal/controllers/remote/_action.py +0 -131
  302. union/_internal/controllers/remote/_client.py +0 -43
  303. union/_internal/controllers/remote/_controller.py +0 -169
  304. union/_internal/controllers/remote/_core.py +0 -341
  305. union/_internal/controllers/remote/_informer.py +0 -260
  306. union/_internal/controllers/remote/_service_protocol.py +0 -44
  307. union/_internal/imagebuild/__init__.py +0 -11
  308. union/_internal/imagebuild/docker_builder.py +0 -416
  309. union/_internal/imagebuild/image_builder.py +0 -243
  310. union/_internal/imagebuild/remote_builder.py +0 -0
  311. union/_internal/resolvers/__init__.py +0 -0
  312. union/_internal/resolvers/_task_module.py +0 -31
  313. union/_internal/resolvers/common.py +0 -24
  314. union/_internal/resolvers/default.py +0 -27
  315. union/_internal/runtime/__init__.py +0 -0
  316. union/_internal/runtime/convert.py +0 -163
  317. union/_internal/runtime/entrypoints.py +0 -121
  318. union/_internal/runtime/io.py +0 -136
  319. union/_internal/runtime/resources_serde.py +0 -134
  320. union/_internal/runtime/task_serde.py +0 -202
  321. union/_internal/runtime/taskrunner.py +0 -179
  322. union/_internal/runtime/types_serde.py +0 -53
  323. union/_logging.py +0 -124
  324. union/_protos/__init__.py +0 -0
  325. union/_protos/common/authorization_pb2.py +0 -66
  326. union/_protos/common/authorization_pb2.pyi +0 -106
  327. union/_protos/common/authorization_pb2_grpc.py +0 -4
  328. union/_protos/common/identifier_pb2.py +0 -71
  329. union/_protos/common/identifier_pb2.pyi +0 -82
  330. union/_protos/common/identifier_pb2_grpc.py +0 -4
  331. union/_protos/common/identity_pb2.py +0 -48
  332. union/_protos/common/identity_pb2.pyi +0 -72
  333. union/_protos/common/identity_pb2_grpc.py +0 -4
  334. union/_protos/common/list_pb2.py +0 -36
  335. union/_protos/common/list_pb2.pyi +0 -69
  336. union/_protos/common/list_pb2_grpc.py +0 -4
  337. union/_protos/common/policy_pb2.py +0 -37
  338. union/_protos/common/policy_pb2.pyi +0 -27
  339. union/_protos/common/policy_pb2_grpc.py +0 -4
  340. union/_protos/common/role_pb2.py +0 -37
  341. union/_protos/common/role_pb2.pyi +0 -51
  342. union/_protos/common/role_pb2_grpc.py +0 -4
  343. union/_protos/common/runtime_version_pb2.py +0 -28
  344. union/_protos/common/runtime_version_pb2.pyi +0 -24
  345. union/_protos/common/runtime_version_pb2_grpc.py +0 -4
  346. union/_protos/logs/dataplane/payload_pb2.py +0 -96
  347. union/_protos/logs/dataplane/payload_pb2.pyi +0 -168
  348. union/_protos/logs/dataplane/payload_pb2_grpc.py +0 -4
  349. union/_protos/secret/definition_pb2.py +0 -49
  350. union/_protos/secret/definition_pb2.pyi +0 -93
  351. union/_protos/secret/definition_pb2_grpc.py +0 -4
  352. union/_protos/secret/payload_pb2.py +0 -62
  353. union/_protos/secret/payload_pb2.pyi +0 -94
  354. union/_protos/secret/payload_pb2_grpc.py +0 -4
  355. union/_protos/secret/secret_pb2.py +0 -38
  356. union/_protos/secret/secret_pb2.pyi +0 -6
  357. union/_protos/secret/secret_pb2_grpc.py +0 -198
  358. union/_protos/validate/validate/validate_pb2.py +0 -76
  359. union/_protos/workflow/node_execution_service_pb2.py +0 -26
  360. union/_protos/workflow/node_execution_service_pb2.pyi +0 -4
  361. union/_protos/workflow/node_execution_service_pb2_grpc.py +0 -32
  362. union/_protos/workflow/queue_service_pb2.py +0 -75
  363. union/_protos/workflow/queue_service_pb2.pyi +0 -103
  364. union/_protos/workflow/queue_service_pb2_grpc.py +0 -172
  365. union/_protos/workflow/run_definition_pb2.py +0 -100
  366. union/_protos/workflow/run_definition_pb2.pyi +0 -256
  367. union/_protos/workflow/run_definition_pb2_grpc.py +0 -4
  368. union/_protos/workflow/run_logs_service_pb2.py +0 -41
  369. union/_protos/workflow/run_logs_service_pb2.pyi +0 -28
  370. union/_protos/workflow/run_logs_service_pb2_grpc.py +0 -69
  371. union/_protos/workflow/run_service_pb2.py +0 -133
  372. union/_protos/workflow/run_service_pb2.pyi +0 -173
  373. union/_protos/workflow/run_service_pb2_grpc.py +0 -412
  374. union/_protos/workflow/state_service_pb2.py +0 -58
  375. union/_protos/workflow/state_service_pb2.pyi +0 -69
  376. union/_protos/workflow/state_service_pb2_grpc.py +0 -138
  377. union/_protos/workflow/task_definition_pb2.py +0 -72
  378. union/_protos/workflow/task_definition_pb2.pyi +0 -65
  379. union/_protos/workflow/task_definition_pb2_grpc.py +0 -4
  380. union/_protos/workflow/task_service_pb2.py +0 -44
  381. union/_protos/workflow/task_service_pb2.pyi +0 -31
  382. union/_protos/workflow/task_service_pb2_grpc.py +0 -104
  383. union/_resources.py +0 -226
  384. union/_retry.py +0 -32
  385. union/_reusable_environment.py +0 -25
  386. union/_run.py +0 -374
  387. union/_secret.py +0 -61
  388. union/_task.py +0 -354
  389. union/_task_environment.py +0 -186
  390. union/_timeout.py +0 -47
  391. union/_tools.py +0 -27
  392. union/_utils/__init__.py +0 -11
  393. union/_utils/asyn.py +0 -119
  394. union/_utils/file_handling.py +0 -71
  395. union/_utils/helpers.py +0 -46
  396. union/_utils/lazy_module.py +0 -54
  397. union/_utils/uv_script_parser.py +0 -49
  398. union/_version.py +0 -21
  399. union/connectors/__init__.py +0 -0
  400. union/errors.py +0 -128
  401. union/extras/__init__.py +0 -5
  402. union/extras/_container.py +0 -263
  403. union/io/__init__.py +0 -11
  404. union/io/_dataframe.py +0 -0
  405. union/io/_dir.py +0 -425
  406. union/io/_file.py +0 -418
  407. union/io/pickle/__init__.py +0 -0
  408. union/io/pickle/transformer.py +0 -117
  409. union/io/structured_dataset/__init__.py +0 -122
  410. union/io/structured_dataset/basic_dfs.py +0 -219
  411. union/io/structured_dataset/structured_dataset.py +0 -1057
  412. union/py.typed +0 -0
  413. union/remote/__init__.py +0 -23
  414. union/remote/_client/__init__.py +0 -0
  415. union/remote/_client/_protocols.py +0 -129
  416. union/remote/_client/auth/__init__.py +0 -12
  417. union/remote/_client/auth/_authenticators/__init__.py +0 -0
  418. union/remote/_client/auth/_authenticators/base.py +0 -391
  419. union/remote/_client/auth/_authenticators/client_credentials.py +0 -73
  420. union/remote/_client/auth/_authenticators/device_code.py +0 -120
  421. union/remote/_client/auth/_authenticators/external_command.py +0 -77
  422. union/remote/_client/auth/_authenticators/factory.py +0 -200
  423. union/remote/_client/auth/_authenticators/pkce.py +0 -515
  424. union/remote/_client/auth/_channel.py +0 -184
  425. union/remote/_client/auth/_client_config.py +0 -83
  426. union/remote/_client/auth/_default_html.py +0 -32
  427. union/remote/_client/auth/_grpc_utils/__init__.py +0 -0
  428. union/remote/_client/auth/_grpc_utils/auth_interceptor.py +0 -204
  429. union/remote/_client/auth/_grpc_utils/default_metadata_interceptor.py +0 -144
  430. union/remote/_client/auth/_keyring.py +0 -154
  431. union/remote/_client/auth/_token_client.py +0 -258
  432. union/remote/_client/auth/errors.py +0 -16
  433. union/remote/_client/controlplane.py +0 -86
  434. union/remote/_data.py +0 -149
  435. union/remote/_logs.py +0 -74
  436. union/remote/_project.py +0 -86
  437. union/remote/_run.py +0 -820
  438. union/remote/_secret.py +0 -132
  439. union/remote/_task.py +0 -193
  440. union/report/__init__.py +0 -3
  441. union/report/_report.py +0 -178
  442. union/report/_template.html +0 -124
  443. union/storage/__init__.py +0 -24
  444. union/storage/_remote_fs.py +0 -34
  445. union/storage/_storage.py +0 -247
  446. union/storage/_utils.py +0 -5
  447. union/types/__init__.py +0 -11
  448. union/types/_renderer.py +0 -162
  449. union/types/_string_literals.py +0 -120
  450. union/types/_type_engine.py +0 -2131
  451. union/types/_utils.py +0 -80
  452. /flyte/{_cli → _debug}/__init__.py +0 -0
  453. /flyte/{_protos → _keyring}/__init__.py +0 -0
  454. {flyte-0.0.1b0.dist-info → flyte-2.0.0b46.dist-info}/WHEEL +0 -0
  455. {flyte-0.0.1b0.dist-info → flyte-2.0.0b46.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,563 @@
1
+ """
2
+ HuggingFace model prefetch utilities for Flyte.
3
+
4
+ This module provides functionality to prefetch HuggingFace models to remote storage,
5
+ with support for optional sharding using vLLM.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import os
11
+ import re
12
+ import shutil
13
+ import tempfile
14
+ import typing
15
+ from typing import TYPE_CHECKING, Any, Literal
16
+
17
+ from pydantic import BaseModel, Field
18
+
19
+ from flyte._logging import logger
20
+ from flyte._resources import Resources
21
+ from flyte._task_environment import TaskEnvironment
22
+ from flyte.io import Dir
23
+
24
+ if TYPE_CHECKING:
25
+ from flyte.remote import Run
26
+
27
+
28
+ DEFAULT_SHARD_PATTERN = "model-rank-{rank}-part-{part}.safetensors"
29
+
30
+
31
+ class VLLMShardArgs(BaseModel):
32
+ """
33
+ Arguments for sharding a model using vLLM.
34
+
35
+ :param tensor_parallel_size: Number of tensor parallel workers.
36
+ :param dtype: Data type for model weights.
37
+ :param trust_remote_code: Whether to trust remote code from HuggingFace.
38
+ :param max_model_len: Maximum model context length.
39
+ :param file_pattern: Pattern for sharded weight files.
40
+ :param max_file_size: Maximum size for each sharded file.
41
+ """
42
+
43
+ tensor_parallel_size: int = 1
44
+ dtype: str = "auto"
45
+ trust_remote_code: bool = True
46
+ max_model_len: int | None = None
47
+ file_pattern: str | None = DEFAULT_SHARD_PATTERN
48
+ max_file_size: int = 5 * 1024**3 # 5GB default
49
+
50
+ def get_vllm_args(self, model_path: str) -> dict[str, Any]:
51
+ """Get arguments dict for vLLM LLM constructor."""
52
+ args = {
53
+ "model": model_path,
54
+ "tensor_parallel_size": self.tensor_parallel_size,
55
+ "dtype": self.dtype,
56
+ "trust_remote_code": self.trust_remote_code,
57
+ }
58
+ if self.max_model_len is not None:
59
+ args["max_model_len"] = self.max_model_len
60
+ return args
61
+
62
+
63
+ class ShardConfig(BaseModel):
64
+ """
65
+ Configuration for model sharding.
66
+
67
+ :param engine: The sharding engine to use (currently only "vllm" is supported).
68
+ :param args: Arguments for the sharding engine.
69
+ """
70
+
71
+ engine: Literal["vllm"] = "vllm"
72
+ args: VLLMShardArgs = Field(default_factory=VLLMShardArgs)
73
+
74
+
75
+ class HuggingFaceModelInfo(BaseModel):
76
+ """
77
+ Information about a HuggingFace model to store.
78
+
79
+ :param repo: The HuggingFace repository ID (e.g., 'meta-llama/Llama-2-7b-hf').
80
+ :param artifact_name: Optional name for the stored artifact. If not provided,
81
+ the repo name will be used (with '.' replaced by '-').
82
+ :param architecture: Model architecture from HuggingFace config.json.
83
+ :param task: Model task (e.g., 'generate', 'classify', 'embed').
84
+ :param modality: Modalities supported by the model (e.g., 'text', 'image').
85
+ :param serial_format: Model serialization format (e.g., 'safetensors', 'onnx').
86
+ :param model_type: Model type (e.g., 'transformer', 'custom').
87
+ :param short_description: Short description of the model.
88
+ :param shard_config: Optional configuration for model sharding.
89
+ """
90
+
91
+ repo: str
92
+ artifact_name: str | None = None
93
+ architecture: str | None = None
94
+ task: str = "auto"
95
+ modality: tuple[str, ...] = ("text",)
96
+ serial_format: str | None = None
97
+ model_type: str | None = None
98
+ short_description: str | None = None
99
+ shard_config: ShardConfig | None = None
100
+
101
+
102
+ class StoredModelInfo(BaseModel):
103
+ """
104
+ Information about a stored model.
105
+
106
+ :param artifact_name: Name of the stored artifact.
107
+ :param path: Path to the stored model directory.
108
+ :param metadata: Metadata about the stored model.
109
+ """
110
+
111
+ artifact_name: str
112
+ path: str
113
+ metadata: dict[str, str]
114
+
115
+
116
+ # Image definitions for the store task
117
+ HF_DOWNLOAD_IMAGE_PACKAGES = [
118
+ "huggingface-hub>=0.27.0",
119
+ "hf-transfer>=0.1.8",
120
+ "markdown>=3.10",
121
+ ]
122
+
123
+ VLLM_SHARDING_IMAGE_PACKAGES = [
124
+ *HF_DOWNLOAD_IMAGE_PACKAGES,
125
+ "vllm>=0.11.0",
126
+ ]
127
+
128
+
129
+ def _validate_artifact_name(name: str | None) -> None:
130
+ """Validate that artifact name contains only allowed characters."""
131
+ if name is not None and not re.match(r"^[a-zA-Z0-9_-]+$", name):
132
+ raise ValueError(f"Artifact name '{name}' must only contain alphanumeric characters, underscores, and hyphens")
133
+
134
+
135
+ def _lookup_huggingface_model_info(model_repo: str, commit: str, token: str | None) -> tuple[str | None, str | None]:
136
+ """
137
+ Lookup HuggingFace model info from config.json.
138
+
139
+ :param model_repo: The model repository ID.
140
+ :param commit: The commit ID.
141
+ :param token: HuggingFace token for private models.
142
+ :return: Tuple of (model_type, architecture).
143
+ """
144
+ import json
145
+
146
+ import huggingface_hub
147
+
148
+ config_file = huggingface_hub.hf_hub_download(
149
+ repo_id=model_repo, filename="config.json", revision=commit, token=token
150
+ )
151
+ arch = None
152
+ model_type = None
153
+ with open(config_file, "r") as f:
154
+ j = json.load(f)
155
+ arch = j.get("architecture", None)
156
+ if arch is None:
157
+ arch = j.get("architectures", None)
158
+ if arch:
159
+ arch = ",".join(arch)
160
+ model_type = j.get("model_type", None)
161
+ return model_type, arch
162
+
163
+
164
+ def _stream_to_remote_dir(
165
+ repo_id: str,
166
+ commit: str,
167
+ token: str | None,
168
+ remote_dir_path: str,
169
+ ) -> tuple[str, str | None]:
170
+ """
171
+ Stream files directly from HuggingFace to a remote directory.
172
+
173
+ :param repo_id: The HuggingFace repository ID.
174
+ :param commit: The commit ID.
175
+ :param token: HuggingFace token.
176
+ :param remote_dir_path: Path to the remote directory.
177
+ :return: Tuple of (remote_dir_path, readme_content).
178
+ """
179
+ import huggingface_hub
180
+
181
+ import flyte.storage as storage
182
+
183
+ hfs = huggingface_hub.HfFileSystem(token=token)
184
+ fs = storage.get_underlying_filesystem(path=remote_dir_path)
185
+ card = None
186
+
187
+ # Try to get README
188
+ try:
189
+ readme_file_details = hfs.info(f"{repo_id}/README.md", revision=commit)
190
+ readme_name = readme_file_details["name"]
191
+ with tempfile.NamedTemporaryFile() as temp_file:
192
+ hfs.download(readme_name, temp_file.name, revision=commit)
193
+ with open(temp_file.name, "r") as f:
194
+ card = f.read()
195
+ except FileNotFoundError:
196
+ logger.info("No README.md file found")
197
+
198
+ # List all files in the repo
199
+ repo_files = hfs.ls(f"{repo_id}", revision=commit, detail=True)
200
+
201
+ logger.info(f"Streaming {len(repo_files)} files to {remote_dir_path}")
202
+
203
+ for file_info in repo_files:
204
+ if isinstance(file_info, str):
205
+ logger.info(f" Skipping {file_info}...")
206
+ continue
207
+ if file_info["type"] == "file":
208
+ file_name = file_info["name"].split("/")[-1]
209
+ remote_file_path = f"{remote_dir_path}/{file_name}"
210
+ logger.info(f" Streaming {file_name}...")
211
+
212
+ # Stream file content directly to remote
213
+ with hfs.open(file_info["name"], "rb", revision=commit) as src:
214
+ with fs.open(remote_file_path, "wb") as dst:
215
+ # Stream in chunks
216
+ chunk_size = 64 * 1024 * 1024 # 64MB chunks
217
+ while True:
218
+ chunk = src.read(chunk_size)
219
+ if not chunk:
220
+ break
221
+ dst.write(chunk)
222
+
223
+ return remote_dir_path, card
224
+
225
+
226
+ def _download_snapshot_to_local(
227
+ repo_id: str,
228
+ commit: str,
229
+ token: str | None,
230
+ local_dir: str,
231
+ ) -> tuple[str, str | None]:
232
+ """
233
+ Download model snapshot to local directory.
234
+
235
+ :param repo_id: The HuggingFace repository ID.
236
+ :param commit: The commit ID.
237
+ :param token: HuggingFace token.
238
+ :param local_dir: Local directory to download to.
239
+ :return: Tuple of (local_dir, readme_content).
240
+ """
241
+ import huggingface_hub
242
+
243
+ card = None
244
+ hfs = huggingface_hub.HfFileSystem(token=token)
245
+
246
+ # Try to get README
247
+ try:
248
+ readme_file_details = hfs.info(f"{repo_id}/README.md", revision=commit)
249
+ readme_name = readme_file_details["name"]
250
+ with tempfile.NamedTemporaryFile() as temp_file:
251
+ hfs.download(readme_name, temp_file.name, revision=commit)
252
+ with open(temp_file.name, "r") as f:
253
+ card = f.read()
254
+ except FileNotFoundError:
255
+ logger.info("No README.md file found")
256
+
257
+ logger.info(f"Downloading model from {repo_id} to {local_dir}")
258
+ huggingface_hub.snapshot_download(
259
+ repo_id=repo_id,
260
+ revision=commit,
261
+ local_dir=local_dir,
262
+ token=token,
263
+ )
264
+ return local_dir, card
265
+
266
+
267
+ def _shard_model(
268
+ repo: str,
269
+ commit: str,
270
+ shard_config: ShardConfig,
271
+ token: str,
272
+ model_path: str,
273
+ output_dir: str,
274
+ ) -> tuple[str, str | None]:
275
+ """
276
+ Shard a model using vLLM.
277
+
278
+ :param shard_config: Sharding configuration.
279
+ :param model_path: Path to the model to shard.
280
+ :param output_dir: Directory to save sharded model.
281
+ :return: Path to sharded model directory.
282
+ """
283
+ import huggingface_hub
284
+ import vllm
285
+
286
+ assert shard_config.engine == "vllm", "'vllm' is the only supported sharding engine for now"
287
+
288
+ # Download snapshot
289
+ hfs = huggingface_hub.HfFileSystem(token=token)
290
+ try:
291
+ readme_info = hfs.info(f"{repo}/README.md", revision=commit)
292
+ with tempfile.NamedTemporaryFile() as temp_file:
293
+ hfs.download(readme_info["name"], temp_file.name, revision=commit)
294
+ with open(temp_file.name, "r") as f:
295
+ card = f.read()
296
+ except FileNotFoundError:
297
+ logger.warning("No README.md found")
298
+
299
+ logger.info(f"Downloading model to {model_path}")
300
+ huggingface_hub.snapshot_download(
301
+ repo_id=repo,
302
+ revision=commit,
303
+ local_dir=model_path,
304
+ token=token,
305
+ )
306
+
307
+ # Create LLM instance
308
+ llm = vllm.LLM(**shard_config.args.get_vllm_args(model_path))
309
+ logger.info(f"LLM initialized: {llm}")
310
+
311
+ llm.llm_engine.engine_core.save_sharded_state(
312
+ path=output_dir,
313
+ pattern=shard_config.args.file_pattern,
314
+ max_size=shard_config.args.max_file_size,
315
+ )
316
+
317
+ # Copy metadata files to output directory
318
+ logger.info(f"Copying metadata files to {output_dir}")
319
+ for file in os.listdir(model_path):
320
+ if os.path.splitext(file)[1] not in (".bin", ".pt", ".safetensors"):
321
+ src_path = os.path.join(model_path, file)
322
+ dst_path = os.path.join(output_dir, file)
323
+ if os.path.isdir(src_path):
324
+ shutil.copytree(src_path, dst_path)
325
+ else:
326
+ shutil.copy(src_path, dst_path)
327
+
328
+ return output_dir, card
329
+
330
+
331
+ # NOTE: the info argument is a json string instead of a HuggingFaceModelInfo
332
+ # object because the type engine cannot handle nested pydantic or dataclass
333
+ # objects when run in interactive mode.
334
+ def store_hf_model_task(info: str, raw_data_path: str | None = None) -> Dir:
335
+ """Task to store a HuggingFace model."""
336
+
337
+ import huggingface_hub
338
+
339
+ import flyte.report
340
+
341
+ # Get HF token from secrets
342
+ token = os.environ.get("HF_TOKEN")
343
+ assert token is not None, "HF_TOKEN environment variable is not set"
344
+
345
+ # Validate repo exists and get latest commit
346
+ _info: HuggingFaceModelInfo = HuggingFaceModelInfo.model_validate_json(info)
347
+ if not huggingface_hub.repo_exists(_info.repo, token=token):
348
+ raise ValueError(f"Repository {_info.repo} does not exist in HuggingFace.")
349
+
350
+ commit = huggingface_hub.list_repo_commits(_info.repo, token=token)[0].commit_id
351
+ logger.info(f"Latest commit: {commit}")
352
+
353
+ # Lookup model info if not provided
354
+ if not _info.model_type or not _info.architecture:
355
+ logger.info("Looking up HuggingFace model info...")
356
+ try:
357
+ _info.model_type, _info.architecture = _lookup_huggingface_model_info(_info.repo, commit, token)
358
+ except Exception as e:
359
+ logger.warning(f"Warning: Could not lookup model info: {e}")
360
+ _info.model_type = "custom"
361
+ _info.architecture = "custom"
362
+
363
+ logger.info(f"Model type: {_info.model_type}, architecture: {_info.architecture}")
364
+
365
+ # Determine artifact name
366
+ if _info.artifact_name is None:
367
+ artifact_name = _info.repo.split("/")[-1].replace(".", "-")
368
+ else:
369
+ artifact_name = _info.artifact_name
370
+
371
+ card = None
372
+ result_dir: Dir
373
+
374
+ # If sharding is needed, we must download locally first
375
+ if _info.shard_config is not None:
376
+ logger.info(f"Sharding requested with {_info.shard_config.engine} engine")
377
+
378
+ # Download to local temp directory
379
+ sharded_dir = tempfile.mkdtemp()
380
+ with tempfile.TemporaryDirectory() as local_model_dir:
381
+ sharded_dir, card = _shard_model(
382
+ _info.repo, commit, _info.shard_config, token, local_model_dir, sharded_dir
383
+ )
384
+
385
+ # Upload sharded model
386
+ logger.info("Uploading sharded model...")
387
+ result_dir = Dir.from_local_sync(sharded_dir, remote_destination=raw_data_path)
388
+
389
+ else:
390
+ # Try direct streaming first
391
+ try:
392
+ logger.info("Attempting direct streaming to remote storage...")
393
+
394
+ if raw_data_path is not None:
395
+ remote_path = raw_data_path
396
+ else:
397
+ remote_path = flyte.ctx().raw_data_path.get_random_remote_path(artifact_name) # type: ignore [union-attr]
398
+
399
+ remote_path, card = _stream_to_remote_dir(_info.repo, commit, token, remote_path)
400
+ result_dir = Dir.from_existing_remote(remote_path)
401
+ logger.info(f"Direct streaming completed to {remote_path}")
402
+
403
+ except Exception as e:
404
+ logger.error(f"Direct streaming failed: {e}")
405
+ logger.error("Falling back to snapshot download...")
406
+
407
+ # Fallback: download snapshot and upload
408
+ with tempfile.TemporaryDirectory() as local_model_dir:
409
+ _local_model_dir, card = _download_snapshot_to_local(_info.repo, commit, token, local_model_dir)
410
+ result_dir = Dir.from_local_sync(_local_model_dir, remote_destination=raw_data_path)
411
+
412
+ # create report from the markdown `card`
413
+ if card:
414
+ # Try to convert markdown to HTML for richer presentation, fallback to plain text
415
+ try:
416
+ # Try to import markdown if available (don't add import; just use if exists)
417
+ import markdown
418
+
419
+ report = markdown.markdown(card)
420
+ except Exception:
421
+ report = card # fallback to plain markdown content
422
+ flyte.report.log(report)
423
+ flyte.report.flush()
424
+
425
+ logger.info(f"Model stored successfully at {result_dir.path}")
426
+ return result_dir
427
+
428
+
429
+ def hf_model(
430
+ repo: str,
431
+ *,
432
+ raw_data_path: str | None = None,
433
+ artifact_name: str | None = None,
434
+ architecture: str | None = None,
435
+ task: str = "auto",
436
+ modality: tuple[str, ...] = ("text",),
437
+ serial_format: str | None = None,
438
+ model_type: str | None = None,
439
+ short_description: str | None = None,
440
+ shard_config: ShardConfig | None = None,
441
+ hf_token_key: str = "HF_TOKEN",
442
+ resources: Resources = Resources(cpu="2", memory="8Gi", disk="50Gi"),
443
+ force: int = 0,
444
+ ) -> Run:
445
+ """
446
+ Store a HuggingFace model to remote storage.
447
+
448
+ This function downloads a model from the HuggingFace Hub and prefetches it to
449
+ remote storage. It supports optional sharding using vLLM for large models.
450
+
451
+ The prefetch behavior follows this priority:
452
+ 1. If the model isn't being sharded, stream files directly to remote storage.
453
+ 2. If streaming fails, fall back to downloading a snapshot and uploading.
454
+ 3. If sharding is configured, download locally, shard with vLLM, then upload.
455
+
456
+ Example usage:
457
+
458
+ ```python
459
+ import flyte
460
+
461
+ flyte.init(endpoint="my-flyte-endpoint")
462
+
463
+ # Store a model without sharding
464
+ run = flyte.prefetch.hf_model(
465
+ repo="meta-llama/Llama-2-7b-hf",
466
+ hf_token_key="HF_TOKEN",
467
+ )
468
+ run.wait()
469
+
470
+ # Prefetch and shard a model
471
+ from flyte.prefetch import ShardConfig, VLLMShardArgs
472
+
473
+ run = flyte.prefetch.hf_model(
474
+ repo="meta-llama/Llama-2-70b-hf",
475
+ shard_config=ShardConfig(
476
+ engine="vllm",
477
+ args=VLLMShardArgs(tensor_parallel_size=8),
478
+ ),
479
+ accelerator="A100:8",
480
+ hf_token_key="HF_TOKEN",
481
+ )
482
+ run.wait()
483
+ ```
484
+
485
+ :param repo: The HuggingFace repository ID (e.g., 'meta-llama/Llama-2-7b-hf').
486
+ :param artifact_name: Optional name for the stored artifact. If not provided,
487
+ the repo name will be used (with '.' replaced by '-').
488
+ :param architecture: Model architecture from HuggingFace config.json.
489
+ :param task: Model task (e.g., 'generate', 'classify', 'embed'). Default: 'auto'.
490
+ :param modality: Modalities supported by the model. Default: ('text',).
491
+ :param serial_format: Model serialization format (e.g., 'safetensors', 'onnx').
492
+ :param model_type: Model type (e.g., 'transformer', 'custom').
493
+ :param short_description: Short description of the model.
494
+ :param shard_config: Optional configuration for model sharding with vLLM.
495
+ :param hf_token_key: Name of the secret containing the HuggingFace token. Default: 'HF_TOKEN'.
496
+ :param cpu: CPU request for the prefetch task (e.g., '2').
497
+ :param mem: Memory request for the prefetch task (e.g., '16Gi').
498
+ :param disk: Disk storage request (e.g., '100Gi').
499
+ :param gpu: Accelerator type in format '{type}:{quantity}' (e.g., 'A100:8', 'L4:1').
500
+ :param shm: Shared memory request (e.g., '100Gi', 'auto').
501
+ :param wait: Whether to wait for the prefetch task to complete. Default: False.
502
+ :param force: Force re-prefetch. Increment to force a new prefetch. Default: 0.
503
+
504
+ :return: A Run object representing the prefetch task execution.
505
+ """
506
+ import flyte
507
+ from flyte import Secret
508
+ from flyte.remote import Run
509
+
510
+ _validate_artifact_name(artifact_name)
511
+
512
+ info = HuggingFaceModelInfo(
513
+ repo=repo,
514
+ artifact_name=artifact_name,
515
+ architecture=architecture,
516
+ task=task,
517
+ modality=modality,
518
+ serial_format=serial_format,
519
+ model_type=model_type,
520
+ short_description=short_description,
521
+ shard_config=shard_config,
522
+ )
523
+
524
+ # Select image based on whether sharding is needed
525
+ if shard_config is not None:
526
+ image = (
527
+ flyte.Image.from_debian_base(name="prefetch-hf-model-image")
528
+ .with_apt_packages("gcc", "wget")
529
+ .with_commands(
530
+ [
531
+ "wget https://developer.download.nvidia.com/compute/cuda/repos/debian12/x86_64/cuda-keyring_1.1-1_all.deb",
532
+ "dpkg -i cuda-keyring_1.1-1_all.deb",
533
+ "apt-get update",
534
+ "apt-get install -y cuda-toolkit-12-9",
535
+ ]
536
+ )
537
+ .with_env_vars(
538
+ {
539
+ "CUDA_HOME": "/usr/local/cuda-12.9",
540
+ "LD_LIBRARY_PATH": "/usr/local/cuda-12.9/lib64/stubs",
541
+ "VLLM_USE_V1": "1",
542
+ }
543
+ )
544
+ .with_pip_packages(*VLLM_SHARDING_IMAGE_PACKAGES)
545
+ )
546
+ else:
547
+ image = flyte.Image.from_debian_base(name="prefetch-hf-model-image").with_pip_packages(
548
+ *HF_DOWNLOAD_IMAGE_PACKAGES
549
+ )
550
+
551
+ # Create a task from the module-level function with the configured environment
552
+ disable_run_cache = force > 0
553
+ env = TaskEnvironment(
554
+ name="prefetch-hf-model",
555
+ image=image,
556
+ resources=resources,
557
+ secrets=[Secret(key=hf_token_key, as_env_var="HF_TOKEN")],
558
+ )
559
+ task = env.task(report=True)(store_hf_model_task) # type: ignore [assignment]
560
+ run = flyte.with_runcontext(interactive_mode=True, disable_run_cache=disable_run_cache).run(
561
+ task, info.model_dump_json(), raw_data_path
562
+ )
563
+ return typing.cast(Run, run)
flyte/remote/__init__.py CHANGED
@@ -7,19 +7,30 @@ __all__ = [
7
7
  "ActionDetails",
8
8
  "ActionInputs",
9
9
  "ActionOutputs",
10
+ "App",
10
11
  "Project",
11
12
  "Run",
12
13
  "RunDetails",
13
14
  "Secret",
15
+ "SecretTypes",
14
16
  "Task",
17
+ "TaskDetails",
18
+ "Trigger",
19
+ "User",
20
+ "auth_metadata",
15
21
  "create_channel",
16
22
  "upload_dir",
17
23
  "upload_file",
18
24
  ]
19
25
 
26
+ from ._action import Action, ActionDetails, ActionInputs, ActionOutputs
27
+ from ._app import App
28
+ from ._auth_metadata import auth_metadata
20
29
  from ._client.auth import create_channel
21
30
  from ._data import upload_dir, upload_file
22
31
  from ._project import Project
23
- from ._run import Action, ActionDetails, ActionInputs, ActionOutputs, Run, RunDetails
24
- from ._secret import Secret
25
- from ._task import Task
32
+ from ._run import Run, RunDetails
33
+ from ._secret import Secret, SecretTypes
34
+ from ._task import Task, TaskDetails
35
+ from ._trigger import Trigger
36
+ from ._user import User