sglang 0.2.9.post1__py3-none-any.whl → 0.2.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -0
- sglang/api.py +10 -2
- sglang/bench_latency.py +234 -74
- sglang/check_env.py +25 -2
- sglang/global_config.py +0 -1
- sglang/lang/backend/base_backend.py +3 -1
- sglang/lang/backend/openai.py +8 -3
- sglang/lang/backend/runtime_endpoint.py +46 -40
- sglang/lang/choices.py +164 -0
- sglang/lang/interpreter.py +6 -13
- sglang/lang/ir.py +11 -2
- sglang/srt/hf_transformers_utils.py +2 -2
- sglang/srt/layers/extend_attention.py +59 -7
- sglang/srt/layers/logits_processor.py +1 -1
- sglang/srt/layers/radix_attention.py +24 -14
- sglang/srt/layers/token_attention.py +28 -2
- sglang/srt/managers/io_struct.py +9 -4
- sglang/srt/managers/schedule_batch.py +98 -323
- sglang/srt/managers/tokenizer_manager.py +34 -16
- sglang/srt/managers/tp_worker.py +20 -22
- sglang/srt/mem_cache/memory_pool.py +74 -38
- sglang/srt/model_config.py +11 -0
- sglang/srt/model_executor/cuda_graph_runner.py +3 -3
- sglang/srt/model_executor/forward_batch_info.py +256 -0
- sglang/srt/model_executor/model_runner.py +51 -26
- sglang/srt/models/chatglm.py +1 -1
- sglang/srt/models/commandr.py +1 -1
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/deepseek.py +1 -1
- sglang/srt/models/deepseek_v2.py +199 -17
- sglang/srt/models/gemma.py +1 -1
- sglang/srt/models/gemma2.py +1 -1
- sglang/srt/models/gpt_bigcode.py +1 -1
- sglang/srt/models/grok.py +1 -1
- sglang/srt/models/internlm2.py +1 -1
- sglang/srt/models/llama2.py +1 -1
- sglang/srt/models/llama_classification.py +1 -1
- sglang/srt/models/llava.py +1 -2
- sglang/srt/models/llavavid.py +1 -2
- sglang/srt/models/minicpm.py +1 -1
- sglang/srt/models/mixtral.py +1 -1
- sglang/srt/models/mixtral_quant.py +1 -1
- sglang/srt/models/qwen.py +1 -1
- sglang/srt/models/qwen2.py +1 -1
- sglang/srt/models/qwen2_moe.py +1 -1
- sglang/srt/models/stablelm.py +1 -1
- sglang/srt/openai_api/adapter.py +151 -29
- sglang/srt/openai_api/protocol.py +7 -1
- sglang/srt/server.py +111 -84
- sglang/srt/server_args.py +12 -2
- sglang/srt/utils.py +25 -20
- sglang/test/run_eval.py +21 -10
- sglang/test/runners.py +237 -0
- sglang/test/simple_eval_common.py +12 -12
- sglang/test/simple_eval_gpqa.py +92 -0
- sglang/test/simple_eval_humaneval.py +5 -5
- sglang/test/simple_eval_math.py +72 -0
- sglang/test/test_utils.py +95 -14
- sglang/utils.py +15 -37
- sglang/version.py +1 -1
- {sglang-0.2.9.post1.dist-info → sglang-0.2.11.dist-info}/METADATA +59 -48
- sglang-0.2.11.dist-info/RECORD +102 -0
- sglang-0.2.9.post1.dist-info/RECORD +0 -97
- {sglang-0.2.9.post1.dist-info → sglang-0.2.11.dist-info}/LICENSE +0 -0
- {sglang-0.2.9.post1.dist-info → sglang-0.2.11.dist-info}/WHEEL +0 -0
- {sglang-0.2.9.post1.dist-info → sglang-0.2.11.dist-info}/top_level.txt +0 -0
sglang/srt/models/gpt_bigcode.py
CHANGED
@@ -35,7 +35,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
35
35
|
|
36
36
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
37
37
|
from sglang.srt.layers.radix_attention import RadixAttention
|
38
|
-
from sglang.srt.
|
38
|
+
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
39
39
|
|
40
40
|
|
41
41
|
class GPTBigCodeAttention(nn.Module):
|
sglang/srt/models/grok.py
CHANGED
@@ -52,7 +52,7 @@ from vllm.utils import print_warning_once
|
|
52
52
|
from sglang.srt.layers.fused_moe import fused_moe
|
53
53
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
54
54
|
from sglang.srt.layers.radix_attention import RadixAttention
|
55
|
-
from sglang.srt.model_executor.
|
55
|
+
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
56
56
|
|
57
57
|
use_fused = True
|
58
58
|
|
sglang/srt/models/internlm2.py
CHANGED
@@ -40,7 +40,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
40
40
|
|
41
41
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
42
42
|
from sglang.srt.layers.radix_attention import RadixAttention
|
43
|
-
from sglang.srt.model_executor.
|
43
|
+
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
44
44
|
|
45
45
|
|
46
46
|
class InternLM2MLP(nn.Module):
|
sglang/srt/models/llama2.py
CHANGED
@@ -41,7 +41,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
41
41
|
|
42
42
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
43
43
|
from sglang.srt.layers.radix_attention import RadixAttention
|
44
|
-
from sglang.srt.model_executor.
|
44
|
+
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
45
45
|
|
46
46
|
|
47
47
|
class LlamaMLP(nn.Module):
|
@@ -25,7 +25,7 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
|
|
25
25
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
26
26
|
|
27
27
|
from sglang.srt.layers.logits_processor import LogitProcessorOutput
|
28
|
-
from sglang.srt.model_executor.
|
28
|
+
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
29
29
|
from sglang.srt.models.llama2 import LlamaModel
|
30
30
|
|
31
31
|
|
sglang/srt/models/llava.py
CHANGED
@@ -32,13 +32,12 @@ from vllm.config import CacheConfig
|
|
32
32
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
33
33
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
34
34
|
|
35
|
-
from sglang.srt.managers.schedule_batch import ForwardMode
|
36
35
|
from sglang.srt.mm_utils import (
|
37
36
|
get_anyres_image_grid_shape,
|
38
37
|
unpad_image,
|
39
38
|
unpad_image_shape,
|
40
39
|
)
|
41
|
-
from sglang.srt.model_executor.
|
40
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
42
41
|
from sglang.srt.models.llama2 import LlamaForCausalLM
|
43
42
|
from sglang.srt.models.mistral import MistralForCausalLM
|
44
43
|
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
sglang/srt/models/llavavid.py
CHANGED
@@ -26,13 +26,12 @@ from vllm.config import CacheConfig
|
|
26
26
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
27
27
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
28
28
|
|
29
|
-
from sglang.srt.managers.schedule_batch import ForwardMode
|
30
29
|
from sglang.srt.mm_utils import (
|
31
30
|
get_anyres_image_grid_shape,
|
32
31
|
unpad_image,
|
33
32
|
unpad_image_shape,
|
34
33
|
)
|
35
|
-
from sglang.srt.model_executor.
|
34
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
36
35
|
from sglang.srt.models.llama2 import LlamaForCausalLM
|
37
36
|
|
38
37
|
|
sglang/srt/models/minicpm.py
CHANGED
@@ -39,7 +39,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
39
39
|
|
40
40
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
41
41
|
from sglang.srt.layers.radix_attention import RadixAttention
|
42
|
-
from sglang.srt.model_executor.
|
42
|
+
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
43
43
|
|
44
44
|
|
45
45
|
class MiniCPMMLP(nn.Module):
|
sglang/srt/models/mixtral.py
CHANGED
@@ -50,7 +50,7 @@ from vllm.utils import print_warning_once
|
|
50
50
|
|
51
51
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
52
52
|
from sglang.srt.layers.radix_attention import RadixAttention
|
53
|
-
from sglang.srt.model_executor.
|
53
|
+
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
54
54
|
|
55
55
|
|
56
56
|
class MixtralMoE(nn.Module):
|
@@ -45,7 +45,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
45
45
|
|
46
46
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
47
47
|
from sglang.srt.layers.radix_attention import RadixAttention
|
48
|
-
from sglang.srt.model_executor.
|
48
|
+
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
49
49
|
|
50
50
|
|
51
51
|
class MixtralMLP(nn.Module):
|
sglang/srt/models/qwen.py
CHANGED
@@ -39,7 +39,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
39
39
|
|
40
40
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
41
41
|
from sglang.srt.layers.radix_attention import RadixAttention
|
42
|
-
from sglang.srt.model_executor.
|
42
|
+
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
43
43
|
|
44
44
|
|
45
45
|
class QWenMLP(nn.Module):
|
sglang/srt/models/qwen2.py
CHANGED
@@ -39,7 +39,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
39
39
|
|
40
40
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
41
41
|
from sglang.srt.layers.radix_attention import RadixAttention
|
42
|
-
from sglang.srt.model_executor.
|
42
|
+
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
43
43
|
|
44
44
|
Qwen2Config = None
|
45
45
|
|
sglang/srt/models/qwen2_moe.py
CHANGED
@@ -51,7 +51,7 @@ from vllm.sequence import IntermediateTensors, SamplerOutput
|
|
51
51
|
|
52
52
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
53
53
|
from sglang.srt.layers.radix_attention import RadixAttention
|
54
|
-
from sglang.srt.model_executor.
|
54
|
+
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
55
55
|
|
56
56
|
|
57
57
|
class Qwen2MoeMLP(nn.Module):
|
sglang/srt/models/stablelm.py
CHANGED
@@ -40,7 +40,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
40
40
|
|
41
41
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
42
42
|
from sglang.srt.layers.radix_attention import RadixAttention
|
43
|
-
from sglang.srt.model_executor.
|
43
|
+
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
44
44
|
|
45
45
|
|
46
46
|
class StablelmMLP(nn.Module):
|
sglang/srt/openai_api/adapter.py
CHANGED
@@ -53,6 +53,7 @@ from sglang.srt.openai_api.protocol import (
|
|
53
53
|
CompletionStreamResponse,
|
54
54
|
DeltaMessage,
|
55
55
|
ErrorResponse,
|
56
|
+
FileDeleteResponse,
|
56
57
|
FileRequest,
|
57
58
|
FileResponse,
|
58
59
|
LogProbs,
|
@@ -174,6 +175,20 @@ async def v1_files_create(file: UploadFile, purpose: str, file_storage_pth: str
|
|
174
175
|
return {"error": "Invalid input", "details": e.errors()}
|
175
176
|
|
176
177
|
|
178
|
+
async def v1_delete_file(file_id: str):
|
179
|
+
# Retrieve the file job from the in-memory storage
|
180
|
+
file_response = file_id_response.get(file_id)
|
181
|
+
if file_response is None:
|
182
|
+
raise HTTPException(status_code=404, detail="File not found")
|
183
|
+
file_path = file_id_storage.get(file_id)
|
184
|
+
if file_path is None:
|
185
|
+
raise HTTPException(status_code=404, detail="File not found")
|
186
|
+
os.remove(file_path)
|
187
|
+
del file_id_response[file_id]
|
188
|
+
del file_id_storage[file_id]
|
189
|
+
return FileDeleteResponse(id=file_id, deleted=True)
|
190
|
+
|
191
|
+
|
177
192
|
async def v1_batches(tokenizer_manager, raw_request: Request):
|
178
193
|
try:
|
179
194
|
body = await raw_request.json()
|
@@ -251,7 +266,9 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
|
|
251
266
|
if end_point == "/v1/chat/completions":
|
252
267
|
responses = v1_chat_generate_response(request, ret, to_file=True)
|
253
268
|
else:
|
254
|
-
responses = v1_generate_response(
|
269
|
+
responses = v1_generate_response(
|
270
|
+
request, ret, tokenizer_manager, to_file=True
|
271
|
+
)
|
255
272
|
|
256
273
|
except Exception as e:
|
257
274
|
error_json = {
|
@@ -285,6 +302,13 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
|
|
285
302
|
retrieve_batch = batch_storage[batch_id]
|
286
303
|
retrieve_batch.output_file_id = output_file_id
|
287
304
|
file_id_storage[output_file_id] = output_file_path
|
305
|
+
file_id_response[output_file_id] = FileResponse(
|
306
|
+
id=output_file_id,
|
307
|
+
bytes=os.path.getsize(output_file_path),
|
308
|
+
created_at=int(time.time()),
|
309
|
+
filename=f"{output_file_id}.jsonl",
|
310
|
+
purpose="batch_result",
|
311
|
+
)
|
288
312
|
# Update batch status to "completed"
|
289
313
|
retrieve_batch.status = "completed"
|
290
314
|
retrieve_batch.completed_at = int(time.time())
|
@@ -339,6 +363,7 @@ def v1_generate_request(all_requests):
|
|
339
363
|
return_logprobs = []
|
340
364
|
top_logprobs_nums = []
|
341
365
|
first_prompt_type = type(all_requests[0].prompt)
|
366
|
+
|
342
367
|
for request in all_requests:
|
343
368
|
prompt = request.prompt
|
344
369
|
assert (
|
@@ -364,7 +389,7 @@ def v1_generate_request(all_requests):
|
|
364
389
|
)
|
365
390
|
if len(all_requests) > 1 and request.n > 1:
|
366
391
|
raise ValueError(
|
367
|
-
"
|
392
|
+
"Parallel sampling is not supported for completions from files"
|
368
393
|
)
|
369
394
|
|
370
395
|
if len(all_requests) == 1:
|
@@ -381,6 +406,7 @@ def v1_generate_request(all_requests):
|
|
381
406
|
prompt_kwargs = {"text": prompts}
|
382
407
|
else:
|
383
408
|
prompt_kwargs = {"input_ids": prompts}
|
409
|
+
|
384
410
|
adapted_request = GenerateReqInput(
|
385
411
|
**prompt_kwargs,
|
386
412
|
sampling_params=sampling_params_list,
|
@@ -389,35 +415,52 @@ def v1_generate_request(all_requests):
|
|
389
415
|
return_text_in_logprobs=True,
|
390
416
|
stream=all_requests[0].stream,
|
391
417
|
)
|
418
|
+
|
392
419
|
if len(all_requests) == 1:
|
393
420
|
return adapted_request, all_requests[0]
|
394
421
|
return adapted_request, all_requests
|
395
422
|
|
396
423
|
|
397
|
-
def v1_generate_response(request, ret, to_file=False):
|
424
|
+
def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
|
398
425
|
choices = []
|
399
426
|
echo = False
|
400
427
|
|
401
|
-
if (not isinstance(request,
|
428
|
+
if (not isinstance(request, list)) and request.echo:
|
402
429
|
# TODO: handle the case propmt is token ids
|
403
|
-
if isinstance(request.prompt, list):
|
430
|
+
if isinstance(request.prompt, list) and isinstance(request.prompt[0], str):
|
431
|
+
# for the case of multiple str prompts
|
404
432
|
prompts = request.prompt
|
433
|
+
elif isinstance(request.prompt, list) and isinstance(request.prompt[0], list):
|
434
|
+
# for the case of multiple token ids prompts
|
435
|
+
prompts = [
|
436
|
+
tokenizer_manager.tokenizer.decode(prompt, skip_special_tokens=True)
|
437
|
+
for prompt in request.prompt
|
438
|
+
]
|
439
|
+
elif isinstance(request.prompt, list) and isinstance(request.prompt[0], int):
|
440
|
+
# for the case of single token ids prompt
|
441
|
+
prompts = [
|
442
|
+
tokenizer_manager.tokenizer.decode(
|
443
|
+
request.prompt, skip_special_tokens=True
|
444
|
+
)
|
445
|
+
]
|
405
446
|
else:
|
447
|
+
# for the case of single str prompt
|
406
448
|
prompts = [request.prompt]
|
407
449
|
echo = True
|
408
450
|
|
409
451
|
for idx, ret_item in enumerate(ret):
|
410
452
|
text = ret_item["text"]
|
411
|
-
if isinstance(request,
|
453
|
+
if isinstance(request, list) and request[idx].echo:
|
412
454
|
echo = True
|
413
455
|
text = request[idx].prompt + text
|
414
|
-
if (not isinstance(request,
|
415
|
-
|
456
|
+
if (not isinstance(request, list)) and echo:
|
457
|
+
prompt_index = idx // request.n
|
458
|
+
text = prompts[prompt_index] + text
|
416
459
|
|
417
460
|
logprobs = False
|
418
|
-
if isinstance(request,
|
461
|
+
if isinstance(request, list) and request[idx].logprobs:
|
419
462
|
logprobs = True
|
420
|
-
elif (not isinstance(request,
|
463
|
+
elif (not isinstance(request, list)) and request.logprobs:
|
421
464
|
logprobs = True
|
422
465
|
if logprobs:
|
423
466
|
if echo:
|
@@ -479,15 +522,18 @@ def v1_generate_response(request, ret, to_file=False):
|
|
479
522
|
responses.append(response)
|
480
523
|
return responses
|
481
524
|
else:
|
525
|
+
prompt_tokens = sum(
|
526
|
+
ret[i]["meta_info"]["prompt_tokens"] for i in range(0, len(ret), request.n)
|
527
|
+
)
|
482
528
|
completion_tokens = sum(item["meta_info"]["completion_tokens"] for item in ret)
|
483
529
|
response = CompletionResponse(
|
484
530
|
id=ret[0]["meta_info"]["id"],
|
485
531
|
model=request.model,
|
486
532
|
choices=choices,
|
487
533
|
usage=UsageInfo(
|
488
|
-
prompt_tokens=
|
534
|
+
prompt_tokens=prompt_tokens,
|
489
535
|
completion_tokens=completion_tokens,
|
490
|
-
total_tokens=
|
536
|
+
total_tokens=prompt_tokens + completion_tokens,
|
491
537
|
),
|
492
538
|
)
|
493
539
|
return response
|
@@ -513,8 +559,18 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|
513
559
|
|
514
560
|
if not stream_buffer: # The first chunk
|
515
561
|
if request.echo:
|
562
|
+
if isinstance(request.prompt, str):
|
563
|
+
# for the case of single str prompts
|
564
|
+
prompts = request.prompt
|
565
|
+
elif isinstance(request.prompt, list) and isinstance(
|
566
|
+
request.prompt[0], int
|
567
|
+
):
|
568
|
+
prompts = tokenizer_manager.tokenizer.decode(
|
569
|
+
request.prompt, skip_special_tokens=True
|
570
|
+
)
|
571
|
+
|
516
572
|
# Prepend prompt in response text.
|
517
|
-
text =
|
573
|
+
text = prompts + text
|
518
574
|
|
519
575
|
if request.logprobs:
|
520
576
|
# The first chunk and echo is enabled.
|
@@ -539,7 +595,6 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|
539
595
|
"output_top_logprobs"
|
540
596
|
][n_prev_token:],
|
541
597
|
)
|
542
|
-
|
543
598
|
n_prev_token = len(
|
544
599
|
content["meta_info"]["output_token_logprobs"]
|
545
600
|
)
|
@@ -588,7 +643,7 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|
588
643
|
if not isinstance(ret, list):
|
589
644
|
ret = [ret]
|
590
645
|
|
591
|
-
response = v1_generate_response(request, ret)
|
646
|
+
response = v1_generate_response(request, ret, tokenizer_manager)
|
592
647
|
return response
|
593
648
|
|
594
649
|
|
@@ -626,7 +681,7 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
|
|
626
681
|
prompt_ids = tokenizer_manager.tokenizer.encode(prompt)
|
627
682
|
else:
|
628
683
|
# Use the raw prompt and stop strings if the messages is already a string.
|
629
|
-
|
684
|
+
prompt_ids = request.messages
|
630
685
|
stop = request.stop
|
631
686
|
image_data = None
|
632
687
|
input_ids.append(prompt_ids)
|
@@ -647,12 +702,21 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
|
|
647
702
|
image_data_list.append(image_data)
|
648
703
|
if len(all_requests) == 1:
|
649
704
|
input_ids = input_ids[0]
|
705
|
+
if isinstance(input_ids, str):
|
706
|
+
prompt_kwargs = {"text": input_ids}
|
707
|
+
else:
|
708
|
+
prompt_kwargs = {"input_ids": input_ids}
|
650
709
|
sampling_params_list = sampling_params_list[0]
|
651
710
|
image_data = image_data_list[0]
|
652
711
|
return_logprobs = return_logprobs[0]
|
653
712
|
top_logprobs_nums = top_logprobs_nums[0]
|
713
|
+
else:
|
714
|
+
if isinstance(input_ids[0], str):
|
715
|
+
prompt_kwargs = {"text": input_ids}
|
716
|
+
else:
|
717
|
+
prompt_kwargs = {"input_ids": input_ids}
|
654
718
|
adapted_request = GenerateReqInput(
|
655
|
-
|
719
|
+
**prompt_kwargs,
|
656
720
|
image_data=image_data,
|
657
721
|
sampling_params=sampling_params_list,
|
658
722
|
return_logprob=return_logprobs,
|
@@ -667,14 +731,12 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
|
|
667
731
|
|
668
732
|
def v1_chat_generate_response(request, ret, to_file=False):
|
669
733
|
choices = []
|
670
|
-
total_prompt_tokens = 0
|
671
|
-
total_completion_tokens = 0
|
672
734
|
|
673
735
|
for idx, ret_item in enumerate(ret):
|
674
736
|
logprobs = False
|
675
|
-
if isinstance(request,
|
737
|
+
if isinstance(request, list) and request[idx].logprobs:
|
676
738
|
logprobs = True
|
677
|
-
elif (not isinstance(request,
|
739
|
+
elif (not isinstance(request, list)) and request.logprobs:
|
678
740
|
logprobs = True
|
679
741
|
if logprobs:
|
680
742
|
logprobs = to_openai_style_logprobs(
|
@@ -707,8 +769,6 @@ def v1_chat_generate_response(request, ret, to_file=False):
|
|
707
769
|
choice_logprobs = ChoiceLogprobs(content=token_logprobs)
|
708
770
|
else:
|
709
771
|
choice_logprobs = None
|
710
|
-
prompt_tokens = ret_item["meta_info"]["prompt_tokens"]
|
711
|
-
completion_tokens = ret_item["meta_info"]["completion_tokens"]
|
712
772
|
|
713
773
|
if to_file:
|
714
774
|
# to make the choice data json serializable
|
@@ -727,8 +787,7 @@ def v1_chat_generate_response(request, ret, to_file=False):
|
|
727
787
|
)
|
728
788
|
|
729
789
|
choices.append(choice_data)
|
730
|
-
|
731
|
-
total_completion_tokens += completion_tokens
|
790
|
+
|
732
791
|
if to_file:
|
733
792
|
responses = []
|
734
793
|
|
@@ -755,14 +814,18 @@ def v1_chat_generate_response(request, ret, to_file=False):
|
|
755
814
|
responses.append(response)
|
756
815
|
return responses
|
757
816
|
else:
|
817
|
+
prompt_tokens = sum(
|
818
|
+
ret[i]["meta_info"]["prompt_tokens"] for i in range(0, len(ret), request.n)
|
819
|
+
)
|
820
|
+
completion_tokens = sum(item["meta_info"]["completion_tokens"] for item in ret)
|
758
821
|
response = ChatCompletionResponse(
|
759
822
|
id=ret[0]["meta_info"]["id"],
|
760
823
|
model=request.model,
|
761
824
|
choices=choices,
|
762
825
|
usage=UsageInfo(
|
763
|
-
prompt_tokens=
|
764
|
-
completion_tokens=
|
765
|
-
total_tokens=
|
826
|
+
prompt_tokens=prompt_tokens,
|
827
|
+
completion_tokens=completion_tokens,
|
828
|
+
total_tokens=prompt_tokens + completion_tokens,
|
766
829
|
),
|
767
830
|
)
|
768
831
|
return response
|
@@ -779,10 +842,58 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|
779
842
|
is_first = True
|
780
843
|
|
781
844
|
stream_buffer = ""
|
845
|
+
n_prev_token = 0
|
782
846
|
try:
|
783
847
|
async for content in tokenizer_manager.generate_request(
|
784
848
|
adapted_request, raw_request
|
785
849
|
):
|
850
|
+
prompt_tokens = content["meta_info"]["prompt_tokens"]
|
851
|
+
completion_tokens = content["meta_info"]["completion_tokens"]
|
852
|
+
if request.logprobs:
|
853
|
+
logprobs = to_openai_style_logprobs(
|
854
|
+
output_token_logprobs=content["meta_info"][
|
855
|
+
"output_token_logprobs"
|
856
|
+
][n_prev_token:],
|
857
|
+
output_top_logprobs=content["meta_info"][
|
858
|
+
"output_top_logprobs"
|
859
|
+
][n_prev_token:],
|
860
|
+
)
|
861
|
+
|
862
|
+
n_prev_token = len(
|
863
|
+
content["meta_info"]["output_token_logprobs"]
|
864
|
+
)
|
865
|
+
token_logprobs = []
|
866
|
+
for token, logprob in zip(
|
867
|
+
logprobs.tokens, logprobs.token_logprobs
|
868
|
+
):
|
869
|
+
token_bytes = list(token.encode("utf-8"))
|
870
|
+
top_logprobs = []
|
871
|
+
if logprobs.top_logprobs:
|
872
|
+
for top_token, top_logprob in logprobs.top_logprobs[
|
873
|
+
0
|
874
|
+
].items():
|
875
|
+
top_token_bytes = list(top_token.encode("utf-8"))
|
876
|
+
top_logprobs.append(
|
877
|
+
TopLogprob(
|
878
|
+
token=top_token,
|
879
|
+
bytes=top_token_bytes,
|
880
|
+
logprob=top_logprob,
|
881
|
+
)
|
882
|
+
)
|
883
|
+
token_logprobs.append(
|
884
|
+
ChatCompletionTokenLogprob(
|
885
|
+
token=token,
|
886
|
+
bytes=token_bytes,
|
887
|
+
logprob=logprob,
|
888
|
+
top_logprobs=top_logprobs,
|
889
|
+
)
|
890
|
+
)
|
891
|
+
|
892
|
+
choice_logprobs = ChoiceLogprobs(content=token_logprobs)
|
893
|
+
|
894
|
+
else:
|
895
|
+
choice_logprobs = None
|
896
|
+
|
786
897
|
if is_first:
|
787
898
|
# First chunk with role
|
788
899
|
is_first = False
|
@@ -790,11 +901,17 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|
790
901
|
index=0,
|
791
902
|
delta=DeltaMessage(role="assistant"),
|
792
903
|
finish_reason=content["meta_info"]["finish_reason"],
|
904
|
+
logprobs=choice_logprobs,
|
793
905
|
)
|
794
906
|
chunk = ChatCompletionStreamResponse(
|
795
907
|
id=content["meta_info"]["id"],
|
796
908
|
choices=[choice_data],
|
797
909
|
model=request.model,
|
910
|
+
usage=UsageInfo(
|
911
|
+
prompt_tokens=prompt_tokens,
|
912
|
+
completion_tokens=completion_tokens,
|
913
|
+
total_tokens=prompt_tokens + completion_tokens,
|
914
|
+
),
|
798
915
|
)
|
799
916
|
yield f"data: {chunk.model_dump_json()}\n\n"
|
800
917
|
|
@@ -805,11 +922,17 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|
805
922
|
index=0,
|
806
923
|
delta=DeltaMessage(content=delta),
|
807
924
|
finish_reason=content["meta_info"]["finish_reason"],
|
925
|
+
logprobs=choice_logprobs,
|
808
926
|
)
|
809
927
|
chunk = ChatCompletionStreamResponse(
|
810
928
|
id=content["meta_info"]["id"],
|
811
929
|
choices=[choice_data],
|
812
930
|
model=request.model,
|
931
|
+
usage=UsageInfo(
|
932
|
+
prompt_tokens=prompt_tokens,
|
933
|
+
completion_tokens=completion_tokens,
|
934
|
+
total_tokens=prompt_tokens + completion_tokens,
|
935
|
+
),
|
813
936
|
)
|
814
937
|
yield f"data: {chunk.model_dump_json()}\n\n"
|
815
938
|
except ValueError as e:
|
@@ -830,7 +953,6 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|
830
953
|
).__anext__()
|
831
954
|
except ValueError as e:
|
832
955
|
return create_error_response(str(e))
|
833
|
-
|
834
956
|
if not isinstance(ret, list):
|
835
957
|
ret = [ret]
|
836
958
|
|
@@ -95,6 +95,12 @@ class FileResponse(BaseModel):
|
|
95
95
|
purpose: str
|
96
96
|
|
97
97
|
|
98
|
+
class FileDeleteResponse(BaseModel):
|
99
|
+
id: str
|
100
|
+
object: str = "file"
|
101
|
+
deleted: bool
|
102
|
+
|
103
|
+
|
98
104
|
class BatchRequest(BaseModel):
|
99
105
|
input_file_id: (
|
100
106
|
str # The ID of an uploaded file that contains requests for the new batch
|
@@ -278,7 +284,7 @@ class DeltaMessage(BaseModel):
|
|
278
284
|
class ChatCompletionResponseStreamChoice(BaseModel):
|
279
285
|
index: int
|
280
286
|
delta: DeltaMessage
|
281
|
-
logprobs: Optional[LogProbs] = None
|
287
|
+
logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None
|
282
288
|
finish_reason: Optional[str] = None
|
283
289
|
|
284
290
|
|