sglang 0.3.2__py3-none-any.whl → 0.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +23 -1
- sglang/bench_latency.py +46 -25
- sglang/bench_serving.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +14 -1
- sglang/lang/interpreter.py +16 -6
- sglang/lang/ir.py +20 -4
- sglang/srt/configs/model_config.py +11 -9
- sglang/srt/constrained/fsm_cache.py +9 -1
- sglang/srt/constrained/jump_forward.py +15 -2
- sglang/srt/layers/activation.py +4 -4
- sglang/srt/layers/attention/__init__.py +49 -0
- sglang/srt/layers/attention/flashinfer_backend.py +277 -0
- sglang/srt/layers/{flashinfer_utils.py → attention/flashinfer_utils.py} +82 -80
- sglang/srt/layers/attention/triton_backend.py +161 -0
- sglang/srt/layers/{triton_attention → attention/triton_ops}/extend_attention.py +3 -1
- sglang/srt/layers/layernorm.py +4 -4
- sglang/srt/layers/logits_processor.py +19 -15
- sglang/srt/layers/pooler.py +3 -3
- sglang/srt/layers/quantization/__init__.py +0 -2
- sglang/srt/layers/radix_attention.py +6 -4
- sglang/srt/layers/sampler.py +6 -4
- sglang/srt/layers/torchao_utils.py +18 -0
- sglang/srt/lora/lora.py +20 -21
- sglang/srt/lora/lora_manager.py +97 -25
- sglang/srt/managers/detokenizer_manager.py +31 -18
- sglang/srt/managers/image_processor.py +187 -0
- sglang/srt/managers/io_struct.py +99 -75
- sglang/srt/managers/schedule_batch.py +184 -63
- sglang/srt/managers/{policy_scheduler.py → schedule_policy.py} +31 -21
- sglang/srt/managers/scheduler.py +1021 -0
- sglang/srt/managers/tokenizer_manager.py +120 -248
- sglang/srt/managers/tp_worker.py +28 -925
- sglang/srt/mem_cache/memory_pool.py +34 -52
- sglang/srt/model_executor/cuda_graph_runner.py +15 -19
- sglang/srt/model_executor/forward_batch_info.py +94 -95
- sglang/srt/model_executor/model_runner.py +76 -75
- sglang/srt/models/baichuan.py +10 -10
- sglang/srt/models/chatglm.py +12 -12
- sglang/srt/models/commandr.py +10 -10
- sglang/srt/models/dbrx.py +12 -12
- sglang/srt/models/deepseek.py +10 -10
- sglang/srt/models/deepseek_v2.py +14 -15
- sglang/srt/models/exaone.py +10 -10
- sglang/srt/models/gemma.py +10 -10
- sglang/srt/models/gemma2.py +11 -11
- sglang/srt/models/gpt_bigcode.py +10 -10
- sglang/srt/models/grok.py +10 -10
- sglang/srt/models/internlm2.py +10 -10
- sglang/srt/models/llama.py +14 -10
- sglang/srt/models/llama_classification.py +5 -5
- sglang/srt/models/llama_embedding.py +4 -4
- sglang/srt/models/llama_reward.py +142 -0
- sglang/srt/models/llava.py +39 -33
- sglang/srt/models/llavavid.py +31 -28
- sglang/srt/models/minicpm.py +10 -10
- sglang/srt/models/minicpm3.py +14 -15
- sglang/srt/models/mixtral.py +10 -10
- sglang/srt/models/mixtral_quant.py +10 -10
- sglang/srt/models/olmoe.py +10 -10
- sglang/srt/models/qwen.py +10 -10
- sglang/srt/models/qwen2.py +11 -11
- sglang/srt/models/qwen2_moe.py +10 -10
- sglang/srt/models/stablelm.py +10 -10
- sglang/srt/models/torch_native_llama.py +506 -0
- sglang/srt/models/xverse.py +10 -10
- sglang/srt/models/xverse_moe.py +10 -10
- sglang/srt/sampling/sampling_batch_info.py +36 -27
- sglang/srt/sampling/sampling_params.py +3 -1
- sglang/srt/server.py +170 -119
- sglang/srt/server_args.py +54 -27
- sglang/srt/utils.py +101 -128
- sglang/test/runners.py +71 -26
- sglang/test/test_programs.py +38 -5
- sglang/test/test_utils.py +18 -9
- sglang/version.py +1 -1
- {sglang-0.3.2.dist-info → sglang-0.3.3.dist-info}/METADATA +37 -19
- sglang-0.3.3.dist-info/RECORD +139 -0
- sglang/srt/layers/attention_backend.py +0 -474
- sglang/srt/managers/controller_multi.py +0 -207
- sglang/srt/managers/controller_single.py +0 -164
- sglang-0.3.2.dist-info/RECORD +0 -135
- /sglang/srt/layers/{triton_attention → attention/triton_ops}/decode_attention.py +0 -0
- /sglang/srt/layers/{triton_attention → attention/triton_ops}/prefill_attention.py +0 -0
- {sglang-0.3.2.dist-info → sglang-0.3.3.dist-info}/LICENSE +0 -0
- {sglang-0.3.2.dist-info → sglang-0.3.3.dist-info}/WHEEL +0 -0
- {sglang-0.3.2.dist-info → sglang-0.3.3.dist-info}/top_level.txt +0 -0
sglang/srt/server_args.py
CHANGED
@@ -19,9 +19,10 @@ import argparse
|
|
19
19
|
import dataclasses
|
20
20
|
import logging
|
21
21
|
import random
|
22
|
-
|
22
|
+
import tempfile
|
23
|
+
from typing import List, Optional
|
23
24
|
|
24
|
-
from sglang.srt.utils import
|
25
|
+
from sglang.srt.utils import is_flashinfer_available, is_ipv6, is_port_available
|
25
26
|
|
26
27
|
logger = logging.getLogger(__name__)
|
27
28
|
|
@@ -46,7 +47,6 @@ class ServerArgs:
|
|
46
47
|
# Port
|
47
48
|
host: str = "127.0.0.1"
|
48
49
|
port: int = 30000
|
49
|
-
additional_ports: Optional[Union[List[int], int]] = None
|
50
50
|
|
51
51
|
# Memory and scheduling
|
52
52
|
mem_fraction_static: Optional[float] = None
|
@@ -78,9 +78,9 @@ class ServerArgs:
|
|
78
78
|
load_balance_method: str = "round_robin"
|
79
79
|
|
80
80
|
# Distributed args
|
81
|
-
|
81
|
+
dist_init_addr: Optional[str] = None
|
82
82
|
nnodes: int = 1
|
83
|
-
node_rank:
|
83
|
+
node_rank: int = 0
|
84
84
|
|
85
85
|
# Model override args in JSON
|
86
86
|
json_model_override_args: str = "{}"
|
@@ -134,11 +134,6 @@ class ServerArgs:
|
|
134
134
|
else:
|
135
135
|
self.mem_fraction_static = 0.88
|
136
136
|
|
137
|
-
if isinstance(self.additional_ports, int):
|
138
|
-
self.additional_ports = [self.additional_ports]
|
139
|
-
elif self.additional_ports is None:
|
140
|
-
self.additional_ports = []
|
141
|
-
|
142
137
|
if self.random_seed is None:
|
143
138
|
self.random_seed = random.randint(0, 1 << 30)
|
144
139
|
|
@@ -156,8 +151,7 @@ class ServerArgs:
|
|
156
151
|
)
|
157
152
|
self.sampling_backend = "pytorch"
|
158
153
|
|
159
|
-
|
160
|
-
if is_hip():
|
154
|
+
if not is_flashinfer_available():
|
161
155
|
self.attention_backend = "triton"
|
162
156
|
self.sampling_backend = "pytorch"
|
163
157
|
|
@@ -199,13 +193,6 @@ class ServerArgs:
|
|
199
193
|
parser.add_argument(
|
200
194
|
"--port", type=int, default=ServerArgs.port, help="The port of the server."
|
201
195
|
)
|
202
|
-
parser.add_argument(
|
203
|
-
"--additional-ports",
|
204
|
-
type=int,
|
205
|
-
nargs="*",
|
206
|
-
default=[],
|
207
|
-
help="The additional ports specified for the server.",
|
208
|
-
)
|
209
196
|
parser.add_argument(
|
210
197
|
"--tokenizer-mode",
|
211
198
|
type=str,
|
@@ -279,7 +266,6 @@ class ServerArgs:
|
|
279
266
|
"marlin",
|
280
267
|
"gptq_marlin",
|
281
268
|
"awq_marlin",
|
282
|
-
"squeezellm",
|
283
269
|
"bitsandbytes",
|
284
270
|
],
|
285
271
|
help="The quantization method.",
|
@@ -426,14 +412,17 @@ class ServerArgs:
|
|
426
412
|
|
427
413
|
# Multi-node distributed serving args
|
428
414
|
parser.add_argument(
|
429
|
-
"--
|
415
|
+
"--dist-init-addr",
|
416
|
+
"--nccl-init-addr", # For backward compatbility. This will be removed in the future.
|
430
417
|
type=str,
|
431
|
-
help="The
|
418
|
+
help="The host address for initializing distributed backend (e.g., `192.168.0.2:25000`).",
|
432
419
|
)
|
433
420
|
parser.add_argument(
|
434
421
|
"--nnodes", type=int, default=ServerArgs.nnodes, help="The number of nodes."
|
435
422
|
)
|
436
|
-
parser.add_argument(
|
423
|
+
parser.add_argument(
|
424
|
+
"--node-rank", type=int, default=ServerArgs.node_rank, help="The node rank."
|
425
|
+
)
|
437
426
|
|
438
427
|
# Model override args
|
439
428
|
parser.add_argument(
|
@@ -567,7 +556,10 @@ class ServerArgs:
|
|
567
556
|
return cls(**{attr: getattr(args, attr) for attr in attrs})
|
568
557
|
|
569
558
|
def url(self):
|
570
|
-
|
559
|
+
if is_ipv6(self.host):
|
560
|
+
return f"http://[{self.host}]:{self.port}"
|
561
|
+
else:
|
562
|
+
return f"http://{self.host}:{self.port}"
|
571
563
|
|
572
564
|
def check_server_args(self):
|
573
565
|
assert (
|
@@ -583,6 +575,21 @@ class ServerArgs:
|
|
583
575
|
and (self.lora_paths is None or self.disable_radix_cache)
|
584
576
|
), "compatibility of lora and cuda graph and radix attention is in progress"
|
585
577
|
|
578
|
+
assert self.dp_size == 1, (
|
579
|
+
"The support for data parallelism is temporarily disabled during refactor. "
|
580
|
+
"Please use sglang<=0.3.2 or wait for later updates."
|
581
|
+
)
|
582
|
+
|
583
|
+
if isinstance(self.lora_paths, list):
|
584
|
+
lora_paths = self.lora_paths
|
585
|
+
self.lora_paths = {}
|
586
|
+
for lora_path in lora_paths:
|
587
|
+
if "=" in lora_path:
|
588
|
+
name, path = lora_path.split("=", 1)
|
589
|
+
self.lora_paths[name] = path
|
590
|
+
else:
|
591
|
+
self.lora_paths[lora_path] = lora_path
|
592
|
+
|
586
593
|
|
587
594
|
def prepare_server_args(argv: List[str]) -> ServerArgs:
|
588
595
|
"""
|
@@ -604,11 +611,31 @@ def prepare_server_args(argv: List[str]) -> ServerArgs:
|
|
604
611
|
|
605
612
|
@dataclasses.dataclass
|
606
613
|
class PortArgs:
|
607
|
-
|
608
|
-
|
609
|
-
|
614
|
+
# The ipc filename for tokenizer to receive inputs from detokenizer (zmq)
|
615
|
+
tokenizer_ipc_name: str
|
616
|
+
# The ipc filename for scheduler (rank 0) to receive inputs from tokenizer (zmq)
|
617
|
+
scheduler_input_ipc_name: str
|
618
|
+
# The ipc filename for detokenizer to receive inputs from scheduler (zmq)
|
619
|
+
detokenizer_ipc_name: str
|
620
|
+
|
621
|
+
# The port for nccl initialization for multiple TP groups (torch.dist)
|
610
622
|
nccl_ports: List[int]
|
611
623
|
|
624
|
+
@classmethod
|
625
|
+
def init_new(self, server_args):
|
626
|
+
port = server_args.port + 1
|
627
|
+
while True:
|
628
|
+
if is_port_available(port):
|
629
|
+
break
|
630
|
+
port += 1
|
631
|
+
|
632
|
+
return PortArgs(
|
633
|
+
tokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
|
634
|
+
scheduler_input_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
|
635
|
+
detokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
|
636
|
+
nccl_ports=[port],
|
637
|
+
)
|
638
|
+
|
612
639
|
|
613
640
|
class LoRAPathAction(argparse.Action):
|
614
641
|
def __call__(self, parser, namespace, values, option_string=None):
|
sglang/srt/utils.py
CHANGED
@@ -16,14 +16,16 @@ limitations under the License.
|
|
16
16
|
"""Common utilities."""
|
17
17
|
|
18
18
|
import base64
|
19
|
-
import
|
19
|
+
import ipaddress
|
20
|
+
import json
|
20
21
|
import logging
|
21
22
|
import os
|
23
|
+
import pickle
|
22
24
|
import random
|
23
25
|
import resource
|
24
26
|
import socket
|
25
|
-
import struct
|
26
27
|
import time
|
28
|
+
import warnings
|
27
29
|
from importlib.metadata import PackageNotFoundError, version
|
28
30
|
from io import BytesIO
|
29
31
|
from typing import Any, Dict, List, Optional, Union
|
@@ -36,7 +38,7 @@ import torch.distributed as dist
|
|
36
38
|
from fastapi.responses import JSONResponse
|
37
39
|
from packaging import version as pkg_version
|
38
40
|
from torch import nn
|
39
|
-
from torch.
|
41
|
+
from torch.profiler import ProfilerActivity, profile, record_function
|
40
42
|
from triton.runtime.cache import (
|
41
43
|
FileCacheManager,
|
42
44
|
default_cache_dir,
|
@@ -51,11 +53,27 @@ show_time_cost = False
|
|
51
53
|
time_infos = {}
|
52
54
|
|
53
55
|
|
54
|
-
# torch flag AMD GPU
|
55
56
|
def is_hip() -> bool:
|
57
|
+
"""Return whether it is HIP on the AMD ROCm platform."""
|
56
58
|
return torch.version.hip is not None
|
57
59
|
|
58
60
|
|
61
|
+
def is_flashinfer_available():
|
62
|
+
"""
|
63
|
+
Check whether flashinfer is available.
|
64
|
+
As of Oct. 6, 2024, it is only available on NVIDIA GPUs.
|
65
|
+
"""
|
66
|
+
return torch.cuda.is_available() and not is_hip()
|
67
|
+
|
68
|
+
|
69
|
+
def is_ipv6(address):
|
70
|
+
try:
|
71
|
+
ipaddress.IPv6Address(address)
|
72
|
+
return True
|
73
|
+
except ipaddress.AddressValueError:
|
74
|
+
return False
|
75
|
+
|
76
|
+
|
59
77
|
def enable_show_time_cost():
|
60
78
|
global show_time_cost
|
61
79
|
show_time_cost = True
|
@@ -170,35 +188,6 @@ def is_port_available(port):
|
|
170
188
|
return False
|
171
189
|
|
172
190
|
|
173
|
-
def allocate_init_ports(
|
174
|
-
port: Optional[int] = None,
|
175
|
-
additional_ports: Optional[List[int]] = None,
|
176
|
-
dp_size: int = 1,
|
177
|
-
):
|
178
|
-
"""Allocate ports for all connections."""
|
179
|
-
if additional_ports:
|
180
|
-
ret_ports = [port] + additional_ports
|
181
|
-
else:
|
182
|
-
ret_ports = [port]
|
183
|
-
|
184
|
-
ret_ports = list(set(x for x in ret_ports if is_port_available(x)))
|
185
|
-
cur_port = ret_ports[-1] + 1 if len(ret_ports) > 0 else 10000
|
186
|
-
|
187
|
-
# HTTP + Tokenizer + Controller + Detokenizer + dp_size * 1 (nccl)
|
188
|
-
num_ports_needed = 4 + dp_size
|
189
|
-
while len(ret_ports) < num_ports_needed:
|
190
|
-
if cur_port not in ret_ports and is_port_available(cur_port):
|
191
|
-
ret_ports.append(cur_port)
|
192
|
-
cur_port += 1
|
193
|
-
|
194
|
-
if port is not None and ret_ports[0] != port:
|
195
|
-
logger.warning(
|
196
|
-
f"WARNING: Port {port} is not available. Use port {ret_ports[0]} instead."
|
197
|
-
)
|
198
|
-
|
199
|
-
return ret_ports[0], ret_ports[1:num_ports_needed]
|
200
|
-
|
201
|
-
|
202
191
|
def is_multimodal_model(model_architectures):
|
203
192
|
if (
|
204
193
|
"LlavaLlamaForCausalLM" in model_architectures
|
@@ -219,6 +208,8 @@ def is_generation_model(model_architectures, is_embedding: bool = False):
|
|
219
208
|
if (
|
220
209
|
"LlamaEmbeddingModel" in model_architectures
|
221
210
|
or "MistralModel" in model_architectures
|
211
|
+
or "LlamaForSequenceClassification" in model_architectures
|
212
|
+
or "LlamaForSequenceClassificationWithNormal_Weights" in model_architectures
|
222
213
|
):
|
223
214
|
return False
|
224
215
|
else:
|
@@ -345,6 +336,10 @@ def suppress_other_loggers():
|
|
345
336
|
logging.getLogger("vllm.selector").setLevel(logging.WARN)
|
346
337
|
logging.getLogger("vllm.utils").setLevel(logging.ERROR)
|
347
338
|
|
339
|
+
warnings.filterwarnings(
|
340
|
+
"ignore", category=UserWarning, message="The given NumPy array is not writable"
|
341
|
+
)
|
342
|
+
|
348
343
|
|
349
344
|
def assert_pkg_version(pkg: str, min_version: str, message: str):
|
350
345
|
try:
|
@@ -537,89 +532,6 @@ class CustomCacheManager(FileCacheManager):
|
|
537
532
|
raise RuntimeError("Could not create or locate cache dir")
|
538
533
|
|
539
534
|
|
540
|
-
def get_ip_address(ifname):
|
541
|
-
"""
|
542
|
-
Get the IP address of a network interface.
|
543
|
-
|
544
|
-
:param ifname: Name of the network interface (e.g., 'eth0')
|
545
|
-
:return: IP address of the network interface
|
546
|
-
"""
|
547
|
-
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
548
|
-
ip_address = fcntl.ioctl(
|
549
|
-
s.fileno(),
|
550
|
-
0x8915, # SIOCGIFADDR
|
551
|
-
struct.pack("256s", bytes(ifname[:15], "utf-8")),
|
552
|
-
)[20:24]
|
553
|
-
return socket.inet_ntoa(ip_address)
|
554
|
-
|
555
|
-
|
556
|
-
def send_addrs_to_rank_0(model_port_args, server_args):
|
557
|
-
assert server_args.node_rank != 0 and server_args.dp_size == 1
|
558
|
-
|
559
|
-
ifname = os.environ.get(
|
560
|
-
"SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0")
|
561
|
-
)
|
562
|
-
ip_addr = get_ip_address(ifname)
|
563
|
-
|
564
|
-
num_tp_ports = server_args.tp_size // server_args.nnodes
|
565
|
-
model_port_args.model_tp_ips[:num_tp_ports] = [ip_addr] * num_tp_ports
|
566
|
-
ip_addr = [int(x) for x in ip_addr.split(".")]
|
567
|
-
addrs_tensor = torch.tensor(
|
568
|
-
ip_addr + model_port_args.model_tp_ports, dtype=torch.int
|
569
|
-
)
|
570
|
-
|
571
|
-
init_method = f"tcp://{server_args.nccl_init_addr}"
|
572
|
-
dist.init_process_group(
|
573
|
-
backend="gloo",
|
574
|
-
init_method=init_method,
|
575
|
-
rank=server_args.node_rank,
|
576
|
-
world_size=server_args.nnodes,
|
577
|
-
)
|
578
|
-
dist.send(addrs_tensor, dst=0)
|
579
|
-
print(
|
580
|
-
f"Node {server_args.node_rank} sent: ip_address {ip_addr} and ports {model_port_args.model_tp_ports}"
|
581
|
-
)
|
582
|
-
|
583
|
-
dist.barrier()
|
584
|
-
dist.destroy_process_group()
|
585
|
-
|
586
|
-
|
587
|
-
def receive_addrs(model_port_args, server_args):
|
588
|
-
assert server_args.node_rank == 0 and server_args.dp_size == 1
|
589
|
-
|
590
|
-
ifname = os.environ.get(
|
591
|
-
"SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0")
|
592
|
-
)
|
593
|
-
ip_addr = get_ip_address(ifname)
|
594
|
-
|
595
|
-
num_tp_ports = server_args.tp_size // server_args.nnodes
|
596
|
-
model_port_args.model_tp_ips[:num_tp_ports] = [ip_addr] * num_tp_ports
|
597
|
-
|
598
|
-
init_method = f"tcp://{server_args.nccl_init_addr}"
|
599
|
-
dist.init_process_group(
|
600
|
-
backend="gloo",
|
601
|
-
init_method=init_method,
|
602
|
-
rank=server_args.node_rank,
|
603
|
-
world_size=server_args.nnodes,
|
604
|
-
)
|
605
|
-
|
606
|
-
for src_rank in range(1, server_args.nnodes):
|
607
|
-
tensor = torch.zeros(4 + num_tp_ports, dtype=torch.int)
|
608
|
-
dist.recv(tensor, src=src_rank)
|
609
|
-
ip = ".".join([str(x) for x in tensor[:4].tolist()])
|
610
|
-
ports = tensor[4:].tolist()
|
611
|
-
model_port_args.model_tp_ips[
|
612
|
-
num_tp_ports * src_rank : num_tp_ports * (src_rank + 1)
|
613
|
-
] = [ip] * num_tp_ports
|
614
|
-
model_port_args.model_tp_ports[
|
615
|
-
num_tp_ports * src_rank : num_tp_ports * (src_rank + 1)
|
616
|
-
] = ports
|
617
|
-
print(f"Node 0 received from rank {src_rank}: {tensor.tolist()}")
|
618
|
-
|
619
|
-
dist.barrier()
|
620
|
-
dist.destroy_process_group()
|
621
|
-
|
622
|
-
|
623
535
|
def set_ulimit(target_soft_limit=65535):
|
624
536
|
resource_type = resource.RLIMIT_NOFILE
|
625
537
|
current_soft, current_hard = resource.getrlimit(resource_type)
|
@@ -643,24 +555,16 @@ def add_api_key_middleware(app, api_key: str):
|
|
643
555
|
return await call_next(request)
|
644
556
|
|
645
557
|
|
646
|
-
def
|
558
|
+
def prepare_model_and_tokenizer(model_path: str, tokenizer_path: str):
|
647
559
|
if "SGLANG_USE_MODELSCOPE" in os.environ:
|
648
560
|
if not os.path.exists(model_path):
|
649
561
|
from modelscope import snapshot_download
|
650
562
|
|
651
|
-
|
652
|
-
|
653
|
-
|
654
|
-
|
655
|
-
def prepare_tokenizer(tokenizer_path: str):
|
656
|
-
if "SGLANG_USE_MODELSCOPE" in os.environ:
|
657
|
-
if not os.path.exists(tokenizer_path):
|
658
|
-
from modelscope import snapshot_download
|
659
|
-
|
660
|
-
return snapshot_download(
|
563
|
+
model_path = snapshot_download(model_path)
|
564
|
+
tokenizer_path = snapshot_download(
|
661
565
|
tokenizer_path, ignore_patterns=["*.bin", "*.safetensors"]
|
662
566
|
)
|
663
|
-
return tokenizer_path
|
567
|
+
return model_path, tokenizer_path
|
664
568
|
|
665
569
|
|
666
570
|
def configure_logger(server_args, prefix: str = ""):
|
@@ -702,3 +606,72 @@ def set_weight_attrs(
|
|
702
606
|
for key, value in weight_attrs.items():
|
703
607
|
assert not hasattr(weight, key), f"Overwriting existing tensor attribute: {key}"
|
704
608
|
setattr(weight, key, value)
|
609
|
+
|
610
|
+
|
611
|
+
def broadcast_pyobj(
|
612
|
+
data: List[Any],
|
613
|
+
rank: int,
|
614
|
+
dist_group: Optional[torch.distributed.ProcessGroup] = None,
|
615
|
+
):
|
616
|
+
"""Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
|
617
|
+
|
618
|
+
if rank == 0:
|
619
|
+
if len(data) == 0:
|
620
|
+
tensor_size = torch.tensor([0], dtype=torch.long)
|
621
|
+
dist.broadcast(tensor_size, src=0, group=dist_group)
|
622
|
+
else:
|
623
|
+
serialized_data = pickle.dumps(data)
|
624
|
+
size = len(serialized_data)
|
625
|
+
tensor_data = torch.ByteTensor(
|
626
|
+
np.frombuffer(serialized_data, dtype=np.uint8)
|
627
|
+
)
|
628
|
+
tensor_size = torch.tensor([size], dtype=torch.long)
|
629
|
+
|
630
|
+
dist.broadcast(tensor_size, src=0, group=dist_group)
|
631
|
+
dist.broadcast(tensor_data, src=0, group=dist_group)
|
632
|
+
return data
|
633
|
+
else:
|
634
|
+
tensor_size = torch.tensor([0], dtype=torch.long)
|
635
|
+
dist.broadcast(tensor_size, src=0, group=dist_group)
|
636
|
+
size = tensor_size.item()
|
637
|
+
|
638
|
+
if size == 0:
|
639
|
+
return []
|
640
|
+
|
641
|
+
tensor_data = torch.empty(size, dtype=torch.uint8)
|
642
|
+
dist.broadcast(tensor_data, src=0, group=dist_group)
|
643
|
+
|
644
|
+
serialized_data = bytes(tensor_data.cpu().numpy())
|
645
|
+
data = pickle.loads(serialized_data)
|
646
|
+
return data
|
647
|
+
|
648
|
+
|
649
|
+
step_counter = 0
|
650
|
+
|
651
|
+
|
652
|
+
def pytorch_profile(name, func, *args, data_size=-1):
|
653
|
+
"""
|
654
|
+
Args:
|
655
|
+
name (string): the name of recorded function.
|
656
|
+
func: the function to be profiled.
|
657
|
+
args: the arguments of the profiled function.
|
658
|
+
data_size (int): some measurement of the computation complexity.
|
659
|
+
Usually, it could be the batch size.
|
660
|
+
"""
|
661
|
+
global step_counter
|
662
|
+
os.makedirs("trace", exist_ok=True)
|
663
|
+
with profile(
|
664
|
+
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
665
|
+
# schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2),
|
666
|
+
# on_trace_ready=tensorboard_trace_handler('./log_dir'),
|
667
|
+
record_shapes=True,
|
668
|
+
profile_memory=True,
|
669
|
+
with_stack=True,
|
670
|
+
) as prof:
|
671
|
+
with record_function(name):
|
672
|
+
with open(f"trace/size_{step_counter}.json", "w") as f:
|
673
|
+
json.dump({"size": data_size}, f)
|
674
|
+
result = func(*args)
|
675
|
+
prof.export_chrome_trace(f"trace/{name}_{step_counter}.json")
|
676
|
+
step_counter += 1
|
677
|
+
return result
|
sglang/test/runners.py
CHANGED
@@ -65,17 +65,18 @@ class ModelOutput:
|
|
65
65
|
top_input_logprobs: List[torch.Tensor] = None
|
66
66
|
top_output_logprobs: List[torch.Tensor] = None
|
67
67
|
embed_logits: List[torch.Tensor] = None
|
68
|
+
scores: List[float] = None
|
68
69
|
|
69
70
|
|
70
71
|
class HFRunner:
|
71
72
|
def __init__(
|
72
73
|
self,
|
73
|
-
model_path,
|
74
|
-
torch_dtype,
|
75
|
-
|
76
|
-
output_str_only=False,
|
74
|
+
model_path: str,
|
75
|
+
torch_dtype: torch.dtype,
|
76
|
+
model_type: str = "generation",
|
77
|
+
output_str_only: bool = False,
|
77
78
|
):
|
78
|
-
self.
|
79
|
+
self.model_type = model_type
|
79
80
|
self.output_str_only = output_str_only
|
80
81
|
|
81
82
|
self.in_queue = mp.Queue()
|
@@ -92,22 +93,41 @@ class HFRunner:
|
|
92
93
|
)
|
93
94
|
self.model_proc.start()
|
94
95
|
|
96
|
+
def needs_trust_remote_code(self, model_path):
|
97
|
+
models_needs_trust_remote = [
|
98
|
+
"LxzGordon/URM-LLaMa-3.1-8B",
|
99
|
+
]
|
100
|
+
if model_path in models_needs_trust_remote:
|
101
|
+
return True
|
102
|
+
return False
|
103
|
+
|
95
104
|
def start_model_process(self, in_queue, out_queue, model_path, torch_dtype):
|
96
|
-
self.tokenizer = get_tokenizer(model_path)
|
97
|
-
|
105
|
+
self.tokenizer = get_tokenizer(model_path, torch_dtype=torch.dtype)
|
106
|
+
|
107
|
+
if self.model_type == "generation":
|
98
108
|
self.base_model = AutoModelForCausalLM.from_pretrained(
|
99
109
|
model_path,
|
100
110
|
torch_dtype=torch_dtype,
|
101
111
|
trust_remote_code=False,
|
102
112
|
low_cpu_mem_usage=True,
|
103
113
|
).cuda()
|
104
|
-
|
114
|
+
elif self.model_type == "embedding":
|
105
115
|
from sentence_transformers import SentenceTransformer
|
106
116
|
|
107
117
|
self.model = SentenceTransformer(
|
108
118
|
model_path,
|
109
119
|
model_kwargs={"torch_dtype": torch_dtype},
|
110
|
-
)
|
120
|
+
).cuda()
|
121
|
+
elif self.model_type == "reward":
|
122
|
+
from transformers import AutoModelForSequenceClassification
|
123
|
+
|
124
|
+
self.model = AutoModelForSequenceClassification.from_pretrained(
|
125
|
+
model_path,
|
126
|
+
torch_dtype=torch_dtype,
|
127
|
+
trust_remote_code=self.needs_trust_remote_code(model_path),
|
128
|
+
).cuda()
|
129
|
+
else:
|
130
|
+
raise Exception(f"Unrecognized model type {self.model_type}")
|
111
131
|
|
112
132
|
while True:
|
113
133
|
prompts, max_new_tokens, lora_paths = in_queue.get()
|
@@ -115,7 +135,7 @@ class HFRunner:
|
|
115
135
|
assert len(prompts) == len(lora_paths)
|
116
136
|
|
117
137
|
if prompts is not None:
|
118
|
-
if self.
|
138
|
+
if self.model_type == "generation":
|
119
139
|
output_strs = []
|
120
140
|
top_input_logprobs = []
|
121
141
|
top_output_logprobs = []
|
@@ -179,11 +199,27 @@ class HFRunner:
|
|
179
199
|
)
|
180
200
|
)
|
181
201
|
|
182
|
-
|
202
|
+
elif self.model_type == "embedding":
|
183
203
|
assert not self.output_str_only
|
184
204
|
logits = self.model.encode(prompts).tolist()
|
185
205
|
out_queue.put(ModelOutput(embed_logits=logits))
|
186
206
|
|
207
|
+
elif self.model_type == "reward":
|
208
|
+
scores = []
|
209
|
+
for conv in prompts:
|
210
|
+
conv_formatted = self.tokenizer.apply_chat_template(
|
211
|
+
conv, tokenize=False
|
212
|
+
)
|
213
|
+
conv_tokenized = self.tokenizer(
|
214
|
+
conv_formatted, return_tensors="pt"
|
215
|
+
).to("cuda")
|
216
|
+
scores.append(
|
217
|
+
float(self.model(**conv_tokenized).logits[0][0].item())
|
218
|
+
)
|
219
|
+
out_queue.put(ModelOutput(scores=scores))
|
220
|
+
else:
|
221
|
+
raise Exception(f"Unrecognized model type {self.model_type}")
|
222
|
+
|
187
223
|
def forward(
|
188
224
|
self,
|
189
225
|
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
|
@@ -208,23 +244,24 @@ class HFRunner:
|
|
208
244
|
class SRTRunner:
|
209
245
|
def __init__(
|
210
246
|
self,
|
211
|
-
model_path,
|
212
|
-
torch_dtype,
|
213
|
-
|
214
|
-
tp_size=1,
|
215
|
-
port=DEFAULT_PORT_FOR_SRT_TEST_RUNNER,
|
216
|
-
lora_paths=None,
|
217
|
-
max_loras_per_batch=4,
|
218
|
-
disable_cuda_graph=False,
|
219
|
-
disable_radix_cache=False,
|
247
|
+
model_path: str,
|
248
|
+
torch_dtype: torch.dtype,
|
249
|
+
model_type: str,
|
250
|
+
tp_size: int = 1,
|
251
|
+
port: int = DEFAULT_PORT_FOR_SRT_TEST_RUNNER,
|
252
|
+
lora_paths: List[str] = None,
|
253
|
+
max_loras_per_batch: int = 4,
|
254
|
+
disable_cuda_graph: bool = False,
|
255
|
+
disable_radix_cache: bool = False,
|
220
256
|
):
|
221
|
-
self.
|
257
|
+
self.model_type = model_type
|
258
|
+
self.is_generation = model_type == "generation"
|
222
259
|
self.runtime = Runtime(
|
223
260
|
model_path=model_path,
|
224
261
|
tp_size=tp_size,
|
225
262
|
dtype=get_dtype_str(torch_dtype),
|
226
263
|
port=port,
|
227
|
-
mem_fraction_static=0.
|
264
|
+
mem_fraction_static=0.65,
|
228
265
|
trust_remote_code=False,
|
229
266
|
is_embedding=not self.is_generation,
|
230
267
|
lora_paths=lora_paths,
|
@@ -285,8 +322,12 @@ class SRTRunner:
|
|
285
322
|
else:
|
286
323
|
response = self.runtime.encode(prompts)
|
287
324
|
response = json.loads(response)
|
288
|
-
|
289
|
-
|
325
|
+
if self.model_type == "embedding":
|
326
|
+
logits = [x["embedding"] for x in response]
|
327
|
+
return ModelOutput(embed_logits=logits)
|
328
|
+
else:
|
329
|
+
scores = [x["embedding"][0] for x in response]
|
330
|
+
return ModelOutput(scores=scores)
|
290
331
|
|
291
332
|
def batch_forward(
|
292
333
|
self,
|
@@ -316,8 +357,12 @@ class SRTRunner:
|
|
316
357
|
else:
|
317
358
|
response = self.runtime.encode(prompts)
|
318
359
|
response = json.loads(response)
|
319
|
-
|
320
|
-
|
360
|
+
if self.model_type == "embedding":
|
361
|
+
logits = [x["embedding"] for x in response]
|
362
|
+
return ModelOutput(embed_logits=logits)
|
363
|
+
else:
|
364
|
+
scores = [x["embedding"][0] for x in response]
|
365
|
+
return ModelOutput(scores=logits)
|
321
366
|
|
322
367
|
def __enter__(self):
|
323
368
|
return self
|