sglang 0.4.5.post2__py3-none-any.whl → 0.4.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. sglang/bench_one_batch.py +19 -3
  2. sglang/bench_serving.py +8 -8
  3. sglang/compile_deep_gemm.py +177 -0
  4. sglang/lang/backend/openai.py +5 -1
  5. sglang/lang/backend/runtime_endpoint.py +5 -1
  6. sglang/srt/code_completion_parser.py +1 -1
  7. sglang/srt/configs/deepseekvl2.py +1 -1
  8. sglang/srt/configs/model_config.py +11 -2
  9. sglang/srt/constrained/llguidance_backend.py +78 -61
  10. sglang/srt/constrained/xgrammar_backend.py +1 -0
  11. sglang/srt/conversation.py +34 -1
  12. sglang/srt/disaggregation/decode.py +96 -5
  13. sglang/srt/disaggregation/mini_lb.py +113 -15
  14. sglang/srt/disaggregation/mooncake/conn.py +199 -32
  15. sglang/srt/disaggregation/nixl/__init__.py +1 -0
  16. sglang/srt/disaggregation/nixl/conn.py +622 -0
  17. sglang/srt/disaggregation/prefill.py +119 -20
  18. sglang/srt/disaggregation/utils.py +17 -0
  19. sglang/srt/entrypoints/engine.py +4 -0
  20. sglang/srt/entrypoints/http_server.py +11 -9
  21. sglang/srt/function_call_parser.py +132 -0
  22. sglang/srt/layers/activation.py +2 -2
  23. sglang/srt/layers/attention/base_attn_backend.py +3 -0
  24. sglang/srt/layers/attention/flashattention_backend.py +809 -160
  25. sglang/srt/layers/attention/flashmla_backend.py +8 -11
  26. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +5 -5
  27. sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -5
  28. sglang/srt/layers/attention/triton_ops/prefill_attention.py +7 -3
  29. sglang/srt/layers/attention/vision.py +2 -0
  30. sglang/srt/layers/dp_attention.py +1 -1
  31. sglang/srt/layers/layernorm.py +42 -5
  32. sglang/srt/layers/logits_processor.py +2 -2
  33. sglang/srt/layers/moe/ep_moe/layer.py +2 -0
  34. sglang/srt/layers/moe/fused_moe_native.py +2 -4
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +41 -41
  36. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  37. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +18 -15
  38. sglang/srt/layers/pooler.py +6 -0
  39. sglang/srt/layers/quantization/awq.py +5 -1
  40. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  41. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  42. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +153 -0
  43. sglang/srt/layers/quantization/deep_gemm.py +385 -0
  44. sglang/srt/layers/quantization/fp8_kernel.py +7 -38
  45. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  46. sglang/srt/layers/quantization/gptq.py +13 -7
  47. sglang/srt/layers/quantization/int8_kernel.py +32 -1
  48. sglang/srt/layers/quantization/modelopt_quant.py +2 -2
  49. sglang/srt/layers/quantization/w8a8_int8.py +3 -3
  50. sglang/srt/layers/radix_attention.py +13 -3
  51. sglang/srt/layers/rotary_embedding.py +176 -132
  52. sglang/srt/layers/sampler.py +2 -2
  53. sglang/srt/managers/data_parallel_controller.py +17 -4
  54. sglang/srt/managers/io_struct.py +21 -3
  55. sglang/srt/managers/mm_utils.py +85 -28
  56. sglang/srt/managers/multimodal_processors/base_processor.py +14 -1
  57. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +9 -2
  58. sglang/srt/managers/multimodal_processors/gemma3.py +2 -5
  59. sglang/srt/managers/multimodal_processors/janus_pro.py +2 -2
  60. sglang/srt/managers/multimodal_processors/minicpm.py +4 -3
  61. sglang/srt/managers/multimodal_processors/qwen_vl.py +38 -13
  62. sglang/srt/managers/schedule_batch.py +42 -12
  63. sglang/srt/managers/scheduler.py +47 -26
  64. sglang/srt/managers/tokenizer_manager.py +120 -30
  65. sglang/srt/managers/tp_worker.py +1 -0
  66. sglang/srt/mem_cache/hiradix_cache.py +40 -32
  67. sglang/srt/mem_cache/memory_pool.py +118 -13
  68. sglang/srt/model_executor/cuda_graph_runner.py +16 -10
  69. sglang/srt/model_executor/forward_batch_info.py +51 -95
  70. sglang/srt/model_executor/model_runner.py +29 -27
  71. sglang/srt/models/deepseek.py +12 -2
  72. sglang/srt/models/deepseek_nextn.py +101 -6
  73. sglang/srt/models/deepseek_v2.py +153 -76
  74. sglang/srt/models/deepseek_vl2.py +9 -4
  75. sglang/srt/models/gemma3_causal.py +1 -1
  76. sglang/srt/models/llama4.py +0 -1
  77. sglang/srt/models/minicpm3.py +2 -2
  78. sglang/srt/models/minicpmo.py +22 -7
  79. sglang/srt/models/mllama4.py +2 -2
  80. sglang/srt/models/qwen2_5_vl.py +3 -6
  81. sglang/srt/models/qwen2_vl.py +3 -7
  82. sglang/srt/models/roberta.py +178 -0
  83. sglang/srt/openai_api/adapter.py +87 -10
  84. sglang/srt/openai_api/protocol.py +6 -1
  85. sglang/srt/server_args.py +65 -60
  86. sglang/srt/speculative/build_eagle_tree.py +2 -2
  87. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  88. sglang/srt/speculative/eagle_utils.py +2 -2
  89. sglang/srt/speculative/eagle_worker.py +2 -7
  90. sglang/srt/torch_memory_saver_adapter.py +10 -1
  91. sglang/srt/utils.py +48 -6
  92. sglang/test/runners.py +6 -13
  93. sglang/test/test_utils.py +39 -19
  94. sglang/version.py +1 -1
  95. {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/METADATA +6 -7
  96. {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/RECORD +99 -92
  97. {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/WHEEL +1 -1
  98. {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/licenses/LICENSE +0 -0
  99. {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/top_level.txt +0 -0
@@ -16,19 +16,17 @@ import functools
16
16
  import json
17
17
  import logging
18
18
  import os
19
- from contextlib import contextmanager
20
19
  from typing import Any, Dict, List, Optional, Tuple
21
20
 
22
21
  import torch
23
22
  import triton
24
23
  import triton.language as tl
25
24
 
25
+ from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
26
26
  from sglang.srt.utils import (
27
27
  direct_register_custom_op,
28
- get_bool_env_var,
29
28
  get_device_core_count,
30
29
  get_device_name,
31
- get_device_sm,
32
30
  is_cuda,
33
31
  is_hip,
34
32
  supports_custom_op,
@@ -43,22 +41,16 @@ else:
43
41
  fp8_max = torch.finfo(_fp8_type).max
44
42
  fp8_min = -fp8_max
45
43
 
46
- _enable_jit_deepgemm = False
47
- _enable_jit_deepgemm_bmm = False
48
44
  if _is_cuda:
49
- import deep_gemm
50
45
  from sgl_kernel import (
51
46
  sgl_per_tensor_quant_fp8,
52
47
  sgl_per_token_group_quant_fp8,
53
48
  sgl_per_token_quant_fp8,
54
49
  )
55
50
 
56
- sm_version = get_device_sm()
57
- if sm_version == 90:
58
- if get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="false"):
59
- _enable_jit_deepgemm = True
60
- if get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM_BMM", default="false"):
61
- _enable_jit_deepgemm_bmm = True
51
+ from sglang.srt.layers.quantization.deep_gemm import (
52
+ gemm_nt_f8f8bf16 as deep_gemm_gemm_nt_f8f8bf16,
53
+ )
62
54
 
63
55
  logger = logging.getLogger(__name__)
64
56
 
@@ -71,10 +63,7 @@ if supports_custom_op():
71
63
  Bs: torch.Tensor,
72
64
  C: torch.Tensor,
73
65
  ) -> None:
74
- M, K = A.shape
75
- N, _ = B.shape
76
- with _log_jit_build(M, N, K):
77
- deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
66
+ deep_gemm_gemm_nt_f8f8bf16((A, As), (B, Bs), C)
78
67
 
79
68
  def deep_gemm_fp8_fp8_bf16_nt_fake(
80
69
  A: torch.Tensor,
@@ -715,25 +704,6 @@ def get_w8a8_block_fp8_configs(
715
704
  return None
716
705
 
717
706
 
718
- @contextmanager
719
- def _log_jit_build(M: int, N: int, K: int):
720
- from deep_gemm.jit.runtime import RuntimeCache
721
-
722
- origin_func = RuntimeCache.__getitem__
723
-
724
- def __patched_func(self, *args, **kwargs):
725
- ret = origin_func(self, *args, **kwargs)
726
- if ret is None:
727
- logger.warning(
728
- f"DeepGEMM JIT code generation <gemm_fp8_fp8_bf16_nt>: M={M}, N={N}, K={K}. Please wait."
729
- )
730
- return ret
731
-
732
- RuntimeCache.__getitem__ = __patched_func
733
- yield
734
- RuntimeCache.__getitem__ = origin_func
735
-
736
-
737
707
  def w8a8_block_fp8_matmul(
738
708
  A: torch.Tensor,
739
709
  B: torch.Tensor,
@@ -804,12 +774,11 @@ def w8a8_block_fp8_matmul(
804
774
  )
805
775
 
806
776
  # deepgemm only support bf16
807
- if C.dtype == torch.bfloat16 and _enable_jit_deepgemm:
777
+ if C.dtype == torch.bfloat16 and _ENABLE_JIT_DEEPGEMM:
808
778
  if supports_custom_op():
809
779
  torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C)
810
780
  else:
811
- with _log_jit_build(M, N, K):
812
- deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
781
+ deep_gemm_gemm_nt_f8f8bf16((A, As), (B, Bs), C)
813
782
  else:
814
783
  kernel = (
815
784
  _w8a8_block_fp8_matmul_unrolledx4
@@ -12,8 +12,8 @@ try:
12
12
  except ImportError:
13
13
  VLLM_AVAILABLE = False
14
14
 
15
+ from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
15
16
  from sglang.srt.layers.quantization.fp8_kernel import (
16
- _enable_jit_deepgemm,
17
17
  per_token_group_quant_fp8,
18
18
  scaled_fp8_quant,
19
19
  sglang_per_token_quant_fp8,
@@ -143,7 +143,7 @@ def apply_w8a8_block_fp8_linear(
143
143
  )
144
144
  gemm_a8w8_blockscale(q_input, weight, x_scale, weight_scale, output)
145
145
  else:
146
- if _enable_jit_deepgemm:
146
+ if _ENABLE_JIT_DEEPGEMM:
147
147
  q_input, x_scale = sglang_per_token_group_quant_fp8(
148
148
  input_2d,
149
149
  block_size[1],
@@ -37,6 +37,14 @@ except ImportError:
37
37
  logger = logging.getLogger(__name__)
38
38
 
39
39
 
40
+ def check_marlin_format(hf_quant_cfg: Dict[str, Any]) -> bool:
41
+ # compat: gptqmodel and autogptq (eol) main use checkpoint_format: str
42
+ # compat: autogptq <=0.7.1 is_marlin_format: bool
43
+ return hf_quant_cfg.get("checkpoint_format") == "marlin" or hf_quant_cfg.get(
44
+ "is_marlin_format", False
45
+ )
46
+
47
+
40
48
  class GPTQConfig(QuantizationConfig):
41
49
  """Config class for GPTQ.
42
50
 
@@ -262,13 +270,15 @@ class GPTQMarlinConfig(QuantizationConfig):
262
270
 
263
271
  @classmethod
264
272
  def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
273
+ is_marlin_format = check_marlin_format(hf_quant_cfg)
274
+
265
275
  can_convert = cls.is_gptq_marlin_compatible(hf_quant_cfg)
266
276
 
267
277
  is_valid_user_quant = (
268
278
  user_quant is None or user_quant == "marlin" or user_quant == "gptq_marlin"
269
279
  )
270
280
 
271
- if can_convert and is_valid_user_quant:
281
+ if not is_marlin_format and can_convert and is_valid_user_quant:
272
282
  msg = (
273
283
  "The model is convertible to {} during runtime."
274
284
  " Using {} kernel.".format(cls.get_name(), cls.get_name())
@@ -276,7 +286,7 @@ class GPTQMarlinConfig(QuantizationConfig):
276
286
  logger.info(msg)
277
287
  return cls.get_name()
278
288
 
279
- if can_convert and user_quant == "gptq":
289
+ if not is_marlin_format and can_convert and user_quant == "gptq":
280
290
  logger.info(
281
291
  "Detected that the model can run with gptq_marlin"
282
292
  ", however you specified quantization=gptq explicitly,"
@@ -401,11 +411,7 @@ class MarlinConfig(QuantizationConfig):
401
411
 
402
412
  @classmethod
403
413
  def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
404
- # compat: autogptq >=0.8.0 use checkpoint_format: str
405
- # compat: autogptq <=0.7.1 is_marlin_format: bool
406
- is_marlin_format = hf_quant_cfg.get(
407
- "checkpoint_format"
408
- ) == "marlin" or hf_quant_cfg.get("is_marlin_format", False)
414
+ is_marlin_format = check_marlin_format(hf_quant_cfg)
409
415
 
410
416
  is_valid_user_quant = (
411
417
  user_quant is None or user_quant == "gptq" or user_quant == "marlin"
@@ -8,7 +8,11 @@ import torch
8
8
  import triton
9
9
  import triton.language as tl
10
10
 
11
- from sglang.srt.utils import get_device_name
11
+ from sglang.srt.utils import get_device_name, is_cuda
12
+
13
+ _is_cuda = is_cuda()
14
+ if _is_cuda:
15
+ from sgl_kernel import sgl_per_token_group_quant_int8
12
16
 
13
17
  logger = logging.getLogger(__name__)
14
18
 
@@ -165,6 +169,33 @@ def per_token_group_quant_int8(
165
169
  return x_q, x_s
166
170
 
167
171
 
172
+ def sglang_per_token_group_quant_int8(
173
+ x: torch.Tensor,
174
+ group_size: int,
175
+ eps: float = 1e-10,
176
+ dtype: torch.dtype = torch.int8,
177
+ ):
178
+ assert (
179
+ x.shape[-1] % group_size == 0
180
+ ), "the last dimension of `x` cannot be divisible by `group_size`"
181
+ assert x.is_contiguous(), "`x` is not contiguous"
182
+
183
+ iinfo = torch.iinfo(dtype)
184
+ int8_max = iinfo.max
185
+ int8_min = iinfo.min
186
+
187
+ x_q = torch.empty_like(x, device=x.device, dtype=dtype)
188
+ x_s = torch.empty(
189
+ x.shape[:-1] + (x.shape[-1] // group_size,),
190
+ device=x.device,
191
+ dtype=torch.float32,
192
+ )
193
+
194
+ sgl_per_token_group_quant_int8(x, x_q, x_s, group_size, eps, int8_min, int8_max)
195
+
196
+ return x_q, x_s
197
+
198
+
168
199
  @triton.jit
169
200
  def _w8a8_block_int8_matmul(
170
201
  # Pointers to inputs and output
@@ -22,9 +22,9 @@ from sglang.srt.layers.quantization.utils import (
22
22
  requantize_with_max_scale,
23
23
  )
24
24
  from sglang.srt.layers.radix_attention import RadixAttention
25
- from sglang.srt.utils import is_cuda_available
25
+ from sglang.srt.utils import is_cuda
26
26
 
27
- if is_cuda_available():
27
+ if is_cuda():
28
28
  from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant
29
29
 
30
30
  # Initialize logger for the module
@@ -11,10 +11,10 @@ from sglang.srt.layers.quantization.base_config import (
11
11
  QuantizeMethodBase,
12
12
  )
13
13
  from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
14
- from sglang.srt.utils import is_cuda_available, set_weight_attrs
14
+ from sglang.srt.utils import is_cuda, set_weight_attrs
15
15
 
16
- is_cuda = is_cuda_available()
17
- if is_cuda:
16
+ _is_cuda = is_cuda()
17
+ if _is_cuda:
18
18
  from sgl_kernel import int8_scaled_mm
19
19
 
20
20
 
@@ -87,13 +87,23 @@ class RadixAttention(nn.Module):
87
87
  v,
88
88
  forward_batch: ForwardBatch,
89
89
  save_kv_cache: bool = True,
90
+ **kwargs,
90
91
  ):
91
92
  if k is not None:
92
93
  # For cross-layer sharing, kv can be None
93
94
  assert v is not None
94
- k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
95
- v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
95
+ if "k_rope" not in kwargs:
96
+ k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
97
+ v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
98
+ else:
99
+ k = k.view(-1, self.tp_k_head_num, self.v_head_dim)
96
100
 
97
101
  return forward_batch.attn_backend.forward(
98
- q, k, v, self, forward_batch, save_kv_cache
102
+ q,
103
+ k,
104
+ v,
105
+ self,
106
+ forward_batch,
107
+ save_kv_cache,
108
+ **kwargs,
99
109
  )
@@ -8,14 +8,12 @@ import torch
8
8
  import torch.nn as nn
9
9
 
10
10
  from sglang.srt.custom_op import CustomOp
11
- from sglang.srt.utils import is_cuda_available
11
+ from sglang.srt.utils import is_cuda
12
12
 
13
- _is_cuda_available = is_cuda_available()
13
+ _is_cuda = is_cuda()
14
14
 
15
- if _is_cuda_available:
15
+ if _is_cuda:
16
16
  from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
17
- else:
18
- from vllm._custom_ops import rotary_embedding as vllm_rotary_embedding
19
17
 
20
18
 
21
19
  def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
@@ -82,8 +80,14 @@ class RotaryEmbedding(CustomOp):
82
80
 
83
81
  cache = self._compute_cos_sin_cache()
84
82
  # NOTE(ByronHsu): cache needs to be in FP32 for numerical stability
85
- if not _is_cuda_available:
83
+ if not _is_cuda:
86
84
  cache = cache.to(dtype)
85
+
86
+ if not _is_cuda or self.head_size not in [64, 128, 256, 512]:
87
+ from vllm._custom_ops import rotary_embedding
88
+
89
+ self.vllm_rotary_embedding = rotary_embedding
90
+
87
91
  self.cos_sin_cache: torch.Tensor
88
92
  self.register_buffer("cos_sin_cache", cache, persistent=False)
89
93
 
@@ -149,7 +153,7 @@ class RotaryEmbedding(CustomOp):
149
153
  key: torch.Tensor,
150
154
  offsets: Optional[torch.Tensor] = None,
151
155
  ) -> Tuple[torch.Tensor, torch.Tensor]:
152
- if _is_cuda_available and (self.head_size in [64, 128, 256, 512]):
156
+ if _is_cuda and (self.head_size in [64, 128, 256, 512]):
153
157
  apply_rope_with_cos_sin_cache_inplace(
154
158
  positions=positions,
155
159
  query=query,
@@ -160,7 +164,7 @@ class RotaryEmbedding(CustomOp):
160
164
  )
161
165
  else:
162
166
  self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
163
- vllm_rotary_embedding(
167
+ self.vllm_rotary_embedding(
164
168
  positions,
165
169
  query,
166
170
  key,
@@ -652,7 +656,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
652
656
  def forward(self, *args, **kwargs):
653
657
  if torch.compiler.is_compiling():
654
658
  return self.forward_native(*args, **kwargs)
655
- if _is_cuda_available:
659
+ if _is_cuda:
656
660
  return self.forward_cuda(*args, **kwargs)
657
661
  else:
658
662
  return self.forward_native(*args, **kwargs)
@@ -665,6 +669,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
665
669
  offsets: Optional[torch.Tensor] = None,
666
670
  ) -> Tuple[torch.Tensor, torch.Tensor]:
667
671
  """PyTorch-native implementation equivalent to forward()."""
672
+ dtype = query.dtype
668
673
  query_rot = query[..., : self.rotary_dim]
669
674
  key_rot = key[..., : self.rotary_dim]
670
675
  if self.rotary_dim < self.head_size:
@@ -695,7 +700,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
695
700
  else:
696
701
  query = query_rot
697
702
  key = key_rot
698
- return query, key
703
+ return query.to(dtype), key.to(dtype)
699
704
 
700
705
 
701
706
  class Llama3RotaryEmbedding(RotaryEmbedding):
@@ -876,142 +881,181 @@ class MRotaryEmbedding(RotaryEmbedding):
876
881
  key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
877
882
  return query, key
878
883
 
884
+ # Copied from https://github.com/huggingface/transformers/blob/c8e0e603de9b3d49161a15fe6e8ea84badfb5d02/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1439
879
885
  @staticmethod
880
- def get_input_positions(
881
- input_tokens: List[int],
882
- image_grid_thw: Union[List[List[int]], torch.Tensor],
883
- video_grid_thw: Union[List[List[int]], torch.Tensor],
886
+ def get_rope_index(
887
+ spatial_merge_size: int,
884
888
  image_token_id: int,
885
889
  video_token_id: int,
886
890
  vision_start_token_id: int,
887
- vision_end_token_id: int,
888
- spatial_merge_size: int,
889
- context_len: int = 0,
890
- seq_len: Optional[int] = None,
891
- second_per_grid_ts: Optional[torch.Tensor] = None,
891
+ model_type: str,
892
892
  tokens_per_second: Optional[int] = None,
893
- ) -> Tuple[List[List[int]], int]:
894
- """
895
- Get mrope input positions and delta value.
896
-
897
- :arg
898
- second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*):
899
- The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs.
900
-
901
- """
902
-
903
- if isinstance(image_grid_thw, torch.Tensor):
904
- image_grid_thw = image_grid_thw.tolist()
905
- if isinstance(video_grid_thw, torch.Tensor):
906
- video_grid_thw = video_grid_thw.tolist()
907
-
908
- input_tokens_tensor = torch.tensor(input_tokens)
909
- vision_start_indices = torch.argwhere(
910
- input_tokens_tensor == vision_start_token_id
911
- ).squeeze(1)
912
- vision_tokens = input_tokens_tensor[vision_start_indices + 1]
913
- image_nums = (vision_tokens == image_token_id).sum()
914
- video_nums = (vision_tokens == video_token_id).sum()
915
- llm_pos_ids_list: list = []
916
-
917
- st = 0
918
- remain_images, remain_videos = image_nums, video_nums
919
-
920
- image_index, video_index = 0, 0
921
- for _ in range(image_nums + video_nums):
922
- if image_token_id in input_tokens and remain_images > 0:
923
- ed_image = input_tokens.index(image_token_id, st)
924
- else:
925
- ed_image = len(input_tokens) + 1
926
- if video_token_id in input_tokens and remain_videos > 0:
927
- ed_video = input_tokens.index(video_token_id, st)
928
- else:
929
- ed_video = len(input_tokens) + 1
930
- if ed_image < ed_video:
931
- t, h, w = (
932
- image_grid_thw[image_index][0],
933
- image_grid_thw[image_index][1],
934
- image_grid_thw[image_index][2],
935
- )
936
- image_index += 1
937
- remain_images -= 1
938
- second_per_grid_t = 0
939
- ed = ed_image
940
- else:
941
- t, h, w = (
942
- video_grid_thw[video_index][0],
943
- video_grid_thw[video_index][1],
944
- video_grid_thw[video_index][2],
945
- )
946
- if second_per_grid_ts is not None:
947
- second_per_grid_t = second_per_grid_ts[video_index]
948
- else:
949
- second_per_grid_t = 1.0
950
- video_index += 1
951
- remain_videos -= 1
952
- ed = ed_video
953
- llm_grid_t, llm_grid_h, llm_grid_w = (
954
- t,
955
- h // spatial_merge_size,
956
- w // spatial_merge_size,
957
- )
958
- text_len = ed - st
959
-
960
- st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
961
- llm_pos_ids_list.append(
962
- torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
963
- )
964
-
965
- t_index = (
966
- torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w)
967
- * second_per_grid_t
968
- * tokens_per_second
969
- ).flatten()
970
-
971
- h_index = (
972
- torch.arange(llm_grid_h)
973
- .view(1, -1, 1)
974
- .expand(llm_grid_t, -1, llm_grid_w)
975
- .flatten()
976
- )
977
- w_index = (
978
- torch.arange(llm_grid_w)
979
- .view(1, 1, -1)
980
- .expand(llm_grid_t, llm_grid_h, -1)
981
- .flatten()
982
- )
983
- llm_pos_ids_list.append(
984
- torch.stack([t_index, h_index, w_index]) + text_len + st_idx
893
+ input_ids: Optional[torch.LongTensor] = None,
894
+ image_grid_thw: Optional[torch.LongTensor] = None,
895
+ video_grid_thw: Optional[torch.LongTensor] = None,
896
+ second_per_grid_ts: Optional[torch.Tensor] = None,
897
+ **kwargs,
898
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
899
+ mrope_position_deltas = []
900
+ if input_ids is not None and (
901
+ image_grid_thw is not None or video_grid_thw is not None
902
+ ):
903
+ total_input_ids = input_ids
904
+ position_ids = torch.ones(
905
+ 3,
906
+ input_ids.shape[0],
907
+ input_ids.shape[1],
908
+ dtype=input_ids.dtype,
909
+ device=input_ids.device,
985
910
  )
986
- st = ed + llm_grid_t * llm_grid_h * llm_grid_w
987
-
988
- if st < len(input_tokens):
989
- st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
990
- text_len = len(input_tokens) - st
991
- llm_pos_ids_list.append(
992
- torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
911
+ image_index, video_index = 0, 0
912
+ for i, input_ids in enumerate(total_input_ids):
913
+ image_nums, video_nums = 0, 0
914
+ vision_start_indices = torch.argwhere(
915
+ input_ids == vision_start_token_id
916
+ ).squeeze(1)
917
+ vision_tokens = input_ids[vision_start_indices + 1]
918
+ image_nums = (vision_tokens == image_token_id).sum()
919
+ video_nums = (vision_tokens == video_token_id).sum()
920
+ input_tokens = input_ids.tolist()
921
+ llm_pos_ids_list: list = []
922
+ st = 0
923
+ remain_images, remain_videos = image_nums, video_nums
924
+ for _ in range(image_nums + video_nums):
925
+ if image_token_id in input_tokens and remain_images > 0:
926
+ ed_image = input_tokens.index(image_token_id, st)
927
+ else:
928
+ ed_image = len(input_tokens) + 1
929
+ if video_token_id in input_tokens and remain_videos > 0:
930
+ ed_video = input_tokens.index(video_token_id, st)
931
+ else:
932
+ ed_video = len(input_tokens) + 1
933
+ if ed_image < ed_video:
934
+ t, h, w = (
935
+ image_grid_thw[image_index][0],
936
+ image_grid_thw[image_index][1],
937
+ image_grid_thw[image_index][2],
938
+ )
939
+ second_per_grid_t = 0
940
+ image_index += 1
941
+ remain_images -= 1
942
+ ed = ed_image
943
+ else:
944
+ t, h, w = (
945
+ video_grid_thw[video_index][0],
946
+ video_grid_thw[video_index][1],
947
+ video_grid_thw[video_index][2],
948
+ )
949
+ if second_per_grid_ts is not None:
950
+ second_per_grid_t = second_per_grid_ts[video_index]
951
+ else:
952
+ second_per_grid_t = 1.0
953
+ video_index += 1
954
+ remain_videos -= 1
955
+ ed = ed_video
956
+ llm_grid_t, llm_grid_h, llm_grid_w = (
957
+ t.item(),
958
+ h.item() // spatial_merge_size,
959
+ w.item() // spatial_merge_size,
960
+ )
961
+ text_len = ed - st
962
+
963
+ st_idx = (
964
+ llm_pos_ids_list[-1].max() + 1
965
+ if len(llm_pos_ids_list) > 0
966
+ else 0
967
+ )
968
+ llm_pos_ids_list.append(
969
+ torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
970
+ )
971
+
972
+ if model_type == "qwen2_5_vl":
973
+ range_tensor = torch.arange(llm_grid_t).view(-1, 1)
974
+ expanded_range = range_tensor.expand(
975
+ -1, llm_grid_h * llm_grid_w
976
+ )
977
+
978
+ time_tensor = (
979
+ expanded_range * second_per_grid_t * tokens_per_second
980
+ )
981
+
982
+ time_tensor_long = time_tensor.long()
983
+ t_index = time_tensor_long.flatten()
984
+ elif model_type == "qwen2_vl":
985
+ t_index = (
986
+ torch.arange(llm_grid_t)
987
+ .view(-1, 1)
988
+ .expand(-1, llm_grid_h * llm_grid_w)
989
+ .flatten()
990
+ )
991
+ else:
992
+ raise RuntimeError("Unimplemented")
993
+ h_index = (
994
+ torch.arange(llm_grid_h)
995
+ .view(1, -1, 1)
996
+ .expand(llm_grid_t, -1, llm_grid_w)
997
+ .flatten()
998
+ )
999
+ w_index = (
1000
+ torch.arange(llm_grid_w)
1001
+ .view(1, 1, -1)
1002
+ .expand(llm_grid_t, llm_grid_h, -1)
1003
+ .flatten()
1004
+ )
1005
+ llm_pos_ids_list.append(
1006
+ torch.stack([t_index, h_index, w_index]) + text_len + st_idx
1007
+ )
1008
+ st = ed + llm_grid_t * llm_grid_h * llm_grid_w
1009
+
1010
+ if st < len(input_tokens):
1011
+ st_idx = (
1012
+ llm_pos_ids_list[-1].max() + 1
1013
+ if len(llm_pos_ids_list) > 0
1014
+ else 0
1015
+ )
1016
+ text_len = len(input_tokens) - st
1017
+ llm_pos_ids_list.append(
1018
+ torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
1019
+ )
1020
+
1021
+ llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
1022
+ position_ids[..., i, :] = llm_positions.to(position_ids.device)
1023
+ mrope_position_deltas.append(
1024
+ llm_positions.max() + 1 - len(total_input_ids[i])
1025
+ )
1026
+ mrope_position_deltas = torch.tensor(
1027
+ mrope_position_deltas, device=input_ids.device
1028
+ ).unsqueeze(1)
1029
+ return position_ids, mrope_position_deltas
1030
+ else:
1031
+ s = input_ids.shape[1]
1032
+ position_ids = torch.arange(s)
1033
+ position_ids = (
1034
+ position_ids.unsqueeze(0).expand(3, -1, -1).to(input_ids.device)
993
1035
  )
994
-
995
- llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
996
- mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
997
- llm_positions = llm_positions[:, context_len:seq_len]
998
-
999
- return llm_positions.tolist(), mrope_position_delta
1036
+ max_position_ids = position_ids.max(0, keepdim=False)[0].max(
1037
+ -1, keepdim=True
1038
+ )[0]
1039
+ mrope_position_deltas = max_position_ids + 1 - s
1040
+ return position_ids, mrope_position_deltas
1000
1041
 
1001
1042
  @staticmethod
1002
1043
  def get_next_input_positions(
1003
1044
  mrope_position_delta: int,
1004
1045
  context_len: int,
1005
1046
  seq_len: int,
1006
- ) -> List[List[int]]:
1007
- return [
1008
- list(
1009
- range(
1010
- context_len + mrope_position_delta, seq_len + mrope_position_delta
1047
+ ) -> torch.Tensor:
1048
+ return torch.tensor(
1049
+ [
1050
+ list(
1051
+ range(
1052
+ context_len + mrope_position_delta,
1053
+ seq_len + mrope_position_delta,
1054
+ )
1011
1055
  )
1012
- )
1013
- for _ in range(3)
1014
- ]
1056
+ for _ in range(3)
1057
+ ]
1058
+ )
1015
1059
 
1016
1060
 
1017
1061
  _ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
@@ -10,9 +10,9 @@ from sglang.srt.layers.dp_attention import get_attention_tp_group
10
10
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
11
11
  from sglang.srt.managers.schedule_batch import global_server_args_dict
12
12
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
13
- from sglang.srt.utils import crash_on_warnings, get_bool_env_var, is_cuda_available
13
+ from sglang.srt.utils import crash_on_warnings, get_bool_env_var, is_cuda
14
14
 
15
- if is_cuda_available():
15
+ if is_cuda():
16
16
  from sgl_kernel import (
17
17
  min_p_sampling_from_probs,
18
18
  top_k_renorm_prob,