sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (238) hide show
  1. sglang/bench_one_batch_server.py +10 -1
  2. sglang/bench_serving.py +257 -29
  3. sglang/srt/configs/__init__.py +4 -0
  4. sglang/srt/configs/device_config.py +3 -1
  5. sglang/srt/configs/dots_vlm.py +139 -0
  6. sglang/srt/configs/load_config.py +1 -0
  7. sglang/srt/configs/model_config.py +50 -6
  8. sglang/srt/configs/qwen3_next.py +326 -0
  9. sglang/srt/connector/__init__.py +8 -1
  10. sglang/srt/connector/remote_instance.py +82 -0
  11. sglang/srt/constrained/base_grammar_backend.py +48 -12
  12. sglang/srt/constrained/llguidance_backend.py +0 -1
  13. sglang/srt/constrained/outlines_backend.py +0 -1
  14. sglang/srt/constrained/xgrammar_backend.py +28 -9
  15. sglang/srt/custom_op.py +11 -1
  16. sglang/srt/debug_utils/dump_comparator.py +81 -44
  17. sglang/srt/debug_utils/dump_loader.py +97 -0
  18. sglang/srt/debug_utils/dumper.py +11 -3
  19. sglang/srt/debug_utils/text_comparator.py +73 -11
  20. sglang/srt/disaggregation/base/conn.py +1 -1
  21. sglang/srt/disaggregation/common/conn.py +15 -12
  22. sglang/srt/disaggregation/decode.py +21 -10
  23. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
  24. sglang/srt/disaggregation/fake/conn.py +1 -1
  25. sglang/srt/disaggregation/mini_lb.py +6 -445
  26. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  27. sglang/srt/disaggregation/nixl/conn.py +180 -16
  28. sglang/srt/disaggregation/prefill.py +5 -3
  29. sglang/srt/disaggregation/utils.py +5 -50
  30. sglang/srt/distributed/parallel_state.py +24 -3
  31. sglang/srt/entrypoints/engine.py +38 -17
  32. sglang/srt/entrypoints/grpc_request_manager.py +580 -0
  33. sglang/srt/entrypoints/grpc_server.py +680 -0
  34. sglang/srt/entrypoints/http_server.py +85 -54
  35. sglang/srt/entrypoints/openai/protocol.py +4 -1
  36. sglang/srt/entrypoints/openai/serving_base.py +46 -3
  37. sglang/srt/entrypoints/openai/serving_chat.py +36 -16
  38. sglang/srt/entrypoints/openai/serving_completions.py +12 -3
  39. sglang/srt/entrypoints/openai/serving_embedding.py +8 -3
  40. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  41. sglang/srt/entrypoints/openai/serving_responses.py +6 -3
  42. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  43. sglang/srt/eplb/eplb_manager.py +2 -2
  44. sglang/srt/eplb/expert_distribution.py +26 -13
  45. sglang/srt/eplb/expert_location.py +8 -3
  46. sglang/srt/eplb/expert_location_updater.py +1 -1
  47. sglang/srt/function_call/base_format_detector.py +3 -6
  48. sglang/srt/function_call/ebnf_composer.py +11 -9
  49. sglang/srt/function_call/function_call_parser.py +6 -0
  50. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  51. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  52. sglang/srt/grpc/__init__.py +1 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
  56. sglang/srt/hf_transformers_utils.py +4 -0
  57. sglang/srt/layers/activation.py +142 -9
  58. sglang/srt/layers/attention/ascend_backend.py +11 -4
  59. sglang/srt/layers/attention/fla/chunk.py +242 -0
  60. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  61. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  62. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  63. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  66. sglang/srt/layers/attention/fla/index.py +37 -0
  67. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  68. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  69. sglang/srt/layers/attention/fla/op.py +66 -0
  70. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  71. sglang/srt/layers/attention/fla/utils.py +331 -0
  72. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  73. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  74. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  75. sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
  76. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  77. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  78. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  79. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  80. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  81. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  82. sglang/srt/layers/attention/triton_backend.py +18 -1
  83. sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
  84. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  85. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  86. sglang/srt/layers/dp_attention.py +30 -1
  87. sglang/srt/layers/layernorm.py +32 -15
  88. sglang/srt/layers/linear.py +34 -3
  89. sglang/srt/layers/logits_processor.py +29 -10
  90. sglang/srt/layers/moe/__init__.py +2 -1
  91. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  92. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  93. sglang/srt/layers/moe/ep_moe/layer.py +182 -62
  94. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
  95. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  96. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  97. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  98. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  99. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  100. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  101. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  102. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  103. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  104. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  105. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  106. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  107. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
  108. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  109. sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
  110. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  111. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  112. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  113. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  114. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  115. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  116. sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
  117. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  118. sglang/srt/layers/moe/topk.py +30 -9
  119. sglang/srt/layers/moe/utils.py +12 -6
  120. sglang/srt/layers/quantization/awq.py +19 -7
  121. sglang/srt/layers/quantization/base_config.py +11 -6
  122. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  123. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  124. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  125. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  126. sglang/srt/layers/quantization/fp8.py +76 -47
  127. sglang/srt/layers/quantization/fp8_utils.py +50 -31
  128. sglang/srt/layers/quantization/gptq.py +25 -17
  129. sglang/srt/layers/quantization/modelopt_quant.py +147 -47
  130. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  131. sglang/srt/layers/quantization/mxfp4.py +64 -40
  132. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  133. sglang/srt/layers/quantization/unquant.py +135 -47
  134. sglang/srt/layers/quantization/w4afp8.py +30 -17
  135. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  136. sglang/srt/layers/quantization/w8a8_int8.py +76 -38
  137. sglang/srt/layers/sampler.py +162 -18
  138. sglang/srt/lora/backend/base_backend.py +50 -8
  139. sglang/srt/lora/backend/triton_backend.py +90 -2
  140. sglang/srt/lora/layers.py +32 -0
  141. sglang/srt/lora/lora.py +4 -1
  142. sglang/srt/lora/lora_manager.py +35 -112
  143. sglang/srt/lora/mem_pool.py +24 -10
  144. sglang/srt/lora/utils.py +18 -9
  145. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  146. sglang/srt/managers/cache_controller.py +158 -160
  147. sglang/srt/managers/data_parallel_controller.py +105 -35
  148. sglang/srt/managers/detokenizer_manager.py +8 -4
  149. sglang/srt/managers/disagg_service.py +46 -0
  150. sglang/srt/managers/io_struct.py +199 -12
  151. sglang/srt/managers/mm_utils.py +1 -0
  152. sglang/srt/managers/multi_tokenizer_mixin.py +350 -400
  153. sglang/srt/managers/schedule_batch.py +77 -56
  154. sglang/srt/managers/schedule_policy.py +1 -1
  155. sglang/srt/managers/scheduler.py +187 -39
  156. sglang/srt/managers/scheduler_metrics_mixin.py +4 -3
  157. sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
  158. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  159. sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
  160. sglang/srt/managers/tokenizer_manager.py +259 -519
  161. sglang/srt/managers/tp_worker.py +53 -4
  162. sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
  163. sglang/srt/mem_cache/hicache_storage.py +3 -23
  164. sglang/srt/mem_cache/hiradix_cache.py +103 -43
  165. sglang/srt/mem_cache/memory_pool.py +347 -48
  166. sglang/srt/mem_cache/memory_pool_host.py +105 -46
  167. sglang/srt/mem_cache/radix_cache.py +0 -2
  168. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  169. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  170. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +86 -4
  171. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  172. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  173. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +49 -7
  174. sglang/srt/mem_cache/swa_radix_cache.py +0 -2
  175. sglang/srt/metrics/collector.py +493 -76
  176. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  177. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  178. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  179. sglang/srt/model_executor/forward_batch_info.py +59 -2
  180. sglang/srt/model_executor/model_runner.py +356 -29
  181. sglang/srt/model_loader/__init__.py +9 -3
  182. sglang/srt/model_loader/loader.py +128 -4
  183. sglang/srt/model_loader/weight_utils.py +2 -1
  184. sglang/srt/models/apertus.py +686 -0
  185. sglang/srt/models/bailing_moe.py +798 -218
  186. sglang/srt/models/bailing_moe_nextn.py +168 -0
  187. sglang/srt/models/deepseek_v2.py +109 -15
  188. sglang/srt/models/dots_vlm.py +174 -0
  189. sglang/srt/models/dots_vlm_vit.py +337 -0
  190. sglang/srt/models/ernie4.py +1 -1
  191. sglang/srt/models/gemma3n_mm.py +1 -1
  192. sglang/srt/models/glm4_moe.py +1 -1
  193. sglang/srt/models/glm4v.py +4 -2
  194. sglang/srt/models/glm4v_moe.py +3 -0
  195. sglang/srt/models/gpt_oss.py +1 -1
  196. sglang/srt/models/llama4.py +9 -0
  197. sglang/srt/models/llama_eagle3.py +13 -0
  198. sglang/srt/models/longcat_flash.py +2 -2
  199. sglang/srt/models/mllama4.py +25 -0
  200. sglang/srt/models/opt.py +637 -0
  201. sglang/srt/models/qwen2.py +7 -0
  202. sglang/srt/models/qwen2_5_vl.py +27 -3
  203. sglang/srt/models/qwen2_moe.py +56 -12
  204. sglang/srt/models/qwen3_moe.py +1 -1
  205. sglang/srt/models/qwen3_next.py +1042 -0
  206. sglang/srt/models/qwen3_next_mtp.py +112 -0
  207. sglang/srt/models/step3_vl.py +1 -1
  208. sglang/srt/multimodal/processors/dots_vlm.py +99 -0
  209. sglang/srt/multimodal/processors/glm4v.py +9 -9
  210. sglang/srt/multimodal/processors/internvl.py +141 -129
  211. sglang/srt/multimodal/processors/qwen_vl.py +15 -5
  212. sglang/srt/offloader.py +27 -3
  213. sglang/srt/remote_instance_weight_loader_utils.py +69 -0
  214. sglang/srt/sampling/sampling_batch_info.py +18 -15
  215. sglang/srt/server_args.py +276 -35
  216. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  217. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  218. sglang/srt/speculative/eagle_utils.py +0 -2
  219. sglang/srt/speculative/eagle_worker.py +43 -4
  220. sglang/srt/speculative/spec_info.py +5 -0
  221. sglang/srt/speculative/standalone_worker.py +109 -0
  222. sglang/srt/tracing/trace.py +552 -0
  223. sglang/srt/utils.py +34 -3
  224. sglang/srt/weight_sync/utils.py +1 -1
  225. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  226. sglang/test/runners.py +4 -0
  227. sglang/test/test_cutlass_moe.py +24 -6
  228. sglang/test/test_disaggregation_utils.py +66 -0
  229. sglang/test/test_fp4_moe.py +370 -1
  230. sglang/test/test_utils.py +28 -1
  231. sglang/utils.py +11 -0
  232. sglang/version.py +1 -1
  233. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
  234. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +237 -178
  235. sglang/srt/disaggregation/launch_lb.py +0 -118
  236. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
  237. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
  238. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,337 @@
1
+ import logging
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import torch.utils.checkpoint
8
+ from torch.nn import LayerNorm
9
+ from transformers.modeling_utils import PreTrainedModel
10
+
11
+ from sglang.srt.configs.dots_vlm import DotsVisionConfig
12
+ from sglang.srt.distributed import parallel_state
13
+ from sglang.srt.layers.attention.vision import VisionAttention
14
+ from sglang.srt.layers.quantization import QuantizationConfig
15
+ from sglang.srt.utils import add_prefix
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class VisionRotaryEmbedding(nn.Module):
21
+ def __init__(self, dim: int, theta: float = 10000.0) -> None:
22
+ super().__init__()
23
+ inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
24
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
25
+
26
+ def forward(self, seqlen: int) -> torch.Tensor:
27
+ seq = torch.arange(
28
+ seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype
29
+ )
30
+ freqs = torch.outer(seq, self.inv_freq)
31
+ return freqs
32
+
33
+
34
+ class PatchMerger(nn.Module):
35
+ def __init__(
36
+ self,
37
+ dim: int,
38
+ context_dim: int,
39
+ spatial_merge_size: int = 2,
40
+ pre_norm="layernorm",
41
+ init_merger_std=None,
42
+ quant_config: Optional[QuantizationConfig] = None,
43
+ ) -> None:
44
+ super().__init__()
45
+ self.hidden_size = context_dim * (spatial_merge_size**2)
46
+ self.pre_norm = pre_norm
47
+ if self.pre_norm == "layernorm":
48
+ self.ln_q = LayerNorm(context_dim, eps=1e-6)
49
+ elif self.pre_norm == "rmsnorm":
50
+ self.ln_q = RMSNorm(context_dim, eps=1e-6)
51
+ else:
52
+ logger.warning(f"no norm in patch merger: {self.pre_norm}")
53
+
54
+ self.mlp = nn.Sequential(
55
+ nn.Linear(self.hidden_size, self.hidden_size),
56
+ nn.GELU(),
57
+ nn.Linear(self.hidden_size, dim),
58
+ )
59
+
60
+ if init_merger_std is not None:
61
+ nn.init.normal_(self.mlp[0].weight, mean=0.0, std=init_merger_std)
62
+ nn.init.zeros_(self.mlp[0].bias)
63
+ nn.init.normal_(self.mlp[2].weight, mean=0.0, std=init_merger_std)
64
+ nn.init.zeros_(self.mlp[2].bias)
65
+
66
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
67
+ if self.pre_norm:
68
+ x = self.mlp(self.ln_q(x).view(-1, self.hidden_size))
69
+ else:
70
+ x = self.mlp(x.view(-1, self.hidden_size))
71
+ return x
72
+
73
+
74
+ class RMSNorm(nn.Module):
75
+ def __init__(self, dim: int, eps: float = 1e-6):
76
+ super().__init__()
77
+ self.weight = nn.Parameter(torch.ones(dim))
78
+ self.eps = eps
79
+
80
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
81
+ output = self._norm(x.float()).type_as(x)
82
+ return output * self.weight
83
+
84
+ def extra_repr(self) -> str:
85
+ return f"{tuple(self.weight.shape)}, eps={self.eps}"
86
+
87
+ def _norm(self, x: torch.Tensor) -> torch.Tensor:
88
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
89
+
90
+
91
+ class DotsSwiGLUFFN(nn.Module):
92
+ def __init__(self, config, quant_config: Optional[QuantizationConfig] = None):
93
+ super().__init__()
94
+ hidden_features = config.intermediate_size
95
+ in_features = config.embed_dim
96
+ bias = config.use_bias
97
+
98
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
99
+ self.fc2 = nn.Linear(hidden_features, in_features, bias=bias)
100
+ self.fc3 = nn.Linear(in_features, hidden_features, bias=bias)
101
+
102
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
103
+ x = F.silu(self.fc1(x)) * self.fc3(x)
104
+ x = self.fc2(x)
105
+ return x
106
+
107
+
108
+ class DotsPatchEmbed(nn.Module):
109
+ def __init__(self, config, quant_config: Optional[QuantizationConfig] = None):
110
+ super().__init__()
111
+ self.num_channels = config.num_channels
112
+ self.patch_size = config.patch_size
113
+ self.temporal_patch_size = config.temporal_patch_size
114
+ self.embed_dim = config.embed_dim
115
+ self.config = config
116
+ self.proj = nn.Conv2d(
117
+ config.num_channels,
118
+ config.embed_dim,
119
+ kernel_size=(config.patch_size, config.patch_size),
120
+ stride=(config.patch_size, config.patch_size),
121
+ )
122
+ self.norm = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)
123
+
124
+ def forward(self, x: torch.Tensor, grid_thw=None) -> torch.Tensor:
125
+ x = x.view(
126
+ -1,
127
+ self.num_channels,
128
+ self.temporal_patch_size,
129
+ self.patch_size,
130
+ self.patch_size,
131
+ )[:, :, 0]
132
+ x = self.proj(x).view(-1, self.embed_dim)
133
+ x = self.norm(x)
134
+ return x
135
+
136
+
137
+ class DotsViTPreprocessor(nn.Module):
138
+ def __init__(self, config, quant_config: Optional[QuantizationConfig] = None):
139
+ super().__init__()
140
+ self.patch_h = config.patch_size
141
+ self.patch_w = config.patch_size
142
+ self.embed_dim = config.embed_dim
143
+ self.config = config
144
+ self.patchifier = DotsPatchEmbed(config, quant_config)
145
+
146
+ def forward(self, x: torch.Tensor, grid_thw=None) -> torch.Tensor:
147
+ tokens = self.patchifier(x, grid_thw)
148
+ return tokens
149
+
150
+
151
+ class DotsVisionBlock(nn.Module):
152
+ def __init__(
153
+ self,
154
+ config: DotsVisionConfig,
155
+ quant_config: Optional[QuantizationConfig] = None,
156
+ prefix: str = "",
157
+ attn_implementation: str = "flash_attention_2",
158
+ ):
159
+ super().__init__()
160
+ if attn_implementation == "flash_attention_2":
161
+ qkv_backend = "fa3"
162
+ softmax_in_single_precision = False
163
+ else:
164
+ raise RuntimeError("Unimplemented")
165
+ self.attn = VisionAttention(
166
+ embed_dim=config.embed_dim,
167
+ num_heads=config.num_attention_heads,
168
+ projection_size=config.embed_dim,
169
+ use_qkv_parallel=True,
170
+ qkv_backend=qkv_backend,
171
+ softmax_in_single_precision=softmax_in_single_precision,
172
+ flatten_batch=True,
173
+ quant_config=quant_config,
174
+ prefix=add_prefix("attn", prefix),
175
+ num_dummy_heads=config.num_dummy_heads,
176
+ qkv_bias=config.use_bias,
177
+ proj_bias=config.use_bias,
178
+ )
179
+ self.norm1 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)
180
+ self.mlp = DotsSwiGLUFFN(config, quant_config)
181
+ self.norm2 = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)
182
+
183
+ def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor:
184
+ hidden_states = hidden_states + self.attn(
185
+ self.norm1(hidden_states),
186
+ cu_seqlens=cu_seqlens,
187
+ position_embeddings=rotary_pos_emb,
188
+ )
189
+ hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
190
+ return hidden_states
191
+
192
+
193
+ class DotsVisionTransformer(PreTrainedModel):
194
+ def __init__(
195
+ self,
196
+ config: DotsVisionConfig,
197
+ quant_config: Optional[QuantizationConfig] = None,
198
+ ) -> None:
199
+ super().__init__(config)
200
+ self.config = config
201
+ self._update_vision_config()
202
+ self.spatial_merge_size = config.spatial_merge_size
203
+
204
+ self.patch_embed = DotsViTPreprocessor(config, quant_config)
205
+ self._init_weights(self.patch_embed.patchifier.proj)
206
+
207
+ head_dim = config.embed_dim // config.num_attention_heads
208
+
209
+ self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
210
+
211
+ _num_hidden_layers = config.num_hidden_layers
212
+ self.blocks = nn.ModuleList(
213
+ [
214
+ DotsVisionBlock(
215
+ config, quant_config, f"blocks.{i}", config.attn_implementation
216
+ )
217
+ for i in range(_num_hidden_layers)
218
+ ]
219
+ )
220
+
221
+ if self.config.post_norm:
222
+ self.post_trunk_norm = RMSNorm(config.embed_dim, eps=config.rms_norm_eps)
223
+
224
+ self.merger = PatchMerger(
225
+ dim=config.hidden_size,
226
+ context_dim=config.embed_dim,
227
+ spatial_merge_size=config.spatial_merge_size,
228
+ init_merger_std=self.config.init_merger_std,
229
+ quant_config=quant_config,
230
+ )
231
+
232
+ self.gradient_checkpointing = False
233
+
234
+ def _update_vision_config(self):
235
+ """update vision config to support tp"""
236
+ world_size = parallel_state.get_tensor_model_parallel_world_size()
237
+ num_heads = self.config.num_attention_heads
238
+ head_dim = self.config.embed_dim // num_heads
239
+ num_dummy_heads = 0
240
+
241
+ if num_heads % world_size != 0:
242
+ num_dummy_heads = (
243
+ (num_heads + world_size) // world_size
244
+ ) * world_size - num_heads
245
+
246
+ setattr(self.config, "head_dim", head_dim)
247
+ setattr(self.config, "num_dummy_heads", num_dummy_heads)
248
+
249
+ def _init_weights(self, module):
250
+ std = self.config.initializer_range
251
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
252
+ module.weight.data.normal_(mean=0.0, std=std)
253
+ if module.bias is not None:
254
+ module.bias.data.zero_()
255
+ elif isinstance(module, nn.Embedding):
256
+ module.weight.data.normal_(mean=0.0, std=std)
257
+ if module.padding_idx is not None:
258
+ module.weight.data[module.padding_idx].zero_()
259
+
260
+ @property
261
+ def dtype(self) -> torch.dtype:
262
+ return self.blocks[0].mlp.fc2.weight.dtype
263
+
264
+ @property
265
+ def device(self) -> torch.device:
266
+ return self.blocks[0].mlp.fc2.weight.device
267
+
268
+ def get_pos_ids_by_grid(self, grid_thw):
269
+ pos_ids = []
270
+ for t, h, w in grid_thw:
271
+ hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
272
+ hpos_ids = hpos_ids.reshape(
273
+ h // self.spatial_merge_size,
274
+ self.spatial_merge_size,
275
+ w // self.spatial_merge_size,
276
+ self.spatial_merge_size,
277
+ )
278
+ hpos_ids = hpos_ids.permute(0, 2, 1, 3)
279
+ hpos_ids = hpos_ids.flatten()
280
+
281
+ wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
282
+ wpos_ids = wpos_ids.reshape(
283
+ h // self.spatial_merge_size,
284
+ self.spatial_merge_size,
285
+ w // self.spatial_merge_size,
286
+ self.spatial_merge_size,
287
+ )
288
+ wpos_ids = wpos_ids.permute(0, 2, 1, 3)
289
+ wpos_ids = wpos_ids.flatten()
290
+ pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
291
+
292
+ return pos_ids
293
+
294
+ def rot_pos_emb(self, grid_thw):
295
+ pos_ids = self.get_pos_ids_by_grid(grid_thw)
296
+ pos_ids = torch.cat(pos_ids, dim=0)
297
+ max_grid_size = grid_thw[:, 1:].max()
298
+ rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
299
+ rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
300
+ return rotary_pos_emb
301
+
302
+ def calc_cos_sin(self, rotary_pos_emb):
303
+ cos = rotary_pos_emb.cos()
304
+ sin = rotary_pos_emb.sin()
305
+ cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
306
+ sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
307
+ rotary_pos_emb = (cos, sin)
308
+ return rotary_pos_emb
309
+
310
+ def forward(
311
+ self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, bf16=True
312
+ ) -> torch.Tensor:
313
+ if bf16:
314
+ hidden_states = hidden_states.bfloat16()
315
+ hidden_states = self.patch_embed(hidden_states, grid_thw)
316
+
317
+ rotary_pos_emb = self.rot_pos_emb(grid_thw)
318
+ rotary_pos_emb = self.calc_cos_sin(rotary_pos_emb)
319
+
320
+ cu_seqlens = torch.repeat_interleave(
321
+ grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]
322
+ ).cumsum(
323
+ dim=0,
324
+ dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
325
+ )
326
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
327
+
328
+ for blk in self.blocks:
329
+ hidden_states = blk(
330
+ hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
331
+ )
332
+
333
+ if self.config.post_norm:
334
+ hidden_states = self.post_trunk_norm(hidden_states)
335
+
336
+ hidden_states = self.merger(hidden_states)
337
+ return hidden_states
@@ -92,7 +92,7 @@ class Ernie4Moe(nn.Module):
92
92
  correction_bias=self.gate.e_score_correction_bias,
93
93
  )
94
94
 
95
- self.experts = get_moe_impl_class()(
95
+ self.experts = get_moe_impl_class(quant_config)(
96
96
  num_experts=config.moe_num_experts,
97
97
  top_k=config.moe_k,
98
98
  hidden_size=config.hidden_size,
@@ -499,7 +499,7 @@ class Gemma3nForConditionalGeneration(PreTrainedModel):
499
499
  def should_apply_lora(self, module_name: str) -> bool:
500
500
  return bool(self.lora_pattern.match(module_name))
501
501
 
502
- def get_hidden_dim(self, module_name):
502
+ def get_hidden_dim(self, module_name, layer_idx):
503
503
  # return input_dim, output_dim
504
504
  if module_name == "qkv_proj":
505
505
  return (
@@ -429,7 +429,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
429
429
  routed_scaling_factor=self.routed_scaling_factor,
430
430
  )
431
431
 
432
- self.experts = get_moe_impl_class()(
432
+ self.experts = get_moe_impl_class(quant_config)(
433
433
  num_experts=config.n_routed_experts
434
434
  + self.num_fused_shared_experts
435
435
  + global_server_args_dict["ep_num_redundant_experts"],
@@ -93,9 +93,8 @@ class Glm4vVisionBlock(Qwen2_5_VisionBlock):
93
93
  quant_config=quant_config,
94
94
  prefix=prefix,
95
95
  num_dummy_heads=config.num_dummy_heads,
96
+ rms_norm_eps=config.rms_norm_eps,
96
97
  )
97
- self.norm1 = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
98
- self.norm2 = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
99
98
 
100
99
  self.mlp = Glm4vVisionMLP(
101
100
  config.hidden_size,
@@ -498,6 +497,9 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
498
497
  self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
499
498
  self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
500
499
 
500
+ # For EAGLE3 support
501
+ self.capture_aux_hidden_states = False
502
+
501
503
  def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
502
504
  pixel_values = torch.cat(
503
505
  [item.feature.squeeze(0) for item in items], dim=0
@@ -74,6 +74,9 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
74
74
  self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
75
75
  self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
76
76
 
77
+ # For EAGLE3 support
78
+ self.capture_aux_hidden_states = False
79
+
77
80
  def determine_num_fused_shared_experts(
78
81
  self, architecture: str = "Glm4MoeForCausalLM"
79
82
  ):
@@ -121,7 +121,7 @@ class GptOssSparseMoeBlock(nn.Module):
121
121
  )
122
122
 
123
123
  self.top_k = config.num_experts_per_tok
124
- experts_type = get_moe_impl_class()
124
+ experts_type = get_moe_impl_class(quant_config)
125
125
  extra_kwargs = {}
126
126
  if experts_type.__name__ == "FusedMoE":
127
127
  quant_config_name = (
@@ -423,6 +423,12 @@ class Llama4DecoderLayer(nn.Module):
423
423
  return self.config.num_local_experts > 0
424
424
  return (layer_id + 1) % self.config.interleave_moe_layer_step == 0
425
425
 
426
+ def get_intermediate_size(self) -> int:
427
+ if isinstance(self.feed_forward, Llama4MoE):
428
+ return self.config.intermediate_size
429
+ else:
430
+ return self.config.intermediate_size_mlp
431
+
426
432
  def forward(
427
433
  self,
428
434
  positions: torch.Tensor,
@@ -540,6 +546,9 @@ class Llama4ForCausalLM(LlamaForCausalLM):
540
546
  def get_input_embeddings(self):
541
547
  return self.model.embed_tokens
542
548
 
549
+ def get_layers(self):
550
+ return self.model.layers
551
+
543
552
  def _init_model(
544
553
  self,
545
554
  config: Llama4TextConfig,
@@ -109,6 +109,16 @@ class LlamaModel(nn.Module):
109
109
  ) -> None:
110
110
  super().__init__()
111
111
  self.config = config
112
+
113
+ self.is_mrope_enabled = (
114
+ hasattr(config, "rope_scaling")
115
+ and config.rope_scaling is not None
116
+ and "mrope_section" in config.rope_scaling
117
+ )
118
+ # fix rope_scaling for qwen2.5-vl
119
+ if self.is_mrope_enabled:
120
+ config.rope_scaling["rope_type"] = "default"
121
+
112
122
  self.vocab_size = config.vocab_size
113
123
  self.embed_tokens = VocabParallelEmbedding(
114
124
  config.vocab_size,
@@ -144,6 +154,9 @@ class LlamaModel(nn.Module):
144
154
  else:
145
155
  embeds = input_embeds
146
156
 
157
+ if self.is_mrope_enabled:
158
+ positions = forward_batch.mrope_positions
159
+
147
160
  hidden_states = forward_batch.spec_info.hidden_states
148
161
  if hidden_states.shape[-1] != embeds.shape[-1]:
149
162
  hidden_states = self.fc(hidden_states)
@@ -260,7 +260,7 @@ class LongcatFlashMoE(nn.Module):
260
260
  )
261
261
  self.topk.forward = self.topk.forward_native
262
262
 
263
- self.experts = get_moe_impl_class()(
263
+ self.experts = get_moe_impl_class(quant_config)(
264
264
  num_experts=self.num_experts,
265
265
  top_k=self.top_k,
266
266
  layer_id=self.layer_id,
@@ -853,7 +853,7 @@ class LongcatFlashForCausalLM(nn.Module):
853
853
 
854
854
  # Params for weights, fp8 weight scales, fp8 activation scales
855
855
  # (param_name, weight_name, expert_id, shard_id)
856
- expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
856
+ expert_params_mapping = FusedMoE.make_expert_params_mapping(
857
857
  ckpt_gate_proj_name="gate_proj",
858
858
  ckpt_down_proj_name="down_proj",
859
859
  ckpt_up_proj_name="up_proj",
@@ -961,5 +961,30 @@ class Llama4ForConditionalGeneration(nn.Module):
961
961
  def set_embed(self, embed):
962
962
  return self.language_model.set_embed(embed)
963
963
 
964
+ def get_hidden_dim(self, module_name, layer_idx):
965
+ # return input_dim, output_dim
966
+ if module_name == "qkv_proj":
967
+ return (
968
+ self.config.hidden_size,
969
+ self.config.head_dim
970
+ * (
971
+ self.config.num_attention_heads
972
+ + self.config.num_key_value_heads * 2
973
+ ),
974
+ )
975
+ elif module_name == "o_proj":
976
+ return (
977
+ self.config.head_dim * self.config.num_attention_heads,
978
+ self.config.hidden_size,
979
+ )
980
+ elif module_name == "gate_up_proj":
981
+ return self.config.hidden_size, self.config.intermediate_size * 2
982
+ elif module_name == "down_proj":
983
+ decoder_layer = self.language_model.get_layers()[layer_idx]
984
+ intermediate_size = decoder_layer.get_intermediate_size()
985
+ return intermediate_size, self.config.hidden_size
986
+ else:
987
+ raise NotImplementedError()
988
+
964
989
 
965
990
  EntryClass = Llama4ForConditionalGeneration