sglang 0.1.14__py3-none-any.whl → 0.1.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. sglang/__init__.py +57 -2
  2. sglang/api.py +8 -5
  3. sglang/backend/anthropic.py +18 -4
  4. sglang/backend/openai.py +2 -1
  5. sglang/backend/runtime_endpoint.py +18 -5
  6. sglang/backend/vertexai.py +1 -0
  7. sglang/global_config.py +5 -1
  8. sglang/lang/chat_template.py +83 -2
  9. sglang/lang/interpreter.py +92 -35
  10. sglang/lang/ir.py +12 -9
  11. sglang/lang/tracer.py +6 -4
  12. sglang/launch_server_llavavid.py +31 -0
  13. sglang/srt/constrained/fsm_cache.py +1 -0
  14. sglang/srt/constrained/jump_forward.py +1 -0
  15. sglang/srt/conversation.py +2 -2
  16. sglang/srt/flush_cache.py +16 -0
  17. sglang/srt/hf_transformers_utils.py +10 -2
  18. sglang/srt/layers/context_flashattention_nopad.py +1 -0
  19. sglang/srt/layers/extend_attention.py +1 -0
  20. sglang/srt/layers/logits_processor.py +114 -54
  21. sglang/srt/layers/radix_attention.py +2 -1
  22. sglang/srt/layers/token_attention.py +1 -0
  23. sglang/srt/managers/detokenizer_manager.py +5 -1
  24. sglang/srt/managers/io_struct.py +27 -3
  25. sglang/srt/managers/router/infer_batch.py +97 -48
  26. sglang/srt/managers/router/manager.py +11 -8
  27. sglang/srt/managers/router/model_rpc.py +169 -90
  28. sglang/srt/managers/router/model_runner.py +110 -166
  29. sglang/srt/managers/router/radix_cache.py +89 -51
  30. sglang/srt/managers/router/scheduler.py +17 -28
  31. sglang/srt/managers/tokenizer_manager.py +110 -33
  32. sglang/srt/memory_pool.py +5 -14
  33. sglang/srt/model_config.py +11 -0
  34. sglang/srt/models/commandr.py +372 -0
  35. sglang/srt/models/dbrx.py +412 -0
  36. sglang/srt/models/dbrx_config.py +281 -0
  37. sglang/srt/models/gemma.py +24 -25
  38. sglang/srt/models/llama2.py +25 -26
  39. sglang/srt/models/llava.py +8 -10
  40. sglang/srt/models/llavavid.py +307 -0
  41. sglang/srt/models/mixtral.py +29 -33
  42. sglang/srt/models/qwen.py +34 -25
  43. sglang/srt/models/qwen2.py +25 -26
  44. sglang/srt/models/stablelm.py +26 -26
  45. sglang/srt/models/yivl.py +3 -5
  46. sglang/srt/openai_api_adapter.py +356 -0
  47. sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +36 -20
  48. sglang/srt/sampling_params.py +2 -0
  49. sglang/srt/server.py +91 -456
  50. sglang/srt/server_args.py +79 -49
  51. sglang/srt/utils.py +212 -47
  52. sglang/srt/weight_utils.py +417 -0
  53. sglang/test/test_programs.py +8 -7
  54. sglang/test/test_utils.py +195 -7
  55. sglang/utils.py +77 -26
  56. {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/METADATA +20 -18
  57. sglang-0.1.16.dist-info/RECORD +72 -0
  58. sglang-0.1.14.dist-info/RECORD +0 -64
  59. {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/LICENSE +0 -0
  60. {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/WHEEL +0 -0
  61. {sglang-0.1.14.dist-info → sglang-0.1.16.dist-info}/top_level.txt +0 -0
@@ -4,13 +4,18 @@ import multiprocessing
4
4
  import time
5
5
  import warnings
6
6
  from concurrent.futures import ThreadPoolExecutor
7
- from typing import List
7
+ from typing import Any, Dict, List, Optional, Tuple, Union
8
8
 
9
- import numpy as np
10
9
  import rpyc
11
10
  import torch
12
11
  from rpyc.utils.classic import obtain
13
12
  from rpyc.utils.server import ThreadedServer
13
+
14
+ try:
15
+ from vllm.logger import _default_handler as vllm_default_logger
16
+ except ImportError:
17
+ from vllm.logger import logger as vllm_default_logger
18
+
14
19
  from sglang.srt.constrained.fsm_cache import FSMCache
15
20
  from sglang.srt.constrained.jump_forward import JumpForwardCache
16
21
  from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
@@ -19,7 +24,7 @@ from sglang.srt.managers.io_struct import (
19
24
  FlushCacheReq,
20
25
  TokenizedGenerateReqInput,
21
26
  )
22
- from sglang.srt.managers.router.infer_batch import Batch, ForwardMode, Req
27
+ from sglang.srt.managers.router.infer_batch import Batch, ForwardMode, Req, FinishReason
23
28
  from sglang.srt.managers.router.model_runner import ModelRunner
24
29
  from sglang.srt.managers.router.radix_cache import RadixCache
25
30
  from sglang.srt.managers.router.scheduler import Scheduler
@@ -31,17 +36,20 @@ from sglang.srt.utils import (
31
36
  is_multimodal_model,
32
37
  set_random_seed,
33
38
  )
34
- from vllm.logger import _default_handler as vllm_default_handler
39
+
35
40
 
36
41
  logger = logging.getLogger("model_rpc")
42
+ vllm_default_logger.setLevel(logging.WARN)
43
+ logging.getLogger("vllm.utils").setLevel(logging.WARN)
37
44
 
38
45
 
39
- class ModelRpcServer(rpyc.Service):
40
- def exposed_init_model(
46
+ class ModelRpcServer:
47
+ def __init__(
41
48
  self,
42
49
  tp_rank: int,
43
50
  server_args: ServerArgs,
44
51
  port_args: PortArgs,
52
+ model_overide_args: Optional[dict] = None,
45
53
  ):
46
54
  server_args, port_args = [obtain(x) for x in [server_args, port_args]]
47
55
 
@@ -50,18 +58,16 @@ class ModelRpcServer(rpyc.Service):
50
58
  self.tp_size = server_args.tp_size
51
59
  self.schedule_heuristic = server_args.schedule_heuristic
52
60
  self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
53
- vllm_default_handler.setLevel(
54
- level=getattr(logging, server_args.log_level.upper())
55
- )
56
61
 
57
62
  # Init model and tokenizer
58
63
  self.model_config = ModelConfig(
59
64
  server_args.model_path,
60
65
  server_args.trust_remote_code,
61
66
  context_length=server_args.context_length,
67
+ model_overide_args=model_overide_args,
62
68
  )
63
69
 
64
- # for model end global settings
70
+ # For model end global settings
65
71
  server_args_dict = {
66
72
  "enable_flashinfer": server_args.enable_flashinfer,
67
73
  "attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
@@ -90,7 +96,6 @@ class ModelRpcServer(rpyc.Service):
90
96
  tokenizer_mode=server_args.tokenizer_mode,
91
97
  trust_remote_code=server_args.trust_remote_code,
92
98
  )
93
- self.eos_token_id = self.tokenizer.eos_token_id
94
99
  self.max_total_num_token = self.model_runner.max_total_num_token
95
100
  self.max_num_running_seq = self.max_total_num_token // 2
96
101
  self.max_prefill_num_token = max(
@@ -111,10 +116,15 @@ class ModelRpcServer(rpyc.Service):
111
116
  f"max_prefill_num_token={self.max_prefill_num_token}, "
112
117
  f"context_len={self.model_config.context_len}, "
113
118
  )
114
- logger.info(server_args.get_optional_modes_logging())
119
+ if self.tp_rank == 0:
120
+ logger.info(f"server_args: {server_args.print_mode_args()}")
115
121
 
116
122
  # Init cache
117
- self.tree_cache = RadixCache(server_args.disable_radix_cache)
123
+ self.tree_cache = RadixCache(
124
+ req_to_token_pool=self.model_runner.req_to_token_pool,
125
+ token_to_kv_pool=self.model_runner.token_to_kv_pool,
126
+ disable=server_args.disable_radix_cache,
127
+ )
118
128
  self.tree_cache_metrics = {"total": 0, "hit": 0}
119
129
  self.scheduler = Scheduler(
120
130
  self.schedule_heuristic,
@@ -132,6 +142,8 @@ class ModelRpcServer(rpyc.Service):
132
142
  self.out_pyobjs = []
133
143
  self.decode_forward_ct = 0
134
144
  self.stream_interval = server_args.stream_interval
145
+ self.num_generated_tokens = 0
146
+ self.last_stats_tic = time.time()
135
147
 
136
148
  # Init the FSM cache for constrained generation
137
149
  self.regex_fsm_cache = FSMCache(
@@ -161,7 +173,7 @@ class ModelRpcServer(rpyc.Service):
161
173
  logger.info("Cache flushed successfully!")
162
174
  else:
163
175
  warnings.warn(
164
- "Cache not flushed because there are pending requests. "
176
+ f"Cache not flushed because there are pending requests. "
165
177
  f"#queue-req: {len(self.forward_queue)}, "
166
178
  f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
167
179
  )
@@ -198,6 +210,8 @@ class ModelRpcServer(rpyc.Service):
198
210
  # Run new fill batch
199
211
  self.forward_fill_batch(new_batch)
200
212
 
213
+ self.cache_filled_batch(new_batch)
214
+
201
215
  if not new_batch.is_empty():
202
216
  if self.running_batch is None:
203
217
  self.running_batch = new_batch
@@ -208,6 +222,7 @@ class ModelRpcServer(rpyc.Service):
208
222
  if self.running_batch is not None:
209
223
  # Run a few decode batches continuously for reducing overhead
210
224
  for _ in range(10):
225
+ self.num_generated_tokens += len(self.running_batch.reqs)
211
226
  self.forward_decode_batch(self.running_batch)
212
227
 
213
228
  if self.running_batch.is_empty():
@@ -223,10 +238,14 @@ class ModelRpcServer(rpyc.Service):
223
238
  self.token_to_kv_pool.available_size()
224
239
  + self.tree_cache.evictable_size()
225
240
  )
241
+ throuhgput = self.num_generated_tokens / (time.time() - self.last_stats_tic)
242
+ self.num_generated_tokens = 0
243
+ self.last_stats_tic = time.time()
226
244
  logger.info(
227
245
  f"#running-req: {len(self.running_batch.reqs)}, "
228
246
  f"#token: {num_used}, "
229
247
  f"token usage: {num_used / self.max_total_num_token:.2f}, "
248
+ f"gen throughput (token/s): {throuhgput:.2f}, "
230
249
  f"#queue-req: {len(self.forward_queue)}"
231
250
  )
232
251
  else:
@@ -262,6 +281,7 @@ class ModelRpcServer(rpyc.Service):
262
281
  req.sampling_params = recv_req.sampling_params
263
282
  req.return_logprob = recv_req.return_logprob
264
283
  req.logprob_start_len = recv_req.logprob_start_len
284
+ req.top_logprobs_num = recv_req.top_logprobs_num
265
285
  req.stream = recv_req.stream
266
286
  req.tokenizer = self.tokenizer
267
287
 
@@ -338,25 +358,26 @@ class ModelRpcServer(rpyc.Service):
338
358
  and req.extend_input_len + new_batch_input_tokens
339
359
  < self.max_prefill_num_token
340
360
  ):
341
- delta = self.tree_cache.inc_ref_counter(req.last_node)
361
+ delta = self.tree_cache.inc_lock_ref(req.last_node)
342
362
  available_size += delta
343
363
 
344
364
  if not (
345
365
  req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
346
366
  < available_size
347
367
  ):
348
- # Undo the insertion
349
- delta = self.tree_cache.dec_ref_counter(req.last_node)
368
+ # Undo locking
369
+ delta = self.tree_cache.dec_lock_ref(req.last_node)
350
370
  available_size += delta
371
+ break
351
372
  else:
352
373
  # Add this request to the running batch
353
- self.token_to_kv_pool.add_refs(req.prefix_indices)
354
374
  can_run_list.append(req)
355
375
  new_batch_total_tokens += (
356
376
  req.extend_input_len + req.max_new_tokens()
357
377
  )
358
378
  new_batch_input_tokens += req.extend_input_len
359
-
379
+ else:
380
+ break
360
381
  if len(can_run_list) == 0:
361
382
  return None
362
383
 
@@ -380,12 +401,12 @@ class ModelRpcServer(rpyc.Service):
380
401
  f"#running_req: {running_req}. "
381
402
  f"tree_cache_hit_rate: {100.0 * tree_cache_hit_rate:.2f}%."
382
403
  )
383
- logger.debug(
384
- f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. "
385
- f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. "
386
- f"ff_cache_hit_rate: {100.0 * self.jump_forward_cache.get_cache_hit_rate():.2f}%. "
387
- f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. "
388
- )
404
+ #logger.debug(
405
+ # f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. "
406
+ # f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. "
407
+ # f"ff_cache_hit_rate: {100.0 * self.jump_forward_cache.get_cache_hit_rate():.2f}%. "
408
+ # f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. "
409
+ #)
389
410
 
390
411
  new_batch = Batch.init_new(
391
412
  can_run_list,
@@ -402,56 +423,80 @@ class ModelRpcServer(rpyc.Service):
402
423
  self.model_config.vocab_size, self.int_token_logit_bias
403
424
  )
404
425
 
405
- logprobs = None
406
426
  if batch.extend_num_tokens != 0:
407
427
  # Forward
408
428
  logits, (
409
- prefill_logprobs,
410
- normalized_logprobs,
429
+ prefill_token_logprobs,
430
+ normalized_prompt_logprobs,
431
+ prefill_top_logprobs,
432
+ decode_top_logprobs,
411
433
  last_logprobs,
412
- ) = self.model_runner.forward(
413
- batch, ForwardMode.EXTEND, batch.return_logprob
414
- )
415
- if prefill_logprobs is not None:
416
- logprobs = prefill_logprobs.cpu().tolist()
417
- normalized_logprobs = normalized_logprobs.cpu().tolist()
434
+ ) = self.model_runner.forward(batch, ForwardMode.EXTEND)
435
+ if prefill_token_logprobs is not None:
436
+ prefill_token_logprobs = prefill_token_logprobs.tolist()
437
+ normalized_prompt_logprobs = normalized_prompt_logprobs.tolist()
418
438
 
419
439
  next_token_ids, _ = batch.sample(logits)
420
- next_token_ids = next_token_ids.cpu().tolist()
440
+
441
+ # Only transfer the selected logprobs of the next token to CPU to reduce overhead.
442
+ if last_logprobs is not None:
443
+ last_token_logprobs = (
444
+ last_logprobs[
445
+ torch.arange(len(batch.reqs), device=next_token_ids.device),
446
+ next_token_ids].tolist()
447
+ )
448
+
449
+ next_token_ids = next_token_ids.tolist()
421
450
  else:
422
451
  next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
423
- logits = logprobs = normalized_logprobs = last_logprobs = None
424
-
425
- # Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
426
- reqs = batch.reqs
427
- if last_logprobs is not None:
428
- last_logprobs = (
429
- last_logprobs[torch.arange(len(reqs)), next_token_ids].cpu().tolist()
430
- )
431
452
 
432
453
  # Check finish condition
433
454
  pt = 0
434
- for i, req in enumerate(reqs):
455
+ for i, req in enumerate(batch.reqs):
435
456
  req.completion_tokens_wo_jump_forward += 1
436
457
  req.output_ids = [next_token_ids[i]]
437
458
  req.check_finished()
438
459
 
439
- if logprobs is not None:
440
- req.logprob = logprobs[pt : pt + req.extend_input_len - 1]
441
- req.normalized_logprob = normalized_logprobs[i]
460
+ if req.return_logprob:
461
+ req.normalized_prompt_logprob = normalized_prompt_logprobs[i]
462
+
463
+ # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
464
+ req.prefill_token_logprobs = list(
465
+ zip(
466
+ prefill_token_logprobs[pt : pt + req.extend_input_len - 1],
467
+ req.input_ids[-req.extend_input_len + 1 :],
468
+ )
469
+ )
470
+ if req.logprob_start_len == 0:
471
+ req.prefill_token_logprobs = [
472
+ (None, req.input_ids[0])
473
+ ] + req.prefill_token_logprobs
474
+ req.decode_token_logprobs = [
475
+ (last_token_logprobs[i], next_token_ids[i])
476
+ ]
442
477
 
443
- # If logprob_start_len > 0, then first logprob_start_len prompt tokens
444
- # will be ignored.
445
- prompt_token_len = len(req.logprob)
446
- token_ids = req.input_ids[-prompt_token_len:] + [next_token_ids[i]]
447
- token_logprobs = req.logprob + [last_logprobs[i]]
448
- req.token_logprob = list(zip(token_ids, token_logprobs))
478
+ if req.top_logprobs_num > 0:
479
+ req.prefill_top_logprobs = prefill_top_logprobs[i]
449
480
  if req.logprob_start_len == 0:
450
- req.token_logprob = [(req.input_ids[0], None)] + req.token_logprob
451
- pt += req.extend_input_len
481
+ req.prefill_top_logprobs = [None] + req.prefill_top_logprobs
482
+ req.decode_top_logprobs = [decode_top_logprobs[i]]
483
+
484
+ pt += req.extend_input_len
452
485
 
453
486
  self.handle_finished_requests(batch)
454
487
 
488
+ def cache_filled_batch(self, batch: Batch):
489
+ req_pool_indices_cpu = batch.req_pool_indices.cpu().tolist()
490
+ for i, req in enumerate(batch.reqs):
491
+ new_prefix_indices, new_last_node = self.tree_cache.cache_req(
492
+ token_ids=tuple(req.input_ids + req.output_ids)[:-1],
493
+ last_uncached_pos=len(req.prefix_indices),
494
+ req_pool_idx=req_pool_indices_cpu[i],
495
+ del_in_memory_pool=False,
496
+ old_last_node=req.last_node,
497
+ )
498
+ req.prefix_indices, req.last_node = new_prefix_indices, new_last_node
499
+
455
500
  def forward_decode_batch(self, batch: Batch):
456
501
  # check if decode out of memory
457
502
  if not batch.check_decode_mem():
@@ -497,29 +542,33 @@ class ModelRpcServer(rpyc.Service):
497
542
  batch.prepare_for_decode()
498
543
 
499
544
  # Forward
500
- logits, (_, _, last_logprobs) = self.model_runner.forward(
501
- batch,
502
- ForwardMode.DECODE,
503
- batch.return_logprob,
504
- )
545
+ logits, (
546
+ _,
547
+ _,
548
+ _,
549
+ decode_top_logprobs,
550
+ last_logprobs,
551
+ ) = self.model_runner.forward(batch, ForwardMode.DECODE)
505
552
  next_token_ids, _ = batch.sample(logits)
506
- next_token_ids = next_token_ids.cpu().tolist()
553
+ next_token_ids = next_token_ids.tolist()
507
554
 
508
555
  # Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
509
- reqs = batch.reqs
510
556
  if last_logprobs is not None:
511
- last_logprobs = last_logprobs[
512
- torch.arange(len(reqs)), next_token_ids
557
+ new_token_logprobs = last_logprobs[
558
+ torch.arange(len(batch.reqs)), next_token_ids
513
559
  ].tolist()
514
560
 
515
561
  # Check finish condition
516
- for i, (req, next_tok_id) in enumerate(zip(reqs, next_token_ids)):
562
+ for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
517
563
  req.completion_tokens_wo_jump_forward += 1
518
- req.output_ids.append(next_tok_id)
564
+ req.output_ids.append(next_token_id)
519
565
  req.check_finished()
520
566
 
521
- if last_logprobs is not None:
522
- req.token_logprob.append((next_tok_id, last_logprobs[i]))
567
+ if req.return_logprob:
568
+ req.decode_token_logprobs.append((new_token_logprobs[i], next_token_id))
569
+
570
+ if req.top_logprobs_num > 0:
571
+ req.decode_top_logprobs.append(decode_top_logprobs[i])
523
572
 
524
573
  self.handle_finished_requests(batch)
525
574
 
@@ -529,6 +578,7 @@ class ModelRpcServer(rpyc.Service):
529
578
  output_and_jump_forward_strs = []
530
579
  output_hit_stop_str = []
531
580
  output_skip_special_tokens = []
581
+ output_spaces_between_special_tokens = []
532
582
  output_meta_info = []
533
583
  output_finished = []
534
584
  finished_indices = []
@@ -555,6 +605,9 @@ class ModelRpcServer(rpyc.Service):
555
605
  output_skip_special_tokens.append(
556
606
  req.sampling_params.skip_special_tokens
557
607
  )
608
+ output_spaces_between_special_tokens.append(
609
+ req.sampling_params.spaces_between_special_tokens
610
+ )
558
611
 
559
612
  meta_info = {
560
613
  "prompt_tokens": req.prompt_tokens,
@@ -562,11 +615,23 @@ class ModelRpcServer(rpyc.Service):
562
615
  + len(req.output_ids)
563
616
  - req.prompt_tokens,
564
617
  "completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
618
+ "finish_reason": FinishReason.to_str(req.finish_reason),
619
+ "hit_stop_str": req.hit_stop_str,
565
620
  }
566
621
  if req.return_logprob:
567
- meta_info["prompt_logprob"] = req.logprob
568
- meta_info["token_logprob"] = req.token_logprob
569
- meta_info["normalized_prompt_logprob"] = req.normalized_logprob
622
+ (
623
+ meta_info["prefill_token_logprobs"],
624
+ meta_info["decode_token_logprobs"],
625
+ meta_info["prefill_top_logprobs"],
626
+ meta_info["decode_top_logprobs"],
627
+ meta_info["normalized_prompt_logprob"],
628
+ ) = (
629
+ req.prefill_token_logprobs,
630
+ req.decode_token_logprobs,
631
+ req.prefill_top_logprobs,
632
+ req.decode_top_logprobs,
633
+ req.normalized_prompt_logprob,
634
+ )
570
635
  output_meta_info.append(meta_info)
571
636
  output_finished.append(req.finished)
572
637
 
@@ -579,6 +644,7 @@ class ModelRpcServer(rpyc.Service):
579
644
  output_and_jump_forward_strs,
580
645
  output_hit_stop_str,
581
646
  output_skip_special_tokens,
647
+ output_spaces_between_special_tokens,
582
648
  output_meta_info,
583
649
  output_finished,
584
650
  )
@@ -587,20 +653,16 @@ class ModelRpcServer(rpyc.Service):
587
653
  # Remove finished reqs
588
654
  if finished_indices:
589
655
  # Update radix cache
590
- req_pool_indices_cpu = batch.req_pool_indices.cpu().tolist()
656
+ req_pool_indices_cpu = batch.req_pool_indices.tolist()
591
657
  for i in finished_indices:
592
658
  req = batch.reqs[i]
593
- req_pool_idx = req_pool_indices_cpu[i]
594
- token_ids = tuple(req.input_ids + req.output_ids)
595
- seq_len = len(token_ids) - 1
596
- indices = self.req_to_token_pool.req_to_token[req_pool_idx, :seq_len]
597
- prefix_len = self.tree_cache.insert(
598
- token_ids[:seq_len], indices.clone()
659
+ self.tree_cache.cache_req(
660
+ token_ids=tuple(req.input_ids + req.output_ids)[:-1],
661
+ last_uncached_pos=len(req.prefix_indices),
662
+ req_pool_idx=req_pool_indices_cpu[i],
599
663
  )
600
664
 
601
- self.token_to_kv_pool.free(indices[:prefix_len])
602
- self.req_to_token_pool.free(req_pool_idx)
603
- self.tree_cache.dec_ref_counter(req.last_node)
665
+ self.tree_cache.dec_lock_ref(req.last_node)
604
666
 
605
667
  # Update batch tensors
606
668
  if unfinished_indices:
@@ -609,14 +671,21 @@ class ModelRpcServer(rpyc.Service):
609
671
  batch.reqs = []
610
672
 
611
673
 
674
+ class ModelRpcService(rpyc.Service):
675
+ exposed_ModelRpcServer = ModelRpcServer
676
+
677
+
612
678
  class ModelRpcClient:
613
- def __init__(self, server_args: ServerArgs, port_args: PortArgs):
679
+ def __init__(
680
+ self, server_args: ServerArgs, port_args: PortArgs, model_overide_args
681
+ ):
614
682
  tp_size = server_args.tp_size
615
683
 
616
684
  if tp_size == 1:
617
685
  # Init model
618
- self.model_server = ModelRpcServer()
619
- self.model_server.exposed_init_model(0, server_args, port_args)
686
+ self.model_server = ModelRpcService().exposed_ModelRpcServer(
687
+ 0, server_args, port_args, model_overide_args
688
+ )
620
689
 
621
690
  # Wrap functions
622
691
  def async_wrap(f):
@@ -630,14 +699,16 @@ class ModelRpcClient:
630
699
  with ThreadPoolExecutor(tp_size) as executor:
631
700
  # Launch model processes
632
701
  rets = executor.map(start_model_process, port_args.model_rpc_ports)
633
- self.model_servers = [x[0] for x in rets]
702
+ self.remote_services = [x[0] for x in rets]
634
703
  self.procs = [x[1] for x in rets]
635
704
 
636
705
  # Init model
637
706
  def init_model(i):
638
- return self.model_servers[i].init_model(i, server_args, port_args)
707
+ return self.remote_services[i].ModelRpcServer(
708
+ i, server_args, port_args, model_overide_args
709
+ )
639
710
 
640
- rets = [obtain(x) for x in executor.map(init_model, range(tp_size))]
711
+ self.model_servers = executor.map(init_model, range(tp_size))
641
712
 
642
713
  # Wrap functions
643
714
  def async_wrap(func_name):
@@ -655,9 +726,13 @@ class ModelRpcClient:
655
726
 
656
727
  def _init_service(port):
657
728
  t = ThreadedServer(
658
- ModelRpcServer(),
729
+ ModelRpcService(),
659
730
  port=port,
660
- protocol_config={"allow_pickle": True, "sync_request_timeout": 1800},
731
+ protocol_config={
732
+ "allow_public_attrs": True,
733
+ "allow_pickle": True,
734
+ "sync_request_timeout": 1800,
735
+ },
661
736
  )
662
737
  t.start()
663
738
 
@@ -673,7 +748,11 @@ def start_model_process(port):
673
748
  con = rpyc.connect(
674
749
  "localhost",
675
750
  port,
676
- config={"allow_pickle": True, "sync_request_timeout": 1800},
751
+ config={
752
+ "allow_public_attrs": True,
753
+ "allow_pickle": True,
754
+ "sync_request_timeout": 1800,
755
+ },
677
756
  )
678
757
  break
679
758
  except ConnectionRefusedError: