sglang 0.3.6.post3__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. sglang/bench_one_batch.py +4 -0
  2. sglang/bench_serving.py +13 -0
  3. sglang/check_env.py +1 -1
  4. sglang/srt/_custom_ops.py +118 -0
  5. sglang/srt/configs/device_config.py +17 -0
  6. sglang/srt/configs/load_config.py +84 -0
  7. sglang/srt/configs/model_config.py +161 -4
  8. sglang/srt/configs/qwen2vl.py +5 -8
  9. sglang/srt/constrained/outlines_backend.py +6 -1
  10. sglang/srt/constrained/outlines_jump_forward.py +8 -1
  11. sglang/srt/distributed/__init__.py +3 -0
  12. sglang/srt/distributed/communication_op.py +34 -0
  13. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  14. sglang/srt/distributed/device_communicators/cuda_wrapper.py +182 -0
  15. sglang/srt/distributed/device_communicators/custom_all_reduce.py +352 -0
  16. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +291 -0
  17. sglang/srt/distributed/device_communicators/hpu_communicator.py +48 -0
  18. sglang/srt/distributed/device_communicators/pynccl.py +204 -0
  19. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +362 -0
  20. sglang/srt/distributed/device_communicators/shm_broadcast.py +568 -0
  21. sglang/srt/distributed/device_communicators/xpu_communicator.py +47 -0
  22. sglang/srt/distributed/parallel_state.py +1275 -0
  23. sglang/srt/distributed/utils.py +223 -0
  24. sglang/srt/hf_transformers_utils.py +37 -1
  25. sglang/srt/layers/attention/flashinfer_backend.py +13 -15
  26. sglang/srt/layers/attention/torch_native_backend.py +285 -0
  27. sglang/srt/layers/fused_moe_patch.py +20 -11
  28. sglang/srt/layers/linear.py +1 -0
  29. sglang/srt/layers/logits_processor.py +17 -3
  30. sglang/srt/layers/quantization/__init__.py +34 -0
  31. sglang/srt/layers/vocab_parallel_embedding.py +1 -0
  32. sglang/srt/lora/lora.py +1 -1
  33. sglang/srt/managers/io_struct.py +48 -2
  34. sglang/srt/managers/schedule_batch.py +18 -14
  35. sglang/srt/managers/schedule_policy.py +7 -4
  36. sglang/srt/managers/scheduler.py +76 -20
  37. sglang/srt/managers/tokenizer_manager.py +166 -68
  38. sglang/srt/managers/tp_worker.py +36 -3
  39. sglang/srt/managers/tp_worker_overlap_thread.py +21 -3
  40. sglang/srt/model_executor/cuda_graph_runner.py +16 -7
  41. sglang/srt/model_executor/forward_batch_info.py +9 -4
  42. sglang/srt/model_executor/model_runner.py +136 -150
  43. sglang/srt/model_loader/__init__.py +34 -0
  44. sglang/srt/model_loader/loader.py +1139 -0
  45. sglang/srt/model_loader/utils.py +41 -0
  46. sglang/srt/model_loader/weight_utils.py +640 -0
  47. sglang/srt/models/baichuan.py +9 -10
  48. sglang/srt/models/chatglm.py +6 -15
  49. sglang/srt/models/commandr.py +2 -3
  50. sglang/srt/models/dbrx.py +2 -3
  51. sglang/srt/models/deepseek.py +4 -11
  52. sglang/srt/models/deepseek_v2.py +3 -11
  53. sglang/srt/models/exaone.py +2 -3
  54. sglang/srt/models/gemma.py +2 -6
  55. sglang/srt/models/gemma2.py +3 -14
  56. sglang/srt/models/gemma2_reward.py +0 -1
  57. sglang/srt/models/gpt2.py +5 -12
  58. sglang/srt/models/gpt_bigcode.py +6 -22
  59. sglang/srt/models/grok.py +3 -3
  60. sglang/srt/models/internlm2.py +2 -3
  61. sglang/srt/models/internlm2_reward.py +0 -1
  62. sglang/srt/models/llama.py +97 -27
  63. sglang/srt/models/llama_classification.py +1 -2
  64. sglang/srt/models/llama_embedding.py +1 -2
  65. sglang/srt/models/llama_reward.py +2 -3
  66. sglang/srt/models/llava.py +1 -4
  67. sglang/srt/models/llavavid.py +1 -2
  68. sglang/srt/models/minicpm.py +4 -7
  69. sglang/srt/models/minicpm3.py +6 -19
  70. sglang/srt/models/mixtral.py +12 -5
  71. sglang/srt/models/mixtral_quant.py +2 -3
  72. sglang/srt/models/mllama.py +3 -7
  73. sglang/srt/models/olmo.py +2 -8
  74. sglang/srt/models/olmo2.py +0 -1
  75. sglang/srt/models/olmoe.py +3 -5
  76. sglang/srt/models/phi3_small.py +8 -8
  77. sglang/srt/models/qwen.py +2 -3
  78. sglang/srt/models/qwen2.py +10 -9
  79. sglang/srt/models/qwen2_moe.py +4 -11
  80. sglang/srt/models/qwen2_vl.py +2 -6
  81. sglang/srt/models/registry.py +99 -0
  82. sglang/srt/models/stablelm.py +2 -3
  83. sglang/srt/models/torch_native_llama.py +6 -12
  84. sglang/srt/models/xverse.py +2 -4
  85. sglang/srt/models/xverse_moe.py +4 -11
  86. sglang/srt/models/yivl.py +2 -3
  87. sglang/srt/openai_api/adapter.py +9 -5
  88. sglang/srt/openai_api/protocol.py +1 -0
  89. sglang/srt/server.py +267 -170
  90. sglang/srt/server_args.py +65 -31
  91. sglang/srt/utils.py +245 -28
  92. sglang/test/test_utils.py +7 -0
  93. sglang/version.py +1 -1
  94. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.dist-info}/METADATA +1 -1
  95. sglang-0.4.0.dist-info/RECORD +184 -0
  96. sglang-0.3.6.post3.dist-info/RECORD +0 -162
  97. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.dist-info}/LICENSE +0 -0
  98. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.dist-info}/WHEEL +0 -0
  99. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,223 @@
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/utils.py
2
+ # Copyright 2023 The vLLM team.
3
+ # Adapted from
4
+ # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
5
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
6
+ import dataclasses
7
+ import logging
8
+ import os
9
+ import pickle
10
+ import time
11
+ from collections import deque
12
+ from typing import Any, Deque, Dict, Optional, Sequence, Tuple
13
+
14
+ import torch
15
+ from torch.distributed import TCPStore
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ def ensure_divisibility(numerator, denominator):
21
+ """Ensure that numerator is divisible by the denominator."""
22
+ assert numerator % denominator == 0, "{} is not divisible by {}".format(
23
+ numerator, denominator
24
+ )
25
+
26
+
27
+ def divide(numerator, denominator):
28
+ """Ensure that numerator is divisible by the denominator and return
29
+ the division value."""
30
+ ensure_divisibility(numerator, denominator)
31
+ return numerator // denominator
32
+
33
+
34
+ def split_tensor_along_last_dim(
35
+ tensor: torch.Tensor,
36
+ num_partitions: int,
37
+ contiguous_split_chunks: bool = False,
38
+ ) -> Sequence[torch.Tensor]:
39
+ """Split a tensor along its last dimension.
40
+
41
+ Arguments:
42
+ tensor: input tensor.
43
+ num_partitions: number of partitions to split the tensor
44
+ contiguous_split_chunks: If True, make each chunk contiguous
45
+ in memory.
46
+
47
+ Returns:
48
+ A list of Tensors
49
+ """
50
+ # Get the size and dimension.
51
+ last_dim = tensor.dim() - 1
52
+ last_dim_size = divide(tensor.size()[last_dim], num_partitions)
53
+ # Split.
54
+ tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
55
+ # NOTE: torch.split does not create contiguous tensors by default.
56
+ if contiguous_split_chunks:
57
+ return tuple(chunk.contiguous() for chunk in tensor_list)
58
+
59
+ return tensor_list
60
+
61
+
62
+ def get_pp_indices(
63
+ num_hidden_layers: int, pp_rank: int, pp_size: int
64
+ ) -> Tuple[int, int]:
65
+ """Try to evenly distribute layers across partitions.
66
+ If the number of layers is not divisible by the number of partitions,
67
+ the last partition will have the remaining layers.
68
+ """
69
+ # partition_list_str can be set to None in sglang
70
+ partition_list_str = os.getenv("SGLANG_PP_LAYER_PARTITION", None)
71
+ if partition_list_str is not None:
72
+ try:
73
+ partitions = [int(layer) for layer in partition_list_str.split(",")]
74
+ except ValueError as err:
75
+ raise ValueError(
76
+ "Invalid partition string: {}".format(partition_list_str)
77
+ ) from err
78
+ if len(partitions) != pp_size:
79
+ raise ValueError(f"{len(partitions)=} does not match {pp_size=}.")
80
+ if sum(partitions) != num_hidden_layers:
81
+ raise ValueError(f"{sum(partitions)=} does not match {num_hidden_layers=}.")
82
+ start_layer = sum(partitions[:pp_rank])
83
+ end_layer = start_layer + partitions[pp_rank]
84
+ else:
85
+ layers_per_partition = num_hidden_layers // pp_size
86
+ start_layer = pp_rank * layers_per_partition
87
+ end_layer = start_layer + layers_per_partition
88
+
89
+ if pp_rank == pp_size - 1:
90
+ end_layer = num_hidden_layers
91
+
92
+ return (start_layer, end_layer)
93
+
94
+
95
+ @dataclasses.dataclass
96
+ class StatelessProcessGroup:
97
+ """A dataclass to hold a metadata store, and the rank, world_size of the
98
+ group. Only use it to communicate metadata between processes.
99
+ For data-plane communication, create NCCL-related objects.
100
+ """
101
+
102
+ rank: int
103
+ world_size: int
104
+ store: torch._C._distributed_c10d.Store
105
+ data_expiration_seconds: int = 3600 # 1 hour
106
+
107
+ # dst rank -> counter
108
+ send_dst_counter: Dict[int, int] = dataclasses.field(default_factory=dict)
109
+ # src rank -> counter
110
+ recv_src_counter: Dict[int, int] = dataclasses.field(default_factory=dict)
111
+ broadcast_send_counter: int = 0
112
+ broadcast_recv_src_counter: Dict[int, int] = dataclasses.field(default_factory=dict)
113
+
114
+ # A deque to store the data entries, with key and timestamp.
115
+ entries: Deque[Tuple[str, float]] = dataclasses.field(default_factory=deque)
116
+
117
+ def __post_init__(self):
118
+ assert self.rank < self.world_size
119
+ self.send_dst_counter = {i: 0 for i in range(self.world_size)}
120
+ self.recv_src_counter = {i: 0 for i in range(self.world_size)}
121
+ self.broadcast_recv_src_counter = {i: 0 for i in range(self.world_size)}
122
+
123
+ def send_obj(self, obj: Any, dst: int):
124
+ """Send an object to a destination rank."""
125
+ self.expire_data()
126
+ key = f"send_to/{dst}/{self.send_dst_counter[dst]}"
127
+ self.store.set(key, pickle.dumps(obj))
128
+ self.send_dst_counter[dst] += 1
129
+ self.entries.append((key, time.time()))
130
+
131
+ def expire_data(self):
132
+ """Expire data that is older than `data_expiration_seconds` seconds."""
133
+ while self.entries:
134
+ # check the oldest entry
135
+ key, timestamp = self.entries[0]
136
+ if time.time() - timestamp > self.data_expiration_seconds:
137
+ self.store.delete_key(key)
138
+ self.entries.popleft()
139
+ else:
140
+ break
141
+
142
+ def recv_obj(self, src: int) -> Any:
143
+ """Receive an object from a source rank."""
144
+ obj = pickle.loads(
145
+ self.store.get(f"send_to/{self.rank}/{self.recv_src_counter[src]}")
146
+ )
147
+ self.recv_src_counter[src] += 1
148
+ return obj
149
+
150
+ def broadcast_obj(self, obj: Optional[Any], src: int) -> Any:
151
+ """Broadcast an object from a source rank to all other ranks.
152
+ It does not clean up after all ranks have received the object.
153
+ Use it for limited times, e.g., for initialization.
154
+ """
155
+ if self.rank == src:
156
+ self.expire_data()
157
+ key = f"broadcast_from/{src}/" f"{self.broadcast_send_counter}"
158
+ self.store.set(key, pickle.dumps(obj))
159
+ self.broadcast_send_counter += 1
160
+ self.entries.append((key, time.time()))
161
+ return obj
162
+ else:
163
+ key = f"broadcast_from/{src}/" f"{self.broadcast_recv_src_counter[src]}"
164
+ recv_obj = pickle.loads(self.store.get(key))
165
+ self.broadcast_recv_src_counter[src] += 1
166
+ return recv_obj
167
+
168
+ def all_gather_obj(self, obj: Any) -> list[Any]:
169
+ """All gather an object from all ranks."""
170
+ gathered_objs = []
171
+ for i in range(self.world_size):
172
+ if i == self.rank:
173
+ gathered_objs.append(obj)
174
+ self.broadcast_obj(obj, src=self.rank)
175
+ else:
176
+ recv_obj = self.broadcast_obj(None, src=i)
177
+ gathered_objs.append(recv_obj)
178
+ return gathered_objs
179
+
180
+ def barrier(self):
181
+ """A barrier to synchronize all ranks."""
182
+ for i in range(self.world_size):
183
+ if i == self.rank:
184
+ self.broadcast_obj(None, src=self.rank)
185
+ else:
186
+ self.broadcast_obj(None, src=i)
187
+
188
+ @staticmethod
189
+ def create(
190
+ host: str,
191
+ port: int,
192
+ rank: int,
193
+ world_size: int,
194
+ data_expiration_seconds: int = 3600,
195
+ ) -> "StatelessProcessGroup":
196
+ """A replacement for `torch.distributed.init_process_group` that does not
197
+ pollute the global state.
198
+
199
+ If we have process A and process B called `torch.distributed.init_process_group`
200
+ to form a group, and then we want to form another group with process A, B, C,
201
+ D, it is not possible in PyTorch, because process A and process B have already
202
+ formed a group, and process C and process D cannot join that group. This
203
+ function is a workaround for this issue.
204
+
205
+ `torch.distributed.init_process_group` is a global call, while this function
206
+ is a stateless call. It will return a `StatelessProcessGroup` object that can be
207
+ used for exchanging metadata. With this function, process A and process B
208
+ can call `StatelessProcessGroup.create` to form a group, and then process A, B,
209
+ C, and D can call `StatelessProcessGroup.create` to form another group.
210
+ """ # noqa
211
+ store = TCPStore(
212
+ host_name=host,
213
+ port=port,
214
+ world_size=world_size,
215
+ is_master=(rank == 0),
216
+ )
217
+
218
+ return StatelessProcessGroup(
219
+ rank=rank,
220
+ world_size=world_size,
221
+ store=store,
222
+ data_expiration_seconds=data_expiration_seconds,
223
+ )
@@ -16,6 +16,7 @@
16
16
  import contextlib
17
17
  import os
18
18
  import warnings
19
+ from pathlib import Path
19
20
  from typing import Dict, Optional, Type, Union
20
21
 
21
22
  from huggingface_hub import snapshot_download
@@ -27,6 +28,7 @@ from transformers import (
27
28
  PreTrainedTokenizer,
28
29
  PreTrainedTokenizerFast,
29
30
  )
31
+ from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
30
32
 
31
33
  try:
32
34
  from vllm.transformers_utils.configs import ChatGLMConfig, DbrxConfig
@@ -60,15 +62,31 @@ def get_config(
60
62
  trust_remote_code: bool,
61
63
  revision: Optional[str] = None,
62
64
  model_override_args: Optional[dict] = None,
65
+ **kwargs,
63
66
  ):
67
+ is_gguf = check_gguf_file(model)
68
+ if is_gguf:
69
+ kwargs["gguf_file"] = model
70
+ model = Path(model).parent
71
+
64
72
  config = AutoConfig.from_pretrained(
65
- model, trust_remote_code=trust_remote_code, revision=revision
73
+ model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
66
74
  )
67
75
  if config.model_type in _CONFIG_REGISTRY:
68
76
  config_class = _CONFIG_REGISTRY[config.model_type]
69
77
  config = config_class.from_pretrained(model, revision=revision)
78
+ # NOTE(HandH1998): Qwen2VL requires `_name_or_path` attribute in `config`.
79
+ setattr(config, "_name_or_path", model)
70
80
  if model_override_args:
71
81
  config.update(model_override_args)
82
+
83
+ # Special architecture mapping check for GGUF models
84
+ if is_gguf:
85
+ if config.model_type not in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
86
+ raise RuntimeError(f"Can't get gguf config for {config.model_type}.")
87
+ model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type]
88
+ config.update({"architectures": [model_type]})
89
+
72
90
  return config
73
91
 
74
92
 
@@ -123,6 +141,11 @@ def get_tokenizer(
123
141
  raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
124
142
  kwargs["use_fast"] = False
125
143
 
144
+ is_gguf = check_gguf_file(tokenizer_name)
145
+ if is_gguf:
146
+ kwargs["gguf_file"] = tokenizer_name
147
+ tokenizer_name = Path(tokenizer_name).parent
148
+
126
149
  try:
127
150
  tokenizer = AutoTokenizer.from_pretrained(
128
151
  tokenizer_name,
@@ -195,3 +218,16 @@ def attach_additional_stop_token_ids(tokenizer):
195
218
  )
196
219
  else:
197
220
  tokenizer.additional_stop_token_ids = None
221
+
222
+
223
+ def check_gguf_file(model: Union[str, os.PathLike]) -> bool:
224
+ """Check if the file is a GGUF model."""
225
+ model = Path(model)
226
+ if not model.is_file():
227
+ return False
228
+ elif model.suffix == ".gguf":
229
+ return True
230
+
231
+ with open(model, "rb") as f:
232
+ header = f.read(4)
233
+ return header == b"GGUF"
@@ -18,7 +18,11 @@ import triton.language as tl
18
18
  from sglang.global_config import global_config
19
19
  from sglang.srt.layers.attention import AttentionBackend
20
20
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
21
- from sglang.srt.utils import get_bool_env_var, is_flashinfer_available
21
+ from sglang.srt.utils import (
22
+ get_bool_env_var,
23
+ is_flashinfer_available,
24
+ should_use_tensor_core,
25
+ )
22
26
 
23
27
  if TYPE_CHECKING:
24
28
  from sglang.srt.layers.radix_attention import RadixAttention
@@ -31,7 +35,6 @@ if is_flashinfer_available():
31
35
  BatchPrefillWithRaggedKVCacheWrapper,
32
36
  )
33
37
  from flashinfer.cascade import merge_state
34
- from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
35
38
 
36
39
 
37
40
  class WrapperDispatch(Enum):
@@ -45,19 +48,14 @@ class FlashInferAttnBackend(AttentionBackend):
45
48
  def __init__(self, model_runner: ModelRunner):
46
49
  super().__init__()
47
50
 
48
- # Parse constants
49
- if "SGLANG_FLASHINFER_USE_TENSOR_CORE" in os.environ:
50
- self.decode_use_tensor_cores = get_bool_env_var(
51
- "SGLANG_FLASHINFER_USE_TENSOR_CORE"
52
- )
53
- else:
54
- if not _grouped_size_compiled_for_decode_kernels(
55
- model_runner.model_config.num_attention_heads // model_runner.tp_size,
56
- model_runner.model_config.get_num_kv_heads(model_runner.tp_size),
57
- ):
58
- self.decode_use_tensor_cores = True
59
- else:
60
- self.decode_use_tensor_cores = False
51
+ self.decode_use_tensor_cores = should_use_tensor_core(
52
+ kv_cache_dtype=model_runner.kv_cache_dtype,
53
+ num_attention_heads=model_runner.model_config.num_attention_heads
54
+ // model_runner.tp_size,
55
+ num_kv_heads=model_runner.model_config.get_num_kv_heads(
56
+ model_runner.tp_size
57
+ ),
58
+ )
61
59
 
62
60
  self.max_context_len = model_runner.model_config.context_len
63
61
 
@@ -0,0 +1,285 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, Optional
4
+
5
+ import torch
6
+ from torch.nn.functional import scaled_dot_product_attention
7
+
8
+ from sglang.srt.layers.attention import AttentionBackend
9
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
10
+
11
+ if TYPE_CHECKING:
12
+ from sglang.srt.layers.radix_attention import RadixAttention
13
+ from sglang.srt.model_executor.model_runner import ModelRunner
14
+
15
+
16
+ class TorchNativeAttnBackend(AttentionBackend):
17
+ def __init__(self, model_runner: ModelRunner):
18
+ super().__init__()
19
+ self.forward_metadata = None
20
+ self.device = model_runner.device
21
+
22
+ def init_forward_metadata(self, forward_batch: ForwardBatch):
23
+ """Init the metadata for a forward pass."""
24
+ pass
25
+
26
+ def init_cuda_graph_state(self, max_bs: int):
27
+ # TODO: Support CUDA graph
28
+ raise ValueError(
29
+ "Torch native attention does not support CUDA graph for now. Please --disable-cuda-graph"
30
+ )
31
+
32
+ def init_forward_metadata_capture_cuda_graph(
33
+ self,
34
+ bs: int,
35
+ req_pool_indices: torch.Tensor,
36
+ seq_lens: torch.Tensor,
37
+ encoder_lens: Optional[torch.Tensor] = None,
38
+ ):
39
+ # TODO: Support CUDA graph
40
+ raise ValueError(
41
+ "Torch native attention does not support CUDA graph for now. Please --disable-cuda-graph"
42
+ )
43
+
44
+ def init_forward_metadata_replay_cuda_graph(
45
+ self,
46
+ bs: int,
47
+ req_pool_indices: torch.Tensor,
48
+ seq_lens: torch.Tensor,
49
+ seq_lens_sum: int,
50
+ encoder_lens: Optional[torch.Tensor] = None,
51
+ ):
52
+ # TODO: Support CUDA graph
53
+ raise ValueError(
54
+ "Torch native attention does not support CUDA graph for now. Please --disable-cuda-graph"
55
+ )
56
+
57
+ def get_cuda_graph_seq_len_fill_value(self):
58
+ # TODO: Support CUDA graph
59
+ raise ValueError(
60
+ "Torch native attention does not support CUDA graph for now. Please --disable-cuda-graph"
61
+ )
62
+
63
+ def _run_sdpa_forward_extend(
64
+ self,
65
+ query: torch.Tensor,
66
+ output: torch.Tensor,
67
+ k_cache: torch.Tensor,
68
+ v_cache: torch.Tensor,
69
+ req_to_token: torch.Tensor,
70
+ req_pool_indices: torch.Tensor,
71
+ seq_lens: torch.Tensor,
72
+ extend_prefix_lens: torch.Tensor,
73
+ extend_seq_lens: torch.Tensor,
74
+ scaling=None,
75
+ enable_gqa=False,
76
+ causal=False,
77
+ ):
78
+ """Run the extend forward by using torch native sdpa op.
79
+
80
+ Args:
81
+ query: [num_tokens, num_heads, head_size]
82
+ output: [num_tokens, num_heads, head_size]
83
+ k_cache: [max_total_num_tokens, num_heads, head_size]
84
+ v_cache: [max_total_num_tokens, num_heads, head_size]
85
+ req_to_token: [max_num_reqs, max_context_len]
86
+ req_pool_indices: [num_seqs]
87
+ seq_lens: [num_seqs]
88
+ extend_prefix_lens: [num_seqs]
89
+ extend_seq_lens: [num_seqs]
90
+ scaling: float or None
91
+ enable_gqa: bool
92
+ causal: bool
93
+
94
+ Returns:
95
+ output: [num_tokens, num_heads, head_size]
96
+ """
97
+
98
+ assert seq_lens.shape[0] == extend_prefix_lens.shape[0]
99
+ assert seq_lens.shape[0] == extend_seq_lens.shape[0]
100
+
101
+ # [num_tokens, num_heads, head_size] -> [num_heads, num_tokens, head_size]
102
+ query = query.movedim(0, query.dim() - 2)
103
+
104
+ start_q, start_kv = 0, 0
105
+ for seq_idx in range(seq_lens.shape[0]):
106
+ # TODO: this loop process a sequence per iter, this is inefficient.
107
+ # Need optimize the performance later.
108
+
109
+ extend_seq_len_q = extend_seq_lens[seq_idx]
110
+ prefill_seq_len_q = extend_prefix_lens[seq_idx]
111
+
112
+ seq_len_kv = seq_lens[seq_idx]
113
+ end_q = start_q + extend_seq_len_q
114
+ end_kv = start_kv + seq_len_kv
115
+
116
+ per_req_query = query[:, start_q:end_q, :]
117
+ per_req_query_redudant = torch.empty(
118
+ (per_req_query.shape[0], seq_len_kv, per_req_query.shape[2]),
119
+ dtype=per_req_query.dtype,
120
+ device=per_req_query.device,
121
+ )
122
+
123
+ per_req_query_redudant[:, prefill_seq_len_q:, :] = per_req_query
124
+
125
+ # get key and value from cache. per_req_tokens contains the kv cache
126
+ # index for each token in the sequence.
127
+ req_pool_idx = req_pool_indices[seq_idx]
128
+ per_req_tokens = req_to_token[req_pool_idx, :seq_len_kv]
129
+ per_req_key = k_cache[per_req_tokens].movedim(0, query.dim() - 2)
130
+ per_req_value = v_cache[per_req_tokens].movedim(0, query.dim() - 2)
131
+
132
+ per_req_out_redudant = (
133
+ scaled_dot_product_attention(
134
+ per_req_query_redudant.unsqueeze(0),
135
+ per_req_key.unsqueeze(0),
136
+ per_req_value.unsqueeze(0),
137
+ enable_gqa=enable_gqa,
138
+ scale=scaling,
139
+ is_causal=causal,
140
+ )
141
+ .squeeze(0)
142
+ .movedim(query.dim() - 2, 0)
143
+ )
144
+ output[start_q:end_q, :, :] = per_req_out_redudant[prefill_seq_len_q:, :, :]
145
+ start_q, start_kv = end_q, end_kv
146
+ return output
147
+
148
+ def _run_sdpa_forward_decode(
149
+ self,
150
+ query: torch.Tensor,
151
+ output: torch.Tensor,
152
+ k_cache: torch.Tensor,
153
+ v_cache: torch.Tensor,
154
+ req_to_token: torch.Tensor,
155
+ req_pool_indices: torch.Tensor,
156
+ seq_lens: torch.Tensor,
157
+ scaling=None,
158
+ enable_gqa=False,
159
+ causal=False,
160
+ ):
161
+ """Run the decode forward by using torch native sdpa op.
162
+
163
+ Args:
164
+ query: [num_tokens, num_heads, head_size]
165
+ output: [num_tokens, num_heads, head_size]
166
+ k_cache: [max_total_num_tokens, num_heads, head_size]
167
+ v_cache: [max_total_num_tokens, num_heads, head_size]
168
+ req_to_token: [max_num_reqs, max_context_len]
169
+ req_pool_indices: [num_seqs]
170
+ seq_lens: [num_seqs]
171
+ scaling: float or None
172
+ enable_gqa: bool
173
+ causal: bool
174
+
175
+ Returns:
176
+ output: [num_tokens, num_heads, head_size]
177
+ """
178
+
179
+ # [num_tokens, num_heads, head_size] -> [num_heads, num_tokens, head_size]
180
+ query = query.movedim(0, query.dim() - 2)
181
+
182
+ start_q, start_kv = 0, 0
183
+ for seq_idx in range(seq_lens.shape[0]):
184
+ # TODO: this loop process a sequence per iter, this is inefficient.
185
+ # Need optimize the performance later.
186
+
187
+ seq_len_q = 1
188
+ seq_len_kv = seq_lens[seq_idx]
189
+ end_q = start_q + seq_len_q
190
+ end_kv = start_kv + seq_len_kv
191
+
192
+ per_req_query = query[:, start_q:end_q, :]
193
+
194
+ # get key and value from cache. per_req_tokens contains the kv cache
195
+ # index for each token in the sequence.
196
+ req_pool_idx = req_pool_indices[seq_idx]
197
+ per_req_tokens = req_to_token[req_pool_idx, :seq_len_kv]
198
+ per_req_key = k_cache[per_req_tokens].movedim(0, query.dim() - 2)
199
+ per_req_value = v_cache[per_req_tokens].movedim(0, query.dim() - 2)
200
+
201
+ per_req_out = (
202
+ scaled_dot_product_attention(
203
+ per_req_query.unsqueeze(0),
204
+ per_req_key.unsqueeze(0),
205
+ per_req_value.unsqueeze(0),
206
+ enable_gqa=enable_gqa,
207
+ scale=scaling,
208
+ is_causal=causal,
209
+ )
210
+ .squeeze(0)
211
+ .movedim(query.dim() - 2, 0)
212
+ )
213
+ output[start_q:end_q, :, :] = per_req_out
214
+ start_q, start_kv = end_q, end_kv
215
+
216
+ return output
217
+
218
+ def forward_extend(
219
+ self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
220
+ ):
221
+ if layer.qk_head_dim != layer.v_head_dim:
222
+ o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
223
+ else:
224
+ o = torch.empty_like(q)
225
+
226
+ forward_batch.token_to_kv_pool.set_kv_buffer(
227
+ layer, forward_batch.out_cache_loc, k, v
228
+ )
229
+
230
+ use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
231
+
232
+ q_ = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
233
+ o_ = o.view(-1, layer.tp_q_head_num, layer.v_head_dim)
234
+
235
+ self._run_sdpa_forward_extend(
236
+ q_,
237
+ o_,
238
+ forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
239
+ forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
240
+ forward_batch.req_to_token_pool.req_to_token,
241
+ forward_batch.req_pool_indices,
242
+ forward_batch.seq_lens,
243
+ forward_batch.extend_prefix_lens,
244
+ forward_batch.extend_seq_lens,
245
+ scaling=layer.scaling,
246
+ enable_gqa=use_gqa,
247
+ causal=not layer.is_cross_attention,
248
+ )
249
+ return o
250
+
251
+ def forward_decode(
252
+ self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
253
+ ):
254
+ # During torch.compile, there is a bug in rotary_emb that causes the
255
+ # output value to have a 3D tensor shape. This reshapes the output correctly.
256
+ q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
257
+
258
+ if layer.qk_head_dim != layer.v_head_dim:
259
+ o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
260
+ else:
261
+ o = torch.empty_like(q)
262
+
263
+ forward_batch.token_to_kv_pool.set_kv_buffer(
264
+ layer, forward_batch.out_cache_loc, k, v
265
+ )
266
+
267
+ use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
268
+
269
+ q_ = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
270
+ o_ = o.view(-1, layer.tp_q_head_num, layer.v_head_dim)
271
+
272
+ self._run_sdpa_forward_decode(
273
+ q_,
274
+ o_,
275
+ forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
276
+ forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
277
+ forward_batch.req_to_token_pool.req_to_token,
278
+ forward_batch.req_pool_indices,
279
+ forward_batch.seq_lens,
280
+ scaling=layer.scaling,
281
+ enable_gqa=use_gqa,
282
+ causal=False,
283
+ )
284
+
285
+ return o
@@ -105,20 +105,29 @@ def fused_moe_forward_native(
105
105
  num_expert_group: Optional[int] = None,
106
106
  custom_routing_function: Optional[Callable] = None,
107
107
  ) -> torch.Tensor:
108
- assert custom_routing_function is None
109
- topk_weights, topk_ids = select_experts_native(
110
- hidden_states=x,
111
- router_logits=router_logits,
112
- use_grouped_topk=use_grouped_topk,
113
- top_k=top_k,
114
- renormalize=renormalize,
115
- topk_group=topk_group,
116
- num_expert_group=num_expert_group,
117
- )
108
+
109
+ if use_grouped_topk:
110
+ assert num_expert_group is not None and topk_group is not None
111
+ topk_weights, topk_ids = grouped_topk(
112
+ x,
113
+ router_logits,
114
+ top_k,
115
+ renormalize,
116
+ num_expert_group,
117
+ topk_group,
118
+ )
119
+ elif custom_routing_function is None:
120
+ topk_weights, topk_ids = fused_topk_native(x, router_logits, top_k, renormalize)
121
+ else:
122
+ topk_weights, topk_ids = custom_routing_function(
123
+ x, router_logits, top_k, renormalize
124
+ )
125
+
118
126
  w13_weights = layer.w13_weight[topk_ids]
119
127
  w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)
120
128
  w2_weights = layer.w2_weight[topk_ids]
121
- x1 = F.silu(torch.einsum("ti,taoi -> tao", x, w1_weights))
129
+ x1 = torch.einsum("ti,taoi -> tao", x, w1_weights)
130
+ x1 = F.silu(x1)
122
131
  x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
123
132
  expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
124
133
  return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))
@@ -42,6 +42,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [
42
42
  "Fp8LinearMethod",
43
43
  "MarlinLinearMethod",
44
44
  "GPTQLinearMethod",
45
+ "QQQLinearMethod",
45
46
  ]
46
47
 
47
48