sglang 0.1.14__py3-none-any.whl → 0.1.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. sglang/__init__.py +59 -2
  2. sglang/api.py +40 -11
  3. sglang/backend/anthropic.py +17 -3
  4. sglang/backend/litellm.py +90 -0
  5. sglang/backend/openai.py +160 -12
  6. sglang/backend/runtime_endpoint.py +62 -27
  7. sglang/backend/vertexai.py +1 -0
  8. sglang/bench_latency.py +320 -0
  9. sglang/global_config.py +24 -3
  10. sglang/lang/chat_template.py +122 -6
  11. sglang/lang/compiler.py +2 -2
  12. sglang/lang/interpreter.py +206 -98
  13. sglang/lang/ir.py +98 -34
  14. sglang/lang/tracer.py +6 -4
  15. sglang/launch_server.py +4 -1
  16. sglang/launch_server_llavavid.py +32 -0
  17. sglang/srt/constrained/__init__.py +14 -6
  18. sglang/srt/constrained/fsm_cache.py +9 -2
  19. sglang/srt/constrained/jump_forward.py +113 -24
  20. sglang/srt/conversation.py +4 -2
  21. sglang/srt/flush_cache.py +18 -0
  22. sglang/srt/hf_transformers_utils.py +144 -3
  23. sglang/srt/layers/context_flashattention_nopad.py +1 -0
  24. sglang/srt/layers/extend_attention.py +20 -1
  25. sglang/srt/layers/fused_moe.py +596 -0
  26. sglang/srt/layers/logits_processor.py +190 -61
  27. sglang/srt/layers/radix_attention.py +62 -53
  28. sglang/srt/layers/token_attention.py +21 -9
  29. sglang/srt/managers/controller/cuda_graph_runner.py +196 -0
  30. sglang/srt/managers/controller/dp_worker.py +113 -0
  31. sglang/srt/managers/controller/infer_batch.py +908 -0
  32. sglang/srt/managers/controller/manager_multi.py +195 -0
  33. sglang/srt/managers/controller/manager_single.py +177 -0
  34. sglang/srt/managers/controller/model_runner.py +359 -0
  35. sglang/srt/managers/{router → controller}/radix_cache.py +102 -53
  36. sglang/srt/managers/controller/schedule_heuristic.py +65 -0
  37. sglang/srt/managers/controller/tp_worker.py +813 -0
  38. sglang/srt/managers/detokenizer_manager.py +42 -40
  39. sglang/srt/managers/io_struct.py +44 -10
  40. sglang/srt/managers/tokenizer_manager.py +224 -82
  41. sglang/srt/memory_pool.py +52 -59
  42. sglang/srt/model_config.py +97 -2
  43. sglang/srt/models/chatglm.py +399 -0
  44. sglang/srt/models/commandr.py +369 -0
  45. sglang/srt/models/dbrx.py +406 -0
  46. sglang/srt/models/gemma.py +34 -38
  47. sglang/srt/models/gemma2.py +436 -0
  48. sglang/srt/models/grok.py +738 -0
  49. sglang/srt/models/llama2.py +47 -37
  50. sglang/srt/models/llama_classification.py +107 -0
  51. sglang/srt/models/llava.py +92 -27
  52. sglang/srt/models/llavavid.py +298 -0
  53. sglang/srt/models/minicpm.py +366 -0
  54. sglang/srt/models/mixtral.py +302 -127
  55. sglang/srt/models/mixtral_quant.py +372 -0
  56. sglang/srt/models/qwen.py +40 -35
  57. sglang/srt/models/qwen2.py +33 -36
  58. sglang/srt/models/qwen2_moe.py +473 -0
  59. sglang/srt/models/stablelm.py +33 -39
  60. sglang/srt/models/yivl.py +19 -26
  61. sglang/srt/openai_api_adapter.py +411 -0
  62. sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +44 -19
  63. sglang/srt/sampling_params.py +2 -0
  64. sglang/srt/server.py +197 -481
  65. sglang/srt/server_args.py +190 -74
  66. sglang/srt/utils.py +460 -95
  67. sglang/test/test_programs.py +73 -10
  68. sglang/test/test_utils.py +226 -7
  69. sglang/utils.py +97 -27
  70. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/METADATA +74 -45
  71. sglang-0.1.21.dist-info/RECORD +82 -0
  72. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/WHEEL +1 -1
  73. sglang/srt/backend_config.py +0 -13
  74. sglang/srt/managers/router/infer_batch.py +0 -503
  75. sglang/srt/managers/router/manager.py +0 -79
  76. sglang/srt/managers/router/model_rpc.py +0 -686
  77. sglang/srt/managers/router/model_runner.py +0 -514
  78. sglang/srt/managers/router/scheduler.py +0 -70
  79. sglang-0.1.14.dist-info/RECORD +0 -64
  80. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/LICENSE +0 -0
  81. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,406 @@
1
+ # Adapted from:
2
+ # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/dbrx.py#L1
3
+ # coding=utf-8
4
+ from typing import Iterable, Optional, Tuple
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from vllm.config import CacheConfig
9
+ from vllm.distributed import (
10
+ get_tensor_model_parallel_rank,
11
+ get_tensor_model_parallel_world_size,
12
+ tensor_model_parallel_all_reduce,
13
+ )
14
+ from vllm.model_executor.layers.fused_moe import fused_moe
15
+ from vllm.model_executor.layers.linear import (
16
+ QKVParallelLinear,
17
+ ReplicatedLinear,
18
+ RowParallelLinear,
19
+ )
20
+ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
21
+ from vllm.model_executor.layers.rotary_embedding import get_rope
22
+ from vllm.model_executor.layers.vocab_parallel_embedding import (
23
+ DEFAULT_VOCAB_PADDING_SIZE,
24
+ ParallelLMHead,
25
+ VocabParallelEmbedding,
26
+ )
27
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
28
+ from vllm.model_executor.utils import set_weight_attrs
29
+ from vllm.transformers_utils.configs.dbrx import DbrxConfig
30
+
31
+ from sglang.srt.layers.logits_processor import LogitsProcessor
32
+ from sglang.srt.layers.radix_attention import RadixAttention
33
+ from sglang.srt.managers.controller.model_runner import InputMetadata
34
+
35
+
36
+ class DbrxRouter(nn.Module):
37
+ """A Router implementation for DBRX that returns logits for each expert
38
+ per token.
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ config: DbrxConfig,
44
+ params_dtype: Optional[torch.dtype] = None,
45
+ ):
46
+ super().__init__()
47
+ self.tp_size = get_tensor_model_parallel_world_size()
48
+ self.num_total_experts = config.ffn_config.moe_num_experts
49
+ self.d_model = config.d_model
50
+ self.layer = ReplicatedLinear(
51
+ self.d_model,
52
+ self.num_total_experts,
53
+ bias=False,
54
+ params_dtype=params_dtype,
55
+ quant_config=None,
56
+ )
57
+
58
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
59
+ router_logits, _ = self.layer(hidden_states)
60
+ return router_logits
61
+
62
+
63
+ class DbrxExperts(nn.Module):
64
+ """A tensor-parallel MoE implementation for DBRX.
65
+
66
+ Each expert's weights are sharded across all ranks and a fused MoE
67
+ kernel is used for the forward pass, and finally we reduce the outputs
68
+ across ranks.
69
+ """
70
+
71
+ def __init__(
72
+ self,
73
+ config: DbrxConfig,
74
+ quant_config: Optional[QuantizationConfig] = None,
75
+ params_dtype: Optional[torch.dtype] = None,
76
+ ):
77
+ super().__init__()
78
+ self.tp_size = get_tensor_model_parallel_world_size()
79
+ self.num_total_experts = config.ffn_config.moe_num_experts
80
+ self.top_k = config.ffn_config.moe_top_k
81
+ self.d_model = config.d_model
82
+ self.intermediate_size = config.ffn_config.ffn_hidden_size // self.tp_size
83
+
84
+ if params_dtype is None:
85
+ params_dtype = torch.get_default_dtype()
86
+ self.params_dtype = params_dtype
87
+
88
+ self.router = DbrxRouter(config, self.params_dtype)
89
+ self.ws = nn.Parameter(
90
+ torch.empty(
91
+ self.num_total_experts,
92
+ 2 * self.intermediate_size,
93
+ self.d_model,
94
+ device="cuda",
95
+ dtype=self.params_dtype,
96
+ )
97
+ )
98
+ self.w2s = nn.Parameter(
99
+ torch.empty(
100
+ self.num_total_experts,
101
+ self.d_model,
102
+ self.intermediate_size,
103
+ device="cuda",
104
+ dtype=self.params_dtype,
105
+ )
106
+ )
107
+
108
+ set_weight_attrs(
109
+ self.ws,
110
+ {
111
+ "weight_loader": self.weight_loader,
112
+ },
113
+ )
114
+ set_weight_attrs(
115
+ self.w2s,
116
+ {
117
+ "weight_loader": self.weight_loader,
118
+ },
119
+ )
120
+
121
+ def weight_loader(
122
+ self, param: nn.Parameter, loaded_weight: torch.Tensor, weight_name: str
123
+ ):
124
+ tp_rank = get_tensor_model_parallel_rank()
125
+ param_data = param.data
126
+ shard_size = self.intermediate_size
127
+ shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
128
+ # DBRX uses GLU for each experts.
129
+ # GLU has 3 linear layers: w1, v1 and w2.
130
+ if weight_name.endswith("w1"):
131
+ loaded_weight = torch.reshape(
132
+ loaded_weight,
133
+ [-1, self.intermediate_size * self.tp_size, self.d_model],
134
+ )
135
+ param_data[:, 0:shard_size, :] = loaded_weight[:, shard, :]
136
+ if weight_name.endswith("v1"):
137
+ loaded_weight = torch.reshape(
138
+ loaded_weight,
139
+ [-1, self.intermediate_size * self.tp_size, self.d_model],
140
+ )
141
+ param_data[:, shard_size : 2 * shard_size, :] = loaded_weight[:, shard, :]
142
+ if weight_name.endswith("w2"):
143
+ loaded_weight = torch.reshape(
144
+ loaded_weight,
145
+ [-1, self.intermediate_size * self.tp_size, self.d_model],
146
+ ).transpose(1, 2)
147
+ param_data[:] = loaded_weight[:, :, shard]
148
+
149
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
150
+ num_tokens, hidden_size = hidden_states.shape
151
+ hidden_states = hidden_states.view(-1, self.d_model)
152
+ # router_logits: (num_tokens, n_experts)
153
+ router_logits = self.router(hidden_states)
154
+ final_hidden_states = fused_moe(
155
+ hidden_states,
156
+ self.ws,
157
+ self.w2s,
158
+ router_logits,
159
+ self.top_k,
160
+ renormalize=True,
161
+ inplace=True,
162
+ )
163
+
164
+ if self.tp_size > 1:
165
+ final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
166
+
167
+ return final_hidden_states.view(num_tokens, hidden_size)
168
+
169
+
170
+ class DbrxAttention(nn.Module):
171
+ def __init__(
172
+ self,
173
+ config: DbrxConfig,
174
+ layer_id: int = 0,
175
+ quant_config: Optional[QuantizationConfig] = None,
176
+ ):
177
+ super().__init__()
178
+ self.d_model = config.d_model
179
+ self.total_num_heads = config.n_heads
180
+ self.head_dim = self.d_model // self.total_num_heads
181
+ self.total_num_kv_heads = config.attn_config.kv_n_heads
182
+ self.clip_qkv = config.attn_config.clip_qkv
183
+ self.rope_theta = config.attn_config.rope_theta
184
+ self.max_position = config.max_seq_len
185
+
186
+ # pylint: disable=invalid-name
187
+ self.Wqkv = QKVParallelLinear(
188
+ self.d_model,
189
+ self.head_dim,
190
+ self.total_num_heads,
191
+ self.total_num_kv_heads,
192
+ bias=False,
193
+ quant_config=quant_config,
194
+ )
195
+ self.out_proj = RowParallelLinear(
196
+ self.d_model,
197
+ self.d_model,
198
+ bias=False,
199
+ quant_config=quant_config,
200
+ )
201
+ self.rotary_emb = get_rope(
202
+ self.head_dim,
203
+ rotary_dim=self.head_dim,
204
+ max_position=self.max_position,
205
+ base=int(self.rope_theta),
206
+ is_neox_style=True,
207
+ )
208
+
209
+ tp_world_size = get_tensor_model_parallel_world_size()
210
+ self.tp_size = tp_world_size
211
+ assert self.total_num_heads % tp_world_size == 0
212
+ self.num_heads = self.total_num_heads // tp_world_size
213
+ if self.total_num_kv_heads >= tp_world_size:
214
+ # Number of KV heads is greater than TP size, so we partition
215
+ # the KV heads across multiple tensor parallel GPUs.
216
+ assert self.total_num_kv_heads % tp_world_size == 0
217
+ else:
218
+ # Number of KV heads is less than TP size, so we replicate
219
+ # the KV heads across multiple tensor parallel GPUs.
220
+ assert tp_world_size % self.total_num_kv_heads == 0
221
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
222
+ self.q_size = self.num_heads * self.head_dim
223
+ self.kv_size = self.num_kv_heads * self.head_dim
224
+ self.scaling = self.head_dim**-0.5
225
+ self.attn = RadixAttention(
226
+ self.num_heads,
227
+ self.head_dim,
228
+ self.scaling,
229
+ num_kv_heads=self.num_kv_heads,
230
+ layer_id=layer_id,
231
+ )
232
+
233
+ def forward(
234
+ self,
235
+ position_ids: torch.Tensor,
236
+ hidden_states: torch.Tensor,
237
+ input_metadata: InputMetadata,
238
+ ) -> torch.Tensor:
239
+ qkv, _ = self.Wqkv(hidden_states)
240
+ if self.clip_qkv is not None:
241
+ qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
242
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
243
+ q, k = self.rotary_emb(position_ids, q, k)
244
+ attn_output = self.attn(q, k, v, input_metadata)
245
+ hidden_states, _ = self.out_proj(attn_output)
246
+ return hidden_states
247
+
248
+
249
+ class DbrxFusedNormAttention(nn.Module):
250
+ def __init__(
251
+ self,
252
+ config: DbrxConfig,
253
+ layer_id: int = 0,
254
+ quant_config: Optional[QuantizationConfig] = None,
255
+ ):
256
+ super().__init__()
257
+ self.d_model = config.d_model
258
+ self.attn = DbrxAttention(config, layer_id, quant_config=quant_config)
259
+ self.norm_1 = nn.LayerNorm(self.d_model)
260
+ self.norm_2 = nn.LayerNorm(self.d_model)
261
+
262
+ def forward(
263
+ self,
264
+ position_ids: torch.Tensor,
265
+ hidden_states: torch.Tensor,
266
+ input_metadata: InputMetadata,
267
+ ) -> torch.Tensor:
268
+ residual = hidden_states
269
+ hidden_states = self.norm_1(hidden_states)
270
+ x = self.attn(
271
+ position_ids=position_ids,
272
+ hidden_states=hidden_states,
273
+ input_metadata=input_metadata,
274
+ )
275
+ hidden_states = residual + x
276
+ residual = hidden_states
277
+ hidden_states = self.norm_2(hidden_states)
278
+ return hidden_states, residual
279
+
280
+
281
+ class DbrxBlock(nn.Module):
282
+ def __init__(
283
+ self,
284
+ config: DbrxConfig,
285
+ layer_id: int = 0,
286
+ quant_config: Optional[QuantizationConfig] = None,
287
+ ):
288
+ super().__init__()
289
+ self.norm_attn_norm = DbrxFusedNormAttention(
290
+ config, layer_id, quant_config=quant_config
291
+ )
292
+ self.ffn = DbrxExperts(config, quant_config=quant_config)
293
+
294
+ def forward(
295
+ self,
296
+ position_ids: torch.Tensor,
297
+ hidden_states: torch.Tensor,
298
+ input_metadata: InputMetadata,
299
+ ) -> torch.Tensor:
300
+ hidden_states, residual = self.norm_attn_norm(
301
+ position_ids=position_ids,
302
+ hidden_states=hidden_states,
303
+ input_metadata=input_metadata,
304
+ )
305
+ hidden_states = self.ffn(hidden_states)
306
+ hidden_states = hidden_states + residual
307
+ return hidden_states
308
+
309
+
310
+ class DbrxModel(nn.Module):
311
+ def __init__(
312
+ self,
313
+ config: DbrxConfig,
314
+ quant_config: Optional[QuantizationConfig] = None,
315
+ ):
316
+ super().__init__()
317
+ self.wte = VocabParallelEmbedding(
318
+ config.vocab_size,
319
+ config.d_model,
320
+ )
321
+ self.blocks = nn.ModuleList(
322
+ [
323
+ DbrxBlock(config, i, quant_config=quant_config)
324
+ for i in range(config.n_layers)
325
+ ]
326
+ )
327
+ self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5)
328
+ for module in self.modules():
329
+ if hasattr(module, "bias") and isinstance(module.bias, nn.Parameter):
330
+ # Remove the bias term in Linear and LayerNorm.
331
+ module.register_parameter("bias", None)
332
+
333
+ def forward(
334
+ self,
335
+ input_ids: torch.Tensor,
336
+ position_ids: torch.Tensor,
337
+ input_metadata: InputMetadata,
338
+ input_embeds: torch.Tensor = None,
339
+ ) -> torch.Tensor:
340
+ if input_embeds is None:
341
+ hidden_states = self.wte(input_ids)
342
+ else:
343
+ hidden_states = input_embeds
344
+ for i in range(len(self.blocks)):
345
+ block = self.blocks[i]
346
+ hidden_states = block(position_ids, hidden_states, input_metadata)
347
+ hidden_states = self.norm_f(hidden_states)
348
+ return hidden_states
349
+
350
+
351
+ class DbrxForCausalLM(nn.Module):
352
+ def __init__(
353
+ self,
354
+ config: DbrxConfig,
355
+ quant_config: Optional[QuantizationConfig] = None,
356
+ cache_config: Optional[CacheConfig] = None,
357
+ ):
358
+ super().__init__()
359
+ self.config = config
360
+ self.quant_config = quant_config
361
+ self.unpadded_vocab_size = config.vocab_size
362
+ self.transformer = DbrxModel(config, quant_config=quant_config)
363
+ self.lm_head = ParallelLMHead(
364
+ config.vocab_size,
365
+ config.d_model,
366
+ org_num_embeddings=config.vocab_size,
367
+ padding_size=DEFAULT_VOCAB_PADDING_SIZE,
368
+ )
369
+ self.logits_processor = LogitsProcessor(config)
370
+
371
+ def forward(
372
+ self,
373
+ input_ids: torch.Tensor,
374
+ positions: torch.Tensor,
375
+ input_metadata: InputMetadata,
376
+ ) -> torch.Tensor:
377
+ hidden_states = self.transformer(input_ids, positions, input_metadata)
378
+ return self.logits_processor(
379
+ input_ids, hidden_states, self.lm_head.weight, input_metadata
380
+ )
381
+
382
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
383
+ expert_params_mapping = [
384
+ (
385
+ "ws" if weight_name in ["w1", "v1"] else "w2s",
386
+ f"experts.mlp.{weight_name}",
387
+ )
388
+ for weight_name in ["w1", "v1", "w2"]
389
+ ]
390
+ params_dict = dict(self.named_parameters(remove_duplicate=False))
391
+ for name, loaded_weight in weights:
392
+ for param_name, weight_name in expert_params_mapping:
393
+ if weight_name not in name:
394
+ continue
395
+ name = name.replace(weight_name, param_name)
396
+ param = params_dict[name]
397
+ weight_loader = param.weight_loader
398
+ weight_loader(param, loaded_weight, weight_name)
399
+ break
400
+ else:
401
+ param = params_dict[name]
402
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
403
+ weight_loader(param, loaded_weight)
404
+
405
+
406
+ EntryClass = DbrxForCausalLM
@@ -1,32 +1,28 @@
1
1
  # Adapted from:
2
- # https://github.com/vllm-project/vllm/blob/d65fac2738f0287a41955b45df76a2d5a919bff6/vllm/model_executor/models/gemma.py
2
+ # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/gemma.py#L1
3
3
  """Inference-only Gemma model compatible with HuggingFace weights."""
4
- from typing import Optional, Tuple
4
+ from typing import Iterable, Optional, Tuple
5
5
 
6
6
  import torch
7
- from sglang.srt.layers.logits_processor import LogitsProcessor
8
- from sglang.srt.layers.radix_attention import RadixAttention
9
7
  from torch import nn
10
8
  from transformers import PretrainedConfig
11
- from vllm.config import LoRAConfig
12
- from vllm.model_executor.input_metadata import InputMetadata
9
+ from vllm.config import CacheConfig, LoRAConfig
10
+ from vllm.distributed import get_tensor_model_parallel_world_size
13
11
  from vllm.model_executor.layers.activation import GeluAndMul
14
12
  from vllm.model_executor.layers.layernorm import RMSNorm
15
13
  from vllm.model_executor.layers.linear import (
16
- LinearMethodBase,
17
14
  MergedColumnParallelLinear,
18
15
  QKVParallelLinear,
19
16
  RowParallelLinear,
20
17
  )
18
+ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
21
19
  from vllm.model_executor.layers.rotary_embedding import get_rope
22
20
  from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
23
- from vllm.model_executor.parallel_utils.parallel_state import (
24
- get_tensor_model_parallel_world_size,
25
- )
26
- from vllm.model_executor.weight_utils import (
27
- default_weight_loader,
28
- hf_model_weights_iterator,
29
- )
21
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
22
+
23
+ from sglang.srt.layers.logits_processor import LogitsProcessor
24
+ from sglang.srt.layers.radix_attention import RadixAttention
25
+ from sglang.srt.managers.controller.model_runner import InputMetadata
30
26
 
31
27
 
32
28
  class GemmaMLP(nn.Module):
@@ -34,17 +30,20 @@ class GemmaMLP(nn.Module):
34
30
  self,
35
31
  hidden_size: int,
36
32
  intermediate_size: int,
37
- linear_method: Optional[LinearMethodBase] = None,
33
+ quant_config: Optional[QuantizationConfig] = None,
38
34
  ) -> None:
39
35
  super().__init__()
40
36
  self.gate_up_proj = MergedColumnParallelLinear(
41
37
  hidden_size,
42
38
  [intermediate_size] * 2,
43
39
  bias=False,
44
- linear_method=linear_method,
40
+ quant_config=quant_config,
45
41
  )
46
42
  self.down_proj = RowParallelLinear(
47
- intermediate_size, hidden_size, bias=False, linear_method=linear_method
43
+ intermediate_size,
44
+ hidden_size,
45
+ bias=False,
46
+ quant_config=quant_config,
48
47
  )
49
48
  self.act_fn = GeluAndMul()
50
49
 
@@ -65,7 +64,7 @@ class GemmaAttention(nn.Module):
65
64
  layer_id: int = 0,
66
65
  max_position_embeddings: int = 8192,
67
66
  rope_theta: float = 10000,
68
- linear_method: Optional[LinearMethodBase] = None,
67
+ quant_config: Optional[QuantizationConfig] = None,
69
68
  ) -> None:
70
69
  super().__init__()
71
70
  self.hidden_size = hidden_size
@@ -95,13 +94,13 @@ class GemmaAttention(nn.Module):
95
94
  self.total_num_heads,
96
95
  self.total_num_kv_heads,
97
96
  bias=False,
98
- linear_method=linear_method,
97
+ quant_config=quant_config,
99
98
  )
100
99
  self.o_proj = RowParallelLinear(
101
100
  self.total_num_heads * self.head_dim,
102
101
  hidden_size,
103
102
  bias=False,
104
- linear_method=linear_method,
103
+ quant_config=quant_config,
105
104
  )
106
105
 
107
106
  self.rotary_emb = get_rope(
@@ -138,7 +137,7 @@ class GemmaDecoderLayer(nn.Module):
138
137
  self,
139
138
  config: PretrainedConfig,
140
139
  layer_id: int = 0,
141
- linear_method: Optional[LinearMethodBase] = None,
140
+ quant_config: Optional[QuantizationConfig] = None,
142
141
  ) -> None:
143
142
  super().__init__()
144
143
  self.hidden_size = config.hidden_size
@@ -150,12 +149,12 @@ class GemmaDecoderLayer(nn.Module):
150
149
  layer_id=layer_id,
151
150
  max_position_embeddings=config.max_position_embeddings,
152
151
  rope_theta=config.rope_theta,
153
- linear_method=linear_method,
152
+ quant_config=quant_config,
154
153
  )
155
154
  self.mlp = GemmaMLP(
156
155
  hidden_size=self.hidden_size,
157
156
  intermediate_size=config.intermediate_size,
158
- linear_method=linear_method,
157
+ quant_config=quant_config,
159
158
  )
160
159
  self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
161
160
  self.post_attention_layernorm = RMSNorm(
@@ -191,7 +190,7 @@ class GemmaModel(nn.Module):
191
190
  def __init__(
192
191
  self,
193
192
  config: PretrainedConfig,
194
- linear_method: Optional[LinearMethodBase] = None,
193
+ quant_config: Optional[QuantizationConfig] = None,
195
194
  ) -> None:
196
195
  super().__init__()
197
196
  self.config = config
@@ -202,7 +201,7 @@ class GemmaModel(nn.Module):
202
201
  )
203
202
  self.layers = nn.ModuleList(
204
203
  [
205
- GemmaDecoderLayer(config, i, linear_method)
204
+ GemmaDecoderLayer(config, i, quant_config=quant_config)
206
205
  for i in range(config.num_hidden_layers)
207
206
  ]
208
207
  )
@@ -263,14 +262,15 @@ class GemmaForCausalLM(nn.Module):
263
262
  def __init__(
264
263
  self,
265
264
  config: PretrainedConfig,
266
- linear_method: Optional[LinearMethodBase] = None,
265
+ quant_config: Optional[QuantizationConfig] = None,
267
266
  lora_config: Optional[LoRAConfig] = None,
267
+ cache_config: Optional[CacheConfig] = None,
268
268
  ) -> None:
269
269
  del lora_config # Unused.
270
270
  super().__init__()
271
271
  self.config = config
272
- self.linear_method = linear_method
273
- self.model = GemmaModel(config, linear_method)
272
+ self.quant_config = quant_config
273
+ self.model = GemmaModel(config, quant_config=quant_config)
274
274
  self.logits_processor = LogitsProcessor(config)
275
275
 
276
276
  @torch.no_grad()
@@ -286,13 +286,7 @@ class GemmaForCausalLM(nn.Module):
286
286
  input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
287
287
  )
288
288
 
289
- def load_weights(
290
- self,
291
- model_name_or_path: str,
292
- cache_dir: Optional[str] = None,
293
- load_format: str = "auto",
294
- revision: Optional[str] = None,
295
- ):
289
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
296
290
  stacked_params_mapping = [
297
291
  # (param_name, shard_name, shard_id)
298
292
  ("qkv_proj", "q_proj", "q"),
@@ -303,9 +297,7 @@ class GemmaForCausalLM(nn.Module):
303
297
  ]
304
298
  params_dict = dict(self.named_parameters())
305
299
  loaded_params = set()
306
- for name, loaded_weight in hf_model_weights_iterator(
307
- model_name_or_path, cache_dir, load_format, revision
308
- ):
300
+ for name, loaded_weight in weights:
309
301
  for param_name, shard_name, shard_id in stacked_params_mapping:
310
302
  if shard_name not in name:
311
303
  continue
@@ -318,6 +310,10 @@ class GemmaForCausalLM(nn.Module):
318
310
  weight_loader(param, loaded_weight, shard_id)
319
311
  break
320
312
  else:
313
+ # lm_head is not used in vllm as it is tied with embed_token.
314
+ # To prevent errors, skip loading lm_head.weight.
315
+ if "lm_head.weight" in name:
316
+ continue
321
317
  # Skip loading extra bias for GPTQ models.
322
318
  if name.endswith(".bias") and name not in params_dict:
323
319
  continue