sglang 0.2.14__tar.gz → 0.2.14.post2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {sglang-0.2.14/sglang.egg-info → sglang-0.2.14.post2}/PKG-INFO +11 -5
- {sglang-0.2.14 → sglang-0.2.14.post2}/README.md +10 -4
- {sglang-0.2.14 → sglang-0.2.14.post2}/pyproject.toml +1 -1
- sglang-0.2.14.post2/sglang/launch_server_llavavid.py +26 -0
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/srt/constrained/fsm_cache.py +11 -2
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/srt/constrained/jump_forward.py +1 -0
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/srt/hf_transformers_utils.py +0 -149
- sglang-0.2.14.post2/sglang/srt/layers/activation.py +137 -0
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/srt/layers/layernorm.py +47 -4
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/srt/layers/logits_processor.py +4 -4
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/srt/layers/sampler.py +15 -68
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/srt/managers/io_struct.py +5 -4
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/srt/managers/schedule_batch.py +20 -25
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/srt/managers/tokenizer_manager.py +74 -61
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/srt/managers/tp_worker.py +49 -43
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/srt/model_executor/cuda_graph_runner.py +17 -31
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/srt/model_executor/forward_batch_info.py +9 -26
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/srt/model_executor/model_runner.py +20 -17
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/srt/models/chatglm.py +13 -5
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/srt/models/commandr.py +1 -5
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/srt/models/dbrx.py +1 -5
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/srt/models/deepseek.py +1 -5
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/srt/models/deepseek_v2.py +1 -5
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/srt/models/gemma.py +3 -7
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/srt/models/gemma2.py +2 -56
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/srt/models/gpt_bigcode.py +2 -6
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/srt/models/grok.py +10 -8
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/srt/models/internlm2.py +1 -5
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/srt/models/llama2.py +6 -11
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/srt/models/llama_classification.py +2 -6
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/srt/models/llama_embedding.py +3 -4
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/srt/models/llava.py +69 -91
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/srt/models/llavavid.py +40 -86
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/srt/models/minicpm.py +1 -5
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/srt/models/mixtral.py +1 -5
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/srt/models/mixtral_quant.py +1 -5
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/srt/models/qwen.py +2 -5
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/srt/models/qwen2.py +5 -10
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/srt/models/qwen2_moe.py +21 -24
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/srt/models/stablelm.py +1 -5
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/srt/models/yivl.py +2 -7
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/srt/openai_api/adapter.py +85 -4
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/srt/openai_api/protocol.py +2 -0
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/srt/sampling/sampling_batch_info.py +1 -74
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/srt/sampling/sampling_params.py +4 -0
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/srt/server.py +11 -4
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/srt/utils.py +18 -33
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/test/runners.py +2 -2
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/test/test_layernorm.py +53 -1
- sglang-0.2.14.post2/sglang/version.py +1 -0
- {sglang-0.2.14 → sglang-0.2.14.post2/sglang.egg-info}/PKG-INFO +11 -5
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang.egg-info/SOURCES.txt +1 -0
- sglang-0.2.14/sglang/srt/layers/activation.py +0 -55
- sglang-0.2.14/sglang/version.py +0 -1
- {sglang-0.2.14 → sglang-0.2.14.post2}/LICENSE +0 -0
- {sglang-0.2.14 → sglang-0.2.14.post2}/setup.cfg +0 -0
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/__init__.py +0 -0
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/api.py +0 -0
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/bench_latency.py +0 -0
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/bench_serving.py +0 -0
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/check_env.py +0 -0
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/global_config.py +0 -0
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/lang/__init__.py +0 -0
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/lang/backend/__init__.py +0 -0
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/lang/backend/anthropic.py +0 -0
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/lang/backend/base_backend.py +0 -0
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/lang/backend/litellm.py +0 -0
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/lang/backend/openai.py +0 -0
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/lang/backend/runtime_endpoint.py +0 -0
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/lang/backend/vertexai.py +0 -0
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/lang/chat_template.py +0 -0
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/lang/choices.py +0 -0
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/lang/compiler.py +0 -0
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/lang/interpreter.py +0 -0
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/lang/ir.py +0 -0
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/lang/tracer.py +0 -0
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/launch_server.py +0 -0
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/srt/constrained/__init__.py +0 -0
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/srt/constrained/base_tool_cache.py +0 -0
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/srt/conversation.py +0 -0
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/srt/layers/decode_attention.py +0 -0
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/srt/layers/extend_attention.py +0 -0
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/srt/layers/fused_moe/__init__.py +0 -0
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/srt/layers/fused_moe/fused_moe.py +0 -0
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/srt/layers/fused_moe/layer.py +0 -0
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/srt/layers/pooler.py +0 -0
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/srt/layers/prefill_attention.py +0 -0
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/srt/layers/radix_attention.py +0 -0
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/srt/managers/controller_multi.py +0 -0
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/srt/managers/controller_single.py +0 -0
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/srt/managers/detokenizer_manager.py +0 -0
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/srt/managers/policy_scheduler.py +0 -0
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/srt/mem_cache/base_prefix_cache.py +0 -0
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/srt/mem_cache/chunk_cache.py +0 -0
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/srt/mem_cache/flush_cache.py +0 -0
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/srt/mem_cache/memory_pool.py +0 -0
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/srt/mem_cache/radix_cache.py +0 -0
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/srt/mm_utils.py +0 -0
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/srt/model_config.py +0 -0
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/srt/models/mistral.py +0 -0
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/srt/sampling/penaltylib/__init__.py +0 -0
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/srt/sampling/penaltylib/orchestrator.py +0 -0
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -0
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +0 -0
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -0
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -0
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/srt/server_args.py +0 -0
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/test/run_eval.py +0 -0
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/test/simple_eval_common.py +0 -0
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/test/simple_eval_gpqa.py +0 -0
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/test/simple_eval_humaneval.py +0 -0
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/test/simple_eval_math.py +0 -0
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/test/simple_eval_mgsm.py +0 -0
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/test/simple_eval_mmlu.py +0 -0
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/test/srt/sampling/penaltylib/utils.py +0 -0
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/test/test_activation.py +0 -0
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/test/test_programs.py +0 -0
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/test/test_utils.py +0 -0
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang/utils.py +0 -0
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang.egg-info/dependency_links.txt +0 -0
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang.egg-info/requires.txt +0 -0
- {sglang-0.2.14 → sglang-0.2.14.post2}/sglang.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: sglang
|
3
|
-
Version: 0.2.14
|
3
|
+
Version: 0.2.14.post2
|
4
4
|
Summary: SGLang is yet another fast serving framework for large language models and vision language models.
|
5
5
|
License: Apache License
|
6
6
|
Version 2.0, January 2004
|
@@ -312,7 +312,7 @@ pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/
|
|
312
312
|
### Method 2: From source
|
313
313
|
```
|
314
314
|
# Use the last release branch
|
315
|
-
git clone -b v0.2.14 https://github.com/sgl-project/sglang.git
|
315
|
+
git clone -b v0.2.14.post2 https://github.com/sgl-project/sglang.git
|
316
316
|
cd sglang
|
317
317
|
|
318
318
|
pip install --upgrade pip
|
@@ -339,6 +339,7 @@ docker run --gpus all \
|
|
339
339
|
### Method 4: Using docker compose
|
340
340
|
|
341
341
|
<details>
|
342
|
+
<summary>More</summary>
|
342
343
|
|
343
344
|
> This method is recommended if you plan to serve it as a service.
|
344
345
|
> A better approach is to use the [k8s-sglang-service.yaml](./docker/k8s-sglang-service.yaml).
|
@@ -350,6 +351,7 @@ docker run --gpus all \
|
|
350
351
|
### Method 5: Run on Kubernetes or Clouds with SkyPilot
|
351
352
|
|
352
353
|
<details>
|
354
|
+
<summary>More</summary>
|
353
355
|
|
354
356
|
To deploy on Kubernetes or 12+ clouds, you can use [SkyPilot](https://github.com/skypilot-org/skypilot).
|
355
357
|
|
@@ -389,7 +391,7 @@ sky status --endpoint 30000 sglang
|
|
389
391
|
|
390
392
|
|
391
393
|
### Common Notes
|
392
|
-
- [FlashInfer](https://github.com/flashinfer-ai/flashinfer) is currently one of the dependencies that must be installed for SGLang.
|
394
|
+
- [FlashInfer](https://github.com/flashinfer-ai/flashinfer) is currently one of the dependencies that must be installed for SGLang. It only supports sm75 and above. If you encounter any FlashInfer-related issues on sm75+ devices (e.g., T4, A10, A100, L4, L40S, H100), consider using Triton's kernel by `--disable-flashinfer --disable-flashinfer-sampling` and raise an issue.
|
393
395
|
- If you only need to use the OpenAI backend, you can avoid installing other dependencies by using `pip install "sglang[openai]"`.
|
394
396
|
|
395
397
|
## Backend: SGLang Runtime (SRT)
|
@@ -494,7 +496,7 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct
|
|
494
496
|
- Qwen / Qwen 2 / Qwen 2 MoE
|
495
497
|
- DeepSeek / DeepSeek 2
|
496
498
|
- [LLaVA-OneVision](https://llava-vl.github.io/blog/2024-08-05-llava-onevision/)
|
497
|
-
- `python3 -m sglang.launch_server --model-path lmms-lab/llava-onevision-qwen2-72b-ov --port=30000 --tp-size=8 --chat-template=chatml-llava
|
499
|
+
- `python3 -m sglang.launch_server --model-path lmms-lab/llava-onevision-qwen2-72b-ov --port=30000 --tp-size=8 --chat-template=chatml-llava`
|
498
500
|
- Query the server with the [OpenAI Vision API](https://platform.openai.com/docs/guides/vision). See examples at [test/srt/test_vision_openai_server.py](test/srt/test_vision_openai_server.py)
|
499
501
|
- LLaVA 1.5 / 1.6 / NeXT
|
500
502
|
- `python -m sglang.launch_server --model-path lmms-lab/llama3-llava-next-8b --port=30000 --tp-size=1 --chat-template=llava_llama_3`
|
@@ -518,6 +520,7 @@ Instructions for supporting a new model are [here](https://github.com/sgl-projec
|
|
518
520
|
|
519
521
|
#### Use Models From ModelScope
|
520
522
|
<details>
|
523
|
+
<summary>More</summary>
|
521
524
|
|
522
525
|
To use a model from [ModelScope](https://www.modelscope.cn), set the environment variable SGLANG_USE_MODELSCOPE.
|
523
526
|
```
|
@@ -532,6 +535,7 @@ SGLANG_USE_MODELSCOPE=true python -m sglang.launch_server --model-path qwen/Qwen
|
|
532
535
|
|
533
536
|
#### Run Llama 3.1 405B
|
534
537
|
<details>
|
538
|
+
<summary>More</summary>
|
535
539
|
|
536
540
|
```bash
|
537
541
|
# Run 405B (fp8) on a single node
|
@@ -549,7 +553,9 @@ GLOO_SOCKET_IFNAME=eth0 python3 -m sglang.launch_server --model-path meta-llama/
|
|
549
553
|
|
550
554
|
### Benchmark Performance
|
551
555
|
|
552
|
-
- Benchmark a single static batch by running the following command without launching a server. The arguments are the same as for `launch_server.py`.
|
556
|
+
- Benchmark a single static batch by running the following command without launching a server. The arguments are the same as for `launch_server.py`.
|
557
|
+
Note that this is not a dynamic batching server, so it may run out of memory for a batch size that a real server can handle.
|
558
|
+
A real server truncates the prefill into several batches, while this unit test does not. For accurate large batch testing, please use `sglang.bench_serving` instead.
|
553
559
|
```
|
554
560
|
python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 32 --input-len 256 --output-len 32
|
555
561
|
```
|
@@ -56,7 +56,7 @@ pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/
|
|
56
56
|
### Method 2: From source
|
57
57
|
```
|
58
58
|
# Use the last release branch
|
59
|
-
git clone -b v0.2.14 https://github.com/sgl-project/sglang.git
|
59
|
+
git clone -b v0.2.14.post2 https://github.com/sgl-project/sglang.git
|
60
60
|
cd sglang
|
61
61
|
|
62
62
|
pip install --upgrade pip
|
@@ -83,6 +83,7 @@ docker run --gpus all \
|
|
83
83
|
### Method 4: Using docker compose
|
84
84
|
|
85
85
|
<details>
|
86
|
+
<summary>More</summary>
|
86
87
|
|
87
88
|
> This method is recommended if you plan to serve it as a service.
|
88
89
|
> A better approach is to use the [k8s-sglang-service.yaml](./docker/k8s-sglang-service.yaml).
|
@@ -94,6 +95,7 @@ docker run --gpus all \
|
|
94
95
|
### Method 5: Run on Kubernetes or Clouds with SkyPilot
|
95
96
|
|
96
97
|
<details>
|
98
|
+
<summary>More</summary>
|
97
99
|
|
98
100
|
To deploy on Kubernetes or 12+ clouds, you can use [SkyPilot](https://github.com/skypilot-org/skypilot).
|
99
101
|
|
@@ -133,7 +135,7 @@ sky status --endpoint 30000 sglang
|
|
133
135
|
|
134
136
|
|
135
137
|
### Common Notes
|
136
|
-
- [FlashInfer](https://github.com/flashinfer-ai/flashinfer) is currently one of the dependencies that must be installed for SGLang.
|
138
|
+
- [FlashInfer](https://github.com/flashinfer-ai/flashinfer) is currently one of the dependencies that must be installed for SGLang. It only supports sm75 and above. If you encounter any FlashInfer-related issues on sm75+ devices (e.g., T4, A10, A100, L4, L40S, H100), consider using Triton's kernel by `--disable-flashinfer --disable-flashinfer-sampling` and raise an issue.
|
137
139
|
- If you only need to use the OpenAI backend, you can avoid installing other dependencies by using `pip install "sglang[openai]"`.
|
138
140
|
|
139
141
|
## Backend: SGLang Runtime (SRT)
|
@@ -238,7 +240,7 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct
|
|
238
240
|
- Qwen / Qwen 2 / Qwen 2 MoE
|
239
241
|
- DeepSeek / DeepSeek 2
|
240
242
|
- [LLaVA-OneVision](https://llava-vl.github.io/blog/2024-08-05-llava-onevision/)
|
241
|
-
- `python3 -m sglang.launch_server --model-path lmms-lab/llava-onevision-qwen2-72b-ov --port=30000 --tp-size=8 --chat-template=chatml-llava
|
243
|
+
- `python3 -m sglang.launch_server --model-path lmms-lab/llava-onevision-qwen2-72b-ov --port=30000 --tp-size=8 --chat-template=chatml-llava`
|
242
244
|
- Query the server with the [OpenAI Vision API](https://platform.openai.com/docs/guides/vision). See examples at [test/srt/test_vision_openai_server.py](test/srt/test_vision_openai_server.py)
|
243
245
|
- LLaVA 1.5 / 1.6 / NeXT
|
244
246
|
- `python -m sglang.launch_server --model-path lmms-lab/llama3-llava-next-8b --port=30000 --tp-size=1 --chat-template=llava_llama_3`
|
@@ -262,6 +264,7 @@ Instructions for supporting a new model are [here](https://github.com/sgl-projec
|
|
262
264
|
|
263
265
|
#### Use Models From ModelScope
|
264
266
|
<details>
|
267
|
+
<summary>More</summary>
|
265
268
|
|
266
269
|
To use a model from [ModelScope](https://www.modelscope.cn), set the environment variable SGLANG_USE_MODELSCOPE.
|
267
270
|
```
|
@@ -276,6 +279,7 @@ SGLANG_USE_MODELSCOPE=true python -m sglang.launch_server --model-path qwen/Qwen
|
|
276
279
|
|
277
280
|
#### Run Llama 3.1 405B
|
278
281
|
<details>
|
282
|
+
<summary>More</summary>
|
279
283
|
|
280
284
|
```bash
|
281
285
|
# Run 405B (fp8) on a single node
|
@@ -293,7 +297,9 @@ GLOO_SOCKET_IFNAME=eth0 python3 -m sglang.launch_server --model-path meta-llama/
|
|
293
297
|
|
294
298
|
### Benchmark Performance
|
295
299
|
|
296
|
-
- Benchmark a single static batch by running the following command without launching a server. The arguments are the same as for `launch_server.py`.
|
300
|
+
- Benchmark a single static batch by running the following command without launching a server. The arguments are the same as for `launch_server.py`.
|
301
|
+
Note that this is not a dynamic batching server, so it may run out of memory for a batch size that a real server can handle.
|
302
|
+
A real server truncates the prefill into several batches, while this unit test does not. For accurate large batch testing, please use `sglang.bench_serving` instead.
|
297
303
|
```
|
298
304
|
python -m sglang.bench_latency --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 32 --input-len 256 --output-len 32
|
299
305
|
```
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
4
4
|
|
5
5
|
[project]
|
6
6
|
name = "sglang"
|
7
|
-
version = "0.2.14"
|
7
|
+
version = "0.2.14.post2"
|
8
8
|
description = "SGLang is yet another fast serving framework for large language models and vision language models."
|
9
9
|
readme = "README.md"
|
10
10
|
requires-python = ">=3.8"
|
@@ -0,0 +1,26 @@
|
|
1
|
+
"""Launch the inference server for Llava-video model."""
|
2
|
+
|
3
|
+
import argparse
|
4
|
+
|
5
|
+
from sglang.srt.server import ServerArgs, launch_server
|
6
|
+
|
7
|
+
if __name__ == "__main__":
|
8
|
+
parser = argparse.ArgumentParser()
|
9
|
+
ServerArgs.add_cli_args(parser)
|
10
|
+
args = parser.parse_args()
|
11
|
+
server_args = ServerArgs.from_cli_args(args)
|
12
|
+
|
13
|
+
model_overide_args = {}
|
14
|
+
model_overide_args["mm_spatial_pool_stride"] = 2
|
15
|
+
model_overide_args["architectures"] = ["LlavaVidForCausalLM"]
|
16
|
+
model_overide_args["num_frames"] = 16
|
17
|
+
model_overide_args["model_type"] = "llavavid"
|
18
|
+
if model_overide_args["num_frames"] == 32:
|
19
|
+
model_overide_args["rope_scaling"] = {"factor": 2.0, "type": "linear"}
|
20
|
+
model_overide_args["max_sequence_length"] = 4096 * 2
|
21
|
+
model_overide_args["tokenizer_model_max_length"] = 4096 * 2
|
22
|
+
model_overide_args["model_max_length"] = 4096 * 2
|
23
|
+
if "34b" in args.model_path.lower():
|
24
|
+
model_overide_args["image_token_index"] = 64002
|
25
|
+
|
26
|
+
launch_server(server_args, model_overide_args, None)
|
@@ -15,6 +15,8 @@ limitations under the License.
|
|
15
15
|
|
16
16
|
"""Cache for the compressed finite state machine."""
|
17
17
|
|
18
|
+
from outlines.fsm.json_schema import build_regex_from_schema
|
19
|
+
|
18
20
|
from sglang.srt.constrained import RegexGuide, TransformerTokenizer
|
19
21
|
from sglang.srt.constrained.base_tool_cache import BaseToolCache
|
20
22
|
|
@@ -26,9 +28,12 @@ class FSMCache(BaseToolCache):
|
|
26
28
|
tokenizer_args_dict,
|
27
29
|
enable=True,
|
28
30
|
skip_tokenizer_init=False,
|
31
|
+
json_schema_mode=False,
|
29
32
|
):
|
30
33
|
super().__init__(enable=enable)
|
31
34
|
|
35
|
+
self.json_schema_mode = json_schema_mode
|
36
|
+
|
32
37
|
if (
|
33
38
|
skip_tokenizer_init
|
34
39
|
or tokenizer_path.endswith(".json")
|
@@ -72,5 +77,9 @@ class FSMCache(BaseToolCache):
|
|
72
77
|
tokenizer_path, **tokenizer_args_dict
|
73
78
|
)
|
74
79
|
|
75
|
-
def init_value(self,
|
76
|
-
|
80
|
+
def init_value(self, value):
|
81
|
+
if self.json_schema_mode:
|
82
|
+
regex = build_regex_from_schema(value)
|
83
|
+
return RegexGuide(regex, self.outlines_tokenizer), regex
|
84
|
+
else:
|
85
|
+
return RegexGuide(value, self.outlines_tokenizer)
|
@@ -119,24 +119,7 @@ def get_tokenizer(
|
|
119
119
|
tokenizer_revision: Optional[str] = None,
|
120
120
|
**kwargs,
|
121
121
|
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
|
122
|
-
if tokenizer_name.endswith(".json"):
|
123
|
-
return TiktokenTokenizer(tokenizer_name)
|
124
|
-
|
125
|
-
if tokenizer_name.endswith(".model"):
|
126
|
-
return SentencePieceTokenizer(tokenizer_name)
|
127
|
-
|
128
122
|
"""Gets a tokenizer for the given model name via Huggingface."""
|
129
|
-
if is_multimodal_model(tokenizer_name):
|
130
|
-
processor = get_processor(
|
131
|
-
tokenizer_name,
|
132
|
-
*args,
|
133
|
-
trust_remote_code=trust_remote_code,
|
134
|
-
tokenizer_revision=tokenizer_revision,
|
135
|
-
**kwargs,
|
136
|
-
)
|
137
|
-
tokenizer = processor.tokenizer
|
138
|
-
return tokenizer
|
139
|
-
|
140
123
|
if tokenizer_mode == "slow":
|
141
124
|
if kwargs.get("use_fast", False):
|
142
125
|
raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
|
@@ -199,135 +182,3 @@ def get_processor(
|
|
199
182
|
**kwargs,
|
200
183
|
)
|
201
184
|
return processor
|
202
|
-
|
203
|
-
|
204
|
-
class TiktokenTokenizer:
|
205
|
-
def __init__(self, tokenizer_path):
|
206
|
-
import tiktoken
|
207
|
-
from jinja2 import Template
|
208
|
-
|
209
|
-
PAT_STR_B = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
|
210
|
-
|
211
|
-
# Read JSON
|
212
|
-
name = "tmp-json"
|
213
|
-
with open(tokenizer_path, "rb") as fin:
|
214
|
-
tok_dict = json.load(fin)
|
215
|
-
|
216
|
-
mergeable_ranks = {
|
217
|
-
bytes(item["bytes"]): item["token"] for item in tok_dict["regular_tokens"]
|
218
|
-
}
|
219
|
-
special_tokens = {
|
220
|
-
bytes(item["bytes"]).decode(): item["token"]
|
221
|
-
for item in tok_dict["special_tokens"]
|
222
|
-
}
|
223
|
-
assert tok_dict["word_split"] == "V1"
|
224
|
-
|
225
|
-
default_allowed_special = None
|
226
|
-
|
227
|
-
kwargs = {
|
228
|
-
"name": name,
|
229
|
-
"pat_str": tok_dict.get("pat_str", PAT_STR_B),
|
230
|
-
"mergeable_ranks": mergeable_ranks,
|
231
|
-
"special_tokens": special_tokens,
|
232
|
-
}
|
233
|
-
if "default_allowed_special" in tok_dict:
|
234
|
-
default_allowed_special = set(
|
235
|
-
[
|
236
|
-
bytes(bytes_list).decode()
|
237
|
-
for bytes_list in tok_dict["default_allowed_special"]
|
238
|
-
]
|
239
|
-
)
|
240
|
-
if "vocab_size" in tok_dict:
|
241
|
-
kwargs["explicit_n_vocab"] = tok_dict["vocab_size"]
|
242
|
-
|
243
|
-
PAD = "<|pad|>"
|
244
|
-
EOS = "<|eos|>"
|
245
|
-
SEP = "<|separator|>"
|
246
|
-
|
247
|
-
DEFAULT_CONTROL_TOKENS = {"pad": PAD, "sep": EOS, "eos": SEP}
|
248
|
-
|
249
|
-
tokenizer = tiktoken.Encoding(**kwargs)
|
250
|
-
tokenizer._default_allowed_special = default_allowed_special or set()
|
251
|
-
tokenizer._control_tokens = DEFAULT_CONTROL_TOKENS
|
252
|
-
|
253
|
-
def encode_patched(
|
254
|
-
self,
|
255
|
-
text: str,
|
256
|
-
*,
|
257
|
-
allowed_special: Union[
|
258
|
-
Literal["all"], AbstractSet[str]
|
259
|
-
] = set(), # noqa: B006
|
260
|
-
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
261
|
-
) -> List[int]:
|
262
|
-
if isinstance(allowed_special, set):
|
263
|
-
allowed_special |= self._default_allowed_special
|
264
|
-
return tiktoken.Encoding.encode(
|
265
|
-
self,
|
266
|
-
text,
|
267
|
-
allowed_special=allowed_special,
|
268
|
-
disallowed_special=(),
|
269
|
-
)
|
270
|
-
|
271
|
-
tokenizer.encode = functools.partial(encode_patched, tokenizer)
|
272
|
-
|
273
|
-
# Convert to HF interface
|
274
|
-
self.tokenizer = tokenizer
|
275
|
-
self.eos_token_id = tokenizer._special_tokens[EOS]
|
276
|
-
self.vocab_size = tokenizer.n_vocab
|
277
|
-
self.chat_template = Template(
|
278
|
-
"{% for message in messages %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'system' %}{{ 'System: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + '<|separator|>\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"
|
279
|
-
)
|
280
|
-
|
281
|
-
def encode(self, x, add_special_tokens=False):
|
282
|
-
return self.tokenizer.encode(x)
|
283
|
-
|
284
|
-
def decode(self, x):
|
285
|
-
return self.tokenizer.decode(x)
|
286
|
-
|
287
|
-
def batch_decode(
|
288
|
-
self, batch, skip_special_tokens=True, spaces_between_special_tokens=False
|
289
|
-
):
|
290
|
-
if isinstance(batch[0], int):
|
291
|
-
batch = [[x] for x in batch]
|
292
|
-
return self.tokenizer.decode_batch(batch)
|
293
|
-
|
294
|
-
def apply_chat_template(self, messages, tokenize, add_generation_prompt):
|
295
|
-
ret = self.chat_template.render(
|
296
|
-
messages=messages, add_generation_prompt=add_generation_prompt
|
297
|
-
)
|
298
|
-
return self.encode(ret) if tokenize else ret
|
299
|
-
|
300
|
-
|
301
|
-
class SentencePieceTokenizer:
|
302
|
-
def __init__(self, tokenizer_path):
|
303
|
-
import sentencepiece as spm
|
304
|
-
from jinja2 import Template
|
305
|
-
|
306
|
-
tokenizer = spm.SentencePieceProcessor(model_file=tokenizer_path)
|
307
|
-
|
308
|
-
# Convert to HF interface
|
309
|
-
self.tokenizer = tokenizer
|
310
|
-
self.eos_token_id = tokenizer.eos_id()
|
311
|
-
self.vocab_size = tokenizer.vocab_size()
|
312
|
-
self.chat_template = Template(
|
313
|
-
"{% for message in messages %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'system' %}{{ 'System: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + '<|separator|>\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"
|
314
|
-
)
|
315
|
-
|
316
|
-
def encode(self, x, add_special_tokens=False):
|
317
|
-
return self.tokenizer.encode(x)
|
318
|
-
|
319
|
-
def decode(self, x):
|
320
|
-
return self.tokenizer.decode(x)
|
321
|
-
|
322
|
-
def batch_decode(
|
323
|
-
self, batch, skip_special_tokens=True, spaces_between_special_tokens=False
|
324
|
-
):
|
325
|
-
if isinstance(batch[0], int):
|
326
|
-
batch = [[x] for x in batch]
|
327
|
-
return self.tokenizer.decode(batch)
|
328
|
-
|
329
|
-
def apply_chat_template(self, messages, tokenize, add_generation_prompt):
|
330
|
-
ret = self.chat_template.render(
|
331
|
-
messages=messages, add_generation_prompt=add_generation_prompt
|
332
|
-
)
|
333
|
-
return self.encode(ret) if tokenize else ret
|
@@ -0,0 +1,137 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
Unless required by applicable law or agreed to in writing, software
|
8
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
9
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
10
|
+
See the License for the specific language governing permissions and
|
11
|
+
limitations under the License.
|
12
|
+
"""
|
13
|
+
|
14
|
+
"""Fused operators for activation layers."""
|
15
|
+
|
16
|
+
from typing import Optional
|
17
|
+
|
18
|
+
import torch
|
19
|
+
import torch.nn as nn
|
20
|
+
import torch.nn.functional as F
|
21
|
+
from flashinfer.activation import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
|
22
|
+
from vllm.distributed import (
|
23
|
+
divide,
|
24
|
+
get_tensor_model_parallel_rank,
|
25
|
+
get_tensor_model_parallel_world_size,
|
26
|
+
)
|
27
|
+
from vllm.model_executor.custom_op import CustomOp
|
28
|
+
from vllm.model_executor.layers.quantization import QuantizationConfig
|
29
|
+
from vllm.model_executor.utils import set_weight_attrs
|
30
|
+
|
31
|
+
|
32
|
+
class SiluAndMul(CustomOp):
|
33
|
+
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
34
|
+
d = x.shape[-1] // 2
|
35
|
+
return F.silu(x[..., :d]) * x[..., d:]
|
36
|
+
|
37
|
+
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
38
|
+
d = x.shape[-1] // 2
|
39
|
+
output_shape = x.shape[:-1] + (d,)
|
40
|
+
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
41
|
+
silu_and_mul(x, out)
|
42
|
+
return out
|
43
|
+
|
44
|
+
|
45
|
+
class GeluAndMul(CustomOp):
|
46
|
+
def __init__(self, approximate="tanh"):
|
47
|
+
super().__init__()
|
48
|
+
self.approximate = approximate
|
49
|
+
|
50
|
+
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
51
|
+
d = x.shape[-1] // 2
|
52
|
+
return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
|
53
|
+
|
54
|
+
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
55
|
+
d = x.shape[-1] // 2
|
56
|
+
output_shape = x.shape[:-1] + (d,)
|
57
|
+
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
58
|
+
if self.approximate == "tanh":
|
59
|
+
gelu_tanh_and_mul(x, out)
|
60
|
+
elif self.approximate == "none":
|
61
|
+
gelu_and_mul(x, out)
|
62
|
+
else:
|
63
|
+
raise RuntimeError("GeluAndMul only support tanh or none")
|
64
|
+
return out
|
65
|
+
|
66
|
+
|
67
|
+
class ScaledActivation(nn.Module):
|
68
|
+
"""An activation function with post-scale parameters.
|
69
|
+
|
70
|
+
This is used for some quantization methods like AWQ.
|
71
|
+
"""
|
72
|
+
|
73
|
+
def __init__(
|
74
|
+
self,
|
75
|
+
act_module: nn.Module,
|
76
|
+
intermediate_size: int,
|
77
|
+
input_is_parallel: bool = True,
|
78
|
+
params_dtype: Optional[torch.dtype] = None,
|
79
|
+
):
|
80
|
+
super().__init__()
|
81
|
+
self.act = act_module
|
82
|
+
self.input_is_parallel = input_is_parallel
|
83
|
+
if input_is_parallel:
|
84
|
+
tp_size = get_tensor_model_parallel_world_size()
|
85
|
+
intermediate_size_per_partition = divide(intermediate_size, tp_size)
|
86
|
+
else:
|
87
|
+
intermediate_size_per_partition = intermediate_size
|
88
|
+
if params_dtype is None:
|
89
|
+
params_dtype = torch.get_default_dtype()
|
90
|
+
self.scales = nn.Parameter(
|
91
|
+
torch.empty(intermediate_size_per_partition, dtype=params_dtype)
|
92
|
+
)
|
93
|
+
set_weight_attrs(self.scales, {"weight_loader": self.weight_loader})
|
94
|
+
|
95
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
96
|
+
return self.act(x) / self.scales
|
97
|
+
|
98
|
+
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
99
|
+
param_data = param.data
|
100
|
+
if self.input_is_parallel:
|
101
|
+
tp_rank = get_tensor_model_parallel_rank()
|
102
|
+
shard_size = param_data.shape[0]
|
103
|
+
start_idx = tp_rank * shard_size
|
104
|
+
loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
|
105
|
+
assert param_data.shape == loaded_weight.shape
|
106
|
+
param_data.copy_(loaded_weight)
|
107
|
+
|
108
|
+
|
109
|
+
_ACTIVATION_REGISTRY = {
|
110
|
+
"gelu": nn.GELU(),
|
111
|
+
"gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
|
112
|
+
}
|
113
|
+
|
114
|
+
|
115
|
+
def get_act_fn(
|
116
|
+
act_fn_name: str,
|
117
|
+
quant_config: Optional[QuantizationConfig] = None,
|
118
|
+
intermediate_size: Optional[int] = None,
|
119
|
+
input_is_parallel: bool = True,
|
120
|
+
params_dtype: Optional[torch.dtype] = None,
|
121
|
+
) -> nn.Module:
|
122
|
+
"""Get an activation function by name."""
|
123
|
+
act_fn_name = act_fn_name.lower()
|
124
|
+
if act_fn_name not in _ACTIVATION_REGISTRY:
|
125
|
+
raise ValueError(f"Activation function {act_fn_name!r} is not supported.")
|
126
|
+
|
127
|
+
act_fn = _ACTIVATION_REGISTRY[act_fn_name]
|
128
|
+
if quant_config is not None and act_fn_name in quant_config.get_scaled_act_names():
|
129
|
+
if intermediate_size is None:
|
130
|
+
raise ValueError(
|
131
|
+
"intermediate_size must be specified for scaled "
|
132
|
+
"activation functions."
|
133
|
+
)
|
134
|
+
return ScaledActivation(
|
135
|
+
act_fn, intermediate_size, input_is_parallel, params_dtype
|
136
|
+
)
|
137
|
+
return act_fn
|
@@ -19,7 +19,12 @@ from typing import Optional, Tuple, Union
|
|
19
19
|
|
20
20
|
import torch
|
21
21
|
import torch.nn as nn
|
22
|
-
from flashinfer.norm import
|
22
|
+
from flashinfer.norm import (
|
23
|
+
fused_add_rmsnorm,
|
24
|
+
gemma_fused_add_rmsnorm,
|
25
|
+
gemma_rmsnorm,
|
26
|
+
rmsnorm,
|
27
|
+
)
|
23
28
|
from vllm.model_executor.custom_op import CustomOp
|
24
29
|
|
25
30
|
|
@@ -32,15 +37,12 @@ class RMSNorm(CustomOp):
|
|
32
37
|
super().__init__()
|
33
38
|
self.weight = nn.Parameter(torch.ones(hidden_size))
|
34
39
|
self.variance_epsilon = eps
|
35
|
-
self.is_lower_sm80 = torch.cuda.get_device_capability()[0] < 8
|
36
40
|
|
37
41
|
def forward_cuda(
|
38
42
|
self,
|
39
43
|
x: torch.Tensor,
|
40
44
|
residual: Optional[torch.Tensor] = None,
|
41
45
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
42
|
-
if self.is_lower_sm80:
|
43
|
-
return self.forward_native(x, residual)
|
44
46
|
|
45
47
|
if residual is not None:
|
46
48
|
fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon)
|
@@ -66,3 +68,44 @@ class RMSNorm(CustomOp):
|
|
66
68
|
return x
|
67
69
|
else:
|
68
70
|
return x, residual
|
71
|
+
|
72
|
+
|
73
|
+
class GemmaRMSNorm(CustomOp):
|
74
|
+
def __init__(
|
75
|
+
self,
|
76
|
+
hidden_size: int,
|
77
|
+
eps: float = 1e-6,
|
78
|
+
) -> None:
|
79
|
+
super().__init__()
|
80
|
+
self.weight = nn.Parameter(torch.zeros(hidden_size))
|
81
|
+
self.variance_epsilon = eps
|
82
|
+
|
83
|
+
def forward_native(
|
84
|
+
self,
|
85
|
+
x: torch.Tensor,
|
86
|
+
residual: Optional[torch.Tensor] = None,
|
87
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
88
|
+
orig_dtype = x.dtype
|
89
|
+
if residual is not None:
|
90
|
+
x = x + residual
|
91
|
+
residual = x
|
92
|
+
|
93
|
+
x = x.float()
|
94
|
+
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
95
|
+
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
96
|
+
x = x * (1.0 + self.weight.float())
|
97
|
+
x = x.to(orig_dtype)
|
98
|
+
return x if residual is None else (x, residual)
|
99
|
+
|
100
|
+
def forward_cuda(
|
101
|
+
self,
|
102
|
+
x: torch.Tensor,
|
103
|
+
residual: Optional[torch.Tensor] = None,
|
104
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
105
|
+
if residual is not None:
|
106
|
+
gemma_fused_add_rmsnorm(
|
107
|
+
x, residual, self.weight.data, self.variance_epsilon
|
108
|
+
)
|
109
|
+
return x, residual
|
110
|
+
out = gemma_rmsnorm(x, self.weight.data, self.variance_epsilon)
|
111
|
+
return out
|
@@ -29,7 +29,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetad
|
|
29
29
|
|
30
30
|
|
31
31
|
@dataclasses.dataclass
|
32
|
-
class
|
32
|
+
class LogitProcessorOutput:
|
33
33
|
# The logits of the next tokens. shape: [#seq, vocab_size]
|
34
34
|
next_token_logits: torch.Tensor
|
35
35
|
# The logprobs of the next tokens. shape: [#seq, vocab_size]
|
@@ -185,7 +185,7 @@ class LogitsProcessor(nn.Module):
|
|
185
185
|
|
186
186
|
# Return only last_logits if logprob is not requested
|
187
187
|
if not logits_metadata.return_logprob:
|
188
|
-
return
|
188
|
+
return LogitProcessorOutput(
|
189
189
|
next_token_logits=last_logits,
|
190
190
|
next_token_logprobs=None,
|
191
191
|
normalized_prompt_logprobs=None,
|
@@ -209,7 +209,7 @@ class LogitsProcessor(nn.Module):
|
|
209
209
|
else:
|
210
210
|
output_top_logprobs = None
|
211
211
|
|
212
|
-
return
|
212
|
+
return LogitProcessorOutput(
|
213
213
|
next_token_logits=last_logits,
|
214
214
|
next_token_logprobs=last_logprobs,
|
215
215
|
normalized_prompt_logprobs=None,
|
@@ -278,7 +278,7 @@ class LogitsProcessor(nn.Module):
|
|
278
278
|
# Remove the last token logprob for the prefill tokens.
|
279
279
|
input_token_logprobs = input_token_logprobs[:-1]
|
280
280
|
|
281
|
-
return
|
281
|
+
return LogitProcessorOutput(
|
282
282
|
next_token_logits=last_logits,
|
283
283
|
next_token_logprobs=last_logprobs,
|
284
284
|
normalized_prompt_logprobs=normalized_prompt_logprobs,
|