sglang 0.4.7.post1__py3-none-any.whl → 0.4.8.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (123) hide show
  1. sglang/bench_one_batch.py +8 -6
  2. sglang/srt/_custom_ops.py +2 -2
  3. sglang/srt/code_completion_parser.py +2 -44
  4. sglang/srt/configs/model_config.py +1 -0
  5. sglang/srt/constants.py +3 -0
  6. sglang/srt/conversation.py +14 -3
  7. sglang/srt/custom_op.py +11 -1
  8. sglang/srt/disaggregation/base/conn.py +2 -0
  9. sglang/srt/disaggregation/decode.py +22 -28
  10. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
  11. sglang/srt/disaggregation/mini_lb.py +34 -4
  12. sglang/srt/disaggregation/mooncake/conn.py +301 -64
  13. sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
  14. sglang/srt/disaggregation/nixl/conn.py +94 -46
  15. sglang/srt/disaggregation/prefill.py +20 -15
  16. sglang/srt/disaggregation/utils.py +47 -18
  17. sglang/srt/distributed/parallel_state.py +12 -4
  18. sglang/srt/entrypoints/engine.py +27 -31
  19. sglang/srt/entrypoints/http_server.py +149 -79
  20. sglang/srt/entrypoints/http_server_engine.py +0 -3
  21. sglang/srt/entrypoints/openai/__init__.py +0 -0
  22. sglang/srt/{openai_api → entrypoints/openai}/protocol.py +115 -34
  23. sglang/srt/entrypoints/openai/serving_base.py +149 -0
  24. sglang/srt/entrypoints/openai/serving_chat.py +897 -0
  25. sglang/srt/entrypoints/openai/serving_completions.py +425 -0
  26. sglang/srt/entrypoints/openai/serving_embedding.py +170 -0
  27. sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
  28. sglang/srt/entrypoints/openai/serving_score.py +61 -0
  29. sglang/srt/entrypoints/openai/usage_processor.py +81 -0
  30. sglang/srt/entrypoints/openai/utils.py +72 -0
  31. sglang/srt/function_call/base_format_detector.py +7 -4
  32. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  33. sglang/srt/function_call/ebnf_composer.py +64 -10
  34. sglang/srt/function_call/function_call_parser.py +6 -6
  35. sglang/srt/function_call/llama32_detector.py +1 -1
  36. sglang/srt/function_call/mistral_detector.py +1 -1
  37. sglang/srt/function_call/pythonic_detector.py +1 -1
  38. sglang/srt/function_call/qwen25_detector.py +1 -1
  39. sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
  40. sglang/srt/layers/activation.py +28 -3
  41. sglang/srt/layers/attention/aiter_backend.py +5 -2
  42. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  43. sglang/srt/layers/attention/cutlass_mla_backend.py +1 -0
  44. sglang/srt/layers/attention/flashattention_backend.py +43 -23
  45. sglang/srt/layers/attention/flashinfer_backend.py +9 -6
  46. sglang/srt/layers/attention/flashinfer_mla_backend.py +7 -4
  47. sglang/srt/layers/attention/flashmla_backend.py +5 -2
  48. sglang/srt/layers/attention/tbo_backend.py +3 -3
  49. sglang/srt/layers/attention/triton_backend.py +19 -11
  50. sglang/srt/layers/communicator.py +5 -5
  51. sglang/srt/layers/dp_attention.py +11 -2
  52. sglang/srt/layers/layernorm.py +44 -2
  53. sglang/srt/layers/linear.py +18 -1
  54. sglang/srt/layers/logits_processor.py +14 -5
  55. sglang/srt/layers/moe/ep_moe/kernels.py +159 -2
  56. sglang/srt/layers/moe/ep_moe/layer.py +286 -13
  57. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +19 -2
  58. sglang/srt/layers/moe/fused_moe_native.py +7 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +13 -2
  61. sglang/srt/layers/moe/fused_moe_triton/layer.py +148 -26
  62. sglang/srt/layers/moe/topk.py +117 -4
  63. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
  64. sglang/srt/layers/quantization/fp8.py +25 -17
  65. sglang/srt/layers/quantization/fp8_utils.py +5 -4
  66. sglang/srt/layers/quantization/modelopt_quant.py +62 -8
  67. sglang/srt/layers/quantization/utils.py +5 -2
  68. sglang/srt/layers/rotary_embedding.py +144 -12
  69. sglang/srt/layers/sampler.py +1 -1
  70. sglang/srt/layers/vocab_parallel_embedding.py +14 -1
  71. sglang/srt/lora/lora_manager.py +173 -74
  72. sglang/srt/lora/mem_pool.py +49 -45
  73. sglang/srt/lora/utils.py +1 -1
  74. sglang/srt/managers/cache_controller.py +33 -15
  75. sglang/srt/managers/expert_distribution.py +21 -0
  76. sglang/srt/managers/io_struct.py +19 -14
  77. sglang/srt/managers/multimodal_processors/base_processor.py +44 -9
  78. sglang/srt/managers/multimodal_processors/gemma3n.py +97 -0
  79. sglang/srt/managers/schedule_batch.py +49 -32
  80. sglang/srt/managers/schedule_policy.py +70 -56
  81. sglang/srt/managers/scheduler.py +189 -68
  82. sglang/srt/managers/template_manager.py +226 -0
  83. sglang/srt/managers/tokenizer_manager.py +11 -8
  84. sglang/srt/managers/tp_worker.py +12 -2
  85. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  86. sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
  87. sglang/srt/mem_cache/base_prefix_cache.py +52 -8
  88. sglang/srt/mem_cache/chunk_cache.py +11 -16
  89. sglang/srt/mem_cache/hiradix_cache.py +34 -23
  90. sglang/srt/mem_cache/memory_pool.py +118 -114
  91. sglang/srt/mem_cache/radix_cache.py +20 -16
  92. sglang/srt/model_executor/cuda_graph_runner.py +77 -46
  93. sglang/srt/model_executor/forward_batch_info.py +18 -5
  94. sglang/srt/model_executor/model_runner.py +27 -8
  95. sglang/srt/model_loader/loader.py +50 -8
  96. sglang/srt/model_loader/weight_utils.py +100 -2
  97. sglang/srt/models/deepseek_nextn.py +35 -30
  98. sglang/srt/models/deepseek_v2.py +255 -30
  99. sglang/srt/models/gemma3n_audio.py +949 -0
  100. sglang/srt/models/gemma3n_causal.py +1009 -0
  101. sglang/srt/models/gemma3n_mm.py +511 -0
  102. sglang/srt/models/glm4.py +312 -0
  103. sglang/srt/models/hunyuan.py +771 -0
  104. sglang/srt/models/mimo_mtp.py +2 -18
  105. sglang/srt/reasoning_parser.py +21 -11
  106. sglang/srt/server_args.py +51 -9
  107. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -10
  108. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +125 -12
  109. sglang/srt/speculative/eagle_utils.py +80 -8
  110. sglang/srt/speculative/eagle_worker.py +124 -41
  111. sglang/srt/torch_memory_saver_adapter.py +19 -15
  112. sglang/srt/two_batch_overlap.py +4 -1
  113. sglang/srt/utils.py +248 -11
  114. sglang/test/test_block_fp8_ep.py +1 -0
  115. sglang/test/test_utils.py +1 -0
  116. sglang/version.py +1 -1
  117. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/METADATA +4 -10
  118. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/RECORD +121 -105
  119. sglang/srt/entrypoints/verl_engine.py +0 -179
  120. sglang/srt/openai_api/adapter.py +0 -2148
  121. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/WHEEL +0 -0
  122. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/licenses/LICENSE +0 -0
  123. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/top_level.txt +0 -0
@@ -1,12 +1,14 @@
1
1
  # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/model_loader/weight_utils.py
2
2
 
3
3
  """Utilities for downloading and initializing model weights."""
4
+ import concurrent.futures
4
5
  import fnmatch
5
6
  import glob
6
7
  import hashlib
7
8
  import json
8
9
  import logging
9
10
  import os
11
+ import queue
10
12
  import tempfile
11
13
  from collections import defaultdict
12
14
  from typing import (
@@ -34,6 +36,7 @@ from sglang.srt.configs.load_config import LoadConfig
34
36
  from sglang.srt.configs.model_config import ModelConfig
35
37
  from sglang.srt.distributed import get_tensor_model_parallel_rank
36
38
  from sglang.srt.layers.quantization import QuantizationConfig, get_quantization_config
39
+ from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp4Config
37
40
  from sglang.srt.utils import print_warning_once
38
41
 
39
42
  logger = logging.getLogger(__name__)
@@ -206,7 +209,10 @@ def get_quant_config(
206
209
  config["adapter_name_or_path"] = model_name_or_path
207
210
  elif model_config.quantization == "modelopt":
208
211
  if config["producer"]["name"] == "modelopt":
209
- return quant_cls.from_config(config)
212
+ if "FP4" in config["quantization"]["quant_algo"]:
213
+ return ModelOptFp4Config.from_config(config)
214
+ else:
215
+ return quant_cls.from_config(config)
210
216
  else:
211
217
  raise ValueError(
212
218
  f"Unsupported quantization config"
@@ -418,6 +424,7 @@ def safetensors_weights_iterator(
418
424
  hf_weights_files: List[str],
419
425
  is_all_weights_sharded: bool = False,
420
426
  decryption_key: Optional[str] = None,
427
+ disable_mmap: bool = False,
421
428
  ) -> Generator[Tuple[str, torch.Tensor], None, None]:
422
429
  """Iterate over the weights in the model safetensor files.
423
430
 
@@ -439,11 +446,69 @@ def safetensors_weights_iterator(
439
446
  disable=not enable_tqdm,
440
447
  bar_format=_BAR_FORMAT,
441
448
  ):
442
- result = safetensors.torch.load_file(st_file, device="cpu")
449
+ if disable_mmap:
450
+ with open(st_file, "rb") as f:
451
+ result = safetensors.torch.load(f.read())
452
+ else:
453
+ result = safetensors.torch.load_file(st_file, device="cpu")
443
454
  for name, param in result.items():
444
455
  yield name, param
445
456
 
446
457
 
458
+ def multi_thread_safetensors_weights_iterator(
459
+ hf_weights_files: List[str],
460
+ is_all_weights_sharded: bool = False,
461
+ decryption_key: Optional[str] = None,
462
+ max_workers: int = 4,
463
+ disable_mmap: bool = False,
464
+ ) -> Generator[Tuple[str, torch.Tensor], None, None]:
465
+ """Multi-Thread iterate over the weights in the model safetensor files.
466
+
467
+ If is_all_weights_sharded is True, it uses more optimize read by reading an
468
+ entire file instead of reading each tensor one by one.
469
+ """
470
+ if decryption_key:
471
+ logger.warning(
472
+ "Multi-Thread loading is not working for encrypted safetensor weights."
473
+ )
474
+ yield from safetensors_encrypted_weights_iterator(
475
+ hf_weights_files, is_all_weights_sharded, decryption_key
476
+ )
477
+ return
478
+
479
+ enable_tqdm = (
480
+ not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
481
+ )
482
+
483
+ def _load_file(st_file: str):
484
+ if disable_mmap:
485
+ with open(st_file, "rb") as f:
486
+ result = safetensors.torch.load(f.read())
487
+ else:
488
+ result = safetensors.torch.load_file(st_file, device="cpu")
489
+
490
+ return result
491
+
492
+ with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
493
+ futures = [executor.submit(_load_file, st_file) for st_file in hf_weights_files]
494
+
495
+ if enable_tqdm:
496
+ futures_iter = tqdm(
497
+ concurrent.futures.as_completed(futures),
498
+ total=len(hf_weights_files),
499
+ desc="Multi-thread loading shards",
500
+ disable=not enable_tqdm,
501
+ bar_format=_BAR_FORMAT,
502
+ )
503
+ else:
504
+ futures_iter = concurrent.futures.as_completed(futures)
505
+
506
+ for future in futures_iter:
507
+ state_dict = future.result()
508
+ for name, param in state_dict.items():
509
+ yield name, param
510
+
511
+
447
512
  def pt_weights_iterator(
448
513
  hf_weights_files: List[str],
449
514
  ) -> Generator[Tuple[str, torch.Tensor], None, None]:
@@ -462,6 +527,39 @@ def pt_weights_iterator(
462
527
  del state
463
528
 
464
529
 
530
+ def multi_thread_pt_weights_iterator(
531
+ hf_weights_files: List[str],
532
+ max_workers: int = 4,
533
+ ) -> Generator[Tuple[str, torch.Tensor], None, None]:
534
+ """Multi-Thread iterate over the weights in the model bin/pt files."""
535
+ enable_tqdm = (
536
+ not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
537
+ )
538
+
539
+ def _load_file(bin_file: str):
540
+ return torch.load(bin_file, map_location="cpu", weights_only=True)
541
+
542
+ with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
543
+ futures = [
544
+ executor.submit(_load_file, bin_file) for bin_file in hf_weights_files
545
+ ]
546
+
547
+ if enable_tqdm:
548
+ futures_iter = tqdm(
549
+ concurrent.futures.as_completed(futures),
550
+ total=len(hf_weights_files),
551
+ desc="Multi-thread loading pt checkpoint shards",
552
+ disable=not enable_tqdm,
553
+ bar_format=_BAR_FORMAT,
554
+ )
555
+ else:
556
+ futures_iter = concurrent.futures.as_completed(futures)
557
+
558
+ for future in futures_iter:
559
+ state = future.result()
560
+ yield from state.items()
561
+
562
+
465
563
  def get_gguf_extra_tensor_names(
466
564
  gguf_file: str, gguf_to_hf_name_map: Dict[str, str]
467
565
  ) -> List[str]:
@@ -22,13 +22,15 @@ from transformers import PretrainedConfig
22
22
 
23
23
  from sglang.srt.distributed import get_tensor_model_parallel_world_size
24
24
  from sglang.srt.layers.layernorm import RMSNorm
25
- from sglang.srt.layers.linear import ReplicatedLinear
26
25
  from sglang.srt.layers.logits_processor import LogitsProcessor
27
26
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
28
27
  from sglang.srt.layers.vocab_parallel_embedding import (
29
28
  ParallelLMHead,
30
29
  VocabParallelEmbedding,
31
30
  )
31
+ from sglang.srt.managers.expert_distribution import (
32
+ get_global_expert_distribution_recorder,
33
+ )
32
34
  from sglang.srt.managers.schedule_batch import global_server_args_dict
33
35
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
34
36
  from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM
@@ -45,6 +47,12 @@ class DeepseekModelNextN(nn.Module):
45
47
  prefix: str = "",
46
48
  ) -> None:
47
49
  super().__init__()
50
+ if quant_config is not None and quant_config.get_name() == "modelopt_fp4":
51
+ logger.warning(
52
+ "Overriding DeepseekV3ForCausalLMNextN quant config for modelopt_fp4 Deepseek model."
53
+ )
54
+ quant_config = None
55
+
48
56
  self.vocab_size = config.vocab_size
49
57
 
50
58
  self.embed_tokens = VocabParallelEmbedding(
@@ -90,23 +98,29 @@ class DeepseekModelNextN(nn.Module):
90
98
  else:
91
99
  hidden_states = input_embeds
92
100
 
93
- hidden_states = self.eh_proj(
94
- torch.cat(
95
- (
96
- self.enorm(hidden_states),
97
- self.hnorm(forward_batch.spec_info.hidden_states),
98
- ),
99
- dim=-1,
101
+ if hidden_states.shape[0] > 0:
102
+ hidden_states = self.eh_proj(
103
+ torch.cat(
104
+ (
105
+ self.enorm(hidden_states),
106
+ self.hnorm(forward_batch.spec_info.hidden_states),
107
+ ),
108
+ dim=-1,
109
+ )
100
110
  )
101
- )
102
111
 
103
112
  residual = None
104
- hidden_states, residual = self.decoder(
105
- positions, hidden_states, forward_batch, residual, zero_allocator
106
- )
113
+ with get_global_expert_distribution_recorder().disable_this_region():
114
+ hidden_states, residual = self.decoder(
115
+ positions, hidden_states, forward_batch, residual, zero_allocator
116
+ )
107
117
 
108
118
  if not forward_batch.forward_mode.is_idle():
109
- hidden_states, _ = self.shared_head.norm(hidden_states, residual)
119
+ if residual is not None:
120
+ hidden_states, _ = self.shared_head.norm(hidden_states, residual)
121
+ else:
122
+ hidden_states = self.shared_head.norm(hidden_states)
123
+
110
124
  return hidden_states
111
125
 
112
126
 
@@ -127,23 +141,14 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
127
141
  self.model = DeepseekModelNextN(
128
142
  config, quant_config, prefix=add_prefix("model", prefix)
129
143
  )
130
-
131
- if global_server_args_dict["enable_dp_attention"]:
132
- self.lm_head = ReplicatedLinear(
133
- config.hidden_size,
134
- config.vocab_size,
135
- bias=False,
136
- prefix=add_prefix("model.shared_head.head", prefix),
137
- )
138
- self.logits_processor = LogitsProcessor(config, skip_all_gather=True)
139
- else:
140
- self.lm_head = ParallelLMHead(
141
- config.vocab_size,
142
- config.hidden_size,
143
- quant_config=quant_config,
144
- prefix=add_prefix("model.shared_head.head", prefix),
145
- )
146
- self.logits_processor = LogitsProcessor(config)
144
+ self.lm_head = ParallelLMHead(
145
+ config.vocab_size,
146
+ config.hidden_size,
147
+ quant_config=quant_config,
148
+ prefix=add_prefix("model.shared_head.head", prefix),
149
+ use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
150
+ )
151
+ self.logits_processor = LogitsProcessor(config)
147
152
 
148
153
  @torch.no_grad()
149
154
  def forward(