sglang 0.4.4__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +6 -0
- sglang/bench_one_batch.py +1 -1
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +3 -1
- sglang/check_env.py +3 -4
- sglang/lang/backend/openai.py +18 -5
- sglang/lang/chat_template.py +28 -7
- sglang/lang/interpreter.py +7 -3
- sglang/lang/ir.py +10 -0
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/code_completion_parser.py +174 -0
- sglang/srt/configs/__init__.py +2 -6
- sglang/srt/configs/deepseekvl2.py +667 -0
- sglang/srt/configs/janus_pro.py +3 -4
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +63 -11
- sglang/srt/configs/utils.py +25 -0
- sglang/srt/connector/__init__.py +51 -0
- sglang/srt/connector/base_connector.py +112 -0
- sglang/srt/connector/redis.py +85 -0
- sglang/srt/connector/s3.py +122 -0
- sglang/srt/connector/serde/__init__.py +31 -0
- sglang/srt/connector/serde/safe_serde.py +29 -0
- sglang/srt/connector/serde/serde.py +43 -0
- sglang/srt/connector/utils.py +35 -0
- sglang/srt/conversation.py +88 -0
- sglang/srt/disaggregation/conn.py +81 -0
- sglang/srt/disaggregation/decode.py +495 -0
- sglang/srt/disaggregation/mini_lb.py +285 -0
- sglang/srt/disaggregation/prefill.py +249 -0
- sglang/srt/disaggregation/utils.py +44 -0
- sglang/srt/distributed/parallel_state.py +10 -3
- sglang/srt/entrypoints/engine.py +55 -5
- sglang/srt/entrypoints/http_server.py +71 -12
- sglang/srt/function_call_parser.py +164 -54
- sglang/srt/hf_transformers_utils.py +28 -3
- sglang/srt/layers/activation.py +4 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +295 -0
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashmla_backend.py +284 -0
- sglang/srt/layers/attention/triton_backend.py +171 -38
- sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
- sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
- sglang/srt/layers/attention/utils.py +53 -0
- sglang/srt/layers/attention/vision.py +9 -28
- sglang/srt/layers/dp_attention.py +62 -23
- sglang/srt/layers/elementwise.py +411 -0
- sglang/srt/layers/layernorm.py +24 -2
- sglang/srt/layers/linear.py +17 -5
- sglang/srt/layers/logits_processor.py +26 -7
- sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
- sglang/srt/layers/moe/ep_moe/layer.py +273 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
- sglang/srt/layers/moe/fused_moe_native.py +2 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
- sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
- sglang/srt/layers/moe/router.py +342 -0
- sglang/srt/layers/moe/topk.py +31 -18
- sglang/srt/layers/parameter.py +1 -1
- sglang/srt/layers/quantization/__init__.py +184 -126
- sglang/srt/layers/quantization/base_config.py +5 -0
- sglang/srt/layers/quantization/blockwise_int8.py +1 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
- sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
- sglang/srt/layers/quantization/fp8.py +76 -34
- sglang/srt/layers/quantization/fp8_kernel.py +24 -8
- sglang/srt/layers/quantization/fp8_utils.py +284 -28
- sglang/srt/layers/quantization/gptq.py +36 -9
- sglang/srt/layers/quantization/kv_cache.py +98 -0
- sglang/srt/layers/quantization/modelopt_quant.py +9 -7
- sglang/srt/layers/quantization/utils.py +153 -0
- sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
- sglang/srt/layers/rotary_embedding.py +66 -87
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/layers.py +68 -0
- sglang/srt/lora/lora.py +2 -22
- sglang/srt/lora/lora_manager.py +47 -23
- sglang/srt/lora/mem_pool.py +110 -51
- sglang/srt/lora/utils.py +12 -1
- sglang/srt/managers/cache_controller.py +4 -5
- sglang/srt/managers/data_parallel_controller.py +31 -9
- sglang/srt/managers/expert_distribution.py +81 -0
- sglang/srt/managers/io_struct.py +39 -3
- sglang/srt/managers/mm_utils.py +373 -0
- sglang/srt/managers/multimodal_processor.py +68 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
- sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
- sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
- sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
- sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
- sglang/srt/managers/schedule_batch.py +134 -31
- sglang/srt/managers/scheduler.py +325 -38
- sglang/srt/managers/scheduler_output_processor_mixin.py +4 -1
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +59 -23
- sglang/srt/managers/tp_worker.py +1 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
- sglang/srt/managers/utils.py +6 -1
- sglang/srt/mem_cache/hiradix_cache.py +27 -8
- sglang/srt/mem_cache/memory_pool.py +258 -98
- sglang/srt/mem_cache/paged_allocator.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +4 -4
- sglang/srt/model_executor/cuda_graph_runner.py +85 -28
- sglang/srt/model_executor/forward_batch_info.py +81 -15
- sglang/srt/model_executor/model_runner.py +70 -6
- sglang/srt/model_loader/loader.py +160 -2
- sglang/srt/model_loader/weight_utils.py +45 -0
- sglang/srt/models/deepseek_janus_pro.py +29 -86
- sglang/srt/models/deepseek_nextn.py +22 -10
- sglang/srt/models/deepseek_v2.py +326 -192
- sglang/srt/models/deepseek_vl2.py +358 -0
- sglang/srt/models/gemma3_causal.py +684 -0
- sglang/srt/models/gemma3_mm.py +462 -0
- sglang/srt/models/grok.py +374 -119
- sglang/srt/models/llama.py +47 -7
- sglang/srt/models/llama_eagle.py +1 -0
- sglang/srt/models/llama_eagle3.py +196 -0
- sglang/srt/models/llava.py +3 -3
- sglang/srt/models/llavavid.py +3 -3
- sglang/srt/models/minicpmo.py +1995 -0
- sglang/srt/models/minicpmv.py +62 -137
- sglang/srt/models/mllama.py +4 -4
- sglang/srt/models/phi3_small.py +1 -1
- sglang/srt/models/qwen2.py +3 -0
- sglang/srt/models/qwen2_5_vl.py +68 -146
- sglang/srt/models/qwen2_classification.py +75 -0
- sglang/srt/models/qwen2_moe.py +9 -1
- sglang/srt/models/qwen2_vl.py +25 -63
- sglang/srt/openai_api/adapter.py +145 -47
- sglang/srt/openai_api/protocol.py +23 -2
- sglang/srt/sampling/sampling_batch_info.py +1 -1
- sglang/srt/sampling/sampling_params.py +6 -6
- sglang/srt/server_args.py +104 -14
- sglang/srt/speculative/build_eagle_tree.py +7 -347
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
- sglang/srt/speculative/eagle_utils.py +208 -252
- sglang/srt/speculative/eagle_worker.py +139 -53
- sglang/srt/speculative/spec_info.py +6 -1
- sglang/srt/torch_memory_saver_adapter.py +22 -0
- sglang/srt/utils.py +182 -21
- sglang/test/__init__.py +0 -0
- sglang/test/attention/__init__.py +0 -0
- sglang/test/attention/test_flashattn_backend.py +312 -0
- sglang/test/runners.py +2 -0
- sglang/test/test_activation.py +2 -1
- sglang/test/test_block_fp8.py +5 -4
- sglang/test/test_block_fp8_ep.py +2 -1
- sglang/test/test_dynamic_grad_mode.py +58 -0
- sglang/test/test_layernorm.py +3 -2
- sglang/test/test_utils.py +55 -4
- sglang/utils.py +31 -0
- sglang/version.py +1 -1
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +171 -125
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
- sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
- sglang/srt/managers/image_processor.py +0 -55
- sglang/srt/managers/image_processors/base_image_processor.py +0 -219
- sglang/srt/managers/image_processors/minicpmv.py +0 -86
- sglang/srt/managers/multi_modality_padding.py +0 -134
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
@@ -16,6 +16,7 @@
|
|
16
16
|
from __future__ import annotations
|
17
17
|
|
18
18
|
import bisect
|
19
|
+
import os
|
19
20
|
from contextlib import contextmanager
|
20
21
|
from typing import TYPE_CHECKING, Callable
|
21
22
|
|
@@ -33,7 +34,7 @@ from sglang.srt.model_executor.forward_batch_info import (
|
|
33
34
|
ForwardBatch,
|
34
35
|
ForwardMode,
|
35
36
|
)
|
36
|
-
from sglang.srt.utils import is_hip
|
37
|
+
from sglang.srt.utils import get_available_gpu_memory, is_hip
|
37
38
|
|
38
39
|
_is_hip = is_hip()
|
39
40
|
|
@@ -81,7 +82,9 @@ def patch_model(
|
|
81
82
|
# tp_group.ca_comm = None
|
82
83
|
yield torch.compile(
|
83
84
|
torch.no_grad()(model.forward),
|
84
|
-
mode=
|
85
|
+
mode=os.environ.get(
|
86
|
+
"SGLANG_TORCH_COMPILE_MODE", "max-autotune-no-cudagraphs"
|
87
|
+
),
|
85
88
|
dynamic=False,
|
86
89
|
)
|
87
90
|
else:
|
@@ -117,7 +120,9 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
|
117
120
|
else:
|
118
121
|
capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
|
119
122
|
else:
|
120
|
-
|
123
|
+
# Since speculative decoding requires more cuda graph memory, we
|
124
|
+
# capture less.
|
125
|
+
capture_bs = list(range(1, 9)) + list(range(9, 33, 2)) + [64, 96, 128, 160]
|
121
126
|
|
122
127
|
if _is_hip:
|
123
128
|
capture_bs += [i * 8 for i in range(21, 33)]
|
@@ -125,16 +130,11 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
|
125
130
|
if max(capture_bs) > model_runner.req_to_token_pool.size:
|
126
131
|
# In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests
|
127
132
|
# is very small. We add more values here to make sure we capture the maximum bs.
|
128
|
-
capture_bs
|
129
|
-
|
130
|
-
|
131
|
-
capture_bs
|
132
|
-
+ [model_runner.req_to_token_pool.size - 1]
|
133
|
-
+ [model_runner.req_to_token_pool.size]
|
134
|
-
)
|
135
|
-
)
|
136
|
-
)
|
133
|
+
capture_bs += [model_runner.req_to_token_pool.size - 1] + [
|
134
|
+
model_runner.req_to_token_pool.size
|
135
|
+
]
|
137
136
|
|
137
|
+
capture_bs = list(sorted(set(capture_bs)))
|
138
138
|
capture_bs = [
|
139
139
|
bs
|
140
140
|
for bs in capture_bs
|
@@ -174,6 +174,7 @@ class CudaGraphRunner:
|
|
174
174
|
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
|
175
175
|
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
|
176
176
|
self.enable_dp_attention = model_runner.server_args.enable_dp_attention
|
177
|
+
self.speculative_algorithm = model_runner.server_args.speculative_algorithm
|
177
178
|
self.tp_size = model_runner.server_args.tp_size
|
178
179
|
self.dp_size = model_runner.server_args.dp_size
|
179
180
|
|
@@ -219,7 +220,19 @@ class CudaGraphRunner:
|
|
219
220
|
self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64)
|
220
221
|
|
221
222
|
# Speculative_inference
|
222
|
-
if
|
223
|
+
if (
|
224
|
+
model_runner.spec_algorithm.is_eagle3()
|
225
|
+
and not model_runner.is_draft_worker
|
226
|
+
):
|
227
|
+
self.hidden_states = torch.zeros(
|
228
|
+
(
|
229
|
+
self.max_num_token,
|
230
|
+
3 * self.model_runner.model_config.hidden_size,
|
231
|
+
),
|
232
|
+
dtype=self.model_runner.dtype,
|
233
|
+
)
|
234
|
+
self.model_runner.model.set_eagle3_layers_to_capture()
|
235
|
+
elif model_runner.spec_algorithm.is_eagle():
|
223
236
|
self.hidden_states = torch.zeros(
|
224
237
|
(self.max_num_token, self.model_runner.model_config.hidden_size),
|
225
238
|
dtype=self.model_runner.dtype,
|
@@ -236,7 +249,7 @@ class CudaGraphRunner:
|
|
236
249
|
if self.enable_dp_attention:
|
237
250
|
self.gathered_buffer = torch.zeros(
|
238
251
|
(
|
239
|
-
self.max_bs * self.dp_size,
|
252
|
+
self.max_bs * self.dp_size * self.num_tokens_per_bs,
|
240
253
|
self.model_runner.model_config.hidden_size,
|
241
254
|
),
|
242
255
|
dtype=self.model_runner.dtype,
|
@@ -276,13 +289,12 @@ class CudaGraphRunner:
|
|
276
289
|
|
277
290
|
def can_run(self, forward_batch: ForwardBatch):
|
278
291
|
if self.enable_dp_attention:
|
279
|
-
|
280
|
-
|
281
|
-
), max(forward_batch.global_num_tokens_cpu)
|
292
|
+
total_global_tokens = sum(forward_batch.global_num_tokens_cpu)
|
293
|
+
|
282
294
|
is_bs_supported = forward_batch.can_run_dp_cuda_graph and (
|
283
|
-
|
295
|
+
total_global_tokens in self.graphs
|
284
296
|
if self.disable_padding
|
285
|
-
else
|
297
|
+
else total_global_tokens <= self.max_bs
|
286
298
|
)
|
287
299
|
else:
|
288
300
|
is_bs_supported = (
|
@@ -304,6 +316,9 @@ class CudaGraphRunner:
|
|
304
316
|
def capture(self):
|
305
317
|
with graph_capture() as graph_capture_context:
|
306
318
|
self.stream = graph_capture_context.stream
|
319
|
+
avail_mem = get_available_gpu_memory(
|
320
|
+
self.model_runner.device, self.model_runner.gpu_id, empty_cache=False
|
321
|
+
)
|
307
322
|
# Reverse the order to enable better memory sharing across cuda graphs.
|
308
323
|
capture_range = (
|
309
324
|
tqdm.tqdm(list(reversed(self.capture_bs)))
|
@@ -311,6 +326,16 @@ class CudaGraphRunner:
|
|
311
326
|
else reversed(self.capture_bs)
|
312
327
|
)
|
313
328
|
for bs in capture_range:
|
329
|
+
if get_tensor_model_parallel_rank() == 0:
|
330
|
+
avail_mem = get_available_gpu_memory(
|
331
|
+
self.model_runner.device,
|
332
|
+
self.model_runner.gpu_id,
|
333
|
+
empty_cache=False,
|
334
|
+
)
|
335
|
+
capture_range.set_description(
|
336
|
+
f"Capturing batches ({avail_mem=:.2f} GB)"
|
337
|
+
)
|
338
|
+
|
314
339
|
with patch_model(
|
315
340
|
self.model_runner.model,
|
316
341
|
bs in self.compile_bs,
|
@@ -345,8 +370,18 @@ class CudaGraphRunner:
|
|
345
370
|
mrope_positions = self.mrope_positions[:, :bs]
|
346
371
|
|
347
372
|
if self.enable_dp_attention:
|
348
|
-
|
349
|
-
|
373
|
+
self.global_num_tokens_gpu.copy_(
|
374
|
+
torch.tensor(
|
375
|
+
[
|
376
|
+
num_tokens // self.dp_size + (i < bs % self.dp_size)
|
377
|
+
for i in range(self.dp_size)
|
378
|
+
],
|
379
|
+
dtype=torch.int32,
|
380
|
+
device=input_ids.device,
|
381
|
+
)
|
382
|
+
)
|
383
|
+
global_num_tokens = self.global_num_tokens_gpu
|
384
|
+
gathered_buffer = self.gathered_buffer[:num_tokens]
|
350
385
|
else:
|
351
386
|
global_num_tokens = None
|
352
387
|
gathered_buffer = None
|
@@ -371,7 +406,7 @@ class CudaGraphRunner:
|
|
371
406
|
encoder_lens=encoder_lens,
|
372
407
|
return_logprob=False,
|
373
408
|
positions=positions,
|
374
|
-
|
409
|
+
global_num_tokens_gpu=global_num_tokens,
|
375
410
|
gathered_buffer=gathered_buffer,
|
376
411
|
mrope_positions=mrope_positions,
|
377
412
|
spec_algorithm=self.model_runner.spec_algorithm,
|
@@ -392,6 +427,9 @@ class CudaGraphRunner:
|
|
392
427
|
|
393
428
|
# Run and capture
|
394
429
|
def run_once():
|
430
|
+
# Clean intermediate result cache for DP attention
|
431
|
+
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
|
432
|
+
|
395
433
|
logits_output = forward(input_ids, forward_batch.positions, forward_batch)
|
396
434
|
return logits_output.next_token_logits, logits_output.hidden_states
|
397
435
|
|
@@ -426,7 +464,7 @@ class CudaGraphRunner:
|
|
426
464
|
self.capture_hidden_mode = hidden_mode_from_spec_info
|
427
465
|
self.capture()
|
428
466
|
|
429
|
-
def
|
467
|
+
def replay_prepare(self, forward_batch: ForwardBatch):
|
430
468
|
self.recapture_if_needed(forward_batch)
|
431
469
|
|
432
470
|
raw_bs = forward_batch.batch_size
|
@@ -435,7 +473,7 @@ class CudaGraphRunner:
|
|
435
473
|
# Pad
|
436
474
|
if self.enable_dp_attention:
|
437
475
|
index = bisect.bisect_left(
|
438
|
-
self.capture_bs,
|
476
|
+
self.capture_bs, sum(forward_batch.global_num_tokens_cpu)
|
439
477
|
)
|
440
478
|
else:
|
441
479
|
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
@@ -459,6 +497,8 @@ class CudaGraphRunner:
|
|
459
497
|
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
|
460
498
|
if forward_batch.mrope_positions is not None:
|
461
499
|
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
|
500
|
+
if self.enable_dp_attention:
|
501
|
+
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
|
462
502
|
|
463
503
|
if hasattr(forward_batch.spec_info, "hidden_states"):
|
464
504
|
self.hidden_states[:raw_num_token] = forward_batch.spec_info.hidden_states
|
@@ -475,14 +515,31 @@ class CudaGraphRunner:
|
|
475
515
|
seq_lens_cpu=self.seq_lens_cpu,
|
476
516
|
)
|
477
517
|
|
518
|
+
# Store fields
|
519
|
+
self.raw_bs = raw_bs
|
520
|
+
self.raw_num_token = raw_num_token
|
521
|
+
self.bs = bs
|
522
|
+
|
523
|
+
def replay(
|
524
|
+
self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False
|
525
|
+
) -> LogitsProcessorOutput:
|
526
|
+
if not skip_attn_backend_init:
|
527
|
+
self.replay_prepare(forward_batch)
|
528
|
+
else:
|
529
|
+
# In speculative decoding, these two fields are still needed.
|
530
|
+
self.input_ids[: self.raw_num_token].copy_(forward_batch.input_ids)
|
531
|
+
self.positions[: self.raw_num_token].copy_(forward_batch.positions)
|
532
|
+
|
478
533
|
# Replay
|
479
|
-
self.graphs[bs].replay()
|
480
|
-
next_token_logits, hidden_states = self.output_buffers[bs]
|
534
|
+
self.graphs[self.bs].replay()
|
535
|
+
next_token_logits, hidden_states = self.output_buffers[self.bs]
|
481
536
|
|
482
537
|
logits_output = LogitsProcessorOutput(
|
483
|
-
next_token_logits=next_token_logits[:raw_num_token],
|
538
|
+
next_token_logits=next_token_logits[: self.raw_num_token],
|
484
539
|
hidden_states=(
|
485
|
-
hidden_states[:raw_num_token]
|
540
|
+
hidden_states[: self.raw_num_token]
|
541
|
+
if hidden_states is not None
|
542
|
+
else None
|
486
543
|
),
|
487
544
|
)
|
488
545
|
return logits_output
|
@@ -33,16 +33,17 @@ from dataclasses import dataclass
|
|
33
33
|
from enum import IntEnum, auto
|
34
34
|
from typing import TYPE_CHECKING, List, Optional, Union
|
35
35
|
|
36
|
+
import numpy as np
|
36
37
|
import torch
|
37
38
|
import triton
|
38
39
|
import triton.language as tl
|
39
40
|
|
40
41
|
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
|
41
|
-
from sglang.srt.utils import get_compiler_backend
|
42
|
+
from sglang.srt.utils import get_compiler_backend
|
42
43
|
|
43
44
|
if TYPE_CHECKING:
|
44
45
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
45
|
-
from sglang.srt.managers.schedule_batch import
|
46
|
+
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, MultimodalInputs
|
46
47
|
from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
|
47
48
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
48
49
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
@@ -175,7 +176,7 @@ class ForwardBatch:
|
|
175
176
|
extend_input_logprob_token_ids_gpu: Optional[torch.Tensor] = None
|
176
177
|
|
177
178
|
# For multimodal
|
178
|
-
|
179
|
+
mm_inputs: Optional[List[MultimodalInputs]] = None
|
179
180
|
|
180
181
|
# Encoder-decoder
|
181
182
|
encoder_cached: Optional[List[bool]] = None
|
@@ -241,7 +242,7 @@ class ForwardBatch:
|
|
241
242
|
req_pool_indices=batch.req_pool_indices,
|
242
243
|
seq_lens=batch.seq_lens,
|
243
244
|
out_cache_loc=batch.out_cache_loc,
|
244
|
-
|
245
|
+
mm_inputs=batch.multimodal_inputs,
|
245
246
|
encoder_cached=batch.encoder_cached,
|
246
247
|
encoder_lens=batch.encoder_lens,
|
247
248
|
encoder_lens_cpu=batch.encoder_lens_cpu,
|
@@ -263,15 +264,24 @@ class ForwardBatch:
|
|
263
264
|
extend_input_logprob_token_ids_gpu=extend_input_logprob_token_ids_gpu,
|
264
265
|
)
|
265
266
|
|
267
|
+
# For DP attention
|
266
268
|
if batch.global_num_tokens is not None:
|
267
269
|
ret.global_num_tokens_cpu = batch.global_num_tokens
|
268
|
-
|
270
|
+
ret.global_num_tokens_gpu = torch.tensor(
|
271
|
+
batch.global_num_tokens, dtype=torch.int64
|
272
|
+
).to(device, non_blocking=True)
|
273
|
+
|
274
|
+
ret.global_num_tokens_for_logprob_cpu = batch.global_num_tokens_for_logprob
|
275
|
+
ret.global_num_tokens_for_logprob_gpu = torch.tensor(
|
276
|
+
batch.global_num_tokens_for_logprob, dtype=torch.int64
|
277
|
+
).to(device, non_blocking=True)
|
278
|
+
|
279
|
+
sum_len = sum(batch.global_num_tokens)
|
269
280
|
ret.gathered_buffer = torch.zeros(
|
270
|
-
(
|
281
|
+
(sum_len, model_runner.model_config.hidden_size),
|
271
282
|
dtype=model_runner.dtype,
|
272
283
|
device=device,
|
273
284
|
)
|
274
|
-
|
275
285
|
if ret.forward_mode.is_idle():
|
276
286
|
ret.positions = torch.empty((0,), device=device)
|
277
287
|
return ret
|
@@ -322,6 +332,53 @@ class ForwardBatch:
|
|
322
332
|
|
323
333
|
return ret
|
324
334
|
|
335
|
+
def merge_mm_inputs(self) -> Optional[MultimodalInputs]:
|
336
|
+
"""
|
337
|
+
Merge all image inputs in the batch into a single MultiModalInputs object.
|
338
|
+
|
339
|
+
Returns:
|
340
|
+
if none, current batch contains no image input
|
341
|
+
|
342
|
+
"""
|
343
|
+
if not self.mm_inputs or all(x is None for x in self.mm_inputs):
|
344
|
+
return None
|
345
|
+
|
346
|
+
# Filter out None values
|
347
|
+
valid_inputs = [x for x in self.mm_inputs if x is not None]
|
348
|
+
|
349
|
+
# Start with the first valid image input
|
350
|
+
merged = valid_inputs[0]
|
351
|
+
|
352
|
+
# Merge remaining inputs
|
353
|
+
for mm_input in valid_inputs[1:]:
|
354
|
+
merged.merge(mm_input)
|
355
|
+
|
356
|
+
if isinstance(merged.pixel_values, np.ndarray):
|
357
|
+
merged.pixel_values = torch.from_numpy(merged.pixel_values)
|
358
|
+
if isinstance(merged.audio_features, np.ndarray):
|
359
|
+
merged.audio_features = torch.from_numpy(merged.audio_features)
|
360
|
+
|
361
|
+
return merged
|
362
|
+
|
363
|
+
def contains_image_inputs(self) -> bool:
|
364
|
+
if self.mm_inputs is None:
|
365
|
+
return False
|
366
|
+
return any(
|
367
|
+
mm_input is not None and mm_input.contains_image_inputs()
|
368
|
+
for mm_input in self.mm_inputs
|
369
|
+
)
|
370
|
+
|
371
|
+
def contains_audio_inputs(self) -> bool:
|
372
|
+
if self.mm_inputs is None:
|
373
|
+
return False
|
374
|
+
return any(
|
375
|
+
mm_input is not None and mm_input.contains_audio_inputs()
|
376
|
+
for mm_input in self.mm_inputs
|
377
|
+
)
|
378
|
+
|
379
|
+
def contains_mm_inputs(self) -> bool:
|
380
|
+
return self.contains_audio_inputs() or self.contains_image_inputs()
|
381
|
+
|
325
382
|
def _compute_mrope_positions(
|
326
383
|
self, model_runner: ModelRunner, batch: ModelWorkerBatch
|
327
384
|
):
|
@@ -332,8 +389,8 @@ class ForwardBatch:
|
|
332
389
|
for i, _ in enumerate(mrope_positions_list):
|
333
390
|
mrope_position_delta = (
|
334
391
|
0
|
335
|
-
if batch.
|
336
|
-
else batch.
|
392
|
+
if batch.multimodal_inputs[i] is None
|
393
|
+
else batch.multimodal_inputs[i].mrope_position_delta
|
337
394
|
)
|
338
395
|
mrope_positions_list[i] = MRotaryEmbedding.get_next_input_positions(
|
339
396
|
mrope_position_delta,
|
@@ -342,13 +399,13 @@ class ForwardBatch:
|
|
342
399
|
)
|
343
400
|
elif self.forward_mode.is_extend():
|
344
401
|
extend_start_loc_cpu = self.extend_start_loc.cpu().numpy()
|
345
|
-
for i,
|
402
|
+
for i, multimodal_inputs in enumerate(batch.multimodal_inputs):
|
346
403
|
extend_start_loc, extend_seq_len, extend_prefix_len = (
|
347
404
|
extend_start_loc_cpu[i],
|
348
405
|
batch.extend_seq_lens[i],
|
349
406
|
batch.extend_prefix_lens[i],
|
350
407
|
)
|
351
|
-
if
|
408
|
+
if multimodal_inputs is None:
|
352
409
|
# text only
|
353
410
|
mrope_positions = [
|
354
411
|
[
|
@@ -365,16 +422,25 @@ class ForwardBatch:
|
|
365
422
|
input_tokens=self.input_ids[
|
366
423
|
extend_start_loc : extend_start_loc + extend_seq_len
|
367
424
|
],
|
368
|
-
image_grid_thw=
|
425
|
+
image_grid_thw=multimodal_inputs.image_grid_thws,
|
426
|
+
video_grid_thw=multimodal_inputs.video_grid_thws,
|
427
|
+
image_token_id=multimodal_inputs.im_token_id,
|
428
|
+
video_token_id=multimodal_inputs.video_token_id,
|
369
429
|
vision_start_token_id=hf_config.vision_start_token_id,
|
430
|
+
vision_end_token_id=hf_config.vision_end_token_id,
|
370
431
|
spatial_merge_size=hf_config.vision_config.spatial_merge_size,
|
371
432
|
context_len=0,
|
433
|
+
seq_len=len(self.input_ids),
|
434
|
+
second_per_grid_ts=multimodal_inputs.second_per_grid_ts,
|
435
|
+
tokens_per_second=hf_config.vision_config.tokens_per_second,
|
372
436
|
)
|
373
437
|
)
|
374
|
-
batch.
|
438
|
+
batch.multimodal_inputs[i].mrope_position_delta = (
|
439
|
+
mrope_position_delta
|
440
|
+
)
|
375
441
|
mrope_positions_list[i] = mrope_positions
|
376
442
|
|
377
|
-
self.mrope_positions = torch.
|
443
|
+
self.mrope_positions = torch.cat(
|
378
444
|
[torch.tensor(pos, device=device) for pos in mrope_positions_list],
|
379
445
|
axis=1,
|
380
446
|
)
|
@@ -440,7 +506,7 @@ def compute_position_kernel(
|
|
440
506
|
def compute_position_torch(
|
441
507
|
extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor
|
442
508
|
):
|
443
|
-
positions = torch.
|
509
|
+
positions = torch.cat(
|
444
510
|
[
|
445
511
|
torch.arange(
|
446
512
|
prefix_len, prefix_len + extend_len, device=extend_prefix_lens.device
|
@@ -145,10 +145,12 @@ class ModelRunner:
|
|
145
145
|
"enable_nan_detection": server_args.enable_nan_detection,
|
146
146
|
"enable_dp_attention": server_args.enable_dp_attention,
|
147
147
|
"enable_ep_moe": server_args.enable_ep_moe,
|
148
|
+
"enable_deepep_moe": server_args.enable_deepep_moe,
|
148
149
|
"device": server_args.device,
|
149
150
|
"speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
|
150
151
|
"speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
|
151
152
|
"enable_flashinfer_mla": server_args.enable_flashinfer_mla,
|
153
|
+
"enable_flashmla": server_args.enable_flashmla,
|
152
154
|
"disable_radix_cache": server_args.disable_radix_cache,
|
153
155
|
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
|
154
156
|
"debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
|
@@ -187,9 +189,6 @@ class ModelRunner:
|
|
187
189
|
supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
|
188
190
|
if self.tp_size > 1 and supports_torch_tp:
|
189
191
|
self.apply_torch_tp()
|
190
|
-
self.torch_tp_applied = True
|
191
|
-
else:
|
192
|
-
self.torch_tp_applied = False
|
193
192
|
|
194
193
|
# Init lora
|
195
194
|
if server_args.lora_paths is not None:
|
@@ -209,6 +208,10 @@ class ModelRunner:
|
|
209
208
|
self.cuda_graph_runner = None
|
210
209
|
self.init_attention_backend()
|
211
210
|
|
211
|
+
# auxiliary hidden capture mode. TODO: expose this to server args?
|
212
|
+
if self.spec_algorithm.is_eagle3() and not self.is_draft_worker:
|
213
|
+
self.model.set_eagle3_layers_to_capture()
|
214
|
+
|
212
215
|
def model_specific_adjustment(self):
|
213
216
|
server_args = self.server_args
|
214
217
|
|
@@ -223,6 +226,9 @@ class ModelRunner:
|
|
223
226
|
"MLA optimization is turned on. Use flashinfer mla backend."
|
224
227
|
)
|
225
228
|
server_args.attention_backend = "flashinfer_mla"
|
229
|
+
elif server_args.enable_flashmla:
|
230
|
+
logger.info("MLA optimization is turned on. Use flashmla decode.")
|
231
|
+
server_args.attention_backend = "flashmla"
|
226
232
|
else:
|
227
233
|
logger.info("MLA optimization is turned on. Use triton backend.")
|
228
234
|
server_args.attention_backend = "triton"
|
@@ -254,18 +260,41 @@ class ModelRunner:
|
|
254
260
|
|
255
261
|
if self.model_config.hf_config.architectures == [
|
256
262
|
"Qwen2VLForConditionalGeneration"
|
263
|
+
] or self.model_config.hf_config.architectures == [
|
264
|
+
"Qwen2_5_VLForConditionalGeneration"
|
257
265
|
]:
|
258
|
-
# TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
|
266
|
+
# TODO: qwen2-vl series does not support radix cache now, set disable_radix_cache=True automatically
|
267
|
+
logger.info(
|
268
|
+
"Automatically turn off --chunked-prefill-size and disable radix cache for qwen-vl series."
|
269
|
+
)
|
270
|
+
server_args.chunked_prefill_size = -1
|
271
|
+
server_args.disable_radix_cache = True
|
272
|
+
|
273
|
+
if self.model_config.hf_config.architectures == ["DeepseekVL2ForCausalLM"]:
|
274
|
+
# TODO: deepseek-vl2 does not support radix cache now, set disable_radix_cache=True automatically
|
259
275
|
logger.info(
|
260
|
-
"Automatically turn off --chunked-prefill-size and disable radix cache for
|
276
|
+
"Automatically turn off --chunked-prefill-size and disable radix cache for deepseek-vl2."
|
261
277
|
)
|
262
278
|
server_args.chunked_prefill_size = -1
|
263
279
|
server_args.disable_radix_cache = True
|
264
280
|
|
281
|
+
if server_args.enable_deepep_moe:
|
282
|
+
logger.info("DeepEP is turned on.")
|
283
|
+
assert (
|
284
|
+
server_args.enable_dp_attention == True
|
285
|
+
), "Currently DeepEP is bind to Attention DP. Set '--enable-dp-attention --enable-deepep-moe'"
|
286
|
+
|
265
287
|
def init_torch_distributed(self):
|
266
288
|
logger.info("Init torch distributed begin.")
|
267
289
|
|
268
|
-
|
290
|
+
try:
|
291
|
+
torch.get_device_module(self.device).set_device(self.gpu_id)
|
292
|
+
except Exception:
|
293
|
+
logger.warning(
|
294
|
+
f"Context: {self.device=} {self.gpu_id=} {os.environ.get('CUDA_VISIBLE_DEVICES')=} {self.tp_rank=} {self.tp_size=}"
|
295
|
+
)
|
296
|
+
raise
|
297
|
+
|
269
298
|
if self.device == "cuda":
|
270
299
|
backend = "nccl"
|
271
300
|
elif self.device == "xpu":
|
@@ -606,6 +635,8 @@ class ModelRunner:
|
|
606
635
|
load_config=self.load_config,
|
607
636
|
dtype=self.dtype,
|
608
637
|
lora_backend=self.server_args.lora_backend,
|
638
|
+
tp_size=self.tp_size,
|
639
|
+
tp_rank=self.tp_rank,
|
609
640
|
)
|
610
641
|
logger.info("LoRA manager ready.")
|
611
642
|
|
@@ -840,6 +871,23 @@ class ModelRunner:
|
|
840
871
|
)
|
841
872
|
|
842
873
|
self.attn_backend = FlashInferMLAAttnBackend(self)
|
874
|
+
elif self.server_args.attention_backend == "flashmla":
|
875
|
+
from sglang.srt.layers.attention.flashmla_backend import FlashMLABackend
|
876
|
+
|
877
|
+
self.attn_backend = FlashMLABackend(self)
|
878
|
+
elif self.server_args.attention_backend == "fa3":
|
879
|
+
assert torch.cuda.get_device_capability()[0] >= 9, (
|
880
|
+
"FlashAttention v3 Backend requires SM>=90. "
|
881
|
+
"Please use `--attention-backend flashinfer`."
|
882
|
+
)
|
883
|
+
logger.warning(
|
884
|
+
"FlashAttention v3 Backend is in Beta. Multimodal, Page > 1, FP8, MLA and Speculative Decoding are not supported."
|
885
|
+
)
|
886
|
+
from sglang.srt.layers.attention.flashattention_backend import (
|
887
|
+
FlashAttentionBackend,
|
888
|
+
)
|
889
|
+
|
890
|
+
self.attn_backend = FlashAttentionBackend(self)
|
843
891
|
else:
|
844
892
|
raise ValueError(
|
845
893
|
f"Invalid attention backend: {self.server_args.attention_backend}"
|
@@ -1009,6 +1057,22 @@ class ModelRunner:
|
|
1009
1057
|
return False
|
1010
1058
|
return rope_scaling.get("type", None) == "mrope"
|
1011
1059
|
|
1060
|
+
def save_remote_model(self, url: str):
|
1061
|
+
from sglang.srt.model_loader.loader import RemoteModelLoader
|
1062
|
+
|
1063
|
+
logger.info(f"Saving model to {url}")
|
1064
|
+
RemoteModelLoader.save_model(self.model, self.model_config.model_path, url)
|
1065
|
+
|
1066
|
+
def save_sharded_model(
|
1067
|
+
self, path: str, pattern: Optional[str] = None, max_size: Optional[int] = None
|
1068
|
+
):
|
1069
|
+
from sglang.srt.model_loader.loader import ShardedStateLoader
|
1070
|
+
|
1071
|
+
logger.info(
|
1072
|
+
f"Save sharded model to {path} with pattern {pattern} and max_size {max_size}"
|
1073
|
+
)
|
1074
|
+
ShardedStateLoader.save_model(self.model, path, pattern, max_size)
|
1075
|
+
|
1012
1076
|
|
1013
1077
|
def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tensor]]):
|
1014
1078
|
params_dict = dict(model.named_parameters())
|