sglang 0.2.15__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. sglang/bench_latency.py +10 -6
  2. sglang/bench_serving.py +33 -38
  3. sglang/global_config.py +0 -4
  4. sglang/lang/backend/runtime_endpoint.py +13 -6
  5. sglang/lang/interpreter.py +1 -1
  6. sglang/launch_server.py +3 -6
  7. sglang/launch_server_llavavid.py +7 -8
  8. sglang/srt/{model_config.py → configs/model_config.py} +5 -0
  9. sglang/srt/constrained/__init__.py +2 -0
  10. sglang/srt/constrained/fsm_cache.py +29 -38
  11. sglang/srt/constrained/jump_forward.py +0 -1
  12. sglang/srt/conversation.py +4 -1
  13. sglang/srt/hf_transformers_utils.py +2 -4
  14. sglang/srt/layers/attention_backend.py +480 -0
  15. sglang/srt/layers/flashinfer_utils.py +235 -0
  16. sglang/srt/layers/logits_processor.py +64 -77
  17. sglang/srt/layers/radix_attention.py +11 -161
  18. sglang/srt/layers/sampler.py +40 -35
  19. sglang/srt/layers/torchao_utils.py +75 -0
  20. sglang/srt/layers/{decode_attention.py → triton_attention/decode_attention.py} +67 -63
  21. sglang/srt/layers/{extend_attention.py → triton_attention/extend_attention.py} +40 -132
  22. sglang/srt/layers/{prefill_attention.py → triton_attention/prefill_attention.py} +13 -7
  23. sglang/srt/lora/lora.py +403 -0
  24. sglang/srt/lora/lora_config.py +43 -0
  25. sglang/srt/lora/lora_manager.py +256 -0
  26. sglang/srt/managers/controller_multi.py +1 -5
  27. sglang/srt/managers/controller_single.py +0 -5
  28. sglang/srt/managers/io_struct.py +16 -1
  29. sglang/srt/managers/policy_scheduler.py +122 -5
  30. sglang/srt/managers/schedule_batch.py +110 -74
  31. sglang/srt/managers/tokenizer_manager.py +24 -15
  32. sglang/srt/managers/tp_worker.py +181 -115
  33. sglang/srt/model_executor/cuda_graph_runner.py +60 -133
  34. sglang/srt/model_executor/forward_batch_info.py +35 -312
  35. sglang/srt/model_executor/model_runner.py +118 -141
  36. sglang/srt/models/baichuan.py +416 -0
  37. sglang/srt/models/chatglm.py +6 -8
  38. sglang/srt/models/commandr.py +1 -5
  39. sglang/srt/models/dbrx.py +1 -5
  40. sglang/srt/models/deepseek.py +1 -5
  41. sglang/srt/models/deepseek_v2.py +1 -5
  42. sglang/srt/models/exaone.py +8 -43
  43. sglang/srt/models/gemma.py +1 -5
  44. sglang/srt/models/gemma2.py +1 -5
  45. sglang/srt/models/gpt_bigcode.py +1 -5
  46. sglang/srt/models/grok.py +1 -5
  47. sglang/srt/models/internlm2.py +1 -5
  48. sglang/srt/models/{llama2.py → llama.py} +48 -26
  49. sglang/srt/models/llama_classification.py +14 -40
  50. sglang/srt/models/llama_embedding.py +7 -6
  51. sglang/srt/models/llava.py +38 -16
  52. sglang/srt/models/llavavid.py +7 -8
  53. sglang/srt/models/minicpm.py +1 -5
  54. sglang/srt/models/minicpm3.py +665 -0
  55. sglang/srt/models/mistral.py +2 -3
  56. sglang/srt/models/mixtral.py +6 -5
  57. sglang/srt/models/mixtral_quant.py +1 -5
  58. sglang/srt/models/qwen.py +1 -5
  59. sglang/srt/models/qwen2.py +1 -5
  60. sglang/srt/models/qwen2_moe.py +6 -5
  61. sglang/srt/models/stablelm.py +1 -5
  62. sglang/srt/models/xverse.py +375 -0
  63. sglang/srt/models/xverse_moe.py +445 -0
  64. sglang/srt/openai_api/adapter.py +65 -46
  65. sglang/srt/openai_api/protocol.py +11 -3
  66. sglang/srt/sampling/sampling_batch_info.py +67 -58
  67. sglang/srt/server.py +24 -14
  68. sglang/srt/server_args.py +130 -28
  69. sglang/srt/utils.py +12 -0
  70. sglang/test/few_shot_gsm8k.py +132 -0
  71. sglang/test/runners.py +114 -22
  72. sglang/test/test_programs.py +70 -0
  73. sglang/test/test_utils.py +89 -1
  74. sglang/utils.py +38 -4
  75. sglang/version.py +1 -1
  76. {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/METADATA +31 -18
  77. sglang-0.3.1.dist-info/RECORD +129 -0
  78. {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/WHEEL +1 -1
  79. sglang-0.2.15.dist-info/RECORD +0 -118
  80. {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/LICENSE +0 -0
  81. {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/top_level.txt +0 -0
@@ -15,19 +15,21 @@ limitations under the License.
15
15
 
16
16
  """A tensor parallel worker."""
17
17
 
18
+ import json
18
19
  import logging
19
20
  import multiprocessing
20
21
  import os
21
22
  import pickle
22
23
  import time
23
24
  import warnings
24
- from typing import Any, List, Optional, Union
25
+ from typing import Any, List, Optional
25
26
 
26
27
  import torch
27
28
  import torch.distributed
28
29
  import torch.distributed as dist
29
30
 
30
31
  from sglang.global_config import global_config
32
+ from sglang.srt.configs.model_config import ModelConfig
31
33
  from sglang.srt.constrained.fsm_cache import FSMCache
32
34
  from sglang.srt.constrained.jump_forward import JumpForwardCache
33
35
  from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
@@ -51,8 +53,6 @@ from sglang.srt.managers.schedule_batch import (
51
53
  )
52
54
  from sglang.srt.mem_cache.chunk_cache import ChunkCache
53
55
  from sglang.srt.mem_cache.radix_cache import RadixCache
54
- from sglang.srt.model_config import ModelConfig
55
- from sglang.srt.model_executor.forward_batch_info import ForwardMode
56
56
  from sglang.srt.model_executor.model_runner import ModelRunner
57
57
  from sglang.srt.server_args import ServerArgs
58
58
  from sglang.srt.utils import (
@@ -66,6 +66,7 @@ from sglang.utils import get_exception_traceback
66
66
  logger = logging.getLogger(__name__)
67
67
 
68
68
 
69
+ # Crash on warning if we are running CI tests
69
70
  crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"
70
71
 
71
72
 
@@ -76,26 +77,26 @@ class ModelTpServer:
76
77
  tp_rank: int,
77
78
  server_args: ServerArgs,
78
79
  nccl_port: int,
79
- model_override_args: dict,
80
80
  ):
81
81
  suppress_other_loggers()
82
82
 
83
- # Copy arguments
83
+ # Parse arguments
84
84
  self.gpu_id = gpu_id
85
85
  self.tp_rank = tp_rank
86
86
  self.tp_size = server_args.tp_size
87
87
  self.dp_size = server_args.dp_size
88
88
  self.schedule_policy = server_args.schedule_policy
89
89
  self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
90
+ self.lora_paths = server_args.lora_paths
91
+ self.max_loras_per_batch = server_args.max_loras_per_batch
90
92
 
91
93
  # Init model and tokenizer
92
94
  self.model_config = ModelConfig(
93
95
  server_args.model_path,
94
96
  server_args.trust_remote_code,
95
97
  context_length=server_args.context_length,
96
- model_override_args=model_override_args,
98
+ model_override_args=json.loads(server_args.json_model_override_args),
97
99
  )
98
-
99
100
  self.model_runner = ModelRunner(
100
101
  model_config=self.model_config,
101
102
  mem_fraction_static=server_args.mem_fraction_static,
@@ -129,14 +130,14 @@ class ModelTpServer:
129
130
  if server_args.max_running_requests is None
130
131
  else server_args.max_running_requests
131
132
  ),
132
- self.model_runner.req_to_token_pool.size - 1,
133
+ self.model_runner.req_to_token_pool.size,
133
134
  )
134
135
  self.max_req_input_len = min(
135
136
  self.model_config.context_len - 1,
136
137
  self.max_total_num_tokens - 1,
137
138
  )
138
139
 
139
- # Sync random seed
140
+ # Sync random seed across TP workers
140
141
  server_args.random_seed = broadcast_recv_input(
141
142
  [server_args.random_seed],
142
143
  self.tp_rank,
@@ -144,7 +145,7 @@ class ModelTpServer:
144
145
  )[0]
145
146
  set_random_seed(server_args.random_seed)
146
147
 
147
- # Print info
148
+ # Print debug info
148
149
  logger.info(
149
150
  f"max_total_num_tokens={self.max_total_num_tokens}, "
150
151
  f"max_prefill_tokens={self.max_prefill_tokens}, "
@@ -181,7 +182,7 @@ class ModelTpServer:
181
182
  self.num_generated_tokens = 0
182
183
  self.last_stats_tic = time.time()
183
184
 
184
- # Chunked prefill
185
+ # Init chunked prefill
185
186
  self.chunked_prefill_size = server_args.chunked_prefill_size
186
187
  self.current_inflight_req = None
187
188
  self.is_mixed_chunk = (
@@ -197,16 +198,6 @@ class ModelTpServer:
197
198
  "trust_remote_code": server_args.trust_remote_code,
198
199
  },
199
200
  skip_tokenizer_init=server_args.skip_tokenizer_init,
200
- json_schema_mode=False,
201
- )
202
- self.json_fsm_cache = FSMCache(
203
- server_args.tokenizer_path,
204
- {
205
- "tokenizer_mode": server_args.tokenizer_mode,
206
- "trust_remote_code": server_args.trust_remote_code,
207
- },
208
- skip_tokenizer_init=server_args.skip_tokenizer_init,
209
- json_schema_mode=True,
210
201
  )
211
202
  self.jump_forward_cache = JumpForwardCache()
212
203
 
@@ -221,15 +212,18 @@ class ModelTpServer:
221
212
  )
222
213
  self.new_token_ratio = self.min_new_token_ratio
223
214
  self.new_token_ratio_decay = global_config.new_token_ratio_decay
215
+ self.do_not_get_new_batch = False
224
216
 
225
217
  def exposed_step(self, recv_reqs: List):
226
218
  try:
227
219
  # Recv requests
228
220
  for recv_req in recv_reqs:
229
- if isinstance(
230
- recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
231
- ):
221
+ if isinstance(recv_req, TokenizedGenerateReqInput):
232
222
  self.handle_generate_request(recv_req)
223
+ self.do_not_get_new_batch = False
224
+ elif isinstance(recv_req, TokenizedEmbeddingReqInput):
225
+ self.handle_embedding_request(recv_req)
226
+ self.do_not_get_new_batch = False
233
227
  elif isinstance(recv_req, FlushCacheReq):
234
228
  self.flush_cache()
235
229
  elif isinstance(recv_req, AbortReq):
@@ -253,7 +247,11 @@ class ModelTpServer:
253
247
 
254
248
  @torch.inference_mode()
255
249
  def forward_step(self):
256
- new_batch = self.get_new_prefill_batch()
250
+ if self.do_not_get_new_batch and self.current_inflight_req is None:
251
+ new_batch = None
252
+ else:
253
+ new_batch = self.get_new_prefill_batch()
254
+ self.do_not_get_new_batch = False
257
255
 
258
256
  if new_batch is not None:
259
257
  # Run a new prefill batch
@@ -280,7 +278,7 @@ class ModelTpServer:
280
278
  self.running_batch = None
281
279
  break
282
280
 
283
- if self.out_pyobjs and self.running_batch.has_stream():
281
+ if self.out_pyobjs and self.running_batch.has_stream:
284
282
  break
285
283
  else:
286
284
  self.check_memory()
@@ -325,73 +323,102 @@ class ModelTpServer:
325
323
 
326
324
  def handle_generate_request(
327
325
  self,
328
- recv_req: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
326
+ recv_req: TokenizedGenerateReqInput,
329
327
  ):
330
- req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
328
+ if isinstance(recv_req, TokenizedGenerateReqInput):
329
+ req = Req(
330
+ recv_req.rid,
331
+ recv_req.input_text,
332
+ recv_req.input_ids,
333
+ lora_path=recv_req.lora_path,
334
+ )
335
+ else:
336
+ req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
331
337
  req.tokenizer = self.tokenizer
332
338
  req.sampling_params = recv_req.sampling_params
333
- if self.model_runner.is_generation:
334
- req.pixel_values = recv_req.pixel_values
335
- if req.pixel_values is not None:
336
- # Use image hash as fake token_ids, which is then used
337
- # for prefix matching
338
- image_hash = hash(tuple(recv_req.image_hashes))
339
- req.pad_value = [
340
- (image_hash) % self.model_config.vocab_size,
341
- (image_hash >> 16) % self.model_config.vocab_size,
342
- (image_hash >> 32) % self.model_config.vocab_size,
343
- (image_hash >> 64) % self.model_config.vocab_size,
344
- ]
345
- req.image_sizes = recv_req.image_sizes
346
- (
347
- req.origin_input_ids,
348
- req.image_offsets,
349
- ) = self.model_runner.model.pad_input_ids(
350
- req.origin_input_ids_unpadded,
351
- req.pad_value,
352
- req.pixel_values,
353
- req.image_sizes,
354
- )
355
- req.return_logprob = recv_req.return_logprob
356
- req.logprob_start_len = recv_req.logprob_start_len
357
- req.top_logprobs_num = recv_req.top_logprobs_num
358
- req.stream = recv_req.stream
359
-
360
- # Init regex fsm fron json
339
+ req.pixel_values = recv_req.pixel_values
340
+ if req.pixel_values is not None:
341
+ # Use image hash as fake token_ids, which is then used
342
+ # for prefix matching
343
+ image_hash = hash(tuple(recv_req.image_hashes))
344
+ req.pad_value = [
345
+ (image_hash) % self.model_config.vocab_size,
346
+ (image_hash >> 16) % self.model_config.vocab_size,
347
+ (image_hash >> 32) % self.model_config.vocab_size,
348
+ (image_hash >> 64) % self.model_config.vocab_size,
349
+ ]
350
+ req.image_sizes = recv_req.image_sizes
351
+ (
352
+ req.origin_input_ids,
353
+ req.image_offsets,
354
+ ) = self.model_runner.model.pad_input_ids(
355
+ req.origin_input_ids_unpadded,
356
+ req.pad_value,
357
+ req.pixel_values,
358
+ req.image_sizes,
359
+ )
360
+ # Only when pixel values is not None we have modalities
361
+ req.modalities = recv_req.modalites
362
+ req.return_logprob = recv_req.return_logprob
363
+ req.top_logprobs_num = recv_req.top_logprobs_num
364
+ req.stream = recv_req.stream
365
+ req.logprob_start_len = recv_req.logprob_start_len
366
+
367
+ if req.logprob_start_len == -1:
368
+ # By default, only return the logprobs for output tokens
369
+ req.logprob_start_len = len(recv_req.input_ids) - 1
370
+
371
+ # Init regex FSM
372
+ if (
373
+ req.sampling_params.json_schema is not None
374
+ or req.sampling_params.regex is not None
375
+ ):
361
376
  if req.sampling_params.json_schema is not None:
362
- req.regex_fsm, computed_regex_string = self.json_fsm_cache.query(
363
- req.sampling_params.json_schema
377
+ req.regex_fsm, computed_regex_string = self.regex_fsm_cache.query(
378
+ ("json", req.sampling_params.json_schema)
364
379
  )
365
- if not self.disable_regex_jump_forward:
366
- req.jump_forward_map = self.jump_forward_cache.query(
367
- computed_regex_string
368
- )
369
-
370
- # Init regex fsm
371
380
  elif req.sampling_params.regex is not None:
372
- req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex)
373
- if not self.disable_regex_jump_forward:
374
- req.jump_forward_map = self.jump_forward_cache.query(
375
- req.sampling_params.regex
376
- )
381
+ req.regex_fsm, computed_regex_string = self.regex_fsm_cache.query(
382
+ ("regex", req.sampling_params.regex)
383
+ )
384
+ if not self.disable_regex_jump_forward:
385
+ req.jump_forward_map = self.jump_forward_cache.query(
386
+ computed_regex_string
387
+ )
377
388
 
378
389
  # Truncate prompts that are too long
379
390
  if len(req.origin_input_ids) >= self.max_req_input_len:
380
- logger.warn(
391
+ logger.warning(
381
392
  "Request length is longer than the KV cache pool size or "
382
393
  "the max context length. Truncated!!!"
383
394
  )
384
395
  req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
396
+ req.sampling_params.max_new_tokens = min(
397
+ (
398
+ req.sampling_params.max_new_tokens
399
+ if req.sampling_params.max_new_tokens is not None
400
+ else 1 << 30
401
+ ),
402
+ self.max_req_input_len - 1 - len(req.origin_input_ids),
403
+ )
385
404
 
386
- if self.model_runner.is_generation:
387
- req.sampling_params.max_new_tokens = min(
388
- (
389
- req.sampling_params.max_new_tokens
390
- if req.sampling_params.max_new_tokens is not None
391
- else 1 << 30
392
- ),
393
- self.max_req_input_len - 1 - len(req.origin_input_ids),
405
+ self.waiting_queue.append(req)
406
+
407
+ def handle_embedding_request(
408
+ self,
409
+ recv_req: TokenizedEmbeddingReqInput,
410
+ ):
411
+ req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
412
+ req.tokenizer = self.tokenizer
413
+ req.sampling_params = recv_req.sampling_params
414
+
415
+ # Truncate prompts that are too long
416
+ if len(req.origin_input_ids) >= self.max_req_input_len:
417
+ logger.warn(
418
+ "Request length is longer than the KV cache pool size or "
419
+ "the max context length. Truncated!!!"
394
420
  )
421
+ req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
395
422
 
396
423
  self.waiting_queue.append(req)
397
424
 
@@ -409,6 +436,8 @@ class ModelTpServer:
409
436
 
410
437
  adder = PrefillAdder(
411
438
  self.tree_cache,
439
+ self.running_batch,
440
+ self.new_token_ratio,
412
441
  self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(),
413
442
  self.max_prefill_tokens,
414
443
  self.chunked_prefill_size,
@@ -416,7 +445,7 @@ class ModelTpServer:
416
445
  )
417
446
 
418
447
  if self.running_batch is not None:
419
- adder.remove_running_tokens(self.running_batch, self.new_token_ratio)
448
+ adder.remove_running_tokens(self.running_batch)
420
449
 
421
450
  has_inflight = self.current_inflight_req is not None
422
451
  if self.current_inflight_req is not None:
@@ -427,12 +456,30 @@ class ModelTpServer:
427
456
  self.current_inflight_req
428
457
  )
429
458
 
459
+ if self.lora_paths is not None:
460
+ lora_set = (
461
+ set([req.lora_path for req in self.running_batch.reqs])
462
+ if self.running_batch is not None
463
+ else set([])
464
+ )
465
+
430
466
  for req in self.waiting_queue:
467
+ if adder.no_remaining_tokens():
468
+ break
431
469
  req.init_next_round_input(None if prefix_computed else self.tree_cache)
470
+ if (
471
+ self.lora_paths is not None
472
+ and len(
473
+ lora_set
474
+ | set([req.lora_path for req in adder.can_run_list])
475
+ | set([req.lora_path])
476
+ )
477
+ > self.max_loras_per_batch
478
+ ):
479
+ break
432
480
  res = adder.add_one_req(req)
433
481
  if (
434
482
  not res
435
- or adder.no_remaining_tokens()
436
483
  or running_bs + len(adder.can_run_list) >= self.max_running_requests
437
484
  ):
438
485
  break
@@ -504,10 +551,9 @@ class ModelTpServer:
504
551
  if self.model_runner.is_generation:
505
552
  # Forward and sample the next tokens
506
553
  if batch.extend_num_tokens != 0:
507
- sample_output, logits_output = self.model_runner.forward(
508
- batch, ForwardMode.EXTEND
509
- )
510
- next_token_ids = batch.check_sample_results(sample_output)
554
+ logits_output = self.model_runner.forward(batch)
555
+ next_token_ids = self.model_runner.sample(logits_output, batch)
556
+
511
557
  batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
512
558
  next_token_ids
513
559
  )
@@ -541,7 +587,7 @@ class ModelTpServer:
541
587
  next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
542
588
 
543
589
  # Check finish conditions
544
- pt = 0
590
+ logprob_pt = 0
545
591
  for i, req in enumerate(batch.reqs):
546
592
  if req is not self.current_inflight_req:
547
593
  # Inflight reqs' prefill is not finished
@@ -565,13 +611,12 @@ class ModelTpServer:
565
611
  self.req_to_token_pool.free(req.req_pool_idx)
566
612
 
567
613
  if req.return_logprob:
568
- self.add_logprob_return_values(
569
- i, req, pt, next_token_ids, logits_output
614
+ logprob_pt += self.add_logprob_return_values(
615
+ i, req, logprob_pt, next_token_ids, logits_output
570
616
  )
571
- pt += req.extend_input_len
572
617
  else:
573
618
  assert batch.extend_num_tokens != 0
574
- logits_output = self.model_runner.forward(batch, ForwardMode.EXTEND)
619
+ logits_output = self.model_runner.forward(batch)
575
620
  embeddings = logits_output.embeddings.tolist()
576
621
 
577
622
  # Check finish conditions
@@ -596,48 +641,63 @@ class ModelTpServer:
596
641
 
597
642
  def add_logprob_return_values(
598
643
  self,
599
- i,
644
+ i: int,
600
645
  req: Req,
601
646
  pt: int,
602
647
  next_token_ids: List[int],
603
648
  output: LogitsProcessorOutput,
604
649
  ):
650
+ """Attach logprobs to the return values."""
651
+ req.output_token_logprobs.append(
652
+ (output.next_token_logprobs[i], next_token_ids[i])
653
+ )
654
+
655
+ # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
656
+ num_input_logprobs = req.extend_input_len - req.extend_logprob_start_len
657
+
605
658
  if req.normalized_prompt_logprob is None:
606
659
  req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
607
660
 
608
661
  if req.input_token_logprobs is None:
609
- # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
610
- req.input_token_logprobs = list(
611
- zip(
612
- output.input_token_logprobs[pt : pt + req.extend_input_len - 1],
613
- req.fill_ids[-req.extend_input_len + 1 :],
614
- )
615
- )
616
- if req.logprob_start_len == 0:
662
+ input_token_logprobs = output.input_token_logprobs[
663
+ pt : pt + num_input_logprobs - 1 - req.last_update_decode_tokens
664
+ ]
665
+ input_token_ids = req.fill_ids[
666
+ len(req.fill_ids)
667
+ - num_input_logprobs
668
+ + 1 : len(req.fill_ids)
669
+ - req.last_update_decode_tokens
670
+ ]
671
+ req.input_token_logprobs = list(zip(input_token_logprobs, input_token_ids))
672
+
673
+ if (
674
+ req.logprob_start_len == 0
675
+ ): # The first token does not have logprob, pad it.
617
676
  req.input_token_logprobs = [
618
677
  (None, req.fill_ids[0])
619
678
  ] + req.input_token_logprobs
620
679
 
621
680
  if req.last_update_decode_tokens != 0:
681
+ # Some decode tokens are re-computed in an extend batch
622
682
  req.output_token_logprobs.extend(
623
683
  list(
624
684
  zip(
625
685
  output.input_token_logprobs[
626
686
  pt
627
- + req.extend_input_len
687
+ + num_input_logprobs
688
+ - 1
628
689
  - req.last_update_decode_tokens : pt
629
- + req.extend_input_len
690
+ + num_input_logprobs
630
691
  - 1
631
692
  ],
632
- req.fill_ids[-req.last_update_decode_tokens + 1 :],
693
+ req.fill_ids[
694
+ len(req.fill_ids)
695
+ - req.last_update_decode_tokens : len(req.fill_ids)
696
+ ],
633
697
  )
634
698
  )
635
699
  )
636
700
 
637
- req.output_token_logprobs.append(
638
- (output.next_token_logprobs[i], next_token_ids[i])
639
- )
640
-
641
701
  if req.top_logprobs_num > 0:
642
702
  if req.input_top_logprobs is None:
643
703
  req.input_top_logprobs = output.input_top_logprobs[i]
@@ -646,10 +706,12 @@ class ModelTpServer:
646
706
 
647
707
  if req.last_update_decode_tokens != 0:
648
708
  req.output_top_logprobs.extend(
649
- output.input_top_logprobs[i][-req.last_update_decode_tokens + 1 :]
709
+ output.input_top_logprobs[i][-req.last_update_decode_tokens :]
650
710
  )
651
711
  req.output_top_logprobs.append(output.output_top_logprobs[i])
652
712
 
713
+ return num_input_logprobs
714
+
653
715
  def forward_decode_batch(self, batch: ScheduleBatch):
654
716
  # Check if decode out of memory
655
717
  if not batch.check_decode_mem():
@@ -682,10 +744,8 @@ class ModelTpServer:
682
744
  batch.prepare_for_decode()
683
745
 
684
746
  # Forward and sample the next tokens
685
- sample_output, logits_output = self.model_runner.forward(
686
- batch, ForwardMode.DECODE
687
- )
688
- next_token_ids = batch.check_sample_results(sample_output)
747
+ logits_output = self.model_runner.forward(batch)
748
+ next_token_ids = self.model_runner.sample(logits_output, batch)
689
749
  batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
690
750
  next_token_ids
691
751
  )
@@ -700,6 +760,7 @@ class ModelTpServer:
700
760
  next_token_ids = next_token_ids.tolist()
701
761
 
702
762
  # Check finish condition
763
+ has_finished = False
703
764
  for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
704
765
  req.completion_tokens_wo_jump_forward += 1
705
766
  req.output_ids.append(next_token_id)
@@ -712,6 +773,7 @@ class ModelTpServer:
712
773
 
713
774
  if req.finished():
714
775
  self.tree_cache.cache_finished_req(req)
776
+ has_finished = True
715
777
 
716
778
  if req.return_logprob:
717
779
  req.output_token_logprobs.append(
@@ -720,6 +782,9 @@ class ModelTpServer:
720
782
  if req.top_logprobs_num > 0:
721
783
  req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
722
784
 
785
+ if not has_finished:
786
+ self.do_not_get_new_batch = True
787
+
723
788
  self.handle_finished_requests(batch)
724
789
 
725
790
  def handle_finished_requests(self, batch: ScheduleBatch):
@@ -769,7 +834,11 @@ class ModelTpServer:
769
834
  "prompt_tokens": len(req.origin_input_ids),
770
835
  "completion_tokens": len(req.output_ids),
771
836
  "completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
772
- "finish_reason": str(req.finished_reason),
837
+ "finish_reason": (
838
+ req.finished_reason.to_json()
839
+ if req.finished_reason is not None
840
+ else None
841
+ ),
773
842
  }
774
843
  if req.return_logprob:
775
844
  (
@@ -876,7 +945,6 @@ def run_tp_server(
876
945
  tp_rank: int,
877
946
  server_args: ServerArgs,
878
947
  nccl_port: int,
879
- model_override_args: dict,
880
948
  ):
881
949
  """Run a tensor parallel model server."""
882
950
  configure_logger(server_args, prefix=f" TP{tp_rank}")
@@ -887,7 +955,6 @@ def run_tp_server(
887
955
  tp_rank,
888
956
  server_args,
889
957
  nccl_port,
890
- model_override_args,
891
958
  )
892
959
  tp_cpu_group = model_server.model_runner.tp_group.cpu_group
893
960
 
@@ -904,14 +971,13 @@ def launch_tp_servers(
904
971
  tp_rank_range: List[int],
905
972
  server_args: ServerArgs,
906
973
  nccl_port: int,
907
- model_override_args: dict,
908
974
  ):
909
975
  """Launch multiple tensor parallel servers."""
910
976
  procs = []
911
977
  for i in tp_rank_range:
912
978
  proc = multiprocessing.Process(
913
979
  target=run_tp_server,
914
- args=(gpu_ids[i], i, server_args, nccl_port, model_override_args),
980
+ args=(gpu_ids[i], i, server_args, nccl_port),
915
981
  )
916
982
  proc.start()
917
983
  procs.append(proc)