sglang 0.3.4.post1__py3-none-any.whl → 0.3.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (91) hide show
  1. sglang/api.py +1 -1
  2. sglang/bench_latency.py +3 -3
  3. sglang/bench_server_latency.py +2 -3
  4. sglang/bench_serving.py +92 -0
  5. sglang/global_config.py +9 -3
  6. sglang/lang/chat_template.py +50 -25
  7. sglang/lang/interpreter.py +9 -1
  8. sglang/lang/ir.py +11 -2
  9. sglang/launch_server.py +1 -1
  10. sglang/srt/configs/model_config.py +76 -15
  11. sglang/srt/constrained/__init__.py +18 -0
  12. sglang/srt/constrained/bnf_cache.py +61 -0
  13. sglang/srt/constrained/fsm_cache.py +10 -3
  14. sglang/srt/constrained/grammar.py +190 -0
  15. sglang/srt/hf_transformers_utils.py +20 -5
  16. sglang/srt/layers/attention/flashinfer_backend.py +5 -5
  17. sglang/srt/layers/attention/triton_ops/decode_attention.py +110 -30
  18. sglang/srt/layers/attention/triton_ops/prefill_attention.py +1 -1
  19. sglang/srt/layers/fused_moe/fused_moe.py +4 -3
  20. sglang/srt/layers/fused_moe/layer.py +28 -0
  21. sglang/srt/layers/logits_processor.py +5 -5
  22. sglang/srt/layers/quantization/base_config.py +16 -1
  23. sglang/srt/layers/rotary_embedding.py +15 -48
  24. sglang/srt/layers/sampler.py +51 -39
  25. sglang/srt/layers/vocab_parallel_embedding.py +486 -0
  26. sglang/srt/managers/data_parallel_controller.py +8 -7
  27. sglang/srt/managers/detokenizer_manager.py +11 -9
  28. sglang/srt/managers/image_processor.py +4 -3
  29. sglang/srt/managers/io_struct.py +80 -78
  30. sglang/srt/managers/schedule_batch.py +46 -52
  31. sglang/srt/managers/schedule_policy.py +24 -13
  32. sglang/srt/managers/scheduler.py +145 -82
  33. sglang/srt/managers/tokenizer_manager.py +236 -334
  34. sglang/srt/managers/tp_worker.py +5 -5
  35. sglang/srt/managers/tp_worker_overlap_thread.py +58 -21
  36. sglang/srt/mem_cache/flush_cache.py +1 -1
  37. sglang/srt/mem_cache/memory_pool.py +10 -3
  38. sglang/srt/model_executor/cuda_graph_runner.py +34 -23
  39. sglang/srt/model_executor/forward_batch_info.py +6 -9
  40. sglang/srt/model_executor/model_runner.py +10 -19
  41. sglang/srt/models/baichuan.py +4 -4
  42. sglang/srt/models/chatglm.py +4 -4
  43. sglang/srt/models/commandr.py +1 -1
  44. sglang/srt/models/dbrx.py +5 -5
  45. sglang/srt/models/deepseek.py +4 -4
  46. sglang/srt/models/deepseek_v2.py +4 -4
  47. sglang/srt/models/exaone.py +4 -4
  48. sglang/srt/models/gemma.py +1 -1
  49. sglang/srt/models/gemma2.py +1 -1
  50. sglang/srt/models/gpt2.py +287 -0
  51. sglang/srt/models/gpt_bigcode.py +1 -1
  52. sglang/srt/models/grok.py +4 -4
  53. sglang/srt/models/internlm2.py +4 -4
  54. sglang/srt/models/llama.py +15 -7
  55. sglang/srt/models/llama_embedding.py +2 -10
  56. sglang/srt/models/llama_reward.py +5 -0
  57. sglang/srt/models/minicpm.py +4 -4
  58. sglang/srt/models/minicpm3.py +4 -4
  59. sglang/srt/models/mixtral.py +7 -5
  60. sglang/srt/models/mixtral_quant.py +4 -4
  61. sglang/srt/models/mllama.py +5 -5
  62. sglang/srt/models/olmo.py +4 -4
  63. sglang/srt/models/olmoe.py +4 -4
  64. sglang/srt/models/qwen.py +4 -4
  65. sglang/srt/models/qwen2.py +4 -4
  66. sglang/srt/models/qwen2_moe.py +4 -4
  67. sglang/srt/models/qwen2_vl.py +4 -8
  68. sglang/srt/models/stablelm.py +4 -4
  69. sglang/srt/models/torch_native_llama.py +4 -4
  70. sglang/srt/models/xverse.py +4 -4
  71. sglang/srt/models/xverse_moe.py +4 -4
  72. sglang/srt/openai_api/adapter.py +52 -66
  73. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +6 -3
  74. sglang/srt/sampling/sampling_batch_info.py +7 -13
  75. sglang/srt/sampling/sampling_params.py +5 -7
  76. sglang/srt/server.py +41 -33
  77. sglang/srt/server_args.py +34 -5
  78. sglang/srt/utils.py +40 -56
  79. sglang/test/run_eval.py +2 -0
  80. sglang/test/runners.py +2 -1
  81. sglang/test/srt/sampling/penaltylib/utils.py +1 -0
  82. sglang/test/test_utils.py +151 -6
  83. sglang/utils.py +62 -1
  84. sglang/version.py +1 -1
  85. sglang-0.3.5.dist-info/METADATA +344 -0
  86. sglang-0.3.5.dist-info/RECORD +152 -0
  87. {sglang-0.3.4.post1.dist-info → sglang-0.3.5.dist-info}/WHEEL +1 -1
  88. sglang-0.3.4.post1.dist-info/METADATA +0 -900
  89. sglang-0.3.4.post1.dist-info/RECORD +0 -148
  90. {sglang-0.3.4.post1.dist-info → sglang-0.3.5.dist-info}/LICENSE +0 -0
  91. {sglang-0.3.4.post1.dist-info → sglang-0.3.5.dist-info}/top_level.txt +0 -0
@@ -15,7 +15,6 @@ limitations under the License.
15
15
 
16
16
  """A tensor parallel worker."""
17
17
 
18
- import json
19
18
  import logging
20
19
  from typing import Optional
21
20
 
@@ -26,7 +25,7 @@ from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_a
26
25
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
27
26
  from sglang.srt.model_executor.model_runner import ModelRunner
28
27
  from sglang.srt.server_args import ServerArgs
29
- from sglang.srt.utils import broadcast_pyobj, is_multimodal_model, set_random_seed
28
+ from sglang.srt.utils import broadcast_pyobj, set_random_seed
30
29
 
31
30
  logger = logging.getLogger(__name__)
32
31
 
@@ -48,9 +47,10 @@ class TpModelWorker:
48
47
  # Init model and tokenizer
49
48
  self.model_config = ModelConfig(
50
49
  server_args.model_path,
51
- server_args.trust_remote_code,
50
+ trust_remote_code=server_args.trust_remote_code,
52
51
  context_length=server_args.context_length,
53
- model_override_args=json.loads(server_args.json_model_override_args),
52
+ model_override_args=server_args.json_model_override_args,
53
+ is_embedding=server_args.is_embedding,
54
54
  )
55
55
  self.model_runner = ModelRunner(
56
56
  model_config=self.model_config,
@@ -64,7 +64,7 @@ class TpModelWorker:
64
64
  if server_args.skip_tokenizer_init:
65
65
  self.tokenizer = self.processor = None
66
66
  else:
67
- if is_multimodal_model(self.model_config.hf_config.architectures):
67
+ if self.model_config.is_multimodal:
68
68
  self.processor = get_processor(
69
69
  server_args.tokenizer_path,
70
70
  tokenizer_mode=server_args.tokenizer_mode,
@@ -32,6 +32,15 @@ from sglang.srt.server_args import ServerArgs
32
32
  logger = logging.getLogger(__name__)
33
33
 
34
34
 
35
+ @torch.compile(dynamic=True)
36
+ def resolve_future_token_ids(input_ids, future_token_ids_map):
37
+ input_ids[:] = torch.where(
38
+ input_ids < 0,
39
+ future_token_ids_map[torch.clamp(-input_ids, min=0)],
40
+ input_ids,
41
+ )
42
+
43
+
35
44
  class TpModelWorkerClient:
36
45
  """A tensor parallel model worker."""
37
46
 
@@ -94,46 +103,69 @@ class TpModelWorkerClient:
94
103
  while True:
95
104
  self.has_inflight_batch = False
96
105
  model_worker_batch, future_token_ids_ct = self.input_queue.get()
106
+ if not model_worker_batch:
107
+ break
97
108
  self.has_inflight_batch = True
98
109
  self.launch_event = threading.Event()
99
110
 
100
111
  # Resolve future tokens in the input
101
112
  input_ids = model_worker_batch.input_ids
102
- input_ids[:] = torch.where(
103
- input_ids < 0,
104
- self.future_token_ids_map[torch.clamp(-input_ids, min=0)],
105
- input_ids,
106
- )
113
+ resolve_future_token_ids(input_ids, self.future_token_ids_map)
107
114
 
108
115
  # Run forward
109
116
  logits_output, next_token_ids = self.worker.forward_batch_generation(
110
117
  model_worker_batch
111
118
  )
112
- self.launch_event.set()
113
119
 
114
120
  # Update the future token ids map
115
121
  bs = len(model_worker_batch.seq_lens)
116
- future_next_token_ids = torch.arange(
117
- -(future_token_ids_ct + bs),
118
- -(future_token_ids_ct),
119
- dtype=torch.int32,
120
- device=self.device,
121
- )
122
- self.future_token_ids_map[-future_next_token_ids] = next_token_ids.to(
123
- torch.int32
124
- )
125
-
122
+ self.future_token_ids_map[
123
+ future_token_ids_ct + 1 : future_token_ids_ct + bs + 1
124
+ ] = next_token_ids
125
+
126
+ # Copy results to the CPU
127
+ if model_worker_batch.return_logprob:
128
+ logits_output.next_token_logprobs = logits_output.next_token_logprobs[
129
+ torch.arange(len(next_token_ids), device=self.device),
130
+ next_token_ids,
131
+ ].to("cpu", non_blocking=True)
132
+ if logits_output.input_token_logprobs is not None:
133
+ logits_output.input_token_logprobs = (
134
+ logits_output.input_token_logprobs.to("cpu", non_blocking=True)
135
+ )
136
+ logits_output.normalized_prompt_logprobs = (
137
+ logits_output.normalized_prompt_logprobs.to(
138
+ "cpu", non_blocking=True
139
+ )
140
+ )
126
141
  next_token_ids = next_token_ids.to("cpu", non_blocking=True)
127
142
  copy_event = torch.cuda.Event(blocking=True)
128
143
  copy_event.record()
129
- self.copy_queue.put((copy_event, next_token_ids))
144
+
145
+ self.launch_event.set()
146
+ self.copy_queue.put((copy_event, logits_output, next_token_ids))
130
147
 
131
148
  def copy_thread_func(self):
132
149
  while True:
133
- copy_event, next_token_ids = self.copy_queue.get()
150
+ copy_event, logits_output, next_token_ids = self.copy_queue.get()
151
+ if not copy_event:
152
+ break
134
153
  while not copy_event.query():
135
154
  time.sleep(1e-5)
136
- self.output_queue.put((None, next_token_ids.tolist()))
155
+
156
+ if logits_output.next_token_logprobs is not None:
157
+ logits_output.next_token_logprobs = (
158
+ logits_output.next_token_logprobs.tolist()
159
+ )
160
+ if logits_output.input_token_logprobs is not None:
161
+ logits_output.input_token_logprobs = (
162
+ logits_output.input_token_logprobs.tolist()
163
+ )
164
+ logits_output.normalized_prompt_logprobs = (
165
+ logits_output.normalized_prompt_logprobs.tolist()
166
+ )
167
+
168
+ self.output_queue.put((logits_output, next_token_ids.tolist()))
137
169
 
138
170
  def resulve_batch_result(self, bid: int):
139
171
  logits_output, next_token_ids = self.output_queue.get()
@@ -149,8 +181,9 @@ class TpModelWorkerClient:
149
181
  # Allocate output future objects
150
182
  bs = len(model_worker_batch.seq_lens)
151
183
  future_next_token_ids = torch.arange(
152
- -(self.future_token_ids_ct + bs),
153
- -(self.future_token_ids_ct),
184
+ -(self.future_token_ids_ct + 1),
185
+ -(self.future_token_ids_ct + 1 + bs),
186
+ -1,
154
187
  dtype=torch.int32,
155
188
  device=self.device,
156
189
  )
@@ -170,3 +203,7 @@ class TpModelWorkerClient:
170
203
  recv_req.model_path, recv_req.load_format
171
204
  )
172
205
  return success, message
206
+
207
+ def __delete__(self):
208
+ self.input_queue.put((None, None))
209
+ self.copy_queue.put((None, None, None))
@@ -29,5 +29,5 @@ if __name__ == "__main__":
29
29
  parser.add_argument("--url", type=str, default="http://localhost:30000")
30
30
  args = parser.parse_args()
31
31
 
32
- response = requests.get(args.url + "/flush_cache")
32
+ response = requests.post(args.url + "/flush_cache")
33
33
  assert response.status_code == 200
@@ -38,7 +38,7 @@ class ReqToTokenPool:
38
38
  self.size = size
39
39
  self.max_context_len = max_context_len
40
40
  self.device = device
41
- self.req_to_token = torch.empty(
41
+ self.req_to_token = torch.zeros(
42
42
  (size, max_context_len), dtype=torch.int32, device=device
43
43
  )
44
44
  self.free_slots = list(range(size))
@@ -51,7 +51,7 @@ class ReqToTokenPool:
51
51
  self.write = self.write_without_records
52
52
 
53
53
  def write(self, indices, values):
54
- # Keep the signature for type checking, will be initialized during runtime
54
+ # Keep the signature for type checking. It will be assigned during runtime.
55
55
  raise NotImplementedError()
56
56
 
57
57
  def available_size(self):
@@ -223,7 +223,6 @@ class MHATokenToKVPool(BaseTokenToKVPool):
223
223
  layer_id = layer.layer_id
224
224
  if cache_k.dtype != self.dtype:
225
225
  cache_k = cache_k.to(self.dtype)
226
- if cache_v.dtype != self.dtype:
227
226
  cache_v = cache_v.to(self.dtype)
228
227
  if self.store_dtype != self.dtype:
229
228
  self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
@@ -233,6 +232,14 @@ class MHATokenToKVPool(BaseTokenToKVPool):
233
232
  self.v_buffer[layer_id][loc] = cache_v
234
233
 
235
234
 
235
+ # This compiled version is slower in the unit test
236
+ # python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_non_stream_small_batch_size
237
+ @torch.compile(dynamic=True)
238
+ def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype):
239
+ dst_1[loc] = src_1.to(dtype).view(store_dtype)
240
+ dst_2[loc] = src_2.to(dtype).view(store_dtype)
241
+
242
+
236
243
  class MLATokenToKVPool(BaseTokenToKVPool):
237
244
 
238
245
  def __init__(
@@ -92,6 +92,11 @@ def set_torch_compile_config():
92
92
  torch._dynamo.config.accumulated_cache_size_limit = 1024
93
93
 
94
94
 
95
+ @torch.compile(dynamic=True)
96
+ def clamp_position(seq_lens):
97
+ return torch.clamp((seq_lens - 1), min=0).to(torch.int64)
98
+
99
+
95
100
  class CudaGraphRunner:
96
101
  """A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile."""
97
102
 
@@ -108,19 +113,21 @@ class CudaGraphRunner:
108
113
  self.is_encoder_decoder = self.model_runner.model_config.is_encoder_decoder
109
114
 
110
115
  # Batch sizes to capture
111
- if self.model_runner.server_args.disable_cuda_graph_padding:
116
+ if model_runner.server_args.disable_cuda_graph_padding:
112
117
  self.capture_bs = list(range(1, 32)) + [64, 128]
113
118
  else:
114
119
  self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
115
-
116
120
  self.capture_bs = [
117
- bs for bs in self.capture_bs if bs <= model_runner.req_to_token_pool.size
121
+ bs
122
+ for bs in self.capture_bs
123
+ if bs <= model_runner.req_to_token_pool.size
124
+ and bs <= model_runner.server_args.cuda_graph_max_bs
118
125
  ]
119
126
  self.compile_bs = (
120
127
  [
121
128
  bs
122
129
  for bs in self.capture_bs
123
- if bs <= self.model_runner.server_args.max_torch_compile_bs
130
+ if bs <= self.model_runner.server_args.torch_compile_max_bs
124
131
  ]
125
132
  if self.use_torch_compile
126
133
  else []
@@ -129,6 +136,7 @@ class CudaGraphRunner:
129
136
  # Attention backend
130
137
  self.max_bs = max(self.capture_bs)
131
138
  self.model_runner.attn_backend.init_cuda_graph_state(self.max_bs)
139
+
132
140
  self.seq_len_fill_value = (
133
141
  self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
134
142
  )
@@ -147,6 +155,7 @@ class CudaGraphRunner:
147
155
  (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
148
156
  )
149
157
  self.out_cache_loc = torch.zeros((self.max_bs,), dtype=torch.int32)
158
+ self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int32)
150
159
 
151
160
  if self.is_encoder_decoder:
152
161
  # NOTE: encoder_lens can influence the full_text_row_masked_out_mask tensor when doing mixed batch
@@ -228,6 +237,7 @@ class CudaGraphRunner:
228
237
  encoder_lens = None
229
238
 
230
239
  seq_lens_sum = seq_lens.sum().item()
240
+ mrope_positions = self.mrope_positions[:, :bs]
231
241
 
232
242
  # Attention backend
233
243
  self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph(
@@ -253,9 +263,11 @@ class CudaGraphRunner:
253
263
  encoder_lens=encoder_lens,
254
264
  return_logprob=False,
255
265
  top_logprobs_nums=[0] * bs,
256
- positions=torch.clamp((seq_lens - 1), min=0).to(torch.int64),
266
+ positions=clamp_position(seq_lens),
267
+ mrope_positions=mrope_positions,
257
268
  )
258
- return forward(input_ids, forward_batch.positions, forward_batch)
269
+ logits_output = forward(input_ids, forward_batch.positions, forward_batch)
270
+ return logits_output.next_token_logits
259
271
 
260
272
  for _ in range(2):
261
273
  torch.cuda.synchronize()
@@ -286,7 +298,7 @@ class CudaGraphRunner:
286
298
  index = bisect.bisect_left(self.capture_bs, raw_bs)
287
299
  bs = self.capture_bs[index]
288
300
  if bs != raw_bs:
289
- self.seq_lens.fill_(self.seq_len_fill_value)
301
+ self.seq_lens.fill_(1)
290
302
  self.out_cache_loc.zero_()
291
303
 
292
304
  # Common inputs
@@ -296,35 +308,30 @@ class CudaGraphRunner:
296
308
  self.out_cache_loc[:raw_bs].copy_(forward_batch.out_cache_loc)
297
309
  if self.is_encoder_decoder:
298
310
  self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
311
+ if forward_batch.mrope_positions is not None:
312
+ self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
299
313
 
300
314
  # Attention backend
301
315
  self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
302
316
  bs,
303
317
  self.req_pool_indices,
304
318
  self.seq_lens,
305
- forward_batch.seq_lens_sum,
319
+ forward_batch.seq_lens_sum + (bs - raw_bs),
306
320
  self.encoder_lens,
307
321
  )
308
322
 
309
323
  # Replay
310
324
  self.graphs[bs].replay()
311
- logits_output = self.output_buffers[bs]
312
-
313
- # Unpad
314
- if bs != raw_bs:
315
- logits_output = LogitsProcessorOutput(
316
- next_token_logits=logits_output.next_token_logits[:raw_bs],
317
- next_token_logprobs=None,
318
- normalized_prompt_logprobs=None,
319
- input_token_logprobs=None,
320
- input_top_logprobs=None,
321
- output_top_logprobs=None,
322
- )
325
+ next_token_logits = self.output_buffers[bs][:raw_bs]
323
326
 
324
327
  # Extract logprobs
325
328
  if forward_batch.return_logprob:
326
- logits_output.next_token_logprobs = torch.nn.functional.log_softmax(
327
- logits_output.next_token_logits, dim=-1
329
+ next_token_logprobs = torch.nn.functional.log_softmax(
330
+ next_token_logits, dim=-1
331
+ )
332
+ logits_output = LogitsProcessorOutput(
333
+ next_token_logits=next_token_logits,
334
+ next_token_logprobs=next_token_logprobs,
328
335
  )
329
336
  return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums)
330
337
  if return_top_logprob:
@@ -333,7 +340,11 @@ class CudaGraphRunner:
333
340
  top_logprobs_nums=forward_batch.top_logprobs_nums,
334
341
  )
335
342
  logits_output.output_top_logprobs = LogitsProcessor.get_top_logprobs(
336
- logits_output.next_token_logprobs, logits_metadata
343
+ next_token_logprobs, logits_metadata
337
344
  )[1]
345
+ else:
346
+ logits_output = LogitsProcessorOutput(
347
+ next_token_logits=next_token_logits,
348
+ )
338
349
 
339
350
  return logits_output
@@ -142,11 +142,12 @@ class ForwardBatch:
142
142
  int(self.seq_lens[i]),
143
143
  )
144
144
  elif self.forward_mode.is_extend():
145
+ extend_start_loc_cpu = self.extend_start_loc.cpu().numpy()
145
146
  for i, image_inputs in enumerate(batch.image_inputs):
146
147
  extend_start_loc, extend_seq_len, extend_prefix_len = (
147
- self.extend_start_loc[i],
148
- self.extend_seq_lens[i],
149
- self.extend_prefix_lens[i],
148
+ extend_start_loc_cpu[i],
149
+ batch.extend_seq_lens[i],
150
+ batch.extend_prefix_lens[i],
150
151
  )
151
152
  if image_inputs is None:
152
153
  # text only
@@ -160,20 +161,16 @@ class ForwardBatch:
160
161
  ] * 3
161
162
  mrope_position_delta = 0
162
163
  else:
164
+ # TODO: current qwen2-vl do not support radix cache since mrope position calculation
163
165
  mrope_positions, mrope_position_delta = (
164
166
  MRotaryEmbedding.get_input_positions(
165
167
  input_tokens=self.input_ids[
166
168
  extend_start_loc : extend_start_loc + extend_seq_len
167
- ].tolist(),
169
+ ],
168
170
  image_grid_thw=image_inputs.image_grid_thws,
169
- video_grid_thw=None,
170
- image_token_id=hf_config.image_token_id,
171
- video_token_id=hf_config.video_token_id,
172
171
  vision_start_token_id=hf_config.vision_start_token_id,
173
- vision_end_token_id=hf_config.vision_end_token_id,
174
172
  spatial_merge_size=hf_config.vision_config.spatial_merge_size,
175
173
  context_len=0,
176
- extend_prefix_len=extend_prefix_len.item(),
177
174
  )
178
175
  )
179
176
  mrope_positions_list[i] = mrope_positions
@@ -59,11 +59,6 @@ from sglang.srt.server_args import ServerArgs
59
59
  from sglang.srt.utils import (
60
60
  enable_show_time_cost,
61
61
  get_available_gpu_memory,
62
- is_attention_free_model,
63
- is_embedding_model,
64
- is_generation_model,
65
- is_multimodal_model,
66
- model_has_inner_state,
67
62
  monkey_patch_vllm_dummy_weight_loader,
68
63
  monkey_patch_vllm_p2p_access_check,
69
64
  )
@@ -93,9 +88,8 @@ class ModelRunner:
93
88
  self.tp_size = tp_size
94
89
  self.dist_port = nccl_port
95
90
  self.server_args = server_args
96
- self.is_multimodal_model = is_multimodal_model(
97
- self.model_config.hf_config.architectures
98
- )
91
+ self.is_generation = model_config.is_generation
92
+ self.is_multimodal = model_config.is_multimodal
99
93
 
100
94
  # Model-specific adjustment
101
95
  if (
@@ -119,17 +113,17 @@ class ModelRunner:
119
113
  self.server_args.ds_heavy_channel_type
120
114
  )
121
115
 
122
- if self.is_multimodal_model:
116
+ if self.is_multimodal:
123
117
  logger.warning(
124
118
  "Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
125
119
  )
126
120
  server_args.chunked_prefill_size = None
127
- server_args.mem_fraction_static *= 0.95
128
- # TODO: qwen2-vl does not support cuda graph now, set disable-graph=True automatically
121
+ self.mem_fraction_static *= 0.95
122
+ # TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
129
123
  if self.model_config.hf_config.architectures == [
130
124
  "Qwen2VLForConditionalGeneration"
131
125
  ]:
132
- server_args.disable_cuda_graph = True
126
+ server_args.disable_radix_cache = True
133
127
 
134
128
  # Global vars
135
129
  if server_args.show_time_cost:
@@ -270,9 +264,6 @@ class ModelRunner:
270
264
  if hasattr(self.model, "get_attention_sliding_window_size")
271
265
  else None
272
266
  )
273
- self.is_generation = is_generation_model(
274
- self.model_config.hf_config.architectures, self.server_args.is_embedding
275
- )
276
267
 
277
268
  logger.info(
278
269
  f"Load weight end. "
@@ -679,7 +670,7 @@ def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
679
670
 
680
671
  # Monkey patch model loader
681
672
  setattr(ModelRegistry, "_try_load_model_cls", load_model_cls_srt)
682
- setattr(ModelRegistry, "is_multimodal_model", is_multimodal_model)
683
- setattr(ModelRegistry, "is_attention_free_model", is_attention_free_model)
684
- setattr(ModelRegistry, "model_has_inner_state", model_has_inner_state)
685
- setattr(ModelRegistry, "is_embedding_model", is_embedding_model)
673
+ setattr(ModelRegistry, "is_multimodal_model", lambda model_architectures: False)
674
+ setattr(ModelRegistry, "is_attention_free_model", lambda model_architectures: False)
675
+ setattr(ModelRegistry, "model_has_inner_state", lambda model_architectures: False)
676
+ setattr(ModelRegistry, "is_embedding_model", lambda model_architectures: False)
@@ -34,10 +34,6 @@ from vllm.model_executor.layers.linear import (
34
34
  RowParallelLinear,
35
35
  )
36
36
  from vllm.model_executor.layers.rotary_embedding import get_rope
37
- from vllm.model_executor.layers.vocab_parallel_embedding import (
38
- ParallelLMHead,
39
- VocabParallelEmbedding,
40
- )
41
37
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
42
38
 
43
39
  from sglang.srt.layers.activation import SiluAndMul
@@ -45,6 +41,10 @@ from sglang.srt.layers.layernorm import RMSNorm
45
41
  from sglang.srt.layers.logits_processor import LogitsProcessor
46
42
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
47
43
  from sglang.srt.layers.radix_attention import RadixAttention
44
+ from sglang.srt.layers.vocab_parallel_embedding import (
45
+ ParallelLMHead,
46
+ VocabParallelEmbedding,
47
+ )
48
48
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
49
49
 
50
50
 
@@ -24,10 +24,6 @@ from torch import nn
24
24
  from torch.nn import LayerNorm
25
25
  from vllm.distributed import get_tensor_model_parallel_world_size
26
26
  from vllm.model_executor.layers.rotary_embedding import get_rope
27
- from vllm.model_executor.layers.vocab_parallel_embedding import (
28
- ParallelLMHead,
29
- VocabParallelEmbedding,
30
- )
31
27
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
32
28
  from vllm.transformers_utils.configs import ChatGLMConfig
33
29
 
@@ -41,6 +37,10 @@ from sglang.srt.layers.linear import (
41
37
  from sglang.srt.layers.logits_processor import LogitsProcessor
42
38
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
43
39
  from sglang.srt.layers.radix_attention import RadixAttention
40
+ from sglang.srt.layers.vocab_parallel_embedding import (
41
+ ParallelLMHead,
42
+ VocabParallelEmbedding,
43
+ )
44
44
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
45
45
 
46
46
  LoraConfig = None
@@ -50,7 +50,6 @@ from vllm.distributed import (
50
50
  get_tensor_model_parallel_world_size,
51
51
  )
52
52
  from vllm.model_executor.layers.rotary_embedding import get_rope
53
- from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
54
53
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
55
54
 
56
55
  from sglang.srt.layers.activation import SiluAndMul
@@ -62,6 +61,7 @@ from sglang.srt.layers.linear import (
62
61
  from sglang.srt.layers.logits_processor import LogitsProcessor
63
62
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
64
63
  from sglang.srt.layers.radix_attention import RadixAttention
64
+ from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
65
65
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
66
66
  from sglang.srt.utils import set_weight_attrs
67
67
 
sglang/srt/models/dbrx.py CHANGED
@@ -27,11 +27,6 @@ from vllm.distributed import (
27
27
  )
28
28
  from vllm.model_executor.layers.fused_moe import fused_moe
29
29
  from vllm.model_executor.layers.rotary_embedding import get_rope
30
- from vllm.model_executor.layers.vocab_parallel_embedding import (
31
- DEFAULT_VOCAB_PADDING_SIZE,
32
- ParallelLMHead,
33
- VocabParallelEmbedding,
34
- )
35
30
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
36
31
  from vllm.transformers_utils.configs.dbrx import DbrxConfig
37
32
 
@@ -43,6 +38,11 @@ from sglang.srt.layers.linear import (
43
38
  from sglang.srt.layers.logits_processor import LogitsProcessor
44
39
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
45
40
  from sglang.srt.layers.radix_attention import RadixAttention
41
+ from sglang.srt.layers.vocab_parallel_embedding import (
42
+ DEFAULT_VOCAB_PADDING_SIZE,
43
+ ParallelLMHead,
44
+ VocabParallelEmbedding,
45
+ )
46
46
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
47
47
  from sglang.srt.utils import set_weight_attrs
48
48
 
@@ -28,10 +28,6 @@ from vllm.distributed import (
28
28
  )
29
29
  from vllm.model_executor.layers.fused_moe import fused_moe
30
30
  from vllm.model_executor.layers.rotary_embedding import get_rope
31
- from vllm.model_executor.layers.vocab_parallel_embedding import (
32
- ParallelLMHead,
33
- VocabParallelEmbedding,
34
- )
35
31
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
36
32
 
37
33
  from sglang.srt.layers.activation import SiluAndMul
@@ -45,6 +41,10 @@ from sglang.srt.layers.linear import (
45
41
  from sglang.srt.layers.logits_processor import LogitsProcessor
46
42
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
47
43
  from sglang.srt.layers.radix_attention import RadixAttention
44
+ from sglang.srt.layers.vocab_parallel_embedding import (
45
+ ParallelLMHead,
46
+ VocabParallelEmbedding,
47
+ )
48
48
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
49
49
 
50
50
 
@@ -27,10 +27,6 @@ from vllm.distributed import (
27
27
  )
28
28
  from vllm.model_executor.layers.fused_moe import FusedMoE
29
29
  from vllm.model_executor.layers.rotary_embedding import get_rope
30
- from vllm.model_executor.layers.vocab_parallel_embedding import (
31
- ParallelLMHead,
32
- VocabParallelEmbedding,
33
- )
34
30
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
35
31
 
36
32
  from sglang.srt.layers.activation import SiluAndMul
@@ -44,6 +40,10 @@ from sglang.srt.layers.linear import (
44
40
  from sglang.srt.layers.logits_processor import LogitsProcessor
45
41
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
46
42
  from sglang.srt.layers.radix_attention import RadixAttention
43
+ from sglang.srt.layers.vocab_parallel_embedding import (
44
+ ParallelLMHead,
45
+ VocabParallelEmbedding,
46
+ )
47
47
  from sglang.srt.managers.schedule_batch import global_server_args_dict
48
48
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
49
49
  from sglang.srt.utils import is_flashinfer_available
@@ -23,10 +23,6 @@ import torch
23
23
  from torch import nn
24
24
  from vllm.distributed import get_tensor_model_parallel_world_size
25
25
  from vllm.model_executor.layers.rotary_embedding import get_rope
26
- from vllm.model_executor.layers.vocab_parallel_embedding import (
27
- ParallelLMHead,
28
- VocabParallelEmbedding,
29
- )
30
26
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
31
27
 
32
28
  from sglang.srt.layers.activation import SiluAndMul
@@ -39,6 +35,10 @@ from sglang.srt.layers.linear import (
39
35
  from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
40
36
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
41
37
  from sglang.srt.layers.radix_attention import RadixAttention
38
+ from sglang.srt.layers.vocab_parallel_embedding import (
39
+ ParallelLMHead,
40
+ VocabParallelEmbedding,
41
+ )
42
42
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
43
43
 
44
44
 
@@ -24,7 +24,6 @@ from transformers import PretrainedConfig
24
24
  from vllm.config import LoRAConfig
25
25
  from vllm.distributed import get_tensor_model_parallel_world_size
26
26
  from vllm.model_executor.layers.rotary_embedding import get_rope
27
- from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
28
27
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
29
28
 
30
29
  from sglang.srt.layers.activation import GeluAndMul
@@ -37,6 +36,7 @@ from sglang.srt.layers.linear import (
37
36
  from sglang.srt.layers.logits_processor import LogitsProcessor
38
37
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
39
38
  from sglang.srt.layers.radix_attention import RadixAttention
39
+ from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
40
40
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
41
41
 
42
42
 
@@ -24,7 +24,6 @@ from vllm.config import LoRAConfig
24
24
  from vllm.distributed import get_tensor_model_parallel_world_size
25
25
 
26
26
  # from vllm.model_executor.layers.rotary_embedding import GemmaRotaryEmbedding
27
- from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
28
27
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
29
28
 
30
29
  from sglang.srt.layers.activation import GeluAndMul
@@ -37,6 +36,7 @@ from sglang.srt.layers.linear import (
37
36
  from sglang.srt.layers.logits_processor import LogitsProcessor
38
37
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
39
38
  from sglang.srt.layers.radix_attention import RadixAttention
39
+ from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
40
40
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
41
41
 
42
42