sglang 0.4.7__py3-none-any.whl → 0.4.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +7 -0
- sglang/bench_one_batch.py +8 -6
- sglang/bench_serving.py +1 -1
- sglang/lang/interpreter.py +40 -1
- sglang/lang/ir.py +27 -0
- sglang/math_utils.py +8 -0
- sglang/srt/_custom_ops.py +2 -2
- sglang/srt/code_completion_parser.py +2 -44
- sglang/srt/configs/model_config.py +6 -0
- sglang/srt/constants.py +3 -0
- sglang/srt/conversation.py +19 -3
- sglang/srt/custom_op.py +5 -1
- sglang/srt/disaggregation/base/__init__.py +1 -1
- sglang/srt/disaggregation/base/conn.py +25 -11
- sglang/srt/disaggregation/common/__init__.py +5 -1
- sglang/srt/disaggregation/common/utils.py +42 -0
- sglang/srt/disaggregation/decode.py +211 -72
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
- sglang/srt/disaggregation/fake/__init__.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +15 -9
- sglang/srt/disaggregation/mini_lb.py +34 -4
- sglang/srt/disaggregation/mooncake/__init__.py +1 -1
- sglang/srt/disaggregation/mooncake/conn.py +30 -29
- sglang/srt/disaggregation/nixl/__init__.py +6 -1
- sglang/srt/disaggregation/nixl/conn.py +17 -12
- sglang/srt/disaggregation/prefill.py +144 -55
- sglang/srt/disaggregation/utils.py +155 -123
- sglang/srt/distributed/parallel_state.py +12 -4
- sglang/srt/entrypoints/engine.py +37 -29
- sglang/srt/entrypoints/http_server.py +153 -72
- sglang/srt/entrypoints/http_server_engine.py +0 -3
- sglang/srt/entrypoints/openai/__init__.py +0 -0
- sglang/srt/{openai_api → entrypoints/openai}/protocol.py +84 -10
- sglang/srt/entrypoints/openai/serving_base.py +149 -0
- sglang/srt/entrypoints/openai/serving_chat.py +921 -0
- sglang/srt/entrypoints/openai/serving_completions.py +424 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
- sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
- sglang/srt/entrypoints/openai/serving_score.py +61 -0
- sglang/srt/entrypoints/openai/usage_processor.py +81 -0
- sglang/srt/entrypoints/openai/utils.py +72 -0
- sglang/srt/eplb_simulator/__init__.py +1 -0
- sglang/srt/eplb_simulator/reader.py +51 -0
- sglang/srt/function_call/base_format_detector.py +7 -4
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +64 -10
- sglang/srt/function_call/function_call_parser.py +6 -6
- sglang/srt/function_call/llama32_detector.py +1 -1
- sglang/srt/function_call/mistral_detector.py +1 -1
- sglang/srt/function_call/pythonic_detector.py +1 -1
- sglang/srt/function_call/qwen25_detector.py +1 -1
- sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
- sglang/srt/layers/activation.py +40 -3
- sglang/srt/layers/attention/aiter_backend.py +20 -4
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/cutlass_mla_backend.py +39 -15
- sglang/srt/layers/attention/flashattention_backend.py +71 -72
- sglang/srt/layers/attention/flashinfer_backend.py +10 -8
- sglang/srt/layers/attention/flashinfer_mla_backend.py +29 -28
- sglang/srt/layers/attention/flashmla_backend.py +7 -12
- sglang/srt/layers/attention/tbo_backend.py +3 -3
- sglang/srt/layers/attention/triton_backend.py +138 -130
- sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
- sglang/srt/layers/attention/vision.py +51 -24
- sglang/srt/layers/communicator.py +28 -10
- sglang/srt/layers/dp_attention.py +11 -2
- sglang/srt/layers/layernorm.py +29 -2
- sglang/srt/layers/linear.py +0 -4
- sglang/srt/layers/logits_processor.py +2 -14
- sglang/srt/layers/moe/ep_moe/kernels.py +165 -7
- sglang/srt/layers/moe/ep_moe/layer.py +249 -33
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -4
- sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
- sglang/srt/layers/moe/topk.py +107 -12
- sglang/srt/layers/pooler.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
- sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
- sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
- sglang/srt/layers/quantization/fp8.py +25 -17
- sglang/srt/layers/quantization/fp8_kernel.py +44 -15
- sglang/srt/layers/quantization/fp8_utils.py +87 -22
- sglang/srt/layers/quantization/modelopt_quant.py +62 -8
- sglang/srt/layers/quantization/utils.py +5 -2
- sglang/srt/layers/radix_attention.py +2 -3
- sglang/srt/layers/rotary_embedding.py +42 -2
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/lora_manager.py +249 -105
- sglang/srt/lora/mem_pool.py +53 -50
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +33 -14
- sglang/srt/managers/io_struct.py +31 -10
- sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
- sglang/srt/managers/multimodal_processors/vila.py +85 -0
- sglang/srt/managers/schedule_batch.py +79 -37
- sglang/srt/managers/schedule_policy.py +70 -56
- sglang/srt/managers/scheduler.py +220 -79
- sglang/srt/managers/template_manager.py +226 -0
- sglang/srt/managers/tokenizer_manager.py +40 -10
- sglang/srt/managers/tp_worker.py +12 -2
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
- sglang/srt/mem_cache/base_prefix_cache.py +52 -8
- sglang/srt/mem_cache/chunk_cache.py +11 -15
- sglang/srt/mem_cache/hiradix_cache.py +38 -25
- sglang/srt/mem_cache/memory_pool.py +213 -505
- sglang/srt/mem_cache/memory_pool_host.py +380 -0
- sglang/srt/mem_cache/radix_cache.py +56 -28
- sglang/srt/model_executor/cuda_graph_runner.py +198 -100
- sglang/srt/model_executor/forward_batch_info.py +32 -10
- sglang/srt/model_executor/model_runner.py +28 -12
- sglang/srt/model_loader/loader.py +16 -2
- sglang/srt/model_loader/weight_utils.py +11 -2
- sglang/srt/models/bert.py +113 -13
- sglang/srt/models/deepseek_nextn.py +29 -27
- sglang/srt/models/deepseek_v2.py +213 -173
- sglang/srt/models/glm4.py +312 -0
- sglang/srt/models/internvl.py +46 -102
- sglang/srt/models/mimo_mtp.py +2 -18
- sglang/srt/models/roberta.py +117 -9
- sglang/srt/models/vila.py +305 -0
- sglang/srt/reasoning_parser.py +21 -11
- sglang/srt/sampling/sampling_batch_info.py +24 -0
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +351 -238
- sglang/srt/speculative/build_eagle_tree.py +1 -1
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -9
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +130 -14
- sglang/srt/speculative/eagle_utils.py +468 -116
- sglang/srt/speculative/eagle_worker.py +258 -84
- sglang/srt/torch_memory_saver_adapter.py +19 -15
- sglang/srt/two_batch_overlap.py +4 -2
- sglang/srt/utils.py +235 -11
- sglang/test/attention/test_prefix_chunk_info.py +2 -0
- sglang/test/runners.py +38 -3
- sglang/test/test_block_fp8.py +1 -0
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
- sglang/test/test_block_fp8_ep.py +2 -0
- sglang/test/test_utils.py +4 -1
- sglang/utils.py +9 -0
- sglang/version.py +1 -1
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/METADATA +8 -14
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/RECORD +150 -128
- sglang/srt/entrypoints/verl_engine.py +0 -179
- sglang/srt/openai_api/adapter.py +0 -1990
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
sglang/srt/two_batch_overlap.py
CHANGED
@@ -11,7 +11,7 @@ from sglang.srt.layers.communicator import (
|
|
11
11
|
ScatterMode,
|
12
12
|
)
|
13
13
|
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
|
14
|
-
from sglang.srt.layers.quantization
|
14
|
+
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
15
15
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
16
16
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
17
17
|
from sglang.srt.operations import execute_operations, execute_overlapped_operations
|
@@ -479,7 +479,9 @@ def _model_forward_tbo(
|
|
479
479
|
)
|
480
480
|
del inputs
|
481
481
|
|
482
|
-
with configure_deep_gemm_num_sms(
|
482
|
+
with deep_gemm_wrapper.configure_deep_gemm_num_sms(
|
483
|
+
operations_strategy.deep_gemm_num_sms
|
484
|
+
):
|
483
485
|
outputs_arr = execute_overlapped_operations(
|
484
486
|
inputs_arr=inputs_arr,
|
485
487
|
operations_arr=[operations_strategy.operations] * 2,
|
sglang/srt/utils.py
CHANGED
@@ -17,6 +17,7 @@ import base64
|
|
17
17
|
import builtins
|
18
18
|
import ctypes
|
19
19
|
import dataclasses
|
20
|
+
import functools
|
20
21
|
import importlib
|
21
22
|
import io
|
22
23
|
import ipaddress
|
@@ -159,7 +160,7 @@ def is_npu() -> bool:
|
|
159
160
|
return hasattr(torch, "npu") and torch.npu.is_available()
|
160
161
|
|
161
162
|
|
162
|
-
def
|
163
|
+
def is_host_cpu_x86() -> bool:
|
163
164
|
machine = platform.machine().lower()
|
164
165
|
return (
|
165
166
|
machine in ("x86_64", "amd64", "i386", "i686")
|
@@ -168,6 +169,10 @@ def is_cpu() -> bool:
|
|
168
169
|
)
|
169
170
|
|
170
171
|
|
172
|
+
def is_cpu() -> bool:
|
173
|
+
return os.getenv("SGLANG_USE_CPU_ENGINE", "0") == "1" and is_host_cpu_x86()
|
174
|
+
|
175
|
+
|
171
176
|
def is_flashinfer_available():
|
172
177
|
"""
|
173
178
|
Check whether flashinfer is available.
|
@@ -837,6 +842,7 @@ class CustomCacheManager(FileCacheManager):
|
|
837
842
|
|
838
843
|
|
839
844
|
def set_ulimit(target_soft_limit=65535):
|
845
|
+
# number of open files
|
840
846
|
resource_type = resource.RLIMIT_NOFILE
|
841
847
|
current_soft, current_hard = resource.getrlimit(resource_type)
|
842
848
|
|
@@ -846,6 +852,18 @@ def set_ulimit(target_soft_limit=65535):
|
|
846
852
|
except ValueError as e:
|
847
853
|
logger.warning(f"Fail to set RLIMIT_NOFILE: {e}")
|
848
854
|
|
855
|
+
# stack size
|
856
|
+
resource_type = resource.RLIMIT_STACK
|
857
|
+
current_soft, current_hard = resource.getrlimit(resource_type)
|
858
|
+
target_soft_limit_stack_size = 1024 * target_soft_limit
|
859
|
+
if current_soft < target_soft_limit_stack_size:
|
860
|
+
try:
|
861
|
+
resource.setrlimit(
|
862
|
+
resource_type, (target_soft_limit_stack_size, current_hard)
|
863
|
+
)
|
864
|
+
except ValueError as e:
|
865
|
+
logger.warning(f"Fail to set RLIMIT_STACK: {e}")
|
866
|
+
|
849
867
|
|
850
868
|
def add_api_key_middleware(app, api_key: str):
|
851
869
|
@app.middleware("http")
|
@@ -1277,6 +1295,15 @@ def get_hpu_memory_capacity():
|
|
1277
1295
|
)
|
1278
1296
|
|
1279
1297
|
|
1298
|
+
def get_npu_memory_capacity():
|
1299
|
+
try:
|
1300
|
+
import torch_npu
|
1301
|
+
|
1302
|
+
return torch.npu.mem_get_info()[1] // 1024 // 1024 # unit: MB
|
1303
|
+
except ImportError as e:
|
1304
|
+
raise ImportError("torch_npu is required when run on npu device.")
|
1305
|
+
|
1306
|
+
|
1280
1307
|
def get_device_memory_capacity(device: str = None):
|
1281
1308
|
if is_cuda():
|
1282
1309
|
gpu_mem = get_nvgpu_memory_capacity()
|
@@ -1284,6 +1311,8 @@ def get_device_memory_capacity(device: str = None):
|
|
1284
1311
|
gpu_mem = get_amdgpu_memory_capacity()
|
1285
1312
|
elif device == "hpu":
|
1286
1313
|
gpu_mem = get_hpu_memory_capacity()
|
1314
|
+
elif device == "npu":
|
1315
|
+
gpu_mem = get_npu_memory_capacity()
|
1287
1316
|
else:
|
1288
1317
|
# GPU memory is not known yet or no GPU is available.
|
1289
1318
|
gpu_mem = None
|
@@ -1373,6 +1402,11 @@ def print_warning_once(msg: str) -> None:
|
|
1373
1402
|
logger.warning(msg, stacklevel=2)
|
1374
1403
|
|
1375
1404
|
|
1405
|
+
@functools.lru_cache(None)
|
1406
|
+
def print_info_once(msg: str) -> None:
|
1407
|
+
logger.info(msg)
|
1408
|
+
|
1409
|
+
|
1376
1410
|
def get_device_name(device_id: int = 0) -> str:
|
1377
1411
|
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
1378
1412
|
return torch.cuda.get_device_name(device_id)
|
@@ -1404,6 +1438,11 @@ def get_device(device_id: Optional[int] = None) -> str:
|
|
1404
1438
|
return "xpu"
|
1405
1439
|
return "xpu:{}".format(device_id)
|
1406
1440
|
|
1441
|
+
if hasattr(torch, "npu") and torch.npu.is_available():
|
1442
|
+
if device_id == None:
|
1443
|
+
return "npu"
|
1444
|
+
return "npu:{}".format(device_id)
|
1445
|
+
|
1407
1446
|
if is_habana_available():
|
1408
1447
|
try:
|
1409
1448
|
import habana_frameworks.torch.hpu
|
@@ -1417,6 +1456,15 @@ def get_device(device_id: Optional[int] = None) -> str:
|
|
1417
1456
|
"Habana frameworks detected, but failed to import 'habana_frameworks.torch.hpu'."
|
1418
1457
|
)
|
1419
1458
|
|
1459
|
+
if is_cpu():
|
1460
|
+
if cpu_has_amx_support():
|
1461
|
+
logger.info("Intel AMX is detected, using CPU with Intel AMX support.")
|
1462
|
+
else:
|
1463
|
+
logger.warning(
|
1464
|
+
"CPU device enabled, using torch native backend, low performance expected."
|
1465
|
+
)
|
1466
|
+
return "cpu"
|
1467
|
+
|
1420
1468
|
raise RuntimeError("No accelerator (CUDA, XPU, HPU) is available.")
|
1421
1469
|
|
1422
1470
|
|
@@ -1478,15 +1526,35 @@ def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
|
|
1478
1526
|
return major, minor
|
1479
1527
|
|
1480
1528
|
|
1529
|
+
def get_npu_compiler_config():
|
1530
|
+
config = {
|
1531
|
+
"frozen_parameter": True,
|
1532
|
+
"tiling_schedule_optimize": True,
|
1533
|
+
"topology_sorting_strategy": "StableRDFS",
|
1534
|
+
}
|
1535
|
+
return config
|
1536
|
+
|
1537
|
+
|
1481
1538
|
def get_compiler_backend() -> str:
|
1482
1539
|
if hasattr(torch, "hpu") and torch.hpu.is_available():
|
1483
1540
|
return "hpu_backend"
|
1484
1541
|
|
1485
1542
|
if hasattr(torch, "npu") and torch.npu.is_available():
|
1486
|
-
|
1543
|
+
try:
|
1544
|
+
import torchair
|
1545
|
+
import torchair.ge_concrete_graph.ge_converter.experimental.patch_for_hcom_allreduce
|
1546
|
+
from torchair.configs.compiler_config import CompilerConfig
|
1547
|
+
except ImportError as e:
|
1548
|
+
raise ImportError(
|
1549
|
+
"NPU detected, but torchair package is not installed. "
|
1550
|
+
"Please install torchair for torch.compile support on NPU."
|
1551
|
+
)
|
1552
|
+
compiler_config = CompilerConfig()
|
1553
|
+
predefined_config = get_npu_compiler_config()
|
1554
|
+
for k, v in predefined_config.items():
|
1555
|
+
setattr(compiler_config.experimental_config, k, v)
|
1487
1556
|
|
1488
|
-
|
1489
|
-
npu_backend = torchair.get_npu_backend(compiler_config=config)
|
1557
|
+
npu_backend = torchair.get_npu_backend(compiler_config=compiler_config)
|
1490
1558
|
return npu_backend
|
1491
1559
|
|
1492
1560
|
return "inductor"
|
@@ -1849,13 +1917,6 @@ def configure_ipv6(dist_init_addr):
|
|
1849
1917
|
return port, host
|
1850
1918
|
|
1851
1919
|
|
1852
|
-
def rank0_log(msg: str):
|
1853
|
-
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
1854
|
-
|
1855
|
-
if get_tensor_model_parallel_rank() == 0:
|
1856
|
-
logger.info(msg)
|
1857
|
-
|
1858
|
-
|
1859
1920
|
def rank0_print(msg: str):
|
1860
1921
|
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
1861
1922
|
|
@@ -1863,6 +1924,9 @@ def rank0_print(msg: str):
|
|
1863
1924
|
print(msg, flush=True)
|
1864
1925
|
|
1865
1926
|
|
1927
|
+
rank0_log = rank0_print
|
1928
|
+
|
1929
|
+
|
1866
1930
|
def get_cuda_version():
|
1867
1931
|
if torch.version.cuda:
|
1868
1932
|
return tuple(map(int, torch.version.cuda.split(".")))
|
@@ -2086,6 +2150,44 @@ def get_free_port():
|
|
2086
2150
|
return s.getsockname()[1]
|
2087
2151
|
|
2088
2152
|
|
2153
|
+
def get_local_ip_auto() -> str:
|
2154
|
+
interface = os.environ.get("SGLANG_LOCAL_IP_NIC", None)
|
2155
|
+
return (
|
2156
|
+
get_local_ip_by_nic(interface)
|
2157
|
+
if interface is not None
|
2158
|
+
else get_local_ip_by_remote()
|
2159
|
+
)
|
2160
|
+
|
2161
|
+
|
2162
|
+
def get_local_ip_by_nic(interface: str) -> str:
|
2163
|
+
try:
|
2164
|
+
import netifaces
|
2165
|
+
except ImportError as e:
|
2166
|
+
raise ImportError(
|
2167
|
+
"Environment variable SGLANG_LOCAL_IP_NIC requires package netifaces, please install it through 'pip install netifaces'"
|
2168
|
+
) from e
|
2169
|
+
|
2170
|
+
try:
|
2171
|
+
addresses = netifaces.ifaddresses(interface)
|
2172
|
+
if netifaces.AF_INET in addresses:
|
2173
|
+
for addr_info in addresses[netifaces.AF_INET]:
|
2174
|
+
ip = addr_info.get("addr")
|
2175
|
+
if ip and ip != "127.0.0.1" and ip != "0.0.0.0":
|
2176
|
+
return ip
|
2177
|
+
if netifaces.AF_INET6 in addresses:
|
2178
|
+
for addr_info in addresses[netifaces.AF_INET6]:
|
2179
|
+
ip = addr_info.get("addr")
|
2180
|
+
if ip and not ip.startswith("fe80::") and ip != "::1":
|
2181
|
+
return ip.split("%")[0]
|
2182
|
+
except (ValueError, OSError) as e:
|
2183
|
+
raise ValueError(
|
2184
|
+
"Can not get local ip from NIC. Please verify whether SGLANG_LOCAL_IP_NIC is set correctly."
|
2185
|
+
)
|
2186
|
+
|
2187
|
+
# Fallback
|
2188
|
+
return get_local_ip_by_remote()
|
2189
|
+
|
2190
|
+
|
2089
2191
|
def get_local_ip_by_remote() -> str:
|
2090
2192
|
# try ipv4
|
2091
2193
|
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
@@ -2197,6 +2299,90 @@ class Withable(Generic[T]):
|
|
2197
2299
|
self._value = None
|
2198
2300
|
|
2199
2301
|
|
2302
|
+
def require_mlp_tp_gather(server_args):
|
2303
|
+
"""
|
2304
|
+
Check if the input of MLP is obtained by all-gather rather than all-reduce. This only happens when each MLP TP group contains multiple attention DP groups.
|
2305
|
+
"""
|
2306
|
+
if server_args.enable_dp_attention:
|
2307
|
+
assert server_args.dp_size > 1, "dp_size must be greater than 1"
|
2308
|
+
if (
|
2309
|
+
server_args.moe_dense_tp_size is None
|
2310
|
+
): # TODO(ch-wan): some MoE models do not have dense layers
|
2311
|
+
return True
|
2312
|
+
elif not server_args.enable_dp_lm_head:
|
2313
|
+
return True
|
2314
|
+
elif not server_args.enable_deepep_moe:
|
2315
|
+
return True
|
2316
|
+
else:
|
2317
|
+
return (
|
2318
|
+
server_args.moe_dense_tp_size
|
2319
|
+
> server_args.tp_size // server_args.dp_size
|
2320
|
+
)
|
2321
|
+
else:
|
2322
|
+
return False
|
2323
|
+
|
2324
|
+
|
2325
|
+
def require_attn_tp_gather(server_args):
|
2326
|
+
"""
|
2327
|
+
Check if the input of attention is scattered.
|
2328
|
+
"""
|
2329
|
+
assert server_args.moe_dense_tp_size in [1, None]
|
2330
|
+
if server_args.enable_deepep_moe or server_args.moe_dense_tp_size == 1:
|
2331
|
+
if server_args.enable_dp_attention:
|
2332
|
+
return server_args.dp_size < server_args.tp_size
|
2333
|
+
else:
|
2334
|
+
return True
|
2335
|
+
else:
|
2336
|
+
return False
|
2337
|
+
|
2338
|
+
|
2339
|
+
def require_gathered_buffer(server_args):
|
2340
|
+
return require_mlp_tp_gather(server_args) or require_attn_tp_gather(server_args)
|
2341
|
+
|
2342
|
+
|
2343
|
+
def require_mlp_sync(server_args):
|
2344
|
+
return server_args.enable_dp_attention or require_gathered_buffer(server_args)
|
2345
|
+
|
2346
|
+
|
2347
|
+
def merge_bias_tensor(
|
2348
|
+
lhs: Optional[torch.Tensor],
|
2349
|
+
rhs: Optional[torch.Tensor],
|
2350
|
+
bs1: int,
|
2351
|
+
bs2: int,
|
2352
|
+
device: str,
|
2353
|
+
default: float,
|
2354
|
+
):
|
2355
|
+
"""Merge two bias tensors for batch merging.
|
2356
|
+
|
2357
|
+
Args:
|
2358
|
+
lhs: Left-hand side tensor
|
2359
|
+
rhs: Right-hand side tensor
|
2360
|
+
bs1: Batch size of left-hand side tensor
|
2361
|
+
bs2: Batch size of right-hand side tensor
|
2362
|
+
device: Device to place the merged tensor on
|
2363
|
+
default: Default value for missing tensor elements
|
2364
|
+
|
2365
|
+
Returns:
|
2366
|
+
Merged tensor or None if both inputs are None
|
2367
|
+
"""
|
2368
|
+
if lhs is None and rhs is None:
|
2369
|
+
return None
|
2370
|
+
|
2371
|
+
if lhs is not None and rhs is not None:
|
2372
|
+
return torch.cat([lhs, rhs])
|
2373
|
+
else:
|
2374
|
+
if lhs is not None:
|
2375
|
+
shape, dtype = lhs.shape[1:], lhs.dtype
|
2376
|
+
else:
|
2377
|
+
shape, dtype = rhs.shape[1:], rhs.dtype
|
2378
|
+
|
2379
|
+
if lhs is None:
|
2380
|
+
lhs = torch.empty((bs1, *shape), device=device, dtype=dtype).fill_(default)
|
2381
|
+
if rhs is None:
|
2382
|
+
rhs = torch.empty((bs2, *shape), device=device, dtype=dtype).fill_(default)
|
2383
|
+
return torch.cat([lhs, rhs])
|
2384
|
+
|
2385
|
+
|
2200
2386
|
def find_local_repo_dir(repo_id: str, revision: Optional[str] = None) -> Optional[str]:
|
2201
2387
|
import huggingface_hub as hf
|
2202
2388
|
|
@@ -2282,3 +2468,41 @@ class LazyValue:
|
|
2282
2468
|
self._value = self._creator()
|
2283
2469
|
self._creator = None
|
2284
2470
|
return self._value
|
2471
|
+
|
2472
|
+
|
2473
|
+
def dynamic_import(func_path: str):
|
2474
|
+
parts = func_path.split(".")
|
2475
|
+
if len(parts) < 2:
|
2476
|
+
raise ValueError(
|
2477
|
+
"func_path should contain both module name and func name (such as 'module.func')"
|
2478
|
+
)
|
2479
|
+
module_path = ".".join(parts[:-1])
|
2480
|
+
func_name = parts[-1]
|
2481
|
+
module = importlib.import_module(module_path)
|
2482
|
+
func = getattr(module, func_name)
|
2483
|
+
return func
|
2484
|
+
|
2485
|
+
|
2486
|
+
def configure_gc_logger():
|
2487
|
+
logger.info("Enable GC Logger")
|
2488
|
+
|
2489
|
+
import gc
|
2490
|
+
|
2491
|
+
gc_start_time = {}
|
2492
|
+
|
2493
|
+
def gc_callback(phase, info):
|
2494
|
+
gen = info.get("generation", "?")
|
2495
|
+
if phase == "start":
|
2496
|
+
gc_start_time[gen] = time.time()
|
2497
|
+
logger.info(f"GC start: Time {time.time()} | Generation {gen}")
|
2498
|
+
elif phase == "stop":
|
2499
|
+
duration = time.time() - gc_start_time.get(gen, time.time())
|
2500
|
+
collected = info.get("collected", "?")
|
2501
|
+
uncollectable = info.get("uncollectable", "?")
|
2502
|
+
logger.info(
|
2503
|
+
f"GC end: Time {time.time()} | Generation {gen} | "
|
2504
|
+
f"Duration: {duration:.4f}s | Collected: {collected} | Uncollectable: {uncollectable} "
|
2505
|
+
f'{"(LONG GC)" if duration > 0.1 else ""}'
|
2506
|
+
)
|
2507
|
+
|
2508
|
+
gc.callbacks.append(gc_callback)
|
@@ -2,6 +2,8 @@ import unittest
|
|
2
2
|
|
3
3
|
import torch
|
4
4
|
|
5
|
+
from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend
|
6
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
5
7
|
from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool
|
6
8
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
7
9
|
from sglang.test.test_utils import CustomTestCase
|
sglang/test/runners.py
CHANGED
@@ -42,6 +42,21 @@ DEFAULT_PROMPTS = [
|
|
42
42
|
# the output of gemma-2-2b from SRT is unstable on the commented prompt
|
43
43
|
# "The capital of France is",
|
44
44
|
]
|
45
|
+
TEST_RERANK_QUERY_DOCS = [
|
46
|
+
{
|
47
|
+
"query": "How many people live in Berlin?",
|
48
|
+
"documents": [
|
49
|
+
"Berlin is well known for its museums.",
|
50
|
+
],
|
51
|
+
},
|
52
|
+
{
|
53
|
+
"query": "How many people live in Berlin?",
|
54
|
+
"documents": [
|
55
|
+
"Berlin had a population of 3,520,031 registered inhabitants in an area of 891.82 square kilometers.",
|
56
|
+
"Berlin is well known for its museums.",
|
57
|
+
],
|
58
|
+
},
|
59
|
+
]
|
45
60
|
|
46
61
|
dirpath = os.path.dirname(__file__)
|
47
62
|
with open(os.path.join(dirpath, "long_prompt.txt"), "r") as f:
|
@@ -241,7 +256,7 @@ class HFRunner:
|
|
241
256
|
self.model = _get_sentence_transformer_embedding_model(
|
242
257
|
model_path, torch_dtype
|
243
258
|
)
|
244
|
-
elif self.model_type == "reward":
|
259
|
+
elif self.model_type == "reward" or self.model_type == "cross_encoder":
|
245
260
|
from transformers import AutoModelForSequenceClassification
|
246
261
|
|
247
262
|
self.model = AutoModelForSequenceClassification.from_pretrained(
|
@@ -303,6 +318,15 @@ class HFRunner:
|
|
303
318
|
else:
|
304
319
|
logits = self.model.encode(prompts).tolist()
|
305
320
|
out_queue.put(ModelOutput(embed_logits=logits))
|
321
|
+
elif self.model_type == "cross_encoder":
|
322
|
+
inputs = self.tokenizer(
|
323
|
+
prompts, padding=True, return_tensors="pt"
|
324
|
+
).to("cuda")
|
325
|
+
scores = self.model(**inputs).logits
|
326
|
+
scores = scores.squeeze().tolist()
|
327
|
+
if not isinstance(scores, list):
|
328
|
+
scores = [scores]
|
329
|
+
out_queue.put(ModelOutput(scores=scores))
|
306
330
|
|
307
331
|
elif self.model_type == "reward":
|
308
332
|
scores = []
|
@@ -322,7 +346,9 @@ class HFRunner:
|
|
322
346
|
|
323
347
|
def forward(
|
324
348
|
self,
|
325
|
-
prompts: Union[
|
349
|
+
prompts: Union[
|
350
|
+
List[List[str]], List[str], List[torch.Tensor]
|
351
|
+
] = DEFAULT_PROMPTS,
|
326
352
|
image_data: Optional[List[str]] = None,
|
327
353
|
max_new_tokens: int = 8,
|
328
354
|
lora_paths: Optional[List[str]] = None,
|
@@ -526,7 +552,9 @@ class SRTRunner:
|
|
526
552
|
|
527
553
|
def forward(
|
528
554
|
self,
|
529
|
-
prompts: Union[
|
555
|
+
prompts: Union[
|
556
|
+
List[List[str]], List[str], List[torch.Tensor]
|
557
|
+
] = DEFAULT_PROMPTS,
|
530
558
|
image_data: Optional[List[str]] = None,
|
531
559
|
max_new_tokens: int = 8,
|
532
560
|
lora_paths: Optional[List[str]] = None,
|
@@ -552,6 +580,13 @@ class SRTRunner:
|
|
552
580
|
else:
|
553
581
|
logits = [response["embedding"]]
|
554
582
|
return ModelOutput(embed_logits=logits)
|
583
|
+
# cross encoder model
|
584
|
+
elif self.model_type == "cross_encoder":
|
585
|
+
response = self.engine.rerank(prompts)
|
586
|
+
if not isinstance(response, list):
|
587
|
+
response = [response]
|
588
|
+
scores = [x["embedding"] for x in response]
|
589
|
+
return ModelOutput(scores=scores)
|
555
590
|
# reward model
|
556
591
|
else:
|
557
592
|
response = self.engine.encode(prompts)
|