sglang 0.4.1.post4__py3-none-any.whl → 0.4.1.post6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +18 -1
- sglang/lang/interpreter.py +71 -1
- sglang/lang/ir.py +2 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/chatglm.py +78 -0
- sglang/srt/configs/dbrx.py +279 -0
- sglang/srt/configs/model_config.py +16 -7
- sglang/srt/hf_transformers_utils.py +9 -14
- sglang/srt/layers/attention/__init__.py +8 -1
- sglang/srt/layers/attention/flashinfer_backend.py +21 -5
- sglang/srt/layers/linear.py +89 -47
- sglang/srt/layers/logits_processor.py +6 -6
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +16 -5
- sglang/srt/layers/moe/fused_moe_triton/layer.py +39 -12
- sglang/srt/layers/moe/topk.py +4 -2
- sglang/srt/layers/parameter.py +439 -0
- sglang/srt/layers/quantization/__init__.py +5 -2
- sglang/srt/layers/quantization/fp8.py +107 -53
- sglang/srt/layers/quantization/fp8_utils.py +1 -1
- sglang/srt/layers/quantization/int8_kernel.py +54 -0
- sglang/srt/layers/quantization/modelopt_quant.py +174 -0
- sglang/srt/layers/quantization/w8a8_int8.py +117 -0
- sglang/srt/layers/radix_attention.py +2 -0
- sglang/srt/layers/vocab_parallel_embedding.py +16 -3
- sglang/srt/managers/cache_controller.py +307 -0
- sglang/srt/managers/configure_logging.py +43 -0
- sglang/srt/managers/data_parallel_controller.py +2 -0
- sglang/srt/managers/detokenizer_manager.py +0 -2
- sglang/srt/managers/io_struct.py +29 -13
- sglang/srt/managers/schedule_batch.py +7 -1
- sglang/srt/managers/scheduler.py +58 -15
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +109 -45
- sglang/srt/mem_cache/memory_pool.py +313 -53
- sglang/srt/metrics/collector.py +32 -35
- sglang/srt/model_executor/cuda_graph_runner.py +14 -7
- sglang/srt/model_executor/forward_batch_info.py +20 -15
- sglang/srt/model_executor/model_runner.py +53 -10
- sglang/srt/models/chatglm.py +1 -1
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/grok.py +25 -16
- sglang/srt/models/llama.py +46 -4
- sglang/srt/models/qwen2.py +11 -0
- sglang/srt/models/qwen2_eagle.py +131 -0
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +15 -5
- sglang/srt/sampling/sampling_batch_info.py +15 -5
- sglang/srt/sampling/sampling_params.py +1 -1
- sglang/srt/server.py +125 -69
- sglang/srt/server_args.py +39 -19
- sglang/srt/speculative/eagle_utils.py +93 -85
- sglang/srt/speculative/eagle_worker.py +48 -33
- sglang/srt/torch_memory_saver_adapter.py +59 -0
- sglang/srt/utils.py +61 -5
- sglang/test/test_programs.py +23 -1
- sglang/test/test_utils.py +36 -7
- sglang/version.py +1 -1
- {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/METADATA +16 -15
- {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/RECORD +61 -51
- {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/WHEEL +1 -1
- {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.post4.dist-info → sglang-0.4.1.post6.dist-info}/top_level.txt +0 -0
sglang/srt/server.py
CHANGED
@@ -31,6 +31,8 @@ from typing import AsyncIterator, Dict, List, Optional, Tuple, Union
|
|
31
31
|
|
32
32
|
import torch
|
33
33
|
|
34
|
+
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
35
|
+
|
34
36
|
# Fix a bug of Python threading
|
35
37
|
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
36
38
|
|
@@ -52,11 +54,14 @@ from sglang.srt.managers.data_parallel_controller import (
|
|
52
54
|
from sglang.srt.managers.detokenizer_manager import run_detokenizer_process
|
53
55
|
from sglang.srt.managers.io_struct import (
|
54
56
|
CloseSessionReqInput,
|
57
|
+
ConfigureLoggingReq,
|
55
58
|
EmbeddingReqInput,
|
56
59
|
GenerateReqInput,
|
57
60
|
GetWeightsByNameReqInput,
|
58
61
|
InitWeightsUpdateGroupReqInput,
|
59
62
|
OpenSessionReqInput,
|
63
|
+
ReleaseMemoryOccupationReqInput,
|
64
|
+
ResumeMemoryOccupationReqInput,
|
60
65
|
UpdateWeightFromDiskReqInput,
|
61
66
|
UpdateWeightsFromDistributedReqInput,
|
62
67
|
UpdateWeightsFromTensorReqInput,
|
@@ -127,14 +132,12 @@ async def health() -> Response:
|
|
127
132
|
async def health_generate(request: Request) -> Response:
|
128
133
|
"""Check the health of the inference server by generating one token."""
|
129
134
|
|
135
|
+
sampling_params = {"max_new_tokens": 1, "temperature": 0.7}
|
136
|
+
|
130
137
|
if tokenizer_manager.is_generation:
|
131
|
-
gri = GenerateReqInput(
|
132
|
-
input_ids=[0], sampling_params={"max_new_tokens": 1, "temperature": 0.7}
|
133
|
-
)
|
138
|
+
gri = GenerateReqInput(input_ids=[0], sampling_params=sampling_params)
|
134
139
|
else:
|
135
|
-
gri = EmbeddingReqInput(
|
136
|
-
input_ids=[0], sampling_params={"max_new_tokens": 1, "temperature": 0.7}
|
137
|
-
)
|
140
|
+
gri = EmbeddingReqInput(input_ids=[0], sampling_params=sampling_params)
|
138
141
|
|
139
142
|
try:
|
140
143
|
async for _ in tokenizer_manager.generate_request(gri, request):
|
@@ -159,12 +162,68 @@ async def get_model_info():
|
|
159
162
|
@app.get("/get_server_info")
|
160
163
|
async def get_server_info():
|
161
164
|
return {
|
162
|
-
**dataclasses.asdict(tokenizer_manager.server_args),
|
165
|
+
**dataclasses.asdict(tokenizer_manager.server_args),
|
163
166
|
**scheduler_info,
|
164
167
|
"version": __version__,
|
165
168
|
}
|
166
169
|
|
167
170
|
|
171
|
+
# fastapi implicitly converts json in the request to obj (dataclass)
|
172
|
+
@app.api_route("/generate", methods=["POST", "PUT"])
|
173
|
+
@time_func_latency
|
174
|
+
async def generate_request(obj: GenerateReqInput, request: Request):
|
175
|
+
"""Handle a generate request."""
|
176
|
+
if obj.stream:
|
177
|
+
|
178
|
+
async def stream_results() -> AsyncIterator[bytes]:
|
179
|
+
try:
|
180
|
+
async for out in tokenizer_manager.generate_request(obj, request):
|
181
|
+
yield b"data: " + orjson.dumps(
|
182
|
+
out, option=orjson.OPT_NON_STR_KEYS
|
183
|
+
) + b"\n\n"
|
184
|
+
except ValueError as e:
|
185
|
+
out = {"error": {"message": str(e)}}
|
186
|
+
yield b"data: " + orjson.dumps(
|
187
|
+
out, option=orjson.OPT_NON_STR_KEYS
|
188
|
+
) + b"\n\n"
|
189
|
+
yield b"data: [DONE]\n\n"
|
190
|
+
|
191
|
+
return StreamingResponse(
|
192
|
+
stream_results(),
|
193
|
+
media_type="text/event-stream",
|
194
|
+
background=tokenizer_manager.create_abort_task(obj),
|
195
|
+
)
|
196
|
+
else:
|
197
|
+
try:
|
198
|
+
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
199
|
+
return ret
|
200
|
+
except ValueError as e:
|
201
|
+
logger.error(f"Error: {e}")
|
202
|
+
return _create_error_response(e)
|
203
|
+
|
204
|
+
|
205
|
+
@app.api_route("/encode", methods=["POST", "PUT"])
|
206
|
+
@time_func_latency
|
207
|
+
async def encode_request(obj: EmbeddingReqInput, request: Request):
|
208
|
+
"""Handle an embedding request."""
|
209
|
+
try:
|
210
|
+
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
211
|
+
return ret
|
212
|
+
except ValueError as e:
|
213
|
+
return _create_error_response(e)
|
214
|
+
|
215
|
+
|
216
|
+
@app.api_route("/classify", methods=["POST", "PUT"])
|
217
|
+
@time_func_latency
|
218
|
+
async def classify_request(obj: EmbeddingReqInput, request: Request):
|
219
|
+
"""Handle a reward model request. Now the arguments and return values are the same as embedding models."""
|
220
|
+
try:
|
221
|
+
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
222
|
+
return ret
|
223
|
+
except ValueError as e:
|
224
|
+
return _create_error_response(e)
|
225
|
+
|
226
|
+
|
168
227
|
@app.post("/flush_cache")
|
169
228
|
async def flush_cache():
|
170
229
|
"""Flush the radix cache."""
|
@@ -176,8 +235,7 @@ async def flush_cache():
|
|
176
235
|
)
|
177
236
|
|
178
237
|
|
179
|
-
@app.
|
180
|
-
@app.post("/start_profile")
|
238
|
+
@app.api_route("/start_profile", methods=["GET", "POST"])
|
181
239
|
async def start_profile_async():
|
182
240
|
"""Start profiling."""
|
183
241
|
tokenizer_manager.start_profile()
|
@@ -187,8 +245,7 @@ async def start_profile_async():
|
|
187
245
|
)
|
188
246
|
|
189
247
|
|
190
|
-
@app.
|
191
|
-
@app.post("/stop_profile")
|
248
|
+
@app.api_route("/stop_profile", methods=["GET", "POST"])
|
192
249
|
async def stop_profile_async():
|
193
250
|
"""Stop profiling."""
|
194
251
|
tokenizer_manager.stop_profile()
|
@@ -257,6 +314,28 @@ async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request):
|
|
257
314
|
return _create_error_response(e)
|
258
315
|
|
259
316
|
|
317
|
+
@app.api_route("/release_memory_occupation", methods=["GET", "POST"])
|
318
|
+
async def release_memory_occupation(
|
319
|
+
obj: ReleaseMemoryOccupationReqInput, request: Request
|
320
|
+
):
|
321
|
+
"""Release GPU occupation temporarily"""
|
322
|
+
try:
|
323
|
+
await tokenizer_manager.release_memory_occupation(obj, request)
|
324
|
+
except Exception as e:
|
325
|
+
return _create_error_response(e)
|
326
|
+
|
327
|
+
|
328
|
+
@app.api_route("/resume_memory_occupation", methods=["GET", "POST"])
|
329
|
+
async def resume_memory_occupation(
|
330
|
+
obj: ResumeMemoryOccupationReqInput, request: Request
|
331
|
+
):
|
332
|
+
"""Resume GPU occupation"""
|
333
|
+
try:
|
334
|
+
await tokenizer_manager.resume_memory_occupation(obj, request)
|
335
|
+
except Exception as e:
|
336
|
+
return _create_error_response(e)
|
337
|
+
|
338
|
+
|
260
339
|
@app.api_route("/open_session", methods=["GET", "POST"])
|
261
340
|
async def open_session(obj: OpenSessionReqInput, request: Request):
|
262
341
|
"""Open a session, and return its unique session id."""
|
@@ -281,60 +360,11 @@ async def close_session(obj: CloseSessionReqInput, request: Request):
|
|
281
360
|
return _create_error_response(e)
|
282
361
|
|
283
362
|
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
if obj.stream:
|
290
|
-
|
291
|
-
async def stream_results() -> AsyncIterator[bytes]:
|
292
|
-
try:
|
293
|
-
async for out in tokenizer_manager.generate_request(obj, request):
|
294
|
-
yield b"data: " + orjson.dumps(
|
295
|
-
out, option=orjson.OPT_NON_STR_KEYS
|
296
|
-
) + b"\n\n"
|
297
|
-
except ValueError as e:
|
298
|
-
out = {"error": {"message": str(e)}}
|
299
|
-
yield b"data: " + orjson.dumps(
|
300
|
-
out, option=orjson.OPT_NON_STR_KEYS
|
301
|
-
) + b"\n\n"
|
302
|
-
yield b"data: [DONE]\n\n"
|
303
|
-
|
304
|
-
return StreamingResponse(
|
305
|
-
stream_results(),
|
306
|
-
media_type="text/event-stream",
|
307
|
-
background=tokenizer_manager.create_abort_task(obj),
|
308
|
-
)
|
309
|
-
else:
|
310
|
-
try:
|
311
|
-
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
312
|
-
return ret
|
313
|
-
except ValueError as e:
|
314
|
-
logger.error(f"Error: {e}")
|
315
|
-
return _create_error_response(e)
|
316
|
-
|
317
|
-
|
318
|
-
@app.api_route("/encode", methods=["POST", "PUT"])
|
319
|
-
@time_func_latency
|
320
|
-
async def encode_request(obj: EmbeddingReqInput, request: Request):
|
321
|
-
"""Handle an embedding request."""
|
322
|
-
try:
|
323
|
-
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
324
|
-
return ret
|
325
|
-
except ValueError as e:
|
326
|
-
return _create_error_response(e)
|
327
|
-
|
328
|
-
|
329
|
-
@app.api_route("/classify", methods=["POST", "PUT"])
|
330
|
-
@time_func_latency
|
331
|
-
async def classify_request(obj: EmbeddingReqInput, request: Request):
|
332
|
-
"""Handle a reward model request. Now the arguments and return values are the same as embedding models."""
|
333
|
-
try:
|
334
|
-
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
335
|
-
return ret
|
336
|
-
except ValueError as e:
|
337
|
-
return _create_error_response(e)
|
363
|
+
@app.api_route("/configure_logging", methods=["GET", "POST"])
|
364
|
+
async def configure_logging(obj: ConfigureLoggingReq, request: Request):
|
365
|
+
"""Close the session"""
|
366
|
+
tokenizer_manager.configure_logging(obj)
|
367
|
+
return Response(status_code=200)
|
338
368
|
|
339
369
|
|
340
370
|
##### OpenAI-compatible API endpoints #####
|
@@ -440,6 +470,10 @@ def launch_engine(
|
|
440
470
|
server_args.model_path, server_args.tokenizer_path
|
441
471
|
)
|
442
472
|
|
473
|
+
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
474
|
+
enable=server_args.enable_memory_saver
|
475
|
+
)
|
476
|
+
|
443
477
|
if server_args.dp_size == 1:
|
444
478
|
# Launch tensor parallel scheduler processes
|
445
479
|
scheduler_procs = []
|
@@ -456,7 +490,8 @@ def launch_engine(
|
|
456
490
|
target=run_scheduler_process,
|
457
491
|
args=(server_args, port_args, gpu_id, tp_rank, None, writer),
|
458
492
|
)
|
459
|
-
|
493
|
+
with memory_saver_adapter.configure_subprocess():
|
494
|
+
proc.start()
|
460
495
|
scheduler_procs.append(proc)
|
461
496
|
scheduler_pipe_readers.append(reader)
|
462
497
|
|
@@ -473,7 +508,8 @@ def launch_engine(
|
|
473
508
|
target=run_data_parallel_controller_process,
|
474
509
|
args=(server_args, port_args, writer),
|
475
510
|
)
|
476
|
-
|
511
|
+
with memory_saver_adapter.configure_subprocess():
|
512
|
+
proc.start()
|
477
513
|
|
478
514
|
# Launch detokenizer process
|
479
515
|
detoken_proc = mp.Process(
|
@@ -546,7 +582,12 @@ def launch_server(
|
|
546
582
|
|
547
583
|
# Send a warmup request
|
548
584
|
t = threading.Thread(
|
549
|
-
target=_wait_and_warmup,
|
585
|
+
target=_wait_and_warmup,
|
586
|
+
args=(
|
587
|
+
server_args,
|
588
|
+
pipe_finish_writer,
|
589
|
+
tokenizer_manager.image_token_id,
|
590
|
+
),
|
550
591
|
)
|
551
592
|
t.start()
|
552
593
|
|
@@ -608,6 +649,9 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
608
649
|
# The child processes will send SIGQUIT to this process when any error happens
|
609
650
|
# This process then clean up the whole process tree
|
610
651
|
def sigquit_handler(signum, frame):
|
652
|
+
logger.error(
|
653
|
+
"Received sigquit from a child proces. It usually means the child failed."
|
654
|
+
)
|
611
655
|
kill_process_tree(os.getpid())
|
612
656
|
|
613
657
|
signal.signal(signal.SIGQUIT, sigquit_handler)
|
@@ -616,7 +660,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
616
660
|
mp.set_start_method("spawn", force=True)
|
617
661
|
|
618
662
|
|
619
|
-
def _wait_and_warmup(server_args, pipe_finish_writer):
|
663
|
+
def _wait_and_warmup(server_args, pipe_finish_writer, image_token_text):
|
620
664
|
headers = {}
|
621
665
|
url = server_args.url()
|
622
666
|
if server_args.api_key:
|
@@ -891,6 +935,18 @@ class Engine:
|
|
891
935
|
loop = asyncio.get_event_loop()
|
892
936
|
return loop.run_until_complete(tokenizer_manager.get_weights_by_name(obj, None))
|
893
937
|
|
938
|
+
def release_memory_occupation(self):
|
939
|
+
"""Release GPU occupation temporarily"""
|
940
|
+
obj = ReleaseMemoryOccupationReqInput()
|
941
|
+
loop = asyncio.get_event_loop()
|
942
|
+
loop.run_until_complete(tokenizer_manager.release_memory_occupation(obj, None))
|
943
|
+
|
944
|
+
def resume_memory_occupation(self):
|
945
|
+
"""Resume GPU occupation"""
|
946
|
+
obj = ResumeMemoryOccupationReqInput()
|
947
|
+
loop = asyncio.get_event_loop()
|
948
|
+
loop.run_until_complete(tokenizer_manager.resume_memory_occupation(obj, None))
|
949
|
+
|
894
950
|
|
895
951
|
class Runtime:
|
896
952
|
"""
|
sglang/srt/server_args.py
CHANGED
@@ -23,7 +23,6 @@ from typing import List, Optional
|
|
23
23
|
import torch
|
24
24
|
|
25
25
|
from sglang.srt.hf_transformers_utils import check_gguf_file
|
26
|
-
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
27
26
|
from sglang.srt.utils import (
|
28
27
|
get_amdgpu_memory_capacity,
|
29
28
|
get_hpu_memory_capacity,
|
@@ -32,6 +31,7 @@ from sglang.srt.utils import (
|
|
32
31
|
is_hip,
|
33
32
|
is_ipv6,
|
34
33
|
is_port_available,
|
34
|
+
nullable_str,
|
35
35
|
)
|
36
36
|
|
37
37
|
logger = logging.getLogger(__name__)
|
@@ -47,6 +47,7 @@ class ServerArgs:
|
|
47
47
|
trust_remote_code: bool = True
|
48
48
|
dtype: str = "auto"
|
49
49
|
kv_cache_dtype: str = "auto"
|
50
|
+
quantization_param_path: nullable_str = None
|
50
51
|
quantization: Optional[str] = None
|
51
52
|
context_length: Optional[int] = None
|
52
53
|
device: str = "cuda"
|
@@ -55,7 +56,6 @@ class ServerArgs:
|
|
55
56
|
is_embedding: bool = False
|
56
57
|
revision: Optional[str] = None
|
57
58
|
skip_tokenizer_init: bool = False
|
58
|
-
return_token_ids: bool = False
|
59
59
|
|
60
60
|
# Port for the HTTP server
|
61
61
|
host: str = "127.0.0.1"
|
@@ -91,7 +91,7 @@ class ServerArgs:
|
|
91
91
|
|
92
92
|
# API related
|
93
93
|
api_key: Optional[str] = None
|
94
|
-
file_storage_pth: str = "
|
94
|
+
file_storage_pth: str = "sglang_storage"
|
95
95
|
enable_cache_report: bool = False
|
96
96
|
|
97
97
|
# Data parallelism
|
@@ -148,6 +148,7 @@ class ServerArgs:
|
|
148
148
|
enable_torch_compile: bool = False
|
149
149
|
torch_compile_max_bs: int = 32
|
150
150
|
cuda_graph_max_bs: Optional[int] = None
|
151
|
+
cuda_graph_bs: Optional[List[int]] = None
|
151
152
|
torchao_config: str = ""
|
152
153
|
enable_nan_detection: bool = False
|
153
154
|
enable_p2p_check: bool = False
|
@@ -155,6 +156,7 @@ class ServerArgs:
|
|
155
156
|
triton_attention_num_kv_splits: int = 8
|
156
157
|
num_continuous_decode_steps: int = 1
|
157
158
|
delete_ckpt_after_loading: bool = False
|
159
|
+
enable_memory_saver: bool = False
|
158
160
|
|
159
161
|
def __post_init__(self):
|
160
162
|
# Set missing default values
|
@@ -295,6 +297,11 @@ class ServerArgs:
|
|
295
297
|
"tokenizer if available, and 'slow' will "
|
296
298
|
"always use the slow tokenizer.",
|
297
299
|
)
|
300
|
+
parser.add_argument(
|
301
|
+
"--skip-tokenizer-init",
|
302
|
+
action="store_true",
|
303
|
+
help="If set, skip init tokenizer and pass input_ids in generate request",
|
304
|
+
)
|
298
305
|
parser.add_argument(
|
299
306
|
"--load-format",
|
300
307
|
type=str,
|
@@ -345,8 +352,17 @@ class ServerArgs:
|
|
345
352
|
"--kv-cache-dtype",
|
346
353
|
type=str,
|
347
354
|
default=ServerArgs.kv_cache_dtype,
|
348
|
-
choices=["auto", "fp8_e5m2"],
|
349
|
-
help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" is supported for CUDA 11.8+.',
|
355
|
+
choices=["auto", "fp8_e5m2", "fp8_e4m3"],
|
356
|
+
help='Data type for kv cache storage. "auto" will use model data type. "fp8_e5m2" and "fp8_e4m3" is supported for CUDA 11.8+.',
|
357
|
+
)
|
358
|
+
parser.add_argument(
|
359
|
+
"--quantization-param-path",
|
360
|
+
type=nullable_str,
|
361
|
+
default=None,
|
362
|
+
help="Path to the JSON file containing the KV cache "
|
363
|
+
"scaling factors. This should generally be supplied, when "
|
364
|
+
"KV cache dtype is FP8. Otherwise, KV cache scaling factors "
|
365
|
+
"default to 1.0, which may cause accuracy issues. ",
|
350
366
|
)
|
351
367
|
parser.add_argument(
|
352
368
|
"--quantization",
|
@@ -361,6 +377,8 @@ class ServerArgs:
|
|
361
377
|
"awq_marlin",
|
362
378
|
"bitsandbytes",
|
363
379
|
"gguf",
|
380
|
+
"modelopt",
|
381
|
+
"w8a8_int8",
|
364
382
|
],
|
365
383
|
help="The quantization method.",
|
366
384
|
)
|
@@ -402,18 +420,6 @@ class ServerArgs:
|
|
402
420
|
"name, a tag name, or a commit id. If unspecified, will use "
|
403
421
|
"the default version.",
|
404
422
|
)
|
405
|
-
parser.add_argument(
|
406
|
-
"--skip-tokenizer-init",
|
407
|
-
action="store_true",
|
408
|
-
help="If set, skip init tokenizer and pass input_ids in generate request",
|
409
|
-
)
|
410
|
-
parser.add_argument(
|
411
|
-
"--return-token-ids",
|
412
|
-
action="store_true",
|
413
|
-
default=ServerArgs.return_token_ids,
|
414
|
-
help="Whether to return token IDs in the output, this may introduce additional overhead.",
|
415
|
-
)
|
416
|
-
|
417
423
|
# Memory and scheduling
|
418
424
|
parser.add_argument(
|
419
425
|
"--mem-fraction-static",
|
@@ -549,7 +555,7 @@ class ServerArgs:
|
|
549
555
|
"--decode-log-interval",
|
550
556
|
type=int,
|
551
557
|
default=ServerArgs.decode_log_interval,
|
552
|
-
help="The log interval of decode batch",
|
558
|
+
help="The log interval of decode batch.",
|
553
559
|
)
|
554
560
|
|
555
561
|
# API related
|
@@ -802,6 +808,12 @@ class ServerArgs:
|
|
802
808
|
default=ServerArgs.cuda_graph_max_bs,
|
803
809
|
help="Set the maximum batch size for cuda graph.",
|
804
810
|
)
|
811
|
+
parser.add_argument(
|
812
|
+
"--cuda-graph-bs",
|
813
|
+
type=int,
|
814
|
+
nargs="+",
|
815
|
+
help="Set the list of batch sizes for cuda graph.",
|
816
|
+
)
|
805
817
|
parser.add_argument(
|
806
818
|
"--torchao-config",
|
807
819
|
type=str,
|
@@ -843,6 +855,11 @@ class ServerArgs:
|
|
843
855
|
action="store_true",
|
844
856
|
help="Delete the model checkpoint after loading the model.",
|
845
857
|
)
|
858
|
+
parser.add_argument(
|
859
|
+
"--enable-memory-saver",
|
860
|
+
action="store_true",
|
861
|
+
help="Allow saving memory using release_memory_occupation and resume_memory_occupation",
|
862
|
+
)
|
846
863
|
|
847
864
|
@classmethod
|
848
865
|
def from_cli_args(cls, args: argparse.Namespace):
|
@@ -920,7 +937,10 @@ class PortArgs:
|
|
920
937
|
while True:
|
921
938
|
if is_port_available(port):
|
922
939
|
break
|
923
|
-
port
|
940
|
+
if port < 60000:
|
941
|
+
port += 42
|
942
|
+
else:
|
943
|
+
port -= 43
|
924
944
|
|
925
945
|
return PortArgs(
|
926
946
|
tokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
|