sglang 0.4.7__py3-none-any.whl → 0.4.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (152) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +7 -0
  3. sglang/bench_one_batch.py +8 -6
  4. sglang/bench_serving.py +1 -1
  5. sglang/lang/interpreter.py +40 -1
  6. sglang/lang/ir.py +27 -0
  7. sglang/math_utils.py +8 -0
  8. sglang/srt/_custom_ops.py +2 -2
  9. sglang/srt/code_completion_parser.py +2 -44
  10. sglang/srt/configs/model_config.py +6 -0
  11. sglang/srt/constants.py +3 -0
  12. sglang/srt/conversation.py +19 -3
  13. sglang/srt/custom_op.py +5 -1
  14. sglang/srt/disaggregation/base/__init__.py +1 -1
  15. sglang/srt/disaggregation/base/conn.py +25 -11
  16. sglang/srt/disaggregation/common/__init__.py +5 -1
  17. sglang/srt/disaggregation/common/utils.py +42 -0
  18. sglang/srt/disaggregation/decode.py +211 -72
  19. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
  20. sglang/srt/disaggregation/fake/__init__.py +1 -1
  21. sglang/srt/disaggregation/fake/conn.py +15 -9
  22. sglang/srt/disaggregation/mini_lb.py +34 -4
  23. sglang/srt/disaggregation/mooncake/__init__.py +1 -1
  24. sglang/srt/disaggregation/mooncake/conn.py +30 -29
  25. sglang/srt/disaggregation/nixl/__init__.py +6 -1
  26. sglang/srt/disaggregation/nixl/conn.py +17 -12
  27. sglang/srt/disaggregation/prefill.py +144 -55
  28. sglang/srt/disaggregation/utils.py +155 -123
  29. sglang/srt/distributed/parallel_state.py +12 -4
  30. sglang/srt/entrypoints/engine.py +37 -29
  31. sglang/srt/entrypoints/http_server.py +153 -72
  32. sglang/srt/entrypoints/http_server_engine.py +0 -3
  33. sglang/srt/entrypoints/openai/__init__.py +0 -0
  34. sglang/srt/{openai_api → entrypoints/openai}/protocol.py +84 -10
  35. sglang/srt/entrypoints/openai/serving_base.py +149 -0
  36. sglang/srt/entrypoints/openai/serving_chat.py +921 -0
  37. sglang/srt/entrypoints/openai/serving_completions.py +424 -0
  38. sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
  39. sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
  40. sglang/srt/entrypoints/openai/serving_score.py +61 -0
  41. sglang/srt/entrypoints/openai/usage_processor.py +81 -0
  42. sglang/srt/entrypoints/openai/utils.py +72 -0
  43. sglang/srt/eplb_simulator/__init__.py +1 -0
  44. sglang/srt/eplb_simulator/reader.py +51 -0
  45. sglang/srt/function_call/base_format_detector.py +7 -4
  46. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  47. sglang/srt/function_call/ebnf_composer.py +64 -10
  48. sglang/srt/function_call/function_call_parser.py +6 -6
  49. sglang/srt/function_call/llama32_detector.py +1 -1
  50. sglang/srt/function_call/mistral_detector.py +1 -1
  51. sglang/srt/function_call/pythonic_detector.py +1 -1
  52. sglang/srt/function_call/qwen25_detector.py +1 -1
  53. sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
  54. sglang/srt/layers/activation.py +40 -3
  55. sglang/srt/layers/attention/aiter_backend.py +20 -4
  56. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  57. sglang/srt/layers/attention/cutlass_mla_backend.py +39 -15
  58. sglang/srt/layers/attention/flashattention_backend.py +71 -72
  59. sglang/srt/layers/attention/flashinfer_backend.py +10 -8
  60. sglang/srt/layers/attention/flashinfer_mla_backend.py +29 -28
  61. sglang/srt/layers/attention/flashmla_backend.py +7 -12
  62. sglang/srt/layers/attention/tbo_backend.py +3 -3
  63. sglang/srt/layers/attention/triton_backend.py +138 -130
  64. sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
  65. sglang/srt/layers/attention/vision.py +51 -24
  66. sglang/srt/layers/communicator.py +28 -10
  67. sglang/srt/layers/dp_attention.py +11 -2
  68. sglang/srt/layers/layernorm.py +29 -2
  69. sglang/srt/layers/linear.py +0 -4
  70. sglang/srt/layers/logits_processor.py +2 -14
  71. sglang/srt/layers/moe/ep_moe/kernels.py +165 -7
  72. sglang/srt/layers/moe/ep_moe/layer.py +249 -33
  73. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
  74. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  75. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -4
  76. sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
  77. sglang/srt/layers/moe/topk.py +107 -12
  78. sglang/srt/layers/pooler.py +56 -0
  79. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
  80. sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
  81. sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
  82. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
  83. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
  84. sglang/srt/layers/quantization/fp8.py +25 -17
  85. sglang/srt/layers/quantization/fp8_kernel.py +44 -15
  86. sglang/srt/layers/quantization/fp8_utils.py +87 -22
  87. sglang/srt/layers/quantization/modelopt_quant.py +62 -8
  88. sglang/srt/layers/quantization/utils.py +5 -2
  89. sglang/srt/layers/radix_attention.py +2 -3
  90. sglang/srt/layers/rotary_embedding.py +42 -2
  91. sglang/srt/layers/sampler.py +1 -1
  92. sglang/srt/lora/lora_manager.py +249 -105
  93. sglang/srt/lora/mem_pool.py +53 -50
  94. sglang/srt/lora/utils.py +1 -1
  95. sglang/srt/managers/cache_controller.py +33 -14
  96. sglang/srt/managers/io_struct.py +31 -10
  97. sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
  98. sglang/srt/managers/multimodal_processors/vila.py +85 -0
  99. sglang/srt/managers/schedule_batch.py +79 -37
  100. sglang/srt/managers/schedule_policy.py +70 -56
  101. sglang/srt/managers/scheduler.py +220 -79
  102. sglang/srt/managers/template_manager.py +226 -0
  103. sglang/srt/managers/tokenizer_manager.py +40 -10
  104. sglang/srt/managers/tp_worker.py +12 -2
  105. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  106. sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
  107. sglang/srt/mem_cache/base_prefix_cache.py +52 -8
  108. sglang/srt/mem_cache/chunk_cache.py +11 -15
  109. sglang/srt/mem_cache/hiradix_cache.py +38 -25
  110. sglang/srt/mem_cache/memory_pool.py +213 -505
  111. sglang/srt/mem_cache/memory_pool_host.py +380 -0
  112. sglang/srt/mem_cache/radix_cache.py +56 -28
  113. sglang/srt/model_executor/cuda_graph_runner.py +198 -100
  114. sglang/srt/model_executor/forward_batch_info.py +32 -10
  115. sglang/srt/model_executor/model_runner.py +28 -12
  116. sglang/srt/model_loader/loader.py +16 -2
  117. sglang/srt/model_loader/weight_utils.py +11 -2
  118. sglang/srt/models/bert.py +113 -13
  119. sglang/srt/models/deepseek_nextn.py +29 -27
  120. sglang/srt/models/deepseek_v2.py +213 -173
  121. sglang/srt/models/glm4.py +312 -0
  122. sglang/srt/models/internvl.py +46 -102
  123. sglang/srt/models/mimo_mtp.py +2 -18
  124. sglang/srt/models/roberta.py +117 -9
  125. sglang/srt/models/vila.py +305 -0
  126. sglang/srt/reasoning_parser.py +21 -11
  127. sglang/srt/sampling/sampling_batch_info.py +24 -0
  128. sglang/srt/sampling/sampling_params.py +2 -0
  129. sglang/srt/server_args.py +351 -238
  130. sglang/srt/speculative/build_eagle_tree.py +1 -1
  131. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -9
  132. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +130 -14
  133. sglang/srt/speculative/eagle_utils.py +468 -116
  134. sglang/srt/speculative/eagle_worker.py +258 -84
  135. sglang/srt/torch_memory_saver_adapter.py +19 -15
  136. sglang/srt/two_batch_overlap.py +4 -2
  137. sglang/srt/utils.py +235 -11
  138. sglang/test/attention/test_prefix_chunk_info.py +2 -0
  139. sglang/test/runners.py +38 -3
  140. sglang/test/test_block_fp8.py +1 -0
  141. sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
  142. sglang/test/test_block_fp8_ep.py +2 -0
  143. sglang/test/test_utils.py +4 -1
  144. sglang/utils.py +9 -0
  145. sglang/version.py +1 -1
  146. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/METADATA +8 -14
  147. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/RECORD +150 -128
  148. sglang/srt/entrypoints/verl_engine.py +0 -179
  149. sglang/srt/openai_api/adapter.py +0 -1990
  150. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
  151. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
  152. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
sglang/__init__.py CHANGED
@@ -15,6 +15,7 @@ from sglang.api import (
15
15
  get_server_info,
16
16
  image,
17
17
  select,
18
+ separate_reasoning,
18
19
  set_default_backend,
19
20
  system,
20
21
  system_begin,
@@ -54,6 +55,7 @@ __all__ = [
54
55
  "get_server_info",
55
56
  "image",
56
57
  "select",
58
+ "separate_reasoning",
57
59
  "set_default_backend",
58
60
  "system",
59
61
  "system_begin",
sglang/api.py CHANGED
@@ -15,6 +15,7 @@ from sglang.lang.ir import (
15
15
  SglRoleBegin,
16
16
  SglRoleEnd,
17
17
  SglSelect,
18
+ SglSeparateReasoning,
18
19
  SglVideo,
19
20
  )
20
21
 
@@ -277,3 +278,9 @@ def assistant_begin():
277
278
 
278
279
  def assistant_end():
279
280
  return SglRoleEnd("assistant")
281
+
282
+
283
+ def separate_reasoning(
284
+ expr: Optional[SglExpr] = None, model_type: Optional[str] = None
285
+ ):
286
+ return SglExprList([expr, SglSeparateReasoning(model_type, expr=expr)])
sglang/bench_one_batch.py CHANGED
@@ -71,6 +71,8 @@ from sglang.srt.utils import (
71
71
  configure_logger,
72
72
  get_bool_env_var,
73
73
  kill_process_tree,
74
+ require_mlp_sync,
75
+ require_mlp_tp_gather,
74
76
  set_gpu_proc_affinity,
75
77
  suppress_other_loggers,
76
78
  )
@@ -243,7 +245,7 @@ def extend(reqs, model_runner):
243
245
  enable_custom_logit_processor=False,
244
246
  )
245
247
  batch.prepare_for_extend()
246
- _maybe_prepare_dp_attn_batch(batch, model_runner)
248
+ _maybe_prepare_mlp_sync_batch(batch, model_runner)
247
249
  model_worker_batch = batch.get_model_worker_batch()
248
250
  forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
249
251
  logits_output, _ = model_runner.forward(forward_batch)
@@ -255,7 +257,7 @@ def extend(reqs, model_runner):
255
257
  def decode(input_token_ids, batch, model_runner):
256
258
  batch.output_ids = input_token_ids
257
259
  batch.prepare_for_decode()
258
- _maybe_prepare_dp_attn_batch(batch, model_runner)
260
+ _maybe_prepare_mlp_sync_batch(batch, model_runner)
259
261
  model_worker_batch = batch.get_model_worker_batch()
260
262
  forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
261
263
  logits_output, _ = model_runner.forward(forward_batch)
@@ -263,18 +265,18 @@ def decode(input_token_ids, batch, model_runner):
263
265
  return next_token_ids, logits_output.next_token_logits
264
266
 
265
267
 
266
- def _maybe_prepare_dp_attn_batch(batch: ScheduleBatch, model_runner):
267
- if model_runner.server_args.enable_dp_attention:
268
- Scheduler.prepare_dp_attn_batch_raw(
268
+ def _maybe_prepare_mlp_sync_batch(batch: ScheduleBatch, model_runner):
269
+ if require_mlp_sync(model_runner.server_args):
270
+ Scheduler.prepare_mlp_sync_batch_raw(
269
271
  batch,
270
272
  dp_size=model_runner.server_args.dp_size,
271
273
  attn_tp_size=1,
272
- moe_dense_tp_size=model_runner.server_args.moe_dense_tp_size,
273
274
  tp_cpu_group=model_runner.tp_group.cpu_group,
274
275
  get_idle_batch=None,
275
276
  disable_cuda_graph=model_runner.server_args.disable_cuda_graph,
276
277
  spec_algorithm=SpeculativeAlgorithm.NONE,
277
278
  speculative_num_draft_tokens=None,
279
+ require_mlp_tp_gather=require_mlp_tp_gather(model_runner.server_args),
278
280
  )
279
281
 
280
282
 
sglang/bench_serving.py CHANGED
@@ -399,7 +399,7 @@ async def async_request_sglang_generate(
399
399
  # NOTE: Some completion API might have a last
400
400
  # usage summary response without a token so we
401
401
  # want to check a token was generated
402
- if data["text"]:
402
+ if "text" in data and data["text"]:
403
403
  timestamp = time.perf_counter()
404
404
  generated_text = data["text"]
405
405
  output_len = data["meta_info"]["completion_tokens"]
@@ -26,6 +26,7 @@ from sglang.lang.ir import (
26
26
  SglRoleBegin,
27
27
  SglRoleEnd,
28
28
  SglSelect,
29
+ SglSeparateReasoning,
29
30
  SglVariable,
30
31
  SglVarScopeBegin,
31
32
  SglVarScopeEnd,
@@ -472,6 +473,8 @@ class StreamExecutor:
472
473
  self._execute_concatenate_and_append_kv_cache(other)
473
474
  else:
474
475
  self._execute_concatenate_and_append_text(other)
476
+ elif isinstance(other, SglSeparateReasoning):
477
+ self._execute_separate_reasoning(other)
475
478
  else:
476
479
  raise ValueError(f"Unknown type: {type(other)}")
477
480
 
@@ -724,8 +727,44 @@ class StreamExecutor:
724
727
  src_rids = [state.stream_executor.sid for state in expr.states]
725
728
  self.backend.concatenate_and_append(src_rids, self.sid)
726
729
 
730
+ def _execute_separate_reasoning(self, expr: SglSeparateReasoning):
731
+ if self.stream:
732
+ # separate reasoning for stream is not supported
733
+ return
734
+
735
+ if (
736
+ self.cur_role == "assistant"
737
+ and self.num_api_spec_tokens is not None
738
+ and self.backend.is_chat_model
739
+ ):
740
+ # Execute the stored lazy generation calls
741
+ self.backend.role_end_generate(self)
742
+
743
+ from sglang.srt.reasoning_parser import ReasoningParser
744
+
745
+ reasoning_parser = ReasoningParser(expr.model_type)
746
+ other = expr.expr
747
+ if not other:
748
+ return
749
+ elif isinstance(other, SglGen) or isinstance(other, SglSelect):
750
+ cur_text = self.get_var(other.name)
751
+ reasoning, normal_text = reasoning_parser.parse_non_stream(cur_text)
752
+ reasoning_name = expr.process_name_for_reasoning(other.name)
753
+ self.set_var(other.name, normal_text)
754
+ self.set_var(reasoning_name, reasoning)
755
+ # the variable is ready to be used
756
+ self.variable_event[reasoning_name].set()
757
+ self.text_ = self.text_[: self.cur_role_begin_pos] + normal_text
758
+ elif isinstance(other, SglExprList):
759
+ for x in other.expr_list:
760
+ self._execute_separate_reasoning(
761
+ SglSeparateReasoning(expr.model_type, x)
762
+ )
763
+
727
764
  def _init_var_event(self, expr):
728
- if isinstance(expr, (SglGen, SglSelect, SglVarScopeBegin)):
765
+ if isinstance(
766
+ expr, (SglGen, SglSelect, SglVarScopeBegin, SglSeparateReasoning)
767
+ ):
729
768
  self.variable_event[expr.name] = threading.Event()
730
769
  if self.stream:
731
770
  self.stream_var_event[expr.name] = threading.Event()
sglang/lang/ir.py CHANGED
@@ -606,3 +606,30 @@ class SglCommitLazy(SglExpr):
606
606
 
607
607
  def __repr__(self):
608
608
  return "CommitLazy()"
609
+
610
+
611
+ class SglSeparateReasoning(SglExpr):
612
+ def __init__(self, model_type: str, expr: SglExpr):
613
+ super().__init__()
614
+ self.model_type = model_type
615
+
616
+ self.expr = expr
617
+ self.name = None
618
+ self._process_expr(expr)
619
+
620
+ def process_name_for_reasoning(self, name):
621
+ if not name:
622
+ raise ValueError("name must be provided")
623
+ return f"{name}_reasoning_content"
624
+
625
+ def _process_expr(self, expr):
626
+ if isinstance(expr, SglGen):
627
+ self.name = self.process_name_for_reasoning(expr.name)
628
+ elif isinstance(expr, SglSelect):
629
+ self.name = self.process_name_for_reasoning(expr.name)
630
+ elif isinstance(expr, SglExprList):
631
+ for x in expr.expr_list:
632
+ self._process_expr(x)
633
+
634
+ def __repr__(self):
635
+ return f"SeparateReasoning(model_type={self.model_type}, name={self.name})"
sglang/math_utils.py ADDED
@@ -0,0 +1,8 @@
1
+ # COPIED FROM DeepGEMM
2
+ def align(x: int, y: int) -> int:
3
+ return ceil_div(x, y) * y
4
+
5
+
6
+ # COPIED FROM DeepGEMM
7
+ def ceil_div(x: int, y: int) -> int:
8
+ return (x + y - 1) // y
sglang/srt/_custom_ops.py CHANGED
@@ -4,7 +4,7 @@ from typing import List, Tuple
4
4
 
5
5
  import torch
6
6
 
7
- from sglang.srt.utils import get_bool_env_var, is_hip, is_hpu
7
+ from sglang.srt.utils import get_bool_env_var, is_hip, is_hpu, is_npu
8
8
 
9
9
  logger = logging.getLogger(__name__)
10
10
  use_vllm_custom_allreduce = get_bool_env_var(
@@ -25,7 +25,7 @@ if not is_hpu():
25
25
  logger.warning("Failed to import from custom_ar with %r", e)
26
26
 
27
27
 
28
- if not is_hip():
28
+ if not is_hip() and not is_npu():
29
29
  if use_vllm_custom_allreduce:
30
30
  custom_op = torch.ops._C_custom_ar
31
31
  else:
@@ -15,12 +15,10 @@
15
15
 
16
16
 
17
17
  import dataclasses
18
- import json
19
18
  import logging
20
- import os
21
19
  from enum import auto
22
20
 
23
- from sglang.srt.openai_api.protocol import ChatCompletionRequest
21
+ from sglang.srt.entrypoints.openai.protocol import CompletionRequest
24
22
 
25
23
  logger = logging.getLogger(__name__)
26
24
  completion_template_name = None
@@ -57,46 +55,6 @@ class CompletionTemplate:
57
55
  completion_templates: dict[str, CompletionTemplate] = {}
58
56
 
59
57
 
60
- def load_completion_template_for_openai_api(completion_template_arg):
61
- global completion_template_name
62
-
63
- logger.info(
64
- f"Use completion template for the OpenAI-compatible API server: {completion_template_arg}"
65
- )
66
-
67
- if not completion_template_exists(completion_template_arg):
68
- if not os.path.exists(completion_template_arg):
69
- raise RuntimeError(
70
- f"Completion template {completion_template_arg} is not a built-in template name "
71
- "or a valid completion template file path."
72
- )
73
-
74
- assert completion_template_arg.endswith(
75
- ".json"
76
- ), "unrecognized format of completion template file"
77
- with open(completion_template_arg, "r") as filep:
78
- template = json.load(filep)
79
- try:
80
- fim_position = FimPosition[template["fim_position"]]
81
- except KeyError:
82
- raise ValueError(
83
- f"Unknown fim position: {template['fim_position']}"
84
- ) from None
85
- register_completion_template(
86
- CompletionTemplate(
87
- name=template["name"],
88
- fim_begin_token=template["fim_begin_token"],
89
- fim_middle_token=template["fim_middle_token"],
90
- fim_end_token=template["fim_end_token"],
91
- fim_position=fim_position,
92
- ),
93
- override=True,
94
- )
95
- completion_template_name = template["name"]
96
- else:
97
- completion_template_name = completion_template_arg
98
-
99
-
100
58
  def register_completion_template(template: CompletionTemplate, override: bool = False):
101
59
  """Register a new completion template."""
102
60
  if not override:
@@ -116,7 +74,7 @@ def is_completion_template_defined() -> bool:
116
74
  return completion_template_name is not None
117
75
 
118
76
 
119
- def generate_completion_prompt_from_request(request: ChatCompletionRequest) -> str:
77
+ def generate_completion_prompt_from_request(request: CompletionRequest) -> str:
120
78
  global completion_template_name
121
79
  if request.suffix == "":
122
80
  return request.prompt
@@ -550,6 +550,11 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal
550
550
  or "Qwen2ForRewardModel" in model_architectures
551
551
  or "Qwen2ForSequenceClassification" in model_architectures
552
552
  or "CLIPModel" in model_architectures
553
+ or "BertModel" in model_architectures
554
+ or "Contriever" in model_architectures
555
+ or "BertForSequenceClassification" in model_architectures
556
+ or "XLMRobertaModel" in model_architectures
557
+ or "XLMRobertaForSequenceClassification" in model_architectures
553
558
  ):
554
559
  return False
555
560
  else:
@@ -578,6 +583,7 @@ multimodal_model_archs = [
578
583
  "KimiVLForConditionalGeneration",
579
584
  "InternVLChatModel",
580
585
  "Phi4MMForCausalLM",
586
+ "VILAForConditionalGeneration",
581
587
  ]
582
588
 
583
589
 
@@ -0,0 +1,3 @@
1
+ # GPU Memory Types
2
+ GPU_MEMORY_TYPE_KV_CACHE = "kv_cache"
3
+ GPU_MEMORY_TYPE_WEIGHTS = "weights"
@@ -11,7 +11,17 @@
11
11
  # See the License for the specific language governing permissions and
12
12
  # limitations under the License.
13
13
  # ==============================================================================
14
- """Conversation chat templates."""
14
+ """Conversation chat templates.
15
+
16
+ This module provides conversation template definitions, data structures, and utilities
17
+ for managing chat templates across different model types in SGLang.
18
+
19
+ Key components:
20
+ - Conversation class: Defines the structure and behavior of chat templates
21
+ - SeparatorStyle enum: Different conversation formatting styles
22
+ - Template registry: Functions to register and retrieve templates by name or model path
23
+ - Built-in templates: Pre-defined templates for popular models
24
+ """
15
25
 
16
26
  # Adapted from
17
27
  # https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
@@ -20,7 +30,7 @@ import re
20
30
  from enum import IntEnum, auto
21
31
  from typing import Callable, Dict, List, Optional, Tuple, Union
22
32
 
23
- from sglang.srt.openai_api.protocol import ChatCompletionRequest
33
+ from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest
24
34
  from sglang.srt.utils import read_system_prompt_from_file
25
35
 
26
36
 
@@ -618,7 +628,7 @@ def generate_chat_conv(
618
628
 
619
629
 
620
630
  # llama2 template
621
- # reference: https://huggingface.co/blog/codellama#conversational-instructions
631
+ # reference: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
622
632
  # reference: https://github.com/facebookresearch/llama/blob/1a240688810f8036049e8da36b073f63d2ac552c/llama/generation.py#L212
623
633
  register_conv_template(
624
634
  Conversation(
@@ -983,3 +993,9 @@ def match_devstral(model_path: str):
983
993
  def match_phi_4_mm(model_path: str):
984
994
  if "phi-4-multimodal" in model_path.lower():
985
995
  return "phi-4-mm"
996
+
997
+
998
+ @register_conv_template_matching_function
999
+ def match_vila(model_path: str):
1000
+ if re.search(r"vila", model_path, re.IGNORECASE):
1001
+ return "chatml"
sglang/srt/custom_op.py CHANGED
@@ -1,9 +1,11 @@
1
1
  from torch import nn
2
2
 
3
- from sglang.srt.utils import is_cuda, is_hip
3
+ from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip
4
4
 
5
5
  _is_cuda = is_cuda()
6
6
  _is_hip = is_hip()
7
+ _is_cpu = is_cpu()
8
+ _is_cpu_amx_available = cpu_has_amx_support()
7
9
 
8
10
 
9
11
  class CustomOp(nn.Module):
@@ -75,5 +77,7 @@ class CustomOp(nn.Module):
75
77
  return self.forward_cuda
76
78
  elif _is_hip:
77
79
  return self.forward_hip
80
+ elif _is_cpu and _is_cpu_amx_available:
81
+ return self.forward_cpu
78
82
  else:
79
83
  return self.forward_native
@@ -1,4 +1,4 @@
1
- from .conn import (
1
+ from sglang.srt.disaggregation.base.conn import (
2
2
  BaseKVBootstrapServer,
3
3
  BaseKVManager,
4
4
  BaseKVReceiver,
@@ -1,23 +1,32 @@
1
+ from __future__ import annotations
2
+
1
3
  from abc import ABC, abstractmethod
2
- from typing import Optional
4
+ from typing import TYPE_CHECKING, List, Optional
3
5
 
4
6
  import numpy as np
5
7
  import numpy.typing as npt
6
8
 
7
- from sglang.srt.disaggregation.utils import DisaggregationMode
8
9
  from sglang.srt.server_args import ServerArgs
9
10
 
11
+ if TYPE_CHECKING:
12
+ from sglang.srt.disaggregation.utils import DisaggregationMode
13
+
10
14
 
11
15
  class KVArgs:
12
16
  engine_rank: int
13
- kv_data_ptrs: list[int]
14
- kv_data_lens: list[int]
15
- kv_item_lens: list[int]
16
- aux_data_ptrs: list[int]
17
- aux_data_lens: list[int]
18
- aux_item_lens: list[int]
17
+ kv_data_ptrs: List[int]
18
+ kv_data_lens: List[int]
19
+ kv_item_lens: List[int]
20
+ aux_data_ptrs: List[int]
21
+ aux_data_lens: List[int]
22
+ aux_item_lens: List[int]
19
23
  ib_device: str
24
+ ib_traffic_class: str
20
25
  gpu_id: int
26
+ # for different tp
27
+ decode_tp_size: int
28
+ # for pp prefill
29
+ prefill_pp_size: int
21
30
 
22
31
 
23
32
  class KVPoll:
@@ -45,7 +54,12 @@ class BaseKVSender(ABC):
45
54
 
46
55
  @abstractmethod
47
56
  def __init__(
48
- self, mgr: BaseKVManager, bootstrap_addr: str, bootstrap_room: int
57
+ self,
58
+ mgr: BaseKVManager,
59
+ bootstrap_addr: str,
60
+ bootstrap_room: int,
61
+ dest_tp_ranks: List[int],
62
+ pp_rank: int,
49
63
  ): ...
50
64
 
51
65
  @abstractmethod
@@ -56,7 +70,7 @@ class BaseKVSender(ABC):
56
70
  ...
57
71
 
58
72
  @abstractmethod
59
- def send(self, kv_indices: npt.NDArray[np.int64]):
73
+ def send(self, kv_indices: npt.NDArray[np.int32]):
60
74
  """
61
75
  Send the kv cache at the given kv indices to the decoder server
62
76
  """
@@ -88,7 +102,7 @@ class BaseKVReceiver(ABC):
88
102
  ): ...
89
103
 
90
104
  @abstractmethod
91
- def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None):
105
+ def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
92
106
  """
93
107
  Notify the prefill server about the kv indices and aux index
94
108
  """
@@ -1 +1,5 @@
1
- from .conn import CommonKVBootstrapServer, CommonKVManager, CommonKVReceiver
1
+ from sglang.srt.disaggregation.common.conn import (
2
+ CommonKVBootstrapServer,
3
+ CommonKVManager,
4
+ CommonKVReceiver,
5
+ )
@@ -0,0 +1,42 @@
1
+ import threading
2
+ from collections import deque
3
+ from typing import List, Tuple
4
+
5
+ import numpy as np
6
+ import numpy.typing as npt
7
+
8
+
9
+ class FastQueue:
10
+ def __init__(self):
11
+ self._buf = deque()
12
+ self._cond = threading.Condition()
13
+
14
+ def put(self, item):
15
+ with self._cond:
16
+ self._buf.append(item)
17
+ # wake up a thread of wait()
18
+ self._cond.notify()
19
+
20
+ def get(self):
21
+ with self._cond:
22
+ # if queue is empty ,block until is notified()
23
+ while not self._buf:
24
+ self._cond.wait()
25
+ return self._buf.popleft()
26
+
27
+
28
+ def group_concurrent_contiguous(
29
+ src_indices: npt.NDArray[np.int32], dst_indices: npt.NDArray[np.int32]
30
+ ) -> Tuple[List[npt.NDArray[np.int32]], List[npt.NDArray[np.int32]]]:
31
+ """Vectorised NumPy implementation."""
32
+ if src_indices.size == 0:
33
+ return [], []
34
+
35
+ brk = np.where((np.diff(src_indices) != 1) | (np.diff(dst_indices) != 1))[0] + 1
36
+ src_groups = np.split(src_indices, brk)
37
+ dst_groups = np.split(dst_indices, brk)
38
+
39
+ src_groups = [g.tolist() for g in src_groups]
40
+ dst_groups = [g.tolist() for g in dst_groups]
41
+
42
+ return src_groups, dst_groups