sglang 0.4.3.post2__py3-none-any.whl → 0.4.3.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (205) hide show
  1. sglang/api.py +1 -1
  2. sglang/bench_offline_throughput.py +19 -0
  3. sglang/bench_one_batch.py +2 -2
  4. sglang/bench_serving.py +123 -79
  5. sglang/global_config.py +8 -3
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/lang/ir.py +1 -1
  8. sglang/srt/_custom_ops.py +83 -91
  9. sglang/srt/configs/load_config.py +4 -1
  10. sglang/srt/configs/model_config.py +48 -2
  11. sglang/srt/configs/qwen2_5_vl_config.py +5 -2
  12. sglang/srt/constrained/base_grammar_backend.py +117 -15
  13. sglang/srt/constrained/llguidance_backend.py +151 -0
  14. sglang/srt/constrained/outlines_backend.py +24 -33
  15. sglang/srt/constrained/xgrammar_backend.py +69 -38
  16. sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
  17. sglang/srt/distributed/parallel_state.py +48 -3
  18. sglang/srt/entrypoints/engine.py +67 -9
  19. sglang/srt/entrypoints/http_server.py +190 -41
  20. sglang/srt/entrypoints/verl_engine.py +147 -0
  21. sglang/srt/function_call_parser.py +0 -1
  22. sglang/srt/layers/activation.py +11 -0
  23. sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
  24. sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
  25. sglang/srt/layers/attention/flashinfer_backend.py +220 -378
  26. sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
  27. sglang/srt/layers/attention/torch_native_backend.py +1 -1
  28. sglang/srt/layers/attention/triton_backend.py +9 -6
  29. sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
  30. sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
  31. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
  32. sglang/srt/layers/attention/utils.py +39 -0
  33. sglang/srt/layers/attention/vision.py +60 -63
  34. sglang/srt/layers/dp_attention.py +142 -1
  35. sglang/srt/layers/layernorm.py +1 -1
  36. sglang/srt/layers/linear.py +3 -1
  37. sglang/srt/layers/logits_processor.py +281 -45
  38. sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
  39. sglang/srt/layers/moe/ep_moe/layer.py +140 -28
  40. sglang/srt/layers/moe/fused_moe_native.py +2 -0
  41. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
  43. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
  47. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
  48. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
  51. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
  55. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
  56. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
  63. sglang/srt/layers/moe/topk.py +13 -4
  64. sglang/srt/layers/quantization/__init__.py +111 -7
  65. sglang/srt/layers/quantization/blockwise_int8.py +409 -0
  66. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  68. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  69. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  70. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  71. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  72. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  73. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  74. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  75. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  76. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  77. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  78. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  79. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  80. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  81. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  82. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  83. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  84. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  85. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  86. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  87. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  88. sglang/srt/layers/quantization/fp8.py +69 -28
  89. sglang/srt/layers/quantization/fp8_utils.py +17 -1
  90. sglang/srt/layers/quantization/gptq.py +416 -0
  91. sglang/srt/layers/quantization/int8_kernel.py +327 -0
  92. sglang/srt/layers/quantization/int8_utils.py +73 -0
  93. sglang/srt/layers/quantization/modelopt_quant.py +18 -1
  94. sglang/srt/layers/radix_attention.py +1 -0
  95. sglang/srt/layers/rotary_embedding.py +0 -1
  96. sglang/srt/layers/sampler.py +76 -31
  97. sglang/srt/layers/vocab_parallel_embedding.py +14 -13
  98. sglang/srt/lora/lora.py +17 -1
  99. sglang/srt/lora/lora_config.py +5 -0
  100. sglang/srt/lora/lora_manager.py +1 -3
  101. sglang/srt/managers/cache_controller.py +193 -62
  102. sglang/srt/managers/configure_logging.py +2 -1
  103. sglang/srt/managers/data_parallel_controller.py +6 -2
  104. sglang/srt/managers/detokenizer_manager.py +124 -102
  105. sglang/srt/managers/image_processor.py +2 -1
  106. sglang/srt/managers/io_struct.py +143 -6
  107. sglang/srt/managers/schedule_batch.py +237 -197
  108. sglang/srt/managers/schedule_policy.py +29 -29
  109. sglang/srt/managers/scheduler.py +681 -259
  110. sglang/srt/managers/session_controller.py +6 -2
  111. sglang/srt/managers/tokenizer_manager.py +224 -68
  112. sglang/srt/managers/tp_worker.py +15 -4
  113. sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
  114. sglang/srt/mem_cache/chunk_cache.py +18 -11
  115. sglang/srt/mem_cache/hiradix_cache.py +394 -0
  116. sglang/srt/mem_cache/memory_pool.py +44 -18
  117. sglang/srt/mem_cache/radix_cache.py +58 -47
  118. sglang/srt/metrics/collector.py +94 -36
  119. sglang/srt/model_executor/cuda_graph_runner.py +55 -24
  120. sglang/srt/model_executor/forward_batch_info.py +49 -16
  121. sglang/srt/model_executor/model_runner.py +208 -28
  122. sglang/srt/model_loader/loader.py +3 -3
  123. sglang/srt/model_loader/weight_utils.py +36 -14
  124. sglang/srt/models/baichuan.py +31 -6
  125. sglang/srt/models/chatglm.py +39 -7
  126. sglang/srt/models/commandr.py +29 -5
  127. sglang/srt/models/dbrx.py +31 -5
  128. sglang/srt/models/deepseek.py +43 -6
  129. sglang/srt/models/deepseek_nextn.py +32 -19
  130. sglang/srt/models/deepseek_v2.py +265 -32
  131. sglang/srt/models/exaone.py +19 -9
  132. sglang/srt/models/gemma.py +22 -8
  133. sglang/srt/models/gemma2.py +25 -12
  134. sglang/srt/models/gemma2_reward.py +5 -1
  135. sglang/srt/models/gpt2.py +28 -13
  136. sglang/srt/models/gpt_bigcode.py +27 -5
  137. sglang/srt/models/granite.py +21 -9
  138. sglang/srt/models/grok.py +21 -4
  139. sglang/srt/models/internlm2.py +36 -6
  140. sglang/srt/models/internlm2_reward.py +5 -1
  141. sglang/srt/models/llama.py +26 -9
  142. sglang/srt/models/llama_classification.py +5 -1
  143. sglang/srt/models/llama_eagle.py +17 -4
  144. sglang/srt/models/llama_embedding.py +5 -1
  145. sglang/srt/models/llama_reward.py +7 -2
  146. sglang/srt/models/llava.py +19 -3
  147. sglang/srt/models/llavavid.py +10 -1
  148. sglang/srt/models/minicpm.py +26 -2
  149. sglang/srt/models/minicpm3.py +39 -3
  150. sglang/srt/models/minicpmv.py +45 -14
  151. sglang/srt/models/mixtral.py +20 -9
  152. sglang/srt/models/mixtral_quant.py +50 -8
  153. sglang/srt/models/mllama.py +57 -11
  154. sglang/srt/models/olmo.py +34 -6
  155. sglang/srt/models/olmo2.py +34 -13
  156. sglang/srt/models/olmoe.py +26 -4
  157. sglang/srt/models/phi3_small.py +29 -10
  158. sglang/srt/models/qwen.py +26 -3
  159. sglang/srt/models/qwen2.py +26 -4
  160. sglang/srt/models/qwen2_5_vl.py +46 -8
  161. sglang/srt/models/qwen2_eagle.py +17 -5
  162. sglang/srt/models/qwen2_moe.py +44 -6
  163. sglang/srt/models/qwen2_rm.py +78 -0
  164. sglang/srt/models/qwen2_vl.py +39 -8
  165. sglang/srt/models/stablelm.py +32 -5
  166. sglang/srt/models/torch_native_llama.py +5 -2
  167. sglang/srt/models/xverse.py +21 -9
  168. sglang/srt/models/xverse_moe.py +45 -7
  169. sglang/srt/models/yivl.py +2 -1
  170. sglang/srt/openai_api/adapter.py +109 -24
  171. sglang/srt/openai_api/protocol.py +17 -1
  172. sglang/srt/reasoning_parser.py +154 -0
  173. sglang/srt/sampling/penaltylib/__init__.py +4 -6
  174. sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
  175. sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
  176. sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
  177. sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
  178. sglang/srt/sampling/sampling_batch_info.py +79 -157
  179. sglang/srt/sampling/sampling_params.py +16 -13
  180. sglang/srt/server_args.py +136 -52
  181. sglang/srt/speculative/build_eagle_tree.py +2 -8
  182. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +0 -1
  183. sglang/srt/speculative/eagle_utils.py +92 -58
  184. sglang/srt/speculative/eagle_worker.py +186 -94
  185. sglang/srt/speculative/spec_info.py +1 -13
  186. sglang/srt/utils.py +43 -17
  187. sglang/srt/warmup.py +47 -0
  188. sglang/test/few_shot_gsm8k.py +4 -1
  189. sglang/test/runners.py +389 -126
  190. sglang/test/send_one.py +88 -0
  191. sglang/test/test_block_fp8_ep.py +361 -0
  192. sglang/test/test_programs.py +1 -1
  193. sglang/test/test_utils.py +138 -84
  194. sglang/utils.py +50 -60
  195. sglang/version.py +1 -1
  196. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/METADATA +21 -15
  197. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/RECORD +200 -166
  198. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/WHEEL +1 -1
  199. sglang/bench_latency.py +0 -1
  200. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
  201. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
  202. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
  203. sglang/test/srt/sampling/penaltylib/utils.py +0 -344
  204. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/LICENSE +0 -0
  205. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/top_level.txt +0 -0
@@ -36,6 +36,7 @@ from sglang.srt.managers.schedule_batch import ImageInputs
36
36
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
37
37
  from sglang.srt.model_loader.weight_utils import default_weight_loader
38
38
  from sglang.srt.models.llama import LlamaDecoderLayer, LlamaMLP
39
+ from sglang.srt.utils import add_prefix
39
40
 
40
41
 
41
42
  class ColumnParallelConv2dPatch(torch.nn.Module):
@@ -147,7 +148,12 @@ class MllamaPrecomputedPositionEmbedding(nn.Module):
147
148
 
148
149
 
149
150
  class MllamaVisionMLP(nn.Module):
150
- def __init__(self, config, quant_config: Optional[QuantizationConfig] = None):
151
+ def __init__(
152
+ self,
153
+ config,
154
+ quant_config: Optional[QuantizationConfig] = None,
155
+ prefix: str = "",
156
+ ):
151
157
  super().__init__()
152
158
  self.config = config
153
159
  self.activation_fn = get_act_fn(config.hidden_act)
@@ -156,12 +162,14 @@ class MllamaVisionMLP(nn.Module):
156
162
  config.intermediate_size,
157
163
  bias=True,
158
164
  quant_config=quant_config,
165
+ prefix=add_prefix("fc1", prefix),
159
166
  )
160
167
  self.fc2 = RowParallelLinear(
161
168
  config.intermediate_size,
162
169
  config.hidden_size,
163
170
  bias=True,
164
171
  quant_config=quant_config,
172
+ prefix=add_prefix("fc2", prefix),
165
173
  )
166
174
 
167
175
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@@ -174,7 +182,10 @@ class MllamaVisionMLP(nn.Module):
174
182
 
175
183
  class MllamaVisionEncoderLayer(nn.Module):
176
184
  def __init__(
177
- self, config: config_mllama.MllamaVisionConfig, is_gated: bool = False
185
+ self,
186
+ config: config_mllama.MllamaVisionConfig,
187
+ is_gated: bool = False,
188
+ prefix: str = "",
178
189
  ):
179
190
  super().__init__()
180
191
 
@@ -193,8 +204,9 @@ class MllamaVisionEncoderLayer(nn.Module):
193
204
  use_context_forward=False,
194
205
  use_full_precision_softmax=False,
195
206
  flatten_batch=False,
207
+ prefix=add_prefix("self_attn", prefix),
196
208
  )
197
- self.mlp = MllamaVisionMLP(config)
209
+ self.mlp = MllamaVisionMLP(config, prefix=add_prefix("mlp", prefix))
198
210
 
199
211
  self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=config.norm_eps)
200
212
  self.post_attention_layernorm = nn.LayerNorm(
@@ -235,11 +247,17 @@ class MllamaVisionEncoder(nn.Module):
235
247
  num_layers=32,
236
248
  is_gated=False,
237
249
  output_hidden_states=None,
250
+ prefix: str = "",
238
251
  ):
239
252
  super().__init__()
240
253
  self.config = config
241
254
  self.layers = nn.ModuleList(
242
- [MllamaVisionEncoderLayer(config, is_gated) for _ in range(num_layers)]
255
+ [
256
+ MllamaVisionEncoderLayer(
257
+ config, is_gated, prefix=add_prefix(f"layers.{i}", prefix)
258
+ )
259
+ for i in range(num_layers)
260
+ ]
243
261
  )
244
262
  self.output_hidden_states = output_hidden_states or []
245
263
 
@@ -265,7 +283,7 @@ class MllamaVisionEncoder(nn.Module):
265
283
 
266
284
 
267
285
  class MllamaVisionModel(nn.Module):
268
- def __init__(self, config: config_mllama.MllamaVisionConfig):
286
+ def __init__(self, config: config_mllama.MllamaVisionConfig, prefix: str = ""):
269
287
  super().__init__()
270
288
  self.image_size = config.image_size
271
289
  self.patch_size = config.patch_size
@@ -305,9 +323,13 @@ class MllamaVisionModel(nn.Module):
305
323
  config.num_hidden_layers,
306
324
  is_gated=False,
307
325
  output_hidden_states=config.intermediate_layers_indices,
326
+ prefix=add_prefix("transformer", prefix),
308
327
  )
309
328
  self.global_transformer = MllamaVisionEncoder(
310
- config, config.num_global_layers, is_gated=True
329
+ config,
330
+ config.num_global_layers,
331
+ is_gated=True,
332
+ prefix=add_prefix("global_transformer", prefix),
311
333
  )
312
334
 
313
335
  def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor:
@@ -464,6 +486,7 @@ class MllamaTextCrossAttention(nn.Module):
464
486
  config: Optional[config_mllama.MllamaTextConfig] = None,
465
487
  layer_id: Optional[int] = None,
466
488
  quant_config: Optional[QuantizationConfig] = None,
489
+ prefix: str = "",
467
490
  ):
468
491
  super().__init__()
469
492
  self.config = config
@@ -489,6 +512,7 @@ class MllamaTextCrossAttention(nn.Module):
489
512
  self.num_key_value_heads,
490
513
  bias=False,
491
514
  quant_config=quant_config,
515
+ prefix=add_prefix("qkv_proj", prefix),
492
516
  )
493
517
  self.o_proj = RowParallelLinear(
494
518
  self.num_heads * self.head_dim,
@@ -496,6 +520,7 @@ class MllamaTextCrossAttention(nn.Module):
496
520
  bias=False,
497
521
  input_is_parallel=True,
498
522
  quant_config=quant_config,
523
+ prefix=add_prefix("o_proj", prefix),
499
524
  )
500
525
  # vllm.model_executor.layers.layernorm.RMSNorm has precision issue,
501
526
  # use huggingface's instead
@@ -510,6 +535,7 @@ class MllamaTextCrossAttention(nn.Module):
510
535
  self.num_local_key_value_heads,
511
536
  layer_id=layer_id,
512
537
  is_cross_attention=True,
538
+ prefix=add_prefix("attn", prefix),
513
539
  )
514
540
 
515
541
  def forward(
@@ -551,6 +577,7 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
551
577
  config: config_mllama.MllamaTextConfig,
552
578
  layer_id: int,
553
579
  quant_config: Optional[QuantizationConfig],
580
+ prefix: str = "",
554
581
  ) -> None:
555
582
  super().__init__()
556
583
  self.layer_id = layer_id
@@ -558,6 +585,7 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
558
585
  config=config,
559
586
  layer_id=layer_id,
560
587
  quant_config=quant_config,
588
+ prefix=add_prefix("cross_attn", prefix),
561
589
  )
562
590
 
563
591
  self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -568,6 +596,7 @@ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
568
596
  intermediate_size=config.intermediate_size,
569
597
  hidden_act=config.hidden_act,
570
598
  quant_config=quant_config,
599
+ prefix=add_prefix("mlp", prefix),
571
600
  )
572
601
  self.post_attention_layernorm = RMSNorm(
573
602
  config.hidden_size, eps=config.rms_norm_eps
@@ -610,12 +639,15 @@ class MllamaTextModel(nn.Module):
610
639
  self,
611
640
  config: config_mllama.MllamaTextConfig,
612
641
  quant_config: Optional[QuantizationConfig],
642
+ prefix: str = "",
613
643
  ):
614
644
  super().__init__()
615
645
  self.padding_id = config.pad_token_id
616
646
  self.vocab_size = config.vocab_size
617
647
  self.embed_tokens = VocabParallelEmbedding(
618
- config.vocab_size + 8, config.hidden_size
648
+ config.vocab_size + 8,
649
+ config.hidden_size,
650
+ prefix=add_prefix("embed_tokens", prefix),
619
651
  )
620
652
  self.cross_attention_layers = config.cross_attention_layers
621
653
 
@@ -624,14 +656,20 @@ class MllamaTextModel(nn.Module):
624
656
  if layer_id in self.cross_attention_layers:
625
657
  layers.append(
626
658
  MllamaCrossAttentionDecoderLayer(
627
- config, layer_id, quant_config=quant_config
659
+ config,
660
+ layer_id,
661
+ quant_config=quant_config,
662
+ prefix=add_prefix(f"layers.{layer_id}", prefix),
628
663
  )
629
664
  )
630
665
  else:
631
666
  # TODO: force LlamaDecoderLayer to config.attention_bias=False
632
667
  layers.append(
633
668
  LlamaDecoderLayer(
634
- config, quant_config=quant_config, layer_id=layer_id
669
+ config,
670
+ quant_config=quant_config,
671
+ layer_id=layer_id,
672
+ prefix=add_prefix(f"layers.{layer_id}", prefix),
635
673
  )
636
674
  )
637
675
 
@@ -687,16 +725,20 @@ class MllamaForCausalLM(nn.Module):
687
725
  self,
688
726
  config: config_mllama.MllamaTextConfig,
689
727
  quant_config: Optional[QuantizationConfig],
728
+ prefix: str = "",
690
729
  ):
691
730
  super().__init__()
692
731
  self.vocab_size = config.vocab_size
693
- self.model = MllamaTextModel(config, quant_config)
732
+ self.model = MllamaTextModel(
733
+ config, quant_config, prefix=add_prefix("model", prefix)
734
+ )
694
735
  self.lm_head = ParallelLMHead(
695
736
  config.vocab_size,
696
737
  config.hidden_size,
697
738
  org_num_embeddings=config.vocab_size,
698
739
  padding_size=DEFAULT_VOCAB_PADDING_SIZE,
699
740
  quant_config=quant_config,
741
+ prefix=add_prefix("lm_head", prefix),
700
742
  )
701
743
 
702
744
  def forward(
@@ -726,6 +768,7 @@ class MllamaForConditionalGeneration(nn.Module):
726
768
  self,
727
769
  config: config_mllama.MllamaConfig,
728
770
  quant_config: Optional[QuantizationConfig] = None,
771
+ prefix: str = "",
729
772
  ):
730
773
  super().__init__()
731
774
  self.vocab_size = config.text_config.vocab_size
@@ -737,10 +780,13 @@ class MllamaForConditionalGeneration(nn.Module):
737
780
  )
738
781
  self.image_size = config.vision_config.image_size
739
782
 
740
- self.vision_model = MllamaVisionModel(config.vision_config)
783
+ self.vision_model = MllamaVisionModel(
784
+ config.vision_config, prefix=add_prefix("vision_model", prefix)
785
+ )
741
786
  self.language_model = MllamaForCausalLM(
742
787
  config.text_config,
743
788
  quant_config=quant_config,
789
+ prefix=add_prefix("language_model", prefix),
744
790
  )
745
791
  self.multi_modal_projector = nn.Linear(
746
792
  config.vision_config.vision_output_dim,
sglang/srt/models/olmo.py CHANGED
@@ -38,7 +38,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
38
38
  )
39
39
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
40
40
  from sglang.srt.model_loader.weight_utils import default_weight_loader
41
- from sglang.srt.utils import make_layers
41
+ from sglang.srt.utils import add_prefix, make_layers
42
42
 
43
43
 
44
44
  class OlmoAttention(nn.Module):
@@ -53,6 +53,7 @@ class OlmoAttention(nn.Module):
53
53
  config: OlmoConfig,
54
54
  layer_id: int = 0,
55
55
  quant_config: Optional[QuantizationConfig] = None,
56
+ prefix: str = "",
56
57
  ):
57
58
  super().__init__()
58
59
  self.config = config
@@ -75,6 +76,7 @@ class OlmoAttention(nn.Module):
75
76
  self.head_dim,
76
77
  self.total_num_heads,
77
78
  bias=config.attention_bias,
79
+ prefix=add_prefix("qkv_proj", prefix),
78
80
  )
79
81
 
80
82
  # Rotary embeddings.
@@ -91,6 +93,7 @@ class OlmoAttention(nn.Module):
91
93
  self.scaling,
92
94
  num_kv_heads=self.num_heads,
93
95
  layer_id=layer_id,
96
+ prefix=add_prefix("attn", prefix),
94
97
  )
95
98
 
96
99
  # Attention output projection.
@@ -98,6 +101,7 @@ class OlmoAttention(nn.Module):
98
101
  self.hidden_size,
99
102
  self.hidden_size,
100
103
  bias=config.attention_bias,
104
+ prefix=add_prefix("o_proj", prefix),
101
105
  )
102
106
 
103
107
  def forward(
@@ -127,6 +131,7 @@ class OlmoMLP(nn.Module):
127
131
  self,
128
132
  config: OlmoConfig,
129
133
  quant_config: Optional[QuantizationConfig] = None,
134
+ prefix: str = "",
130
135
  ):
131
136
  super().__init__()
132
137
  self.config = config
@@ -139,6 +144,7 @@ class OlmoMLP(nn.Module):
139
144
  [self.intermediate_size] * 2,
140
145
  bias=False,
141
146
  quant_config=quant_config,
147
+ prefix=add_prefix("gate_up_proj", prefix),
142
148
  )
143
149
 
144
150
  # Activation function.
@@ -150,6 +156,7 @@ class OlmoMLP(nn.Module):
150
156
  self.hidden_size,
151
157
  bias=False,
152
158
  quant_config=quant_config,
159
+ prefix=add_prefix("down_proj", prefix),
153
160
  )
154
161
 
155
162
  def forward(
@@ -174,13 +181,23 @@ class OlmoDecoderLayer(nn.Module):
174
181
  config: OlmoConfig,
175
182
  layer_id: int = 0,
176
183
  quant_config: Optional[QuantizationConfig] = None,
184
+ prefix: str = "",
177
185
  ):
178
186
  super().__init__()
179
187
  # Attention block.
180
- self.self_attn = OlmoAttention(config, layer_id, quant_config)
188
+ self.self_attn = OlmoAttention(
189
+ config,
190
+ layer_id,
191
+ quant_config,
192
+ prefix=add_prefix("self_attn", prefix),
193
+ )
181
194
 
182
195
  # MLP block.
183
- self.mlp = OlmoMLP(config, quant_config)
196
+ self.mlp = OlmoMLP(
197
+ config,
198
+ quant_config,
199
+ prefix=add_prefix("mlp", prefix),
200
+ )
184
201
 
185
202
  # LayerNorm
186
203
  self.input_layernorm = nn.LayerNorm(
@@ -213,13 +230,18 @@ class OlmoDecoderLayer(nn.Module):
213
230
  class OlmoModel(nn.Module):
214
231
 
215
232
  def __init__(
216
- self, config: OlmoConfig, quant_config: Optional[QuantizationConfig] = None
233
+ self,
234
+ config: OlmoConfig,
235
+ quant_config: Optional[QuantizationConfig] = None,
236
+ prefix: str = "",
217
237
  ):
218
238
  super().__init__()
219
239
  self.config = config
220
240
 
221
241
  self.embed_tokens = VocabParallelEmbedding(
222
- config.vocab_size, config.hidden_size
242
+ config.vocab_size,
243
+ config.hidden_size,
244
+ prefix=add_prefix("embed_tokens", prefix),
223
245
  )
224
246
  self.layers = make_layers(
225
247
  config.num_hidden_layers,
@@ -227,7 +249,9 @@ class OlmoModel(nn.Module):
227
249
  layer_id=idx,
228
250
  config=config,
229
251
  quant_config=quant_config,
252
+ prefix=prefix,
230
253
  ),
254
+ prefix=add_prefix("layers", prefix),
231
255
  )
232
256
  self.norm = nn.LayerNorm(
233
257
  config.hidden_size, elementwise_affine=False, bias=False
@@ -275,10 +299,11 @@ class OlmoForCausalLM(nn.Module):
275
299
  self,
276
300
  config: OlmoConfig,
277
301
  quant_config: Optional[QuantizationConfig] = None,
302
+ prefix: str = "",
278
303
  ):
279
304
  super().__init__()
280
305
  self.config = config
281
- self.model = OlmoModel(config, quant_config)
306
+ self.model = OlmoModel(config, quant_config, prefix=add_prefix("model", prefix))
282
307
  if config.tie_word_embeddings:
283
308
  self.lm_head = self.model.embed_tokens
284
309
  else:
@@ -288,6 +313,7 @@ class OlmoForCausalLM(nn.Module):
288
313
  config.hidden_size,
289
314
  org_num_embeddings=config.vocab_size,
290
315
  quant_config=quant_config,
316
+ prefix=add_prefix("lm_head", prefix),
291
317
  )
292
318
  self.logits_processor = LogitsProcessor(config)
293
319
 
@@ -325,6 +351,8 @@ class OlmoForCausalLM(nn.Module):
325
351
  # Models trained using ColossalAI may include these tensors in
326
352
  # the checkpoint. Skip them.
327
353
  continue
354
+ if self.config.tie_word_embeddings and "lm_head.weight" in name:
355
+ continue
328
356
  for param_name, weight_name, shard_id in stacked_params_mapping:
329
357
  if weight_name not in name:
330
358
  continue
@@ -45,7 +45,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
45
45
  )
46
46
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
47
47
  from sglang.srt.model_loader.weight_utils import default_weight_loader
48
- from sglang.srt.utils import make_layers
48
+ from sglang.srt.utils import add_prefix, make_layers
49
49
 
50
50
 
51
51
  class Olmo2Attention(nn.Module):
@@ -60,28 +60,29 @@ class Olmo2Attention(nn.Module):
60
60
  config: PretrainedConfig,
61
61
  layer_id: int = 0,
62
62
  quant_config: Optional[QuantizationConfig] = None,
63
+ prefix: str = "",
63
64
  ):
64
65
  super().__init__()
65
66
  self.config = config
66
67
  self.hidden_size = config.hidden_size
67
- tp_size = get_tensor_model_parallel_world_size()
68
+ self.tp_size = get_tensor_model_parallel_world_size()
68
69
  self.total_num_heads = config.num_attention_heads
69
70
 
70
71
  assert self.hidden_size % self.total_num_heads == 0
71
- assert self.total_num_heads % tp_size == 0
72
+ assert self.total_num_heads % self.tp_size == 0
72
73
 
73
- self.num_heads = self.total_num_heads // tp_size
74
+ self.num_heads = self.total_num_heads // self.tp_size
74
75
  self.total_num_kv_heads = self.config.num_key_value_heads
75
76
 
76
- if self.total_num_kv_heads >= tp_size:
77
+ if self.total_num_kv_heads >= self.tp_size:
77
78
  # Number of KV heads is greater than TP size, so we partition
78
79
  # the KV heads across multiple tensor parallel GPUs.
79
- assert self.total_num_kv_heads % tp_size == 0
80
+ assert self.total_num_kv_heads % self.tp_size == 0
80
81
  else:
81
82
  # Number of KV heads is less than TP size, so we replicate
82
83
  # the KV heads across multiple tensor parallel GPUs.
83
- assert tp_size % self.total_num_kv_heads == 0
84
- self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
84
+ assert self.tp_size % self.total_num_kv_heads == 0
85
+ self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size)
85
86
 
86
87
  self.head_dim = self.hidden_size // self.total_num_heads
87
88
  self.max_position_embeddings = config.max_position_embeddings
@@ -93,6 +94,8 @@ class Olmo2Attention(nn.Module):
93
94
  self.head_dim,
94
95
  self.total_num_heads,
95
96
  bias=config.attention_bias,
97
+ quant_config=quant_config,
98
+ prefix=add_prefix("qkv_proj", prefix),
96
99
  )
97
100
  self.tp_rank = get_tensor_model_parallel_rank()
98
101
 
@@ -115,6 +118,7 @@ class Olmo2Attention(nn.Module):
115
118
  self.scaling,
116
119
  num_kv_heads=self.num_kv_heads,
117
120
  layer_id=layer_id,
121
+ prefix=add_prefix("attn", prefix),
118
122
  )
119
123
 
120
124
  # Attention output projection.
@@ -122,6 +126,8 @@ class Olmo2Attention(nn.Module):
122
126
  self.head_dim * self.total_num_heads,
123
127
  self.hidden_size,
124
128
  bias=config.attention_bias,
129
+ quant_config=quant_config,
130
+ prefix=add_prefix("o_proj", prefix),
125
131
  )
126
132
 
127
133
  def _apply_qk_norm(
@@ -164,6 +170,7 @@ class Olmo2MLP(nn.Module):
164
170
  self,
165
171
  config: PretrainedConfig,
166
172
  quant_config: Optional[QuantizationConfig] = None,
173
+ prefix: str = "",
167
174
  ):
168
175
  super().__init__()
169
176
  self.config = config
@@ -176,6 +183,7 @@ class Olmo2MLP(nn.Module):
176
183
  [self.intermediate_size] * 2,
177
184
  bias=False,
178
185
  quant_config=quant_config,
186
+ prefix=add_prefix("gate_up_proj", prefix),
179
187
  )
180
188
 
181
189
  # Activation function.
@@ -187,6 +195,7 @@ class Olmo2MLP(nn.Module):
187
195
  self.hidden_size,
188
196
  bias=False,
189
197
  quant_config=quant_config,
198
+ prefix=add_prefix("down_proj", prefix),
190
199
  )
191
200
 
192
201
  def forward(
@@ -211,13 +220,16 @@ class Olmo2DecoderLayer(nn.Module):
211
220
  config: PretrainedConfig,
212
221
  layer_id: int = 0,
213
222
  quant_config: Optional[QuantizationConfig] = None,
223
+ prefix: str = "",
214
224
  ):
215
225
  super().__init__()
216
226
  # Attention block.
217
- self.self_attn = Olmo2Attention(config, layer_id, quant_config)
227
+ self.self_attn = Olmo2Attention(
228
+ config, layer_id, quant_config, prefix=add_prefix("self_attn", prefix)
229
+ )
218
230
 
219
231
  # MLP block.
220
- self.mlp = Olmo2MLP(config, quant_config)
232
+ self.mlp = Olmo2MLP(config, quant_config, prefix=add_prefix("mlp", prefix))
221
233
 
222
234
  # RMSNorm
223
235
  self.post_attention_layernorm = RMSNorm(
@@ -254,12 +266,15 @@ class Olmo2Model(nn.Module):
254
266
  self,
255
267
  config: PretrainedConfig,
256
268
  quant_config: Optional[QuantizationConfig] = None,
269
+ prefix: str = "",
257
270
  ):
258
271
  super().__init__()
259
272
  self.config = config
260
273
 
261
274
  self.embed_tokens = VocabParallelEmbedding(
262
- config.vocab_size, config.hidden_size
275
+ config.vocab_size,
276
+ config.hidden_size,
277
+ prefix=add_prefix("embed_tokens", prefix),
263
278
  )
264
279
  self.layers = make_layers(
265
280
  config.num_hidden_layers,
@@ -267,7 +282,9 @@ class Olmo2Model(nn.Module):
267
282
  layer_id=idx,
268
283
  config=config,
269
284
  quant_config=quant_config,
285
+ prefix=prefix,
270
286
  ),
287
+ prefix=add_prefix("layers", prefix),
271
288
  )
272
289
  self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
273
290
 
@@ -313,10 +330,13 @@ class Olmo2ForCausalLM(nn.Module):
313
330
  self,
314
331
  config: PretrainedConfig,
315
332
  quant_config: Optional[QuantizationConfig] = None,
333
+ prefix: str = "",
316
334
  ):
317
335
  super().__init__()
318
336
  self.config = config
319
- self.model = Olmo2Model(config, quant_config)
337
+ self.model = Olmo2Model(
338
+ config, quant_config, prefix=add_prefix("model", prefix)
339
+ )
320
340
  if config.tie_word_embeddings:
321
341
  self.lm_head = self.model.embed_tokens
322
342
  else:
@@ -326,6 +346,7 @@ class Olmo2ForCausalLM(nn.Module):
326
346
  config.hidden_size,
327
347
  org_num_embeddings=config.vocab_size,
328
348
  quant_config=quant_config,
349
+ prefix=add_prefix("lm_head", prefix),
329
350
  )
330
351
  self.logits_processor = LogitsProcessor(config)
331
352
 
@@ -343,7 +364,7 @@ class Olmo2ForCausalLM(nn.Module):
343
364
  input_embeds=input_embeds,
344
365
  )
345
366
  return self.logits_processor(
346
- input_ids, hidden_states, self.lm_head.weight, forward_batch
367
+ input_ids, hidden_states, self.lm_head, forward_batch
347
368
  )
348
369
 
349
370
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
@@ -41,7 +41,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
41
41
  )
42
42
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
43
43
  from sglang.srt.model_loader.weight_utils import default_weight_loader
44
- from sglang.srt.utils import make_layers, print_warning_once
44
+ from sglang.srt.utils import add_prefix, make_layers, print_warning_once
45
45
 
46
46
 
47
47
  class OlmoeMoE(nn.Module):
@@ -69,7 +69,11 @@ class OlmoeMoE(nn.Module):
69
69
 
70
70
  # Gate always runs at half / full precision for now.
71
71
  self.gate = ReplicatedLinear(
72
- hidden_size, num_experts, bias=False, quant_config=None
72
+ hidden_size,
73
+ num_experts,
74
+ bias=False,
75
+ quant_config=None,
76
+ prefix=add_prefix("gate", prefix),
73
77
  )
74
78
 
75
79
  self.experts = FusedMoE(
@@ -81,6 +85,7 @@ class OlmoeMoE(nn.Module):
81
85
  renormalize=False,
82
86
  quant_config=quant_config,
83
87
  tp_size=tp_size,
88
+ prefix=add_prefix("experts", prefix),
84
89
  )
85
90
 
86
91
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@@ -107,6 +112,7 @@ class OlmoeAttention(nn.Module):
107
112
  rope_scaling: Optional[Dict[str, Any]] = None,
108
113
  max_position_embeddings: int = 4096,
109
114
  quant_config: Optional[QuantizationConfig] = None,
115
+ prefix: str = "",
110
116
  ) -> None:
111
117
  super().__init__()
112
118
  self.hidden_size = hidden_size
@@ -138,6 +144,7 @@ class OlmoeAttention(nn.Module):
138
144
  self.total_num_kv_heads,
139
145
  bias=False,
140
146
  quant_config=quant_config,
147
+ prefix=add_prefix("qkv_proj", prefix),
141
148
  )
142
149
  self.q_norm = RMSNorm(hidden_size, eps=1e-5)
143
150
  self.k_norm = RMSNorm(hidden_size, eps=1e-5)
@@ -146,6 +153,7 @@ class OlmoeAttention(nn.Module):
146
153
  hidden_size,
147
154
  bias=False,
148
155
  quant_config=quant_config,
156
+ prefix=add_prefix("o_proj", prefix),
149
157
  )
150
158
 
151
159
  self.rotary_emb = get_rope(
@@ -162,6 +170,7 @@ class OlmoeAttention(nn.Module):
162
170
  self.scaling,
163
171
  layer_id=layer_id,
164
172
  num_kv_heads=self.num_kv_heads,
173
+ prefix=add_prefix("attn", prefix),
165
174
  )
166
175
 
167
176
  def forward(
@@ -186,6 +195,7 @@ class OlmoeDecoderLayer(nn.Module):
186
195
  config: PretrainedConfig,
187
196
  layer_id: int = 0,
188
197
  quant_config: Optional[QuantizationConfig] = None,
198
+ prefix: str = "",
189
199
  ) -> None:
190
200
  super().__init__()
191
201
  self.hidden_size = config.hidden_size
@@ -202,6 +212,7 @@ class OlmoeDecoderLayer(nn.Module):
202
212
  rope_scaling=rope_scaling,
203
213
  max_position_embeddings=max_position_embeddings,
204
214
  quant_config=quant_config,
215
+ prefix=add_prefix("self_attn", prefix),
205
216
  )
206
217
 
207
218
  self.mlp = OlmoeMoE(
@@ -210,6 +221,7 @@ class OlmoeDecoderLayer(nn.Module):
210
221
  hidden_size=config.hidden_size,
211
222
  intermediate_size=config.intermediate_size,
212
223
  quant_config=quant_config,
224
+ prefix=add_prefix("mlp", prefix),
213
225
  )
214
226
  self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
215
227
  self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
@@ -246,6 +258,7 @@ class OlmoeModel(nn.Module):
246
258
  self,
247
259
  config: PretrainedConfig,
248
260
  quant_config: Optional[QuantizationConfig] = None,
261
+ prefix: str = "",
249
262
  ) -> None:
250
263
  super().__init__()
251
264
  self.padding_idx = config.pad_token_id
@@ -254,6 +267,7 @@ class OlmoeModel(nn.Module):
254
267
  self.embed_tokens = VocabParallelEmbedding(
255
268
  config.vocab_size,
256
269
  config.hidden_size,
270
+ prefix=add_prefix("embed_tokens", prefix),
257
271
  )
258
272
  self.layers = make_layers(
259
273
  config.num_hidden_layers,
@@ -261,7 +275,9 @@ class OlmoeModel(nn.Module):
261
275
  config=config,
262
276
  quant_config=quant_config,
263
277
  layer_id=idx,
278
+ prefix=prefix,
264
279
  ),
280
+ prefix=add_prefix("layers", prefix),
265
281
  )
266
282
  self.norm = RMSNorm(config.hidden_size, eps=1e-5)
267
283
 
@@ -294,13 +310,19 @@ class OlmoeForCausalLM(nn.Module):
294
310
  self,
295
311
  config: PretrainedConfig,
296
312
  quant_config: Optional[QuantizationConfig] = None,
313
+ prefix: str = "",
297
314
  ) -> None:
298
315
  super().__init__()
299
316
  self.config = config
300
317
  self.quant_config = quant_config
301
- self.model = OlmoeModel(config, quant_config)
318
+ self.model = OlmoeModel(
319
+ config, quant_config, prefix=add_prefix("model", prefix)
320
+ )
302
321
  self.lm_head = ParallelLMHead(
303
- config.vocab_size, config.hidden_size, quant_config=quant_config
322
+ config.vocab_size,
323
+ config.hidden_size,
324
+ quant_config=quant_config,
325
+ prefix=add_prefix("lm_head", prefix),
304
326
  )
305
327
  self.logits_processor = LogitsProcessor(config)
306
328