sglang 0.4.3.post3__py3-none-any.whl → 0.4.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +2 -2
- sglang/lang/chat_template.py +29 -0
- sglang/srt/_custom_ops.py +19 -17
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/janus_pro.py +629 -0
- sglang/srt/configs/model_config.py +24 -14
- sglang/srt/conversation.py +80 -2
- sglang/srt/custom_op.py +64 -3
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +18 -17
- sglang/srt/distributed/parallel_state.py +10 -1
- sglang/srt/entrypoints/engine.py +5 -3
- sglang/srt/entrypoints/http_server.py +1 -1
- sglang/srt/hf_transformers_utils.py +16 -1
- sglang/srt/layers/attention/flashinfer_backend.py +95 -49
- sglang/srt/layers/attention/flashinfer_mla_backend.py +317 -57
- sglang/srt/layers/attention/triton_backend.py +5 -5
- sglang/srt/layers/attention/triton_ops/decode_attention.py +6 -6
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +3 -3
- sglang/srt/layers/attention/triton_ops/extend_attention.py +4 -4
- sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +3 -3
- sglang/srt/layers/attention/vision.py +43 -62
- sglang/srt/layers/linear.py +1 -1
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -1
- sglang/srt/layers/moe/ep_moe/layer.py +25 -9
- sglang/srt/layers/moe/fused_moe_triton/configs/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +63 -23
- sglang/srt/layers/moe/fused_moe_triton/layer.py +16 -4
- sglang/srt/layers/parameter.py +10 -0
- sglang/srt/layers/quantization/__init__.py +90 -68
- sglang/srt/layers/quantization/blockwise_int8.py +1 -2
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8.py +174 -106
- sglang/srt/layers/quantization/fp8_kernel.py +210 -38
- sglang/srt/layers/quantization/fp8_utils.py +156 -15
- sglang/srt/layers/quantization/modelopt_quant.py +5 -1
- sglang/srt/layers/quantization/w8a8_fp8.py +128 -0
- sglang/srt/layers/quantization/w8a8_int8.py +152 -3
- sglang/srt/layers/rotary_embedding.py +5 -3
- sglang/srt/layers/sampler.py +29 -35
- sglang/srt/layers/vocab_parallel_embedding.py +0 -1
- sglang/srt/lora/backend/__init__.py +9 -12
- sglang/srt/managers/cache_controller.py +72 -8
- sglang/srt/managers/image_processor.py +37 -631
- sglang/srt/managers/image_processors/base_image_processor.py +219 -0
- sglang/srt/managers/image_processors/janus_pro.py +79 -0
- sglang/srt/managers/image_processors/llava.py +152 -0
- sglang/srt/managers/image_processors/minicpmv.py +86 -0
- sglang/srt/managers/image_processors/mlama.py +60 -0
- sglang/srt/managers/image_processors/qwen_vl.py +161 -0
- sglang/srt/managers/io_struct.py +33 -15
- sglang/srt/managers/multi_modality_padding.py +134 -0
- sglang/srt/managers/schedule_batch.py +212 -117
- sglang/srt/managers/schedule_policy.py +40 -8
- sglang/srt/managers/scheduler.py +258 -782
- sglang/srt/managers/scheduler_output_processor_mixin.py +611 -0
- sglang/srt/managers/tokenizer_manager.py +7 -6
- sglang/srt/managers/tp_worker_overlap_thread.py +4 -1
- sglang/srt/mem_cache/base_prefix_cache.py +6 -8
- sglang/srt/mem_cache/chunk_cache.py +12 -44
- sglang/srt/mem_cache/hiradix_cache.py +63 -34
- sglang/srt/mem_cache/memory_pool.py +112 -46
- sglang/srt/mem_cache/paged_allocator.py +283 -0
- sglang/srt/mem_cache/radix_cache.py +117 -36
- sglang/srt/metrics/collector.py +8 -0
- sglang/srt/model_executor/cuda_graph_runner.py +10 -11
- sglang/srt/model_executor/forward_batch_info.py +12 -8
- sglang/srt/model_executor/model_runner.py +153 -134
- sglang/srt/model_loader/loader.py +2 -1
- sglang/srt/model_loader/weight_utils.py +1 -1
- sglang/srt/models/deepseek_janus_pro.py +2127 -0
- sglang/srt/models/deepseek_nextn.py +23 -3
- sglang/srt/models/deepseek_v2.py +25 -19
- sglang/srt/models/minicpmv.py +28 -89
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/qwen2.py +0 -1
- sglang/srt/models/qwen2_5_vl.py +25 -50
- sglang/srt/models/qwen2_vl.py +33 -49
- sglang/srt/openai_api/adapter.py +37 -15
- sglang/srt/openai_api/protocol.py +8 -1
- sglang/srt/sampling/penaltylib/frequency_penalty.py +0 -1
- sglang/srt/sampling/penaltylib/presence_penalty.py +0 -1
- sglang/srt/server_args.py +19 -20
- sglang/srt/speculative/build_eagle_tree.py +6 -1
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -11
- sglang/srt/speculative/eagle_utils.py +2 -1
- sglang/srt/speculative/eagle_worker.py +109 -38
- sglang/srt/utils.py +104 -9
- sglang/test/runners.py +104 -10
- sglang/test/test_block_fp8.py +106 -16
- sglang/test/test_custom_ops.py +88 -0
- sglang/test/test_utils.py +20 -4
- sglang/utils.py +0 -4
- sglang/version.py +1 -1
- {sglang-0.4.3.post3.dist-info → sglang-0.4.4.dist-info}/METADATA +9 -9
- {sglang-0.4.3.post3.dist-info → sglang-0.4.4.dist-info}/RECORD +128 -83
- {sglang-0.4.3.post3.dist-info → sglang-0.4.4.dist-info}/WHEEL +1 -1
- {sglang-0.4.3.post3.dist-info → sglang-0.4.4.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.post3.dist-info → sglang-0.4.4.dist-info}/top_level.txt +0 -0
sglang/srt/openai_api/adapter.py
CHANGED
@@ -38,6 +38,7 @@ from sglang.srt.conversation import (
|
|
38
38
|
SeparatorStyle,
|
39
39
|
chat_template_exists,
|
40
40
|
generate_chat_conv,
|
41
|
+
generate_embedding_convs,
|
41
42
|
register_conv_template,
|
42
43
|
)
|
43
44
|
from sglang.srt.function_call_parser import TOOLS_TAG_LIST, FunctionCallParser
|
@@ -68,6 +69,7 @@ from sglang.srt.openai_api.protocol import (
|
|
68
69
|
FileResponse,
|
69
70
|
FunctionResponse,
|
70
71
|
LogProbs,
|
72
|
+
MultimodalEmbeddingInput,
|
71
73
|
ToolCall,
|
72
74
|
TopLogprob,
|
73
75
|
UsageInfo,
|
@@ -282,11 +284,11 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
|
|
282
284
|
file_request_list = []
|
283
285
|
all_requests = []
|
284
286
|
request_ids = []
|
285
|
-
for line in lines:
|
287
|
+
for line_id, line in enumerate(lines):
|
286
288
|
request_data = json.loads(line)
|
287
289
|
file_request_list.append(request_data)
|
288
290
|
body = request_data["body"]
|
289
|
-
request_ids.append(
|
291
|
+
request_ids.append(f"{batch_id}-req_{line_id}")
|
290
292
|
|
291
293
|
# Although streaming is supported for standalone completions, it is not supported in
|
292
294
|
# batch mode (multiple completions in single request).
|
@@ -436,15 +438,9 @@ async def cancel_batch(tokenizer_manager, batch_id: str, input_file_id: str):
|
|
436
438
|
with open(input_file_path, "r", encoding="utf-8") as f:
|
437
439
|
lines = f.readlines()
|
438
440
|
|
439
|
-
file_request_list = []
|
440
|
-
request_ids = []
|
441
|
-
for line in lines:
|
442
|
-
request_data = json.loads(line)
|
443
|
-
file_request_list.append(request_data)
|
444
|
-
request_ids.append(request_data["custom_id"])
|
445
|
-
|
446
441
|
# Cancel requests by request_ids
|
447
|
-
for
|
442
|
+
for line_id in range(len(lines)):
|
443
|
+
rid = f"{batch_id}-req_{line_id}"
|
448
444
|
tokenizer_manager.abort_request(rid=rid)
|
449
445
|
|
450
446
|
retrieve_batch = batch_storage[batch_id]
|
@@ -824,13 +820,13 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|
824
820
|
)
|
825
821
|
|
826
822
|
final_usage_chunk = CompletionStreamResponse(
|
827
|
-
id=
|
823
|
+
id=content["meta_info"]["id"],
|
828
824
|
choices=[],
|
829
825
|
model=request.model,
|
830
826
|
usage=usage,
|
831
827
|
)
|
832
828
|
final_usage_data = final_usage_chunk.model_dump_json(
|
833
|
-
|
829
|
+
exclude_none=True
|
834
830
|
)
|
835
831
|
yield f"data: {final_usage_data}\n\n"
|
836
832
|
except ValueError as e:
|
@@ -1151,7 +1147,7 @@ def v1_chat_generate_response(
|
|
1151
1147
|
"tool_calls": tool_calls,
|
1152
1148
|
"reasoning_content": reasoning_text,
|
1153
1149
|
},
|
1154
|
-
"logprobs": choice_logprobs,
|
1150
|
+
"logprobs": choice_logprobs.model_dump() if choice_logprobs else None,
|
1155
1151
|
"finish_reason": (finish_reason["type"] if finish_reason else ""),
|
1156
1152
|
"matched_stop": (
|
1157
1153
|
finish_reason["matched"]
|
@@ -1499,13 +1495,13 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|
1499
1495
|
)
|
1500
1496
|
|
1501
1497
|
final_usage_chunk = ChatCompletionStreamResponse(
|
1502
|
-
id=
|
1498
|
+
id=content["meta_info"]["id"],
|
1503
1499
|
choices=[],
|
1504
1500
|
model=request.model,
|
1505
1501
|
usage=usage,
|
1506
1502
|
)
|
1507
1503
|
final_usage_data = final_usage_chunk.model_dump_json(
|
1508
|
-
|
1504
|
+
exclude_none=True
|
1509
1505
|
)
|
1510
1506
|
yield f"data: {final_usage_data}\n\n"
|
1511
1507
|
except ValueError as e:
|
@@ -1556,11 +1552,37 @@ def v1_embedding_request(all_requests, tokenizer_manager):
|
|
1556
1552
|
prompt = prompts[0]
|
1557
1553
|
if isinstance(prompt, str) or isinstance(prompt[0], str):
|
1558
1554
|
prompt_kwargs = {"text": prompt}
|
1555
|
+
elif isinstance(prompt, list) and isinstance(
|
1556
|
+
prompt[0], MultimodalEmbeddingInput
|
1557
|
+
):
|
1558
|
+
assert (
|
1559
|
+
chat_template_name is not None
|
1560
|
+
), "chat_template_name is required for multimodal inputs"
|
1561
|
+
texts = []
|
1562
|
+
images = []
|
1563
|
+
for item in prompt:
|
1564
|
+
texts.append(item.text if item.text is not None else None)
|
1565
|
+
images.append(item.image if item.image is not None else None)
|
1566
|
+
convs = generate_embedding_convs(texts, images, chat_template_name)
|
1567
|
+
generate_prompts = []
|
1568
|
+
for conv in convs:
|
1569
|
+
generate_prompts.append(conv.get_prompt())
|
1570
|
+
if len(generate_prompts) == 1:
|
1571
|
+
prompt_kwargs = {"text": generate_prompts[0], "image_data": images[0]}
|
1572
|
+
else:
|
1573
|
+
prompt_kwargs = {"text": generate_prompts, "image_data": images}
|
1559
1574
|
else:
|
1560
1575
|
prompt_kwargs = {"input_ids": prompt}
|
1561
1576
|
else:
|
1562
1577
|
if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
|
1563
1578
|
prompt_kwargs = {"text": prompts}
|
1579
|
+
elif isinstance(prompts[0], list) and isinstance(
|
1580
|
+
prompts[0][0], MultimodalEmbeddingInput
|
1581
|
+
):
|
1582
|
+
# TODO: multiple requests
|
1583
|
+
raise NotImplementedError(
|
1584
|
+
"Multiple requests with multimodal inputs are not supported yet"
|
1585
|
+
)
|
1564
1586
|
else:
|
1565
1587
|
prompt_kwargs = {"input_ids": prompts}
|
1566
1588
|
|
@@ -403,10 +403,17 @@ class ChatCompletionStreamResponse(BaseModel):
|
|
403
403
|
usage: Optional[UsageInfo] = None
|
404
404
|
|
405
405
|
|
406
|
+
class MultimodalEmbeddingInput(BaseModel):
|
407
|
+
text: Optional[str] = None
|
408
|
+
image: Optional[str] = None
|
409
|
+
|
410
|
+
|
406
411
|
class EmbeddingRequest(BaseModel):
|
407
412
|
# Ordered by official OpenAI API documentation
|
408
413
|
# https://platform.openai.com/docs/api-reference/embeddings/create
|
409
|
-
input: Union[
|
414
|
+
input: Union[
|
415
|
+
List[int], List[List[int]], str, List[str], List[MultimodalEmbeddingInput]
|
416
|
+
]
|
410
417
|
model: str
|
411
418
|
encoding_format: str = "float"
|
412
419
|
dimensions: int = None
|
@@ -56,7 +56,6 @@ class BatchedFrequencyPenalizer(_BatchedPenalizer):
|
|
56
56
|
]
|
57
57
|
|
58
58
|
def _merge(self, their: "BatchedFrequencyPenalizer"):
|
59
|
-
print(f"{self.frequency_penalties.shape=}, {their.frequency_penalties.shape=}")
|
60
59
|
self.frequency_penalties = torch.cat(
|
61
60
|
[self.frequency_penalties, their.frequency_penalties], dim=0
|
62
61
|
)
|
@@ -56,7 +56,6 @@ class BatchedPresencePenalizer(_BatchedPenalizer):
|
|
56
56
|
]
|
57
57
|
|
58
58
|
def _merge(self, their: "BatchedPresencePenalizer"):
|
59
|
-
print(f"{self.presence_penalties.shape=}, {their.presence_penalties.shape=}")
|
60
59
|
self.presence_penalties = torch.cat(
|
61
60
|
[self.presence_penalties, their.presence_penalties], dim=0
|
62
61
|
)
|
sglang/srt/server_args.py
CHANGED
@@ -20,14 +20,13 @@ import random
|
|
20
20
|
import tempfile
|
21
21
|
from typing import List, Optional
|
22
22
|
|
23
|
-
import torch
|
24
|
-
|
25
23
|
from sglang.srt.hf_transformers_utils import check_gguf_file
|
26
24
|
from sglang.srt.reasoning_parser import ReasoningParser
|
27
25
|
from sglang.srt.utils import (
|
28
26
|
get_amdgpu_memory_capacity,
|
29
27
|
get_hpu_memory_capacity,
|
30
28
|
get_nvgpu_memory_capacity,
|
29
|
+
is_cuda,
|
31
30
|
is_flashinfer_available,
|
32
31
|
is_hip,
|
33
32
|
is_port_available,
|
@@ -71,7 +70,7 @@ class ServerArgs:
|
|
71
70
|
schedule_policy: str = "fcfs"
|
72
71
|
schedule_conservativeness: float = 1.0
|
73
72
|
cpu_offload_gb: int = 0
|
74
|
-
|
73
|
+
page_size: int = 1
|
75
74
|
|
76
75
|
# Other runtime options
|
77
76
|
tp_size: int = 1
|
@@ -191,10 +190,10 @@ class ServerArgs:
|
|
191
190
|
if self.random_seed is None:
|
192
191
|
self.random_seed = random.randint(0, 1 << 30)
|
193
192
|
|
194
|
-
if
|
195
|
-
gpu_mem = get_amdgpu_memory_capacity()
|
196
|
-
elif torch.cuda.is_available():
|
193
|
+
if is_cuda():
|
197
194
|
gpu_mem = get_nvgpu_memory_capacity()
|
195
|
+
elif is_hip():
|
196
|
+
gpu_mem = get_amdgpu_memory_capacity()
|
198
197
|
elif self.device == "hpu":
|
199
198
|
gpu_mem = get_hpu_memory_capacity()
|
200
199
|
else:
|
@@ -221,6 +220,8 @@ class ServerArgs:
|
|
221
220
|
else:
|
222
221
|
self.chunked_prefill_size = 8192
|
223
222
|
|
223
|
+
assert self.chunked_prefill_size % self.page_size == 0
|
224
|
+
|
224
225
|
# Set cuda graph max batch size
|
225
226
|
if self.cuda_graph_max_bs is None:
|
226
227
|
# Based on detailed statistics, when serving TP1/TP2 models on lower-end GPUs with HBM<25G, you can either disable cuda graph or set `cuda_graph_max_bs` to a very small value to reduce the memory overhead of creating cuda graphs, with almost no impact on performance. However, when serving models with TP4 or TP8, we need to enable cuda graph to maintain high performance. In this case, we can set `cuda_graph_max_bs` to 80 (half of the default value 160) to reduce the memory overhead of creating cuda graphs. Looking at the logs from TP4 serving of qwen2-72b, a value of 80 is sufficient and can reduce the memory overhead of creating cuda graphs on lower-end GPUs compared to the original 160, avoiding OOM issues.
|
@@ -259,7 +260,7 @@ class ServerArgs:
|
|
259
260
|
f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
|
260
261
|
)
|
261
262
|
|
262
|
-
#
|
263
|
+
# Data parallelism attention
|
263
264
|
if self.enable_dp_attention:
|
264
265
|
self.dp_size = self.tp_size
|
265
266
|
assert self.tp_size % self.dp_size == 0
|
@@ -277,19 +278,17 @@ class ServerArgs:
|
|
277
278
|
self.speculative_algorithm = "EAGLE"
|
278
279
|
|
279
280
|
if self.speculative_algorithm == "EAGLE":
|
280
|
-
self.disable_overlap_schedule = True
|
281
|
-
self.prefill_only_one_req = True
|
282
|
-
self.disable_cuda_graph_padding = True
|
283
281
|
if self.max_running_requests is None:
|
284
282
|
self.max_running_requests = 32
|
283
|
+
self.disable_cuda_graph_padding = True
|
284
|
+
self.disable_overlap_schedule = True
|
285
285
|
logger.info(
|
286
|
-
"Overlap scheduler
|
286
|
+
"Overlap scheduler is disabled because of using "
|
287
287
|
"eagle speculative decoding."
|
288
|
-
"Max running request set to 32 because of using eagle speculative decoding."
|
289
288
|
)
|
290
289
|
# The token generated from the verify step is counted.
|
291
290
|
# If sepculative_num_steps >= speculative_num_draft_tokens, the additional tokens will definitely be discarded.
|
292
|
-
assert self.speculative_num_steps < self.speculative_num_draft_tokens
|
291
|
+
# assert self.speculative_num_steps < self.speculative_num_draft_tokens
|
293
292
|
|
294
293
|
# GGUF
|
295
294
|
if (
|
@@ -408,6 +407,7 @@ class ServerArgs:
|
|
408
407
|
"gguf",
|
409
408
|
"modelopt",
|
410
409
|
"w8a8_int8",
|
410
|
+
"w8a8_fp8",
|
411
411
|
],
|
412
412
|
help="The quantization method.",
|
413
413
|
)
|
@@ -482,7 +482,7 @@ class ServerArgs:
|
|
482
482
|
"--chunked-prefill-size",
|
483
483
|
type=int,
|
484
484
|
default=ServerArgs.chunked_prefill_size,
|
485
|
-
help="The maximum number of tokens in a chunk for the chunked prefill. Setting this to -1 means disabling chunked prefill",
|
485
|
+
help="The maximum number of tokens in a chunk for the chunked prefill. Setting this to -1 means disabling chunked prefill.",
|
486
486
|
)
|
487
487
|
parser.add_argument(
|
488
488
|
"--max-prefill-tokens",
|
@@ -507,13 +507,13 @@ class ServerArgs:
|
|
507
507
|
"--cpu-offload-gb",
|
508
508
|
type=int,
|
509
509
|
default=ServerArgs.cpu_offload_gb,
|
510
|
-
help="How many GBs of RAM to reserve for CPU offloading",
|
510
|
+
help="How many GBs of RAM to reserve for CPU offloading.",
|
511
511
|
)
|
512
512
|
parser.add_argument(
|
513
|
-
"--
|
514
|
-
type=
|
515
|
-
|
516
|
-
|
513
|
+
"--page-size",
|
514
|
+
type=int,
|
515
|
+
default=ServerArgs.page_size,
|
516
|
+
help="The number of tokens in a page.",
|
517
517
|
)
|
518
518
|
|
519
519
|
# Other runtime options
|
@@ -773,7 +773,6 @@ class ServerArgs:
|
|
773
773
|
"--speculative-eagle-topk",
|
774
774
|
type=int,
|
775
775
|
help="The number of tokens sampled from the draft model in eagle2 each step.",
|
776
|
-
choices=[1, 2, 4, 8],
|
777
776
|
default=ServerArgs.speculative_eagle_topk,
|
778
777
|
)
|
779
778
|
parser.add_argument(
|
@@ -26,7 +26,12 @@ def build_tree_kernel_efficient_preprocess(
|
|
26
26
|
|
27
27
|
draft_tokens = torch.gather(ss_token_list, index=top_scores_index, dim=1)
|
28
28
|
draft_tokens = torch.cat((verified_id.unsqueeze(1), draft_tokens), dim=1).flatten()
|
29
|
-
|
29
|
+
|
30
|
+
if len(parents_list) > 1:
|
31
|
+
parent_list = torch.cat(parents_list[:-1], dim=1)
|
32
|
+
else:
|
33
|
+
batch_size = parents_list[0].shape[0]
|
34
|
+
parent_list = torch.empty(batch_size, 0, device=parents_list[0].device)
|
30
35
|
|
31
36
|
return parent_list, top_scores_index, draft_tokens
|
32
37
|
|
@@ -1,7 +1,6 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import bisect
|
4
|
-
import time
|
5
4
|
from typing import TYPE_CHECKING, Callable
|
6
5
|
|
7
6
|
import torch
|
@@ -162,20 +161,11 @@ class EAGLEDraftCudaGraphRunner:
|
|
162
161
|
|
163
162
|
run_once()
|
164
163
|
|
165
|
-
torch.cuda.synchronize()
|
166
|
-
self.model_runner.tp_group.barrier()
|
167
|
-
|
168
|
-
torch.cuda.synchronize()
|
169
|
-
self.model_runner.tp_group.barrier()
|
170
|
-
|
171
164
|
with torch.cuda.graph(
|
172
165
|
graph, pool=get_global_graph_memory_pool(), stream=stream
|
173
166
|
):
|
174
167
|
out = run_once()
|
175
168
|
|
176
|
-
torch.cuda.synchronize()
|
177
|
-
self.model_runner.tp_group.barrier()
|
178
|
-
|
179
169
|
set_global_graph_memory_pool(graph.pool())
|
180
170
|
return graph, out
|
181
171
|
|
@@ -204,7 +194,7 @@ class EAGLEDraftCudaGraphRunner:
|
|
204
194
|
|
205
195
|
# Attention backend
|
206
196
|
self.model_runner.draft_attn_backend.init_forward_metadata_replay_cuda_graph(
|
207
|
-
forward_batch
|
197
|
+
forward_batch, forward_batch.batch_size
|
208
198
|
)
|
209
199
|
|
210
200
|
# Replay
|
@@ -1,7 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
from dataclasses import dataclass
|
4
|
-
from typing import TYPE_CHECKING,
|
4
|
+
from typing import TYPE_CHECKING, List
|
5
5
|
|
6
6
|
import torch
|
7
7
|
import torch.nn.functional as F
|
@@ -62,6 +62,7 @@ class EagleDraftInput:
|
|
62
62
|
batch.input_ids[pt : pt + extend_len] = torch.concat(
|
63
63
|
(input_ids[1:], self.verified_id[i].reshape(1))
|
64
64
|
)
|
65
|
+
pt += extend_len
|
65
66
|
|
66
67
|
def prepare_extend_after_decode(self, batch: ScheduleBatch, speculative_num_steps):
|
67
68
|
assert self.verified_id.numel() == batch.out_cache_loc.shape[0]
|
@@ -1,20 +1,20 @@
|
|
1
1
|
import logging
|
2
2
|
import os
|
3
3
|
import time
|
4
|
-
from typing import
|
4
|
+
from typing import List, Optional, Tuple
|
5
5
|
|
6
6
|
import torch
|
7
7
|
from huggingface_hub import snapshot_download
|
8
8
|
|
9
9
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
10
|
-
from sglang.srt.
|
10
|
+
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
|
11
|
+
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
11
12
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
12
13
|
from sglang.srt.model_executor.forward_batch_info import (
|
13
14
|
CaptureHiddenMode,
|
14
15
|
ForwardBatch,
|
15
16
|
ForwardMode,
|
16
17
|
)
|
17
|
-
from sglang.srt.model_executor.model_runner import ModelRunner
|
18
18
|
from sglang.srt.server_args import ServerArgs
|
19
19
|
from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
|
20
20
|
EAGLEDraftCudaGraphRunner,
|
@@ -27,7 +27,6 @@ from sglang.srt.speculative.eagle_utils import (
|
|
27
27
|
fast_topk,
|
28
28
|
select_top_k_tokens,
|
29
29
|
)
|
30
|
-
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
31
30
|
from sglang.srt.utils import get_available_gpu_memory
|
32
31
|
|
33
32
|
logger = logging.getLogger(__name__)
|
@@ -44,16 +43,30 @@ class EAGLEWorker(TpModelWorker):
|
|
44
43
|
nccl_port: int,
|
45
44
|
target_worker: TpModelWorker,
|
46
45
|
):
|
46
|
+
# Parse arguments
|
47
|
+
self.server_args = server_args
|
48
|
+
self.topk = server_args.speculative_eagle_topk
|
49
|
+
self.speculative_num_steps = server_args.speculative_num_steps
|
50
|
+
self.padded_static_len = self.speculative_num_steps + 1
|
51
|
+
self.enable_nan_detection = server_args.enable_nan_detection
|
52
|
+
self.gpu_id = gpu_id
|
53
|
+
self.device = server_args.device
|
54
|
+
self.target_worker = target_worker
|
55
|
+
|
47
56
|
# Override context length with target model's context length
|
48
57
|
server_args.context_length = target_worker.model_runner.model_config.context_len
|
49
|
-
os.environ["SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"] = "1"
|
50
58
|
|
51
59
|
# Do not capture cuda graph in `super().__init__()`
|
52
|
-
#
|
60
|
+
# It will be captured later.
|
53
61
|
backup_disable_cuda_graph = server_args.disable_cuda_graph
|
54
62
|
server_args.disable_cuda_graph = True
|
63
|
+
# Share the allocator with a target worker.
|
64
|
+
# Draft and target worker own their own KV cache pools.
|
65
|
+
self.req_to_token_pool, self.token_to_kv_pool_allocator = (
|
66
|
+
target_worker.get_memory_pool()
|
67
|
+
)
|
55
68
|
|
56
|
-
#
|
69
|
+
# Load hot token ids
|
57
70
|
if server_args.speculative_token_map is not None:
|
58
71
|
self.hot_token_id = load_token_map(server_args.speculative_token_map)
|
59
72
|
server_args.json_model_override_args = (
|
@@ -62,13 +75,7 @@ class EAGLEWorker(TpModelWorker):
|
|
62
75
|
else:
|
63
76
|
self.hot_token_id = None
|
64
77
|
|
65
|
-
#
|
66
|
-
# owns its own KV cache.
|
67
|
-
self.req_to_token_pool, self.token_to_kv_pool_allocator = (
|
68
|
-
target_worker.get_memory_pool()
|
69
|
-
)
|
70
|
-
|
71
|
-
# Init target worker
|
78
|
+
# Init draft worker
|
72
79
|
super().__init__(
|
73
80
|
gpu_id=gpu_id,
|
74
81
|
tp_rank=tp_rank,
|
@@ -79,18 +86,6 @@ class EAGLEWorker(TpModelWorker):
|
|
79
86
|
req_to_token_pool=self.req_to_token_pool,
|
80
87
|
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
81
88
|
)
|
82
|
-
self.target_worker = target_worker
|
83
|
-
|
84
|
-
# Parse arguments
|
85
|
-
self.topk = server_args.speculative_eagle_topk
|
86
|
-
self.speculative_num_steps = server_args.speculative_num_steps
|
87
|
-
self.speculative_algorithm = SpeculativeAlgorithm.from_string(
|
88
|
-
server_args.speculative_algorithm
|
89
|
-
)
|
90
|
-
self.server_args = server_args
|
91
|
-
self.use_nan_detection = self.server_args.enable_nan_detection
|
92
|
-
self.device = self.model_runner.device
|
93
|
-
self.gpu_id = self.model_runner.gpu_id
|
94
89
|
|
95
90
|
# Share the embedding and lm_head
|
96
91
|
embed, head = self.target_worker.model_runner.model.get_embed_and_head()
|
@@ -103,8 +98,12 @@ class EAGLEWorker(TpModelWorker):
|
|
103
98
|
backup_disable_cuda_graph
|
104
99
|
)
|
105
100
|
|
101
|
+
self.init_attention_backend()
|
102
|
+
self.init_cuda_graphs()
|
103
|
+
|
104
|
+
def init_attention_backend(self):
|
106
105
|
# Create multi-step attn backends and cuda graph runners
|
107
|
-
if server_args.attention_backend == "flashinfer":
|
106
|
+
if self.server_args.attention_backend == "flashinfer":
|
108
107
|
from sglang.srt.layers.attention.flashinfer_backend import (
|
109
108
|
FlashInferMultiStepDraftBackend,
|
110
109
|
)
|
@@ -114,7 +113,7 @@ class EAGLEWorker(TpModelWorker):
|
|
114
113
|
self.topk,
|
115
114
|
self.speculative_num_steps,
|
116
115
|
)
|
117
|
-
elif server_args.attention_backend == "triton":
|
116
|
+
elif self.server_args.attention_backend == "triton":
|
118
117
|
from sglang.srt.layers.attention.triton_backend import (
|
119
118
|
TritonMultiStepDraftBackend,
|
120
119
|
)
|
@@ -124,13 +123,21 @@ class EAGLEWorker(TpModelWorker):
|
|
124
123
|
self.topk,
|
125
124
|
self.speculative_num_steps,
|
126
125
|
)
|
126
|
+
elif self.server_args.attention_backend == "flashinfer_mla":
|
127
|
+
from sglang.srt.layers.attention.flashinfer_mla_backend import (
|
128
|
+
FlashInferMLAMultiStepDraftBackend,
|
129
|
+
)
|
130
|
+
|
131
|
+
self.draft_attn_backend = FlashInferMLAMultiStepDraftBackend(
|
132
|
+
self.model_runner,
|
133
|
+
self.topk,
|
134
|
+
self.speculative_num_steps,
|
135
|
+
)
|
127
136
|
else:
|
128
137
|
raise ValueError(
|
129
|
-
f"EAGLE is not supportted in attention backend {server_args.attention_backend}"
|
138
|
+
f"EAGLE is not supportted in attention backend {self.server_args.attention_backend}"
|
130
139
|
)
|
131
|
-
|
132
140
|
self.draft_model_runner.draft_attn_backend = self.draft_attn_backend
|
133
|
-
self.init_cuda_graphs()
|
134
141
|
|
135
142
|
def init_cuda_graphs(self):
|
136
143
|
"""Capture cuda graphs."""
|
@@ -306,13 +313,10 @@ class EAGLEWorker(TpModelWorker):
|
|
306
313
|
|
307
314
|
# Set inputs
|
308
315
|
forward_batch.input_ids = input_ids
|
316
|
+
out_cache_loc = out_cache_loc.view(forward_batch.batch_size, -1)
|
309
317
|
forward_batch.out_cache_loc = out_cache_loc[
|
310
|
-
|
311
|
-
|
312
|
-
* i : forward_batch.batch_size
|
313
|
-
* self.topk
|
314
|
-
* (i + 1)
|
315
|
-
]
|
318
|
+
:, self.topk * i : self.topk * (i + 1)
|
319
|
+
].flatten()
|
316
320
|
forward_batch.positions.add_(1)
|
317
321
|
forward_batch.attn_backend = self.draft_attn_backend.attn_backends[i]
|
318
322
|
spec_info.hidden_states = hidden_states
|
@@ -356,8 +360,71 @@ class EAGLEWorker(TpModelWorker):
|
|
356
360
|
batch.forward_mode = ForwardMode.DECODE
|
357
361
|
batch.spec_info = res.draft_input
|
358
362
|
|
363
|
+
if batch.return_logprob:
|
364
|
+
self.add_logprob_values(batch, res, logits_output)
|
365
|
+
|
359
366
|
return logits_output, res, model_worker_batch
|
360
367
|
|
368
|
+
def add_logprob_values(
|
369
|
+
self,
|
370
|
+
batch: ScheduleBatch,
|
371
|
+
res: EagleVerifyOutput,
|
372
|
+
logits_output: LogitsProcessorOutput,
|
373
|
+
):
|
374
|
+
# Extract args
|
375
|
+
logits_output = res.logits_output
|
376
|
+
top_logprobs_nums = batch.top_logprobs_nums
|
377
|
+
token_ids_logprobs = batch.token_ids_logprobs
|
378
|
+
logprobs = torch.nn.functional.log_softmax(
|
379
|
+
logits_output.next_token_logits, dim=-1
|
380
|
+
)
|
381
|
+
batch_next_token_ids = res.verified_id
|
382
|
+
num_tokens_per_req = [accept + 1 for accept in res.accept_length_per_req_cpu]
|
383
|
+
|
384
|
+
# We should repeat top_logprobs_nums to match num_tokens_per_req.
|
385
|
+
top_logprobs_nums_repeat_interleaved = []
|
386
|
+
token_ids_logprobs_repeat_interleaved = []
|
387
|
+
for num, num_tokens in zip(top_logprobs_nums, num_tokens_per_req):
|
388
|
+
top_logprobs_nums_repeat_interleaved.extend([num] * num_tokens)
|
389
|
+
for token_ids, num_tokens in zip(token_ids_logprobs, num_tokens_per_req):
|
390
|
+
token_ids_logprobs_repeat_interleaved.extend([token_ids] * num_tokens)
|
391
|
+
|
392
|
+
# Extract logprobs
|
393
|
+
if any(x > 0 for x in top_logprobs_nums):
|
394
|
+
(
|
395
|
+
logits_output.next_token_top_logprobs_val,
|
396
|
+
logits_output.next_token_top_logprobs_idx,
|
397
|
+
) = get_top_logprobs(logprobs, top_logprobs_nums_repeat_interleaved)
|
398
|
+
|
399
|
+
if any(x is not None for x in token_ids_logprobs):
|
400
|
+
(
|
401
|
+
logits_output.next_token_token_ids_logprobs_val,
|
402
|
+
logits_output.next_token_token_ids_logprobs_idx,
|
403
|
+
) = get_token_ids_logprobs(logprobs, token_ids_logprobs_repeat_interleaved)
|
404
|
+
|
405
|
+
logits_output.next_token_logprobs = logprobs[
|
406
|
+
torch.arange(len(batch_next_token_ids), device=batch.sampling_info.device),
|
407
|
+
batch_next_token_ids,
|
408
|
+
]
|
409
|
+
|
410
|
+
# Add output logprobs to the request.
|
411
|
+
pt = 0
|
412
|
+
next_token_logprobs = logits_output.next_token_logprobs.tolist()
|
413
|
+
verified_ids = batch_next_token_ids.tolist()
|
414
|
+
for req, num_tokens in zip(batch.reqs, num_tokens_per_req):
|
415
|
+
for _ in range(num_tokens):
|
416
|
+
if req.return_logprob:
|
417
|
+
req.output_token_logprobs_val.append(next_token_logprobs[pt])
|
418
|
+
req.output_token_logprobs_idx.append(verified_ids[pt])
|
419
|
+
if req.top_logprobs_num > 0:
|
420
|
+
req.output_top_logprobs_val.append(
|
421
|
+
res.logits_output.next_token_top_logprobs_val[pt]
|
422
|
+
)
|
423
|
+
req.output_top_logprobs_idx.append(
|
424
|
+
res.logits_output.next_token_top_logprobs_idx[pt]
|
425
|
+
)
|
426
|
+
pt += 1
|
427
|
+
|
361
428
|
def forward_draft_extend(
|
362
429
|
self,
|
363
430
|
batch: ScheduleBatch,
|
@@ -381,6 +448,7 @@ class EAGLEWorker(TpModelWorker):
|
|
381
448
|
forward_batch = ForwardBatch.init_new(
|
382
449
|
model_worker_batch, self.draft_model_runner
|
383
450
|
)
|
451
|
+
forward_batch.return_logprob = False
|
384
452
|
logits_output = self.draft_model_runner.forward(forward_batch)
|
385
453
|
self._detect_nan_if_needed(logits_output)
|
386
454
|
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
@@ -393,6 +461,8 @@ class EAGLEWorker(TpModelWorker):
|
|
393
461
|
batch.spec_info.prepare_extend_after_decode(batch, self.speculative_num_steps)
|
394
462
|
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
395
463
|
# We don't need logprob for this extend.
|
464
|
+
original_return_logprob = batch.return_logprob
|
465
|
+
batch.return_logprob = False
|
396
466
|
model_worker_batch = batch.get_model_worker_batch()
|
397
467
|
forward_batch = ForwardBatch.init_new(
|
398
468
|
model_worker_batch, self.draft_model_runner
|
@@ -404,6 +474,7 @@ class EAGLEWorker(TpModelWorker):
|
|
404
474
|
|
405
475
|
# Restore backup.
|
406
476
|
# This is because `seq_lens` can be modified in `prepare_extend_after_decode`
|
477
|
+
batch.return_logprob = original_return_logprob
|
407
478
|
batch.forward_mode = ForwardMode.DECODE
|
408
479
|
batch.seq_lens = seq_lens_backup
|
409
480
|
|
@@ -415,7 +486,7 @@ class EAGLEWorker(TpModelWorker):
|
|
415
486
|
draft_input.hidden_states = logits_output.hidden_states
|
416
487
|
|
417
488
|
def _detect_nan_if_needed(self, logits_output: LogitsProcessorOutput):
|
418
|
-
if self.
|
489
|
+
if self.enable_nan_detection:
|
419
490
|
logits = logits_output.next_token_logits
|
420
491
|
if torch.any(torch.isnan(logits)):
|
421
492
|
logger.warning("Detected errors during sampling! NaN in the logits.")
|