sglang 0.4.4.post1__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +6 -0
- sglang/bench_one_batch.py +1 -1
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +3 -1
- sglang/check_env.py +3 -4
- sglang/lang/backend/openai.py +18 -5
- sglang/lang/chat_template.py +28 -7
- sglang/lang/interpreter.py +7 -3
- sglang/lang/ir.py +10 -0
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/code_completion_parser.py +174 -0
- sglang/srt/configs/__init__.py +2 -6
- sglang/srt/configs/deepseekvl2.py +667 -0
- sglang/srt/configs/janus_pro.py +3 -4
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +63 -11
- sglang/srt/configs/utils.py +25 -0
- sglang/srt/connector/__init__.py +51 -0
- sglang/srt/connector/base_connector.py +112 -0
- sglang/srt/connector/redis.py +85 -0
- sglang/srt/connector/s3.py +122 -0
- sglang/srt/connector/serde/__init__.py +31 -0
- sglang/srt/connector/serde/safe_serde.py +29 -0
- sglang/srt/connector/serde/serde.py +43 -0
- sglang/srt/connector/utils.py +35 -0
- sglang/srt/conversation.py +88 -0
- sglang/srt/disaggregation/conn.py +81 -0
- sglang/srt/disaggregation/decode.py +495 -0
- sglang/srt/disaggregation/mini_lb.py +285 -0
- sglang/srt/disaggregation/prefill.py +249 -0
- sglang/srt/disaggregation/utils.py +44 -0
- sglang/srt/distributed/parallel_state.py +10 -3
- sglang/srt/entrypoints/engine.py +55 -5
- sglang/srt/entrypoints/http_server.py +71 -12
- sglang/srt/function_call_parser.py +133 -54
- sglang/srt/hf_transformers_utils.py +28 -3
- sglang/srt/layers/activation.py +4 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +295 -0
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashmla_backend.py +284 -0
- sglang/srt/layers/attention/triton_backend.py +171 -38
- sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
- sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
- sglang/srt/layers/attention/utils.py +53 -0
- sglang/srt/layers/attention/vision.py +9 -28
- sglang/srt/layers/dp_attention.py +32 -21
- sglang/srt/layers/layernorm.py +24 -2
- sglang/srt/layers/linear.py +17 -5
- sglang/srt/layers/logits_processor.py +25 -7
- sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
- sglang/srt/layers/moe/ep_moe/layer.py +273 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
- sglang/srt/layers/moe/fused_moe_native.py +2 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
- sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
- sglang/srt/layers/moe/topk.py +31 -18
- sglang/srt/layers/parameter.py +1 -1
- sglang/srt/layers/quantization/__init__.py +184 -126
- sglang/srt/layers/quantization/base_config.py +5 -0
- sglang/srt/layers/quantization/blockwise_int8.py +1 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
- sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
- sglang/srt/layers/quantization/fp8.py +76 -34
- sglang/srt/layers/quantization/fp8_kernel.py +24 -8
- sglang/srt/layers/quantization/fp8_utils.py +284 -28
- sglang/srt/layers/quantization/gptq.py +36 -9
- sglang/srt/layers/quantization/kv_cache.py +98 -0
- sglang/srt/layers/quantization/modelopt_quant.py +9 -7
- sglang/srt/layers/quantization/utils.py +153 -0
- sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
- sglang/srt/layers/rotary_embedding.py +66 -87
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/layers.py +68 -0
- sglang/srt/lora/lora.py +2 -22
- sglang/srt/lora/lora_manager.py +47 -23
- sglang/srt/lora/mem_pool.py +110 -51
- sglang/srt/lora/utils.py +12 -1
- sglang/srt/managers/cache_controller.py +2 -5
- sglang/srt/managers/data_parallel_controller.py +30 -8
- sglang/srt/managers/expert_distribution.py +81 -0
- sglang/srt/managers/io_struct.py +39 -3
- sglang/srt/managers/mm_utils.py +373 -0
- sglang/srt/managers/multimodal_processor.py +68 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
- sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
- sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
- sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
- sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
- sglang/srt/managers/schedule_batch.py +133 -30
- sglang/srt/managers/scheduler.py +273 -20
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +59 -23
- sglang/srt/managers/tp_worker.py +1 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
- sglang/srt/managers/utils.py +6 -1
- sglang/srt/mem_cache/hiradix_cache.py +18 -7
- sglang/srt/mem_cache/memory_pool.py +255 -98
- sglang/srt/mem_cache/paged_allocator.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +4 -4
- sglang/srt/model_executor/cuda_graph_runner.py +27 -13
- sglang/srt/model_executor/forward_batch_info.py +68 -11
- sglang/srt/model_executor/model_runner.py +70 -6
- sglang/srt/model_loader/loader.py +160 -2
- sglang/srt/model_loader/weight_utils.py +45 -0
- sglang/srt/models/deepseek_janus_pro.py +29 -86
- sglang/srt/models/deepseek_nextn.py +22 -10
- sglang/srt/models/deepseek_v2.py +208 -77
- sglang/srt/models/deepseek_vl2.py +358 -0
- sglang/srt/models/gemma3_causal.py +684 -0
- sglang/srt/models/gemma3_mm.py +462 -0
- sglang/srt/models/llama.py +47 -7
- sglang/srt/models/llama_eagle.py +1 -0
- sglang/srt/models/llama_eagle3.py +196 -0
- sglang/srt/models/llava.py +3 -3
- sglang/srt/models/llavavid.py +3 -3
- sglang/srt/models/minicpmo.py +1995 -0
- sglang/srt/models/minicpmv.py +62 -137
- sglang/srt/models/mllama.py +4 -4
- sglang/srt/models/phi3_small.py +1 -1
- sglang/srt/models/qwen2.py +3 -0
- sglang/srt/models/qwen2_5_vl.py +68 -146
- sglang/srt/models/qwen2_classification.py +75 -0
- sglang/srt/models/qwen2_moe.py +9 -1
- sglang/srt/models/qwen2_vl.py +25 -63
- sglang/srt/openai_api/adapter.py +124 -28
- sglang/srt/openai_api/protocol.py +23 -2
- sglang/srt/sampling/sampling_batch_info.py +1 -1
- sglang/srt/sampling/sampling_params.py +6 -6
- sglang/srt/server_args.py +99 -9
- sglang/srt/speculative/build_eagle_tree.py +7 -347
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
- sglang/srt/speculative/eagle_utils.py +208 -252
- sglang/srt/speculative/eagle_worker.py +139 -53
- sglang/srt/speculative/spec_info.py +6 -1
- sglang/srt/torch_memory_saver_adapter.py +22 -0
- sglang/srt/utils.py +182 -21
- sglang/test/__init__.py +0 -0
- sglang/test/attention/__init__.py +0 -0
- sglang/test/attention/test_flashattn_backend.py +312 -0
- sglang/test/runners.py +2 -0
- sglang/test/test_activation.py +2 -1
- sglang/test/test_block_fp8.py +5 -4
- sglang/test/test_block_fp8_ep.py +2 -1
- sglang/test/test_dynamic_grad_mode.py +58 -0
- sglang/test/test_layernorm.py +3 -2
- sglang/test/test_utils.py +55 -4
- sglang/utils.py +31 -0
- sglang/version.py +1 -1
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +167 -123
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
- sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
- sglang/srt/managers/image_processor.py +0 -55
- sglang/srt/managers/image_processors/base_image_processor.py +0 -219
- sglang/srt/managers/image_processors/minicpmv.py +0 -86
- sglang/srt/managers/multi_modality_padding.py +0 -134
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
@@ -16,6 +16,7 @@
|
|
16
16
|
from __future__ import annotations
|
17
17
|
|
18
18
|
import bisect
|
19
|
+
import os
|
19
20
|
from contextlib import contextmanager
|
20
21
|
from typing import TYPE_CHECKING, Callable
|
21
22
|
|
@@ -81,7 +82,9 @@ def patch_model(
|
|
81
82
|
# tp_group.ca_comm = None
|
82
83
|
yield torch.compile(
|
83
84
|
torch.no_grad()(model.forward),
|
84
|
-
mode=
|
85
|
+
mode=os.environ.get(
|
86
|
+
"SGLANG_TORCH_COMPILE_MODE", "max-autotune-no-cudagraphs"
|
87
|
+
),
|
85
88
|
dynamic=False,
|
86
89
|
)
|
87
90
|
else:
|
@@ -117,7 +120,9 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
|
117
120
|
else:
|
118
121
|
capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
|
119
122
|
else:
|
120
|
-
|
123
|
+
# Since speculative decoding requires more cuda graph memory, we
|
124
|
+
# capture less.
|
125
|
+
capture_bs = list(range(1, 9)) + list(range(9, 33, 2)) + [64, 96, 128, 160]
|
121
126
|
|
122
127
|
if _is_hip:
|
123
128
|
capture_bs += [i * 8 for i in range(21, 33)]
|
@@ -125,16 +130,11 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
|
125
130
|
if max(capture_bs) > model_runner.req_to_token_pool.size:
|
126
131
|
# In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests
|
127
132
|
# is very small. We add more values here to make sure we capture the maximum bs.
|
128
|
-
capture_bs
|
129
|
-
|
130
|
-
|
131
|
-
capture_bs
|
132
|
-
+ [model_runner.req_to_token_pool.size - 1]
|
133
|
-
+ [model_runner.req_to_token_pool.size]
|
134
|
-
)
|
135
|
-
)
|
136
|
-
)
|
133
|
+
capture_bs += [model_runner.req_to_token_pool.size - 1] + [
|
134
|
+
model_runner.req_to_token_pool.size
|
135
|
+
]
|
137
136
|
|
137
|
+
capture_bs = list(sorted(set(capture_bs)))
|
138
138
|
capture_bs = [
|
139
139
|
bs
|
140
140
|
for bs in capture_bs
|
@@ -220,7 +220,19 @@ class CudaGraphRunner:
|
|
220
220
|
self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64)
|
221
221
|
|
222
222
|
# Speculative_inference
|
223
|
-
if
|
223
|
+
if (
|
224
|
+
model_runner.spec_algorithm.is_eagle3()
|
225
|
+
and not model_runner.is_draft_worker
|
226
|
+
):
|
227
|
+
self.hidden_states = torch.zeros(
|
228
|
+
(
|
229
|
+
self.max_num_token,
|
230
|
+
3 * self.model_runner.model_config.hidden_size,
|
231
|
+
),
|
232
|
+
dtype=self.model_runner.dtype,
|
233
|
+
)
|
234
|
+
self.model_runner.model.set_eagle3_layers_to_capture()
|
235
|
+
elif model_runner.spec_algorithm.is_eagle():
|
224
236
|
self.hidden_states = torch.zeros(
|
225
237
|
(self.max_num_token, self.model_runner.model_config.hidden_size),
|
226
238
|
dtype=self.model_runner.dtype,
|
@@ -508,7 +520,9 @@ class CudaGraphRunner:
|
|
508
520
|
self.raw_num_token = raw_num_token
|
509
521
|
self.bs = bs
|
510
522
|
|
511
|
-
def replay(
|
523
|
+
def replay(
|
524
|
+
self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False
|
525
|
+
) -> LogitsProcessorOutput:
|
512
526
|
if not skip_attn_backend_init:
|
513
527
|
self.replay_prepare(forward_batch)
|
514
528
|
else:
|
@@ -33,6 +33,7 @@ from dataclasses import dataclass
|
|
33
33
|
from enum import IntEnum, auto
|
34
34
|
from typing import TYPE_CHECKING, List, Optional, Union
|
35
35
|
|
36
|
+
import numpy as np
|
36
37
|
import torch
|
37
38
|
import triton
|
38
39
|
import triton.language as tl
|
@@ -42,7 +43,7 @@ from sglang.srt.utils import get_compiler_backend
|
|
42
43
|
|
43
44
|
if TYPE_CHECKING:
|
44
45
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
45
|
-
from sglang.srt.managers.schedule_batch import
|
46
|
+
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, MultimodalInputs
|
46
47
|
from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
|
47
48
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
48
49
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
@@ -175,7 +176,7 @@ class ForwardBatch:
|
|
175
176
|
extend_input_logprob_token_ids_gpu: Optional[torch.Tensor] = None
|
176
177
|
|
177
178
|
# For multimodal
|
178
|
-
|
179
|
+
mm_inputs: Optional[List[MultimodalInputs]] = None
|
179
180
|
|
180
181
|
# Encoder-decoder
|
181
182
|
encoder_cached: Optional[List[bool]] = None
|
@@ -241,7 +242,7 @@ class ForwardBatch:
|
|
241
242
|
req_pool_indices=batch.req_pool_indices,
|
242
243
|
seq_lens=batch.seq_lens,
|
243
244
|
out_cache_loc=batch.out_cache_loc,
|
244
|
-
|
245
|
+
mm_inputs=batch.multimodal_inputs,
|
245
246
|
encoder_cached=batch.encoder_cached,
|
246
247
|
encoder_lens=batch.encoder_lens,
|
247
248
|
encoder_lens_cpu=batch.encoder_lens_cpu,
|
@@ -331,6 +332,53 @@ class ForwardBatch:
|
|
331
332
|
|
332
333
|
return ret
|
333
334
|
|
335
|
+
def merge_mm_inputs(self) -> Optional[MultimodalInputs]:
|
336
|
+
"""
|
337
|
+
Merge all image inputs in the batch into a single MultiModalInputs object.
|
338
|
+
|
339
|
+
Returns:
|
340
|
+
if none, current batch contains no image input
|
341
|
+
|
342
|
+
"""
|
343
|
+
if not self.mm_inputs or all(x is None for x in self.mm_inputs):
|
344
|
+
return None
|
345
|
+
|
346
|
+
# Filter out None values
|
347
|
+
valid_inputs = [x for x in self.mm_inputs if x is not None]
|
348
|
+
|
349
|
+
# Start with the first valid image input
|
350
|
+
merged = valid_inputs[0]
|
351
|
+
|
352
|
+
# Merge remaining inputs
|
353
|
+
for mm_input in valid_inputs[1:]:
|
354
|
+
merged.merge(mm_input)
|
355
|
+
|
356
|
+
if isinstance(merged.pixel_values, np.ndarray):
|
357
|
+
merged.pixel_values = torch.from_numpy(merged.pixel_values)
|
358
|
+
if isinstance(merged.audio_features, np.ndarray):
|
359
|
+
merged.audio_features = torch.from_numpy(merged.audio_features)
|
360
|
+
|
361
|
+
return merged
|
362
|
+
|
363
|
+
def contains_image_inputs(self) -> bool:
|
364
|
+
if self.mm_inputs is None:
|
365
|
+
return False
|
366
|
+
return any(
|
367
|
+
mm_input is not None and mm_input.contains_image_inputs()
|
368
|
+
for mm_input in self.mm_inputs
|
369
|
+
)
|
370
|
+
|
371
|
+
def contains_audio_inputs(self) -> bool:
|
372
|
+
if self.mm_inputs is None:
|
373
|
+
return False
|
374
|
+
return any(
|
375
|
+
mm_input is not None and mm_input.contains_audio_inputs()
|
376
|
+
for mm_input in self.mm_inputs
|
377
|
+
)
|
378
|
+
|
379
|
+
def contains_mm_inputs(self) -> bool:
|
380
|
+
return self.contains_audio_inputs() or self.contains_image_inputs()
|
381
|
+
|
334
382
|
def _compute_mrope_positions(
|
335
383
|
self, model_runner: ModelRunner, batch: ModelWorkerBatch
|
336
384
|
):
|
@@ -341,8 +389,8 @@ class ForwardBatch:
|
|
341
389
|
for i, _ in enumerate(mrope_positions_list):
|
342
390
|
mrope_position_delta = (
|
343
391
|
0
|
344
|
-
if batch.
|
345
|
-
else batch.
|
392
|
+
if batch.multimodal_inputs[i] is None
|
393
|
+
else batch.multimodal_inputs[i].mrope_position_delta
|
346
394
|
)
|
347
395
|
mrope_positions_list[i] = MRotaryEmbedding.get_next_input_positions(
|
348
396
|
mrope_position_delta,
|
@@ -351,13 +399,13 @@ class ForwardBatch:
|
|
351
399
|
)
|
352
400
|
elif self.forward_mode.is_extend():
|
353
401
|
extend_start_loc_cpu = self.extend_start_loc.cpu().numpy()
|
354
|
-
for i,
|
402
|
+
for i, multimodal_inputs in enumerate(batch.multimodal_inputs):
|
355
403
|
extend_start_loc, extend_seq_len, extend_prefix_len = (
|
356
404
|
extend_start_loc_cpu[i],
|
357
405
|
batch.extend_seq_lens[i],
|
358
406
|
batch.extend_prefix_lens[i],
|
359
407
|
)
|
360
|
-
if
|
408
|
+
if multimodal_inputs is None:
|
361
409
|
# text only
|
362
410
|
mrope_positions = [
|
363
411
|
[
|
@@ -374,16 +422,25 @@ class ForwardBatch:
|
|
374
422
|
input_tokens=self.input_ids[
|
375
423
|
extend_start_loc : extend_start_loc + extend_seq_len
|
376
424
|
],
|
377
|
-
image_grid_thw=
|
425
|
+
image_grid_thw=multimodal_inputs.image_grid_thws,
|
426
|
+
video_grid_thw=multimodal_inputs.video_grid_thws,
|
427
|
+
image_token_id=multimodal_inputs.im_token_id,
|
428
|
+
video_token_id=multimodal_inputs.video_token_id,
|
378
429
|
vision_start_token_id=hf_config.vision_start_token_id,
|
430
|
+
vision_end_token_id=hf_config.vision_end_token_id,
|
379
431
|
spatial_merge_size=hf_config.vision_config.spatial_merge_size,
|
380
432
|
context_len=0,
|
433
|
+
seq_len=len(self.input_ids),
|
434
|
+
second_per_grid_ts=multimodal_inputs.second_per_grid_ts,
|
435
|
+
tokens_per_second=hf_config.vision_config.tokens_per_second,
|
381
436
|
)
|
382
437
|
)
|
383
|
-
batch.
|
438
|
+
batch.multimodal_inputs[i].mrope_position_delta = (
|
439
|
+
mrope_position_delta
|
440
|
+
)
|
384
441
|
mrope_positions_list[i] = mrope_positions
|
385
442
|
|
386
|
-
self.mrope_positions = torch.
|
443
|
+
self.mrope_positions = torch.cat(
|
387
444
|
[torch.tensor(pos, device=device) for pos in mrope_positions_list],
|
388
445
|
axis=1,
|
389
446
|
)
|
@@ -449,7 +506,7 @@ def compute_position_kernel(
|
|
449
506
|
def compute_position_torch(
|
450
507
|
extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor
|
451
508
|
):
|
452
|
-
positions = torch.
|
509
|
+
positions = torch.cat(
|
453
510
|
[
|
454
511
|
torch.arange(
|
455
512
|
prefix_len, prefix_len + extend_len, device=extend_prefix_lens.device
|
@@ -145,10 +145,12 @@ class ModelRunner:
|
|
145
145
|
"enable_nan_detection": server_args.enable_nan_detection,
|
146
146
|
"enable_dp_attention": server_args.enable_dp_attention,
|
147
147
|
"enable_ep_moe": server_args.enable_ep_moe,
|
148
|
+
"enable_deepep_moe": server_args.enable_deepep_moe,
|
148
149
|
"device": server_args.device,
|
149
150
|
"speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
|
150
151
|
"speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
|
151
152
|
"enable_flashinfer_mla": server_args.enable_flashinfer_mla,
|
153
|
+
"enable_flashmla": server_args.enable_flashmla,
|
152
154
|
"disable_radix_cache": server_args.disable_radix_cache,
|
153
155
|
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
|
154
156
|
"debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
|
@@ -187,9 +189,6 @@ class ModelRunner:
|
|
187
189
|
supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
|
188
190
|
if self.tp_size > 1 and supports_torch_tp:
|
189
191
|
self.apply_torch_tp()
|
190
|
-
self.torch_tp_applied = True
|
191
|
-
else:
|
192
|
-
self.torch_tp_applied = False
|
193
192
|
|
194
193
|
# Init lora
|
195
194
|
if server_args.lora_paths is not None:
|
@@ -209,6 +208,10 @@ class ModelRunner:
|
|
209
208
|
self.cuda_graph_runner = None
|
210
209
|
self.init_attention_backend()
|
211
210
|
|
211
|
+
# auxiliary hidden capture mode. TODO: expose this to server args?
|
212
|
+
if self.spec_algorithm.is_eagle3() and not self.is_draft_worker:
|
213
|
+
self.model.set_eagle3_layers_to_capture()
|
214
|
+
|
212
215
|
def model_specific_adjustment(self):
|
213
216
|
server_args = self.server_args
|
214
217
|
|
@@ -223,6 +226,9 @@ class ModelRunner:
|
|
223
226
|
"MLA optimization is turned on. Use flashinfer mla backend."
|
224
227
|
)
|
225
228
|
server_args.attention_backend = "flashinfer_mla"
|
229
|
+
elif server_args.enable_flashmla:
|
230
|
+
logger.info("MLA optimization is turned on. Use flashmla decode.")
|
231
|
+
server_args.attention_backend = "flashmla"
|
226
232
|
else:
|
227
233
|
logger.info("MLA optimization is turned on. Use triton backend.")
|
228
234
|
server_args.attention_backend = "triton"
|
@@ -254,18 +260,41 @@ class ModelRunner:
|
|
254
260
|
|
255
261
|
if self.model_config.hf_config.architectures == [
|
256
262
|
"Qwen2VLForConditionalGeneration"
|
263
|
+
] or self.model_config.hf_config.architectures == [
|
264
|
+
"Qwen2_5_VLForConditionalGeneration"
|
257
265
|
]:
|
258
|
-
# TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
|
266
|
+
# TODO: qwen2-vl series does not support radix cache now, set disable_radix_cache=True automatically
|
267
|
+
logger.info(
|
268
|
+
"Automatically turn off --chunked-prefill-size and disable radix cache for qwen-vl series."
|
269
|
+
)
|
270
|
+
server_args.chunked_prefill_size = -1
|
271
|
+
server_args.disable_radix_cache = True
|
272
|
+
|
273
|
+
if self.model_config.hf_config.architectures == ["DeepseekVL2ForCausalLM"]:
|
274
|
+
# TODO: deepseek-vl2 does not support radix cache now, set disable_radix_cache=True automatically
|
259
275
|
logger.info(
|
260
|
-
"Automatically turn off --chunked-prefill-size and disable radix cache for
|
276
|
+
"Automatically turn off --chunked-prefill-size and disable radix cache for deepseek-vl2."
|
261
277
|
)
|
262
278
|
server_args.chunked_prefill_size = -1
|
263
279
|
server_args.disable_radix_cache = True
|
264
280
|
|
281
|
+
if server_args.enable_deepep_moe:
|
282
|
+
logger.info("DeepEP is turned on.")
|
283
|
+
assert (
|
284
|
+
server_args.enable_dp_attention == True
|
285
|
+
), "Currently DeepEP is bind to Attention DP. Set '--enable-dp-attention --enable-deepep-moe'"
|
286
|
+
|
265
287
|
def init_torch_distributed(self):
|
266
288
|
logger.info("Init torch distributed begin.")
|
267
289
|
|
268
|
-
|
290
|
+
try:
|
291
|
+
torch.get_device_module(self.device).set_device(self.gpu_id)
|
292
|
+
except Exception:
|
293
|
+
logger.warning(
|
294
|
+
f"Context: {self.device=} {self.gpu_id=} {os.environ.get('CUDA_VISIBLE_DEVICES')=} {self.tp_rank=} {self.tp_size=}"
|
295
|
+
)
|
296
|
+
raise
|
297
|
+
|
269
298
|
if self.device == "cuda":
|
270
299
|
backend = "nccl"
|
271
300
|
elif self.device == "xpu":
|
@@ -606,6 +635,8 @@ class ModelRunner:
|
|
606
635
|
load_config=self.load_config,
|
607
636
|
dtype=self.dtype,
|
608
637
|
lora_backend=self.server_args.lora_backend,
|
638
|
+
tp_size=self.tp_size,
|
639
|
+
tp_rank=self.tp_rank,
|
609
640
|
)
|
610
641
|
logger.info("LoRA manager ready.")
|
611
642
|
|
@@ -840,6 +871,23 @@ class ModelRunner:
|
|
840
871
|
)
|
841
872
|
|
842
873
|
self.attn_backend = FlashInferMLAAttnBackend(self)
|
874
|
+
elif self.server_args.attention_backend == "flashmla":
|
875
|
+
from sglang.srt.layers.attention.flashmla_backend import FlashMLABackend
|
876
|
+
|
877
|
+
self.attn_backend = FlashMLABackend(self)
|
878
|
+
elif self.server_args.attention_backend == "fa3":
|
879
|
+
assert torch.cuda.get_device_capability()[0] >= 9, (
|
880
|
+
"FlashAttention v3 Backend requires SM>=90. "
|
881
|
+
"Please use `--attention-backend flashinfer`."
|
882
|
+
)
|
883
|
+
logger.warning(
|
884
|
+
"FlashAttention v3 Backend is in Beta. Multimodal, Page > 1, FP8, MLA and Speculative Decoding are not supported."
|
885
|
+
)
|
886
|
+
from sglang.srt.layers.attention.flashattention_backend import (
|
887
|
+
FlashAttentionBackend,
|
888
|
+
)
|
889
|
+
|
890
|
+
self.attn_backend = FlashAttentionBackend(self)
|
843
891
|
else:
|
844
892
|
raise ValueError(
|
845
893
|
f"Invalid attention backend: {self.server_args.attention_backend}"
|
@@ -1009,6 +1057,22 @@ class ModelRunner:
|
|
1009
1057
|
return False
|
1010
1058
|
return rope_scaling.get("type", None) == "mrope"
|
1011
1059
|
|
1060
|
+
def save_remote_model(self, url: str):
|
1061
|
+
from sglang.srt.model_loader.loader import RemoteModelLoader
|
1062
|
+
|
1063
|
+
logger.info(f"Saving model to {url}")
|
1064
|
+
RemoteModelLoader.save_model(self.model, self.model_config.model_path, url)
|
1065
|
+
|
1066
|
+
def save_sharded_model(
|
1067
|
+
self, path: str, pattern: Optional[str] = None, max_size: Optional[int] = None
|
1068
|
+
):
|
1069
|
+
from sglang.srt.model_loader.loader import ShardedStateLoader
|
1070
|
+
|
1071
|
+
logger.info(
|
1072
|
+
f"Save sharded model to {path} with pattern {pattern} and max_size {max_size}"
|
1073
|
+
)
|
1074
|
+
ShardedStateLoader.save_model(self.model, path, pattern, max_size)
|
1075
|
+
|
1012
1076
|
|
1013
1077
|
def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tensor]]):
|
1014
1078
|
params_dict = dict(model.named_parameters())
|
@@ -9,6 +9,7 @@ import json
|
|
9
9
|
import logging
|
10
10
|
import math
|
11
11
|
import os
|
12
|
+
import time
|
12
13
|
from abc import ABC, abstractmethod
|
13
14
|
from contextlib import contextmanager
|
14
15
|
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast
|
@@ -25,6 +26,12 @@ from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
|
25
26
|
from sglang.srt.configs.device_config import DeviceConfig
|
26
27
|
from sglang.srt.configs.load_config import LoadConfig, LoadFormat
|
27
28
|
from sglang.srt.configs.model_config import ModelConfig
|
29
|
+
from sglang.srt.connector import (
|
30
|
+
ConnectorType,
|
31
|
+
create_remote_connector,
|
32
|
+
get_connector_type,
|
33
|
+
)
|
34
|
+
from sglang.srt.connector.utils import parse_model_name
|
28
35
|
from sglang.srt.distributed import (
|
29
36
|
get_tensor_model_parallel_rank,
|
30
37
|
get_tensor_model_parallel_world_size,
|
@@ -46,6 +53,7 @@ from sglang.srt.model_loader.weight_utils import (
|
|
46
53
|
np_cache_weights_iterator,
|
47
54
|
pt_weights_iterator,
|
48
55
|
safetensors_weights_iterator,
|
56
|
+
set_runai_streamer_env,
|
49
57
|
)
|
50
58
|
from sglang.srt.utils import (
|
51
59
|
get_bool_env_var,
|
@@ -194,7 +202,7 @@ class DefaultModelLoader(BaseModelLoader):
|
|
194
202
|
def _maybe_download_from_modelscope(
|
195
203
|
self, model: str, revision: Optional[str]
|
196
204
|
) -> Optional[str]:
|
197
|
-
"""Download model from ModelScope hub if
|
205
|
+
"""Download model from ModelScope hub if SGLANG_USE_MODELSCOPE is True.
|
198
206
|
|
199
207
|
Returns the path to the downloaded model, or None if the model is not
|
200
208
|
downloaded from ModelScope."""
|
@@ -490,7 +498,7 @@ class ShardedStateLoader(BaseModelLoader):
|
|
490
498
|
Model loader that directly loads each worker's model state dict, which
|
491
499
|
enables a fast load path for large tensor-parallel models where each worker
|
492
500
|
only needs to read its own shard rather than the entire checkpoint. See
|
493
|
-
`examples/save_sharded_state.py` for creating a sharded checkpoint.
|
501
|
+
`examples/runtime/engine/save_sharded_state.py` for creating a sharded checkpoint.
|
494
502
|
"""
|
495
503
|
|
496
504
|
DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors"
|
@@ -1204,6 +1212,153 @@ class GGUFModelLoader(BaseModelLoader):
|
|
1204
1212
|
return model
|
1205
1213
|
|
1206
1214
|
|
1215
|
+
class RemoteModelLoader(BaseModelLoader):
|
1216
|
+
"""Model loader that can load Tensors from remote database."""
|
1217
|
+
|
1218
|
+
def __init__(self, load_config: LoadConfig):
|
1219
|
+
super().__init__(load_config)
|
1220
|
+
# TODO @DellCurry: move to s3 connector only
|
1221
|
+
set_runai_streamer_env(load_config)
|
1222
|
+
|
1223
|
+
def _get_weights_iterator_kv(
|
1224
|
+
self,
|
1225
|
+
client,
|
1226
|
+
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
1227
|
+
"""Get an iterator for the model weights from remote storage."""
|
1228
|
+
assert get_connector_type(client) == ConnectorType.KV
|
1229
|
+
rank = get_tensor_model_parallel_rank()
|
1230
|
+
return client.weight_iterator(rank)
|
1231
|
+
|
1232
|
+
def _get_weights_iterator_fs(
|
1233
|
+
self,
|
1234
|
+
client,
|
1235
|
+
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
1236
|
+
"""Get an iterator for the model weights from remote storage."""
|
1237
|
+
assert get_connector_type(client) == ConnectorType.FS
|
1238
|
+
return client.weight_iterator()
|
1239
|
+
|
1240
|
+
def download_model(self, model_config: ModelConfig) -> None:
|
1241
|
+
pass
|
1242
|
+
|
1243
|
+
@staticmethod
|
1244
|
+
def save_model(
|
1245
|
+
model: torch.nn.Module,
|
1246
|
+
model_path: str,
|
1247
|
+
url: str,
|
1248
|
+
) -> None:
|
1249
|
+
with create_remote_connector(url) as client:
|
1250
|
+
assert get_connector_type(client) == ConnectorType.KV
|
1251
|
+
model_name = parse_model_name(url)
|
1252
|
+
rank = get_tensor_model_parallel_rank()
|
1253
|
+
state_dict = ShardedStateLoader._filter_subtensors(model.state_dict())
|
1254
|
+
for key, tensor in state_dict.items():
|
1255
|
+
r_key = f"{model_name}/keys/rank_{rank}/{key}"
|
1256
|
+
client.set(r_key, tensor)
|
1257
|
+
|
1258
|
+
for root, _, files in os.walk(model_path):
|
1259
|
+
for file_name in files:
|
1260
|
+
# ignore hidden files
|
1261
|
+
if file_name.startswith("."):
|
1262
|
+
continue
|
1263
|
+
if os.path.splitext(file_name)[1] not in (
|
1264
|
+
".bin",
|
1265
|
+
".pt",
|
1266
|
+
".safetensors",
|
1267
|
+
):
|
1268
|
+
file_path = os.path.join(root, file_name)
|
1269
|
+
with open(file_path, encoding="utf-8") as file:
|
1270
|
+
file_content = file.read()
|
1271
|
+
f_key = f"{model_name}/files/{file_name}"
|
1272
|
+
client.setstr(f_key, file_content)
|
1273
|
+
|
1274
|
+
def _load_model_from_remote_kv(self, model: nn.Module, client):
|
1275
|
+
for _, module in model.named_modules():
|
1276
|
+
quant_method = getattr(module, "quant_method", None)
|
1277
|
+
if quant_method is not None:
|
1278
|
+
quant_method.process_weights_after_loading(module)
|
1279
|
+
weights_iterator = self._get_weights_iterator_kv(client)
|
1280
|
+
state_dict = ShardedStateLoader._filter_subtensors(model.state_dict())
|
1281
|
+
for key, tensor in weights_iterator:
|
1282
|
+
# If loading with LoRA enabled, additional padding may
|
1283
|
+
# be added to certain parameters. We only load into a
|
1284
|
+
# narrowed view of the parameter data.
|
1285
|
+
param_data = state_dict[key].data
|
1286
|
+
param_shape = state_dict[key].shape
|
1287
|
+
for dim, size in enumerate(tensor.shape):
|
1288
|
+
if size < param_shape[dim]:
|
1289
|
+
param_data = param_data.narrow(dim, 0, size)
|
1290
|
+
if tensor.shape != param_shape:
|
1291
|
+
logger.warning(
|
1292
|
+
"loading tensor of shape %s into " "parameter '%s' of shape %s",
|
1293
|
+
tensor.shape,
|
1294
|
+
key,
|
1295
|
+
param_shape,
|
1296
|
+
)
|
1297
|
+
param_data.copy_(tensor)
|
1298
|
+
state_dict.pop(key)
|
1299
|
+
if state_dict:
|
1300
|
+
raise ValueError(f"Missing keys {tuple(state_dict)} in loaded state!")
|
1301
|
+
|
1302
|
+
def _load_model_from_remote_fs(
|
1303
|
+
self, model, client, model_config: ModelConfig, device_config: DeviceConfig
|
1304
|
+
) -> nn.Module:
|
1305
|
+
|
1306
|
+
target_device = torch.device(device_config.device)
|
1307
|
+
with set_default_torch_dtype(model_config.dtype):
|
1308
|
+
model.load_weights(self._get_weights_iterator_fs(client))
|
1309
|
+
|
1310
|
+
for _, module in model.named_modules():
|
1311
|
+
quant_method = getattr(module, "quant_method", None)
|
1312
|
+
if quant_method is not None:
|
1313
|
+
# When quant methods need to process weights after loading
|
1314
|
+
# (for repacking, quantizing, etc), they expect parameters
|
1315
|
+
# to be on the global target device. This scope is for the
|
1316
|
+
# case where cpu offloading is used, where we will move the
|
1317
|
+
# parameters onto device for processing and back off after.
|
1318
|
+
with device_loading_context(module, target_device):
|
1319
|
+
quant_method.process_weights_after_loading(module)
|
1320
|
+
|
1321
|
+
def load_model(
|
1322
|
+
self,
|
1323
|
+
*,
|
1324
|
+
model_config: ModelConfig,
|
1325
|
+
device_config: DeviceConfig,
|
1326
|
+
) -> nn.Module:
|
1327
|
+
logger.info("Loading weights from remote storage ...")
|
1328
|
+
start = time.perf_counter()
|
1329
|
+
load_config = self.load_config
|
1330
|
+
|
1331
|
+
assert load_config.load_format == LoadFormat.REMOTE, (
|
1332
|
+
f"Model loader {self.load_config.load_format} is not supported for "
|
1333
|
+
f"load format {load_config.load_format}"
|
1334
|
+
)
|
1335
|
+
|
1336
|
+
model_weights = model_config.model_path
|
1337
|
+
if hasattr(model_config, "model_weights"):
|
1338
|
+
model_weights = model_config.model_weights
|
1339
|
+
|
1340
|
+
with set_default_torch_dtype(model_config.dtype):
|
1341
|
+
with torch.device(device_config.device):
|
1342
|
+
model = _initialize_model(model_config, self.load_config)
|
1343
|
+
for _, module in model.named_modules():
|
1344
|
+
quant_method = getattr(module, "quant_method", None)
|
1345
|
+
if quant_method is not None:
|
1346
|
+
quant_method.process_weights_after_loading(module)
|
1347
|
+
|
1348
|
+
with create_remote_connector(model_weights, device_config.device) as client:
|
1349
|
+
connector_type = get_connector_type(client)
|
1350
|
+
if connector_type == ConnectorType.KV:
|
1351
|
+
self._load_model_from_remote_kv(model, client)
|
1352
|
+
elif connector_type == ConnectorType.FS:
|
1353
|
+
self._load_model_from_remote_fs(
|
1354
|
+
model, client, model_config, device_config
|
1355
|
+
)
|
1356
|
+
|
1357
|
+
end = time.perf_counter()
|
1358
|
+
logger.info("Loaded weights from remote storage in %.2f seconds.", end - start)
|
1359
|
+
return model.eval()
|
1360
|
+
|
1361
|
+
|
1207
1362
|
def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
|
1208
1363
|
"""Get a model loader based on the load format."""
|
1209
1364
|
|
@@ -1225,4 +1380,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
|
|
1225
1380
|
if load_config.load_format == LoadFormat.LAYERED:
|
1226
1381
|
return LayeredModelLoader(load_config)
|
1227
1382
|
|
1383
|
+
if load_config.load_format == LoadFormat.REMOTE:
|
1384
|
+
return RemoteModelLoader(load_config)
|
1385
|
+
|
1228
1386
|
return DefaultModelLoader(load_config)
|
@@ -585,6 +585,51 @@ def composed_weight_loader(
|
|
585
585
|
return composed_loader
|
586
586
|
|
587
587
|
|
588
|
+
def runai_safetensors_weights_iterator(
|
589
|
+
hf_weights_files: List[str],
|
590
|
+
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
591
|
+
"""Iterate over the weights in the model safetensor files."""
|
592
|
+
from runai_model_streamer import SafetensorsStreamer
|
593
|
+
|
594
|
+
enable_tqdm = (
|
595
|
+
not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
|
596
|
+
)
|
597
|
+
|
598
|
+
with SafetensorsStreamer() as streamer:
|
599
|
+
for st_file in tqdm(
|
600
|
+
hf_weights_files,
|
601
|
+
desc="Loading safetensors using Runai Model Streamer",
|
602
|
+
disable=not enable_tqdm,
|
603
|
+
bar_format=_BAR_FORMAT,
|
604
|
+
):
|
605
|
+
streamer.stream_file(st_file)
|
606
|
+
yield from streamer.get_tensors()
|
607
|
+
|
608
|
+
|
609
|
+
def set_runai_streamer_env(load_config: LoadConfig):
|
610
|
+
if load_config.model_loader_extra_config:
|
611
|
+
extra_config = load_config.model_loader_extra_config
|
612
|
+
|
613
|
+
if "concurrency" in extra_config and isinstance(
|
614
|
+
extra_config.get("concurrency"), int
|
615
|
+
):
|
616
|
+
os.environ["RUNAI_STREAMER_CONCURRENCY"] = str(
|
617
|
+
extra_config.get("concurrency")
|
618
|
+
)
|
619
|
+
|
620
|
+
if "memory_limit" in extra_config and isinstance(
|
621
|
+
extra_config.get("memory_limit"), int
|
622
|
+
):
|
623
|
+
os.environ["RUNAI_STREAMER_MEMORY_LIMIT"] = str(
|
624
|
+
extra_config.get("memory_limit")
|
625
|
+
)
|
626
|
+
|
627
|
+
runai_streamer_s3_endpoint = os.getenv("RUNAI_STREAMER_S3_ENDPOINT")
|
628
|
+
aws_endpoint_url = os.getenv("AWS_ENDPOINT_URL")
|
629
|
+
if runai_streamer_s3_endpoint is None and aws_endpoint_url is not None:
|
630
|
+
os.environ["RUNAI_STREAMER_S3_ENDPOINT"] = aws_endpoint_url
|
631
|
+
|
632
|
+
|
588
633
|
def initialize_dummy_weights(
|
589
634
|
model: torch.nn.Module,
|
590
635
|
low: float = -1e-3,
|