sglang 0.4.7__py3-none-any.whl → 0.4.7.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +7 -0
- sglang/bench_serving.py +1 -1
- sglang/lang/interpreter.py +40 -1
- sglang/lang/ir.py +27 -0
- sglang/math_utils.py +8 -0
- sglang/srt/configs/model_config.py +6 -0
- sglang/srt/conversation.py +6 -0
- sglang/srt/disaggregation/base/__init__.py +1 -1
- sglang/srt/disaggregation/base/conn.py +25 -11
- sglang/srt/disaggregation/common/__init__.py +5 -1
- sglang/srt/disaggregation/common/utils.py +42 -0
- sglang/srt/disaggregation/decode.py +196 -51
- sglang/srt/disaggregation/fake/__init__.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +15 -9
- sglang/srt/disaggregation/mooncake/__init__.py +1 -1
- sglang/srt/disaggregation/mooncake/conn.py +18 -13
- sglang/srt/disaggregation/nixl/__init__.py +6 -1
- sglang/srt/disaggregation/nixl/conn.py +17 -12
- sglang/srt/disaggregation/prefill.py +128 -43
- sglang/srt/disaggregation/utils.py +127 -123
- sglang/srt/entrypoints/engine.py +15 -1
- sglang/srt/entrypoints/http_server.py +13 -2
- sglang/srt/eplb_simulator/__init__.py +1 -0
- sglang/srt/eplb_simulator/reader.py +51 -0
- sglang/srt/layers/activation.py +19 -0
- sglang/srt/layers/attention/aiter_backend.py +15 -2
- sglang/srt/layers/attention/cutlass_mla_backend.py +38 -15
- sglang/srt/layers/attention/flashattention_backend.py +53 -64
- sglang/srt/layers/attention/flashinfer_backend.py +1 -2
- sglang/srt/layers/attention/flashinfer_mla_backend.py +22 -24
- sglang/srt/layers/attention/flashmla_backend.py +2 -10
- sglang/srt/layers/attention/triton_backend.py +119 -119
- sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
- sglang/srt/layers/attention/vision.py +51 -24
- sglang/srt/layers/communicator.py +23 -5
- sglang/srt/layers/linear.py +0 -4
- sglang/srt/layers/logits_processor.py +0 -12
- sglang/srt/layers/moe/ep_moe/kernels.py +6 -5
- sglang/srt/layers/moe/ep_moe/layer.py +42 -32
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -4
- sglang/srt/layers/moe/topk.py +16 -8
- sglang/srt/layers/pooler.py +56 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
- sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
- sglang/srt/layers/quantization/fp8_kernel.py +44 -15
- sglang/srt/layers/quantization/fp8_utils.py +87 -22
- sglang/srt/layers/radix_attention.py +2 -3
- sglang/srt/lora/lora_manager.py +79 -34
- sglang/srt/lora/mem_pool.py +4 -5
- sglang/srt/managers/cache_controller.py +2 -1
- sglang/srt/managers/io_struct.py +28 -4
- sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
- sglang/srt/managers/multimodal_processors/vila.py +85 -0
- sglang/srt/managers/schedule_batch.py +39 -6
- sglang/srt/managers/scheduler.py +73 -17
- sglang/srt/managers/tokenizer_manager.py +29 -2
- sglang/srt/mem_cache/chunk_cache.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +4 -2
- sglang/srt/mem_cache/memory_pool.py +111 -407
- sglang/srt/mem_cache/memory_pool_host.py +380 -0
- sglang/srt/mem_cache/radix_cache.py +36 -12
- sglang/srt/model_executor/cuda_graph_runner.py +122 -55
- sglang/srt/model_executor/forward_batch_info.py +14 -5
- sglang/srt/model_executor/model_runner.py +6 -6
- sglang/srt/model_loader/loader.py +8 -1
- sglang/srt/models/bert.py +113 -13
- sglang/srt/models/deepseek_v2.py +113 -155
- sglang/srt/models/internvl.py +46 -102
- sglang/srt/models/roberta.py +117 -9
- sglang/srt/models/vila.py +305 -0
- sglang/srt/openai_api/adapter.py +162 -4
- sglang/srt/openai_api/protocol.py +37 -1
- sglang/srt/sampling/sampling_batch_info.py +24 -0
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +318 -233
- sglang/srt/speculative/build_eagle_tree.py +1 -1
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -3
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +5 -2
- sglang/srt/speculative/eagle_utils.py +389 -109
- sglang/srt/speculative/eagle_worker.py +134 -43
- sglang/srt/two_batch_overlap.py +4 -2
- sglang/srt/utils.py +58 -0
- sglang/test/attention/test_prefix_chunk_info.py +2 -0
- sglang/test/runners.py +38 -3
- sglang/test/test_block_fp8.py +1 -0
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
- sglang/test/test_block_fp8_ep.py +1 -0
- sglang/test/test_utils.py +3 -1
- sglang/utils.py +9 -0
- sglang/version.py +1 -1
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/METADATA +5 -5
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/RECORD +99 -88
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/top_level.txt +0 -0
@@ -16,7 +16,7 @@
|
|
16
16
|
import time
|
17
17
|
from typing import Dict, List, Optional, Union
|
18
18
|
|
19
|
-
from pydantic import BaseModel, Field, root_validator
|
19
|
+
from pydantic import BaseModel, Field, model_serializer, root_validator
|
20
20
|
from typing_extensions import Literal
|
21
21
|
|
22
22
|
|
@@ -182,6 +182,7 @@ class CompletionRequest(BaseModel):
|
|
182
182
|
skip_special_tokens: bool = True
|
183
183
|
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
184
184
|
session_params: Optional[Dict] = None
|
185
|
+
return_hidden_states: Optional[bool] = False
|
185
186
|
|
186
187
|
# For PD disaggregation
|
187
188
|
bootstrap_host: Optional[str] = None
|
@@ -195,6 +196,11 @@ class CompletionResponseChoice(BaseModel):
|
|
195
196
|
logprobs: Optional[LogProbs] = None
|
196
197
|
finish_reason: Literal["stop", "length", "content_filter", "abort"]
|
197
198
|
matched_stop: Union[None, int, str] = None
|
199
|
+
hidden_states: Optional[object] = None
|
200
|
+
|
201
|
+
@model_serializer
|
202
|
+
def _serialize(self):
|
203
|
+
return exclude_if_none(self, ["hidden_states"])
|
198
204
|
|
199
205
|
|
200
206
|
class CompletionResponse(BaseModel):
|
@@ -212,6 +218,11 @@ class CompletionResponseStreamChoice(BaseModel):
|
|
212
218
|
logprobs: Optional[LogProbs] = None
|
213
219
|
finish_reason: Optional[Literal["stop", "length", "content_filter"]] = None
|
214
220
|
matched_stop: Union[None, int, str] = None
|
221
|
+
hidden_states: Optional[object] = None
|
222
|
+
|
223
|
+
@model_serializer
|
224
|
+
def _serialize(self):
|
225
|
+
return exclude_if_none(self, ["hidden_states"])
|
215
226
|
|
216
227
|
|
217
228
|
class CompletionStreamResponse(BaseModel):
|
@@ -405,6 +416,9 @@ class ChatCompletionRequest(BaseModel):
|
|
405
416
|
bootstrap_port: Optional[int] = None
|
406
417
|
bootstrap_room: Optional[int] = None
|
407
418
|
|
419
|
+
# Hidden States
|
420
|
+
return_hidden_states: Optional[bool] = False
|
421
|
+
|
408
422
|
|
409
423
|
class ChatMessage(BaseModel):
|
410
424
|
role: Optional[str] = None
|
@@ -421,6 +435,11 @@ class ChatCompletionResponseChoice(BaseModel):
|
|
421
435
|
"stop", "length", "tool_calls", "content_filter", "function_call", "abort"
|
422
436
|
]
|
423
437
|
matched_stop: Union[None, int, str] = None
|
438
|
+
hidden_states: Optional[object] = None
|
439
|
+
|
440
|
+
@model_serializer
|
441
|
+
def _serialize(self):
|
442
|
+
return exclude_if_none(self, ["hidden_states"])
|
424
443
|
|
425
444
|
|
426
445
|
class ChatCompletionResponse(BaseModel):
|
@@ -437,6 +456,11 @@ class DeltaMessage(BaseModel):
|
|
437
456
|
content: Optional[str] = None
|
438
457
|
reasoning_content: Optional[str] = None
|
439
458
|
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
|
459
|
+
hidden_states: Optional[object] = None
|
460
|
+
|
461
|
+
@model_serializer
|
462
|
+
def _serialize(self):
|
463
|
+
return exclude_if_none(self, ["hidden_states"])
|
440
464
|
|
441
465
|
|
442
466
|
class ChatCompletionResponseStreamChoice(BaseModel):
|
@@ -513,3 +537,15 @@ class ScoringResponse(BaseModel):
|
|
513
537
|
model: str
|
514
538
|
usage: Optional[UsageInfo] = None
|
515
539
|
object: str = "scoring"
|
540
|
+
|
541
|
+
|
542
|
+
class RerankResponse(BaseModel):
|
543
|
+
score: float
|
544
|
+
document: str
|
545
|
+
index: int
|
546
|
+
meta_info: Optional[dict] = None
|
547
|
+
|
548
|
+
|
549
|
+
def exclude_if_none(obj, field_names: List[str]):
|
550
|
+
omit_if_none_fields = {k for k, v in obj.model_fields.items() if k in field_names}
|
551
|
+
return {k: v for k, v in obj if k not in omit_if_none_fields or v is not None}
|
@@ -10,6 +10,7 @@ import torch
|
|
10
10
|
import sglang.srt.sampling.penaltylib as penaltylib
|
11
11
|
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
|
12
12
|
from sglang.srt.sampling.sampling_params import TOP_K_ALL
|
13
|
+
from sglang.srt.utils import merge_bias_tensor
|
13
14
|
|
14
15
|
if TYPE_CHECKING:
|
15
16
|
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
@@ -63,6 +64,9 @@ class SamplingBatchInfo:
|
|
63
64
|
# Device
|
64
65
|
device: str = "cuda"
|
65
66
|
|
67
|
+
# Handle logit bias
|
68
|
+
logit_bias: Optional[torch.Tensor] = None
|
69
|
+
|
66
70
|
@classmethod
|
67
71
|
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
|
68
72
|
reqs = batch.reqs
|
@@ -85,6 +89,14 @@ class SamplingBatchInfo:
|
|
85
89
|
[r.sampling_params.min_p for r in reqs], dtype=torch.float
|
86
90
|
).to(device, non_blocking=True)
|
87
91
|
|
92
|
+
logit_bias = None
|
93
|
+
if any(r.sampling_params.logit_bias is not None for r in reqs):
|
94
|
+
logit_bias = torch.zeros(len(reqs), vocab_size, device=device)
|
95
|
+
for i, r in enumerate(reqs):
|
96
|
+
if r.sampling_params.logit_bias is not None:
|
97
|
+
for key, value in r.sampling_params.logit_bias.items():
|
98
|
+
logit_bias[i, int(key)] = value
|
99
|
+
|
88
100
|
# Check if any request has custom logit processor
|
89
101
|
has_custom_logit_processor = (
|
90
102
|
batch.enable_custom_logit_processor # check the flag first.
|
@@ -150,6 +162,7 @@ class SamplingBatchInfo:
|
|
150
162
|
custom_params=custom_params,
|
151
163
|
custom_logit_processor=merged_custom_logit_processor,
|
152
164
|
device=device,
|
165
|
+
logit_bias=logit_bias,
|
153
166
|
)
|
154
167
|
return ret
|
155
168
|
|
@@ -206,6 +219,9 @@ class SamplingBatchInfo:
|
|
206
219
|
if self.vocab_mask is not None:
|
207
220
|
self.apply_mask_func(logits=logits, vocab_mask=self.vocab_mask)
|
208
221
|
|
222
|
+
if self.logit_bias is not None:
|
223
|
+
logits.add_(self.logit_bias)
|
224
|
+
|
209
225
|
def filter_batch(self, keep_indices: List[int], keep_indices_device: torch.Tensor):
|
210
226
|
self.penalizer_orchestrator.filter(keep_indices_device)
|
211
227
|
|
@@ -221,6 +237,9 @@ class SamplingBatchInfo:
|
|
221
237
|
value = getattr(self, item, None)
|
222
238
|
setattr(self, item, value[keep_indices_device])
|
223
239
|
|
240
|
+
if self.logit_bias is not None:
|
241
|
+
self.logit_bias = self.logit_bias[keep_indices_device]
|
242
|
+
|
224
243
|
def _filter_batch_custom_logit_processor(
|
225
244
|
self, keep_indices: List[int], keep_indices_device: torch.Tensor
|
226
245
|
):
|
@@ -321,3 +340,8 @@ class SamplingBatchInfo:
|
|
321
340
|
self.need_top_p_sampling |= other.need_top_p_sampling
|
322
341
|
self.need_top_k_sampling |= other.need_top_k_sampling
|
323
342
|
self.need_min_p_sampling |= other.need_min_p_sampling
|
343
|
+
|
344
|
+
# Merge logit bias
|
345
|
+
self.logit_bias = merge_bias_tensor(
|
346
|
+
self.logit_bias, other.logit_bias, len(self), len(other), self.device, 0.0
|
347
|
+
)
|
@@ -52,6 +52,7 @@ class SamplingParams:
|
|
52
52
|
no_stop_trim: bool = False,
|
53
53
|
custom_params: Optional[Dict[str, Any]] = None,
|
54
54
|
stream_interval: Optional[int] = None,
|
55
|
+
logit_bias: Optional[Dict[str, float]] = None,
|
55
56
|
) -> None:
|
56
57
|
self.max_new_tokens = max_new_tokens
|
57
58
|
self.stop_strs = stop
|
@@ -78,6 +79,7 @@ class SamplingParams:
|
|
78
79
|
self.no_stop_trim = no_stop_trim
|
79
80
|
self.custom_params = custom_params
|
80
81
|
self.stream_interval = stream_interval
|
82
|
+
self.logit_bias = logit_bias
|
81
83
|
|
82
84
|
# Process some special cases
|
83
85
|
if 0 <= self.temperature < _SAMPLING_EPS:
|