sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +321 -31
- sglang/bench_serving.py +10 -3
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +8 -0
- sglang/srt/configs/model_config.py +160 -105
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/constrained/base_grammar_backend.py +1 -0
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +6 -4
- sglang/srt/debug_utils/dumper.py +10 -3
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/common/conn.py +266 -98
- sglang/srt/disaggregation/decode.py +50 -9
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
- sglang/srt/disaggregation/mooncake/conn.py +51 -541
- sglang/srt/disaggregation/nixl/conn.py +148 -39
- sglang/srt/disaggregation/prefill.py +31 -14
- sglang/srt/disaggregation/utils.py +36 -5
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +135 -80
- sglang/srt/entrypoints/engine.py +23 -3
- sglang/srt/entrypoints/grpc_request_manager.py +330 -55
- sglang/srt/entrypoints/grpc_server.py +232 -102
- sglang/srt/entrypoints/http_server.py +49 -9
- sglang/srt/entrypoints/openai/protocol.py +110 -5
- sglang/srt/entrypoints/openai/serving_base.py +25 -6
- sglang/srt/entrypoints/openai/serving_chat.py +178 -49
- sglang/srt/entrypoints/openai/serving_completions.py +5 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/entrypoints/openai/serving_responses.py +42 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/expert_location.py +30 -5
- sglang/srt/function_call/function_call_parser.py +3 -2
- sglang/srt/function_call/glm4_moe_detector.py +3 -3
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
- sglang/srt/layers/activation.py +7 -6
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +108 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
- sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +112 -194
- sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
- sglang/srt/layers/attention/mamba/mamba.py +566 -1
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/triton_backend.py +42 -9
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +11 -1
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +2 -0
- sglang/srt/layers/linear.py +21 -4
- sglang/srt/layers/logits_processor.py +15 -2
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +147 -74
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
- sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
- sglang/srt/layers/moe/utils.py +10 -0
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/fp8.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +1 -1
- sglang/srt/layers/quantization/modelopt_quant.py +44 -9
- sglang/srt/layers/quantization/mxfp4.py +12 -4
- sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
- sglang/srt/layers/quantization/w4afp8.py +0 -4
- sglang/srt/layers/quantization/w8a8_int8.py +15 -3
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +52 -4
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +3 -3
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +10 -4
- sglang/srt/lora/lora.py +7 -5
- sglang/srt/lora/lora_manager.py +17 -6
- sglang/srt/lora/mem_pool.py +1 -1
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +7 -5
- sglang/srt/managers/cache_controller.py +42 -142
- sglang/srt/managers/data_parallel_controller.py +11 -46
- sglang/srt/managers/detokenizer_manager.py +11 -11
- sglang/srt/managers/io_struct.py +162 -118
- sglang/srt/managers/mm_utils.py +43 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +167 -86
- sglang/srt/managers/schedule_policy.py +143 -16
- sglang/srt/managers/scheduler.py +359 -214
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
- sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
- sglang/srt/managers/tokenizer_manager.py +84 -136
- sglang/srt/managers/tp_worker.py +39 -29
- sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +40 -1
- sglang/srt/mem_cache/hiradix_cache.py +119 -32
- sglang/srt/mem_cache/memory_pool.py +188 -10
- sglang/srt/mem_cache/memory_pool_host.py +134 -182
- sglang/srt/mem_cache/radix_cache.py +222 -71
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
- sglang/srt/mem_cache/swa_radix_cache.py +25 -34
- sglang/srt/metrics/collector.py +82 -120
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +2 -2
- sglang/srt/model_executor/cuda_graph_runner.py +39 -32
- sglang/srt/model_executor/forward_batch_info.py +23 -38
- sglang/srt/model_executor/model_runner.py +131 -183
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/loader.py +14 -10
- sglang/srt/model_loader/weight_utils.py +156 -2
- sglang/srt/models/bailing_moe.py +27 -4
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +536 -153
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +3 -3
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +1 -1
- sglang/srt/models/glm4v_moe.py +1 -1
- sglang/srt/models/gpt_oss.py +7 -30
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/longcat_flash.py +1 -1
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +15 -4
- sglang/srt/models/qwen2.py +0 -7
- sglang/srt/models/qwen2_5_vl.py +2 -2
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +64 -1
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +31 -3
- sglang/srt/models/qwen3_next.py +36 -9
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +2 -3
- sglang/srt/multimodal/processors/internvl.py +20 -8
- sglang/srt/multimodal/processors/qwen_vl.py +8 -1
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +20 -2
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +753 -295
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
- sglang/srt/speculative/eagle_worker.py +57 -25
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +47 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +32 -6
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +399 -74
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +1 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +12 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +355 -4
- sglang/utils.py +10 -1
- sglang/version.py +1 -1
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,57 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
|
4
|
+
|
5
|
+
class DummyModel(nn.Module):
|
6
|
+
def __init__(self, d_in=2048, n_heads=128, softmax_scale=0.5):
|
7
|
+
super().__init__()
|
8
|
+
self.weights_proj = nn.Linear(d_in, 1024)
|
9
|
+
self.n_heads = n_heads
|
10
|
+
self.softmax_scale = softmax_scale
|
11
|
+
|
12
|
+
def _get_logits_head_gate_orig(self, x: torch.Tensor, q_scale: torch.Tensor):
|
13
|
+
weights = self.weights_proj(x)
|
14
|
+
weights = weights * self.n_heads**-0.5
|
15
|
+
q_scale = q_scale.unsqueeze(1) # (B,1,1)
|
16
|
+
weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale
|
17
|
+
return weights
|
18
|
+
|
19
|
+
def _get_logits_head_gate_opt(self, x: torch.Tensor, q_scale: torch.Tensor):
|
20
|
+
weights = self.weights_proj(x)
|
21
|
+
q_scale = q_scale.unsqueeze(1) # (B,1,1)
|
22
|
+
scale_const = self.n_heads**-0.5 * q_scale * self.softmax_scale # (B,1,1)
|
23
|
+
weights = weights.unsqueeze(-1) * scale_const # (B,1024,1)
|
24
|
+
return weights
|
25
|
+
|
26
|
+
|
27
|
+
def main():
|
28
|
+
torch.manual_seed(0)
|
29
|
+
model = DummyModel(d_in=2048, n_heads=128, softmax_scale=0.5)
|
30
|
+
x = torch.randn(128, 2048) # batch=128, d_in=2048
|
31
|
+
q_scale = torch.randn(128, 1)
|
32
|
+
|
33
|
+
import time
|
34
|
+
|
35
|
+
start = time.time()
|
36
|
+
for _ in range(1000):
|
37
|
+
out_orig = model._get_logits_head_gate_orig(x, q_scale)
|
38
|
+
print("Original version time:", time.time() - start)
|
39
|
+
|
40
|
+
start = time.time()
|
41
|
+
for _ in range(1000):
|
42
|
+
out_opt = model._get_logits_head_gate_opt(x, q_scale)
|
43
|
+
print("Optimized version time:", time.time() - start)
|
44
|
+
|
45
|
+
print("Difference:", (out_orig - out_opt).abs().max().item())
|
46
|
+
assert torch.allclose(out_orig, out_opt), "Mismatch between original and optimized"
|
47
|
+
|
48
|
+
|
49
|
+
if __name__ == "__main__":
|
50
|
+
main()
|
51
|
+
|
52
|
+
|
53
|
+
"""
|
54
|
+
Original version time: 0.49235057830810547
|
55
|
+
Optimized version time: 0.4087331295013428
|
56
|
+
Difference: 1.4901161193847656e-08
|
57
|
+
"""
|
sglang/test/run_eval.py
CHANGED
@@ -10,11 +10,46 @@ import time
|
|
10
10
|
|
11
11
|
from sglang.test.simple_eval_common import (
|
12
12
|
ChatCompletionSampler,
|
13
|
+
Eval,
|
13
14
|
make_report,
|
14
15
|
set_ulimit,
|
15
16
|
)
|
16
17
|
|
17
18
|
|
19
|
+
def get_thinking_kwargs(args):
|
20
|
+
thinking_mode = getattr(args, "thinking_mode", None)
|
21
|
+
if thinking_mode in THINKING_MODE_CHOICES:
|
22
|
+
if thinking_mode == "deepseek-v3":
|
23
|
+
thinking_param = "thinking"
|
24
|
+
else:
|
25
|
+
thinking_param = "enable_thinking"
|
26
|
+
return {
|
27
|
+
"chat_template_kwargs": {thinking_param: True},
|
28
|
+
}
|
29
|
+
return {}
|
30
|
+
|
31
|
+
|
32
|
+
def run_eval_once(args, base_url: str, eval_obj: Eval) -> dict:
|
33
|
+
# Get thinking kwargs based on user's choice
|
34
|
+
thinking_kwargs = get_thinking_kwargs(args)
|
35
|
+
|
36
|
+
sampler = ChatCompletionSampler(
|
37
|
+
model=args.model,
|
38
|
+
max_tokens=getattr(args, "max_tokens", 2048),
|
39
|
+
base_url=base_url,
|
40
|
+
temperature=getattr(args, "temperature", 0.0),
|
41
|
+
reasoning_effort=getattr(args, "reasoning_effort", None),
|
42
|
+
extra_body=thinking_kwargs,
|
43
|
+
)
|
44
|
+
|
45
|
+
# Run eval
|
46
|
+
tic = time.perf_counter()
|
47
|
+
result = eval_obj(sampler)
|
48
|
+
latency = time.perf_counter() - tic
|
49
|
+
|
50
|
+
return result, latency, sampler
|
51
|
+
|
52
|
+
|
18
53
|
def run_eval(args):
|
19
54
|
set_ulimit()
|
20
55
|
|
@@ -60,21 +95,40 @@ def run_eval(args):
|
|
60
95
|
from sglang.test.simple_eval_humaneval import HumanEval
|
61
96
|
|
62
97
|
eval_obj = HumanEval(args.num_examples, args.num_threads)
|
98
|
+
elif args.eval_name == "mmmu":
|
99
|
+
# VLM MMMU evaluation with fixed 100 examples by default
|
100
|
+
from sglang.test.simple_eval_mmmu_vlm import MMMUVLMEval
|
101
|
+
|
102
|
+
eval_obj = MMMUVLMEval(args.num_examples, args.num_threads)
|
63
103
|
else:
|
64
104
|
raise ValueError(f"Invalid eval name: {args.eval_name}")
|
65
105
|
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
temperature=getattr(args, "temperature", 0.0),
|
71
|
-
reasoning_effort=getattr(args, "reasoning_effort", None),
|
72
|
-
)
|
106
|
+
if getattr(args, "repeat", 1) == 1:
|
107
|
+
result, latency, sampler = run_eval_once(args, base_url, eval_obj)
|
108
|
+
else:
|
109
|
+
from concurrent.futures import ThreadPoolExecutor
|
73
110
|
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
111
|
+
executor = ThreadPoolExecutor(max_workers=args.repeat)
|
112
|
+
|
113
|
+
futures = [
|
114
|
+
executor.submit(run_eval_once, args, base_url, eval_obj)
|
115
|
+
for _ in range(args.repeat)
|
116
|
+
]
|
117
|
+
|
118
|
+
scores_repeat = []
|
119
|
+
|
120
|
+
for f in futures:
|
121
|
+
result, latency, sampler = f.result()
|
122
|
+
scores_repeat.append(result.score)
|
123
|
+
|
124
|
+
mean_score = sum(scores_repeat) / len(scores_repeat)
|
125
|
+
scores_repeat = [f"{s:.3f}" for s in scores_repeat]
|
126
|
+
print("=" * 20)
|
127
|
+
print(f"Repeat: {args.repeat}, mean: {mean_score:.3f}")
|
128
|
+
print(f"Scores: {scores_repeat}")
|
129
|
+
print("=" * 20)
|
130
|
+
|
131
|
+
executor.shutdown()
|
78
132
|
|
79
133
|
# Dump reports
|
80
134
|
metrics = result.metrics | {"score": result.score}
|
@@ -94,9 +148,13 @@ def run_eval(args):
|
|
94
148
|
print(f"Total latency: {latency:.3f} s")
|
95
149
|
print(f"Score: {metrics['score']:.3f}")
|
96
150
|
|
151
|
+
if getattr(args, "return_latency", False):
|
152
|
+
return metrics, latency
|
97
153
|
return metrics
|
98
154
|
|
99
155
|
|
156
|
+
THINKING_MODE_CHOICES = ["deepseek-r1", "deepseek-v3", "qwen3"]
|
157
|
+
|
100
158
|
if __name__ == "__main__":
|
101
159
|
parser = argparse.ArgumentParser()
|
102
160
|
parser.add_argument(
|
@@ -118,12 +176,22 @@ if __name__ == "__main__":
|
|
118
176
|
type=str,
|
119
177
|
help="Name or path of the model. If not set, the default model will request /v1/models for conf.",
|
120
178
|
)
|
179
|
+
parser.add_argument(
|
180
|
+
"--repeat", type=int, default=1, help="repeat the evaluation n times"
|
181
|
+
)
|
121
182
|
parser.add_argument("--eval-name", type=str, default="mmlu")
|
122
183
|
parser.add_argument("--num-examples", type=int)
|
123
184
|
parser.add_argument("--num-threads", type=int, default=512)
|
124
185
|
parser.add_argument("--max-tokens", type=int, default=2048)
|
125
186
|
parser.add_argument("--temperature", type=float, default=0.0)
|
126
187
|
parser.add_argument("--reasoning-effort", type=str)
|
188
|
+
parser.add_argument(
|
189
|
+
"--thinking-mode",
|
190
|
+
default=None,
|
191
|
+
type=str,
|
192
|
+
choices=THINKING_MODE_CHOICES,
|
193
|
+
help="Enable thinking mode in Deepseek R1, V3.1/3.2, or Qwen3",
|
194
|
+
)
|
127
195
|
args = parser.parse_args()
|
128
196
|
|
129
197
|
run_eval(args)
|
sglang/test/runners.py
CHANGED
@@ -30,8 +30,8 @@ from transformers import (
|
|
30
30
|
)
|
31
31
|
|
32
32
|
from sglang.srt.entrypoints.engine import Engine
|
33
|
-
from sglang.srt.hf_transformers_utils import get_tokenizer
|
34
33
|
from sglang.srt.utils import load_image
|
34
|
+
from sglang.srt.utils.hf_transformers_utils import get_tokenizer
|
35
35
|
from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER, calculate_rouge_l
|
36
36
|
|
37
37
|
DEFAULT_PROMPTS = [
|
@@ -93,6 +93,7 @@ class ChatCompletionSampler(SamplerBase):
|
|
93
93
|
temperature: float = 0.0,
|
94
94
|
reasoning_effort: Optional[str] = None,
|
95
95
|
max_tokens: int = 2048,
|
96
|
+
extra_body: Optional[Dict[str, Any]] = None,
|
96
97
|
):
|
97
98
|
self.client = OpenAI(base_url=base_url, http_client=LargerHttpxClient())
|
98
99
|
|
@@ -104,9 +105,10 @@ class ChatCompletionSampler(SamplerBase):
|
|
104
105
|
self.temperature = temperature
|
105
106
|
self.max_tokens = max_tokens
|
106
107
|
self.reasoning_effort = reasoning_effort
|
108
|
+
self.extra_body = extra_body
|
107
109
|
self.image_format = "url"
|
108
110
|
print(
|
109
|
-
f"ChatCompletionSampler initialized with {self.system_message=} {self.temperature=} {self.max_tokens=} {self.reasoning_effort=}"
|
111
|
+
f"ChatCompletionSampler initialized with {self.system_message=} {self.temperature=} {self.max_tokens=} {self.reasoning_effort=} {self.extra_body=}"
|
110
112
|
)
|
111
113
|
|
112
114
|
def _handle_image(
|
@@ -136,7 +138,7 @@ class ChatCompletionSampler(SamplerBase):
|
|
136
138
|
self._pack_message("system", self.system_message)
|
137
139
|
] + message_list
|
138
140
|
trial = 0
|
139
|
-
while
|
141
|
+
while trial < 6: # 126 seconds in total
|
140
142
|
try:
|
141
143
|
response = self.client.chat.completions.create(
|
142
144
|
model=self.model,
|
@@ -144,6 +146,7 @@ class ChatCompletionSampler(SamplerBase):
|
|
144
146
|
temperature=self.temperature,
|
145
147
|
max_tokens=self.max_tokens,
|
146
148
|
reasoning_effort=self.reasoning_effort,
|
149
|
+
extra_body=self.extra_body,
|
147
150
|
)
|
148
151
|
return response.choices[0].message.content
|
149
152
|
# NOTE: BadRequestError is triggered once for MMMU, please uncomment if you are rerunning MMMU
|
@@ -0,0 +1,441 @@
|
|
1
|
+
"""
|
2
|
+
MMMU evaluation for VLMs using the run_eval simple-evals interface.
|
3
|
+
|
4
|
+
"""
|
5
|
+
|
6
|
+
from __future__ import annotations
|
7
|
+
|
8
|
+
import base64
|
9
|
+
import io
|
10
|
+
from typing import List, Optional, Tuple
|
11
|
+
|
12
|
+
from datasets import concatenate_datasets, load_dataset
|
13
|
+
from PIL import Image
|
14
|
+
|
15
|
+
from sglang.test import simple_eval_common as common
|
16
|
+
from sglang.test.simple_eval_common import (
|
17
|
+
HTML_JINJA,
|
18
|
+
Eval,
|
19
|
+
EvalResult,
|
20
|
+
SamplerBase,
|
21
|
+
SingleEvalResult,
|
22
|
+
map_with_progress,
|
23
|
+
)
|
24
|
+
|
25
|
+
|
26
|
+
class MMMUVLMEval(Eval):
|
27
|
+
DOMAIN_CAT2SUB_CAT = {
|
28
|
+
"Art and Design": ["Art", "Art_Theory", "Design", "Music"],
|
29
|
+
"Business": ["Accounting", "Economics", "Finance", "Manage", "Marketing"],
|
30
|
+
"Science": ["Biology", "Chemistry", "Geography", "Math", "Physics"],
|
31
|
+
"Health and Medicine": [
|
32
|
+
"Basic_Medical_Science",
|
33
|
+
"Clinical_Medicine",
|
34
|
+
"Diagnostics_and_Laboratory_Medicine",
|
35
|
+
"Pharmacy",
|
36
|
+
"Public_Health",
|
37
|
+
],
|
38
|
+
"Humanities and Social Science": [
|
39
|
+
"History",
|
40
|
+
"Literature",
|
41
|
+
"Sociology",
|
42
|
+
"Psychology",
|
43
|
+
],
|
44
|
+
"Tech and Engineering": [
|
45
|
+
"Agriculture",
|
46
|
+
"Architecture_and_Engineering",
|
47
|
+
"Computer_Science",
|
48
|
+
"Electronics",
|
49
|
+
"Energy_and_Power",
|
50
|
+
"Materials",
|
51
|
+
"Mechanical_Engineering",
|
52
|
+
],
|
53
|
+
}
|
54
|
+
|
55
|
+
def __init__(
|
56
|
+
self, num_examples: Optional[int] = 100, num_threads: int = 32, seed: int = 42
|
57
|
+
):
|
58
|
+
"""Create MMMU VLM eval (Math subset, 100 fixed samples by default)."""
|
59
|
+
self.num_examples = num_examples
|
60
|
+
self.num_threads = num_threads
|
61
|
+
self.seed = seed
|
62
|
+
# Prepare samples deterministically across all MMMU subjects (validation split)
|
63
|
+
self.samples = self._prepare_mmmu_samples(self.num_examples)
|
64
|
+
|
65
|
+
@staticmethod
|
66
|
+
def _to_data_uri(image: Image.Image) -> str:
|
67
|
+
if image.mode == "RGBA":
|
68
|
+
image = image.convert("RGB")
|
69
|
+
buf = io.BytesIO()
|
70
|
+
image.save(buf, format="PNG")
|
71
|
+
b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
|
72
|
+
return f"data:image/png;base64,{b64}"
|
73
|
+
|
74
|
+
@staticmethod
|
75
|
+
def _build_mc_mapping(options: List[str]) -> Tuple[dict, List[str]]:
|
76
|
+
index2ans = {}
|
77
|
+
all_choices = []
|
78
|
+
ch = ord("A")
|
79
|
+
for opt in options:
|
80
|
+
letter = chr(ch)
|
81
|
+
index2ans[letter] = opt
|
82
|
+
all_choices.append(letter)
|
83
|
+
ch += 1
|
84
|
+
return index2ans, all_choices
|
85
|
+
|
86
|
+
def _prepare_mmmu_samples(self, k: int) -> List[dict]:
|
87
|
+
# Subjects and domains copied from MMMU data_utils to categorize results
|
88
|
+
subjects: List[str] = []
|
89
|
+
for subs in self.DOMAIN_CAT2SUB_CAT.values():
|
90
|
+
subjects.extend(subs)
|
91
|
+
|
92
|
+
# Load validation split of each subject
|
93
|
+
datasets = []
|
94
|
+
for subj in subjects:
|
95
|
+
try:
|
96
|
+
d = load_dataset("MMMU/MMMU", subj, split="validation")
|
97
|
+
# attach subject info via transform
|
98
|
+
d = d.add_column("__subject__", [subj] * len(d))
|
99
|
+
datasets.append(d)
|
100
|
+
except Exception:
|
101
|
+
continue
|
102
|
+
if not datasets:
|
103
|
+
raise RuntimeError("Failed to load MMMU datasets")
|
104
|
+
|
105
|
+
merged = concatenate_datasets(datasets)
|
106
|
+
|
107
|
+
# Deterministic selection: sort by id (fallback to subject+index)
|
108
|
+
def _key(idx):
|
109
|
+
ex = merged[idx]
|
110
|
+
return str(ex.get("id", f"{ex['__subject__']}:{idx}"))
|
111
|
+
|
112
|
+
order = sorted(range(len(merged)), key=_key)
|
113
|
+
picked_indices = order[:k]
|
114
|
+
|
115
|
+
samples: List[dict] = []
|
116
|
+
for idx in picked_indices:
|
117
|
+
ex = merged[idx]
|
118
|
+
subject = ex["__subject__"]
|
119
|
+
image = ex.get("image_1")
|
120
|
+
if image is None or not hasattr(image, "convert"):
|
121
|
+
continue
|
122
|
+
data_uri = self._to_data_uri(image)
|
123
|
+
question = ex.get("question", "")
|
124
|
+
answer = ex.get("answer")
|
125
|
+
raw_options = ex.get("options")
|
126
|
+
question_type = "open"
|
127
|
+
index2ans = None
|
128
|
+
all_choices = None
|
129
|
+
options = None
|
130
|
+
if raw_options:
|
131
|
+
try:
|
132
|
+
options = (
|
133
|
+
raw_options
|
134
|
+
if isinstance(raw_options, list)
|
135
|
+
else list(eval(raw_options))
|
136
|
+
)
|
137
|
+
if isinstance(options, list) and len(options) > 0:
|
138
|
+
index2ans, all_choices = self._build_mc_mapping(options)
|
139
|
+
question_type = "multiple-choice"
|
140
|
+
except Exception:
|
141
|
+
options = None
|
142
|
+
|
143
|
+
# Build final textual prompt; include choices if MC
|
144
|
+
prompt_text = f"Question: {question}\n\n"
|
145
|
+
if options:
|
146
|
+
letters = [chr(ord("A") + i) for i in range(len(options))]
|
147
|
+
for letter, opt in zip(letters, options):
|
148
|
+
prompt_text += f"{letter}) {opt}\n"
|
149
|
+
prompt_text += "\nAnswer: "
|
150
|
+
|
151
|
+
samples.append(
|
152
|
+
{
|
153
|
+
"id": ex.get("id", f"{subject}:{idx}"),
|
154
|
+
"final_input_prompt": prompt_text,
|
155
|
+
"image_data": data_uri,
|
156
|
+
"answer": answer,
|
157
|
+
"question_type": question_type,
|
158
|
+
"index2ans": index2ans,
|
159
|
+
"all_choices": all_choices,
|
160
|
+
"category": subject,
|
161
|
+
}
|
162
|
+
)
|
163
|
+
|
164
|
+
return samples
|
165
|
+
|
166
|
+
@staticmethod
|
167
|
+
def _split_prompt_for_image(prompt: str) -> tuple[str, str]:
|
168
|
+
"""Split a prompt containing an inline image tag into prefix and suffix.
|
169
|
+
|
170
|
+
If no tag is present, treat the whole prompt as prefix and empty suffix.
|
171
|
+
"""
|
172
|
+
if "<" in prompt and ">" in prompt:
|
173
|
+
prefix = prompt.split("<")[0]
|
174
|
+
suffix = prompt.split(">", 1)[1]
|
175
|
+
return prefix, suffix
|
176
|
+
return prompt, ""
|
177
|
+
|
178
|
+
@staticmethod
|
179
|
+
def build_chat_messages_from_prompt(prompt: str, image_data) -> List:
|
180
|
+
"""Split a prompt containing an inline image tag into prefix and suffix.
|
181
|
+
|
182
|
+
If no tag is present, treat the whole prompt as prefix and empty suffix.
|
183
|
+
"""
|
184
|
+
# Build a vision+text message for OpenAI-compatible API
|
185
|
+
prefix, suffix = MMMUVLMEval._split_prompt_for_image(prompt)
|
186
|
+
|
187
|
+
content: List[dict] = []
|
188
|
+
if prefix:
|
189
|
+
content.append({"type": "text", "text": prefix})
|
190
|
+
content.append({"type": "image_url", "image_url": {"url": image_data}})
|
191
|
+
if suffix:
|
192
|
+
content.append({"type": "text", "text": suffix})
|
193
|
+
prompt_messages = [{"role": "user", "content": content}]
|
194
|
+
|
195
|
+
return prompt_messages
|
196
|
+
|
197
|
+
def __call__(self, sampler: SamplerBase) -> EvalResult:
|
198
|
+
def fn(sample: dict):
|
199
|
+
prompt = sample["final_input_prompt"]
|
200
|
+
image_data = sample["image_data"]
|
201
|
+
prompt_messages = MMMUVLMEval.build_chat_messages_from_prompt(
|
202
|
+
prompt, image_data
|
203
|
+
)
|
204
|
+
|
205
|
+
# Sample
|
206
|
+
response_text = sampler(prompt_messages)
|
207
|
+
|
208
|
+
# Parse and score
|
209
|
+
gold = sample["answer"]
|
210
|
+
if (
|
211
|
+
sample["question_type"] == "multiple-choice"
|
212
|
+
and sample["all_choices"]
|
213
|
+
and sample["index2ans"]
|
214
|
+
):
|
215
|
+
pred = _parse_multi_choice_response(
|
216
|
+
response_text, sample["all_choices"], sample["index2ans"]
|
217
|
+
)
|
218
|
+
score = 1.0 if (gold is not None and pred == gold) else 0.0
|
219
|
+
extracted_answer = pred
|
220
|
+
else:
|
221
|
+
parsed_list = _parse_open_response(response_text)
|
222
|
+
score = (
|
223
|
+
1.0 if (gold is not None and _eval_open(gold, parsed_list)) else 0.0
|
224
|
+
)
|
225
|
+
extracted_answer = ", ".join(map(str, parsed_list))
|
226
|
+
|
227
|
+
html_rendered = common.jinja_env.from_string(HTML_JINJA).render(
|
228
|
+
prompt_messages=prompt_messages,
|
229
|
+
next_message=dict(content=response_text, role="assistant"),
|
230
|
+
score=score,
|
231
|
+
correct_answer=gold,
|
232
|
+
extracted_answer=extracted_answer,
|
233
|
+
)
|
234
|
+
|
235
|
+
convo = prompt_messages + [dict(content=response_text, role="assistant")]
|
236
|
+
return SingleEvalResult(
|
237
|
+
html=html_rendered,
|
238
|
+
score=score,
|
239
|
+
metrics={"__category__": sample["category"]},
|
240
|
+
convo=convo,
|
241
|
+
)
|
242
|
+
|
243
|
+
results = map_with_progress(fn, self.samples, self.num_threads)
|
244
|
+
|
245
|
+
# Build category table and overall accuracy
|
246
|
+
# Gather per-sample correctness and category
|
247
|
+
per_cat_total: dict[str, int] = {}
|
248
|
+
per_cat_correct: dict[str, int] = {}
|
249
|
+
htmls = []
|
250
|
+
convos = []
|
251
|
+
scores: List[float] = []
|
252
|
+
for r in results:
|
253
|
+
# __category__ stored under metrics
|
254
|
+
cat = r.metrics.get("__category__") if r.metrics else None
|
255
|
+
if cat is None:
|
256
|
+
cat = "Unknown"
|
257
|
+
per_cat_total[cat] = per_cat_total.get(cat, 0) + 1
|
258
|
+
if r.score:
|
259
|
+
per_cat_correct[cat] = per_cat_correct.get(cat, 0) + 1
|
260
|
+
htmls.append(r.html)
|
261
|
+
convos.append(r.convo)
|
262
|
+
if r.score is not None:
|
263
|
+
scores.append(r.score)
|
264
|
+
|
265
|
+
evaluation_result = {}
|
266
|
+
for cat, tot in per_cat_total.items():
|
267
|
+
corr = per_cat_correct.get(cat, 0)
|
268
|
+
acc = (corr / tot) if tot > 0 else 0.0
|
269
|
+
evaluation_result[cat] = {"acc": round(acc, 3), "num_example": tot}
|
270
|
+
|
271
|
+
printable_results = {}
|
272
|
+
# Domains first
|
273
|
+
for domain, cats in self.DOMAIN_CAT2SUB_CAT.items():
|
274
|
+
acc_sum = 0.0
|
275
|
+
num_sum = 0
|
276
|
+
for cat in cats:
|
277
|
+
if cat in evaluation_result:
|
278
|
+
acc_sum += (
|
279
|
+
evaluation_result[cat]["acc"]
|
280
|
+
* evaluation_result[cat]["num_example"]
|
281
|
+
)
|
282
|
+
num_sum += evaluation_result[cat]["num_example"]
|
283
|
+
if num_sum > 0:
|
284
|
+
printable_results[f"Overall-{domain}"] = {
|
285
|
+
"num": num_sum,
|
286
|
+
"acc": round(acc_sum / num_sum, 3),
|
287
|
+
}
|
288
|
+
# add each sub-category row if present
|
289
|
+
for cat in cats:
|
290
|
+
if cat in evaluation_result:
|
291
|
+
printable_results[cat] = {
|
292
|
+
"num": evaluation_result[cat]["num_example"],
|
293
|
+
"acc": evaluation_result[cat]["acc"],
|
294
|
+
}
|
295
|
+
|
296
|
+
# Overall
|
297
|
+
total_num = sum(v["num_example"] for v in evaluation_result.values())
|
298
|
+
overall_acc = (
|
299
|
+
sum(v["acc"] * v["num_example"] for v in evaluation_result.values())
|
300
|
+
/ total_num
|
301
|
+
if total_num > 0
|
302
|
+
else 0.0
|
303
|
+
)
|
304
|
+
printable_results["Overall"] = {"num": total_num, "acc": round(overall_acc, 3)}
|
305
|
+
|
306
|
+
# Build EvalResult
|
307
|
+
return EvalResult(
|
308
|
+
score=overall_acc, metrics=printable_results, htmls=htmls, convos=convos
|
309
|
+
)
|
310
|
+
|
311
|
+
|
312
|
+
def _parse_multi_choice_response(
|
313
|
+
response: str, all_choices: List[str], index2ans: dict
|
314
|
+
) -> str:
|
315
|
+
# loosely adapted from benchmark mmmu eval
|
316
|
+
for char in [",", ".", "!", "?", ";", ":", "'"]:
|
317
|
+
response = response.strip(char)
|
318
|
+
response = " " + response + " "
|
319
|
+
|
320
|
+
# Prefer explicit letter with bracket e.g. (A)
|
321
|
+
candidates: List[str] = []
|
322
|
+
for choice in all_choices:
|
323
|
+
if f"({choice})" in response:
|
324
|
+
candidates.append(choice)
|
325
|
+
if not candidates:
|
326
|
+
for choice in all_choices:
|
327
|
+
if f" {choice} " in response:
|
328
|
+
candidates.append(choice)
|
329
|
+
if not candidates and len(response.split()) > 5:
|
330
|
+
# try match by option text
|
331
|
+
for idx, ans in index2ans.items():
|
332
|
+
if ans and ans.lower() in response.lower():
|
333
|
+
candidates.append(idx)
|
334
|
+
if not candidates:
|
335
|
+
# fallback to first choice
|
336
|
+
return all_choices[0]
|
337
|
+
if len(candidates) == 1:
|
338
|
+
return candidates[0]
|
339
|
+
# choose the last occurrence
|
340
|
+
starts = []
|
341
|
+
for can in candidates:
|
342
|
+
pos = response.rfind(f"({can})")
|
343
|
+
if pos == -1:
|
344
|
+
pos = response.rfind(f" {can} ")
|
345
|
+
if pos == -1 and index2ans.get(can):
|
346
|
+
pos = response.lower().rfind(index2ans[can].lower())
|
347
|
+
starts.append(pos)
|
348
|
+
return candidates[int(max(range(len(starts)), key=lambda i: starts[i]))]
|
349
|
+
|
350
|
+
|
351
|
+
def _check_is_number(s: str) -> bool:
|
352
|
+
try:
|
353
|
+
float(s.replace(",", ""))
|
354
|
+
return True
|
355
|
+
except Exception:
|
356
|
+
return False
|
357
|
+
|
358
|
+
|
359
|
+
def _normalize_str(s: str):
|
360
|
+
s = s.strip()
|
361
|
+
if _check_is_number(s):
|
362
|
+
s = s.replace(",", "")
|
363
|
+
try:
|
364
|
+
v = round(float(s), 2)
|
365
|
+
return [v]
|
366
|
+
except Exception:
|
367
|
+
return [s.lower()]
|
368
|
+
return [s.lower()] if len(s) > 1 else [" " + s, s + " "]
|
369
|
+
|
370
|
+
|
371
|
+
def _extract_numbers(s: str) -> List[str]:
|
372
|
+
import re as _re
|
373
|
+
|
374
|
+
pattern_commas = r"-?\b\d{1,3}(?:,\d{3})+\b"
|
375
|
+
pattern_scientific = r"-?\d+(?:\.\d+)?[eE][+-]?\d+"
|
376
|
+
pattern_simple = r"-?(?:\d+\.\d+|\.\d+|\d+\b)(?![eE][+-]?\d+)(?![,\d])"
|
377
|
+
return (
|
378
|
+
_re.findall(pattern_commas, s)
|
379
|
+
+ _re.findall(pattern_scientific, s)
|
380
|
+
+ _re.findall(pattern_simple, s)
|
381
|
+
)
|
382
|
+
|
383
|
+
|
384
|
+
def _parse_open_response(response: str) -> List[str]:
|
385
|
+
import re as _re
|
386
|
+
|
387
|
+
def get_key_subresponses(resp: str) -> List[str]:
|
388
|
+
resp = resp.strip().strip(".").lower()
|
389
|
+
subs = _re.split(r"\.\s(?=[A-Z])|\n", resp)
|
390
|
+
indicators = [
|
391
|
+
"could be ",
|
392
|
+
"so ",
|
393
|
+
"is ",
|
394
|
+
"thus ",
|
395
|
+
"therefore ",
|
396
|
+
"final ",
|
397
|
+
"answer ",
|
398
|
+
"result ",
|
399
|
+
]
|
400
|
+
keys = []
|
401
|
+
for i, s in enumerate(subs):
|
402
|
+
cands = [*indicators]
|
403
|
+
if i == len(subs) - 1:
|
404
|
+
cands.append("=")
|
405
|
+
shortest = None
|
406
|
+
for ind in cands:
|
407
|
+
if ind in s:
|
408
|
+
part = s.split(ind)[-1].strip()
|
409
|
+
if not shortest or len(part) < len(shortest):
|
410
|
+
shortest = part
|
411
|
+
if shortest and shortest not in [":", ",", ".", "!", "?", ";", ":", "'"]:
|
412
|
+
keys.append(shortest)
|
413
|
+
return keys or [resp]
|
414
|
+
|
415
|
+
key_resps = get_key_subresponses(response)
|
416
|
+
pred_list = key_resps.copy()
|
417
|
+
for r in key_resps:
|
418
|
+
pred_list.extend(_extract_numbers(r))
|
419
|
+
out = []
|
420
|
+
for x in pred_list:
|
421
|
+
out.extend(_normalize_str(x))
|
422
|
+
# dedup
|
423
|
+
return list(dict.fromkeys(out))
|
424
|
+
|
425
|
+
|
426
|
+
def _eval_open(gold, preds: List[str]) -> bool:
|
427
|
+
if isinstance(gold, list):
|
428
|
+
norm_answers = []
|
429
|
+
for ans in gold:
|
430
|
+
norm_answers.extend(_normalize_str(ans))
|
431
|
+
else:
|
432
|
+
norm_answers = _normalize_str(gold)
|
433
|
+
for p in preds:
|
434
|
+
if isinstance(p, str):
|
435
|
+
for na in norm_answers:
|
436
|
+
if isinstance(na, str) and na in p:
|
437
|
+
return True
|
438
|
+
else:
|
439
|
+
if p in norm_answers:
|
440
|
+
return True
|
441
|
+
return False
|
sglang/test/test_block_fp8.py
CHANGED
@@ -621,11 +621,11 @@ class TestW8A8BlockFP8BatchedDeepGemm(CustomTestCase):
|
|
621
621
|
w_s,
|
622
622
|
)
|
623
623
|
|
624
|
-
from deep_gemm import
|
624
|
+
from deep_gemm import fp8_m_grouped_gemm_nt_masked
|
625
625
|
|
626
626
|
with torch.inference_mode():
|
627
627
|
ref_out = torch_w8a8_block_fp8_bmm(a, a_s, w, w_s, block_size, dtype)
|
628
|
-
|
628
|
+
fp8_m_grouped_gemm_nt_masked(lhs, rhs, oe, masked_m, expected_m)
|
629
629
|
out = oe[:, :M, :]
|
630
630
|
|
631
631
|
self.assertTrue(
|