sglang 0.3.1.post3__py3-none-any.whl → 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +23 -1
  3. sglang/bench_latency.py +48 -33
  4. sglang/bench_server_latency.py +0 -6
  5. sglang/bench_serving.py +2 -2
  6. sglang/lang/backend/runtime_endpoint.py +14 -1
  7. sglang/lang/interpreter.py +16 -6
  8. sglang/lang/ir.py +20 -4
  9. sglang/srt/configs/model_config.py +11 -9
  10. sglang/srt/constrained/fsm_cache.py +9 -1
  11. sglang/srt/constrained/jump_forward.py +15 -2
  12. sglang/srt/hf_transformers_utils.py +1 -0
  13. sglang/srt/layers/activation.py +4 -4
  14. sglang/srt/layers/attention/__init__.py +49 -0
  15. sglang/srt/layers/attention/flashinfer_backend.py +277 -0
  16. sglang/srt/layers/{flashinfer_utils.py → attention/flashinfer_utils.py} +82 -80
  17. sglang/srt/layers/attention/triton_backend.py +161 -0
  18. sglang/srt/layers/{triton_attention → attention/triton_ops}/extend_attention.py +3 -1
  19. sglang/srt/layers/fused_moe/patch.py +117 -0
  20. sglang/srt/layers/layernorm.py +4 -4
  21. sglang/srt/layers/logits_processor.py +19 -15
  22. sglang/srt/layers/pooler.py +3 -3
  23. sglang/srt/layers/quantization/__init__.py +0 -2
  24. sglang/srt/layers/radix_attention.py +6 -4
  25. sglang/srt/layers/sampler.py +6 -4
  26. sglang/srt/layers/torchao_utils.py +18 -0
  27. sglang/srt/lora/lora.py +20 -21
  28. sglang/srt/lora/lora_manager.py +97 -25
  29. sglang/srt/managers/detokenizer_manager.py +31 -18
  30. sglang/srt/managers/image_processor.py +187 -0
  31. sglang/srt/managers/io_struct.py +99 -75
  32. sglang/srt/managers/schedule_batch.py +187 -68
  33. sglang/srt/managers/{policy_scheduler.py → schedule_policy.py} +31 -21
  34. sglang/srt/managers/scheduler.py +1021 -0
  35. sglang/srt/managers/tokenizer_manager.py +120 -247
  36. sglang/srt/managers/tp_worker.py +28 -925
  37. sglang/srt/mem_cache/memory_pool.py +34 -52
  38. sglang/srt/mem_cache/radix_cache.py +5 -5
  39. sglang/srt/model_executor/cuda_graph_runner.py +25 -25
  40. sglang/srt/model_executor/forward_batch_info.py +94 -97
  41. sglang/srt/model_executor/model_runner.py +76 -78
  42. sglang/srt/models/baichuan.py +10 -10
  43. sglang/srt/models/chatglm.py +12 -12
  44. sglang/srt/models/commandr.py +10 -10
  45. sglang/srt/models/dbrx.py +12 -12
  46. sglang/srt/models/deepseek.py +10 -10
  47. sglang/srt/models/deepseek_v2.py +14 -15
  48. sglang/srt/models/exaone.py +10 -10
  49. sglang/srt/models/gemma.py +10 -10
  50. sglang/srt/models/gemma2.py +11 -11
  51. sglang/srt/models/gpt_bigcode.py +10 -10
  52. sglang/srt/models/grok.py +10 -10
  53. sglang/srt/models/internlm2.py +10 -10
  54. sglang/srt/models/llama.py +22 -10
  55. sglang/srt/models/llama_classification.py +5 -5
  56. sglang/srt/models/llama_embedding.py +4 -4
  57. sglang/srt/models/llama_reward.py +142 -0
  58. sglang/srt/models/llava.py +39 -33
  59. sglang/srt/models/llavavid.py +31 -28
  60. sglang/srt/models/minicpm.py +10 -10
  61. sglang/srt/models/minicpm3.py +14 -15
  62. sglang/srt/models/mixtral.py +10 -10
  63. sglang/srt/models/mixtral_quant.py +10 -10
  64. sglang/srt/models/olmoe.py +10 -10
  65. sglang/srt/models/qwen.py +10 -10
  66. sglang/srt/models/qwen2.py +11 -11
  67. sglang/srt/models/qwen2_moe.py +10 -10
  68. sglang/srt/models/stablelm.py +10 -10
  69. sglang/srt/models/torch_native_llama.py +506 -0
  70. sglang/srt/models/xverse.py +10 -10
  71. sglang/srt/models/xverse_moe.py +10 -10
  72. sglang/srt/openai_api/adapter.py +7 -0
  73. sglang/srt/sampling/sampling_batch_info.py +36 -27
  74. sglang/srt/sampling/sampling_params.py +3 -1
  75. sglang/srt/server.py +170 -119
  76. sglang/srt/server_args.py +54 -27
  77. sglang/srt/utils.py +101 -128
  78. sglang/test/runners.py +76 -33
  79. sglang/test/test_programs.py +38 -5
  80. sglang/test/test_utils.py +53 -9
  81. sglang/version.py +1 -1
  82. {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/METADATA +42 -23
  83. sglang-0.3.3.dist-info/RECORD +139 -0
  84. sglang/srt/layers/attention_backend.py +0 -482
  85. sglang/srt/managers/controller_multi.py +0 -207
  86. sglang/srt/managers/controller_single.py +0 -164
  87. sglang-0.3.1.post3.dist-info/RECORD +0 -134
  88. /sglang/srt/layers/{triton_attention → attention/triton_ops}/decode_attention.py +0 -0
  89. /sglang/srt/layers/{triton_attention → attention/triton_ops}/prefill_attention.py +0 -0
  90. {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/LICENSE +0 -0
  91. {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/WHEEL +0 -0
  92. {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/top_level.txt +0 -0
sglang/srt/server_args.py CHANGED
@@ -19,9 +19,10 @@ import argparse
19
19
  import dataclasses
20
20
  import logging
21
21
  import random
22
- from typing import List, Optional, Union
22
+ import tempfile
23
+ from typing import List, Optional
23
24
 
24
- from sglang.srt.utils import is_hip
25
+ from sglang.srt.utils import is_flashinfer_available, is_ipv6, is_port_available
25
26
 
26
27
  logger = logging.getLogger(__name__)
27
28
 
@@ -46,7 +47,6 @@ class ServerArgs:
46
47
  # Port
47
48
  host: str = "127.0.0.1"
48
49
  port: int = 30000
49
- additional_ports: Optional[Union[List[int], int]] = None
50
50
 
51
51
  # Memory and scheduling
52
52
  mem_fraction_static: Optional[float] = None
@@ -78,9 +78,9 @@ class ServerArgs:
78
78
  load_balance_method: str = "round_robin"
79
79
 
80
80
  # Distributed args
81
- nccl_init_addr: Optional[str] = None
81
+ dist_init_addr: Optional[str] = None
82
82
  nnodes: int = 1
83
- node_rank: Optional[int] = None
83
+ node_rank: int = 0
84
84
 
85
85
  # Model override args in JSON
86
86
  json_model_override_args: str = "{}"
@@ -134,11 +134,6 @@ class ServerArgs:
134
134
  else:
135
135
  self.mem_fraction_static = 0.88
136
136
 
137
- if isinstance(self.additional_ports, int):
138
- self.additional_ports = [self.additional_ports]
139
- elif self.additional_ports is None:
140
- self.additional_ports = []
141
-
142
137
  if self.random_seed is None:
143
138
  self.random_seed = random.randint(0, 1 << 30)
144
139
 
@@ -156,8 +151,7 @@ class ServerArgs:
156
151
  )
157
152
  self.sampling_backend = "pytorch"
158
153
 
159
- # ROCm: flashinfer available later
160
- if is_hip():
154
+ if not is_flashinfer_available():
161
155
  self.attention_backend = "triton"
162
156
  self.sampling_backend = "pytorch"
163
157
 
@@ -199,13 +193,6 @@ class ServerArgs:
199
193
  parser.add_argument(
200
194
  "--port", type=int, default=ServerArgs.port, help="The port of the server."
201
195
  )
202
- parser.add_argument(
203
- "--additional-ports",
204
- type=int,
205
- nargs="*",
206
- default=[],
207
- help="The additional ports specified for the server.",
208
- )
209
196
  parser.add_argument(
210
197
  "--tokenizer-mode",
211
198
  type=str,
@@ -279,7 +266,6 @@ class ServerArgs:
279
266
  "marlin",
280
267
  "gptq_marlin",
281
268
  "awq_marlin",
282
- "squeezellm",
283
269
  "bitsandbytes",
284
270
  ],
285
271
  help="The quantization method.",
@@ -426,14 +412,17 @@ class ServerArgs:
426
412
 
427
413
  # Multi-node distributed serving args
428
414
  parser.add_argument(
429
- "--nccl-init-addr",
415
+ "--dist-init-addr",
416
+ "--nccl-init-addr", # For backward compatbility. This will be removed in the future.
430
417
  type=str,
431
- help="The nccl init address of multi-node server.",
418
+ help="The host address for initializing distributed backend (e.g., `192.168.0.2:25000`).",
432
419
  )
433
420
  parser.add_argument(
434
421
  "--nnodes", type=int, default=ServerArgs.nnodes, help="The number of nodes."
435
422
  )
436
- parser.add_argument("--node-rank", type=int, help="The node rank.")
423
+ parser.add_argument(
424
+ "--node-rank", type=int, default=ServerArgs.node_rank, help="The node rank."
425
+ )
437
426
 
438
427
  # Model override args
439
428
  parser.add_argument(
@@ -567,7 +556,10 @@ class ServerArgs:
567
556
  return cls(**{attr: getattr(args, attr) for attr in attrs})
568
557
 
569
558
  def url(self):
570
- return f"http://{self.host}:{self.port}"
559
+ if is_ipv6(self.host):
560
+ return f"http://[{self.host}]:{self.port}"
561
+ else:
562
+ return f"http://{self.host}:{self.port}"
571
563
 
572
564
  def check_server_args(self):
573
565
  assert (
@@ -583,6 +575,21 @@ class ServerArgs:
583
575
  and (self.lora_paths is None or self.disable_radix_cache)
584
576
  ), "compatibility of lora and cuda graph and radix attention is in progress"
585
577
 
578
+ assert self.dp_size == 1, (
579
+ "The support for data parallelism is temporarily disabled during refactor. "
580
+ "Please use sglang<=0.3.2 or wait for later updates."
581
+ )
582
+
583
+ if isinstance(self.lora_paths, list):
584
+ lora_paths = self.lora_paths
585
+ self.lora_paths = {}
586
+ for lora_path in lora_paths:
587
+ if "=" in lora_path:
588
+ name, path = lora_path.split("=", 1)
589
+ self.lora_paths[name] = path
590
+ else:
591
+ self.lora_paths[lora_path] = lora_path
592
+
586
593
 
587
594
  def prepare_server_args(argv: List[str]) -> ServerArgs:
588
595
  """
@@ -604,11 +611,31 @@ def prepare_server_args(argv: List[str]) -> ServerArgs:
604
611
 
605
612
  @dataclasses.dataclass
606
613
  class PortArgs:
607
- tokenizer_port: int
608
- controller_port: int
609
- detokenizer_port: int
614
+ # The ipc filename for tokenizer to receive inputs from detokenizer (zmq)
615
+ tokenizer_ipc_name: str
616
+ # The ipc filename for scheduler (rank 0) to receive inputs from tokenizer (zmq)
617
+ scheduler_input_ipc_name: str
618
+ # The ipc filename for detokenizer to receive inputs from scheduler (zmq)
619
+ detokenizer_ipc_name: str
620
+
621
+ # The port for nccl initialization for multiple TP groups (torch.dist)
610
622
  nccl_ports: List[int]
611
623
 
624
+ @classmethod
625
+ def init_new(self, server_args):
626
+ port = server_args.port + 1
627
+ while True:
628
+ if is_port_available(port):
629
+ break
630
+ port += 1
631
+
632
+ return PortArgs(
633
+ tokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
634
+ scheduler_input_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
635
+ detokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
636
+ nccl_ports=[port],
637
+ )
638
+
612
639
 
613
640
  class LoRAPathAction(argparse.Action):
614
641
  def __call__(self, parser, namespace, values, option_string=None):
sglang/srt/utils.py CHANGED
@@ -16,14 +16,16 @@ limitations under the License.
16
16
  """Common utilities."""
17
17
 
18
18
  import base64
19
- import fcntl
19
+ import ipaddress
20
+ import json
20
21
  import logging
21
22
  import os
23
+ import pickle
22
24
  import random
23
25
  import resource
24
26
  import socket
25
- import struct
26
27
  import time
28
+ import warnings
27
29
  from importlib.metadata import PackageNotFoundError, version
28
30
  from io import BytesIO
29
31
  from typing import Any, Dict, List, Optional, Union
@@ -36,7 +38,7 @@ import torch.distributed as dist
36
38
  from fastapi.responses import JSONResponse
37
39
  from packaging import version as pkg_version
38
40
  from torch import nn
39
- from torch.nn.parameter import Parameter
41
+ from torch.profiler import ProfilerActivity, profile, record_function
40
42
  from triton.runtime.cache import (
41
43
  FileCacheManager,
42
44
  default_cache_dir,
@@ -51,11 +53,27 @@ show_time_cost = False
51
53
  time_infos = {}
52
54
 
53
55
 
54
- # torch flag AMD GPU
55
56
  def is_hip() -> bool:
57
+ """Return whether it is HIP on the AMD ROCm platform."""
56
58
  return torch.version.hip is not None
57
59
 
58
60
 
61
+ def is_flashinfer_available():
62
+ """
63
+ Check whether flashinfer is available.
64
+ As of Oct. 6, 2024, it is only available on NVIDIA GPUs.
65
+ """
66
+ return torch.cuda.is_available() and not is_hip()
67
+
68
+
69
+ def is_ipv6(address):
70
+ try:
71
+ ipaddress.IPv6Address(address)
72
+ return True
73
+ except ipaddress.AddressValueError:
74
+ return False
75
+
76
+
59
77
  def enable_show_time_cost():
60
78
  global show_time_cost
61
79
  show_time_cost = True
@@ -170,35 +188,6 @@ def is_port_available(port):
170
188
  return False
171
189
 
172
190
 
173
- def allocate_init_ports(
174
- port: Optional[int] = None,
175
- additional_ports: Optional[List[int]] = None,
176
- dp_size: int = 1,
177
- ):
178
- """Allocate ports for all connections."""
179
- if additional_ports:
180
- ret_ports = [port] + additional_ports
181
- else:
182
- ret_ports = [port]
183
-
184
- ret_ports = list(set(x for x in ret_ports if is_port_available(x)))
185
- cur_port = ret_ports[-1] + 1 if len(ret_ports) > 0 else 10000
186
-
187
- # HTTP + Tokenizer + Controller + Detokenizer + dp_size * 1 (nccl)
188
- num_ports_needed = 4 + dp_size
189
- while len(ret_ports) < num_ports_needed:
190
- if cur_port not in ret_ports and is_port_available(cur_port):
191
- ret_ports.append(cur_port)
192
- cur_port += 1
193
-
194
- if port is not None and ret_ports[0] != port:
195
- logger.warning(
196
- f"WARNING: Port {port} is not available. Use port {ret_ports[0]} instead."
197
- )
198
-
199
- return ret_ports[0], ret_ports[1:num_ports_needed]
200
-
201
-
202
191
  def is_multimodal_model(model_architectures):
203
192
  if (
204
193
  "LlavaLlamaForCausalLM" in model_architectures
@@ -219,6 +208,8 @@ def is_generation_model(model_architectures, is_embedding: bool = False):
219
208
  if (
220
209
  "LlamaEmbeddingModel" in model_architectures
221
210
  or "MistralModel" in model_architectures
211
+ or "LlamaForSequenceClassification" in model_architectures
212
+ or "LlamaForSequenceClassificationWithNormal_Weights" in model_architectures
222
213
  ):
223
214
  return False
224
215
  else:
@@ -345,6 +336,10 @@ def suppress_other_loggers():
345
336
  logging.getLogger("vllm.selector").setLevel(logging.WARN)
346
337
  logging.getLogger("vllm.utils").setLevel(logging.ERROR)
347
338
 
339
+ warnings.filterwarnings(
340
+ "ignore", category=UserWarning, message="The given NumPy array is not writable"
341
+ )
342
+
348
343
 
349
344
  def assert_pkg_version(pkg: str, min_version: str, message: str):
350
345
  try:
@@ -537,89 +532,6 @@ class CustomCacheManager(FileCacheManager):
537
532
  raise RuntimeError("Could not create or locate cache dir")
538
533
 
539
534
 
540
- def get_ip_address(ifname):
541
- """
542
- Get the IP address of a network interface.
543
-
544
- :param ifname: Name of the network interface (e.g., 'eth0')
545
- :return: IP address of the network interface
546
- """
547
- s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
548
- ip_address = fcntl.ioctl(
549
- s.fileno(),
550
- 0x8915, # SIOCGIFADDR
551
- struct.pack("256s", bytes(ifname[:15], "utf-8")),
552
- )[20:24]
553
- return socket.inet_ntoa(ip_address)
554
-
555
-
556
- def send_addrs_to_rank_0(model_port_args, server_args):
557
- assert server_args.node_rank != 0 and server_args.dp_size == 1
558
-
559
- ifname = os.environ.get(
560
- "SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0")
561
- )
562
- ip_addr = get_ip_address(ifname)
563
-
564
- num_tp_ports = server_args.tp_size // server_args.nnodes
565
- model_port_args.model_tp_ips[:num_tp_ports] = [ip_addr] * num_tp_ports
566
- ip_addr = [int(x) for x in ip_addr.split(".")]
567
- addrs_tensor = torch.tensor(
568
- ip_addr + model_port_args.model_tp_ports, dtype=torch.int
569
- )
570
-
571
- init_method = f"tcp://{server_args.nccl_init_addr}"
572
- dist.init_process_group(
573
- backend="gloo",
574
- init_method=init_method,
575
- rank=server_args.node_rank,
576
- world_size=server_args.nnodes,
577
- )
578
- dist.send(addrs_tensor, dst=0)
579
- print(
580
- f"Node {server_args.node_rank} sent: ip_address {ip_addr} and ports {model_port_args.model_tp_ports}"
581
- )
582
-
583
- dist.barrier()
584
- dist.destroy_process_group()
585
-
586
-
587
- def receive_addrs(model_port_args, server_args):
588
- assert server_args.node_rank == 0 and server_args.dp_size == 1
589
-
590
- ifname = os.environ.get(
591
- "SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0")
592
- )
593
- ip_addr = get_ip_address(ifname)
594
-
595
- num_tp_ports = server_args.tp_size // server_args.nnodes
596
- model_port_args.model_tp_ips[:num_tp_ports] = [ip_addr] * num_tp_ports
597
-
598
- init_method = f"tcp://{server_args.nccl_init_addr}"
599
- dist.init_process_group(
600
- backend="gloo",
601
- init_method=init_method,
602
- rank=server_args.node_rank,
603
- world_size=server_args.nnodes,
604
- )
605
-
606
- for src_rank in range(1, server_args.nnodes):
607
- tensor = torch.zeros(4 + num_tp_ports, dtype=torch.int)
608
- dist.recv(tensor, src=src_rank)
609
- ip = ".".join([str(x) for x in tensor[:4].tolist()])
610
- ports = tensor[4:].tolist()
611
- model_port_args.model_tp_ips[
612
- num_tp_ports * src_rank : num_tp_ports * (src_rank + 1)
613
- ] = [ip] * num_tp_ports
614
- model_port_args.model_tp_ports[
615
- num_tp_ports * src_rank : num_tp_ports * (src_rank + 1)
616
- ] = ports
617
- print(f"Node 0 received from rank {src_rank}: {tensor.tolist()}")
618
-
619
- dist.barrier()
620
- dist.destroy_process_group()
621
-
622
-
623
535
  def set_ulimit(target_soft_limit=65535):
624
536
  resource_type = resource.RLIMIT_NOFILE
625
537
  current_soft, current_hard = resource.getrlimit(resource_type)
@@ -643,24 +555,16 @@ def add_api_key_middleware(app, api_key: str):
643
555
  return await call_next(request)
644
556
 
645
557
 
646
- def prepare_model(model_path: str):
558
+ def prepare_model_and_tokenizer(model_path: str, tokenizer_path: str):
647
559
  if "SGLANG_USE_MODELSCOPE" in os.environ:
648
560
  if not os.path.exists(model_path):
649
561
  from modelscope import snapshot_download
650
562
 
651
- return snapshot_download(model_path)
652
- return model_path
653
-
654
-
655
- def prepare_tokenizer(tokenizer_path: str):
656
- if "SGLANG_USE_MODELSCOPE" in os.environ:
657
- if not os.path.exists(tokenizer_path):
658
- from modelscope import snapshot_download
659
-
660
- return snapshot_download(
563
+ model_path = snapshot_download(model_path)
564
+ tokenizer_path = snapshot_download(
661
565
  tokenizer_path, ignore_patterns=["*.bin", "*.safetensors"]
662
566
  )
663
- return tokenizer_path
567
+ return model_path, tokenizer_path
664
568
 
665
569
 
666
570
  def configure_logger(server_args, prefix: str = ""):
@@ -702,3 +606,72 @@ def set_weight_attrs(
702
606
  for key, value in weight_attrs.items():
703
607
  assert not hasattr(weight, key), f"Overwriting existing tensor attribute: {key}"
704
608
  setattr(weight, key, value)
609
+
610
+
611
+ def broadcast_pyobj(
612
+ data: List[Any],
613
+ rank: int,
614
+ dist_group: Optional[torch.distributed.ProcessGroup] = None,
615
+ ):
616
+ """Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
617
+
618
+ if rank == 0:
619
+ if len(data) == 0:
620
+ tensor_size = torch.tensor([0], dtype=torch.long)
621
+ dist.broadcast(tensor_size, src=0, group=dist_group)
622
+ else:
623
+ serialized_data = pickle.dumps(data)
624
+ size = len(serialized_data)
625
+ tensor_data = torch.ByteTensor(
626
+ np.frombuffer(serialized_data, dtype=np.uint8)
627
+ )
628
+ tensor_size = torch.tensor([size], dtype=torch.long)
629
+
630
+ dist.broadcast(tensor_size, src=0, group=dist_group)
631
+ dist.broadcast(tensor_data, src=0, group=dist_group)
632
+ return data
633
+ else:
634
+ tensor_size = torch.tensor([0], dtype=torch.long)
635
+ dist.broadcast(tensor_size, src=0, group=dist_group)
636
+ size = tensor_size.item()
637
+
638
+ if size == 0:
639
+ return []
640
+
641
+ tensor_data = torch.empty(size, dtype=torch.uint8)
642
+ dist.broadcast(tensor_data, src=0, group=dist_group)
643
+
644
+ serialized_data = bytes(tensor_data.cpu().numpy())
645
+ data = pickle.loads(serialized_data)
646
+ return data
647
+
648
+
649
+ step_counter = 0
650
+
651
+
652
+ def pytorch_profile(name, func, *args, data_size=-1):
653
+ """
654
+ Args:
655
+ name (string): the name of recorded function.
656
+ func: the function to be profiled.
657
+ args: the arguments of the profiled function.
658
+ data_size (int): some measurement of the computation complexity.
659
+ Usually, it could be the batch size.
660
+ """
661
+ global step_counter
662
+ os.makedirs("trace", exist_ok=True)
663
+ with profile(
664
+ activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
665
+ # schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2),
666
+ # on_trace_ready=tensorboard_trace_handler('./log_dir'),
667
+ record_shapes=True,
668
+ profile_memory=True,
669
+ with_stack=True,
670
+ ) as prof:
671
+ with record_function(name):
672
+ with open(f"trace/size_{step_counter}.json", "w") as f:
673
+ json.dump({"size": data_size}, f)
674
+ result = func(*args)
675
+ prof.export_chrome_trace(f"trace/{name}_{step_counter}.json")
676
+ step_counter += 1
677
+ return result
sglang/test/runners.py CHANGED
@@ -21,19 +21,19 @@ from typing import List, Union
21
21
 
22
22
  import torch
23
23
  import torch.nn.functional as F
24
- from peft import PeftModel
25
- from transformers import AutoModelForCausalLM, AutoTokenizer
24
+ from transformers import AutoModelForCausalLM
26
25
 
26
+ from sglang.srt.hf_transformers_utils import get_tokenizer
27
27
  from sglang.srt.server import Runtime
28
28
  from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER
29
29
 
30
30
  DEFAULT_PROMPTS = [
31
- # the output of gemma-2-2b from SRT is unstable on the commented prompt
32
- # "The capital of France is",
33
31
  "Apple is red. Banana is Yellow. " * 800 + "Apple is",
34
32
  "The capital of the United Kingdom is",
35
33
  "Today is a sunny day and I like",
36
34
  "AI is a field of computer science focused on",
35
+ # the output of gemma-2-2b from SRT is unstable on the commented prompt
36
+ # "The capital of France is",
37
37
  ]
38
38
 
39
39
  dirpath = os.path.dirname(__file__)
@@ -65,17 +65,18 @@ class ModelOutput:
65
65
  top_input_logprobs: List[torch.Tensor] = None
66
66
  top_output_logprobs: List[torch.Tensor] = None
67
67
  embed_logits: List[torch.Tensor] = None
68
+ scores: List[float] = None
68
69
 
69
70
 
70
71
  class HFRunner:
71
72
  def __init__(
72
73
  self,
73
- model_path,
74
- torch_dtype,
75
- is_generation,
76
- output_str_only=False,
74
+ model_path: str,
75
+ torch_dtype: torch.dtype,
76
+ model_type: str = "generation",
77
+ output_str_only: bool = False,
77
78
  ):
78
- self.is_generation = is_generation
79
+ self.model_type = model_type
79
80
  self.output_str_only = output_str_only
80
81
 
81
82
  self.in_queue = mp.Queue()
@@ -92,26 +93,41 @@ class HFRunner:
92
93
  )
93
94
  self.model_proc.start()
94
95
 
96
+ def needs_trust_remote_code(self, model_path):
97
+ models_needs_trust_remote = [
98
+ "LxzGordon/URM-LLaMa-3.1-8B",
99
+ ]
100
+ if model_path in models_needs_trust_remote:
101
+ return True
102
+ return False
103
+
95
104
  def start_model_process(self, in_queue, out_queue, model_path, torch_dtype):
96
- self.tokenizer = AutoTokenizer.from_pretrained(
97
- model_path,
98
- torch_dtype=torch_dtype,
99
- )
105
+ self.tokenizer = get_tokenizer(model_path, torch_dtype=torch.dtype)
100
106
 
101
- if self.is_generation:
107
+ if self.model_type == "generation":
102
108
  self.base_model = AutoModelForCausalLM.from_pretrained(
103
109
  model_path,
104
110
  torch_dtype=torch_dtype,
105
111
  trust_remote_code=False,
106
112
  low_cpu_mem_usage=True,
107
113
  ).cuda()
108
- else:
114
+ elif self.model_type == "embedding":
109
115
  from sentence_transformers import SentenceTransformer
110
116
 
111
117
  self.model = SentenceTransformer(
112
118
  model_path,
113
119
  model_kwargs={"torch_dtype": torch_dtype},
114
- )
120
+ ).cuda()
121
+ elif self.model_type == "reward":
122
+ from transformers import AutoModelForSequenceClassification
123
+
124
+ self.model = AutoModelForSequenceClassification.from_pretrained(
125
+ model_path,
126
+ torch_dtype=torch_dtype,
127
+ trust_remote_code=self.needs_trust_remote_code(model_path),
128
+ ).cuda()
129
+ else:
130
+ raise Exception(f"Unrecognized model type {self.model_type}")
115
131
 
116
132
  while True:
117
133
  prompts, max_new_tokens, lora_paths = in_queue.get()
@@ -119,7 +135,7 @@ class HFRunner:
119
135
  assert len(prompts) == len(lora_paths)
120
136
 
121
137
  if prompts is not None:
122
- if self.is_generation:
138
+ if self.model_type == "generation":
123
139
  output_strs = []
124
140
  top_input_logprobs = []
125
141
  top_output_logprobs = []
@@ -132,6 +148,8 @@ class HFRunner:
132
148
  input_ids = torch.tensor([p], device="cuda")
133
149
 
134
150
  if lora_paths is not None and lora_paths[i] is not None:
151
+ from peft import PeftModel
152
+
135
153
  self.model = PeftModel.from_pretrained(
136
154
  self.base_model,
137
155
  lora_paths[i],
@@ -181,11 +199,27 @@ class HFRunner:
181
199
  )
182
200
  )
183
201
 
184
- else:
202
+ elif self.model_type == "embedding":
185
203
  assert not self.output_str_only
186
204
  logits = self.model.encode(prompts).tolist()
187
205
  out_queue.put(ModelOutput(embed_logits=logits))
188
206
 
207
+ elif self.model_type == "reward":
208
+ scores = []
209
+ for conv in prompts:
210
+ conv_formatted = self.tokenizer.apply_chat_template(
211
+ conv, tokenize=False
212
+ )
213
+ conv_tokenized = self.tokenizer(
214
+ conv_formatted, return_tensors="pt"
215
+ ).to("cuda")
216
+ scores.append(
217
+ float(self.model(**conv_tokenized).logits[0][0].item())
218
+ )
219
+ out_queue.put(ModelOutput(scores=scores))
220
+ else:
221
+ raise Exception(f"Unrecognized model type {self.model_type}")
222
+
189
223
  def forward(
190
224
  self,
191
225
  prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
@@ -210,23 +244,24 @@ class HFRunner:
210
244
  class SRTRunner:
211
245
  def __init__(
212
246
  self,
213
- model_path,
214
- torch_dtype,
215
- is_generation,
216
- tp_size=1,
217
- port=DEFAULT_PORT_FOR_SRT_TEST_RUNNER,
218
- lora_paths=None,
219
- max_loras_per_batch=4,
220
- disable_cuda_graph=False,
221
- disable_radix_cache=False,
247
+ model_path: str,
248
+ torch_dtype: torch.dtype,
249
+ model_type: str,
250
+ tp_size: int = 1,
251
+ port: int = DEFAULT_PORT_FOR_SRT_TEST_RUNNER,
252
+ lora_paths: List[str] = None,
253
+ max_loras_per_batch: int = 4,
254
+ disable_cuda_graph: bool = False,
255
+ disable_radix_cache: bool = False,
222
256
  ):
223
- self.is_generation = is_generation
257
+ self.model_type = model_type
258
+ self.is_generation = model_type == "generation"
224
259
  self.runtime = Runtime(
225
260
  model_path=model_path,
226
261
  tp_size=tp_size,
227
262
  dtype=get_dtype_str(torch_dtype),
228
263
  port=port,
229
- mem_fraction_static=0.69,
264
+ mem_fraction_static=0.65,
230
265
  trust_remote_code=False,
231
266
  is_embedding=not self.is_generation,
232
267
  lora_paths=lora_paths,
@@ -287,8 +322,12 @@ class SRTRunner:
287
322
  else:
288
323
  response = self.runtime.encode(prompts)
289
324
  response = json.loads(response)
290
- logits = [x["embedding"] for x in response]
291
- return ModelOutput(embed_logits=logits)
325
+ if self.model_type == "embedding":
326
+ logits = [x["embedding"] for x in response]
327
+ return ModelOutput(embed_logits=logits)
328
+ else:
329
+ scores = [x["embedding"][0] for x in response]
330
+ return ModelOutput(scores=scores)
292
331
 
293
332
  def batch_forward(
294
333
  self,
@@ -318,8 +357,12 @@ class SRTRunner:
318
357
  else:
319
358
  response = self.runtime.encode(prompts)
320
359
  response = json.loads(response)
321
- logits = [x["embedding"] for x in response]
322
- return ModelOutput(embed_logits=logits)
360
+ if self.model_type == "embedding":
361
+ logits = [x["embedding"] for x in response]
362
+ return ModelOutput(embed_logits=logits)
363
+ else:
364
+ scores = [x["embedding"][0] for x in response]
365
+ return ModelOutput(scores=logits)
323
366
 
324
367
  def __enter__(self):
325
368
  return self