sglang 0.3.5.post2__py3-none-any.whl → 0.3.6.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -2
- sglang/api.py +2 -2
- sglang/bench_latency.py +1 -553
- sglang/bench_offline_throughput.py +48 -20
- sglang/bench_one_batch.py +472 -0
- sglang/{bench_server_latency.py → bench_one_batch_server.py} +3 -3
- sglang/bench_serving.py +125 -6
- sglang/check_env.py +3 -6
- sglang/lang/backend/base_backend.py +1 -1
- sglang/lang/backend/runtime_endpoint.py +2 -2
- sglang/srt/configs/model_config.py +13 -14
- sglang/srt/constrained/__init__.py +13 -14
- sglang/srt/constrained/base_grammar_backend.py +13 -15
- sglang/srt/constrained/outlines_backend.py +28 -17
- sglang/srt/constrained/outlines_jump_forward.py +13 -15
- sglang/srt/constrained/xgrammar_backend.py +47 -58
- sglang/srt/conversation.py +13 -15
- sglang/srt/hf_transformers_utils.py +13 -15
- sglang/srt/layers/activation.py +16 -13
- sglang/srt/layers/attention/flashinfer_backend.py +106 -54
- sglang/srt/layers/attention/triton_backend.py +9 -7
- sglang/srt/layers/attention/triton_ops/decode_attention.py +51 -55
- sglang/srt/layers/attention/triton_ops/extend_attention.py +16 -16
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +13 -15
- sglang/srt/layers/custom_op_util.py +25 -0
- sglang/srt/layers/fused_moe_grok/__init__.py +1 -0
- sglang/srt/layers/{fused_moe → fused_moe_grok}/fused_moe.py +11 -4
- sglang/srt/layers/{fused_moe → fused_moe_grok}/layer.py +4 -9
- sglang/srt/layers/{fused_moe/patch.py → fused_moe_patch.py} +5 -0
- sglang/srt/layers/fused_moe_triton/__init__.py +44 -0
- sglang/srt/layers/fused_moe_triton/fused_moe.py +861 -0
- sglang/srt/layers/fused_moe_triton/layer.py +633 -0
- sglang/srt/layers/layernorm.py +17 -15
- sglang/srt/layers/logits_processor.py +23 -25
- sglang/srt/layers/quantization/__init__.py +77 -17
- sglang/srt/layers/radix_attention.py +13 -15
- sglang/srt/layers/rotary_embedding.py +13 -13
- sglang/srt/layers/sampler.py +4 -8
- sglang/srt/layers/torchao_utils.py +2 -0
- sglang/srt/lora/lora.py +13 -14
- sglang/srt/lora/lora_config.py +13 -14
- sglang/srt/lora/lora_manager.py +22 -24
- sglang/srt/managers/data_parallel_controller.py +98 -27
- sglang/srt/managers/detokenizer_manager.py +13 -15
- sglang/srt/managers/io_struct.py +63 -21
- sglang/srt/managers/schedule_batch.py +154 -59
- sglang/srt/managers/schedule_policy.py +18 -16
- sglang/srt/managers/scheduler.py +278 -109
- sglang/srt/managers/session_controller.py +61 -0
- sglang/srt/managers/tokenizer_manager.py +63 -18
- sglang/srt/managers/tp_worker.py +25 -16
- sglang/srt/managers/tp_worker_overlap_thread.py +62 -67
- sglang/srt/metrics/collector.py +13 -15
- sglang/srt/metrics/func_timer.py +13 -15
- sglang/srt/mm_utils.py +13 -14
- sglang/srt/model_executor/cuda_graph_runner.py +63 -25
- sglang/srt/model_executor/forward_batch_info.py +128 -32
- sglang/srt/model_executor/model_runner.py +132 -64
- sglang/srt/model_parallel.py +98 -0
- sglang/srt/models/chatglm.py +15 -16
- sglang/srt/models/commandr.py +15 -16
- sglang/srt/models/dbrx.py +15 -16
- sglang/srt/models/deepseek.py +15 -15
- sglang/srt/models/deepseek_v2.py +162 -59
- sglang/srt/models/exaone.py +14 -15
- sglang/srt/models/gemma.py +14 -14
- sglang/srt/models/gemma2.py +31 -25
- sglang/srt/models/gemma2_reward.py +13 -14
- sglang/srt/models/gpt_bigcode.py +14 -14
- sglang/srt/models/grok.py +15 -15
- sglang/srt/models/internlm2.py +13 -15
- sglang/srt/models/internlm2_reward.py +13 -14
- sglang/srt/models/llama.py +21 -21
- sglang/srt/models/llama_classification.py +13 -14
- sglang/srt/models/llama_reward.py +13 -14
- sglang/srt/models/llava.py +14 -16
- sglang/srt/models/llavavid.py +14 -16
- sglang/srt/models/minicpm.py +13 -15
- sglang/srt/models/minicpm3.py +13 -15
- sglang/srt/models/mistral.py +13 -15
- sglang/srt/models/mixtral.py +15 -15
- sglang/srt/models/mixtral_quant.py +14 -14
- sglang/srt/models/olmo.py +22 -20
- sglang/srt/models/olmoe.py +23 -20
- sglang/srt/models/phi3_small.py +447 -0
- sglang/srt/models/qwen.py +14 -14
- sglang/srt/models/qwen2.py +22 -19
- sglang/srt/models/qwen2_moe.py +17 -18
- sglang/srt/models/qwen2_vl.py +13 -6
- sglang/srt/models/stablelm.py +18 -16
- sglang/srt/models/torch_native_llama.py +107 -93
- sglang/srt/models/xverse.py +13 -14
- sglang/srt/models/xverse_moe.py +15 -16
- sglang/srt/models/yivl.py +13 -15
- sglang/srt/openai_api/adapter.py +19 -17
- sglang/srt/openai_api/protocol.py +14 -16
- sglang/srt/sampling/penaltylib/orchestrator.py +49 -79
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +3 -8
- sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +3 -9
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +3 -8
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +3 -8
- sglang/srt/sampling/sampling_batch_info.py +61 -57
- sglang/srt/sampling/sampling_params.py +14 -16
- sglang/srt/server.py +86 -35
- sglang/srt/server_args.py +96 -80
- sglang/srt/utils.py +266 -68
- sglang/test/few_shot_gsm8k.py +8 -4
- sglang/test/runners.py +38 -20
- sglang/test/srt/sampling/penaltylib/utils.py +23 -21
- sglang/test/test_utils.py +31 -20
- sglang/version.py +1 -1
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/LICENSE +1 -1
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/METADATA +66 -57
- sglang-0.3.6.post1.dist-info/RECORD +164 -0
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/WHEEL +1 -1
- sglang/srt/layers/fused_moe/__init__.py +0 -1
- sglang-0.3.5.post2.dist-info/RECORD +0 -156
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/top_level.txt +0 -0
sglang/srt/utils.py
CHANGED
@@ -1,22 +1,21 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
15
|
-
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
16
14
|
"""Common utilities."""
|
17
15
|
|
18
16
|
import base64
|
19
17
|
import ipaddress
|
18
|
+
import itertools
|
20
19
|
import json
|
21
20
|
import logging
|
22
21
|
import os
|
@@ -33,7 +32,7 @@ import time
|
|
33
32
|
import warnings
|
34
33
|
from importlib.metadata import PackageNotFoundError, version
|
35
34
|
from io import BytesIO
|
36
|
-
from typing import Any, Dict, List, Optional, Union
|
35
|
+
from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, Union
|
37
36
|
|
38
37
|
import numpy as np
|
39
38
|
import psutil
|
@@ -46,6 +45,8 @@ from fastapi.responses import ORJSONResponse
|
|
46
45
|
from packaging import version as pkg_version
|
47
46
|
from starlette.routing import Mount
|
48
47
|
from torch import nn
|
48
|
+
from torch.func import functional_call
|
49
|
+
from torch.library import Library
|
49
50
|
from torch.profiler import ProfilerActivity, profile, record_function
|
50
51
|
from triton.runtime.cache import (
|
51
52
|
FileCacheManager,
|
@@ -71,6 +72,8 @@ def is_flashinfer_available():
|
|
71
72
|
Check whether flashinfer is available.
|
72
73
|
As of Oct. 6, 2024, it is only available on NVIDIA GPUs.
|
73
74
|
"""
|
75
|
+
if os.environ.get("SGLANG_IS_FLASHINFER_AVAILABLE", "true") == "false":
|
76
|
+
return False
|
74
77
|
return torch.cuda.is_available() and not is_hip()
|
75
78
|
|
76
79
|
|
@@ -190,6 +193,94 @@ def get_available_gpu_memory(device, gpu_id, distributed=False):
|
|
190
193
|
return free_gpu_memory / (1 << 30)
|
191
194
|
|
192
195
|
|
196
|
+
def is_pin_memory_available() -> bool:
|
197
|
+
return torch.cuda.is_available()
|
198
|
+
|
199
|
+
|
200
|
+
_CPU_OFFLOAD_BYTES = 0
|
201
|
+
_CPU_OFFLOAD_MAX_BYTES = 0
|
202
|
+
|
203
|
+
|
204
|
+
def set_cpu_offload_max_bytes(max_bytes: int) -> None:
|
205
|
+
global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
|
206
|
+
_CPU_OFFLOAD_BYTES = 0
|
207
|
+
_CPU_OFFLOAD_MAX_BYTES = max_bytes
|
208
|
+
|
209
|
+
|
210
|
+
def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
|
211
|
+
device = next(module.parameters()).device
|
212
|
+
|
213
|
+
if device == torch.device("cpu"):
|
214
|
+
return module
|
215
|
+
|
216
|
+
global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
|
217
|
+
if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
|
218
|
+
return module
|
219
|
+
|
220
|
+
pin_memory = is_pin_memory_available()
|
221
|
+
# offload parameters to CPU
|
222
|
+
# use pin_memory if possible, which helps cudagraph capture speed
|
223
|
+
offloaded_parameters = False
|
224
|
+
for p in module.parameters():
|
225
|
+
if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
|
226
|
+
# we use per-parameter offloading
|
227
|
+
# one module might have some parameters offloaded and some not
|
228
|
+
break
|
229
|
+
|
230
|
+
# `torch.empty_like` does not support `pin_memory` argument
|
231
|
+
cpu_data = torch.empty_strided(
|
232
|
+
size=p.data.size(),
|
233
|
+
stride=p.data.stride(),
|
234
|
+
dtype=p.data.dtype,
|
235
|
+
layout=p.data.layout,
|
236
|
+
device="cpu",
|
237
|
+
pin_memory=pin_memory,
|
238
|
+
)
|
239
|
+
cpu_data.copy_(p.data)
|
240
|
+
p.data = cpu_data
|
241
|
+
_CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size()
|
242
|
+
offloaded_parameters = True
|
243
|
+
|
244
|
+
if offloaded_parameters:
|
245
|
+
original_forward = module.forward
|
246
|
+
|
247
|
+
def forward(*args, **kwargs):
|
248
|
+
module.forward = original_forward
|
249
|
+
device_state = {
|
250
|
+
# here we blindly call `to(device)`
|
251
|
+
# if the parameter is already on the device, it will be a no-op
|
252
|
+
k: v.to(device, non_blocking=True)
|
253
|
+
for k, v in module.state_dict().items()
|
254
|
+
}
|
255
|
+
output = functional_call(module, device_state, args=args, kwargs=kwargs)
|
256
|
+
module.forward = forward
|
257
|
+
return output
|
258
|
+
|
259
|
+
module.forward = forward
|
260
|
+
|
261
|
+
return module
|
262
|
+
|
263
|
+
|
264
|
+
class LayerFn(Protocol):
|
265
|
+
|
266
|
+
def __call__(self, layer_id: int, prefix: str) -> torch.nn.Module: ...
|
267
|
+
|
268
|
+
|
269
|
+
def make_layers(
|
270
|
+
num_hidden_layers: int,
|
271
|
+
layer_fn: LayerFn,
|
272
|
+
prefix: str = "",
|
273
|
+
) -> Tuple[int, int, torch.nn.ModuleList]:
|
274
|
+
"""Make a list of layers with the given layer function"""
|
275
|
+
modules = torch.nn.ModuleList(
|
276
|
+
[
|
277
|
+
maybe_offload_to_cpu(layer_fn(idx=idx, prefix=f"{prefix}.{idx}"))
|
278
|
+
for idx in range(num_hidden_layers)
|
279
|
+
]
|
280
|
+
)
|
281
|
+
return modules
|
282
|
+
|
283
|
+
|
193
284
|
def set_random_seed(seed: int) -> None:
|
194
285
|
"""Set the random seed for all libraries."""
|
195
286
|
random.seed(seed)
|
@@ -330,6 +421,7 @@ def suppress_other_loggers():
|
|
330
421
|
)
|
331
422
|
logging.getLogger("vllm.selector").setLevel(logging.WARN)
|
332
423
|
logging.getLogger("vllm.utils").setLevel(logging.ERROR)
|
424
|
+
logging.getLogger("vllm.model_executor.model_loader.loader").setLevel(logging.ERROR)
|
333
425
|
|
334
426
|
warnings.filterwarnings(
|
335
427
|
"ignore", category=UserWarning, message="The given NumPy array is not writable"
|
@@ -394,6 +486,27 @@ def kill_child_process(pid=None, include_self=False, skip_pid=None):
|
|
394
486
|
pass
|
395
487
|
|
396
488
|
|
489
|
+
def monkey_patch_vllm_model_config():
|
490
|
+
from vllm.config import ModelConfig
|
491
|
+
|
492
|
+
if not hasattr(ModelConfig, "_resolve_task"):
|
493
|
+
return
|
494
|
+
|
495
|
+
def _resolve_task(
|
496
|
+
self,
|
497
|
+
task_option,
|
498
|
+
hf_config,
|
499
|
+
):
|
500
|
+
supported_tasks = {
|
501
|
+
"generate": True,
|
502
|
+
"embedding": False,
|
503
|
+
}
|
504
|
+
selected_task = "generate"
|
505
|
+
return supported_tasks, selected_task
|
506
|
+
|
507
|
+
setattr(ModelConfig, "_resolve_task", _resolve_task)
|
508
|
+
|
509
|
+
|
397
510
|
def monkey_patch_vllm_p2p_access_check(gpu_id: int):
|
398
511
|
"""
|
399
512
|
Monkey patch the slow p2p access check in vllm.
|
@@ -405,57 +518,6 @@ def monkey_patch_vllm_p2p_access_check(gpu_id: int):
|
|
405
518
|
setattr(tgt, "gpu_p2p_access_check", lambda *arg, **kwargs: True)
|
406
519
|
|
407
520
|
|
408
|
-
def monkey_patch_vllm_dummy_weight_loader():
|
409
|
-
"""
|
410
|
-
Monkey patch the dummy weight loader in vllm to call process_weights_after_loading.
|
411
|
-
"""
|
412
|
-
|
413
|
-
from vllm.model_executor.model_loader.loader import (
|
414
|
-
CacheConfig,
|
415
|
-
DeviceConfig,
|
416
|
-
DummyModelLoader,
|
417
|
-
LoRAConfig,
|
418
|
-
ModelConfig,
|
419
|
-
ParallelConfig,
|
420
|
-
SchedulerConfig,
|
421
|
-
_initialize_model,
|
422
|
-
initialize_dummy_weights,
|
423
|
-
nn,
|
424
|
-
set_default_torch_dtype,
|
425
|
-
)
|
426
|
-
|
427
|
-
def load_model(
|
428
|
-
self,
|
429
|
-
*,
|
430
|
-
model_config: ModelConfig,
|
431
|
-
device_config: DeviceConfig,
|
432
|
-
lora_config: Optional[LoRAConfig],
|
433
|
-
parallel_config: ParallelConfig,
|
434
|
-
scheduler_config: SchedulerConfig,
|
435
|
-
cache_config: CacheConfig,
|
436
|
-
) -> nn.Module:
|
437
|
-
with set_default_torch_dtype(model_config.dtype):
|
438
|
-
with torch.device(device_config.device):
|
439
|
-
model = _initialize_model(
|
440
|
-
model_config,
|
441
|
-
self.load_config,
|
442
|
-
lora_config,
|
443
|
-
cache_config,
|
444
|
-
)
|
445
|
-
|
446
|
-
for _, module in model.named_modules():
|
447
|
-
quant_method = getattr(module, "quant_method", None)
|
448
|
-
if quant_method is not None:
|
449
|
-
quant_method.process_weights_after_loading(module)
|
450
|
-
|
451
|
-
# NOTE(woosuk): For accurate performance evaluation, we assign
|
452
|
-
# random values to the weights.
|
453
|
-
initialize_dummy_weights(model)
|
454
|
-
return model.eval()
|
455
|
-
|
456
|
-
setattr(DummyModelLoader, "load_model", load_model)
|
457
|
-
|
458
|
-
|
459
521
|
vllm_all_gather_backup = None
|
460
522
|
|
461
523
|
|
@@ -794,7 +856,48 @@ def add_prometheus_middleware(app):
|
|
794
856
|
app.routes.append(metrics_route)
|
795
857
|
|
796
858
|
|
797
|
-
def
|
859
|
+
def bind_port(port):
|
860
|
+
"""Bind to a specific port, assuming it's available."""
|
861
|
+
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
862
|
+
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) # Allows address reuse
|
863
|
+
sock.bind(("", port))
|
864
|
+
sock.listen(1)
|
865
|
+
return sock
|
866
|
+
|
867
|
+
|
868
|
+
def get_amdgpu_memory_capacity():
|
869
|
+
try:
|
870
|
+
# Run rocm-smi and capture the output
|
871
|
+
result = subprocess.run(
|
872
|
+
["rocm-smi --showmeminfo vram | grep 'Total Memory' | awk '{print $NF}'"],
|
873
|
+
stdout=subprocess.PIPE,
|
874
|
+
stderr=subprocess.PIPE,
|
875
|
+
shell=True,
|
876
|
+
text=True,
|
877
|
+
)
|
878
|
+
if result.returncode != 0:
|
879
|
+
raise RuntimeError(f"rocm-smi error: {result.stderr.strip()}")
|
880
|
+
|
881
|
+
# Parse the output to extract memory values in MiB
|
882
|
+
memory_values = [
|
883
|
+
float(mem) / 1024 / 1024
|
884
|
+
for mem in result.stdout.strip().split("\n")
|
885
|
+
if re.match(r"^\d+(\.\d+)?$", mem.strip())
|
886
|
+
]
|
887
|
+
|
888
|
+
if not memory_values:
|
889
|
+
raise ValueError("No GPU memory values found.")
|
890
|
+
|
891
|
+
# Return the minimum memory value
|
892
|
+
return min(memory_values)
|
893
|
+
|
894
|
+
except FileNotFoundError:
|
895
|
+
raise RuntimeError(
|
896
|
+
"rocm-smi not found. Ensure AMD ROCm drivers are installed and accessible."
|
897
|
+
)
|
898
|
+
|
899
|
+
|
900
|
+
def get_nvgpu_memory_capacity():
|
798
901
|
try:
|
799
902
|
# Run nvidia-smi and capture the output
|
800
903
|
result = subprocess.run(
|
@@ -824,3 +927,98 @@ def get_gpu_memory_capacity():
|
|
824
927
|
raise RuntimeError(
|
825
928
|
"nvidia-smi not found. Ensure NVIDIA drivers are installed and accessible."
|
826
929
|
)
|
930
|
+
|
931
|
+
|
932
|
+
def crash_on_warnings():
|
933
|
+
# Crash on warning if we are running CI tests
|
934
|
+
return os.getenv("SGLANG_IS_IN_CI", "false").lower() == "true"
|
935
|
+
|
936
|
+
|
937
|
+
def get_device_name(device_id: int = 0) -> str:
|
938
|
+
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
939
|
+
return torch.cuda.get_device_name(device_id)
|
940
|
+
|
941
|
+
if hasattr(torch, "hip") and torch.hip.is_available():
|
942
|
+
return torch.hip.get_device_name(device_id)
|
943
|
+
|
944
|
+
if hasattr(torch, "xpu") and torch.xpu.is_available():
|
945
|
+
return torch.xpu.get_device_name(device_id)
|
946
|
+
|
947
|
+
if hasattr(torch, "hpu") and torch.hpu.is_available():
|
948
|
+
return torch.hpu.get_device_name(device_id)
|
949
|
+
|
950
|
+
|
951
|
+
sglang_lib = Library("sglang", "FRAGMENT") # noqa
|
952
|
+
|
953
|
+
|
954
|
+
def direct_register_custom_op(
|
955
|
+
op_name: str,
|
956
|
+
op_func: Callable,
|
957
|
+
mutates_args: List[str],
|
958
|
+
fake_impl: Optional[Callable] = None,
|
959
|
+
target_lib: Optional[Library] = None,
|
960
|
+
):
|
961
|
+
"""
|
962
|
+
`torch.library.custom_op` can have significant overhead because it
|
963
|
+
needs to consider complicated dispatching logic. This function
|
964
|
+
directly registers a custom op and dispatches it to the CUDA backend.
|
965
|
+
See https://gist.github.com/youkaichao/ecbea9ec9fc79a45d2adce1784d7a9a5
|
966
|
+
for more details.
|
967
|
+
|
968
|
+
By default, the custom op is registered to the vLLM library. If you
|
969
|
+
want to register it to a different library, you can pass the library
|
970
|
+
object to the `target_lib` argument.
|
971
|
+
|
972
|
+
IMPORTANT: the lifetime of the operator is tied to the lifetime of the
|
973
|
+
library object. If you want to bind the operator to a different library,
|
974
|
+
make sure the library object is alive when the operator is used.
|
975
|
+
"""
|
976
|
+
import torch.library
|
977
|
+
|
978
|
+
if hasattr(torch.library, "infer_schema"):
|
979
|
+
schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args)
|
980
|
+
else:
|
981
|
+
# for pytorch 2.4
|
982
|
+
import torch._custom_op.impl
|
983
|
+
|
984
|
+
schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args)
|
985
|
+
|
986
|
+
my_lib = target_lib or sglang_lib
|
987
|
+
my_lib.define(op_name + schema_str)
|
988
|
+
my_lib.impl(op_name, op_func, "CUDA")
|
989
|
+
if fake_impl is not None:
|
990
|
+
my_lib._register_fake(op_name, fake_impl)
|
991
|
+
|
992
|
+
|
993
|
+
def gpu_proc_affinity(
|
994
|
+
tp_size: int,
|
995
|
+
nnodes: int,
|
996
|
+
gpu_id: int,
|
997
|
+
):
|
998
|
+
# current process
|
999
|
+
pid = os.getpid()
|
1000
|
+
p = psutil.Process(pid)
|
1001
|
+
|
1002
|
+
tp_size_per_node = tp_size // nnodes
|
1003
|
+
|
1004
|
+
# total physical cores
|
1005
|
+
total_pcores = psutil.cpu_count(logical=False)
|
1006
|
+
# physical cores per TP (N.B. more Cores than GPUs on node)
|
1007
|
+
num_cores_bind = total_pcores // tp_size_per_node
|
1008
|
+
|
1009
|
+
# able to handle multiple DP per node
|
1010
|
+
start_cpu_id = (gpu_id * num_cores_bind) % total_pcores
|
1011
|
+
end_cpu_id = start_cpu_id + num_cores_bind
|
1012
|
+
|
1013
|
+
if psutil.cpu_count() != psutil.cpu_count(logical=False):
|
1014
|
+
# HT on
|
1015
|
+
upper_cpu_ids = [id for id in range(start_cpu_id, end_cpu_id)]
|
1016
|
+
lower_cpu_ids = [id + total_pcores for id in range(start_cpu_id, end_cpu_id)]
|
1017
|
+
bind_cpu_ids = list(itertools.chain(upper_cpu_ids, lower_cpu_ids))
|
1018
|
+
else:
|
1019
|
+
# HT off
|
1020
|
+
bind_cpu_ids = [id for id in range(start_cpu_id, end_cpu_id)]
|
1021
|
+
|
1022
|
+
# set cpu_affinity to current process
|
1023
|
+
p.cpu_affinity(bind_cpu_ids)
|
1024
|
+
logger.info(f"Process {pid} gpu_id {gpu_id} is running on CPUs: {p.cpu_affinity()}")
|
sglang/test/few_shot_gsm8k.py
CHANGED
@@ -48,9 +48,13 @@ def run_eval(args):
|
|
48
48
|
# Select backend
|
49
49
|
set_default_backend(RuntimeEndpoint(f"{args.host}:{args.port}"))
|
50
50
|
|
51
|
-
|
52
|
-
|
53
|
-
|
51
|
+
if args.data_path is None:
|
52
|
+
# Read data
|
53
|
+
url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl"
|
54
|
+
filename = download_and_cache_file(url)
|
55
|
+
else:
|
56
|
+
filename = args.data_path
|
57
|
+
|
54
58
|
lines = list(read_jsonl(filename))
|
55
59
|
|
56
60
|
# Construct prompts
|
@@ -131,7 +135,7 @@ def run_eval(args):
|
|
131
135
|
if __name__ == "__main__":
|
132
136
|
parser = argparse.ArgumentParser()
|
133
137
|
parser.add_argument("--num-shots", type=int, default=5)
|
134
|
-
parser.add_argument("--data-path", type=str
|
138
|
+
parser.add_argument("--data-path", type=str)
|
135
139
|
parser.add_argument("--num-questions", type=int, default=200)
|
136
140
|
parser.add_argument("--max-new-tokens", type=int, default=512)
|
137
141
|
parser.add_argument("--parallel", type=int, default=128)
|
sglang/test/runners.py
CHANGED
@@ -1,17 +1,16 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
15
14
|
|
16
15
|
import json
|
17
16
|
import multiprocessing as mp
|
@@ -58,6 +57,28 @@ def get_top_logprobs(logits, k):
|
|
58
57
|
return logprobs
|
59
58
|
|
60
59
|
|
60
|
+
def _get_sentence_transformer_embedding_model(model_path, torch_dtype):
|
61
|
+
from sentence_transformers import SentenceTransformer
|
62
|
+
from sentence_transformers.util import is_sentence_transformer_model
|
63
|
+
|
64
|
+
if is_sentence_transformer_model(model_path):
|
65
|
+
model = SentenceTransformer(
|
66
|
+
model_path,
|
67
|
+
model_kwargs={"torch_dtype": torch_dtype},
|
68
|
+
)
|
69
|
+
else: # if no pre-trained sentence-transformers model
|
70
|
+
from sentence_transformers import models
|
71
|
+
|
72
|
+
word_embedding_model = models.Transformer(model_path).to(dtype=torch_dtype)
|
73
|
+
pooling_model = models.Pooling(
|
74
|
+
word_embedding_model.get_word_embedding_dimension(),
|
75
|
+
pooling_mode="lasttoken",
|
76
|
+
)
|
77
|
+
model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
|
78
|
+
|
79
|
+
return model.cuda()
|
80
|
+
|
81
|
+
|
61
82
|
@dataclass
|
62
83
|
class ModelOutput:
|
63
84
|
output_strs: List[str] = None
|
@@ -114,12 +135,9 @@ class HFRunner:
|
|
114
135
|
low_cpu_mem_usage=True,
|
115
136
|
).cuda()
|
116
137
|
elif self.model_type == "embedding":
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
model_path,
|
121
|
-
model_kwargs={"torch_dtype": torch_dtype},
|
122
|
-
).cuda()
|
138
|
+
self.model = _get_sentence_transformer_embedding_model(
|
139
|
+
model_path, torch_dtype
|
140
|
+
)
|
123
141
|
elif self.model_type == "reward":
|
124
142
|
from transformers import AutoModelForSequenceClassification
|
125
143
|
|
@@ -1,7 +1,7 @@
|
|
1
1
|
import dataclasses
|
2
2
|
import enum
|
3
|
-
import typing
|
4
3
|
import unittest
|
4
|
+
from typing import Dict, List, Optional, Set, Tuple, Type
|
5
5
|
|
6
6
|
import torch
|
7
7
|
|
@@ -16,7 +16,7 @@ from sglang.srt.sampling.penaltylib.orchestrator import (
|
|
16
16
|
class MockSamplingParams:
|
17
17
|
frequency_penalty: float = 0.0
|
18
18
|
min_new_tokens: int = 0
|
19
|
-
stop_token_ids:
|
19
|
+
stop_token_ids: List[int] = None
|
20
20
|
presence_penalty: float = 0.0
|
21
21
|
repetition_penalty: float = 1.0
|
22
22
|
|
@@ -24,12 +24,12 @@ class MockSamplingParams:
|
|
24
24
|
@dataclasses.dataclass
|
25
25
|
class MockTokenizer:
|
26
26
|
eos_token_id: int
|
27
|
-
additional_stop_token_ids:
|
27
|
+
additional_stop_token_ids: Optional[List[int]] = None
|
28
28
|
|
29
29
|
|
30
30
|
@dataclasses.dataclass
|
31
31
|
class MockReq:
|
32
|
-
origin_input_ids:
|
32
|
+
origin_input_ids: List[int]
|
33
33
|
sampling_params: MockSamplingParams
|
34
34
|
tokenizer: MockTokenizer
|
35
35
|
|
@@ -42,8 +42,8 @@ class StepType(enum.Enum):
|
|
42
42
|
@dataclasses.dataclass
|
43
43
|
class Step:
|
44
44
|
type: StepType
|
45
|
-
token_ids:
|
46
|
-
expected_tensors:
|
45
|
+
token_ids: List[int]
|
46
|
+
expected_tensors: Dict[str, torch.Tensor]
|
47
47
|
# assume initial logits are all 1
|
48
48
|
expected_logits: torch.Tensor
|
49
49
|
|
@@ -52,7 +52,7 @@ class Step:
|
|
52
52
|
class Subject:
|
53
53
|
sampling_params: MockSamplingParams
|
54
54
|
# first step must be input, which will be converted to Req
|
55
|
-
steps:
|
55
|
+
steps: List[Step]
|
56
56
|
eos_token_id: int = -1
|
57
57
|
|
58
58
|
def __post_init__(self):
|
@@ -66,7 +66,7 @@ class Subject:
|
|
66
66
|
f"Expected tensors keys must be the same for all steps. Got {self.steps[i].expected_tensors.keys()} for key={i} and {self.steps[0].expected_tensors.keys()}"
|
67
67
|
)
|
68
68
|
|
69
|
-
def tensor_keys(self, i: int = 0) ->
|
69
|
+
def tensor_keys(self, i: int = 0) -> Set[str]:
|
70
70
|
return set(self.steps[i].expected_tensors.keys())
|
71
71
|
|
72
72
|
def to_req(self) -> MockReq:
|
@@ -80,7 +80,7 @@ class Subject:
|
|
80
80
|
@dataclasses.dataclass
|
81
81
|
class Case:
|
82
82
|
enabled: bool
|
83
|
-
test_subjects:
|
83
|
+
test_subjects: List[Subject]
|
84
84
|
|
85
85
|
def __post_init__(self):
|
86
86
|
# each test_subjects.steps should have the same expected_tensors.keys()
|
@@ -90,12 +90,12 @@ class Case:
|
|
90
90
|
f"Expected tensors keys must be the same for all test_subjects. Got {self.test_subjects[i].tensor_keys()} for key={i} and {self.test_subjects[0].tensor_keys()}"
|
91
91
|
)
|
92
92
|
|
93
|
-
def tensor_keys(self, i: int = 0) ->
|
93
|
+
def tensor_keys(self, i: int = 0) -> List[str]:
|
94
94
|
return set(self.test_subjects[i].tensor_keys())
|
95
95
|
|
96
96
|
|
97
97
|
class BaseBatchedPenalizerTest(unittest.TestCase):
|
98
|
-
Penalizer:
|
98
|
+
Penalizer: Type[_BatchedPenalizer]
|
99
99
|
device = "cuda"
|
100
100
|
vocab_size = 5
|
101
101
|
|
@@ -115,7 +115,7 @@ class BaseBatchedPenalizerTest(unittest.TestCase):
|
|
115
115
|
"""
|
116
116
|
return torch.tensor(data, **kwargs, device=self.device)
|
117
117
|
|
118
|
-
def create_test_subjects(self) ->
|
118
|
+
def create_test_subjects(self) -> List[Subject]:
|
119
119
|
raise NotImplementedError()
|
120
120
|
|
121
121
|
def create_test_cases(self):
|
@@ -127,7 +127,7 @@ class BaseBatchedPenalizerTest(unittest.TestCase):
|
|
127
127
|
|
128
128
|
def _create_penalizer(
|
129
129
|
self, case: Case
|
130
|
-
) ->
|
130
|
+
) -> Tuple[BatchedPenalizerOrchestrator, _BatchedPenalizer]:
|
131
131
|
orchestrator = BatchedPenalizerOrchestrator(
|
132
132
|
vocab_size=self.vocab_size,
|
133
133
|
batch=_BatchLike(reqs=[subject.to_req() for subject in case.test_subjects]),
|
@@ -287,22 +287,24 @@ class BaseBatchedPenalizerTest(unittest.TestCase):
|
|
287
287
|
if i < len(subject.steps)
|
288
288
|
]
|
289
289
|
|
290
|
-
inputs:
|
291
|
-
outputs:
|
290
|
+
inputs: List[List[int]] = []
|
291
|
+
outputs: List[List[int]] = []
|
292
292
|
for subject in filtered_subjects:
|
293
293
|
step = subject.steps[i]
|
294
294
|
if step.type == StepType.INPUT:
|
295
|
-
|
296
|
-
outputs.append([])
|
295
|
+
raise NotImplementedError()
|
297
296
|
else:
|
298
297
|
inputs.append([])
|
299
298
|
outputs.append(step.token_ids)
|
300
299
|
|
301
|
-
if any(inputs):
|
302
|
-
orchestrator.cumulate_input_tokens(inputs)
|
303
|
-
|
304
300
|
if any(outputs):
|
305
|
-
|
301
|
+
for j in range(max(len(x) for x in outputs)):
|
302
|
+
tmp_outputs = torch.tensor(
|
303
|
+
[x[j] for x in outputs],
|
304
|
+
dtype=torch.int32,
|
305
|
+
device=orchestrator.device,
|
306
|
+
)
|
307
|
+
orchestrator.cumulate_output_tokens(tmp_outputs)
|
306
308
|
|
307
309
|
if penalizer.is_required():
|
308
310
|
self.assertTrue(penalizer.is_prepared())
|