sglang 0.3.3.post1__py3-none-any.whl → 0.3.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +28 -10
 - sglang/bench_server_latency.py +21 -10
 - sglang/bench_serving.py +101 -7
 - sglang/global_config.py +0 -1
 - sglang/srt/layers/attention/__init__.py +27 -5
 - sglang/srt/layers/attention/double_sparsity_backend.py +281 -0
 - sglang/srt/layers/attention/flashinfer_backend.py +352 -83
 - sglang/srt/layers/attention/triton_backend.py +6 -4
 - sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +772 -0
 - sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -3
 - sglang/srt/layers/attention/triton_ops/prefill_attention.py +4 -2
 - sglang/srt/layers/sampler.py +6 -2
 - sglang/srt/managers/detokenizer_manager.py +31 -10
 - sglang/srt/managers/io_struct.py +4 -0
 - sglang/srt/managers/schedule_batch.py +120 -43
 - sglang/srt/managers/schedule_policy.py +2 -1
 - sglang/srt/managers/scheduler.py +202 -140
 - sglang/srt/managers/tokenizer_manager.py +5 -1
 - sglang/srt/managers/tp_worker.py +111 -1
 - sglang/srt/mem_cache/chunk_cache.py +8 -4
 - sglang/srt/mem_cache/memory_pool.py +77 -4
 - sglang/srt/mem_cache/radix_cache.py +15 -7
 - sglang/srt/model_executor/cuda_graph_runner.py +4 -4
 - sglang/srt/model_executor/forward_batch_info.py +16 -21
 - sglang/srt/model_executor/model_runner.py +60 -1
 - sglang/srt/models/baichuan.py +2 -3
 - sglang/srt/models/chatglm.py +5 -6
 - sglang/srt/models/commandr.py +1 -2
 - sglang/srt/models/dbrx.py +1 -2
 - sglang/srt/models/deepseek.py +4 -5
 - sglang/srt/models/deepseek_v2.py +5 -6
 - sglang/srt/models/exaone.py +1 -2
 - sglang/srt/models/gemma.py +2 -2
 - sglang/srt/models/gemma2.py +5 -5
 - sglang/srt/models/gpt_bigcode.py +5 -5
 - sglang/srt/models/grok.py +1 -2
 - sglang/srt/models/internlm2.py +1 -2
 - sglang/srt/models/llama.py +1 -2
 - sglang/srt/models/llama_classification.py +1 -2
 - sglang/srt/models/llama_reward.py +2 -3
 - sglang/srt/models/llava.py +4 -8
 - sglang/srt/models/llavavid.py +1 -2
 - sglang/srt/models/minicpm.py +1 -2
 - sglang/srt/models/minicpm3.py +5 -6
 - sglang/srt/models/mixtral.py +1 -2
 - sglang/srt/models/mixtral_quant.py +1 -2
 - sglang/srt/models/olmo.py +352 -0
 - sglang/srt/models/olmoe.py +1 -2
 - sglang/srt/models/qwen.py +1 -2
 - sglang/srt/models/qwen2.py +1 -2
 - sglang/srt/models/qwen2_moe.py +4 -5
 - sglang/srt/models/stablelm.py +1 -2
 - sglang/srt/models/torch_native_llama.py +1 -2
 - sglang/srt/models/xverse.py +1 -2
 - sglang/srt/models/xverse_moe.py +4 -5
 - sglang/srt/models/yivl.py +1 -2
 - sglang/srt/openai_api/adapter.py +92 -49
 - sglang/srt/openai_api/protocol.py +10 -2
 - sglang/srt/sampling/penaltylib/orchestrator.py +28 -9
 - sglang/srt/sampling/sampling_batch_info.py +92 -58
 - sglang/srt/sampling/sampling_params.py +2 -0
 - sglang/srt/server.py +116 -17
 - sglang/srt/server_args.py +121 -45
 - sglang/srt/utils.py +11 -3
 - sglang/test/few_shot_gsm8k.py +4 -1
 - sglang/test/few_shot_gsm8k_engine.py +144 -0
 - sglang/test/srt/sampling/penaltylib/utils.py +16 -12
 - sglang/version.py +1 -1
 - {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/METADATA +72 -29
 - {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/RECORD +73 -70
 - {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/WHEEL +1 -1
 - sglang/srt/layers/attention/flashinfer_utils.py +0 -237
 - {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/LICENSE +0 -0
 - {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/top_level.txt +0 -0
 
    
        sglang/srt/server.py
    CHANGED
    
    | 
         @@ -25,11 +25,12 @@ import json 
     | 
|
| 
       25 
25 
     | 
    
         
             
            import logging
         
     | 
| 
       26 
26 
     | 
    
         
             
            import multiprocessing as mp
         
     | 
| 
       27 
27 
     | 
    
         
             
            import os
         
     | 
| 
       28 
     | 
    
         
            -
            import random
         
     | 
| 
       29 
28 
     | 
    
         
             
            import threading
         
     | 
| 
       30 
29 
     | 
    
         
             
            import time
         
     | 
| 
       31 
30 
     | 
    
         
             
            from http import HTTPStatus
         
     | 
| 
       32 
     | 
    
         
            -
            from typing import Dict, List, Optional, Union
         
     | 
| 
      
 31 
     | 
    
         
            +
            from typing import AsyncIterator, Dict, List, Optional, Union
         
     | 
| 
      
 32 
     | 
    
         
            +
             
     | 
| 
      
 33 
     | 
    
         
            +
            import orjson
         
     | 
| 
       33 
34 
     | 
    
         | 
| 
       34 
35 
     | 
    
         
             
            # Fix a bug of Python threading
         
     | 
| 
       35 
36 
     | 
    
         
             
            setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
         
     | 
| 
         @@ -40,7 +41,8 @@ import uvicorn 
     | 
|
| 
       40 
41 
     | 
    
         
             
            import uvloop
         
     | 
| 
       41 
42 
     | 
    
         
             
            from fastapi import FastAPI, File, Form, Request, UploadFile
         
     | 
| 
       42 
43 
     | 
    
         
             
            from fastapi.middleware.cors import CORSMiddleware
         
     | 
| 
       43 
     | 
    
         
            -
            from fastapi.responses import  
     | 
| 
      
 44 
     | 
    
         
            +
            from fastapi.responses import ORJSONResponse, Response, StreamingResponse
         
     | 
| 
      
 45 
     | 
    
         
            +
            from uvicorn.config import LOGGING_CONFIG
         
     | 
| 
       44 
46 
     | 
    
         | 
| 
       45 
47 
     | 
    
         
             
            from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
         
     | 
| 
       46 
48 
     | 
    
         
             
            from sglang.srt.hf_transformers_utils import get_tokenizer
         
     | 
| 
         @@ -176,12 +178,12 @@ async def update_weights(obj: UpdateWeightReqInput, request: Request): 
     | 
|
| 
       176 
178 
     | 
    
         
             
                success, message = await tokenizer_manager.update_weights(obj, request)
         
     | 
| 
       177 
179 
     | 
    
         
             
                content = {"success": success, "message": message}
         
     | 
| 
       178 
180 
     | 
    
         
             
                if success:
         
     | 
| 
       179 
     | 
    
         
            -
                    return  
     | 
| 
      
 181 
     | 
    
         
            +
                    return ORJSONResponse(
         
     | 
| 
       180 
182 
     | 
    
         
             
                        content,
         
     | 
| 
       181 
183 
     | 
    
         
             
                        status_code=HTTPStatus.OK,
         
     | 
| 
       182 
184 
     | 
    
         
             
                    )
         
     | 
| 
       183 
185 
     | 
    
         
             
                else:
         
     | 
| 
       184 
     | 
    
         
            -
                    return  
     | 
| 
      
 186 
     | 
    
         
            +
                    return ORJSONResponse(
         
     | 
| 
       185 
187 
     | 
    
         
             
                        content,
         
     | 
| 
       186 
188 
     | 
    
         
             
                        status_code=HTTPStatus.BAD_REQUEST,
         
     | 
| 
       187 
189 
     | 
    
         
             
                    )
         
     | 
| 
         @@ -192,14 +194,18 @@ async def generate_request(obj: GenerateReqInput, request: Request): 
     | 
|
| 
       192 
194 
     | 
    
         
             
                """Handle a generate request."""
         
     | 
| 
       193 
195 
     | 
    
         
             
                if obj.stream:
         
     | 
| 
       194 
196 
     | 
    
         | 
| 
       195 
     | 
    
         
            -
                    async def stream_results():
         
     | 
| 
      
 197 
     | 
    
         
            +
                    async def stream_results() -> AsyncIterator[bytes]:
         
     | 
| 
       196 
198 
     | 
    
         
             
                        try:
         
     | 
| 
       197 
199 
     | 
    
         
             
                            async for out in tokenizer_manager.generate_request(obj, request):
         
     | 
| 
       198 
     | 
    
         
            -
                                yield  
     | 
| 
      
 200 
     | 
    
         
            +
                                yield b"data: " + orjson.dumps(
         
     | 
| 
      
 201 
     | 
    
         
            +
                                    out, option=orjson.OPT_NON_STR_KEYS
         
     | 
| 
      
 202 
     | 
    
         
            +
                                ) + b"\n\n"
         
     | 
| 
       199 
203 
     | 
    
         
             
                        except ValueError as e:
         
     | 
| 
       200 
204 
     | 
    
         
             
                            out = {"error": {"message": str(e)}}
         
     | 
| 
       201 
     | 
    
         
            -
                            yield  
     | 
| 
       202 
     | 
    
         
            -
             
     | 
| 
      
 205 
     | 
    
         
            +
                            yield b"data: " + orjson.dumps(
         
     | 
| 
      
 206 
     | 
    
         
            +
                                out, option=orjson.OPT_NON_STR_KEYS
         
     | 
| 
      
 207 
     | 
    
         
            +
                            ) + b"\n\n"
         
     | 
| 
      
 208 
     | 
    
         
            +
                        yield b"data: [DONE]\n\n"
         
     | 
| 
       203 
209 
     | 
    
         | 
| 
       204 
210 
     | 
    
         
             
                    return StreamingResponse(
         
     | 
| 
       205 
211 
     | 
    
         
             
                        stream_results(),
         
     | 
| 
         @@ -211,7 +217,7 @@ async def generate_request(obj: GenerateReqInput, request: Request): 
     | 
|
| 
       211 
217 
     | 
    
         
             
                        ret = await tokenizer_manager.generate_request(obj, request).__anext__()
         
     | 
| 
       212 
218 
     | 
    
         
             
                        return ret
         
     | 
| 
       213 
219 
     | 
    
         
             
                    except ValueError as e:
         
     | 
| 
       214 
     | 
    
         
            -
                        return  
     | 
| 
      
 220 
     | 
    
         
            +
                        return ORJSONResponse(
         
     | 
| 
       215 
221 
     | 
    
         
             
                            {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
         
     | 
| 
       216 
222 
     | 
    
         
             
                        )
         
     | 
| 
       217 
223 
     | 
    
         | 
| 
         @@ -226,7 +232,7 @@ async def encode_request(obj: EmbeddingReqInput, request: Request): 
     | 
|
| 
       226 
232 
     | 
    
         
             
                    ret = await tokenizer_manager.generate_request(obj, request).__anext__()
         
     | 
| 
       227 
233 
     | 
    
         
             
                    return ret
         
     | 
| 
       228 
234 
     | 
    
         
             
                except ValueError as e:
         
     | 
| 
       229 
     | 
    
         
            -
                    return  
     | 
| 
      
 235 
     | 
    
         
            +
                    return ORJSONResponse(
         
     | 
| 
       230 
236 
     | 
    
         
             
                        {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
         
     | 
| 
       231 
237 
     | 
    
         
             
                    )
         
     | 
| 
       232 
238 
     | 
    
         | 
| 
         @@ -241,7 +247,7 @@ async def judge_request(obj: RewardReqInput, request: Request): 
     | 
|
| 
       241 
247 
     | 
    
         
             
                    ret = await tokenizer_manager.generate_request(obj, request).__anext__()
         
     | 
| 
       242 
248 
     | 
    
         
             
                    return ret
         
     | 
| 
       243 
249 
     | 
    
         
             
                except ValueError as e:
         
     | 
| 
       244 
     | 
    
         
            -
                    return  
     | 
| 
      
 250 
     | 
    
         
            +
                    return ORJSONResponse(
         
     | 
| 
       245 
251 
     | 
    
         
             
                        {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
         
     | 
| 
       246 
252 
     | 
    
         
             
                    )
         
     | 
| 
       247 
253 
     | 
    
         | 
| 
         @@ -260,13 +266,13 @@ async def openai_v1_chat_completions(raw_request: Request): 
     | 
|
| 
       260 
266 
     | 
    
         
             
                return await v1_chat_completions(tokenizer_manager, raw_request)
         
     | 
| 
       261 
267 
     | 
    
         | 
| 
       262 
268 
     | 
    
         | 
| 
       263 
     | 
    
         
            -
            @app.post("/v1/embeddings")
         
     | 
| 
      
 269 
     | 
    
         
            +
            @app.post("/v1/embeddings", response_class=ORJSONResponse)
         
     | 
| 
       264 
270 
     | 
    
         
             
            async def openai_v1_embeddings(raw_request: Request):
         
     | 
| 
       265 
271 
     | 
    
         
             
                response = await v1_embeddings(tokenizer_manager, raw_request)
         
     | 
| 
       266 
272 
     | 
    
         
             
                return response
         
     | 
| 
       267 
273 
     | 
    
         | 
| 
       268 
274 
     | 
    
         | 
| 
       269 
     | 
    
         
            -
            @app.get("/v1/models")
         
     | 
| 
      
 275 
     | 
    
         
            +
            @app.get("/v1/models", response_class=ORJSONResponse)
         
     | 
| 
       270 
276 
     | 
    
         
             
            def available_models():
         
     | 
| 
       271 
277 
     | 
    
         
             
                """Show available models."""
         
     | 
| 
       272 
278 
     | 
    
         
             
                served_model_names = [tokenizer_manager.served_model_name]
         
     | 
| 
         @@ -429,6 +435,14 @@ def launch_server( 
     | 
|
| 
       429 
435 
     | 
    
         | 
| 
       430 
436 
     | 
    
         
             
                try:
         
     | 
| 
       431 
437 
     | 
    
         
             
                    # Listen for HTTP requests
         
     | 
| 
      
 438 
     | 
    
         
            +
                    LOGGING_CONFIG["formatters"]["default"][
         
     | 
| 
      
 439 
     | 
    
         
            +
                        "fmt"
         
     | 
| 
      
 440 
     | 
    
         
            +
                    ] = "[%(asctime)s] %(levelprefix)s %(message)s"
         
     | 
| 
      
 441 
     | 
    
         
            +
                    LOGGING_CONFIG["formatters"]["default"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
         
     | 
| 
      
 442 
     | 
    
         
            +
                    LOGGING_CONFIG["formatters"]["access"][
         
     | 
| 
      
 443 
     | 
    
         
            +
                        "fmt"
         
     | 
| 
      
 444 
     | 
    
         
            +
                    ] = '[%(asctime)s] %(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s'
         
     | 
| 
      
 445 
     | 
    
         
            +
                    LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
         
     | 
| 
       432 
446 
     | 
    
         
             
                    uvicorn.run(
         
     | 
| 
       433 
447 
     | 
    
         
             
                        app,
         
     | 
| 
       434 
448 
     | 
    
         
             
                        host=server_args.host,
         
     | 
| 
         @@ -447,7 +461,7 @@ def _set_envs_and_config(server_args: ServerArgs): 
     | 
|
| 
       447 
461 
     | 
    
         
             
                os.environ["NCCL_CUMEM_ENABLE"] = "0"
         
     | 
| 
       448 
462 
     | 
    
         
             
                os.environ["NCCL_NVLS_ENABLE"] = "0"
         
     | 
| 
       449 
463 
     | 
    
         
             
                os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
         
     | 
| 
       450 
     | 
    
         
            -
                os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = " 
     | 
| 
      
 464 
     | 
    
         
            +
                os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4"
         
     | 
| 
       451 
465 
     | 
    
         | 
| 
       452 
466 
     | 
    
         
             
                # Set ulimit
         
     | 
| 
       453 
467 
     | 
    
         
             
                set_ulimit()
         
     | 
| 
         @@ -528,6 +542,8 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid): 
     | 
|
| 
       528 
542 
     | 
    
         
             
                    kill_child_process(pid, including_parent=False)
         
     | 
| 
       529 
543 
     | 
    
         
             
                    return
         
     | 
| 
       530 
544 
     | 
    
         | 
| 
      
 545 
     | 
    
         
            +
                # logger.info(f"{res.json()=}")
         
     | 
| 
      
 546 
     | 
    
         
            +
             
     | 
| 
       531 
547 
     | 
    
         
             
                logger.info("The server is fired up and ready to roll!")
         
     | 
| 
       532 
548 
     | 
    
         
             
                if pipe_finish_writer is not None:
         
     | 
| 
       533 
549 
     | 
    
         
             
                    pipe_finish_writer.send("ready")
         
     | 
| 
         @@ -692,6 +708,10 @@ class Runtime: 
     | 
|
| 
       692 
708 
     | 
    
         
             
                    self.shutdown()
         
     | 
| 
       693 
709 
     | 
    
         | 
| 
       694 
710 
     | 
    
         | 
| 
      
 711 
     | 
    
         
            +
            STREAM_END_SYMBOL = b"data: [DONE]"
         
     | 
| 
      
 712 
     | 
    
         
            +
            STREAM_CHUNK_START_SYMBOL = b"data:"
         
     | 
| 
      
 713 
     | 
    
         
            +
             
     | 
| 
      
 714 
     | 
    
         
            +
             
     | 
| 
       695 
715 
     | 
    
         
             
            class Engine:
         
     | 
| 
       696 
716 
     | 
    
         
             
                """
         
     | 
| 
       697 
717 
     | 
    
         
             
                SRT Engine without an HTTP server layer.
         
     | 
| 
         @@ -716,7 +736,10 @@ class Engine: 
     | 
|
| 
       716 
736 
     | 
    
         
             
                    logprob_start_len: Optional[Union[List[int], int]] = None,
         
     | 
| 
       717 
737 
     | 
    
         
             
                    top_logprobs_num: Optional[Union[List[int], int]] = None,
         
     | 
| 
       718 
738 
     | 
    
         
             
                    lora_path: Optional[List[Optional[str]]] = None,
         
     | 
| 
      
 739 
     | 
    
         
            +
                    stream: bool = False,
         
     | 
| 
       719 
740 
     | 
    
         
             
                ):
         
     | 
| 
      
 741 
     | 
    
         
            +
                    # TODO (ByronHsu): refactor to reduce the duplicated code
         
     | 
| 
      
 742 
     | 
    
         
            +
             
     | 
| 
       720 
743 
     | 
    
         
             
                    obj = GenerateReqInput(
         
     | 
| 
       721 
744 
     | 
    
         
             
                        text=prompt,
         
     | 
| 
       722 
745 
     | 
    
         
             
                        sampling_params=sampling_params,
         
     | 
| 
         @@ -724,13 +747,89 @@ class Engine: 
     | 
|
| 
       724 
747 
     | 
    
         
             
                        logprob_start_len=logprob_start_len,
         
     | 
| 
       725 
748 
     | 
    
         
             
                        top_logprobs_num=top_logprobs_num,
         
     | 
| 
       726 
749 
     | 
    
         
             
                        lora_path=lora_path,
         
     | 
| 
      
 750 
     | 
    
         
            +
                        stream=stream,
         
     | 
| 
       727 
751 
     | 
    
         
             
                    )
         
     | 
| 
       728 
752 
     | 
    
         | 
| 
       729 
753 
     | 
    
         
             
                    # get the current event loop
         
     | 
| 
       730 
754 
     | 
    
         
             
                    loop = asyncio.get_event_loop()
         
     | 
| 
       731 
     | 
    
         
            -
                     
     | 
| 
      
 755 
     | 
    
         
            +
                    ret = loop.run_until_complete(generate_request(obj, None))
         
     | 
| 
      
 756 
     | 
    
         
            +
             
     | 
| 
      
 757 
     | 
    
         
            +
                    if stream is True:
         
     | 
| 
      
 758 
     | 
    
         
            +
             
     | 
| 
      
 759 
     | 
    
         
            +
                        def generator_wrapper():
         
     | 
| 
      
 760 
     | 
    
         
            +
                            offset = 0
         
     | 
| 
      
 761 
     | 
    
         
            +
                            loop = asyncio.get_event_loop()
         
     | 
| 
      
 762 
     | 
    
         
            +
                            generator = ret.body_iterator
         
     | 
| 
      
 763 
     | 
    
         
            +
                            while True:
         
     | 
| 
      
 764 
     | 
    
         
            +
                                chunk = loop.run_until_complete(generator.__anext__())
         
     | 
| 
      
 765 
     | 
    
         
            +
             
     | 
| 
      
 766 
     | 
    
         
            +
                                if chunk.startswith(STREAM_END_SYMBOL):
         
     | 
| 
      
 767 
     | 
    
         
            +
                                    break
         
     | 
| 
      
 768 
     | 
    
         
            +
                                else:
         
     | 
| 
      
 769 
     | 
    
         
            +
                                    data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :])
         
     | 
| 
      
 770 
     | 
    
         
            +
                                    data["text"] = data["text"][offset:]
         
     | 
| 
      
 771 
     | 
    
         
            +
                                    offset += len(data["text"])
         
     | 
| 
      
 772 
     | 
    
         
            +
                                    yield data
         
     | 
| 
      
 773 
     | 
    
         
            +
             
     | 
| 
      
 774 
     | 
    
         
            +
                        # we cannot yield in the scope of generate() because python does not allow yield + return in the same function
         
     | 
| 
      
 775 
     | 
    
         
            +
                        # however, it allows to wrap the generator as a subfunction and return
         
     | 
| 
      
 776 
     | 
    
         
            +
                        return generator_wrapper()
         
     | 
| 
      
 777 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 778 
     | 
    
         
            +
                        return ret
         
     | 
| 
      
 779 
     | 
    
         
            +
             
     | 
| 
      
 780 
     | 
    
         
            +
                async def async_generate(
         
     | 
| 
      
 781 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 782 
     | 
    
         
            +
                    prompt: Union[str, List[str]],
         
     | 
| 
      
 783 
     | 
    
         
            +
                    sampling_params: Optional[Dict] = None,
         
     | 
| 
      
 784 
     | 
    
         
            +
                    return_logprob: Optional[Union[List[bool], bool]] = False,
         
     | 
| 
      
 785 
     | 
    
         
            +
                    logprob_start_len: Optional[Union[List[int], int]] = None,
         
     | 
| 
      
 786 
     | 
    
         
            +
                    top_logprobs_num: Optional[Union[List[int], int]] = None,
         
     | 
| 
      
 787 
     | 
    
         
            +
                    lora_path: Optional[List[Optional[str]]] = None,
         
     | 
| 
      
 788 
     | 
    
         
            +
                    stream: bool = False,
         
     | 
| 
      
 789 
     | 
    
         
            +
                ):
         
     | 
| 
      
 790 
     | 
    
         
            +
                    obj = GenerateReqInput(
         
     | 
| 
      
 791 
     | 
    
         
            +
                        text=prompt,
         
     | 
| 
      
 792 
     | 
    
         
            +
                        sampling_params=sampling_params,
         
     | 
| 
      
 793 
     | 
    
         
            +
                        return_logprob=return_logprob,
         
     | 
| 
      
 794 
     | 
    
         
            +
                        logprob_start_len=logprob_start_len,
         
     | 
| 
      
 795 
     | 
    
         
            +
                        top_logprobs_num=top_logprobs_num,
         
     | 
| 
      
 796 
     | 
    
         
            +
                        lora_path=lora_path,
         
     | 
| 
      
 797 
     | 
    
         
            +
                        stream=stream,
         
     | 
| 
      
 798 
     | 
    
         
            +
                    )
         
     | 
| 
      
 799 
     | 
    
         
            +
             
     | 
| 
      
 800 
     | 
    
         
            +
                    ret = await generate_request(obj, None)
         
     | 
| 
      
 801 
     | 
    
         
            +
             
     | 
| 
      
 802 
     | 
    
         
            +
                    if stream is True:
         
     | 
| 
      
 803 
     | 
    
         
            +
                        generator = ret.body_iterator
         
     | 
| 
      
 804 
     | 
    
         
            +
             
     | 
| 
      
 805 
     | 
    
         
            +
                        async def generator_wrapper():
         
     | 
| 
      
 806 
     | 
    
         
            +
             
     | 
| 
      
 807 
     | 
    
         
            +
                            offset = 0
         
     | 
| 
      
 808 
     | 
    
         
            +
             
     | 
| 
      
 809 
     | 
    
         
            +
                            while True:
         
     | 
| 
      
 810 
     | 
    
         
            +
                                chunk = await generator.__anext__()
         
     | 
| 
      
 811 
     | 
    
         
            +
             
     | 
| 
      
 812 
     | 
    
         
            +
                                if chunk.startswith(STREAM_END_SYMBOL):
         
     | 
| 
      
 813 
     | 
    
         
            +
                                    break
         
     | 
| 
      
 814 
     | 
    
         
            +
                                else:
         
     | 
| 
      
 815 
     | 
    
         
            +
                                    data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :])
         
     | 
| 
      
 816 
     | 
    
         
            +
                                    data["text"] = data["text"][offset:]
         
     | 
| 
      
 817 
     | 
    
         
            +
                                    offset += len(data["text"])
         
     | 
| 
      
 818 
     | 
    
         
            +
                                    yield data
         
     | 
| 
      
 819 
     | 
    
         
            +
             
     | 
| 
      
 820 
     | 
    
         
            +
                        return generator_wrapper()
         
     | 
| 
      
 821 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 822 
     | 
    
         
            +
                        return ret
         
     | 
| 
       732 
823 
     | 
    
         | 
| 
       733 
824 
     | 
    
         
             
                def shutdown(self):
         
     | 
| 
       734 
825 
     | 
    
         
             
                    kill_child_process(os.getpid(), including_parent=False)
         
     | 
| 
       735 
826 
     | 
    
         | 
| 
       736 
     | 
    
         
            -
                 
     | 
| 
      
 827 
     | 
    
         
            +
                def get_tokenizer(self):
         
     | 
| 
      
 828 
     | 
    
         
            +
                    global tokenizer_manager
         
     | 
| 
      
 829 
     | 
    
         
            +
             
     | 
| 
      
 830 
     | 
    
         
            +
                    if tokenizer_manager is None:
         
     | 
| 
      
 831 
     | 
    
         
            +
                        raise ReferenceError("Tokenizer Manager is not initialized.")
         
     | 
| 
      
 832 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 833 
     | 
    
         
            +
                        return tokenizer_manager.tokenizer
         
     | 
| 
      
 834 
     | 
    
         
            +
             
     | 
| 
      
 835 
     | 
    
         
            +
                # TODO (ByronHsu): encode
         
     | 
    
        sglang/srt/server_args.py
    CHANGED
    
    | 
         @@ -35,12 +35,12 @@ class ServerArgs: 
     | 
|
| 
       35 
35 
     | 
    
         
             
                tokenizer_mode: str = "auto"
         
     | 
| 
       36 
36 
     | 
    
         
             
                skip_tokenizer_init: bool = False
         
     | 
| 
       37 
37 
     | 
    
         
             
                load_format: str = "auto"
         
     | 
| 
      
 38 
     | 
    
         
            +
                trust_remote_code: bool = True
         
     | 
| 
       38 
39 
     | 
    
         
             
                dtype: str = "auto"
         
     | 
| 
       39 
     | 
    
         
            -
                device: str = "cuda"
         
     | 
| 
       40 
40 
     | 
    
         
             
                kv_cache_dtype: str = "auto"
         
     | 
| 
       41 
     | 
    
         
            -
                trust_remote_code: bool = True
         
     | 
| 
       42 
     | 
    
         
            -
                context_length: Optional[int] = None
         
     | 
| 
       43 
41 
     | 
    
         
             
                quantization: Optional[str] = None
         
     | 
| 
      
 42 
     | 
    
         
            +
                context_length: Optional[int] = None
         
     | 
| 
      
 43 
     | 
    
         
            +
                device: str = "cuda"
         
     | 
| 
       44 
44 
     | 
    
         
             
                served_model_name: Optional[str] = None
         
     | 
| 
       45 
45 
     | 
    
         
             
                chat_template: Optional[str] = None
         
     | 
| 
       46 
46 
     | 
    
         
             
                is_embedding: bool = False
         
     | 
| 
         @@ -73,6 +73,7 @@ class ServerArgs: 
     | 
|
| 
       73 
73 
     | 
    
         
             
                # Other
         
     | 
| 
       74 
74 
     | 
    
         
             
                api_key: Optional[str] = None
         
     | 
| 
       75 
75 
     | 
    
         
             
                file_storage_pth: str = "SGLang_storage"
         
     | 
| 
      
 76 
     | 
    
         
            +
                enable_cache_report: bool = False
         
     | 
| 
       76 
77 
     | 
    
         | 
| 
       77 
78 
     | 
    
         
             
                # Data parallelism
         
     | 
| 
       78 
79 
     | 
    
         
             
                dp_size: int = 1
         
     | 
| 
         @@ -86,10 +87,23 @@ class ServerArgs: 
     | 
|
| 
       86 
87 
     | 
    
         
             
                # Model override args in JSON
         
     | 
| 
       87 
88 
     | 
    
         
             
                json_model_override_args: str = "{}"
         
     | 
| 
       88 
89 
     | 
    
         | 
| 
       89 
     | 
    
         
            -
                #  
     | 
| 
      
 90 
     | 
    
         
            +
                # Double Sparsity
         
     | 
| 
      
 91 
     | 
    
         
            +
                enable_double_sparsity: bool = False
         
     | 
| 
      
 92 
     | 
    
         
            +
                ds_channel_config_path: str = None
         
     | 
| 
      
 93 
     | 
    
         
            +
                ds_heavy_channel_num: int = 32
         
     | 
| 
      
 94 
     | 
    
         
            +
                ds_heavy_token_num: int = 256
         
     | 
| 
      
 95 
     | 
    
         
            +
                ds_heavy_channel_type: str = "qk"
         
     | 
| 
      
 96 
     | 
    
         
            +
                ds_sparse_decode_threshold: int = 4096
         
     | 
| 
      
 97 
     | 
    
         
            +
             
     | 
| 
      
 98 
     | 
    
         
            +
                # LoRA
         
     | 
| 
      
 99 
     | 
    
         
            +
                lora_paths: Optional[List[str]] = None
         
     | 
| 
      
 100 
     | 
    
         
            +
                max_loras_per_batch: int = 8
         
     | 
| 
      
 101 
     | 
    
         
            +
             
     | 
| 
      
 102 
     | 
    
         
            +
                # Kernel backend
         
     | 
| 
       90 
103 
     | 
    
         
             
                attention_backend: Optional[str] = None
         
     | 
| 
       91 
104 
     | 
    
         
             
                sampling_backend: Optional[str] = None
         
     | 
| 
       92 
105 
     | 
    
         | 
| 
      
 106 
     | 
    
         
            +
                # Optimization/debug options
         
     | 
| 
       93 
107 
     | 
    
         
             
                disable_flashinfer: bool = False
         
     | 
| 
       94 
108 
     | 
    
         
             
                disable_flashinfer_sampling: bool = False
         
     | 
| 
       95 
109 
     | 
    
         
             
                disable_radix_cache: bool = False
         
     | 
| 
         @@ -99,16 +113,16 @@ class ServerArgs: 
     | 
|
| 
       99 
113 
     | 
    
         
             
                disable_disk_cache: bool = False
         
     | 
| 
       100 
114 
     | 
    
         
             
                disable_custom_all_reduce: bool = False
         
     | 
| 
       101 
115 
     | 
    
         
             
                disable_mla: bool = False
         
     | 
| 
      
 116 
     | 
    
         
            +
                disable_penalizer: bool = False
         
     | 
| 
      
 117 
     | 
    
         
            +
                disable_nan_detection: bool = False
         
     | 
| 
      
 118 
     | 
    
         
            +
                enable_overlap_schedule: bool = False
         
     | 
| 
       102 
119 
     | 
    
         
             
                enable_mixed_chunk: bool = False
         
     | 
| 
       103 
120 
     | 
    
         
             
                enable_torch_compile: bool = False
         
     | 
| 
       104 
121 
     | 
    
         
             
                max_torch_compile_bs: int = 32
         
     | 
| 
       105 
122 
     | 
    
         
             
                torchao_config: str = ""
         
     | 
| 
       106 
123 
     | 
    
         
             
                enable_p2p_check: bool = False
         
     | 
| 
       107 
124 
     | 
    
         
             
                triton_attention_reduce_in_fp32: bool = False
         
     | 
| 
       108 
     | 
    
         
            -
             
     | 
| 
       109 
     | 
    
         
            -
                # LoRA
         
     | 
| 
       110 
     | 
    
         
            -
                lora_paths: Optional[List[str]] = None
         
     | 
| 
       111 
     | 
    
         
            -
                max_loras_per_batch: int = 8
         
     | 
| 
      
 125 
     | 
    
         
            +
                num_continuous_decode_steps: int = 1
         
     | 
| 
       112 
126 
     | 
    
         | 
| 
       113 
127 
     | 
    
         
             
                def __post_init__(self):
         
     | 
| 
       114 
128 
     | 
    
         
             
                    # Set missing default values
         
     | 
| 
         @@ -224,6 +238,11 @@ class ServerArgs: 
     | 
|
| 
       224 
238 
     | 
    
         
             
                        '"dummy" will initialize the weights with random values, '
         
     | 
| 
       225 
239 
     | 
    
         
             
                        "which is mainly for profiling.",
         
     | 
| 
       226 
240 
     | 
    
         
             
                    )
         
     | 
| 
      
 241 
     | 
    
         
            +
                    parser.add_argument(
         
     | 
| 
      
 242 
     | 
    
         
            +
                        "--trust-remote-code",
         
     | 
| 
      
 243 
     | 
    
         
            +
                        action="store_true",
         
     | 
| 
      
 244 
     | 
    
         
            +
                        help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
         
     | 
| 
      
 245 
     | 
    
         
            +
                    )
         
     | 
| 
       227 
246 
     | 
    
         
             
                    parser.add_argument(
         
     | 
| 
       228 
247 
     | 
    
         
             
                        "--dtype",
         
     | 
| 
       229 
248 
     | 
    
         
             
                        type=str,
         
     | 
| 
         @@ -238,13 +257,6 @@ class ServerArgs: 
     | 
|
| 
       238 
257 
     | 
    
         
             
                        '* "float" is shorthand for FP32 precision.\n'
         
     | 
| 
       239 
258 
     | 
    
         
             
                        '* "float32" for FP32 precision.',
         
     | 
| 
       240 
259 
     | 
    
         
             
                    )
         
     | 
| 
       241 
     | 
    
         
            -
                    parser.add_argument(
         
     | 
| 
       242 
     | 
    
         
            -
                        "--device",
         
     | 
| 
       243 
     | 
    
         
            -
                        type=str,
         
     | 
| 
       244 
     | 
    
         
            -
                        default="cuda",
         
     | 
| 
       245 
     | 
    
         
            -
                        choices=["cuda"],
         
     | 
| 
       246 
     | 
    
         
            -
                        help="The device type.",
         
     | 
| 
       247 
     | 
    
         
            -
                    )
         
     | 
| 
       248 
260 
     | 
    
         
             
                    parser.add_argument(
         
     | 
| 
       249 
261 
     | 
    
         
             
                        "--kv-cache-dtype",
         
     | 
| 
       250 
262 
     | 
    
         
             
                        type=str,
         
     | 
| 
         @@ -252,17 +264,6 @@ class ServerArgs: 
     | 
|
| 
       252 
264 
     | 
    
         
             
                        choices=["auto", "fp8_e5m2"],
         
     | 
| 
       253 
265 
     | 
    
         
             
                        help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" is supported for CUDA 11.8+.',
         
     | 
| 
       254 
266 
     | 
    
         
             
                    )
         
     | 
| 
       255 
     | 
    
         
            -
                    parser.add_argument(
         
     | 
| 
       256 
     | 
    
         
            -
                        "--trust-remote-code",
         
     | 
| 
       257 
     | 
    
         
            -
                        action="store_true",
         
     | 
| 
       258 
     | 
    
         
            -
                        help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
         
     | 
| 
       259 
     | 
    
         
            -
                    )
         
     | 
| 
       260 
     | 
    
         
            -
                    parser.add_argument(
         
     | 
| 
       261 
     | 
    
         
            -
                        "--context-length",
         
     | 
| 
       262 
     | 
    
         
            -
                        type=int,
         
     | 
| 
       263 
     | 
    
         
            -
                        default=ServerArgs.context_length,
         
     | 
| 
       264 
     | 
    
         
            -
                        help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).",
         
     | 
| 
       265 
     | 
    
         
            -
                    )
         
     | 
| 
       266 
267 
     | 
    
         
             
                    parser.add_argument(
         
     | 
| 
       267 
268 
     | 
    
         
             
                        "--quantization",
         
     | 
| 
       268 
269 
     | 
    
         
             
                        type=str,
         
     | 
| 
         @@ -278,6 +279,19 @@ class ServerArgs: 
     | 
|
| 
       278 
279 
     | 
    
         
             
                        ],
         
     | 
| 
       279 
280 
     | 
    
         
             
                        help="The quantization method.",
         
     | 
| 
       280 
281 
     | 
    
         
             
                    )
         
     | 
| 
      
 282 
     | 
    
         
            +
                    parser.add_argument(
         
     | 
| 
      
 283 
     | 
    
         
            +
                        "--context-length",
         
     | 
| 
      
 284 
     | 
    
         
            +
                        type=int,
         
     | 
| 
      
 285 
     | 
    
         
            +
                        default=ServerArgs.context_length,
         
     | 
| 
      
 286 
     | 
    
         
            +
                        help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).",
         
     | 
| 
      
 287 
     | 
    
         
            +
                    )
         
     | 
| 
      
 288 
     | 
    
         
            +
                    parser.add_argument(
         
     | 
| 
      
 289 
     | 
    
         
            +
                        "--device",
         
     | 
| 
      
 290 
     | 
    
         
            +
                        type=str,
         
     | 
| 
      
 291 
     | 
    
         
            +
                        default="cuda",
         
     | 
| 
      
 292 
     | 
    
         
            +
                        choices=["cuda", "xpu"],
         
     | 
| 
      
 293 
     | 
    
         
            +
                        help="The device type.",
         
     | 
| 
      
 294 
     | 
    
         
            +
                    )
         
     | 
| 
       281 
295 
     | 
    
         
             
                    parser.add_argument(
         
     | 
| 
       282 
296 
     | 
    
         
             
                        "--served-model-name",
         
     | 
| 
       283 
297 
     | 
    
         
             
                        type=str,
         
     | 
| 
         @@ -398,6 +412,11 @@ class ServerArgs: 
     | 
|
| 
       398 
412 
     | 
    
         
             
                        default=ServerArgs.file_storage_pth,
         
     | 
| 
       399 
413 
     | 
    
         
             
                        help="The path of the file storage in backend.",
         
     | 
| 
       400 
414 
     | 
    
         
             
                    )
         
     | 
| 
      
 415 
     | 
    
         
            +
                    parser.add_argument(
         
     | 
| 
      
 416 
     | 
    
         
            +
                        "--enable-cache-report",
         
     | 
| 
      
 417 
     | 
    
         
            +
                        action="store_true",
         
     | 
| 
      
 418 
     | 
    
         
            +
                        help="Return number of cached tokens in usage.prompt_tokens_details for each openai request.",
         
     | 
| 
      
 419 
     | 
    
         
            +
                    )
         
     | 
| 
       401 
420 
     | 
    
         | 
| 
       402 
421 
     | 
    
         
             
                    # Data parallelism
         
     | 
| 
       403 
422 
     | 
    
         
             
                    parser.add_argument(
         
     | 
| 
         @@ -440,7 +459,60 @@ class ServerArgs: 
     | 
|
| 
       440 
459 
     | 
    
         
             
                        default=ServerArgs.json_model_override_args,
         
     | 
| 
       441 
460 
     | 
    
         
             
                    )
         
     | 
| 
       442 
461 
     | 
    
         | 
| 
       443 
     | 
    
         
            -
                    #  
     | 
| 
      
 462 
     | 
    
         
            +
                    # Double Sparsity
         
     | 
| 
      
 463 
     | 
    
         
            +
                    parser.add_argument(
         
     | 
| 
      
 464 
     | 
    
         
            +
                        "--enable-double-sparsity",
         
     | 
| 
      
 465 
     | 
    
         
            +
                        action="store_true",
         
     | 
| 
      
 466 
     | 
    
         
            +
                        help="Enable double sparsity attention",
         
     | 
| 
      
 467 
     | 
    
         
            +
                    )
         
     | 
| 
      
 468 
     | 
    
         
            +
                    parser.add_argument(
         
     | 
| 
      
 469 
     | 
    
         
            +
                        "--ds-channel-config-path",
         
     | 
| 
      
 470 
     | 
    
         
            +
                        type=str,
         
     | 
| 
      
 471 
     | 
    
         
            +
                        default=ServerArgs.ds_channel_config_path,
         
     | 
| 
      
 472 
     | 
    
         
            +
                        help="The path of the double sparsity channel config",
         
     | 
| 
      
 473 
     | 
    
         
            +
                    )
         
     | 
| 
      
 474 
     | 
    
         
            +
                    parser.add_argument(
         
     | 
| 
      
 475 
     | 
    
         
            +
                        "--ds-heavy-channel-num",
         
     | 
| 
      
 476 
     | 
    
         
            +
                        type=int,
         
     | 
| 
      
 477 
     | 
    
         
            +
                        default=ServerArgs.ds_heavy_channel_num,
         
     | 
| 
      
 478 
     | 
    
         
            +
                        help="The number of heavy channels in double sparsity attention",
         
     | 
| 
      
 479 
     | 
    
         
            +
                    )
         
     | 
| 
      
 480 
     | 
    
         
            +
                    parser.add_argument(
         
     | 
| 
      
 481 
     | 
    
         
            +
                        "--ds-heavy-token-num",
         
     | 
| 
      
 482 
     | 
    
         
            +
                        type=int,
         
     | 
| 
      
 483 
     | 
    
         
            +
                        default=ServerArgs.ds_heavy_token_num,
         
     | 
| 
      
 484 
     | 
    
         
            +
                        help="The number of heavy tokens in double sparsity attention",
         
     | 
| 
      
 485 
     | 
    
         
            +
                    )
         
     | 
| 
      
 486 
     | 
    
         
            +
                    parser.add_argument(
         
     | 
| 
      
 487 
     | 
    
         
            +
                        "--ds-heavy-channel-type",
         
     | 
| 
      
 488 
     | 
    
         
            +
                        type=str,
         
     | 
| 
      
 489 
     | 
    
         
            +
                        default=ServerArgs.ds_heavy_channel_type,
         
     | 
| 
      
 490 
     | 
    
         
            +
                        help="The type of heavy channels in double sparsity attention",
         
     | 
| 
      
 491 
     | 
    
         
            +
                    )
         
     | 
| 
      
 492 
     | 
    
         
            +
                    parser.add_argument(
         
     | 
| 
      
 493 
     | 
    
         
            +
                        "--ds-sparse-decode-threshold",
         
     | 
| 
      
 494 
     | 
    
         
            +
                        type=int,
         
     | 
| 
      
 495 
     | 
    
         
            +
                        default=ServerArgs.ds_sparse_decode_threshold,
         
     | 
| 
      
 496 
     | 
    
         
            +
                        help="The type of heavy channels in double sparsity attention",
         
     | 
| 
      
 497 
     | 
    
         
            +
                    )
         
     | 
| 
      
 498 
     | 
    
         
            +
             
     | 
| 
      
 499 
     | 
    
         
            +
                    # LoRA
         
     | 
| 
      
 500 
     | 
    
         
            +
                    parser.add_argument(
         
     | 
| 
      
 501 
     | 
    
         
            +
                        "--lora-paths",
         
     | 
| 
      
 502 
     | 
    
         
            +
                        type=str,
         
     | 
| 
      
 503 
     | 
    
         
            +
                        nargs="*",
         
     | 
| 
      
 504 
     | 
    
         
            +
                        default=None,
         
     | 
| 
      
 505 
     | 
    
         
            +
                        action=LoRAPathAction,
         
     | 
| 
      
 506 
     | 
    
         
            +
                        help="The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}",
         
     | 
| 
      
 507 
     | 
    
         
            +
                    )
         
     | 
| 
      
 508 
     | 
    
         
            +
                    parser.add_argument(
         
     | 
| 
      
 509 
     | 
    
         
            +
                        "--max-loras-per-batch",
         
     | 
| 
      
 510 
     | 
    
         
            +
                        type=int,
         
     | 
| 
      
 511 
     | 
    
         
            +
                        default=8,
         
     | 
| 
      
 512 
     | 
    
         
            +
                        help="Maximum number of adapters for a running batch, include base-only request",
         
     | 
| 
      
 513 
     | 
    
         
            +
                    )
         
     | 
| 
      
 514 
     | 
    
         
            +
             
     | 
| 
      
 515 
     | 
    
         
            +
                    # Kernel backend
         
     | 
| 
       444 
516 
     | 
    
         
             
                    parser.add_argument(
         
     | 
| 
       445 
517 
     | 
    
         
             
                        "--attention-backend",
         
     | 
| 
       446 
518 
     | 
    
         
             
                        type=str,
         
     | 
| 
         @@ -455,6 +527,8 @@ class ServerArgs: 
     | 
|
| 
       455 
527 
     | 
    
         
             
                        default=ServerArgs.sampling_backend,
         
     | 
| 
       456 
528 
     | 
    
         
             
                        help="Choose the kernels for sampling layers.",
         
     | 
| 
       457 
529 
     | 
    
         
             
                    )
         
     | 
| 
      
 530 
     | 
    
         
            +
             
     | 
| 
      
 531 
     | 
    
         
            +
                    # Optimization/debug options
         
     | 
| 
       458 
532 
     | 
    
         
             
                    parser.add_argument(
         
     | 
| 
       459 
533 
     | 
    
         
             
                        "--disable-flashinfer",
         
     | 
| 
       460 
534 
     | 
    
         
             
                        action="store_true",
         
     | 
| 
         @@ -501,6 +575,21 @@ class ServerArgs: 
     | 
|
| 
       501 
575 
     | 
    
         
             
                        action="store_true",
         
     | 
| 
       502 
576 
     | 
    
         
             
                        help="Disable Multi-head Latent Attention (MLA) for DeepSeek-V2.",
         
     | 
| 
       503 
577 
     | 
    
         
             
                    )
         
     | 
| 
      
 578 
     | 
    
         
            +
                    parser.add_argument(
         
     | 
| 
      
 579 
     | 
    
         
            +
                        "--disable-penalizer",
         
     | 
| 
      
 580 
     | 
    
         
            +
                        action="store_true",
         
     | 
| 
      
 581 
     | 
    
         
            +
                        help="Disable the logit penalizers (e.g., frequency and repetition penalty) for better performance if they are not used in any requests.",
         
     | 
| 
      
 582 
     | 
    
         
            +
                    )
         
     | 
| 
      
 583 
     | 
    
         
            +
                    parser.add_argument(
         
     | 
| 
      
 584 
     | 
    
         
            +
                        "--disable-nan-detection",
         
     | 
| 
      
 585 
     | 
    
         
            +
                        action="store_true",
         
     | 
| 
      
 586 
     | 
    
         
            +
                        help="Disable the NaN detection for better performance.",
         
     | 
| 
      
 587 
     | 
    
         
            +
                    )
         
     | 
| 
      
 588 
     | 
    
         
            +
                    parser.add_argument(
         
     | 
| 
      
 589 
     | 
    
         
            +
                        "--enable-overlap-schedule",
         
     | 
| 
      
 590 
     | 
    
         
            +
                        action="store_true",
         
     | 
| 
      
 591 
     | 
    
         
            +
                        help="Overlap the CPU scheduler with GPU model worker. Experimental feature.",
         
     | 
| 
      
 592 
     | 
    
         
            +
                    )
         
     | 
| 
       504 
593 
     | 
    
         
             
                    parser.add_argument(
         
     | 
| 
       505 
594 
     | 
    
         
             
                        "--enable-mixed-chunk",
         
     | 
| 
       506 
595 
     | 
    
         
             
                        action="store_true",
         
     | 
| 
         @@ -535,25 +624,12 @@ class ServerArgs: 
     | 
|
| 
       535 
624 
     | 
    
         
             
                        "This only affects Triton attention kernels.",
         
     | 
| 
       536 
625 
     | 
    
         
             
                    )
         
     | 
| 
       537 
626 
     | 
    
         
             
                    parser.add_argument(
         
     | 
| 
       538 
     | 
    
         
            -
                        "-- 
     | 
| 
       539 
     | 
    
         
            -
                        action="store_true",
         
     | 
| 
       540 
     | 
    
         
            -
                        help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).",
         
     | 
| 
       541 
     | 
    
         
            -
                    )
         
     | 
| 
       542 
     | 
    
         
            -
             
     | 
| 
       543 
     | 
    
         
            -
                    # LoRA options
         
     | 
| 
       544 
     | 
    
         
            -
                    parser.add_argument(
         
     | 
| 
       545 
     | 
    
         
            -
                        "--lora-paths",
         
     | 
| 
       546 
     | 
    
         
            -
                        type=str,
         
     | 
| 
       547 
     | 
    
         
            -
                        nargs="*",
         
     | 
| 
       548 
     | 
    
         
            -
                        default=None,
         
     | 
| 
       549 
     | 
    
         
            -
                        action=LoRAPathAction,
         
     | 
| 
       550 
     | 
    
         
            -
                        help="The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}",
         
     | 
| 
       551 
     | 
    
         
            -
                    )
         
     | 
| 
       552 
     | 
    
         
            -
                    parser.add_argument(
         
     | 
| 
       553 
     | 
    
         
            -
                        "--max-loras-per-batch",
         
     | 
| 
      
 627 
     | 
    
         
            +
                        "--num-continuous-decode-steps",
         
     | 
| 
       554 
628 
     | 
    
         
             
                        type=int,
         
     | 
| 
       555 
     | 
    
         
            -
                        default= 
     | 
| 
       556 
     | 
    
         
            -
                        help=" 
     | 
| 
      
 629 
     | 
    
         
            +
                        default=ServerArgs.num_continuous_decode_steps,
         
     | 
| 
      
 630 
     | 
    
         
            +
                        help="Run multiple continuous decoding steps to reduce scheduling overhead. "
         
     | 
| 
      
 631 
     | 
    
         
            +
                        "This can potentially increase throughput but may also increase time-to-first-token latency. "
         
     | 
| 
      
 632 
     | 
    
         
            +
                        "The default value is 1, meaning only run one decoding step at a time.",
         
     | 
| 
       557 
633 
     | 
    
         
             
                    )
         
     | 
| 
       558 
634 
     | 
    
         | 
| 
       559 
635 
     | 
    
         
             
                @classmethod
         
     | 
    
        sglang/srt/utils.py
    CHANGED
    
    | 
         @@ -35,7 +35,7 @@ import psutil 
     | 
|
| 
       35 
35 
     | 
    
         
             
            import requests
         
     | 
| 
       36 
36 
     | 
    
         
             
            import torch
         
     | 
| 
       37 
37 
     | 
    
         
             
            import torch.distributed as dist
         
     | 
| 
       38 
     | 
    
         
            -
            from fastapi.responses import  
     | 
| 
      
 38 
     | 
    
         
            +
            from fastapi.responses import ORJSONResponse
         
     | 
| 
       39 
39 
     | 
    
         
             
            from packaging import version as pkg_version
         
     | 
| 
       40 
40 
     | 
    
         
             
            from torch import nn
         
     | 
| 
       41 
41 
     | 
    
         
             
            from torch.profiler import ProfilerActivity, profile, record_function
         
     | 
| 
         @@ -566,7 +566,7 @@ def add_api_key_middleware(app, api_key: str): 
     | 
|
| 
       566 
566 
     | 
    
         
             
                    if request.url.path.startswith("/health"):
         
     | 
| 
       567 
567 
     | 
    
         
             
                        return await call_next(request)
         
     | 
| 
       568 
568 
     | 
    
         
             
                    if request.headers.get("Authorization") != "Bearer " + api_key:
         
     | 
| 
       569 
     | 
    
         
            -
                        return  
     | 
| 
      
 569 
     | 
    
         
            +
                        return ORJSONResponse(content={"error": "Unauthorized"}, status_code=401)
         
     | 
| 
       570 
570 
     | 
    
         
             
                    return await call_next(request)
         
     | 
| 
       571 
571 
     | 
    
         | 
| 
       572 
572 
     | 
    
         | 
| 
         @@ -584,10 +584,11 @@ def prepare_model_and_tokenizer(model_path: str, tokenizer_path: str): 
     | 
|
| 
       584 
584 
     | 
    
         | 
| 
       585 
585 
     | 
    
         
             
            def configure_logger(server_args, prefix: str = ""):
         
     | 
| 
       586 
586 
     | 
    
         
             
                format = f"[%(asctime)s{prefix}] %(message)s"
         
     | 
| 
      
 587 
     | 
    
         
            +
                # format = f"[%(asctime)s.%(msecs)03d{prefix}] %(message)s"
         
     | 
| 
       587 
588 
     | 
    
         
             
                logging.basicConfig(
         
     | 
| 
       588 
589 
     | 
    
         
             
                    level=getattr(logging, server_args.log_level.upper()),
         
     | 
| 
       589 
590 
     | 
    
         
             
                    format=format,
         
     | 
| 
       590 
     | 
    
         
            -
                    datefmt="%H:%M:%S",
         
     | 
| 
      
 591 
     | 
    
         
            +
                    datefmt="%Y-%m-%d %H:%M:%S",
         
     | 
| 
       591 
592 
     | 
    
         
             
                    force=True,
         
     | 
| 
       592 
593 
     | 
    
         
             
                )
         
     | 
| 
       593 
594 
     | 
    
         | 
| 
         @@ -690,3 +691,10 @@ def pytorch_profile(name, func, *args, data_size=-1): 
     | 
|
| 
       690 
691 
     | 
    
         
             
                prof.export_chrome_trace(f"trace/{name}_{step_counter}.json")
         
     | 
| 
       691 
692 
     | 
    
         
             
                step_counter += 1
         
     | 
| 
       692 
693 
     | 
    
         
             
                return result
         
     | 
| 
      
 694 
     | 
    
         
            +
             
     | 
| 
      
 695 
     | 
    
         
            +
             
     | 
| 
      
 696 
     | 
    
         
            +
            def first_rank_print(*args, **kwargs):
         
     | 
| 
      
 697 
     | 
    
         
            +
                if torch.cuda.current_device() == 0:
         
     | 
| 
      
 698 
     | 
    
         
            +
                    print(*args, **kwargs)
         
     | 
| 
      
 699 
     | 
    
         
            +
                else:
         
     | 
| 
      
 700 
     | 
    
         
            +
                    pass
         
     | 
    
        sglang/test/few_shot_gsm8k.py
    CHANGED
    
    | 
         @@ -76,7 +76,9 @@ def run_eval(args): 
     | 
|
| 
       76 
76 
     | 
    
         
             
                def few_shot_gsm8k(s, question):
         
     | 
| 
       77 
77 
     | 
    
         
             
                    s += few_shot_examples + question
         
     | 
| 
       78 
78 
     | 
    
         
             
                    s += sgl.gen(
         
     | 
| 
       79 
     | 
    
         
            -
                        "answer", 
     | 
| 
      
 79 
     | 
    
         
            +
                        "answer",
         
     | 
| 
      
 80 
     | 
    
         
            +
                        max_tokens=args.max_new_tokens,
         
     | 
| 
      
 81 
     | 
    
         
            +
                        stop=["Question", "Assistant:", "<|separator|>"],
         
     | 
| 
       80 
82 
     | 
    
         
             
                    )
         
     | 
| 
       81 
83 
     | 
    
         | 
| 
       82 
84 
     | 
    
         
             
                #####################################
         
     | 
| 
         @@ -131,6 +133,7 @@ if __name__ == "__main__": 
     | 
|
| 
       131 
133 
     | 
    
         
             
                parser.add_argument("--num-shots", type=int, default=5)
         
     | 
| 
       132 
134 
     | 
    
         
             
                parser.add_argument("--data-path", type=str, default="test.jsonl")
         
     | 
| 
       133 
135 
     | 
    
         
             
                parser.add_argument("--num-questions", type=int, default=200)
         
     | 
| 
      
 136 
     | 
    
         
            +
                parser.add_argument("--max-new-tokens", type=int, default=512)
         
     | 
| 
       134 
137 
     | 
    
         
             
                parser.add_argument("--parallel", type=int, default=128)
         
     | 
| 
       135 
138 
     | 
    
         
             
                parser.add_argument("--host", type=str, default="http://127.0.0.1")
         
     | 
| 
       136 
139 
     | 
    
         
             
                parser.add_argument("--port", type=int, default=30000)
         
     |