sglang 0.4.8__py3-none-any.whl → 0.4.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sglang/bench_one_batch_server.py +17 -2
  2. sglang/bench_serving.py +168 -22
  3. sglang/srt/configs/internvl.py +4 -2
  4. sglang/srt/configs/janus_pro.py +1 -1
  5. sglang/srt/configs/model_config.py +49 -0
  6. sglang/srt/configs/update_config.py +119 -0
  7. sglang/srt/conversation.py +35 -0
  8. sglang/srt/custom_op.py +7 -1
  9. sglang/srt/disaggregation/base/conn.py +2 -0
  10. sglang/srt/disaggregation/decode.py +22 -6
  11. sglang/srt/disaggregation/mooncake/conn.py +289 -48
  12. sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
  13. sglang/srt/disaggregation/nixl/conn.py +100 -52
  14. sglang/srt/disaggregation/prefill.py +5 -4
  15. sglang/srt/disaggregation/utils.py +13 -12
  16. sglang/srt/distributed/parallel_state.py +44 -17
  17. sglang/srt/entrypoints/EngineBase.py +8 -0
  18. sglang/srt/entrypoints/engine.py +45 -9
  19. sglang/srt/entrypoints/http_server.py +111 -24
  20. sglang/srt/entrypoints/openai/protocol.py +51 -6
  21. sglang/srt/entrypoints/openai/serving_chat.py +52 -76
  22. sglang/srt/entrypoints/openai/serving_completions.py +1 -0
  23. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  24. sglang/srt/eplb/__init__.py +0 -0
  25. sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
  26. sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
  27. sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
  28. sglang/srt/{managers → eplb}/expert_distribution.py +18 -1
  29. sglang/srt/{managers → eplb}/expert_location.py +1 -1
  30. sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
  31. sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
  32. sglang/srt/hf_transformers_utils.py +2 -1
  33. sglang/srt/layers/activation.py +7 -0
  34. sglang/srt/layers/amx_utils.py +86 -0
  35. sglang/srt/layers/attention/ascend_backend.py +219 -0
  36. sglang/srt/layers/attention/flashattention_backend.py +56 -23
  37. sglang/srt/layers/attention/tbo_backend.py +37 -9
  38. sglang/srt/layers/communicator.py +18 -2
  39. sglang/srt/layers/dp_attention.py +9 -3
  40. sglang/srt/layers/elementwise.py +76 -12
  41. sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
  42. sglang/srt/layers/layernorm.py +41 -0
  43. sglang/srt/layers/linear.py +99 -12
  44. sglang/srt/layers/logits_processor.py +15 -6
  45. sglang/srt/layers/moe/ep_moe/kernels.py +23 -8
  46. sglang/srt/layers/moe/ep_moe/layer.py +115 -25
  47. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +42 -19
  48. sglang/srt/layers/moe/fused_moe_native.py +7 -0
  49. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +8 -4
  50. sglang/srt/layers/moe/fused_moe_triton/layer.py +129 -10
  51. sglang/srt/layers/moe/router.py +60 -22
  52. sglang/srt/layers/moe/topk.py +36 -28
  53. sglang/srt/layers/parameter.py +67 -7
  54. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
  55. sglang/srt/layers/quantization/fp8.py +44 -0
  56. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  57. sglang/srt/layers/quantization/fp8_utils.py +6 -6
  58. sglang/srt/layers/quantization/gptq.py +5 -1
  59. sglang/srt/layers/quantization/moe_wna16.py +1 -1
  60. sglang/srt/layers/quantization/quant_utils.py +166 -0
  61. sglang/srt/layers/quantization/w8a8_int8.py +52 -1
  62. sglang/srt/layers/rotary_embedding.py +105 -13
  63. sglang/srt/layers/vocab_parallel_embedding.py +19 -2
  64. sglang/srt/lora/lora.py +4 -5
  65. sglang/srt/lora/lora_manager.py +73 -20
  66. sglang/srt/managers/configure_logging.py +1 -1
  67. sglang/srt/managers/io_struct.py +60 -15
  68. sglang/srt/managers/mm_utils.py +73 -59
  69. sglang/srt/managers/multimodal_processor.py +2 -6
  70. sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
  71. sglang/srt/managers/schedule_batch.py +80 -79
  72. sglang/srt/managers/scheduler.py +153 -63
  73. sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
  74. sglang/srt/managers/session_controller.py +12 -3
  75. sglang/srt/managers/tokenizer_manager.py +314 -103
  76. sglang/srt/managers/tp_worker.py +13 -1
  77. sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
  78. sglang/srt/mem_cache/allocator.py +290 -0
  79. sglang/srt/mem_cache/chunk_cache.py +34 -2
  80. sglang/srt/mem_cache/memory_pool.py +289 -3
  81. sglang/srt/mem_cache/multimodal_cache.py +3 -0
  82. sglang/srt/model_executor/cuda_graph_runner.py +3 -2
  83. sglang/srt/model_executor/forward_batch_info.py +17 -4
  84. sglang/srt/model_executor/model_runner.py +302 -58
  85. sglang/srt/model_loader/loader.py +86 -10
  86. sglang/srt/model_loader/weight_utils.py +160 -3
  87. sglang/srt/models/deepseek_nextn.py +5 -4
  88. sglang/srt/models/deepseek_v2.py +305 -26
  89. sglang/srt/models/deepseek_vl2.py +3 -5
  90. sglang/srt/models/gemma3_causal.py +1 -2
  91. sglang/srt/models/gemma3n_audio.py +949 -0
  92. sglang/srt/models/gemma3n_causal.py +1010 -0
  93. sglang/srt/models/gemma3n_mm.py +495 -0
  94. sglang/srt/models/hunyuan.py +771 -0
  95. sglang/srt/models/kimi_vl.py +1 -2
  96. sglang/srt/models/llama.py +10 -4
  97. sglang/srt/models/llama4.py +32 -45
  98. sglang/srt/models/llama_eagle3.py +61 -11
  99. sglang/srt/models/llava.py +5 -5
  100. sglang/srt/models/minicpmo.py +2 -2
  101. sglang/srt/models/mistral.py +1 -1
  102. sglang/srt/models/mllama4.py +43 -11
  103. sglang/srt/models/phi4mm.py +1 -3
  104. sglang/srt/models/pixtral.py +3 -7
  105. sglang/srt/models/qwen2.py +31 -3
  106. sglang/srt/models/qwen2_5_vl.py +1 -3
  107. sglang/srt/models/qwen2_audio.py +200 -0
  108. sglang/srt/models/qwen2_moe.py +32 -6
  109. sglang/srt/models/qwen2_vl.py +1 -4
  110. sglang/srt/models/qwen3.py +94 -25
  111. sglang/srt/models/qwen3_moe.py +68 -21
  112. sglang/srt/models/vila.py +3 -8
  113. sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +150 -133
  114. sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
  115. sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
  116. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
  117. sglang/srt/multimodal/processors/gemma3n.py +82 -0
  118. sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
  119. sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
  120. sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
  121. sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
  122. sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
  123. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
  124. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +3 -6
  125. sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
  126. sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
  127. sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
  128. sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
  129. sglang/srt/operations_strategy.py +6 -2
  130. sglang/srt/reasoning_parser.py +26 -0
  131. sglang/srt/sampling/sampling_batch_info.py +39 -1
  132. sglang/srt/server_args.py +85 -24
  133. sglang/srt/speculative/build_eagle_tree.py +57 -18
  134. sglang/srt/speculative/eagle_worker.py +6 -4
  135. sglang/srt/two_batch_overlap.py +204 -28
  136. sglang/srt/utils.py +369 -138
  137. sglang/srt/warmup.py +12 -3
  138. sglang/test/runners.py +10 -1
  139. sglang/test/test_utils.py +15 -3
  140. sglang/version.py +1 -1
  141. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/METADATA +9 -6
  142. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/RECORD +149 -137
  143. sglang/math_utils.py +0 -8
  144. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
  145. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
  146. /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
  147. /sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +0 -0
  148. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/WHEEL +0 -0
  149. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/licenses/LICENSE +0 -0
  150. {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/top_level.txt +0 -0
@@ -27,10 +27,11 @@ KVCache actually holds the physical kv cache.
27
27
  import abc
28
28
  import logging
29
29
  from contextlib import nullcontext
30
- from typing import List, Optional, Tuple, Union
30
+ from typing import Dict, List, Optional, Tuple, Union
31
31
 
32
32
  import numpy as np
33
33
  import torch
34
+ import torch.distributed as dist
34
35
  import triton
35
36
  import triton.language as tl
36
37
 
@@ -66,6 +67,7 @@ class ReqToTokenPool:
66
67
  self.req_to_token = torch.zeros(
67
68
  (size, max_context_len), dtype=torch.int32, device=device
68
69
  )
70
+
69
71
  self.free_slots = list(range(size))
70
72
 
71
73
  def write(self, indices, values):
@@ -121,6 +123,7 @@ class KVCache(abc.ABC):
121
123
  self.memory_saver_adapter = TorchMemorySaverAdapter.create(
122
124
  enable=enable_memory_saver
123
125
  )
126
+ self.mem_usage = 0
124
127
 
125
128
  # used for chunked cpu-offloading
126
129
  self.cpu_offloading_chunk_size = 8192
@@ -191,7 +194,6 @@ class MHATokenToKVPool(KVCache):
191
194
  start_layer,
192
195
  end_layer,
193
196
  )
194
-
195
197
  self.head_num = head_num
196
198
  self.head_dim = head_dim
197
199
 
@@ -218,6 +220,7 @@ class MHATokenToKVPool(KVCache):
218
220
  logger.info(
219
221
  f"KV Cache is allocated. #tokens: {size}, K size: {k_size / GB:.2f} GB, V size: {v_size / GB:.2f} GB"
220
222
  )
223
+ self.mem_usage = (k_size + v_size) / GB
221
224
 
222
225
  def _create_buffers(self):
223
226
  with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
@@ -392,10 +395,14 @@ class MHATokenToKVPool(KVCache):
392
395
  cache_v: torch.Tensor,
393
396
  k_scale: Optional[float] = None,
394
397
  v_scale: Optional[float] = None,
398
+ layer_id_override: Optional[int] = None,
395
399
  ):
396
400
  from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
397
401
 
398
- layer_id = layer.layer_id
402
+ if layer_id_override is not None:
403
+ layer_id = layer_id_override
404
+ else:
405
+ layer_id = layer.layer_id
399
406
  if cache_k.dtype != self.dtype:
400
407
  if k_scale is not None:
401
408
  cache_k.div_(k_scale)
@@ -431,6 +438,206 @@ class MHATokenToKVPool(KVCache):
431
438
  )
432
439
 
433
440
 
441
+ class SWAKVPool(KVCache):
442
+ """KV cache with separate pools for full and SWA attention layers."""
443
+
444
+ def __init__(
445
+ self,
446
+ size: int,
447
+ size_swa: int,
448
+ dtype: torch.dtype,
449
+ head_num: int,
450
+ head_dim: int,
451
+ swa_attention_layer_ids: List[int],
452
+ full_attention_layer_ids: List[int],
453
+ enable_kvcache_transpose: bool,
454
+ device: str,
455
+ ):
456
+ self.size = size
457
+ self.size_swa = size_swa
458
+ self.dtype = dtype
459
+ self.device = device
460
+ self.swa_layer_nums = len(swa_attention_layer_ids)
461
+ self.full_layer_nums = len(full_attention_layer_ids)
462
+ self.page_size = 1
463
+ # TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True
464
+ assert not enable_kvcache_transpose
465
+ TokenToKVPoolClass = MHATokenToKVPool
466
+ self.swa_kv_pool = TokenToKVPoolClass(
467
+ size=size_swa,
468
+ page_size=self.page_size,
469
+ dtype=dtype,
470
+ head_num=head_num,
471
+ head_dim=head_dim,
472
+ layer_num=self.swa_layer_nums,
473
+ device=device,
474
+ enable_memory_saver=False,
475
+ )
476
+ self.full_kv_pool = TokenToKVPoolClass(
477
+ size=size,
478
+ page_size=self.page_size,
479
+ dtype=dtype,
480
+ head_num=head_num,
481
+ head_dim=head_dim,
482
+ layer_num=self.full_layer_nums,
483
+ device=device,
484
+ enable_memory_saver=False,
485
+ )
486
+ self.layers_mapping: Dict[int, Tuple[int, bool]] = {}
487
+ for full_attn_layer_id, global_layer_id in enumerate(full_attention_layer_ids):
488
+ self.layers_mapping[global_layer_id] = (full_attn_layer_id, False)
489
+ for swa_layer_id, global_layer_id in enumerate(swa_attention_layer_ids):
490
+ self.layers_mapping[global_layer_id] = (swa_layer_id, True)
491
+ self.full_to_swa_index_mapping: Optional[torch.Tensor] = None
492
+
493
+ def get_kv_size_bytes(self):
494
+ raise NotImplementedError
495
+
496
+ def get_contiguous_buf_infos(self):
497
+ full_kv_data_ptrs, full_kv_data_lens, full_kv_item_lens = (
498
+ self.full_kv_pool.get_contiguous_buf_infos()
499
+ )
500
+ swa_kv_data_ptrs, swa_kv_data_lens, swa_kv_item_lens = (
501
+ self.swa_kv_pool.get_contiguous_buf_infos()
502
+ )
503
+
504
+ kv_data_ptrs = full_kv_data_ptrs + swa_kv_data_ptrs
505
+ kv_data_lens = full_kv_data_lens + swa_kv_data_lens
506
+ kv_item_lens = full_kv_item_lens + swa_kv_item_lens
507
+
508
+ return kv_data_ptrs, kv_data_lens, kv_item_lens
509
+
510
+ def get_key_buffer(self, layer_id: int):
511
+ layer_id_pool, is_swa = self.layers_mapping[layer_id]
512
+ if is_swa:
513
+ return self.swa_kv_pool.get_key_buffer(layer_id_pool)
514
+ else:
515
+ return self.full_kv_pool.get_key_buffer(layer_id_pool)
516
+
517
+ def get_value_buffer(self, layer_id: int):
518
+ layer_id_pool, is_swa = self.layers_mapping[layer_id]
519
+ if is_swa:
520
+ return self.swa_kv_pool.get_value_buffer(layer_id_pool)
521
+ else:
522
+ return self.full_kv_pool.get_value_buffer(layer_id_pool)
523
+
524
+ def get_kv_buffer(self, layer_id: int):
525
+ layer_id_pool, is_swa = self.layers_mapping[layer_id]
526
+ if is_swa:
527
+ return self.swa_kv_pool.get_kv_buffer(layer_id_pool)
528
+ else:
529
+ return self.full_kv_pool.get_kv_buffer(layer_id_pool)
530
+
531
+ def translate_loc_from_full_to_swa(self, kv_indices: torch.Tensor):
532
+ assert self.full_to_swa_index_mapping is not None
533
+ return self.full_to_swa_index_mapping[kv_indices].to(torch.int32)
534
+
535
+ def set_kv_buffer(
536
+ self,
537
+ layer: RadixAttention,
538
+ loc: torch.Tensor,
539
+ cache_k: torch.Tensor,
540
+ cache_v: torch.Tensor,
541
+ k_scale: float = 1.0,
542
+ v_scale: float = 1.0,
543
+ ):
544
+
545
+ layer_id = layer.layer_id
546
+ layer_id_pool, is_swa = self.layers_mapping[layer_id]
547
+ if is_swa:
548
+ if self.full_to_swa_index_mapping is not None:
549
+ loc = self.translate_loc_from_full_to_swa(loc)
550
+ self.swa_kv_pool.set_kv_buffer(
551
+ None,
552
+ loc,
553
+ cache_k,
554
+ cache_v,
555
+ k_scale,
556
+ v_scale,
557
+ layer_id_override=layer_id_pool,
558
+ )
559
+ else:
560
+ self.full_kv_pool.set_kv_buffer(
561
+ None,
562
+ loc,
563
+ cache_k,
564
+ cache_v,
565
+ k_scale,
566
+ v_scale,
567
+ layer_id_override=layer_id_pool,
568
+ )
569
+
570
+
571
+ class AscendTokenToKVPool(MHATokenToKVPool):
572
+
573
+ def _create_buffers(self):
574
+ with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
575
+ # [size, head_num, head_dim] for each layer
576
+ # The padded slot 0 is used for writing dummy outputs from padded tokens.
577
+ self.k_buffer = [
578
+ torch.zeros(
579
+ (
580
+ self.size // self.page_size + 1,
581
+ self.page_size,
582
+ self.head_num,
583
+ self.head_dim,
584
+ ),
585
+ dtype=self.store_dtype,
586
+ device=self.device,
587
+ )
588
+ for _ in range(self.layer_num)
589
+ ]
590
+ self.v_buffer = [
591
+ torch.zeros(
592
+ (
593
+ self.size // self.page_size + 1,
594
+ self.page_size,
595
+ self.head_num,
596
+ self.head_dim,
597
+ ),
598
+ dtype=self.store_dtype,
599
+ device=self.device,
600
+ )
601
+ for _ in range(self.layer_num)
602
+ ]
603
+
604
+ def set_kv_buffer(
605
+ self,
606
+ layer: RadixAttention,
607
+ loc: torch.Tensor,
608
+ cache_k: torch.Tensor,
609
+ cache_v: torch.Tensor,
610
+ k_scale: Optional[float] = None,
611
+ v_scale: Optional[float] = None,
612
+ ):
613
+ layer_id = layer.layer_id
614
+ if cache_k.dtype != self.dtype:
615
+ if k_scale is not None:
616
+ cache_k.div_(k_scale)
617
+ if v_scale is not None:
618
+ cache_v.div_(v_scale)
619
+ cache_k = cache_k.to(self.dtype)
620
+ cache_v = cache_v.to(self.dtype)
621
+
622
+ if self.store_dtype != self.dtype:
623
+ cache_k = cache_k.view(self.store_dtype)
624
+ cache_v = cache_v.view(self.store_dtype)
625
+
626
+ import torch_npu
627
+
628
+ torch_npu._npu_reshape_and_cache(
629
+ key=cache_k,
630
+ value=cache_v,
631
+ key_cache=self.k_buffer[layer_id].view(
632
+ -1, self.page_size, self.head_num, self.head_dim
633
+ ),
634
+ value_cache=self.v_buffer[layer_id].view(
635
+ -1, self.page_size, self.head_num, self.head_dim
636
+ ),
637
+ slot_indices=loc,
638
+ )
639
+
640
+
434
641
  @triton.jit
435
642
  def set_mla_kv_buffer_kernel(
436
643
  kv_buffer_ptr,
@@ -560,6 +767,7 @@ class MLATokenToKVPool(KVCache):
560
767
  logger.info(
561
768
  f"KV Cache is allocated. #tokens: {size}, KV size: {kv_size / GB:.2f} GB"
562
769
  )
770
+ self.mem_usage = kv_size / GB
563
771
 
564
772
  def get_kv_size_bytes(self):
565
773
  assert hasattr(self, "kv_buffer")
@@ -682,6 +890,84 @@ class MLATokenToKVPool(KVCache):
682
890
  torch.cuda.synchronize()
683
891
 
684
892
 
893
+ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
894
+ def __init__(
895
+ self,
896
+ size: int,
897
+ page_size: int,
898
+ dtype: torch.dtype,
899
+ kv_lora_rank: int,
900
+ qk_rope_head_dim: int,
901
+ layer_num: int,
902
+ device: str,
903
+ enable_memory_saver: bool,
904
+ start_layer: Optional[int] = None,
905
+ end_layer: Optional[int] = None,
906
+ ):
907
+ super(MLATokenToKVPool, self).__init__(
908
+ size,
909
+ page_size,
910
+ dtype,
911
+ layer_num,
912
+ device,
913
+ enable_memory_saver,
914
+ start_layer,
915
+ end_layer,
916
+ )
917
+
918
+ self.kv_lora_rank = kv_lora_rank
919
+ self.qk_rope_head_dim = qk_rope_head_dim
920
+
921
+ self.custom_mem_pool = None
922
+
923
+ with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
924
+ # The padded slot 0 is used for writing dummy outputs from padded tokens.
925
+ self.kv_buffer = [
926
+ torch.zeros(
927
+ (
928
+ self.size // self.page_size + 1,
929
+ self.page_size,
930
+ self.kv_lora_rank + self.qk_rope_head_dim,
931
+ ),
932
+ dtype=self.store_dtype,
933
+ device=self.device,
934
+ )
935
+ for _ in range(layer_num)
936
+ ]
937
+
938
+ self.layer_transfer_counter = None
939
+
940
+ kv_size = self.get_kv_size_bytes()
941
+ logger.info(
942
+ f"KV Cache is allocated. #tokens: {size}, KV size: {kv_size / GB:.2f} GB"
943
+ )
944
+ self.mem_usage = kv_size / GB
945
+
946
+ def set_kv_buffer(
947
+ self,
948
+ layer: RadixAttention,
949
+ loc: torch.Tensor,
950
+ cache_k: torch.Tensor,
951
+ cache_v: torch.Tensor,
952
+ ):
953
+ layer_id = layer.layer_id
954
+ if cache_k.dtype != self.dtype:
955
+ cache_k = cache_k.to(self.dtype)
956
+
957
+ if self.store_dtype != self.dtype:
958
+ cache_k = cache_k.view(store_dtype)
959
+
960
+ import torch_npu
961
+
962
+ torch_npu._npu_reshape_and_cache_siso(
963
+ key=cache_k.view(-1, 1, self.kv_lora_rank + self.qk_rope_head_dim),
964
+ key_cache=self.kv_buffer[layer_id - self.start_layer].view(
965
+ -1, 1, 1, self.kv_lora_rank + self.qk_rope_head_dim
966
+ ),
967
+ slot_indices=loc,
968
+ )
969
+
970
+
685
971
  class DoubleSparseTokenToKVPool(KVCache):
686
972
  def __init__(
687
973
  self,
@@ -24,6 +24,9 @@ class MultiModalCache:
24
24
  self.current_size += data_size
25
25
  return True
26
26
 
27
+ def has(self, mm_hash: int) -> bool:
28
+ return mm_hash in self.mm_cache
29
+
27
30
  def get(self, mm_hash: int) -> torch.Tensor:
28
31
  return self.mm_cache.get(mm_hash)
29
32
 
@@ -168,7 +168,7 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
168
168
  capture_bs += [model_runner.req_to_token_pool.size]
169
169
 
170
170
  if server_args.enable_two_batch_overlap:
171
- capture_bs = [bs for bs in capture_bs if bs >= 2]
171
+ capture_bs = [bs for bs in capture_bs if bs % 2 == 0]
172
172
 
173
173
  if server_args.cuda_graph_max_bs:
174
174
  capture_bs = [bs for bs in capture_bs if bs <= server_args.cuda_graph_max_bs]
@@ -421,7 +421,7 @@ class CudaGraphRunner:
421
421
  empty_cache=False,
422
422
  )
423
423
  capture_range.set_description(
424
- f"Capturing batches ({avail_mem=:.2f} GB)"
424
+ f"Capturing batches ({bs=} {avail_mem=:.2f} GB)"
425
425
  )
426
426
 
427
427
  with patch_model(
@@ -679,6 +679,7 @@ class CudaGraphRunner:
679
679
  forward_mode=self.capture_forward_mode,
680
680
  bs=bs,
681
681
  num_token_non_padded=len(forward_batch.input_ids),
682
+ spec_info=forward_batch.spec_info,
682
683
  )
683
684
  if forward_batch.forward_mode.is_idle() and forward_batch.spec_info is not None:
684
685
  forward_batch.spec_info.custom_mask = self.custom_mask
@@ -39,7 +39,12 @@ import triton
39
39
  import triton.language as tl
40
40
 
41
41
  from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
42
- from sglang.srt.utils import flatten_nested_list, get_compiler_backend, support_triton
42
+ from sglang.srt.utils import (
43
+ flatten_nested_list,
44
+ get_compiler_backend,
45
+ is_npu,
46
+ support_triton,
47
+ )
43
48
 
44
49
  if TYPE_CHECKING:
45
50
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
@@ -50,6 +55,8 @@ if TYPE_CHECKING:
50
55
  from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
51
56
  from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
52
57
 
58
+ _is_npu = is_npu()
59
+
53
60
 
54
61
  class ForwardMode(IntEnum):
55
62
  # Extend a sequence. The KV cache of the beginning part of the sequence is already computed (e.g., system prompt).
@@ -247,6 +254,7 @@ class ForwardBatch:
247
254
  dp_local_start_pos: Optional[torch.Tensor] = None # cached info at runtime
248
255
  dp_local_num_tokens: Optional[torch.Tensor] = None # cached info at runtime
249
256
  gathered_buffer: Optional[torch.Tensor] = None
257
+ is_extend_in_batch: bool = False
250
258
  can_run_dp_cuda_graph: bool = False
251
259
  global_forward_mode: Optional[ForwardMode] = None
252
260
 
@@ -292,6 +300,7 @@ class ForwardBatch:
292
300
  return_logprob=batch.return_logprob,
293
301
  top_logprobs_nums=batch.top_logprobs_nums,
294
302
  token_ids_logprobs=batch.token_ids_logprobs,
303
+ is_extend_in_batch=batch.is_extend_in_batch,
295
304
  can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
296
305
  global_forward_mode=batch.global_forward_mode,
297
306
  lora_paths=batch.lora_paths,
@@ -352,7 +361,9 @@ class ForwardBatch:
352
361
 
353
362
  if ret.forward_mode.is_idle():
354
363
  ret.positions = torch.empty((0,), device=device)
355
- TboForwardBatchPreparer.prepare(ret)
364
+ TboForwardBatchPreparer.prepare(
365
+ ret, is_draft_worker=model_runner.is_draft_worker
366
+ )
356
367
  return ret
357
368
 
358
369
  # Override the positions with spec_info
@@ -397,7 +408,9 @@ class ForwardBatch:
397
408
  if model_runner.server_args.lora_paths is not None:
398
409
  model_runner.lora_manager.prepare_lora_batch(ret)
399
410
 
400
- TboForwardBatchPreparer.prepare(ret)
411
+ TboForwardBatchPreparer.prepare(
412
+ ret, is_draft_worker=model_runner.is_draft_worker
413
+ )
401
414
 
402
415
  return ret
403
416
 
@@ -735,7 +748,7 @@ def compute_position_torch(
735
748
  return positions.to(torch.int64), extend_start_loc
736
749
 
737
750
 
738
- @torch.compile(dynamic=True, backend=get_compiler_backend())
751
+ @torch.compile(dynamic=True, backend=get_compiler_backend(), disable=_is_npu)
739
752
  def clamp_position(seq_lens):
740
753
  return torch.clamp((seq_lens - 1), min=0).to(torch.int64)
741
754