sglang 0.2.11__py3-none-any.whl → 0.2.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. sglang/bench_latency.py +6 -4
  2. sglang/bench_serving.py +46 -22
  3. sglang/lang/compiler.py +2 -2
  4. sglang/lang/ir.py +3 -3
  5. sglang/srt/constrained/base_tool_cache.py +1 -1
  6. sglang/srt/constrained/fsm_cache.py +12 -2
  7. sglang/srt/layers/activation.py +33 -0
  8. sglang/srt/layers/{token_attention.py → decode_attention.py} +9 -5
  9. sglang/srt/layers/extend_attention.py +6 -1
  10. sglang/srt/layers/layernorm.py +65 -0
  11. sglang/srt/layers/logits_processor.py +5 -0
  12. sglang/srt/layers/pooler.py +50 -0
  13. sglang/srt/layers/{context_flashattention_nopad.py → prefill_attention.py} +5 -0
  14. sglang/srt/layers/radix_attention.py +2 -2
  15. sglang/srt/managers/detokenizer_manager.py +31 -9
  16. sglang/srt/managers/io_struct.py +63 -0
  17. sglang/srt/managers/policy_scheduler.py +173 -25
  18. sglang/srt/managers/schedule_batch.py +110 -87
  19. sglang/srt/managers/tokenizer_manager.py +193 -111
  20. sglang/srt/managers/tp_worker.py +289 -352
  21. sglang/srt/mem_cache/{base_cache.py → base_prefix_cache.py} +9 -4
  22. sglang/srt/mem_cache/chunk_cache.py +43 -20
  23. sglang/srt/mem_cache/memory_pool.py +2 -2
  24. sglang/srt/mem_cache/radix_cache.py +74 -40
  25. sglang/srt/model_executor/cuda_graph_runner.py +24 -9
  26. sglang/srt/model_executor/forward_batch_info.py +168 -105
  27. sglang/srt/model_executor/model_runner.py +24 -37
  28. sglang/srt/models/gemma2.py +0 -1
  29. sglang/srt/models/internlm2.py +2 -7
  30. sglang/srt/models/llama2.py +4 -4
  31. sglang/srt/models/llama_embedding.py +88 -0
  32. sglang/srt/models/qwen2_moe.py +0 -11
  33. sglang/srt/openai_api/adapter.py +155 -27
  34. sglang/srt/openai_api/protocol.py +37 -1
  35. sglang/srt/sampling/penaltylib/__init__.py +13 -0
  36. sglang/srt/sampling/penaltylib/orchestrator.py +357 -0
  37. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +80 -0
  38. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +105 -0
  39. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +79 -0
  40. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +83 -0
  41. sglang/srt/sampling_params.py +31 -4
  42. sglang/srt/server.py +69 -15
  43. sglang/srt/server_args.py +26 -19
  44. sglang/srt/utils.py +31 -13
  45. sglang/test/run_eval.py +10 -1
  46. sglang/test/runners.py +63 -63
  47. sglang/test/simple_eval_humaneval.py +2 -8
  48. sglang/test/simple_eval_mgsm.py +203 -0
  49. sglang/test/srt/sampling/penaltylib/utils.py +337 -0
  50. sglang/test/test_layernorm.py +60 -0
  51. sglang/test/test_programs.py +4 -2
  52. sglang/test/test_utils.py +20 -2
  53. sglang/utils.py +0 -1
  54. sglang/version.py +1 -1
  55. {sglang-0.2.11.dist-info → sglang-0.2.12.dist-info}/METADATA +23 -14
  56. sglang-0.2.12.dist-info/RECORD +112 -0
  57. sglang/srt/layers/linear.py +0 -884
  58. sglang/srt/layers/quantization/__init__.py +0 -64
  59. sglang/srt/layers/quantization/fp8.py +0 -677
  60. sglang-0.2.11.dist-info/RECORD +0 -102
  61. {sglang-0.2.11.dist-info → sglang-0.2.12.dist-info}/LICENSE +0 -0
  62. {sglang-0.2.11.dist-info → sglang-0.2.12.dist-info}/WHEEL +0 -0
  63. {sglang-0.2.11.dist-info → sglang-0.2.12.dist-info}/top_level.txt +0 -0
@@ -16,13 +16,17 @@ limitations under the License.
16
16
  """ModelRunner runs the forward passes of the models."""
17
17
  from dataclasses import dataclass
18
18
  from enum import IntEnum, auto
19
- from typing import List
19
+ from typing import TYPE_CHECKING, List
20
20
 
21
21
  import numpy as np
22
22
  import torch
23
23
 
24
+ from sglang.srt.managers.schedule_batch import ScheduleBatch
24
25
  from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
25
26
 
27
+ if TYPE_CHECKING:
28
+ from sglang.srt.model_executor.model_runner import ModelRunner
29
+
26
30
 
27
31
  class ForwardMode(IntEnum):
28
32
  # Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
@@ -39,25 +43,33 @@ class InputMetadata:
39
43
 
40
44
  forward_mode: ForwardMode
41
45
  batch_size: int
42
- total_num_tokens: int
43
46
  req_pool_indices: torch.Tensor
44
47
  seq_lens: torch.Tensor
45
- positions: torch.Tensor
46
48
  req_to_token_pool: ReqToTokenPool
47
49
  token_to_kv_pool: BaseTokenToKVPool
48
50
 
49
- # For extend
50
- extend_seq_lens: torch.Tensor
51
- extend_start_loc: torch.Tensor
52
- extend_no_prefix: bool
53
-
54
51
  # Output location of the KV cache
55
- out_cache_loc: torch.Tensor = None
52
+ out_cache_loc: torch.Tensor
53
+
54
+ total_num_tokens: int = None
55
+
56
+ # Position information
57
+ positions: torch.Tensor = None
58
+
59
+ # For extend
60
+ extend_seq_lens: torch.Tensor = None
61
+ extend_start_loc: torch.Tensor = None
62
+ extend_no_prefix: bool = None
56
63
 
57
64
  # Output options
58
65
  return_logprob: bool = False
59
66
  top_logprobs_nums: List[int] = None
60
67
 
68
+ # For multimodal
69
+ pixel_values: List[torch.Tensor] = None
70
+ image_sizes: List[List[int]] = None
71
+ image_offsets: List[int] = None
72
+
61
73
  # Trition attention backend
62
74
  triton_max_seq_len: int = 0
63
75
  triton_max_extend_len: int = 0
@@ -70,107 +82,171 @@ class InputMetadata:
70
82
  flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
71
83
  flashinfer_use_ragged: bool = False
72
84
 
73
- @classmethod
74
- def create(
75
- cls,
76
- model_runner,
77
- forward_mode,
78
- req_pool_indices,
79
- seq_lens,
80
- prefix_lens,
81
- position_ids_offsets,
82
- out_cache_loc,
83
- top_logprobs_nums=None,
84
- return_logprob=False,
85
- skip_flashinfer_init=False,
86
- ):
87
- flashinfer_use_ragged = False
88
- if not skip_flashinfer_init and not model_runner.server_args.disable_flashinfer:
89
- if forward_mode != ForwardMode.DECODE and int(torch.sum(seq_lens)) > 4096:
90
- flashinfer_use_ragged = True
91
- init_flashinfer_args(
92
- forward_mode,
93
- model_runner,
94
- req_pool_indices,
95
- seq_lens,
96
- prefix_lens,
97
- model_runner.flashinfer_decode_wrapper,
98
- flashinfer_use_ragged,
85
+ def init_multimuldal_info(self, batch: ScheduleBatch):
86
+ reqs = batch.reqs
87
+ self.pixel_values = [r.pixel_values for r in reqs]
88
+ self.image_sizes = [r.image_size for r in reqs]
89
+ self.image_offsets = [
90
+ (
91
+ (r.image_offset - len(r.prefix_indices))
92
+ if r.image_offset is not None
93
+ else 0
99
94
  )
95
+ for r in reqs
96
+ ]
100
97
 
101
- batch_size = len(req_pool_indices)
98
+ def compute_positions(self, batch: ScheduleBatch):
99
+ position_ids_offsets = batch.position_ids_offsets
102
100
 
103
- if forward_mode == ForwardMode.DECODE:
104
- positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64)
105
- extend_seq_lens = extend_start_loc = extend_no_prefix = None
106
- if not model_runner.server_args.disable_flashinfer:
107
- # This variable is not needed in this case,
108
- # we do not compute it to make it compatbile with cuda graph.
109
- total_num_tokens = None
101
+ if self.forward_mode == ForwardMode.DECODE:
102
+ if True:
103
+ self.positions = self.seq_lens - 1
110
104
  else:
111
- total_num_tokens = int(torch.sum(seq_lens))
105
+ # Deprecated
106
+ self.positions = (self.seq_lens - 1) + position_ids_offsets
112
107
  else:
113
- seq_lens_cpu = seq_lens.cpu().numpy()
114
- prefix_lens_cpu = prefix_lens.cpu().numpy()
115
- position_ids_offsets_cpu = position_ids_offsets.cpu().numpy()
116
- positions = torch.tensor(
117
- np.concatenate(
118
- [
119
- np.arange(
120
- prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
121
- seq_lens_cpu[i] + position_ids_offsets_cpu[i],
122
- )
123
- for i in range(batch_size)
124
- ],
125
- axis=0,
126
- ),
127
- device="cuda",
128
- )
129
- extend_seq_lens = seq_lens - prefix_lens
130
- extend_start_loc = torch.zeros_like(seq_lens)
131
- extend_start_loc[1:] = torch.cumsum(extend_seq_lens[:-1], dim=0)
132
- extend_no_prefix = torch.all(prefix_lens == 0)
133
- total_num_tokens = int(torch.sum(seq_lens))
108
+ if True:
109
+ self.positions = torch.tensor(
110
+ np.concatenate(
111
+ [
112
+ np.arange(len(req.prefix_indices), len(req.fill_ids))
113
+ for req in batch.reqs
114
+ ],
115
+ axis=0,
116
+ ),
117
+ device="cuda",
118
+ )
119
+ else:
120
+ # Deprecated
121
+ position_ids_offsets_cpu = position_ids_offsets.cpu().numpy()
122
+ self.positions = torch.tensor(
123
+ np.concatenate(
124
+ [
125
+ np.arange(
126
+ len(req.prefix_indices) + position_ids_offsets_cpu[i],
127
+ len(req.fill_ids) + position_ids_offsets_cpu[i],
128
+ )
129
+ for i, req in enumerate(batch.reqs)
130
+ ],
131
+ axis=0,
132
+ ),
133
+ device="cuda",
134
+ )
135
+
136
+ # Positions should be in long type
137
+ self.positions = self.positions.to(torch.int64)
138
+
139
+ def compute_extend_infos(self, batch: ScheduleBatch):
140
+ if self.forward_mode == ForwardMode.DECODE:
141
+ self.extend_seq_lens = self.extend_start_loc = self.extend_no_prefix = None
142
+ else:
143
+ extend_lens_cpu = [
144
+ len(r.fill_ids) - len(r.prefix_indices) for r in batch.reqs
145
+ ]
146
+ self.extend_seq_lens = torch.tensor(extend_lens_cpu, device="cuda")
147
+ self.extend_start_loc = torch.zeros_like(self.seq_lens)
148
+ self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
149
+ self.extend_no_prefix = all(len(r.prefix_indices) == 0 for r in batch.reqs)
134
150
 
151
+ @classmethod
152
+ def from_schedule_batch(
153
+ cls,
154
+ model_runner: "ModelRunner",
155
+ batch: ScheduleBatch,
156
+ forward_mode: ForwardMode,
157
+ ):
135
158
  ret = cls(
136
159
  forward_mode=forward_mode,
137
- batch_size=batch_size,
138
- total_num_tokens=total_num_tokens,
139
- req_pool_indices=req_pool_indices,
140
- seq_lens=seq_lens,
141
- positions=positions,
160
+ batch_size=batch.batch_size(),
161
+ req_pool_indices=batch.req_pool_indices,
162
+ seq_lens=batch.seq_lens,
142
163
  req_to_token_pool=model_runner.req_to_token_pool,
143
164
  token_to_kv_pool=model_runner.token_to_kv_pool,
144
- out_cache_loc=out_cache_loc,
145
- extend_seq_lens=extend_seq_lens,
146
- extend_start_loc=extend_start_loc,
147
- extend_no_prefix=extend_no_prefix,
148
- return_logprob=return_logprob,
149
- top_logprobs_nums=top_logprobs_nums,
150
- flashinfer_prefill_wrapper_ragged=model_runner.flashinfer_prefill_wrapper_ragged,
151
- flashinfer_prefill_wrapper_paged=model_runner.flashinfer_prefill_wrapper_paged,
152
- flashinfer_decode_wrapper=model_runner.flashinfer_decode_wrapper,
153
- flashinfer_use_ragged=flashinfer_use_ragged,
165
+ out_cache_loc=batch.out_cache_loc,
166
+ return_logprob=batch.return_logprob,
167
+ top_logprobs_nums=batch.top_logprobs_nums,
154
168
  )
155
169
 
170
+ ret.compute_positions(batch)
171
+
172
+ ret.compute_extend_infos(batch)
173
+
174
+ if (
175
+ forward_mode != ForwardMode.DECODE
176
+ or model_runner.server_args.disable_flashinfer
177
+ ):
178
+ ret.total_num_tokens = int(torch.sum(ret.seq_lens))
179
+
180
+ if forward_mode != ForwardMode.DECODE:
181
+ ret.init_multimuldal_info(batch)
182
+
183
+ prefix_lens = None
184
+ if forward_mode != ForwardMode.DECODE:
185
+ prefix_lens = torch.tensor(
186
+ [len(r.prefix_indices) for r in batch.reqs], device="cuda"
187
+ )
188
+
156
189
  if model_runner.server_args.disable_flashinfer:
157
- (
158
- ret.triton_max_seq_len,
159
- ret.triton_max_extend_len,
160
- ret.triton_start_loc,
161
- ret.triton_prefix_lens,
162
- ) = init_triton_args(forward_mode, seq_lens, prefix_lens)
190
+ ret.init_triton_args(batch, prefix_lens)
191
+
192
+ flashinfer_use_ragged = False
193
+ if not model_runner.server_args.disable_flashinfer:
194
+ if (
195
+ forward_mode != ForwardMode.DECODE
196
+ and int(torch.sum(ret.seq_lens)) > 4096
197
+ ):
198
+ flashinfer_use_ragged = True
199
+ ret.init_flashinfer_handlers(
200
+ model_runner, prefix_lens, flashinfer_use_ragged
201
+ )
163
202
 
164
203
  return ret
165
204
 
205
+ def init_triton_args(self, batch: ScheduleBatch, prefix_lens):
206
+ """Init auxiliary variables for triton attention backend."""
207
+ self.triton_max_seq_len = int(torch.max(self.seq_lens))
208
+ self.triton_prefix_lens = prefix_lens
209
+ self.triton_start_loc = torch.zeros_like(self.seq_lens, dtype=torch.int32)
210
+ self.triton_start_loc[1:] = torch.cumsum(self.seq_lens[:-1], dim=0)
211
+
212
+ if self.forward_mode == ForwardMode.DECODE:
213
+ self.triton_max_extend_len = None
214
+ else:
215
+ extend_seq_lens = self.seq_lens - prefix_lens
216
+ self.triton_max_extend_len = int(torch.max(extend_seq_lens))
217
+
218
+ def init_flashinfer_handlers(
219
+ self, model_runner, prefix_lens, flashinfer_use_ragged
220
+ ):
221
+ update_flashinfer_indices(
222
+ self.forward_mode,
223
+ model_runner,
224
+ self.req_pool_indices,
225
+ self.seq_lens,
226
+ prefix_lens,
227
+ flashinfer_use_ragged=flashinfer_use_ragged,
228
+ )
229
+
230
+ (
231
+ self.flashinfer_prefill_wrapper_ragged,
232
+ self.flashinfer_prefill_wrapper_paged,
233
+ self.flashinfer_decode_wrapper,
234
+ self.flashinfer_use_ragged,
235
+ ) = (
236
+ model_runner.flashinfer_prefill_wrapper_ragged,
237
+ model_runner.flashinfer_prefill_wrapper_paged,
238
+ model_runner.flashinfer_decode_wrapper,
239
+ flashinfer_use_ragged,
240
+ )
241
+
166
242
 
167
- def init_flashinfer_args(
243
+ def update_flashinfer_indices(
168
244
  forward_mode,
169
245
  model_runner,
170
246
  req_pool_indices,
171
247
  seq_lens,
172
248
  prefix_lens,
173
- flashinfer_decode_wrapper,
249
+ flashinfer_decode_wrapper=None,
174
250
  flashinfer_use_ragged=False,
175
251
  ):
176
252
  """Init auxiliary variables for FlashInfer attention backend."""
@@ -178,7 +254,6 @@ def init_flashinfer_args(
178
254
  num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size)
179
255
  head_dim = model_runner.model_config.head_dim
180
256
  batch_size = len(req_pool_indices)
181
- total_num_tokens = int(torch.sum(seq_lens))
182
257
 
183
258
  if flashinfer_use_ragged:
184
259
  paged_kernel_lens = prefix_lens
@@ -201,6 +276,10 @@ def init_flashinfer_args(
201
276
  kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
202
277
 
203
278
  if forward_mode == ForwardMode.DECODE:
279
+ # CUDA graph uses different flashinfer_decode_wrapper
280
+ if flashinfer_decode_wrapper is None:
281
+ flashinfer_decode_wrapper = model_runner.flashinfer_decode_wrapper
282
+
204
283
  flashinfer_decode_wrapper.end_forward()
205
284
  flashinfer_decode_wrapper.begin_forward(
206
285
  kv_indptr,
@@ -238,19 +317,3 @@ def init_flashinfer_args(
238
317
  head_dim,
239
318
  1,
240
319
  )
241
-
242
-
243
- def init_triton_args(forward_mode, seq_lens, prefix_lens):
244
- """Init auxiliary variables for triton attention backend."""
245
- batch_size = len(seq_lens)
246
- max_seq_len = int(torch.max(seq_lens))
247
- start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
248
- start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)
249
-
250
- if forward_mode == ForwardMode.DECODE:
251
- max_extend_len = None
252
- else:
253
- extend_seq_lens = seq_lens - prefix_lens
254
- max_extend_len = int(torch.max(extend_seq_lens))
255
-
256
- return max_seq_len, max_extend_len, start_loc, prefix_lens
@@ -52,6 +52,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetad
52
52
  from sglang.srt.server_args import ServerArgs
53
53
  from sglang.srt.utils import (
54
54
  get_available_gpu_memory,
55
+ is_generation_model,
55
56
  is_llama3_405b_fp8,
56
57
  is_multimodal_model,
57
58
  monkey_patch_vllm_dummy_weight_loader,
@@ -130,10 +131,12 @@ class ModelRunner:
130
131
  server_args.max_total_tokens,
131
132
  )
132
133
  self.init_cublas()
133
- self.init_flash_infer()
134
+ self.init_flashinfer()
134
135
 
135
- # Capture cuda graphs
136
- self.init_cuda_graphs()
136
+ if self.is_generation:
137
+ # FIXME Currently, cuda graph only capture decode steps, which only exists in causal models
138
+ # Capture cuda graphs
139
+ self.init_cuda_graphs()
137
140
 
138
141
  def load_model(self):
139
142
  logger.info(
@@ -184,6 +187,10 @@ class ModelRunner:
184
187
  scheduler_config=None,
185
188
  cache_config=None,
186
189
  )
190
+ self.is_generation = is_generation_model(
191
+ self.model_config.hf_config.architectures
192
+ )
193
+
187
194
  logger.info(
188
195
  f"[gpu={self.gpu_id}] Load weight end. "
189
196
  f"type={type(self.model).__name__}, "
@@ -287,7 +294,7 @@ class ModelRunner:
287
294
  c = a @ b
288
295
  return c
289
296
 
290
- def init_flash_infer(self):
297
+ def init_flashinfer(self):
291
298
  if self.server_args.disable_flashinfer:
292
299
  self.flashinfer_prefill_wrapper_ragged = None
293
300
  self.flashinfer_prefill_wrapper_paged = None
@@ -350,33 +357,18 @@ class ModelRunner:
350
357
  if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)):
351
358
  return self.cuda_graph_runner.replay(batch)
352
359
 
353
- input_metadata = InputMetadata.create(
354
- self,
355
- forward_mode=ForwardMode.DECODE,
356
- req_pool_indices=batch.req_pool_indices,
357
- seq_lens=batch.seq_lens,
358
- prefix_lens=batch.prefix_lens,
359
- position_ids_offsets=batch.position_ids_offsets,
360
- out_cache_loc=batch.out_cache_loc,
361
- top_logprobs_nums=batch.top_logprobs_nums,
362
- return_logprob=batch.return_logprob,
360
+ input_metadata = InputMetadata.from_schedule_batch(
361
+ self, batch, ForwardMode.DECODE
363
362
  )
363
+
364
364
  return self.model.forward(
365
365
  batch.input_ids, input_metadata.positions, input_metadata
366
366
  )
367
367
 
368
368
  @torch.inference_mode()
369
369
  def forward_extend(self, batch: ScheduleBatch):
370
- input_metadata = InputMetadata.create(
371
- self,
372
- forward_mode=ForwardMode.EXTEND,
373
- req_pool_indices=batch.req_pool_indices,
374
- seq_lens=batch.seq_lens,
375
- prefix_lens=batch.prefix_lens,
376
- position_ids_offsets=batch.position_ids_offsets,
377
- out_cache_loc=batch.out_cache_loc,
378
- top_logprobs_nums=batch.top_logprobs_nums,
379
- return_logprob=batch.return_logprob,
370
+ input_metadata = InputMetadata.from_schedule_batch(
371
+ self, batch, forward_mode=ForwardMode.EXTEND
380
372
  )
381
373
  return self.model.forward(
382
374
  batch.input_ids, input_metadata.positions, input_metadata
@@ -384,24 +376,16 @@ class ModelRunner:
384
376
 
385
377
  @torch.inference_mode()
386
378
  def forward_extend_multi_modal(self, batch: ScheduleBatch):
387
- input_metadata = InputMetadata.create(
388
- self,
389
- forward_mode=ForwardMode.EXTEND,
390
- req_pool_indices=batch.req_pool_indices,
391
- seq_lens=batch.seq_lens,
392
- prefix_lens=batch.prefix_lens,
393
- position_ids_offsets=batch.position_ids_offsets,
394
- out_cache_loc=batch.out_cache_loc,
395
- return_logprob=batch.return_logprob,
396
- top_logprobs_nums=batch.top_logprobs_nums,
379
+ input_metadata = InputMetadata.from_schedule_batch(
380
+ self, batch, forward_mode=ForwardMode.EXTEND
397
381
  )
398
382
  return self.model.forward(
399
383
  batch.input_ids,
400
384
  input_metadata.positions,
401
385
  input_metadata,
402
- batch.pixel_values,
403
- batch.image_sizes,
404
- batch.image_offsets,
386
+ input_metadata.pixel_values,
387
+ input_metadata.image_sizes,
388
+ input_metadata.image_offsets,
405
389
  )
406
390
 
407
391
  def forward(self, batch: ScheduleBatch, forward_mode: ForwardMode):
@@ -429,8 +413,10 @@ def import_model_classes():
429
413
  entry, list
430
414
  ): # To support multiple model classes in one module
431
415
  for tmp in entry:
416
+ assert tmp.__name__ not in model_arch_name_to_cls
432
417
  model_arch_name_to_cls[tmp.__name__] = tmp
433
418
  else:
419
+ assert entry.__name__ not in model_arch_name_to_cls
434
420
  model_arch_name_to_cls[entry.__name__] = entry
435
421
 
436
422
  # compat: some models such as chatglm has incorrect class set in config.json
@@ -440,6 +426,7 @@ def import_model_classes():
440
426
  ):
441
427
  for remap in module.EntryClassRemapping:
442
428
  if isinstance(remap, tuple) and len(remap) == 2:
429
+ assert remap[0] not in model_arch_name_to_cls
443
430
  model_arch_name_to_cls[remap[0]] = remap[1]
444
431
 
445
432
  return model_arch_name_to_cls
@@ -38,7 +38,6 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
38
38
  # from vllm.model_executor.layers.rotary_embedding import GemmaRotaryEmbedding
39
39
  from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
40
40
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
41
- from vllm.model_executor.sampling_metadata import SamplingMetadata
42
41
 
43
42
  from sglang.srt.layers.logits_processor import LogitsProcessor
44
43
  from sglang.srt.layers.radix_attention import RadixAttention
@@ -23,8 +23,6 @@ from torch import nn
23
23
  from transformers import PretrainedConfig
24
24
  from vllm.config import CacheConfig
25
25
  from vllm.distributed import get_tensor_model_parallel_world_size
26
- from vllm.model_executor.layers.activation import SiluAndMul
27
- from vllm.model_executor.layers.layernorm import RMSNorm
28
26
  from vllm.model_executor.layers.linear import (
29
27
  MergedColumnParallelLinear,
30
28
  QKVParallelLinear,
@@ -38,13 +36,14 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
38
36
  )
39
37
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
40
38
 
39
+ from sglang.srt.layers.activation import SiluAndMul
40
+ from sglang.srt.layers.layernorm import RMSNorm
41
41
  from sglang.srt.layers.logits_processor import LogitsProcessor
42
42
  from sglang.srt.layers.radix_attention import RadixAttention
43
43
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
44
44
 
45
45
 
46
46
  class InternLM2MLP(nn.Module):
47
-
48
47
  def __init__(
49
48
  self,
50
49
  hidden_size: int,
@@ -74,7 +73,6 @@ class InternLM2MLP(nn.Module):
74
73
 
75
74
 
76
75
  class InternLM2Attention(nn.Module):
77
-
78
76
  def __init__(
79
77
  self,
80
78
  hidden_size: int,
@@ -150,7 +148,6 @@ class InternLM2Attention(nn.Module):
150
148
 
151
149
 
152
150
  class InternLMDecoderLayer(nn.Module):
153
-
154
151
  def __init__(
155
152
  self,
156
153
  config: PretrainedConfig,
@@ -207,7 +204,6 @@ class InternLMDecoderLayer(nn.Module):
207
204
 
208
205
 
209
206
  class InternLM2Model(nn.Module):
210
-
211
207
  def __init__(
212
208
  self,
213
209
  config: PretrainedConfig,
@@ -254,7 +250,6 @@ class InternLM2Model(nn.Module):
254
250
 
255
251
 
256
252
  class InternLM2ForCausalLM(nn.Module):
257
-
258
253
  def __init__(
259
254
  self,
260
255
  config: PretrainedConfig,
@@ -24,8 +24,6 @@ from torch import nn
24
24
  from transformers import LlamaConfig
25
25
  from vllm.config import CacheConfig
26
26
  from vllm.distributed import get_tensor_model_parallel_world_size
27
- from vllm.model_executor.layers.activation import SiluAndMul
28
- from vllm.model_executor.layers.layernorm import RMSNorm
29
27
  from vllm.model_executor.layers.linear import (
30
28
  MergedColumnParallelLinear,
31
29
  QKVParallelLinear,
@@ -39,7 +37,9 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
39
37
  )
40
38
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
41
39
 
42
- from sglang.srt.layers.logits_processor import LogitsProcessor
40
+ from sglang.srt.layers.activation import SiluAndMul
41
+ from sglang.srt.layers.layernorm import RMSNorm
42
+ from sglang.srt.layers.logits_processor import LogitProcessorOutput, LogitsProcessor
43
43
  from sglang.srt.layers.radix_attention import RadixAttention
44
44
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
45
45
 
@@ -310,7 +310,7 @@ class LlamaForCausalLM(nn.Module):
310
310
  positions: torch.Tensor,
311
311
  input_metadata: InputMetadata,
312
312
  input_embeds: torch.Tensor = None,
313
- ) -> torch.Tensor:
313
+ ) -> LogitProcessorOutput:
314
314
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
315
315
  return self.logits_processor(
316
316
  input_ids, hidden_states, self.lm_head.weight, input_metadata
@@ -0,0 +1,88 @@
1
+ from typing import Iterable, Optional, Tuple
2
+
3
+ import torch
4
+ from torch import nn
5
+ from transformers import LlamaConfig
6
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
7
+
8
+ from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
9
+ from sglang.srt.model_executor.model_runner import InputMetadata
10
+ from sglang.srt.models.llama2 import LlamaForCausalLM, LlamaModel
11
+
12
+
13
+ class LlamaEmbeddingModel(nn.Module):
14
+ def __init__(
15
+ self,
16
+ config: LlamaConfig,
17
+ quant_config=None,
18
+ cache_config=None,
19
+ efficient_weight_load=False,
20
+ ) -> None:
21
+ super().__init__()
22
+ self.model = LlamaModel(config, quant_config=quant_config)
23
+ self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
24
+
25
+ @torch.no_grad()
26
+ def forward(
27
+ self,
28
+ input_ids: torch.Tensor,
29
+ positions: torch.Tensor,
30
+ input_metadata: InputMetadata,
31
+ input_embeds: torch.Tensor = None,
32
+ ) -> EmbeddingPoolerOutput:
33
+ hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
34
+ return self.pooler(hidden_states, input_metadata)
35
+
36
+ def load_weights(
37
+ self, weights: Iterable[Tuple[str, torch.Tensor]], name=None, loaded_weight=None
38
+ ):
39
+ stacked_params_mapping = [
40
+ # (param_name, shard_name, shard_id)
41
+ ("qkv_proj", "q_proj", "q"),
42
+ ("qkv_proj", "k_proj", "k"),
43
+ ("qkv_proj", "v_proj", "v"),
44
+ ("gate_up_proj", "gate_proj", 0),
45
+ ("gate_up_proj", "up_proj", 1),
46
+ ]
47
+ params_dict = dict(self.model.named_parameters())
48
+
49
+ def load_weights_per_param(name, loaded_weight):
50
+ if "rotary_emb.inv_freq" in name or "projector" in name:
51
+ return
52
+ if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
53
+ # Models trained using ColossalAI may include these tensors in
54
+ # the checkpoint. Skip them.
55
+ return
56
+ for param_name, weight_name, shard_id in stacked_params_mapping:
57
+ if weight_name not in name:
58
+ continue
59
+ name = name.replace(weight_name, param_name)
60
+ # Skip loading extra bias for GPTQ models.
61
+ if name.endswith(".bias") and name not in params_dict:
62
+ continue
63
+ if name.startswith("model.vision_tower") and name not in params_dict:
64
+ continue
65
+ param = params_dict[name]
66
+ weight_loader = param.weight_loader
67
+ weight_loader(param, loaded_weight, shard_id)
68
+ break
69
+ else:
70
+ # Skip loading extra bias for GPTQ models.
71
+ if name.endswith(".bias") and name not in params_dict:
72
+ return
73
+ if name.startswith("model.vision_tower") and name not in params_dict:
74
+ return
75
+ param = params_dict[name]
76
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
77
+ weight_loader(param, loaded_weight)
78
+
79
+ if name is None or loaded_weight is None:
80
+ for name, loaded_weight in weights:
81
+ load_weights_per_param(name, loaded_weight)
82
+ else:
83
+ load_weights_per_param(name, loaded_weight)
84
+
85
+
86
+ EntryClass = LlamaEmbeddingModel
87
+ # compat: e5-mistral model.config class == MistralModel
88
+ EntryClassRemapping = [("MistralModel", LlamaEmbeddingModel)]