flyteplugins-sglang 2.0.0b45__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flyteplugins/__init__.py +0 -0
- flyteplugins/sglang/__init__.py +3 -0
- flyteplugins/sglang/_app_environment.py +212 -0
- flyteplugins/sglang/_constants.py +2 -0
- flyteplugins/sglang/_model_loader/__init__.py +0 -0
- flyteplugins/sglang/_model_loader/shim.py +171 -0
- flyteplugins_sglang-2.0.0b45.dist-info/METADATA +69 -0
- flyteplugins_sglang-2.0.0b45.dist-info/RECORD +11 -0
- flyteplugins_sglang-2.0.0b45.dist-info/WHEEL +5 -0
- flyteplugins_sglang-2.0.0b45.dist-info/entry_points.txt +2 -0
- flyteplugins_sglang-2.0.0b45.dist-info/top_level.txt +1 -0
flyteplugins/__init__.py
ADDED
|
File without changes
|
|
@@ -0,0 +1,212 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import shlex
|
|
4
|
+
from dataclasses import dataclass, field, replace
|
|
5
|
+
from typing import Any, Literal, Optional, Union
|
|
6
|
+
|
|
7
|
+
import flyte.app
|
|
8
|
+
import rich.repr
|
|
9
|
+
from flyte import Environment, Image, Resources, SecretRequest
|
|
10
|
+
from flyte.app import Parameter, RunOutput
|
|
11
|
+
from flyte.app._types import Port
|
|
12
|
+
from flyte.models import SerializationContext
|
|
13
|
+
|
|
14
|
+
from flyteplugins.sglang._constants import SGLANG_MIN_VERSION_STR
|
|
15
|
+
|
|
16
|
+
DEFAULT_SGLANG_IMAGE = (
|
|
17
|
+
flyte.Image.from_debian_base(name="sglang-app-image")
|
|
18
|
+
# install system dependencies, including CUDA toolkit, which is needed by sglang for compiling the model
|
|
19
|
+
.with_apt_packages("libnuma-dev", "wget")
|
|
20
|
+
.with_commands(
|
|
21
|
+
[
|
|
22
|
+
"wget https://developer.download.nvidia.com/compute/cuda/repos/debian12/x86_64/cuda-keyring_1.1-1_all.deb",
|
|
23
|
+
"dpkg -i cuda-keyring_1.1-1_all.deb",
|
|
24
|
+
"apt-get update",
|
|
25
|
+
"apt-get install -y cuda-toolkit-12-8",
|
|
26
|
+
]
|
|
27
|
+
)
|
|
28
|
+
# install flash-infer
|
|
29
|
+
.with_pip_packages("flashinfer-python", "flashinfer-cubin")
|
|
30
|
+
.with_pip_packages("flashinfer-jit-cache", index_url="https://flashinfer.ai/whl/cu128")
|
|
31
|
+
.with_pip_packages("flyteplugins-sglang", pre=True)
|
|
32
|
+
.with_pip_packages(f"sglang>={SGLANG_MIN_VERSION_STR}")
|
|
33
|
+
.with_env_vars({"CUDA_HOME": "/usr/local/cuda-12.8"})
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@rich.repr.auto
|
|
38
|
+
@dataclass(kw_only=True, repr=True)
|
|
39
|
+
class SGLangAppEnvironment(flyte.app.AppEnvironment):
|
|
40
|
+
"""
|
|
41
|
+
App environment backed by SGLang for serving large language models.
|
|
42
|
+
|
|
43
|
+
This environment sets up an SGLang server with the specified model and configuration.
|
|
44
|
+
|
|
45
|
+
:param name: The name of the application.
|
|
46
|
+
:param container_image: The container image to use for the application.
|
|
47
|
+
:param port: Port application listens to. Defaults to 8000 for SGLang.
|
|
48
|
+
:param requests: Compute resource requests for application.
|
|
49
|
+
:param secrets: Secrets that are requested for application.
|
|
50
|
+
:param limits: Compute resource limits for application.
|
|
51
|
+
:param env_vars: Environment variables to set for the application.
|
|
52
|
+
:param scaling: Scaling configuration for the app environment.
|
|
53
|
+
:param domain: Domain to use for the app.
|
|
54
|
+
:param cluster_pool: The target cluster_pool where the app should be deployed.
|
|
55
|
+
:param requires_auth: Whether the public URL requires authentication.
|
|
56
|
+
:param type: Type of app.
|
|
57
|
+
:param extra_args: Extra args to pass to `python -m sglang.launch_server`. See
|
|
58
|
+
https://docs.sglang.io/advanced_features/server_arguments.html for details.
|
|
59
|
+
:param model_path: Remote path to model (e.g., s3://bucket/path/to/model).
|
|
60
|
+
:param model_hf_path: Hugging Face path to model (e.g., Qwen/Qwen3-0.6B).
|
|
61
|
+
:param model_id: Model id that is exposed by SGLang.
|
|
62
|
+
:param stream_model: Set to True to stream model from blob store to the GPU directly.
|
|
63
|
+
If False, the model will be downloaded to the local file system first and then loaded
|
|
64
|
+
into the GPU.
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
port: int | Port = 8080
|
|
68
|
+
type: str = "SGLang"
|
|
69
|
+
extra_args: str | list[str] = ""
|
|
70
|
+
model_path: str | RunOutput = ""
|
|
71
|
+
model_hf_path: str = ""
|
|
72
|
+
model_id: str = ""
|
|
73
|
+
stream_model: bool = True
|
|
74
|
+
image: str | Image | Literal["auto"] = DEFAULT_SGLANG_IMAGE
|
|
75
|
+
_model_mount_path: str = field(default="/root/flyte", init=False)
|
|
76
|
+
|
|
77
|
+
def __post_init__(self):
|
|
78
|
+
if self.env_vars is None:
|
|
79
|
+
self.env_vars = {}
|
|
80
|
+
|
|
81
|
+
if self._server is not None:
|
|
82
|
+
raise ValueError("server function cannot be set for SGLangAppEnvironment")
|
|
83
|
+
|
|
84
|
+
if self._on_startup is not None:
|
|
85
|
+
raise ValueError("on_startup function cannot be set for SGLangAppEnvironment")
|
|
86
|
+
|
|
87
|
+
if self._on_shutdown is not None:
|
|
88
|
+
raise ValueError("on_shutdown function cannot be set for SGLangAppEnvironment")
|
|
89
|
+
|
|
90
|
+
if self.model_id == "":
|
|
91
|
+
raise ValueError("model_id must be defined")
|
|
92
|
+
|
|
93
|
+
if self.model_path == "" and self.model_hf_path == "":
|
|
94
|
+
raise ValueError("model_path or model_hf_path must be defined")
|
|
95
|
+
if self.model_path != "" and self.model_hf_path != "":
|
|
96
|
+
raise ValueError("model_path and model_hf_path cannot be set at the same time")
|
|
97
|
+
|
|
98
|
+
if self.model_hf_path:
|
|
99
|
+
self._model_mount_path = self.model_hf_path
|
|
100
|
+
|
|
101
|
+
if self.args:
|
|
102
|
+
raise ValueError("args cannot be set for SGLangAppEnvironment. Use `extra_args` to add extra arguments.")
|
|
103
|
+
|
|
104
|
+
if isinstance(self.extra_args, str):
|
|
105
|
+
extra_args = shlex.split(self.extra_args)
|
|
106
|
+
else:
|
|
107
|
+
extra_args = self.extra_args
|
|
108
|
+
|
|
109
|
+
self.args = [
|
|
110
|
+
"sglang-fserve",
|
|
111
|
+
"--model-path",
|
|
112
|
+
self._model_mount_path,
|
|
113
|
+
"--served-model-name",
|
|
114
|
+
self.model_id,
|
|
115
|
+
"--port",
|
|
116
|
+
str(self.get_port().port),
|
|
117
|
+
*extra_args,
|
|
118
|
+
]
|
|
119
|
+
|
|
120
|
+
if self.parameters:
|
|
121
|
+
raise ValueError("parameters cannot be set for SGLangAppEnvironment")
|
|
122
|
+
|
|
123
|
+
input_kwargs = {}
|
|
124
|
+
if self.stream_model:
|
|
125
|
+
self.env_vars["FLYTE_MODEL_LOADER_STREAM_SAFETENSORS"] = "true"
|
|
126
|
+
input_kwargs["env_var"] = "FLYTE_MODEL_LOADER_REMOTE_MODEL_PATH"
|
|
127
|
+
input_kwargs["download"] = False
|
|
128
|
+
else:
|
|
129
|
+
self.env_vars["FLYTE_MODEL_LOADER_STREAM_SAFETENSORS"] = "false"
|
|
130
|
+
input_kwargs["download"] = True
|
|
131
|
+
input_kwargs["mount"] = self._model_mount_path
|
|
132
|
+
|
|
133
|
+
if self.model_path:
|
|
134
|
+
self.parameters = [Parameter(name="model_path", value=self.model_path, **input_kwargs)]
|
|
135
|
+
|
|
136
|
+
self.env_vars["FLYTE_MODEL_LOADER_LOCAL_MODEL_PATH"] = self._model_mount_path
|
|
137
|
+
self.links = [flyte.app.Link(path="/docs", title="SGLang OpenAPI Docs", is_relative=True)]
|
|
138
|
+
|
|
139
|
+
if self.image is None or self.image == "auto":
|
|
140
|
+
self.image = DEFAULT_SGLANG_IMAGE
|
|
141
|
+
|
|
142
|
+
super().__post_init__()
|
|
143
|
+
|
|
144
|
+
def container_args(self, serialization_context: SerializationContext) -> list[str]:
|
|
145
|
+
"""Return the container arguments for SGLang."""
|
|
146
|
+
if isinstance(self.args, str):
|
|
147
|
+
return shlex.split(self.args)
|
|
148
|
+
return self.args or []
|
|
149
|
+
|
|
150
|
+
def clone_with(
|
|
151
|
+
self,
|
|
152
|
+
name: str,
|
|
153
|
+
image: Optional[Union[str, Image, Literal["auto"]]] = None,
|
|
154
|
+
resources: Optional[Resources] = None,
|
|
155
|
+
env_vars: Optional[dict[str, str]] = None,
|
|
156
|
+
secrets: Optional[SecretRequest] = None,
|
|
157
|
+
depends_on: Optional[list[Environment]] = None,
|
|
158
|
+
description: Optional[str] = None,
|
|
159
|
+
interruptible: Optional[bool] = None,
|
|
160
|
+
**kwargs: Any,
|
|
161
|
+
) -> SGLangAppEnvironment:
|
|
162
|
+
port = kwargs.pop("port", None)
|
|
163
|
+
extra_args = kwargs.pop("extra_args", None)
|
|
164
|
+
if "model_path" in kwargs:
|
|
165
|
+
set_model_path = True
|
|
166
|
+
model_path = kwargs.pop("model_path", "") or ""
|
|
167
|
+
else:
|
|
168
|
+
set_model_path = False
|
|
169
|
+
model_path = self.model_path
|
|
170
|
+
if "model_hf_path" in kwargs:
|
|
171
|
+
set_model_hf_path = True
|
|
172
|
+
model_hf_path = kwargs.pop("model_hf_path", "") or ""
|
|
173
|
+
else:
|
|
174
|
+
set_model_hf_path = False
|
|
175
|
+
model_hf_path = self.model_hf_path
|
|
176
|
+
model_id = kwargs.pop("model_id", None)
|
|
177
|
+
stream_model = kwargs.pop("stream_model", None)
|
|
178
|
+
|
|
179
|
+
if kwargs:
|
|
180
|
+
raise TypeError(f"Unexpected keyword arguments: {list(kwargs.keys())}")
|
|
181
|
+
|
|
182
|
+
kwargs = self._get_kwargs()
|
|
183
|
+
kwargs["name"] = name
|
|
184
|
+
kwargs["args"] = None
|
|
185
|
+
kwargs["parameters"] = None
|
|
186
|
+
if image is not None:
|
|
187
|
+
kwargs["image"] = image
|
|
188
|
+
if resources is not None:
|
|
189
|
+
kwargs["resources"] = resources
|
|
190
|
+
if env_vars is not None:
|
|
191
|
+
kwargs["env_vars"] = env_vars
|
|
192
|
+
if secrets is not None:
|
|
193
|
+
kwargs["secrets"] = secrets
|
|
194
|
+
if depends_on is not None:
|
|
195
|
+
kwargs["depends_on"] = depends_on
|
|
196
|
+
if description is not None:
|
|
197
|
+
kwargs["description"] = description
|
|
198
|
+
if interruptible is not None:
|
|
199
|
+
kwargs["interruptible"] = interruptible
|
|
200
|
+
if port is not None:
|
|
201
|
+
kwargs["port"] = port
|
|
202
|
+
if extra_args is not None:
|
|
203
|
+
kwargs["extra_args"] = extra_args
|
|
204
|
+
if set_model_path:
|
|
205
|
+
kwargs["model_path"] = model_path
|
|
206
|
+
if set_model_hf_path:
|
|
207
|
+
kwargs["model_hf_path"] = model_hf_path
|
|
208
|
+
if model_id is not None:
|
|
209
|
+
kwargs["model_id"] = model_id
|
|
210
|
+
if stream_model is not None:
|
|
211
|
+
kwargs["stream_model"] = stream_model
|
|
212
|
+
return replace(self, **kwargs)
|
|
File without changes
|
|
@@ -0,0 +1,171 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
import sys
|
|
4
|
+
from typing import Generator
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from flyte.app.extras._model_loader.config import (
|
|
8
|
+
LOCAL_MODEL_PATH,
|
|
9
|
+
REMOTE_MODEL_PATH,
|
|
10
|
+
STREAM_SAFETENSORS,
|
|
11
|
+
)
|
|
12
|
+
from flyte.app.extras._model_loader.loader import SafeTensorsStreamer, prefetch
|
|
13
|
+
|
|
14
|
+
from flyteplugins.sglang._constants import SGLANG_MIN_VERSION, SGLANG_MIN_VERSION_STR
|
|
15
|
+
|
|
16
|
+
try:
|
|
17
|
+
import sglang
|
|
18
|
+
except ImportError:
|
|
19
|
+
raise ImportError(
|
|
20
|
+
f"sglang is not installed. Please install 'sglang>={SGLANG_MIN_VERSION_STR}', to use the model loader."
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
if tuple([int(part) for part in sglang.__version__.split(".") if part.isdigit()]) < SGLANG_MIN_VERSION:
|
|
24
|
+
raise ImportError(
|
|
25
|
+
f"sglang version >={SGLANG_MIN_VERSION_STR} required, but found {sglang.__version__}. Please upgrade sglang."
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
import sglang.srt.model_loader.loader
|
|
29
|
+
from sglang.srt.configs.device_config import DeviceConfig
|
|
30
|
+
from sglang.srt.configs.model_config import ModelConfig
|
|
31
|
+
from sglang.srt.server_args import prepare_server_args
|
|
32
|
+
from sglang.srt.utils import kill_process_tree
|
|
33
|
+
|
|
34
|
+
try:
|
|
35
|
+
from sglang.srt.server import launch_server
|
|
36
|
+
except (ModuleNotFoundError, AttributeError, ImportError):
|
|
37
|
+
try:
|
|
38
|
+
from sglang.launch_server import launch_server
|
|
39
|
+
except (ModuleNotFoundError, AttributeError, ImportError):
|
|
40
|
+
from sglang.launch_server import run_server as launch_server
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
logger = logging.getLogger(__name__)
|
|
44
|
+
|
|
45
|
+
_OrigDefaultModelLoader = sglang.srt.model_loader.loader.DefaultModelLoader
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class FlyteModelLoader(_OrigDefaultModelLoader):
|
|
49
|
+
"""Custom model loader for streaming model weights from object storage."""
|
|
50
|
+
|
|
51
|
+
def _get_weights_iterator(self, source) -> Generator[tuple[str, torch.Tensor], None, None]:
|
|
52
|
+
# Try to load weights using the Flyte SafeTensorsStreamer. Fallback to the default loader otherwise.
|
|
53
|
+
try:
|
|
54
|
+
streamer = SafeTensorsStreamer(REMOTE_MODEL_PATH, LOCAL_MODEL_PATH)
|
|
55
|
+
except ValueError:
|
|
56
|
+
return super()._get_weights_iterator(source)
|
|
57
|
+
else:
|
|
58
|
+
for name, tensor in streamer.get_tensors():
|
|
59
|
+
yield source.prefix + name, tensor
|
|
60
|
+
|
|
61
|
+
def download_model(self, model_config: ModelConfig) -> None:
|
|
62
|
+
# This model loader supports streaming only
|
|
63
|
+
pass
|
|
64
|
+
|
|
65
|
+
def _load_sharded_model(self, *, model_config: ModelConfig, device_config: DeviceConfig) -> torch.nn.Module:
|
|
66
|
+
# Forked from: https://github.com/sgl-project/sglang/blob/1c4e0d2445311f2e635e9dab5a660d982731ad20/python/sglang/srt/model_loader/loader.py#L564
|
|
67
|
+
from sglang.srt.distributed import (
|
|
68
|
+
get_tensor_model_parallel_rank,
|
|
69
|
+
get_tensor_model_parallel_world_size,
|
|
70
|
+
model_parallel_is_initialized,
|
|
71
|
+
)
|
|
72
|
+
from sglang.srt.model_loader.loader import (
|
|
73
|
+
ShardedStateLoader,
|
|
74
|
+
_initialize_model,
|
|
75
|
+
)
|
|
76
|
+
from sglang.srt.model_loader.utils import set_default_torch_dtype
|
|
77
|
+
|
|
78
|
+
# Sanity checks
|
|
79
|
+
if model_parallel_is_initialized():
|
|
80
|
+
tensor_parallel_size = get_tensor_model_parallel_world_size()
|
|
81
|
+
rank = get_tensor_model_parallel_rank()
|
|
82
|
+
else:
|
|
83
|
+
tensor_parallel_size = 1
|
|
84
|
+
rank = 0
|
|
85
|
+
if rank >= tensor_parallel_size:
|
|
86
|
+
raise ValueError(f"Invalid rank {rank} for tensor parallel size {tensor_parallel_size}")
|
|
87
|
+
with set_default_torch_dtype(model_config.dtype):
|
|
88
|
+
with torch.device(device_config.device):
|
|
89
|
+
model = _initialize_model(model_config, self.load_config)
|
|
90
|
+
for _, module in model.named_modules():
|
|
91
|
+
quant_method = getattr(module, "quant_method", None)
|
|
92
|
+
if quant_method is not None:
|
|
93
|
+
quant_method.process_weights_after_loading(module)
|
|
94
|
+
state_dict = ShardedStateLoader._filter_subtensors(model.state_dict())
|
|
95
|
+
streamer = SafeTensorsStreamer(
|
|
96
|
+
REMOTE_MODEL_PATH,
|
|
97
|
+
LOCAL_MODEL_PATH,
|
|
98
|
+
rank=rank,
|
|
99
|
+
tensor_parallel_size=tensor_parallel_size,
|
|
100
|
+
)
|
|
101
|
+
for name, tensor in streamer.get_tensors():
|
|
102
|
+
# If loading with LoRA enabled, additional padding may
|
|
103
|
+
# be added to certain parameters. We only load into a
|
|
104
|
+
# narrowed view of the parameter data.
|
|
105
|
+
param_data = state_dict[name].data
|
|
106
|
+
param_shape = state_dict[name].shape
|
|
107
|
+
for dim, size in enumerate(tensor.shape):
|
|
108
|
+
if size < param_shape[dim]:
|
|
109
|
+
param_data = param_data.narrow(dim, 0, size)
|
|
110
|
+
if tensor.shape != param_shape:
|
|
111
|
+
logger.warning(
|
|
112
|
+
"loading tensor of shape %s into parameter '%s' of shape %s",
|
|
113
|
+
tensor.shape,
|
|
114
|
+
name,
|
|
115
|
+
param_shape,
|
|
116
|
+
)
|
|
117
|
+
param_data.copy_(tensor)
|
|
118
|
+
state_dict.pop(name)
|
|
119
|
+
if state_dict:
|
|
120
|
+
raise ValueError(f"Missing keys {tuple(state_dict)} in loaded state!")
|
|
121
|
+
return model.eval()
|
|
122
|
+
|
|
123
|
+
def load_model(self, *, model_config: ModelConfig, device_config: DeviceConfig) -> torch.nn.Module:
|
|
124
|
+
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
|
125
|
+
|
|
126
|
+
logger.info("Loading model with FlyteModelLoader")
|
|
127
|
+
if get_tensor_model_parallel_world_size() > 1:
|
|
128
|
+
return self._load_sharded_model(model_config=model_config, device_config=device_config)
|
|
129
|
+
else:
|
|
130
|
+
return super().load_model(model_config=model_config, device_config=device_config)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
# Monkeypatch the default model loader when streaming is enabled
|
|
134
|
+
if REMOTE_MODEL_PATH and STREAM_SAFETENSORS:
|
|
135
|
+
sglang.srt.model_loader.loader.DefaultModelLoader = FlyteModelLoader
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
async def _get_model_files():
|
|
139
|
+
import flyte.storage as storage
|
|
140
|
+
|
|
141
|
+
if not await storage.exists(REMOTE_MODEL_PATH):
|
|
142
|
+
raise FileNotFoundError(f"Model path not found: {REMOTE_MODEL_PATH}")
|
|
143
|
+
|
|
144
|
+
await prefetch(
|
|
145
|
+
REMOTE_MODEL_PATH,
|
|
146
|
+
LOCAL_MODEL_PATH,
|
|
147
|
+
exclude_safetensors=STREAM_SAFETENSORS,
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def main():
|
|
152
|
+
import asyncio
|
|
153
|
+
|
|
154
|
+
logging.basicConfig(
|
|
155
|
+
level=logging.INFO,
|
|
156
|
+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
|
157
|
+
)
|
|
158
|
+
logger.info(f"REMOTE_MODEL_PATH: {REMOTE_MODEL_PATH}")
|
|
159
|
+
logger.info(f"LOCAL_MODEL_PATH: {LOCAL_MODEL_PATH}")
|
|
160
|
+
logger.info(f"STREAM_SAFETENSORS: {STREAM_SAFETENSORS}")
|
|
161
|
+
|
|
162
|
+
# Prefetch the model
|
|
163
|
+
if REMOTE_MODEL_PATH:
|
|
164
|
+
logger.info("Prefetching model files...")
|
|
165
|
+
asyncio.run(_get_model_files())
|
|
166
|
+
|
|
167
|
+
server_args = prepare_server_args(sys.argv[1:])
|
|
168
|
+
try:
|
|
169
|
+
launch_server(server_args)
|
|
170
|
+
finally:
|
|
171
|
+
kill_process_tree(os.getpid(), include_parent=False)
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: flyteplugins-sglang
|
|
3
|
+
Version: 2.0.0b45
|
|
4
|
+
Summary: SGLang plugin for flyte
|
|
5
|
+
Author-email: Niels Bantilan <cosmicbboy@users.noreply.github.com>
|
|
6
|
+
Requires-Python: >=3.10
|
|
7
|
+
Description-Content-Type: text/markdown
|
|
8
|
+
Requires-Dist: flyte>=2.0.0b43
|
|
9
|
+
|
|
10
|
+
# Flyte SGLang Plugin
|
|
11
|
+
|
|
12
|
+
Serve large language models using SGLang with Flyte Apps.
|
|
13
|
+
|
|
14
|
+
This plugin provides the `SGLangAppEnvironment` class for deploying and serving LLMs using [SGLang](https://docs.sglang.ai/).
|
|
15
|
+
|
|
16
|
+
## Installation
|
|
17
|
+
|
|
18
|
+
```bash
|
|
19
|
+
pip install --pre flyteplugins-sglang
|
|
20
|
+
```
|
|
21
|
+
|
|
22
|
+
## Usage
|
|
23
|
+
|
|
24
|
+
```python
|
|
25
|
+
import flyte
|
|
26
|
+
import flyte.app
|
|
27
|
+
from flyteplugins.sglang import SGLangAppEnvironment
|
|
28
|
+
|
|
29
|
+
# Define the SGLang app environment
|
|
30
|
+
sglang_app = SGLangAppEnvironment(
|
|
31
|
+
name="my-llm-app",
|
|
32
|
+
model="s3://your-bucket/models/your-model",
|
|
33
|
+
model_id="your-model-id",
|
|
34
|
+
resources=flyte.Resources(cpu="4", memory="16Gi", gpu="L40s:1"),
|
|
35
|
+
stream_model=True, # Stream model directly from blob store to GPU
|
|
36
|
+
scaling=flyte.app.Scaling(
|
|
37
|
+
replicas=(0, 1),
|
|
38
|
+
scaledown_after=300,
|
|
39
|
+
),
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
if __name__ == "__main__":
|
|
43
|
+
flyte.init_from_config()
|
|
44
|
+
app = flyte.serve(sglang_app)
|
|
45
|
+
print(f"Deployed SGLang app: {app.url}")
|
|
46
|
+
```
|
|
47
|
+
|
|
48
|
+
## Features
|
|
49
|
+
|
|
50
|
+
- **Streaming Model Loading**: Stream model weights directly from object storage to GPU memory, reducing startup time and disk requirements.
|
|
51
|
+
- **OpenAI-Compatible API**: The deployed app exposes an OpenAI-compatible API for chat completions.
|
|
52
|
+
- **Auto-scaling**: Configure scaling policies to scale up/down based on traffic.
|
|
53
|
+
- **Tensor Parallelism**: Support for distributed inference across multiple GPUs.
|
|
54
|
+
|
|
55
|
+
## Extra Arguments
|
|
56
|
+
|
|
57
|
+
You can pass additional arguments to the SGLang server using the `extra_args` parameter:
|
|
58
|
+
|
|
59
|
+
```python
|
|
60
|
+
sglang_app = SGLangAppEnvironment(
|
|
61
|
+
name="my-llm-app",
|
|
62
|
+
model="s3://your-bucket/models/your-model",
|
|
63
|
+
model_id="your-model-id",
|
|
64
|
+
extra_args="--max-model-len 8192 --enforce-eager",
|
|
65
|
+
)
|
|
66
|
+
```
|
|
67
|
+
|
|
68
|
+
See the [SGLang server arguments documentation](https://docs.sglang.ai/backend/server_arguments.html) for available options.
|
|
69
|
+
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
flyteplugins/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
+
flyteplugins/sglang/__init__.py,sha256=xK4xPyWoYdYr8tuIhusQh-nAXfvwW5cOO0pggY0jvZ8,106
|
|
3
|
+
flyteplugins/sglang/_app_environment.py,sha256=pc6nuNpXXIBYuEwPVaA_ed_MIeuaFFoeOqglL-_tzfY,8510
|
|
4
|
+
flyteplugins/sglang/_constants.py,sha256=gCHEQipHhHY7J8L829C6e3XZ8sxX3ae0F7AJEcVzWBg,95
|
|
5
|
+
flyteplugins/sglang/_model_loader/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
6
|
+
flyteplugins/sglang/_model_loader/shim.py,sha256=U6mHvQBQ3bRjWsVZQG9IxmikMeSTWlJCus6hrydw2zo,6775
|
|
7
|
+
flyteplugins_sglang-2.0.0b45.dist-info/METADATA,sha256=5H4qZPEEufNajOceJcDIfHTRD9Av_yt77eg4_YlfARM,2054
|
|
8
|
+
flyteplugins_sglang-2.0.0b45.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
9
|
+
flyteplugins_sglang-2.0.0b45.dist-info/entry_points.txt,sha256=6eEFwDOLxQIAgFc03TStKyreECu3gycDlBRCTQ0pn0Y,78
|
|
10
|
+
flyteplugins_sglang-2.0.0b45.dist-info/top_level.txt,sha256=cgd779rPu9EsvdtuYgUxNHHgElaQvPn74KhB5XSeMBE,13
|
|
11
|
+
flyteplugins_sglang-2.0.0b45.dist-info/RECORD,,
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
flyteplugins
|