sglang 0.1.14__py3-none-any.whl → 0.1.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. sglang/__init__.py +59 -2
  2. sglang/api.py +40 -11
  3. sglang/backend/anthropic.py +17 -3
  4. sglang/backend/litellm.py +90 -0
  5. sglang/backend/openai.py +160 -12
  6. sglang/backend/runtime_endpoint.py +62 -27
  7. sglang/backend/vertexai.py +1 -0
  8. sglang/bench_latency.py +320 -0
  9. sglang/global_config.py +24 -3
  10. sglang/lang/chat_template.py +122 -6
  11. sglang/lang/compiler.py +2 -2
  12. sglang/lang/interpreter.py +206 -98
  13. sglang/lang/ir.py +98 -34
  14. sglang/lang/tracer.py +6 -4
  15. sglang/launch_server.py +4 -1
  16. sglang/launch_server_llavavid.py +32 -0
  17. sglang/srt/constrained/__init__.py +14 -6
  18. sglang/srt/constrained/fsm_cache.py +9 -2
  19. sglang/srt/constrained/jump_forward.py +113 -24
  20. sglang/srt/conversation.py +4 -2
  21. sglang/srt/flush_cache.py +18 -0
  22. sglang/srt/hf_transformers_utils.py +144 -3
  23. sglang/srt/layers/context_flashattention_nopad.py +1 -0
  24. sglang/srt/layers/extend_attention.py +20 -1
  25. sglang/srt/layers/fused_moe.py +596 -0
  26. sglang/srt/layers/logits_processor.py +190 -61
  27. sglang/srt/layers/radix_attention.py +62 -53
  28. sglang/srt/layers/token_attention.py +21 -9
  29. sglang/srt/managers/controller/cuda_graph_runner.py +196 -0
  30. sglang/srt/managers/controller/dp_worker.py +113 -0
  31. sglang/srt/managers/controller/infer_batch.py +908 -0
  32. sglang/srt/managers/controller/manager_multi.py +195 -0
  33. sglang/srt/managers/controller/manager_single.py +177 -0
  34. sglang/srt/managers/controller/model_runner.py +359 -0
  35. sglang/srt/managers/{router → controller}/radix_cache.py +102 -53
  36. sglang/srt/managers/controller/schedule_heuristic.py +65 -0
  37. sglang/srt/managers/controller/tp_worker.py +813 -0
  38. sglang/srt/managers/detokenizer_manager.py +42 -40
  39. sglang/srt/managers/io_struct.py +44 -10
  40. sglang/srt/managers/tokenizer_manager.py +224 -82
  41. sglang/srt/memory_pool.py +52 -59
  42. sglang/srt/model_config.py +97 -2
  43. sglang/srt/models/chatglm.py +399 -0
  44. sglang/srt/models/commandr.py +369 -0
  45. sglang/srt/models/dbrx.py +406 -0
  46. sglang/srt/models/gemma.py +34 -38
  47. sglang/srt/models/gemma2.py +436 -0
  48. sglang/srt/models/grok.py +738 -0
  49. sglang/srt/models/llama2.py +47 -37
  50. sglang/srt/models/llama_classification.py +107 -0
  51. sglang/srt/models/llava.py +92 -27
  52. sglang/srt/models/llavavid.py +298 -0
  53. sglang/srt/models/minicpm.py +366 -0
  54. sglang/srt/models/mixtral.py +302 -127
  55. sglang/srt/models/mixtral_quant.py +372 -0
  56. sglang/srt/models/qwen.py +40 -35
  57. sglang/srt/models/qwen2.py +33 -36
  58. sglang/srt/models/qwen2_moe.py +473 -0
  59. sglang/srt/models/stablelm.py +33 -39
  60. sglang/srt/models/yivl.py +19 -26
  61. sglang/srt/openai_api_adapter.py +411 -0
  62. sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +44 -19
  63. sglang/srt/sampling_params.py +2 -0
  64. sglang/srt/server.py +197 -481
  65. sglang/srt/server_args.py +190 -74
  66. sglang/srt/utils.py +460 -95
  67. sglang/test/test_programs.py +73 -10
  68. sglang/test/test_utils.py +226 -7
  69. sglang/utils.py +97 -27
  70. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/METADATA +74 -45
  71. sglang-0.1.21.dist-info/RECORD +82 -0
  72. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/WHEEL +1 -1
  73. sglang/srt/backend_config.py +0 -13
  74. sglang/srt/managers/router/infer_batch.py +0 -503
  75. sglang/srt/managers/router/manager.py +0 -79
  76. sglang/srt/managers/router/model_rpc.py +0 -686
  77. sglang/srt/managers/router/model_runner.py +0 -514
  78. sglang/srt/managers/router/scheduler.py +0 -70
  79. sglang-0.1.14.dist-info/RECORD +0 -64
  80. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/LICENSE +0 -0
  81. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,298 @@
1
+ """Inference-only LLaVa video model compatible with HuggingFace weights."""
2
+
3
+ from typing import Iterable, List, Optional, Tuple
4
+
5
+ import numpy as np
6
+ import torch
7
+ from torch import nn
8
+ from transformers import CLIPVisionModel, LlavaConfig
9
+ from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
10
+ from vllm.config import CacheConfig
11
+ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
12
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
13
+
14
+ from sglang.srt.managers.controller.infer_batch import ForwardMode
15
+ from sglang.srt.managers.controller.model_runner import InputMetadata
16
+ from sglang.srt.mm_utils import (
17
+ get_anyres_image_grid_shape,
18
+ unpad_image,
19
+ unpad_image_shape,
20
+ )
21
+ from sglang.srt.models.llama2 import LlamaForCausalLM
22
+
23
+
24
+ class LlavaVidForCausalLM(nn.Module):
25
+ def __init__(
26
+ self,
27
+ config: LlavaConfig,
28
+ quant_config: Optional[QuantizationConfig] = None,
29
+ cache_config: Optional[CacheConfig] = None,
30
+ ) -> None:
31
+ super().__init__()
32
+ self.config = config
33
+ self.vision_tower = None
34
+ self.config.vision_config.hidden_size = config.mm_hidden_size
35
+ self.config.text_config.hidden_size = config.hidden_size
36
+ self.multi_modal_projector = LlavaMultiModalProjector(config)
37
+ self.mm_spatial_pool_stride = getattr(self.config, "mm_spatial_pool_stride", 2)
38
+ self.resampler = nn.AvgPool2d(
39
+ kernel_size=self.mm_spatial_pool_stride, stride=self.mm_spatial_pool_stride
40
+ )
41
+ self.language_model = LlamaForCausalLM(config, quant_config=quant_config)
42
+ self.num_frames = getattr(self.config, "num_frames", 16)
43
+ if "unpad" in getattr(config, "mm_patch_merge_type", ""):
44
+ self.language_model.model.image_newline = nn.Parameter(
45
+ torch.empty(config.text_config.hidden_size, dtype=torch.float16)
46
+ )
47
+
48
+ def pad_input_ids(self, input_ids, pad_value, pt_shape=None, image_size=None):
49
+ new_image_feature_len = self.image_feature_len
50
+ # now only support spatial_unpad + anyres
51
+ # if self.mm_patch_merge_type.startswith("spatial"):
52
+ # height = width = self.num_patches_per_side
53
+ # if pt_shape[0] > 1:
54
+ # if self.image_aspect_ratio == "anyres":
55
+ # num_patch_width, num_patch_height = get_anyres_image_grid_shape(
56
+ # image_size,
57
+ # self.image_grid_pinpoints,
58
+ # self.vision_tower.config.image_size,
59
+ # )
60
+ # if "unpad" in self.mm_patch_merge_type:
61
+ # h = num_patch_height * height
62
+ # w = num_patch_width * width
63
+ # new_h, new_w = unpad_image_shape(h, w, image_size)
64
+ # new_image_feature_len += new_h * (new_w + 1)
65
+
66
+ pad_ids = pad_value * (
67
+ (new_image_feature_len + len(pad_value)) // len(pad_value)
68
+ )
69
+ offset = input_ids.index(self.config.image_token_index)
70
+ # old_len + pad_len - 1, because we need to remove image_token_id
71
+ new_input_ids = (
72
+ input_ids[:offset]
73
+ + pad_ids[:new_image_feature_len]
74
+ + input_ids[offset + 1 :]
75
+ )
76
+ return new_input_ids, offset
77
+
78
+ def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor:
79
+ image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
80
+ # NOTE: This is not memory efficient. (output_hidden_states=True) will save all the hidden stated.
81
+
82
+ selected_image_feature = image_outputs.hidden_states[self.vision_feature_layer]
83
+ if self.vision_feature_select_strategy in ["default", "patch"]:
84
+ selected_image_feature = selected_image_feature[:, 1:]
85
+ elif self.vision_feature_select_strategy == "full":
86
+ selected_image_feature = selected_image_feature
87
+ else:
88
+ raise ValueError(
89
+ f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}"
90
+ )
91
+
92
+ height = width = self.num_patches_per_side
93
+ num_of_frames = selected_image_feature.shape[0]
94
+ selected_image_feature = selected_image_feature.view(
95
+ num_of_frames, height, width, -1
96
+ )
97
+ selected_image_feature = selected_image_feature.permute(0, 3, 1, 2).contiguous()
98
+ selected_image_feature = (
99
+ self.resampler(selected_image_feature)
100
+ .flatten(2)
101
+ .transpose(1, 2)
102
+ .contiguous()
103
+ )
104
+
105
+ image_features = self.multi_modal_projector(selected_image_feature)
106
+
107
+ return image_features
108
+
109
+ def forward(
110
+ self,
111
+ input_ids: torch.LongTensor,
112
+ positions: torch.Tensor,
113
+ input_metadata: InputMetadata,
114
+ pixel_values: Optional[List[Optional[np.array]]] = None,
115
+ image_sizes: Optional[List[List[int]]] = None,
116
+ image_offsets: Optional[List[int]] = None,
117
+ ) -> torch.Tensor:
118
+ if input_metadata.forward_mode == ForwardMode.EXTEND:
119
+ bs = input_metadata.batch_size
120
+
121
+ # Embed text input
122
+ input_embeds = self.language_model.model.embed_tokens(input_ids)
123
+
124
+ # Embed vision input
125
+ need_vision = (
126
+ (positions[input_metadata.extend_start_loc] < self.image_feature_len)
127
+ .cpu()
128
+ .numpy()
129
+ )
130
+ # FIXME: We need to substract the length of the system prompt
131
+ has_pixel = np.array([pixel_values[i] is not None for i in range(bs)])
132
+ need_vision = need_vision & has_pixel
133
+
134
+ if need_vision.any():
135
+ pixel_values = [pixel_values[i] for i in range(bs) if need_vision[i]]
136
+ image_sizes = [image_sizes[i] for i in range(bs) if need_vision[i]]
137
+
138
+ ########## Encode Image ########
139
+
140
+ if pixel_values[0].ndim == 4:
141
+ # llava-hd: BS, num_patch, C=3, H=336, W=336, num_patch obtained from process_images
142
+ np.concatenate(pixel_values, axis=0)
143
+ # ndim=4
144
+ concat_images = torch.tensor(
145
+ np.concatenate(pixel_values, axis=0),
146
+ device=self.vision_tower.device,
147
+ )
148
+ # image_features = self.encode_images(concat_images)
149
+ # split_sizes = [image.shape[0] for image in pixel_values]
150
+ # image_features = torch.split(image_features, split_sizes, dim=0)
151
+ image_features = self.encode_images(
152
+ concat_images
153
+ ) # , prompts)#, image_counts, long_video=long_video)
154
+ split_sizes = [image.shape[0] for image in pixel_values]
155
+ image_features = torch.split(image_features, split_sizes, dim=0)
156
+
157
+ # hd image_features: BS, num_patch, 576, 4096
158
+ else:
159
+ # normal pixel: BS, C=3, H=336, W=336
160
+ pixel_values = torch.tensor(
161
+ np.array(pixel_values), device=self.vision_tower.device
162
+ )
163
+ image_features = self.encode_images(pixel_values)
164
+ # image_features: BS, 576, 4096
165
+
166
+ new_image_features = []
167
+ for image_idx, image_feature in enumerate(image_features):
168
+ new_image_features.append(image_feature.flatten(0, 1))
169
+ image_features = new_image_features
170
+
171
+ extend_start_loc_cpu = input_metadata.extend_start_loc.cpu().numpy()
172
+ pt = 0
173
+ for i in range(bs):
174
+ if not need_vision[i]:
175
+ continue
176
+
177
+ start_idx = extend_start_loc_cpu[i]
178
+ pad_len, pad_dim = image_features[pt].shape # 576, 4096
179
+ dim = input_embeds.shape[1]
180
+ assert (
181
+ pad_dim == dim
182
+ ), "invalid pad_dim={}, input_embed_dim={}!".format(pad_dim, dim)
183
+ # Fill in the placeholder for the image
184
+ try:
185
+ input_embeds[
186
+ start_idx
187
+ + image_offsets[i] : start_idx
188
+ + image_offsets[i]
189
+ + pad_len
190
+ ] = image_features[pt]
191
+ except RuntimeError as e:
192
+ print(f"RuntimeError in llava image encoding: {e}")
193
+ print(input_embeds.shape)
194
+ print(start_idx, image_offsets[i])
195
+ pt += 1
196
+
197
+ return self.language_model(
198
+ input_ids, positions, input_metadata, input_embeds=input_embeds
199
+ )
200
+ elif input_metadata.forward_mode == ForwardMode.DECODE:
201
+ return self.language_model(input_ids, positions, input_metadata)
202
+
203
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
204
+ # load clip vision model by cfg['mm_vision_tower']:
205
+ # huggingface_name or path_of_clip_relative_to_llava_model_dir
206
+ vision_path = self.config.mm_vision_tower
207
+ self.vision_tower = CLIPVisionModel.from_pretrained(
208
+ vision_path, torch_dtype=torch.float16
209
+ ).cuda()
210
+ self.vision_tower.eval()
211
+
212
+ self.vision_feature_layer = self.config.mm_vision_select_layer
213
+ self.vision_feature_select_strategy = self.config.mm_vision_select_feature
214
+ self.image_size = self.vision_tower.config.image_size
215
+ self.patch_size = self.vision_tower.config.patch_size
216
+
217
+ self.mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat")
218
+ self.image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square")
219
+ self.image_grid_pinpoints = getattr(self.config, "image_grid_pinpoints", None)
220
+
221
+ print(f"target_frames: {self.num_frames}")
222
+ self.image_feature_len = self.num_frames * int(
223
+ (self.image_size / self.patch_size / self.mm_spatial_pool_stride) ** 2
224
+ )
225
+ if self.vision_feature_select_strategy == "patch":
226
+ pass
227
+ elif self.vision_feature_select_strategy == "cls_patch":
228
+ self.image_feature_len += 1
229
+ else:
230
+ raise ValueError(f"Unexpected select feature: {self.select_feature}")
231
+
232
+ # load mm_projector
233
+ projector_weights = {
234
+ "model.mm_projector.0": "multi_modal_projector.linear_1",
235
+ "model.mm_projector.2": "multi_modal_projector.linear_2",
236
+ "model.vision_resampler.mm_projector.0": "multi_modal_projector.linear_1",
237
+ "model.vision_resampler.mm_projector.2": "multi_modal_projector.linear_2",
238
+ "model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
239
+ }
240
+ params_dict = dict(self.named_parameters())
241
+ weights = list(weights)
242
+ for name, loaded_weight in weights:
243
+ # FIXME: why projector weights read two times?
244
+ if "projector" in name or "vision_tower" in name:
245
+ for weight_name, param_name in projector_weights.items():
246
+ if weight_name in name:
247
+ name = name.replace(weight_name, param_name)
248
+ if name in params_dict:
249
+ param = params_dict[name]
250
+ else:
251
+ print(f"Warning: {name} not found in the model")
252
+ continue
253
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
254
+ weight_loader(param, loaded_weight)
255
+
256
+ # load language model
257
+ self.language_model.load_weights(weights)
258
+
259
+ monkey_path_clip_vision_embed_forward()
260
+
261
+ @property
262
+ def num_patches_per_side(self):
263
+ return self.image_size // self.patch_size
264
+
265
+
266
+ first_call = True
267
+
268
+
269
+ def clip_vision_embed_forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
270
+ batch_size = pixel_values.shape[0]
271
+
272
+ # Move this conv layer to CPU to avoid a bug in torch >= 2.1 on A10G.
273
+ global first_call
274
+ if first_call:
275
+ self.patch_embedding.cpu().float()
276
+ first_call = False
277
+ pixel_values = pixel_values.to(dtype=torch.float32, device="cpu")
278
+ patch_embeds = self.patch_embedding(pixel_values).cuda().half()
279
+
280
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
281
+
282
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1)
283
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
284
+ embeddings = embeddings + self.position_embedding(self.position_ids)
285
+ return embeddings
286
+
287
+
288
+ def monkey_path_clip_vision_embed_forward():
289
+ import transformers
290
+
291
+ setattr(
292
+ transformers.models.clip.modeling_clip.CLIPVisionEmbeddings,
293
+ "forward",
294
+ clip_vision_embed_forward,
295
+ )
296
+
297
+
298
+ EntryClass = LlavaVidForCausalLM
@@ -0,0 +1,366 @@
1
+ """Inference-only MiniCPM model compatible with HuggingFace weights."""
2
+
3
+ import math
4
+ from typing import Any, Dict, Iterable, Optional, Tuple
5
+
6
+ import torch
7
+ from torch import nn
8
+ from vllm.config import CacheConfig
9
+ from vllm.distributed import get_tensor_model_parallel_world_size
10
+ from vllm.model_executor.layers.activation import SiluAndMul
11
+ from vllm.model_executor.layers.layernorm import RMSNorm
12
+ from vllm.model_executor.layers.linear import (
13
+ MergedColumnParallelLinear,
14
+ QKVParallelLinear,
15
+ RowParallelLinear,
16
+ )
17
+ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
18
+ from vllm.model_executor.layers.rotary_embedding import get_rope
19
+ from vllm.model_executor.layers.vocab_parallel_embedding import (
20
+ ParallelLMHead,
21
+ VocabParallelEmbedding,
22
+ )
23
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
24
+
25
+ from sglang.srt.layers.logits_processor import LogitsProcessor
26
+ from sglang.srt.layers.radix_attention import RadixAttention
27
+ from sglang.srt.managers.controller.model_runner import InputMetadata
28
+
29
+
30
+ class MiniCPMMLP(nn.Module):
31
+ def __init__(
32
+ self,
33
+ hidden_size: int,
34
+ intermediate_size: int,
35
+ hidden_act: str,
36
+ quant_config: Optional[QuantizationConfig] = None,
37
+ ) -> None:
38
+ super().__init__()
39
+ self.gate_up_proj = MergedColumnParallelLinear(
40
+ hidden_size,
41
+ [intermediate_size] * 2,
42
+ bias=False,
43
+ quant_config=quant_config,
44
+ )
45
+ self.down_proj = RowParallelLinear(
46
+ intermediate_size,
47
+ hidden_size,
48
+ bias=False,
49
+ quant_config=quant_config,
50
+ )
51
+ if hidden_act != "silu":
52
+ raise ValueError(
53
+ f"Unsupported activation: {hidden_act}. "
54
+ "Only silu is supported for now."
55
+ )
56
+ self.act_fn = SiluAndMul()
57
+
58
+ def forward(self, x):
59
+ gate_up, _ = self.gate_up_proj(x)
60
+ x = self.act_fn(gate_up)
61
+ x, _ = self.down_proj(x)
62
+ return x
63
+
64
+
65
+ class MiniCPMAttention(nn.Module):
66
+ def __init__(
67
+ self,
68
+ hidden_size: int,
69
+ num_heads: int,
70
+ num_kv_heads: int,
71
+ layer_id: int = 0,
72
+ rope_theta: float = 10000,
73
+ rope_scaling: Optional[Dict[str, Any]] = None,
74
+ max_position_embeddings: int = 8192,
75
+ quant_config: Optional[QuantizationConfig] = None,
76
+ ) -> None:
77
+ super().__init__()
78
+ self.hidden_size = hidden_size
79
+ tp_size = get_tensor_model_parallel_world_size()
80
+ self.total_num_heads = num_heads
81
+ assert self.total_num_heads % tp_size == 0
82
+ self.num_heads = self.total_num_heads // tp_size
83
+ self.total_num_kv_heads = num_kv_heads
84
+ if self.total_num_kv_heads >= tp_size:
85
+ # Number of KV heads is greater than TP size, so we partition
86
+ # the KV heads across multiple tensor parallel GPUs.
87
+ assert self.total_num_kv_heads % tp_size == 0
88
+ else:
89
+ # Number of KV heads is less than TP size, so we replicate
90
+ # the KV heads across multiple tensor parallel GPUs.
91
+ assert tp_size % self.total_num_kv_heads == 0
92
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
93
+ self.head_dim = hidden_size // self.total_num_heads
94
+ self.q_size = self.num_heads * self.head_dim
95
+ self.kv_size = self.num_kv_heads * self.head_dim
96
+ self.scaling = self.head_dim**-0.5
97
+ self.rope_theta = rope_theta
98
+ self.max_position_embeddings = max_position_embeddings
99
+
100
+ self.qkv_proj = QKVParallelLinear(
101
+ hidden_size,
102
+ self.head_dim,
103
+ self.total_num_heads,
104
+ self.total_num_kv_heads,
105
+ bias=False,
106
+ quant_config=quant_config,
107
+ )
108
+ self.o_proj = RowParallelLinear(
109
+ self.total_num_heads * self.head_dim,
110
+ hidden_size,
111
+ bias=False,
112
+ quant_config=quant_config,
113
+ )
114
+
115
+ self.rotary_emb = get_rope(
116
+ self.head_dim,
117
+ rotary_dim=self.head_dim,
118
+ max_position=max_position_embeddings,
119
+ base=rope_theta,
120
+ rope_scaling=rope_scaling,
121
+ )
122
+ # set rope as fp32 instead of bf16
123
+ self.rotary_emb.cos_sin_cache = self.rotary_emb._compute_cos_sin_cache()
124
+ self.attn = RadixAttention(
125
+ self.num_heads,
126
+ self.head_dim,
127
+ self.scaling,
128
+ num_kv_heads=self.num_kv_heads,
129
+ layer_id=layer_id,
130
+ )
131
+
132
+ def forward(
133
+ self,
134
+ positions: torch.Tensor,
135
+ hidden_states: torch.Tensor,
136
+ input_metadata: InputMetadata,
137
+ ) -> torch.Tensor:
138
+ qkv, _ = self.qkv_proj(hidden_states)
139
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
140
+ orig_dtype = q.dtype
141
+ q, k = q.float(), k.float()
142
+ q, k = self.rotary_emb(positions, q, k)
143
+ q, k = q.to(orig_dtype), k.to(orig_dtype)
144
+ attn_output = self.attn(q, k, v, input_metadata)
145
+ output, _ = self.o_proj(attn_output)
146
+ return output
147
+
148
+
149
+ class MiniCPMDecoderLayer(nn.Module):
150
+ def __init__(
151
+ self,
152
+ config,
153
+ layer_id: int = 0,
154
+ quant_config: Optional[QuantizationConfig] = None,
155
+ ) -> None:
156
+ super().__init__()
157
+ self.config = config
158
+ self.hidden_size = config.hidden_size
159
+ rope_theta = getattr(config, "rope_theta", 10000)
160
+ rope_scaling = getattr(config, "rope_scaling", None)
161
+ max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
162
+ self.self_attn = MiniCPMAttention(
163
+ hidden_size=self.hidden_size,
164
+ num_heads=config.num_attention_heads,
165
+ num_kv_heads=config.num_key_value_heads,
166
+ layer_id=layer_id,
167
+ rope_theta=rope_theta,
168
+ rope_scaling=rope_scaling,
169
+ max_position_embeddings=max_position_embeddings,
170
+ quant_config=quant_config,
171
+ )
172
+ self.mlp = MiniCPMMLP(
173
+ hidden_size=self.hidden_size,
174
+ intermediate_size=config.intermediate_size,
175
+ hidden_act=config.hidden_act,
176
+ quant_config=quant_config,
177
+ )
178
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
179
+ self.post_attention_layernorm = RMSNorm(
180
+ config.hidden_size, eps=config.rms_norm_eps
181
+ )
182
+
183
+ def forward(
184
+ self,
185
+ positions: torch.Tensor,
186
+ hidden_states: torch.Tensor,
187
+ input_metadata: InputMetadata,
188
+ residual: Optional[torch.Tensor],
189
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
190
+ # Self Attention
191
+ residual = hidden_states
192
+ hidden_states = self.input_layernorm(hidden_states)
193
+ hidden_states = self.self_attn(
194
+ positions=positions,
195
+ hidden_states=hidden_states,
196
+ input_metadata=input_metadata,
197
+ )
198
+ hidden_states = residual + hidden_states * (
199
+ self.config.scale_depth / math.sqrt(self.config.num_hidden_layers)
200
+ )
201
+
202
+ # Fully Connected
203
+ residual = hidden_states
204
+ hidden_states = self.post_attention_layernorm(hidden_states)
205
+ hidden_states = self.mlp(hidden_states)
206
+ hidden_states = residual + hidden_states * (
207
+ self.config.scale_depth / math.sqrt(self.config.num_hidden_layers)
208
+ )
209
+
210
+ return hidden_states, None
211
+
212
+
213
+ class MiniCPMModel(nn.Module):
214
+ def __init__(
215
+ self,
216
+ config,
217
+ quant_config: Optional[QuantizationConfig] = None,
218
+ ) -> None:
219
+ super().__init__()
220
+ self.config = config
221
+ self.padding_idx = config.pad_token_id
222
+ self.vocab_size = config.vocab_size
223
+ self.embed_tokens = VocabParallelEmbedding(
224
+ self.vocab_size,
225
+ config.hidden_size,
226
+ org_num_embeddings=config.vocab_size,
227
+ )
228
+ self.layers = nn.ModuleList(
229
+ [
230
+ MiniCPMDecoderLayer(config, i, quant_config=quant_config)
231
+ for i in range(config.num_hidden_layers)
232
+ ]
233
+ )
234
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
235
+
236
+ def forward(
237
+ self,
238
+ input_ids: torch.Tensor,
239
+ positions: torch.Tensor,
240
+ input_metadata: InputMetadata,
241
+ input_embeds: torch.Tensor = None,
242
+ ) -> torch.Tensor:
243
+ if input_embeds is None:
244
+ hidden_states = self.embed_tokens(input_ids) * self.config.scale_emb
245
+ else:
246
+ hidden_states = input_embeds
247
+ residual = None
248
+
249
+ for i in range(len(self.layers)):
250
+ layer = self.layers[i]
251
+ hidden_states, residual = layer(
252
+ positions,
253
+ hidden_states,
254
+ input_metadata,
255
+ residual,
256
+ )
257
+ hidden_states = self.norm(hidden_states)
258
+ return hidden_states
259
+
260
+
261
+ class MiniCPMForCausalLM(nn.Module):
262
+ def __init__(
263
+ self,
264
+ config,
265
+ quant_config: Optional[QuantizationConfig] = None,
266
+ cache_config: Optional[CacheConfig] = None,
267
+ ) -> None:
268
+ super().__init__()
269
+ self.config = config
270
+
271
+ self.num_experts = getattr(self.config, "num_experts", 0)
272
+ self.quant_config = quant_config
273
+ self.model = MiniCPMModel(config, quant_config=quant_config)
274
+ # self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
275
+ if not self.config.tie_word_embeddings:
276
+ self.lm_head = ParallelLMHead(
277
+ config.vocab_size,
278
+ config.hidden_size,
279
+ org_num_embeddings=config.vocab_size,
280
+ )
281
+
282
+ self.scale_width = self.config.hidden_size / self.config.dim_model_base
283
+
284
+ self.logits_processor = LogitsProcessor(config)
285
+
286
+ def forward(
287
+ self,
288
+ input_ids: torch.Tensor,
289
+ positions: torch.Tensor,
290
+ input_metadata: InputMetadata,
291
+ input_embeds: torch.Tensor = None,
292
+ ) -> torch.Tensor:
293
+ if input_embeds is not None:
294
+ input_embeds = input_embeds * self.config.scale_emb
295
+ hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
296
+ hidden_states = hidden_states / self.scale_width
297
+ if self.config.tie_word_embeddings:
298
+ lm_head_weight = self.model.embed_tokens.weight
299
+ else:
300
+ lm_head_weight = self.lm_head.weight
301
+ return self.logits_processor(
302
+ input_ids, hidden_states, lm_head_weight, input_metadata
303
+ )
304
+
305
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
306
+ stacked_params_mapping = [
307
+ # (param_name, shard_name, shard_id)
308
+ ("qkv_proj", "q_proj", "q"),
309
+ ("qkv_proj", "k_proj", "k"),
310
+ ("qkv_proj", "v_proj", "v"),
311
+ ("gate_up_proj", "gate_proj", 0),
312
+ ("gate_up_proj", "up_proj", 1),
313
+ ]
314
+ expert_params_mapping = [
315
+ # (param_name, weight_name, expert_id)
316
+ (
317
+ "ws" if weight_name in ["w1", "w3"] else "w2s",
318
+ f"experts.{expert_id}.{weight_name}.weight",
319
+ expert_id,
320
+ )
321
+ for expert_id in range(self.num_experts)
322
+ for weight_name in ["w1", "w2", "w3"]
323
+ ]
324
+ params_dict = dict(self.named_parameters())
325
+ for name, loaded_weight in weights:
326
+ if "rotary_emb.inv_freq" in name:
327
+ continue
328
+ if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
329
+ # Models trained using ColossalAI may include these tensors in
330
+ # the checkpoint. Skip them.
331
+ continue
332
+
333
+ for param_name, weight_name, shard_id in stacked_params_mapping:
334
+ if weight_name not in name:
335
+ continue
336
+ name = name.replace(weight_name, param_name)
337
+ # Skip loading extra bias for GPTQ models.
338
+ if name.endswith(".bias") and name not in params_dict:
339
+ continue
340
+ param = params_dict[name]
341
+ weight_loader = param.weight_loader
342
+ weight_loader(param, loaded_weight, shard_id)
343
+ break
344
+ else:
345
+ for param_name, weight_name, expert_id in expert_params_mapping:
346
+ if weight_name not in name:
347
+ continue
348
+ name = name.replace(weight_name, param_name)
349
+ param = params_dict[name]
350
+ weight_loader = param.weight_loader
351
+ weight_loader(
352
+ param, loaded_weight, weight_name, expert_id=expert_id
353
+ )
354
+ break
355
+ else:
356
+ # Skip loading extra bias for GPTQ models.
357
+ if name.endswith(".bias") and name not in params_dict:
358
+ continue
359
+ param = params_dict[name]
360
+ weight_loader = getattr(
361
+ param, "weight_loader", default_weight_loader
362
+ )
363
+ weight_loader(param, loaded_weight)
364
+
365
+
366
+ EntryClass = MiniCPMForCausalLM