sglang 0.5.4__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. sglang/bench_serving.py +56 -12
  2. sglang/launch_server.py +2 -0
  3. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +101 -4
  4. sglang/srt/compilation/backend.py +1 -1
  5. sglang/srt/configs/model_config.py +5 -5
  6. sglang/srt/distributed/parallel_state.py +0 -7
  7. sglang/srt/entrypoints/engine.py +18 -15
  8. sglang/srt/entrypoints/grpc_server.py +0 -1
  9. sglang/srt/entrypoints/http_server.py +75 -94
  10. sglang/srt/environ.py +16 -2
  11. sglang/srt/eplb/expert_distribution.py +30 -0
  12. sglang/srt/function_call/function_call_parser.py +2 -0
  13. sglang/srt/function_call/minimax_m2.py +367 -0
  14. sglang/srt/layers/activation.py +6 -0
  15. sglang/srt/layers/attention/flashattention_backend.py +12 -2
  16. sglang/srt/layers/attention/flashinfer_backend.py +10 -1
  17. sglang/srt/layers/attention/flashinfer_mla_backend.py +18 -10
  18. sglang/srt/layers/attention/trtllm_mla_backend.py +1 -13
  19. sglang/srt/layers/attention/utils.py +78 -0
  20. sglang/srt/layers/communicator.py +1 -0
  21. sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
  22. sglang/srt/layers/layernorm.py +19 -4
  23. sglang/srt/layers/logits_processor.py +5 -0
  24. sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
  25. sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
  26. sglang/srt/layers/moe/ep_moe/layer.py +79 -272
  27. sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
  28. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
  29. sglang/srt/layers/moe/moe_runner/deep_gemm.py +287 -22
  30. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  31. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  32. sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
  33. sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
  34. sglang/srt/layers/moe/token_dispatcher/deepep.py +18 -14
  35. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  36. sglang/srt/layers/moe/topk.py +4 -4
  37. sglang/srt/layers/moe/utils.py +3 -4
  38. sglang/srt/layers/quantization/__init__.py +3 -5
  39. sglang/srt/layers/quantization/awq.py +0 -3
  40. sglang/srt/layers/quantization/base_config.py +7 -0
  41. sglang/srt/layers/quantization/fp8.py +68 -63
  42. sglang/srt/layers/quantization/gguf.py +566 -0
  43. sglang/srt/layers/quantization/mxfp4.py +30 -38
  44. sglang/srt/layers/quantization/unquant.py +23 -45
  45. sglang/srt/layers/quantization/w4afp8.py +38 -2
  46. sglang/srt/layers/radix_attention.py +5 -2
  47. sglang/srt/layers/rotary_embedding.py +13 -1
  48. sglang/srt/layers/sampler.py +12 -1
  49. sglang/srt/managers/io_struct.py +3 -0
  50. sglang/srt/managers/multi_tokenizer_mixin.py +17 -1
  51. sglang/srt/managers/scheduler.py +21 -15
  52. sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
  53. sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
  54. sglang/srt/managers/tokenizer_manager.py +11 -19
  55. sglang/srt/mem_cache/hicache_storage.py +7 -1
  56. sglang/srt/mem_cache/memory_pool.py +82 -0
  57. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  58. sglang/srt/model_executor/forward_batch_info.py +44 -3
  59. sglang/srt/model_executor/model_runner.py +1 -149
  60. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
  61. sglang/srt/models/deepseek_v2.py +147 -44
  62. sglang/srt/models/glm4_moe.py +322 -354
  63. sglang/srt/models/glm4_moe_nextn.py +4 -14
  64. sglang/srt/models/glm4v_moe.py +29 -196
  65. sglang/srt/models/minimax_m2.py +922 -0
  66. sglang/srt/models/nvila.py +355 -0
  67. sglang/srt/models/nvila_lite.py +184 -0
  68. sglang/srt/models/qwen2.py +22 -1
  69. sglang/srt/models/qwen3.py +34 -4
  70. sglang/srt/models/qwen3_moe.py +2 -4
  71. sglang/srt/multimodal/processors/base_processor.py +1 -0
  72. sglang/srt/multimodal/processors/glm4v.py +1 -1
  73. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  74. sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
  75. sglang/srt/parser/reasoning_parser.py +28 -1
  76. sglang/srt/server_args.py +365 -186
  77. sglang/srt/single_batch_overlap.py +2 -7
  78. sglang/srt/utils/common.py +87 -42
  79. sglang/srt/utils/hf_transformers_utils.py +7 -3
  80. sglang/test/test_deterministic.py +235 -12
  81. sglang/test/test_deterministic_utils.py +2 -1
  82. sglang/version.py +1 -1
  83. {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +7 -6
  84. {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +87 -82
  85. sglang/srt/models/vila.py +0 -306
  86. {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
  87. {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
  88. {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
sglang/bench_serving.py CHANGED
@@ -88,6 +88,7 @@ class RequestFuncOutput:
88
88
  latency: float = 0.0
89
89
  ttft: float = 0.0 # Time to first token
90
90
  itl: List[float] = field(default_factory=list) # List of inter-token latencies
91
+ text_chunks: List[str] = field(default_factory=list)
91
92
  prompt_len: int = 0
92
93
  error: str = ""
93
94
  output_len: int = 0
@@ -258,6 +259,9 @@ async def async_request_openai_completions(
258
259
 
259
260
  # Decoding phase
260
261
  else:
262
+ output.text_chunks.append(
263
+ data["choices"][0]["text"]
264
+ )
261
265
  output.itl.append(timestamp - most_recent_timestamp)
262
266
 
263
267
  most_recent_timestamp = timestamp
@@ -574,9 +578,8 @@ async def async_request_sglang_generate(
574
578
  num_new_tokens = output_len - last_output_len
575
579
  if num_new_tokens == 0:
576
580
  continue
577
- adjust_itl = (
578
- timestamp - most_recent_timestamp
579
- ) / num_new_tokens
581
+ chunk_gap = timestamp - most_recent_timestamp
582
+ adjust_itl = chunk_gap / num_new_tokens
580
583
  output.itl.extend([adjust_itl] * num_new_tokens)
581
584
 
582
585
  most_recent_timestamp = timestamp
@@ -764,6 +767,7 @@ def get_dataset(args, tokenizer, model_id=None):
764
767
  image_content=args.image_content,
765
768
  image_format=args.image_format,
766
769
  image_resolution=args.image_resolution,
770
+ backend=args.backend,
767
771
  )
768
772
  elif args.dataset_name == "generated-shared-prefix":
769
773
  assert not tokenize_prompt
@@ -781,6 +785,7 @@ def get_dataset(args, tokenizer, model_id=None):
781
785
  input_requests = sample_mmmu_requests(
782
786
  num_requests=args.num_prompts,
783
787
  processor=processor,
788
+ backend=args.backend,
784
789
  fixed_output_len=args.random_output_len,
785
790
  random_sample=True,
786
791
  )
@@ -1009,6 +1014,7 @@ async def get_mooncake_request_over_time(
1009
1014
  def sample_mmmu_requests(
1010
1015
  num_requests: int,
1011
1016
  processor: AutoProcessor | AutoTokenizer,
1017
+ backend: str,
1012
1018
  fixed_output_len: Optional[int] = None,
1013
1019
  random_sample: bool = True,
1014
1020
  ) -> List[DatasetRow]:
@@ -1081,7 +1087,7 @@ def sample_mmmu_requests(
1081
1087
  text_prompt = f"Question: {question}\n\nAnswer: "
1082
1088
  output_len = fixed_output_len if fixed_output_len is not None else 256
1083
1089
  data_row = create_mm_data_row(
1084
- text_prompt, [image], [image_data], output_len, processor
1090
+ text_prompt, [image], [image_data], output_len, processor, backend
1085
1091
  )
1086
1092
  filtered_dataset.append(data_row)
1087
1093
 
@@ -1316,13 +1322,19 @@ def parse_image_resolution(image_resolution: str) -> Tuple[int, int]:
1316
1322
  )
1317
1323
 
1318
1324
 
1319
- def create_mm_data_row(text_prompt, images: list, images_base64, output_len, processor):
1325
+ def create_mm_data_row(
1326
+ text_prompt, images: list, images_base64, output_len, processor, backend
1327
+ ):
1320
1328
  try:
1321
- content_items = [
1322
- {"type": "image", "image": {"url": image_base64}}
1323
- for image_base64 in images_base64
1324
- ]
1325
- content_items.append({"type": "text", "text": text_prompt})
1329
+ if type(processor).__name__ == "Phi4MMProcessor":
1330
+ # <|endoftext10|> is the image token used in the phi-4-multimodal model.
1331
+ content_items = text_prompt.replace("image 1", "|endoftext10|")
1332
+ else:
1333
+ content_items = [
1334
+ {"type": "image", "image": {"url": image_base64}}
1335
+ for image_base64 in images_base64
1336
+ ]
1337
+ content_items.append({"type": "text", "text": text_prompt})
1326
1338
  prompt_str = processor.apply_chat_template(
1327
1339
  [{"role": "user", "content": content_items}],
1328
1340
  add_generation_prompt=True,
@@ -1362,8 +1374,16 @@ def create_mm_data_row(text_prompt, images: list, images_base64, output_len, pro
1362
1374
  # Vision tokens = total tokens - text tokens
1363
1375
  vision_prompt_len = prompt_len - text_prompt_len
1364
1376
 
1377
+ use_raw_prompt = backend in [
1378
+ "sglang-oai",
1379
+ "sglang-oai-chat",
1380
+ "vllm",
1381
+ "vllm-chat",
1382
+ "lmdeploy",
1383
+ "lmdeploy-chat",
1384
+ ]
1365
1385
  return DatasetRow(
1366
- prompt=text_prompt,
1386
+ prompt=text_prompt if use_raw_prompt else prompt_str,
1367
1387
  prompt_len=prompt_len,
1368
1388
  output_len=output_len,
1369
1389
  text_prompt_len=text_prompt_len,
@@ -1382,6 +1402,7 @@ def sample_image_requests(
1382
1402
  image_content: str,
1383
1403
  image_format: str,
1384
1404
  image_resolution: str,
1405
+ backend: str,
1385
1406
  ) -> List[DatasetRow]:
1386
1407
  """Generate requests with images.
1387
1408
 
@@ -1447,6 +1468,7 @@ def sample_image_requests(
1447
1468
  list(images_base64),
1448
1469
  int(output_lens[i]),
1449
1470
  processor,
1471
+ backend,
1450
1472
  )
1451
1473
 
1452
1474
  dataset.append(data_row)
@@ -1607,6 +1629,7 @@ def calculate_metrics(
1607
1629
  dur_s: float,
1608
1630
  tokenizer: PreTrainedTokenizerBase,
1609
1631
  backend: str,
1632
+ accept_length: Optional[float] = None,
1610
1633
  ) -> Tuple[BenchmarkMetrics, List[int]]:
1611
1634
  output_lens: List[int] = []
1612
1635
  retokenized_output_lens: List[int] = []
@@ -1618,6 +1641,14 @@ def calculate_metrics(
1618
1641
  tpots: List[float] = []
1619
1642
  ttfts: List[float] = []
1620
1643
  e2e_latencies: List[float] = []
1644
+ retokenized_itls: List[float] = []
1645
+
1646
+ use_retokenized_itl = (
1647
+ accept_length is not None
1648
+ and accept_length > 0
1649
+ and backend in ("sglang-oai", "sglang-oai-chat")
1650
+ )
1651
+
1621
1652
  for i in range(len(outputs)):
1622
1653
  if outputs[i].success:
1623
1654
  output_len = outputs[i].output_len
@@ -1631,7 +1662,17 @@ def calculate_metrics(
1631
1662
  total_input_vision += input_requests[i].vision_prompt_len
1632
1663
  if output_len > 1:
1633
1664
  tpots.append((outputs[i].latency - outputs[i].ttft) / (output_len - 1))
1634
- itls += outputs[i].itl
1665
+ if use_retokenized_itl:
1666
+ for k, itl in enumerate(outputs[i].itl):
1667
+ num_tokens = len(
1668
+ tokenizer.encode(
1669
+ outputs[i].text_chunks[k], add_special_tokens=False
1670
+ )
1671
+ )
1672
+ adjusted_itl = itl / num_tokens
1673
+ retokenized_itls.extend([adjusted_itl] * num_tokens)
1674
+ else:
1675
+ itls += outputs[i].itl
1635
1676
  ttfts.append(outputs[i].ttft)
1636
1677
 
1637
1678
  e2e_latencies.append(outputs[i].latency)
@@ -1647,6 +1688,8 @@ def calculate_metrics(
1647
1688
  "on the benchmark arguments.",
1648
1689
  stacklevel=2,
1649
1690
  )
1691
+
1692
+ itls = retokenized_itls if use_retokenized_itl else itls
1650
1693
  metrics = BenchmarkMetrics(
1651
1694
  completed=completed,
1652
1695
  total_input=total_input,
@@ -1910,6 +1953,7 @@ async def benchmark(
1910
1953
  dur_s=benchmark_duration,
1911
1954
  tokenizer=tokenizer,
1912
1955
  backend=backend,
1956
+ accept_length=accept_length,
1913
1957
  )
1914
1958
 
1915
1959
  print("\n{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="="))
sglang/launch_server.py CHANGED
@@ -12,10 +12,12 @@ if __name__ == "__main__":
12
12
 
13
13
  try:
14
14
  if server_args.grpc_mode:
15
+ # Handle gRPC server
15
16
  from sglang.srt.entrypoints.grpc_server import serve_grpc
16
17
 
17
18
  asyncio.run(serve_grpc(server_args))
18
19
  else:
20
+ # Handle HTTP server
19
21
  from sglang.srt.entrypoints.http_server import launch_server
20
22
 
21
23
  launch_server(server_args)
@@ -9,6 +9,22 @@ import torch
9
9
  import triton
10
10
  import triton.language as tl
11
11
 
12
+ from sglang.srt.layers.deep_gemm_wrapper.configurer import ENABLE_JIT_DEEPGEMM
13
+ from sglang.srt.utils.common import calc_diff, get_bool_env_var
14
+
15
+ if ENABLE_JIT_DEEPGEMM:
16
+ import deep_gemm
17
+
18
+ _ENABLE_MM_DEEPGEMM = get_bool_env_var(
19
+ "SGLANG_BATCH_INVARIANT_OPS_ENABLE_MM_DEEPGEMM", "1"
20
+ )
21
+ _ENABLE_MM_COMPARISON_TEST = get_bool_env_var(
22
+ "SGLANG_BATCH_INVARIANT_OPS_ENABLE_MM_COMPARISON_TEST"
23
+ )
24
+
25
+ if not _ENABLE_MM_DEEPGEMM:
26
+ print("Disable DeepGEMM in batch invariant ops. Performance may be suboptimal.")
27
+
12
28
  __all__ = [
13
29
  "set_batch_invariant_mode",
14
30
  "is_batch_invariant_mode_enabled",
@@ -140,7 +156,7 @@ def matmul_kernel_persistent(
140
156
  tl.store(c_ptrs, c, mask=c_mask)
141
157
 
142
158
 
143
- def matmul_persistent(
159
+ def _matmul_persistent_triton(
144
160
  a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor | None = None
145
161
  ):
146
162
  # Check constraints.
@@ -217,6 +233,54 @@ def matmul_persistent(
217
233
  return c
218
234
 
219
235
 
236
+ def _matmul_persistent_deepgemm(
237
+ a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor | None = None
238
+ ):
239
+ M, K = a.shape
240
+ K, N = b.shape
241
+ dtype = a.dtype
242
+ out = torch.empty((M, N), device=a.device, dtype=dtype)
243
+
244
+ deep_gemm.bf16_gemm_nn(a, b, out)
245
+
246
+ # TODO can this be put in DeepGEMM's `c`?
247
+ if bias is not None:
248
+ out += bias
249
+
250
+ return out
251
+
252
+
253
+ def matmul_persistent(
254
+ a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor | None = None
255
+ ):
256
+ if (
257
+ _ENABLE_MM_DEEPGEMM
258
+ and ENABLE_JIT_DEEPGEMM
259
+ and (a.dtype == torch.bfloat16)
260
+ and (b.dtype == torch.bfloat16)
261
+ and a.is_contiguous()
262
+ and b.transpose(0, 1).is_contiguous()
263
+ ):
264
+ if _ENABLE_MM_COMPARISON_TEST:
265
+ out_triton = _matmul_persistent_triton(a=a, b=b, bias=bias)
266
+ out_deepgemm = _matmul_persistent_deepgemm(a=a, b=b, bias=bias)
267
+ diff = calc_diff(out_triton, out_deepgemm)
268
+ assert diff < 0.0001, f"{diff=} {out_triton=} {out_deepgemm=}"
269
+ # can be enabled for debugging
270
+ # print(
271
+ # f"{diff=} "
272
+ # f"{(out_triton - out_deepgemm).abs().mean()=} "
273
+ # f"{(out_triton - out_deepgemm).abs().sum()=} "
274
+ # f"{torch.sum(out_triton != out_deepgemm)=} "
275
+ # )
276
+ # print(f"{a=} {b=} {bias=} {out_triton=} {out_deepgemm=}")
277
+ return out_deepgemm
278
+
279
+ return _matmul_persistent_deepgemm(a=a, b=b, bias=bias)
280
+
281
+ return _matmul_persistent_triton(a=a, b=b, bias=bias)
282
+
283
+
220
284
  @triton.jit
221
285
  def _log_softmax_kernel(
222
286
  input_ptr,
@@ -495,16 +559,39 @@ def mean_batch_invariant(input, dim, keepdim=False, dtype: torch.dtype | None =
495
559
  return torch.sum(input, dim=dim, keepdim=keepdim, dtype=torch.float32) / n_elems
496
560
 
497
561
 
562
+ def bmm_batch_invariant(a, b, *, out=None):
563
+ # Batched matrix multiply: (B, M, K) x (B, K, N) -> (B, M, N)
564
+ # Process each batch separately with our persistent kernel
565
+ if a.ndim == 3 and b.ndim == 3:
566
+ results = []
567
+ for i in range(a.shape[0]):
568
+ results.append(matmul_persistent(a[i], b[i]))
569
+ result = torch.stack(results, dim=0)
570
+
571
+ if out is not None:
572
+ out.copy_(result)
573
+ return out
574
+ return result
575
+ else:
576
+ raise ValueError(
577
+ f"bmm_batch_invariant expects 3D tensors, "
578
+ f"got shapes {a.shape} and {b.shape}"
579
+ )
580
+
581
+
498
582
  _batch_invariant_MODE = False
499
583
  _batch_invariant_LIB = None
584
+ _original_torch_bmm = None
500
585
 
501
586
 
502
587
  def is_batch_invariant_mode_enabled():
503
588
  return _batch_invariant_MODE
504
589
 
505
590
 
506
- def enable_batch_invariant_mode():
507
- global _batch_invariant_MODE, _batch_invariant_LIB
591
+ def enable_batch_invariant_mode(
592
+ enable_bmm: bool = True,
593
+ ):
594
+ global _batch_invariant_MODE, _batch_invariant_LIB, _original_torch_bmm
508
595
  if _batch_invariant_MODE:
509
596
  return
510
597
 
@@ -517,11 +604,21 @@ def enable_batch_invariant_mode():
517
604
  )
518
605
  _batch_invariant_LIB.impl("aten::mean.dim", mean_batch_invariant, "CUDA")
519
606
 
607
+ if enable_bmm:
608
+ _batch_invariant_LIB.impl("aten::bmm", bmm_batch_invariant, "CUDA")
609
+
610
+ # Also monkeypatch torch.bmm directly as a fallback
611
+ _original_torch_bmm = torch.bmm
612
+ torch.bmm = bmm_batch_invariant
613
+
520
614
 
521
615
  def disable_batch_invariant_mode():
522
- global _batch_invariant_MODE, _batch_invariant_LIB
616
+ global _batch_invariant_MODE, _batch_invariant_LIB, _original_torch_bmm
523
617
  if _batch_invariant_LIB is not None:
524
618
  _batch_invariant_LIB._destroy()
619
+ if _original_torch_bmm is not None:
620
+ torch.bmm = _original_torch_bmm
621
+ _original_torch_bmm = None
525
622
  _batch_invariant_MODE = False
526
623
  _batch_invariant_LIB = None
527
624
 
@@ -392,7 +392,7 @@ class SGLangBackend:
392
392
  self.configure_post_pass()
393
393
 
394
394
  self.split_gm, self.piecewise_graphs = split_graph(
395
- graph, ["sglang.unified_attention_with_output"]
395
+ graph, ["sglang.unified_attention_with_output", "sglang.inplace_all_reduce"]
396
396
  )
397
397
 
398
398
  from torch._dynamo.utils import lazy_format_graph_code
@@ -535,7 +535,7 @@ class ModelConfig:
535
535
  quant_cfg = self._parse_modelopt_quant_config(quant_config_dict)
536
536
  return quant_cfg
537
537
 
538
- def _parse_modelopt_quant_config(self, quant_config_dict: dict) -> dict:
538
+ def _parse_modelopt_quant_config(self, quant_config_dict: dict) -> Optional[dict]:
539
539
  """Parse ModelOpt quantization config and return the appropriate quant_method."""
540
540
  json_quant_configs = quant_config_dict["quantization"]
541
541
  quant_algo = json_quant_configs.get("quant_algo", None)
@@ -547,8 +547,7 @@ class ModelConfig:
547
547
  elif quant_algo and "FP8" in quant_algo:
548
548
  return {"quant_method": "modelopt_fp8"}
549
549
  else:
550
- # Default to FP8 for backward compatibility
551
- return {"quant_method": "modelopt_fp8"}
550
+ return None
552
551
 
553
552
  def _is_already_quantized(self) -> bool:
554
553
  """Check if the model is already quantized based on config files."""
@@ -806,7 +805,7 @@ def _get_and_verify_dtype(
806
805
  ) -> torch.dtype:
807
806
  # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
808
807
  # because config.torch_dtype can be None.
809
- config_dtype = getattr(config, "torch_dtype", None)
808
+ config_dtype = getattr(config, "dtype", None)
810
809
  if isinstance(config_dtype, str):
811
810
  config_dtype = _STR_DTYPE_TO_TORCH_DTYPE.get(config_dtype, None)
812
811
  if config_dtype is None:
@@ -915,12 +914,13 @@ multimodal_model_archs = [
915
914
  "InternVLChatModel",
916
915
  "InternS1ForConditionalGeneration",
917
916
  "Phi4MMForCausalLM",
918
- "VILAForConditionalGeneration",
919
917
  "Step3VLForConditionalGeneration",
920
918
  "POINTSV15ChatModel",
921
919
  "DotsVLMForCausalLM",
922
920
  "DotsOCRForCausalLM",
923
921
  "Sarashina2VisionForCausalLM",
922
+ "NVILAForConditionalGeneration",
923
+ "NVILALiteForConditionalGeneration",
924
924
  "DeepseekOCRForCausalLM",
925
925
  ]
926
926
 
@@ -340,17 +340,10 @@ class GroupCoordinator:
340
340
  self.qr_comm: Optional[QuickAllReduce] = None
341
341
  if use_custom_allreduce and self.world_size > 1:
342
342
  # Initialize a custom fast all-reduce implementation.
343
- if torch_compile is not None and torch_compile:
344
- # For piecewise CUDA graph, the requirement for custom allreduce is larger to
345
- # avoid illegal cuda memory access.
346
- ca_max_size = 256 * 1024 * 1024
347
- else:
348
- ca_max_size = 8 * 1024 * 1024
349
343
  try:
350
344
  self.ca_comm = CustomAllreduce(
351
345
  group=self.cpu_group,
352
346
  device=self.device,
353
- max_size=ca_max_size,
354
347
  )
355
348
  except Exception as e:
356
349
  logger.warning(
@@ -101,7 +101,7 @@ class Engine(EngineBase):
101
101
 
102
102
  Note:
103
103
  1. The HTTP server, Engine, and TokenizerManager all run in the main process.
104
- 2. Inter-process communication (IPC) is handled via the ZMQ library, with each process using a different port.
104
+ 2. Inter-process communication is done through IPC (each process uses a different port) via the ZMQ library.
105
105
  """
106
106
 
107
107
  def __init__(self, **kwargs):
@@ -109,6 +109,8 @@ class Engine(EngineBase):
109
109
  The arguments of this function is the same as `sglang/srt/server_args.py::ServerArgs`.
110
110
  Please refer to `ServerArgs` for the documentation.
111
111
  """
112
+
113
+ # Parse server_args
112
114
  if "server_args" in kwargs:
113
115
  # Directly load server_args
114
116
  server_args = kwargs["server_args"]
@@ -118,29 +120,28 @@ class Engine(EngineBase):
118
120
  # Do not print logs by default
119
121
  kwargs["log_level"] = "error"
120
122
  server_args = ServerArgs(**kwargs)
123
+ self.server_args = server_args
124
+ logger.info(f"{server_args=}")
121
125
 
122
126
  # Shutdown the subprocesses automatically when the program exits
123
127
  atexit.register(self.shutdown)
124
128
 
125
- # Allocate ports for inter-process communications
126
- self.port_args = PortArgs.init_new(server_args)
127
- logger.info(f"{server_args=}")
128
-
129
129
  # Launch subprocesses
130
- tokenizer_manager, template_manager, scheduler_info = _launch_subprocesses(
131
- server_args=server_args,
132
- port_args=self.port_args,
130
+ tokenizer_manager, template_manager, scheduler_info, port_args = (
131
+ _launch_subprocesses(server_args=server_args)
133
132
  )
134
- self.server_args = server_args
135
133
  self.tokenizer_manager = tokenizer_manager
136
134
  self.template_manager = template_manager
137
135
  self.scheduler_info = scheduler_info
136
+ self.port_args = port_args
138
137
 
138
+ # Initialize ZMQ sockets
139
139
  context = zmq.Context(2)
140
140
  self.send_to_rpc = get_zmq_socket(
141
141
  context, zmq.DEALER, self.port_args.rpc_ipc_name, True
142
142
  )
143
143
 
144
+ # Enable tracing
144
145
  if server_args.enable_trace:
145
146
  process_tracing_init(server_args.oltp_traces_endpoint, "sglang")
146
147
  if server_args.disaggregation_mode == "null":
@@ -672,15 +673,17 @@ def _set_envs_and_config(server_args: ServerArgs):
672
673
  os.environ["NCCL_CUMEM_ENABLE"] = str(int(server_args.enable_symm_mem))
673
674
  if not server_args.enable_symm_mem:
674
675
  os.environ["NCCL_NVLS_ENABLE"] = str(int(server_args.enable_nccl_nvls))
675
- os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4"
676
+ os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "8"
676
677
  os.environ["CUDA_MODULE_LOADING"] = "AUTO"
677
- # flashinfer uses this environment variable for various kernels from MoE to quant kernels
678
+
678
679
  if os.environ.get("TRTLLM_ENABLE_PDL", "1") != "0":
680
+ # flashinfer uses this environment variable for various kernels from MoE to quant kernels
679
681
  os.environ["TRTLLM_ENABLE_PDL"] = "1"
680
682
 
681
683
  if os.environ.get("CUTE_DSL_LOG_LEVEL") is None:
682
684
  # Default to warning level, to avoid too many logs
683
685
  os.environ["CUTE_DSL_LOG_LEVEL"] = "30"
686
+
684
687
  if os.environ.get("CUTE_DSL_LOG_TO_CONSOLE") is None:
685
688
  # Need to set log to console, otherwise the log level won't take effect
686
689
  os.environ["CUTE_DSL_LOG_TO_CONSOLE"] = "1"
@@ -709,7 +712,7 @@ def _set_envs_and_config(server_args: ServerArgs):
709
712
  if _is_cuda and not get_bool_env_var("SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK"):
710
713
  assert_pkg_version(
711
714
  "sgl-kernel",
712
- "0.3.16.post3",
715
+ "0.3.16.post4",
713
716
  "Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
714
717
  )
715
718
 
@@ -840,7 +843,7 @@ def _launch_subprocesses(
840
843
 
841
844
  if os.getenv("SGLANG_BLOCK_NONZERO_RANK_CHILDREN") == "0":
842
845
  # When using `Engine` as a Python API, we don't want to block here.
843
- return None, None, None
846
+ return None, None, None, port_args
844
847
 
845
848
  launch_dummy_health_check_server(
846
849
  server_args.host, server_args.port, server_args.enable_metrics
@@ -851,7 +854,7 @@ def _launch_subprocesses(
851
854
  logger.error(
852
855
  f"Scheduler or DataParallelController {proc.pid} terminated with {proc.exitcode}"
853
856
  )
854
- return None, None, None
857
+ return None, None, None, port_args
855
858
 
856
859
  # Launch detokenizer process
857
860
  detoken_proc = mp.Process(
@@ -897,4 +900,4 @@ def _launch_subprocesses(
897
900
 
898
901
  tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"]
899
902
 
900
- return tokenizer_manager, template_manager, scheduler_info
903
+ return tokenizer_manager, template_manager, scheduler_info, port_args
@@ -999,7 +999,6 @@ def _wait_and_warmup_grpc(
999
999
  # Mark health service as SERVING after warmup completes
1000
1000
  if health_servicer:
1001
1001
  health_servicer.set_serving()
1002
- logger.info("Health service marked as SERVING")
1003
1002
 
1004
1003
  logger.info("The server is fired up and ready to roll!")
1005
1004