sglang 0.3.5__py3-none-any.whl → 0.3.5.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. sglang/bench_serving.py +113 -3
  2. sglang/srt/configs/model_config.py +5 -2
  3. sglang/srt/constrained/__init__.py +2 -66
  4. sglang/srt/constrained/base_grammar_backend.py +72 -0
  5. sglang/srt/constrained/outlines_backend.py +165 -0
  6. sglang/srt/constrained/outlines_jump_forward.py +182 -0
  7. sglang/srt/constrained/xgrammar_backend.py +114 -0
  8. sglang/srt/layers/attention/triton_ops/decode_attention.py +7 -0
  9. sglang/srt/layers/attention/triton_ops/extend_attention.py +6 -0
  10. sglang/srt/layers/fused_moe/fused_moe.py +23 -7
  11. sglang/srt/layers/quantization/base_config.py +4 -6
  12. sglang/srt/layers/vocab_parallel_embedding.py +216 -150
  13. sglang/srt/managers/io_struct.py +5 -3
  14. sglang/srt/managers/schedule_batch.py +14 -20
  15. sglang/srt/managers/scheduler.py +153 -94
  16. sglang/srt/managers/tokenizer_manager.py +81 -17
  17. sglang/srt/metrics/collector.py +211 -0
  18. sglang/srt/metrics/func_timer.py +108 -0
  19. sglang/srt/mm_utils.py +1 -1
  20. sglang/srt/model_executor/cuda_graph_runner.py +2 -2
  21. sglang/srt/model_executor/forward_batch_info.py +7 -3
  22. sglang/srt/model_executor/model_runner.py +2 -1
  23. sglang/srt/models/gemma2_reward.py +69 -0
  24. sglang/srt/models/gpt2.py +31 -37
  25. sglang/srt/models/internlm2_reward.py +62 -0
  26. sglang/srt/models/llama.py +11 -6
  27. sglang/srt/models/llama_reward.py +5 -26
  28. sglang/srt/models/qwen2_vl.py +5 -7
  29. sglang/srt/openai_api/adapter.py +6 -2
  30. sglang/srt/sampling/sampling_batch_info.py +2 -3
  31. sglang/srt/sampling/sampling_params.py +0 -14
  32. sglang/srt/server.py +58 -16
  33. sglang/srt/server_args.py +42 -22
  34. sglang/srt/utils.py +87 -0
  35. sglang/test/simple_eval_common.py +1 -1
  36. sglang/test/simple_eval_humaneval.py +2 -2
  37. sglang/test/simple_eval_mgsm.py +2 -2
  38. sglang/test/test_utils.py +18 -4
  39. sglang/utils.py +1 -0
  40. sglang/version.py +1 -1
  41. {sglang-0.3.5.dist-info → sglang-0.3.5.post1.dist-info}/METADATA +11 -7
  42. {sglang-0.3.5.dist-info → sglang-0.3.5.post1.dist-info}/RECORD +45 -42
  43. {sglang-0.3.5.dist-info → sglang-0.3.5.post1.dist-info}/WHEEL +1 -1
  44. sglang/srt/constrained/base_tool_cache.py +0 -65
  45. sglang/srt/constrained/bnf_cache.py +0 -61
  46. sglang/srt/constrained/fsm_cache.py +0 -95
  47. sglang/srt/constrained/grammar.py +0 -190
  48. sglang/srt/constrained/jump_forward.py +0 -203
  49. {sglang-0.3.5.dist-info → sglang-0.3.5.post1.dist-info}/LICENSE +0 -0
  50. {sglang-0.3.5.dist-info → sglang-0.3.5.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,69 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
16
+ from typing import Iterable, Optional, Tuple
17
+
18
+ import torch
19
+ from torch import nn
20
+ from transformers import Gemma2Config
21
+
22
+ from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
23
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
24
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
25
+ from sglang.srt.models.gemma2 import Gemma2ForCausalLM, Gemma2Model
26
+
27
+
28
+ class Gemma2ForSequenceClassification(nn.Module):
29
+ def __init__(
30
+ self,
31
+ config: Gemma2Config,
32
+ quant_config: Optional[QuantizationConfig] = None,
33
+ cache_config=None,
34
+ ) -> None:
35
+ super().__init__()
36
+ self.config = config
37
+ self.torchao_config = None
38
+ self.quant_config = quant_config
39
+ self.num_labels = config.num_labels
40
+ self.model = Gemma2Model(config, quant_config=quant_config)
41
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
42
+ self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=False)
43
+
44
+ self.eos_token_id = config.eos_token_id
45
+
46
+ @torch.no_grad()
47
+ def forward(
48
+ self,
49
+ input_ids: torch.Tensor,
50
+ positions: torch.Tensor,
51
+ forward_batch: ForwardBatch,
52
+ input_embeds: torch.Tensor = None,
53
+ get_embedding: bool = True,
54
+ ) -> EmbeddingPoolerOutput:
55
+ assert (
56
+ get_embedding
57
+ ), "Gemma2ForSequenceClassification is only used for embedding"
58
+
59
+ hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
60
+ last_token_hidden = self.pooler(hidden_states, forward_batch).embeddings
61
+ scores = self.score(last_token_hidden)
62
+
63
+ return EmbeddingPoolerOutput(scores)
64
+
65
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
66
+ Gemma2ForCausalLM.load_weights(self, weights)
67
+
68
+
69
+ EntryClass = [Gemma2ForSequenceClassification]
sglang/srt/models/gpt2.py CHANGED
@@ -28,7 +28,7 @@ from vllm.model_executor.layers.activation import get_act_fn
28
28
  from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
29
29
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
30
30
 
31
- #from sglang.srt.layers.activation import get_act_fn
31
+ # from sglang.srt.layers.activation import get_act_fn
32
32
  from sglang.srt.layers.linear import (
33
33
  ColumnParallelLinear,
34
34
  QKVParallelLinear,
@@ -47,15 +47,14 @@ class GPT2Attention(nn.Module):
47
47
  self,
48
48
  layer_id: int,
49
49
  config: GPT2Config,
50
- cache_config = None,
50
+ cache_config=None,
51
51
  quant_config: Optional[QuantizationConfig] = None,
52
52
  prefix: str = "",
53
53
  ):
54
54
  super().__init__()
55
55
  self.hidden_size = config.hidden_size
56
56
  total_num_heads = config.num_attention_heads
57
- tensor_model_parallel_world_size = (
58
- get_tensor_model_parallel_world_size())
57
+ tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
59
58
  assert total_num_heads % tensor_model_parallel_world_size == 0
60
59
  self.num_heads = total_num_heads // tensor_model_parallel_world_size
61
60
  self.head_dim = self.hidden_size // total_num_heads
@@ -76,11 +75,13 @@ class GPT2Attention(nn.Module):
76
75
  quant_config=quant_config,
77
76
  prefix=f"{prefix}.c_proj",
78
77
  )
79
- self.attn = RadixAttention(self.num_heads,
80
- self.head_dim,
81
- scaling=self.scale,
82
- num_kv_heads=total_num_heads,
83
- layer_id=layer_id)
78
+ self.attn = RadixAttention(
79
+ self.num_heads,
80
+ self.head_dim,
81
+ scaling=self.scale,
82
+ num_kv_heads=total_num_heads,
83
+ layer_id=layer_id,
84
+ )
84
85
 
85
86
  def forward(
86
87
  self,
@@ -119,10 +120,14 @@ class GPT2MLP(nn.Module):
119
120
  quant_config=quant_config,
120
121
  prefix=f"{prefix}.c_proj",
121
122
  )
122
- self.act = get_act_fn(config.activation_function, quant_config,
123
- intermediate_size)
123
+ self.act = get_act_fn(
124
+ config.activation_function, quant_config, intermediate_size
125
+ )
124
126
 
125
- def forward(self, hidden_states: torch.Tensor,) -> torch.Tensor:
127
+ def forward(
128
+ self,
129
+ hidden_states: torch.Tensor,
130
+ ) -> torch.Tensor:
126
131
  hidden_states, _ = self.c_fc(hidden_states)
127
132
  hidden_states = self.act(hidden_states)
128
133
  hidden_states, _ = self.c_proj(hidden_states)
@@ -135,27 +140,20 @@ class GPT2Block(nn.Module):
135
140
  self,
136
141
  layer_id: int,
137
142
  config: GPT2Config,
138
- cache_config = None,
139
-
143
+ cache_config=None,
140
144
  quant_config: Optional[QuantizationConfig] = None,
141
145
  prefix: str = "",
142
146
  ):
143
147
  super().__init__()
144
148
  hidden_size = config.hidden_size
145
- inner_dim = (config.n_inner if config.n_inner is not None else 4 *
146
- hidden_size)
149
+ inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
147
150
 
148
151
  self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
149
- self.attn = GPT2Attention(layer_id,
150
- config,
151
- cache_config,
152
- quant_config,
153
- prefix=f"{prefix}.attn")
152
+ self.attn = GPT2Attention(
153
+ layer_id, config, cache_config, quant_config, prefix=f"{prefix}.attn"
154
+ )
154
155
  self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
155
- self.mlp = GPT2MLP(inner_dim,
156
- config,
157
- quant_config,
158
- prefix=f"{prefix}.mlp")
156
+ self.mlp = GPT2MLP(inner_dim, config, quant_config, prefix=f"{prefix}.mlp")
159
157
 
160
158
  def forward(
161
159
  self,
@@ -179,13 +177,12 @@ class GPT2Block(nn.Module):
179
177
  return hidden_states
180
178
 
181
179
 
182
-
183
180
  class GPT2Model(nn.Module):
184
181
 
185
182
  def __init__(
186
183
  self,
187
184
  config: GPT2Config,
188
- cache_config = None,
185
+ cache_config=None,
189
186
  quant_config: Optional[QuantizationConfig] = None,
190
187
  prefix: str = "",
191
188
  ):
@@ -229,16 +226,15 @@ class GPT2LMHeadModel(nn.Module):
229
226
  def __init__(
230
227
  self,
231
228
  config: GPT2Config,
232
- cache_config = None,
229
+ cache_config=None,
233
230
  quant_config: Optional[QuantizationConfig] = None,
234
231
  ):
235
232
  super().__init__()
236
233
  self.config = config
237
234
  self.quant_config = quant_config
238
- self.transformer = GPT2Model(config,
239
- cache_config,
240
- quant_config,
241
- prefix="transformer")
235
+ self.transformer = GPT2Model(
236
+ config, cache_config, quant_config, prefix="transformer"
237
+ )
242
238
  self.lm_head = self.transformer.wte
243
239
 
244
240
  self.logits_processor = LogitsProcessor(config)
@@ -254,8 +250,6 @@ class GPT2LMHeadModel(nn.Module):
254
250
  input_ids, hidden_states, self.lm_head.weight, forward_batch
255
251
  )
256
252
 
257
-
258
-
259
253
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
260
254
  params_dict = dict(self.named_parameters(remove_duplicate=False))
261
255
  for name, loaded_weight in weights:
@@ -280,8 +274,8 @@ class GPT2LMHeadModel(nn.Module):
280
274
  if not name.endswith(".weight"):
281
275
  continue
282
276
  loaded_weight = loaded_weight.t()
283
- weight_loader = getattr(param, "weight_loader",
284
- default_weight_loader)
277
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
285
278
  weight_loader(param, loaded_weight)
286
279
 
287
- EntryClass = GPT2LMHeadModel
280
+
281
+ EntryClass = GPT2LMHeadModel
@@ -0,0 +1,62 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
16
+ from typing import Iterable, Optional, Tuple
17
+
18
+ import torch
19
+ from torch import nn
20
+ from transformers import PretrainedConfig
21
+
22
+ from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
23
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
24
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
25
+ from sglang.srt.models.internlm2 import InternLM2ForCausalLM, InternLM2Model
26
+
27
+
28
+ class InternLM2ForRewardModel(nn.Module):
29
+ def __init__(
30
+ self,
31
+ config: PretrainedConfig,
32
+ quant_config: Optional[QuantizationConfig] = None,
33
+ cache_config=None,
34
+ ) -> None:
35
+ super().__init__()
36
+ self.config = config
37
+ self.quant_config = quant_config
38
+ self.vocab_size = config.vocab_size
39
+ self.model = InternLM2Model(config, quant_config)
40
+ self.v_head = nn.Linear(config.hidden_size, 1, bias=False)
41
+ self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=False)
42
+
43
+ @torch.no_grad()
44
+ def forward(
45
+ self,
46
+ input_ids: torch.Tensor,
47
+ positions: torch.Tensor,
48
+ forward_batch: ForwardBatch,
49
+ input_embeds: torch.Tensor = None,
50
+ get_embedding: bool = True,
51
+ ) -> EmbeddingPoolerOutput:
52
+ assert get_embedding, "InternLM2ForRewardModel is only used for embedding"
53
+ hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
54
+ last_token_hidden = self.pooler(hidden_states, forward_batch).embeddings
55
+ scores = self.v_head(last_token_hidden)
56
+ return EmbeddingPoolerOutput(scores)
57
+
58
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
59
+ return InternLM2ForCausalLM.load_weights(self, weights)
60
+
61
+
62
+ EntryClass = InternLM2ForRewardModel
@@ -380,6 +380,12 @@ class LlamaForCausalLM(nn.Module):
380
380
  ]
381
381
  params_dict = dict(self.named_parameters())
382
382
 
383
+ load_tie_word_embeddings = (
384
+ hasattr(self.config, "tie_word_embeddings")
385
+ and self.config.tie_word_embeddings
386
+ and "lm_head.weight" in params_dict
387
+ )
388
+
383
389
  for name, loaded_weight in weights:
384
390
  if "rotary_emb.inv_freq" in name or "projector" in name:
385
391
  continue
@@ -412,15 +418,14 @@ class LlamaForCausalLM(nn.Module):
412
418
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
413
419
  weight_loader(param, loaded_weight)
414
420
 
415
- if (
416
- hasattr(self.config, "tie_word_embeddings")
417
- and self.config.tie_word_embeddings
418
- and "lm_head.weight" in params_dict
419
- ):
421
+ if load_tie_word_embeddings and name == "model.embed_tokens.weight":
422
+ embed_tokens_weight = loaded_weight
423
+
424
+ if load_tie_word_embeddings:
420
425
  # Tie output embedding layer to input embedding layer, to solve issues where lm_head.weight is missing
421
426
  param = self.lm_head.weight
422
427
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
423
- weight_loader(param, self.model.embed_tokens.weight)
428
+ weight_loader(param, embed_tokens_weight)
424
429
 
425
430
  apply_torchao_config_(self, params_dict, set(["proj.weight"]))
426
431
 
@@ -18,9 +18,7 @@ from typing import Iterable, Optional, Tuple
18
18
  import torch
19
19
  from torch import nn
20
20
  from transformers import LlamaConfig
21
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
22
21
 
23
- from sglang.srt.layers.logits_processor import LogitsProcessorOutput
24
22
  from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
25
23
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
26
24
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
@@ -59,22 +57,13 @@ class LlamaForSequenceClassification(nn.Module):
59
57
  ), "LlamaForSequenceClassification is only used for embedding"
60
58
 
61
59
  hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
62
- scores = self.score(hidden_states)
60
+ last_token_hidden = self.pooler(hidden_states, forward_batch).embeddings
61
+ scores = self.score(last_token_hidden)
63
62
 
64
- return self.pooler(scores, forward_batch)
63
+ return EmbeddingPoolerOutput(scores)
65
64
 
66
65
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
67
- params_dict = dict(self.named_parameters())
68
-
69
- for name, loaded_weight in weights:
70
- if "classification_head" in name:
71
- param = params_dict[name]
72
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
73
- weight_loader(param, loaded_weight)
74
- elif "lm_head" in name:
75
- continue
76
- else:
77
- LlamaForCausalLM.load_weights(self, [(name, loaded_weight)])
66
+ return LlamaForCausalLM.load_weights(self, weights)
78
67
 
79
68
 
80
69
  class LlamaForSequenceClassificationWithNormal_Weights(LlamaForSequenceClassification):
@@ -127,17 +116,7 @@ class LlamaForSequenceClassificationWithNormal_Weights(LlamaForSequenceClassific
127
116
  return EmbeddingPoolerOutput(scores)
128
117
 
129
118
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
130
- params_dict = dict(self.named_parameters())
131
-
132
- for name, loaded_weight in weights:
133
- if "classification_head" in name:
134
- param = params_dict[name]
135
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
136
- weight_loader(param, loaded_weight)
137
- elif "lm_head" in name:
138
- continue
139
- else:
140
- LlamaForCausalLM.load_weights(self, [(name, loaded_weight)])
119
+ return super().load_weights(weights)
141
120
 
142
121
 
143
122
  EntryClass = [
@@ -57,27 +57,27 @@ logger = init_logger(__name__)
57
57
 
58
58
  class Qwen2VLImageInputs(TypedDict):
59
59
  pixel_values: torch.Tensor
60
- """Shape:
60
+ """Shape:
61
61
  `(num_patches, num_channels * patch_size * patch_size)`
62
62
  """
63
63
 
64
64
  image_grid_thw: torch.Tensor
65
65
  """Shape: `(num_images, 3)`
66
-
66
+
67
67
  This should be in `(grid_t, grid_h, grid_w)` format.
68
68
  """
69
69
 
70
70
 
71
71
  class Qwen2VLVideoInputs(TypedDict):
72
72
  pixel_values_videos: torch.Tensor
73
- """Shape:
74
- `(num_patches,
73
+ """Shape:
74
+ `(num_patches,
75
75
  num_channels * temporal_patch_size * patch_size * patch_size)`
76
76
  """
77
77
 
78
78
  video_grid_thw: torch.Tensor
79
79
  """Shape: `(num_videos, 3)`
80
-
80
+
81
81
  This should be in `(grid_t, grid_h, grid_w)` format.
82
82
  """
83
83
 
@@ -649,8 +649,6 @@ class Qwen2VLForConditionalGeneration(nn.Module):
649
649
  ]
650
650
  image_embeds_offset += num_image_tokens
651
651
 
652
- input_ids = None
653
-
654
652
  hidden_states = self.model(
655
653
  input_ids=input_ids,
656
654
  positions=positions,
@@ -498,6 +498,10 @@ def v1_generate_request(
498
498
  )
499
499
 
500
500
  prompts.append(request.prompt)
501
+ if request.echo and request.logprobs:
502
+ current_logprob_start_len = 0
503
+ else:
504
+ current_logprob_start_len = -1
501
505
  sampling_params_list.append(
502
506
  {
503
507
  "temperature": request.temperature,
@@ -517,7 +521,7 @@ def v1_generate_request(
517
521
  }
518
522
  )
519
523
  return_logprobs.append(request.logprobs is not None and request.logprobs > 0)
520
- logprob_start_lens.append(-1)
524
+ logprob_start_lens.append(current_logprob_start_len)
521
525
  top_logprobs_nums.append(
522
526
  request.logprobs if request.logprobs is not None else 0
523
527
  )
@@ -1277,7 +1281,7 @@ def v1_embedding_request(all_requests, tokenizer_manager):
1277
1281
  else:
1278
1282
  prompt_kwargs = {"input_ids": prompt}
1279
1283
  else:
1280
- if isinstance(prompts[0], str) or isinstance(propmt[0][0], str):
1284
+ if isinstance(prompts[0], str) or isinstance(propmts[0][0], str):
1281
1285
  prompt_kwargs = {"text": prompts}
1282
1286
  else:
1283
1287
  prompt_kwargs = {"input_ids": prompts}
@@ -6,7 +6,6 @@ from typing import TYPE_CHECKING, List, Optional
6
6
  import torch
7
7
 
8
8
  import sglang.srt.sampling.penaltylib as penaltylib
9
- from sglang.srt.constrained.grammar import Grammar
10
9
 
11
10
  if TYPE_CHECKING:
12
11
  from sglang.srt.managers.schedule_batch import ScheduleBatch
@@ -31,7 +30,7 @@ class SamplingBatchInfo:
31
30
  logit_bias: torch.Tensor = None
32
31
  vocab_mask: Optional[torch.Tensor] = None
33
32
 
34
- grammars: Optional[List[Optional[Grammar]]] = None
33
+ grammars: Optional[List] = None
35
34
 
36
35
  # Penalizer
37
36
  penalizer_orchestrator: Optional[penaltylib.BatchedPenalizerOrchestrator] = None
@@ -146,7 +145,7 @@ class SamplingBatchInfo:
146
145
  )
147
146
  for i, grammar in enumerate(self.grammars):
148
147
  if grammar is not None:
149
- grammar.fill_vocab_mask(self.vocab_mask[i], self.vocab_size)
148
+ grammar.fill_vocab_mask(self.vocab_mask[i])
150
149
 
151
150
  def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):
152
151
  if self.penalizer_orchestrator:
@@ -133,17 +133,3 @@ class SamplingParams:
133
133
  else:
134
134
  stop_str_max_len = max(stop_str_max_len, len(stop_str))
135
135
  self.stop_str_max_len = stop_str_max_len
136
-
137
- def to_srt_kwargs(self):
138
- return {
139
- "max_new_tokens": self.max_new_tokens,
140
- "stop": self.stop_strs,
141
- "stop_token_ids": list(self.stop_token_ids),
142
- "temperature": self.temperature,
143
- "top_p": self.top_p,
144
- "top_k": self.top_k,
145
- "frequency_penalty": self.frequency_penalty,
146
- "presence_penalty": self.presence_penalty,
147
- "ignore_eos": self.ignore_eos,
148
- "regex": self.regex,
149
- }