sglang 0.1.17__py3-none-any.whl → 0.1.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. sglang/__init__.py +2 -2
  2. sglang/api.py +30 -4
  3. sglang/backend/litellm.py +2 -2
  4. sglang/backend/openai.py +26 -15
  5. sglang/backend/runtime_endpoint.py +18 -14
  6. sglang/bench_latency.py +317 -0
  7. sglang/global_config.py +5 -1
  8. sglang/lang/chat_template.py +41 -6
  9. sglang/lang/compiler.py +2 -2
  10. sglang/lang/interpreter.py +6 -2
  11. sglang/lang/ir.py +74 -28
  12. sglang/launch_server.py +4 -1
  13. sglang/launch_server_llavavid.py +2 -1
  14. sglang/srt/constrained/__init__.py +14 -6
  15. sglang/srt/constrained/fsm_cache.py +6 -3
  16. sglang/srt/constrained/jump_forward.py +113 -25
  17. sglang/srt/conversation.py +2 -0
  18. sglang/srt/flush_cache.py +2 -0
  19. sglang/srt/hf_transformers_utils.py +68 -9
  20. sglang/srt/layers/extend_attention.py +2 -1
  21. sglang/srt/layers/fused_moe.py +280 -169
  22. sglang/srt/layers/logits_processor.py +106 -42
  23. sglang/srt/layers/radix_attention.py +53 -29
  24. sglang/srt/layers/token_attention.py +4 -1
  25. sglang/srt/managers/controller/dp_worker.py +6 -3
  26. sglang/srt/managers/controller/infer_batch.py +144 -69
  27. sglang/srt/managers/controller/manager_multi.py +5 -5
  28. sglang/srt/managers/controller/manager_single.py +9 -4
  29. sglang/srt/managers/controller/model_runner.py +167 -55
  30. sglang/srt/managers/controller/radix_cache.py +4 -0
  31. sglang/srt/managers/controller/schedule_heuristic.py +2 -0
  32. sglang/srt/managers/controller/tp_worker.py +156 -134
  33. sglang/srt/managers/detokenizer_manager.py +19 -21
  34. sglang/srt/managers/io_struct.py +11 -5
  35. sglang/srt/managers/tokenizer_manager.py +16 -14
  36. sglang/srt/model_config.py +89 -4
  37. sglang/srt/models/chatglm.py +399 -0
  38. sglang/srt/models/commandr.py +2 -2
  39. sglang/srt/models/dbrx.py +1 -1
  40. sglang/srt/models/gemma.py +5 -1
  41. sglang/srt/models/gemma2.py +436 -0
  42. sglang/srt/models/grok.py +204 -137
  43. sglang/srt/models/llama2.py +12 -5
  44. sglang/srt/models/llama_classification.py +107 -0
  45. sglang/srt/models/llava.py +11 -8
  46. sglang/srt/models/llavavid.py +1 -1
  47. sglang/srt/models/minicpm.py +373 -0
  48. sglang/srt/models/mixtral.py +164 -115
  49. sglang/srt/models/mixtral_quant.py +0 -1
  50. sglang/srt/models/qwen.py +1 -1
  51. sglang/srt/models/qwen2.py +1 -1
  52. sglang/srt/models/qwen2_moe.py +454 -0
  53. sglang/srt/models/stablelm.py +1 -1
  54. sglang/srt/models/yivl.py +2 -2
  55. sglang/srt/openai_api_adapter.py +35 -25
  56. sglang/srt/openai_protocol.py +2 -2
  57. sglang/srt/server.py +69 -19
  58. sglang/srt/server_args.py +76 -43
  59. sglang/srt/utils.py +177 -35
  60. sglang/test/test_programs.py +28 -10
  61. sglang/utils.py +4 -3
  62. {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/METADATA +44 -31
  63. sglang-0.1.19.dist-info/RECORD +81 -0
  64. {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/WHEEL +1 -1
  65. sglang/srt/managers/router/infer_batch.py +0 -596
  66. sglang/srt/managers/router/manager.py +0 -82
  67. sglang/srt/managers/router/model_rpc.py +0 -818
  68. sglang/srt/managers/router/model_runner.py +0 -445
  69. sglang/srt/managers/router/radix_cache.py +0 -267
  70. sglang/srt/managers/router/scheduler.py +0 -59
  71. sglang-0.1.17.dist-info/RECORD +0 -81
  72. {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/LICENSE +0 -0
  73. {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,11 @@
1
+ """A tensor parallel worker."""
2
+
1
3
  import asyncio
2
4
  import logging
3
5
  import time
4
6
  import warnings
5
7
  from concurrent.futures import ThreadPoolExecutor
6
- from typing import List
8
+ from typing import List, Optional
7
9
 
8
10
  import rpyc
9
11
  import torch
@@ -13,23 +15,30 @@ from sglang.global_config import global_config
13
15
  from sglang.srt.constrained.fsm_cache import FSMCache
14
16
  from sglang.srt.constrained.jump_forward import JumpForwardCache
15
17
  from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
18
+ from sglang.srt.managers.controller.infer_batch import (
19
+ FINISH_ABORT,
20
+ BaseFinishReason,
21
+ Batch,
22
+ ForwardMode,
23
+ Req,
24
+ )
25
+ from sglang.srt.managers.controller.model_runner import ModelRunner
26
+ from sglang.srt.managers.controller.radix_cache import RadixCache
27
+ from sglang.srt.managers.controller.schedule_heuristic import ScheduleHeuristic
16
28
  from sglang.srt.managers.io_struct import (
17
29
  AbortReq,
18
30
  BatchTokenIDOut,
19
31
  FlushCacheReq,
20
32
  TokenizedGenerateReqInput,
21
33
  )
22
- from sglang.srt.managers.controller.infer_batch import BaseFinishReason, Batch, FINISH_ABORT, ForwardMode, Req
23
- from sglang.srt.managers.controller.model_runner import ModelRunner
24
- from sglang.srt.managers.controller.radix_cache import RadixCache
25
- from sglang.srt.managers.controller.schedule_heuristic import ScheduleHeuristic
26
34
  from sglang.srt.model_config import ModelConfig
27
35
  from sglang.srt.server_args import ModelPortArgs, ServerArgs
28
36
  from sglang.srt.utils import (
37
+ connect_rpyc_service,
29
38
  get_int_token_logit_bias,
30
39
  is_multimodal_model,
31
40
  set_random_seed,
32
- start_rpyc_process,
41
+ start_rpyc_service_process,
33
42
  suppress_other_loggers,
34
43
  )
35
44
  from sglang.utils import get_exception_traceback
@@ -88,16 +97,16 @@ class ModelTpServer:
88
97
  trust_remote_code=server_args.trust_remote_code,
89
98
  )
90
99
  self.max_total_num_tokens = self.model_runner.max_total_num_tokens
91
- self.max_prefill_tokens = max(
92
- self.model_config.context_len,
93
- (
94
- min(self.max_total_num_tokens // 6, 65536)
95
- if server_args.max_prefill_tokens is None
96
- else server_args.max_prefill_tokens
97
- ),
100
+ self.max_prefill_tokens = (
101
+ 4096
102
+ if server_args.max_prefill_tokens is None
103
+ else server_args.max_prefill_tokens
104
+ )
105
+ self.max_running_requests = (
106
+ self.max_total_num_tokens // 2
107
+ if server_args.max_running_requests is None
108
+ else server_args.max_running_requests
98
109
  )
99
- self.max_running_requests = (self.max_total_num_tokens // 2
100
- if server_args.max_running_requests is None else server_args.max_running_requests)
101
110
  self.int_token_logit_bias = torch.tensor(
102
111
  get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
103
112
  )
@@ -108,7 +117,7 @@ class ModelTpServer:
108
117
  f"[gpu_id={self.gpu_id}] "
109
118
  f"max_total_num_tokens={self.max_total_num_tokens}, "
110
119
  f"max_prefill_tokens={self.max_prefill_tokens}, "
111
- f"context_len={self.model_config.context_len}, "
120
+ f"context_len={self.model_config.context_len}"
112
121
  )
113
122
  if self.tp_rank == 0:
114
123
  logger.info(
@@ -242,7 +251,7 @@ class ModelTpServer:
242
251
  self.running_batch = None
243
252
  break
244
253
 
245
- if self.out_pyobjs and self.running_batch.reqs[0].stream:
254
+ if self.out_pyobjs and self.running_batch.has_stream():
246
255
  break
247
256
  else:
248
257
  # Check the available size
@@ -271,13 +280,14 @@ class ModelTpServer:
271
280
  (recv_req.image_hash >> 64) % self.model_config.vocab_size,
272
281
  ]
273
282
  req.image_size = recv_req.image_size
274
- req.origin_input_ids, req.image_offset = (
275
- self.model_runner.model.pad_input_ids(
276
- req.origin_input_ids_unpadded,
277
- req.pad_value,
278
- req.pixel_values.shape,
279
- req.image_size,
280
- )
283
+ (
284
+ req.origin_input_ids,
285
+ req.image_offset,
286
+ ) = self.model_runner.model.pad_input_ids(
287
+ req.origin_input_ids_unpadded,
288
+ req.pad_value,
289
+ req.pixel_values.shape,
290
+ req.image_size,
281
291
  )
282
292
  req.sampling_params = recv_req.sampling_params
283
293
  req.return_logprob = recv_req.return_logprob
@@ -303,7 +313,7 @@ class ModelTpServer:
303
313
  )
304
314
  self.forward_queue.append(req)
305
315
 
306
- def get_new_fill_batch(self):
316
+ def get_new_fill_batch(self) -> Optional[Batch]:
307
317
  if (
308
318
  self.running_batch is not None
309
319
  and len(self.running_batch.reqs) > self.max_running_requests
@@ -312,10 +322,7 @@ class ModelTpServer:
312
322
 
313
323
  # Compute matched prefix length
314
324
  for req in self.forward_queue:
315
- assert (
316
- len(req.output_ids) == 0
317
- ), "The output ids should be empty when prefilling"
318
- req.input_ids = req.origin_input_ids + req.prev_output_ids
325
+ req.input_ids = req.origin_input_ids + req.output_ids
319
326
  prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids)
320
327
  if req.return_logprob:
321
328
  prefix_indices = prefix_indices[: req.logprob_start_len]
@@ -361,8 +368,11 @@ class ModelTpServer:
361
368
  if (
362
369
  req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
363
370
  < available_size
364
- and req.extend_input_len + new_batch_input_tokens
365
- < self.max_prefill_tokens
371
+ and (
372
+ req.extend_input_len + new_batch_input_tokens
373
+ <= self.max_prefill_tokens
374
+ or len(can_run_list) == 0
375
+ )
366
376
  ):
367
377
  delta = self.tree_cache.inc_lock_ref(req.last_node)
368
378
  available_size += delta
@@ -401,7 +411,7 @@ class ModelTpServer:
401
411
  self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
402
412
  )
403
413
  logger.info(
404
- f"[gpu_id={self.gpu_id}] Prefil batch. "
414
+ f"[gpu_id={self.gpu_id}] Prefill batch. "
405
415
  f"#new-seq: {len(can_run_list)}, "
406
416
  f"#new-token: {new_batch_input_tokens}, "
407
417
  f"#cached-token: {hit_tokens}, "
@@ -432,97 +442,93 @@ class ModelTpServer:
432
442
  self.model_config.vocab_size, self.int_token_logit_bias
433
443
  )
434
444
 
445
+ # Forward and sample the next tokens
435
446
  if batch.extend_num_tokens != 0:
436
- # Forward
437
- logits, (
438
- prefill_token_logprobs,
439
- normalized_prompt_logprobs,
440
- prefill_top_logprobs,
441
- decode_top_logprobs,
442
- last_logprobs,
443
- ) = self.model_runner.forward(batch, ForwardMode.EXTEND)
444
- if prefill_token_logprobs is not None:
445
- prefill_token_logprobs = prefill_token_logprobs.tolist()
446
- normalized_prompt_logprobs = normalized_prompt_logprobs.tolist()
447
-
448
- next_token_ids, _ = batch.sample(logits)
449
-
450
- # Only transfer the selected logprobs of the next token to CPU to reduce overhead.
451
- if last_logprobs is not None:
452
- last_token_logprobs = last_logprobs[
453
- torch.arange(len(batch.reqs), device=next_token_ids.device),
447
+ output = self.model_runner.forward(batch, ForwardMode.EXTEND)
448
+ next_token_ids, _ = batch.sample(output.next_token_logits)
449
+
450
+ # Move logprobs to cpu
451
+ if output.next_token_logprobs is not None:
452
+ output.next_token_logprobs = output.next_token_logprobs[
453
+ torch.arange(len(next_token_ids), device=next_token_ids.device),
454
454
  next_token_ids,
455
455
  ].tolist()
456
+ output.prefill_token_logprobs = output.prefill_token_logprobs.tolist()
457
+ output.normalized_prompt_logprobs = (
458
+ output.normalized_prompt_logprobs.tolist()
459
+ )
456
460
 
457
461
  next_token_ids = next_token_ids.tolist()
458
462
  else:
459
463
  next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
460
464
 
461
- # Check finish condition
465
+ # Check finish conditions
462
466
  pt = 0
463
467
  for i, req in enumerate(batch.reqs):
464
468
  req.completion_tokens_wo_jump_forward += 1
465
- req.output_ids = [next_token_ids[i]]
469
+ req.output_ids.append(next_token_ids[i])
466
470
  req.check_finished()
467
471
 
468
472
  if req.return_logprob:
469
- if req.normalized_prompt_logprob is None:
470
- req.normalized_prompt_logprob = normalized_prompt_logprobs[i]
471
-
472
- if req.prefill_token_logprobs is None:
473
- # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
474
- req.prefill_token_logprobs = list(
475
- zip(
476
- prefill_token_logprobs[pt : pt + req.extend_input_len - 1],
477
- req.input_ids[-req.extend_input_len + 1 :],
478
- )
479
- )
480
- if req.logprob_start_len == 0:
481
- req.prefill_token_logprobs = [
482
- (None, req.input_ids[0])
483
- ] + req.prefill_token_logprobs
484
-
485
- if req.last_update_decode_tokens != 0:
486
- req.decode_token_logprobs.extend(
487
- list(
488
- zip(
489
- prefill_token_logprobs[
490
- pt
491
- + req.extend_input_len
492
- - req.last_update_decode_tokens : pt
493
- + req.extend_input_len
494
- - 1
495
- ],
496
- req.input_ids[-req.last_update_decode_tokens + 1 :],
497
- )
498
- )
499
- )
473
+ self.add_logprob_return_values(i, req, pt, next_token_ids, output)
474
+ pt += req.extend_input_len
500
475
 
501
- req.decode_token_logprobs.append(
502
- (last_token_logprobs[i], next_token_ids[i])
503
- )
476
+ self.handle_finished_requests(batch)
504
477
 
505
- if req.top_logprobs_num > 0:
506
- if req.prefill_top_logprobs is None:
507
- req.prefill_top_logprobs = prefill_top_logprobs[i]
508
- if req.logprob_start_len == 0:
509
- req.prefill_top_logprobs = [None] + req.prefill_top_logprobs
478
+ def add_logprob_return_values(self, i, req, pt, next_token_ids, output):
479
+ if req.normalized_prompt_logprob is None:
480
+ req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
510
481
 
511
- if req.last_update_decode_tokens != 0:
512
- req.decode_top_logprobs.extend(
513
- prefill_top_logprobs[i][-req.last_update_decode_tokens + 1 :]
482
+ if req.prefill_token_logprobs is None:
483
+ # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
484
+ req.prefill_token_logprobs = list(
485
+ zip(
486
+ output.prefill_token_logprobs[pt : pt + req.extend_input_len - 1],
487
+ req.input_ids[-req.extend_input_len + 1 :],
488
+ )
489
+ )
490
+ if req.logprob_start_len == 0:
491
+ req.prefill_token_logprobs = [
492
+ (None, req.input_ids[0])
493
+ ] + req.prefill_token_logprobs
494
+
495
+ if req.last_update_decode_tokens != 0:
496
+ req.decode_token_logprobs.extend(
497
+ list(
498
+ zip(
499
+ output.prefill_token_logprobs[
500
+ pt
501
+ + req.extend_input_len
502
+ - req.last_update_decode_tokens : pt
503
+ + req.extend_input_len
504
+ - 1
505
+ ],
506
+ req.input_ids[-req.last_update_decode_tokens + 1 :],
514
507
  )
515
- req.decode_top_logprobs.append(decode_top_logprobs[i])
508
+ )
509
+ )
516
510
 
517
- pt += req.extend_input_len
511
+ req.decode_token_logprobs.append(
512
+ (output.next_token_logprobs[i], next_token_ids[i])
513
+ )
518
514
 
519
- self.handle_finished_requests(batch)
515
+ if req.top_logprobs_num > 0:
516
+ if req.prefill_top_logprobs is None:
517
+ req.prefill_top_logprobs = output.prefill_top_logprobs[i]
518
+ if req.logprob_start_len == 0:
519
+ req.prefill_top_logprobs = [None] + req.prefill_top_logprobs
520
+
521
+ if req.last_update_decode_tokens != 0:
522
+ req.decode_top_logprobs.extend(
523
+ output.prefill_top_logprobs[i][-req.last_update_decode_tokens + 1 :]
524
+ )
525
+ req.decode_top_logprobs.append(output.decode_top_logprobs[i])
520
526
 
521
527
  def cache_filled_batch(self, batch: Batch):
522
528
  req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
523
529
  for i, req in enumerate(batch.reqs):
524
530
  new_prefix_indices, new_last_node = self.tree_cache.cache_req(
525
- token_ids=tuple(req.input_ids + req.output_ids)[:-1],
531
+ token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1],
526
532
  last_uncached_pos=len(req.prefix_indices),
527
533
  req_pool_idx=req_pool_indices_cpu[i],
528
534
  del_in_memory_pool=False,
@@ -531,7 +537,7 @@ class ModelTpServer:
531
537
  req.prefix_indices, req.last_node = new_prefix_indices, new_last_node
532
538
 
533
539
  def forward_decode_batch(self, batch: Batch):
534
- # check if decode out of memory
540
+ # Check if decode out of memory
535
541
  if not batch.check_decode_mem():
536
542
  old_ratio = self.new_token_ratio
537
543
  self.new_token_ratio = min(old_ratio + self.new_token_ratio_recovery, 1.0)
@@ -550,9 +556,8 @@ class ModelTpServer:
550
556
  )
551
557
 
552
558
  if not self.disable_regex_jump_forward:
553
- # check for jump-forward
559
+ # Check for jump-forward
554
560
  jump_forward_reqs = batch.check_for_jump_forward(self.model_runner)
555
-
556
561
  self.forward_queue.extend(jump_forward_reqs)
557
562
  if batch.is_empty():
558
563
  return
@@ -561,23 +566,19 @@ class ModelTpServer:
561
566
  self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
562
567
  batch.prepare_for_decode()
563
568
 
564
- # Forward
565
- logits, (
566
- _,
567
- _,
568
- _,
569
- decode_top_logprobs,
570
- last_logprobs,
571
- ) = self.model_runner.forward(batch, ForwardMode.DECODE)
572
- next_token_ids, _ = batch.sample(logits)
573
- next_token_ids = next_token_ids.tolist()
569
+ # Forward and sample the next tokens
570
+ output = self.model_runner.forward(batch, ForwardMode.DECODE)
571
+ next_token_ids, _ = batch.sample(output.next_token_logits)
574
572
 
575
- # Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
576
- if last_logprobs is not None:
577
- new_token_logprobs = last_logprobs[
578
- torch.arange(len(batch.reqs)), next_token_ids
573
+ # Move logprobs to cpu
574
+ if output.next_token_logprobs is not None:
575
+ next_token_logprobs = output.next_token_logprobs[
576
+ torch.arange(len(next_token_ids), device=next_token_ids.device),
577
+ next_token_ids,
579
578
  ].tolist()
580
579
 
580
+ next_token_ids = next_token_ids.tolist()
581
+
581
582
  # Check finish condition
582
583
  for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
583
584
  req.completion_tokens_wo_jump_forward += 1
@@ -585,17 +586,19 @@ class ModelTpServer:
585
586
  req.check_finished()
586
587
 
587
588
  if req.return_logprob:
588
- req.decode_token_logprobs.append((new_token_logprobs[i], next_token_id))
589
-
590
- if req.top_logprobs_num > 0:
591
- req.decode_top_logprobs.append(decode_top_logprobs[i])
589
+ req.decode_token_logprobs.append(
590
+ (next_token_logprobs[i], next_token_id)
591
+ )
592
+ if req.top_logprobs_num > 0:
593
+ req.decode_top_logprobs.append(output.decode_top_logprobs[i])
592
594
 
593
595
  self.handle_finished_requests(batch)
594
596
 
595
597
  def handle_finished_requests(self, batch: Batch):
596
598
  output_rids = []
597
- prev_output_strs = []
598
- output_tokens = []
599
+ decoded_texts = []
600
+ surr_output_ids = []
601
+ read_output_ids = []
599
602
  output_skip_special_tokens = []
600
603
  output_spaces_between_special_tokens = []
601
604
  output_meta_info = []
@@ -618,8 +621,10 @@ class ModelTpServer:
618
621
  )
619
622
  ):
620
623
  output_rids.append(req.rid)
621
- prev_output_strs.append(req.prev_output_str)
622
- output_tokens.append(req.output_ids)
624
+ decoded_texts.append(req.decoded_text)
625
+ surr_ids, read_ids, _ = req.init_detokenize_incrementally()
626
+ surr_output_ids.append(surr_ids)
627
+ read_output_ids.append(read_ids)
623
628
  output_skip_special_tokens.append(
624
629
  req.sampling_params.skip_special_tokens
625
630
  )
@@ -629,7 +634,7 @@ class ModelTpServer:
629
634
 
630
635
  meta_info = {
631
636
  "prompt_tokens": len(req.origin_input_ids),
632
- "completion_tokens": len(req.prev_output_ids) + len(req.output_ids),
637
+ "completion_tokens": len(req.output_ids),
633
638
  "completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
634
639
  "finish_reason": str(req.finished_reason),
635
640
  }
@@ -655,8 +660,9 @@ class ModelTpServer:
655
660
  self.out_pyobjs.append(
656
661
  BatchTokenIDOut(
657
662
  output_rids,
658
- prev_output_strs,
659
- output_tokens,
663
+ decoded_texts,
664
+ surr_output_ids,
665
+ read_output_ids,
660
666
  output_skip_special_tokens,
661
667
  output_spaces_between_special_tokens,
662
668
  output_meta_info,
@@ -671,7 +677,7 @@ class ModelTpServer:
671
677
  for i in finished_indices:
672
678
  req = batch.reqs[i]
673
679
  self.tree_cache.cache_req(
674
- token_ids=tuple(req.input_ids + req.output_ids)[:-1],
680
+ token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1],
675
681
  last_uncached_pos=len(req.prefix_indices),
676
682
  req_pool_idx=req_pool_indices_cpu[i],
677
683
  )
@@ -758,12 +764,28 @@ class ModelTpClient:
758
764
  else:
759
765
  with ThreadPoolExecutor(self.tp_size) as executor:
760
766
  # Launch model processes
761
- rets = executor.map(
762
- lambda args: start_rpyc_process(*args),
763
- [(ModelTpService, p) for p in model_port_args.model_tp_ports],
767
+ if server_args.nnodes == 1:
768
+ self.procs = list(
769
+ executor.map(
770
+ lambda args: start_rpyc_service_process(*args),
771
+ [
772
+ (ModelTpService, p)
773
+ for p in model_port_args.model_tp_ports
774
+ ],
775
+ )
776
+ )
777
+ addrs = [("localhost", p) for p in model_port_args.model_tp_ports]
778
+ else:
779
+ addrs = [
780
+ (ip, port)
781
+ for ip, port in zip(
782
+ model_port_args.model_tp_ips, model_port_args.model_tp_ports
783
+ )
784
+ ]
785
+
786
+ self.model_services = list(
787
+ executor.map(lambda args: connect_rpyc_service(*args), addrs)
764
788
  )
765
- self.model_services = [x[0] for x in rets]
766
- self.procs = [x[1] for x in rets]
767
789
 
768
790
  # Init model
769
791
  def init_model(i):
@@ -775,7 +797,7 @@ class ModelTpClient:
775
797
  model_overide_args,
776
798
  )
777
799
 
778
- self.model_servers = executor.map(init_model, range(self.tp_size))
800
+ self.model_servers = list(executor.map(init_model, range(self.tp_size)))
779
801
 
780
802
  # Wrap functions
781
803
  def async_wrap(func_name):
@@ -788,4 +810,4 @@ class ModelTpClient:
788
810
 
789
811
  return _func
790
812
 
791
- self.step = async_wrap("step")
813
+ self.step = async_wrap("step")
@@ -1,3 +1,5 @@
1
+ """DetokenizerManager is a process that detokenizes the token ids."""
2
+
1
3
  import asyncio
2
4
  import inspect
3
5
 
@@ -6,10 +8,10 @@ import zmq
6
8
  import zmq.asyncio
7
9
 
8
10
  from sglang.srt.hf_transformers_utils import get_tokenizer
11
+ from sglang.srt.managers.controller.infer_batch import FINISH_MATCHED_STR
9
12
  from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut
10
13
  from sglang.srt.server_args import PortArgs, ServerArgs
11
- from sglang.utils import get_exception_traceback, graceful_registry
12
- from sglang.srt.managers.controller.infer_batch import FINISH_MATCHED_STR
14
+ from sglang.utils import find_printable_text, get_exception_traceback, graceful_registry
13
15
 
14
16
  asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
15
17
 
@@ -38,30 +40,26 @@ class DetokenizerManager:
38
40
  recv_obj: BatchTokenIDOut = await self.recv_from_router.recv_pyobj()
39
41
  assert isinstance(recv_obj, BatchTokenIDOut)
40
42
 
41
- output_tokens = recv_obj.output_tokens
42
-
43
43
  # TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
44
- output_strs = self.tokenizer.batch_decode(
45
- output_tokens,
44
+ surr_texts = self.tokenizer.batch_decode(
45
+ recv_obj.surr_output_ids,
46
+ skip_special_tokens=recv_obj.skip_special_tokens[0],
47
+ spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
48
+ )
49
+ read_texts = self.tokenizer.batch_decode(
50
+ recv_obj.read_output_ids,
46
51
  skip_special_tokens=recv_obj.skip_special_tokens[0],
47
- spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[
48
- 0
49
- ],
52
+ spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[0],
50
53
  )
51
54
 
52
55
  # Trim stop str
53
56
  # TODO(lmzheng): handle the case where multiple stop strs are hit
54
- for i in range(len(output_strs)):
55
- if len(output_tokens[i]) > 0:
56
- first_token = self.tokenizer.convert_ids_to_tokens(
57
- int(output_tokens[i][0])
58
- )
59
- if not isinstance(first_token, str):
60
- first_token = first_token.decode("utf-8", errors="ignore")
61
- if first_token.startswith("▁"):
62
- output_strs[i] = " " + output_strs[i]
63
-
64
- output_strs[i] = recv_obj.prev_output_strs[i] + output_strs[i]
57
+ output_strs = []
58
+ for i in range(len(recv_obj.rids)):
59
+ new_text = read_texts[i][len(surr_texts[i]) :]
60
+ if recv_obj.finished_reason[i] is None:
61
+ new_text = find_printable_text(new_text)
62
+ output_strs.append(recv_obj.decoded_texts[i] + new_text)
65
63
 
66
64
  if isinstance(recv_obj.finished_reason[i], FINISH_MATCHED_STR):
67
65
  pos = output_strs[i].find(recv_obj.finished_reason[i].matched)
@@ -71,7 +69,7 @@ class DetokenizerManager:
71
69
  self.send_to_tokenizer.send_pyobj(
72
70
  BatchStrOut(
73
71
  rids=recv_obj.rids,
74
- output_str=output_strs,
72
+ output_strs=output_strs,
75
73
  meta_info=recv_obj.meta_info,
76
74
  finished_reason=recv_obj.finished_reason,
77
75
  )
@@ -1,9 +1,14 @@
1
+ """
2
+ The definition of objects transfered between different
3
+ processes (TokenizerManager, DetokenizerManager, Controller).
4
+ """
5
+
1
6
  import uuid
2
7
  from dataclasses import dataclass
3
8
  from typing import Dict, List, Optional, Union
4
9
 
5
- from sglang.srt.sampling_params import SamplingParams
6
10
  from sglang.srt.managers.controller.infer_batch import BaseFinishReason
11
+ from sglang.srt.sampling_params import SamplingParams
7
12
 
8
13
 
9
14
  @dataclass
@@ -30,7 +35,6 @@ class GenerateReqInput:
30
35
  stream: bool = False
31
36
 
32
37
  def post_init(self):
33
-
34
38
  if (self.text is None and self.input_ids is None) or (
35
39
  self.text is not None and self.input_ids is not None
36
40
  ):
@@ -106,17 +110,19 @@ class TokenizedGenerateReqInput:
106
110
  @dataclass
107
111
  class BatchTokenIDOut:
108
112
  rids: List[str]
109
- prev_output_strs: List[str]
110
- output_tokens: List[List[int]]
113
+ decoded_texts: List[str]
114
+ surr_output_ids: List[List[int]]
115
+ read_output_ids: List[List[int]]
111
116
  skip_special_tokens: List[bool]
112
117
  spaces_between_special_tokens: List[bool]
113
118
  meta_info: List[Dict]
114
119
  finished_reason: List[BaseFinishReason]
115
120
 
121
+
116
122
  @dataclass
117
123
  class BatchStrOut:
118
124
  rids: List[str]
119
- output_str: List[str]
125
+ output_strs: List[str]
120
126
  meta_info: List[Dict]
121
127
  finished_reason: List[BaseFinishReason]
122
128