sglang 0.1.20__tar.gz → 0.1.21__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sglang-0.1.20/sglang.egg-info → sglang-0.1.21}/PKG-INFO +9 -1
- {sglang-0.1.20 → sglang-0.1.21}/README.md +8 -0
- {sglang-0.1.20 → sglang-0.1.21}/pyproject.toml +1 -1
- {sglang-0.1.20 → sglang-0.1.21}/sglang/__init__.py +1 -1
- {sglang-0.1.20 → sglang-0.1.21}/sglang/backend/runtime_endpoint.py +14 -4
- {sglang-0.1.20 → sglang-0.1.21}/sglang/bench_latency.py +0 -1
- {sglang-0.1.20 → sglang-0.1.21}/sglang/global_config.py +3 -1
- {sglang-0.1.20 → sglang-0.1.21}/sglang/lang/chat_template.py +2 -2
- {sglang-0.1.20 → sglang-0.1.21}/sglang/lang/ir.py +3 -3
- {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/managers/controller/cuda_graph_runner.py +36 -12
- {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/managers/controller/infer_batch.py +32 -26
- {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/managers/controller/manager_multi.py +6 -2
- sglang-0.1.21/sglang/srt/managers/controller/manager_single.py +177 -0
- {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/managers/controller/model_runner.py +17 -5
- {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/managers/controller/radix_cache.py +4 -3
- {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/managers/controller/schedule_heuristic.py +4 -0
- {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/managers/controller/tp_worker.py +38 -40
- {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/memory_pool.py +29 -54
- {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/models/minicpm.py +1 -8
- {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/models/qwen2_moe.py +126 -107
- {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/server.py +10 -15
- {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/server_args.py +4 -2
- {sglang-0.1.20 → sglang-0.1.21/sglang.egg-info}/PKG-INFO +9 -1
- sglang-0.1.20/sglang/srt/managers/controller/manager_single.py +0 -102
- {sglang-0.1.20 → sglang-0.1.21}/LICENSE +0 -0
- {sglang-0.1.20 → sglang-0.1.21}/setup.cfg +0 -0
- {sglang-0.1.20 → sglang-0.1.21}/sglang/api.py +0 -0
- {sglang-0.1.20 → sglang-0.1.21}/sglang/backend/__init__.py +0 -0
- {sglang-0.1.20 → sglang-0.1.21}/sglang/backend/anthropic.py +0 -0
- {sglang-0.1.20 → sglang-0.1.21}/sglang/backend/base_backend.py +0 -0
- {sglang-0.1.20 → sglang-0.1.21}/sglang/backend/litellm.py +0 -0
- {sglang-0.1.20 → sglang-0.1.21}/sglang/backend/openai.py +0 -0
- {sglang-0.1.20 → sglang-0.1.21}/sglang/backend/vertexai.py +0 -0
- {sglang-0.1.20 → sglang-0.1.21}/sglang/lang/__init__.py +0 -0
- {sglang-0.1.20 → sglang-0.1.21}/sglang/lang/compiler.py +0 -0
- {sglang-0.1.20 → sglang-0.1.21}/sglang/lang/interpreter.py +0 -0
- {sglang-0.1.20 → sglang-0.1.21}/sglang/lang/tracer.py +0 -0
- {sglang-0.1.20 → sglang-0.1.21}/sglang/launch_server.py +0 -0
- {sglang-0.1.20 → sglang-0.1.21}/sglang/launch_server_llavavid.py +0 -0
- {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/constrained/__init__.py +0 -0
- {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/constrained/base_cache.py +0 -0
- {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/constrained/fsm_cache.py +0 -0
- {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/constrained/jump_forward.py +0 -0
- {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/conversation.py +0 -0
- {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/flush_cache.py +0 -0
- {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/hf_transformers_utils.py +0 -0
- {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/layers/context_flashattention_nopad.py +0 -0
- {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/layers/extend_attention.py +0 -0
- {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/layers/fused_moe.py +0 -0
- {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/layers/logits_processor.py +0 -0
- {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/layers/radix_attention.py +0 -0
- {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/layers/token_attention.py +0 -0
- {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/managers/controller/dp_worker.py +0 -0
- {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/managers/detokenizer_manager.py +0 -0
- {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/managers/io_struct.py +0 -0
- {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/managers/tokenizer_manager.py +0 -0
- {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/mm_utils.py +0 -0
- {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/model_config.py +0 -0
- {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/models/chatglm.py +0 -0
- {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/models/commandr.py +0 -0
- {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/models/dbrx.py +0 -0
- {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/models/gemma.py +0 -0
- {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/models/gemma2.py +0 -0
- {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/models/grok.py +0 -0
- {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/models/llama2.py +0 -0
- {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/models/llama_classification.py +0 -0
- {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/models/llava.py +0 -0
- {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/models/llavavid.py +0 -0
- {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/models/mistral.py +0 -0
- {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/models/mixtral.py +0 -0
- {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/models/mixtral_quant.py +0 -0
- {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/models/qwen.py +0 -0
- {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/models/qwen2.py +0 -0
- {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/models/stablelm.py +0 -0
- {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/models/yivl.py +0 -0
- {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/openai_api_adapter.py +0 -0
- {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/openai_protocol.py +0 -0
- {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/sampling_params.py +0 -0
- {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/utils.py +1 -1
- {sglang-0.1.20 → sglang-0.1.21}/sglang/test/test_conversation.py +0 -0
- {sglang-0.1.20 → sglang-0.1.21}/sglang/test/test_openai_protocol.py +0 -0
- {sglang-0.1.20 → sglang-0.1.21}/sglang/test/test_programs.py +0 -0
- {sglang-0.1.20 → sglang-0.1.21}/sglang/test/test_utils.py +0 -0
- {sglang-0.1.20 → sglang-0.1.21}/sglang/utils.py +0 -0
- {sglang-0.1.20 → sglang-0.1.21}/sglang.egg-info/SOURCES.txt +0 -0
- {sglang-0.1.20 → sglang-0.1.21}/sglang.egg-info/dependency_links.txt +0 -0
- {sglang-0.1.20 → sglang-0.1.21}/sglang.egg-info/requires.txt +0 -0
- {sglang-0.1.20 → sglang-0.1.21}/sglang.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: sglang
|
3
|
-
Version: 0.1.
|
3
|
+
Version: 0.1.21
|
4
4
|
Summary: A structured generation langauge for LLMs.
|
5
5
|
License: Apache License
|
6
6
|
Version 2.0, January 2004
|
@@ -623,6 +623,14 @@ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port
|
|
623
623
|
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --mem-fraction-static 0.7
|
624
624
|
```
|
625
625
|
- See [hyperparameter_tuning.md](docs/hyperparameter_tuning.md) on tuning hyperparameters for better performance.
|
626
|
+
- Add `--nnodes 2` to run tensor parallelism on multiple nodes. If you have two nodes with two GPUs on each node and want to run TP=4, let `sgl-dev-1` be the hostname of the first node and `50000` be an available port.
|
627
|
+
```
|
628
|
+
# Node 0
|
629
|
+
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --tp 4 --nccl-init sgl-dev-1:50000 --nnodes 2 --node-rank 0
|
630
|
+
|
631
|
+
# Node 1
|
632
|
+
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --tp 4 --nccl-init sgl-dev-1:50000 --nnodes 2 --node-rank 1
|
633
|
+
```
|
626
634
|
|
627
635
|
### Supported Models
|
628
636
|
- Llama
|
@@ -377,6 +377,14 @@ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port
|
|
377
377
|
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --mem-fraction-static 0.7
|
378
378
|
```
|
379
379
|
- See [hyperparameter_tuning.md](docs/hyperparameter_tuning.md) on tuning hyperparameters for better performance.
|
380
|
+
- Add `--nnodes 2` to run tensor parallelism on multiple nodes. If you have two nodes with two GPUs on each node and want to run TP=4, let `sgl-dev-1` be the hostname of the first node and `50000` be an available port.
|
381
|
+
```
|
382
|
+
# Node 0
|
383
|
+
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --tp 4 --nccl-init sgl-dev-1:50000 --nnodes 2 --node-rank 0
|
384
|
+
|
385
|
+
# Node 1
|
386
|
+
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --tp 4 --nccl-init sgl-dev-1:50000 --nnodes 2 --node-rank 1
|
387
|
+
```
|
380
388
|
|
381
389
|
### Supported Models
|
382
390
|
- Llama
|
@@ -12,7 +12,6 @@ from sglang.utils import http_request
|
|
12
12
|
|
13
13
|
|
14
14
|
class RuntimeEndpoint(BaseBackend):
|
15
|
-
|
16
15
|
def __init__(
|
17
16
|
self,
|
18
17
|
base_url: str,
|
@@ -38,7 +37,8 @@ class RuntimeEndpoint(BaseBackend):
|
|
38
37
|
self.model_info = res.json()
|
39
38
|
|
40
39
|
self.chat_template = get_chat_template_by_model_path(
|
41
|
-
self.model_info["model_path"]
|
40
|
+
self.model_info["model_path"]
|
41
|
+
)
|
42
42
|
|
43
43
|
def get_model_name(self):
|
44
44
|
return self.model_info["model_path"]
|
@@ -124,7 +124,12 @@ class RuntimeEndpoint(BaseBackend):
|
|
124
124
|
else:
|
125
125
|
raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
|
126
126
|
|
127
|
-
for item in [
|
127
|
+
for item in [
|
128
|
+
"return_logprob",
|
129
|
+
"logprob_start_len",
|
130
|
+
"top_logprobs_num",
|
131
|
+
"return_text_in_logprobs",
|
132
|
+
]:
|
128
133
|
value = getattr(sampling_params, item, None)
|
129
134
|
if value is not None:
|
130
135
|
data[item] = value
|
@@ -171,7 +176,12 @@ class RuntimeEndpoint(BaseBackend):
|
|
171
176
|
else:
|
172
177
|
raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
|
173
178
|
|
174
|
-
for item in [
|
179
|
+
for item in [
|
180
|
+
"return_logprob",
|
181
|
+
"logprob_start_len",
|
182
|
+
"top_logprobs_num",
|
183
|
+
"return_text_in_logprobs",
|
184
|
+
]:
|
175
185
|
value = getattr(sampling_params, item, None)
|
176
186
|
if value is not None:
|
177
187
|
data[item] = value
|
@@ -25,7 +25,8 @@ class GlobalConfig:
|
|
25
25
|
# This can improve the speed for large batch sizes during prefill.
|
26
26
|
self.layer_sync_threshold = 8192
|
27
27
|
|
28
|
-
# Runtime constants:
|
28
|
+
# Runtime constants: others
|
29
|
+
self.num_continue_decode_steps = 10
|
29
30
|
self.flashinfer_workspace_size = 192 * 1024 * 1024
|
30
31
|
|
31
32
|
# Output tokenization configs
|
@@ -44,4 +45,5 @@ class GlobalConfig:
|
|
44
45
|
# adjust_cache: Adjust the position embedding of KV cache.
|
45
46
|
self.concate_and_append_mode = "no_adjust"
|
46
47
|
|
48
|
+
|
47
49
|
global_config = GlobalConfig()
|
@@ -84,7 +84,7 @@ register_chat_template(
|
|
84
84
|
"system": ("SYSTEM:", "\n"),
|
85
85
|
"user": ("USER:", "\n"),
|
86
86
|
"assistant": ("ASSISTANT:", "\n"),
|
87
|
-
}
|
87
|
+
},
|
88
88
|
)
|
89
89
|
)
|
90
90
|
|
@@ -177,7 +177,7 @@ register_chat_template(
|
|
177
177
|
"assistant": ("", "<|im_end|>\n"),
|
178
178
|
},
|
179
179
|
style=ChatTemplateStyle.PLAIN,
|
180
|
-
stop_str=("<|im_end|>",)
|
180
|
+
stop_str=("<|im_end|>",),
|
181
181
|
)
|
182
182
|
)
|
183
183
|
|
@@ -24,9 +24,9 @@ class SglSamplingParams:
|
|
24
24
|
presence_penalty: float = 0.0
|
25
25
|
ignore_eos: bool = False
|
26
26
|
return_logprob: Optional[bool] = None
|
27
|
-
logprob_start_len: Optional[int] = None,
|
28
|
-
top_logprobs_num: Optional[int] = None,
|
29
|
-
return_text_in_logprobs: Optional[bool] = None,
|
27
|
+
logprob_start_len: Optional[int] = (None,)
|
28
|
+
top_logprobs_num: Optional[int] = (None,)
|
29
|
+
return_text_in_logprobs: Optional[bool] = (None,)
|
30
30
|
|
31
31
|
# for constrained generation, not included in to_xxx_kwargs
|
32
32
|
dtype: Optional[str] = None
|
@@ -8,7 +8,10 @@ from vllm.distributed.parallel_state import graph_capture
|
|
8
8
|
from sglang.global_config import global_config
|
9
9
|
from sglang.srt.layers.logits_processor import LogitProcessorOutput
|
10
10
|
from sglang.srt.managers.controller.infer_batch import (
|
11
|
-
Batch,
|
11
|
+
Batch,
|
12
|
+
ForwardMode,
|
13
|
+
InputMetadata,
|
14
|
+
init_flashinfer_args,
|
12
15
|
)
|
13
16
|
|
14
17
|
|
@@ -24,18 +27,28 @@ class CudaGraphRunner:
|
|
24
27
|
# Common inputs
|
25
28
|
self.max_bs = max_batch_size_to_capture
|
26
29
|
self.input_ids = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
|
27
|
-
self.req_pool_indices = torch.zeros(
|
30
|
+
self.req_pool_indices = torch.zeros(
|
31
|
+
(self.max_bs,), dtype=torch.int32, device="cuda"
|
32
|
+
)
|
28
33
|
self.seq_lens = torch.ones((self.max_bs,), dtype=torch.int32, device="cuda")
|
29
|
-
self.position_ids_offsets = torch.zeros(
|
30
|
-
|
34
|
+
self.position_ids_offsets = torch.zeros(
|
35
|
+
(self.max_bs,), dtype=torch.int32, device="cuda"
|
36
|
+
)
|
37
|
+
self.out_cache_loc = torch.zeros(
|
38
|
+
(self.max_bs,), dtype=torch.int32, device="cuda"
|
39
|
+
)
|
31
40
|
|
32
41
|
# FlashInfer inputs
|
33
|
-
self.flashinfer_workspace_buffer =
|
42
|
+
self.flashinfer_workspace_buffer = (
|
43
|
+
self.model_runner.flashinfer_workspace_buffers[0]
|
44
|
+
)
|
34
45
|
self.flashinfer_kv_indptr = torch.zeros(
|
35
46
|
(self.max_bs + 1,), dtype=torch.int32, device="cuda"
|
36
47
|
)
|
37
48
|
self.flashinfer_kv_indices = torch.zeros(
|
38
|
-
(self.max_bs * model_runner.model_config.context_len,),
|
49
|
+
(self.max_bs * model_runner.model_config.context_len,),
|
50
|
+
dtype=torch.int32,
|
51
|
+
device="cuda",
|
39
52
|
)
|
40
53
|
self.flashinfer_kv_last_page_len = torch.ones(
|
41
54
|
(self.max_bs,), dtype=torch.int32, device="cuda"
|
@@ -49,7 +62,12 @@ class CudaGraphRunner:
|
|
49
62
|
with graph_capture() as graph_capture_context:
|
50
63
|
self.stream = graph_capture_context.stream
|
51
64
|
for bs in batch_size_list:
|
52
|
-
|
65
|
+
(
|
66
|
+
graph,
|
67
|
+
input_buffers,
|
68
|
+
output_buffers,
|
69
|
+
flashinfer_handler,
|
70
|
+
) = self.capture_one_batch_size(bs)
|
53
71
|
self.graphs[bs] = graph
|
54
72
|
self.input_buffers[bs] = input_buffers
|
55
73
|
self.output_buffers[bs] = output_buffers
|
@@ -71,17 +89,19 @@ class CudaGraphRunner:
|
|
71
89
|
|
72
90
|
# FlashInfer inputs
|
73
91
|
if not _grouped_size_compiled_for_decode_kernels(
|
74
|
-
self.model_runner.model_config.num_attention_heads
|
92
|
+
self.model_runner.model_config.num_attention_heads
|
93
|
+
// self.model_runner.tp_size,
|
75
94
|
self.model_runner.model_config.get_num_kv_heads(self.model_runner.tp_size),
|
76
95
|
):
|
77
96
|
use_tensor_cores = True
|
78
97
|
else:
|
79
98
|
use_tensor_cores = False
|
80
99
|
flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
81
|
-
self.flashinfer_workspace_buffer,
|
100
|
+
self.flashinfer_workspace_buffer,
|
101
|
+
"NHD",
|
82
102
|
use_cuda_graph=True,
|
83
103
|
use_tensor_cores=use_tensor_cores,
|
84
|
-
paged_kv_indptr_buffer=self.flashinfer_kv_indptr[:bs+1],
|
104
|
+
paged_kv_indptr_buffer=self.flashinfer_kv_indptr[: bs + 1],
|
85
105
|
paged_kv_indices_buffer=self.flashinfer_kv_indices,
|
86
106
|
paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[:bs],
|
87
107
|
)
|
@@ -163,10 +183,14 @@ class CudaGraphRunner:
|
|
163
183
|
else:
|
164
184
|
output = LogitProcessorOutput(
|
165
185
|
next_token_logits=output.next_token_logits[:raw_bs],
|
166
|
-
next_token_logprobs=output.next_token_logprobs[:raw_bs]
|
186
|
+
next_token_logprobs=output.next_token_logprobs[:raw_bs]
|
187
|
+
if output.next_token_logprobs is not None
|
188
|
+
else None,
|
167
189
|
normalized_prompt_logprobs=None,
|
168
190
|
prefill_token_logprobs=None,
|
169
191
|
prefill_top_logprobs=None,
|
170
|
-
decode_top_logprobs=output.decode_top_logprobs[:raw_bs]
|
192
|
+
decode_top_logprobs=output.decode_top_logprobs[:raw_bs]
|
193
|
+
if output.decode_top_logprobs is not None
|
194
|
+
else None,
|
171
195
|
)
|
172
196
|
return output
|
@@ -174,9 +174,6 @@ class Req:
|
|
174
174
|
|
175
175
|
return False, ""
|
176
176
|
|
177
|
-
def max_new_tokens(self):
|
178
|
-
return self.sampling_params.max_new_tokens
|
179
|
-
|
180
177
|
def check_finished(self):
|
181
178
|
if self.finished():
|
182
179
|
return
|
@@ -352,7 +349,7 @@ class Batch:
|
|
352
349
|
extend_num_tokens = seq_lens.sum() - prefix_lens.sum()
|
353
350
|
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
|
354
351
|
if out_cache_loc is None:
|
355
|
-
self.tree_cache.evict(extend_num_tokens, self.token_to_kv_pool.
|
352
|
+
self.tree_cache.evict(extend_num_tokens, self.token_to_kv_pool.free)
|
356
353
|
out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
|
357
354
|
|
358
355
|
if out_cache_loc is None:
|
@@ -422,7 +419,7 @@ class Batch:
|
|
422
419
|
if self.token_to_kv_pool.available_size() >= bs:
|
423
420
|
return True
|
424
421
|
|
425
|
-
self.tree_cache.evict(bs, self.token_to_kv_pool.
|
422
|
+
self.tree_cache.evict(bs, self.token_to_kv_pool.free)
|
426
423
|
|
427
424
|
if self.token_to_kv_pool.available_size() >= bs:
|
428
425
|
return True
|
@@ -453,7 +450,7 @@ class Batch:
|
|
453
450
|
token_indices = self.req_to_token_pool.req_to_token[
|
454
451
|
req_pool_indices_cpu[idx]
|
455
452
|
][last_uncached_pos : seq_lens_cpu[idx]]
|
456
|
-
self.token_to_kv_pool.
|
453
|
+
self.token_to_kv_pool.free(token_indices)
|
457
454
|
|
458
455
|
# release the last node
|
459
456
|
self.tree_cache.dec_lock_ref(req.last_node)
|
@@ -596,8 +593,7 @@ class Batch:
|
|
596
593
|
"logit_bias",
|
597
594
|
]:
|
598
595
|
self_val = getattr(self, item, None)
|
599
|
-
# logit_bias can be None
|
600
|
-
if self_val is not None:
|
596
|
+
if self_val is not None: # logit_bias can be None
|
601
597
|
setattr(self, item, self_val[new_indices])
|
602
598
|
|
603
599
|
def merge(self, other: "Batch"):
|
@@ -668,7 +664,9 @@ class Batch:
|
|
668
664
|
sampled_index = torch.multinomial(probs_sort, num_samples=1)
|
669
665
|
except RuntimeError as e:
|
670
666
|
warnings.warn(f"Ignore errors in sampling: {e}")
|
671
|
-
sampled_index = torch.ones(
|
667
|
+
sampled_index = torch.ones(
|
668
|
+
probs_sort.shape[:-1] + (1,), dtype=torch.int64, device=probs.device
|
669
|
+
)
|
672
670
|
batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(
|
673
671
|
-1
|
674
672
|
)
|
@@ -749,8 +747,14 @@ class InputMetadata:
|
|
749
747
|
skip_flashinfer_init=False,
|
750
748
|
):
|
751
749
|
if not skip_flashinfer_init and not model_runner.server_args.disable_flashinfer:
|
752
|
-
init_flashinfer_args(
|
753
|
-
|
750
|
+
init_flashinfer_args(
|
751
|
+
forward_mode,
|
752
|
+
model_runner,
|
753
|
+
req_pool_indices,
|
754
|
+
seq_lens,
|
755
|
+
prefix_lens,
|
756
|
+
model_runner.flashinfer_decode_wrapper,
|
757
|
+
)
|
754
758
|
|
755
759
|
batch_size = len(req_pool_indices)
|
756
760
|
|
@@ -807,16 +811,24 @@ class InputMetadata:
|
|
807
811
|
)
|
808
812
|
|
809
813
|
if model_runner.server_args.disable_flashinfer:
|
810
|
-
(
|
811
|
-
|
812
|
-
|
813
|
-
|
814
|
+
(
|
815
|
+
ret.triton_max_seq_len,
|
816
|
+
ret.triton_max_extend_len,
|
817
|
+
ret.triton_start_loc,
|
818
|
+
ret.triton_prefix_lens,
|
819
|
+
) = init_triton_args(forward_mode, seq_lens, prefix_lens)
|
814
820
|
|
815
821
|
return ret
|
816
822
|
|
817
823
|
|
818
|
-
def init_flashinfer_args(
|
819
|
-
|
824
|
+
def init_flashinfer_args(
|
825
|
+
forward_mode,
|
826
|
+
model_runner,
|
827
|
+
req_pool_indices,
|
828
|
+
seq_lens,
|
829
|
+
prefix_lens,
|
830
|
+
flashinfer_decode_wrapper,
|
831
|
+
):
|
820
832
|
num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size
|
821
833
|
num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size)
|
822
834
|
head_dim = model_runner.model_config.head_dim
|
@@ -827,9 +839,7 @@ def init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens,
|
|
827
839
|
else:
|
828
840
|
paged_kernel_lens = prefix_lens
|
829
841
|
|
830
|
-
kv_indptr = torch.zeros(
|
831
|
-
(batch_size + 1,), dtype=torch.int32, device="cuda"
|
832
|
-
)
|
842
|
+
kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
833
843
|
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
834
844
|
req_pool_indices_cpu = req_pool_indices.cpu().numpy()
|
835
845
|
paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
|
@@ -842,9 +852,7 @@ def init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens,
|
|
842
852
|
],
|
843
853
|
dim=0,
|
844
854
|
).contiguous()
|
845
|
-
kv_last_page_len = torch.ones(
|
846
|
-
(batch_size,), dtype=torch.int32, device="cuda"
|
847
|
-
)
|
855
|
+
kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
|
848
856
|
|
849
857
|
if forward_mode == ForwardMode.DECODE:
|
850
858
|
flashinfer_decode_wrapper.end_forward()
|
@@ -859,9 +867,7 @@ def init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens,
|
|
859
867
|
)
|
860
868
|
else:
|
861
869
|
# extend part
|
862
|
-
qo_indptr = torch.zeros(
|
863
|
-
(batch_size + 1,), dtype=torch.int32, device="cuda"
|
864
|
-
)
|
870
|
+
qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
|
865
871
|
qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
|
866
872
|
|
867
873
|
model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
|
@@ -42,6 +42,8 @@ class LoadBalanceMethod(Enum):
|
|
42
42
|
|
43
43
|
|
44
44
|
class Controller:
|
45
|
+
"""A controller that manages multiple data parallel workers."""
|
46
|
+
|
45
47
|
def __init__(
|
46
48
|
self,
|
47
49
|
load_balance_method: str,
|
@@ -183,9 +185,11 @@ def start_controller_process(
|
|
183
185
|
except Exception:
|
184
186
|
pipe_writer.send(get_exception_traceback())
|
185
187
|
raise
|
186
|
-
|
187
188
|
pipe_writer.send("init ok")
|
188
|
-
|
189
|
+
|
190
|
+
loop = asyncio.new_event_loop()
|
191
|
+
loop.set_default_executor(ThreadPoolExecutor(max_workers=256))
|
192
|
+
|
189
193
|
asyncio.set_event_loop(loop)
|
190
194
|
loop.create_task(controller.loop_for_recv_requests())
|
191
195
|
loop.run_until_complete(controller.loop_for_forward())
|
@@ -0,0 +1,177 @@
|
|
1
|
+
"""A controller that manages a group of tensor parallel workers."""
|
2
|
+
|
3
|
+
import multiprocessing
|
4
|
+
import logging
|
5
|
+
import os
|
6
|
+
import pickle
|
7
|
+
|
8
|
+
import torch
|
9
|
+
import torch.distributed as dist
|
10
|
+
import zmq
|
11
|
+
import zmq.asyncio
|
12
|
+
|
13
|
+
from sglang.srt.managers.controller.tp_worker import ModelTpServer
|
14
|
+
from sglang.srt.server_args import PortArgs, ServerArgs, ModelPortArgs
|
15
|
+
from sglang.srt.utils import kill_parent_process
|
16
|
+
from sglang.utils import get_exception_traceback
|
17
|
+
|
18
|
+
|
19
|
+
logger = logging.getLogger("srt.controller")
|
20
|
+
|
21
|
+
|
22
|
+
def run_tp_server(
|
23
|
+
gpu_id: int,
|
24
|
+
tp_rank: int,
|
25
|
+
server_args: ServerArgs,
|
26
|
+
model_port_args: ModelPortArgs,
|
27
|
+
model_overide_args: dict,
|
28
|
+
):
|
29
|
+
"""Run a tp server."""
|
30
|
+
try:
|
31
|
+
model_server = ModelTpServer(
|
32
|
+
gpu_id,
|
33
|
+
tp_rank,
|
34
|
+
server_args,
|
35
|
+
model_port_args,
|
36
|
+
model_overide_args,
|
37
|
+
)
|
38
|
+
tp_cpu_group = model_server.model_runner.tp_group.cpu_group
|
39
|
+
|
40
|
+
while True:
|
41
|
+
recv_reqs = broadcast_recv_input(None, tp_rank, tp_cpu_group)
|
42
|
+
model_server.exposed_step(recv_reqs)
|
43
|
+
except Exception:
|
44
|
+
logger.error("Exception in run_tp_server:\n" + get_exception_traceback())
|
45
|
+
raise
|
46
|
+
|
47
|
+
|
48
|
+
def launch_tp_servers(gpu_ids, tp_rank_range, server_args,
|
49
|
+
model_port_args, model_overide_args):
|
50
|
+
"""Launch multiple tp servers."""
|
51
|
+
procs = []
|
52
|
+
for i in tp_rank_range:
|
53
|
+
proc = multiprocessing.Process(target=run_tp_server, args=(
|
54
|
+
gpu_ids[i], i, server_args, model_port_args, model_overide_args
|
55
|
+
))
|
56
|
+
proc.start()
|
57
|
+
procs.append(proc)
|
58
|
+
|
59
|
+
return procs
|
60
|
+
|
61
|
+
|
62
|
+
def broadcast_recv_input(data, rank, dist_group):
|
63
|
+
"""Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
|
64
|
+
|
65
|
+
if rank == 0:
|
66
|
+
if len(data) == 0:
|
67
|
+
tensor_size = torch.tensor([0], dtype=torch.long)
|
68
|
+
dist.broadcast(tensor_size, src=0, group=dist_group)
|
69
|
+
else:
|
70
|
+
serialized_data = pickle.dumps(data)
|
71
|
+
size = len(serialized_data)
|
72
|
+
tensor_data = torch.ByteTensor(list(serialized_data))
|
73
|
+
tensor_size = torch.tensor([size], dtype=torch.long)
|
74
|
+
|
75
|
+
dist.broadcast(tensor_size, src=0, group=dist_group)
|
76
|
+
dist.broadcast(tensor_data, src=0, group=dist_group)
|
77
|
+
else:
|
78
|
+
tensor_size = torch.tensor([0], dtype=torch.long)
|
79
|
+
dist.broadcast(tensor_size, src=0, group=dist_group)
|
80
|
+
size = tensor_size.item()
|
81
|
+
|
82
|
+
if size == 0:
|
83
|
+
return []
|
84
|
+
|
85
|
+
tensor_data = torch.empty(size, dtype=torch.uint8)
|
86
|
+
dist.broadcast(tensor_data, src=0, group=dist_group)
|
87
|
+
|
88
|
+
serialized_data = bytes(tensor_data.tolist())
|
89
|
+
data = pickle.loads(serialized_data)
|
90
|
+
return data
|
91
|
+
|
92
|
+
|
93
|
+
class ControllerSingle:
|
94
|
+
"""A controller that manages a group of tensor parallel workers."""
|
95
|
+
|
96
|
+
def __init__(self, server_args: ServerArgs, port_args: PortArgs, model_overide_args: dict):
|
97
|
+
# Parse args
|
98
|
+
self.server_args = server_args
|
99
|
+
|
100
|
+
# Init communication
|
101
|
+
context = zmq.Context(2)
|
102
|
+
self.recv_from_tokenizer = context.socket(zmq.PULL)
|
103
|
+
self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.router_port}")
|
104
|
+
|
105
|
+
self.send_to_detokenizer = context.socket(zmq.PUSH)
|
106
|
+
self.send_to_detokenizer.connect(
|
107
|
+
f"tcp://127.0.0.1:{port_args.detokenizer_port}"
|
108
|
+
)
|
109
|
+
|
110
|
+
# Init model server
|
111
|
+
tp_size_local = server_args.tp_size // server_args.nnodes
|
112
|
+
gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)]
|
113
|
+
|
114
|
+
# Launch other tp ranks
|
115
|
+
if tp_size_local > 1:
|
116
|
+
tp_rank_range = range(1, tp_size_local)
|
117
|
+
self.tp_procs = launch_tp_servers(
|
118
|
+
gpu_ids, tp_rank_range, server_args,
|
119
|
+
port_args.model_port_args[0], model_overide_args)
|
120
|
+
|
121
|
+
# Launch tp rank 0
|
122
|
+
self.tp_server = ModelTpServer(
|
123
|
+
gpu_ids[0],
|
124
|
+
0,
|
125
|
+
server_args,
|
126
|
+
port_args.model_port_args[0],
|
127
|
+
model_overide_args,
|
128
|
+
)
|
129
|
+
self.tp_cpu_group = self.tp_server.model_runner.tp_group.cpu_group
|
130
|
+
|
131
|
+
def loop_for_forward(self):
|
132
|
+
while True:
|
133
|
+
recv_reqs = self.recv_requests()
|
134
|
+
|
135
|
+
if self.server_args.tp_size > 1:
|
136
|
+
broadcast_recv_input(recv_reqs, 0, self.tp_cpu_group)
|
137
|
+
|
138
|
+
out_pyobjs = self.tp_server.exposed_step(recv_reqs)
|
139
|
+
|
140
|
+
for obj in out_pyobjs:
|
141
|
+
self.send_to_detokenizer.send_pyobj(obj)
|
142
|
+
|
143
|
+
def recv_requests(self):
|
144
|
+
recv_reqs = []
|
145
|
+
while True:
|
146
|
+
try:
|
147
|
+
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
|
148
|
+
recv_reqs.append(recv_req)
|
149
|
+
except zmq.ZMQError:
|
150
|
+
break
|
151
|
+
return recv_reqs
|
152
|
+
|
153
|
+
|
154
|
+
def start_controller_process(
|
155
|
+
server_args: ServerArgs, port_args: PortArgs, pipe_writer, model_overide_args: dict
|
156
|
+
):
|
157
|
+
logging.basicConfig(
|
158
|
+
level=getattr(logging, server_args.log_level.upper()),
|
159
|
+
format="%(message)s",
|
160
|
+
)
|
161
|
+
|
162
|
+
try:
|
163
|
+
controller = ControllerSingle(server_args, port_args, model_overide_args)
|
164
|
+
except Exception:
|
165
|
+
pipe_writer.send(get_exception_traceback())
|
166
|
+
raise
|
167
|
+
|
168
|
+
pipe_writer.send("init ok")
|
169
|
+
|
170
|
+
try:
|
171
|
+
controller.loop_for_forward()
|
172
|
+
except Exception:
|
173
|
+
logger.error("Exception in ControllerSingle:\n" + get_exception_traceback())
|
174
|
+
finally:
|
175
|
+
for t in controller.tp_procs:
|
176
|
+
os.kill(t.pid, 9)
|
177
|
+
kill_parent_process()
|
@@ -11,12 +11,17 @@ import torch
|
|
11
11
|
import torch.nn as nn
|
12
12
|
from vllm.config import DeviceConfig, LoadConfig
|
13
13
|
from vllm.config import ModelConfig as VllmModelConfig
|
14
|
-
from vllm.distributed import init_distributed_environment, initialize_model_parallel
|
14
|
+
from vllm.distributed import init_distributed_environment, initialize_model_parallel, get_tp_group
|
15
15
|
from vllm.model_executor.model_loader import get_model
|
16
16
|
from vllm.model_executor.models import ModelRegistry
|
17
17
|
|
18
18
|
from sglang.global_config import global_config
|
19
|
-
from sglang.srt.managers.controller.infer_batch import
|
19
|
+
from sglang.srt.managers.controller.infer_batch import (
|
20
|
+
Batch,
|
21
|
+
ForwardMode,
|
22
|
+
InputMetadata,
|
23
|
+
global_server_args_dict,
|
24
|
+
)
|
20
25
|
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
21
26
|
from sglang.srt.server_args import ServerArgs
|
22
27
|
from sglang.srt.utils import (
|
@@ -70,6 +75,7 @@ class ModelRunner:
|
|
70
75
|
distributed_init_method=nccl_init_method,
|
71
76
|
)
|
72
77
|
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
|
78
|
+
self.tp_group = get_tp_group()
|
73
79
|
total_gpu_memory = get_available_gpu_memory(
|
74
80
|
self.gpu_id, distributed=self.tp_size > 1
|
75
81
|
)
|
@@ -83,7 +89,9 @@ class ModelRunner:
|
|
83
89
|
|
84
90
|
# Set some global args
|
85
91
|
global_server_args_dict["disable_flashinfer"] = server_args.disable_flashinfer
|
86
|
-
global_server_args_dict[
|
92
|
+
global_server_args_dict[
|
93
|
+
"attention_reduce_in_fp32"
|
94
|
+
] = server_args.attention_reduce_in_fp32
|
87
95
|
|
88
96
|
# Load the model and create memory pool
|
89
97
|
self.load_model()
|
@@ -217,7 +225,9 @@ class ModelRunner:
|
|
217
225
|
self.flashinfer_workspace_buffers[1], "NHD"
|
218
226
|
)
|
219
227
|
self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
220
|
-
self.flashinfer_workspace_buffers[0],
|
228
|
+
self.flashinfer_workspace_buffers[0],
|
229
|
+
"NHD",
|
230
|
+
use_tensor_cores=use_tensor_cores,
|
221
231
|
)
|
222
232
|
|
223
233
|
def init_cuda_graphs(self):
|
@@ -229,7 +239,9 @@ class ModelRunner:
|
|
229
239
|
|
230
240
|
logger.info(f"[gpu_id={self.gpu_id}] Capture cuda graph begin.")
|
231
241
|
batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 16)]
|
232
|
-
self.cuda_graph_runner = CudaGraphRunner(
|
242
|
+
self.cuda_graph_runner = CudaGraphRunner(
|
243
|
+
self, max_batch_size_to_capture=max(batch_size_list)
|
244
|
+
)
|
233
245
|
self.cuda_graph_runner.capture(batch_size_list)
|
234
246
|
|
235
247
|
@torch.inference_mode()
|