sglang 0.3.6.post1__py3-none-any.whl → 0.3.6.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +55 -2
- sglang/bench_one_batch.py +4 -8
- sglang/bench_one_batch_server.py +6 -5
- sglang/check_env.py +7 -1
- sglang/lang/tracer.py +1 -1
- sglang/launch_server.py +2 -4
- sglang/srt/configs/model_config.py +2 -6
- sglang/srt/layers/attention/flashinfer_backend.py +3 -3
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/managers/data_parallel_controller.py +7 -11
- sglang/srt/managers/detokenizer_manager.py +7 -6
- sglang/srt/managers/image_processor.py +7 -10
- sglang/srt/managers/io_struct.py +0 -10
- sglang/srt/managers/schedule_batch.py +51 -13
- sglang/srt/managers/scheduler.py +41 -29
- sglang/srt/managers/session_controller.py +15 -7
- sglang/srt/managers/tokenizer_manager.py +4 -33
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -2
- sglang/srt/models/grok.py +11 -48
- sglang/srt/models/llava.py +16 -9
- sglang/srt/models/olmo2.py +392 -0
- sglang/srt/models/qwen2_vl.py +10 -3
- sglang/srt/openai_api/adapter.py +1 -1
- sglang/srt/server.py +48 -45
- sglang/srt/server_args.py +1 -1
- sglang/srt/utils.py +22 -24
- sglang/test/test_utils.py +21 -8
- sglang/utils.py +2 -2
- sglang/version.py +1 -1
- {sglang-0.3.6.post1.dist-info → sglang-0.3.6.post3.dist-info}/METADATA +4 -2
- {sglang-0.3.6.post1.dist-info → sglang-0.3.6.post3.dist-info}/RECORD +34 -36
- sglang/srt/layers/fused_moe_grok/__init__.py +0 -1
- sglang/srt/layers/fused_moe_grok/fused_moe.py +0 -692
- sglang/srt/layers/fused_moe_grok/layer.py +0 -630
- {sglang-0.3.6.post1.dist-info → sglang-0.3.6.post3.dist-info}/LICENSE +0 -0
- {sglang-0.3.6.post1.dist-info → sglang-0.3.6.post3.dist-info}/WHEEL +0 -0
- {sglang-0.3.6.post1.dist-info → sglang-0.3.6.post3.dist-info}/top_level.txt +0 -0
    
        sglang/srt/server.py
    CHANGED
    
    | @@ -23,6 +23,7 @@ import json | |
| 23 23 | 
             
            import logging
         | 
| 24 24 | 
             
            import multiprocessing as mp
         | 
| 25 25 | 
             
            import os
         | 
| 26 | 
            +
            import signal
         | 
| 26 27 | 
             
            import threading
         | 
| 27 28 | 
             
            import time
         | 
| 28 29 | 
             
            from http import HTTPStatus
         | 
| @@ -79,19 +80,20 @@ from sglang.srt.utils import ( | |
| 79 80 | 
             
                configure_logger,
         | 
| 80 81 | 
             
                delete_directory,
         | 
| 81 82 | 
             
                is_port_available,
         | 
| 82 | 
            -
                 | 
| 83 | 
            +
                kill_process_tree,
         | 
| 83 84 | 
             
                maybe_set_triton_cache_manager,
         | 
| 84 85 | 
             
                prepare_model_and_tokenizer,
         | 
| 85 86 | 
             
                set_prometheus_multiproc_dir,
         | 
| 86 87 | 
             
                set_ulimit,
         | 
| 87 88 | 
             
            )
         | 
| 88 89 | 
             
            from sglang.utils import get_exception_traceback
         | 
| 90 | 
            +
            from sglang.version import __version__
         | 
| 89 91 |  | 
| 90 92 | 
             
            logger = logging.getLogger(__name__)
         | 
| 91 93 |  | 
| 92 94 | 
             
            asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
         | 
| 93 95 |  | 
| 94 | 
            -
             | 
| 96 | 
            +
            # Fast API
         | 
| 95 97 | 
             
            app = FastAPI()
         | 
| 96 98 | 
             
            app.add_middleware(
         | 
| 97 99 | 
             
                CORSMiddleware,
         | 
| @@ -102,7 +104,7 @@ app.add_middleware( | |
| 102 104 | 
             
            )
         | 
| 103 105 |  | 
| 104 106 | 
             
            tokenizer_manager: TokenizerManager = None
         | 
| 105 | 
            -
             | 
| 107 | 
            +
            scheduler_info: Dict = None
         | 
| 106 108 |  | 
| 107 109 | 
             
            ##### Native API endpoints #####
         | 
| 108 110 |  | 
| @@ -170,7 +172,7 @@ async def flush_cache(): | |
| 170 172 |  | 
| 171 173 | 
             
            @app.get("/start_profile")
         | 
| 172 174 | 
             
            @app.post("/start_profile")
         | 
| 173 | 
            -
            async def  | 
| 175 | 
            +
            async def start_profile_async():
         | 
| 174 176 | 
             
                """Start profiling."""
         | 
| 175 177 | 
             
                tokenizer_manager.start_profile()
         | 
| 176 178 | 
             
                return Response(
         | 
| @@ -181,7 +183,7 @@ async def start_profile(): | |
| 181 183 |  | 
| 182 184 | 
             
            @app.get("/stop_profile")
         | 
| 183 185 | 
             
            @app.post("/stop_profile")
         | 
| 184 | 
            -
            async def  | 
| 186 | 
            +
            async def stop_profile_async():
         | 
| 185 187 | 
             
                """Stop profiling."""
         | 
| 186 188 | 
             
                tokenizer_manager.stop_profile()
         | 
| 187 189 | 
             
                return Response(
         | 
| @@ -232,6 +234,8 @@ async def close_session(obj: CloseSessionReqInput, request: Request): | |
| 232 234 | 
             
                    )
         | 
| 233 235 |  | 
| 234 236 |  | 
| 237 | 
            +
            # fastapi implicitly converts json in the request to obj (dataclass)
         | 
| 238 | 
            +
            @app.api_route("/generate", methods=["POST", "PUT"])
         | 
| 235 239 | 
             
            @time_func_latency
         | 
| 236 240 | 
             
            async def generate_request(obj: GenerateReqInput, request: Request):
         | 
| 237 241 | 
             
                """Handle a generate request."""
         | 
| @@ -265,11 +269,7 @@ async def generate_request(obj: GenerateReqInput, request: Request): | |
| 265 269 | 
             
                        )
         | 
| 266 270 |  | 
| 267 271 |  | 
| 268 | 
            -
             | 
| 269 | 
            -
            app.post("/generate")(generate_request)
         | 
| 270 | 
            -
            app.put("/generate")(generate_request)
         | 
| 271 | 
            -
             | 
| 272 | 
            -
             | 
| 272 | 
            +
            @app.api_route("/encode", methods=["POST", "PUT"])
         | 
| 273 273 | 
             
            @time_func_latency
         | 
| 274 274 | 
             
            async def encode_request(obj: EmbeddingReqInput, request: Request):
         | 
| 275 275 | 
             
                """Handle an embedding request."""
         | 
| @@ -282,10 +282,7 @@ async def encode_request(obj: EmbeddingReqInput, request: Request): | |
| 282 282 | 
             
                    )
         | 
| 283 283 |  | 
| 284 284 |  | 
| 285 | 
            -
            app. | 
| 286 | 
            -
            app.put("/encode")(encode_request)
         | 
| 287 | 
            -
             | 
| 288 | 
            -
             | 
| 285 | 
            +
            @app.api_route("/encode", methods=["POST", "PUT"])
         | 
| 289 286 | 
             
            @time_func_latency
         | 
| 290 287 | 
             
            async def classify_request(obj: EmbeddingReqInput, request: Request):
         | 
| 291 288 | 
             
                """Handle a reward model request. Now the arguments and return values are the same as embedding models."""
         | 
| @@ -298,10 +295,6 @@ async def classify_request(obj: EmbeddingReqInput, request: Request): | |
| 298 295 | 
             
                    )
         | 
| 299 296 |  | 
| 300 297 |  | 
| 301 | 
            -
            app.post("/classify")(classify_request)
         | 
| 302 | 
            -
            app.put("/classify")(classify_request)
         | 
| 303 | 
            -
             | 
| 304 | 
            -
             | 
| 305 298 | 
             
            ##### OpenAI-compatible API endpoints #####
         | 
| 306 299 |  | 
| 307 300 |  | 
| @@ -379,11 +372,11 @@ def launch_engine( | |
| 379 372 | 
             
                server_args: ServerArgs,
         | 
| 380 373 | 
             
            ):
         | 
| 381 374 | 
             
                """
         | 
| 382 | 
            -
                Launch the  | 
| 375 | 
            +
                Launch the TokenizerManager in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess.
         | 
| 383 376 | 
             
                """
         | 
| 384 377 |  | 
| 385 378 | 
             
                global tokenizer_manager
         | 
| 386 | 
            -
                global  | 
| 379 | 
            +
                global scheduler_info
         | 
| 387 380 |  | 
| 388 381 | 
             
                # Configure global environment
         | 
| 389 382 | 
             
                configure_logger(server_args)
         | 
| @@ -449,20 +442,19 @@ def launch_engine( | |
| 449 442 | 
             
                if server_args.chat_template:
         | 
| 450 443 | 
             
                    load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template)
         | 
| 451 444 |  | 
| 452 | 
            -
                # Wait for model to finish loading | 
| 453 | 
            -
                 | 
| 445 | 
            +
                # Wait for model to finish loading
         | 
| 446 | 
            +
                scheduler_infos = []
         | 
| 454 447 | 
             
                for i in range(len(scheduler_pipe_readers)):
         | 
| 455 448 | 
             
                    data = scheduler_pipe_readers[i].recv()
         | 
| 456 449 |  | 
| 457 450 | 
             
                    if data["status"] != "ready":
         | 
| 458 | 
            -
                        self.shutdown()
         | 
| 459 451 | 
             
                        raise RuntimeError(
         | 
| 460 452 | 
             
                            "Initialization failed. Please see the error messages above."
         | 
| 461 453 | 
             
                        )
         | 
| 462 | 
            -
                     | 
| 454 | 
            +
                    scheduler_infos.append(data)
         | 
| 463 455 |  | 
| 464 456 | 
             
                # Assume all schedulers have same max_total_num_tokens
         | 
| 465 | 
            -
                 | 
| 457 | 
            +
                scheduler_info = scheduler_infos[0]
         | 
| 466 458 |  | 
| 467 459 |  | 
| 468 460 | 
             
            def launch_server(
         | 
| @@ -476,12 +468,12 @@ def launch_server( | |
| 476 468 |  | 
| 477 469 | 
             
                1. HTTP server: A FastAPI server that routes requests to the engine.
         | 
| 478 470 | 
             
                2. SRT engine:
         | 
| 479 | 
            -
                    1.  | 
| 471 | 
            +
                    1. TokenizerManager: Tokenizes the requests and sends them to the scheduler.
         | 
| 480 472 | 
             
                    2. Scheduler (subprocess): Receives requests from the Tokenizer Manager, schedules batches, forwards them, and sends the output tokens to the Detokenizer Manager.
         | 
| 481 | 
            -
                    3.  | 
| 473 | 
            +
                    3. DetokenizerManager (subprocess): Detokenizes the output tokens and sends the result back to the Tokenizer Manager.
         | 
| 482 474 |  | 
| 483 475 | 
             
                Note:
         | 
| 484 | 
            -
                1. The HTTP server and  | 
| 476 | 
            +
                1. The HTTP server and TokenizerManager both run in the main process.
         | 
| 485 477 | 
             
                2. Inter-process communication is done through ICP (each process uses a different port) via the ZMQ library.
         | 
| 486 478 | 
             
                """
         | 
| 487 479 | 
             
                launch_engine(server_args=server_args)
         | 
| @@ -490,7 +482,7 @@ def launch_server( | |
| 490 482 | 
             
                if server_args.api_key:
         | 
| 491 483 | 
             
                    add_api_key_middleware(app, server_args.api_key)
         | 
| 492 484 |  | 
| 493 | 
            -
                #  | 
| 485 | 
            +
                # Add prometheus middleware
         | 
| 494 486 | 
             
                if server_args.enable_metrics:
         | 
| 495 487 | 
             
                    add_prometheus_middleware(app)
         | 
| 496 488 | 
             
                    enable_func_timer()
         | 
| @@ -502,7 +494,7 @@ def launch_server( | |
| 502 494 | 
             
                t.start()
         | 
| 503 495 |  | 
| 504 496 | 
             
                try:
         | 
| 505 | 
            -
                    #  | 
| 497 | 
            +
                    # Update logging configs
         | 
| 506 498 | 
             
                    LOGGING_CONFIG["formatters"]["default"][
         | 
| 507 499 | 
             
                        "fmt"
         | 
| 508 500 | 
             
                    ] = "[%(asctime)s] %(levelprefix)s %(message)s"
         | 
| @@ -511,6 +503,8 @@ def launch_server( | |
| 511 503 | 
             
                        "fmt"
         | 
| 512 504 | 
             
                    ] = '[%(asctime)s] %(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s'
         | 
| 513 505 | 
             
                    LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
         | 
| 506 | 
            +
             | 
| 507 | 
            +
                    # Listen for HTTP requests
         | 
| 514 508 | 
             
                    uvicorn.run(
         | 
| 515 509 | 
             
                        app,
         | 
| 516 510 | 
             
                        host=server_args.host,
         | 
| @@ -526,8 +520,8 @@ def launch_server( | |
| 526 520 | 
             
            async def _get_server_info():
         | 
| 527 521 | 
             
                return {
         | 
| 528 522 | 
             
                    **dataclasses.asdict(tokenizer_manager.server_args),  # server args
         | 
| 529 | 
            -
                     | 
| 530 | 
            -
                    " | 
| 523 | 
            +
                    **scheduler_info,
         | 
| 524 | 
            +
                    "version": __version__,
         | 
| 531 525 | 
             
                }
         | 
| 532 526 |  | 
| 533 527 |  | 
| @@ -561,6 +555,15 @@ def _set_envs_and_config(server_args: ServerArgs): | |
| 561 555 | 
             
                        "at https://docs.flashinfer.ai/installation.html.",
         | 
| 562 556 | 
             
                    )
         | 
| 563 557 |  | 
| 558 | 
            +
                # Register the signal handler.
         | 
| 559 | 
            +
                # The child processes will send SIGQUIT to this process when any error happens
         | 
| 560 | 
            +
                # This process then clean up the whole process tree
         | 
| 561 | 
            +
                def sigquit_handler(signum, frame):
         | 
| 562 | 
            +
                    kill_process_tree(os.getpid())
         | 
| 563 | 
            +
             | 
| 564 | 
            +
                signal.signal(signal.SIGQUIT, sigquit_handler)
         | 
| 565 | 
            +
             | 
| 566 | 
            +
                # Set mp start method
         | 
| 564 567 | 
             
                mp.set_start_method("spawn", force=True)
         | 
| 565 568 |  | 
| 566 569 |  | 
| @@ -587,7 +590,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer): | |
| 587 590 | 
             
                    if pipe_finish_writer is not None:
         | 
| 588 591 | 
             
                        pipe_finish_writer.send(last_traceback)
         | 
| 589 592 | 
             
                    logger.error(f"Initialization failed. warmup error: {last_traceback}")
         | 
| 590 | 
            -
                     | 
| 593 | 
            +
                    kill_process_tree(os.getpid())
         | 
| 591 594 | 
             
                    return
         | 
| 592 595 |  | 
| 593 596 | 
             
                model_info = res.json()
         | 
| @@ -620,9 +623,10 @@ def _wait_and_warmup(server_args, pipe_finish_writer): | |
| 620 623 | 
             
                    if pipe_finish_writer is not None:
         | 
| 621 624 | 
             
                        pipe_finish_writer.send(last_traceback)
         | 
| 622 625 | 
             
                    logger.error(f"Initialization failed. warmup error: {last_traceback}")
         | 
| 623 | 
            -
                     | 
| 626 | 
            +
                    kill_process_tree(os.getpid())
         | 
| 624 627 | 
             
                    return
         | 
| 625 628 |  | 
| 629 | 
            +
                # Debug print
         | 
| 626 630 | 
             
                # logger.info(f"{res.json()=}")
         | 
| 627 631 |  | 
| 628 632 | 
             
                logger.info("The server is fired up and ready to roll!")
         | 
| @@ -689,7 +693,7 @@ class Runtime: | |
| 689 693 |  | 
| 690 694 | 
             
                def shutdown(self):
         | 
| 691 695 | 
             
                    if self.pid is not None:
         | 
| 692 | 
            -
                         | 
| 696 | 
            +
                        kill_process_tree(self.pid)
         | 
| 693 697 | 
             
                        self.pid = None
         | 
| 694 698 |  | 
| 695 699 | 
             
                def cache_prefix(self, prefix: str):
         | 
| @@ -799,18 +803,11 @@ class Engine: | |
| 799 803 | 
             
                launching the HTTP server adds unnecessary complexity or overhead,
         | 
| 800 804 | 
             
                """
         | 
| 801 805 |  | 
| 802 | 
            -
                def __init__(self, *args, **kwargs):
         | 
| 803 | 
            -
             | 
| 806 | 
            +
                def __init__(self, log_level: str = "error", *args, **kwargs):
         | 
| 804 807 | 
             
                    # before python program terminates, call shutdown implicitly. Therefore, users don't have to explicitly call .shutdown()
         | 
| 805 808 | 
             
                    atexit.register(self.shutdown)
         | 
| 806 809 |  | 
| 807 | 
            -
                     | 
| 808 | 
            -
                    # offline engine works in scripts, so we set it to error
         | 
| 809 | 
            -
             | 
| 810 | 
            -
                    if "log_level" not in kwargs:
         | 
| 811 | 
            -
                        kwargs["log_level"] = "error"
         | 
| 812 | 
            -
             | 
| 813 | 
            -
                    server_args = ServerArgs(*args, **kwargs)
         | 
| 810 | 
            +
                    server_args = ServerArgs(*args, log_level=log_level, **kwargs)
         | 
| 814 811 | 
             
                    launch_engine(server_args=server_args)
         | 
| 815 812 |  | 
| 816 813 | 
             
                def generate(
         | 
| @@ -913,7 +910,7 @@ class Engine: | |
| 913 910 | 
             
                        return ret
         | 
| 914 911 |  | 
| 915 912 | 
             
                def shutdown(self):
         | 
| 916 | 
            -
                     | 
| 913 | 
            +
                    kill_process_tree(os.getpid(), include_parent=False)
         | 
| 917 914 |  | 
| 918 915 | 
             
                def get_tokenizer(self):
         | 
| 919 916 | 
             
                    global tokenizer_manager
         | 
| @@ -933,5 +930,11 @@ class Engine: | |
| 933 930 | 
             
                    loop = asyncio.get_event_loop()
         | 
| 934 931 | 
             
                    return loop.run_until_complete(encode_request(obj, None))
         | 
| 935 932 |  | 
| 933 | 
            +
                def start_profile(self):
         | 
| 934 | 
            +
                    tokenizer_manager.start_profile()
         | 
| 935 | 
            +
             | 
| 936 | 
            +
                def stop_profile(self):
         | 
| 937 | 
            +
                    tokenizer_manager.stop_profile()
         | 
| 938 | 
            +
             | 
| 936 939 | 
             
                async def get_server_info(self):
         | 
| 937 940 | 
             
                    return await _get_server_info()
         | 
    
        sglang/srt/server_args.py
    CHANGED
    
    | @@ -144,7 +144,7 @@ class ServerArgs: | |
| 144 144 | 
             
                    if self.served_model_name is None:
         | 
| 145 145 | 
             
                        self.served_model_name = self.model_path
         | 
| 146 146 |  | 
| 147 | 
            -
                    if self.chunked_prefill_size <= 0:
         | 
| 147 | 
            +
                    if self.chunked_prefill_size is not None and self.chunked_prefill_size <= 0:
         | 
| 148 148 | 
             
                        # Disable chunked prefill
         | 
| 149 149 | 
             
                        self.chunked_prefill_size = None
         | 
| 150 150 |  | 
    
        sglang/srt/utils.py
    CHANGED
    
    | @@ -72,7 +72,7 @@ def is_flashinfer_available(): | |
| 72 72 | 
             
                Check whether flashinfer is available.
         | 
| 73 73 | 
             
                As of Oct. 6, 2024, it is only available on NVIDIA GPUs.
         | 
| 74 74 | 
             
                """
         | 
| 75 | 
            -
                if  | 
| 75 | 
            +
                if not get_bool_env_var("SGLANG_IS_FLASHINFER_AVAILABLE", default="true"):
         | 
| 76 76 | 
             
                    return False
         | 
| 77 77 | 
             
                return torch.cuda.is_available() and not is_hip()
         | 
| 78 78 |  | 
| @@ -443,26 +443,14 @@ def assert_pkg_version(pkg: str, min_version: str, message: str): | |
| 443 443 | 
             
                    )
         | 
| 444 444 |  | 
| 445 445 |  | 
| 446 | 
            -
            def  | 
| 447 | 
            -
                """Kill the  | 
| 448 | 
            -
                 | 
| 449 | 
            -
             | 
| 450 | 
            -
             | 
| 451 | 
            -
                    parent_process.pid, include_self=True, skip_pid=current_process.pid
         | 
| 452 | 
            -
                )
         | 
| 453 | 
            -
                try:
         | 
| 454 | 
            -
                    current_process.kill()
         | 
| 455 | 
            -
                except psutil.NoSuchProcess:
         | 
| 456 | 
            -
                    pass
         | 
| 457 | 
            -
             | 
| 458 | 
            -
             | 
| 459 | 
            -
            def kill_child_process(pid=None, include_self=False, skip_pid=None):
         | 
| 460 | 
            -
                """Kill the process and all its children process."""
         | 
| 461 | 
            -
                if pid is None:
         | 
| 462 | 
            -
                    pid = os.getpid()
         | 
| 446 | 
            +
            def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = None):
         | 
| 447 | 
            +
                """Kill the process and all its child processes."""
         | 
| 448 | 
            +
                if parent_pid is None:
         | 
| 449 | 
            +
                    parent_pid = os.getpid()
         | 
| 450 | 
            +
                    include_parent = False
         | 
| 463 451 |  | 
| 464 452 | 
             
                try:
         | 
| 465 | 
            -
                    itself = psutil.Process( | 
| 453 | 
            +
                    itself = psutil.Process(parent_pid)
         | 
| 466 454 | 
             
                except psutil.NoSuchProcess:
         | 
| 467 455 | 
             
                    return
         | 
| 468 456 |  | 
| @@ -475,13 +463,13 @@ def kill_child_process(pid=None, include_self=False, skip_pid=None): | |
| 475 463 | 
             
                    except psutil.NoSuchProcess:
         | 
| 476 464 | 
             
                        pass
         | 
| 477 465 |  | 
| 478 | 
            -
                if  | 
| 466 | 
            +
                if include_parent:
         | 
| 479 467 | 
             
                    try:
         | 
| 480 468 | 
             
                        itself.kill()
         | 
| 481 469 |  | 
| 482 470 | 
             
                        # Sometime processes cannot be killed with SIGKILL (e.g, PID=1 launched by kubernetes),
         | 
| 483 471 | 
             
                        # so we send an additional signal to kill them.
         | 
| 484 | 
            -
                        itself.send_signal(signal. | 
| 472 | 
            +
                        itself.send_signal(signal.SIGQUIT)
         | 
| 485 473 | 
             
                    except psutil.NoSuchProcess:
         | 
| 486 474 | 
             
                        pass
         | 
| 487 475 |  | 
| @@ -517,6 +505,11 @@ def monkey_patch_vllm_p2p_access_check(gpu_id: int): | |
| 517 505 |  | 
| 518 506 | 
             
                setattr(tgt, "gpu_p2p_access_check", lambda *arg, **kwargs: True)
         | 
| 519 507 |  | 
| 508 | 
            +
                # Suppress the warnings from this delete function when using sglang.bench_one_batch
         | 
| 509 | 
            +
                from vllm.distributed.device_communicators.custom_all_reduce import CustomAllreduce
         | 
| 510 | 
            +
             | 
| 511 | 
            +
                setattr(CustomAllreduce, "__del__", lambda *args, **kwargs: None)
         | 
| 512 | 
            +
             | 
| 520 513 |  | 
| 521 514 | 
             
            vllm_all_gather_backup = None
         | 
| 522 515 |  | 
| @@ -626,7 +619,7 @@ def add_api_key_middleware(app, api_key: str): | |
| 626 619 |  | 
| 627 620 |  | 
| 628 621 | 
             
            def prepare_model_and_tokenizer(model_path: str, tokenizer_path: str):
         | 
| 629 | 
            -
                if "SGLANG_USE_MODELSCOPE" | 
| 622 | 
            +
                if get_bool_env_var("SGLANG_USE_MODELSCOPE"):
         | 
| 630 623 | 
             
                    if not os.path.exists(model_path):
         | 
| 631 624 | 
             
                        from modelscope import snapshot_download
         | 
| 632 625 |  | 
| @@ -931,7 +924,7 @@ def get_nvgpu_memory_capacity(): | |
| 931 924 |  | 
| 932 925 | 
             
            def crash_on_warnings():
         | 
| 933 926 | 
             
                # Crash on warning if we are running CI tests
         | 
| 934 | 
            -
                return  | 
| 927 | 
            +
                return get_bool_env_var("SGLANG_IS_IN_CI")
         | 
| 935 928 |  | 
| 936 929 |  | 
| 937 930 | 
             
            def get_device_name(device_id: int = 0) -> str:
         | 
| @@ -990,7 +983,7 @@ def direct_register_custom_op( | |
| 990 983 | 
             
                    my_lib._register_fake(op_name, fake_impl)
         | 
| 991 984 |  | 
| 992 985 |  | 
| 993 | 
            -
            def  | 
| 986 | 
            +
            def set_gpu_proc_affinity(
         | 
| 994 987 | 
             
                tp_size: int,
         | 
| 995 988 | 
             
                nnodes: int,
         | 
| 996 989 | 
             
                gpu_id: int,
         | 
| @@ -1022,3 +1015,8 @@ def gpu_proc_affinity( | |
| 1022 1015 | 
             
                # set cpu_affinity to current process
         | 
| 1023 1016 | 
             
                p.cpu_affinity(bind_cpu_ids)
         | 
| 1024 1017 | 
             
                logger.info(f"Process {pid} gpu_id {gpu_id} is running on CPUs: {p.cpu_affinity()}")
         | 
| 1018 | 
            +
             | 
| 1019 | 
            +
             | 
| 1020 | 
            +
            def get_bool_env_var(name: str, default: str = "false") -> bool:
         | 
| 1021 | 
            +
                value = os.getenv(name, default)
         | 
| 1022 | 
            +
                return value.lower() in ("true", "1")
         | 
    
        sglang/test/test_utils.py
    CHANGED
    
    | @@ -22,7 +22,7 @@ from sglang.bench_serving import run_benchmark | |
| 22 22 | 
             
            from sglang.global_config import global_config
         | 
| 23 23 | 
             
            from sglang.lang.backend.openai import OpenAI
         | 
| 24 24 | 
             
            from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
         | 
| 25 | 
            -
            from sglang.srt.utils import  | 
| 25 | 
            +
            from sglang.srt.utils import get_bool_env_var, kill_process_tree
         | 
| 26 26 | 
             
            from sglang.test.run_eval import run_eval
         | 
| 27 27 | 
             
            from sglang.utils import get_exception_traceback
         | 
| 28 28 |  | 
| @@ -44,7 +44,7 @@ DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_QUANT_TP1 = "hugging-quants/Meta-Llama-3.1-8 | |
| 44 44 |  | 
| 45 45 | 
             
            def is_in_ci():
         | 
| 46 46 | 
             
                """Return whether it is in CI runner."""
         | 
| 47 | 
            -
                return  | 
| 47 | 
            +
                return get_bool_env_var("SGLANG_IS_IN_CI")
         | 
| 48 48 |  | 
| 49 49 |  | 
| 50 50 | 
             
            if is_in_ci():
         | 
| @@ -504,7 +504,7 @@ def run_unittest_files(files: List[str], timeout_per_file: float): | |
| 504 504 | 
             
                        )
         | 
| 505 505 | 
             
                        assert ret_code == 0
         | 
| 506 506 | 
             
                    except TimeoutError:
         | 
| 507 | 
            -
                         | 
| 507 | 
            +
                        kill_process_tree(process.pid)
         | 
| 508 508 | 
             
                        time.sleep(5)
         | 
| 509 509 | 
             
                        print(
         | 
| 510 510 | 
             
                            f"\nTimeout after {timeout_per_file} seconds when running {filename}\n",
         | 
| @@ -578,7 +578,7 @@ def run_bench_serving( | |
| 578 578 | 
             
                        run_benchmark(warmup_args)
         | 
| 579 579 | 
             
                    res = run_benchmark(args)
         | 
| 580 580 | 
             
                finally:
         | 
| 581 | 
            -
                     | 
| 581 | 
            +
                    kill_process_tree(process.pid)
         | 
| 582 582 |  | 
| 583 583 | 
             
                assert res["completed"] == num_prompts
         | 
| 584 584 | 
             
                return res
         | 
| @@ -611,7 +611,7 @@ def run_bench_one_batch(model, other_args): | |
| 611 611 | 
             
                    lastline = output.split("\n")[-3]
         | 
| 612 612 | 
             
                    output_throughput = float(lastline.split(" ")[-2])
         | 
| 613 613 | 
             
                finally:
         | 
| 614 | 
            -
                     | 
| 614 | 
            +
                    kill_process_tree(process.pid)
         | 
| 615 615 |  | 
| 616 616 | 
             
                return output_throughput
         | 
| 617 617 |  | 
| @@ -677,8 +677,14 @@ def run_and_check_memory_leak( | |
| 677 677 | 
             
                enable_mixed_chunk,
         | 
| 678 678 | 
             
                disable_overlap,
         | 
| 679 679 | 
             
                chunked_prefill_size,
         | 
| 680 | 
            +
                assert_has_abort,
         | 
| 680 681 | 
             
            ):
         | 
| 681 | 
            -
                other_args = [ | 
| 682 | 
            +
                other_args = [
         | 
| 683 | 
            +
                    "--chunked-prefill-size",
         | 
| 684 | 
            +
                    str(chunked_prefill_size),
         | 
| 685 | 
            +
                    "--log-level",
         | 
| 686 | 
            +
                    "debug",
         | 
| 687 | 
            +
                ]
         | 
| 682 688 | 
             
                if disable_radix_cache:
         | 
| 683 689 | 
             
                    other_args += ["--disable-radix-cache"]
         | 
| 684 690 | 
             
                if enable_mixed_chunk:
         | 
| @@ -710,8 +716,8 @@ def run_and_check_memory_leak( | |
| 710 716 | 
             
                workload_func(base_url, model)
         | 
| 711 717 |  | 
| 712 718 | 
             
                # Clean up everything
         | 
| 713 | 
            -
                 | 
| 714 | 
            -
                 | 
| 719 | 
            +
                kill_process_tree(process.pid)
         | 
| 720 | 
            +
                kill_process_tree(process.pid)
         | 
| 715 721 | 
             
                stdout.close()
         | 
| 716 722 | 
             
                stderr.close()
         | 
| 717 723 | 
             
                if os.path.exists(STDOUT_FILENAME):
         | 
| @@ -723,14 +729,19 @@ def run_and_check_memory_leak( | |
| 723 729 | 
             
                # Assert success
         | 
| 724 730 | 
             
                has_new_server = False
         | 
| 725 731 | 
             
                has_leak = False
         | 
| 732 | 
            +
                has_abort = False
         | 
| 726 733 | 
             
                for line in output_lines:
         | 
| 727 734 | 
             
                    if "The server is fired" in line:
         | 
| 728 735 | 
             
                        has_new_server = True
         | 
| 729 736 | 
             
                    if "leak" in line:
         | 
| 730 737 | 
             
                        has_leak = True
         | 
| 738 | 
            +
                    if "Abort" in line:
         | 
| 739 | 
            +
                        has_abort = True
         | 
| 731 740 |  | 
| 732 741 | 
             
                assert has_new_server
         | 
| 733 742 | 
             
                assert not has_leak
         | 
| 743 | 
            +
                if assert_has_abort:
         | 
| 744 | 
            +
                    assert has_abort
         | 
| 734 745 |  | 
| 735 746 |  | 
| 736 747 | 
             
            def run_mmlu_test(
         | 
| @@ -761,6 +772,7 @@ def run_mmlu_test( | |
| 761 772 | 
             
                    enable_mixed_chunk,
         | 
| 762 773 | 
             
                    disable_overlap,
         | 
| 763 774 | 
             
                    chunked_prefill_size,
         | 
| 775 | 
            +
                    assert_has_abort=False,
         | 
| 764 776 | 
             
                )
         | 
| 765 777 |  | 
| 766 778 |  | 
| @@ -800,4 +812,5 @@ def run_mulit_request_test( | |
| 800 812 | 
             
                    enable_mixed_chunk,
         | 
| 801 813 | 
             
                    enable_overlap,
         | 
| 802 814 | 
             
                    chunked_prefill_size,
         | 
| 815 | 
            +
                    assert_has_abort=False,
         | 
| 803 816 | 
             
                )
         | 
    
        sglang/utils.py
    CHANGED
    
    | @@ -348,9 +348,9 @@ def wait_for_server(base_url: str, timeout: int = None) -> None: | |
| 348 348 |  | 
| 349 349 |  | 
| 350 350 | 
             
            def terminate_process(process):
         | 
| 351 | 
            -
                from sglang.srt.utils import  | 
| 351 | 
            +
                from sglang.srt.utils import kill_process_tree
         | 
| 352 352 |  | 
| 353 | 
            -
                 | 
| 353 | 
            +
                kill_process_tree(process.pid)
         | 
| 354 354 |  | 
| 355 355 |  | 
| 356 356 | 
             
            def print_highlight(html_content: str):
         | 
    
        sglang/version.py
    CHANGED
    
    | @@ -1 +1 @@ | |
| 1 | 
            -
            __version__ = "0.3.6. | 
| 1 | 
            +
            __version__ = "0.3.6.post3"
         | 
| @@ -1,6 +1,6 @@ | |
| 1 1 | 
             
            Metadata-Version: 2.1
         | 
| 2 2 | 
             
            Name: sglang
         | 
| 3 | 
            -
            Version: 0.3.6. | 
| 3 | 
            +
            Version: 0.3.6.post3
         | 
| 4 4 | 
             
            Summary: SGLang is yet another fast serving framework for large language models and vision language models.
         | 
| 5 5 | 
             
            License:                                  Apache License
         | 
| 6 6 | 
             
                                               Version 2.0, January 2004
         | 
| @@ -240,6 +240,8 @@ Provides-Extra: srt | |
| 240 240 | 
             
            Requires-Dist: sglang[runtime_common]; extra == "srt"
         | 
| 241 241 | 
             
            Requires-Dist: torch; extra == "srt"
         | 
| 242 242 | 
             
            Requires-Dist: vllm>=0.6.3.post1; extra == "srt"
         | 
| 243 | 
            +
            Requires-Dist: cuda-python; extra == "srt"
         | 
| 244 | 
            +
            Requires-Dist: flashinfer>=0.1.6; extra == "srt"
         | 
| 243 245 | 
             
            Provides-Extra: srt-hip
         | 
| 244 246 | 
             
            Requires-Dist: sglang[runtime_common]; extra == "srt-hip"
         | 
| 245 247 | 
             
            Requires-Dist: torch; extra == "srt-hip"
         | 
| @@ -350,7 +352,7 @@ Learn more in our release blogs: [v0.2 blog](https://lmsys.org/blog/2024-07-25-s | |
| 350 352 | 
             
            [Development Roadmap (2024 Q4)](https://github.com/sgl-project/sglang/issues/1487)
         | 
| 351 353 |  | 
| 352 354 | 
             
            ## Adoption and Sponsorship
         | 
| 353 | 
            -
            The project is supported by (alphabetically): AMD, Baseten, Etched, Hyperbolic, Jam & Tea Studios, LinkedIn, NVIDIA, RunPod, Stanford, UC Berkeley, and  | 
| 355 | 
            +
            The project is supported by (alphabetically): AMD, Baseten, Etched, Hyperbolic, Jam & Tea Studios, LinkedIn, NVIDIA, RunPod, Stanford, UC Berkeley, xAI and 01.AI.
         | 
| 354 356 |  | 
| 355 357 | 
             
            ## Acknowledgment and Citation
         | 
| 356 358 | 
             
            We learned from the design and reused code from the following projects: [Guidance](https://github.com/guidance-ai/guidance), [vLLM](https://github.com/vllm-project/vllm), [LightLLM](https://github.com/ModelTC/lightllm), [FlashInfer](https://github.com/flashinfer-ai/flashinfer), [Outlines](https://github.com/outlines-dev/outlines), and [LMQL](https://github.com/eth-sri/lmql).
         |