sglang 0.2.15__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. sglang/bench_latency.py +10 -6
  2. sglang/bench_serving.py +33 -38
  3. sglang/global_config.py +0 -4
  4. sglang/lang/backend/runtime_endpoint.py +13 -6
  5. sglang/lang/interpreter.py +1 -1
  6. sglang/launch_server.py +3 -6
  7. sglang/launch_server_llavavid.py +7 -8
  8. sglang/srt/{model_config.py → configs/model_config.py} +5 -0
  9. sglang/srt/constrained/__init__.py +2 -0
  10. sglang/srt/constrained/fsm_cache.py +29 -38
  11. sglang/srt/constrained/jump_forward.py +0 -1
  12. sglang/srt/conversation.py +4 -1
  13. sglang/srt/hf_transformers_utils.py +2 -4
  14. sglang/srt/layers/attention_backend.py +480 -0
  15. sglang/srt/layers/flashinfer_utils.py +235 -0
  16. sglang/srt/layers/logits_processor.py +64 -77
  17. sglang/srt/layers/radix_attention.py +11 -161
  18. sglang/srt/layers/sampler.py +40 -35
  19. sglang/srt/layers/torchao_utils.py +75 -0
  20. sglang/srt/layers/{decode_attention.py → triton_attention/decode_attention.py} +67 -63
  21. sglang/srt/layers/{extend_attention.py → triton_attention/extend_attention.py} +40 -132
  22. sglang/srt/layers/{prefill_attention.py → triton_attention/prefill_attention.py} +13 -7
  23. sglang/srt/lora/lora.py +403 -0
  24. sglang/srt/lora/lora_config.py +43 -0
  25. sglang/srt/lora/lora_manager.py +256 -0
  26. sglang/srt/managers/controller_multi.py +1 -5
  27. sglang/srt/managers/controller_single.py +0 -5
  28. sglang/srt/managers/io_struct.py +16 -1
  29. sglang/srt/managers/policy_scheduler.py +122 -5
  30. sglang/srt/managers/schedule_batch.py +110 -74
  31. sglang/srt/managers/tokenizer_manager.py +24 -15
  32. sglang/srt/managers/tp_worker.py +181 -115
  33. sglang/srt/model_executor/cuda_graph_runner.py +60 -133
  34. sglang/srt/model_executor/forward_batch_info.py +35 -312
  35. sglang/srt/model_executor/model_runner.py +118 -141
  36. sglang/srt/models/baichuan.py +416 -0
  37. sglang/srt/models/chatglm.py +6 -8
  38. sglang/srt/models/commandr.py +1 -5
  39. sglang/srt/models/dbrx.py +1 -5
  40. sglang/srt/models/deepseek.py +1 -5
  41. sglang/srt/models/deepseek_v2.py +1 -5
  42. sglang/srt/models/exaone.py +8 -43
  43. sglang/srt/models/gemma.py +1 -5
  44. sglang/srt/models/gemma2.py +1 -5
  45. sglang/srt/models/gpt_bigcode.py +1 -5
  46. sglang/srt/models/grok.py +1 -5
  47. sglang/srt/models/internlm2.py +1 -5
  48. sglang/srt/models/{llama2.py → llama.py} +48 -26
  49. sglang/srt/models/llama_classification.py +14 -40
  50. sglang/srt/models/llama_embedding.py +7 -6
  51. sglang/srt/models/llava.py +38 -16
  52. sglang/srt/models/llavavid.py +7 -8
  53. sglang/srt/models/minicpm.py +1 -5
  54. sglang/srt/models/minicpm3.py +665 -0
  55. sglang/srt/models/mistral.py +2 -3
  56. sglang/srt/models/mixtral.py +6 -5
  57. sglang/srt/models/mixtral_quant.py +1 -5
  58. sglang/srt/models/qwen.py +1 -5
  59. sglang/srt/models/qwen2.py +1 -5
  60. sglang/srt/models/qwen2_moe.py +6 -5
  61. sglang/srt/models/stablelm.py +1 -5
  62. sglang/srt/models/xverse.py +375 -0
  63. sglang/srt/models/xverse_moe.py +445 -0
  64. sglang/srt/openai_api/adapter.py +65 -46
  65. sglang/srt/openai_api/protocol.py +11 -3
  66. sglang/srt/sampling/sampling_batch_info.py +67 -58
  67. sglang/srt/server.py +24 -14
  68. sglang/srt/server_args.py +130 -28
  69. sglang/srt/utils.py +12 -0
  70. sglang/test/few_shot_gsm8k.py +132 -0
  71. sglang/test/runners.py +114 -22
  72. sglang/test/test_programs.py +70 -0
  73. sglang/test/test_utils.py +89 -1
  74. sglang/utils.py +38 -4
  75. sglang/version.py +1 -1
  76. {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/METADATA +31 -18
  77. sglang-0.3.1.dist-info/RECORD +129 -0
  78. {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/WHEEL +1 -1
  79. sglang-0.2.15.dist-info/RECORD +0 -118
  80. {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/LICENSE +0 -0
  81. {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/top_level.txt +0 -0
@@ -25,12 +25,6 @@ from typing import Optional, Tuple, Type
25
25
 
26
26
  import torch
27
27
  import torch.nn as nn
28
- from flashinfer import (
29
- BatchDecodeWithPagedKVCacheWrapper,
30
- BatchPrefillWithPagedKVCacheWrapper,
31
- BatchPrefillWithRaggedKVCacheWrapper,
32
- )
33
- from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
34
28
  from vllm.config import DeviceConfig, LoadConfig
35
29
  from vllm.config import ModelConfig as VllmModelConfig
36
30
  from vllm.distributed import (
@@ -43,17 +37,19 @@ from vllm.distributed.parallel_state import in_the_same_node_as
43
37
  from vllm.model_executor.model_loader import get_model
44
38
  from vllm.model_executor.models import ModelRegistry
45
39
 
46
- from sglang.global_config import global_config
40
+ from sglang.srt.configs.model_config import AttentionArch, ModelConfig
41
+ from sglang.srt.layers.attention_backend import FlashInferAttnBackend, TritonAttnBackend
47
42
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
48
- from sglang.srt.layers.sampler import SampleOutput
43
+ from sglang.srt.layers.sampler import SampleOutput, Sampler
44
+ from sglang.srt.lora.lora_manager import LoRAManager
49
45
  from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
50
46
  from sglang.srt.mem_cache.memory_pool import (
51
47
  MHATokenToKVPool,
52
48
  MLATokenToKVPool,
53
49
  ReqToTokenPool,
54
50
  )
55
- from sglang.srt.model_config import AttentionArch, ModelConfig
56
- from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
51
+ from sglang.srt.model_executor.forward_batch_info import InputMetadata
52
+ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
57
53
  from sglang.srt.server_args import ServerArgs
58
54
  from sglang.srt.utils import (
59
55
  get_available_gpu_memory,
@@ -69,6 +65,8 @@ logger = logging.getLogger(__name__)
69
65
 
70
66
 
71
67
  class ModelRunner:
68
+ """ModelRunner runs the forward passes of the models."""
69
+
72
70
  def __init__(
73
71
  self,
74
72
  model_config: ModelConfig,
@@ -92,13 +90,15 @@ class ModelRunner:
92
90
  )
93
91
  global_server_args_dict.update(
94
92
  {
95
- "disable_flashinfer": server_args.disable_flashinfer,
96
- "disable_flashinfer_sampling": server_args.disable_flashinfer_sampling,
93
+ "attention_backend": server_args.attention_backend,
94
+ "sampling_backend": server_args.sampling_backend,
97
95
  "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
98
96
  "enable_mla": server_args.enable_mla,
97
+ "torchao_config": server_args.torchao_config,
99
98
  }
100
99
  )
101
100
 
101
+ # Model-specific adjustment
102
102
  if self.is_multimodal_model:
103
103
  logger.info(
104
104
  "Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
@@ -106,15 +106,19 @@ class ModelRunner:
106
106
  server_args.chunked_prefill_size = None
107
107
  server_args.mem_fraction_static *= 0.95
108
108
 
109
+ # Init componnets
109
110
  min_per_gpu_memory = self.init_torch_distributed()
111
+ self.sampler = Sampler()
110
112
  self.load_model()
113
+ if server_args.lora_paths is not None:
114
+ self.init_lora_manager()
111
115
  self.init_memory_pool(
112
116
  min_per_gpu_memory,
113
- server_args.max_num_reqs,
117
+ server_args.max_running_requests,
114
118
  server_args.max_total_tokens,
115
119
  )
116
120
  self.init_cublas()
117
- self.init_flashinfer()
121
+ self.init_attention_backend()
118
122
  self.init_cuda_graphs()
119
123
 
120
124
  def init_torch_distributed(self):
@@ -162,6 +166,7 @@ class ModelRunner:
162
166
  return min_per_gpu_memory
163
167
 
164
168
  def load_model(self):
169
+ torch.set_num_threads(1)
165
170
  logger.info(
166
171
  f"Load weight begin. avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
167
172
  )
@@ -312,6 +317,17 @@ class ModelRunner:
312
317
  logger.info("Update weights end.")
313
318
  return True, "Succeeded to update model weights"
314
319
 
320
+ def init_lora_manager(self):
321
+ self.lora_manager = LoRAManager(
322
+ base_model=self.model,
323
+ lora_paths=self.server_args.lora_paths,
324
+ base_hf_config=self.model_config.hf_config,
325
+ max_loras_per_batch=self.server_args.max_loras_per_batch,
326
+ load_config=self.load_config,
327
+ dtype=self.dtype,
328
+ )
329
+ logger.info("LoRA manager ready.")
330
+
315
331
  def profile_max_num_token(self, total_gpu_memory: int):
316
332
  available_gpu_memory = get_available_gpu_memory(
317
333
  self.gpu_id, distributed=self.tp_size > 1
@@ -342,8 +358,8 @@ class ModelRunner:
342
358
  def init_memory_pool(
343
359
  self,
344
360
  total_gpu_memory: int,
345
- max_num_reqs: int = None,
346
- max_total_tokens: int = None,
361
+ max_num_reqs: Optional[int] = None,
362
+ max_total_tokens: Optional[int] = None,
347
363
  ):
348
364
  if self.server_args.kv_cache_dtype == "auto":
349
365
  self.kv_cache_dtype = self.dtype
@@ -377,7 +393,7 @@ class ModelRunner:
377
393
  ),
378
394
  2048,
379
395
  ),
380
- 5120,
396
+ 4096,
381
397
  )
382
398
 
383
399
  self.req_to_token_pool = ReqToTokenPool(
@@ -395,9 +411,6 @@ class ModelRunner:
395
411
  qk_rope_head_dim=self.model_config.qk_rope_head_dim,
396
412
  layer_num=self.model_config.num_hidden_layers,
397
413
  )
398
- logger.info("using MLA Triton implementaion, flashinfer is disabled")
399
- # FIXME: temporarily only Triton MLA is supported
400
- self.server_args.disable_flashinfer = True
401
414
  else:
402
415
  self.token_to_kv_pool = MHATokenToKVPool(
403
416
  self.max_total_num_tokens,
@@ -420,118 +433,46 @@ class ModelRunner:
420
433
  c = a @ b
421
434
  return c
422
435
 
423
- def init_flashinfer(self):
424
- """Init flashinfer attention kernel wrappers."""
425
- if self.server_args.disable_flashinfer:
426
- assert (
427
- self.sliding_window_size is None
428
- ), "turn on flashinfer to support window attention"
429
- self.flashinfer_prefill_wrapper_ragged = None
430
- self.flashinfer_prefill_wrapper_paged = None
431
- self.flashinfer_decode_wrapper = None
432
- return
433
-
434
- if not _grouped_size_compiled_for_decode_kernels(
435
- self.model_config.num_attention_heads // self.tp_size,
436
- self.model_config.get_num_kv_heads(self.tp_size),
437
- ):
438
- use_tensor_cores = True
439
- else:
440
- use_tensor_cores = False
441
-
442
- if self.sliding_window_size is None:
443
- self.flashinfer_workspace_buffer = torch.empty(
444
- global_config.flashinfer_workspace_size,
445
- dtype=torch.uint8,
446
- device="cuda",
447
- )
448
- self.flashinfer_prefill_wrapper_ragged = (
449
- BatchPrefillWithRaggedKVCacheWrapper(
450
- self.flashinfer_workspace_buffer, "NHD"
451
- )
452
- )
453
- self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
454
- self.flashinfer_workspace_buffer, "NHD"
455
- )
456
- self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
457
- self.flashinfer_workspace_buffer,
458
- "NHD",
459
- use_tensor_cores=use_tensor_cores,
436
+ def init_attention_backend(self):
437
+ """Init attention kernel backend."""
438
+ if self.server_args.attention_backend == "flashinfer":
439
+ self.attn_backend = FlashInferAttnBackend(self)
440
+ elif self.server_args.attention_backend == "triton":
441
+ assert self.sliding_window_size is None, (
442
+ "Window attention is not supported in the triton attention backend. "
443
+ "Please use `--attention-backend flashinfer`."
460
444
  )
445
+ self.attn_backend = TritonAttnBackend(self)
461
446
  else:
462
- self.flashinfer_workspace_buffer = torch.empty(
463
- global_config.flashinfer_workspace_size,
464
- dtype=torch.uint8,
465
- device="cuda",
447
+ raise ValueError(
448
+ f"Invalid attention backend: {self.server_args.attention_backend}"
466
449
  )
467
- self.flashinfer_prefill_wrapper_ragged = None
468
- self.flashinfer_prefill_wrapper_paged = []
469
- self.flashinfer_decode_wrapper = []
470
- for i in range(2):
471
- self.flashinfer_prefill_wrapper_paged.append(
472
- BatchPrefillWithPagedKVCacheWrapper(
473
- self.flashinfer_workspace_buffer, "NHD"
474
- )
475
- )
476
- self.flashinfer_decode_wrapper.append(
477
- BatchDecodeWithPagedKVCacheWrapper(
478
- self.flashinfer_workspace_buffer,
479
- "NHD",
480
- use_tensor_cores=use_tensor_cores,
481
- )
482
- )
483
450
 
484
451
  def init_cuda_graphs(self):
485
452
  """Capture cuda graphs."""
453
+ from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
454
+
455
+ self.cuda_graph_runner = None
456
+
486
457
  if not self.is_generation:
487
458
  # TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
488
459
  return
489
460
 
490
- from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
491
-
492
- if self.server_args.disable_cuda_graph or self.server_args.disable_flashinfer:
493
- self.cuda_graph_runner = None
461
+ if self.server_args.disable_cuda_graph:
494
462
  return
495
463
 
496
464
  logger.info("Capture cuda graph begin. This can take up to several minutes.")
497
-
498
- if self.server_args.disable_cuda_graph_padding:
499
- batch_size_list = list(range(1, 32)) + [64, 128]
500
- else:
501
- batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 21)]
502
-
503
- self.cuda_graph_runner = CudaGraphRunner(
504
- self,
505
- max_batch_size_to_capture=max(batch_size_list),
506
- use_torch_compile=self.server_args.enable_torch_compile,
507
- disable_padding=self.server_args.disable_cuda_graph_padding,
508
- )
509
- try:
510
- self.cuda_graph_runner.capture(batch_size_list)
511
- except RuntimeError as e:
512
- raise Exception(
513
- f"Capture cuda graph failed: {e}\n"
514
- "Possible solutions:\n"
515
- "1. disable cuda graph by --disable-cuda-graph\n"
516
- "2. set --mem-fraction-static to a smaller value\n"
517
- "3. disable torch compile by not using --enable-torch-compile\n"
518
- "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
519
- )
465
+ self.cuda_graph_runner = CudaGraphRunner(self)
520
466
 
521
467
  @torch.inference_mode()
522
468
  def forward_decode(self, batch: ScheduleBatch):
523
- if (
524
- self.cuda_graph_runner
525
- and self.cuda_graph_runner.can_run(len(batch.reqs))
526
- and not batch.sampling_info.has_bias()
527
- ):
469
+ if self.server_args.lora_paths is not None:
470
+ self.lora_manager.prepare_lora_batch(batch)
471
+
472
+ if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)):
528
473
  return self.cuda_graph_runner.replay(batch)
529
474
 
530
- input_metadata = InputMetadata.from_schedule_batch(
531
- self,
532
- batch,
533
- ForwardMode.DECODE,
534
- )
475
+ input_metadata = InputMetadata.from_schedule_batch(self, batch)
535
476
 
536
477
  return self.model.forward(
537
478
  batch.input_ids, input_metadata.positions, input_metadata
@@ -539,11 +480,10 @@ class ModelRunner:
539
480
 
540
481
  @torch.inference_mode()
541
482
  def forward_extend(self, batch: ScheduleBatch):
542
- input_metadata = InputMetadata.from_schedule_batch(
543
- self,
544
- batch,
545
- forward_mode=ForwardMode.EXTEND,
546
- )
483
+ input_metadata = InputMetadata.from_schedule_batch(self, batch)
484
+ if self.server_args.lora_paths is not None:
485
+ self.lora_manager.prepare_lora_batch(batch, input_metadata.extend_seq_lens)
486
+
547
487
  if self.is_generation:
548
488
  return self.model.forward(
549
489
  batch.input_ids, input_metadata.positions, input_metadata
@@ -559,11 +499,7 @@ class ModelRunner:
559
499
 
560
500
  @torch.inference_mode()
561
501
  def forward_extend_multi_modal(self, batch: ScheduleBatch):
562
- input_metadata = InputMetadata.from_schedule_batch(
563
- self,
564
- batch,
565
- forward_mode=ForwardMode.EXTEND,
566
- )
502
+ input_metadata = InputMetadata.from_schedule_batch(self, batch)
567
503
  return self.model.forward(
568
504
  batch.input_ids,
569
505
  input_metadata.positions,
@@ -573,17 +509,68 @@ class ModelRunner:
573
509
  input_metadata.image_offsets,
574
510
  )
575
511
 
576
- def forward(
577
- self, batch: ScheduleBatch, forward_mode: ForwardMode
578
- ) -> Tuple[SampleOutput, LogitsProcessorOutput]:
579
- if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
512
+ def forward(self, batch: ScheduleBatch) -> Tuple[LogitsProcessorOutput]:
513
+ assert batch.forward_mode is not None
514
+
515
+ if self.is_multimodal_model and batch.forward_mode.is_extend():
580
516
  return self.forward_extend_multi_modal(batch)
581
- elif forward_mode == ForwardMode.DECODE:
517
+ elif batch.forward_mode.is_decode():
582
518
  return self.forward_decode(batch)
583
- elif forward_mode == ForwardMode.EXTEND:
519
+ elif batch.forward_mode.is_extend():
584
520
  return self.forward_extend(batch)
585
521
  else:
586
- raise ValueError(f"Invaid forward mode: {forward_mode}")
522
+ raise ValueError(f"Invaid forward mode: {batch.forward_mode}")
523
+
524
+ def _check_sample_results(self, sample_output: SampleOutput):
525
+ if not torch.all(sample_output.success):
526
+ probs = sample_output.probs
527
+ batch_next_token_ids = sample_output.batch_next_token_ids
528
+ logging.warning("Sampling failed, fallback to top_k=1 strategy")
529
+ probs = probs.masked_fill(torch.isnan(probs), 0.0)
530
+ argmax_ids = torch.argmax(probs, dim=-1)
531
+ batch_next_token_ids = torch.where(
532
+ sample_output.success, batch_next_token_ids, argmax_ids
533
+ )
534
+ sample_output.probs = probs
535
+ sample_output.batch_next_token_ids = batch_next_token_ids
536
+
537
+ return sample_output.batch_next_token_ids
538
+
539
+ def _apply_logits_bias(
540
+ self, logits: torch.Tensor, sampling_info: SamplingBatchInfo
541
+ ):
542
+ # Apply logit_bias
543
+ if sampling_info.logit_bias is not None:
544
+ logits.add_(sampling_info.logit_bias)
545
+
546
+ # min-token, presence, frequency
547
+ if sampling_info.linear_penalties is not None:
548
+ logits += sampling_info.linear_penalties
549
+
550
+ # repetition
551
+ if sampling_info.scaling_penalties is not None:
552
+ logits = torch.where(
553
+ logits > 0,
554
+ logits / sampling_info.scaling_penalties,
555
+ logits * sampling_info.scaling_penalties,
556
+ )
557
+
558
+ # Apply regex vocab_mask
559
+ if sampling_info.vocab_mask is not None:
560
+ logits = logits.masked_fill(sampling_info.vocab_mask, float("-inf"))
561
+
562
+ return logits
563
+
564
+ def sample(
565
+ self, logits_output: LogitsProcessorOutput, batch: ScheduleBatch
566
+ ) -> torch.Tensor:
567
+ batch.sampling_info.update_regex_vocab_mask(batch)
568
+ batch.sampling_info.update_penalties()
569
+ logits = self._apply_logits_bias(
570
+ logits_output.next_token_logits, batch.sampling_info
571
+ )
572
+ sample_output = self.sampler(logits, batch.sampling_info)
573
+ return self._check_sample_results(sample_output)
587
574
 
588
575
 
589
576
  @lru_cache()
@@ -606,16 +593,6 @@ def import_model_classes():
606
593
  assert entry.__name__ not in model_arch_name_to_cls
607
594
  model_arch_name_to_cls[entry.__name__] = entry
608
595
 
609
- # compat: some models such as chatglm has incorrect class set in config.json
610
- # usage: [ tuple("From_Entry_Class_Name": EntryClass), ]
611
- if hasattr(module, "EntryClassRemapping") and isinstance(
612
- module.EntryClassRemapping, list
613
- ):
614
- for remap in module.EntryClassRemapping:
615
- if isinstance(remap, tuple) and len(remap) == 2:
616
- assert remap[0] not in model_arch_name_to_cls
617
- model_arch_name_to_cls[remap[0]] = remap[1]
618
-
619
596
  return model_arch_name_to_cls
620
597
 
621
598