sglang 0.4.1.post5__py3-none-any.whl → 0.4.1.post7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (129) hide show
  1. sglang/__init__.py +21 -23
  2. sglang/api.py +2 -7
  3. sglang/bench_offline_throughput.py +24 -16
  4. sglang/bench_one_batch.py +51 -3
  5. sglang/bench_one_batch_server.py +1 -1
  6. sglang/bench_serving.py +37 -28
  7. sglang/lang/backend/runtime_endpoint.py +183 -4
  8. sglang/lang/chat_template.py +15 -4
  9. sglang/launch_server.py +1 -1
  10. sglang/srt/_custom_ops.py +80 -42
  11. sglang/srt/configs/device_config.py +1 -1
  12. sglang/srt/configs/model_config.py +16 -6
  13. sglang/srt/constrained/base_grammar_backend.py +21 -0
  14. sglang/srt/constrained/xgrammar_backend.py +8 -4
  15. sglang/srt/conversation.py +14 -1
  16. sglang/srt/distributed/__init__.py +3 -3
  17. sglang/srt/distributed/communication_op.py +2 -1
  18. sglang/srt/distributed/device_communicators/cuda_wrapper.py +2 -1
  19. sglang/srt/distributed/device_communicators/custom_all_reduce.py +107 -40
  20. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
  21. sglang/srt/distributed/device_communicators/hpu_communicator.py +2 -1
  22. sglang/srt/distributed/device_communicators/pynccl.py +80 -1
  23. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +112 -2
  24. sglang/srt/distributed/device_communicators/shm_broadcast.py +5 -72
  25. sglang/srt/distributed/device_communicators/xpu_communicator.py +2 -1
  26. sglang/srt/distributed/parallel_state.py +1 -1
  27. sglang/srt/distributed/utils.py +2 -1
  28. sglang/srt/entrypoints/engine.py +449 -0
  29. sglang/srt/entrypoints/http_server.py +579 -0
  30. sglang/srt/layers/activation.py +3 -3
  31. sglang/srt/layers/attention/flashinfer_backend.py +27 -12
  32. sglang/srt/layers/attention/triton_backend.py +4 -6
  33. sglang/srt/layers/attention/vision.py +204 -0
  34. sglang/srt/layers/dp_attention.py +69 -0
  35. sglang/srt/layers/linear.py +76 -102
  36. sglang/srt/layers/logits_processor.py +48 -63
  37. sglang/srt/layers/moe/ep_moe/layer.py +4 -4
  38. sglang/srt/layers/moe/fused_moe_native.py +69 -0
  39. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -6
  40. sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -14
  41. sglang/srt/layers/moe/topk.py +4 -2
  42. sglang/srt/layers/parameter.py +26 -17
  43. sglang/srt/layers/quantization/__init__.py +22 -23
  44. sglang/srt/layers/quantization/fp8.py +112 -55
  45. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  46. sglang/srt/layers/quantization/int8_kernel.py +54 -0
  47. sglang/srt/layers/quantization/modelopt_quant.py +2 -3
  48. sglang/srt/layers/quantization/w8a8_int8.py +117 -0
  49. sglang/srt/layers/radix_attention.py +2 -0
  50. sglang/srt/layers/rotary_embedding.py +1179 -31
  51. sglang/srt/layers/sampler.py +39 -1
  52. sglang/srt/layers/vocab_parallel_embedding.py +17 -4
  53. sglang/srt/lora/lora.py +1 -9
  54. sglang/srt/managers/configure_logging.py +46 -0
  55. sglang/srt/managers/data_parallel_controller.py +79 -72
  56. sglang/srt/managers/detokenizer_manager.py +23 -8
  57. sglang/srt/managers/image_processor.py +158 -2
  58. sglang/srt/managers/io_struct.py +54 -15
  59. sglang/srt/managers/schedule_batch.py +49 -22
  60. sglang/srt/managers/schedule_policy.py +26 -12
  61. sglang/srt/managers/scheduler.py +319 -181
  62. sglang/srt/managers/session_controller.py +1 -0
  63. sglang/srt/managers/tokenizer_manager.py +303 -158
  64. sglang/srt/managers/tp_worker.py +6 -4
  65. sglang/srt/managers/tp_worker_overlap_thread.py +5 -8
  66. sglang/srt/managers/utils.py +44 -0
  67. sglang/srt/mem_cache/memory_pool.py +110 -77
  68. sglang/srt/metrics/collector.py +25 -11
  69. sglang/srt/model_executor/cuda_graph_runner.py +4 -6
  70. sglang/srt/model_executor/model_runner.py +80 -21
  71. sglang/srt/model_loader/loader.py +8 -6
  72. sglang/srt/model_loader/weight_utils.py +55 -2
  73. sglang/srt/models/baichuan.py +6 -6
  74. sglang/srt/models/chatglm.py +2 -2
  75. sglang/srt/models/commandr.py +3 -3
  76. sglang/srt/models/dbrx.py +4 -4
  77. sglang/srt/models/deepseek.py +3 -3
  78. sglang/srt/models/deepseek_v2.py +8 -8
  79. sglang/srt/models/exaone.py +2 -2
  80. sglang/srt/models/gemma.py +2 -2
  81. sglang/srt/models/gemma2.py +6 -24
  82. sglang/srt/models/gpt2.py +3 -5
  83. sglang/srt/models/gpt_bigcode.py +1 -1
  84. sglang/srt/models/granite.py +2 -2
  85. sglang/srt/models/grok.py +3 -3
  86. sglang/srt/models/internlm2.py +2 -2
  87. sglang/srt/models/llama.py +41 -4
  88. sglang/srt/models/minicpm.py +2 -2
  89. sglang/srt/models/minicpm3.py +6 -6
  90. sglang/srt/models/minicpmv.py +1238 -0
  91. sglang/srt/models/mixtral.py +3 -3
  92. sglang/srt/models/mixtral_quant.py +3 -3
  93. sglang/srt/models/mllama.py +2 -2
  94. sglang/srt/models/olmo.py +3 -3
  95. sglang/srt/models/olmo2.py +4 -4
  96. sglang/srt/models/olmoe.py +7 -13
  97. sglang/srt/models/phi3_small.py +2 -2
  98. sglang/srt/models/qwen.py +2 -2
  99. sglang/srt/models/qwen2.py +52 -4
  100. sglang/srt/models/qwen2_eagle.py +131 -0
  101. sglang/srt/models/qwen2_moe.py +3 -3
  102. sglang/srt/models/qwen2_vl.py +22 -122
  103. sglang/srt/models/stablelm.py +2 -2
  104. sglang/srt/models/torch_native_llama.py +3 -3
  105. sglang/srt/models/xverse.py +6 -6
  106. sglang/srt/models/xverse_moe.py +6 -6
  107. sglang/srt/openai_api/protocol.py +2 -0
  108. sglang/srt/sampling/custom_logit_processor.py +38 -0
  109. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +15 -5
  110. sglang/srt/sampling/sampling_batch_info.py +153 -9
  111. sglang/srt/sampling/sampling_params.py +4 -2
  112. sglang/srt/server.py +4 -1037
  113. sglang/srt/server_args.py +84 -32
  114. sglang/srt/speculative/eagle_worker.py +1 -0
  115. sglang/srt/torch_memory_saver_adapter.py +59 -0
  116. sglang/srt/utils.py +130 -63
  117. sglang/test/runners.py +8 -13
  118. sglang/test/test_programs.py +1 -1
  119. sglang/test/test_utils.py +3 -1
  120. sglang/utils.py +12 -2
  121. sglang/version.py +1 -1
  122. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/METADATA +26 -13
  123. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/RECORD +126 -117
  124. sglang/launch_server_llavavid.py +0 -25
  125. sglang/srt/constrained/__init__.py +0 -16
  126. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  127. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/LICENSE +0 -0
  128. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/WHEEL +0 -0
  129. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/top_level.txt +0 -0
@@ -24,9 +24,8 @@ from typing import Iterable, Optional, Tuple
24
24
  import torch
25
25
  from torch import nn
26
26
  from transformers import PretrainedConfig
27
- from vllm.distributed import get_tensor_model_parallel_world_size
28
- from vllm.model_executor.layers.rotary_embedding import get_rope
29
27
 
28
+ from sglang.srt.distributed import get_tensor_model_parallel_world_size
30
29
  from sglang.srt.layers.activation import SiluAndMul
31
30
  from sglang.srt.layers.linear import (
32
31
  MergedColumnParallelLinear,
@@ -36,6 +35,7 @@ from sglang.srt.layers.linear import (
36
35
  from sglang.srt.layers.logits_processor import LogitsProcessor
37
36
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
38
37
  from sglang.srt.layers.radix_attention import RadixAttention
38
+ from sglang.srt.layers.rotary_embedding import get_rope
39
39
  from sglang.srt.layers.vocab_parallel_embedding import (
40
40
  ParallelLMHead,
41
41
  VocabParallelEmbedding,
@@ -47,17 +47,17 @@ import torch
47
47
  from torch import nn
48
48
  from torch.nn.parameter import Parameter
49
49
  from transformers import LlamaConfig
50
- from vllm.distributed import (
50
+
51
+ from sglang.srt.distributed import (
51
52
  get_tensor_model_parallel_rank,
52
53
  get_tensor_model_parallel_world_size,
53
54
  )
54
- from vllm.model_executor.layers.rotary_embedding import get_rope
55
-
56
55
  from sglang.srt.layers.activation import SiluAndMul
57
56
  from sglang.srt.layers.layernorm import RMSNorm
58
57
  from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
59
58
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
60
59
  from sglang.srt.layers.radix_attention import RadixAttention
60
+ from sglang.srt.layers.rotary_embedding import get_rope
61
61
  from sglang.srt.layers.vocab_parallel_embedding import (
62
62
  ParallelLMHead,
63
63
  VocabParallelEmbedding,
@@ -21,19 +21,19 @@ from typing import Any, Dict, Iterable, Optional, Tuple
21
21
  import torch
22
22
  from torch import nn
23
23
  from transformers import LlamaConfig
24
- from vllm.distributed import get_tensor_model_parallel_world_size
25
- from vllm.model_executor.layers.activation import SiluAndMul
26
- from vllm.model_executor.layers.layernorm import RMSNorm
27
- from vllm.model_executor.layers.linear import (
24
+
25
+ from sglang.srt.distributed import get_tensor_model_parallel_world_size
26
+ from sglang.srt.layers.activation import SiluAndMul
27
+ from sglang.srt.layers.layernorm import RMSNorm
28
+ from sglang.srt.layers.linear import (
28
29
  MergedColumnParallelLinear,
29
30
  QKVParallelLinear,
30
31
  RowParallelLinear,
31
32
  )
32
- from vllm.model_executor.layers.rotary_embedding import get_rope
33
-
34
33
  from sglang.srt.layers.logits_processor import LogitsProcessor
35
34
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
36
35
  from sglang.srt.layers.radix_attention import RadixAttention
36
+ from sglang.srt.layers.rotary_embedding import get_rope
37
37
  from sglang.srt.layers.vocab_parallel_embedding import (
38
38
  ParallelLMHead,
39
39
  VocabParallelEmbedding,
@@ -18,25 +18,25 @@ from typing import Any, Dict, Iterable, Optional, Tuple
18
18
  import torch
19
19
  from torch import nn
20
20
  from transformers import PretrainedConfig
21
- from vllm.distributed import (
21
+
22
+ from sglang.srt.distributed import (
22
23
  get_tensor_model_parallel_rank,
23
24
  get_tensor_model_parallel_world_size,
24
25
  tensor_model_parallel_all_reduce,
25
26
  )
26
- from vllm.model_executor.layers.activation import SiluAndMul
27
- from vllm.model_executor.layers.layernorm import RMSNorm
28
- from vllm.model_executor.layers.linear import (
27
+ from sglang.srt.layers.activation import SiluAndMul
28
+ from sglang.srt.layers.layernorm import RMSNorm
29
+ from sglang.srt.layers.linear import (
29
30
  MergedColumnParallelLinear,
30
31
  QKVParallelLinear,
31
32
  ReplicatedLinear,
32
33
  RowParallelLinear,
33
34
  )
34
- from vllm.model_executor.layers.rotary_embedding import get_rope
35
-
36
35
  from sglang.srt.layers.logits_processor import LogitsProcessor
37
36
  from sglang.srt.layers.moe.fused_moe_triton import fused_moe
38
37
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
39
38
  from sglang.srt.layers.radix_attention import RadixAttention
39
+ from sglang.srt.layers.rotary_embedding import get_rope
40
40
  from sglang.srt.layers.vocab_parallel_embedding import (
41
41
  ParallelLMHead,
42
42
  VocabParallelEmbedding,
@@ -180,6 +180,7 @@ class CompletionRequest(BaseModel):
180
180
  ignore_eos: bool = False
181
181
  skip_special_tokens: bool = True
182
182
  lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
183
+ session_params: Optional[Dict] = None
183
184
 
184
185
 
185
186
  class CompletionResponseChoice(BaseModel):
@@ -322,6 +323,7 @@ class ChatCompletionRequest(BaseModel):
322
323
  ignore_eos: bool = False
323
324
  skip_special_tokens: bool = True
324
325
  lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
326
+ session_params: Optional[Dict] = None
325
327
 
326
328
 
327
329
  class FunctionResponse(BaseModel):
@@ -0,0 +1,38 @@
1
+ import json
2
+ from abc import ABC, abstractmethod
3
+ from functools import lru_cache
4
+ from typing import Any, Dict, List, Optional
5
+
6
+ import dill
7
+ import torch
8
+
9
+
10
+ @lru_cache(maxsize=None)
11
+ def _cache_from_str(json_str: str):
12
+ """Deserialize a json string to a Callable object.
13
+ This function is cached to avoid redundant deserialization.
14
+ """
15
+ data = json.loads(json_str)
16
+ return dill.loads(bytes.fromhex(data["callable"]))
17
+
18
+
19
+ class CustomLogitProcessor(ABC):
20
+ """Abstract base class for callable functions."""
21
+
22
+ @abstractmethod
23
+ def __call__(
24
+ self,
25
+ logits: torch.Tensor,
26
+ custom_param_list: Optional[List[Dict[str, Any]]] = None,
27
+ ) -> torch.Tensor:
28
+ """Define the callable behavior."""
29
+ raise NotImplementedError
30
+
31
+ def to_str(self) -> str:
32
+ """Serialize the callable function to a JSON-compatible string."""
33
+ return json.dumps({"callable": dill.dumps(self).hex()})
34
+
35
+ @classmethod
36
+ def from_str(cls, json_str: str):
37
+ """Deserialize a callable function from a JSON string."""
38
+ return _cache_from_str(json_str)
@@ -3,6 +3,11 @@ from typing import List
3
3
  import torch
4
4
 
5
5
  from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
6
+ from sglang.srt.utils import is_cuda_available
7
+
8
+ is_cuda = is_cuda_available()
9
+ if is_cuda:
10
+ from sgl_kernel import sampling_scaling_penalties
6
11
 
7
12
 
8
13
  class BatchedRepetitionPenalizer(_BatchedPenalizer):
@@ -56,11 +61,16 @@ class BatchedRepetitionPenalizer(_BatchedPenalizer):
56
61
  self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask]
57
62
 
58
63
  def _apply(self, logits: torch.Tensor) -> torch.Tensor:
59
- return torch.where(
60
- logits > 0,
61
- logits / self.cumulated_repetition_penalties,
62
- logits * self.cumulated_repetition_penalties,
63
- )
64
+ if is_cuda:
65
+ return sampling_scaling_penalties(
66
+ logits, self.cumulated_repetition_penalties
67
+ )
68
+ else:
69
+ return torch.where(
70
+ logits > 0,
71
+ logits / self.cumulated_repetition_penalties,
72
+ logits * self.cumulated_repetition_penalties,
73
+ )
64
74
 
65
75
  def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
66
76
  self.repetition_penalties = self.repetition_penalties[indices_tensor_to_keep]
@@ -3,11 +3,18 @@ from __future__ import annotations
3
3
  import dataclasses
4
4
  import logging
5
5
  import threading
6
- from typing import TYPE_CHECKING, Callable, List, Optional
6
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
7
7
 
8
8
  import torch
9
9
 
10
+ from sglang.srt.utils import is_cuda_available
11
+
12
+ is_cuda = is_cuda_available()
13
+ if is_cuda:
14
+ from sgl_kernel import sampling_scaling_penalties
15
+
10
16
  import sglang.srt.sampling.penaltylib as penaltylib
17
+ from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
11
18
 
12
19
  logger = logging.getLogger(__name__)
13
20
 
@@ -30,6 +37,9 @@ class SamplingBatchInfo:
30
37
  # Dispatch in CUDA graph
31
38
  need_min_p_sampling: bool
32
39
 
40
+ # Whether any request has custom logit processor
41
+ has_custom_logit_processor: bool
42
+
33
43
  # Bias Tensors
34
44
  vocab_size: int
35
45
  grammars: Optional[List] = None
@@ -46,6 +56,14 @@ class SamplingBatchInfo:
46
56
  # Device
47
57
  device: str = "cuda"
48
58
 
59
+ # Custom Parameters
60
+ custom_params: Optional[List[Optional[Dict[str, Any]]]] = None
61
+
62
+ # Custom Logit Processor
63
+ custom_logit_processor: Optional[
64
+ Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]
65
+ ] = None
66
+
49
67
  @classmethod
50
68
  def from_schedule_batch(
51
69
  cls, batch: ScheduleBatch, vocab_size: int, enable_overlap_schedule: bool
@@ -70,6 +88,39 @@ class SamplingBatchInfo:
70
88
  [r.sampling_params.min_p for r in reqs], dtype=torch.float
71
89
  ).to(device, non_blocking=True)
72
90
 
91
+ # Check if any request has custom logit processor
92
+ has_custom_logit_processor = (
93
+ batch.enable_custom_logit_processor # check the flag first.
94
+ and any(r.custom_logit_processor for r in reqs) # then check the requests.
95
+ )
96
+
97
+ if has_custom_logit_processor:
98
+ # Merge the same type of custom logit processors together
99
+ processor_dict = {}
100
+ for i, r in enumerate(reqs):
101
+ if r.custom_logit_processor is None:
102
+ continue
103
+ processor_str = r.custom_logit_processor
104
+ if processor_str not in processor_dict:
105
+ processor_dict[processor_str] = []
106
+ processor_dict[processor_str].append(i)
107
+
108
+ merged_custom_logit_processor = {
109
+ hash(processor_str): (
110
+ # The deserialized custom logit processor object
111
+ CustomLogitProcessor.from_str(processor_str),
112
+ # The mask tensor for the requests that use this custom logit processor
113
+ torch.zeros(len(reqs), dtype=torch.bool)
114
+ .scatter_(0, torch.tensor(true_indices), True)
115
+ .to(device, non_blocking=True),
116
+ )
117
+ for processor_str, true_indices in processor_dict.items()
118
+ }
119
+ custom_params = [r.sampling_params.custom_params for r in reqs]
120
+ else:
121
+ merged_custom_logit_processor = None
122
+ custom_params = None
123
+
73
124
  ret = cls(
74
125
  temperatures=temperatures,
75
126
  top_ps=top_ps,
@@ -77,8 +128,11 @@ class SamplingBatchInfo:
77
128
  min_ps=min_ps,
78
129
  need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
79
130
  is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs),
131
+ has_custom_logit_processor=has_custom_logit_processor,
80
132
  vocab_size=vocab_size,
81
133
  device=device,
134
+ custom_params=custom_params,
135
+ custom_logit_processor=merged_custom_logit_processor,
82
136
  )
83
137
  # TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
84
138
 
@@ -178,6 +232,8 @@ class SamplingBatchInfo:
178
232
 
179
233
  def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):
180
234
  self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
235
+ if self.has_custom_logit_processor:
236
+ self._filter_batch_custom_logit_processor(unfinished_indices, new_indices)
181
237
 
182
238
  for item in [
183
239
  "temperatures",
@@ -190,6 +246,27 @@ class SamplingBatchInfo:
190
246
  if value is not None: # logit_bias can be None
191
247
  setattr(self, item, value[new_indices])
192
248
 
249
+ def _filter_batch_custom_logit_processor(
250
+ self, unfinished_indices: List[int], new_indices: torch.Tensor
251
+ ):
252
+ """Filter the custom logit processor and custom params"""
253
+
254
+ self.custom_logit_processor = {
255
+ k: (p, mask[new_indices])
256
+ for k, (p, mask) in self.custom_logit_processor.items()
257
+ if any(
258
+ mask[new_indices]
259
+ ) # ignore the custom logit processor whose mask is all False
260
+ }
261
+ self.custom_params = [self.custom_params[i] for i in unfinished_indices]
262
+
263
+ # If the custom logit processor is an empty dict, set the flag to False,
264
+ # and set the custom logit processor and custom params to None.
265
+ if len(self.custom_logit_processor) == 0:
266
+ self.custom_logit_processor = None
267
+ self.custom_params = None
268
+ self.has_custom_logit_processor = False
269
+
193
270
  @staticmethod
194
271
  def merge_bias_tensor(
195
272
  lhs: torch.Tensor,
@@ -215,9 +292,76 @@ class SamplingBatchInfo:
215
292
 
216
293
  return None
217
294
 
295
+ @staticmethod
296
+ def merge_custom_logit_processor(
297
+ lhs: Optional[Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]],
298
+ rhs: Optional[Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]],
299
+ bs1: int,
300
+ bs2: int,
301
+ device: str,
302
+ ):
303
+ if lhs is None and rhs is None:
304
+ return None
305
+ lhs, rhs = lhs or {}, rhs or {}
306
+
307
+ keys = set(lhs.keys()).union(set(rhs.keys()))
308
+ merged_dict = {}
309
+
310
+ for k in keys:
311
+ # Get the logit processor object
312
+ processor = lhs[k][0] if k in lhs else rhs[k][0]
313
+ # Get and merge the mask tensors from the two dicts
314
+ left_mask = (
315
+ lhs[k][1]
316
+ if k in lhs
317
+ else torch.zeros(bs1, dtype=torch.bool, device=device)
318
+ )
319
+ right_mask = (
320
+ rhs[k][1]
321
+ if k in rhs
322
+ else torch.zeros(bs2, dtype=torch.bool, device=device)
323
+ )
324
+ merged_dict[k] = (processor, torch.cat([left_mask, right_mask]))
325
+
326
+ assert merged_dict[k][1].shape[0] == bs1 + bs2, (
327
+ f"The batch size of merged mask ({merged_dict[k][1].shape[0]}) does not match "
328
+ f"the sum of the batch sizes of the two masks ({bs1 + bs2})"
329
+ f"\n{left_mask=}\n{right_mask=}\n{bs1=}\n{bs2=}"
330
+ f"\n{lhs=}\n{rhs=}"
331
+ )
332
+
333
+ return merged_dict
334
+
218
335
  def merge_batch(self, other: "SamplingBatchInfo"):
219
336
  self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
220
337
 
338
+ # Merge the logit bias tensor
339
+ self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
340
+ self.logit_bias, other.logit_bias, len(self), len(other), self.device
341
+ )
342
+ # Merge the custom logit processors and custom params lists
343
+ if self.has_custom_logit_processor or other.has_custom_logit_processor:
344
+ # Merge the custom logit processors
345
+ self.custom_logit_processor = (
346
+ SamplingBatchInfo.merge_custom_logit_processor(
347
+ self.custom_logit_processor,
348
+ other.custom_logit_processor,
349
+ len(self),
350
+ len(other),
351
+ self.device,
352
+ )
353
+ )
354
+ # Merge the custom params lists
355
+ self.custom_params = self.custom_params or [None] * len(self)
356
+ other.custom_params = other.custom_params or [None] * len(other)
357
+ self.custom_params.extend(other.custom_params)
358
+
359
+ # Set the flag to True if any of the two has custom logit processor
360
+ self.has_custom_logit_processor = True
361
+
362
+ # Note: becasue the __len()__ operator is defined on the temperatures tensor,
363
+ # please make sure any merge operation with len(self) or len(other) is done before
364
+ # the merge operation of the temperatures tensor below.
221
365
  for item in [
222
366
  "temperatures",
223
367
  "top_ps",
@@ -229,9 +373,6 @@ class SamplingBatchInfo:
229
373
  setattr(self, item, torch.concat([self_val, other_val]))
230
374
 
231
375
  self.is_all_greedy = self.is_all_greedy and other.is_all_greedy
232
- self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
233
- self.logit_bias, other.logit_bias, len(self), len(other), self.device
234
- )
235
376
  self.need_min_p_sampling = self.need_min_p_sampling or other.need_min_p_sampling
236
377
 
237
378
  def apply_logits_bias(self, logits: torch.Tensor):
@@ -245,11 +386,14 @@ class SamplingBatchInfo:
245
386
 
246
387
  # repetition
247
388
  if self.scaling_penalties is not None:
248
- logits[:] = torch.where(
249
- logits > 0,
250
- logits / self.scaling_penalties,
251
- logits * self.scaling_penalties,
252
- )
389
+ if is_cuda:
390
+ logits[:] = sampling_scaling_penalties(logits, self.scaling_penalties)
391
+ else:
392
+ logits[:] = torch.where(
393
+ logits > 0,
394
+ logits / self.scaling_penalties,
395
+ logits * self.scaling_penalties,
396
+ )
253
397
 
254
398
  # Apply regex vocab_mask
255
399
  if self.vocab_mask is not None:
@@ -13,7 +13,7 @@
13
13
  # ==============================================================================
14
14
  """Sampling parameters for text generation."""
15
15
 
16
- from typing import List, Optional, Union
16
+ from typing import Any, Dict, List, Optional, Union
17
17
 
18
18
  _SAMPLING_EPS = 1e-6
19
19
 
@@ -23,7 +23,7 @@ class SamplingParams:
23
23
  The sampling parameters.
24
24
 
25
25
  See docs/references/sampling_params.md or
26
- https://sgl-project.github.io/references/sampling_params.html
26
+ https://docs.sglang.ai/references/sampling_params.html
27
27
  for the documentation.
28
28
  """
29
29
 
@@ -48,6 +48,7 @@ class SamplingParams:
48
48
  no_stop_trim: bool = False,
49
49
  ignore_eos: bool = False,
50
50
  skip_special_tokens: bool = True,
51
+ custom_params: Optional[Dict[str, Any]] = None,
51
52
  ) -> None:
52
53
  self.temperature = temperature
53
54
  self.top_p = top_p
@@ -71,6 +72,7 @@ class SamplingParams:
71
72
  self.json_schema = json_schema
72
73
  self.ebnf = ebnf
73
74
  self.no_stop_trim = no_stop_trim
75
+ self.custom_params = custom_params
74
76
 
75
77
  # Process some special cases
76
78
  if self.temperature < _SAMPLING_EPS: