sglang 0.5.3rc0__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (482) hide show
  1. sglang/bench_one_batch.py +54 -37
  2. sglang/bench_one_batch_server.py +340 -34
  3. sglang/bench_serving.py +340 -159
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/backend/runtime_endpoint.py +1 -1
  9. sglang/lang/interpreter.py +1 -0
  10. sglang/lang/ir.py +13 -0
  11. sglang/launch_server.py +9 -2
  12. sglang/profiler.py +20 -3
  13. sglang/srt/_custom_ops.py +1 -1
  14. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  15. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +547 -0
  16. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  17. sglang/srt/compilation/backend.py +437 -0
  18. sglang/srt/compilation/compilation_config.py +20 -0
  19. sglang/srt/compilation/compilation_counter.py +47 -0
  20. sglang/srt/compilation/compile.py +210 -0
  21. sglang/srt/compilation/compiler_interface.py +503 -0
  22. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  23. sglang/srt/compilation/fix_functionalization.py +134 -0
  24. sglang/srt/compilation/fx_utils.py +83 -0
  25. sglang/srt/compilation/inductor_pass.py +140 -0
  26. sglang/srt/compilation/pass_manager.py +66 -0
  27. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  28. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  29. sglang/srt/configs/__init__.py +8 -0
  30. sglang/srt/configs/deepseek_ocr.py +262 -0
  31. sglang/srt/configs/deepseekvl2.py +194 -96
  32. sglang/srt/configs/dots_ocr.py +64 -0
  33. sglang/srt/configs/dots_vlm.py +2 -7
  34. sglang/srt/configs/falcon_h1.py +309 -0
  35. sglang/srt/configs/load_config.py +33 -2
  36. sglang/srt/configs/mamba_utils.py +117 -0
  37. sglang/srt/configs/model_config.py +284 -118
  38. sglang/srt/configs/modelopt_config.py +30 -0
  39. sglang/srt/configs/nemotron_h.py +286 -0
  40. sglang/srt/configs/olmo3.py +105 -0
  41. sglang/srt/configs/points_v15_chat.py +29 -0
  42. sglang/srt/configs/qwen3_next.py +11 -47
  43. sglang/srt/configs/qwen3_omni.py +613 -0
  44. sglang/srt/configs/qwen3_vl.py +576 -0
  45. sglang/srt/connector/remote_instance.py +1 -1
  46. sglang/srt/constrained/base_grammar_backend.py +6 -1
  47. sglang/srt/constrained/llguidance_backend.py +5 -0
  48. sglang/srt/constrained/outlines_backend.py +1 -1
  49. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  50. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  51. sglang/srt/constrained/utils.py +12 -0
  52. sglang/srt/constrained/xgrammar_backend.py +26 -15
  53. sglang/srt/debug_utils/dumper.py +10 -3
  54. sglang/srt/disaggregation/ascend/conn.py +2 -2
  55. sglang/srt/disaggregation/ascend/transfer_engine.py +48 -10
  56. sglang/srt/disaggregation/base/conn.py +17 -4
  57. sglang/srt/disaggregation/common/conn.py +268 -98
  58. sglang/srt/disaggregation/decode.py +172 -39
  59. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  60. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  61. sglang/srt/disaggregation/fake/conn.py +11 -3
  62. sglang/srt/disaggregation/mooncake/conn.py +203 -555
  63. sglang/srt/disaggregation/nixl/conn.py +217 -63
  64. sglang/srt/disaggregation/prefill.py +113 -270
  65. sglang/srt/disaggregation/utils.py +36 -5
  66. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  67. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  68. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  69. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  70. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  71. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  72. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  73. sglang/srt/distributed/naive_distributed.py +5 -4
  74. sglang/srt/distributed/parallel_state.py +203 -97
  75. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  76. sglang/srt/entrypoints/context.py +3 -2
  77. sglang/srt/entrypoints/engine.py +85 -65
  78. sglang/srt/entrypoints/grpc_server.py +632 -305
  79. sglang/srt/entrypoints/harmony_utils.py +2 -2
  80. sglang/srt/entrypoints/http_server.py +169 -17
  81. sglang/srt/entrypoints/http_server_engine.py +1 -7
  82. sglang/srt/entrypoints/openai/protocol.py +327 -34
  83. sglang/srt/entrypoints/openai/serving_base.py +74 -8
  84. sglang/srt/entrypoints/openai/serving_chat.py +202 -118
  85. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  86. sglang/srt/entrypoints/openai/serving_completions.py +20 -4
  87. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  88. sglang/srt/entrypoints/openai/serving_responses.py +47 -2
  89. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  90. sglang/srt/environ.py +323 -0
  91. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  92. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  93. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  94. sglang/srt/eplb/expert_distribution.py +3 -4
  95. sglang/srt/eplb/expert_location.py +30 -5
  96. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  97. sglang/srt/eplb/expert_location_updater.py +2 -2
  98. sglang/srt/function_call/base_format_detector.py +17 -18
  99. sglang/srt/function_call/function_call_parser.py +21 -16
  100. sglang/srt/function_call/glm4_moe_detector.py +4 -8
  101. sglang/srt/function_call/gpt_oss_detector.py +24 -1
  102. sglang/srt/function_call/json_array_parser.py +61 -0
  103. sglang/srt/function_call/kimik2_detector.py +17 -4
  104. sglang/srt/function_call/utils.py +98 -7
  105. sglang/srt/grpc/compile_proto.py +245 -0
  106. sglang/srt/grpc/grpc_request_manager.py +915 -0
  107. sglang/srt/grpc/health_servicer.py +189 -0
  108. sglang/srt/grpc/scheduler_launcher.py +181 -0
  109. sglang/srt/grpc/sglang_scheduler_pb2.py +81 -68
  110. sglang/srt/grpc/sglang_scheduler_pb2.pyi +124 -61
  111. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +92 -1
  112. sglang/srt/layers/activation.py +11 -7
  113. sglang/srt/layers/attention/aiter_backend.py +17 -18
  114. sglang/srt/layers/attention/ascend_backend.py +125 -10
  115. sglang/srt/layers/attention/attention_registry.py +226 -0
  116. sglang/srt/layers/attention/base_attn_backend.py +32 -4
  117. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  118. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  119. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  120. sglang/srt/layers/attention/fla/chunk.py +0 -1
  121. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  122. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  123. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  124. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  125. sglang/srt/layers/attention/fla/index.py +0 -2
  126. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  127. sglang/srt/layers/attention/fla/utils.py +0 -3
  128. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  129. sglang/srt/layers/attention/flashattention_backend.py +52 -15
  130. sglang/srt/layers/attention/flashinfer_backend.py +357 -212
  131. sglang/srt/layers/attention/flashinfer_mla_backend.py +31 -33
  132. sglang/srt/layers/attention/flashmla_backend.py +9 -7
  133. sglang/srt/layers/attention/hybrid_attn_backend.py +12 -4
  134. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +236 -133
  135. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  136. sglang/srt/layers/attention/mamba/causal_conv1d.py +2 -1
  137. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +24 -103
  138. sglang/srt/layers/attention/mamba/mamba.py +514 -1
  139. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  140. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  141. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  142. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  143. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  144. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  145. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  146. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  147. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +261 -0
  148. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  149. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  150. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  151. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  152. sglang/srt/layers/attention/nsa/nsa_indexer.py +718 -0
  153. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  154. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  155. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  156. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  157. sglang/srt/layers/attention/nsa/utils.py +23 -0
  158. sglang/srt/layers/attention/nsa_backend.py +1201 -0
  159. sglang/srt/layers/attention/tbo_backend.py +6 -6
  160. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  161. sglang/srt/layers/attention/triton_backend.py +249 -42
  162. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  163. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  164. sglang/srt/layers/attention/trtllm_mha_backend.py +7 -9
  165. sglang/srt/layers/attention/trtllm_mla_backend.py +523 -48
  166. sglang/srt/layers/attention/utils.py +11 -7
  167. sglang/srt/layers/attention/vision.py +61 -3
  168. sglang/srt/layers/attention/wave_backend.py +4 -4
  169. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  170. sglang/srt/layers/communicator.py +19 -7
  171. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  172. sglang/srt/layers/deep_gemm_wrapper/configurer.py +25 -0
  173. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  174. sglang/srt/layers/dp_attention.py +28 -1
  175. sglang/srt/layers/elementwise.py +3 -1
  176. sglang/srt/layers/layernorm.py +47 -15
  177. sglang/srt/layers/linear.py +30 -5
  178. sglang/srt/layers/logits_processor.py +161 -18
  179. sglang/srt/layers/modelopt_utils.py +11 -0
  180. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  181. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  182. sglang/srt/layers/moe/ep_moe/kernels.py +36 -458
  183. sglang/srt/layers/moe/ep_moe/layer.py +243 -448
  184. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  185. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  186. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  187. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  188. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  189. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  190. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +17 -5
  191. sglang/srt/layers/moe/fused_moe_triton/layer.py +86 -81
  192. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  193. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  194. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  195. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  196. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  197. sglang/srt/layers/moe/router.py +51 -15
  198. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  199. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  200. sglang/srt/layers/moe/token_dispatcher/deepep.py +177 -106
  201. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  202. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  203. sglang/srt/layers/moe/topk.py +3 -2
  204. sglang/srt/layers/moe/utils.py +27 -1
  205. sglang/srt/layers/parameter.py +23 -6
  206. sglang/srt/layers/quantization/__init__.py +2 -53
  207. sglang/srt/layers/quantization/awq.py +183 -6
  208. sglang/srt/layers/quantization/awq_triton.py +29 -0
  209. sglang/srt/layers/quantization/base_config.py +20 -1
  210. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  211. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +21 -49
  212. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  213. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +5 -0
  214. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  215. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  216. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  217. sglang/srt/layers/quantization/fp8.py +86 -20
  218. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  219. sglang/srt/layers/quantization/fp8_utils.py +43 -15
  220. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  221. sglang/srt/layers/quantization/gptq.py +0 -1
  222. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  223. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  224. sglang/srt/layers/quantization/modelopt_quant.py +141 -81
  225. sglang/srt/layers/quantization/mxfp4.py +17 -34
  226. sglang/srt/layers/quantization/petit.py +1 -1
  227. sglang/srt/layers/quantization/quark/quark.py +3 -1
  228. sglang/srt/layers/quantization/quark/quark_moe.py +18 -5
  229. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  230. sglang/srt/layers/quantization/unquant.py +1 -4
  231. sglang/srt/layers/quantization/utils.py +0 -1
  232. sglang/srt/layers/quantization/w4afp8.py +51 -24
  233. sglang/srt/layers/quantization/w8a8_int8.py +45 -27
  234. sglang/srt/layers/radix_attention.py +59 -9
  235. sglang/srt/layers/rotary_embedding.py +750 -46
  236. sglang/srt/layers/sampler.py +84 -16
  237. sglang/srt/layers/sparse_pooler.py +98 -0
  238. sglang/srt/layers/utils.py +23 -1
  239. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  240. sglang/srt/lora/backend/base_backend.py +3 -3
  241. sglang/srt/lora/backend/chunked_backend.py +348 -0
  242. sglang/srt/lora/backend/triton_backend.py +9 -4
  243. sglang/srt/lora/eviction_policy.py +139 -0
  244. sglang/srt/lora/lora.py +7 -5
  245. sglang/srt/lora/lora_manager.py +33 -7
  246. sglang/srt/lora/lora_registry.py +1 -1
  247. sglang/srt/lora/mem_pool.py +41 -17
  248. sglang/srt/lora/triton_ops/__init__.py +4 -0
  249. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  250. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +176 -0
  251. sglang/srt/lora/utils.py +7 -5
  252. sglang/srt/managers/cache_controller.py +83 -152
  253. sglang/srt/managers/data_parallel_controller.py +156 -87
  254. sglang/srt/managers/detokenizer_manager.py +51 -24
  255. sglang/srt/managers/io_struct.py +223 -129
  256. sglang/srt/managers/mm_utils.py +49 -10
  257. sglang/srt/managers/multi_tokenizer_mixin.py +83 -98
  258. sglang/srt/managers/multimodal_processor.py +1 -2
  259. sglang/srt/managers/overlap_utils.py +130 -0
  260. sglang/srt/managers/schedule_batch.py +340 -529
  261. sglang/srt/managers/schedule_policy.py +158 -18
  262. sglang/srt/managers/scheduler.py +665 -620
  263. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  264. sglang/srt/managers/scheduler_metrics_mixin.py +150 -131
  265. sglang/srt/managers/scheduler_output_processor_mixin.py +337 -122
  266. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  267. sglang/srt/managers/scheduler_profiler_mixin.py +62 -15
  268. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  269. sglang/srt/managers/scheduler_update_weights_mixin.py +40 -14
  270. sglang/srt/managers/tokenizer_communicator_mixin.py +141 -19
  271. sglang/srt/managers/tokenizer_manager.py +462 -226
  272. sglang/srt/managers/tp_worker.py +217 -156
  273. sglang/srt/managers/utils.py +79 -47
  274. sglang/srt/mem_cache/allocator.py +21 -22
  275. sglang/srt/mem_cache/allocator_ascend.py +42 -28
  276. sglang/srt/mem_cache/base_prefix_cache.py +3 -3
  277. sglang/srt/mem_cache/chunk_cache.py +20 -2
  278. sglang/srt/mem_cache/common.py +480 -0
  279. sglang/srt/mem_cache/evict_policy.py +38 -0
  280. sglang/srt/mem_cache/hicache_storage.py +44 -2
  281. sglang/srt/mem_cache/hiradix_cache.py +134 -34
  282. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  283. sglang/srt/mem_cache/memory_pool.py +602 -208
  284. sglang/srt/mem_cache/memory_pool_host.py +134 -183
  285. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  286. sglang/srt/mem_cache/radix_cache.py +263 -78
  287. sglang/srt/mem_cache/radix_cache_cpp.py +29 -21
  288. sglang/srt/mem_cache/storage/__init__.py +10 -0
  289. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +157 -0
  290. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +97 -0
  291. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  292. sglang/srt/mem_cache/storage/eic/eic_storage.py +777 -0
  293. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  294. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  295. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +180 -59
  296. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +15 -9
  297. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +217 -26
  298. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  299. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  300. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  301. sglang/srt/mem_cache/swa_radix_cache.py +115 -58
  302. sglang/srt/metrics/collector.py +113 -120
  303. sglang/srt/metrics/func_timer.py +3 -8
  304. sglang/srt/metrics/utils.py +8 -1
  305. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  306. sglang/srt/model_executor/cuda_graph_runner.py +81 -36
  307. sglang/srt/model_executor/forward_batch_info.py +40 -50
  308. sglang/srt/model_executor/model_runner.py +507 -319
  309. sglang/srt/model_executor/npu_graph_runner.py +11 -5
  310. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  311. sglang/srt/model_loader/__init__.py +1 -1
  312. sglang/srt/model_loader/loader.py +438 -37
  313. sglang/srt/model_loader/utils.py +0 -1
  314. sglang/srt/model_loader/weight_utils.py +200 -27
  315. sglang/srt/models/apertus.py +2 -3
  316. sglang/srt/models/arcee.py +2 -2
  317. sglang/srt/models/bailing_moe.py +40 -56
  318. sglang/srt/models/bailing_moe_nextn.py +3 -4
  319. sglang/srt/models/bert.py +1 -1
  320. sglang/srt/models/deepseek_nextn.py +25 -4
  321. sglang/srt/models/deepseek_ocr.py +1516 -0
  322. sglang/srt/models/deepseek_v2.py +793 -235
  323. sglang/srt/models/dots_ocr.py +171 -0
  324. sglang/srt/models/dots_vlm.py +0 -1
  325. sglang/srt/models/dots_vlm_vit.py +1 -1
  326. sglang/srt/models/falcon_h1.py +570 -0
  327. sglang/srt/models/gemma3_causal.py +0 -2
  328. sglang/srt/models/gemma3_mm.py +17 -1
  329. sglang/srt/models/gemma3n_mm.py +2 -3
  330. sglang/srt/models/glm4_moe.py +17 -40
  331. sglang/srt/models/glm4_moe_nextn.py +4 -4
  332. sglang/srt/models/glm4v.py +3 -2
  333. sglang/srt/models/glm4v_moe.py +6 -6
  334. sglang/srt/models/gpt_oss.py +12 -35
  335. sglang/srt/models/grok.py +10 -23
  336. sglang/srt/models/hunyuan.py +2 -7
  337. sglang/srt/models/interns1.py +0 -1
  338. sglang/srt/models/kimi_vl.py +1 -7
  339. sglang/srt/models/kimi_vl_moonvit.py +4 -2
  340. sglang/srt/models/llama.py +6 -2
  341. sglang/srt/models/llama_eagle3.py +1 -1
  342. sglang/srt/models/longcat_flash.py +6 -23
  343. sglang/srt/models/longcat_flash_nextn.py +4 -15
  344. sglang/srt/models/mimo.py +2 -13
  345. sglang/srt/models/mimo_mtp.py +1 -2
  346. sglang/srt/models/minicpmo.py +7 -5
  347. sglang/srt/models/mixtral.py +1 -4
  348. sglang/srt/models/mllama.py +1 -1
  349. sglang/srt/models/mllama4.py +27 -6
  350. sglang/srt/models/nemotron_h.py +511 -0
  351. sglang/srt/models/olmo2.py +31 -4
  352. sglang/srt/models/opt.py +5 -5
  353. sglang/srt/models/phi.py +1 -1
  354. sglang/srt/models/phi4mm.py +1 -1
  355. sglang/srt/models/phimoe.py +0 -1
  356. sglang/srt/models/pixtral.py +0 -3
  357. sglang/srt/models/points_v15_chat.py +186 -0
  358. sglang/srt/models/qwen.py +0 -1
  359. sglang/srt/models/qwen2.py +0 -7
  360. sglang/srt/models/qwen2_5_vl.py +5 -5
  361. sglang/srt/models/qwen2_audio.py +2 -15
  362. sglang/srt/models/qwen2_moe.py +70 -4
  363. sglang/srt/models/qwen2_vl.py +6 -3
  364. sglang/srt/models/qwen3.py +18 -3
  365. sglang/srt/models/qwen3_moe.py +50 -38
  366. sglang/srt/models/qwen3_next.py +43 -21
  367. sglang/srt/models/qwen3_next_mtp.py +3 -4
  368. sglang/srt/models/qwen3_omni_moe.py +661 -0
  369. sglang/srt/models/qwen3_vl.py +791 -0
  370. sglang/srt/models/qwen3_vl_moe.py +343 -0
  371. sglang/srt/models/registry.py +15 -3
  372. sglang/srt/models/roberta.py +55 -3
  373. sglang/srt/models/sarashina2_vision.py +268 -0
  374. sglang/srt/models/solar.py +505 -0
  375. sglang/srt/models/starcoder2.py +357 -0
  376. sglang/srt/models/step3_vl.py +3 -5
  377. sglang/srt/models/torch_native_llama.py +9 -2
  378. sglang/srt/models/utils.py +61 -0
  379. sglang/srt/multimodal/processors/base_processor.py +21 -9
  380. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  381. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  382. sglang/srt/multimodal/processors/dots_vlm.py +2 -4
  383. sglang/srt/multimodal/processors/glm4v.py +1 -5
  384. sglang/srt/multimodal/processors/internvl.py +20 -10
  385. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  386. sglang/srt/multimodal/processors/mllama4.py +0 -8
  387. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  388. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  389. sglang/srt/multimodal/processors/qwen_vl.py +83 -17
  390. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  391. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  392. sglang/srt/parser/conversation.py +41 -0
  393. sglang/srt/parser/jinja_template_utils.py +6 -0
  394. sglang/srt/parser/reasoning_parser.py +0 -1
  395. sglang/srt/sampling/custom_logit_processor.py +77 -2
  396. sglang/srt/sampling/sampling_batch_info.py +36 -23
  397. sglang/srt/sampling/sampling_params.py +75 -0
  398. sglang/srt/server_args.py +1300 -338
  399. sglang/srt/server_args_config_parser.py +146 -0
  400. sglang/srt/single_batch_overlap.py +161 -0
  401. sglang/srt/speculative/base_spec_worker.py +34 -0
  402. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  403. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  404. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  405. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  406. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  407. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  408. sglang/srt/speculative/draft_utils.py +226 -0
  409. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +26 -8
  410. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +26 -3
  411. sglang/srt/speculative/eagle_info.py +786 -0
  412. sglang/srt/speculative/eagle_info_v2.py +458 -0
  413. sglang/srt/speculative/eagle_utils.py +113 -1270
  414. sglang/srt/speculative/eagle_worker.py +120 -285
  415. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  416. sglang/srt/speculative/ngram_info.py +433 -0
  417. sglang/srt/speculative/ngram_worker.py +246 -0
  418. sglang/srt/speculative/spec_info.py +49 -0
  419. sglang/srt/speculative/spec_utils.py +641 -0
  420. sglang/srt/speculative/standalone_worker.py +4 -14
  421. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  422. sglang/srt/tracing/trace.py +32 -6
  423. sglang/srt/two_batch_overlap.py +35 -18
  424. sglang/srt/utils/__init__.py +2 -0
  425. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  426. sglang/srt/{utils.py → utils/common.py} +583 -113
  427. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +86 -19
  428. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  429. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  430. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  431. sglang/srt/utils/profile_merger.py +199 -0
  432. sglang/srt/utils/rpd_utils.py +452 -0
  433. sglang/srt/utils/slow_rank_detector.py +71 -0
  434. sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +5 -7
  435. sglang/srt/warmup.py +8 -4
  436. sglang/srt/weight_sync/utils.py +1 -1
  437. sglang/test/attention/test_flashattn_backend.py +1 -1
  438. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  439. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  440. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  441. sglang/test/few_shot_gsm8k_engine.py +2 -4
  442. sglang/test/get_logits_ut.py +57 -0
  443. sglang/test/kit_matched_stop.py +157 -0
  444. sglang/test/longbench_v2/__init__.py +1 -0
  445. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  446. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  447. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  448. sglang/test/run_eval.py +120 -11
  449. sglang/test/runners.py +3 -1
  450. sglang/test/send_one.py +42 -7
  451. sglang/test/simple_eval_common.py +8 -2
  452. sglang/test/simple_eval_gpqa.py +0 -1
  453. sglang/test/simple_eval_humaneval.py +0 -3
  454. sglang/test/simple_eval_longbench_v2.py +344 -0
  455. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  456. sglang/test/test_block_fp8.py +3 -4
  457. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  458. sglang/test/test_cutlass_moe.py +1 -2
  459. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  460. sglang/test/test_deterministic.py +430 -0
  461. sglang/test/test_deterministic_utils.py +73 -0
  462. sglang/test/test_disaggregation_utils.py +93 -1
  463. sglang/test/test_marlin_moe.py +0 -1
  464. sglang/test/test_programs.py +1 -1
  465. sglang/test/test_utils.py +432 -16
  466. sglang/utils.py +10 -1
  467. sglang/version.py +1 -1
  468. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/METADATA +64 -43
  469. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/RECORD +476 -346
  470. sglang/srt/entrypoints/grpc_request_manager.py +0 -580
  471. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -32
  472. sglang/srt/managers/tp_worker_overlap_thread.py +0 -319
  473. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  474. sglang/srt/speculative/build_eagle_tree.py +0 -427
  475. sglang/test/test_block_fp8_ep.py +0 -358
  476. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  477. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  478. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  479. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  480. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  481. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  482. {sglang-0.5.3rc0.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1516 @@
1
+ # Copyright 2025 The SwissAI Initiative
2
+ # Copyright 2023-2024 SGLang Team
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ # Adapted from
17
+ # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1
18
+ """Inference-only Apertus model compatible with HuggingFace weights."""
19
+ import copy
20
+ import logging
21
+ import math
22
+ from functools import partial
23
+ from typing import Iterable, List, Optional, Set, Tuple, Type, TypeAlias, Union
24
+
25
+ import torch
26
+ import torch.nn.functional as F
27
+ from torch import Tensor, nn
28
+ from transformers.models.vitdet.modeling_vitdet import get_rel_pos
29
+
30
+ from sglang.srt.configs.deepseek_ocr import DeepseekVLV2Config
31
+ from sglang.srt.layers.quantization import QuantizationConfig
32
+ from sglang.srt.managers.mm_utils import (
33
+ MultiModalityDataPaddingPatternMultimodalTokens,
34
+ general_mm_embed_routine,
35
+ )
36
+ from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
37
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
38
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
39
+ from sglang.srt.models.deepseek import DeepseekForCausalLM
40
+ from sglang.srt.models.deepseek_v2 import DeepseekV2ForCausalLM, DeepseekV3ForCausalLM
41
+ from sglang.srt.models.transformers import maybe_prefix
42
+
43
+ NestedTensors: TypeAlias = Union[
44
+ list["NestedTensors"],
45
+ list["torch.Tensor"],
46
+ "torch.Tensor",
47
+ tuple["torch.Tensor", ...],
48
+ ]
49
+
50
+ MultiModalEmbeddings: TypeAlias = list[Tensor] | Tensor | tuple[Tensor, ...]
51
+
52
+ logger = logging.getLogger(__name__)
53
+
54
+
55
+ def _flatten_embeddings(embeddings: NestedTensors) -> torch.Tensor:
56
+ """
57
+ Recursively flattens and concatenates NestedTensors on all but the last
58
+ dimension.
59
+ """
60
+
61
+ if isinstance(embeddings, torch.Tensor):
62
+ # Flatten all but the last dimension.
63
+ return embeddings.flatten(0, -2)
64
+
65
+ return torch.cat(tuple(_flatten_embeddings(t) for t in embeddings))
66
+
67
+
68
+ def _embedding_count_expression(embeddings: NestedTensors) -> str:
69
+ """
70
+ Constructs a debugging representation of the number of embeddings in the
71
+ NestedTensors.
72
+ """
73
+
74
+ if isinstance(embeddings, torch.Tensor):
75
+ return " x ".join([str(dim) for dim in embeddings.shape[:-1]])
76
+
77
+ return " + ".join(_embedding_count_expression(inner) for inner in embeddings)
78
+
79
+
80
+ def _merge_multimodal_embeddings(
81
+ inputs_embeds: torch.Tensor,
82
+ multimodal_embeddings: NestedTensors,
83
+ is_multimodal: torch.Tensor,
84
+ ) -> torch.Tensor:
85
+ """
86
+ Merge `multimodal_embeddings` into `inputs_embeds` by overwriting the
87
+ positions in `inputs_embeds` corresponding to placeholder tokens in
88
+ `input_ids`.
89
+
90
+ Note:
91
+ This updates `inputs_embeds` in place.
92
+ """
93
+ if len(multimodal_embeddings) == 0:
94
+ return inputs_embeds
95
+
96
+ mm_embeds_flat = _flatten_embeddings(multimodal_embeddings)
97
+ input_dtype = inputs_embeds.dtype
98
+
99
+ try:
100
+ # NOTE: This can avoid D2H sync (#22105), but fails to
101
+ # raise an error if is_multimodal.sum() < len(mm_embeds_flat)
102
+ inputs_embeds.masked_scatter_(
103
+ is_multimodal.unsqueeze(-1), mm_embeds_flat.to(dtype=input_dtype)
104
+ )
105
+ except RuntimeError as e:
106
+ num_actual_tokens = len(mm_embeds_flat)
107
+ num_expected_tokens = is_multimodal.sum().item()
108
+
109
+ if num_actual_tokens != num_expected_tokens:
110
+ expr = _embedding_count_expression(multimodal_embeddings)
111
+
112
+ raise ValueError(
113
+ f"Attempted to assign {expr} = {num_actual_tokens} "
114
+ f"multimodal tokens to {num_expected_tokens} placeholders"
115
+ ) from e
116
+
117
+ raise ValueError("Error during masked scatter operation") from e
118
+
119
+ return inputs_embeds
120
+
121
+
122
+ def isin_list(
123
+ elements: torch.Tensor,
124
+ test_elements_list: list[int],
125
+ ) -> torch.Tensor:
126
+ test_elements = torch.tensor(test_elements_list, pin_memory=True).to(
127
+ device=elements.device, non_blocking=True
128
+ )
129
+
130
+ return torch.isin(elements, test_elements)
131
+
132
+
133
+ def merge_multimodal_embeddings(
134
+ input_ids: torch.Tensor,
135
+ inputs_embeds: torch.Tensor,
136
+ multimodal_embeddings: NestedTensors,
137
+ placeholder_token_id: int | list[int],
138
+ ) -> torch.Tensor:
139
+ """
140
+ Merge `multimodal_embeddings` into `inputs_embeds` by overwriting the
141
+ positions in `inputs_embeds` corresponding to placeholder tokens in
142
+ `input_ids`.
143
+
144
+ `placeholder_token_id` can be a list of token ids (e.g, token ids
145
+ of img_start, img_break, and img_end tokens) when needed: This means
146
+ the order of these tokens in the `input_ids` MUST MATCH the order of
147
+ their embeddings in `multimodal_embeddings` since we need to
148
+ slice-merge instead of individually scattering.
149
+
150
+ For example, if input_ids is "TTTTTSIIIBIIIBIIIETTT", where
151
+ - T is text token
152
+ - S is image start token
153
+ - I is image embedding token
154
+ - B is image break token
155
+ - E is image end token.
156
+
157
+ Then the image embeddings (that correspond to I's) from vision encoder
158
+ must be padded with embeddings of S, B, and E in the same order of
159
+ input_ids for a correct embedding merge.
160
+
161
+ Note:
162
+ This updates `inputs_embeds` in place.
163
+ """
164
+ if isinstance(placeholder_token_id, list):
165
+ is_multimodal = isin_list(input_ids, placeholder_token_id)
166
+ else:
167
+ is_multimodal = input_ids == placeholder_token_id
168
+
169
+ return _merge_multimodal_embeddings(
170
+ inputs_embeds,
171
+ multimodal_embeddings=multimodal_embeddings,
172
+ is_multimodal=is_multimodal,
173
+ )
174
+
175
+
176
+ class MlpProjector(nn.Module):
177
+
178
+ def __init__(
179
+ self,
180
+ projector_type,
181
+ input_dim,
182
+ n_embed,
183
+ depth=1,
184
+ mlp_ratio=1,
185
+ downsample_ratio=4,
186
+ ):
187
+ self.projector_type = projector_type
188
+ self.input_dim = input_dim
189
+ self.n_embed = n_embed
190
+ self.depth = depth
191
+ self.token_pooling = False
192
+ self.conv_fusion_high_low_features = False
193
+
194
+ super().__init__()
195
+
196
+ if projector_type == "identity":
197
+ modules = nn.Identity()
198
+
199
+ elif projector_type == "linear":
200
+ modules = nn.Linear(input_dim, n_embed)
201
+
202
+ elif projector_type == "mlp_gelu":
203
+ mlp_depth = depth
204
+ modules = [nn.Linear(input_dim, n_embed)]
205
+ for _ in range(1, mlp_depth):
206
+ modules.append(nn.GELU())
207
+ modules.append(nn.Linear(n_embed, n_embed))
208
+ modules = nn.Sequential(*modules)
209
+
210
+ elif projector_type == "normlayer_downsample_mlp_gelu":
211
+ mlp_depth = depth
212
+ mlp_ratio = mlp_ratio
213
+ modules = [
214
+ nn.LayerNorm(input_dim * downsample_ratio * downsample_ratio),
215
+ nn.Linear(
216
+ input_dim * downsample_ratio * downsample_ratio,
217
+ n_embed * mlp_ratio,
218
+ ),
219
+ ]
220
+ for _ in range(1, mlp_depth - 1):
221
+ modules.append(nn.GELU())
222
+ modules.append(nn.Linear(n_embed * mlp_ratio, n_embed * mlp_ratio))
223
+ modules.append(nn.GELU())
224
+ modules.append(nn.Linear(n_embed * mlp_ratio, n_embed))
225
+ modules = nn.Sequential(*modules)
226
+
227
+ elif projector_type == "downsample_mlp_gelu":
228
+ mlp_depth = depth
229
+ mlp_ratio = mlp_ratio
230
+ modules = [
231
+ nn.Linear(
232
+ input_dim * downsample_ratio * downsample_ratio,
233
+ n_embed * mlp_ratio,
234
+ )
235
+ ]
236
+ for _ in range(1, mlp_depth - 1):
237
+ modules.append(nn.GELU())
238
+ modules.append(nn.Linear(n_embed * mlp_ratio, n_embed * mlp_ratio))
239
+ modules.append(nn.GELU())
240
+ modules.append(nn.Linear(n_embed * mlp_ratio, n_embed))
241
+ modules = nn.Sequential(*modules)
242
+
243
+ elif projector_type == "low_high_hybrid_split_mlp_gelu":
244
+ mlp_depth = depth
245
+ self.high_up_proj = nn.Linear(input_dim, n_embed // 2)
246
+ self.low_up_proj = nn.Linear(input_dim, n_embed // 2)
247
+
248
+ modules = []
249
+ for _ in range(1, mlp_depth):
250
+ modules.append(nn.GELU())
251
+ modules.append(nn.Linear(n_embed, n_embed))
252
+ modules = nn.Sequential(*modules)
253
+
254
+ elif projector_type == "hybrid_split_feature_mlp_gelu":
255
+ mlp_depth = depth
256
+ channel_div = 0.5
257
+ self.high_up_proj = nn.Linear(input_dim[0], int(n_embed * channel_div))
258
+ self.low_up_proj = nn.Linear(
259
+ input_dim[1], n_embed - int(n_embed * channel_div)
260
+ )
261
+
262
+ modules = []
263
+ for _ in range(1, mlp_depth):
264
+ modules.append(nn.GELU())
265
+ modules.append(nn.Linear(n_embed, n_embed))
266
+ modules = nn.Sequential(*modules)
267
+
268
+ elif projector_type == "low_high_split_mlp_gelu":
269
+ mlp_depth = depth
270
+ modules = []
271
+ for _ in range(1, mlp_depth):
272
+ modules.append(nn.GELU())
273
+ modules.append(nn.Linear(n_embed // 2, n_embed // 2))
274
+ modules = nn.Sequential(*modules)
275
+ self.high_layers = nn.Sequential(*modules)
276
+ self.low_layers = copy.deepcopy(modules)
277
+
278
+ else:
279
+ raise ValueError(f"Unknown projector type: {projector_type}")
280
+
281
+ self.layers = modules
282
+
283
+ def forward(self, x):
284
+ if self.token_pooling:
285
+ batch_size, wxh, channels = x.shape
286
+ w = h = int(wxh**0.5)
287
+ x = x.view(batch_size, w, h, channels)
288
+ x = x.permute(0, 3, 1, 2)
289
+ patches = x.unfold(2, 2, 2).unfold(3, 2, 2)
290
+ batch_size, channels, h_patches, w_patches, _, _ = patches.size()
291
+ # Concatenate on channel dimension
292
+ patches = patches.contiguous().view(
293
+ batch_size, channels, h_patches * w_patches, -1
294
+ )
295
+
296
+ # Pass through linear layer
297
+ patches = patches.permute(0, 2, 1, 3).contiguous()
298
+ patches = patches.view(batch_size, h_patches * w_patches, channels * 4)
299
+
300
+ x = self.token_pooling_layer(patches)
301
+
302
+ if self.conv_fusion_high_low_features:
303
+ x = self.fusion_layer(x[:, 0]) + x[:, 1]
304
+
305
+ if self.projector_type == "low_high_hybrid_split_mlp_gelu":
306
+ high_x, low_x = x[0], x[1]
307
+ high_x = self.high_up_proj(high_x)
308
+ low_x = self.low_up_proj(low_x)
309
+ x = torch.concat([high_x, low_x], dim=-1)
310
+
311
+ if self.projector_type == "hybrid_split_feature_mlp_gelu":
312
+ high_x = x[..., : self.input_dim[0]]
313
+ low_x = x[..., self.input_dim[0] :]
314
+ high_x = self.high_up_proj(high_x)
315
+ low_x = self.low_up_proj(low_x)
316
+ x = torch.concat([high_x, low_x], dim=-1)
317
+
318
+ if self.projector_type == "low_high_split_mlp_gelu":
319
+ high_x, low_x = x[0], x[1]
320
+ high_x = self.high_layers(high_x)
321
+ low_x = self.low_layers(low_x)
322
+ x = torch.concat([high_x, low_x], dim=-1)
323
+ return x
324
+
325
+ if (
326
+ self.projector_type == "downsample_mlp_gelu"
327
+ or self.projector_type == "normlayer_downsample_mlp_gelu"
328
+ ):
329
+ bs, hw, input_dim = x.shape
330
+ h = w = int((hw) ** 0.5)
331
+
332
+ """compute padding"""
333
+ if h % self.downsample_ratio:
334
+ pad = self.downsample_ratio - h % self.downsample_ratio
335
+ else:
336
+ pad = 0
337
+ x = x.reshape(bs, h, w, input_dim)
338
+ if pad > 0:
339
+ x = F.pad(x, (0, 0, 0, pad, 0, pad), "constant", 0)
340
+
341
+ """4 to 1 concat"""
342
+ x = x.permute(0, 3, 1, 2) # B, C, H, W
343
+ x = F.unfold(
344
+ x,
345
+ kernel_size=self.downsample_ratio,
346
+ stride=self.downsample_ratio,
347
+ padding=0,
348
+ ) # B, C*4, HW // 4
349
+ x = x.permute(0, 2, 1)
350
+
351
+ return self.layers(x)
352
+
353
+
354
+ class LayerNorm2d(nn.Module):
355
+ def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
356
+ super().__init__()
357
+ self.weight = nn.Parameter(torch.ones(num_channels))
358
+ self.bias = nn.Parameter(torch.zeros(num_channels))
359
+ self.eps = eps
360
+
361
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
362
+ u = x.mean(1, keepdim=True)
363
+ s = (x - u).pow(2).mean(1, keepdim=True)
364
+ x = (x - u) / torch.sqrt(s + self.eps)
365
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
366
+ return x
367
+
368
+
369
+ class MLPBlock(nn.Module):
370
+ def __init__(
371
+ self,
372
+ embedding_dim: int,
373
+ mlp_dim: int,
374
+ act: Type[nn.Module] = nn.GELU,
375
+ ) -> None:
376
+ super().__init__()
377
+ self.lin1 = nn.Linear(embedding_dim, mlp_dim)
378
+ self.lin2 = nn.Linear(mlp_dim, embedding_dim)
379
+ self.act = act()
380
+
381
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
382
+ return self.lin2(self.act(self.lin1(x)))
383
+
384
+
385
+ def add_decomposed_rel_pos(
386
+ q: torch.Tensor,
387
+ rel_pos_h: torch.Tensor,
388
+ rel_pos_w: torch.Tensor,
389
+ q_size: Tuple[int, int],
390
+ k_size: Tuple[int, int],
391
+ ) -> torch.Tensor:
392
+ """
393
+ Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
394
+ https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
395
+ Args:
396
+ q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
397
+ rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
398
+ rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
399
+ q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
400
+ k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
401
+ Returns:
402
+ attn (Tensor): attention map with added relative positional embeddings.
403
+ """
404
+ q_h, q_w = q_size
405
+ k_h, k_w = k_size
406
+ Rh = get_rel_pos(q_h, k_h, rel_pos_h)
407
+ Rw = get_rel_pos(q_w, k_w, rel_pos_w)
408
+
409
+ B, _, dim = q.shape
410
+ r_q = q.reshape(B, q_h, q_w, dim)
411
+ rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
412
+ rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
413
+ rel_h = rel_h.unsqueeze(-1)
414
+ rel_w = rel_w.unsqueeze(-2)
415
+ rel_h = rel_h.reshape(B, q_h * q_w, k_h, 1)
416
+ rel_w = rel_w.reshape(B, q_h * q_w, 1, k_w)
417
+
418
+ return rel_h, rel_w
419
+
420
+
421
+ class Attention(nn.Module):
422
+ """Multi-head Attention block with relative position embeddings."""
423
+
424
+ def __init__(
425
+ self,
426
+ dim: int,
427
+ num_heads: int = 8,
428
+ qkv_bias: bool = True,
429
+ use_rel_pos: bool = False,
430
+ rel_pos_zero_init: bool = True,
431
+ input_size: Optional[Tuple[int, int]] = None,
432
+ ) -> None:
433
+ """
434
+ Args:
435
+ dim (int): Number of input channels.
436
+ num_heads (int): Number of attention heads.
437
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
438
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
439
+ input_size (tuple(int, int) or None): Input resolution for calculating the relative
440
+ positional parameter size.
441
+ """
442
+ super().__init__()
443
+ self.num_heads = num_heads
444
+ head_dim = dim // num_heads
445
+ self.scale = head_dim**-0.5
446
+
447
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
448
+ self.proj = nn.Linear(dim, dim)
449
+
450
+ self.use_rel_pos = use_rel_pos
451
+ if self.use_rel_pos:
452
+ assert (
453
+ input_size is not None
454
+ ), "Input size must be provided if using relative positional encoding."
455
+ # initialize relative positional embeddings
456
+ self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
457
+ self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
458
+
459
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
460
+ B, H, W, _ = x.shape
461
+ # qkv with shape (3, B, nHead, H * W, C)
462
+ qkv = (
463
+ self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
464
+ )
465
+ # q, k, v with shape (B * nHead, H * W, C)
466
+ q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
467
+
468
+ rel_h, rel_w = None, None
469
+ if self.use_rel_pos:
470
+ rel_h, rel_w = add_decomposed_rel_pos(
471
+ q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)
472
+ )
473
+
474
+ q = q.view(B, self.num_heads, H * W, -1)
475
+ k = k.view(B, self.num_heads, H * W, -1)
476
+ v = v.view(B, self.num_heads, H * W, -1)
477
+
478
+ if self.use_rel_pos:
479
+ rel_h = rel_h.view(
480
+ B, self.num_heads, rel_h.size(1), rel_h.size(2), rel_h.size(3)
481
+ )
482
+ rel_w = rel_w.view(
483
+ B, self.num_heads, rel_w.size(1), rel_w.size(2), rel_w.size(3)
484
+ )
485
+ attn_bias = (rel_h + rel_w).view(
486
+ B, self.num_heads, rel_h.size(2), rel_h.size(3) * rel_w.size(4)
487
+ )
488
+ x = torch.nn.functional.scaled_dot_product_attention(
489
+ q, k, v, attn_mask=attn_bias
490
+ )
491
+ # x = _attention_rel_h_rel_w(q, k, v, rel_h, rel_w)
492
+ else:
493
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
494
+
495
+ x = (
496
+ x.view(B, self.num_heads, H, W, -1)
497
+ .permute(0, 2, 3, 1, 4)
498
+ .reshape(B, H, W, -1)
499
+ )
500
+
501
+ x = self.proj(x)
502
+
503
+ return x
504
+
505
+
506
+ def window_partition(
507
+ x: torch.Tensor, window_size: int
508
+ ) -> Tuple[torch.Tensor, Tuple[int, int]]:
509
+ """
510
+ Partition into non-overlapping windows with padding if needed.
511
+ Args:
512
+ x (tensor): input tokens with [B, H, W, C].
513
+ window_size (int): window size.
514
+ Returns:
515
+ windows: windows after partition with [B * num_windows, window_size, window_size, C].
516
+ (Hp, Wp): padded height and width before partition
517
+ """
518
+ B, H, W, C = x.shape
519
+
520
+ pad_h = (window_size - H % window_size) % window_size
521
+ pad_w = (window_size - W % window_size) % window_size
522
+ if pad_h > 0 or pad_w > 0:
523
+ x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
524
+ Hp, Wp = H + pad_h, W + pad_w
525
+
526
+ x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
527
+ windows = (
528
+ x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
529
+ )
530
+ return windows, (Hp, Wp)
531
+
532
+
533
+ def window_unpartition(
534
+ windows: torch.Tensor,
535
+ window_size: int,
536
+ pad_hw: Tuple[int, int],
537
+ hw: Tuple[int, int],
538
+ ) -> torch.Tensor:
539
+ """
540
+ Window unpartition into original sequences and removing padding.
541
+ Args:
542
+ windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
543
+ window_size (int): window size.
544
+ pad_hw (Tuple): padded height and width (Hp, Wp).
545
+ hw (Tuple): original height and width (H, W) before padding.
546
+ Returns:
547
+ x: unpartitioned sequences with [B, H, W, C].
548
+ """
549
+ Hp, Wp = pad_hw
550
+ H, W = hw
551
+ B = windows.shape[0] // (Hp * Wp // window_size // window_size)
552
+ x = windows.view(
553
+ B, Hp // window_size, Wp // window_size, window_size, window_size, -1
554
+ )
555
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
556
+
557
+ if Hp > H or Wp > W:
558
+ x = x[:, :H, :W, :].contiguous()
559
+ return x
560
+
561
+
562
+ class Block(nn.Module):
563
+ """Transformer blocks with support of window attention and residual propagation blocks"""
564
+
565
+ def __init__(
566
+ self,
567
+ dim: int,
568
+ num_heads: int,
569
+ mlp_ratio: float = 4.0,
570
+ qkv_bias: bool = True,
571
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
572
+ act_layer: Type[nn.Module] = nn.GELU,
573
+ use_rel_pos: bool = False,
574
+ rel_pos_zero_init: bool = True,
575
+ window_size: int = 0,
576
+ input_size: Optional[Tuple[int, int]] = None,
577
+ ) -> None:
578
+ """
579
+ Args:
580
+ dim (int): Number of input channels.
581
+ num_heads (int): Number of attention heads in each ViT block.
582
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
583
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
584
+ norm_layer (nn.Module): Normalization layer.
585
+ act_layer (nn.Module): Activation layer.
586
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
587
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
588
+ window_size (int): Window size for window attention blocks. If it equals 0, then
589
+ use global attention.
590
+ input_size (tuple(int, int) or None): Input resolution for calculating the relative
591
+ positional parameter size.
592
+ """
593
+ super().__init__()
594
+ self.norm1 = norm_layer(dim)
595
+ self.attn = Attention(
596
+ dim,
597
+ num_heads=num_heads,
598
+ qkv_bias=qkv_bias,
599
+ use_rel_pos=use_rel_pos,
600
+ rel_pos_zero_init=rel_pos_zero_init,
601
+ input_size=input_size if window_size == 0 else (window_size, window_size),
602
+ )
603
+
604
+ self.norm2 = norm_layer(dim)
605
+ self.mlp = MLPBlock(
606
+ embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer
607
+ )
608
+
609
+ self.window_size = window_size
610
+
611
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
612
+ shortcut = x
613
+ x = self.norm1(x)
614
+ # Window partition
615
+ if self.window_size > 0:
616
+ H, W = x.shape[1], x.shape[2]
617
+ x, pad_hw = window_partition(x, self.window_size)
618
+
619
+ x = self.attn(x)
620
+ # Reverse window partition
621
+ if self.window_size > 0:
622
+ x = window_unpartition(x, self.window_size, pad_hw, (H, W))
623
+
624
+ x = shortcut + x
625
+ x = x + self.mlp(self.norm2(x))
626
+
627
+ return x
628
+
629
+
630
+ class PatchEmbed(nn.Module):
631
+ """
632
+ Image to Patch Embedding.
633
+ """
634
+
635
+ def __init__(
636
+ self,
637
+ kernel_size: Tuple[int, int] = (16, 16),
638
+ stride: Tuple[int, int] = (16, 16),
639
+ padding: Tuple[int, int] = (0, 0),
640
+ in_chans: int = 3,
641
+ embed_dim: int = 768,
642
+ ) -> None:
643
+ """
644
+ Args:
645
+ kernel_size (Tuple): kernel size of the projection layer.
646
+ stride (Tuple): stride of the projection layer.
647
+ padding (Tuple): padding size of the projection layer.
648
+ in_chans (int): Number of input image channels.
649
+ embed_dim (int): Patch embedding dimension.
650
+ """
651
+ super().__init__()
652
+
653
+ self.proj = nn.Conv2d(
654
+ in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
655
+ )
656
+
657
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
658
+ x = self.proj(x)
659
+ # B C H W -> B H W C
660
+ x = x.permute(0, 2, 3, 1)
661
+ return x
662
+
663
+
664
+ def get_abs_pos_sam(abs_pos, tgt_size):
665
+ dtype = abs_pos.dtype
666
+
667
+ src_size = abs_pos.size(1)
668
+
669
+ if src_size != tgt_size:
670
+ old_pos_embed = abs_pos.permute(0, 3, 1, 2)
671
+ old_pos_embed = old_pos_embed.to(torch.float32)
672
+ new_pos_embed = F.interpolate(
673
+ old_pos_embed,
674
+ size=(tgt_size, tgt_size),
675
+ mode="bicubic",
676
+ antialias=True,
677
+ align_corners=False,
678
+ ).to(dtype)
679
+ new_pos_embed = new_pos_embed.permute(0, 2, 3, 1)
680
+ return new_pos_embed
681
+ else:
682
+ return abs_pos
683
+
684
+
685
+ # This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
686
+ class ImageEncoderViT(nn.Module):
687
+ def __init__(
688
+ self,
689
+ img_size: int = 1024,
690
+ patch_size: int = 16,
691
+ in_chans: int = 3,
692
+ embed_dim: int = 768,
693
+ depth: int = 12,
694
+ num_heads: int = 12,
695
+ mlp_ratio: float = 4.0,
696
+ out_chans: int = 256,
697
+ qkv_bias: bool = True,
698
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
699
+ act_layer: Type[nn.Module] = nn.GELU,
700
+ use_abs_pos: bool = True,
701
+ use_rel_pos: bool = False,
702
+ rel_pos_zero_init: bool = True,
703
+ window_size: int = 0,
704
+ global_attn_indexes: Tuple[int, ...] = (),
705
+ ) -> None:
706
+ """
707
+ Args:
708
+ img_size (int): Input image size.
709
+ patch_size (int): Patch size.
710
+ in_chans (int): Number of input image channels.
711
+ embed_dim (int): Patch embedding dimension.
712
+ depth (int): Depth of ViT.
713
+ num_heads (int): Number of attention heads in each ViT block.
714
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
715
+ qkv_bias (bool): If True, add a learnable bias to query, key, value.
716
+ norm_layer (nn.Module): Normalization layer.
717
+ act_layer (nn.Module): Activation layer.
718
+ use_abs_pos (bool): If True, use absolute positional embeddings.
719
+ use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
720
+ rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
721
+ window_size (int): Window size for window attention blocks.
722
+ global_attn_indexes (list): Indexes for blocks using global attention.
723
+ """
724
+ super().__init__()
725
+ self.img_size = img_size
726
+
727
+ self.patch_embed = PatchEmbed(
728
+ kernel_size=(patch_size, patch_size),
729
+ stride=(patch_size, patch_size),
730
+ in_chans=in_chans,
731
+ embed_dim=embed_dim,
732
+ )
733
+
734
+ self.pos_embed: Optional[nn.Parameter] = None
735
+ if use_abs_pos:
736
+ # Initialize absolute positional embedding with pretrain image size.
737
+ self.pos_embed = nn.Parameter(
738
+ torch.zeros(
739
+ 1, img_size // patch_size, img_size // patch_size, embed_dim
740
+ )
741
+ )
742
+
743
+ self.blocks = nn.ModuleList()
744
+ for i in range(depth):
745
+ block = Block(
746
+ dim=embed_dim,
747
+ num_heads=num_heads,
748
+ mlp_ratio=mlp_ratio,
749
+ qkv_bias=qkv_bias,
750
+ norm_layer=norm_layer,
751
+ act_layer=act_layer,
752
+ use_rel_pos=use_rel_pos,
753
+ rel_pos_zero_init=rel_pos_zero_init,
754
+ window_size=window_size if i not in global_attn_indexes else 0,
755
+ input_size=(img_size // patch_size, img_size // patch_size),
756
+ )
757
+ self.blocks.append(block)
758
+
759
+ self.neck = nn.Sequential(
760
+ nn.Conv2d(
761
+ embed_dim,
762
+ out_chans,
763
+ kernel_size=1,
764
+ bias=False,
765
+ ),
766
+ LayerNorm2d(out_chans),
767
+ nn.Conv2d(
768
+ out_chans,
769
+ out_chans,
770
+ kernel_size=3,
771
+ padding=1,
772
+ bias=False,
773
+ ),
774
+ LayerNorm2d(out_chans),
775
+ )
776
+
777
+ self.net_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False)
778
+ self.net_3 = nn.Conv2d(
779
+ 512, 1024, kernel_size=3, stride=2, padding=1, bias=False
780
+ )
781
+
782
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
783
+ x = self.patch_embed(x)
784
+ if self.pos_embed is not None:
785
+ x = x + get_abs_pos_sam(self.pos_embed, x.size(1))
786
+
787
+ for blk in self.blocks:
788
+ x = blk(x)
789
+
790
+ x = self.neck(x.permute(0, 3, 1, 2))
791
+ x2 = self.net_2(x)
792
+ x3 = self.net_3(x2.clone())
793
+
794
+ return x3
795
+
796
+
797
+ def _build_sam(
798
+ encoder_embed_dim,
799
+ encoder_depth,
800
+ encoder_num_heads,
801
+ encoder_global_attn_indexes,
802
+ checkpoint=None,
803
+ ):
804
+ prompt_embed_dim = 256
805
+ image_size = 1024
806
+ vit_patch_size = 16
807
+ image_encoder = ImageEncoderViT(
808
+ depth=encoder_depth,
809
+ embed_dim=encoder_embed_dim,
810
+ img_size=image_size,
811
+ mlp_ratio=4,
812
+ norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
813
+ num_heads=encoder_num_heads,
814
+ patch_size=vit_patch_size,
815
+ qkv_bias=True,
816
+ use_rel_pos=True,
817
+ global_attn_indexes=encoder_global_attn_indexes,
818
+ window_size=14,
819
+ out_chans=prompt_embed_dim,
820
+ )
821
+ image_encoder.eval()
822
+ if checkpoint is not None:
823
+ state_dict = torch.load(checkpoint)
824
+ image_encoder.load_state_dict(
825
+ {k[30:]: v for k, v in state_dict.items() if "vision_tower_high" in k},
826
+ strict=True,
827
+ )
828
+ return image_encoder
829
+
830
+
831
+ def build_sam_vit_b(checkpoint=None):
832
+ return _build_sam(
833
+ encoder_embed_dim=768,
834
+ encoder_depth=12,
835
+ encoder_num_heads=12,
836
+ encoder_global_attn_indexes=[2, 5, 8, 11],
837
+ checkpoint=checkpoint,
838
+ )
839
+
840
+
841
+ def get_abs_pos(abs_pos, tgt_size):
842
+ # abs_pos: L, C
843
+ # tgt_size: M
844
+ # return: M, C
845
+ dim = abs_pos.size(-1)
846
+ abs_pos_new = abs_pos.squeeze(0)
847
+ cls_token, old_pos_embed = abs_pos_new[:1], abs_pos_new[1:]
848
+
849
+ src_size = int(math.sqrt(abs_pos_new.shape[0] - 1))
850
+ tgt_size = int(math.sqrt(tgt_size))
851
+ dtype = abs_pos.dtype
852
+
853
+ if src_size != tgt_size:
854
+ old_pos_embed = (
855
+ old_pos_embed.view(1, src_size, src_size, dim)
856
+ .permute(0, 3, 1, 2)
857
+ .contiguous()
858
+ )
859
+ old_pos_embed = old_pos_embed.to(torch.float32)
860
+ new_pos_embed = F.interpolate(
861
+ old_pos_embed,
862
+ size=(tgt_size, tgt_size),
863
+ mode="bicubic",
864
+ antialias=True,
865
+ align_corners=False,
866
+ ).to(dtype)
867
+ new_pos_embed = new_pos_embed.permute(0, 2, 3, 1)
868
+ new_pos_embed = new_pos_embed.view(tgt_size * tgt_size, dim)
869
+ vision_pos_embed = torch.cat([cls_token, new_pos_embed], dim=0)
870
+ vision_pos_embed = vision_pos_embed.view(1, tgt_size * tgt_size + 1, dim)
871
+ return vision_pos_embed
872
+ else:
873
+ return abs_pos
874
+
875
+
876
+ class CLIPVisionEmbeddings(nn.Module):
877
+ def __init__(self, hidden_size=1024, image_size=224, patch_size=14, num_channels=3):
878
+ super().__init__()
879
+ self.embed_dim = hidden_size
880
+ self.image_size = image_size
881
+ self.patch_size = patch_size
882
+
883
+ self.class_embedding = torch.nn.Parameter(torch.randn(self.embed_dim))
884
+
885
+ self.patch_embedding = torch.nn.Conv2d(
886
+ in_channels=num_channels,
887
+ out_channels=self.embed_dim,
888
+ kernel_size=self.patch_size,
889
+ stride=self.patch_size,
890
+ bias=False,
891
+ )
892
+
893
+ self.num_patches = (self.image_size // self.patch_size) ** 2
894
+ self.num_positions = self.num_patches + 1
895
+ self.position_embedding = torch.nn.Embedding(self.num_positions, self.embed_dim)
896
+ self.register_buffer(
897
+ "position_ids", torch.arange(self.num_positions).expand((1, -1))
898
+ )
899
+
900
+ def forward(self, pixel_values, patch_embeds):
901
+ batch_size = pixel_values.shape[0]
902
+
903
+ if patch_embeds is not None:
904
+ patch_embeds = patch_embeds
905
+ else:
906
+ patch_embeds = self.patch_embedding(pixel_values)
907
+
908
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
909
+
910
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1)
911
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
912
+
913
+ embeddings = embeddings + get_abs_pos(
914
+ self.position_embedding(self.position_ids), embeddings.size(1)
915
+ )
916
+ return embeddings
917
+
918
+
919
+ class NoTPAttention(torch.nn.Module):
920
+ def __init__(self, cfg):
921
+ super().__init__()
922
+ self.num_heads = cfg["num_attention_heads"]
923
+ self.n_local_heads = cfg["num_attention_heads"]
924
+ self.head_dim = cfg["hidden_size"] // cfg["num_attention_heads"]
925
+ self.max_seq_len = cfg["seq_length"]
926
+ self.use_flash_attention = cfg["use_flash_attn"]
927
+
928
+ self.qkv_proj = torch.nn.Linear(
929
+ cfg["hidden_size"], cfg["hidden_size"] * 3, bias=True
930
+ )
931
+ self.out_proj = torch.nn.Linear(
932
+ cfg["hidden_size"], cfg["hidden_size"], bias=True
933
+ )
934
+
935
+ # self.core_attention = CoreAttention(cfg, AttnType.self_attn)
936
+
937
+ self.attn_drop = cfg["attention_dropout"]
938
+
939
+ def forward(
940
+ self,
941
+ x: torch.Tensor,
942
+ ):
943
+ bsz, seqlen, _ = x.shape
944
+ xqkv = self.qkv_proj(x)
945
+ xqkv = xqkv.view(bsz, seqlen, 3, self.num_heads, self.head_dim)
946
+
947
+ if self.use_flash_attention:
948
+
949
+ xq, xk, xv = torch.split(xqkv, 1, dim=2)
950
+ xq = xq.squeeze(2)
951
+ xk = xk.squeeze(2)
952
+ xv = xv.squeeze(2)
953
+ # xq, xk, xv = xqkv[:, :, 0, ...], xqkv[:, :, 1, ...], xqkv[:, :, 2, ...]
954
+
955
+ # (B, num_head, S, head_size)
956
+ xq = xq.permute(0, 2, 1, 3)
957
+ xk = xk.permute(0, 2, 1, 3)
958
+ xv = xv.permute(0, 2, 1, 3)
959
+ output = torch.nn.functional.scaled_dot_product_attention(
960
+ xq, xk, xv, attn_mask=None
961
+ )
962
+ output = output.permute(0, 2, 1, 3).reshape(bsz, seqlen, -1)
963
+ else:
964
+ xq, xk, xv = torch.split(xqkv, 1, dim=2)
965
+ xq = xq.squeeze(2)
966
+ xk = xk.squeeze(2)
967
+ xv = xv.squeeze(2)
968
+
969
+ xq = xq.permute(0, 2, 1, 3)
970
+ xk = xk.permute(0, 2, 1, 3)
971
+ xv = xv.permute(0, 2, 1, 3)
972
+ output = torch.nn.functional.scaled_dot_product_attention(
973
+ xq, xk, xv, attn_mask=None
974
+ )
975
+ output = output.permute(0, 2, 1, 3).reshape(bsz, seqlen, -1)
976
+ output = self.out_proj(output)
977
+ return output
978
+
979
+
980
+ @torch.jit.script
981
+ def quick_gelu(x):
982
+ return x * torch.sigmoid(1.702 * x)
983
+
984
+
985
+ class NoTPFeedForward(nn.Module):
986
+ def __init__(
987
+ self,
988
+ cfg,
989
+ dim: int,
990
+ hidden_dim: int,
991
+ ):
992
+ super().__init__()
993
+
994
+ self.fc1 = torch.nn.Linear(dim, hidden_dim, bias=True)
995
+ self.fc2 = torch.nn.Linear(hidden_dim, dim, bias=True)
996
+
997
+ def forward(self, x):
998
+ output = self.fc2(quick_gelu(self.fc1(x)))
999
+ return output
1000
+
1001
+
1002
+ class LayerNormfp32(torch.nn.LayerNorm):
1003
+ """Subclass torch's LayerNorm to handle fp16."""
1004
+
1005
+ def forward(self, x: torch.Tensor):
1006
+ orig_type = x.dtype
1007
+ ret = super().forward(x.type(torch.float32))
1008
+ return ret.type(orig_type)
1009
+
1010
+
1011
+ class NoTPTransformerBlock(nn.Module):
1012
+ def __init__(self, cfg, layer_id: int, multiple_of=256):
1013
+ super().__init__()
1014
+
1015
+ self.n_heads = cfg["num_attention_heads"]
1016
+ self.dim = cfg["hidden_size"]
1017
+ self.head_dim = cfg["hidden_size"] // cfg["num_attention_heads"]
1018
+ self.self_attn = NoTPAttention(cfg)
1019
+ self.mlp = NoTPFeedForward(
1020
+ cfg, dim=cfg["hidden_size"], hidden_dim=cfg["ffn_hidden_size"]
1021
+ )
1022
+ self.layer_id = layer_id
1023
+ self.layer_norm1 = torch.nn.LayerNorm(
1024
+ cfg["hidden_size"], eps=cfg["layernorm_epsilon"]
1025
+ )
1026
+ self.layer_norm2 = torch.nn.LayerNorm(
1027
+ cfg["hidden_size"], eps=cfg["layernorm_epsilon"]
1028
+ )
1029
+
1030
+ def forward(self, x: torch.Tensor):
1031
+ residual = self.self_attn.forward(self.layer_norm1(x))
1032
+ h = x + residual
1033
+ out = h + self.mlp.forward(self.layer_norm2(h))
1034
+ return out
1035
+
1036
+
1037
+ class NoTPTransformer(nn.Module):
1038
+ def __init__(self, cfg):
1039
+ super().__init__()
1040
+
1041
+ self.cfg = cfg
1042
+ self.num_layers = cfg["num_layers"]
1043
+
1044
+ self.layers = torch.nn.ModuleList()
1045
+ for layer_id in range(self.num_layers):
1046
+ self.layers.append(
1047
+ NoTPTransformerBlock(
1048
+ cfg,
1049
+ layer_id + 1,
1050
+ )
1051
+ )
1052
+
1053
+ def forward(
1054
+ self,
1055
+ hidden_states,
1056
+ ):
1057
+
1058
+ for layer in self.layers:
1059
+ hidden_states = layer(hidden_states)
1060
+
1061
+ return hidden_states
1062
+
1063
+
1064
+ class VitModel(nn.Module):
1065
+ def __init__(self, cfg, freeze_embed=False, freeze_pre_norm=False) -> None:
1066
+ super().__init__()
1067
+
1068
+ self.embeddings = CLIPVisionEmbeddings(
1069
+ hidden_size=cfg["hidden_size"],
1070
+ image_size=cfg["image_size"],
1071
+ patch_size=cfg["patch_size"],
1072
+ )
1073
+
1074
+ if freeze_embed:
1075
+ for _, param in self.embeddings.named_parameters():
1076
+ param.requires_grad = False
1077
+
1078
+ self.transformer = NoTPTransformer(cfg=cfg)
1079
+
1080
+ if cfg.get("fp32norm", False):
1081
+ logger.info("Load fp32 layernorm for ViT.")
1082
+ self.pre_layrnorm = LayerNormfp32(
1083
+ cfg["hidden_size"],
1084
+ eps=cfg.get("pre_layernorm_epsilon", 1e-5),
1085
+ )
1086
+ else:
1087
+ self.pre_layrnorm = torch.nn.LayerNorm(
1088
+ cfg["hidden_size"],
1089
+ eps=cfg.get("pre_layernorm_epsilon", 1e-5),
1090
+ )
1091
+
1092
+ if freeze_pre_norm:
1093
+ for _, param in self.pre_layrnorm.named_parameters():
1094
+ param.requires_grad = False
1095
+
1096
+ for p in self.parameters():
1097
+ p.micro_dp = True
1098
+
1099
+ @property
1100
+ def dtype(self):
1101
+ return next(self.parameters()).dtype
1102
+
1103
+ def set_input_tensor(self, input_tensor):
1104
+ if not isinstance(input_tensor, list):
1105
+ input_tensor = [input_tensor]
1106
+ self.transformer.set_input_tensor(input_tensor[0])
1107
+
1108
+ def __str__(self) -> str:
1109
+ return "open_clip"
1110
+
1111
+ def forward(self, x, patch_embeds):
1112
+ x = self.embeddings(x, patch_embeds)
1113
+ hidden_states = self.pre_layrnorm(x)
1114
+
1115
+ output = self.transformer(hidden_states)
1116
+
1117
+ return output
1118
+
1119
+
1120
+ vit_model_cfg = dict(
1121
+ num_layers=24,
1122
+ hidden_size=1024,
1123
+ num_heads=16,
1124
+ num_attention_heads=16,
1125
+ ffn_hidden_size=4096,
1126
+ seq_length=256,
1127
+ max_position_embeddings=256,
1128
+ use_flash_attn=False,
1129
+ understand_projector_stride=2,
1130
+ hidden_dropout=0.0,
1131
+ attention_dropout=0.0,
1132
+ no_persist_layer_norm=False,
1133
+ layernorm_epsilon=1e-5,
1134
+ pre_layernorm_epsilon=1e-5,
1135
+ image_size=224,
1136
+ patch_size=14,
1137
+ recompute_list=[],
1138
+ )
1139
+
1140
+
1141
+ def build_clip_l():
1142
+ return VitModel(
1143
+ cfg=vit_model_cfg,
1144
+ freeze_embed=False,
1145
+ freeze_pre_norm=False,
1146
+ )
1147
+
1148
+
1149
+ class DeepseekOCRForCausalLM(nn.Module):
1150
+ def __init__(
1151
+ self,
1152
+ *,
1153
+ config: DeepseekVLV2Config,
1154
+ quant_config: Optional[QuantizationConfig] = None,
1155
+ prefix: str = "",
1156
+ ):
1157
+ super().__init__()
1158
+
1159
+ self.config = config
1160
+
1161
+ self.vision_config = config.vision_config
1162
+ self.projector_config = config.projector_config
1163
+ self.text_config = config.text_config
1164
+
1165
+ n_embed = 1280
1166
+
1167
+ self.tile_tag = config.tile_tag
1168
+ self.global_view_pos = config.global_view_pos
1169
+
1170
+ # special token for image token sequence format
1171
+ embed_std = 1 / torch.sqrt(torch.tensor(n_embed, dtype=torch.float32))
1172
+ if self.tile_tag == "2D":
1173
+ # <|view_separator|>, <|\n|>
1174
+ self.image_newline = nn.Parameter(torch.randn(n_embed) * embed_std)
1175
+ self.view_seperator = nn.Parameter(torch.randn(n_embed) * embed_std)
1176
+ else:
1177
+ raise ValueError(
1178
+ f"Only 2D tile_tag is supported currently, got: {self.tile_tag}"
1179
+ )
1180
+
1181
+ if self.text_config.topk_method == "noaux_tc":
1182
+ self.model = DeepseekV3ForCausalLM(
1183
+ config=config.text_config,
1184
+ quant_config=quant_config,
1185
+ prefix=maybe_prefix(prefix, "language"),
1186
+ )
1187
+ elif not self.text_config.use_mla:
1188
+ self.model = DeepseekForCausalLM(
1189
+ config=config.text_config,
1190
+ quant_config=quant_config,
1191
+ prefix=maybe_prefix(prefix, "language"),
1192
+ )
1193
+ else:
1194
+ self.model = DeepseekV2ForCausalLM(
1195
+ config=config.text_config,
1196
+ quant_config=quant_config,
1197
+ prefix=maybe_prefix(prefix, "language"),
1198
+ )
1199
+
1200
+ self.sam_model = build_sam_vit_b()
1201
+ self.vision_model = build_clip_l()
1202
+ n_embed = 1280
1203
+ self.projector = MlpProjector(
1204
+ projector_type="linear",
1205
+ input_dim=2048,
1206
+ n_embed=n_embed,
1207
+ )
1208
+
1209
+ def _parse_and_validate_image_input(self, **kwargs: object):
1210
+
1211
+ pixel_values = kwargs.pop("pixel_values", None)
1212
+ images_spatial_crop = kwargs.pop("images_spatial_crop", None)
1213
+ images_crop = kwargs.pop("images_crop", None)
1214
+
1215
+ if pixel_values is None or torch.sum(pixel_values).item() == 0:
1216
+ return None
1217
+
1218
+ if pixel_values is not None:
1219
+ if not isinstance(pixel_values, (torch.Tensor, list)):
1220
+ raise ValueError(
1221
+ "Incorrect type of pixel values. " f"Got type: {type(pixel_values)}"
1222
+ )
1223
+
1224
+ if not isinstance(images_spatial_crop, (torch.Tensor, list)):
1225
+ raise ValueError(
1226
+ "Incorrect type of image sizes. "
1227
+ f"Got type: {type(images_spatial_crop)}"
1228
+ )
1229
+
1230
+ if not isinstance(images_crop, (torch.Tensor, list)):
1231
+ raise ValueError(
1232
+ "Incorrect type of image crop. " f"Got type: {type(images_crop)}"
1233
+ )
1234
+
1235
+ return [pixel_values, images_crop, images_spatial_crop]
1236
+
1237
+ raise AssertionError("This line should be unreachable.")
1238
+
1239
+ def _pixel_values_to_embedding(
1240
+ self,
1241
+ pixel_values: torch.Tensor,
1242
+ images_crop: torch.Tensor,
1243
+ images_spatial_crop: torch.Tensor,
1244
+ ) -> NestedTensors:
1245
+
1246
+ # Pixel_values (global view): [n_image, batch_size, 3, height, width]
1247
+ # images_spatial_crop: [n_image, batch_size, [num_tiles_w, num_tiles_h]]
1248
+ # images_crop (local view): [n_image, batch_size, num_pathes, 3, h, w]
1249
+ # split the pixel and image_crop, all batch_size = 1
1250
+
1251
+ images_in_this_batch = []
1252
+
1253
+ with torch.no_grad():
1254
+ for jdx in range(images_spatial_crop.size(0)):
1255
+ patches = images_crop[jdx][0].to(torch.bfloat16)
1256
+ image_ori = pixel_values[jdx]
1257
+ crop_shape = images_spatial_crop[jdx][0]
1258
+
1259
+ if torch.sum(patches).item() != 0:
1260
+ local_features_1 = self.sam_model(patches)
1261
+ local_features_2 = self.vision_model(patches, local_features_1)
1262
+
1263
+ local_features = torch.cat(
1264
+ (
1265
+ local_features_2[:, 1:],
1266
+ local_features_1.flatten(2).permute(0, 2, 1),
1267
+ ),
1268
+ dim=-1,
1269
+ )
1270
+ local_features = self.projector(local_features)
1271
+
1272
+ global_features_1 = self.sam_model(image_ori)
1273
+ global_features_2 = self.vision_model(image_ori, global_features_1)
1274
+ global_features = torch.cat(
1275
+ (
1276
+ global_features_2[:, 1:],
1277
+ global_features_1.flatten(2).permute(0, 2, 1),
1278
+ ),
1279
+ dim=-1,
1280
+ )
1281
+ global_features = self.projector(global_features)
1282
+
1283
+ _, hw, n_dim = global_features.shape
1284
+ h = w = int(hw**0.5)
1285
+
1286
+ _2, hw2, n_dim2 = local_features.shape
1287
+ h2 = w2 = int(hw2**0.5)
1288
+
1289
+ width_crop_num, height_crop_num = int(crop_shape[0]), int(
1290
+ crop_shape[1]
1291
+ )
1292
+
1293
+ global_features = global_features.view(h, w, n_dim)
1294
+
1295
+ global_features = torch.cat(
1296
+ [
1297
+ global_features,
1298
+ self.image_newline[None, None, :].expand(h, 1, n_dim),
1299
+ ],
1300
+ dim=1,
1301
+ )
1302
+
1303
+ global_features = global_features.view(-1, n_dim)
1304
+
1305
+ local_features = (
1306
+ local_features.view(
1307
+ height_crop_num, width_crop_num, h2, w2, n_dim2
1308
+ )
1309
+ .permute(0, 2, 1, 3, 4)
1310
+ .reshape(height_crop_num * h2, width_crop_num * w2, n_dim2)
1311
+ )
1312
+ local_features = torch.cat(
1313
+ [
1314
+ local_features,
1315
+ self.image_newline[None, None, :].expand(
1316
+ height_crop_num * h2, 1, n_dim2
1317
+ ),
1318
+ ],
1319
+ dim=1,
1320
+ )
1321
+ local_features = local_features.view(-1, n_dim2)
1322
+
1323
+ global_local_features = torch.cat(
1324
+ [local_features, global_features, self.view_seperator[None, :]],
1325
+ dim=0,
1326
+ )
1327
+
1328
+ else:
1329
+ global_features_1 = self.sam_model(image_ori)
1330
+ global_features_2 = self.vision_model(image_ori, global_features_1)
1331
+ global_features = torch.cat(
1332
+ (
1333
+ global_features_2[:, 1:],
1334
+ global_features_1.flatten(2).permute(0, 2, 1),
1335
+ ),
1336
+ dim=-1,
1337
+ )
1338
+ global_features = self.projector(global_features)
1339
+
1340
+ _, hw, n_dim = global_features.shape
1341
+ h = w = int(hw**0.5)
1342
+
1343
+ global_features = global_features.view(h, w, n_dim)
1344
+
1345
+ global_features = torch.cat(
1346
+ [
1347
+ global_features,
1348
+ self.image_newline[None, None, :].expand(h, 1, n_dim),
1349
+ ],
1350
+ dim=1,
1351
+ )
1352
+
1353
+ global_features = global_features.view(-1, n_dim)
1354
+
1355
+ global_local_features = torch.cat(
1356
+ [global_features, self.view_seperator[None, :]], dim=0
1357
+ )
1358
+
1359
+ images_in_this_batch.append(global_local_features)
1360
+
1361
+ return images_in_this_batch
1362
+
1363
+ def _process_image_input(self, mm_items: List[MultimodalDataItem]) -> torch.Tensor:
1364
+ pixel_values = torch.stack([item.feature for item in mm_items], dim=0).type(
1365
+ self.vision_model.dtype
1366
+ )
1367
+
1368
+ images_crop = (
1369
+ torch.stack([item.images_crop for item in mm_items], dim=0)
1370
+ .type(torch.long)
1371
+ .to(device=pixel_values.device)
1372
+ )
1373
+ images_spatial_crop = (
1374
+ torch.cat([item.images_spatial_crop for item in mm_items], dim=0)
1375
+ .type(torch.long)
1376
+ .to(device=pixel_values.device)
1377
+ )
1378
+
1379
+ assert images_crop.dim() == 6
1380
+ assert images_spatial_crop.dim() == 3
1381
+
1382
+ vision_feature_lists = self._pixel_values_to_embedding(
1383
+ pixel_values=pixel_values,
1384
+ images_crop=images_crop,
1385
+ images_spatial_crop=images_spatial_crop,
1386
+ )
1387
+ vision_features = torch.cat(vision_feature_lists, dim=0).type(
1388
+ self.vision_model.dtype
1389
+ )
1390
+
1391
+ return vision_features
1392
+
1393
+ def get_language_model(self) -> torch.nn.Module:
1394
+ return self.model
1395
+
1396
+ def get_multimodal_embeddings(
1397
+ self, **kwargs: object
1398
+ ) -> Optional[MultiModalEmbeddings]:
1399
+ image_input = self._parse_and_validate_image_input(**kwargs)
1400
+ if image_input is None:
1401
+ return None
1402
+ vision_embeddings = self._process_image_input(image_input)
1403
+ return vision_embeddings
1404
+
1405
+ def get_input_embeddings(
1406
+ self,
1407
+ input_ids: torch.Tensor,
1408
+ multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
1409
+ ) -> torch.Tensor:
1410
+
1411
+ inputs_embeds = self.model.get_input_embeddings(input_ids)
1412
+
1413
+ if multimodal_embeddings is not None:
1414
+ inputs_embeds = merge_multimodal_embeddings(
1415
+ input_ids, inputs_embeds, multimodal_embeddings, self.image_token_id
1416
+ )
1417
+
1418
+ return inputs_embeds
1419
+
1420
+ def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
1421
+ pattern = MultiModalityDataPaddingPatternMultimodalTokens()
1422
+ return pattern.pad_input_tokens(input_ids, mm_inputs)
1423
+
1424
+ def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
1425
+ vision_embeddings = self._process_image_input(items)
1426
+ return vision_embeddings
1427
+
1428
+ def forward(
1429
+ self,
1430
+ input_ids: torch.Tensor,
1431
+ positions: torch.Tensor,
1432
+ forward_batch: ForwardBatch,
1433
+ **kwargs: object,
1434
+ ):
1435
+ hidden_states = general_mm_embed_routine(
1436
+ input_ids=input_ids,
1437
+ forward_batch=forward_batch,
1438
+ language_model=self.model,
1439
+ multimodal_model=self,
1440
+ positions=positions,
1441
+ )
1442
+
1443
+ return hidden_states
1444
+
1445
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
1446
+ stacked_params_mapping = [
1447
+ # (param_name, shard_name, shard_id)
1448
+ (".qkv_proj", ".q_proj", "q"),
1449
+ (".qkv_proj", ".k_proj", "k"),
1450
+ (".qkv_proj", ".v_proj", "v"),
1451
+ (".gate_up_proj", ".gate_proj", 0),
1452
+ (".gate_up_proj", ".up_proj", 1),
1453
+ ]
1454
+
1455
+ params_dict = dict(self.named_parameters())
1456
+ loaded_params: Set[str] = set()
1457
+
1458
+ for name, loaded_weight in weights:
1459
+ if "rotary_emb.inv_freq" in name:
1460
+ continue
1461
+ if name == "lm_head.weight":
1462
+ name = "model.lm_head.weight"
1463
+ elif name.startswith("model."):
1464
+ if (
1465
+ "image_newline" in name
1466
+ or ".projector" in name
1467
+ or "vision_model" in name
1468
+ or "sam_model" in name
1469
+ or "view_seperator" in name
1470
+ ):
1471
+ name = name[len("model.") :]
1472
+ elif not (
1473
+ ".projector" in name
1474
+ or "vision_model" in name
1475
+ or "sam_model" in name
1476
+ or "image_newline" in name
1477
+ ):
1478
+ name = name.replace("model.", "model.model.")
1479
+
1480
+ for param_name, weight_name, shard_id in stacked_params_mapping:
1481
+ if weight_name not in name:
1482
+ continue
1483
+ name = name.replace(weight_name, param_name)
1484
+ # Skip loading extra bias for GPTQ models.
1485
+ if name.endswith(".bias") and name not in params_dict:
1486
+ continue
1487
+ # Skip experts that are not assigned to this worker.
1488
+ if (
1489
+ "mlp.experts." in name or "mlp.shared_experts." in name
1490
+ ) and name not in params_dict:
1491
+ continue
1492
+ param = params_dict[name]
1493
+ weight_loader = param.weight_loader
1494
+ weight_loader(param, loaded_weight, shard_id)
1495
+ break
1496
+ else:
1497
+ # Skip loading extra bias for GPTQ models.
1498
+ if name.endswith(".bias") and name not in params_dict:
1499
+ continue
1500
+ # Skip experts that are not assigned to this worker.
1501
+ if (
1502
+ "mlp.experts." in name or "mlp.shared_experts." in name
1503
+ ) and name not in params_dict:
1504
+ continue
1505
+ param = params_dict[name]
1506
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
1507
+ weight_loader(param, loaded_weight)
1508
+ loaded_params.add(name)
1509
+ unloaded_params = params_dict.keys() - loaded_params
1510
+ if unloaded_params:
1511
+ raise RuntimeError(
1512
+ f"Some weights are not initialized from checkpoints: {unloaded_params}"
1513
+ )
1514
+
1515
+
1516
+ EntryClass = [DeepseekOCRForCausalLM]