sglang 0.4.1.post6__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (141) hide show
  1. sglang/__init__.py +21 -23
  2. sglang/api.py +2 -7
  3. sglang/bench_offline_throughput.py +41 -27
  4. sglang/bench_one_batch.py +60 -4
  5. sglang/bench_one_batch_server.py +1 -1
  6. sglang/bench_serving.py +83 -71
  7. sglang/lang/backend/runtime_endpoint.py +183 -4
  8. sglang/lang/chat_template.py +46 -4
  9. sglang/launch_server.py +1 -1
  10. sglang/srt/_custom_ops.py +80 -42
  11. sglang/srt/configs/device_config.py +1 -1
  12. sglang/srt/configs/load_config.py +1 -0
  13. sglang/srt/configs/model_config.py +1 -0
  14. sglang/srt/constrained/base_grammar_backend.py +21 -0
  15. sglang/srt/constrained/xgrammar_backend.py +8 -4
  16. sglang/srt/conversation.py +14 -1
  17. sglang/srt/distributed/__init__.py +3 -3
  18. sglang/srt/distributed/communication_op.py +2 -1
  19. sglang/srt/distributed/device_communicators/cuda_wrapper.py +2 -1
  20. sglang/srt/distributed/device_communicators/custom_all_reduce.py +112 -42
  21. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
  22. sglang/srt/distributed/device_communicators/hpu_communicator.py +2 -1
  23. sglang/srt/distributed/device_communicators/pynccl.py +80 -1
  24. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +112 -2
  25. sglang/srt/distributed/device_communicators/shm_broadcast.py +5 -72
  26. sglang/srt/distributed/device_communicators/xpu_communicator.py +2 -1
  27. sglang/srt/distributed/parallel_state.py +1 -1
  28. sglang/srt/distributed/utils.py +2 -1
  29. sglang/srt/entrypoints/engine.py +452 -0
  30. sglang/srt/entrypoints/http_server.py +603 -0
  31. sglang/srt/function_call_parser.py +494 -0
  32. sglang/srt/layers/activation.py +8 -8
  33. sglang/srt/layers/attention/flashinfer_backend.py +10 -9
  34. sglang/srt/layers/attention/triton_backend.py +4 -6
  35. sglang/srt/layers/attention/vision.py +204 -0
  36. sglang/srt/layers/dp_attention.py +71 -0
  37. sglang/srt/layers/layernorm.py +5 -5
  38. sglang/srt/layers/linear.py +65 -14
  39. sglang/srt/layers/logits_processor.py +49 -64
  40. sglang/srt/layers/moe/ep_moe/layer.py +24 -16
  41. sglang/srt/layers/moe/fused_moe_native.py +84 -1
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  43. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +27 -7
  44. sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -5
  45. sglang/srt/layers/parameter.py +18 -8
  46. sglang/srt/layers/quantization/__init__.py +20 -23
  47. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  48. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  49. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  50. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  51. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  52. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  53. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  54. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  55. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  56. sglang/srt/layers/quantization/fp8.py +10 -4
  57. sglang/srt/layers/quantization/modelopt_quant.py +1 -2
  58. sglang/srt/layers/quantization/w8a8_int8.py +1 -1
  59. sglang/srt/layers/radix_attention.py +2 -2
  60. sglang/srt/layers/rotary_embedding.py +1184 -31
  61. sglang/srt/layers/sampler.py +64 -6
  62. sglang/srt/layers/torchao_utils.py +12 -6
  63. sglang/srt/layers/vocab_parallel_embedding.py +2 -2
  64. sglang/srt/lora/lora.py +1 -9
  65. sglang/srt/managers/configure_logging.py +3 -0
  66. sglang/srt/managers/data_parallel_controller.py +79 -72
  67. sglang/srt/managers/detokenizer_manager.py +24 -6
  68. sglang/srt/managers/image_processor.py +158 -2
  69. sglang/srt/managers/io_struct.py +57 -3
  70. sglang/srt/managers/schedule_batch.py +78 -45
  71. sglang/srt/managers/schedule_policy.py +26 -12
  72. sglang/srt/managers/scheduler.py +326 -201
  73. sglang/srt/managers/session_controller.py +1 -0
  74. sglang/srt/managers/tokenizer_manager.py +210 -121
  75. sglang/srt/managers/tp_worker.py +6 -4
  76. sglang/srt/managers/tp_worker_overlap_thread.py +5 -8
  77. sglang/srt/managers/utils.py +44 -0
  78. sglang/srt/mem_cache/memory_pool.py +10 -32
  79. sglang/srt/metrics/collector.py +15 -6
  80. sglang/srt/model_executor/cuda_graph_runner.py +26 -30
  81. sglang/srt/model_executor/forward_batch_info.py +5 -7
  82. sglang/srt/model_executor/model_runner.py +44 -19
  83. sglang/srt/model_loader/loader.py +83 -6
  84. sglang/srt/model_loader/weight_utils.py +145 -6
  85. sglang/srt/models/baichuan.py +6 -6
  86. sglang/srt/models/chatglm.py +2 -2
  87. sglang/srt/models/commandr.py +17 -5
  88. sglang/srt/models/dbrx.py +13 -5
  89. sglang/srt/models/deepseek.py +3 -3
  90. sglang/srt/models/deepseek_v2.py +11 -11
  91. sglang/srt/models/exaone.py +2 -2
  92. sglang/srt/models/gemma.py +2 -2
  93. sglang/srt/models/gemma2.py +15 -25
  94. sglang/srt/models/gpt2.py +3 -5
  95. sglang/srt/models/gpt_bigcode.py +1 -1
  96. sglang/srt/models/granite.py +2 -2
  97. sglang/srt/models/grok.py +4 -3
  98. sglang/srt/models/internlm2.py +2 -2
  99. sglang/srt/models/llama.py +7 -5
  100. sglang/srt/models/minicpm.py +2 -2
  101. sglang/srt/models/minicpm3.py +9 -9
  102. sglang/srt/models/minicpmv.py +1238 -0
  103. sglang/srt/models/mixtral.py +3 -3
  104. sglang/srt/models/mixtral_quant.py +3 -3
  105. sglang/srt/models/mllama.py +2 -2
  106. sglang/srt/models/olmo.py +3 -3
  107. sglang/srt/models/olmo2.py +4 -4
  108. sglang/srt/models/olmoe.py +7 -13
  109. sglang/srt/models/phi3_small.py +2 -2
  110. sglang/srt/models/qwen.py +2 -2
  111. sglang/srt/models/qwen2.py +41 -4
  112. sglang/srt/models/qwen2_moe.py +3 -3
  113. sglang/srt/models/qwen2_vl.py +22 -122
  114. sglang/srt/models/stablelm.py +2 -2
  115. sglang/srt/models/torch_native_llama.py +20 -7
  116. sglang/srt/models/xverse.py +6 -6
  117. sglang/srt/models/xverse_moe.py +6 -6
  118. sglang/srt/openai_api/adapter.py +139 -37
  119. sglang/srt/openai_api/protocol.py +7 -4
  120. sglang/srt/sampling/custom_logit_processor.py +38 -0
  121. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +11 -14
  122. sglang/srt/sampling/sampling_batch_info.py +143 -18
  123. sglang/srt/sampling/sampling_params.py +3 -1
  124. sglang/srt/server.py +4 -1090
  125. sglang/srt/server_args.py +77 -15
  126. sglang/srt/speculative/eagle_utils.py +37 -15
  127. sglang/srt/speculative/eagle_worker.py +11 -13
  128. sglang/srt/utils.py +164 -129
  129. sglang/test/runners.py +8 -13
  130. sglang/test/test_programs.py +2 -1
  131. sglang/test/test_utils.py +83 -22
  132. sglang/utils.py +12 -2
  133. sglang/version.py +1 -1
  134. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/METADATA +21 -10
  135. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/RECORD +138 -123
  136. sglang/launch_server_llavavid.py +0 -25
  137. sglang/srt/constrained/__init__.py +0 -16
  138. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  139. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/LICENSE +0 -0
  140. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/WHEEL +0 -0
  141. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,204 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Optional
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from einops import rearrange, repeat
8
+
9
+ from sglang.srt.distributed import parallel_state
10
+ from sglang.srt.distributed import utils as dist_utils
11
+ from sglang.srt.layers.attention.triton_ops.prefill_attention import (
12
+ context_attention_fwd,
13
+ )
14
+ from sglang.srt.layers.linear import (
15
+ ColumnParallelLinear,
16
+ QKVParallelLinear,
17
+ RowParallelLinear,
18
+ )
19
+ from sglang.srt.layers.quantization import QuantizationConfig
20
+
21
+
22
+ def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
23
+ if not interleaved:
24
+ x1, x2 = x.chunk(2, dim=-1)
25
+ return torch.cat((-x2, x1), dim=-1)
26
+ else:
27
+ x1, x2 = x[..., ::2], x[..., 1::2]
28
+ return rearrange(
29
+ torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
30
+ )
31
+
32
+
33
+ def apply_rotary_emb_torch(
34
+ x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False
35
+ ) -> torch.Tensor:
36
+ """
37
+ x: (batch_size, seqlen, nheads, headdim)
38
+ cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
39
+ """
40
+ ro_dim = cos.shape[-1] * 2
41
+ assert ro_dim <= x.shape[-1]
42
+ cos = repeat(
43
+ cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
44
+ )
45
+ sin = repeat(
46
+ sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
47
+ )
48
+ return torch.cat(
49
+ [
50
+ x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
51
+ x[..., ro_dim:],
52
+ ],
53
+ dim=-1,
54
+ )
55
+
56
+
57
+ def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
58
+ t_ = t.float()
59
+ cos = freqs.cos()
60
+ sin = freqs.sin()
61
+ output = apply_rotary_emb_torch(t_, cos, sin).type_as(t)
62
+ return output
63
+
64
+
65
+ class VisionAttention(nn.Module):
66
+ """Multi-headed attention without any cache, mostly used for ViT."""
67
+
68
+ def __init__(
69
+ self,
70
+ embed_dim: int,
71
+ num_heads: int,
72
+ projection_size: int,
73
+ use_qkv_parallel: bool,
74
+ quant_config: Optional[QuantizationConfig] = None,
75
+ prefix: str = "",
76
+ ):
77
+ super().__init__()
78
+ world_size = parallel_state.get_tensor_model_parallel_world_size()
79
+
80
+ self.hidden_size_per_attention_head = dist_utils.divide(
81
+ projection_size, num_heads
82
+ )
83
+ self.num_attention_heads_per_partition = dist_utils.divide(
84
+ num_heads, world_size
85
+ )
86
+ # self.tp_size = get_tensor_model_parallel_world_size()
87
+ # num_heads = self.num_heads_per_partition
88
+ self.use_qkv_parallel = use_qkv_parallel
89
+ if use_qkv_parallel:
90
+ self.head_dim = embed_dim // num_heads
91
+ self.qkv_proj = QKVParallelLinear(
92
+ hidden_size=embed_dim,
93
+ head_size=self.head_dim,
94
+ total_num_heads=num_heads,
95
+ quant_config=quant_config,
96
+ prefix=f"{prefix}.qkv_proj",
97
+ )
98
+ else:
99
+ self.qkv_proj = ColumnParallelLinear(
100
+ input_size=embed_dim,
101
+ output_size=3 * projection_size,
102
+ quant_config=quant_config,
103
+ prefix=f"{prefix}.qkv_proj",
104
+ )
105
+ self.proj = RowParallelLinear(
106
+ input_size=embed_dim,
107
+ output_size=embed_dim,
108
+ quant_config=quant_config,
109
+ prefix=f"{prefix}.out_proj",
110
+ )
111
+
112
+ def forward(
113
+ self,
114
+ x: torch.Tensor,
115
+ cu_seqlens: Optional[torch.Tensor] = None,
116
+ rotary_pos_emb: torch.Tensor = None,
117
+ ) -> torch.Tensor:
118
+ """
119
+ Input shape: [b, s, embed_dim]
120
+ Output shape: [s, b, num_heads * head_size]
121
+ """
122
+
123
+ bsz, s, _ = x.shape
124
+ if self.use_qkv_parallel:
125
+ # [b, s, embed_dim] --> [b, s, embed_dim]
126
+ qkv, _ = self.qkv_proj(x)
127
+ q, k, v = qkv.chunk(3, dim=-1)
128
+
129
+ # [b, s, embed_dim] --> [b * s, num_heads, head_size]
130
+ q, k, v = [
131
+ x.reshape(
132
+ bsz * s, self.num_attention_heads_per_partition, -1
133
+ ).contiguous()
134
+ for x in (q, k, v)
135
+ ]
136
+ else:
137
+ # [b, s, embed_dim] --> [s, b, embed_dim]
138
+ x = rearrange(x, "b s ... -> s b ...")
139
+ # [s, b, embed_dim] --> [s, b, head * 3 * head_dim]
140
+ qkv, _ = self.qkv_proj(x)
141
+ # [s, b, head * 3 * head_dim] --> [s, b, head, 3 * head_dim]
142
+ new_x_shape = qkv.size()[:-1] + (
143
+ self.num_attention_heads_per_partition,
144
+ 3 * self.hidden_size_per_attention_head,
145
+ )
146
+ qkv = qkv.view(*new_x_shape)
147
+
148
+ # [s, b, head, 3 * head_dim] --> 3 [s, b, head, head_dim]
149
+ q, k, v = dist_utils.split_tensor_along_last_dim(qkv, 3)
150
+
151
+ # [s, b, head, head_dim] --> [b, s, head, head_dim]
152
+ q, k, v = [
153
+ rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)
154
+ ]
155
+
156
+ if rotary_pos_emb is not None:
157
+ q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
158
+ k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)
159
+
160
+ if self.use_qkv_parallel:
161
+ pass
162
+ else:
163
+ # [b, s, head, head_dim] --> [b * s, head, head_dim]
164
+ q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]
165
+
166
+ # [b * s, num_heads, head_size]
167
+ output = torch.empty_like(q)
168
+
169
+ seq_lens = (cu_seqlens[1:] - cu_seqlens[:-1]).cuda()
170
+ max_seqlen = seq_lens.max().item()
171
+
172
+ context_attention_fwd(
173
+ q,
174
+ k,
175
+ v,
176
+ output,
177
+ cu_seqlens.cuda(),
178
+ seq_lens,
179
+ max_seqlen,
180
+ is_causal=False,
181
+ )
182
+
183
+ if self.use_qkv_parallel:
184
+
185
+ # [b * s, head, head_dim] --> [b, s, head * head_dim]
186
+ output = rearrange(output, "(b s) ... h d -> b s ... (h d)", b=bsz)
187
+
188
+ # [b, s, head, head_dim] --> [b, s, head, head_dim]
189
+ output, _ = self.proj(output)
190
+ else:
191
+ # [b * s, head, head_dim] --> [b, s, head, head_dim]
192
+ context_layer = rearrange(output, "(b s) ... -> b s ...", b=bsz)
193
+
194
+ # [s, b, num_heads * head_size]
195
+ context_layer = rearrange(
196
+ context_layer, "b s h d -> s b (h d)"
197
+ ).contiguous()
198
+
199
+ # [s, b, num_heads * head_size] --> [s, b, num_heads * head_size]
200
+ output, _ = self.proj(context_layer)
201
+
202
+ output = output.view(bsz, s, -1)
203
+
204
+ return output
@@ -0,0 +1,71 @@
1
+ import torch
2
+
3
+ from sglang.srt.distributed import GroupCoordinator, get_tp_group
4
+
5
+ _ATTN_TP_GROUP = None
6
+ _ATTN_TP_RANK = None
7
+ _ATTN_TP_SIZE = None
8
+ _DP_RANK = None
9
+ _DP_SIZE = None
10
+
11
+
12
+ def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
13
+ if not enable_dp_attention:
14
+ return tp_rank, tp_size, 0
15
+
16
+ attn_tp_size = tp_size // dp_size
17
+ dp_rank = tp_rank // attn_tp_size
18
+ attn_tp_rank = tp_rank % attn_tp_size
19
+ return attn_tp_rank, attn_tp_size, dp_rank
20
+
21
+
22
+ def initialize_dp_attention(enable_dp_attention, tp_rank, tp_size, dp_size):
23
+ global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK, _DP_SIZE
24
+
25
+ from sglang.srt.layers.sampler import SYNC_TOKEN_IDS_ACROSS_TP
26
+
27
+ _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK = compute_dp_attention_world_info(
28
+ enable_dp_attention, tp_rank, tp_size, dp_size
29
+ )
30
+ _DP_SIZE = dp_size
31
+
32
+ tp_group = get_tp_group()
33
+ _ATTN_TP_GROUP = GroupCoordinator(
34
+ [
35
+ list(range(head, head + _ATTN_TP_SIZE))
36
+ for head in range(0, tp_size, _ATTN_TP_SIZE)
37
+ ],
38
+ tp_rank,
39
+ torch.distributed.get_backend(tp_group.device_group),
40
+ SYNC_TOKEN_IDS_ACROSS_TP,
41
+ False,
42
+ False,
43
+ False,
44
+ False,
45
+ group_name="attention_tp",
46
+ )
47
+
48
+
49
+ def get_attention_tp_group():
50
+ assert _ATTN_TP_GROUP is not None, "dp attention not initialized!"
51
+ return _ATTN_TP_GROUP
52
+
53
+
54
+ def get_attention_tp_rank():
55
+ assert _ATTN_TP_RANK is not None, "dp attention not initialized!"
56
+ return _ATTN_TP_RANK
57
+
58
+
59
+ def get_attention_tp_size():
60
+ assert _ATTN_TP_SIZE is not None, "dp attention not initialized!"
61
+ return _ATTN_TP_SIZE
62
+
63
+
64
+ def get_attention_dp_rank():
65
+ assert _DP_RANK is not None, "dp attention not initialized!"
66
+ return _DP_RANK
67
+
68
+
69
+ def get_attention_dp_size():
70
+ assert _DP_SIZE is not None, "dp attention not initialized!"
71
+ return _DP_SIZE
@@ -19,10 +19,10 @@ from typing import Optional, Tuple, Union
19
19
  import torch
20
20
  import torch.nn as nn
21
21
 
22
- from sglang.srt.utils import is_flashinfer_available
22
+ from sglang.srt.utils import is_cuda_available
23
23
 
24
- if is_flashinfer_available():
25
- from flashinfer.norm import (
24
+ if is_cuda_available():
25
+ from sgl_kernel import (
26
26
  fused_add_rmsnorm,
27
27
  gemma_fused_add_rmsnorm,
28
28
  gemma_rmsnorm,
@@ -121,8 +121,8 @@ class GemmaRMSNorm(CustomOp):
121
121
  return out
122
122
 
123
123
 
124
- if not is_flashinfer_available():
124
+ if not is_cuda_available():
125
125
  logger.info(
126
- "FlashInfer is not available on Non-NV platforms. Fallback to other kernel libraries."
126
+ "sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."
127
127
  )
128
128
  from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
@@ -7,7 +7,8 @@ from typing import Dict, List, Optional, Tuple
7
7
  import torch
8
8
  import torch.nn.functional as F
9
9
  from torch.nn.parameter import Parameter, UninitializedParameter
10
- from vllm.distributed import (
10
+
11
+ from sglang.srt.distributed import (
11
12
  divide,
12
13
  get_tensor_model_parallel_rank,
13
14
  get_tensor_model_parallel_world_size,
@@ -15,10 +16,6 @@ from vllm.distributed import (
15
16
  tensor_model_parallel_all_gather,
16
17
  tensor_model_parallel_all_reduce,
17
18
  )
18
-
19
- # Workaround: many QuantizationConfig still depends on this, so we have to use vLLM's LinearBase now.
20
- from vllm.model_executor.layers.linear import LinearBase
21
-
22
19
  from sglang.srt.layers.parameter import (
23
20
  BasevLLMParameter,
24
21
  PackedColumnParameter,
@@ -174,6 +171,45 @@ class UnquantizedLinearMethod(LinearMethodBase):
174
171
  return F.linear(x, layer.weight, bias)
175
172
 
176
173
 
174
+ class LinearBase(torch.nn.Module):
175
+ """Base linear layer.
176
+
177
+ Args:
178
+ input_size: input dimension of the linear layer.
179
+ output_size: output dimension of the linear layer.
180
+ bias: If true, add bias.
181
+ skip_bias_add: If true, skip adding bias but instead return it.
182
+ params_dtype: Data type for the parameters.
183
+ quant_config: Quantization configure.
184
+ """
185
+
186
+ def __init__(
187
+ self,
188
+ input_size: int,
189
+ output_size: int,
190
+ skip_bias_add: bool = False,
191
+ params_dtype: Optional[torch.dtype] = None,
192
+ quant_config: Optional[QuantizationConfig] = None,
193
+ prefix: str = "",
194
+ ):
195
+ super().__init__()
196
+
197
+ # Keep input parameters
198
+ self.input_size = input_size
199
+ self.output_size = output_size
200
+ self.skip_bias_add = skip_bias_add
201
+ if params_dtype is None:
202
+ params_dtype = torch.get_default_dtype()
203
+ self.params_dtype = params_dtype
204
+ if quant_config is None:
205
+ self.quant_method: Optional[QuantizeMethodBase] = UnquantizedLinearMethod()
206
+ else:
207
+ self.quant_method = quant_config.get_quant_method(self, prefix=prefix)
208
+
209
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
210
+ raise NotImplementedError
211
+
212
+
177
213
  class ReplicatedLinear(LinearBase):
178
214
  """Replicated linear layer.
179
215
 
@@ -293,12 +329,14 @@ class ColumnParallelLinear(LinearBase):
293
329
  prefix: str = "",
294
330
  tp_rank: Optional[int] = None,
295
331
  tp_size: Optional[int] = None,
332
+ use_presharded_weights: bool = False,
296
333
  ):
297
334
  super().__init__(
298
335
  input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix
299
336
  )
300
337
 
301
338
  self.gather_output = gather_output
339
+ self.use_presharded_weights = use_presharded_weights
302
340
 
303
341
  # Divide the weight matrix along the last dimension.
304
342
  if tp_rank is None:
@@ -366,7 +404,8 @@ class ColumnParallelLinear(LinearBase):
366
404
  if output_dim is not None and not use_bitsandbytes_4bit:
367
405
  shard_size = param_data.shape[output_dim]
368
406
  start_idx = self.tp_rank * shard_size
369
- loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
407
+ if not self.use_presharded_weights:
408
+ loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
370
409
 
371
410
  # Special case for loading scales off disk, which often do not
372
411
  # have a shape (such as in the case of AutoFP8).
@@ -382,7 +421,11 @@ class ColumnParallelLinear(LinearBase):
382
421
  if len(loaded_weight.shape) == 0:
383
422
  assert loaded_weight.numel() == 1
384
423
  loaded_weight = loaded_weight.reshape(1)
385
- param.load_column_parallel_weight(loaded_weight, tp_rank=self.tp_rank)
424
+ param.load_column_parallel_weight(
425
+ loaded_weight,
426
+ tp_rank=self.tp_rank,
427
+ use_presharded_weights=self.use_presharded_weights,
428
+ )
386
429
 
387
430
  def forward(self, input_):
388
431
  bias = self.bias if not self.skip_bias_add else None
@@ -463,7 +506,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
463
506
  prefix=prefix,
464
507
  tp_rank=tp_rank,
465
508
  tp_size=tp_size,
509
+ use_presharded_weights=use_presharded_weights,
466
510
  )
511
+ self.prefix = prefix
467
512
 
468
513
  def weight_loader(
469
514
  self,
@@ -707,6 +752,7 @@ class QKVParallelLinear(ColumnParallelLinear):
707
752
  prefix: str = "",
708
753
  tp_rank: Optional[int] = None,
709
754
  tp_size: Optional[int] = None,
755
+ load_presharded_attn: bool = False,
710
756
  ):
711
757
  self.hidden_size = hidden_size
712
758
  self.head_size = head_size
@@ -736,6 +782,7 @@ class QKVParallelLinear(ColumnParallelLinear):
736
782
  self.num_kv_heads * self.head_size * tp_size, # k_proj
737
783
  self.num_kv_heads * self.head_size * tp_size, # v_proj
738
784
  ]
785
+ self.use_presharded_weights = load_presharded_attn
739
786
 
740
787
  super().__init__(
741
788
  input_size=input_size,
@@ -748,6 +795,7 @@ class QKVParallelLinear(ColumnParallelLinear):
748
795
  prefix=prefix,
749
796
  tp_rank=tp_rank,
750
797
  tp_size=tp_size,
798
+ use_presharded_weights=self.use_presharded_weights,
751
799
  )
752
800
 
753
801
  def _get_shard_offset_mapping(self, loaded_shard_id: str):
@@ -806,9 +854,10 @@ class QKVParallelLinear(ColumnParallelLinear):
806
854
  shard_size=shard_size, shard_offset=shard_offset
807
855
  )
808
856
 
809
- loaded_weight_shard = loaded_weight.narrow(
810
- param.output_dim, shard_offset, shard_size
811
- )
857
+ if not self.use_presharded_weights:
858
+ loaded_weight_shard = loaded_weight.narrow(
859
+ param.output_dim, shard_offset, shard_size
860
+ )
812
861
  self.weight_loader_v2(param, loaded_weight_shard, shard_id)
813
862
 
814
863
  def weight_loader_v2(
@@ -846,6 +895,7 @@ class QKVParallelLinear(ColumnParallelLinear):
846
895
  shard_offset=shard_offset,
847
896
  shard_size=shard_size,
848
897
  tp_rank=self.tp_rank,
898
+ use_presharded_weights=self.use_presharded_weights,
849
899
  )
850
900
 
851
901
  def weight_loader(
@@ -951,9 +1001,10 @@ class QKVParallelLinear(ColumnParallelLinear):
951
1001
  param, orig_qkv_offsets, shard_id
952
1002
  )
953
1003
 
954
- loaded_weight_shard = loaded_weight.narrow(
955
- output_dim, shard_offset, shard_size
956
- )
1004
+ if not self.use_presharded_weights:
1005
+ loaded_weight_shard = loaded_weight.narrow(
1006
+ output_dim, shard_offset, shard_size
1007
+ )
957
1008
  self.weight_loader(param, loaded_weight_shard, shard_id)
958
1009
  return
959
1010
 
@@ -1013,7 +1064,7 @@ class QKVParallelLinear(ColumnParallelLinear):
1013
1064
 
1014
1065
  # bitsandbytes loads the weights of the specific portion
1015
1066
  # no need to narrow here
1016
- if not use_bitsandbytes_4bit:
1067
+ if not use_bitsandbytes_4bit and not self.use_presharded_weights:
1017
1068
  loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
1018
1069
 
1019
1070
  # Special case for for AQLM codebooks.
@@ -14,17 +14,18 @@
14
14
  """Logits processing."""
15
15
 
16
16
  import dataclasses
17
+ import logging
17
18
  from typing import List, Optional, Union
18
19
 
19
20
  import torch
20
21
  import triton
21
22
  import triton.language as tl
22
23
  from torch import nn
23
- from vllm.distributed import (
24
+
25
+ from sglang.srt.distributed import (
24
26
  get_tensor_model_parallel_world_size,
25
27
  tensor_model_parallel_all_gather,
26
28
  )
27
-
28
29
  from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
29
30
  from sglang.srt.model_executor.forward_batch_info import (
30
31
  CaptureHiddenMode,
@@ -32,6 +33,8 @@ from sglang.srt.model_executor.forward_batch_info import (
32
33
  ForwardMode,
33
34
  )
34
35
 
36
+ logger = logging.getLogger(__name__)
37
+
35
38
 
36
39
  @dataclasses.dataclass
37
40
  class LogitsProcessorOutput:
@@ -50,8 +53,6 @@ class LogitsProcessorOutput:
50
53
  next_token_top_logprobs_idx: Optional[List] = None
51
54
 
52
55
  ## Part 3: Prefill-only. This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
53
- # The normlaized logprobs of prompts. shape: [#seq]
54
- normalized_prompt_logprobs: torch.Tensor = None
55
56
  # The logprobs of input tokens. shape: [#token]
56
57
  input_token_logprobs: torch.Tensor = None
57
58
  # The logprobs and ids of the top-k tokens in input positions. shape: [#seq, #token, k]
@@ -129,59 +130,70 @@ class LogitsProcessor(nn.Module):
129
130
  hidden_states,
130
131
  lm_head: VocabParallelEmbedding,
131
132
  logits_metadata: Union[LogitsMetadata, ForwardBatch],
132
- ):
133
+ ) -> LogitsProcessorOutput:
133
134
  if isinstance(logits_metadata, ForwardBatch):
134
135
  logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata)
135
136
 
136
137
  # Get the last hidden states and last logits for the next token prediction
137
138
  if (
138
- logits_metadata.forward_mode.is_decode()
139
+ logits_metadata.forward_mode.is_decode_or_idle()
139
140
  or logits_metadata.forward_mode.is_target_verify()
140
141
  ):
141
- last_index = None
142
- last_hidden = hidden_states
143
- else:
142
+ pruned_states = hidden_states
143
+ sample_indices = None
144
+ elif (
145
+ logits_metadata.forward_mode.is_extend()
146
+ and not logits_metadata.extend_return_logprob
147
+ ):
148
+ # Prefill without input logprobs.
144
149
  last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
145
- last_hidden = hidden_states[last_index]
150
+ pruned_states = hidden_states[last_index]
151
+ sample_indices = None
152
+ else:
153
+ # Slice the requested tokens to compute logprob
154
+ sample_index_pt = -1
155
+ sample_indices = []
156
+ pt, pruned_states, pruned_input_ids = 0, [], []
157
+ for start_len, extend_len in zip(
158
+ logits_metadata.extend_logprob_start_lens_cpu,
159
+ logits_metadata.extend_seq_lens_cpu,
160
+ ):
161
+ pruned_states.append(hidden_states[pt + start_len : pt + extend_len])
162
+ sample_index_pt += extend_len - start_len
163
+ sample_indices.append(sample_index_pt)
164
+ pruned_input_ids.append(input_ids[pt + start_len : pt + extend_len])
165
+ pt += extend_len
166
+
167
+ pruned_states = torch.cat(pruned_states)
168
+
169
+ # Compute logits for both input and sampled tokens.
170
+ logits = self._get_logits(pruned_states, lm_head, logits_metadata)
171
+ sampled_logits = (
172
+ logits[sample_indices] if sample_indices is not None else logits
173
+ )
146
174
 
147
- # Compute logits
148
- last_logits = self._get_logits(last_hidden, lm_head)
149
175
  if (
150
176
  not logits_metadata.extend_return_logprob
151
177
  or logits_metadata.capture_hidden_mode.need_capture()
152
178
  ):
153
179
  # Decode mode or extend mode without return_logprob.
154
180
  return LogitsProcessorOutput(
155
- next_token_logits=last_logits,
181
+ next_token_logits=sampled_logits,
156
182
  hidden_states=(
157
183
  hidden_states
158
184
  if logits_metadata.capture_hidden_mode.is_full()
159
185
  else (
160
- last_hidden
186
+ pruned_states
161
187
  if logits_metadata.capture_hidden_mode.is_last()
162
188
  else None
163
189
  )
164
190
  ),
165
191
  )
166
192
  else:
167
- # Slice the requested tokens to compute logprob
168
- pt, pruned_states, pruned_input_ids = 0, [], []
169
- for start_len, extend_len in zip(
170
- logits_metadata.extend_logprob_start_lens_cpu,
171
- logits_metadata.extend_seq_lens_cpu,
172
- ):
173
- pruned_states.append(hidden_states[pt + start_len : pt + extend_len])
174
- pruned_input_ids.append(input_ids[pt + start_len : pt + extend_len])
175
- pt += extend_len
176
-
177
- # Compute the logits of all required tokens
178
- pruned_states = torch.cat(pruned_states)
179
- del hidden_states
180
- input_token_logits = self._get_logits(pruned_states, lm_head)
181
- del pruned_states
193
+ input_logprobs = logits
194
+ del hidden_states, logits
182
195
 
183
196
  # Normalize the logprob w/o temperature, top-p
184
- input_logprobs = input_token_logits
185
197
  input_logprobs = self.compute_temp_top_p_normalized_logprobs(
186
198
  input_logprobs, logits_metadata
187
199
  )
@@ -195,25 +207,18 @@ class LogitsProcessor(nn.Module):
195
207
  else:
196
208
  input_top_logprobs_val = input_top_logprobs_idx = None
197
209
 
198
- # Compute the normalized logprobs for the requested tokens.
199
- # Note that we pad a zero at the end for easy batching.
200
210
  input_token_logprobs = input_logprobs[
201
- torch.arange(input_logprobs.shape[0], device="cuda"),
211
+ torch.arange(input_logprobs.shape[0], device=input_logprobs.device),
202
212
  torch.cat(
203
213
  [
204
214
  torch.cat(pruned_input_ids)[1:],
205
- torch.tensor([0], device="cuda"),
215
+ torch.tensor([0], device=input_logprobs.device),
206
216
  ]
207
217
  ),
208
218
  ]
209
- normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
210
- input_token_logprobs,
211
- logits_metadata,
212
- )
213
219
 
214
220
  return LogitsProcessorOutput(
215
- next_token_logits=last_logits,
216
- normalized_prompt_logprobs=normalized_prompt_logprobs,
221
+ next_token_logits=sampled_logits,
217
222
  input_token_logprobs=input_token_logprobs,
218
223
  input_top_logprobs_val=input_top_logprobs_val,
219
224
  input_top_logprobs_idx=input_top_logprobs_idx,
@@ -223,8 +228,11 @@ class LogitsProcessor(nn.Module):
223
228
  self,
224
229
  hidden_states: torch.Tensor,
225
230
  lm_head: VocabParallelEmbedding,
231
+ logits_metadata: LogitsMetadata,
226
232
  embedding_bias: Optional[torch.Tensor] = None,
227
233
  ) -> torch.Tensor:
234
+ """Get logits from hidden_states."""
235
+
228
236
  if hasattr(lm_head, "weight"):
229
237
  logits = torch.matmul(hidden_states, lm_head.weight.T)
230
238
  else:
@@ -237,8 +245,6 @@ class LogitsProcessor(nn.Module):
237
245
  if self.do_tensor_parallel_all_gather:
238
246
  logits = tensor_model_parallel_all_gather(logits)
239
247
 
240
- # Compute the normalized logprobs for the requested tokens.
241
- # Note that we pad a zero at the end for easy batching.
242
248
  logits = logits[:, : self.config.vocab_size].float()
243
249
 
244
250
  if self.final_logit_softcapping:
@@ -246,27 +252,6 @@ class LogitsProcessor(nn.Module):
246
252
 
247
253
  return logits
248
254
 
249
- @staticmethod
250
- def _get_normalized_prompt_logprobs(
251
- input_token_logprobs: torch.Tensor,
252
- logits_metadata: LogitsMetadata,
253
- ):
254
- logprobs_cumsum = torch.cumsum(input_token_logprobs, dim=0, dtype=torch.float32)
255
- pruned_lens = torch.tensor(
256
- logits_metadata.extend_logprob_pruned_lens_cpu, device="cuda"
257
- )
258
-
259
- start = torch.zeros_like(pruned_lens)
260
- start[1:] = torch.cumsum(pruned_lens[:-1], dim=0)
261
- end = torch.clamp(
262
- start + pruned_lens - 2, min=0, max=logprobs_cumsum.shape[0] - 1
263
- )
264
- sum_logp = (
265
- logprobs_cumsum[end] - logprobs_cumsum[start] + input_token_logprobs[start]
266
- )
267
- normalized_prompt_logprobs = sum_logp / (pruned_lens - 1).clamp(min=1)
268
- return normalized_prompt_logprobs
269
-
270
255
  @staticmethod
271
256
  def get_top_logprobs(all_logprobs: torch.Tensor, logits_metadata: LogitsMetadata):
272
257
  max_k = max(logits_metadata.top_logprobs_nums)
@@ -311,7 +296,7 @@ def fused_softcap_kernel(
311
296
  n_elements,
312
297
  BLOCK_SIZE: tl.constexpr,
313
298
  ):
314
- pid = tl.program_id(0)
299
+ pid = tl.program_id(0).to(tl.int64)
315
300
  block_start = pid * BLOCK_SIZE
316
301
  offsets = block_start + tl.arange(0, BLOCK_SIZE)
317
302
  mask = offsets < n_elements