sglang 0.4.3.post2__py3-none-any.whl → 0.4.3.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +1 -1
- sglang/bench_offline_throughput.py +19 -0
- sglang/bench_one_batch.py +2 -2
- sglang/bench_serving.py +123 -79
- sglang/global_config.py +8 -3
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/lang/ir.py +1 -1
- sglang/srt/_custom_ops.py +83 -91
- sglang/srt/configs/load_config.py +4 -1
- sglang/srt/configs/model_config.py +48 -2
- sglang/srt/configs/qwen2_5_vl_config.py +5 -2
- sglang/srt/constrained/base_grammar_backend.py +117 -15
- sglang/srt/constrained/llguidance_backend.py +151 -0
- sglang/srt/constrained/outlines_backend.py +24 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -38
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
- sglang/srt/distributed/parallel_state.py +48 -3
- sglang/srt/entrypoints/engine.py +67 -9
- sglang/srt/entrypoints/http_server.py +190 -41
- sglang/srt/entrypoints/verl_engine.py +147 -0
- sglang/srt/function_call_parser.py +0 -1
- sglang/srt/layers/activation.py +11 -0
- sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
- sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
- sglang/srt/layers/attention/flashinfer_backend.py +302 -414
- sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
- sglang/srt/layers/attention/torch_native_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +13 -8
- sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
- sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
- sglang/srt/layers/attention/utils.py +39 -0
- sglang/srt/layers/attention/vision.py +60 -63
- sglang/srt/layers/dp_attention.py +142 -1
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/linear.py +3 -1
- sglang/srt/layers/logits_processor.py +281 -45
- sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
- sglang/srt/layers/moe/ep_moe/layer.py +140 -28
- sglang/srt/layers/moe/fused_moe_native.py +2 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
- sglang/srt/layers/moe/topk.py +13 -4
- sglang/srt/layers/quantization/__init__.py +111 -7
- sglang/srt/layers/quantization/blockwise_int8.py +409 -0
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8.py +69 -28
- sglang/srt/layers/quantization/fp8_utils.py +17 -1
- sglang/srt/layers/quantization/gptq.py +416 -0
- sglang/srt/layers/quantization/int8_kernel.py +327 -0
- sglang/srt/layers/quantization/int8_utils.py +73 -0
- sglang/srt/layers/quantization/modelopt_quant.py +18 -1
- sglang/srt/layers/radix_attention.py +1 -0
- sglang/srt/layers/rotary_embedding.py +0 -1
- sglang/srt/layers/sampler.py +76 -31
- sglang/srt/layers/vocab_parallel_embedding.py +14 -13
- sglang/srt/lora/lora.py +17 -1
- sglang/srt/lora/lora_config.py +5 -0
- sglang/srt/lora/lora_manager.py +1 -3
- sglang/srt/managers/cache_controller.py +193 -62
- sglang/srt/managers/configure_logging.py +2 -1
- sglang/srt/managers/data_parallel_controller.py +6 -2
- sglang/srt/managers/detokenizer_manager.py +124 -102
- sglang/srt/managers/image_processor.py +2 -1
- sglang/srt/managers/io_struct.py +144 -6
- sglang/srt/managers/schedule_batch.py +237 -197
- sglang/srt/managers/schedule_policy.py +29 -29
- sglang/srt/managers/scheduler.py +773 -334
- sglang/srt/managers/session_controller.py +6 -2
- sglang/srt/managers/tokenizer_manager.py +225 -68
- sglang/srt/managers/tp_worker.py +15 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
- sglang/srt/mem_cache/chunk_cache.py +18 -11
- sglang/srt/mem_cache/hiradix_cache.py +394 -0
- sglang/srt/mem_cache/memory_pool.py +68 -37
- sglang/srt/mem_cache/radix_cache.py +58 -47
- sglang/srt/metrics/collector.py +102 -36
- sglang/srt/model_executor/cuda_graph_runner.py +56 -31
- sglang/srt/model_executor/forward_batch_info.py +49 -16
- sglang/srt/model_executor/model_runner.py +280 -81
- sglang/srt/model_loader/loader.py +3 -3
- sglang/srt/model_loader/weight_utils.py +36 -14
- sglang/srt/models/baichuan.py +31 -6
- sglang/srt/models/chatglm.py +39 -7
- sglang/srt/models/commandr.py +29 -5
- sglang/srt/models/dbrx.py +31 -5
- sglang/srt/models/deepseek.py +43 -6
- sglang/srt/models/deepseek_nextn.py +32 -19
- sglang/srt/models/deepseek_v2.py +265 -32
- sglang/srt/models/exaone.py +19 -9
- sglang/srt/models/gemma.py +22 -8
- sglang/srt/models/gemma2.py +25 -12
- sglang/srt/models/gemma2_reward.py +5 -1
- sglang/srt/models/gpt2.py +28 -13
- sglang/srt/models/gpt_bigcode.py +27 -5
- sglang/srt/models/granite.py +21 -9
- sglang/srt/models/grok.py +21 -4
- sglang/srt/models/internlm2.py +36 -6
- sglang/srt/models/internlm2_reward.py +5 -1
- sglang/srt/models/llama.py +26 -9
- sglang/srt/models/llama_classification.py +5 -1
- sglang/srt/models/llama_eagle.py +17 -4
- sglang/srt/models/llama_embedding.py +5 -1
- sglang/srt/models/llama_reward.py +7 -2
- sglang/srt/models/llava.py +19 -3
- sglang/srt/models/llavavid.py +10 -1
- sglang/srt/models/minicpm.py +26 -2
- sglang/srt/models/minicpm3.py +39 -3
- sglang/srt/models/minicpmv.py +45 -14
- sglang/srt/models/mixtral.py +20 -9
- sglang/srt/models/mixtral_quant.py +50 -8
- sglang/srt/models/mllama.py +57 -11
- sglang/srt/models/olmo.py +34 -6
- sglang/srt/models/olmo2.py +34 -13
- sglang/srt/models/olmoe.py +26 -4
- sglang/srt/models/phi3_small.py +29 -10
- sglang/srt/models/qwen.py +26 -3
- sglang/srt/models/qwen2.py +26 -4
- sglang/srt/models/qwen2_5_vl.py +46 -8
- sglang/srt/models/qwen2_eagle.py +17 -5
- sglang/srt/models/qwen2_moe.py +44 -6
- sglang/srt/models/qwen2_rm.py +78 -0
- sglang/srt/models/qwen2_vl.py +39 -8
- sglang/srt/models/stablelm.py +32 -5
- sglang/srt/models/torch_native_llama.py +5 -2
- sglang/srt/models/xverse.py +21 -9
- sglang/srt/models/xverse_moe.py +45 -7
- sglang/srt/models/yivl.py +2 -1
- sglang/srt/openai_api/adapter.py +109 -24
- sglang/srt/openai_api/protocol.py +17 -1
- sglang/srt/reasoning_parser.py +154 -0
- sglang/srt/sampling/penaltylib/__init__.py +4 -6
- sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
- sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
- sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
- sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
- sglang/srt/sampling/sampling_batch_info.py +79 -157
- sglang/srt/sampling/sampling_params.py +16 -13
- sglang/srt/server_args.py +135 -60
- sglang/srt/speculative/build_eagle_tree.py +8 -9
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -12
- sglang/srt/speculative/eagle_utils.py +92 -57
- sglang/srt/speculative/eagle_worker.py +238 -111
- sglang/srt/speculative/spec_info.py +1 -13
- sglang/srt/utils.py +43 -17
- sglang/srt/warmup.py +47 -0
- sglang/test/few_shot_gsm8k.py +4 -1
- sglang/test/runners.py +389 -126
- sglang/test/send_one.py +88 -0
- sglang/test/test_block_fp8_ep.py +361 -0
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +138 -84
- sglang/utils.py +50 -60
- sglang/version.py +1 -1
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/METADATA +22 -15
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/RECORD +200 -166
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/WHEEL +1 -1
- sglang/bench_latency.py +0 -1
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
- sglang/test/srt/sampling/penaltylib/utils.py +0 -344
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/top_level.txt +0 -0
@@ -31,7 +31,7 @@ from __future__ import annotations
|
|
31
31
|
|
32
32
|
from dataclasses import dataclass
|
33
33
|
from enum import IntEnum, auto
|
34
|
-
from typing import TYPE_CHECKING, List, Optional
|
34
|
+
from typing import TYPE_CHECKING, List, Optional, Union
|
35
35
|
|
36
36
|
import torch
|
37
37
|
import triton
|
@@ -41,12 +41,13 @@ from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
|
|
41
41
|
from sglang.srt.utils import get_compiler_backend
|
42
42
|
|
43
43
|
if TYPE_CHECKING:
|
44
|
-
from sglang.srt.layers.attention import AttentionBackend
|
44
|
+
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
45
45
|
from sglang.srt.managers.schedule_batch import ImageInputs, ModelWorkerBatch
|
46
46
|
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
47
47
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
48
48
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
49
|
-
from sglang.srt.speculative.
|
49
|
+
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
50
|
+
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
50
51
|
|
51
52
|
|
52
53
|
class ForwardMode(IntEnum):
|
@@ -112,7 +113,9 @@ class ForwardMode(IntEnum):
|
|
112
113
|
|
113
114
|
class CaptureHiddenMode(IntEnum):
|
114
115
|
NULL = auto()
|
116
|
+
# Capture hidden states of all tokens.
|
115
117
|
FULL = auto()
|
118
|
+
# Capture a hidden state of the last token.
|
116
119
|
LAST = auto()
|
117
120
|
|
118
121
|
def need_capture(self):
|
@@ -148,10 +151,14 @@ class ForwardBatch:
|
|
148
151
|
# For logprob
|
149
152
|
return_logprob: bool = False
|
150
153
|
top_logprobs_nums: Optional[List[int]] = None
|
154
|
+
token_ids_logprobs: Optional[List[List[int]]] = None
|
151
155
|
|
152
156
|
# Position information
|
153
157
|
positions: torch.Tensor = None
|
154
158
|
|
159
|
+
# For decode
|
160
|
+
decode_seq_lens_cpu: Optional[torch.Tensor] = None
|
161
|
+
|
155
162
|
# For extend
|
156
163
|
extend_num_tokens: Optional[int] = None
|
157
164
|
extend_seq_lens: Optional[torch.Tensor] = None
|
@@ -160,6 +167,7 @@ class ForwardBatch:
|
|
160
167
|
extend_prefix_lens_cpu: Optional[List[int]] = None
|
161
168
|
extend_seq_lens_cpu: Optional[List[int]] = None
|
162
169
|
extend_logprob_start_lens_cpu: Optional[List[int]] = None
|
170
|
+
extend_input_logprob_token_ids_gpu: Optional[torch.Tensor] = None
|
163
171
|
|
164
172
|
# For multimodal
|
165
173
|
image_inputs: Optional[List[ImageInputs]] = None
|
@@ -185,15 +193,27 @@ class ForwardBatch:
|
|
185
193
|
attn_backend: AttentionBackend = None
|
186
194
|
|
187
195
|
# For DP attention
|
188
|
-
|
196
|
+
global_num_tokens_cpu: Optional[List[int]] = None
|
197
|
+
global_num_tokens_gpu: Optional[torch.Tensor] = None
|
198
|
+
# Has to be None when cuda graph is captured.
|
199
|
+
global_num_tokens_for_logprob_cpu: Optional[List[int]] = None
|
200
|
+
global_num_tokens_for_logprob_gpu: Optional[torch.Tensor] = None
|
201
|
+
# for extend, local start pos and num tokens is different in logits processor
|
202
|
+
# this will be computed in get_dp_local_info
|
203
|
+
# this will be recomputed in LogitsMetadata.from_forward_batch
|
204
|
+
dp_local_start_pos: Optional[torch.Tensor] = None # cached info at runtime
|
205
|
+
dp_local_num_tokens: Optional[torch.Tensor] = None # cached info at runtime
|
189
206
|
gathered_buffer: Optional[torch.Tensor] = None
|
190
207
|
can_run_dp_cuda_graph: bool = False
|
191
208
|
|
192
209
|
# Speculative decoding
|
193
|
-
spec_info:
|
210
|
+
spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
|
194
211
|
spec_algorithm: SpeculativeAlgorithm = None
|
195
212
|
capture_hidden_mode: CaptureHiddenMode = None
|
196
213
|
|
214
|
+
# For padding
|
215
|
+
padded_static_len: int = -1 # -1 if not padded
|
216
|
+
|
197
217
|
# For Qwen2-VL
|
198
218
|
mrope_positions: torch.Tensor = None
|
199
219
|
|
@@ -203,8 +223,13 @@ class ForwardBatch:
|
|
203
223
|
batch: ModelWorkerBatch,
|
204
224
|
model_runner: ModelRunner,
|
205
225
|
):
|
206
|
-
|
207
226
|
device = model_runner.device
|
227
|
+
extend_input_logprob_token_ids_gpu = None
|
228
|
+
if batch.extend_input_logprob_token_ids is not None:
|
229
|
+
extend_input_logprob_token_ids_gpu = (
|
230
|
+
batch.extend_input_logprob_token_ids.to(device, non_blocking=True)
|
231
|
+
)
|
232
|
+
|
208
233
|
ret = cls(
|
209
234
|
forward_mode=batch.forward_mode,
|
210
235
|
batch_size=len(batch.seq_lens),
|
@@ -220,7 +245,7 @@ class ForwardBatch:
|
|
220
245
|
seq_lens_sum=batch.seq_lens_sum,
|
221
246
|
return_logprob=batch.return_logprob,
|
222
247
|
top_logprobs_nums=batch.top_logprobs_nums,
|
223
|
-
|
248
|
+
token_ids_logprobs=batch.token_ids_logprobs,
|
224
249
|
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
|
225
250
|
lora_paths=batch.lora_paths,
|
226
251
|
sampling_info=batch.sampling_info,
|
@@ -231,10 +256,12 @@ class ForwardBatch:
|
|
231
256
|
spec_info=batch.spec_info,
|
232
257
|
capture_hidden_mode=batch.capture_hidden_mode,
|
233
258
|
input_embeds=batch.input_embeds,
|
259
|
+
extend_input_logprob_token_ids_gpu=extend_input_logprob_token_ids_gpu,
|
234
260
|
)
|
235
261
|
|
236
|
-
if
|
237
|
-
|
262
|
+
if batch.global_num_tokens is not None:
|
263
|
+
ret.global_num_tokens_cpu = batch.global_num_tokens
|
264
|
+
max_len = max(ret.global_num_tokens_cpu)
|
238
265
|
ret.gathered_buffer = torch.zeros(
|
239
266
|
(max_len * model_runner.tp_size, model_runner.model_config.hidden_size),
|
240
267
|
dtype=model_runner.dtype,
|
@@ -256,6 +283,8 @@ class ForwardBatch:
|
|
256
283
|
if ret.forward_mode.is_decode():
|
257
284
|
if ret.positions is None:
|
258
285
|
ret.positions = clamp_position(batch.seq_lens)
|
286
|
+
if ret.decode_seq_lens_cpu is None:
|
287
|
+
ret.decode_seq_lens_cpu = batch.decode_seq_lens
|
259
288
|
else:
|
260
289
|
ret.extend_seq_lens = torch.tensor(
|
261
290
|
batch.extend_seq_lens, dtype=torch.int32
|
@@ -263,13 +292,12 @@ class ForwardBatch:
|
|
263
292
|
ret.extend_prefix_lens = torch.tensor(
|
264
293
|
batch.extend_prefix_lens, dtype=torch.int32
|
265
294
|
).to(device, non_blocking=True)
|
266
|
-
if
|
267
|
-
model_runner.server_args.attention_backend != "torch_native"
|
268
|
-
and model_runner.server_args.speculative_algorithm != "NEXTN"
|
269
|
-
):
|
295
|
+
if model_runner.server_args.attention_backend != "torch_native":
|
270
296
|
ret.extend_num_tokens = batch.extend_num_tokens
|
271
297
|
positions, ret.extend_start_loc = compute_position_triton(
|
272
|
-
ret.extend_prefix_lens,
|
298
|
+
ret.extend_prefix_lens,
|
299
|
+
ret.extend_seq_lens,
|
300
|
+
ret.extend_num_tokens,
|
273
301
|
)
|
274
302
|
else:
|
275
303
|
positions, ret.extend_start_loc = compute_position_torch(
|
@@ -341,6 +369,7 @@ class ForwardBatch:
|
|
341
369
|
)
|
342
370
|
batch.image_inputs[i].mrope_position_delta = mrope_position_delta
|
343
371
|
mrope_positions_list[i] = mrope_positions
|
372
|
+
|
344
373
|
self.mrope_positions = torch.concat(
|
345
374
|
[torch.tensor(pos, device=device) for pos in mrope_positions_list],
|
346
375
|
axis=1,
|
@@ -353,6 +382,8 @@ def compute_position_triton(
|
|
353
382
|
):
|
354
383
|
"""Compute positions. It is a fused version of `compute_position_torch`."""
|
355
384
|
batch_size = extend_seq_lens.shape[0]
|
385
|
+
has_prefix = extend_prefix_lens.shape[0] == batch_size
|
386
|
+
|
356
387
|
positions = torch.empty(
|
357
388
|
extend_seq_lens_sum, dtype=torch.int64, device=extend_seq_lens.device
|
358
389
|
)
|
@@ -366,6 +397,7 @@ def compute_position_triton(
|
|
366
397
|
extend_start_loc,
|
367
398
|
extend_prefix_lens,
|
368
399
|
extend_seq_lens,
|
400
|
+
has_prefix,
|
369
401
|
)
|
370
402
|
|
371
403
|
return positions, extend_start_loc
|
@@ -377,11 +409,12 @@ def compute_position_kernel(
|
|
377
409
|
extend_start_loc,
|
378
410
|
extend_prefix_lens,
|
379
411
|
extend_seq_lens,
|
412
|
+
has_prefix: tl.constexpr,
|
380
413
|
):
|
381
414
|
BLOCK_SIZE: tl.constexpr = 512
|
382
|
-
pid = tl.program_id(0)
|
415
|
+
pid = tl.program_id(0).to(tl.int64)
|
383
416
|
|
384
|
-
prefix_len = tl.load(extend_prefix_lens + pid)
|
417
|
+
prefix_len = tl.load(extend_prefix_lens + pid) if has_prefix else 0
|
385
418
|
seq_len = tl.load(extend_seq_lens + pid)
|
386
419
|
|
387
420
|
# TODO: optimize this?
|