sglang 0.4.5.post1__py3-none-any.whl → 0.4.5.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -4
- sglang/bench_one_batch.py +2 -2
- sglang/bench_serving.py +3 -6
- sglang/compile_deep_gemm.py +136 -0
- sglang/lang/backend/anthropic.py +0 -4
- sglang/lang/backend/base_backend.py +1 -1
- sglang/lang/backend/openai.py +6 -2
- sglang/lang/backend/runtime_endpoint.py +5 -1
- sglang/lang/backend/vertexai.py +0 -1
- sglang/lang/compiler.py +1 -7
- sglang/lang/tracer.py +3 -7
- sglang/srt/_custom_ops.py +0 -2
- sglang/srt/configs/model_config.py +4 -1
- sglang/srt/constrained/outlines_jump_forward.py +14 -1
- sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
- sglang/srt/constrained/xgrammar_backend.py +27 -4
- sglang/srt/custom_op.py +0 -62
- sglang/srt/disaggregation/decode.py +105 -6
- sglang/srt/disaggregation/mini_lb.py +74 -9
- sglang/srt/disaggregation/mooncake/conn.py +33 -63
- sglang/srt/disaggregation/mooncake/transfer_engine.py +30 -61
- sglang/srt/disaggregation/nixl/__init__.py +1 -0
- sglang/srt/disaggregation/nixl/conn.py +622 -0
- sglang/srt/disaggregation/prefill.py +137 -17
- sglang/srt/disaggregation/utils.py +32 -0
- sglang/srt/entrypoints/engine.py +4 -0
- sglang/srt/entrypoints/http_server.py +3 -7
- sglang/srt/entrypoints/verl_engine.py +7 -5
- sglang/srt/function_call_parser.py +60 -0
- sglang/srt/layers/activation.py +6 -8
- sglang/srt/layers/attention/flashattention_backend.py +883 -209
- sglang/srt/layers/attention/flashinfer_backend.py +5 -2
- sglang/srt/layers/attention/torch_native_backend.py +6 -1
- sglang/srt/layers/attention/triton_backend.py +6 -0
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +5 -5
- sglang/srt/layers/attention/triton_ops/extend_attention.py +18 -7
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +7 -3
- sglang/srt/layers/dp_attention.py +1 -1
- sglang/srt/layers/layernorm.py +20 -5
- sglang/srt/layers/linear.py +17 -3
- sglang/srt/layers/moe/ep_moe/layer.py +17 -29
- sglang/srt/layers/moe/fused_moe_native.py +4 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +14 -19
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
- sglang/srt/layers/moe/topk.py +27 -30
- sglang/srt/layers/parameter.py +0 -2
- sglang/srt/layers/quantization/__init__.py +1 -0
- sglang/srt/layers/quantization/blockwise_int8.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +9 -2
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +16 -44
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +153 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
- sglang/srt/layers/quantization/deep_gemm.py +378 -0
- sglang/srt/layers/quantization/fp8.py +115 -132
- sglang/srt/layers/quantization/fp8_kernel.py +213 -88
- sglang/srt/layers/quantization/fp8_utils.py +189 -264
- sglang/srt/layers/quantization/gptq.py +13 -7
- sglang/srt/layers/quantization/modelopt_quant.py +2 -2
- sglang/srt/layers/quantization/moe_wna16.py +2 -0
- sglang/srt/layers/quantization/utils.py +5 -11
- sglang/srt/layers/quantization/w8a8_fp8.py +2 -0
- sglang/srt/layers/quantization/w8a8_int8.py +7 -7
- sglang/srt/layers/radix_attention.py +15 -0
- sglang/srt/layers/rotary_embedding.py +9 -8
- sglang/srt/layers/sampler.py +7 -12
- sglang/srt/lora/backend/base_backend.py +18 -2
- sglang/srt/lora/backend/flashinfer_backend.py +1 -1
- sglang/srt/lora/backend/triton_backend.py +1 -1
- sglang/srt/lora/layers.py +1 -1
- sglang/srt/lora/lora.py +1 -1
- sglang/srt/lora/lora_manager.py +1 -1
- sglang/srt/managers/data_parallel_controller.py +7 -1
- sglang/srt/managers/detokenizer_manager.py +0 -1
- sglang/srt/managers/io_struct.py +15 -3
- sglang/srt/managers/mm_utils.py +4 -3
- sglang/srt/managers/multimodal_processor.py +0 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +3 -2
- sglang/srt/managers/schedule_batch.py +15 -4
- sglang/srt/managers/scheduler.py +28 -77
- sglang/srt/managers/tokenizer_manager.py +116 -29
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +41 -29
- sglang/srt/mem_cache/memory_pool.py +38 -15
- sglang/srt/model_executor/cuda_graph_runner.py +15 -10
- sglang/srt/model_executor/model_runner.py +39 -31
- sglang/srt/models/bert.py +398 -0
- sglang/srt/models/deepseek.py +1 -1
- sglang/srt/models/deepseek_nextn.py +74 -70
- sglang/srt/models/deepseek_v2.py +292 -348
- sglang/srt/models/llama.py +5 -5
- sglang/srt/models/minicpm3.py +31 -203
- sglang/srt/models/minicpmo.py +17 -6
- sglang/srt/models/qwen2.py +4 -1
- sglang/srt/models/qwen2_moe.py +14 -13
- sglang/srt/models/qwen3.py +335 -0
- sglang/srt/models/qwen3_moe.py +423 -0
- sglang/srt/openai_api/adapter.py +71 -4
- sglang/srt/openai_api/protocol.py +6 -1
- sglang/srt/reasoning_parser.py +0 -1
- sglang/srt/sampling/sampling_batch_info.py +2 -3
- sglang/srt/server_args.py +86 -72
- sglang/srt/speculative/build_eagle_tree.py +2 -2
- sglang/srt/speculative/eagle_utils.py +2 -2
- sglang/srt/speculative/eagle_worker.py +6 -14
- sglang/srt/utils.py +62 -6
- sglang/test/runners.py +5 -1
- sglang/test/test_block_fp8.py +167 -0
- sglang/test/test_custom_ops.py +1 -1
- sglang/test/test_utils.py +3 -1
- sglang/version.py +1 -1
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/METADATA +5 -5
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/RECORD +116 -110
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/WHEEL +1 -1
- sglang/lang/__init__.py +0 -0
- sglang/srt/lora/backend/__init__.py +0 -25
- sglang/srt/server.py +0 -18
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/top_level.txt +0 -0
@@ -286,8 +286,12 @@ class MHATokenToKVPool(KVCache):
|
|
286
286
|
self.get_key_buffer(i).nbytes for i in range(self.layer_num)
|
287
287
|
] + [self.get_value_buffer(i).nbytes for i in range(self.layer_num)]
|
288
288
|
kv_item_lens = [
|
289
|
-
self.get_key_buffer(i)[0].nbytes
|
290
|
-
|
289
|
+
self.get_key_buffer(i)[0].nbytes * self.page_size
|
290
|
+
for i in range(self.layer_num)
|
291
|
+
] + [
|
292
|
+
self.get_value_buffer(i)[0].nbytes * self.page_size
|
293
|
+
for i in range(self.layer_num)
|
294
|
+
]
|
291
295
|
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
292
296
|
|
293
297
|
# Todo: different memory layout
|
@@ -414,6 +418,7 @@ class MLATokenToKVPool(KVCache):
|
|
414
418
|
enable_memory_saver: bool,
|
415
419
|
):
|
416
420
|
self.size = size
|
421
|
+
self.page_size = page_size
|
417
422
|
self.dtype = dtype
|
418
423
|
self.device = device
|
419
424
|
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
|
@@ -441,13 +446,28 @@ class MLATokenToKVPool(KVCache):
|
|
441
446
|
]
|
442
447
|
|
443
448
|
self.layer_transfer_counter = None
|
449
|
+
self.page_size = page_size
|
450
|
+
|
451
|
+
kv_size = self.get_kv_size_bytes()
|
452
|
+
logger.info(
|
453
|
+
f"KV Cache is allocated. #tokens: {size}, KV size: {kv_size / GB:.2f} GB"
|
454
|
+
)
|
455
|
+
|
456
|
+
def get_kv_size_bytes(self):
|
457
|
+
assert hasattr(self, "kv_buffer")
|
458
|
+
kv_size_bytes = 0
|
459
|
+
for kv_cache in self.kv_buffer:
|
460
|
+
kv_size_bytes += np.prod(kv_cache.shape) * kv_cache.dtype.itemsize
|
461
|
+
return kv_size_bytes
|
444
462
|
|
445
463
|
# for disagg
|
446
464
|
def get_contiguous_buf_infos(self):
|
447
465
|
# MLA has only one kv_buffer, so only the information of this buffer needs to be returned.
|
448
466
|
kv_data_ptrs = [self.kv_buffer[i].data_ptr() for i in range(self.layer_num)]
|
449
467
|
kv_data_lens = [self.kv_buffer[i].nbytes for i in range(self.layer_num)]
|
450
|
-
kv_item_lens = [
|
468
|
+
kv_item_lens = [
|
469
|
+
self.kv_buffer[i][0].nbytes * self.page_size for i in range(self.layer_num)
|
470
|
+
]
|
451
471
|
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
452
472
|
|
453
473
|
def get_key_buffer(self, layer_id: int):
|
@@ -616,26 +636,27 @@ class HostKVCache(abc.ABC):
|
|
616
636
|
self,
|
617
637
|
device_pool: MHATokenToKVPool,
|
618
638
|
host_to_device_ratio: float,
|
639
|
+
host_size: int,
|
619
640
|
pin_memory: bool,
|
620
641
|
device: str,
|
621
642
|
page_size: int,
|
622
643
|
):
|
623
|
-
assert (
|
624
|
-
host_to_device_ratio >= 1
|
625
|
-
), "The host memory should be larger than the device memory with the current protocol"
|
626
|
-
# todo, other ways of configuring the size
|
627
|
-
|
628
644
|
self.device_pool = device_pool
|
629
|
-
self.
|
645
|
+
self.dtype = device_pool.store_dtype
|
630
646
|
self.pin_memory = pin_memory
|
631
647
|
self.device = device
|
632
648
|
self.page_size = page_size
|
633
|
-
|
634
|
-
|
649
|
+
self.size_per_token = self.get_size_per_token()
|
650
|
+
if host_size > 0:
|
651
|
+
self.size = int(host_size * 1e9 // self.size_per_token)
|
652
|
+
else:
|
653
|
+
self.size = int(device_pool.size * host_to_device_ratio)
|
635
654
|
# Align the host memory pool size to the page size
|
636
655
|
self.size = self.size - (self.size % self.page_size)
|
637
|
-
|
638
|
-
|
656
|
+
|
657
|
+
assert (
|
658
|
+
self.size > device_pool.size
|
659
|
+
), "The host memory should be larger than the device memory with the current protocol"
|
639
660
|
|
640
661
|
# Verify there is enough available host memory.
|
641
662
|
host_mem = psutil.virtual_memory()
|
@@ -787,12 +808,13 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|
787
808
|
self,
|
788
809
|
device_pool: MHATokenToKVPool,
|
789
810
|
host_to_device_ratio: float,
|
811
|
+
host_size: int,
|
790
812
|
page_size: int,
|
791
813
|
pin_memory: bool = True,
|
792
814
|
device: str = "cpu",
|
793
815
|
):
|
794
816
|
super().__init__(
|
795
|
-
device_pool, host_to_device_ratio, pin_memory, device, page_size
|
817
|
+
device_pool, host_to_device_ratio, host_size, pin_memory, device, page_size
|
796
818
|
)
|
797
819
|
|
798
820
|
def get_size_per_token(self):
|
@@ -861,12 +883,13 @@ class MLATokenToKVPoolHost(HostKVCache):
|
|
861
883
|
self,
|
862
884
|
device_pool: MLATokenToKVPool,
|
863
885
|
host_to_device_ratio: float,
|
886
|
+
host_size: int,
|
864
887
|
page_size: int,
|
865
888
|
pin_memory: bool = True,
|
866
889
|
device: str = "cpu",
|
867
890
|
):
|
868
891
|
super().__init__(
|
869
|
-
device_pool, host_to_device_ratio, pin_memory, device, page_size
|
892
|
+
device_pool, host_to_device_ratio, host_size, pin_memory, device, page_size
|
870
893
|
)
|
871
894
|
|
872
895
|
def get_size_per_token(self):
|
@@ -35,13 +35,17 @@ from sglang.srt.model_executor.forward_batch_info import (
|
|
35
35
|
ForwardMode,
|
36
36
|
)
|
37
37
|
from sglang.srt.patch_torch import monkey_patch_torch_compile
|
38
|
-
from sglang.srt.utils import
|
39
|
-
|
40
|
-
|
38
|
+
from sglang.srt.utils import (
|
39
|
+
get_available_gpu_memory,
|
40
|
+
get_device_memory_capacity,
|
41
|
+
is_hip,
|
42
|
+
)
|
41
43
|
|
42
44
|
if TYPE_CHECKING:
|
43
45
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
44
46
|
|
47
|
+
_is_hip = is_hip()
|
48
|
+
|
45
49
|
|
46
50
|
def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
|
47
51
|
for sub in model._modules.values():
|
@@ -129,7 +133,8 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
|
129
133
|
list(range(1, 9)) + list(range(10, 33, 2)) + list(range(40, 161, 16))
|
130
134
|
)
|
131
135
|
|
132
|
-
|
136
|
+
gpu_mem = get_device_memory_capacity()
|
137
|
+
if gpu_mem is not None and gpu_mem > 81920:
|
133
138
|
capture_bs += list(range(160, 257, 8))
|
134
139
|
|
135
140
|
if max(capture_bs) > model_runner.req_to_token_pool.size:
|
@@ -140,12 +145,11 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
|
140
145
|
]
|
141
146
|
|
142
147
|
capture_bs = list(sorted(set(capture_bs)))
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
]
|
148
|
+
|
149
|
+
assert len(capture_bs) > 0 and capture_bs[0] > 0
|
150
|
+
capture_bs = [bs for bs in capture_bs if bs <= model_runner.req_to_token_pool.size]
|
151
|
+
if server_args.cuda_graph_max_bs:
|
152
|
+
capture_bs = [bs for bs in capture_bs if bs <= server_args.cuda_graph_max_bs]
|
149
153
|
compile_bs = (
|
150
154
|
[bs for bs in capture_bs if bs <= server_args.torch_compile_max_bs]
|
151
155
|
if server_args.enable_torch_compile
|
@@ -186,6 +190,7 @@ class CudaGraphRunner:
|
|
186
190
|
|
187
191
|
# Batch sizes to capture
|
188
192
|
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
|
193
|
+
|
189
194
|
self.capture_forward_mode = ForwardMode.DECODE
|
190
195
|
self.capture_hidden_mode = CaptureHiddenMode.NULL
|
191
196
|
self.num_tokens_per_bs = 1
|
@@ -42,6 +42,10 @@ from sglang.srt.layers.dp_attention import (
|
|
42
42
|
)
|
43
43
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
44
44
|
from sglang.srt.layers.quantization import monkey_patch_isinstance_for_vllm_base_layer
|
45
|
+
from sglang.srt.layers.quantization.deep_gemm import (
|
46
|
+
_ENABLE_JIT_DEEPGEMM,
|
47
|
+
update_deep_gemm_config,
|
48
|
+
)
|
45
49
|
from sglang.srt.layers.sampler import Sampler
|
46
50
|
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
|
47
51
|
from sglang.srt.lora.lora_manager import LoRAManager
|
@@ -73,6 +77,7 @@ from sglang.srt.utils import (
|
|
73
77
|
MultiprocessingSerializer,
|
74
78
|
enable_show_time_cost,
|
75
79
|
get_available_gpu_memory,
|
80
|
+
get_bool_env_var,
|
76
81
|
init_custom_process_group,
|
77
82
|
is_cuda,
|
78
83
|
is_fa3_default_architecture,
|
@@ -127,10 +132,7 @@ class ModelRunner:
|
|
127
132
|
self.page_size = server_args.page_size
|
128
133
|
self.req_to_token_pool = req_to_token_pool
|
129
134
|
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
130
|
-
self.use_mla_backend =
|
131
|
-
self.model_config.attention_arch == AttentionArch.MLA
|
132
|
-
and not server_args.disable_mla
|
133
|
-
)
|
135
|
+
self.use_mla_backend = self.model_config.attention_arch == AttentionArch.MLA
|
134
136
|
self.attention_chunk_size = model_config.attention_chunk_size
|
135
137
|
|
136
138
|
# Model-specific adjustment
|
@@ -139,18 +141,12 @@ class ModelRunner:
|
|
139
141
|
if server_args.show_time_cost:
|
140
142
|
enable_show_time_cost()
|
141
143
|
|
142
|
-
if server_args.disable_outlines_disk_cache:
|
143
|
-
from outlines.caching import disable_cache
|
144
|
-
|
145
|
-
disable_cache()
|
146
|
-
|
147
144
|
# Global vars
|
148
145
|
global_server_args_dict.update(
|
149
146
|
{
|
150
147
|
"attention_backend": server_args.attention_backend,
|
151
148
|
"sampling_backend": server_args.sampling_backend,
|
152
149
|
"triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
|
153
|
-
"disable_mla": server_args.disable_mla,
|
154
150
|
"torchao_config": server_args.torchao_config,
|
155
151
|
"enable_nan_detection": server_args.enable_nan_detection,
|
156
152
|
"enable_dp_attention": server_args.enable_dp_attention,
|
@@ -160,13 +156,12 @@ class ModelRunner:
|
|
160
156
|
"device": server_args.device,
|
161
157
|
"speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
|
162
158
|
"speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
|
163
|
-
"enable_flashmla": server_args.enable_flashmla,
|
164
159
|
"disable_radix_cache": server_args.disable_radix_cache,
|
165
160
|
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
|
161
|
+
"moe_dense_tp_size": server_args.moe_dense_tp_size,
|
166
162
|
"debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
|
167
163
|
"debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
|
168
164
|
"n_share_experts_fusion": server_args.n_share_experts_fusion,
|
169
|
-
"disable_shared_experts_fusion": server_args.disable_shared_experts_fusion,
|
170
165
|
"disable_chunked_prefix_cache": server_args.disable_chunked_prefix_cache,
|
171
166
|
"use_mla_backend": self.use_mla_backend,
|
172
167
|
}
|
@@ -178,6 +173,10 @@ class ModelRunner:
|
|
178
173
|
# Get memory before model loading
|
179
174
|
min_per_gpu_memory = self.init_torch_distributed()
|
180
175
|
|
176
|
+
# Update deep gemm configure
|
177
|
+
if _ENABLE_JIT_DEEPGEMM:
|
178
|
+
update_deep_gemm_config(gpu_id, server_args)
|
179
|
+
|
181
180
|
# If it is a draft model tp_group can be different.
|
182
181
|
self.initialize(min_per_gpu_memory)
|
183
182
|
|
@@ -229,16 +228,17 @@ class ModelRunner:
|
|
229
228
|
def model_specific_adjustment(self):
|
230
229
|
server_args = self.server_args
|
231
230
|
|
232
|
-
if server_args.
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
231
|
+
if server_args.attention_backend is None:
|
232
|
+
"""
|
233
|
+
We auto select the fastest attention backend according to the current offering
|
234
|
+
1. Models with MHA Architecture (e.g: Llama, QWen)
|
235
|
+
1.1 We will turn on FA3 on hopper unless user use spec decode with topk > 1 or page_size > 1.
|
236
|
+
1.2 In other cases, we will use flashinfer if available, otherwise use triton.
|
237
|
+
2. Models with MLA Architecture and using FA3
|
238
|
+
2.1 We will use FA3 backend on hopper.
|
239
|
+
2.2 Otherwise, we will use triton backend.
|
240
|
+
"""
|
241
|
+
|
242
242
|
if not self.use_mla_backend:
|
243
243
|
if (
|
244
244
|
is_hopper_with_cuda_12_3()
|
@@ -251,9 +251,7 @@ class ModelRunner:
|
|
251
251
|
"flashinfer" if is_flashinfer_available() else "triton"
|
252
252
|
)
|
253
253
|
else:
|
254
|
-
if is_hopper_with_cuda_12_3()
|
255
|
-
server_args
|
256
|
-
):
|
254
|
+
if is_hopper_with_cuda_12_3():
|
257
255
|
server_args.attention_backend = "fa3"
|
258
256
|
else:
|
259
257
|
server_args.attention_backend = "triton"
|
@@ -263,7 +261,12 @@ class ModelRunner:
|
|
263
261
|
elif self.use_mla_backend:
|
264
262
|
# TODO: add MLA optimization on CPU
|
265
263
|
if server_args.device != "cpu":
|
266
|
-
if server_args.attention_backend in [
|
264
|
+
if server_args.attention_backend in [
|
265
|
+
"flashinfer",
|
266
|
+
"fa3",
|
267
|
+
"triton",
|
268
|
+
"flashmla",
|
269
|
+
]:
|
267
270
|
logger.info(
|
268
271
|
f"MLA optimization is turned on. Use {server_args.attention_backend} backend."
|
269
272
|
)
|
@@ -320,7 +323,6 @@ class ModelRunner:
|
|
320
323
|
logger.info(f"DeepEP is turned on. DeepEP mode: {server_args.deepep_mode}")
|
321
324
|
|
322
325
|
if not self.use_mla_backend:
|
323
|
-
logger.info("Disable chunked prefix cache for non-MLA backend.")
|
324
326
|
server_args.disable_chunked_prefix_cache = True
|
325
327
|
elif self.page_size > 1:
|
326
328
|
logger.info("Disable chunked prefix cache when page size > 1.")
|
@@ -387,10 +389,16 @@ class ModelRunner:
|
|
387
389
|
local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
|
388
390
|
if self.tp_size > 1:
|
389
391
|
if min_per_gpu_memory < local_gpu_memory * 0.9:
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
|
392
|
+
if get_bool_env_var("SGL_DISABLE_TP_MEMORY_INBALANCE_CHECK"):
|
393
|
+
logger.warning(
|
394
|
+
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes. "
|
395
|
+
f"{min_per_gpu_memory=}, {local_gpu_memory=}, {local_gpu_memory * 0.9=}"
|
396
|
+
)
|
397
|
+
else:
|
398
|
+
raise ValueError(
|
399
|
+
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes. "
|
400
|
+
f"{min_per_gpu_memory=}, {local_gpu_memory=}, {local_gpu_memory * 0.9=}"
|
401
|
+
)
|
394
402
|
|
395
403
|
logger.info(
|
396
404
|
f"Init torch distributed ends. mem usage={(before_avail_memory - local_gpu_memory):.2f} GB"
|