sglang 0.2.14.post1__py3-none-any.whl → 0.2.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. sglang/api.py +2 -0
  2. sglang/bench_latency.py +39 -28
  3. sglang/lang/interpreter.py +3 -0
  4. sglang/lang/ir.py +5 -0
  5. sglang/launch_server_llavavid.py +26 -0
  6. sglang/srt/configs/__init__.py +5 -0
  7. sglang/srt/configs/exaone.py +195 -0
  8. sglang/srt/constrained/fsm_cache.py +1 -1
  9. sglang/srt/conversation.py +24 -2
  10. sglang/srt/hf_transformers_utils.py +11 -160
  11. sglang/srt/layers/activation.py +10 -4
  12. sglang/srt/layers/extend_attention.py +13 -8
  13. sglang/srt/layers/layernorm.py +47 -1
  14. sglang/srt/layers/logits_processor.py +4 -4
  15. sglang/srt/layers/sampler.py +69 -16
  16. sglang/srt/managers/controller_multi.py +5 -5
  17. sglang/srt/managers/controller_single.py +5 -5
  18. sglang/srt/managers/io_struct.py +11 -5
  19. sglang/srt/managers/schedule_batch.py +25 -13
  20. sglang/srt/managers/tokenizer_manager.py +76 -63
  21. sglang/srt/managers/tp_worker.py +47 -36
  22. sglang/srt/model_config.py +3 -3
  23. sglang/srt/model_executor/cuda_graph_runner.py +24 -9
  24. sglang/srt/model_executor/forward_batch_info.py +78 -43
  25. sglang/srt/model_executor/model_runner.py +29 -18
  26. sglang/srt/models/chatglm.py +5 -13
  27. sglang/srt/models/commandr.py +5 -1
  28. sglang/srt/models/dbrx.py +5 -1
  29. sglang/srt/models/deepseek.py +5 -1
  30. sglang/srt/models/deepseek_v2.py +57 -25
  31. sglang/srt/models/exaone.py +399 -0
  32. sglang/srt/models/gemma.py +7 -3
  33. sglang/srt/models/gemma2.py +6 -52
  34. sglang/srt/models/gpt_bigcode.py +5 -1
  35. sglang/srt/models/grok.py +14 -4
  36. sglang/srt/models/internlm2.py +5 -1
  37. sglang/srt/models/llama2.py +10 -7
  38. sglang/srt/models/llama_classification.py +2 -6
  39. sglang/srt/models/llama_embedding.py +3 -4
  40. sglang/srt/models/llava.py +69 -91
  41. sglang/srt/models/llavavid.py +40 -86
  42. sglang/srt/models/minicpm.py +5 -1
  43. sglang/srt/models/mixtral.py +6 -2
  44. sglang/srt/models/mixtral_quant.py +5 -1
  45. sglang/srt/models/qwen.py +5 -2
  46. sglang/srt/models/qwen2.py +9 -6
  47. sglang/srt/models/qwen2_moe.py +12 -33
  48. sglang/srt/models/stablelm.py +5 -1
  49. sglang/srt/models/yivl.py +2 -7
  50. sglang/srt/openai_api/adapter.py +16 -1
  51. sglang/srt/openai_api/protocol.py +5 -5
  52. sglang/srt/sampling/sampling_batch_info.py +79 -6
  53. sglang/srt/server.py +9 -9
  54. sglang/srt/utils.py +18 -36
  55. sglang/test/runners.py +2 -2
  56. sglang/test/test_layernorm.py +53 -1
  57. sglang/version.py +1 -1
  58. {sglang-0.2.14.post1.dist-info → sglang-0.2.15.dist-info}/METADATA +8 -8
  59. sglang-0.2.15.dist-info/RECORD +118 -0
  60. sglang-0.2.14.post1.dist-info/RECORD +0 -114
  61. {sglang-0.2.14.post1.dist-info → sglang-0.2.15.dist-info}/LICENSE +0 -0
  62. {sglang-0.2.14.post1.dist-info → sglang-0.2.15.dist-info}/WHEEL +0 -0
  63. {sglang-0.2.14.post1.dist-info → sglang-0.2.15.dist-info}/top_level.txt +0 -0
@@ -15,6 +15,7 @@ limitations under the License.
15
15
 
16
16
  """Utilities for Huggingface Transformers."""
17
17
 
18
+ import contextlib
18
19
  import functools
19
20
  import json
20
21
  import os
@@ -34,15 +35,20 @@ from transformers import (
34
35
  try:
35
36
  from vllm.transformers_utils.configs import ChatGLMConfig, DbrxConfig
36
37
 
38
+ from sglang.srt.configs import ExaoneConfig
39
+
37
40
  _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
38
41
  ChatGLMConfig.model_type: ChatGLMConfig,
39
42
  DbrxConfig.model_type: DbrxConfig,
43
+ ExaoneConfig.model_type: ExaoneConfig,
40
44
  }
41
45
  except ImportError:
42
46
  # We want this file to run without vllm dependency
43
47
  _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {}
44
48
 
45
- from sglang.srt.utils import is_multimodal_model
49
+ for name, cls in _CONFIG_REGISTRY.items():
50
+ with contextlib.suppress(ValueError):
51
+ AutoConfig.register(name, cls)
46
52
 
47
53
 
48
54
  def download_from_hf(model_path: str):
@@ -52,17 +58,11 @@ def download_from_hf(model_path: str):
52
58
  return snapshot_download(model_path, allow_patterns=["*.json", "*.bin", "*.model"])
53
59
 
54
60
 
55
- def get_config_json(model_path: str):
56
- with open(os.path.join(model_path, "config.json")) as f:
57
- config = json.load(f)
58
- return config
59
-
60
-
61
61
  def get_config(
62
62
  model: str,
63
63
  trust_remote_code: bool,
64
64
  revision: Optional[str] = None,
65
- model_overide_args: Optional[dict] = None,
65
+ model_override_args: Optional[dict] = None,
66
66
  ):
67
67
  config = AutoConfig.from_pretrained(
68
68
  model, trust_remote_code=trust_remote_code, revision=revision
@@ -70,8 +70,8 @@ def get_config(
70
70
  if config.model_type in _CONFIG_REGISTRY:
71
71
  config_class = _CONFIG_REGISTRY[config.model_type]
72
72
  config = config_class.from_pretrained(model, revision=revision)
73
- if model_overide_args:
74
- config.update(model_overide_args)
73
+ if model_override_args:
74
+ config.update(model_override_args)
75
75
  return config
76
76
 
77
77
 
@@ -89,7 +89,7 @@ CONTEXT_LENGTH_KEYS = [
89
89
 
90
90
 
91
91
  def get_context_length(config):
92
- """Get the context length of a model from a huggingface model config."""
92
+ """Get the context length of a model from a huggingface model configs."""
93
93
  rope_scaling = getattr(config, "rope_scaling", None)
94
94
  if rope_scaling:
95
95
  rope_scaling_factor = config.rope_scaling["factor"]
@@ -119,24 +119,7 @@ def get_tokenizer(
119
119
  tokenizer_revision: Optional[str] = None,
120
120
  **kwargs,
121
121
  ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
122
- if tokenizer_name.endswith(".json"):
123
- return TiktokenTokenizer(tokenizer_name)
124
-
125
- if tokenizer_name.endswith(".model"):
126
- return SentencePieceTokenizer(tokenizer_name)
127
-
128
122
  """Gets a tokenizer for the given model name via Huggingface."""
129
- if is_multimodal_model(tokenizer_name):
130
- processor = get_processor(
131
- tokenizer_name,
132
- *args,
133
- trust_remote_code=trust_remote_code,
134
- tokenizer_revision=tokenizer_revision,
135
- **kwargs,
136
- )
137
- tokenizer = processor.tokenizer
138
- return tokenizer
139
-
140
123
  if tokenizer_mode == "slow":
141
124
  if kwargs.get("use_fast", False):
142
125
  raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
@@ -199,135 +182,3 @@ def get_processor(
199
182
  **kwargs,
200
183
  )
201
184
  return processor
202
-
203
-
204
- class TiktokenTokenizer:
205
- def __init__(self, tokenizer_path):
206
- import tiktoken
207
- from jinja2 import Template
208
-
209
- PAT_STR_B = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
210
-
211
- # Read JSON
212
- name = "tmp-json"
213
- with open(tokenizer_path, "rb") as fin:
214
- tok_dict = json.load(fin)
215
-
216
- mergeable_ranks = {
217
- bytes(item["bytes"]): item["token"] for item in tok_dict["regular_tokens"]
218
- }
219
- special_tokens = {
220
- bytes(item["bytes"]).decode(): item["token"]
221
- for item in tok_dict["special_tokens"]
222
- }
223
- assert tok_dict["word_split"] == "V1"
224
-
225
- default_allowed_special = None
226
-
227
- kwargs = {
228
- "name": name,
229
- "pat_str": tok_dict.get("pat_str", PAT_STR_B),
230
- "mergeable_ranks": mergeable_ranks,
231
- "special_tokens": special_tokens,
232
- }
233
- if "default_allowed_special" in tok_dict:
234
- default_allowed_special = set(
235
- [
236
- bytes(bytes_list).decode()
237
- for bytes_list in tok_dict["default_allowed_special"]
238
- ]
239
- )
240
- if "vocab_size" in tok_dict:
241
- kwargs["explicit_n_vocab"] = tok_dict["vocab_size"]
242
-
243
- PAD = "<|pad|>"
244
- EOS = "<|eos|>"
245
- SEP = "<|separator|>"
246
-
247
- DEFAULT_CONTROL_TOKENS = {"pad": PAD, "sep": EOS, "eos": SEP}
248
-
249
- tokenizer = tiktoken.Encoding(**kwargs)
250
- tokenizer._default_allowed_special = default_allowed_special or set()
251
- tokenizer._control_tokens = DEFAULT_CONTROL_TOKENS
252
-
253
- def encode_patched(
254
- self,
255
- text: str,
256
- *,
257
- allowed_special: Union[
258
- Literal["all"], AbstractSet[str]
259
- ] = set(), # noqa: B006
260
- disallowed_special: Union[Literal["all"], Collection[str]] = "all",
261
- ) -> List[int]:
262
- if isinstance(allowed_special, set):
263
- allowed_special |= self._default_allowed_special
264
- return tiktoken.Encoding.encode(
265
- self,
266
- text,
267
- allowed_special=allowed_special,
268
- disallowed_special=(),
269
- )
270
-
271
- tokenizer.encode = functools.partial(encode_patched, tokenizer)
272
-
273
- # Convert to HF interface
274
- self.tokenizer = tokenizer
275
- self.eos_token_id = tokenizer._special_tokens[EOS]
276
- self.vocab_size = tokenizer.n_vocab
277
- self.chat_template = Template(
278
- "{% for message in messages %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'system' %}{{ 'System: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + '<|separator|>\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"
279
- )
280
-
281
- def encode(self, x, add_special_tokens=False):
282
- return self.tokenizer.encode(x)
283
-
284
- def decode(self, x):
285
- return self.tokenizer.decode(x)
286
-
287
- def batch_decode(
288
- self, batch, skip_special_tokens=True, spaces_between_special_tokens=False
289
- ):
290
- if isinstance(batch[0], int):
291
- batch = [[x] for x in batch]
292
- return self.tokenizer.decode_batch(batch)
293
-
294
- def apply_chat_template(self, messages, tokenize, add_generation_prompt):
295
- ret = self.chat_template.render(
296
- messages=messages, add_generation_prompt=add_generation_prompt
297
- )
298
- return self.encode(ret) if tokenize else ret
299
-
300
-
301
- class SentencePieceTokenizer:
302
- def __init__(self, tokenizer_path):
303
- import sentencepiece as spm
304
- from jinja2 import Template
305
-
306
- tokenizer = spm.SentencePieceProcessor(model_file=tokenizer_path)
307
-
308
- # Convert to HF interface
309
- self.tokenizer = tokenizer
310
- self.eos_token_id = tokenizer.eos_id()
311
- self.vocab_size = tokenizer.vocab_size()
312
- self.chat_template = Template(
313
- "{% for message in messages %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'system' %}{{ 'System: ' + message['content'].strip() + '<|separator|>\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + '<|separator|>\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"
314
- )
315
-
316
- def encode(self, x, add_special_tokens=False):
317
- return self.tokenizer.encode(x)
318
-
319
- def decode(self, x):
320
- return self.tokenizer.decode(x)
321
-
322
- def batch_decode(
323
- self, batch, skip_special_tokens=True, spaces_between_special_tokens=False
324
- ):
325
- if isinstance(batch[0], int):
326
- batch = [[x] for x in batch]
327
- return self.tokenizer.decode(batch)
328
-
329
- def apply_chat_template(self, messages, tokenize, add_generation_prompt):
330
- ret = self.chat_template.render(
331
- messages=messages, add_generation_prompt=add_generation_prompt
332
- )
333
- return self.encode(ret) if tokenize else ret
@@ -18,7 +18,7 @@ from typing import Optional
18
18
  import torch
19
19
  import torch.nn as nn
20
20
  import torch.nn.functional as F
21
- from flashinfer.activation import gelu_tanh_and_mul, silu_and_mul
21
+ from flashinfer.activation import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
22
22
  from vllm.distributed import (
23
23
  divide,
24
24
  get_tensor_model_parallel_rank,
@@ -43,18 +43,24 @@ class SiluAndMul(CustomOp):
43
43
 
44
44
 
45
45
  class GeluAndMul(CustomOp):
46
- def __init__(self, **kwargs):
46
+ def __init__(self, approximate="tanh"):
47
47
  super().__init__()
48
+ self.approximate = approximate
48
49
 
49
50
  def forward_native(self, x: torch.Tensor) -> torch.Tensor:
50
51
  d = x.shape[-1] // 2
51
- return F.gelu(x[..., :d], approximate="tanh") * x[..., d:]
52
+ return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
52
53
 
53
54
  def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
54
55
  d = x.shape[-1] // 2
55
56
  output_shape = x.shape[:-1] + (d,)
56
57
  out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
57
- gelu_tanh_and_mul(x, out)
58
+ if self.approximate == "tanh":
59
+ gelu_tanh_and_mul(x, out)
60
+ elif self.approximate == "none":
61
+ gelu_and_mul(x, out)
62
+ else:
63
+ raise RuntimeError("GeluAndMul only support tanh or none")
58
64
  return out
59
65
 
60
66
 
@@ -127,8 +127,7 @@ def _fwd_kernel(
127
127
  )
128
128
  k = tl.load(K_Buffer + offs_buf_k, mask=mask_n[None, :], other=0.0)
129
129
 
130
- qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
131
- qk += tl.dot(q, k)
130
+ qk = tl.dot(q.to(k.dtype), k)
132
131
  if BLOCK_DPE > 0:
133
132
  offs_kpe = (
134
133
  offs_kv_loc[None, :] * stride_buf_kbs
@@ -140,7 +139,7 @@ def _fwd_kernel(
140
139
  mask=mask_n[None, :],
141
140
  other=0.0,
142
141
  )
143
- qk += tl.dot(qpe, kpe)
142
+ qk += tl.dot(qpe.to(kpe.dtype), kpe)
144
143
  qk *= sm_scale
145
144
 
146
145
  if logit_cap > 0:
@@ -179,9 +178,7 @@ def _fwd_kernel(
179
178
  )
180
179
  k = tl.load(K_Extend + offs_k, mask=mask_n[None, :], other=0.0)
181
180
 
182
- qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
183
- qk += tl.dot(q, k)
184
-
181
+ qk = tl.dot(q, k, out_dtype=tl.float32)
185
182
  if BLOCK_DPE > 0:
186
183
  offs_kpe = (
187
184
  (cur_seq_extend_start_contiguous + start_n + offs_n[None, :])
@@ -276,9 +273,17 @@ def extend_attention_fwd(
276
273
  BLOCK_DV = Lv
277
274
 
278
275
  if CUDA_CAPABILITY[0] >= 9:
279
- BLOCK_M, BLOCK_N = (128, 64)
276
+ if Lq <= 256:
277
+ BLOCK_M, BLOCK_N = (128, 64)
278
+ else:
279
+ BLOCK_M, BLOCK_N = (32, 64)
280
280
  elif CUDA_CAPABILITY[0] >= 8:
281
- BLOCK_M, BLOCK_N = (128, 128) if Lq <= 128 else (64, 64)
281
+ if Lq <= 128:
282
+ BLOCK_M, BLOCK_N = (128, 128)
283
+ elif Lq <= 256:
284
+ BLOCK_M, BLOCK_N = (64, 64)
285
+ else:
286
+ BLOCK_M, BLOCK_N = (32, 64)
282
287
  else:
283
288
  BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32)
284
289
 
@@ -19,7 +19,12 @@ from typing import Optional, Tuple, Union
19
19
 
20
20
  import torch
21
21
  import torch.nn as nn
22
- from flashinfer.norm import fused_add_rmsnorm, rmsnorm
22
+ from flashinfer.norm import (
23
+ fused_add_rmsnorm,
24
+ gemma_fused_add_rmsnorm,
25
+ gemma_rmsnorm,
26
+ rmsnorm,
27
+ )
23
28
  from vllm.model_executor.custom_op import CustomOp
24
29
 
25
30
 
@@ -63,3 +68,44 @@ class RMSNorm(CustomOp):
63
68
  return x
64
69
  else:
65
70
  return x, residual
71
+
72
+
73
+ class GemmaRMSNorm(CustomOp):
74
+ def __init__(
75
+ self,
76
+ hidden_size: int,
77
+ eps: float = 1e-6,
78
+ ) -> None:
79
+ super().__init__()
80
+ self.weight = nn.Parameter(torch.zeros(hidden_size))
81
+ self.variance_epsilon = eps
82
+
83
+ def forward_native(
84
+ self,
85
+ x: torch.Tensor,
86
+ residual: Optional[torch.Tensor] = None,
87
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
88
+ orig_dtype = x.dtype
89
+ if residual is not None:
90
+ x = x + residual
91
+ residual = x
92
+
93
+ x = x.float()
94
+ variance = x.pow(2).mean(dim=-1, keepdim=True)
95
+ x = x * torch.rsqrt(variance + self.variance_epsilon)
96
+ x = x * (1.0 + self.weight.float())
97
+ x = x.to(orig_dtype)
98
+ return x if residual is None else (x, residual)
99
+
100
+ def forward_cuda(
101
+ self,
102
+ x: torch.Tensor,
103
+ residual: Optional[torch.Tensor] = None,
104
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
105
+ if residual is not None:
106
+ gemma_fused_add_rmsnorm(
107
+ x, residual, self.weight.data, self.variance_epsilon
108
+ )
109
+ return x, residual
110
+ out = gemma_rmsnorm(x, self.weight.data, self.variance_epsilon)
111
+ return out
@@ -29,7 +29,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetad
29
29
 
30
30
 
31
31
  @dataclasses.dataclass
32
- class LogitProcessorOutput:
32
+ class LogitsProcessorOutput:
33
33
  # The logits of the next tokens. shape: [#seq, vocab_size]
34
34
  next_token_logits: torch.Tensor
35
35
  # The logprobs of the next tokens. shape: [#seq, vocab_size]
@@ -185,7 +185,7 @@ class LogitsProcessor(nn.Module):
185
185
 
186
186
  # Return only last_logits if logprob is not requested
187
187
  if not logits_metadata.return_logprob:
188
- return LogitProcessorOutput(
188
+ return LogitsProcessorOutput(
189
189
  next_token_logits=last_logits,
190
190
  next_token_logprobs=None,
191
191
  normalized_prompt_logprobs=None,
@@ -209,7 +209,7 @@ class LogitsProcessor(nn.Module):
209
209
  else:
210
210
  output_top_logprobs = None
211
211
 
212
- return LogitProcessorOutput(
212
+ return LogitsProcessorOutput(
213
213
  next_token_logits=last_logits,
214
214
  next_token_logprobs=last_logprobs,
215
215
  normalized_prompt_logprobs=None,
@@ -278,7 +278,7 @@ class LogitsProcessor(nn.Module):
278
278
  # Remove the last token logprob for the prefill tokens.
279
279
  input_token_logprobs = input_token_logprobs[:-1]
280
280
 
281
- return LogitProcessorOutput(
281
+ return LogitsProcessorOutput(
282
282
  next_token_logits=last_logits,
283
283
  next_token_logprobs=last_logprobs,
284
284
  normalized_prompt_logprobs=normalized_prompt_logprobs,
@@ -1,4 +1,6 @@
1
+ import dataclasses
1
2
  import logging
3
+ from typing import Union
2
4
 
3
5
  import torch
4
6
  from flashinfer.sampling import (
@@ -9,6 +11,8 @@ from flashinfer.sampling import (
9
11
  )
10
12
  from vllm.model_executor.custom_op import CustomOp
11
13
 
14
+ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
15
+
12
16
  # TODO: move this dict to another place
13
17
  from sglang.srt.managers.schedule_batch import global_server_args_dict
14
18
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
@@ -16,30 +20,71 @@ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
16
20
  logger = logging.getLogger(__name__)
17
21
 
18
22
 
23
+ @dataclasses.dataclass
24
+ class SampleOutput:
25
+ success: torch.Tensor
26
+ probs: torch.Tensor
27
+ batch_next_token_ids: torch.Tensor
28
+
29
+
19
30
  class Sampler(CustomOp):
20
31
  def __init__(self):
21
32
  super().__init__()
22
33
 
23
- def forward_cuda(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
34
+ def _apply_penalties(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
35
+ # min-token, presence, frequency
36
+ if sampling_info.linear_penalties is not None:
37
+ logits += sampling_info.linear_penalties
38
+
39
+ # repetition
40
+ if sampling_info.scaling_penalties is not None:
41
+ logits = torch.where(
42
+ logits > 0,
43
+ logits / sampling_info.scaling_penalties,
44
+ logits * sampling_info.scaling_penalties,
45
+ )
46
+
47
+ return logits
48
+
49
+ def _get_probs(
50
+ self,
51
+ logits: torch.Tensor,
52
+ sampling_info: SamplingBatchInfo,
53
+ is_torch_compile: bool = False,
54
+ ):
24
55
  # Post process logits
25
56
  logits = logits.contiguous()
26
57
  logits.div_(sampling_info.temperatures)
58
+ if is_torch_compile:
59
+ # FIXME: Temporary workaround for unknown bugs in torch.compile
60
+ logits.add_(0)
61
+
27
62
  if sampling_info.logit_bias is not None:
28
63
  logits.add_(sampling_info.logit_bias)
29
64
 
30
65
  if sampling_info.vocab_mask is not None:
31
- logits = logits.masked_fill(~sampling_info.vocab_mask, float("-inf"))
66
+ logits = logits.masked_fill(sampling_info.vocab_mask, float("-inf"))
32
67
 
33
- logits = sampling_info.penalizer_orchestrator.apply(logits)
68
+ logits = self._apply_penalties(logits, sampling_info)
34
69
 
35
- probs = torch.softmax(logits, dim=-1)
70
+ return torch.softmax(logits, dim=-1)
71
+
72
+ def forward_cuda(
73
+ self,
74
+ logits: Union[torch.Tensor, LogitsProcessorOutput],
75
+ sampling_info: SamplingBatchInfo,
76
+ ):
77
+ if isinstance(logits, LogitsProcessorOutput):
78
+ logits = logits.next_token_logits
79
+
80
+ probs = self._get_probs(logits, sampling_info)
36
81
 
37
82
  if not global_server_args_dict["disable_flashinfer_sampling"]:
38
83
  max_top_k_round, batch_size = 32, probs.shape[0]
39
84
  uniform_samples = torch.rand(
40
85
  (max_top_k_round, batch_size), device=probs.device
41
86
  )
42
- if sampling_info.min_ps.any():
87
+ if sampling_info.need_min_p_sampling:
43
88
  probs = top_k_renorm_prob(probs, sampling_info.top_ks)
44
89
  probs = top_p_renorm_prob(probs, sampling_info.top_ps)
45
90
  batch_next_token_ids, success = min_p_sampling_from_probs(
@@ -55,18 +100,23 @@ class Sampler(CustomOp):
55
100
  probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
56
101
  )
57
102
 
58
- if not torch.all(success):
59
- logging.warning("Sampling failed, fallback to top_k=1 strategy")
60
- probs = probs.masked_fill(torch.isnan(probs), 0.0)
61
- argmax_ids = torch.argmax(probs, dim=-1)
62
- batch_next_token_ids = torch.where(
63
- success, batch_next_token_ids, argmax_ids
64
- )
103
+ return SampleOutput(success, probs, batch_next_token_ids)
65
104
 
66
- return batch_next_token_ids
105
+ def forward_native(
106
+ self,
107
+ logits: Union[torch.Tensor, LogitsProcessorOutput],
108
+ sampling_info: SamplingBatchInfo,
109
+ ):
110
+ if isinstance(logits, LogitsProcessorOutput):
111
+ logits = logits.next_token_logits
112
+
113
+ probs = self._get_probs(logits, sampling_info, is_torch_compile=True)
114
+
115
+ batch_next_token_ids, success = top_k_top_p_min_p_sampling_from_probs_torch(
116
+ probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
117
+ )
67
118
 
68
- def forward_native():
69
- raise NotImplementedError("Native forward is not implemented yet.")
119
+ return SampleOutput(success, probs, batch_next_token_ids)
70
120
 
71
121
 
72
122
  def top_k_top_p_min_p_sampling_from_probs_torch(
@@ -87,7 +137,10 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
87
137
  probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0
88
138
  probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
89
139
  try:
90
- sampled_index = torch.multinomial(probs_sort, num_samples=1)
140
+ # FIXME: torch.multiomial does not support num_samples = 1
141
+ sampled_index = torch.multinomial(probs_sort, num_samples=2, replacement=True)[
142
+ :, :1
143
+ ]
91
144
  except RuntimeError as e:
92
145
  logger.warning(f"Sampling error: {e}")
93
146
  batch_next_token_ids = torch.zeros(
@@ -71,12 +71,12 @@ class ControllerMulti:
71
71
  self,
72
72
  server_args: ServerArgs,
73
73
  port_args: PortArgs,
74
- model_overide_args,
74
+ model_override_args,
75
75
  ):
76
76
  # Parse args
77
77
  self.server_args = server_args
78
78
  self.port_args = port_args
79
- self.model_overide_args = model_overide_args
79
+ self.model_override_args = model_override_args
80
80
  self.load_balance_method = LoadBalanceMethod.from_str(
81
81
  server_args.load_balance_method
82
82
  )
@@ -114,7 +114,7 @@ class ControllerMulti:
114
114
  self.server_args,
115
115
  self.port_args,
116
116
  pipe_controller_writer,
117
- self.model_overide_args,
117
+ self.model_override_args,
118
118
  True,
119
119
  gpu_ids,
120
120
  dp_worker_id,
@@ -189,14 +189,14 @@ def start_controller_process(
189
189
  server_args: ServerArgs,
190
190
  port_args: PortArgs,
191
191
  pipe_writer,
192
- model_overide_args: dict,
192
+ model_override_args: dict,
193
193
  ):
194
194
  """Start a controller process."""
195
195
 
196
196
  configure_logger(server_args)
197
197
 
198
198
  try:
199
- controller = ControllerMulti(server_args, port_args, model_overide_args)
199
+ controller = ControllerMulti(server_args, port_args, model_override_args)
200
200
  except Exception:
201
201
  pipe_writer.send(get_exception_traceback())
202
202
  raise
@@ -40,7 +40,7 @@ class ControllerSingle:
40
40
  self,
41
41
  server_args: ServerArgs,
42
42
  port_args: PortArgs,
43
- model_overide_args: dict,
43
+ model_override_args: dict,
44
44
  gpu_ids: List[int],
45
45
  is_data_parallel_worker: bool,
46
46
  dp_worker_id: int,
@@ -76,7 +76,7 @@ class ControllerSingle:
76
76
  tp_rank_range,
77
77
  server_args,
78
78
  port_args.nccl_ports[dp_worker_id],
79
- model_overide_args,
79
+ model_override_args,
80
80
  )
81
81
 
82
82
  # Launch tp rank 0
@@ -85,7 +85,7 @@ class ControllerSingle:
85
85
  0,
86
86
  server_args,
87
87
  port_args.nccl_ports[dp_worker_id],
88
- model_overide_args,
88
+ model_override_args,
89
89
  )
90
90
  self.tp_cpu_group = self.tp_server.model_runner.tp_group.cpu_group
91
91
 
@@ -126,7 +126,7 @@ def start_controller_process(
126
126
  server_args: ServerArgs,
127
127
  port_args: PortArgs,
128
128
  pipe_writer: multiprocessing.connection.Connection,
129
- model_overide_args: dict,
129
+ model_override_args: dict,
130
130
  is_data_parallel_worker: bool = False,
131
131
  gpu_ids: List[int] = None,
132
132
  dp_worker_id: int = None,
@@ -149,7 +149,7 @@ def start_controller_process(
149
149
  controller = ControllerSingle(
150
150
  server_args,
151
151
  port_args,
152
- model_overide_args,
152
+ model_override_args,
153
153
  gpu_ids,
154
154
  is_data_parallel_worker,
155
155
  dp_worker_id,
@@ -18,8 +18,9 @@ The definition of objects transfered between different
18
18
  processes (TokenizerManager, DetokenizerManager, Controller).
19
19
  """
20
20
 
21
+ import copy
21
22
  import uuid
22
- from dataclasses import dataclass
23
+ from dataclasses import dataclass, field
23
24
  from typing import Dict, List, Optional, Union
24
25
 
25
26
  from sglang.srt.managers.schedule_batch import BaseFinishReason
@@ -55,6 +56,7 @@ class GenerateReqInput:
55
56
  self.text is not None and self.input_ids is not None
56
57
  ):
57
58
  raise ValueError("Either text or input_ids should be provided.")
59
+
58
60
  if (
59
61
  isinstance(self.sampling_params, dict)
60
62
  and self.sampling_params.get("n", 1) != 1
@@ -161,10 +163,10 @@ class TokenizedGenerateReqInput:
161
163
  input_ids: List[int]
162
164
  # The pixel values for input images
163
165
  pixel_values: List[float]
164
- # The hash of input images
165
- image_hash: int
166
- # The image size
167
- image_size: List[int]
166
+ # The hash values of input images
167
+ image_hashes: List[int]
168
+ # The image sizes
169
+ image_sizes: List[List[int]]
168
170
  # The sampling parameters
169
171
  sampling_params: SamplingParams
170
172
  # Whether to return the logprobs
@@ -248,6 +250,10 @@ class BatchTokenIDOut:
248
250
  meta_info: List[Dict]
249
251
  finished_reason: List[BaseFinishReason]
250
252
 
253
+ def __post_init__(self):
254
+ # deepcopy meta_info to avoid modification in place
255
+ self.meta_info = copy.deepcopy(self.meta_info)
256
+
251
257
 
252
258
  @dataclass
253
259
  class BatchStrOut: