sglang 0.1.14__py3-none-any.whl → 0.1.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. sglang/__init__.py +59 -2
  2. sglang/api.py +40 -11
  3. sglang/backend/anthropic.py +17 -3
  4. sglang/backend/litellm.py +90 -0
  5. sglang/backend/openai.py +160 -12
  6. sglang/backend/runtime_endpoint.py +62 -27
  7. sglang/backend/vertexai.py +1 -0
  8. sglang/bench_latency.py +320 -0
  9. sglang/global_config.py +24 -3
  10. sglang/lang/chat_template.py +122 -6
  11. sglang/lang/compiler.py +2 -2
  12. sglang/lang/interpreter.py +206 -98
  13. sglang/lang/ir.py +98 -34
  14. sglang/lang/tracer.py +6 -4
  15. sglang/launch_server.py +4 -1
  16. sglang/launch_server_llavavid.py +32 -0
  17. sglang/srt/constrained/__init__.py +14 -6
  18. sglang/srt/constrained/fsm_cache.py +9 -2
  19. sglang/srt/constrained/jump_forward.py +113 -24
  20. sglang/srt/conversation.py +4 -2
  21. sglang/srt/flush_cache.py +18 -0
  22. sglang/srt/hf_transformers_utils.py +144 -3
  23. sglang/srt/layers/context_flashattention_nopad.py +1 -0
  24. sglang/srt/layers/extend_attention.py +20 -1
  25. sglang/srt/layers/fused_moe.py +596 -0
  26. sglang/srt/layers/logits_processor.py +190 -61
  27. sglang/srt/layers/radix_attention.py +62 -53
  28. sglang/srt/layers/token_attention.py +21 -9
  29. sglang/srt/managers/controller/cuda_graph_runner.py +196 -0
  30. sglang/srt/managers/controller/dp_worker.py +113 -0
  31. sglang/srt/managers/controller/infer_batch.py +908 -0
  32. sglang/srt/managers/controller/manager_multi.py +195 -0
  33. sglang/srt/managers/controller/manager_single.py +177 -0
  34. sglang/srt/managers/controller/model_runner.py +359 -0
  35. sglang/srt/managers/{router → controller}/radix_cache.py +102 -53
  36. sglang/srt/managers/controller/schedule_heuristic.py +65 -0
  37. sglang/srt/managers/controller/tp_worker.py +813 -0
  38. sglang/srt/managers/detokenizer_manager.py +42 -40
  39. sglang/srt/managers/io_struct.py +44 -10
  40. sglang/srt/managers/tokenizer_manager.py +224 -82
  41. sglang/srt/memory_pool.py +52 -59
  42. sglang/srt/model_config.py +97 -2
  43. sglang/srt/models/chatglm.py +399 -0
  44. sglang/srt/models/commandr.py +369 -0
  45. sglang/srt/models/dbrx.py +406 -0
  46. sglang/srt/models/gemma.py +34 -38
  47. sglang/srt/models/gemma2.py +436 -0
  48. sglang/srt/models/grok.py +738 -0
  49. sglang/srt/models/llama2.py +47 -37
  50. sglang/srt/models/llama_classification.py +107 -0
  51. sglang/srt/models/llava.py +92 -27
  52. sglang/srt/models/llavavid.py +298 -0
  53. sglang/srt/models/minicpm.py +366 -0
  54. sglang/srt/models/mixtral.py +302 -127
  55. sglang/srt/models/mixtral_quant.py +372 -0
  56. sglang/srt/models/qwen.py +40 -35
  57. sglang/srt/models/qwen2.py +33 -36
  58. sglang/srt/models/qwen2_moe.py +473 -0
  59. sglang/srt/models/stablelm.py +33 -39
  60. sglang/srt/models/yivl.py +19 -26
  61. sglang/srt/openai_api_adapter.py +411 -0
  62. sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +44 -19
  63. sglang/srt/sampling_params.py +2 -0
  64. sglang/srt/server.py +197 -481
  65. sglang/srt/server_args.py +190 -74
  66. sglang/srt/utils.py +460 -95
  67. sglang/test/test_programs.py +73 -10
  68. sglang/test/test_utils.py +226 -7
  69. sglang/utils.py +97 -27
  70. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/METADATA +74 -45
  71. sglang-0.1.21.dist-info/RECORD +82 -0
  72. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/WHEEL +1 -1
  73. sglang/srt/backend_config.py +0 -13
  74. sglang/srt/managers/router/infer_batch.py +0 -503
  75. sglang/srt/managers/router/manager.py +0 -79
  76. sglang/srt/managers/router/model_rpc.py +0 -686
  77. sglang/srt/managers/router/model_runner.py +0 -514
  78. sglang/srt/managers/router/scheduler.py +0 -70
  79. sglang-0.1.14.dist-info/RECORD +0 -64
  80. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/LICENSE +0 -0
  81. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/top_level.txt +0 -0
@@ -1,40 +1,38 @@
1
- # This code is based on:
2
- # https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/stablelm.py
1
+ # Adapted from:
2
+ # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/stablelm.py#L1
3
3
  """Inference-only StableLM-2 (https://huggingface.co/stabilityai/stablelm-2-1_6b)
4
4
  model compatible with HuggingFace weights."""
5
- from typing import Optional, Tuple
5
+ from typing import Iterable, Optional, Tuple
6
6
 
7
7
  import torch
8
8
  from torch import nn
9
9
  from transformers import PretrainedConfig
10
-
11
- from sglang.srt.layers.logits_processor import LogitsProcessor
12
- from sglang.srt.layers.radix_attention import RadixAttention
13
- from sglang.srt.managers.router.model_runner import InputMetadata
10
+ from vllm.config import CacheConfig
11
+ from vllm.distributed import get_tensor_model_parallel_world_size
14
12
  from vllm.model_executor.layers.activation import SiluAndMul
15
13
  from vllm.model_executor.layers.linear import (
16
- LinearMethodBase,
17
14
  MergedColumnParallelLinear,
18
15
  QKVParallelLinear,
19
16
  RowParallelLinear,
20
17
  )
18
+ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
21
19
  from vllm.model_executor.layers.rotary_embedding import get_rope
22
20
  from vllm.model_executor.layers.vocab_parallel_embedding import (
23
- VocabParallelEmbedding,
24
21
  ParallelLMHead,
22
+ VocabParallelEmbedding,
25
23
  )
26
- from vllm.model_executor.parallel_utils.parallel_state import (
27
- get_tensor_model_parallel_world_size,
28
- )
29
- from vllm.model_executor.weight_utils import (
30
- default_weight_loader,
31
- hf_model_weights_iterator,
32
- )
24
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
25
+
26
+ from sglang.srt.layers.logits_processor import LogitsProcessor
27
+ from sglang.srt.layers.radix_attention import RadixAttention
28
+ from sglang.srt.managers.controller.model_runner import InputMetadata
33
29
 
34
30
 
35
31
  class StablelmMLP(nn.Module):
36
32
  def __init__(
37
- self, config: PretrainedConfig, linear_method: Optional[LinearMethodBase] = None
33
+ self,
34
+ config: PretrainedConfig,
35
+ quant_config: Optional[QuantizationConfig] = None,
38
36
  ) -> None:
39
37
  super().__init__()
40
38
  self.config = config
@@ -44,10 +42,13 @@ class StablelmMLP(nn.Module):
44
42
  config.hidden_size,
45
43
  [config.intermediate_size] * 2,
46
44
  bias=False,
47
- linear_method=linear_method,
45
+ quant_config=quant_config,
48
46
  )
49
47
  self.down_proj = RowParallelLinear(
50
- config.intermediate_size, config.hidden_size, bias=False
48
+ config.intermediate_size,
49
+ config.hidden_size,
50
+ bias=False,
51
+ quant_config=quant_config,
51
52
  )
52
53
  self.act_fn = SiluAndMul()
53
54
 
@@ -63,7 +64,7 @@ class StablelmAttention(nn.Module):
63
64
  self,
64
65
  config: PretrainedConfig,
65
66
  layer_id: int = 0,
66
- linear_method: Optional[LinearMethodBase] = None,
67
+ quant_config: Optional[QuantizationConfig] = None,
67
68
  ) -> None:
68
69
  super().__init__()
69
70
  self.config = config
@@ -105,13 +106,11 @@ class StablelmAttention(nn.Module):
105
106
  self.total_num_heads,
106
107
  self.total_num_key_value_heads,
107
108
  self.qkv_bias,
108
- linear_method=linear_method,
109
109
  )
110
110
  self.o_proj = RowParallelLinear(
111
111
  self.total_num_heads * self.head_dim,
112
112
  self.hidden_size,
113
113
  bias=False,
114
- linear_method=linear_method,
115
114
  )
116
115
  self.rotary_emb = get_rope(
117
116
  self.head_dim,
@@ -146,11 +145,11 @@ class StablelmDecoderLayer(nn.Module):
146
145
  self,
147
146
  config: PretrainedConfig,
148
147
  layer_id: int = 0,
149
- linear_method: Optional[LinearMethodBase] = None,
148
+ quant_config: Optional[QuantizationConfig] = None,
150
149
  ) -> None:
151
150
  super().__init__()
152
151
  self.self_attn = StablelmAttention(config, layer_id=layer_id)
153
- self.mlp = StablelmMLP(config, linear_method)
152
+ self.mlp = StablelmMLP(config, quant_config=quant_config)
154
153
  norm_eps = getattr(config, "norm_eps", getattr(config, "layer_norm_eps", 1e-05))
155
154
  self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
156
155
  self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
@@ -182,7 +181,9 @@ class StablelmDecoderLayer(nn.Module):
182
181
 
183
182
  class StableLMEpochModel(nn.Module):
184
183
  def __init__(
185
- self, config: PretrainedConfig, linear_method: Optional[LinearMethodBase] = None
184
+ self,
185
+ config: PretrainedConfig,
186
+ quant_config: Optional[QuantizationConfig] = None,
186
187
  ) -> None:
187
188
  super().__init__()
188
189
  self.embed_tokens = VocabParallelEmbedding(
@@ -191,7 +192,7 @@ class StableLMEpochModel(nn.Module):
191
192
  )
192
193
  self.layers = nn.ModuleList(
193
194
  [
194
- StablelmDecoderLayer(config, i, linear_method)
195
+ StablelmDecoderLayer(config, i, quant_config=quant_config)
195
196
  for i in range(config.num_hidden_layers)
196
197
  ]
197
198
  )
@@ -224,12 +225,13 @@ class StableLmForCausalLM(nn.Module):
224
225
  def __init__(
225
226
  self,
226
227
  config: PretrainedConfig,
227
- linear_method: Optional[LinearMethodBase] = None,
228
+ quant_config: Optional[QuantizationConfig] = None,
229
+ cache_config: Optional[CacheConfig] = None,
228
230
  ) -> None:
229
231
  super().__init__()
230
232
  self.config = config
231
- self.linear_method = linear_method
232
- self.model = StableLMEpochModel(config, linear_method)
233
+ self.quant_config = quant_config
234
+ self.model = StableLMEpochModel(config, quant_config=quant_config)
233
235
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
234
236
  self.logits_processor = LogitsProcessor(config)
235
237
 
@@ -245,13 +247,7 @@ class StableLmForCausalLM(nn.Module):
245
247
  input_ids, hidden_states, self.lm_head.weight, input_metadata
246
248
  )
247
249
 
248
- def load_weights(
249
- self,
250
- model_name_or_path: str,
251
- cache_dir: Optional[str] = None,
252
- load_format: str = "auto",
253
- revision: Optional[str] = None,
254
- ):
250
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
255
251
  stacked_params_mapping = [
256
252
  # (param_name, shard_name, shard_id)
257
253
  ("qkv_proj", "q_proj", "q"),
@@ -261,9 +257,7 @@ class StableLmForCausalLM(nn.Module):
261
257
  ("gate_up_proj", "up_proj", 1),
262
258
  ]
263
259
  params_dict = dict(self.named_parameters())
264
- for name, loaded_weight in hf_model_weights_iterator(
265
- model_name_or_path, cache_dir, load_format, revision
266
- ):
260
+ for name, loaded_weight in weights:
267
261
  if "rotary_emb.inv_freq" in name:
268
262
  continue
269
263
  if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
sglang/srt/models/yivl.py CHANGED
@@ -1,42 +1,38 @@
1
1
  """Inference-only Yi-VL model."""
2
2
 
3
- import os
4
- from typing import List, Optional
3
+ from typing import Iterable, Optional, Tuple
5
4
 
6
5
  import torch
7
6
  import torch.nn as nn
7
+ from transformers import CLIPVisionModel, LlavaConfig
8
+ from vllm.config import CacheConfig
9
+ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
10
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
11
+
8
12
  from sglang.srt.models.llava import (
9
13
  LlavaLlamaForCausalLM,
10
- clip_vision_embed_forward,
11
14
  monkey_path_clip_vision_embed_forward,
12
15
  )
13
- from transformers import CLIPVisionModel, LlavaConfig
14
- from vllm.model_executor.weight_utils import (
15
- default_weight_loader,
16
- hf_model_weights_iterator,
17
- )
18
16
 
19
17
 
20
18
  class YiVLForCausalLM(LlavaLlamaForCausalLM):
21
- def __init__(self, *args, **kwargs):
22
- self.config = kwargs["config"]
23
- super().__init__(self.config)
19
+ def __init__(
20
+ self,
21
+ config: LlavaConfig,
22
+ quant_config: Optional[QuantizationConfig] = None,
23
+ cache_config: Optional[CacheConfig] = None,
24
+ ) -> None:
25
+ super().__init__(config, quant_config, cache_config)
24
26
 
25
27
  self.multi_modal_projector = YiVLMultiModalProjector(self.config)
26
28
  self.vision_tower_subfolder = self.config.mm_vision_tower.replace(
27
29
  "./", ""
28
30
  ) # Everything after "./"
29
31
 
30
- def load_weights(
31
- self,
32
- model_name_or_path: str,
33
- cache_dir: Optional[str] = None,
34
- load_format: str = "auto",
35
- revision: Optional[str] = None,
36
- ):
32
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
37
33
  # We have to use the subfolder of the main model directory (e.g. 01-ai/Yi-VL-6B)
38
34
  self.vision_tower = CLIPVisionModel.from_pretrained(
39
- model_name_or_path,
35
+ self.config._name_or_path,
40
36
  torch_dtype=torch.float16,
41
37
  subfolder=self.vision_tower_subfolder,
42
38
  ).cuda()
@@ -70,9 +66,8 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM):
70
66
  "model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
71
67
  }
72
68
  params_dict = dict(self.named_parameters())
73
- for name, loaded_weight in hf_model_weights_iterator(
74
- model_name_or_path, cache_dir, load_format, revision
75
- ):
69
+ weights = list(weights)
70
+ for name, loaded_weight in weights:
76
71
  if "projector" in name or "vision_tower" in name:
77
72
  for weight_name, param_name in projector_weights.items():
78
73
  if weight_name in name:
@@ -82,9 +77,7 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM):
82
77
  weight_loader(param, loaded_weight)
83
78
 
84
79
  # load language model
85
- self.language_model.load_weights(
86
- model_name_or_path, cache_dir, load_format, revision
87
- )
80
+ self.language_model.load_weights(weights)
88
81
 
89
82
  monkey_path_clip_vision_embed_forward()
90
83
 
@@ -105,7 +98,7 @@ class YiVLMultiModalProjector(nn.Module):
105
98
 
106
99
  def forward(self, image_features):
107
100
  hidden_states = self.linear_1(image_features)
108
- hidden_state = self.ln_1(hidden_states)
101
+ hidden_states = self.ln_1(hidden_states)
109
102
  hidden_states = self.act(hidden_states)
110
103
  hidden_states = self.linear_2(hidden_states)
111
104
  hidden_states = self.ln_2(hidden_states)
@@ -0,0 +1,411 @@
1
+ """Conversion between OpenAI APIs and native SRT APIs"""
2
+
3
+ import asyncio
4
+ import json
5
+ import os
6
+ from http import HTTPStatus
7
+
8
+ from fastapi import Request
9
+ from fastapi.responses import JSONResponse, StreamingResponse
10
+
11
+ from sglang.srt.conversation import (
12
+ Conversation,
13
+ SeparatorStyle,
14
+ chat_template_exists,
15
+ generate_chat_conv,
16
+ register_conv_template,
17
+ )
18
+ from sglang.srt.managers.io_struct import GenerateReqInput
19
+ from sglang.srt.openai_protocol import (
20
+ ChatCompletionRequest,
21
+ ChatCompletionResponse,
22
+ ChatCompletionResponseChoice,
23
+ ChatCompletionResponseStreamChoice,
24
+ ChatCompletionStreamResponse,
25
+ ChatMessage,
26
+ CompletionRequest,
27
+ CompletionResponse,
28
+ CompletionResponseChoice,
29
+ CompletionResponseStreamChoice,
30
+ CompletionStreamResponse,
31
+ DeltaMessage,
32
+ ErrorResponse,
33
+ LogProbs,
34
+ UsageInfo,
35
+ )
36
+
37
+ chat_template_name = None
38
+
39
+
40
+ def create_error_response(
41
+ message: str,
42
+ err_type: str = "BadRequestError",
43
+ status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
44
+ ):
45
+ error = ErrorResponse(message=message, type=err_type, code=status_code.value)
46
+ return JSONResponse(content=error.model_dump(), status_code=error.code)
47
+
48
+
49
+ def create_streaming_error_response(
50
+ message: str,
51
+ err_type: str = "BadRequestError",
52
+ status_code: HTTPStatus = HTTPStatus.BAD_REQUEST,
53
+ ) -> str:
54
+ error = ErrorResponse(message=message, type=err_type, code=status_code.value)
55
+ json_str = json.dumps({"error": error.model_dump()})
56
+ return json_str
57
+
58
+
59
+ def load_chat_template_for_openai_api(chat_template_arg):
60
+ global chat_template_name
61
+
62
+ print(f"Use chat template: {chat_template_arg}")
63
+ if not chat_template_exists(chat_template_arg):
64
+ if not os.path.exists(chat_template_arg):
65
+ raise RuntimeError(
66
+ f"Chat template {chat_template_arg} is not a built-in template name "
67
+ "or a valid chat template file path."
68
+ )
69
+ with open(chat_template_arg, "r") as filep:
70
+ template = json.load(filep)
71
+ try:
72
+ sep_style = SeparatorStyle[template["sep_style"]]
73
+ except KeyError:
74
+ raise ValueError(
75
+ f"Unknown separator style: {template['sep_style']}"
76
+ ) from None
77
+ register_conv_template(
78
+ Conversation(
79
+ name=template["name"],
80
+ system_template=template["system"] + "\n{system_message}",
81
+ system_message=template.get("system_message", ""),
82
+ roles=(template["user"], template["assistant"]),
83
+ sep_style=sep_style,
84
+ sep=template.get("sep", "\n"),
85
+ stop_str=template["stop_str"],
86
+ ),
87
+ override=True,
88
+ )
89
+ chat_template_name = template["name"]
90
+ else:
91
+ chat_template_name = chat_template_arg
92
+
93
+
94
+ async def v1_completions(tokenizer_manager, raw_request: Request):
95
+ request_json = await raw_request.json()
96
+ request = CompletionRequest(**request_json)
97
+
98
+ if request.n != 1:
99
+ return create_error_response("n != 1 is not supported")
100
+
101
+ adapted_request = GenerateReqInput(
102
+ text=request.prompt,
103
+ sampling_params={
104
+ "temperature": request.temperature,
105
+ "max_new_tokens": request.max_tokens,
106
+ "stop": request.stop,
107
+ "top_p": request.top_p,
108
+ "presence_penalty": request.presence_penalty,
109
+ "frequency_penalty": request.frequency_penalty,
110
+ "regex": request.regex,
111
+ },
112
+ return_logprob=request.logprobs is not None and request.logprobs > 0,
113
+ top_logprobs_num=request.logprobs if request.logprobs is not None else 0,
114
+ return_text_in_logprobs=True,
115
+ stream=request.stream,
116
+ )
117
+
118
+ if adapted_request.stream:
119
+
120
+ async def generate_stream_resp():
121
+ stream_buffer = ""
122
+ n_prev_token = 0
123
+ try:
124
+ async for content in tokenizer_manager.generate_request(
125
+ adapted_request, raw_request
126
+ ):
127
+ text = content["text"]
128
+ prompt_tokens = content["meta_info"]["prompt_tokens"]
129
+ completion_tokens = content["meta_info"]["completion_tokens"]
130
+
131
+ if not stream_buffer: # The first chunk
132
+ if request.echo:
133
+ # Prepend prompt in response text.
134
+ text = request.prompt + text
135
+
136
+ if request.logprobs:
137
+ # The first chunk and echo is enabled.
138
+ if not stream_buffer and request.echo:
139
+ prefill_token_logprobs = content["meta_info"][
140
+ "prefill_token_logprobs"
141
+ ]
142
+ prefill_top_logprobs = content["meta_info"][
143
+ "prefill_top_logprobs"
144
+ ]
145
+ else:
146
+ prefill_token_logprobs = None
147
+ prefill_top_logprobs = None
148
+
149
+ logprobs = to_openai_style_logprobs(
150
+ prefill_token_logprobs=prefill_token_logprobs,
151
+ prefill_top_logprobs=prefill_top_logprobs,
152
+ decode_token_logprobs=content["meta_info"][
153
+ "decode_token_logprobs"
154
+ ][n_prev_token:],
155
+ decode_top_logprobs=content["meta_info"][
156
+ "decode_top_logprobs"
157
+ ][n_prev_token:],
158
+ )
159
+
160
+ n_prev_token = len(
161
+ content["meta_info"]["decode_token_logprobs"]
162
+ )
163
+ else:
164
+ logprobs = None
165
+
166
+ delta = text[len(stream_buffer) :]
167
+ stream_buffer = stream_buffer + delta
168
+ choice_data = CompletionResponseStreamChoice(
169
+ index=0,
170
+ text=delta,
171
+ logprobs=logprobs,
172
+ finish_reason=content["meta_info"]["finish_reason"],
173
+ )
174
+ chunk = CompletionStreamResponse(
175
+ id=content["meta_info"]["id"],
176
+ object="text_completion",
177
+ choices=[choice_data],
178
+ model=request.model,
179
+ usage=UsageInfo(
180
+ prompt_tokens=prompt_tokens,
181
+ completion_tokens=completion_tokens,
182
+ total_tokens=prompt_tokens + completion_tokens,
183
+ ),
184
+ )
185
+ yield f"data: {chunk.model_dump_json()}\n\n"
186
+ except ValueError as e:
187
+ error = create_streaming_error_response(str(e))
188
+ yield f"data: {error}\n\n"
189
+ yield "data: [DONE]\n\n"
190
+
191
+ return StreamingResponse(
192
+ generate_stream_resp(),
193
+ media_type="text/event-stream",
194
+ background=tokenizer_manager.create_abort_task(adapted_request),
195
+ )
196
+
197
+ # Non-streaming response.
198
+ try:
199
+ ret = await tokenizer_manager.generate_request(
200
+ adapted_request, raw_request
201
+ ).__anext__()
202
+ except ValueError as e:
203
+ return create_error_response(str(e))
204
+
205
+ ret = ret[0] if isinstance(ret, list) else ret
206
+ prompt_tokens = ret["meta_info"]["prompt_tokens"]
207
+ completion_tokens = ret["meta_info"]["completion_tokens"]
208
+ text = ret["text"]
209
+ if request.echo:
210
+ text = request.prompt + text
211
+
212
+ if request.logprobs:
213
+ if request.echo:
214
+ prefill_token_logprobs = ret["meta_info"]["prefill_token_logprobs"]
215
+ prefill_top_logprobs = ret["meta_info"]["prefill_top_logprobs"]
216
+ else:
217
+ prefill_token_logprobs = None
218
+ prefill_top_logprobs = None
219
+
220
+ logprobs = to_openai_style_logprobs(
221
+ prefill_token_logprobs=prefill_token_logprobs,
222
+ prefill_top_logprobs=prefill_top_logprobs,
223
+ decode_token_logprobs=ret["meta_info"]["decode_token_logprobs"],
224
+ decode_top_logprobs=ret["meta_info"]["decode_top_logprobs"],
225
+ )
226
+ else:
227
+ logprobs = None
228
+
229
+ choice_data = CompletionResponseChoice(
230
+ index=0,
231
+ text=text,
232
+ logprobs=logprobs,
233
+ finish_reason=ret["meta_info"]["finish_reason"],
234
+ )
235
+ response = CompletionResponse(
236
+ id=ret["meta_info"]["id"],
237
+ model=request.model,
238
+ choices=[choice_data],
239
+ usage=UsageInfo(
240
+ prompt_tokens=prompt_tokens,
241
+ completion_tokens=completion_tokens,
242
+ total_tokens=prompt_tokens + completion_tokens,
243
+ ),
244
+ )
245
+ return response
246
+
247
+
248
+ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
249
+ request_json = await raw_request.json()
250
+ request = ChatCompletionRequest(**request_json)
251
+
252
+ if request.n != 1:
253
+ return create_error_response("n != 1 is not supported")
254
+
255
+ # Prep the data needed for the underlying GenerateReqInput:
256
+ # - prompt: The full prompt string.
257
+ # - stop: Custom stop tokens.
258
+ # - image_data: None or a list of image strings (URLs or base64 strings).
259
+ # None skips any image processing in GenerateReqInput.
260
+ if not isinstance(request.messages, str):
261
+ # Apply chat template and its stop strings.
262
+ if chat_template_name is None:
263
+ prompt = tokenizer_manager.tokenizer.apply_chat_template(
264
+ request.messages, tokenize=False, add_generation_prompt=True
265
+ )
266
+ stop = request.stop
267
+ image_data = None
268
+ else:
269
+ conv = generate_chat_conv(request, chat_template_name)
270
+ prompt = conv.get_prompt()
271
+ image_data = conv.image_data
272
+ stop = conv.stop_str or []
273
+ if request.stop:
274
+ if isinstance(request.stop, str):
275
+ stop.append(request.stop)
276
+ else:
277
+ stop.extend(request.stop)
278
+ else:
279
+ # Use the raw prompt and stop strings if the messages is already a string.
280
+ prompt = request.messages
281
+ stop = request.stop
282
+ image_data = None
283
+
284
+ adapted_request = GenerateReqInput(
285
+ text=prompt,
286
+ image_data=image_data,
287
+ sampling_params={
288
+ "temperature": request.temperature,
289
+ "max_new_tokens": request.max_tokens,
290
+ "stop": stop,
291
+ "top_p": request.top_p,
292
+ "presence_penalty": request.presence_penalty,
293
+ "frequency_penalty": request.frequency_penalty,
294
+ "regex": request.regex,
295
+ },
296
+ stream=request.stream,
297
+ )
298
+
299
+ if adapted_request.stream:
300
+
301
+ async def generate_stream_resp():
302
+ is_first = True
303
+
304
+ stream_buffer = ""
305
+ try:
306
+ async for content in tokenizer_manager.generate_request(
307
+ adapted_request, raw_request
308
+ ):
309
+ if is_first:
310
+ # First chunk with role
311
+ is_first = False
312
+ choice_data = ChatCompletionResponseStreamChoice(
313
+ index=0,
314
+ delta=DeltaMessage(role="assistant"),
315
+ finish_reason=content["meta_info"]["finish_reason"],
316
+ )
317
+ chunk = ChatCompletionStreamResponse(
318
+ id=content["meta_info"]["id"],
319
+ choices=[choice_data],
320
+ model=request.model,
321
+ )
322
+ yield f"data: {chunk.model_dump_json()}\n\n"
323
+
324
+ text = content["text"]
325
+ delta = text[len(stream_buffer) :]
326
+ stream_buffer = stream_buffer + delta
327
+ choice_data = ChatCompletionResponseStreamChoice(
328
+ index=0,
329
+ delta=DeltaMessage(content=delta),
330
+ finish_reason=content["meta_info"]["finish_reason"],
331
+ )
332
+ chunk = ChatCompletionStreamResponse(
333
+ id=content["meta_info"]["id"],
334
+ choices=[choice_data],
335
+ model=request.model,
336
+ )
337
+ yield f"data: {chunk.model_dump_json()}\n\n"
338
+ except ValueError as e:
339
+ error = create_streaming_error_response(str(e))
340
+ yield f"data: {error}\n\n"
341
+ yield "data: [DONE]\n\n"
342
+
343
+ return StreamingResponse(
344
+ generate_stream_resp(),
345
+ media_type="text/event-stream",
346
+ background=tokenizer_manager.create_abort_task(adapted_request),
347
+ )
348
+
349
+ # Non-streaming response.
350
+ try:
351
+ ret = await tokenizer_manager.generate_request(
352
+ adapted_request, raw_request
353
+ ).__anext__()
354
+ except ValueError as e:
355
+ return create_error_response(str(e))
356
+
357
+ prompt_tokens = ret["meta_info"]["prompt_tokens"]
358
+ completion_tokens = ret["meta_info"]["completion_tokens"]
359
+ choice_data = ChatCompletionResponseChoice(
360
+ index=0,
361
+ message=ChatMessage(role="assistant", content=ret["text"]),
362
+ finish_reason=ret["meta_info"]["finish_reason"],
363
+ )
364
+ response = ChatCompletionResponse(
365
+ id=ret["meta_info"]["id"],
366
+ model=request.model,
367
+ choices=[choice_data],
368
+ usage=UsageInfo(
369
+ prompt_tokens=prompt_tokens,
370
+ completion_tokens=completion_tokens,
371
+ total_tokens=prompt_tokens + completion_tokens,
372
+ ),
373
+ )
374
+ return response
375
+
376
+
377
+ def to_openai_style_logprobs(
378
+ prefill_token_logprobs=None,
379
+ decode_token_logprobs=None,
380
+ prefill_top_logprobs=None,
381
+ decode_top_logprobs=None,
382
+ ):
383
+ ret_logprobs = LogProbs()
384
+
385
+ def append_token_logprobs(token_logprobs):
386
+ for logprob, _, token_text in token_logprobs:
387
+ ret_logprobs.tokens.append(token_text)
388
+ ret_logprobs.token_logprobs.append(logprob)
389
+
390
+ # Not supported yet
391
+ ret_logprobs.text_offset.append(-1)
392
+
393
+ def append_top_logprobs(top_logprobs):
394
+ for tokens in top_logprobs:
395
+ if tokens is not None:
396
+ ret_logprobs.top_logprobs.append(
397
+ {token[2]: token[0] for token in tokens}
398
+ )
399
+ else:
400
+ ret_logprobs.top_logprobs.append(None)
401
+
402
+ if prefill_token_logprobs is not None:
403
+ append_token_logprobs(prefill_token_logprobs)
404
+ if decode_token_logprobs is not None:
405
+ append_token_logprobs(decode_token_logprobs)
406
+ if prefill_top_logprobs is not None:
407
+ append_top_logprobs(prefill_top_logprobs)
408
+ if decode_top_logprobs is not None:
409
+ append_top_logprobs(decode_top_logprobs)
410
+
411
+ return ret_logprobs