sglang 0.4.6.post1__py3-none-any.whl → 0.4.6.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. sglang/bench_one_batch.py +2 -0
  2. sglang/check_env.py +3 -3
  3. sglang/srt/configs/__init__.py +4 -0
  4. sglang/srt/configs/kimi_vl.py +38 -0
  5. sglang/srt/configs/kimi_vl_moonvit.py +32 -0
  6. sglang/srt/configs/model_config.py +15 -0
  7. sglang/srt/conversation.py +122 -1
  8. sglang/srt/entrypoints/engine.py +44 -22
  9. sglang/srt/function_call_parser.py +97 -0
  10. sglang/srt/hf_transformers_utils.py +2 -0
  11. sglang/srt/layers/attention/cutlass_mla_backend.py +1 -1
  12. sglang/srt/layers/attention/flashinfer_backend.py +107 -82
  13. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -16
  14. sglang/srt/layers/attention/flashmla_backend.py +3 -0
  15. sglang/srt/layers/dp_attention.py +5 -2
  16. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +1 -3
  17. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  18. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  19. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  20. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  21. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +8 -6
  22. sglang/srt/layers/quantization/__init__.py +2 -2
  23. sglang/srt/layers/quantization/deep_gemm.py +1 -1
  24. sglang/srt/layers/utils.py +35 -0
  25. sglang/srt/lora/layers.py +35 -9
  26. sglang/srt/lora/lora_manager.py +84 -35
  27. sglang/srt/managers/data_parallel_controller.py +52 -34
  28. sglang/srt/managers/multimodal_processors/kimi_vl.py +73 -0
  29. sglang/srt/managers/schedule_batch.py +25 -15
  30. sglang/srt/managers/scheduler.py +263 -59
  31. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -1
  32. sglang/srt/managers/tp_worker.py +51 -16
  33. sglang/srt/managers/tp_worker_overlap_thread.py +9 -3
  34. sglang/srt/mem_cache/memory_pool.py +70 -36
  35. sglang/srt/model_executor/cuda_graph_runner.py +82 -19
  36. sglang/srt/model_executor/forward_batch_info.py +31 -1
  37. sglang/srt/model_executor/model_runner.py +115 -57
  38. sglang/srt/models/deepseek_nextn.py +1 -257
  39. sglang/srt/models/deepseek_v2.py +78 -18
  40. sglang/srt/models/kimi_vl.py +308 -0
  41. sglang/srt/models/kimi_vl_moonvit.py +639 -0
  42. sglang/srt/models/llama.py +92 -30
  43. sglang/srt/models/llama4.py +2 -1
  44. sglang/srt/models/llama_eagle.py +4 -1
  45. sglang/srt/models/llama_eagle3.py +4 -1
  46. sglang/srt/models/qwen2_moe.py +8 -3
  47. sglang/srt/models/qwen2_vl.py +0 -12
  48. sglang/srt/models/qwen3_moe.py +8 -3
  49. sglang/srt/openai_api/adapter.py +34 -22
  50. sglang/srt/openai_api/protocol.py +11 -1
  51. sglang/srt/server_args.py +67 -22
  52. sglang/srt/speculative/eagle_worker.py +3 -2
  53. sglang/srt/utils.py +88 -9
  54. sglang/test/runners.py +4 -0
  55. sglang/test/test_utils.py +29 -0
  56. sglang/version.py +1 -1
  57. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post2.dist-info}/METADATA +5 -4
  58. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post2.dist-info}/RECORD +61 -51
  59. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post2.dist-info}/WHEEL +1 -1
  60. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post2.dist-info}/licenses/LICENSE +0 -0
  61. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post2.dist-info}/top_level.txt +0 -0
@@ -17,13 +17,14 @@
17
17
  """Inference-only LLaMA model compatible with HuggingFace weights."""
18
18
 
19
19
  import logging
20
- from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
20
+ from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
21
21
 
22
22
  import torch
23
23
  from torch import nn
24
24
  from transformers import LlamaConfig
25
25
 
26
26
  from sglang.srt.distributed import (
27
+ get_pp_group,
27
28
  get_tensor_model_parallel_rank,
28
29
  get_tensor_model_parallel_world_size,
29
30
  )
@@ -39,11 +40,12 @@ from sglang.srt.layers.pooler import Pooler, PoolingType
39
40
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
40
41
  from sglang.srt.layers.radix_attention import RadixAttention
41
42
  from sglang.srt.layers.rotary_embedding import get_rope
43
+ from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
42
44
  from sglang.srt.layers.vocab_parallel_embedding import (
43
45
  ParallelLMHead,
44
46
  VocabParallelEmbedding,
45
47
  )
46
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch
48
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
47
49
  from sglang.srt.model_loader.weight_utils import (
48
50
  default_weight_loader,
49
51
  kv_cache_scales_loader,
@@ -275,21 +277,31 @@ class LlamaModel(nn.Module):
275
277
  self.config = config
276
278
  self.padding_idx = config.pad_token_id
277
279
  self.vocab_size = config.vocab_size
278
- self.embed_tokens = VocabParallelEmbedding(
279
- config.vocab_size,
280
- config.hidden_size,
281
- quant_config=quant_config,
282
- prefix=add_prefix("embed_tokens", prefix),
283
- )
284
- self.layers = make_layers(
280
+ self.pp_group = get_pp_group()
281
+ if self.pp_group.is_first_rank:
282
+ self.embed_tokens = VocabParallelEmbedding(
283
+ config.vocab_size,
284
+ config.hidden_size,
285
+ quant_config=quant_config,
286
+ prefix=add_prefix("embed_tokens", prefix),
287
+ )
288
+ else:
289
+ self.embed_tokens = PPMissingLayer()
290
+
291
+ self.layers, self.start_layer, self.end_layer = make_layers(
285
292
  config.num_hidden_layers,
286
293
  lambda idx, prefix: LlamaDecoderLayer(
287
- config=config, layer_id=idx, quant_config=quant_config, prefix=prefix
294
+ config=config, quant_config=quant_config, layer_id=idx, prefix=prefix
288
295
  ),
296
+ pp_rank=self.pp_group.rank_in_group,
297
+ pp_size=self.pp_group.world_size,
289
298
  prefix="model.layers",
290
299
  )
291
300
 
292
- self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
301
+ if self.pp_group.is_last_rank:
302
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
303
+ else:
304
+ self.norm = PPMissingLayer(return_tuple=True)
293
305
  self.layers_to_capture = []
294
306
 
295
307
  def forward(
@@ -298,14 +310,23 @@ class LlamaModel(nn.Module):
298
310
  positions: torch.Tensor,
299
311
  forward_batch: ForwardBatch,
300
312
  input_embeds: torch.Tensor = None,
301
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
302
- if input_embeds is None:
303
- hidden_states = self.embed_tokens(input_ids)
313
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
314
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]], PPProxyTensors]:
315
+ if self.pp_group.is_first_rank:
316
+ if input_embeds is None:
317
+ hidden_states = self.embed_tokens(input_ids)
318
+ else:
319
+ hidden_states = input_embeds
320
+ residual = None
304
321
  else:
305
- hidden_states = input_embeds
306
- residual = None
322
+ assert pp_proxy_tensors is not None
323
+ # FIXME(@ying): reduce the number of proxy tensors by not fusing layer norms
324
+ hidden_states = pp_proxy_tensors["hidden_states"]
325
+ residual = pp_proxy_tensors["residual"]
326
+ deferred_norm = None
327
+
307
328
  aux_hidden_states = []
308
- for i in range(len(self.layers)):
329
+ for i in range(self.start_layer, self.end_layer):
309
330
  if i in self.layers_to_capture:
310
331
  aux_hidden_states.append(hidden_states + residual)
311
332
  layer = self.layers[i]
@@ -315,7 +336,16 @@ class LlamaModel(nn.Module):
315
336
  forward_batch,
316
337
  residual,
317
338
  )
318
- hidden_states, _ = self.norm(hidden_states, residual)
339
+
340
+ if not self.pp_group.is_last_rank:
341
+ return PPProxyTensors(
342
+ {
343
+ "hidden_states": hidden_states,
344
+ "residual": residual,
345
+ }
346
+ )
347
+ else:
348
+ hidden_states, _ = self.norm(hidden_states, residual)
319
349
 
320
350
  if len(aux_hidden_states) == 0:
321
351
  return hidden_states
@@ -376,6 +406,7 @@ class LlamaForCausalLM(nn.Module):
376
406
  prefix: str = "",
377
407
  ) -> None:
378
408
  super().__init__()
409
+ self.pp_group = get_pp_group()
379
410
  self.config = config
380
411
  self.quant_config = quant_config
381
412
  self.model = self._init_model(config, quant_config, add_prefix("model", prefix))
@@ -419,23 +450,41 @@ class LlamaForCausalLM(nn.Module):
419
450
  forward_batch: ForwardBatch,
420
451
  input_embeds: torch.Tensor = None,
421
452
  get_embedding: bool = False,
453
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
422
454
  ) -> LogitsProcessorOutput:
455
+ hidden_states = self.model(
456
+ input_ids,
457
+ positions,
458
+ forward_batch,
459
+ input_embeds,
460
+ pp_proxy_tensors=pp_proxy_tensors,
461
+ )
462
+
423
463
  aux_hidden_states = None
424
464
  if self.capture_aux_hidden_states:
425
- hidden_states, aux_hidden_states = self.model(
426
- input_ids, positions, forward_batch, input_embeds
427
- )
465
+ hidden_states, aux_hidden_states = hidden_states
466
+
467
+ if self.pp_group.is_last_rank:
468
+ if not get_embedding:
469
+ return self.logits_processor(
470
+ input_ids,
471
+ hidden_states,
472
+ self.lm_head,
473
+ forward_batch,
474
+ aux_hidden_states,
475
+ )
476
+ else:
477
+ return self.pooler(hidden_states, forward_batch)
428
478
  else:
429
- hidden_states = self.model(
430
- input_ids, positions, forward_batch, input_embeds
431
- )
479
+ return hidden_states
432
480
 
433
- if not get_embedding:
434
- return self.logits_processor(
435
- input_ids, hidden_states, self.lm_head, forward_batch, aux_hidden_states
436
- )
437
- else:
438
- return self.pooler(hidden_states, forward_batch)
481
+ @property
482
+ def start_layer(self):
483
+ return self.model.start_layer
484
+
485
+ @property
486
+ def end_layer(self):
487
+ return self.model.end_layer
439
488
 
440
489
  def get_input_embeddings(self) -> nn.Embedding:
441
490
  return self.model.embed_tokens
@@ -491,6 +540,16 @@ class LlamaForCausalLM(nn.Module):
491
540
  params_dict = dict(self.named_parameters())
492
541
 
493
542
  for name, loaded_weight in weights:
543
+ layer_id = get_layer_id(name)
544
+ if (
545
+ layer_id is not None
546
+ and hasattr(self.model, "start_layer")
547
+ and (
548
+ layer_id < self.model.start_layer
549
+ or layer_id >= self.model.end_layer
550
+ )
551
+ ):
552
+ continue
494
553
  if "rotary_emb.inv_freq" in name or "projector" in name:
495
554
  continue
496
555
  if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
@@ -637,6 +696,9 @@ class LlamaForCausalLM(nn.Module):
637
696
  self.model.load_kv_cache_scales(quantization_param_path)
638
697
 
639
698
  def set_eagle3_layers_to_capture(self):
699
+ if not self.pp_group.is_last_rank:
700
+ return
701
+
640
702
  self.capture_aux_hidden_states = True
641
703
  num_layers = self.config.num_hidden_layers
642
704
  self.model.layers_to_capture = [2, num_layers // 2, num_layers - 3]
@@ -46,7 +46,7 @@ from sglang.srt.layers.radix_attention import RadixAttention
46
46
  from sglang.srt.layers.rotary_embedding import get_rope
47
47
  from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
48
48
  from sglang.srt.managers.schedule_batch import global_server_args_dict
49
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch
49
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
50
50
  from sglang.srt.models.llama import LlamaForCausalLM, LlamaMLP
51
51
  from sglang.srt.utils import add_prefix, fast_topk, get_compiler_backend, make_layers
52
52
 
@@ -431,6 +431,7 @@ class Llama4Model(nn.Module):
431
431
  positions: torch.Tensor,
432
432
  forward_batch: ForwardBatch,
433
433
  input_embeds: torch.Tensor = None,
434
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
434
435
  ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
435
436
  if input_embeds is None:
436
437
  hidden_states = self.embed_tokens(input_ids)
@@ -25,13 +25,14 @@ import torch
25
25
  from torch import nn
26
26
  from transformers import LlamaConfig
27
27
 
28
+ from sglang.srt.distributed import get_pp_group
28
29
  from sglang.srt.layers.logits_processor import LogitsProcessor
29
30
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
30
31
  from sglang.srt.layers.vocab_parallel_embedding import (
31
32
  ParallelLMHead,
32
33
  VocabParallelEmbedding,
33
34
  )
34
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch
35
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
35
36
  from sglang.srt.models.llama import LlamaDecoderLayer, LlamaForCausalLM
36
37
 
37
38
 
@@ -86,6 +87,7 @@ class LlamaModel(nn.Module):
86
87
  positions: torch.Tensor,
87
88
  forward_batch: ForwardBatch,
88
89
  input_embeds: torch.Tensor = None,
90
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
89
91
  ) -> torch.Tensor:
90
92
  if input_embeds is None:
91
93
  hidden_states = self.embed_tokens(input_ids)
@@ -118,6 +120,7 @@ class LlamaForCausalLMEagle(LlamaForCausalLM):
118
120
  nn.Module.__init__(self)
119
121
  self.config = config
120
122
  self.quant_config = quant_config
123
+ self.pp_group = get_pp_group()
121
124
  self.model = LlamaModel(
122
125
  config, quant_config=quant_config, prefix=add_prefix("model", prefix)
123
126
  )
@@ -25,6 +25,7 @@ import torch
25
25
  from torch import nn
26
26
  from transformers import LlamaConfig
27
27
 
28
+ from sglang.srt.distributed import get_pp_group
28
29
  from sglang.srt.layers.layernorm import RMSNorm
29
30
  from sglang.srt.layers.linear import QKVParallelLinear, RowParallelLinear
30
31
  from sglang.srt.layers.logits_processor import LogitsProcessor
@@ -33,7 +34,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
33
34
  ParallelLMHead,
34
35
  VocabParallelEmbedding,
35
36
  )
36
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch
37
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
37
38
  from sglang.srt.models.llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM
38
39
 
39
40
 
@@ -118,6 +119,7 @@ class LlamaModel(nn.Module):
118
119
  positions: torch.Tensor,
119
120
  forward_batch: ForwardBatch,
120
121
  input_embeds: torch.Tensor = None,
122
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
121
123
  ) -> torch.Tensor:
122
124
  if input_embeds is None:
123
125
  embeds = self.embed_tokens(input_ids)
@@ -155,6 +157,7 @@ class LlamaForCausalLMEagle3(LlamaForCausalLM):
155
157
  nn.Module.__init__(self)
156
158
  self.config = config
157
159
  self.quant_config = quant_config
160
+ self.pp_group = get_pp_group()
158
161
 
159
162
  if self.config.num_hidden_layers != 1:
160
163
  raise ValueError("EAGLE3 currently only supports 1 layer")
@@ -36,6 +36,7 @@ from sglang.srt.layers.linear import (
36
36
  RowParallelLinear,
37
37
  )
38
38
  from sglang.srt.layers.logits_processor import LogitsProcessor
39
+ from sglang.srt.layers.moe.ep_moe.layer import EPMoE
39
40
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
40
41
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
41
42
  from sglang.srt.layers.radix_attention import RadixAttention
@@ -45,6 +46,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
45
46
  VocabParallelEmbedding,
46
47
  )
47
48
  from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
49
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
48
50
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
49
51
  from sglang.srt.model_loader.weight_utils import default_weight_loader
50
52
  from sglang.srt.utils import add_prefix, make_layers
@@ -108,12 +110,13 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
108
110
  f"the number of experts {config.num_experts}."
109
111
  )
110
112
 
111
- self.experts = FusedMoE(
113
+ MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
114
+
115
+ self.experts = MoEImpl(
112
116
  num_experts=config.num_experts,
113
117
  top_k=config.num_experts_per_tok,
114
118
  hidden_size=config.hidden_size,
115
119
  intermediate_size=config.moe_intermediate_size,
116
- reduce_results=False,
117
120
  renormalize=config.norm_topk_prob,
118
121
  quant_config=quant_config,
119
122
  prefix=add_prefix("experts", prefix),
@@ -427,7 +430,9 @@ class Qwen2MoeForCausalLM(nn.Module):
427
430
  ("gate_up_proj", "up_proj", 1),
428
431
  ]
429
432
 
430
- expert_params_mapping = FusedMoE.make_expert_params_mapping(
433
+ MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
434
+
435
+ expert_params_mapping = MoEImpl.make_expert_params_mapping(
431
436
  ckpt_gate_proj_name="gate_proj",
432
437
  ckpt_down_proj_name="down_proj",
433
438
  ckpt_up_proj_name="up_proj",
@@ -442,18 +442,6 @@ class Qwen2VLForConditionalGeneration(nn.Module):
442
442
  "up_proj": ("gate_up_proj", 1),
443
443
  }
444
444
 
445
- def calculate_num_image_tokens(self, image_grid_thw: Tuple[int, int, int]):
446
- processor = cached_get_processor(self.config._name_or_path)
447
- grid_t, grid_h, grid_w = image_grid_thw
448
- num_image_tokens = (
449
- grid_t
450
- * grid_h
451
- * grid_w
452
- // processor.image_processor.merge_size
453
- // processor.image_processor.merge_size
454
- )
455
- return num_image_tokens
456
-
457
445
  def __init__(
458
446
  self,
459
447
  config: Qwen2VLConfig,
@@ -40,6 +40,7 @@ from sglang.srt.layers.linear import (
40
40
  RowParallelLinear,
41
41
  )
42
42
  from sglang.srt.layers.logits_processor import LogitsProcessor
43
+ from sglang.srt.layers.moe.ep_moe.layer import EPMoE
43
44
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
44
45
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
45
46
  from sglang.srt.layers.radix_attention import RadixAttention
@@ -48,6 +49,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
48
49
  ParallelLMHead,
49
50
  VocabParallelEmbedding,
50
51
  )
52
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
51
53
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
52
54
  from sglang.srt.model_loader.weight_utils import default_weight_loader
53
55
  from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
@@ -73,12 +75,13 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
73
75
  f"the number of experts {config.num_experts}."
74
76
  )
75
77
 
76
- self.experts = FusedMoE(
78
+ MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
79
+
80
+ self.experts = MoEImpl(
77
81
  num_experts=config.num_experts,
78
82
  top_k=config.num_experts_per_tok,
79
83
  hidden_size=config.hidden_size,
80
84
  intermediate_size=config.moe_intermediate_size,
81
- reduce_results=False,
82
85
  renormalize=config.norm_topk_prob,
83
86
  quant_config=quant_config,
84
87
  prefix=add_prefix("experts", prefix),
@@ -356,7 +359,9 @@ class Qwen3MoeForCausalLM(nn.Module):
356
359
  ("gate_up_proj", "up_proj", 1),
357
360
  ]
358
361
 
359
- expert_params_mapping = FusedMoE.make_expert_params_mapping(
362
+ MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
363
+
364
+ expert_params_mapping = MoEImpl.make_expert_params_mapping(
360
365
  ckpt_gate_proj_name="gate_proj",
361
366
  ckpt_down_proj_name="down_proj",
362
367
  ckpt_up_proj_name="up_proj",
@@ -36,6 +36,7 @@ from sglang.srt.conversation import (
36
36
  chat_template_exists,
37
37
  generate_chat_conv,
38
38
  generate_embedding_convs,
39
+ get_conv_template_by_model_path,
39
40
  register_conv_template,
40
41
  )
41
42
  from sglang.srt.function_call_parser import FunctionCallParser
@@ -163,10 +164,14 @@ def load_chat_template_for_openai_api(tokenizer_manager, chat_template_arg, mode
163
164
  else:
164
165
  chat_template_name = chat_template_arg
165
166
 
166
- # Check chat-template
167
- # TODO:
168
- # 1. Do not import any code from sglang.lang
169
- # 2. For VLM, when chat_template_arg is None, set it automatically by guessing from model_path.
167
+
168
+ def guess_chat_template_name_from_model_path(model_path):
169
+ global chat_template_name
170
+ chat_template_name = get_conv_template_by_model_path(model_path)
171
+ if chat_template_name is not None:
172
+ logger.info(
173
+ f"Infer the chat template name from the model path and obtain the result: {chat_template_name}."
174
+ )
170
175
 
171
176
 
172
177
  async def v1_files_create(
@@ -894,6 +899,24 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
894
899
  return response
895
900
 
896
901
 
902
+ def _get_enable_thinking_from_request(request_obj):
903
+ """Extracts the 'enable_thinking' flag from request chat_template_kwargs.
904
+
905
+ Args:
906
+ request_obj: The request object (or an item from a list of requests).
907
+
908
+ Returns:
909
+ The boolean value of 'enable_thinking' if found and not True, otherwise True.
910
+ """
911
+ if (
912
+ hasattr(request_obj, "chat_template_kwargs")
913
+ and request_obj.chat_template_kwargs
914
+ and request_obj.chat_template_kwargs.get("enable_thinking") is not None
915
+ ):
916
+ return request_obj.chat_template_kwargs.get("enable_thinking")
917
+ return True
918
+
919
+
897
920
  def v1_chat_generate_request(
898
921
  all_requests: List[ChatCompletionRequest],
899
922
  tokenizer_manager,
@@ -1099,7 +1122,7 @@ def v1_chat_generate_request(
1099
1122
 
1100
1123
  sampling_params = {
1101
1124
  "temperature": request.temperature,
1102
- "max_new_tokens": request.max_tokens,
1125
+ "max_new_tokens": request.max_tokens or request.max_completion_tokens,
1103
1126
  "min_new_tokens": request.min_tokens,
1104
1127
  "stop": stop,
1105
1128
  "stop_token_ids": request.stop_token_ids,
@@ -1258,31 +1281,16 @@ def v1_chat_generate_response(
1258
1281
  tool_calls = None
1259
1282
  text = ret_item["text"]
1260
1283
 
1261
- enable_thinking = True
1262
1284
  if isinstance(request, list):
1263
1285
  tool_choice = request[idx].tool_choice
1264
1286
  tools = request[idx].tools
1265
1287
  separate_reasoning = request[idx].separate_reasoning
1266
-
1267
- if (
1268
- request[idx].chat_template_kwargs
1269
- and request[idx].chat_template_kwargs.get("enable_thinking") is not None
1270
- ):
1271
- enable_thinking = request[idx].chat_template_kwargs.get(
1272
- "enable_thinking", True
1273
- )
1288
+ enable_thinking = _get_enable_thinking_from_request(request[idx])
1274
1289
  else:
1275
1290
  tool_choice = request.tool_choice
1276
1291
  tools = request.tools
1277
1292
  separate_reasoning = request.separate_reasoning
1278
-
1279
- if (
1280
- request.chat_template_kwargs
1281
- and request.chat_template_kwargs.get("enable_thinking") is not None
1282
- ):
1283
- enable_thinking = request.chat_template_kwargs.get(
1284
- "enable_thinking", True
1285
- )
1293
+ enable_thinking = _get_enable_thinking_from_request(request)
1286
1294
 
1287
1295
  reasoning_text = None
1288
1296
  if reasoning_parser and separate_reasoning and enable_thinking:
@@ -1521,9 +1529,12 @@ async def v1_chat_completions(
1521
1529
  delta = text[len(stream_buffer) :]
1522
1530
  new_stream_buffer = stream_buffer + delta
1523
1531
 
1532
+ enable_thinking = _get_enable_thinking_from_request(request)
1533
+
1524
1534
  if (
1525
1535
  tokenizer_manager.server_args.reasoning_parser
1526
1536
  and request.separate_reasoning
1537
+ and enable_thinking
1527
1538
  ):
1528
1539
  if index not in reasoning_parser_dict:
1529
1540
  reasoning_parser_dict[index] = ReasoningParser(
@@ -1613,6 +1624,7 @@ async def v1_chat_completions(
1613
1624
 
1614
1625
  tool_call = ToolCall(
1615
1626
  id=str(call_item.tool_index),
1627
+ index=call_item.tool_index,
1616
1628
  function=FunctionResponse(
1617
1629
  name=call_item.name,
1618
1630
  arguments=call_item.parameters,
@@ -320,7 +320,16 @@ class ChatCompletionRequest(BaseModel):
320
320
  logit_bias: Optional[Dict[str, float]] = None
321
321
  logprobs: bool = False
322
322
  top_logprobs: Optional[int] = None
323
- max_tokens: Optional[int] = None
323
+ max_tokens: Optional[int] = Field(
324
+ default=None,
325
+ deprecated="max_tokens is deprecated in favor of the max_completion_tokens field",
326
+ description="The maximum number of tokens that can be generated in the chat completion. ",
327
+ )
328
+ max_completion_tokens: Optional[int] = Field(
329
+ default=None,
330
+ description="The maximum number of completion tokens for a chat completion request, "
331
+ "including visible output tokens and reasoning tokens. Input tokens are not included. ",
332
+ )
324
333
  n: int = 1
325
334
  presence_penalty: float = 0.0
326
335
  response_format: Optional[Union[ResponseFormat, StructuralTagResponseFormat]] = None
@@ -380,6 +389,7 @@ class ToolCall(BaseModel):
380
389
  """Tool call response."""
381
390
 
382
391
  id: str
392
+ index: Optional[int] = None
383
393
  type: Literal["function"] = "function"
384
394
  function: FunctionResponse
385
395