sglang 0.2.6__py3-none-any.whl → 0.2.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. sglang/__init__.py +33 -26
  2. sglang/api.py +9 -1
  3. sglang/bench_latency.py +2 -2
  4. sglang/bench_serving.py +10 -1
  5. sglang/check_env.py +1 -1
  6. sglang/lang/backend/litellm.py +1 -1
  7. sglang/lang/backend/openai.py +1 -1
  8. sglang/lang/interpreter.py +21 -5
  9. sglang/lang/ir.py +1 -2
  10. sglang/srt/constrained/__init__.py +15 -0
  11. sglang/srt/constrained/{base_cache.py → base_tool_cache.py} +17 -2
  12. sglang/srt/constrained/fsm_cache.py +17 -2
  13. sglang/srt/constrained/jump_forward.py +17 -2
  14. sglang/srt/conversation.py +26 -0
  15. sglang/srt/hf_transformers_utils.py +15 -0
  16. sglang/srt/layers/context_flashattention_nopad.py +15 -0
  17. sglang/srt/layers/extend_attention.py +15 -0
  18. sglang/srt/layers/fused_moe.py +15 -0
  19. sglang/srt/layers/linear.py +15 -0
  20. sglang/srt/layers/logits_processor.py +41 -13
  21. sglang/srt/layers/quantization/__init__.py +15 -0
  22. sglang/srt/layers/quantization/fp8.py +15 -0
  23. sglang/srt/layers/radix_attention.py +17 -2
  24. sglang/srt/layers/token_attention.py +16 -1
  25. sglang/srt/managers/{controller/manager_multi.py → controller_multi.py} +17 -2
  26. sglang/srt/managers/{controller/manager_single.py → controller_single.py} +17 -2
  27. sglang/srt/managers/detokenizer_manager.py +16 -1
  28. sglang/srt/managers/io_struct.py +36 -3
  29. sglang/srt/managers/{controller/schedule_heuristic.py → policy_scheduler.py} +37 -22
  30. sglang/srt/managers/{controller/infer_batch.py → schedule_batch.py} +60 -21
  31. sglang/srt/managers/tokenizer_manager.py +39 -16
  32. sglang/srt/managers/{controller/tp_worker.py → tp_worker.py} +159 -46
  33. sglang/srt/mem_cache/base_cache.py +43 -0
  34. sglang/srt/mem_cache/chunk_cache.py +60 -0
  35. sglang/srt/mem_cache/flush_cache.py +33 -0
  36. sglang/srt/{memory_pool.py → mem_cache/memory_pool.py} +16 -1
  37. sglang/srt/{managers/controller → mem_cache}/radix_cache.py +20 -2
  38. sglang/srt/mm_utils.py +15 -0
  39. sglang/srt/model_config.py +15 -0
  40. sglang/srt/{managers/controller → model_executor}/cuda_graph_runner.py +16 -1
  41. sglang/srt/{managers/controller → model_executor}/model_runner.py +49 -14
  42. sglang/srt/model_loader/model_loader.py +15 -0
  43. sglang/srt/model_loader/utils.py +16 -1
  44. sglang/srt/models/chatglm.py +16 -1
  45. sglang/srt/models/commandr.py +16 -1
  46. sglang/srt/models/dbrx.py +16 -1
  47. sglang/srt/models/deepseek.py +16 -1
  48. sglang/srt/models/deepseek_v2.py +16 -1
  49. sglang/srt/models/gemma.py +16 -1
  50. sglang/srt/models/gemma2.py +16 -1
  51. sglang/srt/models/gpt_bigcode.py +16 -1
  52. sglang/srt/models/grok.py +16 -1
  53. sglang/srt/models/internlm2.py +16 -1
  54. sglang/srt/models/llama2.py +21 -22
  55. sglang/srt/models/llama_classification.py +16 -1
  56. sglang/srt/models/llava.py +17 -2
  57. sglang/srt/models/llavavid.py +17 -2
  58. sglang/srt/models/minicpm.py +16 -1
  59. sglang/srt/models/mistral.py +15 -0
  60. sglang/srt/models/mixtral.py +16 -1
  61. sglang/srt/models/mixtral_quant.py +16 -1
  62. sglang/srt/models/qwen.py +16 -1
  63. sglang/srt/models/qwen2.py +16 -1
  64. sglang/srt/models/qwen2_moe.py +16 -1
  65. sglang/srt/models/stablelm.py +16 -1
  66. sglang/srt/models/yivl.py +15 -0
  67. sglang/srt/openai_api/adapter.py +569 -131
  68. sglang/srt/openai_api/protocol.py +84 -2
  69. sglang/srt/sampling_params.py +15 -0
  70. sglang/srt/server.py +92 -23
  71. sglang/srt/server_args.py +52 -11
  72. sglang/srt/utils.py +15 -0
  73. sglang/test/test_programs.py +9 -6
  74. sglang/utils.py +22 -0
  75. sglang/version.py +1 -1
  76. {sglang-0.2.6.dist-info → sglang-0.2.8.dist-info}/METADATA +33 -7
  77. sglang-0.2.8.dist-info/RECORD +95 -0
  78. {sglang-0.2.6.dist-info → sglang-0.2.8.dist-info}/WHEEL +1 -1
  79. sglang/srt/flush_cache.py +0 -18
  80. sglang-0.2.6.dist-info/RECORD +0 -93
  81. {sglang-0.2.6.dist-info → sglang-0.2.8.dist-info}/LICENSE +0 -0
  82. {sglang-0.2.6.dist-info → sglang-0.2.8.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,18 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
1
16
  """A tensor parallel worker."""
2
17
 
3
18
  import logging
@@ -14,23 +29,24 @@ from sglang.global_config import global_config
14
29
  from sglang.srt.constrained.fsm_cache import FSMCache
15
30
  from sglang.srt.constrained.jump_forward import JumpForwardCache
16
31
  from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
17
- from sglang.srt.managers.controller.infer_batch import (
18
- FINISH_ABORT,
19
- BaseFinishReason,
20
- Batch,
21
- ForwardMode,
22
- Req,
23
- )
24
- from sglang.srt.managers.controller.model_runner import ModelRunner
25
- from sglang.srt.managers.controller.radix_cache import RadixCache
26
- from sglang.srt.managers.controller.schedule_heuristic import ScheduleHeuristic
27
32
  from sglang.srt.managers.io_struct import (
28
33
  AbortReq,
29
34
  BatchTokenIDOut,
30
35
  FlushCacheReq,
31
36
  TokenizedGenerateReqInput,
32
37
  )
38
+ from sglang.srt.managers.policy_scheduler import PolicyScheduler
39
+ from sglang.srt.managers.schedule_batch import (
40
+ FINISH_ABORT,
41
+ BaseFinishReason,
42
+ Batch,
43
+ ForwardMode,
44
+ Req,
45
+ )
46
+ from sglang.srt.mem_cache.chunk_cache import ChunkCache
47
+ from sglang.srt.mem_cache.radix_cache import RadixCache
33
48
  from sglang.srt.model_config import ModelConfig
49
+ from sglang.srt.model_executor.model_runner import ModelRunner
34
50
  from sglang.srt.server_args import ServerArgs
35
51
  from sglang.srt.utils import (
36
52
  get_int_token_logit_bias,
@@ -40,7 +56,7 @@ from sglang.srt.utils import (
40
56
  )
41
57
  from sglang.utils import get_exception_traceback
42
58
 
43
- logger = logging.getLogger("srt.tp_worker")
59
+ logger = logging.getLogger(__name__)
44
60
 
45
61
 
46
62
  class ModelTpServer:
@@ -59,9 +75,13 @@ class ModelTpServer:
59
75
  self.tp_rank = tp_rank
60
76
  self.tp_size = server_args.tp_size
61
77
  self.dp_size = server_args.dp_size
62
- self.schedule_heuristic = server_args.schedule_heuristic
78
+ self.schedule_policy = server_args.schedule_policy
63
79
  self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
64
80
 
81
+ # Chunked prefill
82
+ self.chunked_prefill_size = server_args.chunked_prefill_size
83
+ self.current_inflight_req = None
84
+
65
85
  # Init model and tokenizer
66
86
  self.model_config = ModelConfig(
67
87
  server_args.model_path,
@@ -117,7 +137,7 @@ class ModelTpServer:
117
137
 
118
138
  # Print info
119
139
  logger.info(
120
- f"[gpu_id={self.gpu_id}] "
140
+ f"[gpu={self.gpu_id}] "
121
141
  f"max_total_num_tokens={self.max_total_num_tokens}, "
122
142
  f"max_prefill_tokens={self.max_prefill_tokens}, "
123
143
  f"max_running_requests={self.max_running_requests}, "
@@ -125,14 +145,23 @@ class ModelTpServer:
125
145
  )
126
146
 
127
147
  # Init cache
128
- self.tree_cache = RadixCache(
129
- req_to_token_pool=self.model_runner.req_to_token_pool,
130
- token_to_kv_pool=self.model_runner.token_to_kv_pool,
131
- disable=server_args.disable_radix_cache,
132
- )
148
+ if (
149
+ server_args.chunked_prefill_size is not None
150
+ and server_args.disable_radix_cache
151
+ ):
152
+ self.tree_cache = ChunkCache(
153
+ req_to_token_pool=self.model_runner.req_to_token_pool,
154
+ token_to_kv_pool=self.model_runner.token_to_kv_pool,
155
+ )
156
+ else:
157
+ self.tree_cache = RadixCache(
158
+ req_to_token_pool=self.model_runner.req_to_token_pool,
159
+ token_to_kv_pool=self.model_runner.token_to_kv_pool,
160
+ disable=server_args.disable_radix_cache,
161
+ )
133
162
  self.tree_cache_metrics = {"total": 0, "hit": 0}
134
- self.scheduler = ScheduleHeuristic(
135
- self.schedule_heuristic,
163
+ self.scheduler = PolicyScheduler(
164
+ self.schedule_policy,
136
165
  self.max_running_requests,
137
166
  self.max_prefill_tokens,
138
167
  self.max_total_num_tokens,
@@ -142,7 +171,7 @@ class ModelTpServer:
142
171
  self.token_to_kv_pool = self.model_runner.token_to_kv_pool
143
172
 
144
173
  # Init running status
145
- self.forward_queue: List[Req] = []
174
+ self.waiting_queue: List[Req] = []
146
175
  self.running_batch: Batch = None
147
176
  self.out_pyobjs = []
148
177
  self.decode_forward_ct = 0
@@ -205,6 +234,7 @@ class ModelTpServer:
205
234
  # Run a new prefill batch
206
235
  self.forward_prefill_batch(new_batch)
207
236
  self.cache_filled_batch(new_batch)
237
+ self.filter_out_inflight(new_batch)
208
238
 
209
239
  if not new_batch.is_empty():
210
240
  if self.running_batch is None:
@@ -241,12 +271,12 @@ class ModelTpServer:
241
271
  self.num_generated_tokens = 0
242
272
  self.last_stats_tic = time.time()
243
273
  logger.info(
244
- f"[gpu_id={self.gpu_id}] Decode batch. "
274
+ f"[gpu={self.gpu_id}] Decode batch. "
245
275
  f"#running-req: {len(self.running_batch.reqs)}, "
246
276
  f"#token: {num_used}, "
247
277
  f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
248
278
  f"gen throughput (token/s): {throughput:.2f}, "
249
- f"#queue-req: {len(self.forward_queue)}"
279
+ f"#queue-req: {len(self.waiting_queue)}"
250
280
  )
251
281
 
252
282
  def check_memory(self):
@@ -260,6 +290,14 @@ class ModelTpServer:
260
290
  "KV cache pool leak detected!"
261
291
  )
262
292
 
293
+ if self.req_to_token_pool.can_use_mem_size != self.req_to_token_pool.size:
294
+ warnings.warn(
295
+ "Warning: "
296
+ f"available req slots={self.req_to_token_pool.can_use_mem_size}, "
297
+ f"total slots={self.req_to_token_pool.size}\n"
298
+ "Memory pool leak detected!"
299
+ )
300
+
263
301
  def handle_generate_request(
264
302
  self,
265
303
  recv_req: TokenizedGenerateReqInput,
@@ -313,9 +351,10 @@ class ModelTpServer:
313
351
  ),
314
352
  self.max_req_input_len - 1 - len(req.origin_input_ids),
315
353
  )
316
- self.forward_queue.append(req)
354
+ self.waiting_queue.append(req)
317
355
 
318
356
  def get_new_prefill_batch(self) -> Optional[Batch]:
357
+ # TODO(lsyin): organize this function
319
358
  running_bs = (
320
359
  len(self.running_batch.reqs) if self.running_batch is not None else 0
321
360
  )
@@ -323,9 +362,12 @@ class ModelTpServer:
323
362
  return
324
363
 
325
364
  # Compute matched prefix length
326
- for req in self.forward_queue:
365
+ for req in self.waiting_queue:
327
366
  req.input_ids = req.origin_input_ids + req.output_ids
328
- prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids)
367
+ prefix_indices, last_node = self.tree_cache.match_prefix(
368
+ rid=req.rid,
369
+ key=req.input_ids,
370
+ )
329
371
  if req.return_logprob:
330
372
  prefix_indices = prefix_indices[: req.logprob_start_len]
331
373
  req.extend_input_len = len(req.input_ids) - len(prefix_indices)
@@ -333,7 +375,7 @@ class ModelTpServer:
333
375
  req.last_node = last_node
334
376
 
335
377
  # Get priority queue
336
- self.forward_queue = self.scheduler.get_priority_queue(self.forward_queue)
378
+ self.waiting_queue = self.scheduler.get_priority_queue(self.waiting_queue)
337
379
 
338
380
  # Add requests if there is available space
339
381
  can_run_list = []
@@ -352,7 +394,33 @@ class ModelTpServer:
352
394
  ]
353
395
  )
354
396
 
355
- for req in self.forward_queue:
397
+ # Handle the current inflight request
398
+ take_inflight = 0
399
+ if self.current_inflight_req:
400
+ take_inflight = 1
401
+ r = self.current_inflight_req
402
+ r.input_ids = r.origin_input_ids + r.output_ids
403
+ truncated = (
404
+ len(r.input_ids) - len(r.prefix_indices) > self.chunked_prefill_size
405
+ )
406
+ r.extend_input_len = min(
407
+ len(r.input_ids) - len(r.prefix_indices), self.chunked_prefill_size
408
+ )
409
+ r.input_ids = r.input_ids[: len(r.prefix_indices) + r.extend_input_len]
410
+ can_run_list.append(r)
411
+
412
+ if not truncated:
413
+ # Finish inflight
414
+ self.current_inflight_req = None
415
+ new_batch_total_tokens += (
416
+ r.extend_input_len + r.sampling_params.max_new_tokens
417
+ )
418
+ new_batch_input_tokens += r.extend_input_len
419
+ else:
420
+ new_batch_total_tokens += r.extend_input_len
421
+ new_batch_input_tokens += r.extend_input_len
422
+
423
+ for req in self.waiting_queue:
356
424
  if req.return_logprob and req.normalized_prompt_logprob is None:
357
425
  # Need at least two tokens to compute normalized logprob
358
426
  if req.extend_input_len < 2:
@@ -394,11 +462,39 @@ class ModelTpServer:
394
462
  break
395
463
  else:
396
464
  # Add this request to the running batch
397
- can_run_list.append(req)
398
- new_batch_total_tokens += (
399
- req.extend_input_len + req.sampling_params.max_new_tokens
400
- )
401
- new_batch_input_tokens += req.extend_input_len
465
+ if (
466
+ self.chunked_prefill_size is None
467
+ or (
468
+ new_batch_input_tokens + req.extend_input_len
469
+ <= self.chunked_prefill_size
470
+ )
471
+ or (
472
+ req.return_logprob and req.normalized_prompt_logprob is None
473
+ )
474
+ ):
475
+ can_run_list.append(req)
476
+ new_batch_total_tokens += (
477
+ req.extend_input_len + req.sampling_params.max_new_tokens
478
+ )
479
+ new_batch_input_tokens += req.extend_input_len
480
+ else:
481
+ trunc_len = self.chunked_prefill_size - new_batch_input_tokens
482
+
483
+ if trunc_len <= 0:
484
+ # Undo locking
485
+ delta = self.tree_cache.dec_lock_ref(req.last_node)
486
+ available_size += delta
487
+ break
488
+
489
+ req.extend_input_len = trunc_len
490
+ req.input_ids = req.input_ids[
491
+ : len(req.prefix_indices) + req.extend_input_len
492
+ ]
493
+ can_run_list.append(req)
494
+ self.current_inflight_req = req
495
+ new_batch_input_tokens += req.extend_input_len
496
+ new_batch_total_tokens += req.extend_input_len
497
+ break
402
498
  else:
403
499
  break
404
500
 
@@ -419,13 +515,13 @@ class ModelTpServer:
419
515
  self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
420
516
  )
421
517
  logger.info(
422
- f"[gpu_id={self.gpu_id}] Prefill batch. "
518
+ f"[gpu={self.gpu_id}] Prefill batch. "
423
519
  f"#new-seq: {len(can_run_list)}, "
424
520
  f"#new-token: {new_batch_input_tokens}, "
425
521
  f"#cached-token: {hit_tokens}, "
426
522
  f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
427
523
  f"#running-req: {running_bs}, "
428
- f"#queue-req: {len(self.forward_queue) - len(can_run_list)}"
524
+ f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + take_inflight}"
429
525
  )
430
526
 
431
527
  # Return the new batch
@@ -435,7 +531,7 @@ class ModelTpServer:
435
531
  self.token_to_kv_pool,
436
532
  self.tree_cache,
437
533
  )
438
- self.forward_queue = [x for x in self.forward_queue if x not in can_run_list]
534
+ self.waiting_queue = [x for x in self.waiting_queue if x not in can_run_list]
439
535
  return new_batch
440
536
 
441
537
  def forward_prefill_batch(self, batch: Batch):
@@ -467,9 +563,10 @@ class ModelTpServer:
467
563
  # Check finish conditions
468
564
  pt = 0
469
565
  for i, req in enumerate(batch.reqs):
470
- req.completion_tokens_wo_jump_forward += 1
471
- req.output_ids.append(next_token_ids[i])
472
- req.check_finished()
566
+ if req is not self.current_inflight_req:
567
+ req.completion_tokens_wo_jump_forward += 1
568
+ req.output_ids.append(next_token_ids[i])
569
+ req.check_finished()
473
570
 
474
571
  if req.return_logprob:
475
572
  self.add_logprob_return_values(i, req, pt, next_token_ids, output)
@@ -530,7 +627,8 @@ class ModelTpServer:
530
627
  req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
531
628
  for i, req in enumerate(batch.reqs):
532
629
  new_prefix_indices, new_last_node = self.tree_cache.cache_req(
533
- token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1],
630
+ rid=req.rid,
631
+ token_ids=tuple(req.input_ids),
534
632
  last_uncached_pos=len(req.prefix_indices),
535
633
  req_pool_idx=req_pool_indices_cpu[i],
536
634
  del_in_memory_pool=False,
@@ -538,6 +636,10 @@ class ModelTpServer:
538
636
  )
539
637
  req.prefix_indices, req.last_node = new_prefix_indices, new_last_node
540
638
 
639
+ if req is self.current_inflight_req:
640
+ # inflight request would get a new req idx
641
+ self.req_to_token_pool.free(int(req_pool_indices_cpu[i]))
642
+
541
643
  def forward_decode_batch(self, batch: Batch):
542
644
  # Check if decode out of memory
543
645
  if not batch.check_decode_mem():
@@ -551,7 +653,7 @@ class ModelTpServer:
551
653
  f"#retracted_reqs: {len(retracted_reqs)}, "
552
654
  f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
553
655
  )
554
- self.forward_queue.extend(retracted_reqs)
656
+ self.waiting_queue.extend(retracted_reqs)
555
657
  else:
556
658
  self.new_token_ratio = max(
557
659
  self.new_token_ratio - self.new_token_ratio_decay,
@@ -561,7 +663,7 @@ class ModelTpServer:
561
663
  if not self.disable_regex_jump_forward:
562
664
  # Check for jump-forward
563
665
  jump_forward_reqs = batch.check_for_jump_forward(self.model_runner)
564
- self.forward_queue.extend(jump_forward_reqs)
666
+ self.waiting_queue.extend(jump_forward_reqs)
565
667
  if batch.is_empty():
566
668
  return
567
669
 
@@ -683,6 +785,7 @@ class ModelTpServer:
683
785
  for i in finished_indices:
684
786
  req = batch.reqs[i]
685
787
  self.tree_cache.cache_req(
788
+ rid=req.rid,
686
789
  token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1],
687
790
  last_uncached_pos=len(req.prefix_indices),
688
791
  req_pool_idx=req_pool_indices_cpu[i],
@@ -696,8 +799,18 @@ class ModelTpServer:
696
799
  else:
697
800
  batch.reqs = []
698
801
 
802
+ def filter_out_inflight(self, batch: Batch):
803
+ # TODO(lsyin): reduce the overhead, make a special version for this
804
+ if self.current_inflight_req is None:
805
+ return
806
+
807
+ to_remove = batch.reqs.index(self.current_inflight_req)
808
+ unfinished_indices = [i for i in range(len(batch.reqs)) if i != to_remove]
809
+
810
+ batch.filter_batch(unfinished_indices)
811
+
699
812
  def flush_cache(self):
700
- if len(self.forward_queue) == 0 and (
813
+ if len(self.waiting_queue) == 0 and (
701
814
  self.running_batch is None or len(self.running_batch.reqs) == 0
702
815
  ):
703
816
  self.tree_cache.reset()
@@ -710,20 +823,20 @@ class ModelTpServer:
710
823
  else:
711
824
  warnings.warn(
712
825
  f"Cache not flushed because there are pending requests. "
713
- f"#queue-req: {len(self.forward_queue)}, "
826
+ f"#queue-req: {len(self.waiting_queue)}, "
714
827
  f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
715
828
  )
716
829
 
717
830
  def abort_request(self, recv_req):
718
831
  # Delete requests in the waiting queue
719
832
  to_del = None
720
- for i, req in enumerate(self.forward_queue):
833
+ for i, req in enumerate(self.waiting_queue):
721
834
  if req.rid == recv_req.rid:
722
835
  to_del = i
723
836
  break
724
837
 
725
838
  if to_del is not None:
726
- del self.forward_queue[to_del]
839
+ del self.waiting_queue[to_del]
727
840
 
728
841
  # Delete requests in the running batch
729
842
  if self.running_batch:
@@ -0,0 +1,43 @@
1
+ from abc import ABC, abstractmethod
2
+
3
+
4
+ class BasePrefixCache(ABC):
5
+ """Cache can be indexed by either rid or key."""
6
+
7
+ @abstractmethod
8
+ def reset(self):
9
+ pass
10
+
11
+ @abstractmethod
12
+ def match_prefix(self, **kwargs):
13
+ pass
14
+
15
+ @abstractmethod
16
+ def insert(self, **kwargs):
17
+ pass
18
+
19
+ @abstractmethod
20
+ def cache_req(self, **kwargs):
21
+ pass
22
+
23
+ @abstractmethod
24
+ def evict(self, num_tokens, evict_callback):
25
+ pass
26
+
27
+ @abstractmethod
28
+ def inc_lock_ref(self, node):
29
+ pass
30
+
31
+ @abstractmethod
32
+ def dec_lock_ref(self, node):
33
+ pass
34
+
35
+ @abstractmethod
36
+ def evictable_size(self):
37
+ pass
38
+
39
+ def total_size(self):
40
+ raise NotImplementedError
41
+
42
+ def pretty_print(self):
43
+ raise NotImplementedError
@@ -0,0 +1,60 @@
1
+ """Cache for chunked prefill, used when RadixCache is disabled."""
2
+
3
+ from sglang.srt.mem_cache.base_cache import BasePrefixCache
4
+
5
+
6
+ class ChunkCacheEntry:
7
+ def __init__(self, rid, value):
8
+ self.rid = rid
9
+ self.value = value
10
+
11
+
12
+ class ChunkCache(BasePrefixCache):
13
+ def __init__(self, req_to_token_pool, token_to_kv_pool):
14
+ self.disable = True
15
+ self.req_to_token_pool = req_to_token_pool
16
+ self.token_to_kv_pool = token_to_kv_pool
17
+
18
+ self.reset()
19
+
20
+ def reset(self):
21
+ self.entries = {}
22
+
23
+ def match_prefix(self, rid, **kwargs):
24
+ if rid not in self.entries:
25
+ return [], None
26
+
27
+ entry = self.entries[rid]
28
+ return entry.value, entry
29
+
30
+ def cache_req(
31
+ self, rid, token_ids, req_pool_idx, del_in_memory_pool=True, **kwargs
32
+ ):
33
+ indices = self.req_to_token_pool.req_to_token[req_pool_idx, : len(token_ids)]
34
+ if del_in_memory_pool:
35
+ assert rid in self.entries
36
+ self.req_to_token_pool.free(req_pool_idx)
37
+ self.token_to_kv_pool.free(indices)
38
+ return
39
+
40
+ if rid not in self.entries:
41
+ self.entries[rid] = ChunkCacheEntry(rid, indices)
42
+
43
+ entry = self.entries[rid]
44
+ entry.value = indices
45
+ return indices, entry
46
+
47
+ def insert(self):
48
+ raise NotImplementedError
49
+
50
+ def evict(self, num_tokens, evict_callback):
51
+ pass
52
+
53
+ def inc_lock_ref(self, node):
54
+ return 0
55
+
56
+ def dec_lock_ref(self, node):
57
+ return 0
58
+
59
+ def evictable_size(self):
60
+ return 0
@@ -0,0 +1,33 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
16
+ """
17
+ Flush the KV cache.
18
+
19
+ Usage:
20
+ python3 -m sglang.srt.mem_cache.flush_cache --url http://localhost:30000
21
+ """
22
+
23
+ import argparse
24
+
25
+ import requests
26
+
27
+ if __name__ == "__main__":
28
+ parser = argparse.ArgumentParser()
29
+ parser.add_argument("--url", type=str, default="http://localhost:30000")
30
+ args = parser.parse_args()
31
+
32
+ response = requests.get(args.url + "/flush_cache")
33
+ assert response.status_code == 200
@@ -1,3 +1,18 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
1
16
  """Memory pool."""
2
17
 
3
18
  import logging
@@ -30,7 +45,7 @@ class ReqToTokenPool:
30
45
 
31
46
  return select_index
32
47
 
33
- def free(self, free_index: int):
48
+ def free(self, free_index):
34
49
  self.mem_state[free_index] = True
35
50
  if isinstance(free_index, (int,)):
36
51
  self.can_use_mem_size += 1
@@ -1,3 +1,18 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
1
16
  """
2
17
  The radix tree data structure for managing the KV cache.
3
18
  """
@@ -8,6 +23,8 @@ from collections import defaultdict
8
23
 
9
24
  import torch
10
25
 
26
+ from sglang.srt.mem_cache.base_cache import BasePrefixCache
27
+
11
28
 
12
29
  class TreeNode:
13
30
  def __init__(self):
@@ -31,7 +48,7 @@ def _key_match(key0, key1):
31
48
  return i
32
49
 
33
50
 
34
- class RadixCache:
51
+ class RadixCache(BasePrefixCache):
35
52
  def __init__(self, req_to_token_pool, token_to_kv_pool, disable: bool = False):
36
53
  self.req_to_token_pool = req_to_token_pool
37
54
  self.token_to_kv_pool = token_to_kv_pool
@@ -47,7 +64,7 @@ class RadixCache:
47
64
  self.root_node.lock_ref = 1
48
65
  self.evictable_size_ = 0
49
66
 
50
- def match_prefix(self, key):
67
+ def match_prefix(self, key, **kwargs):
51
68
  if self.disable:
52
69
  return [], self.root_node
53
70
 
@@ -75,6 +92,7 @@ class RadixCache:
75
92
  req_pool_idx,
76
93
  del_in_memory_pool=True,
77
94
  old_last_node=None,
95
+ **kwargs,
78
96
  ):
79
97
  # Insert the request into radix cache
80
98
  indices = self.req_to_token_pool.req_to_token[req_pool_idx, : len(token_ids)]
sglang/srt/mm_utils.py CHANGED
@@ -1,3 +1,18 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
1
16
  # Source: https://github.com/haotian-liu/LLaVA/blob/main/llava/mm_utils.py
2
17
  import ast
3
18
  import base64
@@ -1,3 +1,18 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
1
16
  from typing import Optional
2
17
 
3
18
  from transformers import PretrainedConfig
@@ -1,3 +1,18 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
1
16
  """Run the model with cuda graph."""
2
17
 
3
18
  import bisect
@@ -14,7 +29,7 @@ from sglang.srt.layers.logits_processor import (
14
29
  LogitsMetadata,
15
30
  LogitsProcessor,
16
31
  )
17
- from sglang.srt.managers.controller.infer_batch import (
32
+ from sglang.srt.managers.schedule_batch import (
18
33
  Batch,
19
34
  ForwardMode,
20
35
  InputMetadata,