sglang 0.4.5.post3__py3-none-any.whl → 0.4.6.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. sglang/bench_one_batch.py +19 -3
  2. sglang/bench_serving.py +8 -9
  3. sglang/compile_deep_gemm.py +45 -4
  4. sglang/srt/code_completion_parser.py +1 -1
  5. sglang/srt/configs/deepseekvl2.py +1 -1
  6. sglang/srt/configs/model_config.py +9 -3
  7. sglang/srt/constrained/llguidance_backend.py +78 -61
  8. sglang/srt/conversation.py +34 -1
  9. sglang/srt/disaggregation/decode.py +67 -13
  10. sglang/srt/disaggregation/fake/__init__.py +1 -0
  11. sglang/srt/disaggregation/fake/conn.py +88 -0
  12. sglang/srt/disaggregation/mini_lb.py +45 -8
  13. sglang/srt/disaggregation/mooncake/conn.py +198 -31
  14. sglang/srt/disaggregation/prefill.py +36 -12
  15. sglang/srt/disaggregation/utils.py +16 -2
  16. sglang/srt/entrypoints/engine.py +9 -0
  17. sglang/srt/entrypoints/http_server.py +35 -4
  18. sglang/srt/function_call_parser.py +77 -5
  19. sglang/srt/layers/attention/base_attn_backend.py +3 -0
  20. sglang/srt/layers/attention/cutlass_mla_backend.py +278 -0
  21. sglang/srt/layers/attention/flashattention_backend.py +28 -10
  22. sglang/srt/layers/attention/flashmla_backend.py +8 -11
  23. sglang/srt/layers/attention/utils.py +1 -1
  24. sglang/srt/layers/attention/vision.py +2 -0
  25. sglang/srt/layers/layernorm.py +38 -16
  26. sglang/srt/layers/logits_processor.py +2 -2
  27. sglang/srt/layers/moe/fused_moe_native.py +2 -4
  28. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H20.json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H200.json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H20.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H200.json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  36. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
  37. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  38. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H200.json +146 -0
  39. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=96,device_name=NVIDIA_H20.json +146 -0
  40. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +41 -41
  41. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +20 -17
  43. sglang/srt/layers/moe/fused_moe_triton/layer.py +15 -17
  44. sglang/srt/layers/pooler.py +6 -0
  45. sglang/srt/layers/quantization/awq.py +5 -1
  46. sglang/srt/layers/quantization/deep_gemm.py +17 -10
  47. sglang/srt/layers/quantization/fp8.py +20 -22
  48. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  49. sglang/srt/layers/quantization/int8_kernel.py +32 -1
  50. sglang/srt/layers/radix_attention.py +13 -3
  51. sglang/srt/layers/rotary_embedding.py +170 -126
  52. sglang/srt/managers/data_parallel_controller.py +10 -3
  53. sglang/srt/managers/io_struct.py +7 -0
  54. sglang/srt/managers/mm_utils.py +85 -28
  55. sglang/srt/managers/multimodal_processors/base_processor.py +14 -1
  56. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +9 -2
  57. sglang/srt/managers/multimodal_processors/gemma3.py +2 -5
  58. sglang/srt/managers/multimodal_processors/janus_pro.py +2 -2
  59. sglang/srt/managers/multimodal_processors/minicpm.py +4 -3
  60. sglang/srt/managers/multimodal_processors/qwen_vl.py +38 -13
  61. sglang/srt/managers/schedule_batch.py +38 -12
  62. sglang/srt/managers/scheduler.py +41 -28
  63. sglang/srt/managers/scheduler_output_processor_mixin.py +25 -9
  64. sglang/srt/managers/tokenizer_manager.py +5 -1
  65. sglang/srt/managers/tp_worker.py +3 -3
  66. sglang/srt/managers/tp_worker_overlap_thread.py +9 -4
  67. sglang/srt/mem_cache/memory_pool.py +87 -0
  68. sglang/srt/model_executor/cuda_graph_runner.py +4 -3
  69. sglang/srt/model_executor/forward_batch_info.py +51 -95
  70. sglang/srt/model_executor/model_runner.py +19 -25
  71. sglang/srt/models/deepseek.py +12 -2
  72. sglang/srt/models/deepseek_nextn.py +101 -6
  73. sglang/srt/models/deepseek_v2.py +144 -70
  74. sglang/srt/models/deepseek_vl2.py +9 -4
  75. sglang/srt/models/gemma3_causal.py +1 -1
  76. sglang/srt/models/llama4.py +0 -1
  77. sglang/srt/models/minicpmo.py +5 -1
  78. sglang/srt/models/mllama4.py +2 -2
  79. sglang/srt/models/qwen2_5_vl.py +3 -6
  80. sglang/srt/models/qwen2_vl.py +3 -7
  81. sglang/srt/models/roberta.py +178 -0
  82. sglang/srt/openai_api/adapter.py +50 -11
  83. sglang/srt/openai_api/protocol.py +2 -0
  84. sglang/srt/reasoning_parser.py +25 -1
  85. sglang/srt/server_args.py +31 -24
  86. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  87. sglang/srt/torch_memory_saver_adapter.py +10 -1
  88. sglang/srt/utils.py +5 -1
  89. sglang/test/runners.py +6 -13
  90. sglang/test/send_one.py +84 -28
  91. sglang/test/test_utils.py +74 -18
  92. sglang/version.py +1 -1
  93. {sglang-0.4.5.post3.dist-info → sglang-0.4.6.post1.dist-info}/METADATA +5 -6
  94. {sglang-0.4.5.post3.dist-info → sglang-0.4.6.post1.dist-info}/RECORD +97 -80
  95. {sglang-0.4.5.post3.dist-info → sglang-0.4.6.post1.dist-info}/WHEEL +1 -1
  96. {sglang-0.4.5.post3.dist-info → sglang-0.4.6.post1.dist-info}/licenses/LICENSE +0 -0
  97. {sglang-0.4.5.post3.dist-info → sglang-0.4.6.post1.dist-info}/top_level.txt +0 -0
@@ -25,7 +25,7 @@ if is_cuda():
25
25
 
26
26
  sm_version = get_device_sm()
27
27
  if sm_version == 90:
28
- if get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="false"):
28
+ if get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true"):
29
29
  _ENABLE_JIT_DEEPGEMM = True
30
30
 
31
31
  logger = logging.getLogger(__name__)
@@ -34,9 +34,10 @@ _BUILTIN_M_LIST = list(range(1, 1024 * 16 + 1))
34
34
  _ENABLE_JIT_DEEPGEMM_PRECOMPILE = get_bool_env_var(
35
35
  "SGL_JIT_DEEPGEMM_PRECOMPILE", "true"
36
36
  )
37
- _DO_COMPILE = get_bool_env_var("SGL_IS_FIRST_RANK_ON_NODE", "true")
37
+ _DO_COMPILE_ALL = True
38
+ _IS_FIRST_RANK_ON_NODE = get_bool_env_var("SGL_IS_FIRST_RANK_ON_NODE", "true")
38
39
  _COMPILE_WORKERS = get_int_env_var("SGL_JIT_DEEPGEMM_COMPILE_WORKERS", 4)
39
- _IN_PRE_COMPILE_STAGE = get_bool_env_var("SGL_IN_DEEP_GEMM_PRE_COMPILE_STAGE", "false")
40
+ _IN_PRECOMPILE_STAGE = get_bool_env_var("SGL_IN_DEEPGEMM_PRECOMPILE_STAGE", "false")
40
41
 
41
42
  # Force redirect deep_gemm cache_dir
42
43
  os.environ["DG_CACHE_DIR"] = os.getenv(
@@ -46,7 +47,8 @@ os.environ["DG_CACHE_DIR"] = os.getenv(
46
47
 
47
48
  def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):
48
49
  global _BUILTIN_M_LIST
49
- global _DO_COMPILE
50
+ global _DO_COMPILE_ALL
51
+ global _IS_FIRST_RANK_ON_NODE
50
52
 
51
53
  # Generate m_max
52
54
  m_max = 1024 * 16
@@ -57,8 +59,13 @@ def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):
57
59
  m_max = min(1024 * 128, m_max)
58
60
  _BUILTIN_M_LIST = list(range(1, m_max + 1))
59
61
 
60
- # Check if is the first rank on node
61
- _DO_COMPILE = ServerArgs.base_gpu_id == gpu_id
62
+ _IS_FIRST_RANK_ON_NODE = ServerArgs.base_gpu_id == gpu_id
63
+
64
+ # Check if is the first rank on node.
65
+ # Default each rank will try compile all Ms to
66
+ # load all symbols at the launch stages.
67
+ # Avoid loading symbols at the serving stages.
68
+ _DO_COMPILE_ALL = _IS_FIRST_RANK_ON_NODE or not _IN_PRECOMPILE_STAGE
62
69
 
63
70
 
64
71
  class DeepGemmKernelType(IntEnum):
@@ -89,7 +96,7 @@ _INITIALIZATION_DICT: Dict[Tuple[DeepGemmKernelType, int, int, int], bool] = dic
89
96
 
90
97
 
91
98
  def _compile_warning_1():
92
- if not _IN_PRE_COMPILE_STAGE:
99
+ if not _IN_PRECOMPILE_STAGE and _IS_FIRST_RANK_ON_NODE:
93
100
  logger.warning(
94
101
  "Entering DeepGEMM JIT Pre-Complie session. "
95
102
  "And it may takes a long time(Typically 10-20 mins) "
@@ -276,7 +283,7 @@ def _maybe_compile_deep_gemm_one_type_all(
276
283
  query_key = (kernel_type, n, k, num_groups)
277
284
  if (
278
285
  _ENABLE_JIT_DEEPGEMM_PRECOMPILE
279
- and _DO_COMPILE
286
+ and _DO_COMPILE_ALL
280
287
  and _INITIALIZATION_DICT.get(query_key) is None
281
288
  ):
282
289
  _INITIALIZATION_DICT[query_key] = True
@@ -286,7 +293,7 @@ def _maybe_compile_deep_gemm_one_type_all(
286
293
  logger.info(
287
294
  f"Try DeepGEMM JIT Compiling for "
288
295
  f"<{kernel_helper.name}> N={n}, K={k}, num_groups={num_groups} with all Ms."
289
- f"{' It only takes a litte time(Typically 1 sec) if you have run `sglang.compile_deep_gemm`. ' if not _IN_PRE_COMPILE_STAGE else ''}"
296
+ f"{' It only takes a litte time(Typically 1 sec) if you have run `sglang.compile_deep_gemm`. ' if not _IN_PRECOMPILE_STAGE else ''}"
290
297
  )
291
298
 
292
299
  # NOTE(alcanderian): get_num_sms should be change when 2-batch-overlap is introduced
@@ -355,7 +362,7 @@ def gemm_nt_f8f8bf16(
355
362
 
356
363
  @contextmanager
357
364
  def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType):
358
- if _IN_PRE_COMPILE_STAGE:
365
+ if _IN_PRECOMPILE_STAGE:
359
366
  yield
360
367
  return
361
368
 
@@ -72,8 +72,8 @@ _is_hip = is_hip()
72
72
  _is_cuda = is_cuda()
73
73
 
74
74
  if _is_hip:
75
- from aiter import ActivationType
76
- from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages, ck_moe_2stages_win4
75
+ from aiter import ActivationType, QuantType
76
+ from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages
77
77
  from aiter.ops.shuffle import shuffle_weight
78
78
 
79
79
  if not _is_cuda:
@@ -484,7 +484,7 @@ class Fp8MoEMethod:
484
484
  if self.quant_config.is_checkpoint_fp8_serialized:
485
485
  params_dtype = (
486
486
  torch.uint32
487
- if get_bool_env_var("USE_INT4_WEIGHT")
487
+ if get_bool_env_var("SGLANG_INT4_WEIGHT")
488
488
  else torch.float8_e4m3fn
489
489
  )
490
490
  tp_size = get_tensor_model_parallel_world_size()
@@ -511,7 +511,7 @@ class Fp8MoEMethod:
511
511
  )
512
512
 
513
513
  # WEIGHTS
514
- if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
514
+ if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
515
515
  # INT4 MoE weight - INT32 packed
516
516
  w13_weight = torch.nn.Parameter(
517
517
  torch.empty(
@@ -585,7 +585,7 @@ class Fp8MoEMethod:
585
585
 
586
586
  if (
587
587
  _is_hip
588
- ): # and get_bool_env_var("CK_MOE"): TODO: add check back after triton kernel
588
+ ): # and get_bool_env_var("SGLANG_AITER_MOE"): TODO: add check back after triton kernel
589
589
  # ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling
590
590
  w13_weight_scale1 = torch.nn.Parameter(
591
591
  torch.ones(num_experts, 2 * intermediate_size, dtype=torch.float32),
@@ -612,7 +612,7 @@ class Fp8MoEMethod:
612
612
  set_weight_attrs(w13_weight_scale, extra_weight_attrs)
613
613
  set_weight_attrs(w2_weight_scale, extra_weight_attrs)
614
614
 
615
- if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
615
+ if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
616
616
  extra_weight_attrs.update(
617
617
  {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
618
618
  )
@@ -644,7 +644,7 @@ class Fp8MoEMethod:
644
644
  layer.w2_input_scale = None
645
645
 
646
646
  def process_weights_after_loading(self, layer: Module) -> None:
647
- if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
647
+ if _is_hip and get_bool_env_var("SGLANG_INT4_WEIGHT"):
648
648
  self.process_weights_hip_int4(layer)
649
649
  return
650
650
 
@@ -675,7 +675,7 @@ class Fp8MoEMethod:
675
675
  )
676
676
  layer.w2_input_scale = None
677
677
 
678
- if get_bool_env_var("CK_MOE"):
678
+ if get_bool_env_var("SGLANG_AITER_MOE"):
679
679
  # Pre-shuffle weights
680
680
  layer.w13_weight.data = shuffle_weight(
681
681
  layer.w13_weight.contiguous(), (16, 16)
@@ -798,17 +798,15 @@ class Fp8MoEMethod:
798
798
  return
799
799
 
800
800
  def process_weights_hip_int4(self, layer: Module):
801
- # TODO: and get_bool_env_var("CK_MOE"): add after triton kernel added
801
+ # TODO: and get_bool_env_var("SGLANG_AITER_MOE"): add after triton kernel added
802
802
  # INT4-FP8 (INT4 MoE Weight, FP8 Compute)
803
803
  # Weight Permutation
804
804
  layer.w13_weight = torch.nn.Parameter(
805
- # permute_weight(layer.w13_weight.data),
806
805
  shuffle_weight(layer.w13_weight.data, (16, 16)),
807
806
  requires_grad=False,
808
807
  )
809
808
  torch.cuda.empty_cache()
810
809
  layer.w2_weight = torch.nn.Parameter(
811
- # permute_weight(layer.w2_weight.data),
812
810
  shuffle_weight(layer.w2_weight.data, (16, 16)),
813
811
  requires_grad=False,
814
812
  )
@@ -847,23 +845,21 @@ class Fp8MoEMethod:
847
845
  padding_size, # Avoid circular import
848
846
  )
849
847
 
850
- if get_bool_env_var("CK_MOE"):
848
+ if get_bool_env_var("SGLANG_AITER_MOE"):
851
849
  layer.w13_weight = torch.nn.Parameter(
852
- # permute_weight(layer.w13_weight.data),
853
850
  shuffle_weight(layer.w13_weight.data, (16, 16)),
854
851
  requires_grad=False,
855
852
  )
856
853
  torch.cuda.empty_cache()
857
854
  layer.w2_weight = torch.nn.Parameter(
858
- # permute_weight(layer.w2_weight.data),
859
855
  shuffle_weight(layer.w2_weight.data, (16, 16)),
860
856
  requires_grad=False,
861
857
  )
862
858
  torch.cuda.empty_cache()
863
- # ROCm (CK_MOE): using column-wise scaling
859
+ # ROCm (SGLANG_AITER_MOE): using column-wise scaling
864
860
  layer.w13_weight_scale1 *= layer.w13_weight_scale.unsqueeze(-1)
865
861
  layer.w2_weight_scale1 *= layer.w2_weight_scale.unsqueeze(-1)
866
- elif get_bool_env_var("MOE_PADDING"):
862
+ elif get_bool_env_var("SGLANG_MOE_PADDING"):
867
863
  # If ROCm, apply weight padding (min. Mem channel contention) only if set
868
864
  layer.w13_weight = torch.nn.Parameter(
869
865
  F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
@@ -912,15 +908,16 @@ class Fp8MoEMethod:
912
908
  )
913
909
 
914
910
  if _is_hip:
915
- if get_bool_env_var("USE_INT4_WEIGHT"):
916
- # TODO: add triton kernel and add check get_bool_env_var("CK_MOE")
911
+ if get_bool_env_var("SGLANG_INT4_WEIGHT"):
912
+ # TODO: add triton kernel and add check get_bool_env_var("SGLANG_AITER_MOE")
917
913
  assert not no_combine, f"{no_combine=} is not supported."
918
- return ck_moe_2stages_win4(
914
+ return ck_moe_2stages(
919
915
  x,
920
916
  layer.w13_weight,
921
917
  layer.w2_weight,
922
918
  topk_weights,
923
919
  topk_ids,
920
+ QuantType.per_Token,
924
921
  layer.w13_weight_scale1,
925
922
  layer.w2_weight_scale1,
926
923
  activation=(
@@ -930,13 +927,13 @@ class Fp8MoEMethod:
930
927
  ),
931
928
  )
932
929
 
933
- if get_bool_env_var("CK_MOE"):
930
+ if get_bool_env_var("SGLANG_AITER_MOE"):
934
931
  assert not no_combine, f"{no_combine=} is not supported."
935
932
  if self.block_quant:
936
- # TODO(CK_MOE): FP8 block_quant only supports 'silu' for the time-being.
933
+ # TODO(SGLANG_AITER_MOE): FP8 block_quant only supports 'silu' for the time-being.
937
934
  assert (
938
935
  activation == "silu"
939
- ), f"CK_MOE: FP8 bloack_quant {activation=} will be supported later, unset CK_MOE"
936
+ ), f"SGLANG_AITER_MOE: FP8 bloack_quant {activation=} will be supported later, unset SGLANG_AITER_MOE"
940
937
  return asm_moe(
941
938
  x,
942
939
  layer.w13_weight,
@@ -955,6 +952,7 @@ class Fp8MoEMethod:
955
952
  layer.w2_weight,
956
953
  topk_weights,
957
954
  topk_ids,
955
+ QuantType.per_Token,
958
956
  layer.w13_weight_scale1,
959
957
  layer.w2_weight_scale1,
960
958
  activation=(
@@ -31,7 +31,7 @@ from sglang.srt.utils import (
31
31
  _is_hip = is_hip()
32
32
  _is_cuda = is_cuda()
33
33
 
34
- if _is_hip and get_bool_env_var("CK_MOE"):
34
+ if _is_hip and get_bool_env_var("SGLANG_AITER_MOE"):
35
35
  from aiter import gemm_a8w8_blockscale
36
36
 
37
37
  if _is_cuda:
@@ -132,7 +132,7 @@ def apply_w8a8_block_fp8_linear(
132
132
  output = fp8_blockwise_scaled_mm(
133
133
  q_input, weight.T, x_scale, weight_scale.T, out_dtype=input.dtype
134
134
  )
135
- elif _is_hip and get_bool_env_var("CK_MOE"):
135
+ elif _is_hip and get_bool_env_var("SGLANG_AITER_MOE"):
136
136
  q_input, x_scale = per_token_group_quant_fp8(
137
137
  input_2d, block_size[1], column_major_scales=False
138
138
  )
@@ -8,7 +8,11 @@ import torch
8
8
  import triton
9
9
  import triton.language as tl
10
10
 
11
- from sglang.srt.utils import get_device_name
11
+ from sglang.srt.utils import get_device_name, is_cuda
12
+
13
+ _is_cuda = is_cuda()
14
+ if _is_cuda:
15
+ from sgl_kernel import sgl_per_token_group_quant_int8
12
16
 
13
17
  logger = logging.getLogger(__name__)
14
18
 
@@ -165,6 +169,33 @@ def per_token_group_quant_int8(
165
169
  return x_q, x_s
166
170
 
167
171
 
172
+ def sglang_per_token_group_quant_int8(
173
+ x: torch.Tensor,
174
+ group_size: int,
175
+ eps: float = 1e-10,
176
+ dtype: torch.dtype = torch.int8,
177
+ ):
178
+ assert (
179
+ x.shape[-1] % group_size == 0
180
+ ), "the last dimension of `x` cannot be divisible by `group_size`"
181
+ assert x.is_contiguous(), "`x` is not contiguous"
182
+
183
+ iinfo = torch.iinfo(dtype)
184
+ int8_max = iinfo.max
185
+ int8_min = iinfo.min
186
+
187
+ x_q = torch.empty_like(x, device=x.device, dtype=dtype)
188
+ x_s = torch.empty(
189
+ x.shape[:-1] + (x.shape[-1] // group_size,),
190
+ device=x.device,
191
+ dtype=torch.float32,
192
+ )
193
+
194
+ sgl_per_token_group_quant_int8(x, x_q, x_s, group_size, eps, int8_min, int8_max)
195
+
196
+ return x_q, x_s
197
+
198
+
168
199
  @triton.jit
169
200
  def _w8a8_block_int8_matmul(
170
201
  # Pointers to inputs and output
@@ -87,13 +87,23 @@ class RadixAttention(nn.Module):
87
87
  v,
88
88
  forward_batch: ForwardBatch,
89
89
  save_kv_cache: bool = True,
90
+ **kwargs,
90
91
  ):
91
92
  if k is not None:
92
93
  # For cross-layer sharing, kv can be None
93
94
  assert v is not None
94
- k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
95
- v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
95
+ if "k_rope" not in kwargs:
96
+ k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
97
+ v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
98
+ else:
99
+ k = k.view(-1, self.tp_k_head_num, self.v_head_dim)
96
100
 
97
101
  return forward_batch.attn_backend.forward(
98
- q, k, v, self, forward_batch, save_kv_cache
102
+ q,
103
+ k,
104
+ v,
105
+ self,
106
+ forward_batch,
107
+ save_kv_cache,
108
+ **kwargs,
99
109
  )
@@ -14,8 +14,6 @@ _is_cuda = is_cuda()
14
14
 
15
15
  if _is_cuda:
16
16
  from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
17
- else:
18
- from vllm._custom_ops import rotary_embedding as vllm_rotary_embedding
19
17
 
20
18
 
21
19
  def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
@@ -84,6 +82,12 @@ class RotaryEmbedding(CustomOp):
84
82
  # NOTE(ByronHsu): cache needs to be in FP32 for numerical stability
85
83
  if not _is_cuda:
86
84
  cache = cache.to(dtype)
85
+
86
+ if not _is_cuda or self.head_size not in [64, 128, 256, 512]:
87
+ from vllm._custom_ops import rotary_embedding
88
+
89
+ self.vllm_rotary_embedding = rotary_embedding
90
+
87
91
  self.cos_sin_cache: torch.Tensor
88
92
  self.register_buffer("cos_sin_cache", cache, persistent=False)
89
93
 
@@ -160,7 +164,7 @@ class RotaryEmbedding(CustomOp):
160
164
  )
161
165
  else:
162
166
  self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
163
- vllm_rotary_embedding(
167
+ self.vllm_rotary_embedding(
164
168
  positions,
165
169
  query,
166
170
  key,
@@ -665,6 +669,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
665
669
  offsets: Optional[torch.Tensor] = None,
666
670
  ) -> Tuple[torch.Tensor, torch.Tensor]:
667
671
  """PyTorch-native implementation equivalent to forward()."""
672
+ dtype = query.dtype
668
673
  query_rot = query[..., : self.rotary_dim]
669
674
  key_rot = key[..., : self.rotary_dim]
670
675
  if self.rotary_dim < self.head_size:
@@ -695,7 +700,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
695
700
  else:
696
701
  query = query_rot
697
702
  key = key_rot
698
- return query, key
703
+ return query.to(dtype), key.to(dtype)
699
704
 
700
705
 
701
706
  class Llama3RotaryEmbedding(RotaryEmbedding):
@@ -876,142 +881,181 @@ class MRotaryEmbedding(RotaryEmbedding):
876
881
  key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
877
882
  return query, key
878
883
 
884
+ # Copied from https://github.com/huggingface/transformers/blob/c8e0e603de9b3d49161a15fe6e8ea84badfb5d02/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1439
879
885
  @staticmethod
880
- def get_input_positions(
881
- input_tokens: List[int],
882
- image_grid_thw: Union[List[List[int]], torch.Tensor],
883
- video_grid_thw: Union[List[List[int]], torch.Tensor],
886
+ def get_rope_index(
887
+ spatial_merge_size: int,
884
888
  image_token_id: int,
885
889
  video_token_id: int,
886
890
  vision_start_token_id: int,
887
- vision_end_token_id: int,
888
- spatial_merge_size: int,
889
- context_len: int = 0,
890
- seq_len: Optional[int] = None,
891
- second_per_grid_ts: Optional[torch.Tensor] = None,
891
+ model_type: str,
892
892
  tokens_per_second: Optional[int] = None,
893
- ) -> Tuple[List[List[int]], int]:
894
- """
895
- Get mrope input positions and delta value.
896
-
897
- :arg
898
- second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*):
899
- The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs.
900
-
901
- """
902
-
903
- if isinstance(image_grid_thw, torch.Tensor):
904
- image_grid_thw = image_grid_thw.tolist()
905
- if isinstance(video_grid_thw, torch.Tensor):
906
- video_grid_thw = video_grid_thw.tolist()
907
-
908
- input_tokens_tensor = torch.tensor(input_tokens)
909
- vision_start_indices = torch.argwhere(
910
- input_tokens_tensor == vision_start_token_id
911
- ).squeeze(1)
912
- vision_tokens = input_tokens_tensor[vision_start_indices + 1]
913
- image_nums = (vision_tokens == image_token_id).sum()
914
- video_nums = (vision_tokens == video_token_id).sum()
915
- llm_pos_ids_list: list = []
916
-
917
- st = 0
918
- remain_images, remain_videos = image_nums, video_nums
919
-
920
- image_index, video_index = 0, 0
921
- for _ in range(image_nums + video_nums):
922
- if image_token_id in input_tokens and remain_images > 0:
923
- ed_image = input_tokens.index(image_token_id, st)
924
- else:
925
- ed_image = len(input_tokens) + 1
926
- if video_token_id in input_tokens and remain_videos > 0:
927
- ed_video = input_tokens.index(video_token_id, st)
928
- else:
929
- ed_video = len(input_tokens) + 1
930
- if ed_image < ed_video:
931
- t, h, w = (
932
- image_grid_thw[image_index][0],
933
- image_grid_thw[image_index][1],
934
- image_grid_thw[image_index][2],
935
- )
936
- image_index += 1
937
- remain_images -= 1
938
- second_per_grid_t = 0
939
- ed = ed_image
940
- else:
941
- t, h, w = (
942
- video_grid_thw[video_index][0],
943
- video_grid_thw[video_index][1],
944
- video_grid_thw[video_index][2],
945
- )
946
- if second_per_grid_ts is not None:
947
- second_per_grid_t = second_per_grid_ts[video_index]
948
- else:
949
- second_per_grid_t = 1.0
950
- video_index += 1
951
- remain_videos -= 1
952
- ed = ed_video
953
- llm_grid_t, llm_grid_h, llm_grid_w = (
954
- t,
955
- h // spatial_merge_size,
956
- w // spatial_merge_size,
957
- )
958
- text_len = ed - st
959
-
960
- st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
961
- llm_pos_ids_list.append(
962
- torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
963
- )
964
-
965
- t_index = (
966
- torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w)
967
- * second_per_grid_t
968
- * tokens_per_second
969
- ).flatten()
970
-
971
- h_index = (
972
- torch.arange(llm_grid_h)
973
- .view(1, -1, 1)
974
- .expand(llm_grid_t, -1, llm_grid_w)
975
- .flatten()
976
- )
977
- w_index = (
978
- torch.arange(llm_grid_w)
979
- .view(1, 1, -1)
980
- .expand(llm_grid_t, llm_grid_h, -1)
981
- .flatten()
982
- )
983
- llm_pos_ids_list.append(
984
- torch.stack([t_index, h_index, w_index]) + text_len + st_idx
893
+ input_ids: Optional[torch.LongTensor] = None,
894
+ image_grid_thw: Optional[torch.LongTensor] = None,
895
+ video_grid_thw: Optional[torch.LongTensor] = None,
896
+ second_per_grid_ts: Optional[torch.Tensor] = None,
897
+ **kwargs,
898
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
899
+ mrope_position_deltas = []
900
+ if input_ids is not None and (
901
+ image_grid_thw is not None or video_grid_thw is not None
902
+ ):
903
+ total_input_ids = input_ids
904
+ position_ids = torch.ones(
905
+ 3,
906
+ input_ids.shape[0],
907
+ input_ids.shape[1],
908
+ dtype=input_ids.dtype,
909
+ device=input_ids.device,
985
910
  )
986
- st = ed + llm_grid_t * llm_grid_h * llm_grid_w
987
-
988
- if st < len(input_tokens):
989
- st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
990
- text_len = len(input_tokens) - st
991
- llm_pos_ids_list.append(
992
- torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
911
+ image_index, video_index = 0, 0
912
+ for i, input_ids in enumerate(total_input_ids):
913
+ image_nums, video_nums = 0, 0
914
+ vision_start_indices = torch.argwhere(
915
+ input_ids == vision_start_token_id
916
+ ).squeeze(1)
917
+ vision_tokens = input_ids[vision_start_indices + 1]
918
+ image_nums = (vision_tokens == image_token_id).sum()
919
+ video_nums = (vision_tokens == video_token_id).sum()
920
+ input_tokens = input_ids.tolist()
921
+ llm_pos_ids_list: list = []
922
+ st = 0
923
+ remain_images, remain_videos = image_nums, video_nums
924
+ for _ in range(image_nums + video_nums):
925
+ if image_token_id in input_tokens and remain_images > 0:
926
+ ed_image = input_tokens.index(image_token_id, st)
927
+ else:
928
+ ed_image = len(input_tokens) + 1
929
+ if video_token_id in input_tokens and remain_videos > 0:
930
+ ed_video = input_tokens.index(video_token_id, st)
931
+ else:
932
+ ed_video = len(input_tokens) + 1
933
+ if ed_image < ed_video:
934
+ t, h, w = (
935
+ image_grid_thw[image_index][0],
936
+ image_grid_thw[image_index][1],
937
+ image_grid_thw[image_index][2],
938
+ )
939
+ second_per_grid_t = 0
940
+ image_index += 1
941
+ remain_images -= 1
942
+ ed = ed_image
943
+ else:
944
+ t, h, w = (
945
+ video_grid_thw[video_index][0],
946
+ video_grid_thw[video_index][1],
947
+ video_grid_thw[video_index][2],
948
+ )
949
+ if second_per_grid_ts is not None:
950
+ second_per_grid_t = second_per_grid_ts[video_index]
951
+ else:
952
+ second_per_grid_t = 1.0
953
+ video_index += 1
954
+ remain_videos -= 1
955
+ ed = ed_video
956
+ llm_grid_t, llm_grid_h, llm_grid_w = (
957
+ t.item(),
958
+ h.item() // spatial_merge_size,
959
+ w.item() // spatial_merge_size,
960
+ )
961
+ text_len = ed - st
962
+
963
+ st_idx = (
964
+ llm_pos_ids_list[-1].max() + 1
965
+ if len(llm_pos_ids_list) > 0
966
+ else 0
967
+ )
968
+ llm_pos_ids_list.append(
969
+ torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
970
+ )
971
+
972
+ if model_type == "qwen2_5_vl":
973
+ range_tensor = torch.arange(llm_grid_t).view(-1, 1)
974
+ expanded_range = range_tensor.expand(
975
+ -1, llm_grid_h * llm_grid_w
976
+ )
977
+
978
+ time_tensor = (
979
+ expanded_range * second_per_grid_t * tokens_per_second
980
+ )
981
+
982
+ time_tensor_long = time_tensor.long()
983
+ t_index = time_tensor_long.flatten()
984
+ elif model_type == "qwen2_vl":
985
+ t_index = (
986
+ torch.arange(llm_grid_t)
987
+ .view(-1, 1)
988
+ .expand(-1, llm_grid_h * llm_grid_w)
989
+ .flatten()
990
+ )
991
+ else:
992
+ raise RuntimeError("Unimplemented")
993
+ h_index = (
994
+ torch.arange(llm_grid_h)
995
+ .view(1, -1, 1)
996
+ .expand(llm_grid_t, -1, llm_grid_w)
997
+ .flatten()
998
+ )
999
+ w_index = (
1000
+ torch.arange(llm_grid_w)
1001
+ .view(1, 1, -1)
1002
+ .expand(llm_grid_t, llm_grid_h, -1)
1003
+ .flatten()
1004
+ )
1005
+ llm_pos_ids_list.append(
1006
+ torch.stack([t_index, h_index, w_index]) + text_len + st_idx
1007
+ )
1008
+ st = ed + llm_grid_t * llm_grid_h * llm_grid_w
1009
+
1010
+ if st < len(input_tokens):
1011
+ st_idx = (
1012
+ llm_pos_ids_list[-1].max() + 1
1013
+ if len(llm_pos_ids_list) > 0
1014
+ else 0
1015
+ )
1016
+ text_len = len(input_tokens) - st
1017
+ llm_pos_ids_list.append(
1018
+ torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
1019
+ )
1020
+
1021
+ llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
1022
+ position_ids[..., i, :] = llm_positions.to(position_ids.device)
1023
+ mrope_position_deltas.append(
1024
+ llm_positions.max() + 1 - len(total_input_ids[i])
1025
+ )
1026
+ mrope_position_deltas = torch.tensor(
1027
+ mrope_position_deltas, device=input_ids.device
1028
+ ).unsqueeze(1)
1029
+ return position_ids, mrope_position_deltas
1030
+ else:
1031
+ s = input_ids.shape[1]
1032
+ position_ids = torch.arange(s)
1033
+ position_ids = (
1034
+ position_ids.unsqueeze(0).expand(3, -1, -1).to(input_ids.device)
993
1035
  )
994
-
995
- llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
996
- mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
997
- llm_positions = llm_positions[:, context_len:seq_len]
998
-
999
- return llm_positions.tolist(), mrope_position_delta
1036
+ max_position_ids = position_ids.max(0, keepdim=False)[0].max(
1037
+ -1, keepdim=True
1038
+ )[0]
1039
+ mrope_position_deltas = max_position_ids + 1 - s
1040
+ return position_ids, mrope_position_deltas
1000
1041
 
1001
1042
  @staticmethod
1002
1043
  def get_next_input_positions(
1003
1044
  mrope_position_delta: int,
1004
1045
  context_len: int,
1005
1046
  seq_len: int,
1006
- ) -> List[List[int]]:
1007
- return [
1008
- list(
1009
- range(
1010
- context_len + mrope_position_delta, seq_len + mrope_position_delta
1047
+ ) -> torch.Tensor:
1048
+ return torch.tensor(
1049
+ [
1050
+ list(
1051
+ range(
1052
+ context_len + mrope_position_delta,
1053
+ seq_len + mrope_position_delta,
1054
+ )
1011
1055
  )
1012
- )
1013
- for _ in range(3)
1014
- ]
1056
+ for _ in range(3)
1057
+ ]
1058
+ )
1015
1059
 
1016
1060
 
1017
1061
  _ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}