sglang 0.5.4__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. sglang/bench_one_batch.py +149 -34
  2. sglang/bench_serving.py +73 -14
  3. sglang/compile_deep_gemm.py +13 -7
  4. sglang/launch_server.py +2 -0
  5. sglang/srt/batch_invariant_ops/__init__.py +2 -0
  6. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +221 -4
  7. sglang/srt/checkpoint_engine/__init__.py +9 -0
  8. sglang/srt/checkpoint_engine/update.py +317 -0
  9. sglang/srt/compilation/backend.py +1 -1
  10. sglang/srt/configs/__init__.py +2 -0
  11. sglang/srt/configs/deepseek_ocr.py +542 -10
  12. sglang/srt/configs/deepseekvl2.py +95 -194
  13. sglang/srt/configs/kimi_linear.py +160 -0
  14. sglang/srt/configs/mamba_utils.py +66 -0
  15. sglang/srt/configs/model_config.py +30 -7
  16. sglang/srt/constants.py +7 -0
  17. sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
  18. sglang/srt/disaggregation/decode.py +34 -6
  19. sglang/srt/disaggregation/nixl/conn.py +2 -2
  20. sglang/srt/disaggregation/prefill.py +25 -3
  21. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
  22. sglang/srt/distributed/parallel_state.py +9 -12
  23. sglang/srt/entrypoints/engine.py +31 -20
  24. sglang/srt/entrypoints/grpc_server.py +0 -1
  25. sglang/srt/entrypoints/http_server.py +94 -94
  26. sglang/srt/entrypoints/openai/protocol.py +7 -1
  27. sglang/srt/entrypoints/openai/serving_chat.py +42 -0
  28. sglang/srt/entrypoints/openai/serving_completions.py +10 -0
  29. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  30. sglang/srt/environ.py +23 -2
  31. sglang/srt/eplb/expert_distribution.py +64 -1
  32. sglang/srt/eplb/expert_location.py +106 -36
  33. sglang/srt/function_call/function_call_parser.py +2 -0
  34. sglang/srt/function_call/minimax_m2.py +367 -0
  35. sglang/srt/grpc/compile_proto.py +3 -0
  36. sglang/srt/layers/activation.py +6 -0
  37. sglang/srt/layers/attention/ascend_backend.py +233 -5
  38. sglang/srt/layers/attention/attention_registry.py +3 -0
  39. sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
  40. sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
  41. sglang/srt/layers/attention/fla/kda.py +1359 -0
  42. sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
  43. sglang/srt/layers/attention/flashattention_backend.py +19 -8
  44. sglang/srt/layers/attention/flashinfer_backend.py +10 -1
  45. sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -11
  46. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  47. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
  48. sglang/srt/layers/attention/mamba/mamba.py +20 -11
  49. sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
  50. sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
  51. sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
  52. sglang/srt/layers/attention/nsa/transform_index.py +1 -1
  53. sglang/srt/layers/attention/nsa_backend.py +157 -23
  54. sglang/srt/layers/attention/triton_backend.py +4 -1
  55. sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
  56. sglang/srt/layers/attention/trtllm_mla_backend.py +11 -15
  57. sglang/srt/layers/attention/utils.py +78 -0
  58. sglang/srt/layers/communicator.py +24 -1
  59. sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
  60. sglang/srt/layers/layernorm.py +35 -6
  61. sglang/srt/layers/logits_processor.py +9 -20
  62. sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
  63. sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
  64. sglang/srt/layers/moe/ep_moe/layer.py +78 -289
  65. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
  67. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
  68. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
  69. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
  70. sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
  71. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
  72. sglang/srt/layers/moe/moe_runner/deep_gemm.py +340 -55
  73. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  74. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  75. sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
  76. sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
  77. sglang/srt/layers/moe/token_dispatcher/deepep.py +25 -18
  78. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  79. sglang/srt/layers/moe/topk.py +35 -10
  80. sglang/srt/layers/moe/utils.py +3 -4
  81. sglang/srt/layers/pooler.py +21 -2
  82. sglang/srt/layers/quantization/__init__.py +13 -84
  83. sglang/srt/layers/quantization/auto_round.py +394 -0
  84. sglang/srt/layers/quantization/awq.py +0 -3
  85. sglang/srt/layers/quantization/base_config.py +7 -0
  86. sglang/srt/layers/quantization/fp8.py +68 -63
  87. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  88. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  89. sglang/srt/layers/quantization/gguf.py +566 -0
  90. sglang/srt/layers/quantization/modelopt_quant.py +168 -11
  91. sglang/srt/layers/quantization/mxfp4.py +30 -38
  92. sglang/srt/layers/quantization/unquant.py +23 -45
  93. sglang/srt/layers/quantization/w4afp8.py +38 -2
  94. sglang/srt/layers/radix_attention.py +5 -2
  95. sglang/srt/layers/rotary_embedding.py +130 -46
  96. sglang/srt/layers/sampler.py +12 -1
  97. sglang/srt/lora/lora_registry.py +9 -0
  98. sglang/srt/managers/async_mm_data_processor.py +122 -0
  99. sglang/srt/managers/data_parallel_controller.py +30 -3
  100. sglang/srt/managers/detokenizer_manager.py +3 -0
  101. sglang/srt/managers/io_struct.py +29 -4
  102. sglang/srt/managers/multi_tokenizer_mixin.py +22 -1
  103. sglang/srt/managers/schedule_batch.py +74 -15
  104. sglang/srt/managers/scheduler.py +185 -144
  105. sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
  106. sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
  107. sglang/srt/managers/scheduler_pp_mixin.py +7 -2
  108. sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
  109. sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
  110. sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
  111. sglang/srt/managers/session_controller.py +6 -5
  112. sglang/srt/managers/tokenizer_manager.py +165 -78
  113. sglang/srt/managers/tp_worker.py +24 -1
  114. sglang/srt/mem_cache/base_prefix_cache.py +23 -4
  115. sglang/srt/mem_cache/common.py +1 -0
  116. sglang/srt/mem_cache/hicache_storage.py +7 -1
  117. sglang/srt/mem_cache/memory_pool.py +253 -57
  118. sglang/srt/mem_cache/memory_pool_host.py +12 -5
  119. sglang/srt/mem_cache/radix_cache.py +4 -0
  120. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  121. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
  122. sglang/srt/metrics/collector.py +46 -3
  123. sglang/srt/model_executor/cuda_graph_runner.py +15 -3
  124. sglang/srt/model_executor/forward_batch_info.py +55 -14
  125. sglang/srt/model_executor/model_runner.py +77 -170
  126. sglang/srt/model_executor/npu_graph_runner.py +7 -3
  127. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
  128. sglang/srt/model_loader/weight_utils.py +1 -1
  129. sglang/srt/models/bailing_moe.py +9 -2
  130. sglang/srt/models/deepseek_nextn.py +11 -2
  131. sglang/srt/models/deepseek_v2.py +296 -78
  132. sglang/srt/models/glm4.py +391 -77
  133. sglang/srt/models/glm4_moe.py +322 -354
  134. sglang/srt/models/glm4_moe_nextn.py +4 -14
  135. sglang/srt/models/glm4v.py +196 -55
  136. sglang/srt/models/glm4v_moe.py +29 -197
  137. sglang/srt/models/gpt_oss.py +1 -10
  138. sglang/srt/models/kimi_linear.py +678 -0
  139. sglang/srt/models/llama4.py +1 -1
  140. sglang/srt/models/llama_eagle3.py +11 -1
  141. sglang/srt/models/longcat_flash.py +2 -2
  142. sglang/srt/models/minimax_m2.py +922 -0
  143. sglang/srt/models/nvila.py +355 -0
  144. sglang/srt/models/nvila_lite.py +184 -0
  145. sglang/srt/models/qwen2.py +23 -2
  146. sglang/srt/models/qwen2_moe.py +30 -15
  147. sglang/srt/models/qwen3.py +35 -5
  148. sglang/srt/models/qwen3_moe.py +18 -12
  149. sglang/srt/models/qwen3_next.py +7 -0
  150. sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
  151. sglang/srt/multimodal/processors/base_processor.py +1 -0
  152. sglang/srt/multimodal/processors/glm4v.py +1 -1
  153. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  154. sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
  155. sglang/srt/multiplex/multiplexing_mixin.py +209 -0
  156. sglang/srt/multiplex/pdmux_context.py +164 -0
  157. sglang/srt/parser/conversation.py +7 -1
  158. sglang/srt/parser/reasoning_parser.py +28 -1
  159. sglang/srt/sampling/custom_logit_processor.py +67 -1
  160. sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
  161. sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
  162. sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
  163. sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
  164. sglang/srt/server_args.py +459 -199
  165. sglang/srt/single_batch_overlap.py +2 -4
  166. sglang/srt/speculative/draft_utils.py +16 -0
  167. sglang/srt/speculative/eagle_info.py +42 -36
  168. sglang/srt/speculative/eagle_info_v2.py +68 -25
  169. sglang/srt/speculative/eagle_utils.py +261 -16
  170. sglang/srt/speculative/eagle_worker.py +11 -3
  171. sglang/srt/speculative/eagle_worker_v2.py +15 -9
  172. sglang/srt/speculative/spec_info.py +305 -31
  173. sglang/srt/speculative/spec_utils.py +44 -8
  174. sglang/srt/tracing/trace.py +121 -12
  175. sglang/srt/utils/common.py +142 -74
  176. sglang/srt/utils/hf_transformers_utils.py +38 -12
  177. sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
  178. sglang/test/kits/radix_cache_server_kit.py +50 -0
  179. sglang/test/runners.py +31 -7
  180. sglang/test/simple_eval_common.py +5 -3
  181. sglang/test/simple_eval_humaneval.py +1 -0
  182. sglang/test/simple_eval_math.py +1 -0
  183. sglang/test/simple_eval_mmlu.py +1 -0
  184. sglang/test/simple_eval_mmmu_vlm.py +1 -0
  185. sglang/test/test_deterministic.py +235 -12
  186. sglang/test/test_deterministic_utils.py +2 -1
  187. sglang/test/test_utils.py +7 -1
  188. sglang/version.py +1 -1
  189. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +15 -28
  190. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +194 -175
  191. sglang/srt/models/vila.py +0 -306
  192. /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
  193. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
  194. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
  195. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
@@ -41,6 +41,12 @@ class TorchMemorySaverAdapter(ABC):
41
41
  def region(self, tag: str, enable_cpu_backup: bool = False):
42
42
  raise NotImplementedError
43
43
 
44
+ def cuda_graph(self, **kwargs):
45
+ raise NotImplementedError
46
+
47
+ def disable(self):
48
+ raise NotImplementedError
49
+
44
50
  def pause(self, tag: str):
45
51
  raise NotImplementedError
46
52
 
@@ -61,6 +67,12 @@ class _TorchMemorySaverAdapterReal(TorchMemorySaverAdapter):
61
67
  def region(self, tag: str, enable_cpu_backup: bool = False):
62
68
  return _memory_saver.region(tag=tag, enable_cpu_backup=enable_cpu_backup)
63
69
 
70
+ def cuda_graph(self, **kwargs):
71
+ return _memory_saver.cuda_graph(**kwargs)
72
+
73
+ def disable(self):
74
+ return _memory_saver.disable()
75
+
64
76
  def pause(self, tag: str):
65
77
  return _memory_saver.pause(tag=tag)
66
78
 
@@ -81,6 +93,14 @@ class _TorchMemorySaverAdapterNoop(TorchMemorySaverAdapter):
81
93
  def region(self, tag: str, enable_cpu_backup: bool = False):
82
94
  yield
83
95
 
96
+ @contextmanager
97
+ def cuda_graph(self, **kwargs):
98
+ yield
99
+
100
+ @contextmanager
101
+ def disable(self):
102
+ yield
103
+
84
104
  def pause(self, tag: str):
85
105
  pass
86
106
 
@@ -0,0 +1,50 @@
1
+ import random
2
+
3
+ import requests
4
+
5
+
6
+ def gen_radix_tree(num_nodes=400, chunk_len=256):
7
+ num0 = num_nodes // 2
8
+ num1 = num_nodes - num0
9
+ nodes = [{"input_ids": [37] * 117, "decode_len": 217}]
10
+ for _ in range(num0):
11
+ parent = random.choice(nodes)
12
+ unique_len = random.randint(0, chunk_len)
13
+ decode_len = random.randint(0, chunk_len)
14
+ token_id = random.randint(0, 32000)
15
+ child = {
16
+ "input_ids": parent["input_ids"] + [token_id] * unique_len,
17
+ "decode_len": decode_len,
18
+ }
19
+ nodes.append(child)
20
+
21
+ while num1 > 0:
22
+ num_branch = random.randint(1, min(num1, 10))
23
+ parent = random.choice(nodes)
24
+ for _ in range(num_branch):
25
+ unique_len = random.randint(0, chunk_len)
26
+ decode_len = random.randint(0, chunk_len)
27
+ token_id = random.randint(0, 32000)
28
+ child = {
29
+ "input_ids": parent["input_ids"] + [token_id] * unique_len,
30
+ "decode_len": decode_len,
31
+ }
32
+ nodes.append(child)
33
+
34
+ num1 -= num_branch
35
+
36
+ random.shuffle(nodes)
37
+ return nodes
38
+
39
+
40
+ def run_radix_attention_test(base_url: str):
41
+ nodes = gen_radix_tree()
42
+ data = {
43
+ "input_ids": [node["input_ids"] for node in nodes],
44
+ "sampling_params": [
45
+ {"max_new_tokens": node["decode_len"], "temperature": 0} for node in nodes
46
+ ],
47
+ }
48
+
49
+ res = requests.post(base_url + "/generate", json=data)
50
+ assert res.status_code == 200
sglang/test/runners.py CHANGED
@@ -12,10 +12,11 @@
12
12
  # limitations under the License.
13
13
  # ==============================================================================
14
14
 
15
+ import json
15
16
  import multiprocessing as mp
16
17
  import os
17
18
  from dataclasses import dataclass
18
- from typing import List, Optional, Tuple, Union
19
+ from typing import Any, List, Optional, Tuple, Union
19
20
 
20
21
  import torch
21
22
  import torch.nn.functional as F
@@ -89,7 +90,9 @@ def get_token_ids_logprobs(logits, token_ids):
89
90
  return logprobs
90
91
 
91
92
 
92
- def _get_sentence_transformer_embedding_model(model_path, torch_dtype):
93
+ def _get_sentence_transformer_embedding_model(
94
+ model_path, torch_dtype, matryoshka_dim: Optional[int] = None
95
+ ):
93
96
  from sentence_transformers import SentenceTransformer
94
97
  from sentence_transformers.util import is_sentence_transformer_model
95
98
 
@@ -97,6 +100,7 @@ def _get_sentence_transformer_embedding_model(model_path, torch_dtype):
97
100
  model = SentenceTransformer(
98
101
  model_path,
99
102
  model_kwargs={"torch_dtype": torch_dtype},
103
+ truncate_dim=matryoshka_dim,
100
104
  )
101
105
  else: # if no pre-trained sentence-transformers model
102
106
  from sentence_transformers import models
@@ -106,7 +110,9 @@ def _get_sentence_transformer_embedding_model(model_path, torch_dtype):
106
110
  word_embedding_model.get_word_embedding_dimension(),
107
111
  pooling_mode="lasttoken",
108
112
  )
109
- model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
113
+ model = SentenceTransformer(
114
+ modules=[word_embedding_model, pooling_model], truncate_dim=matryoshka_dim
115
+ )
110
116
 
111
117
  return model.cuda()
112
118
 
@@ -135,6 +141,7 @@ class HFRunner:
135
141
  output_str_only: bool = False,
136
142
  trust_remote_code: bool = False,
137
143
  patch_model_do_sample_false: bool = False,
144
+ matryoshka_dim: Optional[int] = None,
138
145
  ):
139
146
  self.model_type = model_type
140
147
  self.output_str_only = output_str_only
@@ -151,6 +158,7 @@ class HFRunner:
151
158
  self.out_queue,
152
159
  model_path,
153
160
  torch_dtype,
161
+ matryoshka_dim,
154
162
  ),
155
163
  )
156
164
  self.model_proc.start()
@@ -225,7 +233,14 @@ class HFRunner:
225
233
  embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
226
234
  return embeddings.contiguous()
227
235
 
228
- def start_model_process(self, in_queue, out_queue, model_path, torch_dtype):
236
+ def start_model_process(
237
+ self,
238
+ in_queue,
239
+ out_queue,
240
+ model_path,
241
+ torch_dtype,
242
+ matryoshka_dim: Optional[int] = None,
243
+ ):
229
244
  # Apply model-specific patches
230
245
  monkey_patch_gemma2_sdpa()
231
246
 
@@ -259,7 +274,7 @@ class HFRunner:
259
274
  self.processor = AutoProcessor.from_pretrained(model_path)
260
275
  else:
261
276
  self.model = _get_sentence_transformer_embedding_model(
262
- model_path, torch_dtype
277
+ model_path, torch_dtype, matryoshka_dim=matryoshka_dim
263
278
  )
264
279
  elif self.model_type == "reward" or self.model_type == "cross_encoder":
265
280
  from transformers import AutoModelForSequenceClassification
@@ -496,7 +511,7 @@ class SRTRunner:
496
511
  attention_backend: Optional[str] = None,
497
512
  prefill_attention_backend: Optional[str] = None,
498
513
  decode_attention_backend: Optional[str] = None,
499
- lora_backend: str = "triton",
514
+ lora_backend: str = "csgmv",
500
515
  disable_cuda_graph: bool = False,
501
516
  disable_radix_cache: bool = False,
502
517
  chunked_prefill_size: Optional[int] = None,
@@ -519,6 +534,7 @@ class SRTRunner:
519
534
  lora_target_modules: Optional[List[str]] = None,
520
535
  enable_lora: Optional[bool] = None,
521
536
  max_loaded_loras: Optional[int] = None,
537
+ json_model_override_args: Optional[dict[str, Any]] = None,
522
538
  lora_eviction_policy: str = "lru",
523
539
  ):
524
540
  self.model_type = model_type
@@ -566,6 +582,11 @@ class SRTRunner:
566
582
  lora_target_modules=lora_target_modules,
567
583
  enable_lora=enable_lora,
568
584
  max_loaded_loras=max_loaded_loras,
585
+ json_model_override_args=(
586
+ json.dumps(json_model_override_args)
587
+ if json_model_override_args
588
+ else "{}"
589
+ ),
569
590
  lora_eviction_policy=lora_eviction_policy,
570
591
  **spec_kwargs,
571
592
  )
@@ -594,6 +615,7 @@ class SRTRunner:
594
615
  logprob_start_len: int = 0,
595
616
  top_k: Optional[int] = None,
596
617
  token_ids_logprob: Optional[List[int]] = None,
618
+ dimensions: Optional[int] = None,
597
619
  ):
598
620
  if self.is_generation:
599
621
  return self.forward_generation_raw(
@@ -607,7 +629,9 @@ class SRTRunner:
607
629
  )
608
630
  else:
609
631
  if self.model_type == "embedding":
610
- response = self.engine.encode(prompt=prompts, image_data=image_data)
632
+ response = self.engine.encode(
633
+ prompt=prompts, image_data=image_data, dimensions=dimensions
634
+ )
611
635
  if isinstance(response, list):
612
636
  logits = [x["embedding"] for x in response]
613
637
  else:
@@ -148,7 +148,7 @@ class ChatCompletionSampler(SamplerBase):
148
148
  reasoning_effort=self.reasoning_effort,
149
149
  extra_body=self.extra_body,
150
150
  )
151
- return response.choices[0].message.content
151
+ return response.choices[0].message.content or ""
152
152
  # NOTE: BadRequestError is triggered once for MMMU, please uncomment if you are rerunning MMMU
153
153
  except openai.BadRequestError as e:
154
154
  print("Bad Request Error", e)
@@ -161,7 +161,9 @@ class ChatCompletionSampler(SamplerBase):
161
161
  )
162
162
  time.sleep(exception_backoff)
163
163
  trial += 1
164
- # unknown error shall throw exception
164
+ # If all retries are exhausted, return empty string instead of None
165
+ print(f"All retry attempts exhausted for request. Returning empty response.")
166
+ return ""
165
167
 
166
168
 
167
169
  QUERY_TEMPLATE_MULTICHOICE = """
@@ -261,7 +263,7 @@ def format_multichoice_question(row):
261
263
  def check_equality(sampler: SamplerBase, expr1: str, expr2: str):
262
264
  prompt = EQUALITY_TEMPLATE % {"expression1": expr1, "expression2": expr2}
263
265
  response = sampler([dict(content=prompt, role="user")])
264
- return response.lower().strip() == "yes"
266
+ return (response or "").lower().strip() == "yes"
265
267
 
266
268
 
267
269
  def _compute_stat(values: list, stat: str):
@@ -80,6 +80,7 @@ class HumanEval(Eval):
80
80
  instruction = "Read the following function signature and docstring, and fully implement the function described. Your response should only contain the code for this function.\n"
81
81
 
82
82
  def find_code(completion):
83
+ completion = completion or ""
83
84
  pattern = re.compile(r"```python\n(.*?)```", re.DOTALL)
84
85
  matches = pattern.findall(completion)
85
86
  extracted_answer = matches[0] if len(matches) >= 1 else completion
@@ -54,6 +54,7 @@ class MathEval(Eval):
54
54
  sampler._pack_message(content=QUERY_TEMPLATE.format(**row), role="user")
55
55
  ]
56
56
  response_text = sampler(prompt_messages)
57
+ response_text = response_text or ""
57
58
  match = re.search(ANSWER_PATTERN, response_text)
58
59
  extracted_answer = match.group(1) if match else None
59
60
  score = float(
@@ -101,6 +101,7 @@ class MMLUEval(Eval):
101
101
  )
102
102
  ]
103
103
  response_text = sampler(prompt_messages)
104
+ response_text = response_text or ""
104
105
  match = re.search(ANSWER_PATTERN_MULTICHOICE, response_text)
105
106
  extracted_answer = match.group(1) if match else None
106
107
  score = 1.0 if extracted_answer == row["Answer"] else 0.0
@@ -204,6 +204,7 @@ class MMMUVLMEval(Eval):
204
204
 
205
205
  # Sample
206
206
  response_text = sampler(prompt_messages)
207
+ response_text = response_text or ""
207
208
 
208
209
  # Parse and score
209
210
  gold = sample["answer"]
@@ -17,7 +17,7 @@ import dataclasses
17
17
  import json
18
18
  import os
19
19
  import random
20
- from typing import List
20
+ from typing import Any, Dict, List, Optional
21
21
 
22
22
  import requests
23
23
 
@@ -78,6 +78,7 @@ class BenchArgs:
78
78
  "single",
79
79
  "prefix",
80
80
  "radix_cache",
81
+ "p_vs_d",
81
82
  ],
82
83
  )
83
84
  parser.add_argument("--profile", action="store_true")
@@ -94,18 +95,21 @@ class BenchArgs:
94
95
 
95
96
  def send_single(
96
97
  args,
97
- batch_size: int = 1,
98
98
  profile: bool = False,
99
99
  profile_steps: int = 3,
100
100
  profile_by_stage: bool = False,
101
101
  return_full_response: bool = False,
102
102
  input_ids: List[int] = None,
103
+ prompt: List[str] = None,
103
104
  max_new_tokens: int = None,
105
+ extra_params: Optional[Dict[str, Any]] = None,
106
+ pick_first_result: bool = True,
104
107
  ):
105
108
  base_url = f"http://{args.host}:{args.port}"
106
109
 
107
110
  # Use input_ids if provided, otherwise use text prompts
108
111
  if input_ids is not None:
112
+ assert prompt is None
109
113
  json_data = {
110
114
  "input_ids": input_ids,
111
115
  "sampling_params": {
@@ -120,9 +124,10 @@ def send_single(
120
124
  },
121
125
  "return_logprob": args.return_logprob,
122
126
  "stream": args.stream,
127
+ **(extra_params or {}),
123
128
  }
124
129
  else:
125
- prompt = [PROMPT_1] * batch_size
130
+ assert input_ids is None
126
131
  json_data = {
127
132
  "text": prompt,
128
133
  "sampling_params": {
@@ -137,6 +142,7 @@ def send_single(
137
142
  },
138
143
  "return_logprob": args.return_logprob,
139
144
  "stream": args.stream,
145
+ **(extra_params or {}),
140
146
  }
141
147
 
142
148
  if args.sampling_seed is not None:
@@ -169,7 +175,8 @@ def send_single(
169
175
  else:
170
176
  ret = response.json()
171
177
 
172
- ret = ret[0] if isinstance(ret, list) else ret
178
+ if pick_first_result:
179
+ ret = ret[0] if isinstance(ret, list) else ret
173
180
 
174
181
  if return_full_response:
175
182
  return ret
@@ -177,7 +184,9 @@ def send_single(
177
184
  return ret["text"]
178
185
 
179
186
 
180
- def send_prefix(args, batch_size: int, prompts: List[str]):
187
+ def send_prefix(
188
+ args, batch_size: int, prompts: List[str], return_full_response: bool = False
189
+ ):
181
190
  requests.post(f"http://{args.host}:{args.port}/flush_cache")
182
191
 
183
192
  batch_data = []
@@ -212,11 +221,157 @@ def send_prefix(args, batch_size: int, prompts: List[str]):
212
221
  print(ret)
213
222
  return -1, -1, -1
214
223
 
215
- ret_dict = {i: [] for i in range(len(prompts))}
216
- for i in range(batch_size):
217
- ret_dict[sampled_indices[i]].append(ret[i]["text"])
224
+ if return_full_response:
225
+ # Return full responses grouped by prompt index
226
+ ret_dict = {i: [] for i in range(len(prompts))}
227
+ for i in range(batch_size):
228
+ ret_dict[sampled_indices[i]].append(ret[i])
229
+ return ret_dict
230
+ else:
231
+ # Return only text grouped by prompt index
232
+ ret_dict = {i: [] for i in range(len(prompts))}
233
+ for i in range(batch_size):
234
+ ret_dict[sampled_indices[i]].append(ret[i]["text"])
235
+ return ret_dict
236
+
237
+
238
+ def compare_logprobs(logprobs1, logprobs2, tolerance=0):
239
+ """Compare two logprobs sequences with a tolerance."""
240
+ if len(logprobs1) != len(logprobs2):
241
+ return False, f"Length mismatch: {len(logprobs1)} vs {len(logprobs2)}"
242
+
243
+ for i, (lp1, lp2) in enumerate(zip(logprobs1, logprobs2)):
244
+ # Each element is [logprob, token_id]
245
+ if lp1[1] != lp2[1]:
246
+ return False, f"Token ID mismatch at position {i}: {lp1[1]} vs {lp2[1]}"
247
+ if abs(lp1[0] - lp2[0]) > tolerance:
248
+ return (
249
+ False,
250
+ f"Logprob mismatch at position {i}: {lp1[0]} vs {lp2[0]} (diff: {abs(lp1[0] - lp2[0])})",
251
+ )
252
+
253
+ return True, "Logprobs match"
254
+
218
255
 
219
- return ret_dict
256
+ def _test_mode_p_vs_d(args, batch_size):
257
+ print()
258
+ print(f"Execute: test p_vs_d {batch_size=}")
259
+
260
+ random.seed(42)
261
+ args.return_logprob = True
262
+ query_extra_params = {
263
+ "logprob_start_len": 0,
264
+ "return_text_in_logprobs": True,
265
+ }
266
+
267
+ def _create_prompts():
268
+ ans = [PROMPT_1, PROMPT_2]
269
+ for i in range(batch_size - len(ans)):
270
+ end = random.randrange(1, 4096)
271
+ if random.random() < 0.5:
272
+ begin = 0
273
+ else:
274
+ begin = random.randrange(0, end)
275
+ ans.append(LONG_PROMPT[begin:end])
276
+ return ans[:batch_size]
277
+
278
+ # warmup + flush
279
+ send_single(args, input_ids=[1] * 64, max_new_tokens=65, return_full_response=True)
280
+ requests.post(f"http://{args.host}:{args.port}/flush_cache")
281
+
282
+ prompts = _create_prompts()
283
+
284
+ resp_a = send_single(
285
+ args,
286
+ prompt=prompts,
287
+ max_new_tokens=args.max_new_tokens,
288
+ return_full_response=True,
289
+ pick_first_result=False,
290
+ extra_params=query_extra_params,
291
+ )
292
+ info_a = _extract_ids_and_logprobs(resp_a)
293
+
294
+ requests.post(f"http://{args.host}:{args.port}/flush_cache")
295
+
296
+ resp_b = send_single(
297
+ args,
298
+ input_ids=[x["io"].token_ids for x in info_a],
299
+ max_new_tokens=1,
300
+ return_full_response=True,
301
+ pick_first_result=False,
302
+ extra_params=query_extra_params,
303
+ )
304
+ info_b = _extract_ids_and_logprobs(resp_b)
305
+
306
+ ans = []
307
+ for i, (info_a_item, info_b_item) in enumerate(zip(info_a, info_b, strict=True)):
308
+ print(f"Compare sequence {i} in batch...")
309
+ correct = TokenIdsAndLogprobs.compare(info_a_item["io"], info_b_item["input"])
310
+ ans.append(int(correct))
311
+
312
+ return ans
313
+
314
+
315
+ @dataclasses.dataclass
316
+ class TokenIdsAndLogprobs:
317
+ token_ids: List[int]
318
+ logprobs: List[float]
319
+
320
+ def __add__(self, other):
321
+ return TokenIdsAndLogprobs(
322
+ token_ids=self.token_ids + other.token_ids,
323
+ logprobs=self.logprobs + other.logprobs,
324
+ )
325
+
326
+ @classmethod
327
+ def compare(cls, a: "TokenIdsAndLogprobs", b: "TokenIdsAndLogprobs"):
328
+ assert len(a.token_ids) == len(b.token_ids)
329
+ token_match = a.token_ids == b.token_ids
330
+ logprobs_match = a.logprobs == b.logprobs
331
+
332
+ if token_match:
333
+ print(f"Token match: {a.token_ids}")
334
+ else:
335
+ print(f"❗Token mismatch: {a.token_ids=} {b.token_ids=}")
336
+
337
+ if logprobs_match:
338
+ print(f"Logprobs match:", a.logprobs)
339
+ else:
340
+ print(f"❗Logprobs mismatch")
341
+ print(
342
+ " A: ",
343
+ [f"{x:.10f}" if x is not None else "None" for x in a.logprobs],
344
+ )
345
+ print(
346
+ " B: ",
347
+ [f"{x:.10f}" if x is not None else "None" for x in b.logprobs],
348
+ )
349
+ diff = [
350
+ abs(x - y) if x is not None else float("nan")
351
+ for x, y in zip(a.logprobs, b.logprobs)
352
+ ]
353
+ print(" Diff:", [f"{x:.10e}" for x in diff])
354
+
355
+ return token_match and logprobs_match
356
+
357
+
358
+ def _extract_ids_and_logprobs(responses):
359
+ def _extract_part(response, name):
360
+ token_ids, logprobs = [], []
361
+ for item in response["meta_info"][name]:
362
+ logprob, token_id, text = item
363
+ token_ids.append(token_id)
364
+ logprobs.append(logprob)
365
+ return TokenIdsAndLogprobs(token_ids=token_ids, logprobs=logprobs)
366
+
367
+ def _extract_one_response(response):
368
+ input = _extract_part(response, "input_token_logprobs")
369
+ output = _extract_part(response, "output_token_logprobs")
370
+ return dict(input=input, output=output, io=input + output)
371
+
372
+ if not isinstance(responses, list):
373
+ responses = [responses]
374
+ return [_extract_one_response(x) for x in responses]
220
375
 
221
376
 
222
377
  def test_deterministic(args):
@@ -225,7 +380,7 @@ def test_deterministic(args):
225
380
  texts = []
226
381
  for i in range(1, args.n_trials + 1):
227
382
  batch_size = i
228
- text = send_single(args, batch_size, args.profile)
383
+ text = send_single(args, args.profile, prompt=[PROMPT_1] * batch_size)
229
384
  text = text.replace("\n", " ")
230
385
  print(f"Trial {i} with batch size {batch_size}: {text}")
231
386
  texts.append(text)
@@ -238,15 +393,28 @@ def test_deterministic(args):
238
393
  num_prompts = len(len_prefix)
239
394
  outputs = {i: [] for i in range(4)}
240
395
  prompts = [LONG_PROMPT[: len_prefix[i]] for i in range(4)]
396
+
397
+ # If return_logprob is enabled, store full responses for comparison
398
+ if args.return_logprob:
399
+ full_responses = {i: [] for i in range(4)}
400
+
241
401
  for i in range(args.n_start, args.n_start + args.n_trials):
242
402
  batch_size = i
243
- ret_dict = send_prefix(args, batch_size, prompts)
403
+ ret_dict = send_prefix(
404
+ args, batch_size, prompts, return_full_response=args.return_logprob
405
+ )
244
406
  msg = f"Testing Trial {i} with batch size {batch_size},"
245
407
  for i in range(num_prompts):
246
408
  msg += f" # prefix length {len_prefix[i]}: {len(ret_dict[i])},"
247
409
  print(msg)
248
410
  for i in range(num_prompts):
249
- outputs[i].extend(ret_dict[i])
411
+ if args.return_logprob:
412
+ # Store full response for logprob comparison
413
+ full_responses[i].extend(ret_dict[i])
414
+ # Extract text for determinism check
415
+ outputs[i].extend([resp["text"] for resp in ret_dict[i]])
416
+ else:
417
+ outputs[i].extend(ret_dict[i])
250
418
 
251
419
  for i in range(num_prompts):
252
420
  print(
@@ -256,6 +424,54 @@ def test_deterministic(args):
256
424
  results = []
257
425
  for i in range(num_prompts):
258
426
  results.append(len(set(outputs[i])))
427
+
428
+ # If logprobs are enabled, compare them across different batch sizes
429
+ if args.return_logprob:
430
+ print(f"\n{'='*60}")
431
+ print("Logprobs Comparison Across Batch Sizes")
432
+ print("=" * 60)
433
+
434
+ logprob_results = []
435
+ for prompt_idx in range(num_prompts):
436
+ print(
437
+ f"\nPrompt {prompt_idx} (prefix length {len_prefix[prompt_idx]}):"
438
+ )
439
+ responses = full_responses[prompt_idx]
440
+
441
+ if len(responses) < 2:
442
+ continue
443
+
444
+ # Compare all responses against the first one
445
+ reference = responses[0]
446
+ all_match = True
447
+ mismatches = []
448
+
449
+ for j, resp in enumerate(responses[1:], start=1):
450
+ ref_logprobs = reference["meta_info"]["output_token_logprobs"]
451
+ resp_logprobs = resp["meta_info"]["output_token_logprobs"]
452
+
453
+ match, msg = compare_logprobs(ref_logprobs, resp_logprobs)
454
+
455
+ if not match:
456
+ print(f" ✗ Sample {j+1}: {msg}")
457
+ mismatches.append((j + 1, msg))
458
+ all_match = False
459
+
460
+ if all_match:
461
+ print(f" ✓ All {len(responses)} samples have identical logprobs")
462
+ logprob_results.append(1)
463
+ else:
464
+ print(
465
+ f" ✗ Found {len(mismatches)} mismatches out of {len(responses)} samples"
466
+ )
467
+ logprob_results.append(0)
468
+
469
+ print(f"\n{'='*60}")
470
+ if all(r == 1 for r in logprob_results):
471
+ print("✓✓✓ Logprobs are identical across all batch sizes! ✓✓✓")
472
+ else:
473
+ print("✗✗✗ Some logprobs differ across batch sizes! ✗✗✗")
474
+
259
475
  return results
260
476
 
261
477
  elif args.test_mode == "radix_cache":
@@ -415,6 +631,13 @@ def test_deterministic(args):
415
631
  print("✗✗✗ TEST FAILED - Radix cache produces different results! ✗✗✗")
416
632
  return [0]
417
633
 
634
+ elif args.test_mode == "p_vs_d":
635
+ # TODO also extract other modes to functions
636
+ ans = []
637
+ for i in range(1, args.n_trials + 1):
638
+ ans += _test_mode_p_vs_d(args, batch_size=i)
639
+ return ans
640
+
418
641
  else:
419
642
  raise ValueError(f"Invalid test mode: {args.test_mode}")
420
643
 
@@ -60,7 +60,7 @@ class TestDeterministicBase(CustomTestCase):
60
60
  for result in results:
61
61
  assert result == 1
62
62
 
63
- def test_prefix(self):
63
+ def test_prefix_with_logprobs(self):
64
64
  args = BenchArgs()
65
65
  url = DEFAULT_URL_FOR_TEST
66
66
  args.host, args.port = self._extract_host_and_port(url)
@@ -68,6 +68,7 @@ class TestDeterministicBase(CustomTestCase):
68
68
  args.n_start = 10
69
69
  args.n_trials = 10
70
70
  args.temperature = 0.5 # test for deterministic sampling
71
+ args.return_logprob = True # Enable logprobs comparison
71
72
  results = test_deterministic(args)
72
73
  for result in results:
73
74
  assert result == 1
sglang/test/test_utils.py CHANGED
@@ -84,6 +84,8 @@ DEFAULT_MODEL_NAME_FOR_TEST_AWQ_INT4 = (
84
84
  DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST = "meta-llama/Llama-2-7b-chat-hf"
85
85
  DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST = "lmsys/sglang-EAGLE-llama2-chat-7B"
86
86
  DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST_EAGLE3 = "meta-llama/Llama-3.1-8B-Instruct"
87
+ DEFAULT_EAGLE_DP_ATTENTION_TARGET_MODEL_FOR_TEST = "Qwen/Qwen3-30B-A3B"
88
+ DEFAULT_EAGLE_DP_ATTENTION_DRAFT_MODEL_FOR_TEST = "Tengyunw/qwen3_30b_moe_eagle3"
87
89
  DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3 = "lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B"
88
90
  DEFAULT_STANDALONE_SPECULATIVE_TARGET_MODEL_FOR_TEST = (
89
91
  "meta-llama/Llama-3.1-8B-Instruct"
@@ -92,6 +94,10 @@ DEFAULT_STANDALONE_SPECULATIVE_DRAFT_MODEL_FOR_TEST = "meta-llama/Llama-3.2-1B-I
92
94
  DEFAULT_NGRAM_SPECULATIVE_TARGET_MODEL_FOR_TEST = "Qwen/Qwen2.5-Coder-7B-Instruct"
93
95
 
94
96
  # Other use cases
97
+ DEFAULT_AUTOROUND_MODEL_NAME_FOR_TEST = (
98
+ "OPEA/Qwen2.5-0.5B-Instruct-int4-sym-inc", # auto_round:auto_gptq
99
+ "Intel/Qwen2-0.5B-Instruct-int4-sym-AutoRound", # auto_round:auto_awq
100
+ )
95
101
  DEFAULT_MODEL_NAME_FOR_TEST_LOCAL_ATTENTION = (
96
102
  "meta-llama/Llama-4-Scout-17B-16E-Instruct"
97
103
  )
@@ -145,7 +151,7 @@ def _use_cached_default_models(model_repo: str):
145
151
 
146
152
  if is_in_ci():
147
153
  DEFAULT_PORT_FOR_SRT_TEST_RUNNER = (
148
- 10000 + int(os.environ.get("CUDA_VISIBLE_DEVICES", "0")[0]) * 1000
154
+ 10000 + int(os.environ.get("CUDA_VISIBLE_DEVICES", "0")[0]) * 2000
149
155
  )
150
156
  else:
151
157
  DEFAULT_PORT_FOR_SRT_TEST_RUNNER = (
sglang/version.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.5.4"
1
+ __version__ = "0.5.4.post2"