sglang 0.5.2rc2__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (377) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +267 -32
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +8 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +360 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/model_config.py +181 -82
  17. sglang/srt/configs/qwen3_next.py +326 -0
  18. sglang/srt/configs/qwen3_vl.py +586 -0
  19. sglang/srt/connector/__init__.py +8 -1
  20. sglang/srt/connector/remote_instance.py +82 -0
  21. sglang/srt/constrained/base_grammar_backend.py +49 -12
  22. sglang/srt/constrained/llguidance_backend.py +0 -1
  23. sglang/srt/constrained/outlines_backend.py +0 -1
  24. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  25. sglang/srt/constrained/xgrammar_backend.py +30 -9
  26. sglang/srt/custom_op.py +11 -1
  27. sglang/srt/debug_utils/dump_comparator.py +81 -44
  28. sglang/srt/debug_utils/dump_loader.py +97 -0
  29. sglang/srt/debug_utils/dumper.py +21 -6
  30. sglang/srt/debug_utils/text_comparator.py +73 -11
  31. sglang/srt/disaggregation/ascend/conn.py +2 -2
  32. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  33. sglang/srt/disaggregation/base/conn.py +1 -1
  34. sglang/srt/disaggregation/common/conn.py +279 -108
  35. sglang/srt/disaggregation/decode.py +71 -19
  36. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  37. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  38. sglang/srt/disaggregation/fake/conn.py +1 -1
  39. sglang/srt/disaggregation/mini_lb.py +6 -445
  40. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  41. sglang/srt/disaggregation/nixl/conn.py +326 -53
  42. sglang/srt/disaggregation/prefill.py +36 -17
  43. sglang/srt/disaggregation/utils.py +40 -54
  44. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  45. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  46. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  47. sglang/srt/distributed/parallel_state.py +156 -80
  48. sglang/srt/entrypoints/engine.py +59 -18
  49. sglang/srt/entrypoints/grpc_request_manager.py +855 -0
  50. sglang/srt/entrypoints/grpc_server.py +810 -0
  51. sglang/srt/entrypoints/http_server.py +130 -59
  52. sglang/srt/entrypoints/openai/protocol.py +112 -4
  53. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  54. sglang/srt/entrypoints/openai/serving_chat.py +204 -55
  55. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  56. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  57. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  58. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  59. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  60. sglang/srt/environ.py +285 -0
  61. sglang/srt/eplb/eplb_manager.py +2 -2
  62. sglang/srt/eplb/expert_distribution.py +26 -13
  63. sglang/srt/eplb/expert_location.py +38 -8
  64. sglang/srt/eplb/expert_location_updater.py +1 -1
  65. sglang/srt/function_call/base_format_detector.py +3 -6
  66. sglang/srt/function_call/ebnf_composer.py +11 -9
  67. sglang/srt/function_call/function_call_parser.py +9 -2
  68. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  69. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  70. sglang/srt/function_call/json_array_parser.py +63 -0
  71. sglang/srt/function_call/kimik2_detector.py +17 -4
  72. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  73. sglang/srt/function_call/utils.py +96 -5
  74. sglang/srt/grpc/__init__.py +1 -0
  75. sglang/srt/grpc/compile_proto.py +245 -0
  76. sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
  77. sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
  78. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
  79. sglang/srt/layers/activation.py +143 -9
  80. sglang/srt/layers/attention/aiter_backend.py +14 -15
  81. sglang/srt/layers/attention/ascend_backend.py +115 -9
  82. sglang/srt/layers/attention/attention_registry.py +206 -0
  83. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  84. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  85. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  86. sglang/srt/layers/attention/fla/chunk.py +242 -0
  87. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  88. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  89. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  90. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  91. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  92. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  93. sglang/srt/layers/attention/fla/index.py +37 -0
  94. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  95. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  96. sglang/srt/layers/attention/fla/op.py +66 -0
  97. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  98. sglang/srt/layers/attention/fla/utils.py +331 -0
  99. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  100. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  101. sglang/srt/layers/attention/flashinfer_backend.py +118 -198
  102. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
  103. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  104. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  105. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  106. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  107. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  108. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
  109. sglang/srt/layers/attention/mamba/mamba.py +629 -0
  110. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  111. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  112. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  113. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  114. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  115. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  116. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  117. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  119. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  120. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  121. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  122. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  123. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  124. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  125. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  126. sglang/srt/layers/attention/nsa/utils.py +24 -0
  127. sglang/srt/layers/attention/nsa_backend.py +887 -0
  128. sglang/srt/layers/attention/tbo_backend.py +6 -6
  129. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  130. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  131. sglang/srt/layers/attention/triton_backend.py +57 -7
  132. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  133. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  134. sglang/srt/layers/attention/vision.py +58 -0
  135. sglang/srt/layers/attention/wave_backend.py +4 -4
  136. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  137. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  138. sglang/srt/layers/communicator.py +8 -0
  139. sglang/srt/layers/dp_attention.py +41 -2
  140. sglang/srt/layers/elementwise.py +3 -1
  141. sglang/srt/layers/layernorm.py +34 -15
  142. sglang/srt/layers/linear.py +55 -7
  143. sglang/srt/layers/logits_processor.py +44 -12
  144. sglang/srt/layers/moe/__init__.py +2 -1
  145. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  146. sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
  147. sglang/srt/layers/moe/ep_moe/layer.py +256 -63
  148. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  149. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  150. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  151. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  152. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  153. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  154. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  155. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  164. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  165. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  166. sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
  167. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  168. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  169. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  170. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  171. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  172. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  173. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  174. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  175. sglang/srt/layers/moe/topk.py +30 -9
  176. sglang/srt/layers/moe/utils.py +22 -6
  177. sglang/srt/layers/parameter.py +23 -6
  178. sglang/srt/layers/quantization/awq.py +19 -7
  179. sglang/srt/layers/quantization/base_config.py +11 -6
  180. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  182. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  185. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  186. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  187. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  188. sglang/srt/layers/quantization/fp8.py +78 -49
  189. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  190. sglang/srt/layers/quantization/gptq.py +25 -17
  191. sglang/srt/layers/quantization/modelopt_quant.py +190 -55
  192. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  193. sglang/srt/layers/quantization/mxfp4.py +74 -42
  194. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  195. sglang/srt/layers/quantization/unquant.py +135 -47
  196. sglang/srt/layers/quantization/w4afp8.py +26 -17
  197. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  198. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  199. sglang/srt/layers/rotary_embedding.py +78 -31
  200. sglang/srt/layers/sampler.py +213 -21
  201. sglang/srt/layers/utils.py +23 -0
  202. sglang/srt/lora/backend/base_backend.py +50 -8
  203. sglang/srt/lora/backend/chunked_backend.py +348 -0
  204. sglang/srt/lora/backend/triton_backend.py +99 -5
  205. sglang/srt/lora/layers.py +32 -0
  206. sglang/srt/lora/lora.py +8 -3
  207. sglang/srt/lora/lora_manager.py +52 -118
  208. sglang/srt/lora/mem_pool.py +25 -11
  209. sglang/srt/lora/triton_ops/__init__.py +4 -0
  210. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  211. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  212. sglang/srt/lora/utils.py +22 -11
  213. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  214. sglang/srt/managers/cache_controller.py +199 -301
  215. sglang/srt/managers/data_parallel_controller.py +115 -80
  216. sglang/srt/managers/detokenizer_manager.py +19 -15
  217. sglang/srt/managers/disagg_service.py +46 -0
  218. sglang/srt/managers/io_struct.py +340 -109
  219. sglang/srt/managers/mm_utils.py +44 -6
  220. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  221. sglang/srt/managers/multimodal_processor.py +1 -2
  222. sglang/srt/managers/overlap_utils.py +53 -0
  223. sglang/srt/managers/schedule_batch.py +240 -138
  224. sglang/srt/managers/schedule_policy.py +144 -17
  225. sglang/srt/managers/scheduler.py +502 -209
  226. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  227. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  228. sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
  229. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  230. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  231. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  232. sglang/srt/managers/tokenizer_manager.py +320 -632
  233. sglang/srt/managers/tp_worker.py +81 -22
  234. sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
  235. sglang/srt/managers/utils.py +1 -45
  236. sglang/srt/mem_cache/allocator.py +14 -20
  237. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  238. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  239. sglang/srt/mem_cache/chunk_cache.py +8 -1
  240. sglang/srt/mem_cache/evict_policy.py +23 -0
  241. sglang/srt/mem_cache/hicache_storage.py +43 -24
  242. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  243. sglang/srt/mem_cache/memory_pool.py +535 -58
  244. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  245. sglang/srt/mem_cache/radix_cache.py +222 -73
  246. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  247. sglang/srt/mem_cache/storage/__init__.py +10 -0
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  249. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  250. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  251. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  252. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  253. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  254. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  255. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  256. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  257. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  259. sglang/srt/mem_cache/swa_radix_cache.py +25 -36
  260. sglang/srt/metrics/collector.py +511 -132
  261. sglang/srt/metrics/func_timer.py +2 -7
  262. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  263. sglang/srt/metrics/utils.py +8 -1
  264. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  265. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  266. sglang/srt/model_executor/forward_batch_info.py +82 -40
  267. sglang/srt/model_executor/model_runner.py +432 -157
  268. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  269. sglang/srt/model_loader/__init__.py +9 -3
  270. sglang/srt/model_loader/loader.py +133 -5
  271. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  272. sglang/srt/model_loader/weight_utils.py +158 -3
  273. sglang/srt/models/apertus.py +686 -0
  274. sglang/srt/models/bailing_moe.py +820 -217
  275. sglang/srt/models/bailing_moe_nextn.py +168 -0
  276. sglang/srt/models/deepseek_nextn.py +6 -1
  277. sglang/srt/models/deepseek_v2.py +607 -130
  278. sglang/srt/models/dots_ocr.py +173 -0
  279. sglang/srt/models/dots_vlm.py +174 -0
  280. sglang/srt/models/dots_vlm_vit.py +337 -0
  281. sglang/srt/models/ernie4.py +1 -1
  282. sglang/srt/models/falcon_h1.py +576 -0
  283. sglang/srt/models/gemma3_causal.py +0 -2
  284. sglang/srt/models/gemma3_mm.py +1 -1
  285. sglang/srt/models/gemma3n_mm.py +2 -2
  286. sglang/srt/models/glm4_moe.py +4 -4
  287. sglang/srt/models/glm4_moe_nextn.py +2 -2
  288. sglang/srt/models/glm4v.py +5 -3
  289. sglang/srt/models/glm4v_moe.py +4 -1
  290. sglang/srt/models/gpt_oss.py +8 -31
  291. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  292. sglang/srt/models/llama.py +4 -0
  293. sglang/srt/models/llama4.py +9 -0
  294. sglang/srt/models/llama_eagle3.py +13 -0
  295. sglang/srt/models/longcat_flash.py +3 -3
  296. sglang/srt/models/longcat_flash_nextn.py +1 -1
  297. sglang/srt/models/mllama4.py +40 -4
  298. sglang/srt/models/opt.py +637 -0
  299. sglang/srt/models/qwen2_5_vl.py +29 -5
  300. sglang/srt/models/qwen2_audio.py +1 -1
  301. sglang/srt/models/qwen2_moe.py +120 -13
  302. sglang/srt/models/qwen2_vl.py +1 -1
  303. sglang/srt/models/qwen3.py +18 -3
  304. sglang/srt/models/qwen3_moe.py +32 -4
  305. sglang/srt/models/qwen3_next.py +1069 -0
  306. sglang/srt/models/qwen3_next_mtp.py +112 -0
  307. sglang/srt/models/qwen3_vl.py +787 -0
  308. sglang/srt/models/qwen3_vl_moe.py +471 -0
  309. sglang/srt/models/registry.py +15 -3
  310. sglang/srt/models/sarashina2_vision.py +269 -0
  311. sglang/srt/models/solar.py +505 -0
  312. sglang/srt/models/starcoder2.py +357 -0
  313. sglang/srt/models/step3_vl.py +1 -1
  314. sglang/srt/models/torch_native_llama.py +9 -2
  315. sglang/srt/models/utils.py +51 -0
  316. sglang/srt/multimodal/processors/base_processor.py +15 -7
  317. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  318. sglang/srt/multimodal/processors/glm4v.py +9 -9
  319. sglang/srt/multimodal/processors/internvl.py +153 -129
  320. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  321. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  322. sglang/srt/offloader.py +27 -3
  323. sglang/srt/parser/jinja_template_utils.py +6 -0
  324. sglang/srt/sampling/sampling_batch_info.py +38 -17
  325. sglang/srt/sampling/sampling_params.py +7 -0
  326. sglang/srt/server_args.py +966 -267
  327. sglang/srt/server_args_config_parser.py +146 -0
  328. sglang/srt/single_batch_overlap.py +151 -0
  329. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  330. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  331. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  332. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  333. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  334. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  335. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  336. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  337. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  338. sglang/srt/speculative/eagle_worker.py +99 -28
  339. sglang/srt/speculative/ngram_utils.py +428 -0
  340. sglang/srt/speculative/ngram_worker.py +245 -0
  341. sglang/srt/speculative/spec_info.py +52 -0
  342. sglang/srt/speculative/spec_utils.py +606 -0
  343. sglang/srt/speculative/standalone_worker.py +109 -0
  344. sglang/srt/torch_memory_saver_adapter.py +5 -7
  345. sglang/srt/tracing/trace.py +578 -0
  346. sglang/srt/two_batch_overlap.py +8 -5
  347. sglang/srt/utils/__init__.py +2 -0
  348. sglang/srt/{utils.py → utils/common.py} +433 -77
  349. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
  350. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  351. sglang/srt/utils/rpd_utils.py +452 -0
  352. sglang/srt/utils/slow_rank_detector.py +71 -0
  353. sglang/srt/warmup.py +8 -4
  354. sglang/srt/weight_sync/utils.py +2 -2
  355. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  356. sglang/test/get_logits_ut.py +57 -0
  357. sglang/test/run_eval.py +79 -11
  358. sglang/test/runners.py +5 -1
  359. sglang/test/simple_eval_common.py +5 -2
  360. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  361. sglang/test/test_block_fp8.py +2 -2
  362. sglang/test/test_cutlass_moe.py +24 -6
  363. sglang/test/test_deterministic.py +297 -0
  364. sglang/test/test_disaggregation_utils.py +77 -0
  365. sglang/test/test_fp4_moe.py +370 -1
  366. sglang/test/test_programs.py +1 -1
  367. sglang/test/test_utils.py +383 -5
  368. sglang/utils.py +21 -1
  369. sglang/version.py +1 -1
  370. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
  371. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/RECORD +375 -245
  372. sglang/srt/disaggregation/launch_lb.py +0 -118
  373. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  374. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  375. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
  376. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
  377. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,173 @@
1
+ # coding=utf-8
2
+ # Adapted from Qwen2.5-VL SGLang implementation
3
+
4
+ import logging
5
+ from typing import Iterable, List, Optional, Tuple
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from transformers.activations import ACT2FN
10
+
11
+ from sglang.srt.configs import DotsOCRConfig
12
+ from sglang.srt.layers.logits_processor import LogitsProcessor
13
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
14
+ from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
15
+ from sglang.srt.managers.mm_utils import (
16
+ MultiModalityDataPaddingPatternMultimodalTokens,
17
+ general_mm_embed_routine,
18
+ )
19
+ from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
20
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
21
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
22
+ from sglang.srt.models.dots_vlm_vit import DotsVisionTransformer
23
+ from sglang.srt.models.qwen2 import Qwen2ForCausalLM
24
+ from sglang.srt.utils import add_prefix
25
+ from sglang.srt.utils.hf_transformers_utils import get_processor
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ class DotsOCRForCausalLM(nn.Module):
31
+ def __init__(
32
+ self,
33
+ config: DotsOCRConfig,
34
+ quant_config: Optional[QuantizationConfig] = None,
35
+ prefix: str = "",
36
+ ) -> None:
37
+ super().__init__()
38
+ self.config = config
39
+
40
+ # Initialize vision transformer
41
+ self.visual = DotsVisionTransformer(
42
+ config.vision_config,
43
+ )
44
+
45
+ # Initialize language model
46
+ self.model = Qwen2ForCausalLM(config, quant_config)
47
+
48
+ # Initialize LM head
49
+ if config.tie_word_embeddings:
50
+ self.lm_head = self.model.embed_tokens
51
+ else:
52
+ self.lm_head = ParallelLMHead(
53
+ config.vocab_size,
54
+ config.hidden_size,
55
+ quant_config=quant_config,
56
+ prefix=add_prefix("lm_head", prefix),
57
+ )
58
+
59
+ self.logits_processor = LogitsProcessor(config)
60
+
61
+ def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
62
+ pattern = MultiModalityDataPaddingPatternMultimodalTokens()
63
+ return pattern.pad_input_tokens(input_ids, mm_inputs)
64
+
65
+ def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
66
+ # Extract pixel values and grid information (following reference pattern)
67
+ pixel_values = torch.cat([item.feature for item in items], dim=0).type(
68
+ self.visual.dtype
69
+ )
70
+ image_grid_thw = torch.concat(
71
+ [item.image_grid_thw for item in items], dim=0
72
+ ).to(self.visual.device)
73
+
74
+ # Add dimension checks like in reference code
75
+ assert pixel_values.dim() == 2, f"{pixel_values.dim()=}"
76
+ assert image_grid_thw.dim() == 2, f"{image_grid_thw.dim()=}"
77
+
78
+ # Process through vision tower
79
+ image_embeds = self.visual(pixel_values, image_grid_thw)
80
+
81
+ # Ensure consistent dtype for FlashInfer compatibility
82
+ # Force bfloat16 to match model's expected dtype
83
+ if hasattr(self.model, "embed_tokens"):
84
+ target_dtype = self.model.embed_tokens.weight.dtype
85
+ if image_embeds.dtype != target_dtype:
86
+ image_embeds = image_embeds.to(target_dtype)
87
+
88
+ return image_embeds
89
+
90
+ def _pad_vit_attn_dummy_heads(self, name: str, loaded_weight: torch.Tensor):
91
+ """pad attn qkv weights for dummy heads"""
92
+ num_dummy_heads = self.config.vision_config.num_dummy_heads
93
+ if num_dummy_heads == 0:
94
+ return loaded_weight
95
+ head_dim = self.config.vision_config.head_dim
96
+
97
+ if "attn.qkv_proj" in name:
98
+ wq, wk, wv = loaded_weight.chunk(3, dim=0)
99
+ if name.endswith(".weight"):
100
+ dummy_shape = [num_dummy_heads, head_dim, wq.shape[-1]]
101
+ elif name.endswith(".bias"):
102
+ dummy_shape = [num_dummy_heads, head_dim]
103
+ else:
104
+ raise RuntimeError(f"Unsupported weight with name={name}")
105
+ pad_func = lambda x: torch.cat(
106
+ [x.unflatten(0, (-1, head_dim)), x.new_zeros(dummy_shape)], dim=0
107
+ ).flatten(0, 1)
108
+ wq, wk, wv = pad_func(wq), pad_func(wk), pad_func(wv)
109
+ loaded_weight = torch.cat([wq, wk, wv], dim=0)
110
+ if "attn.proj.weight" in name:
111
+ padded_weight = loaded_weight.new_zeros(
112
+ loaded_weight.shape[0], head_dim * num_dummy_heads
113
+ )
114
+ loaded_weight = torch.cat([loaded_weight, padded_weight], dim=-1)
115
+ if "attn.q_norm.weight" in name or "attn.k_norm.weight" in name:
116
+ padded_weight = loaded_weight.new_zeros(head_dim * num_dummy_heads)
117
+ loaded_weight = torch.cat([loaded_weight, padded_weight], dim=0)
118
+ return loaded_weight
119
+
120
+ def forward(
121
+ self,
122
+ input_ids: torch.Tensor,
123
+ positions: torch.Tensor,
124
+ forward_batch: ForwardBatch,
125
+ **kwargs: object,
126
+ ) -> torch.Tensor:
127
+ hidden_states = general_mm_embed_routine(
128
+ input_ids=input_ids,
129
+ positions=positions,
130
+ forward_batch=forward_batch,
131
+ multimodal_model=self,
132
+ language_model=self.model,
133
+ )
134
+ return hidden_states
135
+
136
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
137
+ """Load weights for the model, separating vision and language weights"""
138
+ weights = list(weights)
139
+
140
+ # Separate vision tower weights and language model weights
141
+ vision_weights = []
142
+ language_weights = []
143
+
144
+ for name, loaded_weight in weights:
145
+ if name.startswith("vision_tower."):
146
+ vision_name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
147
+
148
+ vision_weights.append((vision_name, loaded_weight))
149
+ else:
150
+ # All other weights go to language model
151
+ language_weights.append((name, loaded_weight))
152
+
153
+ # Load vision tower weights
154
+ vision_state_dict = dict(vision_weights)
155
+ params_dict = dict(self.named_parameters(remove_duplicate=False))
156
+
157
+ for name, loaded_weight in vision_state_dict.items():
158
+ name = name.replace("vision_tower", "visual")
159
+ if name not in params_dict:
160
+ raise ValueError(f"Weight {name} not found in params_dict")
161
+ param = params_dict[name]
162
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
163
+ loaded_weight = self._pad_vit_attn_dummy_heads(name, loaded_weight)
164
+ weight_loader(param, loaded_weight)
165
+
166
+ if language_weights:
167
+ self.model.load_weights(language_weights)
168
+
169
+ def get_embed_and_head(self):
170
+ return self.model.embed_tokens.weight, self.lm_head.weight
171
+
172
+
173
+ EntryClass = [DotsOCRForCausalLM]
@@ -0,0 +1,174 @@
1
+ # Copyright 2025 The RedNote HiLab team.
2
+ # Copyright 2025 The SGLang team.
3
+ #
4
+ # This code is based on the DeepseekVL2ForCausalLM and DotsVisionTransformer
5
+ # implementation in this library.
6
+ #
7
+ # Licensed under the Apache License, Version 2.0 (the "License");
8
+ # you may not use this file except in compliance with the License.
9
+ # You may obtain a copy of the License at
10
+ #
11
+ # http://www.apache.org/licenses/LICENSE-2.0
12
+ #
13
+ # Unless required by applicable law or agreed to in writing, software
14
+ # distributed under the License is distributed on an "AS IS" BASIS,
15
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
+ # See the License for the specific language governing permissions and
17
+ # limitations under the License.
18
+ """Inference-only Dots-VL model compatible with HuggingFace weights."""
19
+
20
+ from typing import Iterable, List, Optional, Tuple
21
+
22
+ import torch
23
+ from torch import nn
24
+
25
+ from sglang.srt.configs.dots_vlm import DotsVLMConfig
26
+ from sglang.srt.distributed import parallel_state
27
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
28
+ from sglang.srt.managers.mm_utils import (
29
+ MultiModalityDataPaddingPatternMultimodalTokens,
30
+ general_mm_embed_routine,
31
+ )
32
+ from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
33
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
34
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
35
+ from sglang.srt.models.deepseek_v2 import DeepseekV2ForCausalLM
36
+
37
+ from .dots_vlm_vit import DotsVisionTransformer
38
+
39
+
40
+ class DotsVLMForCausalLM(nn.Module):
41
+ """DotsVLM model for sglang inference"""
42
+
43
+ def __init__(
44
+ self, config: DotsVLMConfig, quant_config: Optional[QuantizationConfig] = None
45
+ ) -> None:
46
+ super().__init__()
47
+
48
+ self.config = config
49
+ self.image_token_id = config.im_span_id
50
+ self.video_token_id = config.video_span_id
51
+
52
+ self.language_model = DeepseekV2ForCausalLM(
53
+ config.language_config, quant_config
54
+ )
55
+
56
+ # Initialize vision tower (matching transformers naming for weight compatibility)
57
+ self.vision_tower = DotsVisionTransformer(config.vision_config)
58
+
59
+ def _pad_vit_attn_dummy_heads(self, name: str, loaded_weight: torch.Tensor):
60
+ """pad attn qkv weights for dummy heads"""
61
+ num_dummy_heads = self.config.vision_config.num_dummy_heads
62
+ if num_dummy_heads == 0:
63
+ return loaded_weight
64
+ head_dim = self.config.vision_config.head_dim
65
+
66
+ if "attn.qkv_proj" in name:
67
+ wq, wk, wv = loaded_weight.chunk(3, dim=0)
68
+ if name.endswith(".weight"):
69
+ dummy_shape = [num_dummy_heads, head_dim, wq.shape[-1]]
70
+ elif name.endswith(".bias"):
71
+ dummy_shape = [num_dummy_heads, head_dim]
72
+ else:
73
+ raise RuntimeError(f"Unsupported weight with name={name}")
74
+ pad_func = lambda x: torch.cat(
75
+ [x.unflatten(0, (-1, head_dim)), x.new_zeros(dummy_shape)], dim=0
76
+ ).flatten(0, 1)
77
+ wq, wk, wv = pad_func(wq), pad_func(wk), pad_func(wv)
78
+ loaded_weight = torch.cat([wq, wk, wv], dim=0)
79
+ if "attn.proj.weight" in name:
80
+ padded_weight = loaded_weight.new_zeros(
81
+ loaded_weight.shape[0], head_dim * num_dummy_heads
82
+ )
83
+ loaded_weight = torch.cat([loaded_weight, padded_weight], dim=-1)
84
+ if "attn.q_norm.weight" in name or "attn.k_norm.weight" in name:
85
+ padded_weight = loaded_weight.new_zeros(head_dim * num_dummy_heads)
86
+ loaded_weight = torch.cat([loaded_weight, padded_weight], dim=0)
87
+ return loaded_weight
88
+
89
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
90
+ """Load weights for the model, separating vision and language weights"""
91
+ weights = list(weights)
92
+
93
+ # Separate vision tower weights and language model weights
94
+ vision_weights = []
95
+ language_weights = []
96
+
97
+ for name, loaded_weight in weights:
98
+ if name.startswith("vision_tower."):
99
+ vision_name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
100
+ vision_weights.append((vision_name, loaded_weight))
101
+ else:
102
+ # All other weights go to language model
103
+ language_weights.append((name, loaded_weight))
104
+
105
+ # Load vision tower weights
106
+ vision_state_dict = dict(vision_weights)
107
+ params_dict = dict(self.named_parameters(remove_duplicate=False))
108
+ for name, loaded_weight in vision_state_dict.items():
109
+ if name not in params_dict:
110
+ raise ValueError(f"Weight {name} not found in params_dict")
111
+ param = params_dict[name]
112
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
113
+ loaded_weight = self._pad_vit_attn_dummy_heads(name, loaded_weight)
114
+ weight_loader(param, loaded_weight)
115
+
116
+ # Load language model weights
117
+ if language_weights:
118
+ self.language_model.load_weights(language_weights)
119
+
120
+ @classmethod
121
+ def get_model_config_for_expert_location(cls, config):
122
+ return DeepseekV2ForCausalLM.get_model_config_for_expert_location(config)
123
+
124
+ def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
125
+ """Pad input_ids with multimodal tokens"""
126
+ # Get image token ID for padding pattern
127
+ pattern = MultiModalityDataPaddingPatternMultimodalTokens()
128
+ padded_input_ids = pattern.pad_input_tokens(input_ids, mm_inputs)
129
+ return padded_input_ids
130
+
131
+ def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
132
+ # Extract pixel values and grid information (following reference pattern)
133
+ pixel_values = torch.cat([item.feature for item in items], dim=0).type(
134
+ self.vision_tower.dtype
135
+ )
136
+ image_grid_thw = torch.concat(
137
+ [item.image_grid_thw for item in items], dim=0
138
+ ).to(self.vision_tower.device)
139
+
140
+ # Add dimension checks like in reference code
141
+ assert pixel_values.dim() == 2, f"{pixel_values.dim()=}"
142
+ assert image_grid_thw.dim() == 2, f"{image_grid_thw.dim()=}"
143
+
144
+ # Process through vision tower
145
+ image_embeds = self.vision_tower(pixel_values, image_grid_thw)
146
+
147
+ # Ensure consistent dtype for FlashInfer compatibility
148
+ # Force bfloat16 to match model's expected dtype
149
+ if image_embeds.dtype != torch.bfloat16 and hasattr(
150
+ self.language_model.model, "embed_tokens"
151
+ ):
152
+ target_dtype = self.language_model.model.embed_tokens.weight.dtype
153
+ image_embeds = image_embeds.to(target_dtype)
154
+
155
+ return image_embeds
156
+
157
+ def forward(
158
+ self,
159
+ input_ids: torch.Tensor,
160
+ positions: torch.Tensor,
161
+ forward_batch: ForwardBatch,
162
+ **kwargs: object,
163
+ ) -> torch.Tensor:
164
+ hidden_states = general_mm_embed_routine(
165
+ input_ids=input_ids,
166
+ positions=positions,
167
+ forward_batch=forward_batch,
168
+ multimodal_model=self,
169
+ language_model=self.language_model,
170
+ )
171
+ return hidden_states
172
+
173
+
174
+ EntryClass = [DotsVLMForCausalLM]
@@ -0,0 +1,337 @@
1
+ import logging
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import torch.utils.checkpoint
8
+ from torch.nn import LayerNorm
9
+ from transformers.modeling_utils import PreTrainedModel
10
+
11
+ from sglang.srt.configs.dots_vlm import DotsVisionConfig
12
+ from sglang.srt.distributed import parallel_state
13
+ from sglang.srt.layers.attention.vision import VisionAttention
14
+ from sglang.srt.layers.quantization import QuantizationConfig
15
+ from sglang.srt.utils import add_prefix
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class VisionRotaryEmbedding(nn.Module):
21
+ def __init__(self, dim: int, theta: float = 10000.0) -> None:
22
+ super().__init__()
23
+ inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
24
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
25
+
26
+ def forward(self, seqlen: int) -> torch.Tensor:
27
+ seq = torch.arange(
28
+ seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
29
+ )
30
+ freqs = torch.outer(seq, self.inv_freq)
31
+ return freqs
32
+
33
+
34
+ class PatchMerger(nn.Module):
35
+ def __init__(
36
+ self,
37
+ dim: int,
38
+ context_dim: int,
39
+ spatial_merge_size: int = 2,
40
+ pre_norm="layernorm",
41
+ init_merger_std=None,
42
+ quant_config: Optional[QuantizationConfig] = None,
43
+ ) -> None:
44
+ super().__init__()
45
+ self.hidden_size = context_dim * (spatial_merge_size**2)
46
+ self.pre_norm = pre_norm
47
+ if self.pre_norm == "layernorm":
48
+ self.ln_q = LayerNorm(context_dim, eps=1e-6)
49
+ elif self.pre_norm == "rmsnorm":
50
+ self.ln_q = RMSNorm(context_dim, eps=1e-6)
51
+ else:
52
+ logger.warning(f"no norm in patch merger: {self.pre_norm}")
53
+
54
+ self.mlp = nn.Sequential(
55
+ nn.Linear(self.hidden_size, self.hidden_size),
56
+ nn.GELU(),
57
+ nn.Linear(self.hidden_size, dim),
58
+ )
59
+
60
+ if init_merger_std is not None:
61
+ nn.init.normal_(self.mlp[0].weight, mean=0.0, std=init_merger_std)
62
+ nn.init.zeros_(self.mlp[0].bias)
63
+ nn.init.normal_(self.mlp[2].weight, mean=0.0, std=init_merger_std)
64
+ nn.init.zeros_(self.mlp[2].bias)
65
+
66
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
67
+ if self.pre_norm:
68
+ x = self.mlp(self.ln_q(x).view(-1, self.hidden_size))
69
+ else:
70
+ x = self.mlp(x.view(-1, self.hidden_size))
71
+ return x
72
+
73
+
74
+ class RMSNorm(nn.Module):
75
+ def __init__(self, dim: int, eps: float = 1e-6):
76
+ super().__init__()
77
+ self.weight = nn.Parameter(torch.ones(dim))
78
+ self.eps = eps
79
+
80
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
81
+ output = self._norm(x.float()).type_as(x)
82
+ return output * self.weight
83
+
84
+ def extra_repr(self) -> str:
85
+ return f"{tuple(self.weight.shape)}, eps={self.eps}"
86
+
87
+ def _norm(self, x: torch.Tensor) -> torch.Tensor:
88
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
89
+
90
+
91
+ class DotsSwiGLUFFN(nn.Module):
92
+ def __init__(self, config, quant_config: Optional[QuantizationConfig] = None):
93
+ super().__init__()
94
+ hidden_features = config.intermediate_size
95
+ in_features = config.embed_dim
96
+ bias = config.use_bias
97
+
98
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
99
+ self.fc2 = nn.Linear(hidden_features, in_features, bias=bias)
100
+ self.fc3 = nn.Linear(in_features, hidden_features, bias=bias)
101
+
102
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
103
+ x = F.silu(self.fc1(x)) * self.fc3(x)
104
+ x = self.fc2(x)
105
+ return x
106
+
107
+
108
+ class DotsPatchEmbed(nn.Module):
109
+ def __init__(self, config, quant_config: Optional[QuantizationConfig] = None):
110
+ super().__init__()
111
+ self.num_channels = config.num_channels
112
+ self.patch_size = config.patch_size
113
+ self.temporal_patch_size = config.temporal_patch_size
114
+ self.embed_dim = config.embed_dim
115
+ self.config = config
116
+ self.proj = nn.Conv2d(
117
+ config.num_channels,
118
+ config.embed_dim,
119
+ kernel_size=(config.patch_size, config.patch_size),
120
+ stride=(config.patch_size, config.patch_size),
121
+ )
122
+ self.norm = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)
123
+
124
+ def forward(self, x: torch.Tensor, grid_thw=None) -> torch.Tensor:
125
+ x = x.view(
126
+ -1,
127
+ self.num_channels,
128
+ self.temporal_patch_size,
129
+ self.patch_size,
130
+ self.patch_size,
131
+ )[:, :, 0]
132
+ x = self.proj(x).view(-1, self.embed_dim)
133
+ x = self.norm(x)
134
+ return x
135
+
136
+
137
+ class DotsViTPreprocessor(nn.Module):
138
+ def __init__(self, config, quant_config: Optional[QuantizationConfig] = None):
139
+ super().__init__()
140
+ self.patch_h = config.patch_size
141
+ self.patch_w = config.patch_size
142
+ self.embed_dim = config.embed_dim
143
+ self.config = config
144
+ self.patchifier = DotsPatchEmbed(config, quant_config)
145
+
146
+ def forward(self, x: torch.Tensor, grid_thw=None) -> torch.Tensor:
147
+ tokens = self.patchifier(x, grid_thw)
148
+ return tokens
149
+
150
+
151
+ class DotsVisionBlock(nn.Module):
152
+ def __init__(
153
+ self,
154
+ config: DotsVisionConfig,
155
+ quant_config: Optional[QuantizationConfig] = None,
156
+ prefix: str = "",
157
+ attn_implementation: str = "flash_attention_2",
158
+ ):
159
+ super().__init__()
160
+ if attn_implementation == "flash_attention_2":
161
+ qkv_backend = "fa3"
162
+ softmax_in_single_precision = False
163
+ else:
164
+ raise RuntimeError("Unimplemented")
165
+ self.attn = VisionAttention(
166
+ embed_dim=config.embed_dim,
167
+ num_heads=config.num_attention_heads,
168
+ projection_size=config.embed_dim,
169
+ use_qkv_parallel=True,
170
+ qkv_backend=qkv_backend,
171
+ softmax_in_single_precision=softmax_in_single_precision,
172
+ flatten_batch=True,
173
+ quant_config=quant_config,
174
+ prefix=add_prefix("attn", prefix),
175
+ num_dummy_heads=config.num_dummy_heads,
176
+ qkv_bias=config.use_bias,
177
+ proj_bias=config.use_bias,
178
+ )
179
+ self.norm1 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)
180
+ self.mlp = DotsSwiGLUFFN(config, quant_config)
181
+ self.norm2 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)
182
+
183
+ def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor:
184
+ hidden_states = hidden_states + self.attn(
185
+ self.norm1(hidden_states),
186
+ cu_seqlens=cu_seqlens,
187
+ position_embeddings=rotary_pos_emb,
188
+ )
189
+ hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
190
+ return hidden_states
191
+
192
+
193
+ class DotsVisionTransformer(PreTrainedModel):
194
+ def __init__(
195
+ self,
196
+ config: DotsVisionConfig,
197
+ quant_config: Optional[QuantizationConfig] = None,
198
+ ) -> None:
199
+ super().__init__(config)
200
+ self.config = config
201
+ self._update_vision_config()
202
+ self.spatial_merge_size = config.spatial_merge_size
203
+
204
+ self.patch_embed = DotsViTPreprocessor(config, quant_config)
205
+ self._init_weights(self.patch_embed.patchifier.proj)
206
+
207
+ head_dim = config.embed_dim // config.num_attention_heads
208
+
209
+ self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
210
+
211
+ _num_hidden_layers = config.num_hidden_layers
212
+ self.blocks = nn.ModuleList(
213
+ [
214
+ DotsVisionBlock(
215
+ config, quant_config, f"blocks.{i}", config.attn_implementation
216
+ )
217
+ for i in range(_num_hidden_layers)
218
+ ]
219
+ )
220
+
221
+ if self.config.post_norm:
222
+ self.post_trunk_norm = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)
223
+
224
+ self.merger = PatchMerger(
225
+ dim=config.hidden_size,
226
+ context_dim=config.embed_dim,
227
+ spatial_merge_size=config.spatial_merge_size,
228
+ init_merger_std=self.config.init_merger_std,
229
+ quant_config=quant_config,
230
+ )
231
+
232
+ self.gradient_checkpointing = False
233
+
234
+ def _update_vision_config(self):
235
+ """update vision config to support tp"""
236
+ world_size = parallel_state.get_tensor_model_parallel_world_size()
237
+ num_heads = self.config.num_attention_heads
238
+ head_dim = self.config.embed_dim // num_heads
239
+ num_dummy_heads = 0
240
+
241
+ if num_heads % world_size != 0:
242
+ num_dummy_heads = (
243
+ (num_heads + world_size) // world_size
244
+ ) * world_size - num_heads
245
+
246
+ setattr(self.config, "head_dim", head_dim)
247
+ setattr(self.config, "num_dummy_heads", num_dummy_heads)
248
+
249
+ def _init_weights(self, module):
250
+ std = self.config.initializer_range
251
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
252
+ module.weight.data.normal_(mean=0.0, std=std)
253
+ if module.bias is not None:
254
+ module.bias.data.zero_()
255
+ elif isinstance(module, nn.Embedding):
256
+ module.weight.data.normal_(mean=0.0, std=std)
257
+ if module.padding_idx is not None:
258
+ module.weight.data[module.padding_idx].zero_()
259
+
260
+ @property
261
+ def dtype(self) -> torch.dtype:
262
+ return self.blocks[0].mlp.fc2.weight.dtype
263
+
264
+ @property
265
+ def device(self) -> torch.device:
266
+ return self.blocks[0].mlp.fc2.weight.device
267
+
268
+ def get_pos_ids_by_grid(self, grid_thw):
269
+ pos_ids = []
270
+ for t, h, w in grid_thw:
271
+ hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
272
+ hpos_ids = hpos_ids.reshape(
273
+ h // self.spatial_merge_size,
274
+ self.spatial_merge_size,
275
+ w // self.spatial_merge_size,
276
+ self.spatial_merge_size,
277
+ )
278
+ hpos_ids = hpos_ids.permute(0, 2, 1, 3)
279
+ hpos_ids = hpos_ids.flatten()
280
+
281
+ wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
282
+ wpos_ids = wpos_ids.reshape(
283
+ h // self.spatial_merge_size,
284
+ self.spatial_merge_size,
285
+ w // self.spatial_merge_size,
286
+ self.spatial_merge_size,
287
+ )
288
+ wpos_ids = wpos_ids.permute(0, 2, 1, 3)
289
+ wpos_ids = wpos_ids.flatten()
290
+ pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
291
+
292
+ return pos_ids
293
+
294
+ def rot_pos_emb(self, grid_thw):
295
+ pos_ids = self.get_pos_ids_by_grid(grid_thw)
296
+ pos_ids = torch.cat(pos_ids, dim=0)
297
+ max_grid_size = grid_thw[:, 1:].max()
298
+ rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
299
+ rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
300
+ return rotary_pos_emb
301
+
302
+ def calc_cos_sin(self, rotary_pos_emb):
303
+ cos = rotary_pos_emb.cos()
304
+ sin = rotary_pos_emb.sin()
305
+ cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
306
+ sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
307
+ rotary_pos_emb = (cos, sin)
308
+ return rotary_pos_emb
309
+
310
+ def forward(
311
+ self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, bf16=True
312
+ ) -> torch.Tensor:
313
+ if bf16:
314
+ hidden_states = hidden_states.bfloat16()
315
+ hidden_states = self.patch_embed(hidden_states, grid_thw)
316
+
317
+ rotary_pos_emb = self.rot_pos_emb(grid_thw)
318
+ rotary_pos_emb = self.calc_cos_sin(rotary_pos_emb)
319
+
320
+ cu_seqlens = torch.repeat_interleave(
321
+ grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
322
+ ).cumsum(
323
+ dim=0,
324
+ dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
325
+ )
326
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
327
+
328
+ for blk in self.blocks:
329
+ hidden_states = blk(
330
+ hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
331
+ )
332
+
333
+ if self.config.post_norm:
334
+ hidden_states = self.post_trunk_norm(hidden_states)
335
+
336
+ hidden_states = self.merger(hidden_states)
337
+ return hidden_states
@@ -92,7 +92,7 @@ class Ernie4Moe(nn.Module):
92
92
  correction_bias=self.gate.e_score_correction_bias,
93
93
  )
94
94
 
95
- self.experts = get_moe_impl_class()(
95
+ self.experts = get_moe_impl_class(quant_config)(
96
96
  num_experts=config.moe_num_experts,
97
97
  top_k=config.moe_k,
98
98
  hidden_size=config.hidden_size,