sglang 0.4.7__py3-none-any.whl → 0.4.7.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +7 -0
  3. sglang/bench_serving.py +1 -1
  4. sglang/lang/interpreter.py +40 -1
  5. sglang/lang/ir.py +27 -0
  6. sglang/math_utils.py +8 -0
  7. sglang/srt/configs/model_config.py +6 -0
  8. sglang/srt/conversation.py +6 -0
  9. sglang/srt/disaggregation/base/__init__.py +1 -1
  10. sglang/srt/disaggregation/base/conn.py +25 -11
  11. sglang/srt/disaggregation/common/__init__.py +5 -1
  12. sglang/srt/disaggregation/common/utils.py +42 -0
  13. sglang/srt/disaggregation/decode.py +196 -51
  14. sglang/srt/disaggregation/fake/__init__.py +1 -1
  15. sglang/srt/disaggregation/fake/conn.py +15 -9
  16. sglang/srt/disaggregation/mooncake/__init__.py +1 -1
  17. sglang/srt/disaggregation/mooncake/conn.py +18 -13
  18. sglang/srt/disaggregation/nixl/__init__.py +6 -1
  19. sglang/srt/disaggregation/nixl/conn.py +17 -12
  20. sglang/srt/disaggregation/prefill.py +128 -43
  21. sglang/srt/disaggregation/utils.py +127 -123
  22. sglang/srt/entrypoints/engine.py +15 -1
  23. sglang/srt/entrypoints/http_server.py +13 -2
  24. sglang/srt/eplb_simulator/__init__.py +1 -0
  25. sglang/srt/eplb_simulator/reader.py +51 -0
  26. sglang/srt/layers/activation.py +19 -0
  27. sglang/srt/layers/attention/aiter_backend.py +15 -2
  28. sglang/srt/layers/attention/cutlass_mla_backend.py +38 -15
  29. sglang/srt/layers/attention/flashattention_backend.py +53 -64
  30. sglang/srt/layers/attention/flashinfer_backend.py +1 -2
  31. sglang/srt/layers/attention/flashinfer_mla_backend.py +22 -24
  32. sglang/srt/layers/attention/flashmla_backend.py +2 -10
  33. sglang/srt/layers/attention/triton_backend.py +119 -119
  34. sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
  35. sglang/srt/layers/attention/vision.py +51 -24
  36. sglang/srt/layers/communicator.py +23 -5
  37. sglang/srt/layers/linear.py +0 -4
  38. sglang/srt/layers/logits_processor.py +0 -12
  39. sglang/srt/layers/moe/ep_moe/kernels.py +6 -5
  40. sglang/srt/layers/moe/ep_moe/layer.py +42 -32
  41. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
  42. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -4
  43. sglang/srt/layers/moe/topk.py +16 -8
  44. sglang/srt/layers/pooler.py +56 -0
  45. sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
  46. sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
  47. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
  48. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
  49. sglang/srt/layers/quantization/fp8_kernel.py +44 -15
  50. sglang/srt/layers/quantization/fp8_utils.py +87 -22
  51. sglang/srt/layers/radix_attention.py +2 -3
  52. sglang/srt/lora/lora_manager.py +79 -34
  53. sglang/srt/lora/mem_pool.py +4 -5
  54. sglang/srt/managers/cache_controller.py +2 -1
  55. sglang/srt/managers/io_struct.py +28 -4
  56. sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
  57. sglang/srt/managers/multimodal_processors/vila.py +85 -0
  58. sglang/srt/managers/schedule_batch.py +39 -6
  59. sglang/srt/managers/scheduler.py +73 -17
  60. sglang/srt/managers/tokenizer_manager.py +29 -2
  61. sglang/srt/mem_cache/chunk_cache.py +1 -0
  62. sglang/srt/mem_cache/hiradix_cache.py +4 -2
  63. sglang/srt/mem_cache/memory_pool.py +111 -407
  64. sglang/srt/mem_cache/memory_pool_host.py +380 -0
  65. sglang/srt/mem_cache/radix_cache.py +36 -12
  66. sglang/srt/model_executor/cuda_graph_runner.py +122 -55
  67. sglang/srt/model_executor/forward_batch_info.py +14 -5
  68. sglang/srt/model_executor/model_runner.py +6 -6
  69. sglang/srt/model_loader/loader.py +8 -1
  70. sglang/srt/models/bert.py +113 -13
  71. sglang/srt/models/deepseek_v2.py +113 -155
  72. sglang/srt/models/internvl.py +46 -102
  73. sglang/srt/models/roberta.py +117 -9
  74. sglang/srt/models/vila.py +305 -0
  75. sglang/srt/openai_api/adapter.py +162 -4
  76. sglang/srt/openai_api/protocol.py +37 -1
  77. sglang/srt/sampling/sampling_batch_info.py +24 -0
  78. sglang/srt/sampling/sampling_params.py +2 -0
  79. sglang/srt/server_args.py +318 -233
  80. sglang/srt/speculative/build_eagle_tree.py +1 -1
  81. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -3
  82. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +5 -2
  83. sglang/srt/speculative/eagle_utils.py +389 -109
  84. sglang/srt/speculative/eagle_worker.py +134 -43
  85. sglang/srt/two_batch_overlap.py +4 -2
  86. sglang/srt/utils.py +58 -0
  87. sglang/test/attention/test_prefix_chunk_info.py +2 -0
  88. sglang/test/runners.py +38 -3
  89. sglang/test/test_block_fp8.py +1 -0
  90. sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
  91. sglang/test/test_block_fp8_ep.py +1 -0
  92. sglang/test/test_utils.py +3 -1
  93. sglang/utils.py +9 -0
  94. sglang/version.py +1 -1
  95. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/METADATA +5 -5
  96. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/RECORD +99 -88
  97. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/WHEEL +0 -0
  98. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/licenses/LICENSE +0 -0
  99. {sglang-0.4.7.dist-info → sglang-0.4.7.post1.dist-info}/top_level.txt +0 -0
@@ -16,7 +16,7 @@
16
16
  import time
17
17
  from typing import Dict, List, Optional, Union
18
18
 
19
- from pydantic import BaseModel, Field, root_validator
19
+ from pydantic import BaseModel, Field, model_serializer, root_validator
20
20
  from typing_extensions import Literal
21
21
 
22
22
 
@@ -182,6 +182,7 @@ class CompletionRequest(BaseModel):
182
182
  skip_special_tokens: bool = True
183
183
  lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
184
184
  session_params: Optional[Dict] = None
185
+ return_hidden_states: Optional[bool] = False
185
186
 
186
187
  # For PD disaggregation
187
188
  bootstrap_host: Optional[str] = None
@@ -195,6 +196,11 @@ class CompletionResponseChoice(BaseModel):
195
196
  logprobs: Optional[LogProbs] = None
196
197
  finish_reason: Literal["stop", "length", "content_filter", "abort"]
197
198
  matched_stop: Union[None, int, str] = None
199
+ hidden_states: Optional[object] = None
200
+
201
+ @model_serializer
202
+ def _serialize(self):
203
+ return exclude_if_none(self, ["hidden_states"])
198
204
 
199
205
 
200
206
  class CompletionResponse(BaseModel):
@@ -212,6 +218,11 @@ class CompletionResponseStreamChoice(BaseModel):
212
218
  logprobs: Optional[LogProbs] = None
213
219
  finish_reason: Optional[Literal["stop", "length", "content_filter"]] = None
214
220
  matched_stop: Union[None, int, str] = None
221
+ hidden_states: Optional[object] = None
222
+
223
+ @model_serializer
224
+ def _serialize(self):
225
+ return exclude_if_none(self, ["hidden_states"])
215
226
 
216
227
 
217
228
  class CompletionStreamResponse(BaseModel):
@@ -405,6 +416,9 @@ class ChatCompletionRequest(BaseModel):
405
416
  bootstrap_port: Optional[int] = None
406
417
  bootstrap_room: Optional[int] = None
407
418
 
419
+ # Hidden States
420
+ return_hidden_states: Optional[bool] = False
421
+
408
422
 
409
423
  class ChatMessage(BaseModel):
410
424
  role: Optional[str] = None
@@ -421,6 +435,11 @@ class ChatCompletionResponseChoice(BaseModel):
421
435
  "stop", "length", "tool_calls", "content_filter", "function_call", "abort"
422
436
  ]
423
437
  matched_stop: Union[None, int, str] = None
438
+ hidden_states: Optional[object] = None
439
+
440
+ @model_serializer
441
+ def _serialize(self):
442
+ return exclude_if_none(self, ["hidden_states"])
424
443
 
425
444
 
426
445
  class ChatCompletionResponse(BaseModel):
@@ -437,6 +456,11 @@ class DeltaMessage(BaseModel):
437
456
  content: Optional[str] = None
438
457
  reasoning_content: Optional[str] = None
439
458
  tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
459
+ hidden_states: Optional[object] = None
460
+
461
+ @model_serializer
462
+ def _serialize(self):
463
+ return exclude_if_none(self, ["hidden_states"])
440
464
 
441
465
 
442
466
  class ChatCompletionResponseStreamChoice(BaseModel):
@@ -513,3 +537,15 @@ class ScoringResponse(BaseModel):
513
537
  model: str
514
538
  usage: Optional[UsageInfo] = None
515
539
  object: str = "scoring"
540
+
541
+
542
+ class RerankResponse(BaseModel):
543
+ score: float
544
+ document: str
545
+ index: int
546
+ meta_info: Optional[dict] = None
547
+
548
+
549
+ def exclude_if_none(obj, field_names: List[str]):
550
+ omit_if_none_fields = {k for k, v in obj.model_fields.items() if k in field_names}
551
+ return {k: v for k, v in obj if k not in omit_if_none_fields or v is not None}
@@ -10,6 +10,7 @@ import torch
10
10
  import sglang.srt.sampling.penaltylib as penaltylib
11
11
  from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
12
12
  from sglang.srt.sampling.sampling_params import TOP_K_ALL
13
+ from sglang.srt.utils import merge_bias_tensor
13
14
 
14
15
  if TYPE_CHECKING:
15
16
  from sglang.srt.managers.schedule_batch import ScheduleBatch
@@ -63,6 +64,9 @@ class SamplingBatchInfo:
63
64
  # Device
64
65
  device: str = "cuda"
65
66
 
67
+ # Handle logit bias
68
+ logit_bias: Optional[torch.Tensor] = None
69
+
66
70
  @classmethod
67
71
  def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
68
72
  reqs = batch.reqs
@@ -85,6 +89,14 @@ class SamplingBatchInfo:
85
89
  [r.sampling_params.min_p for r in reqs], dtype=torch.float
86
90
  ).to(device, non_blocking=True)
87
91
 
92
+ logit_bias = None
93
+ if any(r.sampling_params.logit_bias is not None for r in reqs):
94
+ logit_bias = torch.zeros(len(reqs), vocab_size, device=device)
95
+ for i, r in enumerate(reqs):
96
+ if r.sampling_params.logit_bias is not None:
97
+ for key, value in r.sampling_params.logit_bias.items():
98
+ logit_bias[i, int(key)] = value
99
+
88
100
  # Check if any request has custom logit processor
89
101
  has_custom_logit_processor = (
90
102
  batch.enable_custom_logit_processor # check the flag first.
@@ -150,6 +162,7 @@ class SamplingBatchInfo:
150
162
  custom_params=custom_params,
151
163
  custom_logit_processor=merged_custom_logit_processor,
152
164
  device=device,
165
+ logit_bias=logit_bias,
153
166
  )
154
167
  return ret
155
168
 
@@ -206,6 +219,9 @@ class SamplingBatchInfo:
206
219
  if self.vocab_mask is not None:
207
220
  self.apply_mask_func(logits=logits, vocab_mask=self.vocab_mask)
208
221
 
222
+ if self.logit_bias is not None:
223
+ logits.add_(self.logit_bias)
224
+
209
225
  def filter_batch(self, keep_indices: List[int], keep_indices_device: torch.Tensor):
210
226
  self.penalizer_orchestrator.filter(keep_indices_device)
211
227
 
@@ -221,6 +237,9 @@ class SamplingBatchInfo:
221
237
  value = getattr(self, item, None)
222
238
  setattr(self, item, value[keep_indices_device])
223
239
 
240
+ if self.logit_bias is not None:
241
+ self.logit_bias = self.logit_bias[keep_indices_device]
242
+
224
243
  def _filter_batch_custom_logit_processor(
225
244
  self, keep_indices: List[int], keep_indices_device: torch.Tensor
226
245
  ):
@@ -321,3 +340,8 @@ class SamplingBatchInfo:
321
340
  self.need_top_p_sampling |= other.need_top_p_sampling
322
341
  self.need_top_k_sampling |= other.need_top_k_sampling
323
342
  self.need_min_p_sampling |= other.need_min_p_sampling
343
+
344
+ # Merge logit bias
345
+ self.logit_bias = merge_bias_tensor(
346
+ self.logit_bias, other.logit_bias, len(self), len(other), self.device, 0.0
347
+ )
@@ -52,6 +52,7 @@ class SamplingParams:
52
52
  no_stop_trim: bool = False,
53
53
  custom_params: Optional[Dict[str, Any]] = None,
54
54
  stream_interval: Optional[int] = None,
55
+ logit_bias: Optional[Dict[str, float]] = None,
55
56
  ) -> None:
56
57
  self.max_new_tokens = max_new_tokens
57
58
  self.stop_strs = stop
@@ -78,6 +79,7 @@ class SamplingParams:
78
79
  self.no_stop_trim = no_stop_trim
79
80
  self.custom_params = custom_params
80
81
  self.stream_interval = stream_interval
82
+ self.logit_bias = logit_bias
81
83
 
82
84
  # Process some special cases
83
85
  if 0 <= self.temperature < _SAMPLING_EPS: