sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (282) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +321 -31
  3. sglang/bench_serving.py +10 -3
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +4 -0
  11. sglang/srt/configs/dots_ocr.py +64 -0
  12. sglang/srt/configs/falcon_h1.py +360 -0
  13. sglang/srt/configs/load_config.py +8 -0
  14. sglang/srt/configs/model_config.py +160 -105
  15. sglang/srt/configs/qwen3_vl.py +586 -0
  16. sglang/srt/constrained/base_grammar_backend.py +1 -0
  17. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  18. sglang/srt/constrained/xgrammar_backend.py +6 -4
  19. sglang/srt/debug_utils/dumper.py +10 -3
  20. sglang/srt/disaggregation/ascend/conn.py +2 -2
  21. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  22. sglang/srt/disaggregation/common/conn.py +266 -98
  23. sglang/srt/disaggregation/decode.py +50 -9
  24. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  25. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  26. sglang/srt/disaggregation/mooncake/conn.py +51 -541
  27. sglang/srt/disaggregation/nixl/conn.py +148 -39
  28. sglang/srt/disaggregation/prefill.py +31 -14
  29. sglang/srt/disaggregation/utils.py +36 -5
  30. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  31. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  32. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  33. sglang/srt/distributed/parallel_state.py +135 -80
  34. sglang/srt/entrypoints/engine.py +23 -3
  35. sglang/srt/entrypoints/grpc_request_manager.py +330 -55
  36. sglang/srt/entrypoints/grpc_server.py +232 -102
  37. sglang/srt/entrypoints/http_server.py +49 -9
  38. sglang/srt/entrypoints/openai/protocol.py +110 -5
  39. sglang/srt/entrypoints/openai/serving_base.py +25 -6
  40. sglang/srt/entrypoints/openai/serving_chat.py +178 -49
  41. sglang/srt/entrypoints/openai/serving_completions.py +5 -3
  42. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  43. sglang/srt/entrypoints/openai/serving_responses.py +42 -0
  44. sglang/srt/environ.py +285 -0
  45. sglang/srt/eplb/expert_location.py +30 -5
  46. sglang/srt/function_call/function_call_parser.py +3 -2
  47. sglang/srt/function_call/glm4_moe_detector.py +3 -3
  48. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  49. sglang/srt/function_call/json_array_parser.py +63 -0
  50. sglang/srt/function_call/kimik2_detector.py +17 -4
  51. sglang/srt/function_call/utils.py +96 -5
  52. sglang/srt/grpc/compile_proto.py +245 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
  56. sglang/srt/layers/activation.py +7 -6
  57. sglang/srt/layers/attention/aiter_backend.py +14 -15
  58. sglang/srt/layers/attention/ascend_backend.py +108 -9
  59. sglang/srt/layers/attention/attention_registry.py +206 -0
  60. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  61. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  62. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  63. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  66. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  67. sglang/srt/layers/attention/flashinfer_backend.py +112 -194
  68. sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
  69. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  70. sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
  71. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
  72. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
  73. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
  74. sglang/srt/layers/attention/mamba/mamba.py +566 -1
  75. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  76. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  77. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  78. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  79. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  80. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  81. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  82. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  83. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  84. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  85. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  86. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  87. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  88. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  89. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  90. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  91. sglang/srt/layers/attention/nsa/utils.py +24 -0
  92. sglang/srt/layers/attention/nsa_backend.py +887 -0
  93. sglang/srt/layers/attention/tbo_backend.py +6 -6
  94. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  95. sglang/srt/layers/attention/triton_backend.py +42 -9
  96. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  97. sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
  98. sglang/srt/layers/attention/vision.py +58 -0
  99. sglang/srt/layers/attention/wave_backend.py +4 -4
  100. sglang/srt/layers/communicator.py +8 -0
  101. sglang/srt/layers/dp_attention.py +11 -1
  102. sglang/srt/layers/elementwise.py +3 -1
  103. sglang/srt/layers/layernorm.py +2 -0
  104. sglang/srt/layers/linear.py +21 -4
  105. sglang/srt/layers/logits_processor.py +15 -2
  106. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  107. sglang/srt/layers/moe/ep_moe/layer.py +147 -74
  108. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  109. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  110. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  111. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  112. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
  113. sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
  114. sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
  115. sglang/srt/layers/moe/utils.py +10 -0
  116. sglang/srt/layers/parameter.py +23 -6
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  119. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  121. sglang/srt/layers/quantization/fp8.py +2 -2
  122. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  123. sglang/srt/layers/quantization/modelopt_quant.py +44 -9
  124. sglang/srt/layers/quantization/mxfp4.py +12 -4
  125. sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
  126. sglang/srt/layers/quantization/w4afp8.py +0 -4
  127. sglang/srt/layers/quantization/w8a8_int8.py +15 -3
  128. sglang/srt/layers/rotary_embedding.py +78 -31
  129. sglang/srt/layers/sampler.py +52 -4
  130. sglang/srt/layers/utils.py +23 -0
  131. sglang/srt/lora/backend/base_backend.py +3 -3
  132. sglang/srt/lora/backend/chunked_backend.py +348 -0
  133. sglang/srt/lora/backend/triton_backend.py +10 -4
  134. sglang/srt/lora/lora.py +7 -5
  135. sglang/srt/lora/lora_manager.py +17 -6
  136. sglang/srt/lora/mem_pool.py +1 -1
  137. sglang/srt/lora/triton_ops/__init__.py +4 -0
  138. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  139. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  140. sglang/srt/lora/utils.py +7 -5
  141. sglang/srt/managers/cache_controller.py +42 -142
  142. sglang/srt/managers/data_parallel_controller.py +11 -46
  143. sglang/srt/managers/detokenizer_manager.py +11 -11
  144. sglang/srt/managers/io_struct.py +162 -118
  145. sglang/srt/managers/mm_utils.py +43 -6
  146. sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
  147. sglang/srt/managers/multimodal_processor.py +1 -2
  148. sglang/srt/managers/overlap_utils.py +53 -0
  149. sglang/srt/managers/schedule_batch.py +167 -86
  150. sglang/srt/managers/schedule_policy.py +143 -16
  151. sglang/srt/managers/scheduler.py +359 -214
  152. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  153. sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
  154. sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
  155. sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
  156. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  157. sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
  158. sglang/srt/managers/tokenizer_manager.py +84 -136
  159. sglang/srt/managers/tp_worker.py +39 -29
  160. sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
  161. sglang/srt/managers/utils.py +1 -45
  162. sglang/srt/mem_cache/allocator.py +14 -20
  163. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  164. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  165. sglang/srt/mem_cache/chunk_cache.py +8 -1
  166. sglang/srt/mem_cache/evict_policy.py +23 -0
  167. sglang/srt/mem_cache/hicache_storage.py +40 -1
  168. sglang/srt/mem_cache/hiradix_cache.py +119 -32
  169. sglang/srt/mem_cache/memory_pool.py +188 -10
  170. sglang/srt/mem_cache/memory_pool_host.py +134 -182
  171. sglang/srt/mem_cache/radix_cache.py +222 -71
  172. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  173. sglang/srt/mem_cache/storage/__init__.py +10 -0
  174. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  175. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  176. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  177. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  178. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  179. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
  180. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
  181. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
  182. sglang/srt/mem_cache/swa_radix_cache.py +25 -34
  183. sglang/srt/metrics/collector.py +82 -120
  184. sglang/srt/metrics/func_timer.py +2 -7
  185. sglang/srt/metrics/utils.py +8 -1
  186. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  187. sglang/srt/model_executor/cuda_graph_runner.py +39 -32
  188. sglang/srt/model_executor/forward_batch_info.py +23 -38
  189. sglang/srt/model_executor/model_runner.py +131 -183
  190. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  191. sglang/srt/model_loader/loader.py +14 -10
  192. sglang/srt/model_loader/weight_utils.py +156 -2
  193. sglang/srt/models/bailing_moe.py +27 -4
  194. sglang/srt/models/deepseek_nextn.py +6 -1
  195. sglang/srt/models/deepseek_v2.py +536 -153
  196. sglang/srt/models/dots_ocr.py +173 -0
  197. sglang/srt/models/falcon_h1.py +576 -0
  198. sglang/srt/models/gemma3_causal.py +0 -2
  199. sglang/srt/models/gemma3_mm.py +1 -1
  200. sglang/srt/models/gemma3n_mm.py +1 -1
  201. sglang/srt/models/glm4_moe.py +3 -3
  202. sglang/srt/models/glm4_moe_nextn.py +2 -2
  203. sglang/srt/models/glm4v.py +1 -1
  204. sglang/srt/models/glm4v_moe.py +1 -1
  205. sglang/srt/models/gpt_oss.py +7 -30
  206. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  207. sglang/srt/models/llama.py +4 -0
  208. sglang/srt/models/longcat_flash.py +1 -1
  209. sglang/srt/models/longcat_flash_nextn.py +1 -1
  210. sglang/srt/models/mllama4.py +15 -4
  211. sglang/srt/models/qwen2.py +0 -7
  212. sglang/srt/models/qwen2_5_vl.py +2 -2
  213. sglang/srt/models/qwen2_audio.py +1 -1
  214. sglang/srt/models/qwen2_moe.py +64 -1
  215. sglang/srt/models/qwen2_vl.py +1 -1
  216. sglang/srt/models/qwen3.py +18 -3
  217. sglang/srt/models/qwen3_moe.py +31 -3
  218. sglang/srt/models/qwen3_next.py +36 -9
  219. sglang/srt/models/qwen3_vl.py +787 -0
  220. sglang/srt/models/qwen3_vl_moe.py +471 -0
  221. sglang/srt/models/registry.py +15 -3
  222. sglang/srt/models/sarashina2_vision.py +269 -0
  223. sglang/srt/models/solar.py +505 -0
  224. sglang/srt/models/starcoder2.py +357 -0
  225. sglang/srt/models/torch_native_llama.py +9 -2
  226. sglang/srt/models/utils.py +51 -0
  227. sglang/srt/multimodal/processors/base_processor.py +15 -7
  228. sglang/srt/multimodal/processors/dots_vlm.py +2 -3
  229. sglang/srt/multimodal/processors/internvl.py +20 -8
  230. sglang/srt/multimodal/processors/qwen_vl.py +8 -1
  231. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  232. sglang/srt/parser/jinja_template_utils.py +6 -0
  233. sglang/srt/sampling/sampling_batch_info.py +20 -2
  234. sglang/srt/sampling/sampling_params.py +7 -0
  235. sglang/srt/server_args.py +753 -295
  236. sglang/srt/server_args_config_parser.py +146 -0
  237. sglang/srt/single_batch_overlap.py +151 -0
  238. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  239. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  240. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  241. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  242. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  243. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  244. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
  245. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
  246. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
  247. sglang/srt/speculative/eagle_worker.py +57 -25
  248. sglang/srt/speculative/ngram_utils.py +428 -0
  249. sglang/srt/speculative/ngram_worker.py +245 -0
  250. sglang/srt/speculative/spec_info.py +47 -0
  251. sglang/srt/speculative/spec_utils.py +606 -0
  252. sglang/srt/torch_memory_saver_adapter.py +5 -7
  253. sglang/srt/tracing/trace.py +32 -6
  254. sglang/srt/two_batch_overlap.py +8 -5
  255. sglang/srt/utils/__init__.py +2 -0
  256. sglang/srt/{utils.py → utils/common.py} +399 -74
  257. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
  258. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  259. sglang/srt/utils/rpd_utils.py +452 -0
  260. sglang/srt/utils/slow_rank_detector.py +71 -0
  261. sglang/srt/warmup.py +8 -4
  262. sglang/srt/weight_sync/utils.py +1 -1
  263. sglang/test/get_logits_ut.py +57 -0
  264. sglang/test/run_eval.py +79 -11
  265. sglang/test/runners.py +1 -1
  266. sglang/test/simple_eval_common.py +5 -2
  267. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  268. sglang/test/test_block_fp8.py +2 -2
  269. sglang/test/test_deterministic.py +297 -0
  270. sglang/test/test_disaggregation_utils.py +12 -1
  271. sglang/test/test_programs.py +1 -1
  272. sglang/test/test_utils.py +355 -4
  273. sglang/utils.py +10 -1
  274. sglang/version.py +1 -1
  275. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
  276. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
  277. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  278. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  279. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  280. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  281. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  282. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,549 @@
1
+ # Adapted from https://github.com/thinking-machines-lab/batch_invariant_ops/blob/main/batch_invariant_ops/batch_invariant_ops.py
2
+
3
+ import contextlib
4
+ from collections import namedtuple
5
+ from collections.abc import Callable
6
+ from typing import Any, Dict
7
+
8
+ import torch
9
+ import triton
10
+ import triton.language as tl
11
+
12
+ __all__ = [
13
+ "set_batch_invariant_mode",
14
+ "is_batch_invariant_mode_enabled",
15
+ "disable_batch_invariant_mode",
16
+ "enable_batch_invariant_mode",
17
+ ]
18
+
19
+
20
+ def _matmul_launch_metadata(
21
+ grid: Callable[..., Any], kernel: Any, args: Dict[str, Any]
22
+ ) -> Dict[str, Any]:
23
+ ret = {}
24
+ m, n, k = args["M"], args["N"], args["K"]
25
+ ret["name"] = f"{kernel.name} [M={m}, N={n}, K={k}]"
26
+ if "tiles_per_update" in args:
27
+ ret["name"] = (
28
+ f"{kernel.name} [M={m}, N={n}, K={k}, tiles_per_update={args['tiles_per_update']:02}]"
29
+ )
30
+ if "c_ptr" in args:
31
+ bytes_per_elem = args["c_ptr"].element_size()
32
+ else:
33
+ bytes_per_elem = 1 if args["FP8_OUTPUT"] else 2
34
+ ret[f"flops{bytes_per_elem * 8}"] = 2.0 * m * n * k
35
+ ret["bytes"] = bytes_per_elem * (m * k + n * k + m * n)
36
+ return ret
37
+
38
+
39
+ @triton.jit
40
+ def _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS):
41
+ group_id = tile_id // num_pid_in_group
42
+ first_pid_m = group_id * GROUP_SIZE_M
43
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
44
+ pid_m = first_pid_m + (tile_id % group_size_m)
45
+ pid_n = (tile_id % num_pid_in_group) // group_size_m
46
+ return pid_m, pid_n
47
+
48
+
49
+ @triton.jit(launch_metadata=_matmul_launch_metadata)
50
+ def matmul_kernel_persistent(
51
+ a_ptr,
52
+ b_ptr,
53
+ c_ptr, #
54
+ bias_ptr,
55
+ M,
56
+ N,
57
+ K, #
58
+ stride_am,
59
+ stride_ak,
60
+ stride_bk,
61
+ stride_bn,
62
+ stride_cm,
63
+ stride_cn,
64
+ BLOCK_SIZE_M: tl.constexpr, #
65
+ BLOCK_SIZE_N: tl.constexpr, #
66
+ BLOCK_SIZE_K: tl.constexpr, #
67
+ GROUP_SIZE_M: tl.constexpr, #
68
+ NUM_SMS: tl.constexpr, #
69
+ A_LARGE: tl.constexpr,
70
+ B_LARGE: tl.constexpr,
71
+ C_LARGE: tl.constexpr,
72
+ HAS_BIAS: tl.constexpr,
73
+ ):
74
+ start_pid = tl.program_id(axis=0)
75
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
76
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
77
+ k_tiles = tl.cdiv(K, BLOCK_SIZE_K)
78
+ num_tiles = num_pid_m * num_pid_n
79
+
80
+ tile_id_c = start_pid - NUM_SMS
81
+
82
+ offs_k_for_mask = tl.arange(0, BLOCK_SIZE_K)
83
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
84
+
85
+ for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True):
86
+ pid_m, pid_n = _compute_pid(
87
+ tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS
88
+ )
89
+ start_m = pid_m * BLOCK_SIZE_M
90
+ start_n = pid_n * BLOCK_SIZE_N
91
+ offs_am = start_m + tl.arange(0, BLOCK_SIZE_M)
92
+ offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N)
93
+ if A_LARGE:
94
+ offs_am = offs_am.to(tl.int64)
95
+ if B_LARGE:
96
+ offs_bn = offs_bn.to(tl.int64)
97
+ offs_am = tl.where(offs_am < M, offs_am, 0)
98
+ offs_bn = tl.where(offs_bn < N, offs_bn, 0)
99
+ offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M)
100
+ offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N)
101
+
102
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
103
+ for ki in range(k_tiles):
104
+ if A_LARGE or B_LARGE:
105
+ offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K).to(tl.int64)
106
+ else:
107
+ offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
108
+ a_ptrs = a_ptr + (
109
+ offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak
110
+ )
111
+ b_ptrs = b_ptr + (
112
+ offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn
113
+ )
114
+
115
+ a = tl.load(
116
+ a_ptrs, mask=offs_k_for_mask[None, :] < K - ki * BLOCK_SIZE_K, other=0.0
117
+ )
118
+ b = tl.load(
119
+ b_ptrs, mask=offs_k_for_mask[:, None] < K - ki * BLOCK_SIZE_K, other=0.0
120
+ )
121
+ accumulator = tl.dot(a, b, accumulator)
122
+
123
+ tile_id_c += NUM_SMS
124
+ pid_m, pid_n = _compute_pid(
125
+ tile_id_c, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS
126
+ )
127
+ offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
128
+ offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
129
+ if C_LARGE:
130
+ offs_cm = offs_cm.to(tl.int64)
131
+ offs_cn = offs_cn.to(tl.int64)
132
+ c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
133
+ c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
134
+ if HAS_BIAS:
135
+ bias_ptrs = bias_ptr + offs_cn
136
+ bias = tl.load(bias_ptrs, mask=offs_cn < N, other=0.0).to(tl.float32)
137
+ accumulator += bias
138
+ if c_ptr.dtype.element_ty == tl.float8e4nv:
139
+ c = accumulator.to(tl.float8e4nv)
140
+ else:
141
+ c = accumulator.to(tl.float16)
142
+ tl.store(c_ptrs, c, mask=c_mask)
143
+
144
+
145
+ def matmul_persistent(
146
+ a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor | None = None
147
+ ):
148
+ # Check constraints.
149
+ assert a.shape[1] == b.shape[0], "Incompatible dimensions"
150
+ assert a.dtype == b.dtype, "Incompatible dtypes"
151
+ assert (
152
+ bias is None or bias.dim() == 1
153
+ ), "Currently assuming bias is 1D, let Horace know if you run into this"
154
+ NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
155
+ M, K = a.shape
156
+ K, N = b.shape
157
+ dtype = a.dtype
158
+ # Allocates output.
159
+ c = torch.empty((M, N), device=a.device, dtype=dtype)
160
+
161
+ # 1D launch kernel where each block gets its own program.
162
+ def grid(META):
163
+ return (
164
+ min(
165
+ NUM_SMS,
166
+ triton.cdiv(M, META["BLOCK_SIZE_M"])
167
+ * triton.cdiv(N, META["BLOCK_SIZE_N"]),
168
+ ),
169
+ )
170
+
171
+ configs = {
172
+ torch.bfloat16: {
173
+ "BLOCK_SIZE_M": 128,
174
+ "BLOCK_SIZE_N": 128,
175
+ "BLOCK_SIZE_K": 64,
176
+ "GROUP_SIZE_M": 8,
177
+ "num_stages": 3,
178
+ "num_warps": 8,
179
+ },
180
+ torch.float16: {
181
+ "BLOCK_SIZE_M": 128,
182
+ "BLOCK_SIZE_N": 256,
183
+ "BLOCK_SIZE_K": 64,
184
+ "GROUP_SIZE_M": 8,
185
+ "num_stages": 3,
186
+ "num_warps": 8,
187
+ },
188
+ torch.float32: {
189
+ "BLOCK_SIZE_M": 128,
190
+ "BLOCK_SIZE_N": 128,
191
+ "BLOCK_SIZE_K": 32,
192
+ "GROUP_SIZE_M": 8,
193
+ "num_stages": 3,
194
+ "num_warps": 8,
195
+ },
196
+ }
197
+ # print(a.device, b.device, c.device)
198
+ matmul_kernel_persistent[grid](
199
+ a,
200
+ b,
201
+ c, #
202
+ bias,
203
+ M,
204
+ N,
205
+ K, #
206
+ a.stride(0),
207
+ a.stride(1), #
208
+ b.stride(0),
209
+ b.stride(1), #
210
+ c.stride(0),
211
+ c.stride(1), #
212
+ NUM_SMS=NUM_SMS, #
213
+ A_LARGE=a.numel() > 2**31,
214
+ B_LARGE=b.numel() > 2**31,
215
+ C_LARGE=c.numel() > 2**31,
216
+ HAS_BIAS=bias is not None,
217
+ **configs[dtype],
218
+ )
219
+ return c
220
+
221
+
222
+ @triton.jit
223
+ def _log_softmax_kernel(
224
+ input_ptr,
225
+ output_ptr,
226
+ input_row_stride,
227
+ output_row_stride,
228
+ n_cols,
229
+ BLOCK_SIZE: tl.constexpr,
230
+ ):
231
+ """
232
+ Compute log_softmax along the last dimension of a 2D tensor.
233
+ Each block handles one row of the input tensor.
234
+ """
235
+ # Get the row index for this block
236
+ row_idx = tl.program_id(0).to(tl.int64)
237
+
238
+ # Compute base pointers for input and output rows
239
+ row_start_ptr = input_ptr + row_idx * input_row_stride
240
+ output_row_start_ptr = output_ptr + row_idx * output_row_stride
241
+
242
+ # Step 1: Find maximum value in the row for numerical stability
243
+ max_val = -float("inf")
244
+ for col_offset in range(0, n_cols, BLOCK_SIZE):
245
+ col_idx = col_offset + tl.arange(0, BLOCK_SIZE)
246
+ mask = col_idx < n_cols
247
+
248
+ # Load values
249
+ vals = tl.load(row_start_ptr + col_idx, mask=mask, other=-float("inf"))
250
+
251
+ # Update maximum
252
+ max_val = tl.max(tl.maximum(vals, max_val))
253
+
254
+ # Step 2: Compute sum of exp(x - max_val)
255
+ sum_exp = 0.0
256
+ for col_offset in range(0, n_cols, BLOCK_SIZE):
257
+ col_idx = col_offset + tl.arange(0, BLOCK_SIZE)
258
+ mask = col_idx < n_cols
259
+
260
+ # Load values
261
+ vals = tl.load(row_start_ptr + col_idx, mask=mask, other=0.0)
262
+
263
+ # Compute exp(x - max_val) and accumulate
264
+ exp_vals = tl.exp(vals - max_val)
265
+ sum_exp += tl.sum(tl.where(mask, exp_vals, 0.0))
266
+
267
+ # Compute log(sum_exp)
268
+ log_sum_exp = tl.log(sum_exp)
269
+
270
+ # Step 3: Compute final log_softmax values: x - max_val - log_sum_exp
271
+ for col_offset in range(0, n_cols, BLOCK_SIZE):
272
+ col_idx = col_offset + tl.arange(0, BLOCK_SIZE)
273
+ mask = col_idx < n_cols
274
+
275
+ # Load values
276
+ vals = tl.load(row_start_ptr + col_idx, mask=mask)
277
+
278
+ # Compute log_softmax
279
+ output = vals - max_val - log_sum_exp
280
+
281
+ # Store results
282
+ tl.store(output_row_start_ptr + col_idx, output, mask=mask)
283
+
284
+
285
+ def log_softmax(input: torch.Tensor, dim: int = -1) -> torch.Tensor:
286
+ """
287
+ Compute log_softmax using Triton kernel.
288
+
289
+ Args:
290
+ input: Input tensor
291
+ dim: Dimension along which to compute log_softmax (only -1 or last dim supported)
292
+ >> Stashed changes
293
+ Returns:
294
+ Tensor with log_softmax applied along the specified dimension
295
+ """
296
+ if dim != -1 and dim != input.ndim - 1:
297
+ raise ValueError(
298
+ "This implementation only supports log_softmax along the last dimension"
299
+ )
300
+
301
+ # Flatten all dimensions except the last one
302
+ original_shape = input.shape
303
+ input_2d = input.reshape(-1, input.shape[-1])
304
+ input_2d = input_2d.contiguous()
305
+
306
+ n_rows, n_cols = input_2d.shape
307
+
308
+ # Allocate output tensor
309
+ output = torch.empty_like(input_2d)
310
+
311
+ # Choose block size based on the number of columns
312
+ BLOCK_SIZE = 1024
313
+
314
+ # Launch kernel with one block per row
315
+ grid = (n_rows,)
316
+ _log_softmax_kernel[grid](
317
+ input_2d,
318
+ output,
319
+ input_2d.stride(0),
320
+ output.stride(0),
321
+ n_cols,
322
+ BLOCK_SIZE=BLOCK_SIZE,
323
+ )
324
+ # Reshape output back to original shape
325
+ return output.reshape(original_shape)
326
+
327
+
328
+ @triton.jit
329
+ def mean_kernel(
330
+ input_ptr,
331
+ output_ptr,
332
+ input_stride0,
333
+ input_stride1,
334
+ input_stride2,
335
+ output_stride0,
336
+ output_stride1,
337
+ M, # size before reduction dim
338
+ N, # size of reduction dim
339
+ K, # size after reduction dim
340
+ BLOCK_SIZE: tl.constexpr,
341
+ ):
342
+ """
343
+ Kernel for computing mean along a single dimension.
344
+ Input is viewed as (M, N, K) where N is the dimension being reduced.
345
+ """
346
+ # Program ID gives us which output element we're computing
347
+ pid = tl.program_id(0)
348
+
349
+ # Compute output indices
350
+ m_idx = pid // K
351
+ k_idx = pid % K
352
+
353
+ # Bounds check
354
+ if m_idx >= M or k_idx >= K:
355
+ return
356
+
357
+ # Accumulate sum across reduction dimension
358
+ acc = 0.0
359
+ for n_start in range(0, N, BLOCK_SIZE):
360
+ n_offsets = n_start + tl.arange(0, BLOCK_SIZE)
361
+ mask = n_offsets < N
362
+
363
+ # Calculate input indices
364
+ input_idx = (
365
+ m_idx * input_stride0 + n_offsets * input_stride1 + k_idx * input_stride2
366
+ )
367
+
368
+ # Load and accumulate
369
+ vals = tl.load(input_ptr + input_idx, mask=mask, other=0.0)
370
+ acc += tl.sum(vals)
371
+
372
+ # Compute mean and store
373
+ mean_val = acc / N
374
+ output_idx = m_idx * output_stride0 + k_idx * output_stride1
375
+ tl.store(output_ptr + output_idx, mean_val)
376
+
377
+
378
+ def mean_dim(
379
+ input: torch.Tensor,
380
+ dim: int,
381
+ keepdim: bool = False,
382
+ dtype: torch.dtype | None = None,
383
+ ) -> torch.Tensor:
384
+ """
385
+ Triton implementation of torch.mean with single dimension reduction.
386
+
387
+ Args:
388
+ input: Input tensor
389
+ dim: Single dimension along which to compute mean
390
+ keepdim: Whether to keep the reduced dimension
391
+ dtype: Output dtype. If None, uses input dtype (or float32 for integer inputs)
392
+
393
+ Returns:
394
+ Tensor with mean values along specified dimension
395
+ """
396
+ # Validate inputs
397
+ assert input.is_cuda, "Input must be a CUDA tensor"
398
+ assert (
399
+ -input.ndim <= dim < input.ndim
400
+ ), f"Invalid dimension {dim} for tensor with {input.ndim} dimensions"
401
+
402
+ # Handle negative dim
403
+ if dim < 0:
404
+ dim = dim + input.ndim
405
+
406
+ # Handle dtype
407
+ if dtype is None:
408
+ if input.dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
409
+ dtype = torch.float32
410
+ else:
411
+ dtype = input.dtype
412
+
413
+ # Convert input to appropriate dtype if needed
414
+ if input.dtype != dtype:
415
+ input = input.to(dtype)
416
+
417
+ # Get input shape and strides
418
+ shape = list(input.shape)
419
+
420
+ # Calculate dimensions for kernel
421
+ M = 1
422
+ for i in range(dim):
423
+ M *= shape[i]
424
+
425
+ N = shape[dim]
426
+
427
+ K = 1
428
+ for i in range(dim + 1, len(shape)):
429
+ K *= shape[i]
430
+
431
+ # Reshape input to 3D view (M, N, K)
432
+ input_3d = input.reshape(M, N, K)
433
+
434
+ # Create output shape
435
+ if keepdim:
436
+ output_shape = shape.copy()
437
+ output_shape[dim] = 1
438
+ else:
439
+ output_shape = shape[:dim] + shape[dim + 1 :]
440
+
441
+ # Create output tensor
442
+ output = torch.empty(output_shape, dtype=dtype, device=input.device)
443
+
444
+ # Reshape output for kernel
445
+ if keepdim:
446
+ output_2d = output.reshape(M, 1, K).squeeze(1)
447
+ else:
448
+ output_2d = output.reshape(M, K)
449
+
450
+ # Launch kernel
451
+ grid = (M * K,)
452
+ BLOCK_SIZE = 1024
453
+
454
+ mean_kernel[grid](
455
+ input_3d,
456
+ output_2d,
457
+ input_3d.stride(0),
458
+ input_3d.stride(1),
459
+ input_3d.stride(2),
460
+ output_2d.stride(0),
461
+ output_2d.stride(1) if output_2d.ndim > 1 else 0,
462
+ M,
463
+ N,
464
+ K,
465
+ BLOCK_SIZE,
466
+ )
467
+
468
+ return output
469
+
470
+
471
+ def mm_batch_invariant(a, b):
472
+ return matmul_persistent(a, b)
473
+
474
+
475
+ def addmm_batch_invariant(bias, a, b):
476
+ return matmul_persistent(a, b, bias=bias)
477
+
478
+
479
+ def _log_softmax_batch_invariant(input, dim, _half_to_float):
480
+ assert not _half_to_float, "not implemented"
481
+ return log_softmax(input, dim=dim)
482
+
483
+
484
+ def mean_batch_invariant(input, dim, keepdim=False, dtype: torch.dtype | None = None):
485
+ assert dtype is None or dtype == torch.float32, f"unsupported dtype: {dtype}"
486
+ if len(dim) == 1:
487
+ return mean_dim(input, dim[0], keepdim=keepdim)
488
+ else:
489
+ assert input.dtype in {
490
+ torch.float16,
491
+ torch.bfloat16,
492
+ torch.float32,
493
+ }, "only float types supported for now"
494
+ n_elems = 1
495
+ for d in dim:
496
+ n_elems *= input.shape[d]
497
+ return torch.sum(input, dim=dim, keepdim=keepdim, dtype=torch.float32) / n_elems
498
+
499
+
500
+ _batch_invariant_MODE = False
501
+ _batch_invariant_LIB = None
502
+
503
+
504
+ def is_batch_invariant_mode_enabled():
505
+ return _batch_invariant_MODE
506
+
507
+
508
+ def enable_batch_invariant_mode():
509
+ global _batch_invariant_MODE, _batch_invariant_LIB
510
+ if _batch_invariant_MODE:
511
+ return
512
+
513
+ _batch_invariant_MODE = True
514
+ _batch_invariant_LIB = torch.library.Library("aten", "IMPL")
515
+ _batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, "CUDA")
516
+ _batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, "CUDA")
517
+ _batch_invariant_LIB.impl(
518
+ "aten::_log_softmax", _log_softmax_batch_invariant, "CUDA"
519
+ )
520
+ _batch_invariant_LIB.impl("aten::mean.dim", mean_batch_invariant, "CUDA")
521
+
522
+
523
+ def disable_batch_invariant_mode():
524
+ global _batch_invariant_MODE, _batch_invariant_LIB
525
+ if _batch_invariant_LIB is not None:
526
+ _batch_invariant_LIB._destroy()
527
+ _batch_invariant_MODE = False
528
+ _batch_invariant_LIB = None
529
+
530
+
531
+ @contextlib.contextmanager
532
+ def set_batch_invariant_mode(enabled: bool = True):
533
+ global _batch_invariant_MODE, _batch_invariant_LIB
534
+ old_data = (_batch_invariant_MODE, _batch_invariant_LIB)
535
+ if enabled:
536
+ enable_batch_invariant_mode()
537
+ else:
538
+ disable_batch_invariant_mode()
539
+ yield
540
+ if _batch_invariant_LIB is not None:
541
+ _batch_invariant_LIB._destroy()
542
+ _batch_invariant_MODE, _batch_invariant_LIB = old_data
543
+
544
+
545
+ AttentionBlockSize = namedtuple("AttentionBlockSize", ["block_m", "block_n"])
546
+
547
+
548
+ def get_batch_invariant_attention_block_size() -> AttentionBlockSize:
549
+ return AttentionBlockSize(block_m=16, block_n=16)
@@ -1,8 +1,10 @@
1
1
  from sglang.srt.configs.chatglm import ChatGLMConfig
2
2
  from sglang.srt.configs.dbrx import DbrxConfig
3
3
  from sglang.srt.configs.deepseekvl2 import DeepseekVL2Config
4
+ from sglang.srt.configs.dots_ocr import DotsOCRConfig
4
5
  from sglang.srt.configs.dots_vlm import DotsVLMConfig
5
6
  from sglang.srt.configs.exaone import ExaoneConfig
7
+ from sglang.srt.configs.falcon_h1 import FalconH1Config
6
8
  from sglang.srt.configs.janus_pro import MultiModalityConfig
7
9
  from sglang.srt.configs.kimi_vl import KimiVLConfig
8
10
  from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig
@@ -28,4 +30,6 @@ __all__ = [
28
30
  "Step3VisionEncoderConfig",
29
31
  "Qwen3NextConfig",
30
32
  "DotsVLMConfig",
33
+ "DotsOCRConfig",
34
+ "FalconH1Config",
31
35
  ]
@@ -0,0 +1,64 @@
1
+ from typing import Optional
2
+
3
+ from transformers import AutoProcessor, Qwen2_5_VLProcessor
4
+ from transformers.image_processing_utils import BaseImageProcessor
5
+ from transformers.models.qwen2 import Qwen2Config
6
+
7
+ from sglang.srt.configs.dots_vlm import DotsVisionConfig
8
+
9
+
10
+ class DotsOCRConfig(Qwen2Config):
11
+ model_type = "dots_ocr"
12
+
13
+ def __init__(
14
+ self,
15
+ image_token_id=151665,
16
+ video_token_id=151656,
17
+ vision_config: Optional[dict] = None,
18
+ *args,
19
+ **kwargs
20
+ ):
21
+ super().__init__(*args, **kwargs)
22
+ self.image_token_id = image_token_id
23
+ self.video_token_id = video_token_id
24
+ self.vision_config = DotsVisionConfig(**(vision_config or {}))
25
+
26
+ def save_pretrained(self, save_directory, **kwargs):
27
+ self._auto_class = None
28
+ super().save_pretrained(save_directory, **kwargs)
29
+
30
+
31
+ class DummyVideoProcessor(BaseImageProcessor):
32
+ model_input_names = ["pixel_values"]
33
+
34
+ def __call__(self, *args, **kwargs):
35
+ return None
36
+
37
+
38
+ class DotsVLProcessor(Qwen2_5_VLProcessor):
39
+ def __init__(
40
+ self,
41
+ image_processor=None,
42
+ tokenizer=None,
43
+ video_processor=None,
44
+ chat_template=None,
45
+ **kwargs
46
+ ):
47
+ if video_processor is None:
48
+ video_processor = DummyVideoProcessor()
49
+ super().__init__(
50
+ image_processor, tokenizer, video_processor, chat_template=chat_template
51
+ )
52
+ self.image_token = (
53
+ "<|imgpad|>"
54
+ if not hasattr(tokenizer, "image_token")
55
+ else tokenizer.image_token
56
+ )
57
+ self.image_token_id = (
58
+ tokenizer.image_token_id
59
+ if getattr(tokenizer, "image_token_id", None) is not None
60
+ else tokenizer.convert_tokens_to_ids(self.image_token)
61
+ )
62
+
63
+
64
+ AutoProcessor.register(DotsOCRConfig, DotsVLProcessor)