sglang 0.4.8.post1__py3-none-any.whl → 0.4.9.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +17 -2
- sglang/bench_serving.py +170 -24
- sglang/srt/configs/internvl.py +4 -2
- sglang/srt/configs/janus_pro.py +1 -1
- sglang/srt/configs/model_config.py +60 -1
- sglang/srt/configs/update_config.py +119 -0
- sglang/srt/conversation.py +69 -1
- sglang/srt/disaggregation/decode.py +21 -5
- sglang/srt/disaggregation/mooncake/conn.py +35 -4
- sglang/srt/disaggregation/nixl/conn.py +6 -6
- sglang/srt/disaggregation/prefill.py +2 -2
- sglang/srt/disaggregation/utils.py +1 -1
- sglang/srt/distributed/parallel_state.py +44 -17
- sglang/srt/entrypoints/EngineBase.py +8 -0
- sglang/srt/entrypoints/engine.py +40 -6
- sglang/srt/entrypoints/http_server.py +111 -24
- sglang/srt/entrypoints/http_server_engine.py +1 -1
- sglang/srt/entrypoints/openai/protocol.py +4 -2
- sglang/srt/eplb/__init__.py +0 -0
- sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
- sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
- sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
- sglang/srt/{managers → eplb}/expert_distribution.py +1 -5
- sglang/srt/{managers → eplb}/expert_location.py +1 -1
- sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
- sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
- sglang/srt/hf_transformers_utils.py +2 -1
- sglang/srt/layers/activation.py +2 -2
- sglang/srt/layers/amx_utils.py +86 -0
- sglang/srt/layers/attention/ascend_backend.py +219 -0
- sglang/srt/layers/attention/flashattention_backend.py +32 -9
- sglang/srt/layers/attention/tbo_backend.py +37 -9
- sglang/srt/layers/communicator.py +20 -2
- sglang/srt/layers/dp_attention.py +9 -3
- sglang/srt/layers/elementwise.py +76 -12
- sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
- sglang/srt/layers/layernorm.py +26 -0
- sglang/srt/layers/linear.py +84 -14
- sglang/srt/layers/logits_processor.py +4 -4
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +81 -8
- sglang/srt/layers/moe/ep_moe/layer.py +176 -15
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +23 -17
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +3 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +211 -74
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
- sglang/srt/layers/moe/router.py +60 -22
- sglang/srt/layers/moe/topk.py +10 -28
- sglang/srt/layers/parameter.py +67 -7
- sglang/srt/layers/quantization/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
- sglang/srt/layers/quantization/fp8.py +72 -7
- sglang/srt/layers/quantization/fp8_kernel.py +1 -1
- sglang/srt/layers/quantization/fp8_utils.py +1 -2
- sglang/srt/layers/quantization/gptq.py +5 -1
- sglang/srt/layers/quantization/modelopt_quant.py +244 -1
- sglang/srt/layers/quantization/moe_wna16.py +1 -1
- sglang/srt/layers/quantization/quant_utils.py +166 -0
- sglang/srt/layers/quantization/w4afp8.py +264 -0
- sglang/srt/layers/quantization/w8a8_int8.py +52 -1
- sglang/srt/layers/rotary_embedding.py +2 -2
- sglang/srt/layers/vocab_parallel_embedding.py +20 -10
- sglang/srt/lora/lora.py +4 -5
- sglang/srt/lora/lora_manager.py +73 -20
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
- sglang/srt/managers/cache_controller.py +41 -195
- sglang/srt/managers/configure_logging.py +1 -1
- sglang/srt/managers/io_struct.py +58 -14
- sglang/srt/managers/mm_utils.py +77 -61
- sglang/srt/managers/multimodal_processor.py +2 -6
- sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
- sglang/srt/managers/schedule_batch.py +78 -85
- sglang/srt/managers/scheduler.py +130 -64
- sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
- sglang/srt/managers/session_controller.py +12 -3
- sglang/srt/managers/tokenizer_manager.py +314 -103
- sglang/srt/managers/tp_worker.py +13 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
- sglang/srt/mem_cache/allocator.py +290 -0
- sglang/srt/mem_cache/chunk_cache.py +34 -2
- sglang/srt/mem_cache/hiradix_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +402 -66
- sglang/srt/mem_cache/memory_pool_host.py +6 -109
- sglang/srt/mem_cache/multimodal_cache.py +3 -0
- sglang/srt/mem_cache/radix_cache.py +8 -4
- sglang/srt/model_executor/cuda_graph_runner.py +2 -1
- sglang/srt/model_executor/forward_batch_info.py +17 -4
- sglang/srt/model_executor/model_runner.py +297 -56
- sglang/srt/model_loader/loader.py +41 -0
- sglang/srt/model_loader/weight_utils.py +72 -4
- sglang/srt/models/deepseek_nextn.py +1 -3
- sglang/srt/models/deepseek_v2.py +195 -45
- sglang/srt/models/deepseek_vl2.py +3 -5
- sglang/srt/models/gemma3_causal.py +1 -2
- sglang/srt/models/gemma3n_causal.py +4 -3
- sglang/srt/models/gemma3n_mm.py +4 -20
- sglang/srt/models/hunyuan.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -2
- sglang/srt/models/llama.py +10 -4
- sglang/srt/models/llama4.py +32 -45
- sglang/srt/models/llama_eagle3.py +61 -11
- sglang/srt/models/llava.py +5 -5
- sglang/srt/models/minicpmo.py +2 -2
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mllama4.py +402 -89
- sglang/srt/models/phi4mm.py +1 -3
- sglang/srt/models/pixtral.py +3 -7
- sglang/srt/models/qwen2.py +31 -3
- sglang/srt/models/qwen2_5_vl.py +1 -3
- sglang/srt/models/qwen2_audio.py +200 -0
- sglang/srt/models/qwen2_moe.py +32 -6
- sglang/srt/models/qwen2_vl.py +1 -4
- sglang/srt/models/qwen3.py +94 -25
- sglang/srt/models/qwen3_moe.py +68 -21
- sglang/srt/models/vila.py +3 -8
- sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +2 -2
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +140 -158
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3n.py +5 -20
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +65 -66
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
- sglang/srt/operations_strategy.py +6 -2
- sglang/srt/reasoning_parser.py +26 -0
- sglang/srt/sampling/sampling_batch_info.py +39 -1
- sglang/srt/server_args.py +84 -22
- sglang/srt/speculative/build_eagle_tree.py +57 -18
- sglang/srt/speculative/eagle_worker.py +6 -4
- sglang/srt/two_batch_overlap.py +203 -27
- sglang/srt/utils.py +343 -163
- sglang/srt/warmup.py +12 -3
- sglang/test/runners.py +10 -1
- sglang/test/test_cutlass_w4a8_moe.py +281 -0
- sglang/test/test_utils.py +15 -3
- sglang/utils.py +5 -5
- sglang/version.py +1 -1
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/METADATA +12 -8
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/RECORD +157 -146
- sglang/math_utils.py +0 -8
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
- /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/top_level.txt +0 -0
sglang/srt/utils.py
CHANGED
@@ -13,7 +13,8 @@
|
|
13
13
|
# ==============================================================================
|
14
14
|
"""Common utilities."""
|
15
15
|
|
16
|
-
import
|
16
|
+
from __future__ import annotations
|
17
|
+
|
17
18
|
import builtins
|
18
19
|
import ctypes
|
19
20
|
import dataclasses
|
@@ -40,6 +41,7 @@ import threading
|
|
40
41
|
import time
|
41
42
|
import traceback
|
42
43
|
import warnings
|
44
|
+
from collections import OrderedDict, defaultdict
|
43
45
|
from contextlib import contextmanager
|
44
46
|
from enum import Enum
|
45
47
|
from functools import lru_cache
|
@@ -65,6 +67,7 @@ from typing import (
|
|
65
67
|
|
66
68
|
import numpy as np
|
67
69
|
import psutil
|
70
|
+
import pybase64
|
68
71
|
import requests
|
69
72
|
import torch
|
70
73
|
import torch.distributed
|
@@ -80,12 +83,7 @@ from torch.func import functional_call
|
|
80
83
|
from torch.library import Library
|
81
84
|
from torch.profiler import ProfilerActivity, profile, record_function
|
82
85
|
from torch.utils._contextlib import _DecoratorContextManager
|
83
|
-
from triton.runtime.cache import
|
84
|
-
FileCacheManager,
|
85
|
-
default_cache_dir,
|
86
|
-
default_dump_dir,
|
87
|
-
default_override_dir,
|
88
|
-
)
|
86
|
+
from triton.runtime.cache import FileCacheManager
|
89
87
|
|
90
88
|
logger = logging.getLogger(__name__)
|
91
89
|
|
@@ -94,35 +92,6 @@ time_infos = {}
|
|
94
92
|
|
95
93
|
HIP_FP8_E4M3_FNUZ_MAX = 224.0
|
96
94
|
|
97
|
-
_warned_bool_env_var_keys = set()
|
98
|
-
|
99
|
-
|
100
|
-
def get_bool_env_var(name: str, default: str = "false") -> bool:
|
101
|
-
value = os.getenv(name, default)
|
102
|
-
value = value.lower()
|
103
|
-
|
104
|
-
truthy_values = ("true", "1")
|
105
|
-
falsy_values = ("false", "0")
|
106
|
-
|
107
|
-
if (value not in truthy_values) and (value not in falsy_values):
|
108
|
-
if value not in _warned_bool_env_var_keys:
|
109
|
-
logger.warning(
|
110
|
-
f"get_bool_env_var({name}) see non-understandable value={value} and treat as false"
|
111
|
-
)
|
112
|
-
_warned_bool_env_var_keys.add(value)
|
113
|
-
|
114
|
-
return value in truthy_values
|
115
|
-
|
116
|
-
|
117
|
-
def get_int_env_var(name: str, default: int = 0) -> int:
|
118
|
-
value = os.getenv(name)
|
119
|
-
if value is None or not value.strip():
|
120
|
-
return default
|
121
|
-
try:
|
122
|
-
return int(value)
|
123
|
-
except ValueError:
|
124
|
-
return default
|
125
|
-
|
126
95
|
|
127
96
|
# https://pytorch.org/docs/stable/notes/hip.html#checking-for-hip
|
128
97
|
def is_hip() -> bool:
|
@@ -173,6 +142,82 @@ def is_cpu() -> bool:
|
|
173
142
|
return os.getenv("SGLANG_USE_CPU_ENGINE", "0") == "1" and is_host_cpu_x86()
|
174
143
|
|
175
144
|
|
145
|
+
def get_cuda_version():
|
146
|
+
if torch.version.cuda:
|
147
|
+
return tuple(map(int, torch.version.cuda.split(".")))
|
148
|
+
return (0, 0)
|
149
|
+
|
150
|
+
|
151
|
+
def _check(cc_major):
|
152
|
+
if not is_cuda():
|
153
|
+
return False
|
154
|
+
return torch.cuda.get_device_capability()[0] == cc_major and tuple(
|
155
|
+
map(int, torch.version.cuda.split(".")[:2])
|
156
|
+
) >= (12, 3)
|
157
|
+
|
158
|
+
|
159
|
+
is_ampere_with_cuda_12_3 = lambda: _check(8)
|
160
|
+
is_hopper_with_cuda_12_3 = lambda: _check(9)
|
161
|
+
|
162
|
+
|
163
|
+
def is_blackwell():
|
164
|
+
if not is_cuda():
|
165
|
+
return False
|
166
|
+
return torch.cuda.get_device_capability()[0] == 10
|
167
|
+
|
168
|
+
|
169
|
+
_warned_bool_env_var_keys = set()
|
170
|
+
|
171
|
+
|
172
|
+
def get_bool_env_var(name: str, default: str = "false") -> bool:
|
173
|
+
value = os.getenv(name, default)
|
174
|
+
value = value.lower()
|
175
|
+
|
176
|
+
truthy_values = ("true", "1")
|
177
|
+
falsy_values = ("false", "0")
|
178
|
+
|
179
|
+
if (value not in truthy_values) and (value not in falsy_values):
|
180
|
+
if value not in _warned_bool_env_var_keys:
|
181
|
+
logger.warning(
|
182
|
+
f"get_bool_env_var({name}) see non-understandable value={value} and treat as false"
|
183
|
+
)
|
184
|
+
_warned_bool_env_var_keys.add(value)
|
185
|
+
|
186
|
+
return value in truthy_values
|
187
|
+
|
188
|
+
|
189
|
+
def get_int_env_var(name: str, default: int = 0) -> int:
|
190
|
+
value = os.getenv(name)
|
191
|
+
if value is None or not value.strip():
|
192
|
+
return default
|
193
|
+
try:
|
194
|
+
return int(value)
|
195
|
+
except ValueError:
|
196
|
+
return default
|
197
|
+
|
198
|
+
|
199
|
+
def support_triton(backend: str) -> bool:
|
200
|
+
return backend not in ["torch_native", "intel_amx"]
|
201
|
+
|
202
|
+
|
203
|
+
try:
|
204
|
+
import sgl_kernel
|
205
|
+
|
206
|
+
is_intel_amx_backend_available = hasattr(
|
207
|
+
torch.ops.sgl_kernel, "convert_weight_packed"
|
208
|
+
)
|
209
|
+
except:
|
210
|
+
is_intel_amx_backend_available = False
|
211
|
+
|
212
|
+
|
213
|
+
def cpu_has_amx_support():
|
214
|
+
return torch._C._cpu._is_amx_tile_supported() and is_intel_amx_backend_available
|
215
|
+
|
216
|
+
|
217
|
+
def use_intel_amx_backend(layer):
|
218
|
+
return getattr(layer, "use_intel_amx_backend", False)
|
219
|
+
|
220
|
+
|
176
221
|
def is_flashinfer_available():
|
177
222
|
"""
|
178
223
|
Check whether flashinfer is available.
|
@@ -500,6 +545,46 @@ def set_random_seed(seed: int) -> None:
|
|
500
545
|
torch.cuda.manual_seed_all(seed)
|
501
546
|
|
502
547
|
|
548
|
+
def find_process_using_port(port: int) -> Optional[psutil.Process]:
|
549
|
+
for conn in psutil.net_connections(kind="inet"):
|
550
|
+
if conn.laddr.port == port:
|
551
|
+
try:
|
552
|
+
return psutil.Process(conn.pid)
|
553
|
+
except psutil.NoSuchProcess:
|
554
|
+
# It could happen by race condition (the proc dies when psutil.Process is called).
|
555
|
+
pass
|
556
|
+
|
557
|
+
return None
|
558
|
+
|
559
|
+
|
560
|
+
def wait_port_available(
|
561
|
+
port: int, port_name: str, timeout_s: int = 30, raise_exception: bool = True
|
562
|
+
) -> bool:
|
563
|
+
for i in range(timeout_s):
|
564
|
+
if is_port_available(port):
|
565
|
+
return True
|
566
|
+
|
567
|
+
if i > 10 and i % 5 == 0:
|
568
|
+
process = find_process_using_port(port)
|
569
|
+
if process is None:
|
570
|
+
logger.warning(
|
571
|
+
f"The port {port} is in use, but we could not find the process that uses it."
|
572
|
+
)
|
573
|
+
|
574
|
+
pid = process.pid
|
575
|
+
error_message = f"{port_name} is used by a process already. {process.name()=}' {process.cmdline()=} {process.status()=} {pid=}"
|
576
|
+
logger.info(
|
577
|
+
f"port {port} is in use. Waiting for {i} seconds for {port_name} to be available. {error_message}"
|
578
|
+
)
|
579
|
+
time.sleep(0.1)
|
580
|
+
|
581
|
+
if raise_exception:
|
582
|
+
raise ValueError(
|
583
|
+
f"{port_name} at {port} is not available in {timeout_s} seconds. {error_message}"
|
584
|
+
)
|
585
|
+
return False
|
586
|
+
|
587
|
+
|
503
588
|
def is_port_available(port):
|
504
589
|
"""Return whether a port is available."""
|
505
590
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
@@ -514,11 +599,24 @@ def is_port_available(port):
|
|
514
599
|
return False
|
515
600
|
|
516
601
|
|
602
|
+
def get_free_port():
|
603
|
+
# try ipv4
|
604
|
+
try:
|
605
|
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
606
|
+
s.bind(("", 0))
|
607
|
+
return s.getsockname()[1]
|
608
|
+
except OSError:
|
609
|
+
# try ipv6
|
610
|
+
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
|
611
|
+
s.bind(("", 0))
|
612
|
+
return s.getsockname()[1]
|
613
|
+
|
614
|
+
|
517
615
|
def decode_video_base64(video_base64):
|
518
616
|
from PIL import Image
|
519
617
|
|
520
618
|
# Decode the base64 string
|
521
|
-
video_bytes =
|
619
|
+
video_bytes = pybase64.b64decode(video_base64, validate=True)
|
522
620
|
|
523
621
|
# Placeholder for the start indices of each PNG image
|
524
622
|
img_starts = []
|
@@ -604,7 +702,9 @@ def load_audio(audio_file: str, sr: int = 16000, mono: bool = True) -> np.ndarra
|
|
604
702
|
audio, original_sr = sf.read(BytesIO(audio_file))
|
605
703
|
elif audio_file.startswith("data:"):
|
606
704
|
audio_file = audio_file.split(",")[1]
|
607
|
-
audio, original_sr = sf.read(
|
705
|
+
audio, original_sr = sf.read(
|
706
|
+
BytesIO(pybase64.b64decode(audio_file, validate=True))
|
707
|
+
)
|
608
708
|
elif audio_file.startswith("http://") or audio_file.startswith("https://"):
|
609
709
|
timeout = int(os.getenv("REQUEST_TIMEOUT", "5"))
|
610
710
|
response = requests.get(audio_file, stream=True, timeout=timeout)
|
@@ -673,12 +773,12 @@ def load_image(
|
|
673
773
|
image = Image.open(image_file)
|
674
774
|
elif image_file.startswith("data:"):
|
675
775
|
image_file = image_file.split(",")[1]
|
676
|
-
image = Image.open(BytesIO(
|
776
|
+
image = Image.open(BytesIO(pybase64.b64decode(image_file, validate=True)))
|
677
777
|
elif image_file.startswith("video:"):
|
678
778
|
image_file = image_file.replace("video:", "")
|
679
779
|
image, image_size = decode_video_base64(image_file)
|
680
780
|
elif isinstance(image_file, str):
|
681
|
-
image = Image.open(BytesIO(
|
781
|
+
image = Image.open(BytesIO(pybase64.b64decode(image_file, validate=True)))
|
682
782
|
else:
|
683
783
|
raise ValueError(f"Invalid image: {image}")
|
684
784
|
|
@@ -816,24 +916,51 @@ def maybe_set_triton_cache_manager() -> None:
|
|
816
916
|
class CustomCacheManager(FileCacheManager):
|
817
917
|
# Adapted from: https://github.com/tdoublep/vllm/blob/3307522289fdfefe323b6c00d0db696651989a2f/vllm/triton_utils/custom_cache_manager.py
|
818
918
|
def __init__(self, key, override=False, dump=False):
|
919
|
+
from sglang.srt.distributed.parallel_state import get_tp_group
|
819
920
|
|
820
921
|
self.key = key
|
821
922
|
self.lock_path = None
|
923
|
+
|
924
|
+
try:
|
925
|
+
module_path = "triton.runtime.cache"
|
926
|
+
cache_module = importlib.import_module(module_path)
|
927
|
+
|
928
|
+
default_cache_dir = getattr(cache_module, "default_cache_dir", None)
|
929
|
+
default_dump_dir = getattr(cache_module, "default_dump_dir", None)
|
930
|
+
default_override_dir = getattr(cache_module, "default_override_dir", None)
|
931
|
+
except (ModuleNotFoundError, AttributeError) as e:
|
932
|
+
default_cache_dir = None
|
933
|
+
default_dump_dir = None
|
934
|
+
default_override_dir = None
|
935
|
+
|
822
936
|
if dump:
|
823
|
-
self.cache_dir =
|
937
|
+
self.cache_dir = (
|
938
|
+
default_dump_dir()
|
939
|
+
if default_dump_dir is not None
|
940
|
+
else os.path.join(Path.home(), ".triton", "dump")
|
941
|
+
)
|
824
942
|
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
825
943
|
self.lock_path = os.path.join(self.cache_dir, "lock")
|
826
944
|
os.makedirs(self.cache_dir, exist_ok=True)
|
827
945
|
elif override:
|
828
|
-
self.cache_dir =
|
946
|
+
self.cache_dir = (
|
947
|
+
default_override_dir()
|
948
|
+
if default_override_dir is not None
|
949
|
+
else os.path.join(Path.home(), ".triton", "override")
|
950
|
+
)
|
829
951
|
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
830
952
|
else:
|
831
953
|
# create cache directory if it doesn't exist
|
832
|
-
self.cache_dir = (
|
833
|
-
|
954
|
+
self.cache_dir = os.getenv("TRITON_CACHE_DIR", "").strip() or (
|
955
|
+
default_cache_dir()
|
956
|
+
if default_cache_dir is not None
|
957
|
+
else os.path.join(Path.home(), ".triton", "cache")
|
834
958
|
)
|
835
959
|
if self.cache_dir:
|
836
|
-
|
960
|
+
try:
|
961
|
+
self.cache_dir = f"{self.cache_dir}_{get_tp_group().local_rank}"
|
962
|
+
except:
|
963
|
+
self.cache_dir = f"{self.cache_dir}_{os.getpid()}"
|
837
964
|
self.cache_dir = os.path.join(self.cache_dir, self.key)
|
838
965
|
self.lock_path = os.path.join(self.cache_dir, "lock")
|
839
966
|
os.makedirs(self.cache_dir, exist_ok=True)
|
@@ -997,36 +1124,48 @@ def point_to_point_pyobj(
|
|
997
1124
|
src: int = 0,
|
998
1125
|
dst: int = 1,
|
999
1126
|
):
|
1000
|
-
"""Send data from src to dst in group."""
|
1127
|
+
"""Send data from src to dst in group using DeviceToDevice communication."""
|
1001
1128
|
|
1002
1129
|
if rank == src:
|
1003
1130
|
if len(data) == 0:
|
1004
|
-
tensor_size = torch.tensor(
|
1131
|
+
tensor_size = torch.tensor(
|
1132
|
+
[0], dtype=torch.long, device=torch.cuda.current_device()
|
1133
|
+
)
|
1005
1134
|
dist.send(tensor_size, dst=dst, group=group)
|
1006
1135
|
else:
|
1007
1136
|
serialized_data = pickle.dumps(data)
|
1008
1137
|
size = len(serialized_data)
|
1009
1138
|
tensor_data = torch.ByteTensor(
|
1010
1139
|
np.frombuffer(serialized_data, dtype=np.uint8)
|
1140
|
+
).cuda(
|
1141
|
+
device=torch.cuda.current_device()
|
1142
|
+
) # Move to GPU
|
1143
|
+
tensor_size = torch.tensor(
|
1144
|
+
[size], dtype=torch.long, device=torch.cuda.current_device()
|
1011
1145
|
)
|
1012
|
-
tensor_size = torch.tensor([size], dtype=torch.long)
|
1013
1146
|
|
1014
1147
|
dist.send(tensor_size, dst=dst, group=group)
|
1015
1148
|
dist.send(tensor_data, dst=dst, group=group)
|
1016
1149
|
return data
|
1017
1150
|
|
1018
1151
|
elif rank == dst:
|
1019
|
-
tensor_size = torch.tensor(
|
1152
|
+
tensor_size = torch.tensor(
|
1153
|
+
[0], dtype=torch.long, device=torch.cuda.current_device()
|
1154
|
+
)
|
1020
1155
|
dist.recv(tensor_size, src=src, group=group)
|
1021
1156
|
size = tensor_size.item()
|
1022
1157
|
|
1023
1158
|
if size == 0:
|
1024
1159
|
return []
|
1025
1160
|
|
1026
|
-
tensor_data = torch.empty(
|
1161
|
+
tensor_data = torch.empty(
|
1162
|
+
size, dtype=torch.uint8, device=torch.cuda.current_device()
|
1163
|
+
)
|
1027
1164
|
dist.recv(tensor_data, src=src, group=group)
|
1028
1165
|
|
1029
|
-
serialized_data = bytes(
|
1166
|
+
serialized_data = bytes(
|
1167
|
+
tensor_data.cpu().numpy()
|
1168
|
+
) # Move back to host for deserialization
|
1030
1169
|
data = pickle.loads(serialized_data)
|
1031
1170
|
return data
|
1032
1171
|
|
@@ -1428,6 +1567,15 @@ def is_habana_available() -> bool:
|
|
1428
1567
|
|
1429
1568
|
@lru_cache(maxsize=8)
|
1430
1569
|
def get_device(device_id: Optional[int] = None) -> str:
|
1570
|
+
if is_cpu():
|
1571
|
+
if cpu_has_amx_support():
|
1572
|
+
logger.info("Intel AMX is detected, using CPU with Intel AMX support.")
|
1573
|
+
else:
|
1574
|
+
logger.warning(
|
1575
|
+
"CPU device enabled, using torch native backend, low performance expected."
|
1576
|
+
)
|
1577
|
+
return "cpu"
|
1578
|
+
|
1431
1579
|
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
1432
1580
|
if device_id is None:
|
1433
1581
|
return "cuda"
|
@@ -1456,15 +1604,6 @@ def get_device(device_id: Optional[int] = None) -> str:
|
|
1456
1604
|
"Habana frameworks detected, but failed to import 'habana_frameworks.torch.hpu'."
|
1457
1605
|
)
|
1458
1606
|
|
1459
|
-
if is_cpu():
|
1460
|
-
if cpu_has_amx_support():
|
1461
|
-
logger.info("Intel AMX is detected, using CPU with Intel AMX support.")
|
1462
|
-
else:
|
1463
|
-
logger.warning(
|
1464
|
-
"CPU device enabled, using torch native backend, low performance expected."
|
1465
|
-
)
|
1466
|
-
return "cpu"
|
1467
|
-
|
1468
1607
|
raise RuntimeError("No accelerator (CUDA, XPU, HPU) is available.")
|
1469
1608
|
|
1470
1609
|
|
@@ -1729,7 +1868,7 @@ class MultiprocessingSerializer:
|
|
1729
1868
|
|
1730
1869
|
if output_str:
|
1731
1870
|
# Convert bytes to base64-encoded string
|
1732
|
-
output =
|
1871
|
+
output = pybase64.b64encode(output).decode("utf-8")
|
1733
1872
|
|
1734
1873
|
return output
|
1735
1874
|
|
@@ -1746,7 +1885,7 @@ class MultiprocessingSerializer:
|
|
1746
1885
|
"""
|
1747
1886
|
if isinstance(data, str):
|
1748
1887
|
# Decode base64 string to bytes
|
1749
|
-
data =
|
1888
|
+
data = pybase64.b64decode(data, validate=True)
|
1750
1889
|
|
1751
1890
|
return ForkingPickler.loads(data)
|
1752
1891
|
|
@@ -1917,20 +2056,11 @@ def configure_ipv6(dist_init_addr):
|
|
1917
2056
|
return port, host
|
1918
2057
|
|
1919
2058
|
|
1920
|
-
def
|
2059
|
+
def rank0_log(msg: str):
|
1921
2060
|
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
1922
2061
|
|
1923
2062
|
if get_tensor_model_parallel_rank() == 0:
|
1924
|
-
|
1925
|
-
|
1926
|
-
|
1927
|
-
rank0_log = rank0_print
|
1928
|
-
|
1929
|
-
|
1930
|
-
def get_cuda_version():
|
1931
|
-
if torch.version.cuda:
|
1932
|
-
return tuple(map(int, torch.version.cuda.split(".")))
|
1933
|
-
return (0, 0)
|
2063
|
+
logger.info(msg)
|
1934
2064
|
|
1935
2065
|
|
1936
2066
|
def launch_dummy_health_check_server(host, port):
|
@@ -2092,14 +2222,14 @@ class DeepEPMode(Enum):
|
|
2092
2222
|
def enable_low_latency(self):
|
2093
2223
|
return self in [DeepEPMode.low_latency, DeepEPMode.auto]
|
2094
2224
|
|
2095
|
-
def resolve(self,
|
2225
|
+
def resolve(self, is_extend_in_batch: bool):
|
2096
2226
|
if self != DeepEPMode.auto:
|
2097
2227
|
return self
|
2098
2228
|
|
2099
|
-
if
|
2100
|
-
return DeepEPMode.low_latency
|
2101
|
-
else:
|
2229
|
+
if is_extend_in_batch:
|
2102
2230
|
return DeepEPMode.normal
|
2231
|
+
else:
|
2232
|
+
return DeepEPMode.low_latency
|
2103
2233
|
|
2104
2234
|
|
2105
2235
|
def is_non_idle_and_non_empty(forward_mode, hidden_states):
|
@@ -2119,35 +2249,12 @@ def fast_topk(values, topk, dim):
|
|
2119
2249
|
return torch.topk(values, topk, dim=dim)
|
2120
2250
|
|
2121
2251
|
|
2122
|
-
def
|
2123
|
-
if not
|
2124
|
-
|
2125
|
-
|
2126
|
-
|
2127
|
-
|
2128
|
-
|
2129
|
-
|
2130
|
-
is_ampere_with_cuda_12_3 = lambda: _check(8)
|
2131
|
-
is_hopper_with_cuda_12_3 = lambda: _check(9)
|
2132
|
-
|
2133
|
-
|
2134
|
-
def is_blackwell():
|
2135
|
-
if not is_cuda():
|
2136
|
-
return False
|
2137
|
-
return torch.cuda.get_device_capability()[0] == 10
|
2138
|
-
|
2139
|
-
|
2140
|
-
def get_free_port():
|
2141
|
-
# try ipv4
|
2142
|
-
try:
|
2143
|
-
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
2144
|
-
s.bind(("", 0))
|
2145
|
-
return s.getsockname()[1]
|
2146
|
-
except OSError:
|
2147
|
-
# try ipv6
|
2148
|
-
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
|
2149
|
-
s.bind(("", 0))
|
2150
|
-
return s.getsockname()[1]
|
2252
|
+
def bind_or_assign(target, source):
|
2253
|
+
if target is not None:
|
2254
|
+
target.copy_(source)
|
2255
|
+
return target
|
2256
|
+
else:
|
2257
|
+
return source
|
2151
2258
|
|
2152
2259
|
|
2153
2260
|
def get_local_ip_auto() -> str:
|
@@ -2344,45 +2451,6 @@ def require_mlp_sync(server_args):
|
|
2344
2451
|
return server_args.enable_dp_attention or require_gathered_buffer(server_args)
|
2345
2452
|
|
2346
2453
|
|
2347
|
-
def merge_bias_tensor(
|
2348
|
-
lhs: Optional[torch.Tensor],
|
2349
|
-
rhs: Optional[torch.Tensor],
|
2350
|
-
bs1: int,
|
2351
|
-
bs2: int,
|
2352
|
-
device: str,
|
2353
|
-
default: float,
|
2354
|
-
):
|
2355
|
-
"""Merge two bias tensors for batch merging.
|
2356
|
-
|
2357
|
-
Args:
|
2358
|
-
lhs: Left-hand side tensor
|
2359
|
-
rhs: Right-hand side tensor
|
2360
|
-
bs1: Batch size of left-hand side tensor
|
2361
|
-
bs2: Batch size of right-hand side tensor
|
2362
|
-
device: Device to place the merged tensor on
|
2363
|
-
default: Default value for missing tensor elements
|
2364
|
-
|
2365
|
-
Returns:
|
2366
|
-
Merged tensor or None if both inputs are None
|
2367
|
-
"""
|
2368
|
-
if lhs is None and rhs is None:
|
2369
|
-
return None
|
2370
|
-
|
2371
|
-
if lhs is not None and rhs is not None:
|
2372
|
-
return torch.cat([lhs, rhs])
|
2373
|
-
else:
|
2374
|
-
if lhs is not None:
|
2375
|
-
shape, dtype = lhs.shape[1:], lhs.dtype
|
2376
|
-
else:
|
2377
|
-
shape, dtype = rhs.shape[1:], rhs.dtype
|
2378
|
-
|
2379
|
-
if lhs is None:
|
2380
|
-
lhs = torch.empty((bs1, *shape), device=device, dtype=dtype).fill_(default)
|
2381
|
-
if rhs is None:
|
2382
|
-
rhs = torch.empty((bs2, *shape), device=device, dtype=dtype).fill_(default)
|
2383
|
-
return torch.cat([lhs, rhs])
|
2384
|
-
|
2385
|
-
|
2386
2454
|
def find_local_repo_dir(repo_id: str, revision: Optional[str] = None) -> Optional[str]:
|
2387
2455
|
import huggingface_hub as hf
|
2388
2456
|
|
@@ -2439,24 +2507,6 @@ def bind_or_assign(target, source):
|
|
2439
2507
|
return source
|
2440
2508
|
|
2441
2509
|
|
2442
|
-
def support_triton(backend: str) -> bool:
|
2443
|
-
return backend not in ["torch_native", "intel_amx"]
|
2444
|
-
|
2445
|
-
|
2446
|
-
try:
|
2447
|
-
import sgl_kernel
|
2448
|
-
|
2449
|
-
is_intel_amx_backend_available = hasattr(
|
2450
|
-
torch.ops.sgl_kernel, "convert_weight_packed"
|
2451
|
-
)
|
2452
|
-
except:
|
2453
|
-
is_intel_amx_backend_available = False
|
2454
|
-
|
2455
|
-
|
2456
|
-
def cpu_has_amx_support():
|
2457
|
-
return torch._C._cpu._is_amx_tile_supported() and is_intel_amx_backend_available
|
2458
|
-
|
2459
|
-
|
2460
2510
|
def prepack_weight_if_needed(weight):
|
2461
2511
|
if weight.device != torch.device("cpu"):
|
2462
2512
|
return weight
|
@@ -2577,3 +2627,133 @@ def configure_gc_logger():
|
|
2577
2627
|
)
|
2578
2628
|
|
2579
2629
|
gc.callbacks.append(gc_callback)
|
2630
|
+
|
2631
|
+
|
2632
|
+
# COPIED FROM DeepGEMM
|
2633
|
+
def align(x: int, y: int) -> int:
|
2634
|
+
return ceil_div(x, y) * y
|
2635
|
+
|
2636
|
+
|
2637
|
+
# COPIED FROM DeepGEMM
|
2638
|
+
def ceil_div(x: int, y: int) -> int:
|
2639
|
+
return (x + y - 1) // y
|
2640
|
+
|
2641
|
+
|
2642
|
+
def parse_lscpu_topology():
|
2643
|
+
try:
|
2644
|
+
# Get CPU topology: CPU,Core,Socket,Node
|
2645
|
+
output = subprocess.check_output(
|
2646
|
+
["lscpu", "-p=CPU,Core,Socket,Node"], text=True
|
2647
|
+
)
|
2648
|
+
except Exception as e:
|
2649
|
+
raise RuntimeError(f"Unexpected error running 'lscpu': {e}")
|
2650
|
+
|
2651
|
+
# Parse only data lines (skip comments)
|
2652
|
+
cpu_info = []
|
2653
|
+
for line in output.splitlines():
|
2654
|
+
if not line.startswith("#"):
|
2655
|
+
cpu, core, socket, node = map(int, line.strip().split(","))
|
2656
|
+
cpu_info.append((cpu, core, socket, node))
|
2657
|
+
|
2658
|
+
# [(0,0,0,0),(1,1,0,0),...,(43,43,0,1),...,(256,0,0,0),...]
|
2659
|
+
return cpu_info
|
2660
|
+
|
2661
|
+
|
2662
|
+
def get_physical_cpus_by_numa():
|
2663
|
+
cpu_info = parse_lscpu_topology()
|
2664
|
+
|
2665
|
+
# Map NUMA node -> set of (core_id, socket) to avoid duplicates
|
2666
|
+
# 0: {(0,0): 0, (1, 0): 1,...}
|
2667
|
+
# ...
|
2668
|
+
# 5: {(214,1): 214, (215,1): 215}
|
2669
|
+
physical_by_node = defaultdict(dict) # node -> core_id -> cpu_id
|
2670
|
+
|
2671
|
+
for cpu, core, socket, node in cpu_info:
|
2672
|
+
key = (core, socket)
|
2673
|
+
if key not in physical_by_node[node]:
|
2674
|
+
physical_by_node[node][
|
2675
|
+
key
|
2676
|
+
] = cpu # pick first CPU seen for that physical core
|
2677
|
+
|
2678
|
+
# Retrieves CPUs that the current process is allowed to run on
|
2679
|
+
cpus_allowed_list = psutil.Process().cpu_affinity()
|
2680
|
+
|
2681
|
+
# Convert to list of physical CPUs per node
|
2682
|
+
# 0: [0,1,2,...,42]
|
2683
|
+
# ...
|
2684
|
+
# 2: [86,87,...,127]
|
2685
|
+
# ...
|
2686
|
+
# 5: [214,215,...,255]
|
2687
|
+
node_to_cpus = {}
|
2688
|
+
for node, core_to_cpu in physical_by_node.items():
|
2689
|
+
cpus = sorted(core_to_cpu.values())
|
2690
|
+
allowed_cpus = set(cpus).intersection(cpus_allowed_list)
|
2691
|
+
node_to_cpus[node] = allowed_cpus
|
2692
|
+
|
2693
|
+
return node_to_cpus
|
2694
|
+
|
2695
|
+
|
2696
|
+
# Only physical cores are used. Logical cores are excluded.
|
2697
|
+
def get_cpu_ids_by_node():
|
2698
|
+
node_to_cpus = get_physical_cpus_by_numa()
|
2699
|
+
# Sort by NUMA node index
|
2700
|
+
cpu_ids = [
|
2701
|
+
",".join(map(str, sorted(node_to_cpus[node]))) for node in sorted(node_to_cpus)
|
2702
|
+
]
|
2703
|
+
|
2704
|
+
# ['0,1,2,3', '4,5,6,7', '8,9,10,11', '12,13,14,15', '16,17,18,19', '20,21,22,23']
|
2705
|
+
return cpu_ids
|
2706
|
+
|
2707
|
+
|
2708
|
+
def is_shm_available(dtype, world_size, local_size):
|
2709
|
+
return (
|
2710
|
+
cpu_has_amx_support()
|
2711
|
+
and dtype in [torch.bfloat16, torch.float]
|
2712
|
+
and world_size >= 1
|
2713
|
+
and world_size == local_size
|
2714
|
+
)
|
2715
|
+
|
2716
|
+
|
2717
|
+
def lru_cache_frozenset(maxsize=128):
|
2718
|
+
def _to_hashable(o):
|
2719
|
+
try:
|
2720
|
+
hash(o)
|
2721
|
+
return o
|
2722
|
+
except TypeError:
|
2723
|
+
# Not hashable; convert based on type
|
2724
|
+
if isinstance(o, (dict)):
|
2725
|
+
return frozenset(
|
2726
|
+
(_to_hashable(k), _to_hashable(v)) for k, v in o.items()
|
2727
|
+
)
|
2728
|
+
elif isinstance(o, set):
|
2729
|
+
return frozenset(_to_hashable(v) for v in o)
|
2730
|
+
elif isinstance(o, (list, tuple)) or (
|
2731
|
+
isinstance(o, Sequence) and not isinstance(o, (str, bytes))
|
2732
|
+
):
|
2733
|
+
return tuple(_to_hashable(v) for v in o)
|
2734
|
+
else:
|
2735
|
+
raise TypeError(f"Cannot make hashable: {type(o)}")
|
2736
|
+
|
2737
|
+
def decorator(func):
|
2738
|
+
cache = OrderedDict()
|
2739
|
+
|
2740
|
+
@functools.wraps(func)
|
2741
|
+
def wrapper(*args, **kwargs):
|
2742
|
+
h_args = tuple(_to_hashable(a) for a in args)
|
2743
|
+
h_kwargs = frozenset(
|
2744
|
+
(_to_hashable(k), _to_hashable(v)) for k, v in kwargs.items()
|
2745
|
+
)
|
2746
|
+
key = (h_args, h_kwargs)
|
2747
|
+
if key in cache:
|
2748
|
+
cache.move_to_end(key)
|
2749
|
+
return cache[key]
|
2750
|
+
result = func(*args, **kwargs)
|
2751
|
+
cache[key] = result
|
2752
|
+
if maxsize is not None and len(cache) > maxsize:
|
2753
|
+
cache.popitem(last=False)
|
2754
|
+
return result
|
2755
|
+
|
2756
|
+
wrapper.cache_clear = cache.clear # For manual cache clearing
|
2757
|
+
return wrapper
|
2758
|
+
|
2759
|
+
return decorator
|