sglang 0.4.4.post1__py3-none-any.whl → 0.4.4.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +6 -0
- sglang/bench_one_batch.py +1 -1
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +26 -4
- sglang/check_env.py +3 -4
- sglang/lang/backend/openai.py +18 -5
- sglang/lang/chat_template.py +28 -7
- sglang/lang/interpreter.py +7 -3
- sglang/lang/ir.py +10 -0
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/code_completion_parser.py +174 -0
- sglang/srt/configs/__init__.py +2 -6
- sglang/srt/configs/deepseekvl2.py +676 -0
- sglang/srt/configs/janus_pro.py +3 -4
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +49 -8
- sglang/srt/configs/utils.py +25 -0
- sglang/srt/connector/__init__.py +51 -0
- sglang/srt/connector/base_connector.py +112 -0
- sglang/srt/connector/redis.py +85 -0
- sglang/srt/connector/s3.py +122 -0
- sglang/srt/connector/serde/__init__.py +31 -0
- sglang/srt/connector/serde/safe_serde.py +29 -0
- sglang/srt/connector/serde/serde.py +43 -0
- sglang/srt/connector/utils.py +35 -0
- sglang/srt/conversation.py +88 -0
- sglang/srt/disaggregation/conn.py +81 -0
- sglang/srt/disaggregation/decode.py +495 -0
- sglang/srt/disaggregation/mini_lb.py +285 -0
- sglang/srt/disaggregation/prefill.py +249 -0
- sglang/srt/disaggregation/utils.py +44 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/parallel_state.py +42 -8
- sglang/srt/entrypoints/engine.py +55 -5
- sglang/srt/entrypoints/http_server.py +78 -13
- sglang/srt/entrypoints/verl_engine.py +2 -0
- sglang/srt/function_call_parser.py +133 -55
- sglang/srt/hf_transformers_utils.py +28 -3
- sglang/srt/layers/activation.py +4 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +434 -0
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashmla_backend.py +284 -0
- sglang/srt/layers/attention/triton_backend.py +171 -38
- sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
- sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
- sglang/srt/layers/attention/utils.py +53 -0
- sglang/srt/layers/attention/vision.py +9 -28
- sglang/srt/layers/dp_attention.py +41 -19
- sglang/srt/layers/layernorm.py +24 -2
- sglang/srt/layers/linear.py +17 -5
- sglang/srt/layers/logits_processor.py +25 -7
- sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
- sglang/srt/layers/moe/ep_moe/layer.py +273 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
- sglang/srt/layers/moe/fused_moe_native.py +2 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
- sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
- sglang/srt/layers/moe/topk.py +60 -20
- sglang/srt/layers/parameter.py +1 -1
- sglang/srt/layers/quantization/__init__.py +80 -53
- sglang/srt/layers/quantization/awq.py +200 -0
- sglang/srt/layers/quantization/base_config.py +5 -0
- sglang/srt/layers/quantization/blockwise_int8.py +1 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
- sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
- sglang/srt/layers/quantization/fp8.py +76 -34
- sglang/srt/layers/quantization/fp8_kernel.py +25 -8
- sglang/srt/layers/quantization/fp8_utils.py +284 -28
- sglang/srt/layers/quantization/gptq.py +36 -19
- sglang/srt/layers/quantization/kv_cache.py +98 -0
- sglang/srt/layers/quantization/modelopt_quant.py +9 -7
- sglang/srt/layers/quantization/utils.py +153 -0
- sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
- sglang/srt/layers/rotary_embedding.py +78 -87
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/backend/base_backend.py +4 -4
- sglang/srt/lora/backend/flashinfer_backend.py +12 -9
- sglang/srt/lora/backend/triton_backend.py +5 -8
- sglang/srt/lora/layers.py +87 -33
- sglang/srt/lora/lora.py +2 -22
- sglang/srt/lora/lora_manager.py +67 -30
- sglang/srt/lora/mem_pool.py +117 -52
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
- sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
- sglang/srt/lora/utils.py +18 -1
- sglang/srt/managers/cache_controller.py +2 -5
- sglang/srt/managers/data_parallel_controller.py +30 -8
- sglang/srt/managers/expert_distribution.py +81 -0
- sglang/srt/managers/io_struct.py +43 -5
- sglang/srt/managers/mm_utils.py +373 -0
- sglang/srt/managers/multimodal_processor.py +68 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
- sglang/srt/managers/multimodal_processors/clip.py +63 -0
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
- sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
- sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
- sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
- sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
- sglang/srt/managers/schedule_batch.py +134 -30
- sglang/srt/managers/scheduler.py +290 -31
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +59 -24
- sglang/srt/managers/tp_worker.py +4 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
- sglang/srt/managers/utils.py +6 -1
- sglang/srt/mem_cache/hiradix_cache.py +18 -7
- sglang/srt/mem_cache/memory_pool.py +255 -98
- sglang/srt/mem_cache/paged_allocator.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +4 -4
- sglang/srt/model_executor/cuda_graph_runner.py +36 -21
- sglang/srt/model_executor/forward_batch_info.py +68 -11
- sglang/srt/model_executor/model_runner.py +75 -8
- sglang/srt/model_loader/loader.py +171 -3
- sglang/srt/model_loader/weight_utils.py +51 -3
- sglang/srt/models/clip.py +563 -0
- sglang/srt/models/deepseek_janus_pro.py +31 -88
- sglang/srt/models/deepseek_nextn.py +22 -10
- sglang/srt/models/deepseek_v2.py +329 -73
- sglang/srt/models/deepseek_vl2.py +358 -0
- sglang/srt/models/gemma3_causal.py +694 -0
- sglang/srt/models/gemma3_mm.py +468 -0
- sglang/srt/models/llama.py +47 -7
- sglang/srt/models/llama_eagle.py +1 -0
- sglang/srt/models/llama_eagle3.py +196 -0
- sglang/srt/models/llava.py +3 -3
- sglang/srt/models/llavavid.py +3 -3
- sglang/srt/models/minicpmo.py +1995 -0
- sglang/srt/models/minicpmv.py +62 -137
- sglang/srt/models/mllama.py +4 -4
- sglang/srt/models/phi3_small.py +1 -1
- sglang/srt/models/qwen2.py +3 -0
- sglang/srt/models/qwen2_5_vl.py +68 -146
- sglang/srt/models/qwen2_classification.py +75 -0
- sglang/srt/models/qwen2_moe.py +9 -1
- sglang/srt/models/qwen2_vl.py +25 -63
- sglang/srt/openai_api/adapter.py +201 -104
- sglang/srt/openai_api/protocol.py +33 -7
- sglang/srt/patch_torch.py +71 -0
- sglang/srt/sampling/sampling_batch_info.py +1 -1
- sglang/srt/sampling/sampling_params.py +6 -6
- sglang/srt/server_args.py +114 -14
- sglang/srt/speculative/build_eagle_tree.py +7 -347
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
- sglang/srt/speculative/eagle_utils.py +208 -252
- sglang/srt/speculative/eagle_worker.py +140 -54
- sglang/srt/speculative/spec_info.py +6 -1
- sglang/srt/torch_memory_saver_adapter.py +22 -0
- sglang/srt/utils.py +215 -21
- sglang/test/__init__.py +0 -0
- sglang/test/attention/__init__.py +0 -0
- sglang/test/attention/test_flashattn_backend.py +312 -0
- sglang/test/runners.py +29 -2
- sglang/test/test_activation.py +2 -1
- sglang/test/test_block_fp8.py +5 -4
- sglang/test/test_block_fp8_ep.py +2 -1
- sglang/test/test_dynamic_grad_mode.py +58 -0
- sglang/test/test_layernorm.py +3 -2
- sglang/test/test_utils.py +56 -5
- sglang/utils.py +31 -0
- sglang/version.py +1 -1
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/METADATA +16 -8
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/RECORD +180 -132
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/WHEEL +1 -1
- sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
- sglang/srt/managers/image_processor.py +0 -55
- sglang/srt/managers/image_processors/base_image_processor.py +0 -219
- sglang/srt/managers/image_processors/minicpmv.py +0 -86
- sglang/srt/managers/multi_modality_padding.py +0 -134
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info/licenses}/LICENSE +0 -0
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/top_level.txt +0 -0
@@ -16,6 +16,7 @@
|
|
16
16
|
from __future__ import annotations
|
17
17
|
|
18
18
|
import bisect
|
19
|
+
import os
|
19
20
|
from contextlib import contextmanager
|
20
21
|
from typing import TYPE_CHECKING, Callable
|
21
22
|
|
@@ -81,7 +82,9 @@ def patch_model(
|
|
81
82
|
# tp_group.ca_comm = None
|
82
83
|
yield torch.compile(
|
83
84
|
torch.no_grad()(model.forward),
|
84
|
-
mode=
|
85
|
+
mode=os.environ.get(
|
86
|
+
"SGLANG_TORCH_COMPILE_MODE", "max-autotune-no-cudagraphs"
|
87
|
+
),
|
85
88
|
dynamic=False,
|
86
89
|
)
|
87
90
|
else:
|
@@ -117,24 +120,21 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
|
117
120
|
else:
|
118
121
|
capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
|
119
122
|
else:
|
120
|
-
|
123
|
+
# Since speculative decoding requires more cuda graph memory, we
|
124
|
+
# capture less.
|
125
|
+
capture_bs = list(range(1, 9)) + list(range(9, 33, 2)) + [64, 96, 128, 160]
|
121
126
|
|
122
|
-
|
123
|
-
|
127
|
+
if _is_hip:
|
128
|
+
capture_bs += [i * 8 for i in range(21, 33)]
|
124
129
|
|
125
130
|
if max(capture_bs) > model_runner.req_to_token_pool.size:
|
126
131
|
# In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests
|
127
132
|
# is very small. We add more values here to make sure we capture the maximum bs.
|
128
|
-
capture_bs
|
129
|
-
|
130
|
-
|
131
|
-
capture_bs
|
132
|
-
+ [model_runner.req_to_token_pool.size - 1]
|
133
|
-
+ [model_runner.req_to_token_pool.size]
|
134
|
-
)
|
135
|
-
)
|
136
|
-
)
|
133
|
+
capture_bs += [model_runner.req_to_token_pool.size - 1] + [
|
134
|
+
model_runner.req_to_token_pool.size
|
135
|
+
]
|
137
136
|
|
137
|
+
capture_bs = list(sorted(set(capture_bs)))
|
138
138
|
capture_bs = [
|
139
139
|
bs
|
140
140
|
for bs in capture_bs
|
@@ -174,6 +174,7 @@ class CudaGraphRunner:
|
|
174
174
|
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
|
175
175
|
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
|
176
176
|
self.enable_dp_attention = model_runner.server_args.enable_dp_attention
|
177
|
+
self.enable_sp_layernorm = model_runner.server_args.enable_sp_layernorm
|
177
178
|
self.speculative_algorithm = model_runner.server_args.speculative_algorithm
|
178
179
|
self.tp_size = model_runner.server_args.tp_size
|
179
180
|
self.dp_size = model_runner.server_args.dp_size
|
@@ -220,7 +221,19 @@ class CudaGraphRunner:
|
|
220
221
|
self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64)
|
221
222
|
|
222
223
|
# Speculative_inference
|
223
|
-
if
|
224
|
+
if (
|
225
|
+
model_runner.spec_algorithm.is_eagle3()
|
226
|
+
and not model_runner.is_draft_worker
|
227
|
+
):
|
228
|
+
self.hidden_states = torch.zeros(
|
229
|
+
(
|
230
|
+
self.max_num_token,
|
231
|
+
3 * self.model_runner.model_config.hidden_size,
|
232
|
+
),
|
233
|
+
dtype=self.model_runner.dtype,
|
234
|
+
)
|
235
|
+
self.model_runner.model.set_eagle3_layers_to_capture()
|
236
|
+
elif model_runner.spec_algorithm.is_eagle():
|
224
237
|
self.hidden_states = torch.zeros(
|
225
238
|
(self.max_num_token, self.model_runner.model_config.hidden_size),
|
226
239
|
dtype=self.model_runner.dtype,
|
@@ -233,8 +246,8 @@ class CudaGraphRunner:
|
|
233
246
|
)
|
234
247
|
else:
|
235
248
|
self.encoder_lens = None
|
236
|
-
|
237
|
-
|
249
|
+
if self.enable_dp_attention or self.enable_sp_layernorm:
|
250
|
+
# TODO(ch-wan): SP layernorm should use a different logic to manage gathered_buffer
|
238
251
|
self.gathered_buffer = torch.zeros(
|
239
252
|
(
|
240
253
|
self.max_bs * self.dp_size * self.num_tokens_per_bs,
|
@@ -276,7 +289,7 @@ class CudaGraphRunner:
|
|
276
289
|
self.model_runner.token_to_kv_pool.capture_mode = False
|
277
290
|
|
278
291
|
def can_run(self, forward_batch: ForwardBatch):
|
279
|
-
if self.enable_dp_attention:
|
292
|
+
if self.enable_dp_attention or self.enable_sp_layernorm:
|
280
293
|
total_global_tokens = sum(forward_batch.global_num_tokens_cpu)
|
281
294
|
|
282
295
|
is_bs_supported = forward_batch.can_run_dp_cuda_graph and (
|
@@ -357,7 +370,7 @@ class CudaGraphRunner:
|
|
357
370
|
encoder_lens = None
|
358
371
|
mrope_positions = self.mrope_positions[:, :bs]
|
359
372
|
|
360
|
-
if self.enable_dp_attention:
|
373
|
+
if self.enable_dp_attention or self.enable_sp_layernorm:
|
361
374
|
self.global_num_tokens_gpu.copy_(
|
362
375
|
torch.tensor(
|
363
376
|
[
|
@@ -459,7 +472,7 @@ class CudaGraphRunner:
|
|
459
472
|
raw_num_token = raw_bs * self.num_tokens_per_bs
|
460
473
|
|
461
474
|
# Pad
|
462
|
-
if self.enable_dp_attention:
|
475
|
+
if self.enable_dp_attention or self.enable_sp_layernorm:
|
463
476
|
index = bisect.bisect_left(
|
464
477
|
self.capture_bs, sum(forward_batch.global_num_tokens_cpu)
|
465
478
|
)
|
@@ -485,7 +498,7 @@ class CudaGraphRunner:
|
|
485
498
|
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
|
486
499
|
if forward_batch.mrope_positions is not None:
|
487
500
|
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
|
488
|
-
if self.enable_dp_attention:
|
501
|
+
if self.enable_dp_attention or self.enable_sp_layernorm:
|
489
502
|
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
|
490
503
|
|
491
504
|
if hasattr(forward_batch.spec_info, "hidden_states"):
|
@@ -508,7 +521,9 @@ class CudaGraphRunner:
|
|
508
521
|
self.raw_num_token = raw_num_token
|
509
522
|
self.bs = bs
|
510
523
|
|
511
|
-
def replay(
|
524
|
+
def replay(
|
525
|
+
self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False
|
526
|
+
) -> LogitsProcessorOutput:
|
512
527
|
if not skip_attn_backend_init:
|
513
528
|
self.replay_prepare(forward_batch)
|
514
529
|
else:
|
@@ -33,6 +33,7 @@ from dataclasses import dataclass
|
|
33
33
|
from enum import IntEnum, auto
|
34
34
|
from typing import TYPE_CHECKING, List, Optional, Union
|
35
35
|
|
36
|
+
import numpy as np
|
36
37
|
import torch
|
37
38
|
import triton
|
38
39
|
import triton.language as tl
|
@@ -42,7 +43,7 @@ from sglang.srt.utils import get_compiler_backend
|
|
42
43
|
|
43
44
|
if TYPE_CHECKING:
|
44
45
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
45
|
-
from sglang.srt.managers.schedule_batch import
|
46
|
+
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, MultimodalInputs
|
46
47
|
from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
|
47
48
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
48
49
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
@@ -175,7 +176,7 @@ class ForwardBatch:
|
|
175
176
|
extend_input_logprob_token_ids_gpu: Optional[torch.Tensor] = None
|
176
177
|
|
177
178
|
# For multimodal
|
178
|
-
|
179
|
+
mm_inputs: Optional[List[MultimodalInputs]] = None
|
179
180
|
|
180
181
|
# Encoder-decoder
|
181
182
|
encoder_cached: Optional[List[bool]] = None
|
@@ -241,7 +242,7 @@ class ForwardBatch:
|
|
241
242
|
req_pool_indices=batch.req_pool_indices,
|
242
243
|
seq_lens=batch.seq_lens,
|
243
244
|
out_cache_loc=batch.out_cache_loc,
|
244
|
-
|
245
|
+
mm_inputs=batch.multimodal_inputs,
|
245
246
|
encoder_cached=batch.encoder_cached,
|
246
247
|
encoder_lens=batch.encoder_lens,
|
247
248
|
encoder_lens_cpu=batch.encoder_lens_cpu,
|
@@ -331,6 +332,53 @@ class ForwardBatch:
|
|
331
332
|
|
332
333
|
return ret
|
333
334
|
|
335
|
+
def merge_mm_inputs(self) -> Optional[MultimodalInputs]:
|
336
|
+
"""
|
337
|
+
Merge all image inputs in the batch into a single MultiModalInputs object.
|
338
|
+
|
339
|
+
Returns:
|
340
|
+
if none, current batch contains no image input
|
341
|
+
|
342
|
+
"""
|
343
|
+
if not self.mm_inputs or all(x is None for x in self.mm_inputs):
|
344
|
+
return None
|
345
|
+
|
346
|
+
# Filter out None values
|
347
|
+
valid_inputs = [x for x in self.mm_inputs if x is not None]
|
348
|
+
|
349
|
+
# Start with the first valid image input
|
350
|
+
merged = valid_inputs[0]
|
351
|
+
|
352
|
+
# Merge remaining inputs
|
353
|
+
for mm_input in valid_inputs[1:]:
|
354
|
+
merged.merge(mm_input)
|
355
|
+
|
356
|
+
if isinstance(merged.pixel_values, np.ndarray):
|
357
|
+
merged.pixel_values = torch.from_numpy(merged.pixel_values)
|
358
|
+
if isinstance(merged.audio_features, np.ndarray):
|
359
|
+
merged.audio_features = torch.from_numpy(merged.audio_features)
|
360
|
+
|
361
|
+
return merged
|
362
|
+
|
363
|
+
def contains_image_inputs(self) -> bool:
|
364
|
+
if self.mm_inputs is None:
|
365
|
+
return False
|
366
|
+
return any(
|
367
|
+
mm_input is not None and mm_input.contains_image_inputs()
|
368
|
+
for mm_input in self.mm_inputs
|
369
|
+
)
|
370
|
+
|
371
|
+
def contains_audio_inputs(self) -> bool:
|
372
|
+
if self.mm_inputs is None:
|
373
|
+
return False
|
374
|
+
return any(
|
375
|
+
mm_input is not None and mm_input.contains_audio_inputs()
|
376
|
+
for mm_input in self.mm_inputs
|
377
|
+
)
|
378
|
+
|
379
|
+
def contains_mm_inputs(self) -> bool:
|
380
|
+
return self.contains_audio_inputs() or self.contains_image_inputs()
|
381
|
+
|
334
382
|
def _compute_mrope_positions(
|
335
383
|
self, model_runner: ModelRunner, batch: ModelWorkerBatch
|
336
384
|
):
|
@@ -341,8 +389,8 @@ class ForwardBatch:
|
|
341
389
|
for i, _ in enumerate(mrope_positions_list):
|
342
390
|
mrope_position_delta = (
|
343
391
|
0
|
344
|
-
if batch.
|
345
|
-
else batch.
|
392
|
+
if batch.multimodal_inputs[i] is None
|
393
|
+
else batch.multimodal_inputs[i].mrope_position_delta
|
346
394
|
)
|
347
395
|
mrope_positions_list[i] = MRotaryEmbedding.get_next_input_positions(
|
348
396
|
mrope_position_delta,
|
@@ -351,13 +399,13 @@ class ForwardBatch:
|
|
351
399
|
)
|
352
400
|
elif self.forward_mode.is_extend():
|
353
401
|
extend_start_loc_cpu = self.extend_start_loc.cpu().numpy()
|
354
|
-
for i,
|
402
|
+
for i, multimodal_inputs in enumerate(batch.multimodal_inputs):
|
355
403
|
extend_start_loc, extend_seq_len, extend_prefix_len = (
|
356
404
|
extend_start_loc_cpu[i],
|
357
405
|
batch.extend_seq_lens[i],
|
358
406
|
batch.extend_prefix_lens[i],
|
359
407
|
)
|
360
|
-
if
|
408
|
+
if multimodal_inputs is None:
|
361
409
|
# text only
|
362
410
|
mrope_positions = [
|
363
411
|
[
|
@@ -374,16 +422,25 @@ class ForwardBatch:
|
|
374
422
|
input_tokens=self.input_ids[
|
375
423
|
extend_start_loc : extend_start_loc + extend_seq_len
|
376
424
|
],
|
377
|
-
image_grid_thw=
|
425
|
+
image_grid_thw=multimodal_inputs.image_grid_thws,
|
426
|
+
video_grid_thw=multimodal_inputs.video_grid_thws,
|
427
|
+
image_token_id=multimodal_inputs.im_token_id,
|
428
|
+
video_token_id=multimodal_inputs.video_token_id,
|
378
429
|
vision_start_token_id=hf_config.vision_start_token_id,
|
430
|
+
vision_end_token_id=hf_config.vision_end_token_id,
|
379
431
|
spatial_merge_size=hf_config.vision_config.spatial_merge_size,
|
380
432
|
context_len=0,
|
433
|
+
seq_len=len(self.input_ids),
|
434
|
+
second_per_grid_ts=multimodal_inputs.second_per_grid_ts,
|
435
|
+
tokens_per_second=hf_config.vision_config.tokens_per_second,
|
381
436
|
)
|
382
437
|
)
|
383
|
-
batch.
|
438
|
+
batch.multimodal_inputs[i].mrope_position_delta = (
|
439
|
+
mrope_position_delta
|
440
|
+
)
|
384
441
|
mrope_positions_list[i] = mrope_positions
|
385
442
|
|
386
|
-
self.mrope_positions = torch.
|
443
|
+
self.mrope_positions = torch.cat(
|
387
444
|
[torch.tensor(pos, device=device) for pos in mrope_positions_list],
|
388
445
|
axis=1,
|
389
446
|
)
|
@@ -449,7 +506,7 @@ def compute_position_kernel(
|
|
449
506
|
def compute_position_torch(
|
450
507
|
extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor
|
451
508
|
):
|
452
|
-
positions = torch.
|
509
|
+
positions = torch.cat(
|
453
510
|
[
|
454
511
|
torch.arange(
|
455
512
|
prefix_len, prefix_len + extend_len, device=extend_prefix_lens.device
|
@@ -64,6 +64,7 @@ from sglang.srt.model_loader.loader import (
|
|
64
64
|
)
|
65
65
|
from sglang.srt.model_loader.utils import set_default_torch_dtype
|
66
66
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
67
|
+
from sglang.srt.patch_torch import monkey_patch_torch_reductions
|
67
68
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
68
69
|
from sglang.srt.server_args import ServerArgs
|
69
70
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
@@ -145,10 +146,12 @@ class ModelRunner:
|
|
145
146
|
"enable_nan_detection": server_args.enable_nan_detection,
|
146
147
|
"enable_dp_attention": server_args.enable_dp_attention,
|
147
148
|
"enable_ep_moe": server_args.enable_ep_moe,
|
149
|
+
"enable_deepep_moe": server_args.enable_deepep_moe,
|
148
150
|
"device": server_args.device,
|
149
151
|
"speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
|
150
152
|
"speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
|
151
153
|
"enable_flashinfer_mla": server_args.enable_flashinfer_mla,
|
154
|
+
"enable_flashmla": server_args.enable_flashmla,
|
152
155
|
"disable_radix_cache": server_args.disable_radix_cache,
|
153
156
|
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
|
154
157
|
"debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
|
@@ -187,9 +190,6 @@ class ModelRunner:
|
|
187
190
|
supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
|
188
191
|
if self.tp_size > 1 and supports_torch_tp:
|
189
192
|
self.apply_torch_tp()
|
190
|
-
self.torch_tp_applied = True
|
191
|
-
else:
|
192
|
-
self.torch_tp_applied = False
|
193
193
|
|
194
194
|
# Init lora
|
195
195
|
if server_args.lora_paths is not None:
|
@@ -209,6 +209,10 @@ class ModelRunner:
|
|
209
209
|
self.cuda_graph_runner = None
|
210
210
|
self.init_attention_backend()
|
211
211
|
|
212
|
+
# auxiliary hidden capture mode. TODO: expose this to server args?
|
213
|
+
if self.spec_algorithm.is_eagle3() and not self.is_draft_worker:
|
214
|
+
self.model.set_eagle3_layers_to_capture()
|
215
|
+
|
212
216
|
def model_specific_adjustment(self):
|
213
217
|
server_args = self.server_args
|
214
218
|
|
@@ -223,6 +227,13 @@ class ModelRunner:
|
|
223
227
|
"MLA optimization is turned on. Use flashinfer mla backend."
|
224
228
|
)
|
225
229
|
server_args.attention_backend = "flashinfer_mla"
|
230
|
+
elif server_args.enable_flashmla:
|
231
|
+
logger.info("MLA optimization is turned on. Use flashmla decode.")
|
232
|
+
server_args.attention_backend = "flashmla"
|
233
|
+
elif server_args.attention_backend == "fa3":
|
234
|
+
logger.info(
|
235
|
+
f"MLA optimization is turned on. Use flash attention 3 backend."
|
236
|
+
)
|
226
237
|
else:
|
227
238
|
logger.info("MLA optimization is turned on. Use triton backend.")
|
228
239
|
server_args.attention_backend = "triton"
|
@@ -254,18 +265,38 @@ class ModelRunner:
|
|
254
265
|
|
255
266
|
if self.model_config.hf_config.architectures == [
|
256
267
|
"Qwen2VLForConditionalGeneration"
|
268
|
+
] or self.model_config.hf_config.architectures == [
|
269
|
+
"Qwen2_5_VLForConditionalGeneration"
|
257
270
|
]:
|
258
|
-
# TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
|
271
|
+
# TODO: qwen2-vl series does not support radix cache now, set disable_radix_cache=True automatically
|
259
272
|
logger.info(
|
260
|
-
"Automatically turn off --chunked-prefill-size and disable radix cache for
|
273
|
+
"Automatically turn off --chunked-prefill-size and disable radix cache for qwen-vl series."
|
261
274
|
)
|
262
275
|
server_args.chunked_prefill_size = -1
|
263
276
|
server_args.disable_radix_cache = True
|
264
277
|
|
278
|
+
if self.model_config.hf_config.architectures == ["DeepseekVL2ForCausalLM"]:
|
279
|
+
# TODO: deepseek-vl2 does not support radix cache now, set disable_radix_cache=True automatically
|
280
|
+
logger.info(
|
281
|
+
"Automatically turn off --chunked-prefill-size and disable radix cache for deepseek-vl2."
|
282
|
+
)
|
283
|
+
server_args.chunked_prefill_size = -1
|
284
|
+
server_args.disable_radix_cache = True
|
285
|
+
|
286
|
+
if server_args.enable_deepep_moe:
|
287
|
+
logger.info("DeepEP is turned on.")
|
288
|
+
|
265
289
|
def init_torch_distributed(self):
|
266
290
|
logger.info("Init torch distributed begin.")
|
267
291
|
|
268
|
-
|
292
|
+
try:
|
293
|
+
torch.get_device_module(self.device).set_device(self.gpu_id)
|
294
|
+
except Exception:
|
295
|
+
logger.warning(
|
296
|
+
f"Context: {self.device=} {self.gpu_id=} {os.environ.get('CUDA_VISIBLE_DEVICES')=} {self.tp_rank=} {self.tp_size=}"
|
297
|
+
)
|
298
|
+
raise
|
299
|
+
|
269
300
|
if self.device == "cuda":
|
270
301
|
backend = "nccl"
|
271
302
|
elif self.device == "xpu":
|
@@ -606,6 +637,8 @@ class ModelRunner:
|
|
606
637
|
load_config=self.load_config,
|
607
638
|
dtype=self.dtype,
|
608
639
|
lora_backend=self.server_args.lora_backend,
|
640
|
+
tp_size=self.tp_size,
|
641
|
+
tp_rank=self.tp_rank,
|
609
642
|
)
|
610
643
|
logger.info("LoRA manager ready.")
|
611
644
|
|
@@ -840,6 +873,23 @@ class ModelRunner:
|
|
840
873
|
)
|
841
874
|
|
842
875
|
self.attn_backend = FlashInferMLAAttnBackend(self)
|
876
|
+
elif self.server_args.attention_backend == "flashmla":
|
877
|
+
from sglang.srt.layers.attention.flashmla_backend import FlashMLABackend
|
878
|
+
|
879
|
+
self.attn_backend = FlashMLABackend(self)
|
880
|
+
elif self.server_args.attention_backend == "fa3":
|
881
|
+
assert torch.cuda.get_device_capability()[0] >= 9, (
|
882
|
+
"FlashAttention v3 Backend requires SM>=90. "
|
883
|
+
"Please use `--attention-backend flashinfer`."
|
884
|
+
)
|
885
|
+
logger.warning(
|
886
|
+
"FlashAttention v3 Backend is in Beta. Multimodal, FP8, and Speculative Decoding are not supported."
|
887
|
+
)
|
888
|
+
from sglang.srt.layers.attention.flashattention_backend import (
|
889
|
+
FlashAttentionBackend,
|
890
|
+
)
|
891
|
+
|
892
|
+
self.attn_backend = FlashAttentionBackend(self)
|
843
893
|
else:
|
844
894
|
raise ValueError(
|
845
895
|
f"Invalid attention backend: {self.server_args.attention_backend}"
|
@@ -1009,6 +1059,22 @@ class ModelRunner:
|
|
1009
1059
|
return False
|
1010
1060
|
return rope_scaling.get("type", None) == "mrope"
|
1011
1061
|
|
1062
|
+
def save_remote_model(self, url: str):
|
1063
|
+
from sglang.srt.model_loader.loader import RemoteModelLoader
|
1064
|
+
|
1065
|
+
logger.info(f"Saving model to {url}")
|
1066
|
+
RemoteModelLoader.save_model(self.model, self.model_config.model_path, url)
|
1067
|
+
|
1068
|
+
def save_sharded_model(
|
1069
|
+
self, path: str, pattern: Optional[str] = None, max_size: Optional[int] = None
|
1070
|
+
):
|
1071
|
+
from sglang.srt.model_loader.loader import ShardedStateLoader
|
1072
|
+
|
1073
|
+
logger.info(
|
1074
|
+
f"Save sharded model to {path} with pattern {pattern} and max_size {max_size}"
|
1075
|
+
)
|
1076
|
+
ShardedStateLoader.save_model(self.model, path, pattern, max_size)
|
1077
|
+
|
1012
1078
|
|
1013
1079
|
def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tensor]]):
|
1014
1080
|
params_dict = dict(model.named_parameters())
|
@@ -1018,8 +1084,9 @@ def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tenso
|
|
1018
1084
|
|
1019
1085
|
def _unwrap_tensor(tensor, tp_rank):
|
1020
1086
|
if isinstance(tensor, LocalSerializedTensor):
|
1021
|
-
|
1022
|
-
|
1087
|
+
monkey_patch_torch_reductions()
|
1088
|
+
tensor = tensor.get(tp_rank)
|
1089
|
+
return tensor.to(torch.cuda.current_device())
|
1023
1090
|
|
1024
1091
|
|
1025
1092
|
@dataclass
|
@@ -9,11 +9,11 @@ import json
|
|
9
9
|
import logging
|
10
10
|
import math
|
11
11
|
import os
|
12
|
+
import time
|
12
13
|
from abc import ABC, abstractmethod
|
13
14
|
from contextlib import contextmanager
|
14
15
|
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast
|
15
16
|
|
16
|
-
import gguf
|
17
17
|
import huggingface_hub
|
18
18
|
import numpy as np
|
19
19
|
import torch
|
@@ -25,6 +25,12 @@ from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
|
25
25
|
from sglang.srt.configs.device_config import DeviceConfig
|
26
26
|
from sglang.srt.configs.load_config import LoadConfig, LoadFormat
|
27
27
|
from sglang.srt.configs.model_config import ModelConfig
|
28
|
+
from sglang.srt.connector import (
|
29
|
+
ConnectorType,
|
30
|
+
create_remote_connector,
|
31
|
+
get_connector_type,
|
32
|
+
)
|
33
|
+
from sglang.srt.connector.utils import parse_model_name
|
28
34
|
from sglang.srt.distributed import (
|
29
35
|
get_tensor_model_parallel_rank,
|
30
36
|
get_tensor_model_parallel_world_size,
|
@@ -46,6 +52,7 @@ from sglang.srt.model_loader.weight_utils import (
|
|
46
52
|
np_cache_weights_iterator,
|
47
53
|
pt_weights_iterator,
|
48
54
|
safetensors_weights_iterator,
|
55
|
+
set_runai_streamer_env,
|
49
56
|
)
|
50
57
|
from sglang.srt.utils import (
|
51
58
|
get_bool_env_var,
|
@@ -194,7 +201,7 @@ class DefaultModelLoader(BaseModelLoader):
|
|
194
201
|
def _maybe_download_from_modelscope(
|
195
202
|
self, model: str, revision: Optional[str]
|
196
203
|
) -> Optional[str]:
|
197
|
-
"""Download model from ModelScope hub if
|
204
|
+
"""Download model from ModelScope hub if SGLANG_USE_MODELSCOPE is True.
|
198
205
|
|
199
206
|
Returns the path to the downloaded model, or None if the model is not
|
200
207
|
downloaded from ModelScope."""
|
@@ -490,7 +497,7 @@ class ShardedStateLoader(BaseModelLoader):
|
|
490
497
|
Model loader that directly loads each worker's model state dict, which
|
491
498
|
enables a fast load path for large tensor-parallel models where each worker
|
492
499
|
only needs to read its own shard rather than the entire checkpoint. See
|
493
|
-
`examples/save_sharded_state.py` for creating a sharded checkpoint.
|
500
|
+
`examples/runtime/engine/save_sharded_state.py` for creating a sharded checkpoint.
|
494
501
|
"""
|
495
502
|
|
496
503
|
DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors"
|
@@ -1147,6 +1154,17 @@ class GGUFModelLoader(BaseModelLoader):
|
|
1147
1154
|
See "Standardized tensor names" in
|
1148
1155
|
https://github.com/ggerganov/ggml/blob/master/docs/gguf.md for details.
|
1149
1156
|
"""
|
1157
|
+
|
1158
|
+
# only load the gguf module when needed
|
1159
|
+
try:
|
1160
|
+
import gguf
|
1161
|
+
|
1162
|
+
# FIXME: add version check for gguf
|
1163
|
+
except ImportError as err:
|
1164
|
+
raise ImportError(
|
1165
|
+
"Please install gguf via `pip install gguf` to use gguf quantizer."
|
1166
|
+
) from err
|
1167
|
+
|
1150
1168
|
config = model_config.hf_config
|
1151
1169
|
model_type = config.model_type
|
1152
1170
|
# hack: ggufs have a different name than transformers
|
@@ -1204,6 +1222,153 @@ class GGUFModelLoader(BaseModelLoader):
|
|
1204
1222
|
return model
|
1205
1223
|
|
1206
1224
|
|
1225
|
+
class RemoteModelLoader(BaseModelLoader):
|
1226
|
+
"""Model loader that can load Tensors from remote database."""
|
1227
|
+
|
1228
|
+
def __init__(self, load_config: LoadConfig):
|
1229
|
+
super().__init__(load_config)
|
1230
|
+
# TODO @DellCurry: move to s3 connector only
|
1231
|
+
set_runai_streamer_env(load_config)
|
1232
|
+
|
1233
|
+
def _get_weights_iterator_kv(
|
1234
|
+
self,
|
1235
|
+
client,
|
1236
|
+
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
1237
|
+
"""Get an iterator for the model weights from remote storage."""
|
1238
|
+
assert get_connector_type(client) == ConnectorType.KV
|
1239
|
+
rank = get_tensor_model_parallel_rank()
|
1240
|
+
return client.weight_iterator(rank)
|
1241
|
+
|
1242
|
+
def _get_weights_iterator_fs(
|
1243
|
+
self,
|
1244
|
+
client,
|
1245
|
+
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
1246
|
+
"""Get an iterator for the model weights from remote storage."""
|
1247
|
+
assert get_connector_type(client) == ConnectorType.FS
|
1248
|
+
return client.weight_iterator()
|
1249
|
+
|
1250
|
+
def download_model(self, model_config: ModelConfig) -> None:
|
1251
|
+
pass
|
1252
|
+
|
1253
|
+
@staticmethod
|
1254
|
+
def save_model(
|
1255
|
+
model: torch.nn.Module,
|
1256
|
+
model_path: str,
|
1257
|
+
url: str,
|
1258
|
+
) -> None:
|
1259
|
+
with create_remote_connector(url) as client:
|
1260
|
+
assert get_connector_type(client) == ConnectorType.KV
|
1261
|
+
model_name = parse_model_name(url)
|
1262
|
+
rank = get_tensor_model_parallel_rank()
|
1263
|
+
state_dict = ShardedStateLoader._filter_subtensors(model.state_dict())
|
1264
|
+
for key, tensor in state_dict.items():
|
1265
|
+
r_key = f"{model_name}/keys/rank_{rank}/{key}"
|
1266
|
+
client.set(r_key, tensor)
|
1267
|
+
|
1268
|
+
for root, _, files in os.walk(model_path):
|
1269
|
+
for file_name in files:
|
1270
|
+
# ignore hidden files
|
1271
|
+
if file_name.startswith("."):
|
1272
|
+
continue
|
1273
|
+
if os.path.splitext(file_name)[1] not in (
|
1274
|
+
".bin",
|
1275
|
+
".pt",
|
1276
|
+
".safetensors",
|
1277
|
+
):
|
1278
|
+
file_path = os.path.join(root, file_name)
|
1279
|
+
with open(file_path, encoding="utf-8") as file:
|
1280
|
+
file_content = file.read()
|
1281
|
+
f_key = f"{model_name}/files/{file_name}"
|
1282
|
+
client.setstr(f_key, file_content)
|
1283
|
+
|
1284
|
+
def _load_model_from_remote_kv(self, model: nn.Module, client):
|
1285
|
+
for _, module in model.named_modules():
|
1286
|
+
quant_method = getattr(module, "quant_method", None)
|
1287
|
+
if quant_method is not None:
|
1288
|
+
quant_method.process_weights_after_loading(module)
|
1289
|
+
weights_iterator = self._get_weights_iterator_kv(client)
|
1290
|
+
state_dict = ShardedStateLoader._filter_subtensors(model.state_dict())
|
1291
|
+
for key, tensor in weights_iterator:
|
1292
|
+
# If loading with LoRA enabled, additional padding may
|
1293
|
+
# be added to certain parameters. We only load into a
|
1294
|
+
# narrowed view of the parameter data.
|
1295
|
+
param_data = state_dict[key].data
|
1296
|
+
param_shape = state_dict[key].shape
|
1297
|
+
for dim, size in enumerate(tensor.shape):
|
1298
|
+
if size < param_shape[dim]:
|
1299
|
+
param_data = param_data.narrow(dim, 0, size)
|
1300
|
+
if tensor.shape != param_shape:
|
1301
|
+
logger.warning(
|
1302
|
+
"loading tensor of shape %s into " "parameter '%s' of shape %s",
|
1303
|
+
tensor.shape,
|
1304
|
+
key,
|
1305
|
+
param_shape,
|
1306
|
+
)
|
1307
|
+
param_data.copy_(tensor)
|
1308
|
+
state_dict.pop(key)
|
1309
|
+
if state_dict:
|
1310
|
+
raise ValueError(f"Missing keys {tuple(state_dict)} in loaded state!")
|
1311
|
+
|
1312
|
+
def _load_model_from_remote_fs(
|
1313
|
+
self, model, client, model_config: ModelConfig, device_config: DeviceConfig
|
1314
|
+
) -> nn.Module:
|
1315
|
+
|
1316
|
+
target_device = torch.device(device_config.device)
|
1317
|
+
with set_default_torch_dtype(model_config.dtype):
|
1318
|
+
model.load_weights(self._get_weights_iterator_fs(client))
|
1319
|
+
|
1320
|
+
for _, module in model.named_modules():
|
1321
|
+
quant_method = getattr(module, "quant_method", None)
|
1322
|
+
if quant_method is not None:
|
1323
|
+
# When quant methods need to process weights after loading
|
1324
|
+
# (for repacking, quantizing, etc), they expect parameters
|
1325
|
+
# to be on the global target device. This scope is for the
|
1326
|
+
# case where cpu offloading is used, where we will move the
|
1327
|
+
# parameters onto device for processing and back off after.
|
1328
|
+
with device_loading_context(module, target_device):
|
1329
|
+
quant_method.process_weights_after_loading(module)
|
1330
|
+
|
1331
|
+
def load_model(
|
1332
|
+
self,
|
1333
|
+
*,
|
1334
|
+
model_config: ModelConfig,
|
1335
|
+
device_config: DeviceConfig,
|
1336
|
+
) -> nn.Module:
|
1337
|
+
logger.info("Loading weights from remote storage ...")
|
1338
|
+
start = time.perf_counter()
|
1339
|
+
load_config = self.load_config
|
1340
|
+
|
1341
|
+
assert load_config.load_format == LoadFormat.REMOTE, (
|
1342
|
+
f"Model loader {self.load_config.load_format} is not supported for "
|
1343
|
+
f"load format {load_config.load_format}"
|
1344
|
+
)
|
1345
|
+
|
1346
|
+
model_weights = model_config.model_path
|
1347
|
+
if hasattr(model_config, "model_weights"):
|
1348
|
+
model_weights = model_config.model_weights
|
1349
|
+
|
1350
|
+
with set_default_torch_dtype(model_config.dtype):
|
1351
|
+
with torch.device(device_config.device):
|
1352
|
+
model = _initialize_model(model_config, self.load_config)
|
1353
|
+
for _, module in model.named_modules():
|
1354
|
+
quant_method = getattr(module, "quant_method", None)
|
1355
|
+
if quant_method is not None:
|
1356
|
+
quant_method.process_weights_after_loading(module)
|
1357
|
+
|
1358
|
+
with create_remote_connector(model_weights, device_config.device) as client:
|
1359
|
+
connector_type = get_connector_type(client)
|
1360
|
+
if connector_type == ConnectorType.KV:
|
1361
|
+
self._load_model_from_remote_kv(model, client)
|
1362
|
+
elif connector_type == ConnectorType.FS:
|
1363
|
+
self._load_model_from_remote_fs(
|
1364
|
+
model, client, model_config, device_config
|
1365
|
+
)
|
1366
|
+
|
1367
|
+
end = time.perf_counter()
|
1368
|
+
logger.info("Loaded weights from remote storage in %.2f seconds.", end - start)
|
1369
|
+
return model.eval()
|
1370
|
+
|
1371
|
+
|
1207
1372
|
def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
|
1208
1373
|
"""Get a model loader based on the load format."""
|
1209
1374
|
|
@@ -1225,4 +1390,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
|
|
1225
1390
|
if load_config.load_format == LoadFormat.LAYERED:
|
1226
1391
|
return LayeredModelLoader(load_config)
|
1227
1392
|
|
1393
|
+
if load_config.load_format == LoadFormat.REMOTE:
|
1394
|
+
return RemoteModelLoader(load_config)
|
1395
|
+
|
1228
1396
|
return DefaultModelLoader(load_config)
|