sglang 0.4.0.post1__py3-none-any.whl → 0.4.0.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. sglang/bench_offline_throughput.py +18 -6
  2. sglang/bench_one_batch.py +13 -0
  3. sglang/bench_serving.py +8 -1
  4. sglang/check_env.py +140 -48
  5. sglang/lang/backend/runtime_endpoint.py +1 -0
  6. sglang/lang/chat_template.py +32 -0
  7. sglang/llama3_eval.py +316 -0
  8. sglang/srt/constrained/xgrammar_backend.py +4 -1
  9. sglang/srt/layers/attention/flashinfer_backend.py +2 -0
  10. sglang/srt/layers/attention/triton_backend.py +16 -25
  11. sglang/srt/layers/attention/triton_ops/decode_attention.py +305 -350
  12. sglang/srt/layers/ep_moe/layer.py +4 -0
  13. sglang/srt/layers/fused_moe_triton/fused_moe.py +64 -21
  14. sglang/srt/layers/fused_moe_triton/layer.py +1 -1
  15. sglang/srt/layers/logits_processor.py +133 -95
  16. sglang/srt/layers/quantization/__init__.py +2 -47
  17. sglang/srt/layers/quantization/fp8.py +58 -10
  18. sglang/srt/layers/radix_attention.py +8 -1
  19. sglang/srt/layers/sampler.py +27 -5
  20. sglang/srt/layers/torchao_utils.py +35 -0
  21. sglang/srt/managers/detokenizer_manager.py +37 -17
  22. sglang/srt/managers/io_struct.py +39 -10
  23. sglang/srt/managers/schedule_batch.py +38 -24
  24. sglang/srt/managers/schedule_policy.py +64 -5
  25. sglang/srt/managers/scheduler.py +169 -134
  26. sglang/srt/managers/tokenizer_manager.py +99 -58
  27. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  28. sglang/srt/mem_cache/chunk_cache.py +2 -2
  29. sglang/srt/mem_cache/radix_cache.py +12 -2
  30. sglang/srt/model_executor/cuda_graph_runner.py +24 -10
  31. sglang/srt/model_executor/model_runner.py +22 -14
  32. sglang/srt/model_parallel.py +66 -5
  33. sglang/srt/models/gemma2.py +34 -0
  34. sglang/srt/models/gemma2_reward.py +0 -1
  35. sglang/srt/models/granite.py +517 -0
  36. sglang/srt/models/grok.py +72 -8
  37. sglang/srt/models/llama.py +22 -0
  38. sglang/srt/models/llama_classification.py +11 -23
  39. sglang/srt/models/llama_reward.py +0 -2
  40. sglang/srt/models/llava.py +37 -14
  41. sglang/srt/models/qwen2.py +20 -0
  42. sglang/srt/openai_api/adapter.py +4 -0
  43. sglang/srt/openai_api/protocol.py +9 -4
  44. sglang/srt/server.py +1 -1
  45. sglang/srt/server_args.py +19 -9
  46. sglang/srt/utils.py +7 -10
  47. sglang/test/test_utils.py +3 -2
  48. sglang/utils.py +10 -3
  49. sglang/version.py +1 -1
  50. {sglang-0.4.0.post1.dist-info → sglang-0.4.0.post2.dist-info}/METADATA +11 -6
  51. {sglang-0.4.0.post1.dist-info → sglang-0.4.0.post2.dist-info}/RECORD +54 -52
  52. {sglang-0.4.0.post1.dist-info → sglang-0.4.0.post2.dist-info}/LICENSE +0 -0
  53. {sglang-0.4.0.post1.dist-info → sglang-0.4.0.post2.dist-info}/WHEEL +0 -0
  54. {sglang-0.4.0.post1.dist-info → sglang-0.4.0.post2.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,11 @@
1
1
  # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py
2
2
 
3
3
  import logging
4
+ import os
4
5
  from typing import Any, Callable, Dict, List, Optional
5
6
 
6
7
  import torch
8
+ import torch.nn.functional as F
7
9
  from torch.nn import Module
8
10
  from torch.nn.parameter import Parameter
9
11
  from vllm import _custom_ops as ops
@@ -24,11 +26,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
24
26
  )
25
27
  from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter
26
28
 
27
- from sglang.srt.layers.fused_moe_triton import (
28
- FusedMoE,
29
- FusedMoEMethodBase,
30
- FusedMoeWeightScaleSupported,
31
- )
29
+ from sglang.srt.layers.fused_moe_triton.fused_moe import padding_size
32
30
  from sglang.srt.layers.linear import LinearMethodBase, UnquantizedLinearMethod
33
31
  from sglang.srt.layers.quantization.base_config import (
34
32
  QuantizationConfig,
@@ -100,6 +98,8 @@ class Fp8Config(QuantizationConfig):
100
98
  ) -> Optional["QuantizeMethodBase"]:
101
99
  from vllm.attention.layer import Attention # Avoid circular import
102
100
 
101
+ from sglang.srt.layers.fused_moe_triton import FusedMoE
102
+
103
103
  if isinstance(layer, LinearBase):
104
104
  if is_layer_skipped(prefix, self.ignored_layers):
105
105
  return UnquantizedLinearMethod()
@@ -306,7 +306,7 @@ class Fp8LinearMethod(LinearMethodBase):
306
306
  )
307
307
 
308
308
 
309
- class Fp8MoEMethod(FusedMoEMethodBase):
309
+ class Fp8MoEMethod:
310
310
  """MoE method for FP8.
311
311
  Supports loading FP8 checkpoints with static weight scale and
312
312
  dynamic/static activation scale.
@@ -319,7 +319,25 @@ class Fp8MoEMethod(FusedMoEMethodBase):
319
319
  quant_config: The quantization config.
320
320
  """
321
321
 
322
- def __init__(self, quant_config: Fp8Config):
322
+ def __new__(cls, *args, **kwargs):
323
+ from sglang.srt.layers.fused_moe_triton import FusedMoEMethodBase
324
+
325
+ if not hasattr(cls, "_initialized"):
326
+ original_init = cls.__init__
327
+ new_cls = type(
328
+ cls.__name__,
329
+ (FusedMoEMethodBase,),
330
+ {
331
+ "__init__": original_init,
332
+ **{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
333
+ },
334
+ )
335
+ obj = super(new_cls, new_cls).__new__(new_cls)
336
+ obj.__init__(*args, **kwargs)
337
+ return obj
338
+ return super().__new__(cls)
339
+
340
+ def __init__(self, quant_config):
323
341
  self.quant_config = quant_config
324
342
 
325
343
  def create_weights(
@@ -331,6 +349,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
331
349
  params_dtype: torch.dtype,
332
350
  **extra_weight_attrs,
333
351
  ):
352
+ from sglang.srt.layers.fused_moe_triton import FusedMoeWeightScaleSupported
334
353
 
335
354
  if self.quant_config.is_checkpoint_fp8_serialized:
336
355
  params_dtype = torch.float8_e4m3fn
@@ -404,7 +423,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
404
423
 
405
424
  def process_weights_after_loading(self, layer: Module) -> None:
406
425
 
407
- # If checkpoint is fp16, quantize in place.
426
+ # If checkpoint is fp16 or bfloat16, quantize in place.
408
427
  if not self.quant_config.is_checkpoint_fp8_serialized:
409
428
  # If ROCm, use float8_e4m3fnuz instead (MI300x HW)
410
429
  fp8_dtype = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn
@@ -428,6 +447,19 @@ class Fp8MoEMethod(FusedMoEMethodBase):
428
447
  )
429
448
  layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
430
449
  layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
450
+
451
+ # If ROCm, apply weight padding (min. Mem channel contention) only if set
452
+ if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))):
453
+ layer.w13_weight = torch.nn.Parameter(
454
+ F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
455
+ requires_grad=False,
456
+ )
457
+ torch.cuda.empty_cache()
458
+ layer.w2_weight = torch.nn.Parameter(
459
+ F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
460
+ requires_grad=False,
461
+ )
462
+ torch.cuda.empty_cache()
431
463
  return
432
464
 
433
465
  # If checkpoint is fp8, we need to handle that the
@@ -456,6 +488,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
456
488
  layer.w2_input_scale = torch.nn.Parameter(
457
489
  layer.w2_input_scale.max(), requires_grad=False
458
490
  )
491
+
459
492
  # If ROCm, normalize the weights and scales to e4m3fnuz
460
493
  if is_hip():
461
494
  # Normalize the weights and scales
@@ -507,6 +540,19 @@ class Fp8MoEMethod(FusedMoEMethodBase):
507
540
  layer.w13_weight_scale = torch.nn.Parameter(
508
541
  max_w13_scales, requires_grad=False
509
542
  )
543
+
544
+ # If ROCm, apply weight padding (min. Mem channel contention) only if set
545
+ if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))):
546
+ layer.w13_weight = torch.nn.Parameter(
547
+ F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
548
+ requires_grad=False,
549
+ )
550
+ torch.cuda.empty_cache()
551
+ layer.w2_weight = torch.nn.Parameter(
552
+ F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
553
+ requires_grad=False,
554
+ )
555
+ torch.cuda.empty_cache()
510
556
  return
511
557
 
512
558
  def apply(
@@ -521,9 +567,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
521
567
  num_expert_group: Optional[int] = None,
522
568
  custom_routing_function: Optional[Callable] = None,
523
569
  ) -> torch.Tensor:
570
+ from sglang.srt.layers.fused_moe_triton import FusedMoE
571
+ from sglang.srt.layers.fused_moe_triton.fused_moe import fused_experts
524
572
 
525
- from vllm.model_executor.layers.fused_moe import fused_experts
526
-
573
+ # Expert selection
527
574
  topk_weights, topk_ids = FusedMoE.select_experts(
528
575
  hidden_states=x,
529
576
  router_logits=router_logits,
@@ -535,6 +582,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
535
582
  custom_routing_function=custom_routing_function,
536
583
  )
537
584
 
585
+ # Expert fusion with FP8 quantization
538
586
  return fused_experts(
539
587
  x,
540
588
  layer.w13_weight,
@@ -48,7 +48,14 @@ class RadixAttention(nn.Module):
48
48
  self.sliding_window_size = sliding_window_size or -1
49
49
  self.is_cross_attention = is_cross_attention
50
50
 
51
- def forward(self, q, k, v, forward_batch: ForwardBatch, save_kv_cache=True):
51
+ def forward(
52
+ self,
53
+ q,
54
+ k,
55
+ v,
56
+ forward_batch: ForwardBatch,
57
+ save_kv_cache: bool = True,
58
+ ):
52
59
  if k is not None:
53
60
  # For cross-layer sharing, kv can be None
54
61
  assert v is not None
@@ -51,7 +51,6 @@ class Sampler(nn.Module):
51
51
  # Post process logits
52
52
  logits.div_(sampling_info.temperatures)
53
53
  probs = torch.softmax(logits, dim=-1)
54
- logits = None
55
54
  del logits
56
55
 
57
56
  if global_server_args_dict["sampling_backend"] == "flashinfer":
@@ -84,6 +83,7 @@ class Sampler(nn.Module):
84
83
  sampling_info.top_ks,
85
84
  sampling_info.top_ps,
86
85
  sampling_info.min_ps,
86
+ sampling_info.need_min_p_sampling,
87
87
  )
88
88
  else:
89
89
  raise ValueError(
@@ -98,20 +98,42 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
98
98
  top_ks: torch.Tensor,
99
99
  top_ps: torch.Tensor,
100
100
  min_ps: torch.Tensor,
101
+ need_min_p_sampling: bool,
101
102
  ):
102
103
  """A top-k, top-p and min-p sampling implementation with native pytorch operations."""
103
104
  probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
104
105
  probs_sum = torch.cumsum(probs_sort, dim=-1)
105
- min_p_thresholds = probs_sort[:, 0] * min_ps
106
- probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0
107
106
  probs_sort[
108
107
  torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1)
109
108
  >= top_ks.view(-1, 1)
110
109
  ] = 0.0
111
- probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0
112
- probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
110
+ probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0
111
+
112
+ if need_min_p_sampling:
113
+ min_p_thresholds = probs_sort[:, 0] * min_ps
114
+ probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0
115
+
113
116
  sampled_index = torch.multinomial(probs_sort, num_samples=1)
114
117
  # int32 range is enough to represent the token ids
115
118
  probs_idx = probs_idx.to(torch.int32)
116
119
  batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1)
117
120
  return batch_next_token_ids
121
+
122
+
123
+ def top_p_normalize_probs(
124
+ probs: torch.Tensor,
125
+ top_ps: torch.Tensor,
126
+ ):
127
+ if global_server_args_dict["sampling_backend"] == "flashinfer":
128
+ return top_p_renorm_prob(probs, top_ps)
129
+ elif global_server_args_dict["sampling_backend"] == "pytorch":
130
+ # See also top_k_top_p_min_p_sampling_from_probs_torch
131
+ probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
132
+ probs_sum = torch.cumsum(probs_sort, dim=-1)
133
+ probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0
134
+ probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
135
+ return torch.zeros_like(probs_sort).scatter_(-1, probs_idx, probs_sort)
136
+ else:
137
+ raise ValueError(
138
+ f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
139
+ )
@@ -47,6 +47,41 @@ def apply_torchao_config_to_model(
47
47
  256,
48
48
  ], f"int4wo groupsize needs to be one of [32, 64, 128, 256] but got {group_size}"
49
49
  quantize_(model, int4_weight_only(group_size=group_size), filter_fn=filter_fn)
50
+ elif "gemlite" in torchao_config:
51
+ # gemlite-<packing_bitwidth>-<bit_width>-<group_size> or
52
+ # gemlite-<bit_width>-<group_size> (packing_bitwidth defaults to 32)
53
+ import os
54
+ import pwd
55
+
56
+ import gemlite
57
+ from gemlite.core import GemLiteLinearTriton, set_autotune
58
+
59
+ try:
60
+ from torchao.quantization import gemlite_uintx_weight_only
61
+ except:
62
+ print(
63
+ f"import `gemlite_uintx_weight_only` failed, please use torchao nightly to use gemlite quantization"
64
+ )
65
+ return model
66
+
67
+ _quant_args = torchao_config.split("-")
68
+ bit_width = int(_quant_args[-2])
69
+ group_size = None if _quant_args[-1] == "None" else int(_quant_args[-1])
70
+ try:
71
+ packing_bitwidth = int(_quant_args[-3])
72
+ except:
73
+ # if only 2 inputs found, use default value
74
+ packing_bitwidth = 32
75
+
76
+ quantize_(
77
+ model, gemlite_uintx_weight_only(group_size, bit_width, packing_bitwidth)
78
+ )
79
+
80
+ # try to load gemlite kernel config
81
+ GemLiteLinearTriton.load_config(
82
+ f"/tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json"
83
+ )
84
+
50
85
  elif "fp8wo" in torchao_config:
51
86
  # this requires newer hardware
52
87
  # [rank0]: AssertionError: fp8e4nv data type is not supported on CUDA arch < 89
@@ -17,9 +17,10 @@ import dataclasses
17
17
  import logging
18
18
  import signal
19
19
  from collections import OrderedDict
20
- from typing import List, Union
20
+ from typing import Dict, List, Union
21
21
 
22
22
  import psutil
23
+ import setproctitle
23
24
  import zmq
24
25
 
25
26
  from sglang.srt.hf_transformers_utils import get_tokenizer
@@ -28,7 +29,6 @@ from sglang.srt.managers.io_struct import (
28
29
  BatchStrOut,
29
30
  BatchTokenIDOut,
30
31
  )
31
- from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR, FINISH_MATCHED_TOKEN
32
32
  from sglang.srt.server_args import PortArgs, ServerArgs
33
33
  from sglang.srt.utils import configure_logger, get_zmq_socket
34
34
  from sglang.utils import find_printable_text, get_exception_traceback
@@ -75,17 +75,25 @@ class DetokenizerManager:
75
75
 
76
76
  self.decode_status = LimitedCapacityDict()
77
77
 
78
- def trim_eos(self, output: Union[str, List[int]], finished_reason, no_stop_trim):
79
- if no_stop_trim:
78
+ def trim_matched_stop(
79
+ self, output: Union[str, List[int]], finished_reason: Dict, no_stop_trim: bool
80
+ ):
81
+ if no_stop_trim or not finished_reason:
82
+ return output
83
+
84
+ matched = finished_reason.get("matched", None)
85
+ if not matched:
80
86
  return output
81
87
 
82
- # Trim stop str. TODO(lmzheng): handle the case where multiple stop strs are hit
83
- if isinstance(finished_reason, FINISH_MATCHED_STR) and isinstance(output, str):
84
- pos = output.find(finished_reason.matched)
88
+ # TODO(lmzheng): handle the case where multiple stop strs are hit
89
+
90
+ # Trim stop str.
91
+ if isinstance(matched, str) and isinstance(output, str):
92
+ pos = output.find(matched)
85
93
  return output[:pos] if pos != -1 else output
86
- if isinstance(finished_reason, FINISH_MATCHED_TOKEN) and isinstance(
87
- output, list
88
- ):
94
+
95
+ # Trim stop token.
96
+ if isinstance(matched, int) and isinstance(output, list):
89
97
  assert len(output) > 0
90
98
  return output[:-1]
91
99
  return output
@@ -124,9 +132,9 @@ class DetokenizerManager:
124
132
  s.decode_ids = recv_obj.decode_ids[i]
125
133
 
126
134
  read_ids.append(
127
- self.trim_eos(
135
+ self.trim_matched_stop(
128
136
  s.decode_ids[s.surr_offset :],
129
- recv_obj.finished_reason[i],
137
+ recv_obj.finished_reasons[i],
130
138
  recv_obj.no_stop_trim[i],
131
139
  )
132
140
  )
@@ -149,7 +157,7 @@ class DetokenizerManager:
149
157
  for i in range(bs):
150
158
  s = self.decode_status[recv_obj.rids[i]]
151
159
  new_text = read_texts[i][len(surr_texts[i]) :]
152
- if recv_obj.finished_reason[i] is None:
160
+ if recv_obj.finished_reasons[i] is None:
153
161
  # Streaming chunk: update the decode status
154
162
  if len(new_text) > 0 and not new_text.endswith("�"):
155
163
  s.decoded_text = s.decoded_text + new_text
@@ -160,9 +168,9 @@ class DetokenizerManager:
160
168
  new_text = find_printable_text(new_text)
161
169
 
162
170
  output_strs.append(
163
- self.trim_eos(
171
+ self.trim_matched_stop(
164
172
  s.decoded_text + new_text,
165
- recv_obj.finished_reason[i],
173
+ recv_obj.finished_reasons[i],
166
174
  recv_obj.no_stop_trim[i],
167
175
  )
168
176
  )
@@ -170,9 +178,20 @@ class DetokenizerManager:
170
178
  self.send_to_tokenizer.send_pyobj(
171
179
  BatchStrOut(
172
180
  rids=recv_obj.rids,
181
+ finished_reasons=recv_obj.finished_reasons,
173
182
  output_strs=output_strs,
174
- meta_info=recv_obj.meta_info,
175
- finished_reason=recv_obj.finished_reason,
183
+ prompt_tokens=recv_obj.prompt_tokens,
184
+ completion_tokens=recv_obj.completion_tokens,
185
+ cached_tokens=recv_obj.cached_tokens,
186
+ input_token_logprobs_val=recv_obj.input_token_logprobs_val,
187
+ input_token_logprobs_idx=recv_obj.input_token_logprobs_idx,
188
+ output_token_logprobs_val=recv_obj.output_token_logprobs_val,
189
+ output_token_logprobs_idx=recv_obj.output_token_logprobs_idx,
190
+ input_top_logprobs_val=recv_obj.input_top_logprobs_val,
191
+ input_top_logprobs_idx=recv_obj.input_top_logprobs_idx,
192
+ output_top_logprobs_val=recv_obj.output_top_logprobs_val,
193
+ output_top_logprobs_idx=recv_obj.output_top_logprobs_idx,
194
+ normalized_prompt_logprob=recv_obj.normalized_prompt_logprob,
176
195
  )
177
196
  )
178
197
 
@@ -194,6 +213,7 @@ def run_detokenizer_process(
194
213
  server_args: ServerArgs,
195
214
  port_args: PortArgs,
196
215
  ):
216
+ setproctitle.setproctitle("sglang::detokenizer")
197
217
  configure_logger(server_args)
198
218
  parent_process = psutil.Process().parent()
199
219
 
@@ -308,6 +308,9 @@ class TokenizedEmbeddingReqInput:
308
308
  class BatchTokenIDOut:
309
309
  # The request id
310
310
  rids: List[str]
311
+ # The finish reason
312
+ finished_reasons: List[BaseFinishReason]
313
+ # For incremental decoding
311
314
  # The version id to sync decode status with in detokenizer_manager
312
315
  vids: List[int]
313
316
  decoded_texts: List[str]
@@ -315,35 +318,61 @@ class BatchTokenIDOut:
315
318
  read_offsets: List[int]
316
319
  # Only used when `--skip-tokenizer-init`
317
320
  output_ids: Optional[List[int]]
321
+ # Detokenization configs
318
322
  skip_special_tokens: List[bool]
319
323
  spaces_between_special_tokens: List[bool]
320
- meta_info: List[Dict]
321
- finished_reason: List[BaseFinishReason]
322
324
  no_stop_trim: List[bool]
325
+ # Token counts
326
+ prompt_tokens: List[int]
327
+ completion_tokens: List[int]
328
+ cached_tokens: List[int]
329
+ # Logprobs
330
+ input_token_logprobs_val: List[float]
331
+ input_token_logprobs_idx: List[int]
332
+ output_token_logprobs_val: List[float]
333
+ output_token_logprobs_idx: List[int]
334
+ input_top_logprobs_val: List[List]
335
+ input_top_logprobs_idx: List[List]
336
+ output_top_logprobs_val: List[List]
337
+ output_top_logprobs_idx: List[List]
338
+ normalized_prompt_logprob: List[float]
323
339
 
324
340
 
325
341
  @dataclass
326
342
  class BatchStrOut:
327
343
  # The request id
328
344
  rids: List[str]
345
+ # The finish reason
346
+ finished_reasons: List[dict]
329
347
  # The output decoded strings
330
348
  output_strs: List[str]
331
- # The meta info
332
- meta_info: List[Dict]
333
- # The finish reason
334
- finished_reason: List[BaseFinishReason]
349
+
350
+ # Token counts
351
+ prompt_tokens: List[int]
352
+ completion_tokens: List[int]
353
+ cached_tokens: List[int]
354
+ # Logprobs
355
+ input_token_logprobs_val: List[float]
356
+ input_token_logprobs_idx: List[int]
357
+ output_token_logprobs_val: List[float]
358
+ output_token_logprobs_idx: List[int]
359
+ input_top_logprobs_val: List[List]
360
+ input_top_logprobs_idx: List[List]
361
+ output_top_logprobs_val: List[List]
362
+ output_top_logprobs_idx: List[List]
363
+ normalized_prompt_logprob: List[float]
335
364
 
336
365
 
337
366
  @dataclass
338
367
  class BatchEmbeddingOut:
339
368
  # The request id
340
369
  rids: List[str]
370
+ # The finish reason
371
+ finished_reasons: List[BaseFinishReason]
341
372
  # The output embedding
342
373
  embeddings: List[List[float]]
343
- # The meta info
344
- meta_info: List[Dict]
345
- # The finish reason
346
- finished_reason: List[BaseFinishReason]
374
+ # Token counts
375
+ prompt_tokens: List[int]
347
376
 
348
377
 
349
378
  @dataclass
@@ -129,6 +129,7 @@ class ImageInputs:
129
129
  image_hashes: Optional[list] = None
130
130
  image_sizes: Optional[list] = None
131
131
  image_offsets: Optional[list] = None
132
+ image_pad_len: Optional[list] = None
132
133
  pad_values: Optional[list] = None
133
134
  modalities: Optional[list] = None
134
135
  num_image_tokens: Optional[int] = None
@@ -181,6 +182,7 @@ class ImageInputs:
181
182
  optional_args = [
182
183
  "image_sizes",
183
184
  "image_offsets",
185
+ "image_pad_len",
184
186
  # "modalities", # modalities should be ["multi-images"] (one entry) even for multiple images
185
187
  "aspect_ratio_ids",
186
188
  "aspect_ratio_mask",
@@ -200,6 +202,9 @@ class Req:
200
202
  origin_input_text: str,
201
203
  origin_input_ids: Tuple[int],
202
204
  sampling_params: SamplingParams,
205
+ return_logprob: bool = False,
206
+ top_logprobs_num: int = 0,
207
+ stream: bool = False,
203
208
  origin_input_ids_unpadded: Optional[Tuple[int]] = None,
204
209
  lora_path: Optional[str] = None,
205
210
  input_embeds: Optional[List[List[float]]] = None,
@@ -217,10 +222,11 @@ class Req:
217
222
  self.output_ids = [] # Each decode stage's output ids
218
223
  self.fill_ids = None # fill_ids = origin_input_ids + output_ids
219
224
  self.session_id = session_id
225
+ self.input_embeds = input_embeds
220
226
 
227
+ # Sampling info
221
228
  self.sampling_params = sampling_params
222
229
  self.lora_path = lora_path
223
- self.input_embeds = input_embeds
224
230
 
225
231
  # Memory pool info
226
232
  self.req_pool_idx = None
@@ -228,8 +234,8 @@ class Req:
228
234
  # Check finish
229
235
  self.tokenizer = None
230
236
  self.finished_reason = None
231
- self.stream = False
232
237
  self.to_abort = False
238
+ self.stream = stream
233
239
 
234
240
  # For incremental decoding
235
241
  # ----- | --------- read_ids -------|
@@ -241,37 +247,46 @@ class Req:
241
247
  # 2: read_offset
242
248
  # 3: last token
243
249
  self.vid = 0 # version id to sync decode status with in detokenizer_manager
244
- self.decoded_text = ""
245
250
  self.surr_offset = None # Surrounding offset to defeat the cleanup algorithm
246
251
  self.read_offset = None
247
-
248
- # The number of decoded tokens for token usage report. Note that
249
- # this does not include the jump forward tokens.
250
- self.completion_tokens_wo_jump_forward = 0
252
+ self.decoded_text = ""
251
253
 
252
254
  # For multimodal inputs
253
255
  self.image_inputs: Optional[ImageInputs] = None
254
256
 
255
257
  # Prefix info
256
258
  self.prefix_indices = []
259
+ # Tokens to run prefill. input_tokens - shared_prefix_tokens.
257
260
  self.extend_input_len = 0
258
261
  self.last_node = None
262
+
263
+ # Chunked prefill
259
264
  self.is_being_chunked = 0
260
265
 
261
266
  # For retraction
262
267
  self.is_retracted = False
263
268
 
264
269
  # Logprobs (arguments)
265
- self.return_logprob = False
270
+ self.return_logprob = return_logprob
266
271
  self.logprob_start_len = 0
267
- self.top_logprobs_num = 0
272
+ self.top_logprobs_num = top_logprobs_num
268
273
 
269
274
  # Logprobs (return value)
270
275
  self.normalized_prompt_logprob = None
271
- self.input_token_logprobs = None
272
- self.input_top_logprobs = None
273
- self.output_token_logprobs = []
274
- self.output_top_logprobs = []
276
+ self.input_token_logprobs_val = None
277
+ self.input_token_logprobs_idx = None
278
+ self.input_top_logprobs_val = None
279
+ self.input_top_logprobs_idx = None
280
+
281
+ if return_logprob:
282
+ self.output_token_logprobs_val = []
283
+ self.output_token_logprobs_idx = []
284
+ self.output_top_logprobs_val = []
285
+ self.output_top_logprobs_idx = []
286
+ else:
287
+ self.output_token_logprobs_val = self.output_token_logprobs_idx = (
288
+ self.output_top_logprobs_val
289
+ ) = self.output_top_logprobs_idx = None
275
290
 
276
291
  # Logprobs (internal values)
277
292
  # The tokens is prefilled but need to be considered as decode tokens
@@ -295,13 +310,14 @@ class Req:
295
310
  else:
296
311
  self.image_inputs.merge(image_inputs)
297
312
 
298
- # whether request reached finished condition
299
313
  def finished(self) -> bool:
314
+ # Whether request reached finished condition
300
315
  return self.finished_reason is not None
301
316
 
302
317
  def init_next_round_input(self, tree_cache: Optional[BasePrefixCache] = None):
303
318
  self.fill_ids = self.origin_input_ids + self.output_ids
304
319
  if tree_cache is not None:
320
+ # tree cache is None if the prefix is not computed with tree cache.
305
321
  self.prefix_indices, self.last_node = tree_cache.match_prefix(
306
322
  rid=self.rid, key=self.adjust_max_prefix_ids()
307
323
  )
@@ -454,8 +470,10 @@ class Req:
454
470
  k = k + 1
455
471
  else:
456
472
  break
457
- self.output_token_logprobs = self.output_token_logprobs[:k]
458
- self.output_top_logprobs = self.output_top_logprobs[:k]
473
+ self.output_token_logprobs_val = self.output_token_logprobs_val[:k]
474
+ self.output_token_logprobs_idx = self.output_token_logprobs_idx[:k]
475
+ self.output_top_logprobs_val = self.output_top_logprobs_val[:k]
476
+ self.output_top_logprobs_idx = self.output_top_logprobs_idx[:k]
459
477
  self.logprob_start_len = prompt_tokens + k
460
478
  self.last_update_decode_tokens = len(self.output_ids) - k
461
479
 
@@ -470,7 +488,7 @@ bid = 0
470
488
 
471
489
  @dataclasses.dataclass
472
490
  class ScheduleBatch:
473
- """Store all inforamtion of a batch on the scheduler."""
491
+ """Store all information of a batch on the scheduler."""
474
492
 
475
493
  # Request, memory pool, and cache
476
494
  reqs: List[Req]
@@ -1068,9 +1086,9 @@ class ScheduleBatch:
1068
1086
  self.top_logprobs_nums = [0] * len(self.reqs) + other.top_logprobs_nums
1069
1087
  self.reqs.extend(other.reqs)
1070
1088
 
1071
- self.return_logprob = self.return_logprob or other.return_logprob
1072
- self.has_stream = self.has_stream or other.has_stream
1073
- self.has_grammar = self.has_grammar or other.has_grammar
1089
+ self.return_logprob |= other.return_logprob
1090
+ self.has_stream |= other.has_stream
1091
+ self.has_grammar |= other.has_grammar
1074
1092
 
1075
1093
  def get_model_worker_batch(self):
1076
1094
  if self.forward_mode.is_decode() or self.forward_mode.is_idle():
@@ -1097,7 +1115,6 @@ class ScheduleBatch:
1097
1115
  seq_lens=self.seq_lens,
1098
1116
  out_cache_loc=self.out_cache_loc,
1099
1117
  seq_lens_sum=self.seq_lens_sum,
1100
- req_to_token_pool_records=self.req_to_token_pool.get_write_records(),
1101
1118
  return_logprob=self.return_logprob,
1102
1119
  top_logprobs_nums=self.top_logprobs_nums,
1103
1120
  global_num_tokens=self.global_num_tokens,
@@ -1152,9 +1169,6 @@ class ModelWorkerBatch:
1152
1169
  # The sum of all sequence lengths
1153
1170
  seq_lens_sum: int
1154
1171
 
1155
- # The memory pool operation records
1156
- req_to_token_pool_records: Optional[List[Tuple[Tuple, torch.Tensor]]]
1157
-
1158
1172
  # For logprob
1159
1173
  return_logprob: bool
1160
1174
  top_logprobs_nums: Optional[List[int]]