sglang 0.2.9.post1__py3-none-any.whl → 0.2.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -0
- sglang/api.py +10 -2
- sglang/bench_latency.py +234 -74
- sglang/check_env.py +25 -2
- sglang/global_config.py +0 -1
- sglang/lang/backend/base_backend.py +3 -1
- sglang/lang/backend/openai.py +8 -3
- sglang/lang/backend/runtime_endpoint.py +46 -40
- sglang/lang/choices.py +164 -0
- sglang/lang/interpreter.py +6 -13
- sglang/lang/ir.py +11 -2
- sglang/srt/hf_transformers_utils.py +2 -2
- sglang/srt/layers/extend_attention.py +59 -7
- sglang/srt/layers/logits_processor.py +1 -1
- sglang/srt/layers/radix_attention.py +24 -14
- sglang/srt/layers/token_attention.py +28 -2
- sglang/srt/managers/io_struct.py +9 -4
- sglang/srt/managers/schedule_batch.py +98 -323
- sglang/srt/managers/tokenizer_manager.py +34 -16
- sglang/srt/managers/tp_worker.py +20 -22
- sglang/srt/mem_cache/memory_pool.py +74 -38
- sglang/srt/model_config.py +11 -0
- sglang/srt/model_executor/cuda_graph_runner.py +3 -3
- sglang/srt/model_executor/forward_batch_info.py +256 -0
- sglang/srt/model_executor/model_runner.py +51 -26
- sglang/srt/models/chatglm.py +1 -1
- sglang/srt/models/commandr.py +1 -1
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/deepseek.py +1 -1
- sglang/srt/models/deepseek_v2.py +199 -17
- sglang/srt/models/gemma.py +1 -1
- sglang/srt/models/gemma2.py +1 -1
- sglang/srt/models/gpt_bigcode.py +1 -1
- sglang/srt/models/grok.py +1 -1
- sglang/srt/models/internlm2.py +1 -1
- sglang/srt/models/llama2.py +1 -1
- sglang/srt/models/llama_classification.py +1 -1
- sglang/srt/models/llava.py +1 -2
- sglang/srt/models/llavavid.py +1 -2
- sglang/srt/models/minicpm.py +1 -1
- sglang/srt/models/mixtral.py +1 -1
- sglang/srt/models/mixtral_quant.py +1 -1
- sglang/srt/models/qwen.py +1 -1
- sglang/srt/models/qwen2.py +1 -1
- sglang/srt/models/qwen2_moe.py +1 -1
- sglang/srt/models/stablelm.py +1 -1
- sglang/srt/openai_api/adapter.py +151 -29
- sglang/srt/openai_api/protocol.py +7 -1
- sglang/srt/server.py +111 -84
- sglang/srt/server_args.py +12 -2
- sglang/srt/utils.py +25 -20
- sglang/test/run_eval.py +21 -10
- sglang/test/runners.py +237 -0
- sglang/test/simple_eval_common.py +12 -12
- sglang/test/simple_eval_gpqa.py +92 -0
- sglang/test/simple_eval_humaneval.py +5 -5
- sglang/test/simple_eval_math.py +72 -0
- sglang/test/test_utils.py +95 -14
- sglang/utils.py +15 -37
- sglang/version.py +1 -1
- {sglang-0.2.9.post1.dist-info → sglang-0.2.11.dist-info}/METADATA +59 -48
- sglang-0.2.11.dist-info/RECORD +102 -0
- sglang-0.2.9.post1.dist-info/RECORD +0 -97
- {sglang-0.2.9.post1.dist-info → sglang-0.2.11.dist-info}/LICENSE +0 -0
- {sglang-0.2.9.post1.dist-info → sglang-0.2.11.dist-info}/WHEEL +0 -0
- {sglang-0.2.9.post1.dist-info → sglang-0.2.11.dist-info}/top_level.txt +0 -0
sglang/srt/server.py
CHANGED
@@ -28,7 +28,7 @@ import sys
|
|
28
28
|
import threading
|
29
29
|
import time
|
30
30
|
from http import HTTPStatus
|
31
|
-
from typing import Dict, Optional
|
31
|
+
from typing import Dict, List, Optional, Union
|
32
32
|
|
33
33
|
# Fix a bug of Python threading
|
34
34
|
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
@@ -59,6 +59,7 @@ from sglang.srt.openai_api.adapter import (
|
|
59
59
|
v1_batches,
|
60
60
|
v1_chat_completions,
|
61
61
|
v1_completions,
|
62
|
+
v1_delete_file,
|
62
63
|
v1_files_create,
|
63
64
|
v1_retrieve_batch,
|
64
65
|
v1_retrieve_file,
|
@@ -67,13 +68,13 @@ from sglang.srt.openai_api.adapter import (
|
|
67
68
|
from sglang.srt.openai_api.protocol import ModelCard, ModelList
|
68
69
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
69
70
|
from sglang.srt.utils import (
|
70
|
-
|
71
|
-
APIKeyValidatorMiddleware,
|
71
|
+
add_api_key_middleware,
|
72
72
|
allocate_init_ports,
|
73
73
|
assert_pkg_version,
|
74
74
|
enable_show_time_cost,
|
75
75
|
kill_child_process,
|
76
76
|
maybe_set_triton_cache_manager,
|
77
|
+
set_torch_compile_config,
|
77
78
|
set_ulimit,
|
78
79
|
)
|
79
80
|
from sglang.utils import get_exception_traceback
|
@@ -158,6 +159,16 @@ async def openai_v1_chat_completions(raw_request: Request):
|
|
158
159
|
return await v1_chat_completions(tokenizer_manager, raw_request)
|
159
160
|
|
160
161
|
|
162
|
+
@app.get("/v1/models")
|
163
|
+
def available_models():
|
164
|
+
"""Show available models."""
|
165
|
+
served_model_names = [tokenizer_manager.served_model_name]
|
166
|
+
model_cards = []
|
167
|
+
for served_model_name in served_model_names:
|
168
|
+
model_cards.append(ModelCard(id=served_model_name, root=served_model_name))
|
169
|
+
return ModelList(data=model_cards)
|
170
|
+
|
171
|
+
|
161
172
|
@app.post("/v1/files")
|
162
173
|
async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("batch")):
|
163
174
|
return await v1_files_create(
|
@@ -165,6 +176,12 @@ async def openai_v1_files(file: UploadFile = File(...), purpose: str = Form("bat
|
|
165
176
|
)
|
166
177
|
|
167
178
|
|
179
|
+
@app.delete("/v1/files/{file_id}")
|
180
|
+
async def delete_file(file_id: str):
|
181
|
+
# https://platform.openai.com/docs/api-reference/files/delete
|
182
|
+
return await v1_delete_file(file_id)
|
183
|
+
|
184
|
+
|
168
185
|
@app.post("/v1/batches")
|
169
186
|
async def openai_v1_batches(raw_request: Request):
|
170
187
|
return await v1_batches(tokenizer_manager, raw_request)
|
@@ -187,69 +204,11 @@ async def retrieve_file_content(file_id: str):
|
|
187
204
|
return await v1_retrieve_file_content(file_id)
|
188
205
|
|
189
206
|
|
190
|
-
@app.get("/v1/models")
|
191
|
-
def available_models():
|
192
|
-
"""Show available models."""
|
193
|
-
served_model_names = [tokenizer_manager.served_model_name]
|
194
|
-
model_cards = []
|
195
|
-
for served_model_name in served_model_names:
|
196
|
-
model_cards.append(ModelCard(id=served_model_name, root=served_model_name))
|
197
|
-
return ModelList(data=model_cards)
|
198
|
-
|
199
|
-
|
200
|
-
def _set_torch_compile_config():
|
201
|
-
# The following configurations are for torch compile optimizations
|
202
|
-
import torch._dynamo.config
|
203
|
-
import torch._inductor.config
|
204
|
-
|
205
|
-
torch._inductor.config.coordinate_descent_tuning = True
|
206
|
-
torch._inductor.config.triton.unique_kernel_names = True
|
207
|
-
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
|
208
|
-
|
209
|
-
# FIXME: tmp workaround
|
210
|
-
torch._dynamo.config.accumulated_cache_size_limit = 256
|
211
|
-
|
212
|
-
|
213
|
-
def set_envs_and_config(server_args: ServerArgs):
|
214
|
-
# Set global environments
|
215
|
-
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
216
|
-
os.environ["NCCL_CUMEM_ENABLE"] = "0"
|
217
|
-
os.environ["NCCL_NVLS_ENABLE"] = "0"
|
218
|
-
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
|
219
|
-
|
220
|
-
# Set ulimit
|
221
|
-
set_ulimit()
|
222
|
-
|
223
|
-
# Enable show time cost for debugging
|
224
|
-
if server_args.show_time_cost:
|
225
|
-
enable_show_time_cost()
|
226
|
-
|
227
|
-
# Disable disk cache
|
228
|
-
if server_args.disable_disk_cache:
|
229
|
-
disable_cache()
|
230
|
-
|
231
|
-
# Fix triton bugs
|
232
|
-
if server_args.tp_size * server_args.dp_size > 1:
|
233
|
-
# FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
|
234
|
-
maybe_set_triton_cache_manager()
|
235
|
-
|
236
|
-
# Set torch compile config
|
237
|
-
if server_args.enable_torch_compile:
|
238
|
-
_set_torch_compile_config()
|
239
|
-
|
240
|
-
# Set global chat template
|
241
|
-
if server_args.chat_template:
|
242
|
-
# TODO: replace this with huggingface transformers template
|
243
|
-
load_chat_template_for_openai_api(server_args.chat_template)
|
244
|
-
|
245
|
-
|
246
207
|
def launch_server(
|
247
208
|
server_args: ServerArgs,
|
248
209
|
model_overide_args: Optional[dict] = None,
|
249
210
|
pipe_finish_writer: Optional[mp.connection.Connection] = None,
|
250
211
|
):
|
251
|
-
server_args.check_server_args()
|
252
|
-
|
253
212
|
"""Launch an HTTP server."""
|
254
213
|
global tokenizer_manager
|
255
214
|
|
@@ -258,16 +217,8 @@ def launch_server(
|
|
258
217
|
format="%(message)s",
|
259
218
|
)
|
260
219
|
|
261
|
-
|
262
|
-
|
263
|
-
"flashinfer",
|
264
|
-
"0.1.3",
|
265
|
-
"Please uninstall the old version and "
|
266
|
-
"reinstall the latest version by following the instructions "
|
267
|
-
"at https://docs.flashinfer.ai/installation.html.",
|
268
|
-
)
|
269
|
-
|
270
|
-
set_envs_and_config(server_args)
|
220
|
+
server_args.check_server_args()
|
221
|
+
_set_envs_and_config(server_args)
|
271
222
|
|
272
223
|
# Allocate ports
|
273
224
|
server_args.port, server_args.additional_ports = allocate_init_ports(
|
@@ -284,7 +235,7 @@ def launch_server(
|
|
284
235
|
)
|
285
236
|
logger.info(f"{server_args=}")
|
286
237
|
|
287
|
-
#
|
238
|
+
# Launch processes for multi-node tensor parallelism
|
288
239
|
if server_args.nnodes > 1:
|
289
240
|
if server_args.node_rank != 0:
|
290
241
|
tp_size_local = server_args.tp_size // server_args.nnodes
|
@@ -349,8 +300,9 @@ def launch_server(
|
|
349
300
|
sys.exit(1)
|
350
301
|
assert proc_controller.is_alive() and proc_detoken.is_alive()
|
351
302
|
|
352
|
-
|
353
|
-
|
303
|
+
# Add api key authorization
|
304
|
+
if server_args.api_key:
|
305
|
+
add_api_key_middleware(app, server_args.api_key)
|
354
306
|
|
355
307
|
# Send a warmup request
|
356
308
|
t = threading.Thread(
|
@@ -372,21 +324,74 @@ def launch_server(
|
|
372
324
|
t.join()
|
373
325
|
|
374
326
|
|
327
|
+
def _set_envs_and_config(server_args: ServerArgs):
|
328
|
+
# Set global environments
|
329
|
+
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
330
|
+
os.environ["NCCL_CUMEM_ENABLE"] = "0"
|
331
|
+
os.environ["NCCL_NVLS_ENABLE"] = "0"
|
332
|
+
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
|
333
|
+
|
334
|
+
# Set ulimit
|
335
|
+
set_ulimit()
|
336
|
+
|
337
|
+
# Enable show time cost for debugging
|
338
|
+
if server_args.show_time_cost:
|
339
|
+
enable_show_time_cost()
|
340
|
+
|
341
|
+
# Disable disk cache
|
342
|
+
if server_args.disable_disk_cache:
|
343
|
+
disable_cache()
|
344
|
+
|
345
|
+
# Fix triton bugs
|
346
|
+
if server_args.tp_size * server_args.dp_size > 1:
|
347
|
+
# FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
|
348
|
+
maybe_set_triton_cache_manager()
|
349
|
+
|
350
|
+
# Set torch compile config
|
351
|
+
if server_args.enable_torch_compile:
|
352
|
+
set_torch_compile_config()
|
353
|
+
|
354
|
+
# Set global chat template
|
355
|
+
if server_args.chat_template:
|
356
|
+
# TODO: replace this with huggingface transformers template
|
357
|
+
load_chat_template_for_openai_api(server_args.chat_template)
|
358
|
+
|
359
|
+
# Check flashinfer version
|
360
|
+
if not server_args.disable_flashinfer:
|
361
|
+
assert_pkg_version(
|
362
|
+
"flashinfer",
|
363
|
+
"0.1.3",
|
364
|
+
"Please uninstall the old version and "
|
365
|
+
"reinstall the latest version by following the instructions "
|
366
|
+
"at https://docs.flashinfer.ai/installation.html.",
|
367
|
+
)
|
368
|
+
|
369
|
+
|
375
370
|
def _wait_and_warmup(server_args, pipe_finish_writer):
|
376
371
|
headers = {}
|
377
372
|
url = server_args.url()
|
378
373
|
if server_args.api_key:
|
379
|
-
headers[
|
374
|
+
headers["Authorization"] = f"Bearer {server_args.api_key}"
|
380
375
|
|
381
376
|
# Wait until the server is launched
|
377
|
+
success = False
|
382
378
|
for _ in range(120):
|
383
|
-
time.sleep(
|
379
|
+
time.sleep(1)
|
384
380
|
try:
|
385
|
-
requests.get(url + "/get_model_info", timeout=5, headers=headers)
|
381
|
+
res = requests.get(url + "/get_model_info", timeout=5, headers=headers)
|
382
|
+
assert res.status_code == 200, f"{res}"
|
383
|
+
success = True
|
386
384
|
break
|
387
|
-
except requests.exceptions.RequestException:
|
385
|
+
except (AssertionError, requests.exceptions.RequestException) as e:
|
386
|
+
last_traceback = get_exception_traceback()
|
388
387
|
pass
|
389
388
|
|
389
|
+
if not success:
|
390
|
+
if pipe_finish_writer is not None:
|
391
|
+
pipe_finish_writer.send(last_traceback)
|
392
|
+
print(f"Initialization failed. warmup error: {last_traceback}", flush=True)
|
393
|
+
sys.exit(1)
|
394
|
+
|
390
395
|
# Send a warmup request
|
391
396
|
try:
|
392
397
|
for _ in range(server_args.dp_size):
|
@@ -402,12 +407,13 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
|
|
402
407
|
headers=headers,
|
403
408
|
timeout=600,
|
404
409
|
)
|
405
|
-
assert res.status_code == 200
|
410
|
+
assert res.status_code == 200, f"{res}"
|
406
411
|
except Exception as e:
|
412
|
+
last_traceback = get_exception_traceback()
|
407
413
|
if pipe_finish_writer is not None:
|
408
|
-
pipe_finish_writer.send(
|
409
|
-
print(f"Initialization failed. warmup error: {
|
410
|
-
|
414
|
+
pipe_finish_writer.send(last_traceback)
|
415
|
+
print(f"Initialization failed. warmup error: {last_traceback}", flush=True)
|
416
|
+
sys.exit(1)
|
411
417
|
|
412
418
|
logger.info("The server is fired up and ready to roll!")
|
413
419
|
if pipe_finish_writer is not None:
|
@@ -481,10 +487,10 @@ class Runtime:
|
|
481
487
|
trust_remote_code=self.server_args.trust_remote_code,
|
482
488
|
)
|
483
489
|
|
484
|
-
async def
|
490
|
+
async def async_generate(
|
485
491
|
self,
|
486
492
|
prompt: str,
|
487
|
-
sampling_params: Dict,
|
493
|
+
sampling_params: Optional[Dict] = None,
|
488
494
|
):
|
489
495
|
json_data = {
|
490
496
|
"text": prompt,
|
@@ -507,5 +513,26 @@ class Runtime:
|
|
507
513
|
yield cur
|
508
514
|
pos += len(cur)
|
509
515
|
|
516
|
+
add_request = async_generate
|
517
|
+
|
518
|
+
def generate(
|
519
|
+
self,
|
520
|
+
prompt: str,
|
521
|
+
sampling_params: Optional[Dict] = None,
|
522
|
+
return_logprob: Optional[Union[List[bool], bool]] = False,
|
523
|
+
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
524
|
+
):
|
525
|
+
json_data = {
|
526
|
+
"text": prompt,
|
527
|
+
"sampling_params": sampling_params,
|
528
|
+
"return_logprob": return_logprob,
|
529
|
+
"top_logprobs_num": top_logprobs_num,
|
530
|
+
}
|
531
|
+
response = requests.post(
|
532
|
+
self.url + "/generate",
|
533
|
+
json=json_data,
|
534
|
+
)
|
535
|
+
return json.dumps(response.json())
|
536
|
+
|
510
537
|
def __del__(self):
|
511
538
|
self.shutdown()
|
sglang/srt/server_args.py
CHANGED
@@ -61,7 +61,7 @@ class ServerArgs:
|
|
61
61
|
show_time_cost: bool = False
|
62
62
|
|
63
63
|
# Other
|
64
|
-
api_key: str =
|
64
|
+
api_key: Optional[str] = None
|
65
65
|
file_storage_pth: str = "SGlang_storage"
|
66
66
|
|
67
67
|
# Data parallelism
|
@@ -80,6 +80,7 @@ class ServerArgs:
|
|
80
80
|
disable_disk_cache: bool = False
|
81
81
|
enable_torch_compile: bool = False
|
82
82
|
enable_p2p_check: bool = False
|
83
|
+
enable_mla: bool = False
|
83
84
|
attention_reduce_in_fp32: bool = False
|
84
85
|
efficient_weight_load: bool = False
|
85
86
|
|
@@ -263,6 +264,7 @@ class ServerArgs:
|
|
263
264
|
help="How conservative the schedule policy is. A larger value means more conservative scheduling. Use a larger value if you see requests being retracted frequently.",
|
264
265
|
)
|
265
266
|
parser.add_argument(
|
267
|
+
"--tensor-parallel-size",
|
266
268
|
"--tp-size",
|
267
269
|
type=int,
|
268
270
|
default=ServerArgs.tp_size,
|
@@ -306,7 +308,7 @@ class ServerArgs:
|
|
306
308
|
"--api-key",
|
307
309
|
type=str,
|
308
310
|
default=ServerArgs.api_key,
|
309
|
-
help="Set API key of the server.",
|
311
|
+
help="Set API key of the server. It is also used in the OpenAI API compatible server.",
|
310
312
|
)
|
311
313
|
parser.add_argument(
|
312
314
|
"--file-storage-pth",
|
@@ -317,6 +319,7 @@ class ServerArgs:
|
|
317
319
|
|
318
320
|
# Data parallelism
|
319
321
|
parser.add_argument(
|
322
|
+
"--data-parallel-size",
|
320
323
|
"--dp-size",
|
321
324
|
type=int,
|
322
325
|
default=ServerArgs.dp_size,
|
@@ -393,6 +396,11 @@ class ServerArgs:
|
|
393
396
|
action="store_true",
|
394
397
|
help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.",
|
395
398
|
)
|
399
|
+
parser.add_argument(
|
400
|
+
"--enable-mla",
|
401
|
+
action="store_true",
|
402
|
+
help="Enable Multi-head Latent Attention (MLA) for DeepSeek-V2",
|
403
|
+
)
|
396
404
|
parser.add_argument(
|
397
405
|
"--attention-reduce-in-fp32",
|
398
406
|
action="store_true",
|
@@ -407,6 +415,8 @@ class ServerArgs:
|
|
407
415
|
|
408
416
|
@classmethod
|
409
417
|
def from_cli_args(cls, args: argparse.Namespace):
|
418
|
+
args.tp_size = args.tensor_parallel_size
|
419
|
+
args.dp_size = args.data_parallel_size
|
410
420
|
attrs = [attr.name for attr in dataclasses.fields(cls)]
|
411
421
|
return cls(**{attr: getattr(args, attr) for attr in attrs})
|
412
422
|
|
sglang/srt/utils.py
CHANGED
@@ -539,26 +539,6 @@ class CustomCacheManager(FileCacheManager):
|
|
539
539
|
raise RuntimeError("Could not create or locate cache dir")
|
540
540
|
|
541
541
|
|
542
|
-
API_KEY_HEADER_NAME = "X-API-Key"
|
543
|
-
|
544
|
-
|
545
|
-
class APIKeyValidatorMiddleware(BaseHTTPMiddleware):
|
546
|
-
def __init__(self, app, api_key: str):
|
547
|
-
super().__init__(app)
|
548
|
-
self.api_key = api_key
|
549
|
-
|
550
|
-
async def dispatch(self, request, call_next):
|
551
|
-
# extract API key from the request headers
|
552
|
-
api_key_header = request.headers.get(API_KEY_HEADER_NAME)
|
553
|
-
if not api_key_header or api_key_header != self.api_key:
|
554
|
-
return JSONResponse(
|
555
|
-
status_code=403,
|
556
|
-
content={"detail": "Invalid API Key"},
|
557
|
-
)
|
558
|
-
response = await call_next(request)
|
559
|
-
return response
|
560
|
-
|
561
|
-
|
562
542
|
def get_ip_address(ifname):
|
563
543
|
"""
|
564
544
|
Get the IP address of a network interface.
|
@@ -642,6 +622,19 @@ def receive_addrs(model_port_args, server_args):
|
|
642
622
|
dist.destroy_process_group()
|
643
623
|
|
644
624
|
|
625
|
+
def set_torch_compile_config():
|
626
|
+
# The following configurations are for torch compile optimizations
|
627
|
+
import torch._dynamo.config
|
628
|
+
import torch._inductor.config
|
629
|
+
|
630
|
+
torch._inductor.config.coordinate_descent_tuning = True
|
631
|
+
torch._inductor.config.triton.unique_kernel_names = True
|
632
|
+
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
|
633
|
+
|
634
|
+
# FIXME: tmp workaround
|
635
|
+
torch._dynamo.config.accumulated_cache_size_limit = 256
|
636
|
+
|
637
|
+
|
645
638
|
def set_ulimit(target_soft_limit=65535):
|
646
639
|
resource_type = resource.RLIMIT_NOFILE
|
647
640
|
current_soft, current_hard = resource.getrlimit(resource_type)
|
@@ -700,3 +693,15 @@ def monkey_patch_vllm_qvk_linear_loader():
|
|
700
693
|
origin_weight_loader(self, param, loaded_weight, loaded_shard_id)
|
701
694
|
|
702
695
|
setattr(QKVParallelLinear, "weight_loader", weight_loader_srt)
|
696
|
+
|
697
|
+
|
698
|
+
def add_api_key_middleware(app, api_key):
|
699
|
+
@app.middleware("http")
|
700
|
+
async def authentication(request, call_next):
|
701
|
+
if request.method == "OPTIONS":
|
702
|
+
return await call_next(request)
|
703
|
+
if request.url.path.startswith("/health"):
|
704
|
+
return await call_next(request)
|
705
|
+
if request.headers.get("Authorization") != "Bearer " + api_key:
|
706
|
+
return JSONResponse(content={"error": "Unauthorized"}, status_code=401)
|
707
|
+
return await call_next(request)
|
sglang/test/run_eval.py
CHANGED
@@ -10,7 +10,6 @@ import time
|
|
10
10
|
|
11
11
|
from sglang.test.simple_eval_common import (
|
12
12
|
ChatCompletionSampler,
|
13
|
-
download_dataset,
|
14
13
|
make_report,
|
15
14
|
set_ulimit,
|
16
15
|
)
|
@@ -27,14 +26,26 @@ def run_eval(args):
|
|
27
26
|
if args.eval_name == "mmlu":
|
28
27
|
from sglang.test.simple_eval_mmlu import MMLUEval
|
29
28
|
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
29
|
+
filename = "https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv"
|
30
|
+
eval_obj = MMLUEval(filename, args.num_examples, args.num_threads)
|
31
|
+
elif args.eval_name == "math":
|
32
|
+
from sglang.test.simple_eval_math import MathEval
|
33
|
+
|
34
|
+
equality_checker = ChatCompletionSampler(model="gpt-4-turbo")
|
35
|
+
|
36
|
+
filename = (
|
37
|
+
"https://openaipublic.blob.core.windows.net/simple-evals/math_test.csv"
|
38
|
+
)
|
39
|
+
eval_obj = MathEval(
|
40
|
+
filename, equality_checker, args.num_examples, args.num_threads
|
41
|
+
)
|
42
|
+
elif args.eval_name == "gpqa":
|
43
|
+
from sglang.test.simple_eval_gpqa import GPQAEval
|
44
|
+
|
45
|
+
filename = (
|
46
|
+
"https://openaipublic.blob.core.windows.net/simple-evals/gpqa_diamond.csv"
|
47
|
+
)
|
48
|
+
eval_obj = GPQAEval(filename, args.num_examples, args.num_threads)
|
38
49
|
elif args.eval_name == "humaneval":
|
39
50
|
from sglang.test.simple_eval_humaneval import HumanEval
|
40
51
|
|
@@ -97,7 +108,7 @@ if __name__ == "__main__":
|
|
97
108
|
)
|
98
109
|
parser.add_argument("--eval-name", type=str, default="mmlu")
|
99
110
|
parser.add_argument("--num-examples", type=int)
|
100
|
-
parser.add_argument("--num-threads", type=int, default=
|
111
|
+
parser.add_argument("--num-threads", type=int, default=512)
|
101
112
|
set_ulimit()
|
102
113
|
args = parser.parse_args()
|
103
114
|
|