sglang 0.2.10__py3-none-any.whl → 0.2.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. sglang/__init__.py +8 -0
  2. sglang/api.py +10 -2
  3. sglang/bench_latency.py +151 -40
  4. sglang/bench_serving.py +46 -22
  5. sglang/check_env.py +24 -2
  6. sglang/global_config.py +0 -1
  7. sglang/lang/backend/base_backend.py +3 -1
  8. sglang/lang/backend/openai.py +8 -3
  9. sglang/lang/backend/runtime_endpoint.py +46 -29
  10. sglang/lang/choices.py +164 -0
  11. sglang/lang/compiler.py +2 -2
  12. sglang/lang/interpreter.py +6 -13
  13. sglang/lang/ir.py +14 -5
  14. sglang/srt/constrained/base_tool_cache.py +1 -1
  15. sglang/srt/constrained/fsm_cache.py +12 -2
  16. sglang/srt/layers/activation.py +33 -0
  17. sglang/srt/layers/{token_attention.py → decode_attention.py} +9 -5
  18. sglang/srt/layers/extend_attention.py +6 -1
  19. sglang/srt/layers/layernorm.py +65 -0
  20. sglang/srt/layers/logits_processor.py +6 -1
  21. sglang/srt/layers/pooler.py +50 -0
  22. sglang/srt/layers/{context_flashattention_nopad.py → prefill_attention.py} +5 -0
  23. sglang/srt/layers/radix_attention.py +4 -7
  24. sglang/srt/managers/detokenizer_manager.py +31 -9
  25. sglang/srt/managers/io_struct.py +63 -0
  26. sglang/srt/managers/policy_scheduler.py +173 -25
  27. sglang/srt/managers/schedule_batch.py +174 -380
  28. sglang/srt/managers/tokenizer_manager.py +197 -112
  29. sglang/srt/managers/tp_worker.py +299 -364
  30. sglang/srt/mem_cache/{base_cache.py → base_prefix_cache.py} +9 -4
  31. sglang/srt/mem_cache/chunk_cache.py +43 -20
  32. sglang/srt/mem_cache/memory_pool.py +10 -15
  33. sglang/srt/mem_cache/radix_cache.py +74 -40
  34. sglang/srt/model_executor/cuda_graph_runner.py +27 -12
  35. sglang/srt/model_executor/forward_batch_info.py +319 -0
  36. sglang/srt/model_executor/model_runner.py +30 -47
  37. sglang/srt/models/chatglm.py +1 -1
  38. sglang/srt/models/commandr.py +1 -1
  39. sglang/srt/models/dbrx.py +1 -1
  40. sglang/srt/models/deepseek.py +1 -1
  41. sglang/srt/models/deepseek_v2.py +1 -1
  42. sglang/srt/models/gemma.py +1 -1
  43. sglang/srt/models/gemma2.py +1 -2
  44. sglang/srt/models/gpt_bigcode.py +1 -1
  45. sglang/srt/models/grok.py +1 -1
  46. sglang/srt/models/internlm2.py +3 -8
  47. sglang/srt/models/llama2.py +5 -5
  48. sglang/srt/models/llama_classification.py +1 -1
  49. sglang/srt/models/llama_embedding.py +88 -0
  50. sglang/srt/models/llava.py +1 -2
  51. sglang/srt/models/llavavid.py +1 -2
  52. sglang/srt/models/minicpm.py +1 -1
  53. sglang/srt/models/mixtral.py +1 -1
  54. sglang/srt/models/mixtral_quant.py +1 -1
  55. sglang/srt/models/qwen.py +1 -1
  56. sglang/srt/models/qwen2.py +1 -1
  57. sglang/srt/models/qwen2_moe.py +1 -12
  58. sglang/srt/models/stablelm.py +1 -1
  59. sglang/srt/openai_api/adapter.py +189 -39
  60. sglang/srt/openai_api/protocol.py +43 -1
  61. sglang/srt/sampling/penaltylib/__init__.py +13 -0
  62. sglang/srt/sampling/penaltylib/orchestrator.py +357 -0
  63. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +80 -0
  64. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +105 -0
  65. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +79 -0
  66. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +83 -0
  67. sglang/srt/sampling_params.py +31 -4
  68. sglang/srt/server.py +93 -21
  69. sglang/srt/server_args.py +30 -19
  70. sglang/srt/utils.py +31 -13
  71. sglang/test/run_eval.py +10 -1
  72. sglang/test/runners.py +63 -63
  73. sglang/test/simple_eval_humaneval.py +2 -8
  74. sglang/test/simple_eval_mgsm.py +203 -0
  75. sglang/test/srt/sampling/penaltylib/utils.py +337 -0
  76. sglang/test/test_layernorm.py +60 -0
  77. sglang/test/test_programs.py +4 -2
  78. sglang/test/test_utils.py +21 -3
  79. sglang/utils.py +0 -1
  80. sglang/version.py +1 -1
  81. {sglang-0.2.10.dist-info → sglang-0.2.12.dist-info}/METADATA +50 -31
  82. sglang-0.2.12.dist-info/RECORD +112 -0
  83. sglang/srt/layers/linear.py +0 -884
  84. sglang/srt/layers/quantization/__init__.py +0 -64
  85. sglang/srt/layers/quantization/fp8.py +0 -677
  86. sglang-0.2.10.dist-info/RECORD +0 -100
  87. {sglang-0.2.10.dist-info → sglang-0.2.12.dist-info}/LICENSE +0 -0
  88. {sglang-0.2.10.dist-info → sglang-0.2.12.dist-info}/WHEEL +0 -0
  89. {sglang-0.2.10.dist-info → sglang-0.2.12.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,88 @@
1
+ from typing import Iterable, Optional, Tuple
2
+
3
+ import torch
4
+ from torch import nn
5
+ from transformers import LlamaConfig
6
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
7
+
8
+ from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
9
+ from sglang.srt.model_executor.model_runner import InputMetadata
10
+ from sglang.srt.models.llama2 import LlamaForCausalLM, LlamaModel
11
+
12
+
13
+ class LlamaEmbeddingModel(nn.Module):
14
+ def __init__(
15
+ self,
16
+ config: LlamaConfig,
17
+ quant_config=None,
18
+ cache_config=None,
19
+ efficient_weight_load=False,
20
+ ) -> None:
21
+ super().__init__()
22
+ self.model = LlamaModel(config, quant_config=quant_config)
23
+ self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
24
+
25
+ @torch.no_grad()
26
+ def forward(
27
+ self,
28
+ input_ids: torch.Tensor,
29
+ positions: torch.Tensor,
30
+ input_metadata: InputMetadata,
31
+ input_embeds: torch.Tensor = None,
32
+ ) -> EmbeddingPoolerOutput:
33
+ hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
34
+ return self.pooler(hidden_states, input_metadata)
35
+
36
+ def load_weights(
37
+ self, weights: Iterable[Tuple[str, torch.Tensor]], name=None, loaded_weight=None
38
+ ):
39
+ stacked_params_mapping = [
40
+ # (param_name, shard_name, shard_id)
41
+ ("qkv_proj", "q_proj", "q"),
42
+ ("qkv_proj", "k_proj", "k"),
43
+ ("qkv_proj", "v_proj", "v"),
44
+ ("gate_up_proj", "gate_proj", 0),
45
+ ("gate_up_proj", "up_proj", 1),
46
+ ]
47
+ params_dict = dict(self.model.named_parameters())
48
+
49
+ def load_weights_per_param(name, loaded_weight):
50
+ if "rotary_emb.inv_freq" in name or "projector" in name:
51
+ return
52
+ if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
53
+ # Models trained using ColossalAI may include these tensors in
54
+ # the checkpoint. Skip them.
55
+ return
56
+ for param_name, weight_name, shard_id in stacked_params_mapping:
57
+ if weight_name not in name:
58
+ continue
59
+ name = name.replace(weight_name, param_name)
60
+ # Skip loading extra bias for GPTQ models.
61
+ if name.endswith(".bias") and name not in params_dict:
62
+ continue
63
+ if name.startswith("model.vision_tower") and name not in params_dict:
64
+ continue
65
+ param = params_dict[name]
66
+ weight_loader = param.weight_loader
67
+ weight_loader(param, loaded_weight, shard_id)
68
+ break
69
+ else:
70
+ # Skip loading extra bias for GPTQ models.
71
+ if name.endswith(".bias") and name not in params_dict:
72
+ return
73
+ if name.startswith("model.vision_tower") and name not in params_dict:
74
+ return
75
+ param = params_dict[name]
76
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
77
+ weight_loader(param, loaded_weight)
78
+
79
+ if name is None or loaded_weight is None:
80
+ for name, loaded_weight in weights:
81
+ load_weights_per_param(name, loaded_weight)
82
+ else:
83
+ load_weights_per_param(name, loaded_weight)
84
+
85
+
86
+ EntryClass = LlamaEmbeddingModel
87
+ # compat: e5-mistral model.config class == MistralModel
88
+ EntryClassRemapping = [("MistralModel", LlamaEmbeddingModel)]
@@ -32,13 +32,12 @@ from vllm.config import CacheConfig
32
32
  from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
33
33
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
34
34
 
35
- from sglang.srt.managers.schedule_batch import ForwardMode
36
35
  from sglang.srt.mm_utils import (
37
36
  get_anyres_image_grid_shape,
38
37
  unpad_image,
39
38
  unpad_image_shape,
40
39
  )
41
- from sglang.srt.model_executor.model_runner import InputMetadata
40
+ from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
42
41
  from sglang.srt.models.llama2 import LlamaForCausalLM
43
42
  from sglang.srt.models.mistral import MistralForCausalLM
44
43
  from sglang.srt.models.qwen2 import Qwen2ForCausalLM
@@ -26,13 +26,12 @@ from vllm.config import CacheConfig
26
26
  from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
27
27
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
28
28
 
29
- from sglang.srt.managers.schedule_batch import ForwardMode
30
29
  from sglang.srt.mm_utils import (
31
30
  get_anyres_image_grid_shape,
32
31
  unpad_image,
33
32
  unpad_image_shape,
34
33
  )
35
- from sglang.srt.model_executor.model_runner import InputMetadata
34
+ from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
36
35
  from sglang.srt.models.llama2 import LlamaForCausalLM
37
36
 
38
37
 
@@ -39,7 +39,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
39
39
 
40
40
  from sglang.srt.layers.logits_processor import LogitsProcessor
41
41
  from sglang.srt.layers.radix_attention import RadixAttention
42
- from sglang.srt.model_executor.model_runner import InputMetadata
42
+ from sglang.srt.model_executor.forward_batch_info import InputMetadata
43
43
 
44
44
 
45
45
  class MiniCPMMLP(nn.Module):
@@ -50,7 +50,7 @@ from vllm.utils import print_warning_once
50
50
 
51
51
  from sglang.srt.layers.logits_processor import LogitsProcessor
52
52
  from sglang.srt.layers.radix_attention import RadixAttention
53
- from sglang.srt.model_executor.model_runner import InputMetadata
53
+ from sglang.srt.model_executor.forward_batch_info import InputMetadata
54
54
 
55
55
 
56
56
  class MixtralMoE(nn.Module):
@@ -45,7 +45,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
45
45
 
46
46
  from sglang.srt.layers.logits_processor import LogitsProcessor
47
47
  from sglang.srt.layers.radix_attention import RadixAttention
48
- from sglang.srt.model_executor.model_runner import InputMetadata
48
+ from sglang.srt.model_executor.forward_batch_info import InputMetadata
49
49
 
50
50
 
51
51
  class MixtralMLP(nn.Module):
sglang/srt/models/qwen.py CHANGED
@@ -39,7 +39,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
39
39
 
40
40
  from sglang.srt.layers.logits_processor import LogitsProcessor
41
41
  from sglang.srt.layers.radix_attention import RadixAttention
42
- from sglang.srt.model_executor.model_runner import InputMetadata
42
+ from sglang.srt.model_executor.forward_batch_info import InputMetadata
43
43
 
44
44
 
45
45
  class QWenMLP(nn.Module):
@@ -39,7 +39,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
39
39
 
40
40
  from sglang.srt.layers.logits_processor import LogitsProcessor
41
41
  from sglang.srt.layers.radix_attention import RadixAttention
42
- from sglang.srt.model_executor.model_runner import InputMetadata
42
+ from sglang.srt.model_executor.forward_batch_info import InputMetadata
43
43
 
44
44
  Qwen2Config = None
45
45
 
@@ -46,12 +46,10 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
46
46
  VocabParallelEmbedding,
47
47
  )
48
48
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
49
- from vllm.model_executor.sampling_metadata import SamplingMetadata
50
- from vllm.sequence import IntermediateTensors, SamplerOutput
51
49
 
52
50
  from sglang.srt.layers.logits_processor import LogitsProcessor
53
51
  from sglang.srt.layers.radix_attention import RadixAttention
54
- from sglang.srt.model_executor.model_runner import InputMetadata
52
+ from sglang.srt.model_executor.forward_batch_info import InputMetadata
55
53
 
56
54
 
57
55
  class Qwen2MoeMLP(nn.Module):
@@ -368,7 +366,6 @@ class Qwen2MoeForCausalLM(nn.Module):
368
366
  config.vocab_size, config.hidden_size, quant_config=quant_config
369
367
  )
370
368
  self.logits_processor = LogitsProcessor(config)
371
- self.sampler = Sampler()
372
369
 
373
370
  @torch.no_grad()
374
371
  def forward(
@@ -394,14 +391,6 @@ class Qwen2MoeForCausalLM(nn.Module):
394
391
  )
395
392
  return logits
396
393
 
397
- def sample(
398
- self,
399
- logits: Optional[torch.Tensor],
400
- sampling_metadata: SamplingMetadata,
401
- ) -> Optional[SamplerOutput]:
402
- next_tokens = self.sampler(logits, sampling_metadata)
403
- return next_tokens
404
-
405
394
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
406
395
  stacked_params_mapping = [
407
396
  # (param_name, shard_name, shard_id)
@@ -40,7 +40,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
40
40
 
41
41
  from sglang.srt.layers.logits_processor import LogitsProcessor
42
42
  from sglang.srt.layers.radix_attention import RadixAttention
43
- from sglang.srt.model_executor.model_runner import InputMetadata
43
+ from sglang.srt.model_executor.forward_batch_info import InputMetadata
44
44
 
45
45
 
46
46
  class StablelmMLP(nn.Module):
@@ -34,7 +34,7 @@ from sglang.srt.conversation import (
34
34
  generate_chat_conv,
35
35
  register_conv_template,
36
36
  )
37
- from sglang.srt.managers.io_struct import GenerateReqInput
37
+ from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
38
38
  from sglang.srt.openai_api.protocol import (
39
39
  BatchRequest,
40
40
  BatchResponse,
@@ -52,7 +52,11 @@ from sglang.srt.openai_api.protocol import (
52
52
  CompletionResponseStreamChoice,
53
53
  CompletionStreamResponse,
54
54
  DeltaMessage,
55
+ EmbeddingObject,
56
+ EmbeddingRequest,
57
+ EmbeddingResponse,
55
58
  ErrorResponse,
59
+ FileDeleteResponse,
56
60
  FileRequest,
57
61
  FileResponse,
58
62
  LogProbs,
@@ -73,7 +77,7 @@ class FileMetadata:
73
77
  batch_storage: Dict[str, BatchResponse] = {}
74
78
  file_id_request: Dict[str, FileMetadata] = {}
75
79
  file_id_response: Dict[str, FileResponse] = {}
76
- # map file id to file path in SGlang backend
80
+ # map file id to file path in SGLang backend
77
81
  file_id_storage: Dict[str, str] = {}
78
82
 
79
83
 
@@ -81,6 +85,19 @@ file_id_storage: Dict[str, str] = {}
81
85
  storage_dir = None
82
86
 
83
87
 
88
+ def format_finish_reason(finish_reason) -> Optional[str]:
89
+ if finish_reason.startswith("None"):
90
+ return None
91
+ elif finish_reason.startswith("FINISH_MATCHED"):
92
+ return "stop"
93
+ elif finish_reason.startswith("FINISH_LENGTH"):
94
+ return "length"
95
+ elif finish_reason.startswith("FINISH_ABORT"):
96
+ return "abort"
97
+ else:
98
+ return "unknown"
99
+
100
+
84
101
  def create_error_response(
85
102
  message: str,
86
103
  err_type: str = "BadRequestError",
@@ -174,6 +191,20 @@ async def v1_files_create(file: UploadFile, purpose: str, file_storage_pth: str
174
191
  return {"error": "Invalid input", "details": e.errors()}
175
192
 
176
193
 
194
+ async def v1_delete_file(file_id: str):
195
+ # Retrieve the file job from the in-memory storage
196
+ file_response = file_id_response.get(file_id)
197
+ if file_response is None:
198
+ raise HTTPException(status_code=404, detail="File not found")
199
+ file_path = file_id_storage.get(file_id)
200
+ if file_path is None:
201
+ raise HTTPException(status_code=404, detail="File not found")
202
+ os.remove(file_path)
203
+ del file_id_response[file_id]
204
+ del file_id_storage[file_id]
205
+ return FileDeleteResponse(id=file_id, deleted=True)
206
+
207
+
177
208
  async def v1_batches(tokenizer_manager, raw_request: Request):
178
209
  try:
179
210
  body = await raw_request.json()
@@ -287,6 +318,13 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
287
318
  retrieve_batch = batch_storage[batch_id]
288
319
  retrieve_batch.output_file_id = output_file_id
289
320
  file_id_storage[output_file_id] = output_file_path
321
+ file_id_response[output_file_id] = FileResponse(
322
+ id=output_file_id,
323
+ bytes=os.path.getsize(output_file_path),
324
+ created_at=int(time.time()),
325
+ filename=f"{output_file_id}.jsonl",
326
+ purpose="batch_result",
327
+ )
290
328
  # Update batch status to "completed"
291
329
  retrieve_batch.status = "completed"
292
330
  retrieve_batch.completed_at = int(time.time())
@@ -297,7 +335,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
297
335
  }
298
336
 
299
337
  except Exception as e:
300
- print("error in SGlang:", e)
338
+ print("error in SGLang:", e)
301
339
  # Update batch status to "failed"
302
340
  retrieve_batch = batch_storage[batch_id]
303
341
  retrieve_batch.status = "failed"
@@ -335,7 +373,6 @@ async def v1_retrieve_file_content(file_id: str):
335
373
 
336
374
 
337
375
  def v1_generate_request(all_requests):
338
-
339
376
  prompts = []
340
377
  sampling_params_list = []
341
378
  return_logprobs = []
@@ -356,10 +393,13 @@ def v1_generate_request(all_requests):
356
393
  {
357
394
  "temperature": request.temperature,
358
395
  "max_new_tokens": request.max_tokens,
396
+ "min_new_tokens": request.min_tokens,
359
397
  "stop": request.stop,
398
+ "stop_token_ids": request.stop_token_ids,
360
399
  "top_p": request.top_p,
361
400
  "presence_penalty": request.presence_penalty,
362
401
  "frequency_penalty": request.frequency_penalty,
402
+ "repetition_penalty": request.repetition_penalty,
363
403
  "regex": request.regex,
364
404
  "n": request.n,
365
405
  "ignore_eos": request.ignore_eos,
@@ -380,7 +420,7 @@ def v1_generate_request(all_requests):
380
420
  else:
381
421
  prompt_kwargs = {"input_ids": prompt}
382
422
  else:
383
- if isinstance(prompts[0], str) or isinstance(propmt[0][0], str):
423
+ if isinstance(prompts[0], str):
384
424
  prompt_kwargs = {"text": prompts}
385
425
  else:
386
426
  prompt_kwargs = {"input_ids": prompts}
@@ -463,14 +503,18 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
463
503
  "index": 0,
464
504
  "text": text,
465
505
  "logprobs": logprobs,
466
- "finish_reason": ret_item["meta_info"]["finish_reason"],
506
+ "finish_reason": format_finish_reason(
507
+ ret_item["meta_info"]["finish_reason"]
508
+ ),
467
509
  }
468
510
  else:
469
511
  choice_data = CompletionResponseChoice(
470
512
  index=idx,
471
513
  text=text,
472
514
  logprobs=logprobs,
473
- finish_reason=ret_item["meta_info"]["finish_reason"],
515
+ finish_reason=format_finish_reason(
516
+ ret_item["meta_info"]["finish_reason"]
517
+ ),
474
518
  )
475
519
 
476
520
  choices.append(choice_data)
@@ -500,7 +544,9 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
500
544
  responses.append(response)
501
545
  return responses
502
546
  else:
503
- prompt_tokens = sum(item["meta_info"]["prompt_tokens"] for item in ret)
547
+ prompt_tokens = sum(
548
+ ret[i]["meta_info"]["prompt_tokens"] for i in range(0, len(ret), request.n)
549
+ )
504
550
  completion_tokens = sum(item["meta_info"]["completion_tokens"] for item in ret)
505
551
  response = CompletionResponse(
506
552
  id=ret[0]["meta_info"]["id"],
@@ -583,20 +629,34 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
583
629
  index=0,
584
630
  text=delta,
585
631
  logprobs=logprobs,
586
- finish_reason=content["meta_info"]["finish_reason"],
632
+ finish_reason=format_finish_reason(
633
+ content["meta_info"]["finish_reason"]
634
+ ),
587
635
  )
588
636
  chunk = CompletionStreamResponse(
589
637
  id=content["meta_info"]["id"],
590
638
  object="text_completion",
591
639
  choices=[choice_data],
592
640
  model=request.model,
593
- usage=UsageInfo(
594
- prompt_tokens=prompt_tokens,
595
- completion_tokens=completion_tokens,
596
- total_tokens=prompt_tokens + completion_tokens,
597
- ),
598
641
  )
599
642
  yield f"data: {chunk.model_dump_json()}\n\n"
643
+ if request.stream_options and request.stream_options.include_usage:
644
+ usage = UsageInfo(
645
+ prompt_tokens=prompt_tokens,
646
+ completion_tokens=completion_tokens,
647
+ total_tokens=prompt_tokens + completion_tokens,
648
+ )
649
+
650
+ final_usage_chunk = CompletionStreamResponse(
651
+ id=str(uuid.uuid4().hex),
652
+ choices=[],
653
+ model=request.model,
654
+ usage=usage,
655
+ )
656
+ final_usage_data = final_usage_chunk.model_dump_json(
657
+ exclude_unset=True, exclude_none=True
658
+ )
659
+ yield f"data: {final_usage_data}\n\n"
600
660
  except ValueError as e:
601
661
  error = create_streaming_error_response(str(e))
602
662
  yield f"data: {error}\n\n"
@@ -624,7 +684,6 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
624
684
 
625
685
 
626
686
  def v1_chat_generate_request(all_requests, tokenizer_manager):
627
-
628
687
  input_ids = []
629
688
  sampling_params_list = []
630
689
  image_data_list = []
@@ -667,10 +726,13 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
667
726
  {
668
727
  "temperature": request.temperature,
669
728
  "max_new_tokens": request.max_tokens,
729
+ "min_new_tokens": request.min_tokens,
670
730
  "stop": stop,
731
+ "stop_token_ids": request.stop_token_ids,
671
732
  "top_p": request.top_p,
672
733
  "presence_penalty": request.presence_penalty,
673
734
  "frequency_penalty": request.frequency_penalty,
735
+ "repetition_penalty": request.repetition_penalty,
674
736
  "regex": request.regex,
675
737
  "n": request.n,
676
738
  }
@@ -707,8 +769,6 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
707
769
 
708
770
  def v1_chat_generate_response(request, ret, to_file=False):
709
771
  choices = []
710
- total_prompt_tokens = 0
711
- total_completion_tokens = 0
712
772
 
713
773
  for idx, ret_item in enumerate(ret):
714
774
  logprobs = False
@@ -747,8 +807,6 @@ def v1_chat_generate_response(request, ret, to_file=False):
747
807
  choice_logprobs = ChoiceLogprobs(content=token_logprobs)
748
808
  else:
749
809
  choice_logprobs = None
750
- prompt_tokens = ret_item["meta_info"]["prompt_tokens"]
751
- completion_tokens = ret_item["meta_info"]["completion_tokens"]
752
810
 
753
811
  if to_file:
754
812
  # to make the choice data json serializable
@@ -756,19 +814,22 @@ def v1_chat_generate_response(request, ret, to_file=False):
756
814
  "index": 0,
757
815
  "message": {"role": "assistant", "content": ret_item["text"]},
758
816
  "logprobs": choice_logprobs,
759
- "finish_reason": ret_item["meta_info"]["finish_reason"],
817
+ "finish_reason": format_finish_reason(
818
+ ret_item["meta_info"]["finish_reason"]
819
+ ),
760
820
  }
761
821
  else:
762
822
  choice_data = ChatCompletionResponseChoice(
763
823
  index=idx,
764
824
  message=ChatMessage(role="assistant", content=ret_item["text"]),
765
825
  logprobs=choice_logprobs,
766
- finish_reason=ret_item["meta_info"]["finish_reason"],
826
+ finish_reason=format_finish_reason(
827
+ ret_item["meta_info"]["finish_reason"]
828
+ ),
767
829
  )
768
830
 
769
831
  choices.append(choice_data)
770
- total_prompt_tokens += prompt_tokens
771
- total_completion_tokens += completion_tokens
832
+
772
833
  if to_file:
773
834
  responses = []
774
835
 
@@ -795,14 +856,18 @@ def v1_chat_generate_response(request, ret, to_file=False):
795
856
  responses.append(response)
796
857
  return responses
797
858
  else:
859
+ prompt_tokens = sum(
860
+ ret[i]["meta_info"]["prompt_tokens"] for i in range(0, len(ret), request.n)
861
+ )
862
+ completion_tokens = sum(item["meta_info"]["completion_tokens"] for item in ret)
798
863
  response = ChatCompletionResponse(
799
864
  id=ret[0]["meta_info"]["id"],
800
865
  model=request.model,
801
866
  choices=choices,
802
867
  usage=UsageInfo(
803
- prompt_tokens=total_prompt_tokens,
804
- completion_tokens=total_completion_tokens,
805
- total_tokens=total_prompt_tokens + total_completion_tokens,
868
+ prompt_tokens=prompt_tokens,
869
+ completion_tokens=completion_tokens,
870
+ total_tokens=prompt_tokens + completion_tokens,
806
871
  ),
807
872
  )
808
873
  return response
@@ -877,18 +942,15 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
877
942
  choice_data = ChatCompletionResponseStreamChoice(
878
943
  index=0,
879
944
  delta=DeltaMessage(role="assistant"),
880
- finish_reason=content["meta_info"]["finish_reason"],
945
+ finish_reason=format_finish_reason(
946
+ content["meta_info"]["finish_reason"]
947
+ ),
881
948
  logprobs=choice_logprobs,
882
949
  )
883
950
  chunk = ChatCompletionStreamResponse(
884
951
  id=content["meta_info"]["id"],
885
952
  choices=[choice_data],
886
953
  model=request.model,
887
- usage=UsageInfo(
888
- prompt_tokens=prompt_tokens,
889
- completion_tokens=completion_tokens,
890
- total_tokens=prompt_tokens + completion_tokens,
891
- ),
892
954
  )
893
955
  yield f"data: {chunk.model_dump_json()}\n\n"
894
956
 
@@ -898,20 +960,34 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
898
960
  choice_data = ChatCompletionResponseStreamChoice(
899
961
  index=0,
900
962
  delta=DeltaMessage(content=delta),
901
- finish_reason=content["meta_info"]["finish_reason"],
963
+ finish_reason=format_finish_reason(
964
+ content["meta_info"]["finish_reason"]
965
+ ),
902
966
  logprobs=choice_logprobs,
903
967
  )
904
968
  chunk = ChatCompletionStreamResponse(
905
969
  id=content["meta_info"]["id"],
906
970
  choices=[choice_data],
907
971
  model=request.model,
908
- usage=UsageInfo(
909
- prompt_tokens=prompt_tokens,
910
- completion_tokens=completion_tokens,
911
- total_tokens=prompt_tokens + completion_tokens,
912
- ),
913
972
  )
914
973
  yield f"data: {chunk.model_dump_json()}\n\n"
974
+ if request.stream_options and request.stream_options.include_usage:
975
+ usage = UsageInfo(
976
+ prompt_tokens=prompt_tokens,
977
+ completion_tokens=completion_tokens,
978
+ total_tokens=prompt_tokens + completion_tokens,
979
+ )
980
+
981
+ final_usage_chunk = ChatCompletionStreamResponse(
982
+ id=str(uuid.uuid4().hex),
983
+ choices=[],
984
+ model=request.model,
985
+ usage=usage,
986
+ )
987
+ final_usage_data = final_usage_chunk.model_dump_json(
988
+ exclude_unset=True, exclude_none=True
989
+ )
990
+ yield f"data: {final_usage_data}\n\n"
915
991
  except ValueError as e:
916
992
  error = create_streaming_error_response(str(e))
917
993
  yield f"data: {error}\n\n"
@@ -930,7 +1006,6 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
930
1006
  ).__anext__()
931
1007
  except ValueError as e:
932
1008
  return create_error_response(str(e))
933
-
934
1009
  if not isinstance(ret, list):
935
1010
  ret = [ret]
936
1011
 
@@ -939,6 +1014,81 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
939
1014
  return response
940
1015
 
941
1016
 
1017
+ def v1_embedding_request(all_requests, tokenizer_manager):
1018
+ prompts = []
1019
+ sampling_params_list = []
1020
+ first_prompt_type = type(all_requests[0].input)
1021
+
1022
+ for request in all_requests:
1023
+ prompt = request.input
1024
+ assert (
1025
+ type(prompt) == first_prompt_type
1026
+ ), "All prompts must be of the same type in file input settings"
1027
+ prompts.append(prompt)
1028
+
1029
+ if len(all_requests) == 1:
1030
+ prompt = prompts[0]
1031
+ if isinstance(prompt, str) or isinstance(prompt[0], str):
1032
+ prompt_kwargs = {"text": prompt}
1033
+ else:
1034
+ prompt_kwargs = {"input_ids": prompt}
1035
+ else:
1036
+ if isinstance(prompts[0], str) or isinstance(propmt[0][0], str):
1037
+ prompt_kwargs = {"text": prompts}
1038
+ else:
1039
+ prompt_kwargs = {"input_ids": prompts}
1040
+
1041
+ adapted_request = EmbeddingReqInput(
1042
+ **prompt_kwargs,
1043
+ )
1044
+
1045
+ if len(all_requests) == 1:
1046
+ return adapted_request, all_requests[0]
1047
+ return adapted_request, all_requests
1048
+
1049
+
1050
+ def v1_embedding_response(ret, model_path, to_file=False):
1051
+ embedding_objects = []
1052
+ prompt_tokens = 0
1053
+ for idx, ret_item in enumerate(ret):
1054
+ embedding_objects.append(
1055
+ EmbeddingObject(
1056
+ embedding=ret[idx]["embedding"],
1057
+ index=idx,
1058
+ )
1059
+ )
1060
+ prompt_tokens += ret[idx]["meta_info"]["prompt_tokens"]
1061
+
1062
+ return EmbeddingResponse(
1063
+ data=embedding_objects,
1064
+ model=model_path,
1065
+ usage=UsageInfo(
1066
+ prompt_tokens=prompt_tokens,
1067
+ total_tokens=prompt_tokens,
1068
+ ),
1069
+ )
1070
+
1071
+
1072
+ async def v1_embeddings(tokenizer_manager, raw_request: Request):
1073
+ request_json = await raw_request.json()
1074
+ all_requests = [EmbeddingRequest(**request_json)]
1075
+ adapted_request, request = v1_embedding_request(all_requests, tokenizer_manager)
1076
+
1077
+ try:
1078
+ ret = await tokenizer_manager.generate_request(
1079
+ adapted_request, raw_request
1080
+ ).__anext__()
1081
+ except ValueError as e:
1082
+ return create_error_response(str(e))
1083
+
1084
+ if not isinstance(ret, list):
1085
+ ret = [ret]
1086
+
1087
+ response = v1_embedding_response(ret, tokenizer_manager.model_path)
1088
+
1089
+ return response
1090
+
1091
+
942
1092
  def to_openai_style_logprobs(
943
1093
  input_token_logprobs=None,
944
1094
  output_token_logprobs=None,