sglang 0.3.0__py3-none-any.whl → 0.3.1.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. sglang/bench_latency.py +17 -8
  2. sglang/bench_serving.py +33 -38
  3. sglang/global_config.py +5 -17
  4. sglang/lang/backend/runtime_endpoint.py +5 -2
  5. sglang/lang/interpreter.py +1 -4
  6. sglang/launch_server.py +3 -6
  7. sglang/launch_server_llavavid.py +7 -8
  8. sglang/srt/{model_config.py → configs/model_config.py} +5 -0
  9. sglang/srt/constrained/__init__.py +2 -0
  10. sglang/srt/constrained/fsm_cache.py +33 -38
  11. sglang/srt/constrained/jump_forward.py +0 -1
  12. sglang/srt/conversation.py +4 -1
  13. sglang/srt/hf_transformers_utils.py +1 -3
  14. sglang/srt/layers/activation.py +12 -0
  15. sglang/srt/layers/attention_backend.py +480 -0
  16. sglang/srt/layers/flashinfer_utils.py +235 -0
  17. sglang/srt/layers/fused_moe/layer.py +27 -7
  18. sglang/srt/layers/layernorm.py +12 -0
  19. sglang/srt/layers/logits_processor.py +64 -77
  20. sglang/srt/layers/radix_attention.py +11 -161
  21. sglang/srt/layers/sampler.py +38 -122
  22. sglang/srt/layers/torchao_utils.py +75 -0
  23. sglang/srt/layers/{decode_attention.py → triton_attention/decode_attention.py} +67 -63
  24. sglang/srt/layers/{extend_attention.py → triton_attention/extend_attention.py} +40 -132
  25. sglang/srt/layers/{prefill_attention.py → triton_attention/prefill_attention.py} +13 -7
  26. sglang/srt/lora/lora.py +403 -0
  27. sglang/srt/lora/lora_config.py +43 -0
  28. sglang/srt/lora/lora_manager.py +259 -0
  29. sglang/srt/managers/controller_multi.py +1 -5
  30. sglang/srt/managers/controller_single.py +0 -5
  31. sglang/srt/managers/io_struct.py +16 -1
  32. sglang/srt/managers/policy_scheduler.py +122 -5
  33. sglang/srt/managers/schedule_batch.py +105 -71
  34. sglang/srt/managers/tokenizer_manager.py +17 -8
  35. sglang/srt/managers/tp_worker.py +188 -121
  36. sglang/srt/model_executor/cuda_graph_runner.py +69 -133
  37. sglang/srt/model_executor/forward_batch_info.py +35 -312
  38. sglang/srt/model_executor/model_runner.py +123 -154
  39. sglang/srt/models/baichuan.py +416 -0
  40. sglang/srt/models/chatglm.py +1 -5
  41. sglang/srt/models/commandr.py +1 -5
  42. sglang/srt/models/dbrx.py +1 -5
  43. sglang/srt/models/deepseek.py +1 -5
  44. sglang/srt/models/deepseek_v2.py +7 -6
  45. sglang/srt/models/exaone.py +1 -5
  46. sglang/srt/models/gemma.py +1 -5
  47. sglang/srt/models/gemma2.py +1 -5
  48. sglang/srt/models/gpt_bigcode.py +1 -5
  49. sglang/srt/models/grok.py +1 -5
  50. sglang/srt/models/internlm2.py +1 -5
  51. sglang/srt/models/llama.py +51 -5
  52. sglang/srt/models/llama_classification.py +1 -20
  53. sglang/srt/models/llava.py +30 -5
  54. sglang/srt/models/llavavid.py +2 -2
  55. sglang/srt/models/minicpm.py +1 -5
  56. sglang/srt/models/minicpm3.py +669 -0
  57. sglang/srt/models/mixtral.py +6 -5
  58. sglang/srt/models/mixtral_quant.py +1 -5
  59. sglang/srt/models/olmoe.py +415 -0
  60. sglang/srt/models/qwen.py +1 -5
  61. sglang/srt/models/qwen2.py +1 -5
  62. sglang/srt/models/qwen2_moe.py +6 -5
  63. sglang/srt/models/stablelm.py +1 -5
  64. sglang/srt/models/xverse.py +375 -0
  65. sglang/srt/models/xverse_moe.py +445 -0
  66. sglang/srt/openai_api/adapter.py +65 -46
  67. sglang/srt/openai_api/protocol.py +11 -3
  68. sglang/srt/sampling/sampling_batch_info.py +46 -80
  69. sglang/srt/server.py +30 -15
  70. sglang/srt/server_args.py +163 -28
  71. sglang/srt/utils.py +19 -51
  72. sglang/test/few_shot_gsm8k.py +132 -0
  73. sglang/test/runners.py +114 -22
  74. sglang/test/test_programs.py +7 -5
  75. sglang/test/test_utils.py +85 -2
  76. sglang/utils.py +32 -37
  77. sglang/version.py +1 -1
  78. {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/METADATA +30 -18
  79. sglang-0.3.1.post1.dist-info/RECORD +130 -0
  80. {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/WHEEL +1 -1
  81. sglang-0.3.0.dist-info/RECORD +0 -118
  82. {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/LICENSE +0 -0
  83. {sglang-0.3.0.dist-info → sglang-0.3.1.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,445 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
16
+ """Inference-only XVERSE MoE model."""
17
+ from typing import Any, Dict, Iterable, Optional, Tuple
18
+
19
+ import torch
20
+ from torch import nn
21
+ from transformers import PretrainedConfig
22
+ from vllm.config import CacheConfig
23
+ from vllm.distributed import (
24
+ get_tensor_model_parallel_rank,
25
+ get_tensor_model_parallel_world_size,
26
+ tensor_model_parallel_all_reduce,
27
+ )
28
+ from vllm.model_executor.layers.activation import SiluAndMul
29
+ from vllm.model_executor.layers.fused_moe import fused_moe
30
+ from vllm.model_executor.layers.layernorm import RMSNorm
31
+ from vllm.model_executor.layers.linear import (
32
+ MergedColumnParallelLinear,
33
+ QKVParallelLinear,
34
+ ReplicatedLinear,
35
+ RowParallelLinear,
36
+ )
37
+ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
38
+ from vllm.model_executor.layers.rotary_embedding import get_rope
39
+ from vllm.model_executor.layers.vocab_parallel_embedding import (
40
+ ParallelLMHead,
41
+ VocabParallelEmbedding,
42
+ )
43
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
44
+
45
+ from sglang.srt.layers.logits_processor import LogitsProcessor
46
+ from sglang.srt.layers.radix_attention import RadixAttention
47
+ from sglang.srt.model_executor.forward_batch_info import InputMetadata
48
+
49
+
50
+ class XverseMLP(nn.Module):
51
+
52
+ def __init__(
53
+ self,
54
+ hidden_size: int,
55
+ intermediate_size: int,
56
+ hidden_act: str,
57
+ quant_config: Optional[QuantizationConfig] = None,
58
+ reduce_results: bool = True,
59
+ ) -> None:
60
+ super().__init__()
61
+ self.gate_up_proj = MergedColumnParallelLinear(
62
+ hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config
63
+ )
64
+ self.down_proj = RowParallelLinear(
65
+ intermediate_size,
66
+ hidden_size,
67
+ bias=False,
68
+ quant_config=quant_config,
69
+ reduce_results=reduce_results,
70
+ )
71
+ if hidden_act != "silu":
72
+ raise ValueError(
73
+ f"Unsupported activation: {hidden_act}. "
74
+ "Only silu is supported for now."
75
+ )
76
+ self.act_fn = SiluAndMul()
77
+
78
+ def forward(self, x):
79
+ gate_up, _ = self.gate_up_proj(x)
80
+ x = self.act_fn(gate_up)
81
+ x, _ = self.down_proj(x)
82
+ return x
83
+
84
+
85
+ class XverseMoE(nn.Module):
86
+
87
+ def __init__(
88
+ self,
89
+ config: PretrainedConfig,
90
+ quant_config: Optional[QuantizationConfig] = None,
91
+ ):
92
+ super().__init__()
93
+ self.config = config
94
+ self.rank = get_tensor_model_parallel_rank()
95
+ self.tp_size = get_tensor_model_parallel_world_size()
96
+ self.n_routed_experts = config.num_experts
97
+ self.top_k = config.moe_top_k
98
+ if self.tp_size > self.n_routed_experts:
99
+ raise ValueError(
100
+ f"Tensor parallel size {self.tp_size} is greater than "
101
+ f"the number of experts {self.n_routed_experts}."
102
+ )
103
+
104
+ self.experts = nn.ModuleList(
105
+ [
106
+ XverseMLP(
107
+ hidden_size=config.hidden_size,
108
+ intermediate_size=config.intermediate_size,
109
+ hidden_act=config.hidden_act,
110
+ quant_config=quant_config,
111
+ reduce_results=False,
112
+ )
113
+ for _ in range(self.n_routed_experts)
114
+ ]
115
+ )
116
+ self.pack_params()
117
+
118
+ self.router = ReplicatedLinear(
119
+ config.hidden_size, self.n_routed_experts, bias=False, quant_config=None
120
+ )
121
+
122
+ if config.num_shared_experts is not None:
123
+ intermediate_size = config.intermediate_size * config.num_shared_experts
124
+ self.shared_experts = XverseMLP(
125
+ hidden_size=config.hidden_size,
126
+ intermediate_size=intermediate_size,
127
+ hidden_act=config.hidden_act,
128
+ quant_config=quant_config,
129
+ reduce_results=False,
130
+ )
131
+
132
+ def pack_params(self):
133
+ w1 = []
134
+ w2 = []
135
+ for expert in self.experts:
136
+ w1.append(expert.gate_up_proj.weight)
137
+ w2.append(expert.down_proj.weight)
138
+ self.w1 = torch._utils._flatten_dense_tensors(w1)
139
+ w1s = torch._utils._unflatten_dense_tensors(self.w1, w1)
140
+ for data, param in zip(w1s, w1):
141
+ param.data = data
142
+ self.w1 = self.w1.view(len(w1), *w1s[0].shape)
143
+
144
+ self.w2 = torch._utils._flatten_dense_tensors(w2)
145
+ w2s = torch._utils._unflatten_dense_tensors(self.w2, w2)
146
+ for data, param in zip(w2s, w2):
147
+ param.data = data
148
+
149
+ self.w2 = self.w2.view(len(w2), *w2s[0].shape)
150
+
151
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
152
+ num_tokens, hidden_dim = hidden_states.shape
153
+ hidden_states = hidden_states.view(-1, hidden_dim)
154
+ if self.config.num_shared_experts is not None:
155
+ shared_output = self.shared_experts(hidden_states)
156
+ # router_logits: (num_tokens, n_experts)
157
+ router_logits, _ = self.router(hidden_states)
158
+ final_hidden_states = fused_moe(
159
+ hidden_states,
160
+ self.w1,
161
+ self.w2,
162
+ router_logits,
163
+ self.top_k,
164
+ renormalize=getattr(self.config, "norm_topk_prob", False),
165
+ inplace=True,
166
+ )
167
+
168
+ if self.config.num_shared_experts is not None:
169
+ final_hidden_states = final_hidden_states + shared_output
170
+ final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
171
+
172
+ return final_hidden_states.view(num_tokens, hidden_dim)
173
+
174
+
175
+ class XverseAttention(nn.Module):
176
+
177
+ def __init__(
178
+ self,
179
+ hidden_size: int,
180
+ num_heads: int,
181
+ num_kv_heads: int,
182
+ layer_id: int = 0,
183
+ rope_theta: float = 10000,
184
+ rope_scaling: Optional[Dict[str, Any]] = None,
185
+ max_position_embeddings: int = 8192,
186
+ cache_config: Optional[CacheConfig] = None,
187
+ quant_config: Optional[QuantizationConfig] = None,
188
+ ) -> None:
189
+ super().__init__()
190
+ self.hidden_size = hidden_size
191
+ tp_size = get_tensor_model_parallel_world_size()
192
+ self.total_num_heads = num_heads
193
+ assert self.total_num_heads % tp_size == 0
194
+ self.num_heads = self.total_num_heads // tp_size
195
+ self.total_num_kv_heads = num_kv_heads
196
+ if self.total_num_kv_heads >= tp_size:
197
+ # Number of KV heads is greater than TP size, so we partition
198
+ # the KV heads across multiple tensor parallel GPUs.
199
+ assert self.total_num_kv_heads % tp_size == 0
200
+ else:
201
+ # Number of KV heads is less than TP size, so we replicate
202
+ # the KV heads across multiple tensor parallel GPUs.
203
+ assert tp_size % self.total_num_kv_heads == 0
204
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
205
+ self.head_dim = hidden_size // self.total_num_heads
206
+ self.q_size = self.num_heads * self.head_dim
207
+ self.kv_size = self.num_kv_heads * self.head_dim
208
+ self.scaling = self.head_dim**-0.5
209
+ self.rope_theta = rope_theta
210
+ self.max_position_embeddings = max_position_embeddings
211
+
212
+ self.qkv_proj = QKVParallelLinear(
213
+ hidden_size,
214
+ self.head_dim,
215
+ self.total_num_heads,
216
+ self.total_num_kv_heads,
217
+ bias=False,
218
+ quant_config=quant_config,
219
+ )
220
+
221
+ self.o_proj = RowParallelLinear(
222
+ self.total_num_heads * self.head_dim,
223
+ hidden_size,
224
+ bias=False,
225
+ quant_config=quant_config,
226
+ )
227
+
228
+ self.rotary_emb = get_rope(
229
+ self.head_dim,
230
+ rotary_dim=self.head_dim,
231
+ max_position=max_position_embeddings,
232
+ base=rope_theta,
233
+ rope_scaling=rope_scaling,
234
+ )
235
+ self.attn = RadixAttention(
236
+ self.num_heads,
237
+ self.head_dim,
238
+ self.scaling,
239
+ num_kv_heads=self.num_kv_heads,
240
+ layer_id=layer_id,
241
+ )
242
+
243
+ def forward(
244
+ self,
245
+ positions: torch.Tensor,
246
+ hidden_states: torch.Tensor,
247
+ input_metadata: InputMetadata,
248
+ ) -> torch.Tensor:
249
+ qkv, _ = self.qkv_proj(hidden_states)
250
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
251
+ q, k = self.rotary_emb(positions, q, k)
252
+ attn_output = self.attn(q, k, v, input_metadata)
253
+ output, _ = self.o_proj(attn_output)
254
+ return output
255
+
256
+
257
+ class XverseDecoderLayer(nn.Module):
258
+
259
+ def __init__(
260
+ self,
261
+ config: PretrainedConfig,
262
+ layer_id: int,
263
+ cache_config: Optional[CacheConfig] = None,
264
+ quant_config: Optional[QuantizationConfig] = None,
265
+ ) -> None:
266
+ super().__init__()
267
+ self.hidden_size = config.hidden_size
268
+ rope_theta = getattr(config, "rope_theta", 10000)
269
+ rope_scaling = getattr(config, "rope_scaling", None)
270
+ max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
271
+ num_key_value_heads = getattr(
272
+ config, "num_key_value_heads", config.num_attention_heads
273
+ )
274
+ self.self_attn = XverseAttention(
275
+ hidden_size=self.hidden_size,
276
+ num_heads=config.num_attention_heads,
277
+ num_kv_heads=num_key_value_heads,
278
+ layer_id=layer_id,
279
+ rope_theta=rope_theta,
280
+ rope_scaling=rope_scaling,
281
+ max_position_embeddings=max_position_embeddings,
282
+ cache_config=cache_config,
283
+ quant_config=quant_config,
284
+ )
285
+ if config.num_experts is not None:
286
+ self.mlp = XverseMoE(config=config, quant_config=quant_config)
287
+ else:
288
+ self.mlp = XverseMLP(
289
+ hidden_size=config.hidden_size,
290
+ intermediate_size=config.intermediate_size,
291
+ hidden_act=config.hidden_act,
292
+ quant_config=quant_config,
293
+ )
294
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
295
+ self.post_attention_layernorm = RMSNorm(
296
+ config.hidden_size, eps=config.rms_norm_eps
297
+ )
298
+
299
+ def forward(
300
+ self,
301
+ positions: torch.Tensor,
302
+ hidden_states: torch.Tensor,
303
+ input_metadata: InputMetadata,
304
+ residual: Optional[torch.Tensor],
305
+ ) -> torch.Tensor:
306
+ # Self Attention
307
+ if residual is None:
308
+ residual = hidden_states
309
+ hidden_states = self.input_layernorm(hidden_states)
310
+ else:
311
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
312
+ hidden_states = self.self_attn(
313
+ positions=positions,
314
+ hidden_states=hidden_states,
315
+ input_metadata=input_metadata,
316
+ )
317
+
318
+ # Fully Connected
319
+ hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
320
+ hidden_states = self.mlp(hidden_states)
321
+ return hidden_states, residual
322
+
323
+
324
+ class XverseModel(nn.Module):
325
+
326
+ fall_back_to_pt_during_load = False
327
+
328
+ def __init__(
329
+ self,
330
+ config: PretrainedConfig,
331
+ cache_config: Optional[CacheConfig] = None,
332
+ quant_config: Optional[QuantizationConfig] = None,
333
+ ) -> None:
334
+ super().__init__()
335
+ self.padding_idx = config.pad_token_id
336
+ self.vocab_size = config.vocab_size
337
+
338
+ self.embed_tokens = VocabParallelEmbedding(
339
+ config.vocab_size,
340
+ config.hidden_size,
341
+ )
342
+ self.layers = nn.ModuleList(
343
+ [
344
+ XverseDecoderLayer(
345
+ config, layer_id, cache_config, quant_config=quant_config
346
+ )
347
+ for layer_id in range(config.num_hidden_layers)
348
+ ]
349
+ )
350
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
351
+
352
+ def forward(
353
+ self,
354
+ input_ids: torch.Tensor,
355
+ positions: torch.Tensor,
356
+ input_metadata: InputMetadata,
357
+ ) -> torch.Tensor:
358
+ hidden_states = self.embed_tokens(input_ids)
359
+ residual = None
360
+ for i in range(len(self.layers)):
361
+ layer = self.layers[i]
362
+ hidden_states, residual = layer(
363
+ positions, hidden_states, input_metadata, residual
364
+ )
365
+ hidden_states, _ = self.norm(hidden_states, residual)
366
+ return hidden_states
367
+
368
+
369
+ class XverseMoeForCausalLM(nn.Module):
370
+
371
+ def __init__(
372
+ self,
373
+ config: PretrainedConfig,
374
+ cache_config: Optional[CacheConfig] = None,
375
+ quant_config: Optional[QuantizationConfig] = None,
376
+ ) -> None:
377
+ super().__init__()
378
+ self.config = config
379
+ self.quant_config = quant_config
380
+ self.model = XverseModel(config, cache_config, quant_config)
381
+ self.lm_head = ParallelLMHead(
382
+ config.vocab_size, config.hidden_size, quant_config=quant_config
383
+ )
384
+ self.logits_processor = LogitsProcessor(config)
385
+
386
+ self.param_dict = dict(self.named_parameters())
387
+
388
+ @torch.no_grad()
389
+ def forward(
390
+ self,
391
+ input_ids: torch.Tensor,
392
+ positions: torch.Tensor,
393
+ input_metadata: InputMetadata,
394
+ ) -> torch.Tensor:
395
+ hidden_states = self.model(input_ids, positions, input_metadata)
396
+ return self.logits_processor(
397
+ input_ids, hidden_states, self.lm_head.weight, input_metadata
398
+ )
399
+
400
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
401
+ stacked_params_mapping = [
402
+ # (param_name, shard_name, shard_id)
403
+ ("qkv_proj", "q_proj", "q"),
404
+ ("qkv_proj", "k_proj", "k"),
405
+ ("qkv_proj", "v_proj", "v"),
406
+ ("gate_up_proj", "gate_proj", 0),
407
+ ("gate_up_proj", "up_proj", 1),
408
+ ]
409
+
410
+ params_dict = self.param_dict
411
+
412
+ for name, loaded_weight in weights:
413
+ if "rotary_emb.inv_freq" in name:
414
+ continue
415
+ for param_name, weight_name, shard_id in stacked_params_mapping:
416
+ if weight_name not in name:
417
+ continue
418
+ name = name.replace(weight_name, param_name)
419
+ # Skip loading extra bias for GPTQ models.
420
+ if name.endswith(".bias") and name not in params_dict:
421
+ continue
422
+ # Skip experts that are not assigned to this worker.
423
+ if (
424
+ "mlp.experts." in name or "mlp.shared_experts." in name
425
+ ) and name not in params_dict:
426
+ continue
427
+ param = params_dict[name]
428
+ weight_loader = param.weight_loader
429
+ weight_loader(param, loaded_weight, shard_id)
430
+ break
431
+ else:
432
+ # Skip loading extra bias for GPTQ models.
433
+ if name.endswith(".bias") and name not in params_dict:
434
+ continue
435
+ # Skip experts that are not assigned to this worker.
436
+ if (
437
+ "mlp.experts." in name or "mlp.shared_experts." in name
438
+ ) and name not in params_dict:
439
+ continue
440
+ param = params_dict[name]
441
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
442
+ weight_loader(param, loaded_weight)
443
+
444
+
445
+ EntryClass = XverseMoeForCausalLM
@@ -22,12 +22,19 @@ import os
22
22
  import time
23
23
  import uuid
24
24
  from http import HTTPStatus
25
- from typing import Dict, List, Optional
25
+ from typing import Dict, List
26
26
 
27
27
  from fastapi import HTTPException, Request, UploadFile
28
28
  from fastapi.responses import JSONResponse, StreamingResponse
29
29
  from pydantic import ValidationError
30
30
 
31
+ try:
32
+ from outlines.fsm.json_schema import convert_json_schema_to_str
33
+ except ImportError:
34
+ # Before outlines 0.0.47, convert_json_schema_to_str is under
35
+ # outlines.integrations.utils
36
+ from outlines.integrations.utils import convert_json_schema_to_str
37
+
31
38
  from sglang.srt.conversation import (
32
39
  Conversation,
33
40
  SeparatorStyle,
@@ -88,19 +95,6 @@ file_id_storage: Dict[str, str] = {}
88
95
  storage_dir = None
89
96
 
90
97
 
91
- def format_finish_reason(finish_reason) -> Optional[str]:
92
- if finish_reason.startswith("None"):
93
- return None
94
- elif finish_reason.startswith("FINISH_MATCHED"):
95
- return "stop"
96
- elif finish_reason.startswith("FINISH_LENGTH"):
97
- return "length"
98
- elif finish_reason.startswith("FINISH_ABORT"):
99
- return "abort"
100
- else:
101
- return "unknown"
102
-
103
-
104
98
  def create_error_response(
105
99
  message: str,
106
100
  err_type: str = "BadRequestError",
@@ -478,7 +472,7 @@ def v1_generate_request(
478
472
  first_prompt_type = type(all_requests[0].prompt)
479
473
  for request in all_requests:
480
474
  assert (
481
- type(request.prompt) == first_prompt_type
475
+ type(request.prompt) is first_prompt_type
482
476
  ), "All prompts must be of the same type in file input settings"
483
477
  if len(all_requests) > 1 and request.n > 1:
484
478
  raise ValueError(
@@ -611,8 +605,10 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
611
605
  "index": 0,
612
606
  "text": text,
613
607
  "logprobs": logprobs,
614
- "finish_reason": format_finish_reason(
615
- ret_item["meta_info"]["finish_reason"]
608
+ "finish_reason": (
609
+ ret_item["meta_info"]["finish_reason"]["type"]
610
+ if ret_item["meta_info"]["finish_reason"]
611
+ else ""
616
612
  ),
617
613
  }
618
614
  else:
@@ -620,8 +616,10 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
620
616
  index=idx,
621
617
  text=text,
622
618
  logprobs=logprobs,
623
- finish_reason=format_finish_reason(
624
- ret_item["meta_info"]["finish_reason"]
619
+ finish_reason=(
620
+ ret_item["meta_info"]["finish_reason"]["type"]
621
+ if ret_item["meta_info"]["finish_reason"]
622
+ else ""
625
623
  ),
626
624
  )
627
625
 
@@ -755,8 +753,10 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
755
753
  index=index,
756
754
  text=delta,
757
755
  logprobs=logprobs,
758
- finish_reason=format_finish_reason(
759
- content["meta_info"]["finish_reason"]
756
+ finish_reason=(
757
+ content["meta_info"]["finish_reason"]["type"]
758
+ if content["meta_info"]["finish_reason"]
759
+ else ""
760
760
  ),
761
761
  )
762
762
  chunk = CompletionStreamResponse(
@@ -832,6 +832,7 @@ def v1_chat_generate_request(
832
832
  return_logprobs = []
833
833
  logprob_start_lens = []
834
834
  top_logprobs_nums = []
835
+ modalities_list = []
835
836
 
836
837
  # NOTE: with openai API, the prompt's logprobs are always not computed
837
838
 
@@ -864,10 +865,12 @@ def v1_chat_generate_request(
864
865
  )
865
866
  stop = request.stop
866
867
  image_data = None
868
+ modalities = []
867
869
  else:
868
870
  conv = generate_chat_conv(request, chat_template_name)
869
871
  prompt = conv.get_prompt()
870
872
  image_data = conv.image_data
873
+ modalities = conv.modalities
871
874
  stop = conv.stop_str or []
872
875
  if request.stop:
873
876
  if isinstance(request.stop, str):
@@ -880,27 +883,33 @@ def v1_chat_generate_request(
880
883
  prompt_ids = request.messages
881
884
  stop = request.stop
882
885
  image_data = None
886
+ modalities = []
883
887
  input_ids.append(prompt_ids)
884
888
  return_logprobs.append(request.logprobs)
885
889
  logprob_start_lens.append(-1)
886
- top_logprobs_nums.append(request.top_logprobs)
887
- sampling_params_list.append(
888
- {
889
- "temperature": request.temperature,
890
- "max_new_tokens": request.max_tokens,
891
- "min_new_tokens": request.min_tokens,
892
- "stop": stop,
893
- "stop_token_ids": request.stop_token_ids,
894
- "top_p": request.top_p,
895
- "presence_penalty": request.presence_penalty,
896
- "frequency_penalty": request.frequency_penalty,
897
- "repetition_penalty": request.repetition_penalty,
898
- "regex": request.regex,
899
- "json_schema": request.json_schema,
900
- "n": request.n,
901
- }
902
- )
890
+ top_logprobs_nums.append(request.top_logprobs or 0)
891
+
892
+ sampling_params = {
893
+ "temperature": request.temperature,
894
+ "max_new_tokens": request.max_tokens,
895
+ "min_new_tokens": request.min_tokens,
896
+ "stop": stop,
897
+ "stop_token_ids": request.stop_token_ids,
898
+ "top_p": request.top_p,
899
+ "presence_penalty": request.presence_penalty,
900
+ "frequency_penalty": request.frequency_penalty,
901
+ "repetition_penalty": request.repetition_penalty,
902
+ "regex": request.regex,
903
+ "n": request.n,
904
+ }
905
+ if request.response_format and request.response_format.type == "json_schema":
906
+ sampling_params["json_schema"] = convert_json_schema_to_str(
907
+ request.response_format.json_schema.schema_
908
+ )
909
+ sampling_params_list.append(sampling_params)
910
+
903
911
  image_data_list.append(image_data)
912
+ modalities_list.extend(modalities)
904
913
  if len(all_requests) == 1:
905
914
  input_ids = input_ids[0]
906
915
  if isinstance(input_ids, str):
@@ -912,6 +921,7 @@ def v1_chat_generate_request(
912
921
  return_logprobs = return_logprobs[0]
913
922
  logprob_start_lens = logprob_start_lens[0]
914
923
  top_logprobs_nums = top_logprobs_nums[0]
924
+ modalities_list = modalities_list[:1]
915
925
  else:
916
926
  if isinstance(input_ids[0], str):
917
927
  prompt_kwargs = {"text": input_ids}
@@ -928,6 +938,7 @@ def v1_chat_generate_request(
928
938
  stream=all_requests[0].stream,
929
939
  return_text_in_logprobs=True,
930
940
  rid=request_ids,
941
+ modalities=modalities_list,
931
942
  )
932
943
  if len(all_requests) == 1:
933
944
  return adapted_request, all_requests[0]
@@ -981,8 +992,10 @@ def v1_chat_generate_response(request, ret, to_file=False):
981
992
  "index": 0,
982
993
  "message": {"role": "assistant", "content": ret_item["text"]},
983
994
  "logprobs": choice_logprobs,
984
- "finish_reason": format_finish_reason(
985
- ret_item["meta_info"]["finish_reason"]
995
+ "finish_reason": (
996
+ ret_item["meta_info"]["finish_reason"]["type"]
997
+ if ret_item["meta_info"]["finish_reason"]
998
+ else ""
986
999
  ),
987
1000
  }
988
1001
  else:
@@ -990,8 +1003,10 @@ def v1_chat_generate_response(request, ret, to_file=False):
990
1003
  index=idx,
991
1004
  message=ChatMessage(role="assistant", content=ret_item["text"]),
992
1005
  logprobs=choice_logprobs,
993
- finish_reason=format_finish_reason(
994
- ret_item["meta_info"]["finish_reason"]
1006
+ finish_reason=(
1007
+ ret_item["meta_info"]["finish_reason"]["type"]
1008
+ if ret_item["meta_info"]["finish_reason"]
1009
+ else ""
995
1010
  ),
996
1011
  )
997
1012
 
@@ -1116,8 +1131,10 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
1116
1131
  choice_data = ChatCompletionResponseStreamChoice(
1117
1132
  index=index,
1118
1133
  delta=DeltaMessage(role="assistant"),
1119
- finish_reason=format_finish_reason(
1120
- content["meta_info"]["finish_reason"]
1134
+ finish_reason=(
1135
+ content["meta_info"]["finish_reason"]["type"]
1136
+ if content["meta_info"]["finish_reason"]
1137
+ else ""
1121
1138
  ),
1122
1139
  logprobs=choice_logprobs,
1123
1140
  )
@@ -1134,8 +1151,10 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
1134
1151
  choice_data = ChatCompletionResponseStreamChoice(
1135
1152
  index=index,
1136
1153
  delta=DeltaMessage(content=delta),
1137
- finish_reason=format_finish_reason(
1138
- content["meta_info"]["finish_reason"]
1154
+ finish_reason=(
1155
+ content["meta_info"]["finish_reason"]["type"]
1156
+ if content["meta_info"]["finish_reason"]
1157
+ else ""
1139
1158
  ),
1140
1159
  logprobs=choice_logprobs,
1141
1160
  )