sglang 0.5.2rc1__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (265) hide show
  1. sglang/bench_one_batch_server.py +10 -1
  2. sglang/bench_serving.py +257 -29
  3. sglang/lang/interpreter.py +1 -1
  4. sglang/srt/configs/__init__.py +4 -0
  5. sglang/srt/configs/device_config.py +3 -1
  6. sglang/srt/configs/dots_vlm.py +139 -0
  7. sglang/srt/configs/internvl.py +6 -0
  8. sglang/srt/configs/load_config.py +1 -0
  9. sglang/srt/configs/model_config.py +50 -6
  10. sglang/srt/configs/qwen3_next.py +326 -0
  11. sglang/srt/connector/__init__.py +8 -1
  12. sglang/srt/connector/remote_instance.py +82 -0
  13. sglang/srt/constrained/base_grammar_backend.py +48 -12
  14. sglang/srt/constrained/llguidance_backend.py +0 -1
  15. sglang/srt/constrained/outlines_backend.py +0 -1
  16. sglang/srt/constrained/xgrammar_backend.py +28 -9
  17. sglang/srt/custom_op.py +11 -1
  18. sglang/srt/debug_utils/dump_comparator.py +81 -44
  19. sglang/srt/debug_utils/dump_loader.py +97 -0
  20. sglang/srt/debug_utils/dumper.py +11 -3
  21. sglang/srt/debug_utils/text_comparator.py +73 -11
  22. sglang/srt/disaggregation/base/conn.py +1 -1
  23. sglang/srt/disaggregation/common/conn.py +15 -12
  24. sglang/srt/disaggregation/decode.py +21 -10
  25. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
  26. sglang/srt/disaggregation/fake/conn.py +1 -1
  27. sglang/srt/disaggregation/mini_lb.py +6 -445
  28. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  29. sglang/srt/disaggregation/nixl/conn.py +180 -16
  30. sglang/srt/disaggregation/prefill.py +5 -3
  31. sglang/srt/disaggregation/utils.py +5 -50
  32. sglang/srt/distributed/parallel_state.py +67 -43
  33. sglang/srt/entrypoints/engine.py +38 -17
  34. sglang/srt/entrypoints/grpc_request_manager.py +580 -0
  35. sglang/srt/entrypoints/grpc_server.py +680 -0
  36. sglang/srt/entrypoints/http_server.py +88 -53
  37. sglang/srt/entrypoints/openai/protocol.py +7 -4
  38. sglang/srt/entrypoints/openai/serving_base.py +46 -3
  39. sglang/srt/entrypoints/openai/serving_chat.py +39 -19
  40. sglang/srt/entrypoints/openai/serving_completions.py +15 -4
  41. sglang/srt/entrypoints/openai/serving_embedding.py +9 -4
  42. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  43. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  44. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  45. sglang/srt/eplb/eplb_manager.py +2 -2
  46. sglang/srt/eplb/expert_distribution.py +26 -13
  47. sglang/srt/eplb/expert_location.py +8 -3
  48. sglang/srt/eplb/expert_location_updater.py +1 -1
  49. sglang/srt/function_call/base_format_detector.py +3 -6
  50. sglang/srt/function_call/ebnf_composer.py +11 -9
  51. sglang/srt/function_call/function_call_parser.py +6 -0
  52. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  53. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  54. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  55. sglang/srt/grpc/__init__.py +1 -0
  56. sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
  57. sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
  58. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
  59. sglang/srt/hf_transformers_utils.py +4 -0
  60. sglang/srt/layers/activation.py +142 -9
  61. sglang/srt/layers/attention/aiter_backend.py +93 -68
  62. sglang/srt/layers/attention/ascend_backend.py +11 -4
  63. sglang/srt/layers/attention/fla/chunk.py +242 -0
  64. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  65. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  66. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  67. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  68. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  69. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  70. sglang/srt/layers/attention/fla/index.py +37 -0
  71. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  72. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  73. sglang/srt/layers/attention/fla/op.py +66 -0
  74. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  75. sglang/srt/layers/attention/fla/utils.py +331 -0
  76. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  77. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  78. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  79. sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
  80. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  81. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  82. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  83. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  84. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  85. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  86. sglang/srt/layers/attention/triton_backend.py +18 -1
  87. sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
  88. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  89. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  90. sglang/srt/layers/communicator.py +45 -7
  91. sglang/srt/layers/dp_attention.py +30 -1
  92. sglang/srt/layers/layernorm.py +32 -15
  93. sglang/srt/layers/linear.py +34 -3
  94. sglang/srt/layers/logits_processor.py +29 -10
  95. sglang/srt/layers/moe/__init__.py +2 -1
  96. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  97. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  98. sglang/srt/layers/moe/ep_moe/layer.py +182 -62
  99. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
  100. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  101. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  102. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  103. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  104. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  105. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  106. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  107. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  108. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  109. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  110. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  111. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  112. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  113. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
  114. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  115. sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
  116. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  117. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  118. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  119. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  120. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  121. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  122. sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
  123. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  124. sglang/srt/layers/moe/topk.py +30 -9
  125. sglang/srt/layers/moe/utils.py +12 -7
  126. sglang/srt/layers/quantization/awq.py +19 -7
  127. sglang/srt/layers/quantization/base_config.py +11 -6
  128. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  129. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  130. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  131. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  132. sglang/srt/layers/quantization/fp8.py +76 -47
  133. sglang/srt/layers/quantization/fp8_utils.py +50 -31
  134. sglang/srt/layers/quantization/gptq.py +25 -17
  135. sglang/srt/layers/quantization/modelopt_quant.py +182 -49
  136. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  137. sglang/srt/layers/quantization/mxfp4.py +68 -41
  138. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  139. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  140. sglang/srt/layers/quantization/quark/utils.py +97 -0
  141. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  142. sglang/srt/layers/quantization/unquant.py +135 -47
  143. sglang/srt/layers/quantization/w4afp8.py +30 -17
  144. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  145. sglang/srt/layers/quantization/w8a8_int8.py +76 -38
  146. sglang/srt/layers/rocm_linear_utils.py +44 -0
  147. sglang/srt/layers/rotary_embedding.py +0 -18
  148. sglang/srt/layers/sampler.py +162 -18
  149. sglang/srt/lora/backend/base_backend.py +50 -8
  150. sglang/srt/lora/backend/triton_backend.py +90 -2
  151. sglang/srt/lora/layers.py +32 -0
  152. sglang/srt/lora/lora.py +4 -1
  153. sglang/srt/lora/lora_manager.py +35 -112
  154. sglang/srt/lora/mem_pool.py +24 -10
  155. sglang/srt/lora/utils.py +18 -9
  156. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  157. sglang/srt/managers/cache_controller.py +200 -199
  158. sglang/srt/managers/data_parallel_controller.py +105 -35
  159. sglang/srt/managers/detokenizer_manager.py +8 -4
  160. sglang/srt/managers/disagg_service.py +46 -0
  161. sglang/srt/managers/io_struct.py +199 -12
  162. sglang/srt/managers/mm_utils.py +1 -0
  163. sglang/srt/managers/multi_tokenizer_mixin.py +351 -397
  164. sglang/srt/managers/schedule_batch.py +77 -56
  165. sglang/srt/managers/schedule_policy.py +4 -3
  166. sglang/srt/managers/scheduler.py +191 -139
  167. sglang/srt/managers/scheduler_metrics_mixin.py +116 -9
  168. sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
  169. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  170. sglang/srt/managers/template_manager.py +3 -3
  171. sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
  172. sglang/srt/managers/tokenizer_manager.py +260 -519
  173. sglang/srt/managers/tp_worker.py +53 -4
  174. sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
  175. sglang/srt/mem_cache/allocator.py +1 -1
  176. sglang/srt/mem_cache/hicache_storage.py +18 -33
  177. sglang/srt/mem_cache/hiradix_cache.py +108 -48
  178. sglang/srt/mem_cache/memory_pool.py +347 -48
  179. sglang/srt/mem_cache/memory_pool_host.py +121 -57
  180. sglang/srt/mem_cache/radix_cache.py +0 -2
  181. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  182. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  183. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +95 -5
  184. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  185. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  186. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +81 -20
  187. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  188. sglang/srt/mem_cache/swa_radix_cache.py +0 -2
  189. sglang/srt/metrics/collector.py +502 -77
  190. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  191. sglang/srt/metrics/utils.py +48 -0
  192. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  193. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  194. sglang/srt/model_executor/forward_batch_info.py +75 -19
  195. sglang/srt/model_executor/model_runner.py +357 -30
  196. sglang/srt/model_loader/__init__.py +9 -3
  197. sglang/srt/model_loader/loader.py +128 -4
  198. sglang/srt/model_loader/weight_utils.py +2 -1
  199. sglang/srt/models/apertus.py +686 -0
  200. sglang/srt/models/bailing_moe.py +798 -218
  201. sglang/srt/models/bailing_moe_nextn.py +168 -0
  202. sglang/srt/models/deepseek_v2.py +346 -48
  203. sglang/srt/models/dots_vlm.py +174 -0
  204. sglang/srt/models/dots_vlm_vit.py +337 -0
  205. sglang/srt/models/ernie4.py +1 -1
  206. sglang/srt/models/gemma3n_mm.py +1 -1
  207. sglang/srt/models/glm4_moe.py +11 -2
  208. sglang/srt/models/glm4v.py +4 -2
  209. sglang/srt/models/glm4v_moe.py +3 -0
  210. sglang/srt/models/gpt_oss.py +1 -1
  211. sglang/srt/models/internvl.py +28 -0
  212. sglang/srt/models/llama4.py +9 -0
  213. sglang/srt/models/llama_eagle3.py +13 -0
  214. sglang/srt/models/longcat_flash.py +2 -2
  215. sglang/srt/models/minicpmv.py +165 -3
  216. sglang/srt/models/mllama4.py +25 -0
  217. sglang/srt/models/opt.py +637 -0
  218. sglang/srt/models/qwen2.py +7 -0
  219. sglang/srt/models/qwen2_5_vl.py +27 -3
  220. sglang/srt/models/qwen2_moe.py +60 -13
  221. sglang/srt/models/qwen3.py +8 -2
  222. sglang/srt/models/qwen3_moe.py +40 -9
  223. sglang/srt/models/qwen3_next.py +1042 -0
  224. sglang/srt/models/qwen3_next_mtp.py +112 -0
  225. sglang/srt/models/step3_vl.py +1 -1
  226. sglang/srt/models/torch_native_llama.py +1 -1
  227. sglang/srt/multimodal/processors/dots_vlm.py +99 -0
  228. sglang/srt/multimodal/processors/glm4v.py +9 -9
  229. sglang/srt/multimodal/processors/internvl.py +141 -129
  230. sglang/srt/multimodal/processors/qwen_vl.py +15 -5
  231. sglang/srt/offloader.py +27 -3
  232. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  233. sglang/srt/remote_instance_weight_loader_utils.py +69 -0
  234. sglang/srt/sampling/sampling_batch_info.py +18 -15
  235. sglang/srt/server_args.py +355 -37
  236. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  237. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  238. sglang/srt/speculative/eagle_utils.py +0 -2
  239. sglang/srt/speculative/eagle_worker.py +197 -112
  240. sglang/srt/speculative/spec_info.py +5 -0
  241. sglang/srt/speculative/standalone_worker.py +109 -0
  242. sglang/srt/tracing/trace.py +552 -0
  243. sglang/srt/utils.py +46 -3
  244. sglang/srt/weight_sync/utils.py +1 -1
  245. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  246. sglang/test/few_shot_gsm8k.py +1 -0
  247. sglang/test/runners.py +4 -0
  248. sglang/test/test_cutlass_moe.py +24 -6
  249. sglang/test/test_disaggregation_utils.py +66 -0
  250. sglang/test/test_fp4_moe.py +370 -1
  251. sglang/test/test_utils.py +28 -1
  252. sglang/utils.py +12 -0
  253. sglang/version.py +1 -1
  254. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
  255. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +263 -200
  256. sglang/srt/disaggregation/launch_lb.py +0 -118
  257. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  258. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  259. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  260. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  261. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  262. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  263. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
  264. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
  265. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -24,6 +24,8 @@ from sglang.srt.distributed import (
24
24
  get_tensor_model_parallel_world_size,
25
25
  )
26
26
  from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
27
+ from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
28
+ from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
27
29
  from sglang.srt.layers.parameter import (
28
30
  ChannelQuantScaleParameter,
29
31
  ModelWeightParameter,
@@ -49,8 +51,10 @@ from sglang.srt.utils import (
49
51
  )
50
52
 
51
53
  if TYPE_CHECKING:
52
- from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
53
- from sglang.srt.layers.moe.topk import TopKOutput
54
+ from sglang.srt.layers.moe.token_dispatcher import (
55
+ CombineInput,
56
+ StandardDispatchOutput,
57
+ )
54
58
 
55
59
  _is_cuda = is_cuda()
56
60
  _is_cpu_amx_available = cpu_has_amx_support()
@@ -339,9 +343,8 @@ class W8A8Int8LinearMethod(LinearMethodBase):
339
343
  _is_cpu_amx_available
340
344
  ), "W8A8Int8LinearMethod on CPU requires that CPU has AMX support"
341
345
  _amx_process_weight_after_loading(layer, ["weight"])
342
- return
343
-
344
- layer.weight = Parameter(layer.weight.t(), requires_grad=False)
346
+ else:
347
+ layer.weight = Parameter(layer.weight.t(), requires_grad=False)
345
348
  layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False)
346
349
 
347
350
  def create_weights(
@@ -417,7 +420,7 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
417
420
  layer: torch.nn.Module,
418
421
  num_experts: int,
419
422
  hidden_size: int,
420
- intermediate_size: int,
423
+ intermediate_size_per_partition: int,
421
424
  params_dtype: torch.dtype,
422
425
  **extra_weight_attrs,
423
426
  ):
@@ -428,7 +431,10 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
428
431
  # WEIGHTS
429
432
  w13_weight = torch.nn.Parameter(
430
433
  torch.empty(
431
- num_experts, 2 * intermediate_size, hidden_size, dtype=torch.int8
434
+ num_experts,
435
+ 2 * intermediate_size_per_partition,
436
+ hidden_size,
437
+ dtype=torch.int8,
432
438
  ),
433
439
  requires_grad=False,
434
440
  )
@@ -436,14 +442,21 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
436
442
  set_weight_attrs(w13_weight, extra_weight_attrs)
437
443
 
438
444
  w2_weight = torch.nn.Parameter(
439
- torch.empty(num_experts, hidden_size, intermediate_size, dtype=torch.int8),
445
+ torch.empty(
446
+ num_experts,
447
+ hidden_size,
448
+ intermediate_size_per_partition,
449
+ dtype=torch.int8,
450
+ ),
440
451
  requires_grad=False,
441
452
  )
442
453
  layer.register_parameter("w2_weight", w2_weight)
443
454
  set_weight_attrs(w2_weight, extra_weight_attrs)
444
455
 
445
456
  w13_weight_scale = torch.nn.Parameter(
446
- torch.ones(num_experts, 2 * intermediate_size, 1, dtype=torch.float32),
457
+ torch.ones(
458
+ num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32
459
+ ),
447
460
  requires_grad=False,
448
461
  )
449
462
  w2_weight_scale = torch.nn.Parameter(
@@ -472,10 +485,9 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
472
485
  _is_cpu_amx_available
473
486
  ), "W8A8Int8MoEMethod on CPU requires that CPU has AMX support"
474
487
  _amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
475
- return
476
-
477
- layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False)
478
- layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False)
488
+ else:
489
+ layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False)
490
+ layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False)
479
491
  layer.w13_weight_scale = Parameter(
480
492
  layer.w13_weight_scale.data, requires_grad=False
481
493
  )
@@ -483,23 +495,30 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
483
495
  layer.w2_weight_scale.data, requires_grad=False
484
496
  )
485
497
 
498
+ def create_moe_runner(
499
+ self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
500
+ ):
501
+ self.moe_runner_config = moe_runner_config
502
+ self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
503
+
486
504
  def apply(
487
505
  self,
488
506
  layer: torch.nn.Module,
489
- x: torch.Tensor,
490
- topk_output: TopKOutput,
491
- moe_runner_config: MoeRunnerConfig,
507
+ dispatch_output: StandardDispatchOutput,
492
508
  ) -> torch.Tensor:
493
- from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
509
+ from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
510
+
511
+ x = dispatch_output.hidden_states
512
+ topk_output = dispatch_output.topk_output
494
513
 
495
514
  if use_intel_amx_backend(layer):
496
515
  from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
497
516
 
498
517
  topk_weights, topk_ids, _ = topk_output
499
518
  x, topk_weights = apply_topk_weights_cpu(
500
- moe_runner_config.apply_router_weight_on_input, topk_weights, x
519
+ self.moe_runner_config.apply_router_weight_on_input, topk_weights, x
501
520
  )
502
- return torch.ops.sgl_kernel.fused_experts_cpu(
521
+ output = torch.ops.sgl_kernel.fused_experts_cpu(
503
522
  x,
504
523
  layer.w13_weight,
505
524
  layer.w2_weight,
@@ -515,20 +534,19 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
515
534
  layer.w2_input_scale, # a2_scale
516
535
  True, # is_vnni
517
536
  )
537
+ return StandardCombineInput(hidden_states=output)
518
538
 
519
- return fused_experts(
520
- x,
521
- layer.w13_weight,
522
- layer.w2_weight,
523
- topk_output=topk_output,
524
- moe_runner_config=moe_runner_config,
539
+ quant_info = TritonMoeQuantInfo(
540
+ w13_weight=layer.w13_weight,
541
+ w2_weight=layer.w2_weight,
525
542
  use_int8_w8a8=True,
526
543
  per_channel_quant=True,
527
- w1_scale=(layer.w13_weight_scale),
528
- w2_scale=(layer.w2_weight_scale),
529
- a1_scale=layer.w13_input_scale,
544
+ w13_scale=layer.w13_weight_scale,
545
+ w2_scale=layer.w2_weight_scale,
546
+ a13_scale=layer.w13_input_scale,
530
547
  a2_scale=layer.w2_input_scale,
531
548
  )
549
+ return self.runner.run(dispatch_output, quant_info)
532
550
 
533
551
 
534
552
  class NPU_W8A8LinearMethodImpl:
@@ -900,7 +918,7 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
900
918
  layer: torch.nn.Module,
901
919
  num_experts: int,
902
920
  hidden_size: int,
903
- intermediate_size: int,
921
+ intermediate_size_per_partition: int,
904
922
  params_dtype: torch.dtype,
905
923
  **extra_weight_attrs,
906
924
  ) -> None:
@@ -914,21 +932,31 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
914
932
  # weight
915
933
  w13_weight = torch.nn.Parameter(
916
934
  torch.empty(
917
- num_experts, 2 * intermediate_size, hidden_size, dtype=torch.int8
935
+ num_experts,
936
+ 2 * intermediate_size_per_partition,
937
+ hidden_size,
938
+ dtype=torch.int8,
918
939
  ),
919
940
  requires_grad=False,
920
941
  )
921
942
  layer.register_parameter("w13_weight", w13_weight)
922
943
  set_weight_attrs(w13_weight, extra_weight_attrs)
923
944
  w2_weight = torch.nn.Parameter(
924
- torch.empty(num_experts, hidden_size, intermediate_size, dtype=torch.int8),
945
+ torch.empty(
946
+ num_experts,
947
+ hidden_size,
948
+ intermediate_size_per_partition,
949
+ dtype=torch.int8,
950
+ ),
925
951
  requires_grad=False,
926
952
  )
927
953
  layer.register_parameter("w2_weight", w2_weight)
928
954
  set_weight_attrs(w2_weight, extra_weight_attrs)
929
955
  # scale
930
956
  w13_weight_scale = torch.nn.Parameter(
931
- torch.empty(num_experts, 2 * intermediate_size, 1, dtype=torch.float32),
957
+ torch.empty(
958
+ num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32
959
+ ),
932
960
  requires_grad=False,
933
961
  )
934
962
  layer.register_parameter("w13_weight_scale", w13_weight_scale)
@@ -941,7 +969,9 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
941
969
  set_weight_attrs(w2_weight_scale, extra_weight_attrs)
942
970
  # offset
943
971
  w13_weight_offset = torch.nn.Parameter(
944
- torch.empty(num_experts, 2 * intermediate_size, 1, dtype=torch.float32),
972
+ torch.empty(
973
+ num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32
974
+ ),
945
975
  requires_grad=False,
946
976
  )
947
977
  layer.register_parameter("w13_weight_offset", w13_weight_offset)
@@ -973,18 +1003,25 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
973
1003
  layer.w2_weight_offset.data.squeeze(-1).contiguous(), requires_grad=False
974
1004
  )
975
1005
 
1006
+ def create_moe_runner(
1007
+ self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
1008
+ ):
1009
+ self.moe_runner_config = moe_runner_config
1010
+
976
1011
  def apply(
977
1012
  self,
978
1013
  layer,
979
- x,
980
- topk_output: TopKOutput,
981
- moe_runner_config: MoeRunnerConfig,
982
- ) -> torch.Tensor:
1014
+ dispatch_output: StandardDispatchOutput,
1015
+ ) -> CombineInput:
1016
+ from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
1017
+
1018
+ x = dispatch_output.hidden_states
1019
+ topk_output = dispatch_output.topk_output
983
1020
 
984
1021
  topk_weights, topk_ids, _ = topk_output
985
1022
  topk_ids = topk_ids.to(torch.int32)
986
1023
  topk_weights = topk_weights.to(x.dtype)
987
- return npu_fused_experts(
1024
+ output = npu_fused_experts(
988
1025
  hidden_states=x,
989
1026
  w13=layer.w13_weight,
990
1027
  w13_scale=layer.w13_weight_scale,
@@ -994,3 +1031,4 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
994
1031
  topk_ids=topk_ids,
995
1032
  top_k=topk_ids.shape[1],
996
1033
  )
1034
+ return StandardCombineInput(hidden_states=output)
@@ -0,0 +1,44 @@
1
+ import torch
2
+ from aiter.ops.triton.fused_qk_concat import fused_qk_rope_cat
3
+ from aiter.ops.triton.gemm_a16w16 import gemm_a16w16
4
+ from aiter.ops.triton.gemm_a16w16_atomic import gemm_a16w16_atomic
5
+
6
+ from sglang.srt.utils import BumpAllocator
7
+
8
+ __all__ = ["fused_qk_rope_cat"]
9
+
10
+
11
+ def aiter_dsv3_router_gemm(
12
+ hidden_states: torch.Tensor,
13
+ weight: torch.Tensor,
14
+ gemm_output_zero_allocator: BumpAllocator = None,
15
+ ):
16
+ M = hidden_states.shape[0]
17
+ N = weight.shape[0]
18
+ y = None
19
+
20
+ if M <= 256:
21
+ # TODO (cagri): convert to bfloat16 as part of another kernel to save time
22
+ # for now it is also coupled with zero allocator.
23
+ if gemm_output_zero_allocator != None:
24
+ y = gemm_output_zero_allocator.allocate(M * N).view(M, N)
25
+ else:
26
+ y = torch.zeros((M, N), dtype=torch.float32, device=hidden_states.device)
27
+
28
+ if y is not None:
29
+ logits = gemm_a16w16_atomic(hidden_states, weight, y=y).to(hidden_states.dtype)
30
+ else:
31
+ logits = gemm_a16w16(hidden_states, weight)
32
+
33
+ return logits
34
+
35
+
36
+ def get_dsv3_gemm_output_zero_allocator_size(
37
+ n_routed_experts: int, num_moe_layers: int, allocate_size: int, embedding_dim: int
38
+ ):
39
+ if embedding_dim != 7168 or n_routed_experts != 256:
40
+ return 0
41
+
42
+ per_layer_size = 256 * (allocate_size + n_routed_experts)
43
+
44
+ return num_moe_layers * per_layer_size
@@ -1433,24 +1433,6 @@ class MRotaryEmbedding(RotaryEmbedding):
1433
1433
 
1434
1434
  return position_ids, mrope_position_deltas
1435
1435
 
1436
- @staticmethod
1437
- def get_next_input_positions(
1438
- mrope_position_delta: int,
1439
- context_len: int,
1440
- seq_len: int,
1441
- ) -> torch.Tensor:
1442
- return torch.tensor(
1443
- [
1444
- list(
1445
- range(
1446
- context_len + mrope_position_delta,
1447
- seq_len + mrope_position_delta,
1448
- )
1449
- )
1450
- for _ in range(3)
1451
- ]
1452
- )
1453
-
1454
1436
 
1455
1437
  class DualChunkRotaryEmbedding(CustomOp):
1456
1438
  """Rotary positional embedding for Dual Chunk Attention."""
@@ -1,5 +1,5 @@
1
1
  import logging
2
- from typing import List
2
+ from typing import List, Tuple
3
3
 
4
4
  import torch
5
5
  import torch.distributed as dist
@@ -39,6 +39,25 @@ class Sampler(nn.Module):
39
39
  if is_dp_attention_enabled():
40
40
  self.tp_sync_group = get_attention_tp_group().device_group
41
41
 
42
+ def _preprocess_logits(
43
+ self, logits: torch.Tensor, sampling_info: SamplingBatchInfo
44
+ ) -> torch.Tensor:
45
+ """Apply custom logit processors and handle NaN detection."""
46
+ # Apply the custom logit processors if registered in the sampling info
47
+ if sampling_info.has_custom_logit_processor:
48
+ apply_custom_logit_processor(logits, sampling_info)
49
+
50
+ # Detect and handle NaN values in logits
51
+ if self.use_nan_detection and torch.any(torch.isnan(logits)):
52
+ logger.warning("Detected errors during sampling! NaN in the logits.")
53
+ logits = torch.where(
54
+ torch.isnan(logits), torch.full_like(logits, -1e5), logits
55
+ )
56
+ if crash_on_warnings():
57
+ raise ValueError("Detected errors during sampling! NaN in the logits.")
58
+
59
+ return logits
60
+
42
61
  def forward(
43
62
  self,
44
63
  logits_output: LogitsProcessorOutput,
@@ -61,17 +80,8 @@ class Sampler(nn.Module):
61
80
  """
62
81
  logits = logits_output.next_token_logits
63
82
 
64
- # Apply the custom logit processors if registered in the sampling info.
65
- if sampling_info.has_custom_logit_processor:
66
- apply_custom_logit_processor(logits, sampling_info)
67
-
68
- if self.use_nan_detection and torch.any(torch.isnan(logits)):
69
- logger.warning("Detected errors during sampling! NaN in the logits.")
70
- logits = torch.where(
71
- torch.isnan(logits), torch.full_like(logits, -1e5), logits
72
- )
73
- if crash_on_warnings():
74
- raise ValueError("Detected errors during sampling! NaN in the logits.")
83
+ # Preprocess logits (custom processors and NaN handling)
84
+ logits = self._preprocess_logits(logits, sampling_info)
75
85
 
76
86
  if sampling_info.is_all_greedy:
77
87
  # Use torch.argmax if all requests use greedy sampling
@@ -80,9 +90,9 @@ class Sampler(nn.Module):
80
90
  logprobs = torch.nn.functional.log_softmax(logits, dim=-1)
81
91
 
82
92
  else:
83
- # Post process original logits. if temperatures are all 1.0, no need to rescale
93
+ # If requested, cache probabilities from original logits before temperature scaling.
84
94
  if return_logprob and RETURN_ORIGINAL_LOGPROB:
85
- logprobs = torch.softmax(logits, dim=-1)
95
+ probs_without_temp_scaling = torch.softmax(logits, dim=-1)
86
96
 
87
97
  # Post process logits
88
98
  logits.div_(sampling_info.temperatures)
@@ -123,9 +133,10 @@ class Sampler(nn.Module):
123
133
  if return_logprob:
124
134
  # clamp to avoid -inf
125
135
  if RETURN_ORIGINAL_LOGPROB:
126
- logprobs = torch.log(logprobs).clamp(
127
- min=torch.finfo(logprobs.dtype).min
136
+ logprobs = torch.log(probs_without_temp_scaling).clamp(
137
+ min=torch.finfo(probs_without_temp_scaling.dtype).min
128
138
  )
139
+ del probs_without_temp_scaling
129
140
  else:
130
141
  logprobs = torch.log(probs).clamp(min=torch.finfo(probs.dtype).min)
131
142
 
@@ -164,6 +175,54 @@ class Sampler(nn.Module):
164
175
 
165
176
  return batch_next_token_ids
166
177
 
178
+ def compute_logprobs_only(
179
+ self,
180
+ logits_output: LogitsProcessorOutput,
181
+ sampling_info: SamplingBatchInfo,
182
+ return_logprob: bool,
183
+ top_logprobs_nums: List[int],
184
+ token_ids_logprobs: List[List[int]],
185
+ ) -> None:
186
+ """
187
+ Compute logprobs for requested token IDs without performing sampling.
188
+
189
+ Optimized for prefill-only scoring requests that need token probabilities
190
+ but don't require next token generation.
191
+ """
192
+ if logits_output.next_token_logits is None:
193
+ logger.warning("No logits available for logprob computation")
194
+ return
195
+
196
+ # Check if any requests actually need logprobs computation
197
+ needs_token_ids_logprobs = any(
198
+ token_ids is not None and len(token_ids) > 0
199
+ for token_ids in token_ids_logprobs
200
+ )
201
+ needs_top_logprobs = any(x > 0 for x in top_logprobs_nums)
202
+
203
+ if not (needs_token_ids_logprobs or needs_top_logprobs):
204
+ return
205
+
206
+ # Preprocess logits (custom processors and NaN handling)
207
+ logits = self._preprocess_logits(logits_output.next_token_logits, sampling_info)
208
+
209
+ # Compute logprobs
210
+ logprobs = torch.nn.functional.log_softmax(logits, dim=-1)
211
+
212
+ # Handle top logprobs if requested
213
+ if needs_top_logprobs:
214
+ (
215
+ logits_output.next_token_top_logprobs_val,
216
+ logits_output.next_token_top_logprobs_idx,
217
+ ) = get_top_logprobs(logprobs, top_logprobs_nums)
218
+
219
+ # Handle token_ids logprobs if requested
220
+ if needs_token_ids_logprobs:
221
+ (
222
+ logits_output.next_token_token_ids_logprobs_val,
223
+ logits_output.next_token_token_ids_logprobs_idx,
224
+ ) = get_token_ids_logprobs_batch_optimized(logprobs, token_ids_logprobs)
225
+
167
226
 
168
227
  def top_k_top_p_min_p_sampling_from_probs_torch(
169
228
  probs: torch.Tensor,
@@ -233,10 +292,95 @@ def get_top_logprobs(
233
292
  )
234
293
 
235
294
 
236
- def get_token_ids_logprobs(
295
+ def get_token_ids_logprobs_batch_optimized(
237
296
  logprobs: torch.Tensor,
238
297
  token_ids_logprobs: List[List[int]],
239
- ):
298
+ ) -> Tuple[List, List]:
299
+ """
300
+ Vectorized batch processing for token ID logprobs extraction.
301
+
302
+ Uses a single GPU kernel call for the entire batch instead of multiple
303
+ separate calls, significantly improving performance for large batches.
304
+
305
+ Args:
306
+ logprobs: Log probabilities tensor [batch_size, vocab_size]
307
+ token_ids_logprobs: List of token IDs to extract logprobs for
308
+
309
+ Example:
310
+ # Input: batch_size=3, vocab_size=5
311
+ logprobs = torch.tensor([
312
+ [-1.2, -2.1, -0.8, -3.0, -1.5], # batch 0
313
+ [-0.5, -1.8, -2.2, -1.1, -2.7], # batch 1
314
+ [-2.0, -0.9, -1.4, -2.8, -1.6], # batch 2
315
+ ])
316
+ token_ids_logprobs = [[1, 3], [2], [0, 2, 4]]
317
+
318
+ # Output:
319
+ # values = [tensor([-2.1, -3.0]), tensor([-2.2]), tensor([-2.0, -1.4, -1.6])]
320
+ # indices = [[1, 3], [2], [0, 2, 4]]
321
+ """
322
+ batch_size = len(token_ids_logprobs)
323
+ device = logprobs.device
324
+
325
+ # Step 1: Calculate lengths for each request, treating None as empty list
326
+ # Example: [[1, 3], [2], [0, 2, 4]] -> token_lengths = tensor([2, 1, 3])
327
+ token_lengths = torch.tensor(
328
+ [len(token_ids or []) for token_ids in token_ids_logprobs], device=device
329
+ )
330
+ total_tokens = int(token_lengths.sum().item()) # 2 + 1 + 3 = 6
331
+
332
+ # Handle edge case where no tokens are requested
333
+ if total_tokens == 0:
334
+ return [logprobs.new_empty(0) for _ in token_ids_logprobs], [
335
+ [] for _ in token_ids_logprobs
336
+ ]
337
+
338
+ # Step 2: Build flattened indices using torch operations
339
+ # Example: row_indices = [0, 0, 1, 2, 2, 2] (batch indices repeated by their lengths)
340
+ row_indices = torch.repeat_interleave(
341
+ torch.arange(batch_size, device=device), token_lengths
342
+ )
343
+ # Example: col_indices = [1, 3, 2, 0, 2, 4] (flattened token IDs from all requests)
344
+ col_indices = torch.tensor(
345
+ [
346
+ token_id
347
+ for token_ids in token_ids_logprobs
348
+ for token_id in (token_ids or [])
349
+ ],
350
+ device=device,
351
+ dtype=torch.long,
352
+ )
353
+
354
+ # Step 3: Single vectorized gather operation
355
+ # Example: logprobs[row_indices, col_indices] -> [-2.1, -3.0, -2.2, -2.0, -1.4, -1.6]
356
+ gathered_logprobs = logprobs[row_indices, col_indices]
357
+
358
+ # Step 4: Split results back per request using torch operations
359
+ # Example: split tensor [6] into chunks of sizes [2, 1, 3] -> [tensor(2), tensor(1), tensor(3)]
360
+ split_logprobs = torch.split_with_sizes(
361
+ gathered_logprobs, token_lengths.tolist(), dim=0
362
+ )
363
+
364
+ # Step 5: Format output to match expected return structure
365
+ # Example: Convert split tensors back to list format with proper empty handling
366
+ # i=0: [1,3] -> append split_logprobs[0] and [1,3]
367
+ # i=1: [2] -> append split_logprobs[1] and [2]
368
+ # i=2: [0,2,4] -> append split_logprobs[2] and [0,2,4]
369
+ output_token_ids_logprobs_val = []
370
+ output_token_ids_logprobs_idx = []
371
+
372
+ for i, token_ids in enumerate(token_ids_logprobs):
373
+ if token_ids is not None and len(token_ids) > 0:
374
+ output_token_ids_logprobs_val.append(split_logprobs[i])
375
+ output_token_ids_logprobs_idx.append(token_ids)
376
+ else:
377
+ output_token_ids_logprobs_val.append(logprobs.new_empty(0))
378
+ output_token_ids_logprobs_idx.append([])
379
+
380
+ return output_token_ids_logprobs_val, output_token_ids_logprobs_idx
381
+
382
+
383
+ def get_token_ids_logprobs(logprobs: torch.Tensor, token_ids_logprobs: List[List[int]]):
240
384
  output_token_ids_logprobs_val = []
241
385
  output_token_ids_logprobs_idx = []
242
386
  for i, token_ids in enumerate(token_ids_logprobs):
@@ -1,8 +1,9 @@
1
- from typing import Tuple, Union
1
+ from typing import Optional, Tuple, Union
2
2
 
3
3
  import torch
4
4
 
5
5
  from sglang.srt.lora.utils import LoRABatchInfo
6
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
6
7
 
7
8
 
8
9
  class BaseLoRABackend:
@@ -10,13 +11,14 @@ class BaseLoRABackend:
10
11
  Each backend has its own implementation of Lora kernels.
11
12
 
12
13
  Args:
13
- name: name of backend
14
- batch_info: information of current batch for use
14
+ max_loras_per_batch: maximum number of different lora weights
15
+ that can be applied in a single forward batch.
16
+ device: the device where the backend runs.
15
17
  """
16
18
 
17
- def __init__(self, name: str, batch_info: LoRABatchInfo = None):
18
- self.name = name
19
- self.batch_info = batch_info
19
+ def __init__(self, max_loras_per_batch: int, device: torch.device):
20
+ self.max_loras_per_batch = max_loras_per_batch
21
+ self.device = device
20
22
 
21
23
  def run_lora_a_sgemm(
22
24
  self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
@@ -93,8 +95,44 @@ class BaseLoRABackend:
93
95
  """
94
96
  pass
95
97
 
96
- def set_batch_info(self, batch_info: LoRABatchInfo):
97
- self.batch_info = batch_info
98
+ def init_cuda_graph_batch_info(
99
+ self,
100
+ cuda_graph_batch_info: LoRABatchInfo,
101
+ max_bs_in_cuda_graph: int,
102
+ ):
103
+ """Initialize the batch info for CUDA Graph mode.
104
+
105
+ This method provides a hook for each backend to conduct its own initialization
106
+ logic for CUDA Graph mode.
107
+
108
+ Args:
109
+ cuda_graph_batch_info: the LoRABatchInfo object created in LoraManager
110
+ max_bs_in_cuda_graph: maximum batch size for CUDA Graph mode
111
+ """
112
+ pass
113
+
114
+ def prepare_lora_batch(
115
+ self,
116
+ forward_batch: ForwardBatch,
117
+ weight_indices: list[int],
118
+ lora_ranks: list[int],
119
+ scalings: list[float],
120
+ batch_info: Optional[LoRABatchInfo] = None,
121
+ ):
122
+ """Prepare the lora weights and batch info for current forward batch.
123
+
124
+ This method provides a hook for each backend to conduct its own preparation
125
+ logic for each forward batch.
126
+
127
+ Args:
128
+ forward_batch: the ForwardBatch object for current forward pass
129
+ weight_indices: list of indices of lora weights to be applied for current batch
130
+ lora_ranks: list of lora ranks corresponding to weight_indices
131
+ scalings: list of scaling factors corresponding to weight_indices
132
+ batch_info: optional LoRABatchInfo object, if not provided, the backend should use its own
133
+ internal batch info (e.g., self.cuda_graph_batch_info for CUDA Graph mode)
134
+ """
135
+ pass
98
136
 
99
137
 
100
138
  def get_backend_from_name(name: str) -> BaseLoRABackend:
@@ -105,6 +143,10 @@ def get_backend_from_name(name: str) -> BaseLoRABackend:
105
143
  from sglang.srt.lora.backend.triton_backend import TritonLoRABackend
106
144
 
107
145
  return TritonLoRABackend
146
+ # elif name == "csgmv":
147
+ # from sglang.srt.lora.backend.chunked_backend import ChunkedSgmvLoRABackend
148
+
149
+ # return ChunkedSgmvLoRABackend
108
150
  elif name == "flashinfer":
109
151
  raise ValueError(
110
152
  "FlashInfer LoRA backend has been deprecated, please use `triton` instead."