sglang 0.3.0__py3-none-any.whl → 0.3.1.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +17 -8
- sglang/bench_serving.py +33 -38
- sglang/global_config.py +5 -17
- sglang/lang/backend/runtime_endpoint.py +5 -2
- sglang/lang/interpreter.py +1 -4
- sglang/launch_server.py +3 -6
- sglang/launch_server_llavavid.py +7 -8
- sglang/srt/{model_config.py → configs/model_config.py} +5 -0
- sglang/srt/constrained/__init__.py +2 -0
- sglang/srt/constrained/fsm_cache.py +33 -38
- sglang/srt/constrained/jump_forward.py +0 -1
- sglang/srt/conversation.py +4 -1
- sglang/srt/hf_transformers_utils.py +1 -3
- sglang/srt/layers/activation.py +12 -0
- sglang/srt/layers/attention_backend.py +480 -0
- sglang/srt/layers/flashinfer_utils.py +235 -0
- sglang/srt/layers/fused_moe/layer.py +27 -7
- sglang/srt/layers/layernorm.py +12 -0
- sglang/srt/layers/logits_processor.py +64 -77
- sglang/srt/layers/radix_attention.py +11 -161
- sglang/srt/layers/sampler.py +38 -122
- sglang/srt/layers/torchao_utils.py +75 -0
- sglang/srt/layers/{decode_attention.py → triton_attention/decode_attention.py} +67 -63
- sglang/srt/layers/{extend_attention.py → triton_attention/extend_attention.py} +40 -132
- sglang/srt/layers/{prefill_attention.py → triton_attention/prefill_attention.py} +13 -7
- sglang/srt/lora/lora.py +403 -0
- sglang/srt/lora/lora_config.py +43 -0
- sglang/srt/lora/lora_manager.py +259 -0
- sglang/srt/managers/controller_multi.py +1 -5
- sglang/srt/managers/controller_single.py +0 -5
- sglang/srt/managers/io_struct.py +16 -1
- sglang/srt/managers/policy_scheduler.py +122 -5
- sglang/srt/managers/schedule_batch.py +105 -71
- sglang/srt/managers/tokenizer_manager.py +17 -8
- sglang/srt/managers/tp_worker.py +188 -121
- sglang/srt/model_executor/cuda_graph_runner.py +69 -133
- sglang/srt/model_executor/forward_batch_info.py +35 -312
- sglang/srt/model_executor/model_runner.py +123 -154
- sglang/srt/models/baichuan.py +416 -0
- sglang/srt/models/chatglm.py +1 -5
- sglang/srt/models/commandr.py +1 -5
- sglang/srt/models/dbrx.py +1 -5
- sglang/srt/models/deepseek.py +1 -5
- sglang/srt/models/deepseek_v2.py +7 -6
- sglang/srt/models/exaone.py +1 -5
- sglang/srt/models/gemma.py +1 -5
- sglang/srt/models/gemma2.py +1 -5
- sglang/srt/models/gpt_bigcode.py +1 -5
- sglang/srt/models/grok.py +1 -5
- sglang/srt/models/internlm2.py +1 -5
- sglang/srt/models/llama.py +51 -5
- sglang/srt/models/llama_classification.py +1 -20
- sglang/srt/models/llava.py +30 -5
- sglang/srt/models/llavavid.py +2 -2
- sglang/srt/models/minicpm.py +1 -5
- sglang/srt/models/minicpm3.py +669 -0
- sglang/srt/models/mixtral.py +6 -5
- sglang/srt/models/mixtral_quant.py +1 -5
- sglang/srt/models/olmoe.py +415 -0
- sglang/srt/models/qwen.py +1 -5
- sglang/srt/models/qwen2.py +1 -5
- sglang/srt/models/qwen2_moe.py +6 -5
- sglang/srt/models/stablelm.py +1 -5
- sglang/srt/models/xverse.py +375 -0
- sglang/srt/models/xverse_moe.py +445 -0
- sglang/srt/openai_api/adapter.py +65 -46
- sglang/srt/openai_api/protocol.py +11 -3
- sglang/srt/sampling/sampling_batch_info.py +46 -80
- sglang/srt/server.py +30 -15
- sglang/srt/server_args.py +163 -28
- sglang/srt/utils.py +19 -51
- sglang/test/few_shot_gsm8k.py +132 -0
- sglang/test/runners.py +114 -22
- sglang/test/test_programs.py +7 -5
- sglang/test/test_utils.py +85 -2
- sglang/utils.py +32 -37
- sglang/version.py +1 -1
- {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/METADATA +30 -18
- sglang-0.3.1.post1.dist-info/RECORD +130 -0
- {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/WHEEL +1 -1
- sglang-0.3.0.dist-info/RECORD +0 -118
- {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/LICENSE +0 -0
- {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/top_level.txt +0 -0
sglang/srt/server_args.py
CHANGED
@@ -21,9 +21,22 @@ import logging
|
|
21
21
|
import random
|
22
22
|
from typing import List, Optional, Union
|
23
23
|
|
24
|
+
from sglang.srt.utils import is_hip
|
25
|
+
|
24
26
|
logger = logging.getLogger(__name__)
|
25
27
|
|
26
28
|
|
29
|
+
class LoRAPathAction(argparse.Action):
|
30
|
+
def __call__(self, parser, namespace, values, option_string=None):
|
31
|
+
setattr(namespace, self.dest, {})
|
32
|
+
for lora_path in values:
|
33
|
+
if "=" in lora_path:
|
34
|
+
name, path = lora_path.split("=", 1)
|
35
|
+
getattr(namespace, self.dest)[name] = path
|
36
|
+
else:
|
37
|
+
getattr(namespace, self.dest)[lora_path] = lora_path
|
38
|
+
|
39
|
+
|
27
40
|
@dataclasses.dataclass
|
28
41
|
class ServerArgs:
|
29
42
|
# Model and tokenizer
|
@@ -49,7 +62,6 @@ class ServerArgs:
|
|
49
62
|
# Memory and scheduling
|
50
63
|
mem_fraction_static: Optional[float] = None
|
51
64
|
max_running_requests: Optional[int] = None
|
52
|
-
max_num_reqs: Optional[int] = None
|
53
65
|
max_total_tokens: Optional[int] = None
|
54
66
|
chunked_prefill_size: int = 8192
|
55
67
|
max_prefill_tokens: int = 16384
|
@@ -60,6 +72,7 @@ class ServerArgs:
|
|
60
72
|
tp_size: int = 1
|
61
73
|
stream_interval: int = 1
|
62
74
|
random_seed: Optional[int] = None
|
75
|
+
constrained_json_whitespace_pattern: Optional[str] = None
|
63
76
|
|
64
77
|
# Logging
|
65
78
|
log_level: str = "info"
|
@@ -75,7 +88,18 @@ class ServerArgs:
|
|
75
88
|
dp_size: int = 1
|
76
89
|
load_balance_method: str = "round_robin"
|
77
90
|
|
91
|
+
# Distributed args
|
92
|
+
nccl_init_addr: Optional[str] = None
|
93
|
+
nnodes: int = 1
|
94
|
+
node_rank: Optional[int] = None
|
95
|
+
|
96
|
+
# Model override args in JSON
|
97
|
+
json_model_override_args: str = "{}"
|
98
|
+
|
78
99
|
# Optimization/debug options
|
100
|
+
attention_backend: Optional[str] = None
|
101
|
+
sampling_backend: Optional[str] = None
|
102
|
+
|
79
103
|
disable_flashinfer: bool = False
|
80
104
|
disable_flashinfer_sampling: bool = False
|
81
105
|
disable_radix_cache: bool = False
|
@@ -86,16 +110,18 @@ class ServerArgs:
|
|
86
110
|
disable_custom_all_reduce: bool = False
|
87
111
|
enable_mixed_chunk: bool = False
|
88
112
|
enable_torch_compile: bool = False
|
113
|
+
max_torch_compile_bs: int = 32
|
114
|
+
torchao_config: str = ""
|
89
115
|
enable_p2p_check: bool = False
|
90
116
|
enable_mla: bool = False
|
91
117
|
triton_attention_reduce_in_fp32: bool = False
|
92
118
|
|
93
|
-
#
|
94
|
-
|
95
|
-
|
96
|
-
node_rank: Optional[int] = None
|
119
|
+
# LoRA
|
120
|
+
lora_paths: Optional[List[str]] = None
|
121
|
+
max_loras_per_batch: int = 8
|
97
122
|
|
98
123
|
def __post_init__(self):
|
124
|
+
# Set missing default values
|
99
125
|
if self.tokenizer_path is None:
|
100
126
|
self.tokenizer_path = self.model_path
|
101
127
|
|
@@ -106,6 +132,7 @@ class ServerArgs:
|
|
106
132
|
# Disable chunked prefill
|
107
133
|
self.chunked_prefill_size = None
|
108
134
|
|
135
|
+
# Mem fraction depends on the tensor parallelism size
|
109
136
|
if self.mem_fraction_static is None:
|
110
137
|
if self.tp_size >= 16:
|
111
138
|
self.mem_fraction_static = 0.79
|
@@ -126,6 +153,47 @@ class ServerArgs:
|
|
126
153
|
if self.random_seed is None:
|
127
154
|
self.random_seed = random.randint(0, 1 << 30)
|
128
155
|
|
156
|
+
# Deprecation warnings
|
157
|
+
if self.disable_flashinfer:
|
158
|
+
logger.warning(
|
159
|
+
"The option '--disable-flashinfer' will be deprecated in the next release. "
|
160
|
+
"Please use '--attention-backend triton' instead."
|
161
|
+
)
|
162
|
+
self.attention_backend = "triton"
|
163
|
+
if self.disable_flashinfer_sampling:
|
164
|
+
logger.warning(
|
165
|
+
"The option '--disable-flashinfer-sampling' will be deprecated in the next release. "
|
166
|
+
"Please use '--sampling-backend pytorch' instead. "
|
167
|
+
)
|
168
|
+
self.sampling_backend = "pytorch"
|
169
|
+
|
170
|
+
# ROCm: flashinfer available later
|
171
|
+
if is_hip():
|
172
|
+
self.attention_backend = "triton"
|
173
|
+
self.sampling_backend = "pytorch"
|
174
|
+
|
175
|
+
# Default kernel backends
|
176
|
+
if self.enable_mla:
|
177
|
+
logger.info("MLA optimization is tunred on. Use triton backend.")
|
178
|
+
self.attention_backend = "triton"
|
179
|
+
|
180
|
+
if self.attention_backend is None:
|
181
|
+
self.attention_backend = "flashinfer"
|
182
|
+
|
183
|
+
if self.sampling_backend is None:
|
184
|
+
self.sampling_backend = "flashinfer"
|
185
|
+
|
186
|
+
# Model-specific patches
|
187
|
+
if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == self.model_path:
|
188
|
+
logger.info(
|
189
|
+
"Not sure why, the tokenizer will add an additional token at the end of the prompt when trust_remote_mode=True"
|
190
|
+
)
|
191
|
+
self.trust_remote_code = False
|
192
|
+
|
193
|
+
if "gemma-2" in self.model_path.lower():
|
194
|
+
logger.info("When using sliding window in gemma-2, turn on flashinfer.")
|
195
|
+
self.attention_backend = "flashinfer"
|
196
|
+
|
129
197
|
@staticmethod
|
130
198
|
def add_cli_args(parser: argparse.ArgumentParser):
|
131
199
|
parser.add_argument(
|
@@ -209,11 +277,6 @@ class ServerArgs:
|
|
209
277
|
action="store_true",
|
210
278
|
help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
|
211
279
|
)
|
212
|
-
parser.add_argument(
|
213
|
-
"--is-embedding",
|
214
|
-
action="store_true",
|
215
|
-
help="Whether to use a CausalLM as an embedding model.",
|
216
|
-
)
|
217
280
|
parser.add_argument(
|
218
281
|
"--context-length",
|
219
282
|
type=int,
|
@@ -248,6 +311,11 @@ class ServerArgs:
|
|
248
311
|
default=ServerArgs.chat_template,
|
249
312
|
help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server.",
|
250
313
|
)
|
314
|
+
parser.add_argument(
|
315
|
+
"--is-embedding",
|
316
|
+
action="store_true",
|
317
|
+
help="Whether to use a CausalLM as an embedding model.",
|
318
|
+
)
|
251
319
|
parser.add_argument(
|
252
320
|
"--mem-fraction-static",
|
253
321
|
type=float,
|
@@ -260,17 +328,12 @@ class ServerArgs:
|
|
260
328
|
default=ServerArgs.max_running_requests,
|
261
329
|
help="The maximum number of running requests.",
|
262
330
|
)
|
263
|
-
parser.add_argument(
|
264
|
-
"--max-num-reqs",
|
265
|
-
type=int,
|
266
|
-
default=ServerArgs.max_num_reqs,
|
267
|
-
help="The maximum number of requests to serve in the memory pool. If the model have a large context length, you may need to decrease this value to avoid out-of-memory errors.",
|
268
|
-
)
|
269
331
|
parser.add_argument(
|
270
332
|
"--max-total-tokens",
|
271
333
|
type=int,
|
272
334
|
default=ServerArgs.max_total_tokens,
|
273
|
-
help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction.
|
335
|
+
help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. "
|
336
|
+
"This option is typically used for development and debugging purposes.",
|
274
337
|
)
|
275
338
|
parser.add_argument(
|
276
339
|
"--chunked-prefill-size",
|
@@ -316,6 +379,12 @@ class ServerArgs:
|
|
316
379
|
default=ServerArgs.random_seed,
|
317
380
|
help="The random seed.",
|
318
381
|
)
|
382
|
+
parser.add_argument(
|
383
|
+
"--constrained-json-whitespace-pattern",
|
384
|
+
type=str,
|
385
|
+
default=ServerArgs.constrained_json_whitespace_pattern,
|
386
|
+
help=r"Regex pattern for syntactic whitespaces allowed in JSON constrained output. For example, to allow the model generate consecutive whitespaces, set the pattern to [\n\t ]*",
|
387
|
+
)
|
319
388
|
parser.add_argument(
|
320
389
|
"--log-level",
|
321
390
|
type=str,
|
@@ -381,16 +450,38 @@ class ServerArgs:
|
|
381
450
|
)
|
382
451
|
parser.add_argument("--node-rank", type=int, help="The node rank.")
|
383
452
|
|
453
|
+
# Model override args
|
454
|
+
parser.add_argument(
|
455
|
+
"--json-model-override-args",
|
456
|
+
type=str,
|
457
|
+
help="A dictionary in JSON string format used to override default model configurations.",
|
458
|
+
default=ServerArgs.json_model_override_args,
|
459
|
+
)
|
460
|
+
|
384
461
|
# Optimization/debug options
|
462
|
+
parser.add_argument(
|
463
|
+
"--attention-backend",
|
464
|
+
type=str,
|
465
|
+
choices=["flashinfer", "triton"],
|
466
|
+
default=ServerArgs.attention_backend,
|
467
|
+
help="Choose the kernels for attention layers.",
|
468
|
+
)
|
469
|
+
parser.add_argument(
|
470
|
+
"--sampling-backend",
|
471
|
+
type=str,
|
472
|
+
choices=["flashinfer", "pytorch"],
|
473
|
+
default=ServerArgs.sampling_backend,
|
474
|
+
help="Choose the kernels for sampling layers.",
|
475
|
+
)
|
385
476
|
parser.add_argument(
|
386
477
|
"--disable-flashinfer",
|
387
478
|
action="store_true",
|
388
|
-
help="Disable flashinfer attention kernels.",
|
479
|
+
help="Disable flashinfer attention kernels. This option will be deprecated in the next release. Please use '--attention-backend triton' instead.",
|
389
480
|
)
|
390
481
|
parser.add_argument(
|
391
482
|
"--disable-flashinfer-sampling",
|
392
483
|
action="store_true",
|
393
|
-
help="Disable flashinfer sampling kernels.",
|
484
|
+
help="Disable flashinfer sampling kernels. This option will be deprecated in the next release. Please use '--sampling-backend pytorch' instead.",
|
394
485
|
)
|
395
486
|
parser.add_argument(
|
396
487
|
"--disable-radix-cache",
|
@@ -431,7 +522,19 @@ class ServerArgs:
|
|
431
522
|
parser.add_argument(
|
432
523
|
"--enable-torch-compile",
|
433
524
|
action="store_true",
|
434
|
-
help="Optimize the model with torch.compile
|
525
|
+
help="Optimize the model with torch.compile. Experimental feature.",
|
526
|
+
)
|
527
|
+
parser.add_argument(
|
528
|
+
"--max-torch-compile-bs",
|
529
|
+
type=int,
|
530
|
+
default=ServerArgs.max_torch_compile_bs,
|
531
|
+
help="Set the maximum batch size when using torch compile.",
|
532
|
+
)
|
533
|
+
parser.add_argument(
|
534
|
+
"--torchao-config",
|
535
|
+
type=str,
|
536
|
+
default=ServerArgs.torchao_config,
|
537
|
+
help="Optimize the model with torchao. Experimental feature. Current choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo",
|
435
538
|
)
|
436
539
|
parser.add_argument(
|
437
540
|
"--enable-p2p-check",
|
@@ -455,6 +558,22 @@ class ServerArgs:
|
|
455
558
|
help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).",
|
456
559
|
)
|
457
560
|
|
561
|
+
# LoRA options
|
562
|
+
parser.add_argument(
|
563
|
+
"--lora-paths",
|
564
|
+
type=str,
|
565
|
+
nargs="*",
|
566
|
+
default=None,
|
567
|
+
action=LoRAPathAction,
|
568
|
+
help="The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}",
|
569
|
+
)
|
570
|
+
parser.add_argument(
|
571
|
+
"--max-loras-per-batch",
|
572
|
+
type=int,
|
573
|
+
default=8,
|
574
|
+
help="Maximum number of adapters for a running batch, include base-only request",
|
575
|
+
)
|
576
|
+
|
458
577
|
@classmethod
|
459
578
|
def from_cli_args(cls, args: argparse.Namespace):
|
460
579
|
args.tp_size = args.tensor_parallel_size
|
@@ -472,14 +591,30 @@ class ServerArgs:
|
|
472
591
|
assert not (
|
473
592
|
self.dp_size > 1 and self.node_rank is not None
|
474
593
|
), "multi-node data parallel is not supported"
|
475
|
-
|
476
|
-
|
477
|
-
|
478
|
-
)
|
479
|
-
self.
|
480
|
-
|
481
|
-
|
482
|
-
|
594
|
+
assert (
|
595
|
+
self.max_loras_per_batch > 0
|
596
|
+
# FIXME
|
597
|
+
and (self.lora_paths is None or self.disable_cuda_graph)
|
598
|
+
and (self.lora_paths is None or self.disable_radix_cache)
|
599
|
+
), "compatibility of lora and cuda graph and radix attention is in progress"
|
600
|
+
|
601
|
+
|
602
|
+
def prepare_server_args(argv: List[str]) -> ServerArgs:
|
603
|
+
"""
|
604
|
+
Prepare the server arguments from the command line arguments.
|
605
|
+
|
606
|
+
Args:
|
607
|
+
args: The command line arguments. Typically, it should be `sys.argv[1:]`
|
608
|
+
to ensure compatibility with `parse_args` when no arguments are passed.
|
609
|
+
|
610
|
+
Returns:
|
611
|
+
The server arguments.
|
612
|
+
"""
|
613
|
+
parser = argparse.ArgumentParser()
|
614
|
+
ServerArgs.add_cli_args(parser)
|
615
|
+
raw_args = parser.parse_args(argv)
|
616
|
+
server_args = ServerArgs.from_cli_args(raw_args)
|
617
|
+
return server_args
|
483
618
|
|
484
619
|
|
485
620
|
@dataclasses.dataclass
|
sglang/srt/utils.py
CHANGED
@@ -35,6 +35,7 @@ import torch
|
|
35
35
|
import torch.distributed as dist
|
36
36
|
from fastapi.responses import JSONResponse
|
37
37
|
from packaging import version as pkg_version
|
38
|
+
from torch import nn
|
38
39
|
from torch.nn.parameter import Parameter
|
39
40
|
from triton.runtime.cache import (
|
40
41
|
FileCacheManager,
|
@@ -50,6 +51,11 @@ show_time_cost = False
|
|
50
51
|
time_infos = {}
|
51
52
|
|
52
53
|
|
54
|
+
# torch flag AMD GPU
|
55
|
+
def is_hip() -> bool:
|
56
|
+
return torch.version.hip is not None
|
57
|
+
|
58
|
+
|
53
59
|
def enable_show_time_cost():
|
54
60
|
global show_time_cost
|
55
61
|
show_time_cost = True
|
@@ -186,7 +192,7 @@ def allocate_init_ports(
|
|
186
192
|
cur_port += 1
|
187
193
|
|
188
194
|
if port is not None and ret_ports[0] != port:
|
189
|
-
logger.
|
195
|
+
logger.warning(
|
190
196
|
f"WARNING: Port {port} is not available. Use port {ret_ports[0]} instead."
|
191
197
|
)
|
192
198
|
|
@@ -622,56 +628,7 @@ def set_ulimit(target_soft_limit=65535):
|
|
622
628
|
try:
|
623
629
|
resource.setrlimit(resource_type, (target_soft_limit, current_hard))
|
624
630
|
except ValueError as e:
|
625
|
-
logger.
|
626
|
-
|
627
|
-
|
628
|
-
def is_llama3_405b_fp8_head_16(model_config):
|
629
|
-
"""Return whether the model is meta-llama/Meta-Llama-3.1-405B-FP8 with 16 kv heads."""
|
630
|
-
if (
|
631
|
-
model_config.hf_config.architectures[0] == "LlamaForCausalLM"
|
632
|
-
and model_config.hf_config.hidden_size == 16384
|
633
|
-
and model_config.hf_config.intermediate_size == 53248
|
634
|
-
and model_config.hf_config.num_hidden_layers == 126
|
635
|
-
and model_config.hf_config.num_key_value_heads == 16
|
636
|
-
and hasattr(model_config.hf_config, "quantization_config")
|
637
|
-
and model_config.hf_config.quantization_config["quant_method"] == "fbgemm_fp8"
|
638
|
-
):
|
639
|
-
return True
|
640
|
-
return False
|
641
|
-
|
642
|
-
|
643
|
-
def monkey_patch_vllm_qvk_linear_loader():
|
644
|
-
"""A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints."""
|
645
|
-
from vllm.model_executor.layers.linear import QKVParallelLinear
|
646
|
-
|
647
|
-
origin_weight_loader = QKVParallelLinear.weight_loader
|
648
|
-
|
649
|
-
def get_original_weight(loaded_weight, head_dim):
|
650
|
-
n_kv_head = loaded_weight.shape[0] // (2 * head_dim)
|
651
|
-
dim = loaded_weight.shape[1]
|
652
|
-
for i in range(n_kv_head):
|
653
|
-
loaded_weight[i * head_dim : (i + 1) * head_dim, :] = loaded_weight[
|
654
|
-
2 * i * head_dim : (2 * i + 1) * head_dim, :
|
655
|
-
]
|
656
|
-
original_kv_weight = loaded_weight[: n_kv_head * head_dim, :]
|
657
|
-
assert original_kv_weight.shape == (n_kv_head * head_dim, dim)
|
658
|
-
return original_kv_weight
|
659
|
-
|
660
|
-
def weight_loader_srt(
|
661
|
-
self,
|
662
|
-
param: Parameter,
|
663
|
-
loaded_weight: torch.Tensor,
|
664
|
-
loaded_shard_id: Optional[str] = None,
|
665
|
-
):
|
666
|
-
if (
|
667
|
-
loaded_shard_id in ["k", "v"]
|
668
|
-
and loaded_weight.shape[0] == self.head_size * self.total_num_kv_heads * 2
|
669
|
-
):
|
670
|
-
loaded_weight = get_original_weight(loaded_weight, self.head_size)
|
671
|
-
|
672
|
-
origin_weight_loader(self, param, loaded_weight, loaded_shard_id)
|
673
|
-
|
674
|
-
setattr(QKVParallelLinear, "weight_loader", weight_loader_srt)
|
631
|
+
logger.warning(f"Fail to set RLIMIT_NOFILE: {e}")
|
675
632
|
|
676
633
|
|
677
634
|
def add_api_key_middleware(app, api_key: str):
|
@@ -714,3 +671,14 @@ def configure_logger(server_args, prefix: str = ""):
|
|
714
671
|
datefmt="%H:%M:%S",
|
715
672
|
force=True,
|
716
673
|
)
|
674
|
+
|
675
|
+
|
676
|
+
# source: https://github.com/vllm-project/vllm/blob/93b38bea5dd03e1b140ca997dfaadef86f8f1855/vllm/lora/utils.py#L9
|
677
|
+
def replace_submodule(
|
678
|
+
model: nn.Module, module_name: str, new_module: nn.Module
|
679
|
+
) -> nn.Module:
|
680
|
+
"""Replace a submodule in a model with a new module."""
|
681
|
+
parent = model.get_submodule(".".join(module_name.split(".")[:-1]))
|
682
|
+
target_name = module_name.split(".")[-1]
|
683
|
+
setattr(parent, target_name, new_module)
|
684
|
+
return new_module
|
@@ -0,0 +1,132 @@
|
|
1
|
+
"""
|
2
|
+
Run few-shot GSM-8K evaluation.
|
3
|
+
|
4
|
+
Usage:
|
5
|
+
python3 -m sglang.test.few_shot_gsm8k --num-questions 200
|
6
|
+
"""
|
7
|
+
|
8
|
+
import argparse
|
9
|
+
import ast
|
10
|
+
import re
|
11
|
+
import time
|
12
|
+
|
13
|
+
import numpy as np
|
14
|
+
|
15
|
+
from sglang.api import set_default_backend
|
16
|
+
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
|
17
|
+
from sglang.utils import download_and_cache_file, dump_state_text, read_jsonl
|
18
|
+
|
19
|
+
INVALID = -9999999
|
20
|
+
|
21
|
+
|
22
|
+
def get_one_example(lines, i, include_answer):
|
23
|
+
ret = "Question: " + lines[i]["question"] + "\nAnswer:"
|
24
|
+
if include_answer:
|
25
|
+
ret += " " + lines[i]["answer"]
|
26
|
+
return ret
|
27
|
+
|
28
|
+
|
29
|
+
def get_few_shot_examples(lines, k):
|
30
|
+
ret = ""
|
31
|
+
for i in range(k):
|
32
|
+
ret += get_one_example(lines, i, True) + "\n\n"
|
33
|
+
return ret
|
34
|
+
|
35
|
+
|
36
|
+
def get_answer_value(answer_str):
|
37
|
+
answer_str = answer_str.replace(",", "")
|
38
|
+
numbers = re.findall(r"\d+", answer_str)
|
39
|
+
if len(numbers) < 1:
|
40
|
+
return INVALID
|
41
|
+
try:
|
42
|
+
return ast.literal_eval(numbers[-1])
|
43
|
+
except SyntaxError:
|
44
|
+
return INVALID
|
45
|
+
|
46
|
+
|
47
|
+
def main(args):
|
48
|
+
# Select backend
|
49
|
+
set_default_backend(RuntimeEndpoint(f"{args.host}:{args.port}"))
|
50
|
+
|
51
|
+
# Read data
|
52
|
+
url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl"
|
53
|
+
filename = download_and_cache_file(url)
|
54
|
+
lines = list(read_jsonl(filename))
|
55
|
+
|
56
|
+
# Construct prompts
|
57
|
+
num_questions = args.num_questions
|
58
|
+
num_shots = args.num_shots
|
59
|
+
few_shot_examples = get_few_shot_examples(lines, num_shots)
|
60
|
+
|
61
|
+
questions = []
|
62
|
+
labels = []
|
63
|
+
for i in range(len(lines[:num_questions])):
|
64
|
+
questions.append(get_one_example(lines, i, False))
|
65
|
+
labels.append(get_answer_value(lines[i]["answer"]))
|
66
|
+
assert all(l != INVALID for l in labels)
|
67
|
+
arguments = [{"question": q} for q in questions]
|
68
|
+
|
69
|
+
#####################################
|
70
|
+
######### SGL Program Begin #########
|
71
|
+
#####################################
|
72
|
+
|
73
|
+
import sglang as sgl
|
74
|
+
|
75
|
+
@sgl.function
|
76
|
+
def few_shot_gsm8k(s, question):
|
77
|
+
s += few_shot_examples + question
|
78
|
+
s += sgl.gen(
|
79
|
+
"answer", max_tokens=512, stop=["Question", "Assistant:", "<|separator|>"]
|
80
|
+
)
|
81
|
+
|
82
|
+
#####################################
|
83
|
+
########## SGL Program End ##########
|
84
|
+
#####################################
|
85
|
+
|
86
|
+
# Run requests
|
87
|
+
tic = time.time()
|
88
|
+
states = few_shot_gsm8k.run_batch(
|
89
|
+
arguments,
|
90
|
+
temperature=0,
|
91
|
+
num_threads=args.parallel,
|
92
|
+
progress_bar=True,
|
93
|
+
)
|
94
|
+
latency = time.time() - tic
|
95
|
+
|
96
|
+
preds = []
|
97
|
+
for i in range(len(states)):
|
98
|
+
preds.append(get_answer_value(states[i]["answer"]))
|
99
|
+
|
100
|
+
# print(f"{preds=}")
|
101
|
+
# print(f"{labels=}")
|
102
|
+
|
103
|
+
# Compute accuracy
|
104
|
+
acc = np.mean(np.array(preds) == np.array(labels))
|
105
|
+
invalid = np.mean(np.array(preds) == INVALID)
|
106
|
+
|
107
|
+
# Compute speed
|
108
|
+
num_output_tokens = sum(
|
109
|
+
s.get_meta_info("answer")["completion_tokens"] for s in states
|
110
|
+
)
|
111
|
+
output_throughput = num_output_tokens / latency
|
112
|
+
|
113
|
+
# Print results
|
114
|
+
print(f"Accuracy: {acc:.3f}")
|
115
|
+
print(f"Invalid: {invalid:.3f}")
|
116
|
+
print(f"Latency: {latency:.3f} s")
|
117
|
+
print(f"Output throughput: {output_throughput:.3f} token/s")
|
118
|
+
|
119
|
+
# Dump results
|
120
|
+
dump_state_text("tmp_output_gsm8k.txt", states)
|
121
|
+
|
122
|
+
|
123
|
+
if __name__ == "__main__":
|
124
|
+
parser = argparse.ArgumentParser()
|
125
|
+
parser.add_argument("--num-shots", type=int, default=5)
|
126
|
+
parser.add_argument("--data-path", type=str, default="test.jsonl")
|
127
|
+
parser.add_argument("--num-questions", type=int, default=200)
|
128
|
+
parser.add_argument("--parallel", type=int, default=128)
|
129
|
+
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
130
|
+
parser.add_argument("--port", type=int, default=30000)
|
131
|
+
args = parser.parse_args()
|
132
|
+
main(args)
|