sglang 0.5.2rc0__py3-none-any.whl → 0.5.2rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/lang/interpreter.py +1 -1
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/model_config.py +2 -1
- sglang/srt/disaggregation/mini_lb.py +2 -2
- sglang/srt/distributed/parallel_state.py +46 -41
- sglang/srt/entrypoints/engine.py +1 -1
- sglang/srt/entrypoints/http_server.py +5 -1
- sglang/srt/entrypoints/openai/protocol.py +3 -3
- sglang/srt/entrypoints/openai/serving_chat.py +3 -3
- sglang/srt/entrypoints/openai/serving_completions.py +3 -1
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -1
- sglang/srt/entrypoints/openai/serving_responses.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/communicator.py +45 -7
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +1 -9
- sglang/srt/layers/moe/ep_moe/layer.py +2 -7
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -1048
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +796 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
- sglang/srt/layers/moe/utils.py +0 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +8 -0
- sglang/srt/layers/quantization/modelopt_quant.py +35 -2
- sglang/srt/layers/quantization/mxfp4.py +4 -1
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/w4afp8.py +30 -25
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +0 -18
- sglang/srt/managers/cache_controller.py +42 -39
- sglang/srt/managers/detokenizer_manager.py +0 -34
- sglang/srt/managers/multi_tokenizer_mixin.py +48 -6
- sglang/srt/managers/schedule_policy.py +3 -2
- sglang/srt/managers/scheduler.py +7 -100
- sglang/srt/managers/scheduler_metrics_mixin.py +113 -7
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_manager.py +1 -0
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +15 -10
- sglang/srt/mem_cache/hiradix_cache.py +16 -0
- sglang/srt/mem_cache/memory_pool_host.py +18 -11
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +35 -6
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +32 -13
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/metrics/collector.py +12 -4
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/forward_batch_info.py +16 -17
- sglang/srt/model_executor/model_runner.py +1 -1
- sglang/srt/models/deepseek_v2.py +245 -36
- sglang/srt/models/glm4_moe.py +10 -1
- sglang/srt/models/gpt_oss.py +5 -4
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/longcat_flash.py +26 -15
- sglang/srt/models/longcat_flash_nextn.py +23 -15
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/qwen2_moe.py +4 -1
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +39 -8
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
- sglang/srt/server_args.py +79 -2
- sglang/srt/speculative/eagle_worker.py +158 -112
- sglang/srt/utils.py +12 -10
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/test_cutlass_w4a8_moe.py +24 -9
- sglang/utils.py +1 -0
- sglang/version.py +1 -1
- {sglang-0.5.2rc0.dist-info → sglang-0.5.2rc2.dist-info}/METADATA +2 -2
- {sglang-0.5.2rc0.dist-info → sglang-0.5.2rc2.dist-info}/RECORD +83 -76
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
- /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.2rc0.dist-info → sglang-0.5.2rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc0.dist-info → sglang-0.5.2rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc0.dist-info → sglang-0.5.2rc2.dist-info}/top_level.txt +0 -0
sglang/lang/interpreter.py
CHANGED
@@ -740,7 +740,7 @@ class StreamExecutor:
|
|
740
740
|
# Execute the stored lazy generation calls
|
741
741
|
self.backend.role_end_generate(self)
|
742
742
|
|
743
|
-
from sglang.srt.reasoning_parser import ReasoningParser
|
743
|
+
from sglang.srt.parser.reasoning_parser import ReasoningParser
|
744
744
|
|
745
745
|
reasoning_parser = ReasoningParser(expr.model_type)
|
746
746
|
other = expr.expr
|
sglang/srt/configs/internvl.py
CHANGED
@@ -6,11 +6,13 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|
6
6
|
import sentencepiece as spm
|
7
7
|
from transformers import (
|
8
8
|
TOKENIZER_MAPPING,
|
9
|
+
GptOssConfig,
|
9
10
|
LlamaConfig,
|
10
11
|
PretrainedConfig,
|
11
12
|
PreTrainedTokenizer,
|
12
13
|
Qwen2Config,
|
13
14
|
Qwen3Config,
|
15
|
+
Qwen3MoeConfig,
|
14
16
|
)
|
15
17
|
|
16
18
|
from sglang.utils import logger
|
@@ -316,7 +318,11 @@ class InternVLChatConfig(PretrainedConfig):
|
|
316
318
|
elif llm_config.get("architectures")[0] == "Qwen2ForCausalLM":
|
317
319
|
self.llm_config = Qwen2Config(**llm_config)
|
318
320
|
elif llm_config.get("architectures")[0] == "Qwen3MoeForCausalLM":
|
321
|
+
self.llm_config = Qwen3MoeConfig(**llm_config)
|
322
|
+
elif llm_config.get("architectures")[0] == "Qwen3ForCausalLM":
|
319
323
|
self.llm_config = Qwen3Config(**llm_config)
|
324
|
+
elif llm_config.get("architectures")[0] == "GptOssForCausalLM":
|
325
|
+
self.llm_config = GptOssConfig(**llm_config)
|
320
326
|
else:
|
321
327
|
raise ValueError(
|
322
328
|
"Unsupported architecture: {}".format(
|
@@ -405,9 +405,10 @@ class ModelConfig:
|
|
405
405
|
# compressed-tensors uses a "compression_config" key
|
406
406
|
quant_cfg = getattr(self.hf_config, "compression_config", None)
|
407
407
|
if quant_cfg is None:
|
408
|
-
# check if is modelopt model --
|
408
|
+
# check if is modelopt or mixed-precision model -- Both of them don't have corresponding field
|
409
409
|
# in hf `config.json` but has a standalone `hf_quant_config.json` in the root directory
|
410
410
|
# example: https://huggingface.co/nvidia/Llama-3.1-8B-Instruct-FP8/tree/main
|
411
|
+
# example: https://huggingface.co/Barrrrry/DeepSeek-R1-W4AFP8/tree/main
|
411
412
|
is_local = os.path.exists(self.model_path)
|
412
413
|
modelopt_quant_config = {"quant_method": "modelopt"}
|
413
414
|
if not is_local:
|
@@ -187,7 +187,7 @@ async def health_check():
|
|
187
187
|
|
188
188
|
|
189
189
|
@app.get("/health_generate")
|
190
|
-
async def
|
190
|
+
async def health_generate():
|
191
191
|
prefill_servers, decode_servers = (
|
192
192
|
load_balancer.prefill_servers,
|
193
193
|
load_balancer.decode_servers,
|
@@ -196,7 +196,7 @@ async def health_check():
|
|
196
196
|
# Create the tasks
|
197
197
|
tasks = []
|
198
198
|
for server in chain(prefill_servers, decode_servers):
|
199
|
-
tasks.append(session.
|
199
|
+
tasks.append(session.get(f"{server}/health_generate"))
|
200
200
|
for i, response in enumerate(asyncio.as_completed(tasks)):
|
201
201
|
await response
|
202
202
|
return Response(status_code=200)
|
@@ -43,6 +43,7 @@ from sglang.srt.utils import (
|
|
43
43
|
direct_register_custom_op,
|
44
44
|
get_bool_env_var,
|
45
45
|
get_int_env_var,
|
46
|
+
is_cpu,
|
46
47
|
is_cuda_alike,
|
47
48
|
is_hip,
|
48
49
|
is_npu,
|
@@ -51,6 +52,7 @@ from sglang.srt.utils import (
|
|
51
52
|
)
|
52
53
|
|
53
54
|
_is_npu = is_npu()
|
55
|
+
_is_cpu = is_cpu()
|
54
56
|
|
55
57
|
IS_ONE_DEVICE_PER_PROCESS = get_bool_env_var("SGLANG_ONE_DEVICE_PER_PROCESS")
|
56
58
|
|
@@ -877,17 +879,16 @@ class GroupCoordinator:
|
|
877
879
|
size_tensor = torch.tensor(
|
878
880
|
[object_tensor.numel()],
|
879
881
|
dtype=torch.long,
|
880
|
-
device=
|
882
|
+
device="cpu",
|
881
883
|
)
|
882
|
-
|
883
884
|
# Send object size
|
884
|
-
torch.distributed.send(
|
885
|
-
size_tensor, dst=self.ranks[dst], group=self.device_group
|
886
|
-
)
|
885
|
+
torch.distributed.send(size_tensor, dst=self.ranks[dst], group=self.cpu_group)
|
887
886
|
|
888
887
|
# Send object
|
889
888
|
torch.distributed.send(
|
890
|
-
object_tensor,
|
889
|
+
object_tensor,
|
890
|
+
dst=self.ranks[dst],
|
891
|
+
group=self.device_group,
|
891
892
|
)
|
892
893
|
|
893
894
|
return None
|
@@ -902,13 +903,11 @@ class GroupCoordinator:
|
|
902
903
|
src != self.rank_in_group
|
903
904
|
), "Invalid source rank. Source rank is the same as the current rank."
|
904
905
|
|
905
|
-
size_tensor = torch.empty(
|
906
|
-
1, dtype=torch.long, device=torch.cuda.current_device()
|
907
|
-
)
|
906
|
+
size_tensor = torch.empty(1, dtype=torch.long, device="cpu")
|
908
907
|
|
909
908
|
# Receive object size
|
910
909
|
rank_size = torch.distributed.recv(
|
911
|
-
size_tensor, src=self.ranks[src], group=self.
|
910
|
+
size_tensor, src=self.ranks[src], group=self.cpu_group
|
912
911
|
)
|
913
912
|
|
914
913
|
# Tensor to receive serialized objects into.
|
@@ -926,7 +925,7 @@ class GroupCoordinator:
|
|
926
925
|
rank_object == rank_size
|
927
926
|
), "Received object sender rank does not match the size sender rank."
|
928
927
|
|
929
|
-
obj = pickle.loads(object_tensor.cpu().numpy()
|
928
|
+
obj = pickle.loads(object_tensor.cpu().numpy())
|
930
929
|
|
931
930
|
return obj
|
932
931
|
|
@@ -1459,43 +1458,49 @@ def initialize_model_parallel(
|
|
1459
1458
|
_PDMUX_PREFILL_TP_GROUP.pynccl_comm.disabled = False
|
1460
1459
|
|
1461
1460
|
moe_ep_size = expert_model_parallel_size
|
1462
|
-
|
1463
1461
|
moe_tp_size = tensor_model_parallel_size // moe_ep_size
|
1462
|
+
|
1464
1463
|
global _MOE_EP
|
1465
1464
|
assert _MOE_EP is None, "expert model parallel group is already initialized"
|
1466
|
-
group_ranks = []
|
1467
|
-
for i in range(num_tensor_model_parallel_groups):
|
1468
|
-
for j in range(moe_tp_size):
|
1469
|
-
st = i * tensor_model_parallel_size + j
|
1470
|
-
en = (i + 1) * tensor_model_parallel_size + j
|
1471
|
-
ranks = list(range(st, en, moe_tp_size))
|
1472
|
-
group_ranks.append(ranks)
|
1473
1465
|
|
1474
|
-
|
1475
|
-
|
1476
|
-
|
1477
|
-
|
1478
|
-
|
1479
|
-
|
1480
|
-
|
1466
|
+
if moe_ep_size == tensor_model_parallel_size:
|
1467
|
+
_MOE_EP = _TP
|
1468
|
+
else:
|
1469
|
+
# TODO(ch-wan): use split_group to save memory
|
1470
|
+
group_ranks = []
|
1471
|
+
for i in range(num_tensor_model_parallel_groups):
|
1472
|
+
for j in range(moe_tp_size):
|
1473
|
+
st = i * tensor_model_parallel_size + j
|
1474
|
+
en = (i + 1) * tensor_model_parallel_size + j
|
1475
|
+
ranks = list(range(st, en, moe_tp_size))
|
1476
|
+
group_ranks.append(ranks)
|
1477
|
+
_MOE_EP = init_model_parallel_group(
|
1478
|
+
group_ranks,
|
1479
|
+
get_world_group().local_rank,
|
1480
|
+
backend,
|
1481
|
+
group_name="moe_ep",
|
1482
|
+
)
|
1481
1483
|
|
1482
1484
|
global _MOE_TP
|
1483
1485
|
assert _MOE_TP is None, "expert model parallel group is already initialized"
|
1484
|
-
group_ranks = []
|
1485
|
-
for i in range(num_tensor_model_parallel_groups):
|
1486
|
-
for j in range(moe_ep_size):
|
1487
|
-
st = i * tensor_model_parallel_size + j * moe_tp_size
|
1488
|
-
en = i * tensor_model_parallel_size + (j + 1) * moe_tp_size
|
1489
|
-
ranks = list(range(st, en))
|
1490
|
-
group_ranks.append(ranks)
|
1491
1486
|
|
1492
|
-
|
1493
|
-
|
1494
|
-
|
1495
|
-
|
1496
|
-
|
1497
|
-
|
1498
|
-
|
1487
|
+
if moe_tp_size == tensor_model_parallel_size:
|
1488
|
+
_MOE_TP = _TP
|
1489
|
+
else:
|
1490
|
+
# TODO(ch-wan): use split_group to save memory
|
1491
|
+
group_ranks = []
|
1492
|
+
for i in range(num_tensor_model_parallel_groups):
|
1493
|
+
for j in range(moe_ep_size):
|
1494
|
+
st = i * tensor_model_parallel_size + j * moe_tp_size
|
1495
|
+
en = i * tensor_model_parallel_size + (j + 1) * moe_tp_size
|
1496
|
+
ranks = list(range(st, en))
|
1497
|
+
group_ranks.append(ranks)
|
1498
|
+
_MOE_TP = init_model_parallel_group(
|
1499
|
+
group_ranks,
|
1500
|
+
get_world_group().local_rank,
|
1501
|
+
backend,
|
1502
|
+
group_name="moe_tp",
|
1503
|
+
)
|
1499
1504
|
|
1500
1505
|
# Build the pipeline model-parallel groups.
|
1501
1506
|
num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size
|
@@ -1643,7 +1648,7 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
|
|
1643
1648
|
|
1644
1649
|
ray.shutdown()
|
1645
1650
|
gc.collect()
|
1646
|
-
if not
|
1651
|
+
if not _is_cpu:
|
1647
1652
|
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
1648
1653
|
torch.cuda.empty_cache()
|
1649
1654
|
if hasattr(torch._C, "_host_emptyCache"):
|
sglang/srt/entrypoints/engine.py
CHANGED
@@ -681,7 +681,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
681
681
|
if _is_cuda and not get_bool_env_var("SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK"):
|
682
682
|
assert_pkg_version(
|
683
683
|
"sgl-kernel",
|
684
|
-
"0.3.
|
684
|
+
"0.3.8",
|
685
685
|
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
|
686
686
|
)
|
687
687
|
|
@@ -29,6 +29,8 @@ import time
|
|
29
29
|
from http import HTTPStatus
|
30
30
|
from typing import Any, AsyncIterator, Callable, Dict, List, Optional
|
31
31
|
|
32
|
+
import setproctitle
|
33
|
+
|
32
34
|
# Fix a bug of Python threading
|
33
35
|
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
34
36
|
|
@@ -102,7 +104,7 @@ from sglang.srt.managers.multi_tokenizer_mixin import (
|
|
102
104
|
from sglang.srt.managers.template_manager import TemplateManager
|
103
105
|
from sglang.srt.managers.tokenizer_manager import ServerStatus, TokenizerManager
|
104
106
|
from sglang.srt.metrics.func_timer import enable_func_timer
|
105
|
-
from sglang.srt.reasoning_parser import ReasoningParser
|
107
|
+
from sglang.srt.parser.reasoning_parser import ReasoningParser
|
106
108
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
107
109
|
from sglang.srt.utils import (
|
108
110
|
add_api_key_middleware,
|
@@ -1166,6 +1168,7 @@ def launch_server(
|
|
1166
1168
|
2. Inter-process communication is done through IPC (each process uses a different port) via the ZMQ library.
|
1167
1169
|
"""
|
1168
1170
|
if server_args.tokenizer_worker_num > 1:
|
1171
|
+
setproctitle.setproctitle(f"sglang::http_server/multi_tokenizer_router")
|
1169
1172
|
port_args = PortArgs.init_new(server_args)
|
1170
1173
|
port_args.tokenizer_worker_ipc_name = (
|
1171
1174
|
f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}"
|
@@ -1174,6 +1177,7 @@ def launch_server(
|
|
1174
1177
|
server_args=server_args, port_args=port_args
|
1175
1178
|
)
|
1176
1179
|
else:
|
1180
|
+
setproctitle.setproctitle(f"sglang::http_server/tokenizer_manager")
|
1177
1181
|
tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses(
|
1178
1182
|
server_args=server_args,
|
1179
1183
|
)
|
@@ -542,9 +542,9 @@ class ChatCompletionRequest(BaseModel):
|
|
542
542
|
rid: Optional[Union[List[str], str]] = None
|
543
543
|
|
544
544
|
# For PD disaggregation
|
545
|
-
bootstrap_host: Optional[str] = None
|
546
|
-
bootstrap_port: Optional[int] = None
|
547
|
-
bootstrap_room: Optional[int] = None
|
545
|
+
bootstrap_host: Optional[Union[List[str], str]] = None
|
546
|
+
bootstrap_port: Optional[Union[List[Optional[int]], int]] = None
|
547
|
+
bootstrap_room: Optional[Union[List[int], int]] = None
|
548
548
|
|
549
549
|
|
550
550
|
class ChatMessage(BaseModel):
|
@@ -8,7 +8,6 @@ from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
|
8
8
|
from fastapi import Request
|
9
9
|
from fastapi.responses import ORJSONResponse, StreamingResponse
|
10
10
|
|
11
|
-
from sglang.srt.conversation import generate_chat_conv
|
12
11
|
from sglang.srt.entrypoints.openai.protocol import (
|
13
12
|
ChatCompletionRequest,
|
14
13
|
ChatCompletionResponse,
|
@@ -33,11 +32,12 @@ from sglang.srt.entrypoints.openai.utils import (
|
|
33
32
|
to_openai_style_logprobs,
|
34
33
|
)
|
35
34
|
from sglang.srt.function_call.function_call_parser import FunctionCallParser
|
36
|
-
from sglang.srt.jinja_template_utils import process_content_for_template_format
|
37
35
|
from sglang.srt.managers.io_struct import GenerateReqInput
|
38
36
|
from sglang.srt.managers.template_manager import TemplateManager
|
39
37
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
40
|
-
from sglang.srt.
|
38
|
+
from sglang.srt.parser.conversation import generate_chat_conv
|
39
|
+
from sglang.srt.parser.jinja_template_utils import process_content_for_template_format
|
40
|
+
from sglang.srt.parser.reasoning_parser import ReasoningParser
|
41
41
|
from sglang.utils import convert_json_schema_to_str
|
42
42
|
|
43
43
|
logger = logging.getLogger(__name__)
|
@@ -5,7 +5,6 @@ from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
|
5
5
|
from fastapi import Request
|
6
6
|
from fastapi.responses import ORJSONResponse, StreamingResponse
|
7
7
|
|
8
|
-
from sglang.srt.code_completion_parser import generate_completion_prompt_from_request
|
9
8
|
from sglang.srt.entrypoints.openai.protocol import (
|
10
9
|
CompletionRequest,
|
11
10
|
CompletionResponse,
|
@@ -23,6 +22,9 @@ from sglang.srt.entrypoints.openai.utils import (
|
|
23
22
|
from sglang.srt.managers.io_struct import GenerateReqInput
|
24
23
|
from sglang.srt.managers.template_manager import TemplateManager
|
25
24
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
25
|
+
from sglang.srt.parser.code_completion_parser import (
|
26
|
+
generate_completion_prompt_from_request,
|
27
|
+
)
|
26
28
|
from sglang.utils import convert_json_schema_to_str
|
27
29
|
|
28
30
|
logger = logging.getLogger(__name__)
|
@@ -3,7 +3,6 @@ from typing import Any, Dict, List, Optional, Union
|
|
3
3
|
from fastapi import Request
|
4
4
|
from fastapi.responses import ORJSONResponse
|
5
5
|
|
6
|
-
from sglang.srt.conversation import generate_embedding_convs
|
7
6
|
from sglang.srt.entrypoints.openai.protocol import (
|
8
7
|
EmbeddingObject,
|
9
8
|
EmbeddingRequest,
|
@@ -16,6 +15,7 @@ from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
|
|
16
15
|
from sglang.srt.managers.io_struct import EmbeddingReqInput
|
17
16
|
from sglang.srt.managers.template_manager import TemplateManager
|
18
17
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
18
|
+
from sglang.srt.parser.conversation import generate_embedding_convs
|
19
19
|
|
20
20
|
|
21
21
|
class OpenAIServingEmbedding(OpenAIServingBase):
|
@@ -56,7 +56,7 @@ from sglang.srt.entrypoints.openai.tool_server import MCPToolServer, ToolServer
|
|
56
56
|
from sglang.srt.managers.io_struct import GenerateReqInput
|
57
57
|
from sglang.srt.managers.template_manager import TemplateManager
|
58
58
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
59
|
-
from sglang.srt.reasoning_parser import ReasoningParser
|
59
|
+
from sglang.srt.parser.reasoning_parser import ReasoningParser
|
60
60
|
from sglang.srt.utils import random_uuid
|
61
61
|
|
62
62
|
logger = logging.getLogger(__name__)
|
@@ -18,7 +18,10 @@ import triton.language as tl
|
|
18
18
|
from sglang.global_config import global_config
|
19
19
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
20
20
|
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
21
|
-
from sglang.srt.layers.dp_attention import
|
21
|
+
from sglang.srt.layers.dp_attention import (
|
22
|
+
get_attention_tp_size,
|
23
|
+
is_dp_attention_enabled,
|
24
|
+
)
|
22
25
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
23
26
|
|
24
27
|
if TYPE_CHECKING:
|
@@ -154,6 +157,8 @@ class AiterAttnBackend(AttentionBackend):
|
|
154
157
|
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
155
158
|
)
|
156
159
|
|
160
|
+
self.enable_dp_attention = is_dp_attention_enabled()
|
161
|
+
|
157
162
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
158
163
|
"""Init auxiliary variables for triton attention backend."""
|
159
164
|
|
@@ -302,19 +307,19 @@ class AiterAttnBackend(AttentionBackend):
|
|
302
307
|
if self.use_mla:
|
303
308
|
self.mla_indices_updater_prefill.update(
|
304
309
|
forward_batch.req_pool_indices,
|
305
|
-
forward_batch.
|
306
|
-
|
310
|
+
forward_batch.seq_lens,
|
311
|
+
forward_batch.seq_lens_sum,
|
307
312
|
forward_batch.extend_seq_lens,
|
308
|
-
max(
|
309
|
-
forward_batch.
|
313
|
+
forward_batch.extend_seq_lens.max().item(),
|
314
|
+
forward_batch.seq_lens.max().item(),
|
310
315
|
spec_info=None,
|
311
316
|
)
|
312
|
-
|
313
|
-
|
314
|
-
|
317
|
+
|
318
|
+
kv_indices = self.mla_indices_updater_prefill.kv_indices
|
319
|
+
|
315
320
|
self.forward_metadata = ForwardMetadata(
|
316
321
|
self.mla_indices_updater_prefill.kv_indptr,
|
317
|
-
|
322
|
+
kv_indices,
|
318
323
|
self.mla_indices_updater_prefill.qo_indptr,
|
319
324
|
self.kv_last_page_len[:bs],
|
320
325
|
self.mla_indices_updater_prefill.max_q_len,
|
@@ -614,66 +619,86 @@ class AiterAttnBackend(AttentionBackend):
|
|
614
619
|
assert len(k.shape) == 3
|
615
620
|
assert len(v.shape) == 3
|
616
621
|
|
617
|
-
if
|
618
|
-
|
619
|
-
|
620
|
-
|
621
|
-
|
622
|
-
|
623
|
-
|
624
|
-
|
625
|
-
|
626
|
-
|
627
|
-
|
628
|
-
|
629
|
-
|
630
|
-
|
631
|
-
|
632
|
-
|
633
|
-
|
634
|
-
|
635
|
-
|
622
|
+
if forward_batch.forward_mode.is_extend():
|
623
|
+
if kv_indices.shape[0] == 0:
|
624
|
+
o = flash_attn_varlen_func(
|
625
|
+
q,
|
626
|
+
k,
|
627
|
+
v,
|
628
|
+
qo_indptr,
|
629
|
+
qo_indptr,
|
630
|
+
max_q_len,
|
631
|
+
max_q_len,
|
632
|
+
softmax_scale=layer.scaling,
|
633
|
+
causal=True,
|
634
|
+
)
|
635
|
+
return o
|
636
|
+
elif layer.qk_head_dim != (kv_lora_rank + qk_rope_head_dim):
|
637
|
+
K_Buffer = torch.index_select(K_Buffer, 0, kv_indices)
|
638
|
+
kvc, k_pe = torch.split(
|
639
|
+
K_Buffer, [kv_lora_rank, qk_rope_head_dim], dim=-1
|
640
|
+
)
|
641
|
+
kvprefix = layer.kv_b_proj(kvc.contiguous())[0]
|
636
642
|
|
637
|
-
|
638
|
-
|
639
|
-
|
640
|
-
|
641
|
-
|
642
|
-
|
643
|
-
|
644
|
-
|
645
|
-
|
646
|
-
|
647
|
-
|
648
|
-
|
649
|
-
|
650
|
-
|
651
|
-
|
652
|
-
|
653
|
-
|
654
|
-
|
655
|
-
|
656
|
-
|
657
|
-
|
658
|
-
|
659
|
-
|
660
|
-
|
661
|
-
|
662
|
-
|
663
|
-
|
664
|
-
|
665
|
-
|
666
|
-
|
667
|
-
|
668
|
-
|
669
|
-
|
670
|
-
|
671
|
-
|
672
|
-
|
673
|
-
|
674
|
-
|
675
|
-
|
676
|
-
|
643
|
+
kvprefix = kvprefix.view(
|
644
|
+
-1, layer.tp_k_head_num, qk_nope_head_dim + layer.v_head_dim
|
645
|
+
)
|
646
|
+
k_prefix, v_prefix = torch.split(
|
647
|
+
kvprefix, [qk_nope_head_dim, layer.v_head_dim], dim=-1
|
648
|
+
)
|
649
|
+
k_prefix = torch.cat(
|
650
|
+
[
|
651
|
+
k_prefix,
|
652
|
+
torch.broadcast_to(
|
653
|
+
k_pe,
|
654
|
+
(k_pe.shape[0], layer.tp_k_head_num, k_pe.shape[2]),
|
655
|
+
),
|
656
|
+
],
|
657
|
+
dim=-1,
|
658
|
+
)
|
659
|
+
assert (
|
660
|
+
forward_batch.extend_prefix_lens.shape
|
661
|
+
== forward_batch.extend_seq_lens.shape
|
662
|
+
)
|
663
|
+
|
664
|
+
k = k_prefix
|
665
|
+
v = v_prefix
|
666
|
+
|
667
|
+
o = flash_attn_varlen_func(
|
668
|
+
q,
|
669
|
+
k,
|
670
|
+
v,
|
671
|
+
qo_indptr,
|
672
|
+
kv_indptr,
|
673
|
+
max_q_len,
|
674
|
+
max_kv_len,
|
675
|
+
softmax_scale=layer.scaling,
|
676
|
+
causal=True,
|
677
|
+
)
|
678
|
+
return o
|
679
|
+
|
680
|
+
else:
|
681
|
+
if layer.qk_head_dim != layer.v_head_dim:
|
682
|
+
o = q.new_empty(
|
683
|
+
(q.shape[0], layer.tp_q_head_num * layer.v_head_dim)
|
684
|
+
)
|
685
|
+
else:
|
686
|
+
o = torch.empty_like(q)
|
687
|
+
|
688
|
+
mla_prefill_fwd(
|
689
|
+
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
690
|
+
K_Buffer.view(-1, 1, 1, layer.qk_head_dim),
|
691
|
+
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
|
692
|
+
qo_indptr,
|
693
|
+
kv_indptr,
|
694
|
+
kv_indices,
|
695
|
+
self.forward_metadata.kv_last_page_len,
|
696
|
+
self.forward_metadata.max_q_len,
|
697
|
+
layer.scaling,
|
698
|
+
layer.logit_cap,
|
699
|
+
)
|
700
|
+
K_Buffer = K_Buffer.view(-1, layer.tp_k_head_num, layer.qk_head_dim)
|
701
|
+
return o
|
677
702
|
elif forward_batch.forward_mode.is_target_verify():
|
678
703
|
o = q.new_empty((q.shape[0], layer.tp_q_head_num, layer.v_head_dim))
|
679
704
|
mla_decode_fwd(
|
@@ -42,10 +42,24 @@ from sglang.srt.layers.moe import (
|
|
42
42
|
)
|
43
43
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
44
44
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
45
|
-
from sglang.srt.utils import
|
45
|
+
from sglang.srt.utils import (
|
46
|
+
get_bool_env_var,
|
47
|
+
is_cuda,
|
48
|
+
is_flashinfer_available,
|
49
|
+
is_gfx95_supported,
|
50
|
+
is_hip,
|
51
|
+
is_sm90_supported,
|
52
|
+
is_sm100_supported,
|
53
|
+
)
|
46
54
|
|
47
55
|
_is_flashinfer_available = is_flashinfer_available()
|
56
|
+
_is_sm90_supported = is_cuda() and is_sm90_supported()
|
48
57
|
_is_sm100_supported = is_cuda() and is_sm100_supported()
|
58
|
+
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
|
59
|
+
_is_gfx95_supported = is_gfx95_supported()
|
60
|
+
|
61
|
+
if _use_aiter and _is_gfx95_supported:
|
62
|
+
from sglang.srt.layers.quantization.rocm_mxfp4_utils import fused_rms_mxfp4_quant
|
49
63
|
|
50
64
|
FUSE_ALLREDUCE_MAX_BATCH_SIZE = 2048
|
51
65
|
|
@@ -201,6 +215,7 @@ class LayerCommunicator:
|
|
201
215
|
hidden_states: torch.Tensor,
|
202
216
|
residual: torch.Tensor,
|
203
217
|
forward_batch: ForwardBatch,
|
218
|
+
qaunt_format: str = "",
|
204
219
|
):
|
205
220
|
if hidden_states.shape[0] == 0:
|
206
221
|
residual = hidden_states
|
@@ -218,11 +233,34 @@ class LayerCommunicator:
|
|
218
233
|
else:
|
219
234
|
if residual is None:
|
220
235
|
residual = hidden_states
|
221
|
-
|
236
|
+
|
237
|
+
if _use_aiter and _is_gfx95_supported and ("mxfp4" in qaunt_format):
|
238
|
+
hidden_states = fused_rms_mxfp4_quant(
|
239
|
+
hidden_states,
|
240
|
+
self.input_layernorm.weight,
|
241
|
+
self.input_layernorm.variance_epsilon,
|
242
|
+
None,
|
243
|
+
None,
|
244
|
+
None,
|
245
|
+
None,
|
246
|
+
)
|
247
|
+
else:
|
248
|
+
hidden_states = self.input_layernorm(hidden_states)
|
222
249
|
else:
|
223
|
-
|
224
|
-
hidden_states, residual
|
225
|
-
|
250
|
+
if _use_aiter and _is_gfx95_supported and ("mxfp4" in qaunt_format):
|
251
|
+
hidden_states, residual = fused_rms_mxfp4_quant(
|
252
|
+
hidden_states,
|
253
|
+
self.input_layernorm.weight,
|
254
|
+
self.input_layernorm.variance_epsilon,
|
255
|
+
None,
|
256
|
+
None,
|
257
|
+
None,
|
258
|
+
residual,
|
259
|
+
)
|
260
|
+
else:
|
261
|
+
hidden_states, residual = self.input_layernorm(
|
262
|
+
hidden_states, residual
|
263
|
+
)
|
226
264
|
|
227
265
|
hidden_states = self._communicate_simple_fn(
|
228
266
|
hidden_states=hidden_states,
|
@@ -484,11 +522,11 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
|
484
522
|
# According to the discussion in https://github.com/flashinfer-ai/flashinfer/issues/1223#issuecomment-3047256465
|
485
523
|
# We set the max token num to 128 for allreduce fusion with min-latency case(use_oneshot=True).
|
486
524
|
if (
|
487
|
-
_is_sm100_supported
|
525
|
+
(_is_sm100_supported or _is_sm90_supported)
|
488
526
|
and _is_flashinfer_available
|
489
527
|
and hasattr(layernorm, "forward_with_allreduce_fusion")
|
490
528
|
and global_server_args_dict["enable_flashinfer_allreduce_fusion"]
|
491
|
-
and hidden_states.shape[0] <=
|
529
|
+
and hidden_states.shape[0] <= 4096
|
492
530
|
):
|
493
531
|
hidden_states, residual = layernorm.forward_with_allreduce_fusion(
|
494
532
|
hidden_states, residual
|
@@ -91,18 +91,10 @@ def cutlass_w4a8_moe(
|
|
91
91
|
assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch"
|
92
92
|
assert w1_q.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch"
|
93
93
|
assert w1_q.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch"
|
94
|
-
assert (
|
95
|
-
w1_scale.shape[1] == w1_q.shape[2] * 2 / 512
|
96
|
-
and w1_scale.shape[2] == w1_q.shape[1] * 4
|
97
|
-
), "W1 scale shape mismatch"
|
98
|
-
assert (
|
99
|
-
w2_scale.shape[1] == w2_q.shape[2] * 2 / 512
|
100
|
-
and w2_scale.shape[2] == w2_q.shape[1] * 4
|
101
|
-
), "W2 scale shape mismatch"
|
102
94
|
|
103
95
|
assert a_strides1.shape[0] == w1_q.shape[0], "A Strides 1 expert number mismatch"
|
104
96
|
assert b_strides1.shape[0] == w1_q.shape[0], "B Strides 1 expert number mismatch"
|
105
|
-
assert a_strides2.shape[0] == w2_q.shape[0], "A Strides 2 expert number
|
97
|
+
assert a_strides2.shape[0] == w2_q.shape[0], "A Strides 2 expert number mismatch"
|
106
98
|
assert b_strides2.shape[0] == w2_q.shape[0], "B Strides 2 expert number mismatch"
|
107
99
|
num_experts = w1_q.size(0)
|
108
100
|
m = a.size(0)
|