sglang 0.5.3__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +0 -2
- sglang/bench_serving.py +224 -127
- sglang/compile_deep_gemm.py +3 -0
- sglang/launch_server.py +0 -14
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/falcon_h1.py +12 -58
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +68 -31
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/qwen3_next.py +11 -43
- sglang/srt/disaggregation/decode.py +7 -18
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
- sglang/srt/disaggregation/nixl/conn.py +55 -23
- sglang/srt/disaggregation/prefill.py +17 -32
- sglang/srt/entrypoints/engine.py +2 -2
- sglang/srt/entrypoints/grpc_request_manager.py +10 -23
- sglang/srt/entrypoints/grpc_server.py +220 -80
- sglang/srt/entrypoints/http_server.py +49 -1
- sglang/srt/entrypoints/openai/protocol.py +159 -31
- sglang/srt/entrypoints/openai/serving_chat.py +13 -71
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +4 -0
- sglang/srt/function_call/function_call_parser.py +8 -6
- sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +64 -6
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +88 -0
- sglang/srt/layers/attention/attention_registry.py +31 -22
- sglang/srt/layers/attention/fla/layernorm_gated.py +47 -30
- sglang/srt/layers/attention/flashattention_backend.py +0 -1
- sglang/srt/layers/attention/flashinfer_backend.py +223 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -59
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -4
- sglang/srt/layers/attention/mamba/mamba.py +189 -241
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
- sglang/srt/layers/attention/triton_backend.py +1 -1
- sglang/srt/layers/logits_processor.py +136 -6
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +18 -21
- sglang/srt/layers/moe/ep_moe/kernels.py +31 -452
- sglang/srt/layers/moe/ep_moe/layer.py +8 -286
- sglang/srt/layers/moe/fused_moe_triton/layer.py +6 -11
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/utils.py +7 -1
- sglang/srt/layers/quantization/__init__.py +1 -1
- sglang/srt/layers/quantization/fp8.py +84 -18
- sglang/srt/layers/quantization/modelopt_quant.py +1 -1
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/w4afp8.py +2 -16
- sglang/srt/lora/lora_manager.py +0 -8
- sglang/srt/managers/overlap_utils.py +18 -16
- sglang/srt/managers/schedule_batch.py +119 -90
- sglang/srt/managers/schedule_policy.py +1 -1
- sglang/srt/managers/scheduler.py +213 -126
- sglang/srt/managers/scheduler_metrics_mixin.py +1 -1
- sglang/srt/managers/scheduler_output_processor_mixin.py +180 -86
- sglang/srt/managers/tokenizer_manager.py +270 -53
- sglang/srt/managers/tp_worker.py +39 -28
- sglang/srt/mem_cache/allocator.py +7 -2
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +162 -68
- sglang/srt/mem_cache/radix_cache.py +8 -3
- sglang/srt/mem_cache/swa_radix_cache.py +70 -14
- sglang/srt/model_executor/cuda_graph_runner.py +1 -1
- sglang/srt/model_executor/forward_batch_info.py +4 -18
- sglang/srt/model_executor/model_runner.py +55 -51
- sglang/srt/model_loader/__init__.py +1 -1
- sglang/srt/model_loader/loader.py +187 -6
- sglang/srt/model_loader/weight_utils.py +3 -0
- sglang/srt/models/falcon_h1.py +11 -9
- sglang/srt/models/gemma3_mm.py +16 -0
- sglang/srt/models/grok.py +5 -13
- sglang/srt/models/mixtral.py +1 -3
- sglang/srt/models/mllama4.py +11 -1
- sglang/srt/models/nemotron_h.py +514 -0
- sglang/srt/models/utils.py +5 -1
- sglang/srt/sampling/sampling_batch_info.py +11 -9
- sglang/srt/server_args.py +100 -33
- sglang/srt/speculative/eagle_worker.py +11 -13
- sglang/srt/speculative/ngram_worker.py +12 -11
- sglang/srt/speculative/spec_utils.py +0 -1
- sglang/srt/two_batch_overlap.py +1 -0
- sglang/srt/utils/common.py +18 -0
- sglang/srt/utils/hf_transformers_utils.py +2 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +40 -0
- sglang/test/simple_eval_longbench_v2.py +332 -0
- sglang/test/test_cutlass_w4a8_moe.py +9 -19
- sglang/test/test_deterministic.py +18 -2
- sglang/test/test_deterministic_utils.py +81 -0
- sglang/test/test_disaggregation_utils.py +63 -0
- sglang/test/test_utils.py +32 -11
- sglang/version.py +1 -1
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +4 -4
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +109 -98
- sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +0 -0
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -60,7 +60,8 @@ _is_npu = is_npu()
|
|
60
60
|
class LogitsProcessorOutput:
|
61
61
|
## Part 1: This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
|
62
62
|
# The logits of the next tokens. shape: [#seq, vocab_size]
|
63
|
-
|
63
|
+
# Can be None for certain prefill-only requests (e.g., multi-item scoring) that don't need next token generation
|
64
|
+
next_token_logits: Optional[torch.Tensor]
|
64
65
|
# Used by speculative decoding (EAGLE)
|
65
66
|
# The last hidden layers
|
66
67
|
hidden_states: Optional[torch.Tensor] = None
|
@@ -85,7 +86,10 @@ class LogitsProcessorOutput:
|
|
85
86
|
input_top_logprobs_val: List = None
|
86
87
|
input_top_logprobs_idx: List = None
|
87
88
|
# The logprobs and ids of the requested token ids in input positions. shape: [#seq, n] (n is the number of requested token ids)
|
88
|
-
|
89
|
+
# Can contain either lists or GPU tensors (for delayed GPU-to-CPU transfer optimization)
|
90
|
+
input_token_ids_logprobs_val: Optional[List[Union[List[float], torch.Tensor]]] = (
|
91
|
+
None
|
92
|
+
)
|
89
93
|
input_token_ids_logprobs_idx: Optional[List] = None
|
90
94
|
|
91
95
|
|
@@ -127,6 +131,9 @@ class LogitsMetadata:
|
|
127
131
|
# for padding
|
128
132
|
padded_static_len: int = -1
|
129
133
|
|
134
|
+
# Whether this batch is prefill-only (no token generation needed)
|
135
|
+
is_prefill_only: bool = False
|
136
|
+
|
130
137
|
@classmethod
|
131
138
|
def from_forward_batch(cls, forward_batch: ForwardBatch):
|
132
139
|
if (
|
@@ -169,6 +176,7 @@ class LogitsMetadata:
|
|
169
176
|
token_ids_logprobs=forward_batch.token_ids_logprobs,
|
170
177
|
extend_input_logprob_token_ids_gpu=forward_batch.extend_input_logprob_token_ids_gpu,
|
171
178
|
padded_static_len=forward_batch.padded_static_len,
|
179
|
+
is_prefill_only=forward_batch.is_prefill_only,
|
172
180
|
global_num_tokens_gpu=forward_batch.global_num_tokens_gpu,
|
173
181
|
dp_local_start_pos=forward_batch.dp_local_start_pos,
|
174
182
|
dp_local_num_tokens=forward_batch.dp_local_num_tokens,
|
@@ -247,6 +255,108 @@ class LogitsProcessor(nn.Module):
|
|
247
255
|
"debug_tensor_dump_output_folder", None
|
248
256
|
)
|
249
257
|
|
258
|
+
def compute_logprobs_for_multi_item_scoring(
|
259
|
+
self,
|
260
|
+
input_ids,
|
261
|
+
hidden_states,
|
262
|
+
lm_head: VocabParallelEmbedding,
|
263
|
+
logits_metadata: Union[LogitsMetadata, ForwardBatch],
|
264
|
+
delimiter_token: int,
|
265
|
+
):
|
266
|
+
"""
|
267
|
+
Compute logprobs for multi-item scoring using delimiter-based token extraction.
|
268
|
+
|
269
|
+
This method is designed for scenarios where you want to score multiple items/candidates
|
270
|
+
against a single query by combining them into one sequence separated by delimiters.
|
271
|
+
|
272
|
+
Sequence format: Query<delimiter>Item1<delimiter>Item2<delimiter>...
|
273
|
+
Scoring positions: Extracts logprobs at positions before each <delimiter>
|
274
|
+
|
275
|
+
Args:
|
276
|
+
input_ids (torch.Tensor): Input token IDs containing query and items separated by delimiters.
|
277
|
+
Shape: [total_sequence_length] for single request or [batch_total_length] for batch.
|
278
|
+
hidden_states (torch.Tensor): Hidden states from the model.
|
279
|
+
Shape: [sequence_length, hidden_dim].
|
280
|
+
lm_head (VocabParallelEmbedding): Language model head for computing logits.
|
281
|
+
logits_metadata (Union[LogitsMetadata, ForwardBatch]): Metadata containing batch info
|
282
|
+
and token ID specifications for logprob extraction.
|
283
|
+
delimiter_token (int): Token ID used as delimiter between query and items.
|
284
|
+
|
285
|
+
Returns:
|
286
|
+
LogitsProcessorOutput: Contains:
|
287
|
+
- next_token_logits: None (not needed for scoring-only requests)
|
288
|
+
- input_token_logprobs: Logprobs of delimiter tokens at scoring positions
|
289
|
+
- input_top_logprobs_val: Top-k logprobs at delimiter positions (if requested)
|
290
|
+
- input_top_logprobs_idx: Top-k token indices at delimiter positions (if requested)
|
291
|
+
- input_token_ids_logprobs_val: Logprobs for user-requested token IDs (if any)
|
292
|
+
- input_token_ids_logprobs_idx: Indices for user-requested token IDs (if any)
|
293
|
+
"""
|
294
|
+
multi_item_indices = (input_ids == delimiter_token).nonzero(as_tuple=True)[
|
295
|
+
0
|
296
|
+
] - 1
|
297
|
+
# Extract hidden states at delimiter positions for multi-item scoring
|
298
|
+
sliced_hidden = hidden_states[multi_item_indices]
|
299
|
+
|
300
|
+
sliced_logits = self._get_logits(sliced_hidden, lm_head, logits_metadata)
|
301
|
+
sliced_logprobs = torch.nn.functional.log_softmax(sliced_logits, dim=-1)
|
302
|
+
|
303
|
+
# Initialize return values
|
304
|
+
input_token_ids_logprobs_val = []
|
305
|
+
input_token_ids_logprobs_idx = []
|
306
|
+
input_top_logprobs_val = None
|
307
|
+
input_top_logprobs_idx = None
|
308
|
+
|
309
|
+
# Recalculate extend_logprob_pruned_lens_cpu to match delimiter counts per request
|
310
|
+
# Original contains sequence lengths, but we need delimiter counts for sliced_logprobs
|
311
|
+
if (
|
312
|
+
logits_metadata.token_ids_logprobs
|
313
|
+
or logits_metadata.extend_return_top_logprob
|
314
|
+
):
|
315
|
+
logits_metadata.extend_logprob_pruned_lens_cpu = []
|
316
|
+
|
317
|
+
if logits_metadata.extend_seq_lens_cpu is not None:
|
318
|
+
# Multi-request batch: count delimiters per request
|
319
|
+
input_pt = 0
|
320
|
+
for req_seq_len in logits_metadata.extend_seq_lens_cpu:
|
321
|
+
req_input_ids = input_ids[input_pt : input_pt + req_seq_len]
|
322
|
+
delimiter_count = (req_input_ids == delimiter_token).sum().item()
|
323
|
+
logits_metadata.extend_logprob_pruned_lens_cpu.append(
|
324
|
+
delimiter_count
|
325
|
+
)
|
326
|
+
input_pt += req_seq_len
|
327
|
+
else:
|
328
|
+
# Single request case: one request gets all delimiters
|
329
|
+
total_delimiters = (input_ids == delimiter_token).sum().item()
|
330
|
+
logits_metadata.extend_logprob_pruned_lens_cpu = [total_delimiters]
|
331
|
+
|
332
|
+
# Get the logprobs of specified token ids
|
333
|
+
if logits_metadata.extend_token_ids_logprob:
|
334
|
+
(
|
335
|
+
input_token_ids_logprobs_val,
|
336
|
+
input_token_ids_logprobs_idx,
|
337
|
+
) = self.get_token_ids_logprobs(
|
338
|
+
sliced_logprobs, logits_metadata, delay_cpu_copy=True
|
339
|
+
)
|
340
|
+
|
341
|
+
# Get the logprob of top-k tokens
|
342
|
+
if logits_metadata.extend_return_top_logprob:
|
343
|
+
(
|
344
|
+
input_top_logprobs_val,
|
345
|
+
input_top_logprobs_idx,
|
346
|
+
) = self.get_top_logprobs(sliced_logprobs, logits_metadata)
|
347
|
+
|
348
|
+
# For input_token_logprobs, use delimiter token logprobs
|
349
|
+
input_token_logprobs = sliced_logprobs[:, delimiter_token]
|
350
|
+
|
351
|
+
return LogitsProcessorOutput(
|
352
|
+
next_token_logits=None, # Multi-item scoring doesn't need next token logits
|
353
|
+
input_token_logprobs=input_token_logprobs,
|
354
|
+
input_top_logprobs_val=input_top_logprobs_val,
|
355
|
+
input_top_logprobs_idx=input_top_logprobs_idx,
|
356
|
+
input_token_ids_logprobs_val=input_token_ids_logprobs_val,
|
357
|
+
input_token_ids_logprobs_idx=input_token_ids_logprobs_idx,
|
358
|
+
)
|
359
|
+
|
250
360
|
def forward(
|
251
361
|
self,
|
252
362
|
input_ids,
|
@@ -257,6 +367,16 @@ class LogitsProcessor(nn.Module):
|
|
257
367
|
) -> LogitsProcessorOutput:
|
258
368
|
if isinstance(logits_metadata, ForwardBatch):
|
259
369
|
logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata)
|
370
|
+
|
371
|
+
# Check if multi-item scoring is enabled via server args (only for prefill-only requests)
|
372
|
+
multi_item_delimiter = global_server_args_dict.get(
|
373
|
+
"multi_item_scoring_delimiter"
|
374
|
+
)
|
375
|
+
if multi_item_delimiter is not None and logits_metadata.is_prefill_only:
|
376
|
+
return self.compute_logprobs_for_multi_item_scoring(
|
377
|
+
input_ids, hidden_states, lm_head, logits_metadata, multi_item_delimiter
|
378
|
+
)
|
379
|
+
|
260
380
|
# Get the last hidden states and last logits for the next token prediction
|
261
381
|
if (
|
262
382
|
logits_metadata.forward_mode.is_decode_or_idle()
|
@@ -584,7 +704,9 @@ class LogitsProcessor(nn.Module):
|
|
584
704
|
|
585
705
|
@staticmethod
|
586
706
|
def get_token_ids_logprobs(
|
587
|
-
all_logprobs: torch.Tensor,
|
707
|
+
all_logprobs: torch.Tensor,
|
708
|
+
logits_metadata: LogitsMetadata,
|
709
|
+
delay_cpu_copy: bool = False,
|
588
710
|
):
|
589
711
|
input_token_ids_logprobs_val, input_token_ids_logprobs_idx = [], []
|
590
712
|
pt = 0
|
@@ -597,9 +719,17 @@ class LogitsProcessor(nn.Module):
|
|
597
719
|
input_token_ids_logprobs_idx.append([])
|
598
720
|
continue
|
599
721
|
|
600
|
-
|
601
|
-
|
602
|
-
|
722
|
+
position_logprobs = all_logprobs[
|
723
|
+
pt : pt + pruned_len, token_ids
|
724
|
+
] # Shape: [pruned_len, num_tokens]
|
725
|
+
|
726
|
+
if delay_cpu_copy:
|
727
|
+
# Keep as tensor to delay GPU-to-CPU transfer
|
728
|
+
input_token_ids_logprobs_val.append(position_logprobs)
|
729
|
+
else:
|
730
|
+
# Convert to list immediately (default behavior)
|
731
|
+
input_token_ids_logprobs_val.append(position_logprobs.tolist())
|
732
|
+
|
603
733
|
input_token_ids_logprobs_idx.append([token_ids for _ in range(pruned_len)])
|
604
734
|
pt += pruned_len
|
605
735
|
|
@@ -0,0 +1,11 @@
|
|
1
|
+
"""
|
2
|
+
ModelOpt related constants
|
3
|
+
"""
|
4
|
+
|
5
|
+
QUANT_CFG_CHOICES = {
|
6
|
+
"fp8": "FP8_DEFAULT_CFG",
|
7
|
+
"int4_awq": "INT4_AWQ_CFG", # TODO: add support for int4_awq
|
8
|
+
"w4a8_awq": "W4A8_AWQ_BETA_CFG", # TODO: add support for w4a8_awq
|
9
|
+
"nvfp4": "NVFP4_DEFAULT_CFG",
|
10
|
+
"nvfp4_awq": "NVFP4_AWQ_LITE_CFG", # TODO: add support for nvfp4_awq
|
11
|
+
}
|
@@ -13,22 +13,18 @@ from sgl_kernel import (
|
|
13
13
|
from sglang.srt.layers.moe.ep_moe.kernels import (
|
14
14
|
post_reorder_triton_kernel_for_cutlass_moe,
|
15
15
|
pre_reorder_triton_kernel_for_cutlass_moe,
|
16
|
-
|
16
|
+
run_moe_ep_preproess,
|
17
17
|
)
|
18
18
|
|
19
19
|
|
20
20
|
def cutlass_w4a8_moe(
|
21
|
-
start_expert_id: int,
|
22
|
-
end_expert_id: int,
|
23
|
-
total_num_experts: int,
|
24
21
|
a: torch.Tensor,
|
25
22
|
w1_q: torch.Tensor,
|
26
23
|
w2_q: torch.Tensor,
|
27
24
|
w1_scale: torch.Tensor,
|
28
25
|
w2_scale: torch.Tensor,
|
29
26
|
topk_weights: torch.Tensor,
|
30
|
-
|
31
|
-
local_topk_ids: torch.Tensor,
|
27
|
+
topk_ids: torch.Tensor,
|
32
28
|
a_strides1: torch.Tensor,
|
33
29
|
b_strides1: torch.Tensor,
|
34
30
|
c_strides1: torch.Tensor,
|
@@ -64,6 +60,7 @@ def cutlass_w4a8_moe(
|
|
64
60
|
- w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
|
65
61
|
Shape: [num_experts, N // 512, K * 4]
|
66
62
|
- topk_weights (torch.Tensor): The weights of each token->expert mapping.
|
63
|
+
- topk_ids (torch.Tensor): The ids of each token->expert mapping.
|
67
64
|
- a_strides1 (torch.Tensor): The input strides of the first grouped gemm.
|
68
65
|
- b_strides1 (torch.Tensor): The weights strides of the first grouped gemm.
|
69
66
|
- c_strides1 (torch.Tensor): The output strides of the first grouped gemm.
|
@@ -83,7 +80,7 @@ def cutlass_w4a8_moe(
|
|
83
80
|
Returns:
|
84
81
|
- torch.Tensor: The fp8 output tensor after applying the MoE layer.
|
85
82
|
"""
|
86
|
-
assert topk_weights.shape ==
|
83
|
+
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
87
84
|
assert w1_q.dtype == torch.int8
|
88
85
|
assert w2_q.dtype == torch.int8
|
89
86
|
assert a.shape[1] // 2 == w1_q.shape[2], "Hidden size mismatch w1"
|
@@ -96,20 +93,21 @@ def cutlass_w4a8_moe(
|
|
96
93
|
assert b_strides1.shape[0] == w1_q.shape[0], "B Strides 1 expert number mismatch"
|
97
94
|
assert a_strides2.shape[0] == w2_q.shape[0], "A Strides 2 expert number mismatch"
|
98
95
|
assert b_strides2.shape[0] == w2_q.shape[0], "B Strides 2 expert number mismatch"
|
99
|
-
|
96
|
+
num_local_experts = w1_q.size(0)
|
100
97
|
m = a.size(0)
|
101
98
|
k = w1_q.size(2) * 2 # w1_q is transposed and packed
|
102
99
|
n = w2_q.size(2) * 2 # w2_q is transposed and packed
|
103
|
-
topk =
|
100
|
+
topk = topk_ids.size(1)
|
104
101
|
|
105
102
|
if apply_router_weight_on_input:
|
106
103
|
assert topk == 1, "apply_router_weight_on_input is only implemented for topk=1"
|
107
104
|
|
108
105
|
device = a.device
|
106
|
+
topk_ids = torch.where(topk_ids == -1, num_local_experts, topk_ids)
|
109
107
|
|
110
|
-
_, src2dst, _ =
|
111
|
-
|
112
|
-
|
108
|
+
_, src2dst, _ = run_moe_ep_preproess(
|
109
|
+
topk_ids,
|
110
|
+
num_local_experts,
|
113
111
|
)
|
114
112
|
|
115
113
|
gateup_input = torch.empty(
|
@@ -122,9 +120,9 @@ def cutlass_w4a8_moe(
|
|
122
120
|
a,
|
123
121
|
gateup_input,
|
124
122
|
src2dst,
|
125
|
-
|
123
|
+
topk_ids,
|
126
124
|
a1_scale,
|
127
|
-
|
125
|
+
num_local_experts,
|
128
126
|
topk,
|
129
127
|
k,
|
130
128
|
BLOCK_SIZE=512,
|
@@ -133,16 +131,16 @@ def cutlass_w4a8_moe(
|
|
133
131
|
# NOTE: a_map and c_map are not used in the get_cutlass_w4a8_moe_mm_data kernel,
|
134
132
|
# they are kept to allow for a quick switch of the permutation logic
|
135
133
|
# from the current triton kernel implementation to the cutlass-based one if needed.
|
136
|
-
a_map = torch.empty((
|
137
|
-
c_map = torch.empty((
|
134
|
+
a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
|
135
|
+
c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
|
138
136
|
get_cutlass_w4a8_moe_mm_data(
|
139
|
-
|
137
|
+
topk_ids,
|
140
138
|
expert_offsets,
|
141
139
|
problem_sizes1,
|
142
140
|
problem_sizes2,
|
143
141
|
a_map,
|
144
142
|
c_map,
|
145
|
-
|
143
|
+
num_local_experts,
|
146
144
|
n,
|
147
145
|
k,
|
148
146
|
)
|
@@ -195,12 +193,11 @@ def cutlass_w4a8_moe(
|
|
195
193
|
c2,
|
196
194
|
output,
|
197
195
|
src2dst,
|
198
|
-
|
196
|
+
topk_ids,
|
199
197
|
topk_weights,
|
200
|
-
num_experts,
|
201
198
|
topk,
|
199
|
+
num_local_experts,
|
202
200
|
k,
|
203
|
-
0,
|
204
201
|
BLOCK_SIZE=512,
|
205
202
|
)
|
206
203
|
return output
|