sglang 0.3.2__py3-none-any.whl → 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (87) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +23 -1
  3. sglang/bench_latency.py +46 -25
  4. sglang/bench_serving.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +14 -1
  6. sglang/lang/interpreter.py +16 -6
  7. sglang/lang/ir.py +20 -4
  8. sglang/srt/configs/model_config.py +11 -9
  9. sglang/srt/constrained/fsm_cache.py +9 -1
  10. sglang/srt/constrained/jump_forward.py +15 -2
  11. sglang/srt/layers/activation.py +4 -4
  12. sglang/srt/layers/attention/__init__.py +49 -0
  13. sglang/srt/layers/attention/flashinfer_backend.py +277 -0
  14. sglang/srt/layers/{flashinfer_utils.py → attention/flashinfer_utils.py} +82 -80
  15. sglang/srt/layers/attention/triton_backend.py +161 -0
  16. sglang/srt/layers/{triton_attention → attention/triton_ops}/extend_attention.py +3 -1
  17. sglang/srt/layers/layernorm.py +4 -4
  18. sglang/srt/layers/logits_processor.py +19 -15
  19. sglang/srt/layers/pooler.py +3 -3
  20. sglang/srt/layers/quantization/__init__.py +0 -2
  21. sglang/srt/layers/radix_attention.py +6 -4
  22. sglang/srt/layers/sampler.py +6 -4
  23. sglang/srt/layers/torchao_utils.py +18 -0
  24. sglang/srt/lora/lora.py +20 -21
  25. sglang/srt/lora/lora_manager.py +97 -25
  26. sglang/srt/managers/detokenizer_manager.py +31 -18
  27. sglang/srt/managers/image_processor.py +187 -0
  28. sglang/srt/managers/io_struct.py +99 -75
  29. sglang/srt/managers/schedule_batch.py +184 -63
  30. sglang/srt/managers/{policy_scheduler.py → schedule_policy.py} +31 -21
  31. sglang/srt/managers/scheduler.py +1021 -0
  32. sglang/srt/managers/tokenizer_manager.py +120 -248
  33. sglang/srt/managers/tp_worker.py +28 -925
  34. sglang/srt/mem_cache/memory_pool.py +34 -52
  35. sglang/srt/model_executor/cuda_graph_runner.py +15 -19
  36. sglang/srt/model_executor/forward_batch_info.py +94 -95
  37. sglang/srt/model_executor/model_runner.py +76 -75
  38. sglang/srt/models/baichuan.py +10 -10
  39. sglang/srt/models/chatglm.py +12 -12
  40. sglang/srt/models/commandr.py +10 -10
  41. sglang/srt/models/dbrx.py +12 -12
  42. sglang/srt/models/deepseek.py +10 -10
  43. sglang/srt/models/deepseek_v2.py +14 -15
  44. sglang/srt/models/exaone.py +10 -10
  45. sglang/srt/models/gemma.py +10 -10
  46. sglang/srt/models/gemma2.py +11 -11
  47. sglang/srt/models/gpt_bigcode.py +10 -10
  48. sglang/srt/models/grok.py +10 -10
  49. sglang/srt/models/internlm2.py +10 -10
  50. sglang/srt/models/llama.py +14 -10
  51. sglang/srt/models/llama_classification.py +5 -5
  52. sglang/srt/models/llama_embedding.py +4 -4
  53. sglang/srt/models/llama_reward.py +142 -0
  54. sglang/srt/models/llava.py +39 -33
  55. sglang/srt/models/llavavid.py +31 -28
  56. sglang/srt/models/minicpm.py +10 -10
  57. sglang/srt/models/minicpm3.py +14 -15
  58. sglang/srt/models/mixtral.py +10 -10
  59. sglang/srt/models/mixtral_quant.py +10 -10
  60. sglang/srt/models/olmoe.py +10 -10
  61. sglang/srt/models/qwen.py +10 -10
  62. sglang/srt/models/qwen2.py +11 -11
  63. sglang/srt/models/qwen2_moe.py +10 -10
  64. sglang/srt/models/stablelm.py +10 -10
  65. sglang/srt/models/torch_native_llama.py +506 -0
  66. sglang/srt/models/xverse.py +10 -10
  67. sglang/srt/models/xverse_moe.py +10 -10
  68. sglang/srt/sampling/sampling_batch_info.py +36 -27
  69. sglang/srt/sampling/sampling_params.py +3 -1
  70. sglang/srt/server.py +170 -119
  71. sglang/srt/server_args.py +54 -27
  72. sglang/srt/utils.py +101 -128
  73. sglang/test/runners.py +71 -26
  74. sglang/test/test_programs.py +38 -5
  75. sglang/test/test_utils.py +18 -9
  76. sglang/version.py +1 -1
  77. {sglang-0.3.2.dist-info → sglang-0.3.3.dist-info}/METADATA +37 -19
  78. sglang-0.3.3.dist-info/RECORD +139 -0
  79. sglang/srt/layers/attention_backend.py +0 -474
  80. sglang/srt/managers/controller_multi.py +0 -207
  81. sglang/srt/managers/controller_single.py +0 -164
  82. sglang-0.3.2.dist-info/RECORD +0 -135
  83. /sglang/srt/layers/{triton_attention → attention/triton_ops}/decode_attention.py +0 -0
  84. /sglang/srt/layers/{triton_attention → attention/triton_ops}/prefill_attention.py +0 -0
  85. {sglang-0.3.2.dist-info → sglang-0.3.3.dist-info}/LICENSE +0 -0
  86. {sglang-0.3.2.dist-info → sglang-0.3.3.dist-info}/WHEEL +0 -0
  87. {sglang-0.3.2.dist-info → sglang-0.3.3.dist-info}/top_level.txt +0 -0
sglang/srt/server_args.py CHANGED
@@ -19,9 +19,10 @@ import argparse
19
19
  import dataclasses
20
20
  import logging
21
21
  import random
22
- from typing import List, Optional, Union
22
+ import tempfile
23
+ from typing import List, Optional
23
24
 
24
- from sglang.srt.utils import is_hip
25
+ from sglang.srt.utils import is_flashinfer_available, is_ipv6, is_port_available
25
26
 
26
27
  logger = logging.getLogger(__name__)
27
28
 
@@ -46,7 +47,6 @@ class ServerArgs:
46
47
  # Port
47
48
  host: str = "127.0.0.1"
48
49
  port: int = 30000
49
- additional_ports: Optional[Union[List[int], int]] = None
50
50
 
51
51
  # Memory and scheduling
52
52
  mem_fraction_static: Optional[float] = None
@@ -78,9 +78,9 @@ class ServerArgs:
78
78
  load_balance_method: str = "round_robin"
79
79
 
80
80
  # Distributed args
81
- nccl_init_addr: Optional[str] = None
81
+ dist_init_addr: Optional[str] = None
82
82
  nnodes: int = 1
83
- node_rank: Optional[int] = None
83
+ node_rank: int = 0
84
84
 
85
85
  # Model override args in JSON
86
86
  json_model_override_args: str = "{}"
@@ -134,11 +134,6 @@ class ServerArgs:
134
134
  else:
135
135
  self.mem_fraction_static = 0.88
136
136
 
137
- if isinstance(self.additional_ports, int):
138
- self.additional_ports = [self.additional_ports]
139
- elif self.additional_ports is None:
140
- self.additional_ports = []
141
-
142
137
  if self.random_seed is None:
143
138
  self.random_seed = random.randint(0, 1 << 30)
144
139
 
@@ -156,8 +151,7 @@ class ServerArgs:
156
151
  )
157
152
  self.sampling_backend = "pytorch"
158
153
 
159
- # ROCm: flashinfer available later
160
- if is_hip():
154
+ if not is_flashinfer_available():
161
155
  self.attention_backend = "triton"
162
156
  self.sampling_backend = "pytorch"
163
157
 
@@ -199,13 +193,6 @@ class ServerArgs:
199
193
  parser.add_argument(
200
194
  "--port", type=int, default=ServerArgs.port, help="The port of the server."
201
195
  )
202
- parser.add_argument(
203
- "--additional-ports",
204
- type=int,
205
- nargs="*",
206
- default=[],
207
- help="The additional ports specified for the server.",
208
- )
209
196
  parser.add_argument(
210
197
  "--tokenizer-mode",
211
198
  type=str,
@@ -279,7 +266,6 @@ class ServerArgs:
279
266
  "marlin",
280
267
  "gptq_marlin",
281
268
  "awq_marlin",
282
- "squeezellm",
283
269
  "bitsandbytes",
284
270
  ],
285
271
  help="The quantization method.",
@@ -426,14 +412,17 @@ class ServerArgs:
426
412
 
427
413
  # Multi-node distributed serving args
428
414
  parser.add_argument(
429
- "--nccl-init-addr",
415
+ "--dist-init-addr",
416
+ "--nccl-init-addr", # For backward compatbility. This will be removed in the future.
430
417
  type=str,
431
- help="The nccl init address of multi-node server.",
418
+ help="The host address for initializing distributed backend (e.g., `192.168.0.2:25000`).",
432
419
  )
433
420
  parser.add_argument(
434
421
  "--nnodes", type=int, default=ServerArgs.nnodes, help="The number of nodes."
435
422
  )
436
- parser.add_argument("--node-rank", type=int, help="The node rank.")
423
+ parser.add_argument(
424
+ "--node-rank", type=int, default=ServerArgs.node_rank, help="The node rank."
425
+ )
437
426
 
438
427
  # Model override args
439
428
  parser.add_argument(
@@ -567,7 +556,10 @@ class ServerArgs:
567
556
  return cls(**{attr: getattr(args, attr) for attr in attrs})
568
557
 
569
558
  def url(self):
570
- return f"http://{self.host}:{self.port}"
559
+ if is_ipv6(self.host):
560
+ return f"http://[{self.host}]:{self.port}"
561
+ else:
562
+ return f"http://{self.host}:{self.port}"
571
563
 
572
564
  def check_server_args(self):
573
565
  assert (
@@ -583,6 +575,21 @@ class ServerArgs:
583
575
  and (self.lora_paths is None or self.disable_radix_cache)
584
576
  ), "compatibility of lora and cuda graph and radix attention is in progress"
585
577
 
578
+ assert self.dp_size == 1, (
579
+ "The support for data parallelism is temporarily disabled during refactor. "
580
+ "Please use sglang<=0.3.2 or wait for later updates."
581
+ )
582
+
583
+ if isinstance(self.lora_paths, list):
584
+ lora_paths = self.lora_paths
585
+ self.lora_paths = {}
586
+ for lora_path in lora_paths:
587
+ if "=" in lora_path:
588
+ name, path = lora_path.split("=", 1)
589
+ self.lora_paths[name] = path
590
+ else:
591
+ self.lora_paths[lora_path] = lora_path
592
+
586
593
 
587
594
  def prepare_server_args(argv: List[str]) -> ServerArgs:
588
595
  """
@@ -604,11 +611,31 @@ def prepare_server_args(argv: List[str]) -> ServerArgs:
604
611
 
605
612
  @dataclasses.dataclass
606
613
  class PortArgs:
607
- tokenizer_port: int
608
- controller_port: int
609
- detokenizer_port: int
614
+ # The ipc filename for tokenizer to receive inputs from detokenizer (zmq)
615
+ tokenizer_ipc_name: str
616
+ # The ipc filename for scheduler (rank 0) to receive inputs from tokenizer (zmq)
617
+ scheduler_input_ipc_name: str
618
+ # The ipc filename for detokenizer to receive inputs from scheduler (zmq)
619
+ detokenizer_ipc_name: str
620
+
621
+ # The port for nccl initialization for multiple TP groups (torch.dist)
610
622
  nccl_ports: List[int]
611
623
 
624
+ @classmethod
625
+ def init_new(self, server_args):
626
+ port = server_args.port + 1
627
+ while True:
628
+ if is_port_available(port):
629
+ break
630
+ port += 1
631
+
632
+ return PortArgs(
633
+ tokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
634
+ scheduler_input_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
635
+ detokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
636
+ nccl_ports=[port],
637
+ )
638
+
612
639
 
613
640
  class LoRAPathAction(argparse.Action):
614
641
  def __call__(self, parser, namespace, values, option_string=None):
sglang/srt/utils.py CHANGED
@@ -16,14 +16,16 @@ limitations under the License.
16
16
  """Common utilities."""
17
17
 
18
18
  import base64
19
- import fcntl
19
+ import ipaddress
20
+ import json
20
21
  import logging
21
22
  import os
23
+ import pickle
22
24
  import random
23
25
  import resource
24
26
  import socket
25
- import struct
26
27
  import time
28
+ import warnings
27
29
  from importlib.metadata import PackageNotFoundError, version
28
30
  from io import BytesIO
29
31
  from typing import Any, Dict, List, Optional, Union
@@ -36,7 +38,7 @@ import torch.distributed as dist
36
38
  from fastapi.responses import JSONResponse
37
39
  from packaging import version as pkg_version
38
40
  from torch import nn
39
- from torch.nn.parameter import Parameter
41
+ from torch.profiler import ProfilerActivity, profile, record_function
40
42
  from triton.runtime.cache import (
41
43
  FileCacheManager,
42
44
  default_cache_dir,
@@ -51,11 +53,27 @@ show_time_cost = False
51
53
  time_infos = {}
52
54
 
53
55
 
54
- # torch flag AMD GPU
55
56
  def is_hip() -> bool:
57
+ """Return whether it is HIP on the AMD ROCm platform."""
56
58
  return torch.version.hip is not None
57
59
 
58
60
 
61
+ def is_flashinfer_available():
62
+ """
63
+ Check whether flashinfer is available.
64
+ As of Oct. 6, 2024, it is only available on NVIDIA GPUs.
65
+ """
66
+ return torch.cuda.is_available() and not is_hip()
67
+
68
+
69
+ def is_ipv6(address):
70
+ try:
71
+ ipaddress.IPv6Address(address)
72
+ return True
73
+ except ipaddress.AddressValueError:
74
+ return False
75
+
76
+
59
77
  def enable_show_time_cost():
60
78
  global show_time_cost
61
79
  show_time_cost = True
@@ -170,35 +188,6 @@ def is_port_available(port):
170
188
  return False
171
189
 
172
190
 
173
- def allocate_init_ports(
174
- port: Optional[int] = None,
175
- additional_ports: Optional[List[int]] = None,
176
- dp_size: int = 1,
177
- ):
178
- """Allocate ports for all connections."""
179
- if additional_ports:
180
- ret_ports = [port] + additional_ports
181
- else:
182
- ret_ports = [port]
183
-
184
- ret_ports = list(set(x for x in ret_ports if is_port_available(x)))
185
- cur_port = ret_ports[-1] + 1 if len(ret_ports) > 0 else 10000
186
-
187
- # HTTP + Tokenizer + Controller + Detokenizer + dp_size * 1 (nccl)
188
- num_ports_needed = 4 + dp_size
189
- while len(ret_ports) < num_ports_needed:
190
- if cur_port not in ret_ports and is_port_available(cur_port):
191
- ret_ports.append(cur_port)
192
- cur_port += 1
193
-
194
- if port is not None and ret_ports[0] != port:
195
- logger.warning(
196
- f"WARNING: Port {port} is not available. Use port {ret_ports[0]} instead."
197
- )
198
-
199
- return ret_ports[0], ret_ports[1:num_ports_needed]
200
-
201
-
202
191
  def is_multimodal_model(model_architectures):
203
192
  if (
204
193
  "LlavaLlamaForCausalLM" in model_architectures
@@ -219,6 +208,8 @@ def is_generation_model(model_architectures, is_embedding: bool = False):
219
208
  if (
220
209
  "LlamaEmbeddingModel" in model_architectures
221
210
  or "MistralModel" in model_architectures
211
+ or "LlamaForSequenceClassification" in model_architectures
212
+ or "LlamaForSequenceClassificationWithNormal_Weights" in model_architectures
222
213
  ):
223
214
  return False
224
215
  else:
@@ -345,6 +336,10 @@ def suppress_other_loggers():
345
336
  logging.getLogger("vllm.selector").setLevel(logging.WARN)
346
337
  logging.getLogger("vllm.utils").setLevel(logging.ERROR)
347
338
 
339
+ warnings.filterwarnings(
340
+ "ignore", category=UserWarning, message="The given NumPy array is not writable"
341
+ )
342
+
348
343
 
349
344
  def assert_pkg_version(pkg: str, min_version: str, message: str):
350
345
  try:
@@ -537,89 +532,6 @@ class CustomCacheManager(FileCacheManager):
537
532
  raise RuntimeError("Could not create or locate cache dir")
538
533
 
539
534
 
540
- def get_ip_address(ifname):
541
- """
542
- Get the IP address of a network interface.
543
-
544
- :param ifname: Name of the network interface (e.g., 'eth0')
545
- :return: IP address of the network interface
546
- """
547
- s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
548
- ip_address = fcntl.ioctl(
549
- s.fileno(),
550
- 0x8915, # SIOCGIFADDR
551
- struct.pack("256s", bytes(ifname[:15], "utf-8")),
552
- )[20:24]
553
- return socket.inet_ntoa(ip_address)
554
-
555
-
556
- def send_addrs_to_rank_0(model_port_args, server_args):
557
- assert server_args.node_rank != 0 and server_args.dp_size == 1
558
-
559
- ifname = os.environ.get(
560
- "SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0")
561
- )
562
- ip_addr = get_ip_address(ifname)
563
-
564
- num_tp_ports = server_args.tp_size // server_args.nnodes
565
- model_port_args.model_tp_ips[:num_tp_ports] = [ip_addr] * num_tp_ports
566
- ip_addr = [int(x) for x in ip_addr.split(".")]
567
- addrs_tensor = torch.tensor(
568
- ip_addr + model_port_args.model_tp_ports, dtype=torch.int
569
- )
570
-
571
- init_method = f"tcp://{server_args.nccl_init_addr}"
572
- dist.init_process_group(
573
- backend="gloo",
574
- init_method=init_method,
575
- rank=server_args.node_rank,
576
- world_size=server_args.nnodes,
577
- )
578
- dist.send(addrs_tensor, dst=0)
579
- print(
580
- f"Node {server_args.node_rank} sent: ip_address {ip_addr} and ports {model_port_args.model_tp_ports}"
581
- )
582
-
583
- dist.barrier()
584
- dist.destroy_process_group()
585
-
586
-
587
- def receive_addrs(model_port_args, server_args):
588
- assert server_args.node_rank == 0 and server_args.dp_size == 1
589
-
590
- ifname = os.environ.get(
591
- "SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0")
592
- )
593
- ip_addr = get_ip_address(ifname)
594
-
595
- num_tp_ports = server_args.tp_size // server_args.nnodes
596
- model_port_args.model_tp_ips[:num_tp_ports] = [ip_addr] * num_tp_ports
597
-
598
- init_method = f"tcp://{server_args.nccl_init_addr}"
599
- dist.init_process_group(
600
- backend="gloo",
601
- init_method=init_method,
602
- rank=server_args.node_rank,
603
- world_size=server_args.nnodes,
604
- )
605
-
606
- for src_rank in range(1, server_args.nnodes):
607
- tensor = torch.zeros(4 + num_tp_ports, dtype=torch.int)
608
- dist.recv(tensor, src=src_rank)
609
- ip = ".".join([str(x) for x in tensor[:4].tolist()])
610
- ports = tensor[4:].tolist()
611
- model_port_args.model_tp_ips[
612
- num_tp_ports * src_rank : num_tp_ports * (src_rank + 1)
613
- ] = [ip] * num_tp_ports
614
- model_port_args.model_tp_ports[
615
- num_tp_ports * src_rank : num_tp_ports * (src_rank + 1)
616
- ] = ports
617
- print(f"Node 0 received from rank {src_rank}: {tensor.tolist()}")
618
-
619
- dist.barrier()
620
- dist.destroy_process_group()
621
-
622
-
623
535
  def set_ulimit(target_soft_limit=65535):
624
536
  resource_type = resource.RLIMIT_NOFILE
625
537
  current_soft, current_hard = resource.getrlimit(resource_type)
@@ -643,24 +555,16 @@ def add_api_key_middleware(app, api_key: str):
643
555
  return await call_next(request)
644
556
 
645
557
 
646
- def prepare_model(model_path: str):
558
+ def prepare_model_and_tokenizer(model_path: str, tokenizer_path: str):
647
559
  if "SGLANG_USE_MODELSCOPE" in os.environ:
648
560
  if not os.path.exists(model_path):
649
561
  from modelscope import snapshot_download
650
562
 
651
- return snapshot_download(model_path)
652
- return model_path
653
-
654
-
655
- def prepare_tokenizer(tokenizer_path: str):
656
- if "SGLANG_USE_MODELSCOPE" in os.environ:
657
- if not os.path.exists(tokenizer_path):
658
- from modelscope import snapshot_download
659
-
660
- return snapshot_download(
563
+ model_path = snapshot_download(model_path)
564
+ tokenizer_path = snapshot_download(
661
565
  tokenizer_path, ignore_patterns=["*.bin", "*.safetensors"]
662
566
  )
663
- return tokenizer_path
567
+ return model_path, tokenizer_path
664
568
 
665
569
 
666
570
  def configure_logger(server_args, prefix: str = ""):
@@ -702,3 +606,72 @@ def set_weight_attrs(
702
606
  for key, value in weight_attrs.items():
703
607
  assert not hasattr(weight, key), f"Overwriting existing tensor attribute: {key}"
704
608
  setattr(weight, key, value)
609
+
610
+
611
+ def broadcast_pyobj(
612
+ data: List[Any],
613
+ rank: int,
614
+ dist_group: Optional[torch.distributed.ProcessGroup] = None,
615
+ ):
616
+ """Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
617
+
618
+ if rank == 0:
619
+ if len(data) == 0:
620
+ tensor_size = torch.tensor([0], dtype=torch.long)
621
+ dist.broadcast(tensor_size, src=0, group=dist_group)
622
+ else:
623
+ serialized_data = pickle.dumps(data)
624
+ size = len(serialized_data)
625
+ tensor_data = torch.ByteTensor(
626
+ np.frombuffer(serialized_data, dtype=np.uint8)
627
+ )
628
+ tensor_size = torch.tensor([size], dtype=torch.long)
629
+
630
+ dist.broadcast(tensor_size, src=0, group=dist_group)
631
+ dist.broadcast(tensor_data, src=0, group=dist_group)
632
+ return data
633
+ else:
634
+ tensor_size = torch.tensor([0], dtype=torch.long)
635
+ dist.broadcast(tensor_size, src=0, group=dist_group)
636
+ size = tensor_size.item()
637
+
638
+ if size == 0:
639
+ return []
640
+
641
+ tensor_data = torch.empty(size, dtype=torch.uint8)
642
+ dist.broadcast(tensor_data, src=0, group=dist_group)
643
+
644
+ serialized_data = bytes(tensor_data.cpu().numpy())
645
+ data = pickle.loads(serialized_data)
646
+ return data
647
+
648
+
649
+ step_counter = 0
650
+
651
+
652
+ def pytorch_profile(name, func, *args, data_size=-1):
653
+ """
654
+ Args:
655
+ name (string): the name of recorded function.
656
+ func: the function to be profiled.
657
+ args: the arguments of the profiled function.
658
+ data_size (int): some measurement of the computation complexity.
659
+ Usually, it could be the batch size.
660
+ """
661
+ global step_counter
662
+ os.makedirs("trace", exist_ok=True)
663
+ with profile(
664
+ activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
665
+ # schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2),
666
+ # on_trace_ready=tensorboard_trace_handler('./log_dir'),
667
+ record_shapes=True,
668
+ profile_memory=True,
669
+ with_stack=True,
670
+ ) as prof:
671
+ with record_function(name):
672
+ with open(f"trace/size_{step_counter}.json", "w") as f:
673
+ json.dump({"size": data_size}, f)
674
+ result = func(*args)
675
+ prof.export_chrome_trace(f"trace/{name}_{step_counter}.json")
676
+ step_counter += 1
677
+ return result
sglang/test/runners.py CHANGED
@@ -65,17 +65,18 @@ class ModelOutput:
65
65
  top_input_logprobs: List[torch.Tensor] = None
66
66
  top_output_logprobs: List[torch.Tensor] = None
67
67
  embed_logits: List[torch.Tensor] = None
68
+ scores: List[float] = None
68
69
 
69
70
 
70
71
  class HFRunner:
71
72
  def __init__(
72
73
  self,
73
- model_path,
74
- torch_dtype,
75
- is_generation,
76
- output_str_only=False,
74
+ model_path: str,
75
+ torch_dtype: torch.dtype,
76
+ model_type: str = "generation",
77
+ output_str_only: bool = False,
77
78
  ):
78
- self.is_generation = is_generation
79
+ self.model_type = model_type
79
80
  self.output_str_only = output_str_only
80
81
 
81
82
  self.in_queue = mp.Queue()
@@ -92,22 +93,41 @@ class HFRunner:
92
93
  )
93
94
  self.model_proc.start()
94
95
 
96
+ def needs_trust_remote_code(self, model_path):
97
+ models_needs_trust_remote = [
98
+ "LxzGordon/URM-LLaMa-3.1-8B",
99
+ ]
100
+ if model_path in models_needs_trust_remote:
101
+ return True
102
+ return False
103
+
95
104
  def start_model_process(self, in_queue, out_queue, model_path, torch_dtype):
96
- self.tokenizer = get_tokenizer(model_path)
97
- if self.is_generation:
105
+ self.tokenizer = get_tokenizer(model_path, torch_dtype=torch.dtype)
106
+
107
+ if self.model_type == "generation":
98
108
  self.base_model = AutoModelForCausalLM.from_pretrained(
99
109
  model_path,
100
110
  torch_dtype=torch_dtype,
101
111
  trust_remote_code=False,
102
112
  low_cpu_mem_usage=True,
103
113
  ).cuda()
104
- else:
114
+ elif self.model_type == "embedding":
105
115
  from sentence_transformers import SentenceTransformer
106
116
 
107
117
  self.model = SentenceTransformer(
108
118
  model_path,
109
119
  model_kwargs={"torch_dtype": torch_dtype},
110
- )
120
+ ).cuda()
121
+ elif self.model_type == "reward":
122
+ from transformers import AutoModelForSequenceClassification
123
+
124
+ self.model = AutoModelForSequenceClassification.from_pretrained(
125
+ model_path,
126
+ torch_dtype=torch_dtype,
127
+ trust_remote_code=self.needs_trust_remote_code(model_path),
128
+ ).cuda()
129
+ else:
130
+ raise Exception(f"Unrecognized model type {self.model_type}")
111
131
 
112
132
  while True:
113
133
  prompts, max_new_tokens, lora_paths = in_queue.get()
@@ -115,7 +135,7 @@ class HFRunner:
115
135
  assert len(prompts) == len(lora_paths)
116
136
 
117
137
  if prompts is not None:
118
- if self.is_generation:
138
+ if self.model_type == "generation":
119
139
  output_strs = []
120
140
  top_input_logprobs = []
121
141
  top_output_logprobs = []
@@ -179,11 +199,27 @@ class HFRunner:
179
199
  )
180
200
  )
181
201
 
182
- else:
202
+ elif self.model_type == "embedding":
183
203
  assert not self.output_str_only
184
204
  logits = self.model.encode(prompts).tolist()
185
205
  out_queue.put(ModelOutput(embed_logits=logits))
186
206
 
207
+ elif self.model_type == "reward":
208
+ scores = []
209
+ for conv in prompts:
210
+ conv_formatted = self.tokenizer.apply_chat_template(
211
+ conv, tokenize=False
212
+ )
213
+ conv_tokenized = self.tokenizer(
214
+ conv_formatted, return_tensors="pt"
215
+ ).to("cuda")
216
+ scores.append(
217
+ float(self.model(**conv_tokenized).logits[0][0].item())
218
+ )
219
+ out_queue.put(ModelOutput(scores=scores))
220
+ else:
221
+ raise Exception(f"Unrecognized model type {self.model_type}")
222
+
187
223
  def forward(
188
224
  self,
189
225
  prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
@@ -208,23 +244,24 @@ class HFRunner:
208
244
  class SRTRunner:
209
245
  def __init__(
210
246
  self,
211
- model_path,
212
- torch_dtype,
213
- is_generation,
214
- tp_size=1,
215
- port=DEFAULT_PORT_FOR_SRT_TEST_RUNNER,
216
- lora_paths=None,
217
- max_loras_per_batch=4,
218
- disable_cuda_graph=False,
219
- disable_radix_cache=False,
247
+ model_path: str,
248
+ torch_dtype: torch.dtype,
249
+ model_type: str,
250
+ tp_size: int = 1,
251
+ port: int = DEFAULT_PORT_FOR_SRT_TEST_RUNNER,
252
+ lora_paths: List[str] = None,
253
+ max_loras_per_batch: int = 4,
254
+ disable_cuda_graph: bool = False,
255
+ disable_radix_cache: bool = False,
220
256
  ):
221
- self.is_generation = is_generation
257
+ self.model_type = model_type
258
+ self.is_generation = model_type == "generation"
222
259
  self.runtime = Runtime(
223
260
  model_path=model_path,
224
261
  tp_size=tp_size,
225
262
  dtype=get_dtype_str(torch_dtype),
226
263
  port=port,
227
- mem_fraction_static=0.69,
264
+ mem_fraction_static=0.65,
228
265
  trust_remote_code=False,
229
266
  is_embedding=not self.is_generation,
230
267
  lora_paths=lora_paths,
@@ -285,8 +322,12 @@ class SRTRunner:
285
322
  else:
286
323
  response = self.runtime.encode(prompts)
287
324
  response = json.loads(response)
288
- logits = [x["embedding"] for x in response]
289
- return ModelOutput(embed_logits=logits)
325
+ if self.model_type == "embedding":
326
+ logits = [x["embedding"] for x in response]
327
+ return ModelOutput(embed_logits=logits)
328
+ else:
329
+ scores = [x["embedding"][0] for x in response]
330
+ return ModelOutput(scores=scores)
290
331
 
291
332
  def batch_forward(
292
333
  self,
@@ -316,8 +357,12 @@ class SRTRunner:
316
357
  else:
317
358
  response = self.runtime.encode(prompts)
318
359
  response = json.loads(response)
319
- logits = [x["embedding"] for x in response]
320
- return ModelOutput(embed_logits=logits)
360
+ if self.model_type == "embedding":
361
+ logits = [x["embedding"] for x in response]
362
+ return ModelOutput(embed_logits=logits)
363
+ else:
364
+ scores = [x["embedding"][0] for x in response]
365
+ return ModelOutput(scores=logits)
321
366
 
322
367
  def __enter__(self):
323
368
  return self