sglang 0.1.14__py3-none-any.whl → 0.1.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. sglang/__init__.py +55 -2
  2. sglang/api.py +3 -5
  3. sglang/backend/anthropic.py +18 -4
  4. sglang/backend/openai.py +2 -1
  5. sglang/backend/runtime_endpoint.py +18 -5
  6. sglang/backend/vertexai.py +1 -0
  7. sglang/global_config.py +1 -0
  8. sglang/lang/chat_template.py +74 -0
  9. sglang/lang/interpreter.py +40 -16
  10. sglang/lang/tracer.py +6 -4
  11. sglang/launch_server.py +2 -1
  12. sglang/srt/constrained/fsm_cache.py +1 -0
  13. sglang/srt/constrained/jump_forward.py +1 -0
  14. sglang/srt/conversation.py +2 -2
  15. sglang/srt/hf_transformers_utils.py +2 -1
  16. sglang/srt/layers/context_flashattention_nopad.py +1 -0
  17. sglang/srt/layers/extend_attention.py +1 -0
  18. sglang/srt/layers/logits_processor.py +114 -54
  19. sglang/srt/layers/radix_attention.py +2 -1
  20. sglang/srt/layers/token_attention.py +1 -0
  21. sglang/srt/managers/detokenizer_manager.py +5 -1
  22. sglang/srt/managers/io_struct.py +12 -0
  23. sglang/srt/managers/router/infer_batch.py +70 -33
  24. sglang/srt/managers/router/manager.py +7 -2
  25. sglang/srt/managers/router/model_rpc.py +116 -73
  26. sglang/srt/managers/router/model_runner.py +111 -167
  27. sglang/srt/managers/router/radix_cache.py +46 -38
  28. sglang/srt/managers/tokenizer_manager.py +56 -11
  29. sglang/srt/memory_pool.py +5 -14
  30. sglang/srt/model_config.py +7 -0
  31. sglang/srt/models/commandr.py +376 -0
  32. sglang/srt/models/dbrx.py +413 -0
  33. sglang/srt/models/dbrx_config.py +281 -0
  34. sglang/srt/models/gemma.py +22 -20
  35. sglang/srt/models/llama2.py +23 -21
  36. sglang/srt/models/llava.py +12 -10
  37. sglang/srt/models/mixtral.py +27 -25
  38. sglang/srt/models/qwen.py +23 -21
  39. sglang/srt/models/qwen2.py +23 -21
  40. sglang/srt/models/stablelm.py +20 -21
  41. sglang/srt/models/yivl.py +6 -5
  42. sglang/srt/openai_api_adapter.py +356 -0
  43. sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +36 -20
  44. sglang/srt/sampling_params.py +2 -0
  45. sglang/srt/server.py +68 -447
  46. sglang/srt/server_args.py +76 -49
  47. sglang/srt/utils.py +88 -32
  48. sglang/srt/weight_utils.py +402 -0
  49. sglang/test/test_programs.py +8 -7
  50. sglang/test/test_utils.py +195 -7
  51. {sglang-0.1.14.dist-info → sglang-0.1.15.dist-info}/METADATA +12 -14
  52. sglang-0.1.15.dist-info/RECORD +69 -0
  53. sglang-0.1.14.dist-info/RECORD +0 -64
  54. {sglang-0.1.14.dist-info → sglang-0.1.15.dist-info}/LICENSE +0 -0
  55. {sglang-0.1.14.dist-info → sglang-0.1.15.dist-info}/WHEEL +0 -0
  56. {sglang-0.1.14.dist-info → sglang-0.1.15.dist-info}/top_level.txt +0 -0
@@ -1,34 +1,36 @@
1
1
  # Adapted from llama2.py
2
2
  # Modify details for the adaptation of Qwen2 model.
3
3
  """Inference-only Qwen2 model compatible with HuggingFace weights."""
4
- from typing import Any, Dict, List, Optional, Tuple
4
+ from typing import Any, Dict, Optional, Tuple
5
5
 
6
6
  import torch
7
- from sglang.srt.layers.logits_processor import LogitsProcessor
8
- from sglang.srt.layers.radix_attention import RadixAttention
9
- from sglang.srt.managers.router.model_runner import InputMetadata
10
7
  from torch import nn
11
8
  from vllm.model_executor.layers.activation import SiluAndMul
12
9
  from vllm.model_executor.layers.layernorm import RMSNorm
13
10
  from vllm.model_executor.layers.linear import (
14
- LinearMethodBase,
15
11
  MergedColumnParallelLinear,
16
12
  QKVParallelLinear,
17
13
  RowParallelLinear,
18
14
  )
15
+ from vllm.model_executor.layers.quantization.base_config import (
16
+ QuantizationConfig)
19
17
  from vllm.model_executor.layers.rotary_embedding import get_rope
20
18
  from vllm.model_executor.layers.vocab_parallel_embedding import (
21
19
  ParallelLMHead,
22
20
  VocabParallelEmbedding,
23
21
  )
24
- from vllm.model_executor.parallel_utils.parallel_state import (
22
+ from vllm.distributed import (
25
23
  get_tensor_model_parallel_world_size,
26
24
  )
27
- from vllm.model_executor.weight_utils import (
25
+ from sglang.srt.weight_utils import (
28
26
  default_weight_loader,
29
27
  hf_model_weights_iterator,
30
28
  )
31
29
 
30
+ from sglang.srt.layers.logits_processor import LogitsProcessor
31
+ from sglang.srt.layers.radix_attention import RadixAttention
32
+ from sglang.srt.managers.router.model_runner import InputMetadata
33
+
32
34
  Qwen2Config = None
33
35
 
34
36
 
@@ -38,17 +40,17 @@ class Qwen2MLP(nn.Module):
38
40
  hidden_size: int,
39
41
  intermediate_size: int,
40
42
  hidden_act: str,
41
- linear_method: Optional[LinearMethodBase] = None,
43
+ quant_config: Optional[QuantizationConfig] = None,
42
44
  ) -> None:
43
45
  super().__init__()
44
46
  self.gate_up_proj = MergedColumnParallelLinear(
45
47
  hidden_size,
46
48
  [intermediate_size] * 2,
47
49
  bias=False,
48
- linear_method=linear_method,
50
+ quant_config=quant_config,
49
51
  )
50
52
  self.down_proj = RowParallelLinear(
51
- intermediate_size, hidden_size, bias=False, linear_method=linear_method
53
+ intermediate_size, hidden_size, bias=False, quant_config=quant_config,
52
54
  )
53
55
  if hidden_act != "silu":
54
56
  raise ValueError(
@@ -74,7 +76,7 @@ class Qwen2Attention(nn.Module):
74
76
  rope_theta: float = 1000000,
75
77
  rope_scaling: Optional[Dict[str, Any]] = None,
76
78
  max_position_embeddings: int = 32768,
77
- linear_method: Optional[LinearMethodBase] = None,
79
+ quant_config: Optional[QuantizationConfig] = None,
78
80
  ) -> None:
79
81
  super().__init__()
80
82
  self.hidden_size = hidden_size
@@ -105,13 +107,13 @@ class Qwen2Attention(nn.Module):
105
107
  self.total_num_heads,
106
108
  self.total_num_kv_heads,
107
109
  bias=True,
108
- linear_method=linear_method,
110
+ quant_config=quant_config,
109
111
  )
110
112
  self.o_proj = RowParallelLinear(
111
113
  self.total_num_heads * self.head_dim,
112
114
  hidden_size,
113
115
  bias=False,
114
- linear_method=linear_method,
116
+ quant_config=quant_config,
115
117
  )
116
118
 
117
119
  self.rotary_emb = get_rope(
@@ -148,7 +150,7 @@ class Qwen2DecoderLayer(nn.Module):
148
150
  self,
149
151
  config: Qwen2Config,
150
152
  layer_id: int = 0,
151
- linear_method: Optional[LinearMethodBase] = None,
153
+ quant_config: Optional[QuantizationConfig] = None,
152
154
  ) -> None:
153
155
  super().__init__()
154
156
  self.hidden_size = config.hidden_size
@@ -163,13 +165,13 @@ class Qwen2DecoderLayer(nn.Module):
163
165
  rope_theta=rope_theta,
164
166
  rope_scaling=rope_scaling,
165
167
  max_position_embeddings=max_position_embeddings,
166
- linear_method=linear_method,
168
+ quant_config=quant_config,
167
169
  )
168
170
  self.mlp = Qwen2MLP(
169
171
  hidden_size=self.hidden_size,
170
172
  intermediate_size=config.intermediate_size,
171
173
  hidden_act=config.hidden_act,
172
- linear_method=linear_method,
174
+ quant_config=quant_config,
173
175
  )
174
176
  self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
175
177
  self.post_attention_layernorm = RMSNorm(
@@ -205,7 +207,7 @@ class Qwen2Model(nn.Module):
205
207
  def __init__(
206
208
  self,
207
209
  config: Qwen2Config,
208
- linear_method: Optional[LinearMethodBase] = None,
210
+ quant_config: Optional[QuantizationConfig] = None,
209
211
  ) -> None:
210
212
  super().__init__()
211
213
  self.config = config
@@ -217,7 +219,7 @@ class Qwen2Model(nn.Module):
217
219
  )
218
220
  self.layers = nn.ModuleList(
219
221
  [
220
- Qwen2DecoderLayer(config, i, linear_method)
222
+ Qwen2DecoderLayer(config, i, quant_config=quant_config)
221
223
  for i in range(config.num_hidden_layers)
222
224
  ]
223
225
  )
@@ -251,12 +253,12 @@ class Qwen2ForCausalLM(nn.Module):
251
253
  def __init__(
252
254
  self,
253
255
  config: Qwen2Config,
254
- linear_method: Optional[LinearMethodBase] = None,
256
+ quant_config: Optional[QuantizationConfig] = None,
255
257
  ) -> None:
256
258
  super().__init__()
257
259
  self.config = config
258
- self.linear_method = linear_method
259
- self.model = Qwen2Model(config, linear_method)
260
+ self.quant_config = quant_config
261
+ self.model = Qwen2Model(config, quant_config=quant_config)
260
262
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
261
263
  self.logits_processor = LogitsProcessor(config)
262
264
 
@@ -7,34 +7,35 @@ from typing import Optional, Tuple
7
7
  import torch
8
8
  from torch import nn
9
9
  from transformers import PretrainedConfig
10
-
11
- from sglang.srt.layers.logits_processor import LogitsProcessor
12
- from sglang.srt.layers.radix_attention import RadixAttention
13
- from sglang.srt.managers.router.model_runner import InputMetadata
14
10
  from vllm.model_executor.layers.activation import SiluAndMul
15
11
  from vllm.model_executor.layers.linear import (
16
- LinearMethodBase,
17
12
  MergedColumnParallelLinear,
18
13
  QKVParallelLinear,
19
14
  RowParallelLinear,
20
15
  )
16
+ from vllm.model_executor.layers.quantization.base_config import (
17
+ QuantizationConfig)
21
18
  from vllm.model_executor.layers.rotary_embedding import get_rope
22
19
  from vllm.model_executor.layers.vocab_parallel_embedding import (
23
- VocabParallelEmbedding,
24
20
  ParallelLMHead,
21
+ VocabParallelEmbedding,
25
22
  )
26
- from vllm.model_executor.parallel_utils.parallel_state import (
23
+ from vllm.distributed import (
27
24
  get_tensor_model_parallel_world_size,
28
25
  )
29
- from vllm.model_executor.weight_utils import (
26
+ from sglang.srt.weight_utils import (
30
27
  default_weight_loader,
31
28
  hf_model_weights_iterator,
32
29
  )
33
30
 
31
+ from sglang.srt.layers.logits_processor import LogitsProcessor
32
+ from sglang.srt.layers.radix_attention import RadixAttention
33
+ from sglang.srt.managers.router.model_runner import InputMetadata
34
+
34
35
 
35
36
  class StablelmMLP(nn.Module):
36
37
  def __init__(
37
- self, config: PretrainedConfig, linear_method: Optional[LinearMethodBase] = None
38
+ self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None,
38
39
  ) -> None:
39
40
  super().__init__()
40
41
  self.config = config
@@ -44,10 +45,10 @@ class StablelmMLP(nn.Module):
44
45
  config.hidden_size,
45
46
  [config.intermediate_size] * 2,
46
47
  bias=False,
47
- linear_method=linear_method,
48
+ quant_config=quant_config,
48
49
  )
49
50
  self.down_proj = RowParallelLinear(
50
- config.intermediate_size, config.hidden_size, bias=False
51
+ config.intermediate_size, config.hidden_size, bias=False, quant_config=quant_config,
51
52
  )
52
53
  self.act_fn = SiluAndMul()
53
54
 
@@ -63,7 +64,7 @@ class StablelmAttention(nn.Module):
63
64
  self,
64
65
  config: PretrainedConfig,
65
66
  layer_id: int = 0,
66
- linear_method: Optional[LinearMethodBase] = None,
67
+ quant_config: Optional[QuantizationConfig] = None,
67
68
  ) -> None:
68
69
  super().__init__()
69
70
  self.config = config
@@ -105,13 +106,11 @@ class StablelmAttention(nn.Module):
105
106
  self.total_num_heads,
106
107
  self.total_num_key_value_heads,
107
108
  self.qkv_bias,
108
- linear_method=linear_method,
109
109
  )
110
110
  self.o_proj = RowParallelLinear(
111
111
  self.total_num_heads * self.head_dim,
112
112
  self.hidden_size,
113
113
  bias=False,
114
- linear_method=linear_method,
115
114
  )
116
115
  self.rotary_emb = get_rope(
117
116
  self.head_dim,
@@ -146,11 +145,11 @@ class StablelmDecoderLayer(nn.Module):
146
145
  self,
147
146
  config: PretrainedConfig,
148
147
  layer_id: int = 0,
149
- linear_method: Optional[LinearMethodBase] = None,
148
+ quant_config: Optional[QuantizationConfig] = None,
150
149
  ) -> None:
151
150
  super().__init__()
152
151
  self.self_attn = StablelmAttention(config, layer_id=layer_id)
153
- self.mlp = StablelmMLP(config, linear_method)
152
+ self.mlp = StablelmMLP(config, quant_config=quant_config)
154
153
  norm_eps = getattr(config, "norm_eps", getattr(config, "layer_norm_eps", 1e-05))
155
154
  self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
156
155
  self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps)
@@ -182,7 +181,7 @@ class StablelmDecoderLayer(nn.Module):
182
181
 
183
182
  class StableLMEpochModel(nn.Module):
184
183
  def __init__(
185
- self, config: PretrainedConfig, linear_method: Optional[LinearMethodBase] = None
184
+ self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None,
186
185
  ) -> None:
187
186
  super().__init__()
188
187
  self.embed_tokens = VocabParallelEmbedding(
@@ -191,7 +190,7 @@ class StableLMEpochModel(nn.Module):
191
190
  )
192
191
  self.layers = nn.ModuleList(
193
192
  [
194
- StablelmDecoderLayer(config, i, linear_method)
193
+ StablelmDecoderLayer(config, i, quant_config=quant_config)
195
194
  for i in range(config.num_hidden_layers)
196
195
  ]
197
196
  )
@@ -224,12 +223,12 @@ class StableLmForCausalLM(nn.Module):
224
223
  def __init__(
225
224
  self,
226
225
  config: PretrainedConfig,
227
- linear_method: Optional[LinearMethodBase] = None,
226
+ quant_config: Optional[QuantizationConfig] = None,
228
227
  ) -> None:
229
228
  super().__init__()
230
229
  self.config = config
231
- self.linear_method = linear_method
232
- self.model = StableLMEpochModel(config, linear_method)
230
+ self.quant_config = quant_config
231
+ self.model = StableLMEpochModel(config, quant_config=quant_config)
233
232
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
234
233
  self.logits_processor = LogitsProcessor(config)
235
234
 
sglang/srt/models/yivl.py CHANGED
@@ -5,16 +5,17 @@ from typing import List, Optional
5
5
 
6
6
  import torch
7
7
  import torch.nn as nn
8
+ from transformers import CLIPVisionModel, LlavaConfig
9
+ from sglang.srt.weight_utils import (
10
+ default_weight_loader,
11
+ hf_model_weights_iterator,
12
+ )
13
+
8
14
  from sglang.srt.models.llava import (
9
15
  LlavaLlamaForCausalLM,
10
16
  clip_vision_embed_forward,
11
17
  monkey_path_clip_vision_embed_forward,
12
18
  )
13
- from transformers import CLIPVisionModel, LlavaConfig
14
- from vllm.model_executor.weight_utils import (
15
- default_weight_loader,
16
- hf_model_weights_iterator,
17
- )
18
19
 
19
20
 
20
21
  class YiVLForCausalLM(LlavaLlamaForCausalLM):
@@ -0,0 +1,356 @@
1
+ """Conversion between OpenAI APIs and native SRT APIs"""
2
+ import json
3
+ import os
4
+
5
+ from fastapi import HTTPException, Request
6
+ from fastapi.responses import StreamingResponse
7
+
8
+ from sglang.srt.conversation import (
9
+ Conversation,
10
+ SeparatorStyle,
11
+ chat_template_exists,
12
+ generate_chat_conv,
13
+ register_conv_template,
14
+ )
15
+ from sglang.srt.managers.io_struct import GenerateReqInput
16
+ from sglang.srt.openai_protocol import (
17
+ ChatCompletionRequest,
18
+ ChatCompletionResponse,
19
+ ChatCompletionResponseChoice,
20
+ ChatCompletionResponseStreamChoice,
21
+ ChatCompletionStreamResponse,
22
+ ChatMessage,
23
+ CompletionRequest,
24
+ CompletionResponse,
25
+ CompletionResponseChoice,
26
+ CompletionResponseStreamChoice,
27
+ CompletionStreamResponse,
28
+ DeltaMessage,
29
+ LogProbs,
30
+ UsageInfo,
31
+ )
32
+ from sglang.srt.utils import jsonify_pydantic_model
33
+
34
+
35
+ chat_template_name = None
36
+
37
+ def load_chat_template_for_openai_api(chat_template_arg):
38
+ global chat_template_name
39
+
40
+ print(f"Use chat template: {chat_template_arg}")
41
+ if not chat_template_exists(chat_template_arg):
42
+ if not os.path.exists(chat_template_arg):
43
+ raise RuntimeError(
44
+ f"Chat template {chat_template_arg} is not a built-in template name "
45
+ "or a valid chat template file path."
46
+ )
47
+ with open(chat_template_arg, "r") as filep:
48
+ template = json.load(filep)
49
+ try:
50
+ sep_style = SeparatorStyle[template["sep_style"]]
51
+ except KeyError:
52
+ raise ValueError(
53
+ f"Unknown separator style: {template['sep_style']}"
54
+ ) from None
55
+ register_conv_template(
56
+ Conversation(
57
+ name=template["name"],
58
+ system_template=template["system"] + "\n{system_message}",
59
+ system_message=template.get("system_message", ""),
60
+ roles=(template["user"], template["assistant"]),
61
+ sep_style=sep_style,
62
+ sep=template.get("sep", "\n"),
63
+ stop_str=template["stop_str"],
64
+ ),
65
+ override=True,
66
+ )
67
+ chat_template_name = template["name"]
68
+ else:
69
+ chat_template_name = chat_template_arg
70
+
71
+
72
+ async def v1_completions(tokenizer_manager, raw_request: Request):
73
+ request_json = await raw_request.json()
74
+ request = CompletionRequest(**request_json)
75
+
76
+ # TODO: Validate the request and return HTTPStatus.BAD_REQUEST if invalid.
77
+ assert request.n == 1
78
+
79
+ adapted_request = GenerateReqInput(
80
+ text=request.prompt,
81
+ sampling_params={
82
+ "temperature": request.temperature,
83
+ "max_new_tokens": request.max_tokens,
84
+ "stop": request.stop,
85
+ "top_p": request.top_p,
86
+ "presence_penalty": request.presence_penalty,
87
+ "frequency_penalty": request.frequency_penalty,
88
+ "regex": request.regex,
89
+ },
90
+ return_logprob=request.logprobs is not None and request.logprobs > 0,
91
+ top_logprobs_num=request.logprobs if request.logprobs is not None else 0,
92
+ return_text_in_logprobs=True,
93
+ stream=request.stream,
94
+ )
95
+ adapted_request.post_init()
96
+
97
+ if adapted_request.stream:
98
+
99
+ async def generate_stream_resp():
100
+ stream_buffer = ""
101
+ n_prev_token = 0
102
+ async for content in tokenizer_manager.generate_request(adapted_request):
103
+ text = content["text"]
104
+ prompt_tokens = content["meta_info"]["prompt_tokens"]
105
+ completion_tokens = content["meta_info"]["completion_tokens"]
106
+
107
+ if not stream_buffer: # The first chunk
108
+ if request.echo:
109
+ # Prepend prompt in response text.
110
+ text = request.prompt + text
111
+
112
+ if request.logprobs:
113
+ # The first chunk and echo is enabled.
114
+ if not stream_buffer and request.echo:
115
+ prefill_token_logprobs = content["meta_info"][
116
+ "prefill_token_logprobs"
117
+ ]
118
+ prefill_top_logprobs = content["meta_info"][
119
+ "prefill_top_logprobs"
120
+ ]
121
+ else:
122
+ prefill_token_logprobs = None
123
+ prefill_top_logprobs = None
124
+
125
+ logprobs = to_openai_style_logprobs(
126
+ prefill_token_logprobs=prefill_token_logprobs,
127
+ prefill_top_logprobs=prefill_top_logprobs,
128
+ decode_token_logprobs=content["meta_info"][
129
+ "decode_token_logprobs"
130
+ ][n_prev_token:],
131
+ decode_top_logprobs=content["meta_info"]["decode_top_logprobs"][
132
+ n_prev_token:
133
+ ],
134
+ )
135
+
136
+ n_prev_token = len(content["meta_info"]["decode_token_logprobs"])
137
+ else:
138
+ logprobs = None
139
+
140
+ delta = text[len(stream_buffer) :]
141
+ stream_buffer = content["text"]
142
+ choice_data = CompletionResponseStreamChoice(
143
+ index=0,
144
+ text=delta,
145
+ logprobs=logprobs,
146
+ finish_reason=None,
147
+ )
148
+ chunk = CompletionStreamResponse(
149
+ id=content["meta_info"]["id"],
150
+ object="text_completion",
151
+ choices=[choice_data],
152
+ model=request.model,
153
+ usage=UsageInfo(
154
+ prompt_tokens=prompt_tokens,
155
+ completion_tokens=completion_tokens,
156
+ total_tokens=prompt_tokens + completion_tokens,
157
+ ),
158
+ )
159
+ yield f"data: {jsonify_pydantic_model(chunk)}\n\n"
160
+ yield "data: [DONE]\n\n"
161
+
162
+ return StreamingResponse(generate_stream_resp(), media_type="text/event-stream")
163
+
164
+ # Non-streaming response.
165
+ ret = await tokenizer_manager.generate_request(adapted_request).__anext__()
166
+ ret = ret[0] if isinstance(ret, list) else ret
167
+
168
+ prompt_tokens = ret["meta_info"]["prompt_tokens"]
169
+ completion_tokens = ret["meta_info"]["completion_tokens"]
170
+ text = ret["text"]
171
+ if request.echo:
172
+ text = request.prompt + text
173
+
174
+ if request.logprobs:
175
+ if request.echo:
176
+ prefill_token_logprobs = ret["meta_info"]["prefill_token_logprobs"]
177
+ prefill_top_logprobs = ret["meta_info"]["prefill_top_logprobs"]
178
+ else:
179
+ prefill_token_logprobs = None
180
+ prefill_top_logprobs = None
181
+
182
+ logprobs = to_openai_style_logprobs(
183
+ prefill_token_logprobs=prefill_token_logprobs,
184
+ prefill_top_logprobs=prefill_top_logprobs,
185
+ decode_token_logprobs=ret["meta_info"]["decode_token_logprobs"],
186
+ decode_top_logprobs=ret["meta_info"]["decode_top_logprobs"],
187
+ )
188
+ else:
189
+ logprobs = None
190
+
191
+ choice_data = CompletionResponseChoice(
192
+ index=0,
193
+ text=text,
194
+ logprobs=logprobs,
195
+ finish_reason=None, # TODO(comaniac): Add finish reason.
196
+ )
197
+ response = CompletionResponse(
198
+ id=ret["meta_info"]["id"],
199
+ model=request.model,
200
+ choices=[choice_data],
201
+ usage=UsageInfo(
202
+ prompt_tokens=prompt_tokens,
203
+ completion_tokens=completion_tokens,
204
+ total_tokens=prompt_tokens + completion_tokens,
205
+ ),
206
+ )
207
+ return response
208
+
209
+
210
+ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
211
+ request_json = await raw_request.json()
212
+ request = ChatCompletionRequest(**request_json)
213
+
214
+ # TODO: Validate the request and return HTTPStatus.BAD_REQUEST if invalid.
215
+ assert request.n == 1
216
+
217
+ # Prep the data needed for the underlying GenerateReqInput:
218
+ # - prompt: The full prompt string.
219
+ # - stop: Custom stop tokens.
220
+ # - image_data: None or a list of image strings (URLs or base64 strings).
221
+ # None skips any image processing in GenerateReqInput.
222
+ if not isinstance(request.messages, str):
223
+ # Apply chat template and its stop strings.
224
+ if chat_template_name is None:
225
+ prompt = tokenizer_manager.tokenizer.apply_chat_template(
226
+ request.messages, tokenize=False, add_generation_prompt=True
227
+ )
228
+ stop = request.stop
229
+ image_data = None
230
+ else:
231
+ conv = generate_chat_conv(request, chat_template_name)
232
+ prompt = conv.get_prompt()
233
+ image_data = conv.image_data
234
+ stop = conv.stop_str or []
235
+ if request.stop:
236
+ if isinstance(request.stop, str):
237
+ stop.append(request.stop)
238
+ else:
239
+ stop.extend(request.stop)
240
+ else:
241
+ # Use the raw prompt and stop strings if the messages is already a string.
242
+ prompt = request.messages
243
+ stop = request.stop
244
+ image_data = None
245
+
246
+ adapted_request = GenerateReqInput(
247
+ text=prompt,
248
+ image_data=image_data,
249
+ sampling_params={
250
+ "temperature": request.temperature,
251
+ "max_new_tokens": request.max_tokens,
252
+ "stop": stop,
253
+ "top_p": request.top_p,
254
+ "presence_penalty": request.presence_penalty,
255
+ "frequency_penalty": request.frequency_penalty,
256
+ "regex": request.regex,
257
+ },
258
+ stream=request.stream,
259
+ )
260
+ adapted_request.post_init()
261
+
262
+ if adapted_request.stream:
263
+
264
+ async def generate_stream_resp():
265
+ is_first = True
266
+
267
+ stream_buffer = ""
268
+ async for content in tokenizer_manager.generate_request(adapted_request):
269
+ if is_first:
270
+ # First chunk with role
271
+ is_first = False
272
+ choice_data = ChatCompletionResponseStreamChoice(
273
+ index=0,
274
+ delta=DeltaMessage(role="assistant"),
275
+ finish_reason=None,
276
+ )
277
+ chunk = ChatCompletionStreamResponse(
278
+ id=content["meta_info"]["id"],
279
+ choices=[choice_data],
280
+ model=request.model,
281
+ )
282
+ yield f"data: {jsonify_pydantic_model(chunk)}\n\n"
283
+
284
+ text = content["text"]
285
+ delta = text[len(stream_buffer) :]
286
+ stream_buffer = text
287
+ choice_data = ChatCompletionResponseStreamChoice(
288
+ index=0, delta=DeltaMessage(content=delta), finish_reason=None
289
+ )
290
+ chunk = ChatCompletionStreamResponse(
291
+ id=content["meta_info"]["id"],
292
+ choices=[choice_data],
293
+ model=request.model,
294
+ )
295
+ yield f"data: {jsonify_pydantic_model(chunk)}\n\n"
296
+ yield "data: [DONE]\n\n"
297
+
298
+ return StreamingResponse(generate_stream_resp(), media_type="text/event-stream")
299
+
300
+ # Non-streaming response.
301
+ ret = await tokenizer_manager.generate_request(adapted_request).__anext__()
302
+ prompt_tokens = ret["meta_info"]["prompt_tokens"]
303
+ completion_tokens = ret["meta_info"]["completion_tokens"]
304
+ choice_data = ChatCompletionResponseChoice(
305
+ index=0,
306
+ message=ChatMessage(role="assistant", content=ret["text"]),
307
+ finish_reason=None, # TODO(comaniac): Add finish reason.
308
+ )
309
+ response = ChatCompletionResponse(
310
+ id=ret["meta_info"]["id"],
311
+ model=request.model,
312
+ choices=[choice_data],
313
+ usage=UsageInfo(
314
+ prompt_tokens=prompt_tokens,
315
+ completion_tokens=completion_tokens,
316
+ total_tokens=prompt_tokens + completion_tokens,
317
+ ),
318
+ )
319
+ return response
320
+
321
+
322
+ def to_openai_style_logprobs(
323
+ prefill_token_logprobs=None,
324
+ decode_token_logprobs=None,
325
+ prefill_top_logprobs=None,
326
+ decode_top_logprobs=None,
327
+ ):
328
+ ret_logprobs = LogProbs()
329
+
330
+ def append_token_logprobs(token_logprobs):
331
+ for logprob, _, token_text in token_logprobs:
332
+ ret_logprobs.tokens.append(token_text)
333
+ ret_logprobs.token_logprobs.append(logprob)
334
+
335
+ # Not Supported yet
336
+ ret_logprobs.text_offset.append(-1)
337
+
338
+ def append_top_logprobs(top_logprobs):
339
+ for tokens in top_logprobs:
340
+ if tokens is not None:
341
+ ret_logprobs.top_logprobs.append(
342
+ {token[2]: token[0] for token in tokens}
343
+ )
344
+ else:
345
+ ret_logprobs.top_logprobs.append(None)
346
+
347
+ if prefill_token_logprobs is not None:
348
+ append_token_logprobs(prefill_token_logprobs)
349
+ if decode_token_logprobs is not None:
350
+ append_token_logprobs(decode_token_logprobs)
351
+ if prefill_top_logprobs is not None:
352
+ append_top_logprobs(prefill_top_logprobs)
353
+ if decode_top_logprobs is not None:
354
+ append_top_logprobs(decode_top_logprobs)
355
+
356
+ return ret_logprobs