sglang 0.5.2rc1__py3-none-any.whl → 0.5.2rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. sglang/lang/interpreter.py +1 -1
  2. sglang/srt/configs/internvl.py +6 -0
  3. sglang/srt/disaggregation/mini_lb.py +2 -2
  4. sglang/srt/distributed/parallel_state.py +43 -40
  5. sglang/srt/entrypoints/http_server.py +5 -1
  6. sglang/srt/entrypoints/openai/protocol.py +3 -3
  7. sglang/srt/entrypoints/openai/serving_chat.py +3 -3
  8. sglang/srt/entrypoints/openai/serving_completions.py +3 -1
  9. sglang/srt/entrypoints/openai/serving_embedding.py +1 -1
  10. sglang/srt/entrypoints/openai/serving_responses.py +1 -1
  11. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  12. sglang/srt/layers/attention/aiter_backend.py +93 -68
  13. sglang/srt/layers/communicator.py +45 -7
  14. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  15. sglang/srt/layers/moe/utils.py +0 -1
  16. sglang/srt/layers/quantization/modelopt_quant.py +35 -2
  17. sglang/srt/layers/quantization/mxfp4.py +4 -1
  18. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  19. sglang/srt/layers/quantization/quark/utils.py +97 -0
  20. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  21. sglang/srt/layers/rocm_linear_utils.py +44 -0
  22. sglang/srt/layers/rotary_embedding.py +0 -18
  23. sglang/srt/managers/cache_controller.py +42 -39
  24. sglang/srt/managers/multi_tokenizer_mixin.py +4 -0
  25. sglang/srt/managers/schedule_policy.py +3 -2
  26. sglang/srt/managers/scheduler.py +4 -100
  27. sglang/srt/managers/scheduler_metrics_mixin.py +113 -7
  28. sglang/srt/managers/template_manager.py +3 -3
  29. sglang/srt/managers/tokenizer_manager.py +1 -0
  30. sglang/srt/mem_cache/allocator.py +1 -1
  31. sglang/srt/mem_cache/hicache_storage.py +15 -10
  32. sglang/srt/mem_cache/hiradix_cache.py +5 -5
  33. sglang/srt/mem_cache/memory_pool_host.py +16 -11
  34. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +10 -2
  35. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +32 -13
  36. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  37. sglang/srt/metrics/collector.py +12 -4
  38. sglang/srt/metrics/utils.py +48 -0
  39. sglang/srt/model_executor/forward_batch_info.py +16 -17
  40. sglang/srt/model_executor/model_runner.py +1 -1
  41. sglang/srt/models/deepseek_v2.py +240 -36
  42. sglang/srt/models/glm4_moe.py +10 -1
  43. sglang/srt/models/internvl.py +28 -0
  44. sglang/srt/models/minicpmv.py +165 -3
  45. sglang/srt/models/qwen2_moe.py +4 -1
  46. sglang/srt/models/qwen3.py +8 -2
  47. sglang/srt/models/qwen3_moe.py +39 -8
  48. sglang/srt/models/torch_native_llama.py +1 -1
  49. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  50. sglang/srt/server_args.py +79 -2
  51. sglang/srt/speculative/eagle_worker.py +158 -112
  52. sglang/srt/utils.py +12 -0
  53. sglang/test/few_shot_gsm8k.py +1 -0
  54. sglang/utils.py +1 -0
  55. sglang/version.py +1 -1
  56. {sglang-0.5.2rc1.dist-info → sglang-0.5.2rc2.dist-info}/METADATA +1 -1
  57. {sglang-0.5.2rc1.dist-info → sglang-0.5.2rc2.dist-info}/RECORD +65 -61
  58. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  59. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  60. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  61. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  62. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  63. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  64. {sglang-0.5.2rc1.dist-info → sglang-0.5.2rc2.dist-info}/WHEEL +0 -0
  65. {sglang-0.5.2rc1.dist-info → sglang-0.5.2rc2.dist-info}/licenses/LICENSE +0 -0
  66. {sglang-0.5.2rc1.dist-info → sglang-0.5.2rc2.dist-info}/top_level.txt +0 -0
@@ -740,7 +740,7 @@ class StreamExecutor:
740
740
  # Execute the stored lazy generation calls
741
741
  self.backend.role_end_generate(self)
742
742
 
743
- from sglang.srt.reasoning_parser import ReasoningParser
743
+ from sglang.srt.parser.reasoning_parser import ReasoningParser
744
744
 
745
745
  reasoning_parser = ReasoningParser(expr.model_type)
746
746
  other = expr.expr
@@ -6,11 +6,13 @@ from typing import Any, Dict, List, Optional, Tuple, Union
6
6
  import sentencepiece as spm
7
7
  from transformers import (
8
8
  TOKENIZER_MAPPING,
9
+ GptOssConfig,
9
10
  LlamaConfig,
10
11
  PretrainedConfig,
11
12
  PreTrainedTokenizer,
12
13
  Qwen2Config,
13
14
  Qwen3Config,
15
+ Qwen3MoeConfig,
14
16
  )
15
17
 
16
18
  from sglang.utils import logger
@@ -316,7 +318,11 @@ class InternVLChatConfig(PretrainedConfig):
316
318
  elif llm_config.get("architectures")[0] == "Qwen2ForCausalLM":
317
319
  self.llm_config = Qwen2Config(**llm_config)
318
320
  elif llm_config.get("architectures")[0] == "Qwen3MoeForCausalLM":
321
+ self.llm_config = Qwen3MoeConfig(**llm_config)
322
+ elif llm_config.get("architectures")[0] == "Qwen3ForCausalLM":
319
323
  self.llm_config = Qwen3Config(**llm_config)
324
+ elif llm_config.get("architectures")[0] == "GptOssForCausalLM":
325
+ self.llm_config = GptOssConfig(**llm_config)
320
326
  else:
321
327
  raise ValueError(
322
328
  "Unsupported architecture: {}".format(
@@ -187,7 +187,7 @@ async def health_check():
187
187
 
188
188
 
189
189
  @app.get("/health_generate")
190
- async def health_check():
190
+ async def health_generate():
191
191
  prefill_servers, decode_servers = (
192
192
  load_balancer.prefill_servers,
193
193
  load_balancer.decode_servers,
@@ -196,7 +196,7 @@ async def health_check():
196
196
  # Create the tasks
197
197
  tasks = []
198
198
  for server in chain(prefill_servers, decode_servers):
199
- tasks.append(session.post(f"{server}/health_generate"))
199
+ tasks.append(session.get(f"{server}/health_generate"))
200
200
  for i, response in enumerate(asyncio.as_completed(tasks)):
201
201
  await response
202
202
  return Response(status_code=200)
@@ -879,17 +879,16 @@ class GroupCoordinator:
879
879
  size_tensor = torch.tensor(
880
880
  [object_tensor.numel()],
881
881
  dtype=torch.long,
882
- device=torch.cuda.current_device(),
882
+ device="cpu",
883
883
  )
884
-
885
884
  # Send object size
886
- torch.distributed.send(
887
- size_tensor, dst=self.ranks[dst], group=self.device_group
888
- )
885
+ torch.distributed.send(size_tensor, dst=self.ranks[dst], group=self.cpu_group)
889
886
 
890
887
  # Send object
891
888
  torch.distributed.send(
892
- object_tensor, dst=self.ranks[dst], group=self.device_group
889
+ object_tensor,
890
+ dst=self.ranks[dst],
891
+ group=self.device_group,
893
892
  )
894
893
 
895
894
  return None
@@ -904,13 +903,11 @@ class GroupCoordinator:
904
903
  src != self.rank_in_group
905
904
  ), "Invalid source rank. Source rank is the same as the current rank."
906
905
 
907
- size_tensor = torch.empty(
908
- 1, dtype=torch.long, device=torch.cuda.current_device()
909
- )
906
+ size_tensor = torch.empty(1, dtype=torch.long, device="cpu")
910
907
 
911
908
  # Receive object size
912
909
  rank_size = torch.distributed.recv(
913
- size_tensor, src=self.ranks[src], group=self.device_group
910
+ size_tensor, src=self.ranks[src], group=self.cpu_group
914
911
  )
915
912
 
916
913
  # Tensor to receive serialized objects into.
@@ -928,7 +925,7 @@ class GroupCoordinator:
928
925
  rank_object == rank_size
929
926
  ), "Received object sender rank does not match the size sender rank."
930
927
 
931
- obj = pickle.loads(object_tensor.cpu().numpy().tobytes())
928
+ obj = pickle.loads(object_tensor.cpu().numpy())
932
929
 
933
930
  return obj
934
931
 
@@ -1461,43 +1458,49 @@ def initialize_model_parallel(
1461
1458
  _PDMUX_PREFILL_TP_GROUP.pynccl_comm.disabled = False
1462
1459
 
1463
1460
  moe_ep_size = expert_model_parallel_size
1464
-
1465
1461
  moe_tp_size = tensor_model_parallel_size // moe_ep_size
1462
+
1466
1463
  global _MOE_EP
1467
1464
  assert _MOE_EP is None, "expert model parallel group is already initialized"
1468
- group_ranks = []
1469
- for i in range(num_tensor_model_parallel_groups):
1470
- for j in range(moe_tp_size):
1471
- st = i * tensor_model_parallel_size + j
1472
- en = (i + 1) * tensor_model_parallel_size + j
1473
- ranks = list(range(st, en, moe_tp_size))
1474
- group_ranks.append(ranks)
1475
1465
 
1476
- _MOE_EP = init_model_parallel_group(
1477
- group_ranks,
1478
- get_world_group().local_rank,
1479
- backend,
1480
- use_custom_allreduce=False,
1481
- group_name="moe_ep",
1482
- )
1466
+ if moe_ep_size == tensor_model_parallel_size:
1467
+ _MOE_EP = _TP
1468
+ else:
1469
+ # TODO(ch-wan): use split_group to save memory
1470
+ group_ranks = []
1471
+ for i in range(num_tensor_model_parallel_groups):
1472
+ for j in range(moe_tp_size):
1473
+ st = i * tensor_model_parallel_size + j
1474
+ en = (i + 1) * tensor_model_parallel_size + j
1475
+ ranks = list(range(st, en, moe_tp_size))
1476
+ group_ranks.append(ranks)
1477
+ _MOE_EP = init_model_parallel_group(
1478
+ group_ranks,
1479
+ get_world_group().local_rank,
1480
+ backend,
1481
+ group_name="moe_ep",
1482
+ )
1483
1483
 
1484
1484
  global _MOE_TP
1485
1485
  assert _MOE_TP is None, "expert model parallel group is already initialized"
1486
- group_ranks = []
1487
- for i in range(num_tensor_model_parallel_groups):
1488
- for j in range(moe_ep_size):
1489
- st = i * tensor_model_parallel_size + j * moe_tp_size
1490
- en = i * tensor_model_parallel_size + (j + 1) * moe_tp_size
1491
- ranks = list(range(st, en))
1492
- group_ranks.append(ranks)
1493
1486
 
1494
- _MOE_TP = init_model_parallel_group(
1495
- group_ranks,
1496
- get_world_group().local_rank,
1497
- backend,
1498
- use_custom_allreduce=False,
1499
- group_name="moe_tp",
1500
- )
1487
+ if moe_tp_size == tensor_model_parallel_size:
1488
+ _MOE_TP = _TP
1489
+ else:
1490
+ # TODO(ch-wan): use split_group to save memory
1491
+ group_ranks = []
1492
+ for i in range(num_tensor_model_parallel_groups):
1493
+ for j in range(moe_ep_size):
1494
+ st = i * tensor_model_parallel_size + j * moe_tp_size
1495
+ en = i * tensor_model_parallel_size + (j + 1) * moe_tp_size
1496
+ ranks = list(range(st, en))
1497
+ group_ranks.append(ranks)
1498
+ _MOE_TP = init_model_parallel_group(
1499
+ group_ranks,
1500
+ get_world_group().local_rank,
1501
+ backend,
1502
+ group_name="moe_tp",
1503
+ )
1501
1504
 
1502
1505
  # Build the pipeline model-parallel groups.
1503
1506
  num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size
@@ -29,6 +29,8 @@ import time
29
29
  from http import HTTPStatus
30
30
  from typing import Any, AsyncIterator, Callable, Dict, List, Optional
31
31
 
32
+ import setproctitle
33
+
32
34
  # Fix a bug of Python threading
33
35
  setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
34
36
 
@@ -102,7 +104,7 @@ from sglang.srt.managers.multi_tokenizer_mixin import (
102
104
  from sglang.srt.managers.template_manager import TemplateManager
103
105
  from sglang.srt.managers.tokenizer_manager import ServerStatus, TokenizerManager
104
106
  from sglang.srt.metrics.func_timer import enable_func_timer
105
- from sglang.srt.reasoning_parser import ReasoningParser
107
+ from sglang.srt.parser.reasoning_parser import ReasoningParser
106
108
  from sglang.srt.server_args import PortArgs, ServerArgs
107
109
  from sglang.srt.utils import (
108
110
  add_api_key_middleware,
@@ -1166,6 +1168,7 @@ def launch_server(
1166
1168
  2. Inter-process communication is done through IPC (each process uses a different port) via the ZMQ library.
1167
1169
  """
1168
1170
  if server_args.tokenizer_worker_num > 1:
1171
+ setproctitle.setproctitle(f"sglang::http_server/multi_tokenizer_router")
1169
1172
  port_args = PortArgs.init_new(server_args)
1170
1173
  port_args.tokenizer_worker_ipc_name = (
1171
1174
  f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}"
@@ -1174,6 +1177,7 @@ def launch_server(
1174
1177
  server_args=server_args, port_args=port_args
1175
1178
  )
1176
1179
  else:
1180
+ setproctitle.setproctitle(f"sglang::http_server/tokenizer_manager")
1177
1181
  tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses(
1178
1182
  server_args=server_args,
1179
1183
  )
@@ -542,9 +542,9 @@ class ChatCompletionRequest(BaseModel):
542
542
  rid: Optional[Union[List[str], str]] = None
543
543
 
544
544
  # For PD disaggregation
545
- bootstrap_host: Optional[str] = None
546
- bootstrap_port: Optional[int] = None
547
- bootstrap_room: Optional[int] = None
545
+ bootstrap_host: Optional[Union[List[str], str]] = None
546
+ bootstrap_port: Optional[Union[List[Optional[int]], int]] = None
547
+ bootstrap_room: Optional[Union[List[int], int]] = None
548
548
 
549
549
 
550
550
  class ChatMessage(BaseModel):
@@ -8,7 +8,6 @@ from typing import Any, AsyncGenerator, Dict, List, Optional, Union
8
8
  from fastapi import Request
9
9
  from fastapi.responses import ORJSONResponse, StreamingResponse
10
10
 
11
- from sglang.srt.conversation import generate_chat_conv
12
11
  from sglang.srt.entrypoints.openai.protocol import (
13
12
  ChatCompletionRequest,
14
13
  ChatCompletionResponse,
@@ -33,11 +32,12 @@ from sglang.srt.entrypoints.openai.utils import (
33
32
  to_openai_style_logprobs,
34
33
  )
35
34
  from sglang.srt.function_call.function_call_parser import FunctionCallParser
36
- from sglang.srt.jinja_template_utils import process_content_for_template_format
37
35
  from sglang.srt.managers.io_struct import GenerateReqInput
38
36
  from sglang.srt.managers.template_manager import TemplateManager
39
37
  from sglang.srt.managers.tokenizer_manager import TokenizerManager
40
- from sglang.srt.reasoning_parser import ReasoningParser
38
+ from sglang.srt.parser.conversation import generate_chat_conv
39
+ from sglang.srt.parser.jinja_template_utils import process_content_for_template_format
40
+ from sglang.srt.parser.reasoning_parser import ReasoningParser
41
41
  from sglang.utils import convert_json_schema_to_str
42
42
 
43
43
  logger = logging.getLogger(__name__)
@@ -5,7 +5,6 @@ from typing import Any, AsyncGenerator, Dict, List, Optional, Union
5
5
  from fastapi import Request
6
6
  from fastapi.responses import ORJSONResponse, StreamingResponse
7
7
 
8
- from sglang.srt.code_completion_parser import generate_completion_prompt_from_request
9
8
  from sglang.srt.entrypoints.openai.protocol import (
10
9
  CompletionRequest,
11
10
  CompletionResponse,
@@ -23,6 +22,9 @@ from sglang.srt.entrypoints.openai.utils import (
23
22
  from sglang.srt.managers.io_struct import GenerateReqInput
24
23
  from sglang.srt.managers.template_manager import TemplateManager
25
24
  from sglang.srt.managers.tokenizer_manager import TokenizerManager
25
+ from sglang.srt.parser.code_completion_parser import (
26
+ generate_completion_prompt_from_request,
27
+ )
26
28
  from sglang.utils import convert_json_schema_to_str
27
29
 
28
30
  logger = logging.getLogger(__name__)
@@ -3,7 +3,6 @@ from typing import Any, Dict, List, Optional, Union
3
3
  from fastapi import Request
4
4
  from fastapi.responses import ORJSONResponse
5
5
 
6
- from sglang.srt.conversation import generate_embedding_convs
7
6
  from sglang.srt.entrypoints.openai.protocol import (
8
7
  EmbeddingObject,
9
8
  EmbeddingRequest,
@@ -16,6 +15,7 @@ from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
16
15
  from sglang.srt.managers.io_struct import EmbeddingReqInput
17
16
  from sglang.srt.managers.template_manager import TemplateManager
18
17
  from sglang.srt.managers.tokenizer_manager import TokenizerManager
18
+ from sglang.srt.parser.conversation import generate_embedding_convs
19
19
 
20
20
 
21
21
  class OpenAIServingEmbedding(OpenAIServingBase):
@@ -56,7 +56,7 @@ from sglang.srt.entrypoints.openai.tool_server import MCPToolServer, ToolServer
56
56
  from sglang.srt.managers.io_struct import GenerateReqInput
57
57
  from sglang.srt.managers.template_manager import TemplateManager
58
58
  from sglang.srt.managers.tokenizer_manager import TokenizerManager
59
- from sglang.srt.reasoning_parser import ReasoningParser
59
+ from sglang.srt.parser.reasoning_parser import ReasoningParser
60
60
  from sglang.srt.utils import random_uuid
61
61
 
62
62
  logger = logging.getLogger(__name__)
@@ -10,7 +10,7 @@ from sglang.srt.function_call.core_types import (
10
10
  ToolCallItem,
11
11
  _GetInfoFunc,
12
12
  )
13
- from sglang.srt.harmony_parser import HarmonyParser
13
+ from sglang.srt.parser.harmony_parser import HarmonyParser
14
14
 
15
15
  logger = logging.getLogger(__name__)
16
16
 
@@ -18,7 +18,10 @@ import triton.language as tl
18
18
  from sglang.global_config import global_config
19
19
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
20
20
  from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
21
- from sglang.srt.layers.dp_attention import get_attention_tp_size
21
+ from sglang.srt.layers.dp_attention import (
22
+ get_attention_tp_size,
23
+ is_dp_attention_enabled,
24
+ )
22
25
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
23
26
 
24
27
  if TYPE_CHECKING:
@@ -154,6 +157,8 @@ class AiterAttnBackend(AttentionBackend):
154
157
  (max_bs + 1,), dtype=torch.int32, device=model_runner.device
155
158
  )
156
159
 
160
+ self.enable_dp_attention = is_dp_attention_enabled()
161
+
157
162
  def init_forward_metadata(self, forward_batch: ForwardBatch):
158
163
  """Init auxiliary variables for triton attention backend."""
159
164
 
@@ -302,19 +307,19 @@ class AiterAttnBackend(AttentionBackend):
302
307
  if self.use_mla:
303
308
  self.mla_indices_updater_prefill.update(
304
309
  forward_batch.req_pool_indices,
305
- forward_batch.extend_prefix_lens,
306
- sum(forward_batch.extend_prefix_lens_cpu),
310
+ forward_batch.seq_lens,
311
+ forward_batch.seq_lens_sum,
307
312
  forward_batch.extend_seq_lens,
308
- max(forward_batch.extend_seq_lens_cpu),
309
- forward_batch.seq_lens_cpu.max().item(),
313
+ forward_batch.extend_seq_lens.max().item(),
314
+ forward_batch.seq_lens.max().item(),
310
315
  spec_info=None,
311
316
  )
312
- self.mla_indices_updater_prefill.kv_indptr += (
313
- self.mla_indices_updater_prefill.qo_indptr
314
- )
317
+
318
+ kv_indices = self.mla_indices_updater_prefill.kv_indices
319
+
315
320
  self.forward_metadata = ForwardMetadata(
316
321
  self.mla_indices_updater_prefill.kv_indptr,
317
- self.mla_indices_updater_prefill.kv_indices,
322
+ kv_indices,
318
323
  self.mla_indices_updater_prefill.qo_indptr,
319
324
  self.kv_last_page_len[:bs],
320
325
  self.mla_indices_updater_prefill.max_q_len,
@@ -614,66 +619,86 @@ class AiterAttnBackend(AttentionBackend):
614
619
  assert len(k.shape) == 3
615
620
  assert len(v.shape) == 3
616
621
 
617
- if kv_indices.shape[0] == 0:
618
- o = flash_attn_varlen_func(
619
- q,
620
- k,
621
- v,
622
- qo_indptr,
623
- qo_indptr,
624
- max_q_len,
625
- max_q_len,
626
- softmax_scale=layer.scaling,
627
- causal=True,
628
- )
629
- return o
630
- elif layer.qk_head_dim != (kv_lora_rank + qk_rope_head_dim):
631
- K_Buffer = torch.index_select(K_Buffer, 0, kv_indices)
632
- kvc, k_pe = torch.split(
633
- K_Buffer, [kv_lora_rank, qk_rope_head_dim], dim=-1
634
- )
635
- kvprefix = layer.kv_b_proj(kvc.contiguous())[0]
622
+ if forward_batch.forward_mode.is_extend():
623
+ if kv_indices.shape[0] == 0:
624
+ o = flash_attn_varlen_func(
625
+ q,
626
+ k,
627
+ v,
628
+ qo_indptr,
629
+ qo_indptr,
630
+ max_q_len,
631
+ max_q_len,
632
+ softmax_scale=layer.scaling,
633
+ causal=True,
634
+ )
635
+ return o
636
+ elif layer.qk_head_dim != (kv_lora_rank + qk_rope_head_dim):
637
+ K_Buffer = torch.index_select(K_Buffer, 0, kv_indices)
638
+ kvc, k_pe = torch.split(
639
+ K_Buffer, [kv_lora_rank, qk_rope_head_dim], dim=-1
640
+ )
641
+ kvprefix = layer.kv_b_proj(kvc.contiguous())[0]
636
642
 
637
- kvprefix = kvprefix.view(
638
- -1, layer.tp_k_head_num, qk_nope_head_dim + layer.v_head_dim
639
- )
640
- k_prefix, v_prefix = torch.split(
641
- kvprefix, [qk_nope_head_dim, layer.v_head_dim], dim=-1
642
- )
643
- k_prefix = torch.cat(
644
- [
645
- k_prefix,
646
- torch.broadcast_to(
647
- k_pe,
648
- (k_pe.shape[0], layer.tp_k_head_num, k_pe.shape[2]),
649
- ),
650
- ],
651
- dim=-1,
652
- )
653
- assert (
654
- forward_batch.extend_prefix_lens.shape
655
- == forward_batch.extend_seq_lens.shape
656
- )
657
- k_prefix = torch.split(k_prefix, forward_batch.extend_prefix_lens_cpu)
658
- k_extend = torch.split(k, forward_batch.extend_seq_lens_cpu)
659
- assert len(k_prefix) == len(forward_batch.extend_prefix_lens_cpu)
660
- k = torch.cat([x for el in zip(k_prefix, k_extend) for x in el])
661
- v_prefix = torch.split(v_prefix, forward_batch.extend_prefix_lens_cpu)
662
- v_extend = torch.split(v, forward_batch.extend_seq_lens_cpu)
663
- v = torch.cat([x for el in zip(v_prefix, v_extend) for x in el])
664
-
665
- o = flash_attn_varlen_func(
666
- q,
667
- k,
668
- v,
669
- qo_indptr,
670
- kv_indptr,
671
- max_q_len,
672
- max_kv_len,
673
- softmax_scale=layer.scaling,
674
- causal=True,
675
- )
676
- return o
643
+ kvprefix = kvprefix.view(
644
+ -1, layer.tp_k_head_num, qk_nope_head_dim + layer.v_head_dim
645
+ )
646
+ k_prefix, v_prefix = torch.split(
647
+ kvprefix, [qk_nope_head_dim, layer.v_head_dim], dim=-1
648
+ )
649
+ k_prefix = torch.cat(
650
+ [
651
+ k_prefix,
652
+ torch.broadcast_to(
653
+ k_pe,
654
+ (k_pe.shape[0], layer.tp_k_head_num, k_pe.shape[2]),
655
+ ),
656
+ ],
657
+ dim=-1,
658
+ )
659
+ assert (
660
+ forward_batch.extend_prefix_lens.shape
661
+ == forward_batch.extend_seq_lens.shape
662
+ )
663
+
664
+ k = k_prefix
665
+ v = v_prefix
666
+
667
+ o = flash_attn_varlen_func(
668
+ q,
669
+ k,
670
+ v,
671
+ qo_indptr,
672
+ kv_indptr,
673
+ max_q_len,
674
+ max_kv_len,
675
+ softmax_scale=layer.scaling,
676
+ causal=True,
677
+ )
678
+ return o
679
+
680
+ else:
681
+ if layer.qk_head_dim != layer.v_head_dim:
682
+ o = q.new_empty(
683
+ (q.shape[0], layer.tp_q_head_num * layer.v_head_dim)
684
+ )
685
+ else:
686
+ o = torch.empty_like(q)
687
+
688
+ mla_prefill_fwd(
689
+ q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
690
+ K_Buffer.view(-1, 1, 1, layer.qk_head_dim),
691
+ o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
692
+ qo_indptr,
693
+ kv_indptr,
694
+ kv_indices,
695
+ self.forward_metadata.kv_last_page_len,
696
+ self.forward_metadata.max_q_len,
697
+ layer.scaling,
698
+ layer.logit_cap,
699
+ )
700
+ K_Buffer = K_Buffer.view(-1, layer.tp_k_head_num, layer.qk_head_dim)
701
+ return o
677
702
  elif forward_batch.forward_mode.is_target_verify():
678
703
  o = q.new_empty((q.shape[0], layer.tp_q_head_num, layer.v_head_dim))
679
704
  mla_decode_fwd(
@@ -42,10 +42,24 @@ from sglang.srt.layers.moe import (
42
42
  )
43
43
  from sglang.srt.managers.schedule_batch import global_server_args_dict
44
44
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
45
- from sglang.srt.utils import is_cuda, is_flashinfer_available, is_sm100_supported
45
+ from sglang.srt.utils import (
46
+ get_bool_env_var,
47
+ is_cuda,
48
+ is_flashinfer_available,
49
+ is_gfx95_supported,
50
+ is_hip,
51
+ is_sm90_supported,
52
+ is_sm100_supported,
53
+ )
46
54
 
47
55
  _is_flashinfer_available = is_flashinfer_available()
56
+ _is_sm90_supported = is_cuda() and is_sm90_supported()
48
57
  _is_sm100_supported = is_cuda() and is_sm100_supported()
58
+ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
59
+ _is_gfx95_supported = is_gfx95_supported()
60
+
61
+ if _use_aiter and _is_gfx95_supported:
62
+ from sglang.srt.layers.quantization.rocm_mxfp4_utils import fused_rms_mxfp4_quant
49
63
 
50
64
  FUSE_ALLREDUCE_MAX_BATCH_SIZE = 2048
51
65
 
@@ -201,6 +215,7 @@ class LayerCommunicator:
201
215
  hidden_states: torch.Tensor,
202
216
  residual: torch.Tensor,
203
217
  forward_batch: ForwardBatch,
218
+ qaunt_format: str = "",
204
219
  ):
205
220
  if hidden_states.shape[0] == 0:
206
221
  residual = hidden_states
@@ -218,11 +233,34 @@ class LayerCommunicator:
218
233
  else:
219
234
  if residual is None:
220
235
  residual = hidden_states
221
- hidden_states = self.input_layernorm(hidden_states)
236
+
237
+ if _use_aiter and _is_gfx95_supported and ("mxfp4" in qaunt_format):
238
+ hidden_states = fused_rms_mxfp4_quant(
239
+ hidden_states,
240
+ self.input_layernorm.weight,
241
+ self.input_layernorm.variance_epsilon,
242
+ None,
243
+ None,
244
+ None,
245
+ None,
246
+ )
247
+ else:
248
+ hidden_states = self.input_layernorm(hidden_states)
222
249
  else:
223
- hidden_states, residual = self.input_layernorm(
224
- hidden_states, residual
225
- )
250
+ if _use_aiter and _is_gfx95_supported and ("mxfp4" in qaunt_format):
251
+ hidden_states, residual = fused_rms_mxfp4_quant(
252
+ hidden_states,
253
+ self.input_layernorm.weight,
254
+ self.input_layernorm.variance_epsilon,
255
+ None,
256
+ None,
257
+ None,
258
+ residual,
259
+ )
260
+ else:
261
+ hidden_states, residual = self.input_layernorm(
262
+ hidden_states, residual
263
+ )
226
264
 
227
265
  hidden_states = self._communicate_simple_fn(
228
266
  hidden_states=hidden_states,
@@ -484,11 +522,11 @@ class CommunicateWithAllReduceAndLayerNormFn:
484
522
  # According to the discussion in https://github.com/flashinfer-ai/flashinfer/issues/1223#issuecomment-3047256465
485
523
  # We set the max token num to 128 for allreduce fusion with min-latency case(use_oneshot=True).
486
524
  if (
487
- _is_sm100_supported
525
+ (_is_sm100_supported or _is_sm90_supported)
488
526
  and _is_flashinfer_available
489
527
  and hasattr(layernorm, "forward_with_allreduce_fusion")
490
528
  and global_server_args_dict["enable_flashinfer_allreduce_fusion"]
491
- and hidden_states.shape[0] <= 2048
529
+ and hidden_states.shape[0] <= 4096
492
530
  ):
493
531
  hidden_states, residual = layernorm.forward_with_allreduce_fusion(
494
532
  hidden_states, residual