sglang 0.4.3.post2__py3-none-any.whl → 0.4.3.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (205) hide show
  1. sglang/api.py +1 -1
  2. sglang/bench_offline_throughput.py +19 -0
  3. sglang/bench_one_batch.py +2 -2
  4. sglang/bench_serving.py +123 -79
  5. sglang/global_config.py +8 -3
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/lang/ir.py +1 -1
  8. sglang/srt/_custom_ops.py +83 -91
  9. sglang/srt/configs/load_config.py +4 -1
  10. sglang/srt/configs/model_config.py +48 -2
  11. sglang/srt/configs/qwen2_5_vl_config.py +5 -2
  12. sglang/srt/constrained/base_grammar_backend.py +117 -15
  13. sglang/srt/constrained/llguidance_backend.py +151 -0
  14. sglang/srt/constrained/outlines_backend.py +24 -33
  15. sglang/srt/constrained/xgrammar_backend.py +69 -38
  16. sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
  17. sglang/srt/distributed/parallel_state.py +48 -3
  18. sglang/srt/entrypoints/engine.py +67 -9
  19. sglang/srt/entrypoints/http_server.py +190 -41
  20. sglang/srt/entrypoints/verl_engine.py +147 -0
  21. sglang/srt/function_call_parser.py +0 -1
  22. sglang/srt/layers/activation.py +11 -0
  23. sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
  24. sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
  25. sglang/srt/layers/attention/flashinfer_backend.py +302 -414
  26. sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
  27. sglang/srt/layers/attention/torch_native_backend.py +1 -1
  28. sglang/srt/layers/attention/triton_backend.py +13 -8
  29. sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
  30. sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
  31. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
  32. sglang/srt/layers/attention/utils.py +39 -0
  33. sglang/srt/layers/attention/vision.py +60 -63
  34. sglang/srt/layers/dp_attention.py +142 -1
  35. sglang/srt/layers/layernorm.py +1 -1
  36. sglang/srt/layers/linear.py +3 -1
  37. sglang/srt/layers/logits_processor.py +281 -45
  38. sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
  39. sglang/srt/layers/moe/ep_moe/layer.py +140 -28
  40. sglang/srt/layers/moe/fused_moe_native.py +2 -0
  41. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
  43. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
  47. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
  48. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
  51. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
  55. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
  56. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
  63. sglang/srt/layers/moe/topk.py +13 -4
  64. sglang/srt/layers/quantization/__init__.py +111 -7
  65. sglang/srt/layers/quantization/blockwise_int8.py +409 -0
  66. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  68. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  69. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  70. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  71. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  72. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  73. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  74. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  75. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  76. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  77. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  78. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  79. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  80. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  81. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  82. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  83. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  84. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  85. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  86. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  87. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  88. sglang/srt/layers/quantization/fp8.py +69 -28
  89. sglang/srt/layers/quantization/fp8_utils.py +17 -1
  90. sglang/srt/layers/quantization/gptq.py +416 -0
  91. sglang/srt/layers/quantization/int8_kernel.py +327 -0
  92. sglang/srt/layers/quantization/int8_utils.py +73 -0
  93. sglang/srt/layers/quantization/modelopt_quant.py +18 -1
  94. sglang/srt/layers/radix_attention.py +1 -0
  95. sglang/srt/layers/rotary_embedding.py +0 -1
  96. sglang/srt/layers/sampler.py +76 -31
  97. sglang/srt/layers/vocab_parallel_embedding.py +14 -13
  98. sglang/srt/lora/lora.py +17 -1
  99. sglang/srt/lora/lora_config.py +5 -0
  100. sglang/srt/lora/lora_manager.py +1 -3
  101. sglang/srt/managers/cache_controller.py +193 -62
  102. sglang/srt/managers/configure_logging.py +2 -1
  103. sglang/srt/managers/data_parallel_controller.py +6 -2
  104. sglang/srt/managers/detokenizer_manager.py +124 -102
  105. sglang/srt/managers/image_processor.py +2 -1
  106. sglang/srt/managers/io_struct.py +144 -6
  107. sglang/srt/managers/schedule_batch.py +237 -197
  108. sglang/srt/managers/schedule_policy.py +29 -29
  109. sglang/srt/managers/scheduler.py +773 -334
  110. sglang/srt/managers/session_controller.py +6 -2
  111. sglang/srt/managers/tokenizer_manager.py +225 -68
  112. sglang/srt/managers/tp_worker.py +15 -4
  113. sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
  114. sglang/srt/mem_cache/chunk_cache.py +18 -11
  115. sglang/srt/mem_cache/hiradix_cache.py +394 -0
  116. sglang/srt/mem_cache/memory_pool.py +68 -37
  117. sglang/srt/mem_cache/radix_cache.py +58 -47
  118. sglang/srt/metrics/collector.py +102 -36
  119. sglang/srt/model_executor/cuda_graph_runner.py +56 -31
  120. sglang/srt/model_executor/forward_batch_info.py +49 -16
  121. sglang/srt/model_executor/model_runner.py +280 -81
  122. sglang/srt/model_loader/loader.py +3 -3
  123. sglang/srt/model_loader/weight_utils.py +36 -14
  124. sglang/srt/models/baichuan.py +31 -6
  125. sglang/srt/models/chatglm.py +39 -7
  126. sglang/srt/models/commandr.py +29 -5
  127. sglang/srt/models/dbrx.py +31 -5
  128. sglang/srt/models/deepseek.py +43 -6
  129. sglang/srt/models/deepseek_nextn.py +32 -19
  130. sglang/srt/models/deepseek_v2.py +265 -32
  131. sglang/srt/models/exaone.py +19 -9
  132. sglang/srt/models/gemma.py +22 -8
  133. sglang/srt/models/gemma2.py +25 -12
  134. sglang/srt/models/gemma2_reward.py +5 -1
  135. sglang/srt/models/gpt2.py +28 -13
  136. sglang/srt/models/gpt_bigcode.py +27 -5
  137. sglang/srt/models/granite.py +21 -9
  138. sglang/srt/models/grok.py +21 -4
  139. sglang/srt/models/internlm2.py +36 -6
  140. sglang/srt/models/internlm2_reward.py +5 -1
  141. sglang/srt/models/llama.py +26 -9
  142. sglang/srt/models/llama_classification.py +5 -1
  143. sglang/srt/models/llama_eagle.py +17 -4
  144. sglang/srt/models/llama_embedding.py +5 -1
  145. sglang/srt/models/llama_reward.py +7 -2
  146. sglang/srt/models/llava.py +19 -3
  147. sglang/srt/models/llavavid.py +10 -1
  148. sglang/srt/models/minicpm.py +26 -2
  149. sglang/srt/models/minicpm3.py +39 -3
  150. sglang/srt/models/minicpmv.py +45 -14
  151. sglang/srt/models/mixtral.py +20 -9
  152. sglang/srt/models/mixtral_quant.py +50 -8
  153. sglang/srt/models/mllama.py +57 -11
  154. sglang/srt/models/olmo.py +34 -6
  155. sglang/srt/models/olmo2.py +34 -13
  156. sglang/srt/models/olmoe.py +26 -4
  157. sglang/srt/models/phi3_small.py +29 -10
  158. sglang/srt/models/qwen.py +26 -3
  159. sglang/srt/models/qwen2.py +26 -4
  160. sglang/srt/models/qwen2_5_vl.py +46 -8
  161. sglang/srt/models/qwen2_eagle.py +17 -5
  162. sglang/srt/models/qwen2_moe.py +44 -6
  163. sglang/srt/models/qwen2_rm.py +78 -0
  164. sglang/srt/models/qwen2_vl.py +39 -8
  165. sglang/srt/models/stablelm.py +32 -5
  166. sglang/srt/models/torch_native_llama.py +5 -2
  167. sglang/srt/models/xverse.py +21 -9
  168. sglang/srt/models/xverse_moe.py +45 -7
  169. sglang/srt/models/yivl.py +2 -1
  170. sglang/srt/openai_api/adapter.py +109 -24
  171. sglang/srt/openai_api/protocol.py +17 -1
  172. sglang/srt/reasoning_parser.py +154 -0
  173. sglang/srt/sampling/penaltylib/__init__.py +4 -6
  174. sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
  175. sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
  176. sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
  177. sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
  178. sglang/srt/sampling/sampling_batch_info.py +79 -157
  179. sglang/srt/sampling/sampling_params.py +16 -13
  180. sglang/srt/server_args.py +135 -60
  181. sglang/srt/speculative/build_eagle_tree.py +8 -9
  182. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -12
  183. sglang/srt/speculative/eagle_utils.py +92 -57
  184. sglang/srt/speculative/eagle_worker.py +238 -111
  185. sglang/srt/speculative/spec_info.py +1 -13
  186. sglang/srt/utils.py +43 -17
  187. sglang/srt/warmup.py +47 -0
  188. sglang/test/few_shot_gsm8k.py +4 -1
  189. sglang/test/runners.py +389 -126
  190. sglang/test/send_one.py +88 -0
  191. sglang/test/test_block_fp8_ep.py +361 -0
  192. sglang/test/test_programs.py +1 -1
  193. sglang/test/test_utils.py +138 -84
  194. sglang/utils.py +50 -60
  195. sglang/version.py +1 -1
  196. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/METADATA +22 -15
  197. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/RECORD +200 -166
  198. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/WHEEL +1 -1
  199. sglang/bench_latency.py +0 -1
  200. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
  201. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
  202. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
  203. sglang/test/srt/sampling/penaltylib/utils.py +0 -344
  204. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/LICENSE +0 -0
  205. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/top_level.txt +0 -0
@@ -31,7 +31,7 @@ from __future__ import annotations
31
31
 
32
32
  from dataclasses import dataclass
33
33
  from enum import IntEnum, auto
34
- from typing import TYPE_CHECKING, List, Optional
34
+ from typing import TYPE_CHECKING, List, Optional, Union
35
35
 
36
36
  import torch
37
37
  import triton
@@ -41,12 +41,13 @@ from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
41
41
  from sglang.srt.utils import get_compiler_backend
42
42
 
43
43
  if TYPE_CHECKING:
44
- from sglang.srt.layers.attention import AttentionBackend
44
+ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
45
45
  from sglang.srt.managers.schedule_batch import ImageInputs, ModelWorkerBatch
46
46
  from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
47
47
  from sglang.srt.model_executor.model_runner import ModelRunner
48
48
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
49
- from sglang.srt.speculative.spec_info import SpecInfo, SpeculativeAlgorithm
49
+ from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
50
+ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
50
51
 
51
52
 
52
53
  class ForwardMode(IntEnum):
@@ -112,7 +113,9 @@ class ForwardMode(IntEnum):
112
113
 
113
114
  class CaptureHiddenMode(IntEnum):
114
115
  NULL = auto()
116
+ # Capture hidden states of all tokens.
115
117
  FULL = auto()
118
+ # Capture a hidden state of the last token.
116
119
  LAST = auto()
117
120
 
118
121
  def need_capture(self):
@@ -148,10 +151,14 @@ class ForwardBatch:
148
151
  # For logprob
149
152
  return_logprob: bool = False
150
153
  top_logprobs_nums: Optional[List[int]] = None
154
+ token_ids_logprobs: Optional[List[List[int]]] = None
151
155
 
152
156
  # Position information
153
157
  positions: torch.Tensor = None
154
158
 
159
+ # For decode
160
+ decode_seq_lens_cpu: Optional[torch.Tensor] = None
161
+
155
162
  # For extend
156
163
  extend_num_tokens: Optional[int] = None
157
164
  extend_seq_lens: Optional[torch.Tensor] = None
@@ -160,6 +167,7 @@ class ForwardBatch:
160
167
  extend_prefix_lens_cpu: Optional[List[int]] = None
161
168
  extend_seq_lens_cpu: Optional[List[int]] = None
162
169
  extend_logprob_start_lens_cpu: Optional[List[int]] = None
170
+ extend_input_logprob_token_ids_gpu: Optional[torch.Tensor] = None
163
171
 
164
172
  # For multimodal
165
173
  image_inputs: Optional[List[ImageInputs]] = None
@@ -185,15 +193,27 @@ class ForwardBatch:
185
193
  attn_backend: AttentionBackend = None
186
194
 
187
195
  # For DP attention
188
- global_num_tokens: Optional[List[int]] = None
196
+ global_num_tokens_cpu: Optional[List[int]] = None
197
+ global_num_tokens_gpu: Optional[torch.Tensor] = None
198
+ # Has to be None when cuda graph is captured.
199
+ global_num_tokens_for_logprob_cpu: Optional[List[int]] = None
200
+ global_num_tokens_for_logprob_gpu: Optional[torch.Tensor] = None
201
+ # for extend, local start pos and num tokens is different in logits processor
202
+ # this will be computed in get_dp_local_info
203
+ # this will be recomputed in LogitsMetadata.from_forward_batch
204
+ dp_local_start_pos: Optional[torch.Tensor] = None # cached info at runtime
205
+ dp_local_num_tokens: Optional[torch.Tensor] = None # cached info at runtime
189
206
  gathered_buffer: Optional[torch.Tensor] = None
190
207
  can_run_dp_cuda_graph: bool = False
191
208
 
192
209
  # Speculative decoding
193
- spec_info: SpecInfo = None
210
+ spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
194
211
  spec_algorithm: SpeculativeAlgorithm = None
195
212
  capture_hidden_mode: CaptureHiddenMode = None
196
213
 
214
+ # For padding
215
+ padded_static_len: int = -1 # -1 if not padded
216
+
197
217
  # For Qwen2-VL
198
218
  mrope_positions: torch.Tensor = None
199
219
 
@@ -203,8 +223,13 @@ class ForwardBatch:
203
223
  batch: ModelWorkerBatch,
204
224
  model_runner: ModelRunner,
205
225
  ):
206
-
207
226
  device = model_runner.device
227
+ extend_input_logprob_token_ids_gpu = None
228
+ if batch.extend_input_logprob_token_ids is not None:
229
+ extend_input_logprob_token_ids_gpu = (
230
+ batch.extend_input_logprob_token_ids.to(device, non_blocking=True)
231
+ )
232
+
208
233
  ret = cls(
209
234
  forward_mode=batch.forward_mode,
210
235
  batch_size=len(batch.seq_lens),
@@ -220,7 +245,7 @@ class ForwardBatch:
220
245
  seq_lens_sum=batch.seq_lens_sum,
221
246
  return_logprob=batch.return_logprob,
222
247
  top_logprobs_nums=batch.top_logprobs_nums,
223
- global_num_tokens=batch.global_num_tokens,
248
+ token_ids_logprobs=batch.token_ids_logprobs,
224
249
  can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
225
250
  lora_paths=batch.lora_paths,
226
251
  sampling_info=batch.sampling_info,
@@ -231,10 +256,12 @@ class ForwardBatch:
231
256
  spec_info=batch.spec_info,
232
257
  capture_hidden_mode=batch.capture_hidden_mode,
233
258
  input_embeds=batch.input_embeds,
259
+ extend_input_logprob_token_ids_gpu=extend_input_logprob_token_ids_gpu,
234
260
  )
235
261
 
236
- if ret.global_num_tokens is not None:
237
- max_len = max(ret.global_num_tokens)
262
+ if batch.global_num_tokens is not None:
263
+ ret.global_num_tokens_cpu = batch.global_num_tokens
264
+ max_len = max(ret.global_num_tokens_cpu)
238
265
  ret.gathered_buffer = torch.zeros(
239
266
  (max_len * model_runner.tp_size, model_runner.model_config.hidden_size),
240
267
  dtype=model_runner.dtype,
@@ -256,6 +283,8 @@ class ForwardBatch:
256
283
  if ret.forward_mode.is_decode():
257
284
  if ret.positions is None:
258
285
  ret.positions = clamp_position(batch.seq_lens)
286
+ if ret.decode_seq_lens_cpu is None:
287
+ ret.decode_seq_lens_cpu = batch.decode_seq_lens
259
288
  else:
260
289
  ret.extend_seq_lens = torch.tensor(
261
290
  batch.extend_seq_lens, dtype=torch.int32
@@ -263,13 +292,12 @@ class ForwardBatch:
263
292
  ret.extend_prefix_lens = torch.tensor(
264
293
  batch.extend_prefix_lens, dtype=torch.int32
265
294
  ).to(device, non_blocking=True)
266
- if (
267
- model_runner.server_args.attention_backend != "torch_native"
268
- and model_runner.server_args.speculative_algorithm != "NEXTN"
269
- ):
295
+ if model_runner.server_args.attention_backend != "torch_native":
270
296
  ret.extend_num_tokens = batch.extend_num_tokens
271
297
  positions, ret.extend_start_loc = compute_position_triton(
272
- ret.extend_prefix_lens, ret.extend_seq_lens, ret.extend_num_tokens
298
+ ret.extend_prefix_lens,
299
+ ret.extend_seq_lens,
300
+ ret.extend_num_tokens,
273
301
  )
274
302
  else:
275
303
  positions, ret.extend_start_loc = compute_position_torch(
@@ -341,6 +369,7 @@ class ForwardBatch:
341
369
  )
342
370
  batch.image_inputs[i].mrope_position_delta = mrope_position_delta
343
371
  mrope_positions_list[i] = mrope_positions
372
+
344
373
  self.mrope_positions = torch.concat(
345
374
  [torch.tensor(pos, device=device) for pos in mrope_positions_list],
346
375
  axis=1,
@@ -353,6 +382,8 @@ def compute_position_triton(
353
382
  ):
354
383
  """Compute positions. It is a fused version of `compute_position_torch`."""
355
384
  batch_size = extend_seq_lens.shape[0]
385
+ has_prefix = extend_prefix_lens.shape[0] == batch_size
386
+
356
387
  positions = torch.empty(
357
388
  extend_seq_lens_sum, dtype=torch.int64, device=extend_seq_lens.device
358
389
  )
@@ -366,6 +397,7 @@ def compute_position_triton(
366
397
  extend_start_loc,
367
398
  extend_prefix_lens,
368
399
  extend_seq_lens,
400
+ has_prefix,
369
401
  )
370
402
 
371
403
  return positions, extend_start_loc
@@ -377,11 +409,12 @@ def compute_position_kernel(
377
409
  extend_start_loc,
378
410
  extend_prefix_lens,
379
411
  extend_seq_lens,
412
+ has_prefix: tl.constexpr,
380
413
  ):
381
414
  BLOCK_SIZE: tl.constexpr = 512
382
- pid = tl.program_id(0)
415
+ pid = tl.program_id(0).to(tl.int64)
383
416
 
384
- prefix_len = tl.load(extend_prefix_lens + pid)
417
+ prefix_len = tl.load(extend_prefix_lens + pid) if has_prefix else 0
385
418
  seq_len = tl.load(extend_seq_lens + pid)
386
419
 
387
420
  # TODO: optimize this?