sglang 0.4.9__py3-none-any.whl → 0.4.9.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. sglang/bench_serving.py +2 -2
  2. sglang/srt/configs/model_config.py +36 -2
  3. sglang/srt/conversation.py +56 -3
  4. sglang/srt/disaggregation/ascend/__init__.py +6 -0
  5. sglang/srt/disaggregation/ascend/conn.py +44 -0
  6. sglang/srt/disaggregation/ascend/transfer_engine.py +58 -0
  7. sglang/srt/disaggregation/mooncake/conn.py +50 -18
  8. sglang/srt/disaggregation/mooncake/transfer_engine.py +17 -8
  9. sglang/srt/disaggregation/utils.py +25 -3
  10. sglang/srt/entrypoints/engine.py +1 -1
  11. sglang/srt/entrypoints/http_server.py +1 -0
  12. sglang/srt/entrypoints/http_server_engine.py +1 -1
  13. sglang/srt/entrypoints/openai/protocol.py +11 -0
  14. sglang/srt/entrypoints/openai/serving_chat.py +7 -0
  15. sglang/srt/function_call/function_call_parser.py +2 -0
  16. sglang/srt/function_call/kimik2_detector.py +220 -0
  17. sglang/srt/hf_transformers_utils.py +18 -0
  18. sglang/srt/jinja_template_utils.py +8 -0
  19. sglang/srt/layers/communicator.py +20 -5
  20. sglang/srt/layers/flashinfer_comm_fusion.py +3 -3
  21. sglang/srt/layers/layernorm.py +2 -2
  22. sglang/srt/layers/linear.py +12 -2
  23. sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
  24. sglang/srt/layers/moe/ep_moe/kernels.py +60 -1
  25. sglang/srt/layers/moe/ep_moe/layer.py +141 -2
  26. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +2 -0
  27. sglang/srt/layers/moe/fused_moe_triton/layer.py +141 -59
  28. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
  29. sglang/srt/layers/moe/topk.py +8 -2
  30. sglang/srt/layers/parameter.py +19 -3
  31. sglang/srt/layers/quantization/__init__.py +2 -0
  32. sglang/srt/layers/quantization/fp8.py +28 -7
  33. sglang/srt/layers/quantization/fp8_kernel.py +2 -2
  34. sglang/srt/layers/quantization/modelopt_quant.py +244 -1
  35. sglang/srt/layers/quantization/moe_wna16.py +1 -2
  36. sglang/srt/layers/quantization/w4afp8.py +264 -0
  37. sglang/srt/layers/quantization/w8a8_int8.py +738 -14
  38. sglang/srt/layers/vocab_parallel_embedding.py +9 -3
  39. sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
  40. sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
  41. sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
  42. sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
  43. sglang/srt/managers/cache_controller.py +41 -195
  44. sglang/srt/managers/io_struct.py +35 -3
  45. sglang/srt/managers/mm_utils.py +59 -96
  46. sglang/srt/managers/schedule_batch.py +17 -6
  47. sglang/srt/managers/scheduler.py +38 -6
  48. sglang/srt/managers/tokenizer_manager.py +16 -0
  49. sglang/srt/mem_cache/hiradix_cache.py +2 -0
  50. sglang/srt/mem_cache/memory_pool.py +176 -101
  51. sglang/srt/mem_cache/memory_pool_host.py +6 -109
  52. sglang/srt/mem_cache/radix_cache.py +8 -4
  53. sglang/srt/model_executor/forward_batch_info.py +13 -1
  54. sglang/srt/model_loader/loader.py +23 -12
  55. sglang/srt/models/deepseek_janus_pro.py +1 -1
  56. sglang/srt/models/deepseek_v2.py +78 -19
  57. sglang/srt/models/deepseek_vl2.py +1 -1
  58. sglang/srt/models/gemma3_mm.py +1 -1
  59. sglang/srt/models/gemma3n_mm.py +6 -3
  60. sglang/srt/models/internvl.py +8 -2
  61. sglang/srt/models/kimi_vl.py +8 -2
  62. sglang/srt/models/llama.py +2 -0
  63. sglang/srt/models/llava.py +3 -1
  64. sglang/srt/models/llavavid.py +1 -1
  65. sglang/srt/models/minicpmo.py +1 -2
  66. sglang/srt/models/minicpmv.py +1 -1
  67. sglang/srt/models/mixtral_quant.py +4 -0
  68. sglang/srt/models/mllama4.py +372 -82
  69. sglang/srt/models/phi4mm.py +8 -2
  70. sglang/srt/models/phimoe.py +553 -0
  71. sglang/srt/models/qwen2.py +2 -0
  72. sglang/srt/models/qwen2_5_vl.py +10 -7
  73. sglang/srt/models/qwen2_vl.py +12 -1
  74. sglang/srt/models/vila.py +8 -2
  75. sglang/srt/multimodal/mm_utils.py +2 -2
  76. sglang/srt/multimodal/processors/base_processor.py +197 -137
  77. sglang/srt/multimodal/processors/deepseek_vl_v2.py +1 -1
  78. sglang/srt/multimodal/processors/gemma3.py +4 -2
  79. sglang/srt/multimodal/processors/gemma3n.py +1 -1
  80. sglang/srt/multimodal/processors/internvl.py +1 -1
  81. sglang/srt/multimodal/processors/janus_pro.py +1 -1
  82. sglang/srt/multimodal/processors/kimi_vl.py +1 -1
  83. sglang/srt/multimodal/processors/minicpm.py +4 -3
  84. sglang/srt/multimodal/processors/mllama4.py +63 -61
  85. sglang/srt/multimodal/processors/phi4mm.py +1 -1
  86. sglang/srt/multimodal/processors/pixtral.py +1 -1
  87. sglang/srt/multimodal/processors/qwen_vl.py +203 -80
  88. sglang/srt/multimodal/processors/vila.py +1 -1
  89. sglang/srt/server_args.py +26 -4
  90. sglang/srt/two_batch_overlap.py +3 -0
  91. sglang/srt/utils.py +191 -48
  92. sglang/test/test_cutlass_w4a8_moe.py +281 -0
  93. sglang/utils.py +5 -5
  94. sglang/version.py +1 -1
  95. {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/METADATA +6 -4
  96. {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/RECORD +99 -90
  97. {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/WHEEL +0 -0
  98. {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/licenses/LICENSE +0 -0
  99. {sglang-0.4.9.dist-info → sglang-0.4.9.post2.dist-info}/top_level.txt +0 -0
@@ -37,12 +37,15 @@ import triton.language as tl
37
37
 
38
38
  from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
39
39
  from sglang.srt.layers.radix_attention import RadixAttention
40
- from sglang.srt.utils import debug_timing, get_bool_env_var, is_cuda, next_power_of_2
40
+ from sglang.srt.utils import get_bool_env_var, is_cuda, is_npu, next_power_of_2
41
41
 
42
42
  logger = logging.getLogger(__name__)
43
43
 
44
44
  GB = 1024 * 1024 * 1024
45
45
  _is_cuda = is_cuda()
46
+ _is_npu = is_npu()
47
+ if not _is_npu:
48
+ from sgl_kernel.kvcacheio import transfer_kv_per_layer, transfer_kv_per_layer_mla
46
49
 
47
50
 
48
51
  class ReqToTokenPool:
@@ -150,13 +153,16 @@ class KVCache(abc.ABC):
150
153
  ) -> None:
151
154
  raise NotImplementedError()
152
155
 
153
- def get_flat_data(self, indices):
154
- raise NotImplementedError()
155
-
156
- def transfer(self, indices, flat_data):
156
+ @abc.abstractmethod
157
+ def load_from_host_per_layer(
158
+ self, host_pool, host_indices, device_indices, layer_id, io_backend
159
+ ):
157
160
  raise NotImplementedError()
158
161
 
159
- def transfer_per_layer(self, indices, flat_data, layer_id):
162
+ @abc.abstractmethod
163
+ def backup_to_host_all_layer(
164
+ self, host_pool, host_indices, device_indices, io_backend
165
+ ):
160
166
  raise NotImplementedError()
161
167
 
162
168
  def register_layer_transfer_counter(self, layer_transfer_counter):
@@ -247,7 +253,7 @@ class MHATokenToKVPool(KVCache):
247
253
  )
248
254
  for _ in range(self.layer_num)
249
255
  ]
250
-
256
+ self.token_stride = self.head_num * self.head_dim
251
257
  self.data_ptrs = torch.tensor(
252
258
  [x.data_ptr() for x in self.k_buffer + self.v_buffer],
253
259
  dtype=torch.uint64,
@@ -281,24 +287,24 @@ class MHATokenToKVPool(KVCache):
281
287
  # layer_num x [seq_len, head_num, head_dim]
282
288
  # layer_num x [page_num, page_size, head_num, head_dim]
283
289
  kv_data_ptrs = [
284
- self.get_key_buffer(i).data_ptr()
290
+ self._get_key_buffer(i).data_ptr()
285
291
  for i in range(self.start_layer, self.start_layer + self.layer_num)
286
292
  ] + [
287
- self.get_value_buffer(i).data_ptr()
293
+ self._get_value_buffer(i).data_ptr()
288
294
  for i in range(self.start_layer, self.start_layer + self.layer_num)
289
295
  ]
290
296
  kv_data_lens = [
291
- self.get_key_buffer(i).nbytes
297
+ self._get_key_buffer(i).nbytes
292
298
  for i in range(self.start_layer, self.start_layer + self.layer_num)
293
299
  ] + [
294
- self.get_value_buffer(i).nbytes
300
+ self._get_value_buffer(i).nbytes
295
301
  for i in range(self.start_layer, self.start_layer + self.layer_num)
296
302
  ]
297
303
  kv_item_lens = [
298
- self.get_key_buffer(i)[0].nbytes * self.page_size
304
+ self._get_key_buffer(i)[0].nbytes * self.page_size
299
305
  for i in range(self.start_layer, self.start_layer + self.layer_num)
300
306
  ] + [
301
- self.get_value_buffer(i)[0].nbytes * self.page_size
307
+ self._get_value_buffer(i)[0].nbytes * self.page_size
302
308
  for i in range(self.start_layer, self.start_layer + self.layer_num)
303
309
  ]
304
310
  return kv_data_ptrs, kv_data_lens, kv_item_lens
@@ -341,49 +347,73 @@ class MHATokenToKVPool(KVCache):
341
347
  self.v_buffer[layer_id][chunk_indices] = v_chunk
342
348
  torch.cuda.synchronize()
343
349
 
344
- # Todo: different memory layout
345
- def get_flat_data(self, indices):
346
- # prepare a large chunk of contiguous data for efficient transfer
347
- flatten = torch.stack(
348
- [
349
- torch.stack([self.k_buffer[i][indices] for i in range(self.layer_num)]),
350
- torch.stack([self.v_buffer[i][indices] for i in range(self.layer_num)]),
351
- ]
350
+ def load_from_host_per_layer(
351
+ self,
352
+ host_pool,
353
+ host_indices,
354
+ device_indices,
355
+ layer_id,
356
+ io_backend,
357
+ ):
358
+ transfer_kv_per_layer(
359
+ src_k=host_pool.k_buffer[layer_id],
360
+ dst_k=self.k_buffer[layer_id],
361
+ src_v=host_pool.v_buffer[layer_id],
362
+ dst_v=self.v_buffer[layer_id],
363
+ src_indices=host_indices,
364
+ dst_indices=device_indices,
365
+ io_backend=io_backend,
366
+ page_size=self.page_size,
367
+ item_size=self.token_stride,
352
368
  )
353
- return flatten
354
-
355
- @debug_timing
356
- def transfer(self, indices, flat_data):
357
- # transfer prepared data from host to device
358
- flat_data = flat_data.to(device=self.device, non_blocking=False)
359
- k_data, v_data = flat_data[0], flat_data[1]
360
- for i in range(self.layer_num):
361
- self.k_buffer[i][indices] = k_data[i]
362
- self.v_buffer[i][indices] = v_data[i]
363
-
364
- def transfer_per_layer(self, indices, flat_data, layer_id):
365
- # transfer prepared data from host to device
366
- flat_data = flat_data.to(device=self.device, non_blocking=False)
367
- k_data, v_data = flat_data[0], flat_data[1]
368
- self.k_buffer[layer_id - self.start_layer][indices] = k_data
369
- self.v_buffer[layer_id - self.start_layer][indices] = v_data
370
369
 
371
- def get_key_buffer(self, layer_id: int):
372
- if self.layer_transfer_counter is not None:
373
- self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
370
+ def backup_to_host_all_layer(
371
+ self, host_pool, host_indices, device_indices, io_backend
372
+ ):
373
+ # todo: specialized all layer kernels for the layer-non-contiguous memory pool
374
+ for layer_id in range(self.start_layer, self.start_layer + self.layer_num):
375
+ if layer_id - self.start_layer >= len(host_pool.k_buffer):
376
+ raise ValueError(
377
+ f"Layer ID {layer_id} exceeds the number of layers in host pool."
378
+ )
379
+ transfer_kv_per_layer(
380
+ src_k=self.k_buffer[layer_id],
381
+ dst_k=host_pool.k_buffer[layer_id],
382
+ src_v=self.v_buffer[layer_id],
383
+ dst_v=host_pool.v_buffer[layer_id],
384
+ src_indices=device_indices,
385
+ dst_indices=host_indices,
386
+ io_backend=io_backend,
387
+ page_size=self.page_size,
388
+ item_size=self.token_stride,
389
+ )
374
390
 
391
+ def _get_key_buffer(self, layer_id: int):
392
+ # for internal use of referencing
375
393
  if self.store_dtype != self.dtype:
376
394
  return self.k_buffer[layer_id - self.start_layer].view(self.dtype)
377
395
  return self.k_buffer[layer_id - self.start_layer]
378
396
 
379
- def get_value_buffer(self, layer_id: int):
397
+ def get_key_buffer(self, layer_id: int):
398
+ # note: get_key_buffer is hooked with synchronization for layer-wise KV cache loading
399
+ # it is supposed to be used only by attention backend not for information purpose
400
+ # same applies to get_value_buffer and get_kv_buffer
380
401
  if self.layer_transfer_counter is not None:
381
402
  self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
382
403
 
404
+ return self._get_key_buffer(layer_id)
405
+
406
+ def _get_value_buffer(self, layer_id: int):
407
+ # for internal use of referencing
383
408
  if self.store_dtype != self.dtype:
384
409
  return self.v_buffer[layer_id - self.start_layer].view(self.dtype)
385
410
  return self.v_buffer[layer_id - self.start_layer]
386
411
 
412
+ def get_value_buffer(self, layer_id: int):
413
+ if self.layer_transfer_counter is not None:
414
+ self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
415
+ return self._get_value_buffer(layer_id)
416
+
387
417
  def get_kv_buffer(self, layer_id: int):
388
418
  return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
389
419
 
@@ -574,32 +604,49 @@ class AscendTokenToKVPool(MHATokenToKVPool):
574
604
  with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
575
605
  # [size, head_num, head_dim] for each layer
576
606
  # The padded slot 0 is used for writing dummy outputs from padded tokens.
577
- self.k_buffer = [
578
- torch.zeros(
579
- (
580
- self.size // self.page_size + 1,
581
- self.page_size,
582
- self.head_num,
583
- self.head_dim,
584
- ),
585
- dtype=self.store_dtype,
586
- device=self.device,
587
- )
588
- for _ in range(self.layer_num)
589
- ]
590
- self.v_buffer = [
591
- torch.zeros(
592
- (
593
- self.size // self.page_size + 1,
594
- self.page_size,
595
- self.head_num,
596
- self.head_dim,
597
- ),
598
- dtype=self.store_dtype,
599
- device=self.device,
600
- )
601
- for _ in range(self.layer_num)
602
- ]
607
+ # Continuous memory improves the efficiency of Ascend`s transmission backend,
608
+ # while other backends remain unchanged.
609
+ self.kv_buffer = torch.zeros(
610
+ (
611
+ 2,
612
+ self.layer_num,
613
+ self.size // self.page_size + 1,
614
+ self.page_size,
615
+ self.head_num,
616
+ self.head_dim,
617
+ ),
618
+ dtype=self.store_dtype,
619
+ device=self.device,
620
+ )
621
+ self.k_buffer = self.kv_buffer[0]
622
+ self.v_buffer = self.kv_buffer[1]
623
+
624
+ # for disagg
625
+ def get_contiguous_buf_infos(self):
626
+ # layer_num x [seq_len, head_num, head_dim]
627
+ # layer_num x [page_num, page_size, head_num, head_dim]
628
+ kv_data_ptrs = [
629
+ self.get_key_buffer(i).data_ptr()
630
+ for i in range(self.start_layer, self.start_layer + self.layer_num)
631
+ ] + [
632
+ self.get_value_buffer(i).data_ptr()
633
+ for i in range(self.start_layer, self.start_layer + self.layer_num)
634
+ ]
635
+ kv_data_lens = [
636
+ self.get_key_buffer(i).nbytes
637
+ for i in range(self.start_layer, self.start_layer + self.layer_num)
638
+ ] + [
639
+ self.get_value_buffer(i).nbytes
640
+ for i in range(self.start_layer, self.start_layer + self.layer_num)
641
+ ]
642
+ kv_item_lens = [
643
+ self.get_key_buffer(i)[0].nbytes
644
+ for i in range(self.start_layer, self.start_layer + self.layer_num)
645
+ ] + [
646
+ self.get_value_buffer(i)[0].nbytes
647
+ for i in range(self.start_layer, self.start_layer + self.layer_num)
648
+ ]
649
+ return kv_data_ptrs, kv_data_lens, kv_item_lens
603
650
 
604
651
  def set_kv_buffer(
605
652
  self,
@@ -761,6 +808,7 @@ class MLATokenToKVPool(KVCache):
761
808
  for _ in range(layer_num)
762
809
  ]
763
810
 
811
+ self.token_stride = kv_lora_rank + qk_rope_head_dim
764
812
  self.layer_transfer_counter = None
765
813
 
766
814
  kv_size = self.get_kv_size_bytes()
@@ -846,21 +894,37 @@ class MLATokenToKVPool(KVCache):
846
894
  self.kv_buffer[layer_id], loc, cache_k_nope, cache_k_rope
847
895
  )
848
896
 
849
- def get_flat_data(self, indices):
850
- # prepare a large chunk of contiguous data for efficient transfer
851
- return torch.stack([self.kv_buffer[i][indices] for i in range(self.layer_num)])
852
-
853
- @debug_timing
854
- def transfer(self, indices, flat_data):
855
- # transfer prepared data from host to device
856
- flat_data = flat_data.to(device=self.device, non_blocking=False)
857
- for i in range(self.layer_num):
858
- self.kv_buffer[i][indices] = flat_data[i]
897
+ def load_from_host_per_layer(
898
+ self, host_pool, host_indices, device_indices, layer_id, io_backend
899
+ ):
900
+ transfer_kv_per_layer_mla(
901
+ src=host_pool.kv_buffer[layer_id],
902
+ dst=self.kv_buffer[layer_id],
903
+ src_indices=host_indices,
904
+ dst_indices=device_indices,
905
+ io_backend=io_backend,
906
+ page_size=self.page_size,
907
+ item_size=self.token_stride,
908
+ )
859
909
 
860
- def transfer_per_layer(self, indices, flat_data, layer_id):
861
- # transfer prepared data from host to device
862
- flat_data = flat_data.to(device=self.device, non_blocking=False)
863
- self.kv_buffer[layer_id - self.start_layer][indices] = flat_data
910
+ def backup_to_host_all_layer(
911
+ self, host_pool, host_indices, device_indices, io_backend
912
+ ):
913
+ # todo: specialized all layer kernels for the layer-non-contiguous memory pool
914
+ for layer_id in range(self.start_layer, self.start_layer + self.layer_num):
915
+ if layer_id - self.start_layer >= len(host_pool.kv_buffer):
916
+ raise ValueError(
917
+ f"Layer ID {layer_id} exceeds the number of layers in host pool."
918
+ )
919
+ transfer_kv_per_layer_mla(
920
+ src=self.kv_buffer[layer_id],
921
+ dst=host_pool.kv_buffer[layer_id],
922
+ src_indices=device_indices,
923
+ dst_indices=host_indices,
924
+ io_backend=io_backend,
925
+ page_size=self.page_size,
926
+ item_size=self.token_stride,
927
+ )
864
928
 
865
929
  def get_cpu_copy(self, indices):
866
930
  torch.cuda.synchronize()
@@ -922,18 +986,16 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
922
986
 
923
987
  with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
924
988
  # The padded slot 0 is used for writing dummy outputs from padded tokens.
925
- self.kv_buffer = [
926
- torch.zeros(
927
- (
928
- self.size // self.page_size + 1,
929
- self.page_size,
930
- self.kv_lora_rank + self.qk_rope_head_dim,
931
- ),
932
- dtype=self.store_dtype,
933
- device=self.device,
934
- )
935
- for _ in range(layer_num)
936
- ]
989
+ self.kv_buffer = torch.zeros(
990
+ (
991
+ layer_num,
992
+ self.size // self.page_size + 1,
993
+ self.page_size,
994
+ self.kv_lora_rank + self.qk_rope_head_dim,
995
+ ),
996
+ dtype=self.store_dtype,
997
+ device=self.device,
998
+ )
937
999
 
938
1000
  self.layer_transfer_counter = None
939
1001
 
@@ -943,6 +1005,14 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
943
1005
  )
944
1006
  self.mem_usage = kv_size / GB
945
1007
 
1008
+ # for disagg
1009
+ def get_contiguous_buf_infos(self):
1010
+ # MLA has only one kv_buffer, so only the information of this buffer needs to be returned.
1011
+ kv_data_ptrs = [self.kv_buffer[i].data_ptr() for i in range(self.layer_num)]
1012
+ kv_data_lens = [self.kv_buffer[i].nbytes for i in range(self.layer_num)]
1013
+ kv_item_lens = [self.kv_buffer[i][0].nbytes for i in range(self.layer_num)]
1014
+ return kv_data_ptrs, kv_data_lens, kv_item_lens
1015
+
946
1016
  def set_kv_buffer(
947
1017
  self,
948
1018
  layer: RadixAttention,
@@ -1046,14 +1116,19 @@ class DoubleSparseTokenToKVPool(KVCache):
1046
1116
  self.v_buffer[layer_id - self.start_layer][loc] = cache_v
1047
1117
  self.label_buffer[layer_id - self.start_layer][loc] = cache_label
1048
1118
 
1049
- def get_flat_data(self, indices):
1050
- pass
1051
-
1052
- def transfer(self, indices, flat_data):
1053
- pass
1119
+ def load_from_host_per_layer(
1120
+ self, host_pool, host_indices, device_indices, layer_id, io_backend
1121
+ ):
1122
+ raise NotImplementedError(
1123
+ "HiCache not supported for DoubleSparseTokenToKVPool."
1124
+ )
1054
1125
 
1055
- def transfer_per_layer(self, indices, flat_data, layer_id):
1056
- pass
1126
+ def backup_to_host_all_layer(
1127
+ self, host_pool, host_indices, device_indices, io_backend
1128
+ ):
1129
+ raise NotImplementedError(
1130
+ "HiCache not supported for DoubleSparseTokenToKVPool."
1131
+ )
1057
1132
 
1058
1133
 
1059
1134
  @triton.jit
@@ -8,7 +8,6 @@ import psutil
8
8
  import torch
9
9
 
10
10
  from sglang.srt.mem_cache.memory_pool import KVCache, MHATokenToKVPool, MLATokenToKVPool
11
- from sglang.srt.utils import debug_timing
12
11
 
13
12
  logger = logging.getLogger(__name__)
14
13
 
@@ -99,22 +98,6 @@ class HostKVCache(abc.ABC):
99
98
  def init_kv_buffer(self):
100
99
  raise NotImplementedError()
101
100
 
102
- @abc.abstractmethod
103
- def transfer(self, indices, flat_data):
104
- raise NotImplementedError()
105
-
106
- @abc.abstractmethod
107
- def get_flat_data(self, indices):
108
- raise NotImplementedError()
109
-
110
- @abc.abstractmethod
111
- def get_flat_data_by_layer(self, indices, layer_id):
112
- raise NotImplementedError()
113
-
114
- @abc.abstractmethod
115
- def assign_flat_data(self, indices, flat_data):
116
- raise NotImplementedError()
117
-
118
101
  @synchronized()
119
102
  def clear(self):
120
103
  # Initialize memory states and tracking structures.
@@ -243,58 +226,13 @@ class MHATokenToKVPoolHost(HostKVCache):
243
226
  pin_memory=self.pin_memory,
244
227
  )
245
228
 
246
- @debug_timing
247
- def transfer(self, indices, flat_data):
248
- # backup prepared data from device to host
249
- self.kv_buffer[:, :, indices] = flat_data.to(
250
- device=self.device, non_blocking=False
251
- )
229
+ @property
230
+ def k_buffer(self):
231
+ return self.kv_buffer[0]
252
232
 
253
- def get_flat_data(self, indices):
254
- return self.kv_buffer[:, :, indices]
255
-
256
- def get_flat_data_by_layer(self, indices, layer_id):
257
- return self.kv_buffer[:, layer_id - self.start_layer, indices]
258
-
259
- def assign_flat_data(self, indices, flat_data):
260
- self.kv_buffer[:, :, indices] = flat_data
261
-
262
- def write_page_all_layers(self, host_indices, device_indices, device_pool):
263
- device_indices_cpu = device_indices[:: self.page_size].cpu()
264
- for i in range(len(device_indices_cpu)):
265
- h_index = host_indices[i * self.page_size]
266
- d_index = device_indices_cpu[i]
267
- for j in range(self.layer_num):
268
- self.kv_buffer[0, j, h_index : h_index + self.page_size].copy_(
269
- device_pool.k_buffer[j][d_index : d_index + self.page_size],
270
- non_blocking=True,
271
- )
272
- self.kv_buffer[1, j, h_index : h_index + self.page_size].copy_(
273
- device_pool.v_buffer[j][d_index : d_index + self.page_size],
274
- non_blocking=True,
275
- )
276
-
277
- def load_page_per_layer(self, host_indices, device_indices, device_pool, layer_id):
278
- device_indices_cpu = device_indices[:: self.page_size].cpu()
279
- for i in range(len(device_indices_cpu)):
280
- h_index = host_indices[i * self.page_size]
281
- d_index = device_indices_cpu[i]
282
- device_pool.k_buffer[layer_id - self.start_layer][
283
- d_index : d_index + self.page_size
284
- ].copy_(
285
- self.kv_buffer[
286
- 0, layer_id - self.start_layer, h_index : h_index + self.page_size
287
- ],
288
- non_blocking=True,
289
- )
290
- device_pool.v_buffer[layer_id - self.start_layer][
291
- d_index : d_index + self.page_size
292
- ].copy_(
293
- self.kv_buffer[
294
- 1, layer_id - self.start_layer, h_index : h_index + self.page_size
295
- ],
296
- non_blocking=True,
297
- )
233
+ @property
234
+ def v_buffer(self):
235
+ return self.kv_buffer[1]
298
236
 
299
237
 
300
238
  class MLATokenToKVPoolHost(HostKVCache):
@@ -337,44 +275,3 @@ class MLATokenToKVPoolHost(HostKVCache):
337
275
  device=self.device,
338
276
  pin_memory=self.pin_memory,
339
277
  )
340
-
341
- @debug_timing
342
- def transfer(self, indices, flat_data):
343
- # backup prepared data from device to host
344
- self.kv_buffer[:, indices] = flat_data.to(
345
- device=self.device, non_blocking=False
346
- )
347
-
348
- def get_flat_data(self, indices):
349
- return self.kv_buffer[:, indices]
350
-
351
- def get_flat_data_by_layer(self, indices, layer_id):
352
- return self.kv_buffer[layer_id - self.start_layer, indices]
353
-
354
- def assign_flat_data(self, indices, flat_data):
355
- self.kv_buffer[:, indices] = flat_data
356
-
357
- def write_page_all_layers(self, host_indices, device_indices, device_pool):
358
- device_indices_cpu = device_indices[:: self.page_size].cpu()
359
- for i in range(len(device_indices_cpu)):
360
- h_index = host_indices[i * self.page_size]
361
- d_index = device_indices_cpu[i]
362
- for j in range(self.layer_num):
363
- self.kv_buffer[j, h_index : h_index + self.page_size].copy_(
364
- device_pool.kv_buffer[j][d_index : d_index + self.page_size],
365
- non_blocking=True,
366
- )
367
-
368
- def load_page_per_layer(self, host_indices, device_indices, device_pool, layer_id):
369
- device_indices_cpu = device_indices[:: self.page_size].cpu()
370
- for i in range(len(device_indices_cpu)):
371
- h_index = host_indices[i * self.page_size]
372
- d_index = device_indices_cpu[i]
373
- device_pool.kv_buffer[layer_id - self.start_layer][
374
- d_index : d_index + self.page_size
375
- ].copy_(
376
- self.kv_buffer[
377
- layer_id - self.start_layer, h_index : h_index + self.page_size
378
- ],
379
- non_blocking=True,
380
- )
@@ -196,11 +196,13 @@ class RadixCache(BasePrefixCache):
196
196
 
197
197
  if self.page_size != 1:
198
198
  page_aligned_len = len(kv_indices) // self.page_size * self.page_size
199
- page_aligned_kv_indices = kv_indices[:page_aligned_len].clone()
199
+ page_aligned_kv_indices = kv_indices[:page_aligned_len].to(
200
+ dtype=torch.int64, copy=True
201
+ )
200
202
  self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
201
203
  else:
202
204
  page_aligned_len = len(kv_indices)
203
- page_aligned_kv_indices = kv_indices.clone()
205
+ page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
204
206
 
205
207
  # Radix Cache takes one ref in memory pool
206
208
  new_prefix_len = self.insert(
@@ -226,10 +228,12 @@ class RadixCache(BasePrefixCache):
226
228
 
227
229
  if self.page_size != 1:
228
230
  page_aligned_len = len(kv_indices) // self.page_size * self.page_size
229
- page_aligned_kv_indices = kv_indices[:page_aligned_len].clone()
231
+ page_aligned_kv_indices = kv_indices[:page_aligned_len].to(
232
+ dtype=torch.int64, copy=True
233
+ )
230
234
  else:
231
235
  page_aligned_len = len(kv_indices)
232
- page_aligned_kv_indices = kv_indices.clone()
236
+ page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
233
237
  page_aligned_token_ids = token_ids[:page_aligned_len]
234
238
 
235
239
  # Radix Cache takes one ref in memory pool
@@ -453,8 +453,20 @@ class ForwardBatch:
453
453
  for mm_input in self.mm_inputs
454
454
  )
455
455
 
456
+ def contains_video_inputs(self) -> bool:
457
+ if self.mm_inputs is None:
458
+ return False
459
+ return any(
460
+ mm_input is not None and mm_input.contains_video_inputs()
461
+ for mm_input in self.mm_inputs
462
+ )
463
+
456
464
  def contains_mm_inputs(self) -> bool:
457
- return self.contains_audio_inputs() or self.contains_image_inputs()
465
+ return (
466
+ self.contains_audio_inputs()
467
+ or self.contains_video_inputs()
468
+ or self.contains_image_inputs()
469
+ )
458
470
 
459
471
  def _compute_mrope_positions(
460
472
  self, model_runner: ModelRunner, batch: ModelWorkerBatch
@@ -64,10 +64,13 @@ from sglang.srt.model_loader.weight_utils import (
64
64
  from sglang.srt.utils import (
65
65
  get_bool_env_var,
66
66
  get_device_capability,
67
+ is_npu,
67
68
  is_pin_memory_available,
68
69
  set_weight_attrs,
69
70
  )
70
71
 
72
+ _is_npu = is_npu()
73
+
71
74
 
72
75
  @contextmanager
73
76
  def device_loading_context(module: torch.nn.Module, target_device: torch.device):
@@ -127,18 +130,19 @@ def _get_quantization_config(
127
130
  # (yizhang2077) workaround for nvidia/Llama-4-Maverick-17B-128E-Eagle3
128
131
  if quant_config is None:
129
132
  return None
130
- major, minor = get_device_capability()
131
-
132
- if major is not None and minor is not None:
133
- assert 0 <= minor < 10
134
- capability = major * 10 + minor
135
- if capability < quant_config.get_min_capability():
136
- raise ValueError(
137
- f"The quantization method {model_config.quantization} "
138
- "is not supported for the current GPU. "
139
- f"Minimum capability: {quant_config.get_min_capability()}. "
140
- f"Current capability: {capability}."
141
- )
133
+ if not _is_npu:
134
+ major, minor = get_device_capability()
135
+
136
+ if major is not None and minor is not None:
137
+ assert 0 <= minor < 10
138
+ capability = major * 10 + minor
139
+ if capability < quant_config.get_min_capability():
140
+ raise ValueError(
141
+ f"The quantization method {model_config.quantization} "
142
+ "is not supported for the current GPU. "
143
+ f"Minimum capability: {quant_config.get_min_capability()}. "
144
+ f"Current capability: {capability}."
145
+ )
142
146
  supported_dtypes = quant_config.get_supported_act_dtypes()
143
147
  if model_config.dtype not in supported_dtypes:
144
148
  raise ValueError(
@@ -157,6 +161,13 @@ def _initialize_model(
157
161
  """Initialize a model with the given configurations."""
158
162
  model_class, _ = get_model_architecture(model_config)
159
163
  packed_modules_mapping = getattr(model_class, "packed_modules_mapping", {})
164
+ if _is_npu:
165
+ packed_modules_mapping["fused_qkv_a_proj_with_mqa"] = [
166
+ "q_a_proj",
167
+ "kv_a_proj_with_mqa",
168
+ ]
169
+ packed_modules_mapping["qkv_proj"] = ["q_proj", "k_proj", "v_proj"]
170
+ packed_modules_mapping["gate_up_proj"] = ["gate_proj", "up_proj"]
160
171
  quant_config = _get_quantization_config(
161
172
  model_config, load_config, packed_modules_mapping
162
173
  )
@@ -1989,7 +1989,7 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
1989
1989
  hidden_states = general_mm_embed_routine(
1990
1990
  input_ids=input_ids,
1991
1991
  forward_batch=forward_batch,
1992
- image_data_embedding_func=self.get_image_feature,
1992
+ multimodal_model=self,
1993
1993
  language_model=self.language_model,
1994
1994
  positions=positions,
1995
1995
  )