sglang 0.4.4__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +6 -0
- sglang/bench_one_batch.py +1 -1
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +3 -1
- sglang/check_env.py +3 -4
- sglang/lang/backend/openai.py +18 -5
- sglang/lang/chat_template.py +28 -7
- sglang/lang/interpreter.py +7 -3
- sglang/lang/ir.py +10 -0
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/code_completion_parser.py +174 -0
- sglang/srt/configs/__init__.py +2 -6
- sglang/srt/configs/deepseekvl2.py +667 -0
- sglang/srt/configs/janus_pro.py +3 -4
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +63 -11
- sglang/srt/configs/utils.py +25 -0
- sglang/srt/connector/__init__.py +51 -0
- sglang/srt/connector/base_connector.py +112 -0
- sglang/srt/connector/redis.py +85 -0
- sglang/srt/connector/s3.py +122 -0
- sglang/srt/connector/serde/__init__.py +31 -0
- sglang/srt/connector/serde/safe_serde.py +29 -0
- sglang/srt/connector/serde/serde.py +43 -0
- sglang/srt/connector/utils.py +35 -0
- sglang/srt/conversation.py +88 -0
- sglang/srt/disaggregation/conn.py +81 -0
- sglang/srt/disaggregation/decode.py +495 -0
- sglang/srt/disaggregation/mini_lb.py +285 -0
- sglang/srt/disaggregation/prefill.py +249 -0
- sglang/srt/disaggregation/utils.py +44 -0
- sglang/srt/distributed/parallel_state.py +10 -3
- sglang/srt/entrypoints/engine.py +55 -5
- sglang/srt/entrypoints/http_server.py +71 -12
- sglang/srt/function_call_parser.py +164 -54
- sglang/srt/hf_transformers_utils.py +28 -3
- sglang/srt/layers/activation.py +4 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +295 -0
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashmla_backend.py +284 -0
- sglang/srt/layers/attention/triton_backend.py +171 -38
- sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
- sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
- sglang/srt/layers/attention/utils.py +53 -0
- sglang/srt/layers/attention/vision.py +9 -28
- sglang/srt/layers/dp_attention.py +62 -23
- sglang/srt/layers/elementwise.py +411 -0
- sglang/srt/layers/layernorm.py +24 -2
- sglang/srt/layers/linear.py +17 -5
- sglang/srt/layers/logits_processor.py +26 -7
- sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
- sglang/srt/layers/moe/ep_moe/layer.py +273 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
- sglang/srt/layers/moe/fused_moe_native.py +2 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
- sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
- sglang/srt/layers/moe/router.py +342 -0
- sglang/srt/layers/moe/topk.py +31 -18
- sglang/srt/layers/parameter.py +1 -1
- sglang/srt/layers/quantization/__init__.py +184 -126
- sglang/srt/layers/quantization/base_config.py +5 -0
- sglang/srt/layers/quantization/blockwise_int8.py +1 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
- sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
- sglang/srt/layers/quantization/fp8.py +76 -34
- sglang/srt/layers/quantization/fp8_kernel.py +24 -8
- sglang/srt/layers/quantization/fp8_utils.py +284 -28
- sglang/srt/layers/quantization/gptq.py +36 -9
- sglang/srt/layers/quantization/kv_cache.py +98 -0
- sglang/srt/layers/quantization/modelopt_quant.py +9 -7
- sglang/srt/layers/quantization/utils.py +153 -0
- sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
- sglang/srt/layers/rotary_embedding.py +66 -87
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/layers.py +68 -0
- sglang/srt/lora/lora.py +2 -22
- sglang/srt/lora/lora_manager.py +47 -23
- sglang/srt/lora/mem_pool.py +110 -51
- sglang/srt/lora/utils.py +12 -1
- sglang/srt/managers/cache_controller.py +4 -5
- sglang/srt/managers/data_parallel_controller.py +31 -9
- sglang/srt/managers/expert_distribution.py +81 -0
- sglang/srt/managers/io_struct.py +39 -3
- sglang/srt/managers/mm_utils.py +373 -0
- sglang/srt/managers/multimodal_processor.py +68 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
- sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
- sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
- sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
- sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
- sglang/srt/managers/schedule_batch.py +134 -31
- sglang/srt/managers/scheduler.py +325 -38
- sglang/srt/managers/scheduler_output_processor_mixin.py +4 -1
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +59 -23
- sglang/srt/managers/tp_worker.py +1 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
- sglang/srt/managers/utils.py +6 -1
- sglang/srt/mem_cache/hiradix_cache.py +27 -8
- sglang/srt/mem_cache/memory_pool.py +258 -98
- sglang/srt/mem_cache/paged_allocator.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +4 -4
- sglang/srt/model_executor/cuda_graph_runner.py +85 -28
- sglang/srt/model_executor/forward_batch_info.py +81 -15
- sglang/srt/model_executor/model_runner.py +70 -6
- sglang/srt/model_loader/loader.py +160 -2
- sglang/srt/model_loader/weight_utils.py +45 -0
- sglang/srt/models/deepseek_janus_pro.py +29 -86
- sglang/srt/models/deepseek_nextn.py +22 -10
- sglang/srt/models/deepseek_v2.py +326 -192
- sglang/srt/models/deepseek_vl2.py +358 -0
- sglang/srt/models/gemma3_causal.py +684 -0
- sglang/srt/models/gemma3_mm.py +462 -0
- sglang/srt/models/grok.py +374 -119
- sglang/srt/models/llama.py +47 -7
- sglang/srt/models/llama_eagle.py +1 -0
- sglang/srt/models/llama_eagle3.py +196 -0
- sglang/srt/models/llava.py +3 -3
- sglang/srt/models/llavavid.py +3 -3
- sglang/srt/models/minicpmo.py +1995 -0
- sglang/srt/models/minicpmv.py +62 -137
- sglang/srt/models/mllama.py +4 -4
- sglang/srt/models/phi3_small.py +1 -1
- sglang/srt/models/qwen2.py +3 -0
- sglang/srt/models/qwen2_5_vl.py +68 -146
- sglang/srt/models/qwen2_classification.py +75 -0
- sglang/srt/models/qwen2_moe.py +9 -1
- sglang/srt/models/qwen2_vl.py +25 -63
- sglang/srt/openai_api/adapter.py +145 -47
- sglang/srt/openai_api/protocol.py +23 -2
- sglang/srt/sampling/sampling_batch_info.py +1 -1
- sglang/srt/sampling/sampling_params.py +6 -6
- sglang/srt/server_args.py +104 -14
- sglang/srt/speculative/build_eagle_tree.py +7 -347
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
- sglang/srt/speculative/eagle_utils.py +208 -252
- sglang/srt/speculative/eagle_worker.py +139 -53
- sglang/srt/speculative/spec_info.py +6 -1
- sglang/srt/torch_memory_saver_adapter.py +22 -0
- sglang/srt/utils.py +182 -21
- sglang/test/__init__.py +0 -0
- sglang/test/attention/__init__.py +0 -0
- sglang/test/attention/test_flashattn_backend.py +312 -0
- sglang/test/runners.py +2 -0
- sglang/test/test_activation.py +2 -1
- sglang/test/test_block_fp8.py +5 -4
- sglang/test/test_block_fp8_ep.py +2 -1
- sglang/test/test_dynamic_grad_mode.py +58 -0
- sglang/test/test_layernorm.py +3 -2
- sglang/test/test_utils.py +55 -4
- sglang/utils.py +31 -0
- sglang/version.py +1 -1
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +171 -125
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
- sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
- sglang/srt/managers/image_processor.py +0 -55
- sglang/srt/managers/image_processors/base_image_processor.py +0 -219
- sglang/srt/managers/image_processors/minicpmv.py +0 -86
- sglang/srt/managers/multi_modality_padding.py +0 -134
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
@@ -22,7 +22,11 @@ import torch
|
|
22
22
|
from transformers import PretrainedConfig
|
23
23
|
|
24
24
|
from sglang.srt.hf_transformers_utils import get_config, get_context_length
|
25
|
-
from sglang.srt.layers.quantization import
|
25
|
+
from sglang.srt.layers.quantization import (
|
26
|
+
BASE_QUANTIZATION_METHODS,
|
27
|
+
QUANTIZATION_METHODS,
|
28
|
+
VLLM_AVAILABLE,
|
29
|
+
)
|
26
30
|
from sglang.srt.utils import get_bool_env_var, is_hip
|
27
31
|
|
28
32
|
logger = logging.getLogger(__name__)
|
@@ -51,13 +55,14 @@ class ModelConfig:
|
|
51
55
|
self.quantization = quantization
|
52
56
|
|
53
57
|
# Parse args
|
58
|
+
self.maybe_pull_model_tokenizer_from_remote()
|
54
59
|
self.model_override_args = json.loads(model_override_args)
|
55
60
|
kwargs = {}
|
56
61
|
if override_config_file and override_config_file.strip():
|
57
62
|
kwargs["_configuration_file"] = override_config_file.strip()
|
58
63
|
|
59
64
|
self.hf_config = get_config(
|
60
|
-
model_path,
|
65
|
+
self.model_path,
|
61
66
|
trust_remote_code=trust_remote_code,
|
62
67
|
revision=revision,
|
63
68
|
model_override_args=self.model_override_args,
|
@@ -134,6 +139,11 @@ class ModelConfig:
|
|
134
139
|
self.attention_arch = AttentionArch.MLA
|
135
140
|
self.kv_lora_rank = self.hf_config.kv_lora_rank
|
136
141
|
self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
|
142
|
+
elif "DeepseekVL2ForCausalLM" in self.hf_config.architectures:
|
143
|
+
self.head_dim = 256
|
144
|
+
self.attention_arch = AttentionArch.MLA
|
145
|
+
self.kv_lora_rank = self.hf_text_config.kv_lora_rank
|
146
|
+
self.qk_rope_head_dim = self.hf_text_config.qk_rope_head_dim
|
137
147
|
else:
|
138
148
|
self.attention_arch = AttentionArch.MHA
|
139
149
|
|
@@ -229,7 +239,12 @@ class ModelConfig:
|
|
229
239
|
|
230
240
|
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
|
231
241
|
def _verify_quantization(self) -> None:
|
232
|
-
|
242
|
+
# Select supported quantization methods based on vllm availability
|
243
|
+
if VLLM_AVAILABLE:
|
244
|
+
supported_quantization = [*QUANTIZATION_METHODS]
|
245
|
+
else:
|
246
|
+
supported_quantization = [*BASE_QUANTIZATION_METHODS]
|
247
|
+
|
233
248
|
rocm_supported_quantization = [
|
234
249
|
"awq",
|
235
250
|
"gptq",
|
@@ -267,7 +282,11 @@ class ModelConfig:
|
|
267
282
|
quant_method = quant_cfg.get("quant_method", "").lower()
|
268
283
|
|
269
284
|
# Detect which checkpoint is it
|
270
|
-
|
285
|
+
# Only iterate through currently available quantization methods
|
286
|
+
available_methods = (
|
287
|
+
QUANTIZATION_METHODS if VLLM_AVAILABLE else BASE_QUANTIZATION_METHODS
|
288
|
+
)
|
289
|
+
for _, method in available_methods.items():
|
271
290
|
quantization_override = method.override_quantization_method(
|
272
291
|
quant_cfg, self.quantization
|
273
292
|
)
|
@@ -318,6 +337,29 @@ class ModelConfig:
|
|
318
337
|
eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
|
319
338
|
return eos_ids
|
320
339
|
|
340
|
+
def maybe_pull_model_tokenizer_from_remote(self) -> None:
|
341
|
+
"""
|
342
|
+
Pull the model config files to a temporary
|
343
|
+
directory in case of remote.
|
344
|
+
|
345
|
+
Args:
|
346
|
+
model: The model name or path.
|
347
|
+
|
348
|
+
"""
|
349
|
+
from sglang.srt.connector import create_remote_connector
|
350
|
+
from sglang.srt.utils import is_remote_url
|
351
|
+
|
352
|
+
if is_remote_url(self.model_path):
|
353
|
+
logger.info("Pulling model configs from remote...")
|
354
|
+
# BaseConnector implements __del__() to clean up the local dir.
|
355
|
+
# Since config files need to exist all the time, so we DO NOT use
|
356
|
+
# with statement to avoid closing the client.
|
357
|
+
client = create_remote_connector(self.model_path)
|
358
|
+
if is_remote_url(self.model_path):
|
359
|
+
client.pull_files(allow_pattern=["*config.json"])
|
360
|
+
self.model_weights = self.model_path
|
361
|
+
self.model_path = client.get_local_dir()
|
362
|
+
|
321
363
|
|
322
364
|
def get_hf_text_config(config: PretrainedConfig):
|
323
365
|
"""Get the "sub" config relevant to llm for multi modal models.
|
@@ -338,6 +380,8 @@ def get_hf_text_config(config: PretrainedConfig):
|
|
338
380
|
# if transformers config doesn't align with this assumption.
|
339
381
|
assert hasattr(config.text_config, "num_attention_heads")
|
340
382
|
return config.text_config
|
383
|
+
if hasattr(config, "language_config"):
|
384
|
+
return config.language_config
|
341
385
|
else:
|
342
386
|
return config
|
343
387
|
|
@@ -367,9 +411,13 @@ def _get_and_verify_dtype(
|
|
367
411
|
dtype = dtype.lower()
|
368
412
|
if dtype == "auto":
|
369
413
|
if config_dtype == torch.float32:
|
370
|
-
if config.model_type
|
414
|
+
if config.model_type.startswith("gemma"):
|
415
|
+
if config.model_type == "gemma":
|
416
|
+
gemma_version = ""
|
417
|
+
else:
|
418
|
+
gemma_version = config.model_type[5]
|
371
419
|
logger.info(
|
372
|
-
"For Gemma
|
420
|
+
f"For Gemma {gemma_version}, we downcast float32 to bfloat16 instead "
|
373
421
|
"of float16 by default. Please specify `dtype` if you "
|
374
422
|
"want to use float16."
|
375
423
|
)
|
@@ -418,6 +466,7 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal
|
|
418
466
|
or "LlamaForSequenceClassificationWithNormal_Weights" in model_architectures
|
419
467
|
or "InternLM2ForRewardModel" in model_architectures
|
420
468
|
or "Qwen2ForRewardModel" in model_architectures
|
469
|
+
or "Qwen2ForSequenceClassification" in model_architectures
|
421
470
|
):
|
422
471
|
return False
|
423
472
|
else:
|
@@ -425,17 +474,20 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal
|
|
425
474
|
|
426
475
|
|
427
476
|
multimodal_model_archs = [
|
477
|
+
"DeepseekVL2ForCausalLM",
|
478
|
+
"Gemma3ForConditionalGeneration",
|
479
|
+
"Grok1VForCausalLM",
|
480
|
+
"Grok1AForCausalLM",
|
428
481
|
"LlavaLlamaForCausalLM",
|
429
|
-
"LlavaQwenForCausalLM",
|
430
482
|
"LlavaMistralForCausalLM",
|
483
|
+
"LlavaQwenForCausalLM",
|
431
484
|
"LlavaVidForCausalLM",
|
432
|
-
"
|
433
|
-
"
|
485
|
+
"MiniCPMO",
|
486
|
+
"MiniCPMV",
|
487
|
+
"MultiModalityCausalLM",
|
434
488
|
"MllamaForConditionalGeneration",
|
435
489
|
"Qwen2VLForConditionalGeneration",
|
436
490
|
"Qwen2_5_VLForConditionalGeneration",
|
437
|
-
"MiniCPMV",
|
438
|
-
"MultiModalityCausalLM",
|
439
491
|
]
|
440
492
|
|
441
493
|
|
@@ -0,0 +1,25 @@
|
|
1
|
+
from typing import Type
|
2
|
+
|
3
|
+
from transformers import (
|
4
|
+
AutoImageProcessor,
|
5
|
+
AutoProcessor,
|
6
|
+
BaseImageProcessor,
|
7
|
+
PretrainedConfig,
|
8
|
+
ProcessorMixin,
|
9
|
+
)
|
10
|
+
|
11
|
+
|
12
|
+
def register_image_processor(
|
13
|
+
config: Type[PretrainedConfig], image_processor: Type[BaseImageProcessor]
|
14
|
+
):
|
15
|
+
"""
|
16
|
+
register customized hf image processor while removing hf impl
|
17
|
+
"""
|
18
|
+
AutoImageProcessor.register(config, None, image_processor, None, exist_ok=True)
|
19
|
+
|
20
|
+
|
21
|
+
def register_processor(config: Type[PretrainedConfig], processor: Type[ProcessorMixin]):
|
22
|
+
"""
|
23
|
+
register customized hf processor while removing hf impl
|
24
|
+
"""
|
25
|
+
AutoProcessor.register(config, processor, exist_ok=True)
|
@@ -0,0 +1,51 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
|
3
|
+
import enum
|
4
|
+
import logging
|
5
|
+
|
6
|
+
from sglang.srt.connector.base_connector import (
|
7
|
+
BaseConnector,
|
8
|
+
BaseFileConnector,
|
9
|
+
BaseKVConnector,
|
10
|
+
)
|
11
|
+
from sglang.srt.connector.redis import RedisConnector
|
12
|
+
from sglang.srt.connector.s3 import S3Connector
|
13
|
+
from sglang.srt.utils import parse_connector_type
|
14
|
+
|
15
|
+
logger = logging.getLogger(__name__)
|
16
|
+
|
17
|
+
|
18
|
+
class ConnectorType(str, enum.Enum):
|
19
|
+
FS = "filesystem"
|
20
|
+
KV = "KV"
|
21
|
+
|
22
|
+
|
23
|
+
def create_remote_connector(url, device="cpu") -> BaseConnector:
|
24
|
+
connector_type = parse_connector_type(url)
|
25
|
+
if connector_type == "redis":
|
26
|
+
return RedisConnector(url)
|
27
|
+
elif connector_type == "s3":
|
28
|
+
return S3Connector(url)
|
29
|
+
else:
|
30
|
+
raise ValueError(f"Invalid connector type: {url}")
|
31
|
+
|
32
|
+
|
33
|
+
def get_connector_type(client: BaseConnector) -> ConnectorType:
|
34
|
+
if isinstance(client, BaseKVConnector):
|
35
|
+
return ConnectorType.KV
|
36
|
+
if isinstance(client, BaseFileConnector):
|
37
|
+
return ConnectorType.FS
|
38
|
+
|
39
|
+
raise ValueError(f"Invalid connector type: {client}")
|
40
|
+
|
41
|
+
|
42
|
+
__all__ = [
|
43
|
+
"BaseConnector",
|
44
|
+
"BaseFileConnector",
|
45
|
+
"BaseKVConnector",
|
46
|
+
"RedisConnector",
|
47
|
+
"S3Connector",
|
48
|
+
"ConnectorType",
|
49
|
+
"create_remote_connector",
|
50
|
+
"get_connector_type",
|
51
|
+
]
|
@@ -0,0 +1,112 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
|
3
|
+
import os
|
4
|
+
import shutil
|
5
|
+
import signal
|
6
|
+
import tempfile
|
7
|
+
from abc import ABC, abstractmethod
|
8
|
+
from typing import Generator, List, Optional, Tuple
|
9
|
+
|
10
|
+
import torch
|
11
|
+
|
12
|
+
|
13
|
+
class BaseConnector(ABC):
|
14
|
+
"""
|
15
|
+
For fs connector such as s3:
|
16
|
+
<connector_type>://<path>/<filename>
|
17
|
+
|
18
|
+
For kv connector such as redis:
|
19
|
+
<connector_type>://<host>:<port>/<model_name>/keys/<key>
|
20
|
+
<connector_type://<host>:<port>/<model_name>/files/<filename>
|
21
|
+
"""
|
22
|
+
|
23
|
+
def __init__(self, url: str, device: torch.device = "cpu"):
|
24
|
+
self.url = url
|
25
|
+
self.device = device
|
26
|
+
self.closed = False
|
27
|
+
self.local_dir = tempfile.mkdtemp()
|
28
|
+
for sig in (signal.SIGINT, signal.SIGTERM):
|
29
|
+
existing_handler = signal.getsignal(sig)
|
30
|
+
signal.signal(sig, self._close_by_signal(existing_handler))
|
31
|
+
|
32
|
+
def get_local_dir(self):
|
33
|
+
return self.local_dir
|
34
|
+
|
35
|
+
@abstractmethod
|
36
|
+
def weight_iterator(
|
37
|
+
self, rank: int = 0
|
38
|
+
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
39
|
+
raise NotImplementedError()
|
40
|
+
|
41
|
+
@abstractmethod
|
42
|
+
def pull_files(
|
43
|
+
self,
|
44
|
+
allow_pattern: Optional[List[str]] = None,
|
45
|
+
ignore_pattern: Optional[List[str]] = None,
|
46
|
+
) -> None:
|
47
|
+
raise NotImplementedError()
|
48
|
+
|
49
|
+
def close(self):
|
50
|
+
if self.closed:
|
51
|
+
return
|
52
|
+
|
53
|
+
self.closed = True
|
54
|
+
if os.path.exists(self.local_dir):
|
55
|
+
shutil.rmtree(self.local_dir)
|
56
|
+
|
57
|
+
def __enter__(self):
|
58
|
+
return self
|
59
|
+
|
60
|
+
def __exit__(self, exc_type, exc_value, traceback):
|
61
|
+
self.close()
|
62
|
+
|
63
|
+
def __del__(self):
|
64
|
+
self.close()
|
65
|
+
|
66
|
+
def _close_by_signal(self, existing_handler=None):
|
67
|
+
|
68
|
+
def new_handler(signum, frame):
|
69
|
+
self.close()
|
70
|
+
if existing_handler:
|
71
|
+
existing_handler(signum, frame)
|
72
|
+
|
73
|
+
return new_handler
|
74
|
+
|
75
|
+
|
76
|
+
class BaseKVConnector(BaseConnector):
|
77
|
+
|
78
|
+
@abstractmethod
|
79
|
+
def get(self, key: str) -> Optional[torch.Tensor]:
|
80
|
+
raise NotImplementedError()
|
81
|
+
|
82
|
+
@abstractmethod
|
83
|
+
def getstr(self, key: str) -> Optional[str]:
|
84
|
+
raise NotImplementedError()
|
85
|
+
|
86
|
+
@abstractmethod
|
87
|
+
def set(self, key: str, obj: torch.Tensor) -> None:
|
88
|
+
raise NotImplementedError()
|
89
|
+
|
90
|
+
@abstractmethod
|
91
|
+
def setstr(self, key: str, obj: str) -> None:
|
92
|
+
raise NotImplementedError()
|
93
|
+
|
94
|
+
@abstractmethod
|
95
|
+
def list(self, prefix: str) -> List[str]:
|
96
|
+
raise NotImplementedError()
|
97
|
+
|
98
|
+
|
99
|
+
class BaseFileConnector(BaseConnector):
|
100
|
+
"""
|
101
|
+
List full file names from remote fs path and filter by allow pattern.
|
102
|
+
|
103
|
+
Args:
|
104
|
+
allow_pattern: A list of patterns of which files to pull.
|
105
|
+
|
106
|
+
Returns:
|
107
|
+
list[str]: List of full paths allowed by the pattern
|
108
|
+
"""
|
109
|
+
|
110
|
+
@abstractmethod
|
111
|
+
def glob(self, allow_pattern: str) -> List[str]:
|
112
|
+
raise NotImplementedError()
|
@@ -0,0 +1,85 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
|
3
|
+
import logging
|
4
|
+
from typing import Generator, List, Optional, Tuple
|
5
|
+
from urllib.parse import urlparse
|
6
|
+
|
7
|
+
import torch
|
8
|
+
|
9
|
+
from sglang.srt.connector import BaseKVConnector
|
10
|
+
from sglang.srt.connector.serde import create_serde
|
11
|
+
from sglang.srt.connector.utils import pull_files_from_db
|
12
|
+
|
13
|
+
logger = logging.getLogger(__name__)
|
14
|
+
|
15
|
+
|
16
|
+
class RedisConnector(BaseKVConnector):
|
17
|
+
|
18
|
+
def __init__(self, url: str, device: torch.device = "cpu"):
|
19
|
+
import redis
|
20
|
+
|
21
|
+
super().__init__(url, device)
|
22
|
+
parsed_url = urlparse(url)
|
23
|
+
self.connection = redis.Redis(host=parsed_url.hostname, port=parsed_url.port)
|
24
|
+
self.model_name = parsed_url.path.lstrip("/")
|
25
|
+
# TODO: more serde options
|
26
|
+
self.s, self.d = create_serde("safe")
|
27
|
+
|
28
|
+
def get(self, key: str) -> Optional[torch.Tensor]:
|
29
|
+
val = self.connection.get(key)
|
30
|
+
|
31
|
+
if val is None:
|
32
|
+
logger.error("Key %s not found", key)
|
33
|
+
return None
|
34
|
+
|
35
|
+
return self.d.from_bytes(val)
|
36
|
+
|
37
|
+
def getstr(self, key: str) -> Optional[str]:
|
38
|
+
val = self.connection.get(key)
|
39
|
+
if val is None:
|
40
|
+
logger.error("Key %s not found", key)
|
41
|
+
return None
|
42
|
+
|
43
|
+
return val.decode("utf-8")
|
44
|
+
|
45
|
+
def set(self, key: str, tensor: torch.Tensor) -> None:
|
46
|
+
assert tensor is not None
|
47
|
+
self.connection.set(key, self.s.to_bytes(tensor))
|
48
|
+
|
49
|
+
def setstr(self, key: str, obj: str) -> None:
|
50
|
+
self.connection.set(key, obj)
|
51
|
+
|
52
|
+
def list(self, prefix: str) -> List[str]:
|
53
|
+
cursor = 0
|
54
|
+
all_keys: List[bytes] = []
|
55
|
+
|
56
|
+
while True:
|
57
|
+
ret: Tuple[int, List[bytes]] = self.connection.scan(
|
58
|
+
cursor=cursor, match=f"{prefix}*"
|
59
|
+
) # type: ignore
|
60
|
+
cursor, keys = ret
|
61
|
+
all_keys.extend(keys)
|
62
|
+
if cursor == 0:
|
63
|
+
break
|
64
|
+
|
65
|
+
return [key.decode("utf-8") for key in all_keys]
|
66
|
+
|
67
|
+
def weight_iterator(
|
68
|
+
self, rank: int = 0
|
69
|
+
) -> Generator[Tuple[str, bytes], None, None]:
|
70
|
+
keys = self.list(f"{self.model_name}/keys/rank_{rank}/")
|
71
|
+
for key in keys:
|
72
|
+
val = self.get(key)
|
73
|
+
key = key.removeprefix(f"{self.model_name}/keys/rank_{rank}/")
|
74
|
+
yield key, val
|
75
|
+
|
76
|
+
def pull_files(
|
77
|
+
self,
|
78
|
+
allow_pattern: Optional[List[str]] = None,
|
79
|
+
ignore_pattern: Optional[List[str]] = None,
|
80
|
+
) -> None:
|
81
|
+
pull_files_from_db(self, self.model_name, allow_pattern, ignore_pattern)
|
82
|
+
|
83
|
+
def close(self):
|
84
|
+
self.connection.close()
|
85
|
+
super().close()
|
@@ -0,0 +1,122 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
|
3
|
+
import fnmatch
|
4
|
+
import os
|
5
|
+
from pathlib import Path
|
6
|
+
from typing import Generator, Optional, Tuple
|
7
|
+
|
8
|
+
import torch
|
9
|
+
|
10
|
+
from sglang.srt.connector import BaseFileConnector
|
11
|
+
|
12
|
+
|
13
|
+
def _filter_allow(paths: list[str], patterns: list[str]) -> list[str]:
|
14
|
+
return [
|
15
|
+
path
|
16
|
+
for path in paths
|
17
|
+
if any(fnmatch.fnmatch(path, pattern) for pattern in patterns)
|
18
|
+
]
|
19
|
+
|
20
|
+
|
21
|
+
def _filter_ignore(paths: list[str], patterns: list[str]) -> list[str]:
|
22
|
+
return [
|
23
|
+
path
|
24
|
+
for path in paths
|
25
|
+
if not any(fnmatch.fnmatch(path, pattern) for pattern in patterns)
|
26
|
+
]
|
27
|
+
|
28
|
+
|
29
|
+
def list_files(
|
30
|
+
s3,
|
31
|
+
path: str,
|
32
|
+
allow_pattern: Optional[list[str]] = None,
|
33
|
+
ignore_pattern: Optional[list[str]] = None,
|
34
|
+
) -> tuple[str, str, list[str]]:
|
35
|
+
"""
|
36
|
+
List files from S3 path and filter by pattern.
|
37
|
+
|
38
|
+
Args:
|
39
|
+
s3: S3 client to use.
|
40
|
+
path: The S3 path to list from.
|
41
|
+
allow_pattern: A list of patterns of which files to pull.
|
42
|
+
ignore_pattern: A list of patterns of which files not to pull.
|
43
|
+
|
44
|
+
Returns:
|
45
|
+
tuple[str, str, list[str]]: A tuple where:
|
46
|
+
- The first element is the bucket name
|
47
|
+
- The second element is string represent the bucket
|
48
|
+
and the prefix as a dir like string
|
49
|
+
- The third element is a list of files allowed or
|
50
|
+
disallowed by pattern
|
51
|
+
"""
|
52
|
+
parts = path.removeprefix("s3://").split("/")
|
53
|
+
prefix = "/".join(parts[1:])
|
54
|
+
bucket_name = parts[0]
|
55
|
+
|
56
|
+
objects = s3.list_objects_v2(Bucket=bucket_name, Prefix=prefix)
|
57
|
+
paths = [obj["Key"] for obj in objects.get("Contents", [])]
|
58
|
+
|
59
|
+
paths = _filter_ignore(paths, ["*/"])
|
60
|
+
if allow_pattern is not None:
|
61
|
+
paths = _filter_allow(paths, allow_pattern)
|
62
|
+
|
63
|
+
if ignore_pattern is not None:
|
64
|
+
paths = _filter_ignore(paths, ignore_pattern)
|
65
|
+
|
66
|
+
return bucket_name, prefix, paths
|
67
|
+
|
68
|
+
|
69
|
+
class S3Connector(BaseFileConnector):
|
70
|
+
|
71
|
+
def __init__(self, url: str) -> None:
|
72
|
+
import boto3
|
73
|
+
|
74
|
+
super().__init__(url)
|
75
|
+
self.client = boto3.client("s3")
|
76
|
+
|
77
|
+
def glob(self, allow_pattern: Optional[list[str]] = None) -> list[str]:
|
78
|
+
bucket_name, _, paths = list_files(
|
79
|
+
self.client, path=self.url, allow_pattern=allow_pattern
|
80
|
+
)
|
81
|
+
return [f"s3://{bucket_name}/{path}" for path in paths]
|
82
|
+
|
83
|
+
def pull_files(
|
84
|
+
self,
|
85
|
+
allow_pattern: Optional[list[str]] = None,
|
86
|
+
ignore_pattern: Optional[list[str]] = None,
|
87
|
+
) -> None:
|
88
|
+
"""
|
89
|
+
Pull files from S3 storage into the temporary directory.
|
90
|
+
|
91
|
+
Args:
|
92
|
+
s3_model_path: The S3 path of the model.
|
93
|
+
allow_pattern: A list of patterns of which files to pull.
|
94
|
+
ignore_pattern: A list of patterns of which files not to pull.
|
95
|
+
|
96
|
+
"""
|
97
|
+
bucket_name, base_dir, files = list_files(
|
98
|
+
self.client, self.url, allow_pattern, ignore_pattern
|
99
|
+
)
|
100
|
+
if len(files) == 0:
|
101
|
+
return
|
102
|
+
|
103
|
+
for file in files:
|
104
|
+
destination_file = os.path.join(self.local_dir, file.removeprefix(base_dir))
|
105
|
+
local_dir = Path(destination_file).parent
|
106
|
+
os.makedirs(local_dir, exist_ok=True)
|
107
|
+
self.client.download_file(bucket_name, file, destination_file)
|
108
|
+
|
109
|
+
def weight_iterator(
|
110
|
+
self, rank: int = 0
|
111
|
+
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
112
|
+
from sglang.srt.model_loader.weight_utils import (
|
113
|
+
runai_safetensors_weights_iterator,
|
114
|
+
)
|
115
|
+
|
116
|
+
# only support safetensor files now
|
117
|
+
hf_weights_files = self.glob(allow_pattern=["*.safetensors"])
|
118
|
+
return runai_safetensors_weights_iterator(hf_weights_files)
|
119
|
+
|
120
|
+
def close(self):
|
121
|
+
self.client.close()
|
122
|
+
super().close()
|
@@ -0,0 +1,31 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
|
3
|
+
# inspired by LMCache
|
4
|
+
from typing import Optional, Tuple
|
5
|
+
|
6
|
+
import torch
|
7
|
+
|
8
|
+
from sglang.srt.connector.serde.safe_serde import SafeDeserializer, SafeSerializer
|
9
|
+
from sglang.srt.connector.serde.serde import Deserializer, Serializer
|
10
|
+
|
11
|
+
|
12
|
+
def create_serde(serde_type: str) -> Tuple[Serializer, Deserializer]:
|
13
|
+
s: Optional[Serializer] = None
|
14
|
+
d: Optional[Deserializer] = None
|
15
|
+
|
16
|
+
if serde_type == "safe":
|
17
|
+
s = SafeSerializer()
|
18
|
+
d = SafeDeserializer(torch.uint8)
|
19
|
+
else:
|
20
|
+
raise ValueError(f"Unknown serde type: {serde_type}")
|
21
|
+
|
22
|
+
return s, d
|
23
|
+
|
24
|
+
|
25
|
+
__all__ = [
|
26
|
+
"Serializer",
|
27
|
+
"Deserializer",
|
28
|
+
"SafeSerializer",
|
29
|
+
"SafeDeserializer",
|
30
|
+
"create_serde",
|
31
|
+
]
|
@@ -0,0 +1,29 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
|
3
|
+
from typing import Union
|
4
|
+
|
5
|
+
import torch
|
6
|
+
from safetensors.torch import load, save
|
7
|
+
|
8
|
+
from sglang.srt.connector.serde.serde import Deserializer, Serializer
|
9
|
+
|
10
|
+
|
11
|
+
class SafeSerializer(Serializer):
|
12
|
+
|
13
|
+
def __init__(self):
|
14
|
+
super().__init__()
|
15
|
+
|
16
|
+
def to_bytes(self, t: torch.Tensor) -> bytes:
|
17
|
+
return save({"tensor_bytes": t.cpu().contiguous()})
|
18
|
+
|
19
|
+
|
20
|
+
class SafeDeserializer(Deserializer):
|
21
|
+
|
22
|
+
def __init__(self, dtype):
|
23
|
+
super().__init__(dtype)
|
24
|
+
|
25
|
+
def from_bytes_normal(self, b: Union[bytearray, bytes]) -> torch.Tensor:
|
26
|
+
return load(bytes(b))["tensor_bytes"].to(dtype=self.dtype)
|
27
|
+
|
28
|
+
def from_bytes(self, b: Union[bytearray, bytes]) -> torch.Tensor:
|
29
|
+
return self.from_bytes_normal(b)
|
@@ -0,0 +1,43 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
|
3
|
+
import abc
|
4
|
+
from abc import ABC, abstractmethod
|
5
|
+
|
6
|
+
import torch
|
7
|
+
|
8
|
+
|
9
|
+
class Serializer(ABC):
|
10
|
+
|
11
|
+
@abstractmethod
|
12
|
+
def to_bytes(self, t: torch.Tensor) -> bytes:
|
13
|
+
"""
|
14
|
+
Serialize a pytorch tensor to bytes. The serialized bytes should contain
|
15
|
+
both the data and the metadata (shape, dtype, etc.) of the tensor.
|
16
|
+
|
17
|
+
Input:
|
18
|
+
t: the input pytorch tensor, can be on any device, in any shape,
|
19
|
+
with any dtype
|
20
|
+
|
21
|
+
Returns:
|
22
|
+
bytes: the serialized bytes
|
23
|
+
"""
|
24
|
+
raise NotImplementedError
|
25
|
+
|
26
|
+
|
27
|
+
class Deserializer(metaclass=abc.ABCMeta):
|
28
|
+
|
29
|
+
def __init__(self, dtype):
|
30
|
+
self.dtype = dtype
|
31
|
+
|
32
|
+
@abstractmethod
|
33
|
+
def from_bytes(self, bs: bytes) -> torch.Tensor:
|
34
|
+
"""
|
35
|
+
Deserialize a pytorch tensor from bytes.
|
36
|
+
|
37
|
+
Input:
|
38
|
+
bytes: a stream of bytes
|
39
|
+
|
40
|
+
Output:
|
41
|
+
torch.Tensor: the deserialized pytorch tensor
|
42
|
+
"""
|
43
|
+
raise NotImplementedError
|
@@ -0,0 +1,35 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
|
3
|
+
import os
|
4
|
+
from pathlib import Path
|
5
|
+
from typing import Optional
|
6
|
+
from urllib.parse import urlparse
|
7
|
+
|
8
|
+
from sglang.srt.connector import BaseConnector
|
9
|
+
|
10
|
+
|
11
|
+
def parse_model_name(url: str) -> str:
|
12
|
+
"""
|
13
|
+
Parse the model name from the url.
|
14
|
+
Only used for db connector
|
15
|
+
"""
|
16
|
+
parsed_url = urlparse(url)
|
17
|
+
return parsed_url.path.lstrip("/")
|
18
|
+
|
19
|
+
|
20
|
+
def pull_files_from_db(
|
21
|
+
connector: BaseConnector,
|
22
|
+
model_name: str,
|
23
|
+
allow_pattern: Optional[list[str]] = None,
|
24
|
+
ignore_pattern: Optional[list[str]] = None,
|
25
|
+
) -> None:
|
26
|
+
prefix = f"{model_name}/files/"
|
27
|
+
local_dir = connector.get_local_dir()
|
28
|
+
files = connector.list(prefix)
|
29
|
+
|
30
|
+
for file in files:
|
31
|
+
destination_file = os.path.join(local_dir, file.removeprefix(prefix))
|
32
|
+
local_dir = Path(destination_file).parent
|
33
|
+
os.makedirs(local_dir, exist_ok=True)
|
34
|
+
with open(destination_file, "wb") as f:
|
35
|
+
f.write(connector.getstr(file).encode("utf-8"))
|