sglang 0.4.4.post1__py3-none-any.whl → 0.4.4.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +6 -0
- sglang/bench_one_batch.py +1 -1
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +26 -4
- sglang/check_env.py +3 -4
- sglang/lang/backend/openai.py +18 -5
- sglang/lang/chat_template.py +28 -7
- sglang/lang/interpreter.py +7 -3
- sglang/lang/ir.py +10 -0
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/code_completion_parser.py +174 -0
- sglang/srt/configs/__init__.py +2 -6
- sglang/srt/configs/deepseekvl2.py +676 -0
- sglang/srt/configs/janus_pro.py +3 -4
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +49 -8
- sglang/srt/configs/utils.py +25 -0
- sglang/srt/connector/__init__.py +51 -0
- sglang/srt/connector/base_connector.py +112 -0
- sglang/srt/connector/redis.py +85 -0
- sglang/srt/connector/s3.py +122 -0
- sglang/srt/connector/serde/__init__.py +31 -0
- sglang/srt/connector/serde/safe_serde.py +29 -0
- sglang/srt/connector/serde/serde.py +43 -0
- sglang/srt/connector/utils.py +35 -0
- sglang/srt/conversation.py +88 -0
- sglang/srt/disaggregation/conn.py +81 -0
- sglang/srt/disaggregation/decode.py +495 -0
- sglang/srt/disaggregation/mini_lb.py +285 -0
- sglang/srt/disaggregation/prefill.py +249 -0
- sglang/srt/disaggregation/utils.py +44 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/parallel_state.py +42 -8
- sglang/srt/entrypoints/engine.py +55 -5
- sglang/srt/entrypoints/http_server.py +78 -13
- sglang/srt/entrypoints/verl_engine.py +2 -0
- sglang/srt/function_call_parser.py +133 -55
- sglang/srt/hf_transformers_utils.py +28 -3
- sglang/srt/layers/activation.py +4 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +434 -0
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashmla_backend.py +284 -0
- sglang/srt/layers/attention/triton_backend.py +171 -38
- sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
- sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
- sglang/srt/layers/attention/utils.py +53 -0
- sglang/srt/layers/attention/vision.py +9 -28
- sglang/srt/layers/dp_attention.py +41 -19
- sglang/srt/layers/layernorm.py +24 -2
- sglang/srt/layers/linear.py +17 -5
- sglang/srt/layers/logits_processor.py +25 -7
- sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
- sglang/srt/layers/moe/ep_moe/layer.py +273 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
- sglang/srt/layers/moe/fused_moe_native.py +2 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
- sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
- sglang/srt/layers/moe/topk.py +60 -20
- sglang/srt/layers/parameter.py +1 -1
- sglang/srt/layers/quantization/__init__.py +80 -53
- sglang/srt/layers/quantization/awq.py +200 -0
- sglang/srt/layers/quantization/base_config.py +5 -0
- sglang/srt/layers/quantization/blockwise_int8.py +1 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
- sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
- sglang/srt/layers/quantization/fp8.py +76 -34
- sglang/srt/layers/quantization/fp8_kernel.py +25 -8
- sglang/srt/layers/quantization/fp8_utils.py +284 -28
- sglang/srt/layers/quantization/gptq.py +36 -19
- sglang/srt/layers/quantization/kv_cache.py +98 -0
- sglang/srt/layers/quantization/modelopt_quant.py +9 -7
- sglang/srt/layers/quantization/utils.py +153 -0
- sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
- sglang/srt/layers/rotary_embedding.py +78 -87
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/backend/base_backend.py +4 -4
- sglang/srt/lora/backend/flashinfer_backend.py +12 -9
- sglang/srt/lora/backend/triton_backend.py +5 -8
- sglang/srt/lora/layers.py +87 -33
- sglang/srt/lora/lora.py +2 -22
- sglang/srt/lora/lora_manager.py +67 -30
- sglang/srt/lora/mem_pool.py +117 -52
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
- sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
- sglang/srt/lora/utils.py +18 -1
- sglang/srt/managers/cache_controller.py +2 -5
- sglang/srt/managers/data_parallel_controller.py +30 -8
- sglang/srt/managers/expert_distribution.py +81 -0
- sglang/srt/managers/io_struct.py +43 -5
- sglang/srt/managers/mm_utils.py +373 -0
- sglang/srt/managers/multimodal_processor.py +68 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
- sglang/srt/managers/multimodal_processors/clip.py +63 -0
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
- sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
- sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
- sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
- sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
- sglang/srt/managers/schedule_batch.py +134 -30
- sglang/srt/managers/scheduler.py +290 -31
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +59 -24
- sglang/srt/managers/tp_worker.py +4 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
- sglang/srt/managers/utils.py +6 -1
- sglang/srt/mem_cache/hiradix_cache.py +18 -7
- sglang/srt/mem_cache/memory_pool.py +255 -98
- sglang/srt/mem_cache/paged_allocator.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +4 -4
- sglang/srt/model_executor/cuda_graph_runner.py +36 -21
- sglang/srt/model_executor/forward_batch_info.py +68 -11
- sglang/srt/model_executor/model_runner.py +75 -8
- sglang/srt/model_loader/loader.py +171 -3
- sglang/srt/model_loader/weight_utils.py +51 -3
- sglang/srt/models/clip.py +563 -0
- sglang/srt/models/deepseek_janus_pro.py +31 -88
- sglang/srt/models/deepseek_nextn.py +22 -10
- sglang/srt/models/deepseek_v2.py +329 -73
- sglang/srt/models/deepseek_vl2.py +358 -0
- sglang/srt/models/gemma3_causal.py +694 -0
- sglang/srt/models/gemma3_mm.py +468 -0
- sglang/srt/models/llama.py +47 -7
- sglang/srt/models/llama_eagle.py +1 -0
- sglang/srt/models/llama_eagle3.py +196 -0
- sglang/srt/models/llava.py +3 -3
- sglang/srt/models/llavavid.py +3 -3
- sglang/srt/models/minicpmo.py +1995 -0
- sglang/srt/models/minicpmv.py +62 -137
- sglang/srt/models/mllama.py +4 -4
- sglang/srt/models/phi3_small.py +1 -1
- sglang/srt/models/qwen2.py +3 -0
- sglang/srt/models/qwen2_5_vl.py +68 -146
- sglang/srt/models/qwen2_classification.py +75 -0
- sglang/srt/models/qwen2_moe.py +9 -1
- sglang/srt/models/qwen2_vl.py +25 -63
- sglang/srt/openai_api/adapter.py +201 -104
- sglang/srt/openai_api/protocol.py +33 -7
- sglang/srt/patch_torch.py +71 -0
- sglang/srt/sampling/sampling_batch_info.py +1 -1
- sglang/srt/sampling/sampling_params.py +6 -6
- sglang/srt/server_args.py +114 -14
- sglang/srt/speculative/build_eagle_tree.py +7 -347
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
- sglang/srt/speculative/eagle_utils.py +208 -252
- sglang/srt/speculative/eagle_worker.py +140 -54
- sglang/srt/speculative/spec_info.py +6 -1
- sglang/srt/torch_memory_saver_adapter.py +22 -0
- sglang/srt/utils.py +215 -21
- sglang/test/__init__.py +0 -0
- sglang/test/attention/__init__.py +0 -0
- sglang/test/attention/test_flashattn_backend.py +312 -0
- sglang/test/runners.py +29 -2
- sglang/test/test_activation.py +2 -1
- sglang/test/test_block_fp8.py +5 -4
- sglang/test/test_block_fp8_ep.py +2 -1
- sglang/test/test_dynamic_grad_mode.py +58 -0
- sglang/test/test_layernorm.py +3 -2
- sglang/test/test_utils.py +56 -5
- sglang/utils.py +31 -0
- sglang/version.py +1 -1
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/METADATA +16 -8
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/RECORD +180 -132
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/WHEEL +1 -1
- sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
- sglang/srt/managers/image_processor.py +0 -55
- sglang/srt/managers/image_processors/base_image_processor.py +0 -219
- sglang/srt/managers/image_processors/minicpmv.py +0 -86
- sglang/srt/managers/multi_modality_padding.py +0 -134
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info/licenses}/LICENSE +0 -0
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/top_level.txt +0 -0
@@ -51,13 +51,14 @@ class ModelConfig:
|
|
51
51
|
self.quantization = quantization
|
52
52
|
|
53
53
|
# Parse args
|
54
|
+
self.maybe_pull_model_tokenizer_from_remote()
|
54
55
|
self.model_override_args = json.loads(model_override_args)
|
55
56
|
kwargs = {}
|
56
57
|
if override_config_file and override_config_file.strip():
|
57
58
|
kwargs["_configuration_file"] = override_config_file.strip()
|
58
59
|
|
59
60
|
self.hf_config = get_config(
|
60
|
-
model_path,
|
61
|
+
self.model_path,
|
61
62
|
trust_remote_code=trust_remote_code,
|
62
63
|
revision=revision,
|
63
64
|
model_override_args=self.model_override_args,
|
@@ -134,6 +135,11 @@ class ModelConfig:
|
|
134
135
|
self.attention_arch = AttentionArch.MLA
|
135
136
|
self.kv_lora_rank = self.hf_config.kv_lora_rank
|
136
137
|
self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
|
138
|
+
elif "DeepseekVL2ForCausalLM" in self.hf_config.architectures:
|
139
|
+
self.head_dim = 256
|
140
|
+
self.attention_arch = AttentionArch.MLA
|
141
|
+
self.kv_lora_rank = self.hf_text_config.kv_lora_rank
|
142
|
+
self.qk_rope_head_dim = self.hf_text_config.qk_rope_head_dim
|
137
143
|
else:
|
138
144
|
self.attention_arch = AttentionArch.MHA
|
139
145
|
|
@@ -318,6 +324,29 @@ class ModelConfig:
|
|
318
324
|
eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
|
319
325
|
return eos_ids
|
320
326
|
|
327
|
+
def maybe_pull_model_tokenizer_from_remote(self) -> None:
|
328
|
+
"""
|
329
|
+
Pull the model config files to a temporary
|
330
|
+
directory in case of remote.
|
331
|
+
|
332
|
+
Args:
|
333
|
+
model: The model name or path.
|
334
|
+
|
335
|
+
"""
|
336
|
+
from sglang.srt.connector import create_remote_connector
|
337
|
+
from sglang.srt.utils import is_remote_url
|
338
|
+
|
339
|
+
if is_remote_url(self.model_path):
|
340
|
+
logger.info("Pulling model configs from remote...")
|
341
|
+
# BaseConnector implements __del__() to clean up the local dir.
|
342
|
+
# Since config files need to exist all the time, so we DO NOT use
|
343
|
+
# with statement to avoid closing the client.
|
344
|
+
client = create_remote_connector(self.model_path)
|
345
|
+
if is_remote_url(self.model_path):
|
346
|
+
client.pull_files(allow_pattern=["*config.json"])
|
347
|
+
self.model_weights = self.model_path
|
348
|
+
self.model_path = client.get_local_dir()
|
349
|
+
|
321
350
|
|
322
351
|
def get_hf_text_config(config: PretrainedConfig):
|
323
352
|
"""Get the "sub" config relevant to llm for multi modal models.
|
@@ -338,6 +367,8 @@ def get_hf_text_config(config: PretrainedConfig):
|
|
338
367
|
# if transformers config doesn't align with this assumption.
|
339
368
|
assert hasattr(config.text_config, "num_attention_heads")
|
340
369
|
return config.text_config
|
370
|
+
if hasattr(config, "language_config"):
|
371
|
+
return config.language_config
|
341
372
|
else:
|
342
373
|
return config
|
343
374
|
|
@@ -367,9 +398,13 @@ def _get_and_verify_dtype(
|
|
367
398
|
dtype = dtype.lower()
|
368
399
|
if dtype == "auto":
|
369
400
|
if config_dtype == torch.float32:
|
370
|
-
if config.model_type
|
401
|
+
if config.model_type.startswith("gemma"):
|
402
|
+
if config.model_type == "gemma":
|
403
|
+
gemma_version = ""
|
404
|
+
else:
|
405
|
+
gemma_version = config.model_type[5]
|
371
406
|
logger.info(
|
372
|
-
"For Gemma
|
407
|
+
f"For Gemma {gemma_version}, we downcast float32 to bfloat16 instead "
|
373
408
|
"of float16 by default. Please specify `dtype` if you "
|
374
409
|
"want to use float16."
|
375
410
|
)
|
@@ -418,6 +453,8 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal
|
|
418
453
|
or "LlamaForSequenceClassificationWithNormal_Weights" in model_architectures
|
419
454
|
or "InternLM2ForRewardModel" in model_architectures
|
420
455
|
or "Qwen2ForRewardModel" in model_architectures
|
456
|
+
or "Qwen2ForSequenceClassification" in model_architectures
|
457
|
+
or "CLIPModel" in model_architectures
|
421
458
|
):
|
422
459
|
return False
|
423
460
|
else:
|
@@ -425,17 +462,21 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal
|
|
425
462
|
|
426
463
|
|
427
464
|
multimodal_model_archs = [
|
465
|
+
"DeepseekVL2ForCausalLM",
|
466
|
+
"Gemma3ForConditionalGeneration",
|
467
|
+
"Grok1VForCausalLM",
|
468
|
+
"Grok1AForCausalLM",
|
428
469
|
"LlavaLlamaForCausalLM",
|
429
|
-
"LlavaQwenForCausalLM",
|
430
470
|
"LlavaMistralForCausalLM",
|
471
|
+
"LlavaQwenForCausalLM",
|
431
472
|
"LlavaVidForCausalLM",
|
432
|
-
"
|
433
|
-
"
|
473
|
+
"MiniCPMO",
|
474
|
+
"MiniCPMV",
|
475
|
+
"MultiModalityCausalLM",
|
434
476
|
"MllamaForConditionalGeneration",
|
435
477
|
"Qwen2VLForConditionalGeneration",
|
436
478
|
"Qwen2_5_VLForConditionalGeneration",
|
437
|
-
"
|
438
|
-
"MultiModalityCausalLM",
|
479
|
+
"CLIPModel",
|
439
480
|
]
|
440
481
|
|
441
482
|
|
@@ -0,0 +1,25 @@
|
|
1
|
+
from typing import Type
|
2
|
+
|
3
|
+
from transformers import (
|
4
|
+
AutoImageProcessor,
|
5
|
+
AutoProcessor,
|
6
|
+
BaseImageProcessor,
|
7
|
+
PretrainedConfig,
|
8
|
+
ProcessorMixin,
|
9
|
+
)
|
10
|
+
|
11
|
+
|
12
|
+
def register_image_processor(
|
13
|
+
config: Type[PretrainedConfig], image_processor: Type[BaseImageProcessor]
|
14
|
+
):
|
15
|
+
"""
|
16
|
+
register customized hf image processor while removing hf impl
|
17
|
+
"""
|
18
|
+
AutoImageProcessor.register(config, None, image_processor, None, exist_ok=True)
|
19
|
+
|
20
|
+
|
21
|
+
def register_processor(config: Type[PretrainedConfig], processor: Type[ProcessorMixin]):
|
22
|
+
"""
|
23
|
+
register customized hf processor while removing hf impl
|
24
|
+
"""
|
25
|
+
AutoProcessor.register(config, processor, exist_ok=True)
|
@@ -0,0 +1,51 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
|
3
|
+
import enum
|
4
|
+
import logging
|
5
|
+
|
6
|
+
from sglang.srt.connector.base_connector import (
|
7
|
+
BaseConnector,
|
8
|
+
BaseFileConnector,
|
9
|
+
BaseKVConnector,
|
10
|
+
)
|
11
|
+
from sglang.srt.connector.redis import RedisConnector
|
12
|
+
from sglang.srt.connector.s3 import S3Connector
|
13
|
+
from sglang.srt.utils import parse_connector_type
|
14
|
+
|
15
|
+
logger = logging.getLogger(__name__)
|
16
|
+
|
17
|
+
|
18
|
+
class ConnectorType(str, enum.Enum):
|
19
|
+
FS = "filesystem"
|
20
|
+
KV = "KV"
|
21
|
+
|
22
|
+
|
23
|
+
def create_remote_connector(url, device="cpu") -> BaseConnector:
|
24
|
+
connector_type = parse_connector_type(url)
|
25
|
+
if connector_type == "redis":
|
26
|
+
return RedisConnector(url)
|
27
|
+
elif connector_type == "s3":
|
28
|
+
return S3Connector(url)
|
29
|
+
else:
|
30
|
+
raise ValueError(f"Invalid connector type: {url}")
|
31
|
+
|
32
|
+
|
33
|
+
def get_connector_type(client: BaseConnector) -> ConnectorType:
|
34
|
+
if isinstance(client, BaseKVConnector):
|
35
|
+
return ConnectorType.KV
|
36
|
+
if isinstance(client, BaseFileConnector):
|
37
|
+
return ConnectorType.FS
|
38
|
+
|
39
|
+
raise ValueError(f"Invalid connector type: {client}")
|
40
|
+
|
41
|
+
|
42
|
+
__all__ = [
|
43
|
+
"BaseConnector",
|
44
|
+
"BaseFileConnector",
|
45
|
+
"BaseKVConnector",
|
46
|
+
"RedisConnector",
|
47
|
+
"S3Connector",
|
48
|
+
"ConnectorType",
|
49
|
+
"create_remote_connector",
|
50
|
+
"get_connector_type",
|
51
|
+
]
|
@@ -0,0 +1,112 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
|
3
|
+
import os
|
4
|
+
import shutil
|
5
|
+
import signal
|
6
|
+
import tempfile
|
7
|
+
from abc import ABC, abstractmethod
|
8
|
+
from typing import Generator, List, Optional, Tuple
|
9
|
+
|
10
|
+
import torch
|
11
|
+
|
12
|
+
|
13
|
+
class BaseConnector(ABC):
|
14
|
+
"""
|
15
|
+
For fs connector such as s3:
|
16
|
+
<connector_type>://<path>/<filename>
|
17
|
+
|
18
|
+
For kv connector such as redis:
|
19
|
+
<connector_type>://<host>:<port>/<model_name>/keys/<key>
|
20
|
+
<connector_type://<host>:<port>/<model_name>/files/<filename>
|
21
|
+
"""
|
22
|
+
|
23
|
+
def __init__(self, url: str, device: torch.device = "cpu"):
|
24
|
+
self.url = url
|
25
|
+
self.device = device
|
26
|
+
self.closed = False
|
27
|
+
self.local_dir = tempfile.mkdtemp()
|
28
|
+
for sig in (signal.SIGINT, signal.SIGTERM):
|
29
|
+
existing_handler = signal.getsignal(sig)
|
30
|
+
signal.signal(sig, self._close_by_signal(existing_handler))
|
31
|
+
|
32
|
+
def get_local_dir(self):
|
33
|
+
return self.local_dir
|
34
|
+
|
35
|
+
@abstractmethod
|
36
|
+
def weight_iterator(
|
37
|
+
self, rank: int = 0
|
38
|
+
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
39
|
+
raise NotImplementedError()
|
40
|
+
|
41
|
+
@abstractmethod
|
42
|
+
def pull_files(
|
43
|
+
self,
|
44
|
+
allow_pattern: Optional[List[str]] = None,
|
45
|
+
ignore_pattern: Optional[List[str]] = None,
|
46
|
+
) -> None:
|
47
|
+
raise NotImplementedError()
|
48
|
+
|
49
|
+
def close(self):
|
50
|
+
if self.closed:
|
51
|
+
return
|
52
|
+
|
53
|
+
self.closed = True
|
54
|
+
if os.path.exists(self.local_dir):
|
55
|
+
shutil.rmtree(self.local_dir)
|
56
|
+
|
57
|
+
def __enter__(self):
|
58
|
+
return self
|
59
|
+
|
60
|
+
def __exit__(self, exc_type, exc_value, traceback):
|
61
|
+
self.close()
|
62
|
+
|
63
|
+
def __del__(self):
|
64
|
+
self.close()
|
65
|
+
|
66
|
+
def _close_by_signal(self, existing_handler=None):
|
67
|
+
|
68
|
+
def new_handler(signum, frame):
|
69
|
+
self.close()
|
70
|
+
if existing_handler:
|
71
|
+
existing_handler(signum, frame)
|
72
|
+
|
73
|
+
return new_handler
|
74
|
+
|
75
|
+
|
76
|
+
class BaseKVConnector(BaseConnector):
|
77
|
+
|
78
|
+
@abstractmethod
|
79
|
+
def get(self, key: str) -> Optional[torch.Tensor]:
|
80
|
+
raise NotImplementedError()
|
81
|
+
|
82
|
+
@abstractmethod
|
83
|
+
def getstr(self, key: str) -> Optional[str]:
|
84
|
+
raise NotImplementedError()
|
85
|
+
|
86
|
+
@abstractmethod
|
87
|
+
def set(self, key: str, obj: torch.Tensor) -> None:
|
88
|
+
raise NotImplementedError()
|
89
|
+
|
90
|
+
@abstractmethod
|
91
|
+
def setstr(self, key: str, obj: str) -> None:
|
92
|
+
raise NotImplementedError()
|
93
|
+
|
94
|
+
@abstractmethod
|
95
|
+
def list(self, prefix: str) -> List[str]:
|
96
|
+
raise NotImplementedError()
|
97
|
+
|
98
|
+
|
99
|
+
class BaseFileConnector(BaseConnector):
|
100
|
+
"""
|
101
|
+
List full file names from remote fs path and filter by allow pattern.
|
102
|
+
|
103
|
+
Args:
|
104
|
+
allow_pattern: A list of patterns of which files to pull.
|
105
|
+
|
106
|
+
Returns:
|
107
|
+
list[str]: List of full paths allowed by the pattern
|
108
|
+
"""
|
109
|
+
|
110
|
+
@abstractmethod
|
111
|
+
def glob(self, allow_pattern: str) -> List[str]:
|
112
|
+
raise NotImplementedError()
|
@@ -0,0 +1,85 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
|
3
|
+
import logging
|
4
|
+
from typing import Generator, List, Optional, Tuple
|
5
|
+
from urllib.parse import urlparse
|
6
|
+
|
7
|
+
import torch
|
8
|
+
|
9
|
+
from sglang.srt.connector import BaseKVConnector
|
10
|
+
from sglang.srt.connector.serde import create_serde
|
11
|
+
from sglang.srt.connector.utils import pull_files_from_db
|
12
|
+
|
13
|
+
logger = logging.getLogger(__name__)
|
14
|
+
|
15
|
+
|
16
|
+
class RedisConnector(BaseKVConnector):
|
17
|
+
|
18
|
+
def __init__(self, url: str, device: torch.device = "cpu"):
|
19
|
+
import redis
|
20
|
+
|
21
|
+
super().__init__(url, device)
|
22
|
+
parsed_url = urlparse(url)
|
23
|
+
self.connection = redis.Redis(host=parsed_url.hostname, port=parsed_url.port)
|
24
|
+
self.model_name = parsed_url.path.lstrip("/")
|
25
|
+
# TODO: more serde options
|
26
|
+
self.s, self.d = create_serde("safe")
|
27
|
+
|
28
|
+
def get(self, key: str) -> Optional[torch.Tensor]:
|
29
|
+
val = self.connection.get(key)
|
30
|
+
|
31
|
+
if val is None:
|
32
|
+
logger.error("Key %s not found", key)
|
33
|
+
return None
|
34
|
+
|
35
|
+
return self.d.from_bytes(val)
|
36
|
+
|
37
|
+
def getstr(self, key: str) -> Optional[str]:
|
38
|
+
val = self.connection.get(key)
|
39
|
+
if val is None:
|
40
|
+
logger.error("Key %s not found", key)
|
41
|
+
return None
|
42
|
+
|
43
|
+
return val.decode("utf-8")
|
44
|
+
|
45
|
+
def set(self, key: str, tensor: torch.Tensor) -> None:
|
46
|
+
assert tensor is not None
|
47
|
+
self.connection.set(key, self.s.to_bytes(tensor))
|
48
|
+
|
49
|
+
def setstr(self, key: str, obj: str) -> None:
|
50
|
+
self.connection.set(key, obj)
|
51
|
+
|
52
|
+
def list(self, prefix: str) -> List[str]:
|
53
|
+
cursor = 0
|
54
|
+
all_keys: List[bytes] = []
|
55
|
+
|
56
|
+
while True:
|
57
|
+
ret: Tuple[int, List[bytes]] = self.connection.scan(
|
58
|
+
cursor=cursor, match=f"{prefix}*"
|
59
|
+
) # type: ignore
|
60
|
+
cursor, keys = ret
|
61
|
+
all_keys.extend(keys)
|
62
|
+
if cursor == 0:
|
63
|
+
break
|
64
|
+
|
65
|
+
return [key.decode("utf-8") for key in all_keys]
|
66
|
+
|
67
|
+
def weight_iterator(
|
68
|
+
self, rank: int = 0
|
69
|
+
) -> Generator[Tuple[str, bytes], None, None]:
|
70
|
+
keys = self.list(f"{self.model_name}/keys/rank_{rank}/")
|
71
|
+
for key in keys:
|
72
|
+
val = self.get(key)
|
73
|
+
key = key.removeprefix(f"{self.model_name}/keys/rank_{rank}/")
|
74
|
+
yield key, val
|
75
|
+
|
76
|
+
def pull_files(
|
77
|
+
self,
|
78
|
+
allow_pattern: Optional[List[str]] = None,
|
79
|
+
ignore_pattern: Optional[List[str]] = None,
|
80
|
+
) -> None:
|
81
|
+
pull_files_from_db(self, self.model_name, allow_pattern, ignore_pattern)
|
82
|
+
|
83
|
+
def close(self):
|
84
|
+
self.connection.close()
|
85
|
+
super().close()
|
@@ -0,0 +1,122 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
|
3
|
+
import fnmatch
|
4
|
+
import os
|
5
|
+
from pathlib import Path
|
6
|
+
from typing import Generator, Optional, Tuple
|
7
|
+
|
8
|
+
import torch
|
9
|
+
|
10
|
+
from sglang.srt.connector import BaseFileConnector
|
11
|
+
|
12
|
+
|
13
|
+
def _filter_allow(paths: list[str], patterns: list[str]) -> list[str]:
|
14
|
+
return [
|
15
|
+
path
|
16
|
+
for path in paths
|
17
|
+
if any(fnmatch.fnmatch(path, pattern) for pattern in patterns)
|
18
|
+
]
|
19
|
+
|
20
|
+
|
21
|
+
def _filter_ignore(paths: list[str], patterns: list[str]) -> list[str]:
|
22
|
+
return [
|
23
|
+
path
|
24
|
+
for path in paths
|
25
|
+
if not any(fnmatch.fnmatch(path, pattern) for pattern in patterns)
|
26
|
+
]
|
27
|
+
|
28
|
+
|
29
|
+
def list_files(
|
30
|
+
s3,
|
31
|
+
path: str,
|
32
|
+
allow_pattern: Optional[list[str]] = None,
|
33
|
+
ignore_pattern: Optional[list[str]] = None,
|
34
|
+
) -> tuple[str, str, list[str]]:
|
35
|
+
"""
|
36
|
+
List files from S3 path and filter by pattern.
|
37
|
+
|
38
|
+
Args:
|
39
|
+
s3: S3 client to use.
|
40
|
+
path: The S3 path to list from.
|
41
|
+
allow_pattern: A list of patterns of which files to pull.
|
42
|
+
ignore_pattern: A list of patterns of which files not to pull.
|
43
|
+
|
44
|
+
Returns:
|
45
|
+
tuple[str, str, list[str]]: A tuple where:
|
46
|
+
- The first element is the bucket name
|
47
|
+
- The second element is string represent the bucket
|
48
|
+
and the prefix as a dir like string
|
49
|
+
- The third element is a list of files allowed or
|
50
|
+
disallowed by pattern
|
51
|
+
"""
|
52
|
+
parts = path.removeprefix("s3://").split("/")
|
53
|
+
prefix = "/".join(parts[1:])
|
54
|
+
bucket_name = parts[0]
|
55
|
+
|
56
|
+
objects = s3.list_objects_v2(Bucket=bucket_name, Prefix=prefix)
|
57
|
+
paths = [obj["Key"] for obj in objects.get("Contents", [])]
|
58
|
+
|
59
|
+
paths = _filter_ignore(paths, ["*/"])
|
60
|
+
if allow_pattern is not None:
|
61
|
+
paths = _filter_allow(paths, allow_pattern)
|
62
|
+
|
63
|
+
if ignore_pattern is not None:
|
64
|
+
paths = _filter_ignore(paths, ignore_pattern)
|
65
|
+
|
66
|
+
return bucket_name, prefix, paths
|
67
|
+
|
68
|
+
|
69
|
+
class S3Connector(BaseFileConnector):
|
70
|
+
|
71
|
+
def __init__(self, url: str) -> None:
|
72
|
+
import boto3
|
73
|
+
|
74
|
+
super().__init__(url)
|
75
|
+
self.client = boto3.client("s3")
|
76
|
+
|
77
|
+
def glob(self, allow_pattern: Optional[list[str]] = None) -> list[str]:
|
78
|
+
bucket_name, _, paths = list_files(
|
79
|
+
self.client, path=self.url, allow_pattern=allow_pattern
|
80
|
+
)
|
81
|
+
return [f"s3://{bucket_name}/{path}" for path in paths]
|
82
|
+
|
83
|
+
def pull_files(
|
84
|
+
self,
|
85
|
+
allow_pattern: Optional[list[str]] = None,
|
86
|
+
ignore_pattern: Optional[list[str]] = None,
|
87
|
+
) -> None:
|
88
|
+
"""
|
89
|
+
Pull files from S3 storage into the temporary directory.
|
90
|
+
|
91
|
+
Args:
|
92
|
+
s3_model_path: The S3 path of the model.
|
93
|
+
allow_pattern: A list of patterns of which files to pull.
|
94
|
+
ignore_pattern: A list of patterns of which files not to pull.
|
95
|
+
|
96
|
+
"""
|
97
|
+
bucket_name, base_dir, files = list_files(
|
98
|
+
self.client, self.url, allow_pattern, ignore_pattern
|
99
|
+
)
|
100
|
+
if len(files) == 0:
|
101
|
+
return
|
102
|
+
|
103
|
+
for file in files:
|
104
|
+
destination_file = os.path.join(self.local_dir, file.removeprefix(base_dir))
|
105
|
+
local_dir = Path(destination_file).parent
|
106
|
+
os.makedirs(local_dir, exist_ok=True)
|
107
|
+
self.client.download_file(bucket_name, file, destination_file)
|
108
|
+
|
109
|
+
def weight_iterator(
|
110
|
+
self, rank: int = 0
|
111
|
+
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
112
|
+
from sglang.srt.model_loader.weight_utils import (
|
113
|
+
runai_safetensors_weights_iterator,
|
114
|
+
)
|
115
|
+
|
116
|
+
# only support safetensor files now
|
117
|
+
hf_weights_files = self.glob(allow_pattern=["*.safetensors"])
|
118
|
+
return runai_safetensors_weights_iterator(hf_weights_files)
|
119
|
+
|
120
|
+
def close(self):
|
121
|
+
self.client.close()
|
122
|
+
super().close()
|
@@ -0,0 +1,31 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
|
3
|
+
# inspired by LMCache
|
4
|
+
from typing import Optional, Tuple
|
5
|
+
|
6
|
+
import torch
|
7
|
+
|
8
|
+
from sglang.srt.connector.serde.safe_serde import SafeDeserializer, SafeSerializer
|
9
|
+
from sglang.srt.connector.serde.serde import Deserializer, Serializer
|
10
|
+
|
11
|
+
|
12
|
+
def create_serde(serde_type: str) -> Tuple[Serializer, Deserializer]:
|
13
|
+
s: Optional[Serializer] = None
|
14
|
+
d: Optional[Deserializer] = None
|
15
|
+
|
16
|
+
if serde_type == "safe":
|
17
|
+
s = SafeSerializer()
|
18
|
+
d = SafeDeserializer(torch.uint8)
|
19
|
+
else:
|
20
|
+
raise ValueError(f"Unknown serde type: {serde_type}")
|
21
|
+
|
22
|
+
return s, d
|
23
|
+
|
24
|
+
|
25
|
+
__all__ = [
|
26
|
+
"Serializer",
|
27
|
+
"Deserializer",
|
28
|
+
"SafeSerializer",
|
29
|
+
"SafeDeserializer",
|
30
|
+
"create_serde",
|
31
|
+
]
|
@@ -0,0 +1,29 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
|
3
|
+
from typing import Union
|
4
|
+
|
5
|
+
import torch
|
6
|
+
from safetensors.torch import load, save
|
7
|
+
|
8
|
+
from sglang.srt.connector.serde.serde import Deserializer, Serializer
|
9
|
+
|
10
|
+
|
11
|
+
class SafeSerializer(Serializer):
|
12
|
+
|
13
|
+
def __init__(self):
|
14
|
+
super().__init__()
|
15
|
+
|
16
|
+
def to_bytes(self, t: torch.Tensor) -> bytes:
|
17
|
+
return save({"tensor_bytes": t.cpu().contiguous()})
|
18
|
+
|
19
|
+
|
20
|
+
class SafeDeserializer(Deserializer):
|
21
|
+
|
22
|
+
def __init__(self, dtype):
|
23
|
+
super().__init__(dtype)
|
24
|
+
|
25
|
+
def from_bytes_normal(self, b: Union[bytearray, bytes]) -> torch.Tensor:
|
26
|
+
return load(bytes(b))["tensor_bytes"].to(dtype=self.dtype)
|
27
|
+
|
28
|
+
def from_bytes(self, b: Union[bytearray, bytes]) -> torch.Tensor:
|
29
|
+
return self.from_bytes_normal(b)
|
@@ -0,0 +1,43 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
|
3
|
+
import abc
|
4
|
+
from abc import ABC, abstractmethod
|
5
|
+
|
6
|
+
import torch
|
7
|
+
|
8
|
+
|
9
|
+
class Serializer(ABC):
|
10
|
+
|
11
|
+
@abstractmethod
|
12
|
+
def to_bytes(self, t: torch.Tensor) -> bytes:
|
13
|
+
"""
|
14
|
+
Serialize a pytorch tensor to bytes. The serialized bytes should contain
|
15
|
+
both the data and the metadata (shape, dtype, etc.) of the tensor.
|
16
|
+
|
17
|
+
Input:
|
18
|
+
t: the input pytorch tensor, can be on any device, in any shape,
|
19
|
+
with any dtype
|
20
|
+
|
21
|
+
Returns:
|
22
|
+
bytes: the serialized bytes
|
23
|
+
"""
|
24
|
+
raise NotImplementedError
|
25
|
+
|
26
|
+
|
27
|
+
class Deserializer(metaclass=abc.ABCMeta):
|
28
|
+
|
29
|
+
def __init__(self, dtype):
|
30
|
+
self.dtype = dtype
|
31
|
+
|
32
|
+
@abstractmethod
|
33
|
+
def from_bytes(self, bs: bytes) -> torch.Tensor:
|
34
|
+
"""
|
35
|
+
Deserialize a pytorch tensor from bytes.
|
36
|
+
|
37
|
+
Input:
|
38
|
+
bytes: a stream of bytes
|
39
|
+
|
40
|
+
Output:
|
41
|
+
torch.Tensor: the deserialized pytorch tensor
|
42
|
+
"""
|
43
|
+
raise NotImplementedError
|
@@ -0,0 +1,35 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
|
3
|
+
import os
|
4
|
+
from pathlib import Path
|
5
|
+
from typing import Optional
|
6
|
+
from urllib.parse import urlparse
|
7
|
+
|
8
|
+
from sglang.srt.connector import BaseConnector
|
9
|
+
|
10
|
+
|
11
|
+
def parse_model_name(url: str) -> str:
|
12
|
+
"""
|
13
|
+
Parse the model name from the url.
|
14
|
+
Only used for db connector
|
15
|
+
"""
|
16
|
+
parsed_url = urlparse(url)
|
17
|
+
return parsed_url.path.lstrip("/")
|
18
|
+
|
19
|
+
|
20
|
+
def pull_files_from_db(
|
21
|
+
connector: BaseConnector,
|
22
|
+
model_name: str,
|
23
|
+
allow_pattern: Optional[list[str]] = None,
|
24
|
+
ignore_pattern: Optional[list[str]] = None,
|
25
|
+
) -> None:
|
26
|
+
prefix = f"{model_name}/files/"
|
27
|
+
local_dir = connector.get_local_dir()
|
28
|
+
files = connector.list(prefix)
|
29
|
+
|
30
|
+
for file in files:
|
31
|
+
destination_file = os.path.join(local_dir, file.removeprefix(prefix))
|
32
|
+
local_dir = Path(destination_file).parent
|
33
|
+
os.makedirs(local_dir, exist_ok=True)
|
34
|
+
with open(destination_file, "wb") as f:
|
35
|
+
f.write(connector.getstr(file).encode("utf-8"))
|