sglang 0.2.5__tar.gz → 0.2.6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sglang-0.2.5/sglang.egg-info → sglang-0.2.6}/PKG-INFO +9 -7
- {sglang-0.2.5 → sglang-0.2.6}/README.md +8 -6
- {sglang-0.2.5 → sglang-0.2.6}/pyproject.toml +1 -1
- {sglang-0.2.5 → sglang-0.2.6}/sglang/lang/backend/runtime_endpoint.py +4 -4
- {sglang-0.2.5 → sglang-0.2.6}/sglang/lang/interpreter.py +4 -4
- {sglang-0.2.5 → sglang-0.2.6}/sglang/srt/constrained/fsm_cache.py +21 -1
- {sglang-0.2.5 → sglang-0.2.6}/sglang/srt/hf_transformers_utils.py +3 -1
- {sglang-0.2.5 → sglang-0.2.6}/sglang/srt/layers/logits_processor.py +70 -61
- {sglang-0.2.5 → sglang-0.2.6}/sglang/srt/layers/radix_attention.py +5 -2
- {sglang-0.2.5 → sglang-0.2.6}/sglang/srt/layers/token_attention.py +1 -1
- {sglang-0.2.5 → sglang-0.2.6}/sglang/srt/managers/controller/cuda_graph_runner.py +26 -17
- {sglang-0.2.5 → sglang-0.2.6}/sglang/srt/managers/controller/infer_batch.py +54 -13
- {sglang-0.2.5 → sglang-0.2.6}/sglang/srt/managers/controller/model_runner.py +22 -7
- {sglang-0.2.5 → sglang-0.2.6}/sglang/srt/managers/controller/tp_worker.py +47 -41
- {sglang-0.2.5 → sglang-0.2.6}/sglang/srt/managers/io_struct.py +2 -2
- {sglang-0.2.5 → sglang-0.2.6}/sglang/srt/managers/tokenizer_manager.py +62 -43
- {sglang-0.2.5 → sglang-0.2.6}/sglang/srt/model_config.py +5 -0
- sglang-0.2.6/sglang/srt/models/deepseek_v2.py +517 -0
- {sglang-0.2.5 → sglang-0.2.6}/sglang/srt/models/llama_classification.py +3 -3
- {sglang-0.2.5 → sglang-0.2.6}/sglang/srt/openai_api/adapter.py +33 -33
- {sglang-0.2.5 → sglang-0.2.6}/sglang/srt/openai_api/protocol.py +1 -1
- {sglang-0.2.5 → sglang-0.2.6}/sglang/srt/sampling_params.py +5 -4
- {sglang-0.2.5 → sglang-0.2.6}/sglang/srt/server.py +2 -15
- {sglang-0.2.5 → sglang-0.2.6}/sglang/srt/server_args.py +28 -7
- {sglang-0.2.5 → sglang-0.2.6}/sglang/test/test_programs.py +5 -1
- sglang-0.2.6/sglang/version.py +1 -0
- {sglang-0.2.5 → sglang-0.2.6/sglang.egg-info}/PKG-INFO +9 -7
- {sglang-0.2.5 → sglang-0.2.6}/sglang.egg-info/SOURCES.txt +1 -0
- sglang-0.2.5/sglang/version.py +0 -1
- {sglang-0.2.5 → sglang-0.2.6}/LICENSE +0 -0
- {sglang-0.2.5 → sglang-0.2.6}/setup.cfg +0 -0
- {sglang-0.2.5 → sglang-0.2.6}/sglang/__init__.py +0 -0
- {sglang-0.2.5 → sglang-0.2.6}/sglang/api.py +0 -0
- {sglang-0.2.5 → sglang-0.2.6}/sglang/bench_latency.py +0 -0
- {sglang-0.2.5 → sglang-0.2.6}/sglang/bench_serving.py +0 -0
- {sglang-0.2.5 → sglang-0.2.6}/sglang/check_env.py +0 -0
- {sglang-0.2.5 → sglang-0.2.6}/sglang/global_config.py +0 -0
- {sglang-0.2.5 → sglang-0.2.6}/sglang/lang/__init__.py +0 -0
- {sglang-0.2.5 → sglang-0.2.6}/sglang/lang/backend/__init__.py +0 -0
- {sglang-0.2.5 → sglang-0.2.6}/sglang/lang/backend/anthropic.py +0 -0
- {sglang-0.2.5 → sglang-0.2.6}/sglang/lang/backend/base_backend.py +0 -0
- {sglang-0.2.5 → sglang-0.2.6}/sglang/lang/backend/litellm.py +0 -0
- {sglang-0.2.5 → sglang-0.2.6}/sglang/lang/backend/openai.py +0 -0
- {sglang-0.2.5 → sglang-0.2.6}/sglang/lang/backend/vertexai.py +0 -0
- {sglang-0.2.5 → sglang-0.2.6}/sglang/lang/chat_template.py +0 -0
- {sglang-0.2.5 → sglang-0.2.6}/sglang/lang/compiler.py +0 -0
- {sglang-0.2.5 → sglang-0.2.6}/sglang/lang/ir.py +0 -0
- {sglang-0.2.5 → sglang-0.2.6}/sglang/lang/tracer.py +0 -0
- {sglang-0.2.5 → sglang-0.2.6}/sglang/launch_server.py +0 -0
- {sglang-0.2.5 → sglang-0.2.6}/sglang/launch_server_llavavid.py +0 -0
- {sglang-0.2.5 → sglang-0.2.6}/sglang/srt/constrained/__init__.py +0 -0
- {sglang-0.2.5 → sglang-0.2.6}/sglang/srt/constrained/base_cache.py +0 -0
- {sglang-0.2.5 → sglang-0.2.6}/sglang/srt/constrained/jump_forward.py +0 -0
- {sglang-0.2.5 → sglang-0.2.6}/sglang/srt/conversation.py +0 -0
- {sglang-0.2.5 → sglang-0.2.6}/sglang/srt/flush_cache.py +0 -0
- {sglang-0.2.5 → sglang-0.2.6}/sglang/srt/layers/context_flashattention_nopad.py +0 -0
- {sglang-0.2.5 → sglang-0.2.6}/sglang/srt/layers/extend_attention.py +0 -0
- {sglang-0.2.5 → sglang-0.2.6}/sglang/srt/layers/fused_moe.py +0 -0
- {sglang-0.2.5 → sglang-0.2.6}/sglang/srt/layers/linear.py +0 -0
- {sglang-0.2.5 → sglang-0.2.6}/sglang/srt/layers/quantization/__init__.py +0 -0
- {sglang-0.2.5 → sglang-0.2.6}/sglang/srt/layers/quantization/fp8.py +0 -0
- {sglang-0.2.5 → sglang-0.2.6}/sglang/srt/managers/controller/manager_multi.py +0 -0
- {sglang-0.2.5 → sglang-0.2.6}/sglang/srt/managers/controller/manager_single.py +0 -0
- {sglang-0.2.5 → sglang-0.2.6}/sglang/srt/managers/controller/radix_cache.py +0 -0
- {sglang-0.2.5 → sglang-0.2.6}/sglang/srt/managers/controller/schedule_heuristic.py +0 -0
- {sglang-0.2.5 → sglang-0.2.6}/sglang/srt/managers/detokenizer_manager.py +0 -0
- {sglang-0.2.5 → sglang-0.2.6}/sglang/srt/memory_pool.py +0 -0
- {sglang-0.2.5 → sglang-0.2.6}/sglang/srt/mm_utils.py +0 -0
- {sglang-0.2.5 → sglang-0.2.6}/sglang/srt/model_loader/model_loader.py +0 -0
- {sglang-0.2.5 → sglang-0.2.6}/sglang/srt/model_loader/utils.py +0 -0
- {sglang-0.2.5 → sglang-0.2.6}/sglang/srt/models/chatglm.py +0 -0
- {sglang-0.2.5 → sglang-0.2.6}/sglang/srt/models/commandr.py +0 -0
- {sglang-0.2.5 → sglang-0.2.6}/sglang/srt/models/dbrx.py +0 -0
- {sglang-0.2.5 → sglang-0.2.6}/sglang/srt/models/deepseek.py +0 -0
- {sglang-0.2.5 → sglang-0.2.6}/sglang/srt/models/gemma.py +0 -0
- {sglang-0.2.5 → sglang-0.2.6}/sglang/srt/models/gemma2.py +0 -0
- {sglang-0.2.5 → sglang-0.2.6}/sglang/srt/models/gpt_bigcode.py +0 -0
- {sglang-0.2.5 → sglang-0.2.6}/sglang/srt/models/grok.py +0 -0
- {sglang-0.2.5 → sglang-0.2.6}/sglang/srt/models/internlm2.py +0 -0
- {sglang-0.2.5 → sglang-0.2.6}/sglang/srt/models/llama2.py +0 -0
- {sglang-0.2.5 → sglang-0.2.6}/sglang/srt/models/llava.py +0 -0
- {sglang-0.2.5 → sglang-0.2.6}/sglang/srt/models/llavavid.py +0 -0
- {sglang-0.2.5 → sglang-0.2.6}/sglang/srt/models/minicpm.py +0 -0
- {sglang-0.2.5 → sglang-0.2.6}/sglang/srt/models/mistral.py +0 -0
- {sglang-0.2.5 → sglang-0.2.6}/sglang/srt/models/mixtral.py +0 -0
- {sglang-0.2.5 → sglang-0.2.6}/sglang/srt/models/mixtral_quant.py +0 -0
- {sglang-0.2.5 → sglang-0.2.6}/sglang/srt/models/qwen.py +0 -0
- {sglang-0.2.5 → sglang-0.2.6}/sglang/srt/models/qwen2.py +0 -0
- {sglang-0.2.5 → sglang-0.2.6}/sglang/srt/models/qwen2_moe.py +0 -0
- {sglang-0.2.5 → sglang-0.2.6}/sglang/srt/models/stablelm.py +0 -0
- {sglang-0.2.5 → sglang-0.2.6}/sglang/srt/models/yivl.py +0 -0
- {sglang-0.2.5 → sglang-0.2.6}/sglang/srt/utils.py +0 -0
- {sglang-0.2.5 → sglang-0.2.6}/sglang/test/test_conversation.py +0 -0
- {sglang-0.2.5 → sglang-0.2.6}/sglang/test/test_openai_protocol.py +0 -0
- {sglang-0.2.5 → sglang-0.2.6}/sglang/test/test_utils.py +0 -0
- {sglang-0.2.5 → sglang-0.2.6}/sglang/utils.py +0 -0
- {sglang-0.2.5 → sglang-0.2.6}/sglang.egg-info/dependency_links.txt +0 -0
- {sglang-0.2.5 → sglang-0.2.6}/sglang.egg-info/requires.txt +0 -0
- {sglang-0.2.5 → sglang-0.2.6}/sglang.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: sglang
|
3
|
-
Version: 0.2.
|
3
|
+
Version: 0.2.6
|
4
4
|
Summary: SGLang is yet another fast serving framework for large language models and vision language models.
|
5
5
|
License: Apache License
|
6
6
|
Version 2.0, January 2004
|
@@ -249,7 +249,7 @@ Requires-Dist: sglang[litellm]; extra == "all"
|
|
249
249
|
|
250
250
|
--------------------------------------------------------------------------------
|
251
251
|
|
252
|
-
| [**Blog**](https://lmsys.org/blog/2024-07-25-sglang-llama3/) | [**Paper**](https://arxiv.org/abs/2312.07104) |
|
252
|
+
| [**Blog**](https://lmsys.org/blog/2024-07-25-sglang-llama3/) | [**Paper**](https://arxiv.org/abs/2312.07104) | [**Slack**](https://join.slack.com/t/sgl-fru7574/shared_invite/zt-2ngly9muu-t37XiH87qvD~6rVBTkTEHw) |
|
253
253
|
|
254
254
|
SGLang is a fast serving framework for large language models and vision language models.
|
255
255
|
It makes your interaction with models faster and more controllable by co-designing the backend runtime and frontend language.
|
@@ -404,16 +404,17 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct
|
|
404
404
|
### Run Llama 3.1 405B
|
405
405
|
|
406
406
|
```bash
|
407
|
-
|
407
|
+
## Run 405B (fp8) on a single node
|
408
|
+
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-405B-Instruct-FP8 --tp 8
|
409
|
+
|
410
|
+
## Run 405B (fp16) on two nodes
|
408
411
|
# replace the `172.16.4.52:20000` with your own first node ip address and port, disable CUDA Graph temporarily
|
412
|
+
|
409
413
|
# on the first node
|
410
414
|
GLOO_SOCKET_IFNAME=eth0 python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-405B-Instruct --tp 16 --nccl-init-addr 172.16.4.52:20000 --nnodes 2 --node-rank 0 --disable-cuda-graph --mem-frac 0.75
|
411
415
|
|
412
416
|
# on the second
|
413
417
|
GLOO_SOCKET_IFNAME=eth0 python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-405B-Instruct --tp 16 --nccl-init-addr 172.16.4.52:20000 --nnodes 2 --node-rank 1 --disable-cuda-graph --mem-frac 0.75
|
414
|
-
|
415
|
-
# single node run 405B fp8
|
416
|
-
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-405B-Instruct-FP8 --tp 8
|
417
418
|
```
|
418
419
|
|
419
420
|
### Supported Models
|
@@ -422,6 +423,7 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-405B-Instr
|
|
422
423
|
- Mistral / Mixtral
|
423
424
|
- Gemma / Gemma 2
|
424
425
|
- Qwen / Qwen 2 / Qwen 2 MoE
|
426
|
+
- DeepSeek / DeepSeek 2
|
425
427
|
- LLaVA 1.5 / 1.6
|
426
428
|
- `python -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --chat-template vicuna_v1.1 --port 30000`
|
427
429
|
- `python -m sglang.launch_server --model-path liuhaotian/llava-v1.6-vicuna-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --chat-template vicuna_v1.1 --port 30000`
|
@@ -442,7 +444,7 @@ Instructions for supporting a new model are [here](https://github.com/sgl-projec
|
|
442
444
|
|
443
445
|
### Benchmark Performance
|
444
446
|
|
445
|
-
- Benchmark a single static batch
|
447
|
+
- Benchmark a single static batch by running the following command without launching a server. The arguments are the same as those for `launch_server.py`. This is not a dynamic batching server, so it may run out of memory for a batch size that can run successfully with a real server. This is because a real server will truncate the prefill into several batches/chunks, while this unit test does not do this.
|
446
448
|
```
|
447
449
|
python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 32 --input-len 256 --output-len 32
|
448
450
|
```
|
@@ -4,7 +4,7 @@
|
|
4
4
|
|
5
5
|
--------------------------------------------------------------------------------
|
6
6
|
|
7
|
-
| [**Blog**](https://lmsys.org/blog/2024-07-25-sglang-llama3/) | [**Paper**](https://arxiv.org/abs/2312.07104) |
|
7
|
+
| [**Blog**](https://lmsys.org/blog/2024-07-25-sglang-llama3/) | [**Paper**](https://arxiv.org/abs/2312.07104) | [**Slack**](https://join.slack.com/t/sgl-fru7574/shared_invite/zt-2ngly9muu-t37XiH87qvD~6rVBTkTEHw) |
|
8
8
|
|
9
9
|
SGLang is a fast serving framework for large language models and vision language models.
|
10
10
|
It makes your interaction with models faster and more controllable by co-designing the backend runtime and frontend language.
|
@@ -159,16 +159,17 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct
|
|
159
159
|
### Run Llama 3.1 405B
|
160
160
|
|
161
161
|
```bash
|
162
|
-
|
162
|
+
## Run 405B (fp8) on a single node
|
163
|
+
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-405B-Instruct-FP8 --tp 8
|
164
|
+
|
165
|
+
## Run 405B (fp16) on two nodes
|
163
166
|
# replace the `172.16.4.52:20000` with your own first node ip address and port, disable CUDA Graph temporarily
|
167
|
+
|
164
168
|
# on the first node
|
165
169
|
GLOO_SOCKET_IFNAME=eth0 python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-405B-Instruct --tp 16 --nccl-init-addr 172.16.4.52:20000 --nnodes 2 --node-rank 0 --disable-cuda-graph --mem-frac 0.75
|
166
170
|
|
167
171
|
# on the second
|
168
172
|
GLOO_SOCKET_IFNAME=eth0 python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-405B-Instruct --tp 16 --nccl-init-addr 172.16.4.52:20000 --nnodes 2 --node-rank 1 --disable-cuda-graph --mem-frac 0.75
|
169
|
-
|
170
|
-
# single node run 405B fp8
|
171
|
-
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-405B-Instruct-FP8 --tp 8
|
172
173
|
```
|
173
174
|
|
174
175
|
### Supported Models
|
@@ -177,6 +178,7 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-405B-Instr
|
|
177
178
|
- Mistral / Mixtral
|
178
179
|
- Gemma / Gemma 2
|
179
180
|
- Qwen / Qwen 2 / Qwen 2 MoE
|
181
|
+
- DeepSeek / DeepSeek 2
|
180
182
|
- LLaVA 1.5 / 1.6
|
181
183
|
- `python -m sglang.launch_server --model-path liuhaotian/llava-v1.5-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --chat-template vicuna_v1.1 --port 30000`
|
182
184
|
- `python -m sglang.launch_server --model-path liuhaotian/llava-v1.6-vicuna-7b --tokenizer-path llava-hf/llava-1.5-7b-hf --chat-template vicuna_v1.1 --port 30000`
|
@@ -197,7 +199,7 @@ Instructions for supporting a new model are [here](https://github.com/sgl-projec
|
|
197
199
|
|
198
200
|
### Benchmark Performance
|
199
201
|
|
200
|
-
- Benchmark a single static batch
|
202
|
+
- Benchmark a single static batch by running the following command without launching a server. The arguments are the same as those for `launch_server.py`. This is not a dynamic batching server, so it may run out of memory for a batch size that can run successfully with a real server. This is because a real server will truncate the prefill into several batches/chunks, while this unit test does not do this.
|
201
203
|
```
|
202
204
|
python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 32 --input-len 256 --output-len 32
|
203
205
|
```
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
4
4
|
|
5
5
|
[project]
|
6
6
|
name = "sglang"
|
7
|
-
version = "0.2.
|
7
|
+
version = "0.2.6"
|
8
8
|
description = "SGLang is yet another fast serving framework for large language models and vision language models."
|
9
9
|
readme = "README.md"
|
10
10
|
requires-python = ">=3.8"
|
@@ -253,14 +253,14 @@ class RuntimeEndpoint(BaseBackend):
|
|
253
253
|
r["meta_info"]["normalized_prompt_logprob"] for r in obj
|
254
254
|
]
|
255
255
|
decision = choices[np.argmax(normalized_prompt_logprobs)]
|
256
|
-
|
257
|
-
|
256
|
+
input_token_logprobs = [r["meta_info"]["input_token_logprobs"] for r in obj]
|
257
|
+
output_token_logprobs = [r["meta_info"]["output_token_logprobs"] for r in obj]
|
258
258
|
|
259
259
|
return (
|
260
260
|
decision,
|
261
261
|
normalized_prompt_logprobs,
|
262
|
-
|
263
|
-
|
262
|
+
input_token_logprobs,
|
263
|
+
output_token_logprobs,
|
264
264
|
)
|
265
265
|
|
266
266
|
def concatenate_and_append(self, src_rids: List[str], dst_rid: str):
|
@@ -541,16 +541,16 @@ class StreamExecutor:
|
|
541
541
|
(
|
542
542
|
decision,
|
543
543
|
normalized_prompt_logprobs,
|
544
|
-
|
545
|
-
|
544
|
+
input_token_logprobs,
|
545
|
+
output_token_logprobs,
|
546
546
|
) = self.backend.select(self, expr.choices, expr.temperature)
|
547
547
|
if expr.name is not None:
|
548
548
|
name = expr.name
|
549
549
|
self.variables[name] = decision
|
550
550
|
self.meta_info[name] = {
|
551
551
|
"normalized_prompt_logprobs": normalized_prompt_logprobs,
|
552
|
-
"
|
553
|
-
"
|
552
|
+
"input_token_logprobs": input_token_logprobs,
|
553
|
+
"output_token_logprobs": output_token_logprobs,
|
554
554
|
}
|
555
555
|
self.variable_event[name].set()
|
556
556
|
self.text_ += decision
|
@@ -21,7 +21,27 @@ class FSMCache(BaseCache):
|
|
21
21
|
tokenizer = AutoTokenizer.from_pretrained(
|
22
22
|
tokenizer_path, **tokenizer_args_dict
|
23
23
|
)
|
24
|
-
|
24
|
+
try:
|
25
|
+
self.outlines_tokenizer = TransformerTokenizer(tokenizer)
|
26
|
+
except AttributeError:
|
27
|
+
# FIXME: tmp fix for chatglm2 & chatglm3 (pad_token_id=0)
|
28
|
+
origin_pad_token_id = tokenizer.pad_token_id
|
29
|
+
|
30
|
+
def fset(self, value):
|
31
|
+
self._value = value
|
32
|
+
|
33
|
+
type(tokenizer).pad_token_id = property(
|
34
|
+
fget=type(tokenizer).pad_token_id.fget, fset=fset
|
35
|
+
)
|
36
|
+
self.outlines_tokenizer = TransformerTokenizer(tokenizer)
|
37
|
+
self.outlines_tokenizer.tokenizer.pad_token_id = origin_pad_token_id
|
38
|
+
self.outlines_tokenizer.pad_token_id = origin_pad_token_id
|
39
|
+
self.outlines_tokenizer.pad_token = (
|
40
|
+
self.outlines_tokenizer.tokenizer.pad_token
|
41
|
+
)
|
42
|
+
self.outlines_tokenizer.vocabulary = (
|
43
|
+
self.outlines_tokenizer.tokenizer.get_vocab()
|
44
|
+
)
|
25
45
|
else:
|
26
46
|
self.outlines_tokenizer = TransformerTokenizer(
|
27
47
|
tokenizer_path, **tokenizer_args_dict
|
@@ -73,7 +73,9 @@ def get_context_length(config):
|
|
73
73
|
rope_scaling = getattr(config, "rope_scaling", None)
|
74
74
|
if rope_scaling:
|
75
75
|
rope_scaling_factor = config.rope_scaling["factor"]
|
76
|
-
if
|
76
|
+
if "original_max_position_embeddings" in rope_scaling:
|
77
|
+
rope_scaling_factor = 1
|
78
|
+
if config.rope_scaling.get("rope_type", None) == "llama3":
|
77
79
|
rope_scaling_factor = 1
|
78
80
|
else:
|
79
81
|
rope_scaling_factor = 1
|
@@ -1,7 +1,7 @@
|
|
1
1
|
"""Logits processing."""
|
2
2
|
|
3
3
|
import dataclasses
|
4
|
-
from typing import List, Union
|
4
|
+
from typing import List, Optional, Union
|
5
5
|
|
6
6
|
import torch
|
7
7
|
from torch import nn
|
@@ -22,23 +22,23 @@ class LogitProcessorOutput:
|
|
22
22
|
|
23
23
|
# The normlaized logprobs of prompts. shape: [#seq]
|
24
24
|
normalized_prompt_logprobs: torch.Tensor
|
25
|
-
# The logprobs of
|
26
|
-
|
25
|
+
# The logprobs of input tokens. shape: [#token, vocab_size]
|
26
|
+
input_token_logprobs: torch.Tensor
|
27
27
|
|
28
|
-
# The logprob and id of the top-k tokens in
|
29
|
-
|
30
|
-
# The logprob and id of the top-k tokens in
|
31
|
-
|
28
|
+
# The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
|
29
|
+
input_top_logprobs: List
|
30
|
+
# The logprob and id of the top-k tokens in output positions. shape [#seq, #token, k] of Tuple(logprob, token_id)
|
31
|
+
output_top_logprobs: List
|
32
32
|
|
33
33
|
|
34
34
|
@dataclasses.dataclass
|
35
35
|
class LogitsMetadata:
|
36
36
|
forward_mode: ForwardMode
|
37
|
-
return_logprob: bool
|
37
|
+
return_logprob: bool = False
|
38
38
|
|
39
|
-
extend_seq_lens: torch.Tensor = None
|
40
|
-
extend_start_loc: torch.Tensor = None
|
41
|
-
top_logprobs_nums: List[int] = None
|
39
|
+
extend_seq_lens: Optional[torch.Tensor] = None
|
40
|
+
extend_start_loc: Optional[torch.Tensor] = None
|
41
|
+
top_logprobs_nums: Optional[List[int]] = None
|
42
42
|
|
43
43
|
@classmethod
|
44
44
|
def from_input_metadata(cls, input_metadata: InputMetadata):
|
@@ -58,20 +58,16 @@ class LogitsProcessor(nn.Module):
|
|
58
58
|
self.tp_size = get_tensor_model_parallel_world_size()
|
59
59
|
|
60
60
|
def _get_normalized_prompt_logprobs(
|
61
|
-
self,
|
61
|
+
self, input_token_logprobs, logits_metadata: LogitsMetadata
|
62
62
|
):
|
63
|
-
logprobs_cumsum = torch.cumsum(
|
64
|
-
prefill_token_logprobs, dim=0, dtype=torch.float32
|
65
|
-
)
|
63
|
+
logprobs_cumsum = torch.cumsum(input_token_logprobs, dim=0, dtype=torch.float32)
|
66
64
|
|
67
65
|
start = logits_metadata.extend_start_loc.clone()
|
68
66
|
end = start + logits_metadata.extend_seq_lens - 2
|
69
|
-
start.clamp_(min=0, max=
|
70
|
-
end.clamp_(min=0, max=
|
67
|
+
start.clamp_(min=0, max=input_token_logprobs.shape[0] - 1)
|
68
|
+
end.clamp_(min=0, max=input_token_logprobs.shape[0] - 1)
|
71
69
|
sum_logp = (
|
72
|
-
logprobs_cumsum[end]
|
73
|
-
- logprobs_cumsum[start]
|
74
|
-
+ prefill_token_logprobs[start]
|
70
|
+
logprobs_cumsum[end] - logprobs_cumsum[start] + input_token_logprobs[start]
|
75
71
|
)
|
76
72
|
normalized_prompt_logprobs = sum_logp / (
|
77
73
|
(logits_metadata.extend_seq_lens - 1).clamp(min=1)
|
@@ -79,37 +75,38 @@ class LogitsProcessor(nn.Module):
|
|
79
75
|
|
80
76
|
return normalized_prompt_logprobs
|
81
77
|
|
82
|
-
|
78
|
+
@staticmethod
|
79
|
+
def get_top_logprobs(all_logprobs, logits_metadata: LogitsMetadata):
|
83
80
|
# TODO: vectorize the code below
|
84
81
|
if logits_metadata.forward_mode == ForwardMode.DECODE:
|
85
|
-
|
82
|
+
output_top_logprobs = []
|
86
83
|
for i in range(all_logprobs.shape[0]):
|
87
84
|
k = logits_metadata.top_logprobs_nums[i]
|
88
85
|
t = all_logprobs[i].topk(k)
|
89
86
|
v_cpu = t.values.tolist()
|
90
87
|
p_cpu = t.indices.tolist()
|
91
|
-
|
92
|
-
return None,
|
88
|
+
output_top_logprobs.append(list(zip(v_cpu, p_cpu)))
|
89
|
+
return None, output_top_logprobs
|
93
90
|
else:
|
94
|
-
|
91
|
+
input_top_logprobs, output_top_logprobs = [], []
|
95
92
|
pt = 0
|
96
93
|
extend_seq_lens_cpu = logits_metadata.extend_seq_lens.tolist()
|
97
94
|
for i, extend_seq_len in enumerate(extend_seq_lens_cpu):
|
98
95
|
if extend_seq_len == 0:
|
99
|
-
|
100
|
-
|
96
|
+
input_top_logprobs.append([])
|
97
|
+
output_top_logprobs.append([])
|
101
98
|
continue
|
102
99
|
k = logits_metadata.top_logprobs_nums[i]
|
103
100
|
t = all_logprobs[pt : pt + extend_seq_len].topk(k)
|
104
101
|
vs_cpu = t.values.tolist()
|
105
102
|
ps_cpu = t.indices.tolist()
|
106
|
-
|
103
|
+
input_top_logprobs.append(
|
107
104
|
[list(zip(vs_cpu[j], ps_cpu[j])) for j in range(len(vs_cpu) - 1)]
|
108
105
|
)
|
109
|
-
|
106
|
+
output_top_logprobs.append(list(zip(vs_cpu[-1], ps_cpu[-1])))
|
110
107
|
pt += extend_seq_len
|
111
108
|
|
112
|
-
return
|
109
|
+
return input_top_logprobs, output_top_logprobs
|
113
110
|
|
114
111
|
def forward(
|
115
112
|
self,
|
@@ -136,7 +133,7 @@ class LogitsProcessor(nn.Module):
|
|
136
133
|
last_logits = torch.matmul(last_hidden, weight.T)
|
137
134
|
if self.tp_size > 1:
|
138
135
|
last_logits = tensor_model_parallel_all_gather(last_logits)
|
139
|
-
last_logits = last_logits[:, : self.config.vocab_size]
|
136
|
+
last_logits = last_logits[:, : self.config.vocab_size].float()
|
140
137
|
|
141
138
|
if hasattr(self.config, "final_logit_softcapping"):
|
142
139
|
last_logits /= self.config.final_logit_softcapping
|
@@ -149,63 +146,75 @@ class LogitsProcessor(nn.Module):
|
|
149
146
|
next_token_logits=last_logits,
|
150
147
|
next_token_logprobs=None,
|
151
148
|
normalized_prompt_logprobs=None,
|
152
|
-
|
153
|
-
|
154
|
-
|
149
|
+
input_token_logprobs=None,
|
150
|
+
input_top_logprobs=None,
|
151
|
+
output_top_logprobs=None,
|
155
152
|
)
|
156
153
|
else:
|
157
154
|
# When logprob is requested, compute the logits for all tokens.
|
158
155
|
if logits_metadata.forward_mode == ForwardMode.DECODE:
|
159
|
-
|
160
|
-
else:
|
161
|
-
all_logits = torch.matmul(hidden_states, weight.T)
|
162
|
-
if self.tp_size > 1:
|
163
|
-
all_logits = tensor_model_parallel_all_gather(all_logits)
|
164
|
-
all_logits = all_logits[:, : self.config.vocab_size]
|
165
|
-
|
166
|
-
all_logprobs = all_logits.float()
|
167
|
-
del all_logits
|
168
|
-
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
|
156
|
+
last_logprobs = torch.nn.functional.log_softmax(last_logits, dim=-1)
|
169
157
|
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs(
|
174
|
-
all_logprobs, logits_metadata
|
158
|
+
# Get the logprob of top-k tokens
|
159
|
+
return_top_logprob = any(
|
160
|
+
x > 0 for x in logits_metadata.top_logprobs_nums
|
175
161
|
)
|
176
|
-
|
177
|
-
|
162
|
+
if return_top_logprob:
|
163
|
+
output_top_logprobs = self.get_top_logprobs(
|
164
|
+
last_logprobs, logits_metadata
|
165
|
+
)[1]
|
166
|
+
else:
|
167
|
+
output_top_logprobs = None
|
178
168
|
|
179
|
-
if logits_metadata.forward_mode == ForwardMode.DECODE:
|
180
169
|
return LogitProcessorOutput(
|
181
170
|
next_token_logits=last_logits,
|
182
|
-
next_token_logprobs=
|
171
|
+
next_token_logprobs=last_logprobs,
|
183
172
|
normalized_prompt_logprobs=None,
|
184
|
-
|
185
|
-
|
186
|
-
|
173
|
+
input_token_logprobs=None,
|
174
|
+
input_top_logprobs=None,
|
175
|
+
output_top_logprobs=output_top_logprobs,
|
187
176
|
)
|
188
177
|
else:
|
178
|
+
all_logits = torch.matmul(hidden_states, weight.T)
|
179
|
+
if self.tp_size > 1:
|
180
|
+
all_logits = tensor_model_parallel_all_gather(all_logits)
|
181
|
+
all_logits = all_logits[:, : self.config.vocab_size].float()
|
182
|
+
|
183
|
+
all_logprobs = all_logits
|
184
|
+
del all_logits
|
185
|
+
all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
|
186
|
+
|
187
|
+
# Get the logprob of top-k tokens
|
188
|
+
return_top_logprob = any(
|
189
|
+
x > 0 for x in logits_metadata.top_logprobs_nums
|
190
|
+
)
|
191
|
+
if return_top_logprob:
|
192
|
+
input_top_logprobs, output_top_logprobs = self.get_top_logprobs(
|
193
|
+
all_logprobs, logits_metadata
|
194
|
+
)
|
195
|
+
else:
|
196
|
+
input_top_logprobs = output_top_logprobs = None
|
197
|
+
|
189
198
|
last_logprobs = all_logprobs[last_index]
|
190
199
|
|
191
200
|
# Compute the logprobs and normalized logprobs for the prefill tokens.
|
192
201
|
# Note that we pad a zero at the end of each sequence for easy computation.
|
193
|
-
|
202
|
+
input_token_logprobs = all_logprobs[
|
194
203
|
torch.arange(all_logprobs.shape[0], device="cuda"),
|
195
204
|
torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
|
196
205
|
]
|
197
206
|
|
198
207
|
normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
|
199
|
-
|
208
|
+
input_token_logprobs, logits_metadata
|
200
209
|
)
|
201
210
|
|
202
211
|
return LogitProcessorOutput(
|
203
212
|
next_token_logits=last_logits,
|
204
213
|
next_token_logprobs=last_logprobs,
|
205
214
|
normalized_prompt_logprobs=normalized_prompt_logprobs,
|
206
|
-
|
207
|
-
|
208
|
-
|
215
|
+
input_token_logprobs=input_token_logprobs,
|
216
|
+
input_top_logprobs=input_top_logprobs,
|
217
|
+
output_top_logprobs=output_top_logprobs,
|
209
218
|
)
|
210
219
|
|
211
220
|
|
@@ -7,8 +7,11 @@ from torch import nn
|
|
7
7
|
from sglang.global_config import global_config
|
8
8
|
from sglang.srt.layers.extend_attention import extend_attention_fwd
|
9
9
|
from sglang.srt.layers.token_attention import token_attention_fwd
|
10
|
-
from sglang.srt.managers.controller.model_runner import
|
11
|
-
|
10
|
+
from sglang.srt.managers.controller.model_runner import (
|
11
|
+
ForwardMode,
|
12
|
+
InputMetadata,
|
13
|
+
global_server_args_dict,
|
14
|
+
)
|
12
15
|
|
13
16
|
|
14
17
|
class RadixAttention(nn.Module):
|
@@ -5,7 +5,7 @@ import torch
|
|
5
5
|
import triton
|
6
6
|
import triton.language as tl
|
7
7
|
|
8
|
-
from sglang.srt.
|
8
|
+
from sglang.srt.managers.controller.infer_batch import global_server_args_dict
|
9
9
|
|
10
10
|
if global_server_args_dict.get("attention_reduce_in_fp32", False):
|
11
11
|
REDUCE_TRITON_TYPE = tl.float32
|
@@ -9,7 +9,11 @@ from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
|
9
9
|
from vllm.distributed.parallel_state import graph_capture
|
10
10
|
from vllm.model_executor.custom_op import CustomOp
|
11
11
|
|
12
|
-
from sglang.srt.layers.logits_processor import
|
12
|
+
from sglang.srt.layers.logits_processor import (
|
13
|
+
LogitProcessorOutput,
|
14
|
+
LogitsMetadata,
|
15
|
+
LogitsProcessor,
|
16
|
+
)
|
13
17
|
from sglang.srt.managers.controller.infer_batch import (
|
14
18
|
Batch,
|
15
19
|
ForwardMode,
|
@@ -185,7 +189,6 @@ class CudaGraphRunner:
|
|
185
189
|
|
186
190
|
def replay(self, batch: Batch):
|
187
191
|
assert batch.out_cache_loc is not None
|
188
|
-
assert not batch.return_logprob
|
189
192
|
raw_bs = len(batch.reqs)
|
190
193
|
|
191
194
|
# Pad
|
@@ -218,23 +221,29 @@ class CudaGraphRunner:
|
|
218
221
|
output = self.output_buffers[bs]
|
219
222
|
|
220
223
|
# Unpad
|
221
|
-
if bs
|
222
|
-
return output
|
223
|
-
else:
|
224
|
+
if bs != raw_bs:
|
224
225
|
output = LogitProcessorOutput(
|
225
226
|
next_token_logits=output.next_token_logits[:raw_bs],
|
226
|
-
next_token_logprobs=
|
227
|
-
output.next_token_logprobs[:raw_bs]
|
228
|
-
if output.next_token_logprobs is not None
|
229
|
-
else None
|
230
|
-
),
|
227
|
+
next_token_logprobs=None,
|
231
228
|
normalized_prompt_logprobs=None,
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
output.decode_top_logprobs[:raw_bs]
|
236
|
-
if output.decode_top_logprobs is not None
|
237
|
-
else None
|
238
|
-
),
|
229
|
+
input_token_logprobs=None,
|
230
|
+
input_top_logprobs=None,
|
231
|
+
output_top_logprobs=None,
|
239
232
|
)
|
233
|
+
|
234
|
+
# Extract logprobs
|
235
|
+
if batch.return_logprob:
|
236
|
+
output.next_token_logprobs = torch.nn.functional.log_softmax(
|
237
|
+
output.next_token_logits, dim=-1
|
238
|
+
)
|
239
|
+
return_top_logprob = any(x > 0 for x in batch.top_logprobs_nums)
|
240
|
+
if return_top_logprob:
|
241
|
+
logits_metadata = LogitsMetadata(
|
242
|
+
forward_mode=ForwardMode.DECODE,
|
243
|
+
top_logprobs_nums=batch.top_logprobs_nums,
|
244
|
+
)
|
245
|
+
output.output_top_logprobs = LogitsProcessor.get_top_logprobs(
|
246
|
+
output.next_token_logprobs, logits_metadata
|
247
|
+
)[1]
|
248
|
+
|
240
249
|
return output
|
@@ -17,6 +17,13 @@ from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
|
17
17
|
|
18
18
|
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
19
19
|
|
20
|
+
# Put some global args for easy access
|
21
|
+
global_server_args_dict = {
|
22
|
+
"disable_flashinfer": False,
|
23
|
+
"disable_flashinfer_sampling": False,
|
24
|
+
"attention_reduce_in_fp32": False,
|
25
|
+
}
|
26
|
+
|
20
27
|
|
21
28
|
class ForwardMode(IntEnum):
|
22
29
|
# Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
|
@@ -124,10 +131,10 @@ class Req:
|
|
124
131
|
self.logprob_start_len = 0
|
125
132
|
self.top_logprobs_num = 0
|
126
133
|
self.normalized_prompt_logprob = None
|
127
|
-
self.
|
128
|
-
self.
|
129
|
-
self.
|
130
|
-
self.
|
134
|
+
self.input_token_logprobs = None
|
135
|
+
self.input_top_logprobs = None
|
136
|
+
self.output_token_logprobs = []
|
137
|
+
self.output_top_logprobs = []
|
131
138
|
# The tokens is prefilled but need to be considered as decode tokens
|
132
139
|
# and should be updated for the decode logprobs
|
133
140
|
self.last_update_decode_tokens = 0
|
@@ -244,8 +251,8 @@ class Req:
|
|
244
251
|
k = k + 1
|
245
252
|
else:
|
246
253
|
break
|
247
|
-
self.
|
248
|
-
self.
|
254
|
+
self.output_token_logprobs = self.output_token_logprobs[:k]
|
255
|
+
self.output_top_logprobs = self.output_top_logprobs[:k]
|
249
256
|
self.logprob_start_len = prompt_tokens + k
|
250
257
|
self.last_update_decode_tokens = len(self.output_ids) - k
|
251
258
|
|
@@ -376,7 +383,7 @@ class Batch:
|
|
376
383
|
logit_bias = torch.zeros(
|
377
384
|
(bs, vocab_size), dtype=torch.float32, device=device
|
378
385
|
)
|
379
|
-
logit_bias[i] = int_token_logit_bias
|
386
|
+
logit_bias[i][: len(int_token_logit_bias)] = int_token_logit_bias
|
380
387
|
|
381
388
|
# Set fields
|
382
389
|
self.input_ids = torch.tensor(
|
@@ -687,13 +694,21 @@ class Batch:
|
|
687
694
|
# TODO(lmzheng): apply penalty
|
688
695
|
probs = torch.softmax(logits, dim=-1)
|
689
696
|
|
690
|
-
|
691
|
-
|
692
|
-
|
693
|
-
|
694
|
-
|
697
|
+
if not global_server_args_dict["disable_flashinfer_sampling"]:
|
698
|
+
max_top_k_round, batch_size = 32, probs.shape[0]
|
699
|
+
uniform_samples = torch.rand(
|
700
|
+
(max_top_k_round, batch_size), device=probs.device
|
701
|
+
)
|
702
|
+
batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
|
703
|
+
probs, uniform_samples, self.top_ks, self.top_ps
|
704
|
+
)
|
705
|
+
else:
|
706
|
+
# Here we provide a slower fallback implementation.
|
707
|
+
batch_next_token_ids, success = top_k_top_p_sampling_from_probs_torch(
|
708
|
+
probs, self.top_ks, self.top_ps
|
709
|
+
)
|
695
710
|
|
696
|
-
if torch.
|
711
|
+
if not torch.all(success):
|
697
712
|
warnings.warn("Sampling failed, fallback to top_k=1 strategy")
|
698
713
|
probs = probs.masked_fill(torch.isnan(probs), 0.0)
|
699
714
|
argmax_ids = torch.argmax(probs, dim=-1)
|
@@ -933,3 +948,29 @@ def init_triton_args(forward_mode, seq_lens, prefix_lens):
|
|
933
948
|
max_extend_len = int(torch.max(extend_seq_lens))
|
934
949
|
|
935
950
|
return max_seq_len, max_extend_len, start_loc, prefix_lens
|
951
|
+
|
952
|
+
|
953
|
+
def top_k_top_p_sampling_from_probs_torch(
|
954
|
+
probs: torch.Tensor, top_ks: torch.Tensor, top_ps: torch.Tensor
|
955
|
+
):
|
956
|
+
"""A top-k and top-k sampling implementation with native pytorch operations."""
|
957
|
+
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
|
958
|
+
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
959
|
+
probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0
|
960
|
+
probs_sort[
|
961
|
+
torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1)
|
962
|
+
>= top_ks.view(-1, 1)
|
963
|
+
] = 0.0
|
964
|
+
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
|
965
|
+
try:
|
966
|
+
sampled_index = torch.multinomial(probs_sort, num_samples=1)
|
967
|
+
except RuntimeError:
|
968
|
+
batch_next_token_ids = torch.zeros(
|
969
|
+
(probs_sort.shape[0],), dtype=torch.int64, device=probs.device
|
970
|
+
)
|
971
|
+
success = torch.zeros(probs.shape[0], dtype=torch.bool, device=probs.device)
|
972
|
+
return batch_next_token_ids, success
|
973
|
+
|
974
|
+
batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1)
|
975
|
+
success = torch.ones(probs.shape[0], dtype=torch.bool, device=probs.device)
|
976
|
+
return batch_next_token_ids, success
|