flyteplugins-sglang 2.0.0b45__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
File without changes
@@ -0,0 +1,3 @@
1
+ __all__ = ["SGLangAppEnvironment"]
2
+
3
+ from flyteplugins.sglang._app_environment import SGLangAppEnvironment
@@ -0,0 +1,212 @@
1
+ from __future__ import annotations
2
+
3
+ import shlex
4
+ from dataclasses import dataclass, field, replace
5
+ from typing import Any, Literal, Optional, Union
6
+
7
+ import flyte.app
8
+ import rich.repr
9
+ from flyte import Environment, Image, Resources, SecretRequest
10
+ from flyte.app import Parameter, RunOutput
11
+ from flyte.app._types import Port
12
+ from flyte.models import SerializationContext
13
+
14
+ from flyteplugins.sglang._constants import SGLANG_MIN_VERSION_STR
15
+
16
+ DEFAULT_SGLANG_IMAGE = (
17
+ flyte.Image.from_debian_base(name="sglang-app-image")
18
+ # install system dependencies, including CUDA toolkit, which is needed by sglang for compiling the model
19
+ .with_apt_packages("libnuma-dev", "wget")
20
+ .with_commands(
21
+ [
22
+ "wget https://developer.download.nvidia.com/compute/cuda/repos/debian12/x86_64/cuda-keyring_1.1-1_all.deb",
23
+ "dpkg -i cuda-keyring_1.1-1_all.deb",
24
+ "apt-get update",
25
+ "apt-get install -y cuda-toolkit-12-8",
26
+ ]
27
+ )
28
+ # install flash-infer
29
+ .with_pip_packages("flashinfer-python", "flashinfer-cubin")
30
+ .with_pip_packages("flashinfer-jit-cache", index_url="https://flashinfer.ai/whl/cu128")
31
+ .with_pip_packages("flyteplugins-sglang", pre=True)
32
+ .with_pip_packages(f"sglang>={SGLANG_MIN_VERSION_STR}")
33
+ .with_env_vars({"CUDA_HOME": "/usr/local/cuda-12.8"})
34
+ )
35
+
36
+
37
+ @rich.repr.auto
38
+ @dataclass(kw_only=True, repr=True)
39
+ class SGLangAppEnvironment(flyte.app.AppEnvironment):
40
+ """
41
+ App environment backed by SGLang for serving large language models.
42
+
43
+ This environment sets up an SGLang server with the specified model and configuration.
44
+
45
+ :param name: The name of the application.
46
+ :param container_image: The container image to use for the application.
47
+ :param port: Port application listens to. Defaults to 8000 for SGLang.
48
+ :param requests: Compute resource requests for application.
49
+ :param secrets: Secrets that are requested for application.
50
+ :param limits: Compute resource limits for application.
51
+ :param env_vars: Environment variables to set for the application.
52
+ :param scaling: Scaling configuration for the app environment.
53
+ :param domain: Domain to use for the app.
54
+ :param cluster_pool: The target cluster_pool where the app should be deployed.
55
+ :param requires_auth: Whether the public URL requires authentication.
56
+ :param type: Type of app.
57
+ :param extra_args: Extra args to pass to `python -m sglang.launch_server`. See
58
+ https://docs.sglang.io/advanced_features/server_arguments.html for details.
59
+ :param model_path: Remote path to model (e.g., s3://bucket/path/to/model).
60
+ :param model_hf_path: Hugging Face path to model (e.g., Qwen/Qwen3-0.6B).
61
+ :param model_id: Model id that is exposed by SGLang.
62
+ :param stream_model: Set to True to stream model from blob store to the GPU directly.
63
+ If False, the model will be downloaded to the local file system first and then loaded
64
+ into the GPU.
65
+ """
66
+
67
+ port: int | Port = 8080
68
+ type: str = "SGLang"
69
+ extra_args: str | list[str] = ""
70
+ model_path: str | RunOutput = ""
71
+ model_hf_path: str = ""
72
+ model_id: str = ""
73
+ stream_model: bool = True
74
+ image: str | Image | Literal["auto"] = DEFAULT_SGLANG_IMAGE
75
+ _model_mount_path: str = field(default="/root/flyte", init=False)
76
+
77
+ def __post_init__(self):
78
+ if self.env_vars is None:
79
+ self.env_vars = {}
80
+
81
+ if self._server is not None:
82
+ raise ValueError("server function cannot be set for SGLangAppEnvironment")
83
+
84
+ if self._on_startup is not None:
85
+ raise ValueError("on_startup function cannot be set for SGLangAppEnvironment")
86
+
87
+ if self._on_shutdown is not None:
88
+ raise ValueError("on_shutdown function cannot be set for SGLangAppEnvironment")
89
+
90
+ if self.model_id == "":
91
+ raise ValueError("model_id must be defined")
92
+
93
+ if self.model_path == "" and self.model_hf_path == "":
94
+ raise ValueError("model_path or model_hf_path must be defined")
95
+ if self.model_path != "" and self.model_hf_path != "":
96
+ raise ValueError("model_path and model_hf_path cannot be set at the same time")
97
+
98
+ if self.model_hf_path:
99
+ self._model_mount_path = self.model_hf_path
100
+
101
+ if self.args:
102
+ raise ValueError("args cannot be set for SGLangAppEnvironment. Use `extra_args` to add extra arguments.")
103
+
104
+ if isinstance(self.extra_args, str):
105
+ extra_args = shlex.split(self.extra_args)
106
+ else:
107
+ extra_args = self.extra_args
108
+
109
+ self.args = [
110
+ "sglang-fserve",
111
+ "--model-path",
112
+ self._model_mount_path,
113
+ "--served-model-name",
114
+ self.model_id,
115
+ "--port",
116
+ str(self.get_port().port),
117
+ *extra_args,
118
+ ]
119
+
120
+ if self.parameters:
121
+ raise ValueError("parameters cannot be set for SGLangAppEnvironment")
122
+
123
+ input_kwargs = {}
124
+ if self.stream_model:
125
+ self.env_vars["FLYTE_MODEL_LOADER_STREAM_SAFETENSORS"] = "true"
126
+ input_kwargs["env_var"] = "FLYTE_MODEL_LOADER_REMOTE_MODEL_PATH"
127
+ input_kwargs["download"] = False
128
+ else:
129
+ self.env_vars["FLYTE_MODEL_LOADER_STREAM_SAFETENSORS"] = "false"
130
+ input_kwargs["download"] = True
131
+ input_kwargs["mount"] = self._model_mount_path
132
+
133
+ if self.model_path:
134
+ self.parameters = [Parameter(name="model_path", value=self.model_path, **input_kwargs)]
135
+
136
+ self.env_vars["FLYTE_MODEL_LOADER_LOCAL_MODEL_PATH"] = self._model_mount_path
137
+ self.links = [flyte.app.Link(path="/docs", title="SGLang OpenAPI Docs", is_relative=True)]
138
+
139
+ if self.image is None or self.image == "auto":
140
+ self.image = DEFAULT_SGLANG_IMAGE
141
+
142
+ super().__post_init__()
143
+
144
+ def container_args(self, serialization_context: SerializationContext) -> list[str]:
145
+ """Return the container arguments for SGLang."""
146
+ if isinstance(self.args, str):
147
+ return shlex.split(self.args)
148
+ return self.args or []
149
+
150
+ def clone_with(
151
+ self,
152
+ name: str,
153
+ image: Optional[Union[str, Image, Literal["auto"]]] = None,
154
+ resources: Optional[Resources] = None,
155
+ env_vars: Optional[dict[str, str]] = None,
156
+ secrets: Optional[SecretRequest] = None,
157
+ depends_on: Optional[list[Environment]] = None,
158
+ description: Optional[str] = None,
159
+ interruptible: Optional[bool] = None,
160
+ **kwargs: Any,
161
+ ) -> SGLangAppEnvironment:
162
+ port = kwargs.pop("port", None)
163
+ extra_args = kwargs.pop("extra_args", None)
164
+ if "model_path" in kwargs:
165
+ set_model_path = True
166
+ model_path = kwargs.pop("model_path", "") or ""
167
+ else:
168
+ set_model_path = False
169
+ model_path = self.model_path
170
+ if "model_hf_path" in kwargs:
171
+ set_model_hf_path = True
172
+ model_hf_path = kwargs.pop("model_hf_path", "") or ""
173
+ else:
174
+ set_model_hf_path = False
175
+ model_hf_path = self.model_hf_path
176
+ model_id = kwargs.pop("model_id", None)
177
+ stream_model = kwargs.pop("stream_model", None)
178
+
179
+ if kwargs:
180
+ raise TypeError(f"Unexpected keyword arguments: {list(kwargs.keys())}")
181
+
182
+ kwargs = self._get_kwargs()
183
+ kwargs["name"] = name
184
+ kwargs["args"] = None
185
+ kwargs["parameters"] = None
186
+ if image is not None:
187
+ kwargs["image"] = image
188
+ if resources is not None:
189
+ kwargs["resources"] = resources
190
+ if env_vars is not None:
191
+ kwargs["env_vars"] = env_vars
192
+ if secrets is not None:
193
+ kwargs["secrets"] = secrets
194
+ if depends_on is not None:
195
+ kwargs["depends_on"] = depends_on
196
+ if description is not None:
197
+ kwargs["description"] = description
198
+ if interruptible is not None:
199
+ kwargs["interruptible"] = interruptible
200
+ if port is not None:
201
+ kwargs["port"] = port
202
+ if extra_args is not None:
203
+ kwargs["extra_args"] = extra_args
204
+ if set_model_path:
205
+ kwargs["model_path"] = model_path
206
+ if set_model_hf_path:
207
+ kwargs["model_hf_path"] = model_hf_path
208
+ if model_id is not None:
209
+ kwargs["model_id"] = model_id
210
+ if stream_model is not None:
211
+ kwargs["stream_model"] = stream_model
212
+ return replace(self, **kwargs)
@@ -0,0 +1,2 @@
1
+ SGLANG_MIN_VERSION = (0, 5, 2)
2
+ SGLANG_MIN_VERSION_STR = ".".join(map(str, SGLANG_MIN_VERSION))
File without changes
@@ -0,0 +1,171 @@
1
+ import logging
2
+ import os
3
+ import sys
4
+ from typing import Generator
5
+
6
+ import torch
7
+ from flyte.app.extras._model_loader.config import (
8
+ LOCAL_MODEL_PATH,
9
+ REMOTE_MODEL_PATH,
10
+ STREAM_SAFETENSORS,
11
+ )
12
+ from flyte.app.extras._model_loader.loader import SafeTensorsStreamer, prefetch
13
+
14
+ from flyteplugins.sglang._constants import SGLANG_MIN_VERSION, SGLANG_MIN_VERSION_STR
15
+
16
+ try:
17
+ import sglang
18
+ except ImportError:
19
+ raise ImportError(
20
+ f"sglang is not installed. Please install 'sglang>={SGLANG_MIN_VERSION_STR}', to use the model loader."
21
+ )
22
+
23
+ if tuple([int(part) for part in sglang.__version__.split(".") if part.isdigit()]) < SGLANG_MIN_VERSION:
24
+ raise ImportError(
25
+ f"sglang version >={SGLANG_MIN_VERSION_STR} required, but found {sglang.__version__}. Please upgrade sglang."
26
+ )
27
+
28
+ import sglang.srt.model_loader.loader
29
+ from sglang.srt.configs.device_config import DeviceConfig
30
+ from sglang.srt.configs.model_config import ModelConfig
31
+ from sglang.srt.server_args import prepare_server_args
32
+ from sglang.srt.utils import kill_process_tree
33
+
34
+ try:
35
+ from sglang.srt.server import launch_server
36
+ except (ModuleNotFoundError, AttributeError, ImportError):
37
+ try:
38
+ from sglang.launch_server import launch_server
39
+ except (ModuleNotFoundError, AttributeError, ImportError):
40
+ from sglang.launch_server import run_server as launch_server
41
+
42
+
43
+ logger = logging.getLogger(__name__)
44
+
45
+ _OrigDefaultModelLoader = sglang.srt.model_loader.loader.DefaultModelLoader
46
+
47
+
48
+ class FlyteModelLoader(_OrigDefaultModelLoader):
49
+ """Custom model loader for streaming model weights from object storage."""
50
+
51
+ def _get_weights_iterator(self, source) -> Generator[tuple[str, torch.Tensor], None, None]:
52
+ # Try to load weights using the Flyte SafeTensorsStreamer. Fallback to the default loader otherwise.
53
+ try:
54
+ streamer = SafeTensorsStreamer(REMOTE_MODEL_PATH, LOCAL_MODEL_PATH)
55
+ except ValueError:
56
+ return super()._get_weights_iterator(source)
57
+ else:
58
+ for name, tensor in streamer.get_tensors():
59
+ yield source.prefix + name, tensor
60
+
61
+ def download_model(self, model_config: ModelConfig) -> None:
62
+ # This model loader supports streaming only
63
+ pass
64
+
65
+ def _load_sharded_model(self, *, model_config: ModelConfig, device_config: DeviceConfig) -> torch.nn.Module:
66
+ # Forked from: https://github.com/sgl-project/sglang/blob/1c4e0d2445311f2e635e9dab5a660d982731ad20/python/sglang/srt/model_loader/loader.py#L564
67
+ from sglang.srt.distributed import (
68
+ get_tensor_model_parallel_rank,
69
+ get_tensor_model_parallel_world_size,
70
+ model_parallel_is_initialized,
71
+ )
72
+ from sglang.srt.model_loader.loader import (
73
+ ShardedStateLoader,
74
+ _initialize_model,
75
+ )
76
+ from sglang.srt.model_loader.utils import set_default_torch_dtype
77
+
78
+ # Sanity checks
79
+ if model_parallel_is_initialized():
80
+ tensor_parallel_size = get_tensor_model_parallel_world_size()
81
+ rank = get_tensor_model_parallel_rank()
82
+ else:
83
+ tensor_parallel_size = 1
84
+ rank = 0
85
+ if rank >= tensor_parallel_size:
86
+ raise ValueError(f"Invalid rank {rank} for tensor parallel size {tensor_parallel_size}")
87
+ with set_default_torch_dtype(model_config.dtype):
88
+ with torch.device(device_config.device):
89
+ model = _initialize_model(model_config, self.load_config)
90
+ for _, module in model.named_modules():
91
+ quant_method = getattr(module, "quant_method", None)
92
+ if quant_method is not None:
93
+ quant_method.process_weights_after_loading(module)
94
+ state_dict = ShardedStateLoader._filter_subtensors(model.state_dict())
95
+ streamer = SafeTensorsStreamer(
96
+ REMOTE_MODEL_PATH,
97
+ LOCAL_MODEL_PATH,
98
+ rank=rank,
99
+ tensor_parallel_size=tensor_parallel_size,
100
+ )
101
+ for name, tensor in streamer.get_tensors():
102
+ # If loading with LoRA enabled, additional padding may
103
+ # be added to certain parameters. We only load into a
104
+ # narrowed view of the parameter data.
105
+ param_data = state_dict[name].data
106
+ param_shape = state_dict[name].shape
107
+ for dim, size in enumerate(tensor.shape):
108
+ if size < param_shape[dim]:
109
+ param_data = param_data.narrow(dim, 0, size)
110
+ if tensor.shape != param_shape:
111
+ logger.warning(
112
+ "loading tensor of shape %s into parameter '%s' of shape %s",
113
+ tensor.shape,
114
+ name,
115
+ param_shape,
116
+ )
117
+ param_data.copy_(tensor)
118
+ state_dict.pop(name)
119
+ if state_dict:
120
+ raise ValueError(f"Missing keys {tuple(state_dict)} in loaded state!")
121
+ return model.eval()
122
+
123
+ def load_model(self, *, model_config: ModelConfig, device_config: DeviceConfig) -> torch.nn.Module:
124
+ from sglang.srt.distributed import get_tensor_model_parallel_world_size
125
+
126
+ logger.info("Loading model with FlyteModelLoader")
127
+ if get_tensor_model_parallel_world_size() > 1:
128
+ return self._load_sharded_model(model_config=model_config, device_config=device_config)
129
+ else:
130
+ return super().load_model(model_config=model_config, device_config=device_config)
131
+
132
+
133
+ # Monkeypatch the default model loader when streaming is enabled
134
+ if REMOTE_MODEL_PATH and STREAM_SAFETENSORS:
135
+ sglang.srt.model_loader.loader.DefaultModelLoader = FlyteModelLoader
136
+
137
+
138
+ async def _get_model_files():
139
+ import flyte.storage as storage
140
+
141
+ if not await storage.exists(REMOTE_MODEL_PATH):
142
+ raise FileNotFoundError(f"Model path not found: {REMOTE_MODEL_PATH}")
143
+
144
+ await prefetch(
145
+ REMOTE_MODEL_PATH,
146
+ LOCAL_MODEL_PATH,
147
+ exclude_safetensors=STREAM_SAFETENSORS,
148
+ )
149
+
150
+
151
+ def main():
152
+ import asyncio
153
+
154
+ logging.basicConfig(
155
+ level=logging.INFO,
156
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
157
+ )
158
+ logger.info(f"REMOTE_MODEL_PATH: {REMOTE_MODEL_PATH}")
159
+ logger.info(f"LOCAL_MODEL_PATH: {LOCAL_MODEL_PATH}")
160
+ logger.info(f"STREAM_SAFETENSORS: {STREAM_SAFETENSORS}")
161
+
162
+ # Prefetch the model
163
+ if REMOTE_MODEL_PATH:
164
+ logger.info("Prefetching model files...")
165
+ asyncio.run(_get_model_files())
166
+
167
+ server_args = prepare_server_args(sys.argv[1:])
168
+ try:
169
+ launch_server(server_args)
170
+ finally:
171
+ kill_process_tree(os.getpid(), include_parent=False)
@@ -0,0 +1,69 @@
1
+ Metadata-Version: 2.4
2
+ Name: flyteplugins-sglang
3
+ Version: 2.0.0b45
4
+ Summary: SGLang plugin for flyte
5
+ Author-email: Niels Bantilan <cosmicbboy@users.noreply.github.com>
6
+ Requires-Python: >=3.10
7
+ Description-Content-Type: text/markdown
8
+ Requires-Dist: flyte>=2.0.0b43
9
+
10
+ # Flyte SGLang Plugin
11
+
12
+ Serve large language models using SGLang with Flyte Apps.
13
+
14
+ This plugin provides the `SGLangAppEnvironment` class for deploying and serving LLMs using [SGLang](https://docs.sglang.ai/).
15
+
16
+ ## Installation
17
+
18
+ ```bash
19
+ pip install --pre flyteplugins-sglang
20
+ ```
21
+
22
+ ## Usage
23
+
24
+ ```python
25
+ import flyte
26
+ import flyte.app
27
+ from flyteplugins.sglang import SGLangAppEnvironment
28
+
29
+ # Define the SGLang app environment
30
+ sglang_app = SGLangAppEnvironment(
31
+ name="my-llm-app",
32
+ model="s3://your-bucket/models/your-model",
33
+ model_id="your-model-id",
34
+ resources=flyte.Resources(cpu="4", memory="16Gi", gpu="L40s:1"),
35
+ stream_model=True, # Stream model directly from blob store to GPU
36
+ scaling=flyte.app.Scaling(
37
+ replicas=(0, 1),
38
+ scaledown_after=300,
39
+ ),
40
+ )
41
+
42
+ if __name__ == "__main__":
43
+ flyte.init_from_config()
44
+ app = flyte.serve(sglang_app)
45
+ print(f"Deployed SGLang app: {app.url}")
46
+ ```
47
+
48
+ ## Features
49
+
50
+ - **Streaming Model Loading**: Stream model weights directly from object storage to GPU memory, reducing startup time and disk requirements.
51
+ - **OpenAI-Compatible API**: The deployed app exposes an OpenAI-compatible API for chat completions.
52
+ - **Auto-scaling**: Configure scaling policies to scale up/down based on traffic.
53
+ - **Tensor Parallelism**: Support for distributed inference across multiple GPUs.
54
+
55
+ ## Extra Arguments
56
+
57
+ You can pass additional arguments to the SGLang server using the `extra_args` parameter:
58
+
59
+ ```python
60
+ sglang_app = SGLangAppEnvironment(
61
+ name="my-llm-app",
62
+ model="s3://your-bucket/models/your-model",
63
+ model_id="your-model-id",
64
+ extra_args="--max-model-len 8192 --enforce-eager",
65
+ )
66
+ ```
67
+
68
+ See the [SGLang server arguments documentation](https://docs.sglang.ai/backend/server_arguments.html) for available options.
69
+
@@ -0,0 +1,11 @@
1
+ flyteplugins/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ flyteplugins/sglang/__init__.py,sha256=xK4xPyWoYdYr8tuIhusQh-nAXfvwW5cOO0pggY0jvZ8,106
3
+ flyteplugins/sglang/_app_environment.py,sha256=pc6nuNpXXIBYuEwPVaA_ed_MIeuaFFoeOqglL-_tzfY,8510
4
+ flyteplugins/sglang/_constants.py,sha256=gCHEQipHhHY7J8L829C6e3XZ8sxX3ae0F7AJEcVzWBg,95
5
+ flyteplugins/sglang/_model_loader/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
+ flyteplugins/sglang/_model_loader/shim.py,sha256=U6mHvQBQ3bRjWsVZQG9IxmikMeSTWlJCus6hrydw2zo,6775
7
+ flyteplugins_sglang-2.0.0b45.dist-info/METADATA,sha256=5H4qZPEEufNajOceJcDIfHTRD9Av_yt77eg4_YlfARM,2054
8
+ flyteplugins_sglang-2.0.0b45.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
9
+ flyteplugins_sglang-2.0.0b45.dist-info/entry_points.txt,sha256=6eEFwDOLxQIAgFc03TStKyreECu3gycDlBRCTQ0pn0Y,78
10
+ flyteplugins_sglang-2.0.0b45.dist-info/top_level.txt,sha256=cgd779rPu9EsvdtuYgUxNHHgElaQvPn74KhB5XSeMBE,13
11
+ flyteplugins_sglang-2.0.0b45.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (80.9.0)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1,2 @@
1
+ [console_scripts]
2
+ sglang-fserve = flyteplugins.sglang._model_loader.shim:main
@@ -0,0 +1 @@
1
+ flyteplugins