flyteplugins-vllm 2.0.0b40__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
File without changes
@@ -0,0 +1,3 @@
1
+ __all__ = ["VLLMAppEnvironment"]
2
+
3
+ from flyteplugins.vllm._app_environment import VLLMAppEnvironment
@@ -0,0 +1,196 @@
1
+ from __future__ import annotations
2
+
3
+ import shlex
4
+ from dataclasses import dataclass, field, replace
5
+ from typing import Any, Literal, Optional, Union
6
+
7
+ import flyte.app
8
+ import rich.repr
9
+ from flyte import Environment, Image, Resources, SecretRequest
10
+ from flyte.app import Input, RunOutput
11
+ from flyte.app._types import Port
12
+ from flyte.models import SerializationContext
13
+
14
+ DEFAULT_VLLM_IMAGE = (
15
+ flyte.Image.from_debian_base(name="vllm-app-image", python_version=(3, 12))
16
+ # install flashinfer and vllm
17
+ .with_pip_packages("flashinfer-python", "flashinfer-cubin")
18
+ .with_pip_packages("flashinfer-jit-cache", index_url="https://flashinfer.ai/whl/cu129")
19
+ # install the vllm flyte plugin
20
+ .with_pip_packages("flyteplugins-vllm", pre=True)
21
+ )
22
+
23
+
24
+ @rich.repr.auto
25
+ @dataclass(kw_only=True, repr=True)
26
+ class VLLMAppEnvironment(flyte.app.AppEnvironment):
27
+ """
28
+ App environment backed by vLLM for serving large language models.
29
+
30
+ This environment sets up a vLLM server with the specified model and configuration.
31
+
32
+ :param name: The name of the application.
33
+ :param container_image: The container image to use for the application.
34
+ :param port: Port application listens to. Defaults to 8000 for vLLM.
35
+ :param requests: Compute resource requests for application.
36
+ :param secrets: Secrets that are requested for application.
37
+ :param limits: Compute resource limits for application.
38
+ :param env_vars: Environment variables to set for the application.
39
+ :param scaling: Scaling configuration for the app environment.
40
+ :param domain: Domain to use for the app.
41
+ :param cluster_pool: The target cluster_pool where the app should be deployed.
42
+ :param requires_auth: Whether the public URL requires authentication.
43
+ :param type: Type of app.
44
+ :param extra_args: Extra args to pass to `vllm serve`. See
45
+ https://docs.vllm.ai/en/stable/configuration/engine_args
46
+ or run `vllm serve --help` for details.
47
+ :param model_path: Remote path to model (e.g., s3://bucket/path/to/model).
48
+ :param model_hf_path: Hugging Face path to model (e.g., Qwen/Qwen3-0.6B).
49
+ :param model_id: Model id that is exposed by vllm.
50
+ :param stream_model: Set to True to stream model from blob store to the GPU directly.
51
+ If False, the model will be downloaded to the local file system first and then loaded
52
+ into the GPU.
53
+ """
54
+
55
+ port: int | Port = 8080
56
+ type: str = "vLLM"
57
+ extra_args: str | list[str] = ""
58
+ model_path: str | RunOutput = ""
59
+ model_hf_path: str = ""
60
+ model_id: str = ""
61
+ stream_model: bool = True
62
+ image: str | Image | Literal["auto"] = DEFAULT_VLLM_IMAGE
63
+ _model_mount_path: str = field(default="/root/flyte", init=False)
64
+
65
+ def __post_init__(self):
66
+ if self.env_vars is None:
67
+ self.env_vars = {}
68
+
69
+ if self.model_id == "":
70
+ raise ValueError("model_id must be defined")
71
+
72
+ if self.model_path == "" and self.model_hf_path == "":
73
+ raise ValueError("model_path or model_hf_path must be defined")
74
+ if self.model_path != "" and self.model_hf_path != "":
75
+ raise ValueError("model_path and model_hf_path cannot be set at the same time")
76
+
77
+ if self.model_hf_path:
78
+ self._model_mount_path = self.model_hf_path
79
+
80
+ if self.args:
81
+ raise ValueError("args cannot be set for VLLMAppEnvironment. Use `extra_args` to add extra arguments.")
82
+
83
+ if isinstance(self.extra_args, str):
84
+ extra_args = shlex.split(self.extra_args)
85
+ else:
86
+ extra_args = self.extra_args
87
+
88
+ stream_model_args = []
89
+ if self.stream_model:
90
+ stream_model_args.extend(["--load-format", "flyte-vllm-streaming"])
91
+
92
+ self.args = [
93
+ "vllm-fserve",
94
+ "serve",
95
+ self._model_mount_path,
96
+ "--served-model-name",
97
+ self.model_id,
98
+ "--port",
99
+ str(self.get_port().port),
100
+ *stream_model_args,
101
+ *extra_args,
102
+ ]
103
+
104
+ if self.inputs:
105
+ raise ValueError("inputs cannot be set for VLLMAppEnvironment")
106
+
107
+ input_kwargs = {}
108
+ if self.stream_model:
109
+ self.env_vars["FLYTE_MODEL_LOADER_STREAM_SAFETENSORS"] = "true"
110
+ input_kwargs["env_var"] = "FLYTE_MODEL_LOADER_REMOTE_MODEL_PATH"
111
+ input_kwargs["download"] = False
112
+ else:
113
+ self.env_vars["FLYTE_MODEL_LOADER_STREAM_SAFETENSORS"] = "false"
114
+ input_kwargs["download"] = True
115
+ input_kwargs["mount"] = self._model_mount_path
116
+
117
+ if self.model_path:
118
+ self.inputs = [Input(name="model_path", value=self.model_path, **input_kwargs)]
119
+
120
+ self.env_vars["FLYTE_MODEL_LOADER_LOCAL_MODEL_PATH"] = self._model_mount_path
121
+ self.links = [flyte.app.Link(path="/docs", title="vLLM OpenAPI Docs", is_relative=True)]
122
+
123
+ if self.image is None or self.image == "auto":
124
+ self.image = DEFAULT_VLLM_IMAGE
125
+
126
+ super().__post_init__()
127
+
128
+ def container_args(self, serialization_context: SerializationContext) -> list[str]:
129
+ """Return the container arguments for vLLM."""
130
+ if isinstance(self.args, str):
131
+ return shlex.split(self.args)
132
+ return self.args or []
133
+
134
+ def clone_with(
135
+ self,
136
+ name: str,
137
+ image: Optional[Union[str, Image, Literal["auto"]]] = None,
138
+ resources: Optional[Resources] = None,
139
+ env_vars: Optional[dict[str, str]] = None,
140
+ secrets: Optional[SecretRequest] = None,
141
+ depends_on: Optional[list[Environment]] = None,
142
+ description: Optional[str] = None,
143
+ interruptible: Optional[bool] = None,
144
+ **kwargs: Any,
145
+ ) -> VLLMAppEnvironment:
146
+ port = kwargs.pop("port", None)
147
+ extra_args = kwargs.pop("extra_args", None)
148
+ if "model_path" in kwargs:
149
+ set_model_path = True
150
+ model_path = kwargs.pop("model_path", "") or ""
151
+ else:
152
+ set_model_path = False
153
+ model_path = self.model_path
154
+ if "model_hf_path" in kwargs:
155
+ set_model_hf_path = True
156
+ model_hf_path = kwargs.pop("model_hf_path", "") or ""
157
+ else:
158
+ set_model_hf_path = False
159
+ model_hf_path = self.model_hf_path
160
+ model_id = kwargs.pop("model_id", None)
161
+ stream_model = kwargs.pop("stream_model", None)
162
+
163
+ if kwargs:
164
+ raise TypeError(f"Unexpected keyword arguments: {list(kwargs.keys())}")
165
+
166
+ kwargs = self._get_kwargs()
167
+ kwargs["name"] = name
168
+ kwargs["args"] = None
169
+ kwargs["inputs"] = None
170
+ if image is not None:
171
+ kwargs["image"] = image
172
+ if resources is not None:
173
+ kwargs["resources"] = resources
174
+ if env_vars is not None:
175
+ kwargs["env_vars"] = env_vars
176
+ if secrets is not None:
177
+ kwargs["secrets"] = secrets
178
+ if depends_on is not None:
179
+ kwargs["depends_on"] = depends_on
180
+ if description is not None:
181
+ kwargs["description"] = description
182
+ if interruptible is not None:
183
+ kwargs["interruptible"] = interruptible
184
+ if port is not None:
185
+ kwargs["port"] = port
186
+ if extra_args is not None:
187
+ kwargs["extra_args"] = extra_args
188
+ if set_model_path:
189
+ kwargs["model_path"] = model_path
190
+ if set_model_hf_path:
191
+ kwargs["model_hf_path"] = model_hf_path
192
+ if model_id is not None:
193
+ kwargs["model_id"] = model_id
194
+ if stream_model is not None:
195
+ kwargs["stream_model"] = stream_model
196
+ return replace(self, **kwargs)
File without changes
@@ -0,0 +1,126 @@
1
+ import logging
2
+ from typing import Generator
3
+
4
+ import torch
5
+ import vllm
6
+ import vllm.entrypoints.cli.main
7
+ from flyte.app.extras._model_loader.config import (
8
+ LOCAL_MODEL_PATH,
9
+ REMOTE_MODEL_PATH,
10
+ STREAM_SAFETENSORS,
11
+ )
12
+ from flyte.app.extras._model_loader.loader import SafeTensorsStreamer, prefetch
13
+ from vllm.config import ModelConfig, VllmConfig
14
+ from vllm.distributed import get_tensor_model_parallel_rank
15
+ from vllm.model_executor.model_loader import register_model_loader
16
+ from vllm.model_executor.model_loader.default_loader import DefaultModelLoader
17
+ from vllm.model_executor.model_loader.dummy_loader import DummyModelLoader
18
+ from vllm.model_executor.model_loader.sharded_state_loader import ShardedStateLoader
19
+ from vllm.model_executor.model_loader.utils import set_default_torch_dtype
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ @register_model_loader("flyte-vllm-streaming")
25
+ class FlyteModelLoader(DefaultModelLoader):
26
+ """Custom model loader for streaming model weights from object storage."""
27
+
28
+ def _get_weights_iterator(
29
+ self, source: DefaultModelLoader.Source
30
+ ) -> Generator[tuple[str, torch.Tensor], None, None]:
31
+ # Try to load weights using the Flyte SafeTensorsLoader. Fallback to the default loader otherwise.
32
+ try:
33
+ streamer = SafeTensorsStreamer(REMOTE_MODEL_PATH, LOCAL_MODEL_PATH)
34
+ except ValueError:
35
+ yield from super()._get_weights_iterator(source)
36
+ else:
37
+ for name, tensor in streamer.get_tensors():
38
+ yield source.prefix + name, tensor
39
+
40
+ def download_model(self, model_config: ModelConfig) -> None:
41
+ # This model loader supports streaming only
42
+ pass
43
+
44
+ def _load_sharded_model(self, vllm_config: VllmConfig, model_config: ModelConfig) -> torch.nn.Module:
45
+ # Forked from: https://github.com/vllm-project/vllm/blob/99d01a5e3d5278284bad359ac8b87ee7a551afda/vllm/model_executor/model_loader/loader.py#L613
46
+ # Sanity checks
47
+ tensor_parallel_size = vllm_config.parallel_config.tensor_parallel_size
48
+ rank = get_tensor_model_parallel_rank()
49
+ if rank >= tensor_parallel_size:
50
+ raise ValueError(f"Invalid rank {rank} for tensor parallel size {tensor_parallel_size}")
51
+ with set_default_torch_dtype(vllm_config.model_config.dtype): # type: ignore[arg-type]
52
+ with torch.device(vllm_config.device_config.device): # type: ignore[arg-type]
53
+ model_loader = DummyModelLoader(load_config=vllm_config.load_config)
54
+ model = model_loader.load_model(vllm_config=vllm_config, model_config=model_config)
55
+ for i, (name, module) in enumerate(model.named_modules()):
56
+ print(i, name, module)
57
+ quant_method = getattr(module, "quant_method", None)
58
+ if quant_method is not None:
59
+ quant_method.process_weights_after_loading(module)
60
+ state_dict = ShardedStateLoader._filter_subtensors(model.state_dict())
61
+ streamer = SafeTensorsStreamer(
62
+ REMOTE_MODEL_PATH,
63
+ LOCAL_MODEL_PATH,
64
+ rank=rank,
65
+ tensor_parallel_size=tensor_parallel_size,
66
+ )
67
+ for name, tensor in streamer.get_tensors():
68
+ # If loading with LoRA enabled, additional padding may
69
+ # be added to certain parameters. We only load into a
70
+ # narrowed view of the parameter data.
71
+ param_data = state_dict[name].data
72
+ param_shape = state_dict[name].shape
73
+ for dim, size in enumerate(tensor.shape):
74
+ if size < param_shape[dim]:
75
+ param_data = param_data.narrow(dim, 0, size)
76
+ if tensor.shape != param_shape:
77
+ logger.warning(
78
+ "loading tensor of shape %s into parameter '%s' of shape %s",
79
+ tensor.shape,
80
+ name,
81
+ param_shape,
82
+ )
83
+ param_data.copy_(tensor)
84
+ state_dict.pop(name)
85
+ if state_dict:
86
+ raise ValueError(f"Missing keys {tuple(state_dict)} in loaded state!")
87
+ return model.eval()
88
+
89
+ def load_model(
90
+ self,
91
+ vllm_config: VllmConfig,
92
+ model_config: ModelConfig,
93
+ ) -> torch.nn.Module:
94
+ logger.info("Loading model with FlyteModelLoader")
95
+ if vllm_config.parallel_config.tensor_parallel_size > 1:
96
+ return self._load_sharded_model(vllm_config, model_config)
97
+ else:
98
+ return super().load_model(vllm_config, model_config)
99
+
100
+
101
+ async def _get_model_files():
102
+ import flyte.storage as storage
103
+
104
+ if not await storage.exists(REMOTE_MODEL_PATH):
105
+ raise FileNotFoundError(f"Model path not found: {REMOTE_MODEL_PATH}")
106
+
107
+ await prefetch(
108
+ REMOTE_MODEL_PATH,
109
+ LOCAL_MODEL_PATH,
110
+ exclude_safetensors=STREAM_SAFETENSORS,
111
+ )
112
+
113
+
114
+ def main():
115
+ import asyncio
116
+
117
+ # TODO: add CLI here to be able to pass in serialized inputs from AppEnvironment
118
+ logging.basicConfig(
119
+ level=logging.INFO,
120
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
121
+ )
122
+
123
+ # Prefetch the model
124
+ asyncio.run(_get_model_files())
125
+
126
+ vllm.entrypoints.cli.main.main()
@@ -0,0 +1,54 @@
1
+ Metadata-Version: 2.4
2
+ Name: flyteplugins-vllm
3
+ Version: 2.0.0b40
4
+ Summary: vLLM plugin for flyte
5
+ Author-email: Niels Bantilan <cosmicbboy@users.noreply.github.com>
6
+ Requires-Python: >=3.10
7
+ Description-Content-Type: text/markdown
8
+ Requires-Dist: vllm>=0.11.0
9
+
10
+ # Union vLLM Plugin
11
+
12
+ Serve large language models using vLLM with Flyte Apps.
13
+
14
+ This plugin provides the `VLLMAppEnvironment` class for deploying and serving LLMs using [vLLM](https://docs.vllm.ai/).
15
+
16
+ ## Installation
17
+
18
+ ```bash
19
+ pip install --pre flyteplugins-vllm
20
+ ```
21
+
22
+ ## Usage
23
+
24
+ ```python
25
+ import flyte
26
+ import flyte.app
27
+ from flyteplugins.vllm import VLLMAppEnvironment
28
+
29
+ # Define the vLLM app environment
30
+ vllm_app = VLLMAppEnvironment(
31
+ name="my-llm-app",
32
+ model="s3://your-bucket/models/your-model",
33
+ model_id="your-model-id",
34
+ resources=flyte.Resources(cpu="4", memory="16Gi", gpu="L40s:1"),
35
+ stream_model=True, # Stream model directly from blob store to GPU
36
+ scaling=flyte.app.Scaling(
37
+ replicas=(0, 1),
38
+ scaledown_after=300,
39
+ ),
40
+ )
41
+
42
+ if __name__ == "__main__":
43
+ flyte.init_from_config()
44
+ app = flyte.serve(vllm_app)
45
+ print(f"Deployed vLLM app: {app.url}")
46
+ ```
47
+
48
+ ## Features
49
+
50
+ - **Streaming Model Loading**: Stream model weights directly from object storage to GPU memory, reducing startup time and disk requirements.
51
+ - **OpenAI-Compatible API**: The deployed app exposes an OpenAI-compatible API for chat completions.
52
+ - **Auto-scaling**: Configure scaling policies to scale up/down based on traffic.
53
+ - **Tensor Parallelism**: Support for distributed inference across multiple GPUs.
54
+
@@ -0,0 +1,10 @@
1
+ flyteplugins/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ flyteplugins/vllm/__init__.py,sha256=FhdW2e_f6PsGo4wyV07jradAFbg7WmB0Luz3zHIDd7A,100
3
+ flyteplugins/vllm/_app_environment.py,sha256=RYB2Oj0aHa_fj9N49_4p89qEYU3GRPuszO5g_U597N4,7664
4
+ flyteplugins/vllm/_model_loader/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
+ flyteplugins/vllm/_model_loader/shim.py,sha256=vSb7_r0sGJJVkWVIaqkypry0OTa5PAv-REqdigufYd0,5348
6
+ flyteplugins_vllm-2.0.0b40.dist-info/METADATA,sha256=c7i4cEbfX-tK4i2rmRhNnIr61bUagNRsHxOZhQHLj4A,1577
7
+ flyteplugins_vllm-2.0.0b40.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
8
+ flyteplugins_vllm-2.0.0b40.dist-info/entry_points.txt,sha256=lC-uwvkaytwtzbkJWdS69np63yLAakaDpI4mV1Yp9l8,74
9
+ flyteplugins_vllm-2.0.0b40.dist-info/top_level.txt,sha256=cgd779rPu9EsvdtuYgUxNHHgElaQvPn74KhB5XSeMBE,13
10
+ flyteplugins_vllm-2.0.0b40.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (80.9.0)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1,2 @@
1
+ [console_scripts]
2
+ vllm-fserve = flyteplugins.vllm._model_loader.shim:main
@@ -0,0 +1 @@
1
+ flyteplugins