sglang 0.1.14__py3-none-any.whl → 0.1.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. sglang/__init__.py +59 -2
  2. sglang/api.py +40 -11
  3. sglang/backend/anthropic.py +17 -3
  4. sglang/backend/litellm.py +90 -0
  5. sglang/backend/openai.py +160 -12
  6. sglang/backend/runtime_endpoint.py +62 -27
  7. sglang/backend/vertexai.py +1 -0
  8. sglang/bench_latency.py +320 -0
  9. sglang/global_config.py +24 -3
  10. sglang/lang/chat_template.py +122 -6
  11. sglang/lang/compiler.py +2 -2
  12. sglang/lang/interpreter.py +206 -98
  13. sglang/lang/ir.py +98 -34
  14. sglang/lang/tracer.py +6 -4
  15. sglang/launch_server.py +4 -1
  16. sglang/launch_server_llavavid.py +32 -0
  17. sglang/srt/constrained/__init__.py +14 -6
  18. sglang/srt/constrained/fsm_cache.py +9 -2
  19. sglang/srt/constrained/jump_forward.py +113 -24
  20. sglang/srt/conversation.py +4 -2
  21. sglang/srt/flush_cache.py +18 -0
  22. sglang/srt/hf_transformers_utils.py +144 -3
  23. sglang/srt/layers/context_flashattention_nopad.py +1 -0
  24. sglang/srt/layers/extend_attention.py +20 -1
  25. sglang/srt/layers/fused_moe.py +596 -0
  26. sglang/srt/layers/logits_processor.py +190 -61
  27. sglang/srt/layers/radix_attention.py +62 -53
  28. sglang/srt/layers/token_attention.py +21 -9
  29. sglang/srt/managers/controller/cuda_graph_runner.py +196 -0
  30. sglang/srt/managers/controller/dp_worker.py +113 -0
  31. sglang/srt/managers/controller/infer_batch.py +908 -0
  32. sglang/srt/managers/controller/manager_multi.py +195 -0
  33. sglang/srt/managers/controller/manager_single.py +177 -0
  34. sglang/srt/managers/controller/model_runner.py +359 -0
  35. sglang/srt/managers/{router → controller}/radix_cache.py +102 -53
  36. sglang/srt/managers/controller/schedule_heuristic.py +65 -0
  37. sglang/srt/managers/controller/tp_worker.py +813 -0
  38. sglang/srt/managers/detokenizer_manager.py +42 -40
  39. sglang/srt/managers/io_struct.py +44 -10
  40. sglang/srt/managers/tokenizer_manager.py +224 -82
  41. sglang/srt/memory_pool.py +52 -59
  42. sglang/srt/model_config.py +97 -2
  43. sglang/srt/models/chatglm.py +399 -0
  44. sglang/srt/models/commandr.py +369 -0
  45. sglang/srt/models/dbrx.py +406 -0
  46. sglang/srt/models/gemma.py +34 -38
  47. sglang/srt/models/gemma2.py +436 -0
  48. sglang/srt/models/grok.py +738 -0
  49. sglang/srt/models/llama2.py +47 -37
  50. sglang/srt/models/llama_classification.py +107 -0
  51. sglang/srt/models/llava.py +92 -27
  52. sglang/srt/models/llavavid.py +298 -0
  53. sglang/srt/models/minicpm.py +366 -0
  54. sglang/srt/models/mixtral.py +302 -127
  55. sglang/srt/models/mixtral_quant.py +372 -0
  56. sglang/srt/models/qwen.py +40 -35
  57. sglang/srt/models/qwen2.py +33 -36
  58. sglang/srt/models/qwen2_moe.py +473 -0
  59. sglang/srt/models/stablelm.py +33 -39
  60. sglang/srt/models/yivl.py +19 -26
  61. sglang/srt/openai_api_adapter.py +411 -0
  62. sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +44 -19
  63. sglang/srt/sampling_params.py +2 -0
  64. sglang/srt/server.py +197 -481
  65. sglang/srt/server_args.py +190 -74
  66. sglang/srt/utils.py +460 -95
  67. sglang/test/test_programs.py +73 -10
  68. sglang/test/test_utils.py +226 -7
  69. sglang/utils.py +97 -27
  70. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/METADATA +74 -45
  71. sglang-0.1.21.dist-info/RECORD +82 -0
  72. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/WHEEL +1 -1
  73. sglang/srt/backend_config.py +0 -13
  74. sglang/srt/managers/router/infer_batch.py +0 -503
  75. sglang/srt/managers/router/manager.py +0 -79
  76. sglang/srt/managers/router/model_rpc.py +0 -686
  77. sglang/srt/managers/router/model_runner.py +0 -514
  78. sglang/srt/managers/router/scheduler.py +0 -70
  79. sglang-0.1.14.dist-info/RECORD +0 -64
  80. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/LICENSE +0 -0
  81. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,738 @@
1
+ # Adapted from
2
+ # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
3
+ """Inference-only Grok1 model."""
4
+ from typing import Iterable, List, Optional, Tuple
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import tqdm
10
+ from torch import nn
11
+ from transformers import PretrainedConfig
12
+ from vllm import _custom_ops as ops
13
+ from vllm.config import CacheConfig
14
+ from vllm.distributed import (
15
+ get_tensor_model_parallel_rank,
16
+ get_tensor_model_parallel_world_size,
17
+ tensor_model_parallel_all_reduce,
18
+ )
19
+ from vllm.model_executor.layers.layernorm import RMSNorm
20
+ from vllm.model_executor.layers.linear import (
21
+ QKVParallelLinear,
22
+ ReplicatedLinear,
23
+ RowParallelLinear,
24
+ )
25
+ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
26
+ from vllm.model_executor.layers.quantization.fp8 import Fp8Config
27
+ from vllm.model_executor.layers.rotary_embedding import get_rope
28
+ from vllm.model_executor.layers.vocab_parallel_embedding import (
29
+ ParallelLMHead,
30
+ VocabParallelEmbedding,
31
+ )
32
+ from vllm.model_executor.model_loader.loader import DefaultModelLoader
33
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
34
+ from vllm.model_executor.utils import set_weight_attrs
35
+ from vllm.utils import print_warning_once
36
+
37
+ from sglang.srt.layers.fused_moe import fused_moe
38
+ from sglang.srt.layers.logits_processor import LogitsProcessor
39
+ from sglang.srt.layers.radix_attention import RadixAttention
40
+ from sglang.srt.managers.controller.model_runner import InputMetadata
41
+
42
+ use_fused = True
43
+
44
+
45
+ class Grok1MLP(nn.Module):
46
+ def __init__(
47
+ self,
48
+ num_experts: int,
49
+ hidden_size: int,
50
+ intermediate_size: int,
51
+ quant_config: Optional[QuantizationConfig] = None,
52
+ ) -> None:
53
+ super().__init__()
54
+ self.num_experts = num_experts
55
+ self.ffn_dim = intermediate_size
56
+ self.hidden_dim = hidden_size
57
+
58
+ self.w1 = ReplicatedLinear(
59
+ self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config
60
+ )
61
+ self.w2 = ReplicatedLinear(
62
+ self.ffn_dim, self.hidden_dim, bias=False, quant_config=quant_config
63
+ )
64
+ self.w3 = ReplicatedLinear(
65
+ self.hidden_dim, self.ffn_dim, bias=False, quant_config=quant_config
66
+ )
67
+
68
+ self.act_fn = nn.GELU()
69
+
70
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
71
+ w1_out, _ = self.w1(hidden_states)
72
+ w1_out = self.act_fn(w1_out)
73
+ w3_out, _ = self.w3(hidden_states)
74
+ current_hidden_states = w1_out * w3_out
75
+ current_hidden_states, _ = self.w2(current_hidden_states)
76
+ return current_hidden_states
77
+
78
+
79
+ class Grok1MoEUnfused(nn.Module):
80
+ def __init__(
81
+ self,
82
+ config: PretrainedConfig,
83
+ quant_config: Optional[QuantizationConfig] = None,
84
+ ):
85
+ super().__init__()
86
+ self.config = config
87
+ self.rank = get_tensor_model_parallel_rank()
88
+ self.tp_size = get_tensor_model_parallel_world_size()
89
+ self.num_total_experts = config.num_local_experts
90
+ self.top_k = config.num_experts_per_tok
91
+ if self.tp_size > self.num_total_experts:
92
+ raise ValueError(
93
+ f"Tensor parallel size {self.tp_size} is greater than "
94
+ f"the number of experts {self.num_total_experts}."
95
+ )
96
+ # Split experts equally between ranks
97
+ self.expert_indicies = np.array_split(
98
+ range(self.num_total_experts), self.tp_size
99
+ )[self.rank].tolist()
100
+ if not self.expert_indicies:
101
+ raise ValueError(f"Rank {self.rank} has no experts assigned to it.")
102
+
103
+ self.experts = nn.ModuleList(
104
+ [
105
+ (
106
+ Grok1MLP(
107
+ self.num_total_experts,
108
+ config.hidden_size,
109
+ config.intermediate_size,
110
+ quant_config=quant_config,
111
+ )
112
+ if idx in self.expert_indicies
113
+ else None
114
+ )
115
+ for idx in range(self.num_total_experts)
116
+ ]
117
+ )
118
+ self.gate = ReplicatedLinear(
119
+ config.hidden_size, self.num_total_experts, bias=False, quant_config=None
120
+ )
121
+
122
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
123
+ router_logits, _ = self.gate(hidden_states)
124
+ router_logits = 30 * F.tanh(router_logits / 30)
125
+
126
+ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
127
+ routing_weights, selected_experts = torch.topk(
128
+ routing_weights, self.top_k, dim=-1
129
+ )
130
+ routing_weights = routing_weights.to(hidden_states.dtype)
131
+ hidden_dim = hidden_states.shape[1]
132
+
133
+ final_hidden_states = torch.zeros(
134
+ (hidden_states.shape[0], hidden_dim),
135
+ dtype=hidden_states.dtype,
136
+ device=hidden_states.device,
137
+ )
138
+ expert_mask = torch.nn.functional.one_hot(
139
+ selected_experts, num_classes=self.num_total_experts
140
+ ).permute(2, 1, 0)
141
+
142
+ for expert_idx in self.expert_indicies:
143
+ expert_layer = self.experts[expert_idx]
144
+ idx, top_x = torch.where(expert_mask[expert_idx])
145
+
146
+ if top_x.shape[0] == 0:
147
+ continue
148
+
149
+ # in torch it is faster to index using lists than torch tensors
150
+ top_x_list = top_x.tolist()
151
+ idx_list = idx.tolist()
152
+
153
+ # Index the correct hidden states and compute the expert hidden state for
154
+ # the current expert. We need to make sure to multiply the output hidden
155
+ # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
156
+ current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
157
+ current_hidden_states = (
158
+ expert_layer(current_state)
159
+ * routing_weights[top_x_list, idx_list, None]
160
+ )
161
+
162
+ # However `index_add_` only support torch tensors for indexing so we'll use
163
+ # the `top_x` tensor here.
164
+ final_hidden_states.index_add_(0, top_x, current_hidden_states)
165
+
166
+ return tensor_model_parallel_all_reduce(final_hidden_states)
167
+
168
+
169
+ class Grok1MoE(nn.Module):
170
+ """A tensor-parallel MoE implementation for Grok1 that shards each expert
171
+ across all ranks.
172
+
173
+ Each expert's weights are sharded across all ranks and a fused MoE
174
+ kernel is used for the forward pass, and finally we reduce the outputs
175
+ across ranks.
176
+ """
177
+
178
+ def __init__(
179
+ self,
180
+ num_experts: int,
181
+ top_k: int,
182
+ hidden_size: int,
183
+ intermediate_size: int,
184
+ params_dtype: Optional[torch.dtype] = None,
185
+ tp_size: Optional[int] = None,
186
+ quant_config: Optional[QuantizationConfig] = None,
187
+ ):
188
+ super().__init__()
189
+ self.tp_size = tp_size or get_tensor_model_parallel_world_size()
190
+ self.num_total_experts = num_experts
191
+ self.top_k = top_k
192
+ self.hidden_size = hidden_size
193
+ self.intermediate_size = intermediate_size // self.tp_size
194
+ self.quant_config = quant_config
195
+
196
+ # FIXME(pcmoritz): Make this more general to support different
197
+ # quantization schemes
198
+ self.use_fp8 = isinstance(quant_config, Fp8Config)
199
+
200
+ if params_dtype is None:
201
+ params_dtype = torch.get_default_dtype()
202
+ self.params_dtype = params_dtype
203
+
204
+ # Gate always runs at half / full precision for now.
205
+ self.gate = ReplicatedLinear(
206
+ self.hidden_size,
207
+ self.num_total_experts,
208
+ bias=False,
209
+ params_dtype=self.params_dtype,
210
+ quant_config=None,
211
+ )
212
+
213
+ if self.use_fp8 and self.quant_config.is_checkpoint_fp8_serialized:
214
+ params_dtype = torch.float8_e4m3fn
215
+
216
+ self.w13_weight = nn.Parameter(
217
+ torch.empty(
218
+ self.num_total_experts,
219
+ 2 * self.intermediate_size,
220
+ self.hidden_size,
221
+ dtype=params_dtype,
222
+ )
223
+ )
224
+ self.w2_weight = nn.Parameter(
225
+ torch.empty(
226
+ self.num_total_experts,
227
+ self.hidden_size,
228
+ self.intermediate_size,
229
+ dtype=params_dtype,
230
+ )
231
+ )
232
+
233
+ set_weight_attrs(
234
+ self.w13_weight,
235
+ {
236
+ "weight_loader": self.weight_loader,
237
+ },
238
+ )
239
+ set_weight_attrs(
240
+ self.w2_weight,
241
+ {
242
+ "weight_loader": self.weight_loader,
243
+ },
244
+ )
245
+
246
+ # Used for fp8.
247
+ self.w13_scale = None
248
+ self.w2_scale = None
249
+ self.a13_scale = None
250
+ self.a2_scale = None
251
+
252
+ if self.use_fp8:
253
+ # WEIGHT_SCALE (for fp8)
254
+ self.w13_scale = nn.Parameter(
255
+ torch.ones(self.num_total_experts, dtype=torch.float32),
256
+ requires_grad=False,
257
+ )
258
+ self.w2_scale = nn.Parameter(
259
+ torch.ones(self.num_total_experts, dtype=torch.float32),
260
+ requires_grad=False,
261
+ )
262
+
263
+ # If loading fp8 checkpoint, pass the weight loaders.
264
+ # If loading an fp16 checkpoint, do not (we will quantize in
265
+ # process_weights_after_loading()
266
+ if quant_config.is_checkpoint_fp8_serialized:
267
+ set_weight_attrs(
268
+ self.w13_scale,
269
+ {
270
+ "weight_loader": self.weight_loader,
271
+ },
272
+ )
273
+ set_weight_attrs(
274
+ self.w2_scale,
275
+ {
276
+ "weight_loader": self.weight_loader,
277
+ },
278
+ )
279
+
280
+ # ACT_SCALE (for fp8)
281
+ if quant_config.activation_scheme == "static":
282
+ if not quant_config.is_checkpoint_fp8_serialized:
283
+ raise ValueError(
284
+ "Found static activation scheme for checkpoint that "
285
+ "was not serialized fp8."
286
+ )
287
+ self.a13_scale = nn.Parameter(
288
+ torch.zeros(self.num_total_experts, dtype=torch.float32),
289
+ requires_grad=False,
290
+ )
291
+ self.a2_scale = nn.Parameter(
292
+ torch.zeros(self.num_total_experts, dtype=torch.float32),
293
+ requires_grad=False,
294
+ )
295
+
296
+ set_weight_attrs(
297
+ self.a13_scale,
298
+ {
299
+ "weight_loader": self.weight_loader,
300
+ },
301
+ )
302
+ set_weight_attrs(
303
+ self.a2_scale,
304
+ {
305
+ "weight_loader": self.weight_loader,
306
+ },
307
+ )
308
+
309
+ def weight_loader(
310
+ self,
311
+ param: nn.Parameter,
312
+ loaded_weight: torch.Tensor,
313
+ weight_name: str,
314
+ expert_id: int,
315
+ pre_sharded: bool,
316
+ ):
317
+ param_data = param.data
318
+ shard_size = self.intermediate_size
319
+ if pre_sharded:
320
+ # The weight is already sharded. Readl the full shard
321
+ shard = slice(None)
322
+ else:
323
+ tp_rank = get_tensor_model_parallel_rank()
324
+ shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
325
+ if weight_name.endswith("w1.weight"):
326
+ param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
327
+ if weight_name.endswith("w3.weight"):
328
+ param_data[expert_id, shard_size : 2 * shard_size, :] = loaded_weight[
329
+ shard, :
330
+ ]
331
+ if weight_name.endswith("w2.weight"):
332
+ param_data[expert_id, :, :] = loaded_weight[:, shard]
333
+ if "act_scale" in weight_name or "weight_scale" in weight_name:
334
+ param_data[expert_id] = loaded_weight
335
+
336
+ def process_weights_after_loading(self):
337
+ # Fp8 is the only case where we need to process after loading.
338
+ if not self.use_fp8:
339
+ return
340
+
341
+ # If checkpoint is fp16, quantize here.
342
+ if not self.quant_config.is_checkpoint_fp8_serialized:
343
+ w13_weight = torch.empty_like(
344
+ self.w13_weight.data, dtype=torch.float8_e4m3fn
345
+ )
346
+ w2_weight = torch.empty_like(self.w2_weight.data, dtype=torch.float8_e4m3fn)
347
+ for expert in range(self.num_total_experts):
348
+ w13_weight[expert, :, :], self.w13_scale[expert] = ops.scaled_fp8_quant(
349
+ self.w13_weight.data[expert, :, :]
350
+ )
351
+ w2_weight[expert, :, :], self.w2_scale[expert] = ops.scaled_fp8_quant(
352
+ self.w2_weight.data[expert, :, :]
353
+ )
354
+ self.w13_weight = nn.Parameter(w13_weight, requires_grad=False)
355
+ self.w2_weight = nn.Parameter(w2_weight, requires_grad=False)
356
+
357
+ # If checkpoint is fp8 + static, cleanup act_scales.
358
+ # Since state_dict has an act_scale per expert but our kernels
359
+ # are passed one act_scale shared across all experts.
360
+ elif self.quant_config.activation_scheme == "static":
361
+ if self.a13_scale is None or self.a2_scale is None:
362
+ raise ValueError(
363
+ "QuantConfig has static quantization, but found "
364
+ "activation scales are None."
365
+ )
366
+
367
+ if not all_close_1d(self.a13_scale) or not all_close_1d(self.a2_scale):
368
+ print_warning_once(
369
+ "Found act_scales that are not equal for fp8 MoE layer. "
370
+ "Using the maximum across experts for each layer. "
371
+ )
372
+
373
+ self.a13_scale = nn.Parameter(self.a13_scale.max(), requires_grad=False)
374
+ self.a2_scale = nn.Parameter(self.a2_scale.max(), requires_grad=False)
375
+
376
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
377
+ num_tokens, hidden_size = hidden_states.shape
378
+ hidden_states = hidden_states.view(-1, self.hidden_size)
379
+ # router_logits: (num_tokens, n_experts)
380
+ router_logits, _ = self.gate(hidden_states)
381
+ final_hidden_states = fused_moe(
382
+ hidden_states,
383
+ self.w13_weight,
384
+ self.w2_weight,
385
+ router_logits,
386
+ self.top_k,
387
+ renormalize=False,
388
+ inplace=True,
389
+ use_fp8=self.use_fp8,
390
+ w1_scale=self.w13_scale,
391
+ w2_scale=self.w2_scale,
392
+ a1_scale=self.a13_scale,
393
+ a2_scale=self.a2_scale,
394
+ )
395
+
396
+ if self.tp_size > 1:
397
+ final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
398
+
399
+ return final_hidden_states.view(num_tokens, hidden_size)
400
+
401
+
402
+ class Grok1Attention(nn.Module):
403
+ def __init__(
404
+ self,
405
+ hidden_size: int,
406
+ num_heads: int,
407
+ num_kv_heads: int,
408
+ layer_id: int = 0,
409
+ max_position: int = 4096 * 32,
410
+ rope_theta: float = 10000,
411
+ logit_cap: float = 30,
412
+ quant_config: Optional[QuantizationConfig] = None,
413
+ ) -> None:
414
+ super().__init__()
415
+ self.hidden_size = hidden_size
416
+ tp_size = get_tensor_model_parallel_world_size()
417
+ self.total_num_heads = num_heads
418
+ assert self.total_num_heads % tp_size == 0
419
+ self.num_heads = self.total_num_heads // tp_size
420
+ self.total_num_kv_heads = num_kv_heads
421
+ if self.total_num_kv_heads >= tp_size:
422
+ # Number of KV heads is greater than TP size, so we partition
423
+ # the KV heads across multiple tensor parallel GPUs.
424
+ assert self.total_num_kv_heads % tp_size == 0
425
+ else:
426
+ # Number of KV heads is less than TP size, so we replicate
427
+ # the KV heads across multiple tensor parallel GPUs.
428
+ assert tp_size % self.total_num_kv_heads == 0
429
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
430
+ self.head_dim = 128
431
+ self.q_size = self.num_heads * self.head_dim
432
+ self.kv_size = self.num_kv_heads * self.head_dim
433
+ self.scaling = self.head_dim**-0.5
434
+ self.rope_theta = rope_theta
435
+
436
+ self.qkv_proj = QKVParallelLinear(
437
+ hidden_size,
438
+ self.head_dim,
439
+ self.total_num_heads,
440
+ self.total_num_kv_heads,
441
+ bias=False,
442
+ quant_config=quant_config,
443
+ )
444
+
445
+ self.o_proj = RowParallelLinear(
446
+ self.total_num_heads * self.head_dim,
447
+ hidden_size,
448
+ bias=False,
449
+ quant_config=quant_config,
450
+ )
451
+ self.rotary_emb = get_rope(
452
+ self.head_dim,
453
+ rotary_dim=self.head_dim,
454
+ max_position=max_position,
455
+ base=int(self.rope_theta),
456
+ is_neox_style=True,
457
+ )
458
+ self.attn = RadixAttention(
459
+ self.num_heads,
460
+ self.head_dim,
461
+ self.scaling,
462
+ num_kv_heads=self.num_kv_heads,
463
+ layer_id=layer_id,
464
+ logit_cap=logit_cap,
465
+ )
466
+
467
+ def forward(
468
+ self,
469
+ positions: torch.Tensor,
470
+ hidden_states: torch.Tensor,
471
+ input_metadata: InputMetadata,
472
+ ) -> torch.Tensor:
473
+ qkv, _ = self.qkv_proj(hidden_states)
474
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
475
+ q, k = self.rotary_emb(positions, q, k)
476
+ attn_output = self.attn(q, k, v, input_metadata)
477
+ output, _ = self.o_proj(attn_output)
478
+ return output
479
+
480
+
481
+ class Grok1DecoderLayer(nn.Module):
482
+ def __init__(
483
+ self,
484
+ config: PretrainedConfig,
485
+ layer_id: int = 0,
486
+ quant_config: Optional[QuantizationConfig] = None,
487
+ ) -> None:
488
+ super().__init__()
489
+ self.hidden_size = config.hidden_size
490
+ # Requires transformers > 4.32.0
491
+ rope_theta = getattr(config, "rope_theta", 10000)
492
+ self.self_attn = Grok1Attention(
493
+ hidden_size=self.hidden_size,
494
+ num_heads=config.num_attention_heads,
495
+ max_position=config.max_position_embeddings,
496
+ num_kv_heads=config.num_key_value_heads,
497
+ layer_id=layer_id,
498
+ rope_theta=rope_theta,
499
+ quant_config=quant_config,
500
+ )
501
+ if use_fused:
502
+ self.block_sparse_moe = Grok1MoE(
503
+ num_experts=config.num_local_experts,
504
+ top_k=config.num_experts_per_tok,
505
+ hidden_size=config.hidden_size,
506
+ intermediate_size=config.intermediate_size,
507
+ quant_config=quant_config,
508
+ )
509
+ else:
510
+ self.block_sparse_moe = Grok1MoEUnfused(
511
+ config=config, quant_config=quant_config
512
+ )
513
+ self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
514
+ self.post_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
515
+ self.pre_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
516
+ self.post_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
517
+
518
+ def forward(
519
+ self,
520
+ positions: torch.Tensor,
521
+ hidden_states: torch.Tensor,
522
+ input_metadata: InputMetadata,
523
+ ) -> torch.Tensor:
524
+ hidden_states = (
525
+ self.post_attn_norm(
526
+ self.self_attn(
527
+ positions=positions,
528
+ hidden_states=self.pre_attn_norm(hidden_states),
529
+ input_metadata=input_metadata,
530
+ )
531
+ )
532
+ + hidden_states
533
+ )
534
+
535
+ hidden_states = (
536
+ self.post_moe_norm(self.block_sparse_moe(self.pre_moe_norm(hidden_states)))
537
+ + hidden_states
538
+ )
539
+
540
+ return hidden_states
541
+
542
+
543
+ class Grok1Model(nn.Module):
544
+ def __init__(
545
+ self,
546
+ config: PretrainedConfig,
547
+ quant_config: Optional[QuantizationConfig] = None,
548
+ ) -> None:
549
+ super().__init__()
550
+ self.config = config
551
+ self.padding_idx = config.pad_token_id
552
+ self.vocab_size = config.vocab_size
553
+
554
+ self.embed_tokens = VocabParallelEmbedding(
555
+ config.vocab_size,
556
+ config.hidden_size,
557
+ )
558
+ self.layers = nn.ModuleList(
559
+ [
560
+ Grok1DecoderLayer(config, i, quant_config=quant_config)
561
+ for i in range(config.num_hidden_layers)
562
+ ]
563
+ )
564
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
565
+
566
+ def forward(
567
+ self,
568
+ input_ids: torch.Tensor,
569
+ positions: torch.Tensor,
570
+ input_metadata: InputMetadata,
571
+ input_embeds: torch.Tensor = None,
572
+ ) -> torch.Tensor:
573
+ if input_embeds is None:
574
+ hidden_states = self.embed_tokens(input_ids)
575
+ else:
576
+ hidden_states = input_embeds
577
+ hidden_states.mul_(self.config.embedding_multiplier_scale)
578
+
579
+ for i in range(len(self.layers)):
580
+ hidden_states = self.layers[i](positions, hidden_states, input_metadata)
581
+
582
+ hidden_states = self.norm(hidden_states)
583
+ hidden_states.mul_(self.config.output_multiplier_scale)
584
+ return hidden_states
585
+
586
+
587
+ class Grok1ModelForCausalLM(nn.Module):
588
+ def __init__(
589
+ self,
590
+ config: PretrainedConfig,
591
+ quant_config: Optional[QuantizationConfig] = None,
592
+ cache_config: Optional[CacheConfig] = None,
593
+ ) -> None:
594
+ super().__init__()
595
+ self.config = config
596
+ self.quant_config = quant_config
597
+ self.model = Grok1Model(config, quant_config=quant_config)
598
+ self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
599
+ self.logits_processor = LogitsProcessor(config)
600
+
601
+ # Monkey patch _prepare_weights to load pre-sharded weights
602
+ setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
603
+
604
+ def forward(
605
+ self,
606
+ input_ids: torch.Tensor,
607
+ positions: torch.Tensor,
608
+ input_metadata: InputMetadata,
609
+ input_embeds: torch.Tensor = None,
610
+ ) -> torch.Tensor:
611
+ hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
612
+ return self.logits_processor(
613
+ input_ids, hidden_states, self.lm_head.weight, input_metadata
614
+ )
615
+
616
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
617
+ stacked_params_mapping = [
618
+ # (param_name, shard_name, shard_id)
619
+ ("qkv_proj", "q_proj", "q"),
620
+ ("qkv_proj", "k_proj", "k"),
621
+ ("qkv_proj", "v_proj", "v"),
622
+ ]
623
+
624
+ if use_fused:
625
+ expert_params_mapping = (
626
+ [
627
+ # These are the weight scales for the experts
628
+ # (param_name, weight_name, expert_id)
629
+ (
630
+ "w13_scale" if weight_name in ["w1", "w3"] else "w2_scale",
631
+ f"experts.{expert_id}.{weight_name}.weight_scale",
632
+ expert_id,
633
+ )
634
+ for expert_id in range(self.config.num_local_experts)
635
+ for weight_name in ["w1", "w2", "w3"]
636
+ ]
637
+ + [
638
+ # These are the weights for the experts
639
+ # (param_name, weight_name, expert_id)
640
+ (
641
+ "w13_weight" if weight_name in ["w1", "w3"] else "w2_weight",
642
+ f"experts.{expert_id}.{weight_name}.weight",
643
+ expert_id,
644
+ )
645
+ for expert_id in range(self.config.num_local_experts)
646
+ for weight_name in ["w1", "w2", "w3"]
647
+ ]
648
+ + [
649
+ # These are the activation scales for the experts
650
+ # (param_name, weight_name, expert_id)
651
+ (
652
+ "a13_scale" if weight_name in ["w1", "w3"] else "a2_scale",
653
+ f"experts.{expert_id}.{weight_name}.act_scale",
654
+ expert_id,
655
+ )
656
+ for expert_id in range(self.config.num_local_experts)
657
+ for weight_name in ["w1", "w2", "w3"]
658
+ ]
659
+ )
660
+ else:
661
+ expert_params_mapping = []
662
+
663
+ params_dict = dict(self.named_parameters())
664
+ if get_tensor_model_parallel_rank() == 0:
665
+ weights = tqdm.tqdm(weights, total=int(len(params_dict) * 3.4))
666
+ for name, loaded_weight in weights:
667
+ # print(get_tensor_model_parallel_rank(), name)
668
+ if "rotary_emb.inv_freq" in name:
669
+ continue
670
+
671
+ for param_name, weight_name, shard_id in stacked_params_mapping:
672
+ if weight_name not in name:
673
+ continue
674
+ name = name.replace(weight_name, param_name)
675
+ # Skip loading extra bias for GPTQ models.
676
+ if name.endswith(".bias") and name not in params_dict:
677
+ continue
678
+ param = params_dict[name]
679
+ weight_loader = param.weight_loader
680
+ weight_loader(param, loaded_weight, shard_id)
681
+ break
682
+ else:
683
+ for param_name, weight_name, expert_id in expert_params_mapping:
684
+ if weight_name not in name:
685
+ continue
686
+ name = name.replace(weight_name, param_name)
687
+ param = params_dict[name]
688
+ weight_loader = param.weight_loader
689
+ weight_loader(
690
+ param,
691
+ loaded_weight,
692
+ weight_name,
693
+ expert_id=expert_id,
694
+ pre_sharded=get_tensor_model_parallel_world_size() > 1,
695
+ )
696
+ break
697
+ else:
698
+ # Skip loading extra bias for GPTQ models.
699
+ if name.endswith(".bias") and name not in params_dict:
700
+ continue
701
+ param = params_dict[name]
702
+ weight_loader = getattr(
703
+ param, "weight_loader", default_weight_loader
704
+ )
705
+ weight_loader(param, loaded_weight)
706
+
707
+
708
+ def all_close_1d(x: torch.Tensor) -> bool:
709
+ assert len(x.shape) == 1
710
+ return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))
711
+
712
+
713
+ old_prepare_weights = getattr(DefaultModelLoader, "_prepare_weights")
714
+
715
+
716
+ def _prepare_presharded_weights(
717
+ self, model_name_or_path: str, revision: Optional[str], fall_back_to_pt: bool
718
+ ) -> Tuple[str, List[str], bool]:
719
+ import glob
720
+ import os
721
+
722
+ if get_tensor_model_parallel_world_size() == 1:
723
+ return old_prepare_weights(self, model_name_or_path, revision, fall_back_to_pt)
724
+
725
+ tp_rank = get_tensor_model_parallel_rank()
726
+ allow_patterns = [f"*-{tp_rank:03d}.bin"]
727
+
728
+ hf_folder = model_name_or_path
729
+
730
+ hf_weights_files: List[str] = []
731
+ for pattern in allow_patterns:
732
+ hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
733
+ use_safetensors = False
734
+
735
+ return hf_folder, hf_weights_files, use_safetensors
736
+
737
+
738
+ EntryClass = Grok1ModelForCausalLM