sglang 0.5.1.post3__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (245) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/bench_one_batch_server.py +10 -1
  3. sglang/bench_serving.py +251 -26
  4. sglang/lang/interpreter.py +1 -1
  5. sglang/srt/configs/__init__.py +4 -0
  6. sglang/srt/configs/internvl.py +6 -0
  7. sglang/srt/configs/longcat_flash.py +104 -0
  8. sglang/srt/configs/model_config.py +37 -7
  9. sglang/srt/configs/qwen3_next.py +326 -0
  10. sglang/srt/connector/__init__.py +1 -1
  11. sglang/srt/connector/base_connector.py +1 -2
  12. sglang/srt/connector/redis.py +2 -2
  13. sglang/srt/connector/serde/__init__.py +1 -1
  14. sglang/srt/connector/serde/safe_serde.py +4 -3
  15. sglang/srt/custom_op.py +11 -1
  16. sglang/srt/debug_utils/dump_comparator.py +81 -44
  17. sglang/srt/debug_utils/dump_loader.py +97 -0
  18. sglang/srt/debug_utils/dumper.py +11 -3
  19. sglang/srt/debug_utils/text_comparator.py +73 -11
  20. sglang/srt/disaggregation/ascend/conn.py +75 -0
  21. sglang/srt/disaggregation/base/conn.py +1 -1
  22. sglang/srt/disaggregation/common/conn.py +15 -12
  23. sglang/srt/disaggregation/decode.py +6 -4
  24. sglang/srt/disaggregation/fake/conn.py +1 -1
  25. sglang/srt/disaggregation/mini_lb.py +6 -420
  26. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  27. sglang/srt/disaggregation/nixl/conn.py +180 -16
  28. sglang/srt/disaggregation/prefill.py +6 -4
  29. sglang/srt/disaggregation/utils.py +5 -50
  30. sglang/srt/distributed/parallel_state.py +94 -58
  31. sglang/srt/entrypoints/engine.py +34 -14
  32. sglang/srt/entrypoints/http_server.py +172 -47
  33. sglang/srt/entrypoints/openai/protocol.py +63 -3
  34. sglang/srt/entrypoints/openai/serving_base.py +6 -2
  35. sglang/srt/entrypoints/openai/serving_chat.py +34 -19
  36. sglang/srt/entrypoints/openai/serving_completions.py +10 -4
  37. sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
  38. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  39. sglang/srt/eplb/eplb_manager.py +28 -4
  40. sglang/srt/eplb/expert_distribution.py +55 -15
  41. sglang/srt/eplb/expert_location.py +8 -3
  42. sglang/srt/eplb/expert_location_updater.py +1 -1
  43. sglang/srt/function_call/ebnf_composer.py +11 -9
  44. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  45. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  46. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  47. sglang/srt/hf_transformers_utils.py +12 -0
  48. sglang/srt/layers/activation.py +44 -9
  49. sglang/srt/layers/attention/aiter_backend.py +93 -68
  50. sglang/srt/layers/attention/ascend_backend.py +250 -112
  51. sglang/srt/layers/attention/fla/chunk.py +242 -0
  52. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  53. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  54. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  55. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  56. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  57. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  58. sglang/srt/layers/attention/fla/index.py +37 -0
  59. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  60. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  61. sglang/srt/layers/attention/fla/op.py +66 -0
  62. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  63. sglang/srt/layers/attention/fla/utils.py +331 -0
  64. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  65. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  66. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  67. sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
  68. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
  69. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  70. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  71. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  72. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  73. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  74. sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
  75. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  76. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  77. sglang/srt/layers/communicator.py +45 -7
  78. sglang/srt/layers/layernorm.py +54 -12
  79. sglang/srt/layers/logits_processor.py +10 -3
  80. sglang/srt/layers/moe/__init__.py +2 -1
  81. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
  82. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  83. sglang/srt/layers/moe/ep_moe/layer.py +110 -49
  84. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  85. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  86. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  87. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  88. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  89. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  90. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  91. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  92. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  93. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
  94. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  95. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
  96. sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
  97. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  98. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  99. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  100. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  101. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  102. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  103. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  104. sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
  105. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  106. sglang/srt/layers/moe/topk.py +43 -12
  107. sglang/srt/layers/moe/utils.py +6 -5
  108. sglang/srt/layers/quantization/awq.py +19 -7
  109. sglang/srt/layers/quantization/base_config.py +11 -6
  110. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  111. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  112. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  113. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
  114. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
  115. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  116. sglang/srt/layers/quantization/fp8.py +76 -47
  117. sglang/srt/layers/quantization/fp8_utils.py +43 -29
  118. sglang/srt/layers/quantization/gptq.py +25 -17
  119. sglang/srt/layers/quantization/modelopt_quant.py +107 -40
  120. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  121. sglang/srt/layers/quantization/mxfp4.py +77 -45
  122. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  123. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  124. sglang/srt/layers/quantization/quark/utils.py +97 -0
  125. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  126. sglang/srt/layers/quantization/unquant.py +135 -47
  127. sglang/srt/layers/quantization/utils.py +13 -0
  128. sglang/srt/layers/quantization/w4afp8.py +60 -42
  129. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  130. sglang/srt/layers/quantization/w8a8_int8.py +83 -41
  131. sglang/srt/layers/rocm_linear_utils.py +44 -0
  132. sglang/srt/layers/rotary_embedding.py +28 -19
  133. sglang/srt/layers/sampler.py +29 -5
  134. sglang/srt/lora/backend/base_backend.py +50 -8
  135. sglang/srt/lora/backend/triton_backend.py +90 -2
  136. sglang/srt/lora/layers.py +32 -0
  137. sglang/srt/lora/lora.py +4 -1
  138. sglang/srt/lora/lora_manager.py +35 -112
  139. sglang/srt/lora/mem_pool.py +24 -10
  140. sglang/srt/lora/utils.py +18 -9
  141. sglang/srt/managers/cache_controller.py +242 -278
  142. sglang/srt/managers/data_parallel_controller.py +30 -15
  143. sglang/srt/managers/detokenizer_manager.py +13 -2
  144. sglang/srt/managers/disagg_service.py +46 -0
  145. sglang/srt/managers/io_struct.py +160 -11
  146. sglang/srt/managers/mm_utils.py +6 -1
  147. sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
  148. sglang/srt/managers/schedule_batch.py +27 -44
  149. sglang/srt/managers/schedule_policy.py +4 -3
  150. sglang/srt/managers/scheduler.py +90 -115
  151. sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
  152. sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
  153. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  154. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  155. sglang/srt/managers/template_manager.py +3 -3
  156. sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
  157. sglang/srt/managers/tokenizer_manager.py +41 -477
  158. sglang/srt/managers/tp_worker.py +16 -4
  159. sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
  160. sglang/srt/mem_cache/allocator.py +1 -1
  161. sglang/srt/mem_cache/chunk_cache.py +1 -1
  162. sglang/srt/mem_cache/hicache_storage.py +24 -22
  163. sglang/srt/mem_cache/hiradix_cache.py +184 -101
  164. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  165. sglang/srt/mem_cache/memory_pool.py +324 -41
  166. sglang/srt/mem_cache/memory_pool_host.py +25 -18
  167. sglang/srt/mem_cache/radix_cache.py +5 -6
  168. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  169. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  170. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  171. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  172. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +149 -12
  173. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  174. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  175. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +74 -19
  176. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  177. sglang/srt/mem_cache/swa_radix_cache.py +1 -3
  178. sglang/srt/metrics/collector.py +484 -63
  179. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  180. sglang/srt/metrics/utils.py +48 -0
  181. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  182. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  183. sglang/srt/model_executor/forward_batch_info.py +72 -18
  184. sglang/srt/model_executor/model_runner.py +189 -31
  185. sglang/srt/model_loader/__init__.py +9 -3
  186. sglang/srt/model_loader/loader.py +33 -28
  187. sglang/srt/model_loader/utils.py +12 -0
  188. sglang/srt/model_loader/weight_utils.py +2 -1
  189. sglang/srt/models/deepseek_v2.py +311 -50
  190. sglang/srt/models/gemma3n_mm.py +1 -1
  191. sglang/srt/models/glm4_moe.py +10 -1
  192. sglang/srt/models/glm4v.py +4 -2
  193. sglang/srt/models/gpt_oss.py +5 -18
  194. sglang/srt/models/internvl.py +28 -0
  195. sglang/srt/models/llama4.py +9 -0
  196. sglang/srt/models/llama_eagle3.py +17 -0
  197. sglang/srt/models/longcat_flash.py +1026 -0
  198. sglang/srt/models/longcat_flash_nextn.py +699 -0
  199. sglang/srt/models/minicpmv.py +165 -3
  200. sglang/srt/models/mllama4.py +25 -0
  201. sglang/srt/models/opt.py +637 -0
  202. sglang/srt/models/qwen2.py +33 -3
  203. sglang/srt/models/qwen2_5_vl.py +90 -42
  204. sglang/srt/models/qwen2_moe.py +79 -14
  205. sglang/srt/models/qwen3.py +8 -2
  206. sglang/srt/models/qwen3_moe.py +39 -8
  207. sglang/srt/models/qwen3_next.py +1039 -0
  208. sglang/srt/models/qwen3_next_mtp.py +109 -0
  209. sglang/srt/models/torch_native_llama.py +1 -1
  210. sglang/srt/models/transformers.py +1 -1
  211. sglang/srt/multimodal/processors/base_processor.py +4 -2
  212. sglang/srt/multimodal/processors/glm4v.py +9 -9
  213. sglang/srt/multimodal/processors/internvl.py +141 -129
  214. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  215. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  216. sglang/srt/sampling/sampling_batch_info.py +18 -15
  217. sglang/srt/server_args.py +297 -79
  218. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  219. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  220. sglang/srt/speculative/eagle_worker.py +216 -120
  221. sglang/srt/speculative/spec_info.py +5 -0
  222. sglang/srt/speculative/standalone_worker.py +109 -0
  223. sglang/srt/utils.py +37 -2
  224. sglang/srt/weight_sync/utils.py +1 -1
  225. sglang/test/attention/test_trtllm_mla_backend.py +181 -8
  226. sglang/test/few_shot_gsm8k.py +1 -0
  227. sglang/test/runners.py +4 -0
  228. sglang/test/test_cutlass_moe.py +24 -6
  229. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  230. sglang/test/test_disaggregation_utils.py +66 -0
  231. sglang/test/test_utils.py +25 -1
  232. sglang/utils.py +5 -0
  233. sglang/version.py +1 -1
  234. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/METADATA +11 -9
  235. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/RECORD +243 -194
  236. sglang/srt/disaggregation/launch_lb.py +0 -131
  237. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  238. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  239. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  240. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  241. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  242. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  243. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
  244. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
  245. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -24,6 +24,8 @@ from sglang.srt.distributed import (
24
24
  get_tensor_model_parallel_world_size,
25
25
  )
26
26
  from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
27
+ from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
28
+ from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
27
29
  from sglang.srt.layers.parameter import (
28
30
  ChannelQuantScaleParameter,
29
31
  ModelWeightParameter,
@@ -49,8 +51,10 @@ from sglang.srt.utils import (
49
51
  )
50
52
 
51
53
  if TYPE_CHECKING:
52
- from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
53
- from sglang.srt.layers.moe.topk import TopKOutput
54
+ from sglang.srt.layers.moe.token_dispatcher import (
55
+ CombineInput,
56
+ StandardDispatchOutput,
57
+ )
54
58
 
55
59
  _is_cuda = is_cuda()
56
60
  _is_cpu_amx_available = cpu_has_amx_support()
@@ -339,9 +343,8 @@ class W8A8Int8LinearMethod(LinearMethodBase):
339
343
  _is_cpu_amx_available
340
344
  ), "W8A8Int8LinearMethod on CPU requires that CPU has AMX support"
341
345
  _amx_process_weight_after_loading(layer, ["weight"])
342
- return
343
-
344
- layer.weight = Parameter(layer.weight.t(), requires_grad=False)
346
+ else:
347
+ layer.weight = Parameter(layer.weight.t(), requires_grad=False)
345
348
  layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False)
346
349
 
347
350
  def create_weights(
@@ -417,7 +420,7 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
417
420
  layer: torch.nn.Module,
418
421
  num_experts: int,
419
422
  hidden_size: int,
420
- intermediate_size: int,
423
+ intermediate_size_per_partition: int,
421
424
  params_dtype: torch.dtype,
422
425
  **extra_weight_attrs,
423
426
  ):
@@ -428,7 +431,10 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
428
431
  # WEIGHTS
429
432
  w13_weight = torch.nn.Parameter(
430
433
  torch.empty(
431
- num_experts, 2 * intermediate_size, hidden_size, dtype=torch.int8
434
+ num_experts,
435
+ 2 * intermediate_size_per_partition,
436
+ hidden_size,
437
+ dtype=torch.int8,
432
438
  ),
433
439
  requires_grad=False,
434
440
  )
@@ -436,14 +442,21 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
436
442
  set_weight_attrs(w13_weight, extra_weight_attrs)
437
443
 
438
444
  w2_weight = torch.nn.Parameter(
439
- torch.empty(num_experts, hidden_size, intermediate_size, dtype=torch.int8),
445
+ torch.empty(
446
+ num_experts,
447
+ hidden_size,
448
+ intermediate_size_per_partition,
449
+ dtype=torch.int8,
450
+ ),
440
451
  requires_grad=False,
441
452
  )
442
453
  layer.register_parameter("w2_weight", w2_weight)
443
454
  set_weight_attrs(w2_weight, extra_weight_attrs)
444
455
 
445
456
  w13_weight_scale = torch.nn.Parameter(
446
- torch.ones(num_experts, 2 * intermediate_size, 1, dtype=torch.float32),
457
+ torch.ones(
458
+ num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32
459
+ ),
447
460
  requires_grad=False,
448
461
  )
449
462
  w2_weight_scale = torch.nn.Parameter(
@@ -472,10 +485,9 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
472
485
  _is_cpu_amx_available
473
486
  ), "W8A8Int8MoEMethod on CPU requires that CPU has AMX support"
474
487
  _amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
475
- return
476
-
477
- layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False)
478
- layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False)
488
+ else:
489
+ layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False)
490
+ layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False)
479
491
  layer.w13_weight_scale = Parameter(
480
492
  layer.w13_weight_scale.data, requires_grad=False
481
493
  )
@@ -483,23 +495,30 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
483
495
  layer.w2_weight_scale.data, requires_grad=False
484
496
  )
485
497
 
498
+ def create_moe_runner(
499
+ self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
500
+ ):
501
+ self.moe_runner_config = moe_runner_config
502
+ self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
503
+
486
504
  def apply(
487
505
  self,
488
506
  layer: torch.nn.Module,
489
- x: torch.Tensor,
490
- topk_output: TopKOutput,
491
- moe_runner_config: MoeRunnerConfig,
507
+ dispatch_output: StandardDispatchOutput,
492
508
  ) -> torch.Tensor:
493
- from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
509
+ from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
510
+
511
+ x = dispatch_output.hidden_states
512
+ topk_output = dispatch_output.topk_output
494
513
 
495
514
  if use_intel_amx_backend(layer):
496
515
  from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
497
516
 
498
517
  topk_weights, topk_ids, _ = topk_output
499
518
  x, topk_weights = apply_topk_weights_cpu(
500
- moe_runner_config.apply_router_weight_on_input, topk_weights, x
519
+ self.moe_runner_config.apply_router_weight_on_input, topk_weights, x
501
520
  )
502
- return torch.ops.sgl_kernel.fused_experts_cpu(
521
+ output = torch.ops.sgl_kernel.fused_experts_cpu(
503
522
  x,
504
523
  layer.w13_weight,
505
524
  layer.w2_weight,
@@ -515,20 +534,19 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
515
534
  layer.w2_input_scale, # a2_scale
516
535
  True, # is_vnni
517
536
  )
537
+ return StandardCombineInput(hidden_states=output)
518
538
 
519
- return fused_experts(
520
- x,
521
- layer.w13_weight,
522
- layer.w2_weight,
523
- topk_output=topk_output,
524
- moe_runner_config=moe_runner_config,
539
+ quant_info = TritonMoeQuantInfo(
540
+ w13_weight=layer.w13_weight,
541
+ w2_weight=layer.w2_weight,
525
542
  use_int8_w8a8=True,
526
543
  per_channel_quant=True,
527
- w1_scale=(layer.w13_weight_scale),
528
- w2_scale=(layer.w2_weight_scale),
529
- a1_scale=layer.w13_input_scale,
544
+ w13_scale=layer.w13_weight_scale,
545
+ w2_scale=layer.w2_weight_scale,
546
+ a13_scale=layer.w13_input_scale,
530
547
  a2_scale=layer.w2_input_scale,
531
548
  )
549
+ return self.runner.run(dispatch_output, quant_info)
532
550
 
533
551
 
534
552
  class NPU_W8A8LinearMethodImpl:
@@ -551,7 +569,7 @@ class NPU_W8A8LinearMethodImpl:
551
569
  def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
552
570
  params_dict = {}
553
571
  params_dict["input_scale"] = torch.empty(1, dtype=params_dtype)
554
- params_dict["input_offset"] = torch.empty(1, dtype=torch.int8)
572
+ params_dict["input_offset"] = torch.empty(1, dtype=params_dtype)
555
573
  return params_dict
556
574
 
557
575
  @staticmethod
@@ -582,11 +600,11 @@ class NPU_W8A8LinearMethodImpl:
582
600
  if original_dtype != torch.int8:
583
601
  x = torch_npu.npu_quantize(
584
602
  x,
585
- layer.aclnn_input_scale,
603
+ layer.aclnn_input_scale_reciprocal,
586
604
  layer.aclnn_input_offset,
587
605
  torch.qint8,
588
606
  -1,
589
- True,
607
+ False,
590
608
  )
591
609
  # Only fuse bias add into GEMM for rank 0 (this ensures that
592
610
  # bias will not get added more than once in Attention TP>1 case)
@@ -608,6 +626,10 @@ class NPU_W8A8LinearMethodImpl:
608
626
  layer.input_scale.data.repeat(expanding_factor).to(device="npu"),
609
627
  requires_grad=False,
610
628
  )
629
+ layer.aclnn_input_scale_reciprocal = 1 / torch.nn.Parameter(
630
+ layer.input_scale.data.repeat(expanding_factor).to(device="npu"),
631
+ requires_grad=False,
632
+ )
611
633
  layer.aclnn_input_offset = torch.nn.Parameter(
612
634
  layer.input_offset.data.repeat(expanding_factor).to(device="npu"),
613
635
  requires_grad=False,
@@ -896,7 +918,7 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
896
918
  layer: torch.nn.Module,
897
919
  num_experts: int,
898
920
  hidden_size: int,
899
- intermediate_size: int,
921
+ intermediate_size_per_partition: int,
900
922
  params_dtype: torch.dtype,
901
923
  **extra_weight_attrs,
902
924
  ) -> None:
@@ -910,21 +932,31 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
910
932
  # weight
911
933
  w13_weight = torch.nn.Parameter(
912
934
  torch.empty(
913
- num_experts, 2 * intermediate_size, hidden_size, dtype=torch.int8
935
+ num_experts,
936
+ 2 * intermediate_size_per_partition,
937
+ hidden_size,
938
+ dtype=torch.int8,
914
939
  ),
915
940
  requires_grad=False,
916
941
  )
917
942
  layer.register_parameter("w13_weight", w13_weight)
918
943
  set_weight_attrs(w13_weight, extra_weight_attrs)
919
944
  w2_weight = torch.nn.Parameter(
920
- torch.empty(num_experts, hidden_size, intermediate_size, dtype=torch.int8),
945
+ torch.empty(
946
+ num_experts,
947
+ hidden_size,
948
+ intermediate_size_per_partition,
949
+ dtype=torch.int8,
950
+ ),
921
951
  requires_grad=False,
922
952
  )
923
953
  layer.register_parameter("w2_weight", w2_weight)
924
954
  set_weight_attrs(w2_weight, extra_weight_attrs)
925
955
  # scale
926
956
  w13_weight_scale = torch.nn.Parameter(
927
- torch.empty(num_experts, 2 * intermediate_size, 1, dtype=torch.float32),
957
+ torch.empty(
958
+ num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32
959
+ ),
928
960
  requires_grad=False,
929
961
  )
930
962
  layer.register_parameter("w13_weight_scale", w13_weight_scale)
@@ -937,7 +969,9 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
937
969
  set_weight_attrs(w2_weight_scale, extra_weight_attrs)
938
970
  # offset
939
971
  w13_weight_offset = torch.nn.Parameter(
940
- torch.empty(num_experts, 2 * intermediate_size, 1, dtype=torch.float32),
972
+ torch.empty(
973
+ num_experts, 2 * intermediate_size_per_partition, 1, dtype=torch.float32
974
+ ),
941
975
  requires_grad=False,
942
976
  )
943
977
  layer.register_parameter("w13_weight_offset", w13_weight_offset)
@@ -969,18 +1003,25 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
969
1003
  layer.w2_weight_offset.data.squeeze(-1).contiguous(), requires_grad=False
970
1004
  )
971
1005
 
1006
+ def create_moe_runner(
1007
+ self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
1008
+ ):
1009
+ self.moe_runner_config = moe_runner_config
1010
+
972
1011
  def apply(
973
1012
  self,
974
1013
  layer,
975
- x,
976
- topk_output: TopKOutput,
977
- moe_runner_config: MoeRunnerConfig,
978
- ) -> torch.Tensor:
1014
+ dispatch_output: StandardDispatchOutput,
1015
+ ) -> CombineInput:
1016
+ from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
1017
+
1018
+ x = dispatch_output.hidden_states
1019
+ topk_output = dispatch_output.topk_output
979
1020
 
980
1021
  topk_weights, topk_ids, _ = topk_output
981
1022
  topk_ids = topk_ids.to(torch.int32)
982
1023
  topk_weights = topk_weights.to(x.dtype)
983
- return npu_fused_experts(
1024
+ output = npu_fused_experts(
984
1025
  hidden_states=x,
985
1026
  w13=layer.w13_weight,
986
1027
  w13_scale=layer.w13_weight_scale,
@@ -990,3 +1031,4 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
990
1031
  topk_ids=topk_ids,
991
1032
  top_k=topk_ids.shape[1],
992
1033
  )
1034
+ return StandardCombineInput(hidden_states=output)
@@ -0,0 +1,44 @@
1
+ import torch
2
+ from aiter.ops.triton.fused_qk_concat import fused_qk_rope_cat
3
+ from aiter.ops.triton.gemm_a16w16 import gemm_a16w16
4
+ from aiter.ops.triton.gemm_a16w16_atomic import gemm_a16w16_atomic
5
+
6
+ from sglang.srt.utils import BumpAllocator
7
+
8
+ __all__ = ["fused_qk_rope_cat"]
9
+
10
+
11
+ def aiter_dsv3_router_gemm(
12
+ hidden_states: torch.Tensor,
13
+ weight: torch.Tensor,
14
+ gemm_output_zero_allocator: BumpAllocator = None,
15
+ ):
16
+ M = hidden_states.shape[0]
17
+ N = weight.shape[0]
18
+ y = None
19
+
20
+ if M <= 256:
21
+ # TODO (cagri): convert to bfloat16 as part of another kernel to save time
22
+ # for now it is also coupled with zero allocator.
23
+ if gemm_output_zero_allocator != None:
24
+ y = gemm_output_zero_allocator.allocate(M * N).view(M, N)
25
+ else:
26
+ y = torch.zeros((M, N), dtype=torch.float32, device=hidden_states.device)
27
+
28
+ if y is not None:
29
+ logits = gemm_a16w16_atomic(hidden_states, weight, y=y).to(hidden_states.dtype)
30
+ else:
31
+ logits = gemm_a16w16(hidden_states, weight)
32
+
33
+ return logits
34
+
35
+
36
+ def get_dsv3_gemm_output_zero_allocator_size(
37
+ n_routed_experts: int, num_moe_layers: int, allocate_size: int, embedding_dim: int
38
+ ):
39
+ if embedding_dim != 7168 or n_routed_experts != 256:
40
+ return 0
41
+
42
+ per_layer_size = 256 * (allocate_size + n_routed_experts)
43
+
44
+ return num_moe_layers * per_layer_size
@@ -1433,24 +1433,6 @@ class MRotaryEmbedding(RotaryEmbedding):
1433
1433
 
1434
1434
  return position_ids, mrope_position_deltas
1435
1435
 
1436
- @staticmethod
1437
- def get_next_input_positions(
1438
- mrope_position_delta: int,
1439
- context_len: int,
1440
- seq_len: int,
1441
- ) -> torch.Tensor:
1442
- return torch.tensor(
1443
- [
1444
- list(
1445
- range(
1446
- context_len + mrope_position_delta,
1447
- seq_len + mrope_position_delta,
1448
- )
1449
- )
1450
- for _ in range(3)
1451
- ]
1452
- )
1453
-
1454
1436
 
1455
1437
  class DualChunkRotaryEmbedding(CustomOp):
1456
1438
  """Rotary positional embedding for Dual Chunk Attention."""
@@ -1876,7 +1858,7 @@ def rotate_half(x):
1876
1858
  return torch.cat((-x2, x1), dim=-1)
1877
1859
 
1878
1860
 
1879
- def apply_rotary_pos_emb(
1861
+ def apply_rotary_pos_emb_native(
1880
1862
  q: torch.Tensor,
1881
1863
  k: torch.Tensor,
1882
1864
  cos: torch.Tensor,
@@ -1899,6 +1881,33 @@ def apply_rotary_pos_emb(
1899
1881
  return q_embed, k_embed
1900
1882
 
1901
1883
 
1884
+ def apply_rotary_pos_emb_npu(
1885
+ q: torch.Tensor,
1886
+ k: torch.Tensor,
1887
+ cos: torch.Tensor,
1888
+ sin: torch.Tensor,
1889
+ unsqueeze_dim=1,
1890
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1891
+ if q.shape[1] != 128:
1892
+ return apply_rotary_pos_emb_native(q, k, cos, sin, unsqueeze_dim)
1893
+ cos = cos.unsqueeze(unsqueeze_dim)
1894
+ cos = torch.transpose(cos, 1, 2)
1895
+ sin = sin.unsqueeze(unsqueeze_dim)
1896
+ sin = torch.transpose(sin, 1, 2)
1897
+ q = torch.transpose(q, 1, 2)
1898
+ k = torch.transpose(k, 1, 2)
1899
+ q_embed, k_embed = torch_npu.npu_apply_rotary_pos_emb(q, k, cos, sin)
1900
+ q_embed = torch.transpose(q_embed, 1, 2)
1901
+ k_embed = torch.transpose(k_embed, 1, 2)
1902
+ return q_embed, k_embed
1903
+
1904
+
1905
+ if _is_npu:
1906
+ apply_rotary_pos_emb = apply_rotary_pos_emb_npu
1907
+ else:
1908
+ apply_rotary_pos_emb = apply_rotary_pos_emb_native
1909
+
1910
+
1902
1911
  def get_rope_cpu(
1903
1912
  head_size: int,
1904
1913
  rotary_dim: int,
@@ -27,6 +27,7 @@ if is_cuda():
27
27
  logger = logging.getLogger(__name__)
28
28
 
29
29
  SYNC_TOKEN_IDS_ACROSS_TP = get_bool_env_var("SYNC_TOKEN_IDS_ACROSS_TP")
30
+ RETURN_ORIGINAL_LOGPROB = get_bool_env_var("RETURN_ORIGINAL_LOGPROB")
30
31
 
31
32
 
32
33
  class Sampler(nn.Module):
@@ -77,7 +78,12 @@ class Sampler(nn.Module):
77
78
  batch_next_token_ids = torch.argmax(logits, -1)
78
79
  if return_logprob:
79
80
  logprobs = torch.nn.functional.log_softmax(logits, dim=-1)
81
+
80
82
  else:
83
+ # Post process original logits. if temperatures are all 1.0, no need to rescale
84
+ if return_logprob and RETURN_ORIGINAL_LOGPROB:
85
+ logprobs = torch.softmax(logits, dim=-1)
86
+
81
87
  # Post process logits
82
88
  logits.div_(sampling_info.temperatures)
83
89
  logits[:] = torch.softmax(logits, dim=-1)
@@ -116,7 +122,12 @@ class Sampler(nn.Module):
116
122
 
117
123
  if return_logprob:
118
124
  # clamp to avoid -inf
119
- logprobs = torch.log(probs).clamp(min=torch.finfo(probs.dtype).min)
125
+ if RETURN_ORIGINAL_LOGPROB:
126
+ logprobs = torch.log(logprobs).clamp(
127
+ min=torch.finfo(logprobs.dtype).min
128
+ )
129
+ else:
130
+ logprobs = torch.log(probs).clamp(min=torch.finfo(probs.dtype).min)
120
131
 
121
132
  # Attach logprobs to logits_output (in-place modification)
122
133
  if return_logprob:
@@ -201,7 +212,10 @@ def top_p_normalize_probs_torch(
201
212
  return torch.zeros_like(probs_sort).scatter_(-1, probs_idx, probs_sort)
202
213
 
203
214
 
204
- def get_top_logprobs(logprobs: torch.Tensor, top_logprobs_nums: List[int]):
215
+ def get_top_logprobs(
216
+ logprobs: torch.Tensor,
217
+ top_logprobs_nums: List[int],
218
+ ):
205
219
  max_k = max(top_logprobs_nums)
206
220
  ret = logprobs.topk(max_k, dim=1)
207
221
  values = ret.values.tolist()
@@ -212,10 +226,17 @@ def get_top_logprobs(logprobs: torch.Tensor, top_logprobs_nums: List[int]):
212
226
  for i, k in enumerate(top_logprobs_nums):
213
227
  output_top_logprobs_val.append(values[i][:k])
214
228
  output_top_logprobs_idx.append(indices[i][:k])
215
- return output_top_logprobs_val, output_top_logprobs_idx
229
+
230
+ return (
231
+ output_top_logprobs_val,
232
+ output_top_logprobs_idx,
233
+ )
216
234
 
217
235
 
218
- def get_token_ids_logprobs(logprobs: torch.Tensor, token_ids_logprobs: List[List[int]]):
236
+ def get_token_ids_logprobs(
237
+ logprobs: torch.Tensor,
238
+ token_ids_logprobs: List[List[int]],
239
+ ):
219
240
  output_token_ids_logprobs_val = []
220
241
  output_token_ids_logprobs_idx = []
221
242
  for i, token_ids in enumerate(token_ids_logprobs):
@@ -226,7 +247,10 @@ def get_token_ids_logprobs(logprobs: torch.Tensor, token_ids_logprobs: List[List
226
247
  output_token_ids_logprobs_val.append([])
227
248
  output_token_ids_logprobs_idx.append([])
228
249
 
229
- return output_token_ids_logprobs_val, output_token_ids_logprobs_idx
250
+ return (
251
+ output_token_ids_logprobs_val,
252
+ output_token_ids_logprobs_idx,
253
+ )
230
254
 
231
255
 
232
256
  def apply_custom_logit_processor(
@@ -1,8 +1,9 @@
1
- from typing import Tuple, Union
1
+ from typing import Optional, Tuple, Union
2
2
 
3
3
  import torch
4
4
 
5
5
  from sglang.srt.lora.utils import LoRABatchInfo
6
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
6
7
 
7
8
 
8
9
  class BaseLoRABackend:
@@ -10,13 +11,14 @@ class BaseLoRABackend:
10
11
  Each backend has its own implementation of Lora kernels.
11
12
 
12
13
  Args:
13
- name: name of backend
14
- batch_info: information of current batch for use
14
+ max_loras_per_batch: maximum number of different lora weights
15
+ that can be applied in a single forward batch.
16
+ device: the device where the backend runs.
15
17
  """
16
18
 
17
- def __init__(self, name: str, batch_info: LoRABatchInfo = None):
18
- self.name = name
19
- self.batch_info = batch_info
19
+ def __init__(self, max_loras_per_batch: int, device: torch.device):
20
+ self.max_loras_per_batch = max_loras_per_batch
21
+ self.device = device
20
22
 
21
23
  def run_lora_a_sgemm(
22
24
  self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
@@ -93,8 +95,44 @@ class BaseLoRABackend:
93
95
  """
94
96
  pass
95
97
 
96
- def set_batch_info(self, batch_info: LoRABatchInfo):
97
- self.batch_info = batch_info
98
+ def init_cuda_graph_batch_info(
99
+ self,
100
+ cuda_graph_batch_info: LoRABatchInfo,
101
+ max_bs_in_cuda_graph: int,
102
+ ):
103
+ """Initialize the batch info for CUDA Graph mode.
104
+
105
+ This method provides a hook for each backend to conduct its own initialization
106
+ logic for CUDA Graph mode.
107
+
108
+ Args:
109
+ cuda_graph_batch_info: the LoRABatchInfo object created in LoraManager
110
+ max_bs_in_cuda_graph: maximum batch size for CUDA Graph mode
111
+ """
112
+ pass
113
+
114
+ def prepare_lora_batch(
115
+ self,
116
+ forward_batch: ForwardBatch,
117
+ weight_indices: list[int],
118
+ lora_ranks: list[int],
119
+ scalings: list[float],
120
+ batch_info: Optional[LoRABatchInfo] = None,
121
+ ):
122
+ """Prepare the lora weights and batch info for current forward batch.
123
+
124
+ This method provides a hook for each backend to conduct its own preparation
125
+ logic for each forward batch.
126
+
127
+ Args:
128
+ forward_batch: the ForwardBatch object for current forward pass
129
+ weight_indices: list of indices of lora weights to be applied for current batch
130
+ lora_ranks: list of lora ranks corresponding to weight_indices
131
+ scalings: list of scaling factors corresponding to weight_indices
132
+ batch_info: optional LoRABatchInfo object, if not provided, the backend should use its own
133
+ internal batch info (e.g., self.cuda_graph_batch_info for CUDA Graph mode)
134
+ """
135
+ pass
98
136
 
99
137
 
100
138
  def get_backend_from_name(name: str) -> BaseLoRABackend:
@@ -105,6 +143,10 @@ def get_backend_from_name(name: str) -> BaseLoRABackend:
105
143
  from sglang.srt.lora.backend.triton_backend import TritonLoRABackend
106
144
 
107
145
  return TritonLoRABackend
146
+ # elif name == "csgmv":
147
+ # from sglang.srt.lora.backend.chunked_backend import ChunkedSgmvLoRABackend
148
+
149
+ # return ChunkedSgmvLoRABackend
108
150
  elif name == "flashinfer":
109
151
  raise ValueError(
110
152
  "FlashInfer LoRA backend has been deprecated, please use `triton` instead."
@@ -1,3 +1,5 @@
1
+ from typing import Optional
2
+
1
3
  import torch
2
4
 
3
5
  from sglang.srt.lora.backend.base_backend import BaseLoRABackend
@@ -8,12 +10,14 @@ from sglang.srt.lora.triton_ops import (
8
10
  sgemm_lora_b_fwd,
9
11
  )
10
12
  from sglang.srt.lora.utils import LoRABatchInfo
13
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
11
14
 
12
15
 
13
16
  class TritonLoRABackend(BaseLoRABackend):
17
+ name = "triton"
14
18
 
15
- def __init__(self, name: str, batch_info: LoRABatchInfo = None):
16
- super().__init__(name, batch_info)
19
+ def __init__(self, max_loras_per_batch: int, device: torch.device):
20
+ super().__init__(max_loras_per_batch, device)
17
21
 
18
22
  def run_lora_a_sgemm(
19
23
  self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
@@ -86,3 +90,87 @@ class TritonLoRABackend(BaseLoRABackend):
86
90
  base_output,
87
91
  )
88
92
  return lora_output
93
+
94
+ def init_cuda_graph_batch_info(
95
+ self, cuda_graph_batch_info: LoRABatchInfo, max_bs_in_cuda_graph: int
96
+ ):
97
+ # Initialize seg_lens and seg_indptr for CUDA graph as they remain constant
98
+ # across batches.
99
+ cuda_graph_batch_info.seg_lens[:max_bs_in_cuda_graph].fill_(1)
100
+ torch.cumsum(
101
+ cuda_graph_batch_info.seg_lens[:max_bs_in_cuda_graph],
102
+ dim=0,
103
+ out=cuda_graph_batch_info.seg_indptr[1 : max_bs_in_cuda_graph + 1],
104
+ )
105
+
106
+ def prepare_lora_batch(
107
+ self,
108
+ forward_batch: ForwardBatch,
109
+ weight_indices: list[int],
110
+ lora_ranks: list[int],
111
+ scalings: list[float],
112
+ batch_info: Optional[LoRABatchInfo] = None,
113
+ ):
114
+ # Use pinned memory to avoid synchronizations during host-to-device transfer
115
+ weight_indices_tensor = torch.tensor(
116
+ weight_indices, dtype=torch.int32, pin_memory=True, device="cpu"
117
+ )
118
+ lora_ranks_tensor = torch.tensor(
119
+ lora_ranks, dtype=torch.int32, pin_memory=True, device="cpu"
120
+ )
121
+ scalings_tensor = torch.tensor(
122
+ scalings, dtype=torch.float, pin_memory=True, device="cpu"
123
+ )
124
+
125
+ bs = forward_batch.batch_size
126
+
127
+ if batch_info is not None:
128
+ assert (
129
+ batch_info.use_cuda_graph
130
+ ), "batch_info.use_cuda_graph must be True when batch_info is provided"
131
+ batch_info.bs = forward_batch.batch_size
132
+ batch_info.num_segments = forward_batch.batch_size
133
+ else:
134
+ max_len = (
135
+ # Calculate max_len from the CPU copy to avoid D2H transfer.
136
+ max(forward_batch.extend_seq_lens_cpu)
137
+ if forward_batch.forward_mode.is_extend()
138
+ else 1
139
+ )
140
+ seg_lens = (
141
+ forward_batch.extend_seq_lens
142
+ if forward_batch.forward_mode.is_extend()
143
+ else torch.ones(bs, device=self.device)
144
+ )
145
+ seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device)
146
+ seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
147
+
148
+ batch_info = LoRABatchInfo(
149
+ bs=forward_batch.batch_size,
150
+ num_segments=forward_batch.batch_size,
151
+ max_len=max_len,
152
+ use_cuda_graph=False,
153
+ seg_lens=seg_lens,
154
+ seg_indptr=seg_indptr,
155
+ weight_indices=torch.empty(
156
+ (bs,), dtype=torch.int32, device=self.device
157
+ ),
158
+ lora_ranks=torch.empty(
159
+ (self.max_loras_per_batch,), dtype=torch.int64, device=self.device
160
+ ),
161
+ scalings=torch.empty(
162
+ (self.max_loras_per_batch,), dtype=torch.float, device=self.device
163
+ ),
164
+ permutation=None,
165
+ )
166
+
167
+ # Copy to device asynchronously
168
+ batch_info.lora_ranks[: self.max_loras_per_batch].copy_(
169
+ lora_ranks_tensor, non_blocking=True
170
+ )
171
+ batch_info.scalings[: self.max_loras_per_batch].copy_(
172
+ scalings_tensor, non_blocking=True
173
+ )
174
+ batch_info.weight_indices[:bs].copy_(weight_indices_tensor, non_blocking=True)
175
+
176
+ self.batch_info = batch_info