sglang 0.1.13__py3-none-any.whl → 0.1.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. sglang/__init__.py +55 -2
  2. sglang/api.py +3 -5
  3. sglang/backend/anthropic.py +33 -13
  4. sglang/backend/openai.py +2 -1
  5. sglang/backend/runtime_endpoint.py +18 -5
  6. sglang/backend/vertexai.py +1 -0
  7. sglang/global_config.py +1 -0
  8. sglang/lang/chat_template.py +74 -0
  9. sglang/lang/interpreter.py +40 -16
  10. sglang/lang/ir.py +1 -1
  11. sglang/lang/tracer.py +6 -4
  12. sglang/launch_server.py +2 -1
  13. sglang/srt/constrained/fsm_cache.py +15 -3
  14. sglang/srt/constrained/jump_forward.py +1 -0
  15. sglang/srt/conversation.py +2 -2
  16. sglang/srt/hf_transformers_utils.py +2 -1
  17. sglang/srt/layers/context_flashattention_nopad.py +1 -0
  18. sglang/srt/layers/extend_attention.py +1 -0
  19. sglang/srt/layers/logits_processor.py +114 -54
  20. sglang/srt/layers/radix_attention.py +2 -1
  21. sglang/srt/layers/token_attention.py +1 -0
  22. sglang/srt/managers/detokenizer_manager.py +5 -1
  23. sglang/srt/managers/io_struct.py +12 -0
  24. sglang/srt/managers/router/infer_batch.py +70 -33
  25. sglang/srt/managers/router/manager.py +7 -2
  26. sglang/srt/managers/router/model_rpc.py +116 -73
  27. sglang/srt/managers/router/model_runner.py +121 -155
  28. sglang/srt/managers/router/radix_cache.py +46 -38
  29. sglang/srt/managers/tokenizer_manager.py +56 -11
  30. sglang/srt/memory_pool.py +5 -14
  31. sglang/srt/model_config.py +7 -0
  32. sglang/srt/models/commandr.py +376 -0
  33. sglang/srt/models/dbrx.py +413 -0
  34. sglang/srt/models/dbrx_config.py +281 -0
  35. sglang/srt/models/gemma.py +22 -20
  36. sglang/srt/models/llama2.py +23 -21
  37. sglang/srt/models/llava.py +12 -10
  38. sglang/srt/models/mixtral.py +27 -25
  39. sglang/srt/models/qwen.py +23 -21
  40. sglang/srt/models/qwen2.py +23 -21
  41. sglang/srt/models/stablelm.py +292 -0
  42. sglang/srt/models/yivl.py +6 -5
  43. sglang/srt/openai_api_adapter.py +356 -0
  44. sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +36 -20
  45. sglang/srt/sampling_params.py +2 -0
  46. sglang/srt/server.py +68 -439
  47. sglang/srt/server_args.py +76 -49
  48. sglang/srt/utils.py +88 -32
  49. sglang/srt/weight_utils.py +402 -0
  50. sglang/test/test_programs.py +8 -7
  51. sglang/test/test_utils.py +196 -8
  52. {sglang-0.1.13.dist-info → sglang-0.1.15.dist-info}/METADATA +13 -15
  53. sglang-0.1.15.dist-info/RECORD +69 -0
  54. {sglang-0.1.13.dist-info → sglang-0.1.15.dist-info}/WHEEL +1 -1
  55. sglang-0.1.13.dist-info/RECORD +0 -63
  56. {sglang-0.1.13.dist-info → sglang-0.1.15.dist-info}/LICENSE +0 -0
  57. {sglang-0.1.13.dist-info → sglang-0.1.15.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,7 @@
1
1
  import asyncio
2
2
  import concurrent.futures
3
3
  import dataclasses
4
+ import logging
4
5
  import multiprocessing as mp
5
6
  import os
6
7
  from typing import List
@@ -10,6 +11,7 @@ import transformers
10
11
  import uvloop
11
12
  import zmq
12
13
  import zmq.asyncio
14
+
13
15
  from sglang.srt.hf_transformers_utils import (
14
16
  get_config,
15
17
  get_context_length,
@@ -30,13 +32,14 @@ from sglang.srt.utils import get_exception_traceback, is_multimodal_model, load_
30
32
 
31
33
  asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
32
34
 
35
+ logger = logging.getLogger(__name__)
36
+
33
37
 
34
38
  @dataclasses.dataclass
35
39
  class ReqState:
36
40
  out_list: List
37
41
  finished: bool
38
42
  event: asyncio.Event
39
- lock: asyncio.Lock
40
43
 
41
44
 
42
45
  global global_processor
@@ -174,18 +177,26 @@ class TokenizerManager:
174
177
  sampling_params=sampling_params,
175
178
  return_logprob=obj.return_logprob,
176
179
  logprob_start_len=obj.logprob_start_len,
180
+ top_logprobs_num=obj.top_logprobs_num,
177
181
  stream=obj.stream,
178
182
  )
179
183
  self.send_to_router.send_pyobj(tokenized_obj)
180
184
 
181
- lock = asyncio.Lock()
182
185
  event = asyncio.Event()
183
- state = ReqState([], False, event, lock)
186
+ state = ReqState([], False, event)
184
187
  self.rid_to_state[rid] = state
185
188
 
186
189
  while True:
187
190
  await event.wait()
188
- yield state.out_list[-1]
191
+ out = self.convert_logprob_style(state.out_list[-1],
192
+ obj.return_logprob,
193
+ obj.top_logprobs_num,
194
+ obj.return_text_in_logprobs)
195
+
196
+ if self.server_args.log_requests and state.finished:
197
+ logger.info(f"in={obj.text}, out={out}")
198
+
199
+ yield out
189
200
  state.out_list = []
190
201
  if state.finished:
191
202
  del self.rid_to_state[rid]
@@ -217,13 +228,13 @@ class TokenizerManager:
217
228
  sampling_params=sampling_params,
218
229
  return_logprob=obj.return_logprob[i],
219
230
  logprob_start_len=obj.logprob_start_len[i],
231
+ top_logprobs_num=obj.top_logprobs_num[i],
220
232
  stream=obj.stream,
221
233
  )
222
234
  self.send_to_router.send_pyobj(tokenized_obj)
223
235
 
224
- lock = asyncio.Lock()
225
236
  event = asyncio.Event()
226
- state = ReqState([], False, event, lock)
237
+ state = ReqState([], False, event)
227
238
  self.rid_to_state[rid] = state
228
239
 
229
240
  output_list = []
@@ -231,16 +242,16 @@ class TokenizerManager:
231
242
  rid = obj.rid[i]
232
243
  state = self.rid_to_state[rid]
233
244
  await state.event.wait()
234
- output_list.append(state.out_list[-1])
245
+ output_list.append(
246
+ self.convert_logprob_style(state.out_list[-1],
247
+ obj.return_logprob[i],
248
+ obj.top_logprobs_num[i],
249
+ obj.return_text_in_logprobs))
235
250
  assert state.finished
236
251
  del self.rid_to_state[rid]
237
252
 
238
253
  yield output_list
239
254
 
240
- async def detokenize(self, obj: DetokenizeReqInput):
241
- token_texts = self.tokenizer.convert_ids_to_tokens(obj.input_ids)
242
- return [t.decode() if isinstance(t, bytes) else t for t in token_texts]
243
-
244
255
  async def flush_cache(self):
245
256
  flush_cache_req = FlushCacheReq()
246
257
  self.send_to_router.send_pyobj(flush_cache_req)
@@ -267,3 +278,37 @@ class TokenizerManager:
267
278
  state.event.set()
268
279
  else:
269
280
  raise ValueError(f"Invalid object: {recv_obj}")
281
+
282
+ def convert_logprob_style(self, ret, return_logprob, top_logprobs_num, return_text_in_logprobs):
283
+ if return_logprob:
284
+ ret["meta_info"]["prefill_token_logprobs"] = self.detokenize_logprob_tokens(
285
+ ret["meta_info"]["prefill_token_logprobs"], return_text_in_logprobs
286
+ )
287
+ ret["meta_info"]["decode_token_logprobs"] = self.detokenize_logprob_tokens(
288
+ ret["meta_info"]["decode_token_logprobs"], return_text_in_logprobs
289
+ )
290
+ if top_logprobs_num > 0:
291
+ ret["meta_info"]["prefill_top_logprobs"] = self.detokenize_top_logprobs_tokens(
292
+ ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs
293
+ )
294
+ ret["meta_info"]["decode_top_logprobs"] = self.detokenize_top_logprobs_tokens(
295
+ ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs
296
+ )
297
+ return ret
298
+
299
+ def detokenize_logprob_tokens(self, token_logprobs, decode_to_text):
300
+ if not decode_to_text:
301
+ return [(logprob, token_id, None) for logprob, token_id in token_logprobs]
302
+
303
+ token_ids = [tid for _, tid in token_logprobs]
304
+ token_texts = self.tokenizer.batch_decode(token_ids)
305
+ return [
306
+ (logprob, token_id, token_text)
307
+ for (logprob, token_id), token_text, in zip(token_logprobs, token_texts)
308
+ ]
309
+
310
+ def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text):
311
+ for i, t in enumerate(top_logprobs):
312
+ if t:
313
+ top_logprobs[i] = self.detokenize_logprob_tokens(t, decode_to_text)
314
+ return top_logprobs
sglang/srt/memory_pool.py CHANGED
@@ -31,9 +31,6 @@ class ReqToTokenPool:
31
31
  self.can_use_mem_size += free_index.shape[0]
32
32
  self.mem_state[free_index] = 1
33
33
 
34
- # if self.can_use_mem_size == len(self.mem_state):
35
- # print(f"ReqToTokenPool: freed all. size = {self.can_use_mem_size}.")
36
-
37
34
  def clear(self):
38
35
  self.mem_state.fill_(1)
39
36
  self.can_use_mem_size = len(self.mem_state)
@@ -42,7 +39,7 @@ class ReqToTokenPool:
42
39
  class TokenToKVPool:
43
40
  def __init__(self, size, dtype, head_num, head_dim, layer_num):
44
41
  self.mem_state = torch.zeros((size,), dtype=torch.int16, device="cuda")
45
- self.alloc_ct = 0
42
+ self.total_ref_ct = 0
46
43
 
47
44
  # [size, key/value, head_num, head_dim] for each layer
48
45
  self.kv_data = [
@@ -83,9 +80,6 @@ class TokenToKVPool:
83
80
  self.add_refs(select_index)
84
81
  return select_index.to(torch.int32), start_loc, start_loc + need_size
85
82
 
86
- def free(self, free_index):
87
- return self.decrease_refs(free_index)
88
-
89
83
  def used_size(self):
90
84
  return len(torch.nonzero(self.mem_state).squeeze(1))
91
85
 
@@ -93,20 +87,17 @@ class TokenToKVPool:
93
87
  return torch.sum(self.mem_state == 0).item()
94
88
 
95
89
  def add_refs(self, token_index: torch.Tensor):
96
- self.alloc_ct += len(token_index)
90
+ self.total_ref_ct += len(token_index)
97
91
  self.mem_state[token_index] += 1
98
92
 
99
- def decrease_refs(self, token_index: torch.Tensor):
100
- self.alloc_ct -= len(token_index)
93
+ def dec_refs(self, token_index: torch.Tensor):
94
+ self.total_ref_ct -= len(token_index)
101
95
  self.mem_state[token_index] -= 1
102
96
 
103
97
  num_freed = torch.sum(self.mem_state[token_index] == 0)
104
98
 
105
- # if self.alloc_ct == 0:
106
- # print(f"TokenToKVPool: freed all. size = {len(self.mem_state)}.")
107
-
108
99
  return num_freed
109
100
 
110
101
  def clear(self):
111
102
  self.mem_state.fill_(0)
112
- self.alloc_ct = 0
103
+ self.total_ref_ct = 0
@@ -29,6 +29,13 @@ class ModelConfig:
29
29
  )
30
30
  self.num_attention_heads = self.hf_config.num_attention_heads
31
31
  self.num_key_value_heads = getattr(self.hf_config, "num_key_value_heads", None)
32
+
33
+ # for Dbrx and MPT models
34
+ if self.hf_config.model_type in ["dbrx", "mpt"]:
35
+ self.num_key_value_heads = getattr(
36
+ self.hf_config.attn_config, "kv_n_heads", None
37
+ )
38
+
32
39
  if self.num_key_value_heads is None:
33
40
  self.num_key_value_heads = self.num_attention_heads
34
41
  self.hidden_size = self.hf_config.hidden_size
@@ -0,0 +1,376 @@
1
+ # coding=utf-8
2
+ # Copyright 2024 Cohere and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+
21
+ # This file is based on the LLama model definition file in transformers
22
+ """PyTorch Cohere model."""
23
+ from typing import Optional, Tuple
24
+
25
+ import torch
26
+ import torch.utils.checkpoint
27
+ from torch import nn
28
+ from torch.nn.parameter import Parameter
29
+ from transformers import PretrainedConfig
30
+ from vllm.model_executor.layers.activation import SiluAndMul
31
+ from vllm.model_executor.layers.linear import (
32
+ MergedColumnParallelLinear,
33
+ QKVParallelLinear,
34
+ RowParallelLinear,
35
+ )
36
+ from vllm.model_executor.layers.quantization.base_config import (
37
+ QuantizationConfig)
38
+ from vllm.model_executor.layers.rotary_embedding import get_rope
39
+ from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
40
+ from vllm.distributed import (
41
+ get_tensor_model_parallel_rank,
42
+ get_tensor_model_parallel_world_size,
43
+ )
44
+ from vllm.model_executor.utils import set_weight_attrs
45
+ from sglang.srt.weight_utils import (
46
+ default_weight_loader,
47
+ hf_model_weights_iterator,
48
+ )
49
+
50
+ from sglang.srt.layers.logits_processor import LogitsProcessor
51
+ from sglang.srt.layers.radix_attention import RadixAttention
52
+ from sglang.srt.managers.router.model_runner import InputMetadata
53
+
54
+
55
+ @torch.compile
56
+ def layer_norm_func(hidden_states, weight, variance_epsilon):
57
+ input_dtype = hidden_states.dtype
58
+ hidden_states = hidden_states.to(torch.float32)
59
+ mean = hidden_states.mean(-1, keepdim=True)
60
+ variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
61
+ hidden_states = (hidden_states - mean) * torch.rsqrt(variance + variance_epsilon)
62
+ hidden_states = weight.to(torch.float32) * hidden_states
63
+ return hidden_states.to(input_dtype)
64
+
65
+
66
+ class LayerNorm(nn.Module):
67
+ def __init__(self, param_shape=None, eps=1e-5):
68
+ super().__init__()
69
+ self.weight = nn.Parameter(torch.ones(param_shape))
70
+ self.variance_epsilon = eps
71
+ set_weight_attrs(self.weight, {"weight_loader": self.weight_loader})
72
+
73
+ def forward(self, hidden_states, residuals=None):
74
+ hidden_states = layer_norm_func(
75
+ hidden_states, self.weight, self.variance_epsilon
76
+ )
77
+ return hidden_states, residuals
78
+
79
+ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
80
+ tp_rank = get_tensor_model_parallel_rank()
81
+ shard_dim = 0 if param.dim() != 1 else None
82
+ param_data = param.data
83
+ if shard_dim is not None:
84
+ shard_size = param_data.shape[shard_dim]
85
+ start_idx = tp_rank * shard_size
86
+ loaded_weight = loaded_weight.narrow(shard_dim, start_idx, shard_size)
87
+ assert param_data.shape == loaded_weight.shape
88
+ param_data.copy_(loaded_weight)
89
+
90
+
91
+ # Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere
92
+ class CohereMLP(nn.Module):
93
+ def __init__(
94
+ self,
95
+ config,
96
+ quant_config: Optional[QuantizationConfig] = None,
97
+ ):
98
+ super().__init__()
99
+ self.config = config
100
+ self.hidden_size = config.hidden_size
101
+ self.intermediate_size = config.intermediate_size
102
+ self.gate_up_proj = MergedColumnParallelLinear(
103
+ self.hidden_size,
104
+ [self.intermediate_size] * 2,
105
+ bias=False,
106
+ quant_config=quant_config,
107
+ )
108
+ self.down_proj = RowParallelLinear(
109
+ self.intermediate_size,
110
+ self.hidden_size,
111
+ bias=False,
112
+ quant_config=quant_config,
113
+ )
114
+ self.act_fn = SiluAndMul()
115
+
116
+ def forward(self, x):
117
+ gate_up, _ = self.gate_up_proj(x)
118
+ x = self.act_fn(gate_up)
119
+ x, _ = self.down_proj(x)
120
+ return x
121
+
122
+
123
+ class CohereAttention(nn.Module):
124
+ def __init__(
125
+ self,
126
+ config: PretrainedConfig,
127
+ layer_id: int = 0,
128
+ quant_config: Optional[QuantizationConfig] = None,
129
+ ):
130
+ super().__init__()
131
+ tp_size = get_tensor_model_parallel_world_size()
132
+ self.config = config
133
+ self.attention_dropout = config.attention_dropout
134
+ self.hidden_size = config.hidden_size
135
+ self.total_num_heads = config.num_attention_heads
136
+ self.num_heads = self.total_num_heads // tp_size
137
+ self.head_dim = self.hidden_size // self.total_num_heads
138
+ self.total_num_kv_heads = config.num_key_value_heads
139
+ if self.total_num_kv_heads >= tp_size:
140
+ # Number of KV heads is greater than TP size, so we partition
141
+ # the KV heads across multiple tensor parallel GPUs.
142
+ assert self.total_num_kv_heads % tp_size == 0
143
+ else:
144
+ # Number of KV heads is less than TP size, so we replicate
145
+ # the KV heads across multiple tensor parallel GPUs.
146
+ assert tp_size % self.total_num_kv_heads == 0
147
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
148
+ self.q_size = self.num_heads * self.head_dim
149
+ self.kv_size = self.num_kv_heads * self.head_dim
150
+ self.scaling = self.head_dim**-0.5
151
+ self.max_position_embeddings = getattr(
152
+ config, "model_max_length", None
153
+ ) or getattr(config, "max_position_embeddings", 8192)
154
+ self.rope_theta = config.rope_theta
155
+ self.rope_scaling = getattr(config, "rope_scaling", None)
156
+ self.use_qk_norm = getattr(config, "use_qk_norm", False)
157
+ self.qkv_proj = QKVParallelLinear(
158
+ self.hidden_size,
159
+ self.head_dim,
160
+ self.total_num_heads,
161
+ self.total_num_kv_heads,
162
+ bias=False,
163
+ quant_config=quant_config,
164
+ )
165
+ self.o_proj = RowParallelLinear(
166
+ self.total_num_heads * self.head_dim,
167
+ self.hidden_size,
168
+ bias=False,
169
+ quant_config=quant_config,
170
+ )
171
+ self.rotary_emb = get_rope(
172
+ self.head_dim,
173
+ rotary_dim=self.head_dim,
174
+ max_position=self.max_position_embeddings,
175
+ base=self.rope_theta,
176
+ rope_scaling=self.rope_scaling,
177
+ is_neox_style=False,
178
+ )
179
+ self.attn = RadixAttention(
180
+ self.num_heads,
181
+ self.head_dim,
182
+ self.scaling,
183
+ num_kv_heads=self.num_kv_heads,
184
+ layer_id=layer_id,
185
+ )
186
+ if self.use_qk_norm:
187
+ self.q_norm = LayerNorm(
188
+ param_shape=(self.num_heads, self.head_dim), eps=config.layer_norm_eps
189
+ )
190
+ self.k_norm = LayerNorm(
191
+ param_shape=(self.num_kv_heads, self.head_dim),
192
+ eps=config.layer_norm_eps,
193
+ )
194
+
195
+ def _apply_qk_norm(self, q, k):
196
+ q = q.view(*q.shape[:-1], -1, self.head_dim)
197
+ k = k.view(*k.shape[:-1], -1, self.head_dim)
198
+ q, _ = self.q_norm(q)
199
+ k, _ = self.k_norm(k)
200
+ q = q.view(*q.shape[:-2], -1)
201
+ k = k.view(*k.shape[:-2], -1)
202
+ return q, k
203
+
204
+ def forward(
205
+ self,
206
+ positions: torch.Tensor,
207
+ hidden_states: torch.Tensor,
208
+ input_metadata: InputMetadata,
209
+ ) -> torch.Tensor:
210
+ qkv, _ = self.qkv_proj(hidden_states)
211
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
212
+ if self.use_qk_norm:
213
+ q, k = self._apply_qk_norm(q, k)
214
+ q, k = self.rotary_emb(positions, q, k)
215
+ attn_output = self.attn(q, k, v, input_metadata)
216
+ output, _ = self.o_proj(attn_output)
217
+ return output
218
+
219
+
220
+ class CohereDecoderLayer(nn.Module):
221
+ def __init__(
222
+ self,
223
+ config: PretrainedConfig,
224
+ layer_id: int = 0,
225
+ quant_config: Optional[QuantizationConfig] = None,
226
+ ):
227
+ super().__init__()
228
+ self.hidden_size = config.hidden_size
229
+
230
+ self.self_attn = CohereAttention(
231
+ config, layer_id=layer_id, quant_config=quant_config
232
+ )
233
+
234
+ self.mlp = CohereMLP(config, quant_config=quant_config)
235
+ self.input_layernorm = LayerNorm(
236
+ param_shape=(config.hidden_size), eps=config.layer_norm_eps
237
+ )
238
+
239
+ def forward(
240
+ self,
241
+ positions: torch.Tensor,
242
+ hidden_states: torch.Tensor,
243
+ input_metadata: InputMetadata,
244
+ residual: Optional[torch.Tensor],
245
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
246
+ # Self Attention
247
+ residual = hidden_states
248
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
249
+ hidden_states_attention = self.self_attn(
250
+ positions=positions,
251
+ hidden_states=hidden_states,
252
+ input_metadata=input_metadata,
253
+ )
254
+ hidden_states_mlp = self.mlp(hidden_states)
255
+ # Add everything together
256
+ hidden_states = residual + hidden_states_attention + hidden_states_mlp
257
+
258
+ return hidden_states, residual
259
+
260
+
261
+ class CohereModel(nn.Module):
262
+ def __init__(
263
+ self,
264
+ config: PretrainedConfig,
265
+ quant_config: Optional[QuantizationConfig] = None,
266
+ ):
267
+ super().__init__()
268
+ self.config = config
269
+ self.vocab_size = config.vocab_size
270
+ self.embed_tokens = VocabParallelEmbedding(
271
+ config.vocab_size, config.hidden_size
272
+ )
273
+ self.layers = nn.ModuleList(
274
+ [
275
+ CohereDecoderLayer(config, i, quant_config=quant_config)
276
+ for i in range(config.num_hidden_layers)
277
+ ]
278
+ )
279
+ self.norm = LayerNorm(
280
+ param_shape=(config.hidden_size), eps=config.layer_norm_eps
281
+ )
282
+
283
+ def forward(
284
+ self,
285
+ input_ids: torch.Tensor,
286
+ positions: torch.Tensor,
287
+ input_metadata: InputMetadata,
288
+ ) -> torch.Tensor:
289
+ hidden_states = self.embed_tokens(input_ids)
290
+ residual = None
291
+ for i in range(len(self.layers)):
292
+ layer = self.layers[i]
293
+ hidden_states, residual = layer(
294
+ positions,
295
+ hidden_states,
296
+ input_metadata,
297
+ residual,
298
+ )
299
+ hidden_states, _ = self.norm(hidden_states, residual)
300
+ return hidden_states
301
+
302
+
303
+ class CohereForCausalLM(nn.Module):
304
+ def __init__(
305
+ self,
306
+ config: PretrainedConfig,
307
+ quant_config: Optional[QuantizationConfig] = None,
308
+ ) -> None:
309
+ super().__init__()
310
+ self.config = config
311
+ self.quant_config = quant_config
312
+ self.logits_processor = LogitsProcessor(config)
313
+ self.model = CohereModel(config, quant_config)
314
+
315
+ @torch.no_grad()
316
+ def forward(
317
+ self,
318
+ input_ids: torch.Tensor,
319
+ positions: torch.Tensor,
320
+ input_metadata: InputMetadata,
321
+ ) -> torch.Tensor:
322
+ hidden_states = self.model(
323
+ input_ids,
324
+ positions,
325
+ input_metadata,
326
+ )
327
+ return self.logits_processor(
328
+ input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
329
+ )
330
+
331
+ def load_weights(
332
+ self,
333
+ model_name_or_path: str,
334
+ cache_dir: Optional[str] = None,
335
+ load_format: str = "auto",
336
+ revision: Optional[str] = None,
337
+ ):
338
+ stacked_params_mapping = [
339
+ # (param_name, shard_name, shard_id)
340
+ ("qkv_proj", "q_proj", "q"),
341
+ ("qkv_proj", "k_proj", "k"),
342
+ ("qkv_proj", "v_proj", "v"),
343
+ ("gate_up_proj", "gate_proj", 0),
344
+ ("gate_up_proj", "up_proj", 1),
345
+ ]
346
+ params_dict = dict(self.named_parameters())
347
+ loaded_params = set()
348
+ for name, loaded_weight in hf_model_weights_iterator(
349
+ model_name_or_path, cache_dir, load_format, revision
350
+ ):
351
+ for param_name, shard_name, shard_id in stacked_params_mapping:
352
+ if shard_name not in name:
353
+ continue
354
+ name = name.replace(shard_name, param_name)
355
+ # Skip loading extra bias for GPTQ models.
356
+ if name.endswith(".bias") and name not in params_dict:
357
+ continue
358
+ param = params_dict[name]
359
+ weight_loader = param.weight_loader
360
+ weight_loader(param, loaded_weight, shard_id)
361
+ break
362
+ else:
363
+ # lm_head is not used in vllm as it is tied with embed_token.
364
+ # To prevent errors, skip loading lm_head.weight.
365
+ if "lm_head.weight" in name:
366
+ continue
367
+ # Skip loading extra bias for GPTQ models.
368
+ if name.endswith(".bias") and name not in params_dict:
369
+ continue
370
+ param = params_dict[name]
371
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
372
+ weight_loader(param, loaded_weight)
373
+ loaded_params.add(name)
374
+
375
+
376
+ EntryClass = CohereForCausalLM