sglang 0.2.12__py3-none-any.whl → 0.2.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. sglang/api.py +13 -1
  2. sglang/bench_latency.py +10 -5
  3. sglang/bench_serving.py +50 -26
  4. sglang/check_env.py +15 -0
  5. sglang/global_config.py +1 -1
  6. sglang/lang/backend/runtime_endpoint.py +60 -49
  7. sglang/lang/chat_template.py +10 -5
  8. sglang/lang/compiler.py +4 -0
  9. sglang/lang/interpreter.py +5 -2
  10. sglang/lang/ir.py +22 -4
  11. sglang/launch_server.py +8 -1
  12. sglang/srt/constrained/jump_forward.py +13 -2
  13. sglang/srt/conversation.py +50 -1
  14. sglang/srt/hf_transformers_utils.py +22 -23
  15. sglang/srt/layers/activation.py +24 -2
  16. sglang/srt/layers/decode_attention.py +338 -50
  17. sglang/srt/layers/extend_attention.py +3 -1
  18. sglang/srt/layers/fused_moe/__init__.py +1 -0
  19. sglang/srt/layers/{fused_moe.py → fused_moe/fused_moe.py} +165 -108
  20. sglang/srt/layers/fused_moe/layer.py +587 -0
  21. sglang/srt/layers/layernorm.py +3 -0
  22. sglang/srt/layers/logits_processor.py +64 -27
  23. sglang/srt/layers/radix_attention.py +41 -18
  24. sglang/srt/layers/sampler.py +154 -0
  25. sglang/srt/managers/controller_multi.py +2 -8
  26. sglang/srt/managers/controller_single.py +7 -10
  27. sglang/srt/managers/detokenizer_manager.py +20 -9
  28. sglang/srt/managers/io_struct.py +44 -11
  29. sglang/srt/managers/policy_scheduler.py +5 -2
  30. sglang/srt/managers/schedule_batch.py +59 -179
  31. sglang/srt/managers/tokenizer_manager.py +193 -84
  32. sglang/srt/managers/tp_worker.py +131 -50
  33. sglang/srt/mem_cache/memory_pool.py +82 -8
  34. sglang/srt/mm_utils.py +79 -7
  35. sglang/srt/model_executor/cuda_graph_runner.py +97 -28
  36. sglang/srt/model_executor/forward_batch_info.py +188 -82
  37. sglang/srt/model_executor/model_runner.py +269 -87
  38. sglang/srt/models/chatglm.py +6 -14
  39. sglang/srt/models/commandr.py +6 -2
  40. sglang/srt/models/dbrx.py +5 -1
  41. sglang/srt/models/deepseek.py +7 -3
  42. sglang/srt/models/deepseek_v2.py +12 -7
  43. sglang/srt/models/gemma.py +6 -2
  44. sglang/srt/models/gemma2.py +22 -8
  45. sglang/srt/models/gpt_bigcode.py +5 -1
  46. sglang/srt/models/grok.py +66 -398
  47. sglang/srt/models/internlm2.py +5 -1
  48. sglang/srt/models/llama2.py +7 -3
  49. sglang/srt/models/llama_classification.py +2 -2
  50. sglang/srt/models/llama_embedding.py +4 -0
  51. sglang/srt/models/llava.py +176 -59
  52. sglang/srt/models/minicpm.py +7 -3
  53. sglang/srt/models/mixtral.py +61 -255
  54. sglang/srt/models/mixtral_quant.py +6 -5
  55. sglang/srt/models/qwen.py +7 -4
  56. sglang/srt/models/qwen2.py +15 -5
  57. sglang/srt/models/qwen2_moe.py +7 -16
  58. sglang/srt/models/stablelm.py +6 -2
  59. sglang/srt/openai_api/adapter.py +149 -58
  60. sglang/srt/sampling/sampling_batch_info.py +209 -0
  61. sglang/srt/{sampling_params.py → sampling/sampling_params.py} +18 -4
  62. sglang/srt/server.py +107 -71
  63. sglang/srt/server_args.py +49 -15
  64. sglang/srt/utils.py +27 -18
  65. sglang/test/runners.py +38 -38
  66. sglang/test/simple_eval_common.py +9 -10
  67. sglang/test/simple_eval_gpqa.py +2 -1
  68. sglang/test/simple_eval_humaneval.py +2 -2
  69. sglang/test/simple_eval_math.py +2 -1
  70. sglang/test/simple_eval_mmlu.py +2 -1
  71. sglang/test/test_activation.py +55 -0
  72. sglang/test/test_programs.py +32 -5
  73. sglang/test/test_utils.py +37 -50
  74. sglang/version.py +1 -1
  75. {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/METADATA +102 -27
  76. sglang-0.2.14.dist-info/RECORD +114 -0
  77. {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/WHEEL +1 -1
  78. sglang/launch_server_llavavid.py +0 -29
  79. sglang/srt/model_loader/model_loader.py +0 -292
  80. sglang/srt/model_loader/utils.py +0 -275
  81. sglang-0.2.12.dist-info/RECORD +0 -112
  82. {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/LICENSE +0 -0
  83. {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/top_level.txt +0 -0
sglang/srt/server_args.py CHANGED
@@ -17,9 +17,12 @@ limitations under the License.
17
17
 
18
18
  import argparse
19
19
  import dataclasses
20
+ import logging
20
21
  import random
21
22
  from typing import List, Optional, Union
22
23
 
24
+ logger = logging.getLogger(__name__)
25
+
23
26
 
24
27
  @dataclasses.dataclass
25
28
  class ServerArgs:
@@ -30,11 +33,13 @@ class ServerArgs:
30
33
  skip_tokenizer_init: bool = False
31
34
  load_format: str = "auto"
32
35
  dtype: str = "auto"
36
+ kv_cache_dtype: str = "auto"
33
37
  trust_remote_code: bool = True
34
38
  context_length: Optional[int] = None
35
39
  quantization: Optional[str] = None
36
40
  served_model_name: Optional[str] = None
37
41
  chat_template: Optional[str] = None
42
+ is_embedding: bool = False
38
43
 
39
44
  # Port
40
45
  host: str = "127.0.0.1"
@@ -46,7 +51,7 @@ class ServerArgs:
46
51
  max_running_requests: Optional[int] = None
47
52
  max_num_reqs: Optional[int] = None
48
53
  max_total_tokens: Optional[int] = None
49
- chunked_prefill_size: int = -1
54
+ chunked_prefill_size: int = 8192
50
55
  max_prefill_tokens: int = 16384
51
56
  schedule_policy: str = "lpm"
52
57
  schedule_conservativeness: float = 1.0
@@ -76,12 +81,14 @@ class ServerArgs:
76
81
  disable_radix_cache: bool = False
77
82
  disable_regex_jump_forward: bool = False
78
83
  disable_cuda_graph: bool = False
84
+ disable_cuda_graph_padding: bool = False
79
85
  disable_disk_cache: bool = False
86
+ disable_custom_all_reduce: bool = False
87
+ enable_mixed_chunk: bool = False
80
88
  enable_torch_compile: bool = False
81
89
  enable_p2p_check: bool = False
82
90
  enable_mla: bool = False
83
- attention_reduce_in_fp32: bool = False
84
- efficient_weight_load: bool = False
91
+ triton_attention_reduce_in_fp32: bool = False
85
92
 
86
93
  # Distributed args
87
94
  nccl_init_addr: Optional[str] = None
@@ -190,11 +197,23 @@ class ServerArgs:
190
197
  '* "float" is shorthand for FP32 precision.\n'
191
198
  '* "float32" for FP32 precision.',
192
199
  )
200
+ parser.add_argument(
201
+ "--kv-cache-dtype",
202
+ type=str,
203
+ default=ServerArgs.kv_cache_dtype,
204
+ choices=["auto", "fp8_e5m2"],
205
+ help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" is supported for CUDA 11.8+.',
206
+ )
193
207
  parser.add_argument(
194
208
  "--trust-remote-code",
195
209
  action="store_true",
196
210
  help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
197
211
  )
212
+ parser.add_argument(
213
+ "--is-embedding",
214
+ action="store_true",
215
+ help="Whether to use a CausalLM as an embedding model.",
216
+ )
198
217
  parser.add_argument(
199
218
  "--context-length",
200
219
  type=int,
@@ -388,11 +407,27 @@ class ServerArgs:
388
407
  action="store_true",
389
408
  help="Disable cuda graph.",
390
409
  )
410
+ parser.add_argument(
411
+ "--disable-cuda-graph-padding",
412
+ action="store_true",
413
+ help="Disable cuda graph when padding is needed. Still uses cuda graph when padding is not needed.",
414
+ )
391
415
  parser.add_argument(
392
416
  "--disable-disk-cache",
393
417
  action="store_true",
394
418
  help="Disable disk cache to avoid possible crashes related to file system or high concurrency.",
395
419
  )
420
+ parser.add_argument(
421
+ "--disable-custom-all-reduce",
422
+ action="store_true",
423
+ default=False,
424
+ help="Disable the custom all-reduce kernel and fall back to NCCL.",
425
+ )
426
+ parser.add_argument(
427
+ "--enable-mixed-chunk",
428
+ action="store_true",
429
+ help="Enabling mixing prefill and decode in a batch when using chunked prefill.",
430
+ )
396
431
  parser.add_argument(
397
432
  "--enable-torch-compile",
398
433
  action="store_true",
@@ -406,13 +441,13 @@ class ServerArgs:
406
441
  parser.add_argument(
407
442
  "--enable-mla",
408
443
  action="store_true",
409
- help="Enable Multi-head Latent Attention (MLA) for DeepSeek-V2",
444
+ help="Enable Multi-head Latent Attention (MLA) for DeepSeek-V2.",
410
445
  )
411
446
  parser.add_argument(
412
- "--attention-reduce-in-fp32",
447
+ "--triton-attention-reduce-in-fp32",
413
448
  action="store_true",
414
449
  help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
415
- "This only affects Triton attention kernels",
450
+ "This only affects Triton attention kernels.",
416
451
  )
417
452
  parser.add_argument(
418
453
  "--efficient-weight-load",
@@ -430,15 +465,6 @@ class ServerArgs:
430
465
  def url(self):
431
466
  return f"http://{self.host}:{self.port}"
432
467
 
433
- def print_mode_args(self):
434
- return (
435
- f"disable_flashinfer={self.disable_flashinfer}, "
436
- f"attention_reduce_in_fp32={self.attention_reduce_in_fp32}, "
437
- f"disable_radix_cache={self.disable_radix_cache}, "
438
- f"disable_regex_jump_forward={self.disable_regex_jump_forward}, "
439
- f"disable_disk_cache={self.disable_disk_cache}, "
440
- )
441
-
442
468
  def check_server_args(self):
443
469
  assert (
444
470
  self.tp_size % self.nnodes == 0
@@ -446,6 +472,14 @@ class ServerArgs:
446
472
  assert not (
447
473
  self.dp_size > 1 and self.node_rank is not None
448
474
  ), "multi-node data parallel is not supported"
475
+ if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == self.model_path:
476
+ logger.info(
477
+ "Not sure why, the tokenizer will add an additional token at the end of the prompt when trust_remote_mode=True"
478
+ )
479
+ self.trust_remote_code = False
480
+ if "gemma-2" in self.model_path.lower():
481
+ logger.info("When using sliding window in gemma-2, turn on flashinfer.")
482
+ self.disable_flashinfer = False
449
483
 
450
484
 
451
485
  @dataclasses.dataclass
sglang/srt/utils.py CHANGED
@@ -35,7 +35,6 @@ import torch
35
35
  import torch.distributed as dist
36
36
  from fastapi.responses import JSONResponse
37
37
  from packaging import version as pkg_version
38
- from starlette.middleware.base import BaseHTTPMiddleware
39
38
  from torch.nn.parameter import Parameter
40
39
  from triton.runtime.cache import (
41
40
  FileCacheManager,
@@ -225,13 +224,18 @@ def is_multimodal_model(model):
225
224
  raise ValueError("unrecognized type")
226
225
 
227
226
 
228
- def is_generation_model(model_architectures):
227
+ def is_generation_model(model_architectures, is_embedding: bool = False):
228
+ # We have two ways to determine whether a model is a generative model.
229
+ # 1. Check the model architectue
230
+ # 2. check the `is_embedding` server args
231
+
229
232
  if (
230
233
  "LlamaEmbeddingModel" in model_architectures
231
234
  or "MistralModel" in model_architectures
232
235
  ):
233
236
  return False
234
- return True
237
+ else:
238
+ return not is_embedding
235
239
 
236
240
 
237
241
  def decode_video_base64(video_base64):
@@ -348,7 +352,7 @@ def suppress_other_loggers():
348
352
  logging.WARN
349
353
  )
350
354
  logging.getLogger("vllm.selector").setLevel(logging.WARN)
351
- logging.getLogger("vllm.utils").setLevel(logging.WARN)
355
+ logging.getLogger("vllm.utils").setLevel(logging.ERROR)
352
356
 
353
357
 
354
358
  def assert_pkg_version(pkg: str, min_version: str, message: str):
@@ -370,14 +374,11 @@ def kill_parent_process():
370
374
  """Kill the parent process and all children of the parent process."""
371
375
  current_process = psutil.Process()
372
376
  parent_process = current_process.parent()
373
- children = parent_process.children(recursive=True)
374
- for child in children:
375
- if child.pid != current_process.pid:
376
- os.kill(child.pid, 9)
377
- os.kill(parent_process.pid, 9)
377
+ kill_child_process(parent_process.pid, skip_pid=current_process.pid)
378
378
 
379
379
 
380
- def kill_child_process(pid, including_parent=True):
380
+ def kill_child_process(pid, including_parent=True, skip_pid=None):
381
+ """Kill the process and all its children process."""
381
382
  try:
382
383
  parent = psutil.Process(pid)
383
384
  except psutil.NoSuchProcess:
@@ -385,6 +386,8 @@ def kill_child_process(pid, including_parent=True):
385
386
 
386
387
  children = parent.children(recursive=True)
387
388
  for child in children:
389
+ if child.pid == skip_pid:
390
+ continue
388
391
  try:
389
392
  child.kill()
390
393
  except psutil.NoSuchProcess:
@@ -453,10 +456,6 @@ def monkey_patch_vllm_dummy_weight_loader():
453
456
  quant_method = getattr(module, "quant_method", None)
454
457
  if quant_method is not None:
455
458
  quant_method.process_weights_after_loading(module)
456
- # FIXME: Remove this after Mixtral is updated
457
- # to use quant_method.
458
- if hasattr(module, "process_weights_after_loading"):
459
- module.process_weights_after_loading()
460
459
 
461
460
  # NOTE(woosuk): For accurate performance evaluation, we assign
462
461
  # random values to the weights.
@@ -644,7 +643,7 @@ def set_ulimit(target_soft_limit=65535):
644
643
  logger.warn(f"Fail to set RLIMIT_NOFILE: {e}")
645
644
 
646
645
 
647
- def is_llama3_405b_fp8(model_config):
646
+ def is_llama3_405b_fp8_head_16(model_config):
648
647
  """Return whether the model is meta-llama/Meta-Llama-3.1-405B-FP8 with 16 kv heads."""
649
648
  if (
650
649
  model_config.hf_config.architectures[0] == "LlamaForCausalLM"
@@ -693,7 +692,7 @@ def monkey_patch_vllm_qvk_linear_loader():
693
692
  setattr(QKVParallelLinear, "weight_loader", weight_loader_srt)
694
693
 
695
694
 
696
- def add_api_key_middleware(app, api_key):
695
+ def add_api_key_middleware(app, api_key: str):
697
696
  @app.middleware("http")
698
697
  async def authentication(request, call_next):
699
698
  if request.method == "OPTIONS":
@@ -705,7 +704,7 @@ def add_api_key_middleware(app, api_key):
705
704
  return await call_next(request)
706
705
 
707
706
 
708
- def prepare_model(model_path):
707
+ def prepare_model(model_path: str):
709
708
  if "SGLANG_USE_MODELSCOPE" in os.environ:
710
709
  if not os.path.exists(model_path):
711
710
  from modelscope import snapshot_download
@@ -714,7 +713,7 @@ def prepare_model(model_path):
714
713
  return model_path
715
714
 
716
715
 
717
- def prepare_tokenizer(tokenizer_path):
716
+ def prepare_tokenizer(tokenizer_path: str):
718
717
  if "SGLANG_USE_MODELSCOPE" in os.environ:
719
718
  if not os.path.exists(tokenizer_path):
720
719
  from modelscope import snapshot_download
@@ -723,3 +722,13 @@ def prepare_tokenizer(tokenizer_path):
723
722
  tokenizer_path, ignore_patterns=["*.bin", "*.safetensors"]
724
723
  )
725
724
  return tokenizer_path
725
+
726
+
727
+ def configure_logger(server_args, prefix: str = ""):
728
+ format = f"[%(asctime)s{prefix}] %(message)s"
729
+ logging.basicConfig(
730
+ level=getattr(logging, server_args.log_level.upper()),
731
+ format=format,
732
+ datefmt="%H:%M:%S",
733
+ force=True,
734
+ )
sglang/test/runners.py CHANGED
@@ -14,7 +14,8 @@ limitations under the License.
14
14
  """
15
15
 
16
16
  import json
17
- import multiprocessing
17
+ import multiprocessing as mp
18
+ import os
18
19
  from dataclasses import dataclass
19
20
  from typing import List, Union
20
21
 
@@ -23,16 +24,22 @@ import torch.nn.functional as F
23
24
  from transformers import AutoModelForCausalLM, AutoTokenizer
24
25
 
25
26
  from sglang.srt.server import Runtime
26
- from sglang.srt.utils import is_generation_model
27
+ from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER
27
28
 
28
29
  DEFAULT_PROMPTS = [
29
30
  # the output of gemma-2-2b from SRT is unstable on the commented prompt
30
31
  # "The capital of France is",
32
+ "Apple is red. Banana is Yellow. " * 800 + "Apple is",
31
33
  "The capital of the United Kindom is",
32
34
  "Today is a sunny day and I like",
33
35
  "AI is a field of computer science focused on",
34
36
  ]
35
37
 
38
+ dirpath = os.path.dirname(__file__)
39
+ with open(os.path.join(dirpath, "long_prompt.txt"), "r") as f:
40
+ long_prompt = f.read()
41
+ DEFAULT_PROMPTS.append(long_prompt)
42
+
36
43
  NUM_TOP_LOGPROBS = 5
37
44
 
38
45
 
@@ -56,44 +63,37 @@ class HFRunner:
56
63
  def __init__(
57
64
  self,
58
65
  model_path,
59
- torch_dtype=torch.float16,
60
- is_generation_model=None,
66
+ torch_dtype,
67
+ is_generation,
61
68
  ):
62
- self.in_queue = multiprocessing.Queue()
63
- self.out_queue = multiprocessing.Queue()
69
+ self.is_generation = is_generation
70
+
71
+ self.in_queue = mp.Queue()
72
+ self.out_queue = mp.Queue()
64
73
 
65
- self.model_proc = multiprocessing.Process(
74
+ self.model_proc = mp.Process(
66
75
  target=self.start_model_process,
67
76
  args=(
68
77
  self.in_queue,
69
78
  self.out_queue,
70
79
  model_path,
71
80
  torch_dtype,
72
- is_generation_model,
73
81
  ),
74
82
  )
75
83
  self.model_proc.start()
76
84
 
77
- def start_model_process(
78
- self, in_queue, out_queue, model_path, torch_dtype, is_generation_model
79
- ):
85
+ def start_model_process(self, in_queue, out_queue, model_path, torch_dtype):
80
86
  self.tokenizer = AutoTokenizer.from_pretrained(
81
87
  model_path,
82
88
  torch_dtype=torch_dtype,
83
- trust_remote_code=True,
84
89
  )
85
90
 
86
- self.is_generation_model = (
87
- is_generation_model(model_path)
88
- if is_generation_model is None
89
- else is_generation_model
90
- )
91
- if self.is_generation_model:
91
+ if self.is_generation:
92
92
  self.model = AutoModelForCausalLM.from_pretrained(
93
93
  model_path,
94
94
  torch_dtype=torch_dtype,
95
+ trust_remote_code=False,
95
96
  low_cpu_mem_usage=True,
96
- trust_remote_code=True,
97
97
  ).cuda()
98
98
  else:
99
99
  from sentence_transformers import SentenceTransformer
@@ -106,7 +106,7 @@ class HFRunner:
106
106
  while True:
107
107
  prompts, max_new_tokens = in_queue.get()
108
108
  if prompts is not None:
109
- if self.is_generation_model:
109
+ if self.is_generation:
110
110
  output_strs = []
111
111
  prefill_logprobs = []
112
112
  for p in prompts:
@@ -125,16 +125,14 @@ class HFRunner:
125
125
  )
126
126
 
127
127
  logits = self.model.forward(input_ids).logits[0]
128
- logprobs = F.log_softmax(
129
- logits, dim=-1, dtype=torch.float32
130
- ).tolist()
131
- # index_of_max = (lambda nums: nums.index(max(nums)))(logprobs[-1])
132
- # print("index", index_of_max)
133
- logprobs = [
134
- sorted(token_logprobs, reverse=True)[:NUM_TOP_LOGPROBS]
135
- for token_logprobs in logprobs
136
- ]
137
- prefill_logprobs.append(logprobs)
128
+ logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
129
+ logprobs, top_indices = torch.topk(
130
+ logprobs, k=NUM_TOP_LOGPROBS, dim=-1
131
+ )
132
+ # print("index", top_indices)
133
+ prefill_logprobs.append(logprobs.tolist())
134
+ del logits
135
+ del logprobs
138
136
 
139
137
  out_queue.put(
140
138
  ModelOutput(
@@ -171,19 +169,20 @@ class SRTRunner:
171
169
  def __init__(
172
170
  self,
173
171
  model_path,
172
+ torch_dtype,
173
+ is_generation,
174
174
  tp_size=1,
175
- torch_dtype=torch.float16,
176
- is_generation_model=None,
175
+ port=DEFAULT_PORT_FOR_SRT_TEST_RUNNER,
177
176
  ):
178
- self.is_generation_model = (
179
- is_generation_model(model_path)
180
- if is_generation_model is None
181
- else is_generation_model
182
- )
177
+ self.is_generation = is_generation
183
178
  self.runtime = Runtime(
184
179
  model_path=model_path,
185
180
  tp_size=tp_size,
186
181
  dtype=get_dtype_str(torch_dtype),
182
+ port=port,
183
+ mem_fraction_static=0.69,
184
+ trust_remote_code=False,
185
+ is_embedding=not self.is_generation,
187
186
  )
188
187
 
189
188
  def forward(
@@ -191,7 +190,7 @@ class SRTRunner:
191
190
  prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
192
191
  max_new_tokens=8,
193
192
  ):
194
- if self.is_generation_model:
193
+ if self.is_generation:
195
194
  # the return value contains logprobs from prefill
196
195
  output_strs = []
197
196
  top_input_logprobs = []
@@ -201,6 +200,7 @@ class SRTRunner:
201
200
  prompt,
202
201
  sampling_params=sampling_params,
203
202
  return_logprob=True,
203
+ logprob_start_len=0,
204
204
  top_logprobs_num=NUM_TOP_LOGPROBS,
205
205
  )
206
206
  response = json.loads(response)
@@ -1,13 +1,12 @@
1
1
  # Adapted from https://github.com/openai/simple-evals/
2
2
 
3
- import base64
4
3
  import os
5
4
  import resource
6
5
  import time
7
6
  from collections import defaultdict
8
7
  from dataclasses import dataclass, field
9
8
  from multiprocessing.pool import ThreadPool
10
- from typing import Any, Dict, List, Tuple
9
+ from typing import Any, Dict, List, Optional, Tuple
11
10
 
12
11
  import httpx
13
12
  import jinja2
@@ -44,8 +43,8 @@ class EvalResult:
44
43
  Result of running an evaluation (usually consisting of many samples)
45
44
  """
46
45
 
47
- score: float | None # top-line metric
48
- metrics: Dict[str, float] | None # other metrics
46
+ score: Optional[float] # top-line metric
47
+ metrics: Optional[Dict[str, float]] # other metrics
49
48
  htmls: List[str] # strings of valid HTML
50
49
  convos: List[MessageList] # sampled conversations
51
50
 
@@ -56,10 +55,10 @@ class SingleEvalResult:
56
55
  Result of evaluating a single sample
57
56
  """
58
57
 
59
- score: float | None
58
+ score: Optional[float]
60
59
  metrics: Dict[str, float] = field(default_factory=dict)
61
- html: str | None = None
62
- convo: MessageList | None = None # sampled conversation
60
+ html: Optional[str] = None
61
+ convo: Optional[MessageList] = None # sampled conversation
63
62
 
64
63
 
65
64
  class Eval:
@@ -89,8 +88,8 @@ class ChatCompletionSampler(SamplerBase):
89
88
  def __init__(
90
89
  self,
91
90
  base_url: str = None,
92
- model: str | None = None,
93
- system_message: str | None = None,
91
+ model: Optional[str] = None,
92
+ system_message: Optional[str] = None,
94
93
  temperature: float = 0.0,
95
94
  max_tokens: int = 2048,
96
95
  ):
@@ -272,7 +271,7 @@ def _compute_stat(values: list, stat: str):
272
271
  def aggregate_results(
273
272
  single_eval_results: List[SingleEvalResult],
274
273
  default_stats: Tuple[str] = ("mean", "std"),
275
- name2stats: Dict[str, Tuple[str]] | None = None,
274
+ name2stats: Optional[Dict[str, Tuple[str]]] = None,
276
275
  ) -> EvalResult:
277
276
  """
278
277
  Aggregate results from multiple evaluations into a single EvalResult.
@@ -8,6 +8,7 @@ https://arxiv.org/abs/2311.12022
8
8
 
9
9
  import random
10
10
  import re
11
+ from typing import Optional
11
12
 
12
13
  import pandas
13
14
 
@@ -28,7 +29,7 @@ class GPQAEval(Eval):
28
29
  def __init__(
29
30
  self,
30
31
  filename: str,
31
- num_examples: int | None,
32
+ num_examples: Optional[int],
32
33
  num_threads: int,
33
34
  n_repeats: int = 1,
34
35
  ):
@@ -9,7 +9,7 @@ https://arxiv.org/abs/2107.03374 https://github.com/openai/human-eval/
9
9
  import random
10
10
  import re
11
11
  from concurrent.futures import ThreadPoolExecutor, as_completed
12
- from typing import Dict, List
12
+ from typing import Dict, List, Optional
13
13
 
14
14
  import tqdm
15
15
 
@@ -61,7 +61,7 @@ def evaluate_functional_correctness(
61
61
  class HumanEval(Eval):
62
62
  def __init__(
63
63
  self,
64
- num_examples: int | None,
64
+ num_examples: Optional[int],
65
65
  num_threads: int,
66
66
  num_samples_per_task: int = 5,
67
67
  ks_passes: List[int] = [1, 2, 5],
@@ -8,6 +8,7 @@ https://arxiv.org/abs/2103.03874
8
8
 
9
9
  import random
10
10
  import re
11
+ from typing import Optional
11
12
 
12
13
  import pandas
13
14
 
@@ -36,7 +37,7 @@ class MathEval(Eval):
36
37
  self,
37
38
  filename: str,
38
39
  equality_checker: SamplerBase,
39
- num_examples: int | None,
40
+ num_examples: Optional[int],
40
41
  num_threads: int,
41
42
  ):
42
43
  df = pandas.read_csv(filename)
@@ -8,6 +8,7 @@ https://arxiv.org/abs/2009.03300
8
8
 
9
9
  import random
10
10
  import re
11
+ from typing import Optional
11
12
 
12
13
  import pandas
13
14
 
@@ -84,7 +85,7 @@ subject2category = {
84
85
 
85
86
 
86
87
  class MMLUEval(Eval):
87
- def __init__(self, filename: str, num_examples: int | None, num_threads: int):
88
+ def __init__(self, filename: str, num_examples: Optional[int], num_threads: int):
88
89
  df = pandas.read_csv(filename)
89
90
  examples = [row.to_dict() for _, row in df.iterrows()]
90
91
  if num_examples:
@@ -0,0 +1,55 @@
1
+ import itertools
2
+ import unittest
3
+
4
+ import torch
5
+
6
+ from sglang.srt.layers.activation import GeluAndMul
7
+
8
+
9
+ class TestGeluAndMul(unittest.TestCase):
10
+ DTYPES = [torch.half, torch.bfloat16]
11
+ NUM_TOKENS = [7, 83, 2048]
12
+ D = [512, 4096, 5120, 13824]
13
+ SEEDS = [0]
14
+
15
+ @classmethod
16
+ def setUpClass(cls):
17
+ if not torch.cuda.is_available():
18
+ raise unittest.SkipTest("CUDA is not available")
19
+ torch.set_default_device("cuda")
20
+
21
+ def _run_gelu_and_mul_test(self, num_tokens, d, dtype, seed):
22
+ torch.manual_seed(seed)
23
+
24
+ layer = GeluAndMul().to(dtype=dtype)
25
+ x = torch.randn(num_tokens, 2 * d, dtype=dtype)
26
+
27
+ with torch.inference_mode():
28
+ ref_out = layer.forward_native(x)
29
+ out = layer.forward_cuda(x)
30
+
31
+ if dtype == torch.bfloat16:
32
+ atol = rtol = 1e-2
33
+ else:
34
+ atol = rtol = 1e-3
35
+
36
+ self.assertTrue(torch.allclose(out, ref_out, atol=atol, rtol=rtol))
37
+
38
+ def test_gelu_and_mul(self):
39
+ for params in itertools.product(
40
+ self.NUM_TOKENS,
41
+ self.D,
42
+ self.DTYPES,
43
+ self.SEEDS,
44
+ ):
45
+ with self.subTest(
46
+ num_tokens=params[0],
47
+ d=params[1],
48
+ dtype=params[2],
49
+ seed=params[3],
50
+ ):
51
+ self._run_gelu_and_mul_test(*params)
52
+
53
+
54
+ if __name__ == "__main__":
55
+ unittest.main(verbosity=2)
@@ -103,16 +103,19 @@ def test_decode_int():
103
103
  def test_decode_json_regex():
104
104
  @sgl.function
105
105
  def decode_json(s):
106
- from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STRING
106
+ from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STR
107
107
 
108
108
  s += "Generate a JSON object to describe the basic city information of Paris.\n"
109
+ s += "Here are the JSON object:\n"
110
+
111
+ # NOTE: we recommend using dtype gen or whole regex string to control the output
109
112
 
110
113
  with s.var_scope("json_output"):
111
114
  s += "{\n"
112
- s += ' "name": ' + sgl.gen(regex=REGEX_STRING + ",") + "\n"
113
- s += ' "population": ' + sgl.gen(regex=REGEX_INT + ",") + "\n"
114
- s += ' "area": ' + sgl.gen(regex=REGEX_INT + ",") + "\n"
115
- s += ' "latitude": ' + sgl.gen(regex=REGEX_FLOAT) + "\n"
115
+ s += ' "name": ' + sgl.gen(regex=REGEX_STR) + ",\n"
116
+ s += ' "population": ' + sgl.gen(regex=REGEX_INT, stop=[" ", "\n"]) + ",\n"
117
+ s += ' "area": ' + sgl.gen(regex=REGEX_INT, stop=[" ", "\n"]) + ",\n"
118
+ s += ' "latitude": ' + sgl.gen(regex=REGEX_FLOAT, stop=[" ", "\n"]) + "\n"
116
119
  s += "}"
117
120
 
118
121
  ret = decode_json.run(temperature=0.0)
@@ -359,6 +362,30 @@ def test_regex():
359
362
  assert re.match(regex, answer)
360
363
 
361
364
 
365
+ def test_dtype_gen():
366
+ @sgl.function
367
+ def dtype_gen(s):
368
+ s += "Q: What is the full name of DNS?\n"
369
+ s += "A: The full nams is " + sgl.gen("str_res", dtype=str, stop="\n") + "\n"
370
+ s += "Q: Which year was DNS invented?\n"
371
+ s += "A: " + sgl.gen("int_res", dtype=int) + "\n"
372
+ s += "Q: What is the value of pi?\n"
373
+ s += "A: " + sgl.gen("float_res", dtype=float) + "\n"
374
+ s += "Q: Is the sky blue?\n"
375
+ s += "A: " + sgl.gen("bool_res", dtype=bool) + "\n"
376
+
377
+ state = dtype_gen.run()
378
+
379
+ try:
380
+ state["int_res"] = int(state["int_res"])
381
+ state["float_res"] = float(state["float_res"])
382
+ state["bool_res"] = bool(state["bool_res"])
383
+ # assert state["str_res"].startswith('"') and state["str_res"].endswith('"')
384
+ except ValueError:
385
+ print(state)
386
+ raise
387
+
388
+
362
389
  def test_completion_speculative():
363
390
  @sgl.function(num_api_spec_tokens=64)
364
391
  def gen_character_spec(s):