sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (168) hide show
  1. sglang/bench_one_batch.py +2 -1
  2. sglang/eval/loogle_eval.py +7 -0
  3. sglang/srt/configs/deepseekvl2.py +11 -2
  4. sglang/srt/configs/internvl.py +3 -0
  5. sglang/srt/configs/janus_pro.py +3 -0
  6. sglang/srt/configs/model_config.py +9 -7
  7. sglang/srt/configs/update_config.py +3 -1
  8. sglang/srt/conversation.py +1 -0
  9. sglang/srt/custom_op.py +5 -2
  10. sglang/srt/disaggregation/decode.py +9 -1
  11. sglang/srt/disaggregation/mooncake/conn.py +44 -56
  12. sglang/srt/distributed/parallel_state.py +33 -0
  13. sglang/srt/entrypoints/engine.py +30 -26
  14. sglang/srt/entrypoints/openai/serving_chat.py +21 -2
  15. sglang/srt/eplb/expert_location_dispatch.py +1 -1
  16. sglang/srt/function_call/function_call_parser.py +2 -0
  17. sglang/srt/function_call/qwen3_detector.py +150 -0
  18. sglang/srt/hf_transformers_utils.py +0 -1
  19. sglang/srt/layers/activation.py +13 -0
  20. sglang/srt/layers/attention/flashattention_backend.py +3 -3
  21. sglang/srt/layers/attention/flashinfer_backend.py +40 -1
  22. sglang/srt/layers/linear.py +13 -102
  23. sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
  24. sglang/srt/layers/moe/ep_moe/layer.py +23 -402
  25. sglang/srt/layers/moe/fused_moe_native.py +7 -47
  26. sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
  27. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +35 -45
  33. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
  34. sglang/srt/layers/moe/topk.py +187 -12
  35. sglang/srt/layers/quantization/__init__.py +20 -134
  36. sglang/srt/layers/quantization/awq.py +578 -11
  37. sglang/srt/layers/quantization/awq_triton.py +339 -0
  38. sglang/srt/layers/quantization/base_config.py +85 -10
  39. sglang/srt/layers/quantization/blockwise_int8.py +17 -55
  40. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
  41. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +24 -73
  42. sglang/srt/layers/quantization/fp8.py +273 -62
  43. sglang/srt/layers/quantization/fp8_kernel.py +210 -46
  44. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  45. sglang/srt/layers/quantization/gptq.py +501 -143
  46. sglang/srt/layers/quantization/marlin_utils.py +790 -0
  47. sglang/srt/layers/quantization/modelopt_quant.py +26 -108
  48. sglang/srt/layers/quantization/moe_wna16.py +45 -49
  49. sglang/srt/layers/quantization/petit.py +252 -0
  50. sglang/srt/layers/quantization/petit_utils.py +104 -0
  51. sglang/srt/layers/quantization/qoq.py +7 -6
  52. sglang/srt/layers/quantization/scalar_type.py +352 -0
  53. sglang/srt/layers/quantization/unquant.py +422 -0
  54. sglang/srt/layers/quantization/utils.py +343 -3
  55. sglang/srt/layers/quantization/w4afp8.py +8 -4
  56. sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
  57. sglang/srt/layers/quantization/w8a8_int8.py +51 -115
  58. sglang/srt/layers/vocab_parallel_embedding.py +1 -41
  59. sglang/srt/lora/lora.py +0 -4
  60. sglang/srt/lora/lora_manager.py +87 -53
  61. sglang/srt/lora/mem_pool.py +81 -33
  62. sglang/srt/lora/utils.py +12 -5
  63. sglang/srt/managers/cache_controller.py +241 -0
  64. sglang/srt/managers/io_struct.py +41 -29
  65. sglang/srt/managers/mm_utils.py +7 -8
  66. sglang/srt/managers/schedule_batch.py +150 -110
  67. sglang/srt/managers/schedule_policy.py +68 -27
  68. sglang/srt/managers/scheduler.py +243 -61
  69. sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
  70. sglang/srt/managers/tokenizer_manager.py +11 -3
  71. sglang/srt/managers/tp_worker.py +14 -0
  72. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  73. sglang/srt/mem_cache/allocator.py +7 -16
  74. sglang/srt/mem_cache/base_prefix_cache.py +14 -2
  75. sglang/srt/mem_cache/chunk_cache.py +5 -2
  76. sglang/srt/mem_cache/hicache_storage.py +152 -0
  77. sglang/srt/mem_cache/hiradix_cache.py +179 -4
  78. sglang/srt/mem_cache/memory_pool.py +16 -1
  79. sglang/srt/mem_cache/memory_pool_host.py +41 -2
  80. sglang/srt/mem_cache/radix_cache.py +26 -0
  81. sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
  82. sglang/srt/metrics/collector.py +9 -0
  83. sglang/srt/model_executor/cuda_graph_runner.py +5 -6
  84. sglang/srt/model_executor/forward_batch_info.py +14 -1
  85. sglang/srt/model_executor/model_runner.py +109 -22
  86. sglang/srt/model_loader/loader.py +7 -1
  87. sglang/srt/model_loader/utils.py +4 -4
  88. sglang/srt/models/clip.py +1 -1
  89. sglang/srt/models/deepseek.py +9 -6
  90. sglang/srt/models/deepseek_janus_pro.py +1 -1
  91. sglang/srt/models/deepseek_v2.py +191 -171
  92. sglang/srt/models/deepseek_vl2.py +5 -5
  93. sglang/srt/models/gemma.py +48 -0
  94. sglang/srt/models/gemma2.py +52 -0
  95. sglang/srt/models/gemma3_causal.py +63 -0
  96. sglang/srt/models/gemma3_mm.py +1 -1
  97. sglang/srt/models/gemma3n_mm.py +2 -4
  98. sglang/srt/models/granitemoe.py +385 -0
  99. sglang/srt/models/grok.py +9 -3
  100. sglang/srt/models/hunyuan.py +63 -16
  101. sglang/srt/models/internvl.py +1 -1
  102. sglang/srt/models/kimi_vl.py +1 -1
  103. sglang/srt/models/llama.py +41 -0
  104. sglang/srt/models/llama4.py +11 -11
  105. sglang/srt/models/llava.py +2 -2
  106. sglang/srt/models/llavavid.py +1 -1
  107. sglang/srt/models/minicpm.py +0 -2
  108. sglang/srt/models/minicpmo.py +3 -7
  109. sglang/srt/models/minicpmv.py +1 -1
  110. sglang/srt/models/mistral.py +1 -1
  111. sglang/srt/models/mixtral.py +9 -2
  112. sglang/srt/models/mllama.py +3 -5
  113. sglang/srt/models/mllama4.py +3 -3
  114. sglang/srt/models/olmoe.py +8 -5
  115. sglang/srt/models/persimmon.py +330 -0
  116. sglang/srt/models/phi.py +321 -0
  117. sglang/srt/models/phi4mm.py +44 -4
  118. sglang/srt/models/phi4mm_audio.py +1260 -0
  119. sglang/srt/models/phi4mm_utils.py +1917 -0
  120. sglang/srt/models/phimoe.py +9 -3
  121. sglang/srt/models/qwen.py +37 -0
  122. sglang/srt/models/qwen2.py +41 -0
  123. sglang/srt/models/qwen2_5_vl.py +4 -4
  124. sglang/srt/models/qwen2_audio.py +1 -1
  125. sglang/srt/models/qwen2_moe.py +53 -5
  126. sglang/srt/models/qwen2_vl.py +4 -4
  127. sglang/srt/models/qwen3.py +65 -1
  128. sglang/srt/models/qwen3_moe.py +56 -18
  129. sglang/srt/models/vila.py +1 -1
  130. sglang/srt/multimodal/processors/base_processor.py +91 -97
  131. sglang/srt/multimodal/processors/clip.py +21 -19
  132. sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
  133. sglang/srt/multimodal/processors/gemma3.py +13 -17
  134. sglang/srt/multimodal/processors/gemma3n.py +19 -23
  135. sglang/srt/multimodal/processors/internvl.py +9 -10
  136. sglang/srt/multimodal/processors/janus_pro.py +12 -27
  137. sglang/srt/multimodal/processors/kimi_vl.py +12 -14
  138. sglang/srt/multimodal/processors/llava.py +4 -2
  139. sglang/srt/multimodal/processors/minicpm.py +35 -44
  140. sglang/srt/multimodal/processors/mlama.py +21 -18
  141. sglang/srt/multimodal/processors/mllama4.py +4 -5
  142. sglang/srt/multimodal/processors/phi4mm.py +63 -39
  143. sglang/srt/multimodal/processors/pixtral.py +14 -35
  144. sglang/srt/multimodal/processors/qwen_audio.py +65 -0
  145. sglang/srt/multimodal/processors/qwen_vl.py +16 -21
  146. sglang/srt/multimodal/processors/vila.py +14 -14
  147. sglang/srt/sampling/sampling_params.py +8 -1
  148. sglang/srt/server_args.py +393 -230
  149. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +9 -1
  150. sglang/srt/two_batch_overlap.py +1 -0
  151. sglang/srt/utils.py +27 -1
  152. sglang/test/runners.py +14 -3
  153. sglang/test/test_block_fp8.py +8 -3
  154. sglang/test/test_block_fp8_ep.py +1 -1
  155. sglang/test/test_custom_ops.py +12 -7
  156. sglang/test/test_cutlass_w4a8_moe.py +1 -3
  157. sglang/test/test_fp4_moe.py +1 -3
  158. sglang/test/test_marlin_moe.py +286 -0
  159. sglang/test/test_marlin_utils.py +171 -0
  160. sglang/test/test_utils.py +35 -0
  161. sglang/version.py +1 -1
  162. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/METADATA +8 -8
  163. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/RECORD +166 -146
  164. sglang/srt/layers/quantization/quant_utils.py +0 -166
  165. sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
  166. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/WHEEL +0 -0
  167. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/licenses/LICENSE +0 -0
  168. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/top_level.txt +0 -0
sglang/srt/server_args.py CHANGED
@@ -20,11 +20,14 @@ import logging
20
20
  import os
21
21
  import random
22
22
  import tempfile
23
+ from token import OP
23
24
  from typing import List, Literal, Optional, Union
24
25
 
25
26
  from sglang.srt.hf_transformers_utils import check_gguf_file, get_config
26
27
  from sglang.srt.reasoning_parser import ReasoningParser
27
28
  from sglang.srt.utils import (
29
+ LORA_TARGET_ALL_MODULES,
30
+ SUPPORTED_LORA_TARGET_MODULES,
28
31
  configure_ipv6,
29
32
  get_device,
30
33
  get_device_memory_capacity,
@@ -46,30 +49,28 @@ class ServerArgs:
46
49
  tokenizer_path: Optional[str] = None
47
50
  tokenizer_mode: str = "auto"
48
51
  skip_tokenizer_init: bool = False
49
- skip_server_warmup: bool = False
50
52
  load_format: str = "auto"
51
53
  model_loader_extra_config: str = "{}"
52
54
  trust_remote_code: bool = False
53
- dtype: str = "auto"
54
- kv_cache_dtype: str = "auto"
55
- quantization: Optional[str] = None
56
- quantization_param_path: Optional[str] = None
57
55
  context_length: Optional[int] = None
58
- device: Optional[str] = None
59
- served_model_name: Optional[str] = None
60
- chat_template: Optional[str] = None
61
- completion_template: Optional[str] = None
62
56
  is_embedding: bool = False
63
57
  enable_multimodal: Optional[bool] = None
64
58
  revision: Optional[str] = None
65
- hybrid_kvcache_ratio: Optional[float] = None
66
- impl: str = "auto"
59
+ model_impl: str = "auto"
67
60
 
68
- # Port for the HTTP server
61
+ # HTTP server
69
62
  host: str = "127.0.0.1"
70
63
  port: int = 30000
64
+ skip_server_warmup: bool = False
65
+ warmups: Optional[str] = None
71
66
  nccl_port: Optional[int] = None
72
67
 
68
+ # Quantization and data type
69
+ dtype: str = "auto"
70
+ quantization: Optional[str] = None
71
+ quantization_param_path: Optional[str] = None
72
+ kv_cache_dtype: str = "auto"
73
+
73
74
  # Memory and scheduling
74
75
  mem_fraction_static: Optional[float] = None
75
76
  max_running_requests: Optional[int] = None
@@ -80,8 +81,12 @@ class ServerArgs:
80
81
  schedule_conservativeness: float = 1.0
81
82
  cpu_offload_gb: int = 0
82
83
  page_size: int = 1
84
+ hybrid_kvcache_ratio: Optional[float] = None
85
+ swa_full_tokens_ratio: float = 0.8
86
+ disable_hybrid_swa_memory: bool = False
83
87
 
84
- # Other runtime options
88
+ # Runtime options
89
+ device: Optional[str] = None
85
90
  tp_size: int = 1
86
91
  pp_size: int = 1
87
92
  max_micro_batch_size: Optional[int] = None
@@ -104,9 +109,10 @@ class ServerArgs:
104
109
  crash_dump_folder: Optional[str] = None
105
110
  show_time_cost: bool = False
106
111
  enable_metrics: bool = False
112
+ enable_metrics_for_all_schedulers: bool = False
107
113
  bucket_time_to_first_token: Optional[List[float]] = None
108
- bucket_e2e_request_latency: Optional[List[float]] = None
109
114
  bucket_inter_token_latency: Optional[List[float]] = None
115
+ bucket_e2e_request_latency: Optional[List[float]] = None
110
116
  collect_tokens_histogram: bool = False
111
117
  decode_log_interval: int = 40
112
118
  enable_request_time_stats_logging: bool = False
@@ -114,6 +120,9 @@ class ServerArgs:
114
120
 
115
121
  # API related
116
122
  api_key: Optional[str] = None
123
+ served_model_name: Optional[str] = None
124
+ chat_template: Optional[str] = None
125
+ completion_template: Optional[str] = None
117
126
  file_storage_path: str = "sglang_storage"
118
127
  enable_cache_report: bool = False
119
128
  reasoning_parser: Optional[str] = None
@@ -133,6 +142,9 @@ class ServerArgs:
133
142
  preferred_sampling_params: Optional[str] = None
134
143
 
135
144
  # LoRA
145
+ enable_lora: Optional[bool] = None
146
+ max_lora_rank: Optional[int] = None
147
+ lora_target_modules: Optional[Union[set[str], List[str]]] = None
136
148
  lora_paths: Optional[Union[dict[str, str], List[str]]] = None
137
149
  max_loras_per_batch: int = 8
138
150
  lora_backend: str = "triton"
@@ -175,6 +187,14 @@ class ServerArgs:
175
187
  deepep_config: Optional[str] = None
176
188
  moe_dense_tp_size: Optional[int] = None
177
189
 
190
+ # Hierarchical cache
191
+ enable_hierarchical_cache: bool = False
192
+ hicache_ratio: float = 2.0
193
+ hicache_size: int = 0
194
+ hicache_write_policy: str = "write_through_selective"
195
+ hicache_io_backend: str = ""
196
+ hicache_storage_backend: Optional[str] = None
197
+
178
198
  # Double Sparsity
179
199
  enable_double_sparsity: bool = False
180
200
  ds_channel_config_path: Optional[str] = None
@@ -196,7 +216,6 @@ class ServerArgs:
196
216
  disable_custom_all_reduce: bool = False
197
217
  enable_mscclpp: bool = False
198
218
  disable_overlap_schedule: bool = False
199
- disable_overlap_cg_plan: bool = False
200
219
  enable_mixed_chunk: bool = False
201
220
  enable_dp_attention: bool = False
202
221
  enable_dp_lm_head: bool = False
@@ -213,18 +232,12 @@ class ServerArgs:
213
232
  enable_memory_saver: bool = False
214
233
  allow_auto_truncate: bool = False
215
234
  enable_custom_logit_processor: bool = False
216
- enable_hierarchical_cache: bool = False
217
- hicache_ratio: float = 2.0
218
- hicache_size: int = 0
219
- hicache_write_policy: str = "write_through_selective"
220
- hicache_io_backend: str = ""
221
235
  flashinfer_mla_disable_ragged: bool = False
222
236
  disable_shared_experts_fusion: bool = False
223
237
  disable_chunked_prefix_cache: bool = False
224
238
  disable_fast_image_processor: bool = False
225
239
  enable_return_hidden_states: bool = False
226
240
  enable_triton_kernel_moe: bool = False
227
- warmups: Optional[str] = None
228
241
 
229
242
  # Debug tensor dumps
230
243
  debug_tensor_dump_output_folder: Optional[str] = None
@@ -232,7 +245,7 @@ class ServerArgs:
232
245
  debug_tensor_dump_inject: bool = False
233
246
  debug_tensor_dump_prefill_only: bool = False
234
247
 
235
- # For PD disaggregation: can be "null" (not disaggregated), "prefill" (prefill-only), or "decode" (decode-only)
248
+ # PD disaggregation: can be "null" (not disaggregated), "prefill" (prefill-only), or "decode" (decode-only)
236
249
  disaggregation_mode: str = "null"
237
250
  disaggregation_transfer_backend: str = "mooncake"
238
251
  disaggregation_bootstrap_port: int = 8998
@@ -247,6 +260,10 @@ class ServerArgs:
247
260
  custom_weight_loader: Optional[List[str]] = None
248
261
  weight_loader_disable_mmap: bool = False
249
262
 
263
+ # For PD-Multiplexing
264
+ enable_pdmux: bool = False
265
+ sm_group_num: int = 3
266
+
250
267
  def __post_init__(self):
251
268
  # Expert parallelism
252
269
  if self.enable_ep_moe:
@@ -263,6 +280,7 @@ class ServerArgs:
263
280
  logger.warning(
264
281
  f"Flashinfer MoE is enabled. Shared expert fusion is disabled."
265
282
  )
283
+
266
284
  # Set missing default values
267
285
  if self.tokenizer_path is None:
268
286
  self.tokenizer_path = self.model_path
@@ -323,12 +341,12 @@ class ServerArgs:
323
341
  self.mem_fraction_static = 0.88
324
342
 
325
343
  # Lazy init to avoid circular import
344
+ # Multimodal models need more memory for the image processor
326
345
  from sglang.srt.configs.model_config import ModelConfig
327
346
 
328
- # Multimodal models need more memory for the image processor
329
347
  model_config = ModelConfig.from_server_args(self)
330
348
  if model_config.is_multimodal:
331
- self.mem_fraction_static *= 0.90
349
+ self.adjust_mem_fraction_for_vlm(model_config)
332
350
 
333
351
  # Set chunked prefill size, which depends on the gpu memory capacity
334
352
  if self.chunked_prefill_size is None:
@@ -352,23 +370,6 @@ class ServerArgs:
352
370
  else:
353
371
  self.cuda_graph_max_bs = 80
354
372
 
355
- assert self.moe_dense_tp_size in {
356
- 1,
357
- None,
358
- }, "moe_dense_tp_size only support 1 and None currently"
359
-
360
- if self.attention_backend == "flashmla":
361
- logger.warning(
362
- "FlashMLA only supports a page_size of 64, change page_size to 64."
363
- )
364
- self.page_size = 64
365
-
366
- if self.attention_backend == "cutlass_mla":
367
- logger.warning(
368
- "Cutlass MLA only supports a page_size of 128, change page_size to 128."
369
- )
370
- self.page_size = 128
371
-
372
373
  # Set kernel backends for hpu device
373
374
  if self.device == "hpu":
374
375
  self.attention_backend = "torch_native"
@@ -397,6 +398,18 @@ class ServerArgs:
397
398
  )
398
399
  self.page_size = 128
399
400
 
401
+ if self.attention_backend == "flashmla":
402
+ logger.warning(
403
+ "FlashMLA only supports a page_size of 64, change page_size to 64."
404
+ )
405
+ self.page_size = 64
406
+
407
+ if self.attention_backend == "cutlass_mla":
408
+ logger.warning(
409
+ "Cutlass MLA only supports a page_size of 128, change page_size to 128."
410
+ )
411
+ self.page_size = 128
412
+
400
413
  # Choose grammar backend
401
414
  if self.grammar_backend is None:
402
415
  self.grammar_backend = "xgrammar"
@@ -428,12 +441,6 @@ class ServerArgs:
428
441
  f"DeepEP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
429
442
  )
430
443
 
431
- if self.pp_size > 1:
432
- self.disable_overlap_schedule = True
433
- logger.warning(
434
- "Pipeline parallelism is incompatible with overlap schedule."
435
- )
436
-
437
444
  if self.enable_eplb and (self.expert_distribution_recorder_mode is None):
438
445
  self.expert_distribution_recorder_mode = "stat"
439
446
  logger.info(
@@ -459,6 +466,13 @@ class ServerArgs:
459
466
  elif self.expert_distribution_recorder_mode is not None:
460
467
  self.expert_distribution_recorder_buffer_size = 1000
461
468
 
469
+ # Pipeline parallelism
470
+ if self.pp_size > 1:
471
+ self.disable_overlap_schedule = True
472
+ logger.warning(
473
+ "Pipeline parallelism is incompatible with overlap schedule."
474
+ )
475
+
462
476
  # Speculative Decoding
463
477
  if self.speculative_algorithm == "NEXTN":
464
478
  # NEXTN shares the same implementation of EAGLE
@@ -479,16 +493,23 @@ class ServerArgs:
479
493
  "eagle speculative decoding."
480
494
  )
481
495
 
482
- model_arch = get_model_arch(self)
483
-
484
- # Auto set draft_model_path DeepSeek-V3/R1
496
+ model_arch = self.get_hf_config().architectures[0]
485
497
  if model_arch == "DeepseekV3ForCausalLM":
498
+ # Auto set draft_model_path DeepSeek-V3/R1
486
499
  if self.speculative_draft_model_path is None:
487
500
  self.speculative_draft_model_path = self.model_path
488
501
  else:
489
502
  logger.warning(
490
503
  "DeepSeek MTP does not require setting speculative_draft_model_path."
491
504
  )
505
+ elif "Llama4" in model_arch:
506
+ # TODO: remove this after Llama4 supports in other backends
507
+ if self.attention_backend != "fa3":
508
+ self.attention_backend = "fa3"
509
+ logger.warning(
510
+ "Llama4 requires using fa3 attention backend. "
511
+ "Attention backend is automatically set to fa3."
512
+ )
492
513
 
493
514
  # Auto choose parameters
494
515
  if self.speculative_num_steps is None:
@@ -562,17 +583,9 @@ class ServerArgs:
562
583
  if self.custom_weight_loader is None:
563
584
  self.custom_weight_loader = []
564
585
 
565
- def validate_disagg_tp_size(self, prefill_tp: int, decode_tp: int):
566
- larger_tp = max(decode_tp, prefill_tp)
567
- smaller_tp = min(decode_tp, prefill_tp)
568
- assert larger_tp % smaller_tp == 0, (
569
- "Different tp size is supported only when one tp is multiple of the other. "
570
- f"decode_tp={decode_tp}, prefill_tp={prefill_tp}"
571
- )
572
-
573
586
  @staticmethod
574
587
  def add_cli_args(parser: argparse.ArgumentParser):
575
- # Model and port args
588
+ # Model and tokenizer
576
589
  parser.add_argument(
577
590
  "--model-path",
578
591
  "--model",
@@ -586,24 +599,6 @@ class ServerArgs:
586
599
  default=ServerArgs.tokenizer_path,
587
600
  help="The path of the tokenizer.",
588
601
  )
589
- parser.add_argument(
590
- "--host",
591
- type=str,
592
- default=ServerArgs.host,
593
- help="The host of the HTTP server.",
594
- )
595
- parser.add_argument(
596
- "--port",
597
- type=int,
598
- default=ServerArgs.port,
599
- help="The port of the HTTP server.",
600
- )
601
- parser.add_argument(
602
- "--nccl-port",
603
- type=int,
604
- default=ServerArgs.nccl_port,
605
- help="The port for NCCL distributed environment setup. Defaults to a random port.",
606
- )
607
602
  parser.add_argument(
608
603
  "--tokenizer-mode",
609
604
  type=str,
@@ -618,11 +613,6 @@ class ServerArgs:
618
613
  action="store_true",
619
614
  help="If set, skip init tokenizer and pass input_ids in generate request.",
620
615
  )
621
- parser.add_argument(
622
- "--skip-server-warmup",
623
- action="store_true",
624
- help="If set, skip warmup.",
625
- )
626
616
  parser.add_argument(
627
617
  "--load-format",
628
618
  type=str,
@@ -668,6 +658,77 @@ class ServerArgs:
668
658
  action="store_true",
669
659
  help="Whether or not to allow for custom models defined on the Hub in their own modeling files.",
670
660
  )
661
+ parser.add_argument(
662
+ "--context-length",
663
+ type=int,
664
+ default=ServerArgs.context_length,
665
+ help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).",
666
+ )
667
+ parser.add_argument(
668
+ "--is-embedding",
669
+ action="store_true",
670
+ help="Whether to use a CausalLM as an embedding model.",
671
+ )
672
+ parser.add_argument(
673
+ "--enable-multimodal",
674
+ default=ServerArgs.enable_multimodal,
675
+ action="store_true",
676
+ help="Enable the multimodal functionality for the served model. If the model being served is not multimodal, nothing will happen",
677
+ )
678
+ parser.add_argument(
679
+ "--revision",
680
+ type=str,
681
+ default=None,
682
+ help="The specific model version to use. It can be a branch "
683
+ "name, a tag name, or a commit id. If unspecified, will use "
684
+ "the default version.",
685
+ )
686
+ parser.add_argument(
687
+ "--model-impl",
688
+ type=str,
689
+ default=ServerArgs.model_impl,
690
+ help="Which implementation of the model to use.\n\n"
691
+ '* "auto" will try to use the SGLang implementation if it exists '
692
+ "and fall back to the Transformers implementation if no SGLang "
693
+ "implementation is available.\n"
694
+ '* "sglang" will use the SGLang model implementation.\n'
695
+ '* "transformers" will use the Transformers model '
696
+ "implementation.\n",
697
+ )
698
+
699
+ # HTTP server
700
+ parser.add_argument(
701
+ "--host",
702
+ type=str,
703
+ default=ServerArgs.host,
704
+ help="The host of the HTTP server.",
705
+ )
706
+ parser.add_argument(
707
+ "--port",
708
+ type=int,
709
+ default=ServerArgs.port,
710
+ help="The port of the HTTP server.",
711
+ )
712
+ parser.add_argument(
713
+ "--skip-server-warmup",
714
+ action="store_true",
715
+ help="If set, skip warmup.",
716
+ )
717
+ parser.add_argument(
718
+ "--warmups",
719
+ type=str,
720
+ required=False,
721
+ help="Specify custom warmup functions (csv) to run before server starts eg. --warmups=warmup_name1,warmup_name2 "
722
+ "will run the functions `warmup_name1` and `warmup_name2` specified in warmup.py before the server starts listening for requests",
723
+ )
724
+ parser.add_argument(
725
+ "--nccl-port",
726
+ type=int,
727
+ default=ServerArgs.nccl_port,
728
+ help="The port for NCCL distributed environment setup. Defaults to a random port.",
729
+ )
730
+
731
+ # Quantization and data type
671
732
  parser.add_argument(
672
733
  "--dtype",
673
734
  type=str,
@@ -682,13 +743,6 @@ class ServerArgs:
682
743
  '* "float" is shorthand for FP32 precision.\n'
683
744
  '* "float32" for FP32 precision.',
684
745
  )
685
- parser.add_argument(
686
- "--kv-cache-dtype",
687
- type=str,
688
- default=ServerArgs.kv_cache_dtype,
689
- choices=["auto", "fp8_e5m2", "fp8_e4m3"],
690
- help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" and "fp8_e4m3" is supported for CUDA 11.8+.',
691
- )
692
746
  parser.add_argument(
693
747
  "--quantization",
694
748
  type=str,
@@ -704,6 +758,7 @@ class ServerArgs:
704
758
  "gguf",
705
759
  "modelopt",
706
760
  "modelopt_fp4",
761
+ "petit_nvfp4",
707
762
  "w8a8_int8",
708
763
  "w8a8_fp8",
709
764
  "moe_wna16",
@@ -722,65 +777,11 @@ class ServerArgs:
722
777
  "default to 1.0, which may cause accuracy issues. ",
723
778
  )
724
779
  parser.add_argument(
725
- "--context-length",
726
- type=int,
727
- default=ServerArgs.context_length,
728
- help="The model's maximum context length. Defaults to None (will use the value from the model's config.json instead).",
729
- )
730
- parser.add_argument(
731
- "--device",
732
- type=str,
733
- default=ServerArgs.device,
734
- help="The device to use ('cuda', 'xpu', 'hpu', 'npu', 'cpu'). Defaults to auto-detection if not specified.",
735
- )
736
- parser.add_argument(
737
- "--served-model-name",
738
- type=str,
739
- default=ServerArgs.served_model_name,
740
- help="Override the model name returned by the v1/models endpoint in OpenAI API server.",
741
- )
742
- parser.add_argument(
743
- "--chat-template",
744
- type=str,
745
- default=ServerArgs.chat_template,
746
- help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server.",
747
- )
748
- parser.add_argument(
749
- "--completion-template",
750
- type=str,
751
- default=ServerArgs.completion_template,
752
- help="The buliltin completion template name or the path of the completion template file. This is only used for OpenAI-compatible API server. only for code completion currently.",
753
- )
754
- parser.add_argument(
755
- "--is-embedding",
756
- action="store_true",
757
- help="Whether to use a CausalLM as an embedding model.",
758
- )
759
- parser.add_argument(
760
- "--enable-multimodal",
761
- default=ServerArgs.enable_multimodal,
762
- action="store_true",
763
- help="Enable the multimodal functionality for the served model. If the model being served is not multimodal, nothing will happen",
764
- )
765
- parser.add_argument(
766
- "--revision",
767
- type=str,
768
- default=None,
769
- help="The specific model version to use. It can be a branch "
770
- "name, a tag name, or a commit id. If unspecified, will use "
771
- "the default version.",
772
- )
773
- parser.add_argument(
774
- "--impl",
780
+ "--kv-cache-dtype",
775
781
  type=str,
776
- default=ServerArgs.impl,
777
- help="Which implementation of the model to use.\n\n"
778
- '* "auto" will try to use the SGLang implementation if it exists '
779
- "and fall back to the Transformers implementation if no SGLang "
780
- "implementation is available.\n"
781
- '* "sglang" will use the SGLang model implementation.\n'
782
- '* "transformers" will use the Transformers model '
783
- "implementation.\n",
782
+ default=ServerArgs.kv_cache_dtype,
783
+ choices=["auto", "fp8_e5m2", "fp8_e4m3"],
784
+ help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" and "fp8_e4m3" is supported for CUDA 11.8+.',
784
785
  )
785
786
 
786
787
  # Memory and scheduling
@@ -852,8 +853,26 @@ class ServerArgs:
852
853
  "(1.0 = pure hybrid: swa_size / full_size = local_attention_size / context_length)"
853
854
  ),
854
855
  )
856
+ parser.add_argument(
857
+ "--swa-full-tokens-ratio",
858
+ type=float,
859
+ default=ServerArgs.swa_full_tokens_ratio,
860
+ help="The ratio of SWA layer KV tokens / full layer KV tokens, regardless of the number of swa:full layers. It should be between 0 and 1. "
861
+ "E.g. 0.5 means if each swa layer has 50 tokens, then each full layer has 100 tokens.",
862
+ )
863
+ parser.add_argument(
864
+ "--disable-hybrid-swa-memory",
865
+ action="store_true",
866
+ help="Disable the hybrid SWA memory.",
867
+ )
855
868
 
856
- # Other runtime options
869
+ # Runtime options
870
+ parser.add_argument(
871
+ "--device",
872
+ type=str,
873
+ default=ServerArgs.device,
874
+ help="The device to use ('cuda', 'xpu', 'hpu', 'npu', 'cpu'). Defaults to auto-detection if not specified.",
875
+ )
857
876
  parser.add_argument(
858
877
  "--tensor-parallel-size",
859
878
  "--tp-size",
@@ -895,7 +914,7 @@ class ServerArgs:
895
914
  "--constrained-json-whitespace-pattern",
896
915
  type=str,
897
916
  default=ServerArgs.constrained_json_whitespace_pattern,
898
- help=r"Regex pattern for syntactic whitespaces allowed in JSON constrained output. For example, to allow the model generate consecutive whitespaces, set the pattern to [\n\t ]*",
917
+ help="(outlines backend only) Regex pattern for syntactic whitespaces allowed in JSON constrained output. For example, to allow the model generate consecutive whitespaces, set the pattern to [\n\t ]*",
899
918
  )
900
919
  parser.add_argument(
901
920
  "--watchdog-timeout",
@@ -974,6 +993,13 @@ class ServerArgs:
974
993
  action="store_true",
975
994
  help="Enable log prometheus metrics.",
976
995
  )
996
+ parser.add_argument(
997
+ "--enable-metrics-for-all-schedulers",
998
+ action="store_true",
999
+ help="Enable --enable-metrics-for-all-schedulers when you want schedulers on all TP ranks (not just TP 0) "
1000
+ "to record request metrics separately. This is especially useful when dp_attention is enabled, as "
1001
+ "otherwise all metrics appear to come from TP 0.",
1002
+ )
977
1003
  parser.add_argument(
978
1004
  "--bucket-time-to-first-token",
979
1005
  type=float,
@@ -1001,12 +1027,6 @@ class ServerArgs:
1001
1027
  default=ServerArgs.collect_tokens_histogram,
1002
1028
  help="Collect prompt/generation tokens histogram.",
1003
1029
  )
1004
- parser.add_argument(
1005
- "--kv-events-config",
1006
- type=str,
1007
- default=None,
1008
- help="Config in json format for NVIDIA dynamo KV event publishing. Publishing will be enabled if this flag is used.",
1009
- )
1010
1030
  parser.add_argument(
1011
1031
  "--decode-log-interval",
1012
1032
  type=int,
@@ -1019,6 +1039,12 @@ class ServerArgs:
1019
1039
  default=ServerArgs.enable_request_time_stats_logging,
1020
1040
  help="Enable per request time stats logging",
1021
1041
  )
1042
+ parser.add_argument(
1043
+ "--kv-events-config",
1044
+ type=str,
1045
+ default=None,
1046
+ help="Config in json format for NVIDIA dynamo KV event publishing. Publishing will be enabled if this flag is used.",
1047
+ )
1022
1048
 
1023
1049
  # API related
1024
1050
  parser.add_argument(
@@ -1027,6 +1053,24 @@ class ServerArgs:
1027
1053
  default=ServerArgs.api_key,
1028
1054
  help="Set API key of the server. It is also used in the OpenAI API compatible server.",
1029
1055
  )
1056
+ parser.add_argument(
1057
+ "--served-model-name",
1058
+ type=str,
1059
+ default=ServerArgs.served_model_name,
1060
+ help="Override the model name returned by the v1/models endpoint in OpenAI API server.",
1061
+ )
1062
+ parser.add_argument(
1063
+ "--chat-template",
1064
+ type=str,
1065
+ default=ServerArgs.chat_template,
1066
+ help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server.",
1067
+ )
1068
+ parser.add_argument(
1069
+ "--completion-template",
1070
+ type=str,
1071
+ default=ServerArgs.completion_template,
1072
+ help="The buliltin completion template name or the path of the completion template file. This is only used for OpenAI-compatible API server. only for code completion currently.",
1073
+ )
1030
1074
  parser.add_argument(
1031
1075
  "--file-storage-path",
1032
1076
  type=str,
@@ -1055,6 +1099,7 @@ class ServerArgs:
1055
1099
  "deepseekv3",
1056
1100
  "pythonic",
1057
1101
  "kimi_k2",
1102
+ "qwen3",
1058
1103
  ],
1059
1104
  default=ServerArgs.tool_call_parser,
1060
1105
  help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', 'llama3', 'deepseekv3', 'pythonic', and 'kimi_k2'.",
@@ -1107,6 +1152,28 @@ class ServerArgs:
1107
1152
  )
1108
1153
 
1109
1154
  # LoRA
1155
+ parser.add_argument(
1156
+ "--enable-lora",
1157
+ default=ServerArgs.enable_lora,
1158
+ action="store_true",
1159
+ help="Enable LoRA support for the model. This argument is automatically set to True if `--lora-paths` is provided for backward compatibility.",
1160
+ )
1161
+ parser.add_argument(
1162
+ "--max-lora-rank",
1163
+ default=ServerArgs.max_lora_rank,
1164
+ type=int,
1165
+ help="The maximum rank of LoRA adapters. If not specified, it will be automatically inferred from the adapters provided in --lora-paths.",
1166
+ )
1167
+ parser.add_argument(
1168
+ "--lora-target-modules",
1169
+ type=str,
1170
+ choices=SUPPORTED_LORA_TARGET_MODULES + [LORA_TARGET_ALL_MODULES],
1171
+ nargs="*",
1172
+ default=None,
1173
+ help="The union set of all target modules where LoRA should be applied. If not specified, "
1174
+ "it will be automatically inferred from the adapters provided in --lora-paths. If 'all' is specified, "
1175
+ "all supported modules will be targeted.",
1176
+ )
1110
1177
  parser.add_argument(
1111
1178
  "--lora-paths",
1112
1179
  type=str,
@@ -1323,6 +1390,46 @@ class ServerArgs:
1323
1390
  help="TP size for MoE dense MLP layers. This flag is useful when, with large TP size, there are errors caused by weights in MLP layers having dimension smaller than the min dimension GEMM supports.",
1324
1391
  )
1325
1392
 
1393
+ # Hierarchical cache
1394
+ parser.add_argument(
1395
+ "--enable-hierarchical-cache",
1396
+ action="store_true",
1397
+ help="Enable hierarchical cache",
1398
+ )
1399
+ parser.add_argument(
1400
+ "--hicache-ratio",
1401
+ type=float,
1402
+ default=ServerArgs.hicache_ratio,
1403
+ help="The ratio of the size of host KV cache memory pool to the size of device pool.",
1404
+ )
1405
+ parser.add_argument(
1406
+ "--hicache-size",
1407
+ type=int,
1408
+ default=ServerArgs.hicache_size,
1409
+ help="The size of host KV cache memory pool in gigabytes, which will override the hicache_ratio if set.",
1410
+ )
1411
+ parser.add_argument(
1412
+ "--hicache-write-policy",
1413
+ type=str,
1414
+ choices=["write_back", "write_through", "write_through_selective"],
1415
+ default=ServerArgs.hicache_write_policy,
1416
+ help="The write policy of hierarchical cache.",
1417
+ )
1418
+ parser.add_argument(
1419
+ "--hicache-io-backend",
1420
+ type=str,
1421
+ choices=["direct", "kernel"],
1422
+ default=ServerArgs.hicache_io_backend,
1423
+ help="The IO backend for KV cache transfer between CPU and GPU",
1424
+ )
1425
+ parser.add_argument(
1426
+ "--hicache-storage-backend",
1427
+ type=str,
1428
+ choices=["file"], # todo, mooncake
1429
+ default=ServerArgs.hicache_storage_backend,
1430
+ help="The storage backend for hierarchical KV cache.",
1431
+ )
1432
+
1326
1433
  # Double Sparsity
1327
1434
  parser.add_argument(
1328
1435
  "--enable-double-sparsity",
@@ -1515,37 +1622,6 @@ class ServerArgs:
1515
1622
  action="store_true",
1516
1623
  help="Enable users to pass custom logit processors to the server (disabled by default for security)",
1517
1624
  )
1518
- parser.add_argument(
1519
- "--enable-hierarchical-cache",
1520
- action="store_true",
1521
- help="Enable hierarchical cache",
1522
- )
1523
- parser.add_argument(
1524
- "--hicache-ratio",
1525
- type=float,
1526
- default=ServerArgs.hicache_ratio,
1527
- help="The ratio of the size of host KV cache memory pool to the size of device pool.",
1528
- )
1529
- parser.add_argument(
1530
- "--hicache-size",
1531
- type=int,
1532
- default=ServerArgs.hicache_size,
1533
- help="The size of host KV cache memory pool in gigabytes, which will override the hicache_ratio if set.",
1534
- )
1535
- parser.add_argument(
1536
- "--hicache-write-policy",
1537
- type=str,
1538
- choices=["write_back", "write_through", "write_through_selective"],
1539
- default=ServerArgs.hicache_write_policy,
1540
- help="The write policy of hierarchical cache.",
1541
- )
1542
- parser.add_argument(
1543
- "--hicache-io-backend",
1544
- type=str,
1545
- choices=["direct", "kernel"],
1546
- default=ServerArgs.hicache_io_backend,
1547
- help="The IO backend for KV cache transfer between CPU and GPU",
1548
- )
1549
1625
  parser.add_argument(
1550
1626
  "--flashinfer-mla-disable-ragged",
1551
1627
  action="store_true",
@@ -1576,13 +1652,6 @@ class ServerArgs:
1576
1652
  action="store_true",
1577
1653
  help="Use triton moe grouped gemm kernel.",
1578
1654
  )
1579
- parser.add_argument(
1580
- "--warmups",
1581
- type=str,
1582
- required=False,
1583
- help="Specify custom warmup functions (csv) to run before server starts eg. --warmups=warmup_name1,warmup_name2 "
1584
- "will run the functions `warmup_name1` and `warmup_name2` specified in warmup.py before the server starts listening for requests",
1585
- )
1586
1655
 
1587
1656
  # Debug tensor dumps
1588
1657
  parser.add_argument(
@@ -1609,7 +1678,7 @@ class ServerArgs:
1609
1678
  help="Only dump the tensors for prefill requests (i.e. batch size > 1).",
1610
1679
  )
1611
1680
 
1612
- # Disaggregation
1681
+ # PD disaggregation
1613
1682
  parser.add_argument(
1614
1683
  "--disaggregation-mode",
1615
1684
  type=str,
@@ -1668,6 +1737,8 @@ class ServerArgs:
1668
1737
  default=None,
1669
1738
  help="The URL of the PD disaggregation load balancer. If set, the prefill/decode server will register with the load balancer.",
1670
1739
  )
1740
+
1741
+ # Custom weight loader
1671
1742
  parser.add_argument(
1672
1743
  "--custom-weight-loader",
1673
1744
  type=str,
@@ -1675,6 +1746,19 @@ class ServerArgs:
1675
1746
  default=None,
1676
1747
  help="The custom dataloader which used to update the model. Should be set with a valid import path, such as my_package.weight_load_func",
1677
1748
  )
1749
+ parser.add_argument(
1750
+ "--enable-pdmux",
1751
+ action="store_true",
1752
+ help="Enable PD-Multiplexing, PD running on greenctx stream.",
1753
+ )
1754
+
1755
+ # For PD-Multiplexing
1756
+ parser.add_argument(
1757
+ "--sm-group-num",
1758
+ type=int,
1759
+ default=ServerArgs.sm_group_num,
1760
+ help="Number of sm partition groups.",
1761
+ )
1678
1762
  parser.add_argument(
1679
1763
  "--weight-loader-disable-mmap",
1680
1764
  action="store_true",
@@ -1696,6 +1780,17 @@ class ServerArgs:
1696
1780
  else:
1697
1781
  return f"http://{self.host}:{self.port}"
1698
1782
 
1783
+ def get_hf_config(self):
1784
+ kwargs = {}
1785
+ hf_config = get_config(
1786
+ self.model_path,
1787
+ trust_remote_code=self.trust_remote_code,
1788
+ revision=self.revision,
1789
+ model_override_args=json.loads(self.json_model_override_args),
1790
+ **kwargs,
1791
+ )
1792
+ return hf_config
1793
+
1699
1794
  def check_server_args(self):
1700
1795
  assert (
1701
1796
  self.tp_size * self.pp_size
@@ -1720,15 +1815,101 @@ class ServerArgs:
1720
1815
  assert self.base_gpu_id >= 0, "base_gpu_id must be non-negative"
1721
1816
  assert self.gpu_id_step >= 1, "gpu_id_step must be positive"
1722
1817
 
1723
- if isinstance(self.lora_paths, list):
1724
- lora_paths = self.lora_paths
1725
- self.lora_paths = {}
1726
- for lora_path in lora_paths:
1727
- if "=" in lora_path:
1728
- name, path = lora_path.split("=", 1)
1729
- self.lora_paths[name] = path
1730
- else:
1731
- self.lora_paths[lora_path] = lora_path
1818
+ assert self.moe_dense_tp_size in {
1819
+ 1,
1820
+ None,
1821
+ }, "moe_dense_tp_size only support 1 and None currently"
1822
+
1823
+ self.check_lora_server_args()
1824
+
1825
+ def check_lora_server_args(self):
1826
+ # Enable LoRA if any LoRA paths are provided for backward compatibility.
1827
+ if self.lora_paths:
1828
+ if self.enable_lora is None:
1829
+ self.enable_lora = True
1830
+ logger.info(
1831
+ "--enable-lora is set to True because --lora-paths is provided."
1832
+ )
1833
+ elif self.enable_lora is False:
1834
+ logger.warning(
1835
+ "--enable-lora is set to False, any provided lora_paths will be ignored."
1836
+ )
1837
+
1838
+ if self.enable_lora:
1839
+ # Normalize lora_paths to a dictionary if it is a list.
1840
+ if isinstance(self.lora_paths, list):
1841
+ lora_paths = self.lora_paths
1842
+ self.lora_paths = {}
1843
+ for lora_path in lora_paths:
1844
+ if "=" in lora_path:
1845
+ name, path = lora_path.split("=", 1)
1846
+ self.lora_paths[name] = path
1847
+ else:
1848
+ self.lora_paths[lora_path] = lora_path
1849
+
1850
+ # Expand target modules
1851
+ if self.lora_target_modules:
1852
+ self.lora_target_modules = set(self.lora_target_modules)
1853
+ if "all" in self.lora_target_modules:
1854
+ assert (
1855
+ len(self.lora_target_modules) == 1
1856
+ ), "If 'all' is specified in --lora-target-modules, it should be the only module specified."
1857
+ self.lora_target_modules = set(SUPPORTED_LORA_TARGET_MODULES)
1858
+
1859
+ # Ensure sufficient information is provided for LoRA initialization.
1860
+ assert self.lora_paths or (
1861
+ self.max_lora_rank and self.lora_target_modules
1862
+ ), "When no initial --lora-paths is provided, you need to specify both --max-lora-rank and --lora-target-modules for LoRA initialization."
1863
+
1864
+ def validate_disagg_tp_size(self, prefill_tp: int, decode_tp: int):
1865
+ larger_tp = max(decode_tp, prefill_tp)
1866
+ smaller_tp = min(decode_tp, prefill_tp)
1867
+ assert larger_tp % smaller_tp == 0, (
1868
+ "Different tp size is supported only when one tp is multiple of the other. "
1869
+ f"decode_tp={decode_tp}, prefill_tp={prefill_tp}"
1870
+ )
1871
+
1872
+ def adjust_mem_fraction_for_vlm(self, model_config):
1873
+ vision_config = getattr(model_config.hf_config, "vision_config", None)
1874
+ if vision_config is None:
1875
+ return
1876
+
1877
+ # roughly reduce the mem_fraction_static base on params of Vit
1878
+ original_server_arg_mem_fraction = self.mem_fraction_static
1879
+ # a base mem_fraction_static factor for regular Vit
1880
+ base_mem_fraction_reduction_ratio = 0.95
1881
+
1882
+ vit_num_layers = getattr(vision_config, "num_hidden_layers", 24)
1883
+ vit_hidden_size = getattr(vision_config, "hidden_size", 1024)
1884
+
1885
+ # baseline ViT params (ViT-L/14)
1886
+ baseline_vit_layers = 24
1887
+ baseline_vit_hidden_size = 1024
1888
+
1889
+ # weight params count
1890
+ current_complexity_score = vit_num_layers * (vit_hidden_size**2)
1891
+ baseline_complexity_score = baseline_vit_layers * (baseline_vit_hidden_size**2)
1892
+ complexity_ratio = (
1893
+ current_complexity_score / baseline_complexity_score
1894
+ if baseline_complexity_score > 0
1895
+ else 1.0
1896
+ )
1897
+
1898
+ # every time the complexity grows 100%, adjust final factor for 10%
1899
+ sensitivity_scale = 0.1
1900
+ dynamic_adjustment_factor = 1.0 - sensitivity_scale * (complexity_ratio - 1.0)
1901
+ dynamic_adjustment_factor = max(0.8, min(1.05, dynamic_adjustment_factor))
1902
+
1903
+ final_overall_factor = (
1904
+ base_mem_fraction_reduction_ratio * dynamic_adjustment_factor
1905
+ )
1906
+ self.mem_fraction_static = (
1907
+ original_server_arg_mem_fraction * final_overall_factor
1908
+ )
1909
+ logger.warning(
1910
+ f"Multimodal model: Dynamically adjusted --mem-fraction-static "
1911
+ f"from: {original_server_arg_mem_fraction:.3f} to: {self.mem_fraction_static:.3f}."
1912
+ )
1732
1913
 
1733
1914
 
1734
1915
  def prepare_server_args(argv: List[str]) -> ServerArgs:
@@ -1773,16 +1954,16 @@ class PortArgs:
1773
1954
  @staticmethod
1774
1955
  def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs":
1775
1956
  if server_args.nccl_port is None:
1776
- port = server_args.port + random.randint(100, 1000)
1957
+ nccl_port = server_args.port + random.randint(100, 1000)
1777
1958
  while True:
1778
- if is_port_available(port):
1959
+ if is_port_available(nccl_port):
1779
1960
  break
1780
- if port < 60000:
1781
- port += 42
1961
+ if nccl_port < 60000:
1962
+ nccl_port += 42
1782
1963
  else:
1783
- port -= 43
1964
+ nccl_port -= 43
1784
1965
  else:
1785
- port = server_args.nccl_port
1966
+ nccl_port = server_args.nccl_port
1786
1967
 
1787
1968
  if not server_args.enable_dp_attention:
1788
1969
  # Normal case, use IPC within a single node
@@ -1790,7 +1971,7 @@ class PortArgs:
1790
1971
  tokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
1791
1972
  scheduler_input_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
1792
1973
  detokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
1793
- nccl_port=port,
1974
+ nccl_port=nccl_port,
1794
1975
  rpc_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
1795
1976
  metrics_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
1796
1977
  )
@@ -1820,7 +2001,7 @@ class PortArgs:
1820
2001
  tokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base}",
1821
2002
  scheduler_input_ipc_name=f"tcp://{dist_init_host}:{scheduler_input_port}",
1822
2003
  detokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base + 1}",
1823
- nccl_port=port,
2004
+ nccl_port=nccl_port,
1824
2005
  rpc_ipc_name=f"tcp://{dist_init_host}:{port_base + 2}",
1825
2006
  metrics_ipc_name=f"tcp://{dist_init_host}:{port_base + 3}",
1826
2007
  )
@@ -1847,31 +2028,13 @@ class DeprecatedAction(argparse.Action):
1847
2028
  raise ValueError(self.help)
1848
2029
 
1849
2030
 
1850
- def get_model_arch(args: ServerArgs):
1851
- hf_config = get_config(
1852
- args.model_path,
1853
- trust_remote_code=args.trust_remote_code,
1854
- revision=args.revision,
1855
- model_override_args=json.loads(args.json_model_override_args),
1856
- )
1857
- return hf_config.architectures[0]
1858
-
1859
-
1860
2031
  def auto_choose_speculative_params(self: ServerArgs):
1861
2032
  """
1862
2033
  Automatically choose the parameters for speculative decoding.
1863
2034
 
1864
2035
  You can tune them on your own models and prompts with scripts/playground/bench_speculative.py
1865
2036
  """
1866
- kwargs = {}
1867
-
1868
- hf_config = get_config(
1869
- self.model_path,
1870
- trust_remote_code=self.trust_remote_code,
1871
- revision=self.revision,
1872
- model_override_args=json.loads(self.json_model_override_args),
1873
- **kwargs,
1874
- )
2037
+ hf_config = self.get_hf_config()
1875
2038
  arch = hf_config.architectures[0]
1876
2039
 
1877
2040
  if arch in ["LlamaForCausalLM"]: