sglang 0.4.3.post4__py3-none-any.whl → 0.4.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. sglang/bench_serving.py +1 -1
  2. sglang/lang/chat_template.py +29 -0
  3. sglang/srt/_custom_ops.py +19 -17
  4. sglang/srt/configs/__init__.py +2 -0
  5. sglang/srt/configs/janus_pro.py +629 -0
  6. sglang/srt/configs/model_config.py +24 -14
  7. sglang/srt/conversation.py +80 -2
  8. sglang/srt/custom_op.py +64 -3
  9. sglang/srt/distributed/device_communicators/custom_all_reduce.py +18 -17
  10. sglang/srt/distributed/parallel_state.py +10 -1
  11. sglang/srt/entrypoints/engine.py +5 -3
  12. sglang/srt/entrypoints/http_server.py +1 -1
  13. sglang/srt/function_call_parser.py +33 -2
  14. sglang/srt/hf_transformers_utils.py +16 -1
  15. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  16. sglang/srt/layers/attention/flashinfer_mla_backend.py +317 -57
  17. sglang/srt/layers/attention/triton_backend.py +1 -3
  18. sglang/srt/layers/attention/triton_ops/decode_attention.py +6 -6
  19. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +3 -3
  20. sglang/srt/layers/attention/triton_ops/extend_attention.py +4 -4
  21. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +3 -3
  22. sglang/srt/layers/attention/vision.py +43 -62
  23. sglang/srt/layers/dp_attention.py +30 -2
  24. sglang/srt/layers/elementwise.py +411 -0
  25. sglang/srt/layers/linear.py +1 -1
  26. sglang/srt/layers/logits_processor.py +1 -0
  27. sglang/srt/layers/moe/ep_moe/kernels.py +2 -1
  28. sglang/srt/layers/moe/ep_moe/layer.py +25 -9
  29. sglang/srt/layers/moe/fused_moe_triton/configs/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +63 -23
  36. sglang/srt/layers/moe/fused_moe_triton/layer.py +16 -4
  37. sglang/srt/layers/moe/router.py +342 -0
  38. sglang/srt/layers/parameter.py +10 -0
  39. sglang/srt/layers/quantization/__init__.py +90 -68
  40. sglang/srt/layers/quantization/blockwise_int8.py +1 -2
  41. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  43. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  44. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  45. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  46. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  47. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  48. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  49. sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  50. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  51. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  52. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  63. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  64. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  65. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/quantization/fp8.py +174 -106
  68. sglang/srt/layers/quantization/fp8_kernel.py +210 -38
  69. sglang/srt/layers/quantization/fp8_utils.py +156 -15
  70. sglang/srt/layers/quantization/modelopt_quant.py +5 -1
  71. sglang/srt/layers/quantization/w8a8_fp8.py +128 -0
  72. sglang/srt/layers/quantization/w8a8_int8.py +152 -3
  73. sglang/srt/layers/rotary_embedding.py +5 -3
  74. sglang/srt/layers/sampler.py +29 -35
  75. sglang/srt/layers/vocab_parallel_embedding.py +0 -1
  76. sglang/srt/lora/backend/__init__.py +9 -12
  77. sglang/srt/managers/cache_controller.py +74 -8
  78. sglang/srt/managers/data_parallel_controller.py +1 -1
  79. sglang/srt/managers/image_processor.py +37 -631
  80. sglang/srt/managers/image_processors/base_image_processor.py +219 -0
  81. sglang/srt/managers/image_processors/janus_pro.py +79 -0
  82. sglang/srt/managers/image_processors/llava.py +152 -0
  83. sglang/srt/managers/image_processors/minicpmv.py +86 -0
  84. sglang/srt/managers/image_processors/mlama.py +60 -0
  85. sglang/srt/managers/image_processors/qwen_vl.py +161 -0
  86. sglang/srt/managers/io_struct.py +32 -15
  87. sglang/srt/managers/multi_modality_padding.py +134 -0
  88. sglang/srt/managers/schedule_batch.py +213 -118
  89. sglang/srt/managers/schedule_policy.py +40 -8
  90. sglang/srt/managers/scheduler.py +176 -683
  91. sglang/srt/managers/scheduler_output_processor_mixin.py +614 -0
  92. sglang/srt/managers/tokenizer_manager.py +6 -6
  93. sglang/srt/managers/tp_worker_overlap_thread.py +4 -1
  94. sglang/srt/mem_cache/base_prefix_cache.py +6 -8
  95. sglang/srt/mem_cache/chunk_cache.py +12 -44
  96. sglang/srt/mem_cache/hiradix_cache.py +71 -34
  97. sglang/srt/mem_cache/memory_pool.py +81 -17
  98. sglang/srt/mem_cache/paged_allocator.py +283 -0
  99. sglang/srt/mem_cache/radix_cache.py +117 -36
  100. sglang/srt/model_executor/cuda_graph_runner.py +68 -20
  101. sglang/srt/model_executor/forward_batch_info.py +23 -10
  102. sglang/srt/model_executor/model_runner.py +63 -63
  103. sglang/srt/model_loader/loader.py +2 -1
  104. sglang/srt/model_loader/weight_utils.py +1 -1
  105. sglang/srt/models/deepseek_janus_pro.py +2127 -0
  106. sglang/srt/models/deepseek_nextn.py +23 -3
  107. sglang/srt/models/deepseek_v2.py +200 -191
  108. sglang/srt/models/grok.py +374 -119
  109. sglang/srt/models/minicpmv.py +28 -89
  110. sglang/srt/models/mllama.py +1 -1
  111. sglang/srt/models/qwen2.py +0 -1
  112. sglang/srt/models/qwen2_5_vl.py +25 -50
  113. sglang/srt/models/qwen2_vl.py +33 -49
  114. sglang/srt/openai_api/adapter.py +59 -35
  115. sglang/srt/openai_api/protocol.py +8 -1
  116. sglang/srt/sampling/penaltylib/frequency_penalty.py +0 -1
  117. sglang/srt/sampling/penaltylib/presence_penalty.py +0 -1
  118. sglang/srt/server_args.py +24 -16
  119. sglang/srt/speculative/eagle_worker.py +75 -39
  120. sglang/srt/utils.py +104 -9
  121. sglang/test/runners.py +104 -10
  122. sglang/test/test_block_fp8.py +106 -16
  123. sglang/test/test_custom_ops.py +88 -0
  124. sglang/test/test_utils.py +20 -4
  125. sglang/utils.py +0 -4
  126. sglang/version.py +1 -1
  127. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/METADATA +9 -10
  128. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/RECORD +131 -84
  129. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/WHEEL +1 -1
  130. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/LICENSE +0 -0
  131. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,614 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, List, Optional, Tuple, Union
4
+
5
+ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
6
+ from sglang.srt.managers.io_struct import BatchEmbeddingOut, BatchTokenIDOut
7
+ from sglang.srt.managers.schedule_batch import BaseFinishReason, Req, ScheduleBatch
8
+
9
+ if TYPE_CHECKING:
10
+ from sglang.srt.managers.scheduler import (
11
+ EmbeddingBatchResult,
12
+ GenerationBatchResult,
13
+ ScheduleBatch,
14
+ )
15
+
16
+
17
+ class SchedulerOutputProcessorMixin:
18
+ """
19
+ This class implements the output processing logic for Scheduler.
20
+ We put them into a separate file to make the `scheduler.py` shorter.
21
+ """
22
+
23
+ def process_batch_result_prefill(
24
+ self,
25
+ batch: ScheduleBatch,
26
+ result: Union[GenerationBatchResult, EmbeddingBatchResult],
27
+ ):
28
+ skip_stream_req = None
29
+
30
+ if self.is_generation:
31
+ (
32
+ logits_output,
33
+ next_token_ids,
34
+ extend_input_len_per_req,
35
+ extend_logprob_start_len_per_req,
36
+ bid,
37
+ ) = (
38
+ result.logits_output,
39
+ result.next_token_ids,
40
+ result.extend_input_len_per_req,
41
+ result.extend_logprob_start_len_per_req,
42
+ result.bid,
43
+ )
44
+
45
+ if self.enable_overlap:
46
+ logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
47
+ else:
48
+ # Move next_token_ids and logprobs to cpu
49
+ next_token_ids = next_token_ids.tolist()
50
+ if batch.return_logprob:
51
+ if logits_output.next_token_logprobs is not None:
52
+ logits_output.next_token_logprobs = (
53
+ logits_output.next_token_logprobs.tolist()
54
+ )
55
+ if logits_output.input_token_logprobs is not None:
56
+ logits_output.input_token_logprobs = tuple(
57
+ logits_output.input_token_logprobs.tolist()
58
+ )
59
+
60
+ hidden_state_offset = 0
61
+
62
+ # Check finish conditions
63
+ logprob_pt = 0
64
+ for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
65
+ if req.is_retracted:
66
+ continue
67
+
68
+ if self.is_mixed_chunk and self.enable_overlap and req.finished():
69
+ # Free the one delayed token for the mixed decode batch
70
+ j = len(batch.out_cache_loc) - len(batch.reqs) + i
71
+ self.token_to_kv_pool_allocator.free(batch.out_cache_loc[j : j + 1])
72
+ continue
73
+
74
+ if req.is_chunked <= 0:
75
+ # req output_ids are set here
76
+ req.output_ids.append(next_token_id)
77
+ req.check_finished()
78
+
79
+ if req.finished():
80
+ self.tree_cache.cache_finished_req(req)
81
+ elif not batch.decoding_reqs or req not in batch.decoding_reqs:
82
+ # This updates radix so others can match
83
+ self.tree_cache.cache_unfinished_req(req)
84
+
85
+ if req.return_logprob:
86
+ assert extend_logprob_start_len_per_req is not None
87
+ assert extend_input_len_per_req is not None
88
+ extend_logprob_start_len = extend_logprob_start_len_per_req[i]
89
+ extend_input_len = extend_input_len_per_req[i]
90
+ num_input_logprobs = extend_input_len - extend_logprob_start_len
91
+ self.add_logprob_return_values(
92
+ i,
93
+ req,
94
+ logprob_pt,
95
+ next_token_ids,
96
+ num_input_logprobs,
97
+ logits_output,
98
+ )
99
+ logprob_pt += num_input_logprobs
100
+
101
+ if (
102
+ req.return_hidden_states
103
+ and logits_output.hidden_states is not None
104
+ ):
105
+ req.hidden_states.append(
106
+ logits_output.hidden_states[
107
+ hidden_state_offset : (
108
+ hidden_state_offset := hidden_state_offset
109
+ + len(req.origin_input_ids)
110
+ )
111
+ ]
112
+ .cpu()
113
+ .clone()
114
+ .tolist()
115
+ )
116
+
117
+ if req.grammar is not None:
118
+ req.grammar.accept_token(next_token_id)
119
+ req.grammar.finished = req.finished()
120
+ else:
121
+ # being chunked reqs' prefill is not finished
122
+ req.is_chunked -= 1
123
+ # There is only at most one request being currently chunked.
124
+ # Because this request does not finish prefill,
125
+ # we don't want to stream the request currently being chunked.
126
+ skip_stream_req = req
127
+
128
+ # Incrementally update input logprobs.
129
+ if req.return_logprob:
130
+ extend_logprob_start_len = extend_logprob_start_len_per_req[i]
131
+ extend_input_len = extend_input_len_per_req[i]
132
+ if extend_logprob_start_len < extend_input_len:
133
+ # Update input logprobs.
134
+ num_input_logprobs = (
135
+ extend_input_len - extend_logprob_start_len
136
+ )
137
+ self.add_input_logprob_return_values(
138
+ i,
139
+ req,
140
+ logits_output,
141
+ logprob_pt,
142
+ num_input_logprobs,
143
+ last_prefill_chunk=False,
144
+ )
145
+ logprob_pt += num_input_logprobs
146
+
147
+ if batch.next_batch_sampling_info:
148
+ batch.next_batch_sampling_info.update_regex_vocab_mask()
149
+ self.current_stream.synchronize()
150
+ batch.next_batch_sampling_info.sampling_info_done.set()
151
+
152
+ else: # embedding or reward model
153
+ embeddings, bid = result.embeddings, result.bid
154
+ embeddings = embeddings.tolist()
155
+
156
+ # Check finish conditions
157
+ for i, req in enumerate(batch.reqs):
158
+ if req.is_retracted:
159
+ continue
160
+
161
+ req.embedding = embeddings[i]
162
+ if req.is_chunked <= 0:
163
+ # Dummy output token for embedding models
164
+ req.output_ids.append(0)
165
+ req.check_finished()
166
+
167
+ if req.finished():
168
+ self.tree_cache.cache_finished_req(req)
169
+ else:
170
+ self.tree_cache.cache_unfinished_req(req)
171
+ else:
172
+ # being chunked reqs' prefill is not finished
173
+ req.is_chunked -= 1
174
+
175
+ self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req)
176
+
177
+ def process_batch_result_decode(
178
+ self,
179
+ batch: ScheduleBatch,
180
+ result: GenerationBatchResult,
181
+ ):
182
+ logits_output, next_token_ids, bid = (
183
+ result.logits_output,
184
+ result.next_token_ids,
185
+ result.bid,
186
+ )
187
+ self.num_generated_tokens += len(batch.reqs)
188
+
189
+ if self.enable_overlap:
190
+ logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
191
+ next_token_logprobs = logits_output.next_token_logprobs
192
+ elif batch.spec_algorithm.is_none():
193
+ # spec decoding handles output logprobs inside verify process.
194
+ next_token_ids = next_token_ids.tolist()
195
+ if batch.return_logprob:
196
+ next_token_logprobs = logits_output.next_token_logprobs.tolist()
197
+
198
+ self.token_to_kv_pool_allocator.free_group_begin()
199
+
200
+ # Check finish condition
201
+ # NOTE: the length of reqs and next_token_ids don't match if it is spec decoding.
202
+ # We should ignore using next_token_ids for spec decoding cases.
203
+ for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
204
+ if req.is_retracted:
205
+ continue
206
+
207
+ if self.enable_overlap and req.finished():
208
+ # Free the one extra delayed token
209
+ if self.page_size == 1:
210
+ self.token_to_kv_pool_allocator.free(batch.out_cache_loc[i : i + 1])
211
+ else:
212
+ # Only free when the extra token is in a new page
213
+ if (
214
+ len(req.origin_input_ids) + len(req.output_ids) - 1
215
+ ) % self.page_size == 0:
216
+ self.token_to_kv_pool_allocator.free(
217
+ batch.out_cache_loc[i : i + 1]
218
+ )
219
+ continue
220
+
221
+ if batch.spec_algorithm.is_none():
222
+ # speculative worker will solve the output_ids in speculative decoding
223
+ req.output_ids.append(next_token_id)
224
+
225
+ req.check_finished()
226
+ if req.finished():
227
+ self.tree_cache.cache_finished_req(req)
228
+
229
+ if req.return_logprob and batch.spec_algorithm.is_none():
230
+ # speculative worker handles logprob in speculative decoding
231
+ req.output_token_logprobs_val.append(next_token_logprobs[i])
232
+ req.output_token_logprobs_idx.append(next_token_id)
233
+ if req.top_logprobs_num > 0:
234
+ req.output_top_logprobs_val.append(
235
+ logits_output.next_token_top_logprobs_val[i]
236
+ )
237
+ req.output_top_logprobs_idx.append(
238
+ logits_output.next_token_top_logprobs_idx[i]
239
+ )
240
+ if req.token_ids_logprob is not None:
241
+ req.output_token_ids_logprobs_val.append(
242
+ logits_output.next_token_token_ids_logprobs_val[i]
243
+ )
244
+ req.output_token_ids_logprobs_idx.append(
245
+ logits_output.next_token_token_ids_logprobs_idx[i]
246
+ )
247
+
248
+ if req.return_hidden_states and logits_output.hidden_states is not None:
249
+ req.hidden_states.append(
250
+ logits_output.hidden_states[i].cpu().clone().tolist()
251
+ )
252
+
253
+ if req.grammar is not None and batch.spec_algorithm.is_none():
254
+ req.grammar.accept_token(next_token_id)
255
+ req.grammar.finished = req.finished()
256
+
257
+ if batch.next_batch_sampling_info:
258
+ batch.next_batch_sampling_info.update_regex_vocab_mask()
259
+ self.current_stream.synchronize()
260
+ batch.next_batch_sampling_info.sampling_info_done.set()
261
+
262
+ self.stream_output(batch.reqs, batch.return_logprob)
263
+
264
+ self.token_to_kv_pool_allocator.free_group_end()
265
+
266
+ self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
267
+ if (
268
+ self.attn_tp_rank == 0
269
+ and self.forward_ct_decode % self.server_args.decode_log_interval == 0
270
+ ):
271
+ self.log_decode_stats()
272
+
273
+ def add_input_logprob_return_values(
274
+ self,
275
+ i: int,
276
+ req: Req,
277
+ output: LogitsProcessorOutput,
278
+ logprob_pt: int,
279
+ num_input_logprobs: int,
280
+ last_prefill_chunk: bool, # If True, it means prefill is finished.
281
+ ):
282
+ """Incrementally add input logprobs to `req`.
283
+
284
+ Args:
285
+ i: The request index in a batch.
286
+ req: The request. Input logprobs inside req are modified as a
287
+ consequence of the API
288
+ fill_ids: The prefill ids processed.
289
+ output: Logit processor output that's used to compute input logprobs
290
+ last_prefill_chunk: True if it is the last prefill (when chunked).
291
+ Some of input logprob operation should only happen at the last
292
+ prefill (e.g., computing input token logprobs).
293
+ """
294
+ assert output.input_token_logprobs is not None
295
+ if req.input_token_logprobs is None:
296
+ req.input_token_logprobs = []
297
+ if req.temp_input_top_logprobs_val is None:
298
+ req.temp_input_top_logprobs_val = []
299
+ if req.temp_input_top_logprobs_idx is None:
300
+ req.temp_input_top_logprobs_idx = []
301
+ if req.temp_input_token_ids_logprobs_val is None:
302
+ req.temp_input_token_ids_logprobs_val = []
303
+ if req.temp_input_token_ids_logprobs_idx is None:
304
+ req.temp_input_token_ids_logprobs_idx = []
305
+
306
+ if req.input_token_logprobs_val is not None:
307
+ # The input logprob has been already computed. It only happens
308
+ # upon retract.
309
+ if req.top_logprobs_num > 0:
310
+ assert req.input_token_logprobs_val is not None
311
+ return
312
+
313
+ # Important for the performance.
314
+ assert isinstance(output.input_token_logprobs, tuple)
315
+ input_token_logprobs: Tuple[int] = output.input_token_logprobs
316
+ input_token_logprobs = input_token_logprobs[
317
+ logprob_pt : logprob_pt + num_input_logprobs
318
+ ]
319
+ req.input_token_logprobs.extend(input_token_logprobs)
320
+
321
+ if req.top_logprobs_num > 0:
322
+ req.temp_input_top_logprobs_val.append(output.input_top_logprobs_val[i])
323
+ req.temp_input_top_logprobs_idx.append(output.input_top_logprobs_idx[i])
324
+
325
+ if req.token_ids_logprob is not None:
326
+ req.temp_input_token_ids_logprobs_val.append(
327
+ output.input_token_ids_logprobs_val[i]
328
+ )
329
+ req.temp_input_token_ids_logprobs_idx.append(
330
+ output.input_token_ids_logprobs_idx[i]
331
+ )
332
+
333
+ if last_prefill_chunk:
334
+ input_token_logprobs = req.input_token_logprobs
335
+ req.input_token_logprobs = None
336
+ assert req.input_token_logprobs_val is None
337
+ assert req.input_token_logprobs_idx is None
338
+ assert req.input_top_logprobs_val is None
339
+ assert req.input_top_logprobs_idx is None
340
+
341
+ # Compute input_token_logprobs_val
342
+ # Always pad the first one with None.
343
+ req.input_token_logprobs_val = [None]
344
+ req.input_token_logprobs_val.extend(input_token_logprobs)
345
+ # The last input logprob is for sampling, so just pop it out.
346
+ req.input_token_logprobs_val.pop()
347
+
348
+ # Compute input_token_logprobs_idx
349
+ input_token_logprobs_idx = req.origin_input_ids[req.logprob_start_len :]
350
+ # Clip the padded hash values from image tokens.
351
+ # Otherwise, it will lead to detokenization errors.
352
+ input_token_logprobs_idx = [
353
+ x if x < self.model_config.vocab_size - 1 else 0
354
+ for x in input_token_logprobs_idx
355
+ ]
356
+ req.input_token_logprobs_idx = input_token_logprobs_idx
357
+
358
+ if req.top_logprobs_num > 0:
359
+ req.input_top_logprobs_val = [None]
360
+ req.input_top_logprobs_idx = [None]
361
+ assert len(req.temp_input_token_ids_logprobs_val) == len(
362
+ req.temp_input_token_ids_logprobs_idx
363
+ )
364
+ for val, idx in zip(
365
+ req.temp_input_top_logprobs_val,
366
+ req.temp_input_top_logprobs_idx,
367
+ strict=True,
368
+ ):
369
+ req.input_top_logprobs_val.extend(val)
370
+ req.input_top_logprobs_idx.extend(idx)
371
+
372
+ # Last token is a sample token.
373
+ req.input_top_logprobs_val.pop()
374
+ req.input_top_logprobs_idx.pop()
375
+ req.temp_input_top_logprobs_idx = None
376
+ req.temp_input_top_logprobs_val = None
377
+
378
+ if req.token_ids_logprob is not None:
379
+ req.input_token_ids_logprobs_val = [None]
380
+ req.input_token_ids_logprobs_idx = [None]
381
+
382
+ for val, idx in zip(
383
+ req.temp_input_token_ids_logprobs_val,
384
+ req.temp_input_token_ids_logprobs_idx,
385
+ strict=True,
386
+ ):
387
+ req.input_token_ids_logprobs_val.extend(val)
388
+ req.input_token_ids_logprobs_idx.extend(idx)
389
+
390
+ # Last token is a sample token.
391
+ req.input_token_ids_logprobs_val.pop()
392
+ req.input_token_ids_logprobs_idx.pop()
393
+ req.temp_input_token_ids_logprobs_idx = None
394
+ req.temp_input_token_ids_logprobs_val = None
395
+
396
+ if req.return_logprob:
397
+ relevant_tokens_len = len(req.origin_input_ids) - req.logprob_start_len
398
+ assert len(req.input_token_logprobs_val) == relevant_tokens_len
399
+ assert len(req.input_token_logprobs_idx) == relevant_tokens_len
400
+ if req.top_logprobs_num > 0:
401
+ assert len(req.input_top_logprobs_val) == relevant_tokens_len
402
+ assert len(req.input_top_logprobs_idx) == relevant_tokens_len
403
+ if req.token_ids_logprob is not None:
404
+ assert len(req.input_token_ids_logprobs_val) == relevant_tokens_len
405
+ assert len(req.input_token_ids_logprobs_idx) == relevant_tokens_len
406
+
407
+ def add_logprob_return_values(
408
+ self,
409
+ i: int,
410
+ req: Req,
411
+ pt: int,
412
+ next_token_ids: List[int],
413
+ num_input_logprobs: int,
414
+ output: LogitsProcessorOutput,
415
+ ):
416
+ """Attach logprobs to the return values."""
417
+ req.output_token_logprobs_val.append(output.next_token_logprobs[i])
418
+ req.output_token_logprobs_idx.append(next_token_ids[i])
419
+
420
+ self.add_input_logprob_return_values(
421
+ i, req, output, pt, num_input_logprobs, last_prefill_chunk=True
422
+ )
423
+
424
+ if req.top_logprobs_num > 0:
425
+ req.output_top_logprobs_val.append(output.next_token_top_logprobs_val[i])
426
+ req.output_top_logprobs_idx.append(output.next_token_top_logprobs_idx[i])
427
+
428
+ if req.token_ids_logprob is not None:
429
+ req.output_token_ids_logprobs_val.append(
430
+ output.next_token_token_ids_logprobs_val[i]
431
+ )
432
+ req.output_token_ids_logprobs_idx.append(
433
+ output.next_token_token_ids_logprobs_idx[i]
434
+ )
435
+
436
+ return num_input_logprobs
437
+
438
+ def stream_output(
439
+ self, reqs: List[Req], return_logprob: bool, skip_req: Optional[Req] = None
440
+ ):
441
+ """Stream the output to detokenizer."""
442
+ if self.is_generation:
443
+ self.stream_output_generation(reqs, return_logprob, skip_req)
444
+ else: # embedding or reward model
445
+ self.stream_output_embedding(reqs)
446
+
447
+ def stream_output_generation(
448
+ self, reqs: List[Req], return_logprob: bool, skip_req: Optional[Req] = None
449
+ ):
450
+ rids = []
451
+ finished_reasons: List[BaseFinishReason] = []
452
+
453
+ decoded_texts = []
454
+ decode_ids_list = []
455
+ read_offsets = []
456
+ output_ids = []
457
+
458
+ skip_special_tokens = []
459
+ spaces_between_special_tokens = []
460
+ no_stop_trim = []
461
+ prompt_tokens = []
462
+ completion_tokens = []
463
+ cached_tokens = []
464
+ spec_verify_ct = []
465
+ output_hidden_states = None
466
+
467
+ if return_logprob:
468
+ input_token_logprobs_val = []
469
+ input_token_logprobs_idx = []
470
+ output_token_logprobs_val = []
471
+ output_token_logprobs_idx = []
472
+ input_top_logprobs_val = []
473
+ input_top_logprobs_idx = []
474
+ output_top_logprobs_val = []
475
+ output_top_logprobs_idx = []
476
+ input_token_ids_logprobs_val = []
477
+ input_token_ids_logprobs_idx = []
478
+ output_token_ids_logprobs_val = []
479
+ output_token_ids_logprobs_idx = []
480
+ else:
481
+ input_token_logprobs_val = input_token_logprobs_idx = (
482
+ output_token_logprobs_val
483
+ ) = output_token_logprobs_idx = input_top_logprobs_val = (
484
+ input_top_logprobs_idx
485
+ ) = output_top_logprobs_val = output_top_logprobs_idx = (
486
+ input_token_ids_logprobs_val
487
+ ) = input_token_ids_logprobs_idx = output_token_ids_logprobs_val = (
488
+ output_token_ids_logprobs_idx
489
+ ) = None
490
+
491
+ for req in reqs:
492
+ if req is skip_req:
493
+ continue
494
+
495
+ # Multimodal partial stream chunks break the detokenizer, so drop aborted requests here.
496
+ if self.model_config.is_multimodal_gen and req.to_abort:
497
+ continue
498
+
499
+ if (
500
+ req.finished()
501
+ # If stream, follow the given stream_interval
502
+ or (req.stream and len(req.output_ids) % self.stream_interval == 0)
503
+ # If not stream, we still want to output some tokens to get the benefit of incremental decoding.
504
+ # TODO(lianmin): this is wrong for speculative decoding because len(req.output_ids) does not
505
+ # always increase one-by-one.
506
+ or (
507
+ not req.stream
508
+ and len(req.output_ids) % 50 == 0
509
+ and not self.model_config.is_multimodal_gen
510
+ )
511
+ ):
512
+ rids.append(req.rid)
513
+ finished_reasons.append(
514
+ req.finished_reason.to_json() if req.finished_reason else None
515
+ )
516
+ decoded_texts.append(req.decoded_text)
517
+ decode_ids, read_offset = req.init_incremental_detokenize()
518
+ decode_ids_list.append(decode_ids)
519
+ read_offsets.append(read_offset)
520
+ if self.skip_tokenizer_init:
521
+ output_ids.append(req.output_ids)
522
+ skip_special_tokens.append(req.sampling_params.skip_special_tokens)
523
+ spaces_between_special_tokens.append(
524
+ req.sampling_params.spaces_between_special_tokens
525
+ )
526
+ no_stop_trim.append(req.sampling_params.no_stop_trim)
527
+ prompt_tokens.append(len(req.origin_input_ids))
528
+ completion_tokens.append(len(req.output_ids))
529
+ cached_tokens.append(req.cached_tokens)
530
+
531
+ if not self.spec_algorithm.is_none():
532
+ spec_verify_ct.append(req.spec_verify_ct)
533
+
534
+ if return_logprob:
535
+ input_token_logprobs_val.append(req.input_token_logprobs_val)
536
+ input_token_logprobs_idx.append(req.input_token_logprobs_idx)
537
+ output_token_logprobs_val.append(req.output_token_logprobs_val)
538
+ output_token_logprobs_idx.append(req.output_token_logprobs_idx)
539
+ input_top_logprobs_val.append(req.input_top_logprobs_val)
540
+ input_top_logprobs_idx.append(req.input_top_logprobs_idx)
541
+ output_top_logprobs_val.append(req.output_top_logprobs_val)
542
+ output_top_logprobs_idx.append(req.output_top_logprobs_idx)
543
+ input_token_ids_logprobs_val.append(
544
+ req.input_token_ids_logprobs_val
545
+ )
546
+ input_token_ids_logprobs_idx.append(
547
+ req.input_token_ids_logprobs_idx
548
+ )
549
+ output_token_ids_logprobs_val.append(
550
+ req.output_token_ids_logprobs_val
551
+ )
552
+ output_token_ids_logprobs_idx.append(
553
+ req.output_token_ids_logprobs_idx
554
+ )
555
+
556
+ if req.return_hidden_states:
557
+ if output_hidden_states is None:
558
+ output_hidden_states = []
559
+ output_hidden_states.append(req.hidden_states)
560
+
561
+ # Send to detokenizer
562
+ if rids:
563
+ if self.model_config.is_multimodal_gen:
564
+ return
565
+ self.send_to_detokenizer.send_pyobj(
566
+ BatchTokenIDOut(
567
+ rids,
568
+ finished_reasons,
569
+ decoded_texts,
570
+ decode_ids_list,
571
+ read_offsets,
572
+ output_ids,
573
+ skip_special_tokens,
574
+ spaces_between_special_tokens,
575
+ no_stop_trim,
576
+ prompt_tokens,
577
+ completion_tokens,
578
+ cached_tokens,
579
+ spec_verify_ct,
580
+ input_token_logprobs_val,
581
+ input_token_logprobs_idx,
582
+ output_token_logprobs_val,
583
+ output_token_logprobs_idx,
584
+ input_top_logprobs_val,
585
+ input_top_logprobs_idx,
586
+ output_top_logprobs_val,
587
+ output_top_logprobs_idx,
588
+ input_token_ids_logprobs_val,
589
+ input_token_ids_logprobs_idx,
590
+ output_token_ids_logprobs_val,
591
+ output_token_ids_logprobs_idx,
592
+ output_hidden_states,
593
+ )
594
+ )
595
+
596
+ def stream_output_embedding(self, reqs: List[Req]):
597
+ rids = []
598
+ finished_reasons: List[BaseFinishReason] = []
599
+
600
+ embeddings = []
601
+ prompt_tokens = []
602
+ cached_tokens = []
603
+ for req in reqs:
604
+ if req.finished():
605
+ rids.append(req.rid)
606
+ finished_reasons.append(req.finished_reason.to_json())
607
+ embeddings.append(req.embedding)
608
+ prompt_tokens.append(len(req.origin_input_ids))
609
+ cached_tokens.append(req.cached_tokens)
610
+ self.send_to_detokenizer.send_pyobj(
611
+ BatchEmbeddingOut(
612
+ rids, finished_reasons, embeddings, prompt_tokens, cached_tokens
613
+ )
614
+ )
@@ -372,13 +372,12 @@ class TokenizerManager:
372
372
  )
373
373
  input_ids = self.tokenizer.encode(input_text)
374
374
 
375
+ image_inputs: Dict = await self.image_processor.process_images_async(
376
+ obj.image_data, input_text or input_ids, obj, self.max_req_input_len
377
+ )
378
+ if image_inputs and "input_ids" in image_inputs:
379
+ input_ids = image_inputs["input_ids"]
375
380
  if self.is_generation:
376
- # TODO: also support getting embeddings for multimodal models
377
- image_inputs: Dict = await self.image_processor.process_images_async(
378
- obj.image_data, input_text or input_ids, obj, self.max_req_input_len
379
- )
380
- if image_inputs and "input_ids" in image_inputs:
381
- input_ids = image_inputs["input_ids"]
382
381
  return_logprob = obj.return_logprob
383
382
  logprob_start_len = obj.logprob_start_len
384
383
  top_logprobs_num = obj.top_logprobs_num
@@ -438,6 +437,7 @@ class TokenizerManager:
438
437
  obj.rid,
439
438
  input_text,
440
439
  input_ids,
440
+ image_inputs,
441
441
  sampling_params,
442
442
  )
443
443
 
@@ -103,6 +103,9 @@ class TpModelWorkerClient:
103
103
  self.worker.model_runner.token_to_kv_pool_allocator,
104
104
  )
105
105
 
106
+ def get_kv_cache(self):
107
+ return self.worker.model_runner.token_to_kv_pool
108
+
106
109
  def forward_thread_func(self):
107
110
  try:
108
111
  with torch.get_device_module(self.device).stream(self.forward_stream):
@@ -203,7 +206,7 @@ class TpModelWorkerClient:
203
206
  -(self.future_token_ids_ct + 1),
204
207
  -(self.future_token_ids_ct + 1 + bs),
205
208
  -1,
206
- dtype=torch.int32,
209
+ dtype=torch.int64,
207
210
  device=self.device,
208
211
  )
209
212
  self.future_token_ids_ct = (
@@ -1,5 +1,5 @@
1
1
  from abc import ABC, abstractmethod
2
- from typing import Callable, List, Tuple
2
+ from typing import Any, List, Tuple
3
3
 
4
4
 
5
5
  class BasePrefixCache(ABC):
@@ -26,24 +26,22 @@ class BasePrefixCache(ABC):
26
26
  pass
27
27
 
28
28
  @abstractmethod
29
- def evict(self, num_tokens: int, evict_callback: Callable):
29
+ def evict(self, num_tokens: int):
30
30
  pass
31
31
 
32
32
  @abstractmethod
33
- def inc_lock_ref(self, node):
33
+ def inc_lock_ref(self, node: Any):
34
34
  pass
35
35
 
36
36
  @abstractmethod
37
- def dec_lock_ref(self, node):
37
+ def dec_lock_ref(self, node: Any):
38
38
  pass
39
39
 
40
- @abstractmethod
41
40
  def evictable_size(self):
42
- pass
41
+ return 0
43
42
 
44
- @abstractmethod
45
43
  def protected_size(self):
46
- raise NotImplementedError()
44
+ return 0
47
45
 
48
46
  def total_size(self):
49
47
  raise NotImplementedError()