sglang 0.5.2rc0__py3-none-any.whl → 0.5.2rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (84) hide show
  1. sglang/lang/interpreter.py +1 -1
  2. sglang/srt/configs/internvl.py +6 -0
  3. sglang/srt/configs/model_config.py +2 -1
  4. sglang/srt/disaggregation/mini_lb.py +2 -2
  5. sglang/srt/distributed/parallel_state.py +46 -41
  6. sglang/srt/entrypoints/engine.py +1 -1
  7. sglang/srt/entrypoints/http_server.py +5 -1
  8. sglang/srt/entrypoints/openai/protocol.py +3 -3
  9. sglang/srt/entrypoints/openai/serving_chat.py +3 -3
  10. sglang/srt/entrypoints/openai/serving_completions.py +3 -1
  11. sglang/srt/entrypoints/openai/serving_embedding.py +1 -1
  12. sglang/srt/entrypoints/openai/serving_responses.py +1 -1
  13. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  14. sglang/srt/layers/attention/aiter_backend.py +93 -68
  15. sglang/srt/layers/communicator.py +45 -7
  16. sglang/srt/layers/moe/cutlass_w4a8_moe.py +1 -9
  17. sglang/srt/layers/moe/ep_moe/layer.py +2 -7
  18. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  19. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  20. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -1048
  21. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  22. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +796 -0
  23. sglang/srt/layers/moe/fused_moe_triton/layer.py +5 -2
  24. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  25. sglang/srt/layers/moe/utils.py +0 -1
  26. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +8 -0
  27. sglang/srt/layers/quantization/modelopt_quant.py +35 -2
  28. sglang/srt/layers/quantization/mxfp4.py +4 -1
  29. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  30. sglang/srt/layers/quantization/quark/utils.py +97 -0
  31. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  32. sglang/srt/layers/quantization/w4afp8.py +30 -25
  33. sglang/srt/layers/rocm_linear_utils.py +44 -0
  34. sglang/srt/layers/rotary_embedding.py +0 -18
  35. sglang/srt/managers/cache_controller.py +42 -39
  36. sglang/srt/managers/detokenizer_manager.py +0 -34
  37. sglang/srt/managers/multi_tokenizer_mixin.py +48 -6
  38. sglang/srt/managers/schedule_policy.py +3 -2
  39. sglang/srt/managers/scheduler.py +7 -100
  40. sglang/srt/managers/scheduler_metrics_mixin.py +113 -7
  41. sglang/srt/managers/template_manager.py +3 -3
  42. sglang/srt/managers/tokenizer_manager.py +1 -0
  43. sglang/srt/mem_cache/allocator.py +1 -1
  44. sglang/srt/mem_cache/hicache_storage.py +15 -10
  45. sglang/srt/mem_cache/hiradix_cache.py +16 -0
  46. sglang/srt/mem_cache/memory_pool_host.py +18 -11
  47. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  48. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +35 -6
  49. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +32 -13
  50. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  51. sglang/srt/metrics/collector.py +12 -4
  52. sglang/srt/metrics/utils.py +48 -0
  53. sglang/srt/model_executor/forward_batch_info.py +16 -17
  54. sglang/srt/model_executor/model_runner.py +1 -1
  55. sglang/srt/models/deepseek_v2.py +245 -36
  56. sglang/srt/models/glm4_moe.py +10 -1
  57. sglang/srt/models/gpt_oss.py +5 -4
  58. sglang/srt/models/internvl.py +28 -0
  59. sglang/srt/models/longcat_flash.py +26 -15
  60. sglang/srt/models/longcat_flash_nextn.py +23 -15
  61. sglang/srt/models/minicpmv.py +165 -3
  62. sglang/srt/models/qwen2_moe.py +4 -1
  63. sglang/srt/models/qwen3.py +8 -2
  64. sglang/srt/models/qwen3_moe.py +39 -8
  65. sglang/srt/models/torch_native_llama.py +1 -1
  66. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  67. sglang/srt/server_args.py +79 -2
  68. sglang/srt/speculative/eagle_worker.py +158 -112
  69. sglang/srt/utils.py +12 -10
  70. sglang/test/few_shot_gsm8k.py +1 -0
  71. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  72. sglang/utils.py +1 -0
  73. sglang/version.py +1 -1
  74. {sglang-0.5.2rc0.dist-info → sglang-0.5.2rc2.dist-info}/METADATA +2 -2
  75. {sglang-0.5.2rc0.dist-info → sglang-0.5.2rc2.dist-info}/RECORD +83 -76
  76. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  77. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  78. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  79. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  80. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  81. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  82. {sglang-0.5.2rc0.dist-info → sglang-0.5.2rc2.dist-info}/WHEEL +0 -0
  83. {sglang-0.5.2rc0.dist-info → sglang-0.5.2rc2.dist-info}/licenses/LICENSE +0 -0
  84. {sglang-0.5.2rc0.dist-info → sglang-0.5.2rc2.dist-info}/top_level.txt +0 -0
@@ -740,7 +740,7 @@ class StreamExecutor:
740
740
  # Execute the stored lazy generation calls
741
741
  self.backend.role_end_generate(self)
742
742
 
743
- from sglang.srt.reasoning_parser import ReasoningParser
743
+ from sglang.srt.parser.reasoning_parser import ReasoningParser
744
744
 
745
745
  reasoning_parser = ReasoningParser(expr.model_type)
746
746
  other = expr.expr
@@ -6,11 +6,13 @@ from typing import Any, Dict, List, Optional, Tuple, Union
6
6
  import sentencepiece as spm
7
7
  from transformers import (
8
8
  TOKENIZER_MAPPING,
9
+ GptOssConfig,
9
10
  LlamaConfig,
10
11
  PretrainedConfig,
11
12
  PreTrainedTokenizer,
12
13
  Qwen2Config,
13
14
  Qwen3Config,
15
+ Qwen3MoeConfig,
14
16
  )
15
17
 
16
18
  from sglang.utils import logger
@@ -316,7 +318,11 @@ class InternVLChatConfig(PretrainedConfig):
316
318
  elif llm_config.get("architectures")[0] == "Qwen2ForCausalLM":
317
319
  self.llm_config = Qwen2Config(**llm_config)
318
320
  elif llm_config.get("architectures")[0] == "Qwen3MoeForCausalLM":
321
+ self.llm_config = Qwen3MoeConfig(**llm_config)
322
+ elif llm_config.get("architectures")[0] == "Qwen3ForCausalLM":
319
323
  self.llm_config = Qwen3Config(**llm_config)
324
+ elif llm_config.get("architectures")[0] == "GptOssForCausalLM":
325
+ self.llm_config = GptOssConfig(**llm_config)
320
326
  else:
321
327
  raise ValueError(
322
328
  "Unsupported architecture: {}".format(
@@ -405,9 +405,10 @@ class ModelConfig:
405
405
  # compressed-tensors uses a "compression_config" key
406
406
  quant_cfg = getattr(self.hf_config, "compression_config", None)
407
407
  if quant_cfg is None:
408
- # check if is modelopt model -- modelopt doesn't have corresponding field
408
+ # check if is modelopt or mixed-precision model -- Both of them don't have corresponding field
409
409
  # in hf `config.json` but has a standalone `hf_quant_config.json` in the root directory
410
410
  # example: https://huggingface.co/nvidia/Llama-3.1-8B-Instruct-FP8/tree/main
411
+ # example: https://huggingface.co/Barrrrry/DeepSeek-R1-W4AFP8/tree/main
411
412
  is_local = os.path.exists(self.model_path)
412
413
  modelopt_quant_config = {"quant_method": "modelopt"}
413
414
  if not is_local:
@@ -187,7 +187,7 @@ async def health_check():
187
187
 
188
188
 
189
189
  @app.get("/health_generate")
190
- async def health_check():
190
+ async def health_generate():
191
191
  prefill_servers, decode_servers = (
192
192
  load_balancer.prefill_servers,
193
193
  load_balancer.decode_servers,
@@ -196,7 +196,7 @@ async def health_check():
196
196
  # Create the tasks
197
197
  tasks = []
198
198
  for server in chain(prefill_servers, decode_servers):
199
- tasks.append(session.post(f"{server}/health_generate"))
199
+ tasks.append(session.get(f"{server}/health_generate"))
200
200
  for i, response in enumerate(asyncio.as_completed(tasks)):
201
201
  await response
202
202
  return Response(status_code=200)
@@ -43,6 +43,7 @@ from sglang.srt.utils import (
43
43
  direct_register_custom_op,
44
44
  get_bool_env_var,
45
45
  get_int_env_var,
46
+ is_cpu,
46
47
  is_cuda_alike,
47
48
  is_hip,
48
49
  is_npu,
@@ -51,6 +52,7 @@ from sglang.srt.utils import (
51
52
  )
52
53
 
53
54
  _is_npu = is_npu()
55
+ _is_cpu = is_cpu()
54
56
 
55
57
  IS_ONE_DEVICE_PER_PROCESS = get_bool_env_var("SGLANG_ONE_DEVICE_PER_PROCESS")
56
58
 
@@ -877,17 +879,16 @@ class GroupCoordinator:
877
879
  size_tensor = torch.tensor(
878
880
  [object_tensor.numel()],
879
881
  dtype=torch.long,
880
- device=torch.cuda.current_device(),
882
+ device="cpu",
881
883
  )
882
-
883
884
  # Send object size
884
- torch.distributed.send(
885
- size_tensor, dst=self.ranks[dst], group=self.device_group
886
- )
885
+ torch.distributed.send(size_tensor, dst=self.ranks[dst], group=self.cpu_group)
887
886
 
888
887
  # Send object
889
888
  torch.distributed.send(
890
- object_tensor, dst=self.ranks[dst], group=self.device_group
889
+ object_tensor,
890
+ dst=self.ranks[dst],
891
+ group=self.device_group,
891
892
  )
892
893
 
893
894
  return None
@@ -902,13 +903,11 @@ class GroupCoordinator:
902
903
  src != self.rank_in_group
903
904
  ), "Invalid source rank. Source rank is the same as the current rank."
904
905
 
905
- size_tensor = torch.empty(
906
- 1, dtype=torch.long, device=torch.cuda.current_device()
907
- )
906
+ size_tensor = torch.empty(1, dtype=torch.long, device="cpu")
908
907
 
909
908
  # Receive object size
910
909
  rank_size = torch.distributed.recv(
911
- size_tensor, src=self.ranks[src], group=self.device_group
910
+ size_tensor, src=self.ranks[src], group=self.cpu_group
912
911
  )
913
912
 
914
913
  # Tensor to receive serialized objects into.
@@ -926,7 +925,7 @@ class GroupCoordinator:
926
925
  rank_object == rank_size
927
926
  ), "Received object sender rank does not match the size sender rank."
928
927
 
929
- obj = pickle.loads(object_tensor.cpu().numpy().tobytes())
928
+ obj = pickle.loads(object_tensor.cpu().numpy())
930
929
 
931
930
  return obj
932
931
 
@@ -1459,43 +1458,49 @@ def initialize_model_parallel(
1459
1458
  _PDMUX_PREFILL_TP_GROUP.pynccl_comm.disabled = False
1460
1459
 
1461
1460
  moe_ep_size = expert_model_parallel_size
1462
-
1463
1461
  moe_tp_size = tensor_model_parallel_size // moe_ep_size
1462
+
1464
1463
  global _MOE_EP
1465
1464
  assert _MOE_EP is None, "expert model parallel group is already initialized"
1466
- group_ranks = []
1467
- for i in range(num_tensor_model_parallel_groups):
1468
- for j in range(moe_tp_size):
1469
- st = i * tensor_model_parallel_size + j
1470
- en = (i + 1) * tensor_model_parallel_size + j
1471
- ranks = list(range(st, en, moe_tp_size))
1472
- group_ranks.append(ranks)
1473
1465
 
1474
- _MOE_EP = init_model_parallel_group(
1475
- group_ranks,
1476
- get_world_group().local_rank,
1477
- backend,
1478
- use_custom_allreduce=False,
1479
- group_name="moe_ep",
1480
- )
1466
+ if moe_ep_size == tensor_model_parallel_size:
1467
+ _MOE_EP = _TP
1468
+ else:
1469
+ # TODO(ch-wan): use split_group to save memory
1470
+ group_ranks = []
1471
+ for i in range(num_tensor_model_parallel_groups):
1472
+ for j in range(moe_tp_size):
1473
+ st = i * tensor_model_parallel_size + j
1474
+ en = (i + 1) * tensor_model_parallel_size + j
1475
+ ranks = list(range(st, en, moe_tp_size))
1476
+ group_ranks.append(ranks)
1477
+ _MOE_EP = init_model_parallel_group(
1478
+ group_ranks,
1479
+ get_world_group().local_rank,
1480
+ backend,
1481
+ group_name="moe_ep",
1482
+ )
1481
1483
 
1482
1484
  global _MOE_TP
1483
1485
  assert _MOE_TP is None, "expert model parallel group is already initialized"
1484
- group_ranks = []
1485
- for i in range(num_tensor_model_parallel_groups):
1486
- for j in range(moe_ep_size):
1487
- st = i * tensor_model_parallel_size + j * moe_tp_size
1488
- en = i * tensor_model_parallel_size + (j + 1) * moe_tp_size
1489
- ranks = list(range(st, en))
1490
- group_ranks.append(ranks)
1491
1486
 
1492
- _MOE_TP = init_model_parallel_group(
1493
- group_ranks,
1494
- get_world_group().local_rank,
1495
- backend,
1496
- use_custom_allreduce=False,
1497
- group_name="moe_tp",
1498
- )
1487
+ if moe_tp_size == tensor_model_parallel_size:
1488
+ _MOE_TP = _TP
1489
+ else:
1490
+ # TODO(ch-wan): use split_group to save memory
1491
+ group_ranks = []
1492
+ for i in range(num_tensor_model_parallel_groups):
1493
+ for j in range(moe_ep_size):
1494
+ st = i * tensor_model_parallel_size + j * moe_tp_size
1495
+ en = i * tensor_model_parallel_size + (j + 1) * moe_tp_size
1496
+ ranks = list(range(st, en))
1497
+ group_ranks.append(ranks)
1498
+ _MOE_TP = init_model_parallel_group(
1499
+ group_ranks,
1500
+ get_world_group().local_rank,
1501
+ backend,
1502
+ group_name="moe_tp",
1503
+ )
1499
1504
 
1500
1505
  # Build the pipeline model-parallel groups.
1501
1506
  num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size
@@ -1643,7 +1648,7 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
1643
1648
 
1644
1649
  ray.shutdown()
1645
1650
  gc.collect()
1646
- if not current_platform.is_cpu():
1651
+ if not _is_cpu:
1647
1652
  if hasattr(torch, "cuda") and torch.cuda.is_available():
1648
1653
  torch.cuda.empty_cache()
1649
1654
  if hasattr(torch._C, "_host_emptyCache"):
@@ -681,7 +681,7 @@ def _set_envs_and_config(server_args: ServerArgs):
681
681
  if _is_cuda and not get_bool_env_var("SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK"):
682
682
  assert_pkg_version(
683
683
  "sgl-kernel",
684
- "0.3.7.post1",
684
+ "0.3.8",
685
685
  "Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
686
686
  )
687
687
 
@@ -29,6 +29,8 @@ import time
29
29
  from http import HTTPStatus
30
30
  from typing import Any, AsyncIterator, Callable, Dict, List, Optional
31
31
 
32
+ import setproctitle
33
+
32
34
  # Fix a bug of Python threading
33
35
  setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
34
36
 
@@ -102,7 +104,7 @@ from sglang.srt.managers.multi_tokenizer_mixin import (
102
104
  from sglang.srt.managers.template_manager import TemplateManager
103
105
  from sglang.srt.managers.tokenizer_manager import ServerStatus, TokenizerManager
104
106
  from sglang.srt.metrics.func_timer import enable_func_timer
105
- from sglang.srt.reasoning_parser import ReasoningParser
107
+ from sglang.srt.parser.reasoning_parser import ReasoningParser
106
108
  from sglang.srt.server_args import PortArgs, ServerArgs
107
109
  from sglang.srt.utils import (
108
110
  add_api_key_middleware,
@@ -1166,6 +1168,7 @@ def launch_server(
1166
1168
  2. Inter-process communication is done through IPC (each process uses a different port) via the ZMQ library.
1167
1169
  """
1168
1170
  if server_args.tokenizer_worker_num > 1:
1171
+ setproctitle.setproctitle(f"sglang::http_server/multi_tokenizer_router")
1169
1172
  port_args = PortArgs.init_new(server_args)
1170
1173
  port_args.tokenizer_worker_ipc_name = (
1171
1174
  f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}"
@@ -1174,6 +1177,7 @@ def launch_server(
1174
1177
  server_args=server_args, port_args=port_args
1175
1178
  )
1176
1179
  else:
1180
+ setproctitle.setproctitle(f"sglang::http_server/tokenizer_manager")
1177
1181
  tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses(
1178
1182
  server_args=server_args,
1179
1183
  )
@@ -542,9 +542,9 @@ class ChatCompletionRequest(BaseModel):
542
542
  rid: Optional[Union[List[str], str]] = None
543
543
 
544
544
  # For PD disaggregation
545
- bootstrap_host: Optional[str] = None
546
- bootstrap_port: Optional[int] = None
547
- bootstrap_room: Optional[int] = None
545
+ bootstrap_host: Optional[Union[List[str], str]] = None
546
+ bootstrap_port: Optional[Union[List[Optional[int]], int]] = None
547
+ bootstrap_room: Optional[Union[List[int], int]] = None
548
548
 
549
549
 
550
550
  class ChatMessage(BaseModel):
@@ -8,7 +8,6 @@ from typing import Any, AsyncGenerator, Dict, List, Optional, Union
8
8
  from fastapi import Request
9
9
  from fastapi.responses import ORJSONResponse, StreamingResponse
10
10
 
11
- from sglang.srt.conversation import generate_chat_conv
12
11
  from sglang.srt.entrypoints.openai.protocol import (
13
12
  ChatCompletionRequest,
14
13
  ChatCompletionResponse,
@@ -33,11 +32,12 @@ from sglang.srt.entrypoints.openai.utils import (
33
32
  to_openai_style_logprobs,
34
33
  )
35
34
  from sglang.srt.function_call.function_call_parser import FunctionCallParser
36
- from sglang.srt.jinja_template_utils import process_content_for_template_format
37
35
  from sglang.srt.managers.io_struct import GenerateReqInput
38
36
  from sglang.srt.managers.template_manager import TemplateManager
39
37
  from sglang.srt.managers.tokenizer_manager import TokenizerManager
40
- from sglang.srt.reasoning_parser import ReasoningParser
38
+ from sglang.srt.parser.conversation import generate_chat_conv
39
+ from sglang.srt.parser.jinja_template_utils import process_content_for_template_format
40
+ from sglang.srt.parser.reasoning_parser import ReasoningParser
41
41
  from sglang.utils import convert_json_schema_to_str
42
42
 
43
43
  logger = logging.getLogger(__name__)
@@ -5,7 +5,6 @@ from typing import Any, AsyncGenerator, Dict, List, Optional, Union
5
5
  from fastapi import Request
6
6
  from fastapi.responses import ORJSONResponse, StreamingResponse
7
7
 
8
- from sglang.srt.code_completion_parser import generate_completion_prompt_from_request
9
8
  from sglang.srt.entrypoints.openai.protocol import (
10
9
  CompletionRequest,
11
10
  CompletionResponse,
@@ -23,6 +22,9 @@ from sglang.srt.entrypoints.openai.utils import (
23
22
  from sglang.srt.managers.io_struct import GenerateReqInput
24
23
  from sglang.srt.managers.template_manager import TemplateManager
25
24
  from sglang.srt.managers.tokenizer_manager import TokenizerManager
25
+ from sglang.srt.parser.code_completion_parser import (
26
+ generate_completion_prompt_from_request,
27
+ )
26
28
  from sglang.utils import convert_json_schema_to_str
27
29
 
28
30
  logger = logging.getLogger(__name__)
@@ -3,7 +3,6 @@ from typing import Any, Dict, List, Optional, Union
3
3
  from fastapi import Request
4
4
  from fastapi.responses import ORJSONResponse
5
5
 
6
- from sglang.srt.conversation import generate_embedding_convs
7
6
  from sglang.srt.entrypoints.openai.protocol import (
8
7
  EmbeddingObject,
9
8
  EmbeddingRequest,
@@ -16,6 +15,7 @@ from sglang.srt.entrypoints.openai.serving_base import OpenAIServingBase
16
15
  from sglang.srt.managers.io_struct import EmbeddingReqInput
17
16
  from sglang.srt.managers.template_manager import TemplateManager
18
17
  from sglang.srt.managers.tokenizer_manager import TokenizerManager
18
+ from sglang.srt.parser.conversation import generate_embedding_convs
19
19
 
20
20
 
21
21
  class OpenAIServingEmbedding(OpenAIServingBase):
@@ -56,7 +56,7 @@ from sglang.srt.entrypoints.openai.tool_server import MCPToolServer, ToolServer
56
56
  from sglang.srt.managers.io_struct import GenerateReqInput
57
57
  from sglang.srt.managers.template_manager import TemplateManager
58
58
  from sglang.srt.managers.tokenizer_manager import TokenizerManager
59
- from sglang.srt.reasoning_parser import ReasoningParser
59
+ from sglang.srt.parser.reasoning_parser import ReasoningParser
60
60
  from sglang.srt.utils import random_uuid
61
61
 
62
62
  logger = logging.getLogger(__name__)
@@ -10,7 +10,7 @@ from sglang.srt.function_call.core_types import (
10
10
  ToolCallItem,
11
11
  _GetInfoFunc,
12
12
  )
13
- from sglang.srt.harmony_parser import HarmonyParser
13
+ from sglang.srt.parser.harmony_parser import HarmonyParser
14
14
 
15
15
  logger = logging.getLogger(__name__)
16
16
 
@@ -18,7 +18,10 @@ import triton.language as tl
18
18
  from sglang.global_config import global_config
19
19
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
20
20
  from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
21
- from sglang.srt.layers.dp_attention import get_attention_tp_size
21
+ from sglang.srt.layers.dp_attention import (
22
+ get_attention_tp_size,
23
+ is_dp_attention_enabled,
24
+ )
22
25
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
23
26
 
24
27
  if TYPE_CHECKING:
@@ -154,6 +157,8 @@ class AiterAttnBackend(AttentionBackend):
154
157
  (max_bs + 1,), dtype=torch.int32, device=model_runner.device
155
158
  )
156
159
 
160
+ self.enable_dp_attention = is_dp_attention_enabled()
161
+
157
162
  def init_forward_metadata(self, forward_batch: ForwardBatch):
158
163
  """Init auxiliary variables for triton attention backend."""
159
164
 
@@ -302,19 +307,19 @@ class AiterAttnBackend(AttentionBackend):
302
307
  if self.use_mla:
303
308
  self.mla_indices_updater_prefill.update(
304
309
  forward_batch.req_pool_indices,
305
- forward_batch.extend_prefix_lens,
306
- sum(forward_batch.extend_prefix_lens_cpu),
310
+ forward_batch.seq_lens,
311
+ forward_batch.seq_lens_sum,
307
312
  forward_batch.extend_seq_lens,
308
- max(forward_batch.extend_seq_lens_cpu),
309
- forward_batch.seq_lens_cpu.max().item(),
313
+ forward_batch.extend_seq_lens.max().item(),
314
+ forward_batch.seq_lens.max().item(),
310
315
  spec_info=None,
311
316
  )
312
- self.mla_indices_updater_prefill.kv_indptr += (
313
- self.mla_indices_updater_prefill.qo_indptr
314
- )
317
+
318
+ kv_indices = self.mla_indices_updater_prefill.kv_indices
319
+
315
320
  self.forward_metadata = ForwardMetadata(
316
321
  self.mla_indices_updater_prefill.kv_indptr,
317
- self.mla_indices_updater_prefill.kv_indices,
322
+ kv_indices,
318
323
  self.mla_indices_updater_prefill.qo_indptr,
319
324
  self.kv_last_page_len[:bs],
320
325
  self.mla_indices_updater_prefill.max_q_len,
@@ -614,66 +619,86 @@ class AiterAttnBackend(AttentionBackend):
614
619
  assert len(k.shape) == 3
615
620
  assert len(v.shape) == 3
616
621
 
617
- if kv_indices.shape[0] == 0:
618
- o = flash_attn_varlen_func(
619
- q,
620
- k,
621
- v,
622
- qo_indptr,
623
- qo_indptr,
624
- max_q_len,
625
- max_q_len,
626
- softmax_scale=layer.scaling,
627
- causal=True,
628
- )
629
- return o
630
- elif layer.qk_head_dim != (kv_lora_rank + qk_rope_head_dim):
631
- K_Buffer = torch.index_select(K_Buffer, 0, kv_indices)
632
- kvc, k_pe = torch.split(
633
- K_Buffer, [kv_lora_rank, qk_rope_head_dim], dim=-1
634
- )
635
- kvprefix = layer.kv_b_proj(kvc.contiguous())[0]
622
+ if forward_batch.forward_mode.is_extend():
623
+ if kv_indices.shape[0] == 0:
624
+ o = flash_attn_varlen_func(
625
+ q,
626
+ k,
627
+ v,
628
+ qo_indptr,
629
+ qo_indptr,
630
+ max_q_len,
631
+ max_q_len,
632
+ softmax_scale=layer.scaling,
633
+ causal=True,
634
+ )
635
+ return o
636
+ elif layer.qk_head_dim != (kv_lora_rank + qk_rope_head_dim):
637
+ K_Buffer = torch.index_select(K_Buffer, 0, kv_indices)
638
+ kvc, k_pe = torch.split(
639
+ K_Buffer, [kv_lora_rank, qk_rope_head_dim], dim=-1
640
+ )
641
+ kvprefix = layer.kv_b_proj(kvc.contiguous())[0]
636
642
 
637
- kvprefix = kvprefix.view(
638
- -1, layer.tp_k_head_num, qk_nope_head_dim + layer.v_head_dim
639
- )
640
- k_prefix, v_prefix = torch.split(
641
- kvprefix, [qk_nope_head_dim, layer.v_head_dim], dim=-1
642
- )
643
- k_prefix = torch.cat(
644
- [
645
- k_prefix,
646
- torch.broadcast_to(
647
- k_pe,
648
- (k_pe.shape[0], layer.tp_k_head_num, k_pe.shape[2]),
649
- ),
650
- ],
651
- dim=-1,
652
- )
653
- assert (
654
- forward_batch.extend_prefix_lens.shape
655
- == forward_batch.extend_seq_lens.shape
656
- )
657
- k_prefix = torch.split(k_prefix, forward_batch.extend_prefix_lens_cpu)
658
- k_extend = torch.split(k, forward_batch.extend_seq_lens_cpu)
659
- assert len(k_prefix) == len(forward_batch.extend_prefix_lens_cpu)
660
- k = torch.cat([x for el in zip(k_prefix, k_extend) for x in el])
661
- v_prefix = torch.split(v_prefix, forward_batch.extend_prefix_lens_cpu)
662
- v_extend = torch.split(v, forward_batch.extend_seq_lens_cpu)
663
- v = torch.cat([x for el in zip(v_prefix, v_extend) for x in el])
664
-
665
- o = flash_attn_varlen_func(
666
- q,
667
- k,
668
- v,
669
- qo_indptr,
670
- kv_indptr,
671
- max_q_len,
672
- max_kv_len,
673
- softmax_scale=layer.scaling,
674
- causal=True,
675
- )
676
- return o
643
+ kvprefix = kvprefix.view(
644
+ -1, layer.tp_k_head_num, qk_nope_head_dim + layer.v_head_dim
645
+ )
646
+ k_prefix, v_prefix = torch.split(
647
+ kvprefix, [qk_nope_head_dim, layer.v_head_dim], dim=-1
648
+ )
649
+ k_prefix = torch.cat(
650
+ [
651
+ k_prefix,
652
+ torch.broadcast_to(
653
+ k_pe,
654
+ (k_pe.shape[0], layer.tp_k_head_num, k_pe.shape[2]),
655
+ ),
656
+ ],
657
+ dim=-1,
658
+ )
659
+ assert (
660
+ forward_batch.extend_prefix_lens.shape
661
+ == forward_batch.extend_seq_lens.shape
662
+ )
663
+
664
+ k = k_prefix
665
+ v = v_prefix
666
+
667
+ o = flash_attn_varlen_func(
668
+ q,
669
+ k,
670
+ v,
671
+ qo_indptr,
672
+ kv_indptr,
673
+ max_q_len,
674
+ max_kv_len,
675
+ softmax_scale=layer.scaling,
676
+ causal=True,
677
+ )
678
+ return o
679
+
680
+ else:
681
+ if layer.qk_head_dim != layer.v_head_dim:
682
+ o = q.new_empty(
683
+ (q.shape[0], layer.tp_q_head_num * layer.v_head_dim)
684
+ )
685
+ else:
686
+ o = torch.empty_like(q)
687
+
688
+ mla_prefill_fwd(
689
+ q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
690
+ K_Buffer.view(-1, 1, 1, layer.qk_head_dim),
691
+ o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
692
+ qo_indptr,
693
+ kv_indptr,
694
+ kv_indices,
695
+ self.forward_metadata.kv_last_page_len,
696
+ self.forward_metadata.max_q_len,
697
+ layer.scaling,
698
+ layer.logit_cap,
699
+ )
700
+ K_Buffer = K_Buffer.view(-1, layer.tp_k_head_num, layer.qk_head_dim)
701
+ return o
677
702
  elif forward_batch.forward_mode.is_target_verify():
678
703
  o = q.new_empty((q.shape[0], layer.tp_q_head_num, layer.v_head_dim))
679
704
  mla_decode_fwd(
@@ -42,10 +42,24 @@ from sglang.srt.layers.moe import (
42
42
  )
43
43
  from sglang.srt.managers.schedule_batch import global_server_args_dict
44
44
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
45
- from sglang.srt.utils import is_cuda, is_flashinfer_available, is_sm100_supported
45
+ from sglang.srt.utils import (
46
+ get_bool_env_var,
47
+ is_cuda,
48
+ is_flashinfer_available,
49
+ is_gfx95_supported,
50
+ is_hip,
51
+ is_sm90_supported,
52
+ is_sm100_supported,
53
+ )
46
54
 
47
55
  _is_flashinfer_available = is_flashinfer_available()
56
+ _is_sm90_supported = is_cuda() and is_sm90_supported()
48
57
  _is_sm100_supported = is_cuda() and is_sm100_supported()
58
+ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
59
+ _is_gfx95_supported = is_gfx95_supported()
60
+
61
+ if _use_aiter and _is_gfx95_supported:
62
+ from sglang.srt.layers.quantization.rocm_mxfp4_utils import fused_rms_mxfp4_quant
49
63
 
50
64
  FUSE_ALLREDUCE_MAX_BATCH_SIZE = 2048
51
65
 
@@ -201,6 +215,7 @@ class LayerCommunicator:
201
215
  hidden_states: torch.Tensor,
202
216
  residual: torch.Tensor,
203
217
  forward_batch: ForwardBatch,
218
+ qaunt_format: str = "",
204
219
  ):
205
220
  if hidden_states.shape[0] == 0:
206
221
  residual = hidden_states
@@ -218,11 +233,34 @@ class LayerCommunicator:
218
233
  else:
219
234
  if residual is None:
220
235
  residual = hidden_states
221
- hidden_states = self.input_layernorm(hidden_states)
236
+
237
+ if _use_aiter and _is_gfx95_supported and ("mxfp4" in qaunt_format):
238
+ hidden_states = fused_rms_mxfp4_quant(
239
+ hidden_states,
240
+ self.input_layernorm.weight,
241
+ self.input_layernorm.variance_epsilon,
242
+ None,
243
+ None,
244
+ None,
245
+ None,
246
+ )
247
+ else:
248
+ hidden_states = self.input_layernorm(hidden_states)
222
249
  else:
223
- hidden_states, residual = self.input_layernorm(
224
- hidden_states, residual
225
- )
250
+ if _use_aiter and _is_gfx95_supported and ("mxfp4" in qaunt_format):
251
+ hidden_states, residual = fused_rms_mxfp4_quant(
252
+ hidden_states,
253
+ self.input_layernorm.weight,
254
+ self.input_layernorm.variance_epsilon,
255
+ None,
256
+ None,
257
+ None,
258
+ residual,
259
+ )
260
+ else:
261
+ hidden_states, residual = self.input_layernorm(
262
+ hidden_states, residual
263
+ )
226
264
 
227
265
  hidden_states = self._communicate_simple_fn(
228
266
  hidden_states=hidden_states,
@@ -484,11 +522,11 @@ class CommunicateWithAllReduceAndLayerNormFn:
484
522
  # According to the discussion in https://github.com/flashinfer-ai/flashinfer/issues/1223#issuecomment-3047256465
485
523
  # We set the max token num to 128 for allreduce fusion with min-latency case(use_oneshot=True).
486
524
  if (
487
- _is_sm100_supported
525
+ (_is_sm100_supported or _is_sm90_supported)
488
526
  and _is_flashinfer_available
489
527
  and hasattr(layernorm, "forward_with_allreduce_fusion")
490
528
  and global_server_args_dict["enable_flashinfer_allreduce_fusion"]
491
- and hidden_states.shape[0] <= 2048
529
+ and hidden_states.shape[0] <= 4096
492
530
  ):
493
531
  hidden_states, residual = layernorm.forward_with_allreduce_fusion(
494
532
  hidden_states, residual
@@ -91,18 +91,10 @@ def cutlass_w4a8_moe(
91
91
  assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch"
92
92
  assert w1_q.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch"
93
93
  assert w1_q.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch"
94
- assert (
95
- w1_scale.shape[1] == w1_q.shape[2] * 2 / 512
96
- and w1_scale.shape[2] == w1_q.shape[1] * 4
97
- ), "W1 scale shape mismatch"
98
- assert (
99
- w2_scale.shape[1] == w2_q.shape[2] * 2 / 512
100
- and w2_scale.shape[2] == w2_q.shape[1] * 4
101
- ), "W2 scale shape mismatch"
102
94
 
103
95
  assert a_strides1.shape[0] == w1_q.shape[0], "A Strides 1 expert number mismatch"
104
96
  assert b_strides1.shape[0] == w1_q.shape[0], "B Strides 1 expert number mismatch"
105
- assert a_strides2.shape[0] == w2_q.shape[0], "A Strides 2 expert number mismatch"
97
+ assert a_strides2.shape[0] == w2_q.shape[0], "A Strides 2 expert number mismatch"
106
98
  assert b_strides2.shape[0] == w2_q.shape[0], "B Strides 2 expert number mismatch"
107
99
  num_experts = w1_q.size(0)
108
100
  m = a.size(0)