sglang 0.3.6__py3-none-any.whl → 0.3.6.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -2
- sglang/api.py +2 -2
- sglang/bench_one_batch.py +2 -4
- sglang/bench_serving.py +75 -26
- sglang/lang/backend/base_backend.py +1 -1
- sglang/lang/backend/runtime_endpoint.py +2 -2
- sglang/srt/configs/model_config.py +13 -14
- sglang/srt/constrained/__init__.py +13 -14
- sglang/srt/constrained/base_grammar_backend.py +13 -15
- sglang/srt/constrained/outlines_backend.py +13 -15
- sglang/srt/constrained/outlines_jump_forward.py +13 -15
- sglang/srt/constrained/xgrammar_backend.py +38 -57
- sglang/srt/conversation.py +13 -15
- sglang/srt/hf_transformers_utils.py +13 -15
- sglang/srt/layers/activation.py +13 -13
- sglang/srt/layers/attention/flashinfer_backend.py +13 -6
- sglang/srt/layers/attention/triton_ops/decode_attention.py +51 -55
- sglang/srt/layers/attention/triton_ops/extend_attention.py +16 -16
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +13 -15
- sglang/srt/layers/custom_op_util.py +13 -14
- sglang/srt/layers/fused_moe_grok/__init__.py +1 -0
- sglang/srt/layers/{fused_moe → fused_moe_grok}/layer.py +4 -9
- sglang/srt/layers/{fused_moe/patch.py → fused_moe_patch.py} +5 -0
- sglang/srt/layers/fused_moe_triton/__init__.py +44 -0
- sglang/srt/layers/fused_moe_triton/fused_moe.py +861 -0
- sglang/srt/layers/fused_moe_triton/layer.py +633 -0
- sglang/srt/layers/layernorm.py +13 -15
- sglang/srt/layers/logits_processor.py +13 -15
- sglang/srt/layers/quantization/__init__.py +77 -17
- sglang/srt/layers/radix_attention.py +13 -15
- sglang/srt/layers/rotary_embedding.py +13 -13
- sglang/srt/lora/lora.py +13 -14
- sglang/srt/lora/lora_config.py +13 -14
- sglang/srt/lora/lora_manager.py +22 -24
- sglang/srt/managers/data_parallel_controller.py +25 -19
- sglang/srt/managers/detokenizer_manager.py +13 -16
- sglang/srt/managers/io_struct.py +43 -28
- sglang/srt/managers/schedule_batch.py +55 -26
- sglang/srt/managers/schedule_policy.py +13 -15
- sglang/srt/managers/scheduler.py +89 -70
- sglang/srt/managers/session_controller.py +14 -15
- sglang/srt/managers/tokenizer_manager.py +29 -22
- sglang/srt/managers/tp_worker.py +13 -15
- sglang/srt/managers/tp_worker_overlap_thread.py +13 -15
- sglang/srt/metrics/collector.py +13 -15
- sglang/srt/metrics/func_timer.py +13 -15
- sglang/srt/mm_utils.py +13 -14
- sglang/srt/model_executor/cuda_graph_runner.py +20 -19
- sglang/srt/model_executor/forward_batch_info.py +19 -17
- sglang/srt/model_executor/model_runner.py +42 -30
- sglang/srt/models/chatglm.py +15 -16
- sglang/srt/models/commandr.py +15 -16
- sglang/srt/models/dbrx.py +15 -16
- sglang/srt/models/deepseek.py +15 -15
- sglang/srt/models/deepseek_v2.py +15 -15
- sglang/srt/models/exaone.py +14 -15
- sglang/srt/models/gemma.py +14 -14
- sglang/srt/models/gemma2.py +24 -19
- sglang/srt/models/gemma2_reward.py +13 -14
- sglang/srt/models/gpt_bigcode.py +14 -14
- sglang/srt/models/grok.py +15 -15
- sglang/srt/models/internlm2.py +13 -15
- sglang/srt/models/internlm2_reward.py +13 -14
- sglang/srt/models/llama.py +21 -21
- sglang/srt/models/llama_classification.py +13 -14
- sglang/srt/models/llama_reward.py +13 -14
- sglang/srt/models/llava.py +13 -15
- sglang/srt/models/llavavid.py +13 -15
- sglang/srt/models/minicpm.py +13 -15
- sglang/srt/models/minicpm3.py +13 -15
- sglang/srt/models/mistral.py +13 -15
- sglang/srt/models/mixtral.py +15 -15
- sglang/srt/models/mixtral_quant.py +14 -14
- sglang/srt/models/olmo.py +21 -19
- sglang/srt/models/olmoe.py +23 -20
- sglang/srt/models/qwen.py +14 -14
- sglang/srt/models/qwen2.py +22 -19
- sglang/srt/models/qwen2_moe.py +17 -18
- sglang/srt/models/stablelm.py +18 -16
- sglang/srt/models/torch_native_llama.py +15 -17
- sglang/srt/models/xverse.py +13 -14
- sglang/srt/models/xverse_moe.py +15 -16
- sglang/srt/models/yivl.py +13 -15
- sglang/srt/openai_api/adapter.py +13 -15
- sglang/srt/openai_api/protocol.py +13 -15
- sglang/srt/sampling/sampling_batch_info.py +4 -1
- sglang/srt/sampling/sampling_params.py +13 -15
- sglang/srt/server.py +59 -34
- sglang/srt/server_args.py +22 -22
- sglang/srt/utils.py +196 -17
- sglang/test/few_shot_gsm8k.py +8 -4
- sglang/test/runners.py +13 -14
- sglang/test/test_utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.3.6.dist-info → sglang-0.3.6.post1.dist-info}/LICENSE +1 -1
- {sglang-0.3.6.dist-info → sglang-0.3.6.post1.dist-info}/METADATA +24 -15
- sglang-0.3.6.post1.dist-info/RECORD +164 -0
- sglang/srt/layers/fused_moe/__init__.py +0 -1
- sglang-0.3.6.dist-info/RECORD +0 -161
- /sglang/srt/layers/{fused_moe → fused_moe_grok}/fused_moe.py +0 -0
- {sglang-0.3.6.dist-info → sglang-0.3.6.post1.dist-info}/WHEEL +0 -0
- {sglang-0.3.6.dist-info → sglang-0.3.6.post1.dist-info}/top_level.txt +0 -0
@@ -1,18 +1,19 @@
|
|
1
1
|
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py
|
2
2
|
|
3
|
-
from typing import Dict, Type
|
3
|
+
from typing import Callable, Dict, Optional, Type
|
4
4
|
|
5
|
+
import torch
|
5
6
|
from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
|
6
7
|
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
7
8
|
from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig
|
8
9
|
from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
|
9
|
-
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import (
|
10
|
+
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import (
|
10
11
|
CompressedTensorsConfig,
|
11
12
|
)
|
12
13
|
from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig
|
13
14
|
from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
|
14
15
|
from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
|
15
|
-
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
16
|
+
from vllm.model_executor.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
|
16
17
|
from vllm.model_executor.layers.quantization.gguf import GGUFConfig
|
17
18
|
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
18
19
|
from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinConfig
|
@@ -30,8 +31,6 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
|
30
31
|
"tpu_int8": Int8TpuConfig,
|
31
32
|
"fp8": Fp8Config,
|
32
33
|
"fbgemm_fp8": FBGEMMFp8Config,
|
33
|
-
# The order of gptq methods is important for config.py iteration over
|
34
|
-
# override_quantization_method(..)
|
35
34
|
"marlin": MarlinConfig,
|
36
35
|
"gguf": GGUFConfig,
|
37
36
|
"gptq_marlin_24": GPTQMarlin24Config,
|
@@ -47,20 +46,68 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
|
47
46
|
|
48
47
|
def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
49
48
|
if quantization not in QUANTIZATION_METHODS:
|
50
|
-
raise ValueError(
|
49
|
+
raise ValueError(
|
50
|
+
f"Invalid quantization method: {quantization}. "
|
51
|
+
f"Available methods: {list(QUANTIZATION_METHODS.keys())}"
|
52
|
+
)
|
51
53
|
return QUANTIZATION_METHODS[quantization]
|
52
54
|
|
53
55
|
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
56
|
+
def fp8_moe_apply(
|
57
|
+
self,
|
58
|
+
layer: torch.nn.Module,
|
59
|
+
x: torch.Tensor,
|
60
|
+
router_logits: torch.Tensor,
|
61
|
+
top_k: int,
|
62
|
+
renormalize: bool,
|
63
|
+
use_grouped_topk: bool,
|
64
|
+
topk_group: Optional[int] = None,
|
65
|
+
num_expert_group: Optional[int] = None,
|
66
|
+
custom_routing_function: Optional[Callable] = None,
|
67
|
+
) -> torch.Tensor:
|
68
|
+
"""Enhanced apply method for FP8 MoE."""
|
69
|
+
from sglang.srt.layers.fused_moe_triton import FusedMoE
|
70
|
+
from sglang.srt.layers.fused_moe_triton.fused_moe import fused_experts
|
71
|
+
|
72
|
+
# Expert selection
|
73
|
+
topk_weights, topk_ids = FusedMoE.select_experts(
|
74
|
+
hidden_states=x,
|
75
|
+
router_logits=router_logits,
|
76
|
+
use_grouped_topk=use_grouped_topk,
|
77
|
+
top_k=top_k,
|
78
|
+
renormalize=renormalize,
|
79
|
+
topk_group=topk_group,
|
80
|
+
num_expert_group=num_expert_group,
|
81
|
+
custom_routing_function=custom_routing_function,
|
82
|
+
)
|
83
|
+
|
84
|
+
# Expert fusion with FP8 quantization
|
85
|
+
return fused_experts(
|
86
|
+
x,
|
87
|
+
layer.w13_weight,
|
88
|
+
layer.w2_weight,
|
89
|
+
topk_weights=topk_weights,
|
90
|
+
topk_ids=topk_ids,
|
91
|
+
inplace=True,
|
92
|
+
use_fp8_w8a8=True,
|
93
|
+
w1_scale=layer.w13_weight_scale,
|
94
|
+
w2_scale=layer.w2_weight_scale,
|
95
|
+
a1_scale=layer.w13_input_scale,
|
96
|
+
a2_scale=layer.w2_input_scale,
|
97
|
+
)
|
98
|
+
|
99
|
+
|
100
|
+
def fp8_get_quant_method(self, layer, prefix):
|
101
|
+
"""Enhanced get_quant_method for FP8 config."""
|
102
|
+
from vllm.model_executor.layers.linear import LinearBase
|
103
|
+
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
|
104
|
+
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
105
|
+
is_layer_skipped,
|
106
|
+
)
|
107
|
+
|
108
|
+
from sglang.srt.layers.fused_moe_triton.layer import FusedMoE
|
109
|
+
from sglang.srt.layers.linear import UnquantizedLinearMethod
|
59
110
|
|
60
|
-
"""
|
61
|
-
def fp8_get_quant_method(
|
62
|
-
self, layer: torch.nn.Module, prefix: str
|
63
|
-
) -> Optional["QuantizeMethodBase"]:
|
64
111
|
if isinstance(layer, LinearBase):
|
65
112
|
if is_layer_skipped(prefix, self.ignored_layers):
|
66
113
|
return UnquantizedLinearMethod()
|
@@ -70,5 +117,18 @@ def fp8_get_quant_method(
|
|
70
117
|
return None
|
71
118
|
|
72
119
|
|
73
|
-
|
74
|
-
"""
|
120
|
+
def apply_monkey_patches():
|
121
|
+
"""Apply all monkey patches in one place."""
|
122
|
+
setattr(Fp8MoEMethod, "apply", fp8_moe_apply)
|
123
|
+
setattr(Fp8Config, "get_quant_method", fp8_get_quant_method)
|
124
|
+
|
125
|
+
|
126
|
+
# Apply patches when module is imported
|
127
|
+
apply_monkey_patches()
|
128
|
+
|
129
|
+
|
130
|
+
__all__ = [
|
131
|
+
"QuantizationConfig",
|
132
|
+
"get_quantization_config",
|
133
|
+
"QUANTIZATION_METHODS",
|
134
|
+
]
|
@@ -1,18 +1,16 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
15
|
-
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
16
14
|
"""Radix attention."""
|
17
15
|
|
18
16
|
from torch import nn
|
@@ -1,16 +1,16 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
14
|
"""MRotaryEmbedding"""
|
15
15
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
16
16
|
|
sglang/srt/lora/lora.py
CHANGED
@@ -1,17 +1,16 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
15
14
|
|
16
15
|
# Integrates "S-LoRA: Serving Thousands of Concurrent LoRA Adapters"
|
17
16
|
# and "Punica: Multi-Tenant LoRA Serving"
|
sglang/srt/lora/lora_config.py
CHANGED
@@ -1,17 +1,16 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
15
14
|
|
16
15
|
import json
|
17
16
|
import os
|
sglang/srt/lora/lora_manager.py
CHANGED
@@ -1,22 +1,20 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
15
14
|
|
16
15
|
# Integrates "S-LoRA: Serving Thousands of Concurrent LoRA Adapters"
|
17
16
|
# and "Punica: Multi-Tenant LoRA Serving"
|
18
17
|
|
19
|
-
|
20
18
|
import logging
|
21
19
|
import re
|
22
20
|
|
@@ -146,9 +144,9 @@ class LoRAManager:
|
|
146
144
|
}
|
147
145
|
else:
|
148
146
|
logger.warning(
|
149
|
-
|
150
|
-
|
151
|
-
|
147
|
+
"WARNING: get_module_name() is not defined, "
|
148
|
+
"which is used to map config module name to model implementation module name."
|
149
|
+
"Use the default one, but please check if it is correct for your model."
|
152
150
|
)
|
153
151
|
self.target_modules = {
|
154
152
|
get_module_name(module) for module in self.origin_target_modules
|
@@ -194,9 +192,9 @@ class LoRAManager:
|
|
194
192
|
hidden_dim_A, _ = self.base_model.get_hidden_dim(module_A)
|
195
193
|
else:
|
196
194
|
logger.warning(
|
197
|
-
|
198
|
-
|
199
|
-
|
195
|
+
"WARNING: get_hidden_dim() is not defined, "
|
196
|
+
"which is used to get the hidden dim for different lora modules"
|
197
|
+
"Use the default one, but please check if it is correct for your model."
|
200
198
|
)
|
201
199
|
hidden_dim_A, _ = get_hidden_dim(module_A, self.base_hf_config)
|
202
200
|
c = self.loras[-1].get_stacked_multiply(module_A)
|
@@ -218,9 +216,9 @@ class LoRAManager:
|
|
218
216
|
_, hidden_dim_B = self.base_model.get_hidden_dim(module_B)
|
219
217
|
else:
|
220
218
|
logger.warning(
|
221
|
-
|
222
|
-
|
223
|
-
|
219
|
+
"WARNING: get_hidden_dim() is not defined, "
|
220
|
+
"which is used to get the hidden dim for different lora modules"
|
221
|
+
"Use the default one, but please check if it is correct for your model."
|
224
222
|
)
|
225
223
|
_, hidden_dim_B = get_hidden_dim(module_B, self.base_hf_config)
|
226
224
|
c = self.loras[-1].get_stacked_multiply(module_B)
|
@@ -1,18 +1,16 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
15
|
-
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
16
14
|
"""A controller that dispatches requests to multiple data parallel workers."""
|
17
15
|
|
18
16
|
import logging
|
@@ -169,9 +167,12 @@ class DataParallelController:
|
|
169
167
|
self.context, zmq.PUSH, port_args.scheduler_input_ipc_name
|
170
168
|
)
|
171
169
|
|
172
|
-
# Wait for model to finish loading
|
170
|
+
# Wait for model to finish loading and get max token nums
|
171
|
+
scheduler_info = []
|
173
172
|
for i in range(len(scheduler_pipe_readers)):
|
174
|
-
scheduler_pipe_readers[i].recv()
|
173
|
+
scheduler_info.append(scheduler_pipe_readers[i].recv())
|
174
|
+
|
175
|
+
self.max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"]
|
175
176
|
|
176
177
|
return send_to
|
177
178
|
|
@@ -193,7 +194,10 @@ class DataParallelController:
|
|
193
194
|
send_to = get_zmq_socket(
|
194
195
|
self.context, zmq.PUSH, port_args.scheduler_input_ipc_name
|
195
196
|
)
|
196
|
-
|
197
|
+
|
198
|
+
scheduler_info = reader.recv()
|
199
|
+
self.max_total_num_tokens = scheduler_info["max_total_num_tokens"]
|
200
|
+
|
197
201
|
return send_to
|
198
202
|
|
199
203
|
def round_robin_scheduler(self, req):
|
@@ -235,7 +239,9 @@ def run_data_parallel_controller_process(
|
|
235
239
|
|
236
240
|
try:
|
237
241
|
controller = DataParallelController(server_args, port_args)
|
238
|
-
pipe_writer.send(
|
242
|
+
pipe_writer.send(
|
243
|
+
{"status": "ready", "max_total_num_tokens": controller.max_total_num_tokens}
|
244
|
+
)
|
239
245
|
controller.event_loop()
|
240
246
|
except Exception:
|
241
247
|
msg = get_exception_traceback()
|
@@ -1,18 +1,16 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
15
|
-
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
16
14
|
"""DetokenizerManager is a process that detokenizes the token ids."""
|
17
15
|
|
18
16
|
import dataclasses
|
@@ -175,7 +173,6 @@ class DetokenizerManager:
|
|
175
173
|
output_strs=output_strs,
|
176
174
|
meta_info=recv_obj.meta_info,
|
177
175
|
finished_reason=recv_obj.finished_reason,
|
178
|
-
session_ids=recv_obj.session_ids,
|
179
176
|
)
|
180
177
|
)
|
181
178
|
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -1,18 +1,16 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
15
|
-
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
16
14
|
"""
|
17
15
|
The definition of objects transfered between different
|
18
16
|
processes (TokenizerManager, DetokenizerManager, Controller).
|
@@ -21,7 +19,7 @@ processes (TokenizerManager, DetokenizerManager, Controller).
|
|
21
19
|
import uuid
|
22
20
|
from dataclasses import dataclass
|
23
21
|
from enum import Enum
|
24
|
-
from typing import Dict, List, Optional, Union
|
22
|
+
from typing import Dict, List, Optional, Tuple, Union
|
25
23
|
|
26
24
|
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
27
25
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
@@ -31,8 +29,10 @@ from sglang.srt.sampling.sampling_params import SamplingParams
|
|
31
29
|
class GenerateReqInput:
|
32
30
|
# The input prompt. It can be a single prompt or a batch of prompts.
|
33
31
|
text: Optional[Union[List[str], str]] = None
|
34
|
-
# The token ids for text; one can either
|
32
|
+
# The token ids for text; one can specify either text or input_ids
|
35
33
|
input_ids: Optional[Union[List[List[int]], List[int]]] = None
|
34
|
+
# The embeddings for input_ids; one can specify either text or input_ids or input_embeds.
|
35
|
+
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
|
36
36
|
# The image input. It can be a file name, a url, or base64 encoded string.
|
37
37
|
# See also python/sglang/srt/utils.py:load_image.
|
38
38
|
image_data: Optional[Union[List[str], str]] = None
|
@@ -57,14 +57,21 @@ class GenerateReqInput:
|
|
57
57
|
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
58
58
|
|
59
59
|
# Session id info for continual prompting
|
60
|
-
|
61
|
-
|
60
|
+
session: Optional[
|
61
|
+
Union[List[Tuple[str, Optional[str]]], Tuple[str, Optional[str]]]
|
62
|
+
] = None
|
62
63
|
|
63
64
|
def normalize_batch_and_arguments(self):
|
64
|
-
if (
|
65
|
-
self.text is
|
65
|
+
if (
|
66
|
+
self.text is None and self.input_ids is None and self.input_embeds is None
|
67
|
+
) or (
|
68
|
+
self.text is not None
|
69
|
+
and self.input_ids is not None
|
70
|
+
and self.input_embeds is not None
|
66
71
|
):
|
67
|
-
raise ValueError(
|
72
|
+
raise ValueError(
|
73
|
+
"Either text, input_ids or input_embeds should be provided."
|
74
|
+
)
|
68
75
|
|
69
76
|
# Derive the batch size
|
70
77
|
if self.text is not None:
|
@@ -74,13 +81,21 @@ class GenerateReqInput:
|
|
74
81
|
else:
|
75
82
|
self.is_single = False
|
76
83
|
self.batch_size = len(self.text)
|
77
|
-
|
84
|
+
self.input_embeds = None
|
85
|
+
elif self.input_ids is not None:
|
78
86
|
if isinstance(self.input_ids[0], int):
|
79
87
|
self.is_single = True
|
80
88
|
self.batch_size = 1
|
81
89
|
else:
|
82
90
|
self.is_single = False
|
83
91
|
self.batch_size = len(self.input_ids)
|
92
|
+
self.input_embeds = None
|
93
|
+
else:
|
94
|
+
if isinstance(self.input_embeds[0][0], float):
|
95
|
+
self.is_single = True
|
96
|
+
self.batch_size = 1
|
97
|
+
else:
|
98
|
+
self.batch_size = len(self.input_embeds)
|
84
99
|
|
85
100
|
# Handle parallel sampling
|
86
101
|
# When parallel sampling is used, we always treat the input as a batch.
|
@@ -203,9 +218,11 @@ class TokenizedGenerateReqInput:
|
|
203
218
|
|
204
219
|
# LoRA related
|
205
220
|
lora_path: Optional[str] = None # None means just use the base model
|
221
|
+
# The input embeds
|
222
|
+
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
|
206
223
|
|
207
224
|
# Session id info for continual prompting
|
208
|
-
session_id: Optional[
|
225
|
+
session_id: Optional[str] = None
|
209
226
|
session_rid: Optional[str] = None
|
210
227
|
|
211
228
|
|
@@ -219,6 +236,8 @@ class EmbeddingReqInput:
|
|
219
236
|
rid: Optional[Union[List[str], str]] = None
|
220
237
|
# Dummy sampling params for compatibility
|
221
238
|
sampling_params: Union[List[Dict], Dict] = None
|
239
|
+
# Dummy input embeds for compatibility
|
240
|
+
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
|
222
241
|
|
223
242
|
def normalize_batch_and_arguments(self):
|
224
243
|
if (self.text is None and self.input_ids is None) or (
|
@@ -301,8 +320,6 @@ class BatchTokenIDOut:
|
|
301
320
|
meta_info: List[Dict]
|
302
321
|
finished_reason: List[BaseFinishReason]
|
303
322
|
no_stop_trim: List[bool]
|
304
|
-
# The updated session unique id
|
305
|
-
session_ids: List[str]
|
306
323
|
|
307
324
|
|
308
325
|
@dataclass
|
@@ -315,8 +332,6 @@ class BatchStrOut:
|
|
315
332
|
meta_info: List[Dict]
|
316
333
|
# The finish reason
|
317
334
|
finished_reason: List[BaseFinishReason]
|
318
|
-
# The update session unique id
|
319
|
-
session_ids: List[str]
|
320
335
|
|
321
336
|
|
322
337
|
@dataclass
|