sglang 0.3.6.post3__py3-none-any.whl → 0.4.0.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. sglang/__init__.py +1 -1
  2. sglang/bench_one_batch.py +4 -0
  3. sglang/bench_serving.py +13 -0
  4. sglang/check_env.py +1 -1
  5. sglang/srt/_custom_ops.py +118 -0
  6. sglang/srt/configs/device_config.py +17 -0
  7. sglang/srt/configs/load_config.py +84 -0
  8. sglang/srt/configs/model_config.py +161 -4
  9. sglang/srt/configs/qwen2vl.py +5 -8
  10. sglang/srt/constrained/outlines_backend.py +11 -1
  11. sglang/srt/constrained/outlines_jump_forward.py +8 -1
  12. sglang/srt/constrained/xgrammar_backend.py +5 -5
  13. sglang/srt/distributed/__init__.py +3 -0
  14. sglang/srt/distributed/communication_op.py +34 -0
  15. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  16. sglang/srt/distributed/device_communicators/cuda_wrapper.py +182 -0
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +352 -0
  18. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +291 -0
  19. sglang/srt/distributed/device_communicators/hpu_communicator.py +48 -0
  20. sglang/srt/distributed/device_communicators/pynccl.py +204 -0
  21. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +362 -0
  22. sglang/srt/distributed/device_communicators/shm_broadcast.py +568 -0
  23. sglang/srt/distributed/device_communicators/xpu_communicator.py +47 -0
  24. sglang/srt/distributed/parallel_state.py +1275 -0
  25. sglang/srt/distributed/utils.py +223 -0
  26. sglang/srt/hf_transformers_utils.py +37 -1
  27. sglang/srt/layers/attention/__init__.py +5 -2
  28. sglang/srt/layers/attention/double_sparsity_backend.py +22 -8
  29. sglang/srt/layers/attention/flashinfer_backend.py +33 -20
  30. sglang/srt/layers/attention/torch_native_backend.py +299 -0
  31. sglang/srt/layers/attention/triton_backend.py +22 -8
  32. sglang/srt/layers/attention/triton_ops/extend_attention.py +3 -0
  33. sglang/srt/layers/ep_moe/__init__.py +0 -0
  34. sglang/srt/layers/ep_moe/kernels.py +349 -0
  35. sglang/srt/layers/ep_moe/layer.py +661 -0
  36. sglang/srt/layers/fused_moe_patch.py +20 -11
  37. sglang/srt/layers/linear.py +1 -0
  38. sglang/srt/layers/logits_processor.py +17 -3
  39. sglang/srt/layers/quantization/__init__.py +36 -2
  40. sglang/srt/layers/quantization/fp8.py +559 -0
  41. sglang/srt/layers/quantization/fp8_utils.py +27 -0
  42. sglang/srt/layers/radix_attention.py +4 -2
  43. sglang/srt/layers/sampler.py +2 -0
  44. sglang/srt/layers/torchao_utils.py +23 -45
  45. sglang/srt/layers/vocab_parallel_embedding.py +1 -0
  46. sglang/srt/lora/lora.py +1 -1
  47. sglang/srt/managers/io_struct.py +48 -2
  48. sglang/srt/managers/schedule_batch.py +19 -14
  49. sglang/srt/managers/schedule_policy.py +7 -4
  50. sglang/srt/managers/scheduler.py +145 -85
  51. sglang/srt/managers/tokenizer_manager.py +166 -68
  52. sglang/srt/managers/tp_worker.py +36 -3
  53. sglang/srt/managers/tp_worker_overlap_thread.py +28 -8
  54. sglang/srt/mem_cache/memory_pool.py +5 -1
  55. sglang/srt/model_executor/cuda_graph_runner.py +30 -7
  56. sglang/srt/model_executor/forward_batch_info.py +9 -4
  57. sglang/srt/model_executor/model_runner.py +146 -153
  58. sglang/srt/model_loader/__init__.py +34 -0
  59. sglang/srt/model_loader/loader.py +1139 -0
  60. sglang/srt/model_loader/utils.py +41 -0
  61. sglang/srt/model_loader/weight_utils.py +640 -0
  62. sglang/srt/model_parallel.py +1 -5
  63. sglang/srt/models/baichuan.py +9 -10
  64. sglang/srt/models/chatglm.py +6 -15
  65. sglang/srt/models/commandr.py +4 -5
  66. sglang/srt/models/dbrx.py +2 -3
  67. sglang/srt/models/deepseek.py +4 -11
  68. sglang/srt/models/deepseek_v2.py +90 -18
  69. sglang/srt/models/exaone.py +2 -3
  70. sglang/srt/models/gemma.py +2 -6
  71. sglang/srt/models/gemma2.py +3 -14
  72. sglang/srt/models/gemma2_reward.py +0 -1
  73. sglang/srt/models/gpt2.py +5 -12
  74. sglang/srt/models/gpt_bigcode.py +6 -22
  75. sglang/srt/models/grok.py +3 -8
  76. sglang/srt/models/internlm2.py +2 -3
  77. sglang/srt/models/internlm2_reward.py +0 -1
  78. sglang/srt/models/llama.py +96 -31
  79. sglang/srt/models/llama_classification.py +1 -2
  80. sglang/srt/models/llama_embedding.py +1 -2
  81. sglang/srt/models/llama_reward.py +2 -3
  82. sglang/srt/models/llava.py +1 -4
  83. sglang/srt/models/llavavid.py +1 -2
  84. sglang/srt/models/minicpm.py +4 -7
  85. sglang/srt/models/minicpm3.py +6 -19
  86. sglang/srt/models/mixtral.py +24 -14
  87. sglang/srt/models/mixtral_quant.py +2 -3
  88. sglang/srt/models/mllama.py +3 -7
  89. sglang/srt/models/olmo.py +2 -8
  90. sglang/srt/models/olmo2.py +0 -1
  91. sglang/srt/models/olmoe.py +3 -5
  92. sglang/srt/models/phi3_small.py +8 -13
  93. sglang/srt/models/qwen.py +2 -3
  94. sglang/srt/models/qwen2.py +10 -9
  95. sglang/srt/models/qwen2_moe.py +4 -16
  96. sglang/srt/models/qwen2_vl.py +2 -6
  97. sglang/srt/models/registry.py +99 -0
  98. sglang/srt/models/stablelm.py +2 -3
  99. sglang/srt/models/torch_native_llama.py +6 -17
  100. sglang/srt/models/xverse.py +2 -4
  101. sglang/srt/models/xverse_moe.py +4 -11
  102. sglang/srt/models/yivl.py +2 -3
  103. sglang/srt/openai_api/adapter.py +9 -5
  104. sglang/srt/openai_api/protocol.py +1 -0
  105. sglang/srt/sampling/sampling_batch_info.py +9 -8
  106. sglang/srt/server.py +270 -173
  107. sglang/srt/server_args.py +102 -29
  108. sglang/srt/utils.py +295 -28
  109. sglang/test/test_utils.py +7 -0
  110. sglang/version.py +1 -1
  111. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/METADATA +5 -4
  112. sglang-0.4.0.post1.dist-info/RECORD +189 -0
  113. sglang-0.3.6.post3.dist-info/RECORD +0 -162
  114. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/LICENSE +0 -0
  115. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/WHEEL +0 -0
  116. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/top_level.txt +0 -0
sglang/__init__.py CHANGED
@@ -66,7 +66,7 @@ from sglang.version import __version__
66
66
 
67
67
  __all__ += ["__version__"]
68
68
 
69
- # SGL Backends
69
+ # SGLang Backends
70
70
  from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
71
71
  from sglang.utils import LazyImport
72
72
 
sglang/bench_one_batch.py CHANGED
@@ -111,8 +111,12 @@ def load_model(server_args, port_args, tp_rank):
111
111
  model_config = ModelConfig(
112
112
  server_args.model_path,
113
113
  trust_remote_code=server_args.trust_remote_code,
114
+ revision=server_args.revision,
114
115
  context_length=server_args.context_length,
115
116
  model_override_args=server_args.json_model_override_args,
117
+ is_embedding=server_args.is_embedding,
118
+ dtype=server_args.dtype,
119
+ quantization=server_args.quantization,
116
120
  )
117
121
  model_runner = ModelRunner(
118
122
  model_config=model_config,
sglang/bench_serving.py CHANGED
@@ -51,6 +51,7 @@ class RequestFuncInput:
51
51
  prompt_len: int
52
52
  output_len: int
53
53
  model: str
54
+ lora_name: str
54
55
  extra_request_body: Dict[str, Any]
55
56
 
56
57
 
@@ -319,6 +320,7 @@ async def async_request_sglang_generate(
319
320
  "ignore_eos": not args.disable_ignore_eos,
320
321
  },
321
322
  "stream": not args.disable_stream,
323
+ "lora_path": request_func_input.lora_name,
322
324
  **request_func_input.extra_request_body,
323
325
  }
324
326
  headers = {}
@@ -884,6 +886,7 @@ async def benchmark(
884
886
  request_rate: float,
885
887
  max_concurrency: Optional[int],
886
888
  disable_tqdm: bool,
889
+ lora_name: str,
887
890
  extra_request_body: Dict[str, Any],
888
891
  profile: bool,
889
892
  ):
@@ -909,6 +912,7 @@ async def benchmark(
909
912
  api_url=api_url,
910
913
  prompt_len=test_prompt_len,
911
914
  output_len=test_output_len,
915
+ lora_name=lora_name,
912
916
  extra_request_body=extra_request_body,
913
917
  )
914
918
  test_output = await request_func(request_func_input=test_input)
@@ -942,6 +946,7 @@ async def benchmark(
942
946
  api_url=api_url,
943
947
  prompt_len=prompt_len,
944
948
  output_len=output_len,
949
+ lora_name=lora_name,
945
950
  extra_request_body=extra_request_body,
946
951
  )
947
952
  tasks.append(
@@ -1247,6 +1252,7 @@ def run_benchmark(args_: argparse.Namespace):
1247
1252
  request_rate=args.request_rate,
1248
1253
  max_concurrency=args.max_concurrency,
1249
1254
  disable_tqdm=args.disable_tqdm,
1255
+ lora_name=args.lora_name,
1250
1256
  extra_request_body=extra_request_body,
1251
1257
  profile=args.profile,
1252
1258
  )
@@ -1267,6 +1273,7 @@ def run_benchmark(args_: argparse.Namespace):
1267
1273
  request_rate=rate,
1268
1274
  max_concurrency=args.max_concurrency,
1269
1275
  disable_tqdm=args.disable_tqdm,
1276
+ lora_name=args.lora_name,
1270
1277
  extra_request_body=extra_request_body,
1271
1278
  profile=args.profile,
1272
1279
  )
@@ -1451,5 +1458,11 @@ if __name__ == "__main__":
1451
1458
  help="Use Torch Profiler. The endpoint must be launched with "
1452
1459
  "SGLANG_TORCH_PROFILER_DIR to enable profiler.",
1453
1460
  )
1461
+ parser.add_argument(
1462
+ "--lora-name",
1463
+ type=str,
1464
+ default=None,
1465
+ help="The name of LoRA adapter",
1466
+ )
1454
1467
  args = parser.parse_args()
1455
1468
  run_benchmark(args)
sglang/check_env.py CHANGED
@@ -9,7 +9,7 @@ from collections import OrderedDict, defaultdict
9
9
 
10
10
  import torch
11
11
 
12
- # List of packages to check versions for
12
+ # List of packages to check versions
13
13
  PACKAGE_LIST = [
14
14
  "sglang",
15
15
  "flashinfer",
@@ -0,0 +1,118 @@
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/_custom_ops.py
2
+ import contextlib
3
+ import functools
4
+ import importlib
5
+ import logging
6
+ from typing import TYPE_CHECKING, List, Optional, Tuple, Union
7
+
8
+ import torch
9
+ import torch.library
10
+
11
+ from sglang.srt.utils import is_hpu
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ if not is_hpu():
16
+ try:
17
+ import custom_ar
18
+ except ImportError as e:
19
+ logger.warning("Failed to import from custom_ar with %r", e)
20
+
21
+
22
+ def hint_on_error(fn):
23
+
24
+ @functools.wraps(fn)
25
+ def wrapper(*args, **kwargs):
26
+ try:
27
+ return fn(*args, **kwargs)
28
+
29
+ except NotImplementedError as e:
30
+ msg = (
31
+ "Error in calling custom op %s: %s\n"
32
+ "Not implemented or built, mostly likely because the current current device "
33
+ "does not support this kernel (less likely TORCH_CUDA_ARCH_LIST was set "
34
+ "incorrectly while building)"
35
+ )
36
+ logger.error(msg, fn.__name__, e)
37
+ raise NotImplementedError(msg % (fn.__name__, e)) from e
38
+ except AttributeError as e:
39
+ msg = (
40
+ "Error in calling custom op %s: %s\n"
41
+ "Possibly you have built or installed an obsolete version of vllm.\n"
42
+ "Please try a clean build and install of vllm,"
43
+ "or remove old built files such as vllm/*cpython*.so and build/ ."
44
+ )
45
+ logger.error(msg, fn.__name__, e)
46
+ raise e
47
+
48
+ return wrapper
49
+
50
+
51
+ # custom ar
52
+ def init_custom_ar(
53
+ ipc_tensors: List[torch.Tensor],
54
+ rank_data: torch.Tensor,
55
+ rank: int,
56
+ full_nvlink: bool,
57
+ ) -> int:
58
+ return torch.ops._C_vllm_ar.init_custom_ar(
59
+ ipc_tensors, rank_data, rank, full_nvlink
60
+ )
61
+
62
+
63
+ def all_reduce(
64
+ fa: int,
65
+ inp: torch.Tensor,
66
+ out: torch.Tensor,
67
+ reg_buffer: int,
68
+ reg_buffer_sz_bytes: int,
69
+ ) -> None:
70
+ torch.ops._C_vllm_ar.all_reduce(fa, inp, out, reg_buffer, reg_buffer_sz_bytes)
71
+
72
+
73
+ def dispose(fa: int) -> None:
74
+ torch.ops._C_vllm_ar.dispose(fa)
75
+
76
+
77
+ def meta_size() -> int:
78
+ return torch.ops._C_vllm_ar.meta_size()
79
+
80
+
81
+ def register_buffer(fa: int, ipc_tensors: List[int]) -> None:
82
+ return torch.ops._C_vllm_ar.register_buffer(fa, ipc_tensors)
83
+
84
+
85
+ def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]:
86
+ return torch.ops._C_vllm_ar.get_graph_buffer_ipc_meta(fa)
87
+
88
+
89
+ def register_graph_buffers(
90
+ fa: int, handles: List[List[int]], offsets: List[List[int]]
91
+ ) -> None:
92
+ torch.ops._C_vllm_ar.register_graph_buffers(fa, handles, offsets)
93
+
94
+
95
+ # temporary fix for https://github.com/vllm-project/vllm/issues/5456
96
+ # TODO: remove this in v0.6.0
97
+ names_and_values = globals()
98
+ names_and_values_to_update = {}
99
+ # prepare variables to avoid dict size change during iteration
100
+ k, v, arg = None, None, None
101
+ fn_type = type(lambda x: x)
102
+ for k, v in names_and_values.items():
103
+ # find functions that are defined in this file and have torch.Tensor
104
+ # in their annotations. `arg == "torch.Tensor"` is used to handle
105
+ # the case when users use `import __annotations__` to turn type
106
+ # hints into strings.
107
+ if (
108
+ isinstance(v, fn_type)
109
+ and v.__code__.co_filename == __file__
110
+ and any(
111
+ arg is torch.Tensor or arg == "torch.Tensor"
112
+ for arg in v.__annotations__.values()
113
+ )
114
+ ):
115
+ names_and_values_to_update[k] = hint_on_error(v)
116
+
117
+ names_and_values.update(names_and_values_to_update)
118
+ del names_and_values_to_update, names_and_values, v, k, fn_type
@@ -0,0 +1,17 @@
1
+ import logging
2
+ from typing import Optional
3
+
4
+ import torch
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+
9
+ class DeviceConfig:
10
+ device: Optional[torch.device]
11
+
12
+ def __init__(self, device: str = "cuda") -> None:
13
+ if device in ["cuda", "xpu", "hpu"]:
14
+ self.device_type = device
15
+ else:
16
+ raise RuntimeError(f"Not supported device type: {device}")
17
+ self.device = torch.device(self.device_type)
@@ -0,0 +1,84 @@
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
2
+ import enum
3
+ import json
4
+ import logging
5
+ from dataclasses import dataclass, field
6
+ from typing import List, Optional, Union
7
+
8
+ from sglang.srt.utils import is_hip
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class LoadFormat(str, enum.Enum):
14
+ AUTO = "auto"
15
+ PT = "pt"
16
+ SAFETENSORS = "safetensors"
17
+ NPCACHE = "npcache"
18
+ DUMMY = "dummy"
19
+ SHARDED_STATE = "sharded_state"
20
+ GGUF = "gguf"
21
+ BITSANDBYTES = "bitsandbytes"
22
+ MISTRAL = "mistral"
23
+
24
+
25
+ @dataclass
26
+ class LoadConfig:
27
+ """
28
+ download_dir: Directory to download and load the weights, default to the
29
+ default cache directory of huggingface.
30
+ load_format: The format of the model weights to load:
31
+ "auto" will try to load the weights in the safetensors format and
32
+ fall back to the pytorch bin format if safetensors format is
33
+ not available.
34
+ "pt" will load the weights in the pytorch bin format.
35
+ "safetensors" will load the weights in the safetensors format.
36
+ "npcache" will load the weights in pytorch format and store
37
+ a numpy cache to speed up the loading.
38
+ "dummy" will initialize the weights with random values, which is
39
+ mainly for profiling.
40
+ "bitsandbytes" will load nf4 type weights.
41
+ ignore_patterns: The list of patterns to ignore when loading the model.
42
+ Default to "original/**/*" to avoid repeated loading of llama's
43
+ checkpoints.
44
+
45
+ """
46
+
47
+ load_format: Union[str, LoadFormat] = LoadFormat.AUTO
48
+ download_dir: Optional[str] = None
49
+ model_loader_extra_config: Optional[Union[str, dict]] = field(default_factory=dict)
50
+ ignore_patterns: Optional[Union[List[str], str]] = None
51
+
52
+ def __post_init__(self):
53
+ model_loader_extra_config = self.model_loader_extra_config or {}
54
+ if isinstance(model_loader_extra_config, str):
55
+ self.model_loader_extra_config = json.loads(model_loader_extra_config)
56
+ self._verify_load_format()
57
+
58
+ if self.ignore_patterns is not None and len(self.ignore_patterns) > 0:
59
+ logger.info(
60
+ "Ignoring the following patterns when downloading weights: %s",
61
+ self.ignore_patterns,
62
+ )
63
+ else:
64
+ self.ignore_patterns = ["original/**/*"]
65
+
66
+ def _verify_load_format(self) -> None:
67
+ if not isinstance(self.load_format, str):
68
+ return
69
+
70
+ load_format = self.load_format.lower()
71
+ self.load_format = LoadFormat(load_format)
72
+
73
+ rocm_not_supported_load_format: List[str] = []
74
+ if is_hip() and load_format in rocm_not_supported_load_format:
75
+ rocm_supported_load_format = [
76
+ f
77
+ for f in LoadFormat.__members__
78
+ if (f not in rocm_not_supported_load_format)
79
+ ]
80
+ raise ValueError(
81
+ f"load format '{load_format}' is not supported in ROCm. "
82
+ f"Supported load formats are "
83
+ f"{rocm_supported_load_format}"
84
+ )
@@ -15,12 +15,14 @@
15
15
  import json
16
16
  import logging
17
17
  from enum import IntEnum, auto
18
- from typing import List, Optional
18
+ from typing import List, Optional, Union
19
19
 
20
+ import torch
20
21
  from transformers import PretrainedConfig
21
22
 
22
23
  from sglang.srt.hf_transformers_utils import get_config, get_context_length
23
- from sglang.srt.utils import get_bool_env_var
24
+ from sglang.srt.layers.quantization import QUANTIZATION_METHODS
25
+ from sglang.srt.utils import get_bool_env_var, is_hip
24
26
 
25
27
  logger = logging.getLogger(__name__)
26
28
 
@@ -33,17 +35,22 @@ class AttentionArch(IntEnum):
33
35
  class ModelConfig:
34
36
  def __init__(
35
37
  self,
36
- path: str,
38
+ model_path: str,
37
39
  trust_remote_code: bool = True,
38
40
  revision: Optional[str] = None,
39
41
  context_length: Optional[int] = None,
40
42
  model_override_args: Optional[dict] = None,
41
43
  is_embedding: Optional[bool] = None,
44
+ dtype: str = "auto",
45
+ quantization: Optional[str] = None,
42
46
  ) -> None:
47
+ self.model_path = model_path
48
+ self.revision = revision
49
+ self.quantization = quantization
43
50
  # Parse args
44
51
  self.model_override_args = json.loads(model_override_args)
45
52
  self.hf_config = get_config(
46
- path,
53
+ model_path,
47
54
  trust_remote_code=trust_remote_code,
48
55
  revision=revision,
49
56
  model_override_args=self.model_override_args,
@@ -56,6 +63,7 @@ class ModelConfig:
56
63
  )
57
64
  self.is_multimodal = is_multimodal_model(self.hf_config.architectures)
58
65
  self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
66
+ self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
59
67
 
60
68
  # Derive context length
61
69
  derived_context_len = get_context_length(self.hf_text_config)
@@ -116,6 +124,8 @@ class ModelConfig:
116
124
  self.num_hidden_layers = self.hf_text_config.num_hidden_layers
117
125
  self.vocab_size = self.hf_text_config.vocab_size
118
126
 
127
+ self._verify_quantization()
128
+
119
129
  # adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
120
130
  def get_total_num_kv_heads(self) -> int:
121
131
  """Returns the total number of KV heads."""
@@ -174,6 +184,86 @@ class ModelConfig:
174
184
  # parallel size so each GPU has at least one KV head.
175
185
  return max(1, total_num_kv_heads // tensor_parallel_size)
176
186
 
187
+ # adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
188
+ def _parse_quant_hf_config(self):
189
+ quant_cfg = getattr(self.hf_config, "quantization_config", None)
190
+ if quant_cfg is None:
191
+ # compressed-tensors uses a "compression_config" key
192
+ quant_cfg = getattr(self.hf_config, "compression_config", None)
193
+ return quant_cfg
194
+
195
+ # adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
196
+ def _verify_quantization(self) -> None:
197
+ supported_quantization = [*QUANTIZATION_METHODS]
198
+ rocm_supported_quantization = [
199
+ "awq",
200
+ "gptq",
201
+ "fp8",
202
+ "compressed_tensors",
203
+ "compressed-tensors",
204
+ "fbgemm_fp8",
205
+ ]
206
+ optimized_quantization_methods = [
207
+ "fp8",
208
+ "marlin",
209
+ "modelopt",
210
+ "gptq_marlin_24",
211
+ "gptq_marlin",
212
+ "awq_marlin",
213
+ "fbgemm_fp8",
214
+ "compressed_tensors",
215
+ "compressed-tensors",
216
+ "experts_int8",
217
+ ]
218
+ if self.quantization is not None:
219
+ self.quantization = self.quantization.lower()
220
+
221
+ # Parse quantization method from the HF model config, if available.
222
+ quant_cfg = self._parse_quant_hf_config()
223
+
224
+ if quant_cfg is not None:
225
+ quant_method = quant_cfg.get("quant_method", "").lower()
226
+
227
+ # Detect which checkpoint is it
228
+ for _, method in QUANTIZATION_METHODS.items():
229
+ quantization_override = method.override_quantization_method(
230
+ quant_cfg, self.quantization
231
+ )
232
+ if quantization_override:
233
+ quant_method = quantization_override
234
+ self.quantization = quantization_override
235
+ break
236
+
237
+ # Verify quantization configurations.
238
+ if self.quantization is None:
239
+ self.quantization = quant_method
240
+ elif self.quantization != quant_method:
241
+ raise ValueError(
242
+ "Quantization method specified in the model config "
243
+ f"({quant_method}) does not match the quantization "
244
+ f"method specified in the `quantization` argument "
245
+ f"({self.quantization})."
246
+ )
247
+
248
+ if self.quantization is not None:
249
+ if self.quantization not in supported_quantization:
250
+ raise ValueError(
251
+ f"Unknown quantization method: {self.quantization}. Must "
252
+ f"be one of {supported_quantization}."
253
+ )
254
+ if is_hip() and self.quantization not in rocm_supported_quantization:
255
+ raise ValueError(
256
+ f"{self.quantization} quantization is currently not "
257
+ f"supported in ROCm."
258
+ )
259
+ if self.quantization not in optimized_quantization_methods:
260
+ logger.warning(
261
+ "%s quantization is not fully "
262
+ "optimized yet. The speed can be slower than "
263
+ "non-quantized models.",
264
+ self.quantization,
265
+ )
266
+
177
267
 
178
268
  def get_hf_text_config(config: PretrainedConfig):
179
269
  """Get the "sub" config relevant to llm for multi modal models.
@@ -183,6 +273,9 @@ def get_hf_text_config(config: PretrainedConfig):
183
273
  if class_name.startswith("Llava") and class_name.endswith("ForCausalLM"):
184
274
  # We support non-hf version of llava models, so we do not want to
185
275
  # read the wrong values from the unused default text_config.
276
+ # NOTE(HandH1998): We set `torch_dtype` of config to `torch.float16` for the weights, as
277
+ # `torch.float16` is default used for image features in `python/sglang/srt/models/llava.py`.
278
+ setattr(config, "torch_dtype", torch.float16)
186
279
  return config
187
280
 
188
281
  if hasattr(config, "text_config"):
@@ -195,6 +288,70 @@ def get_hf_text_config(config: PretrainedConfig):
195
288
  return config
196
289
 
197
290
 
291
+ # adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
292
+ _STR_DTYPE_TO_TORCH_DTYPE = {
293
+ "half": torch.float16,
294
+ "float16": torch.float16,
295
+ "float": torch.float32,
296
+ "float32": torch.float32,
297
+ "bfloat16": torch.bfloat16,
298
+ }
299
+
300
+
301
+ # adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
302
+ def _get_and_verify_dtype(
303
+ config: PretrainedConfig,
304
+ dtype: Union[str, torch.dtype],
305
+ ) -> torch.dtype:
306
+ # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
307
+ # because config.torch_dtype can be None.
308
+ config_dtype = getattr(config, "torch_dtype", None)
309
+ if config_dtype is None:
310
+ config_dtype = torch.float32
311
+
312
+ if isinstance(dtype, str):
313
+ dtype = dtype.lower()
314
+ if dtype == "auto":
315
+ if config_dtype == torch.float32:
316
+ if config.model_type == "gemma2":
317
+ logger.info(
318
+ "For Gemma 2, we downcast float32 to bfloat16 instead "
319
+ "of float16 by default. Please specify `dtype` if you "
320
+ "want to use float16."
321
+ )
322
+ torch_dtype = torch.bfloat16
323
+ else:
324
+ # Following the common practice, we use float16 for float32
325
+ # models.
326
+ torch_dtype = torch.float16
327
+ else:
328
+ torch_dtype = config_dtype
329
+ else:
330
+ if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
331
+ raise ValueError(f"Unknown dtype: {dtype}")
332
+ torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
333
+ elif isinstance(dtype, torch.dtype):
334
+ torch_dtype = dtype
335
+ else:
336
+ raise ValueError(f"Unknown dtype: {dtype}")
337
+
338
+ # Verify the dtype.
339
+ if torch_dtype != config_dtype:
340
+ if torch_dtype == torch.float32:
341
+ # Upcasting to float32 is allowed.
342
+ logger.info("Upcasting %s to %s.", config_dtype, torch_dtype)
343
+ pass
344
+ elif config_dtype == torch.float32:
345
+ # Downcasting from float32 to float16 or bfloat16 is allowed.
346
+ logger.info("Downcasting %s to %s.", config_dtype, torch_dtype)
347
+ pass
348
+ else:
349
+ # Casting between float16 and bfloat16 is allowed with a warning.
350
+ logger.warning("Casting %s to %s.", config_dtype, torch_dtype)
351
+
352
+ return torch_dtype
353
+
354
+
198
355
  def is_generation_model(model_architectures: List[str], is_embedding: bool = False):
199
356
  # We have two ways to determine whether a model is a generative model.
200
357
  # 1. Check the model architectue
@@ -121,13 +121,10 @@ class Qwen2VLConfig(PretrainedConfig):
121
121
  self.attention_dropout = attention_dropout
122
122
  self.rope_scaling = rope_scaling
123
123
 
124
- # NOTE: the following section from original transformers config
125
- # for Qwen2-VL is commented out to address rope config loading issue
126
- #
127
- # if self.rope_scaling is not None and "type" in self.rope_scaling:
128
- # if self.rope_scaling["type"] == "mrope":
129
- # self.rope_scaling["type"] = "default"
130
- # self.rope_scaling["rope_type"] = self.rope_scaling["type"]
131
- # rope_config_validation(self)
124
+ # NOTE(HandH1998): This is necessary for configuring the `rope_type`` of qwen2vl models after removing dependencies on vllm.
125
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
126
+ if self.rope_scaling["type"] == "mrope":
127
+ self.rope_scaling["type"] = "default"
128
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
132
129
 
133
130
  super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
@@ -42,6 +42,7 @@ class OutlinesGrammar(BaseGrammarObject):
42
42
  self.guide = guide
43
43
  self.jump_forward_map = jump_forward_map
44
44
  self.state = 0
45
+ self.finished = False
45
46
 
46
47
  def accept_token(self, token: int):
47
48
  self.state = self.guide.get_next_state(self.state, token)
@@ -84,6 +85,10 @@ class OutlinesGrammar(BaseGrammarObject):
84
85
  ) -> torch.Tensor:
85
86
  return torch.zeros(batch_size, vocab_size, dtype=torch.bool, device=device)
86
87
 
88
+ @staticmethod
89
+ def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
90
+ return vocab_mask
91
+
87
92
  def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
88
93
  tokens = torch.tensor(
89
94
  self.guide.get_next_instruction(self.state).tokens, dtype=torch.int64
@@ -152,7 +157,12 @@ class OutlinesGrammarBackend(BaseGrammarBackend):
152
157
  raise ValueError(f"Invalid key_type: {key_type}")
153
158
 
154
159
  try:
155
- guide = RegexGuide(regex, self.outlines_tokenizer)
160
+ if hasattr(RegexGuide, "from_regex"):
161
+ # outlines >= 0.1.1
162
+ guide = RegexGuide.from_regex(regex, self.outlines_tokenizer)
163
+ else:
164
+ # outlines <= 0.0.46
165
+ guide = RegexGuide(regex, self.outlines_tokenizer)
156
166
  except interegular.patterns.InvalidSyntax as e:
157
167
  logger.warning(f"skip invalid regex schema: {regex=}, {e=}")
158
168
  return None
@@ -23,7 +23,14 @@ from collections import defaultdict
23
23
  import interegular
24
24
  from interegular import InvalidSyntax
25
25
  from outlines.caching import cache as disk_cache
26
- from outlines.fsm.regex import FSMInfo, make_byte_level_fsm, make_deterministic_fsm
26
+
27
+ try:
28
+ # outlines >= 0.1.0
29
+ from outlines_core.fsm.outlines_core_rs import FSMInfo
30
+ from outlines_core.fsm.regex import make_byte_level_fsm, make_deterministic_fsm
31
+ except ImportError:
32
+ # outlines <= 0.0.46
33
+ from outlines.fsm.regex import FSMInfo, make_byte_level_fsm, make_deterministic_fsm
27
34
 
28
35
  IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
29
36
 
@@ -45,6 +45,7 @@ class XGrammarGrammar(BaseGrammarObject):
45
45
  self.matcher = matcher
46
46
  self.vocab_size = vocab_size
47
47
  self.ctx = ctx
48
+ self.finished = False
48
49
 
49
50
  def accept_token(self, token: int):
50
51
  assert self.matcher.accept_token(token)
@@ -85,12 +86,11 @@ class XGrammarGrammar(BaseGrammarObject):
85
86
  self.matcher.fill_next_token_bitmask(vocab_mask, idx)
86
87
 
87
88
  @staticmethod
88
- def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
89
- if vocab_mask.device.type != logits.device.type:
90
- # vocab_mask must then be on the same device as logits
91
- # when applying the token bitmask, so we check and move if needed
92
- vocab_mask = vocab_mask.to(logits.device)
89
+ def move_vocab_mask(vocab_mask: torch.Tensor, device) -> torch.Tensor:
90
+ return vocab_mask.to(device, non_blocking=True)
93
91
 
92
+ @staticmethod
93
+ def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
94
94
  apply_token_bitmask_inplace(logits, vocab_mask)
95
95
 
96
96
  def copy(self):
@@ -0,0 +1,3 @@
1
+ from .communication_op import *
2
+ from .parallel_state import *
3
+ from .utils import *
@@ -0,0 +1,34 @@
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/communication_op.py
2
+ from typing import Any, Dict, Optional, Union
3
+
4
+ import torch
5
+ import torch.distributed
6
+
7
+ from .parallel_state import get_tp_group
8
+
9
+
10
+ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
11
+ """All-reduce the input tensor across model parallel group."""
12
+ return get_tp_group().all_reduce(input_)
13
+
14
+
15
+ def tensor_model_parallel_all_gather(
16
+ input_: torch.Tensor, dim: int = -1
17
+ ) -> torch.Tensor:
18
+ """All-gather the input tensor across model parallel group."""
19
+ return get_tp_group().all_gather(input_, dim)
20
+
21
+
22
+ def tensor_model_parallel_gather(
23
+ input_: torch.Tensor, dst: int = 0, dim: int = -1
24
+ ) -> Optional[torch.Tensor]:
25
+ """Gather the input tensor across model parallel group."""
26
+ return get_tp_group().gather(input_, dst, dim)
27
+
28
+
29
+ def broadcast_tensor_dict(
30
+ tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None, src: int = 0
31
+ ):
32
+ if not torch.distributed.is_initialized():
33
+ return tensor_dict
34
+ return get_tp_group().broadcast_tensor_dict(tensor_dict, src)