sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (168) hide show
  1. sglang/bench_one_batch.py +2 -1
  2. sglang/eval/loogle_eval.py +7 -0
  3. sglang/srt/configs/deepseekvl2.py +11 -2
  4. sglang/srt/configs/internvl.py +3 -0
  5. sglang/srt/configs/janus_pro.py +3 -0
  6. sglang/srt/configs/model_config.py +9 -7
  7. sglang/srt/configs/update_config.py +3 -1
  8. sglang/srt/conversation.py +1 -0
  9. sglang/srt/custom_op.py +5 -2
  10. sglang/srt/disaggregation/decode.py +9 -1
  11. sglang/srt/disaggregation/mooncake/conn.py +44 -56
  12. sglang/srt/distributed/parallel_state.py +33 -0
  13. sglang/srt/entrypoints/engine.py +30 -26
  14. sglang/srt/entrypoints/openai/serving_chat.py +21 -2
  15. sglang/srt/eplb/expert_location_dispatch.py +1 -1
  16. sglang/srt/function_call/function_call_parser.py +2 -0
  17. sglang/srt/function_call/qwen3_detector.py +150 -0
  18. sglang/srt/hf_transformers_utils.py +0 -1
  19. sglang/srt/layers/activation.py +13 -0
  20. sglang/srt/layers/attention/flashattention_backend.py +3 -3
  21. sglang/srt/layers/attention/flashinfer_backend.py +40 -1
  22. sglang/srt/layers/linear.py +13 -102
  23. sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
  24. sglang/srt/layers/moe/ep_moe/layer.py +23 -402
  25. sglang/srt/layers/moe/fused_moe_native.py +7 -47
  26. sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
  27. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +35 -45
  33. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
  34. sglang/srt/layers/moe/topk.py +187 -12
  35. sglang/srt/layers/quantization/__init__.py +20 -134
  36. sglang/srt/layers/quantization/awq.py +578 -11
  37. sglang/srt/layers/quantization/awq_triton.py +339 -0
  38. sglang/srt/layers/quantization/base_config.py +85 -10
  39. sglang/srt/layers/quantization/blockwise_int8.py +17 -55
  40. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
  41. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +24 -73
  42. sglang/srt/layers/quantization/fp8.py +273 -62
  43. sglang/srt/layers/quantization/fp8_kernel.py +210 -46
  44. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  45. sglang/srt/layers/quantization/gptq.py +501 -143
  46. sglang/srt/layers/quantization/marlin_utils.py +790 -0
  47. sglang/srt/layers/quantization/modelopt_quant.py +26 -108
  48. sglang/srt/layers/quantization/moe_wna16.py +45 -49
  49. sglang/srt/layers/quantization/petit.py +252 -0
  50. sglang/srt/layers/quantization/petit_utils.py +104 -0
  51. sglang/srt/layers/quantization/qoq.py +7 -6
  52. sglang/srt/layers/quantization/scalar_type.py +352 -0
  53. sglang/srt/layers/quantization/unquant.py +422 -0
  54. sglang/srt/layers/quantization/utils.py +343 -3
  55. sglang/srt/layers/quantization/w4afp8.py +8 -4
  56. sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
  57. sglang/srt/layers/quantization/w8a8_int8.py +51 -115
  58. sglang/srt/layers/vocab_parallel_embedding.py +1 -41
  59. sglang/srt/lora/lora.py +0 -4
  60. sglang/srt/lora/lora_manager.py +87 -53
  61. sglang/srt/lora/mem_pool.py +81 -33
  62. sglang/srt/lora/utils.py +12 -5
  63. sglang/srt/managers/cache_controller.py +241 -0
  64. sglang/srt/managers/io_struct.py +41 -29
  65. sglang/srt/managers/mm_utils.py +7 -8
  66. sglang/srt/managers/schedule_batch.py +150 -110
  67. sglang/srt/managers/schedule_policy.py +68 -27
  68. sglang/srt/managers/scheduler.py +243 -61
  69. sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
  70. sglang/srt/managers/tokenizer_manager.py +11 -3
  71. sglang/srt/managers/tp_worker.py +14 -0
  72. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  73. sglang/srt/mem_cache/allocator.py +7 -16
  74. sglang/srt/mem_cache/base_prefix_cache.py +14 -2
  75. sglang/srt/mem_cache/chunk_cache.py +5 -2
  76. sglang/srt/mem_cache/hicache_storage.py +152 -0
  77. sglang/srt/mem_cache/hiradix_cache.py +179 -4
  78. sglang/srt/mem_cache/memory_pool.py +16 -1
  79. sglang/srt/mem_cache/memory_pool_host.py +41 -2
  80. sglang/srt/mem_cache/radix_cache.py +26 -0
  81. sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
  82. sglang/srt/metrics/collector.py +9 -0
  83. sglang/srt/model_executor/cuda_graph_runner.py +5 -6
  84. sglang/srt/model_executor/forward_batch_info.py +14 -1
  85. sglang/srt/model_executor/model_runner.py +109 -22
  86. sglang/srt/model_loader/loader.py +7 -1
  87. sglang/srt/model_loader/utils.py +4 -4
  88. sglang/srt/models/clip.py +1 -1
  89. sglang/srt/models/deepseek.py +9 -6
  90. sglang/srt/models/deepseek_janus_pro.py +1 -1
  91. sglang/srt/models/deepseek_v2.py +191 -171
  92. sglang/srt/models/deepseek_vl2.py +5 -5
  93. sglang/srt/models/gemma.py +48 -0
  94. sglang/srt/models/gemma2.py +52 -0
  95. sglang/srt/models/gemma3_causal.py +63 -0
  96. sglang/srt/models/gemma3_mm.py +1 -1
  97. sglang/srt/models/gemma3n_mm.py +2 -4
  98. sglang/srt/models/granitemoe.py +385 -0
  99. sglang/srt/models/grok.py +9 -3
  100. sglang/srt/models/hunyuan.py +63 -16
  101. sglang/srt/models/internvl.py +1 -1
  102. sglang/srt/models/kimi_vl.py +1 -1
  103. sglang/srt/models/llama.py +41 -0
  104. sglang/srt/models/llama4.py +11 -11
  105. sglang/srt/models/llava.py +2 -2
  106. sglang/srt/models/llavavid.py +1 -1
  107. sglang/srt/models/minicpm.py +0 -2
  108. sglang/srt/models/minicpmo.py +3 -7
  109. sglang/srt/models/minicpmv.py +1 -1
  110. sglang/srt/models/mistral.py +1 -1
  111. sglang/srt/models/mixtral.py +9 -2
  112. sglang/srt/models/mllama.py +3 -5
  113. sglang/srt/models/mllama4.py +3 -3
  114. sglang/srt/models/olmoe.py +8 -5
  115. sglang/srt/models/persimmon.py +330 -0
  116. sglang/srt/models/phi.py +321 -0
  117. sglang/srt/models/phi4mm.py +44 -4
  118. sglang/srt/models/phi4mm_audio.py +1260 -0
  119. sglang/srt/models/phi4mm_utils.py +1917 -0
  120. sglang/srt/models/phimoe.py +9 -3
  121. sglang/srt/models/qwen.py +37 -0
  122. sglang/srt/models/qwen2.py +41 -0
  123. sglang/srt/models/qwen2_5_vl.py +4 -4
  124. sglang/srt/models/qwen2_audio.py +1 -1
  125. sglang/srt/models/qwen2_moe.py +53 -5
  126. sglang/srt/models/qwen2_vl.py +4 -4
  127. sglang/srt/models/qwen3.py +65 -1
  128. sglang/srt/models/qwen3_moe.py +56 -18
  129. sglang/srt/models/vila.py +1 -1
  130. sglang/srt/multimodal/processors/base_processor.py +91 -97
  131. sglang/srt/multimodal/processors/clip.py +21 -19
  132. sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
  133. sglang/srt/multimodal/processors/gemma3.py +13 -17
  134. sglang/srt/multimodal/processors/gemma3n.py +19 -23
  135. sglang/srt/multimodal/processors/internvl.py +9 -10
  136. sglang/srt/multimodal/processors/janus_pro.py +12 -27
  137. sglang/srt/multimodal/processors/kimi_vl.py +12 -14
  138. sglang/srt/multimodal/processors/llava.py +4 -2
  139. sglang/srt/multimodal/processors/minicpm.py +35 -44
  140. sglang/srt/multimodal/processors/mlama.py +21 -18
  141. sglang/srt/multimodal/processors/mllama4.py +4 -5
  142. sglang/srt/multimodal/processors/phi4mm.py +63 -39
  143. sglang/srt/multimodal/processors/pixtral.py +14 -35
  144. sglang/srt/multimodal/processors/qwen_audio.py +65 -0
  145. sglang/srt/multimodal/processors/qwen_vl.py +16 -21
  146. sglang/srt/multimodal/processors/vila.py +14 -14
  147. sglang/srt/sampling/sampling_params.py +8 -1
  148. sglang/srt/server_args.py +393 -230
  149. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +9 -1
  150. sglang/srt/two_batch_overlap.py +1 -0
  151. sglang/srt/utils.py +27 -1
  152. sglang/test/runners.py +14 -3
  153. sglang/test/test_block_fp8.py +8 -3
  154. sglang/test/test_block_fp8_ep.py +1 -1
  155. sglang/test/test_custom_ops.py +12 -7
  156. sglang/test/test_cutlass_w4a8_moe.py +1 -3
  157. sglang/test/test_fp4_moe.py +1 -3
  158. sglang/test/test_marlin_moe.py +286 -0
  159. sglang/test/test_marlin_utils.py +171 -0
  160. sglang/test/test_utils.py +35 -0
  161. sglang/version.py +1 -1
  162. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/METADATA +8 -8
  163. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/RECORD +166 -146
  164. sglang/srt/layers/quantization/quant_utils.py +0 -166
  165. sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
  166. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/WHEEL +0 -0
  167. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/licenses/LICENSE +0 -0
  168. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,6 @@
1
1
  # Adapted from https://github.com/vllm-project/vllm/tree/v0.8.2/vllm/model_executor/layers/quantization/compressed_tensors
2
2
  # SPDX-License-Identifier: Apache-2.0
3
+ from __future__ import annotations
3
4
 
4
5
  import logging
5
6
  from contextlib import suppress
@@ -18,12 +19,8 @@ from compressed_tensors.quantization import (
18
19
  )
19
20
  from pydantic import BaseModel
20
21
 
21
- from sglang.srt.layers.linear import (
22
- LinearBase,
23
- LinearMethodBase,
24
- UnquantizedLinearMethod,
25
- )
26
22
  from sglang.srt.layers.quantization.base_config import (
23
+ LinearMethodBase,
27
24
  QuantizationConfig,
28
25
  QuantizeMethodBase,
29
26
  )
@@ -40,9 +37,13 @@ from sglang.srt.layers.quantization.compressed_tensors.utils import (
40
37
  is_activation_quantization_format,
41
38
  should_ignore_layer,
42
39
  )
40
+ from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
43
41
 
44
42
  try:
45
- import vllm
43
+ from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import (
44
+ WNA16_SUPPORTED_BITS,
45
+ CompressedTensorsWNA16,
46
+ )
46
47
 
47
48
  VLLM_AVAILABLE = True
48
49
  except ImportError:
@@ -97,7 +98,7 @@ class CompressedTensorsConfig(QuantizationConfig):
97
98
  self.config = config
98
99
  self.packed_modules_mapping = packed_modules_mapping
99
100
 
100
- def get_linear_method(self) -> "CompressedTensorsLinearMethod":
101
+ def get_linear_method(self) -> CompressedTensorsLinearMethod:
101
102
  return CompressedTensorsLinearMethod(self)
102
103
 
103
104
  def get_supported_act_dtypes(cls) -> List[torch.dtype]:
@@ -117,7 +118,8 @@ class CompressedTensorsConfig(QuantizationConfig):
117
118
  self,
118
119
  layer: torch.nn.Module,
119
120
  prefix: str,
120
- ) -> Optional["QuantizeMethodBase"]:
121
+ ) -> Optional[QuantizeMethodBase]:
122
+ from sglang.srt.layers.linear import LinearBase
121
123
 
122
124
  # Check if the layer is skipped for quantization.
123
125
  # TODO (@robertgshaw2): support module names
@@ -138,7 +140,7 @@ class CompressedTensorsConfig(QuantizationConfig):
138
140
  return None
139
141
 
140
142
  @classmethod
141
- def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
143
+ def from_config(cls, config: Dict[str, Any]) -> CompressedTensorsConfig:
142
144
  ignore: List[str] = cast(List[str], config.get("ignore", []))
143
145
  quant_format = cast(str, config.get("format"))
144
146
  target_scheme_map = cls._quantization_scheme_map_from_config(config=config)
@@ -357,7 +359,7 @@ class CompressedTensorsConfig(QuantizationConfig):
357
359
 
358
360
  def _get_scheme_from_parts(
359
361
  self, weight_quant: BaseModel, input_quant: BaseModel
360
- ) -> "CompressedTensorsScheme":
362
+ ) -> CompressedTensorsScheme:
361
363
 
362
364
  # Detect If Mixed Precision
363
365
  if self._is_wNa16_group_channel(weight_quant, input_quant):
@@ -435,7 +437,7 @@ class CompressedTensorsConfig(QuantizationConfig):
435
437
 
436
438
  def get_scheme(
437
439
  self, layer: torch.nn.Module, layer_name: Optional[str] = None
438
- ) -> Optional["CompressedTensorsScheme"]:
440
+ ) -> Optional[CompressedTensorsScheme]:
439
441
  """
440
442
  compressed-tensors supports non uniform in the following way:
441
443
 
@@ -1,15 +1,17 @@
1
1
  # Adapted from https://github.com/vllm-project/vllm/tree/v0.8.2/vllm/model_executor/layers/quantization/compressed_tensors
2
2
  # SPDX-License-Identifier: Apache-2.0
3
+ from __future__ import annotations
3
4
 
4
5
  import enum
5
6
  import logging
6
7
  from enum import Enum
7
- from typing import Callable, List, Optional
8
+ from typing import TYPE_CHECKING, List, Optional
8
9
 
9
10
  import torch
10
11
  from compressed_tensors import CompressionFormat
11
12
  from compressed_tensors.quantization import QuantizationStrategy
12
13
 
14
+ from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase
13
15
  from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_quant
14
16
  from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
15
17
  from sglang.srt.layers.quantization.utils import (
@@ -18,14 +20,21 @@ from sglang.srt.layers.quantization.utils import (
18
20
  per_tensor_dequantize,
19
21
  replace_parameter,
20
22
  )
21
- from sglang.srt.utils import is_cpu, is_cuda, is_npu, set_weight_attrs
23
+ from sglang.srt.utils import is_cpu, is_cuda, is_hip, is_npu, set_weight_attrs
24
+
25
+ if TYPE_CHECKING:
26
+ from sglang.srt.layers.moe.topk import TopKOutput
27
+ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
28
+ CompressedTensorsConfig,
29
+ )
22
30
 
23
31
  _is_cuda = is_cuda()
24
32
  _is_npu = is_npu()
25
33
  _is_cpu_amx_available = cpu_has_amx_support()
26
34
  _is_cpu = is_cpu()
35
+ _is_hip = is_hip()
27
36
 
28
- if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)):
37
+ if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_hip):
29
38
  from vllm import _custom_ops as vllm_ops
30
39
  from vllm._custom_ops import scaled_fp8_quant
31
40
 
@@ -51,7 +60,7 @@ __all__ = [
51
60
  ]
52
61
 
53
62
 
54
- class CompressedTensorsMoEMethod:
63
+ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
55
64
  def __new__(cls, *args, **kwargs):
56
65
  if cls is CompressedTensorsMoEMethod:
57
66
  return super().__new__(cls)
@@ -59,7 +68,7 @@ class CompressedTensorsMoEMethod:
59
68
 
60
69
  @staticmethod
61
70
  def get_moe_method(
62
- quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
71
+ quant_config: CompressedTensorsConfig,
63
72
  ) -> "CompressedTensorsMoEMethod":
64
73
  # TODO: @dsikka: refactor this to use schemes as other kernels
65
74
  # are supported + check if the layer is being ignored.
@@ -82,9 +91,7 @@ class CompressedTensorsMoEMethod:
82
91
 
83
92
  class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
84
93
 
85
- def __init__(
86
- self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
87
- ):
94
+ def __init__(self, quant_config: CompressedTensorsConfig):
88
95
  self.quant_config = quant_config
89
96
  self.weight_quant = self.quant_config.target_scheme_map["Linear"].get("weights")
90
97
  self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
@@ -270,47 +277,21 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
270
277
  self,
271
278
  layer: torch.nn.Module,
272
279
  x: torch.Tensor,
273
- router_logits: torch.Tensor,
274
- top_k: int,
275
- renormalize: bool,
276
- use_grouped_topk: bool = False,
277
- topk_group: Optional[int] = None,
278
- num_expert_group: Optional[int] = None,
279
- num_fused_shared_experts: int = 0,
280
- global_num_experts: int = -1,
281
- expert_map: Optional[torch.Tensor] = None,
282
- custom_routing_function: Optional[Callable] = None,
283
- scoring_func: str = "softmax",
284
- correction_bias: Optional[torch.Tensor] = None,
280
+ topk_output: TopKOutput,
281
+ *,
285
282
  activation: str = "silu",
283
+ apply_router_weight_on_input: bool = False,
286
284
  inplace: bool = True,
287
285
  no_combine: bool = False,
288
- apply_router_weight_on_input: bool = False,
289
286
  routed_scaling_factor: Optional[float] = None,
290
287
  ) -> torch.Tensor:
291
288
  from sglang.srt.layers.moe.fused_moe_triton import fused_experts
292
- from sglang.srt.layers.moe.topk import select_experts
293
-
294
- topk_weights, topk_ids = select_experts(
295
- hidden_states=x,
296
- router_logits=router_logits,
297
- use_grouped_topk=use_grouped_topk,
298
- top_k=top_k,
299
- renormalize=renormalize,
300
- topk_group=topk_group,
301
- num_expert_group=num_expert_group,
302
- num_fused_shared_experts=num_fused_shared_experts,
303
- custom_routing_function=custom_routing_function,
304
- correction_bias=correction_bias,
305
- routed_scaling_factor=routed_scaling_factor,
306
- )
307
289
 
308
290
  return fused_experts(
309
291
  x,
310
292
  layer.w13_weight,
311
293
  layer.w2_weight,
312
- topk_weights=topk_weights,
313
- topk_ids=topk_ids,
294
+ topk_output=topk_output,
314
295
  inplace=inplace,
315
296
  activation=activation,
316
297
  use_fp8_w8a8=True,
@@ -327,9 +308,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
327
308
 
328
309
  class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
329
310
 
330
- def __init__(
331
- self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
332
- ):
311
+ def __init__(self, quant_config: CompressedTensorsConfig):
333
312
  self.quant_config = quant_config
334
313
  # TODO: @dsikka: refactor this to use schemes as other kernels
335
314
  # are supported + check if the layer is being ignored.
@@ -628,43 +607,15 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
628
607
  self,
629
608
  layer: torch.nn.Module,
630
609
  x: torch.Tensor,
631
- router_logits: torch.Tensor,
632
- top_k: int,
633
- renormalize: bool,
634
- use_grouped_topk: bool = False,
635
- topk_group: Optional[int] = None,
636
- num_expert_group: Optional[int] = None,
637
- num_fused_shared_experts: int = 0,
638
- global_num_experts: int = -1,
639
- expert_map: Optional[torch.Tensor] = None,
640
- custom_routing_function: Optional[Callable] = None,
641
- scoring_func: str = "softmax",
642
- correction_bias: Optional[torch.Tensor] = None,
610
+ topk_output: TopKOutput,
611
+ *,
643
612
  activation: str = "silu",
644
- routed_scaling_factor: Optional[float] = None,
613
+ **kwargs,
645
614
  ) -> torch.Tensor:
646
- from sglang.srt.layers.moe.topk import select_experts
647
615
 
648
616
  assert activation == "silu", "Only SiLU activation is supported."
649
- if expert_map is not None:
650
- raise NotImplementedError(
651
- "Expert Parallelism is not supported for " "fused Marlin MoE method."
652
- )
653
617
 
654
- topk_weights, topk_ids = select_experts(
655
- hidden_states=x,
656
- router_logits=router_logits,
657
- use_grouped_topk=use_grouped_topk,
658
- top_k=top_k,
659
- renormalize=renormalize,
660
- topk_group=topk_group,
661
- num_expert_group=num_expert_group,
662
- num_fused_shared_experts=num_fused_shared_experts,
663
- custom_routing_function=custom_routing_function,
664
- scoring_func=scoring_func,
665
- correction_bias=correction_bias,
666
- routed_scaling_factor=routed_scaling_factor,
667
- )
618
+ topk_weights, topk_ids, router_logits = topk_output
668
619
 
669
620
  return torch.ops.vllm.fused_marlin_moe(
670
621
  x,