sglang 0.4.3.post3__py3-none-any.whl → 0.4.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (128) hide show
  1. sglang/bench_serving.py +2 -2
  2. sglang/lang/chat_template.py +29 -0
  3. sglang/srt/_custom_ops.py +19 -17
  4. sglang/srt/configs/__init__.py +2 -0
  5. sglang/srt/configs/janus_pro.py +629 -0
  6. sglang/srt/configs/model_config.py +24 -14
  7. sglang/srt/conversation.py +80 -2
  8. sglang/srt/custom_op.py +64 -3
  9. sglang/srt/distributed/device_communicators/custom_all_reduce.py +18 -17
  10. sglang/srt/distributed/parallel_state.py +10 -1
  11. sglang/srt/entrypoints/engine.py +5 -3
  12. sglang/srt/entrypoints/http_server.py +1 -1
  13. sglang/srt/hf_transformers_utils.py +16 -1
  14. sglang/srt/layers/attention/flashinfer_backend.py +95 -49
  15. sglang/srt/layers/attention/flashinfer_mla_backend.py +317 -57
  16. sglang/srt/layers/attention/triton_backend.py +5 -5
  17. sglang/srt/layers/attention/triton_ops/decode_attention.py +6 -6
  18. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +3 -3
  19. sglang/srt/layers/attention/triton_ops/extend_attention.py +4 -4
  20. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +3 -3
  21. sglang/srt/layers/attention/vision.py +43 -62
  22. sglang/srt/layers/linear.py +1 -1
  23. sglang/srt/layers/moe/ep_moe/kernels.py +2 -1
  24. sglang/srt/layers/moe/ep_moe/layer.py +25 -9
  25. sglang/srt/layers/moe/fused_moe_triton/configs/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  27. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +63 -23
  32. sglang/srt/layers/moe/fused_moe_triton/layer.py +16 -4
  33. sglang/srt/layers/parameter.py +10 -0
  34. sglang/srt/layers/quantization/__init__.py +90 -68
  35. sglang/srt/layers/quantization/blockwise_int8.py +1 -2
  36. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  37. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  38. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  39. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  40. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  41. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  43. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  44. sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  45. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  46. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  47. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  48. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  49. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  50. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  51. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  52. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/quantization/fp8.py +174 -106
  63. sglang/srt/layers/quantization/fp8_kernel.py +210 -38
  64. sglang/srt/layers/quantization/fp8_utils.py +156 -15
  65. sglang/srt/layers/quantization/modelopt_quant.py +5 -1
  66. sglang/srt/layers/quantization/w8a8_fp8.py +128 -0
  67. sglang/srt/layers/quantization/w8a8_int8.py +152 -3
  68. sglang/srt/layers/rotary_embedding.py +5 -3
  69. sglang/srt/layers/sampler.py +29 -35
  70. sglang/srt/layers/vocab_parallel_embedding.py +0 -1
  71. sglang/srt/lora/backend/__init__.py +9 -12
  72. sglang/srt/managers/cache_controller.py +72 -8
  73. sglang/srt/managers/image_processor.py +37 -631
  74. sglang/srt/managers/image_processors/base_image_processor.py +219 -0
  75. sglang/srt/managers/image_processors/janus_pro.py +79 -0
  76. sglang/srt/managers/image_processors/llava.py +152 -0
  77. sglang/srt/managers/image_processors/minicpmv.py +86 -0
  78. sglang/srt/managers/image_processors/mlama.py +60 -0
  79. sglang/srt/managers/image_processors/qwen_vl.py +161 -0
  80. sglang/srt/managers/io_struct.py +33 -15
  81. sglang/srt/managers/multi_modality_padding.py +134 -0
  82. sglang/srt/managers/schedule_batch.py +212 -117
  83. sglang/srt/managers/schedule_policy.py +40 -8
  84. sglang/srt/managers/scheduler.py +258 -782
  85. sglang/srt/managers/scheduler_output_processor_mixin.py +611 -0
  86. sglang/srt/managers/tokenizer_manager.py +7 -6
  87. sglang/srt/managers/tp_worker_overlap_thread.py +4 -1
  88. sglang/srt/mem_cache/base_prefix_cache.py +6 -8
  89. sglang/srt/mem_cache/chunk_cache.py +12 -44
  90. sglang/srt/mem_cache/hiradix_cache.py +63 -34
  91. sglang/srt/mem_cache/memory_pool.py +112 -46
  92. sglang/srt/mem_cache/paged_allocator.py +283 -0
  93. sglang/srt/mem_cache/radix_cache.py +117 -36
  94. sglang/srt/metrics/collector.py +8 -0
  95. sglang/srt/model_executor/cuda_graph_runner.py +10 -11
  96. sglang/srt/model_executor/forward_batch_info.py +12 -8
  97. sglang/srt/model_executor/model_runner.py +153 -134
  98. sglang/srt/model_loader/loader.py +2 -1
  99. sglang/srt/model_loader/weight_utils.py +1 -1
  100. sglang/srt/models/deepseek_janus_pro.py +2127 -0
  101. sglang/srt/models/deepseek_nextn.py +23 -3
  102. sglang/srt/models/deepseek_v2.py +25 -19
  103. sglang/srt/models/minicpmv.py +28 -89
  104. sglang/srt/models/mllama.py +1 -1
  105. sglang/srt/models/qwen2.py +0 -1
  106. sglang/srt/models/qwen2_5_vl.py +25 -50
  107. sglang/srt/models/qwen2_vl.py +33 -49
  108. sglang/srt/openai_api/adapter.py +37 -15
  109. sglang/srt/openai_api/protocol.py +8 -1
  110. sglang/srt/sampling/penaltylib/frequency_penalty.py +0 -1
  111. sglang/srt/sampling/penaltylib/presence_penalty.py +0 -1
  112. sglang/srt/server_args.py +19 -20
  113. sglang/srt/speculative/build_eagle_tree.py +6 -1
  114. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -11
  115. sglang/srt/speculative/eagle_utils.py +2 -1
  116. sglang/srt/speculative/eagle_worker.py +109 -38
  117. sglang/srt/utils.py +104 -9
  118. sglang/test/runners.py +104 -10
  119. sglang/test/test_block_fp8.py +106 -16
  120. sglang/test/test_custom_ops.py +88 -0
  121. sglang/test/test_utils.py +20 -4
  122. sglang/utils.py +0 -4
  123. sglang/version.py +1 -1
  124. {sglang-0.4.3.post3.dist-info → sglang-0.4.4.dist-info}/METADATA +9 -9
  125. {sglang-0.4.3.post3.dist-info → sglang-0.4.4.dist-info}/RECORD +128 -83
  126. {sglang-0.4.3.post3.dist-info → sglang-0.4.4.dist-info}/WHEEL +1 -1
  127. {sglang-0.4.3.post3.dist-info → sglang-0.4.4.dist-info}/LICENSE +0 -0
  128. {sglang-0.4.3.post3.dist-info → sglang-0.4.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,611 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, List, Optional, Tuple, Union
4
+
5
+ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
6
+ from sglang.srt.managers.io_struct import BatchEmbeddingOut, BatchTokenIDOut
7
+ from sglang.srt.managers.schedule_batch import BaseFinishReason, Req, ScheduleBatch
8
+
9
+ if TYPE_CHECKING:
10
+ from sglang.srt.managers.scheduler import (
11
+ EmbeddingBatchResult,
12
+ GenerationBatchResult,
13
+ ScheduleBatch,
14
+ )
15
+
16
+
17
+ class SchedulerOutputProcessorMixin:
18
+ """
19
+ This class implements the output processing logic for Scheduler.
20
+ We put them into a separate file to make the `scheduler.py` shorter.
21
+ """
22
+
23
+ def process_batch_result_prefill(
24
+ self,
25
+ batch: ScheduleBatch,
26
+ result: Union[GenerationBatchResult, EmbeddingBatchResult],
27
+ ):
28
+ skip_stream_req = None
29
+
30
+ if self.is_generation:
31
+ (
32
+ logits_output,
33
+ next_token_ids,
34
+ extend_input_len_per_req,
35
+ extend_logprob_start_len_per_req,
36
+ bid,
37
+ ) = (
38
+ result.logits_output,
39
+ result.next_token_ids,
40
+ result.extend_input_len_per_req,
41
+ result.extend_logprob_start_len_per_req,
42
+ result.bid,
43
+ )
44
+
45
+ if self.enable_overlap:
46
+ logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
47
+ else:
48
+ # Move next_token_ids and logprobs to cpu
49
+ next_token_ids = next_token_ids.tolist()
50
+ if batch.return_logprob:
51
+ if logits_output.next_token_logprobs is not None:
52
+ logits_output.next_token_logprobs = (
53
+ logits_output.next_token_logprobs.tolist()
54
+ )
55
+ if logits_output.input_token_logprobs is not None:
56
+ logits_output.input_token_logprobs = tuple(
57
+ logits_output.input_token_logprobs.tolist()
58
+ )
59
+
60
+ hidden_state_offset = 0
61
+
62
+ # Check finish conditions
63
+ logprob_pt = 0
64
+ for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
65
+ if req.is_retracted:
66
+ continue
67
+
68
+ if self.is_mixed_chunk and self.enable_overlap and req.finished():
69
+ # Free the one delayed token for the mixed decode batch
70
+ j = len(batch.out_cache_loc) - len(batch.reqs) + i
71
+ self.token_to_kv_pool_allocator.free(batch.out_cache_loc[j : j + 1])
72
+ continue
73
+
74
+ if req.is_chunked <= 0:
75
+ # req output_ids are set here
76
+ req.output_ids.append(next_token_id)
77
+ req.check_finished()
78
+
79
+ if req.finished():
80
+ self.tree_cache.cache_finished_req(req)
81
+ elif not batch.decoding_reqs or req not in batch.decoding_reqs:
82
+ # This updates radix so others can match
83
+ self.tree_cache.cache_unfinished_req(req)
84
+
85
+ if req.return_logprob:
86
+ assert extend_logprob_start_len_per_req is not None
87
+ assert extend_input_len_per_req is not None
88
+ extend_logprob_start_len = extend_logprob_start_len_per_req[i]
89
+ extend_input_len = extend_input_len_per_req[i]
90
+ num_input_logprobs = extend_input_len - extend_logprob_start_len
91
+ self.add_logprob_return_values(
92
+ i,
93
+ req,
94
+ logprob_pt,
95
+ next_token_ids,
96
+ num_input_logprobs,
97
+ logits_output,
98
+ )
99
+ logprob_pt += num_input_logprobs
100
+
101
+ if (
102
+ req.return_hidden_states
103
+ and logits_output.hidden_states is not None
104
+ ):
105
+ req.hidden_states.append(
106
+ logits_output.hidden_states[
107
+ hidden_state_offset : (
108
+ hidden_state_offset := hidden_state_offset
109
+ + len(req.origin_input_ids)
110
+ )
111
+ ]
112
+ .cpu()
113
+ .clone()
114
+ )
115
+
116
+ if req.grammar is not None:
117
+ req.grammar.accept_token(next_token_id)
118
+ req.grammar.finished = req.finished()
119
+ else:
120
+ # being chunked reqs' prefill is not finished
121
+ req.is_chunked -= 1
122
+ # There is only at most one request being currently chunked.
123
+ # Because this request does not finish prefill,
124
+ # we don't want to stream the request currently being chunked.
125
+ skip_stream_req = req
126
+
127
+ # Incrementally update input logprobs.
128
+ if req.return_logprob:
129
+ extend_logprob_start_len = extend_logprob_start_len_per_req[i]
130
+ extend_input_len = extend_input_len_per_req[i]
131
+ if extend_logprob_start_len < extend_input_len:
132
+ # Update input logprobs.
133
+ num_input_logprobs = (
134
+ extend_input_len - extend_logprob_start_len
135
+ )
136
+ self.add_input_logprob_return_values(
137
+ i,
138
+ req,
139
+ logits_output,
140
+ logprob_pt,
141
+ num_input_logprobs,
142
+ last_prefill_chunk=False,
143
+ )
144
+ logprob_pt += num_input_logprobs
145
+
146
+ if batch.next_batch_sampling_info:
147
+ batch.next_batch_sampling_info.update_regex_vocab_mask()
148
+ self.current_stream.synchronize()
149
+ batch.next_batch_sampling_info.sampling_info_done.set()
150
+
151
+ else: # embedding or reward model
152
+ embeddings, bid = result.embeddings, result.bid
153
+ embeddings = embeddings.tolist()
154
+
155
+ # Check finish conditions
156
+ for i, req in enumerate(batch.reqs):
157
+ if req.is_retracted:
158
+ continue
159
+
160
+ req.embedding = embeddings[i]
161
+ if req.is_chunked <= 0:
162
+ # Dummy output token for embedding models
163
+ req.output_ids.append(0)
164
+ req.check_finished()
165
+
166
+ if req.finished():
167
+ self.tree_cache.cache_finished_req(req)
168
+ else:
169
+ self.tree_cache.cache_unfinished_req(req)
170
+ else:
171
+ # being chunked reqs' prefill is not finished
172
+ req.is_chunked -= 1
173
+
174
+ self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req)
175
+
176
+ def process_batch_result_decode(
177
+ self,
178
+ batch: ScheduleBatch,
179
+ result: GenerationBatchResult,
180
+ ):
181
+ logits_output, next_token_ids, bid = (
182
+ result.logits_output,
183
+ result.next_token_ids,
184
+ result.bid,
185
+ )
186
+ self.num_generated_tokens += len(batch.reqs)
187
+
188
+ if self.enable_overlap:
189
+ logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
190
+ next_token_logprobs = logits_output.next_token_logprobs
191
+ elif batch.spec_algorithm.is_none():
192
+ # spec decoding handles output logprobs inside verify process.
193
+ next_token_ids = next_token_ids.tolist()
194
+ if batch.return_logprob:
195
+ next_token_logprobs = logits_output.next_token_logprobs.tolist()
196
+
197
+ self.token_to_kv_pool_allocator.free_group_begin()
198
+
199
+ # Check finish condition
200
+ # NOTE: the length of reqs and next_token_ids don't match if it is spec decoding.
201
+ # We should ignore using next_token_ids for spec decoding cases.
202
+ for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
203
+ if req.is_retracted:
204
+ continue
205
+
206
+ if self.enable_overlap and req.finished():
207
+ # Free the one extra delayed token
208
+ if self.page_size == 1:
209
+ self.token_to_kv_pool_allocator.free(batch.out_cache_loc[i : i + 1])
210
+ else:
211
+ # Only free when the extra token is in a new page
212
+ if (
213
+ len(req.origin_input_ids) + len(req.output_ids) - 1
214
+ ) % self.page_size == 0:
215
+ self.token_to_kv_pool_allocator.free(
216
+ batch.out_cache_loc[i : i + 1]
217
+ )
218
+ continue
219
+
220
+ if batch.spec_algorithm.is_none():
221
+ # speculative worker will solve the output_ids in speculative decoding
222
+ req.output_ids.append(next_token_id)
223
+
224
+ req.check_finished()
225
+ if req.finished():
226
+ self.tree_cache.cache_finished_req(req)
227
+
228
+ if req.return_logprob and batch.spec_algorithm.is_none():
229
+ # speculative worker handles logprob in speculative decoding
230
+ req.output_token_logprobs_val.append(next_token_logprobs[i])
231
+ req.output_token_logprobs_idx.append(next_token_id)
232
+ if req.top_logprobs_num > 0:
233
+ req.output_top_logprobs_val.append(
234
+ logits_output.next_token_top_logprobs_val[i]
235
+ )
236
+ req.output_top_logprobs_idx.append(
237
+ logits_output.next_token_top_logprobs_idx[i]
238
+ )
239
+ if req.token_ids_logprob is not None:
240
+ req.output_token_ids_logprobs_val.append(
241
+ logits_output.next_token_token_ids_logprobs_val[i]
242
+ )
243
+ req.output_token_ids_logprobs_idx.append(
244
+ logits_output.next_token_token_ids_logprobs_idx[i]
245
+ )
246
+
247
+ if req.return_hidden_states and logits_output.hidden_states is not None:
248
+ req.hidden_states.append(logits_output.hidden_states[i].cpu().clone())
249
+
250
+ if req.grammar is not None and batch.spec_algorithm.is_none():
251
+ req.grammar.accept_token(next_token_id)
252
+ req.grammar.finished = req.finished()
253
+
254
+ if batch.next_batch_sampling_info:
255
+ batch.next_batch_sampling_info.update_regex_vocab_mask()
256
+ self.current_stream.synchronize()
257
+ batch.next_batch_sampling_info.sampling_info_done.set()
258
+
259
+ self.stream_output(batch.reqs, batch.return_logprob)
260
+
261
+ self.token_to_kv_pool_allocator.free_group_end()
262
+
263
+ self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
264
+ if (
265
+ self.attn_tp_rank == 0
266
+ and self.forward_ct_decode % self.server_args.decode_log_interval == 0
267
+ ):
268
+ self.log_decode_stats()
269
+
270
+ def add_input_logprob_return_values(
271
+ self,
272
+ i: int,
273
+ req: Req,
274
+ output: LogitsProcessorOutput,
275
+ logprob_pt: int,
276
+ num_input_logprobs: int,
277
+ last_prefill_chunk: bool, # If True, it means prefill is finished.
278
+ ):
279
+ """Incrementally add input logprobs to `req`.
280
+
281
+ Args:
282
+ i: The request index in a batch.
283
+ req: The request. Input logprobs inside req are modified as a
284
+ consequence of the API
285
+ fill_ids: The prefill ids processed.
286
+ output: Logit processor output that's used to compute input logprobs
287
+ last_prefill_chunk: True if it is the last prefill (when chunked).
288
+ Some of input logprob operation should only happen at the last
289
+ prefill (e.g., computing input token logprobs).
290
+ """
291
+ assert output.input_token_logprobs is not None
292
+ if req.input_token_logprobs is None:
293
+ req.input_token_logprobs = []
294
+ if req.temp_input_top_logprobs_val is None:
295
+ req.temp_input_top_logprobs_val = []
296
+ if req.temp_input_top_logprobs_idx is None:
297
+ req.temp_input_top_logprobs_idx = []
298
+ if req.temp_input_token_ids_logprobs_val is None:
299
+ req.temp_input_token_ids_logprobs_val = []
300
+ if req.temp_input_token_ids_logprobs_idx is None:
301
+ req.temp_input_token_ids_logprobs_idx = []
302
+
303
+ if req.input_token_logprobs_val is not None:
304
+ # The input logprob has been already computed. It only happens
305
+ # upon retract.
306
+ if req.top_logprobs_num > 0:
307
+ assert req.input_token_logprobs_val is not None
308
+ return
309
+
310
+ # Important for the performance.
311
+ assert isinstance(output.input_token_logprobs, tuple)
312
+ input_token_logprobs: Tuple[int] = output.input_token_logprobs
313
+ input_token_logprobs = input_token_logprobs[
314
+ logprob_pt : logprob_pt + num_input_logprobs
315
+ ]
316
+ req.input_token_logprobs.extend(input_token_logprobs)
317
+
318
+ if req.top_logprobs_num > 0:
319
+ req.temp_input_top_logprobs_val.append(output.input_top_logprobs_val[i])
320
+ req.temp_input_top_logprobs_idx.append(output.input_top_logprobs_idx[i])
321
+
322
+ if req.token_ids_logprob is not None:
323
+ req.temp_input_token_ids_logprobs_val.append(
324
+ output.input_token_ids_logprobs_val[i]
325
+ )
326
+ req.temp_input_token_ids_logprobs_idx.append(
327
+ output.input_token_ids_logprobs_idx[i]
328
+ )
329
+
330
+ if last_prefill_chunk:
331
+ input_token_logprobs = req.input_token_logprobs
332
+ req.input_token_logprobs = None
333
+ assert req.input_token_logprobs_val is None
334
+ assert req.input_token_logprobs_idx is None
335
+ assert req.input_top_logprobs_val is None
336
+ assert req.input_top_logprobs_idx is None
337
+
338
+ # Compute input_token_logprobs_val
339
+ # Always pad the first one with None.
340
+ req.input_token_logprobs_val = [None]
341
+ req.input_token_logprobs_val.extend(input_token_logprobs)
342
+ # The last input logprob is for sampling, so just pop it out.
343
+ req.input_token_logprobs_val.pop()
344
+
345
+ # Compute input_token_logprobs_idx
346
+ input_token_logprobs_idx = req.origin_input_ids[req.logprob_start_len :]
347
+ # Clip the padded hash values from image tokens.
348
+ # Otherwise, it will lead to detokenization errors.
349
+ input_token_logprobs_idx = [
350
+ x if x < self.model_config.vocab_size - 1 else 0
351
+ for x in input_token_logprobs_idx
352
+ ]
353
+ req.input_token_logprobs_idx = input_token_logprobs_idx
354
+
355
+ if req.top_logprobs_num > 0:
356
+ req.input_top_logprobs_val = [None]
357
+ req.input_top_logprobs_idx = [None]
358
+ assert len(req.temp_input_token_ids_logprobs_val) == len(
359
+ req.temp_input_token_ids_logprobs_idx
360
+ )
361
+ for val, idx in zip(
362
+ req.temp_input_top_logprobs_val,
363
+ req.temp_input_top_logprobs_idx,
364
+ strict=True,
365
+ ):
366
+ req.input_top_logprobs_val.extend(val)
367
+ req.input_top_logprobs_idx.extend(idx)
368
+
369
+ # Last token is a sample token.
370
+ req.input_top_logprobs_val.pop()
371
+ req.input_top_logprobs_idx.pop()
372
+ req.temp_input_top_logprobs_idx = None
373
+ req.temp_input_top_logprobs_val = None
374
+
375
+ if req.token_ids_logprob is not None:
376
+ req.input_token_ids_logprobs_val = [None]
377
+ req.input_token_ids_logprobs_idx = [None]
378
+
379
+ for val, idx in zip(
380
+ req.temp_input_token_ids_logprobs_val,
381
+ req.temp_input_token_ids_logprobs_idx,
382
+ strict=True,
383
+ ):
384
+ req.input_token_ids_logprobs_val.extend(val)
385
+ req.input_token_ids_logprobs_idx.extend(idx)
386
+
387
+ # Last token is a sample token.
388
+ req.input_token_ids_logprobs_val.pop()
389
+ req.input_token_ids_logprobs_idx.pop()
390
+ req.temp_input_token_ids_logprobs_idx = None
391
+ req.temp_input_token_ids_logprobs_val = None
392
+
393
+ if req.return_logprob:
394
+ relevant_tokens_len = len(req.origin_input_ids) - req.logprob_start_len
395
+ assert len(req.input_token_logprobs_val) == relevant_tokens_len
396
+ assert len(req.input_token_logprobs_idx) == relevant_tokens_len
397
+ if req.top_logprobs_num > 0:
398
+ assert len(req.input_top_logprobs_val) == relevant_tokens_len
399
+ assert len(req.input_top_logprobs_idx) == relevant_tokens_len
400
+ if req.token_ids_logprob is not None:
401
+ assert len(req.input_token_ids_logprobs_val) == relevant_tokens_len
402
+ assert len(req.input_token_ids_logprobs_idx) == relevant_tokens_len
403
+
404
+ def add_logprob_return_values(
405
+ self,
406
+ i: int,
407
+ req: Req,
408
+ pt: int,
409
+ next_token_ids: List[int],
410
+ num_input_logprobs: int,
411
+ output: LogitsProcessorOutput,
412
+ ):
413
+ """Attach logprobs to the return values."""
414
+ req.output_token_logprobs_val.append(output.next_token_logprobs[i])
415
+ req.output_token_logprobs_idx.append(next_token_ids[i])
416
+
417
+ self.add_input_logprob_return_values(
418
+ i, req, output, pt, num_input_logprobs, last_prefill_chunk=True
419
+ )
420
+
421
+ if req.top_logprobs_num > 0:
422
+ req.output_top_logprobs_val.append(output.next_token_top_logprobs_val[i])
423
+ req.output_top_logprobs_idx.append(output.next_token_top_logprobs_idx[i])
424
+
425
+ if req.token_ids_logprob is not None:
426
+ req.output_token_ids_logprobs_val.append(
427
+ output.next_token_token_ids_logprobs_val[i]
428
+ )
429
+ req.output_token_ids_logprobs_idx.append(
430
+ output.next_token_token_ids_logprobs_idx[i]
431
+ )
432
+
433
+ return num_input_logprobs
434
+
435
+ def stream_output(
436
+ self, reqs: List[Req], return_logprob: bool, skip_req: Optional[Req] = None
437
+ ):
438
+ """Stream the output to detokenizer."""
439
+ if self.is_generation:
440
+ self.stream_output_generation(reqs, return_logprob, skip_req)
441
+ else: # embedding or reward model
442
+ self.stream_output_embedding(reqs)
443
+
444
+ def stream_output_generation(
445
+ self, reqs: List[Req], return_logprob: bool, skip_req: Optional[Req] = None
446
+ ):
447
+ rids = []
448
+ finished_reasons: List[BaseFinishReason] = []
449
+
450
+ decoded_texts = []
451
+ decode_ids_list = []
452
+ read_offsets = []
453
+ output_ids = []
454
+
455
+ skip_special_tokens = []
456
+ spaces_between_special_tokens = []
457
+ no_stop_trim = []
458
+ prompt_tokens = []
459
+ completion_tokens = []
460
+ cached_tokens = []
461
+ spec_verify_ct = []
462
+ output_hidden_states = None
463
+
464
+ if return_logprob:
465
+ input_token_logprobs_val = []
466
+ input_token_logprobs_idx = []
467
+ output_token_logprobs_val = []
468
+ output_token_logprobs_idx = []
469
+ input_top_logprobs_val = []
470
+ input_top_logprobs_idx = []
471
+ output_top_logprobs_val = []
472
+ output_top_logprobs_idx = []
473
+ input_token_ids_logprobs_val = []
474
+ input_token_ids_logprobs_idx = []
475
+ output_token_ids_logprobs_val = []
476
+ output_token_ids_logprobs_idx = []
477
+ else:
478
+ input_token_logprobs_val = input_token_logprobs_idx = (
479
+ output_token_logprobs_val
480
+ ) = output_token_logprobs_idx = input_top_logprobs_val = (
481
+ input_top_logprobs_idx
482
+ ) = output_top_logprobs_val = output_top_logprobs_idx = (
483
+ input_token_ids_logprobs_val
484
+ ) = input_token_ids_logprobs_idx = output_token_ids_logprobs_val = (
485
+ output_token_ids_logprobs_idx
486
+ ) = None
487
+
488
+ for req in reqs:
489
+ if req is skip_req:
490
+ continue
491
+
492
+ # Multimodal partial stream chunks break the detokenizer, so drop aborted requests here.
493
+ if self.model_config.is_multimodal_gen and req.to_abort:
494
+ continue
495
+
496
+ if (
497
+ req.finished()
498
+ # If stream, follow the given stream_interval
499
+ or (req.stream and len(req.output_ids) % self.stream_interval == 0)
500
+ # If not stream, we still want to output some tokens to get the benefit of incremental decoding.
501
+ # TODO(lianmin): this is wrong for speculative decoding because len(req.output_ids) does not
502
+ # always increase one-by-one.
503
+ or (
504
+ not req.stream
505
+ and len(req.output_ids) % 50 == 0
506
+ and not self.model_config.is_multimodal_gen
507
+ )
508
+ ):
509
+ rids.append(req.rid)
510
+ finished_reasons.append(
511
+ req.finished_reason.to_json() if req.finished_reason else None
512
+ )
513
+ decoded_texts.append(req.decoded_text)
514
+ decode_ids, read_offset = req.init_incremental_detokenize()
515
+ decode_ids_list.append(decode_ids)
516
+ read_offsets.append(read_offset)
517
+ if self.skip_tokenizer_init:
518
+ output_ids.append(req.output_ids)
519
+ skip_special_tokens.append(req.sampling_params.skip_special_tokens)
520
+ spaces_between_special_tokens.append(
521
+ req.sampling_params.spaces_between_special_tokens
522
+ )
523
+ no_stop_trim.append(req.sampling_params.no_stop_trim)
524
+ prompt_tokens.append(len(req.origin_input_ids))
525
+ completion_tokens.append(len(req.output_ids))
526
+ cached_tokens.append(req.cached_tokens)
527
+
528
+ if not self.spec_algorithm.is_none():
529
+ spec_verify_ct.append(req.spec_verify_ct)
530
+
531
+ if return_logprob:
532
+ input_token_logprobs_val.append(req.input_token_logprobs_val)
533
+ input_token_logprobs_idx.append(req.input_token_logprobs_idx)
534
+ output_token_logprobs_val.append(req.output_token_logprobs_val)
535
+ output_token_logprobs_idx.append(req.output_token_logprobs_idx)
536
+ input_top_logprobs_val.append(req.input_top_logprobs_val)
537
+ input_top_logprobs_idx.append(req.input_top_logprobs_idx)
538
+ output_top_logprobs_val.append(req.output_top_logprobs_val)
539
+ output_top_logprobs_idx.append(req.output_top_logprobs_idx)
540
+ input_token_ids_logprobs_val.append(
541
+ req.input_token_ids_logprobs_val
542
+ )
543
+ input_token_ids_logprobs_idx.append(
544
+ req.input_token_ids_logprobs_idx
545
+ )
546
+ output_token_ids_logprobs_val.append(
547
+ req.output_token_ids_logprobs_val
548
+ )
549
+ output_token_ids_logprobs_idx.append(
550
+ req.output_token_ids_logprobs_idx
551
+ )
552
+
553
+ if req.return_hidden_states:
554
+ if output_hidden_states is None:
555
+ output_hidden_states = []
556
+ output_hidden_states.append(req.hidden_states)
557
+
558
+ # Send to detokenizer
559
+ if rids:
560
+ if self.model_config.is_multimodal_gen:
561
+ return
562
+ self.send_to_detokenizer.send_pyobj(
563
+ BatchTokenIDOut(
564
+ rids,
565
+ finished_reasons,
566
+ decoded_texts,
567
+ decode_ids_list,
568
+ read_offsets,
569
+ output_ids,
570
+ skip_special_tokens,
571
+ spaces_between_special_tokens,
572
+ no_stop_trim,
573
+ prompt_tokens,
574
+ completion_tokens,
575
+ cached_tokens,
576
+ spec_verify_ct,
577
+ input_token_logprobs_val,
578
+ input_token_logprobs_idx,
579
+ output_token_logprobs_val,
580
+ output_token_logprobs_idx,
581
+ input_top_logprobs_val,
582
+ input_top_logprobs_idx,
583
+ output_top_logprobs_val,
584
+ output_top_logprobs_idx,
585
+ input_token_ids_logprobs_val,
586
+ input_token_ids_logprobs_idx,
587
+ output_token_ids_logprobs_val,
588
+ output_token_ids_logprobs_idx,
589
+ output_hidden_states,
590
+ )
591
+ )
592
+
593
+ def stream_output_embedding(self, reqs: List[Req]):
594
+ rids = []
595
+ finished_reasons: List[BaseFinishReason] = []
596
+
597
+ embeddings = []
598
+ prompt_tokens = []
599
+ cached_tokens = []
600
+ for req in reqs:
601
+ if req.finished():
602
+ rids.append(req.rid)
603
+ finished_reasons.append(req.finished_reason.to_json())
604
+ embeddings.append(req.embedding)
605
+ prompt_tokens.append(len(req.origin_input_ids))
606
+ cached_tokens.append(req.cached_tokens)
607
+ self.send_to_detokenizer.send_pyobj(
608
+ BatchEmbeddingOut(
609
+ rids, finished_reasons, embeddings, prompt_tokens, cached_tokens
610
+ )
611
+ )
@@ -372,13 +372,12 @@ class TokenizerManager:
372
372
  )
373
373
  input_ids = self.tokenizer.encode(input_text)
374
374
 
375
+ image_inputs: Dict = await self.image_processor.process_images_async(
376
+ obj.image_data, input_text or input_ids, obj, self.max_req_input_len
377
+ )
378
+ if image_inputs and "input_ids" in image_inputs:
379
+ input_ids = image_inputs["input_ids"]
375
380
  if self.is_generation:
376
- # TODO: also support getting embeddings for multimodal models
377
- image_inputs: Dict = await self.image_processor.process_images_async(
378
- obj.image_data, input_text or input_ids, obj, self.max_req_input_len
379
- )
380
- if image_inputs and "input_ids" in image_inputs:
381
- input_ids = image_inputs["input_ids"]
382
381
  return_logprob = obj.return_logprob
383
382
  logprob_start_len = obj.logprob_start_len
384
383
  top_logprobs_num = obj.top_logprobs_num
@@ -438,6 +437,7 @@ class TokenizerManager:
438
437
  obj.rid,
439
438
  input_text,
440
439
  input_ids,
440
+ image_inputs,
441
441
  sampling_params,
442
442
  )
443
443
 
@@ -1068,6 +1068,7 @@ class TokenizerManager:
1068
1068
  self.metrics_collector.observe_one_finished_request(
1069
1069
  recv_obj.prompt_tokens[i],
1070
1070
  completion_tokens,
1071
+ recv_obj.cached_tokens[i],
1071
1072
  state.finished_time - state.created_time,
1072
1073
  )
1073
1074
 
@@ -103,6 +103,9 @@ class TpModelWorkerClient:
103
103
  self.worker.model_runner.token_to_kv_pool_allocator,
104
104
  )
105
105
 
106
+ def get_kv_cache(self):
107
+ return self.worker.model_runner.token_to_kv_pool
108
+
106
109
  def forward_thread_func(self):
107
110
  try:
108
111
  with torch.get_device_module(self.device).stream(self.forward_stream):
@@ -203,7 +206,7 @@ class TpModelWorkerClient:
203
206
  -(self.future_token_ids_ct + 1),
204
207
  -(self.future_token_ids_ct + 1 + bs),
205
208
  -1,
206
- dtype=torch.int32,
209
+ dtype=torch.int64,
207
210
  device=self.device,
208
211
  )
209
212
  self.future_token_ids_ct = (