sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (282) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +321 -31
  3. sglang/bench_serving.py +10 -3
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +4 -0
  11. sglang/srt/configs/dots_ocr.py +64 -0
  12. sglang/srt/configs/falcon_h1.py +360 -0
  13. sglang/srt/configs/load_config.py +8 -0
  14. sglang/srt/configs/model_config.py +160 -105
  15. sglang/srt/configs/qwen3_vl.py +586 -0
  16. sglang/srt/constrained/base_grammar_backend.py +1 -0
  17. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  18. sglang/srt/constrained/xgrammar_backend.py +6 -4
  19. sglang/srt/debug_utils/dumper.py +10 -3
  20. sglang/srt/disaggregation/ascend/conn.py +2 -2
  21. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  22. sglang/srt/disaggregation/common/conn.py +266 -98
  23. sglang/srt/disaggregation/decode.py +50 -9
  24. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  25. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  26. sglang/srt/disaggregation/mooncake/conn.py +51 -541
  27. sglang/srt/disaggregation/nixl/conn.py +148 -39
  28. sglang/srt/disaggregation/prefill.py +31 -14
  29. sglang/srt/disaggregation/utils.py +36 -5
  30. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  31. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  32. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  33. sglang/srt/distributed/parallel_state.py +135 -80
  34. sglang/srt/entrypoints/engine.py +23 -3
  35. sglang/srt/entrypoints/grpc_request_manager.py +330 -55
  36. sglang/srt/entrypoints/grpc_server.py +232 -102
  37. sglang/srt/entrypoints/http_server.py +49 -9
  38. sglang/srt/entrypoints/openai/protocol.py +110 -5
  39. sglang/srt/entrypoints/openai/serving_base.py +25 -6
  40. sglang/srt/entrypoints/openai/serving_chat.py +178 -49
  41. sglang/srt/entrypoints/openai/serving_completions.py +5 -3
  42. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  43. sglang/srt/entrypoints/openai/serving_responses.py +42 -0
  44. sglang/srt/environ.py +285 -0
  45. sglang/srt/eplb/expert_location.py +30 -5
  46. sglang/srt/function_call/function_call_parser.py +3 -2
  47. sglang/srt/function_call/glm4_moe_detector.py +3 -3
  48. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  49. sglang/srt/function_call/json_array_parser.py +63 -0
  50. sglang/srt/function_call/kimik2_detector.py +17 -4
  51. sglang/srt/function_call/utils.py +96 -5
  52. sglang/srt/grpc/compile_proto.py +245 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
  56. sglang/srt/layers/activation.py +7 -6
  57. sglang/srt/layers/attention/aiter_backend.py +14 -15
  58. sglang/srt/layers/attention/ascend_backend.py +108 -9
  59. sglang/srt/layers/attention/attention_registry.py +206 -0
  60. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  61. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  62. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  63. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  66. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  67. sglang/srt/layers/attention/flashinfer_backend.py +112 -194
  68. sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
  69. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  70. sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
  71. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
  72. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
  73. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
  74. sglang/srt/layers/attention/mamba/mamba.py +566 -1
  75. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  76. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  77. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  78. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  79. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  80. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  81. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  82. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  83. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  84. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  85. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  86. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  87. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  88. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  89. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  90. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  91. sglang/srt/layers/attention/nsa/utils.py +24 -0
  92. sglang/srt/layers/attention/nsa_backend.py +887 -0
  93. sglang/srt/layers/attention/tbo_backend.py +6 -6
  94. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  95. sglang/srt/layers/attention/triton_backend.py +42 -9
  96. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  97. sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
  98. sglang/srt/layers/attention/vision.py +58 -0
  99. sglang/srt/layers/attention/wave_backend.py +4 -4
  100. sglang/srt/layers/communicator.py +8 -0
  101. sglang/srt/layers/dp_attention.py +11 -1
  102. sglang/srt/layers/elementwise.py +3 -1
  103. sglang/srt/layers/layernorm.py +2 -0
  104. sglang/srt/layers/linear.py +21 -4
  105. sglang/srt/layers/logits_processor.py +15 -2
  106. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  107. sglang/srt/layers/moe/ep_moe/layer.py +147 -74
  108. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  109. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  110. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  111. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  112. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
  113. sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
  114. sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
  115. sglang/srt/layers/moe/utils.py +10 -0
  116. sglang/srt/layers/parameter.py +23 -6
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  119. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  121. sglang/srt/layers/quantization/fp8.py +2 -2
  122. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  123. sglang/srt/layers/quantization/modelopt_quant.py +44 -9
  124. sglang/srt/layers/quantization/mxfp4.py +12 -4
  125. sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
  126. sglang/srt/layers/quantization/w4afp8.py +0 -4
  127. sglang/srt/layers/quantization/w8a8_int8.py +15 -3
  128. sglang/srt/layers/rotary_embedding.py +78 -31
  129. sglang/srt/layers/sampler.py +52 -4
  130. sglang/srt/layers/utils.py +23 -0
  131. sglang/srt/lora/backend/base_backend.py +3 -3
  132. sglang/srt/lora/backend/chunked_backend.py +348 -0
  133. sglang/srt/lora/backend/triton_backend.py +10 -4
  134. sglang/srt/lora/lora.py +7 -5
  135. sglang/srt/lora/lora_manager.py +17 -6
  136. sglang/srt/lora/mem_pool.py +1 -1
  137. sglang/srt/lora/triton_ops/__init__.py +4 -0
  138. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  139. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  140. sglang/srt/lora/utils.py +7 -5
  141. sglang/srt/managers/cache_controller.py +42 -142
  142. sglang/srt/managers/data_parallel_controller.py +11 -46
  143. sglang/srt/managers/detokenizer_manager.py +11 -11
  144. sglang/srt/managers/io_struct.py +162 -118
  145. sglang/srt/managers/mm_utils.py +43 -6
  146. sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
  147. sglang/srt/managers/multimodal_processor.py +1 -2
  148. sglang/srt/managers/overlap_utils.py +53 -0
  149. sglang/srt/managers/schedule_batch.py +167 -86
  150. sglang/srt/managers/schedule_policy.py +143 -16
  151. sglang/srt/managers/scheduler.py +359 -214
  152. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  153. sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
  154. sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
  155. sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
  156. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  157. sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
  158. sglang/srt/managers/tokenizer_manager.py +84 -136
  159. sglang/srt/managers/tp_worker.py +39 -29
  160. sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
  161. sglang/srt/managers/utils.py +1 -45
  162. sglang/srt/mem_cache/allocator.py +14 -20
  163. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  164. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  165. sglang/srt/mem_cache/chunk_cache.py +8 -1
  166. sglang/srt/mem_cache/evict_policy.py +23 -0
  167. sglang/srt/mem_cache/hicache_storage.py +40 -1
  168. sglang/srt/mem_cache/hiradix_cache.py +119 -32
  169. sglang/srt/mem_cache/memory_pool.py +188 -10
  170. sglang/srt/mem_cache/memory_pool_host.py +134 -182
  171. sglang/srt/mem_cache/radix_cache.py +222 -71
  172. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  173. sglang/srt/mem_cache/storage/__init__.py +10 -0
  174. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  175. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  176. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  177. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  178. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  179. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
  180. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
  181. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
  182. sglang/srt/mem_cache/swa_radix_cache.py +25 -34
  183. sglang/srt/metrics/collector.py +82 -120
  184. sglang/srt/metrics/func_timer.py +2 -7
  185. sglang/srt/metrics/utils.py +8 -1
  186. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  187. sglang/srt/model_executor/cuda_graph_runner.py +39 -32
  188. sglang/srt/model_executor/forward_batch_info.py +23 -38
  189. sglang/srt/model_executor/model_runner.py +131 -183
  190. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  191. sglang/srt/model_loader/loader.py +14 -10
  192. sglang/srt/model_loader/weight_utils.py +156 -2
  193. sglang/srt/models/bailing_moe.py +27 -4
  194. sglang/srt/models/deepseek_nextn.py +6 -1
  195. sglang/srt/models/deepseek_v2.py +536 -153
  196. sglang/srt/models/dots_ocr.py +173 -0
  197. sglang/srt/models/falcon_h1.py +576 -0
  198. sglang/srt/models/gemma3_causal.py +0 -2
  199. sglang/srt/models/gemma3_mm.py +1 -1
  200. sglang/srt/models/gemma3n_mm.py +1 -1
  201. sglang/srt/models/glm4_moe.py +3 -3
  202. sglang/srt/models/glm4_moe_nextn.py +2 -2
  203. sglang/srt/models/glm4v.py +1 -1
  204. sglang/srt/models/glm4v_moe.py +1 -1
  205. sglang/srt/models/gpt_oss.py +7 -30
  206. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  207. sglang/srt/models/llama.py +4 -0
  208. sglang/srt/models/longcat_flash.py +1 -1
  209. sglang/srt/models/longcat_flash_nextn.py +1 -1
  210. sglang/srt/models/mllama4.py +15 -4
  211. sglang/srt/models/qwen2.py +0 -7
  212. sglang/srt/models/qwen2_5_vl.py +2 -2
  213. sglang/srt/models/qwen2_audio.py +1 -1
  214. sglang/srt/models/qwen2_moe.py +64 -1
  215. sglang/srt/models/qwen2_vl.py +1 -1
  216. sglang/srt/models/qwen3.py +18 -3
  217. sglang/srt/models/qwen3_moe.py +31 -3
  218. sglang/srt/models/qwen3_next.py +36 -9
  219. sglang/srt/models/qwen3_vl.py +787 -0
  220. sglang/srt/models/qwen3_vl_moe.py +471 -0
  221. sglang/srt/models/registry.py +15 -3
  222. sglang/srt/models/sarashina2_vision.py +269 -0
  223. sglang/srt/models/solar.py +505 -0
  224. sglang/srt/models/starcoder2.py +357 -0
  225. sglang/srt/models/torch_native_llama.py +9 -2
  226. sglang/srt/models/utils.py +51 -0
  227. sglang/srt/multimodal/processors/base_processor.py +15 -7
  228. sglang/srt/multimodal/processors/dots_vlm.py +2 -3
  229. sglang/srt/multimodal/processors/internvl.py +20 -8
  230. sglang/srt/multimodal/processors/qwen_vl.py +8 -1
  231. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  232. sglang/srt/parser/jinja_template_utils.py +6 -0
  233. sglang/srt/sampling/sampling_batch_info.py +20 -2
  234. sglang/srt/sampling/sampling_params.py +7 -0
  235. sglang/srt/server_args.py +753 -295
  236. sglang/srt/server_args_config_parser.py +146 -0
  237. sglang/srt/single_batch_overlap.py +151 -0
  238. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  239. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  240. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  241. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  242. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  243. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  244. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
  245. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
  246. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
  247. sglang/srt/speculative/eagle_worker.py +57 -25
  248. sglang/srt/speculative/ngram_utils.py +428 -0
  249. sglang/srt/speculative/ngram_worker.py +245 -0
  250. sglang/srt/speculative/spec_info.py +47 -0
  251. sglang/srt/speculative/spec_utils.py +606 -0
  252. sglang/srt/torch_memory_saver_adapter.py +5 -7
  253. sglang/srt/tracing/trace.py +32 -6
  254. sglang/srt/two_batch_overlap.py +8 -5
  255. sglang/srt/utils/__init__.py +2 -0
  256. sglang/srt/{utils.py → utils/common.py} +399 -74
  257. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
  258. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  259. sglang/srt/utils/rpd_utils.py +452 -0
  260. sglang/srt/utils/slow_rank_detector.py +71 -0
  261. sglang/srt/warmup.py +8 -4
  262. sglang/srt/weight_sync/utils.py +1 -1
  263. sglang/test/get_logits_ut.py +57 -0
  264. sglang/test/run_eval.py +79 -11
  265. sglang/test/runners.py +1 -1
  266. sglang/test/simple_eval_common.py +5 -2
  267. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  268. sglang/test/test_block_fp8.py +2 -2
  269. sglang/test/test_deterministic.py +297 -0
  270. sglang/test/test_disaggregation_utils.py +12 -1
  271. sglang/test/test_programs.py +1 -1
  272. sglang/test/test_utils.py +355 -4
  273. sglang/utils.py +10 -1
  274. sglang/version.py +1 -1
  275. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
  276. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
  277. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  278. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  279. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  280. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  281. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  282. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
sglang/bench_one_batch.py CHANGED
@@ -60,7 +60,6 @@ import torch.distributed as dist
60
60
  from sglang.srt.configs.model_config import ModelConfig
61
61
  from sglang.srt.distributed.parallel_state import destroy_distributed_environment
62
62
  from sglang.srt.entrypoints.engine import _set_envs_and_config
63
- from sglang.srt.hf_transformers_utils import get_tokenizer
64
63
  from sglang.srt.layers.moe import initialize_moe_config
65
64
  from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
66
65
  from sglang.srt.managers.scheduler import Scheduler
@@ -78,6 +77,7 @@ from sglang.srt.utils import (
78
77
  set_gpu_proc_affinity,
79
78
  suppress_other_loggers,
80
79
  )
80
+ from sglang.srt.utils.hf_transformers_utils import get_tokenizer
81
81
 
82
82
 
83
83
  @dataclasses.dataclass
@@ -443,11 +443,9 @@ def latency_test_run_once(
443
443
 
444
444
  if profile:
445
445
  profiler.stop()
446
- profile_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_prefill.trace.json.gz"
447
- _save_profile_trace_results(profiler, profile_filename)
448
- rank_print(
449
- f"torch profiler chrome trace for prefill saved to {profile_filename}"
450
- )
446
+ trace_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_prefill.trace.json.gz"
447
+ _save_profile_trace_results(profiler, trace_filename)
448
+ rank_print(f"torch profiler chrome trace for prefill saved to {trace_filename}")
451
449
 
452
450
  # Decode
453
451
  decode_latencies = []
@@ -479,10 +477,10 @@ def latency_test_run_once(
479
477
 
480
478
  if profile and i == output_len / 2:
481
479
  profiler.stop()
482
- profile_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_decode.trace.json.gz"
483
- _save_profile_trace_results(profiler, profile_filename)
480
+ trace_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_decode.trace.json.gz"
481
+ _save_profile_trace_results(profiler, trace_filename)
484
482
  rank_print(
485
- f"torch profiler chrome trace for decoding 1 token saved to {profile_filename}"
483
+ f"torch profiler chrome trace for decoding 1 token saved to {trace_filename}"
486
484
  )
487
485
 
488
486
  # Record decode timing from 2nd output
@@ -9,6 +9,7 @@ python3 -m sglang.bench_one_batch_server --model meta-llama/Meta-Llama-3.1-8B --
9
9
 
10
10
  python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8
11
11
  python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8 --show-report --profile --profile-by-stage
12
+ python3 -m sglang.bench_one_batch_server --model None --base-url http://localhost:30000 --batch-size 16 --input-len 1024 --output-len 8 --output-path results.json --profile
12
13
  """
13
14
 
14
15
  import argparse
@@ -17,12 +18,19 @@ import itertools
17
18
  import json
18
19
  import multiprocessing
19
20
  import os
21
+ import random
20
22
  import time
21
- from typing import List, Tuple
23
+ from typing import List, Optional, Tuple
22
24
 
25
+ import numpy as np
23
26
  import requests
27
+ from pydantic import BaseModel
24
28
 
25
- from sglang.bench_serving import get_tokenizer, sample_random_requests
29
+ from sglang.bench_serving import (
30
+ get_tokenizer,
31
+ sample_mmmu_requests,
32
+ sample_random_requests,
33
+ )
26
34
  from sglang.profiler import run_profile
27
35
  from sglang.srt.entrypoints.http_server import launch_server
28
36
  from sglang.srt.server_args import ServerArgs
@@ -30,9 +38,112 @@ from sglang.srt.utils import is_blackwell, kill_process_tree
30
38
  from sglang.test.test_utils import is_in_ci, write_github_step_summary
31
39
 
32
40
 
41
+ class ProfileLinks(BaseModel):
42
+ """Pydantic model for profile trace links."""
43
+
44
+ extend: Optional[str] = None
45
+ decode: Optional[str] = None
46
+
47
+
48
+ class BenchmarkResult(BaseModel):
49
+ """Pydantic model for benchmark results table data, for a single isl and osl"""
50
+
51
+ model_path: str
52
+ run_name: str
53
+ batch_size: int
54
+ input_len: int
55
+ output_len: int
56
+ latency: float
57
+ ttft: float
58
+ input_throughput: float
59
+ output_throughput: float
60
+ overall_throughput: float
61
+ last_gen_throughput: float
62
+ acc_length: Optional[float] = None
63
+ profile_links: Optional[ProfileLinks] = None
64
+
65
+ @staticmethod
66
+ def help_str() -> str:
67
+ return f"""
68
+ Note: To view the traces through perfetto-ui, please:
69
+ 1. open with Google Chrome
70
+ 2. allow popup
71
+ """
72
+
73
+ def to_markdown_row(
74
+ self, trace_dir, base_url: str = "", relay_base: str = ""
75
+ ) -> str:
76
+ """Convert this benchmark result to a markdown table row."""
77
+ # Calculate costs (assuming H100 pricing for now)
78
+ hourly_cost_per_gpu = 2 # $2/hour for one H100
79
+ hourly_cost = hourly_cost_per_gpu * 1 # Assuming tp_size = 1 for simplicity
80
+ input_util = 0.7
81
+ accept_length = (
82
+ round(self.acc_length, 2) if self.acc_length is not None else "n/a"
83
+ )
84
+ itl = 1 / (self.output_throughput / self.batch_size) * 1000
85
+ input_cost = 1e6 / (self.input_throughput * input_util) / 3600 * hourly_cost
86
+ output_cost = 1e6 / self.output_throughput / 3600 * hourly_cost
87
+
88
+ def get_perfetto_relay_link_from_trace_file(trace_file: str):
89
+ import os
90
+ from urllib.parse import quote
91
+
92
+ rel_path = os.path.relpath(trace_file, trace_dir)
93
+ raw_file_link = f"{base_url}/{rel_path}"
94
+ relay_link = (
95
+ f"{relay_base}?src={quote(raw_file_link, safe='')}"
96
+ if relay_base and quote
97
+ else raw_file_link
98
+ )
99
+ return relay_link
100
+
101
+ # Handle profile links
102
+ profile_link = "NA | NA"
103
+ if self.profile_links:
104
+ if self.profile_links.extend or self.profile_links.decode:
105
+ # Create a combined link or use the first available one
106
+ trace_files = [self.profile_links.extend, self.profile_links.decode]
107
+ trace_files_relay_links = [
108
+ f"[trace]({get_perfetto_relay_link_from_trace_file(trace_file)})"
109
+ for trace_file in trace_files
110
+ ]
111
+
112
+ profile_link = " | ".join(trace_files_relay_links)
113
+
114
+ # Build the row
115
+ return f"| {self.batch_size} | {self.input_len} | {self.latency:.2f} | {self.input_throughput:.2f} | {self.output_throughput:.2f} | {accept_length} | {itl:.2f} | {input_cost:.2f} | {output_cost:.2f} | {profile_link} |\n"
116
+
117
+ @classmethod
118
+ def generate_markdown_report(
119
+ cls, trace_dir, results: List["BenchmarkResult"]
120
+ ) -> str:
121
+ """Generate a markdown report from a list of BenchmarkResult object from a single run."""
122
+ import os
123
+
124
+ summary = f"### {results[0].model_path}\n"
125
+
126
+ # summary += (
127
+ # f"Input lens: {result.input_len}. Output lens: {result.output_len}.\n"
128
+ # )
129
+ summary += "| batch size | input len | latency (s) | input throughput (tok/s) | output throughput (tok/s) | acc length | ITL (ms) | input cost ($/1M) | output cost ($/1M) | profile (extend) | profile (decode)|\n"
130
+ summary += "| ---------- | --------- | ----------- | ------------------------- | ------------------------- | ---------- | -------- | ----------------- | ------------------ | --------------- | -------------- |\n"
131
+
132
+ # all results should share the same isl & osl
133
+ for result in results:
134
+ base_url = os.getenv("TRACE_BASE_URL", "").rstrip("/")
135
+ relay_base = os.getenv("PERFETTO_RELAY_URL", "").rstrip("/")
136
+ relay_base = "https://docs.sglang.ai/ci-data/pages/perfetto_relay.html"
137
+ # base_url = "https://github.com/sgl-project/ci-data/traces"
138
+ summary += result.to_markdown_row(trace_dir, base_url, relay_base)
139
+
140
+ return summary
141
+
142
+
33
143
  @dataclasses.dataclass
34
144
  class BenchArgs:
35
145
  run_name: str = "default"
146
+ seed: int = 42
36
147
  batch_size: Tuple[int] = (1,)
37
148
  input_len: Tuple[int] = (1024,)
38
149
  output_len: Tuple[int] = (16,)
@@ -47,11 +158,17 @@ class BenchArgs:
47
158
  profile: bool = False
48
159
  profile_steps: int = 3
49
160
  profile_by_stage: bool = False
161
+ profile_filename_prefix: str = None
162
+ append_to_github_summary: bool = True
50
163
  dataset_path: str = ""
164
+ parallel_batch: bool = False
165
+ dataset_name: str = "random"
166
+ output_path: Optional[str] = None
51
167
 
52
168
  @staticmethod
53
169
  def add_cli_args(parser: argparse.ArgumentParser):
54
170
  parser.add_argument("--run-name", type=str, default=BenchArgs.run_name)
171
+ parser.add_argument("--seed", type=int, default=BenchArgs.seed)
55
172
  parser.add_argument(
56
173
  "--batch-size", type=int, nargs="+", default=BenchArgs.batch_size
57
174
  )
@@ -62,6 +179,13 @@ class BenchArgs:
62
179
  "--output-len", type=int, nargs="+", default=BenchArgs.output_len
63
180
  )
64
181
  parser.add_argument("--temperature", type=float, default=BenchArgs.temperature)
182
+ parser.add_argument(
183
+ "--dataset-name",
184
+ type=str,
185
+ default=BenchArgs.dataset_name,
186
+ choices=["mmmu", "random"],
187
+ help="Name of the dataset to benchmark on.",
188
+ )
65
189
  parser.add_argument("--return-logprob", action="store_true")
66
190
  parser.add_argument(
67
191
  "--client-stream-interval",
@@ -90,14 +214,37 @@ class BenchArgs:
90
214
  default=BenchArgs.dataset_path,
91
215
  help="Path to the dataset.",
92
216
  )
217
+ parser.add_argument("--parallel-batch", action="store_true")
218
+ parser.add_argument(
219
+ "--profile-filename-prefix",
220
+ type=str,
221
+ default=BenchArgs.profile_filename_prefix,
222
+ )
223
+ parser.add_argument(
224
+ "--no-append-to-github-summary",
225
+ action="store_false",
226
+ dest="append_to_github_summary",
227
+ help="Disable appending the output of this run to github ci summary",
228
+ )
229
+ parser.add_argument(
230
+ "--output-path",
231
+ type=str,
232
+ default=BenchArgs.output_path,
233
+ help="Path to save benchmark results as JSON format. If not specified, results will only be saved to result-filename.",
234
+ )
93
235
 
94
236
  @classmethod
95
237
  def from_cli_args(cls, args: argparse.Namespace):
96
238
  # use the default value's type to cast the args into correct types.
97
239
  attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)]
98
- return cls(
99
- **{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs}
100
- )
240
+ kwargs = {}
241
+ for attr, attr_type in attrs:
242
+ val = getattr(args, attr)
243
+ if attr_type is type(None):
244
+ kwargs[attr] = val
245
+ else:
246
+ kwargs[attr] = attr_type(val)
247
+ return cls(**kwargs)
101
248
 
102
249
 
103
250
  def launch_server_internal(server_args):
@@ -142,22 +289,35 @@ def run_one_case(
142
289
  run_name: str,
143
290
  result_filename: str,
144
291
  tokenizer,
292
+ dataset_name="",
145
293
  profile: bool = False,
146
294
  profile_steps: int = 3,
147
295
  profile_by_stage: bool = False,
296
+ profile_filename_prefix: str = None,
148
297
  dataset_path: str = "",
298
+ parallel_batch: bool = False,
149
299
  ):
150
300
  requests.post(url + "/flush_cache")
151
- input_requests = sample_random_requests(
152
- input_len=input_len,
153
- output_len=output_len,
154
- num_prompts=batch_size,
155
- range_ratio=1.0,
156
- tokenizer=tokenizer,
157
- dataset_path=dataset_path,
158
- random_sample=True,
159
- return_text=False,
160
- )
301
+ # TODO: reuse bench_serving.get_dataset ?
302
+ if dataset_name == "mmmu":
303
+ input_requests = sample_mmmu_requests(
304
+ num_requests=batch_size,
305
+ tokenizer=tokenizer,
306
+ fixed_output_len=output_len,
307
+ apply_chat_template=True,
308
+ random_sample=False,
309
+ )
310
+ elif dataset_name == "random":
311
+ input_requests = sample_random_requests(
312
+ input_len=input_len,
313
+ output_len=output_len,
314
+ num_prompts=batch_size,
315
+ range_ratio=1.0,
316
+ tokenizer=tokenizer,
317
+ dataset_path=dataset_path,
318
+ random_sample=True,
319
+ return_text=False,
320
+ )
161
321
 
162
322
  use_structured_outputs = False
163
323
  if use_structured_outputs:
@@ -174,25 +334,48 @@ def run_one_case(
174
334
 
175
335
  profile_link = None
176
336
  if profile:
337
+ output_dir, profile_name = None, None
338
+ if profile_filename_prefix:
339
+ output_dir = os.path.dirname(profile_filename_prefix)
340
+ profile_name = os.path.basename(profile_filename_prefix)
177
341
  profile_link: str = run_profile(
178
- url, profile_steps, ["CPU", "GPU"], None, None, profile_by_stage
342
+ url,
343
+ profile_steps,
344
+ ["CPU", "GPU"],
345
+ output_dir,
346
+ profile_name,
347
+ profile_by_stage,
179
348
  )
180
349
 
181
350
  tic = time.perf_counter()
351
+
352
+ payload = {
353
+ "sampling_params": {
354
+ "temperature": temperature,
355
+ "max_new_tokens": output_len,
356
+ "ignore_eos": True,
357
+ "json_schema": json_schema,
358
+ "stream_interval": stream_interval,
359
+ },
360
+ "return_logprob": return_logprob,
361
+ "stream": True,
362
+ **({"parallel_batch": parallel_batch} if parallel_batch else {}),
363
+ }
364
+ if dataset_name == "mmmu":
365
+ # vlm
366
+ input_ids = []
367
+ for input_req in input_requests:
368
+ input_ids += [tokenizer.encode(input_req.prompt)]
369
+ payload["image_data"] = [req.image_data for req in input_requests]
370
+
371
+ else:
372
+ input_ids = [req.prompt for req in input_requests]
373
+
374
+ payload["input_ids"] = input_ids
375
+
182
376
  response = requests.post(
183
377
  url + "/generate",
184
- json={
185
- "input_ids": [req.prompt for req in input_requests],
186
- "sampling_params": {
187
- "temperature": temperature,
188
- "max_new_tokens": output_len,
189
- "ignore_eos": True,
190
- "json_schema": json_schema,
191
- "stream_interval": stream_interval,
192
- },
193
- "return_logprob": return_logprob,
194
- "stream": True,
195
- },
378
+ json=payload,
196
379
  stream=True,
197
380
  )
198
381
 
@@ -256,10 +439,100 @@ def run_one_case(
256
439
  overall_throughput,
257
440
  last_gen_throughput,
258
441
  acc_length,
259
- profile_link if profile else None,
442
+ profile_link,
260
443
  )
261
444
 
262
445
 
446
+ def save_results_as_json(result: List[Tuple], bench_args: BenchArgs, model: str):
447
+ """Save benchmark results as JSON using Pydantic models."""
448
+ json_results = []
449
+
450
+ # Generate all parameter combinations to match with results
451
+ param_combinations = list(
452
+ itertools.product(
453
+ bench_args.batch_size, bench_args.input_len, bench_args.output_len
454
+ )
455
+ )
456
+
457
+ for i, (
458
+ batch_size,
459
+ latency,
460
+ ttft,
461
+ input_throughput,
462
+ output_throughput,
463
+ overall_throughput,
464
+ last_gen_throughput,
465
+ acc_length,
466
+ profile_link,
467
+ ) in enumerate(result):
468
+ # Get the corresponding parameters for this result
469
+ bs, input_len, output_len = param_combinations[i]
470
+
471
+ # Parse profile links if available
472
+ profile_links = None
473
+ if profile_link:
474
+ profile_links = parse_profile_links(
475
+ profile_link, batch_size, input_len, output_len
476
+ )
477
+
478
+ benchmark_result = BenchmarkResult(
479
+ model_path=model,
480
+ run_name=bench_args.run_name,
481
+ batch_size=batch_size,
482
+ input_len=input_len,
483
+ output_len=output_len,
484
+ latency=latency,
485
+ ttft=ttft,
486
+ input_throughput=input_throughput,
487
+ output_throughput=output_throughput,
488
+ overall_throughput=overall_throughput,
489
+ last_gen_throughput=last_gen_throughput,
490
+ acc_length=acc_length,
491
+ profile_links=profile_links,
492
+ )
493
+ json_results.append(benchmark_result.model_dump())
494
+
495
+ # Save to JSON file
496
+ with open(bench_args.output_path, "w", encoding="utf-8") as f:
497
+ json.dump(json_results, f, indent=2, ensure_ascii=False)
498
+
499
+ print(f"Results saved as JSON to {bench_args.output_path}")
500
+
501
+
502
+ def parse_profile_links(
503
+ profile_dir: str, batch_size: int, input_len: int, output_len: int
504
+ ) -> Optional[ProfileLinks]:
505
+ """Parse profile directory to extract extend and decode trace file links."""
506
+ if not profile_dir or not os.path.exists(profile_dir):
507
+ return None
508
+
509
+ extend_link = None
510
+ decode_link = None
511
+
512
+ # Look for extend/prefill trace files
513
+ for file in os.listdir(profile_dir):
514
+ if file.endswith(".trace.json.gz") or file.endswith(".trace.json"):
515
+ if "extend" in file.lower() or "prefill" in file.lower():
516
+ extend_link = os.path.join(profile_dir, file)
517
+ elif "decode" in file.lower():
518
+ decode_link = os.path.join(profile_dir, file)
519
+
520
+ # If no specific extend/decode files found, try to find files with batch/input/output info
521
+ if not extend_link or not decode_link:
522
+ for file in os.listdir(profile_dir):
523
+ if file.endswith(".trace.json.gz") or file.endswith(".trace.json"):
524
+ if f"_batch{batch_size}_input{input_len}_output{output_len}_" in file:
525
+ if "prefill" in file.lower() or "extend" in file.lower():
526
+ extend_link = os.path.join(profile_dir, file)
527
+ elif "decode" in file.lower():
528
+ decode_link = os.path.join(profile_dir, file)
529
+
530
+ if extend_link or decode_link:
531
+ return ProfileLinks(extend=extend_link, decode=decode_link)
532
+
533
+ return None
534
+
535
+
263
536
  def get_report_summary(
264
537
  result: List[Tuple], server_args: ServerArgs, bench_args: BenchArgs
265
538
  ):
@@ -350,10 +623,12 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
350
623
  return_logprob=bench_args.return_logprob,
351
624
  stream_interval=bench_args.client_stream_interval,
352
625
  input_len_step_percentage=bench_args.input_len_step_percentage,
626
+ dataset_name=bench_args.dataset_name,
353
627
  run_name="",
354
628
  result_filename="",
355
629
  tokenizer=tokenizer,
356
630
  dataset_path=bench_args.dataset_path,
631
+ parallel_batch=bench_args.parallel_batch,
357
632
  )
358
633
  print("=" * 8 + " Warmup End " + "=" * 8 + "\n")
359
634
 
@@ -375,8 +650,12 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
375
650
  stream_interval=bench_args.client_stream_interval,
376
651
  input_len_step_percentage=bench_args.input_len_step_percentage,
377
652
  run_name=bench_args.run_name,
653
+ dataset_name=bench_args.dataset_name,
378
654
  result_filename=bench_args.result_filename,
379
655
  tokenizer=tokenizer,
656
+ dataset_path=bench_args.dataset_path,
657
+ parallel_batch=bench_args.parallel_batch,
658
+ profile_filename_prefix=bench_args.profile_filename_prefix,
380
659
  )
381
660
  )
382
661
 
@@ -399,9 +678,13 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
399
678
  run_name=bench_args.run_name,
400
679
  result_filename=bench_args.result_filename,
401
680
  tokenizer=tokenizer,
681
+ dataset_name=bench_args.dataset_name,
402
682
  profile=bench_args.profile,
403
683
  profile_steps=bench_args.profile_steps,
404
684
  profile_by_stage=bench_args.profile_by_stage,
685
+ dataset_path=bench_args.dataset_path,
686
+ parallel_batch=bench_args.parallel_batch,
687
+ profile_filename_prefix=bench_args.profile_filename_prefix,
405
688
  )[-1],
406
689
  )
407
690
  )
@@ -414,13 +697,16 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
414
697
 
415
698
  print(f"\nResults are saved to {bench_args.result_filename}")
416
699
 
700
+ # Save results as JSON if output_path is specified
701
+ if bench_args.output_path:
702
+ save_results_as_json(result, bench_args, model=server_args.model_path)
703
+
417
704
  if not bench_args.show_report:
418
705
  return
419
706
 
420
707
  summary = get_report_summary(result, server_args, bench_args)
421
- print(summary)
422
708
 
423
- if is_in_ci():
709
+ if is_in_ci() and bench_args.append_to_github_summary:
424
710
  write_github_step_summary(summary)
425
711
 
426
712
 
@@ -429,6 +715,10 @@ def main():
429
715
  ServerArgs.add_cli_args(parser)
430
716
  BenchArgs.add_cli_args(parser)
431
717
  args = parser.parse_args()
718
+
719
+ random.seed(args.seed)
720
+ np.random.seed(args.seed)
721
+
432
722
  server_args = ServerArgs.from_cli_args(args)
433
723
  bench_args = BenchArgs.from_cli_args(args)
434
724
 
sglang/bench_serving.py CHANGED
@@ -208,6 +208,10 @@ async def async_request_openai_completions(
208
208
  "ignore_eos": not args.disable_ignore_eos,
209
209
  **request_func_input.extra_request_body,
210
210
  }
211
+
212
+ if request_func_input.image_data:
213
+ payload.update({"image_data": request_func_input.image_data})
214
+
211
215
  headers = get_auth_headers()
212
216
 
213
217
  output = RequestFuncOutput.init_new(request_func_input)
@@ -631,7 +635,7 @@ def get_tokenizer(
631
635
  if pretrained_model_name_or_path.endswith(
632
636
  ".json"
633
637
  ) or pretrained_model_name_or_path.endswith(".model"):
634
- from sglang.srt.hf_transformers_utils import get_tokenizer
638
+ from sglang.srt.utils.hf_transformers_utils import get_tokenizer
635
639
 
636
640
  return get_tokenizer(pretrained_model_name_or_path)
637
641
 
@@ -1110,7 +1114,8 @@ def sample_sharegpt_requests(
1110
1114
  add_generation_prompt=True,
1111
1115
  tokenize=False,
1112
1116
  )
1113
- prompt = prompt.replace(tokenizer.bos_token, "")
1117
+ if tokenizer.bos_token:
1118
+ prompt = prompt.replace(tokenizer.bos_token, "")
1114
1119
 
1115
1120
  prompt_token_ids = tokenizer.encode(prompt)
1116
1121
  completion = dataset[i][1]
@@ -1758,7 +1763,9 @@ async def benchmark(
1758
1763
  pbar.close()
1759
1764
 
1760
1765
  if "sglang" in backend:
1761
- server_info = requests.get(base_url + "/get_server_info")
1766
+ server_info = requests.get(
1767
+ base_url + "/get_server_info", headers=get_auth_headers()
1768
+ )
1762
1769
  if server_info.status_code == 200:
1763
1770
  server_info_json = server_info.json()
1764
1771
  if "decode" in server_info_json:
sglang/global_config.py CHANGED
@@ -37,8 +37,8 @@ class GlobalConfig:
37
37
  )
38
38
  # Runtime constants: others
39
39
  self.retract_decode_steps = 20
40
- self.flashinfer_workspace_size = os.environ.get(
41
- "FLASHINFER_WORKSPACE_SIZE", 384 * 1024 * 1024
40
+ self.flashinfer_workspace_size = int(
41
+ os.environ.get("FLASHINFER_WORKSPACE_SIZE", 384 * 1024 * 1024)
42
42
  )
43
43
 
44
44
  # Output tokenization configs
@@ -433,7 +433,7 @@ class Runtime:
433
433
  self.endpoint.cache_prefix(prefix)
434
434
 
435
435
  def get_tokenizer(self):
436
- from sglang.srt.hf_transformers_utils import get_tokenizer
436
+ from sglang.srt.utils.hf_transformers_utils import get_tokenizer
437
437
 
438
438
  return get_tokenizer(
439
439
  self.server_args.tokenizer_path,
sglang/launch_server.py CHANGED
@@ -7,9 +7,23 @@ from sglang.srt.entrypoints.http_server import launch_server
7
7
  from sglang.srt.server_args import prepare_server_args
8
8
  from sglang.srt.utils import kill_process_tree
9
9
 
10
+ MOVE_ENVS_WARN = """
11
+ ########################################################################
12
+ # For contributors and developers: #
13
+ # Please move environment variable definitions to sglang.srt.environ #
14
+ # using the following pattern: #
15
+ # SGLANG_XXX = EnvBool(False) #
16
+ # #
17
+ ########################################################################
18
+ """
19
+
10
20
  if __name__ == "__main__":
11
21
  server_args = prepare_server_args(sys.argv[1:])
12
22
 
23
+ from sglang.srt.server_args import print_deprecated_warning
24
+
25
+ print_deprecated_warning(MOVE_ENVS_WARN)
26
+
13
27
  try:
14
28
  launch_server(server_args)
15
29
  finally:
sglang/profiler.py CHANGED
@@ -15,7 +15,7 @@ from typing import List, Optional
15
15
 
16
16
  import requests
17
17
 
18
- PARENT_FOLDER = "/tmp/sglang-profile"
18
+ PROFILER_DIR = os.getenv("SGLANG_TORCH_PROFILER_DIR", "/tmp")
19
19
 
20
20
 
21
21
  def _run_profile(
@@ -27,7 +27,7 @@ def _run_profile(
27
27
  profile_by_stage: bool = False,
28
28
  ) -> str:
29
29
  if output_dir is None:
30
- output_dir = PARENT_FOLDER
30
+ output_dir = PROFILER_DIR
31
31
 
32
32
  output_dir = os.path.normpath(output_dir)
33
33
  output_dir = os.path.abspath(output_dir)
@@ -0,0 +1,27 @@
1
+ # Adapted from https://github.com/thinking-machines-lab/batch_invariant_ops/blob/main/batch_invariant_ops/__init__.py
2
+
3
+ from .batch_invariant_ops import (
4
+ AttentionBlockSize,
5
+ disable_batch_invariant_mode,
6
+ enable_batch_invariant_mode,
7
+ get_batch_invariant_attention_block_size,
8
+ is_batch_invariant_mode_enabled,
9
+ log_softmax,
10
+ matmul_persistent,
11
+ mean_dim,
12
+ set_batch_invariant_mode,
13
+ )
14
+
15
+ __version__ = "0.1.0"
16
+
17
+ __all__ = [
18
+ "set_batch_invariant_mode",
19
+ "is_batch_invariant_mode_enabled",
20
+ "disable_batch_invariant_mode",
21
+ "enable_batch_invariant_mode",
22
+ "matmul_persistent",
23
+ "log_softmax",
24
+ "mean_dim",
25
+ "get_batch_invariant_attention_block_size",
26
+ "AttentionBlockSize",
27
+ ]