sglang 0.4.0.post2__py3-none-any.whl → 0.4.1.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. sglang/bench_offline_throughput.py +0 -12
  2. sglang/bench_one_batch.py +0 -12
  3. sglang/bench_serving.py +11 -2
  4. sglang/lang/backend/openai.py +10 -0
  5. sglang/srt/aio_rwlock.py +100 -0
  6. sglang/srt/configs/model_config.py +8 -1
  7. sglang/srt/constrained/xgrammar_backend.py +6 -0
  8. sglang/srt/layers/attention/flashinfer_backend.py +49 -5
  9. sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -14
  10. sglang/srt/layers/linear.py +20 -2
  11. sglang/srt/layers/{ep_moe → moe/ep_moe}/layer.py +14 -39
  12. sglang/srt/layers/moe/fused_moe_native.py +46 -0
  13. sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/__init__.py +3 -7
  14. sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/fused_moe.py +124 -99
  15. sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/layer.py +16 -48
  16. sglang/srt/layers/moe/topk.py +205 -0
  17. sglang/srt/layers/quantization/__init__.py +3 -3
  18. sglang/srt/layers/quantization/fp8.py +169 -32
  19. sglang/srt/layers/quantization/fp8_kernel.py +292 -0
  20. sglang/srt/layers/quantization/fp8_utils.py +90 -1
  21. sglang/srt/layers/torchao_utils.py +11 -15
  22. sglang/srt/managers/schedule_batch.py +16 -10
  23. sglang/srt/managers/schedule_policy.py +1 -1
  24. sglang/srt/managers/scheduler.py +13 -16
  25. sglang/srt/managers/tokenizer_manager.py +130 -111
  26. sglang/srt/mem_cache/memory_pool.py +15 -8
  27. sglang/srt/model_executor/cuda_graph_runner.py +1 -1
  28. sglang/srt/model_loader/loader.py +22 -11
  29. sglang/srt/models/dbrx.py +1 -1
  30. sglang/srt/models/deepseek.py +1 -1
  31. sglang/srt/models/deepseek_v2.py +67 -18
  32. sglang/srt/models/gemma2.py +19 -0
  33. sglang/srt/models/grok.py +1 -1
  34. sglang/srt/models/llama.py +2 -2
  35. sglang/srt/models/mixtral.py +2 -2
  36. sglang/srt/models/olmoe.py +1 -1
  37. sglang/srt/models/qwen2_moe.py +1 -1
  38. sglang/srt/models/xverse_moe.py +1 -1
  39. sglang/srt/openai_api/adapter.py +23 -0
  40. sglang/srt/openai_api/protocol.py +2 -0
  41. sglang/srt/sampling/sampling_params.py +9 -2
  42. sglang/srt/server.py +21 -37
  43. sglang/srt/utils.py +33 -44
  44. sglang/test/test_block_fp8.py +341 -0
  45. sglang/version.py +1 -1
  46. {sglang-0.4.0.post2.dist-info → sglang-0.4.1.post1.dist-info}/METADATA +4 -4
  47. {sglang-0.4.0.post2.dist-info → sglang-0.4.1.post1.dist-info}/RECORD +52 -48
  48. sglang/srt/layers/fused_moe_patch.py +0 -133
  49. /sglang/srt/layers/{ep_moe → moe/ep_moe}/__init__.py +0 -0
  50. /sglang/srt/layers/{ep_moe → moe/ep_moe}/kernels.py +0 -0
  51. {sglang-0.4.0.post2.dist-info → sglang-0.4.1.post1.dist-info}/LICENSE +0 -0
  52. {sglang-0.4.0.post2.dist-info → sglang-0.4.1.post1.dist-info}/WHEEL +0 -0
  53. {sglang-0.4.0.post2.dist-info → sglang-0.4.1.post1.dist-info}/top_level.txt +0 -0
@@ -1,133 +0,0 @@
1
- """
2
- Torch-native implementation for FusedMoE. This is used for torch.compile.
3
- It is based on https://github.com/pytorch-labs/gpt-fast/blob/32971d3129541c5bfb4f715abc33d1c5f408d204/mixtral-moe/model.py#L204
4
- """
5
-
6
- from typing import Callable, Optional
7
-
8
- import torch
9
- from torch.nn import functional as F
10
-
11
-
12
- def fused_topk_native(
13
- hidden_states: torch.Tensor,
14
- gating_output: torch.Tensor,
15
- topk: int,
16
- renormalize: bool,
17
- ):
18
- assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
19
- M, _ = hidden_states.shape
20
- topk_weights = torch.empty(
21
- M, topk, dtype=torch.float32, device=hidden_states.device
22
- )
23
- topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
24
- topk_weights = F.softmax(gating_output.float(), dim=-1)
25
- topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1)
26
- if renormalize:
27
- topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
28
- return topk_weights, topk_ids
29
-
30
-
31
- # This is used by the Deepseek-V2 model
32
- def grouped_topk(
33
- hidden_states: torch.Tensor,
34
- gating_output: torch.Tensor,
35
- topk: int,
36
- renormalize: bool,
37
- num_expert_group: int = 0,
38
- topk_group: int = 0,
39
- ):
40
-
41
- assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
42
-
43
- scores = torch.softmax(gating_output, dim=-1)
44
- num_token = scores.shape[0]
45
- group_scores = (
46
- scores.view(num_token, num_expert_group, -1).max(dim=-1).values
47
- ) # [n, n_group]
48
- group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
49
- 1
50
- ] # [n, top_k_group]
51
- group_mask = torch.zeros_like(group_scores) # [n, n_group]
52
- group_mask.scatter_(1, group_idx, 1) # [n, n_group]
53
- score_mask = (
54
- group_mask.unsqueeze(-1)
55
- .expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
56
- .reshape(num_token, -1)
57
- ) # [n, e]
58
- tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
59
- topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
60
-
61
- if renormalize:
62
- topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
63
- return topk_weights, topk_ids
64
-
65
-
66
- def select_experts_native(
67
- hidden_states: torch.Tensor,
68
- router_logits: torch.Tensor,
69
- top_k: int,
70
- use_grouped_topk: bool,
71
- renormalize: bool,
72
- topk_group: Optional[int] = None,
73
- num_expert_group: Optional[int] = None,
74
- ):
75
- # DeekSeekv2 uses grouped_top_k
76
- if use_grouped_topk:
77
- assert topk_group is not None
78
- assert num_expert_group is not None
79
- topk_weights, topk_ids = grouped_topk(
80
- hidden_states=hidden_states,
81
- gating_output=router_logits,
82
- topk=top_k,
83
- renormalize=renormalize,
84
- num_expert_group=num_expert_group,
85
- topk_group=topk_group,
86
- )
87
- else:
88
- topk_weights, topk_ids = fused_topk_native(
89
- hidden_states=hidden_states,
90
- gating_output=router_logits,
91
- topk=top_k,
92
- renormalize=renormalize,
93
- )
94
- return topk_weights, topk_ids
95
-
96
-
97
- def fused_moe_forward_native(
98
- layer: torch.nn.Module,
99
- x: torch.Tensor,
100
- use_grouped_topk: bool,
101
- top_k: int,
102
- router_logits: torch.Tensor,
103
- renormalize: bool,
104
- topk_group: Optional[int] = None,
105
- num_expert_group: Optional[int] = None,
106
- custom_routing_function: Optional[Callable] = None,
107
- ) -> torch.Tensor:
108
-
109
- if use_grouped_topk:
110
- assert num_expert_group is not None and topk_group is not None
111
- topk_weights, topk_ids = grouped_topk(
112
- x,
113
- router_logits,
114
- top_k,
115
- renormalize,
116
- num_expert_group,
117
- topk_group,
118
- )
119
- elif custom_routing_function is None:
120
- topk_weights, topk_ids = fused_topk_native(x, router_logits, top_k, renormalize)
121
- else:
122
- topk_weights, topk_ids = custom_routing_function(
123
- x, router_logits, top_k, renormalize
124
- )
125
-
126
- w13_weights = layer.w13_weight[topk_ids]
127
- w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)
128
- w2_weights = layer.w2_weight[topk_ids]
129
- x1 = torch.einsum("ti,taoi -> tao", x, w1_weights)
130
- x1 = F.silu(x1)
131
- x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
132
- expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
133
- return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))
File without changes
File without changes