sglang 0.3.1.post3__py3-none-any.whl → 0.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +23 -1
- sglang/bench_latency.py +48 -33
- sglang/bench_server_latency.py +0 -6
- sglang/bench_serving.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +14 -1
- sglang/lang/interpreter.py +16 -6
- sglang/lang/ir.py +20 -4
- sglang/srt/configs/model_config.py +11 -9
- sglang/srt/constrained/fsm_cache.py +9 -1
- sglang/srt/constrained/jump_forward.py +15 -2
- sglang/srt/hf_transformers_utils.py +1 -0
- sglang/srt/layers/activation.py +4 -4
- sglang/srt/layers/attention/__init__.py +49 -0
- sglang/srt/layers/attention/flashinfer_backend.py +277 -0
- sglang/srt/layers/{flashinfer_utils.py → attention/flashinfer_utils.py} +82 -80
- sglang/srt/layers/attention/triton_backend.py +161 -0
- sglang/srt/layers/{triton_attention → attention/triton_ops}/extend_attention.py +3 -1
- sglang/srt/layers/fused_moe/patch.py +117 -0
- sglang/srt/layers/layernorm.py +4 -4
- sglang/srt/layers/logits_processor.py +19 -15
- sglang/srt/layers/pooler.py +3 -3
- sglang/srt/layers/quantization/__init__.py +0 -2
- sglang/srt/layers/radix_attention.py +6 -4
- sglang/srt/layers/sampler.py +6 -4
- sglang/srt/layers/torchao_utils.py +18 -0
- sglang/srt/lora/lora.py +20 -21
- sglang/srt/lora/lora_manager.py +97 -25
- sglang/srt/managers/detokenizer_manager.py +31 -18
- sglang/srt/managers/image_processor.py +187 -0
- sglang/srt/managers/io_struct.py +99 -75
- sglang/srt/managers/schedule_batch.py +187 -68
- sglang/srt/managers/{policy_scheduler.py → schedule_policy.py} +31 -21
- sglang/srt/managers/scheduler.py +1021 -0
- sglang/srt/managers/tokenizer_manager.py +120 -247
- sglang/srt/managers/tp_worker.py +28 -925
- sglang/srt/mem_cache/memory_pool.py +34 -52
- sglang/srt/mem_cache/radix_cache.py +5 -5
- sglang/srt/model_executor/cuda_graph_runner.py +25 -25
- sglang/srt/model_executor/forward_batch_info.py +94 -97
- sglang/srt/model_executor/model_runner.py +76 -78
- sglang/srt/models/baichuan.py +10 -10
- sglang/srt/models/chatglm.py +12 -12
- sglang/srt/models/commandr.py +10 -10
- sglang/srt/models/dbrx.py +12 -12
- sglang/srt/models/deepseek.py +10 -10
- sglang/srt/models/deepseek_v2.py +14 -15
- sglang/srt/models/exaone.py +10 -10
- sglang/srt/models/gemma.py +10 -10
- sglang/srt/models/gemma2.py +11 -11
- sglang/srt/models/gpt_bigcode.py +10 -10
- sglang/srt/models/grok.py +10 -10
- sglang/srt/models/internlm2.py +10 -10
- sglang/srt/models/llama.py +22 -10
- sglang/srt/models/llama_classification.py +5 -5
- sglang/srt/models/llama_embedding.py +4 -4
- sglang/srt/models/llama_reward.py +142 -0
- sglang/srt/models/llava.py +39 -33
- sglang/srt/models/llavavid.py +31 -28
- sglang/srt/models/minicpm.py +10 -10
- sglang/srt/models/minicpm3.py +14 -15
- sglang/srt/models/mixtral.py +10 -10
- sglang/srt/models/mixtral_quant.py +10 -10
- sglang/srt/models/olmoe.py +10 -10
- sglang/srt/models/qwen.py +10 -10
- sglang/srt/models/qwen2.py +11 -11
- sglang/srt/models/qwen2_moe.py +10 -10
- sglang/srt/models/stablelm.py +10 -10
- sglang/srt/models/torch_native_llama.py +506 -0
- sglang/srt/models/xverse.py +10 -10
- sglang/srt/models/xverse_moe.py +10 -10
- sglang/srt/openai_api/adapter.py +7 -0
- sglang/srt/sampling/sampling_batch_info.py +36 -27
- sglang/srt/sampling/sampling_params.py +3 -1
- sglang/srt/server.py +170 -119
- sglang/srt/server_args.py +54 -27
- sglang/srt/utils.py +101 -128
- sglang/test/runners.py +76 -33
- sglang/test/test_programs.py +38 -5
- sglang/test/test_utils.py +53 -9
- sglang/version.py +1 -1
- {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/METADATA +42 -23
- sglang-0.3.3.dist-info/RECORD +139 -0
- sglang/srt/layers/attention_backend.py +0 -482
- sglang/srt/managers/controller_multi.py +0 -207
- sglang/srt/managers/controller_single.py +0 -164
- sglang-0.3.1.post3.dist-info/RECORD +0 -134
- /sglang/srt/layers/{triton_attention → attention/triton_ops}/decode_attention.py +0 -0
- /sglang/srt/layers/{triton_attention → attention/triton_ops}/prefill_attention.py +0 -0
- {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/LICENSE +0 -0
- {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/WHEEL +0 -0
- {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/top_level.txt +0 -0
sglang/srt/server_args.py
CHANGED
@@ -19,9 +19,10 @@ import argparse
|
|
19
19
|
import dataclasses
|
20
20
|
import logging
|
21
21
|
import random
|
22
|
-
|
22
|
+
import tempfile
|
23
|
+
from typing import List, Optional
|
23
24
|
|
24
|
-
from sglang.srt.utils import
|
25
|
+
from sglang.srt.utils import is_flashinfer_available, is_ipv6, is_port_available
|
25
26
|
|
26
27
|
logger = logging.getLogger(__name__)
|
27
28
|
|
@@ -46,7 +47,6 @@ class ServerArgs:
|
|
46
47
|
# Port
|
47
48
|
host: str = "127.0.0.1"
|
48
49
|
port: int = 30000
|
49
|
-
additional_ports: Optional[Union[List[int], int]] = None
|
50
50
|
|
51
51
|
# Memory and scheduling
|
52
52
|
mem_fraction_static: Optional[float] = None
|
@@ -78,9 +78,9 @@ class ServerArgs:
|
|
78
78
|
load_balance_method: str = "round_robin"
|
79
79
|
|
80
80
|
# Distributed args
|
81
|
-
|
81
|
+
dist_init_addr: Optional[str] = None
|
82
82
|
nnodes: int = 1
|
83
|
-
node_rank:
|
83
|
+
node_rank: int = 0
|
84
84
|
|
85
85
|
# Model override args in JSON
|
86
86
|
json_model_override_args: str = "{}"
|
@@ -134,11 +134,6 @@ class ServerArgs:
|
|
134
134
|
else:
|
135
135
|
self.mem_fraction_static = 0.88
|
136
136
|
|
137
|
-
if isinstance(self.additional_ports, int):
|
138
|
-
self.additional_ports = [self.additional_ports]
|
139
|
-
elif self.additional_ports is None:
|
140
|
-
self.additional_ports = []
|
141
|
-
|
142
137
|
if self.random_seed is None:
|
143
138
|
self.random_seed = random.randint(0, 1 << 30)
|
144
139
|
|
@@ -156,8 +151,7 @@ class ServerArgs:
|
|
156
151
|
)
|
157
152
|
self.sampling_backend = "pytorch"
|
158
153
|
|
159
|
-
|
160
|
-
if is_hip():
|
154
|
+
if not is_flashinfer_available():
|
161
155
|
self.attention_backend = "triton"
|
162
156
|
self.sampling_backend = "pytorch"
|
163
157
|
|
@@ -199,13 +193,6 @@ class ServerArgs:
|
|
199
193
|
parser.add_argument(
|
200
194
|
"--port", type=int, default=ServerArgs.port, help="The port of the server."
|
201
195
|
)
|
202
|
-
parser.add_argument(
|
203
|
-
"--additional-ports",
|
204
|
-
type=int,
|
205
|
-
nargs="*",
|
206
|
-
default=[],
|
207
|
-
help="The additional ports specified for the server.",
|
208
|
-
)
|
209
196
|
parser.add_argument(
|
210
197
|
"--tokenizer-mode",
|
211
198
|
type=str,
|
@@ -279,7 +266,6 @@ class ServerArgs:
|
|
279
266
|
"marlin",
|
280
267
|
"gptq_marlin",
|
281
268
|
"awq_marlin",
|
282
|
-
"squeezellm",
|
283
269
|
"bitsandbytes",
|
284
270
|
],
|
285
271
|
help="The quantization method.",
|
@@ -426,14 +412,17 @@ class ServerArgs:
|
|
426
412
|
|
427
413
|
# Multi-node distributed serving args
|
428
414
|
parser.add_argument(
|
429
|
-
"--
|
415
|
+
"--dist-init-addr",
|
416
|
+
"--nccl-init-addr", # For backward compatbility. This will be removed in the future.
|
430
417
|
type=str,
|
431
|
-
help="The
|
418
|
+
help="The host address for initializing distributed backend (e.g., `192.168.0.2:25000`).",
|
432
419
|
)
|
433
420
|
parser.add_argument(
|
434
421
|
"--nnodes", type=int, default=ServerArgs.nnodes, help="The number of nodes."
|
435
422
|
)
|
436
|
-
parser.add_argument(
|
423
|
+
parser.add_argument(
|
424
|
+
"--node-rank", type=int, default=ServerArgs.node_rank, help="The node rank."
|
425
|
+
)
|
437
426
|
|
438
427
|
# Model override args
|
439
428
|
parser.add_argument(
|
@@ -567,7 +556,10 @@ class ServerArgs:
|
|
567
556
|
return cls(**{attr: getattr(args, attr) for attr in attrs})
|
568
557
|
|
569
558
|
def url(self):
|
570
|
-
|
559
|
+
if is_ipv6(self.host):
|
560
|
+
return f"http://[{self.host}]:{self.port}"
|
561
|
+
else:
|
562
|
+
return f"http://{self.host}:{self.port}"
|
571
563
|
|
572
564
|
def check_server_args(self):
|
573
565
|
assert (
|
@@ -583,6 +575,21 @@ class ServerArgs:
|
|
583
575
|
and (self.lora_paths is None or self.disable_radix_cache)
|
584
576
|
), "compatibility of lora and cuda graph and radix attention is in progress"
|
585
577
|
|
578
|
+
assert self.dp_size == 1, (
|
579
|
+
"The support for data parallelism is temporarily disabled during refactor. "
|
580
|
+
"Please use sglang<=0.3.2 or wait for later updates."
|
581
|
+
)
|
582
|
+
|
583
|
+
if isinstance(self.lora_paths, list):
|
584
|
+
lora_paths = self.lora_paths
|
585
|
+
self.lora_paths = {}
|
586
|
+
for lora_path in lora_paths:
|
587
|
+
if "=" in lora_path:
|
588
|
+
name, path = lora_path.split("=", 1)
|
589
|
+
self.lora_paths[name] = path
|
590
|
+
else:
|
591
|
+
self.lora_paths[lora_path] = lora_path
|
592
|
+
|
586
593
|
|
587
594
|
def prepare_server_args(argv: List[str]) -> ServerArgs:
|
588
595
|
"""
|
@@ -604,11 +611,31 @@ def prepare_server_args(argv: List[str]) -> ServerArgs:
|
|
604
611
|
|
605
612
|
@dataclasses.dataclass
|
606
613
|
class PortArgs:
|
607
|
-
|
608
|
-
|
609
|
-
|
614
|
+
# The ipc filename for tokenizer to receive inputs from detokenizer (zmq)
|
615
|
+
tokenizer_ipc_name: str
|
616
|
+
# The ipc filename for scheduler (rank 0) to receive inputs from tokenizer (zmq)
|
617
|
+
scheduler_input_ipc_name: str
|
618
|
+
# The ipc filename for detokenizer to receive inputs from scheduler (zmq)
|
619
|
+
detokenizer_ipc_name: str
|
620
|
+
|
621
|
+
# The port for nccl initialization for multiple TP groups (torch.dist)
|
610
622
|
nccl_ports: List[int]
|
611
623
|
|
624
|
+
@classmethod
|
625
|
+
def init_new(self, server_args):
|
626
|
+
port = server_args.port + 1
|
627
|
+
while True:
|
628
|
+
if is_port_available(port):
|
629
|
+
break
|
630
|
+
port += 1
|
631
|
+
|
632
|
+
return PortArgs(
|
633
|
+
tokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
|
634
|
+
scheduler_input_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
|
635
|
+
detokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
|
636
|
+
nccl_ports=[port],
|
637
|
+
)
|
638
|
+
|
612
639
|
|
613
640
|
class LoRAPathAction(argparse.Action):
|
614
641
|
def __call__(self, parser, namespace, values, option_string=None):
|
sglang/srt/utils.py
CHANGED
@@ -16,14 +16,16 @@ limitations under the License.
|
|
16
16
|
"""Common utilities."""
|
17
17
|
|
18
18
|
import base64
|
19
|
-
import
|
19
|
+
import ipaddress
|
20
|
+
import json
|
20
21
|
import logging
|
21
22
|
import os
|
23
|
+
import pickle
|
22
24
|
import random
|
23
25
|
import resource
|
24
26
|
import socket
|
25
|
-
import struct
|
26
27
|
import time
|
28
|
+
import warnings
|
27
29
|
from importlib.metadata import PackageNotFoundError, version
|
28
30
|
from io import BytesIO
|
29
31
|
from typing import Any, Dict, List, Optional, Union
|
@@ -36,7 +38,7 @@ import torch.distributed as dist
|
|
36
38
|
from fastapi.responses import JSONResponse
|
37
39
|
from packaging import version as pkg_version
|
38
40
|
from torch import nn
|
39
|
-
from torch.
|
41
|
+
from torch.profiler import ProfilerActivity, profile, record_function
|
40
42
|
from triton.runtime.cache import (
|
41
43
|
FileCacheManager,
|
42
44
|
default_cache_dir,
|
@@ -51,11 +53,27 @@ show_time_cost = False
|
|
51
53
|
time_infos = {}
|
52
54
|
|
53
55
|
|
54
|
-
# torch flag AMD GPU
|
55
56
|
def is_hip() -> bool:
|
57
|
+
"""Return whether it is HIP on the AMD ROCm platform."""
|
56
58
|
return torch.version.hip is not None
|
57
59
|
|
58
60
|
|
61
|
+
def is_flashinfer_available():
|
62
|
+
"""
|
63
|
+
Check whether flashinfer is available.
|
64
|
+
As of Oct. 6, 2024, it is only available on NVIDIA GPUs.
|
65
|
+
"""
|
66
|
+
return torch.cuda.is_available() and not is_hip()
|
67
|
+
|
68
|
+
|
69
|
+
def is_ipv6(address):
|
70
|
+
try:
|
71
|
+
ipaddress.IPv6Address(address)
|
72
|
+
return True
|
73
|
+
except ipaddress.AddressValueError:
|
74
|
+
return False
|
75
|
+
|
76
|
+
|
59
77
|
def enable_show_time_cost():
|
60
78
|
global show_time_cost
|
61
79
|
show_time_cost = True
|
@@ -170,35 +188,6 @@ def is_port_available(port):
|
|
170
188
|
return False
|
171
189
|
|
172
190
|
|
173
|
-
def allocate_init_ports(
|
174
|
-
port: Optional[int] = None,
|
175
|
-
additional_ports: Optional[List[int]] = None,
|
176
|
-
dp_size: int = 1,
|
177
|
-
):
|
178
|
-
"""Allocate ports for all connections."""
|
179
|
-
if additional_ports:
|
180
|
-
ret_ports = [port] + additional_ports
|
181
|
-
else:
|
182
|
-
ret_ports = [port]
|
183
|
-
|
184
|
-
ret_ports = list(set(x for x in ret_ports if is_port_available(x)))
|
185
|
-
cur_port = ret_ports[-1] + 1 if len(ret_ports) > 0 else 10000
|
186
|
-
|
187
|
-
# HTTP + Tokenizer + Controller + Detokenizer + dp_size * 1 (nccl)
|
188
|
-
num_ports_needed = 4 + dp_size
|
189
|
-
while len(ret_ports) < num_ports_needed:
|
190
|
-
if cur_port not in ret_ports and is_port_available(cur_port):
|
191
|
-
ret_ports.append(cur_port)
|
192
|
-
cur_port += 1
|
193
|
-
|
194
|
-
if port is not None and ret_ports[0] != port:
|
195
|
-
logger.warning(
|
196
|
-
f"WARNING: Port {port} is not available. Use port {ret_ports[0]} instead."
|
197
|
-
)
|
198
|
-
|
199
|
-
return ret_ports[0], ret_ports[1:num_ports_needed]
|
200
|
-
|
201
|
-
|
202
191
|
def is_multimodal_model(model_architectures):
|
203
192
|
if (
|
204
193
|
"LlavaLlamaForCausalLM" in model_architectures
|
@@ -219,6 +208,8 @@ def is_generation_model(model_architectures, is_embedding: bool = False):
|
|
219
208
|
if (
|
220
209
|
"LlamaEmbeddingModel" in model_architectures
|
221
210
|
or "MistralModel" in model_architectures
|
211
|
+
or "LlamaForSequenceClassification" in model_architectures
|
212
|
+
or "LlamaForSequenceClassificationWithNormal_Weights" in model_architectures
|
222
213
|
):
|
223
214
|
return False
|
224
215
|
else:
|
@@ -345,6 +336,10 @@ def suppress_other_loggers():
|
|
345
336
|
logging.getLogger("vllm.selector").setLevel(logging.WARN)
|
346
337
|
logging.getLogger("vllm.utils").setLevel(logging.ERROR)
|
347
338
|
|
339
|
+
warnings.filterwarnings(
|
340
|
+
"ignore", category=UserWarning, message="The given NumPy array is not writable"
|
341
|
+
)
|
342
|
+
|
348
343
|
|
349
344
|
def assert_pkg_version(pkg: str, min_version: str, message: str):
|
350
345
|
try:
|
@@ -537,89 +532,6 @@ class CustomCacheManager(FileCacheManager):
|
|
537
532
|
raise RuntimeError("Could not create or locate cache dir")
|
538
533
|
|
539
534
|
|
540
|
-
def get_ip_address(ifname):
|
541
|
-
"""
|
542
|
-
Get the IP address of a network interface.
|
543
|
-
|
544
|
-
:param ifname: Name of the network interface (e.g., 'eth0')
|
545
|
-
:return: IP address of the network interface
|
546
|
-
"""
|
547
|
-
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
548
|
-
ip_address = fcntl.ioctl(
|
549
|
-
s.fileno(),
|
550
|
-
0x8915, # SIOCGIFADDR
|
551
|
-
struct.pack("256s", bytes(ifname[:15], "utf-8")),
|
552
|
-
)[20:24]
|
553
|
-
return socket.inet_ntoa(ip_address)
|
554
|
-
|
555
|
-
|
556
|
-
def send_addrs_to_rank_0(model_port_args, server_args):
|
557
|
-
assert server_args.node_rank != 0 and server_args.dp_size == 1
|
558
|
-
|
559
|
-
ifname = os.environ.get(
|
560
|
-
"SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0")
|
561
|
-
)
|
562
|
-
ip_addr = get_ip_address(ifname)
|
563
|
-
|
564
|
-
num_tp_ports = server_args.tp_size // server_args.nnodes
|
565
|
-
model_port_args.model_tp_ips[:num_tp_ports] = [ip_addr] * num_tp_ports
|
566
|
-
ip_addr = [int(x) for x in ip_addr.split(".")]
|
567
|
-
addrs_tensor = torch.tensor(
|
568
|
-
ip_addr + model_port_args.model_tp_ports, dtype=torch.int
|
569
|
-
)
|
570
|
-
|
571
|
-
init_method = f"tcp://{server_args.nccl_init_addr}"
|
572
|
-
dist.init_process_group(
|
573
|
-
backend="gloo",
|
574
|
-
init_method=init_method,
|
575
|
-
rank=server_args.node_rank,
|
576
|
-
world_size=server_args.nnodes,
|
577
|
-
)
|
578
|
-
dist.send(addrs_tensor, dst=0)
|
579
|
-
print(
|
580
|
-
f"Node {server_args.node_rank} sent: ip_address {ip_addr} and ports {model_port_args.model_tp_ports}"
|
581
|
-
)
|
582
|
-
|
583
|
-
dist.barrier()
|
584
|
-
dist.destroy_process_group()
|
585
|
-
|
586
|
-
|
587
|
-
def receive_addrs(model_port_args, server_args):
|
588
|
-
assert server_args.node_rank == 0 and server_args.dp_size == 1
|
589
|
-
|
590
|
-
ifname = os.environ.get(
|
591
|
-
"SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0")
|
592
|
-
)
|
593
|
-
ip_addr = get_ip_address(ifname)
|
594
|
-
|
595
|
-
num_tp_ports = server_args.tp_size // server_args.nnodes
|
596
|
-
model_port_args.model_tp_ips[:num_tp_ports] = [ip_addr] * num_tp_ports
|
597
|
-
|
598
|
-
init_method = f"tcp://{server_args.nccl_init_addr}"
|
599
|
-
dist.init_process_group(
|
600
|
-
backend="gloo",
|
601
|
-
init_method=init_method,
|
602
|
-
rank=server_args.node_rank,
|
603
|
-
world_size=server_args.nnodes,
|
604
|
-
)
|
605
|
-
|
606
|
-
for src_rank in range(1, server_args.nnodes):
|
607
|
-
tensor = torch.zeros(4 + num_tp_ports, dtype=torch.int)
|
608
|
-
dist.recv(tensor, src=src_rank)
|
609
|
-
ip = ".".join([str(x) for x in tensor[:4].tolist()])
|
610
|
-
ports = tensor[4:].tolist()
|
611
|
-
model_port_args.model_tp_ips[
|
612
|
-
num_tp_ports * src_rank : num_tp_ports * (src_rank + 1)
|
613
|
-
] = [ip] * num_tp_ports
|
614
|
-
model_port_args.model_tp_ports[
|
615
|
-
num_tp_ports * src_rank : num_tp_ports * (src_rank + 1)
|
616
|
-
] = ports
|
617
|
-
print(f"Node 0 received from rank {src_rank}: {tensor.tolist()}")
|
618
|
-
|
619
|
-
dist.barrier()
|
620
|
-
dist.destroy_process_group()
|
621
|
-
|
622
|
-
|
623
535
|
def set_ulimit(target_soft_limit=65535):
|
624
536
|
resource_type = resource.RLIMIT_NOFILE
|
625
537
|
current_soft, current_hard = resource.getrlimit(resource_type)
|
@@ -643,24 +555,16 @@ def add_api_key_middleware(app, api_key: str):
|
|
643
555
|
return await call_next(request)
|
644
556
|
|
645
557
|
|
646
|
-
def
|
558
|
+
def prepare_model_and_tokenizer(model_path: str, tokenizer_path: str):
|
647
559
|
if "SGLANG_USE_MODELSCOPE" in os.environ:
|
648
560
|
if not os.path.exists(model_path):
|
649
561
|
from modelscope import snapshot_download
|
650
562
|
|
651
|
-
|
652
|
-
|
653
|
-
|
654
|
-
|
655
|
-
def prepare_tokenizer(tokenizer_path: str):
|
656
|
-
if "SGLANG_USE_MODELSCOPE" in os.environ:
|
657
|
-
if not os.path.exists(tokenizer_path):
|
658
|
-
from modelscope import snapshot_download
|
659
|
-
|
660
|
-
return snapshot_download(
|
563
|
+
model_path = snapshot_download(model_path)
|
564
|
+
tokenizer_path = snapshot_download(
|
661
565
|
tokenizer_path, ignore_patterns=["*.bin", "*.safetensors"]
|
662
566
|
)
|
663
|
-
return tokenizer_path
|
567
|
+
return model_path, tokenizer_path
|
664
568
|
|
665
569
|
|
666
570
|
def configure_logger(server_args, prefix: str = ""):
|
@@ -702,3 +606,72 @@ def set_weight_attrs(
|
|
702
606
|
for key, value in weight_attrs.items():
|
703
607
|
assert not hasattr(weight, key), f"Overwriting existing tensor attribute: {key}"
|
704
608
|
setattr(weight, key, value)
|
609
|
+
|
610
|
+
|
611
|
+
def broadcast_pyobj(
|
612
|
+
data: List[Any],
|
613
|
+
rank: int,
|
614
|
+
dist_group: Optional[torch.distributed.ProcessGroup] = None,
|
615
|
+
):
|
616
|
+
"""Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
|
617
|
+
|
618
|
+
if rank == 0:
|
619
|
+
if len(data) == 0:
|
620
|
+
tensor_size = torch.tensor([0], dtype=torch.long)
|
621
|
+
dist.broadcast(tensor_size, src=0, group=dist_group)
|
622
|
+
else:
|
623
|
+
serialized_data = pickle.dumps(data)
|
624
|
+
size = len(serialized_data)
|
625
|
+
tensor_data = torch.ByteTensor(
|
626
|
+
np.frombuffer(serialized_data, dtype=np.uint8)
|
627
|
+
)
|
628
|
+
tensor_size = torch.tensor([size], dtype=torch.long)
|
629
|
+
|
630
|
+
dist.broadcast(tensor_size, src=0, group=dist_group)
|
631
|
+
dist.broadcast(tensor_data, src=0, group=dist_group)
|
632
|
+
return data
|
633
|
+
else:
|
634
|
+
tensor_size = torch.tensor([0], dtype=torch.long)
|
635
|
+
dist.broadcast(tensor_size, src=0, group=dist_group)
|
636
|
+
size = tensor_size.item()
|
637
|
+
|
638
|
+
if size == 0:
|
639
|
+
return []
|
640
|
+
|
641
|
+
tensor_data = torch.empty(size, dtype=torch.uint8)
|
642
|
+
dist.broadcast(tensor_data, src=0, group=dist_group)
|
643
|
+
|
644
|
+
serialized_data = bytes(tensor_data.cpu().numpy())
|
645
|
+
data = pickle.loads(serialized_data)
|
646
|
+
return data
|
647
|
+
|
648
|
+
|
649
|
+
step_counter = 0
|
650
|
+
|
651
|
+
|
652
|
+
def pytorch_profile(name, func, *args, data_size=-1):
|
653
|
+
"""
|
654
|
+
Args:
|
655
|
+
name (string): the name of recorded function.
|
656
|
+
func: the function to be profiled.
|
657
|
+
args: the arguments of the profiled function.
|
658
|
+
data_size (int): some measurement of the computation complexity.
|
659
|
+
Usually, it could be the batch size.
|
660
|
+
"""
|
661
|
+
global step_counter
|
662
|
+
os.makedirs("trace", exist_ok=True)
|
663
|
+
with profile(
|
664
|
+
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
665
|
+
# schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2),
|
666
|
+
# on_trace_ready=tensorboard_trace_handler('./log_dir'),
|
667
|
+
record_shapes=True,
|
668
|
+
profile_memory=True,
|
669
|
+
with_stack=True,
|
670
|
+
) as prof:
|
671
|
+
with record_function(name):
|
672
|
+
with open(f"trace/size_{step_counter}.json", "w") as f:
|
673
|
+
json.dump({"size": data_size}, f)
|
674
|
+
result = func(*args)
|
675
|
+
prof.export_chrome_trace(f"trace/{name}_{step_counter}.json")
|
676
|
+
step_counter += 1
|
677
|
+
return result
|
sglang/test/runners.py
CHANGED
@@ -21,19 +21,19 @@ from typing import List, Union
|
|
21
21
|
|
22
22
|
import torch
|
23
23
|
import torch.nn.functional as F
|
24
|
-
from
|
25
|
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
24
|
+
from transformers import AutoModelForCausalLM
|
26
25
|
|
26
|
+
from sglang.srt.hf_transformers_utils import get_tokenizer
|
27
27
|
from sglang.srt.server import Runtime
|
28
28
|
from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER
|
29
29
|
|
30
30
|
DEFAULT_PROMPTS = [
|
31
|
-
# the output of gemma-2-2b from SRT is unstable on the commented prompt
|
32
|
-
# "The capital of France is",
|
33
31
|
"Apple is red. Banana is Yellow. " * 800 + "Apple is",
|
34
32
|
"The capital of the United Kingdom is",
|
35
33
|
"Today is a sunny day and I like",
|
36
34
|
"AI is a field of computer science focused on",
|
35
|
+
# the output of gemma-2-2b from SRT is unstable on the commented prompt
|
36
|
+
# "The capital of France is",
|
37
37
|
]
|
38
38
|
|
39
39
|
dirpath = os.path.dirname(__file__)
|
@@ -65,17 +65,18 @@ class ModelOutput:
|
|
65
65
|
top_input_logprobs: List[torch.Tensor] = None
|
66
66
|
top_output_logprobs: List[torch.Tensor] = None
|
67
67
|
embed_logits: List[torch.Tensor] = None
|
68
|
+
scores: List[float] = None
|
68
69
|
|
69
70
|
|
70
71
|
class HFRunner:
|
71
72
|
def __init__(
|
72
73
|
self,
|
73
|
-
model_path,
|
74
|
-
torch_dtype,
|
75
|
-
|
76
|
-
output_str_only=False,
|
74
|
+
model_path: str,
|
75
|
+
torch_dtype: torch.dtype,
|
76
|
+
model_type: str = "generation",
|
77
|
+
output_str_only: bool = False,
|
77
78
|
):
|
78
|
-
self.
|
79
|
+
self.model_type = model_type
|
79
80
|
self.output_str_only = output_str_only
|
80
81
|
|
81
82
|
self.in_queue = mp.Queue()
|
@@ -92,26 +93,41 @@ class HFRunner:
|
|
92
93
|
)
|
93
94
|
self.model_proc.start()
|
94
95
|
|
96
|
+
def needs_trust_remote_code(self, model_path):
|
97
|
+
models_needs_trust_remote = [
|
98
|
+
"LxzGordon/URM-LLaMa-3.1-8B",
|
99
|
+
]
|
100
|
+
if model_path in models_needs_trust_remote:
|
101
|
+
return True
|
102
|
+
return False
|
103
|
+
|
95
104
|
def start_model_process(self, in_queue, out_queue, model_path, torch_dtype):
|
96
|
-
self.tokenizer =
|
97
|
-
model_path,
|
98
|
-
torch_dtype=torch_dtype,
|
99
|
-
)
|
105
|
+
self.tokenizer = get_tokenizer(model_path, torch_dtype=torch.dtype)
|
100
106
|
|
101
|
-
if self.
|
107
|
+
if self.model_type == "generation":
|
102
108
|
self.base_model = AutoModelForCausalLM.from_pretrained(
|
103
109
|
model_path,
|
104
110
|
torch_dtype=torch_dtype,
|
105
111
|
trust_remote_code=False,
|
106
112
|
low_cpu_mem_usage=True,
|
107
113
|
).cuda()
|
108
|
-
|
114
|
+
elif self.model_type == "embedding":
|
109
115
|
from sentence_transformers import SentenceTransformer
|
110
116
|
|
111
117
|
self.model = SentenceTransformer(
|
112
118
|
model_path,
|
113
119
|
model_kwargs={"torch_dtype": torch_dtype},
|
114
|
-
)
|
120
|
+
).cuda()
|
121
|
+
elif self.model_type == "reward":
|
122
|
+
from transformers import AutoModelForSequenceClassification
|
123
|
+
|
124
|
+
self.model = AutoModelForSequenceClassification.from_pretrained(
|
125
|
+
model_path,
|
126
|
+
torch_dtype=torch_dtype,
|
127
|
+
trust_remote_code=self.needs_trust_remote_code(model_path),
|
128
|
+
).cuda()
|
129
|
+
else:
|
130
|
+
raise Exception(f"Unrecognized model type {self.model_type}")
|
115
131
|
|
116
132
|
while True:
|
117
133
|
prompts, max_new_tokens, lora_paths = in_queue.get()
|
@@ -119,7 +135,7 @@ class HFRunner:
|
|
119
135
|
assert len(prompts) == len(lora_paths)
|
120
136
|
|
121
137
|
if prompts is not None:
|
122
|
-
if self.
|
138
|
+
if self.model_type == "generation":
|
123
139
|
output_strs = []
|
124
140
|
top_input_logprobs = []
|
125
141
|
top_output_logprobs = []
|
@@ -132,6 +148,8 @@ class HFRunner:
|
|
132
148
|
input_ids = torch.tensor([p], device="cuda")
|
133
149
|
|
134
150
|
if lora_paths is not None and lora_paths[i] is not None:
|
151
|
+
from peft import PeftModel
|
152
|
+
|
135
153
|
self.model = PeftModel.from_pretrained(
|
136
154
|
self.base_model,
|
137
155
|
lora_paths[i],
|
@@ -181,11 +199,27 @@ class HFRunner:
|
|
181
199
|
)
|
182
200
|
)
|
183
201
|
|
184
|
-
|
202
|
+
elif self.model_type == "embedding":
|
185
203
|
assert not self.output_str_only
|
186
204
|
logits = self.model.encode(prompts).tolist()
|
187
205
|
out_queue.put(ModelOutput(embed_logits=logits))
|
188
206
|
|
207
|
+
elif self.model_type == "reward":
|
208
|
+
scores = []
|
209
|
+
for conv in prompts:
|
210
|
+
conv_formatted = self.tokenizer.apply_chat_template(
|
211
|
+
conv, tokenize=False
|
212
|
+
)
|
213
|
+
conv_tokenized = self.tokenizer(
|
214
|
+
conv_formatted, return_tensors="pt"
|
215
|
+
).to("cuda")
|
216
|
+
scores.append(
|
217
|
+
float(self.model(**conv_tokenized).logits[0][0].item())
|
218
|
+
)
|
219
|
+
out_queue.put(ModelOutput(scores=scores))
|
220
|
+
else:
|
221
|
+
raise Exception(f"Unrecognized model type {self.model_type}")
|
222
|
+
|
189
223
|
def forward(
|
190
224
|
self,
|
191
225
|
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
|
@@ -210,23 +244,24 @@ class HFRunner:
|
|
210
244
|
class SRTRunner:
|
211
245
|
def __init__(
|
212
246
|
self,
|
213
|
-
model_path,
|
214
|
-
torch_dtype,
|
215
|
-
|
216
|
-
tp_size=1,
|
217
|
-
port=DEFAULT_PORT_FOR_SRT_TEST_RUNNER,
|
218
|
-
lora_paths=None,
|
219
|
-
max_loras_per_batch=4,
|
220
|
-
disable_cuda_graph=False,
|
221
|
-
disable_radix_cache=False,
|
247
|
+
model_path: str,
|
248
|
+
torch_dtype: torch.dtype,
|
249
|
+
model_type: str,
|
250
|
+
tp_size: int = 1,
|
251
|
+
port: int = DEFAULT_PORT_FOR_SRT_TEST_RUNNER,
|
252
|
+
lora_paths: List[str] = None,
|
253
|
+
max_loras_per_batch: int = 4,
|
254
|
+
disable_cuda_graph: bool = False,
|
255
|
+
disable_radix_cache: bool = False,
|
222
256
|
):
|
223
|
-
self.
|
257
|
+
self.model_type = model_type
|
258
|
+
self.is_generation = model_type == "generation"
|
224
259
|
self.runtime = Runtime(
|
225
260
|
model_path=model_path,
|
226
261
|
tp_size=tp_size,
|
227
262
|
dtype=get_dtype_str(torch_dtype),
|
228
263
|
port=port,
|
229
|
-
mem_fraction_static=0.
|
264
|
+
mem_fraction_static=0.65,
|
230
265
|
trust_remote_code=False,
|
231
266
|
is_embedding=not self.is_generation,
|
232
267
|
lora_paths=lora_paths,
|
@@ -287,8 +322,12 @@ class SRTRunner:
|
|
287
322
|
else:
|
288
323
|
response = self.runtime.encode(prompts)
|
289
324
|
response = json.loads(response)
|
290
|
-
|
291
|
-
|
325
|
+
if self.model_type == "embedding":
|
326
|
+
logits = [x["embedding"] for x in response]
|
327
|
+
return ModelOutput(embed_logits=logits)
|
328
|
+
else:
|
329
|
+
scores = [x["embedding"][0] for x in response]
|
330
|
+
return ModelOutput(scores=scores)
|
292
331
|
|
293
332
|
def batch_forward(
|
294
333
|
self,
|
@@ -318,8 +357,12 @@ class SRTRunner:
|
|
318
357
|
else:
|
319
358
|
response = self.runtime.encode(prompts)
|
320
359
|
response = json.loads(response)
|
321
|
-
|
322
|
-
|
360
|
+
if self.model_type == "embedding":
|
361
|
+
logits = [x["embedding"] for x in response]
|
362
|
+
return ModelOutput(embed_logits=logits)
|
363
|
+
else:
|
364
|
+
scores = [x["embedding"][0] for x in response]
|
365
|
+
return ModelOutput(scores=logits)
|
323
366
|
|
324
367
|
def __enter__(self):
|
325
368
|
return self
|