sglang 0.3.4.post1__py3-none-any.whl → 0.3.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +1 -1
- sglang/bench_latency.py +3 -3
- sglang/bench_server_latency.py +2 -3
- sglang/bench_serving.py +92 -0
- sglang/global_config.py +9 -3
- sglang/lang/chat_template.py +50 -25
- sglang/lang/interpreter.py +9 -1
- sglang/lang/ir.py +11 -2
- sglang/launch_server.py +1 -1
- sglang/srt/configs/model_config.py +76 -15
- sglang/srt/constrained/__init__.py +18 -0
- sglang/srt/constrained/bnf_cache.py +61 -0
- sglang/srt/constrained/fsm_cache.py +10 -3
- sglang/srt/constrained/grammar.py +190 -0
- sglang/srt/hf_transformers_utils.py +20 -5
- sglang/srt/layers/attention/flashinfer_backend.py +5 -5
- sglang/srt/layers/attention/triton_ops/decode_attention.py +110 -30
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +1 -1
- sglang/srt/layers/fused_moe/fused_moe.py +4 -3
- sglang/srt/layers/fused_moe/layer.py +28 -0
- sglang/srt/layers/logits_processor.py +5 -5
- sglang/srt/layers/quantization/base_config.py +16 -1
- sglang/srt/layers/rotary_embedding.py +15 -48
- sglang/srt/layers/sampler.py +51 -39
- sglang/srt/layers/vocab_parallel_embedding.py +486 -0
- sglang/srt/managers/data_parallel_controller.py +8 -7
- sglang/srt/managers/detokenizer_manager.py +11 -9
- sglang/srt/managers/image_processor.py +4 -3
- sglang/srt/managers/io_struct.py +80 -78
- sglang/srt/managers/schedule_batch.py +46 -52
- sglang/srt/managers/schedule_policy.py +24 -13
- sglang/srt/managers/scheduler.py +145 -82
- sglang/srt/managers/tokenizer_manager.py +236 -334
- sglang/srt/managers/tp_worker.py +5 -5
- sglang/srt/managers/tp_worker_overlap_thread.py +58 -21
- sglang/srt/mem_cache/flush_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +10 -3
- sglang/srt/model_executor/cuda_graph_runner.py +34 -23
- sglang/srt/model_executor/forward_batch_info.py +6 -9
- sglang/srt/model_executor/model_runner.py +10 -19
- sglang/srt/models/baichuan.py +4 -4
- sglang/srt/models/chatglm.py +4 -4
- sglang/srt/models/commandr.py +1 -1
- sglang/srt/models/dbrx.py +5 -5
- sglang/srt/models/deepseek.py +4 -4
- sglang/srt/models/deepseek_v2.py +4 -4
- sglang/srt/models/exaone.py +4 -4
- sglang/srt/models/gemma.py +1 -1
- sglang/srt/models/gemma2.py +1 -1
- sglang/srt/models/gpt2.py +287 -0
- sglang/srt/models/gpt_bigcode.py +1 -1
- sglang/srt/models/grok.py +4 -4
- sglang/srt/models/internlm2.py +4 -4
- sglang/srt/models/llama.py +15 -7
- sglang/srt/models/llama_embedding.py +2 -10
- sglang/srt/models/llama_reward.py +5 -0
- sglang/srt/models/minicpm.py +4 -4
- sglang/srt/models/minicpm3.py +4 -4
- sglang/srt/models/mixtral.py +7 -5
- sglang/srt/models/mixtral_quant.py +4 -4
- sglang/srt/models/mllama.py +5 -5
- sglang/srt/models/olmo.py +4 -4
- sglang/srt/models/olmoe.py +4 -4
- sglang/srt/models/qwen.py +4 -4
- sglang/srt/models/qwen2.py +4 -4
- sglang/srt/models/qwen2_moe.py +4 -4
- sglang/srt/models/qwen2_vl.py +4 -8
- sglang/srt/models/stablelm.py +4 -4
- sglang/srt/models/torch_native_llama.py +4 -4
- sglang/srt/models/xverse.py +4 -4
- sglang/srt/models/xverse_moe.py +4 -4
- sglang/srt/openai_api/adapter.py +52 -66
- sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +6 -3
- sglang/srt/sampling/sampling_batch_info.py +7 -13
- sglang/srt/sampling/sampling_params.py +5 -7
- sglang/srt/server.py +41 -33
- sglang/srt/server_args.py +34 -5
- sglang/srt/utils.py +40 -56
- sglang/test/run_eval.py +2 -0
- sglang/test/runners.py +2 -1
- sglang/test/srt/sampling/penaltylib/utils.py +1 -0
- sglang/test/test_utils.py +151 -6
- sglang/utils.py +62 -1
- sglang/version.py +1 -1
- sglang-0.3.5.dist-info/METADATA +344 -0
- sglang-0.3.5.dist-info/RECORD +152 -0
- {sglang-0.3.4.post1.dist-info → sglang-0.3.5.dist-info}/WHEEL +1 -1
- sglang-0.3.4.post1.dist-info/METADATA +0 -900
- sglang-0.3.4.post1.dist-info/RECORD +0 -148
- {sglang-0.3.4.post1.dist-info → sglang-0.3.5.dist-info}/LICENSE +0 -0
- {sglang-0.3.4.post1.dist-info → sglang-0.3.5.dist-info}/top_level.txt +0 -0
sglang/srt/models/qwen2_moe.py
CHANGED
@@ -29,10 +29,6 @@ from vllm.distributed import (
|
|
29
29
|
)
|
30
30
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
31
31
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
32
|
-
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
33
|
-
ParallelLMHead,
|
34
|
-
VocabParallelEmbedding,
|
35
|
-
)
|
36
32
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
37
33
|
|
38
34
|
from sglang.srt.layers.activation import SiluAndMul
|
@@ -47,6 +43,10 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
|
|
47
43
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
48
44
|
from sglang.srt.layers.radix_attention import RadixAttention
|
49
45
|
from sglang.srt.layers.torchao_utils import apply_torchao_config_
|
46
|
+
from sglang.srt.layers.vocab_parallel_embedding import (
|
47
|
+
ParallelLMHead,
|
48
|
+
VocabParallelEmbedding,
|
49
|
+
)
|
50
50
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
51
51
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
52
52
|
|
sglang/srt/models/qwen2_vl.py
CHANGED
@@ -23,7 +23,7 @@
|
|
23
23
|
# limitations under the License.
|
24
24
|
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
|
25
25
|
from functools import lru_cache, partial
|
26
|
-
from typing import Iterable, List,
|
26
|
+
from typing import Iterable, List, Optional, Tuple, Type, TypedDict
|
27
27
|
|
28
28
|
import numpy as np
|
29
29
|
import torch
|
@@ -35,9 +35,7 @@ from vllm.distributed import parallel_state
|
|
35
35
|
from vllm.distributed import utils as dist_utils
|
36
36
|
from vllm.logger import init_logger
|
37
37
|
from vllm.model_executor.layers.activation import QuickGELU
|
38
|
-
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
39
38
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
40
|
-
from vllm.model_executor.models.interfaces import SupportsMultiModal
|
41
39
|
|
42
40
|
from sglang.srt.configs import Qwen2VLConfig, Qwen2VLVisionConfig
|
43
41
|
from sglang.srt.hf_transformers_utils import get_processor
|
@@ -47,6 +45,7 @@ from sglang.srt.layers.attention.triton_ops.prefill_attention import (
|
|
47
45
|
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
|
48
46
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
49
47
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
48
|
+
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
50
49
|
from sglang.srt.managers.schedule_batch import ImageInputs
|
51
50
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
52
51
|
from sglang.srt.models.qwen2 import Qwen2Model
|
@@ -486,7 +485,7 @@ class Qwen2VisionTransformer(nn.Module):
|
|
486
485
|
cached_get_processor = lru_cache(get_processor)
|
487
486
|
|
488
487
|
|
489
|
-
class Qwen2VLForConditionalGeneration(nn.Module
|
488
|
+
class Qwen2VLForConditionalGeneration(nn.Module):
|
490
489
|
def calculate_num_image_tokens(self, image_grid_thw: Tuple[int, int, int]):
|
491
490
|
processor = cached_get_processor(self.config._name_or_path)
|
492
491
|
grid_t, grid_h, grid_w = image_grid_thw
|
@@ -536,15 +535,12 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|
536
535
|
def __init__(
|
537
536
|
self,
|
538
537
|
config: Qwen2VLConfig,
|
539
|
-
multimodal_config: MultiModalConfig,
|
540
538
|
cache_config: Optional[CacheConfig] = None,
|
541
539
|
quant_config: Optional[QuantizationConfig] = None,
|
542
540
|
) -> None:
|
543
541
|
super().__init__()
|
544
542
|
|
545
543
|
self.config = config
|
546
|
-
self.multimodal_config = multimodal_config
|
547
|
-
|
548
544
|
self.visual = Qwen2VisionTransformer(
|
549
545
|
config.vision_config,
|
550
546
|
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
@@ -622,7 +618,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal):
|
|
622
618
|
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
|
623
619
|
prefix_lens_cpu = forward_batch.extend_prefix_lens.cpu().numpy()
|
624
620
|
for i, image in enumerate(forward_batch.image_inputs):
|
625
|
-
if image
|
621
|
+
if image is None:
|
626
622
|
continue
|
627
623
|
start_idx = extend_start_loc_cpu[i]
|
628
624
|
prefix_len = prefix_lens_cpu[i]
|
sglang/srt/models/stablelm.py
CHANGED
@@ -24,10 +24,6 @@ from torch import nn
|
|
24
24
|
from transformers import PretrainedConfig
|
25
25
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
26
26
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
27
|
-
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
28
|
-
ParallelLMHead,
|
29
|
-
VocabParallelEmbedding,
|
30
|
-
)
|
31
27
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
32
28
|
|
33
29
|
from sglang.srt.layers.activation import SiluAndMul
|
@@ -39,6 +35,10 @@ from sglang.srt.layers.linear import (
|
|
39
35
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
40
36
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
41
37
|
from sglang.srt.layers.radix_attention import RadixAttention
|
38
|
+
from sglang.srt.layers.vocab_parallel_embedding import (
|
39
|
+
ParallelLMHead,
|
40
|
+
VocabParallelEmbedding,
|
41
|
+
)
|
42
42
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
43
43
|
|
44
44
|
|
@@ -26,10 +26,6 @@ from torch.nn.parameter import Parameter
|
|
26
26
|
from transformers import LlamaConfig
|
27
27
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
28
28
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
29
|
-
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
30
|
-
ParallelLMHead,
|
31
|
-
VocabParallelEmbedding,
|
32
|
-
)
|
33
29
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
34
30
|
|
35
31
|
from sglang.srt.layers.activation import SiluAndMul
|
@@ -38,6 +34,10 @@ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorO
|
|
38
34
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
39
35
|
from sglang.srt.layers.radix_attention import RadixAttention
|
40
36
|
from sglang.srt.layers.torchao_utils import apply_torchao_config_
|
37
|
+
from sglang.srt.layers.vocab_parallel_embedding import (
|
38
|
+
ParallelLMHead,
|
39
|
+
VocabParallelEmbedding,
|
40
|
+
)
|
41
41
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
42
42
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
43
43
|
|
sglang/srt/models/xverse.py
CHANGED
@@ -31,15 +31,15 @@ from vllm.model_executor.layers.linear import (
|
|
31
31
|
RowParallelLinear,
|
32
32
|
)
|
33
33
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
34
|
-
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
35
|
-
ParallelLMHead,
|
36
|
-
VocabParallelEmbedding,
|
37
|
-
)
|
38
34
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
39
35
|
|
40
36
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
41
37
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
42
38
|
from sglang.srt.layers.radix_attention import RadixAttention
|
39
|
+
from sglang.srt.layers.vocab_parallel_embedding import (
|
40
|
+
ParallelLMHead,
|
41
|
+
VocabParallelEmbedding,
|
42
|
+
)
|
43
43
|
from sglang.srt.model_executor.model_runner import ForwardBatch
|
44
44
|
|
45
45
|
|
sglang/srt/models/xverse_moe.py
CHANGED
@@ -34,15 +34,15 @@ from vllm.model_executor.layers.linear import (
|
|
34
34
|
RowParallelLinear,
|
35
35
|
)
|
36
36
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
37
|
-
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
38
|
-
ParallelLMHead,
|
39
|
-
VocabParallelEmbedding,
|
40
|
-
)
|
41
37
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
42
38
|
|
43
39
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
44
40
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
45
41
|
from sglang.srt.layers.radix_attention import RadixAttention
|
42
|
+
from sglang.srt.layers.vocab_parallel_embedding import (
|
43
|
+
ParallelLMHead,
|
44
|
+
VocabParallelEmbedding,
|
45
|
+
)
|
46
46
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
47
47
|
|
48
48
|
|
sglang/srt/openai_api/adapter.py
CHANGED
@@ -71,6 +71,7 @@ from sglang.srt.openai_api.protocol import (
|
|
71
71
|
TopLogprob,
|
72
72
|
UsageInfo,
|
73
73
|
)
|
74
|
+
from sglang.utils import get_exception_traceback
|
74
75
|
|
75
76
|
logger = logging.getLogger(__name__)
|
76
77
|
|
@@ -314,6 +315,8 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
|
|
314
315
|
)
|
315
316
|
|
316
317
|
except Exception as e:
|
318
|
+
logger.error(f"error: {get_exception_traceback()}")
|
319
|
+
responses = []
|
317
320
|
error_json = {
|
318
321
|
"id": f"batch_req_{uuid.uuid4()}",
|
319
322
|
"custom_id": request_data.get("custom_id"),
|
@@ -363,7 +366,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
|
|
363
366
|
}
|
364
367
|
|
365
368
|
except Exception as e:
|
366
|
-
logger.error("error
|
369
|
+
logger.error(f"error: {e}")
|
367
370
|
# Update batch status to "failed"
|
368
371
|
retrieve_batch = batch_storage[batch_id]
|
369
372
|
retrieve_batch.status = "failed"
|
@@ -469,80 +472,67 @@ async def v1_retrieve_file_content(file_id: str):
|
|
469
472
|
def v1_generate_request(
|
470
473
|
all_requests: List[CompletionRequest], request_ids: List[str] = None
|
471
474
|
):
|
475
|
+
if len(all_requests) > 1:
|
476
|
+
first_prompt_type = type(all_requests[0].prompt)
|
477
|
+
for request in all_requests:
|
478
|
+
assert (
|
479
|
+
type(request.prompt) is first_prompt_type
|
480
|
+
), "All prompts must be of the same type in file input settings"
|
481
|
+
if request.n > 1:
|
482
|
+
raise ValueError(
|
483
|
+
"Parallel sampling is not supported for completions from files"
|
484
|
+
)
|
485
|
+
|
472
486
|
prompts = []
|
473
487
|
sampling_params_list = []
|
474
488
|
return_logprobs = []
|
475
489
|
logprob_start_lens = []
|
476
490
|
top_logprobs_nums = []
|
477
491
|
|
478
|
-
# NOTE: with openai API, the prompt's logprobs are always not computed
|
479
|
-
first_prompt_type = type(all_requests[0].prompt)
|
480
492
|
for request in all_requests:
|
481
|
-
|
482
|
-
type(request.prompt) is first_prompt_type
|
483
|
-
), "All prompts must be of the same type in file input settings"
|
484
|
-
if len(all_requests) > 1 and request.n > 1:
|
485
|
-
raise ValueError(
|
486
|
-
"Parallel sampling is not supported for completions from files"
|
487
|
-
)
|
493
|
+
# NOTE: with openai API, the prompt's logprobs are always not computed
|
488
494
|
if request.echo and request.logprobs:
|
489
495
|
logger.warning(
|
490
496
|
"Echo is not compatible with logprobs. "
|
491
|
-
"To compute logprobs of input prompt, please use
|
497
|
+
"To compute logprobs of input prompt, please use the native /generate API."
|
492
498
|
)
|
493
499
|
|
494
|
-
for request in all_requests:
|
495
500
|
prompts.append(request.prompt)
|
501
|
+
sampling_params_list.append(
|
502
|
+
{
|
503
|
+
"temperature": request.temperature,
|
504
|
+
"max_new_tokens": request.max_tokens,
|
505
|
+
"min_new_tokens": request.min_tokens,
|
506
|
+
"stop": request.stop,
|
507
|
+
"stop_token_ids": request.stop_token_ids,
|
508
|
+
"top_p": request.top_p,
|
509
|
+
"presence_penalty": request.presence_penalty,
|
510
|
+
"frequency_penalty": request.frequency_penalty,
|
511
|
+
"repetition_penalty": request.repetition_penalty,
|
512
|
+
"regex": request.regex,
|
513
|
+
"json_schema": request.json_schema,
|
514
|
+
"n": request.n,
|
515
|
+
"ignore_eos": request.ignore_eos,
|
516
|
+
"no_stop_trim": request.no_stop_trim,
|
517
|
+
}
|
518
|
+
)
|
496
519
|
return_logprobs.append(request.logprobs is not None and request.logprobs > 0)
|
497
520
|
logprob_start_lens.append(-1)
|
498
521
|
top_logprobs_nums.append(
|
499
522
|
request.logprobs if request.logprobs is not None else 0
|
500
523
|
)
|
501
|
-
sampling_params = []
|
502
|
-
if isinstance(request.no_stop_trim, list):
|
503
|
-
num_reqs = len(request.prompt)
|
504
|
-
else:
|
505
|
-
num_reqs = 1
|
506
|
-
for i in range(num_reqs):
|
507
|
-
sampling_params.append(
|
508
|
-
{
|
509
|
-
"temperature": request.temperature,
|
510
|
-
"max_new_tokens": request.max_tokens,
|
511
|
-
"min_new_tokens": request.min_tokens,
|
512
|
-
"stop": request.stop,
|
513
|
-
"stop_token_ids": request.stop_token_ids,
|
514
|
-
"top_p": request.top_p,
|
515
|
-
"presence_penalty": request.presence_penalty,
|
516
|
-
"frequency_penalty": request.frequency_penalty,
|
517
|
-
"repetition_penalty": request.repetition_penalty,
|
518
|
-
"regex": request.regex,
|
519
|
-
"json_schema": request.json_schema,
|
520
|
-
"n": request.n,
|
521
|
-
"ignore_eos": request.ignore_eos,
|
522
|
-
"no_stop_trim": (
|
523
|
-
request.no_stop_trim
|
524
|
-
if not isinstance(request.no_stop_trim, list)
|
525
|
-
else request.no_stop_trim[i]
|
526
|
-
),
|
527
|
-
}
|
528
|
-
)
|
529
|
-
if num_reqs == 1:
|
530
|
-
sampling_params_list.append(sampling_params[0])
|
531
|
-
else:
|
532
|
-
sampling_params_list.append(sampling_params)
|
533
524
|
|
534
525
|
if len(all_requests) == 1:
|
535
|
-
|
526
|
+
if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
|
527
|
+
prompt_kwargs = {"text": prompts[0]}
|
528
|
+
else:
|
529
|
+
prompt_kwargs = {"input_ids": prompts[0]}
|
536
530
|
sampling_params_list = sampling_params_list[0]
|
537
|
-
logprob_start_lens = logprob_start_lens[0]
|
538
531
|
return_logprobs = return_logprobs[0]
|
532
|
+
logprob_start_lens = logprob_start_lens[0]
|
539
533
|
top_logprobs_nums = top_logprobs_nums[0]
|
540
|
-
if isinstance(prompt, str) or isinstance(prompt[0], str):
|
541
|
-
prompt_kwargs = {"text": prompt}
|
542
|
-
else:
|
543
|
-
prompt_kwargs = {"input_ids": prompt}
|
544
534
|
else:
|
545
|
-
if isinstance(prompts[0], str):
|
535
|
+
if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
|
546
536
|
prompt_kwargs = {"text": prompts}
|
547
537
|
else:
|
548
538
|
prompt_kwargs = {"input_ids": prompts}
|
@@ -558,9 +548,7 @@ def v1_generate_request(
|
|
558
548
|
rid=request_ids,
|
559
549
|
)
|
560
550
|
|
561
|
-
if len(all_requests)
|
562
|
-
return adapted_request, all_requests[0]
|
563
|
-
return adapted_request, all_requests
|
551
|
+
return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]
|
564
552
|
|
565
553
|
|
566
554
|
def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
|
@@ -595,7 +583,7 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
|
|
595
583
|
if isinstance(request, list) and request[idx].echo:
|
596
584
|
echo = True
|
597
585
|
text = request[idx].prompt + text
|
598
|
-
if
|
586
|
+
if echo and not isinstance(request, list):
|
599
587
|
prompt_index = idx // request.n
|
600
588
|
text = prompts[prompt_index] + text
|
601
589
|
|
@@ -709,7 +697,7 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|
709
697
|
async for content in tokenizer_manager.generate_request(
|
710
698
|
adapted_request, raw_request
|
711
699
|
):
|
712
|
-
index = content
|
700
|
+
index = content.get("index", 0)
|
713
701
|
|
714
702
|
stream_buffer = stream_buffers.get(index, "")
|
715
703
|
n_prev_token = n_prev_tokens.get(index, 0)
|
@@ -945,19 +933,18 @@ def v1_chat_generate_request(
|
|
945
933
|
sampling_params_list.append(sampling_params)
|
946
934
|
|
947
935
|
image_data_list.append(image_data)
|
948
|
-
modalities_list.
|
936
|
+
modalities_list.append(modalities)
|
949
937
|
if len(all_requests) == 1:
|
950
|
-
|
951
|
-
|
952
|
-
prompt_kwargs = {"text": input_ids}
|
938
|
+
if isinstance(input_ids[0], str):
|
939
|
+
prompt_kwargs = {"text": input_ids[0]}
|
953
940
|
else:
|
954
|
-
prompt_kwargs = {"input_ids": input_ids}
|
941
|
+
prompt_kwargs = {"input_ids": input_ids[0]}
|
955
942
|
sampling_params_list = sampling_params_list[0]
|
956
943
|
image_data_list = image_data_list[0]
|
957
944
|
return_logprobs = return_logprobs[0]
|
958
945
|
logprob_start_lens = logprob_start_lens[0]
|
959
946
|
top_logprobs_nums = top_logprobs_nums[0]
|
960
|
-
modalities_list = modalities_list[
|
947
|
+
modalities_list = modalities_list[0]
|
961
948
|
else:
|
962
949
|
if isinstance(input_ids[0], str):
|
963
950
|
prompt_kwargs = {"text": input_ids}
|
@@ -976,9 +963,8 @@ def v1_chat_generate_request(
|
|
976
963
|
rid=request_ids,
|
977
964
|
modalities=modalities_list,
|
978
965
|
)
|
979
|
-
|
980
|
-
|
981
|
-
return adapted_request, all_requests
|
966
|
+
|
967
|
+
return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]
|
982
968
|
|
983
969
|
|
984
970
|
def v1_chat_generate_response(request, ret, to_file=False, cache_report=False):
|
@@ -1116,7 +1102,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|
1116
1102
|
async for content in tokenizer_manager.generate_request(
|
1117
1103
|
adapted_request, raw_request
|
1118
1104
|
):
|
1119
|
-
index = content
|
1105
|
+
index = content.get("index", 0)
|
1120
1106
|
|
1121
1107
|
is_first = is_firsts.get(index, True)
|
1122
1108
|
stream_buffer = stream_buffers.get(index, "")
|
@@ -31,9 +31,12 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
|
|
31
31
|
padded_stop_token_ids = torch.nn.utils.rnn.pad_sequence(
|
32
32
|
sequences=[
|
33
33
|
torch.tensor(
|
34
|
-
data=
|
35
|
-
|
36
|
-
|
34
|
+
data=(
|
35
|
+
list(
|
36
|
+
(req.sampling_params.stop_token_ids or set())
|
37
|
+
| (req.tokenizer.additional_stop_token_ids or set())
|
38
|
+
| {req.tokenizer.eos_token_id}
|
39
|
+
)
|
37
40
|
),
|
38
41
|
dtype=torch.int64,
|
39
42
|
device=self.orchestrator.device,
|
@@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, List, Optional
|
|
6
6
|
import torch
|
7
7
|
|
8
8
|
import sglang.srt.sampling.penaltylib as penaltylib
|
9
|
-
from sglang.srt.constrained import
|
9
|
+
from sglang.srt.constrained.grammar import Grammar
|
10
10
|
|
11
11
|
if TYPE_CHECKING:
|
12
12
|
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
@@ -29,11 +29,9 @@ class SamplingBatchInfo:
|
|
29
29
|
# Bias Tensors
|
30
30
|
vocab_size: int
|
31
31
|
logit_bias: torch.Tensor = None
|
32
|
-
vocab_mask: torch.Tensor = None
|
32
|
+
vocab_mask: Optional[torch.Tensor] = None
|
33
33
|
|
34
|
-
|
35
|
-
regex_fsms: List[RegexGuide] = None
|
36
|
-
regex_fsm_states: List[int] = None
|
34
|
+
grammars: Optional[List[Optional[Grammar]]] = None
|
37
35
|
|
38
36
|
# Penalizer
|
39
37
|
penalizer_orchestrator: Optional[penaltylib.BatchedPenalizerOrchestrator] = None
|
@@ -136,8 +134,7 @@ class SamplingBatchInfo:
|
|
136
134
|
self.linear_penalties = penalizer.apply(self.linear_penalties)
|
137
135
|
|
138
136
|
def update_regex_vocab_mask(self):
|
139
|
-
|
140
|
-
if not has_regex:
|
137
|
+
if not self.grammars or not any(grammar for grammar in self.grammars):
|
141
138
|
self.vocab_mask = None
|
142
139
|
return
|
143
140
|
|
@@ -147,12 +144,9 @@ class SamplingBatchInfo:
|
|
147
144
|
dtype=torch.bool,
|
148
145
|
device=self.device,
|
149
146
|
)
|
150
|
-
for i,
|
151
|
-
if
|
152
|
-
self.vocab_mask[i].
|
153
|
-
self.vocab_mask[i][
|
154
|
-
regex_fsm.get_next_instruction(self.regex_fsm_states[i]).tokens
|
155
|
-
] = 0
|
147
|
+
for i, grammar in enumerate(self.grammars):
|
148
|
+
if grammar is not None:
|
149
|
+
grammar.fill_vocab_mask(self.vocab_mask[i], self.vocab_size)
|
156
150
|
|
157
151
|
def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):
|
158
152
|
if self.penalizer_orchestrator:
|
@@ -50,9 +50,10 @@ class SamplingParams:
|
|
50
50
|
self.presence_penalty = presence_penalty
|
51
51
|
self.repetition_penalty = repetition_penalty
|
52
52
|
self.stop_strs = stop
|
53
|
-
if stop_token_ids
|
54
|
-
stop_token_ids =
|
55
|
-
|
53
|
+
if stop_token_ids:
|
54
|
+
self.stop_token_ids = set(stop_token_ids)
|
55
|
+
else:
|
56
|
+
self.stop_token_ids = None
|
56
57
|
self.max_new_tokens = max_new_tokens
|
57
58
|
self.min_new_tokens = min_new_tokens
|
58
59
|
self.ignore_eos = ignore_eos
|
@@ -119,10 +120,7 @@ class SamplingParams:
|
|
119
120
|
# Process stop strings
|
120
121
|
if self.stop_strs is None:
|
121
122
|
self.stop_strs = []
|
122
|
-
|
123
|
-
self.stop_str_max_len = 0
|
124
|
-
else:
|
125
|
-
self.stop_str_max_len = 1
|
123
|
+
self.stop_str_max_len = 0
|
126
124
|
else:
|
127
125
|
if isinstance(self.stop_strs, str):
|
128
126
|
self.stop_strs = [self.stop_strs]
|
sglang/srt/server.py
CHANGED
@@ -53,7 +53,6 @@ from sglang.srt.managers.detokenizer_manager import run_detokenizer_process
|
|
53
53
|
from sglang.srt.managers.io_struct import (
|
54
54
|
EmbeddingReqInput,
|
55
55
|
GenerateReqInput,
|
56
|
-
RewardReqInput,
|
57
56
|
UpdateWeightReqInput,
|
58
57
|
)
|
59
58
|
from sglang.srt.managers.scheduler import run_scheduler_process
|
@@ -91,7 +90,7 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
|
91
90
|
|
92
91
|
|
93
92
|
app = FastAPI()
|
94
|
-
tokenizer_manager = None
|
93
|
+
tokenizer_manager: TokenizerManager = None
|
95
94
|
|
96
95
|
app.add_middleware(
|
97
96
|
CORSMiddleware,
|
@@ -139,7 +138,7 @@ async def get_server_args():
|
|
139
138
|
return dataclasses.asdict(tokenizer_manager.server_args)
|
140
139
|
|
141
140
|
|
142
|
-
@app.
|
141
|
+
@app.post("/flush_cache")
|
143
142
|
async def flush_cache():
|
144
143
|
"""Flush the radix cache."""
|
145
144
|
tokenizer_manager.flush_cache()
|
@@ -172,6 +171,19 @@ async def stop_profile():
|
|
172
171
|
)
|
173
172
|
|
174
173
|
|
174
|
+
@app.api_route("/get_memory_pool_size", methods=["GET", "POST"])
|
175
|
+
async def get_memory_pool_size():
|
176
|
+
"""Get the memory pool size in number of tokens"""
|
177
|
+
try:
|
178
|
+
ret = await tokenizer_manager.get_memory_pool_size()
|
179
|
+
|
180
|
+
return ret
|
181
|
+
except Exception as e:
|
182
|
+
return ORJSONResponse(
|
183
|
+
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
184
|
+
)
|
185
|
+
|
186
|
+
|
175
187
|
@app.post("/update_weights")
|
176
188
|
async def update_weights(obj: UpdateWeightReqInput, request: Request):
|
177
189
|
"""Update the weights inplace without re-launching the server."""
|
@@ -241,8 +253,8 @@ app.post("/encode")(encode_request)
|
|
241
253
|
app.put("/encode")(encode_request)
|
242
254
|
|
243
255
|
|
244
|
-
async def judge_request(obj:
|
245
|
-
"""Handle a reward model request."""
|
256
|
+
async def judge_request(obj: EmbeddingReqInput, request: Request):
|
257
|
+
"""Handle a reward model request. Now the arguments and return values are the same as embedding models."""
|
246
258
|
try:
|
247
259
|
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
248
260
|
return ret
|
@@ -429,7 +441,7 @@ def launch_server(
|
|
429
441
|
|
430
442
|
# Send a warmup request
|
431
443
|
t = threading.Thread(
|
432
|
-
target=_wait_and_warmup, args=(server_args, pipe_finish_writer
|
444
|
+
target=_wait_and_warmup, args=(server_args, pipe_finish_writer)
|
433
445
|
)
|
434
446
|
t.start()
|
435
447
|
|
@@ -484,7 +496,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
484
496
|
mp.set_start_method("spawn", force=True)
|
485
497
|
|
486
498
|
|
487
|
-
def _wait_and_warmup(server_args, pipe_finish_writer
|
499
|
+
def _wait_and_warmup(server_args, pipe_finish_writer):
|
488
500
|
headers = {}
|
489
501
|
url = server_args.url()
|
490
502
|
if server_args.api_key:
|
@@ -507,7 +519,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid):
|
|
507
519
|
if pipe_finish_writer is not None:
|
508
520
|
pipe_finish_writer.send(last_traceback)
|
509
521
|
logger.error(f"Initialization failed. warmup error: {last_traceback}")
|
510
|
-
kill_child_process(
|
522
|
+
kill_child_process(include_self=True)
|
511
523
|
return
|
512
524
|
|
513
525
|
model_info = res.json()
|
@@ -539,7 +551,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid):
|
|
539
551
|
if pipe_finish_writer is not None:
|
540
552
|
pipe_finish_writer.send(last_traceback)
|
541
553
|
logger.error(f"Initialization failed. warmup error: {last_traceback}")
|
542
|
-
kill_child_process(
|
554
|
+
kill_child_process(include_self=True)
|
543
555
|
return
|
544
556
|
|
545
557
|
# logger.info(f"{res.json()=}")
|
@@ -605,7 +617,7 @@ class Runtime:
|
|
605
617
|
|
606
618
|
def shutdown(self):
|
607
619
|
if self.pid is not None:
|
608
|
-
kill_child_process(self.pid)
|
620
|
+
kill_child_process(self.pid, include_self=True)
|
609
621
|
self.pid = None
|
610
622
|
|
611
623
|
def cache_prefix(self, prefix: str):
|
@@ -684,24 +696,8 @@ class Runtime:
|
|
684
696
|
self,
|
685
697
|
prompt: Union[str, List[str], List[Dict], List[List[Dict]]],
|
686
698
|
):
|
687
|
-
|
688
|
-
|
689
|
-
json_data = {
|
690
|
-
"text": prompt,
|
691
|
-
}
|
692
|
-
response = requests.post(
|
693
|
-
self.url + "/encode",
|
694
|
-
json=json_data,
|
695
|
-
)
|
696
|
-
else:
|
697
|
-
# reward
|
698
|
-
json_data = {
|
699
|
-
"conv": prompt,
|
700
|
-
}
|
701
|
-
response = requests.post(
|
702
|
-
self.url + "/judge",
|
703
|
-
json=json_data,
|
704
|
-
)
|
699
|
+
json_data = {"text": prompt}
|
700
|
+
response = requests.post(self.url + "/encode", json=json_data)
|
705
701
|
return json.dumps(response.json())
|
706
702
|
|
707
703
|
def __del__(self):
|
@@ -724,24 +720,32 @@ class Engine:
|
|
724
720
|
|
725
721
|
# before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
|
726
722
|
atexit.register(self.shutdown)
|
723
|
+
|
724
|
+
# runtime server default log level is log
|
725
|
+
# offline engine works in scripts, so we set it to error
|
726
|
+
|
727
|
+
if 'log_level' not in kwargs:
|
728
|
+
kwargs['log_level'] = 'error'
|
727
729
|
|
728
730
|
server_args = ServerArgs(*args, **kwargs)
|
729
731
|
launch_engine(server_args=server_args)
|
730
732
|
|
731
733
|
def generate(
|
732
734
|
self,
|
733
|
-
prompt
|
735
|
+
# The input prompt. It can be a single prompt or a batch of prompts.
|
736
|
+
prompt: Optional[Union[List[str], str]] = None,
|
734
737
|
sampling_params: Optional[Dict] = None,
|
738
|
+
# The token ids for text; one can either specify text or input_ids.
|
739
|
+
input_ids: Optional[Union[List[List[int]], List[int]]] = None,
|
735
740
|
return_logprob: Optional[Union[List[bool], bool]] = False,
|
736
741
|
logprob_start_len: Optional[Union[List[int], int]] = None,
|
737
742
|
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
738
743
|
lora_path: Optional[List[Optional[str]]] = None,
|
739
744
|
stream: bool = False,
|
740
745
|
):
|
741
|
-
# TODO (ByronHsu): refactor to reduce the duplicated code
|
742
|
-
|
743
746
|
obj = GenerateReqInput(
|
744
747
|
text=prompt,
|
748
|
+
input_ids=input_ids,
|
745
749
|
sampling_params=sampling_params,
|
746
750
|
return_logprob=return_logprob,
|
747
751
|
logprob_start_len=logprob_start_len,
|
@@ -779,8 +783,11 @@ class Engine:
|
|
779
783
|
|
780
784
|
async def async_generate(
|
781
785
|
self,
|
782
|
-
prompt
|
786
|
+
# The input prompt. It can be a single prompt or a batch of prompts.
|
787
|
+
prompt: Optional[Union[List[str], str]] = None,
|
783
788
|
sampling_params: Optional[Dict] = None,
|
789
|
+
# The token ids for text; one can either specify text or input_ids.
|
790
|
+
input_ids: Optional[Union[List[List[int]], List[int]]] = None,
|
784
791
|
return_logprob: Optional[Union[List[bool], bool]] = False,
|
785
792
|
logprob_start_len: Optional[Union[List[int], int]] = None,
|
786
793
|
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
@@ -789,6 +796,7 @@ class Engine:
|
|
789
796
|
):
|
790
797
|
obj = GenerateReqInput(
|
791
798
|
text=prompt,
|
799
|
+
input_ids=input_ids,
|
792
800
|
sampling_params=sampling_params,
|
793
801
|
return_logprob=return_logprob,
|
794
802
|
logprob_start_len=logprob_start_len,
|
@@ -822,7 +830,7 @@ class Engine:
|
|
822
830
|
return ret
|
823
831
|
|
824
832
|
def shutdown(self):
|
825
|
-
kill_child_process(
|
833
|
+
kill_child_process()
|
826
834
|
|
827
835
|
def get_tokenizer(self):
|
828
836
|
global tokenizer_manager
|