sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (282) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +321 -31
  3. sglang/bench_serving.py +10 -3
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +4 -0
  11. sglang/srt/configs/dots_ocr.py +64 -0
  12. sglang/srt/configs/falcon_h1.py +360 -0
  13. sglang/srt/configs/load_config.py +8 -0
  14. sglang/srt/configs/model_config.py +160 -105
  15. sglang/srt/configs/qwen3_vl.py +586 -0
  16. sglang/srt/constrained/base_grammar_backend.py +1 -0
  17. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  18. sglang/srt/constrained/xgrammar_backend.py +6 -4
  19. sglang/srt/debug_utils/dumper.py +10 -3
  20. sglang/srt/disaggregation/ascend/conn.py +2 -2
  21. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  22. sglang/srt/disaggregation/common/conn.py +266 -98
  23. sglang/srt/disaggregation/decode.py +50 -9
  24. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  25. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  26. sglang/srt/disaggregation/mooncake/conn.py +51 -541
  27. sglang/srt/disaggregation/nixl/conn.py +148 -39
  28. sglang/srt/disaggregation/prefill.py +31 -14
  29. sglang/srt/disaggregation/utils.py +36 -5
  30. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  31. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  32. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  33. sglang/srt/distributed/parallel_state.py +135 -80
  34. sglang/srt/entrypoints/engine.py +23 -3
  35. sglang/srt/entrypoints/grpc_request_manager.py +330 -55
  36. sglang/srt/entrypoints/grpc_server.py +232 -102
  37. sglang/srt/entrypoints/http_server.py +49 -9
  38. sglang/srt/entrypoints/openai/protocol.py +110 -5
  39. sglang/srt/entrypoints/openai/serving_base.py +25 -6
  40. sglang/srt/entrypoints/openai/serving_chat.py +178 -49
  41. sglang/srt/entrypoints/openai/serving_completions.py +5 -3
  42. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  43. sglang/srt/entrypoints/openai/serving_responses.py +42 -0
  44. sglang/srt/environ.py +285 -0
  45. sglang/srt/eplb/expert_location.py +30 -5
  46. sglang/srt/function_call/function_call_parser.py +3 -2
  47. sglang/srt/function_call/glm4_moe_detector.py +3 -3
  48. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  49. sglang/srt/function_call/json_array_parser.py +63 -0
  50. sglang/srt/function_call/kimik2_detector.py +17 -4
  51. sglang/srt/function_call/utils.py +96 -5
  52. sglang/srt/grpc/compile_proto.py +245 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
  56. sglang/srt/layers/activation.py +7 -6
  57. sglang/srt/layers/attention/aiter_backend.py +14 -15
  58. sglang/srt/layers/attention/ascend_backend.py +108 -9
  59. sglang/srt/layers/attention/attention_registry.py +206 -0
  60. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  61. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  62. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  63. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  66. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  67. sglang/srt/layers/attention/flashinfer_backend.py +112 -194
  68. sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
  69. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  70. sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
  71. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
  72. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
  73. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
  74. sglang/srt/layers/attention/mamba/mamba.py +566 -1
  75. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  76. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  77. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  78. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  79. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  80. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  81. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  82. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  83. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  84. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  85. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  86. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  87. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  88. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  89. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  90. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  91. sglang/srt/layers/attention/nsa/utils.py +24 -0
  92. sglang/srt/layers/attention/nsa_backend.py +887 -0
  93. sglang/srt/layers/attention/tbo_backend.py +6 -6
  94. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  95. sglang/srt/layers/attention/triton_backend.py +42 -9
  96. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  97. sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
  98. sglang/srt/layers/attention/vision.py +58 -0
  99. sglang/srt/layers/attention/wave_backend.py +4 -4
  100. sglang/srt/layers/communicator.py +8 -0
  101. sglang/srt/layers/dp_attention.py +11 -1
  102. sglang/srt/layers/elementwise.py +3 -1
  103. sglang/srt/layers/layernorm.py +2 -0
  104. sglang/srt/layers/linear.py +21 -4
  105. sglang/srt/layers/logits_processor.py +15 -2
  106. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  107. sglang/srt/layers/moe/ep_moe/layer.py +147 -74
  108. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  109. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  110. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  111. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  112. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
  113. sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
  114. sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
  115. sglang/srt/layers/moe/utils.py +10 -0
  116. sglang/srt/layers/parameter.py +23 -6
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  119. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  121. sglang/srt/layers/quantization/fp8.py +2 -2
  122. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  123. sglang/srt/layers/quantization/modelopt_quant.py +44 -9
  124. sglang/srt/layers/quantization/mxfp4.py +12 -4
  125. sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
  126. sglang/srt/layers/quantization/w4afp8.py +0 -4
  127. sglang/srt/layers/quantization/w8a8_int8.py +15 -3
  128. sglang/srt/layers/rotary_embedding.py +78 -31
  129. sglang/srt/layers/sampler.py +52 -4
  130. sglang/srt/layers/utils.py +23 -0
  131. sglang/srt/lora/backend/base_backend.py +3 -3
  132. sglang/srt/lora/backend/chunked_backend.py +348 -0
  133. sglang/srt/lora/backend/triton_backend.py +10 -4
  134. sglang/srt/lora/lora.py +7 -5
  135. sglang/srt/lora/lora_manager.py +17 -6
  136. sglang/srt/lora/mem_pool.py +1 -1
  137. sglang/srt/lora/triton_ops/__init__.py +4 -0
  138. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  139. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  140. sglang/srt/lora/utils.py +7 -5
  141. sglang/srt/managers/cache_controller.py +42 -142
  142. sglang/srt/managers/data_parallel_controller.py +11 -46
  143. sglang/srt/managers/detokenizer_manager.py +11 -11
  144. sglang/srt/managers/io_struct.py +162 -118
  145. sglang/srt/managers/mm_utils.py +43 -6
  146. sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
  147. sglang/srt/managers/multimodal_processor.py +1 -2
  148. sglang/srt/managers/overlap_utils.py +53 -0
  149. sglang/srt/managers/schedule_batch.py +167 -86
  150. sglang/srt/managers/schedule_policy.py +143 -16
  151. sglang/srt/managers/scheduler.py +359 -214
  152. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  153. sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
  154. sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
  155. sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
  156. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  157. sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
  158. sglang/srt/managers/tokenizer_manager.py +84 -136
  159. sglang/srt/managers/tp_worker.py +39 -29
  160. sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
  161. sglang/srt/managers/utils.py +1 -45
  162. sglang/srt/mem_cache/allocator.py +14 -20
  163. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  164. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  165. sglang/srt/mem_cache/chunk_cache.py +8 -1
  166. sglang/srt/mem_cache/evict_policy.py +23 -0
  167. sglang/srt/mem_cache/hicache_storage.py +40 -1
  168. sglang/srt/mem_cache/hiradix_cache.py +119 -32
  169. sglang/srt/mem_cache/memory_pool.py +188 -10
  170. sglang/srt/mem_cache/memory_pool_host.py +134 -182
  171. sglang/srt/mem_cache/radix_cache.py +222 -71
  172. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  173. sglang/srt/mem_cache/storage/__init__.py +10 -0
  174. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  175. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  176. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  177. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  178. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  179. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
  180. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
  181. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
  182. sglang/srt/mem_cache/swa_radix_cache.py +25 -34
  183. sglang/srt/metrics/collector.py +82 -120
  184. sglang/srt/metrics/func_timer.py +2 -7
  185. sglang/srt/metrics/utils.py +8 -1
  186. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  187. sglang/srt/model_executor/cuda_graph_runner.py +39 -32
  188. sglang/srt/model_executor/forward_batch_info.py +23 -38
  189. sglang/srt/model_executor/model_runner.py +131 -183
  190. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  191. sglang/srt/model_loader/loader.py +14 -10
  192. sglang/srt/model_loader/weight_utils.py +156 -2
  193. sglang/srt/models/bailing_moe.py +27 -4
  194. sglang/srt/models/deepseek_nextn.py +6 -1
  195. sglang/srt/models/deepseek_v2.py +536 -153
  196. sglang/srt/models/dots_ocr.py +173 -0
  197. sglang/srt/models/falcon_h1.py +576 -0
  198. sglang/srt/models/gemma3_causal.py +0 -2
  199. sglang/srt/models/gemma3_mm.py +1 -1
  200. sglang/srt/models/gemma3n_mm.py +1 -1
  201. sglang/srt/models/glm4_moe.py +3 -3
  202. sglang/srt/models/glm4_moe_nextn.py +2 -2
  203. sglang/srt/models/glm4v.py +1 -1
  204. sglang/srt/models/glm4v_moe.py +1 -1
  205. sglang/srt/models/gpt_oss.py +7 -30
  206. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  207. sglang/srt/models/llama.py +4 -0
  208. sglang/srt/models/longcat_flash.py +1 -1
  209. sglang/srt/models/longcat_flash_nextn.py +1 -1
  210. sglang/srt/models/mllama4.py +15 -4
  211. sglang/srt/models/qwen2.py +0 -7
  212. sglang/srt/models/qwen2_5_vl.py +2 -2
  213. sglang/srt/models/qwen2_audio.py +1 -1
  214. sglang/srt/models/qwen2_moe.py +64 -1
  215. sglang/srt/models/qwen2_vl.py +1 -1
  216. sglang/srt/models/qwen3.py +18 -3
  217. sglang/srt/models/qwen3_moe.py +31 -3
  218. sglang/srt/models/qwen3_next.py +36 -9
  219. sglang/srt/models/qwen3_vl.py +787 -0
  220. sglang/srt/models/qwen3_vl_moe.py +471 -0
  221. sglang/srt/models/registry.py +15 -3
  222. sglang/srt/models/sarashina2_vision.py +269 -0
  223. sglang/srt/models/solar.py +505 -0
  224. sglang/srt/models/starcoder2.py +357 -0
  225. sglang/srt/models/torch_native_llama.py +9 -2
  226. sglang/srt/models/utils.py +51 -0
  227. sglang/srt/multimodal/processors/base_processor.py +15 -7
  228. sglang/srt/multimodal/processors/dots_vlm.py +2 -3
  229. sglang/srt/multimodal/processors/internvl.py +20 -8
  230. sglang/srt/multimodal/processors/qwen_vl.py +8 -1
  231. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  232. sglang/srt/parser/jinja_template_utils.py +6 -0
  233. sglang/srt/sampling/sampling_batch_info.py +20 -2
  234. sglang/srt/sampling/sampling_params.py +7 -0
  235. sglang/srt/server_args.py +753 -295
  236. sglang/srt/server_args_config_parser.py +146 -0
  237. sglang/srt/single_batch_overlap.py +151 -0
  238. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  239. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  240. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  241. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  242. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  243. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  244. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
  245. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
  246. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
  247. sglang/srt/speculative/eagle_worker.py +57 -25
  248. sglang/srt/speculative/ngram_utils.py +428 -0
  249. sglang/srt/speculative/ngram_worker.py +245 -0
  250. sglang/srt/speculative/spec_info.py +47 -0
  251. sglang/srt/speculative/spec_utils.py +606 -0
  252. sglang/srt/torch_memory_saver_adapter.py +5 -7
  253. sglang/srt/tracing/trace.py +32 -6
  254. sglang/srt/two_batch_overlap.py +8 -5
  255. sglang/srt/utils/__init__.py +2 -0
  256. sglang/srt/{utils.py → utils/common.py} +399 -74
  257. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
  258. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  259. sglang/srt/utils/rpd_utils.py +452 -0
  260. sglang/srt/utils/slow_rank_detector.py +71 -0
  261. sglang/srt/warmup.py +8 -4
  262. sglang/srt/weight_sync/utils.py +1 -1
  263. sglang/test/get_logits_ut.py +57 -0
  264. sglang/test/run_eval.py +79 -11
  265. sglang/test/runners.py +1 -1
  266. sglang/test/simple_eval_common.py +5 -2
  267. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  268. sglang/test/test_block_fp8.py +2 -2
  269. sglang/test/test_deterministic.py +297 -0
  270. sglang/test/test_disaggregation_utils.py +12 -1
  271. sglang/test/test_programs.py +1 -1
  272. sglang/test/test_utils.py +355 -4
  273. sglang/utils.py +10 -1
  274. sglang/version.py +1 -1
  275. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
  276. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
  277. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  278. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  279. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  280. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  281. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  282. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -18,6 +18,7 @@ processes (TokenizerManager, DetokenizerManager, Scheduler).
18
18
 
19
19
  import copy
20
20
  import uuid
21
+ from abc import ABC
21
22
  from dataclasses import dataclass, field
22
23
  from enum import Enum
23
24
  from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
@@ -35,10 +36,33 @@ else:
35
36
  Image = Any
36
37
 
37
38
 
39
+ # Parameters for a session
40
+ @dataclass
41
+ class BaseReq(ABC):
42
+ rid: Optional[Union[str, List[str]]] = field(default=None, kw_only=True)
43
+
44
+ def regenerate_rid(self):
45
+ """Generate a new request ID and return it."""
46
+ if isinstance(self.rid, list):
47
+ self.rid = [uuid.uuid4().hex for _ in range(len(self.rid))]
48
+ else:
49
+ self.rid = uuid.uuid4().hex
50
+ return self.rid
51
+
52
+
53
+ @dataclass
54
+ class BaseBatchReq(ABC):
55
+ rids: Optional[List[str]] = field(default=None, kw_only=True)
56
+
57
+ def regenerate_rids(self):
58
+ """Generate new request IDs and return them."""
59
+ self.rids = [uuid.uuid4().hex for _ in range(len(self.rids))]
60
+ return self.rids
61
+
62
+
38
63
  @dataclass
39
64
  class SessionParams:
40
65
  id: Optional[str] = None
41
- rid: Optional[str] = None
42
66
  offset: Optional[int] = None
43
67
  replace: Optional[bool] = None
44
68
  drop_previous_output: Optional[bool] = None
@@ -62,7 +86,7 @@ MultimodalDataInputFormat = Union[
62
86
 
63
87
 
64
88
  @dataclass
65
- class GenerateReqInput:
89
+ class GenerateReqInput(BaseReq):
66
90
  # The input prompt. It can be a single prompt or a batch of prompts.
67
91
  text: Optional[Union[List[str], str]] = None
68
92
  # The token ids for text; one can specify either text or input_ids
@@ -82,8 +106,6 @@ class GenerateReqInput:
82
106
  audio_data: Optional[MultimodalDataInputFormat] = None
83
107
  # The sampling_params. See descriptions below.
84
108
  sampling_params: Optional[Union[List[Dict], Dict]] = None
85
- # The request id.
86
- rid: Optional[Union[List[str], str]] = None
87
109
  # Whether to return logprobs.
88
110
  return_logprob: Optional[Union[List[bool], bool]] = None
89
111
  # If return logprobs, the start location in the prompt for returning logprobs.
@@ -132,17 +154,20 @@ class GenerateReqInput:
132
154
  # Conversation id used for tracking requests
133
155
  conversation_id: Optional[str] = None
134
156
 
135
- # Label for the request
136
- label: Optional[str] = None
137
-
138
157
  # Priority for the request
139
158
  priority: Optional[int] = None
140
159
 
141
- # Image gen grpc migration
142
- return_bytes: bool = False
160
+ # Extra key for classifying the request (e.g. cache_salt)
161
+ extra_key: Optional[Union[List[str], str]] = None
143
162
 
144
- # For customer metric labels
145
- customer_labels: Optional[Dict[str, str]] = None
163
+ # Whether to disallow logging for this request (e.g. due to ZDR)
164
+ no_logs: bool = False
165
+
166
+ # For custom metric labels
167
+ custom_labels: Optional[Dict[str, str]] = None
168
+
169
+ # (Internal) Whether to return bytes for image generation
170
+ return_bytes: bool = False
146
171
 
147
172
  def contains_mm_input(self) -> bool:
148
173
  return (
@@ -485,11 +510,6 @@ class GenerateReqInput:
485
510
  ):
486
511
  raise ValueError("Session params must be a dict or a list of dicts.")
487
512
 
488
- def regenerate_rid(self):
489
- """Generate a new request ID and return it."""
490
- self.rid = uuid.uuid4().hex
491
- return self.rid
492
-
493
513
  def __getitem__(self, i):
494
514
  return GenerateReqInput(
495
515
  text=self.text[i] if self.text is not None else None,
@@ -542,16 +562,16 @@ class GenerateReqInput:
542
562
  self.data_parallel_rank if self.data_parallel_rank is not None else None
543
563
  ),
544
564
  conversation_id=self.conversation_id,
545
- label=self.label,
546
565
  priority=self.priority,
566
+ extra_key=self.extra_key,
567
+ no_logs=self.no_logs,
568
+ custom_labels=self.custom_labels,
547
569
  return_bytes=self.return_bytes,
548
570
  )
549
571
 
550
572
 
551
573
  @dataclass
552
- class TokenizedGenerateReqInput:
553
- # The request id
554
- rid: str
574
+ class TokenizedGenerateReqInput(BaseReq):
555
575
  # The input text
556
576
  input_text: str
557
577
  # The input token ids
@@ -570,6 +590,7 @@ class TokenizedGenerateReqInput:
570
590
  token_ids_logprob: List[int]
571
591
  # Whether to stream output
572
592
  stream: bool
593
+
573
594
  # Whether to return hidden states
574
595
  return_hidden_states: bool = False
575
596
 
@@ -596,24 +617,24 @@ class TokenizedGenerateReqInput:
596
617
  # For data parallel rank routing
597
618
  data_parallel_rank: Optional[int] = None
598
619
 
599
- # For dp balance
600
- dp_balance_id: int = -1
601
-
602
- # Label for the request
603
- label: Optional[str] = None
604
-
605
620
  # Priority for the request
606
621
  priority: Optional[int] = None
607
622
 
608
- # Image gen grpc migration
609
- return_bytes: bool = False
623
+ # Extra key for classifying the request (e.g. cache_salt)
624
+ extra_key: Optional[str] = None
625
+
626
+ # Whether to disallow logging for this request (e.g. due to ZDR)
627
+ no_logs: bool = False
610
628
 
611
629
  # tracing context
612
630
  trace_context: Optional[Dict] = None
613
631
 
632
+ # (Internal) Whether to return bytes for image generation
633
+ return_bytes: bool = False
634
+
614
635
 
615
636
  @dataclass
616
- class BatchTokenizedGenerateReqInput:
637
+ class BatchTokenizedGenerateReqInput(BaseBatchReq):
617
638
  # The batch of tokenized requests
618
639
  batch: List[TokenizedGenerateReqInput]
619
640
 
@@ -628,7 +649,7 @@ class BatchTokenizedGenerateReqInput:
628
649
 
629
650
 
630
651
  @dataclass
631
- class EmbeddingReqInput:
652
+ class EmbeddingReqInput(BaseReq):
632
653
  # The input prompt. It can be a single prompt or a batch of prompts.
633
654
  text: Optional[Union[List[List[str]], List[str], str]] = None
634
655
  # The image input. It can be an image instance, file name, URL, or base64 encoded string.
@@ -644,8 +665,6 @@ class EmbeddingReqInput:
644
665
  audio_data: Optional[MultimodalDataInputFormat] = None
645
666
  # The token ids for text; one can either specify text or input_ids.
646
667
  input_ids: Optional[Union[List[List[int]], List[int]]] = None
647
- # The request id.
648
- rid: Optional[Union[List[str], str]] = None
649
668
  # Dummy sampling params for compatibility
650
669
  sampling_params: Optional[Union[List[Dict], Dict]] = None
651
670
  # Dummy input embeds for compatibility
@@ -656,6 +675,8 @@ class EmbeddingReqInput:
656
675
  modalities: Optional[List[str]] = None
657
676
  # For cross-encoder requests
658
677
  is_cross_encoder_request: bool = False
678
+ # Priority for the request
679
+ priority: Optional[int] = None
659
680
 
660
681
  # For background responses (OpenAI responses API)
661
682
  background: bool = False
@@ -714,10 +735,6 @@ class EmbeddingReqInput:
714
735
  for i in range(self.batch_size):
715
736
  self.sampling_params[i]["max_new_tokens"] = 0
716
737
 
717
- def regenerate_rid(self):
718
- self.rid = uuid.uuid4().hex
719
- return self.rid
720
-
721
738
  def contains_mm_input(self) -> bool:
722
739
  return (
723
740
  has_valid_data(self.image_data)
@@ -746,9 +763,7 @@ class EmbeddingReqInput:
746
763
 
747
764
 
748
765
  @dataclass
749
- class TokenizedEmbeddingReqInput:
750
- # The request id
751
- rid: str
766
+ class TokenizedEmbeddingReqInput(BaseReq):
752
767
  # The input text
753
768
  input_text: str
754
769
  # The input token ids
@@ -761,12 +776,12 @@ class TokenizedEmbeddingReqInput:
761
776
  sampling_params: SamplingParams
762
777
  # For data parallel rank routing
763
778
  data_parallel_rank: Optional[int] = None
764
- # For dp balance
765
- dp_balance_id: int = -1
779
+ # Priority for the request
780
+ priority: Optional[int] = None
766
781
 
767
782
 
768
783
  @dataclass
769
- class BatchTokenizedEmbeddingReqInput:
784
+ class BatchTokenizedEmbeddingReqInput(BaseBatchReq):
770
785
  # The batch of tokenized embedding requests
771
786
  batch: List[TokenizedEmbeddingReqInput]
772
787
 
@@ -781,9 +796,7 @@ class BatchTokenizedEmbeddingReqInput:
781
796
 
782
797
 
783
798
  @dataclass
784
- class BatchTokenIDOut:
785
- # The request id
786
- rids: List[str]
799
+ class BatchTokenIDOutput(BaseBatchReq):
787
800
  # The finish reason
788
801
  finished_reasons: List[BaseFinishReason]
789
802
  # For incremental decoding
@@ -828,7 +841,7 @@ class BatchTokenIDOut:
828
841
 
829
842
 
830
843
  @dataclass
831
- class BatchMultimodalDecodeReq:
844
+ class BatchMultimodalDecodeReq(BaseBatchReq):
832
845
  decoded_ids: List[int]
833
846
  input_token_logprobs_val: List[float]
834
847
  input_token_logprobs_idx: List[int]
@@ -840,8 +853,6 @@ class BatchMultimodalDecodeReq:
840
853
  image_resolutions: List[List[int]]
841
854
  resize_image_resolutions: List[List[int]]
842
855
 
843
- # The request id
844
- rids: List[str]
845
856
  finished_reasons: List[BaseFinishReason]
846
857
 
847
858
  # Token counts
@@ -857,9 +868,7 @@ class BatchMultimodalDecodeReq:
857
868
 
858
869
 
859
870
  @dataclass
860
- class BatchStrOut:
861
- # The request id
862
- rids: List[str]
871
+ class BatchStrOutput(BaseBatchReq):
863
872
  # The finish reason
864
873
  finished_reasons: List[dict]
865
874
  # The output decoded strings
@@ -895,9 +904,7 @@ class BatchStrOut:
895
904
 
896
905
 
897
906
  @dataclass
898
- class BatchMultimodalOut:
899
- # The request id
900
- rids: List[str]
907
+ class BatchMultimodalOutput(BaseBatchReq):
901
908
  # The finish reason
902
909
  finished_reasons: List[dict]
903
910
  decoded_ids: List[List[int]]
@@ -922,9 +929,7 @@ class BatchMultimodalOut:
922
929
 
923
930
 
924
931
  @dataclass
925
- class BatchEmbeddingOut:
926
- # The request id
927
- rids: List[str]
932
+ class BatchEmbeddingOutput(BaseBatchReq):
928
933
  # The finish reason
929
934
  finished_reasons: List[BaseFinishReason]
930
935
  # The output embedding
@@ -938,27 +943,27 @@ class BatchEmbeddingOut:
938
943
 
939
944
 
940
945
  @dataclass
941
- class ClearHiCacheReqInput:
946
+ class ClearHiCacheReqInput(BaseReq):
942
947
  pass
943
948
 
944
949
 
945
950
  @dataclass
946
- class ClearHiCacheReqOutput:
951
+ class ClearHiCacheReqOutput(BaseReq):
947
952
  success: bool
948
953
 
949
954
 
950
955
  @dataclass
951
- class FlushCacheReqInput:
956
+ class FlushCacheReqInput(BaseReq):
952
957
  pass
953
958
 
954
959
 
955
960
  @dataclass
956
- class FlushCacheReqOutput:
961
+ class FlushCacheReqOutput(BaseReq):
957
962
  success: bool
958
963
 
959
964
 
960
965
  @dataclass
961
- class UpdateWeightFromDiskReqInput:
966
+ class UpdateWeightFromDiskReqInput(BaseReq):
962
967
  # The model path with the new weights
963
968
  model_path: str
964
969
  # The format to load the weights
@@ -976,7 +981,7 @@ class UpdateWeightFromDiskReqInput:
976
981
 
977
982
 
978
983
  @dataclass
979
- class UpdateWeightFromDiskReqOutput:
984
+ class UpdateWeightFromDiskReqOutput(BaseReq):
980
985
  success: bool
981
986
  message: str
982
987
  # Number of paused requests during weight sync.
@@ -984,7 +989,7 @@ class UpdateWeightFromDiskReqOutput:
984
989
 
985
990
 
986
991
  @dataclass
987
- class UpdateWeightsFromDistributedReqInput:
992
+ class UpdateWeightsFromDistributedReqInput(BaseReq):
988
993
  names: List[str]
989
994
  dtypes: List[str]
990
995
  shapes: List[List[int]]
@@ -999,13 +1004,13 @@ class UpdateWeightsFromDistributedReqInput:
999
1004
 
1000
1005
 
1001
1006
  @dataclass
1002
- class UpdateWeightsFromDistributedReqOutput:
1007
+ class UpdateWeightsFromDistributedReqOutput(BaseReq):
1003
1008
  success: bool
1004
1009
  message: str
1005
1010
 
1006
1011
 
1007
1012
  @dataclass
1008
- class UpdateWeightsFromTensorReqInput:
1013
+ class UpdateWeightsFromTensorReqInput(BaseReq):
1009
1014
  """Update model weights from tensor input.
1010
1015
 
1011
1016
  - Tensors are serialized for transmission
@@ -1024,13 +1029,13 @@ class UpdateWeightsFromTensorReqInput:
1024
1029
 
1025
1030
 
1026
1031
  @dataclass
1027
- class UpdateWeightsFromTensorReqOutput:
1032
+ class UpdateWeightsFromTensorReqOutput(BaseReq):
1028
1033
  success: bool
1029
1034
  message: str
1030
1035
 
1031
1036
 
1032
1037
  @dataclass
1033
- class InitWeightsSendGroupForRemoteInstanceReqInput:
1038
+ class InitWeightsSendGroupForRemoteInstanceReqInput(BaseReq):
1034
1039
  # The master address
1035
1040
  master_address: str
1036
1041
  # The ports for each rank's communication group
@@ -1046,13 +1051,13 @@ class InitWeightsSendGroupForRemoteInstanceReqInput:
1046
1051
 
1047
1052
 
1048
1053
  @dataclass
1049
- class InitWeightsSendGroupForRemoteInstanceReqOutput:
1054
+ class InitWeightsSendGroupForRemoteInstanceReqOutput(BaseReq):
1050
1055
  success: bool
1051
1056
  message: str
1052
1057
 
1053
1058
 
1054
1059
  @dataclass
1055
- class SendWeightsToRemoteInstanceReqInput:
1060
+ class SendWeightsToRemoteInstanceReqInput(BaseReq):
1056
1061
  # The master address
1057
1062
  master_address: str
1058
1063
  # The ports for each rank's communication group
@@ -1062,13 +1067,13 @@ class SendWeightsToRemoteInstanceReqInput:
1062
1067
 
1063
1068
 
1064
1069
  @dataclass
1065
- class SendWeightsToRemoteInstanceReqOutput:
1070
+ class SendWeightsToRemoteInstanceReqOutput(BaseReq):
1066
1071
  success: bool
1067
1072
  message: str
1068
1073
 
1069
1074
 
1070
1075
  @dataclass
1071
- class InitWeightsUpdateGroupReqInput:
1076
+ class InitWeightsUpdateGroupReqInput(BaseReq):
1072
1077
  # The master address
1073
1078
  master_address: str
1074
1079
  # The master port
@@ -1084,13 +1089,24 @@ class InitWeightsUpdateGroupReqInput:
1084
1089
 
1085
1090
 
1086
1091
  @dataclass
1087
- class InitWeightsUpdateGroupReqOutput:
1092
+ class InitWeightsUpdateGroupReqOutput(BaseReq):
1093
+ success: bool
1094
+ message: str
1095
+
1096
+
1097
+ @dataclass
1098
+ class DestroyWeightsUpdateGroupReqInput(BaseReq):
1099
+ group_name: str = "weight_update_group"
1100
+
1101
+
1102
+ @dataclass
1103
+ class DestroyWeightsUpdateGroupReqOutput(BaseReq):
1088
1104
  success: bool
1089
1105
  message: str
1090
1106
 
1091
1107
 
1092
1108
  @dataclass
1093
- class UpdateWeightVersionReqInput:
1109
+ class UpdateWeightVersionReqInput(BaseReq):
1094
1110
  # The new weight version
1095
1111
  new_version: str
1096
1112
  # Whether to abort all running requests before updating
@@ -1098,89 +1114,87 @@ class UpdateWeightVersionReqInput:
1098
1114
 
1099
1115
 
1100
1116
  @dataclass
1101
- class GetWeightsByNameReqInput:
1117
+ class GetWeightsByNameReqInput(BaseReq):
1102
1118
  name: str
1103
1119
  truncate_size: int = 100
1104
1120
 
1105
1121
 
1106
1122
  @dataclass
1107
- class GetWeightsByNameReqOutput:
1123
+ class GetWeightsByNameReqOutput(BaseReq):
1108
1124
  parameter: list
1109
1125
 
1110
1126
 
1111
1127
  @dataclass
1112
- class ReleaseMemoryOccupationReqInput:
1128
+ class ReleaseMemoryOccupationReqInput(BaseReq):
1113
1129
  # Optional tags to identify the memory region, which is primarily used for RL
1114
1130
  # Currently we only support `weights` and `kv_cache`
1115
1131
  tags: Optional[List[str]] = None
1116
1132
 
1117
1133
 
1118
1134
  @dataclass
1119
- class ReleaseMemoryOccupationReqOutput:
1135
+ class ReleaseMemoryOccupationReqOutput(BaseReq):
1120
1136
  pass
1121
1137
 
1122
1138
 
1123
1139
  @dataclass
1124
- class ResumeMemoryOccupationReqInput:
1140
+ class ResumeMemoryOccupationReqInput(BaseReq):
1125
1141
  # Optional tags to identify the memory region, which is primarily used for RL
1126
1142
  # Currently we only support `weights` and `kv_cache`
1127
1143
  tags: Optional[List[str]] = None
1128
1144
 
1129
1145
 
1130
1146
  @dataclass
1131
- class ResumeMemoryOccupationReqOutput:
1147
+ class ResumeMemoryOccupationReqOutput(BaseReq):
1132
1148
  pass
1133
1149
 
1134
1150
 
1135
1151
  @dataclass
1136
- class SlowDownReqInput:
1152
+ class SlowDownReqInput(BaseReq):
1137
1153
  forward_sleep_time: Optional[float]
1138
1154
 
1139
1155
 
1140
1156
  @dataclass
1141
- class SlowDownReqOutput:
1157
+ class SlowDownReqOutput(BaseReq):
1142
1158
  pass
1143
1159
 
1144
1160
 
1145
1161
  @dataclass
1146
- class AbortReq:
1147
- # The request id
1148
- rid: str = ""
1162
+ class AbortReq(BaseReq):
1149
1163
  # Whether to abort all requests
1150
1164
  abort_all: bool = False
1151
1165
  # The finished reason data
1152
1166
  finished_reason: Optional[Dict[str, Any]] = None
1153
1167
  abort_reason: Optional[str] = None
1154
- # used in MultiTokenzierManager mode
1155
- rids: Optional[Union[List[str], str]] = None
1156
1168
 
1157
1169
  def __post_init__(self):
1158
- self.rids = self.rid
1170
+ # FIXME: This is a hack to keep the same with the old code
1171
+ if self.rid is None:
1172
+ self.rid = ""
1159
1173
 
1160
1174
 
1161
1175
  @dataclass
1162
- class GetInternalStateReq:
1176
+ class GetInternalStateReq(BaseReq):
1163
1177
  pass
1164
1178
 
1165
1179
 
1166
1180
  @dataclass
1167
- class GetInternalStateReqOutput:
1181
+ class GetInternalStateReqOutput(BaseReq):
1168
1182
  internal_state: Dict[Any, Any]
1169
1183
 
1170
1184
 
1171
1185
  @dataclass
1172
- class SetInternalStateReq:
1186
+ class SetInternalStateReq(BaseReq):
1173
1187
  server_args: Dict[str, Any]
1174
1188
 
1175
1189
 
1176
1190
  @dataclass
1177
- class SetInternalStateReqOutput:
1191
+ class SetInternalStateReqOutput(BaseReq):
1178
1192
  updated: bool
1179
1193
  server_args: Dict[str, Any]
1180
1194
 
1181
1195
 
1182
1196
  @dataclass
1183
- class ProfileReqInput:
1197
+ class ProfileReqInput(BaseReq):
1184
1198
  # The output directory
1185
1199
  output_dir: Optional[str] = None
1186
1200
  # If set, it profile as many as this number of steps.
@@ -1200,7 +1214,7 @@ class ProfileReqType(Enum):
1200
1214
 
1201
1215
 
1202
1216
  @dataclass
1203
- class ProfileReq:
1217
+ class ProfileReq(BaseReq):
1204
1218
  type: ProfileReqType
1205
1219
  output_dir: Optional[str] = None
1206
1220
  start_step: Optional[int] = None
@@ -1213,18 +1227,18 @@ class ProfileReq:
1213
1227
 
1214
1228
 
1215
1229
  @dataclass
1216
- class ProfileReqOutput:
1230
+ class ProfileReqOutput(BaseReq):
1217
1231
  success: bool
1218
1232
  message: str
1219
1233
 
1220
1234
 
1221
1235
  @dataclass
1222
- class FreezeGCReq:
1236
+ class FreezeGCReq(BaseReq):
1223
1237
  pass
1224
1238
 
1225
1239
 
1226
1240
  @dataclass
1227
- class ConfigureLoggingReq:
1241
+ class ConfigureLoggingReq(BaseReq):
1228
1242
  log_requests: Optional[bool] = None
1229
1243
  log_requests_level: Optional[int] = None
1230
1244
  dump_requests_folder: Optional[str] = None
@@ -1233,35 +1247,39 @@ class ConfigureLoggingReq:
1233
1247
 
1234
1248
 
1235
1249
  @dataclass
1236
- class OpenSessionReqInput:
1250
+ class OpenSessionReqInput(BaseReq):
1237
1251
  capacity_of_str_len: int
1238
1252
  session_id: Optional[str] = None
1239
1253
 
1240
1254
 
1241
1255
  @dataclass
1242
- class CloseSessionReqInput:
1256
+ class CloseSessionReqInput(BaseReq):
1243
1257
  session_id: str
1244
1258
 
1245
1259
 
1246
1260
  @dataclass
1247
- class OpenSessionReqOutput:
1261
+ class OpenSessionReqOutput(BaseReq):
1248
1262
  session_id: Optional[str]
1249
1263
  success: bool
1250
1264
 
1251
1265
 
1252
1266
  @dataclass
1253
- class HealthCheckOutput:
1267
+ class HealthCheckOutput(BaseReq):
1254
1268
  pass
1255
1269
 
1256
1270
 
1257
- class ExpertDistributionReq(Enum):
1271
+ class ExpertDistributionReqType(Enum):
1258
1272
  START_RECORD = 1
1259
1273
  STOP_RECORD = 2
1260
1274
  DUMP_RECORD = 3
1261
1275
 
1262
1276
 
1277
+ class ExpertDistributionReq(BaseReq):
1278
+ action: ExpertDistributionReqType
1279
+
1280
+
1263
1281
  @dataclass
1264
- class ExpertDistributionReqOutput:
1282
+ class ExpertDistributionReqOutput(BaseReq):
1265
1283
  pass
1266
1284
 
1267
1285
 
@@ -1279,7 +1297,7 @@ class Tool:
1279
1297
 
1280
1298
 
1281
1299
  @dataclass
1282
- class ParseFunctionCallReq:
1300
+ class ParseFunctionCallReq(BaseReq):
1283
1301
  text: str # The text to parse.
1284
1302
  tools: List[Tool] = field(
1285
1303
  default_factory=list
@@ -1290,31 +1308,31 @@ class ParseFunctionCallReq:
1290
1308
 
1291
1309
 
1292
1310
  @dataclass
1293
- class SeparateReasoningReqInput:
1311
+ class SeparateReasoningReqInput(BaseReq):
1294
1312
  text: str # The text to parse.
1295
1313
  reasoning_parser: str # Specify the parser type, e.g., "deepseek-r1".
1296
1314
 
1297
1315
 
1298
1316
  @dataclass
1299
- class VertexGenerateReqInput:
1317
+ class VertexGenerateReqInput(BaseReq):
1300
1318
  instances: List[dict]
1301
1319
  parameters: Optional[dict] = None
1302
1320
 
1303
1321
 
1304
1322
  @dataclass
1305
- class RpcReqInput:
1323
+ class RpcReqInput(BaseReq):
1306
1324
  method: str
1307
1325
  parameters: Optional[Dict] = None
1308
1326
 
1309
1327
 
1310
1328
  @dataclass
1311
- class RpcReqOutput:
1329
+ class RpcReqOutput(BaseReq):
1312
1330
  success: bool
1313
1331
  message: str
1314
1332
 
1315
1333
 
1316
1334
  @dataclass
1317
- class LoadLoRAAdapterReqInput:
1335
+ class LoadLoRAAdapterReqInput(BaseReq):
1318
1336
  # The name of the lora module to newly loaded.
1319
1337
  lora_name: str
1320
1338
  # The path of loading.
@@ -1334,7 +1352,7 @@ class LoadLoRAAdapterReqInput:
1334
1352
 
1335
1353
 
1336
1354
  @dataclass
1337
- class UnloadLoRAAdapterReqInput:
1355
+ class UnloadLoRAAdapterReqInput(BaseReq):
1338
1356
  # The name of lora module to unload.
1339
1357
  lora_name: str
1340
1358
  # The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`.
@@ -1348,23 +1366,23 @@ class UnloadLoRAAdapterReqInput:
1348
1366
 
1349
1367
 
1350
1368
  @dataclass
1351
- class LoRAUpdateResult:
1369
+ class LoRAUpdateOutput(BaseReq):
1352
1370
  success: bool
1353
1371
  error_message: Optional[str] = None
1354
1372
  loaded_adapters: Optional[Dict[str, LoRARef]] = None
1355
1373
 
1356
1374
 
1357
- LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult
1375
+ LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateOutput
1358
1376
 
1359
1377
 
1360
1378
  @dataclass
1361
- class MultiTokenizerRegisterReq:
1362
- rids: Optional[Union[List[str], str]] = None
1379
+ class MultiTokenizerRegisterReq(BaseBatchReq):
1363
1380
  ipc_name: Optional[str] = None
1364
1381
 
1365
1382
 
1366
1383
  @dataclass
1367
1384
  class MultiTokenizerWrapper:
1385
+ # FIXME(lsyin): remove this
1368
1386
  worker_id: int
1369
1387
  obj: Optional[Any] = None
1370
1388
 
@@ -1375,17 +1393,17 @@ class BlockReqType(Enum):
1375
1393
 
1376
1394
 
1377
1395
  @dataclass
1378
- class BlockReqInput:
1396
+ class BlockReqInput(BaseReq):
1379
1397
  type: BlockReqType
1380
1398
 
1381
1399
 
1382
1400
  @dataclass
1383
- class GetLoadReqInput:
1401
+ class GetLoadReqInput(BaseReq):
1384
1402
  pass
1385
1403
 
1386
1404
 
1387
1405
  @dataclass
1388
- class GetLoadReqOutput:
1406
+ class GetLoadReqOutput(BaseReq):
1389
1407
  dp_rank: int
1390
1408
  num_reqs: int
1391
1409
  num_waiting_reqs: int
@@ -1393,5 +1411,31 @@ class GetLoadReqOutput:
1393
1411
 
1394
1412
 
1395
1413
  @dataclass
1396
- class WatchLoadUpdateReq:
1414
+ class WatchLoadUpdateReq(BaseReq):
1397
1415
  loads: List[GetLoadReqOutput]
1416
+
1417
+
1418
+ def _check_all_req_types():
1419
+ """A helper function to check all request types are defined in this file."""
1420
+ import inspect
1421
+ import sys
1422
+
1423
+ all_classes = inspect.getmembers(sys.modules[__name__], inspect.isclass)
1424
+ for class_type in all_classes:
1425
+ # check its name
1426
+ name = class_type[0]
1427
+ is_io_struct = (
1428
+ name.endswith("Req") or name.endswith("Input") or name.endswith("Output")
1429
+ )
1430
+ is_base_req = issubclass(class_type[1], BaseReq) or issubclass(
1431
+ class_type[1], BaseBatchReq
1432
+ )
1433
+ if is_io_struct and not is_base_req:
1434
+ raise ValueError(f"{name} is not a subclass of BaseReq or BaseBatchReq.")
1435
+ if is_base_req and not is_io_struct:
1436
+ raise ValueError(
1437
+ f"{name} is a subclass of BaseReq but not follow the naming convention."
1438
+ )
1439
+
1440
+
1441
+ _check_all_req_types()