sglang 0.5.4__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +149 -34
- sglang/bench_serving.py +73 -14
- sglang/compile_deep_gemm.py +13 -7
- sglang/launch_server.py +2 -0
- sglang/srt/batch_invariant_ops/__init__.py +2 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +221 -4
- sglang/srt/checkpoint_engine/__init__.py +9 -0
- sglang/srt/checkpoint_engine/update.py +317 -0
- sglang/srt/compilation/backend.py +1 -1
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/deepseek_ocr.py +542 -10
- sglang/srt/configs/deepseekvl2.py +95 -194
- sglang/srt/configs/kimi_linear.py +160 -0
- sglang/srt/configs/mamba_utils.py +66 -0
- sglang/srt/configs/model_config.py +30 -7
- sglang/srt/constants.py +7 -0
- sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
- sglang/srt/disaggregation/decode.py +34 -6
- sglang/srt/disaggregation/nixl/conn.py +2 -2
- sglang/srt/disaggregation/prefill.py +25 -3
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
- sglang/srt/distributed/parallel_state.py +9 -12
- sglang/srt/entrypoints/engine.py +31 -20
- sglang/srt/entrypoints/grpc_server.py +0 -1
- sglang/srt/entrypoints/http_server.py +94 -94
- sglang/srt/entrypoints/openai/protocol.py +7 -1
- sglang/srt/entrypoints/openai/serving_chat.py +42 -0
- sglang/srt/entrypoints/openai/serving_completions.py +10 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/environ.py +23 -2
- sglang/srt/eplb/expert_distribution.py +64 -1
- sglang/srt/eplb/expert_location.py +106 -36
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/minimax_m2.py +367 -0
- sglang/srt/grpc/compile_proto.py +3 -0
- sglang/srt/layers/activation.py +6 -0
- sglang/srt/layers/attention/ascend_backend.py +233 -5
- sglang/srt/layers/attention/attention_registry.py +3 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
- sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
- sglang/srt/layers/attention/fla/kda.py +1359 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
- sglang/srt/layers/attention/flashattention_backend.py +19 -8
- sglang/srt/layers/attention/flashinfer_backend.py +10 -1
- sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -11
- sglang/srt/layers/attention/flashmla_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
- sglang/srt/layers/attention/mamba/mamba.py +20 -11
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
- sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
- sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
- sglang/srt/layers/attention/nsa/transform_index.py +1 -1
- sglang/srt/layers/attention/nsa_backend.py +157 -23
- sglang/srt/layers/attention/triton_backend.py +4 -1
- sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
- sglang/srt/layers/attention/trtllm_mla_backend.py +11 -15
- sglang/srt/layers/attention/utils.py +78 -0
- sglang/srt/layers/communicator.py +24 -1
- sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
- sglang/srt/layers/layernorm.py +35 -6
- sglang/srt/layers/logits_processor.py +9 -20
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
- sglang/srt/layers/moe/ep_moe/layer.py +78 -289
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
- sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +340 -55
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
- sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
- sglang/srt/layers/moe/token_dispatcher/deepep.py +25 -18
- sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
- sglang/srt/layers/moe/topk.py +35 -10
- sglang/srt/layers/moe/utils.py +3 -4
- sglang/srt/layers/pooler.py +21 -2
- sglang/srt/layers/quantization/__init__.py +13 -84
- sglang/srt/layers/quantization/auto_round.py +394 -0
- sglang/srt/layers/quantization/awq.py +0 -3
- sglang/srt/layers/quantization/base_config.py +7 -0
- sglang/srt/layers/quantization/fp8.py +68 -63
- sglang/srt/layers/quantization/fp8_kernel.py +1 -1
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gguf.py +566 -0
- sglang/srt/layers/quantization/modelopt_quant.py +168 -11
- sglang/srt/layers/quantization/mxfp4.py +30 -38
- sglang/srt/layers/quantization/unquant.py +23 -45
- sglang/srt/layers/quantization/w4afp8.py +38 -2
- sglang/srt/layers/radix_attention.py +5 -2
- sglang/srt/layers/rotary_embedding.py +130 -46
- sglang/srt/layers/sampler.py +12 -1
- sglang/srt/lora/lora_registry.py +9 -0
- sglang/srt/managers/async_mm_data_processor.py +122 -0
- sglang/srt/managers/data_parallel_controller.py +30 -3
- sglang/srt/managers/detokenizer_manager.py +3 -0
- sglang/srt/managers/io_struct.py +29 -4
- sglang/srt/managers/multi_tokenizer_mixin.py +22 -1
- sglang/srt/managers/schedule_batch.py +74 -15
- sglang/srt/managers/scheduler.py +185 -144
- sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
- sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
- sglang/srt/managers/scheduler_pp_mixin.py +7 -2
- sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
- sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
- sglang/srt/managers/session_controller.py +6 -5
- sglang/srt/managers/tokenizer_manager.py +165 -78
- sglang/srt/managers/tp_worker.py +24 -1
- sglang/srt/mem_cache/base_prefix_cache.py +23 -4
- sglang/srt/mem_cache/common.py +1 -0
- sglang/srt/mem_cache/hicache_storage.py +7 -1
- sglang/srt/mem_cache/memory_pool.py +253 -57
- sglang/srt/mem_cache/memory_pool_host.py +12 -5
- sglang/srt/mem_cache/radix_cache.py +4 -0
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
- sglang/srt/metrics/collector.py +46 -3
- sglang/srt/model_executor/cuda_graph_runner.py +15 -3
- sglang/srt/model_executor/forward_batch_info.py +55 -14
- sglang/srt/model_executor/model_runner.py +77 -170
- sglang/srt/model_executor/npu_graph_runner.py +7 -3
- sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
- sglang/srt/model_loader/weight_utils.py +1 -1
- sglang/srt/models/bailing_moe.py +9 -2
- sglang/srt/models/deepseek_nextn.py +11 -2
- sglang/srt/models/deepseek_v2.py +296 -78
- sglang/srt/models/glm4.py +391 -77
- sglang/srt/models/glm4_moe.py +322 -354
- sglang/srt/models/glm4_moe_nextn.py +4 -14
- sglang/srt/models/glm4v.py +196 -55
- sglang/srt/models/glm4v_moe.py +29 -197
- sglang/srt/models/gpt_oss.py +1 -10
- sglang/srt/models/kimi_linear.py +678 -0
- sglang/srt/models/llama4.py +1 -1
- sglang/srt/models/llama_eagle3.py +11 -1
- sglang/srt/models/longcat_flash.py +2 -2
- sglang/srt/models/minimax_m2.py +922 -0
- sglang/srt/models/nvila.py +355 -0
- sglang/srt/models/nvila_lite.py +184 -0
- sglang/srt/models/qwen2.py +23 -2
- sglang/srt/models/qwen2_moe.py +30 -15
- sglang/srt/models/qwen3.py +35 -5
- sglang/srt/models/qwen3_moe.py +18 -12
- sglang/srt/models/qwen3_next.py +7 -0
- sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
- sglang/srt/multimodal/processors/base_processor.py +1 -0
- sglang/srt/multimodal/processors/glm4v.py +1 -1
- sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
- sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
- sglang/srt/multiplex/multiplexing_mixin.py +209 -0
- sglang/srt/multiplex/pdmux_context.py +164 -0
- sglang/srt/parser/conversation.py +7 -1
- sglang/srt/parser/reasoning_parser.py +28 -1
- sglang/srt/sampling/custom_logit_processor.py +67 -1
- sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
- sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
- sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
- sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
- sglang/srt/server_args.py +459 -199
- sglang/srt/single_batch_overlap.py +2 -4
- sglang/srt/speculative/draft_utils.py +16 -0
- sglang/srt/speculative/eagle_info.py +42 -36
- sglang/srt/speculative/eagle_info_v2.py +68 -25
- sglang/srt/speculative/eagle_utils.py +261 -16
- sglang/srt/speculative/eagle_worker.py +11 -3
- sglang/srt/speculative/eagle_worker_v2.py +15 -9
- sglang/srt/speculative/spec_info.py +305 -31
- sglang/srt/speculative/spec_utils.py +44 -8
- sglang/srt/tracing/trace.py +121 -12
- sglang/srt/utils/common.py +142 -74
- sglang/srt/utils/hf_transformers_utils.py +38 -12
- sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
- sglang/test/kits/radix_cache_server_kit.py +50 -0
- sglang/test/runners.py +31 -7
- sglang/test/simple_eval_common.py +5 -3
- sglang/test/simple_eval_humaneval.py +1 -0
- sglang/test/simple_eval_math.py +1 -0
- sglang/test/simple_eval_mmlu.py +1 -0
- sglang/test/simple_eval_mmmu_vlm.py +1 -0
- sglang/test/test_deterministic.py +235 -12
- sglang/test/test_deterministic_utils.py +2 -1
- sglang/test/test_utils.py +7 -1
- sglang/version.py +1 -1
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +15 -28
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +194 -175
- sglang/srt/models/vila.py +0 -306
- /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
|
@@ -41,6 +41,12 @@ class TorchMemorySaverAdapter(ABC):
|
|
|
41
41
|
def region(self, tag: str, enable_cpu_backup: bool = False):
|
|
42
42
|
raise NotImplementedError
|
|
43
43
|
|
|
44
|
+
def cuda_graph(self, **kwargs):
|
|
45
|
+
raise NotImplementedError
|
|
46
|
+
|
|
47
|
+
def disable(self):
|
|
48
|
+
raise NotImplementedError
|
|
49
|
+
|
|
44
50
|
def pause(self, tag: str):
|
|
45
51
|
raise NotImplementedError
|
|
46
52
|
|
|
@@ -61,6 +67,12 @@ class _TorchMemorySaverAdapterReal(TorchMemorySaverAdapter):
|
|
|
61
67
|
def region(self, tag: str, enable_cpu_backup: bool = False):
|
|
62
68
|
return _memory_saver.region(tag=tag, enable_cpu_backup=enable_cpu_backup)
|
|
63
69
|
|
|
70
|
+
def cuda_graph(self, **kwargs):
|
|
71
|
+
return _memory_saver.cuda_graph(**kwargs)
|
|
72
|
+
|
|
73
|
+
def disable(self):
|
|
74
|
+
return _memory_saver.disable()
|
|
75
|
+
|
|
64
76
|
def pause(self, tag: str):
|
|
65
77
|
return _memory_saver.pause(tag=tag)
|
|
66
78
|
|
|
@@ -81,6 +93,14 @@ class _TorchMemorySaverAdapterNoop(TorchMemorySaverAdapter):
|
|
|
81
93
|
def region(self, tag: str, enable_cpu_backup: bool = False):
|
|
82
94
|
yield
|
|
83
95
|
|
|
96
|
+
@contextmanager
|
|
97
|
+
def cuda_graph(self, **kwargs):
|
|
98
|
+
yield
|
|
99
|
+
|
|
100
|
+
@contextmanager
|
|
101
|
+
def disable(self):
|
|
102
|
+
yield
|
|
103
|
+
|
|
84
104
|
def pause(self, tag: str):
|
|
85
105
|
pass
|
|
86
106
|
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
import random
|
|
2
|
+
|
|
3
|
+
import requests
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def gen_radix_tree(num_nodes=400, chunk_len=256):
|
|
7
|
+
num0 = num_nodes // 2
|
|
8
|
+
num1 = num_nodes - num0
|
|
9
|
+
nodes = [{"input_ids": [37] * 117, "decode_len": 217}]
|
|
10
|
+
for _ in range(num0):
|
|
11
|
+
parent = random.choice(nodes)
|
|
12
|
+
unique_len = random.randint(0, chunk_len)
|
|
13
|
+
decode_len = random.randint(0, chunk_len)
|
|
14
|
+
token_id = random.randint(0, 32000)
|
|
15
|
+
child = {
|
|
16
|
+
"input_ids": parent["input_ids"] + [token_id] * unique_len,
|
|
17
|
+
"decode_len": decode_len,
|
|
18
|
+
}
|
|
19
|
+
nodes.append(child)
|
|
20
|
+
|
|
21
|
+
while num1 > 0:
|
|
22
|
+
num_branch = random.randint(1, min(num1, 10))
|
|
23
|
+
parent = random.choice(nodes)
|
|
24
|
+
for _ in range(num_branch):
|
|
25
|
+
unique_len = random.randint(0, chunk_len)
|
|
26
|
+
decode_len = random.randint(0, chunk_len)
|
|
27
|
+
token_id = random.randint(0, 32000)
|
|
28
|
+
child = {
|
|
29
|
+
"input_ids": parent["input_ids"] + [token_id] * unique_len,
|
|
30
|
+
"decode_len": decode_len,
|
|
31
|
+
}
|
|
32
|
+
nodes.append(child)
|
|
33
|
+
|
|
34
|
+
num1 -= num_branch
|
|
35
|
+
|
|
36
|
+
random.shuffle(nodes)
|
|
37
|
+
return nodes
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def run_radix_attention_test(base_url: str):
|
|
41
|
+
nodes = gen_radix_tree()
|
|
42
|
+
data = {
|
|
43
|
+
"input_ids": [node["input_ids"] for node in nodes],
|
|
44
|
+
"sampling_params": [
|
|
45
|
+
{"max_new_tokens": node["decode_len"], "temperature": 0} for node in nodes
|
|
46
|
+
],
|
|
47
|
+
}
|
|
48
|
+
|
|
49
|
+
res = requests.post(base_url + "/generate", json=data)
|
|
50
|
+
assert res.status_code == 200
|
sglang/test/runners.py
CHANGED
|
@@ -12,10 +12,11 @@
|
|
|
12
12
|
# limitations under the License.
|
|
13
13
|
# ==============================================================================
|
|
14
14
|
|
|
15
|
+
import json
|
|
15
16
|
import multiprocessing as mp
|
|
16
17
|
import os
|
|
17
18
|
from dataclasses import dataclass
|
|
18
|
-
from typing import List, Optional, Tuple, Union
|
|
19
|
+
from typing import Any, List, Optional, Tuple, Union
|
|
19
20
|
|
|
20
21
|
import torch
|
|
21
22
|
import torch.nn.functional as F
|
|
@@ -89,7 +90,9 @@ def get_token_ids_logprobs(logits, token_ids):
|
|
|
89
90
|
return logprobs
|
|
90
91
|
|
|
91
92
|
|
|
92
|
-
def _get_sentence_transformer_embedding_model(
|
|
93
|
+
def _get_sentence_transformer_embedding_model(
|
|
94
|
+
model_path, torch_dtype, matryoshka_dim: Optional[int] = None
|
|
95
|
+
):
|
|
93
96
|
from sentence_transformers import SentenceTransformer
|
|
94
97
|
from sentence_transformers.util import is_sentence_transformer_model
|
|
95
98
|
|
|
@@ -97,6 +100,7 @@ def _get_sentence_transformer_embedding_model(model_path, torch_dtype):
|
|
|
97
100
|
model = SentenceTransformer(
|
|
98
101
|
model_path,
|
|
99
102
|
model_kwargs={"torch_dtype": torch_dtype},
|
|
103
|
+
truncate_dim=matryoshka_dim,
|
|
100
104
|
)
|
|
101
105
|
else: # if no pre-trained sentence-transformers model
|
|
102
106
|
from sentence_transformers import models
|
|
@@ -106,7 +110,9 @@ def _get_sentence_transformer_embedding_model(model_path, torch_dtype):
|
|
|
106
110
|
word_embedding_model.get_word_embedding_dimension(),
|
|
107
111
|
pooling_mode="lasttoken",
|
|
108
112
|
)
|
|
109
|
-
model = SentenceTransformer(
|
|
113
|
+
model = SentenceTransformer(
|
|
114
|
+
modules=[word_embedding_model, pooling_model], truncate_dim=matryoshka_dim
|
|
115
|
+
)
|
|
110
116
|
|
|
111
117
|
return model.cuda()
|
|
112
118
|
|
|
@@ -135,6 +141,7 @@ class HFRunner:
|
|
|
135
141
|
output_str_only: bool = False,
|
|
136
142
|
trust_remote_code: bool = False,
|
|
137
143
|
patch_model_do_sample_false: bool = False,
|
|
144
|
+
matryoshka_dim: Optional[int] = None,
|
|
138
145
|
):
|
|
139
146
|
self.model_type = model_type
|
|
140
147
|
self.output_str_only = output_str_only
|
|
@@ -151,6 +158,7 @@ class HFRunner:
|
|
|
151
158
|
self.out_queue,
|
|
152
159
|
model_path,
|
|
153
160
|
torch_dtype,
|
|
161
|
+
matryoshka_dim,
|
|
154
162
|
),
|
|
155
163
|
)
|
|
156
164
|
self.model_proc.start()
|
|
@@ -225,7 +233,14 @@ class HFRunner:
|
|
|
225
233
|
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
|
|
226
234
|
return embeddings.contiguous()
|
|
227
235
|
|
|
228
|
-
def start_model_process(
|
|
236
|
+
def start_model_process(
|
|
237
|
+
self,
|
|
238
|
+
in_queue,
|
|
239
|
+
out_queue,
|
|
240
|
+
model_path,
|
|
241
|
+
torch_dtype,
|
|
242
|
+
matryoshka_dim: Optional[int] = None,
|
|
243
|
+
):
|
|
229
244
|
# Apply model-specific patches
|
|
230
245
|
monkey_patch_gemma2_sdpa()
|
|
231
246
|
|
|
@@ -259,7 +274,7 @@ class HFRunner:
|
|
|
259
274
|
self.processor = AutoProcessor.from_pretrained(model_path)
|
|
260
275
|
else:
|
|
261
276
|
self.model = _get_sentence_transformer_embedding_model(
|
|
262
|
-
model_path, torch_dtype
|
|
277
|
+
model_path, torch_dtype, matryoshka_dim=matryoshka_dim
|
|
263
278
|
)
|
|
264
279
|
elif self.model_type == "reward" or self.model_type == "cross_encoder":
|
|
265
280
|
from transformers import AutoModelForSequenceClassification
|
|
@@ -496,7 +511,7 @@ class SRTRunner:
|
|
|
496
511
|
attention_backend: Optional[str] = None,
|
|
497
512
|
prefill_attention_backend: Optional[str] = None,
|
|
498
513
|
decode_attention_backend: Optional[str] = None,
|
|
499
|
-
lora_backend: str = "
|
|
514
|
+
lora_backend: str = "csgmv",
|
|
500
515
|
disable_cuda_graph: bool = False,
|
|
501
516
|
disable_radix_cache: bool = False,
|
|
502
517
|
chunked_prefill_size: Optional[int] = None,
|
|
@@ -519,6 +534,7 @@ class SRTRunner:
|
|
|
519
534
|
lora_target_modules: Optional[List[str]] = None,
|
|
520
535
|
enable_lora: Optional[bool] = None,
|
|
521
536
|
max_loaded_loras: Optional[int] = None,
|
|
537
|
+
json_model_override_args: Optional[dict[str, Any]] = None,
|
|
522
538
|
lora_eviction_policy: str = "lru",
|
|
523
539
|
):
|
|
524
540
|
self.model_type = model_type
|
|
@@ -566,6 +582,11 @@ class SRTRunner:
|
|
|
566
582
|
lora_target_modules=lora_target_modules,
|
|
567
583
|
enable_lora=enable_lora,
|
|
568
584
|
max_loaded_loras=max_loaded_loras,
|
|
585
|
+
json_model_override_args=(
|
|
586
|
+
json.dumps(json_model_override_args)
|
|
587
|
+
if json_model_override_args
|
|
588
|
+
else "{}"
|
|
589
|
+
),
|
|
569
590
|
lora_eviction_policy=lora_eviction_policy,
|
|
570
591
|
**spec_kwargs,
|
|
571
592
|
)
|
|
@@ -594,6 +615,7 @@ class SRTRunner:
|
|
|
594
615
|
logprob_start_len: int = 0,
|
|
595
616
|
top_k: Optional[int] = None,
|
|
596
617
|
token_ids_logprob: Optional[List[int]] = None,
|
|
618
|
+
dimensions: Optional[int] = None,
|
|
597
619
|
):
|
|
598
620
|
if self.is_generation:
|
|
599
621
|
return self.forward_generation_raw(
|
|
@@ -607,7 +629,9 @@ class SRTRunner:
|
|
|
607
629
|
)
|
|
608
630
|
else:
|
|
609
631
|
if self.model_type == "embedding":
|
|
610
|
-
response = self.engine.encode(
|
|
632
|
+
response = self.engine.encode(
|
|
633
|
+
prompt=prompts, image_data=image_data, dimensions=dimensions
|
|
634
|
+
)
|
|
611
635
|
if isinstance(response, list):
|
|
612
636
|
logits = [x["embedding"] for x in response]
|
|
613
637
|
else:
|
|
@@ -148,7 +148,7 @@ class ChatCompletionSampler(SamplerBase):
|
|
|
148
148
|
reasoning_effort=self.reasoning_effort,
|
|
149
149
|
extra_body=self.extra_body,
|
|
150
150
|
)
|
|
151
|
-
return response.choices[0].message.content
|
|
151
|
+
return response.choices[0].message.content or ""
|
|
152
152
|
# NOTE: BadRequestError is triggered once for MMMU, please uncomment if you are rerunning MMMU
|
|
153
153
|
except openai.BadRequestError as e:
|
|
154
154
|
print("Bad Request Error", e)
|
|
@@ -161,7 +161,9 @@ class ChatCompletionSampler(SamplerBase):
|
|
|
161
161
|
)
|
|
162
162
|
time.sleep(exception_backoff)
|
|
163
163
|
trial += 1
|
|
164
|
-
|
|
164
|
+
# If all retries are exhausted, return empty string instead of None
|
|
165
|
+
print(f"All retry attempts exhausted for request. Returning empty response.")
|
|
166
|
+
return ""
|
|
165
167
|
|
|
166
168
|
|
|
167
169
|
QUERY_TEMPLATE_MULTICHOICE = """
|
|
@@ -261,7 +263,7 @@ def format_multichoice_question(row):
|
|
|
261
263
|
def check_equality(sampler: SamplerBase, expr1: str, expr2: str):
|
|
262
264
|
prompt = EQUALITY_TEMPLATE % {"expression1": expr1, "expression2": expr2}
|
|
263
265
|
response = sampler([dict(content=prompt, role="user")])
|
|
264
|
-
return response.lower().strip() == "yes"
|
|
266
|
+
return (response or "").lower().strip() == "yes"
|
|
265
267
|
|
|
266
268
|
|
|
267
269
|
def _compute_stat(values: list, stat: str):
|
|
@@ -80,6 +80,7 @@ class HumanEval(Eval):
|
|
|
80
80
|
instruction = "Read the following function signature and docstring, and fully implement the function described. Your response should only contain the code for this function.\n"
|
|
81
81
|
|
|
82
82
|
def find_code(completion):
|
|
83
|
+
completion = completion or ""
|
|
83
84
|
pattern = re.compile(r"```python\n(.*?)```", re.DOTALL)
|
|
84
85
|
matches = pattern.findall(completion)
|
|
85
86
|
extracted_answer = matches[0] if len(matches) >= 1 else completion
|
sglang/test/simple_eval_math.py
CHANGED
|
@@ -54,6 +54,7 @@ class MathEval(Eval):
|
|
|
54
54
|
sampler._pack_message(content=QUERY_TEMPLATE.format(**row), role="user")
|
|
55
55
|
]
|
|
56
56
|
response_text = sampler(prompt_messages)
|
|
57
|
+
response_text = response_text or ""
|
|
57
58
|
match = re.search(ANSWER_PATTERN, response_text)
|
|
58
59
|
extracted_answer = match.group(1) if match else None
|
|
59
60
|
score = float(
|
sglang/test/simple_eval_mmlu.py
CHANGED
|
@@ -101,6 +101,7 @@ class MMLUEval(Eval):
|
|
|
101
101
|
)
|
|
102
102
|
]
|
|
103
103
|
response_text = sampler(prompt_messages)
|
|
104
|
+
response_text = response_text or ""
|
|
104
105
|
match = re.search(ANSWER_PATTERN_MULTICHOICE, response_text)
|
|
105
106
|
extracted_answer = match.group(1) if match else None
|
|
106
107
|
score = 1.0 if extracted_answer == row["Answer"] else 0.0
|
|
@@ -17,7 +17,7 @@ import dataclasses
|
|
|
17
17
|
import json
|
|
18
18
|
import os
|
|
19
19
|
import random
|
|
20
|
-
from typing import List
|
|
20
|
+
from typing import Any, Dict, List, Optional
|
|
21
21
|
|
|
22
22
|
import requests
|
|
23
23
|
|
|
@@ -78,6 +78,7 @@ class BenchArgs:
|
|
|
78
78
|
"single",
|
|
79
79
|
"prefix",
|
|
80
80
|
"radix_cache",
|
|
81
|
+
"p_vs_d",
|
|
81
82
|
],
|
|
82
83
|
)
|
|
83
84
|
parser.add_argument("--profile", action="store_true")
|
|
@@ -94,18 +95,21 @@ class BenchArgs:
|
|
|
94
95
|
|
|
95
96
|
def send_single(
|
|
96
97
|
args,
|
|
97
|
-
batch_size: int = 1,
|
|
98
98
|
profile: bool = False,
|
|
99
99
|
profile_steps: int = 3,
|
|
100
100
|
profile_by_stage: bool = False,
|
|
101
101
|
return_full_response: bool = False,
|
|
102
102
|
input_ids: List[int] = None,
|
|
103
|
+
prompt: List[str] = None,
|
|
103
104
|
max_new_tokens: int = None,
|
|
105
|
+
extra_params: Optional[Dict[str, Any]] = None,
|
|
106
|
+
pick_first_result: bool = True,
|
|
104
107
|
):
|
|
105
108
|
base_url = f"http://{args.host}:{args.port}"
|
|
106
109
|
|
|
107
110
|
# Use input_ids if provided, otherwise use text prompts
|
|
108
111
|
if input_ids is not None:
|
|
112
|
+
assert prompt is None
|
|
109
113
|
json_data = {
|
|
110
114
|
"input_ids": input_ids,
|
|
111
115
|
"sampling_params": {
|
|
@@ -120,9 +124,10 @@ def send_single(
|
|
|
120
124
|
},
|
|
121
125
|
"return_logprob": args.return_logprob,
|
|
122
126
|
"stream": args.stream,
|
|
127
|
+
**(extra_params or {}),
|
|
123
128
|
}
|
|
124
129
|
else:
|
|
125
|
-
|
|
130
|
+
assert input_ids is None
|
|
126
131
|
json_data = {
|
|
127
132
|
"text": prompt,
|
|
128
133
|
"sampling_params": {
|
|
@@ -137,6 +142,7 @@ def send_single(
|
|
|
137
142
|
},
|
|
138
143
|
"return_logprob": args.return_logprob,
|
|
139
144
|
"stream": args.stream,
|
|
145
|
+
**(extra_params or {}),
|
|
140
146
|
}
|
|
141
147
|
|
|
142
148
|
if args.sampling_seed is not None:
|
|
@@ -169,7 +175,8 @@ def send_single(
|
|
|
169
175
|
else:
|
|
170
176
|
ret = response.json()
|
|
171
177
|
|
|
172
|
-
|
|
178
|
+
if pick_first_result:
|
|
179
|
+
ret = ret[0] if isinstance(ret, list) else ret
|
|
173
180
|
|
|
174
181
|
if return_full_response:
|
|
175
182
|
return ret
|
|
@@ -177,7 +184,9 @@ def send_single(
|
|
|
177
184
|
return ret["text"]
|
|
178
185
|
|
|
179
186
|
|
|
180
|
-
def send_prefix(
|
|
187
|
+
def send_prefix(
|
|
188
|
+
args, batch_size: int, prompts: List[str], return_full_response: bool = False
|
|
189
|
+
):
|
|
181
190
|
requests.post(f"http://{args.host}:{args.port}/flush_cache")
|
|
182
191
|
|
|
183
192
|
batch_data = []
|
|
@@ -212,11 +221,157 @@ def send_prefix(args, batch_size: int, prompts: List[str]):
|
|
|
212
221
|
print(ret)
|
|
213
222
|
return -1, -1, -1
|
|
214
223
|
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
ret_dict[
|
|
224
|
+
if return_full_response:
|
|
225
|
+
# Return full responses grouped by prompt index
|
|
226
|
+
ret_dict = {i: [] for i in range(len(prompts))}
|
|
227
|
+
for i in range(batch_size):
|
|
228
|
+
ret_dict[sampled_indices[i]].append(ret[i])
|
|
229
|
+
return ret_dict
|
|
230
|
+
else:
|
|
231
|
+
# Return only text grouped by prompt index
|
|
232
|
+
ret_dict = {i: [] for i in range(len(prompts))}
|
|
233
|
+
for i in range(batch_size):
|
|
234
|
+
ret_dict[sampled_indices[i]].append(ret[i]["text"])
|
|
235
|
+
return ret_dict
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
def compare_logprobs(logprobs1, logprobs2, tolerance=0):
|
|
239
|
+
"""Compare two logprobs sequences with a tolerance."""
|
|
240
|
+
if len(logprobs1) != len(logprobs2):
|
|
241
|
+
return False, f"Length mismatch: {len(logprobs1)} vs {len(logprobs2)}"
|
|
242
|
+
|
|
243
|
+
for i, (lp1, lp2) in enumerate(zip(logprobs1, logprobs2)):
|
|
244
|
+
# Each element is [logprob, token_id]
|
|
245
|
+
if lp1[1] != lp2[1]:
|
|
246
|
+
return False, f"Token ID mismatch at position {i}: {lp1[1]} vs {lp2[1]}"
|
|
247
|
+
if abs(lp1[0] - lp2[0]) > tolerance:
|
|
248
|
+
return (
|
|
249
|
+
False,
|
|
250
|
+
f"Logprob mismatch at position {i}: {lp1[0]} vs {lp2[0]} (diff: {abs(lp1[0] - lp2[0])})",
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
return True, "Logprobs match"
|
|
254
|
+
|
|
218
255
|
|
|
219
|
-
|
|
256
|
+
def _test_mode_p_vs_d(args, batch_size):
|
|
257
|
+
print()
|
|
258
|
+
print(f"Execute: test p_vs_d {batch_size=}")
|
|
259
|
+
|
|
260
|
+
random.seed(42)
|
|
261
|
+
args.return_logprob = True
|
|
262
|
+
query_extra_params = {
|
|
263
|
+
"logprob_start_len": 0,
|
|
264
|
+
"return_text_in_logprobs": True,
|
|
265
|
+
}
|
|
266
|
+
|
|
267
|
+
def _create_prompts():
|
|
268
|
+
ans = [PROMPT_1, PROMPT_2]
|
|
269
|
+
for i in range(batch_size - len(ans)):
|
|
270
|
+
end = random.randrange(1, 4096)
|
|
271
|
+
if random.random() < 0.5:
|
|
272
|
+
begin = 0
|
|
273
|
+
else:
|
|
274
|
+
begin = random.randrange(0, end)
|
|
275
|
+
ans.append(LONG_PROMPT[begin:end])
|
|
276
|
+
return ans[:batch_size]
|
|
277
|
+
|
|
278
|
+
# warmup + flush
|
|
279
|
+
send_single(args, input_ids=[1] * 64, max_new_tokens=65, return_full_response=True)
|
|
280
|
+
requests.post(f"http://{args.host}:{args.port}/flush_cache")
|
|
281
|
+
|
|
282
|
+
prompts = _create_prompts()
|
|
283
|
+
|
|
284
|
+
resp_a = send_single(
|
|
285
|
+
args,
|
|
286
|
+
prompt=prompts,
|
|
287
|
+
max_new_tokens=args.max_new_tokens,
|
|
288
|
+
return_full_response=True,
|
|
289
|
+
pick_first_result=False,
|
|
290
|
+
extra_params=query_extra_params,
|
|
291
|
+
)
|
|
292
|
+
info_a = _extract_ids_and_logprobs(resp_a)
|
|
293
|
+
|
|
294
|
+
requests.post(f"http://{args.host}:{args.port}/flush_cache")
|
|
295
|
+
|
|
296
|
+
resp_b = send_single(
|
|
297
|
+
args,
|
|
298
|
+
input_ids=[x["io"].token_ids for x in info_a],
|
|
299
|
+
max_new_tokens=1,
|
|
300
|
+
return_full_response=True,
|
|
301
|
+
pick_first_result=False,
|
|
302
|
+
extra_params=query_extra_params,
|
|
303
|
+
)
|
|
304
|
+
info_b = _extract_ids_and_logprobs(resp_b)
|
|
305
|
+
|
|
306
|
+
ans = []
|
|
307
|
+
for i, (info_a_item, info_b_item) in enumerate(zip(info_a, info_b, strict=True)):
|
|
308
|
+
print(f"Compare sequence {i} in batch...")
|
|
309
|
+
correct = TokenIdsAndLogprobs.compare(info_a_item["io"], info_b_item["input"])
|
|
310
|
+
ans.append(int(correct))
|
|
311
|
+
|
|
312
|
+
return ans
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
@dataclasses.dataclass
|
|
316
|
+
class TokenIdsAndLogprobs:
|
|
317
|
+
token_ids: List[int]
|
|
318
|
+
logprobs: List[float]
|
|
319
|
+
|
|
320
|
+
def __add__(self, other):
|
|
321
|
+
return TokenIdsAndLogprobs(
|
|
322
|
+
token_ids=self.token_ids + other.token_ids,
|
|
323
|
+
logprobs=self.logprobs + other.logprobs,
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
@classmethod
|
|
327
|
+
def compare(cls, a: "TokenIdsAndLogprobs", b: "TokenIdsAndLogprobs"):
|
|
328
|
+
assert len(a.token_ids) == len(b.token_ids)
|
|
329
|
+
token_match = a.token_ids == b.token_ids
|
|
330
|
+
logprobs_match = a.logprobs == b.logprobs
|
|
331
|
+
|
|
332
|
+
if token_match:
|
|
333
|
+
print(f"Token match: {a.token_ids}")
|
|
334
|
+
else:
|
|
335
|
+
print(f"❗Token mismatch: {a.token_ids=} {b.token_ids=}")
|
|
336
|
+
|
|
337
|
+
if logprobs_match:
|
|
338
|
+
print(f"Logprobs match:", a.logprobs)
|
|
339
|
+
else:
|
|
340
|
+
print(f"❗Logprobs mismatch")
|
|
341
|
+
print(
|
|
342
|
+
" A: ",
|
|
343
|
+
[f"{x:.10f}" if x is not None else "None" for x in a.logprobs],
|
|
344
|
+
)
|
|
345
|
+
print(
|
|
346
|
+
" B: ",
|
|
347
|
+
[f"{x:.10f}" if x is not None else "None" for x in b.logprobs],
|
|
348
|
+
)
|
|
349
|
+
diff = [
|
|
350
|
+
abs(x - y) if x is not None else float("nan")
|
|
351
|
+
for x, y in zip(a.logprobs, b.logprobs)
|
|
352
|
+
]
|
|
353
|
+
print(" Diff:", [f"{x:.10e}" for x in diff])
|
|
354
|
+
|
|
355
|
+
return token_match and logprobs_match
|
|
356
|
+
|
|
357
|
+
|
|
358
|
+
def _extract_ids_and_logprobs(responses):
|
|
359
|
+
def _extract_part(response, name):
|
|
360
|
+
token_ids, logprobs = [], []
|
|
361
|
+
for item in response["meta_info"][name]:
|
|
362
|
+
logprob, token_id, text = item
|
|
363
|
+
token_ids.append(token_id)
|
|
364
|
+
logprobs.append(logprob)
|
|
365
|
+
return TokenIdsAndLogprobs(token_ids=token_ids, logprobs=logprobs)
|
|
366
|
+
|
|
367
|
+
def _extract_one_response(response):
|
|
368
|
+
input = _extract_part(response, "input_token_logprobs")
|
|
369
|
+
output = _extract_part(response, "output_token_logprobs")
|
|
370
|
+
return dict(input=input, output=output, io=input + output)
|
|
371
|
+
|
|
372
|
+
if not isinstance(responses, list):
|
|
373
|
+
responses = [responses]
|
|
374
|
+
return [_extract_one_response(x) for x in responses]
|
|
220
375
|
|
|
221
376
|
|
|
222
377
|
def test_deterministic(args):
|
|
@@ -225,7 +380,7 @@ def test_deterministic(args):
|
|
|
225
380
|
texts = []
|
|
226
381
|
for i in range(1, args.n_trials + 1):
|
|
227
382
|
batch_size = i
|
|
228
|
-
text = send_single(args,
|
|
383
|
+
text = send_single(args, args.profile, prompt=[PROMPT_1] * batch_size)
|
|
229
384
|
text = text.replace("\n", " ")
|
|
230
385
|
print(f"Trial {i} with batch size {batch_size}: {text}")
|
|
231
386
|
texts.append(text)
|
|
@@ -238,15 +393,28 @@ def test_deterministic(args):
|
|
|
238
393
|
num_prompts = len(len_prefix)
|
|
239
394
|
outputs = {i: [] for i in range(4)}
|
|
240
395
|
prompts = [LONG_PROMPT[: len_prefix[i]] for i in range(4)]
|
|
396
|
+
|
|
397
|
+
# If return_logprob is enabled, store full responses for comparison
|
|
398
|
+
if args.return_logprob:
|
|
399
|
+
full_responses = {i: [] for i in range(4)}
|
|
400
|
+
|
|
241
401
|
for i in range(args.n_start, args.n_start + args.n_trials):
|
|
242
402
|
batch_size = i
|
|
243
|
-
ret_dict = send_prefix(
|
|
403
|
+
ret_dict = send_prefix(
|
|
404
|
+
args, batch_size, prompts, return_full_response=args.return_logprob
|
|
405
|
+
)
|
|
244
406
|
msg = f"Testing Trial {i} with batch size {batch_size},"
|
|
245
407
|
for i in range(num_prompts):
|
|
246
408
|
msg += f" # prefix length {len_prefix[i]}: {len(ret_dict[i])},"
|
|
247
409
|
print(msg)
|
|
248
410
|
for i in range(num_prompts):
|
|
249
|
-
|
|
411
|
+
if args.return_logprob:
|
|
412
|
+
# Store full response for logprob comparison
|
|
413
|
+
full_responses[i].extend(ret_dict[i])
|
|
414
|
+
# Extract text for determinism check
|
|
415
|
+
outputs[i].extend([resp["text"] for resp in ret_dict[i]])
|
|
416
|
+
else:
|
|
417
|
+
outputs[i].extend(ret_dict[i])
|
|
250
418
|
|
|
251
419
|
for i in range(num_prompts):
|
|
252
420
|
print(
|
|
@@ -256,6 +424,54 @@ def test_deterministic(args):
|
|
|
256
424
|
results = []
|
|
257
425
|
for i in range(num_prompts):
|
|
258
426
|
results.append(len(set(outputs[i])))
|
|
427
|
+
|
|
428
|
+
# If logprobs are enabled, compare them across different batch sizes
|
|
429
|
+
if args.return_logprob:
|
|
430
|
+
print(f"\n{'='*60}")
|
|
431
|
+
print("Logprobs Comparison Across Batch Sizes")
|
|
432
|
+
print("=" * 60)
|
|
433
|
+
|
|
434
|
+
logprob_results = []
|
|
435
|
+
for prompt_idx in range(num_prompts):
|
|
436
|
+
print(
|
|
437
|
+
f"\nPrompt {prompt_idx} (prefix length {len_prefix[prompt_idx]}):"
|
|
438
|
+
)
|
|
439
|
+
responses = full_responses[prompt_idx]
|
|
440
|
+
|
|
441
|
+
if len(responses) < 2:
|
|
442
|
+
continue
|
|
443
|
+
|
|
444
|
+
# Compare all responses against the first one
|
|
445
|
+
reference = responses[0]
|
|
446
|
+
all_match = True
|
|
447
|
+
mismatches = []
|
|
448
|
+
|
|
449
|
+
for j, resp in enumerate(responses[1:], start=1):
|
|
450
|
+
ref_logprobs = reference["meta_info"]["output_token_logprobs"]
|
|
451
|
+
resp_logprobs = resp["meta_info"]["output_token_logprobs"]
|
|
452
|
+
|
|
453
|
+
match, msg = compare_logprobs(ref_logprobs, resp_logprobs)
|
|
454
|
+
|
|
455
|
+
if not match:
|
|
456
|
+
print(f" ✗ Sample {j+1}: {msg}")
|
|
457
|
+
mismatches.append((j + 1, msg))
|
|
458
|
+
all_match = False
|
|
459
|
+
|
|
460
|
+
if all_match:
|
|
461
|
+
print(f" ✓ All {len(responses)} samples have identical logprobs")
|
|
462
|
+
logprob_results.append(1)
|
|
463
|
+
else:
|
|
464
|
+
print(
|
|
465
|
+
f" ✗ Found {len(mismatches)} mismatches out of {len(responses)} samples"
|
|
466
|
+
)
|
|
467
|
+
logprob_results.append(0)
|
|
468
|
+
|
|
469
|
+
print(f"\n{'='*60}")
|
|
470
|
+
if all(r == 1 for r in logprob_results):
|
|
471
|
+
print("✓✓✓ Logprobs are identical across all batch sizes! ✓✓✓")
|
|
472
|
+
else:
|
|
473
|
+
print("✗✗✗ Some logprobs differ across batch sizes! ✗✗✗")
|
|
474
|
+
|
|
259
475
|
return results
|
|
260
476
|
|
|
261
477
|
elif args.test_mode == "radix_cache":
|
|
@@ -415,6 +631,13 @@ def test_deterministic(args):
|
|
|
415
631
|
print("✗✗✗ TEST FAILED - Radix cache produces different results! ✗✗✗")
|
|
416
632
|
return [0]
|
|
417
633
|
|
|
634
|
+
elif args.test_mode == "p_vs_d":
|
|
635
|
+
# TODO also extract other modes to functions
|
|
636
|
+
ans = []
|
|
637
|
+
for i in range(1, args.n_trials + 1):
|
|
638
|
+
ans += _test_mode_p_vs_d(args, batch_size=i)
|
|
639
|
+
return ans
|
|
640
|
+
|
|
418
641
|
else:
|
|
419
642
|
raise ValueError(f"Invalid test mode: {args.test_mode}")
|
|
420
643
|
|
|
@@ -60,7 +60,7 @@ class TestDeterministicBase(CustomTestCase):
|
|
|
60
60
|
for result in results:
|
|
61
61
|
assert result == 1
|
|
62
62
|
|
|
63
|
-
def
|
|
63
|
+
def test_prefix_with_logprobs(self):
|
|
64
64
|
args = BenchArgs()
|
|
65
65
|
url = DEFAULT_URL_FOR_TEST
|
|
66
66
|
args.host, args.port = self._extract_host_and_port(url)
|
|
@@ -68,6 +68,7 @@ class TestDeterministicBase(CustomTestCase):
|
|
|
68
68
|
args.n_start = 10
|
|
69
69
|
args.n_trials = 10
|
|
70
70
|
args.temperature = 0.5 # test for deterministic sampling
|
|
71
|
+
args.return_logprob = True # Enable logprobs comparison
|
|
71
72
|
results = test_deterministic(args)
|
|
72
73
|
for result in results:
|
|
73
74
|
assert result == 1
|
sglang/test/test_utils.py
CHANGED
|
@@ -84,6 +84,8 @@ DEFAULT_MODEL_NAME_FOR_TEST_AWQ_INT4 = (
|
|
|
84
84
|
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST = "meta-llama/Llama-2-7b-chat-hf"
|
|
85
85
|
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST = "lmsys/sglang-EAGLE-llama2-chat-7B"
|
|
86
86
|
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST_EAGLE3 = "meta-llama/Llama-3.1-8B-Instruct"
|
|
87
|
+
DEFAULT_EAGLE_DP_ATTENTION_TARGET_MODEL_FOR_TEST = "Qwen/Qwen3-30B-A3B"
|
|
88
|
+
DEFAULT_EAGLE_DP_ATTENTION_DRAFT_MODEL_FOR_TEST = "Tengyunw/qwen3_30b_moe_eagle3"
|
|
87
89
|
DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3 = "lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B"
|
|
88
90
|
DEFAULT_STANDALONE_SPECULATIVE_TARGET_MODEL_FOR_TEST = (
|
|
89
91
|
"meta-llama/Llama-3.1-8B-Instruct"
|
|
@@ -92,6 +94,10 @@ DEFAULT_STANDALONE_SPECULATIVE_DRAFT_MODEL_FOR_TEST = "meta-llama/Llama-3.2-1B-I
|
|
|
92
94
|
DEFAULT_NGRAM_SPECULATIVE_TARGET_MODEL_FOR_TEST = "Qwen/Qwen2.5-Coder-7B-Instruct"
|
|
93
95
|
|
|
94
96
|
# Other use cases
|
|
97
|
+
DEFAULT_AUTOROUND_MODEL_NAME_FOR_TEST = (
|
|
98
|
+
"OPEA/Qwen2.5-0.5B-Instruct-int4-sym-inc", # auto_round:auto_gptq
|
|
99
|
+
"Intel/Qwen2-0.5B-Instruct-int4-sym-AutoRound", # auto_round:auto_awq
|
|
100
|
+
)
|
|
95
101
|
DEFAULT_MODEL_NAME_FOR_TEST_LOCAL_ATTENTION = (
|
|
96
102
|
"meta-llama/Llama-4-Scout-17B-16E-Instruct"
|
|
97
103
|
)
|
|
@@ -145,7 +151,7 @@ def _use_cached_default_models(model_repo: str):
|
|
|
145
151
|
|
|
146
152
|
if is_in_ci():
|
|
147
153
|
DEFAULT_PORT_FOR_SRT_TEST_RUNNER = (
|
|
148
|
-
10000 + int(os.environ.get("CUDA_VISIBLE_DEVICES", "0")[0]) *
|
|
154
|
+
10000 + int(os.environ.get("CUDA_VISIBLE_DEVICES", "0")[0]) * 2000
|
|
149
155
|
)
|
|
150
156
|
else:
|
|
151
157
|
DEFAULT_PORT_FOR_SRT_TEST_RUNNER = (
|
sglang/version.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = "0.5.4"
|
|
1
|
+
__version__ = "0.5.4.post2"
|