sglang 0.1.14__py3-none-any.whl → 0.1.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. sglang/__init__.py +59 -2
  2. sglang/api.py +40 -11
  3. sglang/backend/anthropic.py +17 -3
  4. sglang/backend/litellm.py +90 -0
  5. sglang/backend/openai.py +160 -12
  6. sglang/backend/runtime_endpoint.py +62 -27
  7. sglang/backend/vertexai.py +1 -0
  8. sglang/bench_latency.py +320 -0
  9. sglang/global_config.py +24 -3
  10. sglang/lang/chat_template.py +122 -6
  11. sglang/lang/compiler.py +2 -2
  12. sglang/lang/interpreter.py +206 -98
  13. sglang/lang/ir.py +98 -34
  14. sglang/lang/tracer.py +6 -4
  15. sglang/launch_server.py +4 -1
  16. sglang/launch_server_llavavid.py +32 -0
  17. sglang/srt/constrained/__init__.py +14 -6
  18. sglang/srt/constrained/fsm_cache.py +9 -2
  19. sglang/srt/constrained/jump_forward.py +113 -24
  20. sglang/srt/conversation.py +4 -2
  21. sglang/srt/flush_cache.py +18 -0
  22. sglang/srt/hf_transformers_utils.py +144 -3
  23. sglang/srt/layers/context_flashattention_nopad.py +1 -0
  24. sglang/srt/layers/extend_attention.py +20 -1
  25. sglang/srt/layers/fused_moe.py +596 -0
  26. sglang/srt/layers/logits_processor.py +190 -61
  27. sglang/srt/layers/radix_attention.py +62 -53
  28. sglang/srt/layers/token_attention.py +21 -9
  29. sglang/srt/managers/controller/cuda_graph_runner.py +196 -0
  30. sglang/srt/managers/controller/dp_worker.py +113 -0
  31. sglang/srt/managers/controller/infer_batch.py +908 -0
  32. sglang/srt/managers/controller/manager_multi.py +195 -0
  33. sglang/srt/managers/controller/manager_single.py +177 -0
  34. sglang/srt/managers/controller/model_runner.py +359 -0
  35. sglang/srt/managers/{router → controller}/radix_cache.py +102 -53
  36. sglang/srt/managers/controller/schedule_heuristic.py +65 -0
  37. sglang/srt/managers/controller/tp_worker.py +813 -0
  38. sglang/srt/managers/detokenizer_manager.py +42 -40
  39. sglang/srt/managers/io_struct.py +44 -10
  40. sglang/srt/managers/tokenizer_manager.py +224 -82
  41. sglang/srt/memory_pool.py +52 -59
  42. sglang/srt/model_config.py +97 -2
  43. sglang/srt/models/chatglm.py +399 -0
  44. sglang/srt/models/commandr.py +369 -0
  45. sglang/srt/models/dbrx.py +406 -0
  46. sglang/srt/models/gemma.py +34 -38
  47. sglang/srt/models/gemma2.py +436 -0
  48. sglang/srt/models/grok.py +738 -0
  49. sglang/srt/models/llama2.py +47 -37
  50. sglang/srt/models/llama_classification.py +107 -0
  51. sglang/srt/models/llava.py +92 -27
  52. sglang/srt/models/llavavid.py +298 -0
  53. sglang/srt/models/minicpm.py +366 -0
  54. sglang/srt/models/mixtral.py +302 -127
  55. sglang/srt/models/mixtral_quant.py +372 -0
  56. sglang/srt/models/qwen.py +40 -35
  57. sglang/srt/models/qwen2.py +33 -36
  58. sglang/srt/models/qwen2_moe.py +473 -0
  59. sglang/srt/models/stablelm.py +33 -39
  60. sglang/srt/models/yivl.py +19 -26
  61. sglang/srt/openai_api_adapter.py +411 -0
  62. sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +44 -19
  63. sglang/srt/sampling_params.py +2 -0
  64. sglang/srt/server.py +197 -481
  65. sglang/srt/server_args.py +190 -74
  66. sglang/srt/utils.py +460 -95
  67. sglang/test/test_programs.py +73 -10
  68. sglang/test/test_utils.py +226 -7
  69. sglang/utils.py +97 -27
  70. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/METADATA +74 -45
  71. sglang-0.1.21.dist-info/RECORD +82 -0
  72. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/WHEEL +1 -1
  73. sglang/srt/backend_config.py +0 -13
  74. sglang/srt/managers/router/infer_batch.py +0 -503
  75. sglang/srt/managers/router/manager.py +0 -79
  76. sglang/srt/managers/router/model_rpc.py +0 -686
  77. sglang/srt/managers/router/model_runner.py +0 -514
  78. sglang/srt/managers/router/scheduler.py +0 -70
  79. sglang-0.1.14.dist-info/RECORD +0 -64
  80. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/LICENSE +0 -0
  81. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,372 @@
1
+ # Adapted from
2
+ # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral_quant.py#L1
3
+ """Inference-only Mixtral model."""
4
+ from typing import Iterable, Optional, Tuple
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torch import nn
10
+ from transformers import MixtralConfig
11
+ from vllm.config import CacheConfig
12
+ from vllm.distributed import (
13
+ get_tensor_model_parallel_rank,
14
+ get_tensor_model_parallel_world_size,
15
+ tensor_model_parallel_all_reduce,
16
+ )
17
+ from vllm.model_executor.layers.layernorm import RMSNorm
18
+ from vllm.model_executor.layers.linear import (
19
+ QKVParallelLinear,
20
+ ReplicatedLinear,
21
+ RowParallelLinear,
22
+ )
23
+ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
24
+ from vllm.model_executor.layers.rotary_embedding import get_rope
25
+ from vllm.model_executor.layers.vocab_parallel_embedding import (
26
+ ParallelLMHead,
27
+ VocabParallelEmbedding,
28
+ )
29
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
30
+
31
+ from sglang.srt.layers.logits_processor import LogitsProcessor
32
+ from sglang.srt.layers.radix_attention import RadixAttention
33
+ from sglang.srt.managers.controller.model_runner import InputMetadata
34
+
35
+
36
+ class MixtralMLP(nn.Module):
37
+ def __init__(
38
+ self,
39
+ num_experts: int,
40
+ hidden_size: int,
41
+ intermediate_size: int,
42
+ quant_config: Optional[QuantizationConfig] = None,
43
+ ) -> None:
44
+ super().__init__()
45
+ self.num_experts = num_experts
46
+ self.ffn_dim = intermediate_size
47
+ self.hidden_dim = hidden_size
48
+
49
+ self.w1 = ReplicatedLinear(
50
+ self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config
51
+ )
52
+ self.w2 = ReplicatedLinear(
53
+ self.ffn_dim, self.hidden_dim, bias=False, quant_config=quant_config
54
+ )
55
+ self.w3 = ReplicatedLinear(
56
+ self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config
57
+ )
58
+
59
+ # TODO: Use vllm's SiluAndMul
60
+ self.act_fn = nn.SiLU()
61
+
62
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
63
+ w1_out, _ = self.w1(hidden_states)
64
+ w1_out = self.act_fn(w1_out)
65
+ w3_out, _ = self.w3(hidden_states)
66
+ current_hidden_states = w1_out * w3_out
67
+ current_hidden_states, _ = self.w2(current_hidden_states)
68
+ return current_hidden_states
69
+
70
+
71
+ class MixtralMoE(nn.Module):
72
+ def __init__(
73
+ self,
74
+ config: MixtralConfig,
75
+ quant_config: Optional[QuantizationConfig] = None,
76
+ ):
77
+ super().__init__()
78
+ self.config = config
79
+ self.rank = get_tensor_model_parallel_rank()
80
+ self.tp_size = get_tensor_model_parallel_world_size()
81
+ self.num_total_experts = config.num_local_experts
82
+ self.top_k = config.num_experts_per_tok
83
+ if self.tp_size > self.num_total_experts:
84
+ raise ValueError(
85
+ f"Tensor parallel size {self.tp_size} is greater than "
86
+ f"the number of experts {self.num_total_experts}."
87
+ )
88
+ # Split experts equally between ranks
89
+ self.expert_indicies = np.array_split(
90
+ range(self.num_total_experts), self.tp_size
91
+ )[self.rank].tolist()
92
+ if not self.expert_indicies:
93
+ raise ValueError(f"Rank {self.rank} has no experts assigned to it.")
94
+
95
+ self.experts = nn.ModuleList(
96
+ [
97
+ (
98
+ MixtralMLP(
99
+ self.num_total_experts,
100
+ config.hidden_size,
101
+ config.intermediate_size,
102
+ quant_config=quant_config,
103
+ )
104
+ if idx in self.expert_indicies
105
+ else None
106
+ )
107
+ for idx in range(self.num_total_experts)
108
+ ]
109
+ )
110
+ self.gate = ReplicatedLinear(
111
+ config.hidden_size, self.num_total_experts, bias=False, quant_config=None
112
+ )
113
+
114
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
115
+ router_logits, _ = self.gate(hidden_states)
116
+
117
+ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
118
+ routing_weights, selected_experts = torch.topk(
119
+ routing_weights, self.top_k, dim=-1
120
+ )
121
+ routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
122
+
123
+ final_hidden_states = None
124
+ for expert_idx in self.expert_indicies:
125
+ expert_layer = self.experts[expert_idx]
126
+ expert_mask = selected_experts == expert_idx
127
+ expert_weights = (routing_weights * expert_mask).sum(dim=-1, keepdim=True)
128
+
129
+ current_hidden_states = expert_layer(hidden_states).mul_(expert_weights)
130
+ if final_hidden_states is None:
131
+ final_hidden_states = current_hidden_states
132
+ else:
133
+ final_hidden_states.add_(current_hidden_states)
134
+
135
+ return tensor_model_parallel_all_reduce(final_hidden_states)
136
+
137
+
138
+ class MixtralAttention(nn.Module):
139
+ def __init__(
140
+ self,
141
+ hidden_size: int,
142
+ num_heads: int,
143
+ num_kv_heads: int,
144
+ layer_id: int = 0,
145
+ max_position: int = 4096 * 32,
146
+ rope_theta: float = 10000,
147
+ quant_config: Optional[QuantizationConfig] = None,
148
+ sliding_window: Optional[int] = None,
149
+ ) -> None:
150
+ super().__init__()
151
+ self.hidden_size = hidden_size
152
+ tp_size = get_tensor_model_parallel_world_size()
153
+ self.total_num_heads = num_heads
154
+ assert self.total_num_heads % tp_size == 0
155
+ self.num_heads = self.total_num_heads // tp_size
156
+ self.total_num_kv_heads = num_kv_heads
157
+ if self.total_num_kv_heads >= tp_size:
158
+ # Number of KV heads is greater than TP size, so we partition
159
+ # the KV heads across multiple tensor parallel GPUs.
160
+ assert self.total_num_kv_heads % tp_size == 0
161
+ else:
162
+ # Number of KV heads is less than TP size, so we replicate
163
+ # the KV heads across multiple tensor parallel GPUs.
164
+ assert tp_size % self.total_num_kv_heads == 0
165
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
166
+ self.head_dim = hidden_size // self.total_num_heads
167
+ self.q_size = self.num_heads * self.head_dim
168
+ self.kv_size = self.num_kv_heads * self.head_dim
169
+ self.scaling = self.head_dim**-0.5
170
+ self.rope_theta = rope_theta
171
+ self.sliding_window = sliding_window
172
+
173
+ self.qkv_proj = QKVParallelLinear(
174
+ hidden_size,
175
+ self.head_dim,
176
+ self.total_num_heads,
177
+ self.total_num_kv_heads,
178
+ bias=False,
179
+ quant_config=quant_config,
180
+ )
181
+ self.o_proj = RowParallelLinear(
182
+ self.total_num_heads * self.head_dim,
183
+ hidden_size,
184
+ bias=False,
185
+ quant_config=quant_config,
186
+ )
187
+ self.rotary_emb = get_rope(
188
+ self.head_dim,
189
+ rotary_dim=self.head_dim,
190
+ max_position=max_position,
191
+ base=int(self.rope_theta),
192
+ is_neox_style=True,
193
+ )
194
+ self.attn = RadixAttention(
195
+ self.num_heads,
196
+ self.head_dim,
197
+ self.scaling,
198
+ num_kv_heads=self.num_kv_heads,
199
+ layer_id=layer_id,
200
+ )
201
+
202
+ def forward(
203
+ self,
204
+ positions: torch.Tensor,
205
+ hidden_states: torch.Tensor,
206
+ input_metadata: InputMetadata,
207
+ ) -> torch.Tensor:
208
+ qkv, _ = self.qkv_proj(hidden_states)
209
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
210
+ q, k = self.rotary_emb(positions, q, k)
211
+ attn_output = self.attn(q, k, v, input_metadata)
212
+ output, _ = self.o_proj(attn_output)
213
+ return output
214
+
215
+
216
+ class MixtralDecoderLayer(nn.Module):
217
+ def __init__(
218
+ self,
219
+ config: MixtralConfig,
220
+ layer_id: int = 0,
221
+ quant_config: Optional[QuantizationConfig] = None,
222
+ ) -> None:
223
+ super().__init__()
224
+ self.hidden_size = config.hidden_size
225
+ # Requires transformers > 4.32.0
226
+ rope_theta = getattr(config, "rope_theta", 10000)
227
+ self.self_attn = MixtralAttention(
228
+ hidden_size=self.hidden_size,
229
+ num_heads=config.num_attention_heads,
230
+ max_position=config.max_position_embeddings,
231
+ num_kv_heads=config.num_key_value_heads,
232
+ layer_id=layer_id,
233
+ rope_theta=rope_theta,
234
+ sliding_window=config.sliding_window,
235
+ quant_config=quant_config,
236
+ )
237
+ self.block_sparse_moe = MixtralMoE(config=config, quant_config=quant_config)
238
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
239
+ self.post_attention_layernorm = RMSNorm(
240
+ config.hidden_size, eps=config.rms_norm_eps
241
+ )
242
+
243
+ def forward(
244
+ self,
245
+ positions: torch.Tensor,
246
+ hidden_states: torch.Tensor,
247
+ input_metadata: InputMetadata,
248
+ residual: Optional[torch.Tensor],
249
+ ) -> torch.Tensor:
250
+ # Self Attention
251
+ if residual is None:
252
+ residual = hidden_states
253
+ hidden_states = self.input_layernorm(hidden_states)
254
+ else:
255
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
256
+ hidden_states = self.self_attn(
257
+ positions=positions,
258
+ hidden_states=hidden_states,
259
+ input_metadata=input_metadata,
260
+ )
261
+
262
+ # Fully Connected
263
+ hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
264
+ hidden_states = self.block_sparse_moe(hidden_states)
265
+ return hidden_states, residual
266
+
267
+
268
+ class MixtralModel(nn.Module):
269
+ def __init__(
270
+ self,
271
+ config: MixtralConfig,
272
+ quant_config: Optional[QuantizationConfig] = None,
273
+ ) -> None:
274
+ super().__init__()
275
+ self.padding_idx = config.pad_token_id
276
+ self.vocab_size = config.vocab_size
277
+
278
+ self.embed_tokens = VocabParallelEmbedding(
279
+ config.vocab_size,
280
+ config.hidden_size,
281
+ )
282
+ self.layers = nn.ModuleList(
283
+ [
284
+ MixtralDecoderLayer(config, i, quant_config=quant_config)
285
+ for i in range(config.num_hidden_layers)
286
+ ]
287
+ )
288
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
289
+
290
+ def forward(
291
+ self,
292
+ input_ids: torch.Tensor,
293
+ positions: torch.Tensor,
294
+ input_metadata: InputMetadata,
295
+ input_embeds: torch.Tensor = None,
296
+ ) -> torch.Tensor:
297
+ if input_embeds is None:
298
+ hidden_states = self.embed_tokens(input_ids)
299
+ else:
300
+ hidden_states = input_embeds
301
+ residual = None
302
+ for i in range(len(self.layers)):
303
+ layer = self.layers[i]
304
+ hidden_states, residual = layer(
305
+ positions, hidden_states, input_metadata, residual
306
+ )
307
+ hidden_states, _ = self.norm(hidden_states, residual)
308
+ return hidden_states
309
+
310
+
311
+ class QuantMixtralForCausalLM(nn.Module):
312
+ def __init__(
313
+ self,
314
+ config: MixtralConfig,
315
+ quant_config: Optional[QuantizationConfig] = None,
316
+ cache_config: Optional[CacheConfig] = None,
317
+ ) -> None:
318
+ super().__init__()
319
+ self.config = config
320
+ self.quant_config = quant_config
321
+ self.model = MixtralModel(config, quant_config=quant_config)
322
+ self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
323
+ self.logits_processor = LogitsProcessor(config)
324
+
325
+ def forward(
326
+ self,
327
+ input_ids: torch.Tensor,
328
+ positions: torch.Tensor,
329
+ input_metadata: InputMetadata,
330
+ input_embeds: torch.Tensor = None,
331
+ ) -> torch.Tensor:
332
+ hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
333
+ return self.logits_processor(
334
+ input_ids, hidden_states, self.lm_head.weight, input_metadata
335
+ )
336
+
337
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
338
+ stacked_params_mapping = [
339
+ # (param_name, shard_name, shard_id)
340
+ ("qkv_proj", "q_proj", "q"),
341
+ ("qkv_proj", "k_proj", "k"),
342
+ ("qkv_proj", "v_proj", "v"),
343
+ ]
344
+
345
+ params_dict = dict(self.named_parameters())
346
+ for name, loaded_weight in weights:
347
+ if "rotary_emb.inv_freq" in name:
348
+ continue
349
+ for param_name, weight_name, shard_id in stacked_params_mapping:
350
+ if weight_name not in name:
351
+ continue
352
+ name = name.replace(weight_name, param_name)
353
+ # Skip loading extra bias for GPTQ models.
354
+ if name.endswith(".bias") and name not in params_dict:
355
+ continue
356
+ param = params_dict[name]
357
+ weight_loader = param.weight_loader
358
+ weight_loader(param, loaded_weight, shard_id)
359
+ break
360
+ else:
361
+ # Skip loading extra bias for GPTQ models.
362
+ if name.endswith(".bias") and name not in params_dict:
363
+ continue
364
+ # Skip experts that are not assigned to this worker.
365
+ if "block_sparse_moe.experts." in name and name not in params_dict:
366
+ continue
367
+ param = params_dict[name]
368
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
369
+ weight_loader(param, loaded_weight)
370
+
371
+
372
+ EntryClass = QuantMixtralForCausalLM
sglang/srt/models/qwen.py CHANGED
@@ -1,31 +1,30 @@
1
- from typing import Any, Dict, List, Optional, Tuple
1
+ # Adapted from
2
+ # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/qwen.py#L1
3
+ from typing import Any, Dict, Iterable, Optional, Tuple
2
4
 
3
5
  import torch
4
- from sglang.srt.layers.logits_processor import LogitsProcessor
5
- from sglang.srt.layers.radix_attention import RadixAttention
6
- from sglang.srt.managers.router.model_runner import InputMetadata
7
6
  from torch import nn
8
7
  from transformers import PretrainedConfig
8
+ from vllm.config import CacheConfig
9
+ from vllm.distributed import get_tensor_model_parallel_world_size
9
10
  from vllm.model_executor.layers.activation import SiluAndMul
10
11
  from vllm.model_executor.layers.layernorm import RMSNorm
11
12
  from vllm.model_executor.layers.linear import (
12
- LinearMethodBase,
13
13
  MergedColumnParallelLinear,
14
14
  QKVParallelLinear,
15
15
  RowParallelLinear,
16
16
  )
17
+ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
17
18
  from vllm.model_executor.layers.rotary_embedding import get_rope
18
19
  from vllm.model_executor.layers.vocab_parallel_embedding import (
19
20
  ParallelLMHead,
20
21
  VocabParallelEmbedding,
21
22
  )
22
- from vllm.model_executor.parallel_utils.parallel_state import (
23
- get_tensor_model_parallel_world_size,
24
- )
25
- from vllm.model_executor.weight_utils import (
26
- default_weight_loader,
27
- hf_model_weights_iterator,
28
- )
23
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
24
+
25
+ from sglang.srt.layers.logits_processor import LogitsProcessor
26
+ from sglang.srt.layers.radix_attention import RadixAttention
27
+ from sglang.srt.managers.controller.model_runner import InputMetadata
29
28
 
30
29
 
31
30
  class QWenMLP(nn.Module):
@@ -34,7 +33,7 @@ class QWenMLP(nn.Module):
34
33
  hidden_size: int,
35
34
  intermediate_size: int,
36
35
  hidden_act: str = "silu",
37
- linear_method: Optional[LinearMethodBase] = None,
36
+ quant_config: Optional[QuantizationConfig] = None,
38
37
  ):
39
38
  super().__init__()
40
39
  self.gate_up_proj = MergedColumnParallelLinear(
@@ -42,14 +41,14 @@ class QWenMLP(nn.Module):
42
41
  2 * [intermediate_size],
43
42
  bias=False,
44
43
  gather_output=False,
45
- linear_method=linear_method,
44
+ quant_config=quant_config,
46
45
  )
47
46
  self.c_proj = RowParallelLinear(
48
47
  intermediate_size,
49
48
  hidden_size,
50
49
  bias=False,
51
50
  input_is_parallel=True,
52
- linear_method=linear_method,
51
+ quant_config=quant_config,
53
52
  )
54
53
  if hidden_act != "silu":
55
54
  raise ValueError(
@@ -74,7 +73,7 @@ class QWenAttention(nn.Module):
74
73
  layer_id: int = 0,
75
74
  rope_theta: float = 10000,
76
75
  rope_scaling: Optional[Dict[str, Any]] = None,
77
- linear_method: Optional[LinearMethodBase] = None,
76
+ quant_config: Optional[QuantizationConfig] = None,
78
77
  ):
79
78
  super().__init__()
80
79
  self.hidden_size = hidden_size
@@ -90,14 +89,14 @@ class QWenAttention(nn.Module):
90
89
  self.head_dim,
91
90
  self.total_num_heads,
92
91
  bias=True,
93
- linear_method=linear_method,
92
+ quant_config=quant_config,
94
93
  )
95
94
  self.c_proj = RowParallelLinear(
96
95
  self.total_num_heads * self.head_dim,
97
96
  hidden_size,
98
97
  bias=False,
99
98
  input_is_parallel=True,
100
- linear_method=linear_method,
99
+ quant_config=quant_config,
101
100
  )
102
101
  self.rotary_emb = get_rope(
103
102
  self.head_dim,
@@ -130,7 +129,12 @@ class QWenAttention(nn.Module):
130
129
 
131
130
 
132
131
  class QWenBlock(nn.Module):
133
- def __init__(self, config: PretrainedConfig, layer_id, linear_method=None):
132
+ def __init__(
133
+ self,
134
+ config: PretrainedConfig,
135
+ layer_id,
136
+ quant_config: Optional[QuantizationConfig] = None,
137
+ ):
134
138
  super().__init__()
135
139
  self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
136
140
 
@@ -143,7 +147,7 @@ class QWenBlock(nn.Module):
143
147
  rope_theta=rope_theta,
144
148
  rope_scaling=rope_scaling,
145
149
  layer_id=layer_id,
146
- linear_method=linear_method,
150
+ quant_config=quant_config,
147
151
  )
148
152
 
149
153
  self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
@@ -151,7 +155,7 @@ class QWenBlock(nn.Module):
151
155
  self.mlp = QWenMLP(
152
156
  config.hidden_size,
153
157
  config.intermediate_size // 2,
154
- linear_method=linear_method,
158
+ quant_config=quant_config,
155
159
  )
156
160
 
157
161
  def forward(
@@ -179,7 +183,11 @@ class QWenBlock(nn.Module):
179
183
 
180
184
 
181
185
  class QWenModel(nn.Module):
182
- def __init__(self, config: PretrainedConfig, linear_method=None):
186
+ def __init__(
187
+ self,
188
+ config: PretrainedConfig,
189
+ quant_config: Optional[QuantizationConfig] = None,
190
+ ):
183
191
  super().__init__()
184
192
  self.config = config
185
193
  self.vocab_size = config.vocab_size
@@ -191,7 +199,7 @@ class QWenModel(nn.Module):
191
199
  )
192
200
  self.h = nn.ModuleList(
193
201
  [
194
- QWenBlock(config, i, linear_method=linear_method)
202
+ QWenBlock(config, i, quant_config=quant_config)
195
203
  for i in range(config.num_hidden_layers)
196
204
  ]
197
205
  )
@@ -216,10 +224,15 @@ class QWenModel(nn.Module):
216
224
 
217
225
 
218
226
  class QWenLMHeadModel(nn.Module):
219
- def __init__(self, config: PretrainedConfig, linear_method=None):
227
+ def __init__(
228
+ self,
229
+ config: PretrainedConfig,
230
+ quant_config: Optional[QuantizationConfig] = None,
231
+ cache_config: Optional[CacheConfig] = None,
232
+ ):
220
233
  super().__init__()
221
234
  self.config = config
222
- self.transformer = QWenModel(config, linear_method=linear_method)
235
+ self.transformer = QWenModel(config, quant_config=quant_config)
223
236
  vocab_size = ((config.vocab_size + 63) // 64) * 64
224
237
  self.lm_head = ParallelLMHead(vocab_size, config.hidden_size)
225
238
  self.logits_processor = LogitsProcessor(config)
@@ -236,22 +249,14 @@ class QWenLMHeadModel(nn.Module):
236
249
  )
237
250
  return next_tokens
238
251
 
239
- def load_weights(
240
- self,
241
- model_name_or_path: str,
242
- cache_dir: Optional[str] = None,
243
- load_format: str = "auto",
244
- revision: Optional[str] = None,
245
- ):
252
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
246
253
  stacked_params_mapping = [
247
254
  # (param_name, shard_name, shard_id)
248
255
  ("gate_up_proj", "w2", 0),
249
256
  ("gate_up_proj", "w1", 1),
250
257
  ]
251
258
  params_dict = dict(self.named_parameters())
252
- for name, loaded_weight in hf_model_weights_iterator(
253
- model_name_or_path, cache_dir, load_format, revision
254
- ):
259
+ for name, loaded_weight in weights:
255
260
  if "rotary_emb.inv_freq" in name:
256
261
  continue
257
262
  for param_name, weight_name, shard_id in stacked_params_mapping: