sglang 0.5.2rc2__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (396) hide show
  1. sglang/bench_one_batch.py +7 -11
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +474 -142
  4. sglang/compile_deep_gemm.py +3 -0
  5. sglang/global_config.py +2 -2
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +10 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +314 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/mamba_utils.py +117 -0
  17. sglang/srt/configs/model_config.py +228 -92
  18. sglang/srt/configs/nemotron_h.py +286 -0
  19. sglang/srt/configs/qwen3_next.py +294 -0
  20. sglang/srt/configs/qwen3_vl.py +586 -0
  21. sglang/srt/connector/__init__.py +8 -1
  22. sglang/srt/connector/remote_instance.py +82 -0
  23. sglang/srt/constrained/base_grammar_backend.py +49 -12
  24. sglang/srt/constrained/llguidance_backend.py +0 -1
  25. sglang/srt/constrained/outlines_backend.py +0 -1
  26. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  27. sglang/srt/constrained/xgrammar_backend.py +30 -9
  28. sglang/srt/custom_op.py +11 -1
  29. sglang/srt/debug_utils/dump_comparator.py +81 -44
  30. sglang/srt/debug_utils/dump_loader.py +97 -0
  31. sglang/srt/debug_utils/dumper.py +21 -6
  32. sglang/srt/debug_utils/text_comparator.py +73 -11
  33. sglang/srt/disaggregation/ascend/conn.py +2 -2
  34. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  35. sglang/srt/disaggregation/base/conn.py +1 -1
  36. sglang/srt/disaggregation/common/conn.py +279 -108
  37. sglang/srt/disaggregation/decode.py +78 -37
  38. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  39. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  40. sglang/srt/disaggregation/fake/conn.py +1 -1
  41. sglang/srt/disaggregation/mini_lb.py +6 -445
  42. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  43. sglang/srt/disaggregation/nixl/conn.py +373 -68
  44. sglang/srt/disaggregation/prefill.py +53 -49
  45. sglang/srt/disaggregation/utils.py +40 -54
  46. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  47. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  48. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  49. sglang/srt/distributed/parallel_state.py +156 -80
  50. sglang/srt/entrypoints/engine.py +59 -18
  51. sglang/srt/entrypoints/grpc_request_manager.py +842 -0
  52. sglang/srt/entrypoints/grpc_server.py +950 -0
  53. sglang/srt/entrypoints/http_server.py +179 -60
  54. sglang/srt/entrypoints/openai/protocol.py +265 -29
  55. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  56. sglang/srt/entrypoints/openai/serving_chat.py +213 -122
  57. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  58. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  59. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  60. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  61. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  62. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  63. sglang/srt/environ.py +289 -0
  64. sglang/srt/eplb/eplb_manager.py +2 -2
  65. sglang/srt/eplb/expert_distribution.py +26 -13
  66. sglang/srt/eplb/expert_location.py +38 -8
  67. sglang/srt/eplb/expert_location_updater.py +1 -1
  68. sglang/srt/function_call/base_format_detector.py +3 -6
  69. sglang/srt/function_call/ebnf_composer.py +11 -9
  70. sglang/srt/function_call/function_call_parser.py +17 -8
  71. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  72. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  73. sglang/srt/function_call/json_array_parser.py +63 -0
  74. sglang/srt/function_call/kimik2_detector.py +17 -4
  75. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  76. sglang/srt/function_call/utils.py +96 -5
  77. sglang/srt/grpc/__init__.py +1 -0
  78. sglang/srt/grpc/compile_proto.py +245 -0
  79. sglang/srt/grpc/sglang_scheduler_pb2.py +119 -0
  80. sglang/srt/grpc/sglang_scheduler_pb2.pyi +492 -0
  81. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +327 -0
  82. sglang/srt/layers/activation.py +143 -9
  83. sglang/srt/layers/attention/aiter_backend.py +14 -15
  84. sglang/srt/layers/attention/ascend_backend.py +115 -9
  85. sglang/srt/layers/attention/attention_registry.py +215 -0
  86. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  87. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  88. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  89. sglang/srt/layers/attention/fla/chunk.py +242 -0
  90. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  91. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  92. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  93. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  94. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  95. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  96. sglang/srt/layers/attention/fla/index.py +37 -0
  97. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  98. sglang/srt/layers/attention/fla/layernorm_gated.py +343 -0
  99. sglang/srt/layers/attention/fla/op.py +66 -0
  100. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  101. sglang/srt/layers/attention/fla/utils.py +331 -0
  102. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  103. sglang/srt/layers/attention/flashattention_backend.py +40 -8
  104. sglang/srt/layers/attention/flashinfer_backend.py +341 -204
  105. sglang/srt/layers/attention/flashinfer_mla_backend.py +28 -28
  106. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  107. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  108. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +708 -0
  109. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  110. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  111. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +974 -0
  112. sglang/srt/layers/attention/mamba/mamba.py +577 -0
  113. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  114. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  115. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  116. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  117. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  119. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  120. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  121. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  122. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  123. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  124. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  125. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  126. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  127. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  128. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  129. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  130. sglang/srt/layers/attention/nsa/utils.py +24 -0
  131. sglang/srt/layers/attention/nsa_backend.py +887 -0
  132. sglang/srt/layers/attention/tbo_backend.py +6 -6
  133. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  134. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  135. sglang/srt/layers/attention/triton_backend.py +57 -7
  136. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  137. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  138. sglang/srt/layers/attention/vision.py +58 -0
  139. sglang/srt/layers/attention/wave_backend.py +4 -4
  140. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  141. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  142. sglang/srt/layers/communicator.py +8 -0
  143. sglang/srt/layers/dp_attention.py +41 -2
  144. sglang/srt/layers/elementwise.py +3 -1
  145. sglang/srt/layers/layernorm.py +34 -15
  146. sglang/srt/layers/linear.py +55 -7
  147. sglang/srt/layers/logits_processor.py +180 -18
  148. sglang/srt/layers/modelopt_utils.py +11 -0
  149. sglang/srt/layers/moe/__init__.py +2 -1
  150. sglang/srt/layers/moe/cutlass_w4a8_moe.py +21 -24
  151. sglang/srt/layers/moe/ep_moe/kernels.py +33 -454
  152. sglang/srt/layers/moe/ep_moe/layer.py +248 -333
  153. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  154. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  155. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  164. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  165. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  166. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  167. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  168. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  169. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  170. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  171. sglang/srt/layers/moe/fused_moe_triton/layer.py +68 -72
  172. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  173. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  174. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  175. sglang/srt/layers/moe/moe_runner/runner.py +83 -0
  176. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  177. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  178. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  179. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  180. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  181. sglang/srt/layers/moe/topk.py +30 -9
  182. sglang/srt/layers/moe/utils.py +29 -7
  183. sglang/srt/layers/parameter.py +23 -6
  184. sglang/srt/layers/quantization/__init__.py +1 -1
  185. sglang/srt/layers/quantization/awq.py +19 -7
  186. sglang/srt/layers/quantization/base_config.py +11 -6
  187. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  188. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  189. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  190. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  191. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  192. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  193. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  194. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  195. sglang/srt/layers/quantization/fp8.py +155 -60
  196. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  197. sglang/srt/layers/quantization/gptq.py +25 -17
  198. sglang/srt/layers/quantization/modelopt_quant.py +191 -56
  199. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  200. sglang/srt/layers/quantization/mxfp4.py +74 -42
  201. sglang/srt/layers/quantization/quark/quark.py +3 -1
  202. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  203. sglang/srt/layers/quantization/unquant.py +135 -47
  204. sglang/srt/layers/quantization/w4afp8.py +28 -33
  205. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  206. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  207. sglang/srt/layers/rotary_embedding.py +78 -31
  208. sglang/srt/layers/sampler.py +213 -21
  209. sglang/srt/layers/utils.py +23 -0
  210. sglang/srt/lora/backend/base_backend.py +50 -8
  211. sglang/srt/lora/backend/chunked_backend.py +348 -0
  212. sglang/srt/lora/backend/triton_backend.py +99 -5
  213. sglang/srt/lora/layers.py +32 -0
  214. sglang/srt/lora/lora.py +8 -3
  215. sglang/srt/lora/lora_manager.py +44 -118
  216. sglang/srt/lora/mem_pool.py +25 -11
  217. sglang/srt/lora/triton_ops/__init__.py +4 -0
  218. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  219. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  220. sglang/srt/lora/utils.py +22 -11
  221. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  222. sglang/srt/managers/cache_controller.py +199 -301
  223. sglang/srt/managers/data_parallel_controller.py +115 -80
  224. sglang/srt/managers/detokenizer_manager.py +19 -15
  225. sglang/srt/managers/disagg_service.py +46 -0
  226. sglang/srt/managers/io_struct.py +340 -109
  227. sglang/srt/managers/mm_utils.py +44 -6
  228. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  229. sglang/srt/managers/multimodal_processor.py +1 -2
  230. sglang/srt/managers/overlap_utils.py +55 -0
  231. sglang/srt/managers/schedule_batch.py +343 -212
  232. sglang/srt/managers/schedule_policy.py +145 -18
  233. sglang/srt/managers/scheduler.py +653 -273
  234. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  235. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  236. sglang/srt/managers/scheduler_output_processor_mixin.py +255 -108
  237. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  238. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  239. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  240. sglang/srt/managers/tokenizer_manager.py +579 -674
  241. sglang/srt/managers/tp_worker.py +96 -26
  242. sglang/srt/managers/utils.py +1 -45
  243. sglang/srt/mem_cache/allocator.py +21 -22
  244. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  245. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  246. sglang/srt/mem_cache/chunk_cache.py +9 -2
  247. sglang/srt/mem_cache/evict_policy.py +23 -0
  248. sglang/srt/mem_cache/hicache_storage.py +43 -24
  249. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  250. sglang/srt/mem_cache/memory_pool.py +651 -80
  251. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  252. sglang/srt/mem_cache/radix_cache.py +227 -73
  253. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  254. sglang/srt/mem_cache/storage/__init__.py +10 -0
  255. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  256. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  257. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  258. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  259. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  260. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  261. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  262. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  263. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  264. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  265. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  266. sglang/srt/mem_cache/swa_radix_cache.py +93 -48
  267. sglang/srt/metrics/collector.py +511 -132
  268. sglang/srt/metrics/func_timer.py +2 -7
  269. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  270. sglang/srt/metrics/utils.py +8 -1
  271. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  272. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  273. sglang/srt/model_executor/forward_batch_info.py +74 -46
  274. sglang/srt/model_executor/model_runner.py +455 -176
  275. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  276. sglang/srt/model_loader/__init__.py +10 -4
  277. sglang/srt/model_loader/loader.py +319 -10
  278. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  279. sglang/srt/model_loader/weight_utils.py +161 -3
  280. sglang/srt/models/apertus.py +686 -0
  281. sglang/srt/models/bailing_moe.py +820 -217
  282. sglang/srt/models/bailing_moe_nextn.py +168 -0
  283. sglang/srt/models/deepseek_nextn.py +6 -1
  284. sglang/srt/models/deepseek_v2.py +607 -130
  285. sglang/srt/models/dots_ocr.py +173 -0
  286. sglang/srt/models/dots_vlm.py +174 -0
  287. sglang/srt/models/dots_vlm_vit.py +337 -0
  288. sglang/srt/models/ernie4.py +1 -1
  289. sglang/srt/models/falcon_h1.py +578 -0
  290. sglang/srt/models/gemma3_causal.py +0 -2
  291. sglang/srt/models/gemma3_mm.py +17 -1
  292. sglang/srt/models/gemma3n_mm.py +2 -2
  293. sglang/srt/models/glm4_moe.py +4 -4
  294. sglang/srt/models/glm4_moe_nextn.py +2 -2
  295. sglang/srt/models/glm4v.py +5 -3
  296. sglang/srt/models/glm4v_moe.py +4 -1
  297. sglang/srt/models/gpt_oss.py +8 -31
  298. sglang/srt/models/grok.py +5 -13
  299. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  300. sglang/srt/models/llama.py +4 -0
  301. sglang/srt/models/llama4.py +9 -0
  302. sglang/srt/models/llama_eagle3.py +13 -0
  303. sglang/srt/models/longcat_flash.py +3 -3
  304. sglang/srt/models/longcat_flash_nextn.py +1 -1
  305. sglang/srt/models/mixtral.py +1 -3
  306. sglang/srt/models/mllama4.py +50 -4
  307. sglang/srt/models/nemotron_h.py +514 -0
  308. sglang/srt/models/opt.py +637 -0
  309. sglang/srt/models/qwen2_5_vl.py +29 -5
  310. sglang/srt/models/qwen2_audio.py +1 -1
  311. sglang/srt/models/qwen2_moe.py +120 -13
  312. sglang/srt/models/qwen2_vl.py +1 -1
  313. sglang/srt/models/qwen3.py +18 -3
  314. sglang/srt/models/qwen3_moe.py +32 -4
  315. sglang/srt/models/qwen3_next.py +1069 -0
  316. sglang/srt/models/qwen3_next_mtp.py +112 -0
  317. sglang/srt/models/qwen3_vl.py +787 -0
  318. sglang/srt/models/qwen3_vl_moe.py +471 -0
  319. sglang/srt/models/registry.py +15 -3
  320. sglang/srt/models/sarashina2_vision.py +269 -0
  321. sglang/srt/models/solar.py +505 -0
  322. sglang/srt/models/starcoder2.py +357 -0
  323. sglang/srt/models/step3_vl.py +1 -1
  324. sglang/srt/models/torch_native_llama.py +9 -2
  325. sglang/srt/models/utils.py +55 -0
  326. sglang/srt/multimodal/processors/base_processor.py +15 -7
  327. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  328. sglang/srt/multimodal/processors/glm4v.py +9 -9
  329. sglang/srt/multimodal/processors/internvl.py +153 -129
  330. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  331. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  332. sglang/srt/offloader.py +27 -3
  333. sglang/srt/parser/jinja_template_utils.py +6 -0
  334. sglang/srt/sampling/sampling_batch_info.py +49 -26
  335. sglang/srt/sampling/sampling_params.py +7 -0
  336. sglang/srt/server_args.py +1051 -285
  337. sglang/srt/server_args_config_parser.py +146 -0
  338. sglang/srt/single_batch_overlap.py +151 -0
  339. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  340. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  341. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  342. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  343. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  344. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  345. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  346. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  347. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  348. sglang/srt/speculative/eagle_worker.py +98 -29
  349. sglang/srt/speculative/ngram_info.py +428 -0
  350. sglang/srt/speculative/ngram_worker.py +246 -0
  351. sglang/srt/speculative/spec_info.py +52 -0
  352. sglang/srt/speculative/spec_utils.py +605 -0
  353. sglang/srt/speculative/standalone_worker.py +109 -0
  354. sglang/srt/torch_memory_saver_adapter.py +5 -7
  355. sglang/srt/tracing/trace.py +578 -0
  356. sglang/srt/two_batch_overlap.py +9 -5
  357. sglang/srt/utils/__init__.py +2 -0
  358. sglang/srt/{utils.py → utils/common.py} +451 -77
  359. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +55 -5
  360. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  361. sglang/srt/utils/rpd_utils.py +452 -0
  362. sglang/srt/utils/slow_rank_detector.py +71 -0
  363. sglang/srt/warmup.py +8 -4
  364. sglang/srt/weight_sync/utils.py +2 -2
  365. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  366. sglang/test/get_logits_ut.py +57 -0
  367. sglang/test/longbench_v2/__init__.py +1 -0
  368. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  369. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  370. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  371. sglang/test/run_eval.py +119 -11
  372. sglang/test/runners.py +5 -1
  373. sglang/test/simple_eval_common.py +5 -2
  374. sglang/test/simple_eval_longbench_v2.py +332 -0
  375. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  376. sglang/test/test_block_fp8.py +2 -2
  377. sglang/test/test_cutlass_moe.py +24 -6
  378. sglang/test/test_cutlass_w4a8_moe.py +9 -19
  379. sglang/test/test_deterministic.py +313 -0
  380. sglang/test/test_deterministic_utils.py +81 -0
  381. sglang/test/test_disaggregation_utils.py +140 -0
  382. sglang/test/test_fp4_moe.py +370 -1
  383. sglang/test/test_programs.py +1 -1
  384. sglang/test/test_utils.py +407 -8
  385. sglang/utils.py +21 -1
  386. sglang/version.py +1 -1
  387. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +69 -124
  388. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +392 -251
  389. sglang/srt/disaggregation/launch_lb.py +0 -118
  390. sglang/srt/managers/tp_worker_overlap_thread.py +0 -296
  391. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  392. sglang/test/test_block_fp8_ep.py +0 -358
  393. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  394. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
  395. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
  396. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,505 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+
4
+ # Adapted from
5
+ # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
6
+ # Copyright 2023 The vLLM team.
7
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
8
+ #
9
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
10
+ # and OPT implementations in this library. It has been modified from its
11
+ # original forms to accommodate minor architectural differences compared
12
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
13
+ #
14
+ # Licensed under the Apache License, Version 2.0 (the "License");
15
+ # you may not use this file except in compliance with the License.
16
+ # You may obtain a copy of the License at
17
+ #
18
+ # http://www.apache.org/licenses/LICENSE-2.0
19
+ #
20
+ # Unless required by applicable law or agreed to in writing, software
21
+ # distributed under the License is distributed on an "AS IS" BASIS,
22
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
23
+ # See the License for the specific language governing permissions and
24
+ # limitations under the License.
25
+ # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/solar.py
26
+ from collections.abc import Iterable
27
+ from typing import Any, List, Optional, Tuple, Union
28
+
29
+ import torch
30
+ from torch import nn
31
+ from transformers import PretrainedConfig
32
+
33
+ from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size
34
+ from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
35
+ from sglang.srt.layers.activation import SiluAndMul
36
+ from sglang.srt.layers.layernorm import RMSNorm
37
+ from sglang.srt.layers.linear import (
38
+ MergedColumnParallelLinear,
39
+ QKVParallelLinear,
40
+ RowParallelLinear,
41
+ )
42
+ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
43
+ from sglang.srt.layers.quantization import QuantizationConfig
44
+ from sglang.srt.layers.radix_attention import RadixAttention
45
+ from sglang.srt.layers.rotary_embedding import get_rope
46
+ from sglang.srt.layers.utils import PPMissingLayer
47
+ from sglang.srt.layers.vocab_parallel_embedding import (
48
+ DEFAULT_VOCAB_PADDING_SIZE,
49
+ ParallelLMHead,
50
+ VocabParallelEmbedding,
51
+ )
52
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
53
+ from sglang.srt.model_loader.weight_utils import (
54
+ default_weight_loader,
55
+ kv_cache_scales_loader,
56
+ )
57
+ from sglang.srt.utils import add_prefix, make_layers
58
+
59
+
60
+ class SolarMLP(nn.Module):
61
+
62
+ def __init__(
63
+ self,
64
+ hidden_size: int,
65
+ intermediate_size: int,
66
+ hidden_act: str,
67
+ quant_config: Optional[QuantizationConfig] = None,
68
+ bias: bool = False,
69
+ prefix: str = "",
70
+ ) -> None:
71
+ super().__init__()
72
+ self.gate_up_proj = MergedColumnParallelLinear(
73
+ input_size=hidden_size,
74
+ output_sizes=[intermediate_size] * 2,
75
+ bias=bias,
76
+ quant_config=quant_config,
77
+ prefix=f"{prefix}.gate_up_proj",
78
+ )
79
+ self.down_proj = RowParallelLinear(
80
+ input_size=intermediate_size,
81
+ output_size=hidden_size,
82
+ bias=bias,
83
+ quant_config=quant_config,
84
+ prefix=f"{prefix}.down_proj",
85
+ )
86
+ if hidden_act != "silu":
87
+ raise ValueError(
88
+ f"Unsupported activation: {hidden_act}. "
89
+ "Only silu is supported for now."
90
+ )
91
+ self.act_fn = SiluAndMul()
92
+
93
+ def forward(self, x):
94
+ gate_up, _ = self.gate_up_proj(x)
95
+ x = self.act_fn(gate_up)
96
+ x, _ = self.down_proj(x)
97
+ return x
98
+
99
+
100
+ class SolarAttention(nn.Module):
101
+
102
+ def __init__(
103
+ self,
104
+ config: PretrainedConfig,
105
+ hidden_size: int,
106
+ num_heads: int,
107
+ num_kv_heads: int,
108
+ rope_theta: float = 10000,
109
+ rope_scaling: Optional[dict[str, Any]] = None,
110
+ max_position_embeddings: int = 8192,
111
+ quant_config: Optional[QuantizationConfig] = None,
112
+ bias: bool = False,
113
+ prefix: str = "",
114
+ layer_id: int = 0,
115
+ ) -> None:
116
+ super().__init__()
117
+ self.hidden_size = hidden_size
118
+ tp_size = get_tensor_model_parallel_world_size()
119
+ self.total_num_heads = num_heads
120
+ assert self.total_num_heads % tp_size == 0
121
+ self.num_heads = self.total_num_heads // tp_size
122
+ self.total_num_kv_heads = num_kv_heads
123
+ if self.total_num_kv_heads >= tp_size:
124
+ assert self.total_num_kv_heads % tp_size == 0
125
+ else:
126
+ assert tp_size % self.total_num_kv_heads == 0
127
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
128
+
129
+ self.head_dim = getattr(config, "head_dim", None)
130
+ if self.head_dim is None:
131
+ self.head_dim = self.hidden_size // self.total_num_heads
132
+ self.q_size = self.num_heads * self.head_dim
133
+ self.kv_size = self.num_kv_heads * self.head_dim
134
+ self.scaling = self.head_dim**-0.5
135
+ self.rope_theta = rope_theta
136
+ self.max_position_embeddings = max_position_embeddings
137
+
138
+ self.qkv_proj = QKVParallelLinear(
139
+ hidden_size=hidden_size,
140
+ head_size=self.head_dim,
141
+ total_num_heads=self.total_num_heads,
142
+ total_num_kv_heads=self.total_num_kv_heads,
143
+ bias=bias,
144
+ quant_config=quant_config,
145
+ prefix=f"{prefix}.qkv_proj",
146
+ )
147
+ self.o_proj = RowParallelLinear(
148
+ input_size=self.total_num_heads * self.head_dim,
149
+ output_size=hidden_size,
150
+ bias=bias,
151
+ quant_config=quant_config,
152
+ prefix=f"{prefix}.o_proj",
153
+ )
154
+
155
+ self.rotary_emb = get_rope(
156
+ self.head_dim,
157
+ rotary_dim=self.head_dim,
158
+ max_position=max_position_embeddings,
159
+ base=rope_theta,
160
+ rope_scaling=rope_scaling,
161
+ )
162
+ self.attn = RadixAttention(
163
+ self.num_heads,
164
+ self.head_dim,
165
+ self.scaling,
166
+ num_kv_heads=self.num_kv_heads,
167
+ layer_id=layer_id,
168
+ quant_config=quant_config,
169
+ prefix=f"{prefix}.attn",
170
+ )
171
+
172
+ def forward(
173
+ self,
174
+ positions: torch.Tensor,
175
+ forward_batch: ForwardBatch,
176
+ hidden_states: torch.Tensor,
177
+ ) -> torch.Tensor:
178
+ qkv, _ = self.qkv_proj(hidden_states)
179
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
180
+ q, k = self.rotary_emb(positions, q, k)
181
+ attn_output = self.attn(q, k, v, forward_batch=forward_batch)
182
+ output, _ = self.o_proj(attn_output)
183
+ return output
184
+
185
+
186
+ class SolarDecoderLayer(nn.Module):
187
+
188
+ def __init__(
189
+ self,
190
+ config: PretrainedConfig,
191
+ layer_id: int,
192
+ quant_config: Optional[QuantizationConfig] = None,
193
+ prefix: str = "",
194
+ ) -> None:
195
+ super().__init__()
196
+ self.hidden_size = config.hidden_size
197
+ rope_theta = getattr(config, "rope_theta", 10000)
198
+ rope_scaling = getattr(config, "rope_scaling", None)
199
+
200
+ if rope_scaling is not None and getattr(
201
+ config, "original_max_position_embeddings", None
202
+ ):
203
+ rope_scaling["original_max_position_embeddings"] = (
204
+ config.original_max_position_embeddings
205
+ )
206
+ max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
207
+
208
+ attention_bias = getattr(config, "attention_bias", False) or getattr(
209
+ config, "bias", False
210
+ )
211
+ self.self_attn = SolarAttention(
212
+ config=config,
213
+ layer_id=layer_id,
214
+ hidden_size=self.hidden_size,
215
+ num_heads=config.num_attention_heads,
216
+ num_kv_heads=getattr(
217
+ config, "num_key_value_heads", config.num_attention_heads
218
+ ),
219
+ rope_theta=rope_theta,
220
+ rope_scaling=rope_scaling,
221
+ max_position_embeddings=max_position_embeddings,
222
+ quant_config=quant_config,
223
+ bias=attention_bias,
224
+ prefix=f"{prefix}.self_attn",
225
+ )
226
+ self.mlp = SolarMLP(
227
+ hidden_size=self.hidden_size,
228
+ intermediate_size=config.intermediate_size,
229
+ hidden_act=config.hidden_act,
230
+ quant_config=quant_config,
231
+ bias=getattr(config, "mlp_bias", False),
232
+ prefix=f"{prefix}.mlp",
233
+ )
234
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
235
+ self.post_attention_layernorm = RMSNorm(
236
+ config.hidden_size, eps=config.rms_norm_eps
237
+ )
238
+
239
+ def forward(
240
+ self,
241
+ positions: torch.Tensor,
242
+ hidden_states: torch.Tensor,
243
+ forward_batch: ForwardBatch,
244
+ residual: Optional[torch.Tensor],
245
+ ) -> tuple[torch.Tensor, torch.Tensor]:
246
+ # Self Attention
247
+ if residual is None:
248
+ residual = hidden_states
249
+ hidden_states = self.input_layernorm(hidden_states)
250
+ else:
251
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
252
+ hidden_states = self.self_attn(
253
+ positions=positions,
254
+ hidden_states=hidden_states,
255
+ forward_batch=forward_batch,
256
+ )
257
+
258
+ # Fully Connected
259
+ hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
260
+ hidden_states = self.mlp(hidden_states)
261
+ return hidden_states, residual
262
+
263
+
264
+ class SolarModel(nn.Module):
265
+
266
+ def __init__(
267
+ self,
268
+ config: PretrainedConfig,
269
+ quant_config: Optional[QuantizationConfig] = None,
270
+ prefix: str = "",
271
+ ):
272
+ super().__init__()
273
+ self.config = config
274
+
275
+ self.vocab_size = config.vocab_size
276
+ self.org_vocab_size = config.vocab_size
277
+ self.pp_group = get_pp_group()
278
+ if self.pp_group.is_first_rank:
279
+ self.embed_tokens = VocabParallelEmbedding(
280
+ config.vocab_size,
281
+ config.hidden_size,
282
+ quant_config=quant_config,
283
+ prefix=add_prefix("embed_tokens", prefix),
284
+ )
285
+ else:
286
+ self.embed_tokens = PPMissingLayer()
287
+ self.start_layer, self.end_layer, self.layers = make_layers(
288
+ config.num_hidden_layers,
289
+ lambda idx, prefix: SolarDecoderLayer(
290
+ config=config,
291
+ quant_config=quant_config,
292
+ layer_id=idx,
293
+ prefix=prefix,
294
+ ),
295
+ prefix=f"{prefix}.layers",
296
+ )
297
+ if get_pp_group().is_last_rank:
298
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
299
+ else:
300
+ self.norm = PPMissingLayer()
301
+
302
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
303
+ return self.embed_tokens(input_ids)
304
+
305
+ def forward(
306
+ self,
307
+ input_ids: Optional[torch.Tensor],
308
+ positions: torch.Tensor,
309
+ forward_batch: ForwardBatch,
310
+ inputs_embeds: Optional[torch.Tensor] = None,
311
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
312
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]], PPProxyTensors]:
313
+ if self.pp_group().is_first_rank:
314
+ if inputs_embeds is not None:
315
+ hidden_states = inputs_embeds
316
+ else:
317
+ hidden_states = self.get_input_embeddings(input_ids)
318
+ residual = None
319
+ else:
320
+ assert pp_proxy_tensors is not None
321
+
322
+ hidden_states = pp_proxy_tensors["hidden_states"]
323
+ residual = pp_proxy_tensors["residual"]
324
+
325
+ # Depth up-scaling mechanism: caches hidden states and residuals from intermediate layers and interpolates them with the states of later layers.
326
+ # `bskcn` stands for "backbone skip connection".
327
+ bskcn_h_1 = None
328
+ bskcn_h_2 = None
329
+ bskcn_r_1 = None
330
+ bskcn_r_2 = None
331
+ bskcn_tv = self.config.bskcn_tv[0] if self.training else self.config.bskcn_tv[1]
332
+
333
+ for i in range(self.start_layer, self.end_layer):
334
+ if i in self.config.bskcn_1:
335
+ bskcn_h_1 = hidden_states.clone()
336
+ bskcn_r_1 = residual.clone() if residual is not None else None
337
+ if i in self.config.bskcn_2:
338
+ bskcn_h_2 = hidden_states.clone()
339
+ bskcn_r_2 = residual.clone() if residual is not None else None
340
+ if i in self.config.bskcn_3:
341
+ hidden_states = bskcn_h_1 * bskcn_tv + hidden_states * (1 - bskcn_tv)
342
+ if bskcn_r_1 is not None and residual is not None:
343
+ residual = bskcn_r_1 * bskcn_tv + residual * (1 - bskcn_tv)
344
+ if i in self.config.bskcn_4:
345
+ hidden_states = bskcn_h_2 * bskcn_tv + hidden_states * (1 - bskcn_tv)
346
+ if bskcn_r_2 is not None and residual is not None:
347
+ residual = bskcn_r_2 * bskcn_tv + residual * (1 - bskcn_tv)
348
+ layer = self.layers[i]
349
+ hidden_states, residual = layer(
350
+ positions=positions,
351
+ hidden_states=hidden_states,
352
+ forward_batch=forward_batch,
353
+ residual=residual,
354
+ )
355
+
356
+ if not self.pp_group().is_last_rank:
357
+ return PPProxyTensors(
358
+ {"hidden_states": hidden_states, "residual": residual}
359
+ )
360
+
361
+ hidden_states, _ = self.norm(hidden_states, residual)
362
+ return hidden_states
363
+
364
+ def load_kv_cache_scales(self, quantization_param_path: str) -> None:
365
+ tp_size = get_tensor_model_parallel_world_size()
366
+ tp_rank = get_tensor_model_parallel_rank()
367
+ for layer_idx, scaling_factor in kv_cache_scales_loader(
368
+ quantization_param_path,
369
+ tp_rank,
370
+ tp_size,
371
+ self.config.num_hidden_layers,
372
+ self.config.__class__.model_type,
373
+ ):
374
+ if not isinstance(self.layers[layer_idx], nn.Identity):
375
+ layer_self_attn = self.layers[layer_idx].self_attn
376
+
377
+ if hasattr(layer_self_attn.attn, "k_scale"):
378
+ layer_self_attn.attn.k_scale = scaling_factor
379
+ layer_self_attn.attn.v_scale = scaling_factor
380
+ else:
381
+ raise RuntimeError(
382
+ "Self attention has no KV cache scaling " "factor attribute!"
383
+ )
384
+
385
+
386
+ class SolarForCausalLM(nn.Module):
387
+
388
+ packed_modules_mapping = {
389
+ "qkv_proj": [
390
+ ("q_proj", "q"),
391
+ ("k_proj", "k"),
392
+ ("v_proj", "v"),
393
+ ],
394
+ "gate_up_proj": [
395
+ ("gate_proj", 0),
396
+ ("up_proj", 1),
397
+ ],
398
+ }
399
+
400
+ default_bitsandbytes_target_modules = [
401
+ ".gate_proj.",
402
+ ".down_proj.",
403
+ ".up_proj.",
404
+ ".q_proj.",
405
+ ".k_proj.",
406
+ ".v_proj.",
407
+ ".o_proj.",
408
+ ]
409
+ column_parallel_weights_modules = [".down_proj.", ".o_proj."]
410
+ bitsandbytes_stacked_params_mapping = {
411
+ ".q_proj": (".qkv_proj", 0),
412
+ ".k_proj": (".qkv_proj", 1),
413
+ ".v_proj": (".qkv_proj", 2),
414
+ ".gate_proj": (".gate_up_proj", 0),
415
+ ".up_proj": (".gate_up_proj", 1),
416
+ }
417
+
418
+ def __init__(
419
+ self,
420
+ config: PretrainedConfig,
421
+ quant_config: Optional[QuantizationConfig] = None,
422
+ prefix: str = "",
423
+ ):
424
+ super().__init__()
425
+ self.pp_group = get_pp_group()
426
+ self.config = config
427
+ self.quant_config = quant_config
428
+ self.model = SolarModel(
429
+ config=config,
430
+ quant_config=self.quant_config,
431
+ prefix=add_prefix("model", prefix),
432
+ )
433
+
434
+ if self.pp_group.is_last_rank:
435
+ self.unpadded_vocab_size = config.vocab_size
436
+ self.lm_head = ParallelLMHead(
437
+ self.unpadded_vocab_size,
438
+ config.hidden_size,
439
+ org_num_embeddings=config.vocab_size,
440
+ padding_size=DEFAULT_VOCAB_PADDING_SIZE,
441
+ quant_config=quant_config,
442
+ )
443
+ if config.tie_word_embeddings and self.pp_group.is_first_rank:
444
+ self.lm_head.weight = self.model.embed_tokens.weight
445
+
446
+ logit_scale = getattr(config, "logit_scale", 1.0)
447
+ self.logits_processor = LogitsProcessor(
448
+ self.unpadded_vocab_size, config.vocab_size, logit_scale
449
+ )
450
+ else:
451
+ self.lm_head = PPMissingLayer()
452
+
453
+ def forward(
454
+ self,
455
+ input_ids: torch.Tensor,
456
+ positions: torch.Tensor,
457
+ forward_batch: ForwardBatch,
458
+ inputs_embeds: Optional[torch.Tensor] = None,
459
+ ) -> Union[torch.Tensor, LogitsProcessorOutput]:
460
+ hidden_states = self.model(
461
+ input_ids=input_ids,
462
+ positions=positions,
463
+ forward_batch=forward_batch,
464
+ inputs_embeds=inputs_embeds,
465
+ )
466
+
467
+ if self.pp_group().is_last_rank:
468
+ logits = self.logits_processor(self.lm_head, hidden_states, forward_batch)
469
+ return logits
470
+
471
+ return hidden_states
472
+
473
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
474
+
475
+ params_dict = dict(self.named_parameters())
476
+ for name, loaded_weight in weights:
477
+
478
+ is_packed = False
479
+ for packed_name, sources in self.packed_modules_mapping.items():
480
+ for src_name, shard_id in sources:
481
+ if src_name in name:
482
+
483
+ model_param_name = name.replace(src_name, packed_name)
484
+
485
+ if model_param_name in params_dict:
486
+ param = params_dict[model_param_name]
487
+ weight_loader = getattr(
488
+ param, "weight_loader", default_weight_loader
489
+ )
490
+ weight_loader(param, loaded_weight, shard_id)
491
+ is_packed = True
492
+ break
493
+ if is_packed:
494
+ break
495
+
496
+ if is_packed:
497
+ continue
498
+
499
+ if name in params_dict:
500
+ param = params_dict[name]
501
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
502
+ weight_loader(param, loaded_weight)
503
+
504
+
505
+ EntryClass = SolarForCausalLM