sglang 0.4.1.post6__py3-none-any.whl → 0.4.1.post7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (122) hide show
  1. sglang/__init__.py +21 -23
  2. sglang/api.py +2 -7
  3. sglang/bench_offline_throughput.py +24 -16
  4. sglang/bench_one_batch.py +51 -3
  5. sglang/bench_one_batch_server.py +1 -1
  6. sglang/bench_serving.py +37 -28
  7. sglang/lang/backend/runtime_endpoint.py +183 -4
  8. sglang/lang/chat_template.py +15 -4
  9. sglang/launch_server.py +1 -1
  10. sglang/srt/_custom_ops.py +80 -42
  11. sglang/srt/configs/device_config.py +1 -1
  12. sglang/srt/configs/model_config.py +1 -0
  13. sglang/srt/constrained/base_grammar_backend.py +21 -0
  14. sglang/srt/constrained/xgrammar_backend.py +8 -4
  15. sglang/srt/conversation.py +14 -1
  16. sglang/srt/distributed/__init__.py +3 -3
  17. sglang/srt/distributed/communication_op.py +2 -1
  18. sglang/srt/distributed/device_communicators/cuda_wrapper.py +2 -1
  19. sglang/srt/distributed/device_communicators/custom_all_reduce.py +107 -40
  20. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
  21. sglang/srt/distributed/device_communicators/hpu_communicator.py +2 -1
  22. sglang/srt/distributed/device_communicators/pynccl.py +80 -1
  23. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +112 -2
  24. sglang/srt/distributed/device_communicators/shm_broadcast.py +5 -72
  25. sglang/srt/distributed/device_communicators/xpu_communicator.py +2 -1
  26. sglang/srt/distributed/parallel_state.py +1 -1
  27. sglang/srt/distributed/utils.py +2 -1
  28. sglang/srt/entrypoints/engine.py +449 -0
  29. sglang/srt/entrypoints/http_server.py +579 -0
  30. sglang/srt/layers/activation.py +3 -3
  31. sglang/srt/layers/attention/flashinfer_backend.py +10 -9
  32. sglang/srt/layers/attention/triton_backend.py +4 -6
  33. sglang/srt/layers/attention/vision.py +204 -0
  34. sglang/srt/layers/dp_attention.py +69 -0
  35. sglang/srt/layers/linear.py +41 -5
  36. sglang/srt/layers/logits_processor.py +48 -63
  37. sglang/srt/layers/moe/ep_moe/layer.py +4 -4
  38. sglang/srt/layers/moe/fused_moe_native.py +69 -0
  39. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -6
  40. sglang/srt/layers/moe/fused_moe_triton/layer.py +29 -5
  41. sglang/srt/layers/parameter.py +2 -1
  42. sglang/srt/layers/quantization/__init__.py +20 -23
  43. sglang/srt/layers/quantization/fp8.py +6 -3
  44. sglang/srt/layers/quantization/modelopt_quant.py +1 -2
  45. sglang/srt/layers/quantization/w8a8_int8.py +1 -1
  46. sglang/srt/layers/radix_attention.py +2 -2
  47. sglang/srt/layers/rotary_embedding.py +1179 -31
  48. sglang/srt/layers/sampler.py +39 -1
  49. sglang/srt/layers/vocab_parallel_embedding.py +2 -2
  50. sglang/srt/lora/lora.py +1 -9
  51. sglang/srt/managers/configure_logging.py +3 -0
  52. sglang/srt/managers/data_parallel_controller.py +79 -72
  53. sglang/srt/managers/detokenizer_manager.py +23 -6
  54. sglang/srt/managers/image_processor.py +158 -2
  55. sglang/srt/managers/io_struct.py +25 -2
  56. sglang/srt/managers/schedule_batch.py +49 -22
  57. sglang/srt/managers/schedule_policy.py +26 -12
  58. sglang/srt/managers/scheduler.py +277 -178
  59. sglang/srt/managers/session_controller.py +1 -0
  60. sglang/srt/managers/tokenizer_manager.py +206 -121
  61. sglang/srt/managers/tp_worker.py +6 -4
  62. sglang/srt/managers/tp_worker_overlap_thread.py +5 -8
  63. sglang/srt/managers/utils.py +44 -0
  64. sglang/srt/mem_cache/memory_pool.py +10 -32
  65. sglang/srt/metrics/collector.py +15 -6
  66. sglang/srt/model_executor/cuda_graph_runner.py +4 -6
  67. sglang/srt/model_executor/model_runner.py +37 -15
  68. sglang/srt/model_loader/loader.py +8 -6
  69. sglang/srt/model_loader/weight_utils.py +55 -2
  70. sglang/srt/models/baichuan.py +6 -6
  71. sglang/srt/models/chatglm.py +2 -2
  72. sglang/srt/models/commandr.py +3 -3
  73. sglang/srt/models/dbrx.py +4 -4
  74. sglang/srt/models/deepseek.py +3 -3
  75. sglang/srt/models/deepseek_v2.py +8 -8
  76. sglang/srt/models/exaone.py +2 -2
  77. sglang/srt/models/gemma.py +2 -2
  78. sglang/srt/models/gemma2.py +6 -24
  79. sglang/srt/models/gpt2.py +3 -5
  80. sglang/srt/models/gpt_bigcode.py +1 -1
  81. sglang/srt/models/granite.py +2 -2
  82. sglang/srt/models/grok.py +3 -3
  83. sglang/srt/models/internlm2.py +2 -2
  84. sglang/srt/models/llama.py +7 -5
  85. sglang/srt/models/minicpm.py +2 -2
  86. sglang/srt/models/minicpm3.py +6 -6
  87. sglang/srt/models/minicpmv.py +1238 -0
  88. sglang/srt/models/mixtral.py +3 -3
  89. sglang/srt/models/mixtral_quant.py +3 -3
  90. sglang/srt/models/mllama.py +2 -2
  91. sglang/srt/models/olmo.py +3 -3
  92. sglang/srt/models/olmo2.py +4 -4
  93. sglang/srt/models/olmoe.py +7 -13
  94. sglang/srt/models/phi3_small.py +2 -2
  95. sglang/srt/models/qwen.py +2 -2
  96. sglang/srt/models/qwen2.py +41 -4
  97. sglang/srt/models/qwen2_moe.py +3 -3
  98. sglang/srt/models/qwen2_vl.py +22 -122
  99. sglang/srt/models/stablelm.py +2 -2
  100. sglang/srt/models/torch_native_llama.py +3 -3
  101. sglang/srt/models/xverse.py +6 -6
  102. sglang/srt/models/xverse_moe.py +6 -6
  103. sglang/srt/openai_api/protocol.py +2 -0
  104. sglang/srt/sampling/custom_logit_processor.py +38 -0
  105. sglang/srt/sampling/sampling_batch_info.py +139 -4
  106. sglang/srt/sampling/sampling_params.py +3 -1
  107. sglang/srt/server.py +4 -1090
  108. sglang/srt/server_args.py +57 -14
  109. sglang/srt/utils.py +103 -65
  110. sglang/test/runners.py +8 -13
  111. sglang/test/test_programs.py +1 -1
  112. sglang/test/test_utils.py +3 -1
  113. sglang/utils.py +12 -2
  114. sglang/version.py +1 -1
  115. {sglang-0.4.1.post6.dist-info → sglang-0.4.1.post7.dist-info}/METADATA +16 -5
  116. {sglang-0.4.1.post6.dist-info → sglang-0.4.1.post7.dist-info}/RECORD +119 -115
  117. sglang/launch_server_llavavid.py +0 -25
  118. sglang/srt/constrained/__init__.py +0 -16
  119. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  120. {sglang-0.4.1.post6.dist-info → sglang-0.4.1.post7.dist-info}/LICENSE +0 -0
  121. {sglang-0.4.1.post6.dist-info → sglang-0.4.1.post7.dist-info}/WHEEL +0 -0
  122. {sglang-0.4.1.post6.dist-info → sglang-0.4.1.post7.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,38 @@
1
+ import json
2
+ from abc import ABC, abstractmethod
3
+ from functools import lru_cache
4
+ from typing import Any, Dict, List, Optional
5
+
6
+ import dill
7
+ import torch
8
+
9
+
10
+ @lru_cache(maxsize=None)
11
+ def _cache_from_str(json_str: str):
12
+ """Deserialize a json string to a Callable object.
13
+ This function is cached to avoid redundant deserialization.
14
+ """
15
+ data = json.loads(json_str)
16
+ return dill.loads(bytes.fromhex(data["callable"]))
17
+
18
+
19
+ class CustomLogitProcessor(ABC):
20
+ """Abstract base class for callable functions."""
21
+
22
+ @abstractmethod
23
+ def __call__(
24
+ self,
25
+ logits: torch.Tensor,
26
+ custom_param_list: Optional[List[Dict[str, Any]]] = None,
27
+ ) -> torch.Tensor:
28
+ """Define the callable behavior."""
29
+ raise NotImplementedError
30
+
31
+ def to_str(self) -> str:
32
+ """Serialize the callable function to a JSON-compatible string."""
33
+ return json.dumps({"callable": dill.dumps(self).hex()})
34
+
35
+ @classmethod
36
+ def from_str(cls, json_str: str):
37
+ """Deserialize a callable function from a JSON string."""
38
+ return _cache_from_str(json_str)
@@ -3,7 +3,7 @@ from __future__ import annotations
3
3
  import dataclasses
4
4
  import logging
5
5
  import threading
6
- from typing import TYPE_CHECKING, Callable, List, Optional
6
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
7
7
 
8
8
  import torch
9
9
 
@@ -14,6 +14,7 @@ if is_cuda:
14
14
  from sgl_kernel import sampling_scaling_penalties
15
15
 
16
16
  import sglang.srt.sampling.penaltylib as penaltylib
17
+ from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
17
18
 
18
19
  logger = logging.getLogger(__name__)
19
20
 
@@ -36,6 +37,9 @@ class SamplingBatchInfo:
36
37
  # Dispatch in CUDA graph
37
38
  need_min_p_sampling: bool
38
39
 
40
+ # Whether any request has custom logit processor
41
+ has_custom_logit_processor: bool
42
+
39
43
  # Bias Tensors
40
44
  vocab_size: int
41
45
  grammars: Optional[List] = None
@@ -52,6 +56,14 @@ class SamplingBatchInfo:
52
56
  # Device
53
57
  device: str = "cuda"
54
58
 
59
+ # Custom Parameters
60
+ custom_params: Optional[List[Optional[Dict[str, Any]]]] = None
61
+
62
+ # Custom Logit Processor
63
+ custom_logit_processor: Optional[
64
+ Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]
65
+ ] = None
66
+
55
67
  @classmethod
56
68
  def from_schedule_batch(
57
69
  cls, batch: ScheduleBatch, vocab_size: int, enable_overlap_schedule: bool
@@ -76,6 +88,39 @@ class SamplingBatchInfo:
76
88
  [r.sampling_params.min_p for r in reqs], dtype=torch.float
77
89
  ).to(device, non_blocking=True)
78
90
 
91
+ # Check if any request has custom logit processor
92
+ has_custom_logit_processor = (
93
+ batch.enable_custom_logit_processor # check the flag first.
94
+ and any(r.custom_logit_processor for r in reqs) # then check the requests.
95
+ )
96
+
97
+ if has_custom_logit_processor:
98
+ # Merge the same type of custom logit processors together
99
+ processor_dict = {}
100
+ for i, r in enumerate(reqs):
101
+ if r.custom_logit_processor is None:
102
+ continue
103
+ processor_str = r.custom_logit_processor
104
+ if processor_str not in processor_dict:
105
+ processor_dict[processor_str] = []
106
+ processor_dict[processor_str].append(i)
107
+
108
+ merged_custom_logit_processor = {
109
+ hash(processor_str): (
110
+ # The deserialized custom logit processor object
111
+ CustomLogitProcessor.from_str(processor_str),
112
+ # The mask tensor for the requests that use this custom logit processor
113
+ torch.zeros(len(reqs), dtype=torch.bool)
114
+ .scatter_(0, torch.tensor(true_indices), True)
115
+ .to(device, non_blocking=True),
116
+ )
117
+ for processor_str, true_indices in processor_dict.items()
118
+ }
119
+ custom_params = [r.sampling_params.custom_params for r in reqs]
120
+ else:
121
+ merged_custom_logit_processor = None
122
+ custom_params = None
123
+
79
124
  ret = cls(
80
125
  temperatures=temperatures,
81
126
  top_ps=top_ps,
@@ -83,8 +128,11 @@ class SamplingBatchInfo:
83
128
  min_ps=min_ps,
84
129
  need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
85
130
  is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs),
131
+ has_custom_logit_processor=has_custom_logit_processor,
86
132
  vocab_size=vocab_size,
87
133
  device=device,
134
+ custom_params=custom_params,
135
+ custom_logit_processor=merged_custom_logit_processor,
88
136
  )
89
137
  # TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
90
138
 
@@ -184,6 +232,8 @@ class SamplingBatchInfo:
184
232
 
185
233
  def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):
186
234
  self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
235
+ if self.has_custom_logit_processor:
236
+ self._filter_batch_custom_logit_processor(unfinished_indices, new_indices)
187
237
 
188
238
  for item in [
189
239
  "temperatures",
@@ -196,6 +246,27 @@ class SamplingBatchInfo:
196
246
  if value is not None: # logit_bias can be None
197
247
  setattr(self, item, value[new_indices])
198
248
 
249
+ def _filter_batch_custom_logit_processor(
250
+ self, unfinished_indices: List[int], new_indices: torch.Tensor
251
+ ):
252
+ """Filter the custom logit processor and custom params"""
253
+
254
+ self.custom_logit_processor = {
255
+ k: (p, mask[new_indices])
256
+ for k, (p, mask) in self.custom_logit_processor.items()
257
+ if any(
258
+ mask[new_indices]
259
+ ) # ignore the custom logit processor whose mask is all False
260
+ }
261
+ self.custom_params = [self.custom_params[i] for i in unfinished_indices]
262
+
263
+ # If the custom logit processor is an empty dict, set the flag to False,
264
+ # and set the custom logit processor and custom params to None.
265
+ if len(self.custom_logit_processor) == 0:
266
+ self.custom_logit_processor = None
267
+ self.custom_params = None
268
+ self.has_custom_logit_processor = False
269
+
199
270
  @staticmethod
200
271
  def merge_bias_tensor(
201
272
  lhs: torch.Tensor,
@@ -221,9 +292,76 @@ class SamplingBatchInfo:
221
292
 
222
293
  return None
223
294
 
295
+ @staticmethod
296
+ def merge_custom_logit_processor(
297
+ lhs: Optional[Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]],
298
+ rhs: Optional[Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]],
299
+ bs1: int,
300
+ bs2: int,
301
+ device: str,
302
+ ):
303
+ if lhs is None and rhs is None:
304
+ return None
305
+ lhs, rhs = lhs or {}, rhs or {}
306
+
307
+ keys = set(lhs.keys()).union(set(rhs.keys()))
308
+ merged_dict = {}
309
+
310
+ for k in keys:
311
+ # Get the logit processor object
312
+ processor = lhs[k][0] if k in lhs else rhs[k][0]
313
+ # Get and merge the mask tensors from the two dicts
314
+ left_mask = (
315
+ lhs[k][1]
316
+ if k in lhs
317
+ else torch.zeros(bs1, dtype=torch.bool, device=device)
318
+ )
319
+ right_mask = (
320
+ rhs[k][1]
321
+ if k in rhs
322
+ else torch.zeros(bs2, dtype=torch.bool, device=device)
323
+ )
324
+ merged_dict[k] = (processor, torch.cat([left_mask, right_mask]))
325
+
326
+ assert merged_dict[k][1].shape[0] == bs1 + bs2, (
327
+ f"The batch size of merged mask ({merged_dict[k][1].shape[0]}) does not match "
328
+ f"the sum of the batch sizes of the two masks ({bs1 + bs2})"
329
+ f"\n{left_mask=}\n{right_mask=}\n{bs1=}\n{bs2=}"
330
+ f"\n{lhs=}\n{rhs=}"
331
+ )
332
+
333
+ return merged_dict
334
+
224
335
  def merge_batch(self, other: "SamplingBatchInfo"):
225
336
  self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
226
337
 
338
+ # Merge the logit bias tensor
339
+ self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
340
+ self.logit_bias, other.logit_bias, len(self), len(other), self.device
341
+ )
342
+ # Merge the custom logit processors and custom params lists
343
+ if self.has_custom_logit_processor or other.has_custom_logit_processor:
344
+ # Merge the custom logit processors
345
+ self.custom_logit_processor = (
346
+ SamplingBatchInfo.merge_custom_logit_processor(
347
+ self.custom_logit_processor,
348
+ other.custom_logit_processor,
349
+ len(self),
350
+ len(other),
351
+ self.device,
352
+ )
353
+ )
354
+ # Merge the custom params lists
355
+ self.custom_params = self.custom_params or [None] * len(self)
356
+ other.custom_params = other.custom_params or [None] * len(other)
357
+ self.custom_params.extend(other.custom_params)
358
+
359
+ # Set the flag to True if any of the two has custom logit processor
360
+ self.has_custom_logit_processor = True
361
+
362
+ # Note: becasue the __len()__ operator is defined on the temperatures tensor,
363
+ # please make sure any merge operation with len(self) or len(other) is done before
364
+ # the merge operation of the temperatures tensor below.
227
365
  for item in [
228
366
  "temperatures",
229
367
  "top_ps",
@@ -235,9 +373,6 @@ class SamplingBatchInfo:
235
373
  setattr(self, item, torch.concat([self_val, other_val]))
236
374
 
237
375
  self.is_all_greedy = self.is_all_greedy and other.is_all_greedy
238
- self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
239
- self.logit_bias, other.logit_bias, len(self), len(other), self.device
240
- )
241
376
  self.need_min_p_sampling = self.need_min_p_sampling or other.need_min_p_sampling
242
377
 
243
378
  def apply_logits_bias(self, logits: torch.Tensor):
@@ -13,7 +13,7 @@
13
13
  # ==============================================================================
14
14
  """Sampling parameters for text generation."""
15
15
 
16
- from typing import List, Optional, Union
16
+ from typing import Any, Dict, List, Optional, Union
17
17
 
18
18
  _SAMPLING_EPS = 1e-6
19
19
 
@@ -48,6 +48,7 @@ class SamplingParams:
48
48
  no_stop_trim: bool = False,
49
49
  ignore_eos: bool = False,
50
50
  skip_special_tokens: bool = True,
51
+ custom_params: Optional[Dict[str, Any]] = None,
51
52
  ) -> None:
52
53
  self.temperature = temperature
53
54
  self.top_p = top_p
@@ -71,6 +72,7 @@ class SamplingParams:
71
72
  self.json_schema = json_schema
72
73
  self.ebnf = ebnf
73
74
  self.no_stop_trim = no_stop_trim
75
+ self.custom_params = custom_params
74
76
 
75
77
  # Process some special cases
76
78
  if self.temperature < _SAMPLING_EPS: