sglang 0.4.8__py3-none-any.whl → 0.4.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +17 -2
- sglang/bench_serving.py +168 -22
- sglang/srt/configs/internvl.py +4 -2
- sglang/srt/configs/janus_pro.py +1 -1
- sglang/srt/configs/model_config.py +49 -0
- sglang/srt/configs/update_config.py +119 -0
- sglang/srt/conversation.py +35 -0
- sglang/srt/custom_op.py +7 -1
- sglang/srt/disaggregation/base/conn.py +2 -0
- sglang/srt/disaggregation/decode.py +22 -6
- sglang/srt/disaggregation/mooncake/conn.py +289 -48
- sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
- sglang/srt/disaggregation/nixl/conn.py +100 -52
- sglang/srt/disaggregation/prefill.py +5 -4
- sglang/srt/disaggregation/utils.py +13 -12
- sglang/srt/distributed/parallel_state.py +44 -17
- sglang/srt/entrypoints/EngineBase.py +8 -0
- sglang/srt/entrypoints/engine.py +45 -9
- sglang/srt/entrypoints/http_server.py +111 -24
- sglang/srt/entrypoints/openai/protocol.py +51 -6
- sglang/srt/entrypoints/openai/serving_chat.py +52 -76
- sglang/srt/entrypoints/openai/serving_completions.py +1 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/eplb/__init__.py +0 -0
- sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
- sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
- sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
- sglang/srt/{managers → eplb}/expert_distribution.py +18 -1
- sglang/srt/{managers → eplb}/expert_location.py +1 -1
- sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
- sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
- sglang/srt/hf_transformers_utils.py +2 -1
- sglang/srt/layers/activation.py +7 -0
- sglang/srt/layers/amx_utils.py +86 -0
- sglang/srt/layers/attention/ascend_backend.py +219 -0
- sglang/srt/layers/attention/flashattention_backend.py +56 -23
- sglang/srt/layers/attention/tbo_backend.py +37 -9
- sglang/srt/layers/communicator.py +18 -2
- sglang/srt/layers/dp_attention.py +9 -3
- sglang/srt/layers/elementwise.py +76 -12
- sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
- sglang/srt/layers/layernorm.py +41 -0
- sglang/srt/layers/linear.py +99 -12
- sglang/srt/layers/logits_processor.py +15 -6
- sglang/srt/layers/moe/ep_moe/kernels.py +23 -8
- sglang/srt/layers/moe/ep_moe/layer.py +115 -25
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +42 -19
- sglang/srt/layers/moe/fused_moe_native.py +7 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +8 -4
- sglang/srt/layers/moe/fused_moe_triton/layer.py +129 -10
- sglang/srt/layers/moe/router.py +60 -22
- sglang/srt/layers/moe/topk.py +36 -28
- sglang/srt/layers/parameter.py +67 -7
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
- sglang/srt/layers/quantization/fp8.py +44 -0
- sglang/srt/layers/quantization/fp8_kernel.py +1 -1
- sglang/srt/layers/quantization/fp8_utils.py +6 -6
- sglang/srt/layers/quantization/gptq.py +5 -1
- sglang/srt/layers/quantization/moe_wna16.py +1 -1
- sglang/srt/layers/quantization/quant_utils.py +166 -0
- sglang/srt/layers/quantization/w8a8_int8.py +52 -1
- sglang/srt/layers/rotary_embedding.py +105 -13
- sglang/srt/layers/vocab_parallel_embedding.py +19 -2
- sglang/srt/lora/lora.py +4 -5
- sglang/srt/lora/lora_manager.py +73 -20
- sglang/srt/managers/configure_logging.py +1 -1
- sglang/srt/managers/io_struct.py +60 -15
- sglang/srt/managers/mm_utils.py +73 -59
- sglang/srt/managers/multimodal_processor.py +2 -6
- sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
- sglang/srt/managers/schedule_batch.py +80 -79
- sglang/srt/managers/scheduler.py +153 -63
- sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
- sglang/srt/managers/session_controller.py +12 -3
- sglang/srt/managers/tokenizer_manager.py +314 -103
- sglang/srt/managers/tp_worker.py +13 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
- sglang/srt/mem_cache/allocator.py +290 -0
- sglang/srt/mem_cache/chunk_cache.py +34 -2
- sglang/srt/mem_cache/memory_pool.py +289 -3
- sglang/srt/mem_cache/multimodal_cache.py +3 -0
- sglang/srt/model_executor/cuda_graph_runner.py +3 -2
- sglang/srt/model_executor/forward_batch_info.py +17 -4
- sglang/srt/model_executor/model_runner.py +302 -58
- sglang/srt/model_loader/loader.py +86 -10
- sglang/srt/model_loader/weight_utils.py +160 -3
- sglang/srt/models/deepseek_nextn.py +5 -4
- sglang/srt/models/deepseek_v2.py +305 -26
- sglang/srt/models/deepseek_vl2.py +3 -5
- sglang/srt/models/gemma3_causal.py +1 -2
- sglang/srt/models/gemma3n_audio.py +949 -0
- sglang/srt/models/gemma3n_causal.py +1010 -0
- sglang/srt/models/gemma3n_mm.py +495 -0
- sglang/srt/models/hunyuan.py +771 -0
- sglang/srt/models/kimi_vl.py +1 -2
- sglang/srt/models/llama.py +10 -4
- sglang/srt/models/llama4.py +32 -45
- sglang/srt/models/llama_eagle3.py +61 -11
- sglang/srt/models/llava.py +5 -5
- sglang/srt/models/minicpmo.py +2 -2
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mllama4.py +43 -11
- sglang/srt/models/phi4mm.py +1 -3
- sglang/srt/models/pixtral.py +3 -7
- sglang/srt/models/qwen2.py +31 -3
- sglang/srt/models/qwen2_5_vl.py +1 -3
- sglang/srt/models/qwen2_audio.py +200 -0
- sglang/srt/models/qwen2_moe.py +32 -6
- sglang/srt/models/qwen2_vl.py +1 -4
- sglang/srt/models/qwen3.py +94 -25
- sglang/srt/models/qwen3_moe.py +68 -21
- sglang/srt/models/vila.py +3 -8
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +150 -133
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
- sglang/srt/multimodal/processors/gemma3n.py +82 -0
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +3 -6
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
- sglang/srt/operations_strategy.py +6 -2
- sglang/srt/reasoning_parser.py +26 -0
- sglang/srt/sampling/sampling_batch_info.py +39 -1
- sglang/srt/server_args.py +85 -24
- sglang/srt/speculative/build_eagle_tree.py +57 -18
- sglang/srt/speculative/eagle_worker.py +6 -4
- sglang/srt/two_batch_overlap.py +204 -28
- sglang/srt/utils.py +369 -138
- sglang/srt/warmup.py +12 -3
- sglang/test/runners.py +10 -1
- sglang/test/test_utils.py +15 -3
- sglang/version.py +1 -1
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/METADATA +9 -6
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/RECORD +149 -137
- sglang/math_utils.py +0 -8
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
- /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
- /sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +0 -0
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/WHEEL +0 -0
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/top_level.txt +0 -0
@@ -27,10 +27,11 @@ KVCache actually holds the physical kv cache.
|
|
27
27
|
import abc
|
28
28
|
import logging
|
29
29
|
from contextlib import nullcontext
|
30
|
-
from typing import List, Optional, Tuple, Union
|
30
|
+
from typing import Dict, List, Optional, Tuple, Union
|
31
31
|
|
32
32
|
import numpy as np
|
33
33
|
import torch
|
34
|
+
import torch.distributed as dist
|
34
35
|
import triton
|
35
36
|
import triton.language as tl
|
36
37
|
|
@@ -66,6 +67,7 @@ class ReqToTokenPool:
|
|
66
67
|
self.req_to_token = torch.zeros(
|
67
68
|
(size, max_context_len), dtype=torch.int32, device=device
|
68
69
|
)
|
70
|
+
|
69
71
|
self.free_slots = list(range(size))
|
70
72
|
|
71
73
|
def write(self, indices, values):
|
@@ -121,6 +123,7 @@ class KVCache(abc.ABC):
|
|
121
123
|
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
|
122
124
|
enable=enable_memory_saver
|
123
125
|
)
|
126
|
+
self.mem_usage = 0
|
124
127
|
|
125
128
|
# used for chunked cpu-offloading
|
126
129
|
self.cpu_offloading_chunk_size = 8192
|
@@ -191,7 +194,6 @@ class MHATokenToKVPool(KVCache):
|
|
191
194
|
start_layer,
|
192
195
|
end_layer,
|
193
196
|
)
|
194
|
-
|
195
197
|
self.head_num = head_num
|
196
198
|
self.head_dim = head_dim
|
197
199
|
|
@@ -218,6 +220,7 @@ class MHATokenToKVPool(KVCache):
|
|
218
220
|
logger.info(
|
219
221
|
f"KV Cache is allocated. #tokens: {size}, K size: {k_size / GB:.2f} GB, V size: {v_size / GB:.2f} GB"
|
220
222
|
)
|
223
|
+
self.mem_usage = (k_size + v_size) / GB
|
221
224
|
|
222
225
|
def _create_buffers(self):
|
223
226
|
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
@@ -392,10 +395,14 @@ class MHATokenToKVPool(KVCache):
|
|
392
395
|
cache_v: torch.Tensor,
|
393
396
|
k_scale: Optional[float] = None,
|
394
397
|
v_scale: Optional[float] = None,
|
398
|
+
layer_id_override: Optional[int] = None,
|
395
399
|
):
|
396
400
|
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
397
401
|
|
398
|
-
|
402
|
+
if layer_id_override is not None:
|
403
|
+
layer_id = layer_id_override
|
404
|
+
else:
|
405
|
+
layer_id = layer.layer_id
|
399
406
|
if cache_k.dtype != self.dtype:
|
400
407
|
if k_scale is not None:
|
401
408
|
cache_k.div_(k_scale)
|
@@ -431,6 +438,206 @@ class MHATokenToKVPool(KVCache):
|
|
431
438
|
)
|
432
439
|
|
433
440
|
|
441
|
+
class SWAKVPool(KVCache):
|
442
|
+
"""KV cache with separate pools for full and SWA attention layers."""
|
443
|
+
|
444
|
+
def __init__(
|
445
|
+
self,
|
446
|
+
size: int,
|
447
|
+
size_swa: int,
|
448
|
+
dtype: torch.dtype,
|
449
|
+
head_num: int,
|
450
|
+
head_dim: int,
|
451
|
+
swa_attention_layer_ids: List[int],
|
452
|
+
full_attention_layer_ids: List[int],
|
453
|
+
enable_kvcache_transpose: bool,
|
454
|
+
device: str,
|
455
|
+
):
|
456
|
+
self.size = size
|
457
|
+
self.size_swa = size_swa
|
458
|
+
self.dtype = dtype
|
459
|
+
self.device = device
|
460
|
+
self.swa_layer_nums = len(swa_attention_layer_ids)
|
461
|
+
self.full_layer_nums = len(full_attention_layer_ids)
|
462
|
+
self.page_size = 1
|
463
|
+
# TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True
|
464
|
+
assert not enable_kvcache_transpose
|
465
|
+
TokenToKVPoolClass = MHATokenToKVPool
|
466
|
+
self.swa_kv_pool = TokenToKVPoolClass(
|
467
|
+
size=size_swa,
|
468
|
+
page_size=self.page_size,
|
469
|
+
dtype=dtype,
|
470
|
+
head_num=head_num,
|
471
|
+
head_dim=head_dim,
|
472
|
+
layer_num=self.swa_layer_nums,
|
473
|
+
device=device,
|
474
|
+
enable_memory_saver=False,
|
475
|
+
)
|
476
|
+
self.full_kv_pool = TokenToKVPoolClass(
|
477
|
+
size=size,
|
478
|
+
page_size=self.page_size,
|
479
|
+
dtype=dtype,
|
480
|
+
head_num=head_num,
|
481
|
+
head_dim=head_dim,
|
482
|
+
layer_num=self.full_layer_nums,
|
483
|
+
device=device,
|
484
|
+
enable_memory_saver=False,
|
485
|
+
)
|
486
|
+
self.layers_mapping: Dict[int, Tuple[int, bool]] = {}
|
487
|
+
for full_attn_layer_id, global_layer_id in enumerate(full_attention_layer_ids):
|
488
|
+
self.layers_mapping[global_layer_id] = (full_attn_layer_id, False)
|
489
|
+
for swa_layer_id, global_layer_id in enumerate(swa_attention_layer_ids):
|
490
|
+
self.layers_mapping[global_layer_id] = (swa_layer_id, True)
|
491
|
+
self.full_to_swa_index_mapping: Optional[torch.Tensor] = None
|
492
|
+
|
493
|
+
def get_kv_size_bytes(self):
|
494
|
+
raise NotImplementedError
|
495
|
+
|
496
|
+
def get_contiguous_buf_infos(self):
|
497
|
+
full_kv_data_ptrs, full_kv_data_lens, full_kv_item_lens = (
|
498
|
+
self.full_kv_pool.get_contiguous_buf_infos()
|
499
|
+
)
|
500
|
+
swa_kv_data_ptrs, swa_kv_data_lens, swa_kv_item_lens = (
|
501
|
+
self.swa_kv_pool.get_contiguous_buf_infos()
|
502
|
+
)
|
503
|
+
|
504
|
+
kv_data_ptrs = full_kv_data_ptrs + swa_kv_data_ptrs
|
505
|
+
kv_data_lens = full_kv_data_lens + swa_kv_data_lens
|
506
|
+
kv_item_lens = full_kv_item_lens + swa_kv_item_lens
|
507
|
+
|
508
|
+
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
509
|
+
|
510
|
+
def get_key_buffer(self, layer_id: int):
|
511
|
+
layer_id_pool, is_swa = self.layers_mapping[layer_id]
|
512
|
+
if is_swa:
|
513
|
+
return self.swa_kv_pool.get_key_buffer(layer_id_pool)
|
514
|
+
else:
|
515
|
+
return self.full_kv_pool.get_key_buffer(layer_id_pool)
|
516
|
+
|
517
|
+
def get_value_buffer(self, layer_id: int):
|
518
|
+
layer_id_pool, is_swa = self.layers_mapping[layer_id]
|
519
|
+
if is_swa:
|
520
|
+
return self.swa_kv_pool.get_value_buffer(layer_id_pool)
|
521
|
+
else:
|
522
|
+
return self.full_kv_pool.get_value_buffer(layer_id_pool)
|
523
|
+
|
524
|
+
def get_kv_buffer(self, layer_id: int):
|
525
|
+
layer_id_pool, is_swa = self.layers_mapping[layer_id]
|
526
|
+
if is_swa:
|
527
|
+
return self.swa_kv_pool.get_kv_buffer(layer_id_pool)
|
528
|
+
else:
|
529
|
+
return self.full_kv_pool.get_kv_buffer(layer_id_pool)
|
530
|
+
|
531
|
+
def translate_loc_from_full_to_swa(self, kv_indices: torch.Tensor):
|
532
|
+
assert self.full_to_swa_index_mapping is not None
|
533
|
+
return self.full_to_swa_index_mapping[kv_indices].to(torch.int32)
|
534
|
+
|
535
|
+
def set_kv_buffer(
|
536
|
+
self,
|
537
|
+
layer: RadixAttention,
|
538
|
+
loc: torch.Tensor,
|
539
|
+
cache_k: torch.Tensor,
|
540
|
+
cache_v: torch.Tensor,
|
541
|
+
k_scale: float = 1.0,
|
542
|
+
v_scale: float = 1.0,
|
543
|
+
):
|
544
|
+
|
545
|
+
layer_id = layer.layer_id
|
546
|
+
layer_id_pool, is_swa = self.layers_mapping[layer_id]
|
547
|
+
if is_swa:
|
548
|
+
if self.full_to_swa_index_mapping is not None:
|
549
|
+
loc = self.translate_loc_from_full_to_swa(loc)
|
550
|
+
self.swa_kv_pool.set_kv_buffer(
|
551
|
+
None,
|
552
|
+
loc,
|
553
|
+
cache_k,
|
554
|
+
cache_v,
|
555
|
+
k_scale,
|
556
|
+
v_scale,
|
557
|
+
layer_id_override=layer_id_pool,
|
558
|
+
)
|
559
|
+
else:
|
560
|
+
self.full_kv_pool.set_kv_buffer(
|
561
|
+
None,
|
562
|
+
loc,
|
563
|
+
cache_k,
|
564
|
+
cache_v,
|
565
|
+
k_scale,
|
566
|
+
v_scale,
|
567
|
+
layer_id_override=layer_id_pool,
|
568
|
+
)
|
569
|
+
|
570
|
+
|
571
|
+
class AscendTokenToKVPool(MHATokenToKVPool):
|
572
|
+
|
573
|
+
def _create_buffers(self):
|
574
|
+
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
575
|
+
# [size, head_num, head_dim] for each layer
|
576
|
+
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
577
|
+
self.k_buffer = [
|
578
|
+
torch.zeros(
|
579
|
+
(
|
580
|
+
self.size // self.page_size + 1,
|
581
|
+
self.page_size,
|
582
|
+
self.head_num,
|
583
|
+
self.head_dim,
|
584
|
+
),
|
585
|
+
dtype=self.store_dtype,
|
586
|
+
device=self.device,
|
587
|
+
)
|
588
|
+
for _ in range(self.layer_num)
|
589
|
+
]
|
590
|
+
self.v_buffer = [
|
591
|
+
torch.zeros(
|
592
|
+
(
|
593
|
+
self.size // self.page_size + 1,
|
594
|
+
self.page_size,
|
595
|
+
self.head_num,
|
596
|
+
self.head_dim,
|
597
|
+
),
|
598
|
+
dtype=self.store_dtype,
|
599
|
+
device=self.device,
|
600
|
+
)
|
601
|
+
for _ in range(self.layer_num)
|
602
|
+
]
|
603
|
+
|
604
|
+
def set_kv_buffer(
|
605
|
+
self,
|
606
|
+
layer: RadixAttention,
|
607
|
+
loc: torch.Tensor,
|
608
|
+
cache_k: torch.Tensor,
|
609
|
+
cache_v: torch.Tensor,
|
610
|
+
k_scale: Optional[float] = None,
|
611
|
+
v_scale: Optional[float] = None,
|
612
|
+
):
|
613
|
+
layer_id = layer.layer_id
|
614
|
+
if cache_k.dtype != self.dtype:
|
615
|
+
if k_scale is not None:
|
616
|
+
cache_k.div_(k_scale)
|
617
|
+
if v_scale is not None:
|
618
|
+
cache_v.div_(v_scale)
|
619
|
+
cache_k = cache_k.to(self.dtype)
|
620
|
+
cache_v = cache_v.to(self.dtype)
|
621
|
+
|
622
|
+
if self.store_dtype != self.dtype:
|
623
|
+
cache_k = cache_k.view(self.store_dtype)
|
624
|
+
cache_v = cache_v.view(self.store_dtype)
|
625
|
+
|
626
|
+
import torch_npu
|
627
|
+
|
628
|
+
torch_npu._npu_reshape_and_cache(
|
629
|
+
key=cache_k,
|
630
|
+
value=cache_v,
|
631
|
+
key_cache=self.k_buffer[layer_id].view(
|
632
|
+
-1, self.page_size, self.head_num, self.head_dim
|
633
|
+
),
|
634
|
+
value_cache=self.v_buffer[layer_id].view(
|
635
|
+
-1, self.page_size, self.head_num, self.head_dim
|
636
|
+
),
|
637
|
+
slot_indices=loc,
|
638
|
+
)
|
639
|
+
|
640
|
+
|
434
641
|
@triton.jit
|
435
642
|
def set_mla_kv_buffer_kernel(
|
436
643
|
kv_buffer_ptr,
|
@@ -560,6 +767,7 @@ class MLATokenToKVPool(KVCache):
|
|
560
767
|
logger.info(
|
561
768
|
f"KV Cache is allocated. #tokens: {size}, KV size: {kv_size / GB:.2f} GB"
|
562
769
|
)
|
770
|
+
self.mem_usage = kv_size / GB
|
563
771
|
|
564
772
|
def get_kv_size_bytes(self):
|
565
773
|
assert hasattr(self, "kv_buffer")
|
@@ -682,6 +890,84 @@ class MLATokenToKVPool(KVCache):
|
|
682
890
|
torch.cuda.synchronize()
|
683
891
|
|
684
892
|
|
893
|
+
class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
894
|
+
def __init__(
|
895
|
+
self,
|
896
|
+
size: int,
|
897
|
+
page_size: int,
|
898
|
+
dtype: torch.dtype,
|
899
|
+
kv_lora_rank: int,
|
900
|
+
qk_rope_head_dim: int,
|
901
|
+
layer_num: int,
|
902
|
+
device: str,
|
903
|
+
enable_memory_saver: bool,
|
904
|
+
start_layer: Optional[int] = None,
|
905
|
+
end_layer: Optional[int] = None,
|
906
|
+
):
|
907
|
+
super(MLATokenToKVPool, self).__init__(
|
908
|
+
size,
|
909
|
+
page_size,
|
910
|
+
dtype,
|
911
|
+
layer_num,
|
912
|
+
device,
|
913
|
+
enable_memory_saver,
|
914
|
+
start_layer,
|
915
|
+
end_layer,
|
916
|
+
)
|
917
|
+
|
918
|
+
self.kv_lora_rank = kv_lora_rank
|
919
|
+
self.qk_rope_head_dim = qk_rope_head_dim
|
920
|
+
|
921
|
+
self.custom_mem_pool = None
|
922
|
+
|
923
|
+
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
924
|
+
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
925
|
+
self.kv_buffer = [
|
926
|
+
torch.zeros(
|
927
|
+
(
|
928
|
+
self.size // self.page_size + 1,
|
929
|
+
self.page_size,
|
930
|
+
self.kv_lora_rank + self.qk_rope_head_dim,
|
931
|
+
),
|
932
|
+
dtype=self.store_dtype,
|
933
|
+
device=self.device,
|
934
|
+
)
|
935
|
+
for _ in range(layer_num)
|
936
|
+
]
|
937
|
+
|
938
|
+
self.layer_transfer_counter = None
|
939
|
+
|
940
|
+
kv_size = self.get_kv_size_bytes()
|
941
|
+
logger.info(
|
942
|
+
f"KV Cache is allocated. #tokens: {size}, KV size: {kv_size / GB:.2f} GB"
|
943
|
+
)
|
944
|
+
self.mem_usage = kv_size / GB
|
945
|
+
|
946
|
+
def set_kv_buffer(
|
947
|
+
self,
|
948
|
+
layer: RadixAttention,
|
949
|
+
loc: torch.Tensor,
|
950
|
+
cache_k: torch.Tensor,
|
951
|
+
cache_v: torch.Tensor,
|
952
|
+
):
|
953
|
+
layer_id = layer.layer_id
|
954
|
+
if cache_k.dtype != self.dtype:
|
955
|
+
cache_k = cache_k.to(self.dtype)
|
956
|
+
|
957
|
+
if self.store_dtype != self.dtype:
|
958
|
+
cache_k = cache_k.view(store_dtype)
|
959
|
+
|
960
|
+
import torch_npu
|
961
|
+
|
962
|
+
torch_npu._npu_reshape_and_cache_siso(
|
963
|
+
key=cache_k.view(-1, 1, self.kv_lora_rank + self.qk_rope_head_dim),
|
964
|
+
key_cache=self.kv_buffer[layer_id - self.start_layer].view(
|
965
|
+
-1, 1, 1, self.kv_lora_rank + self.qk_rope_head_dim
|
966
|
+
),
|
967
|
+
slot_indices=loc,
|
968
|
+
)
|
969
|
+
|
970
|
+
|
685
971
|
class DoubleSparseTokenToKVPool(KVCache):
|
686
972
|
def __init__(
|
687
973
|
self,
|
@@ -168,7 +168,7 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
|
168
168
|
capture_bs += [model_runner.req_to_token_pool.size]
|
169
169
|
|
170
170
|
if server_args.enable_two_batch_overlap:
|
171
|
-
capture_bs = [bs for bs in capture_bs if bs
|
171
|
+
capture_bs = [bs for bs in capture_bs if bs % 2 == 0]
|
172
172
|
|
173
173
|
if server_args.cuda_graph_max_bs:
|
174
174
|
capture_bs = [bs for bs in capture_bs if bs <= server_args.cuda_graph_max_bs]
|
@@ -421,7 +421,7 @@ class CudaGraphRunner:
|
|
421
421
|
empty_cache=False,
|
422
422
|
)
|
423
423
|
capture_range.set_description(
|
424
|
-
f"Capturing batches ({avail_mem=:.2f} GB)"
|
424
|
+
f"Capturing batches ({bs=} {avail_mem=:.2f} GB)"
|
425
425
|
)
|
426
426
|
|
427
427
|
with patch_model(
|
@@ -679,6 +679,7 @@ class CudaGraphRunner:
|
|
679
679
|
forward_mode=self.capture_forward_mode,
|
680
680
|
bs=bs,
|
681
681
|
num_token_non_padded=len(forward_batch.input_ids),
|
682
|
+
spec_info=forward_batch.spec_info,
|
682
683
|
)
|
683
684
|
if forward_batch.forward_mode.is_idle() and forward_batch.spec_info is not None:
|
684
685
|
forward_batch.spec_info.custom_mask = self.custom_mask
|
@@ -39,7 +39,12 @@ import triton
|
|
39
39
|
import triton.language as tl
|
40
40
|
|
41
41
|
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
|
42
|
-
from sglang.srt.utils import
|
42
|
+
from sglang.srt.utils import (
|
43
|
+
flatten_nested_list,
|
44
|
+
get_compiler_backend,
|
45
|
+
is_npu,
|
46
|
+
support_triton,
|
47
|
+
)
|
43
48
|
|
44
49
|
if TYPE_CHECKING:
|
45
50
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
@@ -50,6 +55,8 @@ if TYPE_CHECKING:
|
|
50
55
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
51
56
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
52
57
|
|
58
|
+
_is_npu = is_npu()
|
59
|
+
|
53
60
|
|
54
61
|
class ForwardMode(IntEnum):
|
55
62
|
# Extend a sequence. The KV cache of the beginning part of the sequence is already computed (e.g., system prompt).
|
@@ -247,6 +254,7 @@ class ForwardBatch:
|
|
247
254
|
dp_local_start_pos: Optional[torch.Tensor] = None # cached info at runtime
|
248
255
|
dp_local_num_tokens: Optional[torch.Tensor] = None # cached info at runtime
|
249
256
|
gathered_buffer: Optional[torch.Tensor] = None
|
257
|
+
is_extend_in_batch: bool = False
|
250
258
|
can_run_dp_cuda_graph: bool = False
|
251
259
|
global_forward_mode: Optional[ForwardMode] = None
|
252
260
|
|
@@ -292,6 +300,7 @@ class ForwardBatch:
|
|
292
300
|
return_logprob=batch.return_logprob,
|
293
301
|
top_logprobs_nums=batch.top_logprobs_nums,
|
294
302
|
token_ids_logprobs=batch.token_ids_logprobs,
|
303
|
+
is_extend_in_batch=batch.is_extend_in_batch,
|
295
304
|
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
|
296
305
|
global_forward_mode=batch.global_forward_mode,
|
297
306
|
lora_paths=batch.lora_paths,
|
@@ -352,7 +361,9 @@ class ForwardBatch:
|
|
352
361
|
|
353
362
|
if ret.forward_mode.is_idle():
|
354
363
|
ret.positions = torch.empty((0,), device=device)
|
355
|
-
TboForwardBatchPreparer.prepare(
|
364
|
+
TboForwardBatchPreparer.prepare(
|
365
|
+
ret, is_draft_worker=model_runner.is_draft_worker
|
366
|
+
)
|
356
367
|
return ret
|
357
368
|
|
358
369
|
# Override the positions with spec_info
|
@@ -397,7 +408,9 @@ class ForwardBatch:
|
|
397
408
|
if model_runner.server_args.lora_paths is not None:
|
398
409
|
model_runner.lora_manager.prepare_lora_batch(ret)
|
399
410
|
|
400
|
-
TboForwardBatchPreparer.prepare(
|
411
|
+
TboForwardBatchPreparer.prepare(
|
412
|
+
ret, is_draft_worker=model_runner.is_draft_worker
|
413
|
+
)
|
401
414
|
|
402
415
|
return ret
|
403
416
|
|
@@ -735,7 +748,7 @@ def compute_position_torch(
|
|
735
748
|
return positions.to(torch.int64), extend_start_loc
|
736
749
|
|
737
750
|
|
738
|
-
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
751
|
+
@torch.compile(dynamic=True, backend=get_compiler_backend(), disable=_is_npu)
|
739
752
|
def clamp_position(seq_lens):
|
740
753
|
return torch.clamp((seq_lens - 1), min=0).to(torch.int64)
|
741
754
|
|