sglang 0.5.3__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. sglang/bench_one_batch.py +0 -2
  2. sglang/bench_serving.py +224 -127
  3. sglang/compile_deep_gemm.py +3 -0
  4. sglang/launch_server.py +0 -14
  5. sglang/srt/configs/__init__.py +2 -0
  6. sglang/srt/configs/falcon_h1.py +12 -58
  7. sglang/srt/configs/mamba_utils.py +117 -0
  8. sglang/srt/configs/model_config.py +68 -31
  9. sglang/srt/configs/nemotron_h.py +286 -0
  10. sglang/srt/configs/qwen3_next.py +11 -43
  11. sglang/srt/disaggregation/decode.py +7 -18
  12. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  13. sglang/srt/disaggregation/nixl/conn.py +55 -23
  14. sglang/srt/disaggregation/prefill.py +17 -32
  15. sglang/srt/entrypoints/engine.py +2 -2
  16. sglang/srt/entrypoints/grpc_request_manager.py +10 -23
  17. sglang/srt/entrypoints/grpc_server.py +220 -80
  18. sglang/srt/entrypoints/http_server.py +49 -1
  19. sglang/srt/entrypoints/openai/protocol.py +159 -31
  20. sglang/srt/entrypoints/openai/serving_chat.py +13 -71
  21. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  22. sglang/srt/environ.py +4 -0
  23. sglang/srt/function_call/function_call_parser.py +8 -6
  24. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  25. sglang/srt/grpc/sglang_scheduler_pb2.pyi +64 -6
  26. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +88 -0
  27. sglang/srt/layers/attention/attention_registry.py +31 -22
  28. sglang/srt/layers/attention/fla/layernorm_gated.py +47 -30
  29. sglang/srt/layers/attention/flashattention_backend.py +0 -1
  30. sglang/srt/layers/attention/flashinfer_backend.py +223 -6
  31. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -1
  32. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -59
  33. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  34. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -4
  35. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  36. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  37. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  38. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  39. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  40. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  41. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  42. sglang/srt/layers/attention/triton_backend.py +1 -1
  43. sglang/srt/layers/logits_processor.py +136 -6
  44. sglang/srt/layers/modelopt_utils.py +11 -0
  45. sglang/srt/layers/moe/cutlass_w4a8_moe.py +18 -21
  46. sglang/srt/layers/moe/ep_moe/kernels.py +31 -452
  47. sglang/srt/layers/moe/ep_moe/layer.py +8 -286
  48. sglang/srt/layers/moe/fused_moe_triton/layer.py +6 -11
  49. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  50. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  51. sglang/srt/layers/moe/utils.py +7 -1
  52. sglang/srt/layers/quantization/__init__.py +1 -1
  53. sglang/srt/layers/quantization/fp8.py +84 -18
  54. sglang/srt/layers/quantization/modelopt_quant.py +1 -1
  55. sglang/srt/layers/quantization/quark/quark.py +3 -1
  56. sglang/srt/layers/quantization/w4afp8.py +2 -16
  57. sglang/srt/lora/lora_manager.py +0 -8
  58. sglang/srt/managers/overlap_utils.py +18 -16
  59. sglang/srt/managers/schedule_batch.py +119 -90
  60. sglang/srt/managers/schedule_policy.py +1 -1
  61. sglang/srt/managers/scheduler.py +213 -126
  62. sglang/srt/managers/scheduler_metrics_mixin.py +1 -1
  63. sglang/srt/managers/scheduler_output_processor_mixin.py +180 -86
  64. sglang/srt/managers/tokenizer_manager.py +270 -53
  65. sglang/srt/managers/tp_worker.py +39 -28
  66. sglang/srt/mem_cache/allocator.py +7 -2
  67. sglang/srt/mem_cache/chunk_cache.py +1 -1
  68. sglang/srt/mem_cache/memory_pool.py +162 -68
  69. sglang/srt/mem_cache/radix_cache.py +8 -3
  70. sglang/srt/mem_cache/swa_radix_cache.py +70 -14
  71. sglang/srt/model_executor/cuda_graph_runner.py +1 -1
  72. sglang/srt/model_executor/forward_batch_info.py +4 -18
  73. sglang/srt/model_executor/model_runner.py +55 -51
  74. sglang/srt/model_loader/__init__.py +1 -1
  75. sglang/srt/model_loader/loader.py +187 -6
  76. sglang/srt/model_loader/weight_utils.py +3 -0
  77. sglang/srt/models/falcon_h1.py +11 -9
  78. sglang/srt/models/gemma3_mm.py +16 -0
  79. sglang/srt/models/grok.py +5 -13
  80. sglang/srt/models/mixtral.py +1 -3
  81. sglang/srt/models/mllama4.py +11 -1
  82. sglang/srt/models/nemotron_h.py +514 -0
  83. sglang/srt/models/utils.py +5 -1
  84. sglang/srt/sampling/sampling_batch_info.py +11 -9
  85. sglang/srt/server_args.py +100 -33
  86. sglang/srt/speculative/eagle_worker.py +11 -13
  87. sglang/srt/speculative/ngram_worker.py +12 -11
  88. sglang/srt/speculative/spec_utils.py +0 -1
  89. sglang/srt/two_batch_overlap.py +1 -0
  90. sglang/srt/utils/common.py +18 -0
  91. sglang/srt/utils/hf_transformers_utils.py +2 -0
  92. sglang/test/longbench_v2/__init__.py +1 -0
  93. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  94. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  95. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  96. sglang/test/run_eval.py +40 -0
  97. sglang/test/simple_eval_longbench_v2.py +332 -0
  98. sglang/test/test_cutlass_w4a8_moe.py +9 -19
  99. sglang/test/test_deterministic.py +18 -2
  100. sglang/test/test_deterministic_utils.py +81 -0
  101. sglang/test/test_disaggregation_utils.py +63 -0
  102. sglang/test/test_utils.py +32 -11
  103. sglang/version.py +1 -1
  104. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +4 -4
  105. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +109 -98
  106. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  107. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  108. sglang/test/test_block_fp8_ep.py +0 -358
  109. /sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +0 -0
  110. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
  111. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
  112. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -60,7 +60,8 @@ _is_npu = is_npu()
60
60
  class LogitsProcessorOutput:
61
61
  ## Part 1: This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
62
62
  # The logits of the next tokens. shape: [#seq, vocab_size]
63
- next_token_logits: torch.Tensor
63
+ # Can be None for certain prefill-only requests (e.g., multi-item scoring) that don't need next token generation
64
+ next_token_logits: Optional[torch.Tensor]
64
65
  # Used by speculative decoding (EAGLE)
65
66
  # The last hidden layers
66
67
  hidden_states: Optional[torch.Tensor] = None
@@ -85,7 +86,10 @@ class LogitsProcessorOutput:
85
86
  input_top_logprobs_val: List = None
86
87
  input_top_logprobs_idx: List = None
87
88
  # The logprobs and ids of the requested token ids in input positions. shape: [#seq, n] (n is the number of requested token ids)
88
- input_token_ids_logprobs_val: Optional[List] = None
89
+ # Can contain either lists or GPU tensors (for delayed GPU-to-CPU transfer optimization)
90
+ input_token_ids_logprobs_val: Optional[List[Union[List[float], torch.Tensor]]] = (
91
+ None
92
+ )
89
93
  input_token_ids_logprobs_idx: Optional[List] = None
90
94
 
91
95
 
@@ -127,6 +131,9 @@ class LogitsMetadata:
127
131
  # for padding
128
132
  padded_static_len: int = -1
129
133
 
134
+ # Whether this batch is prefill-only (no token generation needed)
135
+ is_prefill_only: bool = False
136
+
130
137
  @classmethod
131
138
  def from_forward_batch(cls, forward_batch: ForwardBatch):
132
139
  if (
@@ -169,6 +176,7 @@ class LogitsMetadata:
169
176
  token_ids_logprobs=forward_batch.token_ids_logprobs,
170
177
  extend_input_logprob_token_ids_gpu=forward_batch.extend_input_logprob_token_ids_gpu,
171
178
  padded_static_len=forward_batch.padded_static_len,
179
+ is_prefill_only=forward_batch.is_prefill_only,
172
180
  global_num_tokens_gpu=forward_batch.global_num_tokens_gpu,
173
181
  dp_local_start_pos=forward_batch.dp_local_start_pos,
174
182
  dp_local_num_tokens=forward_batch.dp_local_num_tokens,
@@ -247,6 +255,108 @@ class LogitsProcessor(nn.Module):
247
255
  "debug_tensor_dump_output_folder", None
248
256
  )
249
257
 
258
+ def compute_logprobs_for_multi_item_scoring(
259
+ self,
260
+ input_ids,
261
+ hidden_states,
262
+ lm_head: VocabParallelEmbedding,
263
+ logits_metadata: Union[LogitsMetadata, ForwardBatch],
264
+ delimiter_token: int,
265
+ ):
266
+ """
267
+ Compute logprobs for multi-item scoring using delimiter-based token extraction.
268
+
269
+ This method is designed for scenarios where you want to score multiple items/candidates
270
+ against a single query by combining them into one sequence separated by delimiters.
271
+
272
+ Sequence format: Query<delimiter>Item1<delimiter>Item2<delimiter>...
273
+ Scoring positions: Extracts logprobs at positions before each <delimiter>
274
+
275
+ Args:
276
+ input_ids (torch.Tensor): Input token IDs containing query and items separated by delimiters.
277
+ Shape: [total_sequence_length] for single request or [batch_total_length] for batch.
278
+ hidden_states (torch.Tensor): Hidden states from the model.
279
+ Shape: [sequence_length, hidden_dim].
280
+ lm_head (VocabParallelEmbedding): Language model head for computing logits.
281
+ logits_metadata (Union[LogitsMetadata, ForwardBatch]): Metadata containing batch info
282
+ and token ID specifications for logprob extraction.
283
+ delimiter_token (int): Token ID used as delimiter between query and items.
284
+
285
+ Returns:
286
+ LogitsProcessorOutput: Contains:
287
+ - next_token_logits: None (not needed for scoring-only requests)
288
+ - input_token_logprobs: Logprobs of delimiter tokens at scoring positions
289
+ - input_top_logprobs_val: Top-k logprobs at delimiter positions (if requested)
290
+ - input_top_logprobs_idx: Top-k token indices at delimiter positions (if requested)
291
+ - input_token_ids_logprobs_val: Logprobs for user-requested token IDs (if any)
292
+ - input_token_ids_logprobs_idx: Indices for user-requested token IDs (if any)
293
+ """
294
+ multi_item_indices = (input_ids == delimiter_token).nonzero(as_tuple=True)[
295
+ 0
296
+ ] - 1
297
+ # Extract hidden states at delimiter positions for multi-item scoring
298
+ sliced_hidden = hidden_states[multi_item_indices]
299
+
300
+ sliced_logits = self._get_logits(sliced_hidden, lm_head, logits_metadata)
301
+ sliced_logprobs = torch.nn.functional.log_softmax(sliced_logits, dim=-1)
302
+
303
+ # Initialize return values
304
+ input_token_ids_logprobs_val = []
305
+ input_token_ids_logprobs_idx = []
306
+ input_top_logprobs_val = None
307
+ input_top_logprobs_idx = None
308
+
309
+ # Recalculate extend_logprob_pruned_lens_cpu to match delimiter counts per request
310
+ # Original contains sequence lengths, but we need delimiter counts for sliced_logprobs
311
+ if (
312
+ logits_metadata.token_ids_logprobs
313
+ or logits_metadata.extend_return_top_logprob
314
+ ):
315
+ logits_metadata.extend_logprob_pruned_lens_cpu = []
316
+
317
+ if logits_metadata.extend_seq_lens_cpu is not None:
318
+ # Multi-request batch: count delimiters per request
319
+ input_pt = 0
320
+ for req_seq_len in logits_metadata.extend_seq_lens_cpu:
321
+ req_input_ids = input_ids[input_pt : input_pt + req_seq_len]
322
+ delimiter_count = (req_input_ids == delimiter_token).sum().item()
323
+ logits_metadata.extend_logprob_pruned_lens_cpu.append(
324
+ delimiter_count
325
+ )
326
+ input_pt += req_seq_len
327
+ else:
328
+ # Single request case: one request gets all delimiters
329
+ total_delimiters = (input_ids == delimiter_token).sum().item()
330
+ logits_metadata.extend_logprob_pruned_lens_cpu = [total_delimiters]
331
+
332
+ # Get the logprobs of specified token ids
333
+ if logits_metadata.extend_token_ids_logprob:
334
+ (
335
+ input_token_ids_logprobs_val,
336
+ input_token_ids_logprobs_idx,
337
+ ) = self.get_token_ids_logprobs(
338
+ sliced_logprobs, logits_metadata, delay_cpu_copy=True
339
+ )
340
+
341
+ # Get the logprob of top-k tokens
342
+ if logits_metadata.extend_return_top_logprob:
343
+ (
344
+ input_top_logprobs_val,
345
+ input_top_logprobs_idx,
346
+ ) = self.get_top_logprobs(sliced_logprobs, logits_metadata)
347
+
348
+ # For input_token_logprobs, use delimiter token logprobs
349
+ input_token_logprobs = sliced_logprobs[:, delimiter_token]
350
+
351
+ return LogitsProcessorOutput(
352
+ next_token_logits=None, # Multi-item scoring doesn't need next token logits
353
+ input_token_logprobs=input_token_logprobs,
354
+ input_top_logprobs_val=input_top_logprobs_val,
355
+ input_top_logprobs_idx=input_top_logprobs_idx,
356
+ input_token_ids_logprobs_val=input_token_ids_logprobs_val,
357
+ input_token_ids_logprobs_idx=input_token_ids_logprobs_idx,
358
+ )
359
+
250
360
  def forward(
251
361
  self,
252
362
  input_ids,
@@ -257,6 +367,16 @@ class LogitsProcessor(nn.Module):
257
367
  ) -> LogitsProcessorOutput:
258
368
  if isinstance(logits_metadata, ForwardBatch):
259
369
  logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata)
370
+
371
+ # Check if multi-item scoring is enabled via server args (only for prefill-only requests)
372
+ multi_item_delimiter = global_server_args_dict.get(
373
+ "multi_item_scoring_delimiter"
374
+ )
375
+ if multi_item_delimiter is not None and logits_metadata.is_prefill_only:
376
+ return self.compute_logprobs_for_multi_item_scoring(
377
+ input_ids, hidden_states, lm_head, logits_metadata, multi_item_delimiter
378
+ )
379
+
260
380
  # Get the last hidden states and last logits for the next token prediction
261
381
  if (
262
382
  logits_metadata.forward_mode.is_decode_or_idle()
@@ -584,7 +704,9 @@ class LogitsProcessor(nn.Module):
584
704
 
585
705
  @staticmethod
586
706
  def get_token_ids_logprobs(
587
- all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata
707
+ all_logprobs: torch.Tensor,
708
+ logits_metadata: LogitsMetadata,
709
+ delay_cpu_copy: bool = False,
588
710
  ):
589
711
  input_token_ids_logprobs_val, input_token_ids_logprobs_idx = [], []
590
712
  pt = 0
@@ -597,9 +719,17 @@ class LogitsProcessor(nn.Module):
597
719
  input_token_ids_logprobs_idx.append([])
598
720
  continue
599
721
 
600
- input_token_ids_logprobs_val.append(
601
- [all_logprobs[pt + j, token_ids].tolist() for j in range(pruned_len)]
602
- )
722
+ position_logprobs = all_logprobs[
723
+ pt : pt + pruned_len, token_ids
724
+ ] # Shape: [pruned_len, num_tokens]
725
+
726
+ if delay_cpu_copy:
727
+ # Keep as tensor to delay GPU-to-CPU transfer
728
+ input_token_ids_logprobs_val.append(position_logprobs)
729
+ else:
730
+ # Convert to list immediately (default behavior)
731
+ input_token_ids_logprobs_val.append(position_logprobs.tolist())
732
+
603
733
  input_token_ids_logprobs_idx.append([token_ids for _ in range(pruned_len)])
604
734
  pt += pruned_len
605
735
 
@@ -0,0 +1,11 @@
1
+ """
2
+ ModelOpt related constants
3
+ """
4
+
5
+ QUANT_CFG_CHOICES = {
6
+ "fp8": "FP8_DEFAULT_CFG",
7
+ "int4_awq": "INT4_AWQ_CFG", # TODO: add support for int4_awq
8
+ "w4a8_awq": "W4A8_AWQ_BETA_CFG", # TODO: add support for w4a8_awq
9
+ "nvfp4": "NVFP4_DEFAULT_CFG",
10
+ "nvfp4_awq": "NVFP4_AWQ_LITE_CFG", # TODO: add support for nvfp4_awq
11
+ }
@@ -13,22 +13,18 @@ from sgl_kernel import (
13
13
  from sglang.srt.layers.moe.ep_moe.kernels import (
14
14
  post_reorder_triton_kernel_for_cutlass_moe,
15
15
  pre_reorder_triton_kernel_for_cutlass_moe,
16
- run_cutlass_moe_ep_preproess,
16
+ run_moe_ep_preproess,
17
17
  )
18
18
 
19
19
 
20
20
  def cutlass_w4a8_moe(
21
- start_expert_id: int,
22
- end_expert_id: int,
23
- total_num_experts: int,
24
21
  a: torch.Tensor,
25
22
  w1_q: torch.Tensor,
26
23
  w2_q: torch.Tensor,
27
24
  w1_scale: torch.Tensor,
28
25
  w2_scale: torch.Tensor,
29
26
  topk_weights: torch.Tensor,
30
- topk_ids_: torch.Tensor,
31
- local_topk_ids: torch.Tensor,
27
+ topk_ids: torch.Tensor,
32
28
  a_strides1: torch.Tensor,
33
29
  b_strides1: torch.Tensor,
34
30
  c_strides1: torch.Tensor,
@@ -64,6 +60,7 @@ def cutlass_w4a8_moe(
64
60
  - w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
65
61
  Shape: [num_experts, N // 512, K * 4]
66
62
  - topk_weights (torch.Tensor): The weights of each token->expert mapping.
63
+ - topk_ids (torch.Tensor): The ids of each token->expert mapping.
67
64
  - a_strides1 (torch.Tensor): The input strides of the first grouped gemm.
68
65
  - b_strides1 (torch.Tensor): The weights strides of the first grouped gemm.
69
66
  - c_strides1 (torch.Tensor): The output strides of the first grouped gemm.
@@ -83,7 +80,7 @@ def cutlass_w4a8_moe(
83
80
  Returns:
84
81
  - torch.Tensor: The fp8 output tensor after applying the MoE layer.
85
82
  """
86
- assert topk_weights.shape == topk_ids_.shape, "topk shape mismatch"
83
+ assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
87
84
  assert w1_q.dtype == torch.int8
88
85
  assert w2_q.dtype == torch.int8
89
86
  assert a.shape[1] // 2 == w1_q.shape[2], "Hidden size mismatch w1"
@@ -96,20 +93,21 @@ def cutlass_w4a8_moe(
96
93
  assert b_strides1.shape[0] == w1_q.shape[0], "B Strides 1 expert number mismatch"
97
94
  assert a_strides2.shape[0] == w2_q.shape[0], "A Strides 2 expert number mismatch"
98
95
  assert b_strides2.shape[0] == w2_q.shape[0], "B Strides 2 expert number mismatch"
99
- num_experts = w1_q.size(0)
96
+ num_local_experts = w1_q.size(0)
100
97
  m = a.size(0)
101
98
  k = w1_q.size(2) * 2 # w1_q is transposed and packed
102
99
  n = w2_q.size(2) * 2 # w2_q is transposed and packed
103
- topk = topk_ids_.size(1)
100
+ topk = topk_ids.size(1)
104
101
 
105
102
  if apply_router_weight_on_input:
106
103
  assert topk == 1, "apply_router_weight_on_input is only implemented for topk=1"
107
104
 
108
105
  device = a.device
106
+ topk_ids = torch.where(topk_ids == -1, num_local_experts, topk_ids)
109
107
 
110
- _, src2dst, _ = run_cutlass_moe_ep_preproess(
111
- local_topk_ids,
112
- num_experts,
108
+ _, src2dst, _ = run_moe_ep_preproess(
109
+ topk_ids,
110
+ num_local_experts,
113
111
  )
114
112
 
115
113
  gateup_input = torch.empty(
@@ -122,9 +120,9 @@ def cutlass_w4a8_moe(
122
120
  a,
123
121
  gateup_input,
124
122
  src2dst,
125
- local_topk_ids,
123
+ topk_ids,
126
124
  a1_scale,
127
- total_num_experts,
125
+ num_local_experts,
128
126
  topk,
129
127
  k,
130
128
  BLOCK_SIZE=512,
@@ -133,16 +131,16 @@ def cutlass_w4a8_moe(
133
131
  # NOTE: a_map and c_map are not used in the get_cutlass_w4a8_moe_mm_data kernel,
134
132
  # they are kept to allow for a quick switch of the permutation logic
135
133
  # from the current triton kernel implementation to the cutlass-based one if needed.
136
- a_map = torch.empty((local_topk_ids.numel()), dtype=torch.int32, device=device)
137
- c_map = torch.empty((local_topk_ids.numel()), dtype=torch.int32, device=device)
134
+ a_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
135
+ c_map = torch.empty((topk_ids.numel()), dtype=torch.int32, device=device)
138
136
  get_cutlass_w4a8_moe_mm_data(
139
- local_topk_ids,
137
+ topk_ids,
140
138
  expert_offsets,
141
139
  problem_sizes1,
142
140
  problem_sizes2,
143
141
  a_map,
144
142
  c_map,
145
- num_experts,
143
+ num_local_experts,
146
144
  n,
147
145
  k,
148
146
  )
@@ -195,12 +193,11 @@ def cutlass_w4a8_moe(
195
193
  c2,
196
194
  output,
197
195
  src2dst,
198
- local_topk_ids,
196
+ topk_ids,
199
197
  topk_weights,
200
- num_experts,
201
198
  topk,
199
+ num_local_experts,
202
200
  k,
203
- 0,
204
201
  BLOCK_SIZE=512,
205
202
  )
206
203
  return output