sglang 0.2.11__py3-none-any.whl → 0.2.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. sglang/api.py +7 -1
  2. sglang/bench_latency.py +9 -6
  3. sglang/bench_serving.py +46 -22
  4. sglang/global_config.py +1 -1
  5. sglang/lang/backend/runtime_endpoint.py +60 -49
  6. sglang/lang/compiler.py +2 -2
  7. sglang/lang/interpreter.py +4 -2
  8. sglang/lang/ir.py +16 -7
  9. sglang/srt/constrained/base_tool_cache.py +1 -1
  10. sglang/srt/constrained/fsm_cache.py +12 -2
  11. sglang/srt/constrained/jump_forward.py +13 -2
  12. sglang/srt/layers/activation.py +32 -0
  13. sglang/srt/layers/{token_attention.py → decode_attention.py} +9 -5
  14. sglang/srt/layers/extend_attention.py +9 -2
  15. sglang/srt/layers/fused_moe/__init__.py +1 -0
  16. sglang/srt/layers/{fused_moe.py → fused_moe/fused_moe.py} +165 -108
  17. sglang/srt/layers/fused_moe/layer.py +587 -0
  18. sglang/srt/layers/layernorm.py +65 -0
  19. sglang/srt/layers/logits_processor.py +7 -2
  20. sglang/srt/layers/pooler.py +50 -0
  21. sglang/srt/layers/{context_flashattention_nopad.py → prefill_attention.py} +5 -0
  22. sglang/srt/layers/radix_attention.py +40 -16
  23. sglang/srt/managers/detokenizer_manager.py +31 -9
  24. sglang/srt/managers/io_struct.py +63 -0
  25. sglang/srt/managers/policy_scheduler.py +173 -25
  26. sglang/srt/managers/schedule_batch.py +115 -97
  27. sglang/srt/managers/tokenizer_manager.py +194 -112
  28. sglang/srt/managers/tp_worker.py +290 -359
  29. sglang/srt/mem_cache/{base_cache.py → base_prefix_cache.py} +9 -4
  30. sglang/srt/mem_cache/chunk_cache.py +43 -20
  31. sglang/srt/mem_cache/memory_pool.py +2 -2
  32. sglang/srt/mem_cache/radix_cache.py +74 -40
  33. sglang/srt/model_executor/cuda_graph_runner.py +71 -25
  34. sglang/srt/model_executor/forward_batch_info.py +293 -156
  35. sglang/srt/model_executor/model_runner.py +77 -57
  36. sglang/srt/models/chatglm.py +2 -2
  37. sglang/srt/models/commandr.py +1 -1
  38. sglang/srt/models/deepseek.py +2 -2
  39. sglang/srt/models/deepseek_v2.py +7 -6
  40. sglang/srt/models/gemma.py +1 -1
  41. sglang/srt/models/gemma2.py +11 -6
  42. sglang/srt/models/grok.py +50 -396
  43. sglang/srt/models/internlm2.py +2 -7
  44. sglang/srt/models/llama2.py +4 -4
  45. sglang/srt/models/llama_embedding.py +88 -0
  46. sglang/srt/models/minicpm.py +2 -2
  47. sglang/srt/models/mixtral.py +56 -254
  48. sglang/srt/models/mixtral_quant.py +1 -4
  49. sglang/srt/models/qwen.py +2 -2
  50. sglang/srt/models/qwen2.py +2 -2
  51. sglang/srt/models/qwen2_moe.py +2 -13
  52. sglang/srt/models/stablelm.py +1 -1
  53. sglang/srt/openai_api/adapter.py +187 -48
  54. sglang/srt/openai_api/protocol.py +37 -1
  55. sglang/srt/sampling/penaltylib/__init__.py +13 -0
  56. sglang/srt/sampling/penaltylib/orchestrator.py +357 -0
  57. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +80 -0
  58. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +105 -0
  59. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +79 -0
  60. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +83 -0
  61. sglang/srt/sampling_params.py +31 -8
  62. sglang/srt/server.py +91 -29
  63. sglang/srt/server_args.py +32 -19
  64. sglang/srt/utils.py +32 -15
  65. sglang/test/run_eval.py +10 -1
  66. sglang/test/runners.py +81 -73
  67. sglang/test/simple_eval_humaneval.py +2 -8
  68. sglang/test/simple_eval_mgsm.py +203 -0
  69. sglang/test/srt/sampling/penaltylib/utils.py +337 -0
  70. sglang/test/test_layernorm.py +60 -0
  71. sglang/test/test_programs.py +36 -7
  72. sglang/test/test_utils.py +24 -2
  73. sglang/utils.py +0 -1
  74. sglang/version.py +1 -1
  75. {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/METADATA +33 -16
  76. sglang-0.2.13.dist-info/RECORD +112 -0
  77. {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/WHEEL +1 -1
  78. sglang/srt/layers/linear.py +0 -884
  79. sglang/srt/layers/quantization/__init__.py +0 -64
  80. sglang/srt/layers/quantization/fp8.py +0 -677
  81. sglang/srt/model_loader/model_loader.py +0 -292
  82. sglang/srt/model_loader/utils.py +0 -275
  83. sglang-0.2.11.dist-info/RECORD +0 -102
  84. {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/LICENSE +0 -0
  85. {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/top_level.txt +0 -0
@@ -17,25 +17,30 @@ limitations under the License.
17
17
 
18
18
  import logging
19
19
  import multiprocessing
20
+ import os
20
21
  import pickle
21
22
  import time
22
23
  import warnings
23
- from typing import List, Optional
24
+ from typing import Any, List, Optional, Union
24
25
 
25
26
  import torch
27
+ import torch.distributed
26
28
  import torch.distributed as dist
27
29
 
28
30
  from sglang.global_config import global_config
29
31
  from sglang.srt.constrained.fsm_cache import FSMCache
30
32
  from sglang.srt.constrained.jump_forward import JumpForwardCache
31
33
  from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
34
+ from sglang.srt.layers.logits_processor import LogitProcessorOutput
32
35
  from sglang.srt.managers.io_struct import (
33
36
  AbortReq,
37
+ BatchEmbeddingOut,
34
38
  BatchTokenIDOut,
35
39
  FlushCacheReq,
40
+ TokenizedEmbeddingReqInput,
36
41
  TokenizedGenerateReqInput,
37
42
  )
38
- from sglang.srt.managers.policy_scheduler import PolicyScheduler
43
+ from sglang.srt.managers.policy_scheduler import PolicyScheduler, PrefillAdder
39
44
  from sglang.srt.managers.schedule_batch import (
40
45
  FINISH_ABORT,
41
46
  BaseFinishReason,
@@ -49,7 +54,6 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode
49
54
  from sglang.srt.model_executor.model_runner import ModelRunner
50
55
  from sglang.srt.server_args import ServerArgs
51
56
  from sglang.srt.utils import (
52
- get_int_token_logit_bias,
53
57
  is_multimodal_model,
54
58
  set_random_seed,
55
59
  suppress_other_loggers,
@@ -59,6 +63,9 @@ from sglang.utils import get_exception_traceback
59
63
  logger = logging.getLogger(__name__)
60
64
 
61
65
 
66
+ crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"
67
+
68
+
62
69
  class ModelTpServer:
63
70
  def __init__(
64
71
  self,
@@ -98,26 +105,24 @@ class ModelTpServer:
98
105
  nccl_port=nccl_port,
99
106
  server_args=server_args,
100
107
  )
101
-
102
- if is_multimodal_model(server_args.model_path):
103
- self.processor = get_processor(
104
- server_args.tokenizer_path,
105
- tokenizer_mode=server_args.tokenizer_mode,
106
- trust_remote_code=server_args.trust_remote_code,
107
- )
108
- self.tokenizer = self.processor.tokenizer
108
+ if server_args.skip_tokenizer_init:
109
+ self.tokenizer = self.processor = None
109
110
  else:
110
- self.tokenizer = get_tokenizer(
111
- server_args.tokenizer_path,
112
- tokenizer_mode=server_args.tokenizer_mode,
113
- trust_remote_code=server_args.trust_remote_code,
114
- )
111
+ if is_multimodal_model(server_args.model_path):
112
+ self.processor = get_processor(
113
+ server_args.tokenizer_path,
114
+ tokenizer_mode=server_args.tokenizer_mode,
115
+ trust_remote_code=server_args.trust_remote_code,
116
+ )
117
+ self.tokenizer = self.processor.tokenizer
118
+ else:
119
+ self.tokenizer = get_tokenizer(
120
+ server_args.tokenizer_path,
121
+ tokenizer_mode=server_args.tokenizer_mode,
122
+ trust_remote_code=server_args.trust_remote_code,
123
+ )
115
124
  self.max_total_num_tokens = self.model_runner.max_total_num_tokens
116
- self.max_prefill_tokens = (
117
- 16384
118
- if server_args.max_prefill_tokens is None
119
- else server_args.max_prefill_tokens
120
- )
125
+ self.max_prefill_tokens = server_args.max_prefill_tokens
121
126
  self.max_running_requests = min(
122
127
  (
123
128
  self.max_total_num_tokens // 2
@@ -126,9 +131,6 @@ class ModelTpServer:
126
131
  ),
127
132
  self.model_runner.req_to_token_pool.size - 1,
128
133
  )
129
- self.int_token_logit_bias = torch.tensor(
130
- get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
131
- )
132
134
  self.max_req_input_len = min(
133
135
  self.model_config.context_len - 1,
134
136
  self.max_total_num_tokens - 1,
@@ -160,13 +162,7 @@ class ModelTpServer:
160
162
  disable=server_args.disable_radix_cache,
161
163
  )
162
164
  self.tree_cache_metrics = {"total": 0, "hit": 0}
163
- self.scheduler = PolicyScheduler(
164
- self.schedule_policy,
165
- self.max_running_requests,
166
- self.max_prefill_tokens,
167
- self.max_total_num_tokens,
168
- self.tree_cache,
169
- )
165
+ self.scheduler = PolicyScheduler(self.schedule_policy, self.tree_cache)
170
166
  self.req_to_token_pool = self.model_runner.req_to_token_pool
171
167
  self.token_to_kv_pool = self.model_runner.token_to_kv_pool
172
168
 
@@ -180,13 +176,15 @@ class ModelTpServer:
180
176
  self.last_stats_tic = time.time()
181
177
 
182
178
  # Init the FSM cache for constrained generation
183
- self.regex_fsm_cache = FSMCache(
184
- server_args.tokenizer_path,
185
- {
186
- "tokenizer_mode": server_args.tokenizer_mode,
187
- "trust_remote_code": server_args.trust_remote_code,
188
- },
189
- )
179
+ if not server_args.skip_tokenizer_init:
180
+ self.regex_fsm_cache = FSMCache(
181
+ server_args.tokenizer_path,
182
+ {
183
+ "tokenizer_mode": server_args.tokenizer_mode,
184
+ "trust_remote_code": server_args.trust_remote_code,
185
+ },
186
+ skip_tokenizer_init=server_args.skip_tokenizer_init,
187
+ )
190
188
  self.jump_forward_cache = JumpForwardCache()
191
189
 
192
190
  # Init new token estimation
@@ -201,11 +199,13 @@ class ModelTpServer:
201
199
  self.new_token_ratio = self.min_new_token_ratio
202
200
  self.new_token_ratio_decay = global_config.new_token_ratio_decay
203
201
 
204
- def exposed_step(self, recv_reqs):
202
+ def exposed_step(self, recv_reqs: List):
205
203
  try:
206
204
  # Recv requests
207
205
  for recv_req in recv_reqs:
208
- if isinstance(recv_req, TokenizedGenerateReqInput):
206
+ if isinstance(
207
+ recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
208
+ ):
209
209
  self.handle_generate_request(recv_req)
210
210
  elif isinstance(recv_req, FlushCacheReq):
211
211
  self.flush_cache()
@@ -232,8 +232,6 @@ class ModelTpServer:
232
232
  if new_batch is not None:
233
233
  # Run a new prefill batch
234
234
  self.forward_prefill_batch(new_batch)
235
- self.cache_filled_batch(new_batch)
236
- self.filter_out_inflight(new_batch)
237
235
 
238
236
  if not new_batch.is_empty():
239
237
  if self.running_batch is None:
@@ -250,7 +248,7 @@ class ModelTpServer:
250
248
 
251
249
  # Print stats
252
250
  if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0:
253
- self.print_stats()
251
+ self.print_decode_stats()
254
252
 
255
253
  if self.running_batch.is_empty():
256
254
  self.running_batch = None
@@ -262,7 +260,7 @@ class ModelTpServer:
262
260
  self.check_memory()
263
261
  self.new_token_ratio = global_config.init_new_token_ratio
264
262
 
265
- def print_stats(self):
263
+ def print_decode_stats(self):
266
264
  num_used = self.max_total_num_tokens - (
267
265
  self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
268
266
  )
@@ -288,6 +286,7 @@ class ModelTpServer:
288
286
  f"available_size={available_size}, max_total_num_tokens={self.max_total_num_tokens}\n"
289
287
  "KV cache pool leak detected!"
290
288
  )
289
+ exit(1) if crash_on_warning else None
291
290
 
292
291
  if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
293
292
  warnings.warn(
@@ -296,44 +295,46 @@ class ModelTpServer:
296
295
  f"total slots={self.req_to_token_pool.size}\n"
297
296
  "Memory pool leak detected!"
298
297
  )
298
+ exit(1) if crash_on_warning else None
299
299
 
300
300
  def handle_generate_request(
301
301
  self,
302
- recv_req: TokenizedGenerateReqInput,
302
+ recv_req: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
303
303
  ):
304
304
  req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
305
- req.pixel_values = recv_req.pixel_values
306
- if req.pixel_values is not None:
307
- req.pad_value = [
308
- (recv_req.image_hash) % self.model_config.vocab_size,
309
- (recv_req.image_hash >> 16) % self.model_config.vocab_size,
310
- (recv_req.image_hash >> 32) % self.model_config.vocab_size,
311
- (recv_req.image_hash >> 64) % self.model_config.vocab_size,
312
- ]
313
- req.image_size = recv_req.image_size
314
- (
315
- req.origin_input_ids,
316
- req.image_offset,
317
- ) = self.model_runner.model.pad_input_ids(
318
- req.origin_input_ids_unpadded,
319
- req.pad_value,
320
- req.pixel_values.shape,
321
- req.image_size,
322
- )
323
- req.sampling_params = recv_req.sampling_params
324
- req.return_logprob = recv_req.return_logprob
325
- req.logprob_start_len = recv_req.logprob_start_len
326
- req.top_logprobs_num = recv_req.top_logprobs_num
327
- req.stream = recv_req.stream
328
305
  req.tokenizer = self.tokenizer
329
-
330
- # Init regex fsm
331
- if req.sampling_params.regex is not None:
332
- req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex)
333
- if not self.disable_regex_jump_forward:
334
- req.jump_forward_map = self.jump_forward_cache.query(
335
- req.sampling_params.regex
306
+ req.sampling_params = recv_req.sampling_params
307
+ if self.model_runner.is_generation:
308
+ req.pixel_values = recv_req.pixel_values
309
+ if req.pixel_values is not None:
310
+ req.pad_value = [
311
+ (recv_req.image_hash) % self.model_config.vocab_size,
312
+ (recv_req.image_hash >> 16) % self.model_config.vocab_size,
313
+ (recv_req.image_hash >> 32) % self.model_config.vocab_size,
314
+ (recv_req.image_hash >> 64) % self.model_config.vocab_size,
315
+ ]
316
+ req.image_size = recv_req.image_size
317
+ (
318
+ req.origin_input_ids,
319
+ req.image_offset,
320
+ ) = self.model_runner.model.pad_input_ids(
321
+ req.origin_input_ids_unpadded,
322
+ req.pad_value,
323
+ req.pixel_values.shape,
324
+ req.image_size,
336
325
  )
326
+ req.return_logprob = recv_req.return_logprob
327
+ req.logprob_start_len = recv_req.logprob_start_len
328
+ req.top_logprobs_num = recv_req.top_logprobs_num
329
+ req.stream = recv_req.stream
330
+
331
+ # Init regex fsm
332
+ if req.sampling_params.regex is not None:
333
+ req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex)
334
+ if not self.disable_regex_jump_forward:
335
+ req.jump_forward_map = self.jump_forward_cache.query(
336
+ req.sampling_params.regex
337
+ )
337
338
 
338
339
  # Truncate prompts that are too long
339
340
  if len(req.origin_input_ids) >= self.max_req_input_len:
@@ -342,186 +343,87 @@ class ModelTpServer:
342
343
  "the max context length. Truncated!!!"
343
344
  )
344
345
  req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
345
- req.sampling_params.max_new_tokens = min(
346
- (
347
- req.sampling_params.max_new_tokens
348
- if req.sampling_params.max_new_tokens is not None
349
- else 1 << 30
350
- ),
351
- self.max_req_input_len - 1 - len(req.origin_input_ids),
352
- )
346
+
347
+ if self.model_runner.is_generation:
348
+ req.sampling_params.max_new_tokens = min(
349
+ (
350
+ req.sampling_params.max_new_tokens
351
+ if req.sampling_params.max_new_tokens is not None
352
+ else 1 << 30
353
+ ),
354
+ self.max_req_input_len - 1 - len(req.origin_input_ids),
355
+ )
356
+
353
357
  self.waiting_queue.append(req)
354
358
 
355
359
  def get_new_prefill_batch(self) -> Optional[ScheduleBatch]:
356
- # TODO(lsyin): organize this function
357
360
  running_bs = (
358
361
  len(self.running_batch.reqs) if self.running_batch is not None else 0
359
362
  )
360
363
  if running_bs >= self.max_running_requests:
361
- return
362
-
363
- # Compute matched prefix length
364
- for req in self.waiting_queue:
365
- req.input_ids = req.origin_input_ids + req.output_ids
366
- try_match_ids = req.input_ids
367
- if req.return_logprob:
368
- try_match_ids = req.input_ids[: req.logprob_start_len]
369
- # NOTE: the prefix_indices must always be aligned with last_node
370
- prefix_indices, last_node = self.tree_cache.match_prefix(
371
- rid=req.rid, key=try_match_ids
372
- )
373
- req.extend_input_len = len(req.input_ids) - len(prefix_indices)
374
- req.prefix_indices = prefix_indices
375
- req.last_node = last_node
364
+ return None
376
365
 
377
366
  # Get priority queue
378
- self.waiting_queue = self.scheduler.get_priority_queue(self.waiting_queue)
379
-
380
- # Add requests if there is available space
381
- can_run_list = []
382
- new_batch_total_tokens = 0
383
- new_batch_input_tokens = 0
367
+ prefix_computed = self.scheduler.calc_priority(self.waiting_queue)
384
368
 
385
- available_size = (
386
- self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
369
+ adder = PrefillAdder(
370
+ self.tree_cache,
371
+ self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(),
372
+ self.max_prefill_tokens,
373
+ self.chunked_prefill_size,
387
374
  )
388
- if self.running_batch:
389
- available_size -= sum(
390
- [
391
- (r.sampling_params.max_new_tokens - len(r.output_ids))
392
- * self.new_token_ratio
393
- for r in self.running_batch.reqs
394
- ]
395
- )
396
375
 
397
- # Handle the current inflight request
398
- take_inflight = 0
399
- if self.current_inflight_req:
400
- take_inflight = 1
401
- r = self.current_inflight_req
402
- r.input_ids = r.origin_input_ids + r.output_ids
403
- truncated = (
404
- len(r.input_ids) - len(r.prefix_indices) > self.chunked_prefill_size
376
+ if self.running_batch is not None:
377
+ adder.remove_running_tokens(self.running_batch, self.new_token_ratio)
378
+
379
+ has_inflight = self.current_inflight_req is not None
380
+ if self.current_inflight_req is not None:
381
+ self.current_inflight_req.init_next_round_input(
382
+ None if prefix_computed else self.tree_cache
405
383
  )
406
- r.extend_input_len = min(
407
- len(r.input_ids) - len(r.prefix_indices), self.chunked_prefill_size
384
+ self.current_inflight_req = adder.add_inflight_req(
385
+ self.current_inflight_req
408
386
  )
409
- r.input_ids = r.input_ids[: len(r.prefix_indices) + r.extend_input_len]
410
- can_run_list.append(r)
411
-
412
- if not truncated:
413
- # Finish inflight
414
- self.current_inflight_req = None
415
- new_batch_total_tokens += (
416
- r.extend_input_len + r.sampling_params.max_new_tokens
417
- )
418
- new_batch_input_tokens += r.extend_input_len
419
- else:
420
- new_batch_total_tokens += r.extend_input_len
421
- new_batch_input_tokens += r.extend_input_len
422
387
 
423
388
  for req in self.waiting_queue:
424
- if req.return_logprob and req.normalized_prompt_logprob is None:
425
- # Need at least two tokens to compute normalized logprob
426
- if req.extend_input_len < 2:
427
- delta = 2 - req.extend_input_len
428
- req.extend_input_len += delta
429
- req.prefix_indices = req.prefix_indices[:-delta]
430
- if req.image_offset is not None:
431
- req.image_offset += delta
432
- if req.extend_input_len == 0 and req.sampling_params.max_new_tokens > 0:
433
- # Need at least one token to compute logits
434
- req.extend_input_len = 1
435
- req.prefix_indices = req.prefix_indices[:-1]
436
- if req.image_offset is not None:
437
- req.image_offset += 1
438
-
389
+ req.init_next_round_input(None if prefix_computed else self.tree_cache)
390
+ res = adder.add_one_req(req)
439
391
  if (
440
- req.extend_input_len
441
- + req.sampling_params.max_new_tokens
442
- + new_batch_total_tokens
443
- < available_size
444
- and (
445
- req.extend_input_len + new_batch_input_tokens
446
- <= self.max_prefill_tokens
447
- or len(can_run_list) == 0
448
- )
392
+ not res
393
+ or adder.no_remaining_tokens()
394
+ or running_bs + len(adder.can_run_list) >= self.max_running_requests
449
395
  ):
450
- delta = self.tree_cache.inc_lock_ref(req.last_node)
451
- available_size += delta
452
-
453
- if not (
454
- req.extend_input_len
455
- + req.sampling_params.max_new_tokens
456
- + new_batch_total_tokens
457
- < available_size
458
- ):
459
- # Undo locking
460
- delta = self.tree_cache.dec_lock_ref(req.last_node)
461
- available_size += delta
462
- break
463
- else:
464
- # Add this request to the running batch
465
- if (
466
- self.chunked_prefill_size is None
467
- or (
468
- new_batch_input_tokens + req.extend_input_len
469
- <= self.chunked_prefill_size
470
- )
471
- or (
472
- req.return_logprob and req.normalized_prompt_logprob is None
473
- )
474
- ):
475
- can_run_list.append(req)
476
- new_batch_total_tokens += (
477
- req.extend_input_len + req.sampling_params.max_new_tokens
478
- )
479
- new_batch_input_tokens += req.extend_input_len
480
- else:
481
- trunc_len = self.chunked_prefill_size - new_batch_input_tokens
482
-
483
- if trunc_len <= 0:
484
- # Undo locking
485
- delta = self.tree_cache.dec_lock_ref(req.last_node)
486
- available_size += delta
487
- break
488
-
489
- req.extend_input_len = trunc_len
490
- req.input_ids = req.input_ids[
491
- : len(req.prefix_indices) + req.extend_input_len
492
- ]
493
- can_run_list.append(req)
494
- self.current_inflight_req = req
495
- new_batch_input_tokens += req.extend_input_len
496
- new_batch_total_tokens += req.extend_input_len
497
- break
498
- else:
499
396
  break
500
397
 
501
- if running_bs + len(can_run_list) >= self.max_running_requests:
502
- break
398
+ can_run_list = adder.can_run_list
399
+
400
+ if adder.new_inflight_req is not None:
401
+ assert self.current_inflight_req is None
402
+ self.current_inflight_req = adder.new_inflight_req
503
403
 
504
404
  if len(can_run_list) == 0:
505
405
  return None
506
406
 
507
407
  # Print stats
508
408
  if self.tp_rank == 0:
509
- hit_tokens = sum(len(x.prefix_indices) for x in can_run_list)
510
- self.tree_cache_metrics["total"] += (
511
- hit_tokens + new_batch_input_tokens
512
- ) / 10**9
513
- self.tree_cache_metrics["hit"] += hit_tokens / 10**9
514
- tree_cache_hit_rate = (
515
- self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
516
- )
409
+ if isinstance(self.tree_cache, RadixCache):
410
+ self.tree_cache_metrics["total"] += (
411
+ adder.log_input_tokens + adder.log_hit_tokens
412
+ ) / 10**9
413
+ self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9
414
+ tree_cache_hit_rate = (
415
+ self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
416
+ )
417
+ else:
418
+ tree_cache_hit_rate = 0.0
517
419
  logger.info(
518
420
  f"[gpu={self.gpu_id}] Prefill batch. "
519
421
  f"#new-seq: {len(can_run_list)}, "
520
- f"#new-token: {new_batch_input_tokens}, "
521
- f"#cached-token: {hit_tokens}, "
422
+ f"#new-token: {adder.log_input_tokens}, "
423
+ f"#cached-token: {adder.log_hit_tokens}, "
522
424
  f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
523
425
  f"#running-req: {running_bs}, "
524
- f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + take_inflight}"
426
+ f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}"
525
427
  )
526
428
 
527
429
  # Return the new batch
@@ -536,45 +438,90 @@ class ModelTpServer:
536
438
 
537
439
  def forward_prefill_batch(self, batch: ScheduleBatch):
538
440
  # Build batch tensors
539
- batch.prepare_for_extend(
540
- self.model_config.vocab_size, self.int_token_logit_bias
541
- )
441
+ batch.prepare_for_extend(self.model_config.vocab_size)
442
+
443
+ if self.model_runner.is_generation:
444
+ # Forward and sample the next tokens
445
+ if batch.extend_num_tokens != 0:
446
+ output = self.model_runner.forward(batch, ForwardMode.EXTEND)
447
+ next_token_ids = batch.sample(output.next_token_logits)
448
+
449
+ # Move logprobs to cpu
450
+ if output.next_token_logprobs is not None:
451
+ output.next_token_logprobs = output.next_token_logprobs[
452
+ torch.arange(len(next_token_ids), device=next_token_ids.device),
453
+ next_token_ids,
454
+ ].tolist()
455
+ output.input_token_logprobs = output.input_token_logprobs.tolist()
456
+ output.normalized_prompt_logprobs = (
457
+ output.normalized_prompt_logprobs.tolist()
458
+ )
542
459
 
543
- # Forward and sample the next tokens
544
- if batch.extend_num_tokens != 0:
545
- output = self.model_runner.forward(batch, ForwardMode.EXTEND)
546
- next_token_ids = batch.sample(output.next_token_logits)
547
-
548
- # Move logprobs to cpu
549
- if output.next_token_logprobs is not None:
550
- output.next_token_logprobs = output.next_token_logprobs[
551
- torch.arange(len(next_token_ids), device=next_token_ids.device),
552
- next_token_ids,
553
- ].tolist()
554
- output.input_token_logprobs = output.input_token_logprobs.tolist()
555
- output.normalized_prompt_logprobs = (
556
- output.normalized_prompt_logprobs.tolist()
557
- )
460
+ next_token_ids = next_token_ids.tolist()
461
+ else:
462
+ if self.tokenizer is None:
463
+ next_token_ids = []
464
+ for req in batch.reqs:
465
+ next_token_ids.append(
466
+ next(iter(req.sampling_params.stop_token_ids))
467
+ )
468
+ else:
469
+ next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
470
+
471
+ # Check finish conditions
472
+ pt = 0
473
+ for i, req in enumerate(batch.reqs):
474
+ if req is not self.current_inflight_req:
475
+ # Inflight reqs' prefill is not finished
476
+ req.completion_tokens_wo_jump_forward += 1
477
+ req.output_ids.append(next_token_ids[i])
478
+ req.check_finished()
479
+
480
+ if req.finished():
481
+ self.tree_cache.cache_finished_req(req)
482
+ else:
483
+ self.tree_cache.cache_unfinished_req(req)
558
484
 
559
- next_token_ids = next_token_ids.tolist()
560
- else:
561
- next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
485
+ if req is self.current_inflight_req:
486
+ # Inflight request would get a new req idx
487
+ self.req_to_token_pool.free(req.req_pool_idx)
562
488
 
563
- # Check finish conditions
564
- pt = 0
565
- for i, req in enumerate(batch.reqs):
566
- if req is not self.current_inflight_req:
567
- req.completion_tokens_wo_jump_forward += 1
568
- req.output_ids.append(next_token_ids[i])
569
- req.check_finished()
489
+ if req.return_logprob:
490
+ self.add_logprob_return_values(i, req, pt, next_token_ids, output)
491
+ pt += req.extend_input_len
492
+ else:
493
+ assert batch.extend_num_tokens != 0
494
+ output = self.model_runner.forward(batch, ForwardMode.EXTEND)
495
+ embeddings = output.embeddings.tolist()
496
+
497
+ # Check finish conditions
498
+ for i, req in enumerate(batch.reqs):
499
+ req.embedding = embeddings[i]
500
+ if req is not self.current_inflight_req:
501
+ # Inflight reqs' prefill is not finished
502
+ # dummy output token for embedding models
503
+ req.output_ids.append(0)
504
+ req.check_finished()
505
+
506
+ if req.finished():
507
+ self.tree_cache.cache_finished_req(req)
508
+ else:
509
+ self.tree_cache.cache_unfinished_req(req)
570
510
 
571
- if req.return_logprob:
572
- self.add_logprob_return_values(i, req, pt, next_token_ids, output)
573
- pt += req.extend_input_len
511
+ if req is self.current_inflight_req:
512
+ # Inflight request would get a new req idx
513
+ self.req_to_token_pool.free(req.req_pool_idx)
574
514
 
575
515
  self.handle_finished_requests(batch)
576
516
 
577
- def add_logprob_return_values(self, i, req, pt, next_token_ids, output):
517
+ def add_logprob_return_values(
518
+ self,
519
+ i,
520
+ req: Req,
521
+ pt: int,
522
+ next_token_ids: List[int],
523
+ output: LogitProcessorOutput,
524
+ ):
578
525
  if req.normalized_prompt_logprob is None:
579
526
  req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
580
527
 
@@ -583,12 +530,12 @@ class ModelTpServer:
583
530
  req.input_token_logprobs = list(
584
531
  zip(
585
532
  output.input_token_logprobs[pt : pt + req.extend_input_len - 1],
586
- req.input_ids[-req.extend_input_len + 1 :],
533
+ req.fill_ids[-req.extend_input_len + 1 :],
587
534
  )
588
535
  )
589
536
  if req.logprob_start_len == 0:
590
537
  req.input_token_logprobs = [
591
- (None, req.input_ids[0])
538
+ (None, req.fill_ids[0])
592
539
  ] + req.input_token_logprobs
593
540
 
594
541
  if req.last_update_decode_tokens != 0:
@@ -602,7 +549,7 @@ class ModelTpServer:
602
549
  + req.extend_input_len
603
550
  - 1
604
551
  ],
605
- req.input_ids[-req.last_update_decode_tokens + 1 :],
552
+ req.fill_ids[-req.last_update_decode_tokens + 1 :],
606
553
  )
607
554
  )
608
555
  )
@@ -623,22 +570,6 @@ class ModelTpServer:
623
570
  )
624
571
  req.output_top_logprobs.append(output.output_top_logprobs[i])
625
572
 
626
- def cache_filled_batch(self, batch: ScheduleBatch):
627
- for i, req in enumerate(batch.reqs):
628
- new_prefix_indices, new_last_node = self.tree_cache.cache_req(
629
- rid=req.rid,
630
- token_ids=tuple(req.input_ids),
631
- last_uncached_pos=len(req.prefix_indices),
632
- req_pool_idx=req.req_pool_idx,
633
- del_in_memory_pool=False,
634
- old_last_node=req.last_node,
635
- )
636
- req.prefix_indices, req.last_node = new_prefix_indices, new_last_node
637
-
638
- if req is self.current_inflight_req:
639
- # inflight request would get a new req idx
640
- self.req_to_token_pool.free(req.req_pool_idx)
641
-
642
573
  def forward_decode_batch(self, batch: ScheduleBatch):
643
574
  # Check if decode out of memory
644
575
  if not batch.check_decode_mem():
@@ -689,6 +620,9 @@ class ModelTpServer:
689
620
  req.output_ids.append(next_token_id)
690
621
  req.check_finished()
691
622
 
623
+ if req.finished():
624
+ self.tree_cache.cache_finished_req(req)
625
+
692
626
  if req.return_logprob:
693
627
  req.output_token_logprobs.append(
694
628
  (next_token_logprobs[i], next_token_id)
@@ -700,20 +634,21 @@ class ModelTpServer:
700
634
 
701
635
  def handle_finished_requests(self, batch: ScheduleBatch):
702
636
  output_rids = []
703
- output_vids = []
704
- decoded_texts = []
705
- output_read_ids = []
706
- output_read_offsets = []
707
- output_skip_special_tokens = []
708
- output_spaces_between_special_tokens = []
709
637
  output_meta_info = []
710
638
  output_finished_reason: List[BaseFinishReason] = []
711
- finished_indices = []
639
+ if self.model_runner.is_generation:
640
+ output_vids = []
641
+ decoded_texts = []
642
+ output_read_ids = []
643
+ output_read_offsets = []
644
+ output_skip_special_tokens = []
645
+ output_spaces_between_special_tokens = []
646
+ else: # for embedding model
647
+ output_embeddings = []
712
648
  unfinished_indices = []
649
+
713
650
  for i, req in enumerate(batch.reqs):
714
- if req.finished():
715
- finished_indices.append(i)
716
- else:
651
+ if not req.finished() and req is not self.current_inflight_req:
717
652
  unfinished_indices.append(i)
718
653
 
719
654
  if req.finished() or (
@@ -726,85 +661,75 @@ class ModelTpServer:
726
661
  )
727
662
  ):
728
663
  output_rids.append(req.rid)
729
- output_vids.append(req.vid)
730
- decoded_texts.append(req.decoded_text)
731
- read_ids, read_offset = req.init_incremental_detokenize()
732
- output_read_ids.append(read_ids)
733
- output_read_offsets.append(read_offset)
734
- output_skip_special_tokens.append(
735
- req.sampling_params.skip_special_tokens
736
- )
737
- output_spaces_between_special_tokens.append(
738
- req.sampling_params.spaces_between_special_tokens
739
- )
740
-
741
- meta_info = {
742
- "prompt_tokens": len(req.origin_input_ids),
743
- "completion_tokens": len(req.output_ids),
744
- "completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
745
- "finish_reason": str(req.finished_reason),
746
- }
747
- if req.return_logprob:
748
- (
749
- meta_info["input_token_logprobs"],
750
- meta_info["output_token_logprobs"],
751
- meta_info["input_top_logprobs"],
752
- meta_info["output_top_logprobs"],
753
- meta_info["normalized_prompt_logprob"],
754
- ) = (
755
- req.input_token_logprobs,
756
- req.output_token_logprobs,
757
- req.input_top_logprobs,
758
- req.output_top_logprobs,
759
- req.normalized_prompt_logprob,
760
- )
761
- output_meta_info.append(meta_info)
762
664
  output_finished_reason.append(req.finished_reason)
665
+ if self.model_runner.is_generation:
666
+ output_vids.append(req.vid)
667
+ decoded_texts.append(req.decoded_text)
668
+ read_ids, read_offset = req.init_incremental_detokenize()
669
+ output_read_ids.append(read_ids)
670
+ output_read_offsets.append(read_offset)
671
+ output_skip_special_tokens.append(
672
+ req.sampling_params.skip_special_tokens
673
+ )
674
+ output_spaces_between_special_tokens.append(
675
+ req.sampling_params.spaces_between_special_tokens
676
+ )
677
+
678
+ meta_info = {
679
+ "prompt_tokens": len(req.origin_input_ids),
680
+ "completion_tokens": len(req.output_ids),
681
+ "completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
682
+ "finish_reason": str(req.finished_reason),
683
+ }
684
+ if req.return_logprob:
685
+ (
686
+ meta_info["input_token_logprobs"],
687
+ meta_info["output_token_logprobs"],
688
+ meta_info["input_top_logprobs"],
689
+ meta_info["output_top_logprobs"],
690
+ meta_info["normalized_prompt_logprob"],
691
+ ) = (
692
+ req.input_token_logprobs,
693
+ req.output_token_logprobs,
694
+ req.input_top_logprobs,
695
+ req.output_top_logprobs,
696
+ req.normalized_prompt_logprob,
697
+ )
698
+ output_meta_info.append(meta_info)
699
+ else: # for embedding model
700
+ output_embeddings.append(req.embedding)
701
+ meta_info = {
702
+ "prompt_tokens": len(req.origin_input_ids),
703
+ }
704
+ output_meta_info.append(meta_info)
763
705
 
764
706
  # Send to detokenizer
765
707
  if output_rids:
766
- self.out_pyobjs.append(
767
- BatchTokenIDOut(
768
- output_rids,
769
- output_vids,
770
- decoded_texts,
771
- output_read_ids,
772
- output_read_offsets,
773
- output_skip_special_tokens,
774
- output_spaces_between_special_tokens,
775
- output_meta_info,
776
- output_finished_reason,
708
+ if self.model_runner.is_generation:
709
+ self.out_pyobjs.append(
710
+ BatchTokenIDOut(
711
+ output_rids,
712
+ output_vids,
713
+ decoded_texts,
714
+ output_read_ids,
715
+ output_read_offsets,
716
+ output_skip_special_tokens,
717
+ output_spaces_between_special_tokens,
718
+ output_meta_info,
719
+ output_finished_reason,
720
+ )
777
721
  )
778
- )
779
-
780
- # Remove finished reqs
781
- if finished_indices:
782
- # Update radix cache
783
- for i in finished_indices:
784
- req = batch.reqs[i]
785
- self.tree_cache.cache_req(
786
- rid=req.rid,
787
- token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1],
788
- last_uncached_pos=len(req.prefix_indices),
789
- req_pool_idx=req.req_pool_idx,
722
+ else: # for embedding model
723
+ self.out_pyobjs.append(
724
+ BatchEmbeddingOut(
725
+ output_rids,
726
+ output_embeddings,
727
+ output_meta_info,
728
+ output_finished_reason,
729
+ )
790
730
  )
791
731
 
792
- self.tree_cache.dec_lock_ref(req.last_node)
793
-
794
- # Update batch tensors
795
- if unfinished_indices:
796
- batch.filter_batch(unfinished_indices)
797
- else:
798
- batch.reqs = []
799
-
800
- def filter_out_inflight(self, batch: ScheduleBatch):
801
- # TODO(lsyin): reduce the overhead, make a special version for this
802
- if self.current_inflight_req is None:
803
- return
804
-
805
- to_remove = batch.reqs.index(self.current_inflight_req)
806
- unfinished_indices = [i for i in range(len(batch.reqs)) if i != to_remove]
807
-
732
+ # Remove finished reqs: update batch tensors
808
733
  batch.filter_batch(unfinished_indices)
809
734
 
810
735
  def flush_cache(self):
@@ -871,7 +796,11 @@ def run_tp_server(
871
796
 
872
797
 
873
798
  def launch_tp_servers(
874
- gpu_ids, tp_rank_range, server_args, nccl_port, model_overide_args
799
+ gpu_ids: List[int],
800
+ tp_rank_range: List[int],
801
+ server_args: ServerArgs,
802
+ nccl_port: int,
803
+ model_overide_args: dict,
875
804
  ):
876
805
  """Launch multiple tensor parallel servers."""
877
806
  procs = []
@@ -886,7 +815,9 @@ def launch_tp_servers(
886
815
  return procs
887
816
 
888
817
 
889
- def broadcast_recv_input(data, rank, dist_group):
818
+ def broadcast_recv_input(
819
+ data: Any, rank: int, dist_group: torch.distributed.ProcessGroup
820
+ ):
890
821
  """Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
891
822
 
892
823
  if rank == 0: