sglang 0.4.3.post2__py3-none-any.whl → 0.4.3.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +1 -1
- sglang/bench_offline_throughput.py +19 -0
- sglang/bench_one_batch.py +2 -2
- sglang/bench_serving.py +123 -79
- sglang/global_config.py +8 -3
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/lang/ir.py +1 -1
- sglang/srt/_custom_ops.py +83 -91
- sglang/srt/configs/load_config.py +4 -1
- sglang/srt/configs/model_config.py +48 -2
- sglang/srt/configs/qwen2_5_vl_config.py +5 -2
- sglang/srt/constrained/base_grammar_backend.py +117 -15
- sglang/srt/constrained/llguidance_backend.py +151 -0
- sglang/srt/constrained/outlines_backend.py +24 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -38
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
- sglang/srt/distributed/parallel_state.py +48 -3
- sglang/srt/entrypoints/engine.py +67 -9
- sglang/srt/entrypoints/http_server.py +190 -41
- sglang/srt/entrypoints/verl_engine.py +147 -0
- sglang/srt/function_call_parser.py +0 -1
- sglang/srt/layers/activation.py +11 -0
- sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
- sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
- sglang/srt/layers/attention/flashinfer_backend.py +302 -414
- sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
- sglang/srt/layers/attention/torch_native_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +13 -8
- sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
- sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
- sglang/srt/layers/attention/utils.py +39 -0
- sglang/srt/layers/attention/vision.py +60 -63
- sglang/srt/layers/dp_attention.py +142 -1
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/linear.py +3 -1
- sglang/srt/layers/logits_processor.py +281 -45
- sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
- sglang/srt/layers/moe/ep_moe/layer.py +140 -28
- sglang/srt/layers/moe/fused_moe_native.py +2 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
- sglang/srt/layers/moe/topk.py +13 -4
- sglang/srt/layers/quantization/__init__.py +111 -7
- sglang/srt/layers/quantization/blockwise_int8.py +409 -0
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8.py +69 -28
- sglang/srt/layers/quantization/fp8_utils.py +17 -1
- sglang/srt/layers/quantization/gptq.py +416 -0
- sglang/srt/layers/quantization/int8_kernel.py +327 -0
- sglang/srt/layers/quantization/int8_utils.py +73 -0
- sglang/srt/layers/quantization/modelopt_quant.py +18 -1
- sglang/srt/layers/radix_attention.py +1 -0
- sglang/srt/layers/rotary_embedding.py +0 -1
- sglang/srt/layers/sampler.py +76 -31
- sglang/srt/layers/vocab_parallel_embedding.py +14 -13
- sglang/srt/lora/lora.py +17 -1
- sglang/srt/lora/lora_config.py +5 -0
- sglang/srt/lora/lora_manager.py +1 -3
- sglang/srt/managers/cache_controller.py +193 -62
- sglang/srt/managers/configure_logging.py +2 -1
- sglang/srt/managers/data_parallel_controller.py +6 -2
- sglang/srt/managers/detokenizer_manager.py +124 -102
- sglang/srt/managers/image_processor.py +2 -1
- sglang/srt/managers/io_struct.py +144 -6
- sglang/srt/managers/schedule_batch.py +237 -197
- sglang/srt/managers/schedule_policy.py +29 -29
- sglang/srt/managers/scheduler.py +773 -334
- sglang/srt/managers/session_controller.py +6 -2
- sglang/srt/managers/tokenizer_manager.py +225 -68
- sglang/srt/managers/tp_worker.py +15 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
- sglang/srt/mem_cache/chunk_cache.py +18 -11
- sglang/srt/mem_cache/hiradix_cache.py +394 -0
- sglang/srt/mem_cache/memory_pool.py +68 -37
- sglang/srt/mem_cache/radix_cache.py +58 -47
- sglang/srt/metrics/collector.py +102 -36
- sglang/srt/model_executor/cuda_graph_runner.py +56 -31
- sglang/srt/model_executor/forward_batch_info.py +49 -16
- sglang/srt/model_executor/model_runner.py +280 -81
- sglang/srt/model_loader/loader.py +3 -3
- sglang/srt/model_loader/weight_utils.py +36 -14
- sglang/srt/models/baichuan.py +31 -6
- sglang/srt/models/chatglm.py +39 -7
- sglang/srt/models/commandr.py +29 -5
- sglang/srt/models/dbrx.py +31 -5
- sglang/srt/models/deepseek.py +43 -6
- sglang/srt/models/deepseek_nextn.py +32 -19
- sglang/srt/models/deepseek_v2.py +265 -32
- sglang/srt/models/exaone.py +19 -9
- sglang/srt/models/gemma.py +22 -8
- sglang/srt/models/gemma2.py +25 -12
- sglang/srt/models/gemma2_reward.py +5 -1
- sglang/srt/models/gpt2.py +28 -13
- sglang/srt/models/gpt_bigcode.py +27 -5
- sglang/srt/models/granite.py +21 -9
- sglang/srt/models/grok.py +21 -4
- sglang/srt/models/internlm2.py +36 -6
- sglang/srt/models/internlm2_reward.py +5 -1
- sglang/srt/models/llama.py +26 -9
- sglang/srt/models/llama_classification.py +5 -1
- sglang/srt/models/llama_eagle.py +17 -4
- sglang/srt/models/llama_embedding.py +5 -1
- sglang/srt/models/llama_reward.py +7 -2
- sglang/srt/models/llava.py +19 -3
- sglang/srt/models/llavavid.py +10 -1
- sglang/srt/models/minicpm.py +26 -2
- sglang/srt/models/minicpm3.py +39 -3
- sglang/srt/models/minicpmv.py +45 -14
- sglang/srt/models/mixtral.py +20 -9
- sglang/srt/models/mixtral_quant.py +50 -8
- sglang/srt/models/mllama.py +57 -11
- sglang/srt/models/olmo.py +34 -6
- sglang/srt/models/olmo2.py +34 -13
- sglang/srt/models/olmoe.py +26 -4
- sglang/srt/models/phi3_small.py +29 -10
- sglang/srt/models/qwen.py +26 -3
- sglang/srt/models/qwen2.py +26 -4
- sglang/srt/models/qwen2_5_vl.py +46 -8
- sglang/srt/models/qwen2_eagle.py +17 -5
- sglang/srt/models/qwen2_moe.py +44 -6
- sglang/srt/models/qwen2_rm.py +78 -0
- sglang/srt/models/qwen2_vl.py +39 -8
- sglang/srt/models/stablelm.py +32 -5
- sglang/srt/models/torch_native_llama.py +5 -2
- sglang/srt/models/xverse.py +21 -9
- sglang/srt/models/xverse_moe.py +45 -7
- sglang/srt/models/yivl.py +2 -1
- sglang/srt/openai_api/adapter.py +109 -24
- sglang/srt/openai_api/protocol.py +17 -1
- sglang/srt/reasoning_parser.py +154 -0
- sglang/srt/sampling/penaltylib/__init__.py +4 -6
- sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
- sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
- sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
- sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
- sglang/srt/sampling/sampling_batch_info.py +79 -157
- sglang/srt/sampling/sampling_params.py +16 -13
- sglang/srt/server_args.py +135 -60
- sglang/srt/speculative/build_eagle_tree.py +8 -9
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -12
- sglang/srt/speculative/eagle_utils.py +92 -57
- sglang/srt/speculative/eagle_worker.py +238 -111
- sglang/srt/speculative/spec_info.py +1 -13
- sglang/srt/utils.py +43 -17
- sglang/srt/warmup.py +47 -0
- sglang/test/few_shot_gsm8k.py +4 -1
- sglang/test/runners.py +389 -126
- sglang/test/send_one.py +88 -0
- sglang/test/test_block_fp8_ep.py +361 -0
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +138 -84
- sglang/utils.py +50 -60
- sglang/version.py +1 -1
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/METADATA +22 -15
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/RECORD +200 -166
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/WHEEL +1 -1
- sglang/bench_latency.py +0 -1
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
- sglang/test/srt/sampling/penaltylib/utils.py +0 -344
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/top_level.txt +0 -0
@@ -25,10 +25,10 @@ import filelock
|
|
25
25
|
import gguf
|
26
26
|
import huggingface_hub.constants
|
27
27
|
import numpy as np
|
28
|
+
import safetensors.torch
|
28
29
|
import torch
|
29
30
|
from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
|
30
31
|
from pydantic import BaseModel, ConfigDict, ValidationInfo, model_validator
|
31
|
-
from safetensors.torch import load_file, safe_open, save_file
|
32
32
|
from tqdm.auto import tqdm
|
33
33
|
|
34
34
|
from sglang.srt.configs.load_config import LoadConfig
|
@@ -62,7 +62,6 @@ enable_hf_transfer()
|
|
62
62
|
|
63
63
|
|
64
64
|
class DisabledTqdm(tqdm):
|
65
|
-
|
66
65
|
def __init__(self, *args, **kwargs):
|
67
66
|
super().__init__(*args, **kwargs, disable=True)
|
68
67
|
|
@@ -121,7 +120,7 @@ def convert_bin_to_safetensor_file(
|
|
121
120
|
)
|
122
121
|
|
123
122
|
# check if the tensors are the same
|
124
|
-
reloaded = load_file(sf_filename)
|
123
|
+
reloaded = safetensors.torch.load_file(sf_filename)
|
125
124
|
for k in loaded:
|
126
125
|
pt_tensor = loaded[k]
|
127
126
|
sf_tensor = reloaded[k]
|
@@ -133,7 +132,6 @@ def convert_bin_to_safetensor_file(
|
|
133
132
|
def get_quant_config(
|
134
133
|
model_config: ModelConfig, load_config: LoadConfig
|
135
134
|
) -> QuantizationConfig:
|
136
|
-
|
137
135
|
quant_cls = get_quantization_config(model_config.quantization)
|
138
136
|
|
139
137
|
# GGUF doesn't have config file
|
@@ -402,15 +400,34 @@ def np_cache_weights_iterator(
|
|
402
400
|
yield name, torch.from_numpy(param)
|
403
401
|
|
404
402
|
|
403
|
+
def decrypt(fn, key):
|
404
|
+
raise NotImplementedError()
|
405
|
+
|
406
|
+
|
407
|
+
def safetensors_encrypted_weights_iterator(
|
408
|
+
hf_weights_files: List[str],
|
409
|
+
is_all_weights_sharded: bool = False,
|
410
|
+
decryption_key: Optional[str] = None,
|
411
|
+
):
|
412
|
+
raise NotImplementedError()
|
413
|
+
|
414
|
+
|
405
415
|
def safetensors_weights_iterator(
|
406
416
|
hf_weights_files: List[str],
|
407
417
|
is_all_weights_sharded: bool = False,
|
418
|
+
decryption_key: Optional[str] = None,
|
408
419
|
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
409
420
|
"""Iterate over the weights in the model safetensor files.
|
410
421
|
|
411
422
|
If is_all_weights_sharded is True, it uses more optimize read by reading an
|
412
423
|
entire file instead of reading each tensor one by one.
|
413
424
|
"""
|
425
|
+
if decryption_key:
|
426
|
+
yield from safetensors_encrypted_weights_iterator(
|
427
|
+
hf_weights_files, is_all_weights_sharded, decryption_key
|
428
|
+
)
|
429
|
+
return
|
430
|
+
|
414
431
|
enable_tqdm = (
|
415
432
|
not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
|
416
433
|
)
|
@@ -420,15 +437,9 @@ def safetensors_weights_iterator(
|
|
420
437
|
disable=not enable_tqdm,
|
421
438
|
bar_format=_BAR_FORMAT,
|
422
439
|
):
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
param = f.get_tensor(name)
|
427
|
-
yield name, param
|
428
|
-
else:
|
429
|
-
result = load_file(st_file, device="cpu")
|
430
|
-
for name, param in result.items():
|
431
|
-
yield name, param
|
440
|
+
result = safetensors.torch.load_file(st_file, device="cpu")
|
441
|
+
for name, param in result.items():
|
442
|
+
yield name, param
|
432
443
|
|
433
444
|
|
434
445
|
def pt_weights_iterator(
|
@@ -644,9 +655,20 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
|
|
644
655
|
return remapped_name
|
645
656
|
|
646
657
|
possible_scale_names = [".k_scale", ".v_scale"]
|
658
|
+
modelopt_scale_names = [".self_attn.k_proj.k_scale", ".self_attn.v_proj.v_scale"]
|
647
659
|
for scale_name in possible_scale_names:
|
648
660
|
if name.endswith(scale_name):
|
649
|
-
|
661
|
+
# Check and remap the name based on modelopt scale names
|
662
|
+
if any(
|
663
|
+
modelopt_scale_name in name
|
664
|
+
for modelopt_scale_name in modelopt_scale_names
|
665
|
+
):
|
666
|
+
remapped_name = name.replace(
|
667
|
+
f".self_attn.{scale_name[1]}_proj{scale_name}",
|
668
|
+
f".self_attn.attn{scale_name}",
|
669
|
+
)
|
670
|
+
else:
|
671
|
+
remapped_name = name.replace(scale_name, f".attn{scale_name}")
|
650
672
|
if remapped_name not in params_dict:
|
651
673
|
print_warning_once(
|
652
674
|
f"Found {scale_name} in the checkpoint (e.g. {name}), "
|
sglang/srt/models/baichuan.py
CHANGED
@@ -46,6 +46,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
46
46
|
)
|
47
47
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
48
48
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
49
|
+
from sglang.srt.utils import add_prefix
|
49
50
|
|
50
51
|
|
51
52
|
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
|
@@ -80,13 +81,22 @@ class BaiChuanMLP(nn.Module):
|
|
80
81
|
intermediate_size: int,
|
81
82
|
hidden_act: str,
|
82
83
|
quant_config: Optional[QuantizationConfig] = None,
|
84
|
+
prefix: str = "",
|
83
85
|
):
|
84
86
|
super().__init__()
|
85
87
|
self.gate_up_proj = MergedColumnParallelLinear(
|
86
|
-
hidden_size,
|
88
|
+
hidden_size,
|
89
|
+
[intermediate_size] * 2,
|
90
|
+
bias=False,
|
91
|
+
quant_config=quant_config,
|
92
|
+
prefix=add_prefix("gate_up_proj", prefix),
|
87
93
|
)
|
88
94
|
self.down_proj = RowParallelLinear(
|
89
|
-
intermediate_size,
|
95
|
+
intermediate_size,
|
96
|
+
hidden_size,
|
97
|
+
bias=False,
|
98
|
+
quant_config=quant_config,
|
99
|
+
prefix=add_prefix("down_proj", prefix),
|
90
100
|
)
|
91
101
|
if hidden_act != "silu":
|
92
102
|
raise ValueError(
|
@@ -114,6 +124,7 @@ class BaiChuanAttention(nn.Module):
|
|
114
124
|
max_position_embeddings: int = 8192,
|
115
125
|
quant_config: Optional[QuantizationConfig] = None,
|
116
126
|
layer_id: int = 0,
|
127
|
+
prefix: str = "",
|
117
128
|
):
|
118
129
|
super().__init__()
|
119
130
|
self.hidden_size = hidden_size
|
@@ -167,6 +178,7 @@ class BaiChuanAttention(nn.Module):
|
|
167
178
|
scaling,
|
168
179
|
num_kv_heads=self.num_kv_heads,
|
169
180
|
layer_id=layer_id,
|
181
|
+
prefix=add_prefix("attn", prefix),
|
170
182
|
)
|
171
183
|
else:
|
172
184
|
self.rotary_emb = get_rope(
|
@@ -182,6 +194,7 @@ class BaiChuanAttention(nn.Module):
|
|
182
194
|
self.scaling,
|
183
195
|
num_kv_heads=self.num_kv_heads,
|
184
196
|
layer_id=layer_id,
|
197
|
+
prefix=add_prefix("attn", prefix),
|
185
198
|
)
|
186
199
|
|
187
200
|
def forward(
|
@@ -207,6 +220,7 @@ class BaiChuanDecoderLayer(nn.Module):
|
|
207
220
|
position_embedding: str,
|
208
221
|
layer_id: int = 0,
|
209
222
|
quant_config: Optional[QuantizationConfig] = None,
|
223
|
+
prefix: str = "",
|
210
224
|
):
|
211
225
|
super().__init__()
|
212
226
|
self.hidden_size = config.hidden_size
|
@@ -220,12 +234,14 @@ class BaiChuanDecoderLayer(nn.Module):
|
|
220
234
|
layer_id=layer_id,
|
221
235
|
max_position_embeddings=max_position_embeddings,
|
222
236
|
quant_config=quant_config,
|
237
|
+
prefix=add_prefix("self_attn", prefix),
|
223
238
|
)
|
224
239
|
self.mlp = BaiChuanMLP(
|
225
240
|
hidden_size=self.hidden_size,
|
226
241
|
intermediate_size=config.intermediate_size,
|
227
242
|
hidden_act=config.hidden_act,
|
228
243
|
quant_config=quant_config,
|
244
|
+
prefix=add_prefix("mlp", prefix),
|
229
245
|
)
|
230
246
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
231
247
|
self.post_attention_layernorm = RMSNorm(
|
@@ -264,6 +280,7 @@ class BaiChuanModel(nn.Module):
|
|
264
280
|
config: PretrainedConfig,
|
265
281
|
position_embedding: str,
|
266
282
|
quant_config: Optional[QuantizationConfig] = None,
|
283
|
+
prefix: str = "",
|
267
284
|
):
|
268
285
|
super().__init__()
|
269
286
|
self.config = config
|
@@ -281,6 +298,7 @@ class BaiChuanModel(nn.Module):
|
|
281
298
|
layer_id=i,
|
282
299
|
position_embedding=position_embedding,
|
283
300
|
quant_config=quant_config,
|
301
|
+
prefix=add_prefix(f"layers.{i}", prefix),
|
284
302
|
)
|
285
303
|
for i in range(config.num_hidden_layers)
|
286
304
|
]
|
@@ -330,18 +348,24 @@ class BaiChuanBaseForCausalLM(nn.Module):
|
|
330
348
|
config: PretrainedConfig,
|
331
349
|
position_embedding: str,
|
332
350
|
quant_config: Optional[QuantizationConfig] = None,
|
351
|
+
prefix: str = "",
|
333
352
|
):
|
334
353
|
super().__init__()
|
335
354
|
|
336
355
|
self.config = config
|
337
356
|
|
338
357
|
self.quant_config = quant_config
|
339
|
-
self.model = BaiChuanModel(
|
358
|
+
self.model = BaiChuanModel(
|
359
|
+
config, position_embedding, quant_config, prefix=add_prefix("model", prefix)
|
360
|
+
)
|
340
361
|
if self.config.tie_word_embeddings:
|
341
362
|
self.lm_head = self.model.embed_tokens
|
342
363
|
else:
|
343
364
|
self.lm_head = ParallelLMHead(
|
344
|
-
config.vocab_size,
|
365
|
+
config.vocab_size,
|
366
|
+
config.hidden_size,
|
367
|
+
quant_config=quant_config,
|
368
|
+
prefix=add_prefix("lm_head", prefix),
|
345
369
|
)
|
346
370
|
self.logits_processor = LogitsProcessor(config)
|
347
371
|
|
@@ -404,11 +428,12 @@ class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
|
|
404
428
|
self,
|
405
429
|
config,
|
406
430
|
quant_config: Optional[QuantizationConfig] = None,
|
431
|
+
prefix: str = "",
|
407
432
|
):
|
408
433
|
if config.hidden_size == 4096: # baichuan2 7b
|
409
|
-
super().__init__(config, "ROPE", quant_config)
|
434
|
+
super().__init__(config, "ROPE", quant_config, prefix=prefix)
|
410
435
|
else: # baichuan 13b, baichuan2 13b
|
411
|
-
super().__init__(config, "ALIBI", quant_config)
|
436
|
+
super().__init__(config, "ALIBI", quant_config, prefix=prefix)
|
412
437
|
|
413
438
|
|
414
439
|
EntryClass = [BaichuanForCausalLM]
|
sglang/srt/models/chatglm.py
CHANGED
@@ -41,6 +41,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
41
41
|
)
|
42
42
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
43
43
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
44
|
+
from sglang.srt.utils import add_prefix
|
44
45
|
|
45
46
|
LoraConfig = None
|
46
47
|
|
@@ -51,6 +52,7 @@ class GLMAttention(nn.Module):
|
|
51
52
|
config,
|
52
53
|
layer_id: int = 0,
|
53
54
|
quant_config: Optional[QuantizationConfig] = None,
|
55
|
+
prefix: str = "",
|
54
56
|
):
|
55
57
|
super().__init__()
|
56
58
|
self.hidden_size = config.hidden_size
|
@@ -85,12 +87,14 @@ class GLMAttention(nn.Module):
|
|
85
87
|
self.total_num_kv_heads,
|
86
88
|
bias=config.add_bias_linear or config.add_qkv_bias,
|
87
89
|
quant_config=quant_config,
|
90
|
+
prefix=add_prefix("query_key_value", prefix),
|
88
91
|
)
|
89
92
|
self.dense = RowParallelLinear(
|
90
93
|
self.total_num_heads * self.head_dim,
|
91
94
|
config.hidden_size,
|
92
95
|
bias=config.add_bias_linear,
|
93
96
|
quant_config=quant_config,
|
97
|
+
prefix=add_prefix("dense", prefix),
|
94
98
|
)
|
95
99
|
|
96
100
|
# https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141
|
@@ -109,6 +113,7 @@ class GLMAttention(nn.Module):
|
|
109
113
|
self.scaling,
|
110
114
|
num_kv_heads=self.num_kv_heads,
|
111
115
|
layer_id=layer_id,
|
116
|
+
prefix=add_prefix("attn", prefix),
|
112
117
|
)
|
113
118
|
|
114
119
|
def forward(
|
@@ -142,6 +147,7 @@ class GLMMLP(nn.Module):
|
|
142
147
|
self,
|
143
148
|
config,
|
144
149
|
quant_config: Optional[QuantizationConfig] = None,
|
150
|
+
prefix: str = "",
|
145
151
|
):
|
146
152
|
super().__init__()
|
147
153
|
|
@@ -153,6 +159,7 @@ class GLMMLP(nn.Module):
|
|
153
159
|
[config.ffn_hidden_size] * 2,
|
154
160
|
bias=config.add_bias_linear,
|
155
161
|
quant_config=quant_config,
|
162
|
+
prefix=add_prefix("dense_h_to_4h", prefix),
|
156
163
|
)
|
157
164
|
|
158
165
|
self.activation_func = SiluAndMul()
|
@@ -163,6 +170,7 @@ class GLMMLP(nn.Module):
|
|
163
170
|
config.hidden_size,
|
164
171
|
bias=config.add_bias_linear,
|
165
172
|
quant_config=quant_config,
|
173
|
+
prefix=add_prefix("dense_4h_to_h", prefix),
|
166
174
|
)
|
167
175
|
|
168
176
|
def forward(self, hidden_states):
|
@@ -186,6 +194,7 @@ class GLMBlock(nn.Module):
|
|
186
194
|
config,
|
187
195
|
layer_id: int,
|
188
196
|
quant_config: Optional[QuantizationConfig] = None,
|
197
|
+
prefix: str = "",
|
189
198
|
):
|
190
199
|
super().__init__()
|
191
200
|
self.apply_residual_connection_post_layernorm = (
|
@@ -201,7 +210,9 @@ class GLMBlock(nn.Module):
|
|
201
210
|
)
|
202
211
|
|
203
212
|
# Self attention.
|
204
|
-
self.self_attention = GLMAttention(
|
213
|
+
self.self_attention = GLMAttention(
|
214
|
+
config, layer_id, quant_config, prefix=add_prefix("self_attention", prefix)
|
215
|
+
)
|
205
216
|
self.hidden_dropout = config.hidden_dropout
|
206
217
|
|
207
218
|
# Layernorm on the attention output
|
@@ -210,7 +221,7 @@ class GLMBlock(nn.Module):
|
|
210
221
|
)
|
211
222
|
|
212
223
|
# MLP
|
213
|
-
self.mlp = GLMMLP(config, quant_config)
|
224
|
+
self.mlp = GLMMLP(config, quant_config, prefix=add_prefix("mlp", prefix))
|
214
225
|
|
215
226
|
def forward(
|
216
227
|
self,
|
@@ -257,6 +268,7 @@ class GLMTransformer(nn.Module):
|
|
257
268
|
self,
|
258
269
|
config,
|
259
270
|
quant_config: Optional[QuantizationConfig] = None,
|
271
|
+
prefix: str = "",
|
260
272
|
):
|
261
273
|
super().__init__()
|
262
274
|
self.post_layer_norm = config.post_layer_norm
|
@@ -266,7 +278,15 @@ class GLMTransformer(nn.Module):
|
|
266
278
|
|
267
279
|
# Transformer layers.
|
268
280
|
self.layers = nn.ModuleList(
|
269
|
-
[
|
281
|
+
[
|
282
|
+
GLMBlock(
|
283
|
+
config,
|
284
|
+
i,
|
285
|
+
quant_config,
|
286
|
+
prefix=add_prefix(f"layers.{i}", prefix),
|
287
|
+
)
|
288
|
+
for i in range(self.num_layers)
|
289
|
+
]
|
270
290
|
)
|
271
291
|
|
272
292
|
if self.post_layer_norm:
|
@@ -301,19 +321,28 @@ class ChatGLMM(nn.Module):
|
|
301
321
|
self,
|
302
322
|
config,
|
303
323
|
quant_config: Optional[QuantizationConfig] = None,
|
324
|
+
prefix: str = "",
|
304
325
|
):
|
305
326
|
super().__init__()
|
306
327
|
|
307
328
|
self.embedding = VocabParallelEmbedding(
|
308
|
-
config.padded_vocab_size,
|
329
|
+
config.padded_vocab_size,
|
330
|
+
config.hidden_size,
|
331
|
+
prefix=add_prefix("embedding", prefix),
|
309
332
|
)
|
310
333
|
|
311
334
|
self.num_layers = config.num_layers
|
312
335
|
self.multi_query_group_num = config.multi_query_group_num
|
313
336
|
self.kv_channels = config.kv_channels
|
314
|
-
self.encoder = GLMTransformer(
|
337
|
+
self.encoder = GLMTransformer(
|
338
|
+
config, quant_config, add_prefix("encoder", prefix)
|
339
|
+
)
|
315
340
|
|
316
|
-
self.output_layer = ParallelLMHead(
|
341
|
+
self.output_layer = ParallelLMHead(
|
342
|
+
config.padded_vocab_size,
|
343
|
+
config.hidden_size,
|
344
|
+
prefix=add_prefix("output_layer", prefix),
|
345
|
+
)
|
317
346
|
|
318
347
|
def forward(
|
319
348
|
self,
|
@@ -351,12 +380,15 @@ class ChatGLMForCausalLM(nn.Module):
|
|
351
380
|
self,
|
352
381
|
config: ChatGLMConfig,
|
353
382
|
quant_config: Optional[QuantizationConfig] = None,
|
383
|
+
prefix: str = "",
|
354
384
|
):
|
355
385
|
super().__init__()
|
356
386
|
self.config: ChatGLMConfig = config
|
357
387
|
self.quant_config = quant_config
|
358
388
|
self.max_position_embeddings = getattr(config, "max_sequence_length", 8192)
|
359
|
-
self.transformer = ChatGLMM(
|
389
|
+
self.transformer = ChatGLMM(
|
390
|
+
config, quant_config, prefix=add_prefix("transformer", prefix)
|
391
|
+
)
|
360
392
|
self.lm_head = self.transformer.output_layer
|
361
393
|
self.logits_processor = LogitsProcessor(config)
|
362
394
|
|
sglang/srt/models/commandr.py
CHANGED
@@ -65,7 +65,7 @@ from sglang.srt.model_loader.weight_utils import (
|
|
65
65
|
default_weight_loader,
|
66
66
|
maybe_remap_kv_scale_name,
|
67
67
|
)
|
68
|
-
from sglang.srt.utils import get_compiler_backend, set_weight_attrs
|
68
|
+
from sglang.srt.utils import add_prefix, get_compiler_backend, set_weight_attrs
|
69
69
|
|
70
70
|
|
71
71
|
@torch.compile(backend=get_compiler_backend())
|
@@ -110,6 +110,7 @@ class CohereMLP(nn.Module):
|
|
110
110
|
self,
|
111
111
|
config,
|
112
112
|
quant_config: Optional[QuantizationConfig] = None,
|
113
|
+
prefix: str = "",
|
113
114
|
):
|
114
115
|
super().__init__()
|
115
116
|
self.config = config
|
@@ -120,12 +121,14 @@ class CohereMLP(nn.Module):
|
|
120
121
|
[self.intermediate_size] * 2,
|
121
122
|
bias=False,
|
122
123
|
quant_config=quant_config,
|
124
|
+
prefix=add_prefix("gate_up_proj", prefix),
|
123
125
|
)
|
124
126
|
self.down_proj = RowParallelLinear(
|
125
127
|
self.intermediate_size,
|
126
128
|
self.hidden_size,
|
127
129
|
bias=False,
|
128
130
|
quant_config=quant_config,
|
131
|
+
prefix=add_prefix("down_proj", prefix),
|
129
132
|
)
|
130
133
|
self.act_fn = SiluAndMul()
|
131
134
|
|
@@ -142,6 +145,7 @@ class CohereAttention(nn.Module):
|
|
142
145
|
config: PretrainedConfig,
|
143
146
|
layer_id: int = 0,
|
144
147
|
quant_config: Optional[QuantizationConfig] = None,
|
148
|
+
prefix: str = "",
|
145
149
|
):
|
146
150
|
super().__init__()
|
147
151
|
tp_size = get_tensor_model_parallel_world_size()
|
@@ -177,12 +181,14 @@ class CohereAttention(nn.Module):
|
|
177
181
|
self.total_num_kv_heads,
|
178
182
|
bias=False,
|
179
183
|
quant_config=quant_config,
|
184
|
+
prefix=add_prefix("qkv_proj", prefix),
|
180
185
|
)
|
181
186
|
self.o_proj = RowParallelLinear(
|
182
187
|
self.total_num_heads * self.head_dim,
|
183
188
|
self.hidden_size,
|
184
189
|
bias=False,
|
185
190
|
quant_config=quant_config,
|
191
|
+
prefix=add_prefix("o_proj", prefix),
|
186
192
|
)
|
187
193
|
self.rotary_emb = get_rope(
|
188
194
|
self.head_dim,
|
@@ -198,6 +204,7 @@ class CohereAttention(nn.Module):
|
|
198
204
|
self.scaling,
|
199
205
|
num_kv_heads=self.num_kv_heads,
|
200
206
|
layer_id=layer_id,
|
207
|
+
prefix=add_prefix("attn", prefix),
|
201
208
|
)
|
202
209
|
if self.use_qk_norm:
|
203
210
|
self.q_norm = LayerNorm(
|
@@ -239,15 +246,23 @@ class CohereDecoderLayer(nn.Module):
|
|
239
246
|
config: PretrainedConfig,
|
240
247
|
layer_id: int = 0,
|
241
248
|
quant_config: Optional[QuantizationConfig] = None,
|
249
|
+
prefix: str = "",
|
242
250
|
):
|
243
251
|
super().__init__()
|
244
252
|
self.hidden_size = config.hidden_size
|
245
253
|
|
246
254
|
self.self_attn = CohereAttention(
|
247
|
-
config,
|
255
|
+
config,
|
256
|
+
layer_id=layer_id,
|
257
|
+
quant_config=quant_config,
|
258
|
+
prefix=add_prefix("self_attn", prefix),
|
248
259
|
)
|
249
260
|
|
250
|
-
self.mlp = CohereMLP(
|
261
|
+
self.mlp = CohereMLP(
|
262
|
+
config,
|
263
|
+
quant_config=quant_config,
|
264
|
+
prefix=add_prefix("mlp", prefix),
|
265
|
+
)
|
251
266
|
self.input_layernorm = LayerNorm(
|
252
267
|
param_shape=(config.hidden_size), eps=config.layer_norm_eps
|
253
268
|
)
|
@@ -279,6 +294,7 @@ class CohereModel(nn.Module):
|
|
279
294
|
self,
|
280
295
|
config: PretrainedConfig,
|
281
296
|
quant_config: Optional[QuantizationConfig] = None,
|
297
|
+
prefix: str = "",
|
282
298
|
):
|
283
299
|
super().__init__()
|
284
300
|
self.config = config
|
@@ -288,7 +304,12 @@ class CohereModel(nn.Module):
|
|
288
304
|
)
|
289
305
|
self.layers = nn.ModuleList(
|
290
306
|
[
|
291
|
-
CohereDecoderLayer(
|
307
|
+
CohereDecoderLayer(
|
308
|
+
config,
|
309
|
+
i,
|
310
|
+
quant_config=quant_config,
|
311
|
+
prefix=add_prefix(f"layers.{i}", prefix),
|
312
|
+
)
|
292
313
|
for i in range(config.num_hidden_layers)
|
293
314
|
]
|
294
315
|
)
|
@@ -321,12 +342,15 @@ class CohereForCausalLM(nn.Module):
|
|
321
342
|
self,
|
322
343
|
config: PretrainedConfig,
|
323
344
|
quant_config: Optional[QuantizationConfig] = None,
|
345
|
+
prefix: str = "",
|
324
346
|
) -> None:
|
325
347
|
super().__init__()
|
326
348
|
self.config = config
|
327
349
|
self.quant_config = quant_config
|
328
350
|
self.logits_processor = LogitsProcessor(config)
|
329
|
-
self.model = CohereModel(
|
351
|
+
self.model = CohereModel(
|
352
|
+
config, quant_config, prefix=add_prefix("model", prefix)
|
353
|
+
)
|
330
354
|
|
331
355
|
@torch.no_grad()
|
332
356
|
def forward(
|
sglang/srt/models/dbrx.py
CHANGED
@@ -46,7 +46,7 @@ from sglang.srt.model_loader.weight_utils import (
|
|
46
46
|
default_weight_loader,
|
47
47
|
maybe_remap_kv_scale_name,
|
48
48
|
)
|
49
|
-
from sglang.srt.utils import set_weight_attrs
|
49
|
+
from sglang.srt.utils import add_prefix, set_weight_attrs
|
50
50
|
|
51
51
|
|
52
52
|
class DbrxRouter(nn.Module):
|
@@ -58,6 +58,7 @@ class DbrxRouter(nn.Module):
|
|
58
58
|
self,
|
59
59
|
config: DbrxConfig,
|
60
60
|
params_dtype: Optional[torch.dtype] = None,
|
61
|
+
prefix: str = "",
|
61
62
|
):
|
62
63
|
super().__init__()
|
63
64
|
self.tp_size = get_tensor_model_parallel_world_size()
|
@@ -89,6 +90,7 @@ class DbrxExperts(nn.Module):
|
|
89
90
|
config: DbrxConfig,
|
90
91
|
quant_config: Optional[QuantizationConfig] = None,
|
91
92
|
params_dtype: Optional[torch.dtype] = None,
|
93
|
+
prefix: str = "",
|
92
94
|
):
|
93
95
|
super().__init__()
|
94
96
|
self.tp_size = get_tensor_model_parallel_world_size()
|
@@ -189,6 +191,7 @@ class DbrxAttention(nn.Module):
|
|
189
191
|
config: DbrxConfig,
|
190
192
|
layer_id: int = 0,
|
191
193
|
quant_config: Optional[QuantizationConfig] = None,
|
194
|
+
prefix: str = "",
|
192
195
|
):
|
193
196
|
super().__init__()
|
194
197
|
self.d_model = config.d_model
|
@@ -207,12 +210,14 @@ class DbrxAttention(nn.Module):
|
|
207
210
|
self.total_num_kv_heads,
|
208
211
|
bias=False,
|
209
212
|
quant_config=quant_config,
|
213
|
+
prefix=add_prefix("Wqkv", prefix),
|
210
214
|
)
|
211
215
|
self.out_proj = RowParallelLinear(
|
212
216
|
self.d_model,
|
213
217
|
self.d_model,
|
214
218
|
bias=False,
|
215
219
|
quant_config=quant_config,
|
220
|
+
prefix=add_prefix("out_proj", prefix),
|
216
221
|
)
|
217
222
|
self.rotary_emb = get_rope(
|
218
223
|
self.head_dim,
|
@@ -244,6 +249,7 @@ class DbrxAttention(nn.Module):
|
|
244
249
|
self.scaling,
|
245
250
|
num_kv_heads=self.num_kv_heads,
|
246
251
|
layer_id=layer_id,
|
252
|
+
prefix=add_prefix("attn", prefix),
|
247
253
|
)
|
248
254
|
|
249
255
|
def forward(
|
@@ -268,10 +274,16 @@ class DbrxFusedNormAttention(nn.Module):
|
|
268
274
|
config: DbrxConfig,
|
269
275
|
layer_id: int = 0,
|
270
276
|
quant_config: Optional[QuantizationConfig] = None,
|
277
|
+
prefix: str = "",
|
271
278
|
):
|
272
279
|
super().__init__()
|
273
280
|
self.d_model = config.d_model
|
274
|
-
self.attn = DbrxAttention(
|
281
|
+
self.attn = DbrxAttention(
|
282
|
+
config,
|
283
|
+
layer_id,
|
284
|
+
quant_config=quant_config,
|
285
|
+
prefix=add_prefix("attn", prefix),
|
286
|
+
)
|
275
287
|
self.norm_1 = nn.LayerNorm(self.d_model)
|
276
288
|
self.norm_2 = nn.LayerNorm(self.d_model)
|
277
289
|
|
@@ -300,10 +312,14 @@ class DbrxBlock(nn.Module):
|
|
300
312
|
config: DbrxConfig,
|
301
313
|
layer_id: int = 0,
|
302
314
|
quant_config: Optional[QuantizationConfig] = None,
|
315
|
+
prefix: str = "",
|
303
316
|
):
|
304
317
|
super().__init__()
|
305
318
|
self.norm_attn_norm = DbrxFusedNormAttention(
|
306
|
-
config,
|
319
|
+
config,
|
320
|
+
layer_id,
|
321
|
+
quant_config=quant_config,
|
322
|
+
prefix=add_prefix("norm_attn_norm", prefix),
|
307
323
|
)
|
308
324
|
self.ffn = DbrxExperts(config, quant_config=quant_config)
|
309
325
|
|
@@ -328,6 +344,7 @@ class DbrxModel(nn.Module):
|
|
328
344
|
self,
|
329
345
|
config: DbrxConfig,
|
330
346
|
quant_config: Optional[QuantizationConfig] = None,
|
347
|
+
prefix: str = "",
|
331
348
|
):
|
332
349
|
super().__init__()
|
333
350
|
self.wte = VocabParallelEmbedding(
|
@@ -336,7 +353,12 @@ class DbrxModel(nn.Module):
|
|
336
353
|
)
|
337
354
|
self.blocks = nn.ModuleList(
|
338
355
|
[
|
339
|
-
DbrxBlock(
|
356
|
+
DbrxBlock(
|
357
|
+
config,
|
358
|
+
i,
|
359
|
+
quant_config=quant_config,
|
360
|
+
prefix=add_prefix(f"blocks.{i}", prefix),
|
361
|
+
)
|
340
362
|
for i in range(config.n_layers)
|
341
363
|
]
|
342
364
|
)
|
@@ -369,17 +391,21 @@ class DbrxForCausalLM(nn.Module):
|
|
369
391
|
self,
|
370
392
|
config: DbrxConfig,
|
371
393
|
quant_config: Optional[QuantizationConfig] = None,
|
394
|
+
prefix: str = "",
|
372
395
|
):
|
373
396
|
super().__init__()
|
374
397
|
self.config = config
|
375
398
|
self.quant_config = quant_config
|
376
399
|
self.unpadded_vocab_size = config.vocab_size
|
377
|
-
self.transformer = DbrxModel(
|
400
|
+
self.transformer = DbrxModel(
|
401
|
+
config, quant_config=quant_config, prefix=add_prefix("transformer", prefix)
|
402
|
+
)
|
378
403
|
self.lm_head = ParallelLMHead(
|
379
404
|
config.vocab_size,
|
380
405
|
config.d_model,
|
381
406
|
org_num_embeddings=config.vocab_size,
|
382
407
|
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
|
408
|
+
prefix=add_prefix("lm_head", prefix),
|
383
409
|
)
|
384
410
|
self.logits_processor = LogitsProcessor(config)
|
385
411
|
|