sglang 0.4.3.post4__py3-none-any.whl → 0.4.4.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +1 -1
- sglang/lang/chat_template.py +29 -0
- sglang/srt/_custom_ops.py +19 -17
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/janus_pro.py +629 -0
- sglang/srt/configs/model_config.py +24 -14
- sglang/srt/conversation.py +80 -2
- sglang/srt/custom_op.py +64 -3
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +18 -17
- sglang/srt/distributed/parallel_state.py +10 -1
- sglang/srt/entrypoints/engine.py +5 -3
- sglang/srt/entrypoints/http_server.py +1 -1
- sglang/srt/function_call_parser.py +33 -2
- sglang/srt/hf_transformers_utils.py +16 -1
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashinfer_mla_backend.py +317 -57
- sglang/srt/layers/attention/triton_backend.py +1 -3
- sglang/srt/layers/attention/triton_ops/decode_attention.py +6 -6
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +3 -3
- sglang/srt/layers/attention/triton_ops/extend_attention.py +4 -4
- sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +3 -3
- sglang/srt/layers/attention/vision.py +43 -62
- sglang/srt/layers/dp_attention.py +30 -2
- sglang/srt/layers/elementwise.py +411 -0
- sglang/srt/layers/linear.py +1 -1
- sglang/srt/layers/logits_processor.py +1 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -1
- sglang/srt/layers/moe/ep_moe/layer.py +25 -9
- sglang/srt/layers/moe/fused_moe_triton/configs/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +63 -23
- sglang/srt/layers/moe/fused_moe_triton/layer.py +16 -4
- sglang/srt/layers/moe/router.py +342 -0
- sglang/srt/layers/parameter.py +10 -0
- sglang/srt/layers/quantization/__init__.py +90 -68
- sglang/srt/layers/quantization/blockwise_int8.py +1 -2
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8.py +174 -106
- sglang/srt/layers/quantization/fp8_kernel.py +210 -38
- sglang/srt/layers/quantization/fp8_utils.py +156 -15
- sglang/srt/layers/quantization/modelopt_quant.py +5 -1
- sglang/srt/layers/quantization/w8a8_fp8.py +128 -0
- sglang/srt/layers/quantization/w8a8_int8.py +152 -3
- sglang/srt/layers/rotary_embedding.py +5 -3
- sglang/srt/layers/sampler.py +29 -35
- sglang/srt/layers/vocab_parallel_embedding.py +0 -1
- sglang/srt/lora/backend/__init__.py +9 -12
- sglang/srt/managers/cache_controller.py +74 -8
- sglang/srt/managers/data_parallel_controller.py +1 -1
- sglang/srt/managers/image_processor.py +37 -631
- sglang/srt/managers/image_processors/base_image_processor.py +219 -0
- sglang/srt/managers/image_processors/janus_pro.py +79 -0
- sglang/srt/managers/image_processors/llava.py +152 -0
- sglang/srt/managers/image_processors/minicpmv.py +86 -0
- sglang/srt/managers/image_processors/mlama.py +60 -0
- sglang/srt/managers/image_processors/qwen_vl.py +161 -0
- sglang/srt/managers/io_struct.py +32 -15
- sglang/srt/managers/multi_modality_padding.py +134 -0
- sglang/srt/managers/schedule_batch.py +213 -118
- sglang/srt/managers/schedule_policy.py +40 -8
- sglang/srt/managers/scheduler.py +176 -683
- sglang/srt/managers/scheduler_output_processor_mixin.py +614 -0
- sglang/srt/managers/tokenizer_manager.py +6 -6
- sglang/srt/managers/tp_worker_overlap_thread.py +4 -1
- sglang/srt/mem_cache/base_prefix_cache.py +6 -8
- sglang/srt/mem_cache/chunk_cache.py +12 -44
- sglang/srt/mem_cache/hiradix_cache.py +71 -34
- sglang/srt/mem_cache/memory_pool.py +81 -17
- sglang/srt/mem_cache/paged_allocator.py +283 -0
- sglang/srt/mem_cache/radix_cache.py +117 -36
- sglang/srt/model_executor/cuda_graph_runner.py +68 -20
- sglang/srt/model_executor/forward_batch_info.py +23 -10
- sglang/srt/model_executor/model_runner.py +63 -63
- sglang/srt/model_loader/loader.py +2 -1
- sglang/srt/model_loader/weight_utils.py +1 -1
- sglang/srt/models/deepseek_janus_pro.py +2127 -0
- sglang/srt/models/deepseek_nextn.py +23 -3
- sglang/srt/models/deepseek_v2.py +200 -191
- sglang/srt/models/grok.py +374 -119
- sglang/srt/models/minicpmv.py +28 -89
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/qwen2.py +0 -1
- sglang/srt/models/qwen2_5_vl.py +25 -50
- sglang/srt/models/qwen2_vl.py +33 -49
- sglang/srt/openai_api/adapter.py +59 -35
- sglang/srt/openai_api/protocol.py +8 -1
- sglang/srt/sampling/penaltylib/frequency_penalty.py +0 -1
- sglang/srt/sampling/penaltylib/presence_penalty.py +0 -1
- sglang/srt/server_args.py +24 -16
- sglang/srt/speculative/eagle_worker.py +75 -39
- sglang/srt/utils.py +104 -9
- sglang/test/runners.py +104 -10
- sglang/test/test_block_fp8.py +106 -16
- sglang/test/test_custom_ops.py +88 -0
- sglang/test/test_utils.py +20 -4
- sglang/utils.py +0 -4
- sglang/version.py +1 -1
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/METADATA +9 -10
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/RECORD +131 -84
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/WHEEL +1 -1
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/top_level.txt +0 -0
sglang/srt/openai_api/adapter.py
CHANGED
@@ -38,6 +38,7 @@ from sglang.srt.conversation import (
|
|
38
38
|
SeparatorStyle,
|
39
39
|
chat_template_exists,
|
40
40
|
generate_chat_conv,
|
41
|
+
generate_embedding_convs,
|
41
42
|
register_conv_template,
|
42
43
|
)
|
43
44
|
from sglang.srt.function_call_parser import TOOLS_TAG_LIST, FunctionCallParser
|
@@ -68,6 +69,7 @@ from sglang.srt.openai_api.protocol import (
|
|
68
69
|
FileResponse,
|
69
70
|
FunctionResponse,
|
70
71
|
LogProbs,
|
72
|
+
MultimodalEmbeddingInput,
|
71
73
|
ToolCall,
|
72
74
|
TopLogprob,
|
73
75
|
UsageInfo,
|
@@ -282,11 +284,11 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
|
|
282
284
|
file_request_list = []
|
283
285
|
all_requests = []
|
284
286
|
request_ids = []
|
285
|
-
for line in lines:
|
287
|
+
for line_id, line in enumerate(lines):
|
286
288
|
request_data = json.loads(line)
|
287
289
|
file_request_list.append(request_data)
|
288
290
|
body = request_data["body"]
|
289
|
-
request_ids.append(
|
291
|
+
request_ids.append(f"{batch_id}-req_{line_id}")
|
290
292
|
|
291
293
|
# Although streaming is supported for standalone completions, it is not supported in
|
292
294
|
# batch mode (multiple completions in single request).
|
@@ -436,15 +438,9 @@ async def cancel_batch(tokenizer_manager, batch_id: str, input_file_id: str):
|
|
436
438
|
with open(input_file_path, "r", encoding="utf-8") as f:
|
437
439
|
lines = f.readlines()
|
438
440
|
|
439
|
-
file_request_list = []
|
440
|
-
request_ids = []
|
441
|
-
for line in lines:
|
442
|
-
request_data = json.loads(line)
|
443
|
-
file_request_list.append(request_data)
|
444
|
-
request_ids.append(request_data["custom_id"])
|
445
|
-
|
446
441
|
# Cancel requests by request_ids
|
447
|
-
for
|
442
|
+
for line_id in range(len(lines)):
|
443
|
+
rid = f"{batch_id}-req_{line_id}"
|
448
444
|
tokenizer_manager.abort_request(rid=rid)
|
449
445
|
|
450
446
|
retrieve_batch = batch_storage[batch_id]
|
@@ -824,13 +820,13 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|
824
820
|
)
|
825
821
|
|
826
822
|
final_usage_chunk = CompletionStreamResponse(
|
827
|
-
id=
|
823
|
+
id=content["meta_info"]["id"],
|
828
824
|
choices=[],
|
829
825
|
model=request.model,
|
830
826
|
usage=usage,
|
831
827
|
)
|
832
828
|
final_usage_data = final_usage_chunk.model_dump_json(
|
833
|
-
|
829
|
+
exclude_none=True
|
834
830
|
)
|
835
831
|
yield f"data: {final_usage_data}\n\n"
|
836
832
|
except ValueError as e:
|
@@ -1119,27 +1115,29 @@ def v1_chat_generate_response(
|
|
1119
1115
|
else:
|
1120
1116
|
reasoning_text = None
|
1121
1117
|
|
1122
|
-
if tool_choice != "none" and
|
1123
|
-
|
1124
|
-
|
1125
|
-
|
1126
|
-
|
1127
|
-
|
1128
|
-
|
1129
|
-
|
1130
|
-
|
1131
|
-
|
1132
|
-
|
1133
|
-
|
1118
|
+
if tool_choice != "none" and tools:
|
1119
|
+
parser = FunctionCallParser(tools, tool_call_parser)
|
1120
|
+
if parser.has_tool_call(text):
|
1121
|
+
if finish_reason["type"] == "stop":
|
1122
|
+
finish_reason["type"] = "tool_calls"
|
1123
|
+
finish_reason["matched"] = None
|
1124
|
+
try:
|
1125
|
+
full_normal_text, call_info_list = parser.parse_non_stream(text)
|
1126
|
+
tool_calls = [
|
1127
|
+
ToolCall(
|
1128
|
+
id=str(call_info.tool_index),
|
1129
|
+
function=FunctionResponse(
|
1130
|
+
name=call_info.name, arguments=call_info.parameters
|
1131
|
+
),
|
1132
|
+
)
|
1133
|
+
for call_info in call_info_list
|
1134
|
+
]
|
1135
|
+
except Exception as e:
|
1136
|
+
logger.error(f"Exception: {e}")
|
1137
|
+
return create_error_response(
|
1138
|
+
HTTPStatus.BAD_REQUEST,
|
1139
|
+
"Failed to parse fc related info to json format!",
|
1134
1140
|
)
|
1135
|
-
for call_info in call_info_list
|
1136
|
-
]
|
1137
|
-
except Exception as e:
|
1138
|
-
logger.error(f"Exception: {e}")
|
1139
|
-
return create_error_response(
|
1140
|
-
HTTPStatus.BAD_REQUEST,
|
1141
|
-
"Failed to parse fc related info to json format!",
|
1142
|
-
)
|
1143
1141
|
|
1144
1142
|
if to_file:
|
1145
1143
|
# to make the choice data json serializable
|
@@ -1151,7 +1149,7 @@ def v1_chat_generate_response(
|
|
1151
1149
|
"tool_calls": tool_calls,
|
1152
1150
|
"reasoning_content": reasoning_text,
|
1153
1151
|
},
|
1154
|
-
"logprobs": choice_logprobs,
|
1152
|
+
"logprobs": choice_logprobs.model_dump() if choice_logprobs else None,
|
1155
1153
|
"finish_reason": (finish_reason["type"] if finish_reason else ""),
|
1156
1154
|
"matched_stop": (
|
1157
1155
|
finish_reason["matched"]
|
@@ -1499,13 +1497,13 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|
1499
1497
|
)
|
1500
1498
|
|
1501
1499
|
final_usage_chunk = ChatCompletionStreamResponse(
|
1502
|
-
id=
|
1500
|
+
id=content["meta_info"]["id"],
|
1503
1501
|
choices=[],
|
1504
1502
|
model=request.model,
|
1505
1503
|
usage=usage,
|
1506
1504
|
)
|
1507
1505
|
final_usage_data = final_usage_chunk.model_dump_json(
|
1508
|
-
|
1506
|
+
exclude_none=True
|
1509
1507
|
)
|
1510
1508
|
yield f"data: {final_usage_data}\n\n"
|
1511
1509
|
except ValueError as e:
|
@@ -1556,11 +1554,37 @@ def v1_embedding_request(all_requests, tokenizer_manager):
|
|
1556
1554
|
prompt = prompts[0]
|
1557
1555
|
if isinstance(prompt, str) or isinstance(prompt[0], str):
|
1558
1556
|
prompt_kwargs = {"text": prompt}
|
1557
|
+
elif isinstance(prompt, list) and isinstance(
|
1558
|
+
prompt[0], MultimodalEmbeddingInput
|
1559
|
+
):
|
1560
|
+
assert (
|
1561
|
+
chat_template_name is not None
|
1562
|
+
), "chat_template_name is required for multimodal inputs"
|
1563
|
+
texts = []
|
1564
|
+
images = []
|
1565
|
+
for item in prompt:
|
1566
|
+
texts.append(item.text if item.text is not None else None)
|
1567
|
+
images.append(item.image if item.image is not None else None)
|
1568
|
+
convs = generate_embedding_convs(texts, images, chat_template_name)
|
1569
|
+
generate_prompts = []
|
1570
|
+
for conv in convs:
|
1571
|
+
generate_prompts.append(conv.get_prompt())
|
1572
|
+
if len(generate_prompts) == 1:
|
1573
|
+
prompt_kwargs = {"text": generate_prompts[0], "image_data": images[0]}
|
1574
|
+
else:
|
1575
|
+
prompt_kwargs = {"text": generate_prompts, "image_data": images}
|
1559
1576
|
else:
|
1560
1577
|
prompt_kwargs = {"input_ids": prompt}
|
1561
1578
|
else:
|
1562
1579
|
if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
|
1563
1580
|
prompt_kwargs = {"text": prompts}
|
1581
|
+
elif isinstance(prompts[0], list) and isinstance(
|
1582
|
+
prompts[0][0], MultimodalEmbeddingInput
|
1583
|
+
):
|
1584
|
+
# TODO: multiple requests
|
1585
|
+
raise NotImplementedError(
|
1586
|
+
"Multiple requests with multimodal inputs are not supported yet"
|
1587
|
+
)
|
1564
1588
|
else:
|
1565
1589
|
prompt_kwargs = {"input_ids": prompts}
|
1566
1590
|
|
@@ -403,10 +403,17 @@ class ChatCompletionStreamResponse(BaseModel):
|
|
403
403
|
usage: Optional[UsageInfo] = None
|
404
404
|
|
405
405
|
|
406
|
+
class MultimodalEmbeddingInput(BaseModel):
|
407
|
+
text: Optional[str] = None
|
408
|
+
image: Optional[str] = None
|
409
|
+
|
410
|
+
|
406
411
|
class EmbeddingRequest(BaseModel):
|
407
412
|
# Ordered by official OpenAI API documentation
|
408
413
|
# https://platform.openai.com/docs/api-reference/embeddings/create
|
409
|
-
input: Union[
|
414
|
+
input: Union[
|
415
|
+
List[int], List[List[int]], str, List[str], List[MultimodalEmbeddingInput]
|
416
|
+
]
|
410
417
|
model: str
|
411
418
|
encoding_format: str = "float"
|
412
419
|
dimensions: int = None
|
@@ -56,7 +56,6 @@ class BatchedFrequencyPenalizer(_BatchedPenalizer):
|
|
56
56
|
]
|
57
57
|
|
58
58
|
def _merge(self, their: "BatchedFrequencyPenalizer"):
|
59
|
-
print(f"{self.frequency_penalties.shape=}, {their.frequency_penalties.shape=}")
|
60
59
|
self.frequency_penalties = torch.cat(
|
61
60
|
[self.frequency_penalties, their.frequency_penalties], dim=0
|
62
61
|
)
|
@@ -56,7 +56,6 @@ class BatchedPresencePenalizer(_BatchedPenalizer):
|
|
56
56
|
]
|
57
57
|
|
58
58
|
def _merge(self, their: "BatchedPresencePenalizer"):
|
59
|
-
print(f"{self.presence_penalties.shape=}, {their.presence_penalties.shape=}")
|
60
59
|
self.presence_penalties = torch.cat(
|
61
60
|
[self.presence_penalties, their.presence_penalties], dim=0
|
62
61
|
)
|
sglang/srt/server_args.py
CHANGED
@@ -20,14 +20,13 @@ import random
|
|
20
20
|
import tempfile
|
21
21
|
from typing import List, Optional
|
22
22
|
|
23
|
-
import torch
|
24
|
-
|
25
23
|
from sglang.srt.hf_transformers_utils import check_gguf_file
|
26
24
|
from sglang.srt.reasoning_parser import ReasoningParser
|
27
25
|
from sglang.srt.utils import (
|
28
26
|
get_amdgpu_memory_capacity,
|
29
27
|
get_hpu_memory_capacity,
|
30
28
|
get_nvgpu_memory_capacity,
|
29
|
+
is_cuda,
|
31
30
|
is_flashinfer_available,
|
32
31
|
is_hip,
|
33
32
|
is_port_available,
|
@@ -71,6 +70,7 @@ class ServerArgs:
|
|
71
70
|
schedule_policy: str = "fcfs"
|
72
71
|
schedule_conservativeness: float = 1.0
|
73
72
|
cpu_offload_gb: int = 0
|
73
|
+
page_size: int = 1
|
74
74
|
|
75
75
|
# Other runtime options
|
76
76
|
tp_size: int = 1
|
@@ -190,10 +190,10 @@ class ServerArgs:
|
|
190
190
|
if self.random_seed is None:
|
191
191
|
self.random_seed = random.randint(0, 1 << 30)
|
192
192
|
|
193
|
-
if
|
194
|
-
gpu_mem = get_amdgpu_memory_capacity()
|
195
|
-
elif torch.cuda.is_available():
|
193
|
+
if is_cuda():
|
196
194
|
gpu_mem = get_nvgpu_memory_capacity()
|
195
|
+
elif is_hip():
|
196
|
+
gpu_mem = get_amdgpu_memory_capacity()
|
197
197
|
elif self.device == "hpu":
|
198
198
|
gpu_mem = get_hpu_memory_capacity()
|
199
199
|
else:
|
@@ -220,6 +220,8 @@ class ServerArgs:
|
|
220
220
|
else:
|
221
221
|
self.chunked_prefill_size = 8192
|
222
222
|
|
223
|
+
assert self.chunked_prefill_size % self.page_size == 0
|
224
|
+
|
223
225
|
# Set cuda graph max batch size
|
224
226
|
if self.cuda_graph_max_bs is None:
|
225
227
|
# Based on detailed statistics, when serving TP1/TP2 models on lower-end GPUs with HBM<25G, you can either disable cuda graph or set `cuda_graph_max_bs` to a very small value to reduce the memory overhead of creating cuda graphs, with almost no impact on performance. However, when serving models with TP4 or TP8, we need to enable cuda graph to maintain high performance. In this case, we can set `cuda_graph_max_bs` to 80 (half of the default value 160) to reduce the memory overhead of creating cuda graphs. Looking at the logs from TP4 serving of qwen2-72b, a value of 80 is sufficient and can reduce the memory overhead of creating cuda graphs on lower-end GPUs compared to the original 160, avoiding OOM issues.
|
@@ -258,16 +260,16 @@ class ServerArgs:
|
|
258
260
|
f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
|
259
261
|
)
|
260
262
|
|
261
|
-
#
|
263
|
+
# Data parallelism attention
|
262
264
|
if self.enable_dp_attention:
|
263
|
-
self.dp_size = self.tp_size
|
264
|
-
assert self.tp_size % self.dp_size == 0
|
265
|
-
self.chunked_prefill_size = self.chunked_prefill_size // 2
|
266
265
|
self.schedule_conservativeness = self.schedule_conservativeness * 0.3
|
266
|
+
assert (
|
267
|
+
self.dp_size > 1
|
268
|
+
), "Please set a dp-size > 1. You can use 1 < dp-size <= tp-size "
|
269
|
+
assert self.tp_size % self.dp_size == 0
|
270
|
+
self.chunked_prefill_size = self.chunked_prefill_size // self.dp_size
|
267
271
|
logger.warning(
|
268
272
|
f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. "
|
269
|
-
f"The schedule conservativeness is adjusted to {self.schedule_conservativeness}. "
|
270
|
-
"Data parallel size is adjusted to be the same as tensor parallel size. "
|
271
273
|
)
|
272
274
|
|
273
275
|
# Speculative Decoding
|
@@ -278,10 +280,10 @@ class ServerArgs:
|
|
278
280
|
if self.speculative_algorithm == "EAGLE":
|
279
281
|
if self.max_running_requests is None:
|
280
282
|
self.max_running_requests = 32
|
281
|
-
self.disable_overlap_schedule = True
|
282
283
|
self.disable_cuda_graph_padding = True
|
284
|
+
self.disable_overlap_schedule = True
|
283
285
|
logger.info(
|
284
|
-
"Overlap scheduler
|
286
|
+
"Overlap scheduler is disabled because of using "
|
285
287
|
"eagle speculative decoding."
|
286
288
|
)
|
287
289
|
# The token generated from the verify step is counted.
|
@@ -405,6 +407,7 @@ class ServerArgs:
|
|
405
407
|
"gguf",
|
406
408
|
"modelopt",
|
407
409
|
"w8a8_int8",
|
410
|
+
"w8a8_fp8",
|
408
411
|
],
|
409
412
|
help="The quantization method.",
|
410
413
|
)
|
@@ -479,7 +482,7 @@ class ServerArgs:
|
|
479
482
|
"--chunked-prefill-size",
|
480
483
|
type=int,
|
481
484
|
default=ServerArgs.chunked_prefill_size,
|
482
|
-
help="The maximum number of tokens in a chunk for the chunked prefill. Setting this to -1 means disabling chunked prefill",
|
485
|
+
help="The maximum number of tokens in a chunk for the chunked prefill. Setting this to -1 means disabling chunked prefill.",
|
483
486
|
)
|
484
487
|
parser.add_argument(
|
485
488
|
"--max-prefill-tokens",
|
@@ -504,7 +507,13 @@ class ServerArgs:
|
|
504
507
|
"--cpu-offload-gb",
|
505
508
|
type=int,
|
506
509
|
default=ServerArgs.cpu_offload_gb,
|
507
|
-
help="How many GBs of RAM to reserve for CPU offloading",
|
510
|
+
help="How many GBs of RAM to reserve for CPU offloading.",
|
511
|
+
)
|
512
|
+
parser.add_argument(
|
513
|
+
"--page-size",
|
514
|
+
type=int,
|
515
|
+
default=ServerArgs.page_size,
|
516
|
+
help="The number of tokens in a page.",
|
508
517
|
)
|
509
518
|
|
510
519
|
# Other runtime options
|
@@ -764,7 +773,6 @@ class ServerArgs:
|
|
764
773
|
"--speculative-eagle-topk",
|
765
774
|
type=int,
|
766
775
|
help="The number of tokens sampled from the draft model in eagle2 each step.",
|
767
|
-
choices=[1, 2, 4, 8],
|
768
776
|
default=ServerArgs.speculative_eagle_topk,
|
769
777
|
)
|
770
778
|
parser.add_argument(
|
@@ -7,6 +7,7 @@ import torch
|
|
7
7
|
from huggingface_hub import snapshot_download
|
8
8
|
|
9
9
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
10
|
+
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
|
10
11
|
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
11
12
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
12
13
|
from sglang.srt.model_executor.forward_batch_info import (
|
@@ -122,6 +123,16 @@ class EAGLEWorker(TpModelWorker):
|
|
122
123
|
self.topk,
|
123
124
|
self.speculative_num_steps,
|
124
125
|
)
|
126
|
+
elif self.server_args.attention_backend == "flashinfer_mla":
|
127
|
+
from sglang.srt.layers.attention.flashinfer_mla_backend import (
|
128
|
+
FlashInferMLAMultiStepDraftBackend,
|
129
|
+
)
|
130
|
+
|
131
|
+
self.draft_attn_backend = FlashInferMLAMultiStepDraftBackend(
|
132
|
+
self.model_runner,
|
133
|
+
self.topk,
|
134
|
+
self.speculative_num_steps,
|
135
|
+
)
|
125
136
|
else:
|
126
137
|
raise ValueError(
|
127
138
|
f"EAGLE is not supportted in attention backend {self.server_args.attention_backend}"
|
@@ -302,13 +313,10 @@ class EAGLEWorker(TpModelWorker):
|
|
302
313
|
|
303
314
|
# Set inputs
|
304
315
|
forward_batch.input_ids = input_ids
|
316
|
+
out_cache_loc = out_cache_loc.view(forward_batch.batch_size, -1)
|
305
317
|
forward_batch.out_cache_loc = out_cache_loc[
|
306
|
-
|
307
|
-
|
308
|
-
* i : forward_batch.batch_size
|
309
|
-
* self.topk
|
310
|
-
* (i + 1)
|
311
|
-
]
|
318
|
+
:, self.topk * i : self.topk * (i + 1)
|
319
|
+
].flatten()
|
312
320
|
forward_batch.positions.add_(1)
|
313
321
|
forward_batch.attn_backend = self.draft_attn_backend.attn_backends[i]
|
314
322
|
spec_info.hidden_states = hidden_states
|
@@ -353,42 +361,70 @@ class EAGLEWorker(TpModelWorker):
|
|
353
361
|
batch.spec_info = res.draft_input
|
354
362
|
|
355
363
|
if batch.return_logprob:
|
356
|
-
|
357
|
-
num_tokens_per_req = [
|
358
|
-
accept + 1 for accept in res.accept_length_per_req_cpu
|
359
|
-
]
|
360
|
-
self.target_worker.model_runner.update_output_logprobs(
|
361
|
-
logits_output,
|
362
|
-
batch.sampling_info,
|
363
|
-
batch.top_logprobs_nums,
|
364
|
-
batch.token_ids_logprobs,
|
365
|
-
res.verified_id,
|
366
|
-
# +1 for bonus token.
|
367
|
-
num_tokens_per_req=num_tokens_per_req,
|
368
|
-
)
|
369
|
-
|
370
|
-
# Add output logprobs to the request.
|
371
|
-
pt = 0
|
372
|
-
# NOTE: tolist() of these values are skipped when output is processed
|
373
|
-
next_token_logprobs = res.logits_output.next_token_logprobs.tolist()
|
374
|
-
verified_ids = res.verified_id.tolist()
|
375
|
-
for req, num_tokens in zip(batch.reqs, num_tokens_per_req):
|
376
|
-
for _ in range(num_tokens):
|
377
|
-
if req.return_logprob:
|
378
|
-
token_id = verified_ids[pt]
|
379
|
-
req.output_token_logprobs_val.append(next_token_logprobs[pt])
|
380
|
-
req.output_token_logprobs_idx.append(token_id)
|
381
|
-
if req.top_logprobs_num > 0:
|
382
|
-
req.output_top_logprobs_val.append(
|
383
|
-
res.logits_output.next_token_top_logprobs_val[pt]
|
384
|
-
)
|
385
|
-
req.output_top_logprobs_idx.append(
|
386
|
-
res.logits_output.next_token_top_logprobs_idx[pt]
|
387
|
-
)
|
388
|
-
pt += 1
|
364
|
+
self.add_logprob_values(batch, res, logits_output)
|
389
365
|
|
390
366
|
return logits_output, res, model_worker_batch
|
391
367
|
|
368
|
+
def add_logprob_values(
|
369
|
+
self,
|
370
|
+
batch: ScheduleBatch,
|
371
|
+
res: EagleVerifyOutput,
|
372
|
+
logits_output: LogitsProcessorOutput,
|
373
|
+
):
|
374
|
+
# Extract args
|
375
|
+
logits_output = res.logits_output
|
376
|
+
top_logprobs_nums = batch.top_logprobs_nums
|
377
|
+
token_ids_logprobs = batch.token_ids_logprobs
|
378
|
+
logprobs = torch.nn.functional.log_softmax(
|
379
|
+
logits_output.next_token_logits, dim=-1
|
380
|
+
)
|
381
|
+
batch_next_token_ids = res.verified_id
|
382
|
+
num_tokens_per_req = [accept + 1 for accept in res.accept_length_per_req_cpu]
|
383
|
+
|
384
|
+
# We should repeat top_logprobs_nums to match num_tokens_per_req.
|
385
|
+
top_logprobs_nums_repeat_interleaved = []
|
386
|
+
token_ids_logprobs_repeat_interleaved = []
|
387
|
+
for num, num_tokens in zip(top_logprobs_nums, num_tokens_per_req):
|
388
|
+
top_logprobs_nums_repeat_interleaved.extend([num] * num_tokens)
|
389
|
+
for token_ids, num_tokens in zip(token_ids_logprobs, num_tokens_per_req):
|
390
|
+
token_ids_logprobs_repeat_interleaved.extend([token_ids] * num_tokens)
|
391
|
+
|
392
|
+
# Extract logprobs
|
393
|
+
if any(x > 0 for x in top_logprobs_nums):
|
394
|
+
(
|
395
|
+
logits_output.next_token_top_logprobs_val,
|
396
|
+
logits_output.next_token_top_logprobs_idx,
|
397
|
+
) = get_top_logprobs(logprobs, top_logprobs_nums_repeat_interleaved)
|
398
|
+
|
399
|
+
if any(x is not None for x in token_ids_logprobs):
|
400
|
+
(
|
401
|
+
logits_output.next_token_token_ids_logprobs_val,
|
402
|
+
logits_output.next_token_token_ids_logprobs_idx,
|
403
|
+
) = get_token_ids_logprobs(logprobs, token_ids_logprobs_repeat_interleaved)
|
404
|
+
|
405
|
+
logits_output.next_token_logprobs = logprobs[
|
406
|
+
torch.arange(len(batch_next_token_ids), device=batch.sampling_info.device),
|
407
|
+
batch_next_token_ids,
|
408
|
+
]
|
409
|
+
|
410
|
+
# Add output logprobs to the request.
|
411
|
+
pt = 0
|
412
|
+
next_token_logprobs = logits_output.next_token_logprobs.tolist()
|
413
|
+
verified_ids = batch_next_token_ids.tolist()
|
414
|
+
for req, num_tokens in zip(batch.reqs, num_tokens_per_req):
|
415
|
+
for _ in range(num_tokens):
|
416
|
+
if req.return_logprob:
|
417
|
+
req.output_token_logprobs_val.append(next_token_logprobs[pt])
|
418
|
+
req.output_token_logprobs_idx.append(verified_ids[pt])
|
419
|
+
if req.top_logprobs_num > 0:
|
420
|
+
req.output_top_logprobs_val.append(
|
421
|
+
res.logits_output.next_token_top_logprobs_val[pt]
|
422
|
+
)
|
423
|
+
req.output_top_logprobs_idx.append(
|
424
|
+
res.logits_output.next_token_top_logprobs_idx[pt]
|
425
|
+
)
|
426
|
+
pt += 1
|
427
|
+
|
392
428
|
def forward_draft_extend(
|
393
429
|
self,
|
394
430
|
batch: ScheduleBatch,
|
sglang/srt/utils.py
CHANGED
@@ -14,6 +14,7 @@
|
|
14
14
|
"""Common utilities."""
|
15
15
|
|
16
16
|
import base64
|
17
|
+
import builtins
|
17
18
|
import ctypes
|
18
19
|
import dataclasses
|
19
20
|
import io
|
@@ -37,6 +38,7 @@ import time
|
|
37
38
|
import warnings
|
38
39
|
from functools import lru_cache
|
39
40
|
from importlib.metadata import PackageNotFoundError, version
|
41
|
+
from importlib.util import find_spec
|
40
42
|
from io import BytesIO
|
41
43
|
from multiprocessing import Pool
|
42
44
|
from multiprocessing.reduction import ForkingPickler
|
@@ -52,11 +54,13 @@ import triton
|
|
52
54
|
import zmq
|
53
55
|
from fastapi.responses import ORJSONResponse
|
54
56
|
from packaging import version as pkg_version
|
57
|
+
from packaging.version import Version, parse
|
55
58
|
from starlette.routing import Mount
|
56
59
|
from torch import nn
|
57
60
|
from torch.func import functional_call
|
58
61
|
from torch.library import Library
|
59
62
|
from torch.profiler import ProfilerActivity, profile, record_function
|
63
|
+
from torch.utils.cpp_extension import CUDA_HOME
|
60
64
|
from triton.runtime.cache import (
|
61
65
|
FileCacheManager,
|
62
66
|
default_cache_dir,
|
@@ -69,14 +73,31 @@ logger = logging.getLogger(__name__)
|
|
69
73
|
show_time_cost = False
|
70
74
|
time_infos = {}
|
71
75
|
|
76
|
+
HIP_FP8_E4M3_FNUZ_MAX = 224.0
|
72
77
|
|
78
|
+
|
79
|
+
# https://pytorch.org/docs/stable/notes/hip.html#checking-for-hip
|
73
80
|
def is_hip() -> bool:
|
74
|
-
"""Return whether it is HIP on the AMD ROCm platform."""
|
75
81
|
return torch.version.hip is not None
|
76
82
|
|
77
83
|
|
84
|
+
if is_hip():
|
85
|
+
FP8_E4M3_MAX = HIP_FP8_E4M3_FNUZ_MAX
|
86
|
+
else:
|
87
|
+
FP8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
|
88
|
+
|
89
|
+
FP8_E4M3_MIN = -FP8_E4M3_MAX
|
90
|
+
|
91
|
+
builtins.FP8_E4M3_MAX = FP8_E4M3_MAX
|
92
|
+
builtins.FP8_E4M3_MIN = FP8_E4M3_MIN
|
93
|
+
|
94
|
+
|
95
|
+
def is_rocm() -> bool:
|
96
|
+
return torch.cuda.is_available() and torch.version.hip
|
97
|
+
|
98
|
+
|
78
99
|
def is_cuda():
|
79
|
-
return
|
100
|
+
return torch.cuda.is_available() and torch.version.cuda
|
80
101
|
|
81
102
|
|
82
103
|
def is_cuda_alike():
|
@@ -98,11 +119,11 @@ def is_flashinfer_available():
|
|
98
119
|
"""
|
99
120
|
if not get_bool_env_var("SGLANG_IS_FLASHINFER_AVAILABLE", default="true"):
|
100
121
|
return False
|
101
|
-
return
|
122
|
+
return is_cuda()
|
102
123
|
|
103
124
|
|
104
125
|
def is_cuda_available():
|
105
|
-
return
|
126
|
+
return is_cuda()
|
106
127
|
|
107
128
|
|
108
129
|
def enable_show_time_cost():
|
@@ -1045,6 +1066,65 @@ def get_device_name(device_id: int = 0) -> str:
|
|
1045
1066
|
return torch.hpu.get_device_name(device_id)
|
1046
1067
|
|
1047
1068
|
|
1069
|
+
@lru_cache(maxsize=1)
|
1070
|
+
def is_habana_available() -> bool:
|
1071
|
+
return find_spec("habana_frameworks") is not None
|
1072
|
+
|
1073
|
+
|
1074
|
+
@lru_cache(maxsize=8)
|
1075
|
+
def get_device(device_id: Optional[int] = None) -> str:
|
1076
|
+
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
1077
|
+
if device_id is None:
|
1078
|
+
return "cuda"
|
1079
|
+
return "cuda:{}".format(device_id)
|
1080
|
+
|
1081
|
+
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
1082
|
+
if device_id == None:
|
1083
|
+
return "xpu"
|
1084
|
+
return "xpu:{}".format(device_id)
|
1085
|
+
|
1086
|
+
if is_habana_available():
|
1087
|
+
try:
|
1088
|
+
import habana_frameworks.torch.hpu
|
1089
|
+
|
1090
|
+
if torch.hpu.is_available():
|
1091
|
+
if device_id == None:
|
1092
|
+
return "hpu"
|
1093
|
+
return "hpu:{}".format(device_id)
|
1094
|
+
except ImportError as e:
|
1095
|
+
raise ImportError(
|
1096
|
+
"Habana frameworks detected, but failed to import 'habana_frameworks.torch.hpu'."
|
1097
|
+
)
|
1098
|
+
|
1099
|
+
raise RuntimeError("No accelerator (CUDA, XPU, HPU) is available.")
|
1100
|
+
|
1101
|
+
|
1102
|
+
@lru_cache(maxsize=1)
|
1103
|
+
def get_device_count() -> int:
|
1104
|
+
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
1105
|
+
try:
|
1106
|
+
return torch.cuda.device_count()
|
1107
|
+
except RuntimeError:
|
1108
|
+
return 0
|
1109
|
+
|
1110
|
+
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
1111
|
+
try:
|
1112
|
+
return torch.xpu.device_count()
|
1113
|
+
except RuntimeError:
|
1114
|
+
return 0
|
1115
|
+
|
1116
|
+
if is_habana_available():
|
1117
|
+
try:
|
1118
|
+
import habana_frameworks.torch.hpu
|
1119
|
+
|
1120
|
+
if torch.hpu.is_available():
|
1121
|
+
return torch.hpu.device_count()
|
1122
|
+
except (ImportError, RuntimeError):
|
1123
|
+
return 0
|
1124
|
+
|
1125
|
+
return 0 # No accelerators available
|
1126
|
+
|
1127
|
+
|
1048
1128
|
def get_device_core_count(device_id: int = 0) -> int:
|
1049
1129
|
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
1050
1130
|
return torch.cuda.get_device_properties(device_id).multi_processor_count
|
@@ -1063,11 +1143,12 @@ def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
|
|
1063
1143
|
)
|
1064
1144
|
major, minor = int(major), int(minor)
|
1065
1145
|
|
1066
|
-
# TODO(HandH1998): `get_device_capability` is not supported by `torch.hpu` for now.
|
1067
|
-
# Update this once the support is available.
|
1068
1146
|
if hasattr(torch, "hpu") and torch.hpu.is_available():
|
1069
1147
|
try:
|
1070
|
-
|
1148
|
+
# TODO(HandH1998): `get_device_capability` is not supported by `torch.hpu` for now.
|
1149
|
+
# Update this once the support is available.
|
1150
|
+
# major, minor = torch.hpu.get_device_capability(device_id)
|
1151
|
+
major, minor = None, None
|
1071
1152
|
except Exception as e:
|
1072
1153
|
raise RuntimeError(
|
1073
1154
|
f"An error occurred while getting device capability of hpu: {e}."
|
@@ -1269,7 +1350,8 @@ def permute_weight(x: torch.Tensor) -> torch.Tensor:
|
|
1269
1350
|
elif x.dtype == torch.float8_e4m3fnuz or x.dtype == torch.int8:
|
1270
1351
|
x_ = x_.view(int(b_), int(n_ / 16), 16, int(k_ / 64), 4, 16)
|
1271
1352
|
else:
|
1272
|
-
return x_
|
1353
|
+
# return x_
|
1354
|
+
x_ = x_.view(int(b_), int(n_ / 16), 16, int(k_ / 8), 2, 4)
|
1273
1355
|
|
1274
1356
|
x_ = x_.permute(0, 1, 3, 4, 2, 5)
|
1275
1357
|
x_ = x_.contiguous()
|
@@ -1341,7 +1423,7 @@ def kill_itself_when_parent_died():
|
|
1341
1423
|
libc = ctypes.CDLL("libc.so.6")
|
1342
1424
|
libc.prctl(PR_SET_PDEATHSIG, signal.SIGKILL)
|
1343
1425
|
else:
|
1344
|
-
logger.
|
1426
|
+
logger.warning("kill_itself_when_parent_died is only supported in linux.")
|
1345
1427
|
|
1346
1428
|
|
1347
1429
|
def set_uvicorn_logging_configs():
|
@@ -1430,6 +1512,12 @@ def rank0_print(msg: str):
|
|
1430
1512
|
print(msg, flush=True)
|
1431
1513
|
|
1432
1514
|
|
1515
|
+
def get_cuda_version():
|
1516
|
+
if torch.version.cuda:
|
1517
|
+
return tuple(map(int, torch.version.cuda.split(".")))
|
1518
|
+
return (0, 0)
|
1519
|
+
|
1520
|
+
|
1433
1521
|
def launch_dummy_health_check_server(host, port):
|
1434
1522
|
import uvicorn
|
1435
1523
|
from fastapi import FastAPI, Response
|
@@ -1466,6 +1554,13 @@ def set_cuda_arch():
|
|
1466
1554
|
os.environ["TORCH_CUDA_ARCH_LIST"] = f"{arch}{'+PTX' if arch == '9.0' else ''}"
|
1467
1555
|
|
1468
1556
|
|
1557
|
+
def next_power_of_2(n: int):
|
1558
|
+
return 1 << (n - 1).bit_length() if n > 0 else 1
|
1559
|
+
|
1560
|
+
|
1561
|
+
setattr(triton, "next_power_of_2", next_power_of_2)
|
1562
|
+
|
1563
|
+
|
1469
1564
|
def add_prefix(name: str, prefix: str) -> str:
|
1470
1565
|
"""Add a weight path prefix to a module name.
|
1471
1566
|
|