sglang 0.4.1.post3__py3-none-any.whl → 0.4.1.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. sglang/bench_one_batch.py +2 -0
  2. sglang/srt/layers/attention/__init__.py +14 -5
  3. sglang/srt/layers/attention/double_sparsity_backend.py +0 -52
  4. sglang/srt/layers/attention/flashinfer_backend.py +211 -81
  5. sglang/srt/layers/attention/torch_native_backend.py +1 -38
  6. sglang/srt/layers/attention/triton_backend.py +20 -11
  7. sglang/srt/layers/attention/triton_ops/decode_attention.py +4 -0
  8. sglang/srt/layers/logits_processor.py +167 -212
  9. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  10. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200.json +146 -0
  11. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  12. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200.json +146 -0
  13. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  14. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200.json +146 -0
  15. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  16. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200.json +146 -0
  17. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  18. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200.json +146 -0
  19. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  20. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200.json +146 -0
  21. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  22. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200.json +146 -0
  23. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  24. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200.json +146 -0
  25. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200.json +146 -0
  27. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200.json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +187 -29
  31. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -6
  32. sglang/srt/layers/quantization/fp8.py +2 -2
  33. sglang/srt/layers/sampler.py +57 -21
  34. sglang/srt/layers/torchao_utils.py +17 -3
  35. sglang/srt/managers/io_struct.py +1 -2
  36. sglang/srt/managers/schedule_batch.py +26 -2
  37. sglang/srt/managers/schedule_policy.py +159 -90
  38. sglang/srt/managers/scheduler.py +62 -26
  39. sglang/srt/managers/tokenizer_manager.py +22 -20
  40. sglang/srt/managers/tp_worker.py +16 -4
  41. sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
  42. sglang/srt/model_executor/cuda_graph_runner.py +118 -73
  43. sglang/srt/model_executor/forward_batch_info.py +33 -8
  44. sglang/srt/model_executor/model_runner.py +63 -61
  45. sglang/srt/models/deepseek_v2.py +34 -7
  46. sglang/srt/models/grok.py +97 -26
  47. sglang/srt/openai_api/adapter.py +0 -17
  48. sglang/srt/openai_api/protocol.py +3 -3
  49. sglang/srt/sampling/sampling_batch_info.py +21 -0
  50. sglang/srt/sampling/sampling_params.py +9 -1
  51. sglang/srt/server.py +9 -5
  52. sglang/srt/server_args.py +108 -57
  53. sglang/srt/speculative/build_eagle_tree.py +347 -0
  54. sglang/srt/speculative/eagle_utils.py +618 -0
  55. sglang/srt/speculative/eagle_worker.py +170 -0
  56. sglang/srt/speculative/spec_info.py +5 -0
  57. sglang/srt/utils.py +15 -2
  58. sglang/version.py +1 -1
  59. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post4.dist-info}/METADATA +9 -8
  60. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post4.dist-info}/RECORD +63 -39
  61. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post4.dist-info}/WHEEL +1 -1
  62. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post4.dist-info}/LICENSE +0 -0
  63. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 64,
4
+ "BLOCK_SIZE_N": 128,
5
+ "BLOCK_SIZE_K": 256,
6
+ "GROUP_SIZE_M": 32,
7
+ "num_warps": 4,
8
+ "num_stages": 4
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 64,
12
+ "BLOCK_SIZE_N": 128,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 64,
15
+ "num_warps": 4,
16
+ "num_stages": 3
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 64,
20
+ "BLOCK_SIZE_N": 64,
21
+ "BLOCK_SIZE_K": 256,
22
+ "GROUP_SIZE_M": 32,
23
+ "num_warps": 4,
24
+ "num_stages": 3
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 64,
28
+ "BLOCK_SIZE_N": 128,
29
+ "BLOCK_SIZE_K": 256,
30
+ "GROUP_SIZE_M": 64,
31
+ "num_warps": 8,
32
+ "num_stages": 4
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 64,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 64,
39
+ "num_warps": 4,
40
+ "num_stages": 3
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 64,
44
+ "BLOCK_SIZE_N": 128,
45
+ "BLOCK_SIZE_K": 128,
46
+ "GROUP_SIZE_M": 32,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 64,
52
+ "BLOCK_SIZE_N": 128,
53
+ "BLOCK_SIZE_K": 128,
54
+ "GROUP_SIZE_M": 1,
55
+ "num_warps": 4,
56
+ "num_stages": 3
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 64,
60
+ "BLOCK_SIZE_N": 128,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 32,
63
+ "num_warps": 4,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 64,
68
+ "BLOCK_SIZE_N": 128,
69
+ "BLOCK_SIZE_K": 128,
70
+ "GROUP_SIZE_M": 64,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 64,
76
+ "BLOCK_SIZE_N": 128,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 16,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 64,
84
+ "BLOCK_SIZE_N": 128,
85
+ "BLOCK_SIZE_K": 128,
86
+ "GROUP_SIZE_M": 32,
87
+ "num_warps": 4,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 128,
92
+ "BLOCK_SIZE_N": 256,
93
+ "BLOCK_SIZE_K": 128,
94
+ "GROUP_SIZE_M": 1,
95
+ "num_warps": 8,
96
+ "num_stages": 4
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 128,
100
+ "BLOCK_SIZE_N": 256,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 16,
103
+ "num_warps": 8,
104
+ "num_stages": 4
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 128,
108
+ "BLOCK_SIZE_N": 256,
109
+ "BLOCK_SIZE_K": 128,
110
+ "GROUP_SIZE_M": 16,
111
+ "num_warps": 8,
112
+ "num_stages": 4
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 128,
116
+ "BLOCK_SIZE_N": 256,
117
+ "BLOCK_SIZE_K": 128,
118
+ "GROUP_SIZE_M": 64,
119
+ "num_warps": 8,
120
+ "num_stages": 3
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 128,
124
+ "BLOCK_SIZE_N": 256,
125
+ "BLOCK_SIZE_K": 128,
126
+ "GROUP_SIZE_M": 32,
127
+ "num_warps": 8,
128
+ "num_stages": 4
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 128,
132
+ "BLOCK_SIZE_N": 256,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 32,
135
+ "num_warps": 8,
136
+ "num_stages": 4
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 128,
140
+ "BLOCK_SIZE_N": 256,
141
+ "BLOCK_SIZE_K": 128,
142
+ "GROUP_SIZE_M": 16,
143
+ "num_warps": 8,
144
+ "num_stages": 4
145
+ }
146
+ }
@@ -17,15 +17,21 @@ from sglang.srt.layers.moe.topk import select_experts
17
17
  from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
18
18
  from sglang.srt.utils import direct_register_custom_op, get_device_name, is_hip
19
19
 
20
- not_hip = False
20
+ is_hip_flag = False
21
21
  if not is_hip():
22
22
  from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
23
23
 
24
- not_hip = True
24
+ is_hip_flag = False
25
+ else:
26
+ is_hip_flag = True
25
27
 
26
28
  logger = logging.getLogger(__name__)
27
29
  padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
28
30
 
31
+ enable_moe_align_block_size_triton = bool(
32
+ int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
33
+ )
34
+
29
35
 
30
36
  @triton.jit
31
37
  def fused_moe_kernel(
@@ -222,6 +228,139 @@ def fused_moe_kernel(
222
228
  tl.store(c_ptrs, accumulator, mask=c_mask)
223
229
 
224
230
 
231
+ def ceil_div(a, b):
232
+ return (a + b - 1) // b
233
+
234
+
235
+ @triton.jit
236
+ def moe_align_block_size_stage1(
237
+ topk_ids_ptr,
238
+ tokens_cnts_ptr,
239
+ num_experts: tl.constexpr,
240
+ numel: tl.constexpr,
241
+ tokens_per_thread: tl.constexpr,
242
+ ):
243
+ pid = tl.program_id(0)
244
+
245
+ start_idx = pid * tokens_per_thread
246
+
247
+ off_c = (pid + 1) * num_experts
248
+
249
+ for i in range(tokens_per_thread):
250
+ if start_idx + i < numel:
251
+ idx = tl.load(topk_ids_ptr + start_idx + i)
252
+ token_cnt = tl.load(tokens_cnts_ptr + off_c + idx)
253
+ tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1)
254
+
255
+
256
+ @triton.jit
257
+ def moe_align_block_size_stage2(
258
+ tokens_cnts_ptr,
259
+ num_experts: tl.constexpr,
260
+ ):
261
+ pid = tl.program_id(0)
262
+
263
+ last_cnt = 0
264
+ for i in range(1, num_experts + 1):
265
+ token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid)
266
+ last_cnt = last_cnt + token_cnt
267
+ tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt)
268
+
269
+
270
+ @triton.jit
271
+ def moe_align_block_size_stage3(
272
+ total_tokens_post_pad_ptr,
273
+ tokens_cnts_ptr,
274
+ cumsum_ptr,
275
+ num_experts: tl.constexpr,
276
+ block_size: tl.constexpr,
277
+ ):
278
+ last_cumsum = 0
279
+ off_cnt = num_experts * num_experts
280
+ for i in range(1, num_experts + 1):
281
+ token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1)
282
+ last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size
283
+ tl.store(cumsum_ptr + i, last_cumsum)
284
+ tl.store(total_tokens_post_pad_ptr, last_cumsum)
285
+
286
+
287
+ @triton.jit
288
+ def moe_align_block_size_stage4(
289
+ topk_ids_ptr,
290
+ sorted_token_ids_ptr,
291
+ expert_ids_ptr,
292
+ tokens_cnts_ptr,
293
+ cumsum_ptr,
294
+ num_experts: tl.constexpr,
295
+ block_size: tl.constexpr,
296
+ numel: tl.constexpr,
297
+ tokens_per_thread: tl.constexpr,
298
+ ):
299
+ pid = tl.program_id(0)
300
+ start_idx = tl.load(cumsum_ptr + pid)
301
+ end_idx = tl.load(cumsum_ptr + pid + 1)
302
+
303
+ for i in range(start_idx, end_idx, block_size):
304
+ tl.store(expert_ids_ptr + i // block_size, pid)
305
+
306
+ start_idx = pid * tokens_per_thread
307
+ off_t = pid * num_experts
308
+
309
+ for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread, numel)):
310
+ expert_id = tl.load(topk_ids_ptr + i)
311
+ token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id)
312
+ rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id)
313
+ tl.store(sorted_token_ids_ptr + rank_post_pad, i)
314
+ tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1)
315
+
316
+
317
+ def moe_align_block_size_triton(
318
+ topk_ids: torch.Tensor,
319
+ num_experts: int,
320
+ block_size: int,
321
+ sorted_token_ids: torch.Tensor,
322
+ expert_ids: torch.Tensor,
323
+ num_tokens_post_pad: torch.Tensor,
324
+ ) -> None:
325
+ numel = topk_ids.numel()
326
+ grid = (num_experts,)
327
+ tokens_cnts = torch.zeros(
328
+ (num_experts + 1, num_experts), dtype=torch.int32, device=topk_ids.device
329
+ )
330
+ cumsum = torch.zeros((num_experts + 1,), dtype=torch.int32, device=topk_ids.device)
331
+ tokens_per_thread = ceil_div(numel, num_experts)
332
+
333
+ moe_align_block_size_stage1[grid](
334
+ topk_ids,
335
+ tokens_cnts,
336
+ num_experts,
337
+ numel,
338
+ tokens_per_thread,
339
+ )
340
+ moe_align_block_size_stage2[grid](
341
+ tokens_cnts,
342
+ num_experts,
343
+ )
344
+ moe_align_block_size_stage3[(1,)](
345
+ num_tokens_post_pad,
346
+ tokens_cnts,
347
+ cumsum,
348
+ num_experts,
349
+ block_size,
350
+ )
351
+ moe_align_block_size_stage4[grid](
352
+ topk_ids,
353
+ sorted_token_ids,
354
+ expert_ids,
355
+ tokens_cnts,
356
+ cumsum,
357
+ num_experts,
358
+ block_size,
359
+ numel,
360
+ tokens_per_thread,
361
+ )
362
+
363
+
225
364
  def moe_align_block_size(
226
365
  topk_ids: torch.Tensor, block_size: int, num_experts: int
227
366
  ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
@@ -272,24 +411,36 @@ def moe_align_block_size(
272
411
  (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
273
412
  )
274
413
  num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
275
- if not_hip and num_experts >= 224:
276
- token_cnts_buffer = torch.empty(
277
- (num_experts + 1) * num_experts, dtype=torch.int32, device=topk_ids.device
278
- )
279
- cumsum_buffer = torch.empty(
280
- num_experts + 1, dtype=torch.int32, device=topk_ids.device
281
- )
414
+ if num_experts >= 224:
415
+ if enable_moe_align_block_size_triton or is_hip_flag:
416
+ moe_align_block_size_triton(
417
+ topk_ids,
418
+ num_experts,
419
+ block_size,
420
+ sorted_ids,
421
+ expert_ids,
422
+ num_tokens_post_pad,
423
+ )
424
+ else:
425
+ token_cnts_buffer = torch.empty(
426
+ (num_experts + 1) * num_experts,
427
+ dtype=torch.int32,
428
+ device=topk_ids.device,
429
+ )
430
+ cumsum_buffer = torch.empty(
431
+ num_experts + 1, dtype=torch.int32, device=topk_ids.device
432
+ )
282
433
 
283
- sgl_moe_align_block_size(
284
- topk_ids,
285
- num_experts,
286
- block_size,
287
- sorted_ids,
288
- expert_ids,
289
- num_tokens_post_pad,
290
- token_cnts_buffer,
291
- cumsum_buffer,
292
- )
434
+ sgl_moe_align_block_size(
435
+ topk_ids,
436
+ num_experts,
437
+ block_size,
438
+ sorted_ids,
439
+ expert_ids,
440
+ num_tokens_post_pad,
441
+ token_cnts_buffer,
442
+ cumsum_buffer,
443
+ )
293
444
  else:
294
445
  ops.moe_align_block_size(
295
446
  topk_ids,
@@ -326,9 +477,9 @@ def invoke_fused_moe_kernel(
326
477
 
327
478
  padded_size = 0
328
479
  if use_fp8_w8a8:
329
- padded_size = padding_size
330
480
  assert B_scale is not None
331
481
  if block_shape is None:
482
+ padded_size = padding_size
332
483
  A, A_scale = ops.scaled_fp8_quant(A, A_scale)
333
484
  else:
334
485
  assert len(block_shape) == 2
@@ -463,7 +614,7 @@ def get_default_config(
463
614
  "BLOCK_SIZE_K": 128,
464
615
  "GROUP_SIZE_M": 32,
465
616
  "num_warps": 8,
466
- "num_stages": 4,
617
+ "num_stages": 2 if is_hip_flag else 4,
467
618
  }
468
619
  if M <= E:
469
620
  config = {
@@ -472,7 +623,7 @@ def get_default_config(
472
623
  "BLOCK_SIZE_K": 128,
473
624
  "GROUP_SIZE_M": 1,
474
625
  "num_warps": 4,
475
- "num_stages": 4,
626
+ "num_stages": 2 if is_hip_flag else 4,
476
627
  }
477
628
  else:
478
629
  # Block-wise quant: BLOCK_SIZE_K must be divisable by block_shape[1]
@@ -482,7 +633,7 @@ def get_default_config(
482
633
  "BLOCK_SIZE_K": block_shape[1],
483
634
  "GROUP_SIZE_M": 32,
484
635
  "num_warps": 4,
485
- "num_stages": 3,
636
+ "num_stages": 2 if is_hip_flag else 3,
486
637
  }
487
638
  else:
488
639
  config = {
@@ -727,7 +878,7 @@ def fused_experts_impl(
727
878
  block_shape: Optional[List[int]] = None,
728
879
  ):
729
880
  padded_size = padding_size
730
- if not use_fp8_w8a8:
881
+ if not use_fp8_w8a8 or block_shape is not None:
731
882
  padded_size = 0
732
883
 
733
884
  # Check constraints.
@@ -854,11 +1005,18 @@ def fused_experts_impl(
854
1005
  block_shape=block_shape,
855
1006
  )
856
1007
 
857
- torch.sum(
858
- intermediate_cache3.view(*intermediate_cache3.shape),
859
- dim=1,
860
- out=out_hidden_states[begin_chunk_idx:end_chunk_idx],
861
- )
1008
+ if is_hip_flag:
1009
+ ops.moe_sum(
1010
+ intermediate_cache3.view(*intermediate_cache3.shape),
1011
+ out_hidden_states[begin_chunk_idx:end_chunk_idx],
1012
+ )
1013
+ else:
1014
+ torch.sum(
1015
+ intermediate_cache3.view(*intermediate_cache3.shape),
1016
+ dim=1,
1017
+ out=out_hidden_states[begin_chunk_idx:end_chunk_idx],
1018
+ )
1019
+
862
1020
  return out_hidden_states
863
1021
 
864
1022
 
@@ -321,9 +321,12 @@ class FusedMoE(torch.nn.Module):
321
321
  # Index the loaded weight for tp sharding.
322
322
  # gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
323
323
  shard_size = expert_data.shape[shard_dim] // 2
324
- loaded_weight = loaded_weight.narrow(
325
- shard_dim, shard_size * tp_rank, shard_size
326
- )
324
+
325
+ if not self.use_presharded_weights:
326
+ loaded_weight = loaded_weight.narrow(
327
+ shard_dim, shard_size * tp_rank, shard_size
328
+ )
329
+
327
330
  # Narrow parameter and load.
328
331
  # w1, gate_proj: Load into first logical weight of w13.
329
332
  if shard_id == "w1":
@@ -347,9 +350,12 @@ class FusedMoE(torch.nn.Module):
347
350
  # down_proj: "RowParallel" so tp sharding on input_dim
348
351
  # Narrow parameter and load.
349
352
  shard_size = expert_data.shape[shard_dim]
350
- loaded_weight = loaded_weight.narrow(
351
- shard_dim, shard_size * tp_rank, shard_size
352
- )
353
+
354
+ if not self.use_presharded_weights:
355
+ loaded_weight = loaded_weight.narrow(
356
+ shard_dim, shard_size * tp_rank, shard_size
357
+ )
358
+
353
359
  # w2, down_proj: Load into only logical weight of w2.
354
360
  expert_data.copy_(loaded_weight)
355
361
 
@@ -389,7 +395,9 @@ class FusedMoE(torch.nn.Module):
389
395
  weight_name: str,
390
396
  shard_id: str,
391
397
  expert_id: int,
398
+ use_presharded_weights: bool = False,
392
399
  ) -> None:
400
+ self.use_presharded_weights = use_presharded_weights
393
401
 
394
402
  # compressed-tensors checkpoints with packed weights are stored flipped
395
403
  # TODO (mgoin): check self.quant_method.quant_config.quant_format
@@ -280,9 +280,9 @@ class Fp8LinearMethod(LinearMethodBase):
280
280
  weight_scale=layer.weight_scale_inv,
281
281
  input_scale=None,
282
282
  )
283
- layer.weight = torch.nn.Parameter(weight, require_grad=False)
283
+ layer.weight = torch.nn.Parameter(weight, requires_grad=False)
284
284
  layer.weight_scale_inv = torch.nn.Parameter(
285
- weight_scale, require_grad=False
285
+ weight_scale, requires_grad=False
286
286
  )
287
287
  layer.input_scale = None
288
288
  return
@@ -1,5 +1,5 @@
1
1
  import logging
2
- from typing import Union
2
+ from typing import List
3
3
 
4
4
  import torch
5
5
  from torch import nn
@@ -28,13 +28,12 @@ class Sampler(nn.Module):
28
28
 
29
29
  def forward(
30
30
  self,
31
- logits: Union[torch.Tensor, LogitsProcessorOutput],
31
+ logits_output: LogitsProcessorOutput,
32
32
  sampling_info: SamplingBatchInfo,
33
+ return_logprob: bool,
34
+ top_logprobs_nums: List[int],
33
35
  ):
34
- if isinstance(logits, LogitsProcessorOutput):
35
- logits = logits.next_token_logits
36
-
37
- logits = logits.contiguous()
36
+ logits = logits_output.next_token_logits
38
37
 
39
38
  if self.use_nan_detectioin and torch.any(torch.isnan(logits)):
40
39
  logger.warning("Detected errors during sampling! NaN in the logits.")
@@ -47,6 +46,8 @@ class Sampler(nn.Module):
47
46
  if sampling_info.is_all_greedy:
48
47
  # Use torch.argmax if all requests use greedy sampling
49
48
  batch_next_token_ids = torch.argmax(logits, -1)
49
+ if return_logprob:
50
+ logprobs = torch.nn.functional.log_softmax(logits, dim=-1)
50
51
  else:
51
52
  # Post process logits
52
53
  logits.div_(sampling_info.temperatures)
@@ -54,6 +55,14 @@ class Sampler(nn.Module):
54
55
  del logits
55
56
 
56
57
  if global_server_args_dict["sampling_backend"] == "flashinfer":
58
+ if return_logprob:
59
+ # NOTE: the top_p_renorm_prob from flashinfer has numerical problems,
60
+ # https://github.com/flashinfer-ai/flashinfer/issues/708
61
+ # so we use the torch implementation.
62
+ logprobs = torch.log(
63
+ top_p_normalize_probs_torch(probs, sampling_info.top_ps)
64
+ )
65
+
57
66
  max_top_k_round, batch_size = 32, probs.shape[0]
58
67
  uniform_samples = torch.rand(
59
68
  (max_top_k_round, batch_size), device=probs.device
@@ -76,6 +85,7 @@ class Sampler(nn.Module):
76
85
  if self.use_nan_detectioin and not torch.all(success):
77
86
  logger.warning("Detected errors during sampling!")
78
87
  batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
88
+
79
89
  elif global_server_args_dict["sampling_backend"] == "pytorch":
80
90
  # A slower fallback implementation with torch native operations.
81
91
  batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
@@ -85,12 +95,31 @@ class Sampler(nn.Module):
85
95
  sampling_info.min_ps,
86
96
  sampling_info.need_min_p_sampling,
87
97
  )
98
+ if return_logprob:
99
+ logprobs = torch.log(
100
+ top_p_normalize_probs_torch(probs, sampling_info.top_ps)
101
+ )
88
102
  else:
89
103
  raise ValueError(
90
104
  f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
91
105
  )
92
106
 
93
- return batch_next_token_ids.to(torch.int32)
107
+ batch_next_token_ids = batch_next_token_ids.to(torch.int32)
108
+
109
+ # Attach logprobs to logits_output (in-place modification)
110
+ if return_logprob:
111
+ if any(x > 0 for x in top_logprobs_nums):
112
+ (
113
+ logits_output.next_token_top_logprobs_val,
114
+ logits_output.next_token_top_logprobs_idx,
115
+ ) = get_top_logprobs(logprobs, top_logprobs_nums)
116
+
117
+ logits_output.next_token_logprobs = logprobs[
118
+ torch.arange(len(batch_next_token_ids), device=sampling_info.device),
119
+ batch_next_token_ids,
120
+ ]
121
+
122
+ return batch_next_token_ids
94
123
 
95
124
 
96
125
  def top_k_top_p_min_p_sampling_from_probs_torch(
@@ -120,20 +149,27 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
120
149
  return batch_next_token_ids
121
150
 
122
151
 
123
- def top_p_normalize_probs(
152
+ def top_p_normalize_probs_torch(
124
153
  probs: torch.Tensor,
125
154
  top_ps: torch.Tensor,
126
155
  ):
127
- if global_server_args_dict["sampling_backend"] == "flashinfer":
128
- return top_p_renorm_prob(probs, top_ps)
129
- elif global_server_args_dict["sampling_backend"] == "pytorch":
130
- # See also top_k_top_p_min_p_sampling_from_probs_torch
131
- probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
132
- probs_sum = torch.cumsum(probs_sort, dim=-1)
133
- probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0
134
- probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
135
- return torch.zeros_like(probs_sort).scatter_(-1, probs_idx, probs_sort)
136
- else:
137
- raise ValueError(
138
- f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
139
- )
156
+ # See also top_k_top_p_min_p_sampling_from_probs_torch
157
+ probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
158
+ probs_sum = torch.cumsum(probs_sort, dim=-1)
159
+ probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0
160
+ probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
161
+ return torch.zeros_like(probs_sort).scatter_(-1, probs_idx, probs_sort)
162
+
163
+
164
+ def get_top_logprobs(logprobs: torch.Tensor, top_logprobs_nums: List[int]):
165
+ max_k = max(top_logprobs_nums)
166
+ ret = logprobs.topk(max_k, dim=1)
167
+ values = ret.values.tolist()
168
+ indices = ret.indices.tolist()
169
+
170
+ output_top_logprobs_val = []
171
+ output_top_logprobs_idx = []
172
+ for i, k in enumerate(top_logprobs_nums):
173
+ output_top_logprobs_val.append(values[i][:k])
174
+ output_top_logprobs_idx.append(indices[i][:k])
175
+ return output_top_logprobs_val, output_top_logprobs_idx
@@ -11,6 +11,22 @@ import torch
11
11
  logger = logging.getLogger(__name__)
12
12
 
13
13
 
14
+ def get_gemlite_cache_path() -> str:
15
+ return f"/tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json"
16
+
17
+
18
+ def save_gemlite_cache(print_error: bool = False) -> bool:
19
+ try:
20
+ from gemlite.core import GemLiteLinearTriton
21
+
22
+ GemLiteLinearTriton.cache_config(get_gemlite_cache_path())
23
+ except Exception:
24
+ if print_error:
25
+ logger.error("Failed to save the GemLite cache.")
26
+ return False
27
+ return True
28
+
29
+
14
30
  def apply_torchao_config_to_model(
15
31
  model: torch.nn.Module, torchao_config: str, filter_fn=None
16
32
  ):
@@ -74,9 +90,7 @@ def apply_torchao_config_to_model(
74
90
  )
75
91
 
76
92
  # try to load gemlite kernel config
77
- GemLiteLinearTriton.load_config(
78
- f"/tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json"
79
- )
93
+ GemLiteLinearTriton.load_config(get_gemlite_cache_path())
80
94
 
81
95
  elif "fp8wo" in torchao_config:
82
96
  # this requires newer hardware
@@ -426,8 +426,7 @@ class UpdateWeightsFromDistributedReqOutput:
426
426
 
427
427
  @dataclass
428
428
  class UpdateWeightsFromTensorReqInput:
429
- name: str
430
- tensor: torch.Tensor
429
+ serialized_named_tensors: bytes # indeed Dict[str, torch.Tensor]
431
430
 
432
431
 
433
432
  @dataclass