sglang 0.1.17__py3-none-any.whl → 0.1.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. sglang/__init__.py +2 -2
  2. sglang/api.py +30 -4
  3. sglang/backend/litellm.py +2 -2
  4. sglang/backend/openai.py +26 -15
  5. sglang/backend/runtime_endpoint.py +18 -14
  6. sglang/bench_latency.py +317 -0
  7. sglang/global_config.py +5 -1
  8. sglang/lang/chat_template.py +41 -6
  9. sglang/lang/compiler.py +2 -2
  10. sglang/lang/interpreter.py +6 -2
  11. sglang/lang/ir.py +74 -28
  12. sglang/launch_server.py +4 -1
  13. sglang/launch_server_llavavid.py +2 -1
  14. sglang/srt/constrained/__init__.py +14 -6
  15. sglang/srt/constrained/fsm_cache.py +6 -3
  16. sglang/srt/constrained/jump_forward.py +113 -25
  17. sglang/srt/conversation.py +2 -0
  18. sglang/srt/flush_cache.py +2 -0
  19. sglang/srt/hf_transformers_utils.py +68 -9
  20. sglang/srt/layers/extend_attention.py +2 -1
  21. sglang/srt/layers/fused_moe.py +280 -169
  22. sglang/srt/layers/logits_processor.py +106 -42
  23. sglang/srt/layers/radix_attention.py +53 -29
  24. sglang/srt/layers/token_attention.py +4 -1
  25. sglang/srt/managers/controller/dp_worker.py +6 -3
  26. sglang/srt/managers/controller/infer_batch.py +144 -69
  27. sglang/srt/managers/controller/manager_multi.py +5 -5
  28. sglang/srt/managers/controller/manager_single.py +9 -4
  29. sglang/srt/managers/controller/model_runner.py +167 -55
  30. sglang/srt/managers/controller/radix_cache.py +4 -0
  31. sglang/srt/managers/controller/schedule_heuristic.py +2 -0
  32. sglang/srt/managers/controller/tp_worker.py +156 -134
  33. sglang/srt/managers/detokenizer_manager.py +19 -21
  34. sglang/srt/managers/io_struct.py +11 -5
  35. sglang/srt/managers/tokenizer_manager.py +16 -14
  36. sglang/srt/model_config.py +89 -4
  37. sglang/srt/models/chatglm.py +399 -0
  38. sglang/srt/models/commandr.py +2 -2
  39. sglang/srt/models/dbrx.py +1 -1
  40. sglang/srt/models/gemma.py +5 -1
  41. sglang/srt/models/gemma2.py +436 -0
  42. sglang/srt/models/grok.py +204 -137
  43. sglang/srt/models/llama2.py +12 -5
  44. sglang/srt/models/llama_classification.py +107 -0
  45. sglang/srt/models/llava.py +11 -8
  46. sglang/srt/models/llavavid.py +1 -1
  47. sglang/srt/models/minicpm.py +373 -0
  48. sglang/srt/models/mixtral.py +164 -115
  49. sglang/srt/models/mixtral_quant.py +0 -1
  50. sglang/srt/models/qwen.py +1 -1
  51. sglang/srt/models/qwen2.py +1 -1
  52. sglang/srt/models/qwen2_moe.py +454 -0
  53. sglang/srt/models/stablelm.py +1 -1
  54. sglang/srt/models/yivl.py +2 -2
  55. sglang/srt/openai_api_adapter.py +35 -25
  56. sglang/srt/openai_protocol.py +2 -2
  57. sglang/srt/server.py +69 -19
  58. sglang/srt/server_args.py +76 -43
  59. sglang/srt/utils.py +177 -35
  60. sglang/test/test_programs.py +28 -10
  61. sglang/utils.py +4 -3
  62. {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/METADATA +44 -31
  63. sglang-0.1.19.dist-info/RECORD +81 -0
  64. {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/WHEEL +1 -1
  65. sglang/srt/managers/router/infer_batch.py +0 -596
  66. sglang/srt/managers/router/manager.py +0 -82
  67. sglang/srt/managers/router/model_rpc.py +0 -818
  68. sglang/srt/managers/router/model_runner.py +0 -445
  69. sglang/srt/managers/router/radix_cache.py +0 -267
  70. sglang/srt/managers/router/scheduler.py +0 -59
  71. sglang-0.1.17.dist-info/RECORD +0 -81
  72. {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/LICENSE +0 -0
  73. {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/top_level.txt +0 -0
sglang/srt/models/grok.py CHANGED
@@ -1,7 +1,7 @@
1
1
  # Adapted from
2
2
  # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
3
3
  """Inference-only Grok1 model."""
4
- from typing import Iterable, Optional, Tuple, List
4
+ from typing import Iterable, List, Optional, Tuple
5
5
 
6
6
  import numpy as np
7
7
  import torch
@@ -9,7 +9,6 @@ import torch.nn.functional as F
9
9
  import tqdm
10
10
  from torch import nn
11
11
  from transformers import PretrainedConfig
12
-
13
12
  from vllm import _custom_ops as ops
14
13
  from vllm.config import CacheConfig
15
14
  from vllm.distributed import (
@@ -35,12 +34,11 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
35
34
  from vllm.model_executor.utils import set_weight_attrs
36
35
  from vllm.utils import print_warning_once
37
36
 
38
- from sglang.srt.layers.logits_processor import LogitsProcessor
39
37
  from sglang.srt.layers.fused_moe import fused_moe
38
+ from sglang.srt.layers.logits_processor import LogitsProcessor
40
39
  from sglang.srt.layers.radix_attention import RadixAttention
41
40
  from sglang.srt.managers.controller.model_runner import InputMetadata
42
41
 
43
-
44
42
  use_fused = True
45
43
 
46
44
 
@@ -134,9 +132,12 @@ class Grok1MoEUnfused(nn.Module):
134
132
 
135
133
  final_hidden_states = torch.zeros(
136
134
  (hidden_states.shape[0], hidden_dim),
137
- dtype=hidden_states.dtype, device=hidden_states.device
135
+ dtype=hidden_states.dtype,
136
+ device=hidden_states.device,
138
137
  )
139
- expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_total_experts).permute(2, 1, 0)
138
+ expert_mask = torch.nn.functional.one_hot(
139
+ selected_experts, num_classes=self.num_total_experts
140
+ ).permute(2, 1, 0)
140
141
 
141
142
  for expert_idx in self.expert_indicies:
142
143
  expert_layer = self.experts[expert_idx]
@@ -153,7 +154,10 @@ class Grok1MoEUnfused(nn.Module):
153
154
  # the current expert. We need to make sure to multiply the output hidden
154
155
  # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
155
156
  current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
156
- current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None]
157
+ current_hidden_states = (
158
+ expert_layer(current_state)
159
+ * routing_weights[top_x_list, idx_list, None]
160
+ )
157
161
 
158
162
  # However `index_add_` only support torch tensors for indexing so we'll use
159
163
  # the `top_x` tensor here.
@@ -198,32 +202,46 @@ class Grok1MoE(nn.Module):
198
202
  self.params_dtype = params_dtype
199
203
 
200
204
  # Gate always runs at half / full precision for now.
201
- self.gate = ReplicatedLinear(self.hidden_size,
202
- self.num_total_experts,
203
- bias=False,
204
- params_dtype=self.params_dtype,
205
- quant_config=None)
205
+ self.gate = ReplicatedLinear(
206
+ self.hidden_size,
207
+ self.num_total_experts,
208
+ bias=False,
209
+ params_dtype=self.params_dtype,
210
+ quant_config=None,
211
+ )
206
212
 
207
213
  if self.use_fp8 and self.quant_config.is_checkpoint_fp8_serialized:
208
214
  params_dtype = torch.float8_e4m3fn
209
215
 
210
216
  self.w13_weight = nn.Parameter(
211
- torch.empty(self.num_total_experts,
212
- 2 * self.intermediate_size,
213
- self.hidden_size,
214
- dtype=params_dtype))
217
+ torch.empty(
218
+ self.num_total_experts,
219
+ 2 * self.intermediate_size,
220
+ self.hidden_size,
221
+ dtype=params_dtype,
222
+ )
223
+ )
215
224
  self.w2_weight = nn.Parameter(
216
- torch.empty(self.num_total_experts,
217
- self.hidden_size,
218
- self.intermediate_size,
219
- dtype=params_dtype))
220
-
221
- set_weight_attrs(self.w13_weight, {
222
- "weight_loader": self.weight_loader,
223
- })
224
- set_weight_attrs(self.w2_weight, {
225
- "weight_loader": self.weight_loader,
226
- })
225
+ torch.empty(
226
+ self.num_total_experts,
227
+ self.hidden_size,
228
+ self.intermediate_size,
229
+ dtype=params_dtype,
230
+ )
231
+ )
232
+
233
+ set_weight_attrs(
234
+ self.w13_weight,
235
+ {
236
+ "weight_loader": self.weight_loader,
237
+ },
238
+ )
239
+ set_weight_attrs(
240
+ self.w2_weight,
241
+ {
242
+ "weight_loader": self.weight_loader,
243
+ },
244
+ )
227
245
 
228
246
  # Used for fp8.
229
247
  self.w13_scale = None
@@ -233,46 +251,69 @@ class Grok1MoE(nn.Module):
233
251
 
234
252
  if self.use_fp8:
235
253
  # WEIGHT_SCALE (for fp8)
236
- self.w13_scale = nn.Parameter(torch.ones(self.num_total_experts,
237
- dtype=torch.float32),
238
- requires_grad=False)
239
- self.w2_scale = nn.Parameter(torch.ones(self.num_total_experts,
240
- dtype=torch.float32),
241
- requires_grad=False)
254
+ self.w13_scale = nn.Parameter(
255
+ torch.ones(self.num_total_experts, dtype=torch.float32),
256
+ requires_grad=False,
257
+ )
258
+ self.w2_scale = nn.Parameter(
259
+ torch.ones(self.num_total_experts, dtype=torch.float32),
260
+ requires_grad=False,
261
+ )
242
262
 
243
263
  # If loading fp8 checkpoint, pass the weight loaders.
244
264
  # If loading an fp16 checkpoint, do not (we will quantize in
245
265
  # process_weights_after_loading()
246
266
  if quant_config.is_checkpoint_fp8_serialized:
247
- set_weight_attrs(self.w13_scale, {
248
- "weight_loader": self.weight_loader,
249
- })
250
- set_weight_attrs(self.w2_scale, {
251
- "weight_loader": self.weight_loader,
252
- })
267
+ set_weight_attrs(
268
+ self.w13_scale,
269
+ {
270
+ "weight_loader": self.weight_loader,
271
+ },
272
+ )
273
+ set_weight_attrs(
274
+ self.w2_scale,
275
+ {
276
+ "weight_loader": self.weight_loader,
277
+ },
278
+ )
253
279
 
254
280
  # ACT_SCALE (for fp8)
255
281
  if quant_config.activation_scheme == "static":
256
282
  if not quant_config.is_checkpoint_fp8_serialized:
257
283
  raise ValueError(
258
284
  "Found static activation scheme for checkpoint that "
259
- "was not serialized fp8.")
260
- self.a13_scale = nn.Parameter(torch.zeros(
261
- self.num_total_experts, dtype=torch.float32),
262
- requires_grad=False)
263
- self.a2_scale = nn.Parameter(torch.zeros(
264
- self.num_total_experts, dtype=torch.float32),
265
- requires_grad=False)
266
-
267
- set_weight_attrs(self.a13_scale, {
268
- "weight_loader": self.weight_loader,
269
- })
270
- set_weight_attrs(self.a2_scale, {
271
- "weight_loader": self.weight_loader,
272
- })
273
-
274
- def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
275
- weight_name: str, expert_id: int, pre_sharded: bool):
285
+ "was not serialized fp8."
286
+ )
287
+ self.a13_scale = nn.Parameter(
288
+ torch.zeros(self.num_total_experts, dtype=torch.float32),
289
+ requires_grad=False,
290
+ )
291
+ self.a2_scale = nn.Parameter(
292
+ torch.zeros(self.num_total_experts, dtype=torch.float32),
293
+ requires_grad=False,
294
+ )
295
+
296
+ set_weight_attrs(
297
+ self.a13_scale,
298
+ {
299
+ "weight_loader": self.weight_loader,
300
+ },
301
+ )
302
+ set_weight_attrs(
303
+ self.a2_scale,
304
+ {
305
+ "weight_loader": self.weight_loader,
306
+ },
307
+ )
308
+
309
+ def weight_loader(
310
+ self,
311
+ param: nn.Parameter,
312
+ loaded_weight: torch.Tensor,
313
+ weight_name: str,
314
+ expert_id: int,
315
+ pre_sharded: bool,
316
+ ):
276
317
  param_data = param.data
277
318
  shard_size = self.intermediate_size
278
319
  if pre_sharded:
@@ -284,8 +325,9 @@ class Grok1MoE(nn.Module):
284
325
  if weight_name.endswith("w1.weight"):
285
326
  param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
286
327
  if weight_name.endswith("w3.weight"):
287
- param_data[expert_id,
288
- shard_size:2 * shard_size, :] = loaded_weight[shard, :]
328
+ param_data[expert_id, shard_size : 2 * shard_size, :] = loaded_weight[
329
+ shard, :
330
+ ]
289
331
  if weight_name.endswith("w2.weight"):
290
332
  param_data[expert_id, :, :] = loaded_weight[:, shard]
291
333
  if "act_scale" in weight_name or "weight_scale" in weight_name:
@@ -298,17 +340,17 @@ class Grok1MoE(nn.Module):
298
340
 
299
341
  # If checkpoint is fp16, quantize here.
300
342
  if not self.quant_config.is_checkpoint_fp8_serialized:
301
- w13_weight = torch.empty_like(self.w13_weight.data,
302
- dtype=torch.float8_e4m3fn)
303
- w2_weight = torch.empty_like(self.w2_weight.data,
304
- dtype=torch.float8_e4m3fn)
343
+ w13_weight = torch.empty_like(
344
+ self.w13_weight.data, dtype=torch.float8_e4m3fn
345
+ )
346
+ w2_weight = torch.empty_like(self.w2_weight.data, dtype=torch.float8_e4m3fn)
305
347
  for expert in range(self.num_total_experts):
306
- w13_weight[expert, :, :], self.w13_scale[
307
- expert] = ops.scaled_fp8_quant(
308
- self.w13_weight.data[expert, :, :])
309
- w2_weight[expert, :, :], self.w2_scale[
310
- expert] = ops.scaled_fp8_quant(
311
- self.w2_weight.data[expert, :, :])
348
+ w13_weight[expert, :, :], self.w13_scale[expert] = ops.scaled_fp8_quant(
349
+ self.w13_weight.data[expert, :, :]
350
+ )
351
+ w2_weight[expert, :, :], self.w2_scale[expert] = ops.scaled_fp8_quant(
352
+ self.w2_weight.data[expert, :, :]
353
+ )
312
354
  self.w13_weight = nn.Parameter(w13_weight, requires_grad=False)
313
355
  self.w2_weight = nn.Parameter(w2_weight, requires_grad=False)
314
356
 
@@ -319,40 +361,40 @@ class Grok1MoE(nn.Module):
319
361
  if self.a13_scale is None or self.a2_scale is None:
320
362
  raise ValueError(
321
363
  "QuantConfig has static quantization, but found "
322
- "activation scales are None.")
364
+ "activation scales are None."
365
+ )
323
366
 
324
- if (not all_close_1d(self.a13_scale)
325
- or not all_close_1d(self.a2_scale)):
367
+ if not all_close_1d(self.a13_scale) or not all_close_1d(self.a2_scale):
326
368
  print_warning_once(
327
369
  "Found act_scales that are not equal for fp8 MoE layer. "
328
- "Using the maximum across experts for each layer. ")
370
+ "Using the maximum across experts for each layer. "
371
+ )
329
372
 
330
- self.a13_scale = nn.Parameter(self.a13_scale.max(),
331
- requires_grad=False)
332
- self.a2_scale = nn.Parameter(self.a2_scale.max(),
333
- requires_grad=False)
373
+ self.a13_scale = nn.Parameter(self.a13_scale.max(), requires_grad=False)
374
+ self.a2_scale = nn.Parameter(self.a2_scale.max(), requires_grad=False)
334
375
 
335
376
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
336
377
  num_tokens, hidden_size = hidden_states.shape
337
378
  hidden_states = hidden_states.view(-1, self.hidden_size)
338
379
  # router_logits: (num_tokens, n_experts)
339
380
  router_logits, _ = self.gate(hidden_states)
340
- final_hidden_states = fused_moe(hidden_states,
341
- self.w13_weight,
342
- self.w2_weight,
343
- router_logits,
344
- self.top_k,
345
- renormalize=False,
346
- inplace=True,
347
- use_fp8=self.use_fp8,
348
- w1_scale=self.w13_scale,
349
- w2_scale=self.w2_scale,
350
- a1_scale=self.a13_scale,
351
- a2_scale=self.a2_scale)
381
+ final_hidden_states = fused_moe(
382
+ hidden_states,
383
+ self.w13_weight,
384
+ self.w2_weight,
385
+ router_logits,
386
+ self.top_k,
387
+ renormalize=False,
388
+ inplace=True,
389
+ use_fp8=self.use_fp8,
390
+ w1_scale=self.w13_scale,
391
+ w2_scale=self.w2_scale,
392
+ a1_scale=self.a13_scale,
393
+ a2_scale=self.a2_scale,
394
+ )
352
395
 
353
396
  if self.tp_size > 1:
354
- final_hidden_states = tensor_model_parallel_all_reduce(
355
- final_hidden_states)
397
+ final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
356
398
 
357
399
  return final_hidden_states.view(num_tokens, hidden_size)
358
400
 
@@ -462,10 +504,12 @@ class Grok1DecoderLayer(nn.Module):
462
504
  top_k=config.num_experts_per_tok,
463
505
  hidden_size=config.hidden_size,
464
506
  intermediate_size=config.intermediate_size,
465
- quant_config=quant_config)
507
+ quant_config=quant_config,
508
+ )
466
509
  else:
467
510
  self.block_sparse_moe = Grok1MoEUnfused(
468
- config=config, quant_config=quant_config)
511
+ config=config, quant_config=quant_config
512
+ )
469
513
  self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
470
514
  self.post_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
471
515
  self.pre_moe_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -477,13 +521,21 @@ class Grok1DecoderLayer(nn.Module):
477
521
  hidden_states: torch.Tensor,
478
522
  input_metadata: InputMetadata,
479
523
  ) -> torch.Tensor:
524
+ hidden_states = (
525
+ self.post_attn_norm(
526
+ self.self_attn(
527
+ positions=positions,
528
+ hidden_states=self.pre_attn_norm(hidden_states),
529
+ input_metadata=input_metadata,
530
+ )
531
+ )
532
+ + hidden_states
533
+ )
480
534
 
481
- hidden_states = self.post_attn_norm(self.self_attn(
482
- positions=positions, hidden_states=self.pre_attn_norm(hidden_states),
483
- input_metadata=input_metadata,
484
- )) + hidden_states
485
-
486
- hidden_states = self.post_moe_norm(self.block_sparse_moe(self.pre_moe_norm(hidden_states))) + hidden_states
535
+ hidden_states = (
536
+ self.post_moe_norm(self.block_sparse_moe(self.pre_moe_norm(hidden_states)))
537
+ + hidden_states
538
+ )
487
539
 
488
540
  return hidden_states
489
541
 
@@ -525,9 +577,7 @@ class Grok1Model(nn.Module):
525
577
  hidden_states.mul_(self.config.embedding_multiplier_scale)
526
578
 
527
579
  for i in range(len(self.layers)):
528
- hidden_states = self.layers[i](
529
- positions, hidden_states, input_metadata
530
- )
580
+ hidden_states = self.layers[i](positions, hidden_states, input_metadata)
531
581
 
532
582
  hidden_states = self.norm(hidden_states)
533
583
  hidden_states.mul_(self.config.output_multiplier_scale)
@@ -572,28 +622,41 @@ class Grok1ModelForCausalLM(nn.Module):
572
622
  ]
573
623
 
574
624
  if use_fused:
575
- expert_params_mapping = [
576
- # These are the weight scales for the experts
577
- # (param_name, weight_name, expert_id)
578
- ("w13_scale" if weight_name in ["w1", "w3"] else "w2_scale",
579
- f"experts.{expert_id}.{weight_name}.weight_scale", expert_id)
580
- for expert_id in range(self.config.num_local_experts)
581
- for weight_name in ["w1", "w2", "w3"]
582
- ] + [
583
- # These are the weights for the experts
584
- # (param_name, weight_name, expert_id)
585
- ("w13_weight" if weight_name in ["w1", "w3"] else "w2_weight",
586
- f"experts.{expert_id}.{weight_name}.weight", expert_id)
587
- for expert_id in range(self.config.num_local_experts)
588
- for weight_name in ["w1", "w2", "w3"]
589
- ] + [
590
- # These are the activation scales for the experts
591
- # (param_name, weight_name, expert_id)
592
- ("a13_scale" if weight_name in ["w1", "w3"] else "a2_scale",
593
- f"experts.{expert_id}.{weight_name}.act_scale", expert_id)
594
- for expert_id in range(self.config.num_local_experts)
595
- for weight_name in ["w1", "w2", "w3"]
596
- ]
625
+ expert_params_mapping = (
626
+ [
627
+ # These are the weight scales for the experts
628
+ # (param_name, weight_name, expert_id)
629
+ (
630
+ "w13_scale" if weight_name in ["w1", "w3"] else "w2_scale",
631
+ f"experts.{expert_id}.{weight_name}.weight_scale",
632
+ expert_id,
633
+ )
634
+ for expert_id in range(self.config.num_local_experts)
635
+ for weight_name in ["w1", "w2", "w3"]
636
+ ]
637
+ + [
638
+ # These are the weights for the experts
639
+ # (param_name, weight_name, expert_id)
640
+ (
641
+ "w13_weight" if weight_name in ["w1", "w3"] else "w2_weight",
642
+ f"experts.{expert_id}.{weight_name}.weight",
643
+ expert_id,
644
+ )
645
+ for expert_id in range(self.config.num_local_experts)
646
+ for weight_name in ["w1", "w2", "w3"]
647
+ ]
648
+ + [
649
+ # These are the activation scales for the experts
650
+ # (param_name, weight_name, expert_id)
651
+ (
652
+ "a13_scale" if weight_name in ["w1", "w3"] else "a2_scale",
653
+ f"experts.{expert_id}.{weight_name}.act_scale",
654
+ expert_id,
655
+ )
656
+ for expert_id in range(self.config.num_local_experts)
657
+ for weight_name in ["w1", "w2", "w3"]
658
+ ]
659
+ )
597
660
  else:
598
661
  expert_params_mapping = []
599
662
 
@@ -601,11 +664,11 @@ class Grok1ModelForCausalLM(nn.Module):
601
664
  if get_tensor_model_parallel_rank() == 0:
602
665
  weights = tqdm.tqdm(weights, total=int(len(params_dict) * 3.4))
603
666
  for name, loaded_weight in weights:
604
- #print(get_tensor_model_parallel_rank(), name)
667
+ # print(get_tensor_model_parallel_rank(), name)
605
668
  if "rotary_emb.inv_freq" in name:
606
669
  continue
607
670
 
608
- for (param_name, weight_name, shard_id) in stacked_params_mapping:
671
+ for param_name, weight_name, shard_id in stacked_params_mapping:
609
672
  if weight_name not in name:
610
673
  continue
611
674
  name = name.replace(weight_name, param_name)
@@ -623,19 +686,22 @@ class Grok1ModelForCausalLM(nn.Module):
623
686
  name = name.replace(weight_name, param_name)
624
687
  param = params_dict[name]
625
688
  weight_loader = param.weight_loader
626
- weight_loader(param,
627
- loaded_weight,
628
- weight_name,
629
- expert_id=expert_id,
630
- pre_sharded=get_tensor_model_parallel_world_size() > 1)
689
+ weight_loader(
690
+ param,
691
+ loaded_weight,
692
+ weight_name,
693
+ expert_id=expert_id,
694
+ pre_sharded=get_tensor_model_parallel_world_size() > 1,
695
+ )
631
696
  break
632
697
  else:
633
698
  # Skip loading extra bias for GPTQ models.
634
699
  if name.endswith(".bias") and name not in params_dict:
635
700
  continue
636
701
  param = params_dict[name]
637
- weight_loader = getattr(param, "weight_loader",
638
- default_weight_loader)
702
+ weight_loader = getattr(
703
+ param, "weight_loader", default_weight_loader
704
+ )
639
705
  weight_loader(param, loaded_weight)
640
706
 
641
707
 
@@ -645,10 +711,11 @@ def all_close_1d(x: torch.Tensor) -> bool:
645
711
 
646
712
 
647
713
  old_prepare_weights = getattr(DefaultModelLoader, "_prepare_weights")
648
- def _prepare_presharded_weights(self,
649
- model_name_or_path: str,
650
- revision: Optional[str],
651
- fall_back_to_pt: bool) -> Tuple[str, List[str], bool]:
714
+
715
+
716
+ def _prepare_presharded_weights(
717
+ self, model_name_or_path: str, revision: Optional[str], fall_back_to_pt: bool
718
+ ) -> Tuple[str, List[str], bool]:
652
719
  import glob
653
720
  import os
654
721
 
@@ -668,4 +735,4 @@ def _prepare_presharded_weights(self,
668
735
  return hf_folder, hf_weights_files, use_safetensors
669
736
 
670
737
 
671
- EntryClass = Grok1ModelForCausalLM
738
+ EntryClass = Grok1ModelForCausalLM
@@ -1,7 +1,8 @@
1
1
  # Adapted from
2
2
  # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1
3
3
  """Inference-only LLaMA model compatible with HuggingFace weights."""
4
- from typing import Any, Dict, Optional, Tuple, Iterable
4
+
5
+ from typing import Any, Dict, Iterable, Optional, Tuple
5
6
 
6
7
  import torch
7
8
  import tqdm
@@ -10,7 +11,7 @@ from transformers import LlamaConfig
10
11
  from vllm.config import CacheConfig
11
12
  from vllm.distributed import (
12
13
  get_tensor_model_parallel_rank,
13
- get_tensor_model_parallel_world_size
14
+ get_tensor_model_parallel_world_size,
14
15
  )
15
16
  from vllm.model_executor.layers.activation import SiluAndMul
16
17
  from vllm.model_executor.layers.layernorm import RMSNorm
@@ -76,6 +77,7 @@ class LlamaAttention(nn.Module):
76
77
  layer_id: int = 0,
77
78
  rope_theta: float = 10000,
78
79
  rope_scaling: Optional[Dict[str, Any]] = None,
80
+ rope_is_neox_style: bool = True,
79
81
  max_position_embeddings: int = 8192,
80
82
  quant_config: Optional[QuantizationConfig] = None,
81
83
  ) -> None:
@@ -123,6 +125,7 @@ class LlamaAttention(nn.Module):
123
125
  max_position=max_position_embeddings,
124
126
  base=rope_theta,
125
127
  rope_scaling=rope_scaling,
128
+ is_neox_style=rope_is_neox_style,
126
129
  )
127
130
  self.attn = RadixAttention(
128
131
  self.num_heads,
@@ -158,9 +161,12 @@ class LlamaDecoderLayer(nn.Module):
158
161
  rope_theta = getattr(config, "rope_theta", 10000)
159
162
  rope_scaling = getattr(config, "rope_scaling", None)
160
163
  if rope_scaling is not None and getattr(
161
- config, "original_max_position_embeddings", None):
162
- rope_scaling["original_max_position_embeddings"] = (
163
- config.original_max_position_embeddings)
164
+ config, "original_max_position_embeddings", None
165
+ ):
166
+ rope_scaling[
167
+ "original_max_position_embeddings"
168
+ ] = config.original_max_position_embeddings
169
+ rope_is_neox_style = getattr(config, "rope_is_neox_style", True)
164
170
  max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
165
171
  self.self_attn = LlamaAttention(
166
172
  hidden_size=self.hidden_size,
@@ -169,6 +175,7 @@ class LlamaDecoderLayer(nn.Module):
169
175
  layer_id=layer_id,
170
176
  rope_theta=rope_theta,
171
177
  rope_scaling=rope_scaling,
178
+ rope_is_neox_style=rope_is_neox_style,
172
179
  max_position_embeddings=max_position_embeddings,
173
180
  quant_config=quant_config,
174
181
  )
@@ -0,0 +1,107 @@
1
+ from typing import Iterable, Optional, Tuple
2
+
3
+ import torch
4
+ import tqdm
5
+ from torch import nn
6
+ from transformers import LlamaConfig
7
+ from vllm.config import CacheConfig
8
+ from vllm.distributed import get_tensor_model_parallel_rank
9
+ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
10
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
11
+
12
+ from sglang.srt.layers.logits_processor import LogitProcessorOutput
13
+ from sglang.srt.managers.controller.model_runner import InputMetadata
14
+ from sglang.srt.models.llama2 import LlamaModel
15
+
16
+
17
+ class LlamaForClassification(nn.Module):
18
+ def __init__(
19
+ self,
20
+ config: LlamaConfig,
21
+ quant_config: Optional[QuantizationConfig] = None,
22
+ cache_config: Optional[CacheConfig] = None,
23
+ ) -> None:
24
+ super().__init__()
25
+ self.config = config
26
+ self.quant_config = quant_config
27
+ self.model = LlamaModel(config, quant_config=quant_config)
28
+
29
+ self.classification_head = nn.Linear(
30
+ config.hidden_size, config.classification_out_size
31
+ )
32
+ self.eos_token_id = config.eos_token_id
33
+
34
+ def forward(
35
+ self,
36
+ input_ids: torch.Tensor,
37
+ positions: torch.Tensor,
38
+ input_metadata: InputMetadata,
39
+ input_embeds: torch.Tensor = None,
40
+ ) -> torch.Tensor:
41
+ hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
42
+ is_eos_token = input_ids == self.eos_token_id
43
+ hidden_states = hidden_states[is_eos_token]
44
+ scores = self.classification_head(hidden_states)
45
+
46
+ if scores.shape[0] != input_metadata.batch_size:
47
+ print("Warning: the EOS tokens are missing in some sentences.")
48
+ scores = torch.ones(
49
+ (input_metadata.batch_size, self.config.classification_out_size)
50
+ ).to(input_ids.device)
51
+
52
+ return LogitProcessorOutput(
53
+ next_token_logits=scores,
54
+ next_token_logprobs=scores,
55
+ normalized_prompt_logprobs=scores,
56
+ prefill_token_logprobs=torch.ones_like(input_ids),
57
+ prefill_top_logprobs=None,
58
+ decode_top_logprobs=None,
59
+ )
60
+
61
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
62
+ stacked_params_mapping = [
63
+ # (param_name, shard_name, shard_id)
64
+ ("qkv_proj", "q_proj", "q"),
65
+ ("qkv_proj", "k_proj", "k"),
66
+ ("qkv_proj", "v_proj", "v"),
67
+ ("gate_up_proj", "gate_proj", 0),
68
+ ("gate_up_proj", "up_proj", 1),
69
+ ]
70
+ params_dict = dict(self.named_parameters())
71
+ if get_tensor_model_parallel_rank() == 0:
72
+ weights = tqdm.tqdm(weights, total=int(len(params_dict) * 1.5))
73
+ for name, loaded_weight in weights:
74
+ if "rotary_emb.inv_freq" in name or "projector" in name:
75
+ continue
76
+ if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
77
+ # Models trained using ColossalAI may include these tensors in
78
+ # the checkpoint. Skip them.
79
+ continue
80
+ if "lm_head" in name:
81
+ continue
82
+
83
+ for param_name, weight_name, shard_id in stacked_params_mapping:
84
+ if weight_name not in name:
85
+ continue
86
+ name = name.replace(weight_name, param_name)
87
+ # Skip loading extra bias for GPTQ models.
88
+ if name.endswith(".bias") and name not in params_dict:
89
+ continue
90
+ if name.startswith("model.vision_tower") and name not in params_dict:
91
+ continue
92
+ param = params_dict[name]
93
+ weight_loader = param.weight_loader
94
+ weight_loader(param, loaded_weight, shard_id)
95
+ break
96
+ else:
97
+ # Skip loading extra bias for GPTQ models.
98
+ if name.endswith(".bias") and name not in params_dict:
99
+ continue
100
+ if name.startswith("model.vision_tower") and name not in params_dict:
101
+ continue
102
+ param = params_dict[name]
103
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
104
+ weight_loader(param, loaded_weight)
105
+
106
+
107
+ EntryClass = LlamaForClassification