sglang 0.3.5.post2__py3-none-any.whl → 0.3.6.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (118) hide show
  1. sglang/__init__.py +2 -2
  2. sglang/api.py +2 -2
  3. sglang/bench_latency.py +1 -553
  4. sglang/bench_offline_throughput.py +48 -20
  5. sglang/bench_one_batch.py +472 -0
  6. sglang/{bench_server_latency.py → bench_one_batch_server.py} +3 -3
  7. sglang/bench_serving.py +125 -6
  8. sglang/check_env.py +3 -6
  9. sglang/lang/backend/base_backend.py +1 -1
  10. sglang/lang/backend/runtime_endpoint.py +2 -2
  11. sglang/srt/configs/model_config.py +13 -14
  12. sglang/srt/constrained/__init__.py +13 -14
  13. sglang/srt/constrained/base_grammar_backend.py +13 -15
  14. sglang/srt/constrained/outlines_backend.py +28 -17
  15. sglang/srt/constrained/outlines_jump_forward.py +13 -15
  16. sglang/srt/constrained/xgrammar_backend.py +47 -58
  17. sglang/srt/conversation.py +13 -15
  18. sglang/srt/hf_transformers_utils.py +13 -15
  19. sglang/srt/layers/activation.py +16 -13
  20. sglang/srt/layers/attention/flashinfer_backend.py +106 -54
  21. sglang/srt/layers/attention/triton_backend.py +9 -7
  22. sglang/srt/layers/attention/triton_ops/decode_attention.py +51 -55
  23. sglang/srt/layers/attention/triton_ops/extend_attention.py +16 -16
  24. sglang/srt/layers/attention/triton_ops/prefill_attention.py +13 -15
  25. sglang/srt/layers/custom_op_util.py +25 -0
  26. sglang/srt/layers/fused_moe_grok/__init__.py +1 -0
  27. sglang/srt/layers/{fused_moe → fused_moe_grok}/fused_moe.py +11 -4
  28. sglang/srt/layers/{fused_moe → fused_moe_grok}/layer.py +4 -9
  29. sglang/srt/layers/{fused_moe/patch.py → fused_moe_patch.py} +5 -0
  30. sglang/srt/layers/fused_moe_triton/__init__.py +44 -0
  31. sglang/srt/layers/fused_moe_triton/fused_moe.py +861 -0
  32. sglang/srt/layers/fused_moe_triton/layer.py +633 -0
  33. sglang/srt/layers/layernorm.py +17 -15
  34. sglang/srt/layers/logits_processor.py +23 -25
  35. sglang/srt/layers/quantization/__init__.py +77 -17
  36. sglang/srt/layers/radix_attention.py +13 -15
  37. sglang/srt/layers/rotary_embedding.py +13 -13
  38. sglang/srt/layers/sampler.py +4 -8
  39. sglang/srt/layers/torchao_utils.py +2 -0
  40. sglang/srt/lora/lora.py +13 -14
  41. sglang/srt/lora/lora_config.py +13 -14
  42. sglang/srt/lora/lora_manager.py +22 -24
  43. sglang/srt/managers/data_parallel_controller.py +98 -27
  44. sglang/srt/managers/detokenizer_manager.py +13 -15
  45. sglang/srt/managers/io_struct.py +63 -21
  46. sglang/srt/managers/schedule_batch.py +154 -59
  47. sglang/srt/managers/schedule_policy.py +18 -16
  48. sglang/srt/managers/scheduler.py +278 -109
  49. sglang/srt/managers/session_controller.py +61 -0
  50. sglang/srt/managers/tokenizer_manager.py +63 -18
  51. sglang/srt/managers/tp_worker.py +25 -16
  52. sglang/srt/managers/tp_worker_overlap_thread.py +62 -67
  53. sglang/srt/metrics/collector.py +13 -15
  54. sglang/srt/metrics/func_timer.py +13 -15
  55. sglang/srt/mm_utils.py +13 -14
  56. sglang/srt/model_executor/cuda_graph_runner.py +63 -25
  57. sglang/srt/model_executor/forward_batch_info.py +128 -32
  58. sglang/srt/model_executor/model_runner.py +132 -64
  59. sglang/srt/model_parallel.py +98 -0
  60. sglang/srt/models/chatglm.py +15 -16
  61. sglang/srt/models/commandr.py +15 -16
  62. sglang/srt/models/dbrx.py +15 -16
  63. sglang/srt/models/deepseek.py +15 -15
  64. sglang/srt/models/deepseek_v2.py +162 -59
  65. sglang/srt/models/exaone.py +14 -15
  66. sglang/srt/models/gemma.py +14 -14
  67. sglang/srt/models/gemma2.py +31 -25
  68. sglang/srt/models/gemma2_reward.py +13 -14
  69. sglang/srt/models/gpt_bigcode.py +14 -14
  70. sglang/srt/models/grok.py +15 -15
  71. sglang/srt/models/internlm2.py +13 -15
  72. sglang/srt/models/internlm2_reward.py +13 -14
  73. sglang/srt/models/llama.py +21 -21
  74. sglang/srt/models/llama_classification.py +13 -14
  75. sglang/srt/models/llama_reward.py +13 -14
  76. sglang/srt/models/llava.py +14 -16
  77. sglang/srt/models/llavavid.py +14 -16
  78. sglang/srt/models/minicpm.py +13 -15
  79. sglang/srt/models/minicpm3.py +13 -15
  80. sglang/srt/models/mistral.py +13 -15
  81. sglang/srt/models/mixtral.py +15 -15
  82. sglang/srt/models/mixtral_quant.py +14 -14
  83. sglang/srt/models/olmo.py +22 -20
  84. sglang/srt/models/olmoe.py +23 -20
  85. sglang/srt/models/phi3_small.py +447 -0
  86. sglang/srt/models/qwen.py +14 -14
  87. sglang/srt/models/qwen2.py +22 -19
  88. sglang/srt/models/qwen2_moe.py +17 -18
  89. sglang/srt/models/qwen2_vl.py +13 -6
  90. sglang/srt/models/stablelm.py +18 -16
  91. sglang/srt/models/torch_native_llama.py +107 -93
  92. sglang/srt/models/xverse.py +13 -14
  93. sglang/srt/models/xverse_moe.py +15 -16
  94. sglang/srt/models/yivl.py +13 -15
  95. sglang/srt/openai_api/adapter.py +19 -17
  96. sglang/srt/openai_api/protocol.py +14 -16
  97. sglang/srt/sampling/penaltylib/orchestrator.py +49 -79
  98. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +3 -8
  99. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +3 -9
  100. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +3 -8
  101. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +3 -8
  102. sglang/srt/sampling/sampling_batch_info.py +61 -57
  103. sglang/srt/sampling/sampling_params.py +14 -16
  104. sglang/srt/server.py +86 -35
  105. sglang/srt/server_args.py +96 -80
  106. sglang/srt/utils.py +266 -68
  107. sglang/test/few_shot_gsm8k.py +8 -4
  108. sglang/test/runners.py +38 -20
  109. sglang/test/srt/sampling/penaltylib/utils.py +23 -21
  110. sglang/test/test_utils.py +31 -20
  111. sglang/version.py +1 -1
  112. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/LICENSE +1 -1
  113. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/METADATA +66 -57
  114. sglang-0.3.6.post1.dist-info/RECORD +164 -0
  115. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/WHEEL +1 -1
  116. sglang/srt/layers/fused_moe/__init__.py +0 -1
  117. sglang-0.3.5.post2.dist-info/RECORD +0 -156
  118. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/top_level.txt +0 -0
sglang/srt/utils.py CHANGED
@@ -1,22 +1,21 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
-
7
- http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- Unless required by applicable law or agreed to in writing, software
10
- distributed under the License is distributed on an "AS IS" BASIS,
11
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- See the License for the specific language governing permissions and
13
- limitations under the License.
14
- """
15
-
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
16
14
  """Common utilities."""
17
15
 
18
16
  import base64
19
17
  import ipaddress
18
+ import itertools
20
19
  import json
21
20
  import logging
22
21
  import os
@@ -33,7 +32,7 @@ import time
33
32
  import warnings
34
33
  from importlib.metadata import PackageNotFoundError, version
35
34
  from io import BytesIO
36
- from typing import Any, Dict, List, Optional, Union
35
+ from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, Union
37
36
 
38
37
  import numpy as np
39
38
  import psutil
@@ -46,6 +45,8 @@ from fastapi.responses import ORJSONResponse
46
45
  from packaging import version as pkg_version
47
46
  from starlette.routing import Mount
48
47
  from torch import nn
48
+ from torch.func import functional_call
49
+ from torch.library import Library
49
50
  from torch.profiler import ProfilerActivity, profile, record_function
50
51
  from triton.runtime.cache import (
51
52
  FileCacheManager,
@@ -71,6 +72,8 @@ def is_flashinfer_available():
71
72
  Check whether flashinfer is available.
72
73
  As of Oct. 6, 2024, it is only available on NVIDIA GPUs.
73
74
  """
75
+ if os.environ.get("SGLANG_IS_FLASHINFER_AVAILABLE", "true") == "false":
76
+ return False
74
77
  return torch.cuda.is_available() and not is_hip()
75
78
 
76
79
 
@@ -190,6 +193,94 @@ def get_available_gpu_memory(device, gpu_id, distributed=False):
190
193
  return free_gpu_memory / (1 << 30)
191
194
 
192
195
 
196
+ def is_pin_memory_available() -> bool:
197
+ return torch.cuda.is_available()
198
+
199
+
200
+ _CPU_OFFLOAD_BYTES = 0
201
+ _CPU_OFFLOAD_MAX_BYTES = 0
202
+
203
+
204
+ def set_cpu_offload_max_bytes(max_bytes: int) -> None:
205
+ global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
206
+ _CPU_OFFLOAD_BYTES = 0
207
+ _CPU_OFFLOAD_MAX_BYTES = max_bytes
208
+
209
+
210
+ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
211
+ device = next(module.parameters()).device
212
+
213
+ if device == torch.device("cpu"):
214
+ return module
215
+
216
+ global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
217
+ if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
218
+ return module
219
+
220
+ pin_memory = is_pin_memory_available()
221
+ # offload parameters to CPU
222
+ # use pin_memory if possible, which helps cudagraph capture speed
223
+ offloaded_parameters = False
224
+ for p in module.parameters():
225
+ if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
226
+ # we use per-parameter offloading
227
+ # one module might have some parameters offloaded and some not
228
+ break
229
+
230
+ # `torch.empty_like` does not support `pin_memory` argument
231
+ cpu_data = torch.empty_strided(
232
+ size=p.data.size(),
233
+ stride=p.data.stride(),
234
+ dtype=p.data.dtype,
235
+ layout=p.data.layout,
236
+ device="cpu",
237
+ pin_memory=pin_memory,
238
+ )
239
+ cpu_data.copy_(p.data)
240
+ p.data = cpu_data
241
+ _CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size()
242
+ offloaded_parameters = True
243
+
244
+ if offloaded_parameters:
245
+ original_forward = module.forward
246
+
247
+ def forward(*args, **kwargs):
248
+ module.forward = original_forward
249
+ device_state = {
250
+ # here we blindly call `to(device)`
251
+ # if the parameter is already on the device, it will be a no-op
252
+ k: v.to(device, non_blocking=True)
253
+ for k, v in module.state_dict().items()
254
+ }
255
+ output = functional_call(module, device_state, args=args, kwargs=kwargs)
256
+ module.forward = forward
257
+ return output
258
+
259
+ module.forward = forward
260
+
261
+ return module
262
+
263
+
264
+ class LayerFn(Protocol):
265
+
266
+ def __call__(self, layer_id: int, prefix: str) -> torch.nn.Module: ...
267
+
268
+
269
+ def make_layers(
270
+ num_hidden_layers: int,
271
+ layer_fn: LayerFn,
272
+ prefix: str = "",
273
+ ) -> Tuple[int, int, torch.nn.ModuleList]:
274
+ """Make a list of layers with the given layer function"""
275
+ modules = torch.nn.ModuleList(
276
+ [
277
+ maybe_offload_to_cpu(layer_fn(idx=idx, prefix=f"{prefix}.{idx}"))
278
+ for idx in range(num_hidden_layers)
279
+ ]
280
+ )
281
+ return modules
282
+
283
+
193
284
  def set_random_seed(seed: int) -> None:
194
285
  """Set the random seed for all libraries."""
195
286
  random.seed(seed)
@@ -330,6 +421,7 @@ def suppress_other_loggers():
330
421
  )
331
422
  logging.getLogger("vllm.selector").setLevel(logging.WARN)
332
423
  logging.getLogger("vllm.utils").setLevel(logging.ERROR)
424
+ logging.getLogger("vllm.model_executor.model_loader.loader").setLevel(logging.ERROR)
333
425
 
334
426
  warnings.filterwarnings(
335
427
  "ignore", category=UserWarning, message="The given NumPy array is not writable"
@@ -394,6 +486,27 @@ def kill_child_process(pid=None, include_self=False, skip_pid=None):
394
486
  pass
395
487
 
396
488
 
489
+ def monkey_patch_vllm_model_config():
490
+ from vllm.config import ModelConfig
491
+
492
+ if not hasattr(ModelConfig, "_resolve_task"):
493
+ return
494
+
495
+ def _resolve_task(
496
+ self,
497
+ task_option,
498
+ hf_config,
499
+ ):
500
+ supported_tasks = {
501
+ "generate": True,
502
+ "embedding": False,
503
+ }
504
+ selected_task = "generate"
505
+ return supported_tasks, selected_task
506
+
507
+ setattr(ModelConfig, "_resolve_task", _resolve_task)
508
+
509
+
397
510
  def monkey_patch_vllm_p2p_access_check(gpu_id: int):
398
511
  """
399
512
  Monkey patch the slow p2p access check in vllm.
@@ -405,57 +518,6 @@ def monkey_patch_vllm_p2p_access_check(gpu_id: int):
405
518
  setattr(tgt, "gpu_p2p_access_check", lambda *arg, **kwargs: True)
406
519
 
407
520
 
408
- def monkey_patch_vllm_dummy_weight_loader():
409
- """
410
- Monkey patch the dummy weight loader in vllm to call process_weights_after_loading.
411
- """
412
-
413
- from vllm.model_executor.model_loader.loader import (
414
- CacheConfig,
415
- DeviceConfig,
416
- DummyModelLoader,
417
- LoRAConfig,
418
- ModelConfig,
419
- ParallelConfig,
420
- SchedulerConfig,
421
- _initialize_model,
422
- initialize_dummy_weights,
423
- nn,
424
- set_default_torch_dtype,
425
- )
426
-
427
- def load_model(
428
- self,
429
- *,
430
- model_config: ModelConfig,
431
- device_config: DeviceConfig,
432
- lora_config: Optional[LoRAConfig],
433
- parallel_config: ParallelConfig,
434
- scheduler_config: SchedulerConfig,
435
- cache_config: CacheConfig,
436
- ) -> nn.Module:
437
- with set_default_torch_dtype(model_config.dtype):
438
- with torch.device(device_config.device):
439
- model = _initialize_model(
440
- model_config,
441
- self.load_config,
442
- lora_config,
443
- cache_config,
444
- )
445
-
446
- for _, module in model.named_modules():
447
- quant_method = getattr(module, "quant_method", None)
448
- if quant_method is not None:
449
- quant_method.process_weights_after_loading(module)
450
-
451
- # NOTE(woosuk): For accurate performance evaluation, we assign
452
- # random values to the weights.
453
- initialize_dummy_weights(model)
454
- return model.eval()
455
-
456
- setattr(DummyModelLoader, "load_model", load_model)
457
-
458
-
459
521
  vllm_all_gather_backup = None
460
522
 
461
523
 
@@ -794,7 +856,48 @@ def add_prometheus_middleware(app):
794
856
  app.routes.append(metrics_route)
795
857
 
796
858
 
797
- def get_gpu_memory_capacity():
859
+ def bind_port(port):
860
+ """Bind to a specific port, assuming it's available."""
861
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
862
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) # Allows address reuse
863
+ sock.bind(("", port))
864
+ sock.listen(1)
865
+ return sock
866
+
867
+
868
+ def get_amdgpu_memory_capacity():
869
+ try:
870
+ # Run rocm-smi and capture the output
871
+ result = subprocess.run(
872
+ ["rocm-smi --showmeminfo vram | grep 'Total Memory' | awk '{print $NF}'"],
873
+ stdout=subprocess.PIPE,
874
+ stderr=subprocess.PIPE,
875
+ shell=True,
876
+ text=True,
877
+ )
878
+ if result.returncode != 0:
879
+ raise RuntimeError(f"rocm-smi error: {result.stderr.strip()}")
880
+
881
+ # Parse the output to extract memory values in MiB
882
+ memory_values = [
883
+ float(mem) / 1024 / 1024
884
+ for mem in result.stdout.strip().split("\n")
885
+ if re.match(r"^\d+(\.\d+)?$", mem.strip())
886
+ ]
887
+
888
+ if not memory_values:
889
+ raise ValueError("No GPU memory values found.")
890
+
891
+ # Return the minimum memory value
892
+ return min(memory_values)
893
+
894
+ except FileNotFoundError:
895
+ raise RuntimeError(
896
+ "rocm-smi not found. Ensure AMD ROCm drivers are installed and accessible."
897
+ )
898
+
899
+
900
+ def get_nvgpu_memory_capacity():
798
901
  try:
799
902
  # Run nvidia-smi and capture the output
800
903
  result = subprocess.run(
@@ -824,3 +927,98 @@ def get_gpu_memory_capacity():
824
927
  raise RuntimeError(
825
928
  "nvidia-smi not found. Ensure NVIDIA drivers are installed and accessible."
826
929
  )
930
+
931
+
932
+ def crash_on_warnings():
933
+ # Crash on warning if we are running CI tests
934
+ return os.getenv("SGLANG_IS_IN_CI", "false").lower() == "true"
935
+
936
+
937
+ def get_device_name(device_id: int = 0) -> str:
938
+ if hasattr(torch, "cuda") and torch.cuda.is_available():
939
+ return torch.cuda.get_device_name(device_id)
940
+
941
+ if hasattr(torch, "hip") and torch.hip.is_available():
942
+ return torch.hip.get_device_name(device_id)
943
+
944
+ if hasattr(torch, "xpu") and torch.xpu.is_available():
945
+ return torch.xpu.get_device_name(device_id)
946
+
947
+ if hasattr(torch, "hpu") and torch.hpu.is_available():
948
+ return torch.hpu.get_device_name(device_id)
949
+
950
+
951
+ sglang_lib = Library("sglang", "FRAGMENT") # noqa
952
+
953
+
954
+ def direct_register_custom_op(
955
+ op_name: str,
956
+ op_func: Callable,
957
+ mutates_args: List[str],
958
+ fake_impl: Optional[Callable] = None,
959
+ target_lib: Optional[Library] = None,
960
+ ):
961
+ """
962
+ `torch.library.custom_op` can have significant overhead because it
963
+ needs to consider complicated dispatching logic. This function
964
+ directly registers a custom op and dispatches it to the CUDA backend.
965
+ See https://gist.github.com/youkaichao/ecbea9ec9fc79a45d2adce1784d7a9a5
966
+ for more details.
967
+
968
+ By default, the custom op is registered to the vLLM library. If you
969
+ want to register it to a different library, you can pass the library
970
+ object to the `target_lib` argument.
971
+
972
+ IMPORTANT: the lifetime of the operator is tied to the lifetime of the
973
+ library object. If you want to bind the operator to a different library,
974
+ make sure the library object is alive when the operator is used.
975
+ """
976
+ import torch.library
977
+
978
+ if hasattr(torch.library, "infer_schema"):
979
+ schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args)
980
+ else:
981
+ # for pytorch 2.4
982
+ import torch._custom_op.impl
983
+
984
+ schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args)
985
+
986
+ my_lib = target_lib or sglang_lib
987
+ my_lib.define(op_name + schema_str)
988
+ my_lib.impl(op_name, op_func, "CUDA")
989
+ if fake_impl is not None:
990
+ my_lib._register_fake(op_name, fake_impl)
991
+
992
+
993
+ def gpu_proc_affinity(
994
+ tp_size: int,
995
+ nnodes: int,
996
+ gpu_id: int,
997
+ ):
998
+ # current process
999
+ pid = os.getpid()
1000
+ p = psutil.Process(pid)
1001
+
1002
+ tp_size_per_node = tp_size // nnodes
1003
+
1004
+ # total physical cores
1005
+ total_pcores = psutil.cpu_count(logical=False)
1006
+ # physical cores per TP (N.B. more Cores than GPUs on node)
1007
+ num_cores_bind = total_pcores // tp_size_per_node
1008
+
1009
+ # able to handle multiple DP per node
1010
+ start_cpu_id = (gpu_id * num_cores_bind) % total_pcores
1011
+ end_cpu_id = start_cpu_id + num_cores_bind
1012
+
1013
+ if psutil.cpu_count() != psutil.cpu_count(logical=False):
1014
+ # HT on
1015
+ upper_cpu_ids = [id for id in range(start_cpu_id, end_cpu_id)]
1016
+ lower_cpu_ids = [id + total_pcores for id in range(start_cpu_id, end_cpu_id)]
1017
+ bind_cpu_ids = list(itertools.chain(upper_cpu_ids, lower_cpu_ids))
1018
+ else:
1019
+ # HT off
1020
+ bind_cpu_ids = [id for id in range(start_cpu_id, end_cpu_id)]
1021
+
1022
+ # set cpu_affinity to current process
1023
+ p.cpu_affinity(bind_cpu_ids)
1024
+ logger.info(f"Process {pid} gpu_id {gpu_id} is running on CPUs: {p.cpu_affinity()}")
@@ -48,9 +48,13 @@ def run_eval(args):
48
48
  # Select backend
49
49
  set_default_backend(RuntimeEndpoint(f"{args.host}:{args.port}"))
50
50
 
51
- # Read data
52
- url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl"
53
- filename = download_and_cache_file(url)
51
+ if args.data_path is None:
52
+ # Read data
53
+ url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl"
54
+ filename = download_and_cache_file(url)
55
+ else:
56
+ filename = args.data_path
57
+
54
58
  lines = list(read_jsonl(filename))
55
59
 
56
60
  # Construct prompts
@@ -131,7 +135,7 @@ def run_eval(args):
131
135
  if __name__ == "__main__":
132
136
  parser = argparse.ArgumentParser()
133
137
  parser.add_argument("--num-shots", type=int, default=5)
134
- parser.add_argument("--data-path", type=str, default="test.jsonl")
138
+ parser.add_argument("--data-path", type=str)
135
139
  parser.add_argument("--num-questions", type=int, default=200)
136
140
  parser.add_argument("--max-new-tokens", type=int, default=512)
137
141
  parser.add_argument("--parallel", type=int, default=128)
sglang/test/runners.py CHANGED
@@ -1,17 +1,16 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
-
7
- http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- Unless required by applicable law or agreed to in writing, software
10
- distributed under the License is distributed on an "AS IS" BASIS,
11
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- See the License for the specific language governing permissions and
13
- limitations under the License.
14
- """
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
15
14
 
16
15
  import json
17
16
  import multiprocessing as mp
@@ -58,6 +57,28 @@ def get_top_logprobs(logits, k):
58
57
  return logprobs
59
58
 
60
59
 
60
+ def _get_sentence_transformer_embedding_model(model_path, torch_dtype):
61
+ from sentence_transformers import SentenceTransformer
62
+ from sentence_transformers.util import is_sentence_transformer_model
63
+
64
+ if is_sentence_transformer_model(model_path):
65
+ model = SentenceTransformer(
66
+ model_path,
67
+ model_kwargs={"torch_dtype": torch_dtype},
68
+ )
69
+ else: # if no pre-trained sentence-transformers model
70
+ from sentence_transformers import models
71
+
72
+ word_embedding_model = models.Transformer(model_path).to(dtype=torch_dtype)
73
+ pooling_model = models.Pooling(
74
+ word_embedding_model.get_word_embedding_dimension(),
75
+ pooling_mode="lasttoken",
76
+ )
77
+ model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
78
+
79
+ return model.cuda()
80
+
81
+
61
82
  @dataclass
62
83
  class ModelOutput:
63
84
  output_strs: List[str] = None
@@ -114,12 +135,9 @@ class HFRunner:
114
135
  low_cpu_mem_usage=True,
115
136
  ).cuda()
116
137
  elif self.model_type == "embedding":
117
- from sentence_transformers import SentenceTransformer
118
-
119
- self.model = SentenceTransformer(
120
- model_path,
121
- model_kwargs={"torch_dtype": torch_dtype},
122
- ).cuda()
138
+ self.model = _get_sentence_transformer_embedding_model(
139
+ model_path, torch_dtype
140
+ )
123
141
  elif self.model_type == "reward":
124
142
  from transformers import AutoModelForSequenceClassification
125
143
 
@@ -1,7 +1,7 @@
1
1
  import dataclasses
2
2
  import enum
3
- import typing
4
3
  import unittest
4
+ from typing import Dict, List, Optional, Set, Tuple, Type
5
5
 
6
6
  import torch
7
7
 
@@ -16,7 +16,7 @@ from sglang.srt.sampling.penaltylib.orchestrator import (
16
16
  class MockSamplingParams:
17
17
  frequency_penalty: float = 0.0
18
18
  min_new_tokens: int = 0
19
- stop_token_ids: typing.List[int] = None
19
+ stop_token_ids: List[int] = None
20
20
  presence_penalty: float = 0.0
21
21
  repetition_penalty: float = 1.0
22
22
 
@@ -24,12 +24,12 @@ class MockSamplingParams:
24
24
  @dataclasses.dataclass
25
25
  class MockTokenizer:
26
26
  eos_token_id: int
27
- additional_stop_token_ids: typing.Optional[typing.List[int]] = None
27
+ additional_stop_token_ids: Optional[List[int]] = None
28
28
 
29
29
 
30
30
  @dataclasses.dataclass
31
31
  class MockReq:
32
- origin_input_ids: typing.List[int]
32
+ origin_input_ids: List[int]
33
33
  sampling_params: MockSamplingParams
34
34
  tokenizer: MockTokenizer
35
35
 
@@ -42,8 +42,8 @@ class StepType(enum.Enum):
42
42
  @dataclasses.dataclass
43
43
  class Step:
44
44
  type: StepType
45
- token_ids: typing.List[int]
46
- expected_tensors: typing.Dict[str, torch.Tensor]
45
+ token_ids: List[int]
46
+ expected_tensors: Dict[str, torch.Tensor]
47
47
  # assume initial logits are all 1
48
48
  expected_logits: torch.Tensor
49
49
 
@@ -52,7 +52,7 @@ class Step:
52
52
  class Subject:
53
53
  sampling_params: MockSamplingParams
54
54
  # first step must be input, which will be converted to Req
55
- steps: typing.List[Step]
55
+ steps: List[Step]
56
56
  eos_token_id: int = -1
57
57
 
58
58
  def __post_init__(self):
@@ -66,7 +66,7 @@ class Subject:
66
66
  f"Expected tensors keys must be the same for all steps. Got {self.steps[i].expected_tensors.keys()} for key={i} and {self.steps[0].expected_tensors.keys()}"
67
67
  )
68
68
 
69
- def tensor_keys(self, i: int = 0) -> typing.Set[str]:
69
+ def tensor_keys(self, i: int = 0) -> Set[str]:
70
70
  return set(self.steps[i].expected_tensors.keys())
71
71
 
72
72
  def to_req(self) -> MockReq:
@@ -80,7 +80,7 @@ class Subject:
80
80
  @dataclasses.dataclass
81
81
  class Case:
82
82
  enabled: bool
83
- test_subjects: typing.List[Subject]
83
+ test_subjects: List[Subject]
84
84
 
85
85
  def __post_init__(self):
86
86
  # each test_subjects.steps should have the same expected_tensors.keys()
@@ -90,12 +90,12 @@ class Case:
90
90
  f"Expected tensors keys must be the same for all test_subjects. Got {self.test_subjects[i].tensor_keys()} for key={i} and {self.test_subjects[0].tensor_keys()}"
91
91
  )
92
92
 
93
- def tensor_keys(self, i: int = 0) -> typing.List[str]:
93
+ def tensor_keys(self, i: int = 0) -> List[str]:
94
94
  return set(self.test_subjects[i].tensor_keys())
95
95
 
96
96
 
97
97
  class BaseBatchedPenalizerTest(unittest.TestCase):
98
- Penalizer: typing.Type[_BatchedPenalizer]
98
+ Penalizer: Type[_BatchedPenalizer]
99
99
  device = "cuda"
100
100
  vocab_size = 5
101
101
 
@@ -115,7 +115,7 @@ class BaseBatchedPenalizerTest(unittest.TestCase):
115
115
  """
116
116
  return torch.tensor(data, **kwargs, device=self.device)
117
117
 
118
- def create_test_subjects(self) -> typing.List[Subject]:
118
+ def create_test_subjects(self) -> List[Subject]:
119
119
  raise NotImplementedError()
120
120
 
121
121
  def create_test_cases(self):
@@ -127,7 +127,7 @@ class BaseBatchedPenalizerTest(unittest.TestCase):
127
127
 
128
128
  def _create_penalizer(
129
129
  self, case: Case
130
- ) -> typing.Tuple[BatchedPenalizerOrchestrator, _BatchedPenalizer]:
130
+ ) -> Tuple[BatchedPenalizerOrchestrator, _BatchedPenalizer]:
131
131
  orchestrator = BatchedPenalizerOrchestrator(
132
132
  vocab_size=self.vocab_size,
133
133
  batch=_BatchLike(reqs=[subject.to_req() for subject in case.test_subjects]),
@@ -287,22 +287,24 @@ class BaseBatchedPenalizerTest(unittest.TestCase):
287
287
  if i < len(subject.steps)
288
288
  ]
289
289
 
290
- inputs: typing.List[typing.List[int]] = []
291
- outputs: typing.List[typing.List[int]] = []
290
+ inputs: List[List[int]] = []
291
+ outputs: List[List[int]] = []
292
292
  for subject in filtered_subjects:
293
293
  step = subject.steps[i]
294
294
  if step.type == StepType.INPUT:
295
- inputs.append(step.token_ids)
296
- outputs.append([])
295
+ raise NotImplementedError()
297
296
  else:
298
297
  inputs.append([])
299
298
  outputs.append(step.token_ids)
300
299
 
301
- if any(inputs):
302
- orchestrator.cumulate_input_tokens(inputs)
303
-
304
300
  if any(outputs):
305
- orchestrator.cumulate_output_tokens(outputs)
301
+ for j in range(max(len(x) for x in outputs)):
302
+ tmp_outputs = torch.tensor(
303
+ [x[j] for x in outputs],
304
+ dtype=torch.int32,
305
+ device=orchestrator.device,
306
+ )
307
+ orchestrator.cumulate_output_tokens(tmp_outputs)
306
308
 
307
309
  if penalizer.is_required():
308
310
  self.assertTrue(penalizer.is_prepared())