sglang 0.4.3.post1__py3-none-any.whl → 0.4.3.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. sglang/api.py +1 -1
  2. sglang/bench_offline_throughput.py +19 -0
  3. sglang/bench_one_batch.py +2 -2
  4. sglang/bench_serving.py +123 -79
  5. sglang/global_config.py +8 -3
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/lang/ir.py +1 -1
  8. sglang/srt/_custom_ops.py +83 -91
  9. sglang/srt/configs/load_config.py +4 -1
  10. sglang/srt/configs/model_config.py +48 -2
  11. sglang/srt/configs/qwen2_5_vl_config.py +5 -2
  12. sglang/srt/constrained/base_grammar_backend.py +117 -15
  13. sglang/srt/constrained/llguidance_backend.py +151 -0
  14. sglang/srt/constrained/outlines_backend.py +24 -33
  15. sglang/srt/constrained/xgrammar_backend.py +69 -38
  16. sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
  17. sglang/srt/distributed/parallel_state.py +48 -3
  18. sglang/srt/entrypoints/engine.py +67 -9
  19. sglang/srt/entrypoints/http_server.py +190 -41
  20. sglang/srt/entrypoints/verl_engine.py +147 -0
  21. sglang/srt/function_call_parser.py +0 -1
  22. sglang/srt/layers/activation.py +11 -0
  23. sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
  24. sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
  25. sglang/srt/layers/attention/flashinfer_backend.py +208 -295
  26. sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
  27. sglang/srt/layers/attention/torch_native_backend.py +1 -1
  28. sglang/srt/layers/attention/triton_backend.py +9 -6
  29. sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
  30. sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
  31. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
  32. sglang/srt/layers/attention/utils.py +39 -0
  33. sglang/srt/layers/attention/vision.py +60 -63
  34. sglang/srt/layers/dp_attention.py +142 -1
  35. sglang/srt/layers/layernorm.py +1 -1
  36. sglang/srt/layers/linear.py +3 -1
  37. sglang/srt/layers/logits_processor.py +281 -45
  38. sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
  39. sglang/srt/layers/moe/ep_moe/layer.py +140 -28
  40. sglang/srt/layers/moe/fused_moe_native.py +2 -0
  41. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  43. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
  47. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
  48. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
  51. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
  55. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
  56. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
  61. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
  62. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
  63. sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
  64. sglang/srt/layers/moe/topk.py +13 -4
  65. sglang/srt/layers/quantization/__init__.py +111 -7
  66. sglang/srt/layers/quantization/blockwise_int8.py +409 -0
  67. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  68. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  69. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  70. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  71. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  72. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  73. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  74. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  75. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  76. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  77. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  78. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  79. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  80. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  81. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  82. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  83. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  84. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  85. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  86. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  87. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  88. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  89. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  90. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  91. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  92. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  93. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  94. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  95. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  96. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  97. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  98. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  99. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  100. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  101. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  102. sglang/srt/layers/quantization/fp8.py +69 -28
  103. sglang/srt/layers/quantization/fp8_utils.py +17 -1
  104. sglang/srt/layers/quantization/gptq.py +416 -0
  105. sglang/srt/layers/quantization/int8_kernel.py +327 -0
  106. sglang/srt/layers/quantization/int8_utils.py +73 -0
  107. sglang/srt/layers/quantization/modelopt_quant.py +18 -1
  108. sglang/srt/layers/radix_attention.py +1 -0
  109. sglang/srt/layers/rotary_embedding.py +0 -1
  110. sglang/srt/layers/sampler.py +76 -31
  111. sglang/srt/layers/vocab_parallel_embedding.py +14 -13
  112. sglang/srt/lora/lora.py +17 -1
  113. sglang/srt/lora/lora_config.py +5 -0
  114. sglang/srt/lora/lora_manager.py +1 -3
  115. sglang/srt/managers/cache_controller.py +193 -62
  116. sglang/srt/managers/configure_logging.py +2 -1
  117. sglang/srt/managers/data_parallel_controller.py +6 -2
  118. sglang/srt/managers/detokenizer_manager.py +124 -102
  119. sglang/srt/managers/image_processor.py +2 -1
  120. sglang/srt/managers/io_struct.py +143 -6
  121. sglang/srt/managers/schedule_batch.py +238 -197
  122. sglang/srt/managers/schedule_policy.py +29 -29
  123. sglang/srt/managers/scheduler.py +681 -259
  124. sglang/srt/managers/session_controller.py +6 -2
  125. sglang/srt/managers/tokenizer_manager.py +224 -68
  126. sglang/srt/managers/tp_worker.py +15 -4
  127. sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
  128. sglang/srt/mem_cache/chunk_cache.py +18 -11
  129. sglang/srt/mem_cache/hiradix_cache.py +394 -0
  130. sglang/srt/mem_cache/memory_pool.py +44 -18
  131. sglang/srt/mem_cache/radix_cache.py +58 -47
  132. sglang/srt/metrics/collector.py +94 -36
  133. sglang/srt/model_executor/cuda_graph_runner.py +55 -24
  134. sglang/srt/model_executor/forward_batch_info.py +49 -16
  135. sglang/srt/model_executor/model_runner.py +209 -28
  136. sglang/srt/model_loader/loader.py +3 -3
  137. sglang/srt/model_loader/weight_utils.py +36 -14
  138. sglang/srt/models/baichuan.py +31 -6
  139. sglang/srt/models/chatglm.py +39 -7
  140. sglang/srt/models/commandr.py +29 -5
  141. sglang/srt/models/dbrx.py +31 -5
  142. sglang/srt/models/deepseek.py +43 -6
  143. sglang/srt/models/deepseek_nextn.py +32 -19
  144. sglang/srt/models/deepseek_v2.py +265 -29
  145. sglang/srt/models/exaone.py +19 -9
  146. sglang/srt/models/gemma.py +22 -8
  147. sglang/srt/models/gemma2.py +25 -12
  148. sglang/srt/models/gemma2_reward.py +5 -1
  149. sglang/srt/models/gpt2.py +28 -13
  150. sglang/srt/models/gpt_bigcode.py +27 -5
  151. sglang/srt/models/granite.py +21 -9
  152. sglang/srt/models/grok.py +21 -4
  153. sglang/srt/models/internlm2.py +36 -6
  154. sglang/srt/models/internlm2_reward.py +5 -1
  155. sglang/srt/models/llama.py +26 -9
  156. sglang/srt/models/llama_classification.py +5 -1
  157. sglang/srt/models/llama_eagle.py +17 -4
  158. sglang/srt/models/llama_embedding.py +5 -1
  159. sglang/srt/models/llama_reward.py +7 -2
  160. sglang/srt/models/llava.py +19 -3
  161. sglang/srt/models/llavavid.py +10 -1
  162. sglang/srt/models/minicpm.py +26 -2
  163. sglang/srt/models/minicpm3.py +39 -3
  164. sglang/srt/models/minicpmv.py +45 -14
  165. sglang/srt/models/mixtral.py +20 -9
  166. sglang/srt/models/mixtral_quant.py +50 -8
  167. sglang/srt/models/mllama.py +57 -11
  168. sglang/srt/models/olmo.py +34 -6
  169. sglang/srt/models/olmo2.py +34 -13
  170. sglang/srt/models/olmoe.py +26 -4
  171. sglang/srt/models/phi3_small.py +29 -10
  172. sglang/srt/models/qwen.py +26 -3
  173. sglang/srt/models/qwen2.py +26 -4
  174. sglang/srt/models/qwen2_5_vl.py +46 -8
  175. sglang/srt/models/qwen2_eagle.py +17 -5
  176. sglang/srt/models/qwen2_moe.py +44 -6
  177. sglang/srt/models/qwen2_rm.py +78 -0
  178. sglang/srt/models/qwen2_vl.py +39 -8
  179. sglang/srt/models/stablelm.py +32 -5
  180. sglang/srt/models/torch_native_llama.py +5 -2
  181. sglang/srt/models/xverse.py +21 -9
  182. sglang/srt/models/xverse_moe.py +45 -7
  183. sglang/srt/models/yivl.py +2 -1
  184. sglang/srt/openai_api/adapter.py +109 -24
  185. sglang/srt/openai_api/protocol.py +17 -1
  186. sglang/srt/reasoning_parser.py +154 -0
  187. sglang/srt/sampling/penaltylib/__init__.py +4 -6
  188. sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
  189. sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
  190. sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
  191. sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
  192. sglang/srt/sampling/sampling_batch_info.py +79 -157
  193. sglang/srt/sampling/sampling_params.py +16 -13
  194. sglang/srt/server_args.py +136 -52
  195. sglang/srt/speculative/build_eagle_tree.py +2 -8
  196. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +0 -1
  197. sglang/srt/speculative/eagle_utils.py +92 -58
  198. sglang/srt/speculative/eagle_worker.py +186 -94
  199. sglang/srt/speculative/spec_info.py +1 -13
  200. sglang/srt/utils.py +43 -17
  201. sglang/srt/warmup.py +47 -0
  202. sglang/test/few_shot_gsm8k.py +4 -1
  203. sglang/test/runners.py +389 -126
  204. sglang/test/send_one.py +88 -0
  205. sglang/test/test_block_fp8_ep.py +361 -0
  206. sglang/test/test_programs.py +1 -1
  207. sglang/test/test_utils.py +138 -84
  208. sglang/utils.py +50 -60
  209. sglang/version.py +1 -1
  210. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/METADATA +21 -15
  211. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/RECORD +214 -166
  212. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/WHEEL +1 -1
  213. sglang/bench_latency.py +0 -1
  214. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
  215. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
  216. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
  217. sglang/test/srt/sampling/penaltylib/utils.py +0 -344
  218. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/LICENSE +0 -0
  219. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/top_level.txt +0 -0
sglang/test/runners.py CHANGED
@@ -15,15 +15,15 @@
15
15
  import multiprocessing as mp
16
16
  import os
17
17
  from dataclasses import dataclass
18
- from typing import List, Union
18
+ from typing import List, Optional, Tuple, Union
19
19
 
20
20
  import torch
21
21
  import torch.nn.functional as F
22
22
  from transformers import AutoModelForCausalLM
23
23
 
24
- from sglang.srt.entrypoints.engine import Engine
25
24
  from sglang.srt.hf_transformers_utils import get_tokenizer
26
- from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER
25
+ from sglang.srt.server import Engine
26
+ from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER, calculate_rouge_l
27
27
 
28
28
  DEFAULT_PROMPTS = [
29
29
  "Apple is red. Banana is Yellow. " * 800 + "Apple is",
@@ -56,6 +56,13 @@ def get_top_logprobs(logits, k):
56
56
  return logprobs
57
57
 
58
58
 
59
+ def get_token_ids_logprobs(logits, token_ids):
60
+ logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
61
+ del logits
62
+ logprobs = logprobs[..., token_ids]
63
+ return logprobs
64
+
65
+
59
66
  def _get_sentence_transformer_embedding_model(model_path, torch_dtype):
60
67
  from sentence_transformers import SentenceTransformer
61
68
  from sentence_transformers.util import is_sentence_transformer_model
@@ -84,8 +91,13 @@ class ModelOutput:
84
91
  output_ids: List[int] = None
85
92
  top_input_logprobs: List[torch.Tensor] = None
86
93
  top_output_logprobs: List[torch.Tensor] = None
94
+ top_output_logprob_idx: List[List[int]] = None
87
95
  embed_logits: List[torch.Tensor] = None
88
96
  scores: List[float] = None
97
+ input_token_logprobs_lst: List[List[Tuple[float, int, None]]] = None
98
+ output_token_logprobs_lst: List[List[Tuple[float, int, None]]] = None
99
+ token_ids_input_logprobs: List[torch.Tensor] = None
100
+ token_ids_output_logprobs: List[torch.Tensor] = None
89
101
 
90
102
 
91
103
  class HFRunner:
@@ -95,9 +107,11 @@ class HFRunner:
95
107
  torch_dtype: torch.dtype,
96
108
  model_type: str = "generation",
97
109
  output_str_only: bool = False,
110
+ trust_remote_code: bool = False,
98
111
  ):
99
112
  self.model_type = model_type
100
113
  self.output_str_only = output_str_only
114
+ self.trust_remote_code = trust_remote_code
101
115
 
102
116
  self.in_queue = mp.Queue()
103
117
  self.out_queue = mp.Queue()
@@ -130,7 +144,7 @@ class HFRunner:
130
144
  self.base_model = AutoModelForCausalLM.from_pretrained(
131
145
  model_path,
132
146
  torch_dtype=torch_dtype,
133
- trust_remote_code=False,
147
+ trust_remote_code=self.trust_remote_code,
134
148
  low_cpu_mem_usage=True,
135
149
  ).cuda()
136
150
  elif self.model_type == "embedding":
@@ -147,79 +161,32 @@ class HFRunner:
147
161
  ).cuda()
148
162
  else:
149
163
  raise Exception(f"Unrecognized model type {self.model_type}")
150
- self.tokenizer = get_tokenizer(model_path, torch_dtype=torch.dtype)
164
+ self.tokenizer = get_tokenizer(
165
+ model_path,
166
+ torch_dtype=torch.dtype,
167
+ trust_remote_code=self.trust_remote_code,
168
+ )
151
169
 
152
170
  # Run forward
153
171
  while True:
154
- prompts, max_new_tokens, lora_paths = in_queue.get()
172
+ prompts, max_new_tokens, lora_paths, token_ids_logprob = in_queue.get()
155
173
  if lora_paths is not None:
156
174
  assert len(prompts) == len(lora_paths)
157
175
 
158
176
  if prompts is not None:
159
177
  if self.model_type == "generation":
160
- output_strs = []
161
- top_input_logprobs = []
162
- top_output_logprobs = []
163
- for i, p in enumerate(prompts):
164
- if isinstance(p, str):
165
- input_ids = self.tokenizer.encode(
166
- p, return_tensors="pt"
167
- ).cuda()
168
- else:
169
- input_ids = torch.tensor([p], device="cuda")
170
-
171
- if lora_paths is not None and lora_paths[i] is not None:
172
- from peft import PeftModel
173
-
174
- self.model = PeftModel.from_pretrained(
175
- self.base_model,
176
- lora_paths[i],
177
- torch_dtype=torch_dtype,
178
- is_trainable=False,
179
- )
180
- else:
181
- self.model = self.base_model
182
-
183
- outputs = self.model.generate(
184
- input_ids,
185
- do_sample=False,
186
- temperature=None,
187
- top_p=None,
188
- max_new_tokens=max_new_tokens,
189
- return_dict_in_generate=True,
190
- output_scores=(not self.output_str_only),
191
- )
192
- output_strs.append(
193
- self.tokenizer.decode(outputs[0][0][len(input_ids[0]) :])
194
- )
195
- if not self.output_str_only:
196
- # outputs.scores: (num_token, 1, vocab_size)
197
- top_output_logprobs.append(
198
- [
199
- get_top_logprobs(
200
- logits[0], NUM_TOP_LOGPROBS
201
- ).tolist()
202
- for logits in outputs.scores
203
- ]
204
- )
205
- del outputs
206
-
207
- input_logits = self.model.forward(input_ids).logits[0]
208
- top_input_logprobs.append(
209
- get_top_logprobs(
210
- input_logits, NUM_TOP_LOGPROBS
211
- ).tolist()
212
- )
213
- del input_logits
214
-
215
178
  out_queue.put(
216
- ModelOutput(
217
- output_strs=output_strs,
218
- top_input_logprobs=top_input_logprobs,
219
- top_output_logprobs=top_output_logprobs,
179
+ self.forward_generation_raw(
180
+ base_model=self.base_model,
181
+ prompts=prompts,
182
+ max_new_tokens=max_new_tokens,
183
+ tokenizer=self.tokenizer,
184
+ lora_paths=lora_paths,
185
+ torch_dtype=torch_dtype,
186
+ output_str_only=self.output_str_only,
187
+ token_ids_logprob=token_ids_logprob,
220
188
  )
221
189
  )
222
-
223
190
  elif self.model_type == "embedding":
224
191
  assert not self.output_str_only
225
192
  logits = self.model.encode(prompts).tolist()
@@ -244,10 +211,11 @@ class HFRunner:
244
211
  def forward(
245
212
  self,
246
213
  prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
247
- max_new_tokens=8,
248
- lora_paths=None,
214
+ max_new_tokens: int = 8,
215
+ lora_paths: Optional[List[str]] = None,
216
+ token_ids_logprob: Optional[int] = None,
249
217
  ):
250
- self.in_queue.put((prompts, max_new_tokens, lora_paths))
218
+ self.in_queue.put((prompts, max_new_tokens, lora_paths, token_ids_logprob))
251
219
  return self.out_queue.get()
252
220
 
253
221
  def terminate(self):
@@ -261,6 +229,101 @@ class HFRunner:
261
229
  self.model_proc.terminate()
262
230
  self.in_queue = self.out_queue = None
263
231
 
232
+ @staticmethod
233
+ def forward_generation_raw(
234
+ base_model,
235
+ prompts: Union[List[str], List[torch.Tensor]],
236
+ max_new_tokens: int,
237
+ tokenizer,
238
+ torch_dtype: torch.dtype,
239
+ lora_paths: Optional[List[str]] = None,
240
+ output_str_only: bool = False,
241
+ token_ids_logprob: Optional[int] = None,
242
+ ) -> ModelOutput:
243
+ output_strs = []
244
+ top_input_logprobs = []
245
+ top_output_logprobs = []
246
+ if token_ids_logprob is not None:
247
+ token_ids_input_logprobs = []
248
+ token_ids_output_logprobs = []
249
+ else:
250
+ token_ids_input_logprobs = token_ids_output_logprobs = None
251
+
252
+ for i, p in enumerate(prompts):
253
+ if isinstance(p, str):
254
+ input_ids = tokenizer.encode(p, return_tensors="pt").cuda()
255
+ else:
256
+ input_ids = torch.tensor([p], device="cuda")
257
+
258
+ if lora_paths is not None and lora_paths[i] is not None:
259
+ from peft import PeftModel
260
+
261
+ model = PeftModel.from_pretrained(
262
+ base_model,
263
+ lora_paths[i],
264
+ torch_dtype=torch_dtype,
265
+ is_trainable=False,
266
+ )
267
+ else:
268
+ model = base_model
269
+
270
+ outputs = model.generate(
271
+ input_ids,
272
+ do_sample=False,
273
+ temperature=None,
274
+ top_p=None,
275
+ max_new_tokens=max_new_tokens,
276
+ return_dict_in_generate=True,
277
+ output_scores=(not output_str_only),
278
+ )
279
+
280
+ text = tokenizer.decode(
281
+ outputs[0][0][len(input_ids[0]) :], skip_special_tokens=True
282
+ )
283
+ # Check if the text is empty or only whitespace.
284
+ if not text.strip():
285
+ raise ValueError(
286
+ "Received an empty text response. Please verify your input or model configuration."
287
+ )
288
+ output_strs.append(text)
289
+
290
+ if not output_str_only:
291
+ # outputs.scores: (num_token, 1, vocab_size)
292
+ top_output_logprobs.append(
293
+ [
294
+ get_top_logprobs(logits[0], NUM_TOP_LOGPROBS).tolist()
295
+ for logits in outputs.scores
296
+ ]
297
+ )
298
+ if token_ids_logprob is not None:
299
+ token_ids_output_logprobs.append(
300
+ [
301
+ get_token_ids_logprobs(
302
+ logits[0], token_ids_logprob
303
+ ).tolist()
304
+ for logits in outputs.scores
305
+ ]
306
+ )
307
+ del outputs
308
+
309
+ input_logits = model.forward(input_ids).logits[0]
310
+ top_input_logprobs.append(
311
+ get_top_logprobs(input_logits, NUM_TOP_LOGPROBS).tolist()
312
+ )
313
+ if token_ids_logprob is not None:
314
+ token_ids_input_logprobs.append(
315
+ get_token_ids_logprobs(input_logits, token_ids_logprob).tolist()
316
+ )
317
+ del input_logits
318
+
319
+ return ModelOutput(
320
+ output_strs=output_strs,
321
+ top_input_logprobs=top_input_logprobs,
322
+ top_output_logprobs=top_output_logprobs,
323
+ token_ids_input_logprobs=token_ids_input_logprobs,
324
+ token_ids_output_logprobs=token_ids_output_logprobs,
325
+ )
326
+
264
327
 
265
328
  class SRTRunner:
266
329
  def __init__(
@@ -275,72 +338,79 @@ class SRTRunner:
275
338
  lora_backend: str = "triton",
276
339
  disable_cuda_graph: bool = False,
277
340
  disable_radix_cache: bool = False,
341
+ chunked_prefill_size: Optional[int] = None,
342
+ dp_size: int = 1,
343
+ tokenizer_path: Optional[str] = None,
344
+ enable_ep_moe: bool = False,
345
+ mem_fraction_static: float = 0.65,
346
+ trust_remote_code: bool = False,
347
+ speculative_draft_model_path: Optional[str] = None,
348
+ speculative_algorithm: Optional[str] = None,
349
+ speculative_num_steps: Optional[int] = None,
350
+ speculative_eagle_topk: Optional[int] = None,
351
+ speculative_num_draft_tokens: Optional[int] = None,
352
+ disable_overlap_schedule: bool = False,
278
353
  ):
279
354
  self.model_type = model_type
280
355
  self.is_generation = model_type == "generation"
356
+ enable_dp_attention = dp_size > 1
357
+
358
+ spec_kwargs = {}
359
+ if speculative_draft_model_path:
360
+ spec_kwargs["speculative_draft_model_path"] = speculative_draft_model_path
361
+ spec_kwargs["speculative_algorithm"] = speculative_algorithm
362
+ spec_kwargs["speculative_num_steps"] = speculative_num_steps
363
+ spec_kwargs["speculative_eagle_topk"] = speculative_eagle_topk
364
+ spec_kwargs["speculative_num_draft_tokens"] = speculative_num_draft_tokens
365
+
281
366
  self.engine = Engine(
282
367
  model_path=model_path,
283
368
  tp_size=tp_size,
284
369
  dtype=get_dtype_str(torch_dtype),
285
370
  port=port,
286
- mem_fraction_static=0.65,
287
- trust_remote_code=False,
371
+ mem_fraction_static=mem_fraction_static,
372
+ trust_remote_code=trust_remote_code,
288
373
  is_embedding=not self.is_generation,
289
374
  lora_paths=lora_paths,
290
375
  max_loras_per_batch=max_loras_per_batch,
291
376
  lora_backend=lora_backend,
292
377
  disable_cuda_graph=disable_cuda_graph,
293
378
  disable_radix_cache=disable_radix_cache,
379
+ chunked_prefill_size=chunked_prefill_size,
380
+ enable_dp_attention=enable_dp_attention,
381
+ dp_size=dp_size,
382
+ tokenizer_path=tokenizer_path,
383
+ enable_ep_moe=enable_ep_moe,
384
+ disable_overlap_schedule=disable_overlap_schedule,
385
+ cuda_graph_max_bs=4,
386
+ **spec_kwargs,
294
387
  )
295
- self.tokenizer = get_tokenizer(model_path)
388
+
389
+ if tokenizer_path is None:
390
+ self.tokenizer = get_tokenizer(
391
+ model_path, trust_remote_code=trust_remote_code
392
+ )
393
+ else:
394
+ self.tokenizer = None
296
395
 
297
396
  def forward(
298
397
  self,
299
398
  prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
300
- max_new_tokens=8,
301
- lora_paths=None,
399
+ max_new_tokens: int = 8,
400
+ lora_paths: Optional[List[str]] = None,
401
+ logprob_start_len: int = 0,
402
+ top_k: Optional[int] = None,
403
+ token_ids_logprob: Optional[List[int]] = None,
302
404
  ):
303
405
  if self.is_generation:
304
- # the return value contains logprobs from prefill
305
- output_strs = []
306
- top_input_logprobs = []
307
- top_output_logprobs = []
308
- sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
309
- for i, prompt in enumerate(prompts):
310
- response = self.engine.generate(
311
- prompt,
312
- lora_path=lora_paths[i] if lora_paths else None,
313
- sampling_params=sampling_params,
314
- return_logprob=True,
315
- logprob_start_len=0,
316
- top_logprobs_num=NUM_TOP_LOGPROBS,
317
- )
318
- output_strs.append(response["text"])
319
- top_input_logprobs.append(
320
- [
321
- [tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
322
- for x in response["meta_info"]["input_top_logprobs"][1:]
323
- ]
324
- + [
325
- [
326
- tup[0]
327
- for tup in response["meta_info"]["output_top_logprobs"][0][
328
- :NUM_TOP_LOGPROBS
329
- ]
330
- ]
331
- ]
332
- )
333
- top_output_logprobs.append(
334
- [
335
- [tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
336
- for x in response["meta_info"]["output_top_logprobs"]
337
- ]
338
- )
339
-
340
- return ModelOutput(
341
- output_strs=output_strs,
342
- top_input_logprobs=top_input_logprobs,
343
- top_output_logprobs=top_output_logprobs,
406
+ return self.forward_generation_raw(
407
+ engine=self.engine,
408
+ prompts=prompts,
409
+ max_new_tokens=max_new_tokens,
410
+ lora_paths=lora_paths,
411
+ logprob_start_len=logprob_start_len,
412
+ top_k=top_k,
413
+ token_ids_logprob=token_ids_logprob,
344
414
  )
345
415
  else:
346
416
  response = self.engine.encode(prompts)
@@ -362,18 +432,11 @@ class SRTRunner:
362
432
  only return output strings and no logprobs
363
433
  """
364
434
  if self.is_generation:
365
- # the return value contains logprobs from prefill
366
- output_strs = []
367
- sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
368
- response = self.engine.generate(
369
- prompts,
370
- lora_path=lora_paths if lora_paths else None,
371
- sampling_params=sampling_params,
372
- )
373
- output_strs = [r["text"] for r in response]
374
-
375
- return ModelOutput(
376
- output_strs=output_strs,
435
+ return self.batch_forward_generation_raw(
436
+ engine=self.engine,
437
+ prompts=prompts,
438
+ max_new_tokens=max_new_tokens,
439
+ lora_paths=lora_paths,
377
440
  )
378
441
  else:
379
442
  response = self.engine.encode(prompts)
@@ -391,6 +454,157 @@ class SRTRunner:
391
454
  self.engine.shutdown()
392
455
  del self.engine
393
456
 
457
+ @staticmethod
458
+ def forward_generation_raw(
459
+ engine: Engine,
460
+ prompts: Union[List[str], List[torch.Tensor]],
461
+ max_new_tokens: int = 8,
462
+ lora_paths: Optional[List[str]] = None,
463
+ logprob_start_len: int = 0,
464
+ top_k: Optional[int] = None,
465
+ token_ids_logprob: Optional[List[int]] = None,
466
+ ):
467
+ # the return value contains logprobs from prefill
468
+ output_strs = []
469
+ output_ids = []
470
+ # Input logprobs. Note that the last item in input logprob is equivalent to
471
+ # the first item in the output logprob.
472
+ top_input_logprobs = []
473
+ input_token_logprobs_lst = []
474
+ top_output_logprobs = []
475
+ output_token_logprobs_lst = []
476
+ top_output_logprob_idx = []
477
+ if token_ids_logprob is not None:
478
+ token_ids_input_logprobs = []
479
+ token_ids_output_logprobs = []
480
+ else:
481
+ token_ids_input_logprobs = token_ids_output_logprobs = None
482
+
483
+ sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
484
+ if top_k:
485
+ sampling_params["top_k"] = top_k
486
+
487
+ for i, prompt in enumerate(prompts):
488
+ response = engine.generate(
489
+ prompt,
490
+ lora_path=lora_paths[i] if lora_paths else None,
491
+ sampling_params=sampling_params,
492
+ return_logprob=True,
493
+ logprob_start_len=logprob_start_len,
494
+ top_logprobs_num=NUM_TOP_LOGPROBS,
495
+ token_ids_logprob=token_ids_logprob,
496
+ )
497
+ text = response["text"]
498
+
499
+ # Check if the text is empty or only whitespace.
500
+ if not text.strip():
501
+ raise ValueError(
502
+ "Received an empty text response. Please verify your input or model configuration."
503
+ )
504
+ output_strs.append(text)
505
+ # output_ids.append(response["output_ids"])
506
+
507
+ input_token_logprobs = response["meta_info"]["input_token_logprobs"]
508
+ output_token_logprobs = response["meta_info"]["output_token_logprobs"]
509
+ # print(i, input_token_logprobs)
510
+ # print(i, output_token_logprobs)
511
+ logprobs = response["meta_info"]["input_top_logprobs"]
512
+ if token_ids_logprob is not None:
513
+ input_token_ids_logprobs = response["meta_info"][
514
+ "input_token_ids_logprobs"
515
+ ][1:]
516
+ else:
517
+ input_token_ids_logprobs = None
518
+
519
+ num_prompt_tokens = response["meta_info"]["prompt_tokens"]
520
+ assert len(input_token_logprobs) == num_prompt_tokens - logprob_start_len
521
+ assert len(logprobs) == num_prompt_tokens - logprob_start_len
522
+
523
+ # The first token logprob has no meaning in sglang.
524
+ input_token_logprobs = input_token_logprobs[1:]
525
+ logprobs = logprobs[1:]
526
+ assert len(input_token_logprobs) == len(logprobs)
527
+
528
+ input_token_logprobs_lst.append(
529
+ input_token_logprobs + [output_token_logprobs[0]]
530
+ )
531
+ output_token_logprobs_lst.append(output_token_logprobs)
532
+
533
+ top_input_logprobs.append(
534
+ [[tup[0] for tup in x[:NUM_TOP_LOGPROBS]] for x in logprobs]
535
+ + [
536
+ [
537
+ tup[0]
538
+ for tup in response["meta_info"]["output_top_logprobs"][0][
539
+ :NUM_TOP_LOGPROBS
540
+ ]
541
+ ]
542
+ ]
543
+ )
544
+ top_output_logprobs.append(
545
+ [
546
+ [tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
547
+ for x in response["meta_info"]["output_top_logprobs"]
548
+ ]
549
+ )
550
+ top_output_logprob_idx.append(
551
+ [
552
+ [tup[1] for tup in x[:NUM_TOP_LOGPROBS]]
553
+ for x in response["meta_info"]["output_top_logprobs"]
554
+ ]
555
+ )
556
+ if token_ids_logprob is not None:
557
+ token_ids_input_logprobs.append(
558
+ [[tup[0] for tup in x] for x in input_token_ids_logprobs]
559
+ + [
560
+ [
561
+ tup[0]
562
+ for tup in response["meta_info"][
563
+ "output_token_ids_logprobs"
564
+ ][0]
565
+ ]
566
+ ]
567
+ )
568
+ token_ids_output_logprobs.append(
569
+ [
570
+ [tup[0] for tup in x]
571
+ for x in response["meta_info"]["output_token_ids_logprobs"]
572
+ ]
573
+ )
574
+
575
+ return ModelOutput(
576
+ output_strs=output_strs,
577
+ output_ids=output_ids,
578
+ top_input_logprobs=top_input_logprobs,
579
+ top_output_logprobs=top_output_logprobs,
580
+ input_token_logprobs_lst=input_token_logprobs_lst,
581
+ output_token_logprobs_lst=output_token_logprobs_lst,
582
+ top_output_logprob_idx=top_output_logprob_idx,
583
+ token_ids_input_logprobs=token_ids_input_logprobs,
584
+ token_ids_output_logprobs=token_ids_output_logprobs,
585
+ )
586
+
587
+ @staticmethod
588
+ def batch_forward_generation_raw(
589
+ prompts: Union[List[str], List[torch.Tensor]],
590
+ max_new_tokens,
591
+ lora_paths,
592
+ engine,
593
+ ):
594
+ # the return value contains logprobs from prefill
595
+ output_strs = []
596
+ sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
597
+ response = engine.generate(
598
+ prompts,
599
+ lora_path=lora_paths if lora_paths else None,
600
+ sampling_params=sampling_params,
601
+ )
602
+ output_strs = [r["text"] for r in response]
603
+
604
+ return ModelOutput(
605
+ output_strs=output_strs,
606
+ )
607
+
394
608
 
395
609
  def monkey_patch_gemma2_sdpa():
396
610
  """
@@ -405,3 +619,52 @@ def monkey_patch_gemma2_sdpa():
405
619
  return config
406
620
 
407
621
  setattr(Gemma2PreTrainedModel, "_check_and_enable_sdpa", _check_and_enable_sdpa)
622
+
623
+
624
+ def check_close_model_outputs(
625
+ hf_outputs: ModelOutput,
626
+ srt_outputs: ModelOutput,
627
+ prefill_tolerance: float,
628
+ decode_tolerance: float,
629
+ rouge_l_tolerance: float,
630
+ debug_text: str = "",
631
+ check_logprobs: bool = True,
632
+ ):
633
+ # Compare output strings
634
+ print(f"{hf_outputs.output_strs=}")
635
+ print(f"{srt_outputs.output_strs=}")
636
+ rouge_l_scores = calculate_rouge_l(hf_outputs.output_strs, srt_outputs.output_strs)
637
+ print(f"{rouge_l_scores=}")
638
+ assert all(
639
+ score >= rouge_l_tolerance for score in rouge_l_scores
640
+ ), f"Not all ROUGE-L scores are greater than rouge_l_tolerance={rouge_l_tolerance}"
641
+
642
+ if check_logprobs:
643
+ for i in range(len(hf_outputs.output_strs)):
644
+ # Compare input logprobs
645
+ hf_logprobs = torch.Tensor(hf_outputs.top_input_logprobs[i])
646
+ srt_logprobs = torch.Tensor(srt_outputs.top_input_logprobs[i])
647
+ input_len = hf_logprobs.shape[0]
648
+ print(
649
+ "prefill logprobs max_diff", torch.max(abs(hf_logprobs - srt_logprobs))
650
+ )
651
+ if input_len <= 100:
652
+ assert torch.all(abs(hf_logprobs - srt_logprobs) < prefill_tolerance), (
653
+ f"prefill logprobs are not all close with {debug_text} "
654
+ f"prefill_tolerance={prefill_tolerance}."
655
+ f"{hf_logprobs=}, {srt_logprobs=}"
656
+ )
657
+
658
+ # Compare output logprobs
659
+ hf_logprobs = torch.Tensor(hf_outputs.top_output_logprobs[i])
660
+ srt_logprobs = torch.Tensor(srt_outputs.top_output_logprobs[i])
661
+
662
+ print(
663
+ "decode logprobs max_diff", torch.max(abs(hf_logprobs - srt_logprobs))
664
+ )
665
+ if input_len <= 100:
666
+ assert torch.all(abs(hf_logprobs - srt_logprobs) < decode_tolerance), (
667
+ f"decode logprobs are not all close with {debug_text} "
668
+ f"decode_tolerance={decode_tolerance}."
669
+ f"{hf_logprobs=}, {srt_logprobs=}"
670
+ )