sglang 0.3.5.post2__py3-none-any.whl → 0.3.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. sglang/bench_latency.py +1 -553
  2. sglang/bench_offline_throughput.py +48 -20
  3. sglang/bench_one_batch.py +474 -0
  4. sglang/{bench_server_latency.py → bench_one_batch_server.py} +3 -3
  5. sglang/bench_serving.py +71 -1
  6. sglang/check_env.py +3 -6
  7. sglang/srt/constrained/outlines_backend.py +15 -2
  8. sglang/srt/constrained/xgrammar_backend.py +22 -14
  9. sglang/srt/layers/activation.py +3 -0
  10. sglang/srt/layers/attention/flashinfer_backend.py +93 -48
  11. sglang/srt/layers/attention/triton_backend.py +9 -7
  12. sglang/srt/layers/custom_op_util.py +26 -0
  13. sglang/srt/layers/fused_moe/fused_moe.py +11 -4
  14. sglang/srt/layers/layernorm.py +4 -0
  15. sglang/srt/layers/logits_processor.py +10 -10
  16. sglang/srt/layers/sampler.py +4 -8
  17. sglang/srt/layers/torchao_utils.py +2 -0
  18. sglang/srt/managers/data_parallel_controller.py +74 -9
  19. sglang/srt/managers/detokenizer_manager.py +1 -0
  20. sglang/srt/managers/io_struct.py +27 -0
  21. sglang/srt/managers/schedule_batch.py +104 -38
  22. sglang/srt/managers/schedule_policy.py +5 -1
  23. sglang/srt/managers/scheduler.py +204 -54
  24. sglang/srt/managers/session_controller.py +62 -0
  25. sglang/srt/managers/tokenizer_manager.py +38 -0
  26. sglang/srt/managers/tp_worker.py +12 -1
  27. sglang/srt/managers/tp_worker_overlap_thread.py +49 -52
  28. sglang/srt/model_executor/cuda_graph_runner.py +43 -6
  29. sglang/srt/model_executor/forward_batch_info.py +109 -15
  30. sglang/srt/model_executor/model_runner.py +99 -43
  31. sglang/srt/model_parallel.py +98 -0
  32. sglang/srt/models/deepseek_v2.py +147 -44
  33. sglang/srt/models/gemma2.py +9 -8
  34. sglang/srt/models/llava.py +1 -1
  35. sglang/srt/models/llavavid.py +1 -1
  36. sglang/srt/models/olmo.py +3 -3
  37. sglang/srt/models/phi3_small.py +447 -0
  38. sglang/srt/models/qwen2_vl.py +13 -6
  39. sglang/srt/models/torch_native_llama.py +94 -78
  40. sglang/srt/openai_api/adapter.py +6 -2
  41. sglang/srt/openai_api/protocol.py +1 -1
  42. sglang/srt/sampling/penaltylib/orchestrator.py +49 -79
  43. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +3 -8
  44. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +3 -9
  45. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +3 -8
  46. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +3 -8
  47. sglang/srt/sampling/sampling_batch_info.py +58 -57
  48. sglang/srt/sampling/sampling_params.py +1 -1
  49. sglang/srt/server.py +27 -1
  50. sglang/srt/server_args.py +78 -62
  51. sglang/srt/utils.py +71 -52
  52. sglang/test/runners.py +25 -6
  53. sglang/test/srt/sampling/penaltylib/utils.py +23 -21
  54. sglang/test/test_utils.py +30 -19
  55. sglang/version.py +1 -1
  56. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.dist-info}/METADATA +43 -43
  57. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.dist-info}/RECORD +60 -55
  58. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.dist-info}/WHEEL +1 -1
  59. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.dist-info}/LICENSE +0 -0
  60. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.dist-info}/top_level.txt +0 -0
@@ -17,6 +17,31 @@ limitations under the License.
17
17
  # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1
18
18
  """Inference-only LLaMA model compatible with HuggingFace weights."""
19
19
 
20
+ # PyTorch Tensor Parallel Available for This Model
21
+ """
22
+ This model supports tensor parallelism (TP) using the PyTorch tensor parallel package.
23
+ Reference: https://pytorch.org/docs/stable/distributed.tensor.parallel.html
24
+
25
+ Here is a quick example to enable TP:
26
+ ```python
27
+ from sglang.srt.model_parallel import tensor_parallel
28
+
29
+ device_mesh = torch.distributed.init_device_mesh("cuda", (tp_size,))
30
+ tensor_parallel(model, device_mesh)
31
+ ```
32
+
33
+ An end-to-end example can be found in `python/sglang/bench_one_batch.py`.
34
+ You can run it with the following command:
35
+ ```bash
36
+ $ python3 -m sglang.bench_one_batch --correct \
37
+ --model meta-llama/Meta-Llama-3-8B \
38
+ --json-model-override-args '{"architectures": ["TorchNativeLlamaForCausalLM"]}' \
39
+ --tensor-parallel-size 2 \
40
+ --disable-cuda-graph
41
+ ```
42
+ We will eanble CUDA Graph support soon.
43
+ """
44
+
20
45
  import types
21
46
  from typing import Any, Dict, Iterable, Optional, Tuple
22
47
 
@@ -24,7 +49,10 @@ import torch
24
49
  from torch import nn
25
50
  from torch.nn.parameter import Parameter
26
51
  from transformers import LlamaConfig
27
- from vllm.distributed import get_tensor_model_parallel_world_size
52
+ from vllm.distributed import (
53
+ get_tensor_model_parallel_rank,
54
+ get_tensor_model_parallel_world_size,
55
+ )
28
56
  from vllm.model_executor.layers.rotary_embedding import get_rope
29
57
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
30
58
 
@@ -41,35 +69,45 @@ from sglang.srt.layers.vocab_parallel_embedding import (
41
69
  from sglang.srt.managers.schedule_batch import global_server_args_dict
42
70
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
43
71
 
72
+ tp_size = get_tensor_model_parallel_world_size()
73
+ tp_rank = get_tensor_model_parallel_rank()
74
+
44
75
 
45
76
  def gate_up_proj_weight_loader(
46
77
  self,
47
78
  param: Parameter,
48
79
  loaded_weight: torch.Tensor,
49
- loaded_shard_id: Optional[int] = None,
80
+ loaded_shard_id: int,
50
81
  ):
51
- if loaded_shard_id is None:
52
- shard_offsets: List[Tuple[int, int, int]] = []
53
- for i, output_size in enumerate(self.output_sizes):
54
- shard_offsets.append((i, current_shard_offset, output_size))
55
- current_shard_offset += output_size
56
- for shard_id, shard_offset, shard_size in shard_offsets:
57
- loaded_weight_shard = loaded_weight.narrow(
58
- output_dim, shard_offset, shard_size
59
- )
60
- self.weight_loader(param, loaded_weight_shard, shard_id)
61
- else:
62
- assert loaded_shard_id < len(self.output_sizes)
63
- param_data = param.data
64
- shard_size = loaded_weight.shape[0]
65
- shard_offset = loaded_shard_id * shard_size
66
- param_data = param_data.narrow(0, shard_offset, shard_size)
67
- assert param_data.shape == loaded_weight.shape
68
- param_data.copy_(loaded_weight)
69
- return
82
+ # shard_id: (shard_offset, shard_size)
83
+ gate_up_offsets = {}
84
+ current_shard_offset = 0
85
+ for i, output_size in enumerate(self.output_sizes):
86
+ # Everything shrinks by tp_size if TP enabled
87
+ output_size = output_size // tp_size
88
+ gate_up_offsets[i] = (current_shard_offset, output_size)
89
+ current_shard_offset += output_size
90
+ # Re-size the param to the size after TP
91
+ if current_shard_offset != param.shape[0]:
92
+ # The clone will free the original, full tensor
93
+ param.data = param.data.narrow(0, 0, current_shard_offset).clone()
94
+
95
+ # Now load gate or up
96
+ assert loaded_shard_id < len(self.output_sizes)
97
+ param_data = param.data
98
+ shard_offset, shard_size = gate_up_offsets[loaded_shard_id]
99
+ param_data = param_data.narrow(0, shard_offset, shard_size)
100
+ loaded_weight = loaded_weight.narrow(0, tp_rank * shard_size, shard_size)
101
+ assert param_data.shape == loaded_weight.shape
102
+ param_data.copy_(loaded_weight)
70
103
 
71
104
 
72
105
  class LlamaMLP(nn.Module):
106
+ _tp_plan = {
107
+ "gate_up_proj": "Colwise_Sharded",
108
+ "down_proj": "Rowwise",
109
+ }
110
+
73
111
  def __init__(
74
112
  self,
75
113
  hidden_size: int,
@@ -104,62 +142,44 @@ class LlamaMLP(nn.Module):
104
142
  return x
105
143
 
106
144
 
107
- def _get_shard_offset_mapping(self, loaded_shard_id: str):
108
- shard_offset_mapping = {
109
- "q": 0,
110
- "k": self.num_heads * self.head_size,
111
- "v": (self.num_heads + self.num_kv_heads) * self.head_size,
112
- "total": (self.num_heads + 2 * self.num_kv_heads) * self.head_size,
113
- }
114
- return shard_offset_mapping.get(loaded_shard_id)
115
-
116
-
117
- def _get_shard_size_mapping(self, loaded_shard_id: str):
118
- shard_size_mapping = {
119
- "q": self.num_heads * self.head_size,
120
- "k": self.num_kv_heads * self.head_size,
121
- "v": self.num_kv_heads * self.head_size,
122
- }
123
- return shard_size_mapping.get(loaded_shard_id)
124
-
125
-
126
145
  def qkv_proj_weight_loader(
127
146
  self,
128
147
  param: Parameter,
129
148
  loaded_weight: torch.Tensor,
130
- loaded_shard_id: Optional[str] = None,
149
+ loaded_shard_id: str,
131
150
  ):
132
- if loaded_shard_id is None:
133
- shard_offsets = [
134
- # (shard_id, shard_offset, shard_size)
135
- ("q", 0, self.total_num_heads * self.head_size),
136
- (
137
- "k",
138
- self.total_num_heads * self.head_size,
139
- self.total_num_kv_heads * self.head_size,
140
- ),
141
- (
142
- "v",
143
- (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
144
- self.total_num_kv_heads * self.head_size,
145
- ),
146
- ]
147
- for shard_id, shard_offset, shard_size in shard_offsets:
148
- loaded_weight_shard = loaded_weight.narrow(
149
- param.output_dim, shard_offset, shard_size
150
- )
151
- self.weight_loader(param, loaded_weight_shard, shard_id)
152
- else:
153
- shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
154
- shard_size = self._get_shard_size_mapping(loaded_shard_id)
155
- param_data = param.data
156
- param_data = param_data.narrow(0, shard_offset, shard_size)
157
- assert param_data.shape == loaded_weight.shape
158
- param_data.copy_(loaded_weight)
159
- return
151
+ num_heads = self.num_heads // tp_size
152
+ num_kv_heads = self.num_kv_heads // tp_size
153
+ # shard_id: (shard_offset, shard_size)
154
+ qkv_offsets = {
155
+ "q": (0, num_heads * self.head_size),
156
+ "k": (num_heads * self.head_size, num_kv_heads * self.head_size),
157
+ "v": (
158
+ (num_heads + num_kv_heads) * self.head_size,
159
+ num_kv_heads * self.head_size,
160
+ ),
161
+ }
162
+ total_size = qkv_offsets["v"][0] + qkv_offsets["v"][1]
163
+ # Re-size the param to the size after TP
164
+ if total_size != param.shape[0]:
165
+ # The clone will free the original, full tensor
166
+ param.data = param.data.narrow(0, 0, total_size).clone()
167
+
168
+ # Now load q, k or v
169
+ shard_offset, shard_size = qkv_offsets[loaded_shard_id]
170
+ param_data = param.data
171
+ param_data = param_data.narrow(0, shard_offset, shard_size)
172
+ loaded_weight = loaded_weight.narrow(0, tp_rank * shard_size, shard_size)
173
+ assert param_data.shape == loaded_weight.shape
174
+ param_data.copy_(loaded_weight)
160
175
 
161
176
 
162
177
  class LlamaAttention(nn.Module):
178
+ _tp_plan = {
179
+ "qkv_proj": "Colwise_Sharded",
180
+ "o_proj": "Rowwise",
181
+ }
182
+
163
183
  def __init__(
164
184
  self,
165
185
  config: LlamaConfig,
@@ -176,7 +196,6 @@ class LlamaAttention(nn.Module):
176
196
  ) -> None:
177
197
  super().__init__()
178
198
  self.hidden_size = hidden_size
179
- tp_size = get_tensor_model_parallel_world_size()
180
199
  self.total_num_heads = num_heads
181
200
  assert self.total_num_heads % tp_size == 0
182
201
  self.num_heads = self.total_num_heads // tp_size
@@ -205,20 +224,12 @@ class LlamaAttention(nn.Module):
205
224
  (self.total_num_heads + 2 * self.total_num_kv_heads) * self.head_dim,
206
225
  bias=False,
207
226
  )
208
- self.qkv_proj.total_num_heads = self.total_num_heads
209
227
  self.qkv_proj.head_size = self.head_dim
210
- self.qkv_proj.total_num_kv_heads = self.total_num_kv_heads
211
228
  self.qkv_proj.num_heads = self.total_num_heads
212
229
  self.qkv_proj.num_kv_heads = self.total_num_kv_heads
213
230
  self.qkv_proj.weight_loader = types.MethodType(
214
231
  qkv_proj_weight_loader, self.qkv_proj
215
232
  )
216
- self.qkv_proj._get_shard_offset_mapping = types.MethodType(
217
- _get_shard_offset_mapping, self.qkv_proj
218
- )
219
- self.qkv_proj._get_shard_size_mapping = types.MethodType(
220
- _get_shard_size_mapping, self.qkv_proj
221
- )
222
233
  self.qkv_proj.weight.weight_loader = self.qkv_proj.weight_loader
223
234
  self.qkv_proj.weight.output_dim = 0
224
235
  self.o_proj = torch.nn.Linear(
@@ -385,10 +396,15 @@ class TorchNativeLlamaForCausalLM(nn.Module):
385
396
  self.config = config
386
397
  self.quant_config = quant_config
387
398
  self.torchao_config = global_server_args_dict["torchao_config"]
399
+ self.supports_torch_tp = True
388
400
  self.model = LlamaModel(config, quant_config=quant_config)
389
401
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
390
402
  self.logits_processor = LogitsProcessor(config)
391
403
 
404
+ # turning off autotune for fp8dq since it doesn't give speedup and
405
+ # increases compile time significantly
406
+ torch._inductor.config.max_autotune_gemm_backends = "ATEN"
407
+
392
408
  @torch.no_grad()
393
409
  def forward(
394
410
  self,
@@ -989,11 +989,15 @@ def v1_chat_generate_response(request, ret, to_file=False, cache_report=False):
989
989
  output_top_logprobs=ret_item["meta_info"]["output_top_logprobs"],
990
990
  )
991
991
  token_logprobs = []
992
- for token, logprob in zip(logprobs.tokens, logprobs.token_logprobs):
992
+ for token_idx, (token, logprob) in enumerate(
993
+ zip(logprobs.tokens, logprobs.token_logprobs)
994
+ ):
993
995
  token_bytes = list(token.encode("utf-8"))
994
996
  top_logprobs = []
995
997
  if logprobs.top_logprobs:
996
- for top_token, top_logprob in logprobs.top_logprobs[0].items():
998
+ for top_token, top_logprob in logprobs.top_logprobs[
999
+ token_idx
1000
+ ].items():
997
1001
  top_token_bytes = list(top_token.encode("utf-8"))
998
1002
  top_logprobs.append(
999
1003
  TopLogprob(
@@ -236,7 +236,7 @@ ChatCompletionMessageContentPart = Union[
236
236
 
237
237
 
238
238
  class ChatCompletionMessageGenericParam(BaseModel):
239
- role: Literal["system", "assistant"]
239
+ role: Literal["system", "assistant", "tool"]
240
240
  content: Union[str, List[ChatCompletionMessageContentTextPart]]
241
241
 
242
242
 
@@ -1,40 +1,34 @@
1
1
  import abc
2
2
  import dataclasses
3
- import typing
3
+ from typing import List, Set, Type, Union
4
4
 
5
5
  import torch
6
6
 
7
7
 
8
8
  @dataclasses.dataclass
9
9
  class _ReqLike:
10
- origin_input_ids: typing.Union[torch.Tensor, typing.List[int]]
10
+ origin_input_ids: List[int]
11
11
 
12
12
 
13
13
  @dataclasses.dataclass
14
14
  class _BatchLike:
15
- reqs: typing.List[_ReqLike]
15
+ reqs: List[_ReqLike]
16
16
 
17
17
  def batch_size(self):
18
18
  return len(self.reqs)
19
19
 
20
20
 
21
21
  class BatchedPenalizerOrchestrator:
22
- batch: _BatchLike
23
- device: str
24
- vocab_size: int
25
- penalizers: typing.Dict[typing.Type["_BatchedPenalizer"], "_BatchedPenalizer"]
26
-
27
22
  def __init__(
28
23
  self,
29
24
  vocab_size: int,
30
25
  batch: _BatchLike,
31
26
  device: str,
32
- Penalizers: typing.Set[typing.Type["_BatchedPenalizer"]],
27
+ Penalizers: Set[Type["_BatchedPenalizer"]],
33
28
  ):
34
29
  self.vocab_size = vocab_size
35
30
  self.batch = batch
36
31
  self.device = device
37
-
38
32
  self.penalizers = {Penalizer: Penalizer(self) for Penalizer in Penalizers}
39
33
 
40
34
  is_required = False
@@ -43,10 +37,12 @@ class BatchedPenalizerOrchestrator:
43
37
  is_required |= pen_is_required
44
38
  self.is_required = is_required
45
39
 
40
+ input_ids = [
41
+ torch.tensor(req.origin_input_ids, dtype=torch.int64, device=self.device)
42
+ for req in self.reqs()
43
+ ]
46
44
  if self.is_required:
47
- self.cumulate_input_tokens(
48
- input_ids=[req.origin_input_ids for req in self.reqs()]
49
- )
45
+ self.cumulate_input_tokens(input_ids=input_ids)
50
46
 
51
47
  def reqs(self):
52
48
  return self.batch.reqs
@@ -54,34 +50,24 @@ class BatchedPenalizerOrchestrator:
54
50
  def batch_size(self):
55
51
  return self.batch.batch_size()
56
52
 
57
- def cumulate_input_tokens(
58
- self,
59
- input_ids: typing.Union[
60
- typing.List[torch.Tensor], typing.List[typing.List[int]]
61
- ],
62
- ):
53
+ def cumulate_input_tokens(self, input_ids: List[torch.Tensor]):
63
54
  """
64
55
  Feed the input tokens to the penalizers.
65
56
 
66
57
  Args:
67
- input_ids (typing.Union[typing.List[torch.Tensor], typing.List[typing.List[int]]]): The input tokens.
58
+ input_ids (List[torch.Tensor]): The input tokens.
68
59
  """
69
60
  token_ids = _TokenIDs(orchestrator=self, token_ids=input_ids)
70
61
 
71
62
  for penalizer in self.penalizers.values():
72
63
  penalizer.cumulate_input_tokens(input_ids=token_ids)
73
64
 
74
- def cumulate_output_tokens(
75
- self,
76
- output_ids: typing.Union[
77
- typing.List[torch.Tensor], typing.List[typing.List[int]]
78
- ],
79
- ):
65
+ def cumulate_output_tokens(self, output_ids: torch.Tensor):
80
66
  """
81
67
  Feed the output tokens to the penalizers.
82
68
 
83
69
  Args:
84
- output_ids (typing.Union[typing.List[torch.Tensor], typing.List[typing.List[int]]]): The output tokens.
70
+ output_ids (torch.Tensor): The output tokens.
85
71
  """
86
72
  if not self.is_required:
87
73
  return
@@ -112,14 +98,14 @@ class BatchedPenalizerOrchestrator:
112
98
 
113
99
  def filter(
114
100
  self,
115
- indices_to_keep: typing.List[int],
101
+ indices_to_keep: List[int],
116
102
  indices_tensor_to_keep: torch.Tensor = None,
117
103
  ):
118
104
  """
119
105
  Filter the penalizers based on the indices to keep in the batch.
120
106
 
121
107
  Args:
122
- indices_to_keep (typing.List[int]): List of indices to keep in the batch.
108
+ indices_to_keep (List[int]): List of indices to keep in the batch.
123
109
  indices_tensor_to_keep (torch.Tensor = None): Tensor of indices to keep in the batch. If not None, it will be used instead of converting indices_to_keep to a tensor.
124
110
  """
125
111
  if not self.is_required:
@@ -174,32 +160,18 @@ class _TokenIDs:
174
160
 
175
161
  Attributes:
176
162
  orchestrator (BatchedPenalizerOrchestrator): The orchestrator that this token IDs belong to.
177
- token_ids (typing.Union[torch.Tensor, typing.List[torch.Tensor]]): The token IDs.
163
+ token_ids (Union[torch.Tensor, List[torch.Tensor]]): The token IDs.
178
164
  cached_counts (torch.Tensor): The cached occurrence count tensor.
179
165
  """
180
166
 
181
- orchestrator: BatchedPenalizerOrchestrator
182
- token_ids: typing.Union[torch.Tensor, typing.List[torch.Tensor]]
183
- cached_counts: torch.Tensor = None
184
-
185
167
  def __init__(
186
168
  self,
187
169
  orchestrator: BatchedPenalizerOrchestrator,
188
- token_ids: typing.Union[
189
- typing.List[torch.Tensor], typing.List[typing.List[int]]
190
- ],
170
+ token_ids: Union[torch.Tensor, List[torch.Tensor]],
191
171
  ):
192
172
  self.orchestrator = orchestrator
193
-
194
- if not isinstance(token_ids[0], torch.Tensor):
195
- token_ids = [
196
- torch.tensor(
197
- data=ids, dtype=torch.int64, device=self.orchestrator.device
198
- )
199
- for ids in token_ids
200
- ]
201
-
202
173
  self.token_ids = token_ids
174
+ self.cached_counts = None
203
175
 
204
176
  def occurrence_count(self) -> torch.Tensor:
205
177
  """
@@ -213,30 +185,34 @@ class _TokenIDs:
213
185
 
214
186
  token_ids = self.token_ids
215
187
 
216
- if isinstance(token_ids, torch.Tensor):
217
- token_ids = token_ids.unsqueeze(1)
218
-
219
- # needs to be long to be used as index in scatter_add
220
- if token_ids.dtype != torch.int64:
221
- token_ids = token_ids.to(torch.int64)
222
-
223
- padded_token_ids = torch.nn.utils.rnn.pad_sequence(
224
- sequences=token_ids,
225
- batch_first=True,
226
- padding_value=self.orchestrator.vocab_size,
227
- )
228
-
229
- self.cached_counts = torch.zeros(
230
- size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size + 1),
231
- dtype=torch.int64,
232
- device=self.orchestrator.device,
233
- ).scatter_add_(
234
- dim=1,
235
- index=padded_token_ids,
236
- src=torch.ones_like(padded_token_ids),
237
- )[
238
- :, : self.orchestrator.vocab_size
239
- ]
188
+ if isinstance(token_ids, list):
189
+ # TODO: optimize this part
190
+ padded_token_ids = torch.nn.utils.rnn.pad_sequence(
191
+ sequences=token_ids,
192
+ batch_first=True,
193
+ padding_value=self.orchestrator.vocab_size,
194
+ )
195
+ self.cached_counts = torch.zeros(
196
+ size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size + 1),
197
+ dtype=torch.int64,
198
+ device=self.orchestrator.device,
199
+ ).scatter_add_(
200
+ dim=1,
201
+ index=padded_token_ids,
202
+ src=torch.ones_like(padded_token_ids),
203
+ )[
204
+ :, : self.orchestrator.vocab_size
205
+ ]
206
+ else:
207
+ # TODO: optimize this part. We do not need to create this big tensor every time.
208
+ # We can directly apply the results on the logits.
209
+ self.cached_counts = torch.zeros(
210
+ size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size),
211
+ device=self.orchestrator.device,
212
+ )
213
+ self.cached_counts[
214
+ torch.arange(len(token_ids), device=self.orchestrator.device), token_ids
215
+ ] = 1
240
216
 
241
217
  return self.cached_counts
242
218
 
@@ -246,11 +222,9 @@ class _BatchedPenalizer(abc.ABC):
246
222
  An abstract class for a batched penalizer.
247
223
  """
248
224
 
249
- orchestrator: BatchedPenalizerOrchestrator
250
- _is_prepared: bool = False
251
-
252
225
  def __init__(self, orchestrator: BatchedPenalizerOrchestrator):
253
226
  self.orchestrator = orchestrator
227
+ self._is_prepared = False
254
228
 
255
229
  def is_prepared(self) -> bool:
256
230
  return self._is_prepared
@@ -293,9 +267,7 @@ class _BatchedPenalizer(abc.ABC):
293
267
 
294
268
  return self._apply(logits=logits)
295
269
 
296
- def filter(
297
- self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
298
- ):
270
+ def filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
299
271
  if not self.is_prepared():
300
272
  return
301
273
 
@@ -360,9 +332,7 @@ class _BatchedPenalizer(abc.ABC):
360
332
  pass
361
333
 
362
334
  @abc.abstractmethod
363
- def _filter(
364
- self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
365
- ):
335
+ def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
366
336
  """
367
337
  Filter the penalizer (tensors or underlying data) based on the indices to keep in the batch.
368
338
  """
@@ -1,8 +1,8 @@
1
- import typing
1
+ from typing import List
2
2
 
3
3
  import torch
4
4
 
5
- from ..orchestrator import _BatchedPenalizer, _TokenIDs
5
+ from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
6
6
 
7
7
 
8
8
  class BatchedFrequencyPenalizer(_BatchedPenalizer):
@@ -44,9 +44,6 @@ class BatchedFrequencyPenalizer(_BatchedPenalizer):
44
44
  )
45
45
 
46
46
  def _teardown(self):
47
- del self.frequency_penalties
48
- del self.cumulated_frequency_penalties
49
-
50
47
  self.frequency_penalties = None
51
48
  self.cumulated_frequency_penalties = None
52
49
 
@@ -62,9 +59,7 @@ class BatchedFrequencyPenalizer(_BatchedPenalizer):
62
59
  logits -= self.cumulated_frequency_penalties
63
60
  return logits
64
61
 
65
- def _filter(
66
- self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
67
- ):
62
+ def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
68
63
  self.frequency_penalties = self.frequency_penalties[indices_tensor_to_keep]
69
64
  self.cumulated_frequency_penalties = self.cumulated_frequency_penalties[
70
65
  indices_tensor_to_keep
@@ -1,8 +1,8 @@
1
- import typing
1
+ from typing import List
2
2
 
3
3
  import torch
4
4
 
5
- from ..orchestrator import _BatchedPenalizer, _TokenIDs
5
+ from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
6
6
 
7
7
 
8
8
  class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
@@ -70,10 +70,6 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
70
70
  )
71
71
 
72
72
  def _teardown(self):
73
- del self.min_new_tokens
74
- del self.stop_token_penalties
75
- del self.len_output_tokens
76
-
77
73
  self.min_new_tokens = None
78
74
  self.stop_token_penalties = None
79
75
  self.len_output_tokens = None
@@ -89,9 +85,7 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
89
85
  logits[mask] += self.stop_token_penalties[mask]
90
86
  return logits
91
87
 
92
- def _filter(
93
- self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
94
- ):
88
+ def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
95
89
  self.min_new_tokens = self.min_new_tokens[indices_tensor_to_keep]
96
90
  self.stop_token_penalties = self.stop_token_penalties[indices_tensor_to_keep]
97
91
  self.len_output_tokens = self.len_output_tokens[indices_tensor_to_keep]
@@ -1,8 +1,8 @@
1
- import typing
1
+ from typing import List
2
2
 
3
3
  import torch
4
4
 
5
- from ..orchestrator import _BatchedPenalizer, _TokenIDs
5
+ from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
6
6
 
7
7
 
8
8
  class BatchedPresencePenalizer(_BatchedPenalizer):
@@ -44,9 +44,6 @@ class BatchedPresencePenalizer(_BatchedPenalizer):
44
44
  )
45
45
 
46
46
  def _teardown(self):
47
- del self.presence_penalties
48
- del self.cumulated_presence_penalties
49
-
50
47
  self.presence_penalties = None
51
48
  self.cumulated_presence_penalties = None
52
49
 
@@ -61,9 +58,7 @@ class BatchedPresencePenalizer(_BatchedPenalizer):
61
58
  logits -= self.cumulated_presence_penalties
62
59
  return logits
63
60
 
64
- def _filter(
65
- self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
66
- ):
61
+ def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
67
62
  self.presence_penalties = self.presence_penalties[indices_tensor_to_keep]
68
63
  self.cumulated_presence_penalties = self.cumulated_presence_penalties[
69
64
  indices_tensor_to_keep
@@ -1,8 +1,8 @@
1
- import typing
1
+ from typing import List
2
2
 
3
3
  import torch
4
4
 
5
- from ..orchestrator import _BatchedPenalizer, _TokenIDs
5
+ from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
6
6
 
7
7
 
8
8
  class BatchedRepetitionPenalizer(_BatchedPenalizer):
@@ -44,9 +44,6 @@ class BatchedRepetitionPenalizer(_BatchedPenalizer):
44
44
  )
45
45
 
46
46
  def _teardown(self):
47
- del self.repetition_penalties
48
- del self.cumulated_repetition_penalties
49
-
50
47
  self.repetition_penalties = None
51
48
  self.cumulated_repetition_penalties = None
52
49
 
@@ -65,9 +62,7 @@ class BatchedRepetitionPenalizer(_BatchedPenalizer):
65
62
  logits * self.cumulated_repetition_penalties,
66
63
  )
67
64
 
68
- def _filter(
69
- self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
70
- ):
65
+ def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
71
66
  self.repetition_penalties = self.repetition_penalties[indices_tensor_to_keep]
72
67
  self.cumulated_repetition_penalties = self.cumulated_repetition_penalties[
73
68
  indices_tensor_to_keep