sglang 0.4.5.post1__py3-none-any.whl → 0.4.5.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -4
- sglang/bench_one_batch.py +2 -2
- sglang/bench_serving.py +3 -6
- sglang/compile_deep_gemm.py +136 -0
- sglang/lang/backend/anthropic.py +0 -4
- sglang/lang/backend/base_backend.py +1 -1
- sglang/lang/backend/openai.py +6 -2
- sglang/lang/backend/runtime_endpoint.py +5 -1
- sglang/lang/backend/vertexai.py +0 -1
- sglang/lang/compiler.py +1 -7
- sglang/lang/tracer.py +3 -7
- sglang/srt/_custom_ops.py +0 -2
- sglang/srt/configs/model_config.py +4 -1
- sglang/srt/constrained/outlines_jump_forward.py +14 -1
- sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
- sglang/srt/constrained/xgrammar_backend.py +27 -4
- sglang/srt/custom_op.py +0 -62
- sglang/srt/disaggregation/decode.py +105 -6
- sglang/srt/disaggregation/mini_lb.py +74 -9
- sglang/srt/disaggregation/mooncake/conn.py +33 -63
- sglang/srt/disaggregation/mooncake/transfer_engine.py +30 -61
- sglang/srt/disaggregation/nixl/__init__.py +1 -0
- sglang/srt/disaggregation/nixl/conn.py +622 -0
- sglang/srt/disaggregation/prefill.py +137 -17
- sglang/srt/disaggregation/utils.py +32 -0
- sglang/srt/entrypoints/engine.py +4 -0
- sglang/srt/entrypoints/http_server.py +3 -7
- sglang/srt/entrypoints/verl_engine.py +7 -5
- sglang/srt/function_call_parser.py +60 -0
- sglang/srt/layers/activation.py +6 -8
- sglang/srt/layers/attention/flashattention_backend.py +883 -209
- sglang/srt/layers/attention/flashinfer_backend.py +5 -2
- sglang/srt/layers/attention/torch_native_backend.py +6 -1
- sglang/srt/layers/attention/triton_backend.py +6 -0
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +5 -5
- sglang/srt/layers/attention/triton_ops/extend_attention.py +18 -7
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +7 -3
- sglang/srt/layers/dp_attention.py +1 -1
- sglang/srt/layers/layernorm.py +20 -5
- sglang/srt/layers/linear.py +17 -3
- sglang/srt/layers/moe/ep_moe/layer.py +17 -29
- sglang/srt/layers/moe/fused_moe_native.py +4 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +14 -19
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
- sglang/srt/layers/moe/topk.py +27 -30
- sglang/srt/layers/parameter.py +0 -2
- sglang/srt/layers/quantization/__init__.py +1 -0
- sglang/srt/layers/quantization/blockwise_int8.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +9 -2
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +16 -44
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +153 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
- sglang/srt/layers/quantization/deep_gemm.py +378 -0
- sglang/srt/layers/quantization/fp8.py +115 -132
- sglang/srt/layers/quantization/fp8_kernel.py +213 -88
- sglang/srt/layers/quantization/fp8_utils.py +189 -264
- sglang/srt/layers/quantization/gptq.py +13 -7
- sglang/srt/layers/quantization/modelopt_quant.py +2 -2
- sglang/srt/layers/quantization/moe_wna16.py +2 -0
- sglang/srt/layers/quantization/utils.py +5 -11
- sglang/srt/layers/quantization/w8a8_fp8.py +2 -0
- sglang/srt/layers/quantization/w8a8_int8.py +7 -7
- sglang/srt/layers/radix_attention.py +15 -0
- sglang/srt/layers/rotary_embedding.py +9 -8
- sglang/srt/layers/sampler.py +7 -12
- sglang/srt/lora/backend/base_backend.py +18 -2
- sglang/srt/lora/backend/flashinfer_backend.py +1 -1
- sglang/srt/lora/backend/triton_backend.py +1 -1
- sglang/srt/lora/layers.py +1 -1
- sglang/srt/lora/lora.py +1 -1
- sglang/srt/lora/lora_manager.py +1 -1
- sglang/srt/managers/data_parallel_controller.py +7 -1
- sglang/srt/managers/detokenizer_manager.py +0 -1
- sglang/srt/managers/io_struct.py +15 -3
- sglang/srt/managers/mm_utils.py +4 -3
- sglang/srt/managers/multimodal_processor.py +0 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +3 -2
- sglang/srt/managers/schedule_batch.py +15 -4
- sglang/srt/managers/scheduler.py +28 -77
- sglang/srt/managers/tokenizer_manager.py +116 -29
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +41 -29
- sglang/srt/mem_cache/memory_pool.py +38 -15
- sglang/srt/model_executor/cuda_graph_runner.py +15 -10
- sglang/srt/model_executor/model_runner.py +39 -31
- sglang/srt/models/bert.py +398 -0
- sglang/srt/models/deepseek.py +1 -1
- sglang/srt/models/deepseek_nextn.py +74 -70
- sglang/srt/models/deepseek_v2.py +292 -348
- sglang/srt/models/llama.py +5 -5
- sglang/srt/models/minicpm3.py +31 -203
- sglang/srt/models/minicpmo.py +17 -6
- sglang/srt/models/qwen2.py +4 -1
- sglang/srt/models/qwen2_moe.py +14 -13
- sglang/srt/models/qwen3.py +335 -0
- sglang/srt/models/qwen3_moe.py +423 -0
- sglang/srt/openai_api/adapter.py +71 -4
- sglang/srt/openai_api/protocol.py +6 -1
- sglang/srt/reasoning_parser.py +0 -1
- sglang/srt/sampling/sampling_batch_info.py +2 -3
- sglang/srt/server_args.py +86 -72
- sglang/srt/speculative/build_eagle_tree.py +2 -2
- sglang/srt/speculative/eagle_utils.py +2 -2
- sglang/srt/speculative/eagle_worker.py +6 -14
- sglang/srt/utils.py +62 -6
- sglang/test/runners.py +5 -1
- sglang/test/test_block_fp8.py +167 -0
- sglang/test/test_custom_ops.py +1 -1
- sglang/test/test_utils.py +3 -1
- sglang/version.py +1 -1
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/METADATA +5 -5
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/RECORD +116 -110
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/WHEEL +1 -1
- sglang/lang/__init__.py +0 -0
- sglang/srt/lora/backend/__init__.py +0 -25
- sglang/srt/server.py +0 -18
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/top_level.txt +0 -0
@@ -1,13 +1,6 @@
|
|
1
1
|
from typing import Any, Callable, Dict, List, Optional
|
2
2
|
|
3
3
|
import torch
|
4
|
-
|
5
|
-
from sglang.srt.utils import is_cuda_available, set_weight_attrs
|
6
|
-
|
7
|
-
is_cuda = is_cuda_available()
|
8
|
-
if is_cuda:
|
9
|
-
from sgl_kernel import int8_scaled_mm
|
10
|
-
|
11
4
|
from torch.nn.parameter import Parameter
|
12
5
|
|
13
6
|
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
@@ -18,6 +11,11 @@ from sglang.srt.layers.quantization.base_config import (
|
|
18
11
|
QuantizeMethodBase,
|
19
12
|
)
|
20
13
|
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
|
14
|
+
from sglang.srt.utils import is_cuda, set_weight_attrs
|
15
|
+
|
16
|
+
_is_cuda = is_cuda()
|
17
|
+
if _is_cuda:
|
18
|
+
from sgl_kernel import int8_scaled_mm
|
21
19
|
|
22
20
|
|
23
21
|
class W8A8Int8Config(QuantizationConfig):
|
@@ -233,6 +231,7 @@ class W8A8Int8MoEMethod:
|
|
233
231
|
apply_router_weight_on_input: bool = False,
|
234
232
|
inplace: bool = True,
|
235
233
|
no_combine: bool = False,
|
234
|
+
routed_scaling_factor: Optional[float] = None,
|
236
235
|
) -> torch.Tensor:
|
237
236
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
238
237
|
from sglang.srt.layers.moe.topk import select_experts
|
@@ -248,6 +247,7 @@ class W8A8Int8MoEMethod:
|
|
248
247
|
num_expert_group=num_expert_group,
|
249
248
|
custom_routing_function=custom_routing_function,
|
250
249
|
correction_bias=correction_bias,
|
250
|
+
routed_scaling_factor=routed_scaling_factor,
|
251
251
|
)
|
252
252
|
|
253
253
|
return fused_experts(
|
@@ -13,6 +13,7 @@
|
|
13
13
|
# ==============================================================================
|
14
14
|
"""Radix attention."""
|
15
15
|
|
16
|
+
from enum import Enum
|
16
17
|
from typing import Optional
|
17
18
|
|
18
19
|
from torch import nn
|
@@ -22,6 +23,18 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
|
22
23
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
23
24
|
|
24
25
|
|
26
|
+
class AttentionType(Enum):
|
27
|
+
"""
|
28
|
+
Attention type.
|
29
|
+
Use string to be compatible with `torch.compile`.
|
30
|
+
"""
|
31
|
+
|
32
|
+
# Decoder attention between previous layer Q/K/V
|
33
|
+
DECODER = "decoder"
|
34
|
+
# Encoder attention between previous layer Q/K/V
|
35
|
+
ENCODER_ONLY = "encoder_only"
|
36
|
+
|
37
|
+
|
25
38
|
class RadixAttention(nn.Module):
|
26
39
|
"""
|
27
40
|
The attention layer implementation.
|
@@ -39,6 +52,7 @@ class RadixAttention(nn.Module):
|
|
39
52
|
sliding_window_size: int = -1,
|
40
53
|
is_cross_attention: bool = False,
|
41
54
|
quant_config: Optional[QuantizationConfig] = None,
|
55
|
+
attn_type=AttentionType.DECODER,
|
42
56
|
prefix: str = "",
|
43
57
|
use_irope: bool = False,
|
44
58
|
):
|
@@ -64,6 +78,7 @@ class RadixAttention(nn.Module):
|
|
64
78
|
self.quant_method = quant_config.get_quant_method(self, prefix=prefix)
|
65
79
|
if self.quant_method is not None:
|
66
80
|
self.quant_method.create_weights(self)
|
81
|
+
self.attn_type = attn_type
|
67
82
|
|
68
83
|
def forward(
|
69
84
|
self,
|
@@ -8,13 +8,14 @@ import torch
|
|
8
8
|
import torch.nn as nn
|
9
9
|
|
10
10
|
from sglang.srt.custom_op import CustomOp
|
11
|
-
from sglang.srt.utils import
|
11
|
+
from sglang.srt.utils import is_cuda
|
12
12
|
|
13
|
-
|
14
|
-
|
13
|
+
_is_cuda = is_cuda()
|
14
|
+
|
15
|
+
if _is_cuda:
|
15
16
|
from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
|
16
17
|
else:
|
17
|
-
from vllm import
|
18
|
+
from vllm._custom_ops import rotary_embedding as vllm_rotary_embedding
|
18
19
|
|
19
20
|
|
20
21
|
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
|
@@ -81,7 +82,7 @@ class RotaryEmbedding(CustomOp):
|
|
81
82
|
|
82
83
|
cache = self._compute_cos_sin_cache()
|
83
84
|
# NOTE(ByronHsu): cache needs to be in FP32 for numerical stability
|
84
|
-
if not
|
85
|
+
if not _is_cuda:
|
85
86
|
cache = cache.to(dtype)
|
86
87
|
self.cos_sin_cache: torch.Tensor
|
87
88
|
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
@@ -148,7 +149,7 @@ class RotaryEmbedding(CustomOp):
|
|
148
149
|
key: torch.Tensor,
|
149
150
|
offsets: Optional[torch.Tensor] = None,
|
150
151
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
151
|
-
if
|
152
|
+
if _is_cuda and (self.head_size in [64, 128, 256, 512]):
|
152
153
|
apply_rope_with_cos_sin_cache_inplace(
|
153
154
|
positions=positions,
|
154
155
|
query=query,
|
@@ -159,7 +160,7 @@ class RotaryEmbedding(CustomOp):
|
|
159
160
|
)
|
160
161
|
else:
|
161
162
|
self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
|
162
|
-
|
163
|
+
vllm_rotary_embedding(
|
163
164
|
positions,
|
164
165
|
query,
|
165
166
|
key,
|
@@ -651,7 +652,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
|
651
652
|
def forward(self, *args, **kwargs):
|
652
653
|
if torch.compiler.is_compiling():
|
653
654
|
return self.forward_native(*args, **kwargs)
|
654
|
-
if
|
655
|
+
if _is_cuda:
|
655
656
|
return self.forward_cuda(*args, **kwargs)
|
656
657
|
else:
|
657
658
|
return self.forward_native(*args, **kwargs)
|
sglang/srt/layers/sampler.py
CHANGED
@@ -10,9 +10,9 @@ from sglang.srt.layers.dp_attention import get_attention_tp_group
|
|
10
10
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
11
11
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
12
12
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
13
|
-
from sglang.srt.utils import crash_on_warnings, get_bool_env_var,
|
13
|
+
from sglang.srt.utils import crash_on_warnings, get_bool_env_var, is_cuda
|
14
14
|
|
15
|
-
if
|
15
|
+
if is_cuda():
|
16
16
|
from sgl_kernel import (
|
17
17
|
min_p_sampling_from_probs,
|
18
18
|
top_k_renorm_prob,
|
@@ -93,28 +93,23 @@ class Sampler(nn.Module):
|
|
93
93
|
).clamp(min=torch.finfo(probs.dtype).min)
|
94
94
|
|
95
95
|
max_top_k_round, batch_size = 32, probs.shape[0]
|
96
|
-
uniform_samples = torch.rand(
|
97
|
-
(max_top_k_round, batch_size), device=probs.device
|
98
|
-
)
|
99
96
|
if sampling_info.need_min_p_sampling:
|
100
97
|
probs = top_k_renorm_prob(probs, sampling_info.top_ks)
|
101
98
|
probs = top_p_renorm_prob(probs, sampling_info.top_ps)
|
102
99
|
batch_next_token_ids = min_p_sampling_from_probs(
|
103
|
-
probs,
|
100
|
+
probs, sampling_info.min_ps
|
104
101
|
)
|
105
102
|
else:
|
106
|
-
|
103
|
+
# Check Nan will throw exception, only check when crash_on_warnings is True
|
104
|
+
check_nan = self.use_nan_detection and crash_on_warnings()
|
105
|
+
batch_next_token_ids = top_k_top_p_sampling_from_probs(
|
107
106
|
probs,
|
108
|
-
uniform_samples,
|
109
107
|
sampling_info.top_ks,
|
110
108
|
sampling_info.top_ps,
|
111
109
|
filter_apply_order="joint",
|
110
|
+
check_nan=check_nan,
|
112
111
|
)
|
113
112
|
|
114
|
-
if self.use_nan_detection and not torch.all(success):
|
115
|
-
logger.warning("Detected errors during sampling!")
|
116
|
-
batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
|
117
|
-
|
118
113
|
elif global_server_args_dict["sampling_backend"] == "pytorch":
|
119
114
|
# A slower fallback implementation with torch native operations.
|
120
115
|
batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
|
@@ -75,7 +75,7 @@ class BaseLoRABackend:
|
|
75
75
|
qkv_lora_a: torch.Tensor,
|
76
76
|
qkv_lora_b: Union[torch.Tensor, Tuple[torch.Tensor]],
|
77
77
|
*args,
|
78
|
-
**kwargs
|
78
|
+
**kwargs,
|
79
79
|
) -> torch.Tensor:
|
80
80
|
"""Run the lora pass for QKV Layer.
|
81
81
|
|
@@ -98,7 +98,7 @@ class BaseLoRABackend:
|
|
98
98
|
gate_up_lora_a: torch.Tensor,
|
99
99
|
gate_up_lora_b: Union[torch.Tensor, Tuple[torch.Tensor]],
|
100
100
|
*args,
|
101
|
-
**kwargs
|
101
|
+
**kwargs,
|
102
102
|
) -> torch.Tensor:
|
103
103
|
"""Run the lora pass for gate_up_proj, usually attached to MergedColumnParallelLayer.
|
104
104
|
|
@@ -115,3 +115,19 @@ class BaseLoRABackend:
|
|
115
115
|
|
116
116
|
def set_batch_info(self, batch_info: LoRABatchInfo):
|
117
117
|
self.batch_info = batch_info
|
118
|
+
|
119
|
+
|
120
|
+
def get_backend_from_name(name: str) -> BaseLoRABackend:
|
121
|
+
"""
|
122
|
+
Get corresponding backend class from backend's name
|
123
|
+
"""
|
124
|
+
if name == "triton":
|
125
|
+
from sglang.srt.lora.backend.triton_backend import TritonLoRABackend
|
126
|
+
|
127
|
+
return TritonLoRABackend
|
128
|
+
elif name == "flashinfer":
|
129
|
+
from sglang.srt.lora.backend.flashinfer_backend import FlashInferLoRABackend
|
130
|
+
|
131
|
+
return FlashInferLoRABackend
|
132
|
+
else:
|
133
|
+
raise ValueError(f"Invalid backend: {name}")
|
@@ -2,7 +2,7 @@ from typing import Tuple
|
|
2
2
|
|
3
3
|
import torch
|
4
4
|
|
5
|
-
from sglang.srt.lora.backend import BaseLoRABackend
|
5
|
+
from sglang.srt.lora.backend.base_backend import BaseLoRABackend
|
6
6
|
from sglang.srt.lora.utils import LoRABatchInfo
|
7
7
|
from sglang.srt.utils import is_flashinfer_available
|
8
8
|
|
sglang/srt/lora/layers.py
CHANGED
@@ -16,7 +16,7 @@ from sglang.srt.layers.linear import (
|
|
16
16
|
RowParallelLinear,
|
17
17
|
)
|
18
18
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
19
|
-
from sglang.srt.lora.backend import BaseLoRABackend
|
19
|
+
from sglang.srt.lora.backend.base_backend import BaseLoRABackend
|
20
20
|
|
21
21
|
|
22
22
|
class BaseLayerWithLoRA(nn.Module):
|
sglang/srt/lora/lora.py
CHANGED
@@ -27,7 +27,7 @@ from torch import nn
|
|
27
27
|
|
28
28
|
from sglang.srt.configs.load_config import LoadConfig
|
29
29
|
from sglang.srt.hf_transformers_utils import AutoConfig
|
30
|
-
from sglang.srt.lora.backend import BaseLoRABackend
|
30
|
+
from sglang.srt.lora.backend.base_backend import BaseLoRABackend
|
31
31
|
from sglang.srt.lora.lora_config import LoRAConfig
|
32
32
|
from sglang.srt.model_loader.loader import DefaultModelLoader
|
33
33
|
|
sglang/srt/lora/lora_manager.py
CHANGED
@@ -22,7 +22,7 @@ import torch
|
|
22
22
|
|
23
23
|
from sglang.srt.configs.load_config import LoadConfig
|
24
24
|
from sglang.srt.hf_transformers_utils import AutoConfig
|
25
|
-
from sglang.srt.lora.backend import BaseLoRABackend, get_backend_from_name
|
25
|
+
from sglang.srt.lora.backend.base_backend import BaseLoRABackend, get_backend_from_name
|
26
26
|
from sglang.srt.lora.layers import BaseLayerWithLoRA, get_lora_layer
|
27
27
|
from sglang.srt.lora.lora import LoRAAdapter
|
28
28
|
from sglang.srt.lora.lora_config import LoRAConfig
|
@@ -30,6 +30,7 @@ from sglang.srt.managers.io_struct import (
|
|
30
30
|
)
|
31
31
|
from sglang.srt.managers.scheduler import run_scheduler_process
|
32
32
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
33
|
+
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
33
34
|
from sglang.srt.utils import bind_port, configure_logger, get_zmq_socket
|
34
35
|
from sglang.utils import get_exception_traceback
|
35
36
|
|
@@ -174,6 +175,10 @@ class DataParallelController:
|
|
174
175
|
if not server_args.enable_dp_attention:
|
175
176
|
logger.info(f"Launch DP{dp_rank} starting at GPU #{base_gpu_id}.")
|
176
177
|
|
178
|
+
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
179
|
+
enable=server_args.enable_memory_saver
|
180
|
+
)
|
181
|
+
|
177
182
|
# Launch tensor parallel scheduler processes
|
178
183
|
scheduler_pipe_readers = []
|
179
184
|
tp_size_per_node = server_args.tp_size // server_args.nnodes
|
@@ -208,7 +213,8 @@ class DataParallelController:
|
|
208
213
|
target=run_scheduler_process,
|
209
214
|
args=(server_args, rank_port_args, gpu_id, tp_rank, dp_rank, writer),
|
210
215
|
)
|
211
|
-
|
216
|
+
with memory_saver_adapter.configure_subprocess():
|
217
|
+
proc.start()
|
212
218
|
self.scheduler_procs.append(proc)
|
213
219
|
scheduler_pipe_readers.append(reader)
|
214
220
|
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -96,8 +96,8 @@ class GenerateReqInput:
|
|
96
96
|
return_hidden_states: bool = False
|
97
97
|
|
98
98
|
# For disaggregated inference
|
99
|
-
bootstrap_host: Optional[str] = None
|
100
|
-
bootstrap_room: Optional[int] = None
|
99
|
+
bootstrap_host: Optional[Union[List[str], str]] = None
|
100
|
+
bootstrap_room: Optional[Union[List[int], int]] = None
|
101
101
|
|
102
102
|
def normalize_batch_and_arguments(self):
|
103
103
|
"""
|
@@ -397,6 +397,12 @@ class GenerateReqInput:
|
|
397
397
|
else None
|
398
398
|
),
|
399
399
|
return_hidden_states=self.return_hidden_states,
|
400
|
+
bootstrap_host=(
|
401
|
+
self.bootstrap_host[i] if self.bootstrap_host is not None else None
|
402
|
+
),
|
403
|
+
bootstrap_room=(
|
404
|
+
self.bootstrap_room[i] if self.bootstrap_room is not None else None
|
405
|
+
),
|
400
406
|
)
|
401
407
|
|
402
408
|
|
@@ -665,10 +671,15 @@ class BatchEmbeddingOut:
|
|
665
671
|
|
666
672
|
|
667
673
|
@dataclass
|
668
|
-
class
|
674
|
+
class FlushCacheReqInput:
|
669
675
|
pass
|
670
676
|
|
671
677
|
|
678
|
+
@dataclass
|
679
|
+
class FlushCacheReqOutput:
|
680
|
+
success: bool
|
681
|
+
|
682
|
+
|
672
683
|
@dataclass
|
673
684
|
class UpdateWeightFromDiskReqInput:
|
674
685
|
# The model path with the new weights
|
@@ -834,6 +845,7 @@ class ProfileReq:
|
|
834
845
|
activities: Optional[List[str]] = None
|
835
846
|
with_stack: Optional[bool] = None
|
836
847
|
record_shapes: Optional[bool] = None
|
848
|
+
profile_id: Optional[str] = None
|
837
849
|
|
838
850
|
|
839
851
|
@dataclass
|
sglang/srt/managers/mm_utils.py
CHANGED
@@ -1,7 +1,8 @@
|
|
1
1
|
"""
|
2
|
-
|
2
|
+
Multi-modality utils
|
3
3
|
"""
|
4
4
|
|
5
|
+
import logging
|
5
6
|
from abc import abstractmethod
|
6
7
|
from typing import Callable, List, Optional, Tuple
|
7
8
|
|
@@ -12,11 +13,11 @@ from sglang.srt.managers.schedule_batch import (
|
|
12
13
|
MultimodalDataItem,
|
13
14
|
MultimodalInputs,
|
14
15
|
global_server_args_dict,
|
15
|
-
logger,
|
16
16
|
)
|
17
17
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
18
18
|
from sglang.srt.utils import print_warning_once
|
19
|
-
|
19
|
+
|
20
|
+
logger = logging.getLogger(__name__)
|
20
21
|
|
21
22
|
|
22
23
|
class MultiModalityDataPaddingPattern:
|
@@ -8,8 +8,6 @@ from typing import List, Optional
|
|
8
8
|
|
9
9
|
import numpy as np
|
10
10
|
import PIL
|
11
|
-
from decord import VideoReader, cpu
|
12
|
-
from PIL import Image
|
13
11
|
from transformers import BaseImageProcessorFast
|
14
12
|
|
15
13
|
from sglang.srt.managers.schedule_batch import Modality
|
@@ -102,6 +100,9 @@ class BaseMultimodalProcessor(ABC):
|
|
102
100
|
"""
|
103
101
|
estimate the total frame count from all visual input
|
104
102
|
"""
|
103
|
+
# Lazy import because decord is not available on some arm platforms.
|
104
|
+
from decord import VideoReader, cpu
|
105
|
+
|
105
106
|
# Before processing inputs
|
106
107
|
estimated_frames_list = []
|
107
108
|
for image in image_data:
|
@@ -67,7 +67,6 @@ global_server_args_dict = {
|
|
67
67
|
"attention_backend": ServerArgs.attention_backend,
|
68
68
|
"sampling_backend": ServerArgs.sampling_backend,
|
69
69
|
"triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
|
70
|
-
"disable_mla": ServerArgs.disable_mla,
|
71
70
|
"torchao_config": ServerArgs.torchao_config,
|
72
71
|
"enable_nan_detection": ServerArgs.enable_nan_detection,
|
73
72
|
"enable_dp_attention": ServerArgs.enable_dp_attention,
|
@@ -77,12 +76,11 @@ global_server_args_dict = {
|
|
77
76
|
"device": ServerArgs.device,
|
78
77
|
"speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
|
79
78
|
"speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
|
80
|
-
"enable_flashmla": ServerArgs.enable_flashmla,
|
81
79
|
"disable_radix_cache": ServerArgs.disable_radix_cache,
|
82
80
|
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
|
81
|
+
"moe_dense_tp_size": ServerArgs.moe_dense_tp_size,
|
83
82
|
"chunked_prefill_size": ServerArgs.chunked_prefill_size,
|
84
83
|
"n_share_experts_fusion": ServerArgs.n_share_experts_fusion,
|
85
|
-
"disable_shared_experts_fusion": ServerArgs.disable_shared_experts_fusion,
|
86
84
|
"disable_chunked_prefix_cache": ServerArgs.disable_chunked_prefix_cache,
|
87
85
|
}
|
88
86
|
|
@@ -541,6 +539,11 @@ class Req:
|
|
541
539
|
# The first output_id transferred from prefill instance.
|
542
540
|
self.transferred_output_id: Optional[int] = None
|
543
541
|
|
542
|
+
# For overlap schedule, we delay the kv transfer until `process_batch_result_disagg_prefill` rather than `process_prefill_chunk` in non-overlap
|
543
|
+
# This is because kv is not ready in `process_prefill_chunk`.
|
544
|
+
# We use `tmp_end_idx` to store the end index of the kv cache to send.
|
545
|
+
self.tmp_end_idx: int = -1
|
546
|
+
|
544
547
|
@property
|
545
548
|
def seqlen(self):
|
546
549
|
return len(self.origin_input_ids) + len(self.output_ids)
|
@@ -573,6 +576,14 @@ class Req:
|
|
573
576
|
self.prefix_indices, self.last_node = tree_cache.match_prefix(
|
574
577
|
rid=self.rid, key=self.adjust_max_prefix_ids()
|
575
578
|
)
|
579
|
+
elif enable_hierarchical_cache:
|
580
|
+
# in case last_node is evicted during scheduling, we need to update the prefix_indices
|
581
|
+
while self.last_node.evicted:
|
582
|
+
self.prefix_indices = self.prefix_indices[
|
583
|
+
: -len(self.last_node.host_value)
|
584
|
+
]
|
585
|
+
self.last_node = self.last_node.parent
|
586
|
+
|
576
587
|
self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
|
577
588
|
|
578
589
|
def adjust_max_prefix_ids(self):
|
@@ -1481,7 +1492,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1481
1492
|
global_server_args_dict["use_mla_backend"]
|
1482
1493
|
and global_server_args_dict["attention_backend"] == "flashinfer"
|
1483
1494
|
)
|
1484
|
-
or global_server_args_dict["
|
1495
|
+
or global_server_args_dict["attention_backend"] == "flashmla"
|
1485
1496
|
or global_server_args_dict["attention_backend"] == "fa3"
|
1486
1497
|
):
|
1487
1498
|
seq_lens_cpu = self.seq_lens.cpu()
|