sglang 0.4.0__py3-none-any.whl → 0.4.0.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. sglang/__init__.py +1 -1
  2. sglang/bench_offline_throughput.py +18 -6
  3. sglang/bench_one_batch.py +13 -0
  4. sglang/bench_serving.py +8 -1
  5. sglang/check_env.py +140 -48
  6. sglang/lang/backend/runtime_endpoint.py +1 -0
  7. sglang/lang/chat_template.py +32 -0
  8. sglang/llama3_eval.py +316 -0
  9. sglang/srt/constrained/outlines_backend.py +5 -0
  10. sglang/srt/constrained/xgrammar_backend.py +9 -6
  11. sglang/srt/layers/attention/__init__.py +5 -2
  12. sglang/srt/layers/attention/double_sparsity_backend.py +22 -8
  13. sglang/srt/layers/attention/flashinfer_backend.py +22 -5
  14. sglang/srt/layers/attention/torch_native_backend.py +22 -8
  15. sglang/srt/layers/attention/triton_backend.py +38 -33
  16. sglang/srt/layers/attention/triton_ops/decode_attention.py +305 -350
  17. sglang/srt/layers/attention/triton_ops/extend_attention.py +3 -0
  18. sglang/srt/layers/ep_moe/__init__.py +0 -0
  19. sglang/srt/layers/ep_moe/kernels.py +349 -0
  20. sglang/srt/layers/ep_moe/layer.py +665 -0
  21. sglang/srt/layers/fused_moe_triton/fused_moe.py +64 -21
  22. sglang/srt/layers/fused_moe_triton/layer.py +1 -1
  23. sglang/srt/layers/logits_processor.py +133 -95
  24. sglang/srt/layers/quantization/__init__.py +2 -47
  25. sglang/srt/layers/quantization/fp8.py +607 -0
  26. sglang/srt/layers/quantization/fp8_utils.py +27 -0
  27. sglang/srt/layers/radix_attention.py +11 -2
  28. sglang/srt/layers/sampler.py +29 -5
  29. sglang/srt/layers/torchao_utils.py +58 -45
  30. sglang/srt/managers/detokenizer_manager.py +37 -17
  31. sglang/srt/managers/io_struct.py +39 -10
  32. sglang/srt/managers/schedule_batch.py +39 -24
  33. sglang/srt/managers/schedule_policy.py +64 -5
  34. sglang/srt/managers/scheduler.py +236 -197
  35. sglang/srt/managers/tokenizer_manager.py +99 -58
  36. sglang/srt/managers/tp_worker_overlap_thread.py +7 -5
  37. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  38. sglang/srt/mem_cache/chunk_cache.py +2 -2
  39. sglang/srt/mem_cache/memory_pool.py +5 -1
  40. sglang/srt/mem_cache/radix_cache.py +12 -2
  41. sglang/srt/model_executor/cuda_graph_runner.py +39 -11
  42. sglang/srt/model_executor/model_runner.py +24 -9
  43. sglang/srt/model_parallel.py +67 -10
  44. sglang/srt/models/commandr.py +2 -2
  45. sglang/srt/models/deepseek_v2.py +87 -7
  46. sglang/srt/models/gemma2.py +34 -0
  47. sglang/srt/models/gemma2_reward.py +0 -1
  48. sglang/srt/models/granite.py +517 -0
  49. sglang/srt/models/grok.py +72 -13
  50. sglang/srt/models/llama.py +22 -5
  51. sglang/srt/models/llama_classification.py +11 -23
  52. sglang/srt/models/llama_reward.py +0 -2
  53. sglang/srt/models/llava.py +37 -14
  54. sglang/srt/models/mixtral.py +12 -9
  55. sglang/srt/models/phi3_small.py +0 -5
  56. sglang/srt/models/qwen2.py +20 -0
  57. sglang/srt/models/qwen2_moe.py +0 -5
  58. sglang/srt/models/torch_native_llama.py +0 -5
  59. sglang/srt/openai_api/adapter.py +4 -0
  60. sglang/srt/openai_api/protocol.py +9 -4
  61. sglang/srt/sampling/sampling_batch_info.py +9 -8
  62. sglang/srt/server.py +4 -4
  63. sglang/srt/server_args.py +62 -13
  64. sglang/srt/utils.py +57 -10
  65. sglang/test/test_utils.py +3 -2
  66. sglang/utils.py +10 -3
  67. sglang/version.py +1 -1
  68. {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/METADATA +15 -9
  69. {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/RECORD +72 -65
  70. {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/LICENSE +0 -0
  71. {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/WHEEL +0 -0
  72. {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/top_level.txt +0 -0
@@ -51,7 +51,6 @@ class Sampler(nn.Module):
51
51
  # Post process logits
52
52
  logits.div_(sampling_info.temperatures)
53
53
  probs = torch.softmax(logits, dim=-1)
54
- logits = None
55
54
  del logits
56
55
 
57
56
  if global_server_args_dict["sampling_backend"] == "flashinfer":
@@ -84,6 +83,7 @@ class Sampler(nn.Module):
84
83
  sampling_info.top_ks,
85
84
  sampling_info.top_ps,
86
85
  sampling_info.min_ps,
86
+ sampling_info.need_min_p_sampling,
87
87
  )
88
88
  else:
89
89
  raise ValueError(
@@ -98,18 +98,42 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
98
98
  top_ks: torch.Tensor,
99
99
  top_ps: torch.Tensor,
100
100
  min_ps: torch.Tensor,
101
+ need_min_p_sampling: bool,
101
102
  ):
102
103
  """A top-k, top-p and min-p sampling implementation with native pytorch operations."""
103
104
  probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
104
105
  probs_sum = torch.cumsum(probs_sort, dim=-1)
105
- min_p_thresholds = probs_sort[:, 0] * min_ps
106
- probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0
107
106
  probs_sort[
108
107
  torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1)
109
108
  >= top_ks.view(-1, 1)
110
109
  ] = 0.0
111
- probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0
112
- probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
110
+ probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0
111
+
112
+ if need_min_p_sampling:
113
+ min_p_thresholds = probs_sort[:, 0] * min_ps
114
+ probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0
115
+
113
116
  sampled_index = torch.multinomial(probs_sort, num_samples=1)
117
+ # int32 range is enough to represent the token ids
118
+ probs_idx = probs_idx.to(torch.int32)
114
119
  batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1)
115
120
  return batch_next_token_ids
121
+
122
+
123
+ def top_p_normalize_probs(
124
+ probs: torch.Tensor,
125
+ top_ps: torch.Tensor,
126
+ ):
127
+ if global_server_args_dict["sampling_backend"] == "flashinfer":
128
+ return top_p_renorm_prob(probs, top_ps)
129
+ elif global_server_args_dict["sampling_backend"] == "pytorch":
130
+ # See also top_k_top_p_min_p_sampling_from_probs_torch
131
+ probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
132
+ probs_sum = torch.cumsum(probs_sort, dim=-1)
133
+ probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0
134
+ probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
135
+ return torch.zeros_like(probs_sort).scatter_(-1, probs_idx, probs_sort)
136
+ else:
137
+ raise ValueError(
138
+ f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
139
+ )
@@ -2,23 +2,24 @@
2
2
  Common utilities for torchao.
3
3
  """
4
4
 
5
- from typing import Dict, Set
6
-
7
5
  import torch
8
6
 
9
7
 
10
- def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str):
11
- """Quantize a Tensor with torchao quantization specified by torchao_config
8
+ def apply_torchao_config_to_model(
9
+ model: torch.nn.Module, torchao_config: str, filter_fn=None
10
+ ):
11
+ """Quantize a modelwith torchao quantization specified by torchao_config
12
12
 
13
13
  Args:
14
- `param`: weight parameter of the linear module
15
- `torchao_config`: type of quantization and their arguments we want to use to
16
- quantize the Tensor, e.g. int4wo-128 means int4 weight only quantization with group_size
14
+ `model`: a model to be quantized based on torchao_config
15
+ `torchao_config` (str): type of quantization and their arguments we want to use to
16
+ quantize the model, e.g. int4wo-128 means int4 weight only quantization with group_size
17
17
  128
18
18
  """
19
19
  # Lazy import to suppress some warnings
20
20
  from torchao.quantization import (
21
21
  float8_dynamic_activation_float8_weight,
22
+ float8_weight_only,
22
23
  int4_weight_only,
23
24
  int8_dynamic_activation_int8_weight,
24
25
  int8_weight_only,
@@ -26,12 +27,17 @@ def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str):
26
27
  )
27
28
  from torchao.quantization.observer import PerRow, PerTensor
28
29
 
29
- dummy_linear = torch.nn.Linear(param.shape[1], param.shape[0], bias=False)
30
- dummy_linear.weight = param
31
- if "int8wo" in torchao_config:
32
- quantize_(dummy_linear, int8_weight_only())
30
+ if filter_fn is None:
31
+
32
+ def filter_fn(module, fqn):
33
+ return "proj" in fqn
34
+
35
+ if torchao_config == "" or torchao_config is None:
36
+ return model
37
+ elif "int8wo" in torchao_config:
38
+ quantize_(model, int8_weight_only(), filter_fn=filter_fn)
33
39
  elif "int8dq" in torchao_config:
34
- quantize_(dummy_linear, int8_dynamic_activation_int8_weight())
40
+ quantize_(model, int8_dynamic_activation_int8_weight(), filter_fn=filter_fn)
35
41
  elif "int4wo" in torchao_config:
36
42
  group_size = int(torchao_config.split("-")[-1])
37
43
  assert group_size in [
@@ -40,13 +46,46 @@ def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str):
40
46
  128,
41
47
  256,
42
48
  ], f"int4wo groupsize needs to be one of [32, 64, 128, 256] but got {group_size}"
43
- quantize_(dummy_linear, int4_weight_only(group_size=group_size))
44
- elif "fp8wo" in torchao_config:
45
- from torchao.quantization import float8_weight_only
49
+ quantize_(model, int4_weight_only(group_size=group_size), filter_fn=filter_fn)
50
+ elif "gemlite" in torchao_config:
51
+ # gemlite-<packing_bitwidth>-<bit_width>-<group_size> or
52
+ # gemlite-<bit_width>-<group_size> (packing_bitwidth defaults to 32)
53
+ import os
54
+ import pwd
55
+
56
+ import gemlite
57
+ from gemlite.core import GemLiteLinearTriton, set_autotune
58
+
59
+ try:
60
+ from torchao.quantization import gemlite_uintx_weight_only
61
+ except:
62
+ print(
63
+ f"import `gemlite_uintx_weight_only` failed, please use torchao nightly to use gemlite quantization"
64
+ )
65
+ return model
66
+
67
+ _quant_args = torchao_config.split("-")
68
+ bit_width = int(_quant_args[-2])
69
+ group_size = None if _quant_args[-1] == "None" else int(_quant_args[-1])
70
+ try:
71
+ packing_bitwidth = int(_quant_args[-3])
72
+ except:
73
+ # if only 2 inputs found, use default value
74
+ packing_bitwidth = 32
75
+
76
+ quantize_(
77
+ model, gemlite_uintx_weight_only(group_size, bit_width, packing_bitwidth)
78
+ )
46
79
 
80
+ # try to load gemlite kernel config
81
+ GemLiteLinearTriton.load_config(
82
+ f"/tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json"
83
+ )
84
+
85
+ elif "fp8wo" in torchao_config:
47
86
  # this requires newer hardware
48
87
  # [rank0]: AssertionError: fp8e4nv data type is not supported on CUDA arch < 89
49
- quantize_(dummy_linear, float8_weight_only())
88
+ quantize_(model, float8_weight_only(), filter_fn=filter_fn)
50
89
  elif "fp8dq" in torchao_config:
51
90
  granularity = torchao_config.split("-")[-1]
52
91
  GRANULARITY_MAP = {
@@ -57,39 +96,13 @@ def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str):
57
96
  granularity in GRANULARITY_MAP
58
97
  ), f"Supported granularity are: {GRANULARITY_MAP.keys()}, got {granularity}"
59
98
  quantize_(
60
- dummy_linear,
99
+ model,
61
100
  float8_dynamic_activation_float8_weight(
62
101
  granularity=GRANULARITY_MAP[granularity]
63
102
  ),
103
+ filter_fn=filter_fn,
64
104
  )
65
105
  else:
66
106
  raise ValueError(f"Unexpected config: {torchao_config}")
67
107
 
68
- return dummy_linear.weight
69
-
70
-
71
- def apply_torchao_config_(
72
- self: torch.nn.Module,
73
- params_dict: Dict[str, torch.Tensor],
74
- param_suffixes: Set[str],
75
- ) -> None:
76
- """A util function used for quantizing the weight parameters after they are loaded if
77
- self.torchao_config is specified
78
-
79
- Args:
80
- `self`: the model we want to quantize
81
- `params_dict`: dictionary mapping from param_name to the parameter Tensor
82
- `param_suffixes`: a set of suffixes, we'll quantize the Tensor matching these suffixes
83
-
84
- Returns:
85
- None, the `params_dict` is modified inplace and the weights of `self` model are quantized
86
- """
87
- if self.torchao_config:
88
- for param_suffix in param_suffixes:
89
- for name in params_dict:
90
- param = params_dict[name]
91
- if param_suffix in name and param.ndim == 2:
92
- params_dict[name] = torchao_quantize_param_data(
93
- param, self.torchao_config
94
- )
95
- self.load_state_dict(params_dict, assign=True)
108
+ return model
@@ -17,9 +17,10 @@ import dataclasses
17
17
  import logging
18
18
  import signal
19
19
  from collections import OrderedDict
20
- from typing import List, Union
20
+ from typing import Dict, List, Union
21
21
 
22
22
  import psutil
23
+ import setproctitle
23
24
  import zmq
24
25
 
25
26
  from sglang.srt.hf_transformers_utils import get_tokenizer
@@ -28,7 +29,6 @@ from sglang.srt.managers.io_struct import (
28
29
  BatchStrOut,
29
30
  BatchTokenIDOut,
30
31
  )
31
- from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR, FINISH_MATCHED_TOKEN
32
32
  from sglang.srt.server_args import PortArgs, ServerArgs
33
33
  from sglang.srt.utils import configure_logger, get_zmq_socket
34
34
  from sglang.utils import find_printable_text, get_exception_traceback
@@ -75,17 +75,25 @@ class DetokenizerManager:
75
75
 
76
76
  self.decode_status = LimitedCapacityDict()
77
77
 
78
- def trim_eos(self, output: Union[str, List[int]], finished_reason, no_stop_trim):
79
- if no_stop_trim:
78
+ def trim_matched_stop(
79
+ self, output: Union[str, List[int]], finished_reason: Dict, no_stop_trim: bool
80
+ ):
81
+ if no_stop_trim or not finished_reason:
82
+ return output
83
+
84
+ matched = finished_reason.get("matched", None)
85
+ if not matched:
80
86
  return output
81
87
 
82
- # Trim stop str. TODO(lmzheng): handle the case where multiple stop strs are hit
83
- if isinstance(finished_reason, FINISH_MATCHED_STR) and isinstance(output, str):
84
- pos = output.find(finished_reason.matched)
88
+ # TODO(lmzheng): handle the case where multiple stop strs are hit
89
+
90
+ # Trim stop str.
91
+ if isinstance(matched, str) and isinstance(output, str):
92
+ pos = output.find(matched)
85
93
  return output[:pos] if pos != -1 else output
86
- if isinstance(finished_reason, FINISH_MATCHED_TOKEN) and isinstance(
87
- output, list
88
- ):
94
+
95
+ # Trim stop token.
96
+ if isinstance(matched, int) and isinstance(output, list):
89
97
  assert len(output) > 0
90
98
  return output[:-1]
91
99
  return output
@@ -124,9 +132,9 @@ class DetokenizerManager:
124
132
  s.decode_ids = recv_obj.decode_ids[i]
125
133
 
126
134
  read_ids.append(
127
- self.trim_eos(
135
+ self.trim_matched_stop(
128
136
  s.decode_ids[s.surr_offset :],
129
- recv_obj.finished_reason[i],
137
+ recv_obj.finished_reasons[i],
130
138
  recv_obj.no_stop_trim[i],
131
139
  )
132
140
  )
@@ -149,7 +157,7 @@ class DetokenizerManager:
149
157
  for i in range(bs):
150
158
  s = self.decode_status[recv_obj.rids[i]]
151
159
  new_text = read_texts[i][len(surr_texts[i]) :]
152
- if recv_obj.finished_reason[i] is None:
160
+ if recv_obj.finished_reasons[i] is None:
153
161
  # Streaming chunk: update the decode status
154
162
  if len(new_text) > 0 and not new_text.endswith("�"):
155
163
  s.decoded_text = s.decoded_text + new_text
@@ -160,9 +168,9 @@ class DetokenizerManager:
160
168
  new_text = find_printable_text(new_text)
161
169
 
162
170
  output_strs.append(
163
- self.trim_eos(
171
+ self.trim_matched_stop(
164
172
  s.decoded_text + new_text,
165
- recv_obj.finished_reason[i],
173
+ recv_obj.finished_reasons[i],
166
174
  recv_obj.no_stop_trim[i],
167
175
  )
168
176
  )
@@ -170,9 +178,20 @@ class DetokenizerManager:
170
178
  self.send_to_tokenizer.send_pyobj(
171
179
  BatchStrOut(
172
180
  rids=recv_obj.rids,
181
+ finished_reasons=recv_obj.finished_reasons,
173
182
  output_strs=output_strs,
174
- meta_info=recv_obj.meta_info,
175
- finished_reason=recv_obj.finished_reason,
183
+ prompt_tokens=recv_obj.prompt_tokens,
184
+ completion_tokens=recv_obj.completion_tokens,
185
+ cached_tokens=recv_obj.cached_tokens,
186
+ input_token_logprobs_val=recv_obj.input_token_logprobs_val,
187
+ input_token_logprobs_idx=recv_obj.input_token_logprobs_idx,
188
+ output_token_logprobs_val=recv_obj.output_token_logprobs_val,
189
+ output_token_logprobs_idx=recv_obj.output_token_logprobs_idx,
190
+ input_top_logprobs_val=recv_obj.input_top_logprobs_val,
191
+ input_top_logprobs_idx=recv_obj.input_top_logprobs_idx,
192
+ output_top_logprobs_val=recv_obj.output_top_logprobs_val,
193
+ output_top_logprobs_idx=recv_obj.output_top_logprobs_idx,
194
+ normalized_prompt_logprob=recv_obj.normalized_prompt_logprob,
176
195
  )
177
196
  )
178
197
 
@@ -194,6 +213,7 @@ def run_detokenizer_process(
194
213
  server_args: ServerArgs,
195
214
  port_args: PortArgs,
196
215
  ):
216
+ setproctitle.setproctitle("sglang::detokenizer")
197
217
  configure_logger(server_args)
198
218
  parent_process = psutil.Process().parent()
199
219
 
@@ -308,6 +308,9 @@ class TokenizedEmbeddingReqInput:
308
308
  class BatchTokenIDOut:
309
309
  # The request id
310
310
  rids: List[str]
311
+ # The finish reason
312
+ finished_reasons: List[BaseFinishReason]
313
+ # For incremental decoding
311
314
  # The version id to sync decode status with in detokenizer_manager
312
315
  vids: List[int]
313
316
  decoded_texts: List[str]
@@ -315,35 +318,61 @@ class BatchTokenIDOut:
315
318
  read_offsets: List[int]
316
319
  # Only used when `--skip-tokenizer-init`
317
320
  output_ids: Optional[List[int]]
321
+ # Detokenization configs
318
322
  skip_special_tokens: List[bool]
319
323
  spaces_between_special_tokens: List[bool]
320
- meta_info: List[Dict]
321
- finished_reason: List[BaseFinishReason]
322
324
  no_stop_trim: List[bool]
325
+ # Token counts
326
+ prompt_tokens: List[int]
327
+ completion_tokens: List[int]
328
+ cached_tokens: List[int]
329
+ # Logprobs
330
+ input_token_logprobs_val: List[float]
331
+ input_token_logprobs_idx: List[int]
332
+ output_token_logprobs_val: List[float]
333
+ output_token_logprobs_idx: List[int]
334
+ input_top_logprobs_val: List[List]
335
+ input_top_logprobs_idx: List[List]
336
+ output_top_logprobs_val: List[List]
337
+ output_top_logprobs_idx: List[List]
338
+ normalized_prompt_logprob: List[float]
323
339
 
324
340
 
325
341
  @dataclass
326
342
  class BatchStrOut:
327
343
  # The request id
328
344
  rids: List[str]
345
+ # The finish reason
346
+ finished_reasons: List[dict]
329
347
  # The output decoded strings
330
348
  output_strs: List[str]
331
- # The meta info
332
- meta_info: List[Dict]
333
- # The finish reason
334
- finished_reason: List[BaseFinishReason]
349
+
350
+ # Token counts
351
+ prompt_tokens: List[int]
352
+ completion_tokens: List[int]
353
+ cached_tokens: List[int]
354
+ # Logprobs
355
+ input_token_logprobs_val: List[float]
356
+ input_token_logprobs_idx: List[int]
357
+ output_token_logprobs_val: List[float]
358
+ output_token_logprobs_idx: List[int]
359
+ input_top_logprobs_val: List[List]
360
+ input_top_logprobs_idx: List[List]
361
+ output_top_logprobs_val: List[List]
362
+ output_top_logprobs_idx: List[List]
363
+ normalized_prompt_logprob: List[float]
335
364
 
336
365
 
337
366
  @dataclass
338
367
  class BatchEmbeddingOut:
339
368
  # The request id
340
369
  rids: List[str]
370
+ # The finish reason
371
+ finished_reasons: List[BaseFinishReason]
341
372
  # The output embedding
342
373
  embeddings: List[List[float]]
343
- # The meta info
344
- meta_info: List[Dict]
345
- # The finish reason
346
- finished_reason: List[BaseFinishReason]
374
+ # Token counts
375
+ prompt_tokens: List[int]
347
376
 
348
377
 
349
378
  @dataclass
@@ -58,6 +58,7 @@ global_server_args_dict = {
58
58
  "torchao_config": ServerArgs.torchao_config,
59
59
  "enable_nan_detection": ServerArgs.enable_nan_detection,
60
60
  "enable_dp_attention": ServerArgs.enable_dp_attention,
61
+ "enable_ep_moe": ServerArgs.enable_ep_moe,
61
62
  }
62
63
 
63
64
 
@@ -128,6 +129,7 @@ class ImageInputs:
128
129
  image_hashes: Optional[list] = None
129
130
  image_sizes: Optional[list] = None
130
131
  image_offsets: Optional[list] = None
132
+ image_pad_len: Optional[list] = None
131
133
  pad_values: Optional[list] = None
132
134
  modalities: Optional[list] = None
133
135
  num_image_tokens: Optional[int] = None
@@ -180,6 +182,7 @@ class ImageInputs:
180
182
  optional_args = [
181
183
  "image_sizes",
182
184
  "image_offsets",
185
+ "image_pad_len",
183
186
  # "modalities", # modalities should be ["multi-images"] (one entry) even for multiple images
184
187
  "aspect_ratio_ids",
185
188
  "aspect_ratio_mask",
@@ -199,6 +202,9 @@ class Req:
199
202
  origin_input_text: str,
200
203
  origin_input_ids: Tuple[int],
201
204
  sampling_params: SamplingParams,
205
+ return_logprob: bool = False,
206
+ top_logprobs_num: int = 0,
207
+ stream: bool = False,
202
208
  origin_input_ids_unpadded: Optional[Tuple[int]] = None,
203
209
  lora_path: Optional[str] = None,
204
210
  input_embeds: Optional[List[List[float]]] = None,
@@ -216,10 +222,11 @@ class Req:
216
222
  self.output_ids = [] # Each decode stage's output ids
217
223
  self.fill_ids = None # fill_ids = origin_input_ids + output_ids
218
224
  self.session_id = session_id
225
+ self.input_embeds = input_embeds
219
226
 
227
+ # Sampling info
220
228
  self.sampling_params = sampling_params
221
229
  self.lora_path = lora_path
222
- self.input_embeds = input_embeds
223
230
 
224
231
  # Memory pool info
225
232
  self.req_pool_idx = None
@@ -227,8 +234,8 @@ class Req:
227
234
  # Check finish
228
235
  self.tokenizer = None
229
236
  self.finished_reason = None
230
- self.stream = False
231
237
  self.to_abort = False
238
+ self.stream = stream
232
239
 
233
240
  # For incremental decoding
234
241
  # ----- | --------- read_ids -------|
@@ -240,37 +247,46 @@ class Req:
240
247
  # 2: read_offset
241
248
  # 3: last token
242
249
  self.vid = 0 # version id to sync decode status with in detokenizer_manager
243
- self.decoded_text = ""
244
250
  self.surr_offset = None # Surrounding offset to defeat the cleanup algorithm
245
251
  self.read_offset = None
246
-
247
- # The number of decoded tokens for token usage report. Note that
248
- # this does not include the jump forward tokens.
249
- self.completion_tokens_wo_jump_forward = 0
252
+ self.decoded_text = ""
250
253
 
251
254
  # For multimodal inputs
252
255
  self.image_inputs: Optional[ImageInputs] = None
253
256
 
254
257
  # Prefix info
255
258
  self.prefix_indices = []
259
+ # Tokens to run prefill. input_tokens - shared_prefix_tokens.
256
260
  self.extend_input_len = 0
257
261
  self.last_node = None
262
+
263
+ # Chunked prefill
258
264
  self.is_being_chunked = 0
259
265
 
260
266
  # For retraction
261
267
  self.is_retracted = False
262
268
 
263
269
  # Logprobs (arguments)
264
- self.return_logprob = False
270
+ self.return_logprob = return_logprob
265
271
  self.logprob_start_len = 0
266
- self.top_logprobs_num = 0
272
+ self.top_logprobs_num = top_logprobs_num
267
273
 
268
274
  # Logprobs (return value)
269
275
  self.normalized_prompt_logprob = None
270
- self.input_token_logprobs = None
271
- self.input_top_logprobs = None
272
- self.output_token_logprobs = []
273
- self.output_top_logprobs = []
276
+ self.input_token_logprobs_val = None
277
+ self.input_token_logprobs_idx = None
278
+ self.input_top_logprobs_val = None
279
+ self.input_top_logprobs_idx = None
280
+
281
+ if return_logprob:
282
+ self.output_token_logprobs_val = []
283
+ self.output_token_logprobs_idx = []
284
+ self.output_top_logprobs_val = []
285
+ self.output_top_logprobs_idx = []
286
+ else:
287
+ self.output_token_logprobs_val = self.output_token_logprobs_idx = (
288
+ self.output_top_logprobs_val
289
+ ) = self.output_top_logprobs_idx = None
274
290
 
275
291
  # Logprobs (internal values)
276
292
  # The tokens is prefilled but need to be considered as decode tokens
@@ -294,13 +310,14 @@ class Req:
294
310
  else:
295
311
  self.image_inputs.merge(image_inputs)
296
312
 
297
- # whether request reached finished condition
298
313
  def finished(self) -> bool:
314
+ # Whether request reached finished condition
299
315
  return self.finished_reason is not None
300
316
 
301
317
  def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None):
302
318
  self.fill_ids = self.origin_input_ids + self.output_ids
303
319
  if tree_cache is not None:
320
+ # tree cache is None if the prefix is not computed with tree cache.
304
321
  self.prefix_indices, self.last_node = tree_cache.match_prefix(
305
322
  rid=self.rid, key=self.adjust_max_prefix_ids()
306
323
  )
@@ -453,8 +470,10 @@ class Req:
453
470
  k = k + 1
454
471
  else:
455
472
  break
456
- self.output_token_logprobs = self.output_token_logprobs[:k]
457
- self.output_top_logprobs = self.output_top_logprobs[:k]
473
+ self.output_token_logprobs_val = self.output_token_logprobs_val[:k]
474
+ self.output_token_logprobs_idx = self.output_token_logprobs_idx[:k]
475
+ self.output_top_logprobs_val = self.output_top_logprobs_val[:k]
476
+ self.output_top_logprobs_idx = self.output_top_logprobs_idx[:k]
458
477
  self.logprob_start_len = prompt_tokens + k
459
478
  self.last_update_decode_tokens = len(self.output_ids) - k
460
479
 
@@ -469,7 +488,7 @@ bid = 0
469
488
 
470
489
  @dataclasses.dataclass
471
490
  class ScheduleBatch:
472
- """Store all inforamtion of a batch on the scheduler."""
491
+ """Store all information of a batch on the scheduler."""
473
492
 
474
493
  # Request, memory pool, and cache
475
494
  reqs: List[Req]
@@ -1067,9 +1086,9 @@ class ScheduleBatch:
1067
1086
  self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
1068
1087
  self.reqs.extend(other.reqs)
1069
1088
 
1070
- self.return_logprob = self.return_logprob or other.return_logprob
1071
- self.has_stream = self.has_stream or other.has_stream
1072
- self.has_grammar = self.has_grammar or other.has_grammar
1089
+ self.return_logprob |= other.return_logprob
1090
+ self.has_stream |= other.has_stream
1091
+ self.has_grammar |= other.has_grammar
1073
1092
 
1074
1093
  def get_model_worker_batch(self):
1075
1094
  if self.forward_mode.is_decode() or self.forward_mode.is_idle():
@@ -1096,7 +1115,6 @@ class ScheduleBatch:
1096
1115
  seq_lens=self.seq_lens,
1097
1116
  out_cache_loc=self.out_cache_loc,
1098
1117
  seq_lens_sum=self.seq_lens_sum,
1099
- req_to_token_pool_records=self.req_to_token_pool.get_write_records(),
1100
1118
  return_logprob=self.return_logprob,
1101
1119
  top_logprobs_nums=self.top_logprobs_nums,
1102
1120
  global_num_tokens=self.global_num_tokens,
@@ -1151,9 +1169,6 @@ class ModelWorkerBatch:
1151
1169
  # The sum of all sequence lengths
1152
1170
  seq_lens_sum: int
1153
1171
 
1154
- # The memory pool operation records
1155
- req_to_token_pool_records: Optional[List[Tuple[Tuple, torch.Tensor]]]
1156
-
1157
1172
  # For logprob
1158
1173
  return_logprob: bool
1159
1174
  top_logprobs_nums: Optional[List[int]]