sglang 0.3.3__py3-none-any.whl → 0.3.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +31 -13
- sglang/bench_server_latency.py +21 -10
- sglang/bench_serving.py +101 -7
- sglang/global_config.py +0 -1
- sglang/srt/conversation.py +11 -2
- sglang/srt/layers/attention/__init__.py +27 -5
- sglang/srt/layers/attention/double_sparsity_backend.py +281 -0
- sglang/srt/layers/attention/flashinfer_backend.py +352 -83
- sglang/srt/layers/attention/triton_backend.py +6 -4
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +772 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -3
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +4 -2
- sglang/srt/layers/sampler.py +6 -2
- sglang/srt/managers/data_parallel_controller.py +177 -0
- sglang/srt/managers/detokenizer_manager.py +31 -10
- sglang/srt/managers/io_struct.py +11 -2
- sglang/srt/managers/schedule_batch.py +126 -43
- sglang/srt/managers/schedule_policy.py +2 -1
- sglang/srt/managers/scheduler.py +245 -142
- sglang/srt/managers/tokenizer_manager.py +14 -1
- sglang/srt/managers/tp_worker.py +111 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -4
- sglang/srt/mem_cache/memory_pool.py +77 -4
- sglang/srt/mem_cache/radix_cache.py +15 -7
- sglang/srt/model_executor/cuda_graph_runner.py +4 -4
- sglang/srt/model_executor/forward_batch_info.py +16 -21
- sglang/srt/model_executor/model_runner.py +100 -36
- sglang/srt/models/baichuan.py +2 -3
- sglang/srt/models/chatglm.py +5 -6
- sglang/srt/models/commandr.py +1 -2
- sglang/srt/models/dbrx.py +1 -2
- sglang/srt/models/deepseek.py +4 -5
- sglang/srt/models/deepseek_v2.py +5 -6
- sglang/srt/models/exaone.py +1 -2
- sglang/srt/models/gemma.py +2 -2
- sglang/srt/models/gemma2.py +5 -5
- sglang/srt/models/gpt_bigcode.py +5 -5
- sglang/srt/models/grok.py +1 -2
- sglang/srt/models/internlm2.py +1 -2
- sglang/srt/models/llama.py +1 -2
- sglang/srt/models/llama_classification.py +1 -2
- sglang/srt/models/llama_reward.py +2 -3
- sglang/srt/models/llava.py +4 -8
- sglang/srt/models/llavavid.py +1 -2
- sglang/srt/models/minicpm.py +1 -2
- sglang/srt/models/minicpm3.py +5 -6
- sglang/srt/models/mixtral.py +1 -2
- sglang/srt/models/mixtral_quant.py +1 -2
- sglang/srt/models/olmo.py +352 -0
- sglang/srt/models/olmoe.py +1 -2
- sglang/srt/models/qwen.py +1 -2
- sglang/srt/models/qwen2.py +1 -2
- sglang/srt/models/qwen2_moe.py +4 -5
- sglang/srt/models/stablelm.py +1 -2
- sglang/srt/models/torch_native_llama.py +1 -2
- sglang/srt/models/xverse.py +1 -2
- sglang/srt/models/xverse_moe.py +4 -5
- sglang/srt/models/yivl.py +1 -2
- sglang/srt/openai_api/adapter.py +97 -52
- sglang/srt/openai_api/protocol.py +10 -2
- sglang/srt/sampling/penaltylib/orchestrator.py +28 -9
- sglang/srt/sampling/sampling_batch_info.py +105 -59
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server.py +171 -37
- sglang/srt/server_args.py +127 -48
- sglang/srt/utils.py +37 -14
- sglang/test/few_shot_gsm8k.py +4 -1
- sglang/test/few_shot_gsm8k_engine.py +144 -0
- sglang/test/srt/sampling/penaltylib/utils.py +16 -12
- sglang/version.py +1 -1
- {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/METADATA +82 -32
- sglang-0.3.4.dist-info/RECORD +143 -0
- {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/WHEEL +1 -1
- sglang/srt/layers/attention/flashinfer_utils.py +0 -237
- sglang-0.3.3.dist-info/RECORD +0 -139
- {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/LICENSE +0 -0
- {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/top_level.txt +0 -0
sglang/srt/server_args.py
CHANGED
@@ -35,11 +35,12 @@ class ServerArgs:
|
|
35
35
|
tokenizer_mode: str = "auto"
|
36
36
|
skip_tokenizer_init: bool = False
|
37
37
|
load_format: str = "auto"
|
38
|
+
trust_remote_code: bool = True
|
38
39
|
dtype: str = "auto"
|
39
40
|
kv_cache_dtype: str = "auto"
|
40
|
-
trust_remote_code: bool = True
|
41
|
-
context_length: Optional[int] = None
|
42
41
|
quantization: Optional[str] = None
|
42
|
+
context_length: Optional[int] = None
|
43
|
+
device: str = "cuda"
|
43
44
|
served_model_name: Optional[str] = None
|
44
45
|
chat_template: Optional[str] = None
|
45
46
|
is_embedding: bool = False
|
@@ -72,6 +73,7 @@ class ServerArgs:
|
|
72
73
|
# Other
|
73
74
|
api_key: Optional[str] = None
|
74
75
|
file_storage_pth: str = "SGLang_storage"
|
76
|
+
enable_cache_report: bool = False
|
75
77
|
|
76
78
|
# Data parallelism
|
77
79
|
dp_size: int = 1
|
@@ -85,10 +87,23 @@ class ServerArgs:
|
|
85
87
|
# Model override args in JSON
|
86
88
|
json_model_override_args: str = "{}"
|
87
89
|
|
88
|
-
#
|
90
|
+
# Double Sparsity
|
91
|
+
enable_double_sparsity: bool = False
|
92
|
+
ds_channel_config_path: str = None
|
93
|
+
ds_heavy_channel_num: int = 32
|
94
|
+
ds_heavy_token_num: int = 256
|
95
|
+
ds_heavy_channel_type: str = "qk"
|
96
|
+
ds_sparse_decode_threshold: int = 4096
|
97
|
+
|
98
|
+
# LoRA
|
99
|
+
lora_paths: Optional[List[str]] = None
|
100
|
+
max_loras_per_batch: int = 8
|
101
|
+
|
102
|
+
# Kernel backend
|
89
103
|
attention_backend: Optional[str] = None
|
90
104
|
sampling_backend: Optional[str] = None
|
91
105
|
|
106
|
+
# Optimization/debug options
|
92
107
|
disable_flashinfer: bool = False
|
93
108
|
disable_flashinfer_sampling: bool = False
|
94
109
|
disable_radix_cache: bool = False
|
@@ -98,16 +113,16 @@ class ServerArgs:
|
|
98
113
|
disable_disk_cache: bool = False
|
99
114
|
disable_custom_all_reduce: bool = False
|
100
115
|
disable_mla: bool = False
|
116
|
+
disable_penalizer: bool = False
|
117
|
+
disable_nan_detection: bool = False
|
118
|
+
enable_overlap_schedule: bool = False
|
101
119
|
enable_mixed_chunk: bool = False
|
102
120
|
enable_torch_compile: bool = False
|
103
121
|
max_torch_compile_bs: int = 32
|
104
122
|
torchao_config: str = ""
|
105
123
|
enable_p2p_check: bool = False
|
106
124
|
triton_attention_reduce_in_fp32: bool = False
|
107
|
-
|
108
|
-
# LoRA
|
109
|
-
lora_paths: Optional[List[str]] = None
|
110
|
-
max_loras_per_batch: int = 8
|
125
|
+
num_continuous_decode_steps: int = 1
|
111
126
|
|
112
127
|
def __post_init__(self):
|
113
128
|
# Set missing default values
|
@@ -223,6 +238,11 @@ class ServerArgs:
|
|
223
238
|
'"dummy" will initialize the weights with random values, '
|
224
239
|
"which is mainly for profiling.",
|
225
240
|
)
|
241
|
+
parser.add_argument(
|
242
|
+
"--trust-remote-code",
|
243
|
+
action="store_true",
|
244
|
+
help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
|
245
|
+
)
|
226
246
|
parser.add_argument(
|
227
247
|
"--dtype",
|
228
248
|
type=str,
|
@@ -244,17 +264,6 @@ class ServerArgs:
|
|
244
264
|
choices=["auto", "fp8_e5m2"],
|
245
265
|
help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" is supported for CUDA 11.8+.',
|
246
266
|
)
|
247
|
-
parser.add_argument(
|
248
|
-
"--trust-remote-code",
|
249
|
-
action="store_true",
|
250
|
-
help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
|
251
|
-
)
|
252
|
-
parser.add_argument(
|
253
|
-
"--context-length",
|
254
|
-
type=int,
|
255
|
-
default=ServerArgs.context_length,
|
256
|
-
help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).",
|
257
|
-
)
|
258
267
|
parser.add_argument(
|
259
268
|
"--quantization",
|
260
269
|
type=str,
|
@@ -270,6 +279,19 @@ class ServerArgs:
|
|
270
279
|
],
|
271
280
|
help="The quantization method.",
|
272
281
|
)
|
282
|
+
parser.add_argument(
|
283
|
+
"--context-length",
|
284
|
+
type=int,
|
285
|
+
default=ServerArgs.context_length,
|
286
|
+
help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).",
|
287
|
+
)
|
288
|
+
parser.add_argument(
|
289
|
+
"--device",
|
290
|
+
type=str,
|
291
|
+
default="cuda",
|
292
|
+
choices=["cuda", "xpu"],
|
293
|
+
help="The device type.",
|
294
|
+
)
|
273
295
|
parser.add_argument(
|
274
296
|
"--served-model-name",
|
275
297
|
type=str,
|
@@ -390,6 +412,11 @@ class ServerArgs:
|
|
390
412
|
default=ServerArgs.file_storage_pth,
|
391
413
|
help="The path of the file storage in backend.",
|
392
414
|
)
|
415
|
+
parser.add_argument(
|
416
|
+
"--enable-cache-report",
|
417
|
+
action="store_true",
|
418
|
+
help="Return number of cached tokens in usage.prompt_tokens_details for each openai request.",
|
419
|
+
)
|
393
420
|
|
394
421
|
# Data parallelism
|
395
422
|
parser.add_argument(
|
@@ -432,7 +459,60 @@ class ServerArgs:
|
|
432
459
|
default=ServerArgs.json_model_override_args,
|
433
460
|
)
|
434
461
|
|
435
|
-
#
|
462
|
+
# Double Sparsity
|
463
|
+
parser.add_argument(
|
464
|
+
"--enable-double-sparsity",
|
465
|
+
action="store_true",
|
466
|
+
help="Enable double sparsity attention",
|
467
|
+
)
|
468
|
+
parser.add_argument(
|
469
|
+
"--ds-channel-config-path",
|
470
|
+
type=str,
|
471
|
+
default=ServerArgs.ds_channel_config_path,
|
472
|
+
help="The path of the double sparsity channel config",
|
473
|
+
)
|
474
|
+
parser.add_argument(
|
475
|
+
"--ds-heavy-channel-num",
|
476
|
+
type=int,
|
477
|
+
default=ServerArgs.ds_heavy_channel_num,
|
478
|
+
help="The number of heavy channels in double sparsity attention",
|
479
|
+
)
|
480
|
+
parser.add_argument(
|
481
|
+
"--ds-heavy-token-num",
|
482
|
+
type=int,
|
483
|
+
default=ServerArgs.ds_heavy_token_num,
|
484
|
+
help="The number of heavy tokens in double sparsity attention",
|
485
|
+
)
|
486
|
+
parser.add_argument(
|
487
|
+
"--ds-heavy-channel-type",
|
488
|
+
type=str,
|
489
|
+
default=ServerArgs.ds_heavy_channel_type,
|
490
|
+
help="The type of heavy channels in double sparsity attention",
|
491
|
+
)
|
492
|
+
parser.add_argument(
|
493
|
+
"--ds-sparse-decode-threshold",
|
494
|
+
type=int,
|
495
|
+
default=ServerArgs.ds_sparse_decode_threshold,
|
496
|
+
help="The type of heavy channels in double sparsity attention",
|
497
|
+
)
|
498
|
+
|
499
|
+
# LoRA
|
500
|
+
parser.add_argument(
|
501
|
+
"--lora-paths",
|
502
|
+
type=str,
|
503
|
+
nargs="*",
|
504
|
+
default=None,
|
505
|
+
action=LoRAPathAction,
|
506
|
+
help="The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}",
|
507
|
+
)
|
508
|
+
parser.add_argument(
|
509
|
+
"--max-loras-per-batch",
|
510
|
+
type=int,
|
511
|
+
default=8,
|
512
|
+
help="Maximum number of adapters for a running batch, include base-only request",
|
513
|
+
)
|
514
|
+
|
515
|
+
# Kernel backend
|
436
516
|
parser.add_argument(
|
437
517
|
"--attention-backend",
|
438
518
|
type=str,
|
@@ -447,6 +527,8 @@ class ServerArgs:
|
|
447
527
|
default=ServerArgs.sampling_backend,
|
448
528
|
help="Choose the kernels for sampling layers.",
|
449
529
|
)
|
530
|
+
|
531
|
+
# Optimization/debug options
|
450
532
|
parser.add_argument(
|
451
533
|
"--disable-flashinfer",
|
452
534
|
action="store_true",
|
@@ -493,6 +575,21 @@ class ServerArgs:
|
|
493
575
|
action="store_true",
|
494
576
|
help="Disable Multi-head Latent Attention (MLA) for DeepSeek-V2.",
|
495
577
|
)
|
578
|
+
parser.add_argument(
|
579
|
+
"--disable-penalizer",
|
580
|
+
action="store_true",
|
581
|
+
help="Disable the logit penalizers (e.g., frequency and repetition penalty) for better performance if they are not used in any requests.",
|
582
|
+
)
|
583
|
+
parser.add_argument(
|
584
|
+
"--disable-nan-detection",
|
585
|
+
action="store_true",
|
586
|
+
help="Disable the NaN detection for better performance.",
|
587
|
+
)
|
588
|
+
parser.add_argument(
|
589
|
+
"--enable-overlap-schedule",
|
590
|
+
action="store_true",
|
591
|
+
help="Overlap the CPU scheduler with GPU model worker. Experimental feature.",
|
592
|
+
)
|
496
593
|
parser.add_argument(
|
497
594
|
"--enable-mixed-chunk",
|
498
595
|
action="store_true",
|
@@ -527,25 +624,12 @@ class ServerArgs:
|
|
527
624
|
"This only affects Triton attention kernels.",
|
528
625
|
)
|
529
626
|
parser.add_argument(
|
530
|
-
"--
|
531
|
-
action="store_true",
|
532
|
-
help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).",
|
533
|
-
)
|
534
|
-
|
535
|
-
# LoRA options
|
536
|
-
parser.add_argument(
|
537
|
-
"--lora-paths",
|
538
|
-
type=str,
|
539
|
-
nargs="*",
|
540
|
-
default=None,
|
541
|
-
action=LoRAPathAction,
|
542
|
-
help="The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}",
|
543
|
-
)
|
544
|
-
parser.add_argument(
|
545
|
-
"--max-loras-per-batch",
|
627
|
+
"--num-continuous-decode-steps",
|
546
628
|
type=int,
|
547
|
-
default=
|
548
|
-
help="
|
629
|
+
default=ServerArgs.num_continuous_decode_steps,
|
630
|
+
help="Run multiple continuous decoding steps to reduce scheduling overhead. "
|
631
|
+
"This can potentially increase throughput but may also increase time-to-first-token latency. "
|
632
|
+
"The default value is 1, meaning only run one decoding step at a time.",
|
549
633
|
)
|
550
634
|
|
551
635
|
@classmethod
|
@@ -566,7 +650,7 @@ class ServerArgs:
|
|
566
650
|
self.tp_size % self.nnodes == 0
|
567
651
|
), "tp_size must be divisible by number of nodes"
|
568
652
|
assert not (
|
569
|
-
self.dp_size > 1 and self.
|
653
|
+
self.dp_size > 1 and self.nnodes != 1
|
570
654
|
), "multi-node data parallel is not supported"
|
571
655
|
assert (
|
572
656
|
self.max_loras_per_batch > 0
|
@@ -575,11 +659,6 @@ class ServerArgs:
|
|
575
659
|
and (self.lora_paths is None or self.disable_radix_cache)
|
576
660
|
), "compatibility of lora and cuda graph and radix attention is in progress"
|
577
661
|
|
578
|
-
assert self.dp_size == 1, (
|
579
|
-
"The support for data parallelism is temporarily disabled during refactor. "
|
580
|
-
"Please use sglang<=0.3.2 or wait for later updates."
|
581
|
-
)
|
582
|
-
|
583
662
|
if isinstance(self.lora_paths, list):
|
584
663
|
lora_paths = self.lora_paths
|
585
664
|
self.lora_paths = {}
|
@@ -618,11 +697,11 @@ class PortArgs:
|
|
618
697
|
# The ipc filename for detokenizer to receive inputs from scheduler (zmq)
|
619
698
|
detokenizer_ipc_name: str
|
620
699
|
|
621
|
-
# The port for nccl initialization
|
622
|
-
|
700
|
+
# The port for nccl initialization (torch.dist)
|
701
|
+
nccl_port: int
|
623
702
|
|
624
|
-
@
|
625
|
-
def init_new(
|
703
|
+
@staticmethod
|
704
|
+
def init_new(server_args) -> "PortArgs":
|
626
705
|
port = server_args.port + 1
|
627
706
|
while True:
|
628
707
|
if is_port_available(port):
|
@@ -633,7 +712,7 @@ class PortArgs:
|
|
633
712
|
tokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
|
634
713
|
scheduler_input_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
|
635
714
|
detokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
|
636
|
-
|
715
|
+
nccl_port=port,
|
637
716
|
)
|
638
717
|
|
639
718
|
|
sglang/srt/utils.py
CHANGED
@@ -35,7 +35,7 @@ import psutil
|
|
35
35
|
import requests
|
36
36
|
import torch
|
37
37
|
import torch.distributed as dist
|
38
|
-
from fastapi.responses import
|
38
|
+
from fastapi.responses import ORJSONResponse
|
39
39
|
from packaging import version as pkg_version
|
40
40
|
from torch import nn
|
41
41
|
from torch.profiler import ProfilerActivity, profile, record_function
|
@@ -140,26 +140,41 @@ def calculate_time(show=False, min_cost_ms=0.0):
|
|
140
140
|
return wrapper
|
141
141
|
|
142
142
|
|
143
|
-
def get_available_gpu_memory(gpu_id, distributed=False):
|
143
|
+
def get_available_gpu_memory(device, gpu_id, distributed=False):
|
144
144
|
"""
|
145
145
|
Get available memory for cuda:gpu_id device.
|
146
146
|
When distributed is True, the available memory is the minimum available memory of all GPUs.
|
147
147
|
"""
|
148
|
-
|
149
|
-
|
148
|
+
if device == "cuda":
|
149
|
+
num_gpus = torch.cuda.device_count()
|
150
|
+
assert gpu_id < num_gpus
|
151
|
+
|
152
|
+
if torch.cuda.current_device() != gpu_id:
|
153
|
+
print(
|
154
|
+
f"WARNING: current device is not {gpu_id}, but {torch.cuda.current_device()}, ",
|
155
|
+
"which may cause useless memory allocation for torch CUDA context.",
|
156
|
+
)
|
150
157
|
|
151
|
-
|
152
|
-
|
153
|
-
f"WARNING: current device is not {gpu_id}, but {torch.cuda.current_device()}, ",
|
154
|
-
"which may cause useless memory allocation for torch CUDA context.",
|
155
|
-
)
|
158
|
+
torch.cuda.empty_cache()
|
159
|
+
free_gpu_memory, _ = torch.cuda.mem_get_info(gpu_id)
|
156
160
|
|
157
|
-
|
158
|
-
|
161
|
+
elif device == "xpu":
|
162
|
+
num_gpus = torch.xpu.device_count()
|
163
|
+
assert gpu_id < num_gpus
|
164
|
+
|
165
|
+
if torch.xpu.current_device() != gpu_id:
|
166
|
+
print(
|
167
|
+
f"WARNING: current device is not {gpu_id}, but {torch.xpu.current_device()}, ",
|
168
|
+
"which may cause useless memory allocation for torch XPU context.",
|
169
|
+
)
|
170
|
+
torch.xpu.empty_cache()
|
171
|
+
used_memory = torch.xpu.memory_allocated()
|
172
|
+
total_gpu_memory = torch.xpu.get_device_properties(gpu_id).total_memory
|
173
|
+
free_gpu_memory = total_gpu_memory - used_memory
|
159
174
|
|
160
175
|
if distributed:
|
161
176
|
tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to(
|
162
|
-
torch.device(
|
177
|
+
torch.device(device, gpu_id)
|
163
178
|
)
|
164
179
|
torch.distributed.all_reduce(tensor, op=torch.distributed.ReduceOp.MIN)
|
165
180
|
free_gpu_memory = tensor.item()
|
@@ -551,7 +566,7 @@ def add_api_key_middleware(app, api_key: str):
|
|
551
566
|
if request.url.path.startswith("/health"):
|
552
567
|
return await call_next(request)
|
553
568
|
if request.headers.get("Authorization") != "Bearer " + api_key:
|
554
|
-
return
|
569
|
+
return ORJSONResponse(content={"error": "Unauthorized"}, status_code=401)
|
555
570
|
return await call_next(request)
|
556
571
|
|
557
572
|
|
@@ -569,10 +584,11 @@ def prepare_model_and_tokenizer(model_path: str, tokenizer_path: str):
|
|
569
584
|
|
570
585
|
def configure_logger(server_args, prefix: str = ""):
|
571
586
|
format = f"[%(asctime)s{prefix}] %(message)s"
|
587
|
+
# format = f"[%(asctime)s.%(msecs)03d{prefix}] %(message)s"
|
572
588
|
logging.basicConfig(
|
573
589
|
level=getattr(logging, server_args.log_level.upper()),
|
574
590
|
format=format,
|
575
|
-
datefmt="%H:%M:%S",
|
591
|
+
datefmt="%Y-%m-%d %H:%M:%S",
|
576
592
|
force=True,
|
577
593
|
)
|
578
594
|
|
@@ -675,3 +691,10 @@ def pytorch_profile(name, func, *args, data_size=-1):
|
|
675
691
|
prof.export_chrome_trace(f"trace/{name}_{step_counter}.json")
|
676
692
|
step_counter += 1
|
677
693
|
return result
|
694
|
+
|
695
|
+
|
696
|
+
def first_rank_print(*args, **kwargs):
|
697
|
+
if torch.cuda.current_device() == 0:
|
698
|
+
print(*args, **kwargs)
|
699
|
+
else:
|
700
|
+
pass
|
sglang/test/few_shot_gsm8k.py
CHANGED
@@ -76,7 +76,9 @@ def run_eval(args):
|
|
76
76
|
def few_shot_gsm8k(s, question):
|
77
77
|
s += few_shot_examples + question
|
78
78
|
s += sgl.gen(
|
79
|
-
"answer",
|
79
|
+
"answer",
|
80
|
+
max_tokens=args.max_new_tokens,
|
81
|
+
stop=["Question", "Assistant:", "<|separator|>"],
|
80
82
|
)
|
81
83
|
|
82
84
|
#####################################
|
@@ -131,6 +133,7 @@ if __name__ == "__main__":
|
|
131
133
|
parser.add_argument("--num-shots", type=int, default=5)
|
132
134
|
parser.add_argument("--data-path", type=str, default="test.jsonl")
|
133
135
|
parser.add_argument("--num-questions", type=int, default=200)
|
136
|
+
parser.add_argument("--max-new-tokens", type=int, default=512)
|
134
137
|
parser.add_argument("--parallel", type=int, default=128)
|
135
138
|
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
136
139
|
parser.add_argument("--port", type=int, default=30000)
|
@@ -0,0 +1,144 @@
|
|
1
|
+
import argparse
|
2
|
+
import ast
|
3
|
+
import asyncio
|
4
|
+
import json
|
5
|
+
import re
|
6
|
+
import time
|
7
|
+
|
8
|
+
import numpy as np
|
9
|
+
|
10
|
+
import sglang as sgl
|
11
|
+
from sglang.api import set_default_backend
|
12
|
+
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
|
13
|
+
from sglang.utils import download_and_cache_file, dump_state_text, read_jsonl
|
14
|
+
|
15
|
+
INVALID = -9999999
|
16
|
+
|
17
|
+
|
18
|
+
def get_one_example(lines, i, include_answer):
|
19
|
+
ret = "Question: " + lines[i]["question"] + "\nAnswer:"
|
20
|
+
if include_answer:
|
21
|
+
ret += " " + lines[i]["answer"]
|
22
|
+
return ret
|
23
|
+
|
24
|
+
|
25
|
+
def get_few_shot_examples(lines, k):
|
26
|
+
ret = ""
|
27
|
+
for i in range(k):
|
28
|
+
ret += get_one_example(lines, i, True) + "\n\n"
|
29
|
+
return ret
|
30
|
+
|
31
|
+
|
32
|
+
def get_answer_value(answer_str):
|
33
|
+
answer_str = answer_str.replace(",", "")
|
34
|
+
numbers = re.findall(r"\d+", answer_str)
|
35
|
+
if len(numbers) < 1:
|
36
|
+
return INVALID
|
37
|
+
try:
|
38
|
+
return ast.literal_eval(numbers[-1])
|
39
|
+
except SyntaxError:
|
40
|
+
return INVALID
|
41
|
+
|
42
|
+
|
43
|
+
async def concurrent_generate(engine, prompts, sampling_param):
|
44
|
+
tasks = []
|
45
|
+
for prompt in prompts:
|
46
|
+
tasks.append(asyncio.create_task(engine.async_generate(prompt, sampling_param)))
|
47
|
+
|
48
|
+
outputs = await asyncio.gather(*tasks)
|
49
|
+
return outputs
|
50
|
+
|
51
|
+
|
52
|
+
def run_eval(args):
|
53
|
+
# Select backend
|
54
|
+
engine = sgl.Engine(model_path=args.model_path, log_level="error")
|
55
|
+
|
56
|
+
if args.local_data_path is None:
|
57
|
+
# Read data
|
58
|
+
url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl"
|
59
|
+
filename = download_and_cache_file(url)
|
60
|
+
else:
|
61
|
+
filename = args.local_data_path
|
62
|
+
|
63
|
+
lines = list(read_jsonl(filename))
|
64
|
+
|
65
|
+
# Construct prompts
|
66
|
+
num_questions = args.num_questions
|
67
|
+
num_shots = args.num_shots
|
68
|
+
few_shot_examples = get_few_shot_examples(lines, num_shots)
|
69
|
+
|
70
|
+
questions = []
|
71
|
+
labels = []
|
72
|
+
for i in range(len(lines[:num_questions])):
|
73
|
+
questions.append(get_one_example(lines, i, False))
|
74
|
+
labels.append(get_answer_value(lines[i]["answer"]))
|
75
|
+
assert all(l != INVALID for l in labels)
|
76
|
+
arguments = [{"question": q} for q in questions]
|
77
|
+
|
78
|
+
# construct the prompts
|
79
|
+
prompts = []
|
80
|
+
for i, arg in enumerate(arguments):
|
81
|
+
q = arg["question"]
|
82
|
+
prompt = few_shot_examples + q
|
83
|
+
prompts.append(prompt)
|
84
|
+
|
85
|
+
sampling_param = {
|
86
|
+
"stop": ["Question", "Assistant:", "<|separator|>"],
|
87
|
+
"max_new_tokens": 512,
|
88
|
+
"temperature": 0,
|
89
|
+
}
|
90
|
+
|
91
|
+
# Run requests
|
92
|
+
tic = time.time()
|
93
|
+
|
94
|
+
loop = asyncio.get_event_loop()
|
95
|
+
|
96
|
+
outputs = loop.run_until_complete(
|
97
|
+
concurrent_generate(engine, prompts, sampling_param)
|
98
|
+
)
|
99
|
+
|
100
|
+
# End requests
|
101
|
+
latency = time.time() - tic
|
102
|
+
|
103
|
+
# Shutdown the engine
|
104
|
+
engine.shutdown()
|
105
|
+
|
106
|
+
# Parse output
|
107
|
+
preds = []
|
108
|
+
|
109
|
+
for output in outputs:
|
110
|
+
preds.append(get_answer_value(output["text"]))
|
111
|
+
|
112
|
+
# Compute accuracy
|
113
|
+
acc = np.mean(np.array(preds) == np.array(labels))
|
114
|
+
invalid = np.mean(np.array(preds) == INVALID)
|
115
|
+
|
116
|
+
# Compute speed
|
117
|
+
num_output_tokens = sum(
|
118
|
+
output["meta_info"]["completion_tokens"] for output in outputs
|
119
|
+
)
|
120
|
+
output_throughput = num_output_tokens / latency
|
121
|
+
|
122
|
+
# Print results
|
123
|
+
print(f"Accuracy: {acc:.3f}")
|
124
|
+
print(f"Invalid: {invalid:.3f}")
|
125
|
+
print(f"Latency: {latency:.3f} s")
|
126
|
+
print(f"Output throughput: {output_throughput:.3f} token/s")
|
127
|
+
|
128
|
+
return {
|
129
|
+
"accuracy": acc,
|
130
|
+
"latency": latency,
|
131
|
+
"output_throughput": output_throughput,
|
132
|
+
}
|
133
|
+
|
134
|
+
|
135
|
+
if __name__ == "__main__":
|
136
|
+
parser = argparse.ArgumentParser()
|
137
|
+
parser.add_argument(
|
138
|
+
"--model-path", type=str, default="meta-llama/Meta-Llama-3.1-8B-Instruct"
|
139
|
+
)
|
140
|
+
parser.add_argument("--local-data-path", type=Optional[str], default=None)
|
141
|
+
parser.add_argument("--num-shots", type=int, default=5)
|
142
|
+
parser.add_argument("--num-questions", type=int, default=200)
|
143
|
+
args = parser.parse_args()
|
144
|
+
metrics = run_eval(args)
|
@@ -164,19 +164,20 @@ class BaseBatchedPenalizerTest(unittest.TestCase):
|
|
164
164
|
msg=f"key={key}\nactual={getattr(penalizer, key)}\nexpected={tensor}",
|
165
165
|
)
|
166
166
|
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
device=self.device,
|
172
|
-
)
|
167
|
+
original = torch.ones(
|
168
|
+
size=(len(case.test_subjects), self.vocab_size),
|
169
|
+
dtype=torch.float32,
|
170
|
+
device=self.device,
|
173
171
|
)
|
172
|
+
actual = orchestrator.apply(original.clone())
|
174
173
|
expected = torch.cat(
|
175
174
|
tensors=[
|
176
175
|
subject.steps[0].expected_logits
|
177
176
|
for subject in case.test_subjects
|
178
177
|
],
|
179
178
|
)
|
179
|
+
if actual is None:
|
180
|
+
actual = original
|
180
181
|
torch.testing.assert_close(
|
181
182
|
actual=actual,
|
182
183
|
expected=expected,
|
@@ -226,6 +227,8 @@ class BaseBatchedPenalizerTest(unittest.TestCase):
|
|
226
227
|
device=self.device,
|
227
228
|
)
|
228
229
|
)
|
230
|
+
if actual_logits is None:
|
231
|
+
continue
|
229
232
|
filtered_expected_logits = torch.cat(
|
230
233
|
tensors=[
|
231
234
|
subject.steps[0].expected_logits
|
@@ -317,19 +320,20 @@ class BaseBatchedPenalizerTest(unittest.TestCase):
|
|
317
320
|
msg=f"key={key}\nactual={getattr(penalizer, key)}\nexpected={tensor}",
|
318
321
|
)
|
319
322
|
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
device=self.device,
|
325
|
-
)
|
323
|
+
original = torch.ones(
|
324
|
+
size=(len(filtered_subjects), self.vocab_size),
|
325
|
+
dtype=torch.float32,
|
326
|
+
device=self.device,
|
326
327
|
)
|
328
|
+
actual_logits = orchestrator.apply(original.clone())
|
327
329
|
filtered_expected_logits = torch.cat(
|
328
330
|
tensors=[
|
329
331
|
subject.steps[i].expected_logits
|
330
332
|
for subject in filtered_subjects
|
331
333
|
],
|
332
334
|
)
|
335
|
+
if actual_logits is None:
|
336
|
+
actual_logits = original
|
333
337
|
torch.testing.assert_close(
|
334
338
|
actual=actual_logits,
|
335
339
|
expected=filtered_expected_logits,
|
sglang/version.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = "0.3.
|
1
|
+
__version__ = "0.3.4"
|