sglang 0.3.5.post1__py3-none-any.whl → 0.3.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. sglang/bench_latency.py +1 -553
  2. sglang/bench_offline_throughput.py +337 -0
  3. sglang/bench_one_batch.py +474 -0
  4. sglang/{bench_server_latency.py → bench_one_batch_server.py} +3 -3
  5. sglang/bench_serving.py +115 -31
  6. sglang/check_env.py +3 -6
  7. sglang/srt/constrained/base_grammar_backend.py +4 -3
  8. sglang/srt/constrained/outlines_backend.py +39 -26
  9. sglang/srt/constrained/xgrammar_backend.py +58 -14
  10. sglang/srt/layers/activation.py +3 -0
  11. sglang/srt/layers/attention/flashinfer_backend.py +93 -48
  12. sglang/srt/layers/attention/triton_backend.py +9 -7
  13. sglang/srt/layers/custom_op_util.py +26 -0
  14. sglang/srt/layers/fused_moe/fused_moe.py +11 -4
  15. sglang/srt/layers/fused_moe/patch.py +4 -2
  16. sglang/srt/layers/layernorm.py +4 -0
  17. sglang/srt/layers/logits_processor.py +10 -10
  18. sglang/srt/layers/sampler.py +4 -8
  19. sglang/srt/layers/torchao_utils.py +2 -0
  20. sglang/srt/managers/data_parallel_controller.py +74 -9
  21. sglang/srt/managers/detokenizer_manager.py +1 -14
  22. sglang/srt/managers/io_struct.py +27 -0
  23. sglang/srt/managers/schedule_batch.py +104 -38
  24. sglang/srt/managers/schedule_policy.py +5 -1
  25. sglang/srt/managers/scheduler.py +210 -56
  26. sglang/srt/managers/session_controller.py +62 -0
  27. sglang/srt/managers/tokenizer_manager.py +38 -0
  28. sglang/srt/managers/tp_worker.py +12 -1
  29. sglang/srt/managers/tp_worker_overlap_thread.py +49 -52
  30. sglang/srt/model_executor/cuda_graph_runner.py +43 -6
  31. sglang/srt/model_executor/forward_batch_info.py +109 -15
  32. sglang/srt/model_executor/model_runner.py +102 -43
  33. sglang/srt/model_parallel.py +98 -0
  34. sglang/srt/models/deepseek_v2.py +147 -44
  35. sglang/srt/models/gemma2.py +9 -8
  36. sglang/srt/models/llava.py +1 -1
  37. sglang/srt/models/llavavid.py +1 -1
  38. sglang/srt/models/olmo.py +3 -3
  39. sglang/srt/models/phi3_small.py +447 -0
  40. sglang/srt/models/qwen2_vl.py +13 -6
  41. sglang/srt/models/torch_native_llama.py +94 -78
  42. sglang/srt/openai_api/adapter.py +11 -4
  43. sglang/srt/openai_api/protocol.py +30 -27
  44. sglang/srt/sampling/penaltylib/orchestrator.py +49 -79
  45. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +3 -8
  46. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +3 -9
  47. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +3 -8
  48. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +3 -8
  49. sglang/srt/sampling/sampling_batch_info.py +58 -57
  50. sglang/srt/sampling/sampling_params.py +3 -3
  51. sglang/srt/server.py +29 -2
  52. sglang/srt/server_args.py +97 -60
  53. sglang/srt/utils.py +103 -51
  54. sglang/test/runners.py +25 -6
  55. sglang/test/srt/sampling/penaltylib/utils.py +23 -21
  56. sglang/test/test_utils.py +33 -22
  57. sglang/version.py +1 -1
  58. {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/METADATA +43 -43
  59. {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/RECORD +62 -56
  60. {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/WHEEL +1 -1
  61. {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/LICENSE +0 -0
  62. {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/top_level.txt +0 -0
@@ -8,7 +8,7 @@ Each backend supports two operators: extend (i.e. prefill with cached prefix) an
8
8
  """
9
9
 
10
10
  from enum import Enum, auto
11
- from typing import TYPE_CHECKING
11
+ from typing import TYPE_CHECKING, List
12
12
 
13
13
  import torch
14
14
  import triton
@@ -136,15 +136,17 @@ class FlashInferAttnBackend(AttentionBackend):
136
136
  prefix_lens = forward_batch.extend_prefix_lens
137
137
 
138
138
  # Some heuristics to check whether to use ragged forward
139
- use_ragged = False
140
139
  if forward_batch.extend_num_tokens >= 4096 and self.num_wrappers == 1:
141
140
  use_ragged = True
142
-
143
- extend_no_prefix = not torch.any(forward_batch.extend_prefix_lens).item()
141
+ extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
142
+ else:
143
+ use_ragged = False
144
+ extend_no_prefix = False
144
145
 
145
146
  self.indices_updater_prefill.update(
146
147
  forward_batch.req_pool_indices,
147
148
  forward_batch.seq_lens,
149
+ forward_batch.seq_lens_sum,
148
150
  prefix_lens,
149
151
  use_ragged=use_ragged,
150
152
  encoder_lens=forward_batch.encoder_lens,
@@ -314,7 +316,6 @@ class FlashInferIndicesUpdaterDecode:
314
316
  self.head_dim = model_runner.model_config.head_dim
315
317
  self.data_type = model_runner.kv_cache_dtype
316
318
  self.q_data_type = model_runner.dtype
317
- self.max_context_len = model_runner.req_to_token_pool.req_to_token.size(1)
318
319
  self.sliding_window_size = model_runner.sliding_window_size
319
320
 
320
321
  self.attn_backend = attn_backend
@@ -335,7 +336,12 @@ class FlashInferIndicesUpdaterDecode:
335
336
  self.update = self.update_single_wrapper
336
337
 
337
338
  def update(
338
- self, req_pool_indices, seq_lens, seq_lens_sum, decode_wrappers, encoder_lens
339
+ self,
340
+ req_pool_indices: torch.Tensor,
341
+ seq_lens: torch.Tensor,
342
+ seq_lens_sum: int,
343
+ decode_wrappers: List,
344
+ encoder_lens: torch.Tensor,
339
345
  ):
340
346
  # Keep the signature for type checking. It will be assigned during runtime.
341
347
  raise NotImplementedError()
@@ -345,8 +351,8 @@ class FlashInferIndicesUpdaterDecode:
345
351
  req_pool_indices: torch.Tensor,
346
352
  seq_lens: torch.Tensor,
347
353
  seq_lens_sum: int,
348
- decode_wrappers=None,
349
- encoder_lens=None,
354
+ decode_wrappers: List,
355
+ encoder_lens: torch.Tensor,
350
356
  ):
351
357
  decode_wrappers = decode_wrappers or self.decode_wrappers
352
358
  self.call_begin_forward(
@@ -363,8 +369,8 @@ class FlashInferIndicesUpdaterDecode:
363
369
  req_pool_indices: torch.Tensor,
364
370
  seq_lens: torch.Tensor,
365
371
  seq_lens_sum: int,
366
- decode_wrappers=None,
367
- encoder_lens=None,
372
+ decode_wrappers: List,
373
+ encoder_lens: torch.Tensor,
368
374
  ):
369
375
  decode_wrappers = decode_wrappers or self.decode_wrappers
370
376
 
@@ -394,11 +400,11 @@ class FlashInferIndicesUpdaterDecode:
394
400
 
395
401
  def update_cross_attention(
396
402
  self,
397
- req_pool_indices,
398
- seq_lens,
399
- seq_lens_sum,
400
- decode_wrappers=None,
401
- encoder_lens=None,
403
+ req_pool_indices: torch.Tensor,
404
+ seq_lens: torch.Tensor,
405
+ seq_lens_sum: int,
406
+ decode_wrappers: List,
407
+ encoder_lens: torch.Tensor,
402
408
  ):
403
409
  decode_wrappers = decode_wrappers or self.decode_wrappers
404
410
 
@@ -425,11 +431,11 @@ class FlashInferIndicesUpdaterDecode:
425
431
  def call_begin_forward(
426
432
  self,
427
433
  wrapper,
428
- req_pool_indices,
429
- paged_kernel_lens,
430
- paged_kernel_lens_sum,
431
- kv_indptr,
432
- kv_start_idx,
434
+ req_pool_indices: torch.Tensor,
435
+ paged_kernel_lens: torch.Tensor,
436
+ paged_kernel_lens_sum: int,
437
+ kv_indptr: torch.Tensor,
438
+ kv_start_idx: torch.Tensor,
433
439
  ):
434
440
  bs = len(req_pool_indices)
435
441
  kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
@@ -445,7 +451,7 @@ class FlashInferIndicesUpdaterDecode:
445
451
  kv_indptr,
446
452
  kv_start_idx,
447
453
  kv_indices,
448
- self.max_context_len,
454
+ self.req_to_token.shape[1],
449
455
  )
450
456
 
451
457
  wrapper.end_forward()
@@ -474,7 +480,6 @@ class FlashInferIndicesUpdaterPrefill:
474
480
  self.head_dim = model_runner.model_config.head_dim
475
481
  self.data_type = model_runner.kv_cache_dtype
476
482
  self.q_data_type = model_runner.dtype
477
- self.max_context_len = model_runner.req_to_token_pool.req_to_token.size(1)
478
483
  self.sliding_window_size = model_runner.sliding_window_size
479
484
 
480
485
  self.attn_backend = attn_backend
@@ -496,23 +501,40 @@ class FlashInferIndicesUpdaterPrefill:
496
501
  assert self.attn_backend.num_wrappers == 1
497
502
  self.update = self.update_single_wrapper
498
503
 
499
- def update(self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens):
504
+ def update(
505
+ self,
506
+ req_pool_indices: torch.Tnesor,
507
+ seq_lens: torch.Tensor,
508
+ seq_lens_sum: int,
509
+ prefix_lens: torch.Tensor,
510
+ use_ragged: bool,
511
+ encoder_lens: torch.Tensor,
512
+ ):
500
513
  # Keep the signature for type checking. It will be assigned during runtime.
501
514
  raise NotImplementedError()
502
515
 
503
516
  def update_single_wrapper(
504
- self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens
517
+ self,
518
+ req_pool_indices: torch.Tnesor,
519
+ seq_lens: torch.Tensor,
520
+ seq_lens_sum: int,
521
+ prefix_lens: torch.Tensor,
522
+ use_ragged: bool,
523
+ encoder_lens: torch.Tensor,
505
524
  ):
506
525
  if use_ragged:
507
526
  paged_kernel_lens = prefix_lens
527
+ paged_kernel_lens_sum = paged_kernel_lens.sum().item()
508
528
  else:
509
529
  paged_kernel_lens = seq_lens
530
+ paged_kernel_lens_sum = seq_lens_sum
510
531
 
511
532
  self.call_begin_forward(
512
533
  self.wrapper_ragged,
513
534
  self.wrappers_paged[0],
514
535
  req_pool_indices,
515
536
  paged_kernel_lens,
537
+ paged_kernel_lens_sum,
516
538
  seq_lens,
517
539
  prefix_lens,
518
540
  None,
@@ -522,7 +544,13 @@ class FlashInferIndicesUpdaterPrefill:
522
544
  )
523
545
 
524
546
  def update_sliding_window(
525
- self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens
547
+ self,
548
+ req_pool_indices: torch.Tensor,
549
+ seq_lens: torch.Tensor,
550
+ seq_lens_sum: int,
551
+ prefix_lens: torch.Tensor,
552
+ use_ragged: bool,
553
+ encoder_lens: torch.Tensor,
526
554
  ):
527
555
  for wrapper_id in range(2):
528
556
  if wrapper_id == 0:
@@ -531,9 +559,12 @@ class FlashInferIndicesUpdaterPrefill:
531
559
  seq_lens,
532
560
  torch.tensor(self.sliding_window_size) + seq_lens - prefix_lens,
533
561
  )
562
+ paged_kernel_lens_sum = paged_kernel_lens.sum().item()
534
563
  else:
535
564
  # full attention
536
565
  paged_kernel_lens = seq_lens
566
+ paged_kernel_lens_sum = seq_lens_sum
567
+
537
568
  kv_start_idx = seq_lens - paged_kernel_lens
538
569
 
539
570
  self.call_begin_forward(
@@ -541,6 +572,7 @@ class FlashInferIndicesUpdaterPrefill:
541
572
  self.wrappers_paged[wrapper_id],
542
573
  req_pool_indices,
543
574
  paged_kernel_lens,
575
+ paged_kernel_lens_sum,
544
576
  seq_lens,
545
577
  prefix_lens,
546
578
  kv_start_idx,
@@ -550,23 +582,32 @@ class FlashInferIndicesUpdaterPrefill:
550
582
  )
551
583
 
552
584
  def update_cross_attention(
553
- self, req_pool_indices, seq_lens, prefix_lens, use_ragged, encoder_lens
585
+ self,
586
+ req_pool_indices: torch.Tensor,
587
+ seq_lens: torch.Tensor,
588
+ seq_lens_sum: int,
589
+ prefix_lens: torch.Tensor,
590
+ use_ragged: bool,
591
+ encoder_lens: torch.Tensor,
554
592
  ):
555
593
  for wrapper_id in range(2):
556
594
  if wrapper_id == 0:
557
595
  # normal attention
558
596
  paged_kernel_lens = seq_lens
559
597
  kv_start_idx = encoder_lens
598
+ paged_kernel_lens_sum = seq_lens_sum
560
599
  else:
561
600
  # cross attention
562
601
  paged_kernel_lens = encoder_lens
563
602
  kv_start_idx = torch.zeros_like(encoder_lens)
603
+ paged_kernel_lens_sum = paged_kernel_lens.sum().item()
564
604
 
565
605
  self.call_begin_forward(
566
606
  self.wrapper_ragged,
567
607
  self.wrappers_paged[wrapper_id],
568
608
  req_pool_indices,
569
609
  paged_kernel_lens,
610
+ paged_kernel_lens_sum,
570
611
  seq_lens,
571
612
  prefix_lens,
572
613
  kv_start_idx,
@@ -579,19 +620,22 @@ class FlashInferIndicesUpdaterPrefill:
579
620
  self,
580
621
  wrapper_ragged,
581
622
  wrapper_paged,
582
- req_pool_indices,
583
- paged_kernel_lens,
584
- seq_lens,
585
- prefix_lens,
586
- kv_start_idx,
587
- kv_indptr,
588
- qo_indptr,
589
- use_ragged,
623
+ req_pool_indices: torch.Tensor,
624
+ paged_kernel_lens: torch.Tensor,
625
+ paged_kernel_lens_sum: int,
626
+ seq_lens: torch.Tensor,
627
+ prefix_lens: torch.Tensor,
628
+ kv_start_idx: torch.Tensor,
629
+ kv_indptr: torch.Tensor,
630
+ qo_indptr: torch.Tensor,
631
+ use_ragged: bool,
590
632
  ):
591
633
  bs = len(req_pool_indices)
592
634
  kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
593
635
  kv_indptr = kv_indptr[: bs + 1]
594
- kv_indices = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda")
636
+ kv_indices = torch.empty(
637
+ paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
638
+ )
595
639
  create_flashinfer_kv_indices_triton[(bs,)](
596
640
  self.req_to_token,
597
641
  req_pool_indices,
@@ -599,7 +643,7 @@ class FlashInferIndicesUpdaterPrefill:
599
643
  kv_indptr,
600
644
  kv_start_idx,
601
645
  kv_indices,
602
- self.max_context_len,
646
+ self.req_to_token.shape[1],
603
647
  )
604
648
 
605
649
  qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
@@ -638,10 +682,11 @@ def create_flashinfer_kv_indices_triton(
638
682
  kv_indptr,
639
683
  kv_start_idx,
640
684
  kv_indices_ptr,
641
- max_context_len: tl.constexpr,
685
+ req_to_token_ptr_stride: tl.constexpr,
642
686
  ):
643
687
  BLOCK_SIZE: tl.constexpr = 512
644
688
  pid = tl.program_id(axis=0)
689
+
645
690
  req_pool_index = tl.load(req_pool_indices_ptr + pid)
646
691
  kv_indices_offset = tl.load(kv_indptr + pid)
647
692
 
@@ -652,15 +697,15 @@ def create_flashinfer_kv_indices_triton(
652
697
  kv_end = kv_start
653
698
  kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
654
699
 
655
- req_to_token_ptr += req_pool_index * max_context_len
656
- kv_indices_ptr += kv_indices_offset
657
-
658
- ld_offset = kv_start + tl.arange(0, BLOCK_SIZE)
659
- st_offset = tl.arange(0, BLOCK_SIZE)
660
700
  num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
661
- for _ in range(num_loop):
662
- mask = ld_offset < kv_end
663
- data = tl.load(req_to_token_ptr + ld_offset, mask=mask)
664
- tl.store(kv_indices_ptr + st_offset, data, mask=mask)
665
- ld_offset += BLOCK_SIZE
666
- st_offset += BLOCK_SIZE
701
+ for i in range(num_loop):
702
+ offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
703
+ mask = offset < kv_end - kv_start
704
+ data = tl.load(
705
+ req_to_token_ptr
706
+ + req_pool_index * req_to_token_ptr_stride
707
+ + kv_start
708
+ + offset,
709
+ mask=mask,
710
+ )
711
+ tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask)
@@ -3,7 +3,6 @@ from __future__ import annotations
3
3
  from typing import TYPE_CHECKING
4
4
 
5
5
  import torch
6
- import torch.nn as nn
7
6
 
8
7
  from sglang.srt.layers.attention import AttentionBackend
9
8
  from sglang.srt.managers.schedule_batch import global_server_args_dict
@@ -28,9 +27,13 @@ class TritonAttnBackend(AttentionBackend):
28
27
 
29
28
  self.decode_attention_fwd = decode_attention_fwd
30
29
  self.extend_attention_fwd = extend_attention_fwd
31
- self.num_head = (
32
- model_runner.model_config.num_attention_heads // model_runner.tp_size
33
- )
30
+
31
+ if model_runner.server_args.enable_dp_attention:
32
+ self.num_head = model_runner.model_config.num_attention_heads
33
+ else:
34
+ self.num_head = (
35
+ model_runner.model_config.num_attention_heads // model_runner.tp_size
36
+ )
34
37
 
35
38
  if global_server_args_dict.get("triton_attention_reduce_in_fp32", False):
36
39
  self.reduce_dtype = torch.float32
@@ -50,7 +53,7 @@ class TritonAttnBackend(AttentionBackend):
50
53
  start_loc = torch.zeros_like(forward_batch.seq_lens, dtype=torch.int32)
51
54
  start_loc[1:] = torch.cumsum(forward_batch.seq_lens[:-1], dim=0)
52
55
 
53
- total_num_tokens = torch.sum(forward_batch.seq_lens).item()
56
+ total_num_tokens = forward_batch.seq_lens_sum
54
57
  attn_logits = torch.empty(
55
58
  (self.num_head, total_num_tokens),
56
59
  dtype=self.reduce_dtype,
@@ -61,8 +64,7 @@ class TritonAttnBackend(AttentionBackend):
61
64
  max_extend_len = None
62
65
  else:
63
66
  start_loc = attn_logits = max_seq_len = None
64
- prefix_lens = forward_batch.extend_prefix_lens
65
- max_extend_len = torch.max(forward_batch.seq_lens - prefix_lens).item()
67
+ max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
66
68
 
67
69
  self.forward_metadata = start_loc, attn_logits, max_seq_len, max_extend_len
68
70
 
@@ -0,0 +1,26 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
16
+ from vllm.model_executor.custom_op import CustomOp
17
+
18
+
19
+ def register_custom_op(op_name):
20
+ def decorator(cls):
21
+ if hasattr(CustomOp, "register"):
22
+ return CustomOp.register(op_name)(cls)
23
+ else:
24
+ return cls
25
+
26
+ return decorator
@@ -250,9 +250,12 @@ def invoke_fused_moe_kernel(
250
250
  assert topk_weights.stride(1) == 1
251
251
  assert sorted_token_ids.stride(0) == 1
252
252
 
253
+ padded_size = padding_size
253
254
  if not use_fp8:
254
255
  assert A_scale is None
255
256
  assert B_scale is None
257
+ # MOE_PADDING FP8 only
258
+ padded_size = 0
256
259
  else:
257
260
  A, A_scale = ops.scaled_fp8_quant(A, A_scale)
258
261
  assert B_scale is not None
@@ -262,7 +265,7 @@ def invoke_fused_moe_kernel(
262
265
  * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]),
263
266
  )
264
267
 
265
- K = B.shape[2] - padding_size
268
+ K = B.shape[2] - padded_size
266
269
  if K % config["BLOCK_SIZE_K"] == 0:
267
270
  even_ks = True
268
271
  else:
@@ -279,7 +282,7 @@ def invoke_fused_moe_kernel(
279
282
  expert_ids,
280
283
  num_tokens_post_padded,
281
284
  B.shape[1],
282
- B.shape[2] - padding_size,
285
+ B.shape[2] - padded_size,
283
286
  sorted_token_ids.shape[0],
284
287
  topk_ids.numel(),
285
288
  A.stride(0),
@@ -480,8 +483,12 @@ def fused_experts(
480
483
  a1_scale: Optional[torch.Tensor] = None,
481
484
  a2_scale: Optional[torch.Tensor] = None,
482
485
  ):
486
+ padded_size = padding_size
487
+ if not use_fp8:
488
+ # MOE_PADDING FP8 only
489
+ padded_size = 0
483
490
  # Check constraints.
484
- assert hidden_states.shape[1] == w1.shape[2] - padding_size, "Hidden size mismatch"
491
+ assert hidden_states.shape[1] == w1.shape[2] - padded_size, "Hidden size mismatch"
485
492
  assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
486
493
  assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
487
494
  assert w1.is_contiguous(), "Expert weights1 must be contiguous"
@@ -498,7 +505,7 @@ def fused_experts(
498
505
  get_config_func = functools.partial(
499
506
  try_get_optimal_moe_config,
500
507
  w1.shape,
501
- (w2.shape[0], w2.shape[1], w2.shape[2] - padding_size),
508
+ (w2.shape[0], w2.shape[1], w2.shape[2] - padded_size),
502
509
  topk_ids.shape[1],
503
510
  "float8" if use_fp8 else None,
504
511
  override_config=override_config,
@@ -1,4 +1,4 @@
1
- from typing import Optional
1
+ from typing import Callable, Optional
2
2
 
3
3
  import torch
4
4
  from torch.nn import functional as F
@@ -98,7 +98,9 @@ def fused_moe_forward_native(
98
98
  renormalize: bool,
99
99
  topk_group: Optional[int] = None,
100
100
  num_expert_group: Optional[int] = None,
101
+ custom_routing_function: Optional[Callable] = None,
101
102
  ) -> torch.Tensor:
103
+ assert custom_routing_function is None
102
104
  topk_weights, topk_ids = select_experts_native(
103
105
  hidden_states=x,
104
106
  router_logits=router_logits,
@@ -114,4 +116,4 @@ def fused_moe_forward_native(
114
116
  x1 = F.silu(torch.einsum("ti,taoi -> tao", x, w1_weights))
115
117
  x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
116
118
  expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
117
- return torch.einsum("tai,ta -> ti", expert_outs, topk_weights)
119
+ return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))
@@ -33,9 +33,12 @@ if is_flashinfer_available():
33
33
 
34
34
  from vllm.model_executor.custom_op import CustomOp
35
35
 
36
+ from sglang.srt.layers.custom_op_util import register_custom_op
37
+
36
38
  logger = logging.getLogger(__name__)
37
39
 
38
40
 
41
+ @register_custom_op("sglang_rmsnorm")
39
42
  class RMSNorm(CustomOp):
40
43
  def __init__(
41
44
  self,
@@ -78,6 +81,7 @@ class RMSNorm(CustomOp):
78
81
  return x, residual
79
82
 
80
83
 
84
+ @register_custom_op("sglang_gemma_rmsnorm")
81
85
  class GemmaRMSNorm(CustomOp):
82
86
  def __init__(
83
87
  self,
@@ -62,21 +62,21 @@ class LogitsMetadata:
62
62
 
63
63
  @classmethod
64
64
  def from_forward_batch(cls, forward_batch: ForwardBatch):
65
+ extend_logprob_pruned_lens_cpu = None
66
+
65
67
  if forward_batch.return_logprob:
66
68
  return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums)
69
+ if forward_batch.forward_mode.is_extend():
70
+ extend_logprob_pruned_lens_cpu = [
71
+ extend_len - start_len
72
+ for extend_len, start_len in zip(
73
+ forward_batch.extend_seq_lens_cpu,
74
+ forward_batch.extend_logprob_start_lens_cpu,
75
+ )
76
+ ]
67
77
  else:
68
78
  return_top_logprob = False
69
79
 
70
- if forward_batch.forward_mode.is_extend():
71
- extend_logprob_pruned_lens_cpu = [
72
- extend_len - start_len
73
- for extend_len, start_len in zip(
74
- forward_batch.extend_seq_lens,
75
- forward_batch.extend_logprob_start_lens_cpu,
76
- )
77
- ]
78
- else:
79
- extend_logprob_pruned_lens_cpu = None
80
80
  return cls(
81
81
  forward_mode=forward_batch.forward_mode,
82
82
  top_logprobs_nums=forward_batch.top_logprobs_nums,
@@ -1,5 +1,4 @@
1
1
  import logging
2
- import os
3
2
  from typing import Union
4
3
 
5
4
  import torch
@@ -8,7 +7,7 @@ from torch import nn
8
7
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
9
8
  from sglang.srt.managers.schedule_batch import global_server_args_dict
10
9
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
11
- from sglang.srt.utils import is_flashinfer_available
10
+ from sglang.srt.utils import crash_on_warnings, is_flashinfer_available
12
11
 
13
12
  if is_flashinfer_available():
14
13
  from flashinfer.sampling import (
@@ -19,17 +18,13 @@ if is_flashinfer_available():
19
18
  )
20
19
 
21
20
 
22
- # Crash on warning if we are running CI tests
23
- crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"
24
-
25
-
26
21
  logger = logging.getLogger(__name__)
27
22
 
28
23
 
29
24
  class Sampler(nn.Module):
30
25
  def __init__(self):
31
26
  super().__init__()
32
- self.use_nan_detectioin = not global_server_args_dict["disable_nan_detection"]
27
+ self.use_nan_detectioin = global_server_args_dict["enable_nan_detection"]
33
28
 
34
29
  def forward(
35
30
  self,
@@ -46,7 +41,8 @@ class Sampler(nn.Module):
46
41
  logits = torch.where(
47
42
  torch.isnan(logits), torch.full_like(logits, -1e5), logits
48
43
  )
49
- exit(1) if crash_on_warning else None
44
+ if crash_on_warnings():
45
+ raise ValueError("Detected errors during sampling! NaN in the logits.")
50
46
 
51
47
  if sampling_info.is_all_greedy:
52
48
  # Use torch.argmax if all requests use greedy sampling
@@ -62,6 +62,8 @@ def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str):
62
62
  granularity=GRANULARITY_MAP[granularity]
63
63
  ),
64
64
  )
65
+ else:
66
+ raise ValueError(f"Unexpected config: {torchao_config}")
65
67
 
66
68
  return dummy_linear.weight
67
69
 
@@ -17,6 +17,7 @@ limitations under the License.
17
17
 
18
18
  import logging
19
19
  import multiprocessing as mp
20
+ import threading
20
21
  from enum import Enum, auto
21
22
 
22
23
  import zmq
@@ -28,6 +29,7 @@ from sglang.srt.managers.io_struct import (
28
29
  from sglang.srt.managers.scheduler import run_scheduler_process
29
30
  from sglang.srt.server_args import PortArgs, ServerArgs
30
31
  from sglang.srt.utils import (
32
+ bind_port,
31
33
  configure_logger,
32
34
  get_zmq_socket,
33
35
  kill_parent_process,
@@ -80,20 +82,62 @@ class DataParallelController:
80
82
 
81
83
  # Start data parallel workers
82
84
  base_gpu_id = 0
83
- self.workers = []
85
+ self.workers = [None] * server_args.dp_size
86
+
87
+ threads = []
88
+ sockets = []
84
89
  for dp_rank in range(server_args.dp_size):
85
90
  tmp_port_args = PortArgs.init_new(server_args)
91
+ tmp_port_args.tokenizer_ipc_name = port_args.tokenizer_ipc_name
86
92
  tmp_port_args.detokenizer_ipc_name = port_args.detokenizer_ipc_name
87
93
 
88
- send_to = self.launch_tensor_parallel_group(
89
- server_args,
90
- tmp_port_args,
91
- base_gpu_id,
92
- dp_rank,
94
+ if server_args.enable_dp_attention:
95
+ # Data parallelism resues the tensor parallelism group,
96
+ # so all dp ranks should use the same nccl port.
97
+ tmp_port_args.nccl_port = port_args.nccl_port
98
+ else:
99
+ # This port is checked free in PortArgs.init_new.
100
+ # We hold it first so that the next dp worker gets a different port
101
+ sockets.append(bind_port(tmp_port_args.nccl_port))
102
+
103
+ # Create a thread for each worker
104
+ thread = threading.Thread(
105
+ target=self.launch_worker_func,
106
+ args=(server_args, tmp_port_args, base_gpu_id, dp_rank),
93
107
  )
108
+ threads.append(thread)
109
+ base_gpu_id += 1 if server_args.enable_dp_attention else server_args.tp_size
110
+
111
+ # Free all sockets before starting the threads to launch TP workers
112
+ for sock in sockets:
113
+ sock.close()
114
+
115
+ # Start all threads
116
+ for thread in threads:
117
+ thread.start()
118
+ for thread in threads:
119
+ thread.join()
120
+
121
+ def launch_worker_func(
122
+ self,
123
+ server_args: ServerArgs,
124
+ port_args: PortArgs,
125
+ base_gpu_id: int,
126
+ dp_rank: int,
127
+ ):
128
+ logger.info(f"Launch DP{dp_rank} starting at GPU #{base_gpu_id}.")
94
129
 
95
- self.workers.append(send_to)
96
- base_gpu_id += server_args.tp_size
130
+ launch_func_ = (
131
+ self.launch_tensor_parallel_process
132
+ if server_args.enable_dp_attention
133
+ else self.launch_tensor_parallel_group
134
+ )
135
+ self.workers[dp_rank] = launch_func_(
136
+ server_args,
137
+ port_args,
138
+ base_gpu_id,
139
+ dp_rank,
140
+ )
97
141
 
98
142
  def launch_tensor_parallel_group(
99
143
  self,
@@ -112,7 +156,7 @@ class DataParallelController:
112
156
  )
113
157
  for tp_rank in tp_rank_range:
114
158
  reader, writer = mp.Pipe(duplex=False)
115
- gpu_id = base_gpu_id + tp_rank % tp_size_per_node
159
+ gpu_id = server_args.base_gpu_id + base_gpu_id + tp_rank % tp_size_per_node
116
160
  proc = mp.Process(
117
161
  target=run_scheduler_process,
118
162
  args=(server_args, port_args, gpu_id, tp_rank, dp_rank, writer),
@@ -131,6 +175,27 @@ class DataParallelController:
131
175
 
132
176
  return send_to
133
177
 
178
+ def launch_tensor_parallel_process(
179
+ self,
180
+ server_args: ServerArgs,
181
+ port_args: PortArgs,
182
+ base_gpu_id: int,
183
+ dp_rank: int,
184
+ ):
185
+ reader, writer = mp.Pipe(duplex=False)
186
+ gpu_id = base_gpu_id
187
+ tp_rank = dp_rank
188
+ proc = mp.Process(
189
+ target=run_scheduler_process,
190
+ args=(server_args, port_args, gpu_id, tp_rank, dp_rank, writer),
191
+ )
192
+ proc.start()
193
+ send_to = get_zmq_socket(
194
+ self.context, zmq.PUSH, port_args.scheduler_input_ipc_name
195
+ )
196
+ reader.recv()
197
+ return send_to
198
+
134
199
  def round_robin_scheduler(self, req):
135
200
  self.workers[self.round_robin_counter].send_pyobj(req)
136
201
  self.round_robin_counter = (self.round_robin_counter + 1) % len(self.workers)