sglang 0.4.3.post4__py3-none-any.whl → 0.4.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. sglang/bench_serving.py +1 -1
  2. sglang/lang/chat_template.py +29 -0
  3. sglang/srt/_custom_ops.py +19 -17
  4. sglang/srt/configs/__init__.py +2 -0
  5. sglang/srt/configs/janus_pro.py +629 -0
  6. sglang/srt/configs/model_config.py +24 -14
  7. sglang/srt/conversation.py +80 -2
  8. sglang/srt/custom_op.py +64 -3
  9. sglang/srt/distributed/device_communicators/custom_all_reduce.py +18 -17
  10. sglang/srt/distributed/parallel_state.py +10 -1
  11. sglang/srt/entrypoints/engine.py +5 -3
  12. sglang/srt/entrypoints/http_server.py +1 -1
  13. sglang/srt/function_call_parser.py +33 -2
  14. sglang/srt/hf_transformers_utils.py +16 -1
  15. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  16. sglang/srt/layers/attention/flashinfer_mla_backend.py +317 -57
  17. sglang/srt/layers/attention/triton_backend.py +1 -3
  18. sglang/srt/layers/attention/triton_ops/decode_attention.py +6 -6
  19. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +3 -3
  20. sglang/srt/layers/attention/triton_ops/extend_attention.py +4 -4
  21. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +3 -3
  22. sglang/srt/layers/attention/vision.py +43 -62
  23. sglang/srt/layers/dp_attention.py +30 -2
  24. sglang/srt/layers/elementwise.py +411 -0
  25. sglang/srt/layers/linear.py +1 -1
  26. sglang/srt/layers/logits_processor.py +1 -0
  27. sglang/srt/layers/moe/ep_moe/kernels.py +2 -1
  28. sglang/srt/layers/moe/ep_moe/layer.py +25 -9
  29. sglang/srt/layers/moe/fused_moe_triton/configs/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +63 -23
  36. sglang/srt/layers/moe/fused_moe_triton/layer.py +16 -4
  37. sglang/srt/layers/moe/router.py +342 -0
  38. sglang/srt/layers/parameter.py +10 -0
  39. sglang/srt/layers/quantization/__init__.py +90 -68
  40. sglang/srt/layers/quantization/blockwise_int8.py +1 -2
  41. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  43. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  44. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  45. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  46. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  47. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  48. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  49. sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  50. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  51. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  52. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  63. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  64. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  65. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/quantization/fp8.py +174 -106
  68. sglang/srt/layers/quantization/fp8_kernel.py +210 -38
  69. sglang/srt/layers/quantization/fp8_utils.py +156 -15
  70. sglang/srt/layers/quantization/modelopt_quant.py +5 -1
  71. sglang/srt/layers/quantization/w8a8_fp8.py +128 -0
  72. sglang/srt/layers/quantization/w8a8_int8.py +152 -3
  73. sglang/srt/layers/rotary_embedding.py +5 -3
  74. sglang/srt/layers/sampler.py +29 -35
  75. sglang/srt/layers/vocab_parallel_embedding.py +0 -1
  76. sglang/srt/lora/backend/__init__.py +9 -12
  77. sglang/srt/managers/cache_controller.py +74 -8
  78. sglang/srt/managers/data_parallel_controller.py +1 -1
  79. sglang/srt/managers/image_processor.py +37 -631
  80. sglang/srt/managers/image_processors/base_image_processor.py +219 -0
  81. sglang/srt/managers/image_processors/janus_pro.py +79 -0
  82. sglang/srt/managers/image_processors/llava.py +152 -0
  83. sglang/srt/managers/image_processors/minicpmv.py +86 -0
  84. sglang/srt/managers/image_processors/mlama.py +60 -0
  85. sglang/srt/managers/image_processors/qwen_vl.py +161 -0
  86. sglang/srt/managers/io_struct.py +32 -15
  87. sglang/srt/managers/multi_modality_padding.py +134 -0
  88. sglang/srt/managers/schedule_batch.py +213 -118
  89. sglang/srt/managers/schedule_policy.py +40 -8
  90. sglang/srt/managers/scheduler.py +176 -683
  91. sglang/srt/managers/scheduler_output_processor_mixin.py +614 -0
  92. sglang/srt/managers/tokenizer_manager.py +6 -6
  93. sglang/srt/managers/tp_worker_overlap_thread.py +4 -1
  94. sglang/srt/mem_cache/base_prefix_cache.py +6 -8
  95. sglang/srt/mem_cache/chunk_cache.py +12 -44
  96. sglang/srt/mem_cache/hiradix_cache.py +71 -34
  97. sglang/srt/mem_cache/memory_pool.py +81 -17
  98. sglang/srt/mem_cache/paged_allocator.py +283 -0
  99. sglang/srt/mem_cache/radix_cache.py +117 -36
  100. sglang/srt/model_executor/cuda_graph_runner.py +68 -20
  101. sglang/srt/model_executor/forward_batch_info.py +23 -10
  102. sglang/srt/model_executor/model_runner.py +63 -63
  103. sglang/srt/model_loader/loader.py +2 -1
  104. sglang/srt/model_loader/weight_utils.py +1 -1
  105. sglang/srt/models/deepseek_janus_pro.py +2127 -0
  106. sglang/srt/models/deepseek_nextn.py +23 -3
  107. sglang/srt/models/deepseek_v2.py +200 -191
  108. sglang/srt/models/grok.py +374 -119
  109. sglang/srt/models/minicpmv.py +28 -89
  110. sglang/srt/models/mllama.py +1 -1
  111. sglang/srt/models/qwen2.py +0 -1
  112. sglang/srt/models/qwen2_5_vl.py +25 -50
  113. sglang/srt/models/qwen2_vl.py +33 -49
  114. sglang/srt/openai_api/adapter.py +59 -35
  115. sglang/srt/openai_api/protocol.py +8 -1
  116. sglang/srt/sampling/penaltylib/frequency_penalty.py +0 -1
  117. sglang/srt/sampling/penaltylib/presence_penalty.py +0 -1
  118. sglang/srt/server_args.py +24 -16
  119. sglang/srt/speculative/eagle_worker.py +75 -39
  120. sglang/srt/utils.py +104 -9
  121. sglang/test/runners.py +104 -10
  122. sglang/test/test_block_fp8.py +106 -16
  123. sglang/test/test_custom_ops.py +88 -0
  124. sglang/test/test_utils.py +20 -4
  125. sglang/utils.py +0 -4
  126. sglang/version.py +1 -1
  127. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/METADATA +9 -10
  128. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/RECORD +131 -84
  129. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/WHEEL +1 -1
  130. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/LICENSE +0 -0
  131. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 128,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 4,
8
+ "num_stages": 3
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 128,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 32,
15
+ "num_warps": 4,
16
+ "num_stages": 3
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 128,
21
+ "BLOCK_SIZE_K": 128,
22
+ "GROUP_SIZE_M": 1,
23
+ "num_warps": 4,
24
+ "num_stages": 3
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 128,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 32,
31
+ "num_warps": 4,
32
+ "num_stages": 3
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 1,
39
+ "num_warps": 4,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 128,
45
+ "BLOCK_SIZE_K": 128,
46
+ "GROUP_SIZE_M": 64,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 128,
53
+ "BLOCK_SIZE_K": 128,
54
+ "GROUP_SIZE_M": 64,
55
+ "num_warps": 4,
56
+ "num_stages": 3
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 128,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 32,
63
+ "num_warps": 4,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 128,
70
+ "GROUP_SIZE_M": 1,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 64,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 128,
86
+ "GROUP_SIZE_M": 32,
87
+ "num_warps": 4,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 16,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 16,
95
+ "num_warps": 4,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 16,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 16,
103
+ "num_warps": 4,
104
+ "num_stages": 3
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 16,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 128,
110
+ "GROUP_SIZE_M": 16,
111
+ "num_warps": 4,
112
+ "num_stages": 3
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 64,
116
+ "BLOCK_SIZE_N": 64,
117
+ "BLOCK_SIZE_K": 128,
118
+ "GROUP_SIZE_M": 16,
119
+ "num_warps": 4,
120
+ "num_stages": 3
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 16,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 128,
126
+ "GROUP_SIZE_M": 32,
127
+ "num_warps": 4,
128
+ "num_stages": 3
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 128,
132
+ "BLOCK_SIZE_N": 64,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 16,
135
+ "num_warps": 4,
136
+ "num_stages": 4
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 64,
140
+ "BLOCK_SIZE_N": 64,
141
+ "BLOCK_SIZE_K": 128,
142
+ "GROUP_SIZE_M": 16,
143
+ "num_warps": 4,
144
+ "num_stages": 3
145
+ }
146
+ }
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 64,
5
+ "BLOCK_SIZE_K": 64,
6
+ "GROUP_SIZE_M": 16,
7
+ "num_warps": 4,
8
+ "num_stages": 3
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 64,
13
+ "BLOCK_SIZE_K": 64,
14
+ "GROUP_SIZE_M": 64,
15
+ "num_warps": 4,
16
+ "num_stages": 4
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 64,
21
+ "BLOCK_SIZE_K": 64,
22
+ "GROUP_SIZE_M": 64,
23
+ "num_warps": 4,
24
+ "num_stages": 4
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 64,
29
+ "BLOCK_SIZE_K": 64,
30
+ "GROUP_SIZE_M": 32,
31
+ "num_warps": 4,
32
+ "num_stages": 3
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 64,
38
+ "GROUP_SIZE_M": 64,
39
+ "num_warps": 4,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 64,
45
+ "BLOCK_SIZE_K": 64,
46
+ "GROUP_SIZE_M": 1,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 64,
53
+ "BLOCK_SIZE_K": 64,
54
+ "GROUP_SIZE_M": 1,
55
+ "num_warps": 4,
56
+ "num_stages": 3
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 64,
61
+ "BLOCK_SIZE_K": 64,
62
+ "GROUP_SIZE_M": 1,
63
+ "num_warps": 4,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 64,
69
+ "BLOCK_SIZE_K": 64,
70
+ "GROUP_SIZE_M": 1,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 64,
78
+ "GROUP_SIZE_M": 32,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 64,
86
+ "GROUP_SIZE_M": 64,
87
+ "num_warps": 4,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 16,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 64,
94
+ "GROUP_SIZE_M": 16,
95
+ "num_warps": 4,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 32,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 64,
102
+ "GROUP_SIZE_M": 64,
103
+ "num_warps": 4,
104
+ "num_stages": 3
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 32,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 64,
110
+ "GROUP_SIZE_M": 64,
111
+ "num_warps": 4,
112
+ "num_stages": 2
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 64,
116
+ "BLOCK_SIZE_N": 128,
117
+ "BLOCK_SIZE_K": 64,
118
+ "GROUP_SIZE_M": 64,
119
+ "num_warps": 4,
120
+ "num_stages": 2
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 32,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 64,
126
+ "GROUP_SIZE_M": 64,
127
+ "num_warps": 4,
128
+ "num_stages": 2
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 128,
132
+ "BLOCK_SIZE_N": 128,
133
+ "BLOCK_SIZE_K": 64,
134
+ "GROUP_SIZE_M": 1,
135
+ "num_warps": 4,
136
+ "num_stages": 3
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 64,
140
+ "BLOCK_SIZE_N": 128,
141
+ "BLOCK_SIZE_K": 64,
142
+ "GROUP_SIZE_M": 64,
143
+ "num_warps": 4,
144
+ "num_stages": 2
145
+ }
146
+ }
@@ -11,20 +11,23 @@ from typing import Any, Callable, Dict, List, Optional, Tuple
11
11
  import torch
12
12
  import triton
13
13
  import triton.language as tl
14
- from vllm import _custom_ops as ops
14
+ from vllm import _custom_ops as vllm_ops
15
15
 
16
16
  from sglang.srt.layers.moe.topk import select_experts
17
17
  from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
18
- from sglang.srt.layers.quantization.int8_kernel import per_token_group_quant_int8
18
+ from sglang.srt.layers.quantization.int8_kernel import (
19
+ per_token_group_quant_int8,
20
+ per_token_quant_int8,
21
+ )
19
22
  from sglang.srt.utils import (
20
23
  direct_register_custom_op,
21
24
  get_bool_env_var,
22
25
  get_device_name,
23
- is_cuda_available,
26
+ is_cuda,
24
27
  is_hip,
25
28
  )
26
29
 
27
- is_hip_ = is_hip()
30
+ _is_hip = is_hip()
28
31
 
29
32
 
30
33
  logger = logging.getLogger(__name__)
@@ -34,17 +37,17 @@ enable_moe_align_block_size_triton = bool(
34
37
  int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
35
38
  )
36
39
 
37
- _is_cuda = torch.cuda.is_available() and torch.version.cuda
38
- _is_rocm = torch.cuda.is_available() and torch.version.hip
40
+ _is_cuda = is_cuda()
39
41
 
40
42
  if _is_cuda:
41
43
  from sgl_kernel import gelu_and_mul, silu_and_mul
42
44
 
45
+ from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
43
46
  from sglang.srt.layers.quantization.fp8_kernel import (
44
47
  sglang_per_token_group_quant_fp8,
45
48
  )
46
49
 
47
- if _is_cuda or _is_rocm:
50
+ if _is_cuda or _is_hip:
48
51
  from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
49
52
 
50
53
 
@@ -117,6 +120,7 @@ def fused_moe_kernel(
117
120
  - expert_ids: A tensor containing the indices of the expert for each
118
121
  block. It determines which expert matrix from B should be used for
119
122
  each block in A.
123
+
120
124
  This kernel performs the multiplication of a token by its corresponding
121
125
  expert matrix as determined by `expert_ids`. The sorting of
122
126
  `sorted_token_ids` by expert index and padding ensures divisibility by
@@ -167,17 +171,38 @@ def fused_moe_kernel(
167
171
  )
168
172
  b_scale = tl.load(b_scale_ptrs)
169
173
 
170
- if use_fp8_w8a8 or use_int8_w8a8:
174
+ if use_fp8_w8a8:
175
+ # block-wise
171
176
  if group_k > 0 and group_n > 0:
172
177
  a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
173
178
  offs_bsn = offs_bn // group_n
174
179
  b_scale_ptrs = (
175
180
  b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
176
181
  )
182
+ # tensor-wise
177
183
  else:
178
184
  a_scale = tl.load(a_scale_ptr)
179
185
  b_scale = tl.load(b_scale_ptr + off_experts)
180
186
 
187
+ if use_int8_w8a8:
188
+ # block-wise
189
+ if group_k > 0 and group_n > 0:
190
+ a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
191
+ offs_bsn = offs_bn // group_n
192
+ b_scale_ptrs = (
193
+ b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
194
+ )
195
+ # channel-wise
196
+ else:
197
+ # Load per-column scale for weights
198
+ b_scale_ptrs = (
199
+ b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
200
+ )
201
+ b_scale = tl.load(b_scale_ptrs)
202
+ # Load per-token scale for activations
203
+ a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
204
+ a_scale = tl.load(a_scale_ptrs, mask=token_mask, other=0.0)[:, None]
205
+
181
206
  # -----------------------------------------------------------
182
207
  # Iterate to compute a block of the C matrix.
183
208
  # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
@@ -217,7 +242,11 @@ def fused_moe_kernel(
217
242
 
218
243
  accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
219
244
  else:
220
- accumulator = tl.dot(a, b, acc=accumulator)
245
+ # fix out of shared memory issue
246
+ if use_fp8_w8a8:
247
+ accumulator = tl.dot(a, b, acc=accumulator)
248
+ else:
249
+ accumulator += tl.dot(a, b)
221
250
  else:
222
251
  accumulator += tl.dot(a, b)
223
252
  # Advance the ptrs to the next K block.
@@ -458,7 +487,7 @@ def moe_align_block_size(
458
487
  cumsum_buffer,
459
488
  )
460
489
  else:
461
- ops.moe_align_block_size(
490
+ vllm_ops.moe_align_block_size(
462
491
  topk_ids,
463
492
  num_experts,
464
493
  block_size,
@@ -497,9 +526,14 @@ def invoke_fused_moe_kernel(
497
526
  if use_fp8_w8a8:
498
527
  assert B_scale is not None
499
528
  if block_shape is None:
529
+ # activation tensor-wise fp8 quantization, dynamic or static
500
530
  padded_size = padding_size
501
- A, A_scale = ops.scaled_fp8_quant(A, A_scale)
531
+ if _is_cuda:
532
+ A, A_scale = sgl_scaled_fp8_quant(A, A_scale)
533
+ else:
534
+ A, A_scale = vllm_ops.scaled_fp8_quant(A, A_scale)
502
535
  else:
536
+ # activation block-wise fp8 quantization
503
537
  assert len(block_shape) == 2
504
538
  block_n, block_k = block_shape[0], block_shape[1]
505
539
  if _is_cuda:
@@ -512,9 +546,10 @@ def invoke_fused_moe_kernel(
512
546
  elif use_int8_w8a8:
513
547
  assert B_scale is not None
514
548
  if block_shape is None:
515
- padded_size = padding_size
516
- A, A_scale = ops.scaled_int8_quant(A, A_scale)
549
+ # activation channel-wise int8 quantization
550
+ A, A_scale = per_token_quant_int8(A)
517
551
  else:
552
+ # activation block-wise int8 quantization
518
553
  assert len(block_shape) == 2
519
554
  block_n, block_k = block_shape[0], block_shape[1]
520
555
  A, A_scale = per_token_group_quant_int8(A, block_k)
@@ -648,7 +683,7 @@ def get_default_config(
648
683
  "BLOCK_SIZE_K": 128,
649
684
  "GROUP_SIZE_M": 32,
650
685
  "num_warps": 8,
651
- "num_stages": 2 if is_hip_ else 4,
686
+ "num_stages": 2 if _is_hip else 4,
652
687
  }
653
688
  if M <= E:
654
689
  config = {
@@ -657,7 +692,7 @@ def get_default_config(
657
692
  "BLOCK_SIZE_K": 128,
658
693
  "GROUP_SIZE_M": 1,
659
694
  "num_warps": 4,
660
- "num_stages": 2 if is_hip_ else 4,
695
+ "num_stages": 2 if _is_hip else 4,
661
696
  }
662
697
  else:
663
698
  # Block-wise quant: BLOCK_SIZE_K must be divisable by block_shape[1]
@@ -667,7 +702,7 @@ def get_default_config(
667
702
  "BLOCK_SIZE_K": block_shape[1],
668
703
  "GROUP_SIZE_M": 32,
669
704
  "num_warps": 4,
670
- "num_stages": 2 if is_hip_ else 3,
705
+ "num_stages": 2 if _is_hip else 3,
671
706
  }
672
707
  else:
673
708
  config = {
@@ -945,7 +980,7 @@ def fused_experts_impl(
945
980
  if (
946
981
  not (use_fp8_w8a8 or use_int8_w8a8)
947
982
  or block_shape is not None
948
- or (is_hip_ and get_bool_env_var("CK_MOE"))
983
+ or (_is_hip and get_bool_env_var("CK_MOE"))
949
984
  ):
950
985
  padded_size = 0
951
986
 
@@ -1029,7 +1064,9 @@ def fused_experts_impl(
1029
1064
  # so the cache size and config are already set correctly and
1030
1065
  # do not need to be adjusted.
1031
1066
  intermediate_cache1 = intermediate_cache1[:tokens_in_chunk]
1032
- intermediate_cache2 = intermediate_cache2[:tokens_in_chunk]
1067
+ intermediate_cache2 = intermediate_cache2[
1068
+ : tokens_in_chunk * topk_ids.shape[1]
1069
+ ]
1033
1070
  intermediate_cache3 = intermediate_cache3[:tokens_in_chunk]
1034
1071
  config = get_config_func(tokens_in_chunk)
1035
1072
 
@@ -1060,17 +1097,20 @@ def fused_experts_impl(
1060
1097
  use_int8_w8a16=use_int8_w8a16,
1061
1098
  block_shape=block_shape,
1062
1099
  )
1063
-
1064
1100
  if activation == "silu":
1065
1101
  if _is_cuda:
1066
1102
  silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
1067
1103
  else:
1068
- ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
1104
+ vllm_ops.silu_and_mul(
1105
+ intermediate_cache2, intermediate_cache1.view(-1, N)
1106
+ )
1069
1107
  elif activation == "gelu":
1070
1108
  if _is_cuda:
1071
1109
  gelu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
1072
1110
  else:
1073
- ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
1111
+ vllm_ops.gelu_and_mul(
1112
+ intermediate_cache2, intermediate_cache1.view(-1, N)
1113
+ )
1074
1114
  else:
1075
1115
  raise ValueError(f"Unsupported activation: {activation=}")
1076
1116
 
@@ -1101,8 +1141,8 @@ def fused_experts_impl(
1101
1141
 
1102
1142
  if no_combine:
1103
1143
  pass
1104
- elif is_hip_:
1105
- ops.moe_sum(
1144
+ elif _is_hip:
1145
+ vllm_ops.moe_sum(
1106
1146
  intermediate_cache3.view(*intermediate_cache3.shape),
1107
1147
  out_hidden_states[begin_chunk_idx:end_chunk_idx],
1108
1148
  )
@@ -27,9 +27,9 @@ else:
27
27
 
28
28
  import logging
29
29
 
30
- is_hip_ = is_hip()
30
+ _is_hip = is_hip()
31
31
 
32
- if is_hip_:
32
+ if _is_hip:
33
33
  from aiter import ck_moe
34
34
 
35
35
  logger = logging.getLogger(__name__)
@@ -102,7 +102,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
102
102
  set_weight_attrs(w2_weight, extra_weight_attrs)
103
103
 
104
104
  def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
105
- if is_hip_ and get_bool_env_var("CK_MOE"):
105
+ if _is_hip and get_bool_env_var("CK_MOE"):
106
106
  layer.w13_weight = torch.nn.Parameter(
107
107
  permute_weight(layer.w13_weight.data),
108
108
  requires_grad=False,
@@ -175,7 +175,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
175
175
  correction_bias=correction_bias,
176
176
  )
177
177
 
178
- if is_hip_ and get_bool_env_var("CK_MOE"):
178
+ if _is_hip and get_bool_env_var("CK_MOE"):
179
179
  assert not no_combine, "unsupported"
180
180
  return ck_moe(
181
181
  x,
@@ -513,6 +513,10 @@ class FusedMoE(torch.nn.Module):
513
513
 
514
514
  # Case input scale: input_scale loading is only supported for fp8
515
515
  if "input_scale" in weight_name:
516
+ # INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust input_scale for e4m3fnuz (AMD)
517
+ if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
518
+ loaded_weight = loaded_weight * 2.0
519
+
516
520
  # this is needed for compressed-tensors only
517
521
  loaded_weight = loaded_weight.to(param.data.device)
518
522
 
@@ -551,6 +555,10 @@ class FusedMoE(torch.nn.Module):
551
555
  # specific to each case
552
556
  quant_method = getattr(param, "quant_method", None)
553
557
  if quant_method == FusedMoeWeightScaleSupported.CHANNEL.value:
558
+ # INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust INT4 column-wise scaling number to e4m3fnuz (AMD)
559
+ if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
560
+ loaded_weight = loaded_weight * 0.5
561
+
554
562
  self._load_per_channel_weight_scale(
555
563
  shard_id=shard_id,
556
564
  shard_dim=shard_dim,
@@ -570,6 +578,10 @@ class FusedMoE(torch.nn.Module):
570
578
  tp_rank=tp_rank,
571
579
  )
572
580
  elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value:
581
+ # INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust FP8 per-tensor scaling number for e4m3fnuz (AMD)
582
+ if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
583
+ loaded_weight = loaded_weight * 2.0
584
+
573
585
  self._load_per_tensor_weight_scale(
574
586
  shard_id=shard_id,
575
587
  param=param,