sglang 0.5.2rc2__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (377) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +8 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +360 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/model_config.py +181 -82
  17. sglang/srt/configs/qwen3_next.py +326 -0
  18. sglang/srt/configs/qwen3_vl.py +586 -0
  19. sglang/srt/connector/__init__.py +8 -1
  20. sglang/srt/connector/remote_instance.py +82 -0
  21. sglang/srt/constrained/base_grammar_backend.py +49 -12
  22. sglang/srt/constrained/llguidance_backend.py +0 -1
  23. sglang/srt/constrained/outlines_backend.py +0 -1
  24. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  25. sglang/srt/constrained/xgrammar_backend.py +30 -9
  26. sglang/srt/custom_op.py +11 -1
  27. sglang/srt/debug_utils/dump_comparator.py +81 -44
  28. sglang/srt/debug_utils/dump_loader.py +97 -0
  29. sglang/srt/debug_utils/dumper.py +21 -6
  30. sglang/srt/debug_utils/text_comparator.py +73 -11
  31. sglang/srt/disaggregation/ascend/conn.py +2 -2
  32. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  33. sglang/srt/disaggregation/base/conn.py +1 -1
  34. sglang/srt/disaggregation/common/conn.py +279 -108
  35. sglang/srt/disaggregation/decode.py +71 -19
  36. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  37. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  38. sglang/srt/disaggregation/fake/conn.py +1 -1
  39. sglang/srt/disaggregation/mini_lb.py +6 -445
  40. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  41. sglang/srt/disaggregation/nixl/conn.py +326 -53
  42. sglang/srt/disaggregation/prefill.py +36 -17
  43. sglang/srt/disaggregation/utils.py +40 -54
  44. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  45. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  46. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  47. sglang/srt/distributed/parallel_state.py +156 -80
  48. sglang/srt/entrypoints/engine.py +59 -18
  49. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  50. sglang/srt/entrypoints/grpc_server.py +810 -0
  51. sglang/srt/entrypoints/http_server.py +130 -59
  52. sglang/srt/entrypoints/openai/protocol.py +112 -4
  53. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  54. sglang/srt/entrypoints/openai/serving_chat.py +204 -55
  55. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  56. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  57. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  58. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  59. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  60. sglang/srt/environ.py +285 -0
  61. sglang/srt/eplb/eplb_manager.py +2 -2
  62. sglang/srt/eplb/expert_distribution.py +26 -13
  63. sglang/srt/eplb/expert_location.py +38 -8
  64. sglang/srt/eplb/expert_location_updater.py +1 -1
  65. sglang/srt/function_call/base_format_detector.py +3 -6
  66. sglang/srt/function_call/ebnf_composer.py +11 -9
  67. sglang/srt/function_call/function_call_parser.py +9 -2
  68. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  69. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  70. sglang/srt/function_call/json_array_parser.py +63 -0
  71. sglang/srt/function_call/kimik2_detector.py +17 -4
  72. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  73. sglang/srt/function_call/utils.py +96 -5
  74. sglang/srt/grpc/__init__.py +1 -0
  75. sglang/srt/grpc/compile_proto.py +245 -0
  76. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  77. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  79. sglang/srt/layers/activation.py +143 -9
  80. sglang/srt/layers/attention/aiter_backend.py +14 -15
  81. sglang/srt/layers/attention/ascend_backend.py +115 -9
  82. sglang/srt/layers/attention/attention_registry.py +206 -0
  83. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  84. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  85. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  86. sglang/srt/layers/attention/fla/chunk.py +242 -0
  87. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  88. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  89. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  90. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  91. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  92. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  93. sglang/srt/layers/attention/fla/index.py +37 -0
  94. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  95. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  96. sglang/srt/layers/attention/fla/op.py +66 -0
  97. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  98. sglang/srt/layers/attention/fla/utils.py +331 -0
  99. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  100. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  101. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  102. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  103. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  104. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  105. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  106. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  107. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  108. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  109. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  110. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  111. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  112. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  113. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  114. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  115. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  119. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  120. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  121. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  122. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  123. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  124. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  125. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  126. sglang/srt/layers/attention/nsa/utils.py +24 -0
  127. sglang/srt/layers/attention/nsa_backend.py +887 -0
  128. sglang/srt/layers/attention/tbo_backend.py +6 -6
  129. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  130. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  131. sglang/srt/layers/attention/triton_backend.py +57 -7
  132. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  133. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  134. sglang/srt/layers/attention/vision.py +58 -0
  135. sglang/srt/layers/attention/wave_backend.py +4 -4
  136. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  137. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  138. sglang/srt/layers/communicator.py +8 -0
  139. sglang/srt/layers/dp_attention.py +41 -2
  140. sglang/srt/layers/elementwise.py +3 -1
  141. sglang/srt/layers/layernorm.py +34 -15
  142. sglang/srt/layers/linear.py +55 -7
  143. sglang/srt/layers/logits_processor.py +44 -12
  144. sglang/srt/layers/moe/__init__.py +2 -1
  145. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  146. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  147. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  148. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  149. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  150. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  151. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  164. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  165. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  166. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  167. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  168. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  169. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  170. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  171. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  172. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  173. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  174. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  175. sglang/srt/layers/moe/topk.py +30 -9
  176. sglang/srt/layers/moe/utils.py +22 -6
  177. sglang/srt/layers/parameter.py +23 -6
  178. sglang/srt/layers/quantization/awq.py +19 -7
  179. sglang/srt/layers/quantization/base_config.py +11 -6
  180. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  182. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  185. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  186. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  187. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  188. sglang/srt/layers/quantization/fp8.py +78 -49
  189. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  190. sglang/srt/layers/quantization/gptq.py +25 -17
  191. sglang/srt/layers/quantization/modelopt_quant.py +190 -55
  192. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  193. sglang/srt/layers/quantization/mxfp4.py +74 -42
  194. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  195. sglang/srt/layers/quantization/unquant.py +135 -47
  196. sglang/srt/layers/quantization/w4afp8.py +26 -17
  197. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  198. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  199. sglang/srt/layers/rotary_embedding.py +78 -31
  200. sglang/srt/layers/sampler.py +213 -21
  201. sglang/srt/layers/utils.py +23 -0
  202. sglang/srt/lora/backend/base_backend.py +50 -8
  203. sglang/srt/lora/backend/chunked_backend.py +348 -0
  204. sglang/srt/lora/backend/triton_backend.py +99 -5
  205. sglang/srt/lora/layers.py +32 -0
  206. sglang/srt/lora/lora.py +8 -3
  207. sglang/srt/lora/lora_manager.py +52 -118
  208. sglang/srt/lora/mem_pool.py +25 -11
  209. sglang/srt/lora/triton_ops/__init__.py +4 -0
  210. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  211. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  212. sglang/srt/lora/utils.py +22 -11
  213. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  214. sglang/srt/managers/cache_controller.py +199 -301
  215. sglang/srt/managers/data_parallel_controller.py +115 -80
  216. sglang/srt/managers/detokenizer_manager.py +19 -15
  217. sglang/srt/managers/disagg_service.py +46 -0
  218. sglang/srt/managers/io_struct.py +340 -109
  219. sglang/srt/managers/mm_utils.py +44 -6
  220. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  221. sglang/srt/managers/multimodal_processor.py +1 -2
  222. sglang/srt/managers/overlap_utils.py +53 -0
  223. sglang/srt/managers/schedule_batch.py +240 -138
  224. sglang/srt/managers/schedule_policy.py +144 -17
  225. sglang/srt/managers/scheduler.py +502 -209
  226. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  227. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  228. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  229. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  230. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  231. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  232. sglang/srt/managers/tokenizer_manager.py +320 -632
  233. sglang/srt/managers/tp_worker.py +81 -22
  234. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  235. sglang/srt/managers/utils.py +1 -45
  236. sglang/srt/mem_cache/allocator.py +14 -20
  237. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  238. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  239. sglang/srt/mem_cache/chunk_cache.py +8 -1
  240. sglang/srt/mem_cache/evict_policy.py +23 -0
  241. sglang/srt/mem_cache/hicache_storage.py +43 -24
  242. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  243. sglang/srt/mem_cache/memory_pool.py +535 -58
  244. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  245. sglang/srt/mem_cache/radix_cache.py +222 -73
  246. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  247. sglang/srt/mem_cache/storage/__init__.py +10 -0
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  249. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  250. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  251. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  252. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  253. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  254. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  255. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  256. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  257. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  259. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  260. sglang/srt/metrics/collector.py +511 -132
  261. sglang/srt/metrics/func_timer.py +2 -7
  262. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  263. sglang/srt/metrics/utils.py +8 -1
  264. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  265. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  266. sglang/srt/model_executor/forward_batch_info.py +82 -40
  267. sglang/srt/model_executor/model_runner.py +432 -157
  268. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  269. sglang/srt/model_loader/__init__.py +9 -3
  270. sglang/srt/model_loader/loader.py +133 -5
  271. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  272. sglang/srt/model_loader/weight_utils.py +158 -3
  273. sglang/srt/models/apertus.py +686 -0
  274. sglang/srt/models/bailing_moe.py +820 -217
  275. sglang/srt/models/bailing_moe_nextn.py +168 -0
  276. sglang/srt/models/deepseek_nextn.py +6 -1
  277. sglang/srt/models/deepseek_v2.py +607 -130
  278. sglang/srt/models/dots_ocr.py +173 -0
  279. sglang/srt/models/dots_vlm.py +174 -0
  280. sglang/srt/models/dots_vlm_vit.py +337 -0
  281. sglang/srt/models/ernie4.py +1 -1
  282. sglang/srt/models/falcon_h1.py +576 -0
  283. sglang/srt/models/gemma3_causal.py +0 -2
  284. sglang/srt/models/gemma3_mm.py +1 -1
  285. sglang/srt/models/gemma3n_mm.py +2 -2
  286. sglang/srt/models/glm4_moe.py +4 -4
  287. sglang/srt/models/glm4_moe_nextn.py +2 -2
  288. sglang/srt/models/glm4v.py +5 -3
  289. sglang/srt/models/glm4v_moe.py +4 -1
  290. sglang/srt/models/gpt_oss.py +8 -31
  291. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  292. sglang/srt/models/llama.py +4 -0
  293. sglang/srt/models/llama4.py +9 -0
  294. sglang/srt/models/llama_eagle3.py +13 -0
  295. sglang/srt/models/longcat_flash.py +3 -3
  296. sglang/srt/models/longcat_flash_nextn.py +1 -1
  297. sglang/srt/models/mllama4.py +40 -4
  298. sglang/srt/models/opt.py +637 -0
  299. sglang/srt/models/qwen2_5_vl.py +29 -5
  300. sglang/srt/models/qwen2_audio.py +1 -1
  301. sglang/srt/models/qwen2_moe.py +120 -13
  302. sglang/srt/models/qwen2_vl.py +1 -1
  303. sglang/srt/models/qwen3.py +18 -3
  304. sglang/srt/models/qwen3_moe.py +32 -4
  305. sglang/srt/models/qwen3_next.py +1069 -0
  306. sglang/srt/models/qwen3_next_mtp.py +112 -0
  307. sglang/srt/models/qwen3_vl.py +787 -0
  308. sglang/srt/models/qwen3_vl_moe.py +471 -0
  309. sglang/srt/models/registry.py +15 -3
  310. sglang/srt/models/sarashina2_vision.py +269 -0
  311. sglang/srt/models/solar.py +505 -0
  312. sglang/srt/models/starcoder2.py +357 -0
  313. sglang/srt/models/step3_vl.py +1 -1
  314. sglang/srt/models/torch_native_llama.py +9 -2
  315. sglang/srt/models/utils.py +51 -0
  316. sglang/srt/multimodal/processors/base_processor.py +15 -7
  317. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  318. sglang/srt/multimodal/processors/glm4v.py +9 -9
  319. sglang/srt/multimodal/processors/internvl.py +153 -129
  320. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  321. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  322. sglang/srt/offloader.py +27 -3
  323. sglang/srt/parser/jinja_template_utils.py +6 -0
  324. sglang/srt/sampling/sampling_batch_info.py +38 -17
  325. sglang/srt/sampling/sampling_params.py +7 -0
  326. sglang/srt/server_args.py +966 -267
  327. sglang/srt/server_args_config_parser.py +146 -0
  328. sglang/srt/single_batch_overlap.py +151 -0
  329. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  330. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  331. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  332. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  333. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  334. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  335. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  336. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  337. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  338. sglang/srt/speculative/eagle_worker.py +99 -28
  339. sglang/srt/speculative/ngram_utils.py +428 -0
  340. sglang/srt/speculative/ngram_worker.py +245 -0
  341. sglang/srt/speculative/spec_info.py +52 -0
  342. sglang/srt/speculative/spec_utils.py +606 -0
  343. sglang/srt/speculative/standalone_worker.py +109 -0
  344. sglang/srt/torch_memory_saver_adapter.py +5 -7
  345. sglang/srt/tracing/trace.py +578 -0
  346. sglang/srt/two_batch_overlap.py +8 -5
  347. sglang/srt/utils/__init__.py +2 -0
  348. sglang/srt/{utils.py → utils/common.py} +433 -77
  349. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  350. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  351. sglang/srt/utils/rpd_utils.py +452 -0
  352. sglang/srt/utils/slow_rank_detector.py +71 -0
  353. sglang/srt/warmup.py +8 -4
  354. sglang/srt/weight_sync/utils.py +2 -2
  355. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  356. sglang/test/get_logits_ut.py +57 -0
  357. sglang/test/run_eval.py +79 -11
  358. sglang/test/runners.py +5 -1
  359. sglang/test/simple_eval_common.py +5 -2
  360. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  361. sglang/test/test_block_fp8.py +2 -2
  362. sglang/test/test_cutlass_moe.py +24 -6
  363. sglang/test/test_deterministic.py +297 -0
  364. sglang/test/test_disaggregation_utils.py +77 -0
  365. sglang/test/test_fp4_moe.py +370 -1
  366. sglang/test/test_programs.py +1 -1
  367. sglang/test/test_utils.py +383 -5
  368. sglang/utils.py +21 -1
  369. sglang/version.py +1 -1
  370. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
  371. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/RECORD +375 -245
  372. sglang/srt/disaggregation/launch_lb.py +0 -118
  373. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  374. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  375. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
  376. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
  377. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,505 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+
4
+ # Adapted from
5
+ # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
6
+ # Copyright 2023 The vLLM team.
7
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
8
+ #
9
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
10
+ # and OPT implementations in this library. It has been modified from its
11
+ # original forms to accommodate minor architectural differences compared
12
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
13
+ #
14
+ # Licensed under the Apache License, Version 2.0 (the "License");
15
+ # you may not use this file except in compliance with the License.
16
+ # You may obtain a copy of the License at
17
+ #
18
+ # http://www.apache.org/licenses/LICENSE-2.0
19
+ #
20
+ # Unless required by applicable law or agreed to in writing, software
21
+ # distributed under the License is distributed on an "AS IS" BASIS,
22
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
23
+ # See the License for the specific language governing permissions and
24
+ # limitations under the License.
25
+ # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/solar.py
26
+ from collections.abc import Iterable
27
+ from typing import Any, List, Optional, Tuple, Union
28
+
29
+ import torch
30
+ from torch import nn
31
+ from transformers import PretrainedConfig
32
+
33
+ from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size
34
+ from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
35
+ from sglang.srt.layers.activation import SiluAndMul
36
+ from sglang.srt.layers.layernorm import RMSNorm
37
+ from sglang.srt.layers.linear import (
38
+ MergedColumnParallelLinear,
39
+ QKVParallelLinear,
40
+ RowParallelLinear,
41
+ )
42
+ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
43
+ from sglang.srt.layers.quantization import QuantizationConfig
44
+ from sglang.srt.layers.radix_attention import RadixAttention
45
+ from sglang.srt.layers.rotary_embedding import get_rope
46
+ from sglang.srt.layers.utils import PPMissingLayer
47
+ from sglang.srt.layers.vocab_parallel_embedding import (
48
+ DEFAULT_VOCAB_PADDING_SIZE,
49
+ ParallelLMHead,
50
+ VocabParallelEmbedding,
51
+ )
52
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
53
+ from sglang.srt.model_loader.weight_utils import (
54
+ default_weight_loader,
55
+ kv_cache_scales_loader,
56
+ )
57
+ from sglang.srt.utils import add_prefix, make_layers
58
+
59
+
60
+ class SolarMLP(nn.Module):
61
+
62
+ def __init__(
63
+ self,
64
+ hidden_size: int,
65
+ intermediate_size: int,
66
+ hidden_act: str,
67
+ quant_config: Optional[QuantizationConfig] = None,
68
+ bias: bool = False,
69
+ prefix: str = "",
70
+ ) -> None:
71
+ super().__init__()
72
+ self.gate_up_proj = MergedColumnParallelLinear(
73
+ input_size=hidden_size,
74
+ output_sizes=[intermediate_size] * 2,
75
+ bias=bias,
76
+ quant_config=quant_config,
77
+ prefix=f"{prefix}.gate_up_proj",
78
+ )
79
+ self.down_proj = RowParallelLinear(
80
+ input_size=intermediate_size,
81
+ output_size=hidden_size,
82
+ bias=bias,
83
+ quant_config=quant_config,
84
+ prefix=f"{prefix}.down_proj",
85
+ )
86
+ if hidden_act != "silu":
87
+ raise ValueError(
88
+ f"Unsupported activation: {hidden_act}. "
89
+ "Only silu is supported for now."
90
+ )
91
+ self.act_fn = SiluAndMul()
92
+
93
+ def forward(self, x):
94
+ gate_up, _ = self.gate_up_proj(x)
95
+ x = self.act_fn(gate_up)
96
+ x, _ = self.down_proj(x)
97
+ return x
98
+
99
+
100
+ class SolarAttention(nn.Module):
101
+
102
+ def __init__(
103
+ self,
104
+ config: PretrainedConfig,
105
+ hidden_size: int,
106
+ num_heads: int,
107
+ num_kv_heads: int,
108
+ rope_theta: float = 10000,
109
+ rope_scaling: Optional[dict[str, Any]] = None,
110
+ max_position_embeddings: int = 8192,
111
+ quant_config: Optional[QuantizationConfig] = None,
112
+ bias: bool = False,
113
+ prefix: str = "",
114
+ layer_id: int = 0,
115
+ ) -> None:
116
+ super().__init__()
117
+ self.hidden_size = hidden_size
118
+ tp_size = get_tensor_model_parallel_world_size()
119
+ self.total_num_heads = num_heads
120
+ assert self.total_num_heads % tp_size == 0
121
+ self.num_heads = self.total_num_heads // tp_size
122
+ self.total_num_kv_heads = num_kv_heads
123
+ if self.total_num_kv_heads >= tp_size:
124
+ assert self.total_num_kv_heads % tp_size == 0
125
+ else:
126
+ assert tp_size % self.total_num_kv_heads == 0
127
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
128
+
129
+ self.head_dim = getattr(config, "head_dim", None)
130
+ if self.head_dim is None:
131
+ self.head_dim = self.hidden_size // self.total_num_heads
132
+ self.q_size = self.num_heads * self.head_dim
133
+ self.kv_size = self.num_kv_heads * self.head_dim
134
+ self.scaling = self.head_dim**-0.5
135
+ self.rope_theta = rope_theta
136
+ self.max_position_embeddings = max_position_embeddings
137
+
138
+ self.qkv_proj = QKVParallelLinear(
139
+ hidden_size=hidden_size,
140
+ head_size=self.head_dim,
141
+ total_num_heads=self.total_num_heads,
142
+ total_num_kv_heads=self.total_num_kv_heads,
143
+ bias=bias,
144
+ quant_config=quant_config,
145
+ prefix=f"{prefix}.qkv_proj",
146
+ )
147
+ self.o_proj = RowParallelLinear(
148
+ input_size=self.total_num_heads * self.head_dim,
149
+ output_size=hidden_size,
150
+ bias=bias,
151
+ quant_config=quant_config,
152
+ prefix=f"{prefix}.o_proj",
153
+ )
154
+
155
+ self.rotary_emb = get_rope(
156
+ self.head_dim,
157
+ rotary_dim=self.head_dim,
158
+ max_position=max_position_embeddings,
159
+ base=rope_theta,
160
+ rope_scaling=rope_scaling,
161
+ )
162
+ self.attn = RadixAttention(
163
+ self.num_heads,
164
+ self.head_dim,
165
+ self.scaling,
166
+ num_kv_heads=self.num_kv_heads,
167
+ layer_id=layer_id,
168
+ quant_config=quant_config,
169
+ prefix=f"{prefix}.attn",
170
+ )
171
+
172
+ def forward(
173
+ self,
174
+ positions: torch.Tensor,
175
+ forward_batch: ForwardBatch,
176
+ hidden_states: torch.Tensor,
177
+ ) -> torch.Tensor:
178
+ qkv, _ = self.qkv_proj(hidden_states)
179
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
180
+ q, k = self.rotary_emb(positions, q, k)
181
+ attn_output = self.attn(q, k, v, forward_batch=forward_batch)
182
+ output, _ = self.o_proj(attn_output)
183
+ return output
184
+
185
+
186
+ class SolarDecoderLayer(nn.Module):
187
+
188
+ def __init__(
189
+ self,
190
+ config: PretrainedConfig,
191
+ layer_id: int,
192
+ quant_config: Optional[QuantizationConfig] = None,
193
+ prefix: str = "",
194
+ ) -> None:
195
+ super().__init__()
196
+ self.hidden_size = config.hidden_size
197
+ rope_theta = getattr(config, "rope_theta", 10000)
198
+ rope_scaling = getattr(config, "rope_scaling", None)
199
+
200
+ if rope_scaling is not None and getattr(
201
+ config, "original_max_position_embeddings", None
202
+ ):
203
+ rope_scaling["original_max_position_embeddings"] = (
204
+ config.original_max_position_embeddings
205
+ )
206
+ max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
207
+
208
+ attention_bias = getattr(config, "attention_bias", False) or getattr(
209
+ config, "bias", False
210
+ )
211
+ self.self_attn = SolarAttention(
212
+ config=config,
213
+ layer_id=layer_id,
214
+ hidden_size=self.hidden_size,
215
+ num_heads=config.num_attention_heads,
216
+ num_kv_heads=getattr(
217
+ config, "num_key_value_heads", config.num_attention_heads
218
+ ),
219
+ rope_theta=rope_theta,
220
+ rope_scaling=rope_scaling,
221
+ max_position_embeddings=max_position_embeddings,
222
+ quant_config=quant_config,
223
+ bias=attention_bias,
224
+ prefix=f"{prefix}.self_attn",
225
+ )
226
+ self.mlp = SolarMLP(
227
+ hidden_size=self.hidden_size,
228
+ intermediate_size=config.intermediate_size,
229
+ hidden_act=config.hidden_act,
230
+ quant_config=quant_config,
231
+ bias=getattr(config, "mlp_bias", False),
232
+ prefix=f"{prefix}.mlp",
233
+ )
234
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
235
+ self.post_attention_layernorm = RMSNorm(
236
+ config.hidden_size, eps=config.rms_norm_eps
237
+ )
238
+
239
+ def forward(
240
+ self,
241
+ positions: torch.Tensor,
242
+ hidden_states: torch.Tensor,
243
+ forward_batch: ForwardBatch,
244
+ residual: Optional[torch.Tensor],
245
+ ) -> tuple[torch.Tensor, torch.Tensor]:
246
+ # Self Attention
247
+ if residual is None:
248
+ residual = hidden_states
249
+ hidden_states = self.input_layernorm(hidden_states)
250
+ else:
251
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
252
+ hidden_states = self.self_attn(
253
+ positions=positions,
254
+ hidden_states=hidden_states,
255
+ forward_batch=forward_batch,
256
+ )
257
+
258
+ # Fully Connected
259
+ hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
260
+ hidden_states = self.mlp(hidden_states)
261
+ return hidden_states, residual
262
+
263
+
264
+ class SolarModel(nn.Module):
265
+
266
+ def __init__(
267
+ self,
268
+ config: PretrainedConfig,
269
+ quant_config: Optional[QuantizationConfig] = None,
270
+ prefix: str = "",
271
+ ):
272
+ super().__init__()
273
+ self.config = config
274
+
275
+ self.vocab_size = config.vocab_size
276
+ self.org_vocab_size = config.vocab_size
277
+ self.pp_group = get_pp_group()
278
+ if self.pp_group.is_first_rank:
279
+ self.embed_tokens = VocabParallelEmbedding(
280
+ config.vocab_size,
281
+ config.hidden_size,
282
+ quant_config=quant_config,
283
+ prefix=add_prefix("embed_tokens", prefix),
284
+ )
285
+ else:
286
+ self.embed_tokens = PPMissingLayer()
287
+ self.start_layer, self.end_layer, self.layers = make_layers(
288
+ config.num_hidden_layers,
289
+ lambda idx, prefix: SolarDecoderLayer(
290
+ config=config,
291
+ quant_config=quant_config,
292
+ layer_id=idx,
293
+ prefix=prefix,
294
+ ),
295
+ prefix=f"{prefix}.layers",
296
+ )
297
+ if get_pp_group().is_last_rank:
298
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
299
+ else:
300
+ self.norm = PPMissingLayer()
301
+
302
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
303
+ return self.embed_tokens(input_ids)
304
+
305
+ def forward(
306
+ self,
307
+ input_ids: Optional[torch.Tensor],
308
+ positions: torch.Tensor,
309
+ forward_batch: ForwardBatch,
310
+ inputs_embeds: Optional[torch.Tensor] = None,
311
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
312
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]], PPProxyTensors]:
313
+ if self.pp_group().is_first_rank:
314
+ if inputs_embeds is not None:
315
+ hidden_states = inputs_embeds
316
+ else:
317
+ hidden_states = self.get_input_embeddings(input_ids)
318
+ residual = None
319
+ else:
320
+ assert pp_proxy_tensors is not None
321
+
322
+ hidden_states = pp_proxy_tensors["hidden_states"]
323
+ residual = pp_proxy_tensors["residual"]
324
+
325
+ # Depth up-scaling mechanism: caches hidden states and residuals from intermediate layers and interpolates them with the states of later layers.
326
+ # `bskcn` stands for "backbone skip connection".
327
+ bskcn_h_1 = None
328
+ bskcn_h_2 = None
329
+ bskcn_r_1 = None
330
+ bskcn_r_2 = None
331
+ bskcn_tv = self.config.bskcn_tv[0] if self.training else self.config.bskcn_tv[1]
332
+
333
+ for i in range(self.start_layer, self.end_layer):
334
+ if i in self.config.bskcn_1:
335
+ bskcn_h_1 = hidden_states.clone()
336
+ bskcn_r_1 = residual.clone() if residual is not None else None
337
+ if i in self.config.bskcn_2:
338
+ bskcn_h_2 = hidden_states.clone()
339
+ bskcn_r_2 = residual.clone() if residual is not None else None
340
+ if i in self.config.bskcn_3:
341
+ hidden_states = bskcn_h_1 * bskcn_tv + hidden_states * (1 - bskcn_tv)
342
+ if bskcn_r_1 is not None and residual is not None:
343
+ residual = bskcn_r_1 * bskcn_tv + residual * (1 - bskcn_tv)
344
+ if i in self.config.bskcn_4:
345
+ hidden_states = bskcn_h_2 * bskcn_tv + hidden_states * (1 - bskcn_tv)
346
+ if bskcn_r_2 is not None and residual is not None:
347
+ residual = bskcn_r_2 * bskcn_tv + residual * (1 - bskcn_tv)
348
+ layer = self.layers[i]
349
+ hidden_states, residual = layer(
350
+ positions=positions,
351
+ hidden_states=hidden_states,
352
+ forward_batch=forward_batch,
353
+ residual=residual,
354
+ )
355
+
356
+ if not self.pp_group().is_last_rank:
357
+ return PPProxyTensors(
358
+ {"hidden_states": hidden_states, "residual": residual}
359
+ )
360
+
361
+ hidden_states, _ = self.norm(hidden_states, residual)
362
+ return hidden_states
363
+
364
+ def load_kv_cache_scales(self, quantization_param_path: str) -> None:
365
+ tp_size = get_tensor_model_parallel_world_size()
366
+ tp_rank = get_tensor_model_parallel_rank()
367
+ for layer_idx, scaling_factor in kv_cache_scales_loader(
368
+ quantization_param_path,
369
+ tp_rank,
370
+ tp_size,
371
+ self.config.num_hidden_layers,
372
+ self.config.__class__.model_type,
373
+ ):
374
+ if not isinstance(self.layers[layer_idx], nn.Identity):
375
+ layer_self_attn = self.layers[layer_idx].self_attn
376
+
377
+ if hasattr(layer_self_attn.attn, "k_scale"):
378
+ layer_self_attn.attn.k_scale = scaling_factor
379
+ layer_self_attn.attn.v_scale = scaling_factor
380
+ else:
381
+ raise RuntimeError(
382
+ "Self attention has no KV cache scaling " "factor attribute!"
383
+ )
384
+
385
+
386
+ class SolarForCausalLM(nn.Module):
387
+
388
+ packed_modules_mapping = {
389
+ "qkv_proj": [
390
+ ("q_proj", "q"),
391
+ ("k_proj", "k"),
392
+ ("v_proj", "v"),
393
+ ],
394
+ "gate_up_proj": [
395
+ ("gate_proj", 0),
396
+ ("up_proj", 1),
397
+ ],
398
+ }
399
+
400
+ default_bitsandbytes_target_modules = [
401
+ ".gate_proj.",
402
+ ".down_proj.",
403
+ ".up_proj.",
404
+ ".q_proj.",
405
+ ".k_proj.",
406
+ ".v_proj.",
407
+ ".o_proj.",
408
+ ]
409
+ column_parallel_weights_modules = [".down_proj.", ".o_proj."]
410
+ bitsandbytes_stacked_params_mapping = {
411
+ ".q_proj": (".qkv_proj", 0),
412
+ ".k_proj": (".qkv_proj", 1),
413
+ ".v_proj": (".qkv_proj", 2),
414
+ ".gate_proj": (".gate_up_proj", 0),
415
+ ".up_proj": (".gate_up_proj", 1),
416
+ }
417
+
418
+ def __init__(
419
+ self,
420
+ config: PretrainedConfig,
421
+ quant_config: Optional[QuantizationConfig] = None,
422
+ prefix: str = "",
423
+ ):
424
+ super().__init__()
425
+ self.pp_group = get_pp_group()
426
+ self.config = config
427
+ self.quant_config = quant_config
428
+ self.model = SolarModel(
429
+ config=config,
430
+ quant_config=self.quant_config,
431
+ prefix=add_prefix("model", prefix),
432
+ )
433
+
434
+ if self.pp_group.is_last_rank:
435
+ self.unpadded_vocab_size = config.vocab_size
436
+ self.lm_head = ParallelLMHead(
437
+ self.unpadded_vocab_size,
438
+ config.hidden_size,
439
+ org_num_embeddings=config.vocab_size,
440
+ padding_size=DEFAULT_VOCAB_PADDING_SIZE,
441
+ quant_config=quant_config,
442
+ )
443
+ if config.tie_word_embeddings and self.pp_group.is_first_rank:
444
+ self.lm_head.weight = self.model.embed_tokens.weight
445
+
446
+ logit_scale = getattr(config, "logit_scale", 1.0)
447
+ self.logits_processor = LogitsProcessor(
448
+ self.unpadded_vocab_size, config.vocab_size, logit_scale
449
+ )
450
+ else:
451
+ self.lm_head = PPMissingLayer()
452
+
453
+ def forward(
454
+ self,
455
+ input_ids: torch.Tensor,
456
+ positions: torch.Tensor,
457
+ forward_batch: ForwardBatch,
458
+ inputs_embeds: Optional[torch.Tensor] = None,
459
+ ) -> Union[torch.Tensor, LogitsProcessorOutput]:
460
+ hidden_states = self.model(
461
+ input_ids=input_ids,
462
+ positions=positions,
463
+ forward_batch=forward_batch,
464
+ inputs_embeds=inputs_embeds,
465
+ )
466
+
467
+ if self.pp_group().is_last_rank:
468
+ logits = self.logits_processor(self.lm_head, hidden_states, forward_batch)
469
+ return logits
470
+
471
+ return hidden_states
472
+
473
+ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
474
+
475
+ params_dict = dict(self.named_parameters())
476
+ for name, loaded_weight in weights:
477
+
478
+ is_packed = False
479
+ for packed_name, sources in self.packed_modules_mapping.items():
480
+ for src_name, shard_id in sources:
481
+ if src_name in name:
482
+
483
+ model_param_name = name.replace(src_name, packed_name)
484
+
485
+ if model_param_name in params_dict:
486
+ param = params_dict[model_param_name]
487
+ weight_loader = getattr(
488
+ param, "weight_loader", default_weight_loader
489
+ )
490
+ weight_loader(param, loaded_weight, shard_id)
491
+ is_packed = True
492
+ break
493
+ if is_packed:
494
+ break
495
+
496
+ if is_packed:
497
+ continue
498
+
499
+ if name in params_dict:
500
+ param = params_dict[name]
501
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
502
+ weight_loader(param, loaded_weight)
503
+
504
+
505
+ EntryClass = SolarForCausalLM