sglang 0.1.14__py3-none-any.whl → 0.1.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. sglang/__init__.py +59 -2
  2. sglang/api.py +40 -11
  3. sglang/backend/anthropic.py +17 -3
  4. sglang/backend/litellm.py +90 -0
  5. sglang/backend/openai.py +160 -12
  6. sglang/backend/runtime_endpoint.py +62 -27
  7. sglang/backend/vertexai.py +1 -0
  8. sglang/bench_latency.py +320 -0
  9. sglang/global_config.py +24 -3
  10. sglang/lang/chat_template.py +122 -6
  11. sglang/lang/compiler.py +2 -2
  12. sglang/lang/interpreter.py +206 -98
  13. sglang/lang/ir.py +98 -34
  14. sglang/lang/tracer.py +6 -4
  15. sglang/launch_server.py +4 -1
  16. sglang/launch_server_llavavid.py +32 -0
  17. sglang/srt/constrained/__init__.py +14 -6
  18. sglang/srt/constrained/fsm_cache.py +9 -2
  19. sglang/srt/constrained/jump_forward.py +113 -24
  20. sglang/srt/conversation.py +4 -2
  21. sglang/srt/flush_cache.py +18 -0
  22. sglang/srt/hf_transformers_utils.py +144 -3
  23. sglang/srt/layers/context_flashattention_nopad.py +1 -0
  24. sglang/srt/layers/extend_attention.py +20 -1
  25. sglang/srt/layers/fused_moe.py +596 -0
  26. sglang/srt/layers/logits_processor.py +190 -61
  27. sglang/srt/layers/radix_attention.py +62 -53
  28. sglang/srt/layers/token_attention.py +21 -9
  29. sglang/srt/managers/controller/cuda_graph_runner.py +196 -0
  30. sglang/srt/managers/controller/dp_worker.py +113 -0
  31. sglang/srt/managers/controller/infer_batch.py +908 -0
  32. sglang/srt/managers/controller/manager_multi.py +195 -0
  33. sglang/srt/managers/controller/manager_single.py +177 -0
  34. sglang/srt/managers/controller/model_runner.py +359 -0
  35. sglang/srt/managers/{router → controller}/radix_cache.py +102 -53
  36. sglang/srt/managers/controller/schedule_heuristic.py +65 -0
  37. sglang/srt/managers/controller/tp_worker.py +813 -0
  38. sglang/srt/managers/detokenizer_manager.py +42 -40
  39. sglang/srt/managers/io_struct.py +44 -10
  40. sglang/srt/managers/tokenizer_manager.py +224 -82
  41. sglang/srt/memory_pool.py +52 -59
  42. sglang/srt/model_config.py +97 -2
  43. sglang/srt/models/chatglm.py +399 -0
  44. sglang/srt/models/commandr.py +369 -0
  45. sglang/srt/models/dbrx.py +406 -0
  46. sglang/srt/models/gemma.py +34 -38
  47. sglang/srt/models/gemma2.py +436 -0
  48. sglang/srt/models/grok.py +738 -0
  49. sglang/srt/models/llama2.py +47 -37
  50. sglang/srt/models/llama_classification.py +107 -0
  51. sglang/srt/models/llava.py +92 -27
  52. sglang/srt/models/llavavid.py +298 -0
  53. sglang/srt/models/minicpm.py +366 -0
  54. sglang/srt/models/mixtral.py +302 -127
  55. sglang/srt/models/mixtral_quant.py +372 -0
  56. sglang/srt/models/qwen.py +40 -35
  57. sglang/srt/models/qwen2.py +33 -36
  58. sglang/srt/models/qwen2_moe.py +473 -0
  59. sglang/srt/models/stablelm.py +33 -39
  60. sglang/srt/models/yivl.py +19 -26
  61. sglang/srt/openai_api_adapter.py +411 -0
  62. sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +44 -19
  63. sglang/srt/sampling_params.py +2 -0
  64. sglang/srt/server.py +197 -481
  65. sglang/srt/server_args.py +190 -74
  66. sglang/srt/utils.py +460 -95
  67. sglang/test/test_programs.py +73 -10
  68. sglang/test/test_utils.py +226 -7
  69. sglang/utils.py +97 -27
  70. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/METADATA +74 -45
  71. sglang-0.1.21.dist-info/RECORD +82 -0
  72. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/WHEEL +1 -1
  73. sglang/srt/backend_config.py +0 -13
  74. sglang/srt/managers/router/infer_batch.py +0 -503
  75. sglang/srt/managers/router/manager.py +0 -79
  76. sglang/srt/managers/router/model_rpc.py +0 -686
  77. sglang/srt/managers/router/model_runner.py +0 -514
  78. sglang/srt/managers/router/scheduler.py +0 -70
  79. sglang-0.1.14.dist-info/RECORD +0 -64
  80. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/LICENSE +0 -0
  81. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/top_level.txt +0 -0
@@ -1,141 +1,269 @@
1
1
  # Adapted from
2
- # https://github.com/vllm-project/vllm/blob/d0215a58e78572d91dadafe9d832a2db89b09a13/vllm/model_executor/models/mixtral.py#L1
2
+ # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
3
3
  """Inference-only Mixtral model."""
4
- from typing import List, Optional, Tuple
4
+ from typing import Iterable, Optional, Tuple
5
5
 
6
6
  import numpy as np
7
7
  import torch
8
8
  import torch.nn.functional as F
9
- from sglang.srt.layers.logits_processor import LogitsProcessor
10
- from sglang.srt.layers.radix_attention import RadixAttention
11
- from sglang.srt.managers.router.model_runner import InputMetadata
12
9
  from torch import nn
13
10
  from transformers import MixtralConfig
11
+ from vllm import _custom_ops as ops
12
+ from vllm.config import CacheConfig
13
+ from vllm.distributed import (
14
+ get_tensor_model_parallel_rank,
15
+ get_tensor_model_parallel_world_size,
16
+ tensor_model_parallel_all_reduce,
17
+ )
18
+ from vllm.model_executor.layers.fused_moe import fused_moe
14
19
  from vllm.model_executor.layers.layernorm import RMSNorm
15
20
  from vllm.model_executor.layers.linear import (
16
- LinearMethodBase,
17
21
  QKVParallelLinear,
18
22
  ReplicatedLinear,
19
23
  RowParallelLinear,
20
24
  )
25
+ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
26
+ from vllm.model_executor.layers.quantization.fp8 import Fp8Config
21
27
  from vllm.model_executor.layers.rotary_embedding import get_rope
22
28
  from vllm.model_executor.layers.vocab_parallel_embedding import (
23
29
  ParallelLMHead,
24
30
  VocabParallelEmbedding,
25
31
  )
26
- from vllm.model_executor.parallel_utils.communication_op import (
27
- tensor_model_parallel_all_reduce,
28
- )
29
- from vllm.model_executor.parallel_utils.parallel_state import (
30
- get_tensor_model_parallel_rank,
31
- get_tensor_model_parallel_world_size,
32
- )
33
- from vllm.model_executor.weight_utils import (
34
- default_weight_loader,
35
- hf_model_weights_iterator,
36
- )
32
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
33
+ from vllm.model_executor.utils import set_weight_attrs
34
+ from vllm.utils import print_warning_once
37
35
 
36
+ from sglang.srt.layers.logits_processor import LogitsProcessor
37
+ from sglang.srt.layers.radix_attention import RadixAttention
38
+ from sglang.srt.managers.controller.model_runner import InputMetadata
39
+
40
+
41
+ class MixtralMoE(nn.Module):
42
+ """A tensor-parallel MoE implementation for Mixtral that shards each expert
43
+ across all ranks.
44
+
45
+ Each expert's weights are sharded across all ranks and a fused MoE
46
+ kernel is used for the forward pass, and finally we reduce the outputs
47
+ across ranks.
48
+ """
38
49
 
39
- class MixtralMLP(nn.Module):
40
50
  def __init__(
41
51
  self,
42
52
  num_experts: int,
53
+ top_k: int,
43
54
  hidden_size: int,
44
55
  intermediate_size: int,
45
- linear_method: Optional[LinearMethodBase] = None,
46
- ) -> None:
56
+ params_dtype: Optional[torch.dtype] = None,
57
+ tp_size: Optional[int] = None,
58
+ quant_config: Optional[QuantizationConfig] = None,
59
+ ):
47
60
  super().__init__()
48
- self.num_experts = num_experts
49
- self.ffn_dim = intermediate_size
50
- self.hidden_dim = hidden_size
61
+ self.tp_size = tp_size or get_tensor_model_parallel_world_size()
62
+ self.num_total_experts = num_experts
63
+ self.top_k = top_k
64
+ self.hidden_size = hidden_size
65
+ self.intermediate_size = intermediate_size // self.tp_size
66
+ self.quant_config = quant_config
67
+
68
+ # FIXME(pcmoritz): Make this more general to support different
69
+ # quantization schemes
70
+ self.use_fp8 = isinstance(quant_config, Fp8Config)
71
+
72
+ if params_dtype is None:
73
+ params_dtype = torch.get_default_dtype()
74
+ self.params_dtype = params_dtype
51
75
 
52
- self.w1 = ReplicatedLinear(
53
- self.hidden_dim, self.ffn_dim, bias=False, linear_method=linear_method
76
+ # Gate always runs at half / full precision for now.
77
+ self.gate = ReplicatedLinear(
78
+ self.hidden_size,
79
+ self.num_total_experts,
80
+ bias=False,
81
+ params_dtype=self.params_dtype,
82
+ quant_config=None,
54
83
  )
55
- self.w2 = ReplicatedLinear(
56
- self.ffn_dim, self.hidden_dim, bias=False, linear_method=linear_method
84
+
85
+ if self.use_fp8 and self.quant_config.is_checkpoint_fp8_serialized:
86
+ params_dtype = torch.float8_e4m3fn
87
+
88
+ self.w13_weight = nn.Parameter(
89
+ torch.empty(
90
+ self.num_total_experts,
91
+ 2 * self.intermediate_size,
92
+ self.hidden_size,
93
+ dtype=params_dtype,
94
+ )
57
95
  )
58
- self.w3 = ReplicatedLinear(
59
- self.hidden_dim, self.ffn_dim, bias=False, linear_method=linear_method
96
+ self.w2_weight = nn.Parameter(
97
+ torch.empty(
98
+ self.num_total_experts,
99
+ self.hidden_size,
100
+ self.intermediate_size,
101
+ dtype=params_dtype,
102
+ )
60
103
  )
61
104
 
62
- # TODO: Use vllm's SiluAndMul
63
- self.act_fn = nn.SiLU()
105
+ set_weight_attrs(
106
+ self.w13_weight,
107
+ {
108
+ "weight_loader": self.weight_loader,
109
+ },
110
+ )
111
+ set_weight_attrs(
112
+ self.w2_weight,
113
+ {
114
+ "weight_loader": self.weight_loader,
115
+ },
116
+ )
64
117
 
65
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
66
- w1_out, _ = self.w1(hidden_states)
67
- w1_out = self.act_fn(w1_out)
68
- w3_out, _ = self.w3(hidden_states)
69
- current_hidden_states = w1_out * w3_out
70
- current_hidden_states, _ = self.w2(current_hidden_states)
71
- return current_hidden_states
118
+ # Used for fp8.
119
+ self.w13_scale = None
120
+ self.w2_scale = None
121
+ self.a13_scale = None
122
+ self.a2_scale = None
123
+
124
+ if self.use_fp8:
125
+ # WEIGHT_SCALE (for fp8)
126
+ self.w13_scale = nn.Parameter(
127
+ torch.ones(self.num_total_experts, dtype=torch.float32),
128
+ requires_grad=False,
129
+ )
130
+ self.w2_scale = nn.Parameter(
131
+ torch.ones(self.num_total_experts, dtype=torch.float32),
132
+ requires_grad=False,
133
+ )
72
134
 
135
+ # If loading fp8 checkpoint, pass the weight loaders.
136
+ # If loading an fp16 checkpoint, do not (we will quantize in
137
+ # process_weights_after_loading()
138
+ if quant_config.is_checkpoint_fp8_serialized:
139
+ set_weight_attrs(
140
+ self.w13_scale,
141
+ {
142
+ "weight_loader": self.weight_loader,
143
+ },
144
+ )
145
+ set_weight_attrs(
146
+ self.w2_scale,
147
+ {
148
+ "weight_loader": self.weight_loader,
149
+ },
150
+ )
73
151
 
74
- class MixtralMoE(nn.Module):
75
- def __init__(
152
+ # ACT_SCALE (for fp8)
153
+ if quant_config.activation_scheme == "static":
154
+ if not quant_config.is_checkpoint_fp8_serialized:
155
+ raise ValueError(
156
+ "Found static activation scheme for checkpoint that "
157
+ "was not serialized fp8."
158
+ )
159
+ self.a13_scale = nn.Parameter(
160
+ torch.zeros(self.num_total_experts, dtype=torch.float32),
161
+ requires_grad=False,
162
+ )
163
+ self.a2_scale = nn.Parameter(
164
+ torch.zeros(self.num_total_experts, dtype=torch.float32),
165
+ requires_grad=False,
166
+ )
167
+
168
+ set_weight_attrs(
169
+ self.a13_scale,
170
+ {
171
+ "weight_loader": self.weight_loader,
172
+ },
173
+ )
174
+ set_weight_attrs(
175
+ self.a2_scale,
176
+ {
177
+ "weight_loader": self.weight_loader,
178
+ },
179
+ )
180
+
181
+ def weight_loader(
76
182
  self,
77
- config: MixtralConfig,
78
- linear_method: Optional[LinearMethodBase] = None,
183
+ param: nn.Parameter,
184
+ loaded_weight: torch.Tensor,
185
+ weight_name: str,
186
+ expert_id: int,
79
187
  ):
80
- super().__init__()
81
- self.config = config
82
- self.rank = get_tensor_model_parallel_rank()
83
- self.tp_size = get_tensor_model_parallel_world_size()
84
- self.num_total_experts = config.num_local_experts
85
- self.top_k = config.num_experts_per_tok
86
- if self.tp_size > self.num_total_experts:
87
- raise ValueError(
88
- f"Tensor parallel size {self.tp_size} is greater than "
89
- f"the number of experts {self.num_total_experts}."
188
+ tp_rank = get_tensor_model_parallel_rank()
189
+ param_data = param.data
190
+ shard_size = self.intermediate_size
191
+ shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
192
+ if weight_name.endswith("w1.weight"):
193
+ param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
194
+ if weight_name.endswith("w3.weight"):
195
+ param_data[expert_id, shard_size : 2 * shard_size, :] = loaded_weight[
196
+ shard, :
197
+ ]
198
+ if weight_name.endswith("w2.weight"):
199
+ param_data[expert_id, :, :] = loaded_weight[:, shard]
200
+ if "act_scale" in weight_name or "weight_scale" in weight_name:
201
+ param_data[expert_id] = loaded_weight
202
+
203
+ def process_weights_after_loading(self):
204
+ # Fp8 is the only case where we need to process after loading.
205
+ if not self.use_fp8:
206
+ return
207
+
208
+ # If checkpoint is fp16, quantize here.
209
+ if not self.quant_config.is_checkpoint_fp8_serialized:
210
+ w13_weight = torch.empty_like(
211
+ self.w13_weight.data, dtype=torch.float8_e4m3fn
90
212
  )
91
- # Split experts equally between ranks
92
- self.expert_indicies = np.array_split(
93
- range(self.num_total_experts), self.tp_size
94
- )[self.rank].tolist()
95
- if not self.expert_indicies:
96
- raise ValueError(f"Rank {self.rank} has no experts assigned to it.")
97
-
98
- self.experts = nn.ModuleList(
99
- [
100
- (
101
- MixtralMLP(
102
- self.num_total_experts,
103
- config.hidden_size,
104
- config.intermediate_size,
105
- linear_method=linear_method,
106
- )
107
- if idx in self.expert_indicies
108
- else None
213
+ w2_weight = torch.empty_like(self.w2_weight.data, dtype=torch.float8_e4m3fn)
214
+ for expert in range(self.num_total_experts):
215
+ w13_weight[expert, :, :], self.w13_scale[expert] = ops.scaled_fp8_quant(
216
+ self.w13_weight.data[expert, :, :]
109
217
  )
110
- for idx in range(self.num_total_experts)
111
- ]
112
- )
113
- self.gate = ReplicatedLinear(
114
- config.hidden_size, self.num_total_experts, bias=False, linear_method=None
115
- )
218
+ w2_weight[expert, :, :], self.w2_scale[expert] = ops.scaled_fp8_quant(
219
+ self.w2_weight.data[expert, :, :]
220
+ )
221
+ self.w13_weight = nn.Parameter(w13_weight, requires_grad=False)
222
+ self.w2_weight = nn.Parameter(w2_weight, requires_grad=False)
223
+
224
+ # If checkpoint is fp8 + static, cleanup act_scales.
225
+ # Since state_dict has an act_scale per expert but our kernels
226
+ # are passed one act_scale shared across all experts.
227
+ elif self.quant_config.activation_scheme == "static":
228
+ if self.a13_scale is None or self.a2_scale is None:
229
+ raise ValueError(
230
+ "QuantConfig has static quantization, but found "
231
+ "activation scales are None."
232
+ )
233
+
234
+ if not all_close_1d(self.a13_scale) or not all_close_1d(self.a2_scale):
235
+ print_warning_once(
236
+ "Found act_scales that are not equal for fp8 MoE layer. "
237
+ "Using the maximum across experts for each layer. "
238
+ )
239
+
240
+ self.a13_scale = nn.Parameter(self.a13_scale.max(), requires_grad=False)
241
+ self.a2_scale = nn.Parameter(self.a2_scale.max(), requires_grad=False)
116
242
 
117
243
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
244
+ num_tokens, hidden_size = hidden_states.shape
245
+ hidden_states = hidden_states.view(-1, self.hidden_size)
246
+ # router_logits: (num_tokens, n_experts)
118
247
  router_logits, _ = self.gate(hidden_states)
119
-
120
- routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
121
- routing_weights, selected_experts = torch.topk(
122
- routing_weights, self.top_k, dim=-1
248
+ final_hidden_states = fused_moe(
249
+ hidden_states,
250
+ self.w13_weight,
251
+ self.w2_weight,
252
+ router_logits,
253
+ self.top_k,
254
+ renormalize=True,
255
+ inplace=True,
256
+ use_fp8=self.use_fp8,
257
+ w1_scale=self.w13_scale,
258
+ w2_scale=self.w2_scale,
259
+ a1_scale=self.a13_scale,
260
+ a2_scale=self.a2_scale,
123
261
  )
124
- routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
125
262
 
126
- final_hidden_states = None
127
- for expert_idx in self.expert_indicies:
128
- expert_layer = self.experts[expert_idx]
129
- expert_mask = selected_experts == expert_idx
130
- expert_weights = (routing_weights * expert_mask).sum(dim=-1, keepdim=True)
263
+ if self.tp_size > 1:
264
+ final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
131
265
 
132
- current_hidden_states = expert_layer(hidden_states).mul_(expert_weights)
133
- if final_hidden_states is None:
134
- final_hidden_states = current_hidden_states
135
- else:
136
- final_hidden_states.add_(current_hidden_states)
137
-
138
- return tensor_model_parallel_all_reduce(final_hidden_states)
266
+ return final_hidden_states.view(num_tokens, hidden_size)
139
267
 
140
268
 
141
269
  class MixtralAttention(nn.Module):
@@ -147,7 +275,7 @@ class MixtralAttention(nn.Module):
147
275
  layer_id: int = 0,
148
276
  max_position: int = 4096 * 32,
149
277
  rope_theta: float = 10000,
150
- linear_method: Optional[LinearMethodBase] = None,
278
+ quant_config: Optional[QuantizationConfig] = None,
151
279
  sliding_window: Optional[int] = None,
152
280
  ) -> None:
153
281
  super().__init__()
@@ -179,13 +307,13 @@ class MixtralAttention(nn.Module):
179
307
  self.total_num_heads,
180
308
  self.total_num_kv_heads,
181
309
  bias=False,
182
- linear_method=linear_method,
310
+ quant_config=quant_config,
183
311
  )
184
312
  self.o_proj = RowParallelLinear(
185
313
  self.total_num_heads * self.head_dim,
186
314
  hidden_size,
187
315
  bias=False,
188
- linear_method=linear_method,
316
+ quant_config=quant_config,
189
317
  )
190
318
  self.rotary_emb = get_rope(
191
319
  self.head_dim,
@@ -221,7 +349,7 @@ class MixtralDecoderLayer(nn.Module):
221
349
  self,
222
350
  config: MixtralConfig,
223
351
  layer_id: int = 0,
224
- linear_method: Optional[LinearMethodBase] = None,
352
+ quant_config: Optional[QuantizationConfig] = None,
225
353
  ) -> None:
226
354
  super().__init__()
227
355
  self.hidden_size = config.hidden_size
@@ -235,9 +363,15 @@ class MixtralDecoderLayer(nn.Module):
235
363
  layer_id=layer_id,
236
364
  rope_theta=rope_theta,
237
365
  sliding_window=config.sliding_window,
238
- linear_method=linear_method,
366
+ quant_config=quant_config,
367
+ )
368
+ self.block_sparse_moe = MixtralMoE(
369
+ num_experts=config.num_local_experts,
370
+ top_k=config.num_experts_per_tok,
371
+ hidden_size=config.hidden_size,
372
+ intermediate_size=config.intermediate_size,
373
+ quant_config=quant_config,
239
374
  )
240
- self.block_sparse_moe = MixtralMoE(config=config, linear_method=linear_method)
241
375
  self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
242
376
  self.post_attention_layernorm = RMSNorm(
243
377
  config.hidden_size, eps=config.rms_norm_eps
@@ -272,7 +406,7 @@ class MixtralModel(nn.Module):
272
406
  def __init__(
273
407
  self,
274
408
  config: MixtralConfig,
275
- linear_method: Optional[LinearMethodBase] = None,
409
+ quant_config: Optional[QuantizationConfig] = None,
276
410
  ) -> None:
277
411
  super().__init__()
278
412
  self.padding_idx = config.pad_token_id
@@ -285,7 +419,7 @@ class MixtralModel(nn.Module):
285
419
  # config.num_hidden_layers=16
286
420
  self.layers = nn.ModuleList(
287
421
  [
288
- MixtralDecoderLayer(config, i, linear_method=linear_method)
422
+ MixtralDecoderLayer(config, i, quant_config=quant_config)
289
423
  for i in range(config.num_hidden_layers)
290
424
  ]
291
425
  )
@@ -316,12 +450,13 @@ class MixtralForCausalLM(nn.Module):
316
450
  def __init__(
317
451
  self,
318
452
  config: MixtralConfig,
319
- linear_method: Optional[LinearMethodBase] = None,
453
+ quant_config: Optional[QuantizationConfig] = None,
454
+ cache_config: Optional[CacheConfig] = None,
320
455
  ) -> None:
321
456
  super().__init__()
322
457
  self.config = config
323
- self.linear_method = linear_method
324
- self.model = MixtralModel(config, linear_method)
458
+ self.quant_config = quant_config
459
+ self.model = MixtralModel(config, quant_config=quant_config)
325
460
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
326
461
  self.logits_processor = LogitsProcessor(config)
327
462
 
@@ -337,13 +472,7 @@ class MixtralForCausalLM(nn.Module):
337
472
  input_ids, hidden_states, self.lm_head.weight, input_metadata
338
473
  )
339
474
 
340
- def load_weights(
341
- self,
342
- model_name_or_path: str,
343
- cache_dir: Optional[str] = None,
344
- load_format: str = "auto",
345
- revision: Optional[str] = None,
346
- ):
475
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
347
476
  stacked_params_mapping = [
348
477
  # (param_name, shard_name, shard_id)
349
478
  ("qkv_proj", "q_proj", "q"),
@@ -351,16 +480,47 @@ class MixtralForCausalLM(nn.Module):
351
480
  ("qkv_proj", "v_proj", "v"),
352
481
  ]
353
482
 
483
+ expert_params_mapping = (
484
+ [
485
+ # These are the weight scales for the experts
486
+ # (param_name, weight_name, expert_id)
487
+ (
488
+ "w13_scale" if weight_name in ["w1", "w3"] else "w2_scale",
489
+ f"experts.{expert_id}.{weight_name}.weight_scale",
490
+ expert_id,
491
+ )
492
+ for expert_id in range(self.config.num_local_experts)
493
+ for weight_name in ["w1", "w2", "w3"]
494
+ ]
495
+ + [
496
+ # These are the weights for the experts
497
+ # (param_name, weight_name, expert_id)
498
+ (
499
+ "w13_weight" if weight_name in ["w1", "w3"] else "w2_weight",
500
+ f"experts.{expert_id}.{weight_name}.weight",
501
+ expert_id,
502
+ )
503
+ for expert_id in range(self.config.num_local_experts)
504
+ for weight_name in ["w1", "w2", "w3"]
505
+ ]
506
+ + [
507
+ # These are the activation scales for the experts
508
+ # (param_name, weight_name, expert_id)
509
+ (
510
+ "a13_scale" if weight_name in ["w1", "w3"] else "a2_scale",
511
+ f"experts.{expert_id}.{weight_name}.act_scale",
512
+ expert_id,
513
+ )
514
+ for expert_id in range(self.config.num_local_experts)
515
+ for weight_name in ["w1", "w2", "w3"]
516
+ ]
517
+ )
518
+
354
519
  params_dict = dict(self.named_parameters())
355
- for name, loaded_weight in hf_model_weights_iterator(
356
- model_name_or_path,
357
- cache_dir,
358
- load_format,
359
- revision,
360
- fall_back_to_pt=False,
361
- ):
520
+ for name, loaded_weight in weights:
362
521
  if "rotary_emb.inv_freq" in name:
363
522
  continue
523
+
364
524
  for param_name, weight_name, shard_id in stacked_params_mapping:
365
525
  if weight_name not in name:
366
526
  continue
@@ -373,15 +533,30 @@ class MixtralForCausalLM(nn.Module):
373
533
  weight_loader(param, loaded_weight, shard_id)
374
534
  break
375
535
  else:
376
- # Skip loading extra bias for GPTQ models.
377
- if name.endswith(".bias") and name not in params_dict:
378
- continue
379
- # Skip experts that are not assigned to this worker.
380
- if "block_sparse_moe.experts." in name and name not in params_dict:
381
- continue
382
- param = params_dict[name]
383
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
384
- weight_loader(param, loaded_weight)
536
+ for param_name, weight_name, expert_id in expert_params_mapping:
537
+ if weight_name not in name:
538
+ continue
539
+ name = name.replace(weight_name, param_name)
540
+ param = params_dict[name]
541
+ weight_loader = param.weight_loader
542
+ weight_loader(
543
+ param, loaded_weight, weight_name, expert_id=expert_id
544
+ )
545
+ break
546
+ else:
547
+ # Skip loading extra bias for GPTQ models.
548
+ if name.endswith(".bias") and name not in params_dict:
549
+ continue
550
+ param = params_dict[name]
551
+ weight_loader = getattr(
552
+ param, "weight_loader", default_weight_loader
553
+ )
554
+ weight_loader(param, loaded_weight)
555
+
556
+
557
+ def all_close_1d(x: torch.Tensor) -> bool:
558
+ assert len(x.shape) == 1
559
+ return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))
385
560
 
386
561
 
387
562
  EntryClass = MixtralForCausalLM