sglang 0.1.15__py3-none-any.whl → 0.1.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +5 -1
- sglang/api.py +8 -3
- sglang/backend/anthropic.py +1 -1
- sglang/backend/litellm.py +90 -0
- sglang/backend/openai.py +148 -12
- sglang/backend/runtime_endpoint.py +18 -10
- sglang/global_config.py +11 -1
- sglang/lang/chat_template.py +9 -2
- sglang/lang/interpreter.py +161 -81
- sglang/lang/ir.py +29 -11
- sglang/lang/tracer.py +1 -1
- sglang/launch_server.py +1 -2
- sglang/launch_server_llavavid.py +31 -0
- sglang/srt/constrained/fsm_cache.py +3 -0
- sglang/srt/flush_cache.py +16 -0
- sglang/srt/hf_transformers_utils.py +83 -2
- sglang/srt/layers/extend_attention.py +17 -0
- sglang/srt/layers/fused_moe.py +485 -0
- sglang/srt/layers/logits_processor.py +12 -7
- sglang/srt/layers/radix_attention.py +10 -3
- sglang/srt/layers/token_attention.py +16 -1
- sglang/srt/managers/controller/dp_worker.py +110 -0
- sglang/srt/managers/controller/infer_batch.py +619 -0
- sglang/srt/managers/controller/manager_multi.py +191 -0
- sglang/srt/managers/controller/manager_single.py +97 -0
- sglang/srt/managers/controller/model_runner.py +462 -0
- sglang/srt/managers/controller/radix_cache.py +267 -0
- sglang/srt/managers/controller/schedule_heuristic.py +59 -0
- sglang/srt/managers/controller/tp_worker.py +791 -0
- sglang/srt/managers/detokenizer_manager.py +45 -45
- sglang/srt/managers/io_struct.py +26 -10
- sglang/srt/managers/router/infer_batch.py +130 -74
- sglang/srt/managers/router/manager.py +7 -9
- sglang/srt/managers/router/model_rpc.py +224 -135
- sglang/srt/managers/router/model_runner.py +94 -107
- sglang/srt/managers/router/radix_cache.py +54 -18
- sglang/srt/managers/router/scheduler.py +23 -34
- sglang/srt/managers/tokenizer_manager.py +183 -88
- sglang/srt/model_config.py +5 -2
- sglang/srt/models/commandr.py +15 -22
- sglang/srt/models/dbrx.py +22 -29
- sglang/srt/models/gemma.py +14 -24
- sglang/srt/models/grok.py +671 -0
- sglang/srt/models/llama2.py +24 -23
- sglang/srt/models/llava.py +85 -25
- sglang/srt/models/llavavid.py +298 -0
- sglang/srt/models/mixtral.py +254 -130
- sglang/srt/models/mixtral_quant.py +373 -0
- sglang/srt/models/qwen.py +28 -25
- sglang/srt/models/qwen2.py +17 -22
- sglang/srt/models/stablelm.py +21 -26
- sglang/srt/models/yivl.py +17 -25
- sglang/srt/openai_api_adapter.py +140 -95
- sglang/srt/openai_protocol.py +10 -1
- sglang/srt/server.py +101 -52
- sglang/srt/server_args.py +59 -11
- sglang/srt/utils.py +242 -75
- sglang/test/test_programs.py +44 -0
- sglang/test/test_utils.py +32 -1
- sglang/utils.py +95 -26
- {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/METADATA +23 -13
- sglang-0.1.17.dist-info/RECORD +81 -0
- sglang/srt/backend_config.py +0 -13
- sglang/srt/models/dbrx_config.py +0 -281
- sglang/srt/weight_utils.py +0 -402
- sglang-0.1.15.dist-info/RECORD +0 -69
- {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/LICENSE +0 -0
- {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/WHEEL +0 -0
- {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/top_level.txt +0 -0
sglang/srt/models/gemma.py
CHANGED
@@ -1,12 +1,13 @@
|
|
1
1
|
# Adapted from:
|
2
|
-
# https://github.com/vllm-project/vllm/blob/
|
2
|
+
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/gemma.py#L1
|
3
3
|
"""Inference-only Gemma model compatible with HuggingFace weights."""
|
4
|
-
from typing import Optional, Tuple
|
4
|
+
from typing import Iterable, Optional, Tuple
|
5
5
|
|
6
6
|
import torch
|
7
7
|
from torch import nn
|
8
8
|
from transformers import PretrainedConfig
|
9
|
-
from vllm.config import LoRAConfig
|
9
|
+
from vllm.config import LoRAConfig, CacheConfig
|
10
|
+
from vllm.distributed import get_tensor_model_parallel_world_size
|
10
11
|
from vllm.model_executor.layers.activation import GeluAndMul
|
11
12
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
12
13
|
from vllm.model_executor.layers.linear import (
|
@@ -14,21 +15,14 @@ from vllm.model_executor.layers.linear import (
|
|
14
15
|
QKVParallelLinear,
|
15
16
|
RowParallelLinear,
|
16
17
|
)
|
17
|
-
from vllm.model_executor.layers.quantization.base_config import
|
18
|
-
QuantizationConfig)
|
18
|
+
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
19
19
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
20
20
|
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
21
|
-
from vllm.
|
22
|
-
get_tensor_model_parallel_world_size,
|
23
|
-
)
|
24
|
-
from sglang.srt.weight_utils import (
|
25
|
-
default_weight_loader,
|
26
|
-
hf_model_weights_iterator,
|
27
|
-
)
|
21
|
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
28
22
|
|
29
23
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
30
24
|
from sglang.srt.layers.radix_attention import RadixAttention
|
31
|
-
from sglang.srt.managers.
|
25
|
+
from sglang.srt.managers.controller.model_runner import InputMetadata
|
32
26
|
|
33
27
|
|
34
28
|
class GemmaMLP(nn.Module):
|
@@ -46,7 +40,10 @@ class GemmaMLP(nn.Module):
|
|
46
40
|
quant_config=quant_config,
|
47
41
|
)
|
48
42
|
self.down_proj = RowParallelLinear(
|
49
|
-
intermediate_size,
|
43
|
+
intermediate_size,
|
44
|
+
hidden_size,
|
45
|
+
bias=False,
|
46
|
+
quant_config=quant_config,
|
50
47
|
)
|
51
48
|
self.act_fn = GeluAndMul()
|
52
49
|
|
@@ -267,6 +264,7 @@ class GemmaForCausalLM(nn.Module):
|
|
267
264
|
config: PretrainedConfig,
|
268
265
|
quant_config: Optional[QuantizationConfig] = None,
|
269
266
|
lora_config: Optional[LoRAConfig] = None,
|
267
|
+
cache_config: Optional[CacheConfig] = None,
|
270
268
|
) -> None:
|
271
269
|
del lora_config # Unused.
|
272
270
|
super().__init__()
|
@@ -288,13 +286,7 @@ class GemmaForCausalLM(nn.Module):
|
|
288
286
|
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
|
289
287
|
)
|
290
288
|
|
291
|
-
def load_weights(
|
292
|
-
self,
|
293
|
-
model_name_or_path: str,
|
294
|
-
cache_dir: Optional[str] = None,
|
295
|
-
load_format: str = "auto",
|
296
|
-
revision: Optional[str] = None,
|
297
|
-
):
|
289
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
298
290
|
stacked_params_mapping = [
|
299
291
|
# (param_name, shard_name, shard_id)
|
300
292
|
("qkv_proj", "q_proj", "q"),
|
@@ -305,9 +297,7 @@ class GemmaForCausalLM(nn.Module):
|
|
305
297
|
]
|
306
298
|
params_dict = dict(self.named_parameters())
|
307
299
|
loaded_params = set()
|
308
|
-
for name, loaded_weight in
|
309
|
-
model_name_or_path, cache_dir, load_format, revision
|
310
|
-
):
|
300
|
+
for name, loaded_weight in weights:
|
311
301
|
for param_name, shard_name, shard_id in stacked_params_mapping:
|
312
302
|
if shard_name not in name:
|
313
303
|
continue
|