sglang 0.4.4.post3__py3-none-any.whl → 0.4.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. sglang/bench_serving.py +49 -7
  2. sglang/lang/chat_template.py +24 -0
  3. sglang/srt/_custom_ops.py +59 -92
  4. sglang/srt/configs/model_config.py +5 -0
  5. sglang/srt/constrained/base_grammar_backend.py +5 -1
  6. sglang/srt/conversation.py +29 -4
  7. sglang/srt/custom_op.py +5 -0
  8. sglang/srt/distributed/device_communicators/custom_all_reduce.py +27 -79
  9. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
  10. sglang/srt/entrypoints/engine.py +0 -5
  11. sglang/srt/layers/attention/flashattention_backend.py +678 -83
  12. sglang/srt/layers/attention/flashinfer_backend.py +5 -7
  13. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -3
  14. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  15. sglang/srt/layers/moe/ep_moe/kernels.py +142 -0
  16. sglang/srt/layers/moe/ep_moe/layer.py +79 -80
  17. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +382 -199
  18. sglang/srt/layers/moe/fused_moe_native.py +5 -0
  19. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  20. sglang/srt/layers/moe/fused_moe_triton/configs/E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  21. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  22. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H200.json +146 -0
  23. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  24. sglang/srt/layers/moe/fused_moe_triton/configs/E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  25. sglang/srt/layers/moe/fused_moe_triton/configs/E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json +146 -0
  27. sglang/srt/layers/moe/fused_moe_triton/configs/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +416 -50
  30. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  31. sglang/srt/layers/moe/topk.py +49 -3
  32. sglang/srt/layers/quantization/__init__.py +5 -1
  33. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  34. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +2 -1
  35. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +34 -10
  36. sglang/srt/layers/quantization/fp8.py +3 -1
  37. sglang/srt/layers/quantization/fp8_utils.py +1 -4
  38. sglang/srt/layers/quantization/moe_wna16.py +503 -0
  39. sglang/srt/layers/quantization/utils.py +1 -1
  40. sglang/srt/layers/quantization/w8a8_int8.py +2 -0
  41. sglang/srt/layers/radix_attention.py +2 -0
  42. sglang/srt/layers/rotary_embedding.py +63 -12
  43. sglang/srt/managers/cache_controller.py +34 -11
  44. sglang/srt/managers/mm_utils.py +202 -156
  45. sglang/srt/managers/multimodal_processor.py +0 -2
  46. sglang/srt/managers/multimodal_processors/base_processor.py +45 -77
  47. sglang/srt/managers/multimodal_processors/clip.py +7 -26
  48. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +17 -58
  49. sglang/srt/managers/multimodal_processors/gemma3.py +12 -27
  50. sglang/srt/managers/multimodal_processors/janus_pro.py +21 -47
  51. sglang/srt/managers/multimodal_processors/llava.py +34 -14
  52. sglang/srt/managers/multimodal_processors/minicpm.py +35 -38
  53. sglang/srt/managers/multimodal_processors/mlama.py +10 -23
  54. sglang/srt/managers/multimodal_processors/mllama4.py +161 -0
  55. sglang/srt/managers/multimodal_processors/qwen_vl.py +22 -45
  56. sglang/srt/managers/schedule_batch.py +185 -128
  57. sglang/srt/managers/scheduler.py +4 -4
  58. sglang/srt/managers/tokenizer_manager.py +1 -1
  59. sglang/srt/managers/utils.py +1 -6
  60. sglang/srt/mem_cache/hiradix_cache.py +62 -52
  61. sglang/srt/mem_cache/memory_pool.py +72 -6
  62. sglang/srt/mem_cache/paged_allocator.py +39 -0
  63. sglang/srt/metrics/collector.py +23 -53
  64. sglang/srt/model_executor/cuda_graph_runner.py +8 -6
  65. sglang/srt/model_executor/forward_batch_info.py +10 -10
  66. sglang/srt/model_executor/model_runner.py +60 -57
  67. sglang/srt/model_loader/loader.py +8 -0
  68. sglang/srt/models/clip.py +12 -7
  69. sglang/srt/models/deepseek_janus_pro.py +10 -15
  70. sglang/srt/models/deepseek_v2.py +212 -121
  71. sglang/srt/models/deepseek_vl2.py +105 -104
  72. sglang/srt/models/gemma3_mm.py +14 -80
  73. sglang/srt/models/llama.py +16 -5
  74. sglang/srt/models/llama4.py +420 -0
  75. sglang/srt/models/llava.py +31 -19
  76. sglang/srt/models/llavavid.py +16 -7
  77. sglang/srt/models/minicpmo.py +63 -147
  78. sglang/srt/models/minicpmv.py +17 -27
  79. sglang/srt/models/mllama.py +29 -14
  80. sglang/srt/models/mllama4.py +154 -0
  81. sglang/srt/models/qwen2.py +9 -6
  82. sglang/srt/models/qwen2_5_vl.py +21 -31
  83. sglang/srt/models/qwen2_vl.py +20 -21
  84. sglang/srt/openai_api/adapter.py +18 -6
  85. sglang/srt/platforms/interface.py +371 -0
  86. sglang/srt/server_args.py +99 -14
  87. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -5
  88. sglang/srt/speculative/eagle_utils.py +140 -28
  89. sglang/srt/speculative/eagle_worker.py +93 -24
  90. sglang/srt/utils.py +104 -51
  91. sglang/test/test_custom_ops.py +55 -0
  92. sglang/test/test_utils.py +13 -26
  93. sglang/utils.py +2 -2
  94. sglang/version.py +1 -1
  95. {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/METADATA +4 -3
  96. {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/RECORD +99 -84
  97. {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/WHEEL +0 -0
  98. {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/licenses/LICENSE +0 -0
  99. {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/top_level.txt +0 -0
@@ -11,7 +11,11 @@ from sglang.srt.configs.deepseekvl2 import (
11
11
  )
12
12
  from sglang.srt.layers.linear import ReplicatedLinear
13
13
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
14
- from sglang.srt.managers.schedule_batch import MultimodalInputs
14
+ from sglang.srt.managers.mm_utils import (
15
+ MultiModalityDataPaddingPatternImageTokens,
16
+ general_mm_embed_routine,
17
+ )
18
+ from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
15
19
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
16
20
  from sglang.srt.model_loader.weight_utils import default_weight_loader
17
21
  from sglang.srt.models.deepseek_v2 import DeepseekV2ForCausalLM
@@ -150,7 +154,6 @@ class DeepseekVL2MlpProjector(nn.Module):
150
154
  return x
151
155
 
152
156
 
153
- # todo
154
157
  class DeepseekVL2ForCausalLM(nn.Module):
155
158
 
156
159
  def __init__(
@@ -215,32 +218,15 @@ class DeepseekVL2ForCausalLM(nn.Module):
215
218
  forward_batch: ForwardBatch,
216
219
  **kwargs: object,
217
220
  ):
218
- input_embeds = self.language_model.model.embed_tokens(input_ids)
219
- if (
220
- forward_batch.forward_mode.is_extend()
221
- and forward_batch.contains_image_inputs()
222
- ):
223
- extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
224
- extend_seq_lens_cpu = forward_batch.extend_seq_lens.cpu().numpy()
225
- for idx, image in enumerate(forward_batch.mm_inputs):
226
- if image is None:
227
- continue
228
- start_idx = extend_start_loc_cpu[idx]
229
- end_idx = start_idx + extend_seq_lens_cpu[idx]
230
- images_emb_mask = image.images_emb_mask.to(device="cuda")
231
- image_features = self.get_image_feature(image)
232
- input_embeds[start_idx:end_idx] = input_embeds[
233
- start_idx:end_idx
234
- ].masked_scatter(images_emb_mask.unsqueeze(-1), image_features)
235
-
236
- outputs = self.language_model.forward(
221
+ hs = general_mm_embed_routine(
237
222
  input_ids=input_ids,
238
223
  positions=positions,
239
224
  forward_batch=forward_batch,
240
- input_embeds=input_embeds,
225
+ image_data_embedding_func=self.get_image_feature,
226
+ language_model=self.language_model,
241
227
  )
242
228
 
243
- return outputs
229
+ return hs
244
230
 
245
231
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
246
232
  stacked_params_mapping = [
@@ -263,94 +249,109 @@ class DeepseekVL2ForCausalLM(nn.Module):
263
249
  weights_loader(param, loaded_weight)
264
250
 
265
251
  def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
266
- return input_ids
267
-
268
- def get_image_feature(self, image_input: MultimodalInputs):
269
- pixel_values = image_input.pixel_values.type(
270
- next(self.vision.parameters()).dtype
271
- ).to(device=next(self.vision.parameters()).device)
272
- image_feature = self.vision.forward_features(pixel_values)
273
- images_embeds = self.projector(image_feature)
274
- _, hw, n_dim = images_embeds.shape
275
- h = w = int(hw**0.5)
276
- tile_index = 0
252
+ helper = MultiModalityDataPaddingPatternImageTokens(
253
+ image_token_id=image_inputs.im_token_id
254
+ )
255
+ return helper.pad_input_tokens(input_ids, image_inputs)
256
+
257
+ def get_image_feature(self, items: List[MultimodalDataItem]):
258
+
259
+ images_spatial_crop = torch.cat(
260
+ [item.image_spatial_crop for item in items], dim=0
261
+ )
262
+
263
+ assert images_spatial_crop.dim() == 3
264
+
265
+ # TODO: can it be batched ?
277
266
  images_in_this_batch = []
278
- images_spatial_crop = image_input.image_spatial_crop
279
- for jdx in range(images_spatial_crop.shape[1]):
280
- num_width_tiles, num_height_tiles = images_spatial_crop[0, jdx]
281
- if num_width_tiles == 0 or num_height_tiles == 0:
282
- break
283
- num_tiles_in_image = num_width_tiles * num_height_tiles
284
-
285
- # [hw, D]
286
- global_features = images_embeds[tile_index]
287
-
288
- # [num_height_tiles * num_width_tiles, hw, D]
289
- local_features = images_embeds[
290
- tile_index + 1 : tile_index + 1 + num_tiles_in_image
291
- ]
292
- tile_index += num_tiles_in_image + 1
293
-
294
- # format global and local features
295
- # ----------------- global view add newline -----------------
296
- # [hw, D] -> [h, w, D]
297
- global_features = global_features.view(h, w, n_dim)
298
-
299
- # [D] -> [h, 1, D]
300
- new_lines_in_global = repeat(self.image_newline, "d -> h 1 d", h=h)
301
-
302
- # cat([h, w, D], [h, 1, D], dim=1) -> [h, w + 1, D]
303
- global_features = torch.cat([global_features, new_lines_in_global], dim=1)
304
-
305
- # [h, w + 1, D] -> [h * (w + 1), D]
306
- global_features = global_features.view(-1, n_dim)
307
-
308
- # ----------------- local view add newline -----------------
309
- # [num_height_tiles * num_width_tiles, h * w, D] ->
310
- # [num_height_tiles * h, num_width_tiles * w, D]
311
- local_features = rearrange(
312
- local_features,
313
- "(th tw) (h w) d -> (th h) (tw w) d",
314
- th=num_height_tiles,
315
- tw=num_width_tiles,
316
- h=h,
317
- w=w,
267
+ for item in items:
268
+ assert item.pixel_values.dim() == 4
269
+ image_feature = self.vision.forward_features(
270
+ item.pixel_values.type(next(self.vision.parameters()).dtype).to(
271
+ device=next(self.vision.parameters()).device
272
+ )
318
273
  )
274
+ images_embeds = self.projector(image_feature)
275
+ _, hw, n_dim = images_embeds.shape
276
+ h = w = int(hw**0.5)
277
+ tile_index = 0
278
+ for jdx in range(item.image_spatial_crop.shape[1]):
279
+ num_width_tiles, num_height_tiles = item.image_spatial_crop[0, jdx]
280
+ if num_width_tiles == 0 or num_height_tiles == 0:
281
+ break
282
+ num_tiles_in_image = num_width_tiles * num_height_tiles
283
+
284
+ # [hw, D]
285
+ global_features = images_embeds[tile_index]
286
+
287
+ # [num_height_tiles * num_width_tiles, hw, D]
288
+ local_features = images_embeds[
289
+ tile_index + 1 : tile_index + 1 + num_tiles_in_image
290
+ ]
291
+ tile_index += num_tiles_in_image + 1
319
292
 
320
- # [D] -> [num_height_tiles * h, 1, D]
321
- new_lines_in_local = repeat(
322
- self.image_newline,
323
- "d -> (th h) 1 d",
324
- th=num_height_tiles,
325
- h=h,
326
- )
293
+ # format global and local features
294
+ # ----------------- global view add newline -----------------
295
+ # [hw, D] -> [h, w, D]
296
+ global_features = global_features.view(h, w, n_dim)
297
+
298
+ # [D] -> [h, 1, D]
299
+ new_lines_in_global = repeat(self.image_newline, "d -> h 1 d", h=h)
327
300
 
328
- # [num_height_tiles * h, num_width_tiles * w + 1, D]
329
- local_features = torch.cat([local_features, new_lines_in_local], dim=1)
330
-
331
- # [num_height_tiles * h, num_width_tiles * w + 1, D]
332
- # --> [(num_height_tiles * h) * (num_width_tiles * w + 1), D]
333
- local_features = local_features.view(-1, n_dim)
334
-
335
- # merge global and local tiles
336
- if self.global_view_pos == "head":
337
- global_local_features = torch.cat(
338
- [
339
- global_features,
340
- self.view_seperator[None, :],
341
- local_features,
342
- ]
301
+ # cat([h, w, D], [h, 1, D], dim=1) -> [h, w + 1, D]
302
+ global_features = torch.cat(
303
+ [global_features, new_lines_in_global], dim=1
343
304
  )
344
- else:
345
- global_local_features = torch.cat(
346
- [
347
- local_features,
348
- self.view_seperator[None, :],
349
- global_features,
350
- ]
305
+
306
+ # [h, w + 1, D] -> [h * (w + 1), D]
307
+ global_features = global_features.view(-1, n_dim)
308
+
309
+ # ----------------- local view add newline -----------------
310
+ # [num_height_tiles * num_width_tiles, h * w, D] ->
311
+ # [num_height_tiles * h, num_width_tiles * w, D]
312
+ local_features = rearrange(
313
+ local_features,
314
+ "(th tw) (h w) d -> (th h) (tw w) d",
315
+ th=num_height_tiles,
316
+ tw=num_width_tiles,
317
+ h=h,
318
+ w=w,
351
319
  )
352
320
 
353
- images_in_this_batch.append(global_local_features)
321
+ # [D] -> [num_height_tiles * h, 1, D]
322
+ new_lines_in_local = repeat(
323
+ self.image_newline,
324
+ "d -> (th h) 1 d",
325
+ th=num_height_tiles,
326
+ h=h,
327
+ )
328
+
329
+ # [num_height_tiles * h, num_width_tiles * w + 1, D]
330
+ local_features = torch.cat([local_features, new_lines_in_local], dim=1)
331
+
332
+ # [num_height_tiles * h, num_width_tiles * w + 1, D]
333
+ # --> [(num_height_tiles * h) * (num_width_tiles * w + 1), D]
334
+ local_features = local_features.view(-1, n_dim)
335
+
336
+ # merge global and local tiles
337
+ if self.global_view_pos == "head":
338
+ global_local_features = torch.cat(
339
+ [
340
+ global_features,
341
+ self.view_seperator[None, :],
342
+ local_features,
343
+ ]
344
+ )
345
+ else:
346
+ global_local_features = torch.cat(
347
+ [
348
+ local_features,
349
+ self.view_seperator[None, :],
350
+ global_features,
351
+ ]
352
+ )
353
+
354
+ images_in_this_batch.append(global_local_features)
354
355
 
355
356
  return torch.cat(images_in_this_batch, dim=0)
356
357
 
@@ -21,14 +21,7 @@ from typing import Dict, Iterable, List, Optional, Set, Tuple, TypedDict
21
21
 
22
22
  import torch
23
23
  from torch import nn
24
- from transformers import (
25
- AutoModel,
26
- BatchFeature,
27
- Gemma3Config,
28
- Gemma3Processor,
29
- PreTrainedModel,
30
- )
31
- from transformers.models.gemma3.processing_gemma3 import Gemma3ProcessorKwargs
24
+ from transformers import AutoModel, Gemma3Config, PreTrainedModel
32
25
 
33
26
  from sglang.srt.hf_transformers_utils import get_processor
34
27
  from sglang.srt.layers.layernorm import Gemma3RMSNorm
@@ -38,7 +31,11 @@ from sglang.srt.managers.mm_utils import (
38
31
  MultiModalityDataPaddingPatternTokenPairs,
39
32
  general_mm_embed_routine,
40
33
  )
41
- from sglang.srt.managers.schedule_batch import MultimodalInputs
34
+ from sglang.srt.managers.schedule_batch import (
35
+ MultimodalDataItem,
36
+ MultimodalInputs,
37
+ flatten_nested_list,
38
+ )
42
39
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
43
40
  from sglang.srt.model_loader.weight_utils import (
44
41
  default_weight_loader,
@@ -274,17 +271,16 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
274
271
  """
275
272
  return self.language_model.get_attention_sliding_window_size()
276
273
 
277
- def get_image_feature(self, image_input: MultimodalInputs):
274
+ def get_image_feature(self, items: List[MultimodalDataItem]):
278
275
  """
279
276
  Projects the last hidden state from the vision model into language model space.
280
277
 
281
- Args:
282
- pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
283
- The tensors corresponding to the input images.
284
278
  Returns:
285
279
  image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
286
280
  """
287
- pixel_values = image_input.pixel_values
281
+ pixel_values = torch.stack(
282
+ flatten_nested_list([item.pixel_values for item in items]), dim=0
283
+ )
288
284
  pixel_values = pixel_values.to("cuda")
289
285
  pixel_values = pixel_values.to(dtype=self.language_model.dtype())
290
286
 
@@ -292,61 +288,6 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
292
288
  image_features = self.multi_modal_projector(vision_outputs)
293
289
  return image_features
294
290
 
295
- def embed_mm_inputs(
296
- self,
297
- input_ids: torch.Tensor,
298
- forward_batch: ForwardBatch,
299
- image_input: MultimodalInputs,
300
- ) -> torch.Tensor:
301
- if input_ids is None:
302
- raise ValueError("Unimplemented")
303
- # boolean-masking image tokens
304
- special_image_mask = torch.isin(
305
- input_ids,
306
- torch.tensor(image_input.pad_values, device=input_ids.device),
307
- ).unsqueeze(-1)
308
- num_image_tokens_in_input_ids = special_image_mask.sum()
309
-
310
- inputs_embeds = None
311
- if num_image_tokens_in_input_ids == 0:
312
- inputs_embeds = self.get_input_embeddings()(input_ids)
313
- return inputs_embeds
314
- else:
315
- # print(f"image tokens from input_ids: {inputs_embeds[special_image_mask].numel()}")
316
- image_features = self.get_image_feature(image_input.pixel_values)
317
-
318
- # print(f"image tokens from image embeddings: {image_features.numel()}")
319
- num_image_tokens_in_embedding = (
320
- image_features.shape[0] * image_features.shape[1]
321
- )
322
-
323
- if num_image_tokens_in_input_ids != num_image_tokens_in_embedding:
324
- num_image = num_image_tokens_in_input_ids // image_features.shape[1]
325
- image_features = image_features[:num_image, :]
326
- logger.warning(
327
- f"Number of images does not match number of special image tokens in the input text. "
328
- f"Got {num_image_tokens_in_input_ids} image tokens in the text but {num_image_tokens_in_embedding} "
329
- "tokens from image embeddings."
330
- )
331
-
332
- # Important: clamp after extracting original image boundaries
333
- input_ids.clamp_(min=0, max=self.vocab_size - 1)
334
-
335
- inputs_embeds = self.get_input_embeddings()(input_ids)
336
-
337
- special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
338
- inputs_embeds.device
339
- )
340
-
341
- image_features = image_features.to(
342
- inputs_embeds.device, inputs_embeds.dtype
343
- )
344
- inputs_embeds = inputs_embeds.masked_scatter(
345
- special_image_mask, image_features
346
- )
347
-
348
- return inputs_embeds
349
-
350
291
  @torch.no_grad()
351
292
  def forward(
352
293
  self,
@@ -405,22 +346,15 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
405
346
  else:
406
347
  llm_input_ids = input_ids
407
348
 
408
- inputs_embeds = general_mm_embed_routine(
349
+ hs = general_mm_embed_routine(
409
350
  input_ids=llm_input_ids,
410
351
  forward_batch=forward_batch,
411
- embed_tokens=self.get_input_embeddings(),
412
- mm_data_embedding_func=self.get_image_feature,
413
- )
414
-
415
- outputs = self.language_model(
416
- input_ids=None,
352
+ language_model=self.language_model,
353
+ image_data_embedding_func=self.get_image_feature,
417
354
  positions=positions,
418
- forward_batch=forward_batch,
419
- input_embeds=inputs_embeds,
420
- **kwargs,
421
355
  )
422
356
 
423
- return outputs
357
+ return hs
424
358
 
425
359
  def tie_weights(self):
426
360
  return self.language_model.tie_weights()
@@ -17,7 +17,7 @@
17
17
  """Inference-only LLaMA model compatible with HuggingFace weights."""
18
18
 
19
19
  import logging
20
- from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
20
+ from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
21
21
 
22
22
  import torch
23
23
  from torch import nn
@@ -63,6 +63,7 @@ class LlamaMLP(nn.Module):
63
63
  hidden_act: str,
64
64
  quant_config: Optional[QuantizationConfig] = None,
65
65
  prefix: str = "",
66
+ reduce_results: bool = True,
66
67
  ) -> None:
67
68
  super().__init__()
68
69
  self.gate_up_proj = MergedColumnParallelLinear(
@@ -78,6 +79,7 @@ class LlamaMLP(nn.Module):
78
79
  bias=False,
79
80
  quant_config=quant_config,
80
81
  prefix=add_prefix("down_proj", prefix),
82
+ reduce_results=reduce_results,
81
83
  )
82
84
  if hidden_act != "silu":
83
85
  raise ValueError(
@@ -281,7 +283,7 @@ class LlamaModel(nn.Module):
281
283
  self.layers = make_layers(
282
284
  config.num_hidden_layers,
283
285
  lambda idx, prefix: LlamaDecoderLayer(
284
- config=config, quant_config=quant_config, layer_id=idx, prefix=prefix
286
+ config=config, layer_id=idx, quant_config=quant_config, prefix=prefix
285
287
  ),
286
288
  prefix="model.layers",
287
289
  )
@@ -375,9 +377,7 @@ class LlamaForCausalLM(nn.Module):
375
377
  super().__init__()
376
378
  self.config = config
377
379
  self.quant_config = quant_config
378
- self.model = LlamaModel(
379
- config, quant_config=quant_config, prefix=add_prefix("model", prefix)
380
- )
380
+ self.model = self._init_model(config, quant_config, add_prefix("model", prefix))
381
381
  # Llama 3.2 1B Instruct set tie_word_embeddings to True
382
382
  # Llama 3.1 8B Instruct set tie_word_embeddings to False
383
383
  if self.config.tie_word_embeddings:
@@ -402,6 +402,14 @@ class LlamaForCausalLM(nn.Module):
402
402
 
403
403
  self.capture_aux_hidden_states = False
404
404
 
405
+ def _init_model(
406
+ self,
407
+ config: LlamaConfig,
408
+ quant_config: Optional[QuantizationConfig] = None,
409
+ prefix: str = "",
410
+ ):
411
+ return LlamaModel(config, quant_config=quant_config, prefix=prefix)
412
+
405
413
  @torch.no_grad()
406
414
  def forward(
407
415
  self,
@@ -428,6 +436,9 @@ class LlamaForCausalLM(nn.Module):
428
436
  else:
429
437
  return self.pooler(hidden_states, forward_batch)
430
438
 
439
+ def get_input_embeddings(self) -> nn.Embedding:
440
+ return self.model.embed_tokens
441
+
431
442
  def get_hidden_dim(self, module_name):
432
443
  # return input_dim, output_dim
433
444
  if module_name in ["q_proj", "o_proj", "qkv_proj"]: