sglang 0.2.12__py3-none-any.whl → 0.2.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. sglang/api.py +7 -1
  2. sglang/bench_latency.py +3 -2
  3. sglang/global_config.py +1 -1
  4. sglang/lang/backend/runtime_endpoint.py +60 -49
  5. sglang/lang/interpreter.py +4 -2
  6. sglang/lang/ir.py +13 -4
  7. sglang/srt/constrained/jump_forward.py +13 -2
  8. sglang/srt/layers/activation.py +0 -1
  9. sglang/srt/layers/extend_attention.py +3 -1
  10. sglang/srt/layers/fused_moe/__init__.py +1 -0
  11. sglang/srt/layers/{fused_moe.py → fused_moe/fused_moe.py} +165 -108
  12. sglang/srt/layers/fused_moe/layer.py +587 -0
  13. sglang/srt/layers/logits_processor.py +4 -4
  14. sglang/srt/layers/radix_attention.py +38 -14
  15. sglang/srt/managers/schedule_batch.py +9 -14
  16. sglang/srt/managers/tokenizer_manager.py +1 -1
  17. sglang/srt/managers/tp_worker.py +1 -7
  18. sglang/srt/model_executor/cuda_graph_runner.py +48 -17
  19. sglang/srt/model_executor/forward_batch_info.py +132 -58
  20. sglang/srt/model_executor/model_runner.py +61 -28
  21. sglang/srt/models/chatglm.py +2 -2
  22. sglang/srt/models/commandr.py +1 -1
  23. sglang/srt/models/deepseek.py +2 -2
  24. sglang/srt/models/deepseek_v2.py +7 -6
  25. sglang/srt/models/gemma.py +1 -1
  26. sglang/srt/models/gemma2.py +11 -5
  27. sglang/srt/models/grok.py +50 -396
  28. sglang/srt/models/minicpm.py +2 -2
  29. sglang/srt/models/mixtral.py +56 -254
  30. sglang/srt/models/mixtral_quant.py +1 -4
  31. sglang/srt/models/qwen.py +2 -2
  32. sglang/srt/models/qwen2.py +2 -2
  33. sglang/srt/models/qwen2_moe.py +2 -2
  34. sglang/srt/models/stablelm.py +1 -1
  35. sglang/srt/openai_api/adapter.py +32 -21
  36. sglang/srt/sampling_params.py +0 -4
  37. sglang/srt/server.py +23 -15
  38. sglang/srt/server_args.py +7 -1
  39. sglang/srt/utils.py +1 -2
  40. sglang/test/runners.py +18 -10
  41. sglang/test/test_programs.py +32 -5
  42. sglang/test/test_utils.py +5 -1
  43. sglang/version.py +1 -1
  44. {sglang-0.2.12.dist-info → sglang-0.2.13.dist-info}/METADATA +12 -4
  45. {sglang-0.2.12.dist-info → sglang-0.2.13.dist-info}/RECORD +48 -48
  46. {sglang-0.2.12.dist-info → sglang-0.2.13.dist-info}/WHEEL +1 -1
  47. sglang/srt/model_loader/model_loader.py +0 -292
  48. sglang/srt/model_loader/utils.py +0 -275
  49. {sglang-0.2.12.dist-info → sglang-0.2.13.dist-info}/LICENSE +0 -0
  50. {sglang-0.2.12.dist-info → sglang-0.2.13.dist-info}/top_level.txt +0 -0
@@ -1,292 +0,0 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
-
7
- http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- Unless required by applicable law or agreed to in writing, software
10
- distributed under the License is distributed on an "AS IS" BASIS,
11
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- See the License for the specific language governing permissions and
13
- limitations under the License.
14
- """
15
-
16
- # temporarily adapted from https://github.com/vllm-project/vllm/blob/10383887e03412196a2689b9398290719c4797bf/vllm/model_executor/model_loader/loader.py
17
- # FIXME: in progress of refactoring the model loader
18
-
19
- import glob
20
- import os
21
- import re
22
- from typing import Any, Dict, Generator, List, Optional, Tuple, Type
23
-
24
- import torch
25
- from torch import nn
26
- from tqdm import tqdm
27
- from vllm.config import (
28
- CacheConfig,
29
- DeviceConfig,
30
- LoadConfig,
31
- LoadFormat,
32
- LoRAConfig,
33
- ModelConfig,
34
- MultiModalConfig,
35
- ParallelConfig,
36
- SchedulerConfig,
37
- )
38
- from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
39
- from vllm.model_executor.model_loader.utils import (
40
- get_model_architecture,
41
- set_default_torch_dtype,
42
- )
43
- from vllm.platforms import current_platform
44
-
45
- from sglang.srt.model_loader.utils import (
46
- download_safetensors_index_file_from_hf,
47
- download_weights_from_hf,
48
- filter_duplicate_safetensors_files,
49
- get_quant_config,
50
- safetensors_weights_iterator,
51
- )
52
-
53
-
54
- def _get_quantization_config(
55
- model_config: ModelConfig, load_config: LoadConfig
56
- ) -> Optional[QuantizationConfig]:
57
- """Get the quantization config."""
58
- if model_config.quantization is not None:
59
- quant_config = get_quant_config(model_config, load_config)
60
- capability = current_platform.get_device_capability()
61
- capability = capability[0] * 10 + capability[1]
62
- if capability < quant_config.get_min_capability():
63
- raise ValueError(
64
- f"The quantization method {model_config.quantization} is not "
65
- "supported for the current GPU. "
66
- f"Minimum capability: {quant_config.get_min_capability()}. "
67
- f"Current capability: {capability}."
68
- )
69
- supported_dtypes = quant_config.get_supported_act_dtypes()
70
- if model_config.dtype not in supported_dtypes:
71
- raise ValueError(
72
- f"{model_config.dtype} is not supported for quantization "
73
- f"method {model_config.quantization}. Supported dtypes: "
74
- f"{supported_dtypes}"
75
- )
76
- return quant_config
77
- return None
78
-
79
-
80
- def _get_model_initialization_kwargs(
81
- model_class: Type[nn.Module],
82
- lora_config: Optional[LoRAConfig],
83
- multimodal_config: Optional[MultiModalConfig],
84
- ) -> Dict[str, Any]:
85
- """Get extra kwargs for model initialization."""
86
- extra_kwargs: Dict[str, Any] = {}
87
-
88
- assert lora_config is None
89
- assert multimodal_config is None
90
-
91
- return extra_kwargs
92
-
93
-
94
- def _initialize_model(
95
- model_config: ModelConfig,
96
- load_config: LoadConfig,
97
- lora_config: Optional[LoRAConfig],
98
- multimodal_config: Optional[MultiModalConfig],
99
- cache_config: CacheConfig,
100
- ) -> nn.Module:
101
- """Initialize a model with the given configurations."""
102
- model_class = get_model_architecture(model_config)[0]
103
- quant_config = _get_quantization_config(model_config, load_config)
104
-
105
- return model_class(
106
- config=model_config.hf_config,
107
- cache_config=cache_config,
108
- quant_config=quant_config,
109
- efficient_weight_load=True,
110
- **_get_model_initialization_kwargs(model_class, lora_config, multimodal_config),
111
- )
112
-
113
-
114
- class ModelLoader:
115
- """Model loader that can load different file types from disk."""
116
-
117
- def __init__(self, load_config: LoadConfig):
118
- self.load_config = load_config
119
-
120
- def _prepare_weights(
121
- self, model_name_or_path: str, revision: Optional[str], fall_back_to_pt: bool
122
- ) -> Tuple[str, List[str], bool]:
123
- """Prepare weights for the model.
124
-
125
- If the model is not local, it will be downloaded."""
126
-
127
- is_local = os.path.isdir(model_name_or_path)
128
- load_format = self.load_config.load_format
129
- use_safetensors = False
130
- # Some quantized models use .pt files for storing the weights.
131
- if load_format == LoadFormat.AUTO:
132
- allow_patterns = ["*.safetensors", "*.bin"]
133
- elif load_format == LoadFormat.SAFETENSORS:
134
- use_safetensors = True
135
- allow_patterns = ["*.safetensors"]
136
- elif load_format == LoadFormat.PT:
137
- allow_patterns = ["*.pt"]
138
- elif load_format == LoadFormat.NPCACHE:
139
- allow_patterns = ["*.bin"]
140
- else:
141
- raise ValueError(f"Unknown load_format: {load_format}")
142
-
143
- if fall_back_to_pt:
144
- allow_patterns += ["*.pt"]
145
-
146
- if not is_local:
147
- hf_folder = download_weights_from_hf(
148
- model_name_or_path,
149
- self.load_config.download_dir,
150
- allow_patterns,
151
- revision,
152
- )
153
- else:
154
- hf_folder = model_name_or_path
155
-
156
- hf_weights_files: List[str] = []
157
- for pattern in allow_patterns:
158
- hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
159
- if len(hf_weights_files) > 0:
160
- if pattern == "*.safetensors":
161
- use_safetensors = True
162
- break
163
-
164
- if use_safetensors:
165
- # For models like Mistral-7B-Instruct-v0.3
166
- # there are both sharded safetensors files and a consolidated
167
- # safetensors file. Using both breaks.
168
- # Here, we download the `model.safetensors.index.json` and filter
169
- # any files not found in the index.
170
- if not is_local:
171
- download_safetensors_index_file_from_hf(
172
- model_name_or_path, self.load_config.download_dir, revision
173
- )
174
- hf_weights_files = filter_duplicate_safetensors_files(
175
- hf_weights_files, hf_folder
176
- )
177
- else:
178
- hf_weights_files = filter_files_not_needed_for_inference(hf_weights_files)
179
-
180
- if len(hf_weights_files) == 0:
181
- raise RuntimeError(
182
- f"Cannot find any model weights with `{model_name_or_path}`"
183
- )
184
-
185
- return hf_folder, hf_weights_files, use_safetensors
186
-
187
- def _get_weights_iterator(
188
- self, model_name_or_path: str, revision: Optional[str], fall_back_to_pt: bool
189
- ) -> Generator[Tuple[str, torch.Tensor], None, None]:
190
- """Get an iterator for the model weights based on the load format."""
191
- hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
192
- model_name_or_path, revision, fall_back_to_pt
193
- )
194
- if self.load_config.load_format == LoadFormat.NPCACHE:
195
- # Currently np_cache only support *.bin checkpoints
196
- assert use_safetensors is False
197
- weights_iterator = np_cache_weights_iterator(
198
- model_name_or_path,
199
- self.load_config.download_dir,
200
- hf_folder,
201
- hf_weights_files,
202
- )
203
- elif use_safetensors:
204
- weights_iterator = safetensors_weights_iterator(hf_weights_files)
205
- else:
206
- weights_iterator = pt_weights_iterator(hf_weights_files)
207
-
208
- return weights_iterator
209
-
210
- def load_model(
211
- self,
212
- *,
213
- model_config: ModelConfig,
214
- device_config: DeviceConfig,
215
- lora_config: Optional[LoRAConfig],
216
- multimodal_config: Optional[MultiModalConfig],
217
- parallel_config: ParallelConfig,
218
- scheduler_config: SchedulerConfig,
219
- cache_config: CacheConfig,
220
- ) -> nn.Module:
221
- with set_default_torch_dtype(model_config.dtype):
222
- with torch.device(device_config.device):
223
- model = _initialize_model(
224
- model_config,
225
- self.load_config,
226
- lora_config,
227
- multimodal_config,
228
- cache_config,
229
- )
230
- weights = self._get_weights_iterator(
231
- model_config.model,
232
- model_config.revision,
233
- fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load", True),
234
- )
235
-
236
- modules = {}
237
- for name, module in model.named_modules():
238
- modules[name] = module
239
-
240
- def apply_quant_method(module):
241
- quant_method = getattr(module, "quant_method", None)
242
- if quant_method is not None:
243
- # print("before apply quant", module.weight, module.weight.dtype)
244
- quant_method.process_weights_after_loading(module)
245
- # print("after apply quant", module.weight, module.weight.dtype)
246
- # FIXME: Remove this after Mixtral is updated
247
- # to use quant_method.
248
- if hasattr(module, "process_weights_after_loading"):
249
- module.process_weights_after_loading()
250
-
251
- if torch.cuda.current_device() == 0:
252
- weights = tqdm(
253
- weights, total=model.get_num_params() * 1.5, desc="load model"
254
- )
255
-
256
- num_shard = {}
257
- num_loaded = {}
258
- for name, loaded_weight in weights:
259
- model.load_weights(None, name, loaded_weight)
260
- module_name, shard_num = model.get_module_name(name)
261
- num_shard[module_name] = shard_num
262
- if module_name not in num_loaded:
263
- num_loaded[module_name] = 1
264
- else:
265
- num_loaded[module_name] += 1
266
- if num_loaded[module_name] == num_shard[module_name]:
267
- apply_quant_method(modules[module_name])
268
-
269
- return model.eval()
270
-
271
-
272
- def get_model(
273
- *,
274
- model_config: ModelConfig,
275
- load_config: LoadConfig,
276
- device_config: DeviceConfig,
277
- parallel_config: ParallelConfig,
278
- scheduler_config: SchedulerConfig,
279
- lora_config: Optional[LoRAConfig],
280
- multimodal_config: Optional[MultiModalConfig],
281
- cache_config: CacheConfig,
282
- ) -> nn.Module:
283
- loader = ModelLoader(load_config)
284
- return loader.load_model(
285
- model_config=model_config,
286
- device_config=device_config,
287
- lora_config=lora_config,
288
- multimodal_config=multimodal_config,
289
- parallel_config=parallel_config,
290
- scheduler_config=scheduler_config,
291
- cache_config=cache_config,
292
- )
@@ -1,275 +0,0 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
-
7
- http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- Unless required by applicable law or agreed to in writing, software
10
- distributed under the License is distributed on an "AS IS" BASIS,
11
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- See the License for the specific language governing permissions and
13
- limitations under the License.
14
- """
15
-
16
- # temporarily adapted from vLLM
17
- # FIXME: in progress of refactoring the model loader
18
- """Utilities for selecting and loading models."""
19
- import contextlib
20
- import fnmatch
21
- import hashlib
22
- import json
23
- import logging
24
- import os
25
- import tempfile
26
- from typing import Any, Generator, Iterable, List, Optional, Tuple, Type
27
-
28
- import filelock
29
- import huggingface_hub.constants
30
- import torch
31
- from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
32
- from safetensors.torch import load_file, safe_open, save_file
33
- from torch import nn
34
- from tqdm.auto import tqdm
35
- from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
36
- from vllm.config import LoadConfig, ModelConfig
37
- from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
38
-
39
- from sglang.srt.layers.quantization import get_quantization_config
40
-
41
- logger = logging.getLogger(__name__)
42
- temp_dir = tempfile.gettempdir()
43
-
44
-
45
- @contextlib.contextmanager
46
- def set_default_torch_dtype(dtype: torch.dtype):
47
- """Sets the default torch dtype to the given dtype."""
48
- old_dtype = torch.get_default_dtype()
49
- torch.set_default_dtype(dtype)
50
- yield
51
- torch.set_default_dtype(old_dtype)
52
-
53
-
54
- def get_model_architecture(model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
55
- architectures = getattr(model_config.hf_config, "architectures", [])
56
- # Special handling for quantized Mixtral.
57
- # FIXME(woosuk): This is a temporary hack.
58
- if (
59
- model_config.quantization is not None
60
- and model_config.quantization != "fp8"
61
- and "MixtralForCausalLM" in architectures
62
- ):
63
- architectures = ["QuantMixtralForCausalLM"]
64
-
65
- for arch in architectures:
66
- model_cls = ModelRegistry.load_model_cls(arch)
67
- if model_cls is not None:
68
- return (model_cls, arch)
69
- raise ValueError(
70
- f"Model architectures {architectures} are not supported for now. "
71
- f"Supported architectures: {ModelRegistry.get_supported_archs()}"
72
- )
73
-
74
-
75
- class DisabledTqdm(tqdm):
76
-
77
- def __init__(self, *args, **kwargs):
78
- super().__init__(*args, **kwargs, disable=True)
79
-
80
-
81
- def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None):
82
- lock_dir = cache_dir or temp_dir
83
- os.makedirs(os.path.dirname(lock_dir), exist_ok=True)
84
- model_name = model_name_or_path.replace("/", "-")
85
- hash_name = hashlib.sha256(model_name.encode()).hexdigest()
86
- # add hash to avoid conflict with old users' lock files
87
- lock_file_name = hash_name + model_name + ".lock"
88
- # mode 0o666 is required for the filelock to be shared across users
89
- lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name), mode=0o666)
90
- return lock
91
-
92
-
93
- def download_weights_from_hf(
94
- model_name_or_path: str,
95
- cache_dir: Optional[str],
96
- allow_patterns: List[str],
97
- revision: Optional[str] = None,
98
- ) -> str:
99
- """Download model weights from Hugging Face Hub.
100
-
101
- Args:
102
- model_name_or_path (str): The model name or path.
103
- cache_dir (Optional[str]): The cache directory to store the model
104
- weights. If None, will use HF defaults.
105
- allow_patterns (List[str]): The allowed patterns for the
106
- weight files. Files matched by any of the patterns will be
107
- downloaded.
108
- revision (Optional[str]): The revision of the model.
109
-
110
- Returns:
111
- str: The path to the downloaded model weights.
112
- """
113
- if not huggingface_hub.constants.HF_HUB_OFFLINE:
114
- # Before we download we look at that is available:
115
- fs = HfFileSystem()
116
- file_list = fs.ls(model_name_or_path, detail=False, revision=revision)
117
-
118
- # depending on what is available we download different things
119
- for pattern in allow_patterns:
120
- matching = fnmatch.filter(file_list, pattern)
121
- if len(matching) > 0:
122
- allow_patterns = [pattern]
123
- break
124
-
125
- logger.info("Using model weights format %s", allow_patterns)
126
- # Use file lock to prevent multiple processes from
127
- # downloading the same model weights at the same time.
128
- with get_lock(model_name_or_path, cache_dir):
129
- hf_folder = snapshot_download(
130
- model_name_or_path,
131
- allow_patterns=allow_patterns,
132
- cache_dir=cache_dir,
133
- tqdm_class=DisabledTqdm,
134
- revision=revision,
135
- local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
136
- )
137
- return hf_folder
138
-
139
-
140
- def download_safetensors_index_file_from_hf(
141
- model_name_or_path: str,
142
- cache_dir: Optional[str],
143
- revision: Optional[str] = None,
144
- ) -> None:
145
- """Download hf safetensors index file from Hugging Face Hub.
146
-
147
- Args:
148
- model_name_or_path (str): The model name or path.
149
- cache_dir (Optional[str]): The cache directory to store the model
150
- weights. If None, will use HF defaults.
151
- revision (Optional[str]): The revision of the model.
152
- """
153
- # Use file lock to prevent multiple processes from
154
- # downloading the same model weights at the same time.
155
- with get_lock(model_name_or_path, cache_dir):
156
- try:
157
- # Download the safetensors index file.
158
- hf_hub_download(
159
- repo_id=model_name_or_path,
160
- filename=SAFE_WEIGHTS_INDEX_NAME,
161
- cache_dir=cache_dir,
162
- revision=revision,
163
- local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
164
- )
165
- # If file not found on remote or locally, we should not fail since
166
- # only some models will have SAFE_WEIGHTS_INDEX_NAME.
167
- except huggingface_hub.utils.EntryNotFoundError:
168
- logger.info("No %s found in remote.", SAFE_WEIGHTS_INDEX_NAME)
169
- except huggingface_hub.utils.LocalEntryNotFoundError:
170
- logger.info("No %s found in local cache.", SAFE_WEIGHTS_INDEX_NAME)
171
-
172
-
173
- # For models like Mistral-7B-v0.3, there are both sharded
174
- # safetensors files and a consolidated safetensors file.
175
- # Passing both of these to the weight loader functionality breaks.
176
- # So, we use the SAFE_WEIGHTS_INDEX_NAME to
177
- # look up which safetensors files should be used.
178
- def filter_duplicate_safetensors_files(
179
- hf_weights_files: List[str], hf_folder: str
180
- ) -> List[str]:
181
- # model.safetensors.index.json is a mapping from keys in the
182
- # torch state_dict to safetensors file holding that weight.
183
- index_file_name = os.path.join(hf_folder, SAFE_WEIGHTS_INDEX_NAME)
184
- if not os.path.isfile(index_file_name):
185
- return hf_weights_files
186
-
187
- # Iterate through the weight_map (weight_name: safetensors files)
188
- # to identify weights that we should use.
189
- with open(index_file_name) as index_file:
190
- weight_map = json.load(index_file)["weight_map"]
191
- weight_files_in_index = set()
192
- for weight_name in weight_map:
193
- weight_files_in_index.add(os.path.join(hf_folder, weight_map[weight_name]))
194
- # Filter out any fields that are not found in the index file.
195
- hf_weights_files = [f for f in hf_weights_files if f in weight_files_in_index]
196
- return hf_weights_files
197
-
198
-
199
- def safetensors_weights_iterator(
200
- hf_weights_files: List[str],
201
- ) -> Generator[Tuple[str, torch.Tensor], None, None]:
202
- """Iterate over the weights in the model safetensor files."""
203
- for st_file in hf_weights_files:
204
- with safe_open(st_file, framework="pt") as f:
205
- for name in f.keys(): # noqa: SIM118
206
- param = f.get_tensor(name)
207
- yield name, param
208
-
209
-
210
- def get_quant_config(
211
- model_config: ModelConfig, load_config: LoadConfig
212
- ) -> QuantizationConfig:
213
- quant_cls = get_quantization_config(model_config.quantization)
214
- # Read the quantization config from the HF model config, if available.
215
- hf_quant_config = getattr(model_config.hf_config, "quantization_config", None)
216
- if hf_quant_config is None:
217
- # compressed-tensors uses a compressions_config
218
- hf_quant_config = getattr(model_config.hf_config, "compression_config", None)
219
- if hf_quant_config is not None:
220
- return quant_cls.from_config(hf_quant_config)
221
- # In case of bitsandbytes/QLoRA, get quant config from the adapter model.
222
- if model_config.quantization == "bitsandbytes":
223
- if (
224
- not load_config.model_loader_extra_config
225
- or "qlora_adapter_name_or_path" not in load_config.model_loader_extra_config
226
- ):
227
- return quant_cls.from_config({"adapter_name_or_path": ""})
228
- model_name_or_path = load_config.model_loader_extra_config[
229
- "qlora_adapter_name_or_path"
230
- ]
231
-
232
- else:
233
- model_name_or_path = model_config.model
234
- is_local = os.path.isdir(model_name_or_path)
235
- if not is_local:
236
- # Download the config files.
237
- with get_lock(model_name_or_path, load_config.download_dir):
238
- hf_folder = snapshot_download(
239
- model_name_or_path,
240
- revision=model_config.revision,
241
- allow_patterns="*.json",
242
- cache_dir=load_config.download_dir,
243
- local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
244
- tqdm_class=DisabledTqdm,
245
- )
246
- else:
247
- hf_folder = model_name_or_path
248
-
249
- possible_config_filenames = quant_cls.get_config_filenames()
250
-
251
- # If the quantization config is not found, use the default config.
252
- if not possible_config_filenames:
253
- return quant_cls()
254
-
255
- config_files = glob.glob(os.path.join(hf_folder, "*.json"))
256
-
257
- quant_config_files = [
258
- f for f in config_files if any(f.endswith(x) for x in possible_config_filenames)
259
- ]
260
- if len(quant_config_files) == 0:
261
- raise ValueError(f"Cannot find the config file for {model_config.quantization}")
262
- if len(quant_config_files) > 1:
263
- raise ValueError(
264
- f"Found multiple config files for {model_config.quantization}: "
265
- f"{quant_config_files}"
266
- )
267
-
268
- quant_config_file = quant_config_files[0]
269
- with open(quant_config_file, "r") as f:
270
- config = json.load(f)
271
-
272
- if model_config.quantization == "bitsandbytes":
273
- config["adapter_name_or_path"] = model_name_or_path
274
-
275
- return quant_cls.from_config(config)