sglang 0.2.14.post1__py3-none-any.whl → 0.2.14.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -23,6 +23,7 @@ import multiprocessing as mp
23
23
  import os
24
24
  from typing import Dict, List, Optional, Tuple, Union
25
25
 
26
+ import fastapi
26
27
  import numpy as np
27
28
  import transformers
28
29
  import uvloop
@@ -96,21 +97,18 @@ class TokenizerManager:
96
97
  trust_remote_code=server_args.trust_remote_code,
97
98
  model_overide_args=model_overide_args,
98
99
  )
99
-
100
100
  self.is_generation = is_generation_model(
101
101
  self.hf_config.architectures, self.server_args.is_embedding
102
102
  )
103
-
104
- if server_args.context_length is not None:
105
- self.context_len = server_args.context_length
106
- else:
107
- self.context_len = get_context_length(self.hf_config)
103
+ self.context_len = server_args.context_length or get_context_length(
104
+ self.hf_config
105
+ )
108
106
 
109
107
  # Create tokenizer
110
108
  if server_args.skip_tokenizer_init:
111
109
  self.tokenizer = self.processor = None
112
110
  else:
113
- if is_multimodal_model(self.model_path):
111
+ if is_multimodal_model(self.hf_config.architectures):
114
112
  self.processor = get_processor(
115
113
  server_args.tokenizer_path,
116
114
  tokenizer_mode=server_args.tokenizer_mode,
@@ -118,6 +116,9 @@ class TokenizerManager:
118
116
  )
119
117
  self.tokenizer = self.processor.tokenizer
120
118
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
119
+
120
+ # We want to parallelize the image pre-processing so we
121
+ # create an executor for it
121
122
  self.executor = concurrent.futures.ProcessPoolExecutor(
122
123
  initializer=init_global_processor,
123
124
  mp_context=mp.get_context("fork"),
@@ -134,12 +135,14 @@ class TokenizerManager:
134
135
  self.to_create_loop = True
135
136
  self.rid_to_state: Dict[str, ReqState] = {}
136
137
 
137
- # for update model weights
138
+ # For update model weights
138
139
  self.model_update_lock = asyncio.Lock()
139
140
  self.model_update_result = None
140
141
 
141
142
  async def generate_request(
142
- self, obj: Union[GenerateReqInput, EmbeddingReqInput], request=None
143
+ self,
144
+ obj: Union[GenerateReqInput, EmbeddingReqInput],
145
+ request: Optional[fastapi.Request] = None,
143
146
  ):
144
147
  if self.to_create_loop:
145
148
  self.create_handle_loop()
@@ -160,7 +163,7 @@ class TokenizerManager:
160
163
  async def _handle_single_request(
161
164
  self,
162
165
  obj: Union[GenerateReqInput, EmbeddingReqInput],
163
- request,
166
+ request: Optional[fastapi.Request] = None,
164
167
  index: Optional[int] = None,
165
168
  is_cache_for_prefill: Optional[bool] = False,
166
169
  ):
@@ -182,8 +185,8 @@ class TokenizerManager:
182
185
  )
183
186
 
184
187
  if self.is_generation:
185
- pixel_values, image_hash, image_size = await self._get_pixel_values(
186
- obj.image_data
188
+ pixel_values, image_hashes, image_sizes = await self._get_pixel_values(
189
+ obj.image_data if not_use_index else obj.image_data[index]
187
190
  )
188
191
  return_logprob = (
189
192
  obj.return_logprob if not_use_index else obj.return_logprob[index]
@@ -195,7 +198,6 @@ class TokenizerManager:
195
198
  )
196
199
  if return_logprob and logprob_start_len == -1:
197
200
  logprob_start_len = len(input_ids) - 1
198
-
199
201
  top_logprobs_num = (
200
202
  obj.top_logprobs_num
201
203
  if not_use_index
@@ -238,13 +240,14 @@ class TokenizerManager:
238
240
 
239
241
  sampling_params = SamplingParams(**obj.sampling_params[0])
240
242
  sampling_params.max_new_tokens = 0
241
- pixel_values, image_hash, image_size = await self._get_pixel_values(
243
+ pixel_values, image_hashes, image_sizes = await self._get_pixel_values(
242
244
  obj.image_data[0]
243
245
  )
244
246
  return_logprob = obj.return_logprob[0]
245
247
  logprob_start_len = obj.logprob_start_len[0]
246
248
  top_logprobs_num = obj.top_logprobs_num[0]
247
249
 
250
+ # Send to the controller
248
251
  if self.is_generation:
249
252
  if return_logprob and logprob_start_len == -1:
250
253
  logprob_start_len = len(input_ids) - 1
@@ -253,8 +256,8 @@ class TokenizerManager:
253
256
  input_text,
254
257
  input_ids,
255
258
  pixel_values,
256
- image_hash,
257
- image_size,
259
+ image_hashes,
260
+ image_sizes,
258
261
  sampling_params,
259
262
  return_logprob,
260
263
  logprob_start_len,
@@ -268,24 +271,24 @@ class TokenizerManager:
268
271
  input_ids,
269
272
  sampling_params,
270
273
  )
271
-
272
274
  self.send_to_router.send_pyobj(tokenized_obj)
273
275
 
276
+ # Recv results
274
277
  event = asyncio.Event()
275
278
  state = ReqState([], False, event)
276
279
  self.rid_to_state[rid] = state
277
280
  if not is_cache_for_prefill:
278
- async for response in self._wait_for_response(
279
- event, state, obj, rid, request
280
- ):
281
+ async for response in self._wait_for_response(state, obj, rid, request):
281
282
  yield response
282
283
  else:
283
284
  assert self.is_generation
284
- await self._wait_for_cache_prefill_response(event, state, obj, rid, request)
285
+ await self._wait_for_cache_prefill_response(state, obj, rid, request)
285
286
  yield input_ids
286
287
 
287
288
  async def _handle_batch_request(
288
- self, obj: Union[GenerateReqInput, EmbeddingReqInput], request
289
+ self,
290
+ obj: Union[GenerateReqInput, EmbeddingReqInput],
291
+ request: Optional[fastapi.Request] = None,
289
292
  ):
290
293
  batch_size = obj.batch_size
291
294
  if self.is_generation:
@@ -340,8 +343,8 @@ class TokenizerManager:
340
343
  if self.is_generation:
341
344
  if obj.return_logprob[index] and obj.logprob_start_len[index] == -1:
342
345
  obj.logprob_start_len[index] = len(input_ids) - 1
343
- pixel_values, image_hash, image_size = await self._get_pixel_values(
344
- obj.image_data[index]
346
+ pixel_values, image_hashes, image_sizes = (
347
+ await self._get_pixel_values(obj.image_data[index])
345
348
  )
346
349
 
347
350
  tokenized_obj = TokenizedGenerateReqInput(
@@ -349,8 +352,8 @@ class TokenizerManager:
349
352
  input_text,
350
353
  input_ids,
351
354
  pixel_values,
352
- image_hash,
353
- image_size,
355
+ image_hashes,
356
+ image_sizes,
354
357
  sampling_params,
355
358
  obj.return_logprob[index],
356
359
  obj.logprob_start_len[index],
@@ -372,7 +375,6 @@ class TokenizerManager:
372
375
 
373
376
  generators.append(
374
377
  self._wait_for_response(
375
- event,
376
378
  state,
377
379
  obj,
378
380
  rid,
@@ -388,6 +390,7 @@ class TokenizerManager:
388
390
  tasks = [asyncio.create_task(gen.__anext__()) for gen in generators]
389
391
  output_list = [None] * len(tasks)
390
392
 
393
+ # Recv results
391
394
  while tasks:
392
395
  done, _ = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
393
396
 
@@ -426,25 +429,18 @@ class TokenizerManager:
426
429
  sampling_params.verify()
427
430
  return sampling_params
428
431
 
429
- async def _get_pixel_values(self, image_data):
430
- if image_data is None:
431
- return None, None, None
432
- else:
433
- return await self._get_pixel_values_internal(image_data)
434
-
435
432
  async def _wait_for_response(
436
433
  self,
437
- event: asyncio.Event,
438
434
  state: ReqState,
439
435
  obj: Union[GenerateReqInput, EmbeddingReqInput],
440
436
  rid: str,
441
- request,
442
- index: int = None,
437
+ request: Optional[fastapi.Request] = None,
438
+ index: Optional[int] = None,
443
439
  response_index: int = 0,
444
440
  ):
445
441
  while True:
446
442
  try:
447
- await asyncio.wait_for(event.wait(), timeout=4)
443
+ await asyncio.wait_for(state.event.wait(), timeout=4)
448
444
  except asyncio.TimeoutError:
449
445
  if request is not None and await request.is_disconnected():
450
446
  for rid in [obj.rid] if obj.is_single else obj.rid:
@@ -478,16 +474,15 @@ class TokenizerManager:
478
474
  yield out
479
475
  break
480
476
 
481
- event.clear()
477
+ state.event.clear()
482
478
  yield out
483
479
 
484
480
  async def _wait_for_cache_prefill_response(
485
481
  self,
486
- event: asyncio.Event,
487
482
  state: ReqState,
488
483
  obj: GenerateReqInput,
489
484
  rid: str,
490
- request,
485
+ request: Optional[fastapi.Request] = None,
491
486
  ):
492
487
  while True:
493
488
  try:
@@ -514,7 +509,9 @@ class TokenizerManager:
514
509
  req = AbortReq(rid)
515
510
  self.send_to_router.send_pyobj(req)
516
511
 
517
- async def update_weights(self, obj: UpdateWeightReqInput, request):
512
+ async def update_weights(
513
+ self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
514
+ ):
518
515
  if self.to_create_loop:
519
516
  self.create_handle_loop()
520
517
 
@@ -659,12 +656,11 @@ class TokenizerManager:
659
656
  )
660
657
  return top_logprobs
661
658
 
662
- async def _get_pixel_values_internal(self, image_data, aspect_ratio=None):
663
- aspect_ratio = (
664
- getattr(self.hf_config, "image_aspect_ratio", None)
665
- if aspect_ratio is None
666
- else aspect_ratio
667
- )
659
+ async def _get_pixel_values(self, image_data: List[Union[str, bytes]]):
660
+ if not image_data:
661
+ return None, None, None
662
+
663
+ aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
668
664
  grid_pinpoints = (
669
665
  self.hf_config.image_grid_pinpoints
670
666
  if hasattr(self.hf_config, "image_grid_pinpoints")
@@ -673,35 +669,42 @@ class TokenizerManager:
673
669
  )
674
670
 
675
671
  if isinstance(image_data, list) and len(image_data) > 0:
676
- pixel_values, image_hash, image_size = [], [], []
672
+ # Multiple images
677
673
  if len(image_data) > 1:
678
674
  aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
675
+ pixel_values, image_hashes, image_sizes = [], [], []
679
676
  for img_data in image_data:
680
677
  pixel_v, image_h, image_s = await self._process_single_image(
681
678
  img_data, aspect_ratio, grid_pinpoints
682
679
  )
683
680
  pixel_values.append(pixel_v)
684
- image_hash.append(image_h)
685
- image_size.append(image_s)
686
- pixel_values = np.stack(pixel_values, axis=0)
681
+ image_hashes.append(image_h)
682
+ image_sizes.append(image_s)
683
+
684
+ if isinstance(pixel_values[0], np.ndarray):
685
+ pixel_values = np.stack(pixel_values, axis=0)
687
686
  else:
687
+ # A single image
688
688
  pixel_values, image_hash, image_size = await self._process_single_image(
689
689
  image_data[0], aspect_ratio, grid_pinpoints
690
690
  )
691
- image_hash = [image_hash]
692
- image_size = [image_size]
691
+ image_hashes = [image_hash]
692
+ image_sizes = [image_size]
693
693
  elif isinstance(image_data, str):
694
+ # A single image
694
695
  pixel_values, image_hash, image_size = await self._process_single_image(
695
696
  image_data, aspect_ratio, grid_pinpoints
696
697
  )
697
- image_hash = [image_hash]
698
- image_size = [image_size]
698
+ image_hashes = [image_hash]
699
+ image_sizes = [image_size]
699
700
  else:
700
- pixel_values, image_hash, image_size = None, None, None
701
+ raise ValueError(f"Invalid image data: {image_data}")
701
702
 
702
- return pixel_values, image_hash, image_size
703
+ return pixel_values, image_hashes, image_sizes
703
704
 
704
- async def _process_single_image(self, image_data, aspect_ratio, grid_pinpoints):
705
+ async def _process_single_image(
706
+ self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str
707
+ ):
705
708
  if self.executor is not None:
706
709
  loop = asyncio.get_event_loop()
707
710
  return await loop.run_in_executor(
@@ -732,12 +735,16 @@ def init_global_processor(server_args: ServerArgs):
732
735
 
733
736
 
734
737
  def _process_single_image_task(
735
- image_data, image_aspect_ratio=None, image_grid_pinpoints=None, processor=None
738
+ image_data: Union[str, bytes],
739
+ image_aspect_ratio: Optional[str] = None,
740
+ image_grid_pinpoints: Optional[str] = None,
741
+ processor=None,
736
742
  ):
737
743
  try:
738
744
  processor = processor or global_processor
739
745
  image, image_size = load_image(image_data)
740
746
  if image_size is not None:
747
+ # It is a video with multiple images
741
748
  image_hash = hash(image_data)
742
749
  pixel_values = processor.image_processor(image)["pixel_values"]
743
750
  for _ in range(len(pixel_values)):
@@ -745,6 +752,7 @@ def _process_single_image_task(
745
752
  pixel_values = np.stack(pixel_values, axis=0)
746
753
  return pixel_values, image_hash, image_size
747
754
  else:
755
+ # It is an image
748
756
  image_hash = hash(image_data)
749
757
  if image_aspect_ratio == "pad":
750
758
  image = expand2square(
@@ -754,13 +762,18 @@ def _process_single_image_task(
754
762
  pixel_values = processor.image_processor(image.convert("RGB"))[
755
763
  "pixel_values"
756
764
  ][0]
757
- elif image_aspect_ratio == "anyres" or "anyres_max" in image_aspect_ratio:
765
+ elif image_aspect_ratio == "anyres" or (
766
+ image_aspect_ratio is not None and "anyres_max" in image_aspect_ratio
767
+ ):
758
768
  pixel_values = process_anyres_image(
759
769
  image, processor.image_processor, image_grid_pinpoints
760
770
  )
761
771
  else:
762
772
  pixel_values = processor.image_processor(image)["pixel_values"][0]
763
- pixel_values = pixel_values.astype(np.float16)
773
+
774
+ if isinstance(pixel_values, np.ndarray):
775
+ pixel_values = pixel_values.astype(np.float16)
776
+
764
777
  return pixel_values, image_hash, image.size
765
778
  except Exception:
766
779
  logger.error("Exception in TokenizerManager:\n" + get_exception_traceback())
@@ -108,7 +108,7 @@ class ModelTpServer:
108
108
  if server_args.skip_tokenizer_init:
109
109
  self.tokenizer = self.processor = None
110
110
  else:
111
- if is_multimodal_model(server_args.model_path):
111
+ if is_multimodal_model(self.model_config.hf_config.architectures):
112
112
  self.processor = get_processor(
113
113
  server_args.tokenizer_path,
114
114
  tokenizer_mode=server_args.tokenizer_mode,
@@ -333,26 +333,24 @@ class ModelTpServer:
333
333
  if self.model_runner.is_generation:
334
334
  req.pixel_values = recv_req.pixel_values
335
335
  if req.pixel_values is not None:
336
- image_hash = (
337
- hash(tuple(recv_req.image_hash))
338
- if isinstance(recv_req.image_hash, list)
339
- else recv_req.image_hash
340
- )
336
+ # Use image hash as fake token_ids, which is then used
337
+ # for prefix matching
338
+ image_hash = hash(tuple(recv_req.image_hashes))
341
339
  req.pad_value = [
342
340
  (image_hash) % self.model_config.vocab_size,
343
341
  (image_hash >> 16) % self.model_config.vocab_size,
344
342
  (image_hash >> 32) % self.model_config.vocab_size,
345
343
  (image_hash >> 64) % self.model_config.vocab_size,
346
344
  ]
347
- req.image_size = recv_req.image_size
345
+ req.image_sizes = recv_req.image_sizes
348
346
  (
349
347
  req.origin_input_ids,
350
- req.image_offset,
348
+ req.image_offsets,
351
349
  ) = self.model_runner.model.pad_input_ids(
352
350
  req.origin_input_ids_unpadded,
353
351
  req.pad_value,
354
- req.pixel_values.shape,
355
- req.image_size,
352
+ req.pixel_values,
353
+ req.image_sizes,
356
354
  )
357
355
  req.return_logprob = recv_req.return_logprob
358
356
  req.logprob_start_len = recv_req.logprob_start_len
@@ -368,6 +366,7 @@ class ModelTpServer:
368
366
  req.jump_forward_map = self.jump_forward_cache.query(
369
367
  computed_regex_string
370
368
  )
369
+
371
370
  # Init regex fsm
372
371
  elif req.sampling_params.regex is not None:
373
372
  req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex)
@@ -16,7 +16,7 @@ limitations under the License.
16
16
  """ModelRunner runs the forward passes of the models."""
17
17
  from dataclasses import dataclass
18
18
  from enum import IntEnum, auto
19
- from typing import TYPE_CHECKING, List, Optional
19
+ from typing import TYPE_CHECKING, List
20
20
 
21
21
  import numpy as np
22
22
  import torch
@@ -58,6 +58,7 @@ class InputMetadata:
58
58
 
59
59
  # For extend
60
60
  extend_seq_lens: torch.Tensor = None
61
+ extend_prefix_lens: torch.Tensor = None
61
62
  extend_start_loc: torch.Tensor = None
62
63
  extend_no_prefix: bool = None
63
64
 
@@ -69,8 +70,8 @@ class InputMetadata:
69
70
 
70
71
  # For multimodal
71
72
  pixel_values: List[torch.Tensor] = None
72
- image_sizes: List[List[int]] = None
73
- image_offsets: List[int] = None
73
+ image_sizes: List[List[List[int]]] = None
74
+ image_offsets: List[List[int]] = None
74
75
 
75
76
  # Trition attention backend
76
77
  triton_max_seq_len: int = 0
@@ -87,20 +88,8 @@ class InputMetadata:
87
88
  def init_multimuldal_info(self, batch: ScheduleBatch):
88
89
  reqs = batch.reqs
89
90
  self.pixel_values = [r.pixel_values for r in reqs]
90
- self.image_sizes = [r.image_size for r in reqs]
91
- self.image_offsets = []
92
- for r in reqs:
93
- if isinstance(r.image_offset, list):
94
- self.image_offsets.append(
95
- [
96
- (image_offset - len(r.prefix_indices))
97
- for image_offset in r.image_offset
98
- ]
99
- )
100
- elif isinstance(r.image_offset, int):
101
- self.image_offsets.append(r.image_offset - len(r.prefix_indices))
102
- elif r.image_offset is None:
103
- self.image_offsets.append(0)
91
+ self.image_sizes = [r.image_sizes for r in reqs]
92
+ self.image_offsets = [r.image_offsets for r in reqs]
104
93
 
105
94
  def compute_positions(self, batch: ScheduleBatch):
106
95
  position_ids_offsets = batch.position_ids_offsets
@@ -153,6 +142,7 @@ class InputMetadata:
153
142
  for i, r in enumerate(batch.reqs)
154
143
  ]
155
144
  self.extend_seq_lens = torch.tensor(extend_lens_cpu, device="cuda")
145
+ self.extend_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
156
146
  self.extend_start_loc = torch.zeros_like(self.seq_lens)
157
147
  self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
158
148
  self.extend_no_prefix = all(l == 0 for l in batch.prefix_lens_cpu)
@@ -238,10 +228,10 @@ class InputMetadata:
238
228
  prefix_lens_cpu,
239
229
  flashinfer_use_ragged,
240
230
  ):
241
- if self.forward_mode != ForwardMode.DECODE:
242
- prefix_lens = torch.tensor(prefix_lens_cpu, device="cuda")
243
- else:
231
+ if self.forward_mode == ForwardMode.DECODE:
244
232
  prefix_lens = None
233
+ else:
234
+ prefix_lens = self.extend_prefix_lens
245
235
 
246
236
  update_flashinfer_indices(
247
237
  self.forward_mode,
@@ -50,7 +50,7 @@ from sglang.srt.mem_cache.memory_pool import (
50
50
  MLATokenToKVPool,
51
51
  ReqToTokenPool,
52
52
  )
53
- from sglang.srt.model_config import AttentionArch
53
+ from sglang.srt.model_config import AttentionArch, ModelConfig
54
54
  from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
55
55
  from sglang.srt.server_args import ServerArgs
56
56
  from sglang.srt.utils import (
@@ -69,7 +69,7 @@ logger = logging.getLogger(__name__)
69
69
  class ModelRunner:
70
70
  def __init__(
71
71
  self,
72
- model_config,
72
+ model_config: ModelConfig,
73
73
  mem_fraction_static: float,
74
74
  gpu_id: int,
75
75
  tp_rank: int,
@@ -85,7 +85,9 @@ class ModelRunner:
85
85
  self.tp_size = tp_size
86
86
  self.nccl_port = nccl_port
87
87
  self.server_args = server_args
88
- self.is_multimodal_model = is_multimodal_model(self.model_config)
88
+ self.is_multimodal_model = is_multimodal_model(
89
+ self.model_config.hf_config.architectures
90
+ )
89
91
  global_server_args_dict.update(
90
92
  {
91
93
  "disable_flashinfer": server_args.disable_flashinfer,
@@ -95,6 +97,13 @@ class ModelRunner:
95
97
  }
96
98
  )
97
99
 
100
+ if self.is_multimodal_model:
101
+ logger.info(
102
+ "Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
103
+ )
104
+ server_args.chunked_prefill_size = None
105
+ server_args.mem_fraction_static *= 0.95
106
+
98
107
  min_per_gpu_memory = self.init_torch_distributed()
99
108
  self.load_model()
100
109
  self.init_memory_pool(
@@ -507,9 +516,9 @@ class ModelRunner:
507
516
  raise Exception(
508
517
  f"Capture cuda graph failed: {e}\n"
509
518
  "Possible solutions:\n"
510
- "1. disable torch compile by not using --enable-torch-compile\n"
511
- "2. disable cuda graph by --disable-cuda-graph\n"
512
- "3. set --mem-fraction-static to a smaller value\n"
519
+ "1. disable cuda graph by --disable-cuda-graph\n"
520
+ "2. set --mem-fraction-static to a smaller value\n"
521
+ "3. disable torch compile by not using --enable-torch-compile\n"
513
522
  "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
514
523
  )
515
524
 
@@ -17,7 +17,7 @@ limitations under the License.
17
17
  # Adapted from
18
18
  # https://github.com/THUDM/ChatGLM2-6B
19
19
  """Inference-only ChatGLM model compatible with THUDM weights."""
20
- from typing import Iterable, List, Optional, Tuple
20
+ from typing import Iterable, Optional, Tuple
21
21
 
22
22
  import torch
23
23
  from torch import nn
@@ -23,7 +23,6 @@ from torch import nn
23
23
  from transformers import PretrainedConfig
24
24
  from vllm.config import CacheConfig, LoRAConfig
25
25
  from vllm.distributed import get_tensor_model_parallel_world_size
26
- from vllm.model_executor.layers.activation import GeluAndMul
27
26
  from vllm.model_executor.layers.linear import (
28
27
  MergedColumnParallelLinear,
29
28
  QKVParallelLinear,
@@ -34,6 +33,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
34
33
  from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
35
34
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
36
35
 
36
+ from sglang.srt.layers.activation import GeluAndMul
37
37
  from sglang.srt.layers.layernorm import RMSNorm
38
38
  from sglang.srt.layers.logits_processor import LogitsProcessor
39
39
  from sglang.srt.layers.radix_attention import RadixAttention
@@ -60,7 +60,7 @@ class GemmaMLP(nn.Module):
60
60
  bias=False,
61
61
  quant_config=quant_config,
62
62
  )
63
- self.act_fn = GeluAndMul()
63
+ self.act_fn = GeluAndMul("none")
64
64
 
65
65
  def forward(self, x):
66
66
  gate_up, _ = self.gate_up_proj(x)
@@ -22,11 +22,6 @@ from torch import nn
22
22
  from transformers import PretrainedConfig
23
23
  from vllm.config import CacheConfig, LoRAConfig
24
24
  from vllm.distributed import get_tensor_model_parallel_world_size
25
-
26
- # FIXME: temporary solution, remove after next vllm release
27
- from vllm.model_executor.custom_op import CustomOp
28
-
29
- # from vllm.model_executor.layers.layernorm import GemmaRMSNorm
30
25
  from vllm.model_executor.layers.linear import (
31
26
  MergedColumnParallelLinear,
32
27
  QKVParallelLinear,
@@ -39,6 +34,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmb
39
34
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
40
35
 
41
36
  from sglang.srt.layers.activation import GeluAndMul
37
+ from sglang.srt.layers.layernorm import GemmaRMSNorm
42
38
  from sglang.srt.layers.logits_processor import LogitsProcessor
43
39
  from sglang.srt.layers.radix_attention import RadixAttention
44
40
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
@@ -50,52 +46,6 @@ def get_attention_sliding_window_size(config):
50
46
  return config.sliding_window - 1
51
47
 
52
48
 
53
- class GemmaRMSNorm(CustomOp):
54
- """RMS normalization for Gemma.
55
-
56
- Two differences from the above RMSNorm:
57
- 1. x * (1 + w) instead of x * w.
58
- 2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w.
59
- """
60
-
61
- def __init__(
62
- self,
63
- hidden_size: int,
64
- eps: float = 1e-6,
65
- ) -> None:
66
- super().__init__()
67
- self.weight = nn.Parameter(torch.zeros(hidden_size))
68
- self.variance_epsilon = eps
69
-
70
- def forward_native(
71
- self,
72
- x: torch.Tensor,
73
- residual: Optional[torch.Tensor] = None,
74
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
75
- """PyTorch-native implementation equivalent to forward()."""
76
- orig_dtype = x.dtype
77
- if residual is not None:
78
- x = x + residual
79
- residual = x
80
-
81
- x = x.float()
82
- variance = x.pow(2).mean(dim=-1, keepdim=True)
83
- x = x * torch.rsqrt(variance + self.variance_epsilon)
84
- # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
85
- # See https://github.com/huggingface/transformers/pull/29402
86
- x = x * (1.0 + self.weight.float())
87
- x = x.to(orig_dtype)
88
- return x if residual is None else (x, residual)
89
-
90
- def forward_cuda(
91
- self,
92
- x: torch.Tensor,
93
- residual: Optional[torch.Tensor] = None,
94
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
95
- # from vLLM: TODO(woosuk): Implement an optimized kernel for GemmaRMSNorm.
96
- return self.forward_native(x, residual)
97
-
98
-
99
49
  # FIXME: temporary solution, remove after next vllm release
100
50
  from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
101
51
 
sglang/srt/models/grok.py CHANGED
@@ -273,9 +273,9 @@ class Grok1Model(nn.Module):
273
273
  ) -> torch.Tensor:
274
274
  if input_embeds is None:
275
275
  hidden_states = self.embed_tokens(input_ids)
276
+ hidden_states.mul_(self.config.embedding_multiplier_scale)
276
277
  else:
277
278
  hidden_states = input_embeds
278
- hidden_states.mul_(self.config.embedding_multiplier_scale)
279
279
 
280
280
  for i in range(len(self.layers)):
281
281
  hidden_states = self.layers[i](positions, hidden_states, input_metadata)
@@ -284,7 +284,7 @@ class Grok1Model(nn.Module):
284
284
  return hidden_states
285
285
 
286
286
 
287
- class Grok1ModelForCausalLM(nn.Module):
287
+ class Grok1ForCausalLM(nn.Module):
288
288
  def __init__(
289
289
  self,
290
290
  config: PretrainedConfig,
@@ -415,4 +415,10 @@ def _prepare_presharded_weights(
415
415
  return hf_folder, hf_weights_files, use_safetensors
416
416
 
417
417
 
418
- EntryClass = Grok1ModelForCausalLM
418
+ class Grok1ModelForCausalLM(Grok1ForCausalLM):
419
+ """An alias for backward-compatbility."""
420
+
421
+ pass
422
+
423
+
424
+ EntryClass = [Grok1ForCausalLM, Grok1ModelForCausalLM]