flyteplugins-vllm 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
File without changes
@@ -0,0 +1,3 @@
1
+ __all__ = ["DEFAULT_VLLM_IMAGE", "VLLMAppEnvironment"]
2
+
3
+ from flyteplugins.vllm._app_environment import DEFAULT_VLLM_IMAGE, VLLMAppEnvironment
@@ -0,0 +1,209 @@
1
+ from __future__ import annotations
2
+
3
+ import shlex
4
+ from dataclasses import dataclass, field, replace
5
+ from typing import Any, Literal, Optional, Union
6
+
7
+ import flyte.app
8
+ import rich.repr
9
+ from flyte import Environment, Image, Resources, SecretRequest
10
+ from flyte.app import Parameter, RunOutput
11
+ from flyte.app._types import Port
12
+ from flyte.models import SerializationContext
13
+
14
+ from flyteplugins.vllm._constants import VLLM_MIN_VERSION_STR
15
+
16
+ DEFAULT_VLLM_IMAGE = (
17
+ flyte.Image.from_debian_base(name="vllm-app-image")
18
+ # install flashinfer and vllm
19
+ .with_pip_packages("flashinfer-python", "flashinfer-cubin")
20
+ .with_pip_packages("flashinfer-jit-cache", index_url="https://flashinfer.ai/whl/cu129")
21
+ # install the vllm flyte plugin
22
+ .with_pip_packages("flyteplugins-vllm", pre=True)
23
+ # install vllm in a separate layer due to dependency conflict with flyte (protovalidate)
24
+ .with_pip_packages(f"vllm=={VLLM_MIN_VERSION_STR}")
25
+ )
26
+
27
+
28
+ @rich.repr.auto
29
+ @dataclass(kw_only=True, repr=True)
30
+ class VLLMAppEnvironment(flyte.app.AppEnvironment):
31
+ """
32
+ App environment backed by vLLM for serving large language models.
33
+
34
+ This environment sets up a vLLM server with the specified model and configuration.
35
+
36
+ :param name: The name of the application.
37
+ :param container_image: The container image to use for the application.
38
+ :param port: Port application listens to. Defaults to 8000 for vLLM.
39
+ :param requests: Compute resource requests for application.
40
+ :param secrets: Secrets that are requested for application.
41
+ :param limits: Compute resource limits for application.
42
+ :param env_vars: Environment variables to set for the application.
43
+ :param scaling: Scaling configuration for the app environment.
44
+ :param domain: Domain to use for the app.
45
+ :param cluster_pool: The target cluster_pool where the app should be deployed.
46
+ :param requires_auth: Whether the public URL requires authentication.
47
+ :param type: Type of app.
48
+ :param extra_args: Extra args to pass to `vllm serve`. See
49
+ https://docs.vllm.ai/en/stable/configuration/engine_args
50
+ or run `vllm serve --help` for details.
51
+ :param model_path: Remote path to model (e.g., s3://bucket/path/to/model).
52
+ :param model_hf_path: Hugging Face path to model (e.g., Qwen/Qwen3-0.6B).
53
+ :param model_id: Model id that is exposed by vllm.
54
+ :param stream_model: Set to True to stream model from blob store to the GPU directly.
55
+ If False, the model will be downloaded to the local file system first and then loaded
56
+ into the GPU.
57
+ """
58
+
59
+ port: int | Port = 8080
60
+ type: str = "vLLM"
61
+ extra_args: str | list[str] = ""
62
+ model_path: str | RunOutput = ""
63
+ model_hf_path: str = ""
64
+ model_id: str = ""
65
+ stream_model: bool = True
66
+ image: str | Image | Literal["auto"] = DEFAULT_VLLM_IMAGE
67
+ _model_mount_path: str = field(default="/root/flyte", init=False)
68
+
69
+ def __post_init__(self):
70
+ if self.env_vars is None:
71
+ self.env_vars = {}
72
+
73
+ if self._server is not None:
74
+ raise ValueError("server function cannot be set for VLLMAppEnvironment")
75
+
76
+ if self._on_startup is not None:
77
+ raise ValueError("on_startup function cannot be set for VLLMAppEnvironment")
78
+
79
+ if self._on_shutdown is not None:
80
+ raise ValueError("on_shutdown function cannot be set for VLLMAppEnvironment")
81
+
82
+ if self.model_id == "":
83
+ raise ValueError("model_id must be defined")
84
+
85
+ if self.model_path == "" and self.model_hf_path == "":
86
+ raise ValueError("model_path or model_hf_path must be defined")
87
+ if self.model_path != "" and self.model_hf_path != "":
88
+ raise ValueError("model_path and model_hf_path cannot be set at the same time")
89
+
90
+ if self.model_hf_path:
91
+ self._model_mount_path = self.model_hf_path
92
+
93
+ if self.args:
94
+ raise ValueError("args cannot be set for VLLMAppEnvironment. Use `extra_args` to add extra arguments.")
95
+
96
+ if isinstance(self.extra_args, str):
97
+ extra_args = shlex.split(self.extra_args)
98
+ else:
99
+ extra_args = self.extra_args
100
+
101
+ stream_model_args = []
102
+ if self.stream_model:
103
+ stream_model_args.extend(["--load-format", "flyte-vllm-streaming"])
104
+
105
+ self.args = [
106
+ "vllm-fserve",
107
+ "serve",
108
+ self._model_mount_path,
109
+ "--served-model-name",
110
+ self.model_id,
111
+ "--port",
112
+ str(self.get_port().port),
113
+ *stream_model_args,
114
+ *extra_args,
115
+ ]
116
+
117
+ if self.parameters:
118
+ raise ValueError("parameters cannot be set for VLLMAppEnvironment")
119
+
120
+ input_kwargs = {}
121
+ if self.stream_model:
122
+ self.env_vars["FLYTE_MODEL_LOADER_STREAM_SAFETENSORS"] = "true"
123
+ input_kwargs["env_var"] = "FLYTE_MODEL_LOADER_REMOTE_MODEL_PATH"
124
+ input_kwargs["download"] = False
125
+ else:
126
+ self.env_vars["FLYTE_MODEL_LOADER_STREAM_SAFETENSORS"] = "false"
127
+ input_kwargs["download"] = True
128
+ input_kwargs["mount"] = self._model_mount_path
129
+
130
+ if self.model_path:
131
+ self.parameters = [Parameter(name="model_path", value=self.model_path, **input_kwargs)]
132
+
133
+ self.env_vars["FLYTE_MODEL_LOADER_LOCAL_MODEL_PATH"] = self._model_mount_path
134
+ self.links = [flyte.app.Link(path="/docs", title="vLLM OpenAPI Docs", is_relative=True)]
135
+
136
+ if self.image is None or self.image == "auto":
137
+ self.image = DEFAULT_VLLM_IMAGE
138
+
139
+ super().__post_init__()
140
+
141
+ def container_args(self, serialization_context: SerializationContext) -> list[str]:
142
+ """Return the container arguments for vLLM."""
143
+ if isinstance(self.args, str):
144
+ return shlex.split(self.args)
145
+ return self.args or []
146
+
147
+ def clone_with(
148
+ self,
149
+ name: str,
150
+ image: Optional[Union[str, Image, Literal["auto"]]] = None,
151
+ resources: Optional[Resources] = None,
152
+ env_vars: Optional[dict[str, str]] = None,
153
+ secrets: Optional[SecretRequest] = None,
154
+ depends_on: Optional[list[Environment]] = None,
155
+ description: Optional[str] = None,
156
+ interruptible: Optional[bool] = None,
157
+ **kwargs: Any,
158
+ ) -> VLLMAppEnvironment:
159
+ port = kwargs.pop("port", None)
160
+ extra_args = kwargs.pop("extra_args", None)
161
+ if "model_path" in kwargs:
162
+ set_model_path = True
163
+ model_path = kwargs.pop("model_path", "") or ""
164
+ else:
165
+ set_model_path = False
166
+ model_path = self.model_path
167
+ if "model_hf_path" in kwargs:
168
+ set_model_hf_path = True
169
+ model_hf_path = kwargs.pop("model_hf_path", "") or ""
170
+ else:
171
+ set_model_hf_path = False
172
+ model_hf_path = self.model_hf_path
173
+ model_id = kwargs.pop("model_id", None)
174
+ stream_model = kwargs.pop("stream_model", None)
175
+
176
+ if kwargs:
177
+ raise TypeError(f"Unexpected keyword arguments: {list(kwargs.keys())}")
178
+
179
+ kwargs = self._get_kwargs()
180
+ kwargs["name"] = name
181
+ kwargs["args"] = None
182
+ kwargs["parameters"] = None
183
+ if image is not None:
184
+ kwargs["image"] = image
185
+ if resources is not None:
186
+ kwargs["resources"] = resources
187
+ if env_vars is not None:
188
+ kwargs["env_vars"] = env_vars
189
+ if secrets is not None:
190
+ kwargs["secrets"] = secrets
191
+ if depends_on is not None:
192
+ kwargs["depends_on"] = depends_on
193
+ if description is not None:
194
+ kwargs["description"] = description
195
+ if interruptible is not None:
196
+ kwargs["interruptible"] = interruptible
197
+ if port is not None:
198
+ kwargs["port"] = port
199
+ if extra_args is not None:
200
+ kwargs["extra_args"] = extra_args
201
+ if set_model_path:
202
+ kwargs["model_path"] = model_path
203
+ if set_model_hf_path:
204
+ kwargs["model_hf_path"] = model_hf_path
205
+ if model_id is not None:
206
+ kwargs["model_id"] = model_id
207
+ if stream_model is not None:
208
+ kwargs["stream_model"] = stream_model
209
+ return replace(self, **kwargs)
@@ -0,0 +1,2 @@
1
+ VLLM_MIN_VERSION = (0, 11, 0)
2
+ VLLM_MIN_VERSION_STR = ".".join(map(str, VLLM_MIN_VERSION))
File without changes
@@ -0,0 +1,143 @@
1
+ import logging
2
+ from typing import Generator
3
+
4
+ import torch
5
+ from flyte.app.extras._model_loader.config import (
6
+ LOCAL_MODEL_PATH,
7
+ REMOTE_MODEL_PATH,
8
+ STREAM_SAFETENSORS,
9
+ )
10
+ from flyte.app.extras._model_loader.loader import SafeTensorsStreamer, prefetch
11
+
12
+ from flyteplugins.vllm._constants import VLLM_MIN_VERSION, VLLM_MIN_VERSION_STR
13
+
14
+ try:
15
+ import vllm
16
+ except ImportError:
17
+ raise ImportError(f"vllm is not installed. Please install 'vllm>={VLLM_MIN_VERSION_STR}', to use the model loader.")
18
+
19
+ if tuple([int(part) for part in vllm.__version__.split(".") if part.isdigit()]) < VLLM_MIN_VERSION:
20
+ raise ImportError(
21
+ f"vllm version >={VLLM_MIN_VERSION_STR} required, but found {vllm.__version__}. Please upgrade vllm."
22
+ )
23
+
24
+ import vllm.entrypoints.cli.main
25
+ from vllm.config import ModelConfig, VllmConfig
26
+ from vllm.distributed import get_tensor_model_parallel_rank
27
+ from vllm.model_executor.model_loader import register_model_loader
28
+ from vllm.model_executor.model_loader.default_loader import DefaultModelLoader
29
+ from vllm.model_executor.model_loader.dummy_loader import DummyModelLoader
30
+ from vllm.model_executor.model_loader.sharded_state_loader import ShardedStateLoader
31
+
32
+ try:
33
+ from vllm.model_executor.model_loader.utils import set_default_torch_dtype
34
+ except ImportError:
35
+ # vllm 0.13.0 moved the set_default_torch_dtype to vllm.utils.torch_utils
36
+ from vllm.utils.torch_utils import set_default_torch_dtype
37
+
38
+ logger = logging.getLogger(__name__)
39
+
40
+
41
+ @register_model_loader("flyte-vllm-streaming")
42
+ class FlyteModelLoader(DefaultModelLoader):
43
+ """Custom model loader for streaming model weights from object storage."""
44
+
45
+ def _get_weights_iterator(
46
+ self, source: DefaultModelLoader.Source
47
+ ) -> Generator[tuple[str, torch.Tensor], None, None]:
48
+ # Try to load weights using the Flyte SafeTensorsLoader. Fallback to the default loader otherwise.
49
+ try:
50
+ streamer = SafeTensorsStreamer(REMOTE_MODEL_PATH, LOCAL_MODEL_PATH)
51
+ except ValueError:
52
+ yield from super()._get_weights_iterator(source)
53
+ else:
54
+ for name, tensor in streamer.get_tensors():
55
+ yield source.prefix + name, tensor
56
+
57
+ def download_model(self, model_config: ModelConfig) -> None:
58
+ # This model loader supports streaming only
59
+ pass
60
+
61
+ def _load_sharded_model(self, vllm_config: VllmConfig, model_config: ModelConfig) -> torch.nn.Module:
62
+ # Forked from: https://github.com/vllm-project/vllm/blob/99d01a5e3d5278284bad359ac8b87ee7a551afda/vllm/model_executor/model_loader/loader.py#L613
63
+ # Sanity checks
64
+ tensor_parallel_size = vllm_config.parallel_config.tensor_parallel_size
65
+ rank = get_tensor_model_parallel_rank()
66
+ if rank >= tensor_parallel_size:
67
+ raise ValueError(f"Invalid rank {rank} for tensor parallel size {tensor_parallel_size}")
68
+ with set_default_torch_dtype(vllm_config.model_config.dtype): # type: ignore[arg-type]
69
+ with torch.device(vllm_config.device_config.device): # type: ignore[arg-type]
70
+ model_loader = DummyModelLoader(load_config=vllm_config.load_config)
71
+ model = model_loader.load_model(vllm_config=vllm_config, model_config=model_config)
72
+ for i, (name, module) in enumerate(model.named_modules()):
73
+ print(i, name, module)
74
+ quant_method = getattr(module, "quant_method", None)
75
+ if quant_method is not None:
76
+ quant_method.process_weights_after_loading(module)
77
+ state_dict = ShardedStateLoader._filter_subtensors(model.state_dict())
78
+ streamer = SafeTensorsStreamer(
79
+ REMOTE_MODEL_PATH,
80
+ LOCAL_MODEL_PATH,
81
+ rank=rank,
82
+ tensor_parallel_size=tensor_parallel_size,
83
+ )
84
+ for name, tensor in streamer.get_tensors():
85
+ # If loading with LoRA enabled, additional padding may
86
+ # be added to certain parameters. We only load into a
87
+ # narrowed view of the parameter data.
88
+ param_data = state_dict[name].data
89
+ param_shape = state_dict[name].shape
90
+ for dim, size in enumerate(tensor.shape):
91
+ if size < param_shape[dim]:
92
+ param_data = param_data.narrow(dim, 0, size)
93
+ if tensor.shape != param_shape:
94
+ logger.warning(
95
+ "loading tensor of shape %s into parameter '%s' of shape %s",
96
+ tensor.shape,
97
+ name,
98
+ param_shape,
99
+ )
100
+ param_data.copy_(tensor)
101
+ state_dict.pop(name)
102
+ if state_dict:
103
+ raise ValueError(f"Missing keys {tuple(state_dict)} in loaded state!")
104
+ return model.eval()
105
+
106
+ def load_model(
107
+ self,
108
+ vllm_config: VllmConfig,
109
+ model_config: ModelConfig,
110
+ ) -> torch.nn.Module:
111
+ logger.info("Loading model with FlyteModelLoader")
112
+ if vllm_config.parallel_config.tensor_parallel_size > 1:
113
+ return self._load_sharded_model(vllm_config, model_config)
114
+ else:
115
+ return super().load_model(vllm_config, model_config)
116
+
117
+
118
+ async def _get_model_files():
119
+ import flyte.storage as storage
120
+
121
+ if not await storage.exists(REMOTE_MODEL_PATH):
122
+ raise FileNotFoundError(f"Model path not found: {REMOTE_MODEL_PATH}")
123
+
124
+ await prefetch(
125
+ REMOTE_MODEL_PATH,
126
+ LOCAL_MODEL_PATH,
127
+ exclude_safetensors=STREAM_SAFETENSORS,
128
+ )
129
+
130
+
131
+ def main():
132
+ import asyncio
133
+
134
+ # TODO: add CLI here to be able to pass in serialized parameters from AppEnvironment
135
+ logging.basicConfig(
136
+ level=logging.INFO,
137
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
138
+ )
139
+
140
+ # Prefetch the model
141
+ asyncio.run(_get_model_files())
142
+
143
+ vllm.entrypoints.cli.main.main()
@@ -0,0 +1,54 @@
1
+ Metadata-Version: 2.4
2
+ Name: flyteplugins-vllm
3
+ Version: 2.0.0
4
+ Summary: vLLM plugin for flyte
5
+ Author-email: Niels Bantilan <cosmicbboy@users.noreply.github.com>
6
+ Requires-Python: >=3.10
7
+ Description-Content-Type: text/markdown
8
+ Requires-Dist: flyte>=2.0.0b43
9
+
10
+ # Union vLLM Plugin
11
+
12
+ Serve large language models using vLLM with Flyte Apps.
13
+
14
+ This plugin provides the `VLLMAppEnvironment` class for deploying and serving LLMs using [vLLM](https://docs.vllm.ai/).
15
+
16
+ ## Installation
17
+
18
+ ```bash
19
+ pip install --pre flyteplugins-vllm
20
+ ```
21
+
22
+ ## Usage
23
+
24
+ ```python
25
+ import flyte
26
+ import flyte.app
27
+ from flyteplugins.vllm import VLLMAppEnvironment
28
+
29
+ # Define the vLLM app environment
30
+ vllm_app = VLLMAppEnvironment(
31
+ name="my-llm-app",
32
+ model="s3://your-bucket/models/your-model",
33
+ model_id="your-model-id",
34
+ resources=flyte.Resources(cpu="4", memory="16Gi", gpu="L40s:1"),
35
+ stream_model=True, # Stream model directly from blob store to GPU
36
+ scaling=flyte.app.Scaling(
37
+ replicas=(0, 1),
38
+ scaledown_after=300,
39
+ ),
40
+ )
41
+
42
+ if __name__ == "__main__":
43
+ flyte.init_from_config()
44
+ app = flyte.serve(vllm_app)
45
+ print(f"Deployed vLLM app: {app.url}")
46
+ ```
47
+
48
+ ## Features
49
+
50
+ - **Streaming Model Loading**: Stream model weights directly from object storage to GPU memory, reducing startup time and disk requirements.
51
+ - **OpenAI-Compatible API**: The deployed app exposes an OpenAI-compatible API for chat completions.
52
+ - **Auto-scaling**: Configure scaling policies to scale up/down based on traffic.
53
+ - **Tensor Parallelism**: Support for distributed inference across multiple GPUs.
54
+
@@ -0,0 +1,11 @@
1
+ flyteplugins/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ flyteplugins/vllm/__init__.py,sha256=knzAiZG0gNqtgroaB1j5OBxomxBab6JyuWParvYRlLw,142
3
+ flyteplugins/vllm/_app_environment.py,sha256=RE0bzbvaT4DhAVbQGYbqo-KnEYUFHIiuHXXBYhz2JVg,8263
4
+ flyteplugins/vllm/_constants.py,sha256=I8suY7mz05kyvaW0Zskv_Zw0y7Gf0E4wjWm88cdjttU,90
5
+ flyteplugins/vllm/_model_loader/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
+ flyteplugins/vllm/_model_loader/shim.py,sha256=02Myhx4A5W6wDHIzV_TIUrhPR0foIQU0N809nhIBm6U,5996
7
+ flyteplugins_vllm-2.0.0.dist-info/METADATA,sha256=FnPt9XAlixEBxgsEeWRAT17dpszehISclwqy4SoBf5M,1577
8
+ flyteplugins_vllm-2.0.0.dist-info/WHEEL,sha256=YCfwYGOYMi5Jhw2fU4yNgwErybb2IX5PEwBKV4ZbdBo,91
9
+ flyteplugins_vllm-2.0.0.dist-info/entry_points.txt,sha256=lC-uwvkaytwtzbkJWdS69np63yLAakaDpI4mV1Yp9l8,74
10
+ flyteplugins_vllm-2.0.0.dist-info/top_level.txt,sha256=cgd779rPu9EsvdtuYgUxNHHgElaQvPn74KhB5XSeMBE,13
11
+ flyteplugins_vllm-2.0.0.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (82.0.0)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1,2 @@
1
+ [console_scripts]
2
+ vllm-fserve = flyteplugins.vllm._model_loader.shim:main
@@ -0,0 +1 @@
1
+ flyteplugins