sglang 0.2.5__py3-none-any.whl → 0.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. sglang/__init__.py +33 -26
  2. sglang/api.py +9 -1
  3. sglang/bench_latency.py +2 -2
  4. sglang/bench_serving.py +10 -1
  5. sglang/check_env.py +1 -1
  6. sglang/lang/backend/litellm.py +1 -1
  7. sglang/lang/backend/openai.py +1 -1
  8. sglang/lang/backend/runtime_endpoint.py +4 -4
  9. sglang/lang/interpreter.py +24 -9
  10. sglang/lang/ir.py +1 -1
  11. sglang/srt/constrained/__init__.py +15 -0
  12. sglang/srt/constrained/base_cache.py +15 -0
  13. sglang/srt/constrained/fsm_cache.py +36 -1
  14. sglang/srt/constrained/jump_forward.py +15 -0
  15. sglang/srt/conversation.py +26 -0
  16. sglang/srt/hf_transformers_utils.py +18 -1
  17. sglang/srt/layers/context_flashattention_nopad.py +15 -0
  18. sglang/srt/layers/extend_attention.py +15 -0
  19. sglang/srt/layers/fused_moe.py +15 -0
  20. sglang/srt/layers/linear.py +15 -0
  21. sglang/srt/layers/logits_processor.py +109 -72
  22. sglang/srt/layers/quantization/__init__.py +15 -0
  23. sglang/srt/layers/quantization/fp8.py +15 -0
  24. sglang/srt/layers/radix_attention.py +21 -3
  25. sglang/srt/layers/token_attention.py +16 -1
  26. sglang/srt/managers/{controller/manager_multi.py → controller_multi.py} +17 -2
  27. sglang/srt/managers/{controller/manager_single.py → controller_single.py} +17 -2
  28. sglang/srt/managers/detokenizer_manager.py +16 -1
  29. sglang/srt/managers/io_struct.py +38 -5
  30. sglang/srt/managers/{controller/schedule_heuristic.py → policy_scheduler.py} +37 -22
  31. sglang/srt/managers/{controller/infer_batch.py → schedule_batch.py} +85 -25
  32. sglang/srt/managers/tokenizer_manager.py +99 -57
  33. sglang/srt/managers/{controller/tp_worker.py → tp_worker.py} +177 -81
  34. sglang/srt/mem_cache/flush_cache.py +33 -0
  35. sglang/srt/{memory_pool.py → mem_cache/memory_pool.py} +16 -1
  36. sglang/srt/{managers/controller → mem_cache}/radix_cache.py +15 -0
  37. sglang/srt/mm_utils.py +15 -0
  38. sglang/srt/model_config.py +20 -0
  39. sglang/srt/{managers/controller → model_executor}/cuda_graph_runner.py +42 -18
  40. sglang/srt/{managers/controller → model_executor}/model_runner.py +51 -16
  41. sglang/srt/model_loader/model_loader.py +15 -0
  42. sglang/srt/model_loader/utils.py +16 -1
  43. sglang/srt/models/chatglm.py +16 -1
  44. sglang/srt/models/commandr.py +16 -1
  45. sglang/srt/models/dbrx.py +16 -1
  46. sglang/srt/models/deepseek.py +16 -1
  47. sglang/srt/models/deepseek_v2.py +532 -0
  48. sglang/srt/models/gemma.py +16 -1
  49. sglang/srt/models/gemma2.py +16 -1
  50. sglang/srt/models/gpt_bigcode.py +16 -1
  51. sglang/srt/models/grok.py +16 -1
  52. sglang/srt/models/internlm2.py +16 -1
  53. sglang/srt/models/llama2.py +16 -1
  54. sglang/srt/models/llama_classification.py +19 -4
  55. sglang/srt/models/llava.py +17 -2
  56. sglang/srt/models/llavavid.py +17 -2
  57. sglang/srt/models/minicpm.py +16 -1
  58. sglang/srt/models/mistral.py +15 -0
  59. sglang/srt/models/mixtral.py +16 -1
  60. sglang/srt/models/mixtral_quant.py +16 -1
  61. sglang/srt/models/qwen.py +16 -1
  62. sglang/srt/models/qwen2.py +16 -1
  63. sglang/srt/models/qwen2_moe.py +16 -1
  64. sglang/srt/models/stablelm.py +16 -1
  65. sglang/srt/models/yivl.py +15 -0
  66. sglang/srt/openai_api/adapter.py +545 -160
  67. sglang/srt/openai_api/protocol.py +65 -1
  68. sglang/srt/sampling_params.py +20 -4
  69. sglang/srt/server.py +90 -37
  70. sglang/srt/server_args.py +76 -17
  71. sglang/srt/utils.py +15 -0
  72. sglang/test/test_programs.py +5 -1
  73. sglang/utils.py +22 -0
  74. sglang/version.py +1 -1
  75. {sglang-0.2.5.dist-info → sglang-0.2.7.dist-info}/METADATA +40 -12
  76. sglang-0.2.7.dist-info/RECORD +93 -0
  77. {sglang-0.2.5.dist-info → sglang-0.2.7.dist-info}/WHEEL +1 -1
  78. sglang/srt/flush_cache.py +0 -18
  79. sglang-0.2.5.dist-info/RECORD +0 -92
  80. {sglang-0.2.5.dist-info → sglang-0.2.7.dist-info}/LICENSE +0 -0
  81. {sglang-0.2.5.dist-info → sglang-0.2.7.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,18 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
1
16
  """A tensor parallel worker."""
2
17
 
3
18
  import logging
@@ -14,23 +29,23 @@ from sglang.global_config import global_config
14
29
  from sglang.srt.constrained.fsm_cache import FSMCache
15
30
  from sglang.srt.constrained.jump_forward import JumpForwardCache
16
31
  from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
17
- from sglang.srt.managers.controller.infer_batch import (
18
- FINISH_ABORT,
19
- BaseFinishReason,
20
- Batch,
21
- ForwardMode,
22
- Req,
23
- )
24
- from sglang.srt.managers.controller.model_runner import ModelRunner
25
- from sglang.srt.managers.controller.radix_cache import RadixCache
26
- from sglang.srt.managers.controller.schedule_heuristic import ScheduleHeuristic
27
32
  from sglang.srt.managers.io_struct import (
28
33
  AbortReq,
29
34
  BatchTokenIDOut,
30
35
  FlushCacheReq,
31
36
  TokenizedGenerateReqInput,
32
37
  )
38
+ from sglang.srt.managers.policy_scheduler import PolicyScheduler
39
+ from sglang.srt.managers.schedule_batch import (
40
+ FINISH_ABORT,
41
+ BaseFinishReason,
42
+ Batch,
43
+ ForwardMode,
44
+ Req,
45
+ )
46
+ from sglang.srt.mem_cache.radix_cache import RadixCache
33
47
  from sglang.srt.model_config import ModelConfig
48
+ from sglang.srt.model_executor.model_runner import ModelRunner
34
49
  from sglang.srt.server_args import ServerArgs
35
50
  from sglang.srt.utils import (
36
51
  get_int_token_logit_bias,
@@ -40,7 +55,7 @@ from sglang.srt.utils import (
40
55
  )
41
56
  from sglang.utils import get_exception_traceback
42
57
 
43
- logger = logging.getLogger("srt.tp_worker")
58
+ logger = logging.getLogger(__name__)
44
59
 
45
60
 
46
61
  class ModelTpServer:
@@ -59,9 +74,13 @@ class ModelTpServer:
59
74
  self.tp_rank = tp_rank
60
75
  self.tp_size = server_args.tp_size
61
76
  self.dp_size = server_args.dp_size
62
- self.schedule_heuristic = server_args.schedule_heuristic
77
+ self.schedule_policy = server_args.schedule_policy
63
78
  self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
64
79
 
80
+ # Chunked prefill
81
+ self.chunked_prefill_size = server_args.chunked_prefill_size
82
+ self.current_inflight_req = None
83
+
65
84
  # Init model and tokenizer
66
85
  self.model_config = ModelConfig(
67
86
  server_args.model_path,
@@ -98,22 +117,26 @@ class ModelTpServer:
98
117
  if server_args.max_prefill_tokens is None
99
118
  else server_args.max_prefill_tokens
100
119
  )
101
- self.max_running_requests = (
102
- self.max_total_num_tokens // 2
103
- if server_args.max_running_requests is None
104
- else server_args.max_running_requests
105
- )
106
120
  self.max_running_requests = min(
107
- self.max_running_requests, self.model_runner.req_to_token_pool.size - 1
121
+ (
122
+ self.max_total_num_tokens // 2
123
+ if server_args.max_running_requests is None
124
+ else server_args.max_running_requests
125
+ ),
126
+ self.model_runner.req_to_token_pool.size - 1,
108
127
  )
109
128
  self.int_token_logit_bias = torch.tensor(
110
129
  get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
111
130
  )
131
+ self.max_req_input_len = min(
132
+ self.model_config.context_len - 1,
133
+ self.max_total_num_tokens - 1,
134
+ )
112
135
  set_random_seed(server_args.random_seed)
113
136
 
114
137
  # Print info
115
138
  logger.info(
116
- f"[gpu_id={self.gpu_id}] "
139
+ f"[gpu={self.gpu_id}] "
117
140
  f"max_total_num_tokens={self.max_total_num_tokens}, "
118
141
  f"max_prefill_tokens={self.max_prefill_tokens}, "
119
142
  f"max_running_requests={self.max_running_requests}, "
@@ -127,8 +150,8 @@ class ModelTpServer:
127
150
  disable=server_args.disable_radix_cache,
128
151
  )
129
152
  self.tree_cache_metrics = {"total": 0, "hit": 0}
130
- self.scheduler = ScheduleHeuristic(
131
- self.schedule_heuristic,
153
+ self.scheduler = PolicyScheduler(
154
+ self.schedule_policy,
132
155
  self.max_running_requests,
133
156
  self.max_prefill_tokens,
134
157
  self.max_total_num_tokens,
@@ -138,7 +161,7 @@ class ModelTpServer:
138
161
  self.token_to_kv_pool = self.model_runner.token_to_kv_pool
139
162
 
140
163
  # Init running status
141
- self.forward_queue: List[Req] = []
164
+ self.waiting_queue: List[Req] = []
142
165
  self.running_batch: Batch = None
143
166
  self.out_pyobjs = []
144
167
  self.decode_forward_ct = 0
@@ -201,6 +224,7 @@ class ModelTpServer:
201
224
  # Run a new prefill batch
202
225
  self.forward_prefill_batch(new_batch)
203
226
  self.cache_filled_batch(new_batch)
227
+ self.filter_out_inflight(new_batch)
204
228
 
205
229
  if not new_batch.is_empty():
206
230
  if self.running_batch is None:
@@ -237,12 +261,12 @@ class ModelTpServer:
237
261
  self.num_generated_tokens = 0
238
262
  self.last_stats_tic = time.time()
239
263
  logger.info(
240
- f"[gpu_id={self.gpu_id}] Decode batch. "
264
+ f"[gpu={self.gpu_id}] Decode batch. "
241
265
  f"#running-req: {len(self.running_batch.reqs)}, "
242
266
  f"#token: {num_used}, "
243
267
  f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
244
268
  f"gen throughput (token/s): {throughput:.2f}, "
245
- f"#queue-req: {len(self.forward_queue)}"
269
+ f"#queue-req: {len(self.waiting_queue)}"
246
270
  )
247
271
 
248
272
  def check_memory(self):
@@ -295,21 +319,24 @@ class ModelTpServer:
295
319
  )
296
320
 
297
321
  # Truncate prompts that are too long
298
- req.origin_input_ids = req.origin_input_ids[: self.model_config.context_len - 1]
322
+ if len(req.origin_input_ids) >= self.max_req_input_len:
323
+ logger.warn(
324
+ "Request length is longer than the KV cache pool size or "
325
+ "the max context length. Truncated!!!"
326
+ )
327
+ req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
299
328
  req.sampling_params.max_new_tokens = min(
300
- req.sampling_params.max_new_tokens,
301
- self.model_config.context_len - 1 - len(req.origin_input_ids),
302
- self.max_total_num_tokens - 128 - len(req.origin_input_ids),
329
+ (
330
+ req.sampling_params.max_new_tokens
331
+ if req.sampling_params.max_new_tokens is not None
332
+ else 1 << 30
333
+ ),
334
+ self.max_req_input_len - 1 - len(req.origin_input_ids),
303
335
  )
304
- if req.sampling_params.max_new_tokens < 0:
305
- req.origin_input_ids = req.origin_input_ids[
306
- : self.max_total_num_tokens - 128
307
- ]
308
- logger.error("Request longer than memory pool size, truncated!!!")
309
-
310
- self.forward_queue.append(req)
336
+ self.waiting_queue.append(req)
311
337
 
312
338
  def get_new_prefill_batch(self) -> Optional[Batch]:
339
+ # TODO(lsyin): organize this function
313
340
  running_bs = (
314
341
  len(self.running_batch.reqs) if self.running_batch is not None else 0
315
342
  )
@@ -317,7 +344,7 @@ class ModelTpServer:
317
344
  return
318
345
 
319
346
  # Compute matched prefix length
320
- for req in self.forward_queue:
347
+ for req in self.waiting_queue:
321
348
  req.input_ids = req.origin_input_ids + req.output_ids
322
349
  prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids)
323
350
  if req.return_logprob:
@@ -327,7 +354,7 @@ class ModelTpServer:
327
354
  req.last_node = last_node
328
355
 
329
356
  # Get priority queue
330
- self.forward_queue = self.scheduler.get_priority_queue(self.forward_queue)
357
+ self.waiting_queue = self.scheduler.get_priority_queue(self.waiting_queue)
331
358
 
332
359
  # Add requests if there is available space
333
360
  can_run_list = []
@@ -346,7 +373,33 @@ class ModelTpServer:
346
373
  ]
347
374
  )
348
375
 
349
- for req in self.forward_queue:
376
+ # Handle the current inflight request
377
+ take_inflight = 0
378
+ if self.current_inflight_req:
379
+ take_inflight = 1
380
+ r = self.current_inflight_req
381
+ r.input_ids = r.origin_input_ids + r.output_ids
382
+ truncated = (
383
+ len(r.input_ids) - len(r.prefix_indices) > self.chunked_prefill_size
384
+ )
385
+ r.extend_input_len = min(
386
+ len(r.input_ids) - len(r.prefix_indices), self.chunked_prefill_size
387
+ )
388
+ r.input_ids = r.input_ids[: len(r.prefix_indices) + r.extend_input_len]
389
+ can_run_list.append(r)
390
+
391
+ if not truncated:
392
+ # Finish inflight
393
+ self.current_inflight_req = None
394
+ new_batch_total_tokens += (
395
+ r.extend_input_len + r.sampling_params.max_new_tokens
396
+ )
397
+ new_batch_input_tokens += r.extend_input_len
398
+ else:
399
+ new_batch_total_tokens += r.extend_input_len
400
+ new_batch_input_tokens += r.extend_input_len
401
+
402
+ for req in self.waiting_queue:
350
403
  if req.return_logprob and req.normalized_prompt_logprob is None:
351
404
  # Need at least two tokens to compute normalized logprob
352
405
  if req.extend_input_len < 2:
@@ -388,11 +441,39 @@ class ModelTpServer:
388
441
  break
389
442
  else:
390
443
  # Add this request to the running batch
391
- can_run_list.append(req)
392
- new_batch_total_tokens += (
393
- req.extend_input_len + req.sampling_params.max_new_tokens
394
- )
395
- new_batch_input_tokens += req.extend_input_len
444
+ if (
445
+ self.chunked_prefill_size is None
446
+ or (
447
+ new_batch_input_tokens + req.extend_input_len
448
+ <= self.chunked_prefill_size
449
+ )
450
+ or (
451
+ req.return_logprob and req.normalized_prompt_logprob is None
452
+ )
453
+ ):
454
+ can_run_list.append(req)
455
+ new_batch_total_tokens += (
456
+ req.extend_input_len + req.sampling_params.max_new_tokens
457
+ )
458
+ new_batch_input_tokens += req.extend_input_len
459
+ else:
460
+ trunc_len = self.chunked_prefill_size - new_batch_input_tokens
461
+
462
+ if trunc_len <= 0:
463
+ # Undo locking
464
+ delta = self.tree_cache.dec_lock_ref(req.last_node)
465
+ available_size += delta
466
+ break
467
+
468
+ req.extend_input_len = trunc_len
469
+ req.input_ids = req.input_ids[
470
+ : len(req.prefix_indices) + req.extend_input_len
471
+ ]
472
+ can_run_list.append(req)
473
+ self.current_inflight_req = req
474
+ new_batch_input_tokens += req.extend_input_len
475
+ new_batch_total_tokens += req.extend_input_len
476
+ break
396
477
  else:
397
478
  break
398
479
 
@@ -413,13 +494,13 @@ class ModelTpServer:
413
494
  self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
414
495
  )
415
496
  logger.info(
416
- f"[gpu_id={self.gpu_id}] Prefill batch. "
497
+ f"[gpu={self.gpu_id}] Prefill batch. "
417
498
  f"#new-seq: {len(can_run_list)}, "
418
499
  f"#new-token: {new_batch_input_tokens}, "
419
500
  f"#cached-token: {hit_tokens}, "
420
501
  f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
421
502
  f"#running-req: {running_bs}, "
422
- f"#queue-req: {len(self.forward_queue) - len(can_run_list)}"
503
+ f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + take_inflight}"
423
504
  )
424
505
 
425
506
  # Return the new batch
@@ -429,7 +510,7 @@ class ModelTpServer:
429
510
  self.token_to_kv_pool,
430
511
  self.tree_cache,
431
512
  )
432
- self.forward_queue = [x for x in self.forward_queue if x not in can_run_list]
513
+ self.waiting_queue = [x for x in self.waiting_queue if x not in can_run_list]
433
514
  return new_batch
434
515
 
435
516
  def forward_prefill_batch(self, batch: Batch):
@@ -449,7 +530,7 @@ class ModelTpServer:
449
530
  torch.arange(len(next_token_ids), device=next_token_ids.device),
450
531
  next_token_ids,
451
532
  ].tolist()
452
- output.prefill_token_logprobs = output.prefill_token_logprobs.tolist()
533
+ output.input_token_logprobs = output.input_token_logprobs.tolist()
453
534
  output.normalized_prompt_logprobs = (
454
535
  output.normalized_prompt_logprobs.tolist()
455
536
  )
@@ -461,9 +542,10 @@ class ModelTpServer:
461
542
  # Check finish conditions
462
543
  pt = 0
463
544
  for i, req in enumerate(batch.reqs):
464
- req.completion_tokens_wo_jump_forward += 1
465
- req.output_ids.append(next_token_ids[i])
466
- req.check_finished()
545
+ if req is not self.current_inflight_req:
546
+ req.completion_tokens_wo_jump_forward += 1
547
+ req.output_ids.append(next_token_ids[i])
548
+ req.check_finished()
467
549
 
468
550
  if req.return_logprob:
469
551
  self.add_logprob_return_values(i, req, pt, next_token_ids, output)
@@ -475,24 +557,24 @@ class ModelTpServer:
475
557
  if req.normalized_prompt_logprob is None:
476
558
  req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
477
559
 
478
- if req.prefill_token_logprobs is None:
560
+ if req.input_token_logprobs is None:
479
561
  # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
480
- req.prefill_token_logprobs = list(
562
+ req.input_token_logprobs = list(
481
563
  zip(
482
- output.prefill_token_logprobs[pt : pt + req.extend_input_len - 1],
564
+ output.input_token_logprobs[pt : pt + req.extend_input_len - 1],
483
565
  req.input_ids[-req.extend_input_len + 1 :],
484
566
  )
485
567
  )
486
568
  if req.logprob_start_len == 0:
487
- req.prefill_token_logprobs = [
569
+ req.input_token_logprobs = [
488
570
  (None, req.input_ids[0])
489
- ] + req.prefill_token_logprobs
571
+ ] + req.input_token_logprobs
490
572
 
491
573
  if req.last_update_decode_tokens != 0:
492
- req.decode_token_logprobs.extend(
574
+ req.output_token_logprobs.extend(
493
575
  list(
494
576
  zip(
495
- output.prefill_token_logprobs[
577
+ output.input_token_logprobs[
496
578
  pt
497
579
  + req.extend_input_len
498
580
  - req.last_update_decode_tokens : pt
@@ -504,27 +586,27 @@ class ModelTpServer:
504
586
  )
505
587
  )
506
588
 
507
- req.decode_token_logprobs.append(
589
+ req.output_token_logprobs.append(
508
590
  (output.next_token_logprobs[i], next_token_ids[i])
509
591
  )
510
592
 
511
593
  if req.top_logprobs_num > 0:
512
- if req.prefill_top_logprobs is None:
513
- req.prefill_top_logprobs = output.prefill_top_logprobs[i]
594
+ if req.input_top_logprobs is None:
595
+ req.input_top_logprobs = output.input_top_logprobs[i]
514
596
  if req.logprob_start_len == 0:
515
- req.prefill_top_logprobs = [None] + req.prefill_top_logprobs
597
+ req.input_top_logprobs = [None] + req.input_top_logprobs
516
598
 
517
599
  if req.last_update_decode_tokens != 0:
518
- req.decode_top_logprobs.extend(
519
- output.prefill_top_logprobs[i][-req.last_update_decode_tokens + 1 :]
600
+ req.output_top_logprobs.extend(
601
+ output.input_top_logprobs[i][-req.last_update_decode_tokens + 1 :]
520
602
  )
521
- req.decode_top_logprobs.append(output.decode_top_logprobs[i])
603
+ req.output_top_logprobs.append(output.output_top_logprobs[i])
522
604
 
523
605
  def cache_filled_batch(self, batch: Batch):
524
606
  req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
525
607
  for i, req in enumerate(batch.reqs):
526
608
  new_prefix_indices, new_last_node = self.tree_cache.cache_req(
527
- token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1],
609
+ token_ids=tuple(req.input_ids),
528
610
  last_uncached_pos=len(req.prefix_indices),
529
611
  req_pool_idx=req_pool_indices_cpu[i],
530
612
  del_in_memory_pool=False,
@@ -532,6 +614,10 @@ class ModelTpServer:
532
614
  )
533
615
  req.prefix_indices, req.last_node = new_prefix_indices, new_last_node
534
616
 
617
+ if req is self.current_inflight_req:
618
+ # inflight request would get a new req idx
619
+ self.req_to_token_pool.free(int(req_pool_indices_cpu[i]))
620
+
535
621
  def forward_decode_batch(self, batch: Batch):
536
622
  # Check if decode out of memory
537
623
  if not batch.check_decode_mem():
@@ -545,7 +631,7 @@ class ModelTpServer:
545
631
  f"#retracted_reqs: {len(retracted_reqs)}, "
546
632
  f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
547
633
  )
548
- self.forward_queue.extend(retracted_reqs)
634
+ self.waiting_queue.extend(retracted_reqs)
549
635
  else:
550
636
  self.new_token_ratio = max(
551
637
  self.new_token_ratio - self.new_token_ratio_decay,
@@ -555,7 +641,7 @@ class ModelTpServer:
555
641
  if not self.disable_regex_jump_forward:
556
642
  # Check for jump-forward
557
643
  jump_forward_reqs = batch.check_for_jump_forward(self.model_runner)
558
- self.forward_queue.extend(jump_forward_reqs)
644
+ self.waiting_queue.extend(jump_forward_reqs)
559
645
  if batch.is_empty():
560
646
  return
561
647
 
@@ -583,11 +669,11 @@ class ModelTpServer:
583
669
  req.check_finished()
584
670
 
585
671
  if req.return_logprob:
586
- req.decode_token_logprobs.append(
672
+ req.output_token_logprobs.append(
587
673
  (next_token_logprobs[i], next_token_id)
588
674
  )
589
675
  if req.top_logprobs_num > 0:
590
- req.decode_top_logprobs.append(output.decode_top_logprobs[i])
676
+ req.output_top_logprobs.append(output.output_top_logprobs[i])
591
677
 
592
678
  self.handle_finished_requests(batch)
593
679
 
@@ -639,16 +725,16 @@ class ModelTpServer:
639
725
  }
640
726
  if req.return_logprob:
641
727
  (
642
- meta_info["prefill_token_logprobs"],
643
- meta_info["decode_token_logprobs"],
644
- meta_info["prefill_top_logprobs"],
645
- meta_info["decode_top_logprobs"],
728
+ meta_info["input_token_logprobs"],
729
+ meta_info["output_token_logprobs"],
730
+ meta_info["input_top_logprobs"],
731
+ meta_info["output_top_logprobs"],
646
732
  meta_info["normalized_prompt_logprob"],
647
733
  ) = (
648
- req.prefill_token_logprobs,
649
- req.decode_token_logprobs,
650
- req.prefill_top_logprobs,
651
- req.decode_top_logprobs,
734
+ req.input_token_logprobs,
735
+ req.output_token_logprobs,
736
+ req.input_top_logprobs,
737
+ req.output_top_logprobs,
652
738
  req.normalized_prompt_logprob,
653
739
  )
654
740
  output_meta_info.append(meta_info)
@@ -690,8 +776,18 @@ class ModelTpServer:
690
776
  else:
691
777
  batch.reqs = []
692
778
 
779
+ def filter_out_inflight(self, batch: Batch):
780
+ # TODO(lsyin): reduce the overhead, make a special version for this
781
+ if self.current_inflight_req is None:
782
+ return
783
+
784
+ to_remove = batch.reqs.index(self.current_inflight_req)
785
+ unfinished_indices = [i for i in range(len(batch.reqs)) if i != to_remove]
786
+
787
+ batch.filter_batch(unfinished_indices)
788
+
693
789
  def flush_cache(self):
694
- if len(self.forward_queue) == 0 and (
790
+ if len(self.waiting_queue) == 0 and (
695
791
  self.running_batch is None or len(self.running_batch.reqs) == 0
696
792
  ):
697
793
  self.tree_cache.reset()
@@ -704,20 +800,20 @@ class ModelTpServer:
704
800
  else:
705
801
  warnings.warn(
706
802
  f"Cache not flushed because there are pending requests. "
707
- f"#queue-req: {len(self.forward_queue)}, "
803
+ f"#queue-req: {len(self.waiting_queue)}, "
708
804
  f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
709
805
  )
710
806
 
711
807
  def abort_request(self, recv_req):
712
808
  # Delete requests in the waiting queue
713
809
  to_del = None
714
- for i, req in enumerate(self.forward_queue):
810
+ for i, req in enumerate(self.waiting_queue):
715
811
  if req.rid == recv_req.rid:
716
812
  to_del = i
717
813
  break
718
814
 
719
815
  if to_del is not None:
720
- del self.forward_queue[to_del]
816
+ del self.waiting_queue[to_del]
721
817
 
722
818
  # Delete requests in the running batch
723
819
  if self.running_batch:
@@ -0,0 +1,33 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
16
+ """
17
+ Flush the KV cache.
18
+
19
+ Usage:
20
+ python3 -m sglang.srt.mem_cache.flush_cache --url http://localhost:30000
21
+ """
22
+
23
+ import argparse
24
+
25
+ import requests
26
+
27
+ if __name__ == "__main__":
28
+ parser = argparse.ArgumentParser()
29
+ parser.add_argument("--url", type=str, default="http://localhost:30000")
30
+ args = parser.parse_args()
31
+
32
+ response = requests.get(args.url + "/flush_cache")
33
+ assert response.status_code == 200
@@ -1,3 +1,18 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
1
16
  """Memory pool."""
2
17
 
3
18
  import logging
@@ -30,7 +45,7 @@ class ReqToTokenPool:
30
45
 
31
46
  return select_index
32
47
 
33
- def free(self, free_index: int):
48
+ def free(self, free_index):
34
49
  self.mem_state[free_index] = True
35
50
  if isinstance(free_index, (int,)):
36
51
  self.can_use_mem_size += 1
@@ -1,3 +1,18 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
1
16
  """
2
17
  The radix tree data structure for managing the KV cache.
3
18
  """
sglang/srt/mm_utils.py CHANGED
@@ -1,3 +1,18 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
1
16
  # Source: https://github.com/haotian-liu/LLaVA/blob/main/llava/mm_utils.py
2
17
  import ast
3
18
  import base64
@@ -1,3 +1,18 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
1
16
  from typing import Optional
2
17
 
3
18
  from transformers import PretrainedConfig
@@ -36,6 +51,11 @@ class ModelConfig:
36
51
  "head_dim",
37
52
  self.hf_config.hidden_size // self.hf_config.num_attention_heads,
38
53
  )
54
+
55
+ # FIXME: temporary special judge for deepseek v2 MLA architecture
56
+ if "DeepseekV2ForCausalLM" in self.hf_config.architectures:
57
+ self.head_dim = 256
58
+
39
59
  self.num_attention_heads = self.hf_config.num_attention_heads
40
60
  self.num_key_value_heads = getattr(self.hf_config, "num_key_value_heads", None)
41
61