sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +4 -2
- sglang/bench_one_batch.py +2 -2
- sglang/bench_one_batch_server.py +143 -15
- sglang/bench_serving.py +9 -7
- sglang/compile_deep_gemm.py +1 -1
- sglang/eval/loogle_eval.py +157 -0
- sglang/lang/chat_template.py +78 -78
- sglang/lang/tracer.py +1 -1
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +2 -2
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/constrained/base_grammar_backend.py +55 -72
- sglang/srt/constrained/llguidance_backend.py +25 -21
- sglang/srt/constrained/outlines_backend.py +27 -26
- sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -43
- sglang/srt/conversation.py +48 -43
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +7 -2
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mooncake/conn.py +227 -120
- sglang/srt/disaggregation/nixl/conn.py +1 -0
- sglang/srt/disaggregation/prefill.py +7 -4
- sglang/srt/disaggregation/utils.py +7 -1
- sglang/srt/entrypoints/engine.py +17 -2
- sglang/srt/entrypoints/http_server.py +17 -2
- sglang/srt/function_call_parser.py +2 -2
- sglang/srt/layers/attention/flashattention_backend.py +1 -1
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
- sglang/srt/layers/attention/utils.py +4 -2
- sglang/srt/layers/dp_attention.py +71 -21
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/logits_processor.py +46 -11
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
- sglang/srt/layers/moe/topk.py +1 -1
- sglang/srt/layers/quantization/__init__.py +1 -1
- sglang/srt/layers/quantization/blockwise_int8.py +2 -2
- sglang/srt/layers/quantization/deep_gemm.py +72 -71
- sglang/srt/layers/quantization/fp8.py +2 -2
- sglang/srt/layers/quantization/fp8_kernel.py +3 -3
- sglang/srt/layers/quantization/int8_kernel.py +2 -2
- sglang/srt/layers/sampler.py +0 -4
- sglang/srt/layers/vocab_parallel_embedding.py +18 -7
- sglang/srt/lora/lora_manager.py +1 -1
- sglang/srt/lora/mem_pool.py +4 -4
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/data_parallel_controller.py +3 -3
- sglang/srt/managers/detokenizer_manager.py +21 -8
- sglang/srt/managers/io_struct.py +3 -1
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/multimodal_processors/llava.py +46 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
- sglang/srt/managers/schedule_batch.py +76 -24
- sglang/srt/managers/schedule_policy.py +0 -3
- sglang/srt/managers/scheduler.py +113 -88
- sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
- sglang/srt/managers/tokenizer_manager.py +133 -34
- sglang/srt/managers/tp_worker.py +12 -9
- sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
- sglang/srt/mem_cache/memory_pool.py +2 -0
- sglang/srt/metrics/collector.py +312 -37
- sglang/srt/model_executor/cuda_graph_runner.py +10 -11
- sglang/srt/model_executor/forward_batch_info.py +1 -1
- sglang/srt/model_executor/model_runner.py +19 -14
- sglang/srt/models/deepseek_janus_pro.py +2 -2
- sglang/srt/models/deepseek_v2.py +23 -20
- sglang/srt/models/llama.py +2 -0
- sglang/srt/models/llama4.py +5 -6
- sglang/srt/models/llava.py +248 -5
- sglang/srt/models/mixtral.py +98 -34
- sglang/srt/models/pixtral.py +467 -0
- sglang/srt/models/roberta.py +1 -1
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/openai_api/adapter.py +30 -4
- sglang/srt/openai_api/protocol.py +0 -8
- sglang/srt/reasoning_parser.py +3 -3
- sglang/srt/sampling/custom_logit_processor.py +18 -3
- sglang/srt/sampling/sampling_batch_info.py +4 -56
- sglang/srt/sampling/sampling_params.py +2 -2
- sglang/srt/server_args.py +34 -4
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/speculative/eagle_utils.py +7 -7
- sglang/srt/speculative/eagle_worker.py +22 -19
- sglang/srt/utils.py +6 -5
- sglang/test/few_shot_gsm8k.py +2 -2
- sglang/test/few_shot_gsm8k_engine.py +2 -2
- sglang/test/run_eval.py +2 -2
- sglang/test/runners.py +8 -1
- sglang/test/send_one.py +13 -3
- sglang/test/simple_eval_common.py +1 -1
- sglang/test/simple_eval_humaneval.py +1 -1
- sglang/test/test_programs.py +5 -5
- sglang/test/test_utils.py +89 -14
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/METADATA +6 -5
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/RECORD +107 -104
- /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/top_level.txt +0 -0
@@ -15,12 +15,9 @@ _ENABLE_JIT_DEEPGEMM = False
|
|
15
15
|
if is_cuda():
|
16
16
|
import deep_gemm
|
17
17
|
from deep_gemm import get_num_sms
|
18
|
+
from deep_gemm.jit.compiler import get_nvcc_compiler
|
18
19
|
from deep_gemm.jit_kernels.gemm import get_best_configs
|
19
|
-
from deep_gemm.jit_kernels.
|
20
|
-
from deep_gemm.jit_kernels.gemm import template as deep_gemm_gemm_template
|
21
|
-
from deep_gemm.jit_kernels.m_grouped_gemm import (
|
22
|
-
template as deep_gemm_grouped_gemm_template,
|
23
|
-
)
|
20
|
+
from deep_gemm.jit_kernels.runtime import FP8GemmRuntime, GemmType
|
24
21
|
from deep_gemm.jit_kernels.tuner import jit_tuner
|
25
22
|
|
26
23
|
sm_version = get_device_sm()
|
@@ -45,10 +42,25 @@ _COMPILE_WORKERS = get_int_env_var("SGL_JIT_DEEPGEMM_COMPILE_WORKERS", 4)
|
|
45
42
|
_IN_PRECOMPILE_STAGE = get_bool_env_var("SGL_IN_DEEPGEMM_PRECOMPILE_STAGE", "false")
|
46
43
|
|
47
44
|
# Force redirect deep_gemm cache_dir
|
48
|
-
os.environ["
|
49
|
-
"SGL_DG_CACHE_DIR", os.path.expanduser("~")
|
45
|
+
os.environ["DG_JIT_CACHE_DIR"] = os.getenv(
|
46
|
+
"SGL_DG_CACHE_DIR", os.path.join(os.path.expanduser("~"), ".cache", "deep_gemm")
|
50
47
|
)
|
51
48
|
|
49
|
+
# Refer to https://github.com/deepseek-ai/DeepGEMM/commit/d75b218b7b8f4a5dd5406ac87905039ead3ae42f
|
50
|
+
# NVRTC may have performance loss with some cases.
|
51
|
+
# And NVCC JIT speed is also 9x faster in the ref commit
|
52
|
+
_USE_NVRTC_DEFAULT = "0"
|
53
|
+
if _ENABLE_JIT_DEEPGEMM:
|
54
|
+
try:
|
55
|
+
get_nvcc_compiler()
|
56
|
+
except:
|
57
|
+
logger.warning(
|
58
|
+
"NVCC Compiler not found, use NVRTC for DeepGEMM JIT "
|
59
|
+
"and may have performance loss with some cases."
|
60
|
+
)
|
61
|
+
_USE_NVRTC_DEFAULT = "1"
|
62
|
+
os.environ["DG_JIT_USE_NVRTC"] = os.getenv("SGL_DG_USE_NVRTC", _USE_NVRTC_DEFAULT)
|
63
|
+
|
52
64
|
|
53
65
|
def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):
|
54
66
|
global _BUILTIN_M_LIST
|
@@ -103,10 +115,10 @@ _INITIALIZATION_DICT: Dict[Tuple[DeepGemmKernelType, int, int, int], bool] = dic
|
|
103
115
|
def _compile_warning_1():
|
104
116
|
if not _IN_PRECOMPILE_STAGE and _IS_FIRST_RANK_ON_NODE:
|
105
117
|
logger.warning(
|
106
|
-
"Entering DeepGEMM JIT Pre-
|
118
|
+
"Entering DeepGEMM JIT Pre-Compile session. "
|
107
119
|
"And it may takes a long time(Typically 10-20 mins) "
|
108
120
|
"if you have not run `sglang.compile_deep_gemm`. "
|
109
|
-
"
|
121
|
+
"It is recommended to run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`"
|
110
122
|
" for pre-compilation to reduce the overhead if you have not run it before. "
|
111
123
|
"For example: "
|
112
124
|
"`python3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code`"
|
@@ -115,7 +127,7 @@ def _compile_warning_1():
|
|
115
127
|
|
116
128
|
def _compile_warning_2():
|
117
129
|
logger.warning(
|
118
|
-
"Entering DeepGEMM JIT Single Kernel
|
130
|
+
"Entering DeepGEMM JIT Single Kernel Compile session. "
|
119
131
|
"And it will makes inference throughput becomes flaky. "
|
120
132
|
"Please run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`"
|
121
133
|
" for pre-compilation to solve this issue. "
|
@@ -130,10 +142,18 @@ def _compile_grouped_gemm_nt_f8f8bf16_masked_one(
|
|
130
142
|
num_groups: int,
|
131
143
|
config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
|
132
144
|
) -> None:
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
145
|
+
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
|
146
|
+
block_k = 128
|
147
|
+
num_tma_threads = 128
|
148
|
+
num_math_threads_per_group = 128
|
149
|
+
kwargs = {
|
150
|
+
"NUM_TMA_THREADS": num_tma_threads,
|
151
|
+
"NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
|
152
|
+
"BLOCK_K": block_k,
|
153
|
+
"NUM_SMS": num_sms,
|
154
|
+
"SMEM_SIZE": smem_config[0],
|
155
|
+
}
|
156
|
+
_, _ = jit_tuner.compile_and_tune(
|
137
157
|
name="m_grouped_gemm_fp8_fp8_bf16_nt",
|
138
158
|
keys={
|
139
159
|
"N": n,
|
@@ -146,24 +166,11 @@ def _compile_grouped_gemm_nt_f8f8bf16_masked_one(
|
|
146
166
|
"NUM_STAGES": num_stages,
|
147
167
|
"NUM_TMA_MULTICAST": tma_multicast_config[0],
|
148
168
|
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
|
149
|
-
"GEMM_TYPE":
|
169
|
+
"GEMM_TYPE": GemmType.GroupedMasked,
|
150
170
|
},
|
151
171
|
space=(),
|
152
|
-
|
153
|
-
|
154
|
-
("lhs", torch.float8_e4m3fn),
|
155
|
-
("lhs_scales", torch.float),
|
156
|
-
("rhs", torch.float8_e4m3fn),
|
157
|
-
("rhs_scales", torch.float),
|
158
|
-
("out", torch.bfloat16),
|
159
|
-
("grouped_layout", torch.int32),
|
160
|
-
("m", int),
|
161
|
-
("stream", torch.cuda.Stream),
|
162
|
-
("num_sms", int),
|
163
|
-
("smem_size", int),
|
164
|
-
),
|
165
|
-
template=deep_gemm_grouped_gemm_template,
|
166
|
-
args=[],
|
172
|
+
kwargs=kwargs,
|
173
|
+
runtime_cls=FP8GemmRuntime,
|
167
174
|
)
|
168
175
|
|
169
176
|
|
@@ -173,9 +180,18 @@ def _compile_grouped_gemm_nt_f8f8bf16_contig_one(
|
|
173
180
|
num_groups: int,
|
174
181
|
config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
|
175
182
|
) -> None:
|
176
|
-
|
177
|
-
|
178
|
-
|
183
|
+
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
|
184
|
+
block_k = 128
|
185
|
+
num_tma_threads = 128
|
186
|
+
num_math_threads_per_group = 128
|
187
|
+
kwargs = {
|
188
|
+
"NUM_TMA_THREADS": num_tma_threads,
|
189
|
+
"NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
|
190
|
+
"BLOCK_K": block_k,
|
191
|
+
"NUM_SMS": num_sms,
|
192
|
+
"SMEM_SIZE": smem_config[0],
|
193
|
+
}
|
194
|
+
_, _ = jit_tuner.compile_and_tune(
|
179
195
|
name="m_grouped_gemm_fp8_fp8_bf16_nt",
|
180
196
|
keys={
|
181
197
|
"N": n,
|
@@ -188,25 +204,11 @@ def _compile_grouped_gemm_nt_f8f8bf16_contig_one(
|
|
188
204
|
"NUM_STAGES": num_stages,
|
189
205
|
"NUM_TMA_MULTICAST": tma_multicast_config[0],
|
190
206
|
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
|
191
|
-
"GEMM_TYPE":
|
207
|
+
"GEMM_TYPE": GemmType.GroupedContiguous,
|
192
208
|
},
|
193
209
|
space=(),
|
194
|
-
|
195
|
-
|
196
|
-
("lhs", torch.float8_e4m3fn),
|
197
|
-
("lhs_scales", torch.float),
|
198
|
-
("rhs", torch.float8_e4m3fn),
|
199
|
-
("rhs_scales", torch.float),
|
200
|
-
("out", torch.bfloat16),
|
201
|
-
("grouped_layout", torch.int32),
|
202
|
-
("m", int),
|
203
|
-
("num_groups", int),
|
204
|
-
("stream", torch.cuda.Stream),
|
205
|
-
("num_sms", int),
|
206
|
-
("smem_size", int),
|
207
|
-
),
|
208
|
-
template=deep_gemm_grouped_gemm_template,
|
209
|
-
args=[],
|
210
|
+
kwargs=kwargs,
|
211
|
+
runtime_cls=FP8GemmRuntime,
|
210
212
|
)
|
211
213
|
|
212
214
|
|
@@ -216,9 +218,20 @@ def _compile_gemm_nt_f8f8bf16_one(
|
|
216
218
|
_: int, # _ is a dummy parameter to align with other interfaces
|
217
219
|
config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
|
218
220
|
) -> None:
|
219
|
-
|
220
|
-
|
221
|
-
|
221
|
+
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
|
222
|
+
block_k = 128
|
223
|
+
num_tma_threads = 128
|
224
|
+
num_math_threads_per_group = 128
|
225
|
+
kwargs = {
|
226
|
+
"GEMM_TYPE": GemmType.Normal,
|
227
|
+
"NUM_TMA_THREADS": num_tma_threads,
|
228
|
+
"NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
|
229
|
+
"NUM_GROUPS": 1,
|
230
|
+
"BLOCK_K": block_k,
|
231
|
+
"NUM_SMS": num_sms,
|
232
|
+
"SMEM_SIZE": smem_config[0],
|
233
|
+
}
|
234
|
+
_, _ = jit_tuner.compile_and_tune(
|
222
235
|
name="gemm_fp8_fp8_bf16_nt",
|
223
236
|
keys={
|
224
237
|
"N": n,
|
@@ -232,20 +245,8 @@ def _compile_gemm_nt_f8f8bf16_one(
|
|
232
245
|
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
|
233
246
|
},
|
234
247
|
space=(),
|
235
|
-
|
236
|
-
|
237
|
-
("lhs", torch.float8_e4m3fn),
|
238
|
-
("lhs_scales", torch.float),
|
239
|
-
("rhs", torch.float8_e4m3fn),
|
240
|
-
("rhs_scales", torch.float),
|
241
|
-
("out", torch.bfloat16),
|
242
|
-
("m", int),
|
243
|
-
("stream", torch.cuda.Stream),
|
244
|
-
("num_sms", int),
|
245
|
-
("smem_size", int),
|
246
|
-
),
|
247
|
-
template=deep_gemm_gemm_template,
|
248
|
-
args=[],
|
248
|
+
kwargs=kwargs,
|
249
|
+
runtime_cls=FP8GemmRuntime,
|
249
250
|
)
|
250
251
|
|
251
252
|
|
@@ -298,7 +299,7 @@ def _maybe_compile_deep_gemm_one_type_all(
|
|
298
299
|
logger.info(
|
299
300
|
f"Try DeepGEMM JIT Compiling for "
|
300
301
|
f"<{kernel_helper.name}> N={n}, K={k}, num_groups={num_groups} with all Ms."
|
301
|
-
f"{' It only takes a
|
302
|
+
f"{' It only takes a little time (typically 1 sec) if you have run `python3 -m sglang.compile_deep_gemm`. ' if not _IN_PRECOMPILE_STAGE else ''}"
|
302
303
|
)
|
303
304
|
|
304
305
|
# NOTE(alcanderian): get_num_sms should be change when 2-batch-overlap is introduced
|
@@ -373,7 +374,7 @@ def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType):
|
|
373
374
|
|
374
375
|
from deep_gemm.jit.runtime import RuntimeCache
|
375
376
|
|
376
|
-
origin_func = RuntimeCache.
|
377
|
+
origin_func = RuntimeCache.get
|
377
378
|
|
378
379
|
def __patched_func(self, *args, **kwargs):
|
379
380
|
ret = origin_func(self, *args, **kwargs)
|
@@ -385,6 +386,6 @@ def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType):
|
|
385
386
|
)
|
386
387
|
return ret
|
387
388
|
|
388
|
-
RuntimeCache.
|
389
|
+
RuntimeCache.get = __patched_func
|
389
390
|
yield
|
390
|
-
RuntimeCache.
|
391
|
+
RuntimeCache.get = origin_func
|
@@ -235,7 +235,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
235
235
|
f"{input_size_per_partition} is not divisible by "
|
236
236
|
f"weight quantization block_k = {block_k}."
|
237
237
|
)
|
238
|
-
# Required by
|
238
|
+
# Required by column parallel or enabling merged weights
|
239
239
|
if (
|
240
240
|
tp_size > 1 and output_size // output_size_per_partition == tp_size
|
241
241
|
) or len(output_partition_sizes) > 1:
|
@@ -491,7 +491,7 @@ class Fp8MoEMethod:
|
|
491
491
|
self.quant_config.weight_block_size[1],
|
492
492
|
)
|
493
493
|
# NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
|
494
|
-
# Required by
|
494
|
+
# Required by column parallel or enabling merged weights
|
495
495
|
if intermediate_size % block_n != 0:
|
496
496
|
raise ValueError(
|
497
497
|
f"The output_size of gate's and up's weight = "
|
@@ -104,7 +104,7 @@ def _per_token_group_quant_fp8(
|
|
104
104
|
y_s_ptr,
|
105
105
|
# Stride of input
|
106
106
|
y_stride,
|
107
|
-
#
|
107
|
+
# Columns of input
|
108
108
|
N,
|
109
109
|
# Avoid to divide zero
|
110
110
|
eps,
|
@@ -342,7 +342,7 @@ def _static_quant_fp8(
|
|
342
342
|
y_s_repeat_ptr,
|
343
343
|
# Stride of input
|
344
344
|
y_stride,
|
345
|
-
#
|
345
|
+
# Columns of input
|
346
346
|
N,
|
347
347
|
# Information for float8
|
348
348
|
fp8_min,
|
@@ -794,7 +794,7 @@ def w8a8_block_fp8_matmul(
|
|
794
794
|
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
|
795
795
|
else:
|
796
796
|
# Default config
|
797
|
-
# Block-wise quant: BLOCK_SIZE_K must be
|
797
|
+
# Block-wise quant: BLOCK_SIZE_K must be divisible by block_size[1]
|
798
798
|
config = {
|
799
799
|
"BLOCK_SIZE_M": 64,
|
800
800
|
"BLOCK_SIZE_N": block_size[0],
|
@@ -76,7 +76,7 @@ def _per_token_group_quant_int8(
|
|
76
76
|
y_s_ptr,
|
77
77
|
# Stride of input
|
78
78
|
y_stride,
|
79
|
-
#
|
79
|
+
# Columns of input
|
80
80
|
N,
|
81
81
|
# Avoid to divide zero
|
82
82
|
eps,
|
@@ -370,7 +370,7 @@ def w8a8_block_int8_matmul(
|
|
370
370
|
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
|
371
371
|
else:
|
372
372
|
# Default config
|
373
|
-
# Block-wise quant: BLOCK_SIZE_K must be
|
373
|
+
# Block-wise quant: BLOCK_SIZE_K must be divisible by block_size[1]
|
374
374
|
config = {
|
375
375
|
"BLOCK_SIZE_M": 64,
|
376
376
|
"BLOCK_SIZE_N": block_size[0],
|
sglang/srt/layers/sampler.py
CHANGED
@@ -239,10 +239,6 @@ def top_p_normalize_probs_torch(
|
|
239
239
|
|
240
240
|
|
241
241
|
def get_top_logprobs(logprobs: torch.Tensor, top_logprobs_nums: List[int]):
|
242
|
-
assert len(top_logprobs_nums) == logprobs.shape[0], (
|
243
|
-
len(top_logprobs_nums),
|
244
|
-
logprobs.shape[0],
|
245
|
-
)
|
246
242
|
max_k = max(top_logprobs_nums)
|
247
243
|
ret = logprobs.topk(max_k, dim=1)
|
248
244
|
values = ret.values.tolist()
|
@@ -13,6 +13,7 @@ from sglang.srt.distributed import (
|
|
13
13
|
get_tensor_model_parallel_world_size,
|
14
14
|
tensor_model_parallel_all_reduce,
|
15
15
|
)
|
16
|
+
from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
|
16
17
|
from sglang.srt.layers.parameter import BasevLLMParameter
|
17
18
|
from sglang.srt.layers.quantization.base_config import (
|
18
19
|
QuantizationConfig,
|
@@ -214,12 +215,14 @@ class VocabParallelEmbedding(torch.nn.Module):
|
|
214
215
|
self,
|
215
216
|
num_embeddings: int,
|
216
217
|
embedding_dim: int,
|
218
|
+
*,
|
217
219
|
params_dtype: Optional[torch.dtype] = None,
|
218
220
|
org_num_embeddings: Optional[int] = None,
|
219
221
|
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
|
220
222
|
quant_config: Optional[QuantizationConfig] = None,
|
221
223
|
prefix: str = "",
|
222
224
|
enable_tp: bool = True,
|
225
|
+
use_attn_tp_group: bool = False,
|
223
226
|
use_presharded_weights: bool = False,
|
224
227
|
):
|
225
228
|
super().__init__()
|
@@ -227,9 +230,14 @@ class VocabParallelEmbedding(torch.nn.Module):
|
|
227
230
|
|
228
231
|
self.enable_tp = enable_tp
|
229
232
|
if self.enable_tp:
|
230
|
-
|
231
|
-
|
233
|
+
if use_attn_tp_group:
|
234
|
+
tp_rank = get_attention_tp_rank()
|
235
|
+
self.tp_size = get_attention_tp_size()
|
236
|
+
else:
|
237
|
+
tp_rank = get_tensor_model_parallel_rank()
|
238
|
+
self.tp_size = get_tensor_model_parallel_world_size()
|
232
239
|
else:
|
240
|
+
assert use_attn_tp_group is False
|
233
241
|
tp_rank = 0
|
234
242
|
self.tp_size = 1
|
235
243
|
|
@@ -519,22 +527,25 @@ class ParallelLMHead(VocabParallelEmbedding):
|
|
519
527
|
self,
|
520
528
|
num_embeddings: int,
|
521
529
|
embedding_dim: int,
|
530
|
+
*,
|
522
531
|
bias: bool = False,
|
523
532
|
params_dtype: Optional[torch.dtype] = None,
|
524
533
|
org_num_embeddings: Optional[int] = None,
|
525
534
|
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
|
526
535
|
quant_config: Optional[QuantizationConfig] = None,
|
527
536
|
prefix: str = "",
|
537
|
+
use_attn_tp_group: bool = False,
|
528
538
|
use_presharded_weights: bool = False,
|
529
539
|
):
|
530
540
|
super().__init__(
|
531
541
|
num_embeddings,
|
532
542
|
embedding_dim,
|
533
|
-
params_dtype,
|
534
|
-
org_num_embeddings,
|
535
|
-
padding_size,
|
536
|
-
quant_config,
|
537
|
-
prefix,
|
543
|
+
params_dtype=params_dtype,
|
544
|
+
org_num_embeddings=org_num_embeddings,
|
545
|
+
padding_size=padding_size,
|
546
|
+
quant_config=quant_config,
|
547
|
+
prefix=prefix,
|
548
|
+
use_attn_tp_group=use_attn_tp_group,
|
538
549
|
use_presharded_weights=use_presharded_weights,
|
539
550
|
)
|
540
551
|
self.quant_config = quant_config
|
sglang/srt/lora/lora_manager.py
CHANGED
@@ -100,7 +100,7 @@ class LoRAManager:
|
|
100
100
|
self.configs[name] = LoRAConfig(path)
|
101
101
|
self.hf_target_names.update(self.configs[name].target_modules)
|
102
102
|
|
103
|
-
# Target lora weight names for lora_a and lora_b modules
|
103
|
+
# Target lora weight names for lora_a and lora_b modules respectively.
|
104
104
|
# e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj")}
|
105
105
|
self.lora_weight_names: Set[Tuple[str]] = set(
|
106
106
|
[get_stacked_name(module) for module in self.hf_target_names]
|
sglang/srt/lora/mem_pool.py
CHANGED
@@ -50,15 +50,15 @@ class LoRAMemoryPool:
|
|
50
50
|
self.uid_to_buffer_id: Dict[Optional[str], int] = {}
|
51
51
|
|
52
52
|
# Buffer idx -> lora uid in memory pool
|
53
|
-
# All uids are
|
54
|
-
# Here we don't
|
53
|
+
# All uids are initialized as empty strings for empty buffer slots
|
54
|
+
# Here we don't initialize to None since None is a valid uid
|
55
55
|
self.buffer_id_to_uid: List[Optional[str]] = [""] * self.max_loras_per_batch
|
56
56
|
|
57
57
|
def get_lora_A_shape(
|
58
58
|
self, module_name: str, base_model: torch.nn.Module
|
59
59
|
) -> Tuple[int]:
|
60
60
|
"""
|
61
|
-
Given a module_name (might be a stacked name), return the hidden dims of modules'
|
61
|
+
Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
|
62
62
|
"""
|
63
63
|
input_dim, _ = get_hidden_dim(module_name, self.base_hf_config, base_model)
|
64
64
|
c = get_stacked_multiply(module_name)
|
@@ -75,7 +75,7 @@ class LoRAMemoryPool:
|
|
75
75
|
self, module_name: str, base_model: torch.nn.Module
|
76
76
|
) -> Tuple[int]:
|
77
77
|
"""
|
78
|
-
Given a module_name (might be a stacked name), return the hidden dims of modules'
|
78
|
+
Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
|
79
79
|
"""
|
80
80
|
_, output_dim = get_hidden_dim(module_name, self.base_hf_config, base_model)
|
81
81
|
c = get_stacked_multiply(module_name)
|
@@ -77,7 +77,7 @@ def _gate_up_lora_b_kernel(
|
|
77
77
|
k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
|
78
78
|
)
|
79
79
|
|
80
|
-
#
|
80
|
+
# Iterate to compute the block in output matrix
|
81
81
|
partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32)
|
82
82
|
for k in range(0, tl.cdiv(K, BLOCK_K)):
|
83
83
|
x_tile = tl.load(
|
@@ -79,7 +79,7 @@ def _qkv_lora_b_kernel(
|
|
79
79
|
k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
|
80
80
|
)
|
81
81
|
|
82
|
-
#
|
82
|
+
# Iterate to compute the block in output matrix
|
83
83
|
partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32)
|
84
84
|
for k in range(0, tl.cdiv(K, BLOCK_K)):
|
85
85
|
x_tile = tl.load(
|
@@ -67,7 +67,7 @@ def _sgemm_lora_a_kernel(
|
|
67
67
|
k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
|
68
68
|
)
|
69
69
|
|
70
|
-
#
|
70
|
+
# Iterate to compute the block in output matrix
|
71
71
|
partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32)
|
72
72
|
for k in range(0, tl.cdiv(K, BLOCK_K)):
|
73
73
|
x_tile = tl.load(
|
@@ -69,7 +69,7 @@ def _sgemm_lora_b_kernel(
|
|
69
69
|
k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
|
70
70
|
)
|
71
71
|
|
72
|
-
#
|
72
|
+
# Iterate to compute the block in output matrix
|
73
73
|
partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32)
|
74
74
|
for k in range(0, tl.cdiv(K, BLOCK_K)):
|
75
75
|
x_tile = tl.load(
|
sglang/srt/lora/utils.py
CHANGED
@@ -79,7 +79,7 @@ def get_hidden_dim(
|
|
79
79
|
module_name: str, config: AutoConfig, base_model: torch.nn.Module
|
80
80
|
) -> Tuple[int]:
|
81
81
|
"""
|
82
|
-
Given a module_name (might be a stacked name), return the hidden dims of modules'
|
82
|
+
Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
|
83
83
|
"""
|
84
84
|
|
85
85
|
if hasattr(base_model, "get_hidden_dim"):
|
@@ -17,13 +17,13 @@ import logging
|
|
17
17
|
import multiprocessing as mp
|
18
18
|
import signal
|
19
19
|
import threading
|
20
|
+
import time
|
20
21
|
from enum import Enum, auto
|
21
22
|
|
22
23
|
import psutil
|
23
24
|
import setproctitle
|
24
25
|
import zmq
|
25
26
|
|
26
|
-
from sglang.srt.disaggregation.utils import DisaggregationMode
|
27
27
|
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
|
28
28
|
from sglang.srt.managers.io_struct import (
|
29
29
|
TokenizedEmbeddingReqInput,
|
@@ -158,7 +158,7 @@ class DataParallelController:
|
|
158
158
|
# This thread cannot be closed because otherwise the `kill_itself_when_parent_died`
|
159
159
|
# function in scheduler.py will kill the scheduler.
|
160
160
|
while True:
|
161
|
-
|
161
|
+
time.sleep(30 * 24 * 3600)
|
162
162
|
|
163
163
|
def launch_dp_attention_schedulers(self, server_args, port_args):
|
164
164
|
self.launch_tensor_parallel_group(server_args, port_args, 0, None)
|
@@ -210,7 +210,7 @@ class DataParallelController:
|
|
210
210
|
)
|
211
211
|
# compute zmq ports for this dp rank
|
212
212
|
rank_port_args = PortArgs.init_new(server_args, dp_rank)
|
213
|
-
# Data parallelism
|
213
|
+
# Data parallelism reuses the tensor parallelism group,
|
214
214
|
# so all dp ranks should use the same nccl port.
|
215
215
|
rank_port_args.nccl_port = port_args.nccl_port
|
216
216
|
|
@@ -28,6 +28,7 @@ from sglang.srt.hf_transformers_utils import get_tokenizer
|
|
28
28
|
from sglang.srt.managers.io_struct import (
|
29
29
|
BatchEmbeddingOut,
|
30
30
|
BatchMultimodalDecodeReq,
|
31
|
+
BatchMultimodalOut,
|
31
32
|
BatchStrOut,
|
32
33
|
BatchTokenIDOut,
|
33
34
|
)
|
@@ -60,6 +61,8 @@ class DecodeStatus:
|
|
60
61
|
decode_ids: List[int]
|
61
62
|
surr_offset: int
|
62
63
|
read_offset: int
|
64
|
+
# Offset that's sent to tokenizer for incremental update.
|
65
|
+
sent_offset: int = 0
|
63
66
|
|
64
67
|
|
65
68
|
class DetokenizerManager:
|
@@ -151,7 +154,7 @@ class DetokenizerManager:
|
|
151
154
|
self.decode_status[rid] = s
|
152
155
|
else:
|
153
156
|
s = self.decode_status[rid]
|
154
|
-
s.decode_ids
|
157
|
+
s.decode_ids.extend(recv_obj.decode_ids[i])
|
155
158
|
|
156
159
|
read_ids.append(
|
157
160
|
self.trim_matched_stop(
|
@@ -199,13 +202,15 @@ class DetokenizerManager:
|
|
199
202
|
else:
|
200
203
|
new_text = find_printable_text(new_text)
|
201
204
|
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
recv_obj.no_stop_trim[i],
|
207
|
-
)
|
205
|
+
output_str = self.trim_matched_stop(
|
206
|
+
s.decoded_text + new_text,
|
207
|
+
recv_obj.finished_reasons[i],
|
208
|
+
recv_obj.no_stop_trim[i],
|
208
209
|
)
|
210
|
+
# Incrementally send text.
|
211
|
+
incremental_output = output_str[s.sent_offset :]
|
212
|
+
s.sent_offset = len(output_str)
|
213
|
+
output_strs.append(incremental_output)
|
209
214
|
|
210
215
|
return BatchStrOut(
|
211
216
|
rids=recv_obj.rids,
|
@@ -232,7 +237,15 @@ class DetokenizerManager:
|
|
232
237
|
)
|
233
238
|
|
234
239
|
def handle_multimodal_decode_req(self, recv_obj: BatchMultimodalDecodeReq):
|
235
|
-
|
240
|
+
outputs = self.tokenizer.detokenize(recv_obj)
|
241
|
+
return BatchMultimodalOut(
|
242
|
+
rids=recv_obj.rids,
|
243
|
+
finished_reasons=recv_obj.finished_reasons,
|
244
|
+
outputs=outputs,
|
245
|
+
prompt_tokens=recv_obj.prompt_tokens,
|
246
|
+
completion_tokens=recv_obj.completion_tokens,
|
247
|
+
cached_tokens=recv_obj.cached_tokens,
|
248
|
+
)
|
236
249
|
|
237
250
|
|
238
251
|
class LimitedCapacityDict(OrderedDict):
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -12,7 +12,7 @@
|
|
12
12
|
# limitations under the License.
|
13
13
|
# ==============================================================================
|
14
14
|
"""
|
15
|
-
The definition of objects
|
15
|
+
The definition of objects transferred between different
|
16
16
|
processes (TokenizerManager, DetokenizerManager, Controller).
|
17
17
|
"""
|
18
18
|
|
@@ -836,6 +836,8 @@ class ProfileReqInput:
|
|
836
836
|
# the caller doesn't need to run stop_profile.
|
837
837
|
num_steps: Optional[int] = None
|
838
838
|
activities: Optional[List[Literal["CPU", "GPU", "MEM", "CUDA_PROFILER"]]] = None
|
839
|
+
with_stack: Optional[bool] = None
|
840
|
+
record_shapes: Optional[bool] = None
|
839
841
|
|
840
842
|
|
841
843
|
class ProfileReqType(Enum):
|
sglang/srt/managers/mm_utils.py
CHANGED
@@ -51,7 +51,7 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
|
|
51
51
|
self, input_ids: List[int], mm_inputs: MultimodalInputs
|
52
52
|
) -> List[int]:
|
53
53
|
"""
|
54
|
-
This function will replace the data-tokens
|
54
|
+
This function will replace the data-tokens in between with pad_values accordingly
|
55
55
|
"""
|
56
56
|
pad_values = [item.pad_value for item in mm_inputs.mm_items]
|
57
57
|
data_token_pairs = self.data_token_id_pairs
|