sglang 0.3.0__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. sglang/bench_latency.py +10 -6
  2. sglang/bench_serving.py +33 -38
  3. sglang/global_config.py +0 -4
  4. sglang/lang/backend/runtime_endpoint.py +5 -2
  5. sglang/lang/interpreter.py +1 -1
  6. sglang/launch_server.py +3 -6
  7. sglang/launch_server_llavavid.py +7 -8
  8. sglang/srt/{model_config.py → configs/model_config.py} +5 -0
  9. sglang/srt/constrained/__init__.py +2 -0
  10. sglang/srt/constrained/fsm_cache.py +29 -38
  11. sglang/srt/constrained/jump_forward.py +0 -1
  12. sglang/srt/conversation.py +4 -1
  13. sglang/srt/hf_transformers_utils.py +1 -3
  14. sglang/srt/layers/attention_backend.py +480 -0
  15. sglang/srt/layers/flashinfer_utils.py +235 -0
  16. sglang/srt/layers/logits_processor.py +64 -77
  17. sglang/srt/layers/radix_attention.py +11 -161
  18. sglang/srt/layers/sampler.py +6 -25
  19. sglang/srt/layers/torchao_utils.py +75 -0
  20. sglang/srt/layers/{decode_attention.py → triton_attention/decode_attention.py} +67 -63
  21. sglang/srt/layers/{extend_attention.py → triton_attention/extend_attention.py} +40 -132
  22. sglang/srt/layers/{prefill_attention.py → triton_attention/prefill_attention.py} +13 -7
  23. sglang/srt/lora/lora.py +403 -0
  24. sglang/srt/lora/lora_config.py +43 -0
  25. sglang/srt/lora/lora_manager.py +256 -0
  26. sglang/srt/managers/controller_multi.py +1 -5
  27. sglang/srt/managers/controller_single.py +0 -5
  28. sglang/srt/managers/io_struct.py +16 -1
  29. sglang/srt/managers/policy_scheduler.py +122 -5
  30. sglang/srt/managers/schedule_batch.py +104 -71
  31. sglang/srt/managers/tokenizer_manager.py +17 -8
  32. sglang/srt/managers/tp_worker.py +181 -115
  33. sglang/srt/model_executor/cuda_graph_runner.py +58 -133
  34. sglang/srt/model_executor/forward_batch_info.py +35 -312
  35. sglang/srt/model_executor/model_runner.py +117 -131
  36. sglang/srt/models/baichuan.py +416 -0
  37. sglang/srt/models/chatglm.py +1 -5
  38. sglang/srt/models/commandr.py +1 -5
  39. sglang/srt/models/dbrx.py +1 -5
  40. sglang/srt/models/deepseek.py +1 -5
  41. sglang/srt/models/deepseek_v2.py +1 -5
  42. sglang/srt/models/exaone.py +1 -5
  43. sglang/srt/models/gemma.py +1 -5
  44. sglang/srt/models/gemma2.py +1 -5
  45. sglang/srt/models/gpt_bigcode.py +1 -5
  46. sglang/srt/models/grok.py +1 -5
  47. sglang/srt/models/internlm2.py +1 -5
  48. sglang/srt/models/llama.py +51 -5
  49. sglang/srt/models/llama_classification.py +1 -20
  50. sglang/srt/models/llava.py +30 -5
  51. sglang/srt/models/llavavid.py +2 -2
  52. sglang/srt/models/minicpm.py +1 -5
  53. sglang/srt/models/minicpm3.py +665 -0
  54. sglang/srt/models/mixtral.py +6 -5
  55. sglang/srt/models/mixtral_quant.py +1 -5
  56. sglang/srt/models/qwen.py +1 -5
  57. sglang/srt/models/qwen2.py +1 -5
  58. sglang/srt/models/qwen2_moe.py +6 -5
  59. sglang/srt/models/stablelm.py +1 -5
  60. sglang/srt/models/xverse.py +375 -0
  61. sglang/srt/models/xverse_moe.py +445 -0
  62. sglang/srt/openai_api/adapter.py +65 -46
  63. sglang/srt/openai_api/protocol.py +11 -3
  64. sglang/srt/sampling/sampling_batch_info.py +57 -44
  65. sglang/srt/server.py +24 -14
  66. sglang/srt/server_args.py +130 -28
  67. sglang/srt/utils.py +12 -0
  68. sglang/test/few_shot_gsm8k.py +132 -0
  69. sglang/test/runners.py +114 -22
  70. sglang/test/test_programs.py +7 -5
  71. sglang/test/test_utils.py +85 -1
  72. sglang/utils.py +32 -37
  73. sglang/version.py +1 -1
  74. {sglang-0.3.0.dist-info → sglang-0.3.1.dist-info}/METADATA +30 -18
  75. sglang-0.3.1.dist-info/RECORD +129 -0
  76. {sglang-0.3.0.dist-info → sglang-0.3.1.dist-info}/WHEEL +1 -1
  77. sglang-0.3.0.dist-info/RECORD +0 -118
  78. {sglang-0.3.0.dist-info → sglang-0.3.1.dist-info}/LICENSE +0 -0
  79. {sglang-0.3.0.dist-info → sglang-0.3.1.dist-info}/top_level.txt +0 -0
@@ -25,12 +25,6 @@ from typing import Optional, Tuple, Type
25
25
 
26
26
  import torch
27
27
  import torch.nn as nn
28
- from flashinfer import (
29
- BatchDecodeWithPagedKVCacheWrapper,
30
- BatchPrefillWithPagedKVCacheWrapper,
31
- BatchPrefillWithRaggedKVCacheWrapper,
32
- )
33
- from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
34
28
  from vllm.config import DeviceConfig, LoadConfig
35
29
  from vllm.config import ModelConfig as VllmModelConfig
36
30
  from vllm.distributed import (
@@ -43,17 +37,19 @@ from vllm.distributed.parallel_state import in_the_same_node_as
43
37
  from vllm.model_executor.model_loader import get_model
44
38
  from vllm.model_executor.models import ModelRegistry
45
39
 
46
- from sglang.global_config import global_config
40
+ from sglang.srt.configs.model_config import AttentionArch, ModelConfig
41
+ from sglang.srt.layers.attention_backend import FlashInferAttnBackend, TritonAttnBackend
47
42
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
48
- from sglang.srt.layers.sampler import SampleOutput
43
+ from sglang.srt.layers.sampler import SampleOutput, Sampler
44
+ from sglang.srt.lora.lora_manager import LoRAManager
49
45
  from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
50
46
  from sglang.srt.mem_cache.memory_pool import (
51
47
  MHATokenToKVPool,
52
48
  MLATokenToKVPool,
53
49
  ReqToTokenPool,
54
50
  )
55
- from sglang.srt.model_config import AttentionArch, ModelConfig
56
- from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
51
+ from sglang.srt.model_executor.forward_batch_info import InputMetadata
52
+ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
57
53
  from sglang.srt.server_args import ServerArgs
58
54
  from sglang.srt.utils import (
59
55
  get_available_gpu_memory,
@@ -69,6 +65,8 @@ logger = logging.getLogger(__name__)
69
65
 
70
66
 
71
67
  class ModelRunner:
68
+ """ModelRunner runs the forward passes of the models."""
69
+
72
70
  def __init__(
73
71
  self,
74
72
  model_config: ModelConfig,
@@ -92,13 +90,15 @@ class ModelRunner:
92
90
  )
93
91
  global_server_args_dict.update(
94
92
  {
95
- "disable_flashinfer": server_args.disable_flashinfer,
96
- "disable_flashinfer_sampling": server_args.disable_flashinfer_sampling,
93
+ "attention_backend": server_args.attention_backend,
94
+ "sampling_backend": server_args.sampling_backend,
97
95
  "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
98
96
  "enable_mla": server_args.enable_mla,
97
+ "torchao_config": server_args.torchao_config,
99
98
  }
100
99
  )
101
100
 
101
+ # Model-specific adjustment
102
102
  if self.is_multimodal_model:
103
103
  logger.info(
104
104
  "Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
@@ -106,15 +106,19 @@ class ModelRunner:
106
106
  server_args.chunked_prefill_size = None
107
107
  server_args.mem_fraction_static *= 0.95
108
108
 
109
+ # Init componnets
109
110
  min_per_gpu_memory = self.init_torch_distributed()
111
+ self.sampler = Sampler()
110
112
  self.load_model()
113
+ if server_args.lora_paths is not None:
114
+ self.init_lora_manager()
111
115
  self.init_memory_pool(
112
116
  min_per_gpu_memory,
113
- server_args.max_num_reqs,
117
+ server_args.max_running_requests,
114
118
  server_args.max_total_tokens,
115
119
  )
116
120
  self.init_cublas()
117
- self.init_flashinfer()
121
+ self.init_attention_backend()
118
122
  self.init_cuda_graphs()
119
123
 
120
124
  def init_torch_distributed(self):
@@ -313,6 +317,17 @@ class ModelRunner:
313
317
  logger.info("Update weights end.")
314
318
  return True, "Succeeded to update model weights"
315
319
 
320
+ def init_lora_manager(self):
321
+ self.lora_manager = LoRAManager(
322
+ base_model=self.model,
323
+ lora_paths=self.server_args.lora_paths,
324
+ base_hf_config=self.model_config.hf_config,
325
+ max_loras_per_batch=self.server_args.max_loras_per_batch,
326
+ load_config=self.load_config,
327
+ dtype=self.dtype,
328
+ )
329
+ logger.info("LoRA manager ready.")
330
+
316
331
  def profile_max_num_token(self, total_gpu_memory: int):
317
332
  available_gpu_memory = get_available_gpu_memory(
318
333
  self.gpu_id, distributed=self.tp_size > 1
@@ -343,8 +358,8 @@ class ModelRunner:
343
358
  def init_memory_pool(
344
359
  self,
345
360
  total_gpu_memory: int,
346
- max_num_reqs: int = None,
347
- max_total_tokens: int = None,
361
+ max_num_reqs: Optional[int] = None,
362
+ max_total_tokens: Optional[int] = None,
348
363
  ):
349
364
  if self.server_args.kv_cache_dtype == "auto":
350
365
  self.kv_cache_dtype = self.dtype
@@ -378,7 +393,7 @@ class ModelRunner:
378
393
  ),
379
394
  2048,
380
395
  ),
381
- 5120,
396
+ 4096,
382
397
  )
383
398
 
384
399
  self.req_to_token_pool = ReqToTokenPool(
@@ -396,9 +411,6 @@ class ModelRunner:
396
411
  qk_rope_head_dim=self.model_config.qk_rope_head_dim,
397
412
  layer_num=self.model_config.num_hidden_layers,
398
413
  )
399
- logger.info("using MLA Triton implementaion, flashinfer is disabled")
400
- # FIXME: temporarily only Triton MLA is supported
401
- self.server_args.disable_flashinfer = True
402
414
  else:
403
415
  self.token_to_kv_pool = MHATokenToKVPool(
404
416
  self.max_total_num_tokens,
@@ -421,118 +433,46 @@ class ModelRunner:
421
433
  c = a @ b
422
434
  return c
423
435
 
424
- def init_flashinfer(self):
425
- """Init flashinfer attention kernel wrappers."""
426
- if self.server_args.disable_flashinfer:
427
- assert (
428
- self.sliding_window_size is None
429
- ), "turn on flashinfer to support window attention"
430
- self.flashinfer_prefill_wrapper_ragged = None
431
- self.flashinfer_prefill_wrapper_paged = None
432
- self.flashinfer_decode_wrapper = None
433
- return
434
-
435
- if not _grouped_size_compiled_for_decode_kernels(
436
- self.model_config.num_attention_heads // self.tp_size,
437
- self.model_config.get_num_kv_heads(self.tp_size),
438
- ):
439
- use_tensor_cores = True
440
- else:
441
- use_tensor_cores = False
442
-
443
- if self.sliding_window_size is None:
444
- self.flashinfer_workspace_buffer = torch.empty(
445
- global_config.flashinfer_workspace_size,
446
- dtype=torch.uint8,
447
- device="cuda",
448
- )
449
- self.flashinfer_prefill_wrapper_ragged = (
450
- BatchPrefillWithRaggedKVCacheWrapper(
451
- self.flashinfer_workspace_buffer, "NHD"
452
- )
453
- )
454
- self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
455
- self.flashinfer_workspace_buffer, "NHD"
456
- )
457
- self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
458
- self.flashinfer_workspace_buffer,
459
- "NHD",
460
- use_tensor_cores=use_tensor_cores,
436
+ def init_attention_backend(self):
437
+ """Init attention kernel backend."""
438
+ if self.server_args.attention_backend == "flashinfer":
439
+ self.attn_backend = FlashInferAttnBackend(self)
440
+ elif self.server_args.attention_backend == "triton":
441
+ assert self.sliding_window_size is None, (
442
+ "Window attention is not supported in the triton attention backend. "
443
+ "Please use `--attention-backend flashinfer`."
461
444
  )
445
+ self.attn_backend = TritonAttnBackend(self)
462
446
  else:
463
- self.flashinfer_workspace_buffer = torch.empty(
464
- global_config.flashinfer_workspace_size,
465
- dtype=torch.uint8,
466
- device="cuda",
447
+ raise ValueError(
448
+ f"Invalid attention backend: {self.server_args.attention_backend}"
467
449
  )
468
- self.flashinfer_prefill_wrapper_ragged = None
469
- self.flashinfer_prefill_wrapper_paged = []
470
- self.flashinfer_decode_wrapper = []
471
- for i in range(2):
472
- self.flashinfer_prefill_wrapper_paged.append(
473
- BatchPrefillWithPagedKVCacheWrapper(
474
- self.flashinfer_workspace_buffer, "NHD"
475
- )
476
- )
477
- self.flashinfer_decode_wrapper.append(
478
- BatchDecodeWithPagedKVCacheWrapper(
479
- self.flashinfer_workspace_buffer,
480
- "NHD",
481
- use_tensor_cores=use_tensor_cores,
482
- )
483
- )
484
450
 
485
451
  def init_cuda_graphs(self):
486
452
  """Capture cuda graphs."""
453
+ from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
454
+
455
+ self.cuda_graph_runner = None
456
+
487
457
  if not self.is_generation:
488
458
  # TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
489
459
  return
490
460
 
491
- from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
492
-
493
- if self.server_args.disable_cuda_graph or self.server_args.disable_flashinfer:
494
- self.cuda_graph_runner = None
461
+ if self.server_args.disable_cuda_graph:
495
462
  return
496
463
 
497
464
  logger.info("Capture cuda graph begin. This can take up to several minutes.")
498
-
499
- if self.server_args.disable_cuda_graph_padding:
500
- batch_size_list = list(range(1, 32)) + [64, 128]
501
- else:
502
- batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 21)]
503
-
504
- self.cuda_graph_runner = CudaGraphRunner(
505
- self,
506
- max_batch_size_to_capture=max(batch_size_list),
507
- use_torch_compile=self.server_args.enable_torch_compile,
508
- disable_padding=self.server_args.disable_cuda_graph_padding,
509
- )
510
- try:
511
- self.cuda_graph_runner.capture(batch_size_list)
512
- except RuntimeError as e:
513
- raise Exception(
514
- f"Capture cuda graph failed: {e}\n"
515
- "Possible solutions:\n"
516
- "1. disable cuda graph by --disable-cuda-graph\n"
517
- "2. set --mem-fraction-static to a smaller value\n"
518
- "3. disable torch compile by not using --enable-torch-compile\n"
519
- "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
520
- )
465
+ self.cuda_graph_runner = CudaGraphRunner(self)
521
466
 
522
467
  @torch.inference_mode()
523
468
  def forward_decode(self, batch: ScheduleBatch):
524
- if (
525
- self.cuda_graph_runner
526
- and self.cuda_graph_runner.can_run(len(batch.reqs))
527
- and batch.sampling_info.can_run_in_cuda_graph()
528
- ):
469
+ if self.server_args.lora_paths is not None:
470
+ self.lora_manager.prepare_lora_batch(batch)
471
+
472
+ if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)):
529
473
  return self.cuda_graph_runner.replay(batch)
530
474
 
531
- input_metadata = InputMetadata.from_schedule_batch(
532
- self,
533
- batch,
534
- ForwardMode.DECODE,
535
- )
475
+ input_metadata = InputMetadata.from_schedule_batch(self, batch)
536
476
 
537
477
  return self.model.forward(
538
478
  batch.input_ids, input_metadata.positions, input_metadata
@@ -540,11 +480,10 @@ class ModelRunner:
540
480
 
541
481
  @torch.inference_mode()
542
482
  def forward_extend(self, batch: ScheduleBatch):
543
- input_metadata = InputMetadata.from_schedule_batch(
544
- self,
545
- batch,
546
- forward_mode=ForwardMode.EXTEND,
547
- )
483
+ input_metadata = InputMetadata.from_schedule_batch(self, batch)
484
+ if self.server_args.lora_paths is not None:
485
+ self.lora_manager.prepare_lora_batch(batch, input_metadata.extend_seq_lens)
486
+
548
487
  if self.is_generation:
549
488
  return self.model.forward(
550
489
  batch.input_ids, input_metadata.positions, input_metadata
@@ -560,11 +499,7 @@ class ModelRunner:
560
499
 
561
500
  @torch.inference_mode()
562
501
  def forward_extend_multi_modal(self, batch: ScheduleBatch):
563
- input_metadata = InputMetadata.from_schedule_batch(
564
- self,
565
- batch,
566
- forward_mode=ForwardMode.EXTEND,
567
- )
502
+ input_metadata = InputMetadata.from_schedule_batch(self, batch)
568
503
  return self.model.forward(
569
504
  batch.input_ids,
570
505
  input_metadata.positions,
@@ -574,17 +509,68 @@ class ModelRunner:
574
509
  input_metadata.image_offsets,
575
510
  )
576
511
 
577
- def forward(
578
- self, batch: ScheduleBatch, forward_mode: ForwardMode
579
- ) -> Tuple[SampleOutput, LogitsProcessorOutput]:
580
- if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
512
+ def forward(self, batch: ScheduleBatch) -> Tuple[LogitsProcessorOutput]:
513
+ assert batch.forward_mode is not None
514
+
515
+ if self.is_multimodal_model and batch.forward_mode.is_extend():
581
516
  return self.forward_extend_multi_modal(batch)
582
- elif forward_mode == ForwardMode.DECODE:
517
+ elif batch.forward_mode.is_decode():
583
518
  return self.forward_decode(batch)
584
- elif forward_mode == ForwardMode.EXTEND:
519
+ elif batch.forward_mode.is_extend():
585
520
  return self.forward_extend(batch)
586
521
  else:
587
- raise ValueError(f"Invaid forward mode: {forward_mode}")
522
+ raise ValueError(f"Invaid forward mode: {batch.forward_mode}")
523
+
524
+ def _check_sample_results(self, sample_output: SampleOutput):
525
+ if not torch.all(sample_output.success):
526
+ probs = sample_output.probs
527
+ batch_next_token_ids = sample_output.batch_next_token_ids
528
+ logging.warning("Sampling failed, fallback to top_k=1 strategy")
529
+ probs = probs.masked_fill(torch.isnan(probs), 0.0)
530
+ argmax_ids = torch.argmax(probs, dim=-1)
531
+ batch_next_token_ids = torch.where(
532
+ sample_output.success, batch_next_token_ids, argmax_ids
533
+ )
534
+ sample_output.probs = probs
535
+ sample_output.batch_next_token_ids = batch_next_token_ids
536
+
537
+ return sample_output.batch_next_token_ids
538
+
539
+ def _apply_logits_bias(
540
+ self, logits: torch.Tensor, sampling_info: SamplingBatchInfo
541
+ ):
542
+ # Apply logit_bias
543
+ if sampling_info.logit_bias is not None:
544
+ logits.add_(sampling_info.logit_bias)
545
+
546
+ # min-token, presence, frequency
547
+ if sampling_info.linear_penalties is not None:
548
+ logits += sampling_info.linear_penalties
549
+
550
+ # repetition
551
+ if sampling_info.scaling_penalties is not None:
552
+ logits = torch.where(
553
+ logits > 0,
554
+ logits / sampling_info.scaling_penalties,
555
+ logits * sampling_info.scaling_penalties,
556
+ )
557
+
558
+ # Apply regex vocab_mask
559
+ if sampling_info.vocab_mask is not None:
560
+ logits = logits.masked_fill(sampling_info.vocab_mask, float("-inf"))
561
+
562
+ return logits
563
+
564
+ def sample(
565
+ self, logits_output: LogitsProcessorOutput, batch: ScheduleBatch
566
+ ) -> torch.Tensor:
567
+ batch.sampling_info.update_regex_vocab_mask(batch)
568
+ batch.sampling_info.update_penalties()
569
+ logits = self._apply_logits_bias(
570
+ logits_output.next_token_logits, batch.sampling_info
571
+ )
572
+ sample_output = self.sampler(logits, batch.sampling_info)
573
+ return self._check_sample_results(sample_output)
588
574
 
589
575
 
590
576
  @lru_cache()