sglang 0.4.4__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (176) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +6 -0
  3. sglang/bench_one_batch.py +1 -1
  4. sglang/bench_one_batch_server.py +1 -1
  5. sglang/bench_serving.py +3 -1
  6. sglang/check_env.py +3 -4
  7. sglang/lang/backend/openai.py +18 -5
  8. sglang/lang/chat_template.py +28 -7
  9. sglang/lang/interpreter.py +7 -3
  10. sglang/lang/ir.py +10 -0
  11. sglang/srt/_custom_ops.py +1 -1
  12. sglang/srt/code_completion_parser.py +174 -0
  13. sglang/srt/configs/__init__.py +2 -6
  14. sglang/srt/configs/deepseekvl2.py +667 -0
  15. sglang/srt/configs/janus_pro.py +3 -4
  16. sglang/srt/configs/load_config.py +1 -0
  17. sglang/srt/configs/model_config.py +63 -11
  18. sglang/srt/configs/utils.py +25 -0
  19. sglang/srt/connector/__init__.py +51 -0
  20. sglang/srt/connector/base_connector.py +112 -0
  21. sglang/srt/connector/redis.py +85 -0
  22. sglang/srt/connector/s3.py +122 -0
  23. sglang/srt/connector/serde/__init__.py +31 -0
  24. sglang/srt/connector/serde/safe_serde.py +29 -0
  25. sglang/srt/connector/serde/serde.py +43 -0
  26. sglang/srt/connector/utils.py +35 -0
  27. sglang/srt/conversation.py +88 -0
  28. sglang/srt/disaggregation/conn.py +81 -0
  29. sglang/srt/disaggregation/decode.py +495 -0
  30. sglang/srt/disaggregation/mini_lb.py +285 -0
  31. sglang/srt/disaggregation/prefill.py +249 -0
  32. sglang/srt/disaggregation/utils.py +44 -0
  33. sglang/srt/distributed/parallel_state.py +10 -3
  34. sglang/srt/entrypoints/engine.py +55 -5
  35. sglang/srt/entrypoints/http_server.py +71 -12
  36. sglang/srt/function_call_parser.py +164 -54
  37. sglang/srt/hf_transformers_utils.py +28 -3
  38. sglang/srt/layers/activation.py +4 -2
  39. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +295 -0
  41. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  42. sglang/srt/layers/attention/flashmla_backend.py +284 -0
  43. sglang/srt/layers/attention/triton_backend.py +171 -38
  44. sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
  45. sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
  46. sglang/srt/layers/attention/utils.py +53 -0
  47. sglang/srt/layers/attention/vision.py +9 -28
  48. sglang/srt/layers/dp_attention.py +62 -23
  49. sglang/srt/layers/elementwise.py +411 -0
  50. sglang/srt/layers/layernorm.py +24 -2
  51. sglang/srt/layers/linear.py +17 -5
  52. sglang/srt/layers/logits_processor.py +26 -7
  53. sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
  54. sglang/srt/layers/moe/ep_moe/layer.py +273 -1
  55. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
  56. sglang/srt/layers/moe/fused_moe_native.py +2 -1
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
  63. sglang/srt/layers/moe/router.py +342 -0
  64. sglang/srt/layers/moe/topk.py +31 -18
  65. sglang/srt/layers/parameter.py +1 -1
  66. sglang/srt/layers/quantization/__init__.py +184 -126
  67. sglang/srt/layers/quantization/base_config.py +5 -0
  68. sglang/srt/layers/quantization/blockwise_int8.py +1 -1
  69. sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
  70. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
  71. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
  72. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
  73. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
  74. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
  75. sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
  76. sglang/srt/layers/quantization/fp8.py +76 -34
  77. sglang/srt/layers/quantization/fp8_kernel.py +24 -8
  78. sglang/srt/layers/quantization/fp8_utils.py +284 -28
  79. sglang/srt/layers/quantization/gptq.py +36 -9
  80. sglang/srt/layers/quantization/kv_cache.py +98 -0
  81. sglang/srt/layers/quantization/modelopt_quant.py +9 -7
  82. sglang/srt/layers/quantization/utils.py +153 -0
  83. sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
  84. sglang/srt/layers/rotary_embedding.py +66 -87
  85. sglang/srt/layers/sampler.py +1 -1
  86. sglang/srt/lora/layers.py +68 -0
  87. sglang/srt/lora/lora.py +2 -22
  88. sglang/srt/lora/lora_manager.py +47 -23
  89. sglang/srt/lora/mem_pool.py +110 -51
  90. sglang/srt/lora/utils.py +12 -1
  91. sglang/srt/managers/cache_controller.py +4 -5
  92. sglang/srt/managers/data_parallel_controller.py +31 -9
  93. sglang/srt/managers/expert_distribution.py +81 -0
  94. sglang/srt/managers/io_struct.py +39 -3
  95. sglang/srt/managers/mm_utils.py +373 -0
  96. sglang/srt/managers/multimodal_processor.py +68 -0
  97. sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
  98. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
  99. sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
  100. sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
  101. sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
  102. sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
  103. sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
  104. sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
  105. sglang/srt/managers/schedule_batch.py +134 -31
  106. sglang/srt/managers/scheduler.py +325 -38
  107. sglang/srt/managers/scheduler_output_processor_mixin.py +4 -1
  108. sglang/srt/managers/session_controller.py +1 -1
  109. sglang/srt/managers/tokenizer_manager.py +59 -23
  110. sglang/srt/managers/tp_worker.py +1 -1
  111. sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
  112. sglang/srt/managers/utils.py +6 -1
  113. sglang/srt/mem_cache/hiradix_cache.py +27 -8
  114. sglang/srt/mem_cache/memory_pool.py +258 -98
  115. sglang/srt/mem_cache/paged_allocator.py +2 -2
  116. sglang/srt/mem_cache/radix_cache.py +4 -4
  117. sglang/srt/model_executor/cuda_graph_runner.py +85 -28
  118. sglang/srt/model_executor/forward_batch_info.py +81 -15
  119. sglang/srt/model_executor/model_runner.py +70 -6
  120. sglang/srt/model_loader/loader.py +160 -2
  121. sglang/srt/model_loader/weight_utils.py +45 -0
  122. sglang/srt/models/deepseek_janus_pro.py +29 -86
  123. sglang/srt/models/deepseek_nextn.py +22 -10
  124. sglang/srt/models/deepseek_v2.py +326 -192
  125. sglang/srt/models/deepseek_vl2.py +358 -0
  126. sglang/srt/models/gemma3_causal.py +684 -0
  127. sglang/srt/models/gemma3_mm.py +462 -0
  128. sglang/srt/models/grok.py +374 -119
  129. sglang/srt/models/llama.py +47 -7
  130. sglang/srt/models/llama_eagle.py +1 -0
  131. sglang/srt/models/llama_eagle3.py +196 -0
  132. sglang/srt/models/llava.py +3 -3
  133. sglang/srt/models/llavavid.py +3 -3
  134. sglang/srt/models/minicpmo.py +1995 -0
  135. sglang/srt/models/minicpmv.py +62 -137
  136. sglang/srt/models/mllama.py +4 -4
  137. sglang/srt/models/phi3_small.py +1 -1
  138. sglang/srt/models/qwen2.py +3 -0
  139. sglang/srt/models/qwen2_5_vl.py +68 -146
  140. sglang/srt/models/qwen2_classification.py +75 -0
  141. sglang/srt/models/qwen2_moe.py +9 -1
  142. sglang/srt/models/qwen2_vl.py +25 -63
  143. sglang/srt/openai_api/adapter.py +145 -47
  144. sglang/srt/openai_api/protocol.py +23 -2
  145. sglang/srt/sampling/sampling_batch_info.py +1 -1
  146. sglang/srt/sampling/sampling_params.py +6 -6
  147. sglang/srt/server_args.py +104 -14
  148. sglang/srt/speculative/build_eagle_tree.py +7 -347
  149. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
  150. sglang/srt/speculative/eagle_utils.py +208 -252
  151. sglang/srt/speculative/eagle_worker.py +139 -53
  152. sglang/srt/speculative/spec_info.py +6 -1
  153. sglang/srt/torch_memory_saver_adapter.py +22 -0
  154. sglang/srt/utils.py +182 -21
  155. sglang/test/__init__.py +0 -0
  156. sglang/test/attention/__init__.py +0 -0
  157. sglang/test/attention/test_flashattn_backend.py +312 -0
  158. sglang/test/runners.py +2 -0
  159. sglang/test/test_activation.py +2 -1
  160. sglang/test/test_block_fp8.py +5 -4
  161. sglang/test/test_block_fp8_ep.py +2 -1
  162. sglang/test/test_dynamic_grad_mode.py +58 -0
  163. sglang/test/test_layernorm.py +3 -2
  164. sglang/test/test_utils.py +55 -4
  165. sglang/utils.py +31 -0
  166. sglang/version.py +1 -1
  167. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
  168. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +171 -125
  169. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
  170. sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
  171. sglang/srt/managers/image_processor.py +0 -55
  172. sglang/srt/managers/image_processors/base_image_processor.py +0 -219
  173. sglang/srt/managers/image_processors/minicpmv.py +0 -86
  174. sglang/srt/managers/multi_modality_padding.py +0 -134
  175. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
  176. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
@@ -22,7 +22,11 @@ import torch
22
22
  from transformers import PretrainedConfig
23
23
 
24
24
  from sglang.srt.hf_transformers_utils import get_config, get_context_length
25
- from sglang.srt.layers.quantization import QUANTIZATION_METHODS
25
+ from sglang.srt.layers.quantization import (
26
+ BASE_QUANTIZATION_METHODS,
27
+ QUANTIZATION_METHODS,
28
+ VLLM_AVAILABLE,
29
+ )
26
30
  from sglang.srt.utils import get_bool_env_var, is_hip
27
31
 
28
32
  logger = logging.getLogger(__name__)
@@ -51,13 +55,14 @@ class ModelConfig:
51
55
  self.quantization = quantization
52
56
 
53
57
  # Parse args
58
+ self.maybe_pull_model_tokenizer_from_remote()
54
59
  self.model_override_args = json.loads(model_override_args)
55
60
  kwargs = {}
56
61
  if override_config_file and override_config_file.strip():
57
62
  kwargs["_configuration_file"] = override_config_file.strip()
58
63
 
59
64
  self.hf_config = get_config(
60
- model_path,
65
+ self.model_path,
61
66
  trust_remote_code=trust_remote_code,
62
67
  revision=revision,
63
68
  model_override_args=self.model_override_args,
@@ -134,6 +139,11 @@ class ModelConfig:
134
139
  self.attention_arch = AttentionArch.MLA
135
140
  self.kv_lora_rank = self.hf_config.kv_lora_rank
136
141
  self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
142
+ elif "DeepseekVL2ForCausalLM" in self.hf_config.architectures:
143
+ self.head_dim = 256
144
+ self.attention_arch = AttentionArch.MLA
145
+ self.kv_lora_rank = self.hf_text_config.kv_lora_rank
146
+ self.qk_rope_head_dim = self.hf_text_config.qk_rope_head_dim
137
147
  else:
138
148
  self.attention_arch = AttentionArch.MHA
139
149
 
@@ -229,7 +239,12 @@ class ModelConfig:
229
239
 
230
240
  # adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
231
241
  def _verify_quantization(self) -> None:
232
- supported_quantization = [*QUANTIZATION_METHODS]
242
+ # Select supported quantization methods based on vllm availability
243
+ if VLLM_AVAILABLE:
244
+ supported_quantization = [*QUANTIZATION_METHODS]
245
+ else:
246
+ supported_quantization = [*BASE_QUANTIZATION_METHODS]
247
+
233
248
  rocm_supported_quantization = [
234
249
  "awq",
235
250
  "gptq",
@@ -267,7 +282,11 @@ class ModelConfig:
267
282
  quant_method = quant_cfg.get("quant_method", "").lower()
268
283
 
269
284
  # Detect which checkpoint is it
270
- for _, method in QUANTIZATION_METHODS.items():
285
+ # Only iterate through currently available quantization methods
286
+ available_methods = (
287
+ QUANTIZATION_METHODS if VLLM_AVAILABLE else BASE_QUANTIZATION_METHODS
288
+ )
289
+ for _, method in available_methods.items():
271
290
  quantization_override = method.override_quantization_method(
272
291
  quant_cfg, self.quantization
273
292
  )
@@ -318,6 +337,29 @@ class ModelConfig:
318
337
  eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
319
338
  return eos_ids
320
339
 
340
+ def maybe_pull_model_tokenizer_from_remote(self) -> None:
341
+ """
342
+ Pull the model config files to a temporary
343
+ directory in case of remote.
344
+
345
+ Args:
346
+ model: The model name or path.
347
+
348
+ """
349
+ from sglang.srt.connector import create_remote_connector
350
+ from sglang.srt.utils import is_remote_url
351
+
352
+ if is_remote_url(self.model_path):
353
+ logger.info("Pulling model configs from remote...")
354
+ # BaseConnector implements __del__() to clean up the local dir.
355
+ # Since config files need to exist all the time, so we DO NOT use
356
+ # with statement to avoid closing the client.
357
+ client = create_remote_connector(self.model_path)
358
+ if is_remote_url(self.model_path):
359
+ client.pull_files(allow_pattern=["*config.json"])
360
+ self.model_weights = self.model_path
361
+ self.model_path = client.get_local_dir()
362
+
321
363
 
322
364
  def get_hf_text_config(config: PretrainedConfig):
323
365
  """Get the "sub" config relevant to llm for multi modal models.
@@ -338,6 +380,8 @@ def get_hf_text_config(config: PretrainedConfig):
338
380
  # if transformers config doesn't align with this assumption.
339
381
  assert hasattr(config.text_config, "num_attention_heads")
340
382
  return config.text_config
383
+ if hasattr(config, "language_config"):
384
+ return config.language_config
341
385
  else:
342
386
  return config
343
387
 
@@ -367,9 +411,13 @@ def _get_and_verify_dtype(
367
411
  dtype = dtype.lower()
368
412
  if dtype == "auto":
369
413
  if config_dtype == torch.float32:
370
- if config.model_type == "gemma2":
414
+ if config.model_type.startswith("gemma"):
415
+ if config.model_type == "gemma":
416
+ gemma_version = ""
417
+ else:
418
+ gemma_version = config.model_type[5]
371
419
  logger.info(
372
- "For Gemma 2, we downcast float32 to bfloat16 instead "
420
+ f"For Gemma {gemma_version}, we downcast float32 to bfloat16 instead "
373
421
  "of float16 by default. Please specify `dtype` if you "
374
422
  "want to use float16."
375
423
  )
@@ -418,6 +466,7 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal
418
466
  or "LlamaForSequenceClassificationWithNormal_Weights" in model_architectures
419
467
  or "InternLM2ForRewardModel" in model_architectures
420
468
  or "Qwen2ForRewardModel" in model_architectures
469
+ or "Qwen2ForSequenceClassification" in model_architectures
421
470
  ):
422
471
  return False
423
472
  else:
@@ -425,17 +474,20 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal
425
474
 
426
475
 
427
476
  multimodal_model_archs = [
477
+ "DeepseekVL2ForCausalLM",
478
+ "Gemma3ForConditionalGeneration",
479
+ "Grok1VForCausalLM",
480
+ "Grok1AForCausalLM",
428
481
  "LlavaLlamaForCausalLM",
429
- "LlavaQwenForCausalLM",
430
482
  "LlavaMistralForCausalLM",
483
+ "LlavaQwenForCausalLM",
431
484
  "LlavaVidForCausalLM",
432
- "Grok1VForCausalLM",
433
- "Grok1AForCausalLM",
485
+ "MiniCPMO",
486
+ "MiniCPMV",
487
+ "MultiModalityCausalLM",
434
488
  "MllamaForConditionalGeneration",
435
489
  "Qwen2VLForConditionalGeneration",
436
490
  "Qwen2_5_VLForConditionalGeneration",
437
- "MiniCPMV",
438
- "MultiModalityCausalLM",
439
491
  ]
440
492
 
441
493
 
@@ -0,0 +1,25 @@
1
+ from typing import Type
2
+
3
+ from transformers import (
4
+ AutoImageProcessor,
5
+ AutoProcessor,
6
+ BaseImageProcessor,
7
+ PretrainedConfig,
8
+ ProcessorMixin,
9
+ )
10
+
11
+
12
+ def register_image_processor(
13
+ config: Type[PretrainedConfig], image_processor: Type[BaseImageProcessor]
14
+ ):
15
+ """
16
+ register customized hf image processor while removing hf impl
17
+ """
18
+ AutoImageProcessor.register(config, None, image_processor, None, exist_ok=True)
19
+
20
+
21
+ def register_processor(config: Type[PretrainedConfig], processor: Type[ProcessorMixin]):
22
+ """
23
+ register customized hf processor while removing hf impl
24
+ """
25
+ AutoProcessor.register(config, processor, exist_ok=True)
@@ -0,0 +1,51 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import enum
4
+ import logging
5
+
6
+ from sglang.srt.connector.base_connector import (
7
+ BaseConnector,
8
+ BaseFileConnector,
9
+ BaseKVConnector,
10
+ )
11
+ from sglang.srt.connector.redis import RedisConnector
12
+ from sglang.srt.connector.s3 import S3Connector
13
+ from sglang.srt.utils import parse_connector_type
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class ConnectorType(str, enum.Enum):
19
+ FS = "filesystem"
20
+ KV = "KV"
21
+
22
+
23
+ def create_remote_connector(url, device="cpu") -> BaseConnector:
24
+ connector_type = parse_connector_type(url)
25
+ if connector_type == "redis":
26
+ return RedisConnector(url)
27
+ elif connector_type == "s3":
28
+ return S3Connector(url)
29
+ else:
30
+ raise ValueError(f"Invalid connector type: {url}")
31
+
32
+
33
+ def get_connector_type(client: BaseConnector) -> ConnectorType:
34
+ if isinstance(client, BaseKVConnector):
35
+ return ConnectorType.KV
36
+ if isinstance(client, BaseFileConnector):
37
+ return ConnectorType.FS
38
+
39
+ raise ValueError(f"Invalid connector type: {client}")
40
+
41
+
42
+ __all__ = [
43
+ "BaseConnector",
44
+ "BaseFileConnector",
45
+ "BaseKVConnector",
46
+ "RedisConnector",
47
+ "S3Connector",
48
+ "ConnectorType",
49
+ "create_remote_connector",
50
+ "get_connector_type",
51
+ ]
@@ -0,0 +1,112 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import os
4
+ import shutil
5
+ import signal
6
+ import tempfile
7
+ from abc import ABC, abstractmethod
8
+ from typing import Generator, List, Optional, Tuple
9
+
10
+ import torch
11
+
12
+
13
+ class BaseConnector(ABC):
14
+ """
15
+ For fs connector such as s3:
16
+ <connector_type>://<path>/<filename>
17
+
18
+ For kv connector such as redis:
19
+ <connector_type>://<host>:<port>/<model_name>/keys/<key>
20
+ <connector_type://<host>:<port>/<model_name>/files/<filename>
21
+ """
22
+
23
+ def __init__(self, url: str, device: torch.device = "cpu"):
24
+ self.url = url
25
+ self.device = device
26
+ self.closed = False
27
+ self.local_dir = tempfile.mkdtemp()
28
+ for sig in (signal.SIGINT, signal.SIGTERM):
29
+ existing_handler = signal.getsignal(sig)
30
+ signal.signal(sig, self._close_by_signal(existing_handler))
31
+
32
+ def get_local_dir(self):
33
+ return self.local_dir
34
+
35
+ @abstractmethod
36
+ def weight_iterator(
37
+ self, rank: int = 0
38
+ ) -> Generator[Tuple[str, torch.Tensor], None, None]:
39
+ raise NotImplementedError()
40
+
41
+ @abstractmethod
42
+ def pull_files(
43
+ self,
44
+ allow_pattern: Optional[List[str]] = None,
45
+ ignore_pattern: Optional[List[str]] = None,
46
+ ) -> None:
47
+ raise NotImplementedError()
48
+
49
+ def close(self):
50
+ if self.closed:
51
+ return
52
+
53
+ self.closed = True
54
+ if os.path.exists(self.local_dir):
55
+ shutil.rmtree(self.local_dir)
56
+
57
+ def __enter__(self):
58
+ return self
59
+
60
+ def __exit__(self, exc_type, exc_value, traceback):
61
+ self.close()
62
+
63
+ def __del__(self):
64
+ self.close()
65
+
66
+ def _close_by_signal(self, existing_handler=None):
67
+
68
+ def new_handler(signum, frame):
69
+ self.close()
70
+ if existing_handler:
71
+ existing_handler(signum, frame)
72
+
73
+ return new_handler
74
+
75
+
76
+ class BaseKVConnector(BaseConnector):
77
+
78
+ @abstractmethod
79
+ def get(self, key: str) -> Optional[torch.Tensor]:
80
+ raise NotImplementedError()
81
+
82
+ @abstractmethod
83
+ def getstr(self, key: str) -> Optional[str]:
84
+ raise NotImplementedError()
85
+
86
+ @abstractmethod
87
+ def set(self, key: str, obj: torch.Tensor) -> None:
88
+ raise NotImplementedError()
89
+
90
+ @abstractmethod
91
+ def setstr(self, key: str, obj: str) -> None:
92
+ raise NotImplementedError()
93
+
94
+ @abstractmethod
95
+ def list(self, prefix: str) -> List[str]:
96
+ raise NotImplementedError()
97
+
98
+
99
+ class BaseFileConnector(BaseConnector):
100
+ """
101
+ List full file names from remote fs path and filter by allow pattern.
102
+
103
+ Args:
104
+ allow_pattern: A list of patterns of which files to pull.
105
+
106
+ Returns:
107
+ list[str]: List of full paths allowed by the pattern
108
+ """
109
+
110
+ @abstractmethod
111
+ def glob(self, allow_pattern: str) -> List[str]:
112
+ raise NotImplementedError()
@@ -0,0 +1,85 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import logging
4
+ from typing import Generator, List, Optional, Tuple
5
+ from urllib.parse import urlparse
6
+
7
+ import torch
8
+
9
+ from sglang.srt.connector import BaseKVConnector
10
+ from sglang.srt.connector.serde import create_serde
11
+ from sglang.srt.connector.utils import pull_files_from_db
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class RedisConnector(BaseKVConnector):
17
+
18
+ def __init__(self, url: str, device: torch.device = "cpu"):
19
+ import redis
20
+
21
+ super().__init__(url, device)
22
+ parsed_url = urlparse(url)
23
+ self.connection = redis.Redis(host=parsed_url.hostname, port=parsed_url.port)
24
+ self.model_name = parsed_url.path.lstrip("/")
25
+ # TODO: more serde options
26
+ self.s, self.d = create_serde("safe")
27
+
28
+ def get(self, key: str) -> Optional[torch.Tensor]:
29
+ val = self.connection.get(key)
30
+
31
+ if val is None:
32
+ logger.error("Key %s not found", key)
33
+ return None
34
+
35
+ return self.d.from_bytes(val)
36
+
37
+ def getstr(self, key: str) -> Optional[str]:
38
+ val = self.connection.get(key)
39
+ if val is None:
40
+ logger.error("Key %s not found", key)
41
+ return None
42
+
43
+ return val.decode("utf-8")
44
+
45
+ def set(self, key: str, tensor: torch.Tensor) -> None:
46
+ assert tensor is not None
47
+ self.connection.set(key, self.s.to_bytes(tensor))
48
+
49
+ def setstr(self, key: str, obj: str) -> None:
50
+ self.connection.set(key, obj)
51
+
52
+ def list(self, prefix: str) -> List[str]:
53
+ cursor = 0
54
+ all_keys: List[bytes] = []
55
+
56
+ while True:
57
+ ret: Tuple[int, List[bytes]] = self.connection.scan(
58
+ cursor=cursor, match=f"{prefix}*"
59
+ ) # type: ignore
60
+ cursor, keys = ret
61
+ all_keys.extend(keys)
62
+ if cursor == 0:
63
+ break
64
+
65
+ return [key.decode("utf-8") for key in all_keys]
66
+
67
+ def weight_iterator(
68
+ self, rank: int = 0
69
+ ) -> Generator[Tuple[str, bytes], None, None]:
70
+ keys = self.list(f"{self.model_name}/keys/rank_{rank}/")
71
+ for key in keys:
72
+ val = self.get(key)
73
+ key = key.removeprefix(f"{self.model_name}/keys/rank_{rank}/")
74
+ yield key, val
75
+
76
+ def pull_files(
77
+ self,
78
+ allow_pattern: Optional[List[str]] = None,
79
+ ignore_pattern: Optional[List[str]] = None,
80
+ ) -> None:
81
+ pull_files_from_db(self, self.model_name, allow_pattern, ignore_pattern)
82
+
83
+ def close(self):
84
+ self.connection.close()
85
+ super().close()
@@ -0,0 +1,122 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import fnmatch
4
+ import os
5
+ from pathlib import Path
6
+ from typing import Generator, Optional, Tuple
7
+
8
+ import torch
9
+
10
+ from sglang.srt.connector import BaseFileConnector
11
+
12
+
13
+ def _filter_allow(paths: list[str], patterns: list[str]) -> list[str]:
14
+ return [
15
+ path
16
+ for path in paths
17
+ if any(fnmatch.fnmatch(path, pattern) for pattern in patterns)
18
+ ]
19
+
20
+
21
+ def _filter_ignore(paths: list[str], patterns: list[str]) -> list[str]:
22
+ return [
23
+ path
24
+ for path in paths
25
+ if not any(fnmatch.fnmatch(path, pattern) for pattern in patterns)
26
+ ]
27
+
28
+
29
+ def list_files(
30
+ s3,
31
+ path: str,
32
+ allow_pattern: Optional[list[str]] = None,
33
+ ignore_pattern: Optional[list[str]] = None,
34
+ ) -> tuple[str, str, list[str]]:
35
+ """
36
+ List files from S3 path and filter by pattern.
37
+
38
+ Args:
39
+ s3: S3 client to use.
40
+ path: The S3 path to list from.
41
+ allow_pattern: A list of patterns of which files to pull.
42
+ ignore_pattern: A list of patterns of which files not to pull.
43
+
44
+ Returns:
45
+ tuple[str, str, list[str]]: A tuple where:
46
+ - The first element is the bucket name
47
+ - The second element is string represent the bucket
48
+ and the prefix as a dir like string
49
+ - The third element is a list of files allowed or
50
+ disallowed by pattern
51
+ """
52
+ parts = path.removeprefix("s3://").split("/")
53
+ prefix = "/".join(parts[1:])
54
+ bucket_name = parts[0]
55
+
56
+ objects = s3.list_objects_v2(Bucket=bucket_name, Prefix=prefix)
57
+ paths = [obj["Key"] for obj in objects.get("Contents", [])]
58
+
59
+ paths = _filter_ignore(paths, ["*/"])
60
+ if allow_pattern is not None:
61
+ paths = _filter_allow(paths, allow_pattern)
62
+
63
+ if ignore_pattern is not None:
64
+ paths = _filter_ignore(paths, ignore_pattern)
65
+
66
+ return bucket_name, prefix, paths
67
+
68
+
69
+ class S3Connector(BaseFileConnector):
70
+
71
+ def __init__(self, url: str) -> None:
72
+ import boto3
73
+
74
+ super().__init__(url)
75
+ self.client = boto3.client("s3")
76
+
77
+ def glob(self, allow_pattern: Optional[list[str]] = None) -> list[str]:
78
+ bucket_name, _, paths = list_files(
79
+ self.client, path=self.url, allow_pattern=allow_pattern
80
+ )
81
+ return [f"s3://{bucket_name}/{path}" for path in paths]
82
+
83
+ def pull_files(
84
+ self,
85
+ allow_pattern: Optional[list[str]] = None,
86
+ ignore_pattern: Optional[list[str]] = None,
87
+ ) -> None:
88
+ """
89
+ Pull files from S3 storage into the temporary directory.
90
+
91
+ Args:
92
+ s3_model_path: The S3 path of the model.
93
+ allow_pattern: A list of patterns of which files to pull.
94
+ ignore_pattern: A list of patterns of which files not to pull.
95
+
96
+ """
97
+ bucket_name, base_dir, files = list_files(
98
+ self.client, self.url, allow_pattern, ignore_pattern
99
+ )
100
+ if len(files) == 0:
101
+ return
102
+
103
+ for file in files:
104
+ destination_file = os.path.join(self.local_dir, file.removeprefix(base_dir))
105
+ local_dir = Path(destination_file).parent
106
+ os.makedirs(local_dir, exist_ok=True)
107
+ self.client.download_file(bucket_name, file, destination_file)
108
+
109
+ def weight_iterator(
110
+ self, rank: int = 0
111
+ ) -> Generator[Tuple[str, torch.Tensor], None, None]:
112
+ from sglang.srt.model_loader.weight_utils import (
113
+ runai_safetensors_weights_iterator,
114
+ )
115
+
116
+ # only support safetensor files now
117
+ hf_weights_files = self.glob(allow_pattern=["*.safetensors"])
118
+ return runai_safetensors_weights_iterator(hf_weights_files)
119
+
120
+ def close(self):
121
+ self.client.close()
122
+ super().close()
@@ -0,0 +1,31 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ # inspired by LMCache
4
+ from typing import Optional, Tuple
5
+
6
+ import torch
7
+
8
+ from sglang.srt.connector.serde.safe_serde import SafeDeserializer, SafeSerializer
9
+ from sglang.srt.connector.serde.serde import Deserializer, Serializer
10
+
11
+
12
+ def create_serde(serde_type: str) -> Tuple[Serializer, Deserializer]:
13
+ s: Optional[Serializer] = None
14
+ d: Optional[Deserializer] = None
15
+
16
+ if serde_type == "safe":
17
+ s = SafeSerializer()
18
+ d = SafeDeserializer(torch.uint8)
19
+ else:
20
+ raise ValueError(f"Unknown serde type: {serde_type}")
21
+
22
+ return s, d
23
+
24
+
25
+ __all__ = [
26
+ "Serializer",
27
+ "Deserializer",
28
+ "SafeSerializer",
29
+ "SafeDeserializer",
30
+ "create_serde",
31
+ ]
@@ -0,0 +1,29 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ from typing import Union
4
+
5
+ import torch
6
+ from safetensors.torch import load, save
7
+
8
+ from sglang.srt.connector.serde.serde import Deserializer, Serializer
9
+
10
+
11
+ class SafeSerializer(Serializer):
12
+
13
+ def __init__(self):
14
+ super().__init__()
15
+
16
+ def to_bytes(self, t: torch.Tensor) -> bytes:
17
+ return save({"tensor_bytes": t.cpu().contiguous()})
18
+
19
+
20
+ class SafeDeserializer(Deserializer):
21
+
22
+ def __init__(self, dtype):
23
+ super().__init__(dtype)
24
+
25
+ def from_bytes_normal(self, b: Union[bytearray, bytes]) -> torch.Tensor:
26
+ return load(bytes(b))["tensor_bytes"].to(dtype=self.dtype)
27
+
28
+ def from_bytes(self, b: Union[bytearray, bytes]) -> torch.Tensor:
29
+ return self.from_bytes_normal(b)
@@ -0,0 +1,43 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import abc
4
+ from abc import ABC, abstractmethod
5
+
6
+ import torch
7
+
8
+
9
+ class Serializer(ABC):
10
+
11
+ @abstractmethod
12
+ def to_bytes(self, t: torch.Tensor) -> bytes:
13
+ """
14
+ Serialize a pytorch tensor to bytes. The serialized bytes should contain
15
+ both the data and the metadata (shape, dtype, etc.) of the tensor.
16
+
17
+ Input:
18
+ t: the input pytorch tensor, can be on any device, in any shape,
19
+ with any dtype
20
+
21
+ Returns:
22
+ bytes: the serialized bytes
23
+ """
24
+ raise NotImplementedError
25
+
26
+
27
+ class Deserializer(metaclass=abc.ABCMeta):
28
+
29
+ def __init__(self, dtype):
30
+ self.dtype = dtype
31
+
32
+ @abstractmethod
33
+ def from_bytes(self, bs: bytes) -> torch.Tensor:
34
+ """
35
+ Deserialize a pytorch tensor from bytes.
36
+
37
+ Input:
38
+ bytes: a stream of bytes
39
+
40
+ Output:
41
+ torch.Tensor: the deserialized pytorch tensor
42
+ """
43
+ raise NotImplementedError
@@ -0,0 +1,35 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import os
4
+ from pathlib import Path
5
+ from typing import Optional
6
+ from urllib.parse import urlparse
7
+
8
+ from sglang.srt.connector import BaseConnector
9
+
10
+
11
+ def parse_model_name(url: str) -> str:
12
+ """
13
+ Parse the model name from the url.
14
+ Only used for db connector
15
+ """
16
+ parsed_url = urlparse(url)
17
+ return parsed_url.path.lstrip("/")
18
+
19
+
20
+ def pull_files_from_db(
21
+ connector: BaseConnector,
22
+ model_name: str,
23
+ allow_pattern: Optional[list[str]] = None,
24
+ ignore_pattern: Optional[list[str]] = None,
25
+ ) -> None:
26
+ prefix = f"{model_name}/files/"
27
+ local_dir = connector.get_local_dir()
28
+ files = connector.list(prefix)
29
+
30
+ for file in files:
31
+ destination_file = os.path.join(local_dir, file.removeprefix(prefix))
32
+ local_dir = Path(destination_file).parent
33
+ os.makedirs(local_dir, exist_ok=True)
34
+ with open(destination_file, "wb") as f:
35
+ f.write(connector.getstr(file).encode("utf-8"))