flyte 0.2.0b1__py3-none-any.whl → 2.0.0b46__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (266) hide show
  1. flyte/__init__.py +83 -30
  2. flyte/_bin/connect.py +61 -0
  3. flyte/_bin/debug.py +38 -0
  4. flyte/_bin/runtime.py +87 -19
  5. flyte/_bin/serve.py +351 -0
  6. flyte/_build.py +3 -2
  7. flyte/_cache/cache.py +6 -5
  8. flyte/_cache/local_cache.py +216 -0
  9. flyte/_code_bundle/_ignore.py +31 -5
  10. flyte/_code_bundle/_packaging.py +42 -11
  11. flyte/_code_bundle/_utils.py +57 -34
  12. flyte/_code_bundle/bundle.py +130 -27
  13. flyte/_constants.py +1 -0
  14. flyte/_context.py +21 -5
  15. flyte/_custom_context.py +73 -0
  16. flyte/_debug/constants.py +37 -0
  17. flyte/_debug/utils.py +17 -0
  18. flyte/_debug/vscode.py +315 -0
  19. flyte/_deploy.py +396 -75
  20. flyte/_deployer.py +109 -0
  21. flyte/_environment.py +94 -11
  22. flyte/_excepthook.py +37 -0
  23. flyte/_group.py +2 -1
  24. flyte/_hash.py +1 -16
  25. flyte/_image.py +544 -231
  26. flyte/_initialize.py +456 -316
  27. flyte/_interface.py +40 -5
  28. flyte/_internal/controllers/__init__.py +22 -8
  29. flyte/_internal/controllers/_local_controller.py +159 -35
  30. flyte/_internal/controllers/_trace.py +18 -10
  31. flyte/_internal/controllers/remote/__init__.py +38 -9
  32. flyte/_internal/controllers/remote/_action.py +82 -12
  33. flyte/_internal/controllers/remote/_client.py +6 -2
  34. flyte/_internal/controllers/remote/_controller.py +290 -64
  35. flyte/_internal/controllers/remote/_core.py +155 -95
  36. flyte/_internal/controllers/remote/_informer.py +40 -20
  37. flyte/_internal/controllers/remote/_service_protocol.py +2 -2
  38. flyte/_internal/imagebuild/__init__.py +2 -10
  39. flyte/_internal/imagebuild/docker_builder.py +391 -84
  40. flyte/_internal/imagebuild/image_builder.py +111 -55
  41. flyte/_internal/imagebuild/remote_builder.py +409 -0
  42. flyte/_internal/imagebuild/utils.py +79 -0
  43. flyte/_internal/resolvers/_app_env_module.py +92 -0
  44. flyte/_internal/resolvers/_task_module.py +5 -38
  45. flyte/_internal/resolvers/app_env.py +26 -0
  46. flyte/_internal/resolvers/common.py +8 -1
  47. flyte/_internal/resolvers/default.py +2 -2
  48. flyte/_internal/runtime/convert.py +319 -36
  49. flyte/_internal/runtime/entrypoints.py +106 -18
  50. flyte/_internal/runtime/io.py +71 -23
  51. flyte/_internal/runtime/resources_serde.py +21 -7
  52. flyte/_internal/runtime/reuse.py +125 -0
  53. flyte/_internal/runtime/rusty.py +196 -0
  54. flyte/_internal/runtime/task_serde.py +239 -66
  55. flyte/_internal/runtime/taskrunner.py +48 -8
  56. flyte/_internal/runtime/trigger_serde.py +162 -0
  57. flyte/_internal/runtime/types_serde.py +7 -16
  58. flyte/_keyring/file.py +115 -0
  59. flyte/_link.py +30 -0
  60. flyte/_logging.py +241 -42
  61. flyte/_map.py +312 -0
  62. flyte/_metrics.py +59 -0
  63. flyte/_module.py +74 -0
  64. flyte/_pod.py +30 -0
  65. flyte/_resources.py +296 -33
  66. flyte/_retry.py +1 -7
  67. flyte/_reusable_environment.py +72 -7
  68. flyte/_run.py +462 -132
  69. flyte/_secret.py +47 -11
  70. flyte/_serve.py +333 -0
  71. flyte/_task.py +245 -56
  72. flyte/_task_environment.py +219 -97
  73. flyte/_task_plugins.py +47 -0
  74. flyte/_tools.py +8 -8
  75. flyte/_trace.py +15 -24
  76. flyte/_trigger.py +1027 -0
  77. flyte/_utils/__init__.py +12 -1
  78. flyte/_utils/asyn.py +3 -1
  79. flyte/_utils/async_cache.py +139 -0
  80. flyte/_utils/coro_management.py +5 -4
  81. flyte/_utils/description_parser.py +19 -0
  82. flyte/_utils/docker_credentials.py +173 -0
  83. flyte/_utils/helpers.py +45 -19
  84. flyte/_utils/module_loader.py +123 -0
  85. flyte/_utils/org_discovery.py +57 -0
  86. flyte/_utils/uv_script_parser.py +8 -1
  87. flyte/_version.py +16 -3
  88. flyte/app/__init__.py +27 -0
  89. flyte/app/_app_environment.py +362 -0
  90. flyte/app/_connector_environment.py +40 -0
  91. flyte/app/_deploy.py +130 -0
  92. flyte/app/_parameter.py +343 -0
  93. flyte/app/_runtime/__init__.py +3 -0
  94. flyte/app/_runtime/app_serde.py +383 -0
  95. flyte/app/_types.py +113 -0
  96. flyte/app/extras/__init__.py +9 -0
  97. flyte/app/extras/_auth_middleware.py +217 -0
  98. flyte/app/extras/_fastapi.py +93 -0
  99. flyte/app/extras/_model_loader/__init__.py +3 -0
  100. flyte/app/extras/_model_loader/config.py +7 -0
  101. flyte/app/extras/_model_loader/loader.py +288 -0
  102. flyte/cli/__init__.py +12 -0
  103. flyte/cli/_abort.py +28 -0
  104. flyte/cli/_build.py +114 -0
  105. flyte/cli/_common.py +493 -0
  106. flyte/cli/_create.py +371 -0
  107. flyte/cli/_delete.py +45 -0
  108. flyte/cli/_deploy.py +401 -0
  109. flyte/cli/_gen.py +316 -0
  110. flyte/cli/_get.py +446 -0
  111. flyte/cli/_option.py +33 -0
  112. flyte/{_cli → cli}/_params.py +57 -17
  113. flyte/cli/_plugins.py +209 -0
  114. flyte/cli/_prefetch.py +292 -0
  115. flyte/cli/_run.py +690 -0
  116. flyte/cli/_serve.py +338 -0
  117. flyte/cli/_update.py +86 -0
  118. flyte/cli/_user.py +20 -0
  119. flyte/cli/main.py +246 -0
  120. flyte/config/__init__.py +2 -167
  121. flyte/config/_config.py +215 -163
  122. flyte/config/_internal.py +10 -1
  123. flyte/config/_reader.py +225 -0
  124. flyte/connectors/__init__.py +11 -0
  125. flyte/connectors/_connector.py +330 -0
  126. flyte/connectors/_server.py +194 -0
  127. flyte/connectors/utils.py +159 -0
  128. flyte/errors.py +134 -2
  129. flyte/extend.py +24 -0
  130. flyte/extras/_container.py +69 -56
  131. flyte/git/__init__.py +3 -0
  132. flyte/git/_config.py +279 -0
  133. flyte/io/__init__.py +8 -1
  134. flyte/io/{structured_dataset → _dataframe}/__init__.py +32 -30
  135. flyte/io/{structured_dataset → _dataframe}/basic_dfs.py +75 -68
  136. flyte/io/{structured_dataset/structured_dataset.py → _dataframe/dataframe.py} +207 -242
  137. flyte/io/_dir.py +575 -113
  138. flyte/io/_file.py +587 -141
  139. flyte/io/_hashing_io.py +342 -0
  140. flyte/io/extend.py +7 -0
  141. flyte/models.py +635 -0
  142. flyte/prefetch/__init__.py +22 -0
  143. flyte/prefetch/_hf_model.py +563 -0
  144. flyte/remote/__init__.py +14 -3
  145. flyte/remote/_action.py +879 -0
  146. flyte/remote/_app.py +346 -0
  147. flyte/remote/_auth_metadata.py +42 -0
  148. flyte/remote/_client/_protocols.py +62 -4
  149. flyte/remote/_client/auth/_auth_utils.py +19 -0
  150. flyte/remote/_client/auth/_authenticators/base.py +8 -2
  151. flyte/remote/_client/auth/_authenticators/device_code.py +4 -5
  152. flyte/remote/_client/auth/_authenticators/factory.py +4 -0
  153. flyte/remote/_client/auth/_authenticators/passthrough.py +79 -0
  154. flyte/remote/_client/auth/_authenticators/pkce.py +17 -18
  155. flyte/remote/_client/auth/_channel.py +47 -18
  156. flyte/remote/_client/auth/_client_config.py +5 -3
  157. flyte/remote/_client/auth/_keyring.py +15 -2
  158. flyte/remote/_client/auth/_token_client.py +3 -3
  159. flyte/remote/_client/controlplane.py +206 -18
  160. flyte/remote/_common.py +66 -0
  161. flyte/remote/_data.py +107 -22
  162. flyte/remote/_logs.py +116 -33
  163. flyte/remote/_project.py +21 -19
  164. flyte/remote/_run.py +164 -631
  165. flyte/remote/_secret.py +72 -29
  166. flyte/remote/_task.py +387 -46
  167. flyte/remote/_trigger.py +368 -0
  168. flyte/remote/_user.py +43 -0
  169. flyte/report/_report.py +10 -6
  170. flyte/storage/__init__.py +13 -1
  171. flyte/storage/_config.py +237 -0
  172. flyte/storage/_parallel_reader.py +289 -0
  173. flyte/storage/_storage.py +268 -59
  174. flyte/syncify/__init__.py +56 -0
  175. flyte/syncify/_api.py +414 -0
  176. flyte/types/__init__.py +39 -0
  177. flyte/types/_interface.py +22 -7
  178. flyte/{io/pickle/transformer.py → types/_pickle.py} +37 -9
  179. flyte/types/_string_literals.py +8 -9
  180. flyte/types/_type_engine.py +226 -126
  181. flyte/types/_utils.py +1 -1
  182. flyte-2.0.0b46.data/scripts/debug.py +38 -0
  183. flyte-2.0.0b46.data/scripts/runtime.py +194 -0
  184. flyte-2.0.0b46.dist-info/METADATA +352 -0
  185. flyte-2.0.0b46.dist-info/RECORD +221 -0
  186. flyte-2.0.0b46.dist-info/entry_points.txt +8 -0
  187. flyte-2.0.0b46.dist-info/licenses/LICENSE +201 -0
  188. flyte/_api_commons.py +0 -3
  189. flyte/_cli/_common.py +0 -299
  190. flyte/_cli/_create.py +0 -42
  191. flyte/_cli/_delete.py +0 -23
  192. flyte/_cli/_deploy.py +0 -140
  193. flyte/_cli/_get.py +0 -235
  194. flyte/_cli/_run.py +0 -174
  195. flyte/_cli/main.py +0 -98
  196. flyte/_datastructures.py +0 -342
  197. flyte/_internal/controllers/pbhash.py +0 -39
  198. flyte/_protos/common/authorization_pb2.py +0 -66
  199. flyte/_protos/common/authorization_pb2.pyi +0 -108
  200. flyte/_protos/common/authorization_pb2_grpc.py +0 -4
  201. flyte/_protos/common/identifier_pb2.py +0 -71
  202. flyte/_protos/common/identifier_pb2.pyi +0 -82
  203. flyte/_protos/common/identifier_pb2_grpc.py +0 -4
  204. flyte/_protos/common/identity_pb2.py +0 -48
  205. flyte/_protos/common/identity_pb2.pyi +0 -72
  206. flyte/_protos/common/identity_pb2_grpc.py +0 -4
  207. flyte/_protos/common/list_pb2.py +0 -36
  208. flyte/_protos/common/list_pb2.pyi +0 -69
  209. flyte/_protos/common/list_pb2_grpc.py +0 -4
  210. flyte/_protos/common/policy_pb2.py +0 -37
  211. flyte/_protos/common/policy_pb2.pyi +0 -27
  212. flyte/_protos/common/policy_pb2_grpc.py +0 -4
  213. flyte/_protos/common/role_pb2.py +0 -37
  214. flyte/_protos/common/role_pb2.pyi +0 -53
  215. flyte/_protos/common/role_pb2_grpc.py +0 -4
  216. flyte/_protos/common/runtime_version_pb2.py +0 -28
  217. flyte/_protos/common/runtime_version_pb2.pyi +0 -24
  218. flyte/_protos/common/runtime_version_pb2_grpc.py +0 -4
  219. flyte/_protos/logs/dataplane/payload_pb2.py +0 -96
  220. flyte/_protos/logs/dataplane/payload_pb2.pyi +0 -168
  221. flyte/_protos/logs/dataplane/payload_pb2_grpc.py +0 -4
  222. flyte/_protos/secret/definition_pb2.py +0 -49
  223. flyte/_protos/secret/definition_pb2.pyi +0 -93
  224. flyte/_protos/secret/definition_pb2_grpc.py +0 -4
  225. flyte/_protos/secret/payload_pb2.py +0 -62
  226. flyte/_protos/secret/payload_pb2.pyi +0 -94
  227. flyte/_protos/secret/payload_pb2_grpc.py +0 -4
  228. flyte/_protos/secret/secret_pb2.py +0 -38
  229. flyte/_protos/secret/secret_pb2.pyi +0 -6
  230. flyte/_protos/secret/secret_pb2_grpc.py +0 -198
  231. flyte/_protos/secret/secret_pb2_grpc_grpc.py +0 -198
  232. flyte/_protos/validate/validate/validate_pb2.py +0 -76
  233. flyte/_protos/workflow/node_execution_service_pb2.py +0 -26
  234. flyte/_protos/workflow/node_execution_service_pb2.pyi +0 -4
  235. flyte/_protos/workflow/node_execution_service_pb2_grpc.py +0 -32
  236. flyte/_protos/workflow/queue_service_pb2.py +0 -106
  237. flyte/_protos/workflow/queue_service_pb2.pyi +0 -141
  238. flyte/_protos/workflow/queue_service_pb2_grpc.py +0 -172
  239. flyte/_protos/workflow/run_definition_pb2.py +0 -128
  240. flyte/_protos/workflow/run_definition_pb2.pyi +0 -310
  241. flyte/_protos/workflow/run_definition_pb2_grpc.py +0 -4
  242. flyte/_protos/workflow/run_logs_service_pb2.py +0 -41
  243. flyte/_protos/workflow/run_logs_service_pb2.pyi +0 -28
  244. flyte/_protos/workflow/run_logs_service_pb2_grpc.py +0 -69
  245. flyte/_protos/workflow/run_service_pb2.py +0 -133
  246. flyte/_protos/workflow/run_service_pb2.pyi +0 -175
  247. flyte/_protos/workflow/run_service_pb2_grpc.py +0 -412
  248. flyte/_protos/workflow/state_service_pb2.py +0 -58
  249. flyte/_protos/workflow/state_service_pb2.pyi +0 -71
  250. flyte/_protos/workflow/state_service_pb2_grpc.py +0 -138
  251. flyte/_protos/workflow/task_definition_pb2.py +0 -72
  252. flyte/_protos/workflow/task_definition_pb2.pyi +0 -65
  253. flyte/_protos/workflow/task_definition_pb2_grpc.py +0 -4
  254. flyte/_protos/workflow/task_service_pb2.py +0 -44
  255. flyte/_protos/workflow/task_service_pb2.pyi +0 -31
  256. flyte/_protos/workflow/task_service_pb2_grpc.py +0 -104
  257. flyte/io/_dataframe.py +0 -0
  258. flyte/io/pickle/__init__.py +0 -0
  259. flyte/remote/_console.py +0 -18
  260. flyte-0.2.0b1.dist-info/METADATA +0 -179
  261. flyte-0.2.0b1.dist-info/RECORD +0 -204
  262. flyte-0.2.0b1.dist-info/entry_points.txt +0 -3
  263. /flyte/{_cli → _debug}/__init__.py +0 -0
  264. /flyte/{_protos → _keyring}/__init__.py +0 -0
  265. {flyte-0.2.0b1.dist-info → flyte-2.0.0b46.dist-info}/WHEEL +0 -0
  266. {flyte-0.2.0b1.dist-info → flyte-2.0.0b46.dist-info}/top_level.txt +0 -0
flyte/cli/_plugins.py ADDED
@@ -0,0 +1,209 @@
1
+ """CLI Plugin System for Flyte.
2
+
3
+ This module provides a plugin system that allows external packages to:
4
+ 1. Register new top-level CLI commands (e.g., flyte my-command)
5
+ 2. Register new subcommands in existing groups (e.g., flyte get my-object)
6
+ 3. Modify behavior of existing commands via hooks
7
+
8
+ Plugins are discovered via Python entry points.
9
+
10
+ Entry Point Groups:
11
+ - flyte.plugins.cli.commands: Register new commands
12
+ - Entry point name "foo" -> flyte foo (top-level command)
13
+ - Entry point name "get.bar" -> flyte get bar (adds subcommand to get group)
14
+ - Note: At most one dot is supported. For nested groups, register the entire
15
+ group hierarchy as a top-level command (without dots).
16
+
17
+ - flyte.plugins.cli.hooks: Modify existing commands
18
+ - Entry point name "run" -> modifies flyte run
19
+ - Entry point name "get.project" -> modifies flyte get project
20
+ - Note: At most one dot is supported.
21
+
22
+ Example Plugin Package:
23
+ # In your-plugin/pyproject.toml
24
+ [project.entry-points."flyte.plugins.cli.commands"]
25
+ my-command = "your_plugin.cli:my_command"
26
+ get.my-object = "your_plugin.cli:get_my_object"
27
+
28
+ [project.entry-points."flyte.plugins.cli.hooks"]
29
+ run = "your_plugin.hooks:modify_run"
30
+
31
+ # In your-plugin/your_plugin/cli.py
32
+ import rich_click as click
33
+
34
+ @click.command()
35
+ def my_command():
36
+ '''My custom top-level command.'''
37
+ click.echo("Hello from plugin!")
38
+
39
+ @click.command()
40
+ def get_my_object():
41
+ '''Get my custom object.'''
42
+ click.echo("Getting my object...")
43
+
44
+ # In your-plugin/your_plugin/hooks.py
45
+ def modify_run(command):
46
+ '''Add behavior to flyte run command.'''
47
+ # Wrap invoke() instead of callback to ensure Click's full machinery runs
48
+ original_invoke = command.invoke
49
+
50
+ def wrapper(ctx):
51
+ # Do something before
52
+ click.echo("Plugin: Starting task...")
53
+
54
+ result = original_invoke(ctx)
55
+
56
+ # Do something after
57
+ click.echo("Plugin: Task completed!")
58
+ return result
59
+
60
+ command.invoke = wrapper
61
+ return command
62
+ """
63
+
64
+ from importlib.metadata import entry_points
65
+ from typing import Callable
66
+
67
+ import rich_click as click
68
+
69
+ from flyte._logging import logger
70
+
71
+ # Type alias for command hooks
72
+ CommandHook = Callable[[click.Command], click.Command]
73
+
74
+
75
+ def discover_and_register_plugins(root_group: click.Group):
76
+ """
77
+ Discover all CLI plugins from installed packages and register them.
78
+
79
+ This function:
80
+ 1. Discovers command plugins and adds them to the CLI
81
+ 2. Discovers hook plugins and applies them to existing commands
82
+
83
+ Args:
84
+ root_group: The root Click command group (main CLI group)
85
+ """
86
+ _load_command_plugins(root_group)
87
+ _load_hook_plugins(root_group)
88
+
89
+
90
+ def _load_command_plugins(root_group: click.Group):
91
+ """Load and register command plugins."""
92
+ for ep in entry_points(group="flyte.plugins.cli.commands"):
93
+ try:
94
+ command = ep.load()
95
+ if not isinstance(command, click.Command):
96
+ logger.warning(f"Plugin {ep.name} did not return a click.Command, got {type(command).__name__}")
97
+ continue
98
+
99
+ # Check if this is a subcommand (contains dot notation)
100
+ if "." in ep.name:
101
+ group_name, command_name = ep.name.split(".", 1)
102
+
103
+ # Validate: only support one level of nesting (group.command)
104
+ if "." in command_name:
105
+ logger.error(
106
+ f"Plugin {ep.name} uses multiple dots, which is not supported. "
107
+ f"Use at most one dot (e.g., 'group.command'). "
108
+ f"For nested groups, register the entire group hierarchy as a top-level command."
109
+ )
110
+ continue
111
+
112
+ _add_subcommand_to_group(root_group, group_name, command_name, command)
113
+ else:
114
+ # Top-level command
115
+ root_group.add_command(command, name=ep.name)
116
+ logger.info(f"Registered plugin command: flyte {ep.name}")
117
+
118
+ except Exception as e:
119
+ logger.error(f"Failed to load plugin command {ep.name}: {e}")
120
+
121
+
122
+ def _load_hook_plugins(root_group: click.Group):
123
+ """Load and apply hook plugins to existing commands."""
124
+ for ep in entry_points(group="flyte.plugins.cli.hooks"):
125
+ try:
126
+ hook = ep.load()
127
+ if not callable(hook):
128
+ logger.warning(f"Plugin hook {ep.name} is not callable")
129
+ continue
130
+
131
+ # Check if this is a subcommand hook (contains dot notation)
132
+ if "." in ep.name:
133
+ group_name, command_name = ep.name.split(".", 1)
134
+
135
+ # Validate: only support one level of nesting (group.command)
136
+ if "." in command_name:
137
+ logger.error(
138
+ f"Plugin hook {ep.name} uses multiple dots, which is not supported. "
139
+ f"Use at most one dot (e.g., 'group.command')."
140
+ )
141
+ continue
142
+
143
+ _apply_hook_to_subcommand(root_group, group_name, command_name, hook)
144
+ else:
145
+ # Top-level command hook
146
+ _apply_hook_to_command(root_group, ep.name, hook)
147
+
148
+ except Exception as e:
149
+ logger.error(f"Failed to apply hook {ep.name}: {e}")
150
+
151
+
152
+ def _add_subcommand_to_group(root_group: click.Group, group_name: str, command_name: str, command: click.Command):
153
+ """Add a subcommand to an existing command group."""
154
+ if group_name not in root_group.commands:
155
+ logger.warning(f"Cannot add plugin subcommand '{command_name}' - group '{group_name}' does not exist")
156
+ return
157
+
158
+ group = root_group.commands[group_name]
159
+ if not isinstance(group, click.Group):
160
+ logger.warning(f"Cannot add plugin subcommand '{command_name}' - '{group_name}' is not a command group")
161
+ return
162
+
163
+ group.add_command(command, name=command_name)
164
+ # lower to debug later
165
+ logger.info(f"Registered plugin subcommand: flyte {group_name} {command_name}")
166
+
167
+
168
+ def _apply_hook_to_command(root_group: click.Group, command_name: str, hook: CommandHook):
169
+ """Apply a hook to a top-level command."""
170
+ if command_name not in root_group.commands:
171
+ logger.warning(f"Cannot apply hook - command '{command_name}' does not exist")
172
+ return
173
+
174
+ original_command = root_group.commands[command_name]
175
+ try:
176
+ modified_command = hook(original_command)
177
+ root_group.commands[command_name] = modified_command
178
+ # lower to debug later
179
+ logger.info(f"Applied hook to command: flyte {command_name}")
180
+ except Exception as e:
181
+ logger.error(f"Hook failed for command {command_name}: {e}")
182
+ root_group.commands[command_name] = original_command
183
+
184
+
185
+ def _apply_hook_to_subcommand(root_group: click.Group, group_name: str, command_name: str, hook: CommandHook):
186
+ """Apply a hook to a subcommand within a group."""
187
+ if group_name not in root_group.commands:
188
+ logger.warning(f"Cannot apply hook - group '{group_name}' does not exist")
189
+ return
190
+
191
+ group = root_group.commands[group_name]
192
+ if not isinstance(group, click.Group):
193
+ logger.warning(f"Cannot apply hook - '{group_name}' is not a command group")
194
+ return
195
+
196
+ if command_name not in group.commands:
197
+ logger.warning(f"Cannot apply hook - subcommand '{command_name}' does not exist in group '{group_name}'")
198
+ return
199
+
200
+ original_command = group.commands[command_name]
201
+ if original_command.callback is not None:
202
+ original_command.callback()
203
+ try:
204
+ modified_command = hook(original_command)
205
+ group.commands[command_name] = modified_command
206
+ logger.info(f"Applied hook to subcommand: flyte {group_name} {command_name}")
207
+ except Exception as e:
208
+ logger.error(f"Hook failed for subcommand {group_name}.{command_name}: {e}")
209
+ group.commands[command_name] = original_command
flyte/cli/_prefetch.py ADDED
@@ -0,0 +1,292 @@
1
+ """
2
+ CLI commands for prefetching artifacts from remote registries.
3
+ """
4
+
5
+ import typing
6
+ from pathlib import Path
7
+
8
+ import rich_click as click
9
+ from rich.console import Console
10
+
11
+ from flyte._resources import Accelerators
12
+ from flyte.cli._common import CommandBase
13
+
14
+ # Get all valid accelerator choices from the Accelerators literal type
15
+ ACCELERATOR_CHOICES = list(typing.get_args(Accelerators))
16
+
17
+
18
+ @click.group(name="prefetch")
19
+ def prefetch():
20
+ """
21
+ Prefetch artifacts from remote registries.
22
+
23
+ These commands help you download and prefetch artifacts like HuggingFace models
24
+ to your Flyte storage for faster access during task execution.
25
+ """
26
+
27
+
28
+ @prefetch.command(name="hf-model", cls=CommandBase)
29
+ @click.argument("repo", type=str)
30
+ @click.option(
31
+ "--raw-data-path",
32
+ type=str,
33
+ required=False,
34
+ default=None,
35
+ help=(
36
+ "Object store path to store the model. If not provided, the model will be stored using the default path "
37
+ "generated by Flyte storage layer."
38
+ ),
39
+ )
40
+ @click.option(
41
+ "--artifact-name",
42
+ type=str,
43
+ required=False,
44
+ default=None,
45
+ help=(
46
+ "Artifact name to use for the stored model. Must only contain alphanumeric characters, "
47
+ "underscores, and hyphens. If not provided, the repo name will be used (replacing '.' with '-')."
48
+ ),
49
+ )
50
+ @click.option(
51
+ "--architecture",
52
+ type=str,
53
+ help="Model architecture, as given in HuggingFace config.json.",
54
+ )
55
+ @click.option(
56
+ "--task",
57
+ default="auto",
58
+ type=str,
59
+ help=(
60
+ "Model task, e.g., 'generate', 'classify', 'embed', 'score', etc. "
61
+ "Refer to vLLM docs. 'auto' will try to discover this automatically."
62
+ ),
63
+ )
64
+ @click.option(
65
+ "--modality",
66
+ type=str,
67
+ multiple=True,
68
+ default=("text",),
69
+ help="Modalities supported by the model, e.g., 'text', 'image', 'audio', 'video'. Can be specified multiple times.",
70
+ )
71
+ @click.option(
72
+ "--format",
73
+ "serial_format",
74
+ type=str,
75
+ help="Model serialization format, e.g., safetensors, onnx, torchscript, joblib, etc.",
76
+ )
77
+ @click.option(
78
+ "--model-type",
79
+ type=str,
80
+ help=(
81
+ "Model type, e.g., 'transformer', 'xgboost', 'custom', etc. "
82
+ "For HuggingFace models, this is auto-determined from config.json['model_type']."
83
+ ),
84
+ )
85
+ @click.option(
86
+ "--short-description",
87
+ type=str,
88
+ help="Short description of the model.",
89
+ )
90
+ @click.option(
91
+ "--force",
92
+ type=int,
93
+ default=0,
94
+ help="Force store of the model. Increment value (--force=1, --force=2, ...) to force a new store.",
95
+ )
96
+ @click.option(
97
+ "--wait",
98
+ is_flag=True,
99
+ help="Wait for the model to be stored before returning.",
100
+ )
101
+ @click.option(
102
+ "--hf-token-key",
103
+ type=str,
104
+ default="HF_TOKEN",
105
+ help=(
106
+ "Name of the Flyte secret containing your HuggingFace token. "
107
+ "Note: This is not the HuggingFace token itself, but the name of the "
108
+ "secret in the Flyte secret store."
109
+ ),
110
+ show_default=True,
111
+ )
112
+ @click.option(
113
+ "--cpu",
114
+ type=str,
115
+ default="2",
116
+ help="CPU request for the prefetch task (e.g., '2', '4', '2,4' for 2-4 CPUs).",
117
+ )
118
+ @click.option(
119
+ "--mem",
120
+ type=str,
121
+ default="8Gi",
122
+ help="Memory request for the prefetch task (e.g., '16Gi', '64Gi', '16Gi,64Gi' for 16-64GB).",
123
+ )
124
+ @click.option(
125
+ "--gpu",
126
+ type=click.Choice(ACCELERATOR_CHOICES),
127
+ default=None,
128
+ help=(
129
+ "The gpu to use for downloading and (optionally) sharding the model. "
130
+ "Format: '{type}:{quantity}' (e.g., 'A100:8', 'L4:1')."
131
+ ),
132
+ )
133
+ @click.option(
134
+ "--disk",
135
+ type=str,
136
+ default="50Gi",
137
+ help="Disk storage request for the prefetch task (e.g., '100Gi', '500Gi').",
138
+ )
139
+ @click.option(
140
+ "--shm",
141
+ type=str,
142
+ default=None,
143
+ help="Shared memory request for the prefetch task (e.g., '100Gi', 'auto').",
144
+ )
145
+ @click.option(
146
+ "--shard-config",
147
+ type=click.Path(exists=True, path_type=Path),
148
+ help=(
149
+ "Path to a YAML file containing sharding configuration. "
150
+ "The file should have 'engine' (currently only 'vllm') and 'args' keys."
151
+ ),
152
+ )
153
+ @click.pass_obj
154
+ def hf_model(
155
+ cfg,
156
+ repo: str,
157
+ raw_data_path: str | None,
158
+ artifact_name: str | None,
159
+ architecture: str | None,
160
+ task: str,
161
+ modality: tuple[str, ...],
162
+ serial_format: str | None,
163
+ model_type: str | None,
164
+ short_description: str | None,
165
+ force: int,
166
+ wait: bool,
167
+ hf_token_key: str,
168
+ cpu: str | None,
169
+ mem: str | None,
170
+ disk: str | None,
171
+ gpu: Accelerators | None,
172
+ shm: str | None,
173
+ shard_config: Path | None,
174
+ project: str | None,
175
+ domain: str | None,
176
+ ):
177
+ """
178
+ Prefetch a HuggingFace model to Flyte storage.
179
+
180
+ Downloads a model from the HuggingFace Hub and prefetches it to your configured
181
+ Flyte storage backend. This is useful for:
182
+
183
+ - Pre-fetching large models before running inference tasks
184
+ - Sharding models for tensor-parallel inference
185
+ - Avoiding repeated downloads during development
186
+
187
+ **Basic Usage:**
188
+
189
+ ```bash
190
+ $ flyte prefetch hf-model meta-llama/Llama-2-7b-hf --hf-token-key HF_TOKEN
191
+ ```
192
+
193
+ **With Sharding:**
194
+
195
+ Create a shard config file (shard_config.yaml):
196
+
197
+ ```yaml
198
+ engine: vllm
199
+ args:
200
+ tensor_parallel_size: 8
201
+ dtype: auto
202
+ trust_remote_code: true
203
+ ```
204
+
205
+ Then run:
206
+
207
+ ```bash
208
+ $ flyte prefetch hf-model meta-llama/Llama-2-70b-hf \\
209
+ --shard-config shard_config.yaml \\
210
+ --accelerator A100:8 \\
211
+ --hf-token-key HF_TOKEN
212
+ ```
213
+
214
+ **Wait for Completion:**
215
+
216
+ ```bash
217
+ $ flyte prefetch hf-model meta-llama/Llama-2-7b-hf --wait
218
+ ```
219
+ """
220
+ import yaml
221
+
222
+ from flyte._resources import Resources
223
+ from flyte.cli._run import initialize_config
224
+ from flyte.prefetch import ShardConfig, VLLMShardArgs
225
+ from flyte.prefetch import hf_model as prefetch_hf_model
226
+
227
+ # Initialize flyte config
228
+ cfg = initialize_config(
229
+ cfg.ctx,
230
+ project or cfg.config.task.project,
231
+ domain or cfg.config.task.domain,
232
+ )
233
+
234
+ # Parse shard config if provided
235
+ parsed_shard_config = None
236
+ if shard_config is not None:
237
+ with shard_config.open() as f:
238
+ shard_config_dict = yaml.safe_load(f)
239
+ args_dict = shard_config_dict.get("args", {})
240
+ parsed_shard_config = ShardConfig(
241
+ engine=shard_config_dict.get("engine", "vllm"),
242
+ args=VLLMShardArgs(**args_dict),
243
+ )
244
+
245
+ console = Console()
246
+
247
+ console.print("[bold green]Starting model prefetch task...")
248
+
249
+ # Parse cpu and mem for range syntax (e.g., "2, 4" -> ("2", "4"))
250
+ parsed_cpu: str | tuple[str, str] | None = cpu
251
+ if cpu is not None:
252
+ cpu_parts = cpu.split(", ")
253
+ if len(cpu_parts) > 1:
254
+ parsed_cpu = (cpu_parts[0], cpu_parts[1])
255
+
256
+ parsed_mem: str | tuple[str, str] | None = mem
257
+ if mem is not None:
258
+ mem_parts = mem.split(", ")
259
+ if len(mem_parts) > 1:
260
+ parsed_mem = (mem_parts[0], mem_parts[1])
261
+
262
+ run = prefetch_hf_model(
263
+ repo=repo,
264
+ raw_data_path=raw_data_path,
265
+ artifact_name=artifact_name,
266
+ architecture=architecture,
267
+ task=task,
268
+ modality=modality,
269
+ serial_format=serial_format,
270
+ model_type=model_type,
271
+ short_description=short_description,
272
+ shard_config=parsed_shard_config,
273
+ hf_token_key=hf_token_key,
274
+ resources=Resources(cpu=parsed_cpu, memory=parsed_mem, disk=disk, gpu=gpu, shm=shm),
275
+ force=force,
276
+ )
277
+
278
+ url = run.url
279
+ console.print(
280
+ f"šŸ”„ Started run {run.name} to prefetch model from HuggingFace repo [bold]{repo}[/bold].\n"
281
+ f" Check the console for status at [link={url}]{url}[/link]"
282
+ )
283
+
284
+ if wait:
285
+ run.wait()
286
+ try:
287
+ model_path = run.outputs()[0].path
288
+ console.print("\nāœ… Model prefetched successfully!")
289
+ console.print(f"Remote path: [cyan]{model_path}[/cyan]")
290
+ except Exception as e:
291
+ console.print("\nāŒ Model prefetch failed!")
292
+ console.print(f"Error: {e}")