sglang 0.3.0__py3-none-any.whl → 0.3.1.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. sglang/bench_latency.py +17 -8
  2. sglang/bench_serving.py +33 -38
  3. sglang/global_config.py +5 -17
  4. sglang/lang/backend/runtime_endpoint.py +5 -2
  5. sglang/lang/interpreter.py +1 -4
  6. sglang/launch_server.py +3 -6
  7. sglang/launch_server_llavavid.py +7 -8
  8. sglang/srt/{model_config.py → configs/model_config.py} +5 -0
  9. sglang/srt/constrained/__init__.py +2 -0
  10. sglang/srt/constrained/fsm_cache.py +33 -38
  11. sglang/srt/constrained/jump_forward.py +0 -1
  12. sglang/srt/conversation.py +4 -1
  13. sglang/srt/hf_transformers_utils.py +1 -3
  14. sglang/srt/layers/activation.py +12 -0
  15. sglang/srt/layers/attention_backend.py +480 -0
  16. sglang/srt/layers/flashinfer_utils.py +235 -0
  17. sglang/srt/layers/fused_moe/layer.py +27 -7
  18. sglang/srt/layers/layernorm.py +12 -0
  19. sglang/srt/layers/logits_processor.py +64 -77
  20. sglang/srt/layers/radix_attention.py +11 -161
  21. sglang/srt/layers/sampler.py +38 -122
  22. sglang/srt/layers/torchao_utils.py +75 -0
  23. sglang/srt/layers/{decode_attention.py → triton_attention/decode_attention.py} +67 -63
  24. sglang/srt/layers/{extend_attention.py → triton_attention/extend_attention.py} +40 -132
  25. sglang/srt/layers/{prefill_attention.py → triton_attention/prefill_attention.py} +13 -7
  26. sglang/srt/lora/lora.py +403 -0
  27. sglang/srt/lora/lora_config.py +43 -0
  28. sglang/srt/lora/lora_manager.py +259 -0
  29. sglang/srt/managers/controller_multi.py +1 -5
  30. sglang/srt/managers/controller_single.py +0 -5
  31. sglang/srt/managers/io_struct.py +16 -1
  32. sglang/srt/managers/policy_scheduler.py +122 -5
  33. sglang/srt/managers/schedule_batch.py +105 -71
  34. sglang/srt/managers/tokenizer_manager.py +17 -8
  35. sglang/srt/managers/tp_worker.py +188 -121
  36. sglang/srt/model_executor/cuda_graph_runner.py +69 -133
  37. sglang/srt/model_executor/forward_batch_info.py +35 -312
  38. sglang/srt/model_executor/model_runner.py +123 -154
  39. sglang/srt/models/baichuan.py +416 -0
  40. sglang/srt/models/chatglm.py +1 -5
  41. sglang/srt/models/commandr.py +1 -5
  42. sglang/srt/models/dbrx.py +1 -5
  43. sglang/srt/models/deepseek.py +1 -5
  44. sglang/srt/models/deepseek_v2.py +7 -6
  45. sglang/srt/models/exaone.py +1 -5
  46. sglang/srt/models/gemma.py +1 -5
  47. sglang/srt/models/gemma2.py +1 -5
  48. sglang/srt/models/gpt_bigcode.py +1 -5
  49. sglang/srt/models/grok.py +1 -5
  50. sglang/srt/models/internlm2.py +1 -5
  51. sglang/srt/models/llama.py +51 -5
  52. sglang/srt/models/llama_classification.py +1 -20
  53. sglang/srt/models/llava.py +30 -5
  54. sglang/srt/models/llavavid.py +2 -2
  55. sglang/srt/models/minicpm.py +1 -5
  56. sglang/srt/models/minicpm3.py +669 -0
  57. sglang/srt/models/mixtral.py +6 -5
  58. sglang/srt/models/mixtral_quant.py +1 -5
  59. sglang/srt/models/olmoe.py +415 -0
  60. sglang/srt/models/qwen.py +1 -5
  61. sglang/srt/models/qwen2.py +1 -5
  62. sglang/srt/models/qwen2_moe.py +6 -5
  63. sglang/srt/models/stablelm.py +1 -5
  64. sglang/srt/models/xverse.py +375 -0
  65. sglang/srt/models/xverse_moe.py +445 -0
  66. sglang/srt/openai_api/adapter.py +65 -46
  67. sglang/srt/openai_api/protocol.py +11 -3
  68. sglang/srt/sampling/sampling_batch_info.py +46 -80
  69. sglang/srt/server.py +30 -15
  70. sglang/srt/server_args.py +163 -28
  71. sglang/srt/utils.py +19 -51
  72. sglang/test/few_shot_gsm8k.py +132 -0
  73. sglang/test/runners.py +114 -22
  74. sglang/test/test_programs.py +7 -5
  75. sglang/test/test_utils.py +85 -2
  76. sglang/utils.py +32 -37
  77. sglang/version.py +1 -1
  78. {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/METADATA +30 -18
  79. sglang-0.3.1.post1.dist-info/RECORD +130 -0
  80. {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/WHEEL +1 -1
  81. sglang-0.3.0.dist-info/RECORD +0 -118
  82. {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/LICENSE +0 -0
  83. {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/top_level.txt +0 -0
sglang/srt/server_args.py CHANGED
@@ -21,9 +21,22 @@ import logging
21
21
  import random
22
22
  from typing import List, Optional, Union
23
23
 
24
+ from sglang.srt.utils import is_hip
25
+
24
26
  logger = logging.getLogger(__name__)
25
27
 
26
28
 
29
+ class LoRAPathAction(argparse.Action):
30
+ def __call__(self, parser, namespace, values, option_string=None):
31
+ setattr(namespace, self.dest, {})
32
+ for lora_path in values:
33
+ if "=" in lora_path:
34
+ name, path = lora_path.split("=", 1)
35
+ getattr(namespace, self.dest)[name] = path
36
+ else:
37
+ getattr(namespace, self.dest)[lora_path] = lora_path
38
+
39
+
27
40
  @dataclasses.dataclass
28
41
  class ServerArgs:
29
42
  # Model and tokenizer
@@ -49,7 +62,6 @@ class ServerArgs:
49
62
  # Memory and scheduling
50
63
  mem_fraction_static: Optional[float] = None
51
64
  max_running_requests: Optional[int] = None
52
- max_num_reqs: Optional[int] = None
53
65
  max_total_tokens: Optional[int] = None
54
66
  chunked_prefill_size: int = 8192
55
67
  max_prefill_tokens: int = 16384
@@ -60,6 +72,7 @@ class ServerArgs:
60
72
  tp_size: int = 1
61
73
  stream_interval: int = 1
62
74
  random_seed: Optional[int] = None
75
+ constrained_json_whitespace_pattern: Optional[str] = None
63
76
 
64
77
  # Logging
65
78
  log_level: str = "info"
@@ -75,7 +88,18 @@ class ServerArgs:
75
88
  dp_size: int = 1
76
89
  load_balance_method: str = "round_robin"
77
90
 
91
+ # Distributed args
92
+ nccl_init_addr: Optional[str] = None
93
+ nnodes: int = 1
94
+ node_rank: Optional[int] = None
95
+
96
+ # Model override args in JSON
97
+ json_model_override_args: str = "{}"
98
+
78
99
  # Optimization/debug options
100
+ attention_backend: Optional[str] = None
101
+ sampling_backend: Optional[str] = None
102
+
79
103
  disable_flashinfer: bool = False
80
104
  disable_flashinfer_sampling: bool = False
81
105
  disable_radix_cache: bool = False
@@ -86,16 +110,18 @@ class ServerArgs:
86
110
  disable_custom_all_reduce: bool = False
87
111
  enable_mixed_chunk: bool = False
88
112
  enable_torch_compile: bool = False
113
+ max_torch_compile_bs: int = 32
114
+ torchao_config: str = ""
89
115
  enable_p2p_check: bool = False
90
116
  enable_mla: bool = False
91
117
  triton_attention_reduce_in_fp32: bool = False
92
118
 
93
- # Distributed args
94
- nccl_init_addr: Optional[str] = None
95
- nnodes: int = 1
96
- node_rank: Optional[int] = None
119
+ # LoRA
120
+ lora_paths: Optional[List[str]] = None
121
+ max_loras_per_batch: int = 8
97
122
 
98
123
  def __post_init__(self):
124
+ # Set missing default values
99
125
  if self.tokenizer_path is None:
100
126
  self.tokenizer_path = self.model_path
101
127
 
@@ -106,6 +132,7 @@ class ServerArgs:
106
132
  # Disable chunked prefill
107
133
  self.chunked_prefill_size = None
108
134
 
135
+ # Mem fraction depends on the tensor parallelism size
109
136
  if self.mem_fraction_static is None:
110
137
  if self.tp_size >= 16:
111
138
  self.mem_fraction_static = 0.79
@@ -126,6 +153,47 @@ class ServerArgs:
126
153
  if self.random_seed is None:
127
154
  self.random_seed = random.randint(0, 1 << 30)
128
155
 
156
+ # Deprecation warnings
157
+ if self.disable_flashinfer:
158
+ logger.warning(
159
+ "The option '--disable-flashinfer' will be deprecated in the next release. "
160
+ "Please use '--attention-backend triton' instead."
161
+ )
162
+ self.attention_backend = "triton"
163
+ if self.disable_flashinfer_sampling:
164
+ logger.warning(
165
+ "The option '--disable-flashinfer-sampling' will be deprecated in the next release. "
166
+ "Please use '--sampling-backend pytorch' instead. "
167
+ )
168
+ self.sampling_backend = "pytorch"
169
+
170
+ # ROCm: flashinfer available later
171
+ if is_hip():
172
+ self.attention_backend = "triton"
173
+ self.sampling_backend = "pytorch"
174
+
175
+ # Default kernel backends
176
+ if self.enable_mla:
177
+ logger.info("MLA optimization is tunred on. Use triton backend.")
178
+ self.attention_backend = "triton"
179
+
180
+ if self.attention_backend is None:
181
+ self.attention_backend = "flashinfer"
182
+
183
+ if self.sampling_backend is None:
184
+ self.sampling_backend = "flashinfer"
185
+
186
+ # Model-specific patches
187
+ if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == self.model_path:
188
+ logger.info(
189
+ "Not sure why, the tokenizer will add an additional token at the end of the prompt when trust_remote_mode=True"
190
+ )
191
+ self.trust_remote_code = False
192
+
193
+ if "gemma-2" in self.model_path.lower():
194
+ logger.info("When using sliding window in gemma-2, turn on flashinfer.")
195
+ self.attention_backend = "flashinfer"
196
+
129
197
  @staticmethod
130
198
  def add_cli_args(parser: argparse.ArgumentParser):
131
199
  parser.add_argument(
@@ -209,11 +277,6 @@ class ServerArgs:
209
277
  action="store_true",
210
278
  help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
211
279
  )
212
- parser.add_argument(
213
- "--is-embedding",
214
- action="store_true",
215
- help="Whether to use a CausalLM as an embedding model.",
216
- )
217
280
  parser.add_argument(
218
281
  "--context-length",
219
282
  type=int,
@@ -248,6 +311,11 @@ class ServerArgs:
248
311
  default=ServerArgs.chat_template,
249
312
  help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server.",
250
313
  )
314
+ parser.add_argument(
315
+ "--is-embedding",
316
+ action="store_true",
317
+ help="Whether to use a CausalLM as an embedding model.",
318
+ )
251
319
  parser.add_argument(
252
320
  "--mem-fraction-static",
253
321
  type=float,
@@ -260,17 +328,12 @@ class ServerArgs:
260
328
  default=ServerArgs.max_running_requests,
261
329
  help="The maximum number of running requests.",
262
330
  )
263
- parser.add_argument(
264
- "--max-num-reqs",
265
- type=int,
266
- default=ServerArgs.max_num_reqs,
267
- help="The maximum number of requests to serve in the memory pool. If the model have a large context length, you may need to decrease this value to avoid out-of-memory errors.",
268
- )
269
331
  parser.add_argument(
270
332
  "--max-total-tokens",
271
333
  type=int,
272
334
  default=ServerArgs.max_total_tokens,
273
- help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. This option is typically used for development and debugging purposes.",
335
+ help="The maximum number of tokens in the memory pool. If not specified, it will be automatically calculated based on the memory usage fraction. "
336
+ "This option is typically used for development and debugging purposes.",
274
337
  )
275
338
  parser.add_argument(
276
339
  "--chunked-prefill-size",
@@ -316,6 +379,12 @@ class ServerArgs:
316
379
  default=ServerArgs.random_seed,
317
380
  help="The random seed.",
318
381
  )
382
+ parser.add_argument(
383
+ "--constrained-json-whitespace-pattern",
384
+ type=str,
385
+ default=ServerArgs.constrained_json_whitespace_pattern,
386
+ help=r"Regex pattern for syntactic whitespaces allowed in JSON constrained output. For example, to allow the model generate consecutive whitespaces, set the pattern to [\n\t ]*",
387
+ )
319
388
  parser.add_argument(
320
389
  "--log-level",
321
390
  type=str,
@@ -381,16 +450,38 @@ class ServerArgs:
381
450
  )
382
451
  parser.add_argument("--node-rank", type=int, help="The node rank.")
383
452
 
453
+ # Model override args
454
+ parser.add_argument(
455
+ "--json-model-override-args",
456
+ type=str,
457
+ help="A dictionary in JSON string format used to override default model configurations.",
458
+ default=ServerArgs.json_model_override_args,
459
+ )
460
+
384
461
  # Optimization/debug options
462
+ parser.add_argument(
463
+ "--attention-backend",
464
+ type=str,
465
+ choices=["flashinfer", "triton"],
466
+ default=ServerArgs.attention_backend,
467
+ help="Choose the kernels for attention layers.",
468
+ )
469
+ parser.add_argument(
470
+ "--sampling-backend",
471
+ type=str,
472
+ choices=["flashinfer", "pytorch"],
473
+ default=ServerArgs.sampling_backend,
474
+ help="Choose the kernels for sampling layers.",
475
+ )
385
476
  parser.add_argument(
386
477
  "--disable-flashinfer",
387
478
  action="store_true",
388
- help="Disable flashinfer attention kernels.",
479
+ help="Disable flashinfer attention kernels. This option will be deprecated in the next release. Please use '--attention-backend triton' instead.",
389
480
  )
390
481
  parser.add_argument(
391
482
  "--disable-flashinfer-sampling",
392
483
  action="store_true",
393
- help="Disable flashinfer sampling kernels.",
484
+ help="Disable flashinfer sampling kernels. This option will be deprecated in the next release. Please use '--sampling-backend pytorch' instead.",
394
485
  )
395
486
  parser.add_argument(
396
487
  "--disable-radix-cache",
@@ -431,7 +522,19 @@ class ServerArgs:
431
522
  parser.add_argument(
432
523
  "--enable-torch-compile",
433
524
  action="store_true",
434
- help="Optimize the model with torch.compile, experimental feature.",
525
+ help="Optimize the model with torch.compile. Experimental feature.",
526
+ )
527
+ parser.add_argument(
528
+ "--max-torch-compile-bs",
529
+ type=int,
530
+ default=ServerArgs.max_torch_compile_bs,
531
+ help="Set the maximum batch size when using torch compile.",
532
+ )
533
+ parser.add_argument(
534
+ "--torchao-config",
535
+ type=str,
536
+ default=ServerArgs.torchao_config,
537
+ help="Optimize the model with torchao. Experimental feature. Current choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo",
435
538
  )
436
539
  parser.add_argument(
437
540
  "--enable-p2p-check",
@@ -455,6 +558,22 @@ class ServerArgs:
455
558
  help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).",
456
559
  )
457
560
 
561
+ # LoRA options
562
+ parser.add_argument(
563
+ "--lora-paths",
564
+ type=str,
565
+ nargs="*",
566
+ default=None,
567
+ action=LoRAPathAction,
568
+ help="The list of LoRA adapters. You can provide a list of either path in str or renamed path in the format {name}={path}",
569
+ )
570
+ parser.add_argument(
571
+ "--max-loras-per-batch",
572
+ type=int,
573
+ default=8,
574
+ help="Maximum number of adapters for a running batch, include base-only request",
575
+ )
576
+
458
577
  @classmethod
459
578
  def from_cli_args(cls, args: argparse.Namespace):
460
579
  args.tp_size = args.tensor_parallel_size
@@ -472,14 +591,30 @@ class ServerArgs:
472
591
  assert not (
473
592
  self.dp_size > 1 and self.node_rank is not None
474
593
  ), "multi-node data parallel is not supported"
475
- if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == self.model_path:
476
- logger.info(
477
- "Not sure why, the tokenizer will add an additional token at the end of the prompt when trust_remote_mode=True"
478
- )
479
- self.trust_remote_code = False
480
- if "gemma-2" in self.model_path.lower():
481
- logger.info("When using sliding window in gemma-2, turn on flashinfer.")
482
- self.disable_flashinfer = False
594
+ assert (
595
+ self.max_loras_per_batch > 0
596
+ # FIXME
597
+ and (self.lora_paths is None or self.disable_cuda_graph)
598
+ and (self.lora_paths is None or self.disable_radix_cache)
599
+ ), "compatibility of lora and cuda graph and radix attention is in progress"
600
+
601
+
602
+ def prepare_server_args(argv: List[str]) -> ServerArgs:
603
+ """
604
+ Prepare the server arguments from the command line arguments.
605
+
606
+ Args:
607
+ args: The command line arguments. Typically, it should be `sys.argv[1:]`
608
+ to ensure compatibility with `parse_args` when no arguments are passed.
609
+
610
+ Returns:
611
+ The server arguments.
612
+ """
613
+ parser = argparse.ArgumentParser()
614
+ ServerArgs.add_cli_args(parser)
615
+ raw_args = parser.parse_args(argv)
616
+ server_args = ServerArgs.from_cli_args(raw_args)
617
+ return server_args
483
618
 
484
619
 
485
620
  @dataclasses.dataclass
sglang/srt/utils.py CHANGED
@@ -35,6 +35,7 @@ import torch
35
35
  import torch.distributed as dist
36
36
  from fastapi.responses import JSONResponse
37
37
  from packaging import version as pkg_version
38
+ from torch import nn
38
39
  from torch.nn.parameter import Parameter
39
40
  from triton.runtime.cache import (
40
41
  FileCacheManager,
@@ -50,6 +51,11 @@ show_time_cost = False
50
51
  time_infos = {}
51
52
 
52
53
 
54
+ # torch flag AMD GPU
55
+ def is_hip() -> bool:
56
+ return torch.version.hip is not None
57
+
58
+
53
59
  def enable_show_time_cost():
54
60
  global show_time_cost
55
61
  show_time_cost = True
@@ -186,7 +192,7 @@ def allocate_init_ports(
186
192
  cur_port += 1
187
193
 
188
194
  if port is not None and ret_ports[0] != port:
189
- logger.warn(
195
+ logger.warning(
190
196
  f"WARNING: Port {port} is not available. Use port {ret_ports[0]} instead."
191
197
  )
192
198
 
@@ -622,56 +628,7 @@ def set_ulimit(target_soft_limit=65535):
622
628
  try:
623
629
  resource.setrlimit(resource_type, (target_soft_limit, current_hard))
624
630
  except ValueError as e:
625
- logger.warn(f"Fail to set RLIMIT_NOFILE: {e}")
626
-
627
-
628
- def is_llama3_405b_fp8_head_16(model_config):
629
- """Return whether the model is meta-llama/Meta-Llama-3.1-405B-FP8 with 16 kv heads."""
630
- if (
631
- model_config.hf_config.architectures[0] == "LlamaForCausalLM"
632
- and model_config.hf_config.hidden_size == 16384
633
- and model_config.hf_config.intermediate_size == 53248
634
- and model_config.hf_config.num_hidden_layers == 126
635
- and model_config.hf_config.num_key_value_heads == 16
636
- and hasattr(model_config.hf_config, "quantization_config")
637
- and model_config.hf_config.quantization_config["quant_method"] == "fbgemm_fp8"
638
- ):
639
- return True
640
- return False
641
-
642
-
643
- def monkey_patch_vllm_qvk_linear_loader():
644
- """A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints."""
645
- from vllm.model_executor.layers.linear import QKVParallelLinear
646
-
647
- origin_weight_loader = QKVParallelLinear.weight_loader
648
-
649
- def get_original_weight(loaded_weight, head_dim):
650
- n_kv_head = loaded_weight.shape[0] // (2 * head_dim)
651
- dim = loaded_weight.shape[1]
652
- for i in range(n_kv_head):
653
- loaded_weight[i * head_dim : (i + 1) * head_dim, :] = loaded_weight[
654
- 2 * i * head_dim : (2 * i + 1) * head_dim, :
655
- ]
656
- original_kv_weight = loaded_weight[: n_kv_head * head_dim, :]
657
- assert original_kv_weight.shape == (n_kv_head * head_dim, dim)
658
- return original_kv_weight
659
-
660
- def weight_loader_srt(
661
- self,
662
- param: Parameter,
663
- loaded_weight: torch.Tensor,
664
- loaded_shard_id: Optional[str] = None,
665
- ):
666
- if (
667
- loaded_shard_id in ["k", "v"]
668
- and loaded_weight.shape[0] == self.head_size * self.total_num_kv_heads * 2
669
- ):
670
- loaded_weight = get_original_weight(loaded_weight, self.head_size)
671
-
672
- origin_weight_loader(self, param, loaded_weight, loaded_shard_id)
673
-
674
- setattr(QKVParallelLinear, "weight_loader", weight_loader_srt)
631
+ logger.warning(f"Fail to set RLIMIT_NOFILE: {e}")
675
632
 
676
633
 
677
634
  def add_api_key_middleware(app, api_key: str):
@@ -714,3 +671,14 @@ def configure_logger(server_args, prefix: str = ""):
714
671
  datefmt="%H:%M:%S",
715
672
  force=True,
716
673
  )
674
+
675
+
676
+ # source: https://github.com/vllm-project/vllm/blob/93b38bea5dd03e1b140ca997dfaadef86f8f1855/vllm/lora/utils.py#L9
677
+ def replace_submodule(
678
+ model: nn.Module, module_name: str, new_module: nn.Module
679
+ ) -> nn.Module:
680
+ """Replace a submodule in a model with a new module."""
681
+ parent = model.get_submodule(".".join(module_name.split(".")[:-1]))
682
+ target_name = module_name.split(".")[-1]
683
+ setattr(parent, target_name, new_module)
684
+ return new_module
@@ -0,0 +1,132 @@
1
+ """
2
+ Run few-shot GSM-8K evaluation.
3
+
4
+ Usage:
5
+ python3 -m sglang.test.few_shot_gsm8k --num-questions 200
6
+ """
7
+
8
+ import argparse
9
+ import ast
10
+ import re
11
+ import time
12
+
13
+ import numpy as np
14
+
15
+ from sglang.api import set_default_backend
16
+ from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
17
+ from sglang.utils import download_and_cache_file, dump_state_text, read_jsonl
18
+
19
+ INVALID = -9999999
20
+
21
+
22
+ def get_one_example(lines, i, include_answer):
23
+ ret = "Question: " + lines[i]["question"] + "\nAnswer:"
24
+ if include_answer:
25
+ ret += " " + lines[i]["answer"]
26
+ return ret
27
+
28
+
29
+ def get_few_shot_examples(lines, k):
30
+ ret = ""
31
+ for i in range(k):
32
+ ret += get_one_example(lines, i, True) + "\n\n"
33
+ return ret
34
+
35
+
36
+ def get_answer_value(answer_str):
37
+ answer_str = answer_str.replace(",", "")
38
+ numbers = re.findall(r"\d+", answer_str)
39
+ if len(numbers) < 1:
40
+ return INVALID
41
+ try:
42
+ return ast.literal_eval(numbers[-1])
43
+ except SyntaxError:
44
+ return INVALID
45
+
46
+
47
+ def main(args):
48
+ # Select backend
49
+ set_default_backend(RuntimeEndpoint(f"{args.host}:{args.port}"))
50
+
51
+ # Read data
52
+ url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl"
53
+ filename = download_and_cache_file(url)
54
+ lines = list(read_jsonl(filename))
55
+
56
+ # Construct prompts
57
+ num_questions = args.num_questions
58
+ num_shots = args.num_shots
59
+ few_shot_examples = get_few_shot_examples(lines, num_shots)
60
+
61
+ questions = []
62
+ labels = []
63
+ for i in range(len(lines[:num_questions])):
64
+ questions.append(get_one_example(lines, i, False))
65
+ labels.append(get_answer_value(lines[i]["answer"]))
66
+ assert all(l != INVALID for l in labels)
67
+ arguments = [{"question": q} for q in questions]
68
+
69
+ #####################################
70
+ ######### SGL Program Begin #########
71
+ #####################################
72
+
73
+ import sglang as sgl
74
+
75
+ @sgl.function
76
+ def few_shot_gsm8k(s, question):
77
+ s += few_shot_examples + question
78
+ s += sgl.gen(
79
+ "answer", max_tokens=512, stop=["Question", "Assistant:", "<|separator|>"]
80
+ )
81
+
82
+ #####################################
83
+ ########## SGL Program End ##########
84
+ #####################################
85
+
86
+ # Run requests
87
+ tic = time.time()
88
+ states = few_shot_gsm8k.run_batch(
89
+ arguments,
90
+ temperature=0,
91
+ num_threads=args.parallel,
92
+ progress_bar=True,
93
+ )
94
+ latency = time.time() - tic
95
+
96
+ preds = []
97
+ for i in range(len(states)):
98
+ preds.append(get_answer_value(states[i]["answer"]))
99
+
100
+ # print(f"{preds=}")
101
+ # print(f"{labels=}")
102
+
103
+ # Compute accuracy
104
+ acc = np.mean(np.array(preds) == np.array(labels))
105
+ invalid = np.mean(np.array(preds) == INVALID)
106
+
107
+ # Compute speed
108
+ num_output_tokens = sum(
109
+ s.get_meta_info("answer")["completion_tokens"] for s in states
110
+ )
111
+ output_throughput = num_output_tokens / latency
112
+
113
+ # Print results
114
+ print(f"Accuracy: {acc:.3f}")
115
+ print(f"Invalid: {invalid:.3f}")
116
+ print(f"Latency: {latency:.3f} s")
117
+ print(f"Output throughput: {output_throughput:.3f} token/s")
118
+
119
+ # Dump results
120
+ dump_state_text("tmp_output_gsm8k.txt", states)
121
+
122
+
123
+ if __name__ == "__main__":
124
+ parser = argparse.ArgumentParser()
125
+ parser.add_argument("--num-shots", type=int, default=5)
126
+ parser.add_argument("--data-path", type=str, default="test.jsonl")
127
+ parser.add_argument("--num-questions", type=int, default=200)
128
+ parser.add_argument("--parallel", type=int, default=128)
129
+ parser.add_argument("--host", type=str, default="http://127.0.0.1")
130
+ parser.add_argument("--port", type=int, default=30000)
131
+ args = parser.parse_args()
132
+ main(args)