sglang 0.2.14__py3-none-any.whl → 0.2.14.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. sglang/srt/constrained/fsm_cache.py +11 -2
  2. sglang/srt/constrained/jump_forward.py +1 -0
  3. sglang/srt/layers/activation.py +83 -7
  4. sglang/srt/layers/layernorm.py +0 -3
  5. sglang/srt/layers/logits_processor.py +4 -4
  6. sglang/srt/layers/sampler.py +15 -68
  7. sglang/srt/managers/schedule_batch.py +15 -20
  8. sglang/srt/managers/tp_worker.py +40 -33
  9. sglang/srt/model_executor/cuda_graph_runner.py +17 -31
  10. sglang/srt/model_executor/forward_batch_info.py +1 -8
  11. sglang/srt/model_executor/model_runner.py +5 -11
  12. sglang/srt/models/chatglm.py +12 -4
  13. sglang/srt/models/commandr.py +1 -5
  14. sglang/srt/models/dbrx.py +1 -5
  15. sglang/srt/models/deepseek.py +1 -5
  16. sglang/srt/models/deepseek_v2.py +1 -5
  17. sglang/srt/models/gemma.py +1 -5
  18. sglang/srt/models/gemma2.py +1 -5
  19. sglang/srt/models/gpt_bigcode.py +2 -6
  20. sglang/srt/models/grok.py +1 -5
  21. sglang/srt/models/internlm2.py +1 -5
  22. sglang/srt/models/llama2.py +3 -7
  23. sglang/srt/models/llama_classification.py +2 -2
  24. sglang/srt/models/minicpm.py +1 -5
  25. sglang/srt/models/mixtral.py +1 -5
  26. sglang/srt/models/mixtral_quant.py +1 -5
  27. sglang/srt/models/qwen.py +2 -5
  28. sglang/srt/models/qwen2.py +2 -6
  29. sglang/srt/models/qwen2_moe.py +14 -5
  30. sglang/srt/models/stablelm.py +1 -5
  31. sglang/srt/openai_api/adapter.py +85 -4
  32. sglang/srt/openai_api/protocol.py +2 -0
  33. sglang/srt/sampling/sampling_batch_info.py +1 -74
  34. sglang/srt/sampling/sampling_params.py +4 -0
  35. sglang/srt/server.py +8 -1
  36. sglang/test/runners.py +1 -1
  37. sglang/version.py +1 -1
  38. {sglang-0.2.14.dist-info → sglang-0.2.14.post1.dist-info}/METADATA +10 -4
  39. {sglang-0.2.14.dist-info → sglang-0.2.14.post1.dist-info}/RECORD +42 -42
  40. {sglang-0.2.14.dist-info → sglang-0.2.14.post1.dist-info}/WHEEL +1 -1
  41. {sglang-0.2.14.dist-info → sglang-0.2.14.post1.dist-info}/LICENSE +0 -0
  42. {sglang-0.2.14.dist-info → sglang-0.2.14.post1.dist-info}/top_level.txt +0 -0
@@ -15,6 +15,8 @@ limitations under the License.
15
15
 
16
16
  """Cache for the compressed finite state machine."""
17
17
 
18
+ from outlines.fsm.json_schema import build_regex_from_schema
19
+
18
20
  from sglang.srt.constrained import RegexGuide, TransformerTokenizer
19
21
  from sglang.srt.constrained.base_tool_cache import BaseToolCache
20
22
 
@@ -26,9 +28,12 @@ class FSMCache(BaseToolCache):
26
28
  tokenizer_args_dict,
27
29
  enable=True,
28
30
  skip_tokenizer_init=False,
31
+ json_schema_mode=False,
29
32
  ):
30
33
  super().__init__(enable=enable)
31
34
 
35
+ self.json_schema_mode = json_schema_mode
36
+
32
37
  if (
33
38
  skip_tokenizer_init
34
39
  or tokenizer_path.endswith(".json")
@@ -72,5 +77,9 @@ class FSMCache(BaseToolCache):
72
77
  tokenizer_path, **tokenizer_args_dict
73
78
  )
74
79
 
75
- def init_value(self, regex):
76
- return RegexGuide(regex, self.outlines_tokenizer)
80
+ def init_value(self, value):
81
+ if self.json_schema_mode:
82
+ regex = build_regex_from_schema(value)
83
+ return RegexGuide(regex, self.outlines_tokenizer), regex
84
+ else:
85
+ return RegexGuide(value, self.outlines_tokenizer)
@@ -23,6 +23,7 @@ from collections import defaultdict
23
23
 
24
24
  import interegular
25
25
  import outlines.caching
26
+ from outlines.fsm.json_schema import build_regex_from_schema
26
27
 
27
28
  from sglang.srt.constrained import (
28
29
  FSMInfo,
@@ -13,25 +13,28 @@ limitations under the License.
13
13
 
14
14
  """Fused operators for activation layers."""
15
15
 
16
+ from typing import Optional
17
+
16
18
  import torch
19
+ import torch.nn as nn
17
20
  import torch.nn.functional as F
18
21
  from flashinfer.activation import gelu_tanh_and_mul, silu_and_mul
22
+ from vllm.distributed import (
23
+ divide,
24
+ get_tensor_model_parallel_rank,
25
+ get_tensor_model_parallel_world_size,
26
+ )
19
27
  from vllm.model_executor.custom_op import CustomOp
28
+ from vllm.model_executor.layers.quantization import QuantizationConfig
29
+ from vllm.model_executor.utils import set_weight_attrs
20
30
 
21
31
 
22
32
  class SiluAndMul(CustomOp):
23
- def __init__(self, **kwargs):
24
- super().__init__()
25
- self.is_lower_sm80 = torch.cuda.get_device_capability()[0] < 8
26
-
27
33
  def forward_native(self, x: torch.Tensor) -> torch.Tensor:
28
34
  d = x.shape[-1] // 2
29
35
  return F.silu(x[..., :d]) * x[..., d:]
30
36
 
31
37
  def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
32
- if self.is_lower_sm80:
33
- return self.forward_native(x)
34
-
35
38
  d = x.shape[-1] // 2
36
39
  output_shape = x.shape[:-1] + (d,)
37
40
  out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
@@ -53,3 +56,76 @@ class GeluAndMul(CustomOp):
53
56
  out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
54
57
  gelu_tanh_and_mul(x, out)
55
58
  return out
59
+
60
+
61
+ class ScaledActivation(nn.Module):
62
+ """An activation function with post-scale parameters.
63
+
64
+ This is used for some quantization methods like AWQ.
65
+ """
66
+
67
+ def __init__(
68
+ self,
69
+ act_module: nn.Module,
70
+ intermediate_size: int,
71
+ input_is_parallel: bool = True,
72
+ params_dtype: Optional[torch.dtype] = None,
73
+ ):
74
+ super().__init__()
75
+ self.act = act_module
76
+ self.input_is_parallel = input_is_parallel
77
+ if input_is_parallel:
78
+ tp_size = get_tensor_model_parallel_world_size()
79
+ intermediate_size_per_partition = divide(intermediate_size, tp_size)
80
+ else:
81
+ intermediate_size_per_partition = intermediate_size
82
+ if params_dtype is None:
83
+ params_dtype = torch.get_default_dtype()
84
+ self.scales = nn.Parameter(
85
+ torch.empty(intermediate_size_per_partition, dtype=params_dtype)
86
+ )
87
+ set_weight_attrs(self.scales, {"weight_loader": self.weight_loader})
88
+
89
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
90
+ return self.act(x) / self.scales
91
+
92
+ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
93
+ param_data = param.data
94
+ if self.input_is_parallel:
95
+ tp_rank = get_tensor_model_parallel_rank()
96
+ shard_size = param_data.shape[0]
97
+ start_idx = tp_rank * shard_size
98
+ loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
99
+ assert param_data.shape == loaded_weight.shape
100
+ param_data.copy_(loaded_weight)
101
+
102
+
103
+ _ACTIVATION_REGISTRY = {
104
+ "gelu": nn.GELU(),
105
+ "gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
106
+ }
107
+
108
+
109
+ def get_act_fn(
110
+ act_fn_name: str,
111
+ quant_config: Optional[QuantizationConfig] = None,
112
+ intermediate_size: Optional[int] = None,
113
+ input_is_parallel: bool = True,
114
+ params_dtype: Optional[torch.dtype] = None,
115
+ ) -> nn.Module:
116
+ """Get an activation function by name."""
117
+ act_fn_name = act_fn_name.lower()
118
+ if act_fn_name not in _ACTIVATION_REGISTRY:
119
+ raise ValueError(f"Activation function {act_fn_name!r} is not supported.")
120
+
121
+ act_fn = _ACTIVATION_REGISTRY[act_fn_name]
122
+ if quant_config is not None and act_fn_name in quant_config.get_scaled_act_names():
123
+ if intermediate_size is None:
124
+ raise ValueError(
125
+ "intermediate_size must be specified for scaled "
126
+ "activation functions."
127
+ )
128
+ return ScaledActivation(
129
+ act_fn, intermediate_size, input_is_parallel, params_dtype
130
+ )
131
+ return act_fn
@@ -32,15 +32,12 @@ class RMSNorm(CustomOp):
32
32
  super().__init__()
33
33
  self.weight = nn.Parameter(torch.ones(hidden_size))
34
34
  self.variance_epsilon = eps
35
- self.is_lower_sm80 = torch.cuda.get_device_capability()[0] < 8
36
35
 
37
36
  def forward_cuda(
38
37
  self,
39
38
  x: torch.Tensor,
40
39
  residual: Optional[torch.Tensor] = None,
41
40
  ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
42
- if self.is_lower_sm80:
43
- return self.forward_native(x, residual)
44
41
 
45
42
  if residual is not None:
46
43
  fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon)
@@ -29,7 +29,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetad
29
29
 
30
30
 
31
31
  @dataclasses.dataclass
32
- class LogitsProcessorOutput:
32
+ class LogitProcessorOutput:
33
33
  # The logits of the next tokens. shape: [#seq, vocab_size]
34
34
  next_token_logits: torch.Tensor
35
35
  # The logprobs of the next tokens. shape: [#seq, vocab_size]
@@ -185,7 +185,7 @@ class LogitsProcessor(nn.Module):
185
185
 
186
186
  # Return only last_logits if logprob is not requested
187
187
  if not logits_metadata.return_logprob:
188
- return LogitsProcessorOutput(
188
+ return LogitProcessorOutput(
189
189
  next_token_logits=last_logits,
190
190
  next_token_logprobs=None,
191
191
  normalized_prompt_logprobs=None,
@@ -209,7 +209,7 @@ class LogitsProcessor(nn.Module):
209
209
  else:
210
210
  output_top_logprobs = None
211
211
 
212
- return LogitsProcessorOutput(
212
+ return LogitProcessorOutput(
213
213
  next_token_logits=last_logits,
214
214
  next_token_logprobs=last_logprobs,
215
215
  normalized_prompt_logprobs=None,
@@ -278,7 +278,7 @@ class LogitsProcessor(nn.Module):
278
278
  # Remove the last token logprob for the prefill tokens.
279
279
  input_token_logprobs = input_token_logprobs[:-1]
280
280
 
281
- return LogitsProcessorOutput(
281
+ return LogitProcessorOutput(
282
282
  next_token_logits=last_logits,
283
283
  next_token_logprobs=last_logprobs,
284
284
  normalized_prompt_logprobs=normalized_prompt_logprobs,
@@ -1,6 +1,4 @@
1
- import dataclasses
2
1
  import logging
3
- from typing import Union
4
2
 
5
3
  import torch
6
4
  from flashinfer.sampling import (
@@ -11,8 +9,6 @@ from flashinfer.sampling import (
11
9
  )
12
10
  from vllm.model_executor.custom_op import CustomOp
13
11
 
14
- from sglang.srt.layers.logits_processor import LogitsProcessorOutput
15
-
16
12
  # TODO: move this dict to another place
17
13
  from sglang.srt.managers.schedule_batch import global_server_args_dict
18
14
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
@@ -20,71 +16,30 @@ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
20
16
  logger = logging.getLogger(__name__)
21
17
 
22
18
 
23
- @dataclasses.dataclass
24
- class SampleOutput:
25
- success: torch.Tensor
26
- probs: torch.Tensor
27
- batch_next_token_ids: torch.Tensor
28
-
29
-
30
19
  class Sampler(CustomOp):
31
20
  def __init__(self):
32
21
  super().__init__()
33
22
 
34
- def _apply_penalties(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
35
- # min-token, presence, frequency
36
- if sampling_info.linear_penalties is not None:
37
- logits += sampling_info.linear_penalties
38
-
39
- # repetition
40
- if sampling_info.scaling_penalties is not None:
41
- logits = torch.where(
42
- logits > 0,
43
- logits / sampling_info.scaling_penalties,
44
- logits * sampling_info.scaling_penalties,
45
- )
46
-
47
- return logits
48
-
49
- def _get_probs(
50
- self,
51
- logits: torch.Tensor,
52
- sampling_info: SamplingBatchInfo,
53
- is_torch_compile: bool = False,
54
- ):
23
+ def forward_cuda(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
55
24
  # Post process logits
56
25
  logits = logits.contiguous()
57
26
  logits.div_(sampling_info.temperatures)
58
- if is_torch_compile:
59
- # FIXME: Temporary workaround for unknown bugs in torch.compile
60
- logits.add_(0)
61
-
62
27
  if sampling_info.logit_bias is not None:
63
28
  logits.add_(sampling_info.logit_bias)
64
29
 
65
30
  if sampling_info.vocab_mask is not None:
66
31
  logits = logits.masked_fill(~sampling_info.vocab_mask, float("-inf"))
67
32
 
68
- logits = self._apply_penalties(logits, sampling_info)
33
+ logits = sampling_info.penalizer_orchestrator.apply(logits)
69
34
 
70
- return torch.softmax(logits, dim=-1)
71
-
72
- def forward_cuda(
73
- self,
74
- logits: Union[torch.Tensor, LogitsProcessorOutput],
75
- sampling_info: SamplingBatchInfo,
76
- ):
77
- if isinstance(logits, LogitsProcessorOutput):
78
- logits = logits.next_token_logits
79
-
80
- probs = self._get_probs(logits, sampling_info)
35
+ probs = torch.softmax(logits, dim=-1)
81
36
 
82
37
  if not global_server_args_dict["disable_flashinfer_sampling"]:
83
38
  max_top_k_round, batch_size = 32, probs.shape[0]
84
39
  uniform_samples = torch.rand(
85
40
  (max_top_k_round, batch_size), device=probs.device
86
41
  )
87
- if sampling_info.need_min_p_sampling:
42
+ if sampling_info.min_ps.any():
88
43
  probs = top_k_renorm_prob(probs, sampling_info.top_ks)
89
44
  probs = top_p_renorm_prob(probs, sampling_info.top_ps)
90
45
  batch_next_token_ids, success = min_p_sampling_from_probs(
@@ -100,23 +55,18 @@ class Sampler(CustomOp):
100
55
  probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
101
56
  )
102
57
 
103
- return SampleOutput(success, probs, batch_next_token_ids)
104
-
105
- def forward_native(
106
- self,
107
- logits: Union[torch.Tensor, LogitsProcessorOutput],
108
- sampling_info: SamplingBatchInfo,
109
- ):
110
- if isinstance(logits, LogitsProcessorOutput):
111
- logits = logits.next_token_logits
112
-
113
- probs = self._get_probs(logits, sampling_info, is_torch_compile=True)
58
+ if not torch.all(success):
59
+ logging.warning("Sampling failed, fallback to top_k=1 strategy")
60
+ probs = probs.masked_fill(torch.isnan(probs), 0.0)
61
+ argmax_ids = torch.argmax(probs, dim=-1)
62
+ batch_next_token_ids = torch.where(
63
+ success, batch_next_token_ids, argmax_ids
64
+ )
114
65
 
115
- batch_next_token_ids, success = top_k_top_p_min_p_sampling_from_probs_torch(
116
- probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps
117
- )
66
+ return batch_next_token_ids
118
67
 
119
- return SampleOutput(success, probs, batch_next_token_ids)
68
+ def forward_native():
69
+ raise NotImplementedError("Native forward is not implemented yet.")
120
70
 
121
71
 
122
72
  def top_k_top_p_min_p_sampling_from_probs_torch(
@@ -137,10 +87,7 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
137
87
  probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0
138
88
  probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
139
89
  try:
140
- # FIXME: torch.multiomial does not support num_samples = 1
141
- sampled_index = torch.multinomial(probs_sort, num_samples=2, replacement=True)[
142
- :, :1
143
- ]
90
+ sampled_index = torch.multinomial(probs_sort, num_samples=1)
144
91
  except RuntimeError as e:
145
92
  logger.warning(f"Sampling error: {e}")
146
93
  batch_next_token_ids = torch.zeros(
@@ -1,5 +1,3 @@
1
- from __future__ import annotations
2
-
3
1
  """
4
2
  Copyright 2023-2024 SGLang Team
5
3
  Licensed under the Apache License, Version 2.0 (the "License");
@@ -19,7 +17,7 @@ limitations under the License.
19
17
 
20
18
  import logging
21
19
  from dataclasses import dataclass
22
- from typing import TYPE_CHECKING, List, Optional, Union
20
+ from typing import List, Optional, Union
23
21
 
24
22
  import torch
25
23
 
@@ -31,10 +29,6 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
31
29
  from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
32
30
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
33
31
 
34
- if TYPE_CHECKING:
35
- from sglang.srt.layers.sampler import SampleOutput
36
-
37
-
38
32
  INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
39
33
 
40
34
  # Put some global args for easy access
@@ -268,7 +262,14 @@ class Req:
268
262
 
269
263
  all_text = self.origin_input_text + self.decoded_text + jump_forward_str
270
264
  all_ids = self.tokenizer.encode(all_text)
265
+ if not all_ids:
266
+ logger.warning("Encoded all_text resulted in empty all_ids")
267
+ return False
268
+
271
269
  prompt_tokens = len(self.origin_input_ids_unpadded)
270
+ if prompt_tokens > len(all_ids):
271
+ logger.warning("prompt_tokens is larger than encoded all_ids")
272
+ return False
272
273
 
273
274
  if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]:
274
275
  # TODO(lsyin): fix token fusion
@@ -677,17 +678,11 @@ class ScheduleBatch:
677
678
  self.top_logprobs_nums.extend(other.top_logprobs_nums)
678
679
  self.return_logprob = any(req.return_logprob for req in self.reqs)
679
680
 
680
- def check_sample_results(self, sample_output: SampleOutput):
681
- if not torch.all(sample_output.success):
682
- probs = sample_output.probs
683
- batch_next_token_ids = sample_output.batch_next_token_ids
684
- logging.warning("Sampling failed, fallback to top_k=1 strategy")
685
- probs = probs.masked_fill(torch.isnan(probs), 0.0)
686
- argmax_ids = torch.argmax(probs, dim=-1)
687
- batch_next_token_ids = torch.where(
688
- sample_output.success, batch_next_token_ids, argmax_ids
689
- )
690
- sample_output.probs = probs
691
- sample_output.batch_next_token_ids = batch_next_token_ids
681
+ def sample(self, logits: torch.Tensor):
682
+ from sglang.srt.layers.sampler import Sampler
683
+
684
+ sampler = Sampler()
685
+
686
+ batch_next_token_ids = sampler(logits, self.sampling_info)
692
687
 
693
- return sample_output.batch_next_token_ids
688
+ return batch_next_token_ids
@@ -31,7 +31,7 @@ from sglang.global_config import global_config
31
31
  from sglang.srt.constrained.fsm_cache import FSMCache
32
32
  from sglang.srt.constrained.jump_forward import JumpForwardCache
33
33
  from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
34
- from sglang.srt.layers.logits_processor import LogitsProcessorOutput
34
+ from sglang.srt.layers.logits_processor import LogitProcessorOutput
35
35
  from sglang.srt.managers.io_struct import (
36
36
  AbortReq,
37
37
  BatchEmbeddingOut,
@@ -197,6 +197,16 @@ class ModelTpServer:
197
197
  "trust_remote_code": server_args.trust_remote_code,
198
198
  },
199
199
  skip_tokenizer_init=server_args.skip_tokenizer_init,
200
+ json_schema_mode=False,
201
+ )
202
+ self.json_fsm_cache = FSMCache(
203
+ server_args.tokenizer_path,
204
+ {
205
+ "tokenizer_mode": server_args.tokenizer_mode,
206
+ "trust_remote_code": server_args.trust_remote_code,
207
+ },
208
+ skip_tokenizer_init=server_args.skip_tokenizer_init,
209
+ json_schema_mode=True,
200
210
  )
201
211
  self.jump_forward_cache = JumpForwardCache()
202
212
 
@@ -349,8 +359,17 @@ class ModelTpServer:
349
359
  req.top_logprobs_num = recv_req.top_logprobs_num
350
360
  req.stream = recv_req.stream
351
361
 
362
+ # Init regex fsm fron json
363
+ if req.sampling_params.json_schema is not None:
364
+ req.regex_fsm, computed_regex_string = self.json_fsm_cache.query(
365
+ req.sampling_params.json_schema
366
+ )
367
+ if not self.disable_regex_jump_forward:
368
+ req.jump_forward_map = self.jump_forward_cache.query(
369
+ computed_regex_string
370
+ )
352
371
  # Init regex fsm
353
- if req.sampling_params.regex is not None:
372
+ elif req.sampling_params.regex is not None:
354
373
  req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex)
355
374
  if not self.disable_regex_jump_forward:
356
375
  req.jump_forward_map = self.jump_forward_cache.query(
@@ -486,29 +505,21 @@ class ModelTpServer:
486
505
  if self.model_runner.is_generation:
487
506
  # Forward and sample the next tokens
488
507
  if batch.extend_num_tokens != 0:
489
- sample_output, logits_output = self.model_runner.forward(
490
- batch, ForwardMode.EXTEND
491
- )
492
- next_token_ids = batch.check_sample_results(sample_output)
508
+ output = self.model_runner.forward(batch, ForwardMode.EXTEND)
509
+ next_token_ids = batch.sample(output.next_token_logits)
493
510
  batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
494
511
  next_token_ids
495
512
  )
496
513
 
497
514
  # Move logprobs to cpu
498
- if logits_output.next_token_logprobs is not None:
499
- logits_output.next_token_logprobs = (
500
- logits_output.next_token_logprobs[
501
- torch.arange(
502
- len(next_token_ids), device=next_token_ids.device
503
- ),
504
- next_token_ids,
505
- ].tolist()
506
- )
507
- logits_output.input_token_logprobs = (
508
- logits_output.input_token_logprobs.tolist()
509
- )
510
- logits_output.normalized_prompt_logprobs = (
511
- logits_output.normalized_prompt_logprobs.tolist()
515
+ if output.next_token_logprobs is not None:
516
+ output.next_token_logprobs = output.next_token_logprobs[
517
+ torch.arange(len(next_token_ids), device=next_token_ids.device),
518
+ next_token_ids,
519
+ ].tolist()
520
+ output.input_token_logprobs = output.input_token_logprobs.tolist()
521
+ output.normalized_prompt_logprobs = (
522
+ output.normalized_prompt_logprobs.tolist()
512
523
  )
513
524
 
514
525
  next_token_ids = next_token_ids.tolist()
@@ -547,14 +558,12 @@ class ModelTpServer:
547
558
  self.req_to_token_pool.free(req.req_pool_idx)
548
559
 
549
560
  if req.return_logprob:
550
- self.add_logprob_return_values(
551
- i, req, pt, next_token_ids, logits_output
552
- )
561
+ self.add_logprob_return_values(i, req, pt, next_token_ids, output)
553
562
  pt += req.extend_input_len
554
563
  else:
555
564
  assert batch.extend_num_tokens != 0
556
- logits_output = self.model_runner.forward(batch, ForwardMode.EXTEND)
557
- embeddings = logits_output.embeddings.tolist()
565
+ output = self.model_runner.forward(batch, ForwardMode.EXTEND)
566
+ embeddings = output.embeddings.tolist()
558
567
 
559
568
  # Check finish conditions
560
569
  for i, req in enumerate(batch.reqs):
@@ -582,7 +591,7 @@ class ModelTpServer:
582
591
  req: Req,
583
592
  pt: int,
584
593
  next_token_ids: List[int],
585
- output: LogitsProcessorOutput,
594
+ output: LogitProcessorOutput,
586
595
  ):
587
596
  if req.normalized_prompt_logprob is None:
588
597
  req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
@@ -664,17 +673,15 @@ class ModelTpServer:
664
673
  batch.prepare_for_decode()
665
674
 
666
675
  # Forward and sample the next tokens
667
- sample_output, logits_output = self.model_runner.forward(
668
- batch, ForwardMode.DECODE
669
- )
670
- next_token_ids = batch.check_sample_results(sample_output)
676
+ output = self.model_runner.forward(batch, ForwardMode.DECODE)
677
+ next_token_ids = batch.sample(output.next_token_logits)
671
678
  batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
672
679
  next_token_ids
673
680
  )
674
681
 
675
682
  # Move logprobs to cpu
676
- if logits_output.next_token_logprobs is not None:
677
- next_token_logprobs = logits_output.next_token_logprobs[
683
+ if output.next_token_logprobs is not None:
684
+ next_token_logprobs = output.next_token_logprobs[
678
685
  torch.arange(len(next_token_ids), device=next_token_ids.device),
679
686
  next_token_ids,
680
687
  ].tolist()
@@ -700,7 +707,7 @@ class ModelTpServer:
700
707
  (next_token_logprobs[i], next_token_id)
701
708
  )
702
709
  if req.top_logprobs_num > 0:
703
- req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
710
+ req.output_top_logprobs.append(output.output_top_logprobs[i])
704
711
 
705
712
  self.handle_finished_requests(batch)
706
713