sglang 0.2.6__py3-none-any.whl → 0.2.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. sglang/__init__.py +33 -26
  2. sglang/api.py +9 -1
  3. sglang/bench_latency.py +2 -2
  4. sglang/bench_serving.py +10 -1
  5. sglang/check_env.py +1 -1
  6. sglang/lang/backend/litellm.py +1 -1
  7. sglang/lang/backend/openai.py +1 -1
  8. sglang/lang/interpreter.py +20 -5
  9. sglang/lang/ir.py +1 -1
  10. sglang/srt/constrained/__init__.py +15 -0
  11. sglang/srt/constrained/base_cache.py +15 -0
  12. sglang/srt/constrained/fsm_cache.py +15 -0
  13. sglang/srt/constrained/jump_forward.py +15 -0
  14. sglang/srt/conversation.py +26 -0
  15. sglang/srt/hf_transformers_utils.py +15 -0
  16. sglang/srt/layers/context_flashattention_nopad.py +15 -0
  17. sglang/srt/layers/extend_attention.py +15 -0
  18. sglang/srt/layers/fused_moe.py +15 -0
  19. sglang/srt/layers/linear.py +15 -0
  20. sglang/srt/layers/logits_processor.py +41 -13
  21. sglang/srt/layers/quantization/__init__.py +15 -0
  22. sglang/srt/layers/quantization/fp8.py +15 -0
  23. sglang/srt/layers/radix_attention.py +17 -2
  24. sglang/srt/layers/token_attention.py +16 -1
  25. sglang/srt/managers/{controller/manager_multi.py → controller_multi.py} +17 -2
  26. sglang/srt/managers/{controller/manager_single.py → controller_single.py} +17 -2
  27. sglang/srt/managers/detokenizer_manager.py +16 -1
  28. sglang/srt/managers/io_struct.py +36 -3
  29. sglang/srt/managers/{controller/schedule_heuristic.py → policy_scheduler.py} +37 -22
  30. sglang/srt/managers/{controller/infer_batch.py → schedule_batch.py} +31 -12
  31. sglang/srt/managers/tokenizer_manager.py +39 -16
  32. sglang/srt/managers/{controller/tp_worker.py → tp_worker.py} +130 -40
  33. sglang/srt/mem_cache/flush_cache.py +33 -0
  34. sglang/srt/{memory_pool.py → mem_cache/memory_pool.py} +16 -1
  35. sglang/srt/{managers/controller → mem_cache}/radix_cache.py +15 -0
  36. sglang/srt/mm_utils.py +15 -0
  37. sglang/srt/model_config.py +15 -0
  38. sglang/srt/{managers/controller → model_executor}/cuda_graph_runner.py +16 -1
  39. sglang/srt/{managers/controller → model_executor}/model_runner.py +32 -12
  40. sglang/srt/model_loader/model_loader.py +15 -0
  41. sglang/srt/model_loader/utils.py +16 -1
  42. sglang/srt/models/chatglm.py +16 -1
  43. sglang/srt/models/commandr.py +16 -1
  44. sglang/srt/models/dbrx.py +16 -1
  45. sglang/srt/models/deepseek.py +16 -1
  46. sglang/srt/models/deepseek_v2.py +16 -1
  47. sglang/srt/models/gemma.py +16 -1
  48. sglang/srt/models/gemma2.py +16 -1
  49. sglang/srt/models/gpt_bigcode.py +16 -1
  50. sglang/srt/models/grok.py +16 -1
  51. sglang/srt/models/internlm2.py +16 -1
  52. sglang/srt/models/llama2.py +16 -1
  53. sglang/srt/models/llama_classification.py +16 -1
  54. sglang/srt/models/llava.py +17 -2
  55. sglang/srt/models/llavavid.py +17 -2
  56. sglang/srt/models/minicpm.py +16 -1
  57. sglang/srt/models/mistral.py +15 -0
  58. sglang/srt/models/mixtral.py +16 -1
  59. sglang/srt/models/mixtral_quant.py +16 -1
  60. sglang/srt/models/qwen.py +16 -1
  61. sglang/srt/models/qwen2.py +16 -1
  62. sglang/srt/models/qwen2_moe.py +16 -1
  63. sglang/srt/models/stablelm.py +16 -1
  64. sglang/srt/models/yivl.py +15 -0
  65. sglang/srt/openai_api/adapter.py +520 -135
  66. sglang/srt/openai_api/protocol.py +64 -0
  67. sglang/srt/sampling_params.py +15 -0
  68. sglang/srt/server.py +89 -23
  69. sglang/srt/server_args.py +49 -11
  70. sglang/srt/utils.py +15 -0
  71. sglang/utils.py +22 -0
  72. sglang/version.py +1 -1
  73. {sglang-0.2.6.dist-info → sglang-0.2.7.dist-info}/METADATA +32 -6
  74. sglang-0.2.7.dist-info/RECORD +93 -0
  75. {sglang-0.2.6.dist-info → sglang-0.2.7.dist-info}/WHEEL +1 -1
  76. sglang/srt/flush_cache.py +0 -18
  77. sglang-0.2.6.dist-info/RECORD +0 -93
  78. {sglang-0.2.6.dist-info → sglang-0.2.7.dist-info}/LICENSE +0 -0
  79. {sglang-0.2.6.dist-info → sglang-0.2.7.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,18 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
1
16
  """A tensor parallel worker."""
2
17
 
3
18
  import logging
@@ -14,23 +29,23 @@ from sglang.global_config import global_config
14
29
  from sglang.srt.constrained.fsm_cache import FSMCache
15
30
  from sglang.srt.constrained.jump_forward import JumpForwardCache
16
31
  from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
17
- from sglang.srt.managers.controller.infer_batch import (
18
- FINISH_ABORT,
19
- BaseFinishReason,
20
- Batch,
21
- ForwardMode,
22
- Req,
23
- )
24
- from sglang.srt.managers.controller.model_runner import ModelRunner
25
- from sglang.srt.managers.controller.radix_cache import RadixCache
26
- from sglang.srt.managers.controller.schedule_heuristic import ScheduleHeuristic
27
32
  from sglang.srt.managers.io_struct import (
28
33
  AbortReq,
29
34
  BatchTokenIDOut,
30
35
  FlushCacheReq,
31
36
  TokenizedGenerateReqInput,
32
37
  )
38
+ from sglang.srt.managers.policy_scheduler import PolicyScheduler
39
+ from sglang.srt.managers.schedule_batch import (
40
+ FINISH_ABORT,
41
+ BaseFinishReason,
42
+ Batch,
43
+ ForwardMode,
44
+ Req,
45
+ )
46
+ from sglang.srt.mem_cache.radix_cache import RadixCache
33
47
  from sglang.srt.model_config import ModelConfig
48
+ from sglang.srt.model_executor.model_runner import ModelRunner
34
49
  from sglang.srt.server_args import ServerArgs
35
50
  from sglang.srt.utils import (
36
51
  get_int_token_logit_bias,
@@ -40,7 +55,7 @@ from sglang.srt.utils import (
40
55
  )
41
56
  from sglang.utils import get_exception_traceback
42
57
 
43
- logger = logging.getLogger("srt.tp_worker")
58
+ logger = logging.getLogger(__name__)
44
59
 
45
60
 
46
61
  class ModelTpServer:
@@ -59,9 +74,13 @@ class ModelTpServer:
59
74
  self.tp_rank = tp_rank
60
75
  self.tp_size = server_args.tp_size
61
76
  self.dp_size = server_args.dp_size
62
- self.schedule_heuristic = server_args.schedule_heuristic
77
+ self.schedule_policy = server_args.schedule_policy
63
78
  self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
64
79
 
80
+ # Chunked prefill
81
+ self.chunked_prefill_size = server_args.chunked_prefill_size
82
+ self.current_inflight_req = None
83
+
65
84
  # Init model and tokenizer
66
85
  self.model_config = ModelConfig(
67
86
  server_args.model_path,
@@ -117,7 +136,7 @@ class ModelTpServer:
117
136
 
118
137
  # Print info
119
138
  logger.info(
120
- f"[gpu_id={self.gpu_id}] "
139
+ f"[gpu={self.gpu_id}] "
121
140
  f"max_total_num_tokens={self.max_total_num_tokens}, "
122
141
  f"max_prefill_tokens={self.max_prefill_tokens}, "
123
142
  f"max_running_requests={self.max_running_requests}, "
@@ -131,8 +150,8 @@ class ModelTpServer:
131
150
  disable=server_args.disable_radix_cache,
132
151
  )
133
152
  self.tree_cache_metrics = {"total": 0, "hit": 0}
134
- self.scheduler = ScheduleHeuristic(
135
- self.schedule_heuristic,
153
+ self.scheduler = PolicyScheduler(
154
+ self.schedule_policy,
136
155
  self.max_running_requests,
137
156
  self.max_prefill_tokens,
138
157
  self.max_total_num_tokens,
@@ -142,7 +161,7 @@ class ModelTpServer:
142
161
  self.token_to_kv_pool = self.model_runner.token_to_kv_pool
143
162
 
144
163
  # Init running status
145
- self.forward_queue: List[Req] = []
164
+ self.waiting_queue: List[Req] = []
146
165
  self.running_batch: Batch = None
147
166
  self.out_pyobjs = []
148
167
  self.decode_forward_ct = 0
@@ -205,6 +224,7 @@ class ModelTpServer:
205
224
  # Run a new prefill batch
206
225
  self.forward_prefill_batch(new_batch)
207
226
  self.cache_filled_batch(new_batch)
227
+ self.filter_out_inflight(new_batch)
208
228
 
209
229
  if not new_batch.is_empty():
210
230
  if self.running_batch is None:
@@ -241,12 +261,12 @@ class ModelTpServer:
241
261
  self.num_generated_tokens = 0
242
262
  self.last_stats_tic = time.time()
243
263
  logger.info(
244
- f"[gpu_id={self.gpu_id}] Decode batch. "
264
+ f"[gpu={self.gpu_id}] Decode batch. "
245
265
  f"#running-req: {len(self.running_batch.reqs)}, "
246
266
  f"#token: {num_used}, "
247
267
  f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
248
268
  f"gen throughput (token/s): {throughput:.2f}, "
249
- f"#queue-req: {len(self.forward_queue)}"
269
+ f"#queue-req: {len(self.waiting_queue)}"
250
270
  )
251
271
 
252
272
  def check_memory(self):
@@ -313,9 +333,10 @@ class ModelTpServer:
313
333
  ),
314
334
  self.max_req_input_len - 1 - len(req.origin_input_ids),
315
335
  )
316
- self.forward_queue.append(req)
336
+ self.waiting_queue.append(req)
317
337
 
318
338
  def get_new_prefill_batch(self) -> Optional[Batch]:
339
+ # TODO(lsyin): organize this function
319
340
  running_bs = (
320
341
  len(self.running_batch.reqs) if self.running_batch is not None else 0
321
342
  )
@@ -323,7 +344,7 @@ class ModelTpServer:
323
344
  return
324
345
 
325
346
  # Compute matched prefix length
326
- for req in self.forward_queue:
347
+ for req in self.waiting_queue:
327
348
  req.input_ids = req.origin_input_ids + req.output_ids
328
349
  prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids)
329
350
  if req.return_logprob:
@@ -333,7 +354,7 @@ class ModelTpServer:
333
354
  req.last_node = last_node
334
355
 
335
356
  # Get priority queue
336
- self.forward_queue = self.scheduler.get_priority_queue(self.forward_queue)
357
+ self.waiting_queue = self.scheduler.get_priority_queue(self.waiting_queue)
337
358
 
338
359
  # Add requests if there is available space
339
360
  can_run_list = []
@@ -352,7 +373,33 @@ class ModelTpServer:
352
373
  ]
353
374
  )
354
375
 
355
- for req in self.forward_queue:
376
+ # Handle the current inflight request
377
+ take_inflight = 0
378
+ if self.current_inflight_req:
379
+ take_inflight = 1
380
+ r = self.current_inflight_req
381
+ r.input_ids = r.origin_input_ids + r.output_ids
382
+ truncated = (
383
+ len(r.input_ids) - len(r.prefix_indices) > self.chunked_prefill_size
384
+ )
385
+ r.extend_input_len = min(
386
+ len(r.input_ids) - len(r.prefix_indices), self.chunked_prefill_size
387
+ )
388
+ r.input_ids = r.input_ids[: len(r.prefix_indices) + r.extend_input_len]
389
+ can_run_list.append(r)
390
+
391
+ if not truncated:
392
+ # Finish inflight
393
+ self.current_inflight_req = None
394
+ new_batch_total_tokens += (
395
+ r.extend_input_len + r.sampling_params.max_new_tokens
396
+ )
397
+ new_batch_input_tokens += r.extend_input_len
398
+ else:
399
+ new_batch_total_tokens += r.extend_input_len
400
+ new_batch_input_tokens += r.extend_input_len
401
+
402
+ for req in self.waiting_queue:
356
403
  if req.return_logprob and req.normalized_prompt_logprob is None:
357
404
  # Need at least two tokens to compute normalized logprob
358
405
  if req.extend_input_len < 2:
@@ -394,11 +441,39 @@ class ModelTpServer:
394
441
  break
395
442
  else:
396
443
  # Add this request to the running batch
397
- can_run_list.append(req)
398
- new_batch_total_tokens += (
399
- req.extend_input_len + req.sampling_params.max_new_tokens
400
- )
401
- new_batch_input_tokens += req.extend_input_len
444
+ if (
445
+ self.chunked_prefill_size is None
446
+ or (
447
+ new_batch_input_tokens + req.extend_input_len
448
+ <= self.chunked_prefill_size
449
+ )
450
+ or (
451
+ req.return_logprob and req.normalized_prompt_logprob is None
452
+ )
453
+ ):
454
+ can_run_list.append(req)
455
+ new_batch_total_tokens += (
456
+ req.extend_input_len + req.sampling_params.max_new_tokens
457
+ )
458
+ new_batch_input_tokens += req.extend_input_len
459
+ else:
460
+ trunc_len = self.chunked_prefill_size - new_batch_input_tokens
461
+
462
+ if trunc_len <= 0:
463
+ # Undo locking
464
+ delta = self.tree_cache.dec_lock_ref(req.last_node)
465
+ available_size += delta
466
+ break
467
+
468
+ req.extend_input_len = trunc_len
469
+ req.input_ids = req.input_ids[
470
+ : len(req.prefix_indices) + req.extend_input_len
471
+ ]
472
+ can_run_list.append(req)
473
+ self.current_inflight_req = req
474
+ new_batch_input_tokens += req.extend_input_len
475
+ new_batch_total_tokens += req.extend_input_len
476
+ break
402
477
  else:
403
478
  break
404
479
 
@@ -419,13 +494,13 @@ class ModelTpServer:
419
494
  self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
420
495
  )
421
496
  logger.info(
422
- f"[gpu_id={self.gpu_id}] Prefill batch. "
497
+ f"[gpu={self.gpu_id}] Prefill batch. "
423
498
  f"#new-seq: {len(can_run_list)}, "
424
499
  f"#new-token: {new_batch_input_tokens}, "
425
500
  f"#cached-token: {hit_tokens}, "
426
501
  f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
427
502
  f"#running-req: {running_bs}, "
428
- f"#queue-req: {len(self.forward_queue) - len(can_run_list)}"
503
+ f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + take_inflight}"
429
504
  )
430
505
 
431
506
  # Return the new batch
@@ -435,7 +510,7 @@ class ModelTpServer:
435
510
  self.token_to_kv_pool,
436
511
  self.tree_cache,
437
512
  )
438
- self.forward_queue = [x for x in self.forward_queue if x not in can_run_list]
513
+ self.waiting_queue = [x for x in self.waiting_queue if x not in can_run_list]
439
514
  return new_batch
440
515
 
441
516
  def forward_prefill_batch(self, batch: Batch):
@@ -467,9 +542,10 @@ class ModelTpServer:
467
542
  # Check finish conditions
468
543
  pt = 0
469
544
  for i, req in enumerate(batch.reqs):
470
- req.completion_tokens_wo_jump_forward += 1
471
- req.output_ids.append(next_token_ids[i])
472
- req.check_finished()
545
+ if req is not self.current_inflight_req:
546
+ req.completion_tokens_wo_jump_forward += 1
547
+ req.output_ids.append(next_token_ids[i])
548
+ req.check_finished()
473
549
 
474
550
  if req.return_logprob:
475
551
  self.add_logprob_return_values(i, req, pt, next_token_ids, output)
@@ -530,7 +606,7 @@ class ModelTpServer:
530
606
  req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
531
607
  for i, req in enumerate(batch.reqs):
532
608
  new_prefix_indices, new_last_node = self.tree_cache.cache_req(
533
- token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1],
609
+ token_ids=tuple(req.input_ids),
534
610
  last_uncached_pos=len(req.prefix_indices),
535
611
  req_pool_idx=req_pool_indices_cpu[i],
536
612
  del_in_memory_pool=False,
@@ -538,6 +614,10 @@ class ModelTpServer:
538
614
  )
539
615
  req.prefix_indices, req.last_node = new_prefix_indices, new_last_node
540
616
 
617
+ if req is self.current_inflight_req:
618
+ # inflight request would get a new req idx
619
+ self.req_to_token_pool.free(int(req_pool_indices_cpu[i]))
620
+
541
621
  def forward_decode_batch(self, batch: Batch):
542
622
  # Check if decode out of memory
543
623
  if not batch.check_decode_mem():
@@ -551,7 +631,7 @@ class ModelTpServer:
551
631
  f"#retracted_reqs: {len(retracted_reqs)}, "
552
632
  f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
553
633
  )
554
- self.forward_queue.extend(retracted_reqs)
634
+ self.waiting_queue.extend(retracted_reqs)
555
635
  else:
556
636
  self.new_token_ratio = max(
557
637
  self.new_token_ratio - self.new_token_ratio_decay,
@@ -561,7 +641,7 @@ class ModelTpServer:
561
641
  if not self.disable_regex_jump_forward:
562
642
  # Check for jump-forward
563
643
  jump_forward_reqs = batch.check_for_jump_forward(self.model_runner)
564
- self.forward_queue.extend(jump_forward_reqs)
644
+ self.waiting_queue.extend(jump_forward_reqs)
565
645
  if batch.is_empty():
566
646
  return
567
647
 
@@ -696,8 +776,18 @@ class ModelTpServer:
696
776
  else:
697
777
  batch.reqs = []
698
778
 
779
+ def filter_out_inflight(self, batch: Batch):
780
+ # TODO(lsyin): reduce the overhead, make a special version for this
781
+ if self.current_inflight_req is None:
782
+ return
783
+
784
+ to_remove = batch.reqs.index(self.current_inflight_req)
785
+ unfinished_indices = [i for i in range(len(batch.reqs)) if i != to_remove]
786
+
787
+ batch.filter_batch(unfinished_indices)
788
+
699
789
  def flush_cache(self):
700
- if len(self.forward_queue) == 0 and (
790
+ if len(self.waiting_queue) == 0 and (
701
791
  self.running_batch is None or len(self.running_batch.reqs) == 0
702
792
  ):
703
793
  self.tree_cache.reset()
@@ -710,20 +800,20 @@ class ModelTpServer:
710
800
  else:
711
801
  warnings.warn(
712
802
  f"Cache not flushed because there are pending requests. "
713
- f"#queue-req: {len(self.forward_queue)}, "
803
+ f"#queue-req: {len(self.waiting_queue)}, "
714
804
  f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
715
805
  )
716
806
 
717
807
  def abort_request(self, recv_req):
718
808
  # Delete requests in the waiting queue
719
809
  to_del = None
720
- for i, req in enumerate(self.forward_queue):
810
+ for i, req in enumerate(self.waiting_queue):
721
811
  if req.rid == recv_req.rid:
722
812
  to_del = i
723
813
  break
724
814
 
725
815
  if to_del is not None:
726
- del self.forward_queue[to_del]
816
+ del self.waiting_queue[to_del]
727
817
 
728
818
  # Delete requests in the running batch
729
819
  if self.running_batch:
@@ -0,0 +1,33 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
16
+ """
17
+ Flush the KV cache.
18
+
19
+ Usage:
20
+ python3 -m sglang.srt.mem_cache.flush_cache --url http://localhost:30000
21
+ """
22
+
23
+ import argparse
24
+
25
+ import requests
26
+
27
+ if __name__ == "__main__":
28
+ parser = argparse.ArgumentParser()
29
+ parser.add_argument("--url", type=str, default="http://localhost:30000")
30
+ args = parser.parse_args()
31
+
32
+ response = requests.get(args.url + "/flush_cache")
33
+ assert response.status_code == 200
@@ -1,3 +1,18 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
1
16
  """Memory pool."""
2
17
 
3
18
  import logging
@@ -30,7 +45,7 @@ class ReqToTokenPool:
30
45
 
31
46
  return select_index
32
47
 
33
- def free(self, free_index: int):
48
+ def free(self, free_index):
34
49
  self.mem_state[free_index] = True
35
50
  if isinstance(free_index, (int,)):
36
51
  self.can_use_mem_size += 1
@@ -1,3 +1,18 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
1
16
  """
2
17
  The radix tree data structure for managing the KV cache.
3
18
  """
sglang/srt/mm_utils.py CHANGED
@@ -1,3 +1,18 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
1
16
  # Source: https://github.com/haotian-liu/LLaVA/blob/main/llava/mm_utils.py
2
17
  import ast
3
18
  import base64
@@ -1,3 +1,18 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
1
16
  from typing import Optional
2
17
 
3
18
  from transformers import PretrainedConfig
@@ -1,3 +1,18 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
1
16
  """Run the model with cuda graph."""
2
17
 
3
18
  import bisect
@@ -14,7 +29,7 @@ from sglang.srt.layers.logits_processor import (
14
29
  LogitsMetadata,
15
30
  LogitsProcessor,
16
31
  )
17
- from sglang.srt.managers.controller.infer_batch import (
32
+ from sglang.srt.managers.schedule_batch import (
18
33
  Batch,
19
34
  ForwardMode,
20
35
  InputMetadata,
@@ -1,3 +1,18 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
1
16
  """ModelRunner runs the forward passes of the models."""
2
17
 
3
18
  import importlib
@@ -25,13 +40,13 @@ from vllm.distributed import (
25
40
  from vllm.model_executor.models import ModelRegistry
26
41
 
27
42
  from sglang.global_config import global_config
28
- from sglang.srt.managers.controller.infer_batch import (
43
+ from sglang.srt.managers.schedule_batch import (
29
44
  Batch,
30
45
  ForwardMode,
31
46
  InputMetadata,
32
47
  global_server_args_dict,
33
48
  )
34
- from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
49
+ from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPool
35
50
  from sglang.srt.server_args import ServerArgs
36
51
  from sglang.srt.utils import (
37
52
  get_available_gpu_memory,
@@ -42,7 +57,7 @@ from sglang.srt.utils import (
42
57
  monkey_patch_vllm_qvk_linear_loader,
43
58
  )
44
59
 
45
- logger = logging.getLogger("srt.model_runner")
60
+ logger = logging.getLogger(__name__)
46
61
 
47
62
 
48
63
  class ModelRunner:
@@ -75,7 +90,7 @@ class ModelRunner:
75
90
 
76
91
  # Init torch distributed
77
92
  torch.cuda.set_device(self.gpu_id)
78
- logger.info(f"[gpu_id={self.gpu_id}] Init nccl begin.")
93
+ logger.info(f"[gpu={self.gpu_id}] Init nccl begin.")
79
94
 
80
95
  if not server_args.enable_p2p_check:
81
96
  monkey_patch_vllm_p2p_access_check(self.gpu_id)
@@ -115,7 +130,7 @@ class ModelRunner:
115
130
 
116
131
  def load_model(self):
117
132
  logger.info(
118
- f"[gpu_id={self.gpu_id}] Load weight begin. "
133
+ f"[gpu={self.gpu_id}] Load weight begin. "
119
134
  f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
120
135
  )
121
136
 
@@ -163,7 +178,7 @@ class ModelRunner:
163
178
  cache_config=None,
164
179
  )
165
180
  logger.info(
166
- f"[gpu_id={self.gpu_id}] Load weight end. "
181
+ f"[gpu={self.gpu_id}] Load weight end. "
167
182
  f"type={type(self.model).__name__}, "
168
183
  f"dtype={self.dtype}, "
169
184
  f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
@@ -197,9 +212,14 @@ class ModelRunner:
197
212
  )
198
213
 
199
214
  if max_num_reqs is None:
200
- max_num_reqs = max(
201
- int(self.max_total_num_tokens / self.model_config.context_len * 512),
202
- 2048,
215
+ max_num_reqs = min(
216
+ max(
217
+ int(
218
+ self.max_total_num_tokens / self.model_config.context_len * 512
219
+ ),
220
+ 2048,
221
+ ),
222
+ 5120,
203
223
  )
204
224
 
205
225
  self.req_to_token_pool = ReqToTokenPool(
@@ -214,7 +234,7 @@ class ModelRunner:
214
234
  layer_num=self.model_config.num_hidden_layers,
215
235
  )
216
236
  logger.info(
217
- f"[gpu_id={self.gpu_id}] Memory pool end. "
237
+ f"[gpu={self.gpu_id}] Memory pool end. "
218
238
  f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
219
239
  )
220
240
 
@@ -258,14 +278,14 @@ class ModelRunner:
258
278
  )
259
279
 
260
280
  def init_cuda_graphs(self):
261
- from sglang.srt.managers.controller.cuda_graph_runner import CudaGraphRunner
281
+ from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
262
282
 
263
283
  if self.server_args.disable_cuda_graph or self.server_args.disable_flashinfer:
264
284
  self.cuda_graph_runner = None
265
285
  return
266
286
 
267
287
  logger.info(
268
- f"[gpu_id={self.gpu_id}] Capture cuda graph begin. This can take up to several minutes."
288
+ f"[gpu={self.gpu_id}] Capture cuda graph begin. This can take up to several minutes."
269
289
  )
270
290
  batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 17)]
271
291
  self.cuda_graph_runner = CudaGraphRunner(
@@ -1,3 +1,18 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
1
16
  # temporarily adapted from https://github.com/vllm-project/vllm/blob/10383887e03412196a2689b9398290719c4797bf/vllm/model_executor/model_loader/loader.py
2
17
  # FIXME: in progress of refactoring the model loader
3
18
 
@@ -1,3 +1,18 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
1
16
  # temporarily adapted from vLLM
2
17
  # FIXME: in progress of refactoring the model loader
3
18
  """Utilities for selecting and loading models."""
@@ -23,7 +38,7 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
23
38
 
24
39
  from sglang.srt.layers.quantization import get_quantization_config
25
40
 
26
- logger = logging.getLogger("srt.model_loader")
41
+ logger = logging.getLogger(__name__)
27
42
  temp_dir = tempfile.gettempdir()
28
43
 
29
44