sglang 0.4.4__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (176) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +6 -0
  3. sglang/bench_one_batch.py +1 -1
  4. sglang/bench_one_batch_server.py +1 -1
  5. sglang/bench_serving.py +3 -1
  6. sglang/check_env.py +3 -4
  7. sglang/lang/backend/openai.py +18 -5
  8. sglang/lang/chat_template.py +28 -7
  9. sglang/lang/interpreter.py +7 -3
  10. sglang/lang/ir.py +10 -0
  11. sglang/srt/_custom_ops.py +1 -1
  12. sglang/srt/code_completion_parser.py +174 -0
  13. sglang/srt/configs/__init__.py +2 -6
  14. sglang/srt/configs/deepseekvl2.py +667 -0
  15. sglang/srt/configs/janus_pro.py +3 -4
  16. sglang/srt/configs/load_config.py +1 -0
  17. sglang/srt/configs/model_config.py +63 -11
  18. sglang/srt/configs/utils.py +25 -0
  19. sglang/srt/connector/__init__.py +51 -0
  20. sglang/srt/connector/base_connector.py +112 -0
  21. sglang/srt/connector/redis.py +85 -0
  22. sglang/srt/connector/s3.py +122 -0
  23. sglang/srt/connector/serde/__init__.py +31 -0
  24. sglang/srt/connector/serde/safe_serde.py +29 -0
  25. sglang/srt/connector/serde/serde.py +43 -0
  26. sglang/srt/connector/utils.py +35 -0
  27. sglang/srt/conversation.py +88 -0
  28. sglang/srt/disaggregation/conn.py +81 -0
  29. sglang/srt/disaggregation/decode.py +495 -0
  30. sglang/srt/disaggregation/mini_lb.py +285 -0
  31. sglang/srt/disaggregation/prefill.py +249 -0
  32. sglang/srt/disaggregation/utils.py +44 -0
  33. sglang/srt/distributed/parallel_state.py +10 -3
  34. sglang/srt/entrypoints/engine.py +55 -5
  35. sglang/srt/entrypoints/http_server.py +71 -12
  36. sglang/srt/function_call_parser.py +164 -54
  37. sglang/srt/hf_transformers_utils.py +28 -3
  38. sglang/srt/layers/activation.py +4 -2
  39. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +295 -0
  41. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  42. sglang/srt/layers/attention/flashmla_backend.py +284 -0
  43. sglang/srt/layers/attention/triton_backend.py +171 -38
  44. sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
  45. sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
  46. sglang/srt/layers/attention/utils.py +53 -0
  47. sglang/srt/layers/attention/vision.py +9 -28
  48. sglang/srt/layers/dp_attention.py +62 -23
  49. sglang/srt/layers/elementwise.py +411 -0
  50. sglang/srt/layers/layernorm.py +24 -2
  51. sglang/srt/layers/linear.py +17 -5
  52. sglang/srt/layers/logits_processor.py +26 -7
  53. sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
  54. sglang/srt/layers/moe/ep_moe/layer.py +273 -1
  55. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
  56. sglang/srt/layers/moe/fused_moe_native.py +2 -1
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
  63. sglang/srt/layers/moe/router.py +342 -0
  64. sglang/srt/layers/moe/topk.py +31 -18
  65. sglang/srt/layers/parameter.py +1 -1
  66. sglang/srt/layers/quantization/__init__.py +184 -126
  67. sglang/srt/layers/quantization/base_config.py +5 -0
  68. sglang/srt/layers/quantization/blockwise_int8.py +1 -1
  69. sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
  70. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
  71. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
  72. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
  73. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
  74. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
  75. sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
  76. sglang/srt/layers/quantization/fp8.py +76 -34
  77. sglang/srt/layers/quantization/fp8_kernel.py +24 -8
  78. sglang/srt/layers/quantization/fp8_utils.py +284 -28
  79. sglang/srt/layers/quantization/gptq.py +36 -9
  80. sglang/srt/layers/quantization/kv_cache.py +98 -0
  81. sglang/srt/layers/quantization/modelopt_quant.py +9 -7
  82. sglang/srt/layers/quantization/utils.py +153 -0
  83. sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
  84. sglang/srt/layers/rotary_embedding.py +66 -87
  85. sglang/srt/layers/sampler.py +1 -1
  86. sglang/srt/lora/layers.py +68 -0
  87. sglang/srt/lora/lora.py +2 -22
  88. sglang/srt/lora/lora_manager.py +47 -23
  89. sglang/srt/lora/mem_pool.py +110 -51
  90. sglang/srt/lora/utils.py +12 -1
  91. sglang/srt/managers/cache_controller.py +4 -5
  92. sglang/srt/managers/data_parallel_controller.py +31 -9
  93. sglang/srt/managers/expert_distribution.py +81 -0
  94. sglang/srt/managers/io_struct.py +39 -3
  95. sglang/srt/managers/mm_utils.py +373 -0
  96. sglang/srt/managers/multimodal_processor.py +68 -0
  97. sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
  98. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
  99. sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
  100. sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
  101. sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
  102. sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
  103. sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
  104. sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
  105. sglang/srt/managers/schedule_batch.py +134 -31
  106. sglang/srt/managers/scheduler.py +325 -38
  107. sglang/srt/managers/scheduler_output_processor_mixin.py +4 -1
  108. sglang/srt/managers/session_controller.py +1 -1
  109. sglang/srt/managers/tokenizer_manager.py +59 -23
  110. sglang/srt/managers/tp_worker.py +1 -1
  111. sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
  112. sglang/srt/managers/utils.py +6 -1
  113. sglang/srt/mem_cache/hiradix_cache.py +27 -8
  114. sglang/srt/mem_cache/memory_pool.py +258 -98
  115. sglang/srt/mem_cache/paged_allocator.py +2 -2
  116. sglang/srt/mem_cache/radix_cache.py +4 -4
  117. sglang/srt/model_executor/cuda_graph_runner.py +85 -28
  118. sglang/srt/model_executor/forward_batch_info.py +81 -15
  119. sglang/srt/model_executor/model_runner.py +70 -6
  120. sglang/srt/model_loader/loader.py +160 -2
  121. sglang/srt/model_loader/weight_utils.py +45 -0
  122. sglang/srt/models/deepseek_janus_pro.py +29 -86
  123. sglang/srt/models/deepseek_nextn.py +22 -10
  124. sglang/srt/models/deepseek_v2.py +326 -192
  125. sglang/srt/models/deepseek_vl2.py +358 -0
  126. sglang/srt/models/gemma3_causal.py +684 -0
  127. sglang/srt/models/gemma3_mm.py +462 -0
  128. sglang/srt/models/grok.py +374 -119
  129. sglang/srt/models/llama.py +47 -7
  130. sglang/srt/models/llama_eagle.py +1 -0
  131. sglang/srt/models/llama_eagle3.py +196 -0
  132. sglang/srt/models/llava.py +3 -3
  133. sglang/srt/models/llavavid.py +3 -3
  134. sglang/srt/models/minicpmo.py +1995 -0
  135. sglang/srt/models/minicpmv.py +62 -137
  136. sglang/srt/models/mllama.py +4 -4
  137. sglang/srt/models/phi3_small.py +1 -1
  138. sglang/srt/models/qwen2.py +3 -0
  139. sglang/srt/models/qwen2_5_vl.py +68 -146
  140. sglang/srt/models/qwen2_classification.py +75 -0
  141. sglang/srt/models/qwen2_moe.py +9 -1
  142. sglang/srt/models/qwen2_vl.py +25 -63
  143. sglang/srt/openai_api/adapter.py +145 -47
  144. sglang/srt/openai_api/protocol.py +23 -2
  145. sglang/srt/sampling/sampling_batch_info.py +1 -1
  146. sglang/srt/sampling/sampling_params.py +6 -6
  147. sglang/srt/server_args.py +104 -14
  148. sglang/srt/speculative/build_eagle_tree.py +7 -347
  149. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
  150. sglang/srt/speculative/eagle_utils.py +208 -252
  151. sglang/srt/speculative/eagle_worker.py +139 -53
  152. sglang/srt/speculative/spec_info.py +6 -1
  153. sglang/srt/torch_memory_saver_adapter.py +22 -0
  154. sglang/srt/utils.py +182 -21
  155. sglang/test/__init__.py +0 -0
  156. sglang/test/attention/__init__.py +0 -0
  157. sglang/test/attention/test_flashattn_backend.py +312 -0
  158. sglang/test/runners.py +2 -0
  159. sglang/test/test_activation.py +2 -1
  160. sglang/test/test_block_fp8.py +5 -4
  161. sglang/test/test_block_fp8_ep.py +2 -1
  162. sglang/test/test_dynamic_grad_mode.py +58 -0
  163. sglang/test/test_layernorm.py +3 -2
  164. sglang/test/test_utils.py +55 -4
  165. sglang/utils.py +31 -0
  166. sglang/version.py +1 -1
  167. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
  168. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +171 -125
  169. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
  170. sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
  171. sglang/srt/managers/image_processor.py +0 -55
  172. sglang/srt/managers/image_processors/base_image_processor.py +0 -219
  173. sglang/srt/managers/image_processors/minicpmv.py +0 -86
  174. sglang/srt/managers/multi_modality_padding.py +0 -134
  175. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
  176. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
@@ -17,7 +17,7 @@
17
17
  """Inference-only LLaMA model compatible with HuggingFace weights."""
18
18
 
19
19
  import logging
20
- from typing import Any, Dict, Iterable, Optional, Tuple
20
+ from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
21
21
 
22
22
  import torch
23
23
  from torch import nn
@@ -129,6 +129,8 @@ class LlamaAttention(nn.Module):
129
129
  self.head_dim = getattr(
130
130
  config, "head_dim", self.hidden_size // self.total_num_heads
131
131
  )
132
+ partial_rotary_factor = getattr(config, "partial_rotary_factor", 1)
133
+ self.rotary_dim = int(partial_rotary_factor * self.head_dim)
132
134
  self.q_size = self.num_heads * self.head_dim
133
135
  self.kv_size = self.num_kv_heads * self.head_dim
134
136
  self.scaling = self.head_dim**-0.5
@@ -154,7 +156,7 @@ class LlamaAttention(nn.Module):
154
156
 
155
157
  self.rotary_emb = get_rope(
156
158
  self.head_dim,
157
- rotary_dim=self.head_dim,
159
+ rotary_dim=self.rotary_dim,
158
160
  max_position=max_position_embeddings,
159
161
  base=rope_theta,
160
162
  rope_scaling=rope_scaling,
@@ -285,6 +287,7 @@ class LlamaModel(nn.Module):
285
287
  )
286
288
 
287
289
  self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
290
+ self.layers_to_capture = []
288
291
 
289
292
  def forward(
290
293
  self,
@@ -292,13 +295,16 @@ class LlamaModel(nn.Module):
292
295
  positions: torch.Tensor,
293
296
  forward_batch: ForwardBatch,
294
297
  input_embeds: torch.Tensor = None,
295
- ) -> torch.Tensor:
298
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
296
299
  if input_embeds is None:
297
300
  hidden_states = self.embed_tokens(input_ids)
298
301
  else:
299
302
  hidden_states = input_embeds
300
303
  residual = None
304
+ aux_hidden_states = []
301
305
  for i in range(len(self.layers)):
306
+ if i in self.layers_to_capture:
307
+ aux_hidden_states.append(hidden_states + residual)
302
308
  layer = self.layers[i]
303
309
  hidden_states, residual = layer(
304
310
  positions,
@@ -307,7 +313,11 @@ class LlamaModel(nn.Module):
307
313
  residual,
308
314
  )
309
315
  hidden_states, _ = self.norm(hidden_states, residual)
310
- return hidden_states
316
+
317
+ if len(aux_hidden_states) == 0:
318
+ return hidden_states
319
+
320
+ return hidden_states, aux_hidden_states
311
321
 
312
322
  # If this function is called, it should always initialize KV cache scale
313
323
  # factors (or else raise an exception). Thus, handled exceptions should
@@ -335,7 +345,6 @@ class LlamaModel(nn.Module):
335
345
 
336
346
 
337
347
  class LlamaForCausalLM(nn.Module):
338
-
339
348
  # BitandBytes specific attributes
340
349
  default_bitsandbytes_target_modules = [
341
350
  ".gate_proj.",
@@ -391,6 +400,8 @@ class LlamaForCausalLM(nn.Module):
391
400
  (".gate_up_proj", ".up_proj", 1),
392
401
  ]
393
402
 
403
+ self.capture_aux_hidden_states = False
404
+
394
405
  @torch.no_grad()
395
406
  def forward(
396
407
  self,
@@ -400,10 +411,19 @@ class LlamaForCausalLM(nn.Module):
400
411
  input_embeds: torch.Tensor = None,
401
412
  get_embedding: bool = False,
402
413
  ) -> LogitsProcessorOutput:
403
- hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
414
+ aux_hidden_states = None
415
+ if self.capture_aux_hidden_states:
416
+ hidden_states, aux_hidden_states = self.model(
417
+ input_ids, positions, forward_batch, input_embeds
418
+ )
419
+ else:
420
+ hidden_states = self.model(
421
+ input_ids, positions, forward_batch, input_embeds
422
+ )
423
+
404
424
  if not get_embedding:
405
425
  return self.logits_processor(
406
- input_ids, hidden_states, self.lm_head, forward_batch
426
+ input_ids, hidden_states, self.lm_head, forward_batch, aux_hidden_states
407
427
  )
408
428
  else:
409
429
  return self.pooler(hidden_states, forward_batch)
@@ -586,9 +606,29 @@ class LlamaForCausalLM(nn.Module):
586
606
  torch.cuda.empty_cache()
587
607
  torch.cuda.synchronize()
588
608
 
609
+ def get_embed(self):
610
+ return self.model.embed_tokens.weight
611
+
612
+ def set_embed(self, embed):
613
+ # NOTE: If draft hidden size != target hidden size, the embed weight cannot be shared for EAGLE3
614
+ if (
615
+ hasattr(self.config, "target_hidden_size")
616
+ and self.config.target_hidden_size != self.config.hidden_size
617
+ ):
618
+ return
619
+ del self.model.embed_tokens.weight
620
+ self.model.embed_tokens.weight = embed
621
+ torch.cuda.empty_cache()
622
+ torch.cuda.synchronize()
623
+
589
624
  def load_kv_cache_scales(self, quantization_param_path: str) -> None:
590
625
  self.model.load_kv_cache_scales(quantization_param_path)
591
626
 
627
+ def set_eagle3_layers_to_capture(self):
628
+ self.capture_aux_hidden_states = True
629
+ num_layers = self.config.num_hidden_layers
630
+ self.model.layers_to_capture = [2, num_layers // 2, num_layers - 3]
631
+
592
632
 
593
633
  class Phi3ForCausalLM(LlamaForCausalLM):
594
634
  pass
@@ -134,6 +134,7 @@ class LlamaForCausalLMEagle(LlamaForCausalLM):
134
134
  )
135
135
 
136
136
  self.logits_processor = LogitsProcessor(config)
137
+ self.capture_aux_hidden_states = False
137
138
 
138
139
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
139
140
  for name, loaded_weight in weights:
@@ -0,0 +1,196 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
16
+ from sglang.srt.utils import add_prefix
17
+
18
+ # Adapted from
19
+ # https://github.com/SafeAILab/EAGLE/blob/main/eagle/model/cnets.py
20
+ """Inference-only LLaMA-EAGLE model compatible with HuggingFace weights."""
21
+
22
+ from typing import Iterable, Optional, Tuple
23
+
24
+ import torch
25
+ from torch import nn
26
+ from transformers import LlamaConfig
27
+
28
+ from sglang.srt.layers.layernorm import RMSNorm
29
+ from sglang.srt.layers.linear import QKVParallelLinear, RowParallelLinear
30
+ from sglang.srt.layers.logits_processor import LogitsProcessor
31
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
32
+ from sglang.srt.layers.vocab_parallel_embedding import (
33
+ ParallelLMHead,
34
+ VocabParallelEmbedding,
35
+ )
36
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
37
+ from sglang.srt.models.llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM
38
+
39
+
40
+ class LlamaDecoderLayer(LlamaDecoderLayer):
41
+ def __init__(
42
+ self,
43
+ config: LlamaConfig,
44
+ layer_id: int = 0,
45
+ quant_config: Optional[QuantizationConfig] = None,
46
+ prefix: str = "",
47
+ ) -> None:
48
+ super().__init__(config, layer_id, quant_config, prefix)
49
+
50
+ # override qkv
51
+ self.self_attn.qkv_proj = QKVParallelLinear(
52
+ 2 * self.hidden_size,
53
+ self.self_attn.head_dim,
54
+ self.self_attn.total_num_heads,
55
+ self.self_attn.total_num_kv_heads,
56
+ bias=False,
57
+ quant_config=quant_config,
58
+ prefix=add_prefix("qkv_proj", prefix),
59
+ )
60
+
61
+ self.hidden_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
62
+
63
+ def forward(
64
+ self,
65
+ positions: torch.Tensor,
66
+ embeds: torch.Tensor,
67
+ hidden_states: torch.Tensor,
68
+ forward_batch: ForwardBatch,
69
+ residual: Optional[torch.Tensor],
70
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
71
+
72
+ residual = hidden_states
73
+ embeds = self.input_layernorm(embeds)
74
+ hidden_states = self.hidden_norm(hidden_states)
75
+
76
+ hidden_states = torch.cat([embeds, hidden_states], dim=-1)
77
+ # Self Attention
78
+ hidden_states = self.self_attn(
79
+ positions=positions,
80
+ hidden_states=hidden_states,
81
+ forward_batch=forward_batch,
82
+ )
83
+
84
+ hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
85
+
86
+ # Fully Connected
87
+ hidden_states = self.mlp(hidden_states)
88
+
89
+ return hidden_states, residual
90
+
91
+
92
+ class LlamaModel(nn.Module):
93
+ def __init__(
94
+ self,
95
+ config: LlamaConfig,
96
+ quant_config: Optional[QuantizationConfig] = None,
97
+ prefix: str = "",
98
+ ) -> None:
99
+ super().__init__()
100
+ self.config = config
101
+ self.vocab_size = config.vocab_size
102
+ self.embed_tokens = VocabParallelEmbedding(
103
+ config.vocab_size,
104
+ config.hidden_size,
105
+ prefix=add_prefix("embed_tokens", prefix),
106
+ )
107
+ self.midlayer = LlamaDecoderLayer(config, 0, quant_config, prefix)
108
+ if hasattr(config, "target_hidden_size"):
109
+ self.fc = torch.nn.Linear(config.target_hidden_size * 3, config.hidden_size)
110
+ else:
111
+ self.fc = torch.nn.Linear(config.hidden_size * 3, config.hidden_size)
112
+
113
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
114
+
115
+ def forward(
116
+ self,
117
+ input_ids: torch.Tensor,
118
+ positions: torch.Tensor,
119
+ forward_batch: ForwardBatch,
120
+ input_embeds: torch.Tensor = None,
121
+ ) -> torch.Tensor:
122
+ if input_embeds is None:
123
+ embeds = self.embed_tokens(input_ids)
124
+ else:
125
+ embeds = input_embeds
126
+
127
+ hidden_states = forward_batch.spec_info.hidden_states
128
+ if hidden_states.shape[-1] != embeds.shape[-1]:
129
+ hidden_states = self.fc(hidden_states)
130
+
131
+ residual = None
132
+ hidden_states, residual = self.midlayer(
133
+ positions,
134
+ embeds,
135
+ hidden_states,
136
+ forward_batch,
137
+ residual,
138
+ )
139
+
140
+ hidden_states_to_logits, hidden_states_to_aux = self.norm(
141
+ hidden_states, residual
142
+ )
143
+
144
+ # For draft decode, we capture the hidden state before norm
145
+ return hidden_states_to_logits, [hidden_states_to_aux]
146
+
147
+
148
+ class LlamaForCausalLMEagle3(LlamaForCausalLM):
149
+ def __init__(
150
+ self,
151
+ config: LlamaConfig,
152
+ quant_config: Optional[QuantizationConfig] = None,
153
+ prefix: str = "",
154
+ ) -> None:
155
+ nn.Module.__init__(self)
156
+ self.config = config
157
+ self.quant_config = quant_config
158
+
159
+ if self.config.num_hidden_layers != 1:
160
+ raise ValueError("EAGLE3 currently only supports 1 layer")
161
+
162
+ self.model = LlamaModel(
163
+ config, quant_config=quant_config, prefix=add_prefix("model", prefix)
164
+ )
165
+ # Llama 3.2 1B Instruct set tie_word_embeddings to True
166
+ # Llama 3.1 8B Instruct set tie_word_embeddings to False
167
+ if self.config.tie_word_embeddings:
168
+ self.lm_head = self.model.embed_tokens
169
+ else:
170
+ self.lm_head = ParallelLMHead(
171
+ config.draft_vocab_size,
172
+ config.hidden_size,
173
+ quant_config=quant_config,
174
+ prefix=add_prefix("lm_head", prefix),
175
+ )
176
+
177
+ self.logits_processor = LogitsProcessor(config)
178
+ self.capture_aux_hidden_states = True
179
+
180
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
181
+ for name, loaded_weight in weights:
182
+ if "d2t" in name:
183
+ # d2t stores diffs between draft id and target id
184
+ self.hot_token_id = loaded_weight + torch.arange(loaded_weight.shape[0])
185
+
186
+ if "d2t" not in name and "t2d" not in name and "lm_head" not in name:
187
+ new_name = f"model.{name}"
188
+ super().load_weights([(new_name, loaded_weight)])
189
+ elif "lm_head" in name:
190
+ super().load_weights([(name, loaded_weight)])
191
+
192
+ def get_hot_token_id(self):
193
+ return self.hot_token_id
194
+
195
+
196
+ EntryClass = [LlamaForCausalLMEagle3]
@@ -31,7 +31,7 @@ from transformers import (
31
31
  from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
32
32
 
33
33
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
34
- from sglang.srt.managers.schedule_batch import ImageInputs
34
+ from sglang.srt.managers.schedule_batch import MultimodalInputs
35
35
  from sglang.srt.mm_utils import (
36
36
  get_anyres_image_grid_shape,
37
37
  unpad_image,
@@ -46,7 +46,7 @@ from sglang.srt.utils import add_prefix
46
46
 
47
47
 
48
48
  class LlavaBaseForCausalLM(nn.Module):
49
- def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
49
+ def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
50
50
  image_sizes, pad_values = image_inputs.image_sizes, image_inputs.pad_values
51
51
 
52
52
  # hardcode for spatial_unpad + anyres
@@ -134,7 +134,7 @@ class LlavaBaseForCausalLM(nn.Module):
134
134
  positions: torch.Tensor,
135
135
  forward_batch: ForwardBatch,
136
136
  ) -> torch.Tensor:
137
- image_inputs = forward_batch.image_inputs
137
+ image_inputs = forward_batch.mm_inputs
138
138
 
139
139
  if forward_batch.forward_mode.is_extend():
140
140
  # Clamp input ids. This is because the input_ids for the image tokens are
@@ -22,7 +22,7 @@ from transformers import CLIPVisionModel, LlavaConfig
22
22
  from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
23
23
 
24
24
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
25
- from sglang.srt.managers.schedule_batch import ImageInputs
25
+ from sglang.srt.managers.schedule_batch import MultimodalInputs
26
26
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
27
27
  from sglang.srt.model_loader.weight_utils import default_weight_loader
28
28
  from sglang.srt.models.llama import LlamaForCausalLM
@@ -57,7 +57,7 @@ class LlavaVidForCausalLM(nn.Module):
57
57
  torch.empty(config.text_config.hidden_size, dtype=torch.float16)
58
58
  )
59
59
 
60
- def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
60
+ def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
61
61
  pad_values = image_inputs.pad_values
62
62
  new_image_feature_len = self.image_feature_len
63
63
 
@@ -112,7 +112,7 @@ class LlavaVidForCausalLM(nn.Module):
112
112
  positions: torch.Tensor,
113
113
  forward_batch: ForwardBatch,
114
114
  ) -> torch.Tensor:
115
- image_inputs = forward_batch.image_inputs
115
+ image_inputs = forward_batch.mm_inputs
116
116
  if forward_batch.forward_mode.is_extend():
117
117
  bs = forward_batch.batch_size
118
118