sglang 0.4.0.post1__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. sglang/bench_offline_throughput.py +6 -6
  2. sglang/bench_one_batch.py +1 -0
  3. sglang/bench_serving.py +9 -1
  4. sglang/check_env.py +140 -48
  5. sglang/lang/backend/runtime_endpoint.py +1 -0
  6. sglang/lang/chat_template.py +32 -0
  7. sglang/llama3_eval.py +316 -0
  8. sglang/srt/aio_rwlock.py +100 -0
  9. sglang/srt/configs/model_config.py +8 -1
  10. sglang/srt/constrained/xgrammar_backend.py +4 -1
  11. sglang/srt/layers/attention/flashinfer_backend.py +51 -5
  12. sglang/srt/layers/attention/triton_backend.py +16 -25
  13. sglang/srt/layers/attention/triton_ops/decode_attention.py +305 -350
  14. sglang/srt/layers/linear.py +20 -2
  15. sglang/srt/layers/logits_processor.py +133 -95
  16. sglang/srt/layers/{ep_moe → moe/ep_moe}/layer.py +18 -39
  17. sglang/srt/layers/moe/fused_moe_native.py +46 -0
  18. sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/__init__.py +3 -7
  19. sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/fused_moe.py +174 -119
  20. sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/layer.py +17 -49
  21. sglang/srt/layers/moe/topk.py +191 -0
  22. sglang/srt/layers/quantization/__init__.py +5 -50
  23. sglang/srt/layers/quantization/fp8.py +221 -36
  24. sglang/srt/layers/quantization/fp8_kernel.py +278 -0
  25. sglang/srt/layers/quantization/fp8_utils.py +90 -1
  26. sglang/srt/layers/radix_attention.py +8 -1
  27. sglang/srt/layers/sampler.py +27 -5
  28. sglang/srt/layers/torchao_utils.py +31 -0
  29. sglang/srt/managers/detokenizer_manager.py +37 -17
  30. sglang/srt/managers/io_struct.py +39 -10
  31. sglang/srt/managers/schedule_batch.py +54 -34
  32. sglang/srt/managers/schedule_policy.py +64 -5
  33. sglang/srt/managers/scheduler.py +171 -136
  34. sglang/srt/managers/tokenizer_manager.py +184 -133
  35. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  36. sglang/srt/mem_cache/chunk_cache.py +2 -2
  37. sglang/srt/mem_cache/memory_pool.py +15 -8
  38. sglang/srt/mem_cache/radix_cache.py +12 -2
  39. sglang/srt/model_executor/cuda_graph_runner.py +25 -11
  40. sglang/srt/model_executor/model_runner.py +28 -14
  41. sglang/srt/model_parallel.py +66 -5
  42. sglang/srt/models/dbrx.py +1 -1
  43. sglang/srt/models/deepseek.py +1 -1
  44. sglang/srt/models/deepseek_v2.py +67 -18
  45. sglang/srt/models/gemma2.py +34 -0
  46. sglang/srt/models/gemma2_reward.py +0 -1
  47. sglang/srt/models/granite.py +517 -0
  48. sglang/srt/models/grok.py +73 -9
  49. sglang/srt/models/llama.py +22 -0
  50. sglang/srt/models/llama_classification.py +11 -23
  51. sglang/srt/models/llama_reward.py +0 -2
  52. sglang/srt/models/llava.py +37 -14
  53. sglang/srt/models/mixtral.py +2 -2
  54. sglang/srt/models/olmoe.py +1 -1
  55. sglang/srt/models/qwen2.py +20 -0
  56. sglang/srt/models/qwen2_moe.py +1 -1
  57. sglang/srt/models/xverse_moe.py +1 -1
  58. sglang/srt/openai_api/adapter.py +8 -0
  59. sglang/srt/openai_api/protocol.py +9 -4
  60. sglang/srt/server.py +2 -1
  61. sglang/srt/server_args.py +19 -9
  62. sglang/srt/utils.py +40 -54
  63. sglang/test/test_block_fp8.py +341 -0
  64. sglang/test/test_utils.py +3 -2
  65. sglang/utils.py +10 -3
  66. sglang/version.py +1 -1
  67. {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/METADATA +12 -7
  68. {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/RECORD +73 -67
  69. sglang/srt/layers/fused_moe_patch.py +0 -133
  70. /sglang/srt/layers/{ep_moe → moe/ep_moe}/__init__.py +0 -0
  71. /sglang/srt/layers/{ep_moe → moe/ep_moe}/kernels.py +0 -0
  72. {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/LICENSE +0 -0
  73. {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/WHEEL +0 -0
  74. {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,517 @@
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+
15
+ # Adapted from
16
+ # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1
17
+ """Inference-only Granite model compatible with HuggingFace weights."""
18
+
19
+ import logging
20
+ from typing import Any, Dict, Iterable, Optional, Tuple
21
+
22
+ import torch
23
+ from torch import nn
24
+ from transformers import GraniteConfig
25
+ from vllm.distributed import get_tensor_model_parallel_world_size
26
+ from vllm.model_executor.layers.rotary_embedding import get_rope
27
+
28
+ from sglang.srt.layers.activation import SiluAndMul
29
+ from sglang.srt.layers.layernorm import RMSNorm
30
+ from sglang.srt.layers.linear import (
31
+ MergedColumnParallelLinear,
32
+ QKVParallelLinear,
33
+ RowParallelLinear,
34
+ )
35
+ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
36
+ from sglang.srt.layers.pooler import Pooler, PoolingType
37
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
38
+ from sglang.srt.layers.radix_attention import RadixAttention
39
+ from sglang.srt.layers.vocab_parallel_embedding import (
40
+ ParallelLMHead,
41
+ VocabParallelEmbedding,
42
+ )
43
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
44
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
45
+ from sglang.utils import get_exception_traceback
46
+
47
+ logger = logging.getLogger(__name__)
48
+
49
+
50
+ class GraniteMLP(nn.Module):
51
+ def __init__(
52
+ self,
53
+ hidden_size: int,
54
+ intermediate_size: int,
55
+ hidden_act: str,
56
+ quant_config: Optional[QuantizationConfig] = None,
57
+ prefix: str = "",
58
+ ) -> None:
59
+ super().__init__()
60
+ self.gate_up_proj = MergedColumnParallelLinear(
61
+ hidden_size,
62
+ [intermediate_size] * 2,
63
+ bias=False,
64
+ quant_config=quant_config,
65
+ prefix=f"{prefix}.gate_up_proj",
66
+ )
67
+ self.down_proj = RowParallelLinear(
68
+ intermediate_size,
69
+ hidden_size,
70
+ bias=False,
71
+ quant_config=quant_config,
72
+ prefix=f"{prefix}.down_proj",
73
+ )
74
+ if hidden_act != "silu":
75
+ raise ValueError(
76
+ f"Unsupported activation: {hidden_act}. "
77
+ "Only silu is supported for now."
78
+ )
79
+ self.act_fn = SiluAndMul()
80
+
81
+ def forward(self, x):
82
+ gate_up, _ = self.gate_up_proj(x)
83
+ x = self.act_fn(gate_up)
84
+ x, _ = self.down_proj(x)
85
+ return x
86
+
87
+
88
+ class GraniteAttention(nn.Module):
89
+ def __init__(
90
+ self,
91
+ config: GraniteConfig,
92
+ hidden_size: int,
93
+ num_heads: int,
94
+ num_kv_heads: int,
95
+ layer_id: int = 0,
96
+ rope_theta: float = 10000,
97
+ rope_scaling: Optional[Dict[str, Any]] = None,
98
+ rope_is_neox_style: bool = True,
99
+ max_position_embeddings: int = 8192,
100
+ quant_config: Optional[QuantizationConfig] = None,
101
+ prefix: str = "",
102
+ ) -> None:
103
+ super().__init__()
104
+ self.hidden_size = hidden_size
105
+ tp_size = get_tensor_model_parallel_world_size()
106
+ self.total_num_heads = num_heads
107
+ assert self.total_num_heads % tp_size == 0
108
+ self.num_heads = self.total_num_heads // tp_size
109
+ self.total_num_kv_heads = num_kv_heads
110
+ if self.total_num_kv_heads >= tp_size:
111
+ # Number of KV heads is greater than TP size, so we partition
112
+ # the KV heads across multiple tensor parallel GPUs.
113
+ assert self.total_num_kv_heads % tp_size == 0
114
+ else:
115
+ # Number of KV heads is less than TP size, so we replicate
116
+ # the KV heads across multiple tensor parallel GPUs.
117
+ assert tp_size % self.total_num_kv_heads == 0
118
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
119
+ # MistralConfig has an optional head_dim introduced by Mistral-Nemo
120
+ self.head_dim = getattr(
121
+ config, "head_dim", self.hidden_size // self.total_num_heads
122
+ )
123
+ self.q_size = self.num_heads * self.head_dim
124
+ self.kv_size = self.num_kv_heads * self.head_dim
125
+ self.scaling = config.attention_multiplier
126
+ self.rope_theta = rope_theta
127
+ self.max_position_embeddings = max_position_embeddings
128
+
129
+ self.qkv_proj = QKVParallelLinear(
130
+ hidden_size,
131
+ self.head_dim,
132
+ self.total_num_heads,
133
+ self.total_num_kv_heads,
134
+ bias=False,
135
+ quant_config=quant_config,
136
+ prefix=f"{prefix}.qkv_proj",
137
+ )
138
+ self.o_proj = RowParallelLinear(
139
+ self.total_num_heads * self.head_dim,
140
+ hidden_size,
141
+ bias=False,
142
+ quant_config=quant_config,
143
+ prefix=f"{prefix}.o_proj",
144
+ )
145
+
146
+ self.rotary_emb = get_rope(
147
+ self.head_dim,
148
+ rotary_dim=self.head_dim,
149
+ max_position=max_position_embeddings,
150
+ base=rope_theta,
151
+ rope_scaling=rope_scaling,
152
+ is_neox_style=rope_is_neox_style,
153
+ )
154
+ self.attn = RadixAttention(
155
+ self.num_heads,
156
+ self.head_dim,
157
+ self.scaling,
158
+ num_kv_heads=self.num_kv_heads,
159
+ layer_id=layer_id,
160
+ )
161
+
162
+ def forward(
163
+ self,
164
+ positions: torch.Tensor,
165
+ hidden_states: torch.Tensor,
166
+ forward_batch: ForwardBatch,
167
+ ) -> torch.Tensor:
168
+ qkv, _ = self.qkv_proj(hidden_states)
169
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
170
+ q, k = self.rotary_emb(positions, q, k)
171
+ attn_output = self.attn(q, k, v, forward_batch)
172
+ output, _ = self.o_proj(attn_output)
173
+ return output
174
+
175
+
176
+ class GraniteDecoderLayer(nn.Module):
177
+ def __init__(
178
+ self,
179
+ config: GraniteConfig,
180
+ layer_id: int = 0,
181
+ quant_config: Optional[QuantizationConfig] = None,
182
+ prefix: str = "",
183
+ ) -> None:
184
+ super().__init__()
185
+ self.hidden_size = config.hidden_size
186
+ self.residual_multiplier = config.residual_multiplier
187
+ rope_theta = getattr(config, "rope_theta", 10000)
188
+ rope_scaling = getattr(config, "rope_scaling", None)
189
+ if rope_scaling is not None and getattr(
190
+ config, "original_max_position_embeddings", None
191
+ ):
192
+ rope_scaling["original_max_position_embeddings"] = (
193
+ config.original_max_position_embeddings
194
+ )
195
+ rope_is_neox_style = getattr(config, "rope_is_neox_style", True)
196
+ max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
197
+ self.self_attn = GraniteAttention(
198
+ config=config,
199
+ hidden_size=self.hidden_size,
200
+ num_heads=config.num_attention_heads,
201
+ num_kv_heads=config.num_key_value_heads,
202
+ layer_id=layer_id,
203
+ rope_theta=rope_theta,
204
+ rope_scaling=rope_scaling,
205
+ rope_is_neox_style=rope_is_neox_style,
206
+ max_position_embeddings=max_position_embeddings,
207
+ quant_config=quant_config,
208
+ prefix=f"{prefix}.self_attn",
209
+ )
210
+ self.mlp = GraniteMLP(
211
+ hidden_size=self.hidden_size,
212
+ intermediate_size=config.intermediate_size,
213
+ hidden_act=config.hidden_act,
214
+ quant_config=quant_config,
215
+ prefix=f"{prefix}.mlp",
216
+ )
217
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
218
+ self.post_attention_layernorm = RMSNorm(
219
+ config.hidden_size, eps=config.rms_norm_eps
220
+ )
221
+
222
+ def forward(
223
+ self,
224
+ positions: torch.Tensor,
225
+ hidden_states: torch.Tensor,
226
+ forward_batch: ForwardBatch,
227
+ residual: Optional[torch.Tensor],
228
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
229
+ # Self Attention
230
+ if residual is None:
231
+ residual = hidden_states
232
+ hidden_states = self.input_layernorm(hidden_states)
233
+ else:
234
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
235
+ hidden_states = (
236
+ self.self_attn(
237
+ positions=positions,
238
+ hidden_states=hidden_states,
239
+ forward_batch=forward_batch,
240
+ )
241
+ * self.residual_multiplier
242
+ ) # multiplier for Maximal Update Parameterization
243
+
244
+ # Fully Connected
245
+ hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
246
+ hidden_states = self.mlp(hidden_states) * self.residual_multiplier
247
+ return hidden_states, residual
248
+
249
+
250
+ class GraniteModel(nn.Module):
251
+ def __init__(
252
+ self,
253
+ config: GraniteConfig,
254
+ quant_config: Optional[QuantizationConfig] = None,
255
+ ) -> None:
256
+ super().__init__()
257
+ self.config = config
258
+ self.padding_idx = config.pad_token_id
259
+ self.vocab_size = config.vocab_size
260
+ self.embed_tokens = VocabParallelEmbedding(
261
+ config.vocab_size, config.hidden_size
262
+ )
263
+ self.layers = nn.ModuleList(
264
+ [
265
+ GraniteDecoderLayer(
266
+ config, i, quant_config=quant_config, prefix=f"model.layers.{i}"
267
+ )
268
+ for i in range(config.num_hidden_layers)
269
+ ]
270
+ )
271
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
272
+
273
+ def forward(
274
+ self,
275
+ input_ids: torch.Tensor,
276
+ positions: torch.Tensor,
277
+ forward_batch: ForwardBatch,
278
+ input_embeds: torch.Tensor = None,
279
+ ) -> torch.Tensor:
280
+ if input_embeds is None:
281
+ hidden_states = self.embed_tokens(input_ids)
282
+ else:
283
+ hidden_states = input_embeds
284
+ residual = None
285
+ hidden_states *= self.config.embedding_multiplier
286
+ for i in range(len(self.layers)):
287
+ layer = self.layers[i]
288
+ hidden_states, residual = layer(
289
+ positions,
290
+ hidden_states,
291
+ forward_batch,
292
+ residual,
293
+ )
294
+ hidden_states, _ = self.norm(hidden_states, residual)
295
+ return hidden_states
296
+
297
+
298
+ class GraniteForCausalLM(nn.Module):
299
+ def __init__(
300
+ self,
301
+ config: GraniteConfig,
302
+ quant_config: Optional[QuantizationConfig] = None,
303
+ ) -> None:
304
+ super().__init__()
305
+ self.config = config
306
+ self.quant_config = quant_config
307
+ self.model = GraniteModel(config, quant_config=quant_config)
308
+ # If tie_word_embeddings == True, then input and output embeddings are
309
+ # the same tensor. Enforce during object creation so that weights will
310
+ # load correctly even if the LM head weights don't have a separate entry
311
+ # in the state dict.
312
+ self.lm_head = ParallelLMHead(
313
+ config.vocab_size, config.hidden_size, quant_config=quant_config
314
+ )
315
+ if self.config.tie_word_embeddings:
316
+ self.lm_head.tie_weights(self.model.embed_tokens)
317
+
318
+ # Granite logit scaling factors are applied via division, but
319
+ # LogitsProcessor expects a multiplicative factor.
320
+ if hasattr(config, "logits_scaling"):
321
+ logit_scale = 1.0 / config.logits_scaling
322
+ else:
323
+ logit_scale = None
324
+ self.logits_processor = LogitsProcessor(config, logit_scale=logit_scale)
325
+ self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
326
+ self.stacked_params_mapping = [
327
+ # (param_name, shard_name, shard_id)
328
+ (".qkv_proj", ".q_proj", "q"),
329
+ (".qkv_proj", ".k_proj", "k"),
330
+ (".qkv_proj", ".v_proj", "v"),
331
+ (".gate_up_proj", ".gate_proj", 0),
332
+ (".gate_up_proj", ".up_proj", 1),
333
+ ]
334
+
335
+ @torch.no_grad()
336
+ def forward(
337
+ self,
338
+ input_ids: torch.Tensor,
339
+ positions: torch.Tensor,
340
+ forward_batch: ForwardBatch,
341
+ input_embeds: torch.Tensor = None,
342
+ get_embedding: bool = False,
343
+ ) -> LogitsProcessorOutput:
344
+ hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
345
+ if not get_embedding:
346
+ logits_processor_output: LogitsProcessorOutput = self.logits_processor(
347
+ input_ids, hidden_states, self.lm_head, forward_batch
348
+ )
349
+ return logits_processor_output
350
+ else:
351
+ return self.pooler(hidden_states, forward_batch)
352
+
353
+ def get_hidden_dim(self, module_name):
354
+ # return input_dim, output_dim
355
+ if module_name in ["q_proj", "o_proj", "qkv_proj"]:
356
+ return self.config.hidden_size, self.config.hidden_size
357
+ elif module_name in ["kv_proj"]:
358
+ return self.config.hidden_size, self.config.hidden_size // (
359
+ self.config.num_attention_heads // self.config.num_key_value_heads
360
+ )
361
+ elif module_name == "gate_up_proj":
362
+ return self.config.hidden_size, self.config.intermediate_size
363
+ elif module_name == "down_proj":
364
+ return self.config.intermediate_size, self.config.hidden_size
365
+ else:
366
+ raise NotImplementedError()
367
+
368
+ def get_module_name(self, name):
369
+ params_mapping = {
370
+ "q_proj": "qkv_proj",
371
+ "k_proj": "qkv_proj",
372
+ "v_proj": "qkv_proj",
373
+ "gate_proj": "gate_up_proj",
374
+ "up_proj": "gate_up_proj",
375
+ }
376
+ return params_mapping.get(name, name)
377
+
378
+ def get_module_name_from_weight_name(self, name):
379
+ for param_name, weight_name, shard_id, num_shard in self.stacked_params_mapping:
380
+ if weight_name in name:
381
+ return (
382
+ name.replace(weight_name, param_name)[: -len(".weight")],
383
+ num_shard,
384
+ )
385
+ return name[: -len(".weight")], 1
386
+
387
+ def get_num_params(self):
388
+ params_dict = dict(self.named_parameters())
389
+ return len(params_dict)
390
+
391
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
392
+ stacked_params_mapping = [
393
+ # (param_name, shard_name, shard_id)
394
+ (".qkv_proj", ".q_proj", "q"),
395
+ (".qkv_proj", ".k_proj", "k"),
396
+ (".qkv_proj", ".v_proj", "v"),
397
+ (".gate_up_proj", ".gate_proj", 0),
398
+ (".gate_up_proj", ".up_proj", 1),
399
+ ]
400
+
401
+ params_dict = dict(self.named_parameters())
402
+
403
+ for name, loaded_weight in weights:
404
+ if "rotary_emb.inv_freq" in name or "projector" in name:
405
+ continue
406
+ if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
407
+ # Models trained using ColossalAI may include these tensors in
408
+ # the checkpoint. Skip them.
409
+ continue
410
+ if name.startswith("model.vision_tower") and name not in params_dict:
411
+ continue
412
+ if "lm_head.weight" in name and self.config.tie_word_embeddings:
413
+ # Input and output embeddings are tied, so the output embeddings
414
+ # may not be present in the checkpoint. We assume that the input
415
+ # embeddings are always present in the checkpoint.
416
+ continue
417
+
418
+ for param_name, weight_name, shard_id in stacked_params_mapping:
419
+ if weight_name not in name:
420
+ continue
421
+ name = name.replace(weight_name, param_name)
422
+ # Skip loading extra bias for GPTQ models.
423
+ if name.endswith(".bias") and name not in params_dict:
424
+ continue
425
+ param = params_dict[name]
426
+ weight_loader = param.weight_loader
427
+ weight_loader(param, loaded_weight, shard_id)
428
+ break
429
+ else:
430
+ # This block only runs if the preceding for loop doesn't find
431
+ # a match for `name` in `stacked_params_mapping`.
432
+
433
+ # Skip loading extra bias for GPTQ models.
434
+ if name.endswith(".bias") and name not in params_dict:
435
+ continue
436
+ # Skip loading kv_scale from ckpts towards new design.
437
+ if name.endswith(".kv_scale") and name not in params_dict:
438
+ continue
439
+ param = params_dict[name]
440
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
441
+ weight_loader(param, loaded_weight)
442
+
443
+ def get_weights_by_name(
444
+ self, name: str, truncate_size: int = 100, tp_size: int = 1
445
+ ) -> Optional[torch.Tensor]:
446
+ """Get the weights of the parameter by its name. Similar to `get_parameter` in Hugging Face.
447
+
448
+ Only used for unit test with an unoptimized performance.
449
+ For optimized performance, please use torch.save and torch.load.
450
+ """
451
+ try:
452
+ if name == "lm_head.weight" and self.config.tie_word_embeddings:
453
+ logger.info(
454
+ "word embedding is tied for this model, return embed_tokens.weight as lm_head.weight."
455
+ )
456
+ return (
457
+ self.model.embed_tokens.weight.cpu()
458
+ .to(torch.float32)
459
+ .numpy()
460
+ .tolist()[:truncate_size]
461
+ )
462
+
463
+ mapped_name = name
464
+ mapped_shard_id = None
465
+ for param_name, weight_name, shard_id in self.stacked_params_mapping:
466
+ if weight_name in name:
467
+ mapped_name = name.replace(weight_name, param_name)
468
+ mapped_shard_id = shard_id
469
+ break
470
+ params_dict = dict(self.named_parameters())
471
+ param = params_dict[mapped_name]
472
+ if mapped_shard_id is not None:
473
+ if mapped_shard_id in ["q", "k", "v"]:
474
+ num_heads = self.config.num_attention_heads // tp_size
475
+ num_kv_heads = self.config.num_key_value_heads // tp_size
476
+ head_dim = (
477
+ self.config.hidden_size // self.config.num_attention_heads
478
+ )
479
+ if mapped_shard_id == "q":
480
+ offset = 0
481
+ size = num_heads * head_dim
482
+ elif mapped_shard_id == "k":
483
+ offset = num_heads * head_dim
484
+ size = num_kv_heads * head_dim
485
+ elif mapped_shard_id == "v":
486
+ offset = (num_heads + num_kv_heads) * head_dim
487
+ size = num_kv_heads * head_dim
488
+ weight = param.data.narrow(0, offset, size)
489
+ elif mapped_shard_id in [0, 1]:
490
+ intermediate_size = self.config.intermediate_size
491
+ slice_size = intermediate_size // tp_size
492
+ if mapped_shard_id == 0: # gate_proj
493
+ offset = 0
494
+ size = slice_size
495
+ elif mapped_shard_id == 1: # up_proj
496
+ offset = slice_size
497
+ size = slice_size
498
+
499
+ weight = param.data.narrow(0, offset, size)
500
+ else:
501
+ weight = param.data
502
+ else:
503
+ weight = param.data
504
+ if tp_size > 1 and ("o_proj" in name or "down_proj" in name):
505
+ gathered_weights = [torch.zeros_like(weight) for _ in range(tp_size)]
506
+ torch.distributed.all_gather(gathered_weights, weight)
507
+ weight = torch.cat(gathered_weights, dim=1)
508
+ return weight.cpu().to(torch.float32).numpy().tolist()[:truncate_size]
509
+
510
+ except Exception:
511
+ logger.error(
512
+ f"Error getting weights by name {name} in GraniteForCausalLM: {get_exception_traceback()}"
513
+ )
514
+ return None
515
+
516
+
517
+ EntryClass = [GraniteForCausalLM]