sglang 0.5.4__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +56 -12
 - sglang/launch_server.py +2 -0
 - sglang/srt/batch_invariant_ops/batch_invariant_ops.py +101 -4
 - sglang/srt/compilation/backend.py +1 -1
 - sglang/srt/configs/model_config.py +5 -5
 - sglang/srt/distributed/parallel_state.py +0 -7
 - sglang/srt/entrypoints/engine.py +18 -15
 - sglang/srt/entrypoints/grpc_server.py +0 -1
 - sglang/srt/entrypoints/http_server.py +75 -94
 - sglang/srt/environ.py +16 -2
 - sglang/srt/eplb/expert_distribution.py +30 -0
 - sglang/srt/function_call/function_call_parser.py +2 -0
 - sglang/srt/function_call/minimax_m2.py +367 -0
 - sglang/srt/layers/activation.py +6 -0
 - sglang/srt/layers/attention/flashattention_backend.py +12 -2
 - sglang/srt/layers/attention/flashinfer_backend.py +10 -1
 - sglang/srt/layers/attention/flashinfer_mla_backend.py +18 -10
 - sglang/srt/layers/attention/trtllm_mla_backend.py +1 -13
 - sglang/srt/layers/attention/utils.py +78 -0
 - sglang/srt/layers/communicator.py +1 -0
 - sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
 - sglang/srt/layers/layernorm.py +19 -4
 - sglang/srt/layers/logits_processor.py +5 -0
 - sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
 - sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
 - sglang/srt/layers/moe/ep_moe/layer.py +79 -272
 - sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
 - sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
 - sglang/srt/layers/moe/moe_runner/deep_gemm.py +287 -22
 - sglang/srt/layers/moe/moe_runner/runner.py +3 -0
 - sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
 - sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
 - sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
 - sglang/srt/layers/moe/token_dispatcher/deepep.py +18 -14
 - sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
 - sglang/srt/layers/moe/topk.py +4 -4
 - sglang/srt/layers/moe/utils.py +3 -4
 - sglang/srt/layers/quantization/__init__.py +3 -5
 - sglang/srt/layers/quantization/awq.py +0 -3
 - sglang/srt/layers/quantization/base_config.py +7 -0
 - sglang/srt/layers/quantization/fp8.py +68 -63
 - sglang/srt/layers/quantization/gguf.py +566 -0
 - sglang/srt/layers/quantization/mxfp4.py +30 -38
 - sglang/srt/layers/quantization/unquant.py +23 -45
 - sglang/srt/layers/quantization/w4afp8.py +38 -2
 - sglang/srt/layers/radix_attention.py +5 -2
 - sglang/srt/layers/rotary_embedding.py +13 -1
 - sglang/srt/layers/sampler.py +12 -1
 - sglang/srt/managers/io_struct.py +3 -0
 - sglang/srt/managers/multi_tokenizer_mixin.py +17 -1
 - sglang/srt/managers/scheduler.py +21 -15
 - sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
 - sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
 - sglang/srt/managers/tokenizer_manager.py +11 -19
 - sglang/srt/mem_cache/hicache_storage.py +7 -1
 - sglang/srt/mem_cache/memory_pool.py +82 -0
 - sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
 - sglang/srt/model_executor/forward_batch_info.py +44 -3
 - sglang/srt/model_executor/model_runner.py +1 -149
 - sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
 - sglang/srt/models/deepseek_v2.py +147 -44
 - sglang/srt/models/glm4_moe.py +322 -354
 - sglang/srt/models/glm4_moe_nextn.py +4 -14
 - sglang/srt/models/glm4v_moe.py +29 -196
 - sglang/srt/models/minimax_m2.py +922 -0
 - sglang/srt/models/nvila.py +355 -0
 - sglang/srt/models/nvila_lite.py +184 -0
 - sglang/srt/models/qwen2.py +22 -1
 - sglang/srt/models/qwen3.py +34 -4
 - sglang/srt/models/qwen3_moe.py +2 -4
 - sglang/srt/multimodal/processors/base_processor.py +1 -0
 - sglang/srt/multimodal/processors/glm4v.py +1 -1
 - sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
 - sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
 - sglang/srt/parser/reasoning_parser.py +28 -1
 - sglang/srt/server_args.py +365 -186
 - sglang/srt/single_batch_overlap.py +2 -7
 - sglang/srt/utils/common.py +87 -42
 - sglang/srt/utils/hf_transformers_utils.py +7 -3
 - sglang/test/test_deterministic.py +235 -12
 - sglang/test/test_deterministic_utils.py +2 -1
 - sglang/version.py +1 -1
 - {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +7 -6
 - {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +87 -82
 - sglang/srt/models/vila.py +0 -306
 - {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
 - {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
 - {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
 
    
        sglang/bench_serving.py
    CHANGED
    
    | 
         @@ -88,6 +88,7 @@ class RequestFuncOutput: 
     | 
|
| 
       88 
88 
     | 
    
         
             
                latency: float = 0.0
         
     | 
| 
       89 
89 
     | 
    
         
             
                ttft: float = 0.0  # Time to first token
         
     | 
| 
       90 
90 
     | 
    
         
             
                itl: List[float] = field(default_factory=list)  # List of inter-token latencies
         
     | 
| 
      
 91 
     | 
    
         
            +
                text_chunks: List[str] = field(default_factory=list)
         
     | 
| 
       91 
92 
     | 
    
         
             
                prompt_len: int = 0
         
     | 
| 
       92 
93 
     | 
    
         
             
                error: str = ""
         
     | 
| 
       93 
94 
     | 
    
         
             
                output_len: int = 0
         
     | 
| 
         @@ -258,6 +259,9 @@ async def async_request_openai_completions( 
     | 
|
| 
       258 
259 
     | 
    
         | 
| 
       259 
260 
     | 
    
         
             
                                            # Decoding phase
         
     | 
| 
       260 
261 
     | 
    
         
             
                                            else:
         
     | 
| 
      
 262 
     | 
    
         
            +
                                                output.text_chunks.append(
         
     | 
| 
      
 263 
     | 
    
         
            +
                                                    data["choices"][0]["text"]
         
     | 
| 
      
 264 
     | 
    
         
            +
                                                )
         
     | 
| 
       261 
265 
     | 
    
         
             
                                                output.itl.append(timestamp - most_recent_timestamp)
         
     | 
| 
       262 
266 
     | 
    
         | 
| 
       263 
267 
     | 
    
         
             
                                            most_recent_timestamp = timestamp
         
     | 
| 
         @@ -574,9 +578,8 @@ async def async_request_sglang_generate( 
     | 
|
| 
       574 
578 
     | 
    
         
             
                                                num_new_tokens = output_len - last_output_len
         
     | 
| 
       575 
579 
     | 
    
         
             
                                                if num_new_tokens == 0:
         
     | 
| 
       576 
580 
     | 
    
         
             
                                                    continue
         
     | 
| 
       577 
     | 
    
         
            -
                                                 
     | 
| 
       578 
     | 
    
         
            -
             
     | 
| 
       579 
     | 
    
         
            -
                                                ) / num_new_tokens
         
     | 
| 
      
 581 
     | 
    
         
            +
                                                chunk_gap = timestamp - most_recent_timestamp
         
     | 
| 
      
 582 
     | 
    
         
            +
                                                adjust_itl = chunk_gap / num_new_tokens
         
     | 
| 
       580 
583 
     | 
    
         
             
                                                output.itl.extend([adjust_itl] * num_new_tokens)
         
     | 
| 
       581 
584 
     | 
    
         | 
| 
       582 
585 
     | 
    
         
             
                                            most_recent_timestamp = timestamp
         
     | 
| 
         @@ -764,6 +767,7 @@ def get_dataset(args, tokenizer, model_id=None): 
     | 
|
| 
       764 
767 
     | 
    
         
             
                        image_content=args.image_content,
         
     | 
| 
       765 
768 
     | 
    
         
             
                        image_format=args.image_format,
         
     | 
| 
       766 
769 
     | 
    
         
             
                        image_resolution=args.image_resolution,
         
     | 
| 
      
 770 
     | 
    
         
            +
                        backend=args.backend,
         
     | 
| 
       767 
771 
     | 
    
         
             
                    )
         
     | 
| 
       768 
772 
     | 
    
         
             
                elif args.dataset_name == "generated-shared-prefix":
         
     | 
| 
       769 
773 
     | 
    
         
             
                    assert not tokenize_prompt
         
     | 
| 
         @@ -781,6 +785,7 @@ def get_dataset(args, tokenizer, model_id=None): 
     | 
|
| 
       781 
785 
     | 
    
         
             
                    input_requests = sample_mmmu_requests(
         
     | 
| 
       782 
786 
     | 
    
         
             
                        num_requests=args.num_prompts,
         
     | 
| 
       783 
787 
     | 
    
         
             
                        processor=processor,
         
     | 
| 
      
 788 
     | 
    
         
            +
                        backend=args.backend,
         
     | 
| 
       784 
789 
     | 
    
         
             
                        fixed_output_len=args.random_output_len,
         
     | 
| 
       785 
790 
     | 
    
         
             
                        random_sample=True,
         
     | 
| 
       786 
791 
     | 
    
         
             
                    )
         
     | 
| 
         @@ -1009,6 +1014,7 @@ async def get_mooncake_request_over_time( 
     | 
|
| 
       1009 
1014 
     | 
    
         
             
            def sample_mmmu_requests(
         
     | 
| 
       1010 
1015 
     | 
    
         
             
                num_requests: int,
         
     | 
| 
       1011 
1016 
     | 
    
         
             
                processor: AutoProcessor | AutoTokenizer,
         
     | 
| 
      
 1017 
     | 
    
         
            +
                backend: str,
         
     | 
| 
       1012 
1018 
     | 
    
         
             
                fixed_output_len: Optional[int] = None,
         
     | 
| 
       1013 
1019 
     | 
    
         
             
                random_sample: bool = True,
         
     | 
| 
       1014 
1020 
     | 
    
         
             
            ) -> List[DatasetRow]:
         
     | 
| 
         @@ -1081,7 +1087,7 @@ def sample_mmmu_requests( 
     | 
|
| 
       1081 
1087 
     | 
    
         
             
                            text_prompt = f"Question: {question}\n\nAnswer: "
         
     | 
| 
       1082 
1088 
     | 
    
         
             
                            output_len = fixed_output_len if fixed_output_len is not None else 256
         
     | 
| 
       1083 
1089 
     | 
    
         
             
                            data_row = create_mm_data_row(
         
     | 
| 
       1084 
     | 
    
         
            -
                                text_prompt, [image], [image_data], output_len, processor
         
     | 
| 
      
 1090 
     | 
    
         
            +
                                text_prompt, [image], [image_data], output_len, processor, backend
         
     | 
| 
       1085 
1091 
     | 
    
         
             
                            )
         
     | 
| 
       1086 
1092 
     | 
    
         
             
                            filtered_dataset.append(data_row)
         
     | 
| 
       1087 
1093 
     | 
    
         | 
| 
         @@ -1316,13 +1322,19 @@ def parse_image_resolution(image_resolution: str) -> Tuple[int, int]: 
     | 
|
| 
       1316 
1322 
     | 
    
         
             
                )
         
     | 
| 
       1317 
1323 
     | 
    
         | 
| 
       1318 
1324 
     | 
    
         | 
| 
       1319 
     | 
    
         
            -
            def create_mm_data_row( 
     | 
| 
      
 1325 
     | 
    
         
            +
            def create_mm_data_row(
         
     | 
| 
      
 1326 
     | 
    
         
            +
                text_prompt, images: list, images_base64, output_len, processor, backend
         
     | 
| 
      
 1327 
     | 
    
         
            +
            ):
         
     | 
| 
       1320 
1328 
     | 
    
         
             
                try:
         
     | 
| 
       1321 
     | 
    
         
            -
                     
     | 
| 
       1322 
     | 
    
         
            -
                         
     | 
| 
       1323 
     | 
    
         
            -
                         
     | 
| 
       1324 
     | 
    
         
            -
                     
     | 
| 
       1325 
     | 
    
         
            -
             
     | 
| 
      
 1329 
     | 
    
         
            +
                    if type(processor).__name__ == "Phi4MMProcessor":
         
     | 
| 
      
 1330 
     | 
    
         
            +
                        # <|endoftext10|> is the image token used in the phi-4-multimodal model.
         
     | 
| 
      
 1331 
     | 
    
         
            +
                        content_items = text_prompt.replace("image 1", "|endoftext10|")
         
     | 
| 
      
 1332 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 1333 
     | 
    
         
            +
                        content_items = [
         
     | 
| 
      
 1334 
     | 
    
         
            +
                            {"type": "image", "image": {"url": image_base64}}
         
     | 
| 
      
 1335 
     | 
    
         
            +
                            for image_base64 in images_base64
         
     | 
| 
      
 1336 
     | 
    
         
            +
                        ]
         
     | 
| 
      
 1337 
     | 
    
         
            +
                        content_items.append({"type": "text", "text": text_prompt})
         
     | 
| 
       1326 
1338 
     | 
    
         
             
                    prompt_str = processor.apply_chat_template(
         
     | 
| 
       1327 
1339 
     | 
    
         
             
                        [{"role": "user", "content": content_items}],
         
     | 
| 
       1328 
1340 
     | 
    
         
             
                        add_generation_prompt=True,
         
     | 
| 
         @@ -1362,8 +1374,16 @@ def create_mm_data_row(text_prompt, images: list, images_base64, output_len, pro 
     | 
|
| 
       1362 
1374 
     | 
    
         
             
                # Vision tokens = total tokens - text tokens
         
     | 
| 
       1363 
1375 
     | 
    
         
             
                vision_prompt_len = prompt_len - text_prompt_len
         
     | 
| 
       1364 
1376 
     | 
    
         | 
| 
      
 1377 
     | 
    
         
            +
                use_raw_prompt = backend in [
         
     | 
| 
      
 1378 
     | 
    
         
            +
                    "sglang-oai",
         
     | 
| 
      
 1379 
     | 
    
         
            +
                    "sglang-oai-chat",
         
     | 
| 
      
 1380 
     | 
    
         
            +
                    "vllm",
         
     | 
| 
      
 1381 
     | 
    
         
            +
                    "vllm-chat",
         
     | 
| 
      
 1382 
     | 
    
         
            +
                    "lmdeploy",
         
     | 
| 
      
 1383 
     | 
    
         
            +
                    "lmdeploy-chat",
         
     | 
| 
      
 1384 
     | 
    
         
            +
                ]
         
     | 
| 
       1365 
1385 
     | 
    
         
             
                return DatasetRow(
         
     | 
| 
       1366 
     | 
    
         
            -
                    prompt=text_prompt,
         
     | 
| 
      
 1386 
     | 
    
         
            +
                    prompt=text_prompt if use_raw_prompt else prompt_str,
         
     | 
| 
       1367 
1387 
     | 
    
         
             
                    prompt_len=prompt_len,
         
     | 
| 
       1368 
1388 
     | 
    
         
             
                    output_len=output_len,
         
     | 
| 
       1369 
1389 
     | 
    
         
             
                    text_prompt_len=text_prompt_len,
         
     | 
| 
         @@ -1382,6 +1402,7 @@ def sample_image_requests( 
     | 
|
| 
       1382 
1402 
     | 
    
         
             
                image_content: str,
         
     | 
| 
       1383 
1403 
     | 
    
         
             
                image_format: str,
         
     | 
| 
       1384 
1404 
     | 
    
         
             
                image_resolution: str,
         
     | 
| 
      
 1405 
     | 
    
         
            +
                backend: str,
         
     | 
| 
       1385 
1406 
     | 
    
         
             
            ) -> List[DatasetRow]:
         
     | 
| 
       1386 
1407 
     | 
    
         
             
                """Generate requests with images.
         
     | 
| 
       1387 
1408 
     | 
    
         | 
| 
         @@ -1447,6 +1468,7 @@ def sample_image_requests( 
     | 
|
| 
       1447 
1468 
     | 
    
         
             
                        list(images_base64),
         
     | 
| 
       1448 
1469 
     | 
    
         
             
                        int(output_lens[i]),
         
     | 
| 
       1449 
1470 
     | 
    
         
             
                        processor,
         
     | 
| 
      
 1471 
     | 
    
         
            +
                        backend,
         
     | 
| 
       1450 
1472 
     | 
    
         
             
                    )
         
     | 
| 
       1451 
1473 
     | 
    
         | 
| 
       1452 
1474 
     | 
    
         
             
                    dataset.append(data_row)
         
     | 
| 
         @@ -1607,6 +1629,7 @@ def calculate_metrics( 
     | 
|
| 
       1607 
1629 
     | 
    
         
             
                dur_s: float,
         
     | 
| 
       1608 
1630 
     | 
    
         
             
                tokenizer: PreTrainedTokenizerBase,
         
     | 
| 
       1609 
1631 
     | 
    
         
             
                backend: str,
         
     | 
| 
      
 1632 
     | 
    
         
            +
                accept_length: Optional[float] = None,
         
     | 
| 
       1610 
1633 
     | 
    
         
             
            ) -> Tuple[BenchmarkMetrics, List[int]]:
         
     | 
| 
       1611 
1634 
     | 
    
         
             
                output_lens: List[int] = []
         
     | 
| 
       1612 
1635 
     | 
    
         
             
                retokenized_output_lens: List[int] = []
         
     | 
| 
         @@ -1618,6 +1641,14 @@ def calculate_metrics( 
     | 
|
| 
       1618 
1641 
     | 
    
         
             
                tpots: List[float] = []
         
     | 
| 
       1619 
1642 
     | 
    
         
             
                ttfts: List[float] = []
         
     | 
| 
       1620 
1643 
     | 
    
         
             
                e2e_latencies: List[float] = []
         
     | 
| 
      
 1644 
     | 
    
         
            +
                retokenized_itls: List[float] = []
         
     | 
| 
      
 1645 
     | 
    
         
            +
             
     | 
| 
      
 1646 
     | 
    
         
            +
                use_retokenized_itl = (
         
     | 
| 
      
 1647 
     | 
    
         
            +
                    accept_length is not None
         
     | 
| 
      
 1648 
     | 
    
         
            +
                    and accept_length > 0
         
     | 
| 
      
 1649 
     | 
    
         
            +
                    and backend in ("sglang-oai", "sglang-oai-chat")
         
     | 
| 
      
 1650 
     | 
    
         
            +
                )
         
     | 
| 
      
 1651 
     | 
    
         
            +
             
     | 
| 
       1621 
1652 
     | 
    
         
             
                for i in range(len(outputs)):
         
     | 
| 
       1622 
1653 
     | 
    
         
             
                    if outputs[i].success:
         
     | 
| 
       1623 
1654 
     | 
    
         
             
                        output_len = outputs[i].output_len
         
     | 
| 
         @@ -1631,7 +1662,17 @@ def calculate_metrics( 
     | 
|
| 
       1631 
1662 
     | 
    
         
             
                        total_input_vision += input_requests[i].vision_prompt_len
         
     | 
| 
       1632 
1663 
     | 
    
         
             
                        if output_len > 1:
         
     | 
| 
       1633 
1664 
     | 
    
         
             
                            tpots.append((outputs[i].latency - outputs[i].ttft) / (output_len - 1))
         
     | 
| 
       1634 
     | 
    
         
            -
                         
     | 
| 
      
 1665 
     | 
    
         
            +
                        if use_retokenized_itl:
         
     | 
| 
      
 1666 
     | 
    
         
            +
                            for k, itl in enumerate(outputs[i].itl):
         
     | 
| 
      
 1667 
     | 
    
         
            +
                                num_tokens = len(
         
     | 
| 
      
 1668 
     | 
    
         
            +
                                    tokenizer.encode(
         
     | 
| 
      
 1669 
     | 
    
         
            +
                                        outputs[i].text_chunks[k], add_special_tokens=False
         
     | 
| 
      
 1670 
     | 
    
         
            +
                                    )
         
     | 
| 
      
 1671 
     | 
    
         
            +
                                )
         
     | 
| 
      
 1672 
     | 
    
         
            +
                                adjusted_itl = itl / num_tokens
         
     | 
| 
      
 1673 
     | 
    
         
            +
                                retokenized_itls.extend([adjusted_itl] * num_tokens)
         
     | 
| 
      
 1674 
     | 
    
         
            +
                        else:
         
     | 
| 
      
 1675 
     | 
    
         
            +
                            itls += outputs[i].itl
         
     | 
| 
       1635 
1676 
     | 
    
         
             
                        ttfts.append(outputs[i].ttft)
         
     | 
| 
       1636 
1677 
     | 
    
         | 
| 
       1637 
1678 
     | 
    
         
             
                        e2e_latencies.append(outputs[i].latency)
         
     | 
| 
         @@ -1647,6 +1688,8 @@ def calculate_metrics( 
     | 
|
| 
       1647 
1688 
     | 
    
         
             
                        "on the benchmark arguments.",
         
     | 
| 
       1648 
1689 
     | 
    
         
             
                        stacklevel=2,
         
     | 
| 
       1649 
1690 
     | 
    
         
             
                    )
         
     | 
| 
      
 1691 
     | 
    
         
            +
             
     | 
| 
      
 1692 
     | 
    
         
            +
                itls = retokenized_itls if use_retokenized_itl else itls
         
     | 
| 
       1650 
1693 
     | 
    
         
             
                metrics = BenchmarkMetrics(
         
     | 
| 
       1651 
1694 
     | 
    
         
             
                    completed=completed,
         
     | 
| 
       1652 
1695 
     | 
    
         
             
                    total_input=total_input,
         
     | 
| 
         @@ -1910,6 +1953,7 @@ async def benchmark( 
     | 
|
| 
       1910 
1953 
     | 
    
         
             
                    dur_s=benchmark_duration,
         
     | 
| 
       1911 
1954 
     | 
    
         
             
                    tokenizer=tokenizer,
         
     | 
| 
       1912 
1955 
     | 
    
         
             
                    backend=backend,
         
     | 
| 
      
 1956 
     | 
    
         
            +
                    accept_length=accept_length,
         
     | 
| 
       1913 
1957 
     | 
    
         
             
                )
         
     | 
| 
       1914 
1958 
     | 
    
         | 
| 
       1915 
1959 
     | 
    
         
             
                print("\n{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="="))
         
     | 
    
        sglang/launch_server.py
    CHANGED
    
    | 
         @@ -12,10 +12,12 @@ if __name__ == "__main__": 
     | 
|
| 
       12 
12 
     | 
    
         | 
| 
       13 
13 
     | 
    
         
             
                try:
         
     | 
| 
       14 
14 
     | 
    
         
             
                    if server_args.grpc_mode:
         
     | 
| 
      
 15 
     | 
    
         
            +
                        # Handle gRPC server
         
     | 
| 
       15 
16 
     | 
    
         
             
                        from sglang.srt.entrypoints.grpc_server import serve_grpc
         
     | 
| 
       16 
17 
     | 
    
         | 
| 
       17 
18 
     | 
    
         
             
                        asyncio.run(serve_grpc(server_args))
         
     | 
| 
       18 
19 
     | 
    
         
             
                    else:
         
     | 
| 
      
 20 
     | 
    
         
            +
                        # Handle HTTP server
         
     | 
| 
       19 
21 
     | 
    
         
             
                        from sglang.srt.entrypoints.http_server import launch_server
         
     | 
| 
       20 
22 
     | 
    
         | 
| 
       21 
23 
     | 
    
         
             
                        launch_server(server_args)
         
     | 
| 
         @@ -9,6 +9,22 @@ import torch 
     | 
|
| 
       9 
9 
     | 
    
         
             
            import triton
         
     | 
| 
       10 
10 
     | 
    
         
             
            import triton.language as tl
         
     | 
| 
       11 
11 
     | 
    
         | 
| 
      
 12 
     | 
    
         
            +
            from sglang.srt.layers.deep_gemm_wrapper.configurer import ENABLE_JIT_DEEPGEMM
         
     | 
| 
      
 13 
     | 
    
         
            +
            from sglang.srt.utils.common import calc_diff, get_bool_env_var
         
     | 
| 
      
 14 
     | 
    
         
            +
             
     | 
| 
      
 15 
     | 
    
         
            +
            if ENABLE_JIT_DEEPGEMM:
         
     | 
| 
      
 16 
     | 
    
         
            +
                import deep_gemm
         
     | 
| 
      
 17 
     | 
    
         
            +
             
     | 
| 
      
 18 
     | 
    
         
            +
            _ENABLE_MM_DEEPGEMM = get_bool_env_var(
         
     | 
| 
      
 19 
     | 
    
         
            +
                "SGLANG_BATCH_INVARIANT_OPS_ENABLE_MM_DEEPGEMM", "1"
         
     | 
| 
      
 20 
     | 
    
         
            +
            )
         
     | 
| 
      
 21 
     | 
    
         
            +
            _ENABLE_MM_COMPARISON_TEST = get_bool_env_var(
         
     | 
| 
      
 22 
     | 
    
         
            +
                "SGLANG_BATCH_INVARIANT_OPS_ENABLE_MM_COMPARISON_TEST"
         
     | 
| 
      
 23 
     | 
    
         
            +
            )
         
     | 
| 
      
 24 
     | 
    
         
            +
             
     | 
| 
      
 25 
     | 
    
         
            +
            if not _ENABLE_MM_DEEPGEMM:
         
     | 
| 
      
 26 
     | 
    
         
            +
                print("Disable DeepGEMM in batch invariant ops. Performance may be suboptimal.")
         
     | 
| 
      
 27 
     | 
    
         
            +
             
     | 
| 
       12 
28 
     | 
    
         
             
            __all__ = [
         
     | 
| 
       13 
29 
     | 
    
         
             
                "set_batch_invariant_mode",
         
     | 
| 
       14 
30 
     | 
    
         
             
                "is_batch_invariant_mode_enabled",
         
     | 
| 
         @@ -140,7 +156,7 @@ def matmul_kernel_persistent( 
     | 
|
| 
       140 
156 
     | 
    
         
             
                    tl.store(c_ptrs, c, mask=c_mask)
         
     | 
| 
       141 
157 
     | 
    
         | 
| 
       142 
158 
     | 
    
         | 
| 
       143 
     | 
    
         
            -
            def  
     | 
| 
      
 159 
     | 
    
         
            +
            def _matmul_persistent_triton(
         
     | 
| 
       144 
160 
     | 
    
         
             
                a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor | None = None
         
     | 
| 
       145 
161 
     | 
    
         
             
            ):
         
     | 
| 
       146 
162 
     | 
    
         
             
                # Check constraints.
         
     | 
| 
         @@ -217,6 +233,54 @@ def matmul_persistent( 
     | 
|
| 
       217 
233 
     | 
    
         
             
                return c
         
     | 
| 
       218 
234 
     | 
    
         | 
| 
       219 
235 
     | 
    
         | 
| 
      
 236 
     | 
    
         
            +
            def _matmul_persistent_deepgemm(
         
     | 
| 
      
 237 
     | 
    
         
            +
                a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor | None = None
         
     | 
| 
      
 238 
     | 
    
         
            +
            ):
         
     | 
| 
      
 239 
     | 
    
         
            +
                M, K = a.shape
         
     | 
| 
      
 240 
     | 
    
         
            +
                K, N = b.shape
         
     | 
| 
      
 241 
     | 
    
         
            +
                dtype = a.dtype
         
     | 
| 
      
 242 
     | 
    
         
            +
                out = torch.empty((M, N), device=a.device, dtype=dtype)
         
     | 
| 
      
 243 
     | 
    
         
            +
             
     | 
| 
      
 244 
     | 
    
         
            +
                deep_gemm.bf16_gemm_nn(a, b, out)
         
     | 
| 
      
 245 
     | 
    
         
            +
             
     | 
| 
      
 246 
     | 
    
         
            +
                # TODO can this be put in DeepGEMM's `c`?
         
     | 
| 
      
 247 
     | 
    
         
            +
                if bias is not None:
         
     | 
| 
      
 248 
     | 
    
         
            +
                    out += bias
         
     | 
| 
      
 249 
     | 
    
         
            +
             
     | 
| 
      
 250 
     | 
    
         
            +
                return out
         
     | 
| 
      
 251 
     | 
    
         
            +
             
     | 
| 
      
 252 
     | 
    
         
            +
             
     | 
| 
      
 253 
     | 
    
         
            +
            def matmul_persistent(
         
     | 
| 
      
 254 
     | 
    
         
            +
                a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor | None = None
         
     | 
| 
      
 255 
     | 
    
         
            +
            ):
         
     | 
| 
      
 256 
     | 
    
         
            +
                if (
         
     | 
| 
      
 257 
     | 
    
         
            +
                    _ENABLE_MM_DEEPGEMM
         
     | 
| 
      
 258 
     | 
    
         
            +
                    and ENABLE_JIT_DEEPGEMM
         
     | 
| 
      
 259 
     | 
    
         
            +
                    and (a.dtype == torch.bfloat16)
         
     | 
| 
      
 260 
     | 
    
         
            +
                    and (b.dtype == torch.bfloat16)
         
     | 
| 
      
 261 
     | 
    
         
            +
                    and a.is_contiguous()
         
     | 
| 
      
 262 
     | 
    
         
            +
                    and b.transpose(0, 1).is_contiguous()
         
     | 
| 
      
 263 
     | 
    
         
            +
                ):
         
     | 
| 
      
 264 
     | 
    
         
            +
                    if _ENABLE_MM_COMPARISON_TEST:
         
     | 
| 
      
 265 
     | 
    
         
            +
                        out_triton = _matmul_persistent_triton(a=a, b=b, bias=bias)
         
     | 
| 
      
 266 
     | 
    
         
            +
                        out_deepgemm = _matmul_persistent_deepgemm(a=a, b=b, bias=bias)
         
     | 
| 
      
 267 
     | 
    
         
            +
                        diff = calc_diff(out_triton, out_deepgemm)
         
     | 
| 
      
 268 
     | 
    
         
            +
                        assert diff < 0.0001, f"{diff=} {out_triton=} {out_deepgemm=}"
         
     | 
| 
      
 269 
     | 
    
         
            +
                        # can be enabled for debugging
         
     | 
| 
      
 270 
     | 
    
         
            +
                        # print(
         
     | 
| 
      
 271 
     | 
    
         
            +
                        #     f"{diff=} "
         
     | 
| 
      
 272 
     | 
    
         
            +
                        #     f"{(out_triton - out_deepgemm).abs().mean()=} "
         
     | 
| 
      
 273 
     | 
    
         
            +
                        #     f"{(out_triton - out_deepgemm).abs().sum()=} "
         
     | 
| 
      
 274 
     | 
    
         
            +
                        #     f"{torch.sum(out_triton != out_deepgemm)=} "
         
     | 
| 
      
 275 
     | 
    
         
            +
                        # )
         
     | 
| 
      
 276 
     | 
    
         
            +
                        # print(f"{a=} {b=} {bias=} {out_triton=} {out_deepgemm=}")
         
     | 
| 
      
 277 
     | 
    
         
            +
                        return out_deepgemm
         
     | 
| 
      
 278 
     | 
    
         
            +
             
     | 
| 
      
 279 
     | 
    
         
            +
                    return _matmul_persistent_deepgemm(a=a, b=b, bias=bias)
         
     | 
| 
      
 280 
     | 
    
         
            +
             
     | 
| 
      
 281 
     | 
    
         
            +
                return _matmul_persistent_triton(a=a, b=b, bias=bias)
         
     | 
| 
      
 282 
     | 
    
         
            +
             
     | 
| 
      
 283 
     | 
    
         
            +
             
     | 
| 
       220 
284 
     | 
    
         
             
            @triton.jit
         
     | 
| 
       221 
285 
     | 
    
         
             
            def _log_softmax_kernel(
         
     | 
| 
       222 
286 
     | 
    
         
             
                input_ptr,
         
     | 
| 
         @@ -495,16 +559,39 @@ def mean_batch_invariant(input, dim, keepdim=False, dtype: torch.dtype | None = 
     | 
|
| 
       495 
559 
     | 
    
         
             
                    return torch.sum(input, dim=dim, keepdim=keepdim, dtype=torch.float32) / n_elems
         
     | 
| 
       496 
560 
     | 
    
         | 
| 
       497 
561 
     | 
    
         | 
| 
      
 562 
     | 
    
         
            +
            def bmm_batch_invariant(a, b, *, out=None):
         
     | 
| 
      
 563 
     | 
    
         
            +
                # Batched matrix multiply: (B, M, K) x (B, K, N) -> (B, M, N)
         
     | 
| 
      
 564 
     | 
    
         
            +
                # Process each batch separately with our persistent kernel
         
     | 
| 
      
 565 
     | 
    
         
            +
                if a.ndim == 3 and b.ndim == 3:
         
     | 
| 
      
 566 
     | 
    
         
            +
                    results = []
         
     | 
| 
      
 567 
     | 
    
         
            +
                    for i in range(a.shape[0]):
         
     | 
| 
      
 568 
     | 
    
         
            +
                        results.append(matmul_persistent(a[i], b[i]))
         
     | 
| 
      
 569 
     | 
    
         
            +
                    result = torch.stack(results, dim=0)
         
     | 
| 
      
 570 
     | 
    
         
            +
             
     | 
| 
      
 571 
     | 
    
         
            +
                    if out is not None:
         
     | 
| 
      
 572 
     | 
    
         
            +
                        out.copy_(result)
         
     | 
| 
      
 573 
     | 
    
         
            +
                        return out
         
     | 
| 
      
 574 
     | 
    
         
            +
                    return result
         
     | 
| 
      
 575 
     | 
    
         
            +
                else:
         
     | 
| 
      
 576 
     | 
    
         
            +
                    raise ValueError(
         
     | 
| 
      
 577 
     | 
    
         
            +
                        f"bmm_batch_invariant expects 3D tensors, "
         
     | 
| 
      
 578 
     | 
    
         
            +
                        f"got shapes {a.shape} and {b.shape}"
         
     | 
| 
      
 579 
     | 
    
         
            +
                    )
         
     | 
| 
      
 580 
     | 
    
         
            +
             
     | 
| 
      
 581 
     | 
    
         
            +
             
     | 
| 
       498 
582 
     | 
    
         
             
            _batch_invariant_MODE = False
         
     | 
| 
       499 
583 
     | 
    
         
             
            _batch_invariant_LIB = None
         
     | 
| 
      
 584 
     | 
    
         
            +
            _original_torch_bmm = None
         
     | 
| 
       500 
585 
     | 
    
         | 
| 
       501 
586 
     | 
    
         | 
| 
       502 
587 
     | 
    
         
             
            def is_batch_invariant_mode_enabled():
         
     | 
| 
       503 
588 
     | 
    
         
             
                return _batch_invariant_MODE
         
     | 
| 
       504 
589 
     | 
    
         | 
| 
       505 
590 
     | 
    
         | 
| 
       506 
     | 
    
         
            -
            def enable_batch_invariant_mode( 
     | 
| 
       507 
     | 
    
         
            -
                 
     | 
| 
      
 591 
     | 
    
         
            +
            def enable_batch_invariant_mode(
         
     | 
| 
      
 592 
     | 
    
         
            +
                enable_bmm: bool = True,
         
     | 
| 
      
 593 
     | 
    
         
            +
            ):
         
     | 
| 
      
 594 
     | 
    
         
            +
                global _batch_invariant_MODE, _batch_invariant_LIB, _original_torch_bmm
         
     | 
| 
       508 
595 
     | 
    
         
             
                if _batch_invariant_MODE:
         
     | 
| 
       509 
596 
     | 
    
         
             
                    return
         
     | 
| 
       510 
597 
     | 
    
         | 
| 
         @@ -517,11 +604,21 @@ def enable_batch_invariant_mode(): 
     | 
|
| 
       517 
604 
     | 
    
         
             
                )
         
     | 
| 
       518 
605 
     | 
    
         
             
                _batch_invariant_LIB.impl("aten::mean.dim", mean_batch_invariant, "CUDA")
         
     | 
| 
       519 
606 
     | 
    
         | 
| 
      
 607 
     | 
    
         
            +
                if enable_bmm:
         
     | 
| 
      
 608 
     | 
    
         
            +
                    _batch_invariant_LIB.impl("aten::bmm", bmm_batch_invariant, "CUDA")
         
     | 
| 
      
 609 
     | 
    
         
            +
             
     | 
| 
      
 610 
     | 
    
         
            +
                    # Also monkeypatch torch.bmm directly as a fallback
         
     | 
| 
      
 611 
     | 
    
         
            +
                    _original_torch_bmm = torch.bmm
         
     | 
| 
      
 612 
     | 
    
         
            +
                    torch.bmm = bmm_batch_invariant
         
     | 
| 
      
 613 
     | 
    
         
            +
             
     | 
| 
       520 
614 
     | 
    
         | 
| 
       521 
615 
     | 
    
         
             
            def disable_batch_invariant_mode():
         
     | 
| 
       522 
     | 
    
         
            -
                global _batch_invariant_MODE, _batch_invariant_LIB
         
     | 
| 
      
 616 
     | 
    
         
            +
                global _batch_invariant_MODE, _batch_invariant_LIB, _original_torch_bmm
         
     | 
| 
       523 
617 
     | 
    
         
             
                if _batch_invariant_LIB is not None:
         
     | 
| 
       524 
618 
     | 
    
         
             
                    _batch_invariant_LIB._destroy()
         
     | 
| 
      
 619 
     | 
    
         
            +
                if _original_torch_bmm is not None:
         
     | 
| 
      
 620 
     | 
    
         
            +
                    torch.bmm = _original_torch_bmm
         
     | 
| 
      
 621 
     | 
    
         
            +
                    _original_torch_bmm = None
         
     | 
| 
       525 
622 
     | 
    
         
             
                _batch_invariant_MODE = False
         
     | 
| 
       526 
623 
     | 
    
         
             
                _batch_invariant_LIB = None
         
     | 
| 
       527 
624 
     | 
    
         | 
| 
         @@ -392,7 +392,7 @@ class SGLangBackend: 
     | 
|
| 
       392 
392 
     | 
    
         
             
                    self.configure_post_pass()
         
     | 
| 
       393 
393 
     | 
    
         | 
| 
       394 
394 
     | 
    
         
             
                    self.split_gm, self.piecewise_graphs = split_graph(
         
     | 
| 
       395 
     | 
    
         
            -
                        graph, ["sglang.unified_attention_with_output"]
         
     | 
| 
      
 395 
     | 
    
         
            +
                        graph, ["sglang.unified_attention_with_output", "sglang.inplace_all_reduce"]
         
     | 
| 
       396 
396 
     | 
    
         
             
                    )
         
     | 
| 
       397 
397 
     | 
    
         | 
| 
       398 
398 
     | 
    
         
             
                    from torch._dynamo.utils import lazy_format_graph_code
         
     | 
| 
         @@ -535,7 +535,7 @@ class ModelConfig: 
     | 
|
| 
       535 
535 
     | 
    
         
             
                            quant_cfg = self._parse_modelopt_quant_config(quant_config_dict)
         
     | 
| 
       536 
536 
     | 
    
         
             
                    return quant_cfg
         
     | 
| 
       537 
537 
     | 
    
         | 
| 
       538 
     | 
    
         
            -
                def _parse_modelopt_quant_config(self, quant_config_dict: dict) -> dict:
         
     | 
| 
      
 538 
     | 
    
         
            +
                def _parse_modelopt_quant_config(self, quant_config_dict: dict) -> Optional[dict]:
         
     | 
| 
       539 
539 
     | 
    
         
             
                    """Parse ModelOpt quantization config and return the appropriate quant_method."""
         
     | 
| 
       540 
540 
     | 
    
         
             
                    json_quant_configs = quant_config_dict["quantization"]
         
     | 
| 
       541 
541 
     | 
    
         
             
                    quant_algo = json_quant_configs.get("quant_algo", None)
         
     | 
| 
         @@ -547,8 +547,7 @@ class ModelConfig: 
     | 
|
| 
       547 
547 
     | 
    
         
             
                    elif quant_algo and "FP8" in quant_algo:
         
     | 
| 
       548 
548 
     | 
    
         
             
                        return {"quant_method": "modelopt_fp8"}
         
     | 
| 
       549 
549 
     | 
    
         
             
                    else:
         
     | 
| 
       550 
     | 
    
         
            -
                         
     | 
| 
       551 
     | 
    
         
            -
                        return {"quant_method": "modelopt_fp8"}
         
     | 
| 
      
 550 
     | 
    
         
            +
                        return None
         
     | 
| 
       552 
551 
     | 
    
         | 
| 
       553 
552 
     | 
    
         
             
                def _is_already_quantized(self) -> bool:
         
     | 
| 
       554 
553 
     | 
    
         
             
                    """Check if the model is already quantized based on config files."""
         
     | 
| 
         @@ -806,7 +805,7 @@ def _get_and_verify_dtype( 
     | 
|
| 
       806 
805 
     | 
    
         
             
            ) -> torch.dtype:
         
     | 
| 
       807 
806 
     | 
    
         
             
                # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
         
     | 
| 
       808 
807 
     | 
    
         
             
                # because config.torch_dtype can be None.
         
     | 
| 
       809 
     | 
    
         
            -
                config_dtype = getattr(config, " 
     | 
| 
      
 808 
     | 
    
         
            +
                config_dtype = getattr(config, "dtype", None)
         
     | 
| 
       810 
809 
     | 
    
         
             
                if isinstance(config_dtype, str):
         
     | 
| 
       811 
810 
     | 
    
         
             
                    config_dtype = _STR_DTYPE_TO_TORCH_DTYPE.get(config_dtype, None)
         
     | 
| 
       812 
811 
     | 
    
         
             
                if config_dtype is None:
         
     | 
| 
         @@ -915,12 +914,13 @@ multimodal_model_archs = [ 
     | 
|
| 
       915 
914 
     | 
    
         
             
                "InternVLChatModel",
         
     | 
| 
       916 
915 
     | 
    
         
             
                "InternS1ForConditionalGeneration",
         
     | 
| 
       917 
916 
     | 
    
         
             
                "Phi4MMForCausalLM",
         
     | 
| 
       918 
     | 
    
         
            -
                "VILAForConditionalGeneration",
         
     | 
| 
       919 
917 
     | 
    
         
             
                "Step3VLForConditionalGeneration",
         
     | 
| 
       920 
918 
     | 
    
         
             
                "POINTSV15ChatModel",
         
     | 
| 
       921 
919 
     | 
    
         
             
                "DotsVLMForCausalLM",
         
     | 
| 
       922 
920 
     | 
    
         
             
                "DotsOCRForCausalLM",
         
     | 
| 
       923 
921 
     | 
    
         
             
                "Sarashina2VisionForCausalLM",
         
     | 
| 
      
 922 
     | 
    
         
            +
                "NVILAForConditionalGeneration",
         
     | 
| 
      
 923 
     | 
    
         
            +
                "NVILALiteForConditionalGeneration",
         
     | 
| 
       924 
924 
     | 
    
         
             
                "DeepseekOCRForCausalLM",
         
     | 
| 
       925 
925 
     | 
    
         
             
            ]
         
     | 
| 
       926 
926 
     | 
    
         | 
| 
         @@ -340,17 +340,10 @@ class GroupCoordinator: 
     | 
|
| 
       340 
340 
     | 
    
         
             
                    self.qr_comm: Optional[QuickAllReduce] = None
         
     | 
| 
       341 
341 
     | 
    
         
             
                    if use_custom_allreduce and self.world_size > 1:
         
     | 
| 
       342 
342 
     | 
    
         
             
                        # Initialize a custom fast all-reduce implementation.
         
     | 
| 
       343 
     | 
    
         
            -
                        if torch_compile is not None and torch_compile:
         
     | 
| 
       344 
     | 
    
         
            -
                            # For piecewise CUDA graph, the requirement for custom allreduce is larger to
         
     | 
| 
       345 
     | 
    
         
            -
                            # avoid illegal cuda memory access.
         
     | 
| 
       346 
     | 
    
         
            -
                            ca_max_size = 256 * 1024 * 1024
         
     | 
| 
       347 
     | 
    
         
            -
                        else:
         
     | 
| 
       348 
     | 
    
         
            -
                            ca_max_size = 8 * 1024 * 1024
         
     | 
| 
       349 
343 
     | 
    
         
             
                        try:
         
     | 
| 
       350 
344 
     | 
    
         
             
                            self.ca_comm = CustomAllreduce(
         
     | 
| 
       351 
345 
     | 
    
         
             
                                group=self.cpu_group,
         
     | 
| 
       352 
346 
     | 
    
         
             
                                device=self.device,
         
     | 
| 
       353 
     | 
    
         
            -
                                max_size=ca_max_size,
         
     | 
| 
       354 
347 
     | 
    
         
             
                            )
         
     | 
| 
       355 
348 
     | 
    
         
             
                        except Exception as e:
         
     | 
| 
       356 
349 
     | 
    
         
             
                            logger.warning(
         
     | 
    
        sglang/srt/entrypoints/engine.py
    CHANGED
    
    | 
         @@ -101,7 +101,7 @@ class Engine(EngineBase): 
     | 
|
| 
       101 
101 
     | 
    
         | 
| 
       102 
102 
     | 
    
         
             
                Note:
         
     | 
| 
       103 
103 
     | 
    
         
             
                1. The HTTP server, Engine, and TokenizerManager all run in the main process.
         
     | 
| 
       104 
     | 
    
         
            -
                2. Inter-process communication  
     | 
| 
      
 104 
     | 
    
         
            +
                2. Inter-process communication is done through IPC (each process uses a different port) via the ZMQ library.
         
     | 
| 
       105 
105 
     | 
    
         
             
                """
         
     | 
| 
       106 
106 
     | 
    
         | 
| 
       107 
107 
     | 
    
         
             
                def __init__(self, **kwargs):
         
     | 
| 
         @@ -109,6 +109,8 @@ class Engine(EngineBase): 
     | 
|
| 
       109 
109 
     | 
    
         
             
                    The arguments of this function is the same as `sglang/srt/server_args.py::ServerArgs`.
         
     | 
| 
       110 
110 
     | 
    
         
             
                    Please refer to `ServerArgs` for the documentation.
         
     | 
| 
       111 
111 
     | 
    
         
             
                    """
         
     | 
| 
      
 112 
     | 
    
         
            +
             
     | 
| 
      
 113 
     | 
    
         
            +
                    # Parse server_args
         
     | 
| 
       112 
114 
     | 
    
         
             
                    if "server_args" in kwargs:
         
     | 
| 
       113 
115 
     | 
    
         
             
                        # Directly load server_args
         
     | 
| 
       114 
116 
     | 
    
         
             
                        server_args = kwargs["server_args"]
         
     | 
| 
         @@ -118,29 +120,28 @@ class Engine(EngineBase): 
     | 
|
| 
       118 
120 
     | 
    
         
             
                            # Do not print logs by default
         
     | 
| 
       119 
121 
     | 
    
         
             
                            kwargs["log_level"] = "error"
         
     | 
| 
       120 
122 
     | 
    
         
             
                        server_args = ServerArgs(**kwargs)
         
     | 
| 
      
 123 
     | 
    
         
            +
                    self.server_args = server_args
         
     | 
| 
      
 124 
     | 
    
         
            +
                    logger.info(f"{server_args=}")
         
     | 
| 
       121 
125 
     | 
    
         | 
| 
       122 
126 
     | 
    
         
             
                    # Shutdown the subprocesses automatically when the program exits
         
     | 
| 
       123 
127 
     | 
    
         
             
                    atexit.register(self.shutdown)
         
     | 
| 
       124 
128 
     | 
    
         | 
| 
       125 
     | 
    
         
            -
                    # Allocate ports for inter-process communications
         
     | 
| 
       126 
     | 
    
         
            -
                    self.port_args = PortArgs.init_new(server_args)
         
     | 
| 
       127 
     | 
    
         
            -
                    logger.info(f"{server_args=}")
         
     | 
| 
       128 
     | 
    
         
            -
             
     | 
| 
       129 
129 
     | 
    
         
             
                    # Launch subprocesses
         
     | 
| 
       130 
     | 
    
         
            -
                    tokenizer_manager, template_manager, scheduler_info =  
     | 
| 
       131 
     | 
    
         
            -
                        server_args=server_args 
     | 
| 
       132 
     | 
    
         
            -
                        port_args=self.port_args,
         
     | 
| 
      
 130 
     | 
    
         
            +
                    tokenizer_manager, template_manager, scheduler_info, port_args = (
         
     | 
| 
      
 131 
     | 
    
         
            +
                        _launch_subprocesses(server_args=server_args)
         
     | 
| 
       133 
132 
     | 
    
         
             
                    )
         
     | 
| 
       134 
     | 
    
         
            -
                    self.server_args = server_args
         
     | 
| 
       135 
133 
     | 
    
         
             
                    self.tokenizer_manager = tokenizer_manager
         
     | 
| 
       136 
134 
     | 
    
         
             
                    self.template_manager = template_manager
         
     | 
| 
       137 
135 
     | 
    
         
             
                    self.scheduler_info = scheduler_info
         
     | 
| 
      
 136 
     | 
    
         
            +
                    self.port_args = port_args
         
     | 
| 
       138 
137 
     | 
    
         | 
| 
      
 138 
     | 
    
         
            +
                    # Initialize ZMQ sockets
         
     | 
| 
       139 
139 
     | 
    
         
             
                    context = zmq.Context(2)
         
     | 
| 
       140 
140 
     | 
    
         
             
                    self.send_to_rpc = get_zmq_socket(
         
     | 
| 
       141 
141 
     | 
    
         
             
                        context, zmq.DEALER, self.port_args.rpc_ipc_name, True
         
     | 
| 
       142 
142 
     | 
    
         
             
                    )
         
     | 
| 
       143 
143 
     | 
    
         | 
| 
      
 144 
     | 
    
         
            +
                    # Enable tracing
         
     | 
| 
       144 
145 
     | 
    
         
             
                    if server_args.enable_trace:
         
     | 
| 
       145 
146 
     | 
    
         
             
                        process_tracing_init(server_args.oltp_traces_endpoint, "sglang")
         
     | 
| 
       146 
147 
     | 
    
         
             
                        if server_args.disaggregation_mode == "null":
         
     | 
| 
         @@ -672,15 +673,17 @@ def _set_envs_and_config(server_args: ServerArgs): 
     | 
|
| 
       672 
673 
     | 
    
         
             
                os.environ["NCCL_CUMEM_ENABLE"] = str(int(server_args.enable_symm_mem))
         
     | 
| 
       673 
674 
     | 
    
         
             
                if not server_args.enable_symm_mem:
         
     | 
| 
       674 
675 
     | 
    
         
             
                    os.environ["NCCL_NVLS_ENABLE"] = str(int(server_args.enable_nccl_nvls))
         
     | 
| 
       675 
     | 
    
         
            -
                os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = " 
     | 
| 
      
 676 
     | 
    
         
            +
                os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "8"
         
     | 
| 
       676 
677 
     | 
    
         
             
                os.environ["CUDA_MODULE_LOADING"] = "AUTO"
         
     | 
| 
       677 
     | 
    
         
            -
             
     | 
| 
      
 678 
     | 
    
         
            +
             
     | 
| 
       678 
679 
     | 
    
         
             
                if os.environ.get("TRTLLM_ENABLE_PDL", "1") != "0":
         
     | 
| 
      
 680 
     | 
    
         
            +
                    # flashinfer uses this environment variable for various kernels from MoE to quant kernels
         
     | 
| 
       679 
681 
     | 
    
         
             
                    os.environ["TRTLLM_ENABLE_PDL"] = "1"
         
     | 
| 
       680 
682 
     | 
    
         | 
| 
       681 
683 
     | 
    
         
             
                if os.environ.get("CUTE_DSL_LOG_LEVEL") is None:
         
     | 
| 
       682 
684 
     | 
    
         
             
                    # Default to warning level, to avoid too many logs
         
     | 
| 
       683 
685 
     | 
    
         
             
                    os.environ["CUTE_DSL_LOG_LEVEL"] = "30"
         
     | 
| 
      
 686 
     | 
    
         
            +
             
     | 
| 
       684 
687 
     | 
    
         
             
                if os.environ.get("CUTE_DSL_LOG_TO_CONSOLE") is None:
         
     | 
| 
       685 
688 
     | 
    
         
             
                    # Need to set log to console, otherwise the log level won't take effect
         
     | 
| 
       686 
689 
     | 
    
         
             
                    os.environ["CUTE_DSL_LOG_TO_CONSOLE"] = "1"
         
     | 
| 
         @@ -709,7 +712,7 @@ def _set_envs_and_config(server_args: ServerArgs): 
     | 
|
| 
       709 
712 
     | 
    
         
             
                if _is_cuda and not get_bool_env_var("SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK"):
         
     | 
| 
       710 
713 
     | 
    
         
             
                    assert_pkg_version(
         
     | 
| 
       711 
714 
     | 
    
         
             
                        "sgl-kernel",
         
     | 
| 
       712 
     | 
    
         
            -
                        "0.3.16. 
     | 
| 
      
 715 
     | 
    
         
            +
                        "0.3.16.post4",
         
     | 
| 
       713 
716 
     | 
    
         
             
                        "Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
         
     | 
| 
       714 
717 
     | 
    
         
             
                    )
         
     | 
| 
       715 
718 
     | 
    
         | 
| 
         @@ -840,7 +843,7 @@ def _launch_subprocesses( 
     | 
|
| 
       840 
843 
     | 
    
         | 
| 
       841 
844 
     | 
    
         
             
                    if os.getenv("SGLANG_BLOCK_NONZERO_RANK_CHILDREN") == "0":
         
     | 
| 
       842 
845 
     | 
    
         
             
                        # When using `Engine` as a Python API, we don't want to block here.
         
     | 
| 
       843 
     | 
    
         
            -
                        return None, None, None
         
     | 
| 
      
 846 
     | 
    
         
            +
                        return None, None, None, port_args
         
     | 
| 
       844 
847 
     | 
    
         | 
| 
       845 
848 
     | 
    
         
             
                    launch_dummy_health_check_server(
         
     | 
| 
       846 
849 
     | 
    
         
             
                        server_args.host, server_args.port, server_args.enable_metrics
         
     | 
| 
         @@ -851,7 +854,7 @@ def _launch_subprocesses( 
     | 
|
| 
       851 
854 
     | 
    
         
             
                        logger.error(
         
     | 
| 
       852 
855 
     | 
    
         
             
                            f"Scheduler or DataParallelController {proc.pid} terminated with {proc.exitcode}"
         
     | 
| 
       853 
856 
     | 
    
         
             
                        )
         
     | 
| 
       854 
     | 
    
         
            -
                    return None, None, None
         
     | 
| 
      
 857 
     | 
    
         
            +
                    return None, None, None, port_args
         
     | 
| 
       855 
858 
     | 
    
         | 
| 
       856 
859 
     | 
    
         
             
                # Launch detokenizer process
         
     | 
| 
       857 
860 
     | 
    
         
             
                detoken_proc = mp.Process(
         
     | 
| 
         @@ -897,4 +900,4 @@ def _launch_subprocesses( 
     | 
|
| 
       897 
900 
     | 
    
         | 
| 
       898 
901 
     | 
    
         
             
                tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"]
         
     | 
| 
       899 
902 
     | 
    
         | 
| 
       900 
     | 
    
         
            -
                return tokenizer_manager, template_manager, scheduler_info
         
     | 
| 
      
 903 
     | 
    
         
            +
                return tokenizer_manager, template_manager, scheduler_info, port_args
         
     | 
| 
         @@ -999,7 +999,6 @@ def _wait_and_warmup_grpc( 
     | 
|
| 
       999 
999 
     | 
    
         
             
                # Mark health service as SERVING after warmup completes
         
     | 
| 
       1000 
1000 
     | 
    
         
             
                if health_servicer:
         
     | 
| 
       1001 
1001 
     | 
    
         
             
                    health_servicer.set_serving()
         
     | 
| 
       1002 
     | 
    
         
            -
                    logger.info("Health service marked as SERVING")
         
     | 
| 
       1003 
1002 
     | 
    
         | 
| 
       1004 
1003 
     | 
    
         
             
                logger.info("The server is fired up and ready to roll!")
         
     | 
| 
       1005 
1004 
     | 
    
         |