sglang 0.4.1.post3__py3-none-any.whl → 0.4.1.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. sglang/bench_one_batch.py +2 -0
  2. sglang/bench_serving.py +18 -1
  3. sglang/lang/interpreter.py +71 -1
  4. sglang/lang/ir.py +2 -0
  5. sglang/srt/configs/__init__.py +4 -0
  6. sglang/srt/configs/chatglm.py +78 -0
  7. sglang/srt/configs/dbrx.py +279 -0
  8. sglang/srt/configs/model_config.py +1 -1
  9. sglang/srt/hf_transformers_utils.py +9 -14
  10. sglang/srt/layers/attention/__init__.py +22 -6
  11. sglang/srt/layers/attention/double_sparsity_backend.py +0 -52
  12. sglang/srt/layers/attention/flashinfer_backend.py +215 -83
  13. sglang/srt/layers/attention/torch_native_backend.py +1 -38
  14. sglang/srt/layers/attention/triton_backend.py +20 -11
  15. sglang/srt/layers/attention/triton_ops/decode_attention.py +4 -0
  16. sglang/srt/layers/linear.py +159 -55
  17. sglang/srt/layers/logits_processor.py +170 -215
  18. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  19. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200.json +146 -0
  20. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  21. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200.json +146 -0
  22. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  23. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200.json +146 -0
  24. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  25. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200.json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  27. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200.json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200.json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200.json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200.json +146 -0
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200.json +146 -0
  36. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  37. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200.json +146 -0
  38. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  39. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +198 -29
  40. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -7
  41. sglang/srt/layers/parameter.py +431 -0
  42. sglang/srt/layers/quantization/__init__.py +3 -2
  43. sglang/srt/layers/quantization/fp8.py +3 -3
  44. sglang/srt/layers/quantization/modelopt_quant.py +174 -0
  45. sglang/srt/layers/sampler.py +57 -21
  46. sglang/srt/layers/torchao_utils.py +17 -3
  47. sglang/srt/layers/vocab_parallel_embedding.py +1 -1
  48. sglang/srt/managers/cache_controller.py +307 -0
  49. sglang/srt/managers/data_parallel_controller.py +2 -0
  50. sglang/srt/managers/io_struct.py +1 -2
  51. sglang/srt/managers/schedule_batch.py +33 -3
  52. sglang/srt/managers/schedule_policy.py +159 -90
  53. sglang/srt/managers/scheduler.py +68 -28
  54. sglang/srt/managers/session_controller.py +1 -1
  55. sglang/srt/managers/tokenizer_manager.py +27 -21
  56. sglang/srt/managers/tp_worker.py +16 -4
  57. sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
  58. sglang/srt/mem_cache/memory_pool.py +206 -1
  59. sglang/srt/metrics/collector.py +22 -30
  60. sglang/srt/model_executor/cuda_graph_runner.py +129 -77
  61. sglang/srt/model_executor/forward_batch_info.py +51 -21
  62. sglang/srt/model_executor/model_runner.py +72 -64
  63. sglang/srt/models/chatglm.py +1 -1
  64. sglang/srt/models/dbrx.py +1 -1
  65. sglang/srt/models/deepseek_v2.py +34 -7
  66. sglang/srt/models/grok.py +109 -29
  67. sglang/srt/models/llama.py +9 -2
  68. sglang/srt/openai_api/adapter.py +0 -17
  69. sglang/srt/openai_api/protocol.py +3 -3
  70. sglang/srt/sampling/sampling_batch_info.py +22 -0
  71. sglang/srt/sampling/sampling_params.py +9 -1
  72. sglang/srt/server.py +20 -13
  73. sglang/srt/server_args.py +120 -58
  74. sglang/srt/speculative/build_eagle_tree.py +347 -0
  75. sglang/srt/speculative/eagle_utils.py +626 -0
  76. sglang/srt/speculative/eagle_worker.py +184 -0
  77. sglang/srt/speculative/spec_info.py +5 -0
  78. sglang/srt/utils.py +47 -7
  79. sglang/test/test_programs.py +23 -1
  80. sglang/test/test_utils.py +36 -7
  81. sglang/version.py +1 -1
  82. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/METADATA +12 -12
  83. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/RECORD +86 -57
  84. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/WHEEL +1 -1
  85. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/LICENSE +0 -0
  86. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,431 @@
1
+ """
2
+ Adapted from vLLM (0.6.4.post1).
3
+ https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/parameter.py
4
+ """
5
+
6
+ import logging
7
+ from fractions import Fraction
8
+ from typing import Callable, Optional, Union
9
+
10
+ import torch
11
+ from torch.nn import Parameter
12
+ from vllm.distributed import get_tensor_model_parallel_rank
13
+
14
+ __all__ = [
15
+ "BasevLLMParameter",
16
+ "PackedvLLMParameter",
17
+ "PerTensorScaleParameter",
18
+ "ModelWeightParameter",
19
+ "ChannelQuantScaleParameter",
20
+ "GroupQuantScaleParameter",
21
+ "PackedColumnParameter",
22
+ "RowvLLMParameter",
23
+ ]
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ class BasevLLMParameter(Parameter):
29
+ """
30
+ Base parameter for vLLM linear layers. Extends the torch.nn.parameter
31
+ by taking in a linear weight loader. Will copy the loaded weight
32
+ into the parameter when the provided weight loader is called.
33
+ """
34
+
35
+ def __new__(cls, data: torch.Tensor, **kwargs):
36
+
37
+ return super().__new__(cls, data=data, requires_grad=False)
38
+
39
+ def __init__(self, data: torch.Tensor, weight_loader: Callable):
40
+ """
41
+ Initialize the BasevLLMParameter
42
+
43
+ :param data: torch tensor with the parameter data
44
+ :param weight_loader: weight loader callable
45
+
46
+ :returns: a torch.nn.parameter
47
+ """
48
+
49
+ self._weight_loader = weight_loader
50
+
51
+ @property
52
+ def weight_loader(self):
53
+ return self._weight_loader
54
+
55
+ def _assert_and_load(self, loaded_weight: torch.Tensor):
56
+ assert self.data.shape == loaded_weight.shape
57
+ self.data.copy_(loaded_weight)
58
+
59
+ def load_column_parallel_weight(self, loaded_weight: torch.Tensor):
60
+ self._assert_and_load(loaded_weight)
61
+
62
+ def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
63
+ self._assert_and_load(loaded_weight)
64
+
65
+ def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
66
+ self._assert_and_load(loaded_weight)
67
+
68
+ def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
69
+ self._assert_and_load(loaded_weight)
70
+
71
+
72
+ class _ColumnvLLMParameter(BasevLLMParameter):
73
+ """
74
+ Private class defining weight loading functionality
75
+ (load_merged_column_weight, load_qkv_weight)
76
+ for parameters being loaded into linear layers with column
77
+ parallelism. This includes QKV and MLP layers which are
78
+ not already fused on disk. Requires an output dimension
79
+ to be defined. Called within the weight loader of
80
+ each of the column parallel linear layers.
81
+ """
82
+
83
+ def __init__(self, output_dim: int, **kwargs):
84
+ self._output_dim = output_dim
85
+ super().__init__(**kwargs)
86
+
87
+ @property
88
+ def output_dim(self):
89
+ return self._output_dim
90
+
91
+ def load_column_parallel_weight(self, loaded_weight: torch.Tensor):
92
+ tp_rank = get_tensor_model_parallel_rank()
93
+ shard_size = self.data.shape[self.output_dim]
94
+ loaded_weight = loaded_weight.narrow(
95
+ self.output_dim, tp_rank * shard_size, shard_size
96
+ )
97
+ assert self.data.shape == loaded_weight.shape
98
+ self.data.copy_(loaded_weight)
99
+
100
+ def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
101
+
102
+ shard_offset = kwargs.get("shard_offset")
103
+ shard_size = kwargs.get("shard_size")
104
+ use_presharded_weights = kwargs.get("use_presharded_weights")
105
+ if (
106
+ isinstance(self, (PackedColumnParameter, PackedvLLMParameter))
107
+ and self.packed_dim == self.output_dim
108
+ ):
109
+ shard_size, shard_offset = self.adjust_shard_indexes_for_packing(
110
+ shard_offset=shard_offset, shard_size=shard_size
111
+ )
112
+
113
+ param_data = self.data
114
+
115
+ tp_rank = get_tensor_model_parallel_rank()
116
+ param_data = param_data.narrow(self.output_dim, shard_offset, shard_size)
117
+ if not use_presharded_weights:
118
+ loaded_weight = loaded_weight.narrow(
119
+ self.output_dim, tp_rank * shard_size, shard_size
120
+ )
121
+ assert param_data.shape == loaded_weight.shape
122
+ param_data.copy_(loaded_weight)
123
+
124
+ def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
125
+
126
+ shard_offset = kwargs.get("shard_offset")
127
+ shard_size = kwargs.get("shard_size")
128
+ shard_id = kwargs.get("shard_id")
129
+ num_heads = kwargs.get("num_heads")
130
+
131
+ if (
132
+ isinstance(self, (PackedColumnParameter, PackedvLLMParameter))
133
+ and self.output_dim == self.packed_dim
134
+ ):
135
+ shard_size, shard_offset = self.adjust_shard_indexes_for_packing(
136
+ shard_offset=shard_offset, shard_size=shard_size
137
+ )
138
+
139
+ param_data = self.data
140
+ tp_rank = get_tensor_model_parallel_rank()
141
+ shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads
142
+ param_data = param_data.narrow(self.output_dim, shard_offset, shard_size)
143
+ loaded_weight = loaded_weight.narrow(
144
+ self.output_dim, shard_id * shard_size, shard_size
145
+ )
146
+
147
+ assert param_data.shape == loaded_weight.shape
148
+ param_data.copy_(loaded_weight)
149
+
150
+
151
+ class RowvLLMParameter(BasevLLMParameter):
152
+ """
153
+ Parameter class defining weight_loading functionality
154
+ (load_row_parallel_weight) for parameters being loaded
155
+ into linear layers with row parallel functionality.
156
+ Requires an input_dim to be defined.
157
+ """
158
+
159
+ def __init__(self, input_dim: int, **kwargs):
160
+ self._input_dim = input_dim
161
+ super().__init__(**kwargs)
162
+
163
+ @property
164
+ def input_dim(self):
165
+ return self._input_dim
166
+
167
+ def load_row_parallel_weight(self, loaded_weight: torch.Tensor, **kwargs):
168
+ use_presharded_weights = kwargs.get("use_presharded_weights")
169
+ tp_rank = get_tensor_model_parallel_rank()
170
+ shard_size = self.data.shape[self.input_dim]
171
+ if not use_presharded_weights:
172
+ loaded_weight = loaded_weight.narrow(
173
+ self.input_dim, tp_rank * shard_size, shard_size
174
+ )
175
+
176
+ if len(loaded_weight.shape) == 0:
177
+ loaded_weight = loaded_weight.reshape(1)
178
+
179
+ assert self.data.shape == loaded_weight.shape
180
+ self.data.copy_(loaded_weight)
181
+
182
+
183
+ class ModelWeightParameter(_ColumnvLLMParameter, RowvLLMParameter):
184
+ """
185
+ Parameter class for linear layer weights. Uses both column and
186
+ row parallelism.
187
+ """
188
+
189
+ pass
190
+
191
+
192
+ class GroupQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter):
193
+ """
194
+ Parameter class for weight scales loaded for weights with
195
+ grouped quantization. Uses both column and row parallelism.
196
+ """
197
+
198
+ pass
199
+
200
+
201
+ class ChannelQuantScaleParameter(_ColumnvLLMParameter):
202
+ """
203
+ Parameter class for weight scales loaded for weights with
204
+ channel-wise quantization. Equivalent to _ColumnvLLMParameter.
205
+ """
206
+
207
+ pass
208
+
209
+
210
+ class PerTensorScaleParameter(BasevLLMParameter):
211
+ """
212
+ Parameter class for scales where the number of scales is
213
+ equivalent to the number of logical matrices in fused linear
214
+ layers (e.g. for QKV, there are 3 scales loaded from disk).
215
+ This is relevant to weights with per-tensor quantization.
216
+ Adds functionality to map the scalers to a shard during
217
+ weight loading.
218
+
219
+ Note: additional parameter manipulation may be handled
220
+ for each quantization config specifically, within
221
+ process_weights_after_loading
222
+ """
223
+
224
+ def __init__(self, **kwargs):
225
+ self.qkv_idxs = {"q": 0, "k": 1, "v": 2}
226
+ super().__init__(**kwargs)
227
+
228
+ def _shard_id_as_int(self, shard_id: Union[str, int]) -> int:
229
+ if isinstance(shard_id, int):
230
+ return shard_id
231
+
232
+ # if not int, assume shard_id for qkv
233
+ # map to int and return
234
+ assert isinstance(shard_id, str)
235
+ assert shard_id in self.qkv_idxs
236
+ return self.qkv_idxs[shard_id]
237
+
238
+ # For row parallel layers, no sharding needed
239
+ # load weight into parameter as is
240
+ def load_row_parallel_weight(self, *args, **kwargs):
241
+ super().load_row_parallel_weight(*args, **kwargs)
242
+
243
+ def load_merged_column_weight(self, *args, **kwargs):
244
+ self._load_into_shard_id(*args, **kwargs)
245
+
246
+ def load_qkv_weight(self, *args, **kwargs):
247
+ self._load_into_shard_id(*args, **kwargs)
248
+
249
+ def load_column_parallel_weight(self, *args, **kwargs):
250
+ super().load_row_parallel_weight(*args, **kwargs)
251
+
252
+ def _load_into_shard_id(
253
+ self, loaded_weight: torch.Tensor, shard_id: Union[str, int], **kwargs
254
+ ):
255
+ """
256
+ Slice the parameter data based on the shard id for
257
+ loading.
258
+ """
259
+
260
+ param_data = self.data
261
+ shard_id = self._shard_id_as_int(shard_id)
262
+
263
+ # AutoFP8 scales do not have a shape
264
+ # compressed-tensors scales do have a shape
265
+ if len(loaded_weight.shape) != 0:
266
+ assert loaded_weight.shape[0] == 1
267
+ loaded_weight = loaded_weight[0]
268
+
269
+ param_data = param_data[shard_id]
270
+ assert param_data.shape == loaded_weight.shape
271
+ param_data.copy_(loaded_weight)
272
+
273
+
274
+ class PackedColumnParameter(_ColumnvLLMParameter):
275
+ """
276
+ Parameter for model parameters which are packed on disk
277
+ and support column parallelism only. See PackedvLLMParameter
278
+ for more details on the packed properties.
279
+ """
280
+
281
+ def __init__(
282
+ self,
283
+ packed_factor: Union[int, Fraction],
284
+ packed_dim: int,
285
+ marlin_tile_size: Optional[int] = None,
286
+ **kwargs
287
+ ):
288
+ self._packed_factor = packed_factor
289
+ self._packed_dim = packed_dim
290
+ self._marlin_tile_size = marlin_tile_size
291
+ super().__init__(**kwargs)
292
+
293
+ @property
294
+ def packed_dim(self):
295
+ return self._packed_dim
296
+
297
+ @property
298
+ def packed_factor(self):
299
+ return self._packed_factor
300
+
301
+ @property
302
+ def marlin_tile_size(self):
303
+ return self._marlin_tile_size
304
+
305
+ def adjust_shard_indexes_for_packing(self, shard_size, shard_offset):
306
+ return _adjust_shard_indexes_for_packing(
307
+ shard_size=shard_size,
308
+ shard_offset=shard_offset,
309
+ packed_factor=self.packed_factor,
310
+ marlin_tile_size=self.marlin_tile_size,
311
+ )
312
+
313
+
314
+ class PackedvLLMParameter(ModelWeightParameter):
315
+ """
316
+ Parameter for model weights which are packed on disk.
317
+ Example: GPTQ Marlin weights are int4 or int8, packed into int32.
318
+ Extends the ModelWeightParameter to take in the
319
+ packed factor, the packed dimension, and optionally, marlin
320
+ tile size for marlin kernels. Adjusts the shard_size and
321
+ shard_offset for fused linear layers model weight loading
322
+ by accounting for packing and optionally, marlin tile size.
323
+ """
324
+
325
+ def __init__(
326
+ self,
327
+ packed_factor: Union[int, Fraction],
328
+ packed_dim: int,
329
+ marlin_tile_size: Optional[int] = None,
330
+ **kwargs
331
+ ):
332
+ self._packed_factor = packed_factor
333
+ self._packed_dim = packed_dim
334
+ self._marlin_tile_size = marlin_tile_size
335
+ super().__init__(**kwargs)
336
+
337
+ @property
338
+ def packed_dim(self):
339
+ return self._packed_dim
340
+
341
+ @property
342
+ def packed_factor(self):
343
+ return self._packed_factor
344
+
345
+ @property
346
+ def marlin_tile_size(self):
347
+ return self._marlin_tile_size
348
+
349
+ def adjust_shard_indexes_for_packing(self, shard_size, shard_offset):
350
+ return _adjust_shard_indexes_for_packing(
351
+ shard_size=shard_size,
352
+ shard_offset=shard_offset,
353
+ packed_factor=self.packed_factor,
354
+ marlin_tile_size=self.marlin_tile_size,
355
+ )
356
+
357
+
358
+ def permute_param_layout_(
359
+ param: BasevLLMParameter, input_dim: int, output_dim: int, **kwargs
360
+ ) -> BasevLLMParameter:
361
+ """
362
+ Permute a parameter's layout to the specified input and output dimensions,
363
+ useful for forcing the parameter into a known layout, for example, if I need
364
+ a packed (quantized) weight matrix to be in the layout
365
+ {input_dim = 0, output_dim = 1, packed_dim = 0}
366
+ then I can call:
367
+ permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
368
+ to ensure x is in the correct layout (permuting it to the correct layout if
369
+ required, asserting if it cannot get it to the correct layout)
370
+ """
371
+
372
+ curr_input_dim = getattr(param, "input_dim", None)
373
+ curr_output_dim = getattr(param, "output_dim", None)
374
+
375
+ if curr_input_dim is None or curr_output_dim is None:
376
+ assert param.data.dim() == 2, (
377
+ "permute_param_layout_ only supports 2D parameters when either "
378
+ "input_dim or output_dim is not set"
379
+ )
380
+
381
+ # if one of the dimensions is not set, set it to the opposite of the other
382
+ # we can only do this since we asserted the parameter is 2D above
383
+ if curr_input_dim is None:
384
+ assert curr_output_dim is not None, "either input or output dim must be set"
385
+ curr_input_dim = (curr_output_dim + 1) % 2
386
+ if curr_output_dim is None:
387
+ assert curr_input_dim is not None, "either input or output dim must be set"
388
+ curr_output_dim = (curr_input_dim + 1) % 2
389
+
390
+ # create permutation from the current layout to the layout with
391
+ # self.input_dim at input_dim and self.output_dim at output_dim preserving
392
+ # other dimensions
393
+ perm = [
394
+ i for i in range(param.data.dim()) if i not in [curr_input_dim, curr_output_dim]
395
+ ]
396
+ perm.insert(input_dim, curr_input_dim)
397
+ perm.insert(output_dim, curr_output_dim)
398
+
399
+ if "packed_dim" in kwargs:
400
+ assert (
401
+ hasattr(param, "packed_dim")
402
+ and param.packed_dim == perm[kwargs["packed_dim"]]
403
+ ), "permute_param_layout_ currently doesn't support repacking"
404
+
405
+ param.data = param.data.permute(*perm)
406
+ if hasattr(param, "_input_dim"):
407
+ param._input_dim = input_dim
408
+ if hasattr(param, "_output_dim"):
409
+ param._output_dim = output_dim
410
+ if "packed_dim" in kwargs and hasattr(param, "_packed_dim"):
411
+ param._packed_dim = kwargs["packed_dim"]
412
+
413
+ return param
414
+
415
+
416
+ def _adjust_shard_indexes_for_marlin(shard_size, shard_offset, marlin_tile_size):
417
+ return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
418
+
419
+
420
+ def _adjust_shard_indexes_for_packing(
421
+ shard_size, shard_offset, packed_factor, marlin_tile_size
422
+ ):
423
+ shard_size = shard_size // packed_factor
424
+ shard_offset = shard_offset // packed_factor
425
+ if marlin_tile_size is not None:
426
+ return _adjust_shard_indexes_for_marlin(
427
+ shard_size=shard_size,
428
+ shard_offset=shard_offset,
429
+ marlin_tile_size=marlin_tile_size,
430
+ )
431
+ return shard_size, shard_offset
@@ -1,8 +1,7 @@
1
1
  # Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py
2
2
 
3
- from typing import Callable, Dict, Optional, Type
3
+ from typing import Dict, Type
4
4
 
5
- import torch
6
5
  from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
7
6
  from vllm.model_executor.layers.quantization.awq import AWQConfig
8
7
  from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig
@@ -23,6 +22,7 @@ from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
23
22
 
24
23
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
25
24
  from sglang.srt.layers.quantization.fp8 import Fp8Config
25
+ from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config
26
26
 
27
27
  QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
28
28
  "aqlm": AQLMConfig,
@@ -32,6 +32,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
32
32
  "fp8": Fp8Config,
33
33
  "fbgemm_fp8": FBGEMMFp8Config,
34
34
  "marlin": MarlinConfig,
35
+ "modelopt": ModelOptFp8Config,
35
36
  "gguf": GGUFConfig,
36
37
  "gptq_marlin_24": GPTQMarlin24Config,
37
38
  "gptq_marlin": GPTQMarlinConfig,
@@ -25,9 +25,9 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
25
25
  per_tensor_dequantize,
26
26
  requantize_with_max_scale,
27
27
  )
28
- from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter
29
28
 
30
29
  from sglang.srt.layers.linear import LinearMethodBase, UnquantizedLinearMethod
30
+ from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter
31
31
  from sglang.srt.layers.quantization.base_config import (
32
32
  QuantizationConfig,
33
33
  QuantizeMethodBase,
@@ -280,9 +280,9 @@ class Fp8LinearMethod(LinearMethodBase):
280
280
  weight_scale=layer.weight_scale_inv,
281
281
  input_scale=None,
282
282
  )
283
- layer.weight = torch.nn.Parameter(weight, require_grad=False)
283
+ layer.weight = torch.nn.Parameter(weight, requires_grad=False)
284
284
  layer.weight_scale_inv = torch.nn.Parameter(
285
- weight_scale, require_grad=False
285
+ weight_scale, requires_grad=False
286
286
  )
287
287
  layer.input_scale = None
288
288
  return
@@ -0,0 +1,174 @@
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py
2
+
3
+ import logging
4
+ from typing import Any, Dict, List, Optional
5
+
6
+ import torch
7
+ from torch.nn.parameter import Parameter
8
+ from vllm.model_executor.layers.linear import LinearBase
9
+ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
10
+ apply_fp8_linear,
11
+ cutlass_fp8_supported,
12
+ requantize_with_max_scale,
13
+ )
14
+ from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter
15
+
16
+ from sglang.srt.layers.linear import LinearMethodBase
17
+ from sglang.srt.layers.quantization.base_config import (
18
+ QuantizationConfig,
19
+ QuantizeMethodBase,
20
+ )
21
+
22
+ # Initialize logger for the module
23
+ logger = logging.getLogger(__name__)
24
+
25
+ # Supported activation schemes for the current configuration
26
+ ACTIVATION_SCHEMES = ["static"]
27
+
28
+
29
+ class ModelOptFp8Config(QuantizationConfig):
30
+ """Configuration for ModelOpt FP8 quantization, including serialization and compatibility checks."""
31
+
32
+ def __init__(self, is_checkpoint_fp8_serialized: bool = False) -> None:
33
+ """
34
+ Args:
35
+ is_checkpoint_fp8_serialized (bool): Indicates if the checkpoint uses serialized FP8 format.
36
+ """
37
+ self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
38
+ if is_checkpoint_fp8_serialized:
39
+ logger.warning(
40
+ "Detected ModelOpt FP8 checkpoint. The format is experimental and subject to change."
41
+ )
42
+
43
+ @classmethod
44
+ def get_name(cls) -> str:
45
+ return "modelopt"
46
+
47
+ @classmethod
48
+ def get_supported_act_dtypes(cls) -> List[torch.dtype]:
49
+ return [torch.bfloat16, torch.half]
50
+
51
+ @classmethod
52
+ def get_min_capability(cls) -> int:
53
+ return 89 # Minimum hardware capability (e.g., Hopper GPUs).
54
+
55
+ @classmethod
56
+ def get_config_filenames(cls) -> List[str]:
57
+ return ["hf_quant_config.json"]
58
+
59
+ @classmethod
60
+ def from_config(cls, config: Dict[str, Any]) -> "ModelOptFp8Config":
61
+ quant_method = cls.get_from_keys(config, ["quantization"]).get("quant_algo")
62
+
63
+ if "FP8" not in quant_method:
64
+ raise ValueError(
65
+ "ModelOpt only supports static FP8 quantization in SGLang. "
66
+ "Check the `hf_quant_config.json` file for your model's configuration."
67
+ )
68
+
69
+ return cls(is_checkpoint_fp8_serialized=True)
70
+
71
+ def get_quant_method(
72
+ self, layer: torch.nn.Module, prefix: str
73
+ ) -> Optional["QuantizeMethodBase"]:
74
+ return ModelOptFp8LinearMethod(self) if isinstance(layer, LinearBase) else None
75
+
76
+ def get_scaled_act_names(self) -> List[str]:
77
+ return []
78
+
79
+
80
+ class ModelOptFp8LinearMethod(LinearMethodBase):
81
+ """Linear method for ModelOpt static FP8 quantization.
82
+
83
+ Supports loading FP8 checkpoints with static weight and activation scales.
84
+ Future support may include dynamic scales.
85
+
86
+ **Limitations**:
87
+ 1. Only supports per-tensor quantization due to `torch._scaled_mm` limitations.
88
+ 2. Only supports the `float8_e4m3fn` data type.
89
+
90
+ Args:
91
+ quant_config (ModelOptFp8Config): The ModelOpt quantization configuration.
92
+ """
93
+
94
+ def __init__(self, quant_config: ModelOptFp8Config):
95
+ super().__init__()
96
+ self.quant_config = quant_config
97
+ self.cutlass_fp8_supported = cutlass_fp8_supported()
98
+
99
+ def create_weights(
100
+ self,
101
+ layer: torch.nn.Module,
102
+ input_size_per_partition: int,
103
+ output_partition_sizes: List[int],
104
+ params_dtype: torch.dtype,
105
+ **extra_weight_attrs,
106
+ ) -> None:
107
+ """Creates and registers weights, weight scales, and input scales for FP8 quantization."""
108
+ output_size_per_partition = sum(output_partition_sizes)
109
+ weight_loader = extra_weight_attrs.get("weight_loader")
110
+ weight_dtype = (
111
+ torch.float8_e4m3fn
112
+ if self.quant_config.is_checkpoint_fp8_serialized
113
+ else params_dtype
114
+ )
115
+
116
+ # Set layer attributes
117
+ layer.logical_widths = output_partition_sizes
118
+ layer.input_size_per_partition = input_size_per_partition
119
+ layer.output_size_per_partition = output_size_per_partition
120
+
121
+ # Register weight
122
+ layer.register_parameter(
123
+ "weight",
124
+ ModelWeightParameter(
125
+ data=torch.empty(
126
+ output_size_per_partition,
127
+ input_size_per_partition,
128
+ dtype=weight_dtype,
129
+ ),
130
+ input_dim=1,
131
+ output_dim=0,
132
+ weight_loader=weight_loader,
133
+ ),
134
+ )
135
+
136
+ if self.quant_config.is_checkpoint_fp8_serialized:
137
+ # Register weight and input scales
138
+ for scale_name in ["weight_scale", "input_scale"]:
139
+ layer.register_parameter(
140
+ scale_name,
141
+ PerTensorScaleParameter(
142
+ data=torch.full(
143
+ (len(output_partition_sizes),),
144
+ torch.finfo(torch.float32).min,
145
+ dtype=torch.float32,
146
+ ),
147
+ weight_loader=weight_loader,
148
+ ),
149
+ )
150
+
151
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
152
+ """Requantizes weights after loading using the maximum scale."""
153
+ max_w_scale, quantized_weight = requantize_with_max_scale(
154
+ layer.weight, layer.weight_scale, layer.logical_widths
155
+ )
156
+ layer.weight = Parameter(quantized_weight.t(), requires_grad=False)
157
+ layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
158
+ layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False)
159
+
160
+ def apply(
161
+ self,
162
+ layer: torch.nn.Module,
163
+ x: torch.Tensor,
164
+ bias: Optional[torch.Tensor] = None,
165
+ ) -> torch.Tensor:
166
+ """Applies FP8 linear transformation."""
167
+ return apply_fp8_linear(
168
+ input=x,
169
+ weight=layer.weight,
170
+ weight_scale=layer.weight_scale,
171
+ input_scale=layer.input_scale,
172
+ bias=bias,
173
+ cutlass_fp8_supported=self.cutlass_fp8_supported,
174
+ )