sglang 0.3.5.post1__py3-none-any.whl → 0.3.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. sglang/bench_latency.py +1 -553
  2. sglang/bench_offline_throughput.py +337 -0
  3. sglang/bench_one_batch.py +474 -0
  4. sglang/{bench_server_latency.py → bench_one_batch_server.py} +3 -3
  5. sglang/bench_serving.py +115 -31
  6. sglang/check_env.py +3 -6
  7. sglang/srt/constrained/base_grammar_backend.py +4 -3
  8. sglang/srt/constrained/outlines_backend.py +39 -26
  9. sglang/srt/constrained/xgrammar_backend.py +58 -14
  10. sglang/srt/layers/activation.py +3 -0
  11. sglang/srt/layers/attention/flashinfer_backend.py +93 -48
  12. sglang/srt/layers/attention/triton_backend.py +9 -7
  13. sglang/srt/layers/custom_op_util.py +26 -0
  14. sglang/srt/layers/fused_moe/fused_moe.py +11 -4
  15. sglang/srt/layers/fused_moe/patch.py +4 -2
  16. sglang/srt/layers/layernorm.py +4 -0
  17. sglang/srt/layers/logits_processor.py +10 -10
  18. sglang/srt/layers/sampler.py +4 -8
  19. sglang/srt/layers/torchao_utils.py +2 -0
  20. sglang/srt/managers/data_parallel_controller.py +74 -9
  21. sglang/srt/managers/detokenizer_manager.py +1 -14
  22. sglang/srt/managers/io_struct.py +27 -0
  23. sglang/srt/managers/schedule_batch.py +104 -38
  24. sglang/srt/managers/schedule_policy.py +5 -1
  25. sglang/srt/managers/scheduler.py +210 -56
  26. sglang/srt/managers/session_controller.py +62 -0
  27. sglang/srt/managers/tokenizer_manager.py +38 -0
  28. sglang/srt/managers/tp_worker.py +12 -1
  29. sglang/srt/managers/tp_worker_overlap_thread.py +49 -52
  30. sglang/srt/model_executor/cuda_graph_runner.py +43 -6
  31. sglang/srt/model_executor/forward_batch_info.py +109 -15
  32. sglang/srt/model_executor/model_runner.py +102 -43
  33. sglang/srt/model_parallel.py +98 -0
  34. sglang/srt/models/deepseek_v2.py +147 -44
  35. sglang/srt/models/gemma2.py +9 -8
  36. sglang/srt/models/llava.py +1 -1
  37. sglang/srt/models/llavavid.py +1 -1
  38. sglang/srt/models/olmo.py +3 -3
  39. sglang/srt/models/phi3_small.py +447 -0
  40. sglang/srt/models/qwen2_vl.py +13 -6
  41. sglang/srt/models/torch_native_llama.py +94 -78
  42. sglang/srt/openai_api/adapter.py +11 -4
  43. sglang/srt/openai_api/protocol.py +30 -27
  44. sglang/srt/sampling/penaltylib/orchestrator.py +49 -79
  45. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +3 -8
  46. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +3 -9
  47. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +3 -8
  48. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +3 -8
  49. sglang/srt/sampling/sampling_batch_info.py +58 -57
  50. sglang/srt/sampling/sampling_params.py +3 -3
  51. sglang/srt/server.py +29 -2
  52. sglang/srt/server_args.py +97 -60
  53. sglang/srt/utils.py +103 -51
  54. sglang/test/runners.py +25 -6
  55. sglang/test/srt/sampling/penaltylib/utils.py +23 -21
  56. sglang/test/test_utils.py +33 -22
  57. sglang/version.py +1 -1
  58. {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/METADATA +43 -43
  59. {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/RECORD +62 -56
  60. {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/WHEEL +1 -1
  61. {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/LICENSE +0 -0
  62. {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/top_level.txt +0 -0
@@ -17,6 +17,31 @@ limitations under the License.
17
17
  # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1
18
18
  """Inference-only LLaMA model compatible with HuggingFace weights."""
19
19
 
20
+ # PyTorch Tensor Parallel Available for This Model
21
+ """
22
+ This model supports tensor parallelism (TP) using the PyTorch tensor parallel package.
23
+ Reference: https://pytorch.org/docs/stable/distributed.tensor.parallel.html
24
+
25
+ Here is a quick example to enable TP:
26
+ ```python
27
+ from sglang.srt.model_parallel import tensor_parallel
28
+
29
+ device_mesh = torch.distributed.init_device_mesh("cuda", (tp_size,))
30
+ tensor_parallel(model, device_mesh)
31
+ ```
32
+
33
+ An end-to-end example can be found in `python/sglang/bench_one_batch.py`.
34
+ You can run it with the following command:
35
+ ```bash
36
+ $ python3 -m sglang.bench_one_batch --correct \
37
+ --model meta-llama/Meta-Llama-3-8B \
38
+ --json-model-override-args '{"architectures": ["TorchNativeLlamaForCausalLM"]}' \
39
+ --tensor-parallel-size 2 \
40
+ --disable-cuda-graph
41
+ ```
42
+ We will eanble CUDA Graph support soon.
43
+ """
44
+
20
45
  import types
21
46
  from typing import Any, Dict, Iterable, Optional, Tuple
22
47
 
@@ -24,7 +49,10 @@ import torch
24
49
  from torch import nn
25
50
  from torch.nn.parameter import Parameter
26
51
  from transformers import LlamaConfig
27
- from vllm.distributed import get_tensor_model_parallel_world_size
52
+ from vllm.distributed import (
53
+ get_tensor_model_parallel_rank,
54
+ get_tensor_model_parallel_world_size,
55
+ )
28
56
  from vllm.model_executor.layers.rotary_embedding import get_rope
29
57
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
30
58
 
@@ -41,35 +69,45 @@ from sglang.srt.layers.vocab_parallel_embedding import (
41
69
  from sglang.srt.managers.schedule_batch import global_server_args_dict
42
70
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
43
71
 
72
+ tp_size = get_tensor_model_parallel_world_size()
73
+ tp_rank = get_tensor_model_parallel_rank()
74
+
44
75
 
45
76
  def gate_up_proj_weight_loader(
46
77
  self,
47
78
  param: Parameter,
48
79
  loaded_weight: torch.Tensor,
49
- loaded_shard_id: Optional[int] = None,
80
+ loaded_shard_id: int,
50
81
  ):
51
- if loaded_shard_id is None:
52
- shard_offsets: List[Tuple[int, int, int]] = []
53
- for i, output_size in enumerate(self.output_sizes):
54
- shard_offsets.append((i, current_shard_offset, output_size))
55
- current_shard_offset += output_size
56
- for shard_id, shard_offset, shard_size in shard_offsets:
57
- loaded_weight_shard = loaded_weight.narrow(
58
- output_dim, shard_offset, shard_size
59
- )
60
- self.weight_loader(param, loaded_weight_shard, shard_id)
61
- else:
62
- assert loaded_shard_id < len(self.output_sizes)
63
- param_data = param.data
64
- shard_size = loaded_weight.shape[0]
65
- shard_offset = loaded_shard_id * shard_size
66
- param_data = param_data.narrow(0, shard_offset, shard_size)
67
- assert param_data.shape == loaded_weight.shape
68
- param_data.copy_(loaded_weight)
69
- return
82
+ # shard_id: (shard_offset, shard_size)
83
+ gate_up_offsets = {}
84
+ current_shard_offset = 0
85
+ for i, output_size in enumerate(self.output_sizes):
86
+ # Everything shrinks by tp_size if TP enabled
87
+ output_size = output_size // tp_size
88
+ gate_up_offsets[i] = (current_shard_offset, output_size)
89
+ current_shard_offset += output_size
90
+ # Re-size the param to the size after TP
91
+ if current_shard_offset != param.shape[0]:
92
+ # The clone will free the original, full tensor
93
+ param.data = param.data.narrow(0, 0, current_shard_offset).clone()
94
+
95
+ # Now load gate or up
96
+ assert loaded_shard_id < len(self.output_sizes)
97
+ param_data = param.data
98
+ shard_offset, shard_size = gate_up_offsets[loaded_shard_id]
99
+ param_data = param_data.narrow(0, shard_offset, shard_size)
100
+ loaded_weight = loaded_weight.narrow(0, tp_rank * shard_size, shard_size)
101
+ assert param_data.shape == loaded_weight.shape
102
+ param_data.copy_(loaded_weight)
70
103
 
71
104
 
72
105
  class LlamaMLP(nn.Module):
106
+ _tp_plan = {
107
+ "gate_up_proj": "Colwise_Sharded",
108
+ "down_proj": "Rowwise",
109
+ }
110
+
73
111
  def __init__(
74
112
  self,
75
113
  hidden_size: int,
@@ -104,62 +142,44 @@ class LlamaMLP(nn.Module):
104
142
  return x
105
143
 
106
144
 
107
- def _get_shard_offset_mapping(self, loaded_shard_id: str):
108
- shard_offset_mapping = {
109
- "q": 0,
110
- "k": self.num_heads * self.head_size,
111
- "v": (self.num_heads + self.num_kv_heads) * self.head_size,
112
- "total": (self.num_heads + 2 * self.num_kv_heads) * self.head_size,
113
- }
114
- return shard_offset_mapping.get(loaded_shard_id)
115
-
116
-
117
- def _get_shard_size_mapping(self, loaded_shard_id: str):
118
- shard_size_mapping = {
119
- "q": self.num_heads * self.head_size,
120
- "k": self.num_kv_heads * self.head_size,
121
- "v": self.num_kv_heads * self.head_size,
122
- }
123
- return shard_size_mapping.get(loaded_shard_id)
124
-
125
-
126
145
  def qkv_proj_weight_loader(
127
146
  self,
128
147
  param: Parameter,
129
148
  loaded_weight: torch.Tensor,
130
- loaded_shard_id: Optional[str] = None,
149
+ loaded_shard_id: str,
131
150
  ):
132
- if loaded_shard_id is None:
133
- shard_offsets = [
134
- # (shard_id, shard_offset, shard_size)
135
- ("q", 0, self.total_num_heads * self.head_size),
136
- (
137
- "k",
138
- self.total_num_heads * self.head_size,
139
- self.total_num_kv_heads * self.head_size,
140
- ),
141
- (
142
- "v",
143
- (self.total_num_heads + self.total_num_kv_heads) * self.head_size,
144
- self.total_num_kv_heads * self.head_size,
145
- ),
146
- ]
147
- for shard_id, shard_offset, shard_size in shard_offsets:
148
- loaded_weight_shard = loaded_weight.narrow(
149
- param.output_dim, shard_offset, shard_size
150
- )
151
- self.weight_loader(param, loaded_weight_shard, shard_id)
152
- else:
153
- shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
154
- shard_size = self._get_shard_size_mapping(loaded_shard_id)
155
- param_data = param.data
156
- param_data = param_data.narrow(0, shard_offset, shard_size)
157
- assert param_data.shape == loaded_weight.shape
158
- param_data.copy_(loaded_weight)
159
- return
151
+ num_heads = self.num_heads // tp_size
152
+ num_kv_heads = self.num_kv_heads // tp_size
153
+ # shard_id: (shard_offset, shard_size)
154
+ qkv_offsets = {
155
+ "q": (0, num_heads * self.head_size),
156
+ "k": (num_heads * self.head_size, num_kv_heads * self.head_size),
157
+ "v": (
158
+ (num_heads + num_kv_heads) * self.head_size,
159
+ num_kv_heads * self.head_size,
160
+ ),
161
+ }
162
+ total_size = qkv_offsets["v"][0] + qkv_offsets["v"][1]
163
+ # Re-size the param to the size after TP
164
+ if total_size != param.shape[0]:
165
+ # The clone will free the original, full tensor
166
+ param.data = param.data.narrow(0, 0, total_size).clone()
167
+
168
+ # Now load q, k or v
169
+ shard_offset, shard_size = qkv_offsets[loaded_shard_id]
170
+ param_data = param.data
171
+ param_data = param_data.narrow(0, shard_offset, shard_size)
172
+ loaded_weight = loaded_weight.narrow(0, tp_rank * shard_size, shard_size)
173
+ assert param_data.shape == loaded_weight.shape
174
+ param_data.copy_(loaded_weight)
160
175
 
161
176
 
162
177
  class LlamaAttention(nn.Module):
178
+ _tp_plan = {
179
+ "qkv_proj": "Colwise_Sharded",
180
+ "o_proj": "Rowwise",
181
+ }
182
+
163
183
  def __init__(
164
184
  self,
165
185
  config: LlamaConfig,
@@ -176,7 +196,6 @@ class LlamaAttention(nn.Module):
176
196
  ) -> None:
177
197
  super().__init__()
178
198
  self.hidden_size = hidden_size
179
- tp_size = get_tensor_model_parallel_world_size()
180
199
  self.total_num_heads = num_heads
181
200
  assert self.total_num_heads % tp_size == 0
182
201
  self.num_heads = self.total_num_heads // tp_size
@@ -205,20 +224,12 @@ class LlamaAttention(nn.Module):
205
224
  (self.total_num_heads + 2 * self.total_num_kv_heads) * self.head_dim,
206
225
  bias=False,
207
226
  )
208
- self.qkv_proj.total_num_heads = self.total_num_heads
209
227
  self.qkv_proj.head_size = self.head_dim
210
- self.qkv_proj.total_num_kv_heads = self.total_num_kv_heads
211
228
  self.qkv_proj.num_heads = self.total_num_heads
212
229
  self.qkv_proj.num_kv_heads = self.total_num_kv_heads
213
230
  self.qkv_proj.weight_loader = types.MethodType(
214
231
  qkv_proj_weight_loader, self.qkv_proj
215
232
  )
216
- self.qkv_proj._get_shard_offset_mapping = types.MethodType(
217
- _get_shard_offset_mapping, self.qkv_proj
218
- )
219
- self.qkv_proj._get_shard_size_mapping = types.MethodType(
220
- _get_shard_size_mapping, self.qkv_proj
221
- )
222
233
  self.qkv_proj.weight.weight_loader = self.qkv_proj.weight_loader
223
234
  self.qkv_proj.weight.output_dim = 0
224
235
  self.o_proj = torch.nn.Linear(
@@ -385,10 +396,15 @@ class TorchNativeLlamaForCausalLM(nn.Module):
385
396
  self.config = config
386
397
  self.quant_config = quant_config
387
398
  self.torchao_config = global_server_args_dict["torchao_config"]
399
+ self.supports_torch_tp = True
388
400
  self.model = LlamaModel(config, quant_config=quant_config)
389
401
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
390
402
  self.logits_processor = LogitsProcessor(config)
391
403
 
404
+ # turning off autotune for fp8dq since it doesn't give speedup and
405
+ # increases compile time significantly
406
+ torch._inductor.config.max_autotune_gemm_backends = "ATEN"
407
+
392
408
  @torch.no_grad()
393
409
  def forward(
394
410
  self,
@@ -516,8 +516,9 @@ def v1_generate_request(
516
516
  "regex": request.regex,
517
517
  "json_schema": request.json_schema,
518
518
  "n": request.n,
519
- "ignore_eos": request.ignore_eos,
520
519
  "no_stop_trim": request.no_stop_trim,
520
+ "ignore_eos": request.ignore_eos,
521
+ "skip_special_tokens": request.skip_special_tokens,
521
522
  }
522
523
  )
523
524
  return_logprobs.append(request.logprobs is not None and request.logprobs > 0)
@@ -928,7 +929,9 @@ def v1_chat_generate_request(
928
929
  "repetition_penalty": request.repetition_penalty,
929
930
  "regex": request.regex,
930
931
  "n": request.n,
932
+ "no_stop_trim": request.no_stop_trim,
931
933
  "ignore_eos": request.ignore_eos,
934
+ "skip_special_tokens": request.skip_special_tokens,
932
935
  }
933
936
  if request.response_format and request.response_format.type == "json_schema":
934
937
  sampling_params["json_schema"] = convert_json_schema_to_str(
@@ -986,11 +989,15 @@ def v1_chat_generate_response(request, ret, to_file=False, cache_report=False):
986
989
  output_top_logprobs=ret_item["meta_info"]["output_top_logprobs"],
987
990
  )
988
991
  token_logprobs = []
989
- for token, logprob in zip(logprobs.tokens, logprobs.token_logprobs):
992
+ for token_idx, (token, logprob) in enumerate(
993
+ zip(logprobs.tokens, logprobs.token_logprobs)
994
+ ):
990
995
  token_bytes = list(token.encode("utf-8"))
991
996
  top_logprobs = []
992
997
  if logprobs.top_logprobs:
993
- for top_token, top_logprob in logprobs.top_logprobs[0].items():
998
+ for top_token, top_logprob in logprobs.top_logprobs[
999
+ token_idx
1000
+ ].items():
994
1001
  top_token_bytes = list(top_token.encode("utf-8"))
995
1002
  top_logprobs.append(
996
1003
  TopLogprob(
@@ -1166,7 +1173,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
1166
1173
  is_first = False
1167
1174
  choice_data = ChatCompletionResponseStreamChoice(
1168
1175
  index=index,
1169
- delta=DeltaMessage(role="assistant"),
1176
+ delta=DeltaMessage(role="assistant", content=""),
1170
1177
  finish_reason=(
1171
1178
  finish_reason["type"] if finish_reason else ""
1172
1179
  ),
@@ -36,7 +36,7 @@ class ModelList(BaseModel):
36
36
  """Model list consists of model cards."""
37
37
 
38
38
  object: str = "list"
39
- data: List[ModelCard] = []
39
+ data: List[ModelCard] = Field(default_factory=list)
40
40
 
41
41
 
42
42
  class ErrorResponse(BaseModel):
@@ -143,7 +143,7 @@ class BatchResponse(BaseModel):
143
143
  expired_at: Optional[int] = None
144
144
  cancelling_at: Optional[int] = None
145
145
  cancelled_at: Optional[int] = None
146
- request_counts: dict = {"total": 0, "completed": 0, "failed": 0}
146
+ request_counts: Optional[dict] = None
147
147
  metadata: Optional[dict] = None
148
148
 
149
149
 
@@ -153,30 +153,31 @@ class CompletionRequest(BaseModel):
153
153
  model: str
154
154
  prompt: Union[List[int], List[List[int]], str, List[str]]
155
155
  best_of: Optional[int] = None
156
- echo: Optional[bool] = False
157
- frequency_penalty: Optional[float] = 0.0
156
+ echo: bool = False
157
+ frequency_penalty: float = 0.0
158
158
  logit_bias: Optional[Dict[str, float]] = None
159
159
  logprobs: Optional[int] = None
160
- max_tokens: Optional[int] = 16
160
+ max_tokens: int = 16
161
161
  n: int = 1
162
- presence_penalty: Optional[float] = 0.0
162
+ presence_penalty: float = 0.0
163
163
  seed: Optional[int] = None
164
- stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
165
- stream: Optional[bool] = False
164
+ stop: Optional[Union[str, List[str]]] = None
165
+ stream: bool = False
166
166
  stream_options: Optional[StreamOptions] = None
167
167
  suffix: Optional[str] = None
168
- temperature: Optional[float] = 1.0
169
- top_p: Optional[float] = 1.0
168
+ temperature: float = 1.0
169
+ top_p: float = 1.0
170
170
  user: Optional[str] = None
171
171
 
172
172
  # Extra parameters for SRT backend only and will be ignored by OpenAI models.
173
- regex: Optional[str] = None
174
173
  json_schema: Optional[str] = None
175
- ignore_eos: bool = False
174
+ regex: Optional[str] = None
176
175
  min_tokens: int = 0
177
- repetition_penalty: Optional[float] = 1.0
178
- stop_token_ids: Optional[List[int]] = Field(default_factory=list)
179
- no_stop_trim: Union[bool, List[bool]] = False
176
+ repetition_penalty: float = 1.0
177
+ stop_token_ids: Optional[List[int]] = None
178
+ no_stop_trim: bool = False
179
+ ignore_eos: bool = False
180
+ skip_special_tokens: bool = True
180
181
 
181
182
 
182
183
  class CompletionResponseChoice(BaseModel):
@@ -235,7 +236,7 @@ ChatCompletionMessageContentPart = Union[
235
236
 
236
237
 
237
238
  class ChatCompletionMessageGenericParam(BaseModel):
238
- role: Literal["system", "assistant"]
239
+ role: Literal["system", "assistant", "tool"]
239
240
  content: Union[str, List[ChatCompletionMessageContentTextPart]]
240
241
 
241
242
 
@@ -259,28 +260,30 @@ class ChatCompletionRequest(BaseModel):
259
260
  # https://platform.openai.com/docs/api-reference/chat/create
260
261
  messages: List[ChatCompletionMessageParam]
261
262
  model: str
262
- frequency_penalty: Optional[float] = 0.0
263
+ frequency_penalty: float = 0.0
263
264
  logit_bias: Optional[Dict[str, float]] = None
264
- logprobs: Optional[bool] = False
265
+ logprobs: bool = False
265
266
  top_logprobs: Optional[int] = None
266
267
  max_tokens: Optional[int] = None
267
- n: Optional[int] = 1
268
- presence_penalty: Optional[float] = 0.0
268
+ n: int = 1
269
+ presence_penalty: float = 0.0
269
270
  response_format: Optional[ResponseFormat] = None
270
271
  seed: Optional[int] = None
271
- stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
272
- stream: Optional[bool] = False
272
+ stop: Optional[Union[str, List[str]]] = None
273
+ stream: bool = False
273
274
  stream_options: Optional[StreamOptions] = None
274
- temperature: Optional[float] = 0.7
275
- top_p: Optional[float] = 1.0
275
+ temperature: float = 0.7
276
+ top_p: float = 1.0
276
277
  user: Optional[str] = None
277
278
 
278
279
  # Extra parameters for SRT backend only and will be ignored by OpenAI models.
279
280
  regex: Optional[str] = None
280
- min_tokens: Optional[int] = 0
281
- repetition_penalty: Optional[float] = 1.0
282
- stop_token_ids: Optional[List[int]] = Field(default_factory=list)
281
+ min_tokens: int = 0
282
+ repetition_penalty: float = 1.0
283
+ stop_token_ids: Optional[List[int]] = None
284
+ no_stop_trim: bool = False
283
285
  ignore_eos: bool = False
286
+ skip_special_tokens: bool = True
284
287
 
285
288
 
286
289
  class ChatMessage(BaseModel):
@@ -1,40 +1,34 @@
1
1
  import abc
2
2
  import dataclasses
3
- import typing
3
+ from typing import List, Set, Type, Union
4
4
 
5
5
  import torch
6
6
 
7
7
 
8
8
  @dataclasses.dataclass
9
9
  class _ReqLike:
10
- origin_input_ids: typing.Union[torch.Tensor, typing.List[int]]
10
+ origin_input_ids: List[int]
11
11
 
12
12
 
13
13
  @dataclasses.dataclass
14
14
  class _BatchLike:
15
- reqs: typing.List[_ReqLike]
15
+ reqs: List[_ReqLike]
16
16
 
17
17
  def batch_size(self):
18
18
  return len(self.reqs)
19
19
 
20
20
 
21
21
  class BatchedPenalizerOrchestrator:
22
- batch: _BatchLike
23
- device: str
24
- vocab_size: int
25
- penalizers: typing.Dict[typing.Type["_BatchedPenalizer"], "_BatchedPenalizer"]
26
-
27
22
  def __init__(
28
23
  self,
29
24
  vocab_size: int,
30
25
  batch: _BatchLike,
31
26
  device: str,
32
- Penalizers: typing.Set[typing.Type["_BatchedPenalizer"]],
27
+ Penalizers: Set[Type["_BatchedPenalizer"]],
33
28
  ):
34
29
  self.vocab_size = vocab_size
35
30
  self.batch = batch
36
31
  self.device = device
37
-
38
32
  self.penalizers = {Penalizer: Penalizer(self) for Penalizer in Penalizers}
39
33
 
40
34
  is_required = False
@@ -43,10 +37,12 @@ class BatchedPenalizerOrchestrator:
43
37
  is_required |= pen_is_required
44
38
  self.is_required = is_required
45
39
 
40
+ input_ids = [
41
+ torch.tensor(req.origin_input_ids, dtype=torch.int64, device=self.device)
42
+ for req in self.reqs()
43
+ ]
46
44
  if self.is_required:
47
- self.cumulate_input_tokens(
48
- input_ids=[req.origin_input_ids for req in self.reqs()]
49
- )
45
+ self.cumulate_input_tokens(input_ids=input_ids)
50
46
 
51
47
  def reqs(self):
52
48
  return self.batch.reqs
@@ -54,34 +50,24 @@ class BatchedPenalizerOrchestrator:
54
50
  def batch_size(self):
55
51
  return self.batch.batch_size()
56
52
 
57
- def cumulate_input_tokens(
58
- self,
59
- input_ids: typing.Union[
60
- typing.List[torch.Tensor], typing.List[typing.List[int]]
61
- ],
62
- ):
53
+ def cumulate_input_tokens(self, input_ids: List[torch.Tensor]):
63
54
  """
64
55
  Feed the input tokens to the penalizers.
65
56
 
66
57
  Args:
67
- input_ids (typing.Union[typing.List[torch.Tensor], typing.List[typing.List[int]]]): The input tokens.
58
+ input_ids (List[torch.Tensor]): The input tokens.
68
59
  """
69
60
  token_ids = _TokenIDs(orchestrator=self, token_ids=input_ids)
70
61
 
71
62
  for penalizer in self.penalizers.values():
72
63
  penalizer.cumulate_input_tokens(input_ids=token_ids)
73
64
 
74
- def cumulate_output_tokens(
75
- self,
76
- output_ids: typing.Union[
77
- typing.List[torch.Tensor], typing.List[typing.List[int]]
78
- ],
79
- ):
65
+ def cumulate_output_tokens(self, output_ids: torch.Tensor):
80
66
  """
81
67
  Feed the output tokens to the penalizers.
82
68
 
83
69
  Args:
84
- output_ids (typing.Union[typing.List[torch.Tensor], typing.List[typing.List[int]]]): The output tokens.
70
+ output_ids (torch.Tensor): The output tokens.
85
71
  """
86
72
  if not self.is_required:
87
73
  return
@@ -112,14 +98,14 @@ class BatchedPenalizerOrchestrator:
112
98
 
113
99
  def filter(
114
100
  self,
115
- indices_to_keep: typing.List[int],
101
+ indices_to_keep: List[int],
116
102
  indices_tensor_to_keep: torch.Tensor = None,
117
103
  ):
118
104
  """
119
105
  Filter the penalizers based on the indices to keep in the batch.
120
106
 
121
107
  Args:
122
- indices_to_keep (typing.List[int]): List of indices to keep in the batch.
108
+ indices_to_keep (List[int]): List of indices to keep in the batch.
123
109
  indices_tensor_to_keep (torch.Tensor = None): Tensor of indices to keep in the batch. If not None, it will be used instead of converting indices_to_keep to a tensor.
124
110
  """
125
111
  if not self.is_required:
@@ -174,32 +160,18 @@ class _TokenIDs:
174
160
 
175
161
  Attributes:
176
162
  orchestrator (BatchedPenalizerOrchestrator): The orchestrator that this token IDs belong to.
177
- token_ids (typing.Union[torch.Tensor, typing.List[torch.Tensor]]): The token IDs.
163
+ token_ids (Union[torch.Tensor, List[torch.Tensor]]): The token IDs.
178
164
  cached_counts (torch.Tensor): The cached occurrence count tensor.
179
165
  """
180
166
 
181
- orchestrator: BatchedPenalizerOrchestrator
182
- token_ids: typing.Union[torch.Tensor, typing.List[torch.Tensor]]
183
- cached_counts: torch.Tensor = None
184
-
185
167
  def __init__(
186
168
  self,
187
169
  orchestrator: BatchedPenalizerOrchestrator,
188
- token_ids: typing.Union[
189
- typing.List[torch.Tensor], typing.List[typing.List[int]]
190
- ],
170
+ token_ids: Union[torch.Tensor, List[torch.Tensor]],
191
171
  ):
192
172
  self.orchestrator = orchestrator
193
-
194
- if not isinstance(token_ids[0], torch.Tensor):
195
- token_ids = [
196
- torch.tensor(
197
- data=ids, dtype=torch.int64, device=self.orchestrator.device
198
- )
199
- for ids in token_ids
200
- ]
201
-
202
173
  self.token_ids = token_ids
174
+ self.cached_counts = None
203
175
 
204
176
  def occurrence_count(self) -> torch.Tensor:
205
177
  """
@@ -213,30 +185,34 @@ class _TokenIDs:
213
185
 
214
186
  token_ids = self.token_ids
215
187
 
216
- if isinstance(token_ids, torch.Tensor):
217
- token_ids = token_ids.unsqueeze(1)
218
-
219
- # needs to be long to be used as index in scatter_add
220
- if token_ids.dtype != torch.int64:
221
- token_ids = token_ids.to(torch.int64)
222
-
223
- padded_token_ids = torch.nn.utils.rnn.pad_sequence(
224
- sequences=token_ids,
225
- batch_first=True,
226
- padding_value=self.orchestrator.vocab_size,
227
- )
228
-
229
- self.cached_counts = torch.zeros(
230
- size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size + 1),
231
- dtype=torch.int64,
232
- device=self.orchestrator.device,
233
- ).scatter_add_(
234
- dim=1,
235
- index=padded_token_ids,
236
- src=torch.ones_like(padded_token_ids),
237
- )[
238
- :, : self.orchestrator.vocab_size
239
- ]
188
+ if isinstance(token_ids, list):
189
+ # TODO: optimize this part
190
+ padded_token_ids = torch.nn.utils.rnn.pad_sequence(
191
+ sequences=token_ids,
192
+ batch_first=True,
193
+ padding_value=self.orchestrator.vocab_size,
194
+ )
195
+ self.cached_counts = torch.zeros(
196
+ size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size + 1),
197
+ dtype=torch.int64,
198
+ device=self.orchestrator.device,
199
+ ).scatter_add_(
200
+ dim=1,
201
+ index=padded_token_ids,
202
+ src=torch.ones_like(padded_token_ids),
203
+ )[
204
+ :, : self.orchestrator.vocab_size
205
+ ]
206
+ else:
207
+ # TODO: optimize this part. We do not need to create this big tensor every time.
208
+ # We can directly apply the results on the logits.
209
+ self.cached_counts = torch.zeros(
210
+ size=(self.orchestrator.batch_size(), self.orchestrator.vocab_size),
211
+ device=self.orchestrator.device,
212
+ )
213
+ self.cached_counts[
214
+ torch.arange(len(token_ids), device=self.orchestrator.device), token_ids
215
+ ] = 1
240
216
 
241
217
  return self.cached_counts
242
218
 
@@ -246,11 +222,9 @@ class _BatchedPenalizer(abc.ABC):
246
222
  An abstract class for a batched penalizer.
247
223
  """
248
224
 
249
- orchestrator: BatchedPenalizerOrchestrator
250
- _is_prepared: bool = False
251
-
252
225
  def __init__(self, orchestrator: BatchedPenalizerOrchestrator):
253
226
  self.orchestrator = orchestrator
227
+ self._is_prepared = False
254
228
 
255
229
  def is_prepared(self) -> bool:
256
230
  return self._is_prepared
@@ -293,9 +267,7 @@ class _BatchedPenalizer(abc.ABC):
293
267
 
294
268
  return self._apply(logits=logits)
295
269
 
296
- def filter(
297
- self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
298
- ):
270
+ def filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
299
271
  if not self.is_prepared():
300
272
  return
301
273
 
@@ -360,9 +332,7 @@ class _BatchedPenalizer(abc.ABC):
360
332
  pass
361
333
 
362
334
  @abc.abstractmethod
363
- def _filter(
364
- self, indices_to_keep: typing.List[int], indices_tensor_to_keep: torch.Tensor
365
- ):
335
+ def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
366
336
  """
367
337
  Filter the penalizer (tensors or underlying data) based on the indices to keep in the batch.
368
338
  """