sglang 0.5.2rc1__py3-none-any.whl → 0.5.2rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/lang/interpreter.py +1 -1
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/disaggregation/mini_lb.py +2 -2
- sglang/srt/distributed/parallel_state.py +43 -40
- sglang/srt/entrypoints/http_server.py +5 -1
- sglang/srt/entrypoints/openai/protocol.py +3 -3
- sglang/srt/entrypoints/openai/serving_chat.py +3 -3
- sglang/srt/entrypoints/openai/serving_completions.py +3 -1
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -1
- sglang/srt/entrypoints/openai/serving_responses.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/communicator.py +45 -7
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/utils.py +0 -1
- sglang/srt/layers/quantization/modelopt_quant.py +35 -2
- sglang/srt/layers/quantization/mxfp4.py +4 -1
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +0 -18
- sglang/srt/managers/cache_controller.py +42 -39
- sglang/srt/managers/multi_tokenizer_mixin.py +4 -0
- sglang/srt/managers/schedule_policy.py +3 -2
- sglang/srt/managers/scheduler.py +4 -100
- sglang/srt/managers/scheduler_metrics_mixin.py +113 -7
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_manager.py +1 -0
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +15 -10
- sglang/srt/mem_cache/hiradix_cache.py +5 -5
- sglang/srt/mem_cache/memory_pool_host.py +16 -11
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +10 -2
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +32 -13
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/metrics/collector.py +12 -4
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/forward_batch_info.py +16 -17
- sglang/srt/model_executor/model_runner.py +1 -1
- sglang/srt/models/deepseek_v2.py +240 -36
- sglang/srt/models/glm4_moe.py +10 -1
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/qwen2_moe.py +4 -1
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +39 -8
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
- sglang/srt/server_args.py +79 -2
- sglang/srt/speculative/eagle_worker.py +158 -112
- sglang/srt/utils.py +12 -0
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/utils.py +1 -0
- sglang/version.py +1 -1
- {sglang-0.5.2rc1.dist-info → sglang-0.5.2rc2.dist-info}/METADATA +1 -1
- {sglang-0.5.2rc1.dist-info → sglang-0.5.2rc2.dist-info}/RECORD +65 -61
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
- /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.2rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.2rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.2rc2.dist-info}/top_level.txt +0 -0
sglang/lang/interpreter.py
CHANGED
@@ -740,7 +740,7 @@ class StreamExecutor:
|
|
740
740
|
# Execute the stored lazy generation calls
|
741
741
|
self.backend.role_end_generate(self)
|
742
742
|
|
743
|
-
from sglang.srt.reasoning_parser import ReasoningParser
|
743
|
+
from sglang.srt.parser.reasoning_parser import ReasoningParser
|
744
744
|
|
745
745
|
reasoning_parser = ReasoningParser(expr.model_type)
|
746
746
|
other = expr.expr
|
sglang/srt/configs/internvl.py
CHANGED
@@ -6,11 +6,13 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|
6
6
|
import sentencepiece as spm
|
7
7
|
from transformers import (
|
8
8
|
TOKENIZER_MAPPING,
|
9
|
+
GptOssConfig,
|
9
10
|
LlamaConfig,
|
10
11
|
PretrainedConfig,
|
11
12
|
PreTrainedTokenizer,
|
12
13
|
Qwen2Config,
|
13
14
|
Qwen3Config,
|
15
|
+
Qwen3MoeConfig,
|
14
16
|
)
|
15
17
|
|
16
18
|
from sglang.utils import logger
|
@@ -316,7 +318,11 @@ class InternVLChatConfig(PretrainedConfig):
|
|
316
318
|
elif llm_config.get("architectures")[0] == "Qwen2ForCausalLM":
|
317
319
|
self.llm_config = Qwen2Config(**llm_config)
|
318
320
|
elif llm_config.get("architectures")[0] == "Qwen3MoeForCausalLM":
|
321
|
+
self.llm_config = Qwen3MoeConfig(**llm_config)
|
322
|
+
elif llm_config.get("architectures")[0] == "Qwen3ForCausalLM":
|
319
323
|
self.llm_config = Qwen3Config(**llm_config)
|
324
|
+
elif llm_config.get("architectures")[0] == "GptOssForCausalLM":
|
325
|
+
self.llm_config = GptOssConfig(**llm_config)
|
320
326
|
else:
|
321
327
|
raise ValueError(
|
322
328
|
"Unsupported architecture: {}".format(
|
@@ -187,7 +187,7 @@ async def health_check():
|
|
187
187
|
|
188
188
|
|
189
189
|
@app.get("/health_generate")
|
190
|
-
async def
|
190
|
+
async def health_generate():
|
191
191
|
prefill_servers, decode_servers = (
|
192
192
|
load_balancer.prefill_servers,
|
193
193
|
load_balancer.decode_servers,
|
@@ -196,7 +196,7 @@ async def health_check():
|
|
196
196
|
# Create the tasks
|
197
197
|
tasks = []
|
198
198
|
for server in chain(prefill_servers, decode_servers):
|
199
|
-
tasks.append(session.
|
199
|
+
tasks.append(session.get(f"{server}/health_generate"))
|
200
200
|
for i, response in enumerate(asyncio.as_completed(tasks)):
|
201
201
|
await response
|
202
202
|
return Response(status_code=200)
|
@@ -879,17 +879,16 @@ class GroupCoordinator:
|
|
879
879
|
size_tensor = torch.tensor(
|
880
880
|
[object_tensor.numel()],
|
881
881
|
dtype=torch.long,
|
882
|
-
device=
|
882
|
+
device="cpu",
|
883
883
|
)
|
884
|
-
|
885
884
|
# Send object size
|
886
|
-
torch.distributed.send(
|
887
|
-
size_tensor, dst=self.ranks[dst], group=self.device_group
|
888
|
-
)
|
885
|
+
torch.distributed.send(size_tensor, dst=self.ranks[dst], group=self.cpu_group)
|
889
886
|
|
890
887
|
# Send object
|
891
888
|
torch.distributed.send(
|
892
|
-
object_tensor,
|
889
|
+
object_tensor,
|
890
|
+
dst=self.ranks[dst],
|
891
|
+
group=self.device_group,
|
893
892
|
)
|
894
893
|
|
895
894
|
return None
|
@@ -904,13 +903,11 @@ class GroupCoordinator:
|
|
904
903
|
src != self.rank_in_group
|
905
904
|
), "Invalid source rank. Source rank is the same as the current rank."
|
906
905
|
|
907
|
-
size_tensor = torch.empty(
|
908
|
-
1, dtype=torch.long, device=torch.cuda.current_device()
|
909
|
-
)
|
906
|
+
size_tensor = torch.empty(1, dtype=torch.long, device="cpu")
|
910
907
|
|
911
908
|
# Receive object size
|
912
909
|
rank_size = torch.distributed.recv(
|
913
|
-
size_tensor, src=self.ranks[src], group=self.
|
910
|
+
size_tensor, src=self.ranks[src], group=self.cpu_group
|
914
911
|
)
|
915
912
|
|
916
913
|
# Tensor to receive serialized objects into.
|
@@ -928,7 +925,7 @@ class GroupCoordinator:
|
|
928
925
|
rank_object == rank_size
|
929
926
|
), "Received object sender rank does not match the size sender rank."
|
930
927
|
|
931
|
-
obj = pickle.loads(object_tensor.cpu().numpy()
|
928
|
+
obj = pickle.loads(object_tensor.cpu().numpy())
|
932
929
|
|
933
930
|
return obj
|
934
931
|
|
@@ -1461,43 +1458,49 @@ def initialize_model_parallel(
|
|
1461
1458
|
_PDMUX_PREFILL_TP_GROUP.pynccl_comm.disabled = False
|
1462
1459
|
|
1463
1460
|
moe_ep_size = expert_model_parallel_size
|
1464
|
-
|
1465
1461
|
moe_tp_size = tensor_model_parallel_size // moe_ep_size
|
1462
|
+
|
1466
1463
|
global _MOE_EP
|
1467
1464
|
assert _MOE_EP is None, "expert model parallel group is already initialized"
|
1468
|
-
group_ranks = []
|
1469
|
-
for i in range(num_tensor_model_parallel_groups):
|
1470
|
-
for j in range(moe_tp_size):
|
1471
|
-
st = i * tensor_model_parallel_size + j
|
1472
|
-
en = (i + 1) * tensor_model_parallel_size + j
|
1473
|
-
ranks = list(range(st, en, moe_tp_size))
|
1474
|
-
group_ranks.append(ranks)
|
1475
1465
|
|
1476
|
-
|
1477
|
-
|
1478
|
-
|
1479
|
-
|
1480
|
-
|
1481
|
-
|
1482
|
-
|
1466
|
+
if moe_ep_size == tensor_model_parallel_size:
|
1467
|
+
_MOE_EP = _TP
|
1468
|
+
else:
|
1469
|
+
# TODO(ch-wan): use split_group to save memory
|
1470
|
+
group_ranks = []
|
1471
|
+
for i in range(num_tensor_model_parallel_groups):
|
1472
|
+
for j in range(moe_tp_size):
|
1473
|
+
st = i * tensor_model_parallel_size + j
|
1474
|
+
en = (i + 1) * tensor_model_parallel_size + j
|
1475
|
+
ranks = list(range(st, en, moe_tp_size))
|
1476
|
+
group_ranks.append(ranks)
|
1477
|
+
_MOE_EP = init_model_parallel_group(
|
1478
|
+
group_ranks,
|
1479
|
+
get_world_group().local_rank,
|
1480
|
+
backend,
|
1481
|
+
group_name="moe_ep",
|
1482
|
+
)
|
1483
1483
|
|
1484
1484
|
global _MOE_TP
|
1485
1485
|
assert _MOE_TP is None, "expert model parallel group is already initialized"
|
1486
|
-
group_ranks = []
|
1487
|
-
for i in range(num_tensor_model_parallel_groups):
|
1488
|
-
for j in range(moe_ep_size):
|
1489
|
-
st = i * tensor_model_parallel_size + j * moe_tp_size
|
1490
|
-
en = i * tensor_model_parallel_size + (j + 1) * moe_tp_size
|
1491
|
-
ranks = list(range(st, en))
|
1492
|
-
group_ranks.append(ranks)
|
1493
1486
|
|
1494
|
-
|
1495
|
-
|
1496
|
-
|
1497
|
-
|
1498
|
-
|
1499
|
-
|
1500
|
-
|
1487
|
+
if moe_tp_size == tensor_model_parallel_size:
|
1488
|
+
_MOE_TP = _TP
|
1489
|
+
else:
|
1490
|
+
# TODO(ch-wan): use split_group to save memory
|
1491
|
+
group_ranks = []
|
1492
|
+
for i in range(num_tensor_model_parallel_groups):
|
1493
|
+
for j in range(moe_ep_size):
|
1494
|
+
st = i * tensor_model_parallel_size + j * moe_tp_size
|
1495
|
+
en = i * tensor_model_parallel_size + (j + 1) * moe_tp_size
|
1496
|
+
ranks = list(range(st, en))
|
1497
|
+
group_ranks.append(ranks)
|
1498
|
+
_MOE_TP = init_model_parallel_group(
|
1499
|
+
group_ranks,
|
1500
|
+
get_world_group().local_rank,
|
1501
|
+
backend,
|
1502
|
+
group_name="moe_tp",
|
1503
|
+
)
|
1501
1504
|
|
1502
1505
|
# Build the pipeline model-parallel groups.
|
1503
1506
|
num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size
|
@@ -29,6 +29,8 @@ import time
|
|
29
29
|
from http import HTTPStatus
|
30
30
|
from typing import Any, AsyncIterator, Callable, Dict, List, Optional
|
31
31
|
|
32
|
+
import setproctitle
|
33
|
+
|
32
34
|
# Fix a bug of Python threading
|
33
35
|
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
34
36
|
|
@@ -102,7 +104,7 @@ from sglang.srt.managers.multi_tokenizer_mixin import (
|
|
102
104
|
from sglang.srt.managers.template_manager import TemplateManager
|
103
105
|
from sglang.srt.managers.tokenizer_manager import ServerStatus, TokenizerManager
|
104
106
|
from sglang.srt.metrics.func_timer import enable_func_timer
|
105
|
-
from sglang.srt.reasoning_parser import ReasoningParser
|
107
|
+
from sglang.srt.parser.reasoning_parser import ReasoningParser
|
106
108
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
107
109
|
from sglang.srt.utils import (
|
108
110
|
add_api_key_middleware,
|
@@ -1166,6 +1168,7 @@ def launch_server(
|
|
1166
1168
|
2. Inter-process communication is done through IPC (each process uses a different port) via the ZMQ library.
|
1167
1169
|
"""
|
1168
1170
|
if server_args.tokenizer_worker_num > 1:
|
1171
|
+
setproctitle.setproctitle(f"sglang::http_server/multi_tokenizer_router")
|
1169
1172
|
port_args = PortArgs.init_new(server_args)
|
1170
1173
|
port_args.tokenizer_worker_ipc_name = (
|
1171
1174
|
f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}"
|
@@ -1174,6 +1177,7 @@ def launch_server(
|
|
1174
1177
|
server_args=server_args, port_args=port_args
|
1175
1178
|
)
|
1176
1179
|
else:
|
1180
|
+
setproctitle.setproctitle(f"sglang::http_server/tokenizer_manager")
|
1177
1181
|
tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses(
|
1178
1182
|
server_args=server_args,
|
1179
1183
|
)
|
@@ -542,9 +542,9 @@ class ChatCompletionRequest(BaseModel):
|
|
542
542
|
rid: Optional[Union[List[str], str]] = None
|
543
543
|
|
544
544
|
# For PD disaggregation
|
545
|
-
bootstrap_host: Optional[str] = None
|
546
|
-
bootstrap_port: Optional[int] = None
|
547
|
-
bootstrap_room: Optional[int] = None
|
545
|
+
bootstrap_host: Optional[Union[List[str], str]] = None
|
546
|
+
bootstrap_port: Optional[Union[List[Optional[int]], int]] = None
|
547
|
+
bootstrap_room: Optional[Union[List[int], int]] = None
|
548
548
|
|
549
549
|
|
550
550
|
class ChatMessage(BaseModel):
|
@@ -8,7 +8,6 @@ from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
|
8
8
|
from fastapi import Request
|
9
9
|
from fastapi.responses import ORJSONResponse, StreamingResponse
|
10
10
|
|
11
|
-
from sglang.srt.conversation import generate_chat_conv
|
12
11
|
from sglang.srt.entrypoints.openai.protocol import (
|
13
12
|
ChatCompletionRequest,
|
14
13
|
ChatCompletionResponse,
|
@@ -33,11 +32,12 @@ from sglang.srt.entrypoints.openai.utils import (
|
|
33
32
|
to_openai_style_logprobs,
|
34
33
|
)
|
35
34
|
from sglang.srt.function_call.function_call_parser import FunctionCallParser
|
36
|
-
from sglang.srt.jinja_template_utils import process_content_for_template_format
|
37
35
|
from sglang.srt.managers.io_struct import GenerateReqInput
|
38
36
|
from sglang.srt.managers.template_manager import TemplateManager
|
39
37
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
40
|
-
from sglang.srt.
|
38
|
+
from sglang.srt.parser.conversation import generate_chat_conv
|
39
|
+
from sglang.srt.parser.jinja_template_utils import process_content_for_template_format
|
40
|
+
from sglang.srt.parser.reasoning_parser import ReasoningParser
|
41
41
|
from sglang.utils import convert_json_schema_to_str
|
42
42
|
|
43
43
|
logger = logging.getLogger(__name__)
|
@@ -5,7 +5,6 @@ from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
|
5
5
|
from fastapi import Request
|
6
6
|
from fastapi.responses import ORJSONResponse, StreamingResponse
|
7
7
|
|
8
|
-
from sglang.srt.code_completion_parser import generate_completion_prompt_from_request
|
9
8
|
from sglang.srt.entrypoints.openai.protocol import (
|
10
9
|
CompletionRequest,
|
11
10
|
CompletionResponse,
|
@@ -23,6 +22,9 @@ from sglang.srt.entrypoints.openai.utils import (
|
|
23
22
|
from sglang.srt.managers.io_struct import GenerateReqInput
|
24
23
|
from sglang.srt.managers.template_manager import TemplateManager
|
25
24
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
25
|
+
from sglang.srt.parser.code_completion_parser import (
|
26
|
+
generate_completion_prompt_from_request,
|
27
|
+
)
|
26
28
|
from sglang.utils import convert_json_schema_to_str
|
27
29
|
|
28
30
|
logger = logging.getLogger(__name__)
|
@@ -3,7 +3,6 @@ from typing import Any, Dict, List, Optional, Union
|
|
3
3
|
from fastapi import Request
|
4
4
|
from fastapi.responses import ORJSONResponse
|
5
5
|
|
6
|
-
from sglang.srt.conversation import generate_embedding_convs
|
7
6
|
from sglang.srt.entrypoints.openai.protocol import (
|
8
7
|
EmbeddingObject,
|
9
8
|
EmbeddingRequest,
|
@@ -16,6 +15,7 @@ from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
|
|
16
15
|
from sglang.srt.managers.io_struct import EmbeddingReqInput
|
17
16
|
from sglang.srt.managers.template_manager import TemplateManager
|
18
17
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
18
|
+
from sglang.srt.parser.conversation import generate_embedding_convs
|
19
19
|
|
20
20
|
|
21
21
|
class OpenAIServingEmbedding(OpenAIServingBase):
|
@@ -56,7 +56,7 @@ from sglang.srt.entrypoints.openai.tool_server import MCPToolServer, ToolServer
|
|
56
56
|
from sglang.srt.managers.io_struct import GenerateReqInput
|
57
57
|
from sglang.srt.managers.template_manager import TemplateManager
|
58
58
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
59
|
-
from sglang.srt.reasoning_parser import ReasoningParser
|
59
|
+
from sglang.srt.parser.reasoning_parser import ReasoningParser
|
60
60
|
from sglang.srt.utils import random_uuid
|
61
61
|
|
62
62
|
logger = logging.getLogger(__name__)
|
@@ -18,7 +18,10 @@ import triton.language as tl
|
|
18
18
|
from sglang.global_config import global_config
|
19
19
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
20
20
|
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
21
|
-
from sglang.srt.layers.dp_attention import
|
21
|
+
from sglang.srt.layers.dp_attention import (
|
22
|
+
get_attention_tp_size,
|
23
|
+
is_dp_attention_enabled,
|
24
|
+
)
|
22
25
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
23
26
|
|
24
27
|
if TYPE_CHECKING:
|
@@ -154,6 +157,8 @@ class AiterAttnBackend(AttentionBackend):
|
|
154
157
|
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
155
158
|
)
|
156
159
|
|
160
|
+
self.enable_dp_attention = is_dp_attention_enabled()
|
161
|
+
|
157
162
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
158
163
|
"""Init auxiliary variables for triton attention backend."""
|
159
164
|
|
@@ -302,19 +307,19 @@ class AiterAttnBackend(AttentionBackend):
|
|
302
307
|
if self.use_mla:
|
303
308
|
self.mla_indices_updater_prefill.update(
|
304
309
|
forward_batch.req_pool_indices,
|
305
|
-
forward_batch.
|
306
|
-
|
310
|
+
forward_batch.seq_lens,
|
311
|
+
forward_batch.seq_lens_sum,
|
307
312
|
forward_batch.extend_seq_lens,
|
308
|
-
max(
|
309
|
-
forward_batch.
|
313
|
+
forward_batch.extend_seq_lens.max().item(),
|
314
|
+
forward_batch.seq_lens.max().item(),
|
310
315
|
spec_info=None,
|
311
316
|
)
|
312
|
-
|
313
|
-
|
314
|
-
|
317
|
+
|
318
|
+
kv_indices = self.mla_indices_updater_prefill.kv_indices
|
319
|
+
|
315
320
|
self.forward_metadata = ForwardMetadata(
|
316
321
|
self.mla_indices_updater_prefill.kv_indptr,
|
317
|
-
|
322
|
+
kv_indices,
|
318
323
|
self.mla_indices_updater_prefill.qo_indptr,
|
319
324
|
self.kv_last_page_len[:bs],
|
320
325
|
self.mla_indices_updater_prefill.max_q_len,
|
@@ -614,66 +619,86 @@ class AiterAttnBackend(AttentionBackend):
|
|
614
619
|
assert len(k.shape) == 3
|
615
620
|
assert len(v.shape) == 3
|
616
621
|
|
617
|
-
if
|
618
|
-
|
619
|
-
|
620
|
-
|
621
|
-
|
622
|
-
|
623
|
-
|
624
|
-
|
625
|
-
|
626
|
-
|
627
|
-
|
628
|
-
|
629
|
-
|
630
|
-
|
631
|
-
|
632
|
-
|
633
|
-
|
634
|
-
|
635
|
-
|
622
|
+
if forward_batch.forward_mode.is_extend():
|
623
|
+
if kv_indices.shape[0] == 0:
|
624
|
+
o = flash_attn_varlen_func(
|
625
|
+
q,
|
626
|
+
k,
|
627
|
+
v,
|
628
|
+
qo_indptr,
|
629
|
+
qo_indptr,
|
630
|
+
max_q_len,
|
631
|
+
max_q_len,
|
632
|
+
softmax_scale=layer.scaling,
|
633
|
+
causal=True,
|
634
|
+
)
|
635
|
+
return o
|
636
|
+
elif layer.qk_head_dim != (kv_lora_rank + qk_rope_head_dim):
|
637
|
+
K_Buffer = torch.index_select(K_Buffer, 0, kv_indices)
|
638
|
+
kvc, k_pe = torch.split(
|
639
|
+
K_Buffer, [kv_lora_rank, qk_rope_head_dim], dim=-1
|
640
|
+
)
|
641
|
+
kvprefix = layer.kv_b_proj(kvc.contiguous())[0]
|
636
642
|
|
637
|
-
|
638
|
-
|
639
|
-
|
640
|
-
|
641
|
-
|
642
|
-
|
643
|
-
|
644
|
-
|
645
|
-
|
646
|
-
|
647
|
-
|
648
|
-
|
649
|
-
|
650
|
-
|
651
|
-
|
652
|
-
|
653
|
-
|
654
|
-
|
655
|
-
|
656
|
-
|
657
|
-
|
658
|
-
|
659
|
-
|
660
|
-
|
661
|
-
|
662
|
-
|
663
|
-
|
664
|
-
|
665
|
-
|
666
|
-
|
667
|
-
|
668
|
-
|
669
|
-
|
670
|
-
|
671
|
-
|
672
|
-
|
673
|
-
|
674
|
-
|
675
|
-
|
676
|
-
|
643
|
+
kvprefix = kvprefix.view(
|
644
|
+
-1, layer.tp_k_head_num, qk_nope_head_dim + layer.v_head_dim
|
645
|
+
)
|
646
|
+
k_prefix, v_prefix = torch.split(
|
647
|
+
kvprefix, [qk_nope_head_dim, layer.v_head_dim], dim=-1
|
648
|
+
)
|
649
|
+
k_prefix = torch.cat(
|
650
|
+
[
|
651
|
+
k_prefix,
|
652
|
+
torch.broadcast_to(
|
653
|
+
k_pe,
|
654
|
+
(k_pe.shape[0], layer.tp_k_head_num, k_pe.shape[2]),
|
655
|
+
),
|
656
|
+
],
|
657
|
+
dim=-1,
|
658
|
+
)
|
659
|
+
assert (
|
660
|
+
forward_batch.extend_prefix_lens.shape
|
661
|
+
== forward_batch.extend_seq_lens.shape
|
662
|
+
)
|
663
|
+
|
664
|
+
k = k_prefix
|
665
|
+
v = v_prefix
|
666
|
+
|
667
|
+
o = flash_attn_varlen_func(
|
668
|
+
q,
|
669
|
+
k,
|
670
|
+
v,
|
671
|
+
qo_indptr,
|
672
|
+
kv_indptr,
|
673
|
+
max_q_len,
|
674
|
+
max_kv_len,
|
675
|
+
softmax_scale=layer.scaling,
|
676
|
+
causal=True,
|
677
|
+
)
|
678
|
+
return o
|
679
|
+
|
680
|
+
else:
|
681
|
+
if layer.qk_head_dim != layer.v_head_dim:
|
682
|
+
o = q.new_empty(
|
683
|
+
(q.shape[0], layer.tp_q_head_num * layer.v_head_dim)
|
684
|
+
)
|
685
|
+
else:
|
686
|
+
o = torch.empty_like(q)
|
687
|
+
|
688
|
+
mla_prefill_fwd(
|
689
|
+
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
690
|
+
K_Buffer.view(-1, 1, 1, layer.qk_head_dim),
|
691
|
+
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
|
692
|
+
qo_indptr,
|
693
|
+
kv_indptr,
|
694
|
+
kv_indices,
|
695
|
+
self.forward_metadata.kv_last_page_len,
|
696
|
+
self.forward_metadata.max_q_len,
|
697
|
+
layer.scaling,
|
698
|
+
layer.logit_cap,
|
699
|
+
)
|
700
|
+
K_Buffer = K_Buffer.view(-1, layer.tp_k_head_num, layer.qk_head_dim)
|
701
|
+
return o
|
677
702
|
elif forward_batch.forward_mode.is_target_verify():
|
678
703
|
o = q.new_empty((q.shape[0], layer.tp_q_head_num, layer.v_head_dim))
|
679
704
|
mla_decode_fwd(
|
@@ -42,10 +42,24 @@ from sglang.srt.layers.moe import (
|
|
42
42
|
)
|
43
43
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
44
44
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
45
|
-
from sglang.srt.utils import
|
45
|
+
from sglang.srt.utils import (
|
46
|
+
get_bool_env_var,
|
47
|
+
is_cuda,
|
48
|
+
is_flashinfer_available,
|
49
|
+
is_gfx95_supported,
|
50
|
+
is_hip,
|
51
|
+
is_sm90_supported,
|
52
|
+
is_sm100_supported,
|
53
|
+
)
|
46
54
|
|
47
55
|
_is_flashinfer_available = is_flashinfer_available()
|
56
|
+
_is_sm90_supported = is_cuda() and is_sm90_supported()
|
48
57
|
_is_sm100_supported = is_cuda() and is_sm100_supported()
|
58
|
+
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
|
59
|
+
_is_gfx95_supported = is_gfx95_supported()
|
60
|
+
|
61
|
+
if _use_aiter and _is_gfx95_supported:
|
62
|
+
from sglang.srt.layers.quantization.rocm_mxfp4_utils import fused_rms_mxfp4_quant
|
49
63
|
|
50
64
|
FUSE_ALLREDUCE_MAX_BATCH_SIZE = 2048
|
51
65
|
|
@@ -201,6 +215,7 @@ class LayerCommunicator:
|
|
201
215
|
hidden_states: torch.Tensor,
|
202
216
|
residual: torch.Tensor,
|
203
217
|
forward_batch: ForwardBatch,
|
218
|
+
qaunt_format: str = "",
|
204
219
|
):
|
205
220
|
if hidden_states.shape[0] == 0:
|
206
221
|
residual = hidden_states
|
@@ -218,11 +233,34 @@ class LayerCommunicator:
|
|
218
233
|
else:
|
219
234
|
if residual is None:
|
220
235
|
residual = hidden_states
|
221
|
-
|
236
|
+
|
237
|
+
if _use_aiter and _is_gfx95_supported and ("mxfp4" in qaunt_format):
|
238
|
+
hidden_states = fused_rms_mxfp4_quant(
|
239
|
+
hidden_states,
|
240
|
+
self.input_layernorm.weight,
|
241
|
+
self.input_layernorm.variance_epsilon,
|
242
|
+
None,
|
243
|
+
None,
|
244
|
+
None,
|
245
|
+
None,
|
246
|
+
)
|
247
|
+
else:
|
248
|
+
hidden_states = self.input_layernorm(hidden_states)
|
222
249
|
else:
|
223
|
-
|
224
|
-
hidden_states, residual
|
225
|
-
|
250
|
+
if _use_aiter and _is_gfx95_supported and ("mxfp4" in qaunt_format):
|
251
|
+
hidden_states, residual = fused_rms_mxfp4_quant(
|
252
|
+
hidden_states,
|
253
|
+
self.input_layernorm.weight,
|
254
|
+
self.input_layernorm.variance_epsilon,
|
255
|
+
None,
|
256
|
+
None,
|
257
|
+
None,
|
258
|
+
residual,
|
259
|
+
)
|
260
|
+
else:
|
261
|
+
hidden_states, residual = self.input_layernorm(
|
262
|
+
hidden_states, residual
|
263
|
+
)
|
226
264
|
|
227
265
|
hidden_states = self._communicate_simple_fn(
|
228
266
|
hidden_states=hidden_states,
|
@@ -484,11 +522,11 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
|
484
522
|
# According to the discussion in https://github.com/flashinfer-ai/flashinfer/issues/1223#issuecomment-3047256465
|
485
523
|
# We set the max token num to 128 for allreduce fusion with min-latency case(use_oneshot=True).
|
486
524
|
if (
|
487
|
-
_is_sm100_supported
|
525
|
+
(_is_sm100_supported or _is_sm90_supported)
|
488
526
|
and _is_flashinfer_available
|
489
527
|
and hasattr(layernorm, "forward_with_allreduce_fusion")
|
490
528
|
and global_server_args_dict["enable_flashinfer_allreduce_fusion"]
|
491
|
-
and hidden_states.shape[0] <=
|
529
|
+
and hidden_states.shape[0] <= 4096
|
492
530
|
):
|
493
531
|
hidden_states, residual = layernorm.forward_with_allreduce_fusion(
|
494
532
|
hidden_states, residual
|