sglang 0.2.13__py3-none-any.whl → 0.2.14.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. sglang/api.py +6 -0
  2. sglang/bench_latency.py +7 -3
  3. sglang/bench_serving.py +50 -26
  4. sglang/check_env.py +15 -0
  5. sglang/lang/chat_template.py +10 -5
  6. sglang/lang/compiler.py +4 -0
  7. sglang/lang/interpreter.py +1 -0
  8. sglang/lang/ir.py +9 -0
  9. sglang/launch_server.py +8 -1
  10. sglang/srt/constrained/fsm_cache.py +11 -2
  11. sglang/srt/constrained/jump_forward.py +1 -0
  12. sglang/srt/conversation.py +50 -1
  13. sglang/srt/hf_transformers_utils.py +22 -23
  14. sglang/srt/layers/activation.py +100 -1
  15. sglang/srt/layers/decode_attention.py +338 -50
  16. sglang/srt/layers/fused_moe/layer.py +2 -2
  17. sglang/srt/layers/logits_processor.py +56 -19
  18. sglang/srt/layers/radix_attention.py +3 -4
  19. sglang/srt/layers/sampler.py +101 -0
  20. sglang/srt/managers/controller_multi.py +2 -8
  21. sglang/srt/managers/controller_single.py +7 -10
  22. sglang/srt/managers/detokenizer_manager.py +20 -9
  23. sglang/srt/managers/io_struct.py +44 -11
  24. sglang/srt/managers/policy_scheduler.py +5 -2
  25. sglang/srt/managers/schedule_batch.py +46 -166
  26. sglang/srt/managers/tokenizer_manager.py +192 -83
  27. sglang/srt/managers/tp_worker.py +118 -24
  28. sglang/srt/mem_cache/memory_pool.py +82 -8
  29. sglang/srt/mm_utils.py +79 -7
  30. sglang/srt/model_executor/cuda_graph_runner.py +32 -8
  31. sglang/srt/model_executor/forward_batch_info.py +51 -26
  32. sglang/srt/model_executor/model_runner.py +201 -58
  33. sglang/srt/models/gemma2.py +10 -6
  34. sglang/srt/models/gpt_bigcode.py +1 -1
  35. sglang/srt/models/grok.py +11 -1
  36. sglang/srt/models/llama_embedding.py +4 -0
  37. sglang/srt/models/llava.py +176 -59
  38. sglang/srt/models/qwen2.py +9 -3
  39. sglang/srt/openai_api/adapter.py +200 -39
  40. sglang/srt/openai_api/protocol.py +2 -0
  41. sglang/srt/sampling/sampling_batch_info.py +136 -0
  42. sglang/srt/{sampling_params.py → sampling/sampling_params.py} +22 -0
  43. sglang/srt/server.py +92 -57
  44. sglang/srt/server_args.py +43 -15
  45. sglang/srt/utils.py +26 -16
  46. sglang/test/runners.py +22 -30
  47. sglang/test/simple_eval_common.py +9 -10
  48. sglang/test/simple_eval_gpqa.py +2 -1
  49. sglang/test/simple_eval_humaneval.py +2 -2
  50. sglang/test/simple_eval_math.py +2 -1
  51. sglang/test/simple_eval_mmlu.py +2 -1
  52. sglang/test/test_activation.py +55 -0
  53. sglang/test/test_utils.py +36 -53
  54. sglang/version.py +1 -1
  55. {sglang-0.2.13.dist-info → sglang-0.2.14.post1.dist-info}/METADATA +100 -27
  56. sglang-0.2.14.post1.dist-info/RECORD +114 -0
  57. {sglang-0.2.13.dist-info → sglang-0.2.14.post1.dist-info}/WHEEL +1 -1
  58. sglang/launch_server_llavavid.py +0 -29
  59. sglang-0.2.13.dist-info/RECORD +0 -112
  60. {sglang-0.2.13.dist-info → sglang-0.2.14.post1.dist-info}/LICENSE +0 -0
  61. {sglang-0.2.13.dist-info → sglang-0.2.14.post1.dist-info}/top_level.txt +0 -0
@@ -15,11 +15,11 @@ limitations under the License.
15
15
 
16
16
  """ModelRunner runs the forward passes of the models."""
17
17
 
18
+ import gc
18
19
  import importlib
19
20
  import importlib.resources
20
21
  import logging
21
22
  import pkgutil
22
- import warnings
23
23
  from functools import lru_cache
24
24
  from typing import Optional, Type
25
25
 
@@ -37,7 +37,9 @@ from vllm.distributed import (
37
37
  get_tp_group,
38
38
  init_distributed_environment,
39
39
  initialize_model_parallel,
40
+ set_custom_all_reduce,
40
41
  )
42
+ from vllm.distributed.parallel_state import in_the_same_node_as
41
43
  from vllm.model_executor.model_loader import get_model
42
44
  from vllm.model_executor.models import ModelRegistry
43
45
 
@@ -88,22 +90,35 @@ class ModelRunner:
88
90
  {
89
91
  "disable_flashinfer": server_args.disable_flashinfer,
90
92
  "disable_flashinfer_sampling": server_args.disable_flashinfer_sampling,
91
- "attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
93
+ "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
92
94
  "enable_mla": server_args.enable_mla,
93
95
  }
94
96
  )
95
97
 
98
+ min_per_gpu_memory = self.init_torch_distributed()
99
+ self.load_model()
100
+ self.init_memory_pool(
101
+ min_per_gpu_memory,
102
+ server_args.max_num_reqs,
103
+ server_args.max_total_tokens,
104
+ )
105
+ self.init_cublas()
106
+ self.init_flashinfer()
107
+ self.init_cuda_graphs()
108
+
109
+ def init_torch_distributed(self):
96
110
  # Init torch distributed
97
111
  torch.cuda.set_device(self.gpu_id)
98
- logger.info(f"[gpu={self.gpu_id}] Init nccl begin.")
112
+ logger.info("Init nccl begin.")
99
113
 
100
- if not server_args.enable_p2p_check:
114
+ if not self.server_args.enable_p2p_check:
101
115
  monkey_patch_vllm_p2p_access_check(self.gpu_id)
102
116
 
103
- if server_args.nccl_init_addr:
104
- nccl_init_method = f"tcp://{server_args.nccl_init_addr}"
117
+ if self.server_args.nccl_init_addr:
118
+ nccl_init_method = f"tcp://{self.server_args.nccl_init_addr}"
105
119
  else:
106
120
  nccl_init_method = f"tcp://127.0.0.1:{self.nccl_port}"
121
+ set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
107
122
  init_distributed_environment(
108
123
  backend="nccl",
109
124
  world_size=self.tp_size,
@@ -112,43 +127,45 @@ class ModelRunner:
112
127
  distributed_init_method=nccl_init_method,
113
128
  )
114
129
  initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
115
- self.tp_group = get_tp_group()
116
- total_gpu_memory = get_available_gpu_memory(
130
+ min_per_gpu_memory = get_available_gpu_memory(
117
131
  self.gpu_id, distributed=self.tp_size > 1
118
132
  )
133
+ self.tp_group = get_tp_group()
119
134
 
135
+ # Currently, there is a bug with mulit-node tensor parallelsim + padded cuda graph,
136
+ # so we disable padding in cuda graph.
137
+ if not all(in_the_same_node_as(self.tp_group.cpu_group, source_rank=0)):
138
+ self.server_args.disable_cuda_graph_padding = True
139
+ logger.info(
140
+ "Setting disable_cuda_graph_padding to True because of multi-node tensor parallelism."
141
+ )
142
+
143
+ # Check memory for tensor parallelism
120
144
  if self.tp_size > 1:
121
- total_local_gpu_memory = get_available_gpu_memory(self.gpu_id)
122
- if total_local_gpu_memory < total_gpu_memory * 0.9:
145
+ local_gpu_memory = get_available_gpu_memory(self.gpu_id)
146
+ if min_per_gpu_memory < local_gpu_memory * 0.9:
123
147
  raise ValueError(
124
148
  "The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
125
149
  )
126
150
 
127
- # Load the model and create memory pool
128
- self.load_model()
129
- self.init_memory_pool(
130
- total_gpu_memory,
131
- server_args.max_num_reqs,
132
- server_args.max_total_tokens,
133
- )
134
- self.init_cublas()
135
- self.init_flashinfer()
136
-
137
- if self.is_generation:
138
- # FIXME Currently, cuda graph only capture decode steps, which only exists in causal models
139
- # Capture cuda graphs
140
- self.init_cuda_graphs()
151
+ return min_per_gpu_memory
141
152
 
142
153
  def load_model(self):
143
154
  logger.info(
144
- f"[gpu={self.gpu_id}] Load weight begin. "
145
- f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
155
+ f"Load weight begin. avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
146
156
  )
157
+ if torch.cuda.get_device_capability()[0] < 8:
158
+ logger.info(
159
+ "Compute capability below sm80. Use float16 due to lack of bfloat16 support."
160
+ )
161
+ self.server_args.dtype = "float16"
162
+ if torch.cuda.get_device_capability()[1] < 5:
163
+ raise RuntimeError("SGLang only supports sm75 and above.")
147
164
 
148
165
  monkey_patch_vllm_dummy_weight_loader()
149
- device_config = DeviceConfig()
150
- load_config = LoadConfig(load_format=self.server_args.load_format)
151
- vllm_model_config = VllmModelConfig(
166
+ self.device_config = DeviceConfig()
167
+ self.load_config = LoadConfig(load_format=self.server_args.load_format)
168
+ self.vllm_model_config = VllmModelConfig(
152
169
  model=self.server_args.model_path,
153
170
  quantization=self.server_args.quantization,
154
171
  tokenizer=None,
@@ -159,43 +176,132 @@ class ModelRunner:
159
176
  skip_tokenizer_init=True,
160
177
  )
161
178
 
179
+ # A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints
180
+ # Drop this after Sept, 2024.
162
181
  if is_llama3_405b_fp8_head_16(self.model_config) and self.tp_size <= 8:
163
- # A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints
164
182
  self.model_config.hf_config.num_key_value_heads = 8
165
- vllm_model_config.hf_config.num_key_value_heads = 8
183
+ self.vllm_model_config.hf_config.num_key_value_heads = 8
166
184
  monkey_patch_vllm_qvk_linear_loader()
167
185
 
168
- self.dtype = vllm_model_config.dtype
186
+ self.dtype = self.vllm_model_config.dtype
169
187
  if self.model_config.model_overide_args is not None:
170
- vllm_model_config.hf_config.update(self.model_config.model_overide_args)
188
+ self.vllm_model_config.hf_config.update(
189
+ self.model_config.model_overide_args
190
+ )
171
191
 
172
192
  self.model = get_model(
173
- model_config=vllm_model_config,
174
- device_config=device_config,
175
- load_config=load_config,
176
- lora_config=None,
177
- multimodal_config=None,
193
+ model_config=self.vllm_model_config,
194
+ load_config=self.load_config,
195
+ device_config=self.device_config,
178
196
  parallel_config=None,
179
197
  scheduler_config=None,
198
+ lora_config=None,
180
199
  cache_config=None,
181
200
  )
182
201
  self.sliding_window_size = (
183
- self.model.get_window_size()
184
- if hasattr(self.model, "get_window_size")
202
+ self.model.get_attention_sliding_window_size()
203
+ if hasattr(self.model, "get_attention_sliding_window_size")
185
204
  else None
186
205
  )
187
206
  self.is_generation = is_generation_model(
188
- self.model_config.hf_config.architectures
207
+ self.model_config.hf_config.architectures, self.server_args.is_embedding
189
208
  )
190
209
 
191
210
  logger.info(
192
- f"[gpu={self.gpu_id}] Load weight end. "
211
+ f"Load weight end. "
193
212
  f"type={type(self.model).__name__}, "
194
213
  f"dtype={self.dtype}, "
195
214
  f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
196
215
  )
197
216
 
198
- def profile_max_num_token(self, total_gpu_memory):
217
+ def update_weights(self, model_path: str, load_format: str):
218
+ """Update weights in-place."""
219
+ from vllm.model_executor.model_loader.loader import (
220
+ DefaultModelLoader,
221
+ device_loading_context,
222
+ get_model_loader,
223
+ )
224
+ from vllm.model_executor.model_loader.utils import set_default_torch_dtype
225
+
226
+ logger.info(
227
+ f"Update weights begin. "
228
+ f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
229
+ )
230
+
231
+ target_device = torch.device(self.device_config.device)
232
+
233
+ try:
234
+ # TODO: Use a better method to check this
235
+ vllm_model_config = VllmModelConfig(
236
+ model=model_path,
237
+ quantization=self.server_args.quantization,
238
+ tokenizer=None,
239
+ tokenizer_mode=None,
240
+ trust_remote_code=self.server_args.trust_remote_code,
241
+ dtype=self.server_args.dtype,
242
+ seed=42,
243
+ skip_tokenizer_init=True,
244
+ )
245
+ except Exception as e:
246
+ logger.error(f"Failed to load model config: {e}")
247
+ return False, "Failed to update model weights"
248
+
249
+ load_config = LoadConfig(load_format=load_format)
250
+
251
+ # Only support vllm DefaultModelLoader for now
252
+ loader = get_model_loader(load_config)
253
+ if not isinstance(loader, DefaultModelLoader):
254
+ logger.error("Failed to get weights iterator: Unsupported loader")
255
+ return False, "Failed to update model weights"
256
+
257
+ def get_weight_iter(config):
258
+ iter = loader._get_weights_iterator(
259
+ config.model,
260
+ config.revision,
261
+ fall_back_to_pt=getattr(
262
+ self.model, "fall_back_to_pt_during_load", True
263
+ ),
264
+ )
265
+ return iter
266
+
267
+ def model_load_weights(model, iter):
268
+ model.load_weights(iter)
269
+ for _, module in self.model.named_modules():
270
+ quant_method = getattr(module, "quant_method", None)
271
+ if quant_method is not None:
272
+ with device_loading_context(module, target_device):
273
+ quant_method.process_weights_after_loading(module)
274
+ return model
275
+
276
+ with set_default_torch_dtype(vllm_model_config.dtype):
277
+ try:
278
+ iter = get_weight_iter(vllm_model_config)
279
+ except Exception as e:
280
+ message = f"Failed to get weights iterator: {e}"
281
+ logger.error(message)
282
+ return False, message
283
+ try:
284
+ model = model_load_weights(self.model, iter)
285
+ except Exception as e:
286
+ message = f"Failed to update weights: {e}. \n Rolling back to original weights"
287
+ logger.error(message)
288
+ del iter
289
+ gc.collect()
290
+ iter = get_weight_iter(self.vllm_model_config)
291
+ self.model = model_load_weights(self.model, iter)
292
+ return False, message
293
+
294
+ self.model = model
295
+ self.server_args.model_path = model_path
296
+ self.server_args.load_format = load_format
297
+ self.vllm_model_config = vllm_model_config
298
+ self.load_config = load_config
299
+ self.model_config.path = model_path
300
+
301
+ logger.info("Update weights end.")
302
+ return True, "Succeeded to update model weights"
303
+
304
+ def profile_max_num_token(self, total_gpu_memory: int):
199
305
  available_gpu_memory = get_available_gpu_memory(
200
306
  self.gpu_id, distributed=self.tp_size > 1
201
307
  )
@@ -206,7 +312,7 @@ class ModelRunner:
206
312
  cell_size = (
207
313
  (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
208
314
  * self.model_config.num_hidden_layers
209
- * torch._utils._element_size(self.dtype)
315
+ * torch._utils._element_size(self.kv_cache_dtype)
210
316
  )
211
317
  else:
212
318
  cell_size = (
@@ -214,7 +320,7 @@ class ModelRunner:
214
320
  * self.model_config.head_dim
215
321
  * self.model_config.num_hidden_layers
216
322
  * 2
217
- * torch._utils._element_size(self.dtype)
323
+ * torch._utils._element_size(self.kv_cache_dtype)
218
324
  )
219
325
  rest_memory = available_gpu_memory - total_gpu_memory * (
220
326
  1 - self.mem_fraction_static
@@ -223,12 +329,30 @@ class ModelRunner:
223
329
  return max_num_token
224
330
 
225
331
  def init_memory_pool(
226
- self, total_gpu_memory, max_num_reqs=None, max_total_tokens=None
332
+ self,
333
+ total_gpu_memory: int,
334
+ max_num_reqs: int = None,
335
+ max_total_tokens: int = None,
227
336
  ):
337
+ if self.server_args.kv_cache_dtype == "auto":
338
+ self.kv_cache_dtype = self.dtype
339
+ elif self.server_args.kv_cache_dtype == "fp8_e5m2":
340
+ if self.server_args.disable_flashinfer or self.server_args.enable_mla:
341
+ logger.warning(
342
+ "FP8 KV cache is not supported for Triton kernel now, using auto kv cache dtype"
343
+ )
344
+ self.kv_cache_dtype = self.dtype
345
+ else:
346
+ self.kv_cache_dtype = torch.float8_e5m2
347
+ else:
348
+ raise ValueError(
349
+ f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
350
+ )
351
+
228
352
  self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
229
353
  if max_total_tokens is not None:
230
354
  if max_total_tokens > self.max_total_num_tokens:
231
- warnings.warn(
355
+ logging.warning(
232
356
  f"max_total_tokens={max_total_tokens} is larger than the profiled value "
233
357
  f"{self.max_total_num_tokens}. "
234
358
  f"Use the profiled value instead."
@@ -261,7 +385,7 @@ class ModelRunner:
261
385
  ):
262
386
  self.token_to_kv_pool = MLATokenToKVPool(
263
387
  self.max_total_num_tokens,
264
- dtype=self.dtype,
388
+ dtype=self.kv_cache_dtype,
265
389
  kv_lora_rank=self.model_config.kv_lora_rank,
266
390
  qk_rope_head_dim=self.model_config.qk_rope_head_dim,
267
391
  layer_num=self.model_config.num_hidden_layers,
@@ -272,13 +396,13 @@ class ModelRunner:
272
396
  else:
273
397
  self.token_to_kv_pool = MHATokenToKVPool(
274
398
  self.max_total_num_tokens,
275
- dtype=self.dtype,
399
+ dtype=self.kv_cache_dtype,
276
400
  head_num=self.model_config.get_num_kv_heads(self.tp_size),
277
401
  head_dim=self.model_config.head_dim,
278
402
  layer_num=self.model_config.num_hidden_layers,
279
403
  )
280
404
  logger.info(
281
- f"[gpu={self.gpu_id}] Memory pool end. "
405
+ f"Memory pool end. "
282
406
  f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
283
407
  )
284
408
 
@@ -292,6 +416,7 @@ class ModelRunner:
292
416
  return c
293
417
 
294
418
  def init_flashinfer(self):
419
+ """Init flashinfer attention kernel wrappers."""
295
420
  if self.server_args.disable_flashinfer:
296
421
  assert (
297
422
  self.sliding_window_size is None
@@ -352,20 +477,29 @@ class ModelRunner:
352
477
  )
353
478
 
354
479
  def init_cuda_graphs(self):
480
+ """Capture cuda graphs."""
481
+ if not self.is_generation:
482
+ # TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
483
+ return
484
+
355
485
  from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
356
486
 
357
487
  if self.server_args.disable_cuda_graph or self.server_args.disable_flashinfer:
358
488
  self.cuda_graph_runner = None
359
489
  return
360
490
 
361
- logger.info(
362
- f"[gpu={self.gpu_id}] Capture cuda graph begin. This can take up to several minutes."
363
- )
364
- batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 17)]
491
+ logger.info("Capture cuda graph begin. This can take up to several minutes.")
492
+
493
+ if self.server_args.disable_cuda_graph_padding:
494
+ batch_size_list = list(range(1, 32)) + [64, 128]
495
+ else:
496
+ batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 21)]
497
+
365
498
  self.cuda_graph_runner = CudaGraphRunner(
366
499
  self,
367
500
  max_batch_size_to_capture=max(batch_size_list),
368
501
  use_torch_compile=self.server_args.enable_torch_compile,
502
+ disable_padding=self.server_args.disable_cuda_graph_padding,
369
503
  )
370
504
  try:
371
505
  self.cuda_graph_runner.capture(batch_size_list)
@@ -401,9 +535,18 @@ class ModelRunner:
401
535
  batch,
402
536
  forward_mode=ForwardMode.EXTEND,
403
537
  )
404
- return self.model.forward(
405
- batch.input_ids, input_metadata.positions, input_metadata
406
- )
538
+ if self.is_generation:
539
+ return self.model.forward(
540
+ batch.input_ids, input_metadata.positions, input_metadata
541
+ )
542
+ else:
543
+ # Only embedding models have get_embedding parameter
544
+ return self.model.forward(
545
+ batch.input_ids,
546
+ input_metadata.positions,
547
+ input_metadata,
548
+ get_embedding=True,
549
+ )
407
550
 
408
551
  @torch.inference_mode()
409
552
  def forward_extend_multi_modal(self, batch: ScheduleBatch):
@@ -477,4 +620,4 @@ def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
477
620
 
478
621
 
479
622
  # Monkey patch model loader
480
- setattr(ModelRegistry, "load_model_cls", load_model_cls_srt)
623
+ setattr(ModelRegistry, "_try_load_model_cls", load_model_cls_srt)
@@ -25,7 +25,6 @@ from vllm.distributed import get_tensor_model_parallel_world_size
25
25
 
26
26
  # FIXME: temporary solution, remove after next vllm release
27
27
  from vllm.model_executor.custom_op import CustomOp
28
- from vllm.model_executor.layers.activation import GeluAndMul
29
28
 
30
29
  # from vllm.model_executor.layers.layernorm import GemmaRMSNorm
31
30
  from vllm.model_executor.layers.linear import (
@@ -39,6 +38,7 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
39
38
  from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
40
39
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
41
40
 
41
+ from sglang.srt.layers.activation import GeluAndMul
42
42
  from sglang.srt.layers.logits_processor import LogitsProcessor
43
43
  from sglang.srt.layers.radix_attention import RadixAttention
44
44
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
@@ -46,7 +46,7 @@ from sglang.srt.model_executor.forward_batch_info import InputMetadata
46
46
 
47
47
  # Aligned with HF's implementation, using sliding window inclusive with the last token
48
48
  # SGLang assumes exclusive
49
- def get_window_size(config):
49
+ def get_attention_sliding_window_size(config):
50
50
  return config.sliding_window - 1
51
51
 
52
52
 
@@ -135,7 +135,7 @@ class Gemma2MLP(nn.Module):
135
135
  "function. Please set `hidden_act` and `hidden_activation` to "
136
136
  "`gelu_pytorch_tanh`."
137
137
  )
138
- self.act_fn = GeluAndMul(approximate="tanh")
138
+ self.act_fn = GeluAndMul()
139
139
 
140
140
  def forward(self, x: torch.Tensor) -> torch.Tensor:
141
141
  gate_up, _ = self.gate_up_proj(x)
@@ -213,7 +213,11 @@ class Gemma2Attention(nn.Module):
213
213
  self.scaling,
214
214
  num_kv_heads=self.num_kv_heads,
215
215
  layer_id=layer_idx,
216
- sliding_window_size=get_window_size(config) if use_sliding_window else None,
216
+ sliding_window_size=(
217
+ get_attention_sliding_window_size(config)
218
+ if use_sliding_window
219
+ else None
220
+ ),
217
221
  logit_cap=self.config.attn_logit_softcapping,
218
222
  )
219
223
 
@@ -406,8 +410,8 @@ class Gemma2ForCausalLM(nn.Module):
406
410
  input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
407
411
  )
408
412
 
409
- def get_window_size(self):
410
- return get_window_size(self.config)
413
+ def get_attention_sliding_window_size(self):
414
+ return get_attention_sliding_window_size(self.config)
411
415
 
412
416
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
413
417
  stacked_params_mapping = [
@@ -23,7 +23,6 @@ from torch import nn
23
23
  from transformers import GPTBigCodeConfig
24
24
  from vllm.config import CacheConfig, LoRAConfig
25
25
  from vllm.distributed import get_tensor_model_parallel_world_size
26
- from vllm.model_executor.layers.activation import get_act_fn
27
26
  from vllm.model_executor.layers.linear import (
28
27
  ColumnParallelLinear,
29
28
  QKVParallelLinear,
@@ -33,6 +32,7 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
33
32
  from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
34
33
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
35
34
 
35
+ from sglang.srt.layers.activation import get_act_fn
36
36
  from sglang.srt.layers.logits_processor import LogitsProcessor
37
37
  from sglang.srt.layers.radix_attention import RadixAttention
38
38
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
sglang/srt/models/grok.py CHANGED
@@ -300,6 +300,9 @@ class Grok1ModelForCausalLM(nn.Module):
300
300
 
301
301
  # Monkey patch _prepare_weights to load pre-sharded weights
302
302
  setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
303
+
304
+ self.use_presharded_weights = True
305
+
303
306
  warnings.filterwarnings("ignore", category=FutureWarning)
304
307
 
305
308
  def forward(
@@ -355,6 +358,13 @@ class Grok1ModelForCausalLM(nn.Module):
355
358
  continue
356
359
  name = name.replace(weight_name, param_name)
357
360
 
361
+ if self.use_presharded_weights:
362
+ extra_kwargs = {
363
+ "use_presharded_weights": self.use_presharded_weights
364
+ }
365
+ else:
366
+ extra_kwargs = {}
367
+
358
368
  param = params_dict[name]
359
369
  weight_loader = param.weight_loader
360
370
  weight_loader(
@@ -363,7 +373,7 @@ class Grok1ModelForCausalLM(nn.Module):
363
373
  weight_name,
364
374
  shard_id=shard_id,
365
375
  expert_id=expert_id,
366
- pre_sharded=get_tensor_model_parallel_world_size() > 1,
376
+ **extra_kwargs,
367
377
  )
368
378
  break
369
379
  else:
@@ -29,7 +29,11 @@ class LlamaEmbeddingModel(nn.Module):
29
29
  positions: torch.Tensor,
30
30
  input_metadata: InputMetadata,
31
31
  input_embeds: torch.Tensor = None,
32
+ get_embedding: bool = True,
32
33
  ) -> EmbeddingPoolerOutput:
34
+ assert (
35
+ get_embedding
36
+ ), "LlamaEmbeddingModel / MistralModel is only used for embedding"
33
37
  hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
34
38
  return self.pooler(hidden_states, input_metadata)
35
39