sglang 0.2.9.post1__py3-none-any.whl → 0.2.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +114 -63
- sglang/check_env.py +1 -0
- sglang/lang/backend/runtime_endpoint.py +0 -11
- sglang/srt/hf_transformers_utils.py +2 -2
- sglang/srt/layers/extend_attention.py +59 -7
- sglang/srt/layers/radix_attention.py +22 -9
- sglang/srt/layers/token_attention.py +28 -2
- sglang/srt/managers/io_struct.py +9 -4
- sglang/srt/managers/schedule_batch.py +15 -11
- sglang/srt/managers/tokenizer_manager.py +28 -13
- sglang/srt/mem_cache/memory_pool.py +65 -24
- sglang/srt/model_config.py +11 -0
- sglang/srt/model_executor/model_runner.py +46 -17
- sglang/srt/models/deepseek_v2.py +198 -16
- sglang/srt/openai_api/adapter.py +120 -20
- sglang/srt/openai_api/protocol.py +1 -1
- sglang/srt/server.py +87 -78
- sglang/srt/server_args.py +8 -2
- sglang/srt/utils.py +25 -20
- sglang/test/run_eval.py +21 -10
- sglang/test/runners.py +237 -0
- sglang/test/simple_eval_common.py +12 -12
- sglang/test/simple_eval_gpqa.py +92 -0
- sglang/test/simple_eval_humaneval.py +5 -5
- sglang/test/simple_eval_math.py +72 -0
- sglang/test/test_utils.py +94 -13
- sglang/utils.py +15 -37
- sglang/version.py +1 -1
- {sglang-0.2.9.post1.dist-info → sglang-0.2.10.dist-info}/METADATA +29 -28
- {sglang-0.2.9.post1.dist-info → sglang-0.2.10.dist-info}/RECORD +33 -30
- {sglang-0.2.9.post1.dist-info → sglang-0.2.10.dist-info}/LICENSE +0 -0
- {sglang-0.2.9.post1.dist-info → sglang-0.2.10.dist-info}/WHEEL +0 -0
- {sglang-0.2.9.post1.dist-info → sglang-0.2.10.dist-info}/top_level.txt +0 -0
sglang/srt/openai_api/adapter.py
CHANGED
@@ -251,7 +251,9 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
|
|
251
251
|
if end_point == "/v1/chat/completions":
|
252
252
|
responses = v1_chat_generate_response(request, ret, to_file=True)
|
253
253
|
else:
|
254
|
-
responses = v1_generate_response(
|
254
|
+
responses = v1_generate_response(
|
255
|
+
request, ret, tokenizer_manager, to_file=True
|
256
|
+
)
|
255
257
|
|
256
258
|
except Exception as e:
|
257
259
|
error_json = {
|
@@ -339,6 +341,7 @@ def v1_generate_request(all_requests):
|
|
339
341
|
return_logprobs = []
|
340
342
|
top_logprobs_nums = []
|
341
343
|
first_prompt_type = type(all_requests[0].prompt)
|
344
|
+
|
342
345
|
for request in all_requests:
|
343
346
|
prompt = request.prompt
|
344
347
|
assert (
|
@@ -364,7 +367,7 @@ def v1_generate_request(all_requests):
|
|
364
367
|
)
|
365
368
|
if len(all_requests) > 1 and request.n > 1:
|
366
369
|
raise ValueError(
|
367
|
-
"
|
370
|
+
"Parallel sampling is not supported for completions from files"
|
368
371
|
)
|
369
372
|
|
370
373
|
if len(all_requests) == 1:
|
@@ -377,10 +380,11 @@ def v1_generate_request(all_requests):
|
|
377
380
|
else:
|
378
381
|
prompt_kwargs = {"input_ids": prompt}
|
379
382
|
else:
|
380
|
-
if isinstance(prompts[0], str):
|
383
|
+
if isinstance(prompts[0], str) or isinstance(propmt[0][0], str):
|
381
384
|
prompt_kwargs = {"text": prompts}
|
382
385
|
else:
|
383
386
|
prompt_kwargs = {"input_ids": prompts}
|
387
|
+
|
384
388
|
adapted_request = GenerateReqInput(
|
385
389
|
**prompt_kwargs,
|
386
390
|
sampling_params=sampling_params_list,
|
@@ -389,35 +393,52 @@ def v1_generate_request(all_requests):
|
|
389
393
|
return_text_in_logprobs=True,
|
390
394
|
stream=all_requests[0].stream,
|
391
395
|
)
|
396
|
+
|
392
397
|
if len(all_requests) == 1:
|
393
398
|
return adapted_request, all_requests[0]
|
394
399
|
return adapted_request, all_requests
|
395
400
|
|
396
401
|
|
397
|
-
def v1_generate_response(request, ret, to_file=False):
|
402
|
+
def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
|
398
403
|
choices = []
|
399
404
|
echo = False
|
400
405
|
|
401
|
-
if (not isinstance(request,
|
406
|
+
if (not isinstance(request, list)) and request.echo:
|
402
407
|
# TODO: handle the case propmt is token ids
|
403
|
-
if isinstance(request.prompt, list):
|
408
|
+
if isinstance(request.prompt, list) and isinstance(request.prompt[0], str):
|
409
|
+
# for the case of multiple str prompts
|
404
410
|
prompts = request.prompt
|
411
|
+
elif isinstance(request.prompt, list) and isinstance(request.prompt[0], list):
|
412
|
+
# for the case of multiple token ids prompts
|
413
|
+
prompts = [
|
414
|
+
tokenizer_manager.tokenizer.decode(prompt, skip_special_tokens=True)
|
415
|
+
for prompt in request.prompt
|
416
|
+
]
|
417
|
+
elif isinstance(request.prompt, list) and isinstance(request.prompt[0], int):
|
418
|
+
# for the case of single token ids prompt
|
419
|
+
prompts = [
|
420
|
+
tokenizer_manager.tokenizer.decode(
|
421
|
+
request.prompt, skip_special_tokens=True
|
422
|
+
)
|
423
|
+
]
|
405
424
|
else:
|
425
|
+
# for the case of single str prompt
|
406
426
|
prompts = [request.prompt]
|
407
427
|
echo = True
|
408
428
|
|
409
429
|
for idx, ret_item in enumerate(ret):
|
410
430
|
text = ret_item["text"]
|
411
|
-
if isinstance(request,
|
431
|
+
if isinstance(request, list) and request[idx].echo:
|
412
432
|
echo = True
|
413
433
|
text = request[idx].prompt + text
|
414
|
-
if (not isinstance(request,
|
415
|
-
|
434
|
+
if (not isinstance(request, list)) and echo:
|
435
|
+
prompt_index = idx // request.n
|
436
|
+
text = prompts[prompt_index] + text
|
416
437
|
|
417
438
|
logprobs = False
|
418
|
-
if isinstance(request,
|
439
|
+
if isinstance(request, list) and request[idx].logprobs:
|
419
440
|
logprobs = True
|
420
|
-
elif (not isinstance(request,
|
441
|
+
elif (not isinstance(request, list)) and request.logprobs:
|
421
442
|
logprobs = True
|
422
443
|
if logprobs:
|
423
444
|
if echo:
|
@@ -479,15 +500,16 @@ def v1_generate_response(request, ret, to_file=False):
|
|
479
500
|
responses.append(response)
|
480
501
|
return responses
|
481
502
|
else:
|
503
|
+
prompt_tokens = sum(item["meta_info"]["prompt_tokens"] for item in ret)
|
482
504
|
completion_tokens = sum(item["meta_info"]["completion_tokens"] for item in ret)
|
483
505
|
response = CompletionResponse(
|
484
506
|
id=ret[0]["meta_info"]["id"],
|
485
507
|
model=request.model,
|
486
508
|
choices=choices,
|
487
509
|
usage=UsageInfo(
|
488
|
-
prompt_tokens=
|
510
|
+
prompt_tokens=prompt_tokens,
|
489
511
|
completion_tokens=completion_tokens,
|
490
|
-
total_tokens=
|
512
|
+
total_tokens=prompt_tokens + completion_tokens,
|
491
513
|
),
|
492
514
|
)
|
493
515
|
return response
|
@@ -513,8 +535,18 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|
513
535
|
|
514
536
|
if not stream_buffer: # The first chunk
|
515
537
|
if request.echo:
|
538
|
+
if isinstance(request.prompt, str):
|
539
|
+
# for the case of single str prompts
|
540
|
+
prompts = request.prompt
|
541
|
+
elif isinstance(request.prompt, list) and isinstance(
|
542
|
+
request.prompt[0], int
|
543
|
+
):
|
544
|
+
prompts = tokenizer_manager.tokenizer.decode(
|
545
|
+
request.prompt, skip_special_tokens=True
|
546
|
+
)
|
547
|
+
|
516
548
|
# Prepend prompt in response text.
|
517
|
-
text =
|
549
|
+
text = prompts + text
|
518
550
|
|
519
551
|
if request.logprobs:
|
520
552
|
# The first chunk and echo is enabled.
|
@@ -539,7 +571,6 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|
539
571
|
"output_top_logprobs"
|
540
572
|
][n_prev_token:],
|
541
573
|
)
|
542
|
-
|
543
574
|
n_prev_token = len(
|
544
575
|
content["meta_info"]["output_token_logprobs"]
|
545
576
|
)
|
@@ -588,7 +619,7 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|
588
619
|
if not isinstance(ret, list):
|
589
620
|
ret = [ret]
|
590
621
|
|
591
|
-
response = v1_generate_response(request, ret)
|
622
|
+
response = v1_generate_response(request, ret, tokenizer_manager)
|
592
623
|
return response
|
593
624
|
|
594
625
|
|
@@ -626,7 +657,7 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
|
|
626
657
|
prompt_ids = tokenizer_manager.tokenizer.encode(prompt)
|
627
658
|
else:
|
628
659
|
# Use the raw prompt and stop strings if the messages is already a string.
|
629
|
-
|
660
|
+
prompt_ids = request.messages
|
630
661
|
stop = request.stop
|
631
662
|
image_data = None
|
632
663
|
input_ids.append(prompt_ids)
|
@@ -647,12 +678,21 @@ def v1_chat_generate_request(all_requests, tokenizer_manager):
|
|
647
678
|
image_data_list.append(image_data)
|
648
679
|
if len(all_requests) == 1:
|
649
680
|
input_ids = input_ids[0]
|
681
|
+
if isinstance(input_ids, str):
|
682
|
+
prompt_kwargs = {"text": input_ids}
|
683
|
+
else:
|
684
|
+
prompt_kwargs = {"input_ids": input_ids}
|
650
685
|
sampling_params_list = sampling_params_list[0]
|
651
686
|
image_data = image_data_list[0]
|
652
687
|
return_logprobs = return_logprobs[0]
|
653
688
|
top_logprobs_nums = top_logprobs_nums[0]
|
689
|
+
else:
|
690
|
+
if isinstance(input_ids[0], str):
|
691
|
+
prompt_kwargs = {"text": input_ids}
|
692
|
+
else:
|
693
|
+
prompt_kwargs = {"input_ids": input_ids}
|
654
694
|
adapted_request = GenerateReqInput(
|
655
|
-
|
695
|
+
**prompt_kwargs,
|
656
696
|
image_data=image_data,
|
657
697
|
sampling_params=sampling_params_list,
|
658
698
|
return_logprob=return_logprobs,
|
@@ -672,9 +712,9 @@ def v1_chat_generate_response(request, ret, to_file=False):
|
|
672
712
|
|
673
713
|
for idx, ret_item in enumerate(ret):
|
674
714
|
logprobs = False
|
675
|
-
if isinstance(request,
|
715
|
+
if isinstance(request, list) and request[idx].logprobs:
|
676
716
|
logprobs = True
|
677
|
-
elif (not isinstance(request,
|
717
|
+
elif (not isinstance(request, list)) and request.logprobs:
|
678
718
|
logprobs = True
|
679
719
|
if logprobs:
|
680
720
|
logprobs = to_openai_style_logprobs(
|
@@ -779,10 +819,58 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|
779
819
|
is_first = True
|
780
820
|
|
781
821
|
stream_buffer = ""
|
822
|
+
n_prev_token = 0
|
782
823
|
try:
|
783
824
|
async for content in tokenizer_manager.generate_request(
|
784
825
|
adapted_request, raw_request
|
785
826
|
):
|
827
|
+
prompt_tokens = content["meta_info"]["prompt_tokens"]
|
828
|
+
completion_tokens = content["meta_info"]["completion_tokens"]
|
829
|
+
if request.logprobs:
|
830
|
+
logprobs = to_openai_style_logprobs(
|
831
|
+
output_token_logprobs=content["meta_info"][
|
832
|
+
"output_token_logprobs"
|
833
|
+
][n_prev_token:],
|
834
|
+
output_top_logprobs=content["meta_info"][
|
835
|
+
"output_top_logprobs"
|
836
|
+
][n_prev_token:],
|
837
|
+
)
|
838
|
+
|
839
|
+
n_prev_token = len(
|
840
|
+
content["meta_info"]["output_token_logprobs"]
|
841
|
+
)
|
842
|
+
token_logprobs = []
|
843
|
+
for token, logprob in zip(
|
844
|
+
logprobs.tokens, logprobs.token_logprobs
|
845
|
+
):
|
846
|
+
token_bytes = list(token.encode("utf-8"))
|
847
|
+
top_logprobs = []
|
848
|
+
if logprobs.top_logprobs:
|
849
|
+
for top_token, top_logprob in logprobs.top_logprobs[
|
850
|
+
0
|
851
|
+
].items():
|
852
|
+
top_token_bytes = list(top_token.encode("utf-8"))
|
853
|
+
top_logprobs.append(
|
854
|
+
TopLogprob(
|
855
|
+
token=top_token,
|
856
|
+
bytes=top_token_bytes,
|
857
|
+
logprob=top_logprob,
|
858
|
+
)
|
859
|
+
)
|
860
|
+
token_logprobs.append(
|
861
|
+
ChatCompletionTokenLogprob(
|
862
|
+
token=token,
|
863
|
+
bytes=token_bytes,
|
864
|
+
logprob=logprob,
|
865
|
+
top_logprobs=top_logprobs,
|
866
|
+
)
|
867
|
+
)
|
868
|
+
|
869
|
+
choice_logprobs = ChoiceLogprobs(content=token_logprobs)
|
870
|
+
|
871
|
+
else:
|
872
|
+
choice_logprobs = None
|
873
|
+
|
786
874
|
if is_first:
|
787
875
|
# First chunk with role
|
788
876
|
is_first = False
|
@@ -790,11 +878,17 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|
790
878
|
index=0,
|
791
879
|
delta=DeltaMessage(role="assistant"),
|
792
880
|
finish_reason=content["meta_info"]["finish_reason"],
|
881
|
+
logprobs=choice_logprobs,
|
793
882
|
)
|
794
883
|
chunk = ChatCompletionStreamResponse(
|
795
884
|
id=content["meta_info"]["id"],
|
796
885
|
choices=[choice_data],
|
797
886
|
model=request.model,
|
887
|
+
usage=UsageInfo(
|
888
|
+
prompt_tokens=prompt_tokens,
|
889
|
+
completion_tokens=completion_tokens,
|
890
|
+
total_tokens=prompt_tokens + completion_tokens,
|
891
|
+
),
|
798
892
|
)
|
799
893
|
yield f"data: {chunk.model_dump_json()}\n\n"
|
800
894
|
|
@@ -805,11 +899,17 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
|
805
899
|
index=0,
|
806
900
|
delta=DeltaMessage(content=delta),
|
807
901
|
finish_reason=content["meta_info"]["finish_reason"],
|
902
|
+
logprobs=choice_logprobs,
|
808
903
|
)
|
809
904
|
chunk = ChatCompletionStreamResponse(
|
810
905
|
id=content["meta_info"]["id"],
|
811
906
|
choices=[choice_data],
|
812
907
|
model=request.model,
|
908
|
+
usage=UsageInfo(
|
909
|
+
prompt_tokens=prompt_tokens,
|
910
|
+
completion_tokens=completion_tokens,
|
911
|
+
total_tokens=prompt_tokens + completion_tokens,
|
912
|
+
),
|
813
913
|
)
|
814
914
|
yield f"data: {chunk.model_dump_json()}\n\n"
|
815
915
|
except ValueError as e:
|
@@ -278,7 +278,7 @@ class DeltaMessage(BaseModel):
|
|
278
278
|
class ChatCompletionResponseStreamChoice(BaseModel):
|
279
279
|
index: int
|
280
280
|
delta: DeltaMessage
|
281
|
-
logprobs: Optional[LogProbs] = None
|
281
|
+
logprobs: Optional[Union[LogProbs, ChoiceLogprobs]] = None
|
282
282
|
finish_reason: Optional[str] = None
|
283
283
|
|
284
284
|
|
sglang/srt/server.py
CHANGED
@@ -28,7 +28,7 @@ import sys
|
|
28
28
|
import threading
|
29
29
|
import time
|
30
30
|
from http import HTTPStatus
|
31
|
-
from typing import Dict, Optional
|
31
|
+
from typing import Dict, List, Optional, Union
|
32
32
|
|
33
33
|
# Fix a bug of Python threading
|
34
34
|
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
@@ -67,13 +67,13 @@ from sglang.srt.openai_api.adapter import (
|
|
67
67
|
from sglang.srt.openai_api.protocol import ModelCard, ModelList
|
68
68
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
69
69
|
from sglang.srt.utils import (
|
70
|
-
|
71
|
-
APIKeyValidatorMiddleware,
|
70
|
+
add_api_key_middleware,
|
72
71
|
allocate_init_ports,
|
73
72
|
assert_pkg_version,
|
74
73
|
enable_show_time_cost,
|
75
74
|
kill_child_process,
|
76
75
|
maybe_set_triton_cache_manager,
|
76
|
+
set_torch_compile_config,
|
77
77
|
set_ulimit,
|
78
78
|
)
|
79
79
|
from sglang.utils import get_exception_traceback
|
@@ -158,6 +158,16 @@ async def openai_v1_chat_completions(raw_request: Request):
|
|
158
158
|
return await v1_chat_completions(tokenizer_manager, raw_request)
|
159
159
|
|
160
160
|
|
161
|
+
@app.get("/v1/models")
|
162
|
+
def available_models():
|
163
|
+
"""Show available models."""
|
164
|
+
served_model_names = [tokenizer_manager.served_model_name]
|
165
|
+
model_cards = []
|
166
|
+
for served_model_name in served_model_names:
|
167
|
+
model_cards.append(ModelCard(id=served_model_name, root=served_model_name))
|
168
|
+
return ModelList(data=model_cards)
|
169
|
+
|
170
|
+
|
161
171
|
@app.post("/v1/files")
|
162
172
|
async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("batch")):
|
163
173
|
return await v1_files_create(
|
@@ -187,69 +197,11 @@ async def retrieve_file_content(file_id: str):
|
|
187
197
|
return await v1_retrieve_file_content(file_id)
|
188
198
|
|
189
199
|
|
190
|
-
@app.get("/v1/models")
|
191
|
-
def available_models():
|
192
|
-
"""Show available models."""
|
193
|
-
served_model_names = [tokenizer_manager.served_model_name]
|
194
|
-
model_cards = []
|
195
|
-
for served_model_name in served_model_names:
|
196
|
-
model_cards.append(ModelCard(id=served_model_name, root=served_model_name))
|
197
|
-
return ModelList(data=model_cards)
|
198
|
-
|
199
|
-
|
200
|
-
def _set_torch_compile_config():
|
201
|
-
# The following configurations are for torch compile optimizations
|
202
|
-
import torch._dynamo.config
|
203
|
-
import torch._inductor.config
|
204
|
-
|
205
|
-
torch._inductor.config.coordinate_descent_tuning = True
|
206
|
-
torch._inductor.config.triton.unique_kernel_names = True
|
207
|
-
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
|
208
|
-
|
209
|
-
# FIXME: tmp workaround
|
210
|
-
torch._dynamo.config.accumulated_cache_size_limit = 256
|
211
|
-
|
212
|
-
|
213
|
-
def set_envs_and_config(server_args: ServerArgs):
|
214
|
-
# Set global environments
|
215
|
-
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
216
|
-
os.environ["NCCL_CUMEM_ENABLE"] = "0"
|
217
|
-
os.environ["NCCL_NVLS_ENABLE"] = "0"
|
218
|
-
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
|
219
|
-
|
220
|
-
# Set ulimit
|
221
|
-
set_ulimit()
|
222
|
-
|
223
|
-
# Enable show time cost for debugging
|
224
|
-
if server_args.show_time_cost:
|
225
|
-
enable_show_time_cost()
|
226
|
-
|
227
|
-
# Disable disk cache
|
228
|
-
if server_args.disable_disk_cache:
|
229
|
-
disable_cache()
|
230
|
-
|
231
|
-
# Fix triton bugs
|
232
|
-
if server_args.tp_size * server_args.dp_size > 1:
|
233
|
-
# FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
|
234
|
-
maybe_set_triton_cache_manager()
|
235
|
-
|
236
|
-
# Set torch compile config
|
237
|
-
if server_args.enable_torch_compile:
|
238
|
-
_set_torch_compile_config()
|
239
|
-
|
240
|
-
# Set global chat template
|
241
|
-
if server_args.chat_template:
|
242
|
-
# TODO: replace this with huggingface transformers template
|
243
|
-
load_chat_template_for_openai_api(server_args.chat_template)
|
244
|
-
|
245
|
-
|
246
200
|
def launch_server(
|
247
201
|
server_args: ServerArgs,
|
248
202
|
model_overide_args: Optional[dict] = None,
|
249
203
|
pipe_finish_writer: Optional[mp.connection.Connection] = None,
|
250
204
|
):
|
251
|
-
server_args.check_server_args()
|
252
|
-
|
253
205
|
"""Launch an HTTP server."""
|
254
206
|
global tokenizer_manager
|
255
207
|
|
@@ -258,16 +210,8 @@ def launch_server(
|
|
258
210
|
format="%(message)s",
|
259
211
|
)
|
260
212
|
|
261
|
-
|
262
|
-
|
263
|
-
"flashinfer",
|
264
|
-
"0.1.3",
|
265
|
-
"Please uninstall the old version and "
|
266
|
-
"reinstall the latest version by following the instructions "
|
267
|
-
"at https://docs.flashinfer.ai/installation.html.",
|
268
|
-
)
|
269
|
-
|
270
|
-
set_envs_and_config(server_args)
|
213
|
+
server_args.check_server_args()
|
214
|
+
_set_envs_and_config(server_args)
|
271
215
|
|
272
216
|
# Allocate ports
|
273
217
|
server_args.port, server_args.additional_ports = allocate_init_ports(
|
@@ -284,7 +228,7 @@ def launch_server(
|
|
284
228
|
)
|
285
229
|
logger.info(f"{server_args=}")
|
286
230
|
|
287
|
-
#
|
231
|
+
# Launch processes for multi-node tensor parallelism
|
288
232
|
if server_args.nnodes > 1:
|
289
233
|
if server_args.node_rank != 0:
|
290
234
|
tp_size_local = server_args.tp_size // server_args.nnodes
|
@@ -349,8 +293,9 @@ def launch_server(
|
|
349
293
|
sys.exit(1)
|
350
294
|
assert proc_controller.is_alive() and proc_detoken.is_alive()
|
351
295
|
|
352
|
-
|
353
|
-
|
296
|
+
# Add api key authorization
|
297
|
+
if server_args.api_key:
|
298
|
+
add_api_key_middleware(app, server_args.api_key)
|
354
299
|
|
355
300
|
# Send a warmup request
|
356
301
|
t = threading.Thread(
|
@@ -372,15 +317,58 @@ def launch_server(
|
|
372
317
|
t.join()
|
373
318
|
|
374
319
|
|
320
|
+
def _set_envs_and_config(server_args: ServerArgs):
|
321
|
+
# Set global environments
|
322
|
+
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
323
|
+
os.environ["NCCL_CUMEM_ENABLE"] = "0"
|
324
|
+
os.environ["NCCL_NVLS_ENABLE"] = "0"
|
325
|
+
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
|
326
|
+
|
327
|
+
# Set ulimit
|
328
|
+
set_ulimit()
|
329
|
+
|
330
|
+
# Enable show time cost for debugging
|
331
|
+
if server_args.show_time_cost:
|
332
|
+
enable_show_time_cost()
|
333
|
+
|
334
|
+
# Disable disk cache
|
335
|
+
if server_args.disable_disk_cache:
|
336
|
+
disable_cache()
|
337
|
+
|
338
|
+
# Fix triton bugs
|
339
|
+
if server_args.tp_size * server_args.dp_size > 1:
|
340
|
+
# FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
|
341
|
+
maybe_set_triton_cache_manager()
|
342
|
+
|
343
|
+
# Set torch compile config
|
344
|
+
if server_args.enable_torch_compile:
|
345
|
+
set_torch_compile_config()
|
346
|
+
|
347
|
+
# Set global chat template
|
348
|
+
if server_args.chat_template:
|
349
|
+
# TODO: replace this with huggingface transformers template
|
350
|
+
load_chat_template_for_openai_api(server_args.chat_template)
|
351
|
+
|
352
|
+
# Check flashinfer version
|
353
|
+
if not server_args.disable_flashinfer:
|
354
|
+
assert_pkg_version(
|
355
|
+
"flashinfer",
|
356
|
+
"0.1.3",
|
357
|
+
"Please uninstall the old version and "
|
358
|
+
"reinstall the latest version by following the instructions "
|
359
|
+
"at https://docs.flashinfer.ai/installation.html.",
|
360
|
+
)
|
361
|
+
|
362
|
+
|
375
363
|
def _wait_and_warmup(server_args, pipe_finish_writer):
|
376
364
|
headers = {}
|
377
365
|
url = server_args.url()
|
378
366
|
if server_args.api_key:
|
379
|
-
headers[
|
367
|
+
headers["Authorization"] = f"Bearer {server_args.api_key}"
|
380
368
|
|
381
369
|
# Wait until the server is launched
|
382
370
|
for _ in range(120):
|
383
|
-
time.sleep(
|
371
|
+
time.sleep(1)
|
384
372
|
try:
|
385
373
|
requests.get(url + "/get_model_info", timeout=5, headers=headers)
|
386
374
|
break
|
@@ -481,10 +469,10 @@ class Runtime:
|
|
481
469
|
trust_remote_code=self.server_args.trust_remote_code,
|
482
470
|
)
|
483
471
|
|
484
|
-
async def
|
472
|
+
async def async_generate(
|
485
473
|
self,
|
486
474
|
prompt: str,
|
487
|
-
sampling_params: Dict,
|
475
|
+
sampling_params: Optional[Dict] = None,
|
488
476
|
):
|
489
477
|
json_data = {
|
490
478
|
"text": prompt,
|
@@ -507,5 +495,26 @@ class Runtime:
|
|
507
495
|
yield cur
|
508
496
|
pos += len(cur)
|
509
497
|
|
498
|
+
add_request = async_generate
|
499
|
+
|
500
|
+
def generate(
|
501
|
+
self,
|
502
|
+
prompt: str,
|
503
|
+
sampling_params: Optional[Dict] = None,
|
504
|
+
return_logprob: Optional[Union[List[bool], bool]] = False,
|
505
|
+
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
506
|
+
):
|
507
|
+
json_data = {
|
508
|
+
"text": prompt,
|
509
|
+
"sampling_params": sampling_params,
|
510
|
+
"return_logprob": return_logprob,
|
511
|
+
"top_logprobs_num": top_logprobs_num,
|
512
|
+
}
|
513
|
+
response = requests.post(
|
514
|
+
self.url + "/generate",
|
515
|
+
json=json_data,
|
516
|
+
)
|
517
|
+
return json.dumps(response.json())
|
518
|
+
|
510
519
|
def __del__(self):
|
511
520
|
self.shutdown()
|
sglang/srt/server_args.py
CHANGED
@@ -61,7 +61,7 @@ class ServerArgs:
|
|
61
61
|
show_time_cost: bool = False
|
62
62
|
|
63
63
|
# Other
|
64
|
-
api_key: str =
|
64
|
+
api_key: Optional[str] = None
|
65
65
|
file_storage_pth: str = "SGlang_storage"
|
66
66
|
|
67
67
|
# Data parallelism
|
@@ -80,6 +80,7 @@ class ServerArgs:
|
|
80
80
|
disable_disk_cache: bool = False
|
81
81
|
enable_torch_compile: bool = False
|
82
82
|
enable_p2p_check: bool = False
|
83
|
+
enable_mla: bool = False
|
83
84
|
attention_reduce_in_fp32: bool = False
|
84
85
|
efficient_weight_load: bool = False
|
85
86
|
|
@@ -306,7 +307,7 @@ class ServerArgs:
|
|
306
307
|
"--api-key",
|
307
308
|
type=str,
|
308
309
|
default=ServerArgs.api_key,
|
309
|
-
help="Set API key of the server.",
|
310
|
+
help="Set API key of the server. It is also used in the OpenAI API compatible server.",
|
310
311
|
)
|
311
312
|
parser.add_argument(
|
312
313
|
"--file-storage-pth",
|
@@ -393,6 +394,11 @@ class ServerArgs:
|
|
393
394
|
action="store_true",
|
394
395
|
help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.",
|
395
396
|
)
|
397
|
+
parser.add_argument(
|
398
|
+
"--enable-mla",
|
399
|
+
action="store_true",
|
400
|
+
help="Enable Multi-head Latent Attention (MLA) for DeepSeek-V2",
|
401
|
+
)
|
396
402
|
parser.add_argument(
|
397
403
|
"--attention-reduce-in-fp32",
|
398
404
|
action="store_true",
|
sglang/srt/utils.py
CHANGED
@@ -539,26 +539,6 @@ class CustomCacheManager(FileCacheManager):
|
|
539
539
|
raise RuntimeError("Could not create or locate cache dir")
|
540
540
|
|
541
541
|
|
542
|
-
API_KEY_HEADER_NAME = "X-API-Key"
|
543
|
-
|
544
|
-
|
545
|
-
class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
|
546
|
-
def __init__(self, app, api_key: str):
|
547
|
-
super().__init__(app)
|
548
|
-
self.api_key = api_key
|
549
|
-
|
550
|
-
async def dispatch(self, request, call_next):
|
551
|
-
# extract API key from the request headers
|
552
|
-
api_key_header = request.headers.get(API_KEY_HEADER_NAME)
|
553
|
-
if not api_key_header or api_key_header != self.api_key:
|
554
|
-
return JSONResponse(
|
555
|
-
status_code=403,
|
556
|
-
content={"detail": "Invalid API Key"},
|
557
|
-
)
|
558
|
-
response = await call_next(request)
|
559
|
-
return response
|
560
|
-
|
561
|
-
|
562
542
|
def get_ip_address(ifname):
|
563
543
|
"""
|
564
544
|
Get the IP address of a network interface.
|
@@ -642,6 +622,19 @@ def receive_addrs(model_port_args, server_args):
|
|
642
622
|
dist.destroy_process_group()
|
643
623
|
|
644
624
|
|
625
|
+
def set_torch_compile_config():
|
626
|
+
# The following configurations are for torch compile optimizations
|
627
|
+
import torch._dynamo.config
|
628
|
+
import torch._inductor.config
|
629
|
+
|
630
|
+
torch._inductor.config.coordinate_descent_tuning = True
|
631
|
+
torch._inductor.config.triton.unique_kernel_names = True
|
632
|
+
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
|
633
|
+
|
634
|
+
# FIXME: tmp workaround
|
635
|
+
torch._dynamo.config.accumulated_cache_size_limit = 256
|
636
|
+
|
637
|
+
|
645
638
|
def set_ulimit(target_soft_limit=65535):
|
646
639
|
resource_type = resource.RLIMIT_NOFILE
|
647
640
|
current_soft, current_hard = resource.getrlimit(resource_type)
|
@@ -700,3 +693,15 @@ def monkey_patch_vllm_qvk_linear_loader():
|
|
700
693
|
origin_weight_loader(self, param, loaded_weight, loaded_shard_id)
|
701
694
|
|
702
695
|
setattr(QKVParallelLinear, "weight_loader", weight_loader_srt)
|
696
|
+
|
697
|
+
|
698
|
+
def add_api_key_middleware(app, api_key):
|
699
|
+
@app.middleware("http")
|
700
|
+
async def authentication(request, call_next):
|
701
|
+
if request.method == "OPTIONS":
|
702
|
+
return await call_next(request)
|
703
|
+
if request.url.path.startswith("/health"):
|
704
|
+
return await call_next(request)
|
705
|
+
if request.headers.get("Authorization") != "Bearer " + api_key:
|
706
|
+
return JSONResponse(content={"error": "Unauthorized"}, status_code=401)
|
707
|
+
return await call_next(request)
|