sglang 0.4.9.post3__py3-none-any.whl → 0.4.9.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (128) hide show
  1. sglang/lang/chat_template.py +21 -0
  2. sglang/srt/_custom_ops.py +29 -1
  3. sglang/srt/configs/internvl.py +3 -0
  4. sglang/srt/configs/model_config.py +5 -1
  5. sglang/srt/constrained/base_grammar_backend.py +10 -2
  6. sglang/srt/constrained/xgrammar_backend.py +7 -5
  7. sglang/srt/conversation.py +17 -2
  8. sglang/srt/debug_utils/__init__.py +0 -0
  9. sglang/srt/debug_utils/dump_comparator.py +131 -0
  10. sglang/srt/debug_utils/dumper.py +108 -0
  11. sglang/srt/debug_utils/text_comparator.py +172 -0
  12. sglang/srt/disaggregation/common/conn.py +34 -6
  13. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +13 -1
  14. sglang/srt/disaggregation/mini_lb.py +3 -2
  15. sglang/srt/disaggregation/mooncake/conn.py +65 -20
  16. sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
  17. sglang/srt/disaggregation/nixl/conn.py +17 -13
  18. sglang/srt/disaggregation/prefill.py +13 -1
  19. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
  20. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
  21. sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
  22. sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
  23. sglang/srt/distributed/parallel_state.py +70 -15
  24. sglang/srt/entrypoints/engine.py +5 -9
  25. sglang/srt/entrypoints/http_server.py +20 -32
  26. sglang/srt/entrypoints/openai/protocol.py +3 -3
  27. sglang/srt/entrypoints/openai/serving_chat.py +148 -72
  28. sglang/srt/function_call/base_format_detector.py +74 -12
  29. sglang/srt/function_call/deepseekv3_detector.py +26 -11
  30. sglang/srt/function_call/ebnf_composer.py +105 -66
  31. sglang/srt/function_call/function_call_parser.py +6 -4
  32. sglang/srt/function_call/glm4_moe_detector.py +164 -0
  33. sglang/srt/function_call/kimik2_detector.py +41 -16
  34. sglang/srt/function_call/llama32_detector.py +6 -3
  35. sglang/srt/function_call/mistral_detector.py +11 -3
  36. sglang/srt/function_call/pythonic_detector.py +16 -14
  37. sglang/srt/function_call/qwen25_detector.py +12 -3
  38. sglang/srt/function_call/{qwen3_detector.py → qwen3_coder_detector.py} +11 -9
  39. sglang/srt/layers/activation.py +11 -3
  40. sglang/srt/layers/attention/base_attn_backend.py +3 -1
  41. sglang/srt/layers/attention/hybrid_attn_backend.py +100 -0
  42. sglang/srt/layers/attention/vision.py +56 -8
  43. sglang/srt/layers/communicator.py +12 -12
  44. sglang/srt/layers/dp_attention.py +72 -24
  45. sglang/srt/layers/layernorm.py +26 -1
  46. sglang/srt/layers/logits_processor.py +46 -25
  47. sglang/srt/layers/moe/ep_moe/layer.py +172 -206
  48. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=160,N=320,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  49. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
  50. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +25 -224
  51. sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -48
  52. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +11 -8
  53. sglang/srt/layers/moe/topk.py +88 -34
  54. sglang/srt/layers/multimodal.py +11 -8
  55. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -9
  56. sglang/srt/layers/quantization/fp8.py +25 -247
  57. sglang/srt/layers/quantization/fp8_kernel.py +78 -48
  58. sglang/srt/layers/quantization/modelopt_quant.py +33 -14
  59. sglang/srt/layers/quantization/unquant.py +24 -76
  60. sglang/srt/layers/quantization/utils.py +0 -9
  61. sglang/srt/layers/quantization/w4afp8.py +68 -17
  62. sglang/srt/layers/radix_attention.py +5 -3
  63. sglang/srt/lora/lora_manager.py +133 -169
  64. sglang/srt/lora/lora_registry.py +188 -0
  65. sglang/srt/lora/mem_pool.py +2 -2
  66. sglang/srt/managers/cache_controller.py +62 -13
  67. sglang/srt/managers/io_struct.py +19 -1
  68. sglang/srt/managers/mm_utils.py +154 -35
  69. sglang/srt/managers/multimodal_processor.py +3 -14
  70. sglang/srt/managers/schedule_batch.py +27 -11
  71. sglang/srt/managers/scheduler.py +48 -26
  72. sglang/srt/managers/tokenizer_manager.py +62 -28
  73. sglang/srt/managers/tp_worker.py +5 -4
  74. sglang/srt/mem_cache/allocator.py +67 -7
  75. sglang/srt/mem_cache/hicache_storage.py +17 -1
  76. sglang/srt/mem_cache/hiradix_cache.py +35 -18
  77. sglang/srt/mem_cache/memory_pool_host.py +3 -0
  78. sglang/srt/model_executor/cuda_graph_runner.py +61 -25
  79. sglang/srt/model_executor/forward_batch_info.py +201 -29
  80. sglang/srt/model_executor/model_runner.py +109 -37
  81. sglang/srt/models/deepseek_v2.py +63 -30
  82. sglang/srt/models/glm4_moe.py +1035 -0
  83. sglang/srt/models/glm4_moe_nextn.py +167 -0
  84. sglang/srt/models/interns1.py +328 -0
  85. sglang/srt/models/internvl.py +143 -47
  86. sglang/srt/models/llava.py +9 -5
  87. sglang/srt/models/minicpmo.py +4 -1
  88. sglang/srt/models/mllama4.py +10 -3
  89. sglang/srt/models/qwen2_moe.py +2 -6
  90. sglang/srt/models/qwen3_moe.py +6 -8
  91. sglang/srt/multimodal/processors/base_processor.py +20 -6
  92. sglang/srt/multimodal/processors/clip.py +2 -2
  93. sglang/srt/multimodal/processors/deepseek_vl_v2.py +2 -2
  94. sglang/srt/multimodal/processors/gemma3.py +2 -2
  95. sglang/srt/multimodal/processors/gemma3n.py +2 -2
  96. sglang/srt/multimodal/processors/internvl.py +21 -8
  97. sglang/srt/multimodal/processors/janus_pro.py +2 -2
  98. sglang/srt/multimodal/processors/kimi_vl.py +2 -2
  99. sglang/srt/multimodal/processors/llava.py +4 -4
  100. sglang/srt/multimodal/processors/minicpm.py +2 -3
  101. sglang/srt/multimodal/processors/mlama.py +2 -2
  102. sglang/srt/multimodal/processors/mllama4.py +18 -111
  103. sglang/srt/multimodal/processors/phi4mm.py +2 -2
  104. sglang/srt/multimodal/processors/pixtral.py +2 -2
  105. sglang/srt/multimodal/processors/qwen_audio.py +2 -2
  106. sglang/srt/multimodal/processors/qwen_vl.py +2 -2
  107. sglang/srt/multimodal/processors/vila.py +3 -1
  108. sglang/srt/reasoning_parser.py +48 -5
  109. sglang/srt/sampling/sampling_batch_info.py +6 -5
  110. sglang/srt/server_args.py +132 -60
  111. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
  112. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +37 -36
  113. sglang/srt/speculative/eagle_utils.py +51 -23
  114. sglang/srt/speculative/eagle_worker.py +59 -44
  115. sglang/srt/two_batch_overlap.py +9 -5
  116. sglang/srt/utils.py +113 -69
  117. sglang/srt/weight_sync/utils.py +119 -0
  118. sglang/test/runners.py +4 -0
  119. sglang/test/test_activation.py +50 -1
  120. sglang/test/test_utils.py +65 -5
  121. sglang/utils.py +19 -0
  122. sglang/version.py +1 -1
  123. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/METADATA +6 -6
  124. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/RECORD +127 -114
  125. sglang/srt/debug_utils.py +0 -74
  126. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/WHEEL +0 -0
  127. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/licenses/LICENSE +0 -0
  128. {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/top_level.txt +0 -0
@@ -448,6 +448,19 @@ register_chat_template(
448
448
  )
449
449
  )
450
450
 
451
+ register_chat_template(
452
+ ChatTemplate(
453
+ name="interns1",
454
+ default_system_prompt="You are an AI assistant whose name is Intern-S1 (书生大模型).\n- Intern-S1 (书生大模型) is a vision-language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n- Intern-S1 (书生大模型) can understand and communicate fluently in the language chosen by the user such as English and 中文.\nYou are an expert reasoner with extensive experience in all areas. You approach problems through systematic thinking and rigorous reasoning. Your response should reflect deep understanding and precise logical thinking, making your solution path and reasoning clear to others. Please put your thinking process within <think>...</think> tags.",
455
+ role_prefix_and_suffix={
456
+ "system": ("<|im_start|>system\n", "<|im_end|>\n"),
457
+ "user": ("<|im_start|>user\n", "<|im_end|>\n"),
458
+ "assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
459
+ },
460
+ stop_str=["<|im_end|>", "<|action_end|>"],
461
+ )
462
+ )
463
+
451
464
  register_chat_template(
452
465
  ChatTemplate(
453
466
  name="granite-3-instruct",
@@ -609,6 +622,14 @@ def match_internvl_chat(model_path: str):
609
622
  return "internvl-2-5"
610
623
 
611
624
 
625
+ @register_chat_template_matching_function
626
+ def match_interns1_chat(model_path: str):
627
+ if re.search(r"intern-s1", model_path, re.IGNORECASE):
628
+ return "interns1"
629
+ if re.search(r"interns1", model_path, re.IGNORECASE):
630
+ return "interns1"
631
+
632
+
612
633
  if __name__ == "__main__":
613
634
  messages = [
614
635
  {"role": "system", "content": None}, # None means default
sglang/srt/_custom_ops.py CHANGED
@@ -1,6 +1,6 @@
1
1
  # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/_custom_ops.py
2
2
  import logging
3
- from typing import List, Tuple
3
+ from typing import List, Optional, Tuple
4
4
 
5
5
  import torch
6
6
 
@@ -114,6 +114,34 @@ else:
114
114
  def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor:
115
115
  return sgl_kernel.allreduce.get_meta_buffer_ipc_handle(inp)
116
116
 
117
+ # ROCM custom quick allreduce
118
+
119
+ def init_custom_qr(
120
+ rank: int, world_size: int, qr_max_size: Optional[int] = None
121
+ ) -> int:
122
+ return sgl_kernel.allreduce.init_custom_qr(world_size, rank, qr_max_size)
123
+
124
+ def qr_get_handle(fa: int) -> torch.Tensor:
125
+ return sgl_kernel.allreduce.qr_get_handle(fa)
126
+
127
+ def qr_open_handles(fa: int, handles: list[torch.Tensor]) -> None:
128
+ sgl_kernel.allreduce.qr_open_handles(fa, handles)
129
+
130
+ def qr_all_reduce(
131
+ fa: int,
132
+ inp: torch.Tensor,
133
+ out: torch.Tensor,
134
+ quant_level: int,
135
+ cast_bf2half: bool,
136
+ ) -> None:
137
+ sgl_kernel.allreduce.qr_all_reduce(fa, inp, out, quant_level, cast_bf2half)
138
+
139
+ def qr_destroy(fa: int) -> None:
140
+ sgl_kernel.allreduce.qr_destroy(fa)
141
+
142
+ def qr_max_size() -> int:
143
+ return sgl_kernel.allreduce.qr_max_size()
144
+
117
145
 
118
146
  def mscclpp_generate_unique_id() -> bytes:
119
147
  return sgl_kernel.allreduce.mscclpp_generate_unique_id()
@@ -10,6 +10,7 @@ from transformers import (
10
10
  PretrainedConfig,
11
11
  PreTrainedTokenizer,
12
12
  Qwen2Config,
13
+ Qwen3Config,
13
14
  )
14
15
 
15
16
  from sglang.utils import logger
@@ -314,6 +315,8 @@ class InternVLChatConfig(PretrainedConfig):
314
315
  self.llm_config = InternLM2Config(**llm_config)
315
316
  elif llm_config.get("architectures")[0] == "Qwen2ForCausalLM":
316
317
  self.llm_config = Qwen2Config(**llm_config)
318
+ elif llm_config.get("architectures")[0] == "Qwen3MoeForCausalLM":
319
+ self.llm_config = Qwen3Config(**llm_config)
317
320
  else:
318
321
  raise ValueError(
319
322
  "Unsupported architecture: {}".format(
@@ -127,6 +127,9 @@ class ModelConfig:
127
127
  ):
128
128
  self.hf_config.architectures[0] = "DeepseekV3ForCausalLMNextN"
129
129
 
130
+ if is_draft_model and self.hf_config.architectures[0] == "Glm4MoeForCausalLM":
131
+ self.hf_config.architectures[0] = "Glm4MoeForCausalLMNextN"
132
+
130
133
  if is_draft_model and self.hf_config.architectures[0] == "MiMoForCausalLM":
131
134
  self.hf_config.architectures[0] = "MiMoMTP"
132
135
  # Check model type
@@ -475,7 +478,7 @@ class ModelConfig:
475
478
 
476
479
  def get_hf_eos_token_id(self) -> Optional[Set[int]]:
477
480
  eos_ids = getattr(self.hf_config, "eos_token_id", None)
478
- if eos_ids:
481
+ if eos_ids is not None:
479
482
  # it can be either int or list of int
480
483
  eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
481
484
  if eos_ids is None:
@@ -635,6 +638,7 @@ multimodal_model_archs = [
635
638
  "Qwen2_5_VLForConditionalGeneration",
636
639
  "KimiVLForConditionalGeneration",
637
640
  "InternVLChatModel",
641
+ "InternS1ForConditionalGeneration",
638
642
  "Phi4MMForCausalLM",
639
643
  "VILAForConditionalGeneration",
640
644
  ]
@@ -168,7 +168,10 @@ class BaseGrammarBackend:
168
168
 
169
169
 
170
170
  def create_grammar_backend(
171
- server_args: ServerArgs, tokenizer, vocab_size: int
171
+ server_args: ServerArgs,
172
+ tokenizer,
173
+ vocab_size: int,
174
+ eos_token_ids: Optional[set] = None,
172
175
  ) -> Optional[BaseGrammarBackend]:
173
176
  if server_args.grammar_backend == "outlines":
174
177
  from sglang.srt.constrained.outlines_backend import OutlinesGrammarBackend
@@ -180,7 +183,12 @@ def create_grammar_backend(
180
183
  elif server_args.grammar_backend == "xgrammar":
181
184
  from sglang.srt.constrained.xgrammar_backend import XGrammarGrammarBackend
182
185
 
183
- grammar_backend = XGrammarGrammarBackend(tokenizer, vocab_size=vocab_size)
186
+ # Convert Set[int] to List[int] if needed
187
+ eos_list = list(eos_token_ids) if eos_token_ids else None
188
+
189
+ grammar_backend = XGrammarGrammarBackend(
190
+ tokenizer, vocab_size=vocab_size, model_eos_token_ids=eos_list
191
+ )
184
192
  elif server_args.grammar_backend == "llguidance":
185
193
  from sglang.srt.constrained.llguidance_backend import GuidanceBackend
186
194
 
@@ -150,14 +150,16 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
150
150
  self,
151
151
  tokenizer,
152
152
  vocab_size: int,
153
+ model_eos_token_ids: Optional[List[int]] = None,
153
154
  ):
154
155
  super().__init__()
155
156
 
156
- if True:
157
- tokenizer_info = TokenizerInfo.from_huggingface(
158
- tokenizer, vocab_size=vocab_size
159
- )
160
- override_stop_tokens = None
157
+ # Create TokenizerInfo with model's EOS tokens as the authoritative stop tokens
158
+ # This ensures consistency between what the model considers EOS and what XGrammar uses
159
+ tokenizer_info = TokenizerInfo.from_huggingface(
160
+ tokenizer, vocab_size=vocab_size, stop_token_ids=model_eos_token_ids
161
+ )
162
+ override_stop_tokens = None
161
163
 
162
164
  self.grammar_compiler = GrammarCompiler(tokenizer_info=tokenizer_info)
163
165
  self.vocab_size = vocab_size
@@ -623,7 +623,7 @@ def generate_chat_conv(
623
623
  real_content += content.text
624
624
  elif content.type == "image_url":
625
625
  # NOTE: works for llava and intervl2_5
626
- if conv.name == "internvl-2-5":
626
+ if conv.name in ["internvl-2-5", "interns1"]:
627
627
  real_content = image_token + real_content
628
628
  else:
629
629
  real_content += image_token
@@ -817,6 +817,19 @@ register_conv_template(
817
817
  )
818
818
  )
819
819
 
820
+ register_conv_template(
821
+ Conversation(
822
+ name="interns1",
823
+ system_template="<|im_start|>system\n{system_message}",
824
+ system_message="You are an AI assistant whose name is Intern-S1 (书生大模型).\n- Intern-S1 (书生大模型) is a vision-language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n- Intern-S1 (书生大模型) can understand and communicate fluently in the language chosen by the user such as English and 中文.\nYou are an expert reasoner with extensive experience in all areas. You approach problems through systematic thinking and rigorous reasoning. Your response should reflect deep understanding and precise logical thinking, making your solution path and reasoning clear to others. Please put your thinking process within <think>...</think> tags.",
825
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
826
+ sep_style=SeparatorStyle.MPT,
827
+ sep="<|im_end|>\n",
828
+ stop_str=["<|im_end|>", "<|action_end|>"],
829
+ image_token="<image>",
830
+ )
831
+ )
832
+
820
833
  # Reference: https://huggingface.co/docs/transformers/main/model_doc/qwen2_vl#usage-example
821
834
  register_conv_template(
822
835
  Conversation(
@@ -984,8 +997,10 @@ register_conv_template(
984
997
 
985
998
  @register_conv_template_matching_function
986
999
  def match_internvl(model_path: str):
987
- if re.search(r"internvl2_5", model_path, re.IGNORECASE):
1000
+ if re.search(r"internvl", model_path, re.IGNORECASE):
988
1001
  return "internvl-2-5"
1002
+ if re.search(r"intern.*s1", model_path, re.IGNORECASE):
1003
+ return "interns1"
989
1004
 
990
1005
 
991
1006
  @register_conv_template_matching_function
File without changes
@@ -0,0 +1,131 @@
1
+ import argparse
2
+ import functools
3
+ import re
4
+ from pathlib import Path
5
+
6
+ import polars as pl
7
+ import torch
8
+
9
+ from sglang.srt.debug_utils.dumper import get_truncated_value
10
+
11
+
12
+ def main(args):
13
+ df_target = read_meta(args.target_path)
14
+ df_target = df_target.sort("rank", "dump_index")
15
+ df_target = df_target.filter(
16
+ (pl.col("forward_pass_id") >= args.start_id)
17
+ & (pl.col("forward_pass_id") <= args.end_id)
18
+ )
19
+ assert all(
20
+ c in df_target.columns
21
+ for c in ["rank", "forward_pass_id", "dump_index", "name"]
22
+ )
23
+
24
+ df_baseline = read_meta(args.baseline_path)
25
+ print("df_target", df_target)
26
+ print("df_baseline", df_baseline)
27
+
28
+ for row in df_target.iter_rows(named=True):
29
+ rows_baseline = df_baseline.filter(
30
+ (
31
+ pl.col("forward_pass_id")
32
+ == row["forward_pass_id"] - args.start_id + args.baseline_start_id
33
+ )
34
+ & functools.reduce(
35
+ lambda a, b: a & b,
36
+ [
37
+ pl.col(col) == row[col]
38
+ for col in row.keys()
39
+ if col not in ["forward_pass_id", "dump_index", "filename"]
40
+ ],
41
+ )
42
+ )
43
+ assert len(rows_baseline) == 1, f"{rows_baseline=}"
44
+ row_baseline = rows_baseline.to_dicts()[0]
45
+
46
+ path_baseline = Path(args.baseline_path) / row_baseline["filename"]
47
+ path_target = Path(args.target_path) / row["filename"]
48
+ print(f"Check: target={str(path_target)} baseline={str(path_baseline)}")
49
+ check_tensor_pair(path_baseline=path_baseline, path_target=path_target)
50
+ print()
51
+
52
+
53
+ def read_meta(directory):
54
+ directory = Path(directory)
55
+ assert directory.is_dir(), f"{directory=} should be a directory"
56
+
57
+ rows = []
58
+ for p in directory.glob("*.pt"):
59
+ full_kwargs = {}
60
+ for kv in p.stem.split("___"):
61
+ k, v = kv.split("=")
62
+ full_kwargs[k] = v
63
+ rows.append(
64
+ {
65
+ "filename": str(p.name),
66
+ **full_kwargs,
67
+ }
68
+ )
69
+
70
+ df = pl.DataFrame(rows)
71
+ df = df.with_columns(
72
+ pl.col("forward_pass_id").cast(int),
73
+ pl.col("rank").cast(int),
74
+ )
75
+ return df
76
+
77
+
78
+ def check_tensor_pair(path_baseline, path_target):
79
+ x_baseline = torch.load(path_baseline, weights_only=True)
80
+ x_target = torch.load(path_target, weights_only=True)
81
+
82
+ print(
83
+ f"[shape] {x_baseline.shape} vs {x_target.shape}\t"
84
+ f"[dtype] {x_baseline.dtype} vs {x_target.dtype}"
85
+ )
86
+
87
+ if x_baseline.shape != x_target.shape:
88
+ print(f"❌ Shape mismatch")
89
+ return
90
+
91
+ raw_abs_diff = (x_target - x_baseline).abs()
92
+
93
+ max_abs_diff = raw_abs_diff.max().item()
94
+ mean_abs_diff = raw_abs_diff.mean().item()
95
+ rel_diff = _calc_rel_diff(x_target, x_baseline)
96
+
97
+ needs_print = max_abs_diff > 1e-3
98
+
99
+ print(
100
+ "\t".join(
101
+ f"{'❌' if value > 1e-3 else '✅'} {name}={value}"
102
+ for name, value in [
103
+ ("rel_diff", rel_diff),
104
+ ("max_abs_diff", max_abs_diff),
105
+ ("mean_abs_diff", mean_abs_diff),
106
+ ]
107
+ )
108
+ )
109
+
110
+ if needs_print:
111
+ print(f"x_baseline(sample)={get_truncated_value(x_baseline)}")
112
+ print(f"x_target(sample)={get_truncated_value(x_target)}")
113
+
114
+
115
+ # Copied from DeepGEMM
116
+ def _calc_rel_diff(x: torch.Tensor, y: torch.Tensor):
117
+ x, y = x.double(), y.double()
118
+ denominator = (x * x + y * y).sum()
119
+ sim = 2 * (x * y).sum() / denominator
120
+ return 1 - sim
121
+
122
+
123
+ if __name__ == "__main__":
124
+ parser = argparse.ArgumentParser()
125
+ parser.add_argument("--baseline-path", type=str)
126
+ parser.add_argument("--target-path", type=str)
127
+ parser.add_argument("--start-id", type=int, default=0)
128
+ parser.add_argument("--end-id", type=int, default=1000000)
129
+ parser.add_argument("--baseline-start-id", type=int, default=0)
130
+ args = parser.parse_args()
131
+ main(args)
@@ -0,0 +1,108 @@
1
+ import os
2
+ import time
3
+ from pathlib import Path
4
+ from typing import Optional
5
+
6
+ import torch
7
+ import torch.distributed as dist
8
+
9
+
10
+ class _Dumper:
11
+ """Utility to dump tensors, which can be useful when comparison checking models.
12
+
13
+ Example usage:
14
+ dumper.on_forward_pass_start()
15
+ dumper.dump("layer_start__hidden_states", hidden_states, layer_id=self.layer_id)
16
+
17
+ Import from non-SGLang system:
18
+ ```
19
+ import sys
20
+ sys.path.append("/YOUR_PATH/sglang/python/sglang/srt/debug_utils")
21
+ from dumper import dumper
22
+ ```
23
+
24
+ Related: `sglang.srt.debug_utils.dump_comparator` for dump comparison
25
+ """
26
+
27
+ def __init__(self):
28
+ # Do not import `sglang` to make this file standalone
29
+ self._enable = bool(int(os.environ.get("SGLANG_DUMPER_ENABLE", "1")))
30
+ self._base_dir = Path(os.environ.get("SGLANG_DUMPER_DIR", "/tmp"))
31
+ self._enable_write_file = bool(
32
+ int(os.environ.get("SGLANG_DUMPER_WRITE_FILE", "1"))
33
+ )
34
+ self._partial_name: Optional[str] = None
35
+ self._dump_index = 0
36
+ self._forward_pass_id = 0
37
+
38
+ def on_forward_pass_start(self):
39
+ self._forward_pass_id += 1
40
+ print(
41
+ f"[Dumper] [{time.time()}] on_forward_pass_start id={self._forward_pass_id}"
42
+ )
43
+
44
+ def dump(self, name, value, **kwargs):
45
+ if not self._enable:
46
+ return
47
+
48
+ assert (
49
+ self._forward_pass_id >= 1
50
+ ), "Do you forget to call `dumper.on_forward_pass_start()`?"
51
+ self._dump_index += 1
52
+
53
+ if self._partial_name is None:
54
+ self._partial_name = _get_partial_name()
55
+
56
+ rank = dist.get_rank()
57
+ full_kwargs = dict(
58
+ forward_pass_id=self._forward_pass_id,
59
+ rank=rank,
60
+ name=name,
61
+ dump_index=self._dump_index,
62
+ **kwargs,
63
+ )
64
+ full_filename = "___".join(f"{k}={v}" for k, v in full_kwargs.items()) + ".pt"
65
+ path = self._base_dir / f"sglang_dump_{self._partial_name}" / full_filename
66
+
67
+ sample_value = get_truncated_value(value)
68
+
69
+ print(
70
+ f"[Dumper] [{rank}, {time.time()}] {path} "
71
+ f"type={type(value)} "
72
+ f"shape={value.shape if isinstance(value, torch.Tensor) else None} "
73
+ f"dtype={value.dtype if isinstance(value, torch.Tensor) else None} "
74
+ f"sample_value={sample_value}"
75
+ )
76
+
77
+ if self._enable_write_file:
78
+ path.parent.mkdir(parents=True, exist_ok=True)
79
+ torch.save(value, str(path))
80
+
81
+
82
+ def _get_partial_name():
83
+ rank = dist.get_rank()
84
+ object_list = [str(time.time()) if rank == 0 else None]
85
+ dist.broadcast_object_list(object_list, device="cuda")
86
+ return object_list[0]
87
+
88
+
89
+ def get_truncated_value(value):
90
+ if value is None:
91
+ return None
92
+
93
+ if isinstance(value, tuple):
94
+ return [get_truncated_value(x) for x in value]
95
+
96
+ if not isinstance(value, torch.Tensor):
97
+ return None
98
+
99
+ if value.numel() < 200:
100
+ return value
101
+
102
+ slices = [
103
+ slice(0, 5) if dim_size > 200 else slice(None) for dim_size in value.shape
104
+ ]
105
+ return value[tuple(slices)]
106
+
107
+
108
+ dumper = _Dumper()
@@ -0,0 +1,172 @@
1
+ import argparse
2
+ import json
3
+ from pathlib import Path
4
+
5
+ import polars as pl
6
+
7
+ _DESCRIPTION = """Compare and find differences to benchmark outputs.
8
+
9
+ Supported inputs:
10
+ * The samples jsonl from `lm_eval --log_samples --output_path FOLDER_NAME`
11
+ * The output from `gsm8k/bench_sglang.py --raw-result-file FILE_NAME` (or mmlu)
12
+ """
13
+
14
+
15
+ def main(args):
16
+ df_input = _transform_df_input(_compute_df_raw(args))
17
+ assert all(
18
+ c in df_input.columns
19
+ for c in ["category", "trial_index", "prompt_id", "prompt", "output", "correct"]
20
+ )
21
+
22
+ df_meta = _compute_df_meta(df_input)
23
+
24
+ df_correctness_per_trial = df_input.group_by(
25
+ "category", "trial_index", maintain_order=True
26
+ ).agg(pl.col("correct").mean())
27
+ df_correctness_delta = (
28
+ df_meta.group_by("correctness_delta").len().sort("correctness_delta")
29
+ )
30
+ df_good_to_bad = df_meta.filter(pl.col("correctness_delta") < 0)
31
+ df_bad_to_good = df_meta.filter(pl.col("correctness_delta") > 0)
32
+
33
+ print(f"Dump output to {args.output_path}")
34
+ Path(args.output_path).write_text(
35
+ json.dumps(
36
+ dict(
37
+ df_meta=df_meta.to_dicts(),
38
+ df_good_to_bad=df_good_to_bad.to_dicts(),
39
+ df_bad_to_good=df_bad_to_good.to_dicts(),
40
+ )
41
+ )
42
+ )
43
+
44
+ if not args.disable_print_details:
45
+ with pl.Config(
46
+ fmt_str_lengths=10000,
47
+ tbl_cols=-1,
48
+ tbl_rows=-1,
49
+ tbl_width_chars=-1,
50
+ tbl_formatting="UTF8_FULL",
51
+ ):
52
+ print("====== Correctness per trial ======")
53
+ print(df_correctness_per_trial)
54
+
55
+ print(
56
+ "====== Correctness Delta (-1.0 means all-right becomes all-wrong) ======"
57
+ )
58
+ print(df_correctness_delta)
59
+
60
+ for name, df in [
61
+ ("Good->Bad", df_good_to_bad),
62
+ ("Bad->Good", df_bad_to_good),
63
+ ]:
64
+ print(f"====== Concrete Examples: {name} ======")
65
+ print(df)
66
+
67
+
68
+ def _compute_df_raw(args):
69
+ return pl.concat(
70
+ [
71
+ _read_df_raw(p, category=category, trial_index=i)
72
+ for category, paths in [
73
+ ("baseline", args.baseline_path),
74
+ ("target", args.target_path),
75
+ ]
76
+ for i, p in enumerate(paths)
77
+ ]
78
+ )
79
+
80
+
81
+ def _read_df_raw(path: str, category: str, trial_index: int):
82
+ return pl.read_ndjson(path).with_columns(
83
+ category=pl.lit(category), trial_index=trial_index
84
+ )
85
+
86
+
87
+ def _transform_df_input(df: pl.DataFrame):
88
+ if "doc_id" in df.columns:
89
+ print("Transform mode: lm_eval")
90
+
91
+ filter_names = df["filter"].unique(maintain_order=True).to_list()
92
+ if len(filter_names) > 1:
93
+ filter_name = filter_names[0]
94
+ print(f"Choose {filter_name=} among {filter_names}")
95
+ df = df.filter(pl.col("filter") == filter_name)
96
+
97
+ df = df.select(
98
+ pl.col("category"),
99
+ pl.col("trial_index"),
100
+ prompt_id=pl.col("doc_id"),
101
+ prompt=pl.col("arguments").struct.field("gen_args_0").struct.field("arg_0"),
102
+ output=pl.col("resps").list.get(0).list.get(0),
103
+ correct=pl.col("exact_match").cast(bool),
104
+ )
105
+
106
+ return df
107
+ elif "prompt_id" in df.columns:
108
+ print("Transform mode: SGLang bench")
109
+ return df
110
+ else:
111
+ raise Exception(f"Unknown data: {df.columns}")
112
+
113
+
114
+ def _compute_df_meta(df_input: pl.DataFrame):
115
+ df_input = df_input.sort("prompt_id", "category", "trial_index")
116
+ df_meta = pl.DataFrame(
117
+ [
118
+ _handle_one_prompt(df_one_prompt)
119
+ for df_one_prompt in df_input.partition_by("prompt_id", maintain_order=True)
120
+ ]
121
+ )
122
+ df_meta = df_meta.with_columns(
123
+ correctness_delta=pl.col("correctness_target") - pl.col("correctness_baseline"),
124
+ )
125
+ df_meta = df_meta.sort("correctness_delta", "output_same_prefix_len")
126
+ return df_meta
127
+
128
+
129
+ def _handle_one_prompt(df_one_prompt: pl.DataFrame):
130
+ assert len(set(df_one_prompt["prompt"])) == 1
131
+
132
+ df_baseline = df_one_prompt.filter(pl.col("category") == "baseline")
133
+ df_target = df_one_prompt.filter(pl.col("category") == "target")
134
+
135
+ outputs_baseline = df_baseline["output"].to_list()
136
+ outputs_target = df_target["output"].to_list()
137
+
138
+ output_same_prefix_len = max(
139
+ _compute_str_prefix_len(output_baseline, output_target)
140
+ for output_baseline in outputs_baseline
141
+ for output_target in outputs_target
142
+ )
143
+
144
+ return dict(
145
+ prompt_id=df_one_prompt[0, "prompt_id"],
146
+ correctness_baseline=df_baseline["correct"].mean(),
147
+ correctness_target=df_target["correct"].mean(),
148
+ output_same_prefix_len=output_same_prefix_len,
149
+ prompt=df_one_prompt[0, "prompt"],
150
+ outputs_baseline=outputs_baseline,
151
+ outputs_target=outputs_target,
152
+ )
153
+
154
+
155
+ def _compute_str_prefix_len(a: str, b: str) -> int:
156
+ min_len = min(len(a), len(b))
157
+ for i in range(min_len):
158
+ if a[i] != b[i]:
159
+ return i
160
+ return min_len
161
+
162
+
163
+ if __name__ == "__main__":
164
+ parser = argparse.ArgumentParser(description=_DESCRIPTION)
165
+ parser.add_argument("--baseline-path", type=str, nargs="+")
166
+ parser.add_argument("--target-path", type=str, nargs="+")
167
+ parser.add_argument(
168
+ "--output-path", type=str, default="/tmp/text_comparator_output.json"
169
+ )
170
+ parser.add_argument("--disable-print-details", action="store_true")
171
+ args = parser.parse_args()
172
+ main(args)