sglang 0.3.5__py3-none-any.whl → 0.3.5.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. sglang/bench_offline_throughput.py +309 -0
  2. sglang/bench_serving.py +148 -24
  3. sglang/srt/configs/model_config.py +5 -2
  4. sglang/srt/constrained/__init__.py +2 -66
  5. sglang/srt/constrained/base_grammar_backend.py +73 -0
  6. sglang/srt/constrained/outlines_backend.py +165 -0
  7. sglang/srt/constrained/outlines_jump_forward.py +182 -0
  8. sglang/srt/constrained/xgrammar_backend.py +150 -0
  9. sglang/srt/layers/attention/triton_ops/decode_attention.py +7 -0
  10. sglang/srt/layers/attention/triton_ops/extend_attention.py +6 -0
  11. sglang/srt/layers/fused_moe/fused_moe.py +23 -7
  12. sglang/srt/layers/fused_moe/patch.py +4 -2
  13. sglang/srt/layers/quantization/base_config.py +4 -6
  14. sglang/srt/layers/vocab_parallel_embedding.py +216 -150
  15. sglang/srt/managers/detokenizer_manager.py +0 -14
  16. sglang/srt/managers/io_struct.py +5 -3
  17. sglang/srt/managers/schedule_batch.py +14 -20
  18. sglang/srt/managers/scheduler.py +159 -96
  19. sglang/srt/managers/tokenizer_manager.py +81 -17
  20. sglang/srt/metrics/collector.py +211 -0
  21. sglang/srt/metrics/func_timer.py +108 -0
  22. sglang/srt/mm_utils.py +1 -1
  23. sglang/srt/model_executor/cuda_graph_runner.py +2 -2
  24. sglang/srt/model_executor/forward_batch_info.py +7 -3
  25. sglang/srt/model_executor/model_runner.py +6 -2
  26. sglang/srt/models/gemma2_reward.py +69 -0
  27. sglang/srt/models/gpt2.py +31 -37
  28. sglang/srt/models/internlm2_reward.py +62 -0
  29. sglang/srt/models/llama.py +11 -6
  30. sglang/srt/models/llama_reward.py +5 -26
  31. sglang/srt/models/qwen2_vl.py +5 -7
  32. sglang/srt/openai_api/adapter.py +11 -4
  33. sglang/srt/openai_api/protocol.py +29 -26
  34. sglang/srt/sampling/sampling_batch_info.py +2 -3
  35. sglang/srt/sampling/sampling_params.py +2 -16
  36. sglang/srt/server.py +60 -17
  37. sglang/srt/server_args.py +66 -25
  38. sglang/srt/utils.py +120 -0
  39. sglang/test/simple_eval_common.py +1 -1
  40. sglang/test/simple_eval_humaneval.py +2 -2
  41. sglang/test/simple_eval_mgsm.py +2 -2
  42. sglang/test/test_utils.py +21 -7
  43. sglang/utils.py +1 -0
  44. sglang/version.py +1 -1
  45. {sglang-0.3.5.dist-info → sglang-0.3.5.post2.dist-info}/METADATA +12 -8
  46. {sglang-0.3.5.dist-info → sglang-0.3.5.post2.dist-info}/RECORD +49 -45
  47. {sglang-0.3.5.dist-info → sglang-0.3.5.post2.dist-info}/WHEEL +1 -1
  48. sglang/srt/constrained/base_tool_cache.py +0 -65
  49. sglang/srt/constrained/bnf_cache.py +0 -61
  50. sglang/srt/constrained/fsm_cache.py +0 -95
  51. sglang/srt/constrained/grammar.py +0 -190
  52. sglang/srt/constrained/jump_forward.py +0 -203
  53. {sglang-0.3.5.dist-info → sglang-0.3.5.post2.dist-info}/LICENSE +0 -0
  54. {sglang-0.3.5.dist-info → sglang-0.3.5.post2.dist-info}/top_level.txt +0 -0
@@ -122,16 +122,14 @@ class QuantizationConfig(ABC):
122
122
  """
123
123
  raise NotImplementedError
124
124
 
125
- def method_has_implemented_embedding(
126
- method_class: Type[QuantizeMethodBase]) -> bool:
125
+
126
+ def method_has_implemented_embedding(method_class: Type[QuantizeMethodBase]) -> bool:
127
127
  """
128
128
  Not all quant methods have embedding implemented, so we need to check that
129
129
  it exists for our given method. We check this by making sure the function
130
130
  has been changed from the base implementation.
131
131
  """
132
- base_embedding = inspect.getattr_static(QuantizeMethodBase, "embedding",
133
- None)
132
+ base_embedding = inspect.getattr_static(QuantizeMethodBase, "embedding", None)
134
133
  class_embedding = inspect.getattr_static(method_class, "embedding", None)
135
134
 
136
- return (class_embedding is not None
137
- and class_embedding is not base_embedding)
135
+ return class_embedding is not None and class_embedding is not base_embedding
@@ -27,59 +27,67 @@ DEFAULT_VOCAB_PADDING_SIZE = 64
27
27
  class UnquantizedEmbeddingMethod(QuantizeMethodBase):
28
28
  """Unquantized method for embeddings."""
29
29
 
30
- def create_weights(self, layer: torch.nn.Module,
31
- input_size_per_partition: int,
32
- output_partition_sizes: List[int], input_size: int,
33
- output_size: int, params_dtype: torch.dtype,
34
- **extra_weight_attrs):
30
+ def create_weights(
31
+ self,
32
+ layer: torch.nn.Module,
33
+ input_size_per_partition: int,
34
+ output_partition_sizes: List[int],
35
+ input_size: int,
36
+ output_size: int,
37
+ params_dtype: torch.dtype,
38
+ **extra_weight_attrs,
39
+ ):
35
40
  """Create weights for embedding layer."""
36
- weight = Parameter(torch.empty(sum(output_partition_sizes),
37
- input_size_per_partition,
38
- dtype=params_dtype),
39
- requires_grad=False)
41
+ weight = Parameter(
42
+ torch.empty(
43
+ sum(output_partition_sizes),
44
+ input_size_per_partition,
45
+ dtype=params_dtype,
46
+ ),
47
+ requires_grad=False,
48
+ )
40
49
  set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
41
50
  layer.register_parameter("weight", weight)
42
51
  set_weight_attrs(weight, extra_weight_attrs)
43
52
 
44
- def apply(self,
45
- layer: torch.nn.Module,
46
- x: torch.Tensor,
47
- bias: Optional[torch.Tensor] = None) -> torch.Tensor:
53
+ def apply(
54
+ self,
55
+ layer: torch.nn.Module,
56
+ x: torch.Tensor,
57
+ bias: Optional[torch.Tensor] = None,
58
+ ) -> torch.Tensor:
48
59
  return F.linear(x, layer.weight, bias)
49
60
 
50
- def embedding(self, layer: torch.nn.Module,
51
- input_: torch.Tensor) -> torch.Tensor:
61
+ def embedding(self, layer: torch.nn.Module, input_: torch.Tensor) -> torch.Tensor:
52
62
  return F.embedding(input_, layer.weight)
53
63
 
54
64
 
55
- def pad_vocab_size(vocab_size: int,
56
- pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int:
65
+ def pad_vocab_size(vocab_size: int, pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int:
57
66
  """Pad the vocab size to the given value."""
58
67
  return ((vocab_size + pad_to - 1) // pad_to) * pad_to
59
68
 
60
69
 
61
70
  def vocab_range_from_per_partition_vocab_size(
62
- per_partition_vocab_size: int,
63
- rank: int,
64
- offset: int = 0) -> Sequence[int]:
71
+ per_partition_vocab_size: int, rank: int, offset: int = 0
72
+ ) -> Sequence[int]:
65
73
  index_f = rank * per_partition_vocab_size
66
74
  index_l = index_f + per_partition_vocab_size
67
75
  return index_f + offset, index_l + offset
68
76
 
69
77
 
70
- def vocab_range_from_global_vocab_size(global_vocab_size: int,
71
- rank: int,
72
- world_size: int,
73
- offset: int = 0) -> Sequence[int]:
78
+ def vocab_range_from_global_vocab_size(
79
+ global_vocab_size: int, rank: int, world_size: int, offset: int = 0
80
+ ) -> Sequence[int]:
74
81
  per_partition_vocab_size = divide(global_vocab_size, world_size)
75
- return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size,
76
- rank,
77
- offset=offset)
82
+ return vocab_range_from_per_partition_vocab_size(
83
+ per_partition_vocab_size, rank, offset=offset
84
+ )
78
85
 
79
86
 
80
87
  @dataclass
81
88
  class VocabParallelEmbeddingShardIndices:
82
89
  """Indices for a shard of a vocab parallel embedding."""
90
+
83
91
  padded_org_vocab_start_index: int
84
92
  padded_org_vocab_end_index: int
85
93
  padded_added_vocab_start_index: int
@@ -100,13 +108,11 @@ class VocabParallelEmbeddingShardIndices:
100
108
 
101
109
  @property
102
110
  def num_org_elements_padded(self) -> int:
103
- return (self.padded_org_vocab_end_index -
104
- self.padded_org_vocab_start_index)
111
+ return self.padded_org_vocab_end_index - self.padded_org_vocab_start_index
105
112
 
106
113
  @property
107
114
  def num_added_elements_padded(self) -> int:
108
- return (self.padded_added_vocab_end_index -
109
- self.padded_added_vocab_start_index)
115
+ return self.padded_added_vocab_end_index - self.padded_added_vocab_start_index
110
116
 
111
117
  @property
112
118
  def num_org_vocab_padding(self) -> int:
@@ -122,17 +128,14 @@ class VocabParallelEmbeddingShardIndices:
122
128
 
123
129
  def __post_init__(self):
124
130
  # sanity checks
125
- assert (self.padded_org_vocab_start_index <=
126
- self.padded_org_vocab_end_index)
127
- assert (self.padded_added_vocab_start_index <=
128
- self.padded_added_vocab_end_index)
131
+ assert self.padded_org_vocab_start_index <= self.padded_org_vocab_end_index
132
+ assert self.padded_added_vocab_start_index <= self.padded_added_vocab_end_index
129
133
 
130
134
  assert self.org_vocab_start_index <= self.org_vocab_end_index
131
135
  assert self.added_vocab_start_index <= self.added_vocab_end_index
132
136
 
133
137
  assert self.org_vocab_start_index <= self.padded_org_vocab_start_index
134
- assert (self.added_vocab_start_index <=
135
- self.padded_added_vocab_start_index)
138
+ assert self.added_vocab_start_index <= self.padded_added_vocab_start_index
136
139
  assert self.org_vocab_end_index <= self.padded_org_vocab_end_index
137
140
  assert self.added_vocab_end_index <= self.padded_added_vocab_end_index
138
141
 
@@ -142,20 +145,27 @@ class VocabParallelEmbeddingShardIndices:
142
145
 
143
146
  @torch.jit.script
144
147
  def get_masked_input_and_mask(
145
- input_: torch.Tensor, org_vocab_start_index: int,
146
- org_vocab_end_index: int, num_org_vocab_padding: int,
147
- added_vocab_start_index: int,
148
- added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]:
148
+ input_: torch.Tensor,
149
+ org_vocab_start_index: int,
150
+ org_vocab_end_index: int,
151
+ num_org_vocab_padding: int,
152
+ added_vocab_start_index: int,
153
+ added_vocab_end_index: int,
154
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
149
155
  # torch.jit.script will fuse all of the pointwise ops below
150
156
  # into a single kernel, making it very fast
151
- org_vocab_mask = (input_ >= org_vocab_start_index) & (input_ <
152
- org_vocab_end_index)
157
+ org_vocab_mask = (input_ >= org_vocab_start_index) & (input_ < org_vocab_end_index)
153
158
  added_vocab_mask = (input_ >= added_vocab_start_index) & (
154
- input_ < added_vocab_end_index)
155
- added_offset = added_vocab_start_index - (
156
- org_vocab_end_index - org_vocab_start_index) - num_org_vocab_padding
157
- valid_offset = (org_vocab_start_index *
158
- org_vocab_mask) + (added_offset * added_vocab_mask)
159
+ input_ < added_vocab_end_index
160
+ )
161
+ added_offset = (
162
+ added_vocab_start_index
163
+ - (org_vocab_end_index - org_vocab_start_index)
164
+ - num_org_vocab_padding
165
+ )
166
+ valid_offset = (org_vocab_start_index * org_vocab_mask) + (
167
+ added_offset * added_vocab_mask
168
+ )
159
169
  vocab_mask = org_vocab_mask | added_vocab_mask
160
170
  input_ = vocab_mask * (input_ - valid_offset)
161
171
  return input_, ~vocab_mask
@@ -200,15 +210,17 @@ class VocabParallelEmbedding(torch.nn.Module):
200
210
  prefix: full name of the layer in the state dict
201
211
  """ # noqa: E501
202
212
 
203
- def __init__(self,
204
- num_embeddings: int,
205
- embedding_dim: int,
206
- params_dtype: Optional[torch.dtype] = None,
207
- org_num_embeddings: Optional[int] = None,
208
- padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
209
- quant_config: Optional[QuantizationConfig] = None,
210
- prefix: str = "",
211
- enable_tp: bool = True):
213
+ def __init__(
214
+ self,
215
+ num_embeddings: int,
216
+ embedding_dim: int,
217
+ params_dtype: Optional[torch.dtype] = None,
218
+ org_num_embeddings: Optional[int] = None,
219
+ padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
220
+ quant_config: Optional[QuantizationConfig] = None,
221
+ prefix: str = "",
222
+ enable_tp: bool = True,
223
+ ):
212
224
  super().__init__()
213
225
 
214
226
  self.enable_tp = enable_tp
@@ -223,18 +235,22 @@ class VocabParallelEmbedding(torch.nn.Module):
223
235
  self.padding_size = padding_size
224
236
  self.org_vocab_size = org_num_embeddings or num_embeddings
225
237
  num_added_embeddings = num_embeddings - self.org_vocab_size
226
- self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size,
227
- self.padding_size)
238
+ self.org_vocab_size_padded = pad_vocab_size(
239
+ self.org_vocab_size, self.padding_size
240
+ )
228
241
  self.num_embeddings_padded = pad_vocab_size(
229
- self.org_vocab_size_padded + num_added_embeddings,
230
- self.padding_size)
242
+ self.org_vocab_size_padded + num_added_embeddings, self.padding_size
243
+ )
231
244
  assert self.org_vocab_size_padded <= self.num_embeddings_padded
232
245
 
233
- self.shard_indices = self._get_indices(self.num_embeddings_padded,
234
- self.org_vocab_size_padded,
235
- self.num_embeddings,
236
- self.org_vocab_size, tp_rank,
237
- self.tp_size)
246
+ self.shard_indices = self._get_indices(
247
+ self.num_embeddings_padded,
248
+ self.org_vocab_size_padded,
249
+ self.num_embeddings,
250
+ self.org_vocab_size,
251
+ tp_rank,
252
+ self.tp_size,
253
+ )
238
254
  self.embedding_dim = embedding_dim
239
255
 
240
256
  linear_method = None
@@ -248,11 +264,13 @@ class VocabParallelEmbedding(torch.nn.Module):
248
264
  # layer type like ParallelLMHead, this is not important.
249
265
  is_embedding_layer = type(self.__class__) is VocabParallelEmbedding
250
266
  linear_method_implements_embedding = method_has_implemented_embedding(
251
- type(linear_method))
267
+ type(linear_method)
268
+ )
252
269
  if is_embedding_layer and not linear_method_implements_embedding:
253
270
  raise NotImplementedError(
254
271
  f"The class {type(linear_method).__name__} must implement "
255
- "the 'embedding' method, see UnquantizedEmbeddingMethod.")
272
+ "the 'embedding' method, see UnquantizedEmbeddingMethod."
273
+ )
256
274
 
257
275
  self.linear_method: QuantizeMethodBase = linear_method
258
276
 
@@ -260,58 +278,73 @@ class VocabParallelEmbedding(torch.nn.Module):
260
278
  params_dtype = torch.get_default_dtype()
261
279
  # Divide the weight matrix along the vocaburaly dimension.
262
280
  self.num_added_embeddings = self.num_embeddings - self.org_vocab_size
263
- self.num_embeddings_per_partition = divide(self.num_embeddings_padded,
264
- self.tp_size)
265
- assert (self.shard_indices.num_elements_padded ==
266
- self.num_embeddings_per_partition)
281
+ self.num_embeddings_per_partition = divide(
282
+ self.num_embeddings_padded, self.tp_size
283
+ )
284
+ assert (
285
+ self.shard_indices.num_elements_padded == self.num_embeddings_per_partition
286
+ )
267
287
  self.num_org_embeddings_per_partition = (
268
- self.shard_indices.org_vocab_end_index -
269
- self.shard_indices.org_vocab_start_index)
288
+ self.shard_indices.org_vocab_end_index
289
+ - self.shard_indices.org_vocab_start_index
290
+ )
270
291
  self.num_added_embeddings_per_partition = (
271
- self.shard_indices.added_vocab_end_index -
272
- self.shard_indices.added_vocab_start_index)
273
-
274
- self.linear_method.create_weights(self,
275
- self.embedding_dim,
276
- [self.num_embeddings_per_partition],
277
- self.embedding_dim,
278
- self.num_embeddings_padded,
279
- params_dtype=params_dtype,
280
- weight_loader=self.weight_loader)
292
+ self.shard_indices.added_vocab_end_index
293
+ - self.shard_indices.added_vocab_start_index
294
+ )
295
+
296
+ self.linear_method.create_weights(
297
+ self,
298
+ self.embedding_dim,
299
+ [self.num_embeddings_per_partition],
300
+ self.embedding_dim,
301
+ self.num_embeddings_padded,
302
+ params_dtype=params_dtype,
303
+ weight_loader=self.weight_loader,
304
+ )
281
305
 
282
306
  @classmethod
283
- def _get_indices(cls, vocab_size_padded: int, org_vocab_size_padded: int,
284
- vocab_size: int, org_vocab_size: int, tp_rank: int,
285
- tp_size: int) -> VocabParallelEmbeddingShardIndices:
307
+ def _get_indices(
308
+ cls,
309
+ vocab_size_padded: int,
310
+ org_vocab_size_padded: int,
311
+ vocab_size: int,
312
+ org_vocab_size: int,
313
+ tp_rank: int,
314
+ tp_size: int,
315
+ ) -> VocabParallelEmbeddingShardIndices:
286
316
  """Get start and end indices for vocab parallel embedding, following the
287
317
  layout outlined in the class docstring, based on the given tp_rank and
288
318
  tp_size."""
289
319
  num_added_embeddings_padded = vocab_size_padded - org_vocab_size_padded
290
320
  padded_org_vocab_start_index, padded_org_vocab_end_index = (
291
- vocab_range_from_global_vocab_size(org_vocab_size_padded, tp_rank,
292
- tp_size))
321
+ vocab_range_from_global_vocab_size(org_vocab_size_padded, tp_rank, tp_size)
322
+ )
293
323
  padded_added_vocab_start_index, padded_added_vocab_end_index = (
294
- vocab_range_from_global_vocab_size(num_added_embeddings_padded,
295
- tp_rank,
296
- tp_size,
297
- offset=org_vocab_size))
324
+ vocab_range_from_global_vocab_size(
325
+ num_added_embeddings_padded, tp_rank, tp_size, offset=org_vocab_size
326
+ )
327
+ )
298
328
  # remove padding
299
- org_vocab_start_index = min(padded_org_vocab_start_index,
300
- org_vocab_size)
329
+ org_vocab_start_index = min(padded_org_vocab_start_index, org_vocab_size)
301
330
  org_vocab_end_index = min(padded_org_vocab_end_index, org_vocab_size)
302
- added_vocab_start_index = min(padded_added_vocab_start_index,
303
- vocab_size)
331
+ added_vocab_start_index = min(padded_added_vocab_start_index, vocab_size)
304
332
  added_vocab_end_index = min(padded_added_vocab_end_index, vocab_size)
305
333
  return VocabParallelEmbeddingShardIndices(
306
- padded_org_vocab_start_index, padded_org_vocab_end_index,
307
- padded_added_vocab_start_index, padded_added_vocab_end_index,
308
- org_vocab_start_index, org_vocab_end_index,
309
- added_vocab_start_index, added_vocab_end_index)
334
+ padded_org_vocab_start_index,
335
+ padded_org_vocab_end_index,
336
+ padded_added_vocab_start_index,
337
+ padded_added_vocab_end_index,
338
+ org_vocab_start_index,
339
+ org_vocab_end_index,
340
+ added_vocab_start_index,
341
+ added_vocab_end_index,
342
+ )
310
343
 
311
344
  def get_sharded_to_full_mapping(self) -> Optional[List[int]]:
312
345
  """Get a mapping that can be used to reindex the gathered
313
346
  logits for sampling.
314
-
347
+
315
348
  During sampling, we gather logits from all ranks. The relationship
316
349
  of index->token_id will follow the same format as outlined in the class
317
350
  docstring. However, after the gather, we want to reindex the final
@@ -326,32 +359,49 @@ class VocabParallelEmbedding(torch.nn.Module):
326
359
  added_embeddings: List[int] = []
327
360
  padding: List[int] = []
328
361
  for tp_rank in range(self.tp_size):
329
- shard_indices = self._get_indices(self.num_embeddings_padded,
330
- self.org_vocab_size_padded,
331
- self.num_embeddings,
332
- self.org_vocab_size, tp_rank,
333
- self.tp_size)
362
+ shard_indices = self._get_indices(
363
+ self.num_embeddings_padded,
364
+ self.org_vocab_size_padded,
365
+ self.num_embeddings,
366
+ self.org_vocab_size,
367
+ tp_rank,
368
+ self.tp_size,
369
+ )
334
370
  range_start = self.num_embeddings_per_partition * tp_rank
335
371
  range_end = self.num_embeddings_per_partition * (tp_rank + 1)
336
372
  base_embeddings.extend(
337
- range(range_start,
338
- range_start + shard_indices.num_org_elements))
373
+ range(range_start, range_start + shard_indices.num_org_elements)
374
+ )
339
375
  padding.extend(
340
- range(range_start + shard_indices.num_org_elements,
341
- range_start + shard_indices.num_org_elements_padded))
376
+ range(
377
+ range_start + shard_indices.num_org_elements,
378
+ range_start + shard_indices.num_org_elements_padded,
379
+ )
380
+ )
342
381
  added_embeddings.extend(
343
382
  range(
344
383
  range_start + shard_indices.num_org_elements_padded,
345
- range_start + shard_indices.num_org_elements_padded +
346
- shard_indices.num_added_elements))
384
+ range_start
385
+ + shard_indices.num_org_elements_padded
386
+ + shard_indices.num_added_elements,
387
+ )
388
+ )
347
389
  padding.extend(
348
390
  range(
349
- range_start + shard_indices.num_org_elements_padded +
350
- shard_indices.num_added_elements,
351
- range_start + shard_indices.num_org_elements_padded +
352
- shard_indices.num_added_elements_padded))
353
- assert (range_start + shard_indices.num_org_elements_padded +
354
- shard_indices.num_added_elements_padded == range_end)
391
+ range_start
392
+ + shard_indices.num_org_elements_padded
393
+ + shard_indices.num_added_elements,
394
+ range_start
395
+ + shard_indices.num_org_elements_padded
396
+ + shard_indices.num_added_elements_padded,
397
+ )
398
+ )
399
+ assert (
400
+ range_start
401
+ + shard_indices.num_org_elements_padded
402
+ + shard_indices.num_added_elements_padded
403
+ == range_end
404
+ )
355
405
  ret = base_embeddings + added_embeddings + padding
356
406
  assert len(ret) == self.num_embeddings_padded
357
407
  return ret
@@ -385,10 +435,14 @@ class VocabParallelEmbedding(torch.nn.Module):
385
435
  # If param packed on the same dim we are sharding on, then
386
436
  # need to adjust offsets of loaded weight by pack_factor.
387
437
  if packed_dim is not None and packed_dim == output_dim:
388
- packed_factor = param.packed_factor if isinstance(
389
- param, BasevLLMParameter) else param.pack_factor
390
- assert loaded_weight.shape[output_dim] == (self.org_vocab_size //
391
- param.packed_factor)
438
+ packed_factor = (
439
+ param.packed_factor
440
+ if isinstance(param, BasevLLMParameter)
441
+ else param.pack_factor
442
+ )
443
+ assert loaded_weight.shape[output_dim] == (
444
+ self.org_vocab_size // param.packed_factor
445
+ )
392
446
  start_idx = start_idx // packed_factor
393
447
  shard_size = shard_size // packed_factor
394
448
  else:
@@ -396,23 +450,24 @@ class VocabParallelEmbedding(torch.nn.Module):
396
450
 
397
451
  # Copy the data.
398
452
  loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
399
- param[:loaded_weight.shape[0]].data.copy_(loaded_weight)
400
- param[loaded_weight.shape[0]:].data.fill_(0)
453
+ param[: loaded_weight.shape[0]].data.copy_(loaded_weight)
454
+ param[loaded_weight.shape[0] :].data.fill_(0)
401
455
 
402
456
  def forward(self, input_):
403
457
  if self.tp_size > 1:
404
458
  # Build the mask.
405
459
  masked_input, input_mask = get_masked_input_and_mask(
406
- input_, self.shard_indices.org_vocab_start_index,
460
+ input_,
461
+ self.shard_indices.org_vocab_start_index,
407
462
  self.shard_indices.org_vocab_end_index,
408
463
  self.shard_indices.num_org_vocab_padding,
409
464
  self.shard_indices.added_vocab_start_index,
410
- self.shard_indices.added_vocab_end_index)
465
+ self.shard_indices.added_vocab_end_index,
466
+ )
411
467
  else:
412
468
  masked_input = input_
413
469
  # Get the embeddings.
414
- output_parallel = self.linear_method.embedding(self,
415
- masked_input.long())
470
+ output_parallel = self.linear_method.embedding(self, masked_input.long())
416
471
  # Mask the output embedding.
417
472
  if self.tp_size > 1:
418
473
  output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
@@ -426,9 +481,9 @@ class VocabParallelEmbedding(torch.nn.Module):
426
481
  s = f"num_embeddings={self.num_embeddings_per_partition}"
427
482
  s += f", embedding_dim={self.embedding_dim}"
428
483
  s += f", org_vocab_size={self.org_vocab_size}"
429
- s += f', num_embeddings_padded={self.num_embeddings_padded}'
484
+ s += f", num_embeddings_padded={self.num_embeddings_padded}"
430
485
  if self.enable_tp:
431
- s += f', tp_size={self.tp_size}'
486
+ s += f", tp_size={self.tp_size}"
432
487
  return s
433
488
 
434
489
 
@@ -448,27 +503,38 @@ class ParallelLMHead(VocabParallelEmbedding):
448
503
  padding_size: padding size for the vocabulary.
449
504
  """
450
505
 
451
- def __init__(self,
452
- num_embeddings: int,
453
- embedding_dim: int,
454
- bias: bool = False,
455
- params_dtype: Optional[torch.dtype] = None,
456
- org_num_embeddings: Optional[int] = None,
457
- padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
458
- quant_config: Optional[QuantizationConfig] = None,
459
- prefix: str = ""):
460
- super().__init__(num_embeddings, embedding_dim, params_dtype,
461
- org_num_embeddings, padding_size, quant_config,
462
- prefix)
506
+ def __init__(
507
+ self,
508
+ num_embeddings: int,
509
+ embedding_dim: int,
510
+ bias: bool = False,
511
+ params_dtype: Optional[torch.dtype] = None,
512
+ org_num_embeddings: Optional[int] = None,
513
+ padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
514
+ quant_config: Optional[QuantizationConfig] = None,
515
+ prefix: str = "",
516
+ ):
517
+ super().__init__(
518
+ num_embeddings,
519
+ embedding_dim,
520
+ params_dtype,
521
+ org_num_embeddings,
522
+ padding_size,
523
+ quant_config,
524
+ prefix,
525
+ )
463
526
  self.quant_config = quant_config
464
527
  if bias:
465
528
  self.bias = Parameter(
466
- torch.empty(self.num_embeddings_per_partition,
467
- dtype=params_dtype))
468
- set_weight_attrs(self.bias, {
469
- "output_dim": 0,
470
- "weight_loader": self.weight_loader,
471
- })
529
+ torch.empty(self.num_embeddings_per_partition, dtype=params_dtype)
530
+ )
531
+ set_weight_attrs(
532
+ self.bias,
533
+ {
534
+ "output_dim": 0,
535
+ "weight_loader": self.weight_loader,
536
+ },
537
+ )
472
538
  else:
473
539
  self.register_parameter("bias", None)
474
540
 
@@ -483,4 +549,4 @@ class ParallelLMHead(VocabParallelEmbedding):
483
549
 
484
550
  def forward(self, input_):
485
551
  del input_
486
- raise RuntimeError("LMHead's weights should be used in the sampler.")
552
+ raise RuntimeError("LMHead's weights should be used in the sampler.")
@@ -100,20 +100,6 @@ class DetokenizerManager:
100
100
 
101
101
  if isinstance(recv_obj, BatchEmbeddingOut):
102
102
  # If it is embedding model, no detokenization is needed.
103
- self.send_to_tokenizer.send_pyobj(
104
- BatchEmbeddingOut(
105
- rids=recv_obj.rids,
106
- embeddings=recv_obj.embeddings,
107
- meta_info=recv_obj.meta_info,
108
- finished_reason=recv_obj.finished_reason,
109
- )
110
- )
111
- continue
112
- elif isinstance(recv_obj, UpdateWeightReqOutput):
113
- # If it is a weight update request, no detokenization is needed.
114
- self.send_to_tokenizer.send_pyobj(recv_obj)
115
- continue
116
- elif isinstance(recv_obj, GetMemPoolSizeReqOutput):
117
103
  self.send_to_tokenizer.send_pyobj(recv_obj)
118
104
  continue
119
105
  else:
@@ -86,8 +86,10 @@ class GenerateReqInput:
86
86
  self.parallel_sample_num = self.sampling_params.get("n", 1)
87
87
  else: # isinstance(self.sampling_params, list):
88
88
  self.parallel_sample_num = self.sampling_params[0].get("n", 1)
89
- assert all(self.parallel_sample_num == sampling_params.get("n", 1) for sampling_params in self.sampling_params), (
90
- "The parallel_sample_num should be the same for all samples in sample params.")
89
+ assert all(
90
+ self.parallel_sample_num == sampling_params.get("n", 1)
91
+ for sampling_params in self.sampling_params
92
+ ), "The parallel_sample_num should be the same for all samples in sample params."
91
93
 
92
94
  if self.parallel_sample_num > 1 and self.is_single:
93
95
  self.is_single = False
@@ -182,7 +184,7 @@ class TokenizedGenerateReqInput:
182
184
  input_text: str
183
185
  # The input token ids
184
186
  input_ids: List[int]
185
- # The image input
187
+ # The image inputs
186
188
  image_inputs: dict
187
189
  # The sampling parameters
188
190
  sampling_params: SamplingParams