sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. sglang/bench_offline_throughput.py +4 -2
  2. sglang/bench_one_batch.py +2 -2
  3. sglang/bench_one_batch_server.py +143 -15
  4. sglang/bench_serving.py +9 -7
  5. sglang/compile_deep_gemm.py +1 -1
  6. sglang/eval/loogle_eval.py +157 -0
  7. sglang/lang/chat_template.py +78 -78
  8. sglang/lang/tracer.py +1 -1
  9. sglang/srt/code_completion_parser.py +1 -1
  10. sglang/srt/configs/deepseekvl2.py +2 -2
  11. sglang/srt/configs/model_config.py +1 -0
  12. sglang/srt/constrained/base_grammar_backend.py +55 -72
  13. sglang/srt/constrained/llguidance_backend.py +25 -21
  14. sglang/srt/constrained/outlines_backend.py +27 -26
  15. sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
  16. sglang/srt/constrained/xgrammar_backend.py +69 -43
  17. sglang/srt/conversation.py +48 -43
  18. sglang/srt/disaggregation/base/conn.py +1 -0
  19. sglang/srt/disaggregation/decode.py +7 -2
  20. sglang/srt/disaggregation/fake/conn.py +1 -1
  21. sglang/srt/disaggregation/mooncake/conn.py +227 -120
  22. sglang/srt/disaggregation/nixl/conn.py +1 -0
  23. sglang/srt/disaggregation/prefill.py +7 -4
  24. sglang/srt/disaggregation/utils.py +7 -1
  25. sglang/srt/entrypoints/engine.py +17 -2
  26. sglang/srt/entrypoints/http_server.py +17 -2
  27. sglang/srt/function_call_parser.py +2 -2
  28. sglang/srt/layers/attention/flashattention_backend.py +1 -1
  29. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
  30. sglang/srt/layers/attention/utils.py +4 -2
  31. sglang/srt/layers/dp_attention.py +71 -21
  32. sglang/srt/layers/layernorm.py +1 -1
  33. sglang/srt/layers/logits_processor.py +46 -11
  34. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  35. sglang/srt/layers/moe/ep_moe/layer.py +1 -1
  36. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
  37. sglang/srt/layers/moe/topk.py +1 -1
  38. sglang/srt/layers/quantization/__init__.py +1 -1
  39. sglang/srt/layers/quantization/blockwise_int8.py +2 -2
  40. sglang/srt/layers/quantization/deep_gemm.py +72 -71
  41. sglang/srt/layers/quantization/fp8.py +2 -2
  42. sglang/srt/layers/quantization/fp8_kernel.py +3 -3
  43. sglang/srt/layers/quantization/int8_kernel.py +2 -2
  44. sglang/srt/layers/sampler.py +0 -4
  45. sglang/srt/layers/vocab_parallel_embedding.py +18 -7
  46. sglang/srt/lora/lora_manager.py +1 -1
  47. sglang/srt/lora/mem_pool.py +4 -4
  48. sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
  49. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  50. sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
  51. sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
  52. sglang/srt/lora/utils.py +1 -1
  53. sglang/srt/managers/data_parallel_controller.py +3 -3
  54. sglang/srt/managers/detokenizer_manager.py +21 -8
  55. sglang/srt/managers/io_struct.py +3 -1
  56. sglang/srt/managers/mm_utils.py +1 -1
  57. sglang/srt/managers/multimodal_processors/llava.py +46 -0
  58. sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
  59. sglang/srt/managers/schedule_batch.py +76 -24
  60. sglang/srt/managers/schedule_policy.py +0 -3
  61. sglang/srt/managers/scheduler.py +113 -88
  62. sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
  63. sglang/srt/managers/tokenizer_manager.py +133 -34
  64. sglang/srt/managers/tp_worker.py +12 -9
  65. sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
  66. sglang/srt/mem_cache/memory_pool.py +2 -0
  67. sglang/srt/metrics/collector.py +312 -37
  68. sglang/srt/model_executor/cuda_graph_runner.py +10 -11
  69. sglang/srt/model_executor/forward_batch_info.py +1 -1
  70. sglang/srt/model_executor/model_runner.py +19 -14
  71. sglang/srt/models/deepseek_janus_pro.py +2 -2
  72. sglang/srt/models/deepseek_v2.py +23 -20
  73. sglang/srt/models/llama.py +2 -0
  74. sglang/srt/models/llama4.py +5 -6
  75. sglang/srt/models/llava.py +248 -5
  76. sglang/srt/models/mixtral.py +98 -34
  77. sglang/srt/models/pixtral.py +467 -0
  78. sglang/srt/models/roberta.py +1 -1
  79. sglang/srt/models/torch_native_llama.py +1 -1
  80. sglang/srt/openai_api/adapter.py +30 -4
  81. sglang/srt/openai_api/protocol.py +0 -8
  82. sglang/srt/reasoning_parser.py +3 -3
  83. sglang/srt/sampling/custom_logit_processor.py +18 -3
  84. sglang/srt/sampling/sampling_batch_info.py +4 -56
  85. sglang/srt/sampling/sampling_params.py +2 -2
  86. sglang/srt/server_args.py +34 -4
  87. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  88. sglang/srt/speculative/eagle_utils.py +7 -7
  89. sglang/srt/speculative/eagle_worker.py +22 -19
  90. sglang/srt/utils.py +6 -5
  91. sglang/test/few_shot_gsm8k.py +2 -2
  92. sglang/test/few_shot_gsm8k_engine.py +2 -2
  93. sglang/test/run_eval.py +2 -2
  94. sglang/test/runners.py +8 -1
  95. sglang/test/send_one.py +13 -3
  96. sglang/test/simple_eval_common.py +1 -1
  97. sglang/test/simple_eval_humaneval.py +1 -1
  98. sglang/test/test_programs.py +5 -5
  99. sglang/test/test_utils.py +89 -14
  100. sglang/utils.py +1 -1
  101. sglang/version.py +1 -1
  102. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/METADATA +6 -5
  103. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/RECORD +107 -104
  104. /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
  105. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/WHEEL +0 -0
  106. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/licenses/LICENSE +0 -0
  107. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/top_level.txt +0 -0
@@ -15,12 +15,9 @@ _ENABLE_JIT_DEEPGEMM = False
15
15
  if is_cuda():
16
16
  import deep_gemm
17
17
  from deep_gemm import get_num_sms
18
+ from deep_gemm.jit.compiler import get_nvcc_compiler
18
19
  from deep_gemm.jit_kernels.gemm import get_best_configs
19
- from deep_gemm.jit_kernels.gemm import includes as deep_gemm_includes
20
- from deep_gemm.jit_kernels.gemm import template as deep_gemm_gemm_template
21
- from deep_gemm.jit_kernels.m_grouped_gemm import (
22
- template as deep_gemm_grouped_gemm_template,
23
- )
20
+ from deep_gemm.jit_kernels.runtime import FP8GemmRuntime, GemmType
24
21
  from deep_gemm.jit_kernels.tuner import jit_tuner
25
22
 
26
23
  sm_version = get_device_sm()
@@ -45,10 +42,25 @@ _COMPILE_WORKERS = get_int_env_var("SGL_JIT_DEEPGEMM_COMPILE_WORKERS", 4)
45
42
  _IN_PRECOMPILE_STAGE = get_bool_env_var("SGL_IN_DEEPGEMM_PRECOMPILE_STAGE", "false")
46
43
 
47
44
  # Force redirect deep_gemm cache_dir
48
- os.environ["DG_CACHE_DIR"] = os.getenv(
49
- "SGL_DG_CACHE_DIR", os.path.expanduser("~") + "/.cache/deep_gemm"
45
+ os.environ["DG_JIT_CACHE_DIR"] = os.getenv(
46
+ "SGL_DG_CACHE_DIR", os.path.join(os.path.expanduser("~"), ".cache", "deep_gemm")
50
47
  )
51
48
 
49
+ # Refer to https://github.com/deepseek-ai/DeepGEMM/commit/d75b218b7b8f4a5dd5406ac87905039ead3ae42f
50
+ # NVRTC may have performance loss with some cases.
51
+ # And NVCC JIT speed is also 9x faster in the ref commit
52
+ _USE_NVRTC_DEFAULT = "0"
53
+ if _ENABLE_JIT_DEEPGEMM:
54
+ try:
55
+ get_nvcc_compiler()
56
+ except:
57
+ logger.warning(
58
+ "NVCC Compiler not found, use NVRTC for DeepGEMM JIT "
59
+ "and may have performance loss with some cases."
60
+ )
61
+ _USE_NVRTC_DEFAULT = "1"
62
+ os.environ["DG_JIT_USE_NVRTC"] = os.getenv("SGL_DG_USE_NVRTC", _USE_NVRTC_DEFAULT)
63
+
52
64
 
53
65
  def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):
54
66
  global _BUILTIN_M_LIST
@@ -103,10 +115,10 @@ _INITIALIZATION_DICT: Dict[Tuple[DeepGemmKernelType, int, int, int], bool] = dic
103
115
  def _compile_warning_1():
104
116
  if not _IN_PRECOMPILE_STAGE and _IS_FIRST_RANK_ON_NODE:
105
117
  logger.warning(
106
- "Entering DeepGEMM JIT Pre-Complie session. "
118
+ "Entering DeepGEMM JIT Pre-Compile session. "
107
119
  "And it may takes a long time(Typically 10-20 mins) "
108
120
  "if you have not run `sglang.compile_deep_gemm`. "
109
- "Recommand to run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`"
121
+ "It is recommended to run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`"
110
122
  " for pre-compilation to reduce the overhead if you have not run it before. "
111
123
  "For example: "
112
124
  "`python3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code`"
@@ -115,7 +127,7 @@ def _compile_warning_1():
115
127
 
116
128
  def _compile_warning_2():
117
129
  logger.warning(
118
- "Entering DeepGEMM JIT Single Kernel Complie session. "
130
+ "Entering DeepGEMM JIT Single Kernel Compile session. "
119
131
  "And it will makes inference throughput becomes flaky. "
120
132
  "Please run `sglang.compile_deep_gemm` with same args as `sglang.launch_server`"
121
133
  " for pre-compilation to solve this issue. "
@@ -130,10 +142,18 @@ def _compile_grouped_gemm_nt_f8f8bf16_masked_one(
130
142
  num_groups: int,
131
143
  config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
132
144
  ) -> None:
133
- # Auto-tuning with compilation
134
- global deep_gemm_includes, deep_gemm_grouped_gemm_template
135
- _, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
136
- _ = jit_tuner.compile_and_tune(
145
+ num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
146
+ block_k = 128
147
+ num_tma_threads = 128
148
+ num_math_threads_per_group = 128
149
+ kwargs = {
150
+ "NUM_TMA_THREADS": num_tma_threads,
151
+ "NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
152
+ "BLOCK_K": block_k,
153
+ "NUM_SMS": num_sms,
154
+ "SMEM_SIZE": smem_config[0],
155
+ }
156
+ _, _ = jit_tuner.compile_and_tune(
137
157
  name="m_grouped_gemm_fp8_fp8_bf16_nt",
138
158
  keys={
139
159
  "N": n,
@@ -146,24 +166,11 @@ def _compile_grouped_gemm_nt_f8f8bf16_masked_one(
146
166
  "NUM_STAGES": num_stages,
147
167
  "NUM_TMA_MULTICAST": tma_multicast_config[0],
148
168
  "IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
149
- "GEMM_TYPE": "GroupedMasked",
169
+ "GEMM_TYPE": GemmType.GroupedMasked,
150
170
  },
151
171
  space=(),
152
- includes=deep_gemm_includes,
153
- arg_defs=(
154
- ("lhs", torch.float8_e4m3fn),
155
- ("lhs_scales", torch.float),
156
- ("rhs", torch.float8_e4m3fn),
157
- ("rhs_scales", torch.float),
158
- ("out", torch.bfloat16),
159
- ("grouped_layout", torch.int32),
160
- ("m", int),
161
- ("stream", torch.cuda.Stream),
162
- ("num_sms", int),
163
- ("smem_size", int),
164
- ),
165
- template=deep_gemm_grouped_gemm_template,
166
- args=[],
172
+ kwargs=kwargs,
173
+ runtime_cls=FP8GemmRuntime,
167
174
  )
168
175
 
169
176
 
@@ -173,9 +180,18 @@ def _compile_grouped_gemm_nt_f8f8bf16_contig_one(
173
180
  num_groups: int,
174
181
  config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
175
182
  ) -> None:
176
- global deep_gemm_includes, deep_gemm_grouped_gemm_template
177
- _, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
178
- _ = jit_tuner.compile_and_tune(
183
+ num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
184
+ block_k = 128
185
+ num_tma_threads = 128
186
+ num_math_threads_per_group = 128
187
+ kwargs = {
188
+ "NUM_TMA_THREADS": num_tma_threads,
189
+ "NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
190
+ "BLOCK_K": block_k,
191
+ "NUM_SMS": num_sms,
192
+ "SMEM_SIZE": smem_config[0],
193
+ }
194
+ _, _ = jit_tuner.compile_and_tune(
179
195
  name="m_grouped_gemm_fp8_fp8_bf16_nt",
180
196
  keys={
181
197
  "N": n,
@@ -188,25 +204,11 @@ def _compile_grouped_gemm_nt_f8f8bf16_contig_one(
188
204
  "NUM_STAGES": num_stages,
189
205
  "NUM_TMA_MULTICAST": tma_multicast_config[0],
190
206
  "IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
191
- "GEMM_TYPE": "GroupedContiguous",
207
+ "GEMM_TYPE": GemmType.GroupedContiguous,
192
208
  },
193
209
  space=(),
194
- includes=deep_gemm_includes,
195
- arg_defs=(
196
- ("lhs", torch.float8_e4m3fn),
197
- ("lhs_scales", torch.float),
198
- ("rhs", torch.float8_e4m3fn),
199
- ("rhs_scales", torch.float),
200
- ("out", torch.bfloat16),
201
- ("grouped_layout", torch.int32),
202
- ("m", int),
203
- ("num_groups", int),
204
- ("stream", torch.cuda.Stream),
205
- ("num_sms", int),
206
- ("smem_size", int),
207
- ),
208
- template=deep_gemm_grouped_gemm_template,
209
- args=[],
210
+ kwargs=kwargs,
211
+ runtime_cls=FP8GemmRuntime,
210
212
  )
211
213
 
212
214
 
@@ -216,9 +218,20 @@ def _compile_gemm_nt_f8f8bf16_one(
216
218
  _: int, # _ is a dummy parameter to align with other interfaces
217
219
  config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
218
220
  ) -> None:
219
- global deep_gemm_includes, deep_gemm_gemm_template
220
- _, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
221
- _ = jit_tuner.compile_and_tune(
221
+ num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
222
+ block_k = 128
223
+ num_tma_threads = 128
224
+ num_math_threads_per_group = 128
225
+ kwargs = {
226
+ "GEMM_TYPE": GemmType.Normal,
227
+ "NUM_TMA_THREADS": num_tma_threads,
228
+ "NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
229
+ "NUM_GROUPS": 1,
230
+ "BLOCK_K": block_k,
231
+ "NUM_SMS": num_sms,
232
+ "SMEM_SIZE": smem_config[0],
233
+ }
234
+ _, _ = jit_tuner.compile_and_tune(
222
235
  name="gemm_fp8_fp8_bf16_nt",
223
236
  keys={
224
237
  "N": n,
@@ -232,20 +245,8 @@ def _compile_gemm_nt_f8f8bf16_one(
232
245
  "IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
233
246
  },
234
247
  space=(),
235
- includes=deep_gemm_includes,
236
- arg_defs=(
237
- ("lhs", torch.float8_e4m3fn),
238
- ("lhs_scales", torch.float),
239
- ("rhs", torch.float8_e4m3fn),
240
- ("rhs_scales", torch.float),
241
- ("out", torch.bfloat16),
242
- ("m", int),
243
- ("stream", torch.cuda.Stream),
244
- ("num_sms", int),
245
- ("smem_size", int),
246
- ),
247
- template=deep_gemm_gemm_template,
248
- args=[],
248
+ kwargs=kwargs,
249
+ runtime_cls=FP8GemmRuntime,
249
250
  )
250
251
 
251
252
 
@@ -298,7 +299,7 @@ def _maybe_compile_deep_gemm_one_type_all(
298
299
  logger.info(
299
300
  f"Try DeepGEMM JIT Compiling for "
300
301
  f"<{kernel_helper.name}> N={n}, K={k}, num_groups={num_groups} with all Ms."
301
- f"{' It only takes a litte time (typically 1 sec) if you have run `python3 -m sglang.compile_deep_gemm`. ' if not _IN_PRECOMPILE_STAGE else ''}"
302
+ f"{' It only takes a little time (typically 1 sec) if you have run `python3 -m sglang.compile_deep_gemm`. ' if not _IN_PRECOMPILE_STAGE else ''}"
302
303
  )
303
304
 
304
305
  # NOTE(alcanderian): get_num_sms should be change when 2-batch-overlap is introduced
@@ -373,7 +374,7 @@ def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType):
373
374
 
374
375
  from deep_gemm.jit.runtime import RuntimeCache
375
376
 
376
- origin_func = RuntimeCache.__getitem__
377
+ origin_func = RuntimeCache.get
377
378
 
378
379
  def __patched_func(self, *args, **kwargs):
379
380
  ret = origin_func(self, *args, **kwargs)
@@ -385,6 +386,6 @@ def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType):
385
386
  )
386
387
  return ret
387
388
 
388
- RuntimeCache.__getitem__ = __patched_func
389
+ RuntimeCache.get = __patched_func
389
390
  yield
390
- RuntimeCache.__getitem__ = origin_func
391
+ RuntimeCache.get = origin_func
@@ -235,7 +235,7 @@ class Fp8LinearMethod(LinearMethodBase):
235
235
  f"{input_size_per_partition} is not divisible by "
236
236
  f"weight quantization block_k = {block_k}."
237
237
  )
238
- # Required by collum parallel or enabling merged weights
238
+ # Required by column parallel or enabling merged weights
239
239
  if (
240
240
  tp_size > 1 and output_size // output_size_per_partition == tp_size
241
241
  ) or len(output_partition_sizes) > 1:
@@ -491,7 +491,7 @@ class Fp8MoEMethod:
491
491
  self.quant_config.weight_block_size[1],
492
492
  )
493
493
  # NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
494
- # Required by collum parallel or enabling merged weights
494
+ # Required by column parallel or enabling merged weights
495
495
  if intermediate_size % block_n != 0:
496
496
  raise ValueError(
497
497
  f"The output_size of gate's and up's weight = "
@@ -104,7 +104,7 @@ def _per_token_group_quant_fp8(
104
104
  y_s_ptr,
105
105
  # Stride of input
106
106
  y_stride,
107
- # Collums of input
107
+ # Columns of input
108
108
  N,
109
109
  # Avoid to divide zero
110
110
  eps,
@@ -342,7 +342,7 @@ def _static_quant_fp8(
342
342
  y_s_repeat_ptr,
343
343
  # Stride of input
344
344
  y_stride,
345
- # Collums of input
345
+ # Columns of input
346
346
  N,
347
347
  # Information for float8
348
348
  fp8_min,
@@ -794,7 +794,7 @@ def w8a8_block_fp8_matmul(
794
794
  config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
795
795
  else:
796
796
  # Default config
797
- # Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1]
797
+ # Block-wise quant: BLOCK_SIZE_K must be divisible by block_size[1]
798
798
  config = {
799
799
  "BLOCK_SIZE_M": 64,
800
800
  "BLOCK_SIZE_N": block_size[0],
@@ -76,7 +76,7 @@ def _per_token_group_quant_int8(
76
76
  y_s_ptr,
77
77
  # Stride of input
78
78
  y_stride,
79
- # Collums of input
79
+ # Columns of input
80
80
  N,
81
81
  # Avoid to divide zero
82
82
  eps,
@@ -370,7 +370,7 @@ def w8a8_block_int8_matmul(
370
370
  config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
371
371
  else:
372
372
  # Default config
373
- # Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1]
373
+ # Block-wise quant: BLOCK_SIZE_K must be divisible by block_size[1]
374
374
  config = {
375
375
  "BLOCK_SIZE_M": 64,
376
376
  "BLOCK_SIZE_N": block_size[0],
@@ -239,10 +239,6 @@ def top_p_normalize_probs_torch(
239
239
 
240
240
 
241
241
  def get_top_logprobs(logprobs: torch.Tensor, top_logprobs_nums: List[int]):
242
- assert len(top_logprobs_nums) == logprobs.shape[0], (
243
- len(top_logprobs_nums),
244
- logprobs.shape[0],
245
- )
246
242
  max_k = max(top_logprobs_nums)
247
243
  ret = logprobs.topk(max_k, dim=1)
248
244
  values = ret.values.tolist()
@@ -13,6 +13,7 @@ from sglang.srt.distributed import (
13
13
  get_tensor_model_parallel_world_size,
14
14
  tensor_model_parallel_all_reduce,
15
15
  )
16
+ from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
16
17
  from sglang.srt.layers.parameter import BasevLLMParameter
17
18
  from sglang.srt.layers.quantization.base_config import (
18
19
  QuantizationConfig,
@@ -214,12 +215,14 @@ class VocabParallelEmbedding(torch.nn.Module):
214
215
  self,
215
216
  num_embeddings: int,
216
217
  embedding_dim: int,
218
+ *,
217
219
  params_dtype: Optional[torch.dtype] = None,
218
220
  org_num_embeddings: Optional[int] = None,
219
221
  padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
220
222
  quant_config: Optional[QuantizationConfig] = None,
221
223
  prefix: str = "",
222
224
  enable_tp: bool = True,
225
+ use_attn_tp_group: bool = False,
223
226
  use_presharded_weights: bool = False,
224
227
  ):
225
228
  super().__init__()
@@ -227,9 +230,14 @@ class VocabParallelEmbedding(torch.nn.Module):
227
230
 
228
231
  self.enable_tp = enable_tp
229
232
  if self.enable_tp:
230
- tp_rank = get_tensor_model_parallel_rank()
231
- self.tp_size = get_tensor_model_parallel_world_size()
233
+ if use_attn_tp_group:
234
+ tp_rank = get_attention_tp_rank()
235
+ self.tp_size = get_attention_tp_size()
236
+ else:
237
+ tp_rank = get_tensor_model_parallel_rank()
238
+ self.tp_size = get_tensor_model_parallel_world_size()
232
239
  else:
240
+ assert use_attn_tp_group is False
233
241
  tp_rank = 0
234
242
  self.tp_size = 1
235
243
 
@@ -519,22 +527,25 @@ class ParallelLMHead(VocabParallelEmbedding):
519
527
  self,
520
528
  num_embeddings: int,
521
529
  embedding_dim: int,
530
+ *,
522
531
  bias: bool = False,
523
532
  params_dtype: Optional[torch.dtype] = None,
524
533
  org_num_embeddings: Optional[int] = None,
525
534
  padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
526
535
  quant_config: Optional[QuantizationConfig] = None,
527
536
  prefix: str = "",
537
+ use_attn_tp_group: bool = False,
528
538
  use_presharded_weights: bool = False,
529
539
  ):
530
540
  super().__init__(
531
541
  num_embeddings,
532
542
  embedding_dim,
533
- params_dtype,
534
- org_num_embeddings,
535
- padding_size,
536
- quant_config,
537
- prefix,
543
+ params_dtype=params_dtype,
544
+ org_num_embeddings=org_num_embeddings,
545
+ padding_size=padding_size,
546
+ quant_config=quant_config,
547
+ prefix=prefix,
548
+ use_attn_tp_group=use_attn_tp_group,
538
549
  use_presharded_weights=use_presharded_weights,
539
550
  )
540
551
  self.quant_config = quant_config
@@ -100,7 +100,7 @@ class LoRAManager:
100
100
  self.configs[name] = LoRAConfig(path)
101
101
  self.hf_target_names.update(self.configs[name].target_modules)
102
102
 
103
- # Target lora weight names for lora_a and lora_b modules repectively.
103
+ # Target lora weight names for lora_a and lora_b modules respectively.
104
104
  # e.g., {("qkv_proj", "q_proj"), ("qkv_proj", "kv_proj")}
105
105
  self.lora_weight_names: Set[Tuple[str]] = set(
106
106
  [get_stacked_name(module) for module in self.hf_target_names]
@@ -50,15 +50,15 @@ class LoRAMemoryPool:
50
50
  self.uid_to_buffer_id: Dict[Optional[str], int] = {}
51
51
 
52
52
  # Buffer idx -> lora uid in memory pool
53
- # All uids are initalized as empty strings for empty buffer slots
54
- # Here we don't initalize to None since None is a valid uid
53
+ # All uids are initialized as empty strings for empty buffer slots
54
+ # Here we don't initialize to None since None is a valid uid
55
55
  self.buffer_id_to_uid: List[Optional[str]] = [""] * self.max_loras_per_batch
56
56
 
57
57
  def get_lora_A_shape(
58
58
  self, module_name: str, base_model: torch.nn.Module
59
59
  ) -> Tuple[int]:
60
60
  """
61
- Given a module_name (might be a stacked name), return the hidden dims of modules's input and output.
61
+ Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
62
62
  """
63
63
  input_dim, _ = get_hidden_dim(module_name, self.base_hf_config, base_model)
64
64
  c = get_stacked_multiply(module_name)
@@ -75,7 +75,7 @@ class LoRAMemoryPool:
75
75
  self, module_name: str, base_model: torch.nn.Module
76
76
  ) -> Tuple[int]:
77
77
  """
78
- Given a module_name (might be a stacked name), return the hidden dims of modules's input and output.
78
+ Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
79
79
  """
80
80
  _, output_dim = get_hidden_dim(module_name, self.base_hf_config, base_model)
81
81
  c = get_stacked_multiply(module_name)
@@ -77,7 +77,7 @@ def _gate_up_lora_b_kernel(
77
77
  k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
78
78
  )
79
79
 
80
- # Iteate to compute the block in output matrix
80
+ # Iterate to compute the block in output matrix
81
81
  partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32)
82
82
  for k in range(0, tl.cdiv(K, BLOCK_K)):
83
83
  x_tile = tl.load(
@@ -79,7 +79,7 @@ def _qkv_lora_b_kernel(
79
79
  k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
80
80
  )
81
81
 
82
- # Iteate to compute the block in output matrix
82
+ # Iterate to compute the block in output matrix
83
83
  partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32)
84
84
  for k in range(0, tl.cdiv(K, BLOCK_K)):
85
85
  x_tile = tl.load(
@@ -67,7 +67,7 @@ def _sgemm_lora_a_kernel(
67
67
  k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
68
68
  )
69
69
 
70
- # Iteate to compute the block in output matrix
70
+ # Iterate to compute the block in output matrix
71
71
  partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32)
72
72
  for k in range(0, tl.cdiv(K, BLOCK_K)):
73
73
  x_tile = tl.load(
@@ -69,7 +69,7 @@ def _sgemm_lora_b_kernel(
69
69
  k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1
70
70
  )
71
71
 
72
- # Iteate to compute the block in output matrix
72
+ # Iterate to compute the block in output matrix
73
73
  partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32)
74
74
  for k in range(0, tl.cdiv(K, BLOCK_K)):
75
75
  x_tile = tl.load(
sglang/srt/lora/utils.py CHANGED
@@ -79,7 +79,7 @@ def get_hidden_dim(
79
79
  module_name: str, config: AutoConfig, base_model: torch.nn.Module
80
80
  ) -> Tuple[int]:
81
81
  """
82
- Given a module_name (might be a stacked name), return the hidden dims of modules's input and output.
82
+ Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
83
83
  """
84
84
 
85
85
  if hasattr(base_model, "get_hidden_dim"):
@@ -17,13 +17,13 @@ import logging
17
17
  import multiprocessing as mp
18
18
  import signal
19
19
  import threading
20
+ import time
20
21
  from enum import Enum, auto
21
22
 
22
23
  import psutil
23
24
  import setproctitle
24
25
  import zmq
25
26
 
26
- from sglang.srt.disaggregation.utils import DisaggregationMode
27
27
  from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
28
28
  from sglang.srt.managers.io_struct import (
29
29
  TokenizedEmbeddingReqInput,
@@ -158,7 +158,7 @@ class DataParallelController:
158
158
  # This thread cannot be closed because otherwise the `kill_itself_when_parent_died`
159
159
  # function in scheduler.py will kill the scheduler.
160
160
  while True:
161
- pass
161
+ time.sleep(30 * 24 * 3600)
162
162
 
163
163
  def launch_dp_attention_schedulers(self, server_args, port_args):
164
164
  self.launch_tensor_parallel_group(server_args, port_args, 0, None)
@@ -210,7 +210,7 @@ class DataParallelController:
210
210
  )
211
211
  # compute zmq ports for this dp rank
212
212
  rank_port_args = PortArgs.init_new(server_args, dp_rank)
213
- # Data parallelism resues the tensor parallelism group,
213
+ # Data parallelism reuses the tensor parallelism group,
214
214
  # so all dp ranks should use the same nccl port.
215
215
  rank_port_args.nccl_port = port_args.nccl_port
216
216
 
@@ -28,6 +28,7 @@ from sglang.srt.hf_transformers_utils import get_tokenizer
28
28
  from sglang.srt.managers.io_struct import (
29
29
  BatchEmbeddingOut,
30
30
  BatchMultimodalDecodeReq,
31
+ BatchMultimodalOut,
31
32
  BatchStrOut,
32
33
  BatchTokenIDOut,
33
34
  )
@@ -60,6 +61,8 @@ class DecodeStatus:
60
61
  decode_ids: List[int]
61
62
  surr_offset: int
62
63
  read_offset: int
64
+ # Offset that's sent to tokenizer for incremental update.
65
+ sent_offset: int = 0
63
66
 
64
67
 
65
68
  class DetokenizerManager:
@@ -151,7 +154,7 @@ class DetokenizerManager:
151
154
  self.decode_status[rid] = s
152
155
  else:
153
156
  s = self.decode_status[rid]
154
- s.decode_ids = recv_obj.decode_ids[i]
157
+ s.decode_ids.extend(recv_obj.decode_ids[i])
155
158
 
156
159
  read_ids.append(
157
160
  self.trim_matched_stop(
@@ -199,13 +202,15 @@ class DetokenizerManager:
199
202
  else:
200
203
  new_text = find_printable_text(new_text)
201
204
 
202
- output_strs.append(
203
- self.trim_matched_stop(
204
- s.decoded_text + new_text,
205
- recv_obj.finished_reasons[i],
206
- recv_obj.no_stop_trim[i],
207
- )
205
+ output_str = self.trim_matched_stop(
206
+ s.decoded_text + new_text,
207
+ recv_obj.finished_reasons[i],
208
+ recv_obj.no_stop_trim[i],
208
209
  )
210
+ # Incrementally send text.
211
+ incremental_output = output_str[s.sent_offset :]
212
+ s.sent_offset = len(output_str)
213
+ output_strs.append(incremental_output)
209
214
 
210
215
  return BatchStrOut(
211
216
  rids=recv_obj.rids,
@@ -232,7 +237,15 @@ class DetokenizerManager:
232
237
  )
233
238
 
234
239
  def handle_multimodal_decode_req(self, recv_obj: BatchMultimodalDecodeReq):
235
- raise NotImplementedError()
240
+ outputs = self.tokenizer.detokenize(recv_obj)
241
+ return BatchMultimodalOut(
242
+ rids=recv_obj.rids,
243
+ finished_reasons=recv_obj.finished_reasons,
244
+ outputs=outputs,
245
+ prompt_tokens=recv_obj.prompt_tokens,
246
+ completion_tokens=recv_obj.completion_tokens,
247
+ cached_tokens=recv_obj.cached_tokens,
248
+ )
236
249
 
237
250
 
238
251
  class LimitedCapacityDict(OrderedDict):
@@ -12,7 +12,7 @@
12
12
  # limitations under the License.
13
13
  # ==============================================================================
14
14
  """
15
- The definition of objects transfered between different
15
+ The definition of objects transferred between different
16
16
  processes (TokenizerManager, DetokenizerManager, Controller).
17
17
  """
18
18
 
@@ -836,6 +836,8 @@ class ProfileReqInput:
836
836
  # the caller doesn't need to run stop_profile.
837
837
  num_steps: Optional[int] = None
838
838
  activities: Optional[List[Literal["CPU", "GPU", "MEM", "CUDA_PROFILER"]]] = None
839
+ with_stack: Optional[bool] = None
840
+ record_shapes: Optional[bool] = None
839
841
 
840
842
 
841
843
  class ProfileReqType(Enum):
@@ -51,7 +51,7 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
51
51
  self, input_ids: List[int], mm_inputs: MultimodalInputs
52
52
  ) -> List[int]:
53
53
  """
54
- This function will replace the data-tokens inbetween with pad_values accordingly
54
+ This function will replace the data-tokens in between with pad_values accordingly
55
55
  """
56
56
  pad_values = [item.pad_value for item in mm_inputs.mm_items]
57
57
  data_token_pairs = self.data_token_id_pairs