sglang 0.3.0__py3-none-any.whl → 0.3.1.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. sglang/bench_latency.py +17 -8
  2. sglang/bench_serving.py +33 -38
  3. sglang/global_config.py +5 -17
  4. sglang/lang/backend/runtime_endpoint.py +5 -2
  5. sglang/lang/interpreter.py +1 -4
  6. sglang/launch_server.py +3 -6
  7. sglang/launch_server_llavavid.py +7 -8
  8. sglang/srt/{model_config.py → configs/model_config.py} +5 -0
  9. sglang/srt/constrained/__init__.py +2 -0
  10. sglang/srt/constrained/fsm_cache.py +33 -38
  11. sglang/srt/constrained/jump_forward.py +0 -1
  12. sglang/srt/conversation.py +4 -1
  13. sglang/srt/hf_transformers_utils.py +1 -3
  14. sglang/srt/layers/activation.py +12 -0
  15. sglang/srt/layers/attention_backend.py +480 -0
  16. sglang/srt/layers/flashinfer_utils.py +235 -0
  17. sglang/srt/layers/fused_moe/layer.py +27 -7
  18. sglang/srt/layers/layernorm.py +12 -0
  19. sglang/srt/layers/logits_processor.py +64 -77
  20. sglang/srt/layers/radix_attention.py +11 -161
  21. sglang/srt/layers/sampler.py +38 -122
  22. sglang/srt/layers/torchao_utils.py +75 -0
  23. sglang/srt/layers/{decode_attention.py → triton_attention/decode_attention.py} +67 -63
  24. sglang/srt/layers/{extend_attention.py → triton_attention/extend_attention.py} +40 -132
  25. sglang/srt/layers/{prefill_attention.py → triton_attention/prefill_attention.py} +13 -7
  26. sglang/srt/lora/lora.py +403 -0
  27. sglang/srt/lora/lora_config.py +43 -0
  28. sglang/srt/lora/lora_manager.py +259 -0
  29. sglang/srt/managers/controller_multi.py +1 -5
  30. sglang/srt/managers/controller_single.py +0 -5
  31. sglang/srt/managers/io_struct.py +16 -1
  32. sglang/srt/managers/policy_scheduler.py +122 -5
  33. sglang/srt/managers/schedule_batch.py +105 -71
  34. sglang/srt/managers/tokenizer_manager.py +17 -8
  35. sglang/srt/managers/tp_worker.py +188 -121
  36. sglang/srt/model_executor/cuda_graph_runner.py +69 -133
  37. sglang/srt/model_executor/forward_batch_info.py +35 -312
  38. sglang/srt/model_executor/model_runner.py +123 -154
  39. sglang/srt/models/baichuan.py +416 -0
  40. sglang/srt/models/chatglm.py +1 -5
  41. sglang/srt/models/commandr.py +1 -5
  42. sglang/srt/models/dbrx.py +1 -5
  43. sglang/srt/models/deepseek.py +1 -5
  44. sglang/srt/models/deepseek_v2.py +7 -6
  45. sglang/srt/models/exaone.py +1 -5
  46. sglang/srt/models/gemma.py +1 -5
  47. sglang/srt/models/gemma2.py +1 -5
  48. sglang/srt/models/gpt_bigcode.py +1 -5
  49. sglang/srt/models/grok.py +1 -5
  50. sglang/srt/models/internlm2.py +1 -5
  51. sglang/srt/models/llama.py +51 -5
  52. sglang/srt/models/llama_classification.py +1 -20
  53. sglang/srt/models/llava.py +30 -5
  54. sglang/srt/models/llavavid.py +2 -2
  55. sglang/srt/models/minicpm.py +1 -5
  56. sglang/srt/models/minicpm3.py +669 -0
  57. sglang/srt/models/mixtral.py +6 -5
  58. sglang/srt/models/mixtral_quant.py +1 -5
  59. sglang/srt/models/olmoe.py +415 -0
  60. sglang/srt/models/qwen.py +1 -5
  61. sglang/srt/models/qwen2.py +1 -5
  62. sglang/srt/models/qwen2_moe.py +6 -5
  63. sglang/srt/models/stablelm.py +1 -5
  64. sglang/srt/models/xverse.py +375 -0
  65. sglang/srt/models/xverse_moe.py +445 -0
  66. sglang/srt/openai_api/adapter.py +65 -46
  67. sglang/srt/openai_api/protocol.py +11 -3
  68. sglang/srt/sampling/sampling_batch_info.py +46 -80
  69. sglang/srt/server.py +30 -15
  70. sglang/srt/server_args.py +163 -28
  71. sglang/srt/utils.py +19 -51
  72. sglang/test/few_shot_gsm8k.py +132 -0
  73. sglang/test/runners.py +114 -22
  74. sglang/test/test_programs.py +7 -5
  75. sglang/test/test_utils.py +85 -2
  76. sglang/utils.py +32 -37
  77. sglang/version.py +1 -1
  78. {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/METADATA +30 -18
  79. sglang-0.3.1.post1.dist-info/RECORD +130 -0
  80. {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/WHEEL +1 -1
  81. sglang-0.3.0.dist-info/RECORD +0 -118
  82. {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/LICENSE +0 -0
  83. {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/top_level.txt +0 -0
@@ -15,19 +15,21 @@ limitations under the License.
15
15
 
16
16
  """A tensor parallel worker."""
17
17
 
18
+ import json
18
19
  import logging
19
20
  import multiprocessing
20
21
  import os
21
22
  import pickle
22
23
  import time
23
24
  import warnings
24
- from typing import Any, List, Optional, Union
25
+ from typing import Any, List, Optional
25
26
 
26
27
  import torch
27
28
  import torch.distributed
28
29
  import torch.distributed as dist
29
30
 
30
31
  from sglang.global_config import global_config
32
+ from sglang.srt.configs.model_config import ModelConfig
31
33
  from sglang.srt.constrained.fsm_cache import FSMCache
32
34
  from sglang.srt.constrained.jump_forward import JumpForwardCache
33
35
  from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
@@ -51,8 +53,6 @@ from sglang.srt.managers.schedule_batch import (
51
53
  )
52
54
  from sglang.srt.mem_cache.chunk_cache import ChunkCache
53
55
  from sglang.srt.mem_cache.radix_cache import RadixCache
54
- from sglang.srt.model_config import ModelConfig
55
- from sglang.srt.model_executor.forward_batch_info import ForwardMode
56
56
  from sglang.srt.model_executor.model_runner import ModelRunner
57
57
  from sglang.srt.server_args import ServerArgs
58
58
  from sglang.srt.utils import (
@@ -66,6 +66,7 @@ from sglang.utils import get_exception_traceback
66
66
  logger = logging.getLogger(__name__)
67
67
 
68
68
 
69
+ # Crash on warning if we are running CI tests
69
70
  crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"
70
71
 
71
72
 
@@ -76,26 +77,26 @@ class ModelTpServer:
76
77
  tp_rank: int,
77
78
  server_args: ServerArgs,
78
79
  nccl_port: int,
79
- model_override_args: dict,
80
80
  ):
81
81
  suppress_other_loggers()
82
82
 
83
- # Copy arguments
83
+ # Parse arguments
84
84
  self.gpu_id = gpu_id
85
85
  self.tp_rank = tp_rank
86
86
  self.tp_size = server_args.tp_size
87
87
  self.dp_size = server_args.dp_size
88
88
  self.schedule_policy = server_args.schedule_policy
89
89
  self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
90
+ self.lora_paths = server_args.lora_paths
91
+ self.max_loras_per_batch = server_args.max_loras_per_batch
90
92
 
91
93
  # Init model and tokenizer
92
94
  self.model_config = ModelConfig(
93
95
  server_args.model_path,
94
96
  server_args.trust_remote_code,
95
97
  context_length=server_args.context_length,
96
- model_override_args=model_override_args,
98
+ model_override_args=json.loads(server_args.json_model_override_args),
97
99
  )
98
-
99
100
  self.model_runner = ModelRunner(
100
101
  model_config=self.model_config,
101
102
  mem_fraction_static=server_args.mem_fraction_static,
@@ -129,14 +130,14 @@ class ModelTpServer:
129
130
  if server_args.max_running_requests is None
130
131
  else server_args.max_running_requests
131
132
  ),
132
- self.model_runner.req_to_token_pool.size - 1,
133
+ self.model_runner.req_to_token_pool.size,
133
134
  )
134
135
  self.max_req_input_len = min(
135
136
  self.model_config.context_len - 1,
136
137
  self.max_total_num_tokens - 1,
137
138
  )
138
139
 
139
- # Sync random seed
140
+ # Sync random seed across TP workers
140
141
  server_args.random_seed = broadcast_recv_input(
141
142
  [server_args.random_seed],
142
143
  self.tp_rank,
@@ -144,7 +145,7 @@ class ModelTpServer:
144
145
  )[0]
145
146
  set_random_seed(server_args.random_seed)
146
147
 
147
- # Print info
148
+ # Print debug info
148
149
  logger.info(
149
150
  f"max_total_num_tokens={self.max_total_num_tokens}, "
150
151
  f"max_prefill_tokens={self.max_prefill_tokens}, "
@@ -181,7 +182,7 @@ class ModelTpServer:
181
182
  self.num_generated_tokens = 0
182
183
  self.last_stats_tic = time.time()
183
184
 
184
- # Chunked prefill
185
+ # Init chunked prefill
185
186
  self.chunked_prefill_size = server_args.chunked_prefill_size
186
187
  self.current_inflight_req = None
187
188
  self.is_mixed_chunk = (
@@ -197,16 +198,7 @@ class ModelTpServer:
197
198
  "trust_remote_code": server_args.trust_remote_code,
198
199
  },
199
200
  skip_tokenizer_init=server_args.skip_tokenizer_init,
200
- json_schema_mode=False,
201
- )
202
- self.json_fsm_cache = FSMCache(
203
- server_args.tokenizer_path,
204
- {
205
- "tokenizer_mode": server_args.tokenizer_mode,
206
- "trust_remote_code": server_args.trust_remote_code,
207
- },
208
- skip_tokenizer_init=server_args.skip_tokenizer_init,
209
- json_schema_mode=True,
201
+ constrained_json_whitespace_pattern=server_args.constrained_json_whitespace_pattern,
210
202
  )
211
203
  self.jump_forward_cache = JumpForwardCache()
212
204
 
@@ -221,15 +213,18 @@ class ModelTpServer:
221
213
  )
222
214
  self.new_token_ratio = self.min_new_token_ratio
223
215
  self.new_token_ratio_decay = global_config.new_token_ratio_decay
216
+ self.do_not_get_new_batch = False
224
217
 
225
218
  def exposed_step(self, recv_reqs: List):
226
219
  try:
227
220
  # Recv requests
228
221
  for recv_req in recv_reqs:
229
- if isinstance(
230
- recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
231
- ):
222
+ if isinstance(recv_req, TokenizedGenerateReqInput):
232
223
  self.handle_generate_request(recv_req)
224
+ self.do_not_get_new_batch = False
225
+ elif isinstance(recv_req, TokenizedEmbeddingReqInput):
226
+ self.handle_embedding_request(recv_req)
227
+ self.do_not_get_new_batch = False
233
228
  elif isinstance(recv_req, FlushCacheReq):
234
229
  self.flush_cache()
235
230
  elif isinstance(recv_req, AbortReq):
@@ -253,7 +248,11 @@ class ModelTpServer:
253
248
 
254
249
  @torch.inference_mode()
255
250
  def forward_step(self):
256
- new_batch = self.get_new_prefill_batch()
251
+ if self.do_not_get_new_batch and self.current_inflight_req is None:
252
+ new_batch = None
253
+ else:
254
+ new_batch = self.get_new_prefill_batch()
255
+ self.do_not_get_new_batch = False
257
256
 
258
257
  if new_batch is not None:
259
258
  # Run a new prefill batch
@@ -280,7 +279,7 @@ class ModelTpServer:
280
279
  self.running_batch = None
281
280
  break
282
281
 
283
- if self.out_pyobjs and self.running_batch.has_stream():
282
+ if self.out_pyobjs and self.running_batch.has_stream:
284
283
  break
285
284
  else:
286
285
  self.check_memory()
@@ -325,73 +324,102 @@ class ModelTpServer:
325
324
 
326
325
  def handle_generate_request(
327
326
  self,
328
- recv_req: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
327
+ recv_req: TokenizedGenerateReqInput,
329
328
  ):
330
- req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
329
+ if isinstance(recv_req, TokenizedGenerateReqInput):
330
+ req = Req(
331
+ recv_req.rid,
332
+ recv_req.input_text,
333
+ recv_req.input_ids,
334
+ lora_path=recv_req.lora_path,
335
+ )
336
+ else:
337
+ req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
331
338
  req.tokenizer = self.tokenizer
332
339
  req.sampling_params = recv_req.sampling_params
333
- if self.model_runner.is_generation:
334
- req.pixel_values = recv_req.pixel_values
335
- if req.pixel_values is not None:
336
- # Use image hash as fake token_ids, which is then used
337
- # for prefix matching
338
- image_hash = hash(tuple(recv_req.image_hashes))
339
- req.pad_value = [
340
- (image_hash) % self.model_config.vocab_size,
341
- (image_hash >> 16) % self.model_config.vocab_size,
342
- (image_hash >> 32) % self.model_config.vocab_size,
343
- (image_hash >> 64) % self.model_config.vocab_size,
344
- ]
345
- req.image_sizes = recv_req.image_sizes
346
- (
347
- req.origin_input_ids,
348
- req.image_offsets,
349
- ) = self.model_runner.model.pad_input_ids(
350
- req.origin_input_ids_unpadded,
351
- req.pad_value,
352
- req.pixel_values,
353
- req.image_sizes,
354
- )
355
- req.return_logprob = recv_req.return_logprob
356
- req.logprob_start_len = recv_req.logprob_start_len
357
- req.top_logprobs_num = recv_req.top_logprobs_num
358
- req.stream = recv_req.stream
359
-
360
- # Init regex fsm fron json
340
+ req.pixel_values = recv_req.pixel_values
341
+ if req.pixel_values is not None:
342
+ # Use image hash as fake token_ids, which is then used
343
+ # for prefix matching
344
+ image_hash = hash(tuple(recv_req.image_hashes))
345
+ req.pad_value = [
346
+ (image_hash) % self.model_config.vocab_size,
347
+ (image_hash >> 16) % self.model_config.vocab_size,
348
+ (image_hash >> 32) % self.model_config.vocab_size,
349
+ (image_hash >> 64) % self.model_config.vocab_size,
350
+ ]
351
+ req.image_sizes = recv_req.image_sizes
352
+ (
353
+ req.origin_input_ids,
354
+ req.image_offsets,
355
+ ) = self.model_runner.model.pad_input_ids(
356
+ req.origin_input_ids_unpadded,
357
+ req.pad_value,
358
+ req.pixel_values,
359
+ req.image_sizes,
360
+ )
361
+ # Only when pixel values is not None we have modalities
362
+ req.modalities = recv_req.modalites
363
+ req.return_logprob = recv_req.return_logprob
364
+ req.top_logprobs_num = recv_req.top_logprobs_num
365
+ req.stream = recv_req.stream
366
+ req.logprob_start_len = recv_req.logprob_start_len
367
+
368
+ if req.logprob_start_len == -1:
369
+ # By default, only return the logprobs for output tokens
370
+ req.logprob_start_len = len(recv_req.input_ids) - 1
371
+
372
+ # Init regex FSM
373
+ if (
374
+ req.sampling_params.json_schema is not None
375
+ or req.sampling_params.regex is not None
376
+ ):
361
377
  if req.sampling_params.json_schema is not None:
362
- req.regex_fsm, computed_regex_string = self.json_fsm_cache.query(
363
- req.sampling_params.json_schema
378
+ req.regex_fsm, computed_regex_string = self.regex_fsm_cache.query(
379
+ ("json", req.sampling_params.json_schema)
364
380
  )
365
- if not self.disable_regex_jump_forward:
366
- req.jump_forward_map = self.jump_forward_cache.query(
367
- computed_regex_string
368
- )
369
-
370
- # Init regex fsm
371
381
  elif req.sampling_params.regex is not None:
372
- req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex)
373
- if not self.disable_regex_jump_forward:
374
- req.jump_forward_map = self.jump_forward_cache.query(
375
- req.sampling_params.regex
376
- )
382
+ req.regex_fsm, computed_regex_string = self.regex_fsm_cache.query(
383
+ ("regex", req.sampling_params.regex)
384
+ )
385
+ if not self.disable_regex_jump_forward:
386
+ req.jump_forward_map = self.jump_forward_cache.query(
387
+ computed_regex_string
388
+ )
377
389
 
378
390
  # Truncate prompts that are too long
379
391
  if len(req.origin_input_ids) >= self.max_req_input_len:
380
- logger.warn(
392
+ logger.warning(
381
393
  "Request length is longer than the KV cache pool size or "
382
394
  "the max context length. Truncated!!!"
383
395
  )
384
396
  req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
397
+ req.sampling_params.max_new_tokens = min(
398
+ (
399
+ req.sampling_params.max_new_tokens
400
+ if req.sampling_params.max_new_tokens is not None
401
+ else 1 << 30
402
+ ),
403
+ self.max_req_input_len - 1 - len(req.origin_input_ids),
404
+ )
385
405
 
386
- if self.model_runner.is_generation:
387
- req.sampling_params.max_new_tokens = min(
388
- (
389
- req.sampling_params.max_new_tokens
390
- if req.sampling_params.max_new_tokens is not None
391
- else 1 << 30
392
- ),
393
- self.max_req_input_len - 1 - len(req.origin_input_ids),
406
+ self.waiting_queue.append(req)
407
+
408
+ def handle_embedding_request(
409
+ self,
410
+ recv_req: TokenizedEmbeddingReqInput,
411
+ ):
412
+ req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
413
+ req.tokenizer = self.tokenizer
414
+ req.sampling_params = recv_req.sampling_params
415
+
416
+ # Truncate prompts that are too long
417
+ if len(req.origin_input_ids) >= self.max_req_input_len:
418
+ logger.warning(
419
+ "Request length is longer than the KV cache pool size or "
420
+ "the max context length. Truncated!!!"
394
421
  )
422
+ req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
395
423
 
396
424
  self.waiting_queue.append(req)
397
425
 
@@ -409,6 +437,8 @@ class ModelTpServer:
409
437
 
410
438
  adder = PrefillAdder(
411
439
  self.tree_cache,
440
+ self.running_batch,
441
+ self.new_token_ratio,
412
442
  self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(),
413
443
  self.max_prefill_tokens,
414
444
  self.chunked_prefill_size,
@@ -416,7 +446,7 @@ class ModelTpServer:
416
446
  )
417
447
 
418
448
  if self.running_batch is not None:
419
- adder.remove_running_tokens(self.running_batch, self.new_token_ratio)
449
+ adder.remove_running_tokens(self.running_batch)
420
450
 
421
451
  has_inflight = self.current_inflight_req is not None
422
452
  if self.current_inflight_req is not None:
@@ -427,12 +457,30 @@ class ModelTpServer:
427
457
  self.current_inflight_req
428
458
  )
429
459
 
460
+ if self.lora_paths is not None:
461
+ lora_set = (
462
+ set([req.lora_path for req in self.running_batch.reqs])
463
+ if self.running_batch is not None
464
+ else set([])
465
+ )
466
+
430
467
  for req in self.waiting_queue:
468
+ if adder.no_remaining_tokens():
469
+ break
431
470
  req.init_next_round_input(None if prefix_computed else self.tree_cache)
471
+ if (
472
+ self.lora_paths is not None
473
+ and len(
474
+ lora_set
475
+ | set([req.lora_path for req in adder.can_run_list])
476
+ | set([req.lora_path])
477
+ )
478
+ > self.max_loras_per_batch
479
+ ):
480
+ break
432
481
  res = adder.add_one_req(req)
433
482
  if (
434
483
  not res
435
- or adder.no_remaining_tokens()
436
484
  or running_bs + len(adder.can_run_list) >= self.max_running_requests
437
485
  ):
438
486
  break
@@ -504,10 +552,9 @@ class ModelTpServer:
504
552
  if self.model_runner.is_generation:
505
553
  # Forward and sample the next tokens
506
554
  if batch.extend_num_tokens != 0:
507
- sample_output, logits_output = self.model_runner.forward(
508
- batch, ForwardMode.EXTEND
509
- )
510
- next_token_ids = batch.check_sample_results(sample_output)
555
+ logits_output = self.model_runner.forward(batch)
556
+ next_token_ids = self.model_runner.sample(logits_output, batch)
557
+
511
558
  batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
512
559
  next_token_ids
513
560
  )
@@ -541,7 +588,7 @@ class ModelTpServer:
541
588
  next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
542
589
 
543
590
  # Check finish conditions
544
- pt = 0
591
+ logprob_pt = 0
545
592
  for i, req in enumerate(batch.reqs):
546
593
  if req is not self.current_inflight_req:
547
594
  # Inflight reqs' prefill is not finished
@@ -565,13 +612,12 @@ class ModelTpServer:
565
612
  self.req_to_token_pool.free(req.req_pool_idx)
566
613
 
567
614
  if req.return_logprob:
568
- self.add_logprob_return_values(
569
- i, req, pt, next_token_ids, logits_output
615
+ logprob_pt += self.add_logprob_return_values(
616
+ i, req, logprob_pt, next_token_ids, logits_output
570
617
  )
571
- pt += req.extend_input_len
572
618
  else:
573
619
  assert batch.extend_num_tokens != 0
574
- logits_output = self.model_runner.forward(batch, ForwardMode.EXTEND)
620
+ logits_output = self.model_runner.forward(batch)
575
621
  embeddings = logits_output.embeddings.tolist()
576
622
 
577
623
  # Check finish conditions
@@ -596,48 +642,63 @@ class ModelTpServer:
596
642
 
597
643
  def add_logprob_return_values(
598
644
  self,
599
- i,
645
+ i: int,
600
646
  req: Req,
601
647
  pt: int,
602
648
  next_token_ids: List[int],
603
649
  output: LogitsProcessorOutput,
604
650
  ):
651
+ """Attach logprobs to the return values."""
652
+ req.output_token_logprobs.append(
653
+ (output.next_token_logprobs[i], next_token_ids[i])
654
+ )
655
+
656
+ # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
657
+ num_input_logprobs = req.extend_input_len - req.extend_logprob_start_len
658
+
605
659
  if req.normalized_prompt_logprob is None:
606
660
  req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
607
661
 
608
662
  if req.input_token_logprobs is None:
609
- # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
610
- req.input_token_logprobs = list(
611
- zip(
612
- output.input_token_logprobs[pt : pt + req.extend_input_len - 1],
613
- req.fill_ids[-req.extend_input_len + 1 :],
614
- )
615
- )
616
- if req.logprob_start_len == 0:
663
+ input_token_logprobs = output.input_token_logprobs[
664
+ pt : pt + num_input_logprobs - 1 - req.last_update_decode_tokens
665
+ ]
666
+ input_token_ids = req.fill_ids[
667
+ len(req.fill_ids)
668
+ - num_input_logprobs
669
+ + 1 : len(req.fill_ids)
670
+ - req.last_update_decode_tokens
671
+ ]
672
+ req.input_token_logprobs = list(zip(input_token_logprobs, input_token_ids))
673
+
674
+ if (
675
+ req.logprob_start_len == 0
676
+ ): # The first token does not have logprob, pad it.
617
677
  req.input_token_logprobs = [
618
678
  (None, req.fill_ids[0])
619
679
  ] + req.input_token_logprobs
620
680
 
621
681
  if req.last_update_decode_tokens != 0:
682
+ # Some decode tokens are re-computed in an extend batch
622
683
  req.output_token_logprobs.extend(
623
684
  list(
624
685
  zip(
625
686
  output.input_token_logprobs[
626
687
  pt
627
- + req.extend_input_len
688
+ + num_input_logprobs
689
+ - 1
628
690
  - req.last_update_decode_tokens : pt
629
- + req.extend_input_len
691
+ + num_input_logprobs
630
692
  - 1
631
693
  ],
632
- req.fill_ids[-req.last_update_decode_tokens + 1 :],
694
+ req.fill_ids[
695
+ len(req.fill_ids)
696
+ - req.last_update_decode_tokens : len(req.fill_ids)
697
+ ],
633
698
  )
634
699
  )
635
700
  )
636
701
 
637
- req.output_token_logprobs.append(
638
- (output.next_token_logprobs[i], next_token_ids[i])
639
- )
640
-
641
702
  if req.top_logprobs_num > 0:
642
703
  if req.input_top_logprobs is None:
643
704
  req.input_top_logprobs = output.input_top_logprobs[i]
@@ -646,10 +707,12 @@ class ModelTpServer:
646
707
 
647
708
  if req.last_update_decode_tokens != 0:
648
709
  req.output_top_logprobs.extend(
649
- output.input_top_logprobs[i][-req.last_update_decode_tokens + 1 :]
710
+ output.input_top_logprobs[i][-req.last_update_decode_tokens :]
650
711
  )
651
712
  req.output_top_logprobs.append(output.output_top_logprobs[i])
652
713
 
714
+ return num_input_logprobs
715
+
653
716
  def forward_decode_batch(self, batch: ScheduleBatch):
654
717
  # Check if decode out of memory
655
718
  if not batch.check_decode_mem():
@@ -682,10 +745,8 @@ class ModelTpServer:
682
745
  batch.prepare_for_decode()
683
746
 
684
747
  # Forward and sample the next tokens
685
- sample_output, logits_output = self.model_runner.forward(
686
- batch, ForwardMode.DECODE
687
- )
688
- next_token_ids = batch.check_sample_results(sample_output)
748
+ logits_output = self.model_runner.forward(batch)
749
+ next_token_ids = self.model_runner.sample(logits_output, batch)
689
750
  batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
690
751
  next_token_ids
691
752
  )
@@ -700,6 +761,7 @@ class ModelTpServer:
700
761
  next_token_ids = next_token_ids.tolist()
701
762
 
702
763
  # Check finish condition
764
+ has_finished = False
703
765
  for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
704
766
  req.completion_tokens_wo_jump_forward += 1
705
767
  req.output_ids.append(next_token_id)
@@ -712,6 +774,7 @@ class ModelTpServer:
712
774
 
713
775
  if req.finished():
714
776
  self.tree_cache.cache_finished_req(req)
777
+ has_finished = True
715
778
 
716
779
  if req.return_logprob:
717
780
  req.output_token_logprobs.append(
@@ -720,6 +783,9 @@ class ModelTpServer:
720
783
  if req.top_logprobs_num > 0:
721
784
  req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
722
785
 
786
+ if not has_finished:
787
+ self.do_not_get_new_batch = True
788
+
723
789
  self.handle_finished_requests(batch)
724
790
 
725
791
  def handle_finished_requests(self, batch: ScheduleBatch):
@@ -742,12 +808,10 @@ class ModelTpServer:
742
808
  unfinished_indices.append(i)
743
809
 
744
810
  if req.finished() or (
745
- (
746
- req.stream
747
- and (
748
- self.decode_forward_ct % self.stream_interval == 0
749
- or len(req.output_ids) == 1
750
- )
811
+ req.stream
812
+ and (
813
+ self.decode_forward_ct % self.stream_interval == 0
814
+ or len(req.output_ids) == 1
751
815
  )
752
816
  ):
753
817
  output_rids.append(req.rid)
@@ -769,7 +833,11 @@ class ModelTpServer:
769
833
  "prompt_tokens": len(req.origin_input_ids),
770
834
  "completion_tokens": len(req.output_ids),
771
835
  "completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
772
- "finish_reason": str(req.finished_reason),
836
+ "finish_reason": (
837
+ req.finished_reason.to_json()
838
+ if req.finished_reason is not None
839
+ else None
840
+ ),
773
841
  }
774
842
  if req.return_logprob:
775
843
  (
@@ -868,6 +936,8 @@ class ModelTpServer:
868
936
  if success:
869
937
  flash_cache_success = self.flush_cache()
870
938
  assert flash_cache_success, "Cache flush failed after updating weights"
939
+ else:
940
+ logger.error(message)
871
941
  return success, message
872
942
 
873
943
 
@@ -876,7 +946,6 @@ def run_tp_server(
876
946
  tp_rank: int,
877
947
  server_args: ServerArgs,
878
948
  nccl_port: int,
879
- model_override_args: dict,
880
949
  ):
881
950
  """Run a tensor parallel model server."""
882
951
  configure_logger(server_args, prefix=f" TP{tp_rank}")
@@ -887,7 +956,6 @@ def run_tp_server(
887
956
  tp_rank,
888
957
  server_args,
889
958
  nccl_port,
890
- model_override_args,
891
959
  )
892
960
  tp_cpu_group = model_server.model_runner.tp_group.cpu_group
893
961
 
@@ -904,14 +972,13 @@ def launch_tp_servers(
904
972
  tp_rank_range: List[int],
905
973
  server_args: ServerArgs,
906
974
  nccl_port: int,
907
- model_override_args: dict,
908
975
  ):
909
976
  """Launch multiple tensor parallel servers."""
910
977
  procs = []
911
978
  for i in tp_rank_range:
912
979
  proc = multiprocessing.Process(
913
980
  target=run_tp_server,
914
- args=(gpu_ids[i], i, server_args, nccl_port, model_override_args),
981
+ args=(gpu_ids[i], i, server_args, nccl_port),
915
982
  )
916
983
  proc.start()
917
984
  procs.append(proc)