sglang 0.4.7__py3-none-any.whl → 0.4.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (152) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +7 -0
  3. sglang/bench_one_batch.py +8 -6
  4. sglang/bench_serving.py +1 -1
  5. sglang/lang/interpreter.py +40 -1
  6. sglang/lang/ir.py +27 -0
  7. sglang/math_utils.py +8 -0
  8. sglang/srt/_custom_ops.py +2 -2
  9. sglang/srt/code_completion_parser.py +2 -44
  10. sglang/srt/configs/model_config.py +6 -0
  11. sglang/srt/constants.py +3 -0
  12. sglang/srt/conversation.py +19 -3
  13. sglang/srt/custom_op.py +5 -1
  14. sglang/srt/disaggregation/base/__init__.py +1 -1
  15. sglang/srt/disaggregation/base/conn.py +25 -11
  16. sglang/srt/disaggregation/common/__init__.py +5 -1
  17. sglang/srt/disaggregation/common/utils.py +42 -0
  18. sglang/srt/disaggregation/decode.py +211 -72
  19. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
  20. sglang/srt/disaggregation/fake/__init__.py +1 -1
  21. sglang/srt/disaggregation/fake/conn.py +15 -9
  22. sglang/srt/disaggregation/mini_lb.py +34 -4
  23. sglang/srt/disaggregation/mooncake/__init__.py +1 -1
  24. sglang/srt/disaggregation/mooncake/conn.py +30 -29
  25. sglang/srt/disaggregation/nixl/__init__.py +6 -1
  26. sglang/srt/disaggregation/nixl/conn.py +17 -12
  27. sglang/srt/disaggregation/prefill.py +144 -55
  28. sglang/srt/disaggregation/utils.py +155 -123
  29. sglang/srt/distributed/parallel_state.py +12 -4
  30. sglang/srt/entrypoints/engine.py +37 -29
  31. sglang/srt/entrypoints/http_server.py +153 -72
  32. sglang/srt/entrypoints/http_server_engine.py +0 -3
  33. sglang/srt/entrypoints/openai/__init__.py +0 -0
  34. sglang/srt/{openai_api → entrypoints/openai}/protocol.py +84 -10
  35. sglang/srt/entrypoints/openai/serving_base.py +149 -0
  36. sglang/srt/entrypoints/openai/serving_chat.py +921 -0
  37. sglang/srt/entrypoints/openai/serving_completions.py +424 -0
  38. sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
  39. sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
  40. sglang/srt/entrypoints/openai/serving_score.py +61 -0
  41. sglang/srt/entrypoints/openai/usage_processor.py +81 -0
  42. sglang/srt/entrypoints/openai/utils.py +72 -0
  43. sglang/srt/eplb_simulator/__init__.py +1 -0
  44. sglang/srt/eplb_simulator/reader.py +51 -0
  45. sglang/srt/function_call/base_format_detector.py +7 -4
  46. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  47. sglang/srt/function_call/ebnf_composer.py +64 -10
  48. sglang/srt/function_call/function_call_parser.py +6 -6
  49. sglang/srt/function_call/llama32_detector.py +1 -1
  50. sglang/srt/function_call/mistral_detector.py +1 -1
  51. sglang/srt/function_call/pythonic_detector.py +1 -1
  52. sglang/srt/function_call/qwen25_detector.py +1 -1
  53. sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
  54. sglang/srt/layers/activation.py +40 -3
  55. sglang/srt/layers/attention/aiter_backend.py +20 -4
  56. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  57. sglang/srt/layers/attention/cutlass_mla_backend.py +39 -15
  58. sglang/srt/layers/attention/flashattention_backend.py +71 -72
  59. sglang/srt/layers/attention/flashinfer_backend.py +10 -8
  60. sglang/srt/layers/attention/flashinfer_mla_backend.py +29 -28
  61. sglang/srt/layers/attention/flashmla_backend.py +7 -12
  62. sglang/srt/layers/attention/tbo_backend.py +3 -3
  63. sglang/srt/layers/attention/triton_backend.py +138 -130
  64. sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
  65. sglang/srt/layers/attention/vision.py +51 -24
  66. sglang/srt/layers/communicator.py +28 -10
  67. sglang/srt/layers/dp_attention.py +11 -2
  68. sglang/srt/layers/layernorm.py +29 -2
  69. sglang/srt/layers/linear.py +0 -4
  70. sglang/srt/layers/logits_processor.py +2 -14
  71. sglang/srt/layers/moe/ep_moe/kernels.py +165 -7
  72. sglang/srt/layers/moe/ep_moe/layer.py +249 -33
  73. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
  74. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  75. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -4
  76. sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
  77. sglang/srt/layers/moe/topk.py +107 -12
  78. sglang/srt/layers/pooler.py +56 -0
  79. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
  80. sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
  81. sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
  82. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
  83. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
  84. sglang/srt/layers/quantization/fp8.py +25 -17
  85. sglang/srt/layers/quantization/fp8_kernel.py +44 -15
  86. sglang/srt/layers/quantization/fp8_utils.py +87 -22
  87. sglang/srt/layers/quantization/modelopt_quant.py +62 -8
  88. sglang/srt/layers/quantization/utils.py +5 -2
  89. sglang/srt/layers/radix_attention.py +2 -3
  90. sglang/srt/layers/rotary_embedding.py +42 -2
  91. sglang/srt/layers/sampler.py +1 -1
  92. sglang/srt/lora/lora_manager.py +249 -105
  93. sglang/srt/lora/mem_pool.py +53 -50
  94. sglang/srt/lora/utils.py +1 -1
  95. sglang/srt/managers/cache_controller.py +33 -14
  96. sglang/srt/managers/io_struct.py +31 -10
  97. sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
  98. sglang/srt/managers/multimodal_processors/vila.py +85 -0
  99. sglang/srt/managers/schedule_batch.py +79 -37
  100. sglang/srt/managers/schedule_policy.py +70 -56
  101. sglang/srt/managers/scheduler.py +220 -79
  102. sglang/srt/managers/template_manager.py +226 -0
  103. sglang/srt/managers/tokenizer_manager.py +40 -10
  104. sglang/srt/managers/tp_worker.py +12 -2
  105. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  106. sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
  107. sglang/srt/mem_cache/base_prefix_cache.py +52 -8
  108. sglang/srt/mem_cache/chunk_cache.py +11 -15
  109. sglang/srt/mem_cache/hiradix_cache.py +38 -25
  110. sglang/srt/mem_cache/memory_pool.py +213 -505
  111. sglang/srt/mem_cache/memory_pool_host.py +380 -0
  112. sglang/srt/mem_cache/radix_cache.py +56 -28
  113. sglang/srt/model_executor/cuda_graph_runner.py +198 -100
  114. sglang/srt/model_executor/forward_batch_info.py +32 -10
  115. sglang/srt/model_executor/model_runner.py +28 -12
  116. sglang/srt/model_loader/loader.py +16 -2
  117. sglang/srt/model_loader/weight_utils.py +11 -2
  118. sglang/srt/models/bert.py +113 -13
  119. sglang/srt/models/deepseek_nextn.py +29 -27
  120. sglang/srt/models/deepseek_v2.py +213 -173
  121. sglang/srt/models/glm4.py +312 -0
  122. sglang/srt/models/internvl.py +46 -102
  123. sglang/srt/models/mimo_mtp.py +2 -18
  124. sglang/srt/models/roberta.py +117 -9
  125. sglang/srt/models/vila.py +305 -0
  126. sglang/srt/reasoning_parser.py +21 -11
  127. sglang/srt/sampling/sampling_batch_info.py +24 -0
  128. sglang/srt/sampling/sampling_params.py +2 -0
  129. sglang/srt/server_args.py +351 -238
  130. sglang/srt/speculative/build_eagle_tree.py +1 -1
  131. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -9
  132. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +130 -14
  133. sglang/srt/speculative/eagle_utils.py +468 -116
  134. sglang/srt/speculative/eagle_worker.py +258 -84
  135. sglang/srt/torch_memory_saver_adapter.py +19 -15
  136. sglang/srt/two_batch_overlap.py +4 -2
  137. sglang/srt/utils.py +235 -11
  138. sglang/test/attention/test_prefix_chunk_info.py +2 -0
  139. sglang/test/runners.py +38 -3
  140. sglang/test/test_block_fp8.py +1 -0
  141. sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
  142. sglang/test/test_block_fp8_ep.py +2 -0
  143. sglang/test/test_utils.py +4 -1
  144. sglang/utils.py +9 -0
  145. sglang/version.py +1 -1
  146. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/METADATA +8 -14
  147. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/RECORD +150 -128
  148. sglang/srt/entrypoints/verl_engine.py +0 -179
  149. sglang/srt/openai_api/adapter.py +0 -1990
  150. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
  151. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
  152. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -26,9 +26,11 @@ from typing import List, Optional, Tuple, Union
26
26
  import torch
27
27
  import torch.distributed as dist
28
28
 
29
+ from sglang.srt import debug_utils
29
30
  from sglang.srt.configs.device_config import DeviceConfig
30
31
  from sglang.srt.configs.load_config import LoadConfig
31
32
  from sglang.srt.configs.model_config import AttentionArch, ModelConfig
33
+ from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS
32
34
  from sglang.srt.distributed import (
33
35
  get_tp_group,
34
36
  get_world_group,
@@ -45,10 +47,9 @@ from sglang.srt.layers.dp_attention import (
45
47
  initialize_dp_attention,
46
48
  )
47
49
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
48
- from sglang.srt.layers.quantization import monkey_patch_isinstance_for_vllm_base_layer
49
- from sglang.srt.layers.quantization.deep_gemm import (
50
- _ENABLE_JIT_DEEPGEMM,
51
- update_deep_gemm_config,
50
+ from sglang.srt.layers.quantization import (
51
+ deep_gemm_wrapper,
52
+ monkey_patch_isinstance_for_vllm_base_layer,
52
53
  )
53
54
  from sglang.srt.layers.sampler import Sampler
54
55
  from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
@@ -70,14 +71,17 @@ from sglang.srt.managers.schedule_batch import (
70
71
  GLOBAL_SERVER_ARGS_KEYS,
71
72
  global_server_args_dict,
72
73
  )
74
+ from sglang.srt.mem_cache.allocator import (
75
+ BaseTokenToKVPoolAllocator,
76
+ PagedTokenToKVPoolAllocator,
77
+ TokenToKVPoolAllocator,
78
+ )
73
79
  from sglang.srt.mem_cache.memory_pool import (
74
80
  DoubleSparseTokenToKVPool,
75
81
  MHATokenToKVPool,
76
82
  MLATokenToKVPool,
77
83
  ReqToTokenPool,
78
- TokenToKVPoolAllocator,
79
84
  )
80
- from sglang.srt.mem_cache.paged_allocator import PagedTokenToKVPoolAllocator
81
85
  from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
82
86
  from sglang.srt.model_executor.expert_location_updater import ExpertLocationUpdater
83
87
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
@@ -93,6 +97,7 @@ from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
93
97
  from sglang.srt.utils import (
94
98
  MultiprocessingSerializer,
95
99
  cpu_has_amx_support,
100
+ dynamic_import,
96
101
  enable_show_time_cost,
97
102
  get_available_gpu_memory,
98
103
  get_bool_env_var,
@@ -110,6 +115,7 @@ from sglang.srt.utils import (
110
115
  )
111
116
 
112
117
  _is_hip = is_hip()
118
+ _is_cpu_amx_available = cpu_has_amx_support()
113
119
 
114
120
  # Use a small KV cache pool size for tests in CI
115
121
  SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
@@ -149,7 +155,7 @@ class ModelRunner:
149
155
  server_args: ServerArgs,
150
156
  is_draft_worker: bool = False,
151
157
  req_to_token_pool: Optional[ReqToTokenPool] = None,
152
- token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None,
158
+ token_to_kv_pool_allocator: Optional[BaseTokenToKVPoolAllocator] = None,
153
159
  ):
154
160
  # Parse args
155
161
  self.model_config = model_config
@@ -162,6 +168,7 @@ class ModelRunner:
162
168
  logger.addFilter(RankZeroFilter(tp_rank == 0))
163
169
  self.tp_rank = tp_rank
164
170
  self.tp_size = tp_size
171
+ self.dp_size = server_args.dp_size
165
172
  self.pp_rank = pp_rank
166
173
  self.pp_size = pp_size
167
174
  self.dist_port = nccl_port
@@ -195,6 +202,7 @@ class ModelRunner:
195
202
  | {
196
203
  # TODO it is indeed not a "server args"
197
204
  "use_mla_backend": self.use_mla_backend,
205
+ "speculative_algorithm": self.spec_algorithm,
198
206
  }
199
207
  )
200
208
 
@@ -205,8 +213,8 @@ class ModelRunner:
205
213
  min_per_gpu_memory = self.init_torch_distributed()
206
214
 
207
215
  # Update deep gemm configure
208
- if _ENABLE_JIT_DEEPGEMM:
209
- update_deep_gemm_config(gpu_id, server_args)
216
+ if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
217
+ deep_gemm_wrapper.update_deep_gemm_config(gpu_id, server_args)
210
218
 
211
219
  # If it is a draft model, tp_group can be different
212
220
  self.initialize(min_per_gpu_memory)
@@ -218,6 +226,7 @@ class ModelRunner:
218
226
 
219
227
  def initialize(self, min_per_gpu_memory: float):
220
228
  server_args = self.server_args
229
+
221
230
  self.memory_saver_adapter = TorchMemorySaverAdapter.create(
222
231
  enable=self.server_args.enable_memory_saver
223
232
  )
@@ -272,6 +281,10 @@ class ModelRunner:
272
281
  self.apply_torch_tp()
273
282
 
274
283
  # Init lora
284
+ # TODO (lifuhuang): when we support dynamic LoRA loading / unloading, we should add
285
+ # a new server arg `enable_lora` to control whether to init LoRA manager to be more
286
+ # explicit, as it is perfectly valid to start a server with an empty lora_paths and
287
+ # load LoRA adapters dynamically later.
275
288
  if server_args.lora_paths is not None:
276
289
  self.init_lora_manager()
277
290
 
@@ -299,7 +312,7 @@ class ModelRunner:
299
312
  if (
300
313
  server_args.attention_backend == "intel_amx"
301
314
  and server_args.device == "cpu"
302
- and not cpu_has_amx_support()
315
+ and not _is_cpu_amx_available
303
316
  ):
304
317
  logger.info(
305
318
  "The current platform does not support Intel AMX, will fallback to torch_native backend."
@@ -543,7 +556,7 @@ class ModelRunner:
543
556
  monkey_patch_vllm_parallel_state()
544
557
  monkey_patch_isinstance_for_vllm_base_layer()
545
558
 
546
- with self.memory_saver_adapter.region():
559
+ with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_WEIGHTS):
547
560
  self.model = get_model(
548
561
  model_config=self.model_config,
549
562
  load_config=self.load_config,
@@ -761,6 +774,9 @@ class ModelRunner:
761
774
  ]
762
775
  if load_format == "direct":
763
776
  _model_load_weights_direct(self.model, named_tensors)
777
+ elif load_format in self.server_args.custom_weight_loader:
778
+ custom_loader = dynamic_import(load_format)
779
+ custom_loader(self.model, named_tensors)
764
780
  elif load_format is None:
765
781
  self.model.load_weights(named_tensors)
766
782
  else:
@@ -787,7 +803,6 @@ class ModelRunner:
787
803
  def init_lora_manager(self):
788
804
  self.lora_manager = LoRAManager(
789
805
  base_model=self.model,
790
- lora_paths=self.server_args.lora_paths,
791
806
  base_hf_config=self.model_config.hf_config,
792
807
  max_loras_per_batch=self.server_args.max_loras_per_batch,
793
808
  load_config=self.load_config,
@@ -796,6 +811,7 @@ class ModelRunner:
796
811
  tp_size=self.tp_size,
797
812
  tp_rank=self.tp_rank,
798
813
  )
814
+ self.lora_manager.load_lora_adapters(self.server_args.lora_paths)
799
815
  logger.info("LoRA manager ready.")
800
816
 
801
817
  def profile_max_num_token(self, total_gpu_memory: int):
@@ -337,7 +337,14 @@ class DefaultModelLoader(BaseModelLoader):
337
337
  hf_weights_files,
338
338
  )
339
339
  elif use_safetensors:
340
- weights_iterator = safetensors_weights_iterator(hf_weights_files)
340
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
341
+
342
+ weight_loader_disable_mmap = global_server_args_dict.get(
343
+ "weight_loader_disable_mmap"
344
+ )
345
+ weights_iterator = safetensors_weights_iterator(
346
+ hf_weights_files, disable_mmap=weight_loader_disable_mmap
347
+ )
341
348
  else:
342
349
  weights_iterator = pt_weights_iterator(hf_weights_files)
343
350
 
@@ -1259,12 +1266,19 @@ class GGUFModelLoader(BaseModelLoader):
1259
1266
  ):
1260
1267
  model_config.hf_config.update({"tie_word_embeddings": True})
1261
1268
 
1269
+ target_device = torch.device(device_config.device)
1262
1270
  with set_default_torch_dtype(model_config.dtype):
1263
- with torch.device(device_config.device):
1271
+ with target_device:
1264
1272
  model = _initialize_model(model_config, self.load_config)
1265
1273
  model.load_weights(
1266
1274
  self._get_weights_iterator(local_model_path, gguf_weights_map)
1267
1275
  )
1276
+
1277
+ for _, module in model.named_modules():
1278
+ quant_method = getattr(module, "quant_method", None)
1279
+ if quant_method is not None:
1280
+ with device_loading_context(module, target_device):
1281
+ quant_method.process_weights_after_loading(module)
1268
1282
  return model
1269
1283
 
1270
1284
 
@@ -34,6 +34,7 @@ from sglang.srt.configs.load_config import LoadConfig
34
34
  from sglang.srt.configs.model_config import ModelConfig
35
35
  from sglang.srt.distributed import get_tensor_model_parallel_rank
36
36
  from sglang.srt.layers.quantization import QuantizationConfig, get_quantization_config
37
+ from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp4Config
37
38
  from sglang.srt.utils import print_warning_once
38
39
 
39
40
  logger = logging.getLogger(__name__)
@@ -206,7 +207,10 @@ def get_quant_config(
206
207
  config["adapter_name_or_path"] = model_name_or_path
207
208
  elif model_config.quantization == "modelopt":
208
209
  if config["producer"]["name"] == "modelopt":
209
- return quant_cls.from_config(config)
210
+ if "FP4" in config["quantization"]["quant_algo"]:
211
+ return ModelOptFp4Config.from_config(config)
212
+ else:
213
+ return quant_cls.from_config(config)
210
214
  else:
211
215
  raise ValueError(
212
216
  f"Unsupported quantization config"
@@ -418,6 +422,7 @@ def safetensors_weights_iterator(
418
422
  hf_weights_files: List[str],
419
423
  is_all_weights_sharded: bool = False,
420
424
  decryption_key: Optional[str] = None,
425
+ disable_mmap: bool = False,
421
426
  ) -> Generator[Tuple[str, torch.Tensor], None, None]:
422
427
  """Iterate over the weights in the model safetensor files.
423
428
 
@@ -439,7 +444,11 @@ def safetensors_weights_iterator(
439
444
  disable=not enable_tqdm,
440
445
  bar_format=_BAR_FORMAT,
441
446
  ):
442
- result = safetensors.torch.load_file(st_file, device="cpu")
447
+ if disable_mmap:
448
+ with open(st_file, "rb") as f:
449
+ result = safetensors.torch.load(f.read())
450
+ else:
451
+ result = safetensors.torch.load_file(st_file, device="cpu")
443
452
  for name, param in result.items():
444
453
  yield name, param
445
454
 
sglang/srt/models/bert.py CHANGED
@@ -11,12 +11,13 @@ from sglang.srt.layers.linear import (
11
11
  QKVParallelLinear,
12
12
  RowParallelLinear,
13
13
  )
14
- from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
14
+ from sglang.srt.layers.pooler import CrossEncodingPooler, Pooler, PoolingType
15
15
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
16
16
  from sglang.srt.layers.radix_attention import AttentionType, RadixAttention
17
17
  from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
18
18
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
19
19
  from sglang.srt.model_loader.weight_utils import default_weight_loader
20
+ from sglang.srt.utils import add_prefix
20
21
 
21
22
  BertConfig = None
22
23
 
@@ -50,7 +51,8 @@ class BertEmbedding(nn.Module):
50
51
  def forward(
51
52
  self,
52
53
  input_ids: torch.Tensor,
53
- position_ids: torch.Tensor,
54
+ positions: torch.Tensor,
55
+ forward_batch: ForwardBatch,
54
56
  ) -> torch.Tensor:
55
57
  input_shape = input_ids.size()
56
58
 
@@ -58,11 +60,14 @@ class BertEmbedding(nn.Module):
58
60
  inputs_embeds = self.word_embeddings(input_ids)
59
61
 
60
62
  # Position embeddings.
61
- position_embeddings = self.position_embeddings(position_ids)
63
+ position_embeddings = self.position_embeddings(positions)
62
64
 
63
- token_type_ids = torch.zeros(
64
- input_shape, dtype=torch.long, device=inputs_embeds.device
65
- )
65
+ token_type_ids = forward_batch.token_type_ids
66
+
67
+ if token_type_ids is None:
68
+ token_type_ids = torch.zeros(
69
+ input_shape, dtype=torch.long, device=inputs_embeds.device
70
+ )
66
71
 
67
72
  token_type_embeddings = self.token_type_embeddings(token_type_ids)
68
73
 
@@ -71,6 +76,25 @@ class BertEmbedding(nn.Module):
71
76
  return embeddings
72
77
 
73
78
 
79
+ class BertPooler(nn.Module):
80
+
81
+ def __init__(self, config: BertConfig):
82
+ super().__init__()
83
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
84
+ self.activation = nn.Tanh()
85
+
86
+ def forward(
87
+ self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
88
+ ) -> torch.Tensor:
89
+ # simply taking the hidden state corresponding
90
+ first_token_tensor = hidden_states[0, :]
91
+
92
+ pooled_output = self.dense(first_token_tensor)
93
+ pooled_output = self.activation(pooled_output)
94
+
95
+ return pooled_output
96
+
97
+
74
98
  class BertEncoder(nn.Module):
75
99
 
76
100
  def __init__(
@@ -113,6 +137,8 @@ class BertLayer(nn.Module):
113
137
  ):
114
138
  super().__init__()
115
139
 
140
+ self.layer_id = layer_id
141
+
116
142
  self.attention = BertAttention(
117
143
  hidden_size=config.hidden_size,
118
144
  num_attention_heads=config.num_attention_heads,
@@ -142,6 +168,7 @@ class BertLayer(nn.Module):
142
168
  attn_output = self.attention(hidden_states, forward_batch)
143
169
  intermediate_output = self.intermediate(attn_output)
144
170
  output = self.output(intermediate_output, attn_output)
171
+
145
172
  return output
146
173
 
147
174
 
@@ -326,16 +353,23 @@ class BertModel(nn.Module):
326
353
  *,
327
354
  config: BertConfig,
328
355
  quant_config: Optional[QuantizationConfig] = None,
356
+ use_bert_pooler: bool = False,
329
357
  prefix: str = "",
330
358
  ):
331
359
  super().__init__()
360
+ self.use_bert_pooler = use_bert_pooler
332
361
  self.config = config
333
362
  self.embeddings = BertEmbedding(config)
334
363
  self.encoder = BertEncoder(
335
- config=config, quant_config=quant_config, prefix=f"encoder"
364
+ config=config,
365
+ quant_config=quant_config,
366
+ prefix=add_prefix("encoder", prefix),
367
+ )
368
+ self.pooler = (
369
+ BertPooler(config)
370
+ if self.use_bert_pooler
371
+ else Pooler(pooling_type=PoolingType.LAST, normalize=True)
336
372
  )
337
- self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
338
- # self.pooler = BertPooler(config)
339
373
 
340
374
  @torch.no_grad()
341
375
  def forward(
@@ -351,11 +385,16 @@ class BertModel(nn.Module):
351
385
 
352
386
  hidden_states = self.embeddings(
353
387
  input_ids=input_ids,
354
- position_ids=positions,
388
+ positions=positions,
389
+ forward_batch=forward_batch,
355
390
  )
356
391
 
357
392
  hidden_states = self.encoder(hidden_states, forward_batch=forward_batch)
358
- return self.pooler(hidden_states, forward_batch)
393
+
394
+ if not self.use_bert_pooler:
395
+ hidden_states = self.pooler(hidden_states, forward_batch)
396
+
397
+ return hidden_states
359
398
 
360
399
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]:
361
400
  stacked_params_mapping = [
@@ -368,7 +407,7 @@ class BertModel(nn.Module):
368
407
  params_dict = dict(self.named_parameters())
369
408
  for name, loaded_weight in weights:
370
409
  name = name.replace("self", "self_attn")
371
- if "pooler" in name:
410
+ if not self.use_bert_pooler and "pooler" in name:
372
411
  continue
373
412
  for param_name, weight_name, shard_id in stacked_params_mapping:
374
413
 
@@ -395,4 +434,65 @@ class Contriever(BertModel):
395
434
  pass
396
435
 
397
436
 
398
- EntryClass = [BertModel, Contriever]
437
+ class BertForSequenceClassification(nn.Module):
438
+
439
+ def __init__(
440
+ self,
441
+ *,
442
+ config: BertConfig,
443
+ quant_config: Optional[QuantizationConfig] = None,
444
+ prefix: str = "",
445
+ ):
446
+ super().__init__()
447
+
448
+ self.num_labels = config.num_labels
449
+ self.bert = BertModel(
450
+ config=config,
451
+ quant_config=quant_config,
452
+ use_bert_pooler=True,
453
+ prefix=add_prefix("bert", prefix),
454
+ )
455
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
456
+ self.pooler = CrossEncodingPooler(config, self.classifier, self.bert.pooler)
457
+
458
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
459
+ self_weights = []
460
+
461
+ def weight_filter():
462
+ for name, weight in weights:
463
+ if name.startswith("bert."):
464
+ yield (name[len("bert.") :], weight)
465
+ else:
466
+ self_weights.append((name, weight))
467
+
468
+ self.bert.load_weights(weight_filter())
469
+
470
+ params_dict = dict(self.named_parameters())
471
+
472
+ for name, loaded_weight in self_weights:
473
+ if name.startswith("classifier"):
474
+ param = params_dict[name]
475
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
476
+ weight_loader(param, loaded_weight)
477
+
478
+ def forward(
479
+ self,
480
+ input_ids: torch.Tensor,
481
+ positions: torch.Tensor,
482
+ forward_batch: ForwardBatch,
483
+ input_embeds: torch.Tensor = None,
484
+ get_embedding: bool = False,
485
+ ) -> torch.Tensor:
486
+ assert get_embedding == True
487
+
488
+ hidden_states = self.bert(
489
+ input_ids=input_ids,
490
+ positions=positions,
491
+ forward_batch=forward_batch,
492
+ input_embeds=input_embeds,
493
+ get_embedding=get_embedding,
494
+ )
495
+ return self.pooler(hidden_states, forward_batch)
496
+
497
+
498
+ EntryClass = [BertModel, Contriever, BertForSequenceClassification]
@@ -22,7 +22,6 @@ from transformers import PretrainedConfig
22
22
 
23
23
  from sglang.srt.distributed import get_tensor_model_parallel_world_size
24
24
  from sglang.srt.layers.layernorm import RMSNorm
25
- from sglang.srt.layers.linear import ReplicatedLinear
26
25
  from sglang.srt.layers.logits_processor import LogitsProcessor
27
26
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
28
27
  from sglang.srt.layers.vocab_parallel_embedding import (
@@ -45,6 +44,12 @@ class DeepseekModelNextN(nn.Module):
45
44
  prefix: str = "",
46
45
  ) -> None:
47
46
  super().__init__()
47
+ if quant_config is not None and quant_config.get_name() == "modelopt_fp4":
48
+ logger.warning(
49
+ "Overriding DeepseekV3ForCausalLMNextN quant config for modelopt_fp4 Deepseek model."
50
+ )
51
+ quant_config = None
52
+
48
53
  self.vocab_size = config.vocab_size
49
54
 
50
55
  self.embed_tokens = VocabParallelEmbedding(
@@ -77,6 +82,7 @@ class DeepseekModelNextN(nn.Module):
77
82
  forward_batch: ForwardBatch,
78
83
  input_embeds: torch.Tensor = None,
79
84
  ) -> torch.Tensor:
85
+
80
86
  zero_allocator = BumpAllocator(
81
87
  buffer_size=2,
82
88
  dtype=torch.float32,
@@ -90,15 +96,16 @@ class DeepseekModelNextN(nn.Module):
90
96
  else:
91
97
  hidden_states = input_embeds
92
98
 
93
- hidden_states = self.eh_proj(
94
- torch.cat(
95
- (
96
- self.enorm(hidden_states),
97
- self.hnorm(forward_batch.spec_info.hidden_states),
98
- ),
99
- dim=-1,
99
+ if hidden_states.shape[0] > 0:
100
+ hidden_states = self.eh_proj(
101
+ torch.cat(
102
+ (
103
+ self.enorm(hidden_states),
104
+ self.hnorm(forward_batch.spec_info.hidden_states),
105
+ ),
106
+ dim=-1,
107
+ )
100
108
  )
101
- )
102
109
 
103
110
  residual = None
104
111
  hidden_states, residual = self.decoder(
@@ -106,7 +113,11 @@ class DeepseekModelNextN(nn.Module):
106
113
  )
107
114
 
108
115
  if not forward_batch.forward_mode.is_idle():
109
- hidden_states, _ = self.shared_head.norm(hidden_states, residual)
116
+ if residual is not None:
117
+ hidden_states, _ = self.shared_head.norm(hidden_states, residual)
118
+ else:
119
+ hidden_states = self.shared_head.norm(hidden_states)
120
+
110
121
  return hidden_states
111
122
 
112
123
 
@@ -127,23 +138,14 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
127
138
  self.model = DeepseekModelNextN(
128
139
  config, quant_config, prefix=add_prefix("model", prefix)
129
140
  )
130
-
131
- if global_server_args_dict["enable_dp_attention"]:
132
- self.lm_head = ReplicatedLinear(
133
- config.hidden_size,
134
- config.vocab_size,
135
- bias=False,
136
- prefix=add_prefix("model.shared_head.head", prefix),
137
- )
138
- self.logits_processor = LogitsProcessor(config, skip_all_gather=True)
139
- else:
140
- self.lm_head = ParallelLMHead(
141
- config.vocab_size,
142
- config.hidden_size,
143
- quant_config=quant_config,
144
- prefix=add_prefix("model.shared_head.head", prefix),
145
- )
146
- self.logits_processor = LogitsProcessor(config)
141
+ self.lm_head = ParallelLMHead(
142
+ config.vocab_size,
143
+ config.hidden_size,
144
+ quant_config=quant_config,
145
+ prefix=add_prefix("model.shared_head.head", prefix),
146
+ use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
147
+ )
148
+ self.logits_processor = LogitsProcessor(config)
147
149
 
148
150
  @torch.no_grad()
149
151
  def forward(