sglang 0.3.5.post1__py3-none-any.whl → 0.3.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. sglang/bench_latency.py +1 -553
  2. sglang/bench_offline_throughput.py +337 -0
  3. sglang/bench_one_batch.py +474 -0
  4. sglang/{bench_server_latency.py → bench_one_batch_server.py} +3 -3
  5. sglang/bench_serving.py +115 -31
  6. sglang/check_env.py +3 -6
  7. sglang/srt/constrained/base_grammar_backend.py +4 -3
  8. sglang/srt/constrained/outlines_backend.py +39 -26
  9. sglang/srt/constrained/xgrammar_backend.py +58 -14
  10. sglang/srt/layers/activation.py +3 -0
  11. sglang/srt/layers/attention/flashinfer_backend.py +93 -48
  12. sglang/srt/layers/attention/triton_backend.py +9 -7
  13. sglang/srt/layers/custom_op_util.py +26 -0
  14. sglang/srt/layers/fused_moe/fused_moe.py +11 -4
  15. sglang/srt/layers/fused_moe/patch.py +4 -2
  16. sglang/srt/layers/layernorm.py +4 -0
  17. sglang/srt/layers/logits_processor.py +10 -10
  18. sglang/srt/layers/sampler.py +4 -8
  19. sglang/srt/layers/torchao_utils.py +2 -0
  20. sglang/srt/managers/data_parallel_controller.py +74 -9
  21. sglang/srt/managers/detokenizer_manager.py +1 -14
  22. sglang/srt/managers/io_struct.py +27 -0
  23. sglang/srt/managers/schedule_batch.py +104 -38
  24. sglang/srt/managers/schedule_policy.py +5 -1
  25. sglang/srt/managers/scheduler.py +210 -56
  26. sglang/srt/managers/session_controller.py +62 -0
  27. sglang/srt/managers/tokenizer_manager.py +38 -0
  28. sglang/srt/managers/tp_worker.py +12 -1
  29. sglang/srt/managers/tp_worker_overlap_thread.py +49 -52
  30. sglang/srt/model_executor/cuda_graph_runner.py +43 -6
  31. sglang/srt/model_executor/forward_batch_info.py +109 -15
  32. sglang/srt/model_executor/model_runner.py +102 -43
  33. sglang/srt/model_parallel.py +98 -0
  34. sglang/srt/models/deepseek_v2.py +147 -44
  35. sglang/srt/models/gemma2.py +9 -8
  36. sglang/srt/models/llava.py +1 -1
  37. sglang/srt/models/llavavid.py +1 -1
  38. sglang/srt/models/olmo.py +3 -3
  39. sglang/srt/models/phi3_small.py +447 -0
  40. sglang/srt/models/qwen2_vl.py +13 -6
  41. sglang/srt/models/torch_native_llama.py +94 -78
  42. sglang/srt/openai_api/adapter.py +11 -4
  43. sglang/srt/openai_api/protocol.py +30 -27
  44. sglang/srt/sampling/penaltylib/orchestrator.py +49 -79
  45. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +3 -8
  46. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +3 -9
  47. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +3 -8
  48. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +3 -8
  49. sglang/srt/sampling/sampling_batch_info.py +58 -57
  50. sglang/srt/sampling/sampling_params.py +3 -3
  51. sglang/srt/server.py +29 -2
  52. sglang/srt/server_args.py +97 -60
  53. sglang/srt/utils.py +103 -51
  54. sglang/test/runners.py +25 -6
  55. sglang/test/srt/sampling/penaltylib/utils.py +23 -21
  56. sglang/test/test_utils.py +33 -22
  57. sglang/version.py +1 -1
  58. {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/METADATA +43 -43
  59. {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/RECORD +62 -56
  60. {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/WHEEL +1 -1
  61. {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/LICENSE +0 -0
  62. {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/top_level.txt +0 -0
sglang/srt/server_args.py CHANGED
@@ -22,7 +22,14 @@ import random
22
22
  import tempfile
23
23
  from typing import List, Optional
24
24
 
25
- from sglang.srt.utils import is_flashinfer_available, is_ipv6, is_port_available
25
+ from sglang.srt.utils import (
26
+ get_amdgpu_memory_capacity,
27
+ get_nvgpu_memory_capacity,
28
+ is_flashinfer_available,
29
+ is_hip,
30
+ is_ipv6,
31
+ is_port_available,
32
+ )
26
33
 
27
34
  logger = logging.getLogger(__name__)
28
35
 
@@ -64,6 +71,8 @@ class ServerArgs:
64
71
  random_seed: Optional[int] = None
65
72
  constrained_json_whitespace_pattern: Optional[str] = None
66
73
  watchdog_timeout: float = 300
74
+ download_dir: Optional[str] = None
75
+ base_gpu_id: int = 0
67
76
 
68
77
  # Logging
69
78
  log_level: str = "info"
@@ -108,8 +117,6 @@ class ServerArgs:
108
117
  grammar_backend: Optional[str] = "outlines"
109
118
 
110
119
  # Optimization/debug options
111
- disable_flashinfer: bool = False
112
- disable_flashinfer_sampling: bool = False
113
120
  disable_radix_cache: bool = False
114
121
  disable_jump_forward: bool = False
115
122
  disable_cuda_graph: bool = False
@@ -117,14 +124,14 @@ class ServerArgs:
117
124
  disable_disk_cache: bool = False
118
125
  disable_custom_all_reduce: bool = False
119
126
  disable_mla: bool = False
120
- disable_penalizer: bool = False
121
- disable_nan_detection: bool = False
122
- enable_overlap_schedule: bool = False
127
+ disable_overlap_schedule: bool = False
123
128
  enable_mixed_chunk: bool = False
129
+ enable_dp_attention: bool = False
124
130
  enable_torch_compile: bool = False
125
131
  torch_compile_max_bs: int = 32
126
132
  cuda_graph_max_bs: int = 160
127
133
  torchao_config: str = ""
134
+ enable_nan_detection: bool = False
128
135
  enable_p2p_check: bool = False
129
136
  triton_attention_reduce_in_fp32: bool = False
130
137
  num_continuous_decode_steps: int = 1
@@ -142,12 +149,15 @@ class ServerArgs:
142
149
  # Disable chunked prefill
143
150
  self.chunked_prefill_size = None
144
151
 
152
+ if self.random_seed is None:
153
+ self.random_seed = random.randint(0, 1 << 30)
154
+
145
155
  # Mem fraction depends on the tensor parallelism size
146
156
  if self.mem_fraction_static is None:
147
157
  if self.tp_size >= 16:
148
158
  self.mem_fraction_static = 0.79
149
159
  elif self.tp_size >= 8:
150
- self.mem_fraction_static = 0.83
160
+ self.mem_fraction_static = 0.82
151
161
  elif self.tp_size >= 4:
152
162
  self.mem_fraction_static = 0.85
153
163
  elif self.tp_size >= 2:
@@ -155,54 +165,46 @@ class ServerArgs:
155
165
  else:
156
166
  self.mem_fraction_static = 0.88
157
167
 
158
- if self.random_seed is None:
159
- self.random_seed = random.randint(0, 1 << 30)
160
-
161
- # Deprecation warnings
162
- if self.disable_flashinfer:
163
- logger.warning(
164
- "The option '--disable-flashinfer' will be deprecated in the next release. "
165
- "Please use '--attention-backend triton' instead."
166
- )
167
- self.attention_backend = "triton"
168
- if self.disable_flashinfer_sampling:
169
- logger.warning(
170
- "The option '--disable-flashinfer-sampling' will be deprecated in the next release. "
171
- "Please use '--sampling-backend pytorch' instead. "
172
- )
173
- self.sampling_backend = "pytorch"
168
+ # Adjust for GPUs with small memory capacities
169
+ if is_hip():
170
+ gpu_mem = get_amdgpu_memory_capacity()
171
+ else:
172
+ gpu_mem = get_nvgpu_memory_capacity()
173
+ if gpu_mem < 25000:
174
+ self.chunked_prefill_size //= 4 # make it 2048
175
+ self.cuda_graph_max_bs = 4
176
+ logger.info("Automatically adjust --chunked-prefill-size for small GPUs.")
174
177
 
178
+ # Choose kernel backends
175
179
  if not is_flashinfer_available():
176
180
  self.attention_backend = "triton"
177
181
  self.sampling_backend = "pytorch"
178
182
 
179
- # Default kernel backends
180
183
  if self.attention_backend is None:
181
184
  self.attention_backend = "flashinfer"
182
-
183
185
  if self.sampling_backend is None:
184
186
  self.sampling_backend = "flashinfer"
185
187
 
186
- if self.enable_overlap_schedule:
187
- logger.warning(
188
- "Overlap scheduler mode is enabled. This is an experimental feature. "
189
- "Sampling penalizer (e.g., frequency and repetition penalty), constrained decoding (e.g., regex, JSON), "
190
- "and embedding APIs are not supported and will lead to wrong results. "
191
- "The NaN detection is also disabled."
188
+ # Others
189
+ if self.enable_dp_attention:
190
+ self.dp_size = self.tp_size
191
+ self.chunked_prefill_size = self.chunked_prefill_size // 2
192
+ self.cuda_graph_max_bs = min(self.cuda_graph_max_bs, 96)
193
+ self.schedule_conservativeness = self.schedule_conservativeness * 0.3
194
+ self.disable_overlap_schedule = True
195
+ logger.info(
196
+ f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. "
197
+ f"The CUDA graph max batch size is adjusted to {self.cuda_graph_max_bs}. "
198
+ f"The schedule conservativeness is adjusted to {self.schedule_conservativeness}. "
199
+ "Data parallel size is adjusted to be the same as tensor parallel size. "
200
+ "Overlap schedule is disabled."
192
201
  )
193
- self.disable_penalizer = True
194
- self.disable_nan_detection = True
195
202
 
196
- # Model-specific patches
197
- if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == self.model_path:
203
+ if self.enable_mixed_chunk:
198
204
  logger.info(
199
- "Not sure why, the tokenizer will add an additional token at the end of the prompt when trust_remote_mode=True"
205
+ "Overlap schedule is disabled because mixed-style chunked prefill is enabled."
200
206
  )
201
- self.trust_remote_code = False
202
-
203
- if "gemma-2" in self.model_path.lower():
204
- logger.info("When using sliding window in gemma-2, turn on flashinfer.")
205
- self.attention_backend = "flashinfer"
207
+ self.disable_overlap_schedule = True
206
208
 
207
209
  @staticmethod
208
210
  def add_cli_args(parser: argparse.ArgumentParser):
@@ -405,6 +407,18 @@ class ServerArgs:
405
407
  default=ServerArgs.watchdog_timeout,
406
408
  help="Set watchdog timeout in seconds. If a forward batch takes longer than this, the server will crash to prevent hanging.",
407
409
  )
410
+ parser.add_argument(
411
+ "--download-dir",
412
+ type=str,
413
+ default=ServerArgs.download_dir,
414
+ help="Model download directory.",
415
+ )
416
+ parser.add_argument(
417
+ "--base-gpu-id",
418
+ type=int,
419
+ default=ServerArgs.base_gpu_id,
420
+ help="The base GPU ID to start allocating GPUs from. Useful when running multiple instances on the same machine.",
421
+ )
408
422
 
409
423
  # Logging
410
424
  parser.add_argument(
@@ -578,16 +592,6 @@ class ServerArgs:
578
592
  )
579
593
 
580
594
  # Optimization/debug options
581
- parser.add_argument(
582
- "--disable-flashinfer",
583
- action="store_true",
584
- help="Disable flashinfer attention kernels. This option will be deprecated in the next release. Please use '--attention-backend triton' instead.",
585
- )
586
- parser.add_argument(
587
- "--disable-flashinfer-sampling",
588
- action="store_true",
589
- help="Disable flashinfer sampling kernels. This option will be deprecated in the next release. Please use '--sampling-backend pytorch' instead.",
590
- )
591
595
  parser.add_argument(
592
596
  "--disable-radix-cache",
593
597
  action="store_true",
@@ -623,26 +627,26 @@ class ServerArgs:
623
627
  action="store_true",
624
628
  help="Disable Multi-head Latent Attention (MLA) for DeepSeek-V2.",
625
629
  )
626
- parser.add_argument(
627
- "--disable-penalizer",
628
- action="store_true",
629
- help="Disable the logit penalizers (e.g., frequency and repetition penalty) for better performance if they are not used in any requests.",
630
- )
631
630
  parser.add_argument(
632
631
  "--disable-nan-detection",
633
632
  action="store_true",
634
633
  help="Disable the NaN detection for better performance.",
635
634
  )
636
635
  parser.add_argument(
637
- "--enable-overlap-schedule",
636
+ "--disable-overlap-schedule",
638
637
  action="store_true",
639
- help="Overlap the CPU scheduler with GPU model worker. Experimental feature.",
638
+ help="Disable the overlap scheduler, which overlaps the CPU scheduler with GPU model worker.",
640
639
  )
641
640
  parser.add_argument(
642
641
  "--enable-mixed-chunk",
643
642
  action="store_true",
644
643
  help="Enabling mixing prefill and decode in a batch when using chunked prefill.",
645
644
  )
645
+ parser.add_argument(
646
+ "--enable-dp-attention",
647
+ action="store_true",
648
+ help="Enabling data parallelism for attention and tensor parallelism for FFN. The dp size should be equal to the tp size. Currently only DeepSeek-V2 is supported.",
649
+ )
646
650
  parser.add_argument(
647
651
  "--enable-torch-compile",
648
652
  action="store_true",
@@ -664,7 +668,12 @@ class ServerArgs:
664
668
  "--torchao-config",
665
669
  type=str,
666
670
  default=ServerArgs.torchao_config,
667
- help="Optimize the model with torchao. Experimental feature. Current choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo",
671
+ help="Optimize the model with torchao. Experimental feature. Current choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo, fp8dq-per_tensor, fp8dq-per_row",
672
+ )
673
+ parser.add_argument(
674
+ "--enable-nan-detection",
675
+ action="store_true",
676
+ help="Enable the NaN detection for debugging purposes.",
668
677
  )
669
678
  parser.add_argument(
670
679
  "--enable-p2p-check",
@@ -691,6 +700,23 @@ class ServerArgs:
691
700
  help="Delete the model checkpoint after loading the model.",
692
701
  )
693
702
 
703
+ # Deprecated arguments
704
+ parser.add_argument(
705
+ "--enable-overlap-schedule",
706
+ action=DeprecatedAction,
707
+ help="'--enable-overlap-schedule' is deprecated. It is enabled by default now. Please drop this argument.",
708
+ )
709
+ parser.add_argument(
710
+ "--disable-flashinfer",
711
+ action=DeprecatedAction,
712
+ help="'--disable-flashinfer' is deprecated. Please use '--attention-backend triton' instead.",
713
+ )
714
+ parser.add_argument(
715
+ "--disable-flashinfer-sampling",
716
+ action=DeprecatedAction,
717
+ help="'--disable-flashinfer-sampling' is deprecated. Please use '--sampling-backend pytroch' instead.",
718
+ )
719
+
694
720
  @classmethod
695
721
  def from_cli_args(cls, args: argparse.Namespace):
696
722
  args.tp_size = args.tensor_parallel_size
@@ -717,6 +743,7 @@ class ServerArgs:
717
743
  and (self.lora_paths is None or self.disable_cuda_graph)
718
744
  and (self.lora_paths is None or self.disable_radix_cache)
719
745
  ), "compatibility of lora and cuda graph and radix attention is in progress"
746
+ assert self.base_gpu_id >= 0, "base_gpu_id must be non-negative"
720
747
 
721
748
  if isinstance(self.lora_paths, list):
722
749
  lora_paths = self.lora_paths
@@ -761,7 +788,7 @@ class PortArgs:
761
788
 
762
789
  @staticmethod
763
790
  def init_new(server_args) -> "PortArgs":
764
- port = server_args.port + 42
791
+ port = server_args.port + random.randint(100, 1000)
765
792
  while True:
766
793
  if is_port_available(port):
767
794
  break
@@ -784,3 +811,13 @@ class LoRAPathAction(argparse.Action):
784
811
  getattr(namespace, self.dest)[name] = path
785
812
  else:
786
813
  getattr(namespace, self.dest)[lora_path] = lora_path
814
+
815
+
816
+ class DeprecatedAction(argparse.Action):
817
+ def __init__(self, option_strings, dest, nargs=0, **kwargs):
818
+ super(DeprecatedAction, self).__init__(
819
+ option_strings, dest, nargs=nargs, **kwargs
820
+ )
821
+
822
+ def __call__(self, parser, namespace, values, option_string=None):
823
+ raise ValueError(self.help)
sglang/srt/utils.py CHANGED
@@ -27,6 +27,7 @@ import resource
27
27
  import shutil
28
28
  import signal
29
29
  import socket
30
+ import subprocess
30
31
  import tempfile
31
32
  import time
32
33
  import warnings
@@ -70,6 +71,8 @@ def is_flashinfer_available():
70
71
  Check whether flashinfer is available.
71
72
  As of Oct. 6, 2024, it is only available on NVIDIA GPUs.
72
73
  """
74
+ if os.environ.get("SGLANG_IS_FLASHINFER_AVAILABLE", "true") == "false":
75
+ return False
73
76
  return torch.cuda.is_available() and not is_hip()
74
77
 
75
78
 
@@ -329,6 +332,7 @@ def suppress_other_loggers():
329
332
  )
330
333
  logging.getLogger("vllm.selector").setLevel(logging.WARN)
331
334
  logging.getLogger("vllm.utils").setLevel(logging.ERROR)
335
+ logging.getLogger("vllm.model_executor.model_loader.loader").setLevel(logging.ERROR)
332
336
 
333
337
  warnings.filterwarnings(
334
338
  "ignore", category=UserWarning, message="The given NumPy array is not writable"
@@ -393,6 +397,27 @@ def kill_child_process(pid=None, include_self=False, skip_pid=None):
393
397
  pass
394
398
 
395
399
 
400
+ def monkey_patch_vllm_model_config():
401
+ from vllm.config import ModelConfig
402
+
403
+ if not hasattr(ModelConfig, "_resolve_task"):
404
+ return
405
+
406
+ def _resolve_task(
407
+ self,
408
+ task_option,
409
+ hf_config,
410
+ ):
411
+ supported_tasks = {
412
+ "generate": True,
413
+ "embedding": False,
414
+ }
415
+ selected_task = "generate"
416
+ return supported_tasks, selected_task
417
+
418
+ setattr(ModelConfig, "_resolve_task", _resolve_task)
419
+
420
+
396
421
  def monkey_patch_vllm_p2p_access_check(gpu_id: int):
397
422
  """
398
423
  Monkey patch the slow p2p access check in vllm.
@@ -404,57 +429,6 @@ def monkey_patch_vllm_p2p_access_check(gpu_id: int):
404
429
  setattr(tgt, "gpu_p2p_access_check", lambda *arg, **kwargs: True)
405
430
 
406
431
 
407
- def monkey_patch_vllm_dummy_weight_loader():
408
- """
409
- Monkey patch the dummy weight loader in vllm to call process_weights_after_loading.
410
- """
411
-
412
- from vllm.model_executor.model_loader.loader import (
413
- CacheConfig,
414
- DeviceConfig,
415
- DummyModelLoader,
416
- LoRAConfig,
417
- ModelConfig,
418
- ParallelConfig,
419
- SchedulerConfig,
420
- _initialize_model,
421
- initialize_dummy_weights,
422
- nn,
423
- set_default_torch_dtype,
424
- )
425
-
426
- def load_model(
427
- self,
428
- *,
429
- model_config: ModelConfig,
430
- device_config: DeviceConfig,
431
- lora_config: Optional[LoRAConfig],
432
- parallel_config: ParallelConfig,
433
- scheduler_config: SchedulerConfig,
434
- cache_config: CacheConfig,
435
- ) -> nn.Module:
436
- with set_default_torch_dtype(model_config.dtype):
437
- with torch.device(device_config.device):
438
- model = _initialize_model(
439
- model_config,
440
- self.load_config,
441
- lora_config,
442
- cache_config,
443
- )
444
-
445
- for _, module in model.named_modules():
446
- quant_method = getattr(module, "quant_method", None)
447
- if quant_method is not None:
448
- quant_method.process_weights_after_loading(module)
449
-
450
- # NOTE(woosuk): For accurate performance evaluation, we assign
451
- # random values to the weights.
452
- initialize_dummy_weights(model)
453
- return model.eval()
454
-
455
- setattr(DummyModelLoader, "load_model", load_model)
456
-
457
-
458
432
  vllm_all_gather_backup = None
459
433
 
460
434
 
@@ -791,3 +765,81 @@ def add_prometheus_middleware(app):
791
765
  # Workaround for 307 Redirect for /metrics
792
766
  metrics_route.path_regex = re.compile("^/metrics(?P<path>.*)$")
793
767
  app.routes.append(metrics_route)
768
+
769
+
770
+ def bind_port(port):
771
+ """Bind to a specific port, assuming it's available."""
772
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
773
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) # Allows address reuse
774
+ sock.bind(("", port))
775
+ sock.listen(1)
776
+ return sock
777
+
778
+
779
+ def get_amdgpu_memory_capacity():
780
+ try:
781
+ # Run rocm-smi and capture the output
782
+ result = subprocess.run(
783
+ ["rocm-smi --showmeminfo vram | grep 'Total Memory' | awk '{print $NF}'"],
784
+ stdout=subprocess.PIPE,
785
+ stderr=subprocess.PIPE,
786
+ shell=True,
787
+ text=True,
788
+ )
789
+ if result.returncode != 0:
790
+ raise RuntimeError(f"rocm-smi error: {result.stderr.strip()}")
791
+
792
+ # Parse the output to extract memory values in MiB
793
+ memory_values = [
794
+ float(mem) / 1024 / 1024
795
+ for mem in result.stdout.strip().split("\n")
796
+ if re.match(r"^\d+(\.\d+)?$", mem.strip())
797
+ ]
798
+
799
+ if not memory_values:
800
+ raise ValueError("No GPU memory values found.")
801
+
802
+ # Return the minimum memory value
803
+ return min(memory_values)
804
+
805
+ except FileNotFoundError:
806
+ raise RuntimeError(
807
+ "rocm-smi not found. Ensure AMD ROCm drivers are installed and accessible."
808
+ )
809
+
810
+
811
+ def get_nvgpu_memory_capacity():
812
+ try:
813
+ # Run nvidia-smi and capture the output
814
+ result = subprocess.run(
815
+ ["nvidia-smi", "--query-gpu=memory.total", "--format=csv,noheader,nounits"],
816
+ stdout=subprocess.PIPE,
817
+ stderr=subprocess.PIPE,
818
+ text=True,
819
+ )
820
+
821
+ if result.returncode != 0:
822
+ raise RuntimeError(f"nvidia-smi error: {result.stderr.strip()}")
823
+
824
+ # Parse the output to extract memory values
825
+ memory_values = [
826
+ float(mem)
827
+ for mem in result.stdout.strip().split("\n")
828
+ if re.match(r"^\d+(\.\d+)?$", mem.strip())
829
+ ]
830
+
831
+ if not memory_values:
832
+ raise ValueError("No GPU memory values found.")
833
+
834
+ # Return the minimum memory value
835
+ return min(memory_values)
836
+
837
+ except FileNotFoundError:
838
+ raise RuntimeError(
839
+ "nvidia-smi not found. Ensure NVIDIA drivers are installed and accessible."
840
+ )
841
+
842
+
843
+ def crash_on_warnings():
844
+ # Crash on warning if we are running CI tests
845
+ return os.getenv("SGLANG_IS_IN_CI", "false") == "true"
sglang/test/runners.py CHANGED
@@ -58,6 +58,28 @@ def get_top_logprobs(logits, k):
58
58
  return logprobs
59
59
 
60
60
 
61
+ def _get_sentence_transformer_embedding_model(model_path, torch_dtype):
62
+ from sentence_transformers import SentenceTransformer
63
+ from sentence_transformers.util import is_sentence_transformer_model
64
+
65
+ if is_sentence_transformer_model(model_path):
66
+ model = SentenceTransformer(
67
+ model_path,
68
+ model_kwargs={"torch_dtype": torch_dtype},
69
+ )
70
+ else: # if no pre-trained sentence-transformers model
71
+ from sentence_transformers import models
72
+
73
+ word_embedding_model = models.Transformer(model_path).to(dtype=torch_dtype)
74
+ pooling_model = models.Pooling(
75
+ word_embedding_model.get_word_embedding_dimension(),
76
+ pooling_mode="lasttoken",
77
+ )
78
+ model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
79
+
80
+ return model.cuda()
81
+
82
+
61
83
  @dataclass
62
84
  class ModelOutput:
63
85
  output_strs: List[str] = None
@@ -114,12 +136,9 @@ class HFRunner:
114
136
  low_cpu_mem_usage=True,
115
137
  ).cuda()
116
138
  elif self.model_type == "embedding":
117
- from sentence_transformers import SentenceTransformer
118
-
119
- self.model = SentenceTransformer(
120
- model_path,
121
- model_kwargs={"torch_dtype": torch_dtype},
122
- ).cuda()
139
+ self.model = _get_sentence_transformer_embedding_model(
140
+ model_path, torch_dtype
141
+ )
123
142
  elif self.model_type == "reward":
124
143
  from transformers import AutoModelForSequenceClassification
125
144
 
@@ -1,7 +1,7 @@
1
1
  import dataclasses
2
2
  import enum
3
- import typing
4
3
  import unittest
4
+ from typing import Dict, List, Optional, Set, Tuple, Type
5
5
 
6
6
  import torch
7
7
 
@@ -16,7 +16,7 @@ from sglang.srt.sampling.penaltylib.orchestrator import (
16
16
  class MockSamplingParams:
17
17
  frequency_penalty: float = 0.0
18
18
  min_new_tokens: int = 0
19
- stop_token_ids: typing.List[int] = None
19
+ stop_token_ids: List[int] = None
20
20
  presence_penalty: float = 0.0
21
21
  repetition_penalty: float = 1.0
22
22
 
@@ -24,12 +24,12 @@ class MockSamplingParams:
24
24
  @dataclasses.dataclass
25
25
  class MockTokenizer:
26
26
  eos_token_id: int
27
- additional_stop_token_ids: typing.Optional[typing.List[int]] = None
27
+ additional_stop_token_ids: Optional[List[int]] = None
28
28
 
29
29
 
30
30
  @dataclasses.dataclass
31
31
  class MockReq:
32
- origin_input_ids: typing.List[int]
32
+ origin_input_ids: List[int]
33
33
  sampling_params: MockSamplingParams
34
34
  tokenizer: MockTokenizer
35
35
 
@@ -42,8 +42,8 @@ class StepType(enum.Enum):
42
42
  @dataclasses.dataclass
43
43
  class Step:
44
44
  type: StepType
45
- token_ids: typing.List[int]
46
- expected_tensors: typing.Dict[str, torch.Tensor]
45
+ token_ids: List[int]
46
+ expected_tensors: Dict[str, torch.Tensor]
47
47
  # assume initial logits are all 1
48
48
  expected_logits: torch.Tensor
49
49
 
@@ -52,7 +52,7 @@ class Step:
52
52
  class Subject:
53
53
  sampling_params: MockSamplingParams
54
54
  # first step must be input, which will be converted to Req
55
- steps: typing.List[Step]
55
+ steps: List[Step]
56
56
  eos_token_id: int = -1
57
57
 
58
58
  def __post_init__(self):
@@ -66,7 +66,7 @@ class Subject:
66
66
  f"Expected tensors keys must be the same for all steps. Got {self.steps[i].expected_tensors.keys()} for key={i} and {self.steps[0].expected_tensors.keys()}"
67
67
  )
68
68
 
69
- def tensor_keys(self, i: int = 0) -> typing.Set[str]:
69
+ def tensor_keys(self, i: int = 0) -> Set[str]:
70
70
  return set(self.steps[i].expected_tensors.keys())
71
71
 
72
72
  def to_req(self) -> MockReq:
@@ -80,7 +80,7 @@ class Subject:
80
80
  @dataclasses.dataclass
81
81
  class Case:
82
82
  enabled: bool
83
- test_subjects: typing.List[Subject]
83
+ test_subjects: List[Subject]
84
84
 
85
85
  def __post_init__(self):
86
86
  # each test_subjects.steps should have the same expected_tensors.keys()
@@ -90,12 +90,12 @@ class Case:
90
90
  f"Expected tensors keys must be the same for all test_subjects. Got {self.test_subjects[i].tensor_keys()} for key={i} and {self.test_subjects[0].tensor_keys()}"
91
91
  )
92
92
 
93
- def tensor_keys(self, i: int = 0) -> typing.List[str]:
93
+ def tensor_keys(self, i: int = 0) -> List[str]:
94
94
  return set(self.test_subjects[i].tensor_keys())
95
95
 
96
96
 
97
97
  class BaseBatchedPenalizerTest(unittest.TestCase):
98
- Penalizer: typing.Type[_BatchedPenalizer]
98
+ Penalizer: Type[_BatchedPenalizer]
99
99
  device = "cuda"
100
100
  vocab_size = 5
101
101
 
@@ -115,7 +115,7 @@ class BaseBatchedPenalizerTest(unittest.TestCase):
115
115
  """
116
116
  return torch.tensor(data, **kwargs, device=self.device)
117
117
 
118
- def create_test_subjects(self) -> typing.List[Subject]:
118
+ def create_test_subjects(self) -> List[Subject]:
119
119
  raise NotImplementedError()
120
120
 
121
121
  def create_test_cases(self):
@@ -127,7 +127,7 @@ class BaseBatchedPenalizerTest(unittest.TestCase):
127
127
 
128
128
  def _create_penalizer(
129
129
  self, case: Case
130
- ) -> typing.Tuple[BatchedPenalizerOrchestrator, _BatchedPenalizer]:
130
+ ) -> Tuple[BatchedPenalizerOrchestrator, _BatchedPenalizer]:
131
131
  orchestrator = BatchedPenalizerOrchestrator(
132
132
  vocab_size=self.vocab_size,
133
133
  batch=_BatchLike(reqs=[subject.to_req() for subject in case.test_subjects]),
@@ -287,22 +287,24 @@ class BaseBatchedPenalizerTest(unittest.TestCase):
287
287
  if i < len(subject.steps)
288
288
  ]
289
289
 
290
- inputs: typing.List[typing.List[int]] = []
291
- outputs: typing.List[typing.List[int]] = []
290
+ inputs: List[List[int]] = []
291
+ outputs: List[List[int]] = []
292
292
  for subject in filtered_subjects:
293
293
  step = subject.steps[i]
294
294
  if step.type == StepType.INPUT:
295
- inputs.append(step.token_ids)
296
- outputs.append([])
295
+ raise NotImplementedError()
297
296
  else:
298
297
  inputs.append([])
299
298
  outputs.append(step.token_ids)
300
299
 
301
- if any(inputs):
302
- orchestrator.cumulate_input_tokens(inputs)
303
-
304
300
  if any(outputs):
305
- orchestrator.cumulate_output_tokens(outputs)
301
+ for j in range(max(len(x) for x in outputs)):
302
+ tmp_outputs = torch.tensor(
303
+ [x[j] for x in outputs],
304
+ dtype=torch.int32,
305
+ device=orchestrator.device,
306
+ )
307
+ orchestrator.cumulate_output_tokens(tmp_outputs)
306
308
 
307
309
  if penalizer.is_required():
308
310
  self.assertTrue(penalizer.is_prepared())