flyte 0.0.1b0__py3-none-any.whl → 2.0.0b46__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flyte/__init__.py +83 -30
- flyte/_bin/connect.py +61 -0
- flyte/_bin/debug.py +38 -0
- flyte/_bin/runtime.py +87 -19
- flyte/_bin/serve.py +351 -0
- flyte/_build.py +3 -2
- flyte/_cache/cache.py +6 -5
- flyte/_cache/local_cache.py +216 -0
- flyte/_code_bundle/_ignore.py +31 -5
- flyte/_code_bundle/_packaging.py +42 -11
- flyte/_code_bundle/_utils.py +57 -34
- flyte/_code_bundle/bundle.py +130 -27
- flyte/_constants.py +1 -0
- flyte/_context.py +21 -5
- flyte/_custom_context.py +73 -0
- flyte/_debug/constants.py +37 -0
- flyte/_debug/utils.py +17 -0
- flyte/_debug/vscode.py +315 -0
- flyte/_deploy.py +396 -75
- flyte/_deployer.py +109 -0
- flyte/_environment.py +94 -11
- flyte/_excepthook.py +37 -0
- flyte/_group.py +2 -1
- flyte/_hash.py +1 -16
- flyte/_image.py +544 -234
- flyte/_initialize.py +443 -294
- flyte/_interface.py +40 -5
- flyte/_internal/controllers/__init__.py +22 -8
- flyte/_internal/controllers/_local_controller.py +159 -35
- flyte/_internal/controllers/_trace.py +18 -10
- flyte/_internal/controllers/remote/__init__.py +38 -9
- flyte/_internal/controllers/remote/_action.py +82 -12
- flyte/_internal/controllers/remote/_client.py +6 -2
- flyte/_internal/controllers/remote/_controller.py +290 -64
- flyte/_internal/controllers/remote/_core.py +155 -95
- flyte/_internal/controllers/remote/_informer.py +40 -20
- flyte/_internal/controllers/remote/_service_protocol.py +2 -2
- flyte/_internal/imagebuild/__init__.py +2 -10
- flyte/_internal/imagebuild/docker_builder.py +391 -84
- flyte/_internal/imagebuild/image_builder.py +111 -55
- flyte/_internal/imagebuild/remote_builder.py +409 -0
- flyte/_internal/imagebuild/utils.py +79 -0
- flyte/_internal/resolvers/_app_env_module.py +92 -0
- flyte/_internal/resolvers/_task_module.py +5 -38
- flyte/_internal/resolvers/app_env.py +26 -0
- flyte/_internal/resolvers/common.py +8 -1
- flyte/_internal/resolvers/default.py +2 -2
- flyte/_internal/runtime/convert.py +322 -33
- flyte/_internal/runtime/entrypoints.py +106 -18
- flyte/_internal/runtime/io.py +71 -23
- flyte/_internal/runtime/resources_serde.py +21 -7
- flyte/_internal/runtime/reuse.py +125 -0
- flyte/_internal/runtime/rusty.py +196 -0
- flyte/_internal/runtime/task_serde.py +239 -66
- flyte/_internal/runtime/taskrunner.py +48 -8
- flyte/_internal/runtime/trigger_serde.py +162 -0
- flyte/_internal/runtime/types_serde.py +7 -16
- flyte/_keyring/file.py +115 -0
- flyte/_link.py +30 -0
- flyte/_logging.py +241 -42
- flyte/_map.py +312 -0
- flyte/_metrics.py +59 -0
- flyte/_module.py +74 -0
- flyte/_pod.py +30 -0
- flyte/_resources.py +296 -33
- flyte/_retry.py +1 -7
- flyte/_reusable_environment.py +72 -7
- flyte/_run.py +461 -132
- flyte/_secret.py +47 -11
- flyte/_serve.py +333 -0
- flyte/_task.py +245 -56
- flyte/_task_environment.py +219 -97
- flyte/_task_plugins.py +47 -0
- flyte/_tools.py +8 -8
- flyte/_trace.py +15 -24
- flyte/_trigger.py +1027 -0
- flyte/_utils/__init__.py +12 -1
- flyte/_utils/asyn.py +3 -1
- flyte/_utils/async_cache.py +139 -0
- flyte/_utils/coro_management.py +5 -4
- flyte/_utils/description_parser.py +19 -0
- flyte/_utils/docker_credentials.py +173 -0
- flyte/_utils/helpers.py +45 -19
- flyte/_utils/module_loader.py +123 -0
- flyte/_utils/org_discovery.py +57 -0
- flyte/_utils/uv_script_parser.py +8 -1
- flyte/_version.py +16 -3
- flyte/app/__init__.py +27 -0
- flyte/app/_app_environment.py +362 -0
- flyte/app/_connector_environment.py +40 -0
- flyte/app/_deploy.py +130 -0
- flyte/app/_parameter.py +343 -0
- flyte/app/_runtime/__init__.py +3 -0
- flyte/app/_runtime/app_serde.py +383 -0
- flyte/app/_types.py +113 -0
- flyte/app/extras/__init__.py +9 -0
- flyte/app/extras/_auth_middleware.py +217 -0
- flyte/app/extras/_fastapi.py +93 -0
- flyte/app/extras/_model_loader/__init__.py +3 -0
- flyte/app/extras/_model_loader/config.py +7 -0
- flyte/app/extras/_model_loader/loader.py +288 -0
- flyte/cli/__init__.py +12 -0
- flyte/cli/_abort.py +28 -0
- flyte/cli/_build.py +114 -0
- flyte/cli/_common.py +493 -0
- flyte/cli/_create.py +371 -0
- flyte/cli/_delete.py +45 -0
- flyte/cli/_deploy.py +401 -0
- flyte/cli/_gen.py +316 -0
- flyte/cli/_get.py +446 -0
- flyte/cli/_option.py +33 -0
- {union/_cli → flyte/cli}/_params.py +152 -153
- flyte/cli/_plugins.py +209 -0
- flyte/cli/_prefetch.py +292 -0
- flyte/cli/_run.py +690 -0
- flyte/cli/_serve.py +338 -0
- flyte/cli/_update.py +86 -0
- flyte/cli/_user.py +20 -0
- flyte/cli/main.py +246 -0
- flyte/config/__init__.py +3 -0
- flyte/config/_config.py +248 -0
- flyte/config/_internal.py +73 -0
- flyte/config/_reader.py +225 -0
- flyte/connectors/__init__.py +11 -0
- flyte/connectors/_connector.py +330 -0
- flyte/connectors/_server.py +194 -0
- flyte/connectors/utils.py +159 -0
- flyte/errors.py +134 -2
- flyte/extend.py +24 -0
- flyte/extras/_container.py +69 -56
- flyte/git/__init__.py +3 -0
- flyte/git/_config.py +279 -0
- flyte/io/__init__.py +8 -1
- flyte/io/{structured_dataset → _dataframe}/__init__.py +32 -30
- flyte/io/{structured_dataset → _dataframe}/basic_dfs.py +75 -68
- flyte/io/{structured_dataset/structured_dataset.py → _dataframe/dataframe.py} +207 -242
- flyte/io/_dir.py +575 -113
- flyte/io/_file.py +587 -141
- flyte/io/_hashing_io.py +342 -0
- flyte/io/extend.py +7 -0
- flyte/models.py +635 -0
- flyte/prefetch/__init__.py +22 -0
- flyte/prefetch/_hf_model.py +563 -0
- flyte/remote/__init__.py +14 -3
- flyte/remote/_action.py +879 -0
- flyte/remote/_app.py +346 -0
- flyte/remote/_auth_metadata.py +42 -0
- flyte/remote/_client/_protocols.py +62 -4
- flyte/remote/_client/auth/_auth_utils.py +19 -0
- flyte/remote/_client/auth/_authenticators/base.py +8 -2
- flyte/remote/_client/auth/_authenticators/device_code.py +4 -5
- flyte/remote/_client/auth/_authenticators/factory.py +4 -0
- flyte/remote/_client/auth/_authenticators/passthrough.py +79 -0
- flyte/remote/_client/auth/_authenticators/pkce.py +17 -18
- flyte/remote/_client/auth/_channel.py +47 -18
- flyte/remote/_client/auth/_client_config.py +5 -3
- flyte/remote/_client/auth/_keyring.py +15 -2
- flyte/remote/_client/auth/_token_client.py +3 -3
- flyte/remote/_client/controlplane.py +206 -18
- flyte/remote/_common.py +66 -0
- flyte/remote/_data.py +107 -22
- flyte/remote/_logs.py +116 -33
- flyte/remote/_project.py +21 -19
- flyte/remote/_run.py +164 -631
- flyte/remote/_secret.py +72 -29
- flyte/remote/_task.py +387 -46
- flyte/remote/_trigger.py +368 -0
- flyte/remote/_user.py +43 -0
- flyte/report/_report.py +10 -6
- flyte/storage/__init__.py +13 -1
- flyte/storage/_config.py +237 -0
- flyte/storage/_parallel_reader.py +289 -0
- flyte/storage/_storage.py +268 -59
- flyte/syncify/__init__.py +56 -0
- flyte/syncify/_api.py +414 -0
- flyte/types/__init__.py +39 -0
- flyte/types/_interface.py +22 -7
- flyte/{io/pickle/transformer.py → types/_pickle.py} +37 -9
- flyte/types/_string_literals.py +8 -9
- flyte/types/_type_engine.py +230 -129
- flyte/types/_utils.py +1 -1
- flyte-2.0.0b46.data/scripts/debug.py +38 -0
- flyte-2.0.0b46.data/scripts/runtime.py +194 -0
- flyte-2.0.0b46.dist-info/METADATA +352 -0
- flyte-2.0.0b46.dist-info/RECORD +221 -0
- flyte-2.0.0b46.dist-info/entry_points.txt +8 -0
- flyte-2.0.0b46.dist-info/licenses/LICENSE +201 -0
- flyte/_api_commons.py +0 -3
- flyte/_cli/_common.py +0 -287
- flyte/_cli/_create.py +0 -42
- flyte/_cli/_delete.py +0 -23
- flyte/_cli/_deploy.py +0 -140
- flyte/_cli/_get.py +0 -235
- flyte/_cli/_run.py +0 -152
- flyte/_cli/main.py +0 -72
- flyte/_datastructures.py +0 -342
- flyte/_internal/controllers/pbhash.py +0 -39
- flyte/_protos/common/authorization_pb2.py +0 -66
- flyte/_protos/common/authorization_pb2.pyi +0 -108
- flyte/_protos/common/authorization_pb2_grpc.py +0 -4
- flyte/_protos/common/identifier_pb2.py +0 -71
- flyte/_protos/common/identifier_pb2.pyi +0 -82
- flyte/_protos/common/identifier_pb2_grpc.py +0 -4
- flyte/_protos/common/identity_pb2.py +0 -48
- flyte/_protos/common/identity_pb2.pyi +0 -72
- flyte/_protos/common/identity_pb2_grpc.py +0 -4
- flyte/_protos/common/list_pb2.py +0 -36
- flyte/_protos/common/list_pb2.pyi +0 -69
- flyte/_protos/common/list_pb2_grpc.py +0 -4
- flyte/_protos/common/policy_pb2.py +0 -37
- flyte/_protos/common/policy_pb2.pyi +0 -27
- flyte/_protos/common/policy_pb2_grpc.py +0 -4
- flyte/_protos/common/role_pb2.py +0 -37
- flyte/_protos/common/role_pb2.pyi +0 -53
- flyte/_protos/common/role_pb2_grpc.py +0 -4
- flyte/_protos/common/runtime_version_pb2.py +0 -28
- flyte/_protos/common/runtime_version_pb2.pyi +0 -24
- flyte/_protos/common/runtime_version_pb2_grpc.py +0 -4
- flyte/_protos/logs/dataplane/payload_pb2.py +0 -96
- flyte/_protos/logs/dataplane/payload_pb2.pyi +0 -168
- flyte/_protos/logs/dataplane/payload_pb2_grpc.py +0 -4
- flyte/_protos/secret/definition_pb2.py +0 -49
- flyte/_protos/secret/definition_pb2.pyi +0 -93
- flyte/_protos/secret/definition_pb2_grpc.py +0 -4
- flyte/_protos/secret/payload_pb2.py +0 -62
- flyte/_protos/secret/payload_pb2.pyi +0 -94
- flyte/_protos/secret/payload_pb2_grpc.py +0 -4
- flyte/_protos/secret/secret_pb2.py +0 -38
- flyte/_protos/secret/secret_pb2.pyi +0 -6
- flyte/_protos/secret/secret_pb2_grpc.py +0 -198
- flyte/_protos/secret/secret_pb2_grpc_grpc.py +0 -198
- flyte/_protos/validate/validate/validate_pb2.py +0 -76
- flyte/_protos/workflow/node_execution_service_pb2.py +0 -26
- flyte/_protos/workflow/node_execution_service_pb2.pyi +0 -4
- flyte/_protos/workflow/node_execution_service_pb2_grpc.py +0 -32
- flyte/_protos/workflow/queue_service_pb2.py +0 -106
- flyte/_protos/workflow/queue_service_pb2.pyi +0 -141
- flyte/_protos/workflow/queue_service_pb2_grpc.py +0 -172
- flyte/_protos/workflow/run_definition_pb2.py +0 -128
- flyte/_protos/workflow/run_definition_pb2.pyi +0 -310
- flyte/_protos/workflow/run_definition_pb2_grpc.py +0 -4
- flyte/_protos/workflow/run_logs_service_pb2.py +0 -41
- flyte/_protos/workflow/run_logs_service_pb2.pyi +0 -28
- flyte/_protos/workflow/run_logs_service_pb2_grpc.py +0 -69
- flyte/_protos/workflow/run_service_pb2.py +0 -133
- flyte/_protos/workflow/run_service_pb2.pyi +0 -175
- flyte/_protos/workflow/run_service_pb2_grpc.py +0 -412
- flyte/_protos/workflow/state_service_pb2.py +0 -58
- flyte/_protos/workflow/state_service_pb2.pyi +0 -71
- flyte/_protos/workflow/state_service_pb2_grpc.py +0 -138
- flyte/_protos/workflow/task_definition_pb2.py +0 -72
- flyte/_protos/workflow/task_definition_pb2.pyi +0 -65
- flyte/_protos/workflow/task_definition_pb2_grpc.py +0 -4
- flyte/_protos/workflow/task_service_pb2.py +0 -44
- flyte/_protos/workflow/task_service_pb2.pyi +0 -31
- flyte/_protos/workflow/task_service_pb2_grpc.py +0 -104
- flyte/io/_dataframe.py +0 -0
- flyte/io/pickle/__init__.py +0 -0
- flyte/remote/_console.py +0 -18
- flyte-0.0.1b0.dist-info/METADATA +0 -179
- flyte-0.0.1b0.dist-info/RECORD +0 -390
- flyte-0.0.1b0.dist-info/entry_points.txt +0 -3
- union/__init__.py +0 -54
- union/_api_commons.py +0 -3
- union/_bin/__init__.py +0 -0
- union/_bin/runtime.py +0 -113
- union/_build.py +0 -25
- union/_cache/__init__.py +0 -12
- union/_cache/cache.py +0 -141
- union/_cache/defaults.py +0 -9
- union/_cache/policy_function_body.py +0 -42
- union/_cli/__init__.py +0 -0
- union/_cli/_common.py +0 -263
- union/_cli/_create.py +0 -40
- union/_cli/_delete.py +0 -23
- union/_cli/_deploy.py +0 -120
- union/_cli/_get.py +0 -162
- union/_cli/_run.py +0 -150
- union/_cli/main.py +0 -72
- union/_code_bundle/__init__.py +0 -8
- union/_code_bundle/_ignore.py +0 -113
- union/_code_bundle/_packaging.py +0 -187
- union/_code_bundle/_utils.py +0 -342
- union/_code_bundle/bundle.py +0 -176
- union/_context.py +0 -146
- union/_datastructures.py +0 -295
- union/_deploy.py +0 -185
- union/_doc.py +0 -29
- union/_docstring.py +0 -26
- union/_environment.py +0 -43
- union/_group.py +0 -31
- union/_hash.py +0 -23
- union/_image.py +0 -760
- union/_initialize.py +0 -585
- union/_interface.py +0 -84
- union/_internal/__init__.py +0 -3
- union/_internal/controllers/__init__.py +0 -77
- union/_internal/controllers/_local_controller.py +0 -77
- union/_internal/controllers/pbhash.py +0 -39
- union/_internal/controllers/remote/__init__.py +0 -40
- union/_internal/controllers/remote/_action.py +0 -131
- union/_internal/controllers/remote/_client.py +0 -43
- union/_internal/controllers/remote/_controller.py +0 -169
- union/_internal/controllers/remote/_core.py +0 -341
- union/_internal/controllers/remote/_informer.py +0 -260
- union/_internal/controllers/remote/_service_protocol.py +0 -44
- union/_internal/imagebuild/__init__.py +0 -11
- union/_internal/imagebuild/docker_builder.py +0 -416
- union/_internal/imagebuild/image_builder.py +0 -243
- union/_internal/imagebuild/remote_builder.py +0 -0
- union/_internal/resolvers/__init__.py +0 -0
- union/_internal/resolvers/_task_module.py +0 -31
- union/_internal/resolvers/common.py +0 -24
- union/_internal/resolvers/default.py +0 -27
- union/_internal/runtime/__init__.py +0 -0
- union/_internal/runtime/convert.py +0 -163
- union/_internal/runtime/entrypoints.py +0 -121
- union/_internal/runtime/io.py +0 -136
- union/_internal/runtime/resources_serde.py +0 -134
- union/_internal/runtime/task_serde.py +0 -202
- union/_internal/runtime/taskrunner.py +0 -179
- union/_internal/runtime/types_serde.py +0 -53
- union/_logging.py +0 -124
- union/_protos/__init__.py +0 -0
- union/_protos/common/authorization_pb2.py +0 -66
- union/_protos/common/authorization_pb2.pyi +0 -106
- union/_protos/common/authorization_pb2_grpc.py +0 -4
- union/_protos/common/identifier_pb2.py +0 -71
- union/_protos/common/identifier_pb2.pyi +0 -82
- union/_protos/common/identifier_pb2_grpc.py +0 -4
- union/_protos/common/identity_pb2.py +0 -48
- union/_protos/common/identity_pb2.pyi +0 -72
- union/_protos/common/identity_pb2_grpc.py +0 -4
- union/_protos/common/list_pb2.py +0 -36
- union/_protos/common/list_pb2.pyi +0 -69
- union/_protos/common/list_pb2_grpc.py +0 -4
- union/_protos/common/policy_pb2.py +0 -37
- union/_protos/common/policy_pb2.pyi +0 -27
- union/_protos/common/policy_pb2_grpc.py +0 -4
- union/_protos/common/role_pb2.py +0 -37
- union/_protos/common/role_pb2.pyi +0 -51
- union/_protos/common/role_pb2_grpc.py +0 -4
- union/_protos/common/runtime_version_pb2.py +0 -28
- union/_protos/common/runtime_version_pb2.pyi +0 -24
- union/_protos/common/runtime_version_pb2_grpc.py +0 -4
- union/_protos/logs/dataplane/payload_pb2.py +0 -96
- union/_protos/logs/dataplane/payload_pb2.pyi +0 -168
- union/_protos/logs/dataplane/payload_pb2_grpc.py +0 -4
- union/_protos/secret/definition_pb2.py +0 -49
- union/_protos/secret/definition_pb2.pyi +0 -93
- union/_protos/secret/definition_pb2_grpc.py +0 -4
- union/_protos/secret/payload_pb2.py +0 -62
- union/_protos/secret/payload_pb2.pyi +0 -94
- union/_protos/secret/payload_pb2_grpc.py +0 -4
- union/_protos/secret/secret_pb2.py +0 -38
- union/_protos/secret/secret_pb2.pyi +0 -6
- union/_protos/secret/secret_pb2_grpc.py +0 -198
- union/_protos/validate/validate/validate_pb2.py +0 -76
- union/_protos/workflow/node_execution_service_pb2.py +0 -26
- union/_protos/workflow/node_execution_service_pb2.pyi +0 -4
- union/_protos/workflow/node_execution_service_pb2_grpc.py +0 -32
- union/_protos/workflow/queue_service_pb2.py +0 -75
- union/_protos/workflow/queue_service_pb2.pyi +0 -103
- union/_protos/workflow/queue_service_pb2_grpc.py +0 -172
- union/_protos/workflow/run_definition_pb2.py +0 -100
- union/_protos/workflow/run_definition_pb2.pyi +0 -256
- union/_protos/workflow/run_definition_pb2_grpc.py +0 -4
- union/_protos/workflow/run_logs_service_pb2.py +0 -41
- union/_protos/workflow/run_logs_service_pb2.pyi +0 -28
- union/_protos/workflow/run_logs_service_pb2_grpc.py +0 -69
- union/_protos/workflow/run_service_pb2.py +0 -133
- union/_protos/workflow/run_service_pb2.pyi +0 -173
- union/_protos/workflow/run_service_pb2_grpc.py +0 -412
- union/_protos/workflow/state_service_pb2.py +0 -58
- union/_protos/workflow/state_service_pb2.pyi +0 -69
- union/_protos/workflow/state_service_pb2_grpc.py +0 -138
- union/_protos/workflow/task_definition_pb2.py +0 -72
- union/_protos/workflow/task_definition_pb2.pyi +0 -65
- union/_protos/workflow/task_definition_pb2_grpc.py +0 -4
- union/_protos/workflow/task_service_pb2.py +0 -44
- union/_protos/workflow/task_service_pb2.pyi +0 -31
- union/_protos/workflow/task_service_pb2_grpc.py +0 -104
- union/_resources.py +0 -226
- union/_retry.py +0 -32
- union/_reusable_environment.py +0 -25
- union/_run.py +0 -374
- union/_secret.py +0 -61
- union/_task.py +0 -354
- union/_task_environment.py +0 -186
- union/_timeout.py +0 -47
- union/_tools.py +0 -27
- union/_utils/__init__.py +0 -11
- union/_utils/asyn.py +0 -119
- union/_utils/file_handling.py +0 -71
- union/_utils/helpers.py +0 -46
- union/_utils/lazy_module.py +0 -54
- union/_utils/uv_script_parser.py +0 -49
- union/_version.py +0 -21
- union/connectors/__init__.py +0 -0
- union/errors.py +0 -128
- union/extras/__init__.py +0 -5
- union/extras/_container.py +0 -263
- union/io/__init__.py +0 -11
- union/io/_dataframe.py +0 -0
- union/io/_dir.py +0 -425
- union/io/_file.py +0 -418
- union/io/pickle/__init__.py +0 -0
- union/io/pickle/transformer.py +0 -117
- union/io/structured_dataset/__init__.py +0 -122
- union/io/structured_dataset/basic_dfs.py +0 -219
- union/io/structured_dataset/structured_dataset.py +0 -1057
- union/py.typed +0 -0
- union/remote/__init__.py +0 -23
- union/remote/_client/__init__.py +0 -0
- union/remote/_client/_protocols.py +0 -129
- union/remote/_client/auth/__init__.py +0 -12
- union/remote/_client/auth/_authenticators/__init__.py +0 -0
- union/remote/_client/auth/_authenticators/base.py +0 -391
- union/remote/_client/auth/_authenticators/client_credentials.py +0 -73
- union/remote/_client/auth/_authenticators/device_code.py +0 -120
- union/remote/_client/auth/_authenticators/external_command.py +0 -77
- union/remote/_client/auth/_authenticators/factory.py +0 -200
- union/remote/_client/auth/_authenticators/pkce.py +0 -515
- union/remote/_client/auth/_channel.py +0 -184
- union/remote/_client/auth/_client_config.py +0 -83
- union/remote/_client/auth/_default_html.py +0 -32
- union/remote/_client/auth/_grpc_utils/__init__.py +0 -0
- union/remote/_client/auth/_grpc_utils/auth_interceptor.py +0 -204
- union/remote/_client/auth/_grpc_utils/default_metadata_interceptor.py +0 -144
- union/remote/_client/auth/_keyring.py +0 -154
- union/remote/_client/auth/_token_client.py +0 -258
- union/remote/_client/auth/errors.py +0 -16
- union/remote/_client/controlplane.py +0 -86
- union/remote/_data.py +0 -149
- union/remote/_logs.py +0 -74
- union/remote/_project.py +0 -86
- union/remote/_run.py +0 -820
- union/remote/_secret.py +0 -132
- union/remote/_task.py +0 -193
- union/report/__init__.py +0 -3
- union/report/_report.py +0 -178
- union/report/_template.html +0 -124
- union/storage/__init__.py +0 -24
- union/storage/_remote_fs.py +0 -34
- union/storage/_storage.py +0 -247
- union/storage/_utils.py +0 -5
- union/types/__init__.py +0 -11
- union/types/_renderer.py +0 -162
- union/types/_string_literals.py +0 -120
- union/types/_type_engine.py +0 -2131
- union/types/_utils.py +0 -80
- /flyte/{_cli → _debug}/__init__.py +0 -0
- /flyte/{_protos → _keyring}/__init__.py +0 -0
- {flyte-0.0.1b0.dist-info → flyte-2.0.0b46.dist-info}/WHEEL +0 -0
- {flyte-0.0.1b0.dist-info → flyte-2.0.0b46.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,563 @@
|
|
|
1
|
+
"""
|
|
2
|
+
HuggingFace model prefetch utilities for Flyte.
|
|
3
|
+
|
|
4
|
+
This module provides functionality to prefetch HuggingFace models to remote storage,
|
|
5
|
+
with support for optional sharding using vLLM.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import os
|
|
11
|
+
import re
|
|
12
|
+
import shutil
|
|
13
|
+
import tempfile
|
|
14
|
+
import typing
|
|
15
|
+
from typing import TYPE_CHECKING, Any, Literal
|
|
16
|
+
|
|
17
|
+
from pydantic import BaseModel, Field
|
|
18
|
+
|
|
19
|
+
from flyte._logging import logger
|
|
20
|
+
from flyte._resources import Resources
|
|
21
|
+
from flyte._task_environment import TaskEnvironment
|
|
22
|
+
from flyte.io import Dir
|
|
23
|
+
|
|
24
|
+
if TYPE_CHECKING:
|
|
25
|
+
from flyte.remote import Run
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
DEFAULT_SHARD_PATTERN = "model-rank-{rank}-part-{part}.safetensors"
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class VLLMShardArgs(BaseModel):
|
|
32
|
+
"""
|
|
33
|
+
Arguments for sharding a model using vLLM.
|
|
34
|
+
|
|
35
|
+
:param tensor_parallel_size: Number of tensor parallel workers.
|
|
36
|
+
:param dtype: Data type for model weights.
|
|
37
|
+
:param trust_remote_code: Whether to trust remote code from HuggingFace.
|
|
38
|
+
:param max_model_len: Maximum model context length.
|
|
39
|
+
:param file_pattern: Pattern for sharded weight files.
|
|
40
|
+
:param max_file_size: Maximum size for each sharded file.
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
tensor_parallel_size: int = 1
|
|
44
|
+
dtype: str = "auto"
|
|
45
|
+
trust_remote_code: bool = True
|
|
46
|
+
max_model_len: int | None = None
|
|
47
|
+
file_pattern: str | None = DEFAULT_SHARD_PATTERN
|
|
48
|
+
max_file_size: int = 5 * 1024**3 # 5GB default
|
|
49
|
+
|
|
50
|
+
def get_vllm_args(self, model_path: str) -> dict[str, Any]:
|
|
51
|
+
"""Get arguments dict for vLLM LLM constructor."""
|
|
52
|
+
args = {
|
|
53
|
+
"model": model_path,
|
|
54
|
+
"tensor_parallel_size": self.tensor_parallel_size,
|
|
55
|
+
"dtype": self.dtype,
|
|
56
|
+
"trust_remote_code": self.trust_remote_code,
|
|
57
|
+
}
|
|
58
|
+
if self.max_model_len is not None:
|
|
59
|
+
args["max_model_len"] = self.max_model_len
|
|
60
|
+
return args
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class ShardConfig(BaseModel):
|
|
64
|
+
"""
|
|
65
|
+
Configuration for model sharding.
|
|
66
|
+
|
|
67
|
+
:param engine: The sharding engine to use (currently only "vllm" is supported).
|
|
68
|
+
:param args: Arguments for the sharding engine.
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
engine: Literal["vllm"] = "vllm"
|
|
72
|
+
args: VLLMShardArgs = Field(default_factory=VLLMShardArgs)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class HuggingFaceModelInfo(BaseModel):
|
|
76
|
+
"""
|
|
77
|
+
Information about a HuggingFace model to store.
|
|
78
|
+
|
|
79
|
+
:param repo: The HuggingFace repository ID (e.g., 'meta-llama/Llama-2-7b-hf').
|
|
80
|
+
:param artifact_name: Optional name for the stored artifact. If not provided,
|
|
81
|
+
the repo name will be used (with '.' replaced by '-').
|
|
82
|
+
:param architecture: Model architecture from HuggingFace config.json.
|
|
83
|
+
:param task: Model task (e.g., 'generate', 'classify', 'embed').
|
|
84
|
+
:param modality: Modalities supported by the model (e.g., 'text', 'image').
|
|
85
|
+
:param serial_format: Model serialization format (e.g., 'safetensors', 'onnx').
|
|
86
|
+
:param model_type: Model type (e.g., 'transformer', 'custom').
|
|
87
|
+
:param short_description: Short description of the model.
|
|
88
|
+
:param shard_config: Optional configuration for model sharding.
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
repo: str
|
|
92
|
+
artifact_name: str | None = None
|
|
93
|
+
architecture: str | None = None
|
|
94
|
+
task: str = "auto"
|
|
95
|
+
modality: tuple[str, ...] = ("text",)
|
|
96
|
+
serial_format: str | None = None
|
|
97
|
+
model_type: str | None = None
|
|
98
|
+
short_description: str | None = None
|
|
99
|
+
shard_config: ShardConfig | None = None
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
class StoredModelInfo(BaseModel):
|
|
103
|
+
"""
|
|
104
|
+
Information about a stored model.
|
|
105
|
+
|
|
106
|
+
:param artifact_name: Name of the stored artifact.
|
|
107
|
+
:param path: Path to the stored model directory.
|
|
108
|
+
:param metadata: Metadata about the stored model.
|
|
109
|
+
"""
|
|
110
|
+
|
|
111
|
+
artifact_name: str
|
|
112
|
+
path: str
|
|
113
|
+
metadata: dict[str, str]
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
# Image definitions for the store task
|
|
117
|
+
HF_DOWNLOAD_IMAGE_PACKAGES = [
|
|
118
|
+
"huggingface-hub>=0.27.0",
|
|
119
|
+
"hf-transfer>=0.1.8",
|
|
120
|
+
"markdown>=3.10",
|
|
121
|
+
]
|
|
122
|
+
|
|
123
|
+
VLLM_SHARDING_IMAGE_PACKAGES = [
|
|
124
|
+
*HF_DOWNLOAD_IMAGE_PACKAGES,
|
|
125
|
+
"vllm>=0.11.0",
|
|
126
|
+
]
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def _validate_artifact_name(name: str | None) -> None:
|
|
130
|
+
"""Validate that artifact name contains only allowed characters."""
|
|
131
|
+
if name is not None and not re.match(r"^[a-zA-Z0-9_-]+$", name):
|
|
132
|
+
raise ValueError(f"Artifact name '{name}' must only contain alphanumeric characters, underscores, and hyphens")
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def _lookup_huggingface_model_info(model_repo: str, commit: str, token: str | None) -> tuple[str | None, str | None]:
|
|
136
|
+
"""
|
|
137
|
+
Lookup HuggingFace model info from config.json.
|
|
138
|
+
|
|
139
|
+
:param model_repo: The model repository ID.
|
|
140
|
+
:param commit: The commit ID.
|
|
141
|
+
:param token: HuggingFace token for private models.
|
|
142
|
+
:return: Tuple of (model_type, architecture).
|
|
143
|
+
"""
|
|
144
|
+
import json
|
|
145
|
+
|
|
146
|
+
import huggingface_hub
|
|
147
|
+
|
|
148
|
+
config_file = huggingface_hub.hf_hub_download(
|
|
149
|
+
repo_id=model_repo, filename="config.json", revision=commit, token=token
|
|
150
|
+
)
|
|
151
|
+
arch = None
|
|
152
|
+
model_type = None
|
|
153
|
+
with open(config_file, "r") as f:
|
|
154
|
+
j = json.load(f)
|
|
155
|
+
arch = j.get("architecture", None)
|
|
156
|
+
if arch is None:
|
|
157
|
+
arch = j.get("architectures", None)
|
|
158
|
+
if arch:
|
|
159
|
+
arch = ",".join(arch)
|
|
160
|
+
model_type = j.get("model_type", None)
|
|
161
|
+
return model_type, arch
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def _stream_to_remote_dir(
|
|
165
|
+
repo_id: str,
|
|
166
|
+
commit: str,
|
|
167
|
+
token: str | None,
|
|
168
|
+
remote_dir_path: str,
|
|
169
|
+
) -> tuple[str, str | None]:
|
|
170
|
+
"""
|
|
171
|
+
Stream files directly from HuggingFace to a remote directory.
|
|
172
|
+
|
|
173
|
+
:param repo_id: The HuggingFace repository ID.
|
|
174
|
+
:param commit: The commit ID.
|
|
175
|
+
:param token: HuggingFace token.
|
|
176
|
+
:param remote_dir_path: Path to the remote directory.
|
|
177
|
+
:return: Tuple of (remote_dir_path, readme_content).
|
|
178
|
+
"""
|
|
179
|
+
import huggingface_hub
|
|
180
|
+
|
|
181
|
+
import flyte.storage as storage
|
|
182
|
+
|
|
183
|
+
hfs = huggingface_hub.HfFileSystem(token=token)
|
|
184
|
+
fs = storage.get_underlying_filesystem(path=remote_dir_path)
|
|
185
|
+
card = None
|
|
186
|
+
|
|
187
|
+
# Try to get README
|
|
188
|
+
try:
|
|
189
|
+
readme_file_details = hfs.info(f"{repo_id}/README.md", revision=commit)
|
|
190
|
+
readme_name = readme_file_details["name"]
|
|
191
|
+
with tempfile.NamedTemporaryFile() as temp_file:
|
|
192
|
+
hfs.download(readme_name, temp_file.name, revision=commit)
|
|
193
|
+
with open(temp_file.name, "r") as f:
|
|
194
|
+
card = f.read()
|
|
195
|
+
except FileNotFoundError:
|
|
196
|
+
logger.info("No README.md file found")
|
|
197
|
+
|
|
198
|
+
# List all files in the repo
|
|
199
|
+
repo_files = hfs.ls(f"{repo_id}", revision=commit, detail=True)
|
|
200
|
+
|
|
201
|
+
logger.info(f"Streaming {len(repo_files)} files to {remote_dir_path}")
|
|
202
|
+
|
|
203
|
+
for file_info in repo_files:
|
|
204
|
+
if isinstance(file_info, str):
|
|
205
|
+
logger.info(f" Skipping {file_info}...")
|
|
206
|
+
continue
|
|
207
|
+
if file_info["type"] == "file":
|
|
208
|
+
file_name = file_info["name"].split("/")[-1]
|
|
209
|
+
remote_file_path = f"{remote_dir_path}/{file_name}"
|
|
210
|
+
logger.info(f" Streaming {file_name}...")
|
|
211
|
+
|
|
212
|
+
# Stream file content directly to remote
|
|
213
|
+
with hfs.open(file_info["name"], "rb", revision=commit) as src:
|
|
214
|
+
with fs.open(remote_file_path, "wb") as dst:
|
|
215
|
+
# Stream in chunks
|
|
216
|
+
chunk_size = 64 * 1024 * 1024 # 64MB chunks
|
|
217
|
+
while True:
|
|
218
|
+
chunk = src.read(chunk_size)
|
|
219
|
+
if not chunk:
|
|
220
|
+
break
|
|
221
|
+
dst.write(chunk)
|
|
222
|
+
|
|
223
|
+
return remote_dir_path, card
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
def _download_snapshot_to_local(
|
|
227
|
+
repo_id: str,
|
|
228
|
+
commit: str,
|
|
229
|
+
token: str | None,
|
|
230
|
+
local_dir: str,
|
|
231
|
+
) -> tuple[str, str | None]:
|
|
232
|
+
"""
|
|
233
|
+
Download model snapshot to local directory.
|
|
234
|
+
|
|
235
|
+
:param repo_id: The HuggingFace repository ID.
|
|
236
|
+
:param commit: The commit ID.
|
|
237
|
+
:param token: HuggingFace token.
|
|
238
|
+
:param local_dir: Local directory to download to.
|
|
239
|
+
:return: Tuple of (local_dir, readme_content).
|
|
240
|
+
"""
|
|
241
|
+
import huggingface_hub
|
|
242
|
+
|
|
243
|
+
card = None
|
|
244
|
+
hfs = huggingface_hub.HfFileSystem(token=token)
|
|
245
|
+
|
|
246
|
+
# Try to get README
|
|
247
|
+
try:
|
|
248
|
+
readme_file_details = hfs.info(f"{repo_id}/README.md", revision=commit)
|
|
249
|
+
readme_name = readme_file_details["name"]
|
|
250
|
+
with tempfile.NamedTemporaryFile() as temp_file:
|
|
251
|
+
hfs.download(readme_name, temp_file.name, revision=commit)
|
|
252
|
+
with open(temp_file.name, "r") as f:
|
|
253
|
+
card = f.read()
|
|
254
|
+
except FileNotFoundError:
|
|
255
|
+
logger.info("No README.md file found")
|
|
256
|
+
|
|
257
|
+
logger.info(f"Downloading model from {repo_id} to {local_dir}")
|
|
258
|
+
huggingface_hub.snapshot_download(
|
|
259
|
+
repo_id=repo_id,
|
|
260
|
+
revision=commit,
|
|
261
|
+
local_dir=local_dir,
|
|
262
|
+
token=token,
|
|
263
|
+
)
|
|
264
|
+
return local_dir, card
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
def _shard_model(
|
|
268
|
+
repo: str,
|
|
269
|
+
commit: str,
|
|
270
|
+
shard_config: ShardConfig,
|
|
271
|
+
token: str,
|
|
272
|
+
model_path: str,
|
|
273
|
+
output_dir: str,
|
|
274
|
+
) -> tuple[str, str | None]:
|
|
275
|
+
"""
|
|
276
|
+
Shard a model using vLLM.
|
|
277
|
+
|
|
278
|
+
:param shard_config: Sharding configuration.
|
|
279
|
+
:param model_path: Path to the model to shard.
|
|
280
|
+
:param output_dir: Directory to save sharded model.
|
|
281
|
+
:return: Path to sharded model directory.
|
|
282
|
+
"""
|
|
283
|
+
import huggingface_hub
|
|
284
|
+
import vllm
|
|
285
|
+
|
|
286
|
+
assert shard_config.engine == "vllm", "'vllm' is the only supported sharding engine for now"
|
|
287
|
+
|
|
288
|
+
# Download snapshot
|
|
289
|
+
hfs = huggingface_hub.HfFileSystem(token=token)
|
|
290
|
+
try:
|
|
291
|
+
readme_info = hfs.info(f"{repo}/README.md", revision=commit)
|
|
292
|
+
with tempfile.NamedTemporaryFile() as temp_file:
|
|
293
|
+
hfs.download(readme_info["name"], temp_file.name, revision=commit)
|
|
294
|
+
with open(temp_file.name, "r") as f:
|
|
295
|
+
card = f.read()
|
|
296
|
+
except FileNotFoundError:
|
|
297
|
+
logger.warning("No README.md found")
|
|
298
|
+
|
|
299
|
+
logger.info(f"Downloading model to {model_path}")
|
|
300
|
+
huggingface_hub.snapshot_download(
|
|
301
|
+
repo_id=repo,
|
|
302
|
+
revision=commit,
|
|
303
|
+
local_dir=model_path,
|
|
304
|
+
token=token,
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
# Create LLM instance
|
|
308
|
+
llm = vllm.LLM(**shard_config.args.get_vllm_args(model_path))
|
|
309
|
+
logger.info(f"LLM initialized: {llm}")
|
|
310
|
+
|
|
311
|
+
llm.llm_engine.engine_core.save_sharded_state(
|
|
312
|
+
path=output_dir,
|
|
313
|
+
pattern=shard_config.args.file_pattern,
|
|
314
|
+
max_size=shard_config.args.max_file_size,
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
# Copy metadata files to output directory
|
|
318
|
+
logger.info(f"Copying metadata files to {output_dir}")
|
|
319
|
+
for file in os.listdir(model_path):
|
|
320
|
+
if os.path.splitext(file)[1] not in (".bin", ".pt", ".safetensors"):
|
|
321
|
+
src_path = os.path.join(model_path, file)
|
|
322
|
+
dst_path = os.path.join(output_dir, file)
|
|
323
|
+
if os.path.isdir(src_path):
|
|
324
|
+
shutil.copytree(src_path, dst_path)
|
|
325
|
+
else:
|
|
326
|
+
shutil.copy(src_path, dst_path)
|
|
327
|
+
|
|
328
|
+
return output_dir, card
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
# NOTE: the info argument is a json string instead of a HuggingFaceModelInfo
|
|
332
|
+
# object because the type engine cannot handle nested pydantic or dataclass
|
|
333
|
+
# objects when run in interactive mode.
|
|
334
|
+
def store_hf_model_task(info: str, raw_data_path: str | None = None) -> Dir:
|
|
335
|
+
"""Task to store a HuggingFace model."""
|
|
336
|
+
|
|
337
|
+
import huggingface_hub
|
|
338
|
+
|
|
339
|
+
import flyte.report
|
|
340
|
+
|
|
341
|
+
# Get HF token from secrets
|
|
342
|
+
token = os.environ.get("HF_TOKEN")
|
|
343
|
+
assert token is not None, "HF_TOKEN environment variable is not set"
|
|
344
|
+
|
|
345
|
+
# Validate repo exists and get latest commit
|
|
346
|
+
_info: HuggingFaceModelInfo = HuggingFaceModelInfo.model_validate_json(info)
|
|
347
|
+
if not huggingface_hub.repo_exists(_info.repo, token=token):
|
|
348
|
+
raise ValueError(f"Repository {_info.repo} does not exist in HuggingFace.")
|
|
349
|
+
|
|
350
|
+
commit = huggingface_hub.list_repo_commits(_info.repo, token=token)[0].commit_id
|
|
351
|
+
logger.info(f"Latest commit: {commit}")
|
|
352
|
+
|
|
353
|
+
# Lookup model info if not provided
|
|
354
|
+
if not _info.model_type or not _info.architecture:
|
|
355
|
+
logger.info("Looking up HuggingFace model info...")
|
|
356
|
+
try:
|
|
357
|
+
_info.model_type, _info.architecture = _lookup_huggingface_model_info(_info.repo, commit, token)
|
|
358
|
+
except Exception as e:
|
|
359
|
+
logger.warning(f"Warning: Could not lookup model info: {e}")
|
|
360
|
+
_info.model_type = "custom"
|
|
361
|
+
_info.architecture = "custom"
|
|
362
|
+
|
|
363
|
+
logger.info(f"Model type: {_info.model_type}, architecture: {_info.architecture}")
|
|
364
|
+
|
|
365
|
+
# Determine artifact name
|
|
366
|
+
if _info.artifact_name is None:
|
|
367
|
+
artifact_name = _info.repo.split("/")[-1].replace(".", "-")
|
|
368
|
+
else:
|
|
369
|
+
artifact_name = _info.artifact_name
|
|
370
|
+
|
|
371
|
+
card = None
|
|
372
|
+
result_dir: Dir
|
|
373
|
+
|
|
374
|
+
# If sharding is needed, we must download locally first
|
|
375
|
+
if _info.shard_config is not None:
|
|
376
|
+
logger.info(f"Sharding requested with {_info.shard_config.engine} engine")
|
|
377
|
+
|
|
378
|
+
# Download to local temp directory
|
|
379
|
+
sharded_dir = tempfile.mkdtemp()
|
|
380
|
+
with tempfile.TemporaryDirectory() as local_model_dir:
|
|
381
|
+
sharded_dir, card = _shard_model(
|
|
382
|
+
_info.repo, commit, _info.shard_config, token, local_model_dir, sharded_dir
|
|
383
|
+
)
|
|
384
|
+
|
|
385
|
+
# Upload sharded model
|
|
386
|
+
logger.info("Uploading sharded model...")
|
|
387
|
+
result_dir = Dir.from_local_sync(sharded_dir, remote_destination=raw_data_path)
|
|
388
|
+
|
|
389
|
+
else:
|
|
390
|
+
# Try direct streaming first
|
|
391
|
+
try:
|
|
392
|
+
logger.info("Attempting direct streaming to remote storage...")
|
|
393
|
+
|
|
394
|
+
if raw_data_path is not None:
|
|
395
|
+
remote_path = raw_data_path
|
|
396
|
+
else:
|
|
397
|
+
remote_path = flyte.ctx().raw_data_path.get_random_remote_path(artifact_name) # type: ignore [union-attr]
|
|
398
|
+
|
|
399
|
+
remote_path, card = _stream_to_remote_dir(_info.repo, commit, token, remote_path)
|
|
400
|
+
result_dir = Dir.from_existing_remote(remote_path)
|
|
401
|
+
logger.info(f"Direct streaming completed to {remote_path}")
|
|
402
|
+
|
|
403
|
+
except Exception as e:
|
|
404
|
+
logger.error(f"Direct streaming failed: {e}")
|
|
405
|
+
logger.error("Falling back to snapshot download...")
|
|
406
|
+
|
|
407
|
+
# Fallback: download snapshot and upload
|
|
408
|
+
with tempfile.TemporaryDirectory() as local_model_dir:
|
|
409
|
+
_local_model_dir, card = _download_snapshot_to_local(_info.repo, commit, token, local_model_dir)
|
|
410
|
+
result_dir = Dir.from_local_sync(_local_model_dir, remote_destination=raw_data_path)
|
|
411
|
+
|
|
412
|
+
# create report from the markdown `card`
|
|
413
|
+
if card:
|
|
414
|
+
# Try to convert markdown to HTML for richer presentation, fallback to plain text
|
|
415
|
+
try:
|
|
416
|
+
# Try to import markdown if available (don't add import; just use if exists)
|
|
417
|
+
import markdown
|
|
418
|
+
|
|
419
|
+
report = markdown.markdown(card)
|
|
420
|
+
except Exception:
|
|
421
|
+
report = card # fallback to plain markdown content
|
|
422
|
+
flyte.report.log(report)
|
|
423
|
+
flyte.report.flush()
|
|
424
|
+
|
|
425
|
+
logger.info(f"Model stored successfully at {result_dir.path}")
|
|
426
|
+
return result_dir
|
|
427
|
+
|
|
428
|
+
|
|
429
|
+
def hf_model(
|
|
430
|
+
repo: str,
|
|
431
|
+
*,
|
|
432
|
+
raw_data_path: str | None = None,
|
|
433
|
+
artifact_name: str | None = None,
|
|
434
|
+
architecture: str | None = None,
|
|
435
|
+
task: str = "auto",
|
|
436
|
+
modality: tuple[str, ...] = ("text",),
|
|
437
|
+
serial_format: str | None = None,
|
|
438
|
+
model_type: str | None = None,
|
|
439
|
+
short_description: str | None = None,
|
|
440
|
+
shard_config: ShardConfig | None = None,
|
|
441
|
+
hf_token_key: str = "HF_TOKEN",
|
|
442
|
+
resources: Resources = Resources(cpu="2", memory="8Gi", disk="50Gi"),
|
|
443
|
+
force: int = 0,
|
|
444
|
+
) -> Run:
|
|
445
|
+
"""
|
|
446
|
+
Store a HuggingFace model to remote storage.
|
|
447
|
+
|
|
448
|
+
This function downloads a model from the HuggingFace Hub and prefetches it to
|
|
449
|
+
remote storage. It supports optional sharding using vLLM for large models.
|
|
450
|
+
|
|
451
|
+
The prefetch behavior follows this priority:
|
|
452
|
+
1. If the model isn't being sharded, stream files directly to remote storage.
|
|
453
|
+
2. If streaming fails, fall back to downloading a snapshot and uploading.
|
|
454
|
+
3. If sharding is configured, download locally, shard with vLLM, then upload.
|
|
455
|
+
|
|
456
|
+
Example usage:
|
|
457
|
+
|
|
458
|
+
```python
|
|
459
|
+
import flyte
|
|
460
|
+
|
|
461
|
+
flyte.init(endpoint="my-flyte-endpoint")
|
|
462
|
+
|
|
463
|
+
# Store a model without sharding
|
|
464
|
+
run = flyte.prefetch.hf_model(
|
|
465
|
+
repo="meta-llama/Llama-2-7b-hf",
|
|
466
|
+
hf_token_key="HF_TOKEN",
|
|
467
|
+
)
|
|
468
|
+
run.wait()
|
|
469
|
+
|
|
470
|
+
# Prefetch and shard a model
|
|
471
|
+
from flyte.prefetch import ShardConfig, VLLMShardArgs
|
|
472
|
+
|
|
473
|
+
run = flyte.prefetch.hf_model(
|
|
474
|
+
repo="meta-llama/Llama-2-70b-hf",
|
|
475
|
+
shard_config=ShardConfig(
|
|
476
|
+
engine="vllm",
|
|
477
|
+
args=VLLMShardArgs(tensor_parallel_size=8),
|
|
478
|
+
),
|
|
479
|
+
accelerator="A100:8",
|
|
480
|
+
hf_token_key="HF_TOKEN",
|
|
481
|
+
)
|
|
482
|
+
run.wait()
|
|
483
|
+
```
|
|
484
|
+
|
|
485
|
+
:param repo: The HuggingFace repository ID (e.g., 'meta-llama/Llama-2-7b-hf').
|
|
486
|
+
:param artifact_name: Optional name for the stored artifact. If not provided,
|
|
487
|
+
the repo name will be used (with '.' replaced by '-').
|
|
488
|
+
:param architecture: Model architecture from HuggingFace config.json.
|
|
489
|
+
:param task: Model task (e.g., 'generate', 'classify', 'embed'). Default: 'auto'.
|
|
490
|
+
:param modality: Modalities supported by the model. Default: ('text',).
|
|
491
|
+
:param serial_format: Model serialization format (e.g., 'safetensors', 'onnx').
|
|
492
|
+
:param model_type: Model type (e.g., 'transformer', 'custom').
|
|
493
|
+
:param short_description: Short description of the model.
|
|
494
|
+
:param shard_config: Optional configuration for model sharding with vLLM.
|
|
495
|
+
:param hf_token_key: Name of the secret containing the HuggingFace token. Default: 'HF_TOKEN'.
|
|
496
|
+
:param cpu: CPU request for the prefetch task (e.g., '2').
|
|
497
|
+
:param mem: Memory request for the prefetch task (e.g., '16Gi').
|
|
498
|
+
:param disk: Disk storage request (e.g., '100Gi').
|
|
499
|
+
:param gpu: Accelerator type in format '{type}:{quantity}' (e.g., 'A100:8', 'L4:1').
|
|
500
|
+
:param shm: Shared memory request (e.g., '100Gi', 'auto').
|
|
501
|
+
:param wait: Whether to wait for the prefetch task to complete. Default: False.
|
|
502
|
+
:param force: Force re-prefetch. Increment to force a new prefetch. Default: 0.
|
|
503
|
+
|
|
504
|
+
:return: A Run object representing the prefetch task execution.
|
|
505
|
+
"""
|
|
506
|
+
import flyte
|
|
507
|
+
from flyte import Secret
|
|
508
|
+
from flyte.remote import Run
|
|
509
|
+
|
|
510
|
+
_validate_artifact_name(artifact_name)
|
|
511
|
+
|
|
512
|
+
info = HuggingFaceModelInfo(
|
|
513
|
+
repo=repo,
|
|
514
|
+
artifact_name=artifact_name,
|
|
515
|
+
architecture=architecture,
|
|
516
|
+
task=task,
|
|
517
|
+
modality=modality,
|
|
518
|
+
serial_format=serial_format,
|
|
519
|
+
model_type=model_type,
|
|
520
|
+
short_description=short_description,
|
|
521
|
+
shard_config=shard_config,
|
|
522
|
+
)
|
|
523
|
+
|
|
524
|
+
# Select image based on whether sharding is needed
|
|
525
|
+
if shard_config is not None:
|
|
526
|
+
image = (
|
|
527
|
+
flyte.Image.from_debian_base(name="prefetch-hf-model-image")
|
|
528
|
+
.with_apt_packages("gcc", "wget")
|
|
529
|
+
.with_commands(
|
|
530
|
+
[
|
|
531
|
+
"wget https://developer.download.nvidia.com/compute/cuda/repos/debian12/x86_64/cuda-keyring_1.1-1_all.deb",
|
|
532
|
+
"dpkg -i cuda-keyring_1.1-1_all.deb",
|
|
533
|
+
"apt-get update",
|
|
534
|
+
"apt-get install -y cuda-toolkit-12-9",
|
|
535
|
+
]
|
|
536
|
+
)
|
|
537
|
+
.with_env_vars(
|
|
538
|
+
{
|
|
539
|
+
"CUDA_HOME": "/usr/local/cuda-12.9",
|
|
540
|
+
"LD_LIBRARY_PATH": "/usr/local/cuda-12.9/lib64/stubs",
|
|
541
|
+
"VLLM_USE_V1": "1",
|
|
542
|
+
}
|
|
543
|
+
)
|
|
544
|
+
.with_pip_packages(*VLLM_SHARDING_IMAGE_PACKAGES)
|
|
545
|
+
)
|
|
546
|
+
else:
|
|
547
|
+
image = flyte.Image.from_debian_base(name="prefetch-hf-model-image").with_pip_packages(
|
|
548
|
+
*HF_DOWNLOAD_IMAGE_PACKAGES
|
|
549
|
+
)
|
|
550
|
+
|
|
551
|
+
# Create a task from the module-level function with the configured environment
|
|
552
|
+
disable_run_cache = force > 0
|
|
553
|
+
env = TaskEnvironment(
|
|
554
|
+
name="prefetch-hf-model",
|
|
555
|
+
image=image,
|
|
556
|
+
resources=resources,
|
|
557
|
+
secrets=[Secret(key=hf_token_key, as_env_var="HF_TOKEN")],
|
|
558
|
+
)
|
|
559
|
+
task = env.task(report=True)(store_hf_model_task) # type: ignore [assignment]
|
|
560
|
+
run = flyte.with_runcontext(interactive_mode=True, disable_run_cache=disable_run_cache).run(
|
|
561
|
+
task, info.model_dump_json(), raw_data_path
|
|
562
|
+
)
|
|
563
|
+
return typing.cast(Run, run)
|
flyte/remote/__init__.py
CHANGED
|
@@ -7,19 +7,30 @@ __all__ = [
|
|
|
7
7
|
"ActionDetails",
|
|
8
8
|
"ActionInputs",
|
|
9
9
|
"ActionOutputs",
|
|
10
|
+
"App",
|
|
10
11
|
"Project",
|
|
11
12
|
"Run",
|
|
12
13
|
"RunDetails",
|
|
13
14
|
"Secret",
|
|
15
|
+
"SecretTypes",
|
|
14
16
|
"Task",
|
|
17
|
+
"TaskDetails",
|
|
18
|
+
"Trigger",
|
|
19
|
+
"User",
|
|
20
|
+
"auth_metadata",
|
|
15
21
|
"create_channel",
|
|
16
22
|
"upload_dir",
|
|
17
23
|
"upload_file",
|
|
18
24
|
]
|
|
19
25
|
|
|
26
|
+
from ._action import Action, ActionDetails, ActionInputs, ActionOutputs
|
|
27
|
+
from ._app import App
|
|
28
|
+
from ._auth_metadata import auth_metadata
|
|
20
29
|
from ._client.auth import create_channel
|
|
21
30
|
from ._data import upload_dir, upload_file
|
|
22
31
|
from ._project import Project
|
|
23
|
-
from ._run import
|
|
24
|
-
from ._secret import Secret
|
|
25
|
-
from ._task import Task
|
|
32
|
+
from ._run import Run, RunDetails
|
|
33
|
+
from ._secret import Secret, SecretTypes
|
|
34
|
+
from ._task import Task, TaskDetails
|
|
35
|
+
from ._trigger import Trigger
|
|
36
|
+
from ._user import User
|