sglang 0.2.12__py3-none-any.whl → 0.2.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. sglang/api.py +7 -1
  2. sglang/bench_latency.py +3 -2
  3. sglang/global_config.py +1 -1
  4. sglang/lang/backend/runtime_endpoint.py +60 -49
  5. sglang/lang/interpreter.py +4 -2
  6. sglang/lang/ir.py +13 -4
  7. sglang/srt/constrained/jump_forward.py +13 -2
  8. sglang/srt/layers/activation.py +0 -1
  9. sglang/srt/layers/extend_attention.py +3 -1
  10. sglang/srt/layers/fused_moe/__init__.py +1 -0
  11. sglang/srt/layers/{fused_moe.py → fused_moe/fused_moe.py} +165 -108
  12. sglang/srt/layers/fused_moe/layer.py +587 -0
  13. sglang/srt/layers/logits_processor.py +4 -4
  14. sglang/srt/layers/radix_attention.py +38 -14
  15. sglang/srt/managers/schedule_batch.py +9 -14
  16. sglang/srt/managers/tokenizer_manager.py +1 -1
  17. sglang/srt/managers/tp_worker.py +1 -7
  18. sglang/srt/model_executor/cuda_graph_runner.py +48 -17
  19. sglang/srt/model_executor/forward_batch_info.py +132 -58
  20. sglang/srt/model_executor/model_runner.py +61 -28
  21. sglang/srt/models/chatglm.py +2 -2
  22. sglang/srt/models/commandr.py +1 -1
  23. sglang/srt/models/deepseek.py +2 -2
  24. sglang/srt/models/deepseek_v2.py +7 -6
  25. sglang/srt/models/gemma.py +1 -1
  26. sglang/srt/models/gemma2.py +11 -5
  27. sglang/srt/models/grok.py +50 -396
  28. sglang/srt/models/minicpm.py +2 -2
  29. sglang/srt/models/mixtral.py +56 -254
  30. sglang/srt/models/mixtral_quant.py +1 -4
  31. sglang/srt/models/qwen.py +2 -2
  32. sglang/srt/models/qwen2.py +2 -2
  33. sglang/srt/models/qwen2_moe.py +2 -2
  34. sglang/srt/models/stablelm.py +1 -1
  35. sglang/srt/openai_api/adapter.py +32 -21
  36. sglang/srt/sampling_params.py +0 -4
  37. sglang/srt/server.py +23 -15
  38. sglang/srt/server_args.py +7 -1
  39. sglang/srt/utils.py +1 -2
  40. sglang/test/runners.py +18 -10
  41. sglang/test/test_programs.py +32 -5
  42. sglang/test/test_utils.py +5 -1
  43. sglang/version.py +1 -1
  44. {sglang-0.2.12.dist-info → sglang-0.2.13.dist-info}/METADATA +12 -4
  45. {sglang-0.2.12.dist-info → sglang-0.2.13.dist-info}/RECORD +48 -48
  46. {sglang-0.2.12.dist-info → sglang-0.2.13.dist-info}/WHEEL +1 -1
  47. sglang/srt/model_loader/model_loader.py +0 -292
  48. sglang/srt/model_loader/utils.py +0 -275
  49. {sglang-0.2.12.dist-info → sglang-0.2.13.dist-info}/LICENSE +0 -0
  50. {sglang-0.2.12.dist-info → sglang-0.2.13.dist-info}/top_level.txt +0 -0
@@ -1,20 +1,5 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
-
7
- http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- Unless required by applicable law or agreed to in writing, software
10
- distributed under the License is distributed on an "AS IS" BASIS,
11
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- See the License for the specific language governing permissions and
13
- limitations under the License.
14
- """
15
-
16
1
  # Adapted from
17
- # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/layers/fused_moe/fused_moe.py#L1
2
+ # https://github.com/vllm-project/vllm/tree/v0.5.4/vllm/model_executor/layers/fused_moe
18
3
  """Fused MoE kernel."""
19
4
  import functools
20
5
  import json
@@ -24,6 +9,7 @@ from typing import Any, Dict, Optional, Tuple
24
9
  import torch
25
10
  import triton
26
11
  import triton.language as tl
12
+ import vllm.envs as envs
27
13
  from vllm import _custom_ops as ops
28
14
  from vllm.logger import init_logger
29
15
 
@@ -373,6 +359,31 @@ def get_default_config(
373
359
  return config
374
360
 
375
361
 
362
+ def try_get_optimal_moe_config(
363
+ w1_shape: Tuple[int, ...],
364
+ w2_shape: Tuple[int, ...],
365
+ top_k: int,
366
+ dtype: Optional[str],
367
+ M: int,
368
+ override_config: Optional[Dict[str, Any]] = None,
369
+ ):
370
+ if override_config:
371
+ config = override_config
372
+ else:
373
+ # First try to load optimal config from the file
374
+ E, _, N = w2_shape
375
+ configs = get_moe_configs(E, N, dtype)
376
+
377
+ if configs:
378
+ # If an optimal configuration map has been found, look up the
379
+ # optimal config
380
+ config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
381
+ else:
382
+ # Else use the default config
383
+ config = get_default_config(M, E, N, w1_shape[2], top_k, dtype)
384
+ return config
385
+
386
+
376
387
  def fused_topk(
377
388
  hidden_states: torch.Tensor,
378
389
  gating_output: torch.Tensor,
@@ -403,6 +414,41 @@ def fused_topk(
403
414
  return topk_weights, topk_ids
404
415
 
405
416
 
417
+ # This is used by the Deepseek-V2 model
418
+ def grouped_topk(
419
+ hidden_states: torch.Tensor,
420
+ gating_output: torch.Tensor,
421
+ topk: int,
422
+ renormalize: bool,
423
+ num_expert_group: int = 0,
424
+ topk_group: int = 0,
425
+ ):
426
+
427
+ assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
428
+
429
+ scores = torch.softmax(gating_output, dim=-1)
430
+ num_token = scores.shape[0]
431
+ group_scores = (
432
+ scores.view(num_token, num_expert_group, -1).max(dim=-1).values
433
+ ) # [n, n_group]
434
+ group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
435
+ 1
436
+ ] # [n, top_k_group]
437
+ group_mask = torch.zeros_like(group_scores) # [n, n_group]
438
+ group_mask.scatter_(1, group_idx, 1) # [n, n_group]
439
+ score_mask = (
440
+ group_mask.unsqueeze(-1)
441
+ .expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
442
+ .reshape(num_token, -1)
443
+ ) # [n, e]
444
+ tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
445
+ topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
446
+
447
+ if renormalize:
448
+ topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
449
+ return topk_weights, topk_ids
450
+
451
+
406
452
  def fused_experts(
407
453
  hidden_states: torch.Tensor,
408
454
  w1: torch.Tensor,
@@ -425,24 +471,23 @@ def fused_experts(
425
471
  assert w2.is_contiguous(), "Expert weights2 must be contiguous"
426
472
  assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16]
427
473
 
428
- M, _ = hidden_states.shape
474
+ num_tokens, _ = hidden_states.shape
429
475
  E, N, _ = w1.shape
476
+ # We execute the fused_moe kernel in chunks to circumvent this issue:
477
+ # https://github.com/vllm-project/vllm/issues/5938
478
+ CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
479
+ M = min(num_tokens, CHUNK_SIZE)
480
+
481
+ get_config_func = functools.partial(
482
+ try_get_optimal_moe_config,
483
+ w1.shape,
484
+ w2.shape,
485
+ topk_ids.shape[1],
486
+ "float8" if use_fp8 else None,
487
+ override_config=override_config,
488
+ )
430
489
 
431
- if override_config:
432
- config = override_config
433
- else:
434
- # First try to load optimal config from the file
435
- configs = get_moe_configs(E, w2.shape[2], "float8" if use_fp8 else None)
436
-
437
- if configs:
438
- # If an optimal configuration map has been found, look up the
439
- # optimal config
440
- config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
441
- else:
442
- # Else use the default config
443
- config = get_default_config(
444
- M, E, N, w1.shape[2], topk_ids.shape[1], "float8" if use_fp8 else None
445
- )
490
+ config = get_config_func(M)
446
491
 
447
492
  intermediate_cache1 = torch.empty(
448
493
  (M, topk_ids.shape[1], N),
@@ -460,56 +505,85 @@ def fused_experts(
460
505
  dtype=hidden_states.dtype,
461
506
  )
462
507
 
463
- sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
464
- topk_ids, config["BLOCK_SIZE_M"], E
465
- )
466
508
  compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
467
509
 
468
- invoke_fused_moe_kernel(
469
- hidden_states,
470
- w1,
471
- intermediate_cache1,
472
- a1_scale,
473
- w1_scale,
474
- topk_weights,
475
- topk_ids,
476
- sorted_token_ids,
477
- expert_ids,
478
- num_tokens_post_padded,
479
- False,
480
- topk_ids.shape[1],
481
- config,
482
- compute_type=compute_type,
483
- use_fp8=use_fp8,
484
- )
510
+ if inplace:
511
+ out_hidden_states = hidden_states
512
+ else:
513
+ out_hidden_states = torch.empty_like(hidden_states)
485
514
 
486
- ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
515
+ for chunk in range((num_tokens // CHUNK_SIZE) + 1):
516
+ begin_chunk_idx, end_chunk_idx = (
517
+ chunk * CHUNK_SIZE,
518
+ min((chunk + 1) * CHUNK_SIZE, num_tokens),
519
+ )
520
+ curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx]
521
+ tokens_in_chunk, _ = curr_hidden_states.shape
522
+
523
+ if tokens_in_chunk == 0:
524
+ break
525
+
526
+ if tokens_in_chunk < CHUNK_SIZE and chunk > 0:
527
+ # Adjust the intermediate cache size and config for the last
528
+ # chunk. Note that in most cases we only have one chunk
529
+ # so the cache size and config are already set correctly and
530
+ # do not need to be adjusted.
531
+ intermediate_cache1 = intermediate_cache1[:tokens_in_chunk]
532
+ intermediate_cache2 = intermediate_cache2[:tokens_in_chunk]
533
+ intermediate_cache3 = intermediate_cache3[:tokens_in_chunk]
534
+ config = get_config_func(tokens_in_chunk)
535
+
536
+ curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
537
+ curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
538
+
539
+ sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
540
+ curr_topk_ids, config["BLOCK_SIZE_M"], E
541
+ )
487
542
 
488
- invoke_fused_moe_kernel(
489
- intermediate_cache2,
490
- w2,
491
- intermediate_cache3,
492
- a2_scale,
493
- w2_scale,
494
- topk_weights,
495
- topk_ids,
496
- sorted_token_ids,
497
- expert_ids,
498
- num_tokens_post_padded,
499
- True,
500
- 1,
501
- config,
502
- compute_type=compute_type,
503
- use_fp8=use_fp8,
504
- )
543
+ invoke_fused_moe_kernel(
544
+ curr_hidden_states,
545
+ w1,
546
+ intermediate_cache1,
547
+ a1_scale,
548
+ w1_scale,
549
+ curr_topk_weights,
550
+ curr_topk_ids,
551
+ sorted_token_ids,
552
+ expert_ids,
553
+ num_tokens_post_padded,
554
+ False,
555
+ topk_ids.shape[1],
556
+ config,
557
+ compute_type=compute_type,
558
+ use_fp8=use_fp8,
559
+ )
505
560
 
506
- if inplace:
507
- return torch.sum(
561
+ ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
562
+
563
+ invoke_fused_moe_kernel(
564
+ intermediate_cache2,
565
+ w2,
566
+ intermediate_cache3,
567
+ a2_scale,
568
+ w2_scale,
569
+ curr_topk_weights,
570
+ curr_topk_ids,
571
+ sorted_token_ids,
572
+ expert_ids,
573
+ num_tokens_post_padded,
574
+ True,
575
+ 1,
576
+ config,
577
+ compute_type=compute_type,
578
+ use_fp8=use_fp8,
579
+ )
580
+
581
+ torch.sum(
508
582
  intermediate_cache3.view(*intermediate_cache3.shape),
509
583
  dim=1,
510
- out=hidden_states,
584
+ out=out_hidden_states[begin_chunk_idx:end_chunk_idx],
511
585
  )
512
- return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1)
586
+ return out_hidden_states
513
587
 
514
588
 
515
589
  def fused_moe(
@@ -521,6 +595,9 @@ def fused_moe(
521
595
  renormalize: bool,
522
596
  inplace: bool = False,
523
597
  override_config: Optional[Dict[str, Any]] = None,
598
+ use_grouped_topk: bool = False,
599
+ num_expert_group: Optional[int] = None,
600
+ topk_group: Optional[int] = None,
524
601
  use_fp8: bool = False,
525
602
  w1_scale: Optional[torch.Tensor] = None,
526
603
  w2_scale: Optional[torch.Tensor] = None,
@@ -543,6 +620,10 @@ def fused_moe(
543
620
  Defaults to False.
544
621
  - override_config (Optional[Dict[str, Any]]): Optional override
545
622
  for the kernel configuration.
623
+ - num_expert_group: Optional[int]: additional parameter for grouped_topk
624
+ - topk_group: Optional[int]: additional parameter for grouped_topk
625
+ - use_grouped_topk: If True, use grouped_topk instead of fused_topk
626
+ note: Deepseekv2 model uses grouped_topk
546
627
  - use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
547
628
  products for w1 and w2. Defaults to False.
548
629
  - w1_scale (Optional[torch.Tensor]): Optional scale to be used for
@@ -556,12 +637,18 @@ def fused_moe(
556
637
  # Check constraints.
557
638
  assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
558
639
 
559
- if hasattr(ops, "topk_softmax"):
560
- topk_weights, topk_ids = fused_topk(
561
- hidden_states, gating_output, topk, renormalize
640
+ if use_grouped_topk:
641
+ assert num_expert_group is not None and topk_group is not None
642
+ topk_weights, topk_ids = grouped_topk(
643
+ hidden_states,
644
+ gating_output,
645
+ topk,
646
+ renormalize,
647
+ num_expert_group,
648
+ topk_group,
562
649
  )
563
650
  else:
564
- topk_weights, topk_ids = fused_topk_v0_4_3(
651
+ topk_weights, topk_ids = fused_topk(
565
652
  hidden_states, gating_output, topk, renormalize
566
653
  )
567
654
 
@@ -579,33 +666,3 @@ def fused_moe(
579
666
  a1_scale=a1_scale,
580
667
  a2_scale=a2_scale,
581
668
  )
582
-
583
-
584
- def fused_topk_v0_4_3(
585
- hidden_states: torch.Tensor,
586
- gating_output: torch.Tensor,
587
- topk: int,
588
- renormalize: bool,
589
- ):
590
- import vllm._moe_C as moe_kernels
591
-
592
- M, _ = hidden_states.shape
593
-
594
- topk_weights = torch.empty(
595
- M, topk, dtype=torch.float32, device=hidden_states.device
596
- )
597
- topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
598
- token_expert_indicies = torch.empty(
599
- M, topk, dtype=torch.int32, device=hidden_states.device
600
- )
601
- moe_kernels.topk_softmax(
602
- topk_weights,
603
- topk_ids,
604
- token_expert_indicies,
605
- gating_output.float(), # TODO(woosuk): Optimize this.
606
- )
607
- del token_expert_indicies # Not used. Will be used in the future.
608
- if renormalize:
609
- topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
610
-
611
- return topk_weights, topk_ids