sglang 0.4.8.post1__py3-none-any.whl → 0.4.9.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +17 -2
- sglang/bench_serving.py +170 -24
- sglang/srt/configs/internvl.py +4 -2
- sglang/srt/configs/janus_pro.py +1 -1
- sglang/srt/configs/model_config.py +60 -1
- sglang/srt/configs/update_config.py +119 -0
- sglang/srt/conversation.py +69 -1
- sglang/srt/disaggregation/decode.py +21 -5
- sglang/srt/disaggregation/mooncake/conn.py +35 -4
- sglang/srt/disaggregation/nixl/conn.py +6 -6
- sglang/srt/disaggregation/prefill.py +2 -2
- sglang/srt/disaggregation/utils.py +1 -1
- sglang/srt/distributed/parallel_state.py +44 -17
- sglang/srt/entrypoints/EngineBase.py +8 -0
- sglang/srt/entrypoints/engine.py +40 -6
- sglang/srt/entrypoints/http_server.py +111 -24
- sglang/srt/entrypoints/http_server_engine.py +1 -1
- sglang/srt/entrypoints/openai/protocol.py +4 -2
- sglang/srt/eplb/__init__.py +0 -0
- sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
- sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
- sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
- sglang/srt/{managers → eplb}/expert_distribution.py +1 -5
- sglang/srt/{managers → eplb}/expert_location.py +1 -1
- sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
- sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
- sglang/srt/hf_transformers_utils.py +2 -1
- sglang/srt/layers/activation.py +2 -2
- sglang/srt/layers/amx_utils.py +86 -0
- sglang/srt/layers/attention/ascend_backend.py +219 -0
- sglang/srt/layers/attention/flashattention_backend.py +32 -9
- sglang/srt/layers/attention/tbo_backend.py +37 -9
- sglang/srt/layers/communicator.py +20 -2
- sglang/srt/layers/dp_attention.py +9 -3
- sglang/srt/layers/elementwise.py +76 -12
- sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
- sglang/srt/layers/layernorm.py +26 -0
- sglang/srt/layers/linear.py +84 -14
- sglang/srt/layers/logits_processor.py +4 -4
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +81 -8
- sglang/srt/layers/moe/ep_moe/layer.py +176 -15
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +23 -17
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +3 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +211 -74
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
- sglang/srt/layers/moe/router.py +60 -22
- sglang/srt/layers/moe/topk.py +10 -28
- sglang/srt/layers/parameter.py +67 -7
- sglang/srt/layers/quantization/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
- sglang/srt/layers/quantization/fp8.py +72 -7
- sglang/srt/layers/quantization/fp8_kernel.py +1 -1
- sglang/srt/layers/quantization/fp8_utils.py +1 -2
- sglang/srt/layers/quantization/gptq.py +5 -1
- sglang/srt/layers/quantization/modelopt_quant.py +244 -1
- sglang/srt/layers/quantization/moe_wna16.py +1 -1
- sglang/srt/layers/quantization/quant_utils.py +166 -0
- sglang/srt/layers/quantization/w4afp8.py +264 -0
- sglang/srt/layers/quantization/w8a8_int8.py +52 -1
- sglang/srt/layers/rotary_embedding.py +2 -2
- sglang/srt/layers/vocab_parallel_embedding.py +20 -10
- sglang/srt/lora/lora.py +4 -5
- sglang/srt/lora/lora_manager.py +73 -20
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
- sglang/srt/managers/cache_controller.py +41 -195
- sglang/srt/managers/configure_logging.py +1 -1
- sglang/srt/managers/io_struct.py +58 -14
- sglang/srt/managers/mm_utils.py +77 -61
- sglang/srt/managers/multimodal_processor.py +2 -6
- sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
- sglang/srt/managers/schedule_batch.py +78 -85
- sglang/srt/managers/scheduler.py +130 -64
- sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
- sglang/srt/managers/session_controller.py +12 -3
- sglang/srt/managers/tokenizer_manager.py +314 -103
- sglang/srt/managers/tp_worker.py +13 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
- sglang/srt/mem_cache/allocator.py +290 -0
- sglang/srt/mem_cache/chunk_cache.py +34 -2
- sglang/srt/mem_cache/hiradix_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +402 -66
- sglang/srt/mem_cache/memory_pool_host.py +6 -109
- sglang/srt/mem_cache/multimodal_cache.py +3 -0
- sglang/srt/mem_cache/radix_cache.py +8 -4
- sglang/srt/model_executor/cuda_graph_runner.py +2 -1
- sglang/srt/model_executor/forward_batch_info.py +17 -4
- sglang/srt/model_executor/model_runner.py +297 -56
- sglang/srt/model_loader/loader.py +41 -0
- sglang/srt/model_loader/weight_utils.py +72 -4
- sglang/srt/models/deepseek_nextn.py +1 -3
- sglang/srt/models/deepseek_v2.py +195 -45
- sglang/srt/models/deepseek_vl2.py +3 -5
- sglang/srt/models/gemma3_causal.py +1 -2
- sglang/srt/models/gemma3n_causal.py +4 -3
- sglang/srt/models/gemma3n_mm.py +4 -20
- sglang/srt/models/hunyuan.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -2
- sglang/srt/models/llama.py +10 -4
- sglang/srt/models/llama4.py +32 -45
- sglang/srt/models/llama_eagle3.py +61 -11
- sglang/srt/models/llava.py +5 -5
- sglang/srt/models/minicpmo.py +2 -2
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mllama4.py +402 -89
- sglang/srt/models/phi4mm.py +1 -3
- sglang/srt/models/pixtral.py +3 -7
- sglang/srt/models/qwen2.py +31 -3
- sglang/srt/models/qwen2_5_vl.py +1 -3
- sglang/srt/models/qwen2_audio.py +200 -0
- sglang/srt/models/qwen2_moe.py +32 -6
- sglang/srt/models/qwen2_vl.py +1 -4
- sglang/srt/models/qwen3.py +94 -25
- sglang/srt/models/qwen3_moe.py +68 -21
- sglang/srt/models/vila.py +3 -8
- sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +2 -2
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +140 -158
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3n.py +5 -20
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +65 -66
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
- sglang/srt/operations_strategy.py +6 -2
- sglang/srt/reasoning_parser.py +26 -0
- sglang/srt/sampling/sampling_batch_info.py +39 -1
- sglang/srt/server_args.py +84 -22
- sglang/srt/speculative/build_eagle_tree.py +57 -18
- sglang/srt/speculative/eagle_worker.py +6 -4
- sglang/srt/two_batch_overlap.py +203 -27
- sglang/srt/utils.py +343 -163
- sglang/srt/warmup.py +12 -3
- sglang/test/runners.py +10 -1
- sglang/test/test_cutlass_w4a8_moe.py +281 -0
- sglang/test/test_utils.py +15 -3
- sglang/utils.py +5 -5
- sglang/version.py +1 -1
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/METADATA +12 -8
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/RECORD +157 -146
- sglang/math_utils.py +0 -8
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
- /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/top_level.txt +0 -0
@@ -26,10 +26,10 @@ from typing import List, Optional, Tuple, Union
|
|
26
26
|
import torch
|
27
27
|
import torch.distributed as dist
|
28
28
|
|
29
|
-
from sglang.srt import debug_utils
|
30
29
|
from sglang.srt.configs.device_config import DeviceConfig
|
31
30
|
from sglang.srt.configs.load_config import LoadConfig
|
32
31
|
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
32
|
+
from sglang.srt.configs.update_config import adjust_config_with_unaligned_cpu_tp
|
33
33
|
from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS
|
34
34
|
from sglang.srt.distributed import (
|
35
35
|
get_tp_group,
|
@@ -40,6 +40,19 @@ from sglang.srt.distributed import (
|
|
40
40
|
set_mscclpp_all_reduce,
|
41
41
|
)
|
42
42
|
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
|
43
|
+
from sglang.srt.eplb.eplb_manager import EPLBManager
|
44
|
+
from sglang.srt.eplb.expert_distribution import (
|
45
|
+
ExpertDistributionRecorder,
|
46
|
+
get_global_expert_distribution_recorder,
|
47
|
+
set_global_expert_distribution_recorder,
|
48
|
+
)
|
49
|
+
from sglang.srt.eplb.expert_location import (
|
50
|
+
ExpertLocationMetadata,
|
51
|
+
compute_initial_expert_location_metadata,
|
52
|
+
get_global_expert_location_metadata,
|
53
|
+
set_global_expert_location_metadata,
|
54
|
+
)
|
55
|
+
from sglang.srt.eplb.expert_location_updater import ExpertLocationUpdater
|
43
56
|
from sglang.srt.layers.attention.tbo_backend import TboAttnBackend
|
44
57
|
from sglang.srt.layers.dp_attention import (
|
45
58
|
get_attention_tp_group,
|
@@ -55,35 +68,27 @@ from sglang.srt.layers.sampler import Sampler
|
|
55
68
|
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
|
56
69
|
from sglang.srt.layers.utils import is_sm100_supported
|
57
70
|
from sglang.srt.lora.lora_manager import LoRAManager
|
58
|
-
from sglang.srt.managers.eplb_manager import EPLBManager
|
59
|
-
from sglang.srt.managers.expert_distribution import (
|
60
|
-
ExpertDistributionRecorder,
|
61
|
-
get_global_expert_distribution_recorder,
|
62
|
-
set_global_expert_distribution_recorder,
|
63
|
-
)
|
64
|
-
from sglang.srt.managers.expert_location import (
|
65
|
-
ExpertLocationMetadata,
|
66
|
-
compute_initial_expert_location_metadata,
|
67
|
-
get_global_expert_location_metadata,
|
68
|
-
set_global_expert_location_metadata,
|
69
|
-
)
|
70
71
|
from sglang.srt.managers.schedule_batch import (
|
71
72
|
GLOBAL_SERVER_ARGS_KEYS,
|
72
73
|
global_server_args_dict,
|
73
74
|
)
|
74
75
|
from sglang.srt.mem_cache.allocator import (
|
76
|
+
AscendPagedTokenToKVPoolAllocator,
|
75
77
|
BaseTokenToKVPoolAllocator,
|
76
78
|
PagedTokenToKVPoolAllocator,
|
79
|
+
SWATokenToKVPoolAllocator,
|
77
80
|
TokenToKVPoolAllocator,
|
78
81
|
)
|
79
82
|
from sglang.srt.mem_cache.memory_pool import (
|
83
|
+
AscendMLAPagedTokenToKVPool,
|
84
|
+
AscendTokenToKVPool,
|
80
85
|
DoubleSparseTokenToKVPool,
|
81
86
|
MHATokenToKVPool,
|
82
87
|
MLATokenToKVPool,
|
83
88
|
ReqToTokenPool,
|
89
|
+
SWAKVPool,
|
84
90
|
)
|
85
91
|
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
86
|
-
from sglang.srt.model_executor.expert_location_updater import ExpertLocationUpdater
|
87
92
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
88
93
|
from sglang.srt.model_loader import get_model
|
89
94
|
from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader
|
@@ -101,6 +106,7 @@ from sglang.srt.utils import (
|
|
101
106
|
enable_show_time_cost,
|
102
107
|
get_available_gpu_memory,
|
103
108
|
get_bool_env_var,
|
109
|
+
get_cpu_ids_by_node,
|
104
110
|
init_custom_process_group,
|
105
111
|
is_cuda,
|
106
112
|
is_fa3_default_architecture,
|
@@ -108,6 +114,7 @@ from sglang.srt.utils import (
|
|
108
114
|
is_hip,
|
109
115
|
is_hopper_with_cuda_12_3,
|
110
116
|
is_no_spec_infer_or_topk_one,
|
117
|
+
is_npu,
|
111
118
|
monkey_patch_p2p_access_check,
|
112
119
|
monkey_patch_vllm_gguf_config,
|
113
120
|
set_cpu_offload_max_bytes,
|
@@ -115,6 +122,7 @@ from sglang.srt.utils import (
|
|
115
122
|
)
|
116
123
|
|
117
124
|
_is_hip = is_hip()
|
125
|
+
_is_npu = is_npu()
|
118
126
|
_is_cpu_amx_available = cpu_has_amx_support()
|
119
127
|
|
120
128
|
# Use a small KV cache pool size for tests in CI
|
@@ -158,7 +166,6 @@ class ModelRunner:
|
|
158
166
|
token_to_kv_pool_allocator: Optional[BaseTokenToKVPoolAllocator] = None,
|
159
167
|
):
|
160
168
|
# Parse args
|
161
|
-
self.model_config = model_config
|
162
169
|
self.mem_fraction_static = mem_fraction_static
|
163
170
|
self.device = server_args.device
|
164
171
|
self.gpu_id = gpu_id
|
@@ -171,6 +178,7 @@ class ModelRunner:
|
|
171
178
|
self.dp_size = server_args.dp_size
|
172
179
|
self.pp_rank = pp_rank
|
173
180
|
self.pp_size = pp_size
|
181
|
+
self.model_config = model_config
|
174
182
|
self.dist_port = nccl_port
|
175
183
|
self.server_args = server_args
|
176
184
|
self.is_draft_worker = is_draft_worker
|
@@ -185,6 +193,7 @@ class ModelRunner:
|
|
185
193
|
self.page_size = server_args.page_size
|
186
194
|
self.req_to_token_pool = req_to_token_pool
|
187
195
|
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
196
|
+
self.is_hybrid = model_config.is_hybrid
|
188
197
|
self.use_mla_backend = self.model_config.attention_arch == AttentionArch.MLA
|
189
198
|
self.attention_chunk_size = model_config.attention_chunk_size
|
190
199
|
|
@@ -209,6 +218,10 @@ class ModelRunner:
|
|
209
218
|
# CPU offload
|
210
219
|
set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3))
|
211
220
|
|
221
|
+
# Init OpenMP threads binding for CPU
|
222
|
+
if self.device == "cpu":
|
223
|
+
self.init_threads_binding()
|
224
|
+
|
212
225
|
# Get memory before model loading
|
213
226
|
min_per_gpu_memory = self.init_torch_distributed()
|
214
227
|
|
@@ -223,6 +236,7 @@ class ModelRunner:
|
|
223
236
|
self.support_pp = (
|
224
237
|
"pp_proxy_tensors" in inspect.signature(self.model.forward).parameters
|
225
238
|
)
|
239
|
+
self._model_update_group = {}
|
226
240
|
|
227
241
|
def initialize(self, min_per_gpu_memory: float):
|
228
242
|
server_args = self.server_args
|
@@ -300,11 +314,31 @@ class ModelRunner:
|
|
300
314
|
self.init_cuda_graphs()
|
301
315
|
else:
|
302
316
|
self.cuda_graph_runner = None
|
317
|
+
self.cuda_graph_mem_usage = 0
|
303
318
|
self.init_attention_backend()
|
304
319
|
|
305
320
|
# auxiliary hidden capture mode. TODO: expose this to server args?
|
306
321
|
if self.spec_algorithm.is_eagle3() and not self.is_draft_worker:
|
307
|
-
|
322
|
+
# load draft config
|
323
|
+
draft_model_config = ModelConfig.from_server_args(
|
324
|
+
server_args,
|
325
|
+
model_path=(server_args.speculative_draft_model_path),
|
326
|
+
is_draft_model=True,
|
327
|
+
)
|
328
|
+
|
329
|
+
try:
|
330
|
+
# get the aux layer from draft model config
|
331
|
+
eagle_config = getattr(
|
332
|
+
draft_model_config.hf_config, "eagle_config", None
|
333
|
+
)
|
334
|
+
eagle_aux_hidden_state_layer_ids = eagle_config[
|
335
|
+
"eagle_aux_hidden_state_layer_ids"
|
336
|
+
]
|
337
|
+
except:
|
338
|
+
# if there is no aux layer, set to None
|
339
|
+
eagle_aux_hidden_state_layer_ids = None
|
340
|
+
|
341
|
+
self.model.set_eagle3_layers_to_capture(eagle_aux_hidden_state_layer_ids)
|
308
342
|
|
309
343
|
def model_specific_adjustment(self):
|
310
344
|
server_args = self.server_args
|
@@ -342,6 +376,8 @@ class ModelRunner:
|
|
342
376
|
server_args.attention_backend = "fa3"
|
343
377
|
elif _is_hip:
|
344
378
|
server_args.attention_backend = "aiter"
|
379
|
+
elif _is_npu:
|
380
|
+
server_args.attention_backend = "ascend"
|
345
381
|
else:
|
346
382
|
server_args.attention_backend = (
|
347
383
|
"flashinfer" if is_flashinfer_available() else "triton"
|
@@ -361,6 +397,8 @@ class ModelRunner:
|
|
361
397
|
server_args.attention_backend = "aiter"
|
362
398
|
else:
|
363
399
|
server_args.attention_backend = "triton"
|
400
|
+
elif _is_npu:
|
401
|
+
server_args.attention_backend = "ascend"
|
364
402
|
else:
|
365
403
|
server_args.attention_backend = "triton"
|
366
404
|
logger.info(
|
@@ -375,6 +413,7 @@ class ModelRunner:
|
|
375
413
|
"triton",
|
376
414
|
"flashmla",
|
377
415
|
"cutlass_mla",
|
416
|
+
"ascend",
|
378
417
|
]:
|
379
418
|
logger.info(
|
380
419
|
f"MLA optimization is turned on. Use {server_args.attention_backend} backend."
|
@@ -412,11 +451,6 @@ class ModelRunner:
|
|
412
451
|
self.init_double_sparsity_channel_config(server_args.ds_heavy_channel_type)
|
413
452
|
|
414
453
|
if self.is_multimodal:
|
415
|
-
self.mem_fraction_static *= 0.90
|
416
|
-
logger.info(
|
417
|
-
f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
|
418
|
-
f"because this is a multimodal model."
|
419
|
-
)
|
420
454
|
if not self.is_multimodal_chunked_prefill_supported:
|
421
455
|
server_args.chunked_prefill_size = -1
|
422
456
|
logger.info(
|
@@ -437,6 +471,10 @@ class ModelRunner:
|
|
437
471
|
if self.model_config.context_len > 8192:
|
438
472
|
self.mem_fraction_static *= 0.85
|
439
473
|
|
474
|
+
if self.is_hybrid and not server_args.disable_radix_cache:
|
475
|
+
logger.info("Automatically disable radix cache for hybrid cache.")
|
476
|
+
server_args.disable_radix_cache = True
|
477
|
+
|
440
478
|
def init_torch_distributed(self):
|
441
479
|
logger.info("Init torch distributed begin.")
|
442
480
|
|
@@ -471,6 +509,19 @@ class ModelRunner:
|
|
471
509
|
set_mscclpp_all_reduce(self.server_args.enable_mscclpp)
|
472
510
|
|
473
511
|
if not self.is_draft_worker:
|
512
|
+
if self.device == "cpu":
|
513
|
+
if _is_cpu_amx_available:
|
514
|
+
# Bind OpenMP threads to CPU cores
|
515
|
+
torch.ops.sgl_kernel.init_cpu_threads_env(self.local_omp_cpuid)
|
516
|
+
|
517
|
+
# Set local size to hint SGLang to use shared memory based AllReduce
|
518
|
+
os.environ["LOCAL_SIZE"] = str(self.tp_size)
|
519
|
+
torch.ops.sgl_kernel.initialize(self.tp_size, self.tp_rank)
|
520
|
+
else:
|
521
|
+
logger.warning(
|
522
|
+
"init_cpu_threads_env and shared memory based AllReduce is disabled since intel amx backend is not available"
|
523
|
+
)
|
524
|
+
|
474
525
|
# Only initialize the distributed environment on the target model worker.
|
475
526
|
init_distributed_environment(
|
476
527
|
backend=backend,
|
@@ -549,6 +600,10 @@ class ModelRunner:
|
|
549
600
|
download_dir=self.server_args.download_dir,
|
550
601
|
model_loader_extra_config=self.server_args.model_loader_extra_config,
|
551
602
|
)
|
603
|
+
if self.device == "cpu":
|
604
|
+
self.model_config = adjust_config_with_unaligned_cpu_tp(
|
605
|
+
self.model_config, self.load_config, self.tp_size
|
606
|
+
)
|
552
607
|
if self.server_args.load_format == "gguf":
|
553
608
|
monkey_patch_vllm_gguf_config()
|
554
609
|
|
@@ -598,12 +653,13 @@ class ModelRunner:
|
|
598
653
|
self.dtype = self.model_config.dtype
|
599
654
|
|
600
655
|
after_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
|
656
|
+
self.weight_load_mem_usage = before_avail_memory - after_avail_memory
|
601
657
|
logger.info(
|
602
658
|
f"Load weight end. "
|
603
659
|
f"type={type(self.model).__name__}, "
|
604
660
|
f"dtype={self.dtype}, "
|
605
661
|
f"avail mem={after_avail_memory:.2f} GB, "
|
606
|
-
f"mem usage={
|
662
|
+
f"mem usage={self.weight_load_mem_usage:.2f} GB."
|
607
663
|
)
|
608
664
|
|
609
665
|
# Handle the case where some ranks do not finish loading.
|
@@ -718,7 +774,7 @@ class ModelRunner:
|
|
718
774
|
)
|
719
775
|
|
720
776
|
try:
|
721
|
-
self._model_update_group = init_custom_process_group(
|
777
|
+
self._model_update_group[group_name] = init_custom_process_group(
|
722
778
|
backend=backend,
|
723
779
|
init_method=f"tcp://{master_address}:{master_port}",
|
724
780
|
world_size=world_size,
|
@@ -731,7 +787,7 @@ class ModelRunner:
|
|
731
787
|
logger.error(message)
|
732
788
|
return False, message
|
733
789
|
|
734
|
-
def update_weights_from_distributed(self,
|
790
|
+
def update_weights_from_distributed(self, names, dtypes, shapes, group_name):
|
735
791
|
"""
|
736
792
|
Update specific parameter in the model weights online
|
737
793
|
through `_model_update_group` process group.
|
@@ -741,19 +797,34 @@ class ModelRunner:
|
|
741
797
|
dtype: the data type of the parameter to be updated.
|
742
798
|
shape: the shape of the parameter to be updated.
|
743
799
|
"""
|
744
|
-
target_dtype = (
|
745
|
-
dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype)
|
746
|
-
)
|
747
800
|
|
748
|
-
assert (
|
749
|
-
self._model_update_group
|
750
|
-
|
801
|
+
assert group_name in self._model_update_group, (
|
802
|
+
f"Group {group_name} not in {list(self._model_update_group.keys())}. "
|
803
|
+
"Please call `init_weights_update_group` first."
|
804
|
+
)
|
751
805
|
|
752
806
|
try:
|
753
|
-
weights =
|
754
|
-
|
755
|
-
|
756
|
-
|
807
|
+
weights = []
|
808
|
+
handles = []
|
809
|
+
for name, dtype, shape in zip(names, dtypes, shapes):
|
810
|
+
target_dtype = (
|
811
|
+
dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype)
|
812
|
+
)
|
813
|
+
weight = torch.empty(shape, dtype=target_dtype, device=self.device)
|
814
|
+
handles.append(
|
815
|
+
torch.distributed.broadcast(
|
816
|
+
weight,
|
817
|
+
src=0,
|
818
|
+
group=self._model_update_group[group_name],
|
819
|
+
async_op=True,
|
820
|
+
)
|
821
|
+
)
|
822
|
+
weights.append((name, weight))
|
823
|
+
for handle in handles:
|
824
|
+
handle.wait()
|
825
|
+
|
826
|
+
self.model.load_weights(weights)
|
827
|
+
return True, f"Succeeded to update parameter online."
|
757
828
|
|
758
829
|
except Exception as e:
|
759
830
|
error_msg = (
|
@@ -812,8 +883,47 @@ class ModelRunner:
|
|
812
883
|
tp_size=self.tp_size,
|
813
884
|
tp_rank=self.tp_rank,
|
814
885
|
)
|
815
|
-
self.lora_manager.load_lora_adapters(self.server_args.lora_paths)
|
816
|
-
|
886
|
+
result = self.lora_manager.load_lora_adapters(self.server_args.lora_paths)
|
887
|
+
if result.success:
|
888
|
+
logger.info(
|
889
|
+
f"LoRA manager ready. Loaded LoRA adapters: {', '.join(result.loaded_adapters)}"
|
890
|
+
)
|
891
|
+
else:
|
892
|
+
raise RuntimeError(f"Failed to load LoRA adapters: {result.error_message}")
|
893
|
+
|
894
|
+
def load_lora_adapter(self, lora_name: str, lora_path: str):
|
895
|
+
"""Load a new lora adapter from disk or huggingface."""
|
896
|
+
|
897
|
+
logger.info(
|
898
|
+
f"LoRA adapter loading starts: name={lora_name}, path={lora_path}. "
|
899
|
+
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
900
|
+
)
|
901
|
+
|
902
|
+
result = self.lora_manager.load_lora_adapter(lora_name, lora_path)
|
903
|
+
|
904
|
+
logger.info(
|
905
|
+
f"LoRA adapter loading completes: name={lora_name}, path={lora_path}. "
|
906
|
+
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
907
|
+
)
|
908
|
+
|
909
|
+
return result
|
910
|
+
|
911
|
+
def unload_lora_adapter(self, lora_name: str):
|
912
|
+
"""Unload a lora adapter that was previously loaded during initialization or dynamic loading."""
|
913
|
+
|
914
|
+
logger.info(
|
915
|
+
f"LoRA adapter unloading starts: name={lora_name}. "
|
916
|
+
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
917
|
+
)
|
918
|
+
|
919
|
+
result = self.lora_manager.unload_lora_adapter(lora_name)
|
920
|
+
|
921
|
+
logger.info(
|
922
|
+
f"LoRA adapter unloading completes: name={lora_name}. "
|
923
|
+
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
924
|
+
)
|
925
|
+
|
926
|
+
return result
|
817
927
|
|
818
928
|
def profile_max_num_token(self, total_gpu_memory: int):
|
819
929
|
available_gpu_memory = get_available_gpu_memory(
|
@@ -852,6 +962,40 @@ class ModelRunner:
|
|
852
962
|
max_num_token = int(rest_memory * (1 << 30) // cell_size)
|
853
963
|
return max_num_token
|
854
964
|
|
965
|
+
def set_num_token_hybrid(self):
|
966
|
+
if (
|
967
|
+
"Llama4ForConditionalGeneration"
|
968
|
+
in self.model_config.hf_config.architectures
|
969
|
+
):
|
970
|
+
temp_ratio = (
|
971
|
+
(1 - self.is_hybrid)
|
972
|
+
+ self.is_hybrid
|
973
|
+
* self.attention_chunk_size
|
974
|
+
/ self.model_config.context_len
|
975
|
+
)
|
976
|
+
self.swa_max_total_num_tokens = (
|
977
|
+
4 * self.max_total_num_tokens * temp_ratio // (3 * temp_ratio + 1)
|
978
|
+
)
|
979
|
+
self.full_max_total_num_tokens = (
|
980
|
+
4 * self.max_total_num_tokens
|
981
|
+
- 12 * self.max_total_num_tokens * temp_ratio // (3 * temp_ratio + 1)
|
982
|
+
)
|
983
|
+
self.swa_max_total_num_tokens = int(
|
984
|
+
self.swa_max_total_num_tokens
|
985
|
+
// self.server_args.page_size
|
986
|
+
* self.server_args.page_size
|
987
|
+
)
|
988
|
+
self.full_max_total_num_tokens = int(
|
989
|
+
self.full_max_total_num_tokens
|
990
|
+
// self.server_args.page_size
|
991
|
+
* self.server_args.page_size
|
992
|
+
)
|
993
|
+
self.max_total_num_tokens = self.full_max_total_num_tokens
|
994
|
+
else:
|
995
|
+
raise ValueError(
|
996
|
+
f"Unsupported model for hybrid cache: {self.model_config.hf_config.architectures}."
|
997
|
+
)
|
998
|
+
|
855
999
|
def init_memory_pool(
|
856
1000
|
self,
|
857
1001
|
total_gpu_memory: int,
|
@@ -929,6 +1073,10 @@ class ModelRunner:
|
|
929
1073
|
* self.server_args.page_size
|
930
1074
|
)
|
931
1075
|
|
1076
|
+
# create token size for hybrid cache
|
1077
|
+
if self.is_hybrid:
|
1078
|
+
self.set_num_token_hybrid()
|
1079
|
+
|
932
1080
|
if self.max_total_num_tokens <= 0:
|
933
1081
|
raise RuntimeError(
|
934
1082
|
"Not enough memory. Please try to increase --mem-fraction-static."
|
@@ -959,8 +1107,19 @@ class ModelRunner:
|
|
959
1107
|
# Draft worker shares req_to_token_pool with the target worker.
|
960
1108
|
assert self.is_draft_worker
|
961
1109
|
|
962
|
-
if self.use_mla_backend:
|
963
|
-
self.token_to_kv_pool =
|
1110
|
+
if self.server_args.attention_backend == "ascend" and not self.use_mla_backend:
|
1111
|
+
self.token_to_kv_pool = AscendTokenToKVPool(
|
1112
|
+
self.max_total_num_tokens,
|
1113
|
+
page_size=self.page_size,
|
1114
|
+
dtype=self.kv_cache_dtype,
|
1115
|
+
head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
|
1116
|
+
head_dim=self.model_config.head_dim,
|
1117
|
+
layer_num=self.model_config.num_hidden_layers,
|
1118
|
+
device=self.device,
|
1119
|
+
enable_memory_saver=self.server_args.enable_memory_saver,
|
1120
|
+
)
|
1121
|
+
elif self.server_args.attention_backend == "ascend" and self.use_mla_backend:
|
1122
|
+
self.token_to_kv_pool = AscendMLAPagedTokenToKVPool(
|
964
1123
|
self.max_total_num_tokens,
|
965
1124
|
page_size=self.page_size,
|
966
1125
|
dtype=self.kv_cache_dtype,
|
@@ -976,22 +1135,25 @@ class ModelRunner:
|
|
976
1135
|
start_layer=self.start_layer,
|
977
1136
|
end_layer=self.end_layer,
|
978
1137
|
)
|
979
|
-
elif self.
|
980
|
-
self.token_to_kv_pool =
|
1138
|
+
elif self.use_mla_backend:
|
1139
|
+
self.token_to_kv_pool = MLATokenToKVPool(
|
981
1140
|
self.max_total_num_tokens,
|
982
1141
|
page_size=self.page_size,
|
983
1142
|
dtype=self.kv_cache_dtype,
|
984
|
-
|
985
|
-
|
986
|
-
layer_num=
|
1143
|
+
kv_lora_rank=self.model_config.kv_lora_rank,
|
1144
|
+
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
|
1145
|
+
layer_num=(
|
1146
|
+
self.model_config.num_hidden_layers
|
1147
|
+
if not self.is_draft_worker
|
1148
|
+
else self.model_config.hf_config.num_nextn_predict_layers
|
1149
|
+
), # PP is not compatible with mla backend
|
987
1150
|
device=self.device,
|
988
|
-
heavy_channel_num=self.server_args.ds_heavy_channel_num,
|
989
1151
|
enable_memory_saver=self.server_args.enable_memory_saver,
|
990
1152
|
start_layer=self.start_layer,
|
991
1153
|
end_layer=self.end_layer,
|
992
1154
|
)
|
993
|
-
|
994
|
-
self.token_to_kv_pool =
|
1155
|
+
elif self.server_args.enable_double_sparsity:
|
1156
|
+
self.token_to_kv_pool = DoubleSparseTokenToKVPool(
|
995
1157
|
self.max_total_num_tokens,
|
996
1158
|
page_size=self.page_size,
|
997
1159
|
dtype=self.kv_cache_dtype,
|
@@ -999,27 +1161,76 @@ class ModelRunner:
|
|
999
1161
|
head_dim=self.model_config.head_dim,
|
1000
1162
|
layer_num=self.num_effective_layers,
|
1001
1163
|
device=self.device,
|
1164
|
+
heavy_channel_num=self.server_args.ds_heavy_channel_num,
|
1002
1165
|
enable_memory_saver=self.server_args.enable_memory_saver,
|
1003
1166
|
start_layer=self.start_layer,
|
1004
1167
|
end_layer=self.end_layer,
|
1005
1168
|
)
|
1006
|
-
|
1007
|
-
|
1008
|
-
|
1009
|
-
|
1010
|
-
self.
|
1169
|
+
else:
|
1170
|
+
if self.is_hybrid:
|
1171
|
+
self.token_to_kv_pool = SWAKVPool(
|
1172
|
+
size=self.full_max_total_num_tokens,
|
1173
|
+
size_swa=self.swa_max_total_num_tokens,
|
1011
1174
|
dtype=self.kv_cache_dtype,
|
1175
|
+
head_num=self.model_config.get_num_kv_heads(
|
1176
|
+
get_attention_tp_size()
|
1177
|
+
),
|
1178
|
+
head_dim=self.model_config.head_dim,
|
1179
|
+
swa_attention_layer_ids=self.model_config.swa_attention_layer_ids,
|
1180
|
+
full_attention_layer_ids=self.model_config.full_attention_layer_ids,
|
1181
|
+
enable_kvcache_transpose=False,
|
1012
1182
|
device=self.device,
|
1013
|
-
kvcache=self.token_to_kv_pool,
|
1014
1183
|
)
|
1015
1184
|
else:
|
1016
|
-
self.
|
1185
|
+
self.token_to_kv_pool = MHATokenToKVPool(
|
1017
1186
|
self.max_total_num_tokens,
|
1018
1187
|
page_size=self.page_size,
|
1019
1188
|
dtype=self.kv_cache_dtype,
|
1189
|
+
head_num=self.model_config.get_num_kv_heads(
|
1190
|
+
get_attention_tp_size()
|
1191
|
+
),
|
1192
|
+
head_dim=self.model_config.head_dim,
|
1193
|
+
layer_num=self.num_effective_layers,
|
1020
1194
|
device=self.device,
|
1021
|
-
|
1195
|
+
enable_memory_saver=self.server_args.enable_memory_saver,
|
1196
|
+
start_layer=self.start_layer,
|
1197
|
+
end_layer=self.end_layer,
|
1022
1198
|
)
|
1199
|
+
|
1200
|
+
if self.token_to_kv_pool_allocator is None:
|
1201
|
+
if self.page_size == 1:
|
1202
|
+
if self.is_hybrid:
|
1203
|
+
self.token_to_kv_pool_allocator = SWATokenToKVPoolAllocator(
|
1204
|
+
self.full_max_total_num_tokens,
|
1205
|
+
self.swa_max_total_num_tokens,
|
1206
|
+
dtype=self.kv_cache_dtype,
|
1207
|
+
device=self.device,
|
1208
|
+
kvcache=self.token_to_kv_pool,
|
1209
|
+
)
|
1210
|
+
else:
|
1211
|
+
self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
|
1212
|
+
self.max_total_num_tokens,
|
1213
|
+
dtype=self.kv_cache_dtype,
|
1214
|
+
device=self.device,
|
1215
|
+
kvcache=self.token_to_kv_pool,
|
1216
|
+
)
|
1217
|
+
else:
|
1218
|
+
if _is_npu:
|
1219
|
+
self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
|
1220
|
+
self.max_total_num_tokens,
|
1221
|
+
page_size=self.page_size,
|
1222
|
+
dtype=self.kv_cache_dtype,
|
1223
|
+
device=self.device,
|
1224
|
+
kvcache=self.token_to_kv_pool,
|
1225
|
+
)
|
1226
|
+
else:
|
1227
|
+
self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator(
|
1228
|
+
self.max_total_num_tokens,
|
1229
|
+
page_size=self.page_size,
|
1230
|
+
dtype=self.kv_cache_dtype,
|
1231
|
+
device=self.device,
|
1232
|
+
kvcache=self.token_to_kv_pool,
|
1233
|
+
)
|
1023
1234
|
else:
|
1024
1235
|
assert self.is_draft_worker
|
1025
1236
|
|
@@ -1039,7 +1250,7 @@ class ModelRunner:
|
|
1039
1250
|
|
1040
1251
|
def init_attention_backend(self):
|
1041
1252
|
"""Init attention kernel backend."""
|
1042
|
-
if self.server_args.enable_two_batch_overlap:
|
1253
|
+
if self.server_args.enable_two_batch_overlap and not self.is_draft_worker:
|
1043
1254
|
self.attn_backend = TboAttnBackend.init_new(self._get_attention_backend)
|
1044
1255
|
else:
|
1045
1256
|
self.attn_backend = self._get_attention_backend()
|
@@ -1066,6 +1277,10 @@ class ModelRunner:
|
|
1066
1277
|
from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend
|
1067
1278
|
|
1068
1279
|
return AiterAttnBackend(self)
|
1280
|
+
elif self.server_args.attention_backend == "ascend":
|
1281
|
+
from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend
|
1282
|
+
|
1283
|
+
return AscendAttnBackend(self)
|
1069
1284
|
elif self.server_args.attention_backend == "triton":
|
1070
1285
|
assert not self.model_config.is_encoder_decoder, (
|
1071
1286
|
"Cross attention is not supported in the triton attention backend. "
|
@@ -1141,6 +1356,7 @@ class ModelRunner:
|
|
1141
1356
|
def init_cuda_graphs(self):
|
1142
1357
|
"""Capture cuda graphs."""
|
1143
1358
|
self.cuda_graph_runner = None
|
1359
|
+
self.cuda_graph_mem_usage = 0
|
1144
1360
|
|
1145
1361
|
if not self.is_generation:
|
1146
1362
|
# TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
|
@@ -1156,11 +1372,36 @@ class ModelRunner:
|
|
1156
1372
|
)
|
1157
1373
|
self.cuda_graph_runner = CudaGraphRunner(self)
|
1158
1374
|
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
1375
|
+
self.cuda_graph_mem_usage = before_mem - after_mem
|
1159
1376
|
logger.info(
|
1160
1377
|
f"Capture cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. "
|
1161
|
-
f"mem usage={
|
1378
|
+
f"mem usage={self.cuda_graph_mem_usage:.2f} GB. avail mem={after_mem:.2f} GB."
|
1162
1379
|
)
|
1163
1380
|
|
1381
|
+
def init_threads_binding(self):
|
1382
|
+
omp_cpuids = os.environ.get("SGLANG_CPU_OMP_THREADS_BIND", "all")
|
1383
|
+
if omp_cpuids == "all":
|
1384
|
+
cpu_ids_by_node = get_cpu_ids_by_node()
|
1385
|
+
n_numa_node = len(cpu_ids_by_node)
|
1386
|
+
|
1387
|
+
assert self.tp_size <= n_numa_node, (
|
1388
|
+
f"SGLANG_CPU_OMP_THREADS_BIND is not set, in this case, "
|
1389
|
+
f"tp_size {self.tp_size} should be smaller than or equal to number of numa node on the machine {n_numa_node}. "
|
1390
|
+
f"If you need tp_size to be larger than number of numa node, please set the CPU cores for each tp rank via SGLANG_CPU_OMP_THREADS_BIND explicitly. "
|
1391
|
+
f"For example, on a machine with 2 numa nodes, where core 0-31 are on numa node 0 and core 32-63 are on numa node 1, "
|
1392
|
+
f"it is suggested to use -tp 2 and bind tp rank 0 to core 0-31 and tp rank 1 to core 32-63. "
|
1393
|
+
f"This is the default behavior if SGLANG_CPU_OMP_THREADS_BIND is not set and it is the same as setting SGLANG_CPU_OMP_THREADS_BIND=0-31|32-63. "
|
1394
|
+
f"If you do need tp_size to be larger than the number of numa nodes, you could set SGLANG_CPU_OMP_THREADS_BIND explicitly for example SGLANG_CPU_OMP_THREADS_BIND=0-15|16-31|32-47|48-63 and run with -tp 4. "
|
1395
|
+
f"If you don't want each tp rank to use all the cores on one numa node, you could set for example SGLANG_CPU_OMP_THREADS_BIND=0-15|32-47 and run with -tp 2."
|
1396
|
+
)
|
1397
|
+
if self.tp_size < n_numa_node:
|
1398
|
+
logger.warning(
|
1399
|
+
f"Detected the current machine has {n_numa_node} numa nodes available, but tp_size is set to {self.tp_size}, so only {self.tp_size} numa nodes are used."
|
1400
|
+
)
|
1401
|
+
self.local_omp_cpuid = cpu_ids_by_node[self.tp_rank]
|
1402
|
+
else:
|
1403
|
+
self.local_omp_cpuid = omp_cpuids.split("|")[self.tp_rank]
|
1404
|
+
|
1164
1405
|
def apply_torch_tp(self):
|
1165
1406
|
logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
|
1166
1407
|
from sglang.srt.model_parallel import tensor_parallel
|
@@ -124,6 +124,9 @@ def _get_quantization_config(
|
|
124
124
|
quant_config = get_quant_config(
|
125
125
|
model_config, load_config, packed_modules_mapping
|
126
126
|
)
|
127
|
+
# (yizhang2077) workaround for nvidia/Llama-4-Maverick-17B-128E-Eagle3
|
128
|
+
if quant_config is None:
|
129
|
+
return None
|
127
130
|
major, minor = get_device_capability()
|
128
131
|
|
129
132
|
if major is not None and minor is not None:
|
@@ -534,6 +537,12 @@ class DummyModelLoader(BaseModelLoader):
|
|
534
537
|
model_config: ModelConfig,
|
535
538
|
device_config: DeviceConfig,
|
536
539
|
) -> nn.Module:
|
540
|
+
|
541
|
+
if get_bool_env_var("SGL_CPU_QUANTIZATION"):
|
542
|
+
return load_model_with_cpu_quantization(
|
543
|
+
self, model_config=model_config, device_config=device_config
|
544
|
+
)
|
545
|
+
|
537
546
|
with set_default_torch_dtype(model_config.dtype):
|
538
547
|
with torch.device(device_config.device):
|
539
548
|
model = _initialize_model(
|
@@ -1464,6 +1473,38 @@ class RemoteModelLoader(BaseModelLoader):
|
|
1464
1473
|
return model.eval()
|
1465
1474
|
|
1466
1475
|
|
1476
|
+
def load_model_with_cpu_quantization(
|
1477
|
+
self,
|
1478
|
+
*,
|
1479
|
+
model_config: ModelConfig,
|
1480
|
+
device_config: DeviceConfig,
|
1481
|
+
) -> nn.Module:
|
1482
|
+
target_device = torch.device(device_config.device)
|
1483
|
+
with set_default_torch_dtype(model_config.dtype):
|
1484
|
+
model = _initialize_model(
|
1485
|
+
model_config,
|
1486
|
+
self.load_config,
|
1487
|
+
)
|
1488
|
+
|
1489
|
+
if not isinstance(self, DummyModelLoader):
|
1490
|
+
model.load_weights(self._get_all_weights(model_config, model))
|
1491
|
+
|
1492
|
+
for _, module in model.named_modules():
|
1493
|
+
quant_method = getattr(module, "quant_method", None)
|
1494
|
+
if quant_method is not None:
|
1495
|
+
# When quant methods need to process weights after loading
|
1496
|
+
# (for repacking, quantizing, etc), they expect parameters
|
1497
|
+
# to be on the global target device. This scope is for the
|
1498
|
+
# case where cpu offloading is used, where we will move the
|
1499
|
+
# parameters onto device for processing and back off after.
|
1500
|
+
with device_loading_context(module, target_device):
|
1501
|
+
quant_method.process_weights_after_loading(module)
|
1502
|
+
|
1503
|
+
model.to(target_device)
|
1504
|
+
|
1505
|
+
return model.eval()
|
1506
|
+
|
1507
|
+
|
1467
1508
|
def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
|
1468
1509
|
"""Get a model loader based on the load format."""
|
1469
1510
|
|