sglang 0.4.1.post7__py3-none-any.whl → 0.4.2.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. sglang/bench_offline_throughput.py +17 -11
  2. sglang/bench_one_batch.py +14 -6
  3. sglang/bench_serving.py +47 -44
  4. sglang/lang/chat_template.py +31 -0
  5. sglang/srt/configs/load_config.py +1 -0
  6. sglang/srt/distributed/device_communicators/custom_all_reduce.py +5 -2
  7. sglang/srt/entrypoints/engine.py +5 -2
  8. sglang/srt/entrypoints/http_server.py +24 -0
  9. sglang/srt/function_call_parser.py +494 -0
  10. sglang/srt/layers/activation.py +5 -5
  11. sglang/srt/layers/attention/triton_ops/prefill_attention.py +6 -0
  12. sglang/srt/layers/attention/vision.py +243 -40
  13. sglang/srt/layers/dp_attention.py +3 -1
  14. sglang/srt/layers/layernorm.py +5 -5
  15. sglang/srt/layers/linear.py +24 -9
  16. sglang/srt/layers/logits_processor.py +1 -1
  17. sglang/srt/layers/moe/ep_moe/layer.py +20 -12
  18. sglang/srt/layers/moe/fused_moe_native.py +17 -3
  19. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  20. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +18 -1
  21. sglang/srt/layers/moe/fused_moe_triton/layer.py +9 -0
  22. sglang/srt/layers/parameter.py +16 -7
  23. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  24. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  25. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  26. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  27. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  28. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  29. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  30. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  31. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  32. sglang/srt/layers/quantization/fp8.py +11 -1
  33. sglang/srt/layers/rotary_embedding.py +34 -13
  34. sglang/srt/layers/sampler.py +33 -10
  35. sglang/srt/layers/torchao_utils.py +12 -6
  36. sglang/srt/managers/detokenizer_manager.py +1 -0
  37. sglang/srt/managers/image_processor.py +77 -38
  38. sglang/srt/managers/io_struct.py +36 -5
  39. sglang/srt/managers/schedule_batch.py +31 -25
  40. sglang/srt/managers/scheduler.py +78 -38
  41. sglang/srt/managers/tokenizer_manager.py +4 -0
  42. sglang/srt/mem_cache/base_prefix_cache.py +4 -0
  43. sglang/srt/mem_cache/chunk_cache.py +3 -0
  44. sglang/srt/mem_cache/radix_cache.py +30 -1
  45. sglang/srt/model_executor/cuda_graph_runner.py +23 -25
  46. sglang/srt/model_executor/forward_batch_info.py +5 -7
  47. sglang/srt/model_executor/model_runner.py +7 -4
  48. sglang/srt/model_loader/loader.py +75 -0
  49. sglang/srt/model_loader/weight_utils.py +91 -5
  50. sglang/srt/models/commandr.py +14 -2
  51. sglang/srt/models/dbrx.py +9 -1
  52. sglang/srt/models/deepseek_v2.py +3 -3
  53. sglang/srt/models/gemma2.py +9 -1
  54. sglang/srt/models/grok.py +1 -0
  55. sglang/srt/models/minicpm3.py +3 -3
  56. sglang/srt/models/minicpmv.py +129 -76
  57. sglang/srt/models/mllama.py +16 -56
  58. sglang/srt/models/qwen2.py +4 -1
  59. sglang/srt/models/qwen2_vl.py +18 -8
  60. sglang/srt/models/torch_native_llama.py +17 -4
  61. sglang/srt/openai_api/adapter.py +139 -37
  62. sglang/srt/openai_api/protocol.py +5 -4
  63. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +11 -14
  64. sglang/srt/sampling/sampling_batch_info.py +4 -14
  65. sglang/srt/server.py +2 -2
  66. sglang/srt/server_args.py +26 -1
  67. sglang/srt/speculative/eagle_utils.py +37 -15
  68. sglang/srt/speculative/eagle_worker.py +11 -13
  69. sglang/srt/utils.py +62 -67
  70. sglang/test/test_programs.py +1 -0
  71. sglang/test/test_utils.py +81 -22
  72. sglang/utils.py +42 -0
  73. sglang/version.py +1 -1
  74. {sglang-0.4.1.post7.dist-info → sglang-0.4.2.post1.dist-info}/METADATA +8 -8
  75. {sglang-0.4.1.post7.dist-info → sglang-0.4.2.post1.dist-info}/RECORD +78 -67
  76. {sglang-0.4.1.post7.dist-info → sglang-0.4.2.post1.dist-info}/LICENSE +0 -0
  77. {sglang-0.4.1.post7.dist-info → sglang-0.4.2.post1.dist-info}/WHEEL +0 -0
  78. {sglang-0.4.1.post7.dist-info → sglang-0.4.2.post1.dist-info}/top_level.txt +0 -0
@@ -17,6 +17,7 @@ from transformers.models.mllama.modeling_mllama import (
17
17
  import sglang.srt.distributed.parallel_state as ps
18
18
  from sglang.srt.distributed import get_tensor_model_parallel_world_size
19
19
  from sglang.srt.layers.activation import get_act_fn
20
+ from sglang.srt.layers.attention.vision import VisionAttention
20
21
  from sglang.srt.layers.layernorm import RMSNorm
21
22
  from sglang.srt.layers.linear import (
22
23
  ColumnParallelLinear,
@@ -145,61 +146,6 @@ class MllamaPrecomputedPositionEmbedding(nn.Module):
145
146
  return hidden_state
146
147
 
147
148
 
148
- class MllamaVisionSdpaAttention(nn.Module):
149
- def __init__(self, config: config_mllama.MllamaVisionConfig):
150
- super().__init__()
151
-
152
- model_parallel_size = get_tensor_model_parallel_world_size()
153
- self.embed_dim = config.hidden_size
154
- self.num_heads = config.attention_heads
155
- self.head_dim = config.hidden_size // config.attention_heads
156
- self.num_local_heads = self.num_heads // model_parallel_size
157
- self.q_size = self.num_local_heads * self.head_dim
158
- self.kv_size = self.num_local_heads * self.head_dim
159
-
160
- self.qkv_proj = QKVParallelLinear(
161
- self.embed_dim,
162
- self.head_dim,
163
- self.num_heads,
164
- bias=False,
165
- )
166
- self.o_proj = RowParallelLinear(
167
- self.num_heads * self.head_dim,
168
- self.embed_dim,
169
- bias=False,
170
- input_is_parallel=True,
171
- )
172
-
173
- def forward(
174
- self,
175
- hidden_state: torch.Tensor,
176
- attention_mask: Optional[torch.Tensor] = None,
177
- ) -> torch.Tensor:
178
- qkv, _ = self.qkv_proj(hidden_state)
179
- q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
180
- q = q.view(
181
- q.shape[0], q.shape[1], self.num_local_heads, self.head_dim
182
- ).transpose(1, 2)
183
- k = k.view(
184
- k.shape[0], k.shape[1], self.num_local_heads, self.head_dim
185
- ).transpose(1, 2)
186
- v = v.view(
187
- v.shape[0], v.shape[1], self.num_local_heads, self.head_dim
188
- ).transpose(1, 2)
189
-
190
- # TODO: remove padding in image encoder
191
- attn_output = F.scaled_dot_product_attention(
192
- q, k, v, attn_mask=attention_mask, dropout_p=0.0
193
- )
194
-
195
- attn_output = attn_output.transpose(1, 2).contiguous()
196
- attn_output = attn_output.reshape(
197
- attn_output.shape[0], attn_output.shape[1], -1
198
- )
199
- output, _ = self.o_proj(attn_output)
200
- return output
201
-
202
-
203
149
  class MllamaVisionMLP(nn.Module):
204
150
  def __init__(self, config, quant_config: Optional[QuantizationConfig] = None):
205
151
  super().__init__()
@@ -237,7 +183,17 @@ class MllamaVisionEncoderLayer(nn.Module):
237
183
  self.is_gated = is_gated
238
184
  self.intermediate_size = config.intermediate_size
239
185
 
240
- self.self_attn = MllamaVisionSdpaAttention(config)
186
+ self.self_attn = VisionAttention(
187
+ self.hidden_size,
188
+ self.num_attention_heads,
189
+ self.hidden_size,
190
+ use_qkv_parallel=True,
191
+ quant_config=None,
192
+ dropout=0.0,
193
+ use_context_forward=False,
194
+ use_full_precision_softmax=False,
195
+ flatten_batch=False,
196
+ )
241
197
  self.mlp = MllamaVisionMLP(config)
242
198
 
243
199
  self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=config.norm_eps)
@@ -992,6 +948,10 @@ class MllamaForConditionalGeneration(nn.Module):
992
948
  weight_loader(param, loaded_weight, shard_id)
993
949
  break
994
950
  else:
951
+ if "vision_model" in name:
952
+ # adapt to VisionAttention
953
+ name = name.replace("self_attn.o_proj", "self_attn.proj")
954
+
995
955
  param = params_dict.pop(name)
996
956
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
997
957
  weight_loader(param, loaded_weight)
@@ -249,7 +249,10 @@ class Qwen2Model(nn.Module):
249
249
  self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
250
250
 
251
251
  def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
252
- return self.embed_tokens(input_ids)
252
+ if hasattr(self.config, "scale_emb"):
253
+ return self.embed_tokens(input_ids) * self.config.scale_emb
254
+ else:
255
+ return self.embed_tokens(input_ids)
253
256
 
254
257
  def forward(
255
258
  self,
@@ -30,12 +30,10 @@ import numpy as np
30
30
  import torch
31
31
  import torch.nn as nn
32
32
  import torch.nn.functional as F
33
- from einops import rearrange, repeat
33
+ from einops import rearrange
34
34
  from vllm.model_executor.layers.activation import QuickGELU
35
35
 
36
36
  from sglang.srt.configs import Qwen2VLConfig, Qwen2VLVisionConfig
37
- from sglang.srt.distributed import parallel_state
38
- from sglang.srt.distributed import utils as dist_utils
39
37
  from sglang.srt.hf_transformers_utils import get_processor
40
38
  from sglang.srt.layers.attention.vision import VisionAttention
41
39
  from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
@@ -118,6 +116,7 @@ class Qwen2VisionBlock(nn.Module):
118
116
  mlp_ratio: float,
119
117
  act_layer: Type[nn.Module] = QuickGELU,
120
118
  norm_layer: Type[nn.Module] = None,
119
+ attn_implementation: Optional[str] = "sdpa",
121
120
  quant_config: Optional[QuantizationConfig] = None,
122
121
  ) -> None:
123
122
  super().__init__()
@@ -126,12 +125,24 @@ class Qwen2VisionBlock(nn.Module):
126
125
  self.norm1 = norm_layer(dim)
127
126
  self.norm2 = norm_layer(dim)
128
127
  mlp_hidden_dim = int(dim * mlp_ratio)
128
+ if attn_implementation == "sdpa":
129
+ use_context_forward = False
130
+ use_full_precision_softmax = False
131
+ elif attn_implementation == "flash_attention_2":
132
+ use_full_precision_softmax = False
133
+ use_context_forward = True
134
+ elif attn_implementation == "eager":
135
+ use_full_precision_softmax = True
136
+ use_context_forward = False
129
137
 
130
138
  self.attn = VisionAttention(
131
139
  embed_dim=dim,
132
140
  num_heads=num_heads,
133
141
  projection_size=dim,
134
142
  use_qkv_parallel=False,
143
+ use_context_forward=use_context_forward,
144
+ use_full_precision_softmax=use_full_precision_softmax,
145
+ flatten_batch=True,
135
146
  quant_config=quant_config,
136
147
  )
137
148
  self.mlp = Qwen2VisionMLP(
@@ -286,7 +297,6 @@ class Qwen2VisionTransformer(nn.Module):
286
297
  norm_layer = partial(nn.LayerNorm, eps=norm_eps)
287
298
  head_dim = embed_dim // num_heads
288
299
  self.rotary_pos_emb = Qwen2VisionRotaryEmbedding(head_dim // 2)
289
-
290
300
  self.blocks = nn.ModuleList(
291
301
  [
292
302
  Qwen2VisionBlock(
@@ -294,6 +304,7 @@ class Qwen2VisionTransformer(nn.Module):
294
304
  num_heads=num_heads,
295
305
  mlp_ratio=mlp_ratio,
296
306
  norm_layer=norm_layer,
307
+ attn_implementation="sdpa",
297
308
  quant_config=quant_config,
298
309
  )
299
310
  for _ in range(depth)
@@ -482,10 +493,6 @@ class Qwen2VLForConditionalGeneration(nn.Module):
482
493
  opensource models), the shape will be `(3, seq_len)`,
483
494
  otherwise it will be `(seq_len,).
484
495
  (Use input_metadata.mrope_positions to replace it)
485
- pixel_values: Pixel values to be fed to a model.
486
- `None` if no images are passed.
487
- image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM.
488
- `None` if no images are passed.
489
496
  """
490
497
  if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
491
498
  positions = forward_batch.mrope_positions
@@ -540,15 +547,18 @@ class Qwen2VLForConditionalGeneration(nn.Module):
540
547
  num_image_tokens = self.calculate_num_image_tokens(
541
548
  image_grid_thws[idx]
542
549
  )
550
+
543
551
  left_idx = start_idx + (image_offset - prefix_len)
544
552
  right_idx = (
545
553
  start_idx + (image_offset - prefix_len) + num_image_tokens
546
554
  )
555
+
547
556
  inputs_embeds[left_idx:right_idx] = image_embeds[
548
557
  image_embeds_offset : image_embeds_offset + num_image_tokens
549
558
  ]
550
559
  image_embeds_offset += num_image_tokens
551
560
 
561
+ input_ids = None
552
562
  hidden_states = self.model(
553
563
  input_ids=input_ids,
554
564
  positions=positions,
@@ -460,7 +460,12 @@ class TorchNativeLlamaForCausalLM(nn.Module):
460
460
  params_dict = dict(self.named_parameters())
461
461
  return len(params_dict)
462
462
 
463
- def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
463
+ def load_weights_to_module(
464
+ self,
465
+ fqn: str,
466
+ weights: Iterable[Tuple[str, torch.Tensor]],
467
+ ):
468
+ """Load weights onto submodule pointed by path `fqn`."""
464
469
  stacked_params_mapping = [
465
470
  # (param_name, shard_name, shard_id)
466
471
  (".qkv_proj", ".q_proj", "q"),
@@ -469,7 +474,8 @@ class TorchNativeLlamaForCausalLM(nn.Module):
469
474
  (".gate_up_proj", ".gate_proj", 0),
470
475
  (".gate_up_proj", ".up_proj", 1),
471
476
  ]
472
- params_dict = dict(self.named_parameters())
477
+ module = self.get_submodule(fqn)
478
+ params_dict = dict(module.named_parameters(prefix=fqn, recurse=False))
473
479
 
474
480
  for name, loaded_weight in weights:
475
481
  if "rotary_emb.inv_freq" in name or "projector" in name:
@@ -486,7 +492,7 @@ class TorchNativeLlamaForCausalLM(nn.Module):
486
492
  continue
487
493
  name = name.replace(weight_name, param_name)
488
494
  # Skip loading extra bias for GPTQ models.
489
- if name.endswith(".bias") and name not in params_dict:
495
+ if name.endswith(".bias") or name not in params_dict:
490
496
  continue
491
497
  param = params_dict[name]
492
498
  weight_loader = param.weight_loader
@@ -494,12 +500,19 @@ class TorchNativeLlamaForCausalLM(nn.Module):
494
500
  break
495
501
  else:
496
502
  # Skip loading extra bias for GPTQ models.
497
- if name.endswith(".bias") and name not in params_dict:
503
+ if name.endswith(".bias") or name not in params_dict:
498
504
  continue
499
505
  param = params_dict[name]
500
506
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
501
507
  weight_loader(param, loaded_weight)
502
508
 
509
+ def load_weights(
510
+ self,
511
+ weights: Iterable[Tuple[str, torch.Tensor]],
512
+ ):
513
+ """Load weights onto the full model."""
514
+ self.load_weights_to_module("", weights)
515
+
503
516
 
504
517
  class TorchNativePhi3ForCausalLM(TorchNativeLlamaForCausalLM):
505
518
  pass
@@ -20,7 +20,7 @@ import os
20
20
  import time
21
21
  import uuid
22
22
  from http import HTTPStatus
23
- from typing import Dict, List
23
+ from typing import Dict, List, Optional
24
24
 
25
25
  from fastapi import HTTPException, Request, UploadFile
26
26
  from fastapi.responses import ORJSONResponse, StreamingResponse
@@ -40,6 +40,7 @@ from sglang.srt.conversation import (
40
40
  generate_chat_conv,
41
41
  register_conv_template,
42
42
  )
43
+ from sglang.srt.function_call_parser import TOOLS_TAG_LIST, FunctionCallParser
43
44
  from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
44
45
  from sglang.srt.openai_api.protocol import (
45
46
  BatchRequest,
@@ -71,7 +72,6 @@ from sglang.srt.openai_api.protocol import (
71
72
  TopLogprob,
72
73
  UsageInfo,
73
74
  )
74
- from sglang.srt.utils import TOOLS_TAG_LIST, parse_tool_response
75
75
  from sglang.utils import get_exception_traceback
76
76
 
77
77
  logger = logging.getLogger(__name__)
@@ -309,6 +309,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
309
309
  ret,
310
310
  to_file=True,
311
311
  cache_report=tokenizer_manager.server_args.enable_cache_report,
312
+ tool_call_parser=tokenizer_manager.server_args.tool_call_parser,
312
313
  )
313
314
  else:
314
315
  responses = v1_generate_response(
@@ -877,9 +878,6 @@ def v1_chat_generate_request(
877
878
  tools = None
878
879
  if request.tools and request.tool_choice != "none":
879
880
  request.skip_special_tokens = False
880
- if request.stream:
881
- logger.warning("Streaming is not supported with tools.")
882
- request.stream = False
883
881
  if not isinstance(request.tool_choice, str):
884
882
  tools = [
885
883
  item.function.model_dump()
@@ -908,12 +906,26 @@ def v1_chat_generate_request(
908
906
  openai_compatible_messages = openai_compatible_messages[:-1]
909
907
  else:
910
908
  assistant_prefix = None
911
- prompt_ids = tokenizer_manager.tokenizer.apply_chat_template(
912
- openai_compatible_messages,
913
- tokenize=True,
914
- add_generation_prompt=True,
915
- tools=tools,
916
- )
909
+
910
+ try:
911
+ prompt_ids = tokenizer_manager.tokenizer.apply_chat_template(
912
+ openai_compatible_messages,
913
+ tokenize=True,
914
+ add_generation_prompt=True,
915
+ tools=tools,
916
+ )
917
+ except:
918
+ # This except branch will be triggered when the chosen model
919
+ # has a different tools input format that is not compatiable
920
+ # with openAI's apply_chat_template tool_call format, like Mistral.
921
+ tools = [t if "function" in t else {"function": t} for t in tools]
922
+ prompt_ids = tokenizer_manager.tokenizer.apply_chat_template(
923
+ openai_compatible_messages,
924
+ tokenize=True,
925
+ add_generation_prompt=True,
926
+ tools=tools,
927
+ )
928
+
917
929
  if assistant_prefix:
918
930
  prompt_ids += tokenizer_manager.tokenizer.encode(assistant_prefix)
919
931
  stop = request.stop
@@ -1005,7 +1017,9 @@ def v1_chat_generate_request(
1005
1017
  return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]
1006
1018
 
1007
1019
 
1008
- def v1_chat_generate_response(request, ret, to_file=False, cache_report=False):
1020
+ def v1_chat_generate_response(
1021
+ request, ret, to_file=False, cache_report=False, tool_call_parser=None
1022
+ ):
1009
1023
  choices = []
1010
1024
 
1011
1025
  for idx, ret_item in enumerate(ret):
@@ -1066,12 +1080,13 @@ def v1_chat_generate_response(request, ret, to_file=False, cache_report=False):
1066
1080
  if finish_reason == "stop":
1067
1081
  finish_reason = "tool_calls"
1068
1082
  try:
1069
- text, call_info_list = parse_tool_response(text, tools) # noqa
1083
+ parser = FunctionCallParser(tools, tool_call_parser)
1084
+ full_normal_text, call_info_list = parser.parse_non_stream(text)
1070
1085
  tool_calls = [
1071
1086
  ToolCall(
1072
- id=str(call_info[0]),
1087
+ id=str(call_info.tool_index),
1073
1088
  function=FunctionResponse(
1074
- name=call_info[1], arguments=call_info[2]
1089
+ name=call_info.name, arguments=call_info.parameters
1075
1090
  ),
1076
1091
  )
1077
1092
  for call_info in call_info_list
@@ -1172,6 +1187,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
1172
1187
  adapted_request, request = v1_chat_generate_request(all_requests, tokenizer_manager)
1173
1188
 
1174
1189
  if adapted_request.stream:
1190
+ parser_dict = {}
1175
1191
 
1176
1192
  async def generate_stream_resp():
1177
1193
  is_firsts = {}
@@ -1184,6 +1200,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
1184
1200
  adapted_request, raw_request
1185
1201
  ):
1186
1202
  index = content.get("index", 0)
1203
+ text = content["text"]
1187
1204
 
1188
1205
  is_first = is_firsts.get(index, True)
1189
1206
  stream_buffer = stream_buffers.get(index, "")
@@ -1263,29 +1280,111 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
1263
1280
 
1264
1281
  text = content["text"]
1265
1282
  delta = text[len(stream_buffer) :]
1266
- stream_buffer = stream_buffer + delta
1267
- choice_data = ChatCompletionResponseStreamChoice(
1268
- index=index,
1269
- delta=DeltaMessage(content=delta),
1270
- finish_reason=(finish_reason["type"] if finish_reason else ""),
1271
- matched_stop=(
1272
- finish_reason["matched"]
1273
- if finish_reason and "matched" in finish_reason
1274
- else None
1275
- ),
1276
- logprobs=choice_logprobs,
1277
- )
1278
- chunk = ChatCompletionStreamResponse(
1279
- id=content["meta_info"]["id"],
1280
- choices=[choice_data],
1281
- model=request.model,
1282
- )
1283
+ new_stream_buffer = stream_buffer + delta
1283
1284
 
1284
- is_firsts[index] = is_first
1285
- stream_buffers[index] = stream_buffer
1286
- n_prev_tokens[index] = n_prev_token
1285
+ if request.tool_choice != "none" and request.tools:
1286
+ if index not in parser_dict:
1287
+ parser_dict[index] = FunctionCallParser(
1288
+ tools=request.tools,
1289
+ tool_call_parser=tokenizer_manager.server_args.tool_call_parser,
1290
+ )
1291
+ parser = parser_dict[index]
1292
+
1293
+ # parse_increment => returns (normal_text, calls)
1294
+ normal_text, calls = parser.parse_stream_chunk(delta)
1295
+
1296
+ # 1) if there's normal_text, output it as normal content
1297
+ if normal_text:
1298
+ choice_data = ChatCompletionResponseStreamChoice(
1299
+ index=index,
1300
+ delta=DeltaMessage(content=normal_text),
1301
+ finish_reason=(
1302
+ finish_reason["type"] if finish_reason else ""
1303
+ ),
1304
+ )
1305
+ chunk = ChatCompletionStreamResponse(
1306
+ id=content["meta_info"]["id"],
1307
+ choices=[choice_data],
1308
+ model=request.model,
1309
+ )
1310
+ yield f"data: {chunk.model_dump_json()}\n\n"
1311
+
1312
+ # 2) if we found calls, we output them as separate chunk(s)
1313
+ for call_item in calls:
1314
+ # transform call_item -> FunctionResponse + ToolCall
1315
+
1316
+ if (
1317
+ content["meta_info"]["finish_reason"]
1318
+ and content["meta_info"]["finish_reason"]["type"]
1319
+ == "stop"
1320
+ ):
1321
+ latest_delta_len = 0
1322
+ if isinstance(call_item.parameters, str):
1323
+ latest_delta_len = len(call_item.parameters)
1324
+
1325
+ expected_call = json.dumps(
1326
+ parser.multi_format_parser.detectors[0]
1327
+ .prev_tool_call_arr[index]
1328
+ .get("arguments", {}),
1329
+ ensure_ascii=False,
1330
+ )
1331
+ actual_call = parser.multi_format_parser.detectors[
1332
+ 0
1333
+ ].streamed_args_for_tool[index]
1334
+ if latest_delta_len > 0:
1335
+ actual_call = actual_call[:-latest_delta_len]
1336
+ remaining_call = expected_call.replace(
1337
+ actual_call, "", 1
1338
+ )
1339
+ call_item.parameters = remaining_call
1340
+
1341
+ tool_call = ToolCall(
1342
+ id=str(call_item.tool_index),
1343
+ function=FunctionResponse(
1344
+ name=call_item.name,
1345
+ arguments=call_item.parameters,
1346
+ ),
1347
+ )
1348
+ choice_data = ChatCompletionResponseStreamChoice(
1349
+ index=index,
1350
+ delta=DeltaMessage(
1351
+ role="assistant", tool_calls=[tool_call]
1352
+ ),
1353
+ finish_reason="tool_call",
1354
+ )
1355
+ chunk = ChatCompletionStreamResponse(
1356
+ id=content["meta_info"]["id"],
1357
+ choices=[choice_data],
1358
+ model=request.model,
1359
+ )
1360
+ yield f"data: {chunk.model_dump_json()}\n\n"
1287
1361
 
1288
- yield f"data: {chunk.model_dump_json()}\n\n"
1362
+ stream_buffers[index] = new_stream_buffer
1363
+ is_firsts[index] = is_first
1364
+
1365
+ else:
1366
+ # No tool calls => just treat this as normal text
1367
+ choice_data = ChatCompletionResponseStreamChoice(
1368
+ index=index,
1369
+ delta=DeltaMessage(content=delta),
1370
+ finish_reason=(
1371
+ finish_reason["type"] if finish_reason else ""
1372
+ ),
1373
+ matched_stop=(
1374
+ finish_reason["matched"]
1375
+ if finish_reason and "matched" in finish_reason
1376
+ else None
1377
+ ),
1378
+ logprobs=choice_logprobs,
1379
+ )
1380
+ chunk = ChatCompletionStreamResponse(
1381
+ id=content["meta_info"]["id"],
1382
+ choices=[choice_data],
1383
+ model=request.model,
1384
+ )
1385
+ yield f"data: {chunk.model_dump_json()}\n\n"
1386
+ stream_buffers[index] = new_stream_buffer
1387
+ is_firsts[index] = is_first
1289
1388
  if request.stream_options and request.stream_options.include_usage:
1290
1389
  total_prompt_tokens = sum(
1291
1390
  tokens
@@ -1333,7 +1432,10 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
1333
1432
  ret = [ret]
1334
1433
 
1335
1434
  response = v1_chat_generate_response(
1336
- request, ret, cache_report=tokenizer_manager.server_args.enable_cache_report
1435
+ request,
1436
+ ret,
1437
+ cache_report=tokenizer_manager.server_args.enable_cache_report,
1438
+ tool_call_parser=tokenizer_manager.server_args.tool_call_parser,
1337
1439
  )
1338
1440
 
1339
1441
  return response
@@ -262,7 +262,7 @@ class Function(BaseModel):
262
262
  """Function descriptions."""
263
263
 
264
264
  description: Optional[str] = Field(default=None, examples=[None])
265
- name: str
265
+ name: Optional[str] = None
266
266
  parameters: Optional[object] = None
267
267
 
268
268
 
@@ -276,7 +276,7 @@ class Tool(BaseModel):
276
276
  class ToolChoiceFuncName(BaseModel):
277
277
  """The name of tool choice function."""
278
278
 
279
- name: str
279
+ name: Optional[str] = None
280
280
 
281
281
 
282
282
  class ToolChoice(BaseModel):
@@ -329,8 +329,8 @@ class ChatCompletionRequest(BaseModel):
329
329
  class FunctionResponse(BaseModel):
330
330
  """Function response."""
331
331
 
332
- name: str
333
- arguments: str
332
+ name: Optional[str] = None
333
+ arguments: Optional[str] = None
334
334
 
335
335
 
336
336
  class ToolCall(BaseModel):
@@ -367,6 +367,7 @@ class ChatCompletionResponse(BaseModel):
367
367
  class DeltaMessage(BaseModel):
368
368
  role: Optional[str] = None
369
369
  content: Optional[str] = None
370
+ tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
370
371
 
371
372
 
372
373
  class ChatCompletionResponseStreamChoice(BaseModel):
@@ -3,11 +3,16 @@ from typing import List
3
3
  import torch
4
4
 
5
5
  from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
6
- from sglang.srt.utils import is_cuda_available
6
+ from sglang.srt.utils import get_compiler_backend
7
7
 
8
- is_cuda = is_cuda_available()
9
- if is_cuda:
10
- from sgl_kernel import sampling_scaling_penalties
8
+
9
+ @torch.compile(dynamic=True, backend=get_compiler_backend())
10
+ def apply_scaling_penalties(logits, scaling_penalties):
11
+ logits[:] = torch.where(
12
+ logits > 0,
13
+ logits / scaling_penalties,
14
+ logits * scaling_penalties,
15
+ )
11
16
 
12
17
 
13
18
  class BatchedRepetitionPenalizer(_BatchedPenalizer):
@@ -61,16 +66,8 @@ class BatchedRepetitionPenalizer(_BatchedPenalizer):
61
66
  self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask]
62
67
 
63
68
  def _apply(self, logits: torch.Tensor) -> torch.Tensor:
64
- if is_cuda:
65
- return sampling_scaling_penalties(
66
- logits, self.cumulated_repetition_penalties
67
- )
68
- else:
69
- return torch.where(
70
- logits > 0,
71
- logits / self.cumulated_repetition_penalties,
72
- logits * self.cumulated_repetition_penalties,
73
- )
69
+ apply_scaling_penalties(logits, self.cumulated_repetition_penalties)
70
+ return logits
74
71
 
75
72
  def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
76
73
  self.repetition_penalties = self.repetition_penalties[indices_tensor_to_keep]
@@ -7,14 +7,11 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
7
7
 
8
8
  import torch
9
9
 
10
- from sglang.srt.utils import is_cuda_available
11
-
12
- is_cuda = is_cuda_available()
13
- if is_cuda:
14
- from sgl_kernel import sampling_scaling_penalties
15
-
16
10
  import sglang.srt.sampling.penaltylib as penaltylib
17
11
  from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
12
+ from sglang.srt.sampling.penaltylib.penalizers.repetition_penalty import (
13
+ apply_scaling_penalties,
14
+ )
18
15
 
19
16
  logger = logging.getLogger(__name__)
20
17
 
@@ -386,14 +383,7 @@ class SamplingBatchInfo:
386
383
 
387
384
  # repetition
388
385
  if self.scaling_penalties is not None:
389
- if is_cuda:
390
- logits[:] = sampling_scaling_penalties(logits, self.scaling_penalties)
391
- else:
392
- logits[:] = torch.where(
393
- logits > 0,
394
- logits / self.scaling_penalties,
395
- logits * self.scaling_penalties,
396
- )
386
+ apply_scaling_penalties(logits, self.scaling_penalties)
397
387
 
398
388
  # Apply regex vocab_mask
399
389
  if self.vocab_mask is not None:
sglang/srt/server.py CHANGED
@@ -12,7 +12,7 @@
12
12
  # limitations under the License.
13
13
  # ==============================================================================
14
14
 
15
- # Some shortcuts for backward compatbility.
15
+ # Some shortcuts for backward compatibility.
16
16
  # They will be removed in new versions.
17
17
  from sglang.srt.entrypoints.engine import Engine
18
- from sglang.srt.entrypoints.http_server import launch_server
18
+ from sglang.srt.entrypoints.http_server import kill_process_tree, launch_server