sglang 0.3.6.post2__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. sglang/bench_offline_throughput.py +55 -2
  2. sglang/bench_one_batch.py +7 -6
  3. sglang/bench_one_batch_server.py +4 -3
  4. sglang/bench_serving.py +13 -0
  5. sglang/check_env.py +1 -1
  6. sglang/launch_server.py +3 -2
  7. sglang/srt/_custom_ops.py +118 -0
  8. sglang/srt/configs/device_config.py +17 -0
  9. sglang/srt/configs/load_config.py +84 -0
  10. sglang/srt/configs/model_config.py +161 -4
  11. sglang/srt/configs/qwen2vl.py +5 -8
  12. sglang/srt/constrained/outlines_backend.py +6 -1
  13. sglang/srt/constrained/outlines_jump_forward.py +8 -1
  14. sglang/srt/distributed/__init__.py +3 -0
  15. sglang/srt/distributed/communication_op.py +34 -0
  16. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  17. sglang/srt/distributed/device_communicators/cuda_wrapper.py +182 -0
  18. sglang/srt/distributed/device_communicators/custom_all_reduce.py +352 -0
  19. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +291 -0
  20. sglang/srt/distributed/device_communicators/hpu_communicator.py +48 -0
  21. sglang/srt/distributed/device_communicators/pynccl.py +204 -0
  22. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +362 -0
  23. sglang/srt/distributed/device_communicators/shm_broadcast.py +568 -0
  24. sglang/srt/distributed/device_communicators/xpu_communicator.py +47 -0
  25. sglang/srt/distributed/parallel_state.py +1275 -0
  26. sglang/srt/distributed/utils.py +223 -0
  27. sglang/srt/hf_transformers_utils.py +37 -1
  28. sglang/srt/layers/attention/flashinfer_backend.py +13 -15
  29. sglang/srt/layers/attention/torch_native_backend.py +285 -0
  30. sglang/srt/layers/fused_moe_patch.py +20 -11
  31. sglang/srt/layers/linear.py +1 -0
  32. sglang/srt/layers/logits_processor.py +17 -3
  33. sglang/srt/layers/quantization/__init__.py +34 -0
  34. sglang/srt/layers/vocab_parallel_embedding.py +1 -0
  35. sglang/srt/lora/lora.py +1 -1
  36. sglang/srt/managers/data_parallel_controller.py +7 -11
  37. sglang/srt/managers/detokenizer_manager.py +7 -4
  38. sglang/srt/managers/image_processor.py +1 -1
  39. sglang/srt/managers/io_struct.py +48 -12
  40. sglang/srt/managers/schedule_batch.py +42 -36
  41. sglang/srt/managers/schedule_policy.py +7 -4
  42. sglang/srt/managers/scheduler.py +111 -46
  43. sglang/srt/managers/session_controller.py +0 -3
  44. sglang/srt/managers/tokenizer_manager.py +169 -100
  45. sglang/srt/managers/tp_worker.py +36 -3
  46. sglang/srt/managers/tp_worker_overlap_thread.py +32 -5
  47. sglang/srt/model_executor/cuda_graph_runner.py +16 -7
  48. sglang/srt/model_executor/forward_batch_info.py +9 -4
  49. sglang/srt/model_executor/model_runner.py +136 -150
  50. sglang/srt/model_loader/__init__.py +34 -0
  51. sglang/srt/model_loader/loader.py +1139 -0
  52. sglang/srt/model_loader/utils.py +41 -0
  53. sglang/srt/model_loader/weight_utils.py +640 -0
  54. sglang/srt/models/baichuan.py +9 -10
  55. sglang/srt/models/chatglm.py +6 -15
  56. sglang/srt/models/commandr.py +2 -3
  57. sglang/srt/models/dbrx.py +2 -3
  58. sglang/srt/models/deepseek.py +4 -11
  59. sglang/srt/models/deepseek_v2.py +3 -11
  60. sglang/srt/models/exaone.py +2 -3
  61. sglang/srt/models/gemma.py +2 -6
  62. sglang/srt/models/gemma2.py +3 -14
  63. sglang/srt/models/gemma2_reward.py +0 -1
  64. sglang/srt/models/gpt2.py +5 -12
  65. sglang/srt/models/gpt_bigcode.py +6 -22
  66. sglang/srt/models/grok.py +14 -51
  67. sglang/srt/models/internlm2.py +2 -3
  68. sglang/srt/models/internlm2_reward.py +0 -1
  69. sglang/srt/models/llama.py +97 -27
  70. sglang/srt/models/llama_classification.py +1 -2
  71. sglang/srt/models/llama_embedding.py +1 -2
  72. sglang/srt/models/llama_reward.py +2 -3
  73. sglang/srt/models/llava.py +10 -12
  74. sglang/srt/models/llavavid.py +1 -2
  75. sglang/srt/models/minicpm.py +4 -7
  76. sglang/srt/models/minicpm3.py +6 -19
  77. sglang/srt/models/mixtral.py +12 -5
  78. sglang/srt/models/mixtral_quant.py +2 -3
  79. sglang/srt/models/mllama.py +3 -7
  80. sglang/srt/models/olmo.py +2 -8
  81. sglang/srt/models/olmo2.py +391 -0
  82. sglang/srt/models/olmoe.py +3 -5
  83. sglang/srt/models/phi3_small.py +8 -8
  84. sglang/srt/models/qwen.py +2 -3
  85. sglang/srt/models/qwen2.py +10 -9
  86. sglang/srt/models/qwen2_moe.py +4 -11
  87. sglang/srt/models/qwen2_vl.py +12 -9
  88. sglang/srt/models/registry.py +99 -0
  89. sglang/srt/models/stablelm.py +2 -3
  90. sglang/srt/models/torch_native_llama.py +6 -12
  91. sglang/srt/models/xverse.py +2 -4
  92. sglang/srt/models/xverse_moe.py +4 -11
  93. sglang/srt/models/yivl.py +2 -3
  94. sglang/srt/openai_api/adapter.py +10 -6
  95. sglang/srt/openai_api/protocol.py +1 -0
  96. sglang/srt/server.py +303 -204
  97. sglang/srt/server_args.py +65 -31
  98. sglang/srt/utils.py +253 -48
  99. sglang/test/test_utils.py +27 -7
  100. sglang/utils.py +2 -2
  101. sglang/version.py +1 -1
  102. {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/METADATA +2 -1
  103. sglang-0.4.0.dist-info/RECORD +184 -0
  104. sglang/srt/layers/fused_moe_grok/__init__.py +0 -1
  105. sglang/srt/layers/fused_moe_grok/fused_moe.py +0 -692
  106. sglang/srt/layers/fused_moe_grok/layer.py +0 -630
  107. sglang-0.3.6.post2.dist-info/RECORD +0 -164
  108. {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/LICENSE +0 -0
  109. {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/WHEEL +0 -0
  110. {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/top_level.txt +0 -0
@@ -20,7 +20,6 @@ import torch
20
20
  from torch import nn
21
21
  from vllm.distributed import get_tensor_model_parallel_world_size
22
22
  from vllm.model_executor.layers.rotary_embedding import get_rope
23
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
24
23
 
25
24
  from sglang.srt.layers.activation import SiluAndMul
26
25
  from sglang.srt.layers.layernorm import RMSNorm
@@ -37,6 +36,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
37
36
  VocabParallelEmbedding,
38
37
  )
39
38
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
39
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
40
40
 
41
41
 
42
42
  class MiniCPMMLP(nn.Module):
@@ -275,7 +275,6 @@ class MiniCPMForCausalLM(nn.Module):
275
275
  self,
276
276
  config,
277
277
  quant_config: Optional[QuantizationConfig] = None,
278
- cache_config=None,
279
278
  ) -> None:
280
279
  super().__init__()
281
280
  self.config = config
@@ -308,12 +307,10 @@ class MiniCPMForCausalLM(nn.Module):
308
307
  hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
309
308
  hidden_states = hidden_states / self.scale_width
310
309
  if self.config.tie_word_embeddings:
311
- lm_head_weight = self.model.embed_tokens.weight
310
+ lm_head = self.model.embed_tokens
312
311
  else:
313
- lm_head_weight = self.lm_head.weight
314
- return self.logits_processor(
315
- input_ids, hidden_states, lm_head_weight, forward_batch
316
- )
312
+ lm_head = self.lm_head
313
+ return self.logits_processor(input_ids, hidden_states, lm_head, forward_batch)
317
314
 
318
315
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
319
316
  stacked_params_mapping = [
@@ -27,7 +27,6 @@ from vllm.model_executor.layers.linear import (
27
27
  RowParallelLinear,
28
28
  )
29
29
  from vllm.model_executor.layers.rotary_embedding import get_rope
30
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
31
30
 
32
31
  from sglang.srt.layers.activation import SiluAndMul
33
32
  from sglang.srt.layers.layernorm import RMSNorm
@@ -40,6 +39,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
40
39
  )
41
40
  from sglang.srt.managers.schedule_batch import global_server_args_dict
42
41
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
42
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
43
43
  from sglang.srt.utils import is_flashinfer_available
44
44
 
45
45
  if is_flashinfer_available():
@@ -105,7 +105,6 @@ class MiniCPM3Attention(nn.Module):
105
105
  rope_theta: float = 10000,
106
106
  rope_scaling: Optional[Dict[str, Any]] = None,
107
107
  max_position_embeddings: int = 8192,
108
- cache_config=None,
109
108
  quant_config: Optional[QuantizationConfig] = None,
110
109
  layer_id=None,
111
110
  ) -> None:
@@ -249,7 +248,6 @@ class MiniCPM3AttentionMLA(nn.Module):
249
248
  rope_theta: float = 10000,
250
249
  rope_scaling: Optional[Dict[str, Any]] = None,
251
250
  max_position_embeddings: int = 8192,
252
- cache_config=None,
253
251
  quant_config: Optional[QuantizationConfig] = None,
254
252
  layer_id=None,
255
253
  ) -> None:
@@ -406,7 +404,6 @@ class MiniCPM3DecoderLayer(nn.Module):
406
404
  self,
407
405
  config: PretrainedConfig,
408
406
  layer_id: int,
409
- cache_config=None,
410
407
  quant_config: Optional[QuantizationConfig] = None,
411
408
  ) -> None:
412
409
  super().__init__()
@@ -430,7 +427,6 @@ class MiniCPM3DecoderLayer(nn.Module):
430
427
  rope_theta=rope_theta,
431
428
  rope_scaling=rope_scaling,
432
429
  max_position_embeddings=max_position_embeddings,
433
- cache_config=cache_config,
434
430
  quant_config=quant_config,
435
431
  layer_id=layer_id,
436
432
  )
@@ -449,7 +445,6 @@ class MiniCPM3DecoderLayer(nn.Module):
449
445
  rope_theta=rope_theta,
450
446
  rope_scaling=rope_scaling,
451
447
  max_position_embeddings=max_position_embeddings,
452
- cache_config=cache_config,
453
448
  quant_config=quant_config,
454
449
  layer_id=layer_id,
455
450
  )
@@ -498,7 +493,6 @@ class MiniCPM3Model(nn.Module):
498
493
  def __init__(
499
494
  self,
500
495
  config: PretrainedConfig,
501
- cache_config=None,
502
496
  quant_config: Optional[QuantizationConfig] = None,
503
497
  ) -> None:
504
498
  super().__init__()
@@ -512,9 +506,7 @@ class MiniCPM3Model(nn.Module):
512
506
  )
513
507
  self.layers = nn.ModuleList(
514
508
  [
515
- MiniCPM3DecoderLayer(
516
- config, i, cache_config=cache_config, quant_config=quant_config
517
- )
509
+ MiniCPM3DecoderLayer(config, i, quant_config=quant_config)
518
510
  for i in range(config.num_hidden_layers)
519
511
  ]
520
512
  )
@@ -549,7 +541,6 @@ class MiniCPM3ForCausalLM(nn.Module):
549
541
  def __init__(
550
542
  self,
551
543
  config: PretrainedConfig,
552
- cache_config=None,
553
544
  quant_config: Optional[QuantizationConfig] = None,
554
545
  ) -> None:
555
546
  super().__init__()
@@ -557,9 +548,7 @@ class MiniCPM3ForCausalLM(nn.Module):
557
548
 
558
549
  self.num_experts = getattr(self.config, "num_experts", 0)
559
550
  self.quant_config = quant_config
560
- self.model = MiniCPM3Model(
561
- config, cache_config=cache_config, quant_config=quant_config
562
- )
551
+ self.model = MiniCPM3Model(config, quant_config=quant_config)
563
552
  # self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
564
553
  if not self.config.tie_word_embeddings:
565
554
  self.lm_head = ParallelLMHead(
@@ -585,12 +574,10 @@ class MiniCPM3ForCausalLM(nn.Module):
585
574
  hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
586
575
  hidden_states = hidden_states / self.scale_width
587
576
  if self.config.tie_word_embeddings:
588
- lm_head_weight = self.model.embed_tokens.weight
577
+ lm_head = self.model.embed_tokens
589
578
  else:
590
- lm_head_weight = self.lm_head.weight
591
- return self.logits_processor(
592
- input_ids, hidden_states, lm_head_weight, forward_batch
593
- )
579
+ lm_head = self.lm_head
580
+ return self.logits_processor(input_ids, hidden_states, lm_head, forward_batch)
594
581
 
595
582
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
596
583
  stacked_params_mapping = [
@@ -23,7 +23,6 @@ from torch import nn
23
23
  from transformers import MixtralConfig
24
24
  from vllm.distributed import get_tensor_model_parallel_world_size
25
25
  from vllm.model_executor.layers.rotary_embedding import get_rope
26
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
27
26
 
28
27
  from sglang.srt.layers.fused_moe_triton import FusedMoE
29
28
  from sglang.srt.layers.layernorm import RMSNorm
@@ -42,6 +41,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
42
41
  )
43
42
  from sglang.srt.managers.schedule_batch import global_server_args_dict
44
43
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
44
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
45
45
 
46
46
 
47
47
  class MixtralMoE(nn.Module):
@@ -291,7 +291,6 @@ class MixtralForCausalLM(nn.Module):
291
291
  self,
292
292
  config: MixtralConfig,
293
293
  quant_config: Optional[QuantizationConfig] = None,
294
- cache_config=None,
295
294
  ) -> None:
296
295
  super().__init__()
297
296
  self.config = config
@@ -310,7 +309,7 @@ class MixtralForCausalLM(nn.Module):
310
309
  ) -> torch.Tensor:
311
310
  hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
312
311
  return self.logits_processor(
313
- input_ids, hidden_states, self.lm_head.weight, forward_batch
312
+ input_ids, hidden_states, self.lm_head, forward_batch
314
313
  )
315
314
 
316
315
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
@@ -340,7 +339,9 @@ class MixtralForCausalLM(nn.Module):
340
339
  continue
341
340
  name = name.replace(weight_name, param_name)
342
341
  # Skip loading extra bias for GPTQ models.
343
- if name.endswith(".bias") and name not in params_dict:
342
+ if (
343
+ name.endswith(".bias") or name.endswith("_bias")
344
+ ) and name not in params_dict:
344
345
  continue
345
346
 
346
347
  param = params_dict[name]
@@ -354,6 +355,10 @@ class MixtralForCausalLM(nn.Module):
354
355
  continue
355
356
  name = name.replace(weight_name, param_name)
356
357
 
358
+ if (
359
+ name.endswith(".bias") or name.endswith("_bias")
360
+ ) and name not in params_dict:
361
+ continue
357
362
  param = params_dict[name]
358
363
  weight_loader = param.weight_loader
359
364
  weight_loader(
@@ -366,7 +371,9 @@ class MixtralForCausalLM(nn.Module):
366
371
  break
367
372
  else:
368
373
  # Skip loading extra bias for GPTQ models.
369
- if name.endswith(".bias") and name not in params_dict:
374
+ if (
375
+ name.endswith(".bias") or name.endswith("_bias")
376
+ ) and name not in params_dict:
370
377
  continue
371
378
  # Skip loading kv_scale from ckpts towards new design.
372
379
  if name.endswith(".kv_scale") and name not in params_dict:
@@ -29,7 +29,6 @@ from vllm.distributed import (
29
29
  tensor_model_parallel_all_reduce,
30
30
  )
31
31
  from vllm.model_executor.layers.rotary_embedding import get_rope
32
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
33
32
 
34
33
  from sglang.srt.layers.layernorm import RMSNorm
35
34
  from sglang.srt.layers.linear import (
@@ -45,6 +44,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
45
44
  VocabParallelEmbedding,
46
45
  )
47
46
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
47
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
48
48
 
49
49
 
50
50
  class MixtralMLP(nn.Module):
@@ -324,7 +324,6 @@ class QuantMixtralForCausalLM(nn.Module):
324
324
  self,
325
325
  config: MixtralConfig,
326
326
  quant_config: Optional[QuantizationConfig] = None,
327
- cache_config=None,
328
327
  ) -> None:
329
328
  super().__init__()
330
329
  self.config = config
@@ -343,7 +342,7 @@ class QuantMixtralForCausalLM(nn.Module):
343
342
  ) -> torch.Tensor:
344
343
  hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
345
344
  return self.logits_processor(
346
- input_ids, hidden_states, self.lm_head.weight, forward_batch
345
+ input_ids, hidden_states, self.lm_head, forward_batch
347
346
  )
348
347
 
349
348
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
@@ -15,7 +15,6 @@ from transformers.models.mllama.modeling_mllama import (
15
15
  _prepare_aspect_ratio_attention_mask,
16
16
  )
17
17
  from vllm.distributed import get_tensor_model_parallel_world_size
18
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
19
18
 
20
19
  from sglang.srt.layers.activation import get_act_fn
21
20
  from sglang.srt.layers.layernorm import RMSNorm
@@ -34,6 +33,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
34
33
  )
35
34
  from sglang.srt.managers.schedule_batch import ImageInputs
36
35
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
36
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
37
37
  from sglang.srt.models.llama import LlamaDecoderLayer, LlamaMLP
38
38
 
39
39
 
@@ -654,7 +654,6 @@ class MllamaTextModel(nn.Module):
654
654
  self,
655
655
  config: config_mllama.MllamaTextConfig,
656
656
  quant_config: Optional[QuantizationConfig],
657
- cache_config=None,
658
657
  ):
659
658
  super().__init__()
660
659
  self.padding_id = config.pad_token_id
@@ -732,11 +731,10 @@ class MllamaForCausalLM(nn.Module):
732
731
  self,
733
732
  config: config_mllama.MllamaTextConfig,
734
733
  quant_config: Optional[QuantizationConfig],
735
- cache_config=None,
736
734
  ):
737
735
  super().__init__()
738
736
  self.vocab_size = config.vocab_size
739
- self.model = MllamaTextModel(config, cache_config, quant_config)
737
+ self.model = MllamaTextModel(config, quant_config)
740
738
  self.lm_head = ParallelLMHead(
741
739
  config.vocab_size,
742
740
  config.hidden_size,
@@ -772,7 +770,6 @@ class MllamaForConditionalGeneration(nn.Module):
772
770
  self,
773
771
  config: config_mllama.MllamaConfig,
774
772
  quant_config: Optional[QuantizationConfig] = None,
775
- cache_config=None,
776
773
  ):
777
774
  super().__init__()
778
775
  self.vocab_size = config.text_config.vocab_size
@@ -787,7 +784,6 @@ class MllamaForConditionalGeneration(nn.Module):
787
784
  self.vision_model = MllamaVisionModel(config.vision_config)
788
785
  self.language_model = MllamaForCausalLM(
789
786
  config.text_config,
790
- cache_config=cache_config,
791
787
  quant_config=quant_config,
792
788
  )
793
789
  self.multi_modal_projector = nn.Linear(
@@ -966,7 +962,7 @@ class MllamaForConditionalGeneration(nn.Module):
966
962
  skip_cross_attention=skip_cross_attention,
967
963
  )
968
964
  return self.logits_processor(
969
- input_ids, hidden_states, self.language_model.lm_head.weight, forward_batch
965
+ input_ids, hidden_states, self.language_model.lm_head, forward_batch
970
966
  )
971
967
 
972
968
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
sglang/srt/models/olmo.py CHANGED
@@ -22,7 +22,6 @@ from torch import nn
22
22
  from transformers import OlmoConfig
23
23
  from vllm.distributed import get_tensor_model_parallel_world_size
24
24
  from vllm.model_executor.layers.rotary_embedding import get_rope
25
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
26
25
 
27
26
  from sglang.srt.layers.activation import SiluAndMul
28
27
  from sglang.srt.layers.linear import (
@@ -38,6 +37,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
38
37
  VocabParallelEmbedding,
39
38
  )
40
39
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
40
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
41
41
  from sglang.srt.utils import make_layers
42
42
 
43
43
 
@@ -274,7 +274,6 @@ class OlmoForCausalLM(nn.Module):
274
274
  def __init__(
275
275
  self,
276
276
  config: OlmoConfig,
277
- cache_config=None,
278
277
  quant_config: Optional[QuantizationConfig] = None,
279
278
  ):
280
279
  super().__init__()
@@ -306,7 +305,7 @@ class OlmoForCausalLM(nn.Module):
306
305
  input_embeds=input_embeds,
307
306
  )
308
307
  return self.logits_processor(
309
- input_ids, hidden_states, self.lm_head.weight, forward_batch
308
+ input_ids, hidden_states, self.lm_head, forward_batch
310
309
  )
311
310
 
312
311
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
@@ -326,11 +325,6 @@ class OlmoForCausalLM(nn.Module):
326
325
  # Models trained using ColossalAI may include these tensors in
327
326
  # the checkpoint. Skip them.
328
327
  continue
329
- # With tie_word_embeddings, we can skip lm_head.weight
330
- # The weight might appear unnecessarily in the files if the model is
331
- # processed with quantization, LoRA, fine-tuning, etc.
332
- if self.config.tie_word_embeddings and "lm_head.weight" in name:
333
- continue
334
328
  for param_name, weight_name, shard_id in stacked_params_mapping:
335
329
  if weight_name not in name:
336
330
  continue