sglang 0.3.5__py3-none-any.whl → 0.3.5.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +309 -0
- sglang/bench_serving.py +148 -24
- sglang/srt/configs/model_config.py +5 -2
- sglang/srt/constrained/__init__.py +2 -66
- sglang/srt/constrained/base_grammar_backend.py +73 -0
- sglang/srt/constrained/outlines_backend.py +165 -0
- sglang/srt/constrained/outlines_jump_forward.py +182 -0
- sglang/srt/constrained/xgrammar_backend.py +150 -0
- sglang/srt/layers/attention/triton_ops/decode_attention.py +7 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +6 -0
- sglang/srt/layers/fused_moe/fused_moe.py +23 -7
- sglang/srt/layers/fused_moe/patch.py +4 -2
- sglang/srt/layers/quantization/base_config.py +4 -6
- sglang/srt/layers/vocab_parallel_embedding.py +216 -150
- sglang/srt/managers/detokenizer_manager.py +0 -14
- sglang/srt/managers/io_struct.py +5 -3
- sglang/srt/managers/schedule_batch.py +14 -20
- sglang/srt/managers/scheduler.py +159 -96
- sglang/srt/managers/tokenizer_manager.py +81 -17
- sglang/srt/metrics/collector.py +211 -0
- sglang/srt/metrics/func_timer.py +108 -0
- sglang/srt/mm_utils.py +1 -1
- sglang/srt/model_executor/cuda_graph_runner.py +2 -2
- sglang/srt/model_executor/forward_batch_info.py +7 -3
- sglang/srt/model_executor/model_runner.py +6 -2
- sglang/srt/models/gemma2_reward.py +69 -0
- sglang/srt/models/gpt2.py +31 -37
- sglang/srt/models/internlm2_reward.py +62 -0
- sglang/srt/models/llama.py +11 -6
- sglang/srt/models/llama_reward.py +5 -26
- sglang/srt/models/qwen2_vl.py +5 -7
- sglang/srt/openai_api/adapter.py +11 -4
- sglang/srt/openai_api/protocol.py +29 -26
- sglang/srt/sampling/sampling_batch_info.py +2 -3
- sglang/srt/sampling/sampling_params.py +2 -16
- sglang/srt/server.py +60 -17
- sglang/srt/server_args.py +66 -25
- sglang/srt/utils.py +120 -0
- sglang/test/simple_eval_common.py +1 -1
- sglang/test/simple_eval_humaneval.py +2 -2
- sglang/test/simple_eval_mgsm.py +2 -2
- sglang/test/test_utils.py +21 -7
- sglang/utils.py +1 -0
- sglang/version.py +1 -1
- {sglang-0.3.5.dist-info → sglang-0.3.5.post2.dist-info}/METADATA +12 -8
- {sglang-0.3.5.dist-info → sglang-0.3.5.post2.dist-info}/RECORD +49 -45
- {sglang-0.3.5.dist-info → sglang-0.3.5.post2.dist-info}/WHEEL +1 -1
- sglang/srt/constrained/base_tool_cache.py +0 -65
- sglang/srt/constrained/bnf_cache.py +0 -61
- sglang/srt/constrained/fsm_cache.py +0 -95
- sglang/srt/constrained/grammar.py +0 -190
- sglang/srt/constrained/jump_forward.py +0 -203
- {sglang-0.3.5.dist-info → sglang-0.3.5.post2.dist-info}/LICENSE +0 -0
- {sglang-0.3.5.dist-info → sglang-0.3.5.post2.dist-info}/top_level.txt +0 -0
@@ -122,16 +122,14 @@ class QuantizationConfig(ABC):
|
|
122
122
|
"""
|
123
123
|
raise NotImplementedError
|
124
124
|
|
125
|
-
|
126
|
-
|
125
|
+
|
126
|
+
def method_has_implemented_embedding(method_class: Type[QuantizeMethodBase]) -> bool:
|
127
127
|
"""
|
128
128
|
Not all quant methods have embedding implemented, so we need to check that
|
129
129
|
it exists for our given method. We check this by making sure the function
|
130
130
|
has been changed from the base implementation.
|
131
131
|
"""
|
132
|
-
base_embedding = inspect.getattr_static(QuantizeMethodBase, "embedding",
|
133
|
-
None)
|
132
|
+
base_embedding = inspect.getattr_static(QuantizeMethodBase, "embedding", None)
|
134
133
|
class_embedding = inspect.getattr_static(method_class, "embedding", None)
|
135
134
|
|
136
|
-
return
|
137
|
-
and class_embedding is not base_embedding)
|
135
|
+
return class_embedding is not None and class_embedding is not base_embedding
|
@@ -27,59 +27,67 @@ DEFAULT_VOCAB_PADDING_SIZE = 64
|
|
27
27
|
class UnquantizedEmbeddingMethod(QuantizeMethodBase):
|
28
28
|
"""Unquantized method for embeddings."""
|
29
29
|
|
30
|
-
def create_weights(
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
30
|
+
def create_weights(
|
31
|
+
self,
|
32
|
+
layer: torch.nn.Module,
|
33
|
+
input_size_per_partition: int,
|
34
|
+
output_partition_sizes: List[int],
|
35
|
+
input_size: int,
|
36
|
+
output_size: int,
|
37
|
+
params_dtype: torch.dtype,
|
38
|
+
**extra_weight_attrs,
|
39
|
+
):
|
35
40
|
"""Create weights for embedding layer."""
|
36
|
-
weight = Parameter(
|
37
|
-
|
38
|
-
|
39
|
-
|
41
|
+
weight = Parameter(
|
42
|
+
torch.empty(
|
43
|
+
sum(output_partition_sizes),
|
44
|
+
input_size_per_partition,
|
45
|
+
dtype=params_dtype,
|
46
|
+
),
|
47
|
+
requires_grad=False,
|
48
|
+
)
|
40
49
|
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
|
41
50
|
layer.register_parameter("weight", weight)
|
42
51
|
set_weight_attrs(weight, extra_weight_attrs)
|
43
52
|
|
44
|
-
def apply(
|
45
|
-
|
46
|
-
|
47
|
-
|
53
|
+
def apply(
|
54
|
+
self,
|
55
|
+
layer: torch.nn.Module,
|
56
|
+
x: torch.Tensor,
|
57
|
+
bias: Optional[torch.Tensor] = None,
|
58
|
+
) -> torch.Tensor:
|
48
59
|
return F.linear(x, layer.weight, bias)
|
49
60
|
|
50
|
-
def embedding(self, layer: torch.nn.Module,
|
51
|
-
input_: torch.Tensor) -> torch.Tensor:
|
61
|
+
def embedding(self, layer: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor:
|
52
62
|
return F.embedding(input_, layer.weight)
|
53
63
|
|
54
64
|
|
55
|
-
def pad_vocab_size(vocab_size: int,
|
56
|
-
pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int:
|
65
|
+
def pad_vocab_size(vocab_size: int, pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int:
|
57
66
|
"""Pad the vocab size to the given value."""
|
58
67
|
return ((vocab_size + pad_to - 1) // pad_to) * pad_to
|
59
68
|
|
60
69
|
|
61
70
|
def vocab_range_from_per_partition_vocab_size(
|
62
|
-
|
63
|
-
|
64
|
-
offset: int = 0) -> Sequence[int]:
|
71
|
+
per_partition_vocab_size: int, rank: int, offset: int = 0
|
72
|
+
) -> Sequence[int]:
|
65
73
|
index_f = rank * per_partition_vocab_size
|
66
74
|
index_l = index_f + per_partition_vocab_size
|
67
75
|
return index_f + offset, index_l + offset
|
68
76
|
|
69
77
|
|
70
|
-
def vocab_range_from_global_vocab_size(
|
71
|
-
|
72
|
-
|
73
|
-
offset: int = 0) -> Sequence[int]:
|
78
|
+
def vocab_range_from_global_vocab_size(
|
79
|
+
global_vocab_size: int, rank: int, world_size: int, offset: int = 0
|
80
|
+
) -> Sequence[int]:
|
74
81
|
per_partition_vocab_size = divide(global_vocab_size, world_size)
|
75
|
-
return vocab_range_from_per_partition_vocab_size(
|
76
|
-
|
77
|
-
|
82
|
+
return vocab_range_from_per_partition_vocab_size(
|
83
|
+
per_partition_vocab_size, rank, offset=offset
|
84
|
+
)
|
78
85
|
|
79
86
|
|
80
87
|
@dataclass
|
81
88
|
class VocabParallelEmbeddingShardIndices:
|
82
89
|
"""Indices for a shard of a vocab parallel embedding."""
|
90
|
+
|
83
91
|
padded_org_vocab_start_index: int
|
84
92
|
padded_org_vocab_end_index: int
|
85
93
|
padded_added_vocab_start_index: int
|
@@ -100,13 +108,11 @@ class VocabParallelEmbeddingShardIndices:
|
|
100
108
|
|
101
109
|
@property
|
102
110
|
def num_org_elements_padded(self) -> int:
|
103
|
-
return
|
104
|
-
self.padded_org_vocab_start_index)
|
111
|
+
return self.padded_org_vocab_end_index - self.padded_org_vocab_start_index
|
105
112
|
|
106
113
|
@property
|
107
114
|
def num_added_elements_padded(self) -> int:
|
108
|
-
return
|
109
|
-
self.padded_added_vocab_start_index)
|
115
|
+
return self.padded_added_vocab_end_index - self.padded_added_vocab_start_index
|
110
116
|
|
111
117
|
@property
|
112
118
|
def num_org_vocab_padding(self) -> int:
|
@@ -122,17 +128,14 @@ class VocabParallelEmbeddingShardIndices:
|
|
122
128
|
|
123
129
|
def __post_init__(self):
|
124
130
|
# sanity checks
|
125
|
-
assert
|
126
|
-
|
127
|
-
assert (self.padded_added_vocab_start_index <=
|
128
|
-
self.padded_added_vocab_end_index)
|
131
|
+
assert self.padded_org_vocab_start_index <= self.padded_org_vocab_end_index
|
132
|
+
assert self.padded_added_vocab_start_index <= self.padded_added_vocab_end_index
|
129
133
|
|
130
134
|
assert self.org_vocab_start_index <= self.org_vocab_end_index
|
131
135
|
assert self.added_vocab_start_index <= self.added_vocab_end_index
|
132
136
|
|
133
137
|
assert self.org_vocab_start_index <= self.padded_org_vocab_start_index
|
134
|
-
assert
|
135
|
-
self.padded_added_vocab_start_index)
|
138
|
+
assert self.added_vocab_start_index <= self.padded_added_vocab_start_index
|
136
139
|
assert self.org_vocab_end_index <= self.padded_org_vocab_end_index
|
137
140
|
assert self.added_vocab_end_index <= self.padded_added_vocab_end_index
|
138
141
|
|
@@ -142,20 +145,27 @@ class VocabParallelEmbeddingShardIndices:
|
|
142
145
|
|
143
146
|
@torch.jit.script
|
144
147
|
def get_masked_input_and_mask(
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
148
|
+
input_: torch.Tensor,
|
149
|
+
org_vocab_start_index: int,
|
150
|
+
org_vocab_end_index: int,
|
151
|
+
num_org_vocab_padding: int,
|
152
|
+
added_vocab_start_index: int,
|
153
|
+
added_vocab_end_index: int,
|
154
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
149
155
|
# torch.jit.script will fuse all of the pointwise ops below
|
150
156
|
# into a single kernel, making it very fast
|
151
|
-
org_vocab_mask = (input_ >= org_vocab_start_index) & (input_ <
|
152
|
-
org_vocab_end_index)
|
157
|
+
org_vocab_mask = (input_ >= org_vocab_start_index) & (input_ < org_vocab_end_index)
|
153
158
|
added_vocab_mask = (input_ >= added_vocab_start_index) & (
|
154
|
-
input_ < added_vocab_end_index
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
+
input_ < added_vocab_end_index
|
160
|
+
)
|
161
|
+
added_offset = (
|
162
|
+
added_vocab_start_index
|
163
|
+
- (org_vocab_end_index - org_vocab_start_index)
|
164
|
+
- num_org_vocab_padding
|
165
|
+
)
|
166
|
+
valid_offset = (org_vocab_start_index * org_vocab_mask) + (
|
167
|
+
added_offset * added_vocab_mask
|
168
|
+
)
|
159
169
|
vocab_mask = org_vocab_mask | added_vocab_mask
|
160
170
|
input_ = vocab_mask * (input_ - valid_offset)
|
161
171
|
return input_, ~vocab_mask
|
@@ -200,15 +210,17 @@ class VocabParallelEmbedding(torch.nn.Module):
|
|
200
210
|
prefix: full name of the layer in the state dict
|
201
211
|
""" # noqa: E501
|
202
212
|
|
203
|
-
def __init__(
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
213
|
+
def __init__(
|
214
|
+
self,
|
215
|
+
num_embeddings: int,
|
216
|
+
embedding_dim: int,
|
217
|
+
params_dtype: Optional[torch.dtype] = None,
|
218
|
+
org_num_embeddings: Optional[int] = None,
|
219
|
+
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
|
220
|
+
quant_config: Optional[QuantizationConfig] = None,
|
221
|
+
prefix: str = "",
|
222
|
+
enable_tp: bool = True,
|
223
|
+
):
|
212
224
|
super().__init__()
|
213
225
|
|
214
226
|
self.enable_tp = enable_tp
|
@@ -223,18 +235,22 @@ class VocabParallelEmbedding(torch.nn.Module):
|
|
223
235
|
self.padding_size = padding_size
|
224
236
|
self.org_vocab_size = org_num_embeddings or num_embeddings
|
225
237
|
num_added_embeddings = num_embeddings - self.org_vocab_size
|
226
|
-
self.org_vocab_size_padded = pad_vocab_size(
|
227
|
-
|
238
|
+
self.org_vocab_size_padded = pad_vocab_size(
|
239
|
+
self.org_vocab_size, self.padding_size
|
240
|
+
)
|
228
241
|
self.num_embeddings_padded = pad_vocab_size(
|
229
|
-
self.org_vocab_size_padded + num_added_embeddings,
|
230
|
-
|
242
|
+
self.org_vocab_size_padded + num_added_embeddings, self.padding_size
|
243
|
+
)
|
231
244
|
assert self.org_vocab_size_padded <= self.num_embeddings_padded
|
232
245
|
|
233
|
-
self.shard_indices = self._get_indices(
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
246
|
+
self.shard_indices = self._get_indices(
|
247
|
+
self.num_embeddings_padded,
|
248
|
+
self.org_vocab_size_padded,
|
249
|
+
self.num_embeddings,
|
250
|
+
self.org_vocab_size,
|
251
|
+
tp_rank,
|
252
|
+
self.tp_size,
|
253
|
+
)
|
238
254
|
self.embedding_dim = embedding_dim
|
239
255
|
|
240
256
|
linear_method = None
|
@@ -248,11 +264,13 @@ class VocabParallelEmbedding(torch.nn.Module):
|
|
248
264
|
# layer type like ParallelLMHead, this is not important.
|
249
265
|
is_embedding_layer = type(self.__class__) is VocabParallelEmbedding
|
250
266
|
linear_method_implements_embedding = method_has_implemented_embedding(
|
251
|
-
type(linear_method)
|
267
|
+
type(linear_method)
|
268
|
+
)
|
252
269
|
if is_embedding_layer and not linear_method_implements_embedding:
|
253
270
|
raise NotImplementedError(
|
254
271
|
f"The class {type(linear_method).__name__} must implement "
|
255
|
-
"the 'embedding' method, see UnquantizedEmbeddingMethod."
|
272
|
+
"the 'embedding' method, see UnquantizedEmbeddingMethod."
|
273
|
+
)
|
256
274
|
|
257
275
|
self.linear_method: QuantizeMethodBase = linear_method
|
258
276
|
|
@@ -260,58 +278,73 @@ class VocabParallelEmbedding(torch.nn.Module):
|
|
260
278
|
params_dtype = torch.get_default_dtype()
|
261
279
|
# Divide the weight matrix along the vocaburaly dimension.
|
262
280
|
self.num_added_embeddings = self.num_embeddings - self.org_vocab_size
|
263
|
-
self.num_embeddings_per_partition = divide(
|
264
|
-
|
265
|
-
|
266
|
-
|
281
|
+
self.num_embeddings_per_partition = divide(
|
282
|
+
self.num_embeddings_padded, self.tp_size
|
283
|
+
)
|
284
|
+
assert (
|
285
|
+
self.shard_indices.num_elements_padded == self.num_embeddings_per_partition
|
286
|
+
)
|
267
287
|
self.num_org_embeddings_per_partition = (
|
268
|
-
self.shard_indices.org_vocab_end_index
|
269
|
-
self.shard_indices.org_vocab_start_index
|
288
|
+
self.shard_indices.org_vocab_end_index
|
289
|
+
- self.shard_indices.org_vocab_start_index
|
290
|
+
)
|
270
291
|
self.num_added_embeddings_per_partition = (
|
271
|
-
self.shard_indices.added_vocab_end_index
|
272
|
-
self.shard_indices.added_vocab_start_index
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
292
|
+
self.shard_indices.added_vocab_end_index
|
293
|
+
- self.shard_indices.added_vocab_start_index
|
294
|
+
)
|
295
|
+
|
296
|
+
self.linear_method.create_weights(
|
297
|
+
self,
|
298
|
+
self.embedding_dim,
|
299
|
+
[self.num_embeddings_per_partition],
|
300
|
+
self.embedding_dim,
|
301
|
+
self.num_embeddings_padded,
|
302
|
+
params_dtype=params_dtype,
|
303
|
+
weight_loader=self.weight_loader,
|
304
|
+
)
|
281
305
|
|
282
306
|
@classmethod
|
283
|
-
def _get_indices(
|
284
|
-
|
285
|
-
|
307
|
+
def _get_indices(
|
308
|
+
cls,
|
309
|
+
vocab_size_padded: int,
|
310
|
+
org_vocab_size_padded: int,
|
311
|
+
vocab_size: int,
|
312
|
+
org_vocab_size: int,
|
313
|
+
tp_rank: int,
|
314
|
+
tp_size: int,
|
315
|
+
) -> VocabParallelEmbeddingShardIndices:
|
286
316
|
"""Get start and end indices for vocab parallel embedding, following the
|
287
317
|
layout outlined in the class docstring, based on the given tp_rank and
|
288
318
|
tp_size."""
|
289
319
|
num_added_embeddings_padded = vocab_size_padded - org_vocab_size_padded
|
290
320
|
padded_org_vocab_start_index, padded_org_vocab_end_index = (
|
291
|
-
vocab_range_from_global_vocab_size(org_vocab_size_padded, tp_rank,
|
292
|
-
|
321
|
+
vocab_range_from_global_vocab_size(org_vocab_size_padded, tp_rank, tp_size)
|
322
|
+
)
|
293
323
|
padded_added_vocab_start_index, padded_added_vocab_end_index = (
|
294
|
-
vocab_range_from_global_vocab_size(
|
295
|
-
|
296
|
-
|
297
|
-
|
324
|
+
vocab_range_from_global_vocab_size(
|
325
|
+
num_added_embeddings_padded, tp_rank, tp_size, offset=org_vocab_size
|
326
|
+
)
|
327
|
+
)
|
298
328
|
# remove padding
|
299
|
-
org_vocab_start_index = min(padded_org_vocab_start_index,
|
300
|
-
org_vocab_size)
|
329
|
+
org_vocab_start_index = min(padded_org_vocab_start_index, org_vocab_size)
|
301
330
|
org_vocab_end_index = min(padded_org_vocab_end_index, org_vocab_size)
|
302
|
-
added_vocab_start_index = min(padded_added_vocab_start_index,
|
303
|
-
vocab_size)
|
331
|
+
added_vocab_start_index = min(padded_added_vocab_start_index, vocab_size)
|
304
332
|
added_vocab_end_index = min(padded_added_vocab_end_index, vocab_size)
|
305
333
|
return VocabParallelEmbeddingShardIndices(
|
306
|
-
padded_org_vocab_start_index,
|
307
|
-
|
308
|
-
|
309
|
-
|
334
|
+
padded_org_vocab_start_index,
|
335
|
+
padded_org_vocab_end_index,
|
336
|
+
padded_added_vocab_start_index,
|
337
|
+
padded_added_vocab_end_index,
|
338
|
+
org_vocab_start_index,
|
339
|
+
org_vocab_end_index,
|
340
|
+
added_vocab_start_index,
|
341
|
+
added_vocab_end_index,
|
342
|
+
)
|
310
343
|
|
311
344
|
def get_sharded_to_full_mapping(self) -> Optional[List[int]]:
|
312
345
|
"""Get a mapping that can be used to reindex the gathered
|
313
346
|
logits for sampling.
|
314
|
-
|
347
|
+
|
315
348
|
During sampling, we gather logits from all ranks. The relationship
|
316
349
|
of index->token_id will follow the same format as outlined in the class
|
317
350
|
docstring. However, after the gather, we want to reindex the final
|
@@ -326,32 +359,49 @@ class VocabParallelEmbedding(torch.nn.Module):
|
|
326
359
|
added_embeddings: List[int] = []
|
327
360
|
padding: List[int] = []
|
328
361
|
for tp_rank in range(self.tp_size):
|
329
|
-
shard_indices = self._get_indices(
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
362
|
+
shard_indices = self._get_indices(
|
363
|
+
self.num_embeddings_padded,
|
364
|
+
self.org_vocab_size_padded,
|
365
|
+
self.num_embeddings,
|
366
|
+
self.org_vocab_size,
|
367
|
+
tp_rank,
|
368
|
+
self.tp_size,
|
369
|
+
)
|
334
370
|
range_start = self.num_embeddings_per_partition * tp_rank
|
335
371
|
range_end = self.num_embeddings_per_partition * (tp_rank + 1)
|
336
372
|
base_embeddings.extend(
|
337
|
-
range(range_start,
|
338
|
-
|
373
|
+
range(range_start, range_start + shard_indices.num_org_elements)
|
374
|
+
)
|
339
375
|
padding.extend(
|
340
|
-
range(
|
341
|
-
|
376
|
+
range(
|
377
|
+
range_start + shard_indices.num_org_elements,
|
378
|
+
range_start + shard_indices.num_org_elements_padded,
|
379
|
+
)
|
380
|
+
)
|
342
381
|
added_embeddings.extend(
|
343
382
|
range(
|
344
383
|
range_start + shard_indices.num_org_elements_padded,
|
345
|
-
range_start
|
346
|
-
shard_indices.
|
384
|
+
range_start
|
385
|
+
+ shard_indices.num_org_elements_padded
|
386
|
+
+ shard_indices.num_added_elements,
|
387
|
+
)
|
388
|
+
)
|
347
389
|
padding.extend(
|
348
390
|
range(
|
349
|
-
range_start
|
350
|
-
shard_indices.
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
shard_indices.num_added_elements_padded
|
391
|
+
range_start
|
392
|
+
+ shard_indices.num_org_elements_padded
|
393
|
+
+ shard_indices.num_added_elements,
|
394
|
+
range_start
|
395
|
+
+ shard_indices.num_org_elements_padded
|
396
|
+
+ shard_indices.num_added_elements_padded,
|
397
|
+
)
|
398
|
+
)
|
399
|
+
assert (
|
400
|
+
range_start
|
401
|
+
+ shard_indices.num_org_elements_padded
|
402
|
+
+ shard_indices.num_added_elements_padded
|
403
|
+
== range_end
|
404
|
+
)
|
355
405
|
ret = base_embeddings + added_embeddings + padding
|
356
406
|
assert len(ret) == self.num_embeddings_padded
|
357
407
|
return ret
|
@@ -385,10 +435,14 @@ class VocabParallelEmbedding(torch.nn.Module):
|
|
385
435
|
# If param packed on the same dim we are sharding on, then
|
386
436
|
# need to adjust offsets of loaded weight by pack_factor.
|
387
437
|
if packed_dim is not None and packed_dim == output_dim:
|
388
|
-
packed_factor =
|
389
|
-
param
|
390
|
-
|
391
|
-
|
438
|
+
packed_factor = (
|
439
|
+
param.packed_factor
|
440
|
+
if isinstance(param, BasevLLMParameter)
|
441
|
+
else param.pack_factor
|
442
|
+
)
|
443
|
+
assert loaded_weight.shape[output_dim] == (
|
444
|
+
self.org_vocab_size // param.packed_factor
|
445
|
+
)
|
392
446
|
start_idx = start_idx // packed_factor
|
393
447
|
shard_size = shard_size // packed_factor
|
394
448
|
else:
|
@@ -396,23 +450,24 @@ class VocabParallelEmbedding(torch.nn.Module):
|
|
396
450
|
|
397
451
|
# Copy the data.
|
398
452
|
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
|
399
|
-
param[:loaded_weight.shape[0]].data.copy_(loaded_weight)
|
400
|
-
param[loaded_weight.shape[0]:].data.fill_(0)
|
453
|
+
param[: loaded_weight.shape[0]].data.copy_(loaded_weight)
|
454
|
+
param[loaded_weight.shape[0] :].data.fill_(0)
|
401
455
|
|
402
456
|
def forward(self, input_):
|
403
457
|
if self.tp_size > 1:
|
404
458
|
# Build the mask.
|
405
459
|
masked_input, input_mask = get_masked_input_and_mask(
|
406
|
-
input_,
|
460
|
+
input_,
|
461
|
+
self.shard_indices.org_vocab_start_index,
|
407
462
|
self.shard_indices.org_vocab_end_index,
|
408
463
|
self.shard_indices.num_org_vocab_padding,
|
409
464
|
self.shard_indices.added_vocab_start_index,
|
410
|
-
self.shard_indices.added_vocab_end_index
|
465
|
+
self.shard_indices.added_vocab_end_index,
|
466
|
+
)
|
411
467
|
else:
|
412
468
|
masked_input = input_
|
413
469
|
# Get the embeddings.
|
414
|
-
output_parallel = self.linear_method.embedding(self,
|
415
|
-
masked_input.long())
|
470
|
+
output_parallel = self.linear_method.embedding(self, masked_input.long())
|
416
471
|
# Mask the output embedding.
|
417
472
|
if self.tp_size > 1:
|
418
473
|
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
|
@@ -426,9 +481,9 @@ class VocabParallelEmbedding(torch.nn.Module):
|
|
426
481
|
s = f"num_embeddings={self.num_embeddings_per_partition}"
|
427
482
|
s += f", embedding_dim={self.embedding_dim}"
|
428
483
|
s += f", org_vocab_size={self.org_vocab_size}"
|
429
|
-
s += f
|
484
|
+
s += f", num_embeddings_padded={self.num_embeddings_padded}"
|
430
485
|
if self.enable_tp:
|
431
|
-
s += f
|
486
|
+
s += f", tp_size={self.tp_size}"
|
432
487
|
return s
|
433
488
|
|
434
489
|
|
@@ -448,27 +503,38 @@ class ParallelLMHead(VocabParallelEmbedding):
|
|
448
503
|
padding_size: padding size for the vocabulary.
|
449
504
|
"""
|
450
505
|
|
451
|
-
def __init__(
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
|
456
|
-
|
457
|
-
|
458
|
-
|
459
|
-
|
460
|
-
|
461
|
-
|
462
|
-
|
506
|
+
def __init__(
|
507
|
+
self,
|
508
|
+
num_embeddings: int,
|
509
|
+
embedding_dim: int,
|
510
|
+
bias: bool = False,
|
511
|
+
params_dtype: Optional[torch.dtype] = None,
|
512
|
+
org_num_embeddings: Optional[int] = None,
|
513
|
+
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
|
514
|
+
quant_config: Optional[QuantizationConfig] = None,
|
515
|
+
prefix: str = "",
|
516
|
+
):
|
517
|
+
super().__init__(
|
518
|
+
num_embeddings,
|
519
|
+
embedding_dim,
|
520
|
+
params_dtype,
|
521
|
+
org_num_embeddings,
|
522
|
+
padding_size,
|
523
|
+
quant_config,
|
524
|
+
prefix,
|
525
|
+
)
|
463
526
|
self.quant_config = quant_config
|
464
527
|
if bias:
|
465
528
|
self.bias = Parameter(
|
466
|
-
torch.empty(self.num_embeddings_per_partition,
|
467
|
-
|
468
|
-
set_weight_attrs(
|
469
|
-
|
470
|
-
|
471
|
-
|
529
|
+
torch.empty(self.num_embeddings_per_partition, dtype=params_dtype)
|
530
|
+
)
|
531
|
+
set_weight_attrs(
|
532
|
+
self.bias,
|
533
|
+
{
|
534
|
+
"output_dim": 0,
|
535
|
+
"weight_loader": self.weight_loader,
|
536
|
+
},
|
537
|
+
)
|
472
538
|
else:
|
473
539
|
self.register_parameter("bias", None)
|
474
540
|
|
@@ -483,4 +549,4 @@ class ParallelLMHead(VocabParallelEmbedding):
|
|
483
549
|
|
484
550
|
def forward(self, input_):
|
485
551
|
del input_
|
486
|
-
raise RuntimeError("LMHead's weights should be used in the sampler.")
|
552
|
+
raise RuntimeError("LMHead's weights should be used in the sampler.")
|
@@ -100,20 +100,6 @@ class DetokenizerManager:
|
|
100
100
|
|
101
101
|
if isinstance(recv_obj, BatchEmbeddingOut):
|
102
102
|
# If it is embedding model, no detokenization is needed.
|
103
|
-
self.send_to_tokenizer.send_pyobj(
|
104
|
-
BatchEmbeddingOut(
|
105
|
-
rids=recv_obj.rids,
|
106
|
-
embeddings=recv_obj.embeddings,
|
107
|
-
meta_info=recv_obj.meta_info,
|
108
|
-
finished_reason=recv_obj.finished_reason,
|
109
|
-
)
|
110
|
-
)
|
111
|
-
continue
|
112
|
-
elif isinstance(recv_obj, UpdateWeightReqOutput):
|
113
|
-
# If it is a weight update request, no detokenization is needed.
|
114
|
-
self.send_to_tokenizer.send_pyobj(recv_obj)
|
115
|
-
continue
|
116
|
-
elif isinstance(recv_obj, GetMemPoolSizeReqOutput):
|
117
103
|
self.send_to_tokenizer.send_pyobj(recv_obj)
|
118
104
|
continue
|
119
105
|
else:
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -86,8 +86,10 @@ class GenerateReqInput:
|
|
86
86
|
self.parallel_sample_num = self.sampling_params.get("n", 1)
|
87
87
|
else: # isinstance(self.sampling_params, list):
|
88
88
|
self.parallel_sample_num = self.sampling_params[0].get("n", 1)
|
89
|
-
assert all(
|
90
|
-
|
89
|
+
assert all(
|
90
|
+
self.parallel_sample_num == sampling_params.get("n", 1)
|
91
|
+
for sampling_params in self.sampling_params
|
92
|
+
), "The parallel_sample_num should be the same for all samples in sample params."
|
91
93
|
|
92
94
|
if self.parallel_sample_num > 1 and self.is_single:
|
93
95
|
self.is_single = False
|
@@ -182,7 +184,7 @@ class TokenizedGenerateReqInput:
|
|
182
184
|
input_text: str
|
183
185
|
# The input token ids
|
184
186
|
input_ids: List[int]
|
185
|
-
# The image
|
187
|
+
# The image inputs
|
186
188
|
image_inputs: dict
|
187
189
|
# The sampling parameters
|
188
190
|
sampling_params: SamplingParams
|