sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +4 -2
- sglang/bench_one_batch.py +2 -2
- sglang/bench_one_batch_server.py +143 -15
- sglang/bench_serving.py +9 -7
- sglang/compile_deep_gemm.py +1 -1
- sglang/eval/loogle_eval.py +157 -0
- sglang/lang/chat_template.py +78 -78
- sglang/lang/tracer.py +1 -1
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +2 -2
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/constrained/base_grammar_backend.py +55 -72
- sglang/srt/constrained/llguidance_backend.py +25 -21
- sglang/srt/constrained/outlines_backend.py +27 -26
- sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -43
- sglang/srt/conversation.py +48 -43
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +7 -2
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mooncake/conn.py +227 -120
- sglang/srt/disaggregation/nixl/conn.py +1 -0
- sglang/srt/disaggregation/prefill.py +7 -4
- sglang/srt/disaggregation/utils.py +7 -1
- sglang/srt/entrypoints/engine.py +17 -2
- sglang/srt/entrypoints/http_server.py +17 -2
- sglang/srt/function_call_parser.py +2 -2
- sglang/srt/layers/attention/flashattention_backend.py +1 -1
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
- sglang/srt/layers/attention/utils.py +4 -2
- sglang/srt/layers/dp_attention.py +71 -21
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/logits_processor.py +46 -11
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
- sglang/srt/layers/moe/topk.py +1 -1
- sglang/srt/layers/quantization/__init__.py +1 -1
- sglang/srt/layers/quantization/blockwise_int8.py +2 -2
- sglang/srt/layers/quantization/deep_gemm.py +72 -71
- sglang/srt/layers/quantization/fp8.py +2 -2
- sglang/srt/layers/quantization/fp8_kernel.py +3 -3
- sglang/srt/layers/quantization/int8_kernel.py +2 -2
- sglang/srt/layers/sampler.py +0 -4
- sglang/srt/layers/vocab_parallel_embedding.py +18 -7
- sglang/srt/lora/lora_manager.py +1 -1
- sglang/srt/lora/mem_pool.py +4 -4
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/data_parallel_controller.py +3 -3
- sglang/srt/managers/detokenizer_manager.py +21 -8
- sglang/srt/managers/io_struct.py +3 -1
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/multimodal_processors/llava.py +46 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
- sglang/srt/managers/schedule_batch.py +76 -24
- sglang/srt/managers/schedule_policy.py +0 -3
- sglang/srt/managers/scheduler.py +113 -88
- sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
- sglang/srt/managers/tokenizer_manager.py +133 -34
- sglang/srt/managers/tp_worker.py +12 -9
- sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
- sglang/srt/mem_cache/memory_pool.py +2 -0
- sglang/srt/metrics/collector.py +312 -37
- sglang/srt/model_executor/cuda_graph_runner.py +10 -11
- sglang/srt/model_executor/forward_batch_info.py +1 -1
- sglang/srt/model_executor/model_runner.py +19 -14
- sglang/srt/models/deepseek_janus_pro.py +2 -2
- sglang/srt/models/deepseek_v2.py +23 -20
- sglang/srt/models/llama.py +2 -0
- sglang/srt/models/llama4.py +5 -6
- sglang/srt/models/llava.py +248 -5
- sglang/srt/models/mixtral.py +98 -34
- sglang/srt/models/pixtral.py +467 -0
- sglang/srt/models/roberta.py +1 -1
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/openai_api/adapter.py +30 -4
- sglang/srt/openai_api/protocol.py +0 -8
- sglang/srt/reasoning_parser.py +3 -3
- sglang/srt/sampling/custom_logit_processor.py +18 -3
- sglang/srt/sampling/sampling_batch_info.py +4 -56
- sglang/srt/sampling/sampling_params.py +2 -2
- sglang/srt/server_args.py +34 -4
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/speculative/eagle_utils.py +7 -7
- sglang/srt/speculative/eagle_worker.py +22 -19
- sglang/srt/utils.py +6 -5
- sglang/test/few_shot_gsm8k.py +2 -2
- sglang/test/few_shot_gsm8k_engine.py +2 -2
- sglang/test/run_eval.py +2 -2
- sglang/test/runners.py +8 -1
- sglang/test/send_one.py +13 -3
- sglang/test/simple_eval_common.py +1 -1
- sglang/test/simple_eval_humaneval.py +1 -1
- sglang/test/test_programs.py +5 -5
- sglang/test/test_utils.py +89 -14
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/METADATA +6 -5
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/RECORD +107 -104
- /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/top_level.txt +0 -0
@@ -49,6 +49,7 @@ from sglang.srt.disaggregation.utils import (
|
|
49
49
|
from sglang.srt.entrypoints.engine import _launch_subprocesses
|
50
50
|
from sglang.srt.function_call_parser import FunctionCallParser
|
51
51
|
from sglang.srt.managers.io_struct import (
|
52
|
+
AbortReq,
|
52
53
|
CloseSessionReqInput,
|
53
54
|
ConfigureLoggingReq,
|
54
55
|
EmbeddingReqInput,
|
@@ -221,7 +222,7 @@ async def get_server_info():
|
|
221
222
|
return {
|
222
223
|
**dataclasses.asdict(_global_state.tokenizer_manager.server_args),
|
223
224
|
**_global_state.scheduler_info,
|
224
|
-
|
225
|
+
"internal_states": internal_states,
|
225
226
|
"version": __version__,
|
226
227
|
}
|
227
228
|
|
@@ -337,7 +338,11 @@ async def start_profile_async(obj: Optional[ProfileReqInput] = None):
|
|
337
338
|
obj = ProfileReqInput()
|
338
339
|
|
339
340
|
await _global_state.tokenizer_manager.start_profile(
|
340
|
-
obj.output_dir,
|
341
|
+
output_dir=obj.output_dir,
|
342
|
+
num_steps=obj.num_steps,
|
343
|
+
activities=obj.activities,
|
344
|
+
with_stack=obj.with_stack,
|
345
|
+
record_shapes=obj.record_shapes,
|
341
346
|
)
|
342
347
|
return Response(
|
343
348
|
content="Start profiling.\n",
|
@@ -539,6 +544,16 @@ async def configure_logging(obj: ConfigureLoggingReq, request: Request):
|
|
539
544
|
return Response(status_code=200)
|
540
545
|
|
541
546
|
|
547
|
+
@app.post("/abort_request")
|
548
|
+
async def abort_request(obj: AbortReq, request: Request):
|
549
|
+
"""Abort a request."""
|
550
|
+
try:
|
551
|
+
_global_state.tokenizer_manager.abort_request(rid=obj.rid)
|
552
|
+
return Response(status_code=200)
|
553
|
+
except Exception as e:
|
554
|
+
return _create_error_response(e)
|
555
|
+
|
556
|
+
|
542
557
|
@app.post("/parse_function_call")
|
543
558
|
async def parse_function_call_request(obj: ParseFunctionCallReq, request: Request):
|
544
559
|
"""
|
@@ -86,8 +86,8 @@ class StructureInfo:
|
|
86
86
|
|
87
87
|
_GetInfoFunc = Callable[[str], StructureInfo]
|
88
88
|
"""
|
89
|
-
|
90
|
-
|
89
|
+
Helper alias of function
|
90
|
+
Usually it is a function that takes a name string and returns a StructureInfo object,
|
91
91
|
which can be used to construct a structural_tag object
|
92
92
|
"""
|
93
93
|
|
@@ -308,7 +308,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
308
308
|
), "Sliding window and cross attention are not supported together"
|
309
309
|
|
310
310
|
self.forward_metadata: FlashAttentionMetadata = None
|
311
|
-
# extra
|
311
|
+
# extra metadata for handling speculative decoding topk > 1, extended draft decode and verify
|
312
312
|
self.forward_metadata_spec_decode_expand: FlashAttentionMetadata = None
|
313
313
|
self.max_context_len = model_runner.model_config.context_len
|
314
314
|
self.device = model_runner.device
|
@@ -919,7 +919,7 @@ def _fwd_kernel(
|
|
919
919
|
|
920
920
|
e_max = n_e_max
|
921
921
|
|
922
|
-
# stage 2: compute the
|
922
|
+
# stage 2: compute the triangle part
|
923
923
|
|
924
924
|
cur_block_m_end = tl.minimum(cur_seq_len_extend, (cur_block_m + 1) * BLOCK_M)
|
925
925
|
for start_n in range(0, cur_block_m_end, BLOCK_N):
|
@@ -28,7 +28,8 @@ def create_flashinfer_kv_indices_triton(
|
|
28
28
|
|
29
29
|
num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
|
30
30
|
for i in range(num_loop):
|
31
|
-
|
31
|
+
# index into req_to_token_ptr needs to be int64
|
32
|
+
offset = tl.arange(0, BLOCK_SIZE).to(tl.int64) + i * BLOCK_SIZE
|
32
33
|
mask = offset < kv_end - kv_start
|
33
34
|
data = tl.load(
|
34
35
|
req_to_token_ptr
|
@@ -70,8 +71,9 @@ def create_flashmla_kv_indices_triton(
|
|
70
71
|
num_pages_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
|
71
72
|
|
72
73
|
for i in range(num_pages_loop):
|
74
|
+
# index into req_to_token_ptr needs to be int64
|
73
75
|
paged_offset = (
|
74
|
-
tl.arange(0, NUM_PAGE_PER_BLOCK) + i * NUM_PAGE_PER_BLOCK
|
76
|
+
tl.arange(0, NUM_PAGE_PER_BLOCK).to(tl.int64) + i * NUM_PAGE_PER_BLOCK
|
75
77
|
) * PAGED_SIZE
|
76
78
|
paged_offset_out = tl.arange(0, NUM_PAGE_PER_BLOCK) + i * NUM_PAGE_PER_BLOCK
|
77
79
|
|
@@ -24,8 +24,10 @@ if TYPE_CHECKING:
|
|
24
24
|
_ATTN_TP_GROUP = None
|
25
25
|
_ATTN_TP_RANK = None
|
26
26
|
_ATTN_TP_SIZE = None
|
27
|
-
|
28
|
-
|
27
|
+
_ATTN_DP_RANK = None
|
28
|
+
_ATTN_DP_SIZE = None
|
29
|
+
_LOCAL_ATTN_DP_SIZE = None
|
30
|
+
_LOCAL_ATTN_DP_RANK = None
|
29
31
|
|
30
32
|
|
31
33
|
def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
|
@@ -33,9 +35,27 @@ def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_si
|
|
33
35
|
return tp_rank, tp_size, 0
|
34
36
|
|
35
37
|
attn_tp_size = tp_size // dp_size
|
36
|
-
|
38
|
+
attn_dp_rank = tp_rank // attn_tp_size
|
37
39
|
attn_tp_rank = tp_rank % attn_tp_size
|
38
|
-
|
40
|
+
|
41
|
+
return attn_tp_rank, attn_tp_size, attn_dp_rank
|
42
|
+
|
43
|
+
|
44
|
+
def compute_dp_attention_local_info(
|
45
|
+
enable_dp_attention, tp_rank, tp_size, dp_size, moe_dense_tp_size
|
46
|
+
):
|
47
|
+
if not enable_dp_attention:
|
48
|
+
return tp_rank, tp_size, 0
|
49
|
+
|
50
|
+
local_tp_size = moe_dense_tp_size if moe_dense_tp_size else tp_size
|
51
|
+
local_tp_rank = tp_rank % local_tp_size
|
52
|
+
local_dp_size = max(1, dp_size // (tp_size // local_tp_size))
|
53
|
+
|
54
|
+
local_attn_tp_size = local_tp_size // local_dp_size
|
55
|
+
local_attn_dp_rank = local_tp_rank // local_attn_tp_size
|
56
|
+
local_attn_tp_rank = local_tp_rank % local_attn_tp_size
|
57
|
+
|
58
|
+
return local_attn_tp_rank, local_attn_tp_size, local_attn_dp_rank
|
39
59
|
|
40
60
|
|
41
61
|
def initialize_dp_attention(
|
@@ -43,22 +63,32 @@ def initialize_dp_attention(
|
|
43
63
|
tp_rank: int,
|
44
64
|
tp_size: int,
|
45
65
|
dp_size: int,
|
66
|
+
moe_dense_tp_size: int,
|
46
67
|
pp_size: int,
|
47
68
|
):
|
48
|
-
global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE,
|
69
|
+
global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _ATTN_DP_RANK, _ATTN_DP_SIZE
|
70
|
+
global _LOCAL_ATTN_DP_SIZE, _LOCAL_ATTN_DP_RANK
|
49
71
|
|
50
72
|
from sglang.srt.layers.sampler import SYNC_TOKEN_IDS_ACROSS_TP
|
51
73
|
|
52
|
-
_ATTN_TP_RANK, _ATTN_TP_SIZE,
|
74
|
+
_ATTN_TP_RANK, _ATTN_TP_SIZE, _ATTN_DP_RANK = compute_dp_attention_world_info(
|
53
75
|
enable_dp_attention, tp_rank, tp_size, dp_size
|
54
76
|
)
|
77
|
+
_, _, _LOCAL_ATTN_DP_RANK = compute_dp_attention_local_info(
|
78
|
+
enable_dp_attention, tp_rank, tp_size, dp_size, moe_dense_tp_size
|
79
|
+
)
|
55
80
|
|
56
81
|
if enable_dp_attention:
|
57
82
|
local_rank = tp_rank % (tp_size // dp_size)
|
58
|
-
|
83
|
+
_ATTN_DP_SIZE = dp_size
|
84
|
+
if moe_dense_tp_size is None:
|
85
|
+
_LOCAL_ATTN_DP_SIZE = _ATTN_DP_SIZE
|
86
|
+
else:
|
87
|
+
_LOCAL_ATTN_DP_SIZE = max(1, dp_size // (tp_size // moe_dense_tp_size))
|
59
88
|
else:
|
60
89
|
local_rank = tp_rank
|
61
|
-
|
90
|
+
_ATTN_DP_SIZE = 1
|
91
|
+
_LOCAL_ATTN_DP_SIZE = 1
|
62
92
|
|
63
93
|
tp_group = get_tp_group()
|
64
94
|
_ATTN_TP_GROUP = GroupCoordinator(
|
@@ -93,13 +123,33 @@ def get_attention_tp_size():
|
|
93
123
|
|
94
124
|
|
95
125
|
def get_attention_dp_rank():
|
96
|
-
assert
|
97
|
-
return
|
126
|
+
assert _ATTN_DP_RANK is not None, "dp attention not initialized!"
|
127
|
+
return _ATTN_DP_RANK
|
98
128
|
|
99
129
|
|
100
130
|
def get_attention_dp_size():
|
101
|
-
assert
|
102
|
-
return
|
131
|
+
assert _ATTN_DP_SIZE is not None, "dp attention not initialized!"
|
132
|
+
return _ATTN_DP_SIZE
|
133
|
+
|
134
|
+
|
135
|
+
def get_local_attention_dp_rank():
|
136
|
+
assert _LOCAL_ATTN_DP_RANK is not None, "dp attention not initialized!"
|
137
|
+
return _LOCAL_ATTN_DP_RANK
|
138
|
+
|
139
|
+
|
140
|
+
def get_local_attention_dp_size():
|
141
|
+
assert _LOCAL_ATTN_DP_SIZE is not None, "dp attention not initialized!"
|
142
|
+
return _LOCAL_ATTN_DP_SIZE
|
143
|
+
|
144
|
+
|
145
|
+
def get_local_attention_dp_rank():
|
146
|
+
assert _LOCAL_ATTN_DP_RANK is not None, "dp attention not initialized!"
|
147
|
+
return _LOCAL_ATTN_DP_RANK
|
148
|
+
|
149
|
+
|
150
|
+
def get_local_attention_dp_size():
|
151
|
+
assert _LOCAL_ATTN_DP_SIZE is not None, "dp attention not initialized!"
|
152
|
+
return _LOCAL_ATTN_DP_SIZE
|
103
153
|
|
104
154
|
|
105
155
|
@contextmanager
|
@@ -112,19 +162,19 @@ def disable_dp_size():
|
|
112
162
|
Args:
|
113
163
|
tp_group (GroupCoordinator): the tp group coordinator
|
114
164
|
"""
|
115
|
-
global
|
116
|
-
assert
|
165
|
+
global _ATTN_DP_SIZE
|
166
|
+
assert _ATTN_DP_SIZE is not None, "dp attention not initialized!"
|
117
167
|
|
118
|
-
old_dp_size =
|
119
|
-
|
168
|
+
old_dp_size = _ATTN_DP_SIZE
|
169
|
+
_ATTN_DP_SIZE = 1
|
120
170
|
try:
|
121
171
|
yield
|
122
172
|
finally:
|
123
|
-
|
173
|
+
_ATTN_DP_SIZE = old_dp_size
|
124
174
|
|
125
175
|
|
126
176
|
def get_dp_local_info(forward_batch: ForwardBatch):
|
127
|
-
dp_rank =
|
177
|
+
dp_rank = get_local_attention_dp_rank()
|
128
178
|
|
129
179
|
if forward_batch.dp_local_start_pos is None:
|
130
180
|
cumtokens = torch.cumsum(forward_batch.global_num_tokens_gpu, dim=0)
|
@@ -201,7 +251,7 @@ def _dp_gather(
|
|
201
251
|
global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False
|
202
252
|
)
|
203
253
|
|
204
|
-
# Input IDs are in int 32. We should use inplace_all_reduce for local case
|
254
|
+
# Input IDs are in int 32. We should use inplace_all_reduce for local case because of custom all reduce.
|
205
255
|
NUM_GPUS_PER_NODE = 8
|
206
256
|
if (
|
207
257
|
not local_tokens.dtype.is_floating_point
|
@@ -252,12 +302,12 @@ def dp_scatter(
|
|
252
302
|
)
|
253
303
|
|
254
304
|
|
255
|
-
def
|
305
|
+
def attn_tp_reduce_scatter(
|
256
306
|
output: torch.Tensor,
|
257
307
|
input_list: List[torch.Tensor],
|
258
308
|
):
|
259
309
|
return get_attention_tp_group().reduce_scatter(output, input_list)
|
260
310
|
|
261
311
|
|
262
|
-
def
|
312
|
+
def attn_tp_all_gather(output_list: List[torch.Tensor], input_: torch.Tensor):
|
263
313
|
return get_attention_tp_group().all_gather(input_, tensor_list=output_list)
|
sglang/srt/layers/layernorm.py
CHANGED
@@ -76,7 +76,7 @@ class RMSNorm(CustomOp):
|
|
76
76
|
residual: Optional[torch.Tensor] = None,
|
77
77
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
78
78
|
if not x.is_contiguous():
|
79
|
-
# NOTE:
|
79
|
+
# NOTE: Remove this if aiter kernel supports discontinuous input
|
80
80
|
x = x.contiguous()
|
81
81
|
if residual is not None:
|
82
82
|
fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon)
|
@@ -23,15 +23,17 @@ import triton.language as tl
|
|
23
23
|
from torch import nn
|
24
24
|
|
25
25
|
from sglang.srt.distributed import (
|
26
|
-
get_tensor_model_parallel_rank,
|
27
26
|
get_tensor_model_parallel_world_size,
|
28
27
|
tensor_model_parallel_all_gather,
|
29
28
|
)
|
30
29
|
from sglang.srt.layers.dp_attention import (
|
30
|
+
attn_tp_all_gather,
|
31
31
|
dp_gather_replicate,
|
32
32
|
dp_scatter,
|
33
|
-
get_attention_dp_rank,
|
34
33
|
get_attention_dp_size,
|
34
|
+
get_attention_tp_size,
|
35
|
+
get_local_attention_dp_rank,
|
36
|
+
get_local_attention_dp_size,
|
35
37
|
)
|
36
38
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
37
39
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
@@ -45,6 +47,18 @@ from sglang.srt.utils import dump_to_file
|
|
45
47
|
logger = logging.getLogger(__name__)
|
46
48
|
|
47
49
|
|
50
|
+
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
51
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
52
|
+
from sglang.srt.model_executor.forward_batch_info import (
|
53
|
+
CaptureHiddenMode,
|
54
|
+
ForwardBatch,
|
55
|
+
ForwardMode,
|
56
|
+
)
|
57
|
+
from sglang.srt.utils import dump_to_file
|
58
|
+
|
59
|
+
logger = logging.getLogger(__name__)
|
60
|
+
|
61
|
+
|
48
62
|
@dataclasses.dataclass
|
49
63
|
class LogitsProcessorOutput:
|
50
64
|
## Part 1: This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
|
@@ -169,7 +183,7 @@ class LogitsMetadata:
|
|
169
183
|
return
|
170
184
|
|
171
185
|
cumtokens = torch.cumsum(self.global_num_tokens_for_logprob_gpu, dim=0)
|
172
|
-
dp_rank =
|
186
|
+
dp_rank = get_local_attention_dp_rank()
|
173
187
|
if dp_rank == 0:
|
174
188
|
dp_local_start_pos = torch.zeros_like(
|
175
189
|
self.global_num_tokens_for_logprob_gpu[0]
|
@@ -198,12 +212,20 @@ class LogitsProcessor(nn.Module):
|
|
198
212
|
super().__init__()
|
199
213
|
self.config = config
|
200
214
|
self.logit_scale = logit_scale
|
201
|
-
self.
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
215
|
+
self.use_attn_tp_group = global_server_args_dict["enable_dp_lm_head"]
|
216
|
+
if self.use_attn_tp_group:
|
217
|
+
self.attn_tp_size = get_attention_tp_size()
|
218
|
+
self.do_tensor_parallel_all_gather = (
|
219
|
+
not skip_all_gather and self.attn_tp_size > 1
|
220
|
+
)
|
221
|
+
self.do_tensor_parallel_all_gather_dp_attn = False
|
222
|
+
else:
|
223
|
+
self.do_tensor_parallel_all_gather = (
|
224
|
+
not skip_all_gather and get_tensor_model_parallel_world_size() > 1
|
225
|
+
)
|
226
|
+
self.do_tensor_parallel_all_gather_dp_attn = (
|
227
|
+
self.do_tensor_parallel_all_gather and get_attention_dp_size() != 1
|
228
|
+
)
|
207
229
|
self.final_logit_softcapping = getattr(
|
208
230
|
self.config, "final_logit_softcapping", None
|
209
231
|
)
|
@@ -315,7 +337,8 @@ class LogitsProcessor(nn.Module):
|
|
315
337
|
|
316
338
|
if self.debug_tensor_dump_output_folder:
|
317
339
|
assert (
|
318
|
-
not self.do_tensor_parallel_all_gather
|
340
|
+
not self.do_tensor_parallel_all_gather
|
341
|
+
or get_local_attention_dp_size() == 1
|
319
342
|
), "dp attention + sharded lm_head doesn't support full logits"
|
320
343
|
full_logits = self._get_logits(hidden_states, lm_head, logits_metadata)
|
321
344
|
dump_to_file(self.debug_tensor_dump_output_folder, "logits", full_logits)
|
@@ -442,7 +465,19 @@ class LogitsProcessor(nn.Module):
|
|
442
465
|
logits.mul_(self.logit_scale)
|
443
466
|
|
444
467
|
if self.do_tensor_parallel_all_gather:
|
445
|
-
|
468
|
+
if self.use_attn_tp_group:
|
469
|
+
global_logits = torch.empty(
|
470
|
+
(self.config.vocab_size, logits.shape[0]),
|
471
|
+
device=logits.device,
|
472
|
+
dtype=logits.dtype,
|
473
|
+
)
|
474
|
+
global_logits = global_logits.T
|
475
|
+
attn_tp_all_gather(
|
476
|
+
list(global_logits.tensor_split(self.attn_tp_size, dim=-1)), logits
|
477
|
+
)
|
478
|
+
logits = global_logits
|
479
|
+
else:
|
480
|
+
logits = tensor_model_parallel_all_gather(logits)
|
446
481
|
|
447
482
|
if self.do_tensor_parallel_all_gather_dp_attn:
|
448
483
|
logits, global_logits = (
|
@@ -116,7 +116,7 @@ def deepep_run_moe_deep_preprocess(topk_ids: torch.Tensor, num_experts: int):
|
|
116
116
|
seg_indptr = torch.empty(num_experts + 1, device=topk_ids.device, dtype=torch.int64)
|
117
117
|
src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int64)
|
118
118
|
|
119
|
-
# Find
|
119
|
+
# Find offset
|
120
120
|
expert_ids = torch.arange(
|
121
121
|
num_experts + 1, device=topk_ids.device, dtype=reorder_topk_ids.dtype
|
122
122
|
)
|
@@ -611,7 +611,7 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
|
|
611
611
|
self.quant_config.weight_block_size[1],
|
612
612
|
)
|
613
613
|
# NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
|
614
|
-
# Required by
|
614
|
+
# Required by column parallel or enabling merged weights
|
615
615
|
if intermediate_size % block_n != 0:
|
616
616
|
raise ValueError(
|
617
617
|
f"The output_size of gate's and up's weight = "
|
@@ -994,7 +994,7 @@ def get_default_config(
|
|
994
994
|
"num_stages": 2 if _is_hip else 4,
|
995
995
|
}
|
996
996
|
else:
|
997
|
-
# Block-wise quant: BLOCK_SIZE_K must be
|
997
|
+
# Block-wise quant: BLOCK_SIZE_K must be divisible by block_shape[1]
|
998
998
|
config = {
|
999
999
|
"BLOCK_SIZE_M": 64,
|
1000
1000
|
"BLOCK_SIZE_N": block_shape[0],
|
sglang/srt/layers/moe/topk.py
CHANGED
@@ -270,7 +270,7 @@ def select_experts(
|
|
270
270
|
routed_scaling_factor: Optional[float] = None,
|
271
271
|
):
|
272
272
|
n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
|
273
|
-
#
|
273
|
+
# DeepSeek V2/V3/R1 series models use grouped_top_k
|
274
274
|
if use_grouped_topk:
|
275
275
|
assert topk_group is not None
|
276
276
|
assert num_expert_group is not None
|
@@ -109,7 +109,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
|
109
109
|
if quantization in VLLM_QUANTIZATION_METHODS and not VLLM_AVAILABLE:
|
110
110
|
raise ValueError(
|
111
111
|
f"{quantization} quantization requires some operators from vllm. "
|
112
|
-
"
|
112
|
+
"Please install vllm by `pip install vllm==0.8.4`"
|
113
113
|
)
|
114
114
|
|
115
115
|
return QUANTIZATION_METHODS[quantization]
|
@@ -152,7 +152,7 @@ class BlockInt8LinearMethod(LinearMethodBase):
|
|
152
152
|
f"{input_size_per_partition} is not divisible by "
|
153
153
|
f"weight quantization block_k = {block_k}."
|
154
154
|
)
|
155
|
-
# Required by
|
155
|
+
# Required by column parallel or enabling merged weights
|
156
156
|
if (tp_size > 1 and output_size // output_size_per_partition == tp_size) or len(
|
157
157
|
output_partition_sizes
|
158
158
|
) > 1:
|
@@ -285,7 +285,7 @@ class BlockInt8MoEMethod:
|
|
285
285
|
self.quant_config.weight_block_size[1],
|
286
286
|
)
|
287
287
|
# NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
|
288
|
-
# Required by
|
288
|
+
# Required by column parallel or enabling merged weights
|
289
289
|
if intermediate_size % block_n != 0:
|
290
290
|
raise ValueError(
|
291
291
|
f"The output_size of gate's and up's weight = "
|