sglang 0.4.3.post4__py3-none-any.whl → 0.4.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. sglang/bench_serving.py +1 -1
  2. sglang/lang/chat_template.py +29 -0
  3. sglang/srt/_custom_ops.py +19 -17
  4. sglang/srt/configs/__init__.py +2 -0
  5. sglang/srt/configs/janus_pro.py +629 -0
  6. sglang/srt/configs/model_config.py +24 -14
  7. sglang/srt/conversation.py +80 -2
  8. sglang/srt/custom_op.py +64 -3
  9. sglang/srt/distributed/device_communicators/custom_all_reduce.py +18 -17
  10. sglang/srt/distributed/parallel_state.py +10 -1
  11. sglang/srt/entrypoints/engine.py +5 -3
  12. sglang/srt/entrypoints/http_server.py +1 -1
  13. sglang/srt/function_call_parser.py +33 -2
  14. sglang/srt/hf_transformers_utils.py +16 -1
  15. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  16. sglang/srt/layers/attention/flashinfer_mla_backend.py +317 -57
  17. sglang/srt/layers/attention/triton_backend.py +1 -3
  18. sglang/srt/layers/attention/triton_ops/decode_attention.py +6 -6
  19. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +3 -3
  20. sglang/srt/layers/attention/triton_ops/extend_attention.py +4 -4
  21. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +3 -3
  22. sglang/srt/layers/attention/vision.py +43 -62
  23. sglang/srt/layers/dp_attention.py +30 -2
  24. sglang/srt/layers/elementwise.py +411 -0
  25. sglang/srt/layers/linear.py +1 -1
  26. sglang/srt/layers/logits_processor.py +1 -0
  27. sglang/srt/layers/moe/ep_moe/kernels.py +2 -1
  28. sglang/srt/layers/moe/ep_moe/layer.py +25 -9
  29. sglang/srt/layers/moe/fused_moe_triton/configs/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +63 -23
  36. sglang/srt/layers/moe/fused_moe_triton/layer.py +16 -4
  37. sglang/srt/layers/moe/router.py +342 -0
  38. sglang/srt/layers/parameter.py +10 -0
  39. sglang/srt/layers/quantization/__init__.py +90 -68
  40. sglang/srt/layers/quantization/blockwise_int8.py +1 -2
  41. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  43. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  44. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  45. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  46. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  47. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  48. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  49. sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  50. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  51. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  52. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  63. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  64. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  65. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/quantization/fp8.py +174 -106
  68. sglang/srt/layers/quantization/fp8_kernel.py +210 -38
  69. sglang/srt/layers/quantization/fp8_utils.py +156 -15
  70. sglang/srt/layers/quantization/modelopt_quant.py +5 -1
  71. sglang/srt/layers/quantization/w8a8_fp8.py +128 -0
  72. sglang/srt/layers/quantization/w8a8_int8.py +152 -3
  73. sglang/srt/layers/rotary_embedding.py +5 -3
  74. sglang/srt/layers/sampler.py +29 -35
  75. sglang/srt/layers/vocab_parallel_embedding.py +0 -1
  76. sglang/srt/lora/backend/__init__.py +9 -12
  77. sglang/srt/managers/cache_controller.py +74 -8
  78. sglang/srt/managers/data_parallel_controller.py +1 -1
  79. sglang/srt/managers/image_processor.py +37 -631
  80. sglang/srt/managers/image_processors/base_image_processor.py +219 -0
  81. sglang/srt/managers/image_processors/janus_pro.py +79 -0
  82. sglang/srt/managers/image_processors/llava.py +152 -0
  83. sglang/srt/managers/image_processors/minicpmv.py +86 -0
  84. sglang/srt/managers/image_processors/mlama.py +60 -0
  85. sglang/srt/managers/image_processors/qwen_vl.py +161 -0
  86. sglang/srt/managers/io_struct.py +32 -15
  87. sglang/srt/managers/multi_modality_padding.py +134 -0
  88. sglang/srt/managers/schedule_batch.py +213 -118
  89. sglang/srt/managers/schedule_policy.py +40 -8
  90. sglang/srt/managers/scheduler.py +176 -683
  91. sglang/srt/managers/scheduler_output_processor_mixin.py +614 -0
  92. sglang/srt/managers/tokenizer_manager.py +6 -6
  93. sglang/srt/managers/tp_worker_overlap_thread.py +4 -1
  94. sglang/srt/mem_cache/base_prefix_cache.py +6 -8
  95. sglang/srt/mem_cache/chunk_cache.py +12 -44
  96. sglang/srt/mem_cache/hiradix_cache.py +71 -34
  97. sglang/srt/mem_cache/memory_pool.py +81 -17
  98. sglang/srt/mem_cache/paged_allocator.py +283 -0
  99. sglang/srt/mem_cache/radix_cache.py +117 -36
  100. sglang/srt/model_executor/cuda_graph_runner.py +68 -20
  101. sglang/srt/model_executor/forward_batch_info.py +23 -10
  102. sglang/srt/model_executor/model_runner.py +63 -63
  103. sglang/srt/model_loader/loader.py +2 -1
  104. sglang/srt/model_loader/weight_utils.py +1 -1
  105. sglang/srt/models/deepseek_janus_pro.py +2127 -0
  106. sglang/srt/models/deepseek_nextn.py +23 -3
  107. sglang/srt/models/deepseek_v2.py +200 -191
  108. sglang/srt/models/grok.py +374 -119
  109. sglang/srt/models/minicpmv.py +28 -89
  110. sglang/srt/models/mllama.py +1 -1
  111. sglang/srt/models/qwen2.py +0 -1
  112. sglang/srt/models/qwen2_5_vl.py +25 -50
  113. sglang/srt/models/qwen2_vl.py +33 -49
  114. sglang/srt/openai_api/adapter.py +59 -35
  115. sglang/srt/openai_api/protocol.py +8 -1
  116. sglang/srt/sampling/penaltylib/frequency_penalty.py +0 -1
  117. sglang/srt/sampling/penaltylib/presence_penalty.py +0 -1
  118. sglang/srt/server_args.py +24 -16
  119. sglang/srt/speculative/eagle_worker.py +75 -39
  120. sglang/srt/utils.py +104 -9
  121. sglang/test/runners.py +104 -10
  122. sglang/test/test_block_fp8.py +106 -16
  123. sglang/test/test_custom_ops.py +88 -0
  124. sglang/test/test_utils.py +20 -4
  125. sglang/utils.py +0 -4
  126. sglang/version.py +1 -1
  127. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/METADATA +9 -10
  128. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/RECORD +131 -84
  129. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/WHEEL +1 -1
  130. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/LICENSE +0 -0
  131. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/top_level.txt +0 -0
@@ -38,6 +38,7 @@ from sglang.srt.conversation import (
38
38
  SeparatorStyle,
39
39
  chat_template_exists,
40
40
  generate_chat_conv,
41
+ generate_embedding_convs,
41
42
  register_conv_template,
42
43
  )
43
44
  from sglang.srt.function_call_parser import TOOLS_TAG_LIST, FunctionCallParser
@@ -68,6 +69,7 @@ from sglang.srt.openai_api.protocol import (
68
69
  FileResponse,
69
70
  FunctionResponse,
70
71
  LogProbs,
72
+ MultimodalEmbeddingInput,
71
73
  ToolCall,
72
74
  TopLogprob,
73
75
  UsageInfo,
@@ -282,11 +284,11 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
282
284
  file_request_list = []
283
285
  all_requests = []
284
286
  request_ids = []
285
- for line in lines:
287
+ for line_id, line in enumerate(lines):
286
288
  request_data = json.loads(line)
287
289
  file_request_list.append(request_data)
288
290
  body = request_data["body"]
289
- request_ids.append(request_data["custom_id"])
291
+ request_ids.append(f"{batch_id}-req_{line_id}")
290
292
 
291
293
  # Although streaming is supported for standalone completions, it is not supported in
292
294
  # batch mode (multiple completions in single request).
@@ -436,15 +438,9 @@ async def cancel_batch(tokenizer_manager, batch_id: str, input_file_id: str):
436
438
  with open(input_file_path, "r", encoding="utf-8") as f:
437
439
  lines = f.readlines()
438
440
 
439
- file_request_list = []
440
- request_ids = []
441
- for line in lines:
442
- request_data = json.loads(line)
443
- file_request_list.append(request_data)
444
- request_ids.append(request_data["custom_id"])
445
-
446
441
  # Cancel requests by request_ids
447
- for rid in request_ids:
442
+ for line_id in range(len(lines)):
443
+ rid = f"{batch_id}-req_{line_id}"
448
444
  tokenizer_manager.abort_request(rid=rid)
449
445
 
450
446
  retrieve_batch = batch_storage[batch_id]
@@ -824,13 +820,13 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
824
820
  )
825
821
 
826
822
  final_usage_chunk = CompletionStreamResponse(
827
- id=str(uuid.uuid4().hex),
823
+ id=content["meta_info"]["id"],
828
824
  choices=[],
829
825
  model=request.model,
830
826
  usage=usage,
831
827
  )
832
828
  final_usage_data = final_usage_chunk.model_dump_json(
833
- exclude_unset=True, exclude_none=True
829
+ exclude_none=True
834
830
  )
835
831
  yield f"data: {final_usage_data}\n\n"
836
832
  except ValueError as e:
@@ -1119,27 +1115,29 @@ def v1_chat_generate_response(
1119
1115
  else:
1120
1116
  reasoning_text = None
1121
1117
 
1122
- if tool_choice != "none" and any([i in text for i in TOOLS_TAG_LIST]):
1123
- if finish_reason == "stop":
1124
- finish_reason = "tool_calls"
1125
- try:
1126
- parser = FunctionCallParser(tools, tool_call_parser)
1127
- full_normal_text, call_info_list = parser.parse_non_stream(text)
1128
- tool_calls = [
1129
- ToolCall(
1130
- id=str(call_info.tool_index),
1131
- function=FunctionResponse(
1132
- name=call_info.name, arguments=call_info.parameters
1133
- ),
1118
+ if tool_choice != "none" and tools:
1119
+ parser = FunctionCallParser(tools, tool_call_parser)
1120
+ if parser.has_tool_call(text):
1121
+ if finish_reason["type"] == "stop":
1122
+ finish_reason["type"] = "tool_calls"
1123
+ finish_reason["matched"] = None
1124
+ try:
1125
+ full_normal_text, call_info_list = parser.parse_non_stream(text)
1126
+ tool_calls = [
1127
+ ToolCall(
1128
+ id=str(call_info.tool_index),
1129
+ function=FunctionResponse(
1130
+ name=call_info.name, arguments=call_info.parameters
1131
+ ),
1132
+ )
1133
+ for call_info in call_info_list
1134
+ ]
1135
+ except Exception as e:
1136
+ logger.error(f"Exception: {e}")
1137
+ return create_error_response(
1138
+ HTTPStatus.BAD_REQUEST,
1139
+ "Failed to parse fc related info to json format!",
1134
1140
  )
1135
- for call_info in call_info_list
1136
- ]
1137
- except Exception as e:
1138
- logger.error(f"Exception: {e}")
1139
- return create_error_response(
1140
- HTTPStatus.BAD_REQUEST,
1141
- "Failed to parse fc related info to json format!",
1142
- )
1143
1141
 
1144
1142
  if to_file:
1145
1143
  # to make the choice data json serializable
@@ -1151,7 +1149,7 @@ def v1_chat_generate_response(
1151
1149
  "tool_calls": tool_calls,
1152
1150
  "reasoning_content": reasoning_text,
1153
1151
  },
1154
- "logprobs": choice_logprobs,
1152
+ "logprobs": choice_logprobs.model_dump() if choice_logprobs else None,
1155
1153
  "finish_reason": (finish_reason["type"] if finish_reason else ""),
1156
1154
  "matched_stop": (
1157
1155
  finish_reason["matched"]
@@ -1499,13 +1497,13 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
1499
1497
  )
1500
1498
 
1501
1499
  final_usage_chunk = ChatCompletionStreamResponse(
1502
- id=str(uuid.uuid4().hex),
1500
+ id=content["meta_info"]["id"],
1503
1501
  choices=[],
1504
1502
  model=request.model,
1505
1503
  usage=usage,
1506
1504
  )
1507
1505
  final_usage_data = final_usage_chunk.model_dump_json(
1508
- exclude_unset=True, exclude_none=True
1506
+ exclude_none=True
1509
1507
  )
1510
1508
  yield f"data: {final_usage_data}\n\n"
1511
1509
  except ValueError as e:
@@ -1556,11 +1554,37 @@ def v1_embedding_request(all_requests, tokenizer_manager):
1556
1554
  prompt = prompts[0]
1557
1555
  if isinstance(prompt, str) or isinstance(prompt[0], str):
1558
1556
  prompt_kwargs = {"text": prompt}
1557
+ elif isinstance(prompt, list) and isinstance(
1558
+ prompt[0], MultimodalEmbeddingInput
1559
+ ):
1560
+ assert (
1561
+ chat_template_name is not None
1562
+ ), "chat_template_name is required for multimodal inputs"
1563
+ texts = []
1564
+ images = []
1565
+ for item in prompt:
1566
+ texts.append(item.text if item.text is not None else None)
1567
+ images.append(item.image if item.image is not None else None)
1568
+ convs = generate_embedding_convs(texts, images, chat_template_name)
1569
+ generate_prompts = []
1570
+ for conv in convs:
1571
+ generate_prompts.append(conv.get_prompt())
1572
+ if len(generate_prompts) == 1:
1573
+ prompt_kwargs = {"text": generate_prompts[0], "image_data": images[0]}
1574
+ else:
1575
+ prompt_kwargs = {"text": generate_prompts, "image_data": images}
1559
1576
  else:
1560
1577
  prompt_kwargs = {"input_ids": prompt}
1561
1578
  else:
1562
1579
  if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
1563
1580
  prompt_kwargs = {"text": prompts}
1581
+ elif isinstance(prompts[0], list) and isinstance(
1582
+ prompts[0][0], MultimodalEmbeddingInput
1583
+ ):
1584
+ # TODO: multiple requests
1585
+ raise NotImplementedError(
1586
+ "Multiple requests with multimodal inputs are not supported yet"
1587
+ )
1564
1588
  else:
1565
1589
  prompt_kwargs = {"input_ids": prompts}
1566
1590
 
@@ -403,10 +403,17 @@ class ChatCompletionStreamResponse(BaseModel):
403
403
  usage: Optional[UsageInfo] = None
404
404
 
405
405
 
406
+ class MultimodalEmbeddingInput(BaseModel):
407
+ text: Optional[str] = None
408
+ image: Optional[str] = None
409
+
410
+
406
411
  class EmbeddingRequest(BaseModel):
407
412
  # Ordered by official OpenAI API documentation
408
413
  # https://platform.openai.com/docs/api-reference/embeddings/create
409
- input: Union[List[int], List[List[int]], str, List[str]]
414
+ input: Union[
415
+ List[int], List[List[int]], str, List[str], List[MultimodalEmbeddingInput]
416
+ ]
410
417
  model: str
411
418
  encoding_format: str = "float"
412
419
  dimensions: int = None
@@ -56,7 +56,6 @@ class BatchedFrequencyPenalizer(_BatchedPenalizer):
56
56
  ]
57
57
 
58
58
  def _merge(self, their: "BatchedFrequencyPenalizer"):
59
- print(f"{self.frequency_penalties.shape=}, {their.frequency_penalties.shape=}")
60
59
  self.frequency_penalties = torch.cat(
61
60
  [self.frequency_penalties, their.frequency_penalties], dim=0
62
61
  )
@@ -56,7 +56,6 @@ class BatchedPresencePenalizer(_BatchedPenalizer):
56
56
  ]
57
57
 
58
58
  def _merge(self, their: "BatchedPresencePenalizer"):
59
- print(f"{self.presence_penalties.shape=}, {their.presence_penalties.shape=}")
60
59
  self.presence_penalties = torch.cat(
61
60
  [self.presence_penalties, their.presence_penalties], dim=0
62
61
  )
sglang/srt/server_args.py CHANGED
@@ -20,14 +20,13 @@ import random
20
20
  import tempfile
21
21
  from typing import List, Optional
22
22
 
23
- import torch
24
-
25
23
  from sglang.srt.hf_transformers_utils import check_gguf_file
26
24
  from sglang.srt.reasoning_parser import ReasoningParser
27
25
  from sglang.srt.utils import (
28
26
  get_amdgpu_memory_capacity,
29
27
  get_hpu_memory_capacity,
30
28
  get_nvgpu_memory_capacity,
29
+ is_cuda,
31
30
  is_flashinfer_available,
32
31
  is_hip,
33
32
  is_port_available,
@@ -71,6 +70,7 @@ class ServerArgs:
71
70
  schedule_policy: str = "fcfs"
72
71
  schedule_conservativeness: float = 1.0
73
72
  cpu_offload_gb: int = 0
73
+ page_size: int = 1
74
74
 
75
75
  # Other runtime options
76
76
  tp_size: int = 1
@@ -190,10 +190,10 @@ class ServerArgs:
190
190
  if self.random_seed is None:
191
191
  self.random_seed = random.randint(0, 1 << 30)
192
192
 
193
- if is_hip():
194
- gpu_mem = get_amdgpu_memory_capacity()
195
- elif torch.cuda.is_available():
193
+ if is_cuda():
196
194
  gpu_mem = get_nvgpu_memory_capacity()
195
+ elif is_hip():
196
+ gpu_mem = get_amdgpu_memory_capacity()
197
197
  elif self.device == "hpu":
198
198
  gpu_mem = get_hpu_memory_capacity()
199
199
  else:
@@ -220,6 +220,8 @@ class ServerArgs:
220
220
  else:
221
221
  self.chunked_prefill_size = 8192
222
222
 
223
+ assert self.chunked_prefill_size % self.page_size == 0
224
+
223
225
  # Set cuda graph max batch size
224
226
  if self.cuda_graph_max_bs is None:
225
227
  # Based on detailed statistics, when serving TP1/TP2 models on lower-end GPUs with HBM<25G, you can either disable cuda graph or set `cuda_graph_max_bs` to a very small value to reduce the memory overhead of creating cuda graphs, with almost no impact on performance. However, when serving models with TP4 or TP8, we need to enable cuda graph to maintain high performance. In this case, we can set `cuda_graph_max_bs` to 80 (half of the default value 160) to reduce the memory overhead of creating cuda graphs. Looking at the logs from TP4 serving of qwen2-72b, a value of 80 is sufficient and can reduce the memory overhead of creating cuda graphs on lower-end GPUs compared to the original 160, avoiding OOM issues.
@@ -258,16 +260,16 @@ class ServerArgs:
258
260
  f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
259
261
  )
260
262
 
261
- # Others
263
+ # Data parallelism attention
262
264
  if self.enable_dp_attention:
263
- self.dp_size = self.tp_size
264
- assert self.tp_size % self.dp_size == 0
265
- self.chunked_prefill_size = self.chunked_prefill_size // 2
266
265
  self.schedule_conservativeness = self.schedule_conservativeness * 0.3
266
+ assert (
267
+ self.dp_size > 1
268
+ ), "Please set a dp-size > 1. You can use 1 < dp-size <= tp-size "
269
+ assert self.tp_size % self.dp_size == 0
270
+ self.chunked_prefill_size = self.chunked_prefill_size // self.dp_size
267
271
  logger.warning(
268
272
  f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. "
269
- f"The schedule conservativeness is adjusted to {self.schedule_conservativeness}. "
270
- "Data parallel size is adjusted to be the same as tensor parallel size. "
271
273
  )
272
274
 
273
275
  # Speculative Decoding
@@ -278,10 +280,10 @@ class ServerArgs:
278
280
  if self.speculative_algorithm == "EAGLE":
279
281
  if self.max_running_requests is None:
280
282
  self.max_running_requests = 32
281
- self.disable_overlap_schedule = True
282
283
  self.disable_cuda_graph_padding = True
284
+ self.disable_overlap_schedule = True
283
285
  logger.info(
284
- "Overlap scheduler are disabled because of using "
286
+ "Overlap scheduler is disabled because of using "
285
287
  "eagle speculative decoding."
286
288
  )
287
289
  # The token generated from the verify step is counted.
@@ -405,6 +407,7 @@ class ServerArgs:
405
407
  "gguf",
406
408
  "modelopt",
407
409
  "w8a8_int8",
410
+ "w8a8_fp8",
408
411
  ],
409
412
  help="The quantization method.",
410
413
  )
@@ -479,7 +482,7 @@ class ServerArgs:
479
482
  "--chunked-prefill-size",
480
483
  type=int,
481
484
  default=ServerArgs.chunked_prefill_size,
482
- help="The maximum number of tokens in a chunk for the chunked prefill. Setting this to -1 means disabling chunked prefill",
485
+ help="The maximum number of tokens in a chunk for the chunked prefill. Setting this to -1 means disabling chunked prefill.",
483
486
  )
484
487
  parser.add_argument(
485
488
  "--max-prefill-tokens",
@@ -504,7 +507,13 @@ class ServerArgs:
504
507
  "--cpu-offload-gb",
505
508
  type=int,
506
509
  default=ServerArgs.cpu_offload_gb,
507
- help="How many GBs of RAM to reserve for CPU offloading",
510
+ help="How many GBs of RAM to reserve for CPU offloading.",
511
+ )
512
+ parser.add_argument(
513
+ "--page-size",
514
+ type=int,
515
+ default=ServerArgs.page_size,
516
+ help="The number of tokens in a page.",
508
517
  )
509
518
 
510
519
  # Other runtime options
@@ -764,7 +773,6 @@ class ServerArgs:
764
773
  "--speculative-eagle-topk",
765
774
  type=int,
766
775
  help="The number of tokens sampled from the draft model in eagle2 each step.",
767
- choices=[1, 2, 4, 8],
768
776
  default=ServerArgs.speculative_eagle_topk,
769
777
  )
770
778
  parser.add_argument(
@@ -7,6 +7,7 @@ import torch
7
7
  from huggingface_hub import snapshot_download
8
8
 
9
9
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
10
+ from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
10
11
  from sglang.srt.managers.schedule_batch import ScheduleBatch
11
12
  from sglang.srt.managers.tp_worker import TpModelWorker
12
13
  from sglang.srt.model_executor.forward_batch_info import (
@@ -122,6 +123,16 @@ class EAGLEWorker(TpModelWorker):
122
123
  self.topk,
123
124
  self.speculative_num_steps,
124
125
  )
126
+ elif self.server_args.attention_backend == "flashinfer_mla":
127
+ from sglang.srt.layers.attention.flashinfer_mla_backend import (
128
+ FlashInferMLAMultiStepDraftBackend,
129
+ )
130
+
131
+ self.draft_attn_backend = FlashInferMLAMultiStepDraftBackend(
132
+ self.model_runner,
133
+ self.topk,
134
+ self.speculative_num_steps,
135
+ )
125
136
  else:
126
137
  raise ValueError(
127
138
  f"EAGLE is not supportted in attention backend {self.server_args.attention_backend}"
@@ -302,13 +313,10 @@ class EAGLEWorker(TpModelWorker):
302
313
 
303
314
  # Set inputs
304
315
  forward_batch.input_ids = input_ids
316
+ out_cache_loc = out_cache_loc.view(forward_batch.batch_size, -1)
305
317
  forward_batch.out_cache_loc = out_cache_loc[
306
- forward_batch.batch_size
307
- * self.topk
308
- * i : forward_batch.batch_size
309
- * self.topk
310
- * (i + 1)
311
- ]
318
+ :, self.topk * i : self.topk * (i + 1)
319
+ ].flatten()
312
320
  forward_batch.positions.add_(1)
313
321
  forward_batch.attn_backend = self.draft_attn_backend.attn_backends[i]
314
322
  spec_info.hidden_states = hidden_states
@@ -353,42 +361,70 @@ class EAGLEWorker(TpModelWorker):
353
361
  batch.spec_info = res.draft_input
354
362
 
355
363
  if batch.return_logprob:
356
- # Compute output logprobs using the sampler.
357
- num_tokens_per_req = [
358
- accept + 1 for accept in res.accept_length_per_req_cpu
359
- ]
360
- self.target_worker.model_runner.update_output_logprobs(
361
- logits_output,
362
- batch.sampling_info,
363
- batch.top_logprobs_nums,
364
- batch.token_ids_logprobs,
365
- res.verified_id,
366
- # +1 for bonus token.
367
- num_tokens_per_req=num_tokens_per_req,
368
- )
369
-
370
- # Add output logprobs to the request.
371
- pt = 0
372
- # NOTE: tolist() of these values are skipped when output is processed
373
- next_token_logprobs = res.logits_output.next_token_logprobs.tolist()
374
- verified_ids = res.verified_id.tolist()
375
- for req, num_tokens in zip(batch.reqs, num_tokens_per_req):
376
- for _ in range(num_tokens):
377
- if req.return_logprob:
378
- token_id = verified_ids[pt]
379
- req.output_token_logprobs_val.append(next_token_logprobs[pt])
380
- req.output_token_logprobs_idx.append(token_id)
381
- if req.top_logprobs_num > 0:
382
- req.output_top_logprobs_val.append(
383
- res.logits_output.next_token_top_logprobs_val[pt]
384
- )
385
- req.output_top_logprobs_idx.append(
386
- res.logits_output.next_token_top_logprobs_idx[pt]
387
- )
388
- pt += 1
364
+ self.add_logprob_values(batch, res, logits_output)
389
365
 
390
366
  return logits_output, res, model_worker_batch
391
367
 
368
+ def add_logprob_values(
369
+ self,
370
+ batch: ScheduleBatch,
371
+ res: EagleVerifyOutput,
372
+ logits_output: LogitsProcessorOutput,
373
+ ):
374
+ # Extract args
375
+ logits_output = res.logits_output
376
+ top_logprobs_nums = batch.top_logprobs_nums
377
+ token_ids_logprobs = batch.token_ids_logprobs
378
+ logprobs = torch.nn.functional.log_softmax(
379
+ logits_output.next_token_logits, dim=-1
380
+ )
381
+ batch_next_token_ids = res.verified_id
382
+ num_tokens_per_req = [accept + 1 for accept in res.accept_length_per_req_cpu]
383
+
384
+ # We should repeat top_logprobs_nums to match num_tokens_per_req.
385
+ top_logprobs_nums_repeat_interleaved = []
386
+ token_ids_logprobs_repeat_interleaved = []
387
+ for num, num_tokens in zip(top_logprobs_nums, num_tokens_per_req):
388
+ top_logprobs_nums_repeat_interleaved.extend([num] * num_tokens)
389
+ for token_ids, num_tokens in zip(token_ids_logprobs, num_tokens_per_req):
390
+ token_ids_logprobs_repeat_interleaved.extend([token_ids] * num_tokens)
391
+
392
+ # Extract logprobs
393
+ if any(x > 0 for x in top_logprobs_nums):
394
+ (
395
+ logits_output.next_token_top_logprobs_val,
396
+ logits_output.next_token_top_logprobs_idx,
397
+ ) = get_top_logprobs(logprobs, top_logprobs_nums_repeat_interleaved)
398
+
399
+ if any(x is not None for x in token_ids_logprobs):
400
+ (
401
+ logits_output.next_token_token_ids_logprobs_val,
402
+ logits_output.next_token_token_ids_logprobs_idx,
403
+ ) = get_token_ids_logprobs(logprobs, token_ids_logprobs_repeat_interleaved)
404
+
405
+ logits_output.next_token_logprobs = logprobs[
406
+ torch.arange(len(batch_next_token_ids), device=batch.sampling_info.device),
407
+ batch_next_token_ids,
408
+ ]
409
+
410
+ # Add output logprobs to the request.
411
+ pt = 0
412
+ next_token_logprobs = logits_output.next_token_logprobs.tolist()
413
+ verified_ids = batch_next_token_ids.tolist()
414
+ for req, num_tokens in zip(batch.reqs, num_tokens_per_req):
415
+ for _ in range(num_tokens):
416
+ if req.return_logprob:
417
+ req.output_token_logprobs_val.append(next_token_logprobs[pt])
418
+ req.output_token_logprobs_idx.append(verified_ids[pt])
419
+ if req.top_logprobs_num > 0:
420
+ req.output_top_logprobs_val.append(
421
+ res.logits_output.next_token_top_logprobs_val[pt]
422
+ )
423
+ req.output_top_logprobs_idx.append(
424
+ res.logits_output.next_token_top_logprobs_idx[pt]
425
+ )
426
+ pt += 1
427
+
392
428
  def forward_draft_extend(
393
429
  self,
394
430
  batch: ScheduleBatch,
sglang/srt/utils.py CHANGED
@@ -14,6 +14,7 @@
14
14
  """Common utilities."""
15
15
 
16
16
  import base64
17
+ import builtins
17
18
  import ctypes
18
19
  import dataclasses
19
20
  import io
@@ -37,6 +38,7 @@ import time
37
38
  import warnings
38
39
  from functools import lru_cache
39
40
  from importlib.metadata import PackageNotFoundError, version
41
+ from importlib.util import find_spec
40
42
  from io import BytesIO
41
43
  from multiprocessing import Pool
42
44
  from multiprocessing.reduction import ForkingPickler
@@ -52,11 +54,13 @@ import triton
52
54
  import zmq
53
55
  from fastapi.responses import ORJSONResponse
54
56
  from packaging import version as pkg_version
57
+ from packaging.version import Version, parse
55
58
  from starlette.routing import Mount
56
59
  from torch import nn
57
60
  from torch.func import functional_call
58
61
  from torch.library import Library
59
62
  from torch.profiler import ProfilerActivity, profile, record_function
63
+ from torch.utils.cpp_extension import CUDA_HOME
60
64
  from triton.runtime.cache import (
61
65
  FileCacheManager,
62
66
  default_cache_dir,
@@ -69,14 +73,31 @@ logger = logging.getLogger(__name__)
69
73
  show_time_cost = False
70
74
  time_infos = {}
71
75
 
76
+ HIP_FP8_E4M3_FNUZ_MAX = 224.0
72
77
 
78
+
79
+ # https://pytorch.org/docs/stable/notes/hip.html#checking-for-hip
73
80
  def is_hip() -> bool:
74
- """Return whether it is HIP on the AMD ROCm platform."""
75
81
  return torch.version.hip is not None
76
82
 
77
83
 
84
+ if is_hip():
85
+ FP8_E4M3_MAX = HIP_FP8_E4M3_FNUZ_MAX
86
+ else:
87
+ FP8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
88
+
89
+ FP8_E4M3_MIN = -FP8_E4M3_MAX
90
+
91
+ builtins.FP8_E4M3_MAX = FP8_E4M3_MAX
92
+ builtins.FP8_E4M3_MIN = FP8_E4M3_MIN
93
+
94
+
95
+ def is_rocm() -> bool:
96
+ return torch.cuda.is_available() and torch.version.hip
97
+
98
+
78
99
  def is_cuda():
79
- return hasattr(torch, "cuda") and torch.version.cuda is not None
100
+ return torch.cuda.is_available() and torch.version.cuda
80
101
 
81
102
 
82
103
  def is_cuda_alike():
@@ -98,11 +119,11 @@ def is_flashinfer_available():
98
119
  """
99
120
  if not get_bool_env_var("SGLANG_IS_FLASHINFER_AVAILABLE", default="true"):
100
121
  return False
101
- return torch.cuda.is_available() and torch.version.cuda
122
+ return is_cuda()
102
123
 
103
124
 
104
125
  def is_cuda_available():
105
- return torch.cuda.is_available() and torch.version.cuda
126
+ return is_cuda()
106
127
 
107
128
 
108
129
  def enable_show_time_cost():
@@ -1045,6 +1066,65 @@ def get_device_name(device_id: int = 0) -> str:
1045
1066
  return torch.hpu.get_device_name(device_id)
1046
1067
 
1047
1068
 
1069
+ @lru_cache(maxsize=1)
1070
+ def is_habana_available() -> bool:
1071
+ return find_spec("habana_frameworks") is not None
1072
+
1073
+
1074
+ @lru_cache(maxsize=8)
1075
+ def get_device(device_id: Optional[int] = None) -> str:
1076
+ if hasattr(torch, "cuda") and torch.cuda.is_available():
1077
+ if device_id is None:
1078
+ return "cuda"
1079
+ return "cuda:{}".format(device_id)
1080
+
1081
+ if hasattr(torch, "xpu") and torch.xpu.is_available():
1082
+ if device_id == None:
1083
+ return "xpu"
1084
+ return "xpu:{}".format(device_id)
1085
+
1086
+ if is_habana_available():
1087
+ try:
1088
+ import habana_frameworks.torch.hpu
1089
+
1090
+ if torch.hpu.is_available():
1091
+ if device_id == None:
1092
+ return "hpu"
1093
+ return "hpu:{}".format(device_id)
1094
+ except ImportError as e:
1095
+ raise ImportError(
1096
+ "Habana frameworks detected, but failed to import 'habana_frameworks.torch.hpu'."
1097
+ )
1098
+
1099
+ raise RuntimeError("No accelerator (CUDA, XPU, HPU) is available.")
1100
+
1101
+
1102
+ @lru_cache(maxsize=1)
1103
+ def get_device_count() -> int:
1104
+ if hasattr(torch, "cuda") and torch.cuda.is_available():
1105
+ try:
1106
+ return torch.cuda.device_count()
1107
+ except RuntimeError:
1108
+ return 0
1109
+
1110
+ if hasattr(torch, "xpu") and torch.xpu.is_available():
1111
+ try:
1112
+ return torch.xpu.device_count()
1113
+ except RuntimeError:
1114
+ return 0
1115
+
1116
+ if is_habana_available():
1117
+ try:
1118
+ import habana_frameworks.torch.hpu
1119
+
1120
+ if torch.hpu.is_available():
1121
+ return torch.hpu.device_count()
1122
+ except (ImportError, RuntimeError):
1123
+ return 0
1124
+
1125
+ return 0 # No accelerators available
1126
+
1127
+
1048
1128
  def get_device_core_count(device_id: int = 0) -> int:
1049
1129
  if hasattr(torch, "cuda") and torch.cuda.is_available():
1050
1130
  return torch.cuda.get_device_properties(device_id).multi_processor_count
@@ -1063,11 +1143,12 @@ def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
1063
1143
  )
1064
1144
  major, minor = int(major), int(minor)
1065
1145
 
1066
- # TODO(HandH1998): `get_device_capability` is not supported by `torch.hpu` for now.
1067
- # Update this once the support is available.
1068
1146
  if hasattr(torch, "hpu") and torch.hpu.is_available():
1069
1147
  try:
1070
- major, minor = torch.hpu.get_device_capability(device_id)
1148
+ # TODO(HandH1998): `get_device_capability` is not supported by `torch.hpu` for now.
1149
+ # Update this once the support is available.
1150
+ # major, minor = torch.hpu.get_device_capability(device_id)
1151
+ major, minor = None, None
1071
1152
  except Exception as e:
1072
1153
  raise RuntimeError(
1073
1154
  f"An error occurred while getting device capability of hpu: {e}."
@@ -1269,7 +1350,8 @@ def permute_weight(x: torch.Tensor) -> torch.Tensor:
1269
1350
  elif x.dtype == torch.float8_e4m3fnuz or x.dtype == torch.int8:
1270
1351
  x_ = x_.view(int(b_), int(n_ / 16), 16, int(k_ / 64), 4, 16)
1271
1352
  else:
1272
- return x_
1353
+ # return x_
1354
+ x_ = x_.view(int(b_), int(n_ / 16), 16, int(k_ / 8), 2, 4)
1273
1355
 
1274
1356
  x_ = x_.permute(0, 1, 3, 4, 2, 5)
1275
1357
  x_ = x_.contiguous()
@@ -1341,7 +1423,7 @@ def kill_itself_when_parent_died():
1341
1423
  libc = ctypes.CDLL("libc.so.6")
1342
1424
  libc.prctl(PR_SET_PDEATHSIG, signal.SIGKILL)
1343
1425
  else:
1344
- logger.warninig("kill_itself_when_parent_died is only supported in linux.")
1426
+ logger.warning("kill_itself_when_parent_died is only supported in linux.")
1345
1427
 
1346
1428
 
1347
1429
  def set_uvicorn_logging_configs():
@@ -1430,6 +1512,12 @@ def rank0_print(msg: str):
1430
1512
  print(msg, flush=True)
1431
1513
 
1432
1514
 
1515
+ def get_cuda_version():
1516
+ if torch.version.cuda:
1517
+ return tuple(map(int, torch.version.cuda.split(".")))
1518
+ return (0, 0)
1519
+
1520
+
1433
1521
  def launch_dummy_health_check_server(host, port):
1434
1522
  import uvicorn
1435
1523
  from fastapi import FastAPI, Response
@@ -1466,6 +1554,13 @@ def set_cuda_arch():
1466
1554
  os.environ["TORCH_CUDA_ARCH_LIST"] = f"{arch}{'+PTX' if arch == '9.0' else ''}"
1467
1555
 
1468
1556
 
1557
+ def next_power_of_2(n: int):
1558
+ return 1 << (n - 1).bit_length() if n > 0 else 1
1559
+
1560
+
1561
+ setattr(triton, "next_power_of_2", next_power_of_2)
1562
+
1563
+
1469
1564
  def add_prefix(name: str, prefix: str) -> str:
1470
1565
  """Add a weight path prefix to a module name.
1471
1566