sglang 0.4.3.post2__py3-none-any.whl → 0.4.3.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (205) hide show
  1. sglang/api.py +1 -1
  2. sglang/bench_offline_throughput.py +19 -0
  3. sglang/bench_one_batch.py +2 -2
  4. sglang/bench_serving.py +123 -79
  5. sglang/global_config.py +8 -3
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/lang/ir.py +1 -1
  8. sglang/srt/_custom_ops.py +83 -91
  9. sglang/srt/configs/load_config.py +4 -1
  10. sglang/srt/configs/model_config.py +48 -2
  11. sglang/srt/configs/qwen2_5_vl_config.py +5 -2
  12. sglang/srt/constrained/base_grammar_backend.py +117 -15
  13. sglang/srt/constrained/llguidance_backend.py +151 -0
  14. sglang/srt/constrained/outlines_backend.py +24 -33
  15. sglang/srt/constrained/xgrammar_backend.py +69 -38
  16. sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
  17. sglang/srt/distributed/parallel_state.py +48 -3
  18. sglang/srt/entrypoints/engine.py +67 -9
  19. sglang/srt/entrypoints/http_server.py +190 -41
  20. sglang/srt/entrypoints/verl_engine.py +147 -0
  21. sglang/srt/function_call_parser.py +0 -1
  22. sglang/srt/layers/activation.py +11 -0
  23. sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
  24. sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
  25. sglang/srt/layers/attention/flashinfer_backend.py +220 -378
  26. sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
  27. sglang/srt/layers/attention/torch_native_backend.py +1 -1
  28. sglang/srt/layers/attention/triton_backend.py +9 -6
  29. sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
  30. sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
  31. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
  32. sglang/srt/layers/attention/utils.py +39 -0
  33. sglang/srt/layers/attention/vision.py +60 -63
  34. sglang/srt/layers/dp_attention.py +142 -1
  35. sglang/srt/layers/layernorm.py +1 -1
  36. sglang/srt/layers/linear.py +3 -1
  37. sglang/srt/layers/logits_processor.py +281 -45
  38. sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
  39. sglang/srt/layers/moe/ep_moe/layer.py +140 -28
  40. sglang/srt/layers/moe/fused_moe_native.py +2 -0
  41. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
  43. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
  47. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
  48. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
  51. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
  55. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
  56. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
  63. sglang/srt/layers/moe/topk.py +13 -4
  64. sglang/srt/layers/quantization/__init__.py +111 -7
  65. sglang/srt/layers/quantization/blockwise_int8.py +409 -0
  66. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  68. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  69. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  70. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  71. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  72. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  73. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  74. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  75. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  76. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  77. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  78. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  79. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  80. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  81. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  82. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  83. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  84. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  85. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  86. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  87. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  88. sglang/srt/layers/quantization/fp8.py +69 -28
  89. sglang/srt/layers/quantization/fp8_utils.py +17 -1
  90. sglang/srt/layers/quantization/gptq.py +416 -0
  91. sglang/srt/layers/quantization/int8_kernel.py +327 -0
  92. sglang/srt/layers/quantization/int8_utils.py +73 -0
  93. sglang/srt/layers/quantization/modelopt_quant.py +18 -1
  94. sglang/srt/layers/radix_attention.py +1 -0
  95. sglang/srt/layers/rotary_embedding.py +0 -1
  96. sglang/srt/layers/sampler.py +76 -31
  97. sglang/srt/layers/vocab_parallel_embedding.py +14 -13
  98. sglang/srt/lora/lora.py +17 -1
  99. sglang/srt/lora/lora_config.py +5 -0
  100. sglang/srt/lora/lora_manager.py +1 -3
  101. sglang/srt/managers/cache_controller.py +193 -62
  102. sglang/srt/managers/configure_logging.py +2 -1
  103. sglang/srt/managers/data_parallel_controller.py +6 -2
  104. sglang/srt/managers/detokenizer_manager.py +124 -102
  105. sglang/srt/managers/image_processor.py +2 -1
  106. sglang/srt/managers/io_struct.py +143 -6
  107. sglang/srt/managers/schedule_batch.py +237 -197
  108. sglang/srt/managers/schedule_policy.py +29 -29
  109. sglang/srt/managers/scheduler.py +681 -259
  110. sglang/srt/managers/session_controller.py +6 -2
  111. sglang/srt/managers/tokenizer_manager.py +224 -68
  112. sglang/srt/managers/tp_worker.py +15 -4
  113. sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
  114. sglang/srt/mem_cache/chunk_cache.py +18 -11
  115. sglang/srt/mem_cache/hiradix_cache.py +394 -0
  116. sglang/srt/mem_cache/memory_pool.py +44 -18
  117. sglang/srt/mem_cache/radix_cache.py +58 -47
  118. sglang/srt/metrics/collector.py +94 -36
  119. sglang/srt/model_executor/cuda_graph_runner.py +55 -24
  120. sglang/srt/model_executor/forward_batch_info.py +49 -16
  121. sglang/srt/model_executor/model_runner.py +208 -28
  122. sglang/srt/model_loader/loader.py +3 -3
  123. sglang/srt/model_loader/weight_utils.py +36 -14
  124. sglang/srt/models/baichuan.py +31 -6
  125. sglang/srt/models/chatglm.py +39 -7
  126. sglang/srt/models/commandr.py +29 -5
  127. sglang/srt/models/dbrx.py +31 -5
  128. sglang/srt/models/deepseek.py +43 -6
  129. sglang/srt/models/deepseek_nextn.py +32 -19
  130. sglang/srt/models/deepseek_v2.py +265 -32
  131. sglang/srt/models/exaone.py +19 -9
  132. sglang/srt/models/gemma.py +22 -8
  133. sglang/srt/models/gemma2.py +25 -12
  134. sglang/srt/models/gemma2_reward.py +5 -1
  135. sglang/srt/models/gpt2.py +28 -13
  136. sglang/srt/models/gpt_bigcode.py +27 -5
  137. sglang/srt/models/granite.py +21 -9
  138. sglang/srt/models/grok.py +21 -4
  139. sglang/srt/models/internlm2.py +36 -6
  140. sglang/srt/models/internlm2_reward.py +5 -1
  141. sglang/srt/models/llama.py +26 -9
  142. sglang/srt/models/llama_classification.py +5 -1
  143. sglang/srt/models/llama_eagle.py +17 -4
  144. sglang/srt/models/llama_embedding.py +5 -1
  145. sglang/srt/models/llama_reward.py +7 -2
  146. sglang/srt/models/llava.py +19 -3
  147. sglang/srt/models/llavavid.py +10 -1
  148. sglang/srt/models/minicpm.py +26 -2
  149. sglang/srt/models/minicpm3.py +39 -3
  150. sglang/srt/models/minicpmv.py +45 -14
  151. sglang/srt/models/mixtral.py +20 -9
  152. sglang/srt/models/mixtral_quant.py +50 -8
  153. sglang/srt/models/mllama.py +57 -11
  154. sglang/srt/models/olmo.py +34 -6
  155. sglang/srt/models/olmo2.py +34 -13
  156. sglang/srt/models/olmoe.py +26 -4
  157. sglang/srt/models/phi3_small.py +29 -10
  158. sglang/srt/models/qwen.py +26 -3
  159. sglang/srt/models/qwen2.py +26 -4
  160. sglang/srt/models/qwen2_5_vl.py +46 -8
  161. sglang/srt/models/qwen2_eagle.py +17 -5
  162. sglang/srt/models/qwen2_moe.py +44 -6
  163. sglang/srt/models/qwen2_rm.py +78 -0
  164. sglang/srt/models/qwen2_vl.py +39 -8
  165. sglang/srt/models/stablelm.py +32 -5
  166. sglang/srt/models/torch_native_llama.py +5 -2
  167. sglang/srt/models/xverse.py +21 -9
  168. sglang/srt/models/xverse_moe.py +45 -7
  169. sglang/srt/models/yivl.py +2 -1
  170. sglang/srt/openai_api/adapter.py +109 -24
  171. sglang/srt/openai_api/protocol.py +17 -1
  172. sglang/srt/reasoning_parser.py +154 -0
  173. sglang/srt/sampling/penaltylib/__init__.py +4 -6
  174. sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
  175. sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
  176. sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
  177. sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
  178. sglang/srt/sampling/sampling_batch_info.py +79 -157
  179. sglang/srt/sampling/sampling_params.py +16 -13
  180. sglang/srt/server_args.py +136 -52
  181. sglang/srt/speculative/build_eagle_tree.py +2 -8
  182. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +0 -1
  183. sglang/srt/speculative/eagle_utils.py +92 -58
  184. sglang/srt/speculative/eagle_worker.py +186 -94
  185. sglang/srt/speculative/spec_info.py +1 -13
  186. sglang/srt/utils.py +43 -17
  187. sglang/srt/warmup.py +47 -0
  188. sglang/test/few_shot_gsm8k.py +4 -1
  189. sglang/test/runners.py +389 -126
  190. sglang/test/send_one.py +88 -0
  191. sglang/test/test_block_fp8_ep.py +361 -0
  192. sglang/test/test_programs.py +1 -1
  193. sglang/test/test_utils.py +138 -84
  194. sglang/utils.py +50 -60
  195. sglang/version.py +1 -1
  196. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/METADATA +21 -15
  197. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/RECORD +200 -166
  198. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/WHEEL +1 -1
  199. sglang/bench_latency.py +0 -1
  200. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
  201. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
  202. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
  203. sglang/test/srt/sampling/penaltylib/utils.py +0 -344
  204. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/LICENSE +0 -0
  205. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,10 @@
1
1
  import logging
2
+ import os
2
3
  import time
3
- from typing import List, Optional, Union
4
+ from typing import Dict, List, Optional, Tuple, Union
4
5
 
5
6
  import torch
7
+ from huggingface_hub import snapshot_download
6
8
 
7
9
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
8
10
  from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
@@ -20,11 +22,13 @@ from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
20
22
  from sglang.srt.speculative.eagle_utils import (
21
23
  EagleDraftInput,
22
24
  EagleVerifyInput,
25
+ EagleVerifyOutput,
23
26
  assign_draft_cache_locs,
24
27
  fast_topk,
25
28
  select_top_k_tokens,
26
29
  )
27
30
  from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
31
+ from sglang.srt.utils import get_available_gpu_memory
28
32
 
29
33
  logger = logging.getLogger(__name__)
30
34
 
@@ -40,10 +44,31 @@ class EAGLEWorker(TpModelWorker):
40
44
  nccl_port: int,
41
45
  target_worker: TpModelWorker,
42
46
  ):
47
+ # Override context length with target model's context length
48
+ server_args.context_length = target_worker.model_runner.model_config.context_len
49
+ os.environ["SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"] = "1"
50
+
43
51
  # Do not capture cuda graph in `super().__init__()`
44
52
  # We will capture it later
45
53
  backup_disable_cuda_graph = server_args.disable_cuda_graph
46
54
  server_args.disable_cuda_graph = True
55
+
56
+ # Lossy optimization by using hot tokens
57
+ if server_args.speculative_token_map is not None:
58
+ self.hot_token_id = load_token_map(server_args.speculative_token_map)
59
+ server_args.json_model_override_args = (
60
+ f'{{"hot_vocab_size": {len(self.hot_token_id)}}}'
61
+ )
62
+ else:
63
+ self.hot_token_id = None
64
+
65
+ # We share the allocator with a target worker. Draft/target worker
66
+ # owns its own KV cache.
67
+ self.req_to_token_pool, self.token_to_kv_pool_allocator = (
68
+ target_worker.get_memory_pool()
69
+ )
70
+
71
+ # Init target worker
47
72
  super().__init__(
48
73
  gpu_id=gpu_id,
49
74
  tp_rank=tp_rank,
@@ -51,9 +76,10 @@ class EAGLEWorker(TpModelWorker):
51
76
  nccl_port=nccl_port,
52
77
  dp_rank=dp_rank,
53
78
  is_draft_worker=True,
79
+ req_to_token_pool=self.req_to_token_pool,
80
+ token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
54
81
  )
55
82
  self.target_worker = target_worker
56
- self.finish_extend_len = []
57
83
 
58
84
  # Parse arguments
59
85
  self.topk = server_args.speculative_eagle_topk
@@ -62,12 +88,20 @@ class EAGLEWorker(TpModelWorker):
62
88
  server_args.speculative_algorithm
63
89
  )
64
90
  self.server_args = server_args
91
+ self.use_nan_detection = self.server_args.enable_nan_detection
92
+ self.device = self.model_runner.device
93
+ self.gpu_id = self.model_runner.gpu_id
65
94
 
66
95
  # Share the embedding and lm_head
67
- if not self.speculative_algorithm.is_nextn():
68
- embed, head = self.target_worker.model_runner.model.get_embed_and_head()
69
- self.model_runner.model.set_embed_and_head(embed, head)
70
- self.model_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph
96
+ embed, head = self.target_worker.model_runner.model.get_embed_and_head()
97
+ if self.hot_token_id is not None:
98
+ head = head.clone()
99
+ self.hot_token_id = self.hot_token_id.to(head.device)
100
+ head.data = head.data[self.hot_token_id]
101
+ self.draft_model_runner.model.set_embed_and_head(embed, head)
102
+ self.draft_model_runner.server_args.disable_cuda_graph = (
103
+ backup_disable_cuda_graph
104
+ )
71
105
 
72
106
  # Create multi-step attn backends and cuda graph runners
73
107
  if server_args.attention_backend == "flashinfer":
@@ -95,7 +129,7 @@ class EAGLEWorker(TpModelWorker):
95
129
  f"EAGLE is not supportted in attention backend {server_args.attention_backend}"
96
130
  )
97
131
 
98
- self.model_runner.draft_attn_backend = self.draft_attn_backend
132
+ self.draft_model_runner.draft_attn_backend = self.draft_attn_backend
99
133
  self.init_cuda_graphs()
100
134
 
101
135
  def init_cuda_graphs(self):
@@ -106,55 +140,81 @@ class EAGLEWorker(TpModelWorker):
106
140
  return
107
141
 
108
142
  tic = time.time()
109
- logger.info("Capture cuda graph begin. This can take up to several minutes.")
143
+ logger.info(
144
+ f"Capture draft cuda graph begin. This can take up to several minutes. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
145
+ )
110
146
  self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self)
111
- logger.info(f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s")
147
+ logger.info(
148
+ f"Capture draft cuda graph end. Time elapsed: {time.time() - tic:.2f} s. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
149
+ )
112
150
 
113
- def forward_batch_speculative_generation(self, batch: ScheduleBatch):
151
+ @property
152
+ def draft_model_runner(self):
153
+ return self.model_runner
154
+
155
+ def forward_batch_speculative_generation(
156
+ self, batch: ScheduleBatch
157
+ ) -> Tuple[LogitsProcessorOutput, List[int], int, int]:
158
+ """Run speculative decoding forward.
159
+
160
+ NOTE: Many states of batch is modified as you go through. It is not guaranteed
161
+ the final output batch doesn't have the same state as the input.
162
+
163
+ Args:
164
+ batch: The batch to run forward. The state of the batch is modified as it runs.
165
+ Returns:
166
+ A tuple of the final logit output of the target model, next tokens accepeted,
167
+ the batch id (used for overlap schedule), and number of accepeted tokens.
168
+ """
169
+ assert not batch.spec_algorithm.is_none()
114
170
  if batch.forward_mode.is_decode():
115
- # Draft
116
- spec_info: EagleVerifyInput = self.draft(batch)
117
-
118
- # Verify
119
- (
120
- next_draft_input,
121
- logits_output,
122
- verified_id,
123
- self.finish_extend_len,
124
- accept_length_cpu,
125
- model_worker_batch,
126
- ) = self.verify(batch, spec_info)
127
- batch.spec_info = next_draft_input
128
- # if it is None, means all requsets are finished
171
+ spec_info, to_free_cache_loc = self.draft(batch)
172
+ logits_output, verify_output, model_worker_batch = self.verify(
173
+ batch, spec_info
174
+ )
175
+ # Free cache loc (we put it here to avoid synchronization and hide kernel launch overhead.)
176
+ self.token_to_kv_pool_allocator.free(to_free_cache_loc)
177
+ # if it is None, means all requests are finished
129
178
  if batch.spec_info.verified_id is not None:
130
179
  self.forward_draft_extend_after_decode(batch)
180
+
131
181
  return (
132
182
  logits_output,
133
- verified_id,
134
- model_worker_batch,
135
- sum(accept_length_cpu),
183
+ verify_output.verified_id,
184
+ model_worker_batch.bid,
185
+ sum(verify_output.accept_length_per_req_cpu),
136
186
  )
137
187
 
138
188
  else:
139
- # Forward with the target model and get hidden states.
140
- # We need the full hidden states to prefill the KV cache of the draft model.
141
- model_worker_batch = batch.get_model_worker_batch()
142
- model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
143
- logits_output, next_token_ids = self.target_worker.forward_batch_generation(
144
- model_worker_batch
145
- )
146
-
147
- # Forward with the draft model.
148
- batch.spec_info = EagleDraftInput(
149
- hidden_states=logits_output.hidden_states,
150
- verified_id=next_token_ids,
189
+ logits_output, next_token_ids, bid = self.forward_target_extend(batch)
190
+ self.forward_draft_extend(
191
+ batch, logits_output.hidden_states, next_token_ids
151
192
  )
152
- self.forward_draft_extend(batch)
153
- return logits_output, next_token_ids, model_worker_batch, 0
193
+ return logits_output, next_token_ids, bid, 0
194
+
195
+ def forward_target_extend(
196
+ self, batch: ScheduleBatch
197
+ ) -> Tuple[LogitsProcessorOutput, List[int], int]:
198
+ """Run the target extend.
199
+
200
+ Args:
201
+ batch: The batch to run. States could be modified.
202
+
203
+ Returns:
204
+ logits_output: The output of logits. It will contain the full hidden states.
205
+ next_token_ids: Next token ids generated.
206
+ bid: The model batch ID. Used for overlap schedule.
207
+ """
208
+ # Forward with the target model and get hidden states.
209
+ # We need the full hidden states to prefill the KV cache of the draft model.
210
+ model_worker_batch = batch.get_model_worker_batch()
211
+ model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
212
+ logits_output, next_token_ids = self.target_worker.forward_batch_generation(
213
+ model_worker_batch
214
+ )
215
+ return logits_output, next_token_ids, model_worker_batch.bid
154
216
 
155
217
  def draft(self, batch: ScheduleBatch):
156
- self._set_mem_pool(batch, self.model_runner)
157
-
158
218
  # Parse args
159
219
  num_seqs = batch.batch_size()
160
220
  spec_info = batch.spec_info
@@ -172,7 +232,6 @@ class EAGLEWorker(TpModelWorker):
172
232
  self.topk,
173
233
  self.speculative_num_steps,
174
234
  )
175
-
176
235
  batch.out_cache_loc = out_cache_loc
177
236
  batch.seq_lens_sum = torch.sum(batch.seq_lens).item()
178
237
  spec_info.positions = batch.seq_lens.repeat_interleave(self.topk, dim=0)
@@ -180,11 +239,12 @@ class EAGLEWorker(TpModelWorker):
180
239
  # Get forward batch
181
240
  spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
182
241
  model_worker_batch = batch.get_model_worker_batch()
183
- forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
242
+ forward_batch = ForwardBatch.init_new(
243
+ model_worker_batch, self.draft_model_runner
244
+ )
184
245
  can_cuda_graph = self.cuda_graph_runner and self.cuda_graph_runner.can_run(
185
246
  forward_batch
186
247
  )
187
-
188
248
  if can_cuda_graph:
189
249
  score_list, token_list, parents_list = self.cuda_graph_runner.replay(
190
250
  forward_batch
@@ -192,7 +252,9 @@ class EAGLEWorker(TpModelWorker):
192
252
  else:
193
253
  # Initialize attention backend
194
254
  self.draft_attn_backend.init_forward_metadata(forward_batch)
195
-
255
+ forward_batch = ForwardBatch.init_new(
256
+ model_worker_batch, self.draft_model_runner
257
+ )
196
258
  # Run forward steps
197
259
  score_list, token_list, parents_list = self.draft_forward(forward_batch)
198
260
 
@@ -209,10 +271,7 @@ class EAGLEWorker(TpModelWorker):
209
271
  batch.sampling_info.is_all_greedy,
210
272
  )
211
273
 
212
- # Free cache locations
213
- batch.token_to_kv_pool.free(out_cache_loc)
214
- self._set_mem_pool(batch, self.target_worker.model_runner)
215
- return ret
274
+ return ret, out_cache_loc
216
275
 
217
276
  def draft_forward(self, forward_batch: ForwardBatch):
218
277
  # Parse args
@@ -223,6 +282,8 @@ class EAGLEWorker(TpModelWorker):
223
282
  spec_info.topk_index,
224
283
  spec_info.hidden_states,
225
284
  )
285
+ if self.hot_token_id is not None:
286
+ topk_index = self.hot_token_id[topk_index]
226
287
 
227
288
  # Return values
228
289
  score_list: List[torch.Tensor] = []
@@ -260,8 +321,11 @@ class EAGLEWorker(TpModelWorker):
260
321
  logits_output = self.model_runner.model.forward(
261
322
  forward_batch.input_ids, forward_batch.positions, forward_batch
262
323
  )
324
+ self._detect_nan_if_needed(logits_output)
263
325
  probs = torch.softmax(logits_output.next_token_logits, dim=-1)
264
326
  topk_p, topk_index = fast_topk(probs, self.topk, dim=-1)
327
+ if self.hot_token_id is not None:
328
+ topk_index = self.hot_token_id[topk_index]
265
329
  hidden_states = logits_output.hidden_states
266
330
 
267
331
  return score_list, token_list, parents_list
@@ -274,68 +338,96 @@ class EAGLEWorker(TpModelWorker):
274
338
  logits_output, _ = self.target_worker.forward_batch_generation(
275
339
  model_worker_batch, skip_sample=True
276
340
  )
341
+ self._detect_nan_if_needed(logits_output)
277
342
  spec_info.hidden_states = logits_output.hidden_states
278
- res = spec_info.verify(batch, logits_output)
343
+ res: EagleVerifyOutput = spec_info.verify(
344
+ batch, logits_output, self.token_to_kv_pool_allocator
345
+ )
346
+
347
+ # Post process based on verified outputs.
348
+ # Pick indices that we care (accepeted)
349
+ logits_output.next_token_logits = logits_output.next_token_logits[
350
+ res.accepeted_indices_cpu
351
+ ]
352
+ logits_output.hidden_states = logits_output.hidden_states[
353
+ res.accepeted_indices_cpu
354
+ ]
355
+ # Prepare the batch for the next draft forwards.
279
356
  batch.forward_mode = ForwardMode.DECODE
280
- return res + (model_worker_batch,)
357
+ batch.spec_info = res.draft_input
358
+
359
+ return logits_output, res, model_worker_batch
281
360
 
282
- def forward_draft_extend(self, batch: ScheduleBatch):
283
- self._set_mem_pool(batch, self.model_runner)
361
+ def forward_draft_extend(
362
+ self,
363
+ batch: ScheduleBatch,
364
+ hidden_states: torch.Tensor,
365
+ next_token_ids: List[int],
366
+ ):
367
+ """Run draft model extend. This API modifies the states of the batch.
368
+
369
+ Args:
370
+ batch: The batch to run.
371
+ hidden_states: Hidden states from the target model forward
372
+ next_token_ids: Next token ids generated from the target forward.
373
+ """
374
+ batch.spec_info = EagleDraftInput(
375
+ hidden_states=hidden_states,
376
+ verified_id=next_token_ids,
377
+ )
284
378
  batch.spec_info.prepare_for_extend(batch)
285
379
  batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
286
380
  model_worker_batch = batch.get_model_worker_batch()
287
- forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
288
- logits_output = self.model_runner.forward(forward_batch)
289
- self.capture_for_decode(logits_output, forward_batch)
290
- self._set_mem_pool(batch, self.target_worker.model_runner)
291
-
292
- def _set_mem_pool(self, batch: ScheduleBatch, runner: ModelRunner):
293
- batch.token_to_kv_pool = runner.token_to_kv_pool
294
- batch.req_to_token_pool = runner.req_to_token_pool
381
+ forward_batch = ForwardBatch.init_new(
382
+ model_worker_batch, self.draft_model_runner
383
+ )
384
+ logits_output = self.draft_model_runner.forward(forward_batch)
385
+ self._detect_nan_if_needed(logits_output)
386
+ assert isinstance(forward_batch.spec_info, EagleDraftInput)
387
+ assert forward_batch.spec_info is batch.spec_info
388
+ self.capture_for_decode(logits_output, forward_batch.spec_info)
295
389
 
296
390
  def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
297
391
  seq_lens_backup = batch.seq_lens
298
- req_pool_indices_backup = batch.req_pool_indices
299
-
300
- self._set_mem_pool(batch, self.model_runner)
301
392
  batch.forward_mode = ForwardMode.DRAFT_EXTEND
302
393
  batch.spec_info.prepare_extend_after_decode(batch, self.speculative_num_steps)
303
394
  batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
395
+ # We don't need logprob for this extend.
304
396
  model_worker_batch = batch.get_model_worker_batch()
305
- forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
306
- logits_output = self.model_runner.forward(forward_batch)
307
- self.capture_for_decode(logits_output, forward_batch)
308
- self._set_mem_pool(batch, self.target_worker.model_runner)
397
+ forward_batch = ForwardBatch.init_new(
398
+ model_worker_batch, self.draft_model_runner
399
+ )
400
+ logits_output = self.draft_model_runner.forward(forward_batch)
401
+ self._detect_nan_if_needed(logits_output)
402
+ assert forward_batch.spec_info is batch.spec_info
403
+ self.capture_for_decode(logits_output, forward_batch.spec_info)
309
404
 
310
405
  # Restore backup.
311
406
  # This is because `seq_lens` can be modified in `prepare_extend_after_decode`
312
407
  batch.forward_mode = ForwardMode.DECODE
313
408
  batch.seq_lens = seq_lens_backup
314
- batch.req_pool_indices = req_pool_indices_backup
315
409
 
316
410
  def capture_for_decode(
317
- self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
411
+ self, logits_output: LogitsProcessorOutput, draft_input: EagleDraftInput
318
412
  ):
319
413
  probs = torch.softmax(logits_output.next_token_logits, dim=-1)
320
- spec_info = forward_batch.spec_info
321
- spec_info.topk_p, spec_info.topk_index = fast_topk(probs, self.topk, dim=-1)
322
- spec_info.hidden_states = logits_output.hidden_states
323
-
324
- # Don't support prefix share now.
325
- def finish_request(self, reqs: Union[Req, List[Req]]):
326
- if not isinstance(reqs, List):
327
- reqs = [reqs]
328
- for req in reqs:
329
- if req.rid not in self.finish_extend_len:
330
- continue
331
- req_len = (
332
- len(req.origin_input_ids)
333
- + len(req.output_ids)
334
- - self.finish_extend_len[req.rid]
335
- - 1
336
- )
337
- kv_indices = self.model_runner.req_to_token_pool.req_to_token[
338
- req.req_pool_idx
339
- ][:req_len]
340
- self.model_runner.token_to_kv_pool.free(kv_indices)
341
- self.model_runner.req_to_token_pool.free(req.req_pool_idx)
414
+ draft_input.topk_p, draft_input.topk_index = fast_topk(probs, self.topk, dim=-1)
415
+ draft_input.hidden_states = logits_output.hidden_states
416
+
417
+ def _detect_nan_if_needed(self, logits_output: LogitsProcessorOutput):
418
+ if self.use_nan_detection:
419
+ logits = logits_output.next_token_logits
420
+ if torch.any(torch.isnan(logits)):
421
+ logger.warning("Detected errors during sampling! NaN in the logits.")
422
+ raise ValueError("Detected errors during sampling! NaN in the logits.")
423
+
424
+
425
+ def load_token_map(token_map_path: str) -> List[int]:
426
+ if not os.path.exists(token_map_path):
427
+ cache_dir = snapshot_download(
428
+ os.path.dirname(token_map_path),
429
+ ignore_patterns=["*.bin", "*.safetensors"],
430
+ )
431
+ token_map_path = os.path.join(cache_dir, os.path.basename(token_map_path))
432
+ hot_token_id = torch.load(token_map_path)
433
+ return torch.tensor(hot_token_id, dtype=torch.int32)
@@ -5,30 +5,18 @@ class SpeculativeAlgorithm(IntEnum):
5
5
  NONE = auto()
6
6
  EAGLE = auto()
7
7
 
8
- # NEXTN spec decoding is for DeepSeek V3/R1
9
- # currently it's implemented based on EAGLE
10
- NEXTN = auto()
11
-
12
8
  def is_none(self):
13
9
  return self == SpeculativeAlgorithm.NONE
14
10
 
15
11
  def is_eagle(self):
16
- return self == SpeculativeAlgorithm.EAGLE or self == SpeculativeAlgorithm.NEXTN
17
-
18
- def is_nextn(self):
19
- return self == SpeculativeAlgorithm.NEXTN
12
+ return self == SpeculativeAlgorithm.EAGLE
20
13
 
21
14
  @staticmethod
22
15
  def from_string(name: str):
23
16
  name_map = {
24
17
  "EAGLE": SpeculativeAlgorithm.EAGLE,
25
- "NEXTN": SpeculativeAlgorithm.NEXTN,
26
18
  None: SpeculativeAlgorithm.NONE,
27
19
  }
28
20
  if name is not None:
29
21
  name = name.upper()
30
22
  return name_map[name]
31
-
32
-
33
- class SpecInfo:
34
- pass
sglang/srt/utils.py CHANGED
@@ -32,13 +32,15 @@ import socket
32
32
  import subprocess
33
33
  import sys
34
34
  import tempfile
35
+ import threading
35
36
  import time
36
37
  import warnings
37
38
  from functools import lru_cache
38
39
  from importlib.metadata import PackageNotFoundError, version
39
40
  from io import BytesIO
41
+ from multiprocessing import Pool
40
42
  from multiprocessing.reduction import ForkingPickler
41
- from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, Union
43
+ from typing import Any, Callable, Dict, List, Optional, Protocol, Set, Tuple, Union
42
44
 
43
45
  import numpy as np
44
46
  import psutil
@@ -311,7 +313,7 @@ def make_layers(
311
313
  """Make a list of layers with the given layer function"""
312
314
  modules = torch.nn.ModuleList(
313
315
  [
314
- maybe_offload_to_cpu(layer_fn(idx=idx, prefix=f"{prefix}.{idx}"))
316
+ maybe_offload_to_cpu(layer_fn(idx=idx, prefix=add_prefix(idx, prefix)))
315
317
  for idx in range(num_hidden_layers)
316
318
  ]
317
319
  )
@@ -480,6 +482,10 @@ def assert_pkg_version(pkg: str, min_version: str, message: str):
480
482
 
481
483
  def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = None):
482
484
  """Kill the process and all its child processes."""
485
+ # Remove sigchld handler to avoid spammy logs.
486
+ if threading.current_thread() is threading.main_thread():
487
+ signal.signal(signal.SIGCHLD, signal.SIG_DFL)
488
+
483
489
  if parent_pid is None:
484
490
  parent_pid = os.getpid()
485
491
  include_parent = False
@@ -735,13 +741,6 @@ def pytorch_profile(name, func, *args, data_size=-1):
735
741
  return result
736
742
 
737
743
 
738
- def first_rank_print(*args, **kwargs):
739
- if torch.cuda.current_device() == 0:
740
- print(*args, **kwargs)
741
- else:
742
- pass
743
-
744
-
745
744
  def get_zmq_socket(
746
745
  context: zmq.Context, socket_type: zmq.SocketType, endpoint: str, bind: bool
747
746
  ):
@@ -1154,9 +1153,9 @@ def set_gpu_proc_affinity(
1154
1153
 
1155
1154
  if psutil.cpu_count() != psutil.cpu_count(logical=False):
1156
1155
  # HT on
1157
- upper_cpu_ids = [id for id in range(start_cpu_id, end_cpu_id)]
1158
- lower_cpu_ids = [id + total_pcores for id in range(start_cpu_id, end_cpu_id)]
1159
- bind_cpu_ids = list(itertools.chain(upper_cpu_ids, lower_cpu_ids))
1156
+ lower_cpu_ids = [id for id in range(start_cpu_id, end_cpu_id)]
1157
+ upper_cpu_ids = [id + total_pcores for id in range(start_cpu_id, end_cpu_id)]
1158
+ bind_cpu_ids = list(itertools.chain(lower_cpu_ids, upper_cpu_ids))
1160
1159
  else:
1161
1160
  # HT off
1162
1161
  bind_cpu_ids = [id for id in range(start_cpu_id, end_cpu_id)]
@@ -1171,6 +1170,11 @@ def get_bool_env_var(name: str, default: str = "false") -> bool:
1171
1170
  return value.lower() in ("true", "1")
1172
1171
 
1173
1172
 
1173
+ @lru_cache(maxsize=2)
1174
+ def disable_request_logging() -> bool:
1175
+ return get_bool_env_var("SGLANG_DISABLE_REQUEST_LOGGING")
1176
+
1177
+
1174
1178
  @lru_cache(maxsize=8)
1175
1179
  def _cuda_device_count_stateless(cuda_visible_devices: Optional[str] = None) -> int:
1176
1180
  # Note: cuda_visible_devices is not used, but we keep it as an argument for
@@ -1212,7 +1216,11 @@ def cuda_device_count_stateless() -> int:
1212
1216
  return _cuda_device_count_stateless(os.environ.get("CUDA_VISIBLE_DEVICES", None))
1213
1217
 
1214
1218
 
1215
- def dataclass_to_string_truncated(data, max_length=2048):
1219
+ def dataclass_to_string_truncated(
1220
+ data, max_length=2048, skip_names: Optional[Set[str]] = None
1221
+ ):
1222
+ if skip_names is None:
1223
+ skip_names = set()
1216
1224
  if isinstance(data, str):
1217
1225
  if len(data) > max_length:
1218
1226
  half_length = max_length // 2
@@ -1231,6 +1239,7 @@ def dataclass_to_string_truncated(data, max_length=2048):
1231
1239
  + ", ".join(
1232
1240
  f"'{k}': {dataclass_to_string_truncated(v, max_length)}"
1233
1241
  for k, v in data.items()
1242
+ if k not in skip_names
1234
1243
  )
1235
1244
  + "}"
1236
1245
  )
@@ -1241,6 +1250,7 @@ def dataclass_to_string_truncated(data, max_length=2048):
1241
1250
  + ", ".join(
1242
1251
  f"{f.name}={dataclass_to_string_truncated(getattr(data, f.name), max_length)}"
1243
1252
  for f in fields
1253
+ if f.name not in skip_names
1244
1254
  )
1245
1255
  + ")"
1246
1256
  )
@@ -1289,7 +1299,7 @@ def debug_timing(func):
1289
1299
  tic.record()
1290
1300
  result = func(*args, **kwargs)
1291
1301
  toc.record()
1292
- torch.cuda.synchronize() # Ensure all CUDA operations are complete
1302
+ toc.synchronize() # Wait for the function to complete without synchronizing all ops on the GPU
1293
1303
  elapsed = tic.elapsed_time(toc)
1294
1304
  indices = kwargs.get("indices", args[1] if len(args) > 1 else None)
1295
1305
  num_tokens = len(indices) if indices is not None else 0
@@ -1319,9 +1329,9 @@ def pyspy_dump_schedulers():
1319
1329
  result = subprocess.run(
1320
1330
  cmd, shell=True, capture_output=True, text=True, check=True
1321
1331
  )
1322
- logger.info(f"Profile for PID {pid}:\n{result.stdout}")
1332
+ logger.error(f"Pyspy dump for PID {pid}:\n{result.stdout}")
1323
1333
  except subprocess.CalledProcessError as e:
1324
- logger.info(f"Failed to profile PID {pid}. Error: {e.stderr}")
1334
+ logger.error(f"Pyspy failed to dump PID {pid}. Error: {e.stderr}")
1325
1335
 
1326
1336
 
1327
1337
  def kill_itself_when_parent_died():
@@ -1383,7 +1393,6 @@ def get_ip() -> str:
1383
1393
 
1384
1394
 
1385
1395
  def get_open_port() -> int:
1386
-
1387
1396
  port = os.getenv("SGLANG_PORT")
1388
1397
  if port is not None:
1389
1398
  while True:
@@ -1446,8 +1455,25 @@ def launch_dummy_health_check_server(host, port):
1446
1455
  )
1447
1456
 
1448
1457
 
1458
+ def create_checksum(directory: str):
1459
+ raise NotImplementedError()
1460
+
1461
+
1449
1462
  def set_cuda_arch():
1450
1463
  if is_flashinfer_available():
1451
1464
  capability = torch.cuda.get_device_capability()
1452
1465
  arch = f"{capability[0]}.{capability[1]}"
1453
1466
  os.environ["TORCH_CUDA_ARCH_LIST"] = f"{arch}{'+PTX' if arch == '9.0' else ''}"
1467
+
1468
+
1469
+ def add_prefix(name: str, prefix: str) -> str:
1470
+ """Add a weight path prefix to a module name.
1471
+
1472
+ Args:
1473
+ name: base module name.
1474
+ prefix: weight prefix str to added to the front of `name` concatenated with `.`.
1475
+
1476
+ Returns:
1477
+ The string `prefix.name` if prefix is non-empty, otherwise just `name`.
1478
+ """
1479
+ return name if not prefix else f"{prefix}.{name}"
sglang/srt/warmup.py ADDED
@@ -0,0 +1,47 @@
1
+ import logging
2
+ from typing import List
3
+
4
+ import numpy as np
5
+ import tqdm
6
+
7
+ from sglang.srt.managers.io_struct import GenerateReqInput
8
+ from sglang.srt.managers.tokenizer_manager import TokenizerManager
9
+
10
+ logger = logging.getLogger(__file__)
11
+
12
+ _warmup_registry = {}
13
+
14
+
15
+ def warmup(name: str) -> callable:
16
+ def decorator(fn: callable):
17
+ _warmup_registry[name] = fn
18
+ return fn
19
+
20
+ return decorator
21
+
22
+
23
+ async def execute_warmups(warmup_names: List[str], tokenizer_manager: TokenizerManager):
24
+ for warmup_name in warmup_names:
25
+ if warmup_name not in _warmup_registry:
26
+ logger.warning(f"Could not find custom warmup {warmup_name}")
27
+ continue
28
+ logger.info(f"Running warmup {warmup_name}")
29
+ await _warmup_registry[warmup_name](tokenizer_manager)
30
+
31
+
32
+ @warmup("voice_chat")
33
+ async def voice_chat(tokenizer_manager: TokenizerManager):
34
+ # this warms up the fused_moe triton kernels and caches them
35
+ # if we don't do this we break real time inference for voice chat
36
+ for i in tqdm.trange(1, 512):
37
+ size = i * 4
38
+ generate_req_input = GenerateReqInput(
39
+ input_ids=(np.random.randint(2**16, size=[size])).tolist(),
40
+ sampling_params={
41
+ "max_new_tokens": 30,
42
+ "temperature": 0.8,
43
+ "stop_token_ids": [1],
44
+ "min_p": 0.0,
45
+ },
46
+ )
47
+ await tokenizer_manager.generate_request(generate_req_input, None).__anext__()
@@ -93,9 +93,11 @@ def run_eval(args):
93
93
  tic = time.time()
94
94
  states = few_shot_gsm8k.run_batch(
95
95
  arguments,
96
- temperature=0,
96
+ temperature=args.temperature if hasattr(args, "temperature") else 0,
97
97
  num_threads=args.parallel,
98
98
  progress_bar=True,
99
+ return_logprob=getattr(args, "return_logprob", None),
100
+ logprob_start_len=getattr(args, "logprob_start_len", None),
99
101
  )
100
102
  latency = time.time() - tic
101
103
 
@@ -141,5 +143,6 @@ if __name__ == "__main__":
141
143
  parser.add_argument("--parallel", type=int, default=128)
142
144
  parser.add_argument("--host", type=str, default="http://127.0.0.1")
143
145
  parser.add_argument("--port", type=int, default=30000)
146
+ parser.add_argument("--temperature", type=float, default=0.0)
144
147
  args = parser.parse_args()
145
148
  run_eval(args)