sglang 0.4.4.post1__py3-none-any.whl → 0.4.4.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +6 -0
  3. sglang/bench_one_batch.py +1 -1
  4. sglang/bench_one_batch_server.py +1 -1
  5. sglang/bench_serving.py +26 -4
  6. sglang/check_env.py +3 -4
  7. sglang/lang/backend/openai.py +18 -5
  8. sglang/lang/chat_template.py +28 -7
  9. sglang/lang/interpreter.py +7 -3
  10. sglang/lang/ir.py +10 -0
  11. sglang/srt/_custom_ops.py +1 -1
  12. sglang/srt/code_completion_parser.py +174 -0
  13. sglang/srt/configs/__init__.py +2 -6
  14. sglang/srt/configs/deepseekvl2.py +676 -0
  15. sglang/srt/configs/janus_pro.py +3 -4
  16. sglang/srt/configs/load_config.py +1 -0
  17. sglang/srt/configs/model_config.py +49 -8
  18. sglang/srt/configs/utils.py +25 -0
  19. sglang/srt/connector/__init__.py +51 -0
  20. sglang/srt/connector/base_connector.py +112 -0
  21. sglang/srt/connector/redis.py +85 -0
  22. sglang/srt/connector/s3.py +122 -0
  23. sglang/srt/connector/serde/__init__.py +31 -0
  24. sglang/srt/connector/serde/safe_serde.py +29 -0
  25. sglang/srt/connector/serde/serde.py +43 -0
  26. sglang/srt/connector/utils.py +35 -0
  27. sglang/srt/conversation.py +88 -0
  28. sglang/srt/disaggregation/conn.py +81 -0
  29. sglang/srt/disaggregation/decode.py +495 -0
  30. sglang/srt/disaggregation/mini_lb.py +285 -0
  31. sglang/srt/disaggregation/prefill.py +249 -0
  32. sglang/srt/disaggregation/utils.py +44 -0
  33. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  34. sglang/srt/distributed/parallel_state.py +42 -8
  35. sglang/srt/entrypoints/engine.py +55 -5
  36. sglang/srt/entrypoints/http_server.py +78 -13
  37. sglang/srt/entrypoints/verl_engine.py +2 -0
  38. sglang/srt/function_call_parser.py +133 -55
  39. sglang/srt/hf_transformers_utils.py +28 -3
  40. sglang/srt/layers/activation.py +4 -2
  41. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  42. sglang/srt/layers/attention/flashattention_backend.py +434 -0
  43. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  44. sglang/srt/layers/attention/flashmla_backend.py +284 -0
  45. sglang/srt/layers/attention/triton_backend.py +171 -38
  46. sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
  47. sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
  48. sglang/srt/layers/attention/utils.py +53 -0
  49. sglang/srt/layers/attention/vision.py +9 -28
  50. sglang/srt/layers/dp_attention.py +41 -19
  51. sglang/srt/layers/layernorm.py +24 -2
  52. sglang/srt/layers/linear.py +17 -5
  53. sglang/srt/layers/logits_processor.py +25 -7
  54. sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
  55. sglang/srt/layers/moe/ep_moe/layer.py +273 -1
  56. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
  57. sglang/srt/layers/moe/fused_moe_native.py +2 -1
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
  63. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
  64. sglang/srt/layers/moe/topk.py +60 -20
  65. sglang/srt/layers/parameter.py +1 -1
  66. sglang/srt/layers/quantization/__init__.py +80 -53
  67. sglang/srt/layers/quantization/awq.py +200 -0
  68. sglang/srt/layers/quantization/base_config.py +5 -0
  69. sglang/srt/layers/quantization/blockwise_int8.py +1 -1
  70. sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
  71. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
  72. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
  73. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
  74. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
  75. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
  76. sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
  77. sglang/srt/layers/quantization/fp8.py +76 -34
  78. sglang/srt/layers/quantization/fp8_kernel.py +25 -8
  79. sglang/srt/layers/quantization/fp8_utils.py +284 -28
  80. sglang/srt/layers/quantization/gptq.py +36 -19
  81. sglang/srt/layers/quantization/kv_cache.py +98 -0
  82. sglang/srt/layers/quantization/modelopt_quant.py +9 -7
  83. sglang/srt/layers/quantization/utils.py +153 -0
  84. sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
  85. sglang/srt/layers/rotary_embedding.py +78 -87
  86. sglang/srt/layers/sampler.py +1 -1
  87. sglang/srt/lora/backend/base_backend.py +4 -4
  88. sglang/srt/lora/backend/flashinfer_backend.py +12 -9
  89. sglang/srt/lora/backend/triton_backend.py +5 -8
  90. sglang/srt/lora/layers.py +87 -33
  91. sglang/srt/lora/lora.py +2 -22
  92. sglang/srt/lora/lora_manager.py +67 -30
  93. sglang/srt/lora/mem_pool.py +117 -52
  94. sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
  95. sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
  96. sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
  97. sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
  98. sglang/srt/lora/utils.py +18 -1
  99. sglang/srt/managers/cache_controller.py +2 -5
  100. sglang/srt/managers/data_parallel_controller.py +30 -8
  101. sglang/srt/managers/expert_distribution.py +81 -0
  102. sglang/srt/managers/io_struct.py +43 -5
  103. sglang/srt/managers/mm_utils.py +373 -0
  104. sglang/srt/managers/multimodal_processor.py +68 -0
  105. sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
  106. sglang/srt/managers/multimodal_processors/clip.py +63 -0
  107. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
  108. sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
  109. sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
  110. sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
  111. sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
  112. sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
  113. sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
  114. sglang/srt/managers/schedule_batch.py +134 -30
  115. sglang/srt/managers/scheduler.py +290 -31
  116. sglang/srt/managers/session_controller.py +1 -1
  117. sglang/srt/managers/tokenizer_manager.py +59 -24
  118. sglang/srt/managers/tp_worker.py +4 -1
  119. sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
  120. sglang/srt/managers/utils.py +6 -1
  121. sglang/srt/mem_cache/hiradix_cache.py +18 -7
  122. sglang/srt/mem_cache/memory_pool.py +255 -98
  123. sglang/srt/mem_cache/paged_allocator.py +2 -2
  124. sglang/srt/mem_cache/radix_cache.py +4 -4
  125. sglang/srt/model_executor/cuda_graph_runner.py +36 -21
  126. sglang/srt/model_executor/forward_batch_info.py +68 -11
  127. sglang/srt/model_executor/model_runner.py +75 -8
  128. sglang/srt/model_loader/loader.py +171 -3
  129. sglang/srt/model_loader/weight_utils.py +51 -3
  130. sglang/srt/models/clip.py +563 -0
  131. sglang/srt/models/deepseek_janus_pro.py +31 -88
  132. sglang/srt/models/deepseek_nextn.py +22 -10
  133. sglang/srt/models/deepseek_v2.py +329 -73
  134. sglang/srt/models/deepseek_vl2.py +358 -0
  135. sglang/srt/models/gemma3_causal.py +694 -0
  136. sglang/srt/models/gemma3_mm.py +468 -0
  137. sglang/srt/models/llama.py +47 -7
  138. sglang/srt/models/llama_eagle.py +1 -0
  139. sglang/srt/models/llama_eagle3.py +196 -0
  140. sglang/srt/models/llava.py +3 -3
  141. sglang/srt/models/llavavid.py +3 -3
  142. sglang/srt/models/minicpmo.py +1995 -0
  143. sglang/srt/models/minicpmv.py +62 -137
  144. sglang/srt/models/mllama.py +4 -4
  145. sglang/srt/models/phi3_small.py +1 -1
  146. sglang/srt/models/qwen2.py +3 -0
  147. sglang/srt/models/qwen2_5_vl.py +68 -146
  148. sglang/srt/models/qwen2_classification.py +75 -0
  149. sglang/srt/models/qwen2_moe.py +9 -1
  150. sglang/srt/models/qwen2_vl.py +25 -63
  151. sglang/srt/openai_api/adapter.py +201 -104
  152. sglang/srt/openai_api/protocol.py +33 -7
  153. sglang/srt/patch_torch.py +71 -0
  154. sglang/srt/sampling/sampling_batch_info.py +1 -1
  155. sglang/srt/sampling/sampling_params.py +6 -6
  156. sglang/srt/server_args.py +114 -14
  157. sglang/srt/speculative/build_eagle_tree.py +7 -347
  158. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
  159. sglang/srt/speculative/eagle_utils.py +208 -252
  160. sglang/srt/speculative/eagle_worker.py +140 -54
  161. sglang/srt/speculative/spec_info.py +6 -1
  162. sglang/srt/torch_memory_saver_adapter.py +22 -0
  163. sglang/srt/utils.py +215 -21
  164. sglang/test/__init__.py +0 -0
  165. sglang/test/attention/__init__.py +0 -0
  166. sglang/test/attention/test_flashattn_backend.py +312 -0
  167. sglang/test/runners.py +29 -2
  168. sglang/test/test_activation.py +2 -1
  169. sglang/test/test_block_fp8.py +5 -4
  170. sglang/test/test_block_fp8_ep.py +2 -1
  171. sglang/test/test_dynamic_grad_mode.py +58 -0
  172. sglang/test/test_layernorm.py +3 -2
  173. sglang/test/test_utils.py +56 -5
  174. sglang/utils.py +31 -0
  175. sglang/version.py +1 -1
  176. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/METADATA +16 -8
  177. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/RECORD +180 -132
  178. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/WHEEL +1 -1
  179. sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
  180. sglang/srt/managers/image_processor.py +0 -55
  181. sglang/srt/managers/image_processors/base_image_processor.py +0 -219
  182. sglang/srt/managers/image_processors/minicpmv.py +0 -86
  183. sglang/srt/managers/multi_modality_padding.py +0 -134
  184. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info/licenses}/LICENSE +0 -0
  185. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/top_level.txt +0 -0
@@ -51,13 +51,14 @@ class ModelConfig:
51
51
  self.quantization = quantization
52
52
 
53
53
  # Parse args
54
+ self.maybe_pull_model_tokenizer_from_remote()
54
55
  self.model_override_args = json.loads(model_override_args)
55
56
  kwargs = {}
56
57
  if override_config_file and override_config_file.strip():
57
58
  kwargs["_configuration_file"] = override_config_file.strip()
58
59
 
59
60
  self.hf_config = get_config(
60
- model_path,
61
+ self.model_path,
61
62
  trust_remote_code=trust_remote_code,
62
63
  revision=revision,
63
64
  model_override_args=self.model_override_args,
@@ -134,6 +135,11 @@ class ModelConfig:
134
135
  self.attention_arch = AttentionArch.MLA
135
136
  self.kv_lora_rank = self.hf_config.kv_lora_rank
136
137
  self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
138
+ elif "DeepseekVL2ForCausalLM" in self.hf_config.architectures:
139
+ self.head_dim = 256
140
+ self.attention_arch = AttentionArch.MLA
141
+ self.kv_lora_rank = self.hf_text_config.kv_lora_rank
142
+ self.qk_rope_head_dim = self.hf_text_config.qk_rope_head_dim
137
143
  else:
138
144
  self.attention_arch = AttentionArch.MHA
139
145
 
@@ -318,6 +324,29 @@ class ModelConfig:
318
324
  eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
319
325
  return eos_ids
320
326
 
327
+ def maybe_pull_model_tokenizer_from_remote(self) -> None:
328
+ """
329
+ Pull the model config files to a temporary
330
+ directory in case of remote.
331
+
332
+ Args:
333
+ model: The model name or path.
334
+
335
+ """
336
+ from sglang.srt.connector import create_remote_connector
337
+ from sglang.srt.utils import is_remote_url
338
+
339
+ if is_remote_url(self.model_path):
340
+ logger.info("Pulling model configs from remote...")
341
+ # BaseConnector implements __del__() to clean up the local dir.
342
+ # Since config files need to exist all the time, so we DO NOT use
343
+ # with statement to avoid closing the client.
344
+ client = create_remote_connector(self.model_path)
345
+ if is_remote_url(self.model_path):
346
+ client.pull_files(allow_pattern=["*config.json"])
347
+ self.model_weights = self.model_path
348
+ self.model_path = client.get_local_dir()
349
+
321
350
 
322
351
  def get_hf_text_config(config: PretrainedConfig):
323
352
  """Get the "sub" config relevant to llm for multi modal models.
@@ -338,6 +367,8 @@ def get_hf_text_config(config: PretrainedConfig):
338
367
  # if transformers config doesn't align with this assumption.
339
368
  assert hasattr(config.text_config, "num_attention_heads")
340
369
  return config.text_config
370
+ if hasattr(config, "language_config"):
371
+ return config.language_config
341
372
  else:
342
373
  return config
343
374
 
@@ -367,9 +398,13 @@ def _get_and_verify_dtype(
367
398
  dtype = dtype.lower()
368
399
  if dtype == "auto":
369
400
  if config_dtype == torch.float32:
370
- if config.model_type == "gemma2":
401
+ if config.model_type.startswith("gemma"):
402
+ if config.model_type == "gemma":
403
+ gemma_version = ""
404
+ else:
405
+ gemma_version = config.model_type[5]
371
406
  logger.info(
372
- "For Gemma 2, we downcast float32 to bfloat16 instead "
407
+ f"For Gemma {gemma_version}, we downcast float32 to bfloat16 instead "
373
408
  "of float16 by default. Please specify `dtype` if you "
374
409
  "want to use float16."
375
410
  )
@@ -418,6 +453,8 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal
418
453
  or "LlamaForSequenceClassificationWithNormal_Weights" in model_architectures
419
454
  or "InternLM2ForRewardModel" in model_architectures
420
455
  or "Qwen2ForRewardModel" in model_architectures
456
+ or "Qwen2ForSequenceClassification" in model_architectures
457
+ or "CLIPModel" in model_architectures
421
458
  ):
422
459
  return False
423
460
  else:
@@ -425,17 +462,21 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal
425
462
 
426
463
 
427
464
  multimodal_model_archs = [
465
+ "DeepseekVL2ForCausalLM",
466
+ "Gemma3ForConditionalGeneration",
467
+ "Grok1VForCausalLM",
468
+ "Grok1AForCausalLM",
428
469
  "LlavaLlamaForCausalLM",
429
- "LlavaQwenForCausalLM",
430
470
  "LlavaMistralForCausalLM",
471
+ "LlavaQwenForCausalLM",
431
472
  "LlavaVidForCausalLM",
432
- "Grok1VForCausalLM",
433
- "Grok1AForCausalLM",
473
+ "MiniCPMO",
474
+ "MiniCPMV",
475
+ "MultiModalityCausalLM",
434
476
  "MllamaForConditionalGeneration",
435
477
  "Qwen2VLForConditionalGeneration",
436
478
  "Qwen2_5_VLForConditionalGeneration",
437
- "MiniCPMV",
438
- "MultiModalityCausalLM",
479
+ "CLIPModel",
439
480
  ]
440
481
 
441
482
 
@@ -0,0 +1,25 @@
1
+ from typing import Type
2
+
3
+ from transformers import (
4
+ AutoImageProcessor,
5
+ AutoProcessor,
6
+ BaseImageProcessor,
7
+ PretrainedConfig,
8
+ ProcessorMixin,
9
+ )
10
+
11
+
12
+ def register_image_processor(
13
+ config: Type[PretrainedConfig], image_processor: Type[BaseImageProcessor]
14
+ ):
15
+ """
16
+ register customized hf image processor while removing hf impl
17
+ """
18
+ AutoImageProcessor.register(config, None, image_processor, None, exist_ok=True)
19
+
20
+
21
+ def register_processor(config: Type[PretrainedConfig], processor: Type[ProcessorMixin]):
22
+ """
23
+ register customized hf processor while removing hf impl
24
+ """
25
+ AutoProcessor.register(config, processor, exist_ok=True)
@@ -0,0 +1,51 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import enum
4
+ import logging
5
+
6
+ from sglang.srt.connector.base_connector import (
7
+ BaseConnector,
8
+ BaseFileConnector,
9
+ BaseKVConnector,
10
+ )
11
+ from sglang.srt.connector.redis import RedisConnector
12
+ from sglang.srt.connector.s3 import S3Connector
13
+ from sglang.srt.utils import parse_connector_type
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class ConnectorType(str, enum.Enum):
19
+ FS = "filesystem"
20
+ KV = "KV"
21
+
22
+
23
+ def create_remote_connector(url, device="cpu") -> BaseConnector:
24
+ connector_type = parse_connector_type(url)
25
+ if connector_type == "redis":
26
+ return RedisConnector(url)
27
+ elif connector_type == "s3":
28
+ return S3Connector(url)
29
+ else:
30
+ raise ValueError(f"Invalid connector type: {url}")
31
+
32
+
33
+ def get_connector_type(client: BaseConnector) -> ConnectorType:
34
+ if isinstance(client, BaseKVConnector):
35
+ return ConnectorType.KV
36
+ if isinstance(client, BaseFileConnector):
37
+ return ConnectorType.FS
38
+
39
+ raise ValueError(f"Invalid connector type: {client}")
40
+
41
+
42
+ __all__ = [
43
+ "BaseConnector",
44
+ "BaseFileConnector",
45
+ "BaseKVConnector",
46
+ "RedisConnector",
47
+ "S3Connector",
48
+ "ConnectorType",
49
+ "create_remote_connector",
50
+ "get_connector_type",
51
+ ]
@@ -0,0 +1,112 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import os
4
+ import shutil
5
+ import signal
6
+ import tempfile
7
+ from abc import ABC, abstractmethod
8
+ from typing import Generator, List, Optional, Tuple
9
+
10
+ import torch
11
+
12
+
13
+ class BaseConnector(ABC):
14
+ """
15
+ For fs connector such as s3:
16
+ <connector_type>://<path>/<filename>
17
+
18
+ For kv connector such as redis:
19
+ <connector_type>://<host>:<port>/<model_name>/keys/<key>
20
+ <connector_type://<host>:<port>/<model_name>/files/<filename>
21
+ """
22
+
23
+ def __init__(self, url: str, device: torch.device = "cpu"):
24
+ self.url = url
25
+ self.device = device
26
+ self.closed = False
27
+ self.local_dir = tempfile.mkdtemp()
28
+ for sig in (signal.SIGINT, signal.SIGTERM):
29
+ existing_handler = signal.getsignal(sig)
30
+ signal.signal(sig, self._close_by_signal(existing_handler))
31
+
32
+ def get_local_dir(self):
33
+ return self.local_dir
34
+
35
+ @abstractmethod
36
+ def weight_iterator(
37
+ self, rank: int = 0
38
+ ) -> Generator[Tuple[str, torch.Tensor], None, None]:
39
+ raise NotImplementedError()
40
+
41
+ @abstractmethod
42
+ def pull_files(
43
+ self,
44
+ allow_pattern: Optional[List[str]] = None,
45
+ ignore_pattern: Optional[List[str]] = None,
46
+ ) -> None:
47
+ raise NotImplementedError()
48
+
49
+ def close(self):
50
+ if self.closed:
51
+ return
52
+
53
+ self.closed = True
54
+ if os.path.exists(self.local_dir):
55
+ shutil.rmtree(self.local_dir)
56
+
57
+ def __enter__(self):
58
+ return self
59
+
60
+ def __exit__(self, exc_type, exc_value, traceback):
61
+ self.close()
62
+
63
+ def __del__(self):
64
+ self.close()
65
+
66
+ def _close_by_signal(self, existing_handler=None):
67
+
68
+ def new_handler(signum, frame):
69
+ self.close()
70
+ if existing_handler:
71
+ existing_handler(signum, frame)
72
+
73
+ return new_handler
74
+
75
+
76
+ class BaseKVConnector(BaseConnector):
77
+
78
+ @abstractmethod
79
+ def get(self, key: str) -> Optional[torch.Tensor]:
80
+ raise NotImplementedError()
81
+
82
+ @abstractmethod
83
+ def getstr(self, key: str) -> Optional[str]:
84
+ raise NotImplementedError()
85
+
86
+ @abstractmethod
87
+ def set(self, key: str, obj: torch.Tensor) -> None:
88
+ raise NotImplementedError()
89
+
90
+ @abstractmethod
91
+ def setstr(self, key: str, obj: str) -> None:
92
+ raise NotImplementedError()
93
+
94
+ @abstractmethod
95
+ def list(self, prefix: str) -> List[str]:
96
+ raise NotImplementedError()
97
+
98
+
99
+ class BaseFileConnector(BaseConnector):
100
+ """
101
+ List full file names from remote fs path and filter by allow pattern.
102
+
103
+ Args:
104
+ allow_pattern: A list of patterns of which files to pull.
105
+
106
+ Returns:
107
+ list[str]: List of full paths allowed by the pattern
108
+ """
109
+
110
+ @abstractmethod
111
+ def glob(self, allow_pattern: str) -> List[str]:
112
+ raise NotImplementedError()
@@ -0,0 +1,85 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import logging
4
+ from typing import Generator, List, Optional, Tuple
5
+ from urllib.parse import urlparse
6
+
7
+ import torch
8
+
9
+ from sglang.srt.connector import BaseKVConnector
10
+ from sglang.srt.connector.serde import create_serde
11
+ from sglang.srt.connector.utils import pull_files_from_db
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class RedisConnector(BaseKVConnector):
17
+
18
+ def __init__(self, url: str, device: torch.device = "cpu"):
19
+ import redis
20
+
21
+ super().__init__(url, device)
22
+ parsed_url = urlparse(url)
23
+ self.connection = redis.Redis(host=parsed_url.hostname, port=parsed_url.port)
24
+ self.model_name = parsed_url.path.lstrip("/")
25
+ # TODO: more serde options
26
+ self.s, self.d = create_serde("safe")
27
+
28
+ def get(self, key: str) -> Optional[torch.Tensor]:
29
+ val = self.connection.get(key)
30
+
31
+ if val is None:
32
+ logger.error("Key %s not found", key)
33
+ return None
34
+
35
+ return self.d.from_bytes(val)
36
+
37
+ def getstr(self, key: str) -> Optional[str]:
38
+ val = self.connection.get(key)
39
+ if val is None:
40
+ logger.error("Key %s not found", key)
41
+ return None
42
+
43
+ return val.decode("utf-8")
44
+
45
+ def set(self, key: str, tensor: torch.Tensor) -> None:
46
+ assert tensor is not None
47
+ self.connection.set(key, self.s.to_bytes(tensor))
48
+
49
+ def setstr(self, key: str, obj: str) -> None:
50
+ self.connection.set(key, obj)
51
+
52
+ def list(self, prefix: str) -> List[str]:
53
+ cursor = 0
54
+ all_keys: List[bytes] = []
55
+
56
+ while True:
57
+ ret: Tuple[int, List[bytes]] = self.connection.scan(
58
+ cursor=cursor, match=f"{prefix}*"
59
+ ) # type: ignore
60
+ cursor, keys = ret
61
+ all_keys.extend(keys)
62
+ if cursor == 0:
63
+ break
64
+
65
+ return [key.decode("utf-8") for key in all_keys]
66
+
67
+ def weight_iterator(
68
+ self, rank: int = 0
69
+ ) -> Generator[Tuple[str, bytes], None, None]:
70
+ keys = self.list(f"{self.model_name}/keys/rank_{rank}/")
71
+ for key in keys:
72
+ val = self.get(key)
73
+ key = key.removeprefix(f"{self.model_name}/keys/rank_{rank}/")
74
+ yield key, val
75
+
76
+ def pull_files(
77
+ self,
78
+ allow_pattern: Optional[List[str]] = None,
79
+ ignore_pattern: Optional[List[str]] = None,
80
+ ) -> None:
81
+ pull_files_from_db(self, self.model_name, allow_pattern, ignore_pattern)
82
+
83
+ def close(self):
84
+ self.connection.close()
85
+ super().close()
@@ -0,0 +1,122 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import fnmatch
4
+ import os
5
+ from pathlib import Path
6
+ from typing import Generator, Optional, Tuple
7
+
8
+ import torch
9
+
10
+ from sglang.srt.connector import BaseFileConnector
11
+
12
+
13
+ def _filter_allow(paths: list[str], patterns: list[str]) -> list[str]:
14
+ return [
15
+ path
16
+ for path in paths
17
+ if any(fnmatch.fnmatch(path, pattern) for pattern in patterns)
18
+ ]
19
+
20
+
21
+ def _filter_ignore(paths: list[str], patterns: list[str]) -> list[str]:
22
+ return [
23
+ path
24
+ for path in paths
25
+ if not any(fnmatch.fnmatch(path, pattern) for pattern in patterns)
26
+ ]
27
+
28
+
29
+ def list_files(
30
+ s3,
31
+ path: str,
32
+ allow_pattern: Optional[list[str]] = None,
33
+ ignore_pattern: Optional[list[str]] = None,
34
+ ) -> tuple[str, str, list[str]]:
35
+ """
36
+ List files from S3 path and filter by pattern.
37
+
38
+ Args:
39
+ s3: S3 client to use.
40
+ path: The S3 path to list from.
41
+ allow_pattern: A list of patterns of which files to pull.
42
+ ignore_pattern: A list of patterns of which files not to pull.
43
+
44
+ Returns:
45
+ tuple[str, str, list[str]]: A tuple where:
46
+ - The first element is the bucket name
47
+ - The second element is string represent the bucket
48
+ and the prefix as a dir like string
49
+ - The third element is a list of files allowed or
50
+ disallowed by pattern
51
+ """
52
+ parts = path.removeprefix("s3://").split("/")
53
+ prefix = "/".join(parts[1:])
54
+ bucket_name = parts[0]
55
+
56
+ objects = s3.list_objects_v2(Bucket=bucket_name, Prefix=prefix)
57
+ paths = [obj["Key"] for obj in objects.get("Contents", [])]
58
+
59
+ paths = _filter_ignore(paths, ["*/"])
60
+ if allow_pattern is not None:
61
+ paths = _filter_allow(paths, allow_pattern)
62
+
63
+ if ignore_pattern is not None:
64
+ paths = _filter_ignore(paths, ignore_pattern)
65
+
66
+ return bucket_name, prefix, paths
67
+
68
+
69
+ class S3Connector(BaseFileConnector):
70
+
71
+ def __init__(self, url: str) -> None:
72
+ import boto3
73
+
74
+ super().__init__(url)
75
+ self.client = boto3.client("s3")
76
+
77
+ def glob(self, allow_pattern: Optional[list[str]] = None) -> list[str]:
78
+ bucket_name, _, paths = list_files(
79
+ self.client, path=self.url, allow_pattern=allow_pattern
80
+ )
81
+ return [f"s3://{bucket_name}/{path}" for path in paths]
82
+
83
+ def pull_files(
84
+ self,
85
+ allow_pattern: Optional[list[str]] = None,
86
+ ignore_pattern: Optional[list[str]] = None,
87
+ ) -> None:
88
+ """
89
+ Pull files from S3 storage into the temporary directory.
90
+
91
+ Args:
92
+ s3_model_path: The S3 path of the model.
93
+ allow_pattern: A list of patterns of which files to pull.
94
+ ignore_pattern: A list of patterns of which files not to pull.
95
+
96
+ """
97
+ bucket_name, base_dir, files = list_files(
98
+ self.client, self.url, allow_pattern, ignore_pattern
99
+ )
100
+ if len(files) == 0:
101
+ return
102
+
103
+ for file in files:
104
+ destination_file = os.path.join(self.local_dir, file.removeprefix(base_dir))
105
+ local_dir = Path(destination_file).parent
106
+ os.makedirs(local_dir, exist_ok=True)
107
+ self.client.download_file(bucket_name, file, destination_file)
108
+
109
+ def weight_iterator(
110
+ self, rank: int = 0
111
+ ) -> Generator[Tuple[str, torch.Tensor], None, None]:
112
+ from sglang.srt.model_loader.weight_utils import (
113
+ runai_safetensors_weights_iterator,
114
+ )
115
+
116
+ # only support safetensor files now
117
+ hf_weights_files = self.glob(allow_pattern=["*.safetensors"])
118
+ return runai_safetensors_weights_iterator(hf_weights_files)
119
+
120
+ def close(self):
121
+ self.client.close()
122
+ super().close()
@@ -0,0 +1,31 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ # inspired by LMCache
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+
8
+ from sglang.srt.connector.serde.safe_serde import SafeDeserializer, SafeSerializer
9
+ from sglang.srt.connector.serde.serde import Deserializer, Serializer
10
+
11
+
12
+ def create_serde(serde_type: str) -> Tuple[Serializer, Deserializer]:
13
+ s: Optional[Serializer] = None
14
+ d: Optional[Deserializer] = None
15
+
16
+ if serde_type == "safe":
17
+ s = SafeSerializer()
18
+ d = SafeDeserializer(torch.uint8)
19
+ else:
20
+ raise ValueError(f"Unknown serde type: {serde_type}")
21
+
22
+ return s, d
23
+
24
+
25
+ __all__ = [
26
+ "Serializer",
27
+ "Deserializer",
28
+ "SafeSerializer",
29
+ "SafeDeserializer",
30
+ "create_serde",
31
+ ]
@@ -0,0 +1,29 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from typing import Union
4
+
5
+ import torch
6
+ from safetensors.torch import load, save
7
+
8
+ from sglang.srt.connector.serde.serde import Deserializer, Serializer
9
+
10
+
11
+ class SafeSerializer(Serializer):
12
+
13
+ def __init__(self):
14
+ super().__init__()
15
+
16
+ def to_bytes(self, t: torch.Tensor) -> bytes:
17
+ return save({"tensor_bytes": t.cpu().contiguous()})
18
+
19
+
20
+ class SafeDeserializer(Deserializer):
21
+
22
+ def __init__(self, dtype):
23
+ super().__init__(dtype)
24
+
25
+ def from_bytes_normal(self, b: Union[bytearray, bytes]) -> torch.Tensor:
26
+ return load(bytes(b))["tensor_bytes"].to(dtype=self.dtype)
27
+
28
+ def from_bytes(self, b: Union[bytearray, bytes]) -> torch.Tensor:
29
+ return self.from_bytes_normal(b)
@@ -0,0 +1,43 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import abc
4
+ from abc import ABC, abstractmethod
5
+
6
+ import torch
7
+
8
+
9
+ class Serializer(ABC):
10
+
11
+ @abstractmethod
12
+ def to_bytes(self, t: torch.Tensor) -> bytes:
13
+ """
14
+ Serialize a pytorch tensor to bytes. The serialized bytes should contain
15
+ both the data and the metadata (shape, dtype, etc.) of the tensor.
16
+
17
+ Input:
18
+ t: the input pytorch tensor, can be on any device, in any shape,
19
+ with any dtype
20
+
21
+ Returns:
22
+ bytes: the serialized bytes
23
+ """
24
+ raise NotImplementedError
25
+
26
+
27
+ class Deserializer(metaclass=abc.ABCMeta):
28
+
29
+ def __init__(self, dtype):
30
+ self.dtype = dtype
31
+
32
+ @abstractmethod
33
+ def from_bytes(self, bs: bytes) -> torch.Tensor:
34
+ """
35
+ Deserialize a pytorch tensor from bytes.
36
+
37
+ Input:
38
+ bytes: a stream of bytes
39
+
40
+ Output:
41
+ torch.Tensor: the deserialized pytorch tensor
42
+ """
43
+ raise NotImplementedError
@@ -0,0 +1,35 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import os
4
+ from pathlib import Path
5
+ from typing import Optional
6
+ from urllib.parse import urlparse
7
+
8
+ from sglang.srt.connector import BaseConnector
9
+
10
+
11
+ def parse_model_name(url: str) -> str:
12
+ """
13
+ Parse the model name from the url.
14
+ Only used for db connector
15
+ """
16
+ parsed_url = urlparse(url)
17
+ return parsed_url.path.lstrip("/")
18
+
19
+
20
+ def pull_files_from_db(
21
+ connector: BaseConnector,
22
+ model_name: str,
23
+ allow_pattern: Optional[list[str]] = None,
24
+ ignore_pattern: Optional[list[str]] = None,
25
+ ) -> None:
26
+ prefix = f"{model_name}/files/"
27
+ local_dir = connector.get_local_dir()
28
+ files = connector.list(prefix)
29
+
30
+ for file in files:
31
+ destination_file = os.path.join(local_dir, file.removeprefix(prefix))
32
+ local_dir = Path(destination_file).parent
33
+ os.makedirs(local_dir, exist_ok=True)
34
+ with open(destination_file, "wb") as f:
35
+ f.write(connector.getstr(file).encode("utf-8"))