sglang 0.3.3.post1__py3-none-any.whl → 0.3.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +28 -10
 - sglang/bench_server_latency.py +21 -10
 - sglang/bench_serving.py +101 -7
 - sglang/global_config.py +0 -1
 - sglang/srt/layers/attention/__init__.py +27 -5
 - sglang/srt/layers/attention/double_sparsity_backend.py +281 -0
 - sglang/srt/layers/attention/flashinfer_backend.py +352 -83
 - sglang/srt/layers/attention/triton_backend.py +6 -4
 - sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +772 -0
 - sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -3
 - sglang/srt/layers/attention/triton_ops/prefill_attention.py +4 -2
 - sglang/srt/layers/sampler.py +6 -2
 - sglang/srt/managers/detokenizer_manager.py +31 -10
 - sglang/srt/managers/io_struct.py +4 -0
 - sglang/srt/managers/schedule_batch.py +120 -43
 - sglang/srt/managers/schedule_policy.py +2 -1
 - sglang/srt/managers/scheduler.py +202 -140
 - sglang/srt/managers/tokenizer_manager.py +5 -1
 - sglang/srt/managers/tp_worker.py +111 -1
 - sglang/srt/mem_cache/chunk_cache.py +8 -4
 - sglang/srt/mem_cache/memory_pool.py +77 -4
 - sglang/srt/mem_cache/radix_cache.py +15 -7
 - sglang/srt/model_executor/cuda_graph_runner.py +4 -4
 - sglang/srt/model_executor/forward_batch_info.py +16 -21
 - sglang/srt/model_executor/model_runner.py +60 -1
 - sglang/srt/models/baichuan.py +2 -3
 - sglang/srt/models/chatglm.py +5 -6
 - sglang/srt/models/commandr.py +1 -2
 - sglang/srt/models/dbrx.py +1 -2
 - sglang/srt/models/deepseek.py +4 -5
 - sglang/srt/models/deepseek_v2.py +5 -6
 - sglang/srt/models/exaone.py +1 -2
 - sglang/srt/models/gemma.py +2 -2
 - sglang/srt/models/gemma2.py +5 -5
 - sglang/srt/models/gpt_bigcode.py +5 -5
 - sglang/srt/models/grok.py +1 -2
 - sglang/srt/models/internlm2.py +1 -2
 - sglang/srt/models/llama.py +1 -2
 - sglang/srt/models/llama_classification.py +1 -2
 - sglang/srt/models/llama_reward.py +2 -3
 - sglang/srt/models/llava.py +4 -8
 - sglang/srt/models/llavavid.py +1 -2
 - sglang/srt/models/minicpm.py +1 -2
 - sglang/srt/models/minicpm3.py +5 -6
 - sglang/srt/models/mixtral.py +1 -2
 - sglang/srt/models/mixtral_quant.py +1 -2
 - sglang/srt/models/olmo.py +352 -0
 - sglang/srt/models/olmoe.py +1 -2
 - sglang/srt/models/qwen.py +1 -2
 - sglang/srt/models/qwen2.py +1 -2
 - sglang/srt/models/qwen2_moe.py +4 -5
 - sglang/srt/models/stablelm.py +1 -2
 - sglang/srt/models/torch_native_llama.py +1 -2
 - sglang/srt/models/xverse.py +1 -2
 - sglang/srt/models/xverse_moe.py +4 -5
 - sglang/srt/models/yivl.py +1 -2
 - sglang/srt/openai_api/adapter.py +92 -49
 - sglang/srt/openai_api/protocol.py +10 -2
 - sglang/srt/sampling/penaltylib/orchestrator.py +28 -9
 - sglang/srt/sampling/sampling_batch_info.py +92 -58
 - sglang/srt/sampling/sampling_params.py +2 -0
 - sglang/srt/server.py +116 -17
 - sglang/srt/server_args.py +121 -45
 - sglang/srt/utils.py +11 -3
 - sglang/test/few_shot_gsm8k.py +4 -1
 - sglang/test/few_shot_gsm8k_engine.py +144 -0
 - sglang/test/srt/sampling/penaltylib/utils.py +16 -12
 - sglang/version.py +1 -1
 - {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/METADATA +72 -29
 - {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/RECORD +73 -70
 - {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/WHEEL +1 -1
 - sglang/srt/layers/attention/flashinfer_utils.py +0 -237
 - {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/LICENSE +0 -0
 - {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/top_level.txt +0 -0
 
    
        sglang/bench_latency.py
    CHANGED
    
    | 
         @@ -232,17 +232,18 @@ def extend(reqs, model_runner): 
     | 
|
| 
       232 
232 
     | 
    
         
             
                model_worker_batch = batch.get_model_worker_batch()
         
     | 
| 
       233 
233 
     | 
    
         
             
                forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
         
     | 
| 
       234 
234 
     | 
    
         
             
                logits_output = model_runner.forward(forward_batch)
         
     | 
| 
       235 
     | 
    
         
            -
                next_token_ids = model_runner.sample(logits_output, forward_batch) 
     | 
| 
      
 235 
     | 
    
         
            +
                next_token_ids = model_runner.sample(logits_output, forward_batch)
         
     | 
| 
       236 
236 
     | 
    
         
             
                return next_token_ids, logits_output.next_token_logits, batch
         
     | 
| 
       237 
237 
     | 
    
         | 
| 
       238 
238 
     | 
    
         | 
| 
       239 
239 
     | 
    
         
             
            @torch.inference_mode()
         
     | 
| 
       240 
240 
     | 
    
         
             
            def decode(input_token_ids, batch, model_runner):
         
     | 
| 
       241 
     | 
    
         
            -
                batch. 
     | 
| 
      
 241 
     | 
    
         
            +
                batch.output_ids = input_token_ids
         
     | 
| 
      
 242 
     | 
    
         
            +
                batch.prepare_for_decode()
         
     | 
| 
       242 
243 
     | 
    
         
             
                model_worker_batch = batch.get_model_worker_batch()
         
     | 
| 
       243 
244 
     | 
    
         
             
                forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
         
     | 
| 
       244 
245 
     | 
    
         
             
                logits_output = model_runner.forward(forward_batch)
         
     | 
| 
       245 
     | 
    
         
            -
                next_token_ids = model_runner.sample(logits_output, forward_batch) 
     | 
| 
      
 246 
     | 
    
         
            +
                next_token_ids = model_runner.sample(logits_output, forward_batch)
         
     | 
| 
       246 
247 
     | 
    
         
             
                return next_token_ids, logits_output.next_token_logits
         
     | 
| 
       247 
248 
     | 
    
         | 
| 
       248 
249 
     | 
    
         | 
| 
         @@ -252,6 +253,7 @@ def correctness_test( 
     | 
|
| 
       252 
253 
     | 
    
         
             
                bench_args,
         
     | 
| 
       253 
254 
     | 
    
         
             
                tp_rank,
         
     | 
| 
       254 
255 
     | 
    
         
             
            ):
         
     | 
| 
      
 256 
     | 
    
         
            +
                configure_logger(server_args, prefix=f" TP{tp_rank}")
         
     | 
| 
       255 
257 
     | 
    
         
             
                rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
         
     | 
| 
       256 
258 
     | 
    
         | 
| 
       257 
259 
     | 
    
         
             
                # Load the model
         
     | 
| 
         @@ -279,8 +281,9 @@ def correctness_test( 
     | 
|
| 
       279 
281 
     | 
    
         
             
                output_ids = [input_ids[i] + [next_token_ids[i]] for i in range(len(input_ids))]
         
     | 
| 
       280 
282 
     | 
    
         
             
                for _ in range(bench_args.output_len[0] - 1):
         
     | 
| 
       281 
283 
     | 
    
         
             
                    next_token_ids, _ = decode(next_token_ids, batch, model_runner)
         
     | 
| 
      
 284 
     | 
    
         
            +
                    next_token_ids_list = next_token_ids.tolist()
         
     | 
| 
       282 
285 
     | 
    
         
             
                    for i in range(len(reqs)):
         
     | 
| 
       283 
     | 
    
         
            -
                        output_ids[i].append( 
     | 
| 
      
 286 
     | 
    
         
            +
                        output_ids[i].append(next_token_ids_list[i])
         
     | 
| 
       284 
287 
     | 
    
         | 
| 
       285 
288 
     | 
    
         
             
                # Print
         
     | 
| 
       286 
289 
     | 
    
         
             
                for i in range(len(reqs)):
         
     | 
| 
         @@ -288,8 +291,15 @@ def correctness_test( 
     | 
|
| 
       288 
291 
     | 
    
         
             
                    rank_print(tokenizer.decode(output_ids[i]), "\n")
         
     | 
| 
       289 
292 
     | 
    
         | 
| 
       290 
293 
     | 
    
         | 
| 
      
 294 
     | 
    
         
            +
            def synchronize(device):
         
     | 
| 
      
 295 
     | 
    
         
            +
                if device == "cuda":
         
     | 
| 
      
 296 
     | 
    
         
            +
                    torch.cuda.synchronize()
         
     | 
| 
      
 297 
     | 
    
         
            +
                elif device == "xpu":
         
     | 
| 
      
 298 
     | 
    
         
            +
                    torch.xpu.synchronize()
         
     | 
| 
      
 299 
     | 
    
         
            +
             
     | 
| 
      
 300 
     | 
    
         
            +
             
     | 
| 
       291 
301 
     | 
    
         
             
            def latency_test_run_once(
         
     | 
| 
       292 
     | 
    
         
            -
                run_name, model_runner, rank_print, reqs, batch_size, input_len, output_len
         
     | 
| 
      
 302 
     | 
    
         
            +
                run_name, model_runner, rank_print, reqs, batch_size, input_len, output_len, device
         
     | 
| 
       293 
303 
     | 
    
         
             
            ):
         
     | 
| 
       294 
304 
     | 
    
         
             
                max_batch_size = model_runner.max_total_num_tokens // (input_len + output_len)
         
     | 
| 
       295 
305 
     | 
    
         
             
                if batch_size > max_batch_size:
         
     | 
| 
         @@ -312,10 +322,10 @@ def latency_test_run_once( 
     | 
|
| 
       312 
322 
     | 
    
         
             
                tot_latency = 0
         
     | 
| 
       313 
323 
     | 
    
         | 
| 
       314 
324 
     | 
    
         
             
                # Prefill
         
     | 
| 
       315 
     | 
    
         
            -
                 
     | 
| 
      
 325 
     | 
    
         
            +
                synchronize(device)
         
     | 
| 
       316 
326 
     | 
    
         
             
                tic = time.time()
         
     | 
| 
       317 
327 
     | 
    
         
             
                next_token_ids, _, batch = extend(reqs, model_runner)
         
     | 
| 
       318 
     | 
    
         
            -
                 
     | 
| 
      
 328 
     | 
    
         
            +
                synchronize(device)
         
     | 
| 
       319 
329 
     | 
    
         
             
                prefill_latency = time.time() - tic
         
     | 
| 
       320 
330 
     | 
    
         
             
                tot_latency += prefill_latency
         
     | 
| 
       321 
331 
     | 
    
         
             
                throughput = input_len * batch_size / prefill_latency
         
     | 
| 
         @@ -328,10 +338,10 @@ def latency_test_run_once( 
     | 
|
| 
       328 
338 
     | 
    
         
             
                # Decode
         
     | 
| 
       329 
339 
     | 
    
         
             
                decode_latencies = []
         
     | 
| 
       330 
340 
     | 
    
         
             
                for i in range(output_len - 1):
         
     | 
| 
       331 
     | 
    
         
            -
                     
     | 
| 
      
 341 
     | 
    
         
            +
                    synchronize(device)
         
     | 
| 
       332 
342 
     | 
    
         
             
                    tic = time.time()
         
     | 
| 
       333 
343 
     | 
    
         
             
                    next_token_ids, _ = decode(next_token_ids, batch, model_runner)
         
     | 
| 
       334 
     | 
    
         
            -
                     
     | 
| 
      
 344 
     | 
    
         
            +
                    synchronize(device)
         
     | 
| 
       335 
345 
     | 
    
         
             
                    latency = time.time() - tic
         
     | 
| 
       336 
346 
     | 
    
         
             
                    tot_latency += latency
         
     | 
| 
       337 
347 
     | 
    
         
             
                    throughput = batch_size / latency
         
     | 
| 
         @@ -387,6 +397,7 @@ def latency_test( 
     | 
|
| 
       387 
397 
     | 
    
         
             
                    bench_args.batch_size[0],
         
     | 
| 
       388 
398 
     | 
    
         
             
                    bench_args.input_len[0],
         
     | 
| 
       389 
399 
     | 
    
         
             
                    8,  # shorter decoding to speed up the warmup
         
     | 
| 
      
 400 
     | 
    
         
            +
                    server_args.device,
         
     | 
| 
       390 
401 
     | 
    
         
             
                )
         
     | 
| 
       391 
402 
     | 
    
         
             
                rank_print("Benchmark ...")
         
     | 
| 
       392 
403 
     | 
    
         | 
| 
         @@ -397,7 +408,14 @@ def latency_test( 
     | 
|
| 
       397 
408 
     | 
    
         
             
                ):
         
     | 
| 
       398 
409 
     | 
    
         
             
                    reqs = prepare_synthetic_inputs_for_latency_test(bs, il)
         
     | 
| 
       399 
410 
     | 
    
         
             
                    ret = latency_test_run_once(
         
     | 
| 
       400 
     | 
    
         
            -
                        bench_args.run_name, 
     | 
| 
      
 411 
     | 
    
         
            +
                        bench_args.run_name,
         
     | 
| 
      
 412 
     | 
    
         
            +
                        model_runner,
         
     | 
| 
      
 413 
     | 
    
         
            +
                        rank_print,
         
     | 
| 
      
 414 
     | 
    
         
            +
                        reqs,
         
     | 
| 
      
 415 
     | 
    
         
            +
                        bs,
         
     | 
| 
      
 416 
     | 
    
         
            +
                        il,
         
     | 
| 
      
 417 
     | 
    
         
            +
                        ol,
         
     | 
| 
      
 418 
     | 
    
         
            +
                        server_args.device,
         
     | 
| 
       401 
419 
     | 
    
         
             
                    )
         
     | 
| 
       402 
420 
     | 
    
         
             
                    if ret is not None:
         
     | 
| 
       403 
421 
     | 
    
         
             
                        result_list.append(ret)
         
     | 
    
        sglang/bench_server_latency.py
    CHANGED
    
    | 
         @@ -6,6 +6,8 @@ It accepts arguments similar to those of launch_server.py. 
     | 
|
| 
       6 
6 
     | 
    
         
             
            Usage:
         
     | 
| 
       7 
7 
     | 
    
         | 
| 
       8 
8 
     | 
    
         
             
            python3 -m sglang.bench_server_latency --model meta-llama/Meta-Llama-3.1-8B --batch-size 1 16 64 --input-len 1024 --output-len 8
         
     | 
| 
      
 9 
     | 
    
         
            +
             
     | 
| 
      
 10 
     | 
    
         
            +
            python3 -m sglang.bench_server_latency --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8
         
     | 
| 
       9 
11 
     | 
    
         
             
            """
         
     | 
| 
       10 
12 
     | 
    
         | 
| 
       11 
13 
     | 
    
         
             
            import argparse
         
     | 
| 
         @@ -32,6 +34,8 @@ class BenchArgs: 
     | 
|
| 
       32 
34 
     | 
    
         
             
                input_len: Tuple[int] = (1024,)
         
     | 
| 
       33 
35 
     | 
    
         
             
                output_len: Tuple[int] = (16,)
         
     | 
| 
       34 
36 
     | 
    
         
             
                result_filename: str = "result.jsonl"
         
     | 
| 
      
 37 
     | 
    
         
            +
                base_url: str = ""
         
     | 
| 
      
 38 
     | 
    
         
            +
                skip_warmup: bool = False
         
     | 
| 
       35 
39 
     | 
    
         | 
| 
       36 
40 
     | 
    
         
             
                @staticmethod
         
     | 
| 
       37 
41 
     | 
    
         
             
                def add_cli_args(parser: argparse.ArgumentParser):
         
     | 
| 
         @@ -48,6 +52,8 @@ class BenchArgs: 
     | 
|
| 
       48 
52 
     | 
    
         
             
                    parser.add_argument(
         
     | 
| 
       49 
53 
     | 
    
         
             
                        "--result-filename", type=str, default=BenchArgs.result_filename
         
     | 
| 
       50 
54 
     | 
    
         
             
                    )
         
     | 
| 
      
 55 
     | 
    
         
            +
                    parser.add_argument("--base-url", type=str, default=BenchArgs.base_url)
         
     | 
| 
      
 56 
     | 
    
         
            +
                    parser.add_argument("--skip-warmup", action="store_true")
         
     | 
| 
       51 
57 
     | 
    
         | 
| 
       52 
58 
     | 
    
         
             
                @classmethod
         
     | 
| 
       53 
59 
     | 
    
         
             
                def from_cli_args(cls, args: argparse.Namespace):
         
     | 
| 
         @@ -139,17 +145,21 @@ def run_one_case( 
     | 
|
| 
       139 
145 
     | 
    
         | 
| 
       140 
146 
     | 
    
         | 
| 
       141 
147 
     | 
    
         
             
            def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
         
     | 
| 
       142 
     | 
    
         
            -
                 
     | 
| 
      
 148 
     | 
    
         
            +
                if bench_args.base_url:
         
     | 
| 
      
 149 
     | 
    
         
            +
                    proc, base_url = None, bench_args.base_url
         
     | 
| 
      
 150 
     | 
    
         
            +
                else:
         
     | 
| 
      
 151 
     | 
    
         
            +
                    proc, base_url = launch_server_process(server_args)
         
     | 
| 
       143 
152 
     | 
    
         | 
| 
       144 
153 
     | 
    
         
             
                # warmup
         
     | 
| 
       145 
     | 
    
         
            -
                 
     | 
| 
       146 
     | 
    
         
            -
                     
     | 
| 
       147 
     | 
    
         
            -
             
     | 
| 
       148 
     | 
    
         
            -
             
     | 
| 
       149 
     | 
    
         
            -
             
     | 
| 
       150 
     | 
    
         
            -
             
     | 
| 
       151 
     | 
    
         
            -
             
     | 
| 
       152 
     | 
    
         
            -
             
     | 
| 
      
 154 
     | 
    
         
            +
                if not bench_args.skip_warmup:
         
     | 
| 
      
 155 
     | 
    
         
            +
                    run_one_case(
         
     | 
| 
      
 156 
     | 
    
         
            +
                        base_url,
         
     | 
| 
      
 157 
     | 
    
         
            +
                        batch_size=16,
         
     | 
| 
      
 158 
     | 
    
         
            +
                        input_len=1024,
         
     | 
| 
      
 159 
     | 
    
         
            +
                        output_len=16,
         
     | 
| 
      
 160 
     | 
    
         
            +
                        run_name="",
         
     | 
| 
      
 161 
     | 
    
         
            +
                        result_filename="",
         
     | 
| 
      
 162 
     | 
    
         
            +
                    )
         
     | 
| 
       153 
163 
     | 
    
         | 
| 
       154 
164 
     | 
    
         
             
                # benchmark
         
     | 
| 
       155 
165 
     | 
    
         
             
                try:
         
     | 
| 
         @@ -165,7 +175,8 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs): 
     | 
|
| 
       165 
175 
     | 
    
         
             
                            bench_args.result_filename,
         
     | 
| 
       166 
176 
     | 
    
         
             
                        )
         
     | 
| 
       167 
177 
     | 
    
         
             
                finally:
         
     | 
| 
       168 
     | 
    
         
            -
                     
     | 
| 
      
 178 
     | 
    
         
            +
                    if proc:
         
     | 
| 
      
 179 
     | 
    
         
            +
                        kill_child_process(proc.pid)
         
     | 
| 
       169 
180 
     | 
    
         | 
| 
       170 
181 
     | 
    
         
             
                print(f"\nResults are saved to {bench_args.result_filename}")
         
     | 
| 
       171 
182 
     | 
    
         | 
    
        sglang/bench_serving.py
    CHANGED
    
    | 
         @@ -222,6 +222,85 @@ async def async_request_openai_completions( 
     | 
|
| 
       222 
222 
     | 
    
         
             
                return output
         
     | 
| 
       223 
223 
     | 
    
         | 
| 
       224 
224 
     | 
    
         | 
| 
      
 225 
     | 
    
         
            +
            async def async_request_sglang_generate(
         
     | 
| 
      
 226 
     | 
    
         
            +
                request_func_input: RequestFuncInput,
         
     | 
| 
      
 227 
     | 
    
         
            +
                pbar: Optional[tqdm] = None,
         
     | 
| 
      
 228 
     | 
    
         
            +
            ) -> RequestFuncOutput:
         
     | 
| 
      
 229 
     | 
    
         
            +
                api_url = request_func_input.api_url
         
     | 
| 
      
 230 
     | 
    
         
            +
                prompt = request_func_input.prompt
         
     | 
| 
      
 231 
     | 
    
         
            +
             
     | 
| 
      
 232 
     | 
    
         
            +
                async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
         
     | 
| 
      
 233 
     | 
    
         
            +
                    payload = {
         
     | 
| 
      
 234 
     | 
    
         
            +
                        "text": prompt,
         
     | 
| 
      
 235 
     | 
    
         
            +
                        "sampling_params": {
         
     | 
| 
      
 236 
     | 
    
         
            +
                            "temperature": 0.0,
         
     | 
| 
      
 237 
     | 
    
         
            +
                            "max_new_tokens": request_func_input.output_len,
         
     | 
| 
      
 238 
     | 
    
         
            +
                            "ignore_eos": not args.disable_ignore_eos,
         
     | 
| 
      
 239 
     | 
    
         
            +
                        },
         
     | 
| 
      
 240 
     | 
    
         
            +
                        "stream": not args.disable_stream,
         
     | 
| 
      
 241 
     | 
    
         
            +
                        **request_func_input.extra_request_body,
         
     | 
| 
      
 242 
     | 
    
         
            +
                    }
         
     | 
| 
      
 243 
     | 
    
         
            +
                    headers = {}
         
     | 
| 
      
 244 
     | 
    
         
            +
             
     | 
| 
      
 245 
     | 
    
         
            +
                    output = RequestFuncOutput()
         
     | 
| 
      
 246 
     | 
    
         
            +
                    output.prompt_len = request_func_input.prompt_len
         
     | 
| 
      
 247 
     | 
    
         
            +
             
     | 
| 
      
 248 
     | 
    
         
            +
                    generated_text = ""
         
     | 
| 
      
 249 
     | 
    
         
            +
                    ttft = 0.0
         
     | 
| 
      
 250 
     | 
    
         
            +
                    st = time.perf_counter()
         
     | 
| 
      
 251 
     | 
    
         
            +
                    most_recent_timestamp = st
         
     | 
| 
      
 252 
     | 
    
         
            +
                    try:
         
     | 
| 
      
 253 
     | 
    
         
            +
                        async with session.post(
         
     | 
| 
      
 254 
     | 
    
         
            +
                            url=api_url, json=payload, headers=headers
         
     | 
| 
      
 255 
     | 
    
         
            +
                        ) as response:
         
     | 
| 
      
 256 
     | 
    
         
            +
                            if response.status == 200:
         
     | 
| 
      
 257 
     | 
    
         
            +
                                async for chunk_bytes in response.content:
         
     | 
| 
      
 258 
     | 
    
         
            +
                                    chunk_bytes = chunk_bytes.strip()
         
     | 
| 
      
 259 
     | 
    
         
            +
                                    if not chunk_bytes:
         
     | 
| 
      
 260 
     | 
    
         
            +
                                        continue
         
     | 
| 
      
 261 
     | 
    
         
            +
                                    # print(chunk_bytes)
         
     | 
| 
      
 262 
     | 
    
         
            +
             
     | 
| 
      
 263 
     | 
    
         
            +
                                    chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
         
     | 
| 
      
 264 
     | 
    
         
            +
                                    latency = time.perf_counter() - st
         
     | 
| 
      
 265 
     | 
    
         
            +
                                    if chunk == "[DONE]":
         
     | 
| 
      
 266 
     | 
    
         
            +
                                        pass
         
     | 
| 
      
 267 
     | 
    
         
            +
                                    else:
         
     | 
| 
      
 268 
     | 
    
         
            +
                                        data = json.loads(chunk)
         
     | 
| 
      
 269 
     | 
    
         
            +
             
     | 
| 
      
 270 
     | 
    
         
            +
                                        # NOTE: Some completion API might have a last
         
     | 
| 
      
 271 
     | 
    
         
            +
                                        # usage summary response without a token so we
         
     | 
| 
      
 272 
     | 
    
         
            +
                                        # want to check a token was generated
         
     | 
| 
      
 273 
     | 
    
         
            +
                                        if data["text"]:
         
     | 
| 
      
 274 
     | 
    
         
            +
                                            timestamp = time.perf_counter()
         
     | 
| 
      
 275 
     | 
    
         
            +
                                            # First token
         
     | 
| 
      
 276 
     | 
    
         
            +
                                            if ttft == 0.0:
         
     | 
| 
      
 277 
     | 
    
         
            +
                                                ttft = time.perf_counter() - st
         
     | 
| 
      
 278 
     | 
    
         
            +
                                                output.ttft = ttft
         
     | 
| 
      
 279 
     | 
    
         
            +
             
     | 
| 
      
 280 
     | 
    
         
            +
                                            # Decoding phase
         
     | 
| 
      
 281 
     | 
    
         
            +
                                            else:
         
     | 
| 
      
 282 
     | 
    
         
            +
                                                output.itl.append(timestamp - most_recent_timestamp)
         
     | 
| 
      
 283 
     | 
    
         
            +
             
     | 
| 
      
 284 
     | 
    
         
            +
                                            most_recent_timestamp = timestamp
         
     | 
| 
      
 285 
     | 
    
         
            +
                                            generated_text = data["text"]
         
     | 
| 
      
 286 
     | 
    
         
            +
             
     | 
| 
      
 287 
     | 
    
         
            +
                                output.generated_text = generated_text
         
     | 
| 
      
 288 
     | 
    
         
            +
                                output.success = True
         
     | 
| 
      
 289 
     | 
    
         
            +
                                output.latency = latency
         
     | 
| 
      
 290 
     | 
    
         
            +
                                output.output_len = request_func_input.output_len
         
     | 
| 
      
 291 
     | 
    
         
            +
                            else:
         
     | 
| 
      
 292 
     | 
    
         
            +
                                output.error = response.reason or ""
         
     | 
| 
      
 293 
     | 
    
         
            +
                                output.success = False
         
     | 
| 
      
 294 
     | 
    
         
            +
                    except Exception:
         
     | 
| 
      
 295 
     | 
    
         
            +
                        output.success = False
         
     | 
| 
      
 296 
     | 
    
         
            +
                        exc_info = sys.exc_info()
         
     | 
| 
      
 297 
     | 
    
         
            +
                        output.error = "".join(traceback.format_exception(*exc_info))
         
     | 
| 
      
 298 
     | 
    
         
            +
             
     | 
| 
      
 299 
     | 
    
         
            +
                if pbar:
         
     | 
| 
      
 300 
     | 
    
         
            +
                    pbar.update(1)
         
     | 
| 
      
 301 
     | 
    
         
            +
                return output
         
     | 
| 
      
 302 
     | 
    
         
            +
             
     | 
| 
      
 303 
     | 
    
         
            +
             
     | 
| 
       225 
304 
     | 
    
         
             
            async def async_request_gserver(
         
     | 
| 
       226 
305 
     | 
    
         
             
                request_func_input: RequestFuncInput,
         
     | 
| 
       227 
306 
     | 
    
         
             
                pbar: Optional[tqdm] = None,
         
     | 
| 
         @@ -264,7 +343,9 @@ def get_tokenizer( 
     | 
|
| 
       264 
343 
     | 
    
         | 
| 
       265 
344 
     | 
    
         | 
| 
       266 
345 
     | 
    
         
             
            ASYNC_REQUEST_FUNCS = {
         
     | 
| 
       267 
     | 
    
         
            -
                "sglang":  
     | 
| 
      
 346 
     | 
    
         
            +
                "sglang": async_request_sglang_generate,
         
     | 
| 
      
 347 
     | 
    
         
            +
                "sglang-native": async_request_sglang_generate,
         
     | 
| 
      
 348 
     | 
    
         
            +
                "sglang-oai": async_request_openai_completions,
         
     | 
| 
       268 
349 
     | 
    
         
             
                "vllm": async_request_openai_completions,
         
     | 
| 
       269 
350 
     | 
    
         
             
                "lmdeploy": async_request_openai_completions,
         
     | 
| 
       270 
351 
     | 
    
         
             
                "trt": async_request_trt_llm,
         
     | 
| 
         @@ -387,6 +468,8 @@ def sample_sharegpt_requests( 
     | 
|
| 
       387 
468 
     | 
    
         
             
                        continue
         
     | 
| 
       388 
469 
     | 
    
         
             
                    filtered_dataset.append((prompt, prompt_len, output_len))
         
     | 
| 
       389 
470 
     | 
    
         | 
| 
      
 471 
     | 
    
         
            +
                print(f"#Input tokens: {np.sum([x[1] for x in filtered_dataset])}")
         
     | 
| 
      
 472 
     | 
    
         
            +
                print(f"#Output tokens: {np.sum([x[2] for x in filtered_dataset])}")
         
     | 
| 
       390 
473 
     | 
    
         
             
                return filtered_dataset
         
     | 
| 
       391 
474 
     | 
    
         | 
| 
       392 
475 
     | 
    
         | 
| 
         @@ -587,6 +670,8 @@ async def benchmark( 
     | 
|
| 
       587 
670 
     | 
    
         
             
                else:
         
     | 
| 
       588 
671 
     | 
    
         
             
                    print("Initial test run completed. Starting main benchmark run...")
         
     | 
| 
       589 
672 
     | 
    
         | 
| 
      
 673 
     | 
    
         
            +
                time.sleep(1.5)
         
     | 
| 
      
 674 
     | 
    
         
            +
             
     | 
| 
       590 
675 
     | 
    
         
             
                pbar = None if disable_tqdm else tqdm(total=len(input_requests))
         
     | 
| 
       591 
676 
     | 
    
         | 
| 
       592 
677 
     | 
    
         
             
                benchmark_start_time = time.perf_counter()
         
     | 
| 
         @@ -782,24 +867,33 @@ def run_benchmark(args_: argparse.Namespace): 
     | 
|
| 
       782 
867 
     | 
    
         
             
                if args.port is None:
         
     | 
| 
       783 
868 
     | 
    
         
             
                    args.port = {
         
     | 
| 
       784 
869 
     | 
    
         
             
                        "sglang": 30000,
         
     | 
| 
      
 870 
     | 
    
         
            +
                        "sglang-native": 30000,
         
     | 
| 
      
 871 
     | 
    
         
            +
                        "sglang-oai": 30000,
         
     | 
| 
       785 
872 
     | 
    
         
             
                        "lmdeploy": 23333,
         
     | 
| 
       786 
873 
     | 
    
         
             
                        "vllm": 8000,
         
     | 
| 
       787 
874 
     | 
    
         
             
                        "trt": 8000,
         
     | 
| 
       788 
875 
     | 
    
         
             
                        "gserver": 9988,
         
     | 
| 
       789 
876 
     | 
    
         
             
                    }.get(args.backend, 30000)
         
     | 
| 
       790 
877 
     | 
    
         | 
| 
       791 
     | 
    
         
            -
                api_url = (
         
     | 
| 
       792 
     | 
    
         
            -
                    f"{args.base_url}/v1/completions"
         
     | 
| 
       793 
     | 
    
         
            -
                    if args.base_url
         
     | 
| 
       794 
     | 
    
         
            -
                    else f"http://{args.host}:{args.port}/v1/completions"
         
     | 
| 
       795 
     | 
    
         
            -
                )
         
     | 
| 
       796 
878 
     | 
    
         
             
                model_url = (
         
     | 
| 
       797 
879 
     | 
    
         
             
                    f"{args.base_url}/v1/models"
         
     | 
| 
       798 
880 
     | 
    
         
             
                    if args.base_url
         
     | 
| 
       799 
881 
     | 
    
         
             
                    else f"http://{args.host}:{args.port}/v1/models"
         
     | 
| 
       800 
882 
     | 
    
         
             
                )
         
     | 
| 
       801 
883 
     | 
    
         | 
| 
       802 
     | 
    
         
            -
                if args.backend  
     | 
| 
      
 884 
     | 
    
         
            +
                if args.backend in ["sglang", "sglang-native"]:
         
     | 
| 
      
 885 
     | 
    
         
            +
                    api_url = (
         
     | 
| 
      
 886 
     | 
    
         
            +
                        f"{args.base_url}/generate"
         
     | 
| 
      
 887 
     | 
    
         
            +
                        if args.base_url
         
     | 
| 
      
 888 
     | 
    
         
            +
                        else f"http://{args.host}:{args.port}/generate"
         
     | 
| 
      
 889 
     | 
    
         
            +
                    )
         
     | 
| 
      
 890 
     | 
    
         
            +
                elif args.backend in ["sglang-oai", "vllm", "lmdeploy"]:
         
     | 
| 
      
 891 
     | 
    
         
            +
                    api_url = (
         
     | 
| 
      
 892 
     | 
    
         
            +
                        f"{args.base_url}/v1/completions"
         
     | 
| 
      
 893 
     | 
    
         
            +
                        if args.base_url
         
     | 
| 
      
 894 
     | 
    
         
            +
                        else f"http://{args.host}:{args.port}/v1/completions"
         
     | 
| 
      
 895 
     | 
    
         
            +
                    )
         
     | 
| 
      
 896 
     | 
    
         
            +
                elif args.backend == "trt":
         
     | 
| 
       803 
897 
     | 
    
         
             
                    api_url = (
         
     | 
| 
       804 
898 
     | 
    
         
             
                        f"{args.base_url}/v2/models/ensemble/generate_stream"
         
     | 
| 
       805 
899 
     | 
    
         
             
                        if args.base_url
         
     | 
    
        sglang/global_config.py
    CHANGED
    
    | 
         @@ -19,7 +19,6 @@ class GlobalConfig: 
     | 
|
| 
       19 
19 
     | 
    
         
             
                    self.new_token_ratio_decay = 0.001
         
     | 
| 
       20 
20 
     | 
    
         | 
| 
       21 
21 
     | 
    
         
             
                    # Runtime constants: others
         
     | 
| 
       22 
     | 
    
         
            -
                    self.num_continue_decode_steps = 10
         
     | 
| 
       23 
22 
     | 
    
         
             
                    self.retract_decode_steps = 20
         
     | 
| 
       24 
23 
     | 
    
         
             
                    self.flashinfer_workspace_size = os.environ.get(
         
     | 
| 
       25 
24 
     | 
    
         
             
                        "FLASHINFER_WORKSPACE_SIZE", 384 * 1024 * 1024
         
     | 
| 
         @@ -1,5 +1,6 @@ 
     | 
|
| 
       1 
1 
     | 
    
         
             
            from abc import ABC, abstractmethod
         
     | 
| 
       2 
2 
     | 
    
         | 
| 
      
 3 
     | 
    
         
            +
            import torch
         
     | 
| 
       3 
4 
     | 
    
         
             
            from torch import nn
         
     | 
| 
       4 
5 
     | 
    
         | 
| 
       5 
6 
     | 
    
         
             
            from sglang.srt.model_executor.forward_batch_info import ForwardBatch
         
     | 
| 
         @@ -18,13 +19,13 @@ class AttentionBackend(ABC): 
     | 
|
| 
       18 
19 
     | 
    
         
             
                    raise NotImplementedError()
         
     | 
| 
       19 
20 
     | 
    
         | 
| 
       20 
21 
     | 
    
         
             
                def init_forward_metadata_capture_cuda_graph(
         
     | 
| 
       21 
     | 
    
         
            -
                    self, bs: int, req_pool_indices, seq_lens
         
     | 
| 
      
 22 
     | 
    
         
            +
                    self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
         
     | 
| 
       22 
23 
     | 
    
         
             
                ):
         
     | 
| 
       23 
24 
     | 
    
         
             
                    """Init the metadata for a forward pass for capturing a cuda graph."""
         
     | 
| 
       24 
25 
     | 
    
         
             
                    raise NotImplementedError()
         
     | 
| 
       25 
26 
     | 
    
         | 
| 
       26 
27 
     | 
    
         
             
                def init_forward_metadata_replay_cuda_graph(
         
     | 
| 
       27 
     | 
    
         
            -
                    self, bs: int, req_pool_indices, seq_lens
         
     | 
| 
      
 28 
     | 
    
         
            +
                    self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
         
     | 
| 
       28 
29 
     | 
    
         
             
                ):
         
     | 
| 
       29 
30 
     | 
    
         
             
                    """Init the metadata for a forward pass for replying a cuda graph."""
         
     | 
| 
       30 
31 
     | 
    
         
             
                    raise NotImplementedError()
         
     | 
| 
         @@ -33,17 +34,38 @@ class AttentionBackend(ABC): 
     | 
|
| 
       33 
34 
     | 
    
         
             
                    """Get the fill value for padded seq lens. Typically, it is 0 or 1."""
         
     | 
| 
       34 
35 
     | 
    
         
             
                    raise NotImplementedError()
         
     | 
| 
       35 
36 
     | 
    
         | 
| 
       36 
     | 
    
         
            -
                def forward( 
     | 
| 
      
 37 
     | 
    
         
            +
                def forward(
         
     | 
| 
      
 38 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 39 
     | 
    
         
            +
                    q: torch.Tensor,
         
     | 
| 
      
 40 
     | 
    
         
            +
                    k: torch.Tensor,
         
     | 
| 
      
 41 
     | 
    
         
            +
                    v: torch.Tensor,
         
     | 
| 
      
 42 
     | 
    
         
            +
                    layer: nn.Module,
         
     | 
| 
      
 43 
     | 
    
         
            +
                    forward_batch: ForwardBatch,
         
     | 
| 
      
 44 
     | 
    
         
            +
                ):
         
     | 
| 
       37 
45 
     | 
    
         
             
                    """Run forward on an attention layer."""
         
     | 
| 
       38 
46 
     | 
    
         
             
                    if forward_batch.forward_mode.is_decode():
         
     | 
| 
       39 
47 
     | 
    
         
             
                        return self.forward_decode(q, k, v, layer, forward_batch)
         
     | 
| 
       40 
48 
     | 
    
         
             
                    else:
         
     | 
| 
       41 
49 
     | 
    
         
             
                        return self.forward_extend(q, k, v, layer, forward_batch)
         
     | 
| 
       42 
50 
     | 
    
         | 
| 
       43 
     | 
    
         
            -
                def forward_decode( 
     | 
| 
      
 51 
     | 
    
         
            +
                def forward_decode(
         
     | 
| 
      
 52 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 53 
     | 
    
         
            +
                    q: torch.Tensor,
         
     | 
| 
      
 54 
     | 
    
         
            +
                    k: torch.Tensor,
         
     | 
| 
      
 55 
     | 
    
         
            +
                    v: torch.Tensor,
         
     | 
| 
      
 56 
     | 
    
         
            +
                    layer: nn.Module,
         
     | 
| 
      
 57 
     | 
    
         
            +
                    forward_batch: ForwardBatch,
         
     | 
| 
      
 58 
     | 
    
         
            +
                ):
         
     | 
| 
       44 
59 
     | 
    
         
             
                    """Run a forward for decode."""
         
     | 
| 
       45 
60 
     | 
    
         
             
                    raise NotImplementedError()
         
     | 
| 
       46 
61 
     | 
    
         | 
| 
       47 
     | 
    
         
            -
                def forward_extend( 
     | 
| 
      
 62 
     | 
    
         
            +
                def forward_extend(
         
     | 
| 
      
 63 
     | 
    
         
            +
                    self,
         
     | 
| 
      
 64 
     | 
    
         
            +
                    q: torch.Tensor,
         
     | 
| 
      
 65 
     | 
    
         
            +
                    k: torch.Tensor,
         
     | 
| 
      
 66 
     | 
    
         
            +
                    v: torch.Tensor,
         
     | 
| 
      
 67 
     | 
    
         
            +
                    layer: nn.Module,
         
     | 
| 
      
 68 
     | 
    
         
            +
                    forward_batch: ForwardBatch,
         
     | 
| 
      
 69 
     | 
    
         
            +
                ):
         
     | 
| 
       48 
70 
     | 
    
         
             
                    """Run a forward for extend."""
         
     | 
| 
       49 
71 
     | 
    
         
             
                    raise NotImplementedError()
         
     | 
| 
         @@ -0,0 +1,281 @@ 
     | 
|
| 
      
 1 
     | 
    
         
            +
            from __future__ import annotations
         
     | 
| 
      
 2 
     | 
    
         
            +
             
     | 
| 
      
 3 
     | 
    
         
            +
            from typing import TYPE_CHECKING
         
     | 
| 
      
 4 
     | 
    
         
            +
             
     | 
| 
      
 5 
     | 
    
         
            +
            import torch
         
     | 
| 
      
 6 
     | 
    
         
            +
            import torch.nn as nn
         
     | 
| 
      
 7 
     | 
    
         
            +
             
     | 
| 
      
 8 
     | 
    
         
            +
            from sglang.srt.layers.attention import AttentionBackend
         
     | 
| 
      
 9 
     | 
    
         
            +
            from sglang.srt.managers.schedule_batch import global_server_args_dict
         
     | 
| 
      
 10 
     | 
    
         
            +
            from sglang.srt.model_executor.forward_batch_info import ForwardBatch
         
     | 
| 
      
 11 
     | 
    
         
            +
             
     | 
| 
      
 12 
     | 
    
         
            +
            if TYPE_CHECKING:
         
     | 
| 
      
 13 
     | 
    
         
            +
                from sglang.srt.model_executor.model_runner import ModelRunner
         
     | 
| 
      
 14 
     | 
    
         
            +
             
     | 
| 
      
 15 
     | 
    
         
            +
             
     | 
| 
      
 16 
     | 
    
         
            +
            class DoubleSparseAttnBackend(AttentionBackend):
         
     | 
| 
      
 17 
     | 
    
         
            +
                def __init__(self, model_runner: ModelRunner):
         
     | 
| 
      
 18 
     | 
    
         
            +
                    # Lazy import to avoid the initialization of cuda context
         
     | 
| 
      
 19 
     | 
    
         
            +
                    from sglang.srt.layers.attention.triton_ops.double_sparsity_attention import (
         
     | 
| 
      
 20 
     | 
    
         
            +
                        flash_decode_attention_fwd,
         
     | 
| 
      
 21 
     | 
    
         
            +
                        flash_decode_sparse_attention_fwd,
         
     | 
| 
      
 22 
     | 
    
         
            +
                    )
         
     | 
| 
      
 23 
     | 
    
         
            +
                    from sglang.srt.layers.attention.triton_ops.extend_attention import (
         
     | 
| 
      
 24 
     | 
    
         
            +
                        extend_attention_fwd,
         
     | 
| 
      
 25 
     | 
    
         
            +
                    )
         
     | 
| 
      
 26 
     | 
    
         
            +
             
     | 
| 
      
 27 
     | 
    
         
            +
                    super().__init__()
         
     | 
| 
      
 28 
     | 
    
         
            +
             
     | 
| 
      
 29 
     | 
    
         
            +
                    self.decode_attention_fwd = flash_decode_attention_fwd
         
     | 
| 
      
 30 
     | 
    
         
            +
                    self.decode_sparse_attention_fwd = flash_decode_sparse_attention_fwd
         
     | 
| 
      
 31 
     | 
    
         
            +
                    self.extend_attention_fwd = extend_attention_fwd
         
     | 
| 
      
 32 
     | 
    
         
            +
                    self.num_head = model_runner.model_config.num_attention_heads
         
     | 
| 
      
 33 
     | 
    
         
            +
                    self.head_dim = model_runner.model_config.hidden_size // self.num_head
         
     | 
| 
      
 34 
     | 
    
         
            +
                    self.heavy_token_num = model_runner.server_args.ds_heavy_token_num
         
     | 
| 
      
 35 
     | 
    
         
            +
             
     | 
| 
      
 36 
     | 
    
         
            +
                    self.sorted_channels = model_runner.sorted_channels
         
     | 
| 
      
 37 
     | 
    
         
            +
                    self.sparse_decode_thresold = (
         
     | 
| 
      
 38 
     | 
    
         
            +
                        model_runner.server_args.ds_sparse_decode_threshold
         
     | 
| 
      
 39 
     | 
    
         
            +
                    )
         
     | 
| 
      
 40 
     | 
    
         
            +
                    self.att_out_approx: torch.Tensor = None
         
     | 
| 
      
 41 
     | 
    
         
            +
                    self.mid_out: torch.Tensor = None
         
     | 
| 
      
 42 
     | 
    
         
            +
                    self.mid_o_logexpsum: torch.Tensor = None
         
     | 
| 
      
 43 
     | 
    
         
            +
             
     | 
| 
      
 44 
     | 
    
         
            +
                    # TODO: Change the hard-coded block_seq_num
         
     | 
| 
      
 45 
     | 
    
         
            +
                    self.BLOCK_SEQ = 128
         
     | 
| 
      
 46 
     | 
    
         
            +
             
     | 
| 
      
 47 
     | 
    
         
            +
                    if global_server_args_dict.get("triton_attention_reduce_in_fp32", False):
         
     | 
| 
      
 48 
     | 
    
         
            +
                        self.reduce_dtype = torch.float32
         
     | 
| 
      
 49 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 50 
     | 
    
         
            +
                        self.reduce_dtype = torch.float16
         
     | 
| 
      
 51 
     | 
    
         
            +
             
     | 
| 
      
 52 
     | 
    
         
            +
                    self.forward_metadata = None
         
     | 
| 
      
 53 
     | 
    
         
            +
             
     | 
| 
      
 54 
     | 
    
         
            +
                    self.cuda_graph_max_seq_len = model_runner.model_config.context_len
         
     | 
| 
      
 55 
     | 
    
         
            +
             
     | 
| 
      
 56 
     | 
    
         
            +
                def init_forward_metadata(self, forward_batch: ForwardBatch):
         
     | 
| 
      
 57 
     | 
    
         
            +
                    """Init auxiliary variables for triton attention backend."""
         
     | 
| 
      
 58 
     | 
    
         
            +
             
     | 
| 
      
 59 
     | 
    
         
            +
                    if forward_batch.forward_mode.is_decode():
         
     | 
| 
      
 60 
     | 
    
         
            +
                        start_loc = torch.zeros_like(forward_batch.seq_lens, dtype=torch.int32)
         
     | 
| 
      
 61 
     | 
    
         
            +
                        start_loc[1:] = torch.cumsum(forward_batch.seq_lens[:-1], dim=0)
         
     | 
| 
      
 62 
     | 
    
         
            +
             
     | 
| 
      
 63 
     | 
    
         
            +
                        total_num_tokens = torch.sum(forward_batch.seq_lens).item()
         
     | 
| 
      
 64 
     | 
    
         
            +
                        attn_logits = torch.empty(
         
     | 
| 
      
 65 
     | 
    
         
            +
                            (self.num_head, total_num_tokens),
         
     | 
| 
      
 66 
     | 
    
         
            +
                            dtype=self.reduce_dtype,
         
     | 
| 
      
 67 
     | 
    
         
            +
                            device="cuda",
         
     | 
| 
      
 68 
     | 
    
         
            +
                        )
         
     | 
| 
      
 69 
     | 
    
         
            +
             
     | 
| 
      
 70 
     | 
    
         
            +
                        max_seq_len = torch.max(forward_batch.seq_lens).item()
         
     | 
| 
      
 71 
     | 
    
         
            +
                        min_seq_len = torch.min(forward_batch.seq_lens).item()
         
     | 
| 
      
 72 
     | 
    
         
            +
                        max_extend_len = None
         
     | 
| 
      
 73 
     | 
    
         
            +
                        # NOTE: Align sequence order with req_to_token order
         
     | 
| 
      
 74 
     | 
    
         
            +
                        ds_req_to_token = forward_batch.req_to_token_pool.req_to_token[
         
     | 
| 
      
 75 
     | 
    
         
            +
                            forward_batch.req_pool_indices
         
     | 
| 
      
 76 
     | 
    
         
            +
                        ]
         
     | 
| 
      
 77 
     | 
    
         
            +
             
     | 
| 
      
 78 
     | 
    
         
            +
                        bsz = forward_batch.seq_lens.shape[0]
         
     | 
| 
      
 79 
     | 
    
         
            +
             
     | 
| 
      
 80 
     | 
    
         
            +
                        att_out_approx = torch.empty(
         
     | 
| 
      
 81 
     | 
    
         
            +
                            [self.num_head, bsz, max_seq_len],
         
     | 
| 
      
 82 
     | 
    
         
            +
                            dtype=self.reduce_dtype,
         
     | 
| 
      
 83 
     | 
    
         
            +
                            device="cuda",
         
     | 
| 
      
 84 
     | 
    
         
            +
                        )
         
     | 
| 
      
 85 
     | 
    
         
            +
             
     | 
| 
      
 86 
     | 
    
         
            +
                        block_seq_num = (
         
     | 
| 
      
 87 
     | 
    
         
            +
                            self.heavy_token_num + self.BLOCK_SEQ - 1
         
     | 
| 
      
 88 
     | 
    
         
            +
                        ) // self.BLOCK_SEQ
         
     | 
| 
      
 89 
     | 
    
         
            +
             
     | 
| 
      
 90 
     | 
    
         
            +
                        mid_out = torch.empty(
         
     | 
| 
      
 91 
     | 
    
         
            +
                            [bsz, self.num_head, block_seq_num, self.head_dim],
         
     | 
| 
      
 92 
     | 
    
         
            +
                            dtype=torch.float32,
         
     | 
| 
      
 93 
     | 
    
         
            +
                            device="cuda",
         
     | 
| 
      
 94 
     | 
    
         
            +
                        )
         
     | 
| 
      
 95 
     | 
    
         
            +
                        mid_o_logexpsum = torch.empty(
         
     | 
| 
      
 96 
     | 
    
         
            +
                            [bsz, self.num_head, block_seq_num], dtype=torch.float32, device="cuda"
         
     | 
| 
      
 97 
     | 
    
         
            +
                        )
         
     | 
| 
      
 98 
     | 
    
         
            +
                        self.att_out_approx = att_out_approx
         
     | 
| 
      
 99 
     | 
    
         
            +
                        self.mid_out = mid_out
         
     | 
| 
      
 100 
     | 
    
         
            +
                        self.mid_o_logexpsum = mid_o_logexpsum
         
     | 
| 
      
 101 
     | 
    
         
            +
             
     | 
| 
      
 102 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 103 
     | 
    
         
            +
                        start_loc = attn_logits = max_seq_len = min_seq_len = None
         
     | 
| 
      
 104 
     | 
    
         
            +
                        prefix_lens = forward_batch.extend_prefix_lens
         
     | 
| 
      
 105 
     | 
    
         
            +
                        max_extend_len = torch.max(forward_batch.seq_lens - prefix_lens).item()
         
     | 
| 
      
 106 
     | 
    
         
            +
                        ds_req_to_token = None
         
     | 
| 
      
 107 
     | 
    
         
            +
             
     | 
| 
      
 108 
     | 
    
         
            +
                    self.forward_metadata = (
         
     | 
| 
      
 109 
     | 
    
         
            +
                        start_loc,
         
     | 
| 
      
 110 
     | 
    
         
            +
                        attn_logits,
         
     | 
| 
      
 111 
     | 
    
         
            +
                        max_seq_len,
         
     | 
| 
      
 112 
     | 
    
         
            +
                        min_seq_len,
         
     | 
| 
      
 113 
     | 
    
         
            +
                        max_extend_len,
         
     | 
| 
      
 114 
     | 
    
         
            +
                        ds_req_to_token,
         
     | 
| 
      
 115 
     | 
    
         
            +
                    )
         
     | 
| 
      
 116 
     | 
    
         
            +
             
     | 
| 
      
 117 
     | 
    
         
            +
                def init_cuda_graph_state(self, max_bs: int):
         
     | 
| 
      
 118 
     | 
    
         
            +
                    # TODO(Andy): Support CUDA graph for double sparse attention
         
     | 
| 
      
 119 
     | 
    
         
            +
                    raise ValueError(
         
     | 
| 
      
 120 
     | 
    
         
            +
                        "Double sparse attention does not support CUDA graph for now. Please --disable-cuda-graph"
         
     | 
| 
      
 121 
     | 
    
         
            +
                    )
         
     | 
| 
      
 122 
     | 
    
         
            +
                    self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len
         
     | 
| 
      
 123 
     | 
    
         
            +
             
     | 
| 
      
 124 
     | 
    
         
            +
                    self.cuda_graph_start_loc = torch.zeros(
         
     | 
| 
      
 125 
     | 
    
         
            +
                        (max_bs,), dtype=torch.int32, device="cuda"
         
     | 
| 
      
 126 
     | 
    
         
            +
                    )
         
     | 
| 
      
 127 
     | 
    
         
            +
                    self.cuda_graph_attn_logits = torch.empty(
         
     | 
| 
      
 128 
     | 
    
         
            +
                        (
         
     | 
| 
      
 129 
     | 
    
         
            +
                            self.num_head,
         
     | 
| 
      
 130 
     | 
    
         
            +
                            self.cuda_graph_max_total_num_tokens,
         
     | 
| 
      
 131 
     | 
    
         
            +
                        ),
         
     | 
| 
      
 132 
     | 
    
         
            +
                        dtype=self.reduce_dtype,
         
     | 
| 
      
 133 
     | 
    
         
            +
                        device="cuda",
         
     | 
| 
      
 134 
     | 
    
         
            +
                    )
         
     | 
| 
      
 135 
     | 
    
         
            +
             
     | 
| 
      
 136 
     | 
    
         
            +
                def init_forward_metadata_capture_cuda_graph(
         
     | 
| 
      
 137 
     | 
    
         
            +
                    self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
         
     | 
| 
      
 138 
     | 
    
         
            +
                ):
         
     | 
| 
      
 139 
     | 
    
         
            +
                    self.forward_metadata = (
         
     | 
| 
      
 140 
     | 
    
         
            +
                        self.cuda_graph_start_loc,
         
     | 
| 
      
 141 
     | 
    
         
            +
                        self.cuda_graph_attn_logits,
         
     | 
| 
      
 142 
     | 
    
         
            +
                        self.cuda_graph_max_seq_len,
         
     | 
| 
      
 143 
     | 
    
         
            +
                        None,
         
     | 
| 
      
 144 
     | 
    
         
            +
                    )
         
     | 
| 
      
 145 
     | 
    
         
            +
             
     | 
| 
      
 146 
     | 
    
         
            +
                def init_forward_metadata_replay_cuda_graph(
         
     | 
| 
      
 147 
     | 
    
         
            +
                    self, bs: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor
         
     | 
| 
      
 148 
     | 
    
         
            +
                ):
         
     | 
| 
      
 149 
     | 
    
         
            +
                    self.cuda_graph_start_loc.zero_()
         
     | 
| 
      
 150 
     | 
    
         
            +
                    self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
         
     | 
| 
      
 151 
     | 
    
         
            +
             
     | 
| 
      
 152 
     | 
    
         
            +
                def get_cuda_graph_seq_len_fill_value(self):
         
     | 
| 
      
 153 
     | 
    
         
            +
                    return 1
         
     | 
| 
      
 154 
     | 
    
         
            +
             
     | 
| 
      
 155 
     | 
    
         
            +
                def forward_extend(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
         
     | 
| 
      
 156 
     | 
    
         
            +
                    # TODO: reuse the buffer across layers
         
     | 
| 
      
 157 
     | 
    
         
            +
                    if layer.qk_head_dim != layer.v_head_dim:
         
     | 
| 
      
 158 
     | 
    
         
            +
                        o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
         
     | 
| 
      
 159 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 160 
     | 
    
         
            +
                        o = torch.empty_like(q)
         
     | 
| 
      
 161 
     | 
    
         
            +
             
     | 
| 
      
 162 
     | 
    
         
            +
                    k_label = torch.gather(
         
     | 
| 
      
 163 
     | 
    
         
            +
                        k,
         
     | 
| 
      
 164 
     | 
    
         
            +
                        2,
         
     | 
| 
      
 165 
     | 
    
         
            +
                        self.sorted_channels[layer.layer_id]
         
     | 
| 
      
 166 
     | 
    
         
            +
                        .unsqueeze(0)
         
     | 
| 
      
 167 
     | 
    
         
            +
                        .expand(k.shape[0], -1, -1),
         
     | 
| 
      
 168 
     | 
    
         
            +
                    )
         
     | 
| 
      
 169 
     | 
    
         
            +
             
     | 
| 
      
 170 
     | 
    
         
            +
                    forward_batch.token_to_kv_pool.set_kv_buffer(
         
     | 
| 
      
 171 
     | 
    
         
            +
                        layer.layer_id, forward_batch.out_cache_loc, k, v, k_label
         
     | 
| 
      
 172 
     | 
    
         
            +
                    )
         
     | 
| 
      
 173 
     | 
    
         
            +
             
     | 
| 
      
 174 
     | 
    
         
            +
                    (
         
     | 
| 
      
 175 
     | 
    
         
            +
                        start_loc,
         
     | 
| 
      
 176 
     | 
    
         
            +
                        attn_logits,
         
     | 
| 
      
 177 
     | 
    
         
            +
                        max_seq_len,
         
     | 
| 
      
 178 
     | 
    
         
            +
                        min_seq_len,
         
     | 
| 
      
 179 
     | 
    
         
            +
                        max_extend_len,
         
     | 
| 
      
 180 
     | 
    
         
            +
                        ds_req_to_token,
         
     | 
| 
      
 181 
     | 
    
         
            +
                    ) = self.forward_metadata
         
     | 
| 
      
 182 
     | 
    
         
            +
                    self.extend_attention_fwd(
         
     | 
| 
      
 183 
     | 
    
         
            +
                        q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
         
     | 
| 
      
 184 
     | 
    
         
            +
                        k.contiguous(),
         
     | 
| 
      
 185 
     | 
    
         
            +
                        v.contiguous(),
         
     | 
| 
      
 186 
     | 
    
         
            +
                        o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
         
     | 
| 
      
 187 
     | 
    
         
            +
                        forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
         
     | 
| 
      
 188 
     | 
    
         
            +
                        forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
         
     | 
| 
      
 189 
     | 
    
         
            +
                        forward_batch.req_to_token_pool.req_to_token,
         
     | 
| 
      
 190 
     | 
    
         
            +
                        forward_batch.req_pool_indices,
         
     | 
| 
      
 191 
     | 
    
         
            +
                        forward_batch.seq_lens,
         
     | 
| 
      
 192 
     | 
    
         
            +
                        forward_batch.extend_seq_lens,
         
     | 
| 
      
 193 
     | 
    
         
            +
                        forward_batch.extend_start_loc,
         
     | 
| 
      
 194 
     | 
    
         
            +
                        max_extend_len,
         
     | 
| 
      
 195 
     | 
    
         
            +
                        layer.scaling,
         
     | 
| 
      
 196 
     | 
    
         
            +
                        layer.logit_cap,
         
     | 
| 
      
 197 
     | 
    
         
            +
                    )
         
     | 
| 
      
 198 
     | 
    
         
            +
                    return o
         
     | 
| 
      
 199 
     | 
    
         
            +
             
     | 
| 
      
 200 
     | 
    
         
            +
                def forward_decode(self, q, k, v, layer: nn.Module, forward_batch: ForwardBatch):
         
     | 
| 
      
 201 
     | 
    
         
            +
                    # During torch.compile, there is a bug in rotary_emb that causes the
         
     | 
| 
      
 202 
     | 
    
         
            +
                    # output value to have a 3D tensor shape. This reshapes the output correctly.
         
     | 
| 
      
 203 
     | 
    
         
            +
                    q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
         
     | 
| 
      
 204 
     | 
    
         
            +
             
     | 
| 
      
 205 
     | 
    
         
            +
                    # TODO: reuse the buffer across layers
         
     | 
| 
      
 206 
     | 
    
         
            +
                    if layer.qk_head_dim != layer.v_head_dim:
         
     | 
| 
      
 207 
     | 
    
         
            +
                        o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
         
     | 
| 
      
 208 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 209 
     | 
    
         
            +
                        o = torch.empty_like(q)
         
     | 
| 
      
 210 
     | 
    
         
            +
             
     | 
| 
      
 211 
     | 
    
         
            +
                    # TODO: Add min seqlen
         
     | 
| 
      
 212 
     | 
    
         
            +
                    (
         
     | 
| 
      
 213 
     | 
    
         
            +
                        start_loc,
         
     | 
| 
      
 214 
     | 
    
         
            +
                        attn_logits,
         
     | 
| 
      
 215 
     | 
    
         
            +
                        max_seq_len,
         
     | 
| 
      
 216 
     | 
    
         
            +
                        min_seq_len,
         
     | 
| 
      
 217 
     | 
    
         
            +
                        max_extend_len,
         
     | 
| 
      
 218 
     | 
    
         
            +
                        ds_req_to_token,
         
     | 
| 
      
 219 
     | 
    
         
            +
                    ) = self.forward_metadata
         
     | 
| 
      
 220 
     | 
    
         
            +
             
     | 
| 
      
 221 
     | 
    
         
            +
                    k_label = torch.gather(
         
     | 
| 
      
 222 
     | 
    
         
            +
                        k,
         
     | 
| 
      
 223 
     | 
    
         
            +
                        2,
         
     | 
| 
      
 224 
     | 
    
         
            +
                        self.sorted_channels[layer.layer_id]
         
     | 
| 
      
 225 
     | 
    
         
            +
                        .unsqueeze(0)
         
     | 
| 
      
 226 
     | 
    
         
            +
                        .expand(k.shape[0], -1, -1),
         
     | 
| 
      
 227 
     | 
    
         
            +
                    )
         
     | 
| 
      
 228 
     | 
    
         
            +
             
     | 
| 
      
 229 
     | 
    
         
            +
                    forward_batch.token_to_kv_pool.set_kv_buffer(
         
     | 
| 
      
 230 
     | 
    
         
            +
                        layer.layer_id, forward_batch.out_cache_loc, k, v, k_label
         
     | 
| 
      
 231 
     | 
    
         
            +
                    )
         
     | 
| 
      
 232 
     | 
    
         
            +
             
     | 
| 
      
 233 
     | 
    
         
            +
                    # NOTE(Andy) shouldn't be used when max_len_in_batch < heavy_token_num
         
     | 
| 
      
 234 
     | 
    
         
            +
                    #            and set a minimum value for sparse_decode
         
     | 
| 
      
 235 
     | 
    
         
            +
                    if (
         
     | 
| 
      
 236 
     | 
    
         
            +
                        min_seq_len < self.heavy_token_num
         
     | 
| 
      
 237 
     | 
    
         
            +
                        or max_seq_len < self.sparse_decode_thresold
         
     | 
| 
      
 238 
     | 
    
         
            +
                    ):
         
     | 
| 
      
 239 
     | 
    
         
            +
                        self.decode_attention_fwd(
         
     | 
| 
      
 240 
     | 
    
         
            +
                            q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
         
     | 
| 
      
 241 
     | 
    
         
            +
                            forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
         
     | 
| 
      
 242 
     | 
    
         
            +
                            forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
         
     | 
| 
      
 243 
     | 
    
         
            +
                            o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
         
     | 
| 
      
 244 
     | 
    
         
            +
                            forward_batch.req_to_token_pool.req_to_token,
         
     | 
| 
      
 245 
     | 
    
         
            +
                            forward_batch.req_pool_indices,
         
     | 
| 
      
 246 
     | 
    
         
            +
                            start_loc,
         
     | 
| 
      
 247 
     | 
    
         
            +
                            forward_batch.seq_lens,
         
     | 
| 
      
 248 
     | 
    
         
            +
                            attn_logits,
         
     | 
| 
      
 249 
     | 
    
         
            +
                            max_seq_len,
         
     | 
| 
      
 250 
     | 
    
         
            +
                            layer.scaling,
         
     | 
| 
      
 251 
     | 
    
         
            +
                            layer.logit_cap,
         
     | 
| 
      
 252 
     | 
    
         
            +
                        )
         
     | 
| 
      
 253 
     | 
    
         
            +
                    else:
         
     | 
| 
      
 254 
     | 
    
         
            +
                        # TODO(Andy): indexing with torch.gather or torch.index_select or customized kernel
         
     | 
| 
      
 255 
     | 
    
         
            +
                        q_label = torch.gather(
         
     | 
| 
      
 256 
     | 
    
         
            +
                            q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
         
     | 
| 
      
 257 
     | 
    
         
            +
                            2,
         
     | 
| 
      
 258 
     | 
    
         
            +
                            self.sorted_channels[layer.layer_id]
         
     | 
| 
      
 259 
     | 
    
         
            +
                            .unsqueeze(0)
         
     | 
| 
      
 260 
     | 
    
         
            +
                            .expand(q.shape[0], -1, -1),
         
     | 
| 
      
 261 
     | 
    
         
            +
                        )
         
     | 
| 
      
 262 
     | 
    
         
            +
                        self.decode_sparse_attention_fwd(
         
     | 
| 
      
 263 
     | 
    
         
            +
                            q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
         
     | 
| 
      
 264 
     | 
    
         
            +
                            forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
         
     | 
| 
      
 265 
     | 
    
         
            +
                            forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
         
     | 
| 
      
 266 
     | 
    
         
            +
                            o.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
         
     | 
| 
      
 267 
     | 
    
         
            +
                            q_label,
         
     | 
| 
      
 268 
     | 
    
         
            +
                            forward_batch.token_to_kv_pool.get_label_buffer(layer.layer_id),
         
     | 
| 
      
 269 
     | 
    
         
            +
                            ds_req_to_token,
         
     | 
| 
      
 270 
     | 
    
         
            +
                            forward_batch.seq_lens,
         
     | 
| 
      
 271 
     | 
    
         
            +
                            max_seq_len,
         
     | 
| 
      
 272 
     | 
    
         
            +
                            layer.scaling,
         
     | 
| 
      
 273 
     | 
    
         
            +
                            layer.logit_cap,
         
     | 
| 
      
 274 
     | 
    
         
            +
                            self.heavy_token_num,
         
     | 
| 
      
 275 
     | 
    
         
            +
                            self.att_out_approx,
         
     | 
| 
      
 276 
     | 
    
         
            +
                            self.mid_out,
         
     | 
| 
      
 277 
     | 
    
         
            +
                            self.mid_o_logexpsum,
         
     | 
| 
      
 278 
     | 
    
         
            +
                            self.BLOCK_SEQ,
         
     | 
| 
      
 279 
     | 
    
         
            +
                        )
         
     | 
| 
      
 280 
     | 
    
         
            +
             
     | 
| 
      
 281 
     | 
    
         
            +
                    return o
         
     |