sglang 0.4.3.post3__py3-none-any.whl → 0.4.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (128) hide show
  1. sglang/bench_serving.py +2 -2
  2. sglang/lang/chat_template.py +29 -0
  3. sglang/srt/_custom_ops.py +19 -17
  4. sglang/srt/configs/__init__.py +2 -0
  5. sglang/srt/configs/janus_pro.py +629 -0
  6. sglang/srt/configs/model_config.py +24 -14
  7. sglang/srt/conversation.py +80 -2
  8. sglang/srt/custom_op.py +64 -3
  9. sglang/srt/distributed/device_communicators/custom_all_reduce.py +18 -17
  10. sglang/srt/distributed/parallel_state.py +10 -1
  11. sglang/srt/entrypoints/engine.py +5 -3
  12. sglang/srt/entrypoints/http_server.py +1 -1
  13. sglang/srt/hf_transformers_utils.py +16 -1
  14. sglang/srt/layers/attention/flashinfer_backend.py +95 -49
  15. sglang/srt/layers/attention/flashinfer_mla_backend.py +317 -57
  16. sglang/srt/layers/attention/triton_backend.py +5 -5
  17. sglang/srt/layers/attention/triton_ops/decode_attention.py +6 -6
  18. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +3 -3
  19. sglang/srt/layers/attention/triton_ops/extend_attention.py +4 -4
  20. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +3 -3
  21. sglang/srt/layers/attention/vision.py +43 -62
  22. sglang/srt/layers/linear.py +1 -1
  23. sglang/srt/layers/moe/ep_moe/kernels.py +2 -1
  24. sglang/srt/layers/moe/ep_moe/layer.py +25 -9
  25. sglang/srt/layers/moe/fused_moe_triton/configs/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  27. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +63 -23
  32. sglang/srt/layers/moe/fused_moe_triton/layer.py +16 -4
  33. sglang/srt/layers/parameter.py +10 -0
  34. sglang/srt/layers/quantization/__init__.py +90 -68
  35. sglang/srt/layers/quantization/blockwise_int8.py +1 -2
  36. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  37. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  38. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  39. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  40. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  41. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  43. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  44. sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  45. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  46. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  47. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  48. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  49. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  50. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  51. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  52. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/quantization/fp8.py +174 -106
  63. sglang/srt/layers/quantization/fp8_kernel.py +210 -38
  64. sglang/srt/layers/quantization/fp8_utils.py +156 -15
  65. sglang/srt/layers/quantization/modelopt_quant.py +5 -1
  66. sglang/srt/layers/quantization/w8a8_fp8.py +128 -0
  67. sglang/srt/layers/quantization/w8a8_int8.py +152 -3
  68. sglang/srt/layers/rotary_embedding.py +5 -3
  69. sglang/srt/layers/sampler.py +29 -35
  70. sglang/srt/layers/vocab_parallel_embedding.py +0 -1
  71. sglang/srt/lora/backend/__init__.py +9 -12
  72. sglang/srt/managers/cache_controller.py +72 -8
  73. sglang/srt/managers/image_processor.py +37 -631
  74. sglang/srt/managers/image_processors/base_image_processor.py +219 -0
  75. sglang/srt/managers/image_processors/janus_pro.py +79 -0
  76. sglang/srt/managers/image_processors/llava.py +152 -0
  77. sglang/srt/managers/image_processors/minicpmv.py +86 -0
  78. sglang/srt/managers/image_processors/mlama.py +60 -0
  79. sglang/srt/managers/image_processors/qwen_vl.py +161 -0
  80. sglang/srt/managers/io_struct.py +33 -15
  81. sglang/srt/managers/multi_modality_padding.py +134 -0
  82. sglang/srt/managers/schedule_batch.py +212 -117
  83. sglang/srt/managers/schedule_policy.py +40 -8
  84. sglang/srt/managers/scheduler.py +258 -782
  85. sglang/srt/managers/scheduler_output_processor_mixin.py +611 -0
  86. sglang/srt/managers/tokenizer_manager.py +7 -6
  87. sglang/srt/managers/tp_worker_overlap_thread.py +4 -1
  88. sglang/srt/mem_cache/base_prefix_cache.py +6 -8
  89. sglang/srt/mem_cache/chunk_cache.py +12 -44
  90. sglang/srt/mem_cache/hiradix_cache.py +63 -34
  91. sglang/srt/mem_cache/memory_pool.py +112 -46
  92. sglang/srt/mem_cache/paged_allocator.py +283 -0
  93. sglang/srt/mem_cache/radix_cache.py +117 -36
  94. sglang/srt/metrics/collector.py +8 -0
  95. sglang/srt/model_executor/cuda_graph_runner.py +10 -11
  96. sglang/srt/model_executor/forward_batch_info.py +12 -8
  97. sglang/srt/model_executor/model_runner.py +153 -134
  98. sglang/srt/model_loader/loader.py +2 -1
  99. sglang/srt/model_loader/weight_utils.py +1 -1
  100. sglang/srt/models/deepseek_janus_pro.py +2127 -0
  101. sglang/srt/models/deepseek_nextn.py +23 -3
  102. sglang/srt/models/deepseek_v2.py +25 -19
  103. sglang/srt/models/minicpmv.py +28 -89
  104. sglang/srt/models/mllama.py +1 -1
  105. sglang/srt/models/qwen2.py +0 -1
  106. sglang/srt/models/qwen2_5_vl.py +25 -50
  107. sglang/srt/models/qwen2_vl.py +33 -49
  108. sglang/srt/openai_api/adapter.py +37 -15
  109. sglang/srt/openai_api/protocol.py +8 -1
  110. sglang/srt/sampling/penaltylib/frequency_penalty.py +0 -1
  111. sglang/srt/sampling/penaltylib/presence_penalty.py +0 -1
  112. sglang/srt/server_args.py +19 -20
  113. sglang/srt/speculative/build_eagle_tree.py +6 -1
  114. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -11
  115. sglang/srt/speculative/eagle_utils.py +2 -1
  116. sglang/srt/speculative/eagle_worker.py +109 -38
  117. sglang/srt/utils.py +104 -9
  118. sglang/test/runners.py +104 -10
  119. sglang/test/test_block_fp8.py +106 -16
  120. sglang/test/test_custom_ops.py +88 -0
  121. sglang/test/test_utils.py +20 -4
  122. sglang/utils.py +0 -4
  123. sglang/version.py +1 -1
  124. {sglang-0.4.3.post3.dist-info → sglang-0.4.4.dist-info}/METADATA +9 -9
  125. {sglang-0.4.3.post3.dist-info → sglang-0.4.4.dist-info}/RECORD +128 -83
  126. {sglang-0.4.3.post3.dist-info → sglang-0.4.4.dist-info}/WHEEL +1 -1
  127. {sglang-0.4.3.post3.dist-info → sglang-0.4.4.dist-info}/LICENSE +0 -0
  128. {sglang-0.4.3.post3.dist-info → sglang-0.4.4.dist-info}/top_level.txt +0 -0
@@ -38,6 +38,7 @@ from sglang.srt.conversation import (
38
38
  SeparatorStyle,
39
39
  chat_template_exists,
40
40
  generate_chat_conv,
41
+ generate_embedding_convs,
41
42
  register_conv_template,
42
43
  )
43
44
  from sglang.srt.function_call_parser import TOOLS_TAG_LIST, FunctionCallParser
@@ -68,6 +69,7 @@ from sglang.srt.openai_api.protocol import (
68
69
  FileResponse,
69
70
  FunctionResponse,
70
71
  LogProbs,
72
+ MultimodalEmbeddingInput,
71
73
  ToolCall,
72
74
  TopLogprob,
73
75
  UsageInfo,
@@ -282,11 +284,11 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
282
284
  file_request_list = []
283
285
  all_requests = []
284
286
  request_ids = []
285
- for line in lines:
287
+ for line_id, line in enumerate(lines):
286
288
  request_data = json.loads(line)
287
289
  file_request_list.append(request_data)
288
290
  body = request_data["body"]
289
- request_ids.append(request_data["custom_id"])
291
+ request_ids.append(f"{batch_id}-req_{line_id}")
290
292
 
291
293
  # Although streaming is supported for standalone completions, it is not supported in
292
294
  # batch mode (multiple completions in single request).
@@ -436,15 +438,9 @@ async def cancel_batch(tokenizer_manager, batch_id: str, input_file_id: str):
436
438
  with open(input_file_path, "r", encoding="utf-8") as f:
437
439
  lines = f.readlines()
438
440
 
439
- file_request_list = []
440
- request_ids = []
441
- for line in lines:
442
- request_data = json.loads(line)
443
- file_request_list.append(request_data)
444
- request_ids.append(request_data["custom_id"])
445
-
446
441
  # Cancel requests by request_ids
447
- for rid in request_ids:
442
+ for line_id in range(len(lines)):
443
+ rid = f"{batch_id}-req_{line_id}"
448
444
  tokenizer_manager.abort_request(rid=rid)
449
445
 
450
446
  retrieve_batch = batch_storage[batch_id]
@@ -824,13 +820,13 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
824
820
  )
825
821
 
826
822
  final_usage_chunk = CompletionStreamResponse(
827
- id=str(uuid.uuid4().hex),
823
+ id=content["meta_info"]["id"],
828
824
  choices=[],
829
825
  model=request.model,
830
826
  usage=usage,
831
827
  )
832
828
  final_usage_data = final_usage_chunk.model_dump_json(
833
- exclude_unset=True, exclude_none=True
829
+ exclude_none=True
834
830
  )
835
831
  yield f"data: {final_usage_data}\n\n"
836
832
  except ValueError as e:
@@ -1151,7 +1147,7 @@ def v1_chat_generate_response(
1151
1147
  "tool_calls": tool_calls,
1152
1148
  "reasoning_content": reasoning_text,
1153
1149
  },
1154
- "logprobs": choice_logprobs,
1150
+ "logprobs": choice_logprobs.model_dump() if choice_logprobs else None,
1155
1151
  "finish_reason": (finish_reason["type"] if finish_reason else ""),
1156
1152
  "matched_stop": (
1157
1153
  finish_reason["matched"]
@@ -1499,13 +1495,13 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
1499
1495
  )
1500
1496
 
1501
1497
  final_usage_chunk = ChatCompletionStreamResponse(
1502
- id=str(uuid.uuid4().hex),
1498
+ id=content["meta_info"]["id"],
1503
1499
  choices=[],
1504
1500
  model=request.model,
1505
1501
  usage=usage,
1506
1502
  )
1507
1503
  final_usage_data = final_usage_chunk.model_dump_json(
1508
- exclude_unset=True, exclude_none=True
1504
+ exclude_none=True
1509
1505
  )
1510
1506
  yield f"data: {final_usage_data}\n\n"
1511
1507
  except ValueError as e:
@@ -1556,11 +1552,37 @@ def v1_embedding_request(all_requests, tokenizer_manager):
1556
1552
  prompt = prompts[0]
1557
1553
  if isinstance(prompt, str) or isinstance(prompt[0], str):
1558
1554
  prompt_kwargs = {"text": prompt}
1555
+ elif isinstance(prompt, list) and isinstance(
1556
+ prompt[0], MultimodalEmbeddingInput
1557
+ ):
1558
+ assert (
1559
+ chat_template_name is not None
1560
+ ), "chat_template_name is required for multimodal inputs"
1561
+ texts = []
1562
+ images = []
1563
+ for item in prompt:
1564
+ texts.append(item.text if item.text is not None else None)
1565
+ images.append(item.image if item.image is not None else None)
1566
+ convs = generate_embedding_convs(texts, images, chat_template_name)
1567
+ generate_prompts = []
1568
+ for conv in convs:
1569
+ generate_prompts.append(conv.get_prompt())
1570
+ if len(generate_prompts) == 1:
1571
+ prompt_kwargs = {"text": generate_prompts[0], "image_data": images[0]}
1572
+ else:
1573
+ prompt_kwargs = {"text": generate_prompts, "image_data": images}
1559
1574
  else:
1560
1575
  prompt_kwargs = {"input_ids": prompt}
1561
1576
  else:
1562
1577
  if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
1563
1578
  prompt_kwargs = {"text": prompts}
1579
+ elif isinstance(prompts[0], list) and isinstance(
1580
+ prompts[0][0], MultimodalEmbeddingInput
1581
+ ):
1582
+ # TODO: multiple requests
1583
+ raise NotImplementedError(
1584
+ "Multiple requests with multimodal inputs are not supported yet"
1585
+ )
1564
1586
  else:
1565
1587
  prompt_kwargs = {"input_ids": prompts}
1566
1588
 
@@ -403,10 +403,17 @@ class ChatCompletionStreamResponse(BaseModel):
403
403
  usage: Optional[UsageInfo] = None
404
404
 
405
405
 
406
+ class MultimodalEmbeddingInput(BaseModel):
407
+ text: Optional[str] = None
408
+ image: Optional[str] = None
409
+
410
+
406
411
  class EmbeddingRequest(BaseModel):
407
412
  # Ordered by official OpenAI API documentation
408
413
  # https://platform.openai.com/docs/api-reference/embeddings/create
409
- input: Union[List[int], List[List[int]], str, List[str]]
414
+ input: Union[
415
+ List[int], List[List[int]], str, List[str], List[MultimodalEmbeddingInput]
416
+ ]
410
417
  model: str
411
418
  encoding_format: str = "float"
412
419
  dimensions: int = None
@@ -56,7 +56,6 @@ class BatchedFrequencyPenalizer(_BatchedPenalizer):
56
56
  ]
57
57
 
58
58
  def _merge(self, their: "BatchedFrequencyPenalizer"):
59
- print(f"{self.frequency_penalties.shape=}, {their.frequency_penalties.shape=}")
60
59
  self.frequency_penalties = torch.cat(
61
60
  [self.frequency_penalties, their.frequency_penalties], dim=0
62
61
  )
@@ -56,7 +56,6 @@ class BatchedPresencePenalizer(_BatchedPenalizer):
56
56
  ]
57
57
 
58
58
  def _merge(self, their: "BatchedPresencePenalizer"):
59
- print(f"{self.presence_penalties.shape=}, {their.presence_penalties.shape=}")
60
59
  self.presence_penalties = torch.cat(
61
60
  [self.presence_penalties, their.presence_penalties], dim=0
62
61
  )
sglang/srt/server_args.py CHANGED
@@ -20,14 +20,13 @@ import random
20
20
  import tempfile
21
21
  from typing import List, Optional
22
22
 
23
- import torch
24
-
25
23
  from sglang.srt.hf_transformers_utils import check_gguf_file
26
24
  from sglang.srt.reasoning_parser import ReasoningParser
27
25
  from sglang.srt.utils import (
28
26
  get_amdgpu_memory_capacity,
29
27
  get_hpu_memory_capacity,
30
28
  get_nvgpu_memory_capacity,
29
+ is_cuda,
31
30
  is_flashinfer_available,
32
31
  is_hip,
33
32
  is_port_available,
@@ -71,7 +70,7 @@ class ServerArgs:
71
70
  schedule_policy: str = "fcfs"
72
71
  schedule_conservativeness: float = 1.0
73
72
  cpu_offload_gb: int = 0
74
- prefill_only_one_req: bool = False
73
+ page_size: int = 1
75
74
 
76
75
  # Other runtime options
77
76
  tp_size: int = 1
@@ -191,10 +190,10 @@ class ServerArgs:
191
190
  if self.random_seed is None:
192
191
  self.random_seed = random.randint(0, 1 << 30)
193
192
 
194
- if is_hip():
195
- gpu_mem = get_amdgpu_memory_capacity()
196
- elif torch.cuda.is_available():
193
+ if is_cuda():
197
194
  gpu_mem = get_nvgpu_memory_capacity()
195
+ elif is_hip():
196
+ gpu_mem = get_amdgpu_memory_capacity()
198
197
  elif self.device == "hpu":
199
198
  gpu_mem = get_hpu_memory_capacity()
200
199
  else:
@@ -221,6 +220,8 @@ class ServerArgs:
221
220
  else:
222
221
  self.chunked_prefill_size = 8192
223
222
 
223
+ assert self.chunked_prefill_size % self.page_size == 0
224
+
224
225
  # Set cuda graph max batch size
225
226
  if self.cuda_graph_max_bs is None:
226
227
  # Based on detailed statistics, when serving TP1/TP2 models on lower-end GPUs with HBM<25G, you can either disable cuda graph or set `cuda_graph_max_bs` to a very small value to reduce the memory overhead of creating cuda graphs, with almost no impact on performance. However, when serving models with TP4 or TP8, we need to enable cuda graph to maintain high performance. In this case, we can set `cuda_graph_max_bs` to 80 (half of the default value 160) to reduce the memory overhead of creating cuda graphs. Looking at the logs from TP4 serving of qwen2-72b, a value of 80 is sufficient and can reduce the memory overhead of creating cuda graphs on lower-end GPUs compared to the original 160, avoiding OOM issues.
@@ -259,7 +260,7 @@ class ServerArgs:
259
260
  f"EP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
260
261
  )
261
262
 
262
- # Others
263
+ # Data parallelism attention
263
264
  if self.enable_dp_attention:
264
265
  self.dp_size = self.tp_size
265
266
  assert self.tp_size % self.dp_size == 0
@@ -277,19 +278,17 @@ class ServerArgs:
277
278
  self.speculative_algorithm = "EAGLE"
278
279
 
279
280
  if self.speculative_algorithm == "EAGLE":
280
- self.disable_overlap_schedule = True
281
- self.prefill_only_one_req = True
282
- self.disable_cuda_graph_padding = True
283
281
  if self.max_running_requests is None:
284
282
  self.max_running_requests = 32
283
+ self.disable_cuda_graph_padding = True
284
+ self.disable_overlap_schedule = True
285
285
  logger.info(
286
- "Overlap scheduler are disabled because of using "
286
+ "Overlap scheduler is disabled because of using "
287
287
  "eagle speculative decoding."
288
- "Max running request set to 32 because of using eagle speculative decoding."
289
288
  )
290
289
  # The token generated from the verify step is counted.
291
290
  # If sepculative_num_steps >= speculative_num_draft_tokens, the additional tokens will definitely be discarded.
292
- assert self.speculative_num_steps < self.speculative_num_draft_tokens
291
+ # assert self.speculative_num_steps < self.speculative_num_draft_tokens
293
292
 
294
293
  # GGUF
295
294
  if (
@@ -408,6 +407,7 @@ class ServerArgs:
408
407
  "gguf",
409
408
  "modelopt",
410
409
  "w8a8_int8",
410
+ "w8a8_fp8",
411
411
  ],
412
412
  help="The quantization method.",
413
413
  )
@@ -482,7 +482,7 @@ class ServerArgs:
482
482
  "--chunked-prefill-size",
483
483
  type=int,
484
484
  default=ServerArgs.chunked_prefill_size,
485
- help="The maximum number of tokens in a chunk for the chunked prefill. Setting this to -1 means disabling chunked prefill",
485
+ help="The maximum number of tokens in a chunk for the chunked prefill. Setting this to -1 means disabling chunked prefill.",
486
486
  )
487
487
  parser.add_argument(
488
488
  "--max-prefill-tokens",
@@ -507,13 +507,13 @@ class ServerArgs:
507
507
  "--cpu-offload-gb",
508
508
  type=int,
509
509
  default=ServerArgs.cpu_offload_gb,
510
- help="How many GBs of RAM to reserve for CPU offloading",
510
+ help="How many GBs of RAM to reserve for CPU offloading.",
511
511
  )
512
512
  parser.add_argument(
513
- "--prefill-only-one-req",
514
- type=bool,
515
- help="If true, we only prefill one request at one prefill batch",
516
- default=ServerArgs.prefill_only_one_req,
513
+ "--page-size",
514
+ type=int,
515
+ default=ServerArgs.page_size,
516
+ help="The number of tokens in a page.",
517
517
  )
518
518
 
519
519
  # Other runtime options
@@ -773,7 +773,6 @@ class ServerArgs:
773
773
  "--speculative-eagle-topk",
774
774
  type=int,
775
775
  help="The number of tokens sampled from the draft model in eagle2 each step.",
776
- choices=[1, 2, 4, 8],
777
776
  default=ServerArgs.speculative_eagle_topk,
778
777
  )
779
778
  parser.add_argument(
@@ -26,7 +26,12 @@ def build_tree_kernel_efficient_preprocess(
26
26
 
27
27
  draft_tokens = torch.gather(ss_token_list, index=top_scores_index, dim=1)
28
28
  draft_tokens = torch.cat((verified_id.unsqueeze(1), draft_tokens), dim=1).flatten()
29
- parent_list = torch.cat(parents_list[:-1], dim=1)
29
+
30
+ if len(parents_list) > 1:
31
+ parent_list = torch.cat(parents_list[:-1], dim=1)
32
+ else:
33
+ batch_size = parents_list[0].shape[0]
34
+ parent_list = torch.empty(batch_size, 0, device=parents_list[0].device)
30
35
 
31
36
  return parent_list, top_scores_index, draft_tokens
32
37
 
@@ -1,7 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import bisect
4
- import time
5
4
  from typing import TYPE_CHECKING, Callable
6
5
 
7
6
  import torch
@@ -162,20 +161,11 @@ class EAGLEDraftCudaGraphRunner:
162
161
 
163
162
  run_once()
164
163
 
165
- torch.cuda.synchronize()
166
- self.model_runner.tp_group.barrier()
167
-
168
- torch.cuda.synchronize()
169
- self.model_runner.tp_group.barrier()
170
-
171
164
  with torch.cuda.graph(
172
165
  graph, pool=get_global_graph_memory_pool(), stream=stream
173
166
  ):
174
167
  out = run_once()
175
168
 
176
- torch.cuda.synchronize()
177
- self.model_runner.tp_group.barrier()
178
-
179
169
  set_global_graph_memory_pool(graph.pool())
180
170
  return graph, out
181
171
 
@@ -204,7 +194,7 @@ class EAGLEDraftCudaGraphRunner:
204
194
 
205
195
  # Attention backend
206
196
  self.model_runner.draft_attn_backend.init_forward_metadata_replay_cuda_graph(
207
- forward_batch
197
+ forward_batch, forward_batch.batch_size
208
198
  )
209
199
 
210
200
  # Replay
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from dataclasses import dataclass
4
- from typing import TYPE_CHECKING, Dict, List
4
+ from typing import TYPE_CHECKING, List
5
5
 
6
6
  import torch
7
7
  import torch.nn.functional as F
@@ -62,6 +62,7 @@ class EagleDraftInput:
62
62
  batch.input_ids[pt : pt + extend_len] = torch.concat(
63
63
  (input_ids[1:], self.verified_id[i].reshape(1))
64
64
  )
65
+ pt += extend_len
65
66
 
66
67
  def prepare_extend_after_decode(self, batch: ScheduleBatch, speculative_num_steps):
67
68
  assert self.verified_id.numel() == batch.out_cache_loc.shape[0]
@@ -1,20 +1,20 @@
1
1
  import logging
2
2
  import os
3
3
  import time
4
- from typing import Dict, List, Optional, Tuple, Union
4
+ from typing import List, Optional, Tuple
5
5
 
6
6
  import torch
7
7
  from huggingface_hub import snapshot_download
8
8
 
9
9
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
10
- from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
10
+ from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
11
+ from sglang.srt.managers.schedule_batch import ScheduleBatch
11
12
  from sglang.srt.managers.tp_worker import TpModelWorker
12
13
  from sglang.srt.model_executor.forward_batch_info import (
13
14
  CaptureHiddenMode,
14
15
  ForwardBatch,
15
16
  ForwardMode,
16
17
  )
17
- from sglang.srt.model_executor.model_runner import ModelRunner
18
18
  from sglang.srt.server_args import ServerArgs
19
19
  from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
20
20
  EAGLEDraftCudaGraphRunner,
@@ -27,7 +27,6 @@ from sglang.srt.speculative.eagle_utils import (
27
27
  fast_topk,
28
28
  select_top_k_tokens,
29
29
  )
30
- from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
31
30
  from sglang.srt.utils import get_available_gpu_memory
32
31
 
33
32
  logger = logging.getLogger(__name__)
@@ -44,16 +43,30 @@ class EAGLEWorker(TpModelWorker):
44
43
  nccl_port: int,
45
44
  target_worker: TpModelWorker,
46
45
  ):
46
+ # Parse arguments
47
+ self.server_args = server_args
48
+ self.topk = server_args.speculative_eagle_topk
49
+ self.speculative_num_steps = server_args.speculative_num_steps
50
+ self.padded_static_len = self.speculative_num_steps + 1
51
+ self.enable_nan_detection = server_args.enable_nan_detection
52
+ self.gpu_id = gpu_id
53
+ self.device = server_args.device
54
+ self.target_worker = target_worker
55
+
47
56
  # Override context length with target model's context length
48
57
  server_args.context_length = target_worker.model_runner.model_config.context_len
49
- os.environ["SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"] = "1"
50
58
 
51
59
  # Do not capture cuda graph in `super().__init__()`
52
- # We will capture it later
60
+ # It will be captured later.
53
61
  backup_disable_cuda_graph = server_args.disable_cuda_graph
54
62
  server_args.disable_cuda_graph = True
63
+ # Share the allocator with a target worker.
64
+ # Draft and target worker own their own KV cache pools.
65
+ self.req_to_token_pool, self.token_to_kv_pool_allocator = (
66
+ target_worker.get_memory_pool()
67
+ )
55
68
 
56
- # Lossy optimization by using hot tokens
69
+ # Load hot token ids
57
70
  if server_args.speculative_token_map is not None:
58
71
  self.hot_token_id = load_token_map(server_args.speculative_token_map)
59
72
  server_args.json_model_override_args = (
@@ -62,13 +75,7 @@ class EAGLEWorker(TpModelWorker):
62
75
  else:
63
76
  self.hot_token_id = None
64
77
 
65
- # We share the allocator with a target worker. Draft/target worker
66
- # owns its own KV cache.
67
- self.req_to_token_pool, self.token_to_kv_pool_allocator = (
68
- target_worker.get_memory_pool()
69
- )
70
-
71
- # Init target worker
78
+ # Init draft worker
72
79
  super().__init__(
73
80
  gpu_id=gpu_id,
74
81
  tp_rank=tp_rank,
@@ -79,18 +86,6 @@ class EAGLEWorker(TpModelWorker):
79
86
  req_to_token_pool=self.req_to_token_pool,
80
87
  token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
81
88
  )
82
- self.target_worker = target_worker
83
-
84
- # Parse arguments
85
- self.topk = server_args.speculative_eagle_topk
86
- self.speculative_num_steps = server_args.speculative_num_steps
87
- self.speculative_algorithm = SpeculativeAlgorithm.from_string(
88
- server_args.speculative_algorithm
89
- )
90
- self.server_args = server_args
91
- self.use_nan_detection = self.server_args.enable_nan_detection
92
- self.device = self.model_runner.device
93
- self.gpu_id = self.model_runner.gpu_id
94
89
 
95
90
  # Share the embedding and lm_head
96
91
  embed, head = self.target_worker.model_runner.model.get_embed_and_head()
@@ -103,8 +98,12 @@ class EAGLEWorker(TpModelWorker):
103
98
  backup_disable_cuda_graph
104
99
  )
105
100
 
101
+ self.init_attention_backend()
102
+ self.init_cuda_graphs()
103
+
104
+ def init_attention_backend(self):
106
105
  # Create multi-step attn backends and cuda graph runners
107
- if server_args.attention_backend == "flashinfer":
106
+ if self.server_args.attention_backend == "flashinfer":
108
107
  from sglang.srt.layers.attention.flashinfer_backend import (
109
108
  FlashInferMultiStepDraftBackend,
110
109
  )
@@ -114,7 +113,7 @@ class EAGLEWorker(TpModelWorker):
114
113
  self.topk,
115
114
  self.speculative_num_steps,
116
115
  )
117
- elif server_args.attention_backend == "triton":
116
+ elif self.server_args.attention_backend == "triton":
118
117
  from sglang.srt.layers.attention.triton_backend import (
119
118
  TritonMultiStepDraftBackend,
120
119
  )
@@ -124,13 +123,21 @@ class EAGLEWorker(TpModelWorker):
124
123
  self.topk,
125
124
  self.speculative_num_steps,
126
125
  )
126
+ elif self.server_args.attention_backend == "flashinfer_mla":
127
+ from sglang.srt.layers.attention.flashinfer_mla_backend import (
128
+ FlashInferMLAMultiStepDraftBackend,
129
+ )
130
+
131
+ self.draft_attn_backend = FlashInferMLAMultiStepDraftBackend(
132
+ self.model_runner,
133
+ self.topk,
134
+ self.speculative_num_steps,
135
+ )
127
136
  else:
128
137
  raise ValueError(
129
- f"EAGLE is not supportted in attention backend {server_args.attention_backend}"
138
+ f"EAGLE is not supportted in attention backend {self.server_args.attention_backend}"
130
139
  )
131
-
132
140
  self.draft_model_runner.draft_attn_backend = self.draft_attn_backend
133
- self.init_cuda_graphs()
134
141
 
135
142
  def init_cuda_graphs(self):
136
143
  """Capture cuda graphs."""
@@ -306,13 +313,10 @@ class EAGLEWorker(TpModelWorker):
306
313
 
307
314
  # Set inputs
308
315
  forward_batch.input_ids = input_ids
316
+ out_cache_loc = out_cache_loc.view(forward_batch.batch_size, -1)
309
317
  forward_batch.out_cache_loc = out_cache_loc[
310
- forward_batch.batch_size
311
- * self.topk
312
- * i : forward_batch.batch_size
313
- * self.topk
314
- * (i + 1)
315
- ]
318
+ :, self.topk * i : self.topk * (i + 1)
319
+ ].flatten()
316
320
  forward_batch.positions.add_(1)
317
321
  forward_batch.attn_backend = self.draft_attn_backend.attn_backends[i]
318
322
  spec_info.hidden_states = hidden_states
@@ -356,8 +360,71 @@ class EAGLEWorker(TpModelWorker):
356
360
  batch.forward_mode = ForwardMode.DECODE
357
361
  batch.spec_info = res.draft_input
358
362
 
363
+ if batch.return_logprob:
364
+ self.add_logprob_values(batch, res, logits_output)
365
+
359
366
  return logits_output, res, model_worker_batch
360
367
 
368
+ def add_logprob_values(
369
+ self,
370
+ batch: ScheduleBatch,
371
+ res: EagleVerifyOutput,
372
+ logits_output: LogitsProcessorOutput,
373
+ ):
374
+ # Extract args
375
+ logits_output = res.logits_output
376
+ top_logprobs_nums = batch.top_logprobs_nums
377
+ token_ids_logprobs = batch.token_ids_logprobs
378
+ logprobs = torch.nn.functional.log_softmax(
379
+ logits_output.next_token_logits, dim=-1
380
+ )
381
+ batch_next_token_ids = res.verified_id
382
+ num_tokens_per_req = [accept + 1 for accept in res.accept_length_per_req_cpu]
383
+
384
+ # We should repeat top_logprobs_nums to match num_tokens_per_req.
385
+ top_logprobs_nums_repeat_interleaved = []
386
+ token_ids_logprobs_repeat_interleaved = []
387
+ for num, num_tokens in zip(top_logprobs_nums, num_tokens_per_req):
388
+ top_logprobs_nums_repeat_interleaved.extend([num] * num_tokens)
389
+ for token_ids, num_tokens in zip(token_ids_logprobs, num_tokens_per_req):
390
+ token_ids_logprobs_repeat_interleaved.extend([token_ids] * num_tokens)
391
+
392
+ # Extract logprobs
393
+ if any(x > 0 for x in top_logprobs_nums):
394
+ (
395
+ logits_output.next_token_top_logprobs_val,
396
+ logits_output.next_token_top_logprobs_idx,
397
+ ) = get_top_logprobs(logprobs, top_logprobs_nums_repeat_interleaved)
398
+
399
+ if any(x is not None for x in token_ids_logprobs):
400
+ (
401
+ logits_output.next_token_token_ids_logprobs_val,
402
+ logits_output.next_token_token_ids_logprobs_idx,
403
+ ) = get_token_ids_logprobs(logprobs, token_ids_logprobs_repeat_interleaved)
404
+
405
+ logits_output.next_token_logprobs = logprobs[
406
+ torch.arange(len(batch_next_token_ids), device=batch.sampling_info.device),
407
+ batch_next_token_ids,
408
+ ]
409
+
410
+ # Add output logprobs to the request.
411
+ pt = 0
412
+ next_token_logprobs = logits_output.next_token_logprobs.tolist()
413
+ verified_ids = batch_next_token_ids.tolist()
414
+ for req, num_tokens in zip(batch.reqs, num_tokens_per_req):
415
+ for _ in range(num_tokens):
416
+ if req.return_logprob:
417
+ req.output_token_logprobs_val.append(next_token_logprobs[pt])
418
+ req.output_token_logprobs_idx.append(verified_ids[pt])
419
+ if req.top_logprobs_num > 0:
420
+ req.output_top_logprobs_val.append(
421
+ res.logits_output.next_token_top_logprobs_val[pt]
422
+ )
423
+ req.output_top_logprobs_idx.append(
424
+ res.logits_output.next_token_top_logprobs_idx[pt]
425
+ )
426
+ pt += 1
427
+
361
428
  def forward_draft_extend(
362
429
  self,
363
430
  batch: ScheduleBatch,
@@ -381,6 +448,7 @@ class EAGLEWorker(TpModelWorker):
381
448
  forward_batch = ForwardBatch.init_new(
382
449
  model_worker_batch, self.draft_model_runner
383
450
  )
451
+ forward_batch.return_logprob = False
384
452
  logits_output = self.draft_model_runner.forward(forward_batch)
385
453
  self._detect_nan_if_needed(logits_output)
386
454
  assert isinstance(forward_batch.spec_info, EagleDraftInput)
@@ -393,6 +461,8 @@ class EAGLEWorker(TpModelWorker):
393
461
  batch.spec_info.prepare_extend_after_decode(batch, self.speculative_num_steps)
394
462
  batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
395
463
  # We don't need logprob for this extend.
464
+ original_return_logprob = batch.return_logprob
465
+ batch.return_logprob = False
396
466
  model_worker_batch = batch.get_model_worker_batch()
397
467
  forward_batch = ForwardBatch.init_new(
398
468
  model_worker_batch, self.draft_model_runner
@@ -404,6 +474,7 @@ class EAGLEWorker(TpModelWorker):
404
474
 
405
475
  # Restore backup.
406
476
  # This is because `seq_lens` can be modified in `prepare_extend_after_decode`
477
+ batch.return_logprob = original_return_logprob
407
478
  batch.forward_mode = ForwardMode.DECODE
408
479
  batch.seq_lens = seq_lens_backup
409
480
 
@@ -415,7 +486,7 @@ class EAGLEWorker(TpModelWorker):
415
486
  draft_input.hidden_states = logits_output.hidden_states
416
487
 
417
488
  def _detect_nan_if_needed(self, logits_output: LogitsProcessorOutput):
418
- if self.use_nan_detection:
489
+ if self.enable_nan_detection:
419
490
  logits = logits_output.next_token_logits
420
491
  if torch.any(torch.isnan(logits)):
421
492
  logger.warning("Detected errors during sampling! NaN in the logits.")