sglang 0.4.3.post2__py3-none-any.whl → 0.4.3.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (205) hide show
  1. sglang/api.py +1 -1
  2. sglang/bench_offline_throughput.py +19 -0
  3. sglang/bench_one_batch.py +2 -2
  4. sglang/bench_serving.py +123 -79
  5. sglang/global_config.py +8 -3
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/lang/ir.py +1 -1
  8. sglang/srt/_custom_ops.py +83 -91
  9. sglang/srt/configs/load_config.py +4 -1
  10. sglang/srt/configs/model_config.py +48 -2
  11. sglang/srt/configs/qwen2_5_vl_config.py +5 -2
  12. sglang/srt/constrained/base_grammar_backend.py +117 -15
  13. sglang/srt/constrained/llguidance_backend.py +151 -0
  14. sglang/srt/constrained/outlines_backend.py +24 -33
  15. sglang/srt/constrained/xgrammar_backend.py +69 -38
  16. sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
  17. sglang/srt/distributed/parallel_state.py +48 -3
  18. sglang/srt/entrypoints/engine.py +67 -9
  19. sglang/srt/entrypoints/http_server.py +190 -41
  20. sglang/srt/entrypoints/verl_engine.py +147 -0
  21. sglang/srt/function_call_parser.py +0 -1
  22. sglang/srt/layers/activation.py +11 -0
  23. sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
  24. sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
  25. sglang/srt/layers/attention/flashinfer_backend.py +220 -378
  26. sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
  27. sglang/srt/layers/attention/torch_native_backend.py +1 -1
  28. sglang/srt/layers/attention/triton_backend.py +9 -6
  29. sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
  30. sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
  31. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
  32. sglang/srt/layers/attention/utils.py +39 -0
  33. sglang/srt/layers/attention/vision.py +60 -63
  34. sglang/srt/layers/dp_attention.py +142 -1
  35. sglang/srt/layers/layernorm.py +1 -1
  36. sglang/srt/layers/linear.py +3 -1
  37. sglang/srt/layers/logits_processor.py +281 -45
  38. sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
  39. sglang/srt/layers/moe/ep_moe/layer.py +140 -28
  40. sglang/srt/layers/moe/fused_moe_native.py +2 -0
  41. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
  43. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
  47. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
  48. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
  51. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
  55. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
  56. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
  63. sglang/srt/layers/moe/topk.py +13 -4
  64. sglang/srt/layers/quantization/__init__.py +111 -7
  65. sglang/srt/layers/quantization/blockwise_int8.py +409 -0
  66. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  68. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  69. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  70. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  71. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  72. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  73. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  74. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  75. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  76. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  77. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  78. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  79. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  80. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  81. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  82. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  83. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  84. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  85. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  86. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  87. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  88. sglang/srt/layers/quantization/fp8.py +69 -28
  89. sglang/srt/layers/quantization/fp8_utils.py +17 -1
  90. sglang/srt/layers/quantization/gptq.py +416 -0
  91. sglang/srt/layers/quantization/int8_kernel.py +327 -0
  92. sglang/srt/layers/quantization/int8_utils.py +73 -0
  93. sglang/srt/layers/quantization/modelopt_quant.py +18 -1
  94. sglang/srt/layers/radix_attention.py +1 -0
  95. sglang/srt/layers/rotary_embedding.py +0 -1
  96. sglang/srt/layers/sampler.py +76 -31
  97. sglang/srt/layers/vocab_parallel_embedding.py +14 -13
  98. sglang/srt/lora/lora.py +17 -1
  99. sglang/srt/lora/lora_config.py +5 -0
  100. sglang/srt/lora/lora_manager.py +1 -3
  101. sglang/srt/managers/cache_controller.py +193 -62
  102. sglang/srt/managers/configure_logging.py +2 -1
  103. sglang/srt/managers/data_parallel_controller.py +6 -2
  104. sglang/srt/managers/detokenizer_manager.py +124 -102
  105. sglang/srt/managers/image_processor.py +2 -1
  106. sglang/srt/managers/io_struct.py +143 -6
  107. sglang/srt/managers/schedule_batch.py +237 -197
  108. sglang/srt/managers/schedule_policy.py +29 -29
  109. sglang/srt/managers/scheduler.py +681 -259
  110. sglang/srt/managers/session_controller.py +6 -2
  111. sglang/srt/managers/tokenizer_manager.py +224 -68
  112. sglang/srt/managers/tp_worker.py +15 -4
  113. sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
  114. sglang/srt/mem_cache/chunk_cache.py +18 -11
  115. sglang/srt/mem_cache/hiradix_cache.py +394 -0
  116. sglang/srt/mem_cache/memory_pool.py +44 -18
  117. sglang/srt/mem_cache/radix_cache.py +58 -47
  118. sglang/srt/metrics/collector.py +94 -36
  119. sglang/srt/model_executor/cuda_graph_runner.py +55 -24
  120. sglang/srt/model_executor/forward_batch_info.py +49 -16
  121. sglang/srt/model_executor/model_runner.py +208 -28
  122. sglang/srt/model_loader/loader.py +3 -3
  123. sglang/srt/model_loader/weight_utils.py +36 -14
  124. sglang/srt/models/baichuan.py +31 -6
  125. sglang/srt/models/chatglm.py +39 -7
  126. sglang/srt/models/commandr.py +29 -5
  127. sglang/srt/models/dbrx.py +31 -5
  128. sglang/srt/models/deepseek.py +43 -6
  129. sglang/srt/models/deepseek_nextn.py +32 -19
  130. sglang/srt/models/deepseek_v2.py +265 -32
  131. sglang/srt/models/exaone.py +19 -9
  132. sglang/srt/models/gemma.py +22 -8
  133. sglang/srt/models/gemma2.py +25 -12
  134. sglang/srt/models/gemma2_reward.py +5 -1
  135. sglang/srt/models/gpt2.py +28 -13
  136. sglang/srt/models/gpt_bigcode.py +27 -5
  137. sglang/srt/models/granite.py +21 -9
  138. sglang/srt/models/grok.py +21 -4
  139. sglang/srt/models/internlm2.py +36 -6
  140. sglang/srt/models/internlm2_reward.py +5 -1
  141. sglang/srt/models/llama.py +26 -9
  142. sglang/srt/models/llama_classification.py +5 -1
  143. sglang/srt/models/llama_eagle.py +17 -4
  144. sglang/srt/models/llama_embedding.py +5 -1
  145. sglang/srt/models/llama_reward.py +7 -2
  146. sglang/srt/models/llava.py +19 -3
  147. sglang/srt/models/llavavid.py +10 -1
  148. sglang/srt/models/minicpm.py +26 -2
  149. sglang/srt/models/minicpm3.py +39 -3
  150. sglang/srt/models/minicpmv.py +45 -14
  151. sglang/srt/models/mixtral.py +20 -9
  152. sglang/srt/models/mixtral_quant.py +50 -8
  153. sglang/srt/models/mllama.py +57 -11
  154. sglang/srt/models/olmo.py +34 -6
  155. sglang/srt/models/olmo2.py +34 -13
  156. sglang/srt/models/olmoe.py +26 -4
  157. sglang/srt/models/phi3_small.py +29 -10
  158. sglang/srt/models/qwen.py +26 -3
  159. sglang/srt/models/qwen2.py +26 -4
  160. sglang/srt/models/qwen2_5_vl.py +46 -8
  161. sglang/srt/models/qwen2_eagle.py +17 -5
  162. sglang/srt/models/qwen2_moe.py +44 -6
  163. sglang/srt/models/qwen2_rm.py +78 -0
  164. sglang/srt/models/qwen2_vl.py +39 -8
  165. sglang/srt/models/stablelm.py +32 -5
  166. sglang/srt/models/torch_native_llama.py +5 -2
  167. sglang/srt/models/xverse.py +21 -9
  168. sglang/srt/models/xverse_moe.py +45 -7
  169. sglang/srt/models/yivl.py +2 -1
  170. sglang/srt/openai_api/adapter.py +109 -24
  171. sglang/srt/openai_api/protocol.py +17 -1
  172. sglang/srt/reasoning_parser.py +154 -0
  173. sglang/srt/sampling/penaltylib/__init__.py +4 -6
  174. sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
  175. sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
  176. sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
  177. sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
  178. sglang/srt/sampling/sampling_batch_info.py +79 -157
  179. sglang/srt/sampling/sampling_params.py +16 -13
  180. sglang/srt/server_args.py +136 -52
  181. sglang/srt/speculative/build_eagle_tree.py +2 -8
  182. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +0 -1
  183. sglang/srt/speculative/eagle_utils.py +92 -58
  184. sglang/srt/speculative/eagle_worker.py +186 -94
  185. sglang/srt/speculative/spec_info.py +1 -13
  186. sglang/srt/utils.py +43 -17
  187. sglang/srt/warmup.py +47 -0
  188. sglang/test/few_shot_gsm8k.py +4 -1
  189. sglang/test/runners.py +389 -126
  190. sglang/test/send_one.py +88 -0
  191. sglang/test/test_block_fp8_ep.py +361 -0
  192. sglang/test/test_programs.py +1 -1
  193. sglang/test/test_utils.py +138 -84
  194. sglang/utils.py +50 -60
  195. sglang/version.py +1 -1
  196. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/METADATA +21 -15
  197. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/RECORD +200 -166
  198. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/WHEEL +1 -1
  199. sglang/bench_latency.py +0 -1
  200. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
  201. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
  202. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
  203. sglang/test/srt/sampling/penaltylib/utils.py +0 -344
  204. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/LICENSE +0 -0
  205. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/top_level.txt +0 -0
@@ -35,12 +35,12 @@ class SessionReqNode:
35
35
  for req_node in self.childs:
36
36
  req_node.clear(req_dict)
37
37
 
38
- if self.req.finished_reason == None:
38
+ if self.req.finished_reason is None:
39
39
  self.req.to_abort = True
40
40
  del req_dict[self.req.rid]
41
41
 
42
42
  def abort(self):
43
- if self.req.finished_reason == None:
43
+ if self.req.finished_reason is None:
44
44
  self.req.to_abort = True
45
45
 
46
46
  def __str__(self):
@@ -132,6 +132,10 @@ class Session:
132
132
  lora_path=req.lora_path,
133
133
  session_id=self.session_id,
134
134
  custom_logit_processor=req.custom_logit_processor,
135
+ stream=req.stream,
136
+ return_logprob=req.return_logprob,
137
+ top_logprobs_num=req.top_logprobs_num,
138
+ token_ids_logprob=req.token_ids_logprob,
135
139
  )
136
140
  if last_req is not None:
137
141
  new_req.image_inputs = last_req.image_inputs
@@ -16,6 +16,7 @@
16
16
  import asyncio
17
17
  import copy
18
18
  import dataclasses
19
+ import json
19
20
  import logging
20
21
  import os
21
22
  import pickle
@@ -24,9 +25,21 @@ import sys
24
25
  import threading
25
26
  import time
26
27
  import uuid
28
+ from collections import deque
27
29
  from datetime import datetime
28
30
  from http import HTTPStatus
29
- from typing import Any, Awaitable, Dict, Generic, List, Optional, Tuple, TypeVar, Union
31
+ from typing import (
32
+ Any,
33
+ Awaitable,
34
+ Deque,
35
+ Dict,
36
+ Generic,
37
+ List,
38
+ Optional,
39
+ Tuple,
40
+ TypeVar,
41
+ Union,
42
+ )
30
43
 
31
44
  import fastapi
32
45
  import uvloop
@@ -44,6 +57,7 @@ from sglang.srt.managers.image_processor import (
44
57
  from sglang.srt.managers.io_struct import (
45
58
  AbortReq,
46
59
  BatchEmbeddingOut,
60
+ BatchMultimodalOut,
47
61
  BatchStrOut,
48
62
  BatchTokenIDOut,
49
63
  CloseSessionReqInput,
@@ -51,13 +65,18 @@ from sglang.srt.managers.io_struct import (
51
65
  EmbeddingReqInput,
52
66
  FlushCacheReq,
53
67
  GenerateReqInput,
68
+ GetInternalStateReq,
69
+ GetInternalStateReqOutput,
54
70
  GetWeightsByNameReqInput,
55
71
  GetWeightsByNameReqOutput,
72
+ HealthCheckOutput,
56
73
  InitWeightsUpdateGroupReqInput,
57
74
  InitWeightsUpdateGroupReqOutput,
58
75
  OpenSessionReqInput,
59
76
  OpenSessionReqOutput,
60
77
  ProfileReq,
78
+ ProfileReqOutput,
79
+ ProfileReqType,
61
80
  ReleaseMemoryOccupationReqInput,
62
81
  ReleaseMemoryOccupationReqOutput,
63
82
  ResumeMemoryOccupationReqInput,
@@ -98,7 +117,10 @@ class ReqState:
98
117
 
99
118
  # For metrics
100
119
  created_time: float
101
- first_token_time: Optional[float] = None
120
+ finished_time: float = 0.0
121
+ first_token_time: float = 0.0
122
+ last_time: float = 0.0
123
+ last_completion_tokens: int = 1
102
124
 
103
125
  # For streaming output
104
126
  last_output_offset: int = 0
@@ -113,11 +135,10 @@ class TokenizerManager:
113
135
  port_args: PortArgs,
114
136
  ):
115
137
  # Parse args
116
-
117
138
  self.server_args = server_args
118
139
  self.enable_metrics = server_args.enable_metrics
119
140
  self.log_requests = server_args.log_requests
120
- self.log_requests_level = 0
141
+ self.log_requests_level = server_args.log_requests_level
121
142
 
122
143
  # Init inter-process communication
123
144
  context = zmq.asyncio.Context(2)
@@ -143,6 +164,7 @@ class TokenizerManager:
143
164
  )
144
165
 
145
166
  self.is_generation = self.model_config.is_generation
167
+ self.is_image_gen = self.model_config.is_image_gen
146
168
  self.context_len = self.model_config.context_len
147
169
  self.image_token_id = self.model_config.image_token_id
148
170
 
@@ -178,9 +200,12 @@ class TokenizerManager:
178
200
  # Store states
179
201
  self.no_create_loop = False
180
202
  self.rid_to_state: Dict[str, ReqState] = {}
203
+ self.gracefully_exit = False
204
+ self.last_receive_tstamp = 0
181
205
  self.dump_requests_folder = "" # By default do not dump
182
206
  self.dump_requests_threshold = 1000
183
207
  self.dump_request_list: List[Tuple] = []
208
+ self.log_request_metadata = self.get_log_request_metadata()
184
209
 
185
210
  # The event to notify the weight sync is finished.
186
211
  self.model_update_lock = RWLock()
@@ -192,8 +217,19 @@ class TokenizerManager:
192
217
  # For session info
193
218
  self.session_futures = {} # session_id -> asyncio event
194
219
 
195
- # Others
196
- self.gracefully_exit = False
220
+ # Set after scheduler is initialized
221
+ self.max_req_input_len = None
222
+
223
+ # Metrics
224
+ if self.enable_metrics:
225
+ self.metrics_collector = TokenizerMetricsCollector(
226
+ labels={
227
+ "model_name": self.server_args.served_model_name,
228
+ # TODO: Add lora name/path in the future,
229
+ },
230
+ )
231
+
232
+ # Communicators
197
233
  self.init_weights_update_group_communicator = _Communicator(
198
234
  self.send_to_scheduler, server_args.dp_size
199
235
  )
@@ -212,22 +248,23 @@ class TokenizerManager:
212
248
  self.resume_memory_occupation_communicator = _Communicator(
213
249
  self.send_to_scheduler, server_args.dp_size
214
250
  )
215
- # Set after scheduler is initialized
216
- self.max_req_input_len = None
217
-
218
- # Metrics
219
- if self.enable_metrics:
220
- self.metrics_collector = TokenizerMetricsCollector(
221
- labels={
222
- "model_name": self.server_args.served_model_name,
223
- # TODO: Add lora name/path in the future,
224
- },
225
- )
251
+ self.start_profile_communicator = _Communicator(
252
+ self.send_to_scheduler, server_args.dp_size
253
+ )
254
+ self.health_check_communitcator = _Communicator(self.send_to_scheduler, 1)
255
+ self.get_internal_state_communicator = _Communicator(
256
+ self.send_to_scheduler, server_args.dp_size
257
+ )
226
258
 
227
259
  self._result_dispatcher = TypeBasedDispatcher(
228
260
  [
229
261
  (
230
- (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut),
262
+ (
263
+ BatchStrOut,
264
+ BatchEmbeddingOut,
265
+ BatchTokenIDOut,
266
+ BatchMultimodalOut,
267
+ ),
231
268
  self._handle_batch_output,
232
269
  ),
233
270
  (OpenSessionReqOutput, self._handle_open_session_req_output),
@@ -259,6 +296,15 @@ class TokenizerManager:
259
296
  ResumeMemoryOccupationReqOutput,
260
297
  self.resume_memory_occupation_communicator.handle_recv,
261
298
  ),
299
+ (
300
+ ProfileReqOutput,
301
+ self.start_profile_communicator.handle_recv,
302
+ ),
303
+ (
304
+ GetInternalStateReqOutput,
305
+ self.get_internal_state_communicator.handle_recv,
306
+ ),
307
+ (HealthCheckOutput, lambda x: None),
262
308
  ]
263
309
  )
264
310
 
@@ -280,9 +326,9 @@ class TokenizerManager:
280
326
  obj.normalize_batch_and_arguments()
281
327
 
282
328
  if self.log_requests:
283
- max_length = 2048 if self.log_requests_level == 0 else 1 << 30
329
+ max_length, skip_names, _ = self.log_request_metadata
284
330
  logger.info(
285
- f"Receive: obj={dataclass_to_string_truncated(obj, max_length)}"
331
+ f"Receive: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}"
286
332
  )
287
333
 
288
334
  async with self.model_update_lock.reader_lock:
@@ -336,6 +382,7 @@ class TokenizerManager:
336
382
  return_logprob = obj.return_logprob
337
383
  logprob_start_len = obj.logprob_start_len
338
384
  top_logprobs_num = obj.top_logprobs_num
385
+ token_ids_logprob = obj.token_ids_logprob
339
386
  session_params = (
340
387
  SessionParams(**obj.session_params) if obj.session_params else None
341
388
  )
@@ -378,11 +425,13 @@ class TokenizerManager:
378
425
  return_logprob,
379
426
  logprob_start_len,
380
427
  top_logprobs_num,
428
+ token_ids_logprob,
381
429
  obj.stream,
382
430
  lora_path=obj.lora_path,
383
431
  input_embeds=input_embeds,
384
432
  session_params=session_params,
385
433
  custom_logit_processor=obj.custom_logit_processor,
434
+ return_hidden_states=obj.return_hidden_states,
386
435
  )
387
436
  elif isinstance(obj, EmbeddingReqInput):
388
437
  tokenized_obj = TokenizedEmbeddingReqInput(
@@ -400,8 +449,7 @@ class TokenizerManager:
400
449
  tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
401
450
  created_time: Optional[float] = None,
402
451
  ):
403
- event = asyncio.Event()
404
- state = ReqState([], False, event, obj, created_time=created_time)
452
+ state = ReqState([], False, asyncio.Event(), obj, created_time=created_time)
405
453
  self.rid_to_state[obj.rid] = state
406
454
  self.send_to_scheduler.send_pyobj(tokenized_obj)
407
455
 
@@ -419,7 +467,10 @@ class TokenizerManager:
419
467
  except asyncio.TimeoutError:
420
468
  if request is not None and await request.is_disconnected():
421
469
  self.abort_request(obj.rid)
422
- raise ValueError(f"Abort request {obj.rid}")
470
+ raise ValueError(
471
+ "Request is disconnected from the client side. "
472
+ f"Abort request {obj.rid}"
473
+ )
423
474
  continue
424
475
 
425
476
  out = state.out_list[-1]
@@ -427,8 +478,11 @@ class TokenizerManager:
427
478
  state.out_list = []
428
479
  if state.finished:
429
480
  if self.log_requests:
430
- max_length = 2048 if self.log_requests_level == 0 else 1 << 30
431
- msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length)}, out={dataclass_to_string_truncated(out, max_length)}"
481
+ max_length, skip_names, out_skip_names = self.log_request_metadata
482
+ if self.model_config.is_multimodal_gen:
483
+ msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}"
484
+ else:
485
+ msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}, out={dataclass_to_string_truncated(out, max_length, skip_names=out_skip_names)}"
432
486
  logger.info(msg)
433
487
  del self.rid_to_state[obj.rid]
434
488
 
@@ -451,7 +505,10 @@ class TokenizerManager:
451
505
  else:
452
506
  if request is not None and await request.is_disconnected():
453
507
  self.abort_request(obj.rid)
454
- raise ValueError(f"Abort request {obj.rid}")
508
+ raise ValueError(
509
+ "Request is disconnected from the client side. "
510
+ f"Abort request {obj.rid}"
511
+ )
455
512
 
456
513
  async def _handle_batch_request(
457
514
  self,
@@ -542,12 +599,25 @@ class TokenizerManager:
542
599
  req = AbortReq(rid)
543
600
  self.send_to_scheduler.send_pyobj(req)
544
601
 
545
- def start_profile(self):
546
- req = ProfileReq.START_PROFILE
547
- self.send_to_scheduler.send_pyobj(req)
602
+ async def start_profile(
603
+ self,
604
+ output_dir: Optional[str] = None,
605
+ num_steps: Optional[int] = None,
606
+ activities: Optional[List[str]] = None,
607
+ ):
608
+ req = ProfileReq(
609
+ type=ProfileReqType.START_PROFILE,
610
+ output_dir=output_dir,
611
+ num_steps=num_steps,
612
+ activities=activities,
613
+ )
614
+ result = (await self.start_profile_communicator(req))[0]
615
+ if not result.success:
616
+ raise RuntimeError(result.message)
617
+ return result
548
618
 
549
619
  def stop_profile(self):
550
- req = ProfileReq.STOP_PROFILE
620
+ req = ProfileReq(type=ProfileReqType.STOP_PROFILE)
551
621
  self.send_to_scheduler.send_pyobj(req)
552
622
 
553
623
  async def update_weights_from_disk(
@@ -580,7 +650,7 @@ class TokenizerManager:
580
650
  self.server_args.model_path = obj.model_path
581
651
  self.server_args.load_format = obj.load_format
582
652
  self.model_path = obj.model_path
583
- return result.success, result.message
653
+ return result.success, result.message, result.num_paused_requests
584
654
  else: # self.server_args.dp_size > 1
585
655
  self.model_update_tmp = []
586
656
  result = await self.model_update_result
@@ -592,7 +662,8 @@ class TokenizerManager:
592
662
  self.model_path = obj.model_path
593
663
  all_message = [r.message for r in result]
594
664
  all_message = " | ".join(all_message)
595
- return all_success, all_message
665
+ all_paused_requests = [r.num_paused_requests for r in result]
666
+ return all_success, all_message, all_paused_requests
596
667
 
597
668
  async def init_weights_update_group(
598
669
  self,
@@ -687,6 +758,46 @@ class TokenizerManager:
687
758
  ):
688
759
  await self.send_to_scheduler.send_pyobj(obj)
689
760
 
761
+ async def get_internal_state(self) -> Dict[Any, Any]:
762
+ req = GetInternalStateReq()
763
+ res: List[GetInternalStateReqOutput] = (
764
+ await self.get_internal_state_communicator(req)
765
+ )
766
+ return res[0].internal_state
767
+
768
+ def get_log_request_metadata(self):
769
+ max_length = None
770
+ skip_names = None
771
+ out_skip_names = None
772
+ if self.log_requests:
773
+ if self.log_requests_level == 0:
774
+ max_length = 1 << 30
775
+ skip_names = set(
776
+ [
777
+ "text",
778
+ "input_ids",
779
+ "input_embeds",
780
+ "image_data",
781
+ "audio_data",
782
+ "lora_path",
783
+ ]
784
+ )
785
+ out_skip_names = set(
786
+ [
787
+ "text",
788
+ "output_ids",
789
+ ]
790
+ )
791
+ elif self.log_requests_level == 1:
792
+ max_length = 2048
793
+ elif self.log_requests_level == 2:
794
+ max_length = 1 << 30
795
+ else:
796
+ raise ValueError(
797
+ f"Invalid --log-requests-level: {self.log_requests_level=}"
798
+ )
799
+ return max_length, skip_names, out_skip_names
800
+
690
801
  def configure_logging(self, obj: ConfigureLoggingReq):
691
802
  if obj.log_requests is not None:
692
803
  self.log_requests = obj.log_requests
@@ -697,6 +808,7 @@ class TokenizerManager:
697
808
  if obj.dump_requests_threshold is not None:
698
809
  self.dump_requests_threshold = obj.dump_requests_threshold
699
810
  logging.info(f"Config logging: {obj=}")
811
+ self.log_request_metadata = self.get_log_request_metadata()
700
812
 
701
813
  def create_abort_task(self, obj: GenerateReqInput):
702
814
  # Abort the request if the client is disconnected.
@@ -761,15 +873,20 @@ class TokenizerManager:
761
873
  while True:
762
874
  recv_obj = await self.recv_from_detokenizer.recv_pyobj()
763
875
  self._result_dispatcher(recv_obj)
876
+ self.last_receive_tstamp = time.time()
764
877
 
765
878
  def _handle_batch_output(
766
- self, recv_obj: Union[BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut]
879
+ self,
880
+ recv_obj: Union[
881
+ BatchStrOut, BatchEmbeddingOut, BatchMultimodalOut, BatchTokenIDOut
882
+ ],
767
883
  ):
768
884
  for i, rid in enumerate(recv_obj.rids):
769
885
  state = self.rid_to_state.get(rid, None)
770
886
  if state is None:
771
887
  continue
772
888
 
889
+ # Build meta_info and return value
773
890
  meta_info = {
774
891
  "id": rid,
775
892
  "finish_reason": recv_obj.finished_reasons[i],
@@ -780,14 +897,12 @@ class TokenizerManager:
780
897
  self.convert_logprob_style(
781
898
  meta_info,
782
899
  state.obj.top_logprobs_num,
900
+ state.obj.token_ids_logprob,
783
901
  state.obj.return_text_in_logprobs,
784
902
  recv_obj,
785
903
  i,
786
904
  )
787
905
 
788
- if self.server_args.speculative_algorithm:
789
- meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
790
-
791
906
  if not isinstance(recv_obj, BatchEmbeddingOut):
792
907
  meta_info.update(
793
908
  {
@@ -796,10 +911,7 @@ class TokenizerManager:
796
911
  }
797
912
  )
798
913
 
799
- if (
800
- hasattr(recv_obj, "output_hidden_states")
801
- and len(recv_obj.output_hidden_states[i]) > 0
802
- ):
914
+ if getattr(recv_obj, "output_hidden_states", None):
803
915
  meta_info["hidden_states"] = recv_obj.output_hidden_states[i]
804
916
 
805
917
  if isinstance(recv_obj, BatchStrOut):
@@ -808,10 +920,20 @@ class TokenizerManager:
808
920
  "meta_info": meta_info,
809
921
  }
810
922
  elif isinstance(recv_obj, BatchTokenIDOut):
923
+ if self.server_args.stream_output and state.obj.stream:
924
+ output_token_ids = recv_obj.output_ids[i][
925
+ state.last_output_offset :
926
+ ]
927
+ state.last_output_offset = len(recv_obj.output_ids[i])
928
+ else:
929
+ output_token_ids = recv_obj.output_ids[i]
930
+
811
931
  out_dict = {
812
- "token_ids": recv_obj.output_ids[i],
932
+ "output_ids": output_token_ids,
813
933
  "meta_info": meta_info,
814
934
  }
935
+ elif isinstance(recv_obj, BatchMultimodalOut):
936
+ raise NotImplementedError()
815
937
  else:
816
938
  assert isinstance(recv_obj, BatchEmbeddingOut)
817
939
  out_dict = {
@@ -819,10 +941,17 @@ class TokenizerManager:
819
941
  "meta_info": meta_info,
820
942
  }
821
943
 
822
- state.out_list.append(out_dict)
823
944
  state.finished = recv_obj.finished_reasons[i] is not None
945
+ if state.finished:
946
+ if self.server_args.speculative_algorithm:
947
+ meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
948
+ state.finished_time = time.time()
949
+ meta_info["e2e_latency"] = state.finished_time - state.created_time
950
+
951
+ state.out_list.append(out_dict)
824
952
  state.event.set()
825
953
 
954
+ # Log metrics and dump
826
955
  if self.enable_metrics and state.obj.log_metrics:
827
956
  self.collect_metrics(state, recv_obj, i)
828
957
  if self.dump_requests_folder and state.finished and state.obj.log_metrics:
@@ -832,6 +961,7 @@ class TokenizerManager:
832
961
  self,
833
962
  meta_info: dict,
834
963
  top_logprobs_num: int,
964
+ token_ids_logprob: List[int],
835
965
  return_text_in_logprobs: bool,
836
966
  recv_obj: BatchStrOut,
837
967
  recv_obj_index: int,
@@ -859,6 +989,20 @@ class TokenizerManager:
859
989
  return_text_in_logprobs,
860
990
  )
861
991
 
992
+ if token_ids_logprob is not None:
993
+ meta_info["input_token_ids_logprobs"] = self.detokenize_top_logprobs_tokens(
994
+ recv_obj.input_token_ids_logprobs_val[recv_obj_index],
995
+ recv_obj.input_token_ids_logprobs_idx[recv_obj_index],
996
+ return_text_in_logprobs,
997
+ )
998
+ meta_info["output_token_ids_logprobs"] = (
999
+ self.detokenize_top_logprobs_tokens(
1000
+ recv_obj.output_token_ids_logprobs_val[recv_obj_index],
1001
+ recv_obj.output_token_ids_logprobs_idx[recv_obj_index],
1002
+ return_text_in_logprobs,
1003
+ )
1004
+ )
1005
+
862
1006
  def detokenize_logprob_tokens(
863
1007
  self,
864
1008
  token_logprobs_val: List[float],
@@ -902,34 +1046,30 @@ class TokenizerManager:
902
1046
  else 0
903
1047
  )
904
1048
 
905
- if state.first_token_time is None:
906
- state.first_token_time = time.time()
1049
+ if state.first_token_time == 0.0:
1050
+ state.first_token_time = state.last_time = time.time()
1051
+ state.last_completion_tokens = completion_tokens
907
1052
  self.metrics_collector.observe_time_to_first_token(
908
1053
  state.first_token_time - state.created_time
909
1054
  )
910
1055
  else:
911
- if completion_tokens >= 2:
912
- # Compute time_per_output_token for the streaming case
913
- self.metrics_collector.observe_time_per_output_token(
914
- (time.time() - state.first_token_time) / (completion_tokens - 1)
1056
+ num_new_tokens = completion_tokens - state.last_completion_tokens
1057
+ if num_new_tokens:
1058
+ new_time = time.time()
1059
+ interval = new_time - state.last_time
1060
+ self.metrics_collector.observe_inter_token_latency(
1061
+ interval,
1062
+ num_new_tokens,
915
1063
  )
1064
+ state.last_time = new_time
1065
+ state.last_completion_tokens = completion_tokens
916
1066
 
917
1067
  if state.finished:
918
1068
  self.metrics_collector.observe_one_finished_request(
919
- recv_obj.prompt_tokens[i], completion_tokens
920
- )
921
- self.metrics_collector.observe_e2e_request_latency(
922
- time.time() - state.created_time
1069
+ recv_obj.prompt_tokens[i],
1070
+ completion_tokens,
1071
+ state.finished_time - state.created_time,
923
1072
  )
924
- # Compute time_per_output_token for the non-streaming case
925
- if (
926
- hasattr(state.obj, "stream")
927
- and not state.obj.stream
928
- and completion_tokens >= 1
929
- ):
930
- self.metrics_collector.observe_time_per_output_token(
931
- (time.time() - state.created_time) / completion_tokens
932
- )
933
1073
 
934
1074
  def dump_requests(self, state: ReqState, out_dict: dict):
935
1075
  self.dump_request_list.append(
@@ -984,7 +1124,7 @@ async def print_exception_wrapper(func):
984
1124
 
985
1125
 
986
1126
  class SignalHandler:
987
- def __init__(self, tokenizer_manager):
1127
+ def __init__(self, tokenizer_manager: TokenizerManager):
988
1128
  self.tokenizer_manager = tokenizer_manager
989
1129
 
990
1130
  def signal_handler(self, signum=None, frame=None):
@@ -998,22 +1138,38 @@ T = TypeVar("T")
998
1138
 
999
1139
 
1000
1140
  class _Communicator(Generic[T]):
1141
+ """Note: The communicator now only run up to 1 in-flight request at any time."""
1142
+
1001
1143
  def __init__(self, sender, fan_out: int):
1002
1144
  self._sender = sender
1003
1145
  self._fan_out = fan_out
1004
- self._result_future: Optional[asyncio.Future] = None
1146
+ self._result_event: Optional[asyncio.Event] = None
1005
1147
  self._result_values: Optional[List[T]] = None
1148
+ self._ready_queue: Deque[asyncio.Future] = deque()
1006
1149
 
1007
1150
  async def __call__(self, obj):
1008
- self._sender.send_pyobj(obj)
1009
- self._result_future = asyncio.Future()
1151
+ ready_event = asyncio.Event()
1152
+ if self._result_event is not None or len(self._ready_queue) > 0:
1153
+ self._ready_queue.append(ready_event)
1154
+ await ready_event.wait()
1155
+ assert self._result_event is None
1156
+ assert self._result_values is None
1157
+
1158
+ if obj:
1159
+ self._sender.send_pyobj(obj)
1160
+
1161
+ self._result_event = asyncio.Event()
1010
1162
  self._result_values = []
1011
- await self._result_future
1163
+ await self._result_event.wait()
1012
1164
  result_values = self._result_values
1013
- self._result_future = self._result_values = None
1165
+ self._result_event = self._result_values = None
1166
+
1167
+ if len(self._ready_queue) > 0:
1168
+ self._ready_queue.popleft().set()
1169
+
1014
1170
  return result_values
1015
1171
 
1016
1172
  def handle_recv(self, recv_obj: T):
1017
1173
  self._result_values.append(recv_obj)
1018
1174
  if len(self._result_values) == self._fan_out:
1019
- self._result_future.set_result(None)
1175
+ self._result_event.set()
@@ -15,10 +15,13 @@
15
15
 
16
16
  import logging
17
17
  import threading
18
- from typing import Optional
18
+ from typing import Optional, Tuple
19
+
20
+ import torch
19
21
 
20
22
  from sglang.srt.configs.model_config import ModelConfig
21
23
  from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
24
+ from sglang.srt.layers.logits_processor import LogitsProcessorOutput
22
25
  from sglang.srt.managers.io_struct import (
23
26
  GetWeightsByNameReqInput,
24
27
  InitWeightsUpdateGroupReqInput,
@@ -27,6 +30,7 @@ from sglang.srt.managers.io_struct import (
27
30
  UpdateWeightsFromTensorReqInput,
28
31
  )
29
32
  from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
33
+ from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
30
34
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
31
35
  from sglang.srt.model_executor.model_runner import ModelRunner
32
36
  from sglang.srt.server_args import ServerArgs
@@ -46,6 +50,8 @@ class TpModelWorker:
46
50
  dp_rank: Optional[int],
47
51
  nccl_port: int,
48
52
  is_draft_worker: bool = False,
53
+ req_to_token_pool: Optional[ReqToTokenPool] = None,
54
+ token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None,
49
55
  ):
50
56
  # Parse args
51
57
  self.tp_rank = tp_rank
@@ -74,6 +80,8 @@ class TpModelWorker:
74
80
  nccl_port=nccl_port,
75
81
  server_args=server_args,
76
82
  is_draft_worker=is_draft_worker,
83
+ req_to_token_pool=req_to_token_pool,
84
+ token_to_kv_pool_allocator=token_to_kv_pool_allocator,
77
85
  )
78
86
  if server_args.skip_tokenizer_init:
79
87
  self.tokenizer = self.processor = None
@@ -151,7 +159,7 @@ class TpModelWorker:
151
159
  def get_memory_pool(self):
152
160
  return (
153
161
  self.model_runner.req_to_token_pool,
154
- self.model_runner.token_to_kv_pool,
162
+ self.model_runner.token_to_kv_pool_allocator,
155
163
  )
156
164
 
157
165
  def forward_batch_generation(
@@ -159,7 +167,7 @@ class TpModelWorker:
159
167
  model_worker_batch: ModelWorkerBatch,
160
168
  launch_done: Optional[threading.Event] = None,
161
169
  skip_sample: bool = False,
162
- ):
170
+ ) -> Tuple[LogitsProcessorOutput, Optional[torch.Tensor]]:
163
171
  forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
164
172
  logits_output = self.model_runner.forward(forward_batch)
165
173
  if launch_done:
@@ -205,7 +213,10 @@ class TpModelWorker:
205
213
 
206
214
  def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
207
215
  success, message = self.model_runner.update_weights_from_tensor(
208
- MultiprocessingSerializer.deserialize(recv_req.serialized_named_tensors)
216
+ named_tensors=MultiprocessingSerializer.deserialize(
217
+ recv_req.serialized_named_tensors
218
+ ),
219
+ load_format=recv_req.load_format,
209
220
  )
210
221
  return success, message
211
222