sglang 0.4.9__py3-none-any.whl → 0.4.9.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +2 -2
- sglang/srt/configs/model_config.py +36 -2
- sglang/srt/conversation.py +56 -3
- sglang/srt/disaggregation/ascend/__init__.py +6 -0
- sglang/srt/disaggregation/ascend/conn.py +44 -0
- sglang/srt/disaggregation/ascend/transfer_engine.py +58 -0
- sglang/srt/disaggregation/mooncake/conn.py +50 -18
- sglang/srt/disaggregation/mooncake/transfer_engine.py +17 -8
- sglang/srt/disaggregation/utils.py +25 -3
- sglang/srt/entrypoints/engine.py +1 -1
- sglang/srt/entrypoints/http_server.py +1 -0
- sglang/srt/entrypoints/http_server_engine.py +1 -1
- sglang/srt/entrypoints/openai/protocol.py +11 -0
- sglang/srt/entrypoints/openai/serving_chat.py +7 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/kimik2_detector.py +220 -0
- sglang/srt/hf_transformers_utils.py +18 -0
- sglang/srt/jinja_template_utils.py +8 -0
- sglang/srt/layers/communicator.py +20 -5
- sglang/srt/layers/flashinfer_comm_fusion.py +3 -3
- sglang/srt/layers/layernorm.py +2 -2
- sglang/srt/layers/linear.py +12 -2
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +60 -1
- sglang/srt/layers/moe/ep_moe/layer.py +141 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +2 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +141 -59
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
- sglang/srt/layers/moe/topk.py +8 -2
- sglang/srt/layers/parameter.py +19 -3
- sglang/srt/layers/quantization/__init__.py +2 -0
- sglang/srt/layers/quantization/fp8.py +28 -7
- sglang/srt/layers/quantization/fp8_kernel.py +2 -2
- sglang/srt/layers/quantization/modelopt_quant.py +244 -1
- sglang/srt/layers/quantization/moe_wna16.py +1 -2
- sglang/srt/layers/quantization/w4afp8.py +264 -0
- sglang/srt/layers/quantization/w8a8_int8.py +738 -14
- sglang/srt/layers/vocab_parallel_embedding.py +9 -3
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
- sglang/srt/managers/cache_controller.py +41 -195
- sglang/srt/managers/io_struct.py +35 -3
- sglang/srt/managers/mm_utils.py +59 -96
- sglang/srt/managers/schedule_batch.py +17 -6
- sglang/srt/managers/scheduler.py +38 -6
- sglang/srt/managers/tokenizer_manager.py +16 -0
- sglang/srt/mem_cache/hiradix_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +176 -101
- sglang/srt/mem_cache/memory_pool_host.py +6 -109
- sglang/srt/mem_cache/radix_cache.py +8 -4
- sglang/srt/model_executor/forward_batch_info.py +13 -1
- sglang/srt/model_loader/loader.py +23 -12
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_v2.py +78 -19
- sglang/srt/models/deepseek_vl2.py +1 -1
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +6 -3
- sglang/srt/models/internvl.py +8 -2
- sglang/srt/models/kimi_vl.py +8 -2
- sglang/srt/models/llama.py +2 -0
- sglang/srt/models/llava.py +3 -1
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpmo.py +1 -2
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mixtral_quant.py +4 -0
- sglang/srt/models/mllama4.py +372 -82
- sglang/srt/models/phi4mm.py +8 -2
- sglang/srt/models/phimoe.py +553 -0
- sglang/srt/models/qwen2.py +2 -0
- sglang/srt/models/qwen2_5_vl.py +10 -7
- sglang/srt/models/qwen2_vl.py +12 -1
- sglang/srt/models/vila.py +8 -2
- sglang/srt/multimodal/mm_utils.py +2 -2
- sglang/srt/multimodal/processors/base_processor.py +197 -137
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +1 -1
- sglang/srt/multimodal/processors/gemma3.py +4 -2
- sglang/srt/multimodal/processors/gemma3n.py +1 -1
- sglang/srt/multimodal/processors/internvl.py +1 -1
- sglang/srt/multimodal/processors/janus_pro.py +1 -1
- sglang/srt/multimodal/processors/kimi_vl.py +1 -1
- sglang/srt/multimodal/processors/minicpm.py +4 -3
- sglang/srt/multimodal/processors/mllama4.py +63 -61
- sglang/srt/multimodal/processors/phi4mm.py +1 -1
- sglang/srt/multimodal/processors/pixtral.py +1 -1
- sglang/srt/multimodal/processors/qwen_vl.py +203 -80
- sglang/srt/multimodal/processors/vila.py +1 -1
- sglang/srt/server_args.py +26 -4
- sglang/srt/two_batch_overlap.py +3 -0
- sglang/srt/utils.py +191 -48
- sglang/test/test_cutlass_w4a8_moe.py +281 -0
- sglang/utils.py +5 -5
- sglang/version.py +1 -1
- {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/METADATA +6 -4
- {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/RECORD +99 -90
- {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/top_level.txt +0 -0
@@ -37,12 +37,15 @@ import triton.language as tl
|
|
37
37
|
|
38
38
|
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
|
39
39
|
from sglang.srt.layers.radix_attention import RadixAttention
|
40
|
-
from sglang.srt.utils import
|
40
|
+
from sglang.srt.utils import get_bool_env_var, is_cuda, is_npu, next_power_of_2
|
41
41
|
|
42
42
|
logger = logging.getLogger(__name__)
|
43
43
|
|
44
44
|
GB = 1024 * 1024 * 1024
|
45
45
|
_is_cuda = is_cuda()
|
46
|
+
_is_npu = is_npu()
|
47
|
+
if not _is_npu:
|
48
|
+
from sgl_kernel.kvcacheio import transfer_kv_per_layer, transfer_kv_per_layer_mla
|
46
49
|
|
47
50
|
|
48
51
|
class ReqToTokenPool:
|
@@ -150,13 +153,16 @@ class KVCache(abc.ABC):
|
|
150
153
|
) -> None:
|
151
154
|
raise NotImplementedError()
|
152
155
|
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
156
|
+
@abc.abstractmethod
|
157
|
+
def load_from_host_per_layer(
|
158
|
+
self, host_pool, host_indices, device_indices, layer_id, io_backend
|
159
|
+
):
|
157
160
|
raise NotImplementedError()
|
158
161
|
|
159
|
-
|
162
|
+
@abc.abstractmethod
|
163
|
+
def backup_to_host_all_layer(
|
164
|
+
self, host_pool, host_indices, device_indices, io_backend
|
165
|
+
):
|
160
166
|
raise NotImplementedError()
|
161
167
|
|
162
168
|
def register_layer_transfer_counter(self, layer_transfer_counter):
|
@@ -247,7 +253,7 @@ class MHATokenToKVPool(KVCache):
|
|
247
253
|
)
|
248
254
|
for _ in range(self.layer_num)
|
249
255
|
]
|
250
|
-
|
256
|
+
self.token_stride = self.head_num * self.head_dim
|
251
257
|
self.data_ptrs = torch.tensor(
|
252
258
|
[x.data_ptr() for x in self.k_buffer + self.v_buffer],
|
253
259
|
dtype=torch.uint64,
|
@@ -281,24 +287,24 @@ class MHATokenToKVPool(KVCache):
|
|
281
287
|
# layer_num x [seq_len, head_num, head_dim]
|
282
288
|
# layer_num x [page_num, page_size, head_num, head_dim]
|
283
289
|
kv_data_ptrs = [
|
284
|
-
self.
|
290
|
+
self._get_key_buffer(i).data_ptr()
|
285
291
|
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
286
292
|
] + [
|
287
|
-
self.
|
293
|
+
self._get_value_buffer(i).data_ptr()
|
288
294
|
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
289
295
|
]
|
290
296
|
kv_data_lens = [
|
291
|
-
self.
|
297
|
+
self._get_key_buffer(i).nbytes
|
292
298
|
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
293
299
|
] + [
|
294
|
-
self.
|
300
|
+
self._get_value_buffer(i).nbytes
|
295
301
|
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
296
302
|
]
|
297
303
|
kv_item_lens = [
|
298
|
-
self.
|
304
|
+
self._get_key_buffer(i)[0].nbytes * self.page_size
|
299
305
|
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
300
306
|
] + [
|
301
|
-
self.
|
307
|
+
self._get_value_buffer(i)[0].nbytes * self.page_size
|
302
308
|
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
303
309
|
]
|
304
310
|
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
@@ -341,49 +347,73 @@ class MHATokenToKVPool(KVCache):
|
|
341
347
|
self.v_buffer[layer_id][chunk_indices] = v_chunk
|
342
348
|
torch.cuda.synchronize()
|
343
349
|
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
350
|
+
def load_from_host_per_layer(
|
351
|
+
self,
|
352
|
+
host_pool,
|
353
|
+
host_indices,
|
354
|
+
device_indices,
|
355
|
+
layer_id,
|
356
|
+
io_backend,
|
357
|
+
):
|
358
|
+
transfer_kv_per_layer(
|
359
|
+
src_k=host_pool.k_buffer[layer_id],
|
360
|
+
dst_k=self.k_buffer[layer_id],
|
361
|
+
src_v=host_pool.v_buffer[layer_id],
|
362
|
+
dst_v=self.v_buffer[layer_id],
|
363
|
+
src_indices=host_indices,
|
364
|
+
dst_indices=device_indices,
|
365
|
+
io_backend=io_backend,
|
366
|
+
page_size=self.page_size,
|
367
|
+
item_size=self.token_stride,
|
352
368
|
)
|
353
|
-
return flatten
|
354
|
-
|
355
|
-
@debug_timing
|
356
|
-
def transfer(self, indices, flat_data):
|
357
|
-
# transfer prepared data from host to device
|
358
|
-
flat_data = flat_data.to(device=self.device, non_blocking=False)
|
359
|
-
k_data, v_data = flat_data[0], flat_data[1]
|
360
|
-
for i in range(self.layer_num):
|
361
|
-
self.k_buffer[i][indices] = k_data[i]
|
362
|
-
self.v_buffer[i][indices] = v_data[i]
|
363
|
-
|
364
|
-
def transfer_per_layer(self, indices, flat_data, layer_id):
|
365
|
-
# transfer prepared data from host to device
|
366
|
-
flat_data = flat_data.to(device=self.device, non_blocking=False)
|
367
|
-
k_data, v_data = flat_data[0], flat_data[1]
|
368
|
-
self.k_buffer[layer_id - self.start_layer][indices] = k_data
|
369
|
-
self.v_buffer[layer_id - self.start_layer][indices] = v_data
|
370
369
|
|
371
|
-
def
|
372
|
-
|
373
|
-
|
370
|
+
def backup_to_host_all_layer(
|
371
|
+
self, host_pool, host_indices, device_indices, io_backend
|
372
|
+
):
|
373
|
+
# todo: specialized all layer kernels for the layer-non-contiguous memory pool
|
374
|
+
for layer_id in range(self.start_layer, self.start_layer + self.layer_num):
|
375
|
+
if layer_id - self.start_layer >= len(host_pool.k_buffer):
|
376
|
+
raise ValueError(
|
377
|
+
f"Layer ID {layer_id} exceeds the number of layers in host pool."
|
378
|
+
)
|
379
|
+
transfer_kv_per_layer(
|
380
|
+
src_k=self.k_buffer[layer_id],
|
381
|
+
dst_k=host_pool.k_buffer[layer_id],
|
382
|
+
src_v=self.v_buffer[layer_id],
|
383
|
+
dst_v=host_pool.v_buffer[layer_id],
|
384
|
+
src_indices=device_indices,
|
385
|
+
dst_indices=host_indices,
|
386
|
+
io_backend=io_backend,
|
387
|
+
page_size=self.page_size,
|
388
|
+
item_size=self.token_stride,
|
389
|
+
)
|
374
390
|
|
391
|
+
def _get_key_buffer(self, layer_id: int):
|
392
|
+
# for internal use of referencing
|
375
393
|
if self.store_dtype != self.dtype:
|
376
394
|
return self.k_buffer[layer_id - self.start_layer].view(self.dtype)
|
377
395
|
return self.k_buffer[layer_id - self.start_layer]
|
378
396
|
|
379
|
-
def
|
397
|
+
def get_key_buffer(self, layer_id: int):
|
398
|
+
# note: get_key_buffer is hooked with synchronization for layer-wise KV cache loading
|
399
|
+
# it is supposed to be used only by attention backend not for information purpose
|
400
|
+
# same applies to get_value_buffer and get_kv_buffer
|
380
401
|
if self.layer_transfer_counter is not None:
|
381
402
|
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
382
403
|
|
404
|
+
return self._get_key_buffer(layer_id)
|
405
|
+
|
406
|
+
def _get_value_buffer(self, layer_id: int):
|
407
|
+
# for internal use of referencing
|
383
408
|
if self.store_dtype != self.dtype:
|
384
409
|
return self.v_buffer[layer_id - self.start_layer].view(self.dtype)
|
385
410
|
return self.v_buffer[layer_id - self.start_layer]
|
386
411
|
|
412
|
+
def get_value_buffer(self, layer_id: int):
|
413
|
+
if self.layer_transfer_counter is not None:
|
414
|
+
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
415
|
+
return self._get_value_buffer(layer_id)
|
416
|
+
|
387
417
|
def get_kv_buffer(self, layer_id: int):
|
388
418
|
return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
|
389
419
|
|
@@ -574,32 +604,49 @@ class AscendTokenToKVPool(MHATokenToKVPool):
|
|
574
604
|
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
575
605
|
# [size, head_num, head_dim] for each layer
|
576
606
|
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
577
|
-
|
578
|
-
|
579
|
-
|
580
|
-
|
581
|
-
|
582
|
-
|
583
|
-
|
584
|
-
|
585
|
-
|
586
|
-
|
587
|
-
)
|
588
|
-
|
589
|
-
|
590
|
-
|
591
|
-
|
592
|
-
|
593
|
-
|
594
|
-
|
595
|
-
|
596
|
-
|
597
|
-
|
598
|
-
|
599
|
-
|
600
|
-
|
601
|
-
|
602
|
-
|
607
|
+
# Continuous memory improves the efficiency of Ascend`s transmission backend,
|
608
|
+
# while other backends remain unchanged.
|
609
|
+
self.kv_buffer = torch.zeros(
|
610
|
+
(
|
611
|
+
2,
|
612
|
+
self.layer_num,
|
613
|
+
self.size // self.page_size + 1,
|
614
|
+
self.page_size,
|
615
|
+
self.head_num,
|
616
|
+
self.head_dim,
|
617
|
+
),
|
618
|
+
dtype=self.store_dtype,
|
619
|
+
device=self.device,
|
620
|
+
)
|
621
|
+
self.k_buffer = self.kv_buffer[0]
|
622
|
+
self.v_buffer = self.kv_buffer[1]
|
623
|
+
|
624
|
+
# for disagg
|
625
|
+
def get_contiguous_buf_infos(self):
|
626
|
+
# layer_num x [seq_len, head_num, head_dim]
|
627
|
+
# layer_num x [page_num, page_size, head_num, head_dim]
|
628
|
+
kv_data_ptrs = [
|
629
|
+
self.get_key_buffer(i).data_ptr()
|
630
|
+
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
631
|
+
] + [
|
632
|
+
self.get_value_buffer(i).data_ptr()
|
633
|
+
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
634
|
+
]
|
635
|
+
kv_data_lens = [
|
636
|
+
self.get_key_buffer(i).nbytes
|
637
|
+
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
638
|
+
] + [
|
639
|
+
self.get_value_buffer(i).nbytes
|
640
|
+
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
641
|
+
]
|
642
|
+
kv_item_lens = [
|
643
|
+
self.get_key_buffer(i)[0].nbytes
|
644
|
+
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
645
|
+
] + [
|
646
|
+
self.get_value_buffer(i)[0].nbytes
|
647
|
+
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
648
|
+
]
|
649
|
+
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
603
650
|
|
604
651
|
def set_kv_buffer(
|
605
652
|
self,
|
@@ -761,6 +808,7 @@ class MLATokenToKVPool(KVCache):
|
|
761
808
|
for _ in range(layer_num)
|
762
809
|
]
|
763
810
|
|
811
|
+
self.token_stride = kv_lora_rank + qk_rope_head_dim
|
764
812
|
self.layer_transfer_counter = None
|
765
813
|
|
766
814
|
kv_size = self.get_kv_size_bytes()
|
@@ -846,21 +894,37 @@ class MLATokenToKVPool(KVCache):
|
|
846
894
|
self.kv_buffer[layer_id], loc, cache_k_nope, cache_k_rope
|
847
895
|
)
|
848
896
|
|
849
|
-
def
|
850
|
-
|
851
|
-
|
852
|
-
|
853
|
-
|
854
|
-
|
855
|
-
|
856
|
-
|
857
|
-
|
858
|
-
self.
|
897
|
+
def load_from_host_per_layer(
|
898
|
+
self, host_pool, host_indices, device_indices, layer_id, io_backend
|
899
|
+
):
|
900
|
+
transfer_kv_per_layer_mla(
|
901
|
+
src=host_pool.kv_buffer[layer_id],
|
902
|
+
dst=self.kv_buffer[layer_id],
|
903
|
+
src_indices=host_indices,
|
904
|
+
dst_indices=device_indices,
|
905
|
+
io_backend=io_backend,
|
906
|
+
page_size=self.page_size,
|
907
|
+
item_size=self.token_stride,
|
908
|
+
)
|
859
909
|
|
860
|
-
def
|
861
|
-
|
862
|
-
|
863
|
-
|
910
|
+
def backup_to_host_all_layer(
|
911
|
+
self, host_pool, host_indices, device_indices, io_backend
|
912
|
+
):
|
913
|
+
# todo: specialized all layer kernels for the layer-non-contiguous memory pool
|
914
|
+
for layer_id in range(self.start_layer, self.start_layer + self.layer_num):
|
915
|
+
if layer_id - self.start_layer >= len(host_pool.kv_buffer):
|
916
|
+
raise ValueError(
|
917
|
+
f"Layer ID {layer_id} exceeds the number of layers in host pool."
|
918
|
+
)
|
919
|
+
transfer_kv_per_layer_mla(
|
920
|
+
src=self.kv_buffer[layer_id],
|
921
|
+
dst=host_pool.kv_buffer[layer_id],
|
922
|
+
src_indices=device_indices,
|
923
|
+
dst_indices=host_indices,
|
924
|
+
io_backend=io_backend,
|
925
|
+
page_size=self.page_size,
|
926
|
+
item_size=self.token_stride,
|
927
|
+
)
|
864
928
|
|
865
929
|
def get_cpu_copy(self, indices):
|
866
930
|
torch.cuda.synchronize()
|
@@ -922,18 +986,16 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
|
922
986
|
|
923
987
|
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
924
988
|
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
925
|
-
self.kv_buffer =
|
926
|
-
|
927
|
-
|
928
|
-
|
929
|
-
|
930
|
-
|
931
|
-
|
932
|
-
|
933
|
-
|
934
|
-
|
935
|
-
for _ in range(layer_num)
|
936
|
-
]
|
989
|
+
self.kv_buffer = torch.zeros(
|
990
|
+
(
|
991
|
+
layer_num,
|
992
|
+
self.size // self.page_size + 1,
|
993
|
+
self.page_size,
|
994
|
+
self.kv_lora_rank + self.qk_rope_head_dim,
|
995
|
+
),
|
996
|
+
dtype=self.store_dtype,
|
997
|
+
device=self.device,
|
998
|
+
)
|
937
999
|
|
938
1000
|
self.layer_transfer_counter = None
|
939
1001
|
|
@@ -943,6 +1005,14 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
|
943
1005
|
)
|
944
1006
|
self.mem_usage = kv_size / GB
|
945
1007
|
|
1008
|
+
# for disagg
|
1009
|
+
def get_contiguous_buf_infos(self):
|
1010
|
+
# MLA has only one kv_buffer, so only the information of this buffer needs to be returned.
|
1011
|
+
kv_data_ptrs = [self.kv_buffer[i].data_ptr() for i in range(self.layer_num)]
|
1012
|
+
kv_data_lens = [self.kv_buffer[i].nbytes for i in range(self.layer_num)]
|
1013
|
+
kv_item_lens = [self.kv_buffer[i][0].nbytes for i in range(self.layer_num)]
|
1014
|
+
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
1015
|
+
|
946
1016
|
def set_kv_buffer(
|
947
1017
|
self,
|
948
1018
|
layer: RadixAttention,
|
@@ -1046,14 +1116,19 @@ class DoubleSparseTokenToKVPool(KVCache):
|
|
1046
1116
|
self.v_buffer[layer_id - self.start_layer][loc] = cache_v
|
1047
1117
|
self.label_buffer[layer_id - self.start_layer][loc] = cache_label
|
1048
1118
|
|
1049
|
-
def
|
1050
|
-
|
1051
|
-
|
1052
|
-
|
1053
|
-
|
1119
|
+
def load_from_host_per_layer(
|
1120
|
+
self, host_pool, host_indices, device_indices, layer_id, io_backend
|
1121
|
+
):
|
1122
|
+
raise NotImplementedError(
|
1123
|
+
"HiCache not supported for DoubleSparseTokenToKVPool."
|
1124
|
+
)
|
1054
1125
|
|
1055
|
-
def
|
1056
|
-
|
1126
|
+
def backup_to_host_all_layer(
|
1127
|
+
self, host_pool, host_indices, device_indices, io_backend
|
1128
|
+
):
|
1129
|
+
raise NotImplementedError(
|
1130
|
+
"HiCache not supported for DoubleSparseTokenToKVPool."
|
1131
|
+
)
|
1057
1132
|
|
1058
1133
|
|
1059
1134
|
@triton.jit
|
@@ -8,7 +8,6 @@ import psutil
|
|
8
8
|
import torch
|
9
9
|
|
10
10
|
from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool, MLATokenToKVPool
|
11
|
-
from sglang.srt.utils import debug_timing
|
12
11
|
|
13
12
|
logger = logging.getLogger(__name__)
|
14
13
|
|
@@ -99,22 +98,6 @@ class HostKVCache(abc.ABC):
|
|
99
98
|
def init_kv_buffer(self):
|
100
99
|
raise NotImplementedError()
|
101
100
|
|
102
|
-
@abc.abstractmethod
|
103
|
-
def transfer(self, indices, flat_data):
|
104
|
-
raise NotImplementedError()
|
105
|
-
|
106
|
-
@abc.abstractmethod
|
107
|
-
def get_flat_data(self, indices):
|
108
|
-
raise NotImplementedError()
|
109
|
-
|
110
|
-
@abc.abstractmethod
|
111
|
-
def get_flat_data_by_layer(self, indices, layer_id):
|
112
|
-
raise NotImplementedError()
|
113
|
-
|
114
|
-
@abc.abstractmethod
|
115
|
-
def assign_flat_data(self, indices, flat_data):
|
116
|
-
raise NotImplementedError()
|
117
|
-
|
118
101
|
@synchronized()
|
119
102
|
def clear(self):
|
120
103
|
# Initialize memory states and tracking structures.
|
@@ -243,58 +226,13 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|
243
226
|
pin_memory=self.pin_memory,
|
244
227
|
)
|
245
228
|
|
246
|
-
@
|
247
|
-
def
|
248
|
-
|
249
|
-
self.kv_buffer[:, :, indices] = flat_data.to(
|
250
|
-
device=self.device, non_blocking=False
|
251
|
-
)
|
229
|
+
@property
|
230
|
+
def k_buffer(self):
|
231
|
+
return self.kv_buffer[0]
|
252
232
|
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
def get_flat_data_by_layer(self, indices, layer_id):
|
257
|
-
return self.kv_buffer[:, layer_id - self.start_layer, indices]
|
258
|
-
|
259
|
-
def assign_flat_data(self, indices, flat_data):
|
260
|
-
self.kv_buffer[:, :, indices] = flat_data
|
261
|
-
|
262
|
-
def write_page_all_layers(self, host_indices, device_indices, device_pool):
|
263
|
-
device_indices_cpu = device_indices[:: self.page_size].cpu()
|
264
|
-
for i in range(len(device_indices_cpu)):
|
265
|
-
h_index = host_indices[i * self.page_size]
|
266
|
-
d_index = device_indices_cpu[i]
|
267
|
-
for j in range(self.layer_num):
|
268
|
-
self.kv_buffer[0, j, h_index : h_index + self.page_size].copy_(
|
269
|
-
device_pool.k_buffer[j][d_index : d_index + self.page_size],
|
270
|
-
non_blocking=True,
|
271
|
-
)
|
272
|
-
self.kv_buffer[1, j, h_index : h_index + self.page_size].copy_(
|
273
|
-
device_pool.v_buffer[j][d_index : d_index + self.page_size],
|
274
|
-
non_blocking=True,
|
275
|
-
)
|
276
|
-
|
277
|
-
def load_page_per_layer(self, host_indices, device_indices, device_pool, layer_id):
|
278
|
-
device_indices_cpu = device_indices[:: self.page_size].cpu()
|
279
|
-
for i in range(len(device_indices_cpu)):
|
280
|
-
h_index = host_indices[i * self.page_size]
|
281
|
-
d_index = device_indices_cpu[i]
|
282
|
-
device_pool.k_buffer[layer_id - self.start_layer][
|
283
|
-
d_index : d_index + self.page_size
|
284
|
-
].copy_(
|
285
|
-
self.kv_buffer[
|
286
|
-
0, layer_id - self.start_layer, h_index : h_index + self.page_size
|
287
|
-
],
|
288
|
-
non_blocking=True,
|
289
|
-
)
|
290
|
-
device_pool.v_buffer[layer_id - self.start_layer][
|
291
|
-
d_index : d_index + self.page_size
|
292
|
-
].copy_(
|
293
|
-
self.kv_buffer[
|
294
|
-
1, layer_id - self.start_layer, h_index : h_index + self.page_size
|
295
|
-
],
|
296
|
-
non_blocking=True,
|
297
|
-
)
|
233
|
+
@property
|
234
|
+
def v_buffer(self):
|
235
|
+
return self.kv_buffer[1]
|
298
236
|
|
299
237
|
|
300
238
|
class MLATokenToKVPoolHost(HostKVCache):
|
@@ -337,44 +275,3 @@ class MLATokenToKVPoolHost(HostKVCache):
|
|
337
275
|
device=self.device,
|
338
276
|
pin_memory=self.pin_memory,
|
339
277
|
)
|
340
|
-
|
341
|
-
@debug_timing
|
342
|
-
def transfer(self, indices, flat_data):
|
343
|
-
# backup prepared data from device to host
|
344
|
-
self.kv_buffer[:, indices] = flat_data.to(
|
345
|
-
device=self.device, non_blocking=False
|
346
|
-
)
|
347
|
-
|
348
|
-
def get_flat_data(self, indices):
|
349
|
-
return self.kv_buffer[:, indices]
|
350
|
-
|
351
|
-
def get_flat_data_by_layer(self, indices, layer_id):
|
352
|
-
return self.kv_buffer[layer_id - self.start_layer, indices]
|
353
|
-
|
354
|
-
def assign_flat_data(self, indices, flat_data):
|
355
|
-
self.kv_buffer[:, indices] = flat_data
|
356
|
-
|
357
|
-
def write_page_all_layers(self, host_indices, device_indices, device_pool):
|
358
|
-
device_indices_cpu = device_indices[:: self.page_size].cpu()
|
359
|
-
for i in range(len(device_indices_cpu)):
|
360
|
-
h_index = host_indices[i * self.page_size]
|
361
|
-
d_index = device_indices_cpu[i]
|
362
|
-
for j in range(self.layer_num):
|
363
|
-
self.kv_buffer[j, h_index : h_index + self.page_size].copy_(
|
364
|
-
device_pool.kv_buffer[j][d_index : d_index + self.page_size],
|
365
|
-
non_blocking=True,
|
366
|
-
)
|
367
|
-
|
368
|
-
def load_page_per_layer(self, host_indices, device_indices, device_pool, layer_id):
|
369
|
-
device_indices_cpu = device_indices[:: self.page_size].cpu()
|
370
|
-
for i in range(len(device_indices_cpu)):
|
371
|
-
h_index = host_indices[i * self.page_size]
|
372
|
-
d_index = device_indices_cpu[i]
|
373
|
-
device_pool.kv_buffer[layer_id - self.start_layer][
|
374
|
-
d_index : d_index + self.page_size
|
375
|
-
].copy_(
|
376
|
-
self.kv_buffer[
|
377
|
-
layer_id - self.start_layer, h_index : h_index + self.page_size
|
378
|
-
],
|
379
|
-
non_blocking=True,
|
380
|
-
)
|
@@ -196,11 +196,13 @@ class RadixCache(BasePrefixCache):
|
|
196
196
|
|
197
197
|
if self.page_size != 1:
|
198
198
|
page_aligned_len = len(kv_indices) // self.page_size * self.page_size
|
199
|
-
page_aligned_kv_indices = kv_indices[:page_aligned_len].
|
199
|
+
page_aligned_kv_indices = kv_indices[:page_aligned_len].to(
|
200
|
+
dtype=torch.int64, copy=True
|
201
|
+
)
|
200
202
|
self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
|
201
203
|
else:
|
202
204
|
page_aligned_len = len(kv_indices)
|
203
|
-
page_aligned_kv_indices = kv_indices.
|
205
|
+
page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
|
204
206
|
|
205
207
|
# Radix Cache takes one ref in memory pool
|
206
208
|
new_prefix_len = self.insert(
|
@@ -226,10 +228,12 @@ class RadixCache(BasePrefixCache):
|
|
226
228
|
|
227
229
|
if self.page_size != 1:
|
228
230
|
page_aligned_len = len(kv_indices) // self.page_size * self.page_size
|
229
|
-
page_aligned_kv_indices = kv_indices[:page_aligned_len].
|
231
|
+
page_aligned_kv_indices = kv_indices[:page_aligned_len].to(
|
232
|
+
dtype=torch.int64, copy=True
|
233
|
+
)
|
230
234
|
else:
|
231
235
|
page_aligned_len = len(kv_indices)
|
232
|
-
page_aligned_kv_indices = kv_indices.
|
236
|
+
page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
|
233
237
|
page_aligned_token_ids = token_ids[:page_aligned_len]
|
234
238
|
|
235
239
|
# Radix Cache takes one ref in memory pool
|
@@ -453,8 +453,20 @@ class ForwardBatch:
|
|
453
453
|
for mm_input in self.mm_inputs
|
454
454
|
)
|
455
455
|
|
456
|
+
def contains_video_inputs(self) -> bool:
|
457
|
+
if self.mm_inputs is None:
|
458
|
+
return False
|
459
|
+
return any(
|
460
|
+
mm_input is not None and mm_input.contains_video_inputs()
|
461
|
+
for mm_input in self.mm_inputs
|
462
|
+
)
|
463
|
+
|
456
464
|
def contains_mm_inputs(self) -> bool:
|
457
|
-
return
|
465
|
+
return (
|
466
|
+
self.contains_audio_inputs()
|
467
|
+
or self.contains_video_inputs()
|
468
|
+
or self.contains_image_inputs()
|
469
|
+
)
|
458
470
|
|
459
471
|
def _compute_mrope_positions(
|
460
472
|
self, model_runner: ModelRunner, batch: ModelWorkerBatch
|
@@ -64,10 +64,13 @@ from sglang.srt.model_loader.weight_utils import (
|
|
64
64
|
from sglang.srt.utils import (
|
65
65
|
get_bool_env_var,
|
66
66
|
get_device_capability,
|
67
|
+
is_npu,
|
67
68
|
is_pin_memory_available,
|
68
69
|
set_weight_attrs,
|
69
70
|
)
|
70
71
|
|
72
|
+
_is_npu = is_npu()
|
73
|
+
|
71
74
|
|
72
75
|
@contextmanager
|
73
76
|
def device_loading_context(module: torch.nn.Module, target_device: torch.device):
|
@@ -127,18 +130,19 @@ def _get_quantization_config(
|
|
127
130
|
# (yizhang2077) workaround for nvidia/Llama-4-Maverick-17B-128E-Eagle3
|
128
131
|
if quant_config is None:
|
129
132
|
return None
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
133
|
+
if not _is_npu:
|
134
|
+
major, minor = get_device_capability()
|
135
|
+
|
136
|
+
if major is not None and minor is not None:
|
137
|
+
assert 0 <= minor < 10
|
138
|
+
capability = major * 10 + minor
|
139
|
+
if capability < quant_config.get_min_capability():
|
140
|
+
raise ValueError(
|
141
|
+
f"The quantization method {model_config.quantization} "
|
142
|
+
"is not supported for the current GPU. "
|
143
|
+
f"Minimum capability: {quant_config.get_min_capability()}. "
|
144
|
+
f"Current capability: {capability}."
|
145
|
+
)
|
142
146
|
supported_dtypes = quant_config.get_supported_act_dtypes()
|
143
147
|
if model_config.dtype not in supported_dtypes:
|
144
148
|
raise ValueError(
|
@@ -157,6 +161,13 @@ def _initialize_model(
|
|
157
161
|
"""Initialize a model with the given configurations."""
|
158
162
|
model_class, _ = get_model_architecture(model_config)
|
159
163
|
packed_modules_mapping = getattr(model_class, "packed_modules_mapping", {})
|
164
|
+
if _is_npu:
|
165
|
+
packed_modules_mapping["fused_qkv_a_proj_with_mqa"] = [
|
166
|
+
"q_a_proj",
|
167
|
+
"kv_a_proj_with_mqa",
|
168
|
+
]
|
169
|
+
packed_modules_mapping["qkv_proj"] = ["q_proj", "k_proj", "v_proj"]
|
170
|
+
packed_modules_mapping["gate_up_proj"] = ["gate_proj", "up_proj"]
|
160
171
|
quant_config = _get_quantization_config(
|
161
172
|
model_config, load_config, packed_modules_mapping
|
162
173
|
)
|
@@ -1989,7 +1989,7 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
|
|
1989
1989
|
hidden_states = general_mm_embed_routine(
|
1990
1990
|
input_ids=input_ids,
|
1991
1991
|
forward_batch=forward_batch,
|
1992
|
-
|
1992
|
+
multimodal_model=self,
|
1993
1993
|
language_model=self.language_model,
|
1994
1994
|
positions=positions,
|
1995
1995
|
)
|