sglang 0.4.0.post1__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. sglang/bench_offline_throughput.py +6 -6
  2. sglang/bench_one_batch.py +1 -0
  3. sglang/bench_serving.py +9 -1
  4. sglang/check_env.py +140 -48
  5. sglang/lang/backend/runtime_endpoint.py +1 -0
  6. sglang/lang/chat_template.py +32 -0
  7. sglang/llama3_eval.py +316 -0
  8. sglang/srt/aio_rwlock.py +100 -0
  9. sglang/srt/configs/model_config.py +8 -1
  10. sglang/srt/constrained/xgrammar_backend.py +4 -1
  11. sglang/srt/layers/attention/flashinfer_backend.py +51 -5
  12. sglang/srt/layers/attention/triton_backend.py +16 -25
  13. sglang/srt/layers/attention/triton_ops/decode_attention.py +305 -350
  14. sglang/srt/layers/linear.py +20 -2
  15. sglang/srt/layers/logits_processor.py +133 -95
  16. sglang/srt/layers/{ep_moe → moe/ep_moe}/layer.py +18 -39
  17. sglang/srt/layers/moe/fused_moe_native.py +46 -0
  18. sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/__init__.py +3 -7
  19. sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/fused_moe.py +174 -119
  20. sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/layer.py +17 -49
  21. sglang/srt/layers/moe/topk.py +191 -0
  22. sglang/srt/layers/quantization/__init__.py +5 -50
  23. sglang/srt/layers/quantization/fp8.py +221 -36
  24. sglang/srt/layers/quantization/fp8_kernel.py +278 -0
  25. sglang/srt/layers/quantization/fp8_utils.py +90 -1
  26. sglang/srt/layers/radix_attention.py +8 -1
  27. sglang/srt/layers/sampler.py +27 -5
  28. sglang/srt/layers/torchao_utils.py +31 -0
  29. sglang/srt/managers/detokenizer_manager.py +37 -17
  30. sglang/srt/managers/io_struct.py +39 -10
  31. sglang/srt/managers/schedule_batch.py +54 -34
  32. sglang/srt/managers/schedule_policy.py +64 -5
  33. sglang/srt/managers/scheduler.py +171 -136
  34. sglang/srt/managers/tokenizer_manager.py +184 -133
  35. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  36. sglang/srt/mem_cache/chunk_cache.py +2 -2
  37. sglang/srt/mem_cache/memory_pool.py +15 -8
  38. sglang/srt/mem_cache/radix_cache.py +12 -2
  39. sglang/srt/model_executor/cuda_graph_runner.py +25 -11
  40. sglang/srt/model_executor/model_runner.py +28 -14
  41. sglang/srt/model_parallel.py +66 -5
  42. sglang/srt/models/dbrx.py +1 -1
  43. sglang/srt/models/deepseek.py +1 -1
  44. sglang/srt/models/deepseek_v2.py +67 -18
  45. sglang/srt/models/gemma2.py +34 -0
  46. sglang/srt/models/gemma2_reward.py +0 -1
  47. sglang/srt/models/granite.py +517 -0
  48. sglang/srt/models/grok.py +73 -9
  49. sglang/srt/models/llama.py +22 -0
  50. sglang/srt/models/llama_classification.py +11 -23
  51. sglang/srt/models/llama_reward.py +0 -2
  52. sglang/srt/models/llava.py +37 -14
  53. sglang/srt/models/mixtral.py +2 -2
  54. sglang/srt/models/olmoe.py +1 -1
  55. sglang/srt/models/qwen2.py +20 -0
  56. sglang/srt/models/qwen2_moe.py +1 -1
  57. sglang/srt/models/xverse_moe.py +1 -1
  58. sglang/srt/openai_api/adapter.py +8 -0
  59. sglang/srt/openai_api/protocol.py +9 -4
  60. sglang/srt/server.py +2 -1
  61. sglang/srt/server_args.py +19 -9
  62. sglang/srt/utils.py +40 -54
  63. sglang/test/test_block_fp8.py +341 -0
  64. sglang/test/test_utils.py +3 -2
  65. sglang/utils.py +10 -3
  66. sglang/version.py +1 -1
  67. {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/METADATA +12 -7
  68. {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/RECORD +73 -67
  69. sglang/srt/layers/fused_moe_patch.py +0 -133
  70. /sglang/srt/layers/{ep_moe → moe/ep_moe}/__init__.py +0 -0
  71. /sglang/srt/layers/{ep_moe → moe/ep_moe}/kernels.py +0 -0
  72. {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/LICENSE +0 -0
  73. {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/WHEEL +0 -0
  74. {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/top_level.txt +0 -0
@@ -20,15 +20,17 @@ from contextlib import contextmanager
20
20
  from typing import TYPE_CHECKING, Callable
21
21
 
22
22
  import torch
23
+ import tqdm
24
+ from vllm.distributed import get_tensor_model_parallel_rank
23
25
  from vllm.distributed.parallel_state import graph_capture
24
26
  from vllm.model_executor.custom_op import CustomOp
25
27
 
26
- from sglang.srt.layers.fused_moe_patch import fused_moe_forward_native
27
28
  from sglang.srt.layers.logits_processor import (
28
29
  LogitsMetadata,
29
30
  LogitsProcessor,
30
31
  LogitsProcessorOutput,
31
32
  )
33
+ from sglang.srt.layers.moe.fused_moe_native import fused_moe_forward_native
32
34
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
33
35
  from sglang.srt.utils import maybe_torch_compile, monkey_patch_vllm_all_gather
34
36
 
@@ -127,7 +129,7 @@ class CudaGraphRunner:
127
129
 
128
130
  # Batch sizes to capture
129
131
  if model_runner.server_args.disable_cuda_graph_padding:
130
- self.capture_bs = list(range(1, 32)) + [64, 128]
132
+ self.capture_bs = list(range(1, 33)) + [64, 128]
131
133
  else:
132
134
  self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
133
135
 
@@ -255,7 +257,12 @@ class CudaGraphRunner:
255
257
  def capture(self):
256
258
  with graph_capture() as graph_capture_context:
257
259
  self.stream = graph_capture_context.stream
258
- for bs in self.capture_bs:
260
+ capture_bs = (
261
+ tqdm.tqdm(self.capture_bs)
262
+ if get_tensor_model_parallel_rank() == 0
263
+ else self.capture_bs
264
+ )
265
+ for bs in capture_bs:
259
266
  with patch_model(
260
267
  self.model_runner.model,
261
268
  bs in self.compile_bs,
@@ -387,8 +394,14 @@ class CudaGraphRunner:
387
394
 
388
395
  # Extract logprobs
389
396
  if forward_batch.return_logprob:
390
- next_token_logprobs = torch.nn.functional.log_softmax(
391
- next_token_logits, dim=-1
397
+ logits_metadata = LogitsMetadata(
398
+ forward_mode=ForwardMode.DECODE,
399
+ top_logprobs_nums=forward_batch.top_logprobs_nums,
400
+ )
401
+ next_token_logprobs = (
402
+ LogitsProcessor.compute_temp_top_p_normalized_logprobs(
403
+ next_token_logits, logits_metadata
404
+ )
392
405
  )
393
406
  logits_output = LogitsProcessorOutput(
394
407
  next_token_logits=next_token_logits,
@@ -396,13 +409,14 @@ class CudaGraphRunner:
396
409
  )
397
410
  return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums)
398
411
  if return_top_logprob:
399
- logits_metadata = LogitsMetadata(
400
- forward_mode=ForwardMode.DECODE,
401
- top_logprobs_nums=forward_batch.top_logprobs_nums,
402
- )
403
- logits_output.output_top_logprobs = LogitsProcessor.get_top_logprobs(
412
+ (
413
+ logits_output.output_top_logprobs_val,
414
+ logits_output.output_top_logprobs_idx,
415
+ ) = LogitsProcessor.get_top_logprobs(
404
416
  next_token_logprobs, logits_metadata
405
- )[1]
417
+ )[
418
+ 2:4
419
+ ]
406
420
  else:
407
421
  logits_output = LogitsProcessorOutput(
408
422
  next_token_logits=next_token_logits,
@@ -95,6 +95,12 @@ class ModelRunner:
95
95
  ):
96
96
  logger.info("MLA optimization is turned on. Use triton backend.")
97
97
  self.server_args.attention_backend = "triton"
98
+ # FIXME(HandH1998)
99
+ if (
100
+ "DeepseekV3ForCausalLM" in self.model_config.hf_config.architectures
101
+ and not self.server_args.disable_cuda_graph
102
+ ):
103
+ self.server_args.disable_cuda_graph = True
98
104
 
99
105
  if self.server_args.enable_double_sparsity:
100
106
  logger.info(
@@ -111,17 +117,20 @@ class ModelRunner:
111
117
  )
112
118
 
113
119
  if self.is_multimodal:
114
- server_args.chunked_prefill_size = -1
115
120
  self.mem_fraction_static *= 0.95
116
- logger.info(
117
- f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static} "
118
- f"and turn off chunked prefill "
119
- f"because this is a multimodal model."
120
- )
121
+ if self.model_config.hf_config.architectures == [
122
+ "MllamaForConditionalGeneration"
123
+ ]:
124
+ logger.info("Automatically turn off --chunked-prefill-size for mllama.")
125
+ server_args.chunked_prefill_size = -1
121
126
  # TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
122
127
  if self.model_config.hf_config.architectures == [
123
128
  "Qwen2VLForConditionalGeneration"
124
129
  ]:
130
+ logger.info(
131
+ "Automatically turn off --chunked-prefill-size and disable radix cache for qwen2-vl."
132
+ )
133
+ server_args.chunked_prefill_size = -1
125
134
  server_args.disable_radix_cache = True
126
135
 
127
136
  # Global vars
@@ -154,6 +163,11 @@ class ModelRunner:
154
163
  self.sampler = Sampler()
155
164
  self.load_model()
156
165
 
166
+ # Apply torchao quantization
167
+ apply_torchao_config_to_model(
168
+ self.model, global_server_args_dict["torchao_config"]
169
+ )
170
+
157
171
  # Apply torch TP if the model supports it
158
172
  supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
159
173
  if self.tp_size > 1 and supports_torch_tp:
@@ -162,10 +176,6 @@ class ModelRunner:
162
176
  else:
163
177
  self.torch_tp_applied = False
164
178
 
165
- apply_torchao_config_to_model(
166
- self.model, global_server_args_dict["torchao_config"]
167
- )
168
-
169
179
  # Init memory pool and attention backends
170
180
  if server_args.lora_paths is not None:
171
181
  self.init_lora_manager()
@@ -242,20 +252,22 @@ class ModelRunner:
242
252
  if torch.cuda.get_device_capability()[1] < 5:
243
253
  raise RuntimeError("SGLang only supports sm75 and above.")
244
254
 
245
- # Prepare the vllm model config
255
+ # Prepare the model config
246
256
  self.load_config = LoadConfig(
247
257
  load_format=self.server_args.load_format,
248
258
  download_dir=self.server_args.download_dir,
249
259
  )
250
-
251
260
  if self.server_args.load_format == "gguf":
252
261
  monkey_patch_vllm_gguf_config()
262
+
263
+ # Load the model
253
264
  self.model = get_model(
254
265
  model_config=self.model_config,
255
266
  load_config=self.load_config,
256
267
  device_config=DeviceConfig(self.device),
257
268
  )
258
269
 
270
+ # Parse other args
259
271
  self.sliding_window_size = (
260
272
  self.model.get_attention_sliding_window_size()
261
273
  if hasattr(self.model, "get_attention_sliding_window_size")
@@ -270,8 +282,10 @@ class ModelRunner:
270
282
  f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
271
283
  )
272
284
 
273
- def update_weights_from_disk(self, model_path: str, load_format: str):
274
- """Update engine weights online from disk."""
285
+ def update_weights_from_disk(
286
+ self, model_path: str, load_format: str
287
+ ) -> tuple[bool, str]:
288
+ """Update engine weights in-place from the disk."""
275
289
  from sglang.srt.model_loader.loader import (
276
290
  DefaultModelLoader,
277
291
  device_loading_context,
@@ -2,18 +2,18 @@
2
2
  Common utilities for torch model parallelism.
3
3
  """
4
4
 
5
- from typing import Optional
5
+ from typing import Optional, Sequence
6
6
 
7
7
  import torch
8
+ import torch.nn as nn
8
9
  from torch.distributed.device_mesh import DeviceMesh
9
10
 
10
11
  try:
11
- from torch.distributed.tensor import DTensor, Shard
12
+ import torch.distributed.tensor as dt
12
13
  except ImportError:
13
14
  # torch 2.4 or older
14
- from torch.distributed._tensor import DTensor, Shard
15
+ import torch.distributed._tensor as dt
15
16
 
16
- from torch.distributed._functional_collectives import AsyncCollectiveTensor
17
17
  from torch.distributed.tensor.parallel import (
18
18
  ColwiseParallel,
19
19
  RowwiseParallel,
@@ -21,6 +21,50 @@ from torch.distributed.tensor.parallel import (
21
21
  )
22
22
 
23
23
 
24
+ def _shard_tensor(
25
+ full_tensor: torch.Tensor,
26
+ device_mesh: DeviceMesh,
27
+ placements: Sequence[dt.Shard],
28
+ ) -> "dt.DTensor":
29
+ """
30
+ Locally shards a full tensor based on indicated sharding arrangement, and
31
+ returns a DTensor containing the local shard.
32
+
33
+ .. warning:: This is a private API that is subject to change. It skips the
34
+ communication otherwise required by `distribute_tensor`. It is only
35
+ applicable to cases where all ranks have the same `full_tensor`. For
36
+ example, in distributed inference all ranks load from the same
37
+ checkpoint. This API will not check for data equality between ranks, it
38
+ is thus user's responsibility to ensure the `full_tensor` is the same
39
+ across ranks.
40
+
41
+ Args:
42
+ full_tensor (torch.Tensor): the full tensor to be sharded.
43
+ device_mesh (:class:`DeviceMesh`): DeviceMesh to place the
44
+ DTensor. Must have same dimension as the number of placements.
45
+ placements (Sequence[:class:`Shard`]): the placements that
46
+ describes how to place the local tensor on DeviceMesh.
47
+
48
+ Returns:
49
+ A :class:`DTensor` object with the shard as its local tensor.
50
+
51
+ Examples:
52
+ >>> # xdoctest: +SKIP("need world_size and rank")
53
+ >>> device_mesh = dist.init_device_mesh("cuda", (world_size,))
54
+ >>> full_tensor = torch.arange(world_size, device=f"cuda:{rank}")
55
+ >>> dtensor = _shard_tensor(full_tensor, device_mesh, [Shard(1)])
56
+ """
57
+ shape, offset = dt._utils.compute_local_shape_and_global_offset(
58
+ full_tensor.shape, device_mesh, placements
59
+ )
60
+ slices = [
61
+ slice(cur_offset, cur_offset + cur_shape)
62
+ for cur_shape, cur_offset in zip(shape, offset)
63
+ ]
64
+ local_tensor = full_tensor[slices]
65
+ return dt.DTensor.from_local(local_tensor, device_mesh, placements)
66
+
67
+
24
68
  class ColwiseParallelSharded(ColwiseParallel):
25
69
  """
26
70
  A version of ColwiseParallel where the local weight has been already
@@ -34,7 +78,7 @@ class ColwiseParallelSharded(ColwiseParallel):
34
78
  # means Colwise as Linear is input * weight^T + bias, where
35
79
  # weight would become Shard(1)
36
80
  for name, param in module.named_parameters():
37
- dtensor = DTensor.from_local(param, device_mesh, [Shard(0)])
81
+ dtensor = dt.DTensor.from_local(param, device_mesh, [dt.Shard(0)])
38
82
  dist_param = torch.nn.Parameter(dtensor, requires_grad=False)
39
83
  module.register_parameter(name, dist_param)
40
84
 
@@ -47,6 +91,23 @@ class RowwiseParallelMaybeWait(RowwiseParallel):
47
91
  AsyncCollectiveTensor and custom ops, such as `class RMSNorm(CustomOp)`.
48
92
  """
49
93
 
94
+ def _partition_linear_fn(self, name, module, device_mesh):
95
+ # Rowwise shard weight to Shard(1), bias to Replicate(), weight be Shard(1)
96
+ # means Rowwise as nn.Linear is input * weight^T + bias, where
97
+ # weight would become Shard(0)
98
+ module.register_parameter(
99
+ "weight",
100
+ nn.Parameter(_shard_tensor(module.weight, device_mesh, [dt.Shard(1)])),
101
+ )
102
+ if getattr(module, "bias", None) is not None:
103
+ # The Linear module has bias
104
+ module.register_parameter(
105
+ "bias",
106
+ nn.Parameter(
107
+ dt.distribute_tensor(module.bias, device_mesh, [dt.Replicate()])
108
+ ),
109
+ )
110
+
50
111
  @staticmethod
51
112
  def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
52
113
  outputs = super(
sglang/srt/models/dbrx.py CHANGED
@@ -27,13 +27,13 @@ from vllm.distributed import (
27
27
  from vllm.model_executor.layers.rotary_embedding import get_rope
28
28
  from vllm.transformers_utils.configs.dbrx import DbrxConfig
29
29
 
30
- from sglang.srt.layers.fused_moe_triton import fused_moe
31
30
  from sglang.srt.layers.linear import (
32
31
  QKVParallelLinear,
33
32
  ReplicatedLinear,
34
33
  RowParallelLinear,
35
34
  )
36
35
  from sglang.srt.layers.logits_processor import LogitsProcessor
36
+ from sglang.srt.layers.moe.fused_moe_triton import fused_moe
37
37
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
38
38
  from sglang.srt.layers.radix_attention import RadixAttention
39
39
  from sglang.srt.layers.vocab_parallel_embedding import (
@@ -29,7 +29,6 @@ from vllm.distributed import (
29
29
  from vllm.model_executor.layers.rotary_embedding import get_rope
30
30
 
31
31
  from sglang.srt.layers.activation import SiluAndMul
32
- from sglang.srt.layers.fused_moe_triton import fused_moe
33
32
  from sglang.srt.layers.layernorm import RMSNorm
34
33
  from sglang.srt.layers.linear import (
35
34
  MergedColumnParallelLinear,
@@ -38,6 +37,7 @@ from sglang.srt.layers.linear import (
38
37
  RowParallelLinear,
39
38
  )
40
39
  from sglang.srt.layers.logits_processor import LogitsProcessor
40
+ from sglang.srt.layers.moe.fused_moe_triton import fused_moe
41
41
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
42
42
  from sglang.srt.layers.radix_attention import RadixAttention
43
43
  from sglang.srt.layers.vocab_parallel_embedding import (
@@ -19,6 +19,7 @@
19
19
  from typing import Any, Dict, Iterable, Optional, Tuple
20
20
 
21
21
  import torch
22
+ import torch.nn.functional as F
22
23
  from torch import nn
23
24
  from transformers import PretrainedConfig
24
25
  from vllm import _custom_ops as ops
@@ -31,8 +32,6 @@ from vllm.distributed import (
31
32
  from vllm.model_executor.layers.rotary_embedding import get_rope
32
33
 
33
34
  from sglang.srt.layers.activation import SiluAndMul
34
- from sglang.srt.layers.ep_moe.layer import EPMoE
35
- from sglang.srt.layers.fused_moe_triton import FusedMoE
36
35
  from sglang.srt.layers.layernorm import RMSNorm
37
36
  from sglang.srt.layers.linear import (
38
37
  ColumnParallelLinear,
@@ -41,7 +40,13 @@ from sglang.srt.layers.linear import (
41
40
  RowParallelLinear,
42
41
  )
43
42
  from sglang.srt.layers.logits_processor import LogitsProcessor
43
+ from sglang.srt.layers.moe.ep_moe.layer import EPMoE
44
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
44
45
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
46
+ from sglang.srt.layers.quantization.fp8_utils import (
47
+ block_quant_to_tensor_quant,
48
+ input_to_float8,
49
+ )
45
50
  from sglang.srt.layers.radix_attention import RadixAttention
46
51
  from sglang.srt.layers.vocab_parallel_embedding import (
47
52
  ParallelLMHead,
@@ -90,6 +95,24 @@ class DeepseekV2MLP(nn.Module):
90
95
  return x
91
96
 
92
97
 
98
+ class MoEGate(nn.Module):
99
+ def __init__(self, config):
100
+ super().__init__()
101
+ self.weight = nn.Parameter(
102
+ torch.empty((config.n_routed_experts, config.hidden_size))
103
+ )
104
+ if config.topk_method == "noaux_tc":
105
+ self.e_score_correction_bias = nn.Parameter(
106
+ torch.empty((config.n_routed_experts))
107
+ )
108
+ else:
109
+ self.e_score_correction_bias = None
110
+
111
+ def forward(self, hidden_states):
112
+ logits = F.linear(hidden_states, self.weight, None)
113
+ return logits
114
+
115
+
93
116
  class DeepseekV2MoE(nn.Module):
94
117
 
95
118
  def __init__(
@@ -114,6 +137,8 @@ class DeepseekV2MoE(nn.Module):
114
137
  "Only silu is supported for now."
115
138
  )
116
139
 
140
+ self.gate = MoEGate(config=config)
141
+
117
142
  MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
118
143
  self.experts = MoEImpl(
119
144
  num_experts=config.n_routed_experts,
@@ -125,11 +150,9 @@ class DeepseekV2MoE(nn.Module):
125
150
  use_grouped_topk=True,
126
151
  num_expert_group=config.n_group,
127
152
  topk_group=config.topk_group,
153
+ correction_bias=self.gate.e_score_correction_bias,
128
154
  )
129
155
 
130
- self.gate = ReplicatedLinear(
131
- config.hidden_size, config.n_routed_experts, bias=False, quant_config=None
132
- )
133
156
  if config.n_shared_experts is not None:
134
157
  intermediate_size = config.moe_intermediate_size * config.n_shared_experts
135
158
  self.shared_experts = DeepseekV2MLP(
@@ -146,7 +169,7 @@ class DeepseekV2MoE(nn.Module):
146
169
  if self.n_shared_experts is not None:
147
170
  shared_output = self.shared_experts(hidden_states)
148
171
  # router_logits: (num_tokens, n_experts)
149
- router_logits, _ = self.gate(hidden_states)
172
+ router_logits = self.gate(hidden_states)
150
173
  final_hidden_states = (
151
174
  self.experts(hidden_states=hidden_states, router_logits=router_logits)
152
175
  * self.routed_scaling_factor
@@ -167,15 +190,6 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
167
190
  return 0.1 * mscale * math.log(scale) + 1.0
168
191
 
169
192
 
170
- def input_to_float8(x, dtype=torch.float8_e4m3fn):
171
- finfo = torch.finfo(dtype)
172
- min_val, max_val = x.aminmax()
173
- amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
174
- scale = finfo.max / amax
175
- x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
176
- return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
177
-
178
-
179
193
  class DeepseekV2Attention(nn.Module):
180
194
 
181
195
  def __init__(
@@ -439,7 +453,10 @@ class DeepseekV2AttentionMLA(nn.Module):
439
453
  quant_config=quant_config,
440
454
  )
441
455
  self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
442
- rope_scaling["rope_type"] = "deepseek_yarn"
456
+
457
+ if rope_scaling:
458
+ rope_scaling["rope_type"] = "deepseek_yarn"
459
+
443
460
  self.rotary_emb = get_rope(
444
461
  qk_rope_head_dim,
445
462
  rotary_dim=qk_rope_head_dim,
@@ -454,6 +471,8 @@ class DeepseekV2AttentionMLA(nn.Module):
454
471
  scaling_factor = rope_scaling["factor"]
455
472
  mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
456
473
  self.scaling = self.scaling * mscale * mscale
474
+ else:
475
+ self.rotary_emb.forward = self.rotary_emb.forward_native
457
476
 
458
477
  self.attn_mqa = RadixAttention(
459
478
  self.num_local_heads,
@@ -845,6 +864,16 @@ class DeepseekV2ForCausalLM(nn.Module):
845
864
 
846
865
  params_dict = dict(self.named_parameters())
847
866
  for name, loaded_weight in weights:
867
+ # TODO(HandH1998): Modify it when nextn is supported.
868
+ if hasattr(self.config, "num_nextn_predict_layers"):
869
+ num_nextn_layers = self.config.num_nextn_predict_layers
870
+ if num_nextn_layers > 0 and name.startswith("model.layers"):
871
+ name_list = name.split(".")
872
+ if (
873
+ len(name_list) >= 3
874
+ and int(name_list[2]) >= self.config.num_hidden_layers
875
+ ):
876
+ continue
848
877
  if "rotary_emb.inv_freq" in name:
849
878
  continue
850
879
  for param_name, weight_name, shard_id in stacked_params_mapping:
@@ -909,13 +938,33 @@ class DeepseekV2ForCausalLM(nn.Module):
909
938
  ).T
910
939
  else:
911
940
  w = self_attn.kv_b_proj.weight
941
+ # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
942
+ # This may affect the accuracy of fp8 model.
943
+ if (
944
+ hasattr(self.quant_config, "weight_block_size")
945
+ and w.dtype == torch.float8_e4m3fn
946
+ ):
947
+ weight_block_size = self.quant_config.weight_block_size
948
+ if weight_block_size is not None:
949
+ assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
950
+ w, scale = block_quant_to_tensor_quant(
951
+ w, self_attn.kv_b_proj.weight_scale_inv, weight_block_size
952
+ )
953
+ self_attn.w_scale = scale
912
954
  w_kc, w_vc = w.unflatten(
913
955
  0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
914
956
  ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
915
957
  self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
916
958
  self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
917
- if hasattr(self_attn.kv_b_proj, "weight_scale"):
959
+ if (
960
+ hasattr(self_attn.kv_b_proj, "weight_scale")
961
+ and self_attn.w_scale is None
962
+ ):
918
963
  self_attn.w_scale = self_attn.kv_b_proj.weight_scale
919
964
 
920
965
 
921
- EntryClass = DeepseekV2ForCausalLM
966
+ class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
967
+ pass
968
+
969
+
970
+ EntryClass = [DeepseekV2ForCausalLM, DeepseekV3ForCausalLM]
@@ -355,6 +355,40 @@ class Gemma2ForCausalLM(nn.Module):
355
355
  input_ids, hidden_states, self.model.embed_tokens, forward_batch
356
356
  )
357
357
 
358
+ def get_hidden_dim(self, module_name):
359
+ # return input_dim, output_dim
360
+ if module_name in ["q_proj", "qkv_proj"]:
361
+ return (
362
+ self.config.hidden_size,
363
+ self.config.head_dim * self.config.num_attention_heads,
364
+ )
365
+ elif module_name in ["o_proj"]:
366
+ return (
367
+ self.config.head_dim * self.config.num_attention_heads,
368
+ self.config.hidden_size,
369
+ )
370
+ elif module_name in ["kv_proj"]:
371
+ return (
372
+ self.config.hidden_size,
373
+ self.config.head_dim * self.config.num_key_value_heads,
374
+ )
375
+ elif module_name == "gate_up_proj":
376
+ return self.config.hidden_size, self.config.intermediate_size
377
+ elif module_name == "down_proj":
378
+ return self.config.intermediate_size, self.config.hidden_size
379
+ else:
380
+ raise NotImplementedError()
381
+
382
+ def get_module_name(self, name):
383
+ params_mapping = {
384
+ "q_proj": "qkv_proj",
385
+ "k_proj": "qkv_proj",
386
+ "v_proj": "qkv_proj",
387
+ "gate_proj": "gate_up_proj",
388
+ "up_proj": "gate_up_proj",
389
+ }
390
+ return params_mapping.get(name, name)
391
+
358
392
  def get_attention_sliding_window_size(self):
359
393
  return get_attention_sliding_window_size(self.config)
360
394
 
@@ -32,7 +32,6 @@ class Gemma2ForSequenceClassification(nn.Module):
32
32
  ) -> None:
33
33
  super().__init__()
34
34
  self.config = config
35
- self.torchao_config = None
36
35
  self.quant_config = quant_config
37
36
  self.num_labels = config.num_labels
38
37
  self.model = Gemma2Model(config, quant_config=quant_config)