sglang 0.4.4.post2__py3-none-any.whl → 0.4.4.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. sglang/bench_serving.py +23 -3
  2. sglang/srt/configs/deepseekvl2.py +10 -1
  3. sglang/srt/configs/model_config.py +5 -16
  4. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  5. sglang/srt/distributed/parallel_state.py +32 -5
  6. sglang/srt/entrypoints/http_server.py +7 -1
  7. sglang/srt/entrypoints/verl_engine.py +2 -0
  8. sglang/srt/function_call_parser.py +0 -1
  9. sglang/srt/layers/attention/flashattention_backend.py +218 -79
  10. sglang/srt/layers/dp_attention.py +12 -1
  11. sglang/srt/layers/moe/topk.py +30 -3
  12. sglang/srt/layers/quantization/__init__.py +134 -165
  13. sglang/srt/layers/quantization/awq.py +200 -0
  14. sglang/srt/layers/quantization/fp8_kernel.py +2 -1
  15. sglang/srt/layers/quantization/gptq.py +30 -40
  16. sglang/srt/layers/quantization/w8a8_fp8.py +1 -1
  17. sglang/srt/layers/rotary_embedding.py +12 -0
  18. sglang/srt/lora/backend/base_backend.py +4 -4
  19. sglang/srt/lora/backend/flashinfer_backend.py +12 -9
  20. sglang/srt/lora/backend/triton_backend.py +5 -8
  21. sglang/srt/lora/layers.py +19 -33
  22. sglang/srt/lora/lora_manager.py +20 -7
  23. sglang/srt/lora/mem_pool.py +12 -6
  24. sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
  25. sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
  26. sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
  27. sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
  28. sglang/srt/lora/utils.py +6 -0
  29. sglang/srt/managers/io_struct.py +4 -2
  30. sglang/srt/managers/multimodal_processors/clip.py +63 -0
  31. sglang/srt/managers/schedule_batch.py +1 -0
  32. sglang/srt/managers/scheduler.py +25 -19
  33. sglang/srt/managers/tokenizer_manager.py +0 -1
  34. sglang/srt/managers/tp_worker.py +3 -0
  35. sglang/srt/model_executor/cuda_graph_runner.py +9 -8
  36. sglang/srt/model_executor/model_runner.py +9 -6
  37. sglang/srt/model_loader/loader.py +11 -1
  38. sglang/srt/model_loader/weight_utils.py +6 -3
  39. sglang/srt/models/clip.py +563 -0
  40. sglang/srt/models/deepseek_janus_pro.py +2 -2
  41. sglang/srt/models/deepseek_v2.py +151 -26
  42. sglang/srt/models/gemma3_causal.py +12 -2
  43. sglang/srt/models/gemma3_mm.py +6 -0
  44. sglang/srt/openai_api/adapter.py +88 -87
  45. sglang/srt/openai_api/protocol.py +10 -5
  46. sglang/srt/patch_torch.py +71 -0
  47. sglang/srt/server_args.py +21 -11
  48. sglang/srt/speculative/eagle_worker.py +1 -1
  49. sglang/srt/utils.py +33 -0
  50. sglang/test/runners.py +27 -2
  51. sglang/test/test_utils.py +1 -1
  52. sglang/version.py +1 -1
  53. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post3.dist-info}/METADATA +8 -4
  54. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post3.dist-info}/RECORD +57 -53
  55. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post3.dist-info}/WHEEL +0 -0
  56. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post3.dist-info}/licenses/LICENSE +0 -0
  57. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post3.dist-info}/top_level.txt +0 -0
@@ -39,6 +39,8 @@ from sglang.srt.layers.dp_attention import (
39
39
  get_attention_dp_size,
40
40
  get_attention_tp_rank,
41
41
  get_attention_tp_size,
42
+ tp_all_gather,
43
+ tp_reduce_scatter,
42
44
  )
43
45
  from sglang.srt.layers.layernorm import RMSNorm
44
46
  from sglang.srt.layers.linear import (
@@ -71,7 +73,7 @@ from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
71
73
  from sglang.srt.managers.schedule_batch import global_server_args_dict
72
74
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
73
75
  from sglang.srt.model_loader.weight_utils import default_weight_loader
74
- from sglang.srt.utils import add_prefix, is_cuda, is_cuda_available, is_hip
76
+ from sglang.srt.utils import add_prefix, is_cuda, is_hip
75
77
 
76
78
  _is_hip = is_hip()
77
79
  _is_cuda = is_cuda()
@@ -278,7 +280,11 @@ class DeepseekV2MoE(nn.Module):
278
280
  topk_weights = torch.empty(
279
281
  (0, self.top_k), dtype=torch.float32, device=hidden_states.device
280
282
  )
281
- if forward_mode is not None and not forward_mode.is_idle():
283
+ if (
284
+ forward_mode is not None
285
+ and not forward_mode.is_idle()
286
+ and hidden_states.shape[0] > 0
287
+ ):
282
288
  # router_logits: (num_tokens, n_experts)
283
289
  router_logits = self.gate(hidden_states)
284
290
  if self.n_shared_experts is not None:
@@ -649,6 +655,7 @@ class DeepseekV2AttentionMLA(nn.Module):
649
655
  self.flashinfer_mla_disable_ragged = global_server_args_dict[
650
656
  "flashinfer_mla_disable_ragged"
651
657
  ]
658
+ self.attention_backend = global_server_args_dict["attention_backend"]
652
659
  self.rocm_fused_decode_mla = os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1"
653
660
 
654
661
  def no_absorb(self, forward_batch: ForwardBatch) -> bool:
@@ -661,6 +668,9 @@ class DeepseekV2AttentionMLA(nn.Module):
661
668
  and not forward_batch.forward_mode.is_draft_extend()
662
669
  and sum(forward_batch.extend_prefix_lens_cpu) == 0
663
670
  )
671
+ elif self.attention_backend == "fa3":
672
+ # Flash Attention: Keep absorbing for all extend/decode
673
+ return False
664
674
  else:
665
675
  # Triton: Use normal computation for prefill and use weight absorption for extend/decode
666
676
  return (
@@ -969,6 +979,14 @@ class DeepseekV2DecoderLayer(nn.Module):
969
979
  is_nextn: bool = False,
970
980
  prefix: str = "",
971
981
  ) -> None:
982
+
983
+ def is_sparse_layer(l: int):
984
+ return (
985
+ config.n_routed_experts is not None
986
+ and l >= config.first_k_dense_replace
987
+ and l % config.moe_layer_freq == 0
988
+ )
989
+
972
990
  super().__init__()
973
991
  self.hidden_size = config.hidden_size
974
992
  rope_theta = getattr(config, "rope_theta", 10000)
@@ -977,6 +995,8 @@ class DeepseekV2DecoderLayer(nn.Module):
977
995
  self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
978
996
  self.layer_id = layer_id
979
997
  self.dp_size = get_attention_dp_size()
998
+ self.attn_tp_size = get_attention_tp_size()
999
+ self.attn_tp_rank = get_attention_tp_rank()
980
1000
 
981
1001
  if not global_server_args_dict["disable_mla"]:
982
1002
  self.self_attn = DeepseekV2AttentionMLA(
@@ -1019,16 +1039,13 @@ class DeepseekV2DecoderLayer(nn.Module):
1019
1039
  prefix=add_prefix("self_attn", prefix),
1020
1040
  )
1021
1041
 
1022
- if is_nextn or (
1023
- config.n_routed_experts is not None
1024
- and layer_id >= config.first_k_dense_replace
1025
- and layer_id % config.moe_layer_freq == 0
1026
- ):
1042
+ if is_nextn or is_sparse_layer(layer_id):
1027
1043
  self.mlp = DeepseekV2MoE(
1028
1044
  config=config,
1029
1045
  quant_config=quant_config,
1030
1046
  prefix=add_prefix("mlp", prefix),
1031
1047
  )
1048
+ self.is_sparse = True
1032
1049
  else:
1033
1050
  self.mlp = DeepseekV2MLP(
1034
1051
  hidden_size=config.hidden_size,
@@ -1037,6 +1054,14 @@ class DeepseekV2DecoderLayer(nn.Module):
1037
1054
  quant_config=quant_config,
1038
1055
  prefix=add_prefix("mlp", prefix),
1039
1056
  )
1057
+ self.is_sparse = False
1058
+
1059
+ self.input_is_scattered = (
1060
+ is_sparse_layer(layer_id - 1)
1061
+ and global_server_args_dict["enable_deepep_moe"]
1062
+ )
1063
+ self.is_last_layer = self.layer_id == config.num_hidden_layers - 1
1064
+
1040
1065
  self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1041
1066
  self.post_attention_layernorm = RMSNorm(
1042
1067
  config.hidden_size, eps=config.rms_norm_eps
@@ -1049,6 +1074,23 @@ class DeepseekV2DecoderLayer(nn.Module):
1049
1074
  forward_batch: ForwardBatch,
1050
1075
  residual: Optional[torch.Tensor],
1051
1076
  ) -> torch.Tensor:
1077
+ if global_server_args_dict["enable_deepep_moe"] and self.is_sparse:
1078
+ return self.forward_deepep(
1079
+ positions, hidden_states, forward_batch, residual
1080
+ )
1081
+ else:
1082
+ return self.forward_normal(
1083
+ positions, hidden_states, forward_batch, residual
1084
+ )
1085
+
1086
+ def forward_normal(
1087
+ self,
1088
+ positions: torch.Tensor,
1089
+ hidden_states: torch.Tensor,
1090
+ forward_batch: ForwardBatch,
1091
+ residual: Optional[torch.Tensor],
1092
+ ) -> torch.Tensor:
1093
+
1052
1094
  if hidden_states.shape[0] == 0:
1053
1095
  residual = hidden_states
1054
1096
  else:
@@ -1065,29 +1107,35 @@ class DeepseekV2DecoderLayer(nn.Module):
1065
1107
  forward_batch=forward_batch,
1066
1108
  )
1067
1109
 
1110
+ if self.attn_tp_size != 1 and self.input_is_scattered:
1111
+ hidden_states, local_hidden_states = (
1112
+ forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
1113
+ hidden_states,
1114
+ )
1115
+ tp_all_gather(
1116
+ list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
1117
+ )
1118
+ residual, local_residual = (
1119
+ forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
1120
+ residual,
1121
+ )
1122
+ tp_all_gather(
1123
+ list(residual.tensor_split(self.attn_tp_size)), local_residual
1124
+ )
1125
+
1068
1126
  # Gather
1069
1127
  if get_tensor_model_parallel_world_size() > 1:
1070
1128
  # all gather and all reduce
1071
1129
  if self.dp_size != 1:
1072
- if global_server_args_dict["enable_deepep_moe"] and isinstance(
1073
- self.mlp, DeepseekV2MoE
1074
- ):
1075
- if hidden_states.shape[0] != 0:
1076
- hidden_states, residual = self.post_attention_layernorm(
1077
- hidden_states, residual
1078
- )
1079
- hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
1080
- return hidden_states, residual
1081
- else:
1082
- if get_attention_tp_rank() == 0:
1083
- hidden_states += residual
1084
- hidden_states, local_hidden_states = (
1085
- forward_batch.gathered_buffer,
1086
- hidden_states,
1087
- )
1088
- dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
1089
- dp_scatter(residual, hidden_states, forward_batch)
1090
- hidden_states = self.post_attention_layernorm(hidden_states)
1130
+ if self.attn_tp_rank == 0:
1131
+ hidden_states += residual
1132
+ hidden_states, local_hidden_states = (
1133
+ forward_batch.gathered_buffer,
1134
+ hidden_states,
1135
+ )
1136
+ dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
1137
+ dp_scatter(residual, hidden_states, forward_batch)
1138
+ hidden_states = self.post_attention_layernorm(hidden_states)
1091
1139
  else:
1092
1140
  hidden_states = tensor_model_parallel_all_reduce(hidden_states)
1093
1141
  hidden_states, residual = self.post_attention_layernorm(
@@ -1101,6 +1149,7 @@ class DeepseekV2DecoderLayer(nn.Module):
1101
1149
  # Fully Connected
1102
1150
  hidden_states = self.mlp(hidden_states)
1103
1151
 
1152
+ # TODO(ch-wan): ues reduce-scatter in MLP to avoid this scatter
1104
1153
  # Scatter
1105
1154
  if self.dp_size != 1:
1106
1155
  # important: forward batch.gathered_buffer is used both after scatter and after gather.
@@ -1113,6 +1162,82 @@ class DeepseekV2DecoderLayer(nn.Module):
1113
1162
 
1114
1163
  return hidden_states, residual
1115
1164
 
1165
+ def forward_deepep(
1166
+ self,
1167
+ positions: torch.Tensor,
1168
+ hidden_states: torch.Tensor,
1169
+ forward_batch: ForwardBatch,
1170
+ residual: Optional[torch.Tensor],
1171
+ ) -> torch.Tensor:
1172
+
1173
+ if hidden_states.shape[0] == 0:
1174
+ residual = hidden_states
1175
+ else:
1176
+ if residual is None:
1177
+ residual = hidden_states
1178
+ hidden_states = self.input_layernorm(hidden_states)
1179
+ else:
1180
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
1181
+
1182
+ if self.attn_tp_size != 1 and self.input_is_scattered:
1183
+ hidden_states, local_hidden_states = (
1184
+ forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
1185
+ hidden_states,
1186
+ )
1187
+ tp_all_gather(
1188
+ list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
1189
+ )
1190
+
1191
+ # Self Attention
1192
+ hidden_states = self.self_attn(
1193
+ positions=positions,
1194
+ hidden_states=hidden_states,
1195
+ forward_batch=forward_batch,
1196
+ )
1197
+
1198
+ if self.attn_tp_size != 1:
1199
+ if self.input_is_scattered:
1200
+ tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
1201
+ hidden_states = tensor_list[self.attn_tp_rank]
1202
+ tp_reduce_scatter(hidden_states, tensor_list)
1203
+ if hidden_states.shape[0] != 0:
1204
+ hidden_states, residual = self.post_attention_layernorm(
1205
+ hidden_states, residual
1206
+ )
1207
+ else:
1208
+ if self.attn_tp_rank == 0:
1209
+ hidden_states += residual
1210
+ tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
1211
+ hidden_states = tensor_list[self.attn_tp_rank]
1212
+ tp_reduce_scatter(hidden_states, tensor_list)
1213
+ residual = hidden_states
1214
+ if hidden_states.shape[0] != 0:
1215
+ hidden_states = self.post_attention_layernorm(hidden_states)
1216
+ else:
1217
+ if hidden_states.shape[0] != 0:
1218
+ hidden_states, residual = self.post_attention_layernorm(
1219
+ hidden_states, residual
1220
+ )
1221
+ hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
1222
+
1223
+ if self.is_last_layer and self.attn_tp_size != 1:
1224
+ hidden_states, local_hidden_states = (
1225
+ forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
1226
+ hidden_states,
1227
+ )
1228
+ tp_all_gather(
1229
+ list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
1230
+ )
1231
+ residual, local_residual = (
1232
+ forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
1233
+ residual,
1234
+ )
1235
+ tp_all_gather(
1236
+ list(residual.tensor_split(self.attn_tp_size)), local_residual
1237
+ )
1238
+
1239
+ return hidden_states, residual
1240
+
1116
1241
 
1117
1242
  class DeepseekV2Model(nn.Module):
1118
1243
 
@@ -47,6 +47,12 @@ from sglang.srt.model_loader.weight_utils import (
47
47
  from sglang.srt.utils import add_prefix, make_layers
48
48
 
49
49
 
50
+ # Aligned with HF's implementation, using sliding window inclusive with the last token
51
+ # SGLang assumes exclusive
52
+ def get_attention_sliding_window_size(config):
53
+ return config.sliding_window - 1
54
+
55
+
50
56
  # Adapted from:
51
57
  # https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/gemma3.py
52
58
  def extract_layer_index(prefix: str) -> int:
@@ -170,7 +176,7 @@ class Gemma3Attention(nn.Module):
170
176
  self.rope_scaling = {"rope_type": "default"}
171
177
  # FIXME(mick): idk why vllm does this
172
178
  # self.sliding_window = config.interleaved_sliding_window
173
- self.sliding_window = config.sliding_window
179
+ self.sliding_window = get_attention_sliding_window_size(config)
174
180
  else:
175
181
  # Global attention. Use the values in config.json.
176
182
  self.rope_theta = config.rope_theta
@@ -184,6 +190,8 @@ class Gemma3Attention(nn.Module):
184
190
  num_kv_heads=self.num_kv_heads,
185
191
  layer_id=layer_id,
186
192
  logit_cap=getattr(self.config, "attn_logit_softcapping", None),
193
+ # Module must also define `get_attention_sliding_window_size` to correctly initialize
194
+ # attention backend in `ForwardBatch`.
187
195
  sliding_window_size=self.sliding_window,
188
196
  prefix=add_prefix("attn", prefix),
189
197
  )
@@ -609,6 +617,9 @@ class Gemma3ForCausalLM(PreTrainedModel):
609
617
  def get_input_embeddings(self) -> nn.Embedding:
610
618
  return self.model.embed_tokens
611
619
 
620
+ def get_attention_sliding_window_size(self):
621
+ return get_attention_sliding_window_size(self.config)
622
+
612
623
  def dtype(self) -> torch.dtype:
613
624
  return next(self.parameters()).dtype
614
625
 
@@ -621,7 +632,6 @@ class Gemma3ForCausalLM(PreTrainedModel):
621
632
  input_embeds: torch.Tensor = None,
622
633
  **kwargs,
623
634
  ) -> LogitsProcessor:
624
-
625
635
  hidden_states = self.model(
626
636
  input_ids, positions, forward_batch, input_embeds, **kwargs
627
637
  )
@@ -268,6 +268,12 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
268
268
  def get_input_embeddings(self) -> nn.Embedding:
269
269
  return self.language_model.get_input_embeddings()
270
270
 
271
+ def get_attention_sliding_window_size(self):
272
+ """
273
+ This value is used to initialize attention backends in `ForwardBatch`.
274
+ """
275
+ return self.language_model.get_attention_sliding_window_size()
276
+
271
277
  def get_image_feature(self, image_input: MultimodalInputs):
272
278
  """
273
279
  Projects the last hidden state from the vision model into language model space.
@@ -20,7 +20,7 @@ import os
20
20
  import time
21
21
  import uuid
22
22
  from http import HTTPStatus
23
- from typing import Any, Dict, List, Set
23
+ from typing import Dict, List
24
24
 
25
25
  from fastapi import HTTPException, Request, UploadFile
26
26
  from fastapi.responses import ORJSONResponse, StreamingResponse
@@ -645,7 +645,7 @@ def v1_generate_response(
645
645
  "index": 0,
646
646
  "text": text,
647
647
  "logprobs": logprobs,
648
- "finish_reason": (finish_reason["type"] if finish_reason else ""),
648
+ "finish_reason": finish_reason["type"] if finish_reason else None,
649
649
  "matched_stop": (
650
650
  finish_reason["matched"]
651
651
  if finish_reason and "matched" in finish_reason
@@ -657,7 +657,7 @@ def v1_generate_response(
657
657
  index=idx,
658
658
  text=text,
659
659
  logprobs=logprobs,
660
- finish_reason=(finish_reason["type"] if finish_reason else ""),
660
+ finish_reason=finish_reason["type"] if finish_reason else None,
661
661
  matched_stop=(
662
662
  finish_reason["matched"]
663
663
  if finish_reason and "matched" in finish_reason
@@ -805,7 +805,7 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
805
805
  index=index,
806
806
  text=delta,
807
807
  logprobs=logprobs,
808
- finish_reason=(finish_reason["type"] if finish_reason else ""),
808
+ finish_reason=finish_reason["type"] if finish_reason else None,
809
809
  matched_stop=(
810
810
  finish_reason["matched"]
811
811
  if finish_reason and "matched" in finish_reason
@@ -1119,7 +1119,9 @@ def v1_chat_generate_response(
1119
1119
  if logprobs:
1120
1120
  logprobs = to_openai_style_logprobs(
1121
1121
  output_token_logprobs=ret_item["meta_info"]["output_token_logprobs"],
1122
- output_top_logprobs=ret_item["meta_info"]["output_top_logprobs"],
1122
+ output_top_logprobs=ret_item["meta_info"].get(
1123
+ "output_top_logprobs", None
1124
+ ),
1123
1125
  )
1124
1126
  token_logprobs = []
1125
1127
  for token_idx, (token, logprob) in enumerate(
@@ -1216,7 +1218,7 @@ def v1_chat_generate_response(
1216
1218
  "reasoning_content": reasoning_text if reasoning_text else None,
1217
1219
  },
1218
1220
  "logprobs": choice_logprobs.model_dump() if choice_logprobs else None,
1219
- "finish_reason": (finish_reason["type"] if finish_reason else ""),
1221
+ "finish_reason": finish_reason["type"] if finish_reason else None,
1220
1222
  "matched_stop": (
1221
1223
  finish_reason["matched"]
1222
1224
  if finish_reason and "matched" in finish_reason
@@ -1233,7 +1235,7 @@ def v1_chat_generate_response(
1233
1235
  reasoning_content=reasoning_text if reasoning_text else None,
1234
1236
  ),
1235
1237
  logprobs=choice_logprobs,
1236
- finish_reason=(finish_reason["type"] if finish_reason else ""),
1238
+ finish_reason=finish_reason["type"] if finish_reason else None,
1237
1239
  matched_stop=(
1238
1240
  finish_reason["matched"]
1239
1241
  if finish_reason and "matched" in finish_reason
@@ -1329,9 +1331,9 @@ async def v1_chat_completions(
1329
1331
  output_token_logprobs=content["meta_info"][
1330
1332
  "output_token_logprobs"
1331
1333
  ][n_prev_token:],
1332
- output_top_logprobs=content["meta_info"][
1333
- "output_top_logprobs"
1334
- ][n_prev_token:],
1334
+ output_top_logprobs=content["meta_info"].get(
1335
+ "output_top_logprobs", []
1336
+ )[n_prev_token:],
1335
1337
  )
1336
1338
 
1337
1339
  n_prev_token = len(
@@ -1377,23 +1379,11 @@ async def v1_chat_completions(
1377
1379
  if is_first:
1378
1380
  # First chunk with role
1379
1381
  is_first = False
1380
- if (
1381
- tokenizer_manager.server_args.reasoning_parser
1382
- and request.separate_reasoning
1383
- ):
1384
- delta = DeltaMessage(
1385
- role="assistant", reasoning_content=None
1386
- )
1387
- else:
1388
- delta = DeltaMessage(role="assistant", content=None)
1382
+ delta = DeltaMessage(role="assistant")
1389
1383
  choice_data = ChatCompletionResponseStreamChoice(
1390
1384
  index=index,
1391
1385
  delta=delta,
1392
- finish_reason=(
1393
- None
1394
- if finish_reason_type and len(finish_reason_type) == 0
1395
- else finish_reason_type
1396
- ),
1386
+ finish_reason=finish_reason_type,
1397
1387
  matched_stop=(
1398
1388
  finish_reason["matched"]
1399
1389
  if finish_reason and "matched" in finish_reason
@@ -1434,12 +1424,7 @@ async def v1_chat_completions(
1434
1424
  reasoning_text if reasoning_text else None
1435
1425
  )
1436
1426
  ),
1437
- finish_reason=(
1438
- None
1439
- if finish_reason_type
1440
- and len(finish_reason_type) == 0
1441
- else finish_reason_type
1442
- ),
1427
+ finish_reason=finish_reason_type,
1443
1428
  )
1444
1429
  chunk = ChatCompletionStreamResponse(
1445
1430
  id=content["meta_info"]["id"],
@@ -1471,12 +1456,7 @@ async def v1_chat_completions(
1471
1456
  delta=DeltaMessage(
1472
1457
  content=normal_text if normal_text else None
1473
1458
  ),
1474
- finish_reason=(
1475
- None
1476
- if finish_reason_type
1477
- and len(finish_reason_type) == 0
1478
- else finish_reason_type
1479
- ),
1459
+ finish_reason=finish_reason_type,
1480
1460
  )
1481
1461
  chunk = ChatCompletionStreamResponse(
1482
1462
  id=content["meta_info"]["id"],
@@ -1490,11 +1470,7 @@ async def v1_chat_completions(
1490
1470
  for call_item in calls:
1491
1471
  # transform call_item -> FunctionResponse + ToolCall
1492
1472
 
1493
- if (
1494
- content["meta_info"]["finish_reason"]
1495
- and content["meta_info"]["finish_reason"]["type"]
1496
- == "stop"
1497
- ):
1473
+ if finish_reason_type == "stop":
1498
1474
  latest_delta_len = 0
1499
1475
  if isinstance(call_item.parameters, str):
1500
1476
  latest_delta_len = len(call_item.parameters)
@@ -1515,6 +1491,8 @@ async def v1_chat_completions(
1515
1491
  )
1516
1492
  call_item.parameters = remaining_call
1517
1493
 
1494
+ finish_reason_type = "tool_calls"
1495
+
1518
1496
  tool_call = ToolCall(
1519
1497
  id=str(call_item.tool_index),
1520
1498
  function=FunctionResponse(
@@ -1524,10 +1502,13 @@ async def v1_chat_completions(
1524
1502
  )
1525
1503
  choice_data = ChatCompletionResponseStreamChoice(
1526
1504
  index=index,
1527
- delta=DeltaMessage(
1528
- role="assistant", tool_calls=[tool_call]
1529
- ),
1530
- finish_reason="tool_call",
1505
+ delta=DeltaMessage(tool_calls=[tool_call]),
1506
+ finish_reason=(
1507
+ None
1508
+ if request.stream_options
1509
+ and request.stream_options.include_usage
1510
+ else finish_reason_type
1511
+ ), # additional chunk will be return
1531
1512
  )
1532
1513
  chunk = ChatCompletionStreamResponse(
1533
1514
  id=content["meta_info"]["id"],
@@ -1542,30 +1523,44 @@ async def v1_chat_completions(
1542
1523
 
1543
1524
  else:
1544
1525
  # No tool calls => just treat this as normal text
1545
- choice_data = ChatCompletionResponseStreamChoice(
1546
- index=index,
1547
- delta=DeltaMessage(content=delta if delta else None),
1548
- finish_reason=(
1549
- None
1550
- if finish_reason_type and len(finish_reason_type) == 0
1551
- else finish_reason_type
1552
- ),
1553
- matched_stop=(
1554
- finish_reason["matched"]
1555
- if finish_reason and "matched" in finish_reason
1556
- else None
1557
- ),
1558
- logprobs=choice_logprobs,
1559
- )
1560
- chunk = ChatCompletionStreamResponse(
1561
- id=content["meta_info"]["id"],
1562
- created=created,
1563
- choices=[choice_data],
1564
- model=request.model,
1565
- )
1566
- yield f"data: {chunk.model_dump_json()}\n\n"
1567
- stream_buffers[index] = new_stream_buffer
1568
- is_firsts[index] = is_first
1526
+ if delta or not (
1527
+ request.stream_options
1528
+ and request.stream_options.include_usage
1529
+ ):
1530
+ choice_data = ChatCompletionResponseStreamChoice(
1531
+ index=index,
1532
+ delta=DeltaMessage(content=delta if delta else None),
1533
+ finish_reason=(
1534
+ None
1535
+ if request.stream_options
1536
+ and request.stream_options.include_usage
1537
+ else finish_reason_type
1538
+ ),
1539
+ matched_stop=(
1540
+ finish_reason["matched"]
1541
+ if finish_reason and "matched" in finish_reason
1542
+ else None
1543
+ ),
1544
+ logprobs=choice_logprobs,
1545
+ )
1546
+ chunk = ChatCompletionStreamResponse(
1547
+ id=content["meta_info"]["id"],
1548
+ created=created,
1549
+ choices=[choice_data],
1550
+ model=request.model,
1551
+ )
1552
+ yield f"data: {chunk.model_dump_json()}\n\n"
1553
+ stream_buffers[index] = new_stream_buffer
1554
+ is_firsts[index] = is_first
1555
+ if finish_reason_type == "stop" and request.tool_choice != "none":
1556
+ parser = FunctionCallParser(
1557
+ tools=request.tools,
1558
+ tool_call_parser=tokenizer_manager.server_args.tool_call_parser,
1559
+ )
1560
+ if parser.has_tool_call(new_stream_buffer):
1561
+ # if the stream ends with empty string after tool calls
1562
+ finish_reason_type = "tool_calls"
1563
+
1569
1564
  if request.stream_options and request.stream_options.include_usage:
1570
1565
  total_prompt_tokens = sum(
1571
1566
  tokens
@@ -1590,17 +1585,22 @@ async def v1_chat_completions(
1590
1585
  prompt_tokens_details=prompt_tokens_details,
1591
1586
  )
1592
1587
 
1593
- final_usage_chunk = ChatCompletionStreamResponse(
1594
- id=content["meta_info"]["id"],
1595
- created=created,
1596
- choices=[],
1597
- model=request.model,
1598
- usage=usage,
1599
- )
1600
- final_usage_data = final_usage_chunk.model_dump_json(
1601
- exclude_none=True
1602
- )
1603
- yield f"data: {final_usage_data}\n\n"
1588
+ else:
1589
+ usage = None
1590
+ final_usage_chunk = ChatCompletionStreamResponse(
1591
+ id=content["meta_info"]["id"],
1592
+ created=created,
1593
+ choices=[
1594
+ ChatCompletionResponseStreamChoice(
1595
+ index=index,
1596
+ delta=DeltaMessage(),
1597
+ finish_reason=finish_reason_type,
1598
+ )
1599
+ ],
1600
+ model=request.model,
1601
+ usage=usage,
1602
+ )
1603
+ yield f"data: {final_usage_chunk.model_dump_json()}\n\n"
1604
1604
  except ValueError as e:
1605
1605
  error = create_streaming_error_response(str(e))
1606
1606
  yield f"data: {error}\n\n"
@@ -1653,18 +1653,19 @@ def v1_embedding_request(all_requests, tokenizer_manager):
1653
1653
  elif isinstance(prompt, list) and isinstance(
1654
1654
  prompt[0], MultimodalEmbeddingInput
1655
1655
  ):
1656
- assert (
1657
- chat_template_name is not None
1658
- ), "chat_template_name is required for multimodal inputs"
1659
1656
  texts = []
1660
1657
  images = []
1661
1658
  for item in prompt:
1662
- texts.append(item.text if item.text is not None else None)
1659
+ # TODO simply use padding for text, we should use a better way to handle this
1660
+ texts.append(item.text if item.text is not None else "padding")
1663
1661
  images.append(item.image if item.image is not None else None)
1664
- convs = generate_embedding_convs(texts, images, chat_template_name)
1665
1662
  generate_prompts = []
1666
- for conv in convs:
1667
- generate_prompts.append(conv.get_prompt())
1663
+ if chat_template_name is not None:
1664
+ convs = generate_embedding_convs(texts, images, chat_template_name)
1665
+ for conv in convs:
1666
+ generate_prompts.append(conv.get_prompt())
1667
+ else:
1668
+ generate_prompts = texts
1668
1669
  if len(generate_prompts) == 1:
1669
1670
  prompt_kwargs = {"text": generate_prompts[0], "image_data": images[0]}
1670
1671
  else: