sglang 0.4.1.post5__py3-none-any.whl → 0.4.1.post7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +21 -23
- sglang/api.py +2 -7
- sglang/bench_offline_throughput.py +24 -16
- sglang/bench_one_batch.py +51 -3
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +37 -28
- sglang/lang/backend/runtime_endpoint.py +183 -4
- sglang/lang/chat_template.py +15 -4
- sglang/launch_server.py +1 -1
- sglang/srt/_custom_ops.py +80 -42
- sglang/srt/configs/device_config.py +1 -1
- sglang/srt/configs/model_config.py +16 -6
- sglang/srt/constrained/base_grammar_backend.py +21 -0
- sglang/srt/constrained/xgrammar_backend.py +8 -4
- sglang/srt/conversation.py +14 -1
- sglang/srt/distributed/__init__.py +3 -3
- sglang/srt/distributed/communication_op.py +2 -1
- sglang/srt/distributed/device_communicators/cuda_wrapper.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +107 -40
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
- sglang/srt/distributed/device_communicators/hpu_communicator.py +2 -1
- sglang/srt/distributed/device_communicators/pynccl.py +80 -1
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +112 -2
- sglang/srt/distributed/device_communicators/shm_broadcast.py +5 -72
- sglang/srt/distributed/device_communicators/xpu_communicator.py +2 -1
- sglang/srt/distributed/parallel_state.py +1 -1
- sglang/srt/distributed/utils.py +2 -1
- sglang/srt/entrypoints/engine.py +449 -0
- sglang/srt/entrypoints/http_server.py +579 -0
- sglang/srt/layers/activation.py +3 -3
- sglang/srt/layers/attention/flashinfer_backend.py +27 -12
- sglang/srt/layers/attention/triton_backend.py +4 -6
- sglang/srt/layers/attention/vision.py +204 -0
- sglang/srt/layers/dp_attention.py +69 -0
- sglang/srt/layers/linear.py +76 -102
- sglang/srt/layers/logits_processor.py +48 -63
- sglang/srt/layers/moe/ep_moe/layer.py +4 -4
- sglang/srt/layers/moe/fused_moe_native.py +69 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -6
- sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -14
- sglang/srt/layers/moe/topk.py +4 -2
- sglang/srt/layers/parameter.py +26 -17
- sglang/srt/layers/quantization/__init__.py +22 -23
- sglang/srt/layers/quantization/fp8.py +112 -55
- sglang/srt/layers/quantization/fp8_utils.py +1 -1
- sglang/srt/layers/quantization/int8_kernel.py +54 -0
- sglang/srt/layers/quantization/modelopt_quant.py +2 -3
- sglang/srt/layers/quantization/w8a8_int8.py +117 -0
- sglang/srt/layers/radix_attention.py +2 -0
- sglang/srt/layers/rotary_embedding.py +1179 -31
- sglang/srt/layers/sampler.py +39 -1
- sglang/srt/layers/vocab_parallel_embedding.py +17 -4
- sglang/srt/lora/lora.py +1 -9
- sglang/srt/managers/configure_logging.py +46 -0
- sglang/srt/managers/data_parallel_controller.py +79 -72
- sglang/srt/managers/detokenizer_manager.py +23 -8
- sglang/srt/managers/image_processor.py +158 -2
- sglang/srt/managers/io_struct.py +54 -15
- sglang/srt/managers/schedule_batch.py +49 -22
- sglang/srt/managers/schedule_policy.py +26 -12
- sglang/srt/managers/scheduler.py +319 -181
- sglang/srt/managers/session_controller.py +1 -0
- sglang/srt/managers/tokenizer_manager.py +303 -158
- sglang/srt/managers/tp_worker.py +6 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +5 -8
- sglang/srt/managers/utils.py +44 -0
- sglang/srt/mem_cache/memory_pool.py +110 -77
- sglang/srt/metrics/collector.py +25 -11
- sglang/srt/model_executor/cuda_graph_runner.py +4 -6
- sglang/srt/model_executor/model_runner.py +80 -21
- sglang/srt/model_loader/loader.py +8 -6
- sglang/srt/model_loader/weight_utils.py +55 -2
- sglang/srt/models/baichuan.py +6 -6
- sglang/srt/models/chatglm.py +2 -2
- sglang/srt/models/commandr.py +3 -3
- sglang/srt/models/dbrx.py +4 -4
- sglang/srt/models/deepseek.py +3 -3
- sglang/srt/models/deepseek_v2.py +8 -8
- sglang/srt/models/exaone.py +2 -2
- sglang/srt/models/gemma.py +2 -2
- sglang/srt/models/gemma2.py +6 -24
- sglang/srt/models/gpt2.py +3 -5
- sglang/srt/models/gpt_bigcode.py +1 -1
- sglang/srt/models/granite.py +2 -2
- sglang/srt/models/grok.py +3 -3
- sglang/srt/models/internlm2.py +2 -2
- sglang/srt/models/llama.py +41 -4
- sglang/srt/models/minicpm.py +2 -2
- sglang/srt/models/minicpm3.py +6 -6
- sglang/srt/models/minicpmv.py +1238 -0
- sglang/srt/models/mixtral.py +3 -3
- sglang/srt/models/mixtral_quant.py +3 -3
- sglang/srt/models/mllama.py +2 -2
- sglang/srt/models/olmo.py +3 -3
- sglang/srt/models/olmo2.py +4 -4
- sglang/srt/models/olmoe.py +7 -13
- sglang/srt/models/phi3_small.py +2 -2
- sglang/srt/models/qwen.py +2 -2
- sglang/srt/models/qwen2.py +52 -4
- sglang/srt/models/qwen2_eagle.py +131 -0
- sglang/srt/models/qwen2_moe.py +3 -3
- sglang/srt/models/qwen2_vl.py +22 -122
- sglang/srt/models/stablelm.py +2 -2
- sglang/srt/models/torch_native_llama.py +3 -3
- sglang/srt/models/xverse.py +6 -6
- sglang/srt/models/xverse_moe.py +6 -6
- sglang/srt/openai_api/protocol.py +2 -0
- sglang/srt/sampling/custom_logit_processor.py +38 -0
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +15 -5
- sglang/srt/sampling/sampling_batch_info.py +153 -9
- sglang/srt/sampling/sampling_params.py +4 -2
- sglang/srt/server.py +4 -1037
- sglang/srt/server_args.py +84 -32
- sglang/srt/speculative/eagle_worker.py +1 -0
- sglang/srt/torch_memory_saver_adapter.py +59 -0
- sglang/srt/utils.py +130 -63
- sglang/test/runners.py +8 -13
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +3 -1
- sglang/utils.py +12 -2
- sglang/version.py +1 -1
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/METADATA +26 -13
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/RECORD +126 -117
- sglang/launch_server_llavavid.py +0 -25
- sglang/srt/constrained/__init__.py +0 -16
- sglang/srt/distributed/device_communicators/__init__.py +0 -0
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/WHEEL +0 -0
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/top_level.txt +0 -0
sglang/srt/server_args.py
CHANGED
@@ -23,15 +23,15 @@ from typing import List, Optional
|
|
23
23
|
import torch
|
24
24
|
|
25
25
|
from sglang.srt.hf_transformers_utils import check_gguf_file
|
26
|
-
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
27
26
|
from sglang.srt.utils import (
|
28
27
|
get_amdgpu_memory_capacity,
|
29
28
|
get_hpu_memory_capacity,
|
30
29
|
get_nvgpu_memory_capacity,
|
31
30
|
is_flashinfer_available,
|
32
31
|
is_hip,
|
33
|
-
is_ipv6,
|
34
32
|
is_port_available,
|
33
|
+
is_valid_ipv6_address,
|
34
|
+
nullable_str,
|
35
35
|
)
|
36
36
|
|
37
37
|
logger = logging.getLogger(__name__)
|
@@ -47,6 +47,7 @@ class ServerArgs:
|
|
47
47
|
trust_remote_code: bool = True
|
48
48
|
dtype: str = "auto"
|
49
49
|
kv_cache_dtype: str = "auto"
|
50
|
+
quantization_param_path: nullable_str = None
|
50
51
|
quantization: Optional[str] = None
|
51
52
|
context_length: Optional[int] = None
|
52
53
|
device: str = "cuda"
|
@@ -55,7 +56,6 @@ class ServerArgs:
|
|
55
56
|
is_embedding: bool = False
|
56
57
|
revision: Optional[str] = None
|
57
58
|
skip_tokenizer_init: bool = False
|
58
|
-
return_token_ids: bool = False
|
59
59
|
|
60
60
|
# Port for the HTTP server
|
61
61
|
host: str = "127.0.0.1"
|
@@ -91,7 +91,7 @@ class ServerArgs:
|
|
91
91
|
|
92
92
|
# API related
|
93
93
|
api_key: Optional[str] = None
|
94
|
-
file_storage_pth: str = "
|
94
|
+
file_storage_pth: str = "sglang_storage"
|
95
95
|
enable_cache_report: bool = False
|
96
96
|
|
97
97
|
# Data parallelism
|
@@ -156,6 +156,11 @@ class ServerArgs:
|
|
156
156
|
triton_attention_num_kv_splits: int = 8
|
157
157
|
num_continuous_decode_steps: int = 1
|
158
158
|
delete_ckpt_after_loading: bool = False
|
159
|
+
enable_memory_saver: bool = False
|
160
|
+
allow_auto_truncate: bool = False
|
161
|
+
|
162
|
+
# Custom logit processor
|
163
|
+
enable_custom_logit_processor: bool = False
|
159
164
|
|
160
165
|
def __post_init__(self):
|
161
166
|
# Set missing default values
|
@@ -239,14 +244,13 @@ class ServerArgs:
|
|
239
244
|
# Others
|
240
245
|
if self.enable_dp_attention:
|
241
246
|
self.dp_size = self.tp_size
|
247
|
+
assert self.tp_size % self.dp_size == 0
|
242
248
|
self.chunked_prefill_size = self.chunked_prefill_size // 2
|
243
249
|
self.schedule_conservativeness = self.schedule_conservativeness * 0.3
|
244
|
-
self.disable_overlap_schedule = True
|
245
250
|
logger.warning(
|
246
251
|
f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. "
|
247
252
|
f"The schedule conservativeness is adjusted to {self.schedule_conservativeness}. "
|
248
253
|
"Data parallel size is adjusted to be the same as tensor parallel size. "
|
249
|
-
"Overlap scheduler is disabled."
|
250
254
|
)
|
251
255
|
|
252
256
|
# Speculative Decoding
|
@@ -296,6 +300,11 @@ class ServerArgs:
|
|
296
300
|
"tokenizer if available, and 'slow' will "
|
297
301
|
"always use the slow tokenizer.",
|
298
302
|
)
|
303
|
+
parser.add_argument(
|
304
|
+
"--skip-tokenizer-init",
|
305
|
+
action="store_true",
|
306
|
+
help="If set, skip init tokenizer and pass input_ids in generate request",
|
307
|
+
)
|
299
308
|
parser.add_argument(
|
300
309
|
"--load-format",
|
301
310
|
type=str,
|
@@ -346,8 +355,17 @@ class ServerArgs:
|
|
346
355
|
"--kv-cache-dtype",
|
347
356
|
type=str,
|
348
357
|
default=ServerArgs.kv_cache_dtype,
|
349
|
-
choices=["auto", "fp8_e5m2"],
|
350
|
-
help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" is supported for CUDA 11.8+.',
|
358
|
+
choices=["auto", "fp8_e5m2", "fp8_e4m3"],
|
359
|
+
help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" and "fp8_e4m3" is supported for CUDA 11.8+.',
|
360
|
+
)
|
361
|
+
parser.add_argument(
|
362
|
+
"--quantization-param-path",
|
363
|
+
type=nullable_str,
|
364
|
+
default=None,
|
365
|
+
help="Path to the JSON file containing the KV cache "
|
366
|
+
"scaling factors. This should generally be supplied, when "
|
367
|
+
"KV cache dtype is FP8. Otherwise, KV cache scaling factors "
|
368
|
+
"default to 1.0, which may cause accuracy issues. ",
|
351
369
|
)
|
352
370
|
parser.add_argument(
|
353
371
|
"--quantization",
|
@@ -363,6 +381,7 @@ class ServerArgs:
|
|
363
381
|
"bitsandbytes",
|
364
382
|
"gguf",
|
365
383
|
"modelopt",
|
384
|
+
"w8a8_int8",
|
366
385
|
],
|
367
386
|
help="The quantization method.",
|
368
387
|
)
|
@@ -376,7 +395,7 @@ class ServerArgs:
|
|
376
395
|
"--device",
|
377
396
|
type=str,
|
378
397
|
default="cuda",
|
379
|
-
choices=["cuda", "xpu", "hpu"],
|
398
|
+
choices=["cuda", "xpu", "hpu", "cpu"],
|
380
399
|
help="The device type.",
|
381
400
|
)
|
382
401
|
parser.add_argument(
|
@@ -404,18 +423,6 @@ class ServerArgs:
|
|
404
423
|
"name, a tag name, or a commit id. If unspecified, will use "
|
405
424
|
"the default version.",
|
406
425
|
)
|
407
|
-
parser.add_argument(
|
408
|
-
"--skip-tokenizer-init",
|
409
|
-
action="store_true",
|
410
|
-
help="If set, skip init tokenizer and pass input_ids in generate request",
|
411
|
-
)
|
412
|
-
parser.add_argument(
|
413
|
-
"--return-token-ids",
|
414
|
-
action="store_true",
|
415
|
-
default=ServerArgs.return_token_ids,
|
416
|
-
help="Whether to return token IDs in the output, this may introduce additional overhead.",
|
417
|
-
)
|
418
|
-
|
419
426
|
# Memory and scheduling
|
420
427
|
parser.add_argument(
|
421
428
|
"--mem-fraction-static",
|
@@ -551,7 +558,7 @@ class ServerArgs:
|
|
551
558
|
"--decode-log-interval",
|
552
559
|
type=int,
|
553
560
|
default=ServerArgs.decode_log_interval,
|
554
|
-
help="The log interval of decode batch",
|
561
|
+
help="The log interval of decode batch.",
|
555
562
|
)
|
556
563
|
|
557
564
|
# API related
|
@@ -851,6 +858,21 @@ class ServerArgs:
|
|
851
858
|
action="store_true",
|
852
859
|
help="Delete the model checkpoint after loading the model.",
|
853
860
|
)
|
861
|
+
parser.add_argument(
|
862
|
+
"--enable-memory-saver",
|
863
|
+
action="store_true",
|
864
|
+
help="Allow saving memory using release_memory_occupation and resume_memory_occupation",
|
865
|
+
)
|
866
|
+
parser.add_argument(
|
867
|
+
"--allow-auto-truncate",
|
868
|
+
action="store_true",
|
869
|
+
help="Allow automatically truncating requests that exceed the maximum input length instead of returning an error.",
|
870
|
+
)
|
871
|
+
parser.add_argument(
|
872
|
+
"--enable-custom-logit-processor",
|
873
|
+
action="store_true",
|
874
|
+
help="Enable users to pass custom logit processors to the server (disabled by default for security)",
|
875
|
+
)
|
854
876
|
|
855
877
|
@classmethod
|
856
878
|
def from_cli_args(cls, args: argparse.Namespace):
|
@@ -861,7 +883,7 @@ class ServerArgs:
|
|
861
883
|
return cls(**{attr: getattr(args, attr) for attr in attrs})
|
862
884
|
|
863
885
|
def url(self):
|
864
|
-
if
|
886
|
+
if is_valid_ipv6_address(self.host):
|
865
887
|
return f"http://[{self.host}]:{self.port}"
|
866
888
|
else:
|
867
889
|
return f"http://{self.host}:{self.port}"
|
@@ -871,8 +893,8 @@ class ServerArgs:
|
|
871
893
|
self.tp_size % self.nnodes == 0
|
872
894
|
), "tp_size must be divisible by number of nodes"
|
873
895
|
assert not (
|
874
|
-
self.dp_size > 1 and self.nnodes != 1
|
875
|
-
), "multi-node data parallel is not supported"
|
896
|
+
self.dp_size > 1 and self.nnodes != 1 and not self.enable_dp_attention
|
897
|
+
), "multi-node data parallel is not supported unless dp attention!"
|
876
898
|
assert (
|
877
899
|
self.max_loras_per_batch > 0
|
878
900
|
# FIXME
|
@@ -910,6 +932,9 @@ def prepare_server_args(argv: List[str]) -> ServerArgs:
|
|
910
932
|
return server_args
|
911
933
|
|
912
934
|
|
935
|
+
ZMQ_TCP_PORT_DELTA = 233
|
936
|
+
|
937
|
+
|
913
938
|
@dataclasses.dataclass
|
914
939
|
class PortArgs:
|
915
940
|
# The ipc filename for tokenizer to receive inputs from detokenizer (zmq)
|
@@ -923,7 +948,7 @@ class PortArgs:
|
|
923
948
|
nccl_port: int
|
924
949
|
|
925
950
|
@staticmethod
|
926
|
-
def init_new(server_args) -> "PortArgs":
|
951
|
+
def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs":
|
927
952
|
port = server_args.port + random.randint(100, 1000)
|
928
953
|
while True:
|
929
954
|
if is_port_available(port):
|
@@ -933,12 +958,39 @@ class PortArgs:
|
|
933
958
|
else:
|
934
959
|
port -= 43
|
935
960
|
|
936
|
-
|
937
|
-
|
938
|
-
|
939
|
-
|
940
|
-
|
941
|
-
|
961
|
+
if not server_args.enable_dp_attention:
|
962
|
+
# Normal case, use IPC within a single node
|
963
|
+
return PortArgs(
|
964
|
+
tokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
|
965
|
+
scheduler_input_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
|
966
|
+
detokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
|
967
|
+
nccl_port=port,
|
968
|
+
)
|
969
|
+
else:
|
970
|
+
# DP attention. Use TCP + port to handle both single-node and multi-node.
|
971
|
+
if server_args.nnodes == 1 and server_args.dist_init_addr is None:
|
972
|
+
dist_init_addr = ("127.0.0.1", server_args.port + ZMQ_TCP_PORT_DELTA)
|
973
|
+
else:
|
974
|
+
dist_init_addr = server_args.dist_init_addr.split(":")
|
975
|
+
assert (
|
976
|
+
len(dist_init_addr) == 2
|
977
|
+
), "please provide --dist-init-addr as host:port of head node"
|
978
|
+
|
979
|
+
dist_init_host, dist_init_port = dist_init_addr
|
980
|
+
port_base = int(dist_init_port) + 1
|
981
|
+
if dp_rank is None:
|
982
|
+
scheduler_input_port = (
|
983
|
+
port_base + 2
|
984
|
+
) # TokenizerManager to DataParallelController
|
985
|
+
else:
|
986
|
+
scheduler_input_port = port_base + 2 + 1 + dp_rank
|
987
|
+
|
988
|
+
return PortArgs(
|
989
|
+
tokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base}",
|
990
|
+
scheduler_input_ipc_name=f"tcp://{dist_init_host}:{scheduler_input_port}",
|
991
|
+
detokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base + 1}",
|
992
|
+
nccl_port=port,
|
993
|
+
)
|
942
994
|
|
943
995
|
|
944
996
|
class LoRAPathAction(argparse.Action):
|
@@ -0,0 +1,59 @@
|
|
1
|
+
from abc import ABC
|
2
|
+
from contextlib import contextmanager
|
3
|
+
|
4
|
+
try:
|
5
|
+
import torch_memory_saver
|
6
|
+
|
7
|
+
_primary_memory_saver = torch_memory_saver.TorchMemorySaver()
|
8
|
+
except ImportError:
|
9
|
+
pass
|
10
|
+
|
11
|
+
|
12
|
+
class TorchMemorySaverAdapter(ABC):
|
13
|
+
@staticmethod
|
14
|
+
def create(enable: bool):
|
15
|
+
return (
|
16
|
+
_TorchMemorySaverAdapterReal() if enable else _TorchMemorySaverAdapterNoop()
|
17
|
+
)
|
18
|
+
|
19
|
+
def configure_subprocess(self):
|
20
|
+
raise NotImplementedError
|
21
|
+
|
22
|
+
def region(self):
|
23
|
+
raise NotImplementedError
|
24
|
+
|
25
|
+
def pause(self):
|
26
|
+
raise NotImplementedError
|
27
|
+
|
28
|
+
def resume(self):
|
29
|
+
raise NotImplementedError
|
30
|
+
|
31
|
+
|
32
|
+
class _TorchMemorySaverAdapterReal(TorchMemorySaverAdapter):
|
33
|
+
def configure_subprocess(self):
|
34
|
+
return torch_memory_saver.configure_subprocess()
|
35
|
+
|
36
|
+
def region(self):
|
37
|
+
return _primary_memory_saver.region()
|
38
|
+
|
39
|
+
def pause(self):
|
40
|
+
return _primary_memory_saver.pause()
|
41
|
+
|
42
|
+
def resume(self):
|
43
|
+
return _primary_memory_saver.resume()
|
44
|
+
|
45
|
+
|
46
|
+
class _TorchMemorySaverAdapterNoop(TorchMemorySaverAdapter):
|
47
|
+
@contextmanager
|
48
|
+
def configure_subprocess(self):
|
49
|
+
yield
|
50
|
+
|
51
|
+
@contextmanager
|
52
|
+
def region(self):
|
53
|
+
yield
|
54
|
+
|
55
|
+
def pause(self):
|
56
|
+
pass
|
57
|
+
|
58
|
+
def resume(self):
|
59
|
+
pass
|
sglang/srt/utils.py
CHANGED
@@ -59,6 +59,7 @@ from triton.runtime.cache import (
|
|
59
59
|
default_dump_dir,
|
60
60
|
default_override_dir,
|
61
61
|
)
|
62
|
+
from uvicorn.config import LOGGING_CONFIG
|
62
63
|
|
63
64
|
logger = logging.getLogger(__name__)
|
64
65
|
|
@@ -97,12 +98,8 @@ def is_flashinfer_available():
|
|
97
98
|
return torch.cuda.is_available() and torch.version.cuda
|
98
99
|
|
99
100
|
|
100
|
-
def
|
101
|
-
|
102
|
-
ipaddress.IPv6Address(address)
|
103
|
-
return True
|
104
|
-
except ipaddress.AddressValueError:
|
105
|
-
return False
|
101
|
+
def is_cuda_available():
|
102
|
+
return torch.cuda.is_available() and torch.version.cuda
|
106
103
|
|
107
104
|
|
108
105
|
def enable_show_time_cost():
|
@@ -218,6 +215,10 @@ def get_available_gpu_memory(device, gpu_id, distributed=False, empty_cache=True
|
|
218
215
|
|
219
216
|
free_gpu_memory, total_gpu_memory = torch.hpu.mem_get_info()
|
220
217
|
|
218
|
+
elif device == "cpu":
|
219
|
+
# TODO: rename the variables in the current function to be not GPU specific
|
220
|
+
free_gpu_memory = psutil.virtual_memory().available
|
221
|
+
|
221
222
|
if distributed:
|
222
223
|
tensor = torch.tensor(free_gpu_memory, dtype=torch.float32).to(
|
223
224
|
torch.device(device, gpu_id)
|
@@ -442,6 +443,8 @@ def load_image(image_file: Union[str, bytes]):
|
|
442
443
|
else:
|
443
444
|
raise ValueError(f"Invalid image: {image}")
|
444
445
|
|
446
|
+
# if image_size is None:
|
447
|
+
# image_size = image.size
|
445
448
|
return image, image_size
|
446
449
|
|
447
450
|
|
@@ -507,76 +510,32 @@ def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = N
|
|
507
510
|
pass
|
508
511
|
|
509
512
|
|
510
|
-
def
|
513
|
+
def monkey_patch_p2p_access_check():
|
511
514
|
"""
|
512
|
-
Monkey patch the slow p2p access check
|
515
|
+
Monkey patch the slow p2p access check.
|
513
516
|
NOTE: We assume the p2p access is always allowed, which can be wrong for some setups.
|
514
517
|
"""
|
515
518
|
|
516
|
-
import
|
519
|
+
import sglang.srt.distributed.device_communicators.custom_all_reduce_utils as tgt
|
517
520
|
|
518
521
|
setattr(tgt, "gpu_p2p_access_check", lambda *arg, **kwargs: True)
|
519
522
|
|
520
523
|
# Suppress the warnings from this delete function when using sglang.bench_one_batch
|
521
|
-
from
|
524
|
+
from sglang.srt.distributed.device_communicators.custom_all_reduce import (
|
525
|
+
CustomAllreduce,
|
526
|
+
)
|
522
527
|
|
523
528
|
setattr(CustomAllreduce, "__del__", lambda *args, **kwargs: None)
|
524
529
|
|
525
530
|
|
526
|
-
vllm_all_gather_backup = None
|
527
|
-
|
528
|
-
|
529
|
-
def monkey_patch_vllm_all_gather(reverse: bool = False):
|
530
|
-
"""Monkey patch all-gather to remove in-place operations."""
|
531
|
-
from torch.distributed import _functional_collectives as funcol
|
532
|
-
from vllm.distributed.parallel_state import GroupCoordinator
|
533
|
-
|
534
|
-
global vllm_all_gather_backup
|
535
|
-
if vllm_all_gather_backup is None:
|
536
|
-
vllm_all_gather_backup = GroupCoordinator.all_gather
|
537
|
-
|
538
|
-
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
539
|
-
world_size = self.world_size
|
540
|
-
# Bypass the function if we are using only 1 GPU.
|
541
|
-
if world_size == 1:
|
542
|
-
return input_
|
543
|
-
assert (
|
544
|
-
-input_.dim() <= dim < input_.dim()
|
545
|
-
), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
|
546
|
-
if dim < 0:
|
547
|
-
# Convert negative dim to positive.
|
548
|
-
dim += input_.dim()
|
549
|
-
input_size = input_.size()
|
550
|
-
# Allocate output tensor.
|
551
|
-
output_tensor = torch.empty(
|
552
|
-
(world_size,) + input_size, dtype=input_.dtype, device=input_.device
|
553
|
-
)
|
554
|
-
|
555
|
-
output_tensor = funcol.all_gather_tensor(
|
556
|
-
input_, gather_dim=0, group=self.device_group
|
557
|
-
).view((world_size,) + input_size)
|
558
|
-
|
559
|
-
# Reshape
|
560
|
-
output_tensor = output_tensor.movedim(0, dim)
|
561
|
-
output_tensor = output_tensor.reshape(
|
562
|
-
input_size[:dim] + (world_size * input_size[dim],) + input_size[dim + 1 :]
|
563
|
-
)
|
564
|
-
return output_tensor
|
565
|
-
|
566
|
-
if reverse:
|
567
|
-
setattr(GroupCoordinator, "all_gather", vllm_all_gather_backup)
|
568
|
-
else:
|
569
|
-
setattr(GroupCoordinator, "all_gather", all_gather)
|
570
|
-
|
571
|
-
|
572
531
|
def monkey_patch_vllm_gguf_config():
|
573
|
-
from vllm.model_executor.layers.linear import LinearBase
|
574
532
|
from vllm.model_executor.layers.quantization.gguf import (
|
575
533
|
GGUFConfig,
|
576
534
|
GGUFEmbeddingMethod,
|
577
535
|
GGUFLinearMethod,
|
578
536
|
)
|
579
537
|
|
538
|
+
from sglang.srt.layers.linear import LinearBase
|
580
539
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
581
540
|
|
582
541
|
def get_quant_method_with_embedding_replaced(
|
@@ -784,7 +743,9 @@ def first_rank_print(*args, **kwargs):
|
|
784
743
|
pass
|
785
744
|
|
786
745
|
|
787
|
-
def get_zmq_socket(
|
746
|
+
def get_zmq_socket(
|
747
|
+
context: zmq.Context, socket_type: zmq.SocketType, endpoint: str, bind: bool
|
748
|
+
):
|
788
749
|
mem = psutil.virtual_memory()
|
789
750
|
total_mem = mem.total / 1024**3
|
790
751
|
available_mem = mem.available / 1024**3
|
@@ -797,14 +758,17 @@ def get_zmq_socket(context: zmq.Context, socket_type: zmq.SocketType, endpoint:
|
|
797
758
|
if socket_type == zmq.PUSH:
|
798
759
|
socket.setsockopt(zmq.SNDHWM, 0)
|
799
760
|
socket.setsockopt(zmq.SNDBUF, buf_size)
|
800
|
-
socket.connect(f"ipc://{endpoint}")
|
801
761
|
elif socket_type == zmq.PULL:
|
802
762
|
socket.setsockopt(zmq.RCVHWM, 0)
|
803
763
|
socket.setsockopt(zmq.RCVBUF, buf_size)
|
804
|
-
socket.bind(f"ipc://{endpoint}")
|
805
764
|
else:
|
806
765
|
raise ValueError(f"Unsupported socket type: {socket_type}")
|
807
766
|
|
767
|
+
if bind:
|
768
|
+
socket.bind(endpoint)
|
769
|
+
else:
|
770
|
+
socket.connect(endpoint)
|
771
|
+
|
808
772
|
return socket
|
809
773
|
|
810
774
|
|
@@ -1246,9 +1210,9 @@ def dataclass_to_string_truncated(data, max_length=2048):
|
|
1246
1210
|
if isinstance(data, str):
|
1247
1211
|
if len(data) > max_length:
|
1248
1212
|
half_length = max_length // 2
|
1249
|
-
return f
|
1213
|
+
return f"{repr(data[:half_length])} ... {repr(data[-half_length:])}"
|
1250
1214
|
else:
|
1251
|
-
return f
|
1215
|
+
return f"{repr(data)}"
|
1252
1216
|
elif isinstance(data, (list, tuple)):
|
1253
1217
|
if len(data) > max_length:
|
1254
1218
|
half_length = max_length // 2
|
@@ -1259,7 +1223,7 @@ def dataclass_to_string_truncated(data, max_length=2048):
|
|
1259
1223
|
return (
|
1260
1224
|
"{"
|
1261
1225
|
+ ", ".join(
|
1262
|
-
f"{k}: {dataclass_to_string_truncated(v, max_length)}"
|
1226
|
+
f"'{k}': {dataclass_to_string_truncated(v, max_length)}"
|
1263
1227
|
for k, v in data.items()
|
1264
1228
|
)
|
1265
1229
|
+ "}"
|
@@ -1340,6 +1304,25 @@ def parse_tool_response(text, tools, **kwargs):
|
|
1340
1304
|
return text, call_info_list
|
1341
1305
|
|
1342
1306
|
|
1307
|
+
def permute_weight(x: torch.Tensor) -> torch.Tensor:
|
1308
|
+
b_ = x.shape[0]
|
1309
|
+
n_ = x.shape[1]
|
1310
|
+
k_ = x.shape[2]
|
1311
|
+
|
1312
|
+
x_ = x
|
1313
|
+
if x.dtype == torch.bfloat16 or x.dtype == torch.float16:
|
1314
|
+
x_ = x_.view(int(b_), int(n_ / 16), 16, int(k_ / 32), 4, 8)
|
1315
|
+
elif x.dtype == torch.float8_e4m3fnuz or x.dtype == torch.int8:
|
1316
|
+
x_ = x_.view(int(b_), int(n_ / 16), 16, int(k_ / 64), 4, 16)
|
1317
|
+
else:
|
1318
|
+
return x_
|
1319
|
+
|
1320
|
+
x_ = x_.permute(0, 1, 3, 4, 2, 5)
|
1321
|
+
x_ = x_.contiguous()
|
1322
|
+
x_ = x_.view(*x.shape)
|
1323
|
+
return x_
|
1324
|
+
|
1325
|
+
|
1343
1326
|
class MultiprocessingSerializer:
|
1344
1327
|
@staticmethod
|
1345
1328
|
def serialize(obj):
|
@@ -1375,3 +1358,87 @@ def debug_timing(func):
|
|
1375
1358
|
return func(*args, **kwargs)
|
1376
1359
|
|
1377
1360
|
return wrapper
|
1361
|
+
|
1362
|
+
|
1363
|
+
def nullable_str(val: str):
|
1364
|
+
if not val or val == "None":
|
1365
|
+
return None
|
1366
|
+
return val
|
1367
|
+
|
1368
|
+
|
1369
|
+
def set_uvicorn_logging_configs():
|
1370
|
+
LOGGING_CONFIG["formatters"]["default"][
|
1371
|
+
"fmt"
|
1372
|
+
] = "[%(asctime)s] %(levelprefix)s %(message)s"
|
1373
|
+
LOGGING_CONFIG["formatters"]["default"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
|
1374
|
+
LOGGING_CONFIG["formatters"]["access"][
|
1375
|
+
"fmt"
|
1376
|
+
] = '[%(asctime)s] %(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s'
|
1377
|
+
LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
|
1378
|
+
|
1379
|
+
|
1380
|
+
def get_ip() -> str:
|
1381
|
+
# SGLANG_HOST_IP env can be ignore
|
1382
|
+
host_ip = os.getenv("SGLANG_HOST_IP", "") or os.getenv("HOST_IP", "")
|
1383
|
+
if host_ip:
|
1384
|
+
return host_ip
|
1385
|
+
|
1386
|
+
# IP is not set, try to get it from the network interface
|
1387
|
+
|
1388
|
+
# try ipv4
|
1389
|
+
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
1390
|
+
try:
|
1391
|
+
s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable
|
1392
|
+
return s.getsockname()[0]
|
1393
|
+
except Exception:
|
1394
|
+
pass
|
1395
|
+
|
1396
|
+
# try ipv6
|
1397
|
+
try:
|
1398
|
+
s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
|
1399
|
+
# Google's public DNS server, see
|
1400
|
+
# https://developers.google.com/speed/public-dns/docs/using#addresses
|
1401
|
+
s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable
|
1402
|
+
return s.getsockname()[0]
|
1403
|
+
except Exception:
|
1404
|
+
pass
|
1405
|
+
|
1406
|
+
warnings.warn(
|
1407
|
+
"Failed to get the IP address, using 0.0.0.0 by default."
|
1408
|
+
"The value can be set by the environment variable"
|
1409
|
+
" SGLANG_HOST_IP or HOST_IP.",
|
1410
|
+
stacklevel=2,
|
1411
|
+
)
|
1412
|
+
return "0.0.0.0"
|
1413
|
+
|
1414
|
+
|
1415
|
+
def get_open_port() -> int:
|
1416
|
+
|
1417
|
+
port = os.getenv("SGLANG_PORT")
|
1418
|
+
if port is not None:
|
1419
|
+
while True:
|
1420
|
+
try:
|
1421
|
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
1422
|
+
s.bind(("", port))
|
1423
|
+
return port
|
1424
|
+
except OSError:
|
1425
|
+
port += 1 # Increment port number if already in use
|
1426
|
+
logger.info("Port %d is already in use, trying port %d", port - 1, port)
|
1427
|
+
# try ipv4
|
1428
|
+
try:
|
1429
|
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
1430
|
+
s.bind(("", 0))
|
1431
|
+
return s.getsockname()[1]
|
1432
|
+
except OSError:
|
1433
|
+
# try ipv6
|
1434
|
+
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
|
1435
|
+
s.bind(("", 0))
|
1436
|
+
return s.getsockname()[1]
|
1437
|
+
|
1438
|
+
|
1439
|
+
def is_valid_ipv6_address(address: str) -> bool:
|
1440
|
+
try:
|
1441
|
+
ipaddress.IPv6Address(address)
|
1442
|
+
return True
|
1443
|
+
except ValueError:
|
1444
|
+
return False
|
sglang/test/runners.py
CHANGED
@@ -12,7 +12,6 @@
|
|
12
12
|
# limitations under the License.
|
13
13
|
# ==============================================================================
|
14
14
|
|
15
|
-
import json
|
16
15
|
import multiprocessing as mp
|
17
16
|
import os
|
18
17
|
from dataclasses import dataclass
|
@@ -22,8 +21,8 @@ import torch
|
|
22
21
|
import torch.nn.functional as F
|
23
22
|
from transformers import AutoModelForCausalLM
|
24
23
|
|
24
|
+
from sglang.srt.entrypoints.engine import Engine
|
25
25
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
26
|
-
from sglang.srt.server import Runtime
|
27
26
|
from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER
|
28
27
|
|
29
28
|
DEFAULT_PROMPTS = [
|
@@ -278,7 +277,7 @@ class SRTRunner:
|
|
278
277
|
):
|
279
278
|
self.model_type = model_type
|
280
279
|
self.is_generation = model_type == "generation"
|
281
|
-
self.
|
280
|
+
self.engine = Engine(
|
282
281
|
model_path=model_path,
|
283
282
|
tp_size=tp_size,
|
284
283
|
dtype=get_dtype_str(torch_dtype),
|
@@ -306,7 +305,7 @@ class SRTRunner:
|
|
306
305
|
top_output_logprobs = []
|
307
306
|
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
|
308
307
|
for i, prompt in enumerate(prompts):
|
309
|
-
response = self.
|
308
|
+
response = self.engine.generate(
|
310
309
|
prompt,
|
311
310
|
lora_path=lora_paths[i] if lora_paths else None,
|
312
311
|
sampling_params=sampling_params,
|
@@ -314,7 +313,6 @@ class SRTRunner:
|
|
314
313
|
logprob_start_len=0,
|
315
314
|
top_logprobs_num=NUM_TOP_LOGPROBS,
|
316
315
|
)
|
317
|
-
response = json.loads(response)
|
318
316
|
output_strs.append(response["text"])
|
319
317
|
top_input_logprobs.append(
|
320
318
|
[
|
@@ -343,8 +341,7 @@ class SRTRunner:
|
|
343
341
|
top_output_logprobs=top_output_logprobs,
|
344
342
|
)
|
345
343
|
else:
|
346
|
-
response = self.
|
347
|
-
response = json.loads(response)
|
344
|
+
response = self.engine.encode(prompts)
|
348
345
|
if self.model_type == "embedding":
|
349
346
|
logits = [x["embedding"] for x in response]
|
350
347
|
return ModelOutput(embed_logits=logits)
|
@@ -366,20 +363,18 @@ class SRTRunner:
|
|
366
363
|
# the return value contains logprobs from prefill
|
367
364
|
output_strs = []
|
368
365
|
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
|
369
|
-
response = self.
|
366
|
+
response = self.engine.generate(
|
370
367
|
prompts,
|
371
368
|
lora_path=lora_paths if lora_paths else None,
|
372
369
|
sampling_params=sampling_params,
|
373
370
|
)
|
374
|
-
response = json.loads(response)
|
375
371
|
output_strs = [r["text"] for r in response]
|
376
372
|
|
377
373
|
return ModelOutput(
|
378
374
|
output_strs=output_strs,
|
379
375
|
)
|
380
376
|
else:
|
381
|
-
response = self.
|
382
|
-
response = json.loads(response)
|
377
|
+
response = self.engine.encode(prompts)
|
383
378
|
if self.model_type == "embedding":
|
384
379
|
logits = [x["embedding"] for x in response]
|
385
380
|
return ModelOutput(embed_logits=logits)
|
@@ -391,8 +386,8 @@ class SRTRunner:
|
|
391
386
|
return self
|
392
387
|
|
393
388
|
def __exit__(self, exc_type, exc_value, traceback):
|
394
|
-
self.
|
395
|
-
del self.
|
389
|
+
self.engine.shutdown()
|
390
|
+
del self.engine
|
396
391
|
|
397
392
|
|
398
393
|
def monkey_patch_gemma2_sdpa():
|