sglang 0.4.6.post2__py3-none-any.whl → 0.4.6.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. sglang/bench_one_batch.py +1 -11
  2. sglang/bench_serving.py +149 -1
  3. sglang/lang/chat_template.py +44 -0
  4. sglang/srt/configs/deepseekvl2.py +3 -0
  5. sglang/srt/configs/device_config.py +1 -1
  6. sglang/srt/configs/internvl.py +696 -0
  7. sglang/srt/configs/janus_pro.py +3 -0
  8. sglang/srt/configs/model_config.py +17 -0
  9. sglang/srt/constrained/xgrammar_backend.py +11 -19
  10. sglang/srt/conversation.py +30 -3
  11. sglang/srt/disaggregation/decode.py +4 -1
  12. sglang/srt/disaggregation/mini_lb.py +74 -23
  13. sglang/srt/disaggregation/mooncake/conn.py +9 -18
  14. sglang/srt/disaggregation/nixl/conn.py +241 -71
  15. sglang/srt/disaggregation/utils.py +44 -1
  16. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
  17. sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
  18. sglang/srt/distributed/device_communicators/pynccl.py +2 -1
  19. sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
  20. sglang/srt/distributed/parallel_state.py +22 -1
  21. sglang/srt/entrypoints/engine.py +14 -2
  22. sglang/srt/entrypoints/http_server.py +28 -1
  23. sglang/srt/entrypoints/verl_engine.py +3 -2
  24. sglang/srt/hf_transformers_utils.py +20 -1
  25. sglang/srt/layers/attention/flashattention_backend.py +146 -50
  26. sglang/srt/layers/attention/flashinfer_backend.py +23 -13
  27. sglang/srt/layers/attention/flashinfer_mla_backend.py +62 -15
  28. sglang/srt/layers/attention/merge_state.py +46 -0
  29. sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
  30. sglang/srt/layers/attention/vision.py +290 -163
  31. sglang/srt/layers/moe/ep_moe/kernels.py +342 -7
  32. sglang/srt/layers/moe/ep_moe/layer.py +120 -1
  33. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +97 -54
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  36. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +4 -1
  37. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
  38. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
  39. sglang/srt/layers/quantization/deep_gemm.py +5 -0
  40. sglang/srt/layers/quantization/fp8.py +108 -95
  41. sglang/srt/layers/quantization/fp8_kernel.py +79 -60
  42. sglang/srt/layers/quantization/fp8_utils.py +71 -23
  43. sglang/srt/layers/quantization/kv_cache.py +3 -10
  44. sglang/srt/layers/quantization/utils.py +0 -5
  45. sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
  46. sglang/srt/lora/lora_manager.py +10 -13
  47. sglang/srt/managers/cache_controller.py +115 -119
  48. sglang/srt/managers/io_struct.py +10 -0
  49. sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
  50. sglang/srt/managers/multimodal_processors/internvl.py +232 -0
  51. sglang/srt/managers/schedule_batch.py +19 -1
  52. sglang/srt/managers/schedule_policy.py +11 -5
  53. sglang/srt/managers/scheduler.py +28 -13
  54. sglang/srt/managers/tokenizer_manager.py +24 -13
  55. sglang/srt/managers/tp_worker.py +9 -12
  56. sglang/srt/mem_cache/chunk_cache.py +2 -0
  57. sglang/srt/mem_cache/memory_pool.py +2 -2
  58. sglang/srt/model_executor/model_runner.py +44 -33
  59. sglang/srt/model_loader/loader.py +18 -11
  60. sglang/srt/models/clip.py +4 -4
  61. sglang/srt/models/deepseek_janus_pro.py +1 -1
  62. sglang/srt/models/deepseek_nextn.py +1 -20
  63. sglang/srt/models/deepseek_v2.py +55 -20
  64. sglang/srt/models/gemma3_mm.py +1 -1
  65. sglang/srt/models/internlm2.py +3 -0
  66. sglang/srt/models/internvl.py +670 -0
  67. sglang/srt/models/llama.py +1 -1
  68. sglang/srt/models/llama4.py +53 -7
  69. sglang/srt/models/minicpmv.py +1 -1
  70. sglang/srt/models/mllama.py +1 -1
  71. sglang/srt/models/phi3_small.py +16 -2
  72. sglang/srt/models/qwen2_5_vl.py +8 -4
  73. sglang/srt/models/qwen2_vl.py +4 -4
  74. sglang/srt/models/xiaomi_mimo.py +171 -0
  75. sglang/srt/openai_api/adapter.py +24 -40
  76. sglang/srt/openai_api/protocol.py +28 -16
  77. sglang/srt/reasoning_parser.py +2 -2
  78. sglang/srt/sampling/sampling_batch_info.py +54 -2
  79. sglang/srt/sampling/sampling_params.py +2 -0
  80. sglang/srt/server_args.py +30 -6
  81. sglang/srt/utils.py +35 -1
  82. sglang/test/test_block_fp8.py +2 -2
  83. sglang/test/test_deepep_utils.py +219 -0
  84. sglang/test/test_utils.py +3 -1
  85. sglang/version.py +1 -1
  86. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/METADATA +14 -6
  87. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/RECORD +90 -80
  88. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/WHEEL +1 -1
  89. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/licenses/LICENSE +0 -0
  90. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/top_level.txt +0 -0
@@ -1,14 +1,19 @@
1
+ from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
1
2
  from sglang.srt.utils import DeepEPMode
2
3
 
3
4
  try:
4
5
  from deep_ep import Buffer
5
6
 
7
+ from sglang.srt.layers.quantization.fp8_kernel import (
8
+ sglang_per_token_group_quant_fp8,
9
+ )
10
+
6
11
  use_deepep = True
7
12
  except ImportError:
8
13
  use_deepep = False
9
14
 
10
15
  from enum import IntEnum, auto
11
- from typing import Optional, Tuple
16
+ from typing import Optional, Tuple, Union
12
17
 
13
18
  import torch
14
19
  import torch.distributed as dist
@@ -78,7 +83,6 @@ class DeepEPBuffer:
78
83
  ),
79
84
  num_rdma_bytes,
80
85
  )
81
-
82
86
  cls._buffer = Buffer(
83
87
  group,
84
88
  num_nvl_bytes,
@@ -181,44 +185,74 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
181
185
  topk_weights: torch.Tensor,
182
186
  ):
183
187
  topk_idx = topk_idx.to(torch.int64)
188
+ if _ENABLE_JIT_DEEPGEMM:
189
+ # TODO hard code 128 block quant,use fp8 communication
190
+ hidden_states = sglang_per_token_group_quant_fp8(hidden_states, 128)
184
191
  previous_event = Buffer.capture() if self.async_finish else None
185
192
  return hidden_states, topk_idx, topk_weights, previous_event
186
193
 
187
194
  def dispatch_b(self, hidden_states, topk_idx, topk_weights, previous_event):
188
- (
189
- hidden_states,
190
- topk_idx,
191
- topk_weights,
192
- event,
193
- ) = self._dispatch_core(hidden_states, topk_idx, topk_weights, previous_event)
194
- event.current_stream_wait() if self.async_finish else ()
195
- if hidden_states.shape[0] > 0:
196
- reorder_topk_ids, seg_indptr, hidden_states = self._deepep_permute(
197
- hidden_states, topk_idx, fp8_dtype=hidden_states.dtype
195
+ if _ENABLE_JIT_DEEPGEMM:
196
+ (
197
+ hidden_states,
198
+ topk_idx,
199
+ topk_weights,
200
+ num_recv_tokens_per_expert_list,
201
+ event,
202
+ ) = self._dispatch_core(
203
+ hidden_states, topk_idx, topk_weights, previous_event
198
204
  )
199
- else:
200
- reorder_topk_ids = torch.empty(
201
- (0,), device=hidden_states.device, dtype=torch.int64
205
+ event.current_stream_wait() if self.async_finish else ()
206
+ return (
207
+ hidden_states,
208
+ topk_idx,
209
+ topk_weights,
210
+ None,
211
+ num_recv_tokens_per_expert_list,
212
+ None,
213
+ None,
214
+ None,
202
215
  )
203
- seg_indptr = torch.zeros(
204
- (self.num_experts + 1,), device=hidden_states.device, dtype=torch.int64
216
+ else:
217
+ (
218
+ hidden_states,
219
+ topk_idx,
220
+ topk_weights,
221
+ num_recv_tokens_per_expert_list,
222
+ event,
223
+ ) = self._dispatch_core(
224
+ hidden_states, topk_idx, topk_weights, previous_event
205
225
  )
226
+ event.current_stream_wait() if self.async_finish else ()
227
+ if hidden_states.shape[0] > 0:
228
+ reorder_topk_ids, seg_indptr, hidden_states = self._deepep_permute(
229
+ hidden_states, topk_idx, fp8_dtype=hidden_states.dtype
230
+ )
231
+ else:
232
+ reorder_topk_ids = torch.empty(
233
+ (0,), device=hidden_states.device, dtype=torch.int64
234
+ )
235
+ seg_indptr = torch.zeros(
236
+ (self.num_experts + 1,),
237
+ device=hidden_states.device,
238
+ dtype=torch.int64,
239
+ )
206
240
 
207
- masked_m = expected_m = None
208
-
209
- return (
210
- hidden_states,
211
- topk_idx,
212
- topk_weights,
213
- reorder_topk_ids,
214
- seg_indptr,
215
- masked_m,
216
- expected_m,
217
- )
241
+ masked_m = expected_m = None
242
+ return (
243
+ hidden_states,
244
+ topk_idx,
245
+ topk_weights,
246
+ reorder_topk_ids,
247
+ None,
248
+ seg_indptr,
249
+ masked_m,
250
+ expected_m,
251
+ )
218
252
 
219
253
  def _dispatch_core(
220
254
  self,
221
- x: torch.Tensor,
255
+ x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
222
256
  topk_idx: torch.Tensor,
223
257
  topk_weights: torch.Tensor,
224
258
  previous_event,
@@ -246,7 +280,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
246
280
  recv_x,
247
281
  recv_topk_idx,
248
282
  recv_topk_weights,
249
- _, # num_recv_tokens_per_expert_list
283
+ num_recv_tokens_per_expert_list,
250
284
  self.handle,
251
285
  event,
252
286
  ) = buffer.dispatch(
@@ -260,12 +294,14 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
260
294
  previous_event=previous_event,
261
295
  async_finish=self.async_finish,
262
296
  allocate_on_comm_stream=(previous_event is not None) and self.async_finish,
297
+ expert_alignment=128 if _ENABLE_JIT_DEEPGEMM else 1,
263
298
  )
264
299
 
265
300
  return (
266
301
  recv_x,
267
302
  recv_topk_idx,
268
303
  recv_topk_weights,
304
+ num_recv_tokens_per_expert_list,
269
305
  event,
270
306
  )
271
307
 
@@ -314,29 +350,32 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
314
350
  topk_idx: torch.Tensor,
315
351
  topk_weights: torch.Tensor,
316
352
  ):
317
- if hidden_states.shape[0] > 0:
318
- num_tokens = self.src2dst.shape[0] // self.router_topk
319
- output = torch.empty(
320
- (num_tokens, hidden_states.shape[1]),
321
- device=hidden_states.device,
322
- dtype=hidden_states.dtype,
323
- )
324
- deepep_post_reorder_triton_kernel[(num_tokens,)](
325
- hidden_states,
326
- output,
327
- self.src2dst,
328
- topk_idx,
329
- topk_weights,
330
- self.router_topk,
331
- hidden_states.shape[1],
332
- BLOCK_SIZE=512,
333
- )
353
+ if _ENABLE_JIT_DEEPGEMM:
354
+ output = hidden_states
334
355
  else:
335
- output = torch.zeros(
336
- (0, hidden_states.shape[1]),
337
- device=hidden_states.device,
338
- dtype=hidden_states.dtype,
339
- )
356
+ if hidden_states.shape[0] > 0:
357
+ num_tokens = self.src2dst.shape[0] // self.router_topk
358
+ output = torch.empty(
359
+ (num_tokens, hidden_states.shape[1]),
360
+ device=hidden_states.device,
361
+ dtype=hidden_states.dtype,
362
+ )
363
+ deepep_post_reorder_triton_kernel[(num_tokens,)](
364
+ hidden_states,
365
+ output,
366
+ self.src2dst,
367
+ topk_idx,
368
+ topk_weights,
369
+ self.router_topk,
370
+ hidden_states.shape[1],
371
+ BLOCK_SIZE=512,
372
+ )
373
+ else:
374
+ output = torch.zeros(
375
+ (0, hidden_states.shape[1]),
376
+ device=hidden_states.device,
377
+ dtype=hidden_states.dtype,
378
+ )
340
379
  previous_event = Buffer.capture() if self.async_finish else None
341
380
  return output, previous_event
342
381
 
@@ -360,6 +399,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
360
399
 
361
400
  def _get_buffer(self):
362
401
  DeepEPBuffer.set_dispatch_mode_as_normal()
402
+
363
403
  return DeepEPBuffer.get_deepep_buffer(
364
404
  self.group,
365
405
  self.hidden_size,
@@ -426,6 +466,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
426
466
  topk_idx,
427
467
  topk_weights,
428
468
  reorder_topk_ids,
469
+ None,
429
470
  seg_indptr,
430
471
  masked_m,
431
472
  expected_m,
@@ -570,7 +611,8 @@ class DeepEPDispatcher:
570
611
 
571
612
  def dispatch(self, *args, **kwargs) -> Tuple:
572
613
  self.dispatch_a(*args, **kwargs)
573
- return self.dispatch_b()
614
+ ret = self.dispatch_b()
615
+ return ret
574
616
 
575
617
  def dispatch_a(
576
618
  self,
@@ -593,7 +635,8 @@ class DeepEPDispatcher:
593
635
 
594
636
  def combine(self, *args, **kwargs) -> Tuple:
595
637
  self.combine_a(*args, **kwargs)
596
- return self.combine_b()
638
+ ret = self.combine_b()
639
+ return ret
597
640
 
598
641
  def combine_a(
599
642
  self,
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 64,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 32,
7
+ "num_warps": 4,
8
+ "num_stages": 4
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 128,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 1,
15
+ "num_warps": 4,
16
+ "num_stages": 4
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 64,
21
+ "BLOCK_SIZE_K": 128,
22
+ "GROUP_SIZE_M": 1,
23
+ "num_warps": 4,
24
+ "num_stages": 4
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 128,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 1,
31
+ "num_warps": 4,
32
+ "num_stages": 4
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 16,
39
+ "num_warps": 4,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 128,
45
+ "BLOCK_SIZE_K": 128,
46
+ "GROUP_SIZE_M": 64,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 128,
53
+ "BLOCK_SIZE_K": 64,
54
+ "GROUP_SIZE_M": 64,
55
+ "num_warps": 4,
56
+ "num_stages": 3
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 256,
61
+ "BLOCK_SIZE_K": 64,
62
+ "GROUP_SIZE_M": 1,
63
+ "num_warps": 4,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 128,
70
+ "GROUP_SIZE_M": 16,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 16,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 128,
86
+ "GROUP_SIZE_M": 16,
87
+ "num_warps": 4,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 16,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 16,
95
+ "num_warps": 4,
96
+ "num_stages": 3
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 16,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 16,
103
+ "num_warps": 4,
104
+ "num_stages": 3
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 32,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 128,
110
+ "GROUP_SIZE_M": 16,
111
+ "num_warps": 4,
112
+ "num_stages": 3
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 64,
116
+ "BLOCK_SIZE_N": 256,
117
+ "BLOCK_SIZE_K": 128,
118
+ "GROUP_SIZE_M": 32,
119
+ "num_warps": 8,
120
+ "num_stages": 4
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 64,
124
+ "BLOCK_SIZE_N": 256,
125
+ "BLOCK_SIZE_K": 128,
126
+ "GROUP_SIZE_M": 32,
127
+ "num_warps": 8,
128
+ "num_stages": 4
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 64,
132
+ "BLOCK_SIZE_N": 256,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 32,
135
+ "num_warps": 8,
136
+ "num_stages": 4
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 64,
140
+ "BLOCK_SIZE_N": 256,
141
+ "BLOCK_SIZE_K": 128,
142
+ "GROUP_SIZE_M": 32,
143
+ "num_warps": 8,
144
+ "num_stages": 4
145
+ }
146
+ }
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 128,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 16,
7
+ "num_warps": 4,
8
+ "num_stages": 4
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 128,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 64,
15
+ "num_warps": 4,
16
+ "num_stages": 4
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 128,
21
+ "BLOCK_SIZE_K": 128,
22
+ "GROUP_SIZE_M": 1,
23
+ "num_warps": 4,
24
+ "num_stages": 4
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 128,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 32,
31
+ "num_warps": 4,
32
+ "num_stages": 5
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 16,
39
+ "num_warps": 4,
40
+ "num_stages": 4
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 128,
45
+ "BLOCK_SIZE_K": 128,
46
+ "GROUP_SIZE_M": 1,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 128,
53
+ "BLOCK_SIZE_K": 128,
54
+ "GROUP_SIZE_M": 1,
55
+ "num_warps": 4,
56
+ "num_stages": 3
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 128,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 16,
63
+ "num_warps": 4,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 128,
70
+ "GROUP_SIZE_M": 64,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 64,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 128,
86
+ "GROUP_SIZE_M": 64,
87
+ "num_warps": 4,
88
+ "num_stages": 4
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 16,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 16,
95
+ "num_warps": 4,
96
+ "num_stages": 4
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 16,
100
+ "BLOCK_SIZE_N": 128,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 64,
103
+ "num_warps": 4,
104
+ "num_stages": 3
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 64,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 128,
110
+ "GROUP_SIZE_M": 32,
111
+ "num_warps": 4,
112
+ "num_stages": 3
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 64,
116
+ "BLOCK_SIZE_N": 128,
117
+ "BLOCK_SIZE_K": 128,
118
+ "GROUP_SIZE_M": 32,
119
+ "num_warps": 4,
120
+ "num_stages": 4
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 64,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 128,
126
+ "GROUP_SIZE_M": 64,
127
+ "num_warps": 4,
128
+ "num_stages": 3
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 64,
132
+ "BLOCK_SIZE_N": 128,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 16,
135
+ "num_warps": 4,
136
+ "num_stages": 4
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 64,
140
+ "BLOCK_SIZE_N": 128,
141
+ "BLOCK_SIZE_K": 128,
142
+ "GROUP_SIZE_M": 16,
143
+ "num_warps": 4,
144
+ "num_stages": 4
145
+ }
146
+ }
@@ -29,6 +29,7 @@ from sglang.srt.utils import (
29
29
  get_device_name,
30
30
  is_cuda,
31
31
  is_hip,
32
+ log_info_on_rank0,
32
33
  )
33
34
 
34
35
  _is_hip = is_hip()
@@ -945,7 +946,9 @@ def get_moe_configs(
945
946
  # For example, updating the Triton version might cause all old configs to become suboptimal.
946
947
  # To achieve the best performance, consider re-tuning the Triton fused MOE kernel in your environment.
947
948
  # For the tuning method, refer to: https://github.com/sgl-project/sglang/tree/main/benchmark/kernels/fused_moe_triton
948
- logger.info("Using MoE kernel config from %s.", config_file_path)
949
+ log_info_on_rank0(
950
+ logger, f"Using MoE kernel config from {config_file_path}."
951
+ )
949
952
  # If a configuration has been found, return it
950
953
  return {int(key): val for key, val in json.load(f).items()}
951
954
 
@@ -10,16 +10,14 @@ import torch
10
10
  from compressed_tensors import CompressionFormat
11
11
  from compressed_tensors.quantization import QuantizationStrategy
12
12
 
13
- from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
13
+ from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_quant
14
14
  from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
15
15
  from sglang.srt.layers.quantization.utils import (
16
16
  all_close_1d,
17
- is_cuda,
18
- is_fp8_fnuz,
19
17
  per_tensor_dequantize,
20
18
  replace_parameter,
21
19
  )
22
- from sglang.srt.utils import set_weight_attrs
20
+ from sglang.srt.utils import is_cuda, set_weight_attrs
23
21
 
24
22
  _is_cuda = is_cuda()
25
23
 
@@ -15,11 +15,12 @@ from sglang.srt.layers.parameter import (
15
15
  from sglang.srt.layers.quantization.compressed_tensors.schemes import (
16
16
  CompressedTensorsScheme,
17
17
  )
18
+ from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz
18
19
  from sglang.srt.layers.quantization.fp8_utils import (
19
20
  apply_fp8_linear,
20
21
  normalize_e4m3fn_to_e4m3fnuz,
21
22
  )
22
- from sglang.srt.layers.quantization.utils import is_fp8_fnuz, requantize_with_max_scale
23
+ from sglang.srt.layers.quantization.utils import requantize_with_max_scale
23
24
 
24
25
  __all__ = ["CompressedTensorsW8A8Fp8"]
25
26
 
@@ -28,6 +28,11 @@ if is_cuda():
28
28
  if get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true"):
29
29
  _ENABLE_JIT_DEEPGEMM = True
30
30
 
31
+
32
+ def get_enable_jit_deepgemm():
33
+ return _ENABLE_JIT_DEEPGEMM
34
+
35
+
31
36
  logger = logging.getLogger(__name__)
32
37
 
33
38
  _BUILTIN_M_LIST = list(range(1, 1024 * 16 + 1))