sglang 0.4.9.post4__py3-none-any.whl → 0.4.9.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (84) hide show
  1. sglang/lang/chat_template.py +21 -0
  2. sglang/srt/configs/internvl.py +3 -0
  3. sglang/srt/configs/model_config.py +4 -0
  4. sglang/srt/constrained/base_grammar_backend.py +10 -2
  5. sglang/srt/constrained/xgrammar_backend.py +7 -5
  6. sglang/srt/conversation.py +16 -1
  7. sglang/srt/debug_utils/__init__.py +0 -0
  8. sglang/srt/debug_utils/dump_comparator.py +131 -0
  9. sglang/srt/debug_utils/dumper.py +108 -0
  10. sglang/srt/debug_utils/text_comparator.py +172 -0
  11. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +13 -1
  12. sglang/srt/disaggregation/mooncake/conn.py +16 -0
  13. sglang/srt/disaggregation/prefill.py +13 -1
  14. sglang/srt/entrypoints/engine.py +4 -2
  15. sglang/srt/entrypoints/openai/serving_chat.py +132 -79
  16. sglang/srt/function_call/ebnf_composer.py +10 -3
  17. sglang/srt/function_call/function_call_parser.py +2 -0
  18. sglang/srt/function_call/glm4_moe_detector.py +164 -0
  19. sglang/srt/function_call/qwen3_coder_detector.py +1 -0
  20. sglang/srt/layers/attention/hybrid_attn_backend.py +100 -0
  21. sglang/srt/layers/attention/vision.py +56 -8
  22. sglang/srt/layers/layernorm.py +26 -1
  23. sglang/srt/layers/logits_processor.py +14 -3
  24. sglang/srt/layers/moe/ep_moe/layer.py +172 -206
  25. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=160,N=320,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -48
  27. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +11 -8
  28. sglang/srt/layers/moe/topk.py +84 -22
  29. sglang/srt/layers/multimodal.py +11 -8
  30. sglang/srt/layers/quantization/fp8.py +25 -247
  31. sglang/srt/layers/quantization/fp8_kernel.py +78 -48
  32. sglang/srt/layers/quantization/modelopt_quant.py +25 -10
  33. sglang/srt/layers/quantization/unquant.py +24 -76
  34. sglang/srt/layers/quantization/w4afp8.py +68 -17
  35. sglang/srt/lora/lora_registry.py +93 -29
  36. sglang/srt/managers/cache_controller.py +9 -7
  37. sglang/srt/managers/mm_utils.py +154 -35
  38. sglang/srt/managers/multimodal_processor.py +3 -14
  39. sglang/srt/managers/schedule_batch.py +14 -8
  40. sglang/srt/managers/scheduler.py +35 -1
  41. sglang/srt/managers/tokenizer_manager.py +37 -6
  42. sglang/srt/managers/tp_worker.py +3 -0
  43. sglang/srt/mem_cache/hiradix_cache.py +5 -2
  44. sglang/srt/model_executor/model_runner.py +68 -14
  45. sglang/srt/models/deepseek_v2.py +62 -28
  46. sglang/srt/models/glm4_moe.py +1035 -0
  47. sglang/srt/models/glm4_moe_nextn.py +167 -0
  48. sglang/srt/models/interns1.py +328 -0
  49. sglang/srt/models/internvl.py +143 -47
  50. sglang/srt/models/llava.py +9 -5
  51. sglang/srt/models/minicpmo.py +4 -1
  52. sglang/srt/models/qwen2_moe.py +2 -2
  53. sglang/srt/models/qwen3_moe.py +5 -2
  54. sglang/srt/multimodal/processors/base_processor.py +20 -6
  55. sglang/srt/multimodal/processors/clip.py +2 -2
  56. sglang/srt/multimodal/processors/deepseek_vl_v2.py +2 -2
  57. sglang/srt/multimodal/processors/gemma3.py +2 -2
  58. sglang/srt/multimodal/processors/gemma3n.py +2 -2
  59. sglang/srt/multimodal/processors/internvl.py +21 -8
  60. sglang/srt/multimodal/processors/janus_pro.py +2 -2
  61. sglang/srt/multimodal/processors/kimi_vl.py +2 -2
  62. sglang/srt/multimodal/processors/llava.py +4 -4
  63. sglang/srt/multimodal/processors/minicpm.py +2 -3
  64. sglang/srt/multimodal/processors/mlama.py +2 -2
  65. sglang/srt/multimodal/processors/mllama4.py +18 -111
  66. sglang/srt/multimodal/processors/phi4mm.py +2 -2
  67. sglang/srt/multimodal/processors/pixtral.py +2 -2
  68. sglang/srt/multimodal/processors/qwen_audio.py +2 -2
  69. sglang/srt/multimodal/processors/qwen_vl.py +2 -2
  70. sglang/srt/multimodal/processors/vila.py +3 -1
  71. sglang/srt/reasoning_parser.py +2 -1
  72. sglang/srt/server_args.py +57 -6
  73. sglang/srt/utils.py +96 -1
  74. sglang/srt/weight_sync/utils.py +119 -0
  75. sglang/test/runners.py +4 -0
  76. sglang/test/test_utils.py +65 -5
  77. sglang/utils.py +19 -0
  78. sglang/version.py +1 -1
  79. {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post5.dist-info}/METADATA +4 -4
  80. {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post5.dist-info}/RECORD +83 -73
  81. sglang/srt/debug_utils.py +0 -74
  82. {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post5.dist-info}/WHEEL +0 -0
  83. {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post5.dist-info}/licenses/LICENSE +0 -0
  84. {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 64,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 32,
7
+ "num_warps": 4,
8
+ "num_stages": 4
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 128,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 64,
15
+ "num_warps": 4,
16
+ "num_stages": 3
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 128,
21
+ "BLOCK_SIZE_K": 256,
22
+ "GROUP_SIZE_M": 16,
23
+ "num_warps": 4,
24
+ "num_stages": 2
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 64,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 16,
31
+ "num_warps": 4,
32
+ "num_stages": 3
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 16,
39
+ "num_warps": 4,
40
+ "num_stages": 2
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 64,
45
+ "BLOCK_SIZE_K": 256,
46
+ "GROUP_SIZE_M": 1,
47
+ "num_warps": 4,
48
+ "num_stages": 2
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 64,
53
+ "BLOCK_SIZE_K": 256,
54
+ "GROUP_SIZE_M": 1,
55
+ "num_warps": 4,
56
+ "num_stages": 3
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 64,
61
+ "BLOCK_SIZE_K": 256,
62
+ "GROUP_SIZE_M": 1,
63
+ "num_warps": 4,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 256,
70
+ "GROUP_SIZE_M": 1,
71
+ "num_warps": 4,
72
+ "num_stages": 2
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 256,
78
+ "GROUP_SIZE_M": 1,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 256,
86
+ "GROUP_SIZE_M": 16,
87
+ "num_warps": 4,
88
+ "num_stages": 2
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 32,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 256,
94
+ "GROUP_SIZE_M": 16,
95
+ "num_warps": 4,
96
+ "num_stages": 2
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 32,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 256,
102
+ "GROUP_SIZE_M": 16,
103
+ "num_warps": 4,
104
+ "num_stages": 2
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 64,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 128,
110
+ "GROUP_SIZE_M": 1,
111
+ "num_warps": 4,
112
+ "num_stages": 4
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 32,
116
+ "BLOCK_SIZE_N": 128,
117
+ "BLOCK_SIZE_K": 128,
118
+ "GROUP_SIZE_M": 1,
119
+ "num_warps": 4,
120
+ "num_stages": 3
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 64,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 128,
126
+ "GROUP_SIZE_M": 1,
127
+ "num_warps": 4,
128
+ "num_stages": 3
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 64,
132
+ "BLOCK_SIZE_N": 128,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 1,
135
+ "num_warps": 4,
136
+ "num_stages": 3
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 64,
140
+ "BLOCK_SIZE_N": 128,
141
+ "BLOCK_SIZE_K": 128,
142
+ "GROUP_SIZE_M": 1,
143
+ "num_warps": 4,
144
+ "num_stages": 3
145
+ }
146
+ }
@@ -75,8 +75,9 @@ class FusedMoE(torch.nn.Module):
75
75
  inplace: bool = True,
76
76
  no_combine: bool = False,
77
77
  routed_scaling_factor: Optional[float] = None,
78
- enable_flashinfer_moe: Optional[bool] = False,
78
+ enable_flashinfer_cutlass_moe: Optional[bool] = False,
79
79
  enable_ep_moe: Optional[bool] = False,
80
+ skip_quant: Optional[bool] = False,
80
81
  ):
81
82
  super().__init__()
82
83
 
@@ -92,16 +93,13 @@ class FusedMoE(torch.nn.Module):
92
93
  self.num_experts = num_experts
93
94
  self.expert_map = None
94
95
 
95
- if enable_flashinfer_moe and quant_config is None:
96
+ if enable_flashinfer_cutlass_moe and quant_config is None:
96
97
  logger.warning("Disable flashinfer MoE when quantization config is None.")
97
- enable_flashinfer_moe = False
98
+ enable_flashinfer_cutlass_moe = False
98
99
  enable_ep_moe = False
99
100
 
100
- self.enable_flashinfer_moe = enable_flashinfer_moe
101
+ self.enable_flashinfer_cutlass_moe = enable_flashinfer_cutlass_moe
101
102
  if enable_ep_moe:
102
- assert (
103
- self.enable_flashinfer_moe
104
- ), "FusedMoE only supports EP with --enable-flashinfer-moe"
105
103
  self.ep_size = self.tp_size
106
104
  self.ep_rank = self.tp_rank
107
105
  self.tp_size = 1
@@ -110,16 +108,16 @@ class FusedMoE(torch.nn.Module):
110
108
  self.expert_map = torch.full((self.num_experts,), -1, dtype=torch.int32)
111
109
  # Create a expert map for the local experts
112
110
  assert num_experts % self.ep_size == 0
113
- self.local_num_experts = num_experts // self.ep_size
111
+ self.num_local_experts = num_experts // self.ep_size
114
112
  self.expert_map[
115
113
  self.ep_rank
116
- * self.local_num_experts : (self.ep_rank + 1)
117
- * self.local_num_experts
118
- ] = torch.arange(0, self.local_num_experts, dtype=torch.int32, device="cpu")
114
+ * self.num_local_experts : (self.ep_rank + 1)
115
+ * self.num_local_experts
116
+ ] = torch.arange(0, self.num_local_experts, dtype=torch.int32, device="cpu")
119
117
  else:
120
118
  self.ep_size = 1
121
119
  self.ep_rank = 0
122
- self.local_num_experts = num_experts
120
+ self.num_local_experts = num_experts
123
121
  self.routed_scaling_factor = routed_scaling_factor
124
122
  assert intermediate_size % self.tp_size == 0
125
123
  self.intermediate_size_per_partition = intermediate_size // self.tp_size
@@ -134,6 +132,9 @@ class FusedMoE(torch.nn.Module):
134
132
  not _is_cpu and global_server_args_dict["enable_triton_kernel_moe"]
135
133
  )
136
134
 
135
+ if skip_quant:
136
+ return
137
+
137
138
  if quant_config is None:
138
139
  self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod(
139
140
  self.use_triton_kernels
@@ -141,13 +142,15 @@ class FusedMoE(torch.nn.Module):
141
142
  else:
142
143
  self.quant_method = quant_config.get_quant_method(self, prefix)
143
144
  if self.quant_method.__class__.__name__ == "ModelOptNvFp4FusedMoEMethod":
144
- self.quant_method.enable_flashinfer_moe = self.enable_flashinfer_moe
145
+ self.quant_method.enable_flashinfer_cutlass_moe = (
146
+ self.enable_flashinfer_cutlass_moe
147
+ )
145
148
  assert self.quant_method is not None
146
149
 
147
150
  self.quant_config = quant_config
148
151
  self.quant_method.create_weights(
149
152
  layer=self,
150
- num_experts=self.local_num_experts,
153
+ num_experts=self.num_local_experts,
151
154
  hidden_size=hidden_size,
152
155
  # FIXME: figure out which intermediate_size to use
153
156
  intermediate_size=self.intermediate_size_per_partition,
@@ -376,6 +379,23 @@ class FusedMoE(torch.nn.Module):
376
379
  if expert_id == -1:
377
380
  return
378
381
 
382
+ self._weight_loader_impl(
383
+ param=param,
384
+ loaded_weight=loaded_weight,
385
+ weight_name=weight_name,
386
+ shard_id=shard_id,
387
+ expert_id=expert_id,
388
+ )
389
+
390
+ def _weight_loader_impl(
391
+ self,
392
+ param: torch.nn.Parameter,
393
+ loaded_weight: torch.Tensor,
394
+ weight_name: str,
395
+ shard_id: str,
396
+ expert_id: int,
397
+ ) -> None:
398
+
379
399
  # TP rank is set to 0 if EP is enabled
380
400
  tp_rank = 0 if self.ep_size > 1 else get_tensor_model_parallel_rank()
381
401
 
@@ -396,6 +416,10 @@ class FusedMoE(torch.nn.Module):
396
416
  f"shard_id must be ['w1','w2','w3'] but " f"got {shard_id}."
397
417
  )
398
418
 
419
+ # Flashinfer assumes w31 format for w13_weight. Same for the scales.
420
+ if getattr(self, "use_flashinfer_trtllm_moe", False):
421
+ shard_id = {"w1": "w3", "w3": "w1", "w2": "w2"}[shard_id]
422
+
399
423
  WEIGHT_SCALE_SUPPORTED = [e.value for e in FusedMoeWeightScaleSupported]
400
424
  # Fetch the dim to shard the parameter/loaded weight
401
425
  # based on the shard id. This will be whatever
@@ -603,37 +627,3 @@ class FusedMoE(torch.nn.Module):
603
627
  ("w3", ckpt_up_proj_name),
604
628
  ]
605
629
  ]
606
-
607
- def _load_fp8_scale(
608
- self,
609
- param: torch.nn.Parameter,
610
- loaded_weight: torch.Tensor,
611
- weight_name: str,
612
- shard_id: str,
613
- expert_id: int,
614
- ) -> None:
615
- param_data = param.data
616
-
617
- # Input scales can be loaded directly and should be equal.
618
- if "input_scale" in weight_name:
619
- if (
620
- param_data[expert_id] != 1
621
- and (param_data[expert_id] - loaded_weight).abs() > 1e-5
622
- ):
623
- raise ValueError(
624
- "input_scales of w1 and w3 of a layer "
625
- f"must be equal. But got {param_data[expert_id]} "
626
- f"vs. {loaded_weight}"
627
- )
628
- param_data[expert_id] = loaded_weight
629
- # Weight scales
630
- elif "weight_scale" in weight_name:
631
- # If we are in merged column case (gate_up_proj)
632
- if shard_id in ("w1", "w3"):
633
- # We have to keep the weight scales of w1 and w3 because
634
- # we need to re-quantize w1/w3 weights after weight loading.
635
- idx = 0 if shard_id == "w1" else 1
636
- param_data[expert_id][idx] = loaded_weight
637
- # If we are in the row parallel case (down_proj)
638
- else:
639
- param_data[expert_id] = loaded_weight
@@ -1,21 +1,25 @@
1
1
  # Adapted from https://github.com/vllm-project/vllm/pull/18595/files#diff-f426a6de78c82ffec568eff6811bfbf0043dab5f87f1a8c0cffdbdcb8a81e035
2
- from typing import Optional
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING, Optional
3
6
 
4
7
  import torch
5
8
  from sgl_kernel import gelu_and_mul, silu_and_mul
6
9
  from triton_kernels.matmul_ogs import matmul_ogs
7
- from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, routing
10
+ from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx
8
11
 
9
12
  from sglang.srt.utils import direct_register_custom_op
10
13
 
14
+ if TYPE_CHECKING:
15
+ from sglang.srt.layers.moe.topk import TopKOutput
16
+
11
17
 
12
18
  def triton_kernel_moe_forward(
13
19
  hidden_states: torch.Tensor,
14
20
  w1: torch.Tensor,
15
21
  w2: torch.Tensor,
16
- gating_output: torch.Tensor,
17
- topk: int,
18
- renormalize: bool,
22
+ topk_output: TopKOutput,
19
23
  inplace: bool = False,
20
24
  activation: str = "silu",
21
25
  apply_router_weight_on_input: bool = False,
@@ -30,9 +34,8 @@ def triton_kernel_moe_forward(
30
34
  block_shape: Optional[list[int]] = None,
31
35
  ) -> torch.Tensor:
32
36
 
33
- if not renormalize:
34
- gating_output = torch.softmax(gating_output, dim=-1)
35
- routing_data, gather_idx, scatter_idx = routing(gating_output, topk, renormalize)
37
+ assert topk_output.format.is_triton_kernel()
38
+ routing_data, gather_idx, scatter_idx = topk_output
36
39
 
37
40
  return triton_kernel_fused_experts(
38
41
  hidden_states,
@@ -15,7 +15,8 @@
15
15
  from __future__ import annotations
16
16
 
17
17
  import math
18
- from typing import Callable, NamedTuple, Optional
18
+ from enum import Enum, auto
19
+ from typing import Callable, NamedTuple, Optional, Protocol, runtime_checkable
19
20
 
20
21
  import torch
21
22
  import torch.nn.functional as F
@@ -27,6 +28,7 @@ from sglang.srt.eplb.expert_location_dispatch import (
27
28
  ExpertLocationDispatchInfo,
28
29
  topk_ids_logical_to_physical,
29
30
  )
31
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
30
32
  from sglang.srt.utils import (
31
33
  cpu_has_amx_support,
32
34
  get_bool_env_var,
@@ -37,6 +39,12 @@ from sglang.srt.utils import (
37
39
  is_npu,
38
40
  )
39
41
 
42
+ try:
43
+ from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx, routing
44
+ except ImportError:
45
+ pass
46
+
47
+
40
48
  _is_cuda = is_cuda()
41
49
  _is_hip = is_hip()
42
50
  _is_cpu = is_cpu()
@@ -58,15 +66,58 @@ if _is_npu:
58
66
  import torch_npu
59
67
 
60
68
 
61
- class TopKOutput(NamedTuple):
69
+ # -------------------------------- TopKOutput ---------------------------------------
70
+
71
+
72
+ class TopKOutputFormat(Enum):
73
+ STANDARD = auto()
74
+ TRITON_KERNEL = auto()
75
+
76
+ def is_standard(self) -> bool:
77
+ return self == TopKOutputFormat.STANDARD
78
+
79
+ def is_triton_kernel(self) -> bool:
80
+ return self == TopKOutputFormat.TRITON_KERNEL
81
+
82
+
83
+ @runtime_checkable
84
+ class TopKOutput(Protocol):
85
+ """Protocol for top-k outputs in different formats."""
86
+
87
+ @property
88
+ def format(self) -> TopKOutputFormat:
89
+ """The format of the output."""
90
+ ...
91
+
92
+
93
+ class StandardTopKOutput(NamedTuple):
94
+ """Standard top-k output format."""
95
+
62
96
  topk_weights: torch.Tensor
63
97
  topk_ids: torch.Tensor
64
98
  router_logits: torch.Tensor
65
99
 
100
+ @property
101
+ def format(self) -> TopKOutputFormat:
102
+ return TopKOutputFormat.STANDARD
66
103
 
67
- class TopK(CustomOp):
68
104
 
69
- # TODO(ch-wan): support triton_kernels
105
+ class TritonKernelTopKOutput(NamedTuple):
106
+ """Triton kernel top-k output format."""
107
+
108
+ routing_data: RoutingData
109
+ gather_indx: GatherIndx
110
+ scatter_indx: ScatterIndx
111
+
112
+ @property
113
+ def format(self) -> TopKOutputFormat:
114
+ return TopKOutputFormat.TRITON_KERNEL
115
+
116
+
117
+ # -------------------------------- TopK ---------------------------------------
118
+
119
+
120
+ class TopK(CustomOp):
70
121
 
71
122
  def __init__(
72
123
  self,
@@ -97,6 +148,8 @@ class TopK(CustomOp):
97
148
  self.correction_bias = correction_bias
98
149
  self.routed_scaling_factor = routed_scaling_factor
99
150
 
151
+ self.use_triton_kernels = global_server_args_dict["enable_triton_kernel_moe"]
152
+
100
153
  def forward_native(
101
154
  self,
102
155
  hidden_states: torch.Tensor,
@@ -131,23 +184,29 @@ class TopK(CustomOp):
131
184
  num_token_non_padded: Optional[torch.Tensor] = None,
132
185
  expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
133
186
  ) -> TopKOutput:
134
- torch_native = False
135
- return select_experts(
136
- hidden_states=hidden_states,
137
- router_logits=router_logits,
138
- top_k=self.top_k,
139
- use_grouped_topk=self.use_grouped_topk,
140
- renormalize=self.renormalize,
141
- topk_group=self.topk_group,
142
- num_expert_group=self.num_expert_group,
143
- num_fused_shared_experts=self.num_fused_shared_experts,
144
- custom_routing_function=self.custom_routing_function,
145
- correction_bias=self.correction_bias,
146
- torch_native=torch_native,
147
- routed_scaling_factor=self.routed_scaling_factor,
148
- num_token_non_padded=num_token_non_padded,
149
- expert_location_dispatch_info=expert_location_dispatch_info,
150
- )
187
+ if self.use_triton_kernels:
188
+ routing_data, gather_idx, scatter_idx = routing(
189
+ router_logits, self.top_k, self.renormalize
190
+ )
191
+ return TritonKernelTopKOutput(routing_data, gather_idx, scatter_idx)
192
+ else:
193
+ torch_native = False
194
+ return select_experts(
195
+ hidden_states=hidden_states,
196
+ router_logits=router_logits,
197
+ top_k=self.top_k,
198
+ use_grouped_topk=self.use_grouped_topk,
199
+ renormalize=self.renormalize,
200
+ topk_group=self.topk_group,
201
+ num_expert_group=self.num_expert_group,
202
+ num_fused_shared_experts=self.num_fused_shared_experts,
203
+ custom_routing_function=self.custom_routing_function,
204
+ correction_bias=self.correction_bias,
205
+ torch_native=torch_native,
206
+ routed_scaling_factor=self.routed_scaling_factor,
207
+ num_token_non_padded=num_token_non_padded,
208
+ expert_location_dispatch_info=expert_location_dispatch_info,
209
+ )
151
210
 
152
211
  def forward_cpu(
153
212
  self,
@@ -217,6 +276,9 @@ class TopK(CustomOp):
217
276
  )
218
277
 
219
278
 
279
+ # ------------------------------- TopK implementation -------------------------------------
280
+
281
+
220
282
  def fused_topk_torch_native(
221
283
  hidden_states: torch.Tensor,
222
284
  gating_output: torch.Tensor,
@@ -680,4 +742,4 @@ def select_experts(
680
742
 
681
743
  get_global_expert_distribution_recorder().on_select_experts(topk_ids=topk_ids)
682
744
 
683
- return TopKOutput(topk_weights, topk_ids, router_logits)
745
+ return StandardTopKOutput(topk_weights, topk_ids, router_logits)
@@ -55,14 +55,17 @@ def gpu_tensor_hash(tensor: torch.Tensor) -> int:
55
55
 
56
56
  intermediate_hashes = torch.empty(n, dtype=torch.int64, device=tensor.device)
57
57
 
58
- hash_kernel[grid](
59
- tensor,
60
- intermediate_hashes,
61
- n,
62
- BLOCK_SIZE=BLOCK_SIZE,
63
- PRIME=PRIME_1,
64
- XCONST=PRIME_2,
65
- )
58
+ # Set cuda device to prevent ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
59
+ # Solution from Tri: https://github.com/Dao-AILab/flash-attention/issues/523#issuecomment-1707611579
60
+ with torch.cuda.device(tensor.device):
61
+ hash_kernel[grid](
62
+ tensor,
63
+ intermediate_hashes,
64
+ n,
65
+ BLOCK_SIZE=BLOCK_SIZE,
66
+ PRIME=PRIME_1,
67
+ XCONST=PRIME_2,
68
+ )
66
69
 
67
70
  # TODO: threads can't be synced on triton kernel
68
71
  final_hash = intermediate_hashes.sum().item()