sglang 0.4.4.post1__py3-none-any.whl → 0.4.4.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +6 -0
- sglang/bench_one_batch.py +1 -1
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +26 -4
- sglang/check_env.py +3 -4
- sglang/lang/backend/openai.py +18 -5
- sglang/lang/chat_template.py +28 -7
- sglang/lang/interpreter.py +7 -3
- sglang/lang/ir.py +10 -0
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/code_completion_parser.py +174 -0
- sglang/srt/configs/__init__.py +2 -6
- sglang/srt/configs/deepseekvl2.py +676 -0
- sglang/srt/configs/janus_pro.py +3 -4
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +49 -8
- sglang/srt/configs/utils.py +25 -0
- sglang/srt/connector/__init__.py +51 -0
- sglang/srt/connector/base_connector.py +112 -0
- sglang/srt/connector/redis.py +85 -0
- sglang/srt/connector/s3.py +122 -0
- sglang/srt/connector/serde/__init__.py +31 -0
- sglang/srt/connector/serde/safe_serde.py +29 -0
- sglang/srt/connector/serde/serde.py +43 -0
- sglang/srt/connector/utils.py +35 -0
- sglang/srt/conversation.py +88 -0
- sglang/srt/disaggregation/conn.py +81 -0
- sglang/srt/disaggregation/decode.py +495 -0
- sglang/srt/disaggregation/mini_lb.py +285 -0
- sglang/srt/disaggregation/prefill.py +249 -0
- sglang/srt/disaggregation/utils.py +44 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/parallel_state.py +42 -8
- sglang/srt/entrypoints/engine.py +55 -5
- sglang/srt/entrypoints/http_server.py +78 -13
- sglang/srt/entrypoints/verl_engine.py +2 -0
- sglang/srt/function_call_parser.py +133 -55
- sglang/srt/hf_transformers_utils.py +28 -3
- sglang/srt/layers/activation.py +4 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +434 -0
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashmla_backend.py +284 -0
- sglang/srt/layers/attention/triton_backend.py +171 -38
- sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
- sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
- sglang/srt/layers/attention/utils.py +53 -0
- sglang/srt/layers/attention/vision.py +9 -28
- sglang/srt/layers/dp_attention.py +41 -19
- sglang/srt/layers/layernorm.py +24 -2
- sglang/srt/layers/linear.py +17 -5
- sglang/srt/layers/logits_processor.py +25 -7
- sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
- sglang/srt/layers/moe/ep_moe/layer.py +273 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
- sglang/srt/layers/moe/fused_moe_native.py +2 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
- sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
- sglang/srt/layers/moe/topk.py +60 -20
- sglang/srt/layers/parameter.py +1 -1
- sglang/srt/layers/quantization/__init__.py +80 -53
- sglang/srt/layers/quantization/awq.py +200 -0
- sglang/srt/layers/quantization/base_config.py +5 -0
- sglang/srt/layers/quantization/blockwise_int8.py +1 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
- sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
- sglang/srt/layers/quantization/fp8.py +76 -34
- sglang/srt/layers/quantization/fp8_kernel.py +25 -8
- sglang/srt/layers/quantization/fp8_utils.py +284 -28
- sglang/srt/layers/quantization/gptq.py +36 -19
- sglang/srt/layers/quantization/kv_cache.py +98 -0
- sglang/srt/layers/quantization/modelopt_quant.py +9 -7
- sglang/srt/layers/quantization/utils.py +153 -0
- sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
- sglang/srt/layers/rotary_embedding.py +78 -87
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/backend/base_backend.py +4 -4
- sglang/srt/lora/backend/flashinfer_backend.py +12 -9
- sglang/srt/lora/backend/triton_backend.py +5 -8
- sglang/srt/lora/layers.py +87 -33
- sglang/srt/lora/lora.py +2 -22
- sglang/srt/lora/lora_manager.py +67 -30
- sglang/srt/lora/mem_pool.py +117 -52
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
- sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
- sglang/srt/lora/utils.py +18 -1
- sglang/srt/managers/cache_controller.py +2 -5
- sglang/srt/managers/data_parallel_controller.py +30 -8
- sglang/srt/managers/expert_distribution.py +81 -0
- sglang/srt/managers/io_struct.py +43 -5
- sglang/srt/managers/mm_utils.py +373 -0
- sglang/srt/managers/multimodal_processor.py +68 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
- sglang/srt/managers/multimodal_processors/clip.py +63 -0
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
- sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
- sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
- sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
- sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
- sglang/srt/managers/schedule_batch.py +134 -30
- sglang/srt/managers/scheduler.py +290 -31
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +59 -24
- sglang/srt/managers/tp_worker.py +4 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
- sglang/srt/managers/utils.py +6 -1
- sglang/srt/mem_cache/hiradix_cache.py +18 -7
- sglang/srt/mem_cache/memory_pool.py +255 -98
- sglang/srt/mem_cache/paged_allocator.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +4 -4
- sglang/srt/model_executor/cuda_graph_runner.py +36 -21
- sglang/srt/model_executor/forward_batch_info.py +68 -11
- sglang/srt/model_executor/model_runner.py +75 -8
- sglang/srt/model_loader/loader.py +171 -3
- sglang/srt/model_loader/weight_utils.py +51 -3
- sglang/srt/models/clip.py +563 -0
- sglang/srt/models/deepseek_janus_pro.py +31 -88
- sglang/srt/models/deepseek_nextn.py +22 -10
- sglang/srt/models/deepseek_v2.py +329 -73
- sglang/srt/models/deepseek_vl2.py +358 -0
- sglang/srt/models/gemma3_causal.py +694 -0
- sglang/srt/models/gemma3_mm.py +468 -0
- sglang/srt/models/llama.py +47 -7
- sglang/srt/models/llama_eagle.py +1 -0
- sglang/srt/models/llama_eagle3.py +196 -0
- sglang/srt/models/llava.py +3 -3
- sglang/srt/models/llavavid.py +3 -3
- sglang/srt/models/minicpmo.py +1995 -0
- sglang/srt/models/minicpmv.py +62 -137
- sglang/srt/models/mllama.py +4 -4
- sglang/srt/models/phi3_small.py +1 -1
- sglang/srt/models/qwen2.py +3 -0
- sglang/srt/models/qwen2_5_vl.py +68 -146
- sglang/srt/models/qwen2_classification.py +75 -0
- sglang/srt/models/qwen2_moe.py +9 -1
- sglang/srt/models/qwen2_vl.py +25 -63
- sglang/srt/openai_api/adapter.py +201 -104
- sglang/srt/openai_api/protocol.py +33 -7
- sglang/srt/patch_torch.py +71 -0
- sglang/srt/sampling/sampling_batch_info.py +1 -1
- sglang/srt/sampling/sampling_params.py +6 -6
- sglang/srt/server_args.py +114 -14
- sglang/srt/speculative/build_eagle_tree.py +7 -347
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
- sglang/srt/speculative/eagle_utils.py +208 -252
- sglang/srt/speculative/eagle_worker.py +140 -54
- sglang/srt/speculative/spec_info.py +6 -1
- sglang/srt/torch_memory_saver_adapter.py +22 -0
- sglang/srt/utils.py +215 -21
- sglang/test/__init__.py +0 -0
- sglang/test/attention/__init__.py +0 -0
- sglang/test/attention/test_flashattn_backend.py +312 -0
- sglang/test/runners.py +29 -2
- sglang/test/test_activation.py +2 -1
- sglang/test/test_block_fp8.py +5 -4
- sglang/test/test_block_fp8_ep.py +2 -1
- sglang/test/test_dynamic_grad_mode.py +58 -0
- sglang/test/test_layernorm.py +3 -2
- sglang/test/test_utils.py +56 -5
- sglang/utils.py +31 -0
- sglang/version.py +1 -1
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/METADATA +16 -8
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/RECORD +180 -132
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/WHEEL +1 -1
- sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
- sglang/srt/managers/image_processor.py +0 -55
- sglang/srt/managers/image_processors/base_image_processor.py +0 -219
- sglang/srt/managers/image_processors/minicpmv.py +0 -86
- sglang/srt/managers/multi_modality_padding.py +0 -134
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info/licenses}/LICENSE +0 -0
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/top_level.txt +0 -0
sglang/srt/utils.py
CHANGED
@@ -36,12 +36,13 @@ import tempfile
|
|
36
36
|
import threading
|
37
37
|
import time
|
38
38
|
import warnings
|
39
|
+
from contextlib import contextmanager
|
39
40
|
from functools import lru_cache
|
40
41
|
from importlib.metadata import PackageNotFoundError, version
|
41
42
|
from importlib.util import find_spec
|
42
43
|
from io import BytesIO
|
43
|
-
from multiprocessing import Pool
|
44
44
|
from multiprocessing.reduction import ForkingPickler
|
45
|
+
from pathlib import Path
|
45
46
|
from typing import Any, Callable, Dict, List, Optional, Protocol, Set, Tuple, Union
|
46
47
|
|
47
48
|
import numpy as np
|
@@ -54,13 +55,13 @@ import triton
|
|
54
55
|
import zmq
|
55
56
|
from fastapi.responses import ORJSONResponse
|
56
57
|
from packaging import version as pkg_version
|
57
|
-
from
|
58
|
+
from PIL import Image
|
58
59
|
from starlette.routing import Mount
|
59
60
|
from torch import nn
|
60
61
|
from torch.func import functional_call
|
61
62
|
from torch.library import Library
|
62
63
|
from torch.profiler import ProfilerActivity, profile, record_function
|
63
|
-
from torch.utils.
|
64
|
+
from torch.utils._contextlib import _DecoratorContextManager
|
64
65
|
from triton.runtime.cache import (
|
65
66
|
FileCacheManager,
|
66
67
|
default_cache_dir,
|
@@ -76,6 +77,11 @@ time_infos = {}
|
|
76
77
|
HIP_FP8_E4M3_FNUZ_MAX = 224.0
|
77
78
|
|
78
79
|
|
80
|
+
def get_bool_env_var(name: str, default: str = "false") -> bool:
|
81
|
+
value = os.getenv(name, default)
|
82
|
+
return value.lower() in ("true", "1")
|
83
|
+
|
84
|
+
|
79
85
|
# https://pytorch.org/docs/stable/notes/hip.html#checking-for-hip
|
80
86
|
def is_hip() -> bool:
|
81
87
|
return torch.version.hip is not None
|
@@ -126,6 +132,63 @@ def is_cuda_available():
|
|
126
132
|
return is_cuda()
|
127
133
|
|
128
134
|
|
135
|
+
_ENABLE_TORCH_INFERENCE_MODE = get_bool_env_var(
|
136
|
+
"SGLANG_ENABLE_TORCH_INFERENCE_MODE", "false"
|
137
|
+
)
|
138
|
+
|
139
|
+
|
140
|
+
class DynamicGradMode(_DecoratorContextManager):
|
141
|
+
"""
|
142
|
+
A combination of torch.no_grad and torch.inference_mode,
|
143
|
+
with their behavior controlled by an environment variable. Just refer to them.
|
144
|
+
"""
|
145
|
+
|
146
|
+
@staticmethod
|
147
|
+
def set_inference_mode(mode: bool):
|
148
|
+
if isinstance(mode, bool):
|
149
|
+
global _ENABLE_TORCH_INFERENCE_MODE
|
150
|
+
|
151
|
+
_ENABLE_TORCH_INFERENCE_MODE = mode
|
152
|
+
else:
|
153
|
+
logger.warning("mode is not a boolean object")
|
154
|
+
|
155
|
+
def __init__(self, mode=True):
|
156
|
+
if not torch._jit_internal.is_scripting():
|
157
|
+
super().__init__()
|
158
|
+
if _ENABLE_TORCH_INFERENCE_MODE:
|
159
|
+
self.mode = mode
|
160
|
+
else:
|
161
|
+
self.prev = False
|
162
|
+
|
163
|
+
def __new__(cls, mode_or_orig_func=True if _ENABLE_TORCH_INFERENCE_MODE else None):
|
164
|
+
if mode_or_orig_func is None or isinstance(mode_or_orig_func, bool):
|
165
|
+
return super().__new__(cls)
|
166
|
+
return cls()(mode_or_orig_func)
|
167
|
+
|
168
|
+
def __enter__(self) -> None:
|
169
|
+
if _ENABLE_TORCH_INFERENCE_MODE:
|
170
|
+
self._inference_mode_context = torch._C._InferenceMode(self.mode)
|
171
|
+
self._inference_mode_context.__enter__()
|
172
|
+
else:
|
173
|
+
self.prev = torch.is_grad_enabled()
|
174
|
+
torch.set_grad_enabled(False)
|
175
|
+
|
176
|
+
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
177
|
+
if _ENABLE_TORCH_INFERENCE_MODE:
|
178
|
+
self._inference_mode_context.__exit__(exc_type, exc_value, traceback)
|
179
|
+
else:
|
180
|
+
torch.set_grad_enabled(self.prev)
|
181
|
+
|
182
|
+
def clone(self) -> "DynamicGradMode":
|
183
|
+
r"""
|
184
|
+
Create a copy of this class
|
185
|
+
"""
|
186
|
+
if _ENABLE_TORCH_INFERENCE_MODE:
|
187
|
+
return self.__class__(self.mode)
|
188
|
+
else:
|
189
|
+
return self.__class__()
|
190
|
+
|
191
|
+
|
129
192
|
def enable_show_time_cost():
|
130
193
|
global show_time_cost
|
131
194
|
show_time_cost = True
|
@@ -198,7 +261,7 @@ def get_available_gpu_memory(device, gpu_id, distributed=False, empty_cache=True
|
|
198
261
|
When distributed is True, the available memory is the minimum available memory of all GPUs.
|
199
262
|
"""
|
200
263
|
if device == "cuda":
|
201
|
-
num_gpus =
|
264
|
+
num_gpus = cuda_device_count_stateless()
|
202
265
|
assert gpu_id < num_gpus
|
203
266
|
|
204
267
|
if torch.cuda.current_device() != gpu_id:
|
@@ -443,17 +506,46 @@ def decode_video_base64(video_base64):
|
|
443
506
|
) # Return an empty array and size tuple if no frames were found
|
444
507
|
|
445
508
|
|
446
|
-
def
|
447
|
-
|
509
|
+
def load_audio(audio_file: str, sr: int = 16000, mono: bool = True) -> np.ndarray:
|
510
|
+
# Use soundfile here, since librosa use it under the hood,
|
511
|
+
# and librosa will not support audio loading in the future
|
512
|
+
import soundfile as sf
|
513
|
+
from scipy.signal import resample
|
514
|
+
|
515
|
+
# print(f"loading {audio_file}")
|
516
|
+
# Load audio data
|
517
|
+
if isinstance(audio_file, bytes):
|
518
|
+
audio, original_sr = sf.read(BytesIO(audio_file))
|
519
|
+
elif audio_file.startswith("data:"):
|
520
|
+
audio_file = audio_file.split(",")[1]
|
521
|
+
audio, original_sr = sf.read(BytesIO(base64.b64decode(audio_file)))
|
522
|
+
elif isinstance(audio_file, str):
|
523
|
+
audio, original_sr = sf.read(audio_file)
|
524
|
+
else:
|
525
|
+
raise ValueError(f"Invalid audio format: {audio_file}")
|
448
526
|
|
527
|
+
# Resample audio if the original sample rate is different from the desired sample rate
|
528
|
+
if original_sr != sr:
|
529
|
+
num_samples = int(len(audio) * float(sr) / original_sr)
|
530
|
+
audio = resample(audio, num_samples)
|
531
|
+
|
532
|
+
# Convert to mono if requested and audio is stereo
|
533
|
+
if mono and len(audio.shape) > 1:
|
534
|
+
audio = np.mean(audio, axis=1)
|
535
|
+
|
536
|
+
return audio
|
537
|
+
|
538
|
+
|
539
|
+
def load_image(image_file: Union[str, bytes]) -> tuple[Image, tuple[int, int]]:
|
449
540
|
image = image_size = None
|
450
541
|
|
451
542
|
if isinstance(image_file, bytes):
|
452
543
|
image = Image.open(BytesIO(image_file))
|
453
544
|
elif image_file.startswith("http://") or image_file.startswith("https://"):
|
454
545
|
timeout = int(os.getenv("REQUEST_TIMEOUT", "3"))
|
455
|
-
response = requests.get(image_file, timeout=timeout)
|
456
|
-
image = Image.open(
|
546
|
+
response = requests.get(image_file, stream=True, timeout=timeout).raw
|
547
|
+
image = Image.open(response)
|
548
|
+
response.close()
|
457
549
|
elif image_file.lower().endswith(("png", "jpg", "jpeg", "webp", "gif")):
|
458
550
|
image = Image.open(image_file)
|
459
551
|
elif image_file.startswith("data:"):
|
@@ -471,7 +563,10 @@ def load_image(image_file: Union[str, bytes]):
|
|
471
563
|
|
472
564
|
|
473
565
|
def suppress_other_loggers():
|
474
|
-
|
566
|
+
try:
|
567
|
+
from vllm.logger import logger as vllm_default_logger
|
568
|
+
except ImportError:
|
569
|
+
return
|
475
570
|
|
476
571
|
vllm_default_logger.setLevel(logging.WARN)
|
477
572
|
logging.getLogger("vllm.distributed.device_communicators.pynccl").setLevel(
|
@@ -480,6 +575,7 @@ def suppress_other_loggers():
|
|
480
575
|
logging.getLogger("vllm.distributed.device_communicators.shm_broadcast").setLevel(
|
481
576
|
logging.WARN
|
482
577
|
)
|
578
|
+
logging.getLogger("vllm.config").setLevel(logging.ERROR)
|
483
579
|
|
484
580
|
warnings.filterwarnings(
|
485
581
|
"ignore", category=UserWarning, message="The given NumPy array is not writable"
|
@@ -527,6 +623,10 @@ def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = N
|
|
527
623
|
|
528
624
|
if include_parent:
|
529
625
|
try:
|
626
|
+
if parent_pid == os.getpid():
|
627
|
+
itself.kill()
|
628
|
+
sys.exit(0)
|
629
|
+
|
530
630
|
itself.kill()
|
531
631
|
|
532
632
|
# Sometime processes cannot be killed with SIGKILL (e.g, PID=1 launched by kubernetes),
|
@@ -555,11 +655,14 @@ def monkey_patch_p2p_access_check():
|
|
555
655
|
|
556
656
|
|
557
657
|
def monkey_patch_vllm_gguf_config():
|
558
|
-
|
559
|
-
|
560
|
-
|
561
|
-
|
562
|
-
|
658
|
+
try:
|
659
|
+
from vllm.model_executor.layers.quantization.gguf import (
|
660
|
+
GGUFConfig,
|
661
|
+
GGUFEmbeddingMethod,
|
662
|
+
GGUFLinearMethod,
|
663
|
+
)
|
664
|
+
except ImportError:
|
665
|
+
return
|
563
666
|
|
564
667
|
from sglang.srt.layers.linear import LinearBase
|
565
668
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
@@ -651,6 +754,16 @@ def prepare_model_and_tokenizer(model_path: str, tokenizer_path: str):
|
|
651
754
|
|
652
755
|
|
653
756
|
def configure_logger(server_args, prefix: str = ""):
|
757
|
+
if SGLANG_LOGGING_CONFIG_PATH := os.getenv("SGLANG_LOGGING_CONFIG_PATH"):
|
758
|
+
if not os.path.exists(SGLANG_LOGGING_CONFIG_PATH):
|
759
|
+
raise Exception(
|
760
|
+
"Setting SGLANG_LOGGING_CONFIG_PATH from env with "
|
761
|
+
f"{SGLANG_LOGGING_CONFIG_PATH} but it does not exist!"
|
762
|
+
)
|
763
|
+
with open(SGLANG_LOGGING_CONFIG_PATH, encoding="utf-8") as file:
|
764
|
+
custom_config = json.loads(file.read())
|
765
|
+
logging.config.dictConfig(custom_config)
|
766
|
+
return
|
654
767
|
format = f"[%(asctime)s{prefix}] %(message)s"
|
655
768
|
# format = f"[%(asctime)s.%(msecs)03d{prefix}] %(message)s"
|
656
769
|
logging.basicConfig(
|
@@ -774,12 +887,22 @@ def get_zmq_socket(
|
|
774
887
|
buf_size = -1
|
775
888
|
|
776
889
|
socket = context.socket(socket_type)
|
777
|
-
|
890
|
+
|
891
|
+
def set_send_opt():
|
778
892
|
socket.setsockopt(zmq.SNDHWM, 0)
|
779
893
|
socket.setsockopt(zmq.SNDBUF, buf_size)
|
780
|
-
|
894
|
+
|
895
|
+
def set_recv_opt():
|
781
896
|
socket.setsockopt(zmq.RCVHWM, 0)
|
782
897
|
socket.setsockopt(zmq.RCVBUF, buf_size)
|
898
|
+
|
899
|
+
if socket_type == zmq.PUSH:
|
900
|
+
set_send_opt()
|
901
|
+
elif socket_type == zmq.PULL:
|
902
|
+
set_recv_opt()
|
903
|
+
elif socket_type == zmq.DEALER:
|
904
|
+
set_send_opt()
|
905
|
+
set_recv_opt()
|
783
906
|
else:
|
784
907
|
raise ValueError(f"Unsupported socket type: {socket_type}")
|
785
908
|
|
@@ -910,6 +1033,13 @@ def get_amdgpu_memory_capacity():
|
|
910
1033
|
)
|
911
1034
|
|
912
1035
|
|
1036
|
+
def get_device_sm():
|
1037
|
+
if torch.cuda.is_available():
|
1038
|
+
major, minor = torch.cuda.get_device_capability()
|
1039
|
+
return major * 10 + minor
|
1040
|
+
return 0
|
1041
|
+
|
1042
|
+
|
913
1043
|
def get_nvgpu_memory_capacity():
|
914
1044
|
try:
|
915
1045
|
# Run nvidia-smi and capture the output
|
@@ -1246,11 +1376,6 @@ def set_gpu_proc_affinity(
|
|
1246
1376
|
logger.info(f"Process {pid} gpu_id {gpu_id} is running on CPUs: {p.cpu_affinity()}")
|
1247
1377
|
|
1248
1378
|
|
1249
|
-
def get_bool_env_var(name: str, default: str = "false") -> bool:
|
1250
|
-
value = os.getenv(name, default)
|
1251
|
-
return value.lower() in ("true", "1")
|
1252
|
-
|
1253
|
-
|
1254
1379
|
@lru_cache(maxsize=2)
|
1255
1380
|
def disable_request_logging() -> bool:
|
1256
1381
|
return get_bool_env_var("SGLANG_DISABLE_REQUEST_LOGGING")
|
@@ -1477,6 +1602,7 @@ def get_ip() -> str:
|
|
1477
1602
|
def get_open_port() -> int:
|
1478
1603
|
port = os.getenv("SGLANG_PORT")
|
1479
1604
|
if port is not None:
|
1605
|
+
port = int(port)
|
1480
1606
|
while True:
|
1481
1607
|
try:
|
1482
1608
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
@@ -1505,6 +1631,38 @@ def is_valid_ipv6_address(address: str) -> bool:
|
|
1505
1631
|
return False
|
1506
1632
|
|
1507
1633
|
|
1634
|
+
def configure_ipv6(dist_init_addr):
|
1635
|
+
addr = dist_init_addr
|
1636
|
+
end = addr.find("]")
|
1637
|
+
if end == -1:
|
1638
|
+
raise ValueError("invalid IPv6 address format: missing ']'")
|
1639
|
+
|
1640
|
+
host = addr[: end + 1]
|
1641
|
+
|
1642
|
+
# this only validates the address without brackets: we still need the below checks.
|
1643
|
+
# if it's invalid, immediately raise an error so we know it's not formatting issues.
|
1644
|
+
if not is_valid_ipv6_address(host[1:end]):
|
1645
|
+
raise ValueError(f"invalid IPv6 address: {host}")
|
1646
|
+
|
1647
|
+
port_str = None
|
1648
|
+
if len(addr) > end + 1:
|
1649
|
+
if addr[end + 1] == ":":
|
1650
|
+
port_str = addr[end + 2 :]
|
1651
|
+
else:
|
1652
|
+
raise ValueError("received IPv6 address format: expected ':' after ']'")
|
1653
|
+
|
1654
|
+
if not port_str:
|
1655
|
+
raise ValueError(
|
1656
|
+
"a port must be specified in IPv6 address (format: [ipv6]:port)"
|
1657
|
+
)
|
1658
|
+
|
1659
|
+
try:
|
1660
|
+
port = int(port_str)
|
1661
|
+
except ValueError:
|
1662
|
+
raise ValueError(f"invalid port in IPv6 address: '{port_str}'")
|
1663
|
+
return port, host
|
1664
|
+
|
1665
|
+
|
1508
1666
|
def rank0_print(msg: str):
|
1509
1667
|
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
1510
1668
|
|
@@ -1561,6 +1719,16 @@ def next_power_of_2(n: int):
|
|
1561
1719
|
setattr(triton, "next_power_of_2", next_power_of_2)
|
1562
1720
|
|
1563
1721
|
|
1722
|
+
@contextmanager
|
1723
|
+
def empty_context(*args, **kwargs):
|
1724
|
+
try:
|
1725
|
+
# Setup code goes here
|
1726
|
+
yield
|
1727
|
+
finally:
|
1728
|
+
# Cleanup code goes here
|
1729
|
+
pass
|
1730
|
+
|
1731
|
+
|
1564
1732
|
def add_prefix(name: str, prefix: str) -> str:
|
1565
1733
|
"""Add a weight path prefix to a module name.
|
1566
1734
|
|
@@ -1572,3 +1740,29 @@ def add_prefix(name: str, prefix: str) -> str:
|
|
1572
1740
|
The string `prefix.name` if prefix is non-empty, otherwise just `name`.
|
1573
1741
|
"""
|
1574
1742
|
return name if not prefix else f"{prefix}.{name}"
|
1743
|
+
|
1744
|
+
|
1745
|
+
def is_remote_url(url: Union[str, Path]) -> bool:
|
1746
|
+
"""
|
1747
|
+
Check if the URL is a remote URL of the format:
|
1748
|
+
<connector_type>://<host>:<port>/<model_name>
|
1749
|
+
"""
|
1750
|
+
if isinstance(url, Path):
|
1751
|
+
return False
|
1752
|
+
|
1753
|
+
pattern = r"(.+)://(.*)"
|
1754
|
+
m = re.match(pattern, url)
|
1755
|
+
return m is not None
|
1756
|
+
|
1757
|
+
|
1758
|
+
def parse_connector_type(url: str) -> str:
|
1759
|
+
"""
|
1760
|
+
Parse the connector type from the URL of the format:
|
1761
|
+
<connector_type>://<path>
|
1762
|
+
"""
|
1763
|
+
pattern = r"(.+)://(.*)"
|
1764
|
+
m = re.match(pattern, url)
|
1765
|
+
if m is None:
|
1766
|
+
return ""
|
1767
|
+
|
1768
|
+
return m.group(1)
|
sglang/test/__init__.py
ADDED
File without changes
|
File without changes
|
@@ -0,0 +1,312 @@
|
|
1
|
+
import unittest
|
2
|
+
|
3
|
+
import torch
|
4
|
+
|
5
|
+
from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend
|
6
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
7
|
+
from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool
|
8
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
9
|
+
from sglang.test.test_utils import CustomTestCase
|
10
|
+
|
11
|
+
|
12
|
+
class MockModelRunner:
|
13
|
+
model_config = type(
|
14
|
+
"ModelConfig", (), {"context_len": 2048, "is_multimodal": False}
|
15
|
+
)
|
16
|
+
sliding_window_size = None
|
17
|
+
|
18
|
+
def __init__(self, device="cuda"):
|
19
|
+
self.device = device
|
20
|
+
# Create a proper req_to_token_pool with the req_to_token attribute
|
21
|
+
self.req_to_token_pool = type(
|
22
|
+
"TokenPool",
|
23
|
+
(),
|
24
|
+
{
|
25
|
+
"size": 160, # a typical max_bs * max_context_len for cuda graph decode
|
26
|
+
"req_to_token": torch.zeros(
|
27
|
+
160, 2048, dtype=torch.int32, device=device
|
28
|
+
), # Add req_to_token attribute
|
29
|
+
},
|
30
|
+
)
|
31
|
+
|
32
|
+
|
33
|
+
class MockReqToTokenPool:
|
34
|
+
def __init__(self, batch_size, seq_len, device):
|
35
|
+
self.req_to_token = (
|
36
|
+
torch.arange(batch_size * seq_len, device=device)
|
37
|
+
.reshape(batch_size, seq_len)
|
38
|
+
.to(torch.int32)
|
39
|
+
)
|
40
|
+
|
41
|
+
|
42
|
+
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA")
|
43
|
+
class TestFlashAttentionBackend(CustomTestCase):
|
44
|
+
def setUp(self):
|
45
|
+
"""Set up test fixtures before each test method."""
|
46
|
+
self.model_runner = MockModelRunner()
|
47
|
+
self.backend = FlashAttentionBackend(self.model_runner)
|
48
|
+
|
49
|
+
# Common test parameters
|
50
|
+
self.batch_size = 2
|
51
|
+
self.seq_len = 4
|
52
|
+
self.num_heads = 2
|
53
|
+
self.head_dim = 8
|
54
|
+
self.device = "cuda"
|
55
|
+
self.dtype = torch.float16
|
56
|
+
|
57
|
+
def _create_attention_layer(self):
|
58
|
+
"""Helper method to create an attention layer."""
|
59
|
+
return RadixAttention(
|
60
|
+
num_heads=self.num_heads,
|
61
|
+
head_dim=self.head_dim,
|
62
|
+
scaling=1.0,
|
63
|
+
num_kv_heads=self.num_heads,
|
64
|
+
layer_id=0,
|
65
|
+
)
|
66
|
+
|
67
|
+
def _create_kv_pool(self, size):
|
68
|
+
"""Helper method to create a KV pool."""
|
69
|
+
return MHATokenToKVPool(
|
70
|
+
size=size,
|
71
|
+
page_size=1, # only consider page=1 for unit test
|
72
|
+
dtype=self.dtype,
|
73
|
+
head_num=self.num_heads,
|
74
|
+
head_dim=self.head_dim,
|
75
|
+
layer_num=1, # only consider layer=1 for unit test
|
76
|
+
device=self.device,
|
77
|
+
enable_memory_saver=False,
|
78
|
+
)
|
79
|
+
|
80
|
+
def _create_qkv_tensors(self, tokens_len):
|
81
|
+
"""Helper method to create q, k, v tensors."""
|
82
|
+
return (
|
83
|
+
torch.randn(
|
84
|
+
tokens_len,
|
85
|
+
self.num_heads,
|
86
|
+
self.head_dim,
|
87
|
+
dtype=self.dtype,
|
88
|
+
device=self.device,
|
89
|
+
),
|
90
|
+
torch.randn(
|
91
|
+
tokens_len,
|
92
|
+
self.num_heads,
|
93
|
+
self.head_dim,
|
94
|
+
dtype=self.dtype,
|
95
|
+
device=self.device,
|
96
|
+
),
|
97
|
+
torch.randn(
|
98
|
+
tokens_len,
|
99
|
+
self.num_heads,
|
100
|
+
self.head_dim,
|
101
|
+
dtype=self.dtype,
|
102
|
+
device=self.device,
|
103
|
+
),
|
104
|
+
)
|
105
|
+
|
106
|
+
def _verify_output(self, output, expected_shape):
|
107
|
+
"""Helper method to verify output."""
|
108
|
+
self.assertEqual(
|
109
|
+
output.shape,
|
110
|
+
expected_shape,
|
111
|
+
f"Expected shape {expected_shape}, got {output.shape}",
|
112
|
+
)
|
113
|
+
self.assertEqual(output.dtype, self.dtype)
|
114
|
+
self.assertEqual(output.device.type, "cuda")
|
115
|
+
self.assertEqual(
|
116
|
+
torch.isnan(output).sum().item(), 0, "Output contains NaN values"
|
117
|
+
)
|
118
|
+
|
119
|
+
def test_forward_extend(self):
|
120
|
+
"""Test the standard extend operation."""
|
121
|
+
# Create test inputs
|
122
|
+
q, k, v = self._create_qkv_tensors(self.batch_size * self.seq_len)
|
123
|
+
|
124
|
+
# Create attention layer
|
125
|
+
layer = self._create_attention_layer()
|
126
|
+
|
127
|
+
# Create forward batch
|
128
|
+
forward_batch = ForwardBatch(
|
129
|
+
batch_size=self.batch_size,
|
130
|
+
input_ids=torch.randint(
|
131
|
+
0, 100, (self.batch_size, self.seq_len), device=self.device
|
132
|
+
),
|
133
|
+
out_cache_loc=torch.arange(
|
134
|
+
self.batch_size * self.seq_len, device=self.device
|
135
|
+
),
|
136
|
+
seq_lens_sum=self.batch_size * self.seq_len,
|
137
|
+
forward_mode=ForwardMode.EXTEND,
|
138
|
+
req_pool_indices=torch.arange(self.batch_size, device=self.device),
|
139
|
+
seq_lens=torch.tensor([self.seq_len] * self.batch_size, device=self.device),
|
140
|
+
# 0 prefix, 4 extend
|
141
|
+
extend_prefix_lens=torch.tensor([0] * self.batch_size, device=self.device),
|
142
|
+
extend_seq_lens=torch.tensor([4] * self.batch_size, device=self.device),
|
143
|
+
attn_backend=self.backend,
|
144
|
+
)
|
145
|
+
|
146
|
+
# Add token pool and KV cache
|
147
|
+
forward_batch.req_to_token_pool = MockReqToTokenPool(
|
148
|
+
self.batch_size, self.seq_len, self.device
|
149
|
+
)
|
150
|
+
forward_batch.token_to_kv_pool = self._create_kv_pool(
|
151
|
+
self.batch_size * self.seq_len
|
152
|
+
)
|
153
|
+
|
154
|
+
# Initialize forward metadata before running the attention
|
155
|
+
self.backend.init_forward_metadata(forward_batch)
|
156
|
+
|
157
|
+
# Run forward_extend
|
158
|
+
output = self.backend.forward_extend(q, k, v, layer, forward_batch)
|
159
|
+
|
160
|
+
# Verify output
|
161
|
+
expected_shape = (
|
162
|
+
self.batch_size * self.seq_len,
|
163
|
+
self.num_heads * self.head_dim,
|
164
|
+
)
|
165
|
+
self._verify_output(output, expected_shape)
|
166
|
+
|
167
|
+
def test_forward_decode(self):
|
168
|
+
"""Test the decode operation with cached tokens."""
|
169
|
+
# For decode, we only have one token per sequence
|
170
|
+
decode_len = 1
|
171
|
+
curr_seq_len = self.seq_len + decode_len
|
172
|
+
|
173
|
+
# Create test inputs
|
174
|
+
q, k, v = self._create_qkv_tensors(self.batch_size * decode_len)
|
175
|
+
|
176
|
+
# Create attention layer
|
177
|
+
layer = self._create_attention_layer()
|
178
|
+
|
179
|
+
# Create forward batch
|
180
|
+
forward_batch = ForwardBatch(
|
181
|
+
batch_size=self.batch_size,
|
182
|
+
input_ids=torch.randint(
|
183
|
+
0, 100, (self.batch_size, decode_len), device=self.device
|
184
|
+
),
|
185
|
+
out_cache_loc=torch.arange(
|
186
|
+
self.batch_size * self.seq_len,
|
187
|
+
self.batch_size * curr_seq_len,
|
188
|
+
device=self.device,
|
189
|
+
),
|
190
|
+
seq_lens_sum=self.batch_size * curr_seq_len,
|
191
|
+
forward_mode=ForwardMode.DECODE,
|
192
|
+
req_pool_indices=torch.arange(self.batch_size, device=self.device),
|
193
|
+
seq_lens=torch.tensor([curr_seq_len] * self.batch_size, device=self.device),
|
194
|
+
attn_backend=self.backend,
|
195
|
+
)
|
196
|
+
|
197
|
+
# Add token pool and KV cache
|
198
|
+
forward_batch.req_to_token_pool = MockReqToTokenPool(
|
199
|
+
self.batch_size, curr_seq_len, self.device
|
200
|
+
)
|
201
|
+
forward_batch.token_to_kv_pool = self._create_kv_pool(
|
202
|
+
self.batch_size * curr_seq_len
|
203
|
+
)
|
204
|
+
|
205
|
+
# Pre-fill KV cache
|
206
|
+
cache_k, cache_v, _ = self._create_qkv_tensors(self.batch_size * self.seq_len)
|
207
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
208
|
+
layer,
|
209
|
+
torch.arange(self.batch_size * self.seq_len, device=self.device),
|
210
|
+
cache_k,
|
211
|
+
cache_v,
|
212
|
+
layer.k_scale,
|
213
|
+
layer.v_scale,
|
214
|
+
)
|
215
|
+
|
216
|
+
# Initialize forward metadata before running the attention
|
217
|
+
self.backend.init_forward_metadata(forward_batch)
|
218
|
+
|
219
|
+
# Run forward_decode
|
220
|
+
output = self.backend.forward_decode(q, k, v, layer, forward_batch)
|
221
|
+
|
222
|
+
# Verify output
|
223
|
+
expected_shape = (self.batch_size, self.num_heads * self.head_dim)
|
224
|
+
self._verify_output(output, expected_shape)
|
225
|
+
|
226
|
+
def test_forward_extend_with_prefix(self):
|
227
|
+
"""Test extending from cached prefix tokens."""
|
228
|
+
# Define prefix and extend lengths
|
229
|
+
prefix_len = 2
|
230
|
+
extend_len = 2
|
231
|
+
total_len = prefix_len + extend_len
|
232
|
+
|
233
|
+
# Create test inputs for the extend portion
|
234
|
+
q, k, v = self._create_qkv_tensors(self.batch_size * extend_len)
|
235
|
+
|
236
|
+
# Create attention layer
|
237
|
+
layer = self._create_attention_layer()
|
238
|
+
|
239
|
+
# Create forward batch
|
240
|
+
forward_batch = ForwardBatch(
|
241
|
+
batch_size=self.batch_size,
|
242
|
+
input_ids=torch.randint(
|
243
|
+
0, 100, (self.batch_size, extend_len), device=self.device
|
244
|
+
),
|
245
|
+
out_cache_loc=torch.arange(
|
246
|
+
self.batch_size * prefix_len,
|
247
|
+
self.batch_size * total_len,
|
248
|
+
device=self.device,
|
249
|
+
),
|
250
|
+
seq_lens_sum=self.batch_size * total_len,
|
251
|
+
forward_mode=ForwardMode.EXTEND,
|
252
|
+
req_pool_indices=torch.arange(self.batch_size, device=self.device),
|
253
|
+
seq_lens=torch.tensor([total_len] * self.batch_size, device=self.device),
|
254
|
+
extend_prefix_lens=torch.tensor(
|
255
|
+
[prefix_len] * self.batch_size, device=self.device
|
256
|
+
),
|
257
|
+
extend_seq_lens=torch.tensor(
|
258
|
+
[extend_len] * self.batch_size, device=self.device
|
259
|
+
),
|
260
|
+
attn_backend=self.backend,
|
261
|
+
)
|
262
|
+
|
263
|
+
# Add token pool and KV cache
|
264
|
+
forward_batch.req_to_token_pool = MockReqToTokenPool(
|
265
|
+
self.batch_size, total_len, self.device
|
266
|
+
)
|
267
|
+
forward_batch.token_to_kv_pool = self._create_kv_pool(
|
268
|
+
self.batch_size * total_len
|
269
|
+
)
|
270
|
+
|
271
|
+
# Pre-fill the KV cache for prefix with known values
|
272
|
+
cache_k = torch.ones(
|
273
|
+
self.batch_size * prefix_len,
|
274
|
+
self.num_heads,
|
275
|
+
self.head_dim,
|
276
|
+
dtype=self.dtype,
|
277
|
+
device=self.device,
|
278
|
+
)
|
279
|
+
cache_v = (
|
280
|
+
torch.ones(
|
281
|
+
self.batch_size * prefix_len,
|
282
|
+
self.num_heads,
|
283
|
+
self.head_dim,
|
284
|
+
dtype=self.dtype,
|
285
|
+
device=self.device,
|
286
|
+
)
|
287
|
+
* 2
|
288
|
+
)
|
289
|
+
|
290
|
+
# Set the prefix KV cache
|
291
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
292
|
+
layer,
|
293
|
+
torch.arange(self.batch_size * prefix_len, device=self.device),
|
294
|
+
cache_k,
|
295
|
+
cache_v,
|
296
|
+
layer.k_scale,
|
297
|
+
layer.v_scale,
|
298
|
+
)
|
299
|
+
|
300
|
+
# Initialize forward metadata before running the attention
|
301
|
+
self.backend.init_forward_metadata(forward_batch)
|
302
|
+
|
303
|
+
# Run forward_extend
|
304
|
+
output = self.backend.forward_extend(q, k, v, layer, forward_batch)
|
305
|
+
|
306
|
+
# Verify output
|
307
|
+
expected_shape = (self.batch_size * extend_len, self.num_heads * self.head_dim)
|
308
|
+
self._verify_output(output, expected_shape)
|
309
|
+
|
310
|
+
|
311
|
+
if __name__ == "__main__":
|
312
|
+
unittest.main()
|