sglang 0.4.7__py3-none-any.whl → 0.4.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +7 -0
- sglang/bench_one_batch.py +8 -6
- sglang/bench_serving.py +1 -1
- sglang/lang/interpreter.py +40 -1
- sglang/lang/ir.py +27 -0
- sglang/math_utils.py +8 -0
- sglang/srt/_custom_ops.py +2 -2
- sglang/srt/code_completion_parser.py +2 -44
- sglang/srt/configs/model_config.py +6 -0
- sglang/srt/constants.py +3 -0
- sglang/srt/conversation.py +19 -3
- sglang/srt/custom_op.py +5 -1
- sglang/srt/disaggregation/base/__init__.py +1 -1
- sglang/srt/disaggregation/base/conn.py +25 -11
- sglang/srt/disaggregation/common/__init__.py +5 -1
- sglang/srt/disaggregation/common/utils.py +42 -0
- sglang/srt/disaggregation/decode.py +211 -72
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
- sglang/srt/disaggregation/fake/__init__.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +15 -9
- sglang/srt/disaggregation/mini_lb.py +34 -4
- sglang/srt/disaggregation/mooncake/__init__.py +1 -1
- sglang/srt/disaggregation/mooncake/conn.py +30 -29
- sglang/srt/disaggregation/nixl/__init__.py +6 -1
- sglang/srt/disaggregation/nixl/conn.py +17 -12
- sglang/srt/disaggregation/prefill.py +144 -55
- sglang/srt/disaggregation/utils.py +155 -123
- sglang/srt/distributed/parallel_state.py +12 -4
- sglang/srt/entrypoints/engine.py +37 -29
- sglang/srt/entrypoints/http_server.py +153 -72
- sglang/srt/entrypoints/http_server_engine.py +0 -3
- sglang/srt/entrypoints/openai/__init__.py +0 -0
- sglang/srt/{openai_api → entrypoints/openai}/protocol.py +84 -10
- sglang/srt/entrypoints/openai/serving_base.py +149 -0
- sglang/srt/entrypoints/openai/serving_chat.py +921 -0
- sglang/srt/entrypoints/openai/serving_completions.py +424 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
- sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
- sglang/srt/entrypoints/openai/serving_score.py +61 -0
- sglang/srt/entrypoints/openai/usage_processor.py +81 -0
- sglang/srt/entrypoints/openai/utils.py +72 -0
- sglang/srt/eplb_simulator/__init__.py +1 -0
- sglang/srt/eplb_simulator/reader.py +51 -0
- sglang/srt/function_call/base_format_detector.py +7 -4
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +64 -10
- sglang/srt/function_call/function_call_parser.py +6 -6
- sglang/srt/function_call/llama32_detector.py +1 -1
- sglang/srt/function_call/mistral_detector.py +1 -1
- sglang/srt/function_call/pythonic_detector.py +1 -1
- sglang/srt/function_call/qwen25_detector.py +1 -1
- sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
- sglang/srt/layers/activation.py +40 -3
- sglang/srt/layers/attention/aiter_backend.py +20 -4
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/cutlass_mla_backend.py +39 -15
- sglang/srt/layers/attention/flashattention_backend.py +71 -72
- sglang/srt/layers/attention/flashinfer_backend.py +10 -8
- sglang/srt/layers/attention/flashinfer_mla_backend.py +29 -28
- sglang/srt/layers/attention/flashmla_backend.py +7 -12
- sglang/srt/layers/attention/tbo_backend.py +3 -3
- sglang/srt/layers/attention/triton_backend.py +138 -130
- sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
- sglang/srt/layers/attention/vision.py +51 -24
- sglang/srt/layers/communicator.py +28 -10
- sglang/srt/layers/dp_attention.py +11 -2
- sglang/srt/layers/layernorm.py +29 -2
- sglang/srt/layers/linear.py +0 -4
- sglang/srt/layers/logits_processor.py +2 -14
- sglang/srt/layers/moe/ep_moe/kernels.py +165 -7
- sglang/srt/layers/moe/ep_moe/layer.py +249 -33
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -4
- sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
- sglang/srt/layers/moe/topk.py +107 -12
- sglang/srt/layers/pooler.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
- sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
- sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
- sglang/srt/layers/quantization/fp8.py +25 -17
- sglang/srt/layers/quantization/fp8_kernel.py +44 -15
- sglang/srt/layers/quantization/fp8_utils.py +87 -22
- sglang/srt/layers/quantization/modelopt_quant.py +62 -8
- sglang/srt/layers/quantization/utils.py +5 -2
- sglang/srt/layers/radix_attention.py +2 -3
- sglang/srt/layers/rotary_embedding.py +42 -2
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/lora_manager.py +249 -105
- sglang/srt/lora/mem_pool.py +53 -50
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +33 -14
- sglang/srt/managers/io_struct.py +31 -10
- sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
- sglang/srt/managers/multimodal_processors/vila.py +85 -0
- sglang/srt/managers/schedule_batch.py +79 -37
- sglang/srt/managers/schedule_policy.py +70 -56
- sglang/srt/managers/scheduler.py +220 -79
- sglang/srt/managers/template_manager.py +226 -0
- sglang/srt/managers/tokenizer_manager.py +40 -10
- sglang/srt/managers/tp_worker.py +12 -2
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
- sglang/srt/mem_cache/base_prefix_cache.py +52 -8
- sglang/srt/mem_cache/chunk_cache.py +11 -15
- sglang/srt/mem_cache/hiradix_cache.py +38 -25
- sglang/srt/mem_cache/memory_pool.py +213 -505
- sglang/srt/mem_cache/memory_pool_host.py +380 -0
- sglang/srt/mem_cache/radix_cache.py +56 -28
- sglang/srt/model_executor/cuda_graph_runner.py +198 -100
- sglang/srt/model_executor/forward_batch_info.py +32 -10
- sglang/srt/model_executor/model_runner.py +28 -12
- sglang/srt/model_loader/loader.py +16 -2
- sglang/srt/model_loader/weight_utils.py +11 -2
- sglang/srt/models/bert.py +113 -13
- sglang/srt/models/deepseek_nextn.py +29 -27
- sglang/srt/models/deepseek_v2.py +213 -173
- sglang/srt/models/glm4.py +312 -0
- sglang/srt/models/internvl.py +46 -102
- sglang/srt/models/mimo_mtp.py +2 -18
- sglang/srt/models/roberta.py +117 -9
- sglang/srt/models/vila.py +305 -0
- sglang/srt/reasoning_parser.py +21 -11
- sglang/srt/sampling/sampling_batch_info.py +24 -0
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +351 -238
- sglang/srt/speculative/build_eagle_tree.py +1 -1
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -9
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +130 -14
- sglang/srt/speculative/eagle_utils.py +468 -116
- sglang/srt/speculative/eagle_worker.py +258 -84
- sglang/srt/torch_memory_saver_adapter.py +19 -15
- sglang/srt/two_batch_overlap.py +4 -2
- sglang/srt/utils.py +235 -11
- sglang/test/attention/test_prefix_chunk_info.py +2 -0
- sglang/test/runners.py +38 -3
- sglang/test/test_block_fp8.py +1 -0
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
- sglang/test/test_block_fp8_ep.py +2 -0
- sglang/test/test_utils.py +4 -1
- sglang/utils.py +9 -0
- sglang/version.py +1 -1
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/METADATA +8 -14
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/RECORD +150 -128
- sglang/srt/entrypoints/verl_engine.py +0 -179
- sglang/srt/openai_api/adapter.py +0 -1990
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -28,9 +28,9 @@ from sglang.srt.layers.dp_attention import (
|
|
28
28
|
attn_tp_reduce_scatter,
|
29
29
|
dp_gather_partial,
|
30
30
|
dp_scatter,
|
31
|
+
get_attention_dp_size,
|
31
32
|
get_attention_tp_rank,
|
32
33
|
get_attention_tp_size,
|
33
|
-
get_local_attention_dp_size,
|
34
34
|
)
|
35
35
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
36
36
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
@@ -226,31 +226,32 @@ class LayerCommunicator:
|
|
226
226
|
|
227
227
|
@dataclass
|
228
228
|
class CommunicateContext:
|
229
|
-
process_group_sizes: Dict[
|
229
|
+
process_group_sizes: Dict[ScatterMode, int]
|
230
230
|
attn_tp_rank: int
|
231
231
|
attn_tp_size: int
|
232
|
-
|
232
|
+
attn_dp_size: int
|
233
233
|
tp_size: int
|
234
234
|
|
235
|
-
def is_same_group_size(self, a:
|
235
|
+
def is_same_group_size(self, a: ScatterMode, b: ScatterMode):
|
236
236
|
return self.process_group_sizes[a] == self.process_group_sizes[b]
|
237
237
|
|
238
238
|
@classmethod
|
239
239
|
def init_new(cls):
|
240
240
|
attn_tp_rank = get_attention_tp_rank()
|
241
241
|
attn_tp_size = get_attention_tp_size()
|
242
|
-
|
242
|
+
attn_dp_size = get_attention_dp_size()
|
243
243
|
tp_size = get_tensor_model_parallel_world_size()
|
244
244
|
process_group_sizes = {
|
245
245
|
ScatterMode.SCATTERED: 1,
|
246
246
|
ScatterMode.TP_ATTN_FULL: attn_tp_size,
|
247
|
+
# TODO: support --moe-dense-tp-size > 1
|
247
248
|
ScatterMode.FULL: tp_size,
|
248
249
|
}
|
249
250
|
return cls(
|
250
251
|
process_group_sizes=process_group_sizes,
|
251
252
|
attn_tp_rank=attn_tp_rank,
|
252
253
|
attn_tp_size=attn_tp_size,
|
253
|
-
|
254
|
+
attn_dp_size=attn_dp_size,
|
254
255
|
tp_size=tp_size,
|
255
256
|
)
|
256
257
|
|
@@ -323,11 +324,16 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
|
323
324
|
|
324
325
|
if (
|
325
326
|
(hidden_states_input_mode == ScatterMode.TP_ATTN_FULL)
|
326
|
-
and (
|
327
|
+
and (
|
328
|
+
residual_input_mode in [ScatterMode.SCATTERED, ScatterMode.TP_ATTN_FULL]
|
329
|
+
)
|
327
330
|
and (hidden_states_output_mode == ScatterMode.FULL)
|
328
331
|
and (residual_output_mode == ScatterMode.TP_ATTN_FULL)
|
329
332
|
):
|
330
|
-
return
|
333
|
+
return partial(
|
334
|
+
CommunicateWithAllReduceAndLayerNormFn._gather_hidden_states_and_residual,
|
335
|
+
residual_input_mode=residual_input_mode,
|
336
|
+
)
|
331
337
|
|
332
338
|
if (
|
333
339
|
(hidden_states_input_mode == ScatterMode.TP_ATTN_FULL)
|
@@ -360,14 +366,26 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
|
360
366
|
return hidden_states, residual
|
361
367
|
|
362
368
|
@staticmethod
|
363
|
-
def
|
369
|
+
def _gather_hidden_states_and_residual(
|
364
370
|
hidden_states: torch.Tensor,
|
365
371
|
residual: torch.Tensor,
|
366
372
|
forward_batch: ForwardBatch,
|
367
373
|
layernorm: torch.nn.Module,
|
368
374
|
context: CommunicateContext,
|
375
|
+
*,
|
376
|
+
residual_input_mode,
|
369
377
|
):
|
370
|
-
if context.
|
378
|
+
if residual_input_mode == ScatterMode.SCATTERED and context.attn_tp_size > 1:
|
379
|
+
residual, local_residual = (
|
380
|
+
forward_batch.gathered_buffer[
|
381
|
+
: forward_batch.input_ids.shape[0]
|
382
|
+
].clone(),
|
383
|
+
residual,
|
384
|
+
)
|
385
|
+
attn_tp_all_gather(
|
386
|
+
list(residual.tensor_split(context.attn_tp_size)), local_residual
|
387
|
+
)
|
388
|
+
if context.attn_dp_size != 1:
|
371
389
|
if context.attn_tp_rank == 0:
|
372
390
|
hidden_states += residual
|
373
391
|
hidden_states, local_hidden_states = (
|
@@ -165,7 +165,8 @@ def disable_dp_size():
|
|
165
165
|
|
166
166
|
|
167
167
|
def get_dp_local_info(forward_batch: ForwardBatch):
|
168
|
-
|
168
|
+
# `get_dp_local_info` is only called in global DP gather and scatter. We use global DP rank here.
|
169
|
+
dp_rank = get_attention_dp_rank()
|
169
170
|
|
170
171
|
if forward_batch.dp_local_start_pos is None:
|
171
172
|
cumtokens = torch.cumsum(forward_batch.global_num_tokens_gpu, dim=0)
|
@@ -238,6 +239,10 @@ def _dp_gather(
|
|
238
239
|
assert (
|
239
240
|
local_tokens.untyped_storage() is not global_tokens.untyped_storage()
|
240
241
|
), "aliasing between global_tokens and local_tokens not allowed"
|
242
|
+
if forward_batch.forward_mode.is_draft_extend():
|
243
|
+
shape_tensor = local_num_tokens.new_full((), local_tokens.shape[0])
|
244
|
+
local_num_tokens = torch.minimum(local_num_tokens, shape_tensor)
|
245
|
+
|
241
246
|
memcpy_triton(
|
242
247
|
global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False
|
243
248
|
)
|
@@ -288,6 +293,10 @@ def dp_scatter(
|
|
288
293
|
assert (
|
289
294
|
local_tokens.untyped_storage() is not global_tokens.untyped_storage()
|
290
295
|
), "aliasing between local_tokens and global_tokens not allowed"
|
296
|
+
if forward_batch.forward_mode.is_draft_extend():
|
297
|
+
shape_tensor = local_num_tokens.new_full((), local_tokens.shape[0])
|
298
|
+
local_num_tokens = torch.minimum(local_num_tokens, shape_tensor)
|
299
|
+
|
291
300
|
memcpy_triton(
|
292
301
|
local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True
|
293
302
|
)
|
@@ -301,4 +310,4 @@ def attn_tp_reduce_scatter(
|
|
301
310
|
|
302
311
|
|
303
312
|
def attn_tp_all_gather(output_list: List[torch.Tensor], input_: torch.Tensor):
|
304
|
-
return get_attention_tp_group().all_gather(input_,
|
313
|
+
return get_attention_tp_group().all_gather(input_, output_tensor_list=output_list)
|
sglang/srt/layers/layernorm.py
CHANGED
@@ -20,11 +20,21 @@ import torch
|
|
20
20
|
import torch.nn as nn
|
21
21
|
|
22
22
|
from sglang.srt.custom_op import CustomOp
|
23
|
-
from sglang.srt.utils import
|
23
|
+
from sglang.srt.utils import (
|
24
|
+
cpu_has_amx_support,
|
25
|
+
get_bool_env_var,
|
26
|
+
is_cpu,
|
27
|
+
is_cuda,
|
28
|
+
is_hip,
|
29
|
+
is_npu,
|
30
|
+
)
|
24
31
|
|
25
32
|
_is_cuda = is_cuda()
|
26
33
|
_is_hip = is_hip()
|
34
|
+
_is_npu = is_npu()
|
27
35
|
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
36
|
+
_is_cpu_amx_available = cpu_has_amx_support()
|
37
|
+
_is_cpu = is_cpu()
|
28
38
|
|
29
39
|
if _is_cuda:
|
30
40
|
from sgl_kernel import (
|
@@ -121,6 +131,23 @@ class RMSNorm(CustomOp):
|
|
121
131
|
else:
|
122
132
|
return x, residual
|
123
133
|
|
134
|
+
def forward_cpu(
|
135
|
+
self,
|
136
|
+
x: torch.Tensor,
|
137
|
+
residual: Optional[torch.Tensor] = None,
|
138
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
139
|
+
if _is_cpu_amx_available:
|
140
|
+
if residual is not None:
|
141
|
+
torch.ops.sgl_kernel.fused_add_rmsnorm_cpu(
|
142
|
+
x, residual, self.weight.data, self.variance_epsilon
|
143
|
+
)
|
144
|
+
return x, residual
|
145
|
+
return torch.ops.sgl_kernel.rmsnorm_cpu(
|
146
|
+
x, self.weight.data, self.variance_epsilon
|
147
|
+
)
|
148
|
+
else:
|
149
|
+
return self.forward_native(x, residual)
|
150
|
+
|
124
151
|
|
125
152
|
class GemmaRMSNorm(CustomOp):
|
126
153
|
def __init__(
|
@@ -187,7 +214,7 @@ class Gemma3RMSNorm(nn.Module):
|
|
187
214
|
return f"{tuple(self.weight.shape)}, eps={self.eps}"
|
188
215
|
|
189
216
|
|
190
|
-
if not (_is_cuda or _is_hip):
|
217
|
+
if not (_is_cuda or _is_hip or _is_npu or (_is_cpu and _is_cpu_amx_available)):
|
191
218
|
logger.info(
|
192
219
|
"sgl-kernel layernorm implementation is not available on current platform. Fallback to other kernel libraries."
|
193
220
|
)
|
sglang/srt/layers/linear.py
CHANGED
@@ -546,8 +546,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|
546
546
|
param.shard_id.append(loaded_shard_id)
|
547
547
|
param.shard_id_map[loaded_shard_id] = len(param.data_container)
|
548
548
|
param.data_container.append(loaded_weight)
|
549
|
-
if len(param.data_container) == 2:
|
550
|
-
self.qweight = param.materialize_nested()
|
551
549
|
return
|
552
550
|
|
553
551
|
param_data = param.data
|
@@ -961,8 +959,6 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|
961
959
|
param.shard_id.append(loaded_shard_id)
|
962
960
|
param.shard_id_map[loaded_shard_id] = len(param.data_container)
|
963
961
|
param.data_container.append(loaded_weight)
|
964
|
-
if len(param.data_container) == 3:
|
965
|
-
self.qweight = param.materialize_nested()
|
966
962
|
return
|
967
963
|
|
968
964
|
param_data = param.data
|
@@ -30,9 +30,9 @@ from sglang.srt.layers.dp_attention import (
|
|
30
30
|
attn_tp_all_gather,
|
31
31
|
dp_gather_replicate,
|
32
32
|
dp_scatter,
|
33
|
+
get_attention_dp_rank,
|
33
34
|
get_attention_dp_size,
|
34
35
|
get_attention_tp_size,
|
35
|
-
get_local_attention_dp_rank,
|
36
36
|
get_local_attention_dp_size,
|
37
37
|
)
|
38
38
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
@@ -47,18 +47,6 @@ from sglang.srt.utils import dump_to_file
|
|
47
47
|
logger = logging.getLogger(__name__)
|
48
48
|
|
49
49
|
|
50
|
-
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
51
|
-
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
52
|
-
from sglang.srt.model_executor.forward_batch_info import (
|
53
|
-
CaptureHiddenMode,
|
54
|
-
ForwardBatch,
|
55
|
-
ForwardMode,
|
56
|
-
)
|
57
|
-
from sglang.srt.utils import dump_to_file
|
58
|
-
|
59
|
-
logger = logging.getLogger(__name__)
|
60
|
-
|
61
|
-
|
62
50
|
@dataclasses.dataclass
|
63
51
|
class LogitsProcessorOutput:
|
64
52
|
## Part 1: This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
|
@@ -183,7 +171,7 @@ class LogitsMetadata:
|
|
183
171
|
return
|
184
172
|
|
185
173
|
cumtokens = torch.cumsum(self.global_num_tokens_for_logprob_gpu, dim=0)
|
186
|
-
dp_rank =
|
174
|
+
dp_rank = get_attention_dp_rank()
|
187
175
|
if dp_rank == 0:
|
188
176
|
dp_local_start_pos = torch.zeros_like(
|
189
177
|
self.global_num_tokens_for_logprob_gpu[0]
|
@@ -4,6 +4,7 @@ from typing import List, Optional
|
|
4
4
|
import torch
|
5
5
|
import triton
|
6
6
|
|
7
|
+
from sglang.math_utils import ceil_div
|
7
8
|
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
8
9
|
from sglang.srt.utils import dispose_tensor, is_cuda
|
9
10
|
|
@@ -15,11 +16,6 @@ if _is_cuda:
|
|
15
16
|
sglang_per_token_group_quant_fp8 as per_token_group_quant_fp8,
|
16
17
|
)
|
17
18
|
|
18
|
-
try:
|
19
|
-
from deep_gemm import ceil_div
|
20
|
-
except ImportError:
|
21
|
-
logger.error(f"Failed to import ceil_div from deep_gemm.")
|
22
|
-
|
23
19
|
import triton.language as tl
|
24
20
|
|
25
21
|
|
@@ -278,6 +274,7 @@ def _silu_and_mul_post_quant_kernel(
|
|
278
274
|
fp8_min,
|
279
275
|
BLOCK_N: tl.constexpr,
|
280
276
|
NUM_STAGE: tl.constexpr,
|
277
|
+
SCALE_UE8M0: tl.constexpr,
|
281
278
|
):
|
282
279
|
expert_id = tl.program_id(2)
|
283
280
|
token_id = tl.program_id(1)
|
@@ -319,6 +316,8 @@ def _silu_and_mul_post_quant_kernel(
|
|
319
316
|
gate_up = up * gate
|
320
317
|
_absmax = tl.maximum(tl.max(tl.abs(gate_up)), 1e-10)
|
321
318
|
output_s = _absmax / fp8_max
|
319
|
+
if SCALE_UE8M0:
|
320
|
+
output_s = tl.exp2(tl.ceil(tl.log2(tl.abs(output_s))))
|
322
321
|
output_q = tl.clamp(gate_up / output_s, fp8_min, fp8_max).to(
|
323
322
|
output_ptr.dtype.element_ty
|
324
323
|
)
|
@@ -339,6 +338,7 @@ def silu_and_mul_masked_post_quant_fwd(
|
|
339
338
|
output_scale: torch.Tensor,
|
340
339
|
quant_group_size: int,
|
341
340
|
masked_m: torch.Tensor,
|
341
|
+
scale_ue8m0: bool = False,
|
342
342
|
):
|
343
343
|
"""
|
344
344
|
input shape [expert_num, token_num_padded, hidden_dim]
|
@@ -395,6 +395,7 @@ def silu_and_mul_masked_post_quant_fwd(
|
|
395
395
|
BLOCK_N=BLOCK_N,
|
396
396
|
NUM_STAGE=NUM_STAGES,
|
397
397
|
num_warps=num_warps,
|
398
|
+
SCALE_UE8M0=scale_ue8m0,
|
398
399
|
)
|
399
400
|
return
|
400
401
|
|
@@ -477,11 +478,13 @@ def post_reorder_triton_kernel(
|
|
477
478
|
end_expert_id,
|
478
479
|
topk,
|
479
480
|
hidden_size,
|
481
|
+
dst_start,
|
480
482
|
BLOCK_SIZE: tl.constexpr,
|
481
483
|
):
|
482
484
|
InDtype = down_output_ptr.dtype.element_ty
|
483
485
|
|
484
|
-
|
486
|
+
src_idx_int32 = tl.program_id(0)
|
487
|
+
src_idx = src_idx_int32.to(tl.int64)
|
485
488
|
src2dst_ptr = src2dst_ptr + src_idx * topk
|
486
489
|
topk_ids_ptr = topk_ids_ptr + src_idx * topk
|
487
490
|
topk_weights_ptr = topk_weights_ptr + src_idx * topk
|
@@ -500,7 +503,9 @@ def post_reorder_triton_kernel(
|
|
500
503
|
expert_id = tl.load(topk_ids_ptr + idx)
|
501
504
|
if expert_id >= start_expert_id and expert_id <= end_expert_id:
|
502
505
|
computed = True
|
503
|
-
|
506
|
+
dst_idx_int32 = tl.load(src2dst_ptr + idx)
|
507
|
+
dst_idx = dst_idx_int32.to(tl.int64)
|
508
|
+
dst_idx = dst_idx - dst_start
|
504
509
|
weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype)
|
505
510
|
load_ptr = down_output_ptr + dst_idx * hidden_size
|
506
511
|
in_data = tl.load(load_ptr + offset, mask=mask)
|
@@ -1085,3 +1090,156 @@ def tma_align_input_scale(input_scale: torch.Tensor):
|
|
1085
1090
|
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
1086
1091
|
)
|
1087
1092
|
return output.t()[:m]
|
1093
|
+
|
1094
|
+
|
1095
|
+
@triton.jit
|
1096
|
+
def compute_masked_m_triton_kernel(seg_indptr, masked_m):
|
1097
|
+
expert_id = tl.program_id(0)
|
1098
|
+
start = tl.load(seg_indptr + expert_id)
|
1099
|
+
end = tl.load(seg_indptr + expert_id + 1)
|
1100
|
+
tl.store(masked_m + expert_id, (end - start))
|
1101
|
+
|
1102
|
+
|
1103
|
+
@triton.jit
|
1104
|
+
def deepgemm_compute_src2dst_triton_kernel(
|
1105
|
+
topk_ids,
|
1106
|
+
reorder_ids,
|
1107
|
+
seg_indptr,
|
1108
|
+
src2dst,
|
1109
|
+
m_max,
|
1110
|
+
num_toks,
|
1111
|
+
BLOCK_SIZE: tl.constexpr,
|
1112
|
+
):
|
1113
|
+
pid = tl.program_id(axis=0)
|
1114
|
+
dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
1115
|
+
mask = dst_id < num_toks
|
1116
|
+
src_id = tl.load(reorder_ids + dst_id, mask=mask)
|
1117
|
+
expert_id = tl.load(topk_ids + src_id, mask=(src_id < num_toks))
|
1118
|
+
expert_dst_start = tl.load(seg_indptr + expert_id)
|
1119
|
+
expert_dst_offset = dst_id - expert_dst_start
|
1120
|
+
dst_id = expert_id * m_max + expert_dst_offset
|
1121
|
+
tl.store(src2dst + src_id, dst_id, mask=mask)
|
1122
|
+
|
1123
|
+
|
1124
|
+
@triton.jit
|
1125
|
+
def fill_gateup_input_triton_kernel(
|
1126
|
+
input_ptr,
|
1127
|
+
scale_ptr,
|
1128
|
+
gateup_input_ptr,
|
1129
|
+
gateup_input_scale_ptr,
|
1130
|
+
src2dst_ptr,
|
1131
|
+
topk_ids_ptr,
|
1132
|
+
start_expert_id,
|
1133
|
+
end_expert_id,
|
1134
|
+
topk,
|
1135
|
+
m_max,
|
1136
|
+
hidden_size,
|
1137
|
+
scale_size,
|
1138
|
+
BLOCK_SIZE: tl.constexpr,
|
1139
|
+
):
|
1140
|
+
|
1141
|
+
src_idx_int32 = tl.program_id(0)
|
1142
|
+
src_idx = src_idx_int32.to(tl.int64)
|
1143
|
+
src2dst_ptr = src2dst_ptr + src_idx * topk
|
1144
|
+
topk_ids_ptr = topk_ids_ptr + src_idx * topk
|
1145
|
+
src_ptr = input_ptr + src_idx * hidden_size
|
1146
|
+
scale_src_ptr = scale_ptr + src_idx * scale_size
|
1147
|
+
|
1148
|
+
vec = tl.arange(0, BLOCK_SIZE)
|
1149
|
+
for idx in range(topk):
|
1150
|
+
expert_id = tl.load(topk_ids_ptr + idx)
|
1151
|
+
if expert_id >= start_expert_id and expert_id <= end_expert_id:
|
1152
|
+
dst_idx_int32 = tl.load(src2dst_ptr + idx)
|
1153
|
+
dst_idx = dst_idx_int32.to(tl.int64)
|
1154
|
+
dst_idx = dst_idx - start_expert_id * m_max
|
1155
|
+
dst_ptr = gateup_input_ptr + dst_idx * hidden_size
|
1156
|
+
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
|
1157
|
+
offset = start_offset + vec
|
1158
|
+
mask = offset < hidden_size
|
1159
|
+
in_data = tl.load(src_ptr + offset, mask=mask)
|
1160
|
+
tl.store(dst_ptr + offset, in_data, mask=mask)
|
1161
|
+
scale_dst_ptr = gateup_input_scale_ptr + dst_idx * scale_size
|
1162
|
+
for start_offset in tl.range(0, scale_size, BLOCK_SIZE):
|
1163
|
+
offset = start_offset + vec
|
1164
|
+
mask = offset < scale_size
|
1165
|
+
in_scale = tl.load(scale_src_ptr + offset, mask=mask)
|
1166
|
+
tl.store(scale_dst_ptr + offset, in_scale, mask=mask)
|
1167
|
+
|
1168
|
+
|
1169
|
+
def moe_ep_deepgemm_preprocess(
|
1170
|
+
topk_ids: torch.Tensor,
|
1171
|
+
num_experts: int,
|
1172
|
+
hidden_states: torch.Tensor,
|
1173
|
+
top_k: int,
|
1174
|
+
start_expert_id,
|
1175
|
+
end_expert_id,
|
1176
|
+
block_shape,
|
1177
|
+
output_dtype: torch.dtype = torch.float8_e4m3fn,
|
1178
|
+
):
|
1179
|
+
reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)
|
1180
|
+
seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64)
|
1181
|
+
src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32)
|
1182
|
+
masked_m = torch.zeros(num_experts, device=topk_ids.device, dtype=torch.int32)
|
1183
|
+
|
1184
|
+
compute_seg_indptr_triton_kernel[(num_experts,)](
|
1185
|
+
reorder_topk_ids, seg_indptr, topk_ids.numel()
|
1186
|
+
)
|
1187
|
+
|
1188
|
+
grid = lambda meta: (triton.cdiv(topk_ids.numel(), meta["BLOCK_SIZE"]),)
|
1189
|
+
compute_masked_m_triton_kernel[(num_experts,)](seg_indptr, masked_m)
|
1190
|
+
|
1191
|
+
# For masked grouped GEMM, shape M should be multiple of the block M (current block M: {block_m}) https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/jit_kernels/m_grouped_gemm.py#L165
|
1192
|
+
m_max = (hidden_states.size(0) + 255) // 256 * 256
|
1193
|
+
expected_m = (topk_ids.numel() + num_experts - 1) // num_experts
|
1194
|
+
gateup_input = torch.empty(
|
1195
|
+
(int(end_expert_id - start_expert_id + 1), m_max, hidden_states.size(1)),
|
1196
|
+
device=hidden_states.device,
|
1197
|
+
dtype=output_dtype,
|
1198
|
+
)
|
1199
|
+
|
1200
|
+
deepgemm_compute_src2dst_triton_kernel[grid](
|
1201
|
+
topk_ids,
|
1202
|
+
reorder_ids,
|
1203
|
+
seg_indptr,
|
1204
|
+
src2dst,
|
1205
|
+
m_max,
|
1206
|
+
topk_ids.numel(),
|
1207
|
+
BLOCK_SIZE=256,
|
1208
|
+
)
|
1209
|
+
|
1210
|
+
if block_shape is None:
|
1211
|
+
block_shape = [128, 128]
|
1212
|
+
assert len(block_shape) == 2
|
1213
|
+
block_n, block_k = block_shape[0], block_shape[1]
|
1214
|
+
hidden_states, scale = per_token_group_quant_fp8(hidden_states, block_k)
|
1215
|
+
|
1216
|
+
gateup_input_scale = torch.empty(
|
1217
|
+
(gateup_input.size(0), gateup_input.size(1), scale.size(1)),
|
1218
|
+
device=hidden_states.device,
|
1219
|
+
dtype=scale.dtype,
|
1220
|
+
)
|
1221
|
+
|
1222
|
+
fill_gateup_input_triton_kernel[(hidden_states.shape[0],)](
|
1223
|
+
hidden_states,
|
1224
|
+
scale,
|
1225
|
+
gateup_input,
|
1226
|
+
gateup_input_scale,
|
1227
|
+
src2dst,
|
1228
|
+
topk_ids,
|
1229
|
+
start_expert_id,
|
1230
|
+
end_expert_id,
|
1231
|
+
top_k,
|
1232
|
+
m_max,
|
1233
|
+
hidden_states.size(1),
|
1234
|
+
scale.size(1),
|
1235
|
+
BLOCK_SIZE=1024,
|
1236
|
+
)
|
1237
|
+
|
1238
|
+
return (
|
1239
|
+
m_max,
|
1240
|
+
masked_m[start_expert_id : (end_expert_id + 1)],
|
1241
|
+
expected_m,
|
1242
|
+
src2dst,
|
1243
|
+
gateup_input,
|
1244
|
+
gateup_input_scale,
|
1245
|
+
)
|